Skip to content

Commit 6083883

Browse files
rustyconoverclaude
andcommitted
Add scalar function overloading, type casting, and review fixes
- Add function overloading by ConstParam count for scalar functions - Add format_number and make_series overloaded example functions - Add type casting with debug logging in _validate_single_param_type - Extract shared _make_series_emit helper to deduplicate process() methods - Remove dead TopN table-in-out code (never registered, DuckDB unsupported) - Update catalog interface to support multiple functions per name Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent cdb42b9 commit 6083883

8 files changed

Lines changed: 447 additions & 42 deletions

File tree

tests/test_worker.py

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88
from vgi_rpc.rpc import OutputCollector
99

1010
from vgi import Arg, TableInOutFunction, TableInput
11-
from vgi.arguments import Arguments
11+
from vgi.arguments import Arguments, ConstParam, Param, Returns
1212
from vgi.invocation import FunctionType
1313
from vgi.protocol import BindRequest, TableFunctionCardinalityRequest
14+
from vgi.scalar_function import ScalarFunction
1415
from vgi.table_function import (
1516
BindParams,
1617
ProcessParams,
@@ -589,3 +590,104 @@ class MyWorker(Worker):
589590
)
590591
result = worker.table_function_cardinality(request)
591592
assert result.estimate == 30
593+
594+
595+
class TestScalarOverloading:
596+
"""Tests for scalar function overloading by ConstParam count."""
597+
598+
def _make_scalar_candidates(self) -> list[type]:
599+
"""Create three scalar overloads with 0, 1, and 2 ConstParams."""
600+
601+
class ZeroConst(ScalarFunction):
602+
class Meta:
603+
name = "fmt"
604+
605+
@classmethod
606+
def compute(
607+
cls,
608+
val: Annotated[pa.DoubleArray, Param(doc="Value")],
609+
) -> Annotated[pa.StringArray, Returns()]:
610+
return pa.array([str(v) for v in val.to_pylist()], type=pa.string())
611+
612+
class OneConst(ScalarFunction):
613+
class Meta:
614+
name = "fmt"
615+
616+
@classmethod
617+
def compute(
618+
cls,
619+
prec: Annotated[int, ConstParam("Precision")],
620+
val: Annotated[pa.DoubleArray, Param(doc="Value")],
621+
) -> Annotated[pa.StringArray, Returns()]:
622+
return pa.array([f"{v:.{prec}f}" for v in val.to_pylist()], type=pa.string())
623+
624+
class TwoConst(ScalarFunction):
625+
class Meta:
626+
name = "fmt"
627+
628+
@classmethod
629+
def compute(
630+
cls,
631+
prec: Annotated[int, ConstParam("Precision")],
632+
pfx: Annotated[str, ConstParam("Prefix")],
633+
val: Annotated[pa.DoubleArray, Param(doc="Value")],
634+
) -> Annotated[pa.StringArray, Returns()]:
635+
return pa.array([f"{pfx}{v:.{prec}f}" for v in val.to_pylist()], type=pa.string())
636+
637+
return [ZeroConst, OneConst, TwoConst]
638+
639+
def test_match_by_const_param_count(self) -> None:
640+
"""Scalar overloads are matched by ConstParam count."""
641+
candidates = self._make_scalar_candidates()
642+
643+
# 0 const args -> ZeroConst
644+
result = Worker._match_function_arguments(
645+
function_name="fmt",
646+
arguments=Arguments(positional=()),
647+
input_schema=pa.schema([("val", pa.float64())]),
648+
candidates=candidates,
649+
)
650+
assert result is candidates[0]
651+
652+
# 1 const arg -> OneConst
653+
result = Worker._match_function_arguments(
654+
function_name="fmt",
655+
arguments=Arguments(positional=(pa.scalar(2),)),
656+
input_schema=pa.schema([("val", pa.float64())]),
657+
candidates=candidates,
658+
)
659+
assert result is candidates[1]
660+
661+
# 2 const args -> TwoConst
662+
result = Worker._match_function_arguments(
663+
function_name="fmt",
664+
arguments=Arguments(positional=(pa.scalar(2), pa.scalar("$"))),
665+
input_schema=pa.schema([("val", pa.float64())]),
666+
candidates=candidates,
667+
)
668+
assert result is candidates[2]
669+
670+
def test_zero_const_params_matches(self) -> None:
671+
"""A scalar function with 0 ConstParams correctly matches 0 positional args."""
672+
candidates = self._make_scalar_candidates()
673+
674+
result = Worker._match_function_arguments(
675+
function_name="fmt",
676+
arguments=Arguments(positional=()),
677+
input_schema=pa.schema([("val", pa.float64())]),
678+
candidates=candidates,
679+
)
680+
# Should match ZeroConst (0 ConstParams), not fail
681+
assert result is candidates[0]
682+
683+
def test_no_match_error_scalar(self) -> None:
684+
"""Too many const args gives helpful error for scalar overloads."""
685+
candidates = self._make_scalar_candidates()
686+
687+
with pytest.raises(ValueError, match="No matching function"):
688+
Worker._match_function_arguments(
689+
function_name="fmt",
690+
arguments=Arguments(positional=(pa.scalar(1), pa.scalar(2), pa.scalar(3))),
691+
input_schema=pa.schema([("val", pa.float64())]),
692+
candidates=candidates,
693+
)

uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

vgi/catalog/catalog_interface.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -989,7 +989,7 @@ class ReadOnlyCatalogInterface(CatalogInterface):
989989
_schema_registry: "dict[str, Schema] | None" = None
990990
_table_registry: "dict[tuple[str, str], Table] | None" = None
991991
_view_registry: "dict[tuple[str, str], View] | None" = None
992-
_function_registry: "dict[tuple[str, str], type] | None" = None
992+
_function_registry: "dict[tuple[str, str], list[type]] | None" = None
993993
_macro_registry: "dict[tuple[str, str], Macro] | None" = None
994994

995995
def _build_registries(self) -> None:
@@ -1025,9 +1025,9 @@ def _register_view(schema_key: str, view: "View") -> None:
10251025
def _register_function(schema_key: str, func_cls: type) -> None:
10261026
meta = func_cls.get_metadata() # type: ignore[attr-defined]
10271027
key = (schema_key, meta.name.lower())
1028-
if key in self._function_registry: # type: ignore[operator]
1029-
raise ValueError(f"Duplicate function '{meta.name}' in schema '{schema_key}'")
1030-
self._function_registry[key] = func_cls # type: ignore[index]
1028+
if key not in self._function_registry: # type: ignore[operator]
1029+
self._function_registry[key] = [] # type: ignore[index]
1030+
self._function_registry[key].append(func_cls) # type: ignore[index]
10311031

