Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions expr/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,15 @@ func (e *ExprBuilder) Cast(from Builder, to types.Type) *castBuilder {
}
}

// DynamicParam returns a builder for constructing a DynamicParameter expression.
// The paramRef identifies the parameter binding in the plan.
func (e *ExprBuilder) DynamicParam(outputType types.Type, paramRef uint32) *dynamicParamBuilder {
return &dynamicParamBuilder{
outputType: outputType,
paramRef: paramRef,
}
}

// Lambda returns a builder for constructing a Lambda expression with the
// given parameters.
//
Expand Down Expand Up @@ -301,6 +310,24 @@ func (cb *castBuilder) FailBehavior(b types.CastFailBehavior) *castBuilder {
return cb
}

type dynamicParamBuilder struct {
outputType types.Type
paramRef uint32
}

func (dpb *dynamicParamBuilder) Build() (*DynamicParameter, error) {
if dpb.outputType == nil {
return nil, fmt.Errorf("%w: dynamic parameter must have an output type", substraitgo.ErrInvalidExpr)
}
return &DynamicParameter{
OutputType: dpb.outputType,
ParameterReference: dpb.paramRef,
}, nil
}

func (dpb *dynamicParamBuilder) BuildExpr() (Expression, error) { return dpb.Build() }
func (dpb *dynamicParamBuilder) BuildFuncArg() (types.FuncArg, error) { return dpb.Build() }

type scalarFuncBuilder struct {
b *ExprBuilder

Expand Down
20 changes: 20 additions & 0 deletions expr/dynamic_parameter_internal_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// SPDX-License-Identifier: Apache-2.0

package expr

import (
"testing"

"github.com/substrait-io/substrait-go/v7/types"
)

func TestDynamicParameterIsRootRef(t *testing.T) {
dp := &DynamicParameter{
OutputType: &types.Int32Type{Nullability: types.NullabilityRequired},
ParameterReference: 0,
}

// Verify DynamicParameter satisfies RootRefType
var _ RootRefType = dp
dp.isRootRef()
}
165 changes: 165 additions & 0 deletions expr/dynamic_parameter_test.go
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I asked claude to help reduce the size of these tests to make them a bit easier to follow and it produced this. What do you think? I think its a bit easier to see the intended test.

// SPDX-License-Identifier: Apache-2.0

package expr_test

import (
	"testing"

	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
	"github.com/substrait-io/substrait-go/v7/expr"
	"github.com/substrait-io/substrait-go/v7/extensions"
	"github.com/substrait-io/substrait-go/v7/types"
	proto "github.com/substrait-io/substrait-protobuf/go/substraitpb"
	pb "google.golang.org/protobuf/proto"
)

func TestDynamicParameterEquals(t *testing.T) {
	i64Req := &types.Int64Type{Nullability: types.NullabilityRequired}
	fp64Req := &types.Float64Type{Nullability: types.NullabilityRequired}

	base := &expr.DynamicParameter{OutputType: i64Req, ParameterReference: 0}

	tests := []struct {
		name  string
		other expr.Expression
		want  bool
	}{
		{"same type and ref", &expr.DynamicParameter{OutputType: i64Req, ParameterReference: 0}, true},
		{"different ref", &expr.DynamicParameter{OutputType: i64Req, ParameterReference: 1}, false},
		{"different type", &expr.DynamicParameter{OutputType: fp64Req, ParameterReference: 0}, false},
		{"different expression kind", expr.NewPrimitiveLiteral(int64(42), false), false},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			assert.Equal(t, tt.want, base.Equals(tt.other))
		})
	}
}

func TestDynamicParameterVisit(t *testing.T) {
	dp := &expr.DynamicParameter{
		OutputType:         &types.Int32Type{Nullability: types.NullabilityRequired},
		ParameterReference: 5,
	}

	visited := dp.Visit(func(e expr.Expression) expr.Expression { return e })
	assert.Same(t, dp, visited, "Visit should return same pointer for leaf expression")
}

