diff --git a/expr/builder.go b/expr/builder.go index 37daca22..4f41c811 100644 --- a/expr/builder.go +++ b/expr/builder.go @@ -180,6 +180,15 @@ func (e *ExprBuilder) Cast(from Builder, to types.Type) *castBuilder { } } +// DynamicParam returns a builder for constructing a DynamicParameter expression. +// The paramRef identifies the parameter binding in the plan. +func (e *ExprBuilder) DynamicParam(outputType types.Type, paramRef uint32) *dynamicParamBuilder { + return &dynamicParamBuilder{ + outputType: outputType, + paramRef: paramRef, + } +} + // Lambda returns a builder for constructing a Lambda expression with the // given parameters. // @@ -301,6 +310,24 @@ func (cb *castBuilder) FailBehavior(b types.CastFailBehavior) *castBuilder { return cb } +type dynamicParamBuilder struct { + outputType types.Type + paramRef uint32 +} + +func (dpb *dynamicParamBuilder) Build() (*DynamicParameter, error) { + if dpb.outputType == nil { + return nil, fmt.Errorf("%w: dynamic parameter must have an output type", substraitgo.ErrInvalidExpr) + } + return &DynamicParameter{ + OutputType: dpb.outputType, + ParameterReference: dpb.paramRef, + }, nil +} + +func (dpb *dynamicParamBuilder) BuildExpr() (Expression, error) { return dpb.Build() } +func (dpb *dynamicParamBuilder) BuildFuncArg() (types.FuncArg, error) { return dpb.Build() } + type scalarFuncBuilder struct { b *ExprBuilder diff --git a/expr/dynamic_parameter_internal_test.go b/expr/dynamic_parameter_internal_test.go new file mode 100644 index 00000000..22ac77c6 --- /dev/null +++ b/expr/dynamic_parameter_internal_test.go @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: Apache-2.0 + +package expr + +import ( + "testing" + + "github.com/substrait-io/substrait-go/v7/types" +) + +func TestDynamicParameterIsRootRef(t *testing.T) { + dp := &DynamicParameter{ + OutputType: &types.Int32Type{Nullability: types.NullabilityRequired}, + ParameterReference: 0, + } + + // Verify DynamicParameter satisfies RootRefType + var _ RootRefType = dp + dp.isRootRef() +} diff --git a/expr/dynamic_parameter_test.go b/expr/dynamic_parameter_test.go new file mode 100644 index 00000000..2e0c5d84 --- /dev/null +++ b/expr/dynamic_parameter_test.go @@ -0,0 +1,165 @@ +// SPDX-License-Identifier: Apache-2.0 + +package expr_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/substrait-io/substrait-go/v7/expr" + "github.com/substrait-io/substrait-go/v7/extensions" + "github.com/substrait-io/substrait-go/v7/types" + proto "github.com/substrait-io/substrait-protobuf/go/substraitpb" + pb "google.golang.org/protobuf/proto" +) + +func TestDynamicParameterEquals(t *testing.T) { + i64Req := &types.Int64Type{Nullability: types.NullabilityRequired} + fp64Req := &types.Float64Type{Nullability: types.NullabilityRequired} + + base := &expr.DynamicParameter{OutputType: i64Req, ParameterReference: 0} + + tests := []struct { + name string + other expr.Expression + want bool + }{ + {"same type and ref", &expr.DynamicParameter{OutputType: i64Req, ParameterReference: 0}, true}, + {"different ref", &expr.DynamicParameter{OutputType: i64Req, ParameterReference: 1}, false}, + {"different type", &expr.DynamicParameter{OutputType: fp64Req, ParameterReference: 0}, false}, + {"different expression kind", expr.NewPrimitiveLiteral(int64(42), false), false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, base.Equals(tt.other)) + }) + } +} + +func TestDynamicParameterVisit(t *testing.T) { + dp := &expr.DynamicParameter{ + OutputType: &types.Int32Type{Nullability: types.NullabilityRequired}, + ParameterReference: 5, + } + + visited := dp.Visit(func(e expr.Expression) expr.Expression { return e }) + assert.Same(t, dp, visited, "Visit should return same pointer for leaf expression") +} + +// TestDynamicParameterToProtoRoundtrip tests construction, interface compliance, +// and proto roundtrip for various DynamicParameter configurations. +// The $N:type String() format (e.g. "$0:i32") is an internal debugging +// representation used by this library; it is not part of the Substrait spec. +func TestDynamicParameterToProtoRoundtrip(t *testing.T) { + tests := []struct { + name string + dp *expr.DynamicParameter + }{ + {"required i32", &expr.DynamicParameter{ + OutputType: &types.Int32Type{Nullability: types.NullabilityRequired}, ParameterReference: 0}}, + {"nullable string", &expr.DynamicParameter{ + OutputType: &types.StringType{Nullability: types.NullabilityNullable}, ParameterReference: 1}}, + {"required fp64", &expr.DynamicParameter{ + OutputType: &types.Float64Type{Nullability: types.NullabilityRequired}, ParameterReference: 5}}, + {"required boolean", &expr.DynamicParameter{ + OutputType: &types.BooleanType{Nullability: types.NullabilityRequired}, ParameterReference: 10}}, + {"nullable i64", &expr.DynamicParameter{ + OutputType: &types.Int64Type{Nullability: types.NullabilityNullable}, ParameterReference: 42}}, + } + + reg := expr.NewEmptyExtensionRegistry(extensions.GetDefaultCollectionWithNoError()) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.True(t, tt.dp.IsScalar()) + assert.True(t, tt.dp.GetType().Equals(tt.dp.OutputType)) + + protoExpr := tt.dp.ToProto() + require.NotNil(t, protoExpr) + + fromProto, err := expr.ExprFromProto(protoExpr, nil, reg) + require.NoError(t, err) + assert.True(t, tt.dp.Equals(fromProto), "roundtrip should produce equal expression") + + protoRoundTrip := fromProto.ToProto() + assert.True(t, pb.Equal(protoExpr, protoRoundTrip), "proto roundtrip should be equal") + }) + } +} + +func TestDynamicParameterFromProtoNilDynamicParam(t *testing.T) { + protoExpr := &proto.Expression{ + RexType: &proto.Expression_DynamicParameter{ + DynamicParameter: nil, + }, + } + + _, err := expr.ExprFromProto(protoExpr, nil, expr.NewEmptyExtensionRegistry(extensions.GetDefaultCollectionWithNoError())) + require.Error(t, err) + assert.Contains(t, err.Error(), "dynamic parameter is nil") +} + +func TestDynamicParameterBuilderNilType(t *testing.T) { + b := expr.ExprBuilder{ + Reg: expr.NewEmptyExtensionRegistry(extensions.GetDefaultCollectionWithNoError()), + } + + _, err := b.DynamicParam(nil, 0).BuildExpr() + require.Error(t, err) + assert.Contains(t, err.Error(), "dynamic parameter must have an output type") +} + +func TestDynamicParameterBuilderAsFuncArg(t *testing.T) { + b := expr.ExprBuilder{ + Reg: expr.NewEmptyExtensionRegistry(extensions.GetDefaultCollectionWithNoError()), + BaseSchema: types.NewRecordTypeFromStruct(boringSchema.Struct), + } + + dpBuilder := b.DynamicParam(&types.Int8Type{Nullability: types.NullabilityRequired}, 0) + + e, err := b.ScalarFunc(addID).Args( + dpBuilder, + b.Wrap(expr.NewLiteral(int8(5), false)), + ).BuildExpr() + require.NoError(t, err) + assert.Contains(t, e.String(), "$0:i8") +} + +func TestDynamicParameterTypeMismatchInFunction(t *testing.T) { + b := expr.ExprBuilder{ + Reg: expr.NewEmptyExtensionRegistry(extensions.GetDefaultCollectionWithNoError()), + BaseSchema: types.NewRecordTypeFromStruct(boringSchema.Struct), + } + + tests := []struct { + name string + funcID extensions.ID + dpType types.Type + lit func() (expr.Literal, error) + }{ + { + name: "i32 where i8 expected", + funcID: extensions.ID{URN: extensions.SubstraitDefaultURNPrefix + "functions_arithmetic", Name: "add:i8_i8"}, + dpType: &types.Int32Type{Nullability: types.NullabilityRequired}, + lit: func() (expr.Literal, error) { return expr.NewLiteral(int8(5), false) }, + }, + { + name: "string where numeric expected", + funcID: addID, + dpType: &types.StringType{Nullability: types.NullabilityRequired}, + lit: func() (expr.Literal, error) { return expr.NewLiteral(int32(5), false) }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := b.ScalarFunc(tt.funcID).Args( + b.DynamicParam(tt.dpType, 0), + b.Wrap(tt.lit()), + ).BuildExpr() + require.Error(t, err) + }) + } +} diff --git a/expr/expression.go b/expr/expression.go index a52a03ca..48860fa3 100644 --- a/expr/expression.go +++ b/expr/expression.go @@ -316,6 +316,14 @@ func ExprFromProto(e *proto.Expression, baseSchema *types.RecordType, reg Extens return nil, fmt.Errorf("%w: nested expression: %s", substraitgo.ErrInvalidExpr, n) } + case *proto.Expression_DynamicParameter: + if et.DynamicParameter == nil { + return nil, fmt.Errorf("%w: dynamic parameter is nil", substraitgo.ErrInvalidExpr) + } + return &DynamicParameter{ + OutputType: types.TypeFromProto(et.DynamicParameter.Type), + ParameterReference: et.DynamicParameter.ParameterReference, + }, nil case *proto.Expression_Enum_: return nil, fmt.Errorf("%w: deprecated", substraitgo.ErrNotImplemented) case *proto.Expression_Subquery_: @@ -372,6 +380,7 @@ type VisitFunc func(Expression) Expression // - A Cast expression // - A Subquery // - A Nested expression +// - A Dynamic Parameter type Expression interface { // an Expression can also be a function argument types.FuncArg @@ -654,6 +663,54 @@ func (ex *Cast) Visit(visit VisitFunc) Expression { return &newCast } +type DynamicParameter struct { + OutputType types.Type + ParameterReference uint32 +} + +func (dp *DynamicParameter) String() string { + return fmt.Sprintf("$%d:%s", dp.ParameterReference, dp.OutputType) +} + +func (dp *DynamicParameter) ToProtoFuncArg() *proto.FunctionArgument { + return &proto.FunctionArgument{ + ArgType: &proto.FunctionArgument_Value{ + Value: dp.ToProto(), + }, + } +} + +func (dp *DynamicParameter) isRootRef() {} + +func (dp *DynamicParameter) IsScalar() bool { return true } + +func (dp *DynamicParameter) GetType() types.Type { return dp.OutputType } + +func (dp *DynamicParameter) ToProto() *proto.Expression { + return &proto.Expression{ + RexType: &proto.Expression_DynamicParameter{ + DynamicParameter: &proto.DynamicParameter{ + Type: types.TypeToProto(dp.OutputType), + ParameterReference: dp.ParameterReference, + }, + }, + } +} + +func (dp *DynamicParameter) Equals(other Expression) bool { + rhs, ok := other.(*DynamicParameter) + if !ok { + return false + } + + return dp.ParameterReference == rhs.ParameterReference && + dp.OutputType.Equals(rhs.OutputType) +} + +func (dp *DynamicParameter) Visit(visit VisitFunc) Expression { + return dp +} + type SwitchExpr struct { match Expression ifs []struct { diff --git a/plan/builders.go b/plan/builders.go index edf8af87..66388f5e 100644 --- a/plan/builders.go +++ b/plan/builders.go @@ -142,6 +142,10 @@ type Builder interface { // that may be in use with this plan for advanced extensions, optimizations, // and so on. PlanWithTypes(root Rel, rootNames []string, expectedTypeURLs []string, others ...Rel) (*Plan, error) + // PlanWithBindings is the same as PlanWithTypes, but additionally accepts + // dynamic parameter bindings. These bind DynamicParameter expressions in + // the plan tree to concrete literal values at runtime. + PlanWithBindings(root Rel, rootNames []string, expectedTypeURLs []string, bindings []DynamicParameterBinding, others ...Rel) (*Plan, error) // GetExprBuilder returns an expr.ExprBuilder that shares the extension // registry that this Builder uses. @@ -739,6 +743,14 @@ func (b *builder) Set(op SetOp, inputs ...Rel) (*SetRel, error) { } func (b *builder) PlanWithTypes(root Rel, rootNames []string, expectedTypeURLs []string, others ...Rel) (*Plan, error) { + return b.PlanWithBindings(root, rootNames, expectedTypeURLs, nil, others...) +} + +// PlanWithBindings constructs a new plan with the provided root relation, +// expected type URLs, and dynamic parameter bindings. This allows creating +// parameterized plans where expressions contain DynamicParameter references +// that are bound to concrete values. +func (b *builder) PlanWithBindings(root Rel, rootNames []string, expectedTypeURLs []string, bindings []DynamicParameterBinding, others ...Rel) (*Plan, error) { if root == nil { return nil, fmt.Errorf("%w: must provide non-nil root relation for plan", substraitgo.ErrInvalidRel) @@ -748,6 +760,10 @@ func (b *builder) PlanWithTypes(root Rel, rootNames []string, expectedTypeURLs [ return nil, err } + if err := ValidateParameterBindings(root, bindings); err != nil { + return nil, err + } + relations := make([]Relation, len(others)+1) relations[0].root = &Root{ input: root, names: rootNames, @@ -758,11 +774,12 @@ func (b *builder) PlanWithTypes(root Rel, rootNames []string, expectedTypeURLs [ } return &Plan{ - version: &CurrentVersion, - extensions: b.extSet, - reg: b.reg, - expectedTypeURLs: expectedTypeURLs, - relations: relations, + version: &CurrentVersion, + extensions: b.extSet, + reg: b.reg, + expectedTypeURLs: expectedTypeURLs, + relations: relations, + parameterBindings: bindings, }, nil } diff --git a/plan/common.go b/plan/common.go index de1cd193..d6d117d8 100644 --- a/plan/common.go +++ b/plan/common.go @@ -3,11 +3,120 @@ package plan import ( + "fmt" + + substraitgo "github.com/substrait-io/substrait-go/v7" + "github.com/substrait-io/substrait-go/v7/expr" "github.com/substrait-io/substrait-go/v7/extensions" "github.com/substrait-io/substrait-go/v7/types" proto "github.com/substrait-io/substrait-protobuf/go/substraitpb" ) +// DynamicParameterBinding maps a parameter anchor to a literal value +// for use with DynamicParameter expressions in a plan. +// +// When bindings are provided via PlanWithBindings, the builder validates +// that each binding's literal type matches (ignoring nullability) the +// OutputType declared on the corresponding DynamicParameter expression. +type DynamicParameterBinding struct { + ParameterAnchor uint32 + Value expr.Literal +} + +// ValidateParameterBindings checks that every binding's literal type matches +// the OutputType of the corresponding DynamicParameter expression found in +// the relation tree. Type comparison ignores nullability so that a required +// parameter can be bound to a nullable literal and vice-versa. +// +// Returns an error for: +// - A binding whose anchor does not correspond to any DynamicParameter in the tree. +// - A binding whose value type does not match the parameter's declared type +// (ignoring nullability). +func ValidateParameterBindings(root Rel, bindings []DynamicParameterBinding) error { + if len(bindings) == 0 { + return nil + } + + // Collect all DynamicParameter output types keyed by anchor. + paramTypes := make(map[uint32]types.Type) + collectDynamicParams(root, paramTypes) + + for _, b := range bindings { + declaredType, ok := paramTypes[b.ParameterAnchor] + if !ok { + return fmt.Errorf("%w: parameter binding references anchor %d, "+ + "but no DynamicParameter with that reference exists in the plan", + substraitgo.ErrInvalidPlan, b.ParameterAnchor) + } + + // Compare ignoring nullability. + bindingType := b.Value.GetType().WithNullability(types.NullabilityUnspecified) + expectedType := declaredType.WithNullability(types.NullabilityUnspecified) + if !bindingType.Equals(expectedType) { + return fmt.Errorf("%w: parameter binding for anchor %d has type %s, "+ + "but DynamicParameter declares type %s", + substraitgo.ErrInvalidPlan, b.ParameterAnchor, b.Value.GetType(), declaredType) + } + } + + return nil +} + +// collectDynamicParams walks a relation tree and records the OutputType +// of every DynamicParameter expression it encounters. +func collectDynamicParams(rel Rel, out map[uint32]types.Type) { + if rel == nil { + return + } + + // Walk child relations first. + for _, child := range rel.GetInputs() { + collectDynamicParams(child, out) + } + + // Walk expressions owned by this relation. + walkRelExprs(rel, func(e expr.Expression) { + walkExpr(e, func(inner expr.Expression) { + if dp, ok := inner.(*expr.DynamicParameter); ok { + out[dp.ParameterReference] = dp.OutputType + } + }) + }) +} + +// walkRelExprs invokes fn for every top-level expression in a relation. +func walkRelExprs(rel Rel, fn func(expr.Expression)) { + switch r := rel.(type) { + case *FilterRel: + fn(r.Condition()) + case *ProjectRel: + for _, e := range r.Expressions() { + fn(e) + } + case *JoinRel: + fn(r.Expr()) + if pjf := r.PostJoinFilter(); pjf != nil { + fn(pjf) + } + case *SortRel: + for _, sf := range r.Sorts() { + fn(sf.Expr) + } + } +} + +// walkExpr recursively visits every node in an expression tree. +func walkExpr(e expr.Expression, fn func(expr.Expression)) { + if e == nil { + return + } + fn(e) + e.Visit(func(child expr.Expression) expr.Expression { + walkExpr(child, fn) + return child + }) +} + type ( Hint = proto.RelCommon_Hint Stats = proto.RelCommon_Hint_Stats diff --git a/plan/dynamic_parameter_test.go b/plan/dynamic_parameter_test.go new file mode 100644 index 00000000..440013ec --- /dev/null +++ b/plan/dynamic_parameter_test.go @@ -0,0 +1,207 @@ +// SPDX-License-Identifier: Apache-2.0 + +package plan_test + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/substrait-io/substrait-go/v7/expr" + "github.com/substrait-io/substrait-go/v7/extensions" + "github.com/substrait-io/substrait-go/v7/plan" + "github.com/substrait-io/substrait-go/v7/types" + substraitproto "github.com/substrait-io/substrait-protobuf/go/substraitpb" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" +) + +func TestDynamicParameterPlanRoundtrip(t *testing.T) { + for _, name := range []string{ + "dynamic_parameter_plan", + "dynamic_parameter_filter", + } { + t.Run(name, func(t *testing.T) { + planJSON, err := testdata.ReadFile(fmt.Sprintf("testdata/%s.json", name)) + require.NoError(t, err) + + var protoPlan substraitproto.Plan + require.NoError(t, protojson.Unmarshal(planJSON, &protoPlan)) + + p, err := plan.FromProto(&protoPlan, extensions.GetDefaultCollectionWithNoError()) + require.NoError(t, err) + + backToProto, err := p.ToProto() + require.NoError(t, err) + assert.Truef(t, proto.Equal(&protoPlan, backToProto), + "expected: %s\ngot: %s", + protojson.Format(&protoPlan), protojson.Format(backToProto)) + }) + } +} + +func TestDynamicParameterPlanWithoutBindings(t *testing.T) { + b := plan.NewBuilderDefault() + scan := b.NamedScan([]string{"test"}, baseSchema2) + + dp := &expr.DynamicParameter{ + OutputType: &types.Int32Type{Nullability: types.NullabilityRequired}, + ParameterReference: 0, + } + + project, err := b.Project(scan, dp) + require.NoError(t, err) + + p, err := b.Plan(project, []string{"x", "y", "param"}) + require.NoError(t, err) + + assert.Empty(t, p.ParameterBindings()) + + protoPlan, err := p.ToProto() + require.NoError(t, err) + assert.Empty(t, protoPlan.ParameterBindings) +} + +func TestDynamicParameterBindingTypeMismatch(t *testing.T) { + b := plan.NewBuilderDefault() + scan := b.NamedScan([]string{"test"}, baseSchema2) + + tests := []struct { + name string + dpType types.Type + bindValue expr.Literal + errMsg string + }{ + { + name: "i32 param bound to string literal", + dpType: &types.Int32Type{Nullability: types.NullabilityRequired}, + bindValue: expr.NewPrimitiveLiteral("hello", false), + errMsg: "parameter binding for anchor 0 has type", + }, + { + name: "string param bound to i32 literal", + dpType: &types.StringType{Nullability: types.NullabilityNullable}, + bindValue: expr.NewPrimitiveLiteral(int32(42), false), + errMsg: "parameter binding for anchor 0 has type", + }, + { + name: "fp64 param bound to i64 literal", + dpType: &types.Float64Type{Nullability: types.NullabilityRequired}, + bindValue: expr.NewPrimitiveLiteral(int64(100), false), + errMsg: "parameter binding for anchor 0 has type", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dp := &expr.DynamicParameter{ + OutputType: tt.dpType, + ParameterReference: 0, + } + + project, err := b.Project(scan, dp) + require.NoError(t, err) + + bindings := []plan.DynamicParameterBinding{ + { + ParameterAnchor: 0, + Value: tt.bindValue, + }, + } + + _, err = b.PlanWithBindings(project, []string{"x", "y", "p"}, nil, bindings) + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errMsg) + }) + } +} + +func TestDynamicParameterBindingMissingAnchor(t *testing.T) { + b := plan.NewBuilderDefault() + scan := b.NamedScan([]string{"test"}, baseSchema2) + + dp := &expr.DynamicParameter{ + OutputType: &types.Int32Type{Nullability: types.NullabilityRequired}, + ParameterReference: 0, + } + + project, err := b.Project(scan, dp) + require.NoError(t, err) + + bindings := []plan.DynamicParameterBinding{ + { + ParameterAnchor: 99, + Value: expr.NewPrimitiveLiteral(int32(42), false), + }, + } + + _, err = b.PlanWithBindings(project, []string{"x", "y", "p"}, nil, bindings) + require.Error(t, err) + assert.Contains(t, err.Error(), "no DynamicParameter with that reference exists") +} + +func TestDynamicParameterBindingNullabilityMismatch(t *testing.T) { + b := plan.NewBuilderDefault() + scan := b.NamedScan([]string{"test"}, baseSchema2) + + dp := &expr.DynamicParameter{ + OutputType: &types.Int32Type{Nullability: types.NullabilityRequired}, + ParameterReference: 0, + } + + project, err := b.Project(scan, dp) + require.NoError(t, err) + + bindings := []plan.DynamicParameterBinding{ + { + ParameterAnchor: 0, + Value: expr.NewPrimitiveLiteral(int32(42), true), // nullable literal + }, + } + + p, err := b.PlanWithBindings(project, []string{"x", "y", "p"}, nil, bindings) + require.NoError(t, err) + assert.NotNil(t, p) +} + +func TestDynamicParameterBindingInFilter(t *testing.T) { + b := plan.NewBuilderDefault() + scan := b.NamedScan([]string{"test"}, baseSchema2) + + dp := &expr.DynamicParameter{ + OutputType: &types.Int32Type{Nullability: types.NullabilityRequired}, + ParameterReference: 0, + } + + ref, err := b.RootFieldRef(scan, 0) + require.NoError(t, err) + + gt, err := b.ScalarFn(extensions.SubstraitDefaultURNPrefix+"functions_comparison", "gt", nil, ref, dp) + require.NoError(t, err) + + filter, err := b.Filter(scan, gt) + require.NoError(t, err) + + // Wrong type binding should fail + wrongBindings := []plan.DynamicParameterBinding{ + { + ParameterAnchor: 0, + Value: expr.NewPrimitiveLiteral("not-a-number", false), + }, + } + _, err = b.PlanWithBindings(filter, []string{"x", "y"}, nil, wrongBindings) + require.Error(t, err) + assert.Contains(t, err.Error(), "parameter binding for anchor 0 has type") + + // Correct type binding should succeed + goodBindings := []plan.DynamicParameterBinding{ + { + ParameterAnchor: 0, + Value: expr.NewPrimitiveLiteral(int32(42), false), + }, + } + p, err := b.PlanWithBindings(filter, []string{"x", "y"}, nil, goodBindings) + require.NoError(t, err) + assert.NotNil(t, p) +} diff --git a/plan/plan.go b/plan/plan.go index b7f26fff..9decc41d 100644 --- a/plan/plan.go +++ b/plan/plan.go @@ -152,11 +152,12 @@ type AdvancedExtension interface { // Plan describes a set of operations to complete. For // compactness, identifiers are normalized at the plan level. type Plan struct { - version *types.Version - extensions extensions.Set - expectedTypeURLs []string - advExtension *extensions.AdvancedExtension - relations []Relation + version *types.Version + extensions extensions.Set + expectedTypeURLs []string + advExtension *extensions.AdvancedExtension + relations []Relation + parameterBindings []DynamicParameterBinding reg expr.ExtensionRegistry } @@ -216,6 +217,14 @@ func (p *Plan) GetNonRootRelations() (rels []Rel) { return rels } +// ParameterBindings returns the list of dynamic parameter bindings for this plan. +// Each binding maps a parameter anchor to a runtime literal value. +// +// This returns a clone of the internal slice so that the plan itself remains immutable. +func (p *Plan) ParameterBindings() []DynamicParameterBinding { + return slices.Clone(p.parameterBindings) +} + func FromProto(plan *proto.Plan, c *extensions.Collection) (*Plan, error) { extSet, err := extensions.GetExtensionSet(plan, c) if err != nil { @@ -236,6 +245,16 @@ func FromProto(plan *proto.Plan, c *extensions.Collection) (*Plan, error) { } } + if len(plan.ParameterBindings) > 0 { + ret.parameterBindings = make([]DynamicParameterBinding, len(plan.ParameterBindings)) + for i, pb := range plan.ParameterBindings { + ret.parameterBindings[i] = DynamicParameterBinding{ + ParameterAnchor: pb.ParameterAnchor, + Value: expr.LiteralFromProto(pb.Value), + } + } + } + return ret, nil } @@ -245,6 +264,18 @@ func (p *Plan) ToProto() (*proto.Plan, error) { for i, r := range p.relations { relations[i] = r.ToProto() } + + var bindings []*proto.DynamicParameterBinding + if len(p.parameterBindings) > 0 { + bindings = make([]*proto.DynamicParameterBinding, len(p.parameterBindings)) + for i, b := range p.parameterBindings { + bindings[i] = &proto.DynamicParameterBinding{ + ParameterAnchor: b.ParameterAnchor, + Value: b.Value.ToProtoLiteral(), + } + } + } + return &proto.Plan{ Version: p.version, ExpectedTypeUrls: p.expectedTypeURLs, @@ -252,6 +283,7 @@ func (p *Plan) ToProto() (*proto.Plan, error) { Relations: relations, Extensions: decls, ExtensionUrns: urns, + ParameterBindings: bindings, }, nil } diff --git a/plan/testdata/dynamic_parameter_filter.json b/plan/testdata/dynamic_parameter_filter.json new file mode 100644 index 00000000..2ba906b3 --- /dev/null +++ b/plan/testdata/dynamic_parameter_filter.json @@ -0,0 +1,111 @@ +{ + "version": { + "minorNumber": 29, + "producer": "substrait-go (devel) darwin/arm64" + }, + "extensionUrns": [ + { + "extensionUrnAnchor": 1, + "urn": "extension:io.substrait:functions_comparison" + } + ], + "extensions": [ + { + "extensionFunction": { + "extensionUrnReference": 1, + "functionAnchor": 1, + "name": "gt:any_any" + } + } + ], + "relations": [ + { + "root": { + "input": { + "filter": { + "common": { + "direct": {} + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "x", + "y" + ], + "struct": { + "types": [ + { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + } + ], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "test" + ] + } + } + }, + "condition": { + "scalarFunction": { + "functionReference": 1, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + }, + { + "value": { + "dynamicParameter": { + "type": { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + } + } + } + } + ], + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + } + } + } + } + }, + "names": [ + "x", + "y" + ] + } + } + ], + "parameterBindings": [ + { + "value": { + "i32": 42 + } + } + ] +} diff --git a/plan/testdata/dynamic_parameter_plan.json b/plan/testdata/dynamic_parameter_plan.json new file mode 100644 index 00000000..ff62f98b --- /dev/null +++ b/plan/testdata/dynamic_parameter_plan.json @@ -0,0 +1,55 @@ +{ + "version": {"majorNumber": 0, "minorNumber": 29, "patchNumber": 0}, + "relations": [ + { + "root": { + "input": { + "project": { + "common": {"direct": {}}, + "input": { + "read": { + "common": {"direct": {}}, + "baseSchema": { + "names": ["id", "name"], + "struct": { + "nullability": "NULLABILITY_REQUIRED", + "types": [ + {"i64": {"nullability": "NULLABILITY_REQUIRED"}}, + {"string": {"nullability": "NULLABILITY_NULLABLE"}} + ] + } + }, + "namedTable": {"names": ["users"]} + } + }, + "expressions": [ + { + "dynamicParameter": { + "type": {"i32": {"nullability": "NULLABILITY_REQUIRED"}}, + "parameterReference": 0 + } + }, + { + "dynamicParameter": { + "type": {"string": {"nullability": "NULLABILITY_NULLABLE"}}, + "parameterReference": 1 + } + } + ] + } + }, + "names": ["id", "name", "param0", "param1"] + } + } + ], + "parameterBindings": [ + { + "parameterAnchor": 0, + "value": {"i32": 99} + }, + { + "parameterAnchor": 1, + "value": {"string": "test_value"} + } + ] +}