From 85d16cd5ba18f790e720755accbb1b056ee038fa Mon Sep 17 00:00:00 2001 From: Bruno Cunha Date: Tue, 14 Apr 2026 00:35:12 -0400 Subject: [PATCH] test(coverage): strict type-parameter checking, catch 2 wrong test results MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The coverage checker's `is_same_type` compared only the base type name and stripped all parameters, so decimal precision/scale, varchar length, and list/map element types were never actually validated against the extension YAML return formulas. Two wrong test cases had been sitting in-tree because of this: - `sum((2.5, 0, 5.0, -2.5, -7.5)::dec<2, 1>) = -2.5::dec<38, 2>` in `tests/cases/arithmetic_decimal/sum_decimal.test`. `sum` returns `DECIMAL?<38,S>`, so with input scale 1 the output must be `dec?<38, 1>`, not `dec?<38, 2>`. - `nullif(null::dec?<38, 0>, null::dec?<38, 0>) = null::bool?` in `tests/cases/comparison/nullif.test`. `nullif` is `any1, any1 -> any1?`, so with both args `dec<38, 0>` the result must be `dec?<38, 0>`, not `bool?`. Looks like a copy-paste from the bool cases above. Both wrong results date back to the original BFT port (#738). They are not undoing prior work: #913 only touched the `basic` i8/i16 block at the top of `nullif.test`, and #989 rewrote both lines only to add `?` nullability markers, preserving the underlying wrong `38,2` and wrong `bool` base type. This change keeps every `?` marker from #989 and only fixes the parameter #989 wasn't looking at. To prevent regressions, this adds `tests/coverage/type_checker.py` — a symbolic unifier plus evaluator for the YAML return-formula mini-language (assignments, `min`/`max`, `cond ? a : b` ternary). The new module: - parses type strings like `decimal`, `list`, `STRUCT<...>`, `func boolean?>` into tagged tuples; - unifies an impl-declared type against a concrete test type, binding variables like `P1`, `S1`, and polymorphic `any1`/`any2`; - evaluates multi-line return formulas (add/sub/mul/div/mod, min/max, sum, `any1?`, etc.) with the bindings and compares structurally to the test's declared result type. `FunctionOverload`/`FunctionVariant` now carry the raw YAML arg types and return formula alongside the existing short-form fingerprint. `FunctionRegistry.get_function` runs the strict check after the legacy loose match, so the wider test suite's base-type fallback still applies when the caller hasn't supplied full parameterized types. When the formula cannot be evaluated (unbound variable, unusual syntax), the strict check falls back to success, preserving compatibility with the current extensions and leaving room to tighten further. `test_type_checker.py` covers parsing, unification, formula evaluation for add/divide, the two specific bugs above, and the tolerant behavior for tests that omit optional decimal parameters (e.g. `power(dec, dec<38,0>)`). --- .../cases/arithmetic_decimal/sum_decimal.test | 2 +- tests/cases/comparison/nullif.test | 2 +- tests/coverage/coverage.py | 8 + tests/coverage/extensions.py | 82 +++- tests/coverage/nodes.py | 14 + tests/coverage/test_type_checker.py | 240 ++++++++++ tests/coverage/type_checker.py | 426 ++++++++++++++++++ 7 files changed, 760 insertions(+), 14 deletions(-) create mode 100644 tests/coverage/test_type_checker.py create mode 100644 tests/coverage/type_checker.py diff --git a/tests/cases/arithmetic_decimal/sum_decimal.test b/tests/cases/arithmetic_decimal/sum_decimal.test index 5847bc070..8753cc60e 100644 --- a/tests/cases/arithmetic_decimal/sum_decimal.test +++ b/tests/cases/arithmetic_decimal/sum_decimal.test @@ -4,7 +4,7 @@ # basic: Basic examples without any special cases sum((0, -1, 2, 20)::dec<2, 0>) = 21::dec?<38, 0> sum((2000000, -3217908, 629000, -100000, 0, 987654)::dec<7, 0>) = 298746::dec?<38, 0> -sum((2.5, 0, 5.0, -2.5, -7.5)::dec<2, 1>) = -2.5::dec?<38, 2> +sum((2.5, 0, 5.0, -2.5, -7.5)::dec<2, 1>) = -2.5::dec?<38, 1> sum((2.5000007152557373046875, 7.0000007152557373046875, 0, 7.0000007152557373046875)::dec<23, 22>) = 16.5000021457672119140625::dec?<38, 22> # overflow: Examples demonstrating overflow behavior diff --git a/tests/cases/comparison/nullif.test b/tests/cases/comparison/nullif.test index 1ef97538d..092231959 100644 --- a/tests/cases/comparison/nullif.test +++ b/tests/cases/comparison/nullif.test @@ -18,4 +18,4 @@ nullif(null::bool?, true::bool) = null::bool? nullif(true::bool, null::bool?) = true::bool? nullif(null::bool?, null::bool?) = null::bool? nullif(10::dec<38, 0>, null::dec?<38, 0>) = 10::dec?<38, 0> -nullif(null::dec?<38, 0>, null::dec?<38, 0>) = null::bool? +nullif(null::dec?<38, 0>, null::dec?<38, 0>) = null::dec?<38, 0> diff --git a/tests/coverage/coverage.py b/tests/coverage/coverage.py index 69b028d22..a471688bc 100755 --- a/tests/coverage/coverage.py +++ b/tests/coverage/coverage.py @@ -115,6 +115,12 @@ def update_test_count(test_case_files: list, function_registry: FunctionRegistry test_file.include, test_case.get_arg_types(), test_case.get_return_type(), + full_arg_types=test_case.get_full_arg_types(), + full_return_type=( + None + if test_case.is_return_type_error() + else test_case.get_return_type() + ), ) if function_variant: if ( @@ -251,6 +257,8 @@ def validate_nullability(test_file, function_registry): test_file.include, test_case.get_arg_types(), test_case.get_return_type(), + full_arg_types=test_case.get_full_arg_types(), + full_return_type=test_case.get_return_type(), ) if variant is None: continue diff --git a/tests/coverage/extensions.py b/tests/coverage/extensions.py index 0a16918dc..3d0cca6d3 100644 --- a/tests/coverage/extensions.py +++ b/tests/coverage/extensions.py @@ -4,6 +4,7 @@ from tests.coverage.antlr_parser.FuncTestCaseLexer import FuncTestCaseLexer from tests.coverage.nodes import SubstraitError, type_str_is_outer_nullable +from tests.coverage import type_checker enable_debug = False @@ -110,10 +111,12 @@ def get_supported_kernels_from_impls(func): overloads = [] for impl in func["impls"]: args = [] + raw_args = [] if "args" in impl: for arg in impl["args"]: if "value" in arg: arg_type = arg["value"] + raw_args.append(arg_type) if arg_type.endswith("?"): arg_type = arg_type[:-1] args.append(Extension.get_short_type(arg_type)) @@ -122,6 +125,7 @@ def get_supported_kernels_from_impls(func): f"arg is not a value type for function: {func['name']} arg must be enum options {arg['options']}" ) args.append("enum") + raw_args.append(None) nullability = impl.get( "nullability", "MIRROR" ) # MIRROR is the spec default @@ -134,6 +138,8 @@ def get_supported_kernels_from_impls(func): "variadic" in impl, nullability=nullability, is_return_nullable=is_return_nullable, + raw_args=raw_args, + raw_return=return_type_raw, ) ) return overloads @@ -241,6 +247,8 @@ def __init__( func_type, nullability="MIRROR", is_return_nullable=False, + raw_args=None, + raw_return=None, ): self.name = name self.urn = urn @@ -251,6 +259,8 @@ def __init__( self.func_type = func_type self.nullability = nullability self.is_return_nullable = is_return_nullable + self.raw_args = raw_args if raw_args is not None else [] + self.raw_return = raw_return self.test_count = 0 def __str__(self): @@ -268,12 +278,16 @@ def __init__( variadic, nullability="MIRROR", is_return_nullable=False, + raw_args=None, + raw_return=None, ): self.args = args self.return_type = return_type self.variadic = variadic self.nullability = nullability self.is_return_nullable = is_return_nullable + self.raw_args = raw_args if raw_args is not None else [] + self.raw_return = raw_return def __str__(self): return f"FunctionOverload(args={self.args}, result={self.return_type}, variadic={self.variadic}, nullability={self.nullability}, is_return_nullable={self.is_return_nullable})" @@ -316,6 +330,8 @@ def add_functions(self, functions, func_type): func_type, nullability=overload.nullability, is_return_nullable=overload.is_return_nullable, + raw_args=overload.raw_args, + raw_return=overload.raw_return, ) fun_arr.append(function) self.registry[f_name] = fun_arr @@ -346,12 +362,32 @@ def is_same_type(func_arg_type, arg_type): return True return FunctionRegistry.is_type_any(func_arg_type) + def _strict_signature_check(self, function, full_arg_types, full_return_type): + # Variadic impls apply their last declared arg to trailing test args. + impl_args = list(function.raw_args) + if function.variadic and impl_args: + while len(impl_args) < len(full_arg_types): + impl_args.append(impl_args[-1]) + return type_checker.check_signature( + impl_args, + function.raw_return, + list(full_arg_types), + full_return_type, + ) + def get_function( - self, name: str, urn: str, args: object, return_type - ) -> [FunctionVariant]: + self, + name: str, + urn: str, + args, + return_type, + full_arg_types=None, + full_return_type=None, + ): functions = self.registry.get(name, None) if functions is None: return None + strict_failures = [] for function in functions: if urn != function.urn: continue @@ -359,20 +395,42 @@ def get_function( function.return_type, return_type ): continue + # Loose base-type match (legacy fast path). + base_match = False if function.args == args: - return function - if len(function.args) != len(args) and not ( + base_match = True + elif len(function.args) == len(args) or ( function.variadic and len(args) >= len(function.args) ): + base_match = True + for i, arg in enumerate(args): + j = i if i < len(function.args) else len(function.args) - 1 + if not self.is_same_type(function.args[j], arg): + base_match = False + break + if not base_match: continue - is_match = True - for i, arg in enumerate(args): - j = i if i < len(function.args) else len(function.args) - 1 - if not self.is_same_type(function.args[j], arg): - is_match = False - break - if is_match: - return function + + if ( + full_arg_types is not None + and full_return_type is not None + and not isinstance(return_type, SubstraitError) + ): + ok, reason = self._strict_signature_check( + function, full_arg_types, full_return_type + ) + if not ok: + strict_failures.append((function, reason)) + continue + + return function + + if strict_failures: + _, reason = strict_failures[0] + error( + f"Strict parameter check failed for {name}" + f"({', '.join(full_arg_types or [])}) -> {full_return_type}: {reason}" + ) return None def get_extension_list(self): diff --git a/tests/coverage/nodes.py b/tests/coverage/nodes.py index d1151bd31..03977d677 100644 --- a/tests/coverage/nodes.py +++ b/tests/coverage/nodes.py @@ -97,6 +97,20 @@ def get_arg_types(self): types.append(arg.scalar_value.get_base_type()) return types + def get_full_arg_types(self): + # Full parameterized types (e.g. ``dec<38,2>``), as opposed to the + # base-only forms returned by ``get_arg_types``. + types = [] + for arg in self.args: + if isinstance(arg, CaseLiteral): + types.append(arg.type) + elif isinstance(arg, AggregateArgument): + if arg.column_type: + types.append(arg.column_type) + elif arg.scalar_value: + types.append(arg.scalar_value.type) + return types + def get_signature(self): arg_types = [] for arg in self.args: diff --git a/tests/coverage/test_type_checker.py b/tests/coverage/test_type_checker.py new file mode 100644 index 000000000..3b5169702 --- /dev/null +++ b/tests/coverage/test_type_checker.py @@ -0,0 +1,240 @@ +# SPDX-License-Identifier: Apache-2.0 +from tests.coverage.type_checker import ( + parse_type, + unify, + evaluate_return_formula, + check_signature, + structural_equal, +) + + +def test_parse_simple_base_type(): + assert parse_type("i32") == ("i32", False, []) + + +def test_parse_nullable_base_type(): + assert parse_type("i32?") == ("i32", True, []) + + +def test_parse_decimal_with_literals(): + assert parse_type("dec<38, 2>") == ("decimal", False, [38, 2]) + + +def test_parse_decimal_with_variables(): + assert parse_type("DECIMAL") == ("decimal", False, ["P1", "S1"]) + + +def test_parse_decimal_nullable_before_angle(): + assert parse_type("dec?<38, 2>") == ("decimal", True, [38, 2]) + + +def test_parse_decimal_nullable_after_angle(): + assert parse_type("dec<38, 2>?") == ("decimal", True, [38, 2]) + + +def test_parse_list_of_i32(): + assert parse_type("list") == ("list", False, [("i32", False, [])]) + + +def test_parse_nested_list(): + assert parse_type("list>") == ( + "list", + False, + [("list", False, [("i32", False, [])])], + ) + + +def test_parse_map_of_str_to_i32(): + assert parse_type("map") == ( + "map", + False, + [("string", False, []), ("i32", False, [])], + ) + + +def test_parse_struct_with_decimal(): + assert parse_type("STRUCT, i64>") == ( + "struct", + False, + [("decimal", False, [38, "S"]), ("i64", False, [])], + ) + + +def test_parse_func_simple(): + assert parse_type("func bool>") == ( + "func", + False, + [[("i32", False, [])], ("boolean", False, [])], + ) + + +def test_parse_func_polymorphic(): + assert parse_type("func boolean?>") == ( + "func", + False, + [[("any1", False, [])], ("boolean", True, [])], + ) + + +def test_parse_func_multi_arg(): + assert parse_type("func i32>") == ( + "func", + False, + [[("i32", False, []), ("i32", False, [])], ("i32", False, [])], + ) + + +def test_parse_canonicalises_short_to_long(): + assert parse_type("dec<10, 2>")[0] == "decimal" + assert parse_type("vchar<20>")[0] == "varchar" + assert parse_type("bool")[0] == "boolean" + assert parse_type("str")[0] == "string" + + +def test_unify_binds_variables_to_literals(): + b = {} + assert unify(parse_type("DECIMAL"), parse_type("dec<38,2>"), b) + assert b == {"P1": 38, "S1": 2} + + +def test_unify_rejects_mismatched_literal(): + assert not unify(parse_type("DECIMAL<38,0>"), parse_type("dec<10,2>"), {}) + + +def test_unify_binds_any1_to_whole_type(): + b = {} + assert unify(parse_type("any1"), parse_type("dec<10,2>"), b) + assert b["any1"] == ("decimal", False, [10, 2]) + + +def test_unify_any1_must_be_consistent_across_args(): + b = {} + assert unify(parse_type("any1"), parse_type("i32"), b) + assert not unify(parse_type("any1"), parse_type("i64"), b) + + +def test_unify_list_of_any1_with_list_of_i32(): + b = {} + assert unify(parse_type("list"), parse_type("list"), b) + assert b["any1"] == ("i32", False, []) + + +def test_unify_func_polymorphic_with_concrete(): + b = {} + assert unify( + parse_type("func boolean?>"), parse_type("func bool?>"), b + ) + assert b["any1"] == ("i32", False, []) + + +def test_unify_allows_test_to_omit_variable_only_params(): + # ``iday`` is accepted for ``interval_day