func TestDynamicParameterToProtoRoundtrip(t *testing.T) {
	tests := []struct {
		name string
		dp   *expr.DynamicParameter
	}{
		{"required i32", &expr.DynamicParameter{
			OutputType: &types.Int32Type{Nullability: types.NullabilityRequired}, ParameterReference: 0}},
		{"nullable string", &expr.DynamicParameter{
			OutputType: &types.StringType{Nullability: types.NullabilityNullable}, ParameterReference: 1}},
	}

	reg := expr.NewEmptyExtensionRegistry(extensions.GetDefaultCollectionWithNoError())

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			assert.True(t, tt.dp.IsScalar())
			assert.True(t, tt.dp.GetType().Equals(tt.dp.OutputType))

			protoExpr := tt.dp.ToProto()
			require.NotNil(t, protoExpr)

			fromProto, err := expr.ExprFromProto(protoExpr, nil, reg)
			require.NoError(t, err)
			assert.True(t, tt.dp.Equals(fromProto), "roundtrip should produce equal expression")

			protoRoundTrip := fromProto.ToProto()
			assert.True(t, pb.Equal(protoExpr, protoRoundTrip), "proto roundtrip should be equal")
		})
	}
}

func TestDynamicParameterFromProtoNilDynamicParam(t *testing.T) {
	protoExpr := &proto.Expression{
		RexType: &proto.Expression_DynamicParameter{
			DynamicParameter: nil,
		},
	}

	_, err := expr.ExprFromProto(protoExpr, nil, expr.NewEmptyExtensionRegistry(extensions.GetDefaultCollectionWithNoError()))
	require.Error(t, err)
	assert.Contains(t, err.Error(), "dynamic parameter is nil")
}

func TestDynamicParameterBuilderNilType(t *testing.T) {
	b := expr.ExprBuilder{
		Reg: expr.NewEmptyExtensionRegistry(extensions.GetDefaultCollectionWithNoError()),
	}

	_, err := b.DynamicParam(nil, 0).BuildExpr()
	require.Error(t, err)
	assert.Contains(t, err.Error(), "dynamic parameter must have an output type")
}

func TestDynamicParameterBuilderAsFuncArg(t *testing.T) {
	b := expr.ExprBuilder{
		Reg:        expr.NewEmptyExtensionRegistry(extensions.GetDefaultCollectionWithNoError()),
		BaseSchema: types.NewRecordTypeFromStruct(boringSchema.Struct),
	}

	dpBuilder := b.DynamicParam(&types.Int8Type{Nullability: types.NullabilityRequired}, 0)

	e, err := b.ScalarFunc(addID).Args(
		dpBuilder,
		b.Wrap(expr.NewLiteral(int8(5), false)),
	).BuildExpr()
	require.NoError(t, err)
	assert.Contains(t, e.String(), "$0:i8")
}

func TestDynamicParameterTypeMismatchInFunction(t *testing.T) {
	b := expr.ExprBuilder{
		Reg:        expr.NewEmptyExtensionRegistry(extensions.GetDefaultCollectionWithNoError()),
		BaseSchema: types.NewRecordTypeFromStruct(boringSchema.Struct),
	}

	tests := []struct {
		name   string
		funcID extensions.ID
		dpType types.Type
		lit    func() (expr.Literal, error)
	}{
		{
			name:   "i32 where i8 expected",
			funcID: extensions.ID{URN: extensions.SubstraitDefaultURNPrefix + "functions_arithmetic", Name: "add:i8_i8"},
			dpType: &types.Int32Type{Nullability: types.NullabilityRequired},
			lit:    func() (expr.Literal, error) { return expr.NewLiteral(int8(5), false) },
		},
		{
			name:   "string where numeric expected",
			funcID: addID,
			dpType: &types.StringType{Nullability: types.NullabilityRequired},
			lit:    func() (expr.Literal, error) { return expr.NewLiteral(int32(5), false) },
		},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			_, err := b.ScalarFunc(tt.funcID).Args(
				b.DynamicParam(tt.dpType, 0),
				b.Wrap(tt.lit()),
			).BuildExpr()
			require.Error(t, err)
		})
	}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
