diff --git a/README.md b/README.md index 5b40427b..0e089fc1 100644 --- a/README.md +++ b/README.md @@ -783,7 +783,6 @@ Availability: +def f(x: queue.Queue[int]) -> C: ``` - ### use `datetime.UTC` alias Availability: @@ -795,3 +794,16 @@ Availability: -datetime.timezone.utc +datetime.UTC ``` + +### Fold nested context managers + +Availability: +- `--py310-plus` and higher + +```diff +- with foo: +- with bar: +- body ++ with foo, bar: ++ body +``` diff --git a/pyupgrade/_plugins/fold_nested_context_managers.py b/pyupgrade/_plugins/fold_nested_context_managers.py new file mode 100644 index 00000000..d18b9d15 --- /dev/null +++ b/pyupgrade/_plugins/fold_nested_context_managers.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +import ast +import functools +import itertools +from collections.abc import Iterable +from typing import Any + +from tokenize_rt import Offset +from tokenize_rt import Token + +from pyupgrade._ast_helpers import ast_to_offset +from pyupgrade._data import register +from pyupgrade._data import State +from pyupgrade._data import TokenFunc +from pyupgrade._token_helpers import Block + + +def _expand_item(indent: int, item: ast.AST) -> str: + return '{}{}'.format(' ' * indent, ast.unparse(item)) + + +def _replace_context_managers( + i: int, + tokens: list[Token], + *, + with_items: list[ast.withitem], + body: Iterable[ast.AST], +) -> None: + block = Block.find(tokens, i, trim_end=True) + block_indent = block._minimum_indent(tokens) + replacement = '{}with ({}):\n{}\n'.format( + ' ' * block._initial_indent(tokens), + ', '.join(ast.unparse(item) for item in with_items), + '\n'.join(_expand_item(block_indent, item) for item in body), + ) + tokens[block.start:block.end] = [Token('CODE', replacement)] + + +def flatten(xs: Iterable[Any]) -> list[Any]: + return list(itertools.chain.from_iterable(xs)) + + +@register(ast.With) +def visit_With_fold_nested( + state: State, + node: ast.With, + parent: ast.AST, +) -> Iterable[tuple[Offset, TokenFunc]]: + """ + Fold nested with statements into one statement. + + with foo: + with bar: + body + + becomes + + with (foo, bar): + body + """ + if state.settings.min_version < (3, 10): + return + if isinstance(parent, ast.With): + # The top most with statement will handle all of the children. + return + + with_stmts = [] + current: ast.AST = node + while isinstance(current, ast.With): + with_stmts.append(current) + if len(current.body) == 1: + current = current.body[0] + else: + break + + if len(with_stmts) > 1: + with_items = flatten(n.items for n in with_stmts) + yield ast_to_offset(node), functools.partial( + _replace_context_managers, + body=with_stmts[-1].body, + with_items=with_items, + ) diff --git a/tests/features/fold_nested_context_managers_test.py b/tests/features/fold_nested_context_managers_test.py new file mode 100644 index 00000000..9e8b8678 --- /dev/null +++ b/tests/features/fold_nested_context_managers_test.py @@ -0,0 +1,128 @@ +from __future__ import annotations + +import pytest + +from pyupgrade._data import Settings +from pyupgrade._main import _fix_plugins + + +@pytest.mark.parametrize( + ('s', 'version'), + ( + pytest.param( + 'with foo:\n' + " print('something')\n" + '\n', + (3, 10), + id='simple with expression', + ), + pytest.param( + 'with foo as bar:\n' + " print('something')\n" + '\n', + (3, 10), + id='simple with expression and captured name', + ), + pytest.param( + 'with foo as thing1, bar as thing2:\n' + " print('something')\n" + '\n', + (3, 9), + id='nested with expression and captured names', + ), + pytest.param( + 'with foo:\n' + ' with bar:\n' + " print('something')\n" + " print('another')\n" + '\n', + (3, 9), + id='nested with expression with empty name capture workaround', + ), + ), +) +def test_fold_nested_context_managers_noop(s, version): + assert _fix_plugins(s, settings=Settings(min_version=version)) == s + + +@pytest.mark.parametrize( + ('s', 'expected', 'version'), + ( + pytest.param( + 'with foo:\n' + ' with bar:\n' + " print('something')\n" + " print('another')\n" + '\n', + 'with (foo, bar):\n' + " print('something')\n" + " print('another')\n" + '\n', + (3, 10), + id='nested with expression', + ), + pytest.param( + 'if value:\n' + ' with foo:\n' + ' with bar:\n' + ' with baz:\n' + " print('something')\n" + " print('another')\n" + '\n', + 'if value:\n' + ' with (foo, bar, baz):\n' + " print('something')\n" + " print('another')\n" + '\n', + (3, 10), + id='nested with expression inside of an if', + ), + pytest.param( + 'with foo as thing1:\n' + ' with bar as thing2:\n' + " print('something')\n" + " print('another')\n" + '\n', + 'with (foo as thing1, bar as thing2):\n' + " print('something')\n" + " print('another')\n" + '\n', + (3, 10), + id='nested with expression with named capture', + ), + pytest.param( + 'with foo as thing1:\n' + ' with bar:\n' + " print('something')\n" + " print('another')\n" + '\n', + 'with (foo as thing1, bar):\n' + " print('something')\n" + " print('another')\n" + '\n', + (3, 10), + id='nested with expression with only one named capture', + ), + pytest.param( + 'with foo as thing1:\n' + ' with bar:\n' + " print('something')\n" + " print('another')\n" + ' with other:\n' + " print('yet enother')\n" + '\n', + 'with foo as thing1:\n' + ' with bar:\n' + " print('something')\n" + " print('another')\n" + ' with other:\n' + " print('yet enother')\n" + '\n', + (3, 10), + id='nested with expression that is semantically meaningful', + ), + ), +) +def test_fold_nested_context_managers(s, expected, version): + ret = _fix_plugins(s, settings=Settings(min_version=version)) + assert ret == expected