`` without binding P. + assert unify(parse_type("interval_day

"), parse_type("iday"), {}) + + +def test_evaluate_single_line_formula(): + assert evaluate_return_formula("DECIMAL<38, 2>", {}) == ("decimal", False, [38, 2]) + + +def test_evaluate_any1_resolution(): + got = evaluate_return_formula("any1?", {"any1": ("decimal", False, [38, 0])}) + assert got == ("decimal", True, [38, 0]) + + +ADD_FORMULA = ( + "init_scale = max(S1,S2)\n" + "init_prec = init_scale + max(P1 - S1, P2 - S2) + 1\n" + "min_scale = min(init_scale, 6)\n" + "delta = init_prec - 38\n" + "prec = min(init_prec, 38)\n" + "scale_after_borrow = max(init_scale - delta, min_scale)\n" + "scale = init_prec > 38 ? scale_after_borrow : init_scale\n" + "DECIMAL" +) + +DIVIDE_FORMULA = ( + "init_scale = max(6, S1 + P2 + 1)\n" + "init_prec = P1 - S1 + P2 + init_scale\n" + "min_scale = min(init_scale, 6)\n" + "delta = init_prec - 38\n" + "prec = min(init_prec, 38)\n" + "scale_after_borrow = max(init_scale - delta, min_scale)\n" + "scale = init_prec > 38 ? scale_after_borrow : init_scale\n" + "DECIMAL" +) + + +def test_evaluate_add_formula(): + # dec<10,2> + dec<5,1>: init_scale=2, init_prec=11 → dec<11,2> + got = evaluate_return_formula(ADD_FORMULA, {"P1": 10, "S1": 2, "P2": 5, "S2": 1}) + assert got == ("decimal", False, [11, 2]) + + +def test_evaluate_add_formula_overflow_borrow(): + # dec<38,10> + dec<38,10>: init_prec=39 forces scale_after_borrow=9 + got = evaluate_return_formula(ADD_FORMULA, {"P1": 38, "S1": 10, "P2": 38, "S2": 10}) + assert got == ("decimal", False, [38, 9]) + + +def test_evaluate_divide_formula(): + # dec<10,2> / dec<5,1>: init_scale=8, init_prec=21 + got = evaluate_return_formula(DIVIDE_FORMULA, {"P1": 10, "S1": 2, "P2": 5, "S2": 1}) + assert got == ("decimal", False, [21, 8]) + + +def test_check_signature_ok_for_correct_decimal_add(): + ok, reason = check_signature( + ["decimal", "decimal"], + ADD_FORMULA, + ["dec<10,2>", "dec<5,1>"], + "dec<11,2>", + ) + assert ok, reason + + +def test_check_signature_catches_sum_decimal_bug(): + # sum returns DECIMAL?<38,S>; input scale 1 → scale 1, not 2. + ok, reason = check_signature( + ["DECIMAL"], "DECIMAL?<38,S>", ["dec<2,1>"], "dec<38,2>" + ) + assert not ok + assert "decimal<38, 1>" in reason and "decimal<38, 2>" in reason + + +def test_check_signature_catches_nullif_bug(): + # nullif is any1 -> any1?; with dec<38,0> args the result cannot be bool. + ok, reason = check_signature( + ["any1", "any1"], "any1?", ["dec<38,0>", "dec<38,0>"], "bool" + ) + assert not ok + assert "decimal" in reason and "boolean" in reason + + +def test_check_signature_falls_back_when_formula_unevaluable(): + # Unbound variable → accept (test opted out of strict on that dim). + ok, _ = check_signature( + ["i64"], + "fp_precision = UNBOUND + 1\nDECIMAL", + ["i64"], + "dec<5,0>", + ) + assert ok + + +def test_check_signature_tolerates_test_omitting_decimal_params(): + # ``power(dec, dec<38,0>) -> fp64`` — first arg drops precision/scale. + ok, reason = check_signature( + ["DECIMAL", "DECIMAL"], + "fp64", + ["dec", "dec<38,0>"], + "fp64", + ) + assert ok, reason + + +def test_structural_equal_ignores_outer_nullable(): + assert structural_equal(parse_type("dec<38,2>"), parse_type("dec?<38,2>")) + + +def test_structural_equal_rejects_param_differences(): + assert not structural_equal(parse_type("dec<38,2>"), parse_type("dec<38,1>")) diff --git a/tests/coverage/type_checker.py b/tests/coverage/type_checker.py new file mode 100644 index 000000000..350db4448 --- /dev/null +++ b/tests/coverage/type_checker.py @@ -0,0 +1,426 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Parameterised type unifier and return-formula evaluator for the coverage +checker. ``FunctionRegistry.is_same_type`` only matches base names; this +module checks the full decimal/varchar/list/map/any1 parameters against the +YAML ``return`` formula.""" +import re + + +# Short-form aliases used by test files. Long forms come from the YAML. +_SHORT_TO_LONG = { + "dec": "decimal", + "vchar": "varchar", + "fchar": "fixedchar", + "fbin": "fixedbinary", + "vbin": "binary", + "str": "string", + "bool": "boolean", + "ts": "timestamp", + "tstz": "timestamp_tz", + "pts": "precision_timestamp", + "ptstz": "precision_timestamp_tz", + "pt": "precision_time", + "iyear": "interval_year", + "iday": "interval_day", + "icompound": "interval_compound", +} + +# Known concrete base names. Any simple identifier not in this set (and not +# a short alias) is treated as a variable such as ``P1`` or ``prec``. +_KNOWN_BASES = frozenset( + "i8 i16 i32 i64 fp32 fp64 boolean string binary date time timestamp " + "timestamp_tz interval_year interval_day interval_compound uuid " + "decimal varchar fixedchar fixedbinary precision_time precision_timestamp " + "precision_timestamp_tz list map struct any any1 any2 enum func".split() +) + +# Types whose parameters are element/field types (recursed into). +_TYPE_PARAM_BASES = frozenset(["list", "map", "struct"]) + +# Types whose parameters are integer literals or variables (precision, +# scale, length). +_INT_PARAM_BASES = frozenset( + [ + "decimal", + "varchar", + "fixedchar", + "fixedbinary", + "precision_time", + "precision_timestamp", + "precision_timestamp_tz", + ] +) + + +def canon_base(base): + base = base.lower() + return _SHORT_TO_LONG.get(base, base) + + +def parse_type(s): + """Parse a type string into ``(base, nullable, params)``. + + ``params`` items are ints, strings (variable or unparsed literal), or + nested ``parse_type`` tuples. ``func<... -> ...>`` is parsed into + ``(arg_types_list, return_type_tuple)``. Returns ``None`` on empty or + malformed input. + """ + if s is None: + return None + s = s.strip() + if not s: + return None + lt = s.find("<") + if lt == -1: + nullable = s.endswith("?") + return (canon_base(s.rstrip("?")), nullable, []) + + head = s[:lt] + close = _find_matching_angle(s, lt) + if close == -1: + return None + inner = s[lt + 1 : close] + tail = s[close + 1 :].strip() + nullable = head.endswith("?") or tail == "?" + base = canon_base(head.rstrip("?")) + + if base == "func": + parsed_func = _parse_func_inner(inner) + if parsed_func is None: + return (base, nullable, [inner.strip()]) + return (base, nullable, list(parsed_func)) + + raw_params = _split_top_level_commas(inner) + params = [] + if base in _TYPE_PARAM_BASES: + for p in raw_params: + nested = parse_type(p) + if nested is None: + return None + params.append(nested) + elif base in _INT_PARAM_BASES: + for p in raw_params: + try: + params.append(int(p)) + except ValueError: + params.append(p) + else: + for p in raw_params: + if "<" in p: + nested = parse_type(p) + if nested is None: + return None + params.append(nested) + else: + try: + params.append(int(p)) + except ValueError: + params.append(p) + return (base, nullable, params) + + +def _find_matching_angle(s, open_idx): + """Return the index of the ``>`` that closes ``s[open_idx]`` (a ``<``). + ``>`` that follows ``-`` is part of a lambda ``->`` arrow, not a bracket. + """ + depth = 0 + for i in range(open_idx, len(s)): + c = s[i] + if c == "<": + depth += 1 + elif c == ">": + if i > 0 and s[i - 1] == "-": + continue + depth -= 1 + if depth == 0: + return i + return -1 + + +def _split_top_level_commas(inner): + out = [] + depth = 0 + cur = "" + for i, c in enumerate(inner): + if c == "<": + depth += 1 + elif c == ">" and not (i > 0 and inner[i - 1] == "-"): + depth -= 1 + if c == "," and depth == 0: + out.append(cur.strip()) + cur = "" + else: + cur += c + if cur.strip(): + out.append(cur.strip()) + return out + + +def _parse_func_inner(inner): + """Parse ``any1 -> boolean?`` or ``i32, i32 -> i32`` into + ``(arg_types, return_type)``. Returns ``None`` on failure.""" + depth = 0 + arrow = -1 + for i in range(len(inner) - 1): + c = inner[i] + if c == "<": + depth += 1 + elif c == ">": + depth -= 1 + elif c == "-" and depth == 0 and inner[i + 1] == ">": + arrow = i + break + if arrow == -1: + return None + arg_types = [] + for piece in _split_top_level_commas(inner[:arrow]): + parsed = parse_type(piece) + if parsed is None: + return None + arg_types.append(parsed) + ret_type = parse_type(inner[arrow + 2 :].strip()) + if ret_type is None: + return None + return (arg_types, ret_type) + + +def _is_variable_name(s): + return ( + isinstance(s, str) + and re.fullmatch(r"[A-Za-z_][A-Za-z_0-9]*", s) is not None + and s.lower() not in _KNOWN_BASES + and s.lower() not in _SHORT_TO_LONG + ) + + +def unify(impl_t, test_t, bindings): + """Unify an impl type (possibly containing variables) against a concrete + test type. Updates ``bindings`` and returns True on success.""" + if impl_t is None or test_t is None: + return False + ib, _, iparams = impl_t + tb, _, tparams = test_t + + if ib in ("any", "any1", "any2"): + stripped = (tb, False, _strip_nullable(tparams)) + existing = bindings.get(ib) + if existing is None: + bindings[ib] = stripped + return True + return structural_equal(existing, stripped) + + if ib != tb: + return False + + if ib == "func": + if ( + len(iparams) != 2 + or len(tparams) != 2 + or not isinstance(iparams[0], list) + or not isinstance(tparams[0], list) + ): + return iparams == tparams + i_args, i_ret = iparams + t_args, t_ret = tparams + if len(i_args) != len(t_args): + return False + for ia, ta in zip(i_args, t_args): + if not unify(ia, ta, bindings): + return False + return unify(i_ret, t_ret, bindings) + + if len(iparams) != len(tparams): + # Test may omit numeric parameters (``iday`` for ``interval_day

``). + # Accept without binding; if the return formula depends on the + # missing variables, ``check_signature`` falls back to the loose + # check. + if ( + not tparams + and iparams + and all(isinstance(p, str) and _is_variable_name(p) for p in iparams) + ): + return True + return False + + for ip, tp in zip(iparams, tparams): + if isinstance(ip, tuple): + if not isinstance(tp, tuple) or not unify(ip, tp, bindings): + return False + elif isinstance(ip, int): + if ip != tp: + return False + elif isinstance(ip, str): + if _is_variable_name(ip): + existing = bindings.get(ip) + if existing is None: + bindings[ip] = tp + elif existing != tp: + return False + elif str(ip).lower() != str(tp).lower(): + return False + return True + + +def _strip_nullable(params): + """Return ``params`` with nested types normalised to ``nullable=False``.""" + out = [] + for p in params: + if isinstance(p, tuple): + pb, _, pp = p + out.append((pb, False, _strip_nullable(pp))) + else: + out.append(p) + return out + + +def structural_equal(a, b): + """Structural equality ignoring outer and inner nullable flags.""" + if a is None or b is None: + return a is b + ab, _, ap = a + bb, _, bp = b + if ab != bb or len(ap) != len(bp): + return False + for x, y in zip(ap, bp): + if isinstance(x, tuple) or isinstance(y, tuple): + if not structural_equal(x, y): + return False + elif x != y: + return False + return True + + +def _ternary_to_py(expr): + """Rewrite ``cond ? a : b`` (possibly nested) as ``(a) if (cond) else (b)``.""" + q_pos = _find_at_depth_zero(expr, "?") + if q_pos == -1: + return expr + colon_pos = _find_at_depth_zero(expr, ":", start=q_pos + 1) + if colon_pos == -1: + return expr + cond = expr[:q_pos].strip() + a = expr[q_pos + 1 : colon_pos].strip() + b = expr[colon_pos + 1 :].strip() + return f"(({_ternary_to_py(a)}) if ({cond}) else ({_ternary_to_py(b)}))" + + +def _find_at_depth_zero(expr, target, start=0): + depth = 0 + for i in range(start, len(expr)): + c = expr[i] + if c == "(": + depth += 1 + elif c == ")": + depth -= 1 + elif c == target and depth == 0: + return i + return -1 + + +def _evaluate_param_expr(expr, env): + try: + return eval( + _ternary_to_py(expr.strip()), + {"__builtins__": {}, "min": min, "max": max}, + env, + ) + except Exception: + return None + + +def evaluate_return_formula(formula, bindings): + """Evaluate a YAML return formula with variable ``bindings`` and return + a concrete parsed-type tuple, or ``None`` if evaluation fails.""" + if not formula: + return None + formula = str(formula).strip() + env = {k: v for k, v in bindings.items() if isinstance(v, int)} + lines = [ln.strip() for ln in formula.split("\n") if ln.strip()] + if not lines: + return None + for line in lines[:-1]: + if "=" not in line: + return None + name, expr = line.split("=", 1) + value = _evaluate_param_expr(expr, env) + if value is None: + return None + env[name.strip()] = value + return _resolve_type(lines[-1], env, bindings) + + +def _resolve_type(value, env, any_bindings): + """Evaluate variable params (and ``any1``/``any2`` references) inside a + type expression. ``value`` is either a string to parse or a parsed + tuple. Returns a tuple with fully evaluated params, or ``None``.""" + parsed = parse_type(value) if isinstance(value, str) else value + if parsed is None: + return None + base, nullable, params = parsed + if base in ("any", "any1", "any2"): + repl = any_bindings.get(base) + if repl is None: + return None + r_base, _, r_params = repl + return (r_base, nullable, r_params) + out = [] + for p in params: + if isinstance(p, int): + out.append(p) + elif isinstance(p, tuple): + sub = _resolve_type(p, env, any_bindings) + if sub is None: + return None + out.append(sub) + elif isinstance(p, str): + v = _evaluate_param_expr(p, env) + if v is None: + return None + out.append(v) + return (base, nullable, out) + + +def check_signature(impl_args, impl_return, test_args, test_return): + """Strict signature check. Returns ``(ok, reason)``. When the impl's + return formula can't be evaluated (e.g. test omits a numeric parameter) + the check accepts the signature — the loose match still applies.""" + if len(impl_args) != len(test_args): + return (False, f"arg count {len(impl_args)} vs {len(test_args)}") + + bindings = {} + for i, (ia, ta) in enumerate(zip(impl_args, test_args)): + parsed_impl = parse_type(ia) + if parsed_impl is None: + # impl arg is not a value type (e.g. enum option) — skip. + continue + parsed_test = parse_type(ta) + if parsed_test is None: + return (False, f"arg {i}: failed to parse test type {ta!r}") + if not unify(parsed_impl, parsed_test, bindings): + return (False, f"arg {i}: cannot unify {ia} with {ta}") + + expected = evaluate_return_formula(impl_return, bindings) + if expected is None: + return (True, "") + + parsed_test_ret = parse_type(test_return) + if parsed_test_ret is None: + return (False, f"failed to parse test return type {test_return!r}") + + if not structural_equal(expected, parsed_test_ret): + return ( + False, + f"return: expected {_format_type(expected)} " + f"but test declares {_format_type(parsed_test_ret)}", + ) + return (True, "") + + +def _format_type(t): + if t is None: + return "" + base, _, params = t + if not params: + return base + inner = [_format_type(p) if isinstance(p, tuple) else str(p) for p in params] + return f"{base}<{', '.join(inner)}>"