10321032
def _register_macro(schema_key: str, macro: "Macro") -> None:
10331033
key = (schema_key, macro.name.lower())
@@ -1318,19 +1318,20 @@ def schema_contents(
13181318
results.append(macro.to_macro_info(schema_name))
13191319
else:
13201320
# SCALAR_FUNCTION or TABLE_FUNCTION
1321-
for (sn, _), func_cls in self._function_registry.items():
1321+
for (sn, _), func_classes in self._function_registry.items():
13221322
if sn != name_lower:
13231323
continue
1324-
func_info = self._function_to_info(func_cls, schema_name)
1325-
# Filter by function type
1326-
if type_enum == SchemaObjectType.SCALAR_FUNCTION and func_info.function_type != FunctionType.SCALAR:
1327-
continue
1328-
if type_enum == SchemaObjectType.TABLE_FUNCTION and func_info.function_type not in (
1329-
FunctionType.TABLE,
1330-
FunctionType.AGGREGATE,
1331-
):
1332-
continue
1333-
results.append(func_info)
1324+
for func_cls in func_classes:
1325+
func_info = self._function_to_info(func_cls, schema_name)
1326+
# Filter by function type
1327+
if type_enum == SchemaObjectType.SCALAR_FUNCTION and func_info.function_type != FunctionType.SCALAR:
1328+
continue
1329+
if type_enum == SchemaObjectType.TABLE_FUNCTION and func_info.function_type not in (
1330+
FunctionType.TABLE,
1331+
FunctionType.AGGREGATE,
1332+
):
1333+
continue
1334+
results.append(func_info)
13341335

13351336
return results
13361337

vgi/examples/scalar.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@
4646
__all__ = [
4747
"AddValuesFunction",
4848
"BernoulliFunction",
49+
"FormatNumberDefaultFunction",
50+
"FormatNumberFullFunction",
51+
"FormatNumberPrecisionFunction",
4952
"HashSeedFunction",
5053
"BinaryPacketFunction",
5154
"ConditionalMessageFunction",
@@ -732,3 +735,121 @@ def compute(
732735
[json.dumps(secret_dict) for _ in range(_length)],
733736
type=pa.string(),
734737
)
738+
739+
740+
# ============================================================================
741+
# format_number — overloaded scalar function (3 overloads by ConstParam count)
742+
# ============================================================================
743+
744+
745+
class FormatNumberDefaultFunction(ScalarFunction):
746+
"""Format a number with default precision (0 decimal places).
747+
748+
Overload with 0 ConstParams: just a column input.
749+
750+
Example:
751+
SQL: SELECT format_number(price) FROM products
752+
Input: price=[3.14, 2.718, 100.5]
753+
Output: result=['3', '3', '100']
754+
755+
"""
756+
757+
class Meta:
758+
"""Function metadata."""
759+
760+
name = "format_number"
761+
description = "Format number with default precision (0 decimals)"
762+
examples = [
763+
FunctionExample(
764+
sql="SELECT format_number(price) FROM products",
765+
description="Format prices with no decimal places",
766+
),
767+
]
768+
769+
@classmethod
770+
def compute(
771+
cls,
772+
value: Annotated[pa.DoubleArray, Param(doc="Number to format")],
773+
) -> Annotated[pa.StringArray, Returns()]:
774+
"""Format each value with 0 decimal places."""
775+
return pa.array(
776+
[f"{v:.0f}" if v is not None else None for v in value.to_pylist()],
777+
type=pa.string(),
778+
)
779+
780+
781+
class FormatNumberPrecisionFunction(ScalarFunction):
782+
"""Format a number with specified precision.
783+
784+
Overload with 1 ConstParam: precision.
785+
786+
Example:
787+
SQL: SELECT format_number(2, price) FROM products
788+
Input: price=[3.14159, 2.718, 100.5]
789+
Args: precision=2
790+
Output: result=['3.14', '2.72', '100.50']
791+
792+
"""
793+
794+
class Meta:
795+
"""Function metadata."""
796+
797+
name = "format_number"
798+
description = "Format number with specified precision"
799+
examples = [
800+
FunctionExample(
801+
sql="SELECT format_number(2, price) FROM products",
802+
description="Format prices with 2 decimal places",
803+
),
804+
]
805+
806+
@classmethod
807+
def compute(
808+
cls,
809+
precision: Annotated[int, ConstParam("Number of decimal places")],
810+
value: Annotated[pa.DoubleArray, Param(doc="Number to format")],
811+
) -> Annotated[pa.StringArray, Returns()]:
812+
"""Format each value with the specified precision."""
813+
return pa.array(
814+
[f"{v:.{precision}f}" if v is not None else None for v in value.to_pylist()],
815+
type=pa.string(),
816+
)
817+
818+
819+
class FormatNumberFullFunction(ScalarFunction):
820+
"""Format a number with precision and prefix.
821+
822+
Overload with 2 ConstParams: precision and prefix.
823+
824+
Example:
825+
SQL: SELECT format_number(2, '$', price) FROM products
826+
Input: price=[3.14, 2.718, 100.5]
827+
Args: precision=2, prefix='$'
828+
Output: result=['$3.14', '$2.72', '$100.50']
829+
830+
"""
831+
832+
class Meta:
833+
"""Function metadata."""
834+
835+
name = "format_number"
836+
description = "Format number with precision and prefix"
837+
examples = [
838+
FunctionExample(
839+
sql="SELECT format_number(2, '$', price) FROM products",
840+
description="Format prices with dollar sign and 2 decimals",
841+
),
842+
]
843+
844+
@classmethod
845+
def compute(
846+
cls,
847+
precision: Annotated[int, ConstParam("Number of decimal places")],
848+
prefix: Annotated[str, ConstParam("Prefix string")],
849+
value: Annotated[pa.DoubleArray, Param(doc="Number to format")],
850+
) -> Annotated[pa.StringArray, Returns()]:
851+
"""Format each value with prefix and specified precision."""
852+
return pa.array(
853+
[f"{prefix}{v:.{precision}f}" if v is not None else None for v in value.to_pylist()],
854+
type=pa.string(),
855+
)

0 commit comments

Comments
 (0)