Skip to content
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
73b76f7
feat(parser): update to Substrait v0.87.0 with enum arg support
benbellick Apr 9, 2026
4604677
chore: go mod tidy after substrait v0.87.0 upgrade
benbellick Apr 9, 2026
f290b00
refactor(parser): route enum args via CaseLiteral.FuncArg() not Value…
benbellick Apr 9, 2026
4e71758
refactor(parser): widen CaseLiteral.Value to types.FuncArg
benbellick Apr 9, 2026
6713850
refactor(parser): remove WithType from VisitDecimalArg
benbellick Apr 9, 2026
0abf3ed
refactor(parser): remove string-to-enum fallback from GetScalarFuncti…
benbellick Apr 9, 2026
8a18d59
chore: skip std_dev/variance tests and revert getAggregateFuncTableSc…
benbellick Apr 9, 2026
4e0469b
fix: EnumType.ShortString() should return "req" per substrait spec
benbellick Apr 9, 2026
10e28f1
fix: getArgTypes returns CommonEnumType for enum args instead of nil
benbellick Apr 9, 2026
7622e18
fix(parser): cleanup from review — drop stale TODO, extract literal()…
benbellick Apr 9, 2026
dff6d3f
fix(parser): use literal() helper consistently, document enum arg lim…
benbellick Apr 9, 2026
9168505
Merge branch 'main' into feat/substrait-v0.87.0-enum-arg-support
benbellick Apr 10, 2026
e518a40
chore(testcases): update std_dev/variance skip to reference upstream fix
benbellick Apr 13, 2026
2aecefa
chore(testcases): remove stale #223 comment and clarify nil-type work…
benbellick Apr 13, 2026
f9f5fa8
test(testcases): add TestParseEnumArg to cover enum arg parsing and S…
benbellick Apr 13, 2026
37d30e2
refactor(testcases): inline literal() helper
benbellick Apr 13, 2026
8d3d478
fix(testcases): fix enum rendering in String/AsAggregateArgumentStrin…
benbellick Apr 13, 2026
fa6bb76
chore: restore original comments in visitor.go to keep diff minimal
benbellick Apr 13, 2026
7da3b46
chore: restore argTypes variable pattern in testGetFunctionInvocation
benbellick Apr 13, 2026
d2134e2
refactor(testcases): simplify TestParseEnumArg
benbellick Apr 13, 2026
b50e9ae
test: cover types.Enum case in getArgTypes and EnumType.ShortString
benbellick Apr 13, 2026
9137a77
fix(testcases): remove unreachable Literal branch in AsAggregateArgum…
benbellick Apr 14, 2026
c4095ef
docs(types): add comment explaining EnumType.ShortString returns "req"
benbellick Apr 14, 2026
3d43a74
fix(testcases): restore general fallback in AsAggregateArgumentString
benbellick Apr 14, 2026
858423d
refactor(testcases): use type switch for arg value assertion
benbellick Apr 14, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions expr/binding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,10 @@ func TestBoundExpressions(t *testing.T) {
})
}
}

