Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/cases/arithmetic_decimal/sum_decimal.test
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# basic: Basic examples without any special cases
sum((0, -1, 2, 20)::dec<2, 0>) = 21::dec?<38, 0>
sum((2000000, -3217908, 629000, -100000, 0, 987654)::dec<7, 0>) = 298746::dec?<38, 0>
sum((2.5, 0, 5.0, -2.5, -7.5)::dec<2, 1>) = -2.5::dec?<38, 2>
sum((2.5, 0, 5.0, -2.5, -7.5)::dec<2, 1>) = -2.5::dec?<38, 1>
sum((2.5000007152557373046875, 7.0000007152557373046875, 0, 7.0000007152557373046875)::dec<23, 22>) = 16.5000021457672119140625::dec?<38, 22>

# overflow: Examples demonstrating overflow behavior
Expand Down
2 changes: 1 addition & 1 deletion tests/cases/comparison/nullif.test
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ nullif(null::bool?, true::bool) = null::bool?
nullif(true::bool, null::bool?) = true::bool?
nullif(null::bool?, null::bool?) = null::bool?
nullif(10::dec<38, 0>, null::dec?<38, 0>) = 10::dec?<38, 0>
nullif(null::dec?<38, 0>, null::dec?<38, 0>) = null::bool?
nullif(null::dec?<38, 0>, null::dec?<38, 0>) = null::dec?<38, 0>
8 changes: 8 additions & 0 deletions tests/coverage/coverage.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,12 @@ def update_test_count(test_case_files: list, function_registry: FunctionRegistry
test_file.include,
test_case.get_arg_types(),
test_case.get_return_type(),
full_arg_types=test_case.get_full_arg_types(),
full_return_type=(
None
if test_case.is_return_type_error()
else test_case.get_return_type()
),
)
if function_variant:
if (
Expand Down Expand Up @@ -251,6 +257,8 @@ def validate_nullability(test_file, function_registry):
test_file.include,
test_case.get_arg_types(),
test_case.get_return_type(),
full_arg_types=test_case.get_full_arg_types(),
full_return_type=test_case.get_return_type(),
)
if variant is None:
continue
Expand Down
82 changes: 70 additions & 12 deletions tests/coverage/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from tests.coverage.antlr_parser.FuncTestCaseLexer import FuncTestCaseLexer
from tests.coverage.nodes import SubstraitError, type_str_is_outer_nullable
from tests.coverage import type_checker

enable_debug = False

Expand Down Expand Up @@ -110,10 +111,12 @@ def get_supported_kernels_from_impls(func):
overloads = []
for impl in func["impls"]:
args = []
raw_args = []
if "args" in impl:
for arg in impl["args"]:
if "value" in arg:
arg_type = arg["value"]
raw_args.append(arg_type)
if arg_type.endswith("?"):
arg_type = arg_type[:-1]
args.append(Extension.get_short_type(arg_type))
Expand All @@ -122,6 +125,7 @@ def get_supported_kernels_from_impls(func):
f"arg is not a value type for function: {func['name']} arg must be enum options {arg['options']}"
)
args.append("enum")
raw_args.append(None)
nullability = impl.get(
"nullability", "MIRROR"
) # MIRROR is the spec default
Expand All @@ -134,6 +138,8 @@ def get_supported_kernels_from_impls(func):
"variadic" in impl,
nullability=nullability,
is_return_nullable=is_return_nullable,
raw_args=raw_args,
raw_return=return_type_raw,
)
)
return overloads
Expand Down Expand Up @@ -241,6 +247,8 @@ def __init__(
func_type,
nullability="MIRROR",
is_return_nullable=False,
raw_args=None,
raw_return=None,
):
self.name = name
self.urn = urn
Expand All @@ -251,6 +259,8 @@ def __init__(
self.func_type = func_type
self.nullability = nullability
self.is_return_nullable = is_return_nullable
self.raw_args = raw_args if raw_args is not None else []
self.raw_return = raw_return
self.test_count = 0

def __str__(self):
Expand All @@ -268,12 +278,16 @@ def __init__(
variadic,
nullability="MIRROR",
is_return_nullable=False,
raw_args=None,
raw_return=None,
):
self.args = args
self.return_type = return_type
self.variadic = variadic
self.nullability = nullability
self.is_return_nullable = is_return_nullable
self.raw_args = raw_args if raw_args is not None else []
self.raw_return = raw_return

def __str__(self):
return f"FunctionOverload(args={self.args}, result={self.return_type}, variadic={self.variadic}, nullability={self.nullability}, is_return_nullable={self.is_return_nullable})"
Expand Down Expand Up @@ -316,6 +330,8 @@ def add_functions(self, functions, func_type):
func_type,
nullability=overload.nullability,
is_return_nullable=overload.is_return_nullable,
raw_args=overload.raw_args,
raw_return=overload.raw_return,
)
fun_arr.append(function)
self.registry[f_name] = fun_arr
Expand Down Expand Up @@ -346,33 +362,75 @@ def is_same_type(func_arg_type, arg_type):
return True
return FunctionRegistry.is_type_any(func_arg_type)

