From 7c0c6da6ad3c115a05e39fe0a81838e9485d9f80 Mon Sep 17 00:00:00 2001 From: Max R Date: Tue, 7 Apr 2026 09:00:53 -0400 Subject: [PATCH] Fold nested with statements --- README.md | 16 ++++ pyupgrade/_plugins/nested_with.py | 141 +++++++++++++++++++++++++++++ tests/features/nested_with_test.py | 130 ++++++++++++++++++++++++++ 3 files changed, 287 insertions(+) create mode 100644 pyupgrade/_plugins/nested_with.py create mode 100644 tests/features/nested_with_test.py diff --git a/README.md b/README.md index 1bfa3241..5aa3b695 100644 --- a/README.md +++ b/README.md @@ -795,3 +795,19 @@ Availability: -datetime.timezone.utc +datetime.UTC ``` + +### fold nested context managers + +Availability: +- `--py310-plus` is passed on the commandline. + +```diff +-with a: +- with b: +- pass ++with ( ++ a, ++ b, ++): ++ pass +``` diff --git a/pyupgrade/_plugins/nested_with.py b/pyupgrade/_plugins/nested_with.py new file mode 100644 index 00000000..7d42ddc1 --- /dev/null +++ b/pyupgrade/_plugins/nested_with.py @@ -0,0 +1,141 @@ +from __future__ import annotations + +import ast +from collections.abc import Iterable + +from tokenize_rt import Offset +from tokenize_rt import Token +from tokenize_rt import tokens_to_src + +from pyupgrade._ast_helpers import ast_to_offset +from pyupgrade._data import register +from pyupgrade._data import State +from pyupgrade._data import TokenFunc +from pyupgrade._token_helpers import Block + +_NEWLINES = frozenset(('NL', 'NEWLINE')) +_WITH_PREFIX_TOKENS = frozenset(('INDENT', 'UNIMPORTANT_WS')) + + +def _header_has_comment(tokens: list[Token], block: Block) -> bool: + return any( + tokens[i].name == 'COMMENT' for i in range(block.start, block.block) + ) + + +def _with_token_index(tokens: list[Token], block: Block) -> int: + i = block.start + while tokens[i].name in _WITH_PREFIX_TOKENS: + i += 1 + return i + + +def _header_is_single_line(tokens: list[Token], block: Block) -> bool: + start = _with_token_index(tokens, block) + return all( + tokens[i].name not in _NEWLINES + for i in range(start, block.colon) + ) + + +def _item_src(tokens: list[Token], block: Block) -> str: + i = _with_token_index(tokens, block) + return tokens_to_src(tokens[i + 1:block.colon]).strip() + + +def _fix_nested_with(i: int, tokens: list[Token], item_count: int) -> None: + blocks = [Block.find(tokens, i)] + + while ( + len(blocks) < item_count and + (block := blocks[-1]).block + 1 < len(tokens) and + tokens[block.block].name == 'INDENT' and + tokens[block.block + 1].matches(name='NAME', src='with') + ): + blocks.append(Block.find(tokens, block.block + 1)) + + if ( + len(blocks) < item_count or + any( + block.line or + not _header_is_single_line(tokens, block) or + _header_has_comment(tokens, block) + for block in blocks + ) + ): + return + + indent = ( + tokens[blocks[0].start].src + if tokens[blocks[0].start].src.isspace() else '' + ) + newline = tokens[blocks[0].block - 1].src + header = ''.join(( + indent, 'with (', newline, + *( + f'{indent} {_item_src(tokens, block)},{newline}' + for block in blocks + ), + indent, '):', newline, + )) + + for j in range(len(blocks) - 2, -1, -1): + blocks[j].dedent(tokens) + + for j in range(len(blocks) - 1, 0, -1): + del tokens[blocks[j].start:blocks[j].block] + + tokens[blocks[0].start:blocks[0].block] = [Token('CODE', header)] + + +def _parent_wraps_with(node: ast.With, parent: ast.AST) -> bool: + return ( + isinstance(parent, ast.With) and + len(parent.items) == 1 and + len(parent.body) == 1 and + parent.body[0] is node + ) + + +def _single_line_item(item: ast.withitem) -> bool: + return ( + item.context_expr.end_lineno is not None and + item.context_expr.lineno == item.context_expr.end_lineno and + ( + item.optional_vars is None or ( + item.optional_vars.end_lineno is not None and + item.optional_vars.lineno == item.optional_vars.end_lineno + ) + ) + ) + + +@register(ast.With) +def visit_With( + state: State, + node: ast.With, + parent: ast.AST, +) -> Iterable[tuple[Offset, TokenFunc]]: + if ( + state.settings.min_version < (3, 10) or + _parent_wraps_with(node, parent) + ): + return + + cur = node + item_count = 1 + while ( + len(cur.items) == 1 and + _single_line_item(cur.items[0]) and + len(cur.body) == 1 and + isinstance((nxt := cur.body[0]), ast.With) + ): + cur = nxt + item_count += 1 + + if item_count < 2: + return + + yield ast_to_offset(node), ( + lambda i, tokens: _fix_nested_with(i, tokens, item_count) + ) diff --git a/tests/features/nested_with_test.py b/tests/features/nested_with_test.py new file mode 100644 index 00000000..745943a3 --- /dev/null +++ b/tests/features/nested_with_test.py @@ -0,0 +1,130 @@ +from __future__ import annotations + +import pytest + +from pyupgrade._data import Settings +from pyupgrade._main import _fix_plugins + + +@pytest.mark.parametrize( + ('s', 'expected'), + ( + pytest.param( + 'with a:\n' + ' with b:\n' + ' pass\n', + 'with (\n' + ' a,\n' + ' b,\n' + '):\n' + ' pass\n', + id='two-level rewrite', + ), + pytest.param( + 'with a as x:\n' + ' with b as y:\n' + ' pass\n', + 'with (\n' + ' a as x,\n' + ' b as y,\n' + '):\n' + ' pass\n', + id='rewrite preserves as-targets', + ), + pytest.param( + 'def f() -> None:\n' + ' with a:\n' + ' with b:\n' + ' with c:\n' + ' pass\n', + 'def f() -> None:\n' + ' with (\n' + ' a,\n' + ' b,\n' + ' c,\n' + ' ):\n' + ' pass\n', + id='three-level rewrite inside function', + ), + pytest.param( + 'with x:\n' + ' with y:\n' + ' foo()\n' + ' # blah\n' + ' if z:\n' + ' pass\n', + 'with (\n' + ' x,\n' + ' y,\n' + '):\n' + ' foo()\n' + ' # blah\n' + ' if z:\n' + ' pass\n', + id='rewrite preserves dedent with nested body and comment', + ), + ), +) +def test_fix_nested_with(s, expected): + assert _fix_plugins(s, settings=Settings(min_version=(3, 9))) == s + assert _fix_plugins(s, settings=Settings(min_version=(3, 10))) == expected + + +@pytest.mark.parametrize( + ('s',), + ( + pytest.param( + 'async def f() -> None:\n' + ' async with a:\n' + ' async with b:\n' + ' pass\n', + id='skip async-with chain', + ), + pytest.param( + 'with a:\n' + ' with b:\n' + ' pass\n' + ' x = 1\n', + id='skip when outer body has extra statements', + ), + pytest.param( + 'with a, b:\n' + ' with c:\n' + ' pass\n', + id='skip when outer with already has multiple items', + ), + pytest.param( + 'with a:\n' + ' with b: pass\n', + id='skip single-line nested body', + ), + pytest.param( + 'with a:\n' + '\n' + ' with b:\n' + ' pass\n', + id='skip blank line between nested headers', + ), + pytest.param( + 'with a:\n' + ' # keep this comment\n' + ' with b:\n' + ' pass\n', + id='skip comment-only line between headers', + ), + pytest.param( + 'with a: # keep this comment\n' + ' with b:\n' + ' pass\n', + id='skip comment on outer header', + ), + pytest.param( + 'with a:\n' + ' with b: # keep this comment\n' + ' pass\n', + id='skip comment on inner header', + ), + ), +) +def test_fix_nested_with_noop(s): + assert _fix_plugins(s, settings=Settings(min_version=(3, 10))) == s