func TestGetArgTypesWithEnum(t *testing.T) {
f, err := NewScalarFunc(extReg, extractID, nil, types.Enum("YEAR"),
MustExpr(NewRootFieldRef(NewStructFieldRef(9), types.NewRecordTypeFromStruct(boringSchema.Struct))))
assert.NoError(t, err)
assert.Equal(t, types.CommonEnumType, f.GetArgTypes()[0])
}
2 changes: 2 additions & 0 deletions expr/functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,8 @@ func getArgTypes(args []types.FuncArg) []types.Type {
argTypes := make([]types.Type, len(args))
for i, arg := range args {
switch a := arg.(type) {
case types.Enum:
argTypes[i] = types.CommonEnumType
case Expression:
argTypes[i] = a.GetType()
case types.Type:
Expand Down
2 changes: 1 addition & 1 deletion extensions/simple_extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ type TypeVariationFunctions string
const (
TypeVariationInheritsFuncs TypeVariationFunctions = "INHERITS"
TypeVariationSeparateFuncs TypeVariationFunctions = "SEPARATE"
EnumTypeString = "req" // TODO change this to "enum"
EnumTypeString = "req"
)

type TypeVariation struct {
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ require (
github.com/google/go-cmp v0.7.0
github.com/google/uuid v1.6.0
github.com/stretchr/testify v1.10.0
github.com/substrait-io/substrait v0.85.0
github.com/substrait-io/substrait v0.87.0
github.com/substrait-io/substrait-protobuf/go v0.85.0
golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0
google.golang.org/protobuf v1.36.6
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU
github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/substrait-io/substrait v0.85.0 h1:ur2VBFhOpx/3RjVG0w5i8SpLHciLjATh27kAc2HPf5A=
github.com/substrait-io/substrait v0.85.0/go.mod h1:MPFNw6sToJgpD5Z2rj0rQrdP/Oq8HG7Z2t3CAEHtkHw=
github.com/substrait-io/substrait v0.87.0 h1:40rP4LejyK6SNQlWz7NX6kQELf8cmScWMBGruWhN4io=
github.com/substrait-io/substrait v0.87.0/go.mod h1:MPFNw6sToJgpD5Z2rj0rQrdP/Oq8HG7Z2t3CAEHtkHw=
github.com/substrait-io/substrait-protobuf/go v0.85.0 h1:zk6MtNWLtDSl8a7qCZRFH0+EIIXVrrd/hsgYK/SQTgM=
github.com/substrait-io/substrait-protobuf/go v0.85.0/go.mod h1:hn+Szm1NmZZc91FwWK9EXD/lmuGBSRTJ5IvHhlG1YnQ=
golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 h1:R84qjqJb5nVJMxqWYb3np9L5ZsaDtB+a39EqjV0JSUM=
Expand Down
10 changes: 5 additions & 5 deletions grammar/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

package grammar

// using substrait v0.85.0
// using substrait v0.87.0

//go:generate wget -nc https://www.antlr.org/download/antlr-4.13.2-complete.jar
//go:generate wget https://raw.githubusercontent.com/substrait-io/substrait/v0.85.0/grammar/SubstraitLexer.g4
//go:generate wget https://raw.githubusercontent.com/substrait-io/substrait/v0.85.0/grammar/SubstraitType.g4
//go:generate wget https://raw.githubusercontent.com/substrait-io/substrait/v0.85.0/grammar/FuncTestCaseLexer.g4
//go:generate wget https://raw.githubusercontent.com/substrait-io/substrait/v0.85.0/grammar/FuncTestCaseParser.g4
//go:generate wget https://raw.githubusercontent.com/substrait-io/substrait/v0.87.0/grammar/SubstraitLexer.g4
//go:generate wget https://raw.githubusercontent.com/substrait-io/substrait/v0.87.0/grammar/SubstraitType.g4
//go:generate wget https://raw.githubusercontent.com/substrait-io/substrait/v0.87.0/grammar/FuncTestCaseLexer.g4
//go:generate wget https://raw.githubusercontent.com/substrait-io/substrait/v0.87.0/grammar/FuncTestCaseParser.g4
//go:generate -command antlr java -Xmx500M -cp "./antlr-4.13.2-complete.jar:$CLASSPATH" org.antlr.v4.Tool
//go:generate antlr -Dlanguage=Go -visitor -Dlanguage=Go -package baseparser -o "../types/parser/baseparser" SubstraitLexer.g4 SubstraitType.g4
//go:generate antlr -Dlanguage=Go -visitor -no-listener -Dlanguage=Go -package baseparser -o "../testcases/parser/baseparser" FuncTestCaseLexer.g4 FuncTestCaseParser.g4
1,316 changes: 660 additions & 656 deletions testcases/parser/baseparser/functestcase_lexer.go

Large diffs are not rendered by default.

2,023 changes: 1,086 additions & 937 deletions testcases/parser/baseparser/functestcase_parser.go

Large diffs are not rendered by default.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions testcases/parser/baseparser/functestcaseparser_visitor.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 12 additions & 3 deletions testcases/parser/nodes.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ const (
type CaseLiteral struct {
Type types.Type
ValueText string
Value expr.Literal
Value types.FuncArg
SubstraitError *SubstraitError
}

Expand All @@ -34,7 +34,11 @@ func (c *CaseLiteral) String() string {
if c.Value == nil {
return "NULL"
}
return literalToString(c.Value) + "::" + c.Type.String()
if lit, ok := c.Value.(expr.Literal); ok {
return literalToString(lit) + "::" + c.Type.String()
}
// Enum args use CommonEnumType whose String() is empty; render as "enum"
return c.ValueText + "::enum"
}

func literalToString(literal expr.Literal) string {
Expand Down Expand Up @@ -76,13 +80,17 @@ func (c *CaseLiteral) AsAggregateArgumentString() string {
}
return "(" + strings.Join(elements, ", ") + ")::" + c.Type.String()
}
return c.Value.ValueString() + "::" + c.Type.String()
if lit, ok := c.Value.(expr.Literal); ok {
return lit.ValueString() + "::" + c.Type.String()
}
return c.ValueText + "::enum"
}

// updateLiteralType updates the type of the literal CaseLiteral.Value to use the CaseLiteral.Type
// Parser creates a literal with a type using existing util functions.
// For ParameterizedTypes utils functions use minimum required values for the parameters.
// This function changes the type to use requested type, so that the function invocation object is created correctly.
// Enum args are excluded: CommonEnumType has no parameters, so they return early.
func (c *CaseLiteral) updateLiteralType() error {
if len(c.Type.GetParameters()) == 0 {
return nil
Expand Down Expand Up @@ -396,6 +404,7 @@ func (tc *TestCase) GetAggregateFunctionInvocation(reg *expr.ExtensionRegistry,
return nil, fmt.Errorf("%w: no matching function found or %s", substraitgo.ErrNotFound, id)
}

// GetAggregateColumnsData returns column data for aggregate test cases.
func (tc *TestCase) GetAggregateColumnsData() ([][]expr.Literal, error) {
if tc.FuncType != AggregateFuncType {
return nil, fmt.Errorf("expected function type %v, but got %v", AggregateFuncType, tc.FuncType)
Expand Down
46 changes: 37 additions & 9 deletions testcases/parser/parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ func TestParseTestWithVariousTypes(t *testing.T) {
{testCaseStr: "concat('abcd'::vchar<9>, Null::varchar<9>) = Null::vchar<9>", expTestStr: "concat('abcd'::varchar<9>, null::varchar?<9>) = null::varchar?<9>"},
{testCaseStr: "concat('abcd'::vchar<9>, Null::fixedchar<9>) = Null::fchar<9>", expTestStr: "concat('abcd'::varchar<9>, null::fixedchar?<9>) = null::fixedchar?<9>"},
{testCaseStr: "concat('abcd'::fbin<9>, Null::fixedbinary<9>) = Null::fbin<9>", expTestStr: "concat('0x61626364'::fixedbinary<9>, null::fixedbinary?<9>) = null::fixedbinary?<9>"},
{testCaseStr: "extract(YEAR::enum, '2016-12-31T13:30:15'::ts) = 2016::i64", expTestStr: "extract(YEAR::enum, '2016-12-31T13:30:15'::timestamp) = 2016::i64"},
{testCaseStr: "f35('1991-01-01T01:02:03.456'::pts<3>) = '1991-01-01T01:02:30.123123'::precision_timestamp<3>", expTestStr: "f35('1991-01-01T01:02:03.456'::precision_timestamp<3>) = '1991-01-01T01:02:30.123'::precision_timestamp<3>"},
{testCaseStr: "f36('1991-01-01T01:02:03.456'::pts<3>, '1991-01-01T01:02:30.123123'::precision_timestamp<3>) = 123456::i64", expTestStr: "f36('1991-01-01T01:02:03.456'::precision_timestamp<3>, '1991-01-01T01:02:30.123'::precision_timestamp<3>) = 123456::i64"},
{testCaseStr: "f37('1991-01-01T01:02:03.123456'::pts<6>, '1991-01-01T04:05:06.456'::precision_timestamp<6>) = 123456::i64", expTestStr: "f37('1991-01-01T01:02:03.123456'::precision_timestamp<6>, '1991-01-01T04:05:06.456'::precision_timestamp<6>) = 123456::i64"},
Expand All @@ -208,10 +209,15 @@ func TestParseTestWithVariousTypes(t *testing.T) {
}
for _, arg := range testFile.TestCases[0].Args {
assert.NotNil(t, arg.Value)
checkNullability(t, arg.Value, arg.Type)
if lit, ok := arg.Value.(expr.Literal); ok {
checkNullability(t, lit, arg.Type)
} else {
assert.Equal(t, types.CommonEnumType, arg.Type)
}
}
assert.NotNil(t, testFile.TestCases[0].Result.Value)
checkNullability(t, testFile.TestCases[0].Result.Value, testFile.TestCases[0].Result.Type)
resultLit, ok := testFile.TestCases[0].Result.Value.(expr.Literal)
require.True(t, ok, "result should be a literal, got %T", testFile.TestCases[0].Result.Value)
checkNullability(t, resultLit, testFile.TestCases[0].Result.Type)
})
}
}
Expand Down Expand Up @@ -427,7 +433,9 @@ sum((9223372036854775806, 1, 1, 1, 1, 10000000000)::i64) [overflow:ERROR] = <!ER
"sum((9223372036854775806, 1, 1, 1, 1, 10000000000)::i64) [overflow:ERROR] = <!ERROR>",
}
assert.Equal(t, newFloat32List(1, 2, 3), tc.AggregateArgs[0].Argument.Value)
assert.Equal(t, listType, tc.AggregateArgs[0].Argument.Value.GetType())
if lit, ok := tc.AggregateArgs[0].Argument.Value.(expr.Literal); ok {
assert.Equal(t, listType, lit.GetType())
}
assert.Equal(t, "fp64", tc.Result.Type.String())
assert.Equal(t, literal.NewFloat64(2, false), tc.Result.Value)
assert.Equal(t, AggregateFuncType, tc.FuncType)
Expand Down Expand Up @@ -816,8 +824,8 @@ func TestParseTestWithBadAggregateTests(t *testing.T) {
corr(t1.col0, t2.col1) = 1::fp64`,
"table name in argument t2, does not match the table name in the function call t1",
},
{"((20, 20), (-3, -3), (1, 1), (10,10), (5,5)) corr(my_col::fp32, col0::fp32) = 1::fp64", "mismatched input '::' expecting ')"},
{"((20, 20), (-3, -3), (1, 1), (10,10), (5,5)) corr(col0::fp32, column1::fp32) = 1::fp64", "mismatched input '::' expecting ')"},
{"((20, 20), (-3, -3), (1, 1), (10,10), (5,5)) corr(my_col::fp32, col0::fp32) = 1::fp64", "mismatched input 'fp32' expecting 'enum'"},
{"((20, 20), (-3, -3), (1, 1), (10,10), (5,5)) corr(col0::fp32, column1::fp32) = 1::fp64", "mismatched input 'fp32' expecting 'enum'"},
{"f8('13:01:01.234'::time) = 123::i32", "expected aggregate testcase based on test file header, but got scalar function testcase"},
}
for _, test := range tests {
Expand Down Expand Up @@ -888,6 +896,8 @@ count(t1.col0) = 4::fp64`, expTestStr: "(('cat'), ('bat'), ('rat'), (null)) coun
expTestStr: "f38(('1990-12-31T19:32:03.456')::precision_timestamp_tz?<3>) = '1990-12-31T08:30:00.000+00:00'::precision_timestamp_tz<3>"},
{testCaseStr: "f39(('1991-01-01T01:02:03.456+05:30', '1991-01-01T01:02:03.123456+05:30')::ptstz<6>) = '1991-01-01T00:00:00+15:30'::ptstz<6>",
expTestStr: "f39(('1990-12-31T19:32:03.456', '1990-12-31T19:32:03.123456')::precision_timestamp_tz<6>) = '1990-12-31T08:30:00.000+00:00'::precision_timestamp_tz<6>"},
{testCaseStr: "((1.0), (2.0), (3.0)) std_dev(SAMPLE::enum, col0::fp32) = 1.0::fp32?",
expTestStr: "((1), (2), (3)) std_dev(SAMPLE::enum, col0::fp32) = 1::fp32?"},
}
for _, test := range tests {
t.Run(test.testCaseStr, func(t *testing.T) {
Expand Down Expand Up @@ -926,6 +936,22 @@ func TestParseTestCaseFile(t *testing.T) {
assert.Len(t, testFile.TestCases, 13)
}

func TestParseEnumArg(t *testing.T) {
header := makeHeader("v1.0", "/extensions/functions_datetime.yaml")
testFile, err := ParseTestCasesFromString(header + "# timestamps\nextract(YEAR::enum, '2016-12-31T13:30:15'::ts) = 2016::i64\n")
require.NoError(t, err)
require.Len(t, testFile.TestCases, 1)

tc := testFile.TestCases[0]
assert.Equal(t, types.CommonEnumType, tc.Args[0].Type)
assert.Equal(t, "extract(YEAR::enum, '2016-12-31T13:30:15'::timestamp) = 2016::i64", tc.String())

reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(extensions.GetDefaultCollectionWithNoError())
invocation, err := tc.GetScalarFunctionInvocation(&reg, funcRegistry)
require.NoError(t, err)
assert.Equal(t, []types.Type{types.CommonEnumType, &types.TimestampType{Nullability: types.NullabilityRequired}}, invocation.GetArgTypes())
}

func TestLoadAllSubstraitTestFiles(t *testing.T) {
got := substrait.GetSubstraitTestsFS()
filePaths, err := listFiles(got, ".")
Expand All @@ -935,9 +961,11 @@ func TestLoadAllSubstraitTestFiles(t *testing.T) {
for _, filePath := range filePaths {
t.Run(filePath, func(t *testing.T) {
switch filePath {
case "tests/cases/datetime/extract.test":
// TODO deal with enum arguments in testcase
t.Skip("Skipping extract.test")
case "tests/cases/arithmetic/std_dev.test",
"tests/cases/arithmetic/variance.test":
// Skipping: upstream test files use an invalid single-column compact format,
// fixed in substrait v0.88.0 (substrait-io/substrait#1043).
t.Skip("Skipping until substrait dependency is updated to v0.88.0+")
case "tests/cases/list/all_match.test",
"tests/cases/list/any_match.test",
"tests/cases/list/filter.test",
Expand Down
8 changes: 8 additions & 0 deletions testcases/parser/visitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,9 @@ func (v *TestCaseVisitor) VisitArgument(ctx *baseparser.ArgumentContext) interfa
if ctx.ListArg() != nil {
return v.Visit(ctx.ListArg())
}
if ctx.EnumArg() != nil {
return v.Visit(ctx.EnumArg())
}
if ctx.IntervalCompoundArg() != nil {
// TODO(#209): implement when substrait test cases use interval compound args
v.ErrorListener.ReportVisitError(ctx, fmt.Errorf("interval compound argument not yet implemented"))
Expand Down Expand Up @@ -643,6 +646,11 @@ func (v *TestCaseVisitor) VisitDecimalArg(ctx *baseparser.DecimalArgContext) int
return &CaseLiteral{Value: decimal, ValueText: ctx.NumericLiteral().GetText(), Type: decType}
}

func (v *TestCaseVisitor) VisitEnumArg(ctx *baseparser.EnumArgContext) interface{} {
identifier := ctx.Identifier().GetText()
return &CaseLiteral{Value: types.Enum(identifier), ValueText: identifier, Type: types.CommonEnumType}
}

func (v *TestCaseVisitor) VisitPrecisionTimeArg(ctx *baseparser.PrecisionTimeArgContext) interface{} {
ptsType := v.Visit(ctx.PrecisionTimeType()).(*types.PrecisionTimeType)
timestampStr := getRawStringFromStringLiteral(ctx.TimeLiteral().GetText())
Expand Down
2 changes: 1 addition & 1 deletion types/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,7 @@ func (e *EnumType) MatchWithoutNullability(ot Type) bool {
}

func (e *EnumType) ShortString() string {
return "enum"
Comment thread
benbellick marked this conversation as resolved.
return "req"
}

func (e *EnumType) GetNullability() Nullability {
Expand Down
4 changes: 4 additions & 0 deletions types/types_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -700,3 +700,7 @@ func TestStructTypeDepthFirstNameCount(t *testing.T) {
})
}
}

func TestEnumTypeShortString(t *testing.T) {
assert.Equal(t, "req", CommonEnumType.ShortString())
}
Loading