diff --git a/pyupgrade/_plugins/typing_pep604.py b/pyupgrade/_plugins/typing_pep604.py index 9b4039f2..56a3cfd4 100644 --- a/pyupgrade/_plugins/typing_pep604.py +++ b/pyupgrade/_plugins/typing_pep604.py @@ -14,7 +14,7 @@ from pyupgrade._data import register from pyupgrade._data import State from pyupgrade._data import TokenFunc -from pyupgrade._token_helpers import find_closing_bracket +from pyupgrade._token_helpers import _OPENING from pyupgrade._token_helpers import find_op from pyupgrade._token_helpers import is_close from pyupgrade._token_helpers import is_open @@ -22,14 +22,28 @@ def _fix_optional(i: int, tokens: list[Token]) -> None: j = find_op(tokens, i, '[') - k = find_closing_bracket(tokens, j) + k, contains_none = _find_closing_bracket_and_if_contains_none(tokens, j) if tokens[j].line == tokens[k].line: - tokens[k] = Token('CODE', ' | None') + if contains_none: + del tokens[k] + else: + tokens[k:k + 1] = [ + Token('UNIMPORTANT_WS', ' '), + Token('CODE', '| '), + Token('CODE', 'None'), + ] del tokens[i:j + 1] else: tokens[j] = tokens[j]._replace(src='(') tokens[k] = tokens[k]._replace(src=')') - tokens[i:j] = [Token('CODE', 'None | ')] + if contains_none: + del tokens[i:j] + else: + tokens[i:j] = [ + Token('CODE', 'None'), + Token('UNIMPORTANT_WS', ' '), + Token('CODE', '| '), + ] def _fix_union( @@ -43,6 +57,8 @@ def _fix_union( open_parens = [] commas = [] coding_depth = None + top_level_breaks = [] + lines_with_comments = [] j = find_op(tokens, i, '[') k = j + 1 @@ -70,11 +86,16 @@ def _fix_union( parens_done.append((paren_depth, (open_paren, k))) depth -= 1 - elif tokens[k].src == ',': - commas.append((depth, k)) - + elif tokens[k].src.strip() in [',', '|']: + if tokens[k].src.strip() == ',': + commas.append((depth, k)) + if depth == 1: + top_level_breaks.append(k) + elif tokens[k].name == 'COMMENT': + lines_with_comments.append(tokens[k].line) k += 1 k -= 1 + top_level_breaks.append(k) assert coding_depth is not None assert not open_parens, open_parens @@ -95,12 +116,15 @@ def _fix_union( else: comma_positions = [] - to_delete.sort() + to_delete += _find_duplicated_types( + tokens, j, top_level_breaks, lines_with_comments, + ) if tokens[j].line == tokens[k].line: del tokens[k] for comma in comma_positions: tokens[comma] = Token('CODE', ' |') + to_delete.sort() for paren in reversed(to_delete): del tokens[paren] del tokens[i:j + 1] @@ -110,11 +134,80 @@ def _fix_union( for comma in comma_positions: tokens[comma] = Token('CODE', ' |') + to_delete += _remove_consecutive_unimportant_ws( + tokens, [x for x in range(j, k) if x not in to_delete], + ) + to_delete.sort() for paren in reversed(to_delete): del tokens[paren] del tokens[i:j] +def _find_closing_bracket_and_if_contains_none( + tokens: list[Token], + i: int, +) -> tuple[int, bool]: + assert tokens[i].src in _OPENING + depth = 1 + i += 1 + contains_none = False + while depth: + if is_open(tokens[i]): + depth += 1 + elif is_close(tokens[i]): + depth -= 1 + elif depth == 1 and tokens[i].matches(name='NAME', src='None'): + contains_none = True + i += 1 + return i - 1, contains_none + + +def _find_duplicated_types( + tokens: list[Token], + opening_bracket: int, + depth_1_commas: list[int], + lines_with_comments: list[int], +) -> list[int]: + unique_names = [] + to_delete = [] + i = opening_bracket + 1 + for d1c in depth_1_commas: + important_tokens = [ + x + for x in range(i, d1c) + if tokens[x].name + not in ( + ['COMMENT'] + if tokens[x].line not in lines_with_comments + else ['COMMENT', 'NL', 'UNIMPORTANT_WS'] + ) + ] + type_ = ''.join([tokens[k].src.lstrip() for k in important_tokens]) + if type_[0] in [',', '|']: + type_ = type_[1:].lstrip() + if type_ in unique_names: + to_delete += important_tokens + else: + unique_names.append(type_) + i = d1c + return to_delete + + +def _remove_consecutive_unimportant_ws( + tokens: list[Token], idxs: list[int], +) -> list[int]: + to_delete = [] + prev_name = '' + for kk in idxs: + if prev_name == 'UNIMPORTANT_WS': + if tokens[kk].name == 'UNIMPORTANT_WS': + to_delete.append(kk) + elif tokens[kk].src == ' |': + tokens[kk] = Token('CODE', '|') + prev_name = tokens[kk].name + return to_delete + + def _supported_version(state: State) -> bool: return ( state.in_annotation and ( diff --git a/tests/features/typing_pep604_test.py b/tests/features/typing_pep604_test.py index eff5bf31..b13daa1e 100644 --- a/tests/features/typing_pep604_test.py +++ b/tests/features/typing_pep604_test.py @@ -242,6 +242,86 @@ def f(x: int | str) -> None: ... id='optional, 3.12: ignore close brace in fstring', ), + pytest.param( + 'from typing import Optional, Union\n' + 'def f(x: Optional[Union[int, None]]): pass\n' + 'def g(x: Union[Optional[int], None]): pass\n' + 'def h(x: Union[Union[int, None], None]): pass\n' + 'def i(x: Union[int, int, None]): pass\n' + 'def j(x: Union[Union[int, None], int]): pass\n' + 'def k(x: Union[Union[int, None], int]): pass # comment\n' + 'def l(x: Union[Union[Union[Union[a, b], c], d], a]): pass\n' + 'def m(x: Union[a.b | a.c, a.b, list[str], str]): pass\n', + + 'from typing import Optional, Union\n' + 'def f(x: int | None): pass\n' + 'def g(x: int | None): pass\n' + 'def h(x: int | None): pass\n' + 'def i(x: int | None): pass\n' + 'def j(x: int | None): pass\n' + 'def k(x: int | None): pass # comment\n' + 'def l(x: a | b | c | d): pass\n' + 'def m(x: a.b | a.c | list[str] | str): pass\n', + + id='duplicated types in nested unions or optionals', + ), + pytest.param( + 'from typing import Optional, Union\n' + 'f: Optional[\n' + ' Union[int, None]\n' + ']\n' + 'g: Union[\n' + ' int,\n' + ' int,\n' + ' None,\n' + ']\n' + 'h: Union[\n' + ' Union[int, None],\n' + ' int,\n' + ' None,\n' + ' Optional[int],\n' + ']\n' + 'i: Union[\n' + ' Union[int, None], # comment 1\n' + ' int, # comment 2\n' + ' None, # comment 3\n' + ' Optional[int], # comment 4\n' + ' Optional[str], # comment 5\n' + ']\n', + + 'from typing import Optional, Union\n' + 'f: (\n' + ' int | None\n' + ')\n' + 'g: (\n' + ' int |\n' + ' None\n' + ')\n' + 'h: (\n' + ' int | None\n' + ')\n' + 'i: (\n' + ' int | None # comment 1\n' + ' # comment 2\n' + ' # comment 3\n' + ' | # comment 4\n' + ' str # comment 5\n' + ')\n', + + id='duplicated types in multi-line nested unions or optionals', + ), + pytest.param( + 'from typing import Union\n' + 'def f(x: Union[list[Union[int, str]], list[Union[str, int]]]):\n' + ' pass\n', + + 'from typing import Union\n' + 'def f(x: list[int | str]):\n' + ' pass\n', + + id='general duplicated types', + marks=pytest.mark.xfail(reason='requires deeper type evaluation'), + ), ), ) def test_fix_pep604_types(s, expected):