diff --git a/papermill/tests/test_translators.py b/papermill/tests/test_translators.py index 0edc1f07..a14f5099 100644 --- a/papermill/tests/test_translators.py +++ b/papermill/tests/test_translators.py @@ -107,6 +107,9 @@ def test_translate_comment_python(test_input, expected): Parameter("b", "float", "-2.3432", "My b variable"), ], ), + # Regression test for #864: '=' inside string literals shouldn't trip parsing. + ('s = "a=b"', [Parameter("s", "None", '"a=b"', "")]), + ("s = 'a=b'", [Parameter("s", "None", "'a=b'", "")]), ], ) def test_inspect_python(test_input, expected): diff --git a/papermill/translators.py b/papermill/translators.py index 1cb43d89..6c2466a6 100644 --- a/papermill/translators.py +++ b/papermill/translators.py @@ -1,7 +1,9 @@ +import io import logging import math import re import shlex +import tokenize from .exceptions import PapermillException from .models import Parameter @@ -9,6 +11,21 @@ logger = logging.getLogger(__name__) +def _count_assignment_operators(line): + """Count top-level assignment operators in a Python source line. + + Uses ``tokenize`` so that ``=`` characters appearing inside string + literals (e.g. ``s = "a=b"``) are not counted as assignment + operators. Falls back to a naive ``line.count('=')`` if tokenization + fails (e.g. for incomplete multiline definitions). + """ + try: + tokens = tokenize.tokenize(io.BytesIO(line.encode("utf-8")).readline) + return sum(1 for tok in tokens if tok.type == tokenize.OP and tok.string == "=") + except (tokenize.TokenError, SyntaxError): + return line.count("=") + + class PapermillTranslators: ''' The holder which houses any translator registered with the system. @@ -242,7 +259,7 @@ def flatten_accumulator(accumulator): if len(line.strip()) == 0 or line.strip().startswith('#'): continue # Skip blank and comment - nequal = line.count("=") + nequal = _count_assignment_operators(line) if nequal > 0: grouped_variable.append(flatten_accumulator(accumulator)) accumulator = []