def _strict_signature_check(self, function, full_arg_types, full_return_type):
# Variadic impls apply their last declared arg to trailing test args.
impl_args = list(function.raw_args)
if function.variadic and impl_args:
while len(impl_args) < len(full_arg_types):
impl_args.append(impl_args[-1])
return type_checker.check_signature(
impl_args,
function.raw_return,
list(full_arg_types),
full_return_type,
)

def get_function(
self, name: str, urn: str, args: object, return_type
) -> [FunctionVariant]:
self,
name: str,
urn: str,
args,
return_type,
full_arg_types=None,
full_return_type=None,
):
functions = self.registry.get(name, None)
if functions is None:
return None
strict_failures = []
for function in functions:
if urn != function.urn:
continue
if not isinstance(return_type, SubstraitError) and not self.is_same_type(
function.return_type, return_type
):
continue
# Loose base-type match (legacy fast path).
base_match = False
if function.args == args:
return function
if len(function.args) != len(args) and not (
base_match = True
elif len(function.args) == len(args) or (
function.variadic and len(args) >= len(function.args)
):
base_match = True
for i, arg in enumerate(args):
j = i if i < len(function.args) else len(function.args) - 1
if not self.is_same_type(function.args[j], arg):
base_match = False
break
if not base_match:
continue
is_match = True
for i, arg in enumerate(args):
j = i if i < len(function.args) else len(function.args) - 1
if not self.is_same_type(function.args[j], arg):
is_match = False
break
if is_match:
return function

if (
full_arg_types is not None
and full_return_type is not None
and not isinstance(return_type, SubstraitError)
):
ok, reason = self._strict_signature_check(
function, full_arg_types, full_return_type
)
if not ok:
strict_failures.append((function, reason))
continue

return function

if strict_failures:
_, reason = strict_failures[0]
error(
f"Strict parameter check failed for {name}"
f"({', '.join(full_arg_types or [])}) -> {full_return_type}: {reason}"
)
return None

def get_extension_list(self):
Expand Down
14 changes: 14 additions & 0 deletions tests/coverage/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,20 @@ def get_arg_types(self):
types.append(arg.scalar_value.get_base_type())
return types

def get_full_arg_types(self):
# Full parameterized types (e.g. ``dec<38,2>``), as opposed to the
# base-only forms returned by ``get_arg_types``.
types = []
for arg in self.args:
if isinstance(arg, CaseLiteral):
types.append(arg.type)
elif isinstance(arg, AggregateArgument):
if arg.column_type:
types.append(arg.column_type)
elif arg.scalar_value:
types.append(arg.scalar_value.type)
return types

def get_signature(self):
arg_types = []
for arg in self.args:
Expand Down
Loading