// SPDX-License-Identifier: Apache-2.0

package expr_test

import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/substrait-io/substrait-go/v7/expr"
"github.com/substrait-io/substrait-go/v7/extensions"
"github.com/substrait-io/substrait-go/v7/types"
proto "github.com/substrait-io/substrait-protobuf/go/substraitpb"
pb "google.golang.org/protobuf/proto"
)

func TestDynamicParameterEquals(t *testing.T) {
i64Req := &types.Int64Type{Nullability: types.NullabilityRequired}
fp64Req := &types.Float64Type{Nullability: types.NullabilityRequired}

base := &expr.DynamicParameter{OutputType: i64Req, ParameterReference: 0}

tests := []struct {
name string
other expr.Expression
want bool
}{
{"same type and ref", &expr.DynamicParameter{OutputType: i64Req, ParameterReference: 0}, true},
{"different ref", &expr.DynamicParameter{OutputType: i64Req, ParameterReference: 1}, false},
{"different type", &expr.DynamicParameter{OutputType: fp64Req, ParameterReference: 0}, false},
{"different expression kind", expr.NewPrimitiveLiteral(int64(42), false), false},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.want, base.Equals(tt.other))
})
}
}

func TestDynamicParameterVisit(t *testing.T) {
dp := &expr.DynamicParameter{
OutputType: &types.Int32Type{Nullability: types.NullabilityRequired},
ParameterReference: 5,
}

visited := dp.Visit(func(e expr.Expression) expr.Expression { return e })
assert.Same(t, dp, visited, "Visit should return same pointer for leaf expression")
}

// TestDynamicParameterToProtoRoundtrip tests construction, interface compliance,
// and proto roundtrip for various DynamicParameter configurations.
// The $N:type String() format (e.g. "$0:i32") is an internal debugging
// representation used by this library; it is not part of the Substrait spec.
func TestDynamicParameterToProtoRoundtrip(t *testing.T) {
tests := []struct {
name string
dp *expr.DynamicParameter
}{
{"required i32", &expr.DynamicParameter{
OutputType: &types.Int32Type{Nullability: types.NullabilityRequired}, ParameterReference: 0}},
{"nullable string", &expr.DynamicParameter{
OutputType: &types.StringType{Nullability: types.NullabilityNullable}, ParameterReference: 1}},
{"required fp64", &expr.DynamicParameter{
OutputType: &types.Float64Type{Nullability: types.NullabilityRequired}, ParameterReference: 5}},
{"required boolean", &expr.DynamicParameter{
OutputType: &types.BooleanType{Nullability: types.NullabilityRequired}, ParameterReference: 10}},
{"nullable i64", &expr.DynamicParameter{
OutputType: &types.Int64Type{Nullability: types.NullabilityNullable}, ParameterReference: 42}},
}

reg := expr.NewEmptyExtensionRegistry(extensions.GetDefaultCollectionWithNoError())

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.True(t, tt.dp.IsScalar())
assert.True(t, tt.dp.GetType().Equals(tt.dp.OutputType))

protoExpr := tt.dp.ToProto()
require.NotNil(t, protoExpr)

fromProto, err := expr.ExprFromProto(protoExpr, nil, reg)
require.NoError(t, err)
assert.True(t, tt.dp.Equals(fromProto), "roundtrip should produce equal expression")

protoRoundTrip := fromProto.ToProto()
assert.True(t, pb.Equal(protoExpr, protoRoundTrip), "proto roundtrip should be equal")
})
}
}

func TestDynamicParameterFromProtoNilDynamicParam(t *testing.T) {
protoExpr := &proto.Expression{
RexType: &proto.Expression_DynamicParameter{
DynamicParameter: nil,
},
}

_, err := expr.ExprFromProto(protoExpr, nil, expr.NewEmptyExtensionRegistry(extensions.GetDefaultCollectionWithNoError()))
require.Error(t, err)
assert.Contains(t, err.Error(), "dynamic parameter is nil")
}

func TestDynamicParameterBuilderNilType(t *testing.T) {
b := expr.ExprBuilder{
Reg: expr.NewEmptyExtensionRegistry(extensions.GetDefaultCollectionWithNoError()),
}

_, err := b.DynamicParam(nil, 0).BuildExpr()
require.Error(t, err)
assert.Contains(t, err.Error(), "dynamic parameter must have an output type")
}

func TestDynamicParameterBuilderAsFuncArg(t *testing.T) {
b := expr.ExprBuilder{
Reg: expr.NewEmptyExtensionRegistry(extensions.GetDefaultCollectionWithNoError()),
BaseSchema: types.NewRecordTypeFromStruct(boringSchema.Struct),
}

dpBuilder := b.DynamicParam(&types.Int8Type{Nullability: types.NullabilityRequired}, 0)

e, err := b.ScalarFunc(addID).Args(
dpBuilder,
b.Wrap(expr.NewLiteral(int8(5), false)),
).BuildExpr()
require.NoError(t, err)
assert.Contains(t, e.String(), "$0:i8")
}

func TestDynamicParameterTypeMismatchInFunction(t *testing.T) {
b := expr.ExprBuilder{
Reg: expr.NewEmptyExtensionRegistry(extensions.GetDefaultCollectionWithNoError()),
BaseSchema: types.NewRecordTypeFromStruct(boringSchema.Struct),
}

tests := []struct {
name string
funcID extensions.ID
dpType types.Type
lit func() (expr.Literal, error)
}{
{
name: "i32 where i8 expected",
funcID: extensions.ID{URN: extensions.SubstraitDefaultURNPrefix + "functions_arithmetic", Name: "add:i8_i8"},
dpType: &types.Int32Type{Nullability: types.NullabilityRequired},
lit: func() (expr.Literal, error) { return expr.NewLiteral(int8(5), false) },
},
{
name: "string where numeric expected",
funcID: addID,
dpType: &types.StringType{Nullability: types.NullabilityRequired},
lit: func() (expr.Literal, error) { return expr.NewLiteral(int32(5), false) },
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := b.ScalarFunc(tt.funcID).Args(
b.DynamicParam(tt.dpType, 0),
b.Wrap(tt.lit()),
).BuildExpr()
require.Error(t, err)
})
}
}
57 changes: 57 additions & 0 deletions expr/expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,14 @@ func ExprFromProto(e *proto.Expression, baseSchema *types.RecordType, reg Extens
return nil, fmt.Errorf("%w: nested expression: %s",
substraitgo.ErrInvalidExpr, n)
}
case *proto.Expression_DynamicParameter:
if et.DynamicParameter == nil {
return nil, fmt.Errorf("%w: dynamic parameter is nil", substraitgo.ErrInvalidExpr)
}
return &DynamicParameter{
OutputType: types.TypeFromProto(et.DynamicParameter.Type),
ParameterReference: et.DynamicParameter.ParameterReference,
}, nil
case *proto.Expression_Enum_:
return nil, fmt.Errorf("%w: deprecated", substraitgo.ErrNotImplemented)
case *proto.Expression_Subquery_:
Expand Down Expand Up @@ -372,6 +380,7 @@ type VisitFunc func(Expression) Expression
// - A Cast expression
// - A Subquery
// - A Nested expression
// - A Dynamic Parameter
type Expression interface {
// an Expression can also be a function argument
types.FuncArg
Expand Down Expand Up @@ -654,6 +663,54 @@ func (ex *Cast) Visit(visit VisitFunc) Expression {
return &newCast
}

type DynamicParameter struct {
OutputType types.Type
ParameterReference uint32
}

func (dp *DynamicParameter) String() string {
return fmt.Sprintf("$%d:%s", dp.ParameterReference, dp.OutputType)
}

func (dp *DynamicParameter) ToProtoFuncArg() *proto.FunctionArgument {
return &proto.FunctionArgument{
ArgType: &proto.FunctionArgument_Value{
Value: dp.ToProto(),
},
}
}

func (dp *DynamicParameter) isRootRef() {}

func (dp *DynamicParameter) IsScalar() bool { return true }

func (dp *DynamicParameter) GetType() types.Type { return dp.OutputType }

func (dp *DynamicParameter) ToProto() *proto.Expression {
return &proto.Expression{
RexType: &proto.Expression_DynamicParameter{
DynamicParameter: &proto.DynamicParameter{
Type: types.TypeToProto(dp.OutputType),
ParameterReference: dp.ParameterReference,
},
},
}
}

func (dp *DynamicParameter) Equals(other Expression) bool {
rhs, ok := other.(*DynamicParameter)
if !ok {
return false
}

return dp.ParameterReference == rhs.ParameterReference &&
dp.OutputType.Equals(rhs.OutputType)
}

func (dp *DynamicParameter) Visit(visit VisitFunc) Expression {
return dp
}

type SwitchExpr struct {
match Expression
ifs []struct {
Expand Down
27 changes: 22 additions & 5 deletions plan/builders.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,10 @@ type Builder interface {
// that may be in use with this plan for advanced extensions, optimizations,
// and so on.
PlanWithTypes(root Rel, rootNames []string, expectedTypeURLs []string, others ...Rel) (*Plan, error)
// PlanWithBindings is the same as PlanWithTypes, but additionally accepts
// dynamic parameter bindings. These bind DynamicParameter expressions in
// the plan tree to concrete literal values at runtime.
PlanWithBindings(root Rel, rootNames []string, expectedTypeURLs []string, bindings []DynamicParameterBinding, others ...Rel) (*Plan, error)

// GetExprBuilder returns an expr.ExprBuilder that shares the extension
// registry that this Builder uses.
Expand Down Expand Up @@ -739,6 +743,14 @@ func (b *builder) Set(op SetOp, inputs ...Rel) (*SetRel, error) {
}

func (b *builder) PlanWithTypes(root Rel, rootNames []string, expectedTypeURLs []string, others ...Rel) (*Plan, error) {
return b.PlanWithBindings(root, rootNames, expectedTypeURLs, nil, others...)
}

// PlanWithBindings constructs a new plan with the provided root relation,
// expected type URLs, and dynamic parameter bindings. This allows creating
// parameterized plans where expressions contain DynamicParameter references
// that are bound to concrete values.
func (b *builder) PlanWithBindings(root Rel, rootNames []string, expectedTypeURLs []string, bindings []DynamicParameterBinding, others ...Rel) (*Plan, error) {
if root == nil {
return nil, fmt.Errorf("%w: must provide non-nil root relation for plan",
substraitgo.ErrInvalidRel)
Expand All @@ -748,6 +760,10 @@ func (b *builder) PlanWithTypes(root Rel, rootNames []string, expectedTypeURLs [
return nil, err
}

if err := ValidateParameterBindings(root, bindings); err != nil {
return nil, err
}

relations := make([]Relation, len(others)+1)
relations[0].root = &Root{
input: root, names: rootNames,
Expand All @@ -758,11 +774,12 @@ func (b *builder) PlanWithTypes(root Rel, rootNames []string, expectedTypeURLs [
}

return &Plan{
version: &CurrentVersion,
extensions: b.extSet,
reg: b.reg,
expectedTypeURLs: expectedTypeURLs,
relations: relations,
version: &CurrentVersion,
extensions: b.extSet,
reg: b.reg,
expectedTypeURLs: expectedTypeURLs,
relations: relations,
parameterBindings: bindings,
}, nil
}

Expand Down
Loading
Loading