From 2e8f21358ae1048d9fd58f0f46c2ca8997b481ea Mon Sep 17 00:00:00 2001 From: bvolpato Date: Mon, 16 Mar 2026 12:56:21 -0400 Subject: [PATCH 1/3] feat: add DynamicParameter expression and plan bindings support Add DynamicParameter expression type with full Expression interface implementation, ExprBuilder support, and plan-level DynamicParameterBinding for parameterized queries. --- expr/builder.go | 27 +++ expr/dynamic_parameter_internal_test.go | 20 ++ expr/dynamic_parameter_test.go | 274 ++++++++++++++++++++++ expr/expression.go | 57 +++++ plan/builders.go | 23 +- plan/common.go | 13 + plan/dynamic_parameter_test.go | 224 ++++++++++++++++++ plan/plan.go | 42 +++- plan/testdata/dynamic_parameter_plan.json | 55 +++++ 9 files changed, 725 insertions(+), 10 deletions(-) create mode 100644 expr/dynamic_parameter_internal_test.go create mode 100644 expr/dynamic_parameter_test.go create mode 100644 plan/dynamic_parameter_test.go create mode 100644 plan/testdata/dynamic_parameter_plan.json diff --git a/expr/builder.go b/expr/builder.go index 37daca22..4f41c811 100644 --- a/expr/builder.go +++ b/expr/builder.go @@ -180,6 +180,15 @@ func (e *ExprBuilder) Cast(from Builder, to types.Type) *castBuilder { } } +// DynamicParam returns a builder for constructing a DynamicParameter expression. +// The paramRef identifies the parameter binding in the plan. +func (e *ExprBuilder) DynamicParam(outputType types.Type, paramRef uint32) *dynamicParamBuilder { + return &dynamicParamBuilder{ + outputType: outputType, + paramRef: paramRef, + } +} + // Lambda returns a builder for constructing a Lambda expression with the // given parameters. // @@ -301,6 +310,24 @@ func (cb *castBuilder) FailBehavior(b types.CastFailBehavior) *castBuilder { return cb } +type dynamicParamBuilder struct { + outputType types.Type + paramRef uint32 +} + +func (dpb *dynamicParamBuilder) Build() (*DynamicParameter, error) { + if dpb.outputType == nil { + return nil, fmt.Errorf("%w: dynamic parameter must have an output type", substraitgo.ErrInvalidExpr) + } + return &DynamicParameter{ + OutputType: dpb.outputType, + ParameterReference: dpb.paramRef, + }, nil +} + +func (dpb *dynamicParamBuilder) BuildExpr() (Expression, error) { return dpb.Build() } +func (dpb *dynamicParamBuilder) BuildFuncArg() (types.FuncArg, error) { return dpb.Build() } + type scalarFuncBuilder struct { b *ExprBuilder diff --git a/expr/dynamic_parameter_internal_test.go b/expr/dynamic_parameter_internal_test.go new file mode 100644 index 00000000..22ac77c6 --- /dev/null +++ b/expr/dynamic_parameter_internal_test.go @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: Apache-2.0 + +package expr + +import ( + "testing" + + "github.com/substrait-io/substrait-go/v7/types" +) + +func TestDynamicParameterIsRootRef(t *testing.T) { + dp := &DynamicParameter{ + OutputType: &types.Int32Type{Nullability: types.NullabilityRequired}, + ParameterReference: 0, + } + + // Verify DynamicParameter satisfies RootRefType + var _ RootRefType = dp + dp.isRootRef() +} diff --git a/expr/dynamic_parameter_test.go b/expr/dynamic_parameter_test.go new file mode 100644 index 00000000..030c4a8a --- /dev/null +++ b/expr/dynamic_parameter_test.go @@ -0,0 +1,274 @@ +// SPDX-License-Identifier: Apache-2.0 + +package expr_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/substrait-io/substrait-go/v7/expr" + "github.com/substrait-io/substrait-go/v7/extensions" + "github.com/substrait-io/substrait-go/v7/types" + proto "github.com/substrait-io/substrait-protobuf/go/substraitpb" + pb "google.golang.org/protobuf/proto" +) + +func TestDynamicParameterEquals(t *testing.T) { + dp1 := &expr.DynamicParameter{ + OutputType: &types.Int64Type{Nullability: types.NullabilityRequired}, + ParameterReference: 0, + } + + dp2 := &expr.DynamicParameter{ + OutputType: &types.Int64Type{Nullability: types.NullabilityRequired}, + ParameterReference: 0, + } + + dp3 := &expr.DynamicParameter{ + OutputType: &types.Int64Type{Nullability: types.NullabilityRequired}, + ParameterReference: 1, + } + + dp4 := &expr.DynamicParameter{ + OutputType: &types.Float64Type{Nullability: types.NullabilityRequired}, + ParameterReference: 0, + } + + assert.True(t, dp1.Equals(dp2), "same type and ref should be equal") + assert.False(t, dp1.Equals(dp3), "different ref should not be equal") + assert.False(t, dp1.Equals(dp4), "different type should not be equal") + assert.False(t, dp1.Equals(expr.NewPrimitiveLiteral(int64(42), false)), "different expression type should not be equal") +} + +func TestDynamicParameterVisit(t *testing.T) { + dp := &expr.DynamicParameter{ + OutputType: &types.Int32Type{Nullability: types.NullabilityRequired}, + ParameterReference: 5, + } + + // Visit should return the same expression since DynamicParameter has no children + visited := dp.Visit(func(e expr.Expression) expr.Expression { + return e + }) + + assert.Same(t, dp, visited, "Visit should return same pointer for leaf expression") +} + +// TestDynamicParameterToProtoRoundtrip tests construction, interface compliance, +// and proto roundtrip for various DynamicParameter configurations. +// The $N:type String() format (e.g. "$0:i32") is an internal debugging +// representation used by this library; it is not part of the Substrait spec. +func TestDynamicParameterToProtoRoundtrip(t *testing.T) { + tests := []struct { + name string + dp *expr.DynamicParameter + }{ + { + name: "required i32", + dp: &expr.DynamicParameter{ + OutputType: &types.Int32Type{Nullability: types.NullabilityRequired}, + ParameterReference: 0, + }, + }, + { + name: "nullable string", + dp: &expr.DynamicParameter{ + OutputType: &types.StringType{Nullability: types.NullabilityNullable}, + ParameterReference: 1, + }, + }, + { + name: "required fp64", + dp: &expr.DynamicParameter{ + OutputType: &types.Float64Type{Nullability: types.NullabilityRequired}, + ParameterReference: 5, + }, + }, + { + name: "required boolean", + dp: &expr.DynamicParameter{ + OutputType: &types.BooleanType{Nullability: types.NullabilityRequired}, + ParameterReference: 10, + }, + }, + { + name: "nullable i64", + dp: &expr.DynamicParameter{ + OutputType: &types.Int64Type{Nullability: types.NullabilityNullable}, + ParameterReference: 42, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.True(t, tt.dp.IsScalar()) + assert.True(t, tt.dp.GetType().Equals(tt.dp.OutputType)) + + // Proto roundtrip: the plan should equal itself after a roundtrip + protoExpr := tt.dp.ToProto() + require.NotNil(t, protoExpr) + + fromProto, err := expr.ExprFromProto(protoExpr, nil, expr.NewEmptyExtensionRegistry(extensions.GetDefaultCollectionWithNoError())) + require.NoError(t, err) + assert.True(t, tt.dp.Equals(fromProto), "roundtrip should produce equal expression") + + protoRoundTrip := fromProto.ToProto() + assert.True(t, pb.Equal(protoExpr, protoRoundTrip), "proto roundtrip should be equal") + }) + } +} + +func TestDynamicParameterToProtoFuncArg(t *testing.T) { + dp := &expr.DynamicParameter{ + OutputType: &types.Int32Type{Nullability: types.NullabilityRequired}, + ParameterReference: 0, + } + + arg := dp.ToProtoFuncArg() + require.NotNil(t, arg) + require.NotNil(t, arg.GetValue(), "should be a value argument") + require.NotNil(t, arg.GetValue().GetDynamicParameter(), "value should be a dynamic parameter") +} + +func TestDynamicParameterFromProtoNilDynamicParam(t *testing.T) { + // Test ExprFromProto with a DynamicParameter that has nil inner + protoExpr := &proto.Expression{ + RexType: &proto.Expression_DynamicParameter{ + DynamicParameter: nil, + }, + } + + _, err := expr.ExprFromProto(protoExpr, nil, expr.NewEmptyExtensionRegistry(extensions.GetDefaultCollectionWithNoError())) + require.Error(t, err) + assert.Contains(t, err.Error(), "dynamic parameter is nil") +} + +func TestDynamicParameterBuilder(t *testing.T) { + b := expr.ExprBuilder{ + Reg: expr.NewEmptyExtensionRegistry(extensions.GetDefaultCollectionWithNoError()), + } + + tests := []struct { + name string + build func() (expr.Expression, error) + expect string + err string + }{ + { + name: "basic i32", + build: func() (expr.Expression, error) { + return b.DynamicParam(&types.Int32Type{Nullability: types.NullabilityRequired}, 0).BuildExpr() + }, + expect: "$0:i32", + }, + { + name: "nullable string param 3", + build: func() (expr.Expression, error) { + return b.DynamicParam(&types.StringType{Nullability: types.NullabilityNullable}, 3).BuildExpr() + }, + expect: "$3:string?", + }, + { + name: "nil type should error", + build: func() (expr.Expression, error) { + return b.DynamicParam(nil, 0).BuildExpr() + }, + err: "dynamic parameter must have an output type", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e, err := tt.build() + if tt.err != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.err) + } else { + require.NoError(t, err) + assert.Equal(t, tt.expect, e.String()) + // Verify proto roundtrip + e.ToProto() + } + }) + } +} + +func TestDynamicParameterBuilderAsFuncArg(t *testing.T) { + b := expr.ExprBuilder{ + Reg: expr.NewEmptyExtensionRegistry(extensions.GetDefaultCollectionWithNoError()), + BaseSchema: types.NewRecordTypeFromStruct(boringSchema.Struct), + } + + // Use DynamicParam as a function argument via the builder + dpBuilder := b.DynamicParam(&types.Int8Type{Nullability: types.NullabilityRequired}, 0) + + // Verify it implements FuncArgBuilder + funcArg, err := dpBuilder.BuildFuncArg() + require.NoError(t, err) + assert.NotNil(t, funcArg) + + dp, ok := funcArg.(*expr.DynamicParameter) + require.True(t, ok) + assert.Equal(t, uint32(0), dp.ParameterReference) + + // Build as a function argument in a scalar function + e, err := b.ScalarFunc(addID).Args( + dpBuilder, + b.Wrap(expr.NewLiteral(int8(5), false)), + ).BuildExpr() + require.NoError(t, err) + assert.Contains(t, e.String(), "$0:i8") +} + +func TestDynamicParameterInProject(t *testing.T) { + // Test using dynamic parameter in a project expression through builders + + dp := &expr.DynamicParameter{ + OutputType: &types.Int32Type{Nullability: types.NullabilityRequired}, + ParameterReference: 0, + } + + // Verify it can be used as a project expression + protoExpr := dp.ToProto() + require.NotNil(t, protoExpr) + + // Roundtrip + fromProto, err := expr.ExprFromProto(protoExpr, nil, expr.NewEmptyExtensionRegistry(extensions.GetDefaultCollectionWithNoError())) + require.NoError(t, err) + require.IsType(t, &expr.DynamicParameter{}, fromProto) + + roundtripped := fromProto.(*expr.DynamicParameter) + assert.Equal(t, uint32(0), roundtripped.ParameterReference) + assert.True(t, roundtripped.GetType().Equals(&types.Int32Type{Nullability: types.NullabilityRequired})) +} + +func TestDynamicParameterMultipleInExpression(t *testing.T) { + dp0 := &expr.DynamicParameter{ + OutputType: &types.Int32Type{Nullability: types.NullabilityRequired}, + ParameterReference: 0, + } + + dp1 := &expr.DynamicParameter{ + OutputType: &types.StringType{Nullability: types.NullabilityNullable}, + ParameterReference: 1, + } + + // Both should work independently + proto0 := dp0.ToProto() + proto1 := dp1.ToProto() + require.NotNil(t, proto0) + require.NotNil(t, proto1) + + from0, err := expr.ExprFromProto(proto0, nil, expr.NewEmptyExtensionRegistry(extensions.GetDefaultCollectionWithNoError())) + require.NoError(t, err) + from1, err := expr.ExprFromProto(proto1, nil, expr.NewEmptyExtensionRegistry(extensions.GetDefaultCollectionWithNoError())) + require.NoError(t, err) + + // They should not be equal to each other + assert.False(t, from0.Equals(from1)) + // But should be equal to themselves + assert.True(t, from0.Equals(dp0)) + assert.True(t, from1.Equals(dp1)) +} diff --git a/expr/expression.go b/expr/expression.go index a52a03ca..48860fa3 100644 --- a/expr/expression.go +++ b/expr/expression.go @@ -316,6 +316,14 @@ func ExprFromProto(e *proto.Expression, baseSchema *types.RecordType, reg Extens return nil, fmt.Errorf("%w: nested expression: %s", substraitgo.ErrInvalidExpr, n) } + case *proto.Expression_DynamicParameter: + if et.DynamicParameter == nil { + return nil, fmt.Errorf("%w: dynamic parameter is nil", substraitgo.ErrInvalidExpr) + } + return &DynamicParameter{ + OutputType: types.TypeFromProto(et.DynamicParameter.Type), + ParameterReference: et.DynamicParameter.ParameterReference, + }, nil case *proto.Expression_Enum_: return nil, fmt.Errorf("%w: deprecated", substraitgo.ErrNotImplemented) case *proto.Expression_Subquery_: @@ -372,6 +380,7 @@ type VisitFunc func(Expression) Expression // - A Cast expression // - A Subquery // - A Nested expression +// - A Dynamic Parameter type Expression interface { // an Expression can also be a function argument types.FuncArg @@ -654,6 +663,54 @@ func (ex *Cast) Visit(visit VisitFunc) Expression { return &newCast } +type DynamicParameter struct { + OutputType types.Type + ParameterReference uint32 +} + +func (dp *DynamicParameter) String() string { + return fmt.Sprintf("$%d:%s", dp.ParameterReference, dp.OutputType) +} + +func (dp *DynamicParameter) ToProtoFuncArg() *proto.FunctionArgument { + return &proto.FunctionArgument{ + ArgType: &proto.FunctionArgument_Value{ + Value: dp.ToProto(), + }, + } +} + +func (dp *DynamicParameter) isRootRef() {} + +func (dp *DynamicParameter) IsScalar() bool { return true } + +func (dp *DynamicParameter) GetType() types.Type { return dp.OutputType } + +func (dp *DynamicParameter) ToProto() *proto.Expression { + return &proto.Expression{ + RexType: &proto.Expression_DynamicParameter{ + DynamicParameter: &proto.DynamicParameter{ + Type: types.TypeToProto(dp.OutputType), + ParameterReference: dp.ParameterReference, + }, + }, + } +} + +func (dp *DynamicParameter) Equals(other Expression) bool { + rhs, ok := other.(*DynamicParameter) + if !ok { + return false + } + + return dp.ParameterReference == rhs.ParameterReference && + dp.OutputType.Equals(rhs.OutputType) +} + +func (dp *DynamicParameter) Visit(visit VisitFunc) Expression { + return dp +} + type SwitchExpr struct { match Expression ifs []struct { diff --git a/plan/builders.go b/plan/builders.go index edf8af87..1595c7be 100644 --- a/plan/builders.go +++ b/plan/builders.go @@ -142,6 +142,10 @@ type Builder interface { // that may be in use with this plan for advanced extensions, optimizations, // and so on. PlanWithTypes(root Rel, rootNames []string, expectedTypeURLs []string, others ...Rel) (*Plan, error) + // PlanWithBindings is the same as PlanWithTypes, but additionally accepts + // dynamic parameter bindings. These bind DynamicParameter expressions in + // the plan tree to concrete literal values at runtime. + PlanWithBindings(root Rel, rootNames []string, expectedTypeURLs []string, bindings []DynamicParameterBinding, others ...Rel) (*Plan, error) // GetExprBuilder returns an expr.ExprBuilder that shares the extension // registry that this Builder uses. @@ -739,6 +743,14 @@ func (b *builder) Set(op SetOp, inputs ...Rel) (*SetRel, error) { } func (b *builder) PlanWithTypes(root Rel, rootNames []string, expectedTypeURLs []string, others ...Rel) (*Plan, error) { + return b.PlanWithBindings(root, rootNames, expectedTypeURLs, nil, others...) +} + +// PlanWithBindings constructs a new plan with the provided root relation, +// expected type URLs, and dynamic parameter bindings. This allows creating +// parameterized plans where expressions contain DynamicParameter references +// that are bound to concrete values. +func (b *builder) PlanWithBindings(root Rel, rootNames []string, expectedTypeURLs []string, bindings []DynamicParameterBinding, others ...Rel) (*Plan, error) { if root == nil { return nil, fmt.Errorf("%w: must provide non-nil root relation for plan", substraitgo.ErrInvalidRel) @@ -758,11 +770,12 @@ func (b *builder) PlanWithTypes(root Rel, rootNames []string, expectedTypeURLs [ } return &Plan{ - version: &CurrentVersion, - extensions: b.extSet, - reg: b.reg, - expectedTypeURLs: expectedTypeURLs, - relations: relations, + version: &CurrentVersion, + extensions: b.extSet, + reg: b.reg, + expectedTypeURLs: expectedTypeURLs, + relations: relations, + parameterBindings: bindings, }, nil } diff --git a/plan/common.go b/plan/common.go index de1cd193..09fddf6b 100644 --- a/plan/common.go +++ b/plan/common.go @@ -3,11 +3,24 @@ package plan import ( + "github.com/substrait-io/substrait-go/v7/expr" "github.com/substrait-io/substrait-go/v7/extensions" "github.com/substrait-io/substrait-go/v7/types" proto "github.com/substrait-io/substrait-protobuf/go/substraitpb" ) +// DynamicParameterBinding maps a parameter anchor to a literal value +// for use with DynamicParameter expressions in a plan. +// +// NOTE: this library does not currently validate that the type of the +// literal Value matches the OutputType declared on the corresponding +// DynamicParameter expression. Consumers should perform their own +// type-checking if needed. +type DynamicParameterBinding struct { + ParameterAnchor uint32 + Value expr.Literal +} + type ( Hint = proto.RelCommon_Hint Stats = proto.RelCommon_Hint_Stats diff --git a/plan/dynamic_parameter_test.go b/plan/dynamic_parameter_test.go new file mode 100644 index 00000000..1b85ebce --- /dev/null +++ b/plan/dynamic_parameter_test.go @@ -0,0 +1,224 @@ +// SPDX-License-Identifier: Apache-2.0 + +package plan_test + +import ( + "os" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/substrait-io/substrait-go/v7/expr" + "github.com/substrait-io/substrait-go/v7/extensions" + "github.com/substrait-io/substrait-go/v7/plan" + "github.com/substrait-io/substrait-go/v7/types" + substraitproto "github.com/substrait-io/substrait-protobuf/go/substraitpb" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" +) + +func TestDynamicParameterInFilterPlan(t *testing.T) { + b := plan.NewBuilderDefault() + scan := b.NamedScan([]string{"test"}, baseSchema2) + + dp := &expr.DynamicParameter{ + OutputType: &types.Int32Type{Nullability: types.NullabilityRequired}, + ParameterReference: 0, + } + + ref, err := b.RootFieldRef(scan, 0) + require.NoError(t, err) + + gt, err := b.ScalarFn(extensions.SubstraitDefaultURNPrefix+"functions_comparison", "gt", nil, ref, dp) + require.NoError(t, err) + + filter, err := b.Filter(scan, gt) + require.NoError(t, err) + + bindings := []plan.DynamicParameterBinding{ + { + ParameterAnchor: 0, + Value: expr.NewPrimitiveLiteral(int32(42), false), + }, + } + + p, err := b.PlanWithBindings(filter, []string{"x", "y"}, nil, bindings) + require.NoError(t, err) + + protoPlan, err := p.ToProto() + require.NoError(t, err) + + roundTrip, err := plan.FromProto(protoPlan, extensions.GetDefaultCollectionWithNoError()) + require.NoError(t, err) + + roundTripProto, err := roundTrip.ToProto() + require.NoError(t, err) + + assert.Truef(t, proto.Equal(protoPlan, roundTripProto), "plan expected: %s\ngot: %s", + protojson.Format(protoPlan), protojson.Format(roundTripProto)) +} + +func TestDynamicParameterInProjectPlan(t *testing.T) { + b := plan.NewBuilderDefault() + scan := b.NamedScan([]string{"test"}, baseSchema2) + + dp := &expr.DynamicParameter{ + OutputType: &types.StringType{Nullability: types.NullabilityNullable}, + ParameterReference: 0, + } + + project, err := b.Project(scan, dp) + require.NoError(t, err) + + bindings := []plan.DynamicParameterBinding{ + { + ParameterAnchor: 0, + Value: expr.NewPrimitiveLiteral("hello", false), + }, + } + + p, err := b.PlanWithBindings(project, []string{"x", "y", "param_val"}, nil, bindings) + require.NoError(t, err) + + protoPlan, err := p.ToProto() + require.NoError(t, err) + + roundTrip, err := plan.FromProto(protoPlan, extensions.GetDefaultCollectionWithNoError()) + require.NoError(t, err) + + roundTripProto, err := roundTrip.ToProto() + require.NoError(t, err) + assert.True(t, proto.Equal(protoPlan, roundTripProto)) +} + +func TestDynamicParameterMultipleBindings(t *testing.T) { + b := plan.NewBuilderDefault() + scan := b.NamedScan([]string{"test"}, baseSchema2) + + dp0 := &expr.DynamicParameter{ + OutputType: &types.Int32Type{Nullability: types.NullabilityRequired}, + ParameterReference: 0, + } + + dp1 := &expr.DynamicParameter{ + OutputType: &types.StringType{Nullability: types.NullabilityNullable}, + ParameterReference: 1, + } + + project, err := b.Project(scan, dp0, dp1) + require.NoError(t, err) + + bindings := []plan.DynamicParameterBinding{ + { + ParameterAnchor: 0, + Value: expr.NewPrimitiveLiteral(int32(100), false), + }, + { + ParameterAnchor: 1, + Value: expr.NewPrimitiveLiteral("world", true), + }, + } + + p, err := b.PlanWithBindings(project, []string{"x", "y", "p0", "p1"}, nil, bindings) + require.NoError(t, err) + + protoPlan, err := p.ToProto() + require.NoError(t, err) + + roundTrip, err := plan.FromProto(protoPlan, extensions.GetDefaultCollectionWithNoError()) + require.NoError(t, err) + + roundTripProto, err := roundTrip.ToProto() + require.NoError(t, err) + assert.True(t, proto.Equal(protoPlan, roundTripProto)) +} + +func TestDynamicParameterPlanWithoutBindings(t *testing.T) { + // Plan with dynamic parameters but no bindings (valid use case) + b := plan.NewBuilderDefault() + scan := b.NamedScan([]string{"test"}, baseSchema2) + + dp := &expr.DynamicParameter{ + OutputType: &types.Int32Type{Nullability: types.NullabilityRequired}, + ParameterReference: 0, + } + + project, err := b.Project(scan, dp) + require.NoError(t, err) + + // Use regular Plan (no bindings) + p, err := b.Plan(project, []string{"x", "y", "param"}) + require.NoError(t, err) + + // Verify no bindings + assert.Empty(t, p.ParameterBindings()) + + // Proto roundtrip should still work + protoPlan, err := p.ToProto() + require.NoError(t, err) + assert.Empty(t, protoPlan.ParameterBindings) + + roundTrip, err := plan.FromProto(protoPlan, extensions.GetDefaultCollectionWithNoError()) + require.NoError(t, err) + assert.Empty(t, roundTrip.ParameterBindings()) +} + +func TestDynamicParameterFromProtoJSON(t *testing.T) { + planJSON, err := os.ReadFile("testdata/dynamic_parameter_plan.json") + require.NoError(t, err) + + var protoPlan substraitproto.Plan + require.NoError(t, protojson.Unmarshal(planJSON, &protoPlan)) + + p, err := plan.FromProto(&protoPlan, extensions.GetDefaultCollectionWithNoError()) + require.NoError(t, err) + + backToProto, err := p.ToProto() + require.NoError(t, err) + assert.Truef(t, proto.Equal(&protoPlan, backToProto), "expected: %s\ngot: %s", + protojson.Format(&protoPlan), protojson.Format(backToProto)) +} + +func TestDynamicParameterBuilderInPlanBuilder(t *testing.T) { + b := plan.NewBuilderDefault() + eb := b.GetExprBuilder() + + scan := b.NamedScan([]string{"employees"}, types.NamedStruct{ + Names: []string{"id", "salary"}, + Struct: types.StructType{ + Nullability: types.NullabilityRequired, + Types: []types.Type{ + &types.Int64Type{Nullability: types.NullabilityRequired}, + &types.Float64Type{Nullability: types.NullabilityRequired}, + }, + }, + }) + + dpExpr, err := eb.DynamicParam( + &types.Float64Type{Nullability: types.NullabilityRequired}, 0, + ).BuildExpr() + require.NoError(t, err) + + project, err := b.Project(scan, dpExpr) + require.NoError(t, err) + + bindings := []plan.DynamicParameterBinding{ + { + ParameterAnchor: 0, + Value: expr.NewPrimitiveLiteral(float64(50000.0), false), + }, + } + + p, err := b.PlanWithBindings(project, []string{"id", "salary", "threshold"}, nil, bindings) + require.NoError(t, err) + + protoPlan, err := p.ToProto() + require.NoError(t, err) + + roundTrip, err := plan.FromProto(protoPlan, extensions.GetDefaultCollectionWithNoError()) + require.NoError(t, err) + + roundTripProto, err := roundTrip.ToProto() + require.NoError(t, err) + assert.True(t, proto.Equal(protoPlan, roundTripProto)) +} diff --git a/plan/plan.go b/plan/plan.go index b7f26fff..9decc41d 100644 --- a/plan/plan.go +++ b/plan/plan.go @@ -152,11 +152,12 @@ type AdvancedExtension interface { // Plan describes a set of operations to complete. For // compactness, identifiers are normalized at the plan level. type Plan struct { - version *types.Version - extensions extensions.Set - expectedTypeURLs []string - advExtension *extensions.AdvancedExtension - relations []Relation + version *types.Version + extensions extensions.Set + expectedTypeURLs []string + advExtension *extensions.AdvancedExtension + relations []Relation + parameterBindings []DynamicParameterBinding reg expr.ExtensionRegistry } @@ -216,6 +217,14 @@ func (p *Plan) GetNonRootRelations() (rels []Rel) { return rels } +// ParameterBindings returns the list of dynamic parameter bindings for this plan. +// Each binding maps a parameter anchor to a runtime literal value. +// +// This returns a clone of the internal slice so that the plan itself remains immutable. +func (p *Plan) ParameterBindings() []DynamicParameterBinding { + return slices.Clone(p.parameterBindings) +} + func FromProto(plan *proto.Plan, c *extensions.Collection) (*Plan, error) { extSet, err := extensions.GetExtensionSet(plan, c) if err != nil { @@ -236,6 +245,16 @@ func FromProto(plan *proto.Plan, c *extensions.Collection) (*Plan, error) { } } + if len(plan.ParameterBindings) > 0 { + ret.parameterBindings = make([]DynamicParameterBinding, len(plan.ParameterBindings)) + for i, pb := range plan.ParameterBindings { + ret.parameterBindings[i] = DynamicParameterBinding{ + ParameterAnchor: pb.ParameterAnchor, + Value: expr.LiteralFromProto(pb.Value), + } + } + } + return ret, nil } @@ -245,6 +264,18 @@ func (p *Plan) ToProto() (*proto.Plan, error) { for i, r := range p.relations { relations[i] = r.ToProto() } + + var bindings []*proto.DynamicParameterBinding + if len(p.parameterBindings) > 0 { + bindings = make([]*proto.DynamicParameterBinding, len(p.parameterBindings)) + for i, b := range p.parameterBindings { + bindings[i] = &proto.DynamicParameterBinding{ + ParameterAnchor: b.ParameterAnchor, + Value: b.Value.ToProtoLiteral(), + } + } + } + return &proto.Plan{ Version: p.version, ExpectedTypeUrls: p.expectedTypeURLs, @@ -252,6 +283,7 @@ func (p *Plan) ToProto() (*proto.Plan, error) { Relations: relations, Extensions: decls, ExtensionUrns: urns, + ParameterBindings: bindings, }, nil } diff --git a/plan/testdata/dynamic_parameter_plan.json b/plan/testdata/dynamic_parameter_plan.json new file mode 100644 index 00000000..ff62f98b --- /dev/null +++ b/plan/testdata/dynamic_parameter_plan.json @@ -0,0 +1,55 @@ +{ + "version": {"majorNumber": 0, "minorNumber": 29, "patchNumber": 0}, + "relations": [ + { + "root": { + "input": { + "project": { + "common": {"direct": {}}, + "input": { + "read": { + "common": {"direct": {}}, + "baseSchema": { + "names": ["id", "name"], + "struct": { + "nullability": "NULLABILITY_REQUIRED", + "types": [ + {"i64": {"nullability": "NULLABILITY_REQUIRED"}}, + {"string": {"nullability": "NULLABILITY_NULLABLE"}} + ] + } + }, + "namedTable": {"names": ["users"]} + } + }, + "expressions": [ + { + "dynamicParameter": { + "type": {"i32": {"nullability": "NULLABILITY_REQUIRED"}}, + "parameterReference": 0 + } + }, + { + "dynamicParameter": { + "type": {"string": {"nullability": "NULLABILITY_NULLABLE"}}, + "parameterReference": 1 + } + } + ] + } + }, + "names": ["id", "name", "param0", "param1"] + } + } + ], + "parameterBindings": [ + { + "parameterAnchor": 0, + "value": {"i32": 99} + }, + { + "parameterAnchor": 1, + "value": {"string": "test_value"} + } + ] +} From 88e37e8d245262019403670e9ab7ea4d868b387e Mon Sep 17 00:00:00 2001 From: bvolpato Date: Sat, 21 Mar 2026 19:36:15 -0400 Subject: [PATCH 2/3] feat: add type validation for DynamicParameter bindings Validate that each DynamicParameterBinding's literal type matches the OutputType declared on the corresponding DynamicParameter expression in the plan tree. Type comparison ignores nullability. Validation runs automatically in PlanWithBindings and is also available as the exported ValidateParameterBindings function. New tests cover type mismatches, missing anchors, nullability tolerance, and validation through filter conditions. --- plan/builders.go | 4 + plan/common.go | 104 ++++++++++++++++++++++- plan/dynamic_parameter_test.go | 147 +++++++++++++++++++++++++++++++++ 3 files changed, 251 insertions(+), 4 deletions(-) diff --git a/plan/builders.go b/plan/builders.go index 1595c7be..66388f5e 100644 --- a/plan/builders.go +++ b/plan/builders.go @@ -760,6 +760,10 @@ func (b *builder) PlanWithBindings(root Rel, rootNames []string, expectedTypeURL return nil, err } + if err := ValidateParameterBindings(root, bindings); err != nil { + return nil, err + } + relations := make([]Relation, len(others)+1) relations[0].root = &Root{ input: root, names: rootNames, diff --git a/plan/common.go b/plan/common.go index 09fddf6b..d6d117d8 100644 --- a/plan/common.go +++ b/plan/common.go @@ -3,6 +3,9 @@ package plan import ( + "fmt" + + substraitgo "github.com/substrait-io/substrait-go/v7" "github.com/substrait-io/substrait-go/v7/expr" "github.com/substrait-io/substrait-go/v7/extensions" "github.com/substrait-io/substrait-go/v7/types" @@ -12,15 +15,108 @@ import ( // DynamicParameterBinding maps a parameter anchor to a literal value // for use with DynamicParameter expressions in a plan. // -// NOTE: this library does not currently validate that the type of the -// literal Value matches the OutputType declared on the corresponding -// DynamicParameter expression. Consumers should perform their own -// type-checking if needed. +// When bindings are provided via PlanWithBindings, the builder validates +// that each binding's literal type matches (ignoring nullability) the +// OutputType declared on the corresponding DynamicParameter expression. type DynamicParameterBinding struct { ParameterAnchor uint32 Value expr.Literal } +// ValidateParameterBindings checks that every binding's literal type matches +// the OutputType of the corresponding DynamicParameter expression found in +// the relation tree. Type comparison ignores nullability so that a required +// parameter can be bound to a nullable literal and vice-versa. +// +// Returns an error for: +// - A binding whose anchor does not correspond to any DynamicParameter in the tree. +// - A binding whose value type does not match the parameter's declared type +// (ignoring nullability). +func ValidateParameterBindings(root Rel, bindings []DynamicParameterBinding) error { + if len(bindings) == 0 { + return nil + } + + // Collect all DynamicParameter output types keyed by anchor. + paramTypes := make(map[uint32]types.Type) + collectDynamicParams(root, paramTypes) + + for _, b := range bindings { + declaredType, ok := paramTypes[b.ParameterAnchor] + if !ok { + return fmt.Errorf("%w: parameter binding references anchor %d, "+ + "but no DynamicParameter with that reference exists in the plan", + substraitgo.ErrInvalidPlan, b.ParameterAnchor) + } + + // Compare ignoring nullability. + bindingType := b.Value.GetType().WithNullability(types.NullabilityUnspecified) + expectedType := declaredType.WithNullability(types.NullabilityUnspecified) + if !bindingType.Equals(expectedType) { + return fmt.Errorf("%w: parameter binding for anchor %d has type %s, "+ + "but DynamicParameter declares type %s", + substraitgo.ErrInvalidPlan, b.ParameterAnchor, b.Value.GetType(), declaredType) + } + } + + return nil +} + +// collectDynamicParams walks a relation tree and records the OutputType +// of every DynamicParameter expression it encounters. +func collectDynamicParams(rel Rel, out map[uint32]types.Type) { + if rel == nil { + return + } + + // Walk child relations first. + for _, child := range rel.GetInputs() { + collectDynamicParams(child, out) + } + + // Walk expressions owned by this relation. + walkRelExprs(rel, func(e expr.Expression) { + walkExpr(e, func(inner expr.Expression) { + if dp, ok := inner.(*expr.DynamicParameter); ok { + out[dp.ParameterReference] = dp.OutputType + } + }) + }) +} + +// walkRelExprs invokes fn for every top-level expression in a relation. +func walkRelExprs(rel Rel, fn func(expr.Expression)) { + switch r := rel.(type) { + case *FilterRel: + fn(r.Condition()) + case *ProjectRel: + for _, e := range r.Expressions() { + fn(e) + } + case *JoinRel: + fn(r.Expr()) + if pjf := r.PostJoinFilter(); pjf != nil { + fn(pjf) + } + case *SortRel: + for _, sf := range r.Sorts() { + fn(sf.Expr) + } + } +} + +// walkExpr recursively visits every node in an expression tree. +func walkExpr(e expr.Expression, fn func(expr.Expression)) { + if e == nil { + return + } + fn(e) + e.Visit(func(child expr.Expression) expr.Expression { + walkExpr(child, fn) + return child + }) +} + type ( Hint = proto.RelCommon_Hint Stats = proto.RelCommon_Hint_Stats diff --git a/plan/dynamic_parameter_test.go b/plan/dynamic_parameter_test.go index 1b85ebce..599813cf 100644 --- a/plan/dynamic_parameter_test.go +++ b/plan/dynamic_parameter_test.go @@ -222,3 +222,150 @@ func TestDynamicParameterBuilderInPlanBuilder(t *testing.T) { require.NoError(t, err) assert.True(t, proto.Equal(protoPlan, roundTripProto)) } + +func TestDynamicParameterBindingTypeMismatch(t *testing.T) { + b := plan.NewBuilderDefault() + scan := b.NamedScan([]string{"test"}, baseSchema2) + + tests := []struct { + name string + dpType types.Type + bindValue expr.Literal + errMsg string + }{ + { + name: "i32 param bound to string literal", + dpType: &types.Int32Type{Nullability: types.NullabilityRequired}, + bindValue: expr.NewPrimitiveLiteral("hello", false), + errMsg: "parameter binding for anchor 0 has type", + }, + { + name: "string param bound to i32 literal", + dpType: &types.StringType{Nullability: types.NullabilityNullable}, + bindValue: expr.NewPrimitiveLiteral(int32(42), false), + errMsg: "parameter binding for anchor 0 has type", + }, + { + name: "fp64 param bound to i64 literal", + dpType: &types.Float64Type{Nullability: types.NullabilityRequired}, + bindValue: expr.NewPrimitiveLiteral(int64(100), false), + errMsg: "parameter binding for anchor 0 has type", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dp := &expr.DynamicParameter{ + OutputType: tt.dpType, + ParameterReference: 0, + } + + project, err := b.Project(scan, dp) + require.NoError(t, err) + + bindings := []plan.DynamicParameterBinding{ + { + ParameterAnchor: 0, + Value: tt.bindValue, + }, + } + + _, err = b.PlanWithBindings(project, []string{"x", "y", "p"}, nil, bindings) + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errMsg) + }) + } +} + +func TestDynamicParameterBindingMissingAnchor(t *testing.T) { + b := plan.NewBuilderDefault() + scan := b.NamedScan([]string{"test"}, baseSchema2) + + // Plan has dp with anchor 0, but binding references anchor 99 + dp := &expr.DynamicParameter{ + OutputType: &types.Int32Type{Nullability: types.NullabilityRequired}, + ParameterReference: 0, + } + + project, err := b.Project(scan, dp) + require.NoError(t, err) + + bindings := []plan.DynamicParameterBinding{ + { + ParameterAnchor: 99, + Value: expr.NewPrimitiveLiteral(int32(42), false), + }, + } + + _, err = b.PlanWithBindings(project, []string{"x", "y", "p"}, nil, bindings) + require.Error(t, err) + assert.Contains(t, err.Error(), "no DynamicParameter with that reference exists") +} + +func TestDynamicParameterBindingNullabilityMismatch(t *testing.T) { + // Nullability differences should be allowed — a required param can be + // bound to a nullable literal and vice-versa. + b := plan.NewBuilderDefault() + scan := b.NamedScan([]string{"test"}, baseSchema2) + + dp := &expr.DynamicParameter{ + OutputType: &types.Int32Type{Nullability: types.NullabilityRequired}, + ParameterReference: 0, + } + + project, err := b.Project(scan, dp) + require.NoError(t, err) + + bindings := []plan.DynamicParameterBinding{ + { + ParameterAnchor: 0, + Value: expr.NewPrimitiveLiteral(int32(42), true), // nullable literal + }, + } + + p, err := b.PlanWithBindings(project, []string{"x", "y", "p"}, nil, bindings) + require.NoError(t, err) + assert.NotNil(t, p) +} + +func TestDynamicParameterBindingInFilter(t *testing.T) { + // Validate that type validation also works through filter conditions + b := plan.NewBuilderDefault() + scan := b.NamedScan([]string{"test"}, baseSchema2) + + dp := &expr.DynamicParameter{ + OutputType: &types.Int32Type{Nullability: types.NullabilityRequired}, + ParameterReference: 0, + } + + ref, err := b.RootFieldRef(scan, 0) + require.NoError(t, err) + + gt, err := b.ScalarFn(extensions.SubstraitDefaultURNPrefix+"functions_comparison", "gt", nil, ref, dp) + require.NoError(t, err) + + filter, err := b.Filter(scan, gt) + require.NoError(t, err) + + // Wrong type binding should fail + wrongBindings := []plan.DynamicParameterBinding{ + { + ParameterAnchor: 0, + Value: expr.NewPrimitiveLiteral("not-a-number", false), + }, + } + _, err = b.PlanWithBindings(filter, []string{"x", "y"}, nil, wrongBindings) + require.Error(t, err) + assert.Contains(t, err.Error(), "parameter binding for anchor 0 has type") + + // Correct type binding should succeed + goodBindings := []plan.DynamicParameterBinding{ + { + ParameterAnchor: 0, + Value: expr.NewPrimitiveLiteral(int32(42), false), + }, + } + p, err := b.PlanWithBindings(filter, []string{"x", "y"}, nil, goodBindings) + require.NoError(t, err) + assert.NotNil(t, p) +} From db9f3f74b157aedcc47e41e650d0349e520d8234 Mon Sep 17 00:00:00 2001 From: bvolpato Date: Sat, 21 Mar 2026 19:39:51 -0400 Subject: [PATCH 3/3] refactor: simplify DynamicParameter tests per reviewer feedback Expr tests: - Convert Equals test to table-driven subtests - Extract nil-type builder test into its own function - Remove redundant FuncArg, InProject, and MultipleInExpression tests (covered by the roundtrip and builder-as-func-arg tests) - Add TestDynamicParameterTypeMismatchInFunction Plan tests: - Consolidate JSON roundtrip tests into a single loop over testdata files - Add dynamic_parameter_filter.json testdata (generated from builder) - Remove redundant programmatic builder tests that duplicated the JSON roundtrip coverage --- expr/dynamic_parameter_test.go | 237 ++++++-------------- plan/dynamic_parameter_test.go | 202 ++--------------- plan/testdata/dynamic_parameter_filter.json | 111 +++++++++ 3 files changed, 194 insertions(+), 356 deletions(-) create mode 100644 plan/testdata/dynamic_parameter_filter.json diff --git a/expr/dynamic_parameter_test.go b/expr/dynamic_parameter_test.go index 030c4a8a..2e0c5d84 100644 --- a/expr/dynamic_parameter_test.go +++ b/expr/dynamic_parameter_test.go @@ -15,30 +15,27 @@ import ( ) func TestDynamicParameterEquals(t *testing.T) { - dp1 := &expr.DynamicParameter{ - OutputType: &types.Int64Type{Nullability: types.NullabilityRequired}, - ParameterReference: 0, - } + i64Req := &types.Int64Type{Nullability: types.NullabilityRequired} + fp64Req := &types.Float64Type{Nullability: types.NullabilityRequired} - dp2 := &expr.DynamicParameter{ - OutputType: &types.Int64Type{Nullability: types.NullabilityRequired}, - ParameterReference: 0, - } + base := &expr.DynamicParameter{OutputType: i64Req, ParameterReference: 0} - dp3 := &expr.DynamicParameter{ - OutputType: &types.Int64Type{Nullability: types.NullabilityRequired}, - ParameterReference: 1, + tests := []struct { + name string + other expr.Expression + want bool + }{ + {"same type and ref", &expr.DynamicParameter{OutputType: i64Req, ParameterReference: 0}, true}, + {"different ref", &expr.DynamicParameter{OutputType: i64Req, ParameterReference: 1}, false}, + {"different type", &expr.DynamicParameter{OutputType: fp64Req, ParameterReference: 0}, false}, + {"different expression kind", expr.NewPrimitiveLiteral(int64(42), false), false}, } - dp4 := &expr.DynamicParameter{ - OutputType: &types.Float64Type{Nullability: types.NullabilityRequired}, - ParameterReference: 0, + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, base.Equals(tt.other)) + }) } - - assert.True(t, dp1.Equals(dp2), "same type and ref should be equal") - assert.False(t, dp1.Equals(dp3), "different ref should not be equal") - assert.False(t, dp1.Equals(dp4), "different type should not be equal") - assert.False(t, dp1.Equals(expr.NewPrimitiveLiteral(int64(42), false)), "different expression type should not be equal") } func TestDynamicParameterVisit(t *testing.T) { @@ -47,11 +44,7 @@ func TestDynamicParameterVisit(t *testing.T) { ParameterReference: 5, } - // Visit should return the same expression since DynamicParameter has no children - visited := dp.Visit(func(e expr.Expression) expr.Expression { - return e - }) - + visited := dp.Visit(func(e expr.Expression) expr.Expression { return e }) assert.Same(t, dp, visited, "Visit should return same pointer for leaf expression") } @@ -64,53 +57,29 @@ func TestDynamicParameterToProtoRoundtrip(t *testing.T) { name string dp *expr.DynamicParameter }{ - { - name: "required i32", - dp: &expr.DynamicParameter{ - OutputType: &types.Int32Type{Nullability: types.NullabilityRequired}, - ParameterReference: 0, - }, - }, - { - name: "nullable string", - dp: &expr.DynamicParameter{ - OutputType: &types.StringType{Nullability: types.NullabilityNullable}, - ParameterReference: 1, - }, - }, - { - name: "required fp64", - dp: &expr.DynamicParameter{ - OutputType: &types.Float64Type{Nullability: types.NullabilityRequired}, - ParameterReference: 5, - }, - }, - { - name: "required boolean", - dp: &expr.DynamicParameter{ - OutputType: &types.BooleanType{Nullability: types.NullabilityRequired}, - ParameterReference: 10, - }, - }, - { - name: "nullable i64", - dp: &expr.DynamicParameter{ - OutputType: &types.Int64Type{Nullability: types.NullabilityNullable}, - ParameterReference: 42, - }, - }, + {"required i32", &expr.DynamicParameter{ + OutputType: &types.Int32Type{Nullability: types.NullabilityRequired}, ParameterReference: 0}}, + {"nullable string", &expr.DynamicParameter{ + OutputType: &types.StringType{Nullability: types.NullabilityNullable}, ParameterReference: 1}}, + {"required fp64", &expr.DynamicParameter{ + OutputType: &types.Float64Type{Nullability: types.NullabilityRequired}, ParameterReference: 5}}, + {"required boolean", &expr.DynamicParameter{ + OutputType: &types.BooleanType{Nullability: types.NullabilityRequired}, ParameterReference: 10}}, + {"nullable i64", &expr.DynamicParameter{ + OutputType: &types.Int64Type{Nullability: types.NullabilityNullable}, ParameterReference: 42}}, } + reg := expr.NewEmptyExtensionRegistry(extensions.GetDefaultCollectionWithNoError()) + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { assert.True(t, tt.dp.IsScalar()) assert.True(t, tt.dp.GetType().Equals(tt.dp.OutputType)) - // Proto roundtrip: the plan should equal itself after a roundtrip protoExpr := tt.dp.ToProto() require.NotNil(t, protoExpr) - fromProto, err := expr.ExprFromProto(protoExpr, nil, expr.NewEmptyExtensionRegistry(extensions.GetDefaultCollectionWithNoError())) + fromProto, err := expr.ExprFromProto(protoExpr, nil, reg) require.NoError(t, err) assert.True(t, tt.dp.Equals(fromProto), "roundtrip should produce equal expression") @@ -120,20 +89,7 @@ func TestDynamicParameterToProtoRoundtrip(t *testing.T) { } } -func TestDynamicParameterToProtoFuncArg(t *testing.T) { - dp := &expr.DynamicParameter{ - OutputType: &types.Int32Type{Nullability: types.NullabilityRequired}, - ParameterReference: 0, - } - - arg := dp.ToProtoFuncArg() - require.NotNil(t, arg) - require.NotNil(t, arg.GetValue(), "should be a value argument") - require.NotNil(t, arg.GetValue().GetDynamicParameter(), "value should be a dynamic parameter") -} - func TestDynamicParameterFromProtoNilDynamicParam(t *testing.T) { - // Test ExprFromProto with a DynamicParameter that has nil inner protoExpr := &proto.Expression{ RexType: &proto.Expression_DynamicParameter{ DynamicParameter: nil, @@ -145,54 +101,14 @@ func TestDynamicParameterFromProtoNilDynamicParam(t *testing.T) { assert.Contains(t, err.Error(), "dynamic parameter is nil") } -func TestDynamicParameterBuilder(t *testing.T) { +func TestDynamicParameterBuilderNilType(t *testing.T) { b := expr.ExprBuilder{ Reg: expr.NewEmptyExtensionRegistry(extensions.GetDefaultCollectionWithNoError()), } - tests := []struct { - name string - build func() (expr.Expression, error) - expect string - err string - }{ - { - name: "basic i32", - build: func() (expr.Expression, error) { - return b.DynamicParam(&types.Int32Type{Nullability: types.NullabilityRequired}, 0).BuildExpr() - }, - expect: "$0:i32", - }, - { - name: "nullable string param 3", - build: func() (expr.Expression, error) { - return b.DynamicParam(&types.StringType{Nullability: types.NullabilityNullable}, 3).BuildExpr() - }, - expect: "$3:string?", - }, - { - name: "nil type should error", - build: func() (expr.Expression, error) { - return b.DynamicParam(nil, 0).BuildExpr() - }, - err: "dynamic parameter must have an output type", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - e, err := tt.build() - if tt.err != "" { - require.Error(t, err) - assert.Contains(t, err.Error(), tt.err) - } else { - require.NoError(t, err) - assert.Equal(t, tt.expect, e.String()) - // Verify proto roundtrip - e.ToProto() - } - }) - } + _, err := b.DynamicParam(nil, 0).BuildExpr() + require.Error(t, err) + assert.Contains(t, err.Error(), "dynamic parameter must have an output type") } func TestDynamicParameterBuilderAsFuncArg(t *testing.T) { @@ -201,19 +117,8 @@ func TestDynamicParameterBuilderAsFuncArg(t *testing.T) { BaseSchema: types.NewRecordTypeFromStruct(boringSchema.Struct), } - // Use DynamicParam as a function argument via the builder dpBuilder := b.DynamicParam(&types.Int8Type{Nullability: types.NullabilityRequired}, 0) - // Verify it implements FuncArgBuilder - funcArg, err := dpBuilder.BuildFuncArg() - require.NoError(t, err) - assert.NotNil(t, funcArg) - - dp, ok := funcArg.(*expr.DynamicParameter) - require.True(t, ok) - assert.Equal(t, uint32(0), dp.ParameterReference) - - // Build as a function argument in a scalar function e, err := b.ScalarFunc(addID).Args( dpBuilder, b.Wrap(expr.NewLiteral(int8(5), false)), @@ -222,53 +127,39 @@ func TestDynamicParameterBuilderAsFuncArg(t *testing.T) { assert.Contains(t, e.String(), "$0:i8") } -func TestDynamicParameterInProject(t *testing.T) { - // Test using dynamic parameter in a project expression through builders - - dp := &expr.DynamicParameter{ - OutputType: &types.Int32Type{Nullability: types.NullabilityRequired}, - ParameterReference: 0, +func TestDynamicParameterTypeMismatchInFunction(t *testing.T) { + b := expr.ExprBuilder{ + Reg: expr.NewEmptyExtensionRegistry(extensions.GetDefaultCollectionWithNoError()), + BaseSchema: types.NewRecordTypeFromStruct(boringSchema.Struct), } - // Verify it can be used as a project expression - protoExpr := dp.ToProto() - require.NotNil(t, protoExpr) - - // Roundtrip - fromProto, err := expr.ExprFromProto(protoExpr, nil, expr.NewEmptyExtensionRegistry(extensions.GetDefaultCollectionWithNoError())) - require.NoError(t, err) - require.IsType(t, &expr.DynamicParameter{}, fromProto) - - roundtripped := fromProto.(*expr.DynamicParameter) - assert.Equal(t, uint32(0), roundtripped.ParameterReference) - assert.True(t, roundtripped.GetType().Equals(&types.Int32Type{Nullability: types.NullabilityRequired})) -} - -func TestDynamicParameterMultipleInExpression(t *testing.T) { - dp0 := &expr.DynamicParameter{ - OutputType: &types.Int32Type{Nullability: types.NullabilityRequired}, - ParameterReference: 0, + tests := []struct { + name string + funcID extensions.ID + dpType types.Type + lit func() (expr.Literal, error) + }{ + { + name: "i32 where i8 expected", + funcID: extensions.ID{URN: extensions.SubstraitDefaultURNPrefix + "functions_arithmetic", Name: "add:i8_i8"}, + dpType: &types.Int32Type{Nullability: types.NullabilityRequired}, + lit: func() (expr.Literal, error) { return expr.NewLiteral(int8(5), false) }, + }, + { + name: "string where numeric expected", + funcID: addID, + dpType: &types.StringType{Nullability: types.NullabilityRequired}, + lit: func() (expr.Literal, error) { return expr.NewLiteral(int32(5), false) }, + }, } - dp1 := &expr.DynamicParameter{ - OutputType: &types.StringType{Nullability: types.NullabilityNullable}, - ParameterReference: 1, + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := b.ScalarFunc(tt.funcID).Args( + b.DynamicParam(tt.dpType, 0), + b.Wrap(tt.lit()), + ).BuildExpr() + require.Error(t, err) + }) } - - // Both should work independently - proto0 := dp0.ToProto() - proto1 := dp1.ToProto() - require.NotNil(t, proto0) - require.NotNil(t, proto1) - - from0, err := expr.ExprFromProto(proto0, nil, expr.NewEmptyExtensionRegistry(extensions.GetDefaultCollectionWithNoError())) - require.NoError(t, err) - from1, err := expr.ExprFromProto(proto1, nil, expr.NewEmptyExtensionRegistry(extensions.GetDefaultCollectionWithNoError())) - require.NoError(t, err) - - // They should not be equal to each other - assert.False(t, from0.Equals(from1)) - // But should be equal to themselves - assert.True(t, from0.Equals(dp0)) - assert.True(t, from1.Equals(dp1)) } diff --git a/plan/dynamic_parameter_test.go b/plan/dynamic_parameter_test.go index 599813cf..440013ec 100644 --- a/plan/dynamic_parameter_test.go +++ b/plan/dynamic_parameter_test.go @@ -3,7 +3,7 @@ package plan_test import ( - "os" + "fmt" "testing" "github.com/stretchr/testify/assert" @@ -17,124 +17,31 @@ import ( "google.golang.org/protobuf/proto" ) -func TestDynamicParameterInFilterPlan(t *testing.T) { - b := plan.NewBuilderDefault() - scan := b.NamedScan([]string{"test"}, baseSchema2) - - dp := &expr.DynamicParameter{ - OutputType: &types.Int32Type{Nullability: types.NullabilityRequired}, - ParameterReference: 0, - } - - ref, err := b.RootFieldRef(scan, 0) - require.NoError(t, err) - - gt, err := b.ScalarFn(extensions.SubstraitDefaultURNPrefix+"functions_comparison", "gt", nil, ref, dp) - require.NoError(t, err) - - filter, err := b.Filter(scan, gt) - require.NoError(t, err) - - bindings := []plan.DynamicParameterBinding{ - { - ParameterAnchor: 0, - Value: expr.NewPrimitiveLiteral(int32(42), false), - }, - } - - p, err := b.PlanWithBindings(filter, []string{"x", "y"}, nil, bindings) - require.NoError(t, err) - - protoPlan, err := p.ToProto() - require.NoError(t, err) - - roundTrip, err := plan.FromProto(protoPlan, extensions.GetDefaultCollectionWithNoError()) - require.NoError(t, err) - - roundTripProto, err := roundTrip.ToProto() - require.NoError(t, err) - - assert.Truef(t, proto.Equal(protoPlan, roundTripProto), "plan expected: %s\ngot: %s", - protojson.Format(protoPlan), protojson.Format(roundTripProto)) -} - -func TestDynamicParameterInProjectPlan(t *testing.T) { - b := plan.NewBuilderDefault() - scan := b.NamedScan([]string{"test"}, baseSchema2) - - dp := &expr.DynamicParameter{ - OutputType: &types.StringType{Nullability: types.NullabilityNullable}, - ParameterReference: 0, - } - - project, err := b.Project(scan, dp) - require.NoError(t, err) - - bindings := []plan.DynamicParameterBinding{ - { - ParameterAnchor: 0, - Value: expr.NewPrimitiveLiteral("hello", false), - }, - } - - p, err := b.PlanWithBindings(project, []string{"x", "y", "param_val"}, nil, bindings) - require.NoError(t, err) - - protoPlan, err := p.ToProto() - require.NoError(t, err) - - roundTrip, err := plan.FromProto(protoPlan, extensions.GetDefaultCollectionWithNoError()) - require.NoError(t, err) - - roundTripProto, err := roundTrip.ToProto() - require.NoError(t, err) - assert.True(t, proto.Equal(protoPlan, roundTripProto)) -} - -func TestDynamicParameterMultipleBindings(t *testing.T) { - b := plan.NewBuilderDefault() - scan := b.NamedScan([]string{"test"}, baseSchema2) - - dp0 := &expr.DynamicParameter{ - OutputType: &types.Int32Type{Nullability: types.NullabilityRequired}, - ParameterReference: 0, - } +func TestDynamicParameterPlanRoundtrip(t *testing.T) { + for _, name := range []string{ + "dynamic_parameter_plan", + "dynamic_parameter_filter", + } { + t.Run(name, func(t *testing.T) { + planJSON, err := testdata.ReadFile(fmt.Sprintf("testdata/%s.json", name)) + require.NoError(t, err) - dp1 := &expr.DynamicParameter{ - OutputType: &types.StringType{Nullability: types.NullabilityNullable}, - ParameterReference: 1, - } + var protoPlan substraitproto.Plan + require.NoError(t, protojson.Unmarshal(planJSON, &protoPlan)) - project, err := b.Project(scan, dp0, dp1) - require.NoError(t, err) + p, err := plan.FromProto(&protoPlan, extensions.GetDefaultCollectionWithNoError()) + require.NoError(t, err) - bindings := []plan.DynamicParameterBinding{ - { - ParameterAnchor: 0, - Value: expr.NewPrimitiveLiteral(int32(100), false), - }, - { - ParameterAnchor: 1, - Value: expr.NewPrimitiveLiteral("world", true), - }, + backToProto, err := p.ToProto() + require.NoError(t, err) + assert.Truef(t, proto.Equal(&protoPlan, backToProto), + "expected: %s\ngot: %s", + protojson.Format(&protoPlan), protojson.Format(backToProto)) + }) } - - p, err := b.PlanWithBindings(project, []string{"x", "y", "p0", "p1"}, nil, bindings) - require.NoError(t, err) - - protoPlan, err := p.ToProto() - require.NoError(t, err) - - roundTrip, err := plan.FromProto(protoPlan, extensions.GetDefaultCollectionWithNoError()) - require.NoError(t, err) - - roundTripProto, err := roundTrip.ToProto() - require.NoError(t, err) - assert.True(t, proto.Equal(protoPlan, roundTripProto)) } func TestDynamicParameterPlanWithoutBindings(t *testing.T) { - // Plan with dynamic parameters but no bindings (valid use case) b := plan.NewBuilderDefault() scan := b.NamedScan([]string{"test"}, baseSchema2) @@ -146,81 +53,14 @@ func TestDynamicParameterPlanWithoutBindings(t *testing.T) { project, err := b.Project(scan, dp) require.NoError(t, err) - // Use regular Plan (no bindings) p, err := b.Plan(project, []string{"x", "y", "param"}) require.NoError(t, err) - // Verify no bindings assert.Empty(t, p.ParameterBindings()) - // Proto roundtrip should still work protoPlan, err := p.ToProto() require.NoError(t, err) assert.Empty(t, protoPlan.ParameterBindings) - - roundTrip, err := plan.FromProto(protoPlan, extensions.GetDefaultCollectionWithNoError()) - require.NoError(t, err) - assert.Empty(t, roundTrip.ParameterBindings()) -} - -func TestDynamicParameterFromProtoJSON(t *testing.T) { - planJSON, err := os.ReadFile("testdata/dynamic_parameter_plan.json") - require.NoError(t, err) - - var protoPlan substraitproto.Plan - require.NoError(t, protojson.Unmarshal(planJSON, &protoPlan)) - - p, err := plan.FromProto(&protoPlan, extensions.GetDefaultCollectionWithNoError()) - require.NoError(t, err) - - backToProto, err := p.ToProto() - require.NoError(t, err) - assert.Truef(t, proto.Equal(&protoPlan, backToProto), "expected: %s\ngot: %s", - protojson.Format(&protoPlan), protojson.Format(backToProto)) -} - -func TestDynamicParameterBuilderInPlanBuilder(t *testing.T) { - b := plan.NewBuilderDefault() - eb := b.GetExprBuilder() - - scan := b.NamedScan([]string{"employees"}, types.NamedStruct{ - Names: []string{"id", "salary"}, - Struct: types.StructType{ - Nullability: types.NullabilityRequired, - Types: []types.Type{ - &types.Int64Type{Nullability: types.NullabilityRequired}, - &types.Float64Type{Nullability: types.NullabilityRequired}, - }, - }, - }) - - dpExpr, err := eb.DynamicParam( - &types.Float64Type{Nullability: types.NullabilityRequired}, 0, - ).BuildExpr() - require.NoError(t, err) - - project, err := b.Project(scan, dpExpr) - require.NoError(t, err) - - bindings := []plan.DynamicParameterBinding{ - { - ParameterAnchor: 0, - Value: expr.NewPrimitiveLiteral(float64(50000.0), false), - }, - } - - p, err := b.PlanWithBindings(project, []string{"id", "salary", "threshold"}, nil, bindings) - require.NoError(t, err) - - protoPlan, err := p.ToProto() - require.NoError(t, err) - - roundTrip, err := plan.FromProto(protoPlan, extensions.GetDefaultCollectionWithNoError()) - require.NoError(t, err) - - roundTripProto, err := roundTrip.ToProto() - require.NoError(t, err) - assert.True(t, proto.Equal(protoPlan, roundTripProto)) } func TestDynamicParameterBindingTypeMismatch(t *testing.T) { @@ -281,7 +121,6 @@ func TestDynamicParameterBindingMissingAnchor(t *testing.T) { b := plan.NewBuilderDefault() scan := b.NamedScan([]string{"test"}, baseSchema2) - // Plan has dp with anchor 0, but binding references anchor 99 dp := &expr.DynamicParameter{ OutputType: &types.Int32Type{Nullability: types.NullabilityRequired}, ParameterReference: 0, @@ -303,8 +142,6 @@ func TestDynamicParameterBindingMissingAnchor(t *testing.T) { } func TestDynamicParameterBindingNullabilityMismatch(t *testing.T) { - // Nullability differences should be allowed — a required param can be - // bound to a nullable literal and vice-versa. b := plan.NewBuilderDefault() scan := b.NamedScan([]string{"test"}, baseSchema2) @@ -329,7 +166,6 @@ func TestDynamicParameterBindingNullabilityMismatch(t *testing.T) { } func TestDynamicParameterBindingInFilter(t *testing.T) { - // Validate that type validation also works through filter conditions b := plan.NewBuilderDefault() scan := b.NamedScan([]string{"test"}, baseSchema2) diff --git a/plan/testdata/dynamic_parameter_filter.json b/plan/testdata/dynamic_parameter_filter.json new file mode 100644 index 00000000..2ba906b3 --- /dev/null +++ b/plan/testdata/dynamic_parameter_filter.json @@ -0,0 +1,111 @@ +{ + "version": { + "minorNumber": 29, + "producer": "substrait-go (devel) darwin/arm64" + }, + "extensionUrns": [ + { + "extensionUrnAnchor": 1, + "urn": "extension:io.substrait:functions_comparison" + } + ], + "extensions": [ + { + "extensionFunction": { + "extensionUrnReference": 1, + "functionAnchor": 1, + "name": "gt:any_any" + } + } + ], + "relations": [ + { + "root": { + "input": { + "filter": { + "common": { + "direct": {} + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "x", + "y" + ], + "struct": { + "types": [ + { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + } + ], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "test" + ] + } + } + }, + "condition": { + "scalarFunction": { + "functionReference": 1, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + }, + { + "value": { + "dynamicParameter": { + "type": { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + } + } + } + } + ], + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + } + } + } + } + }, + "names": [ + "x", + "y" + ] + } + } + ], + "parameterBindings": [ + { + "value": { + "i32": 42 + } + } + ] +}