diff --git a/extensions/variants_test.go b/extensions/variants_test.go index f5fc2c5..ec585f7 100644 --- a/extensions/variants_test.go +++ b/extensions/variants_test.go @@ -25,6 +25,7 @@ func TestEvaluateTypeExpression(t *testing.T) { strNull, _ = parser.ParseType("string?") strNonNull, _ = parser.ParseType("string") any1NonNull, _ = parser.ParseType("any1") + any1Nullable, _ = parser.ParseType("any1?") any1listNonNull = mkFuncArgList(any1NonNull) // Few shortcut type definitions. @@ -89,6 +90,17 @@ func TestEvaluateTypeExpression(t *testing.T) { args: []types.Type{i64TypeReq, i64TypeReq}, expected: i64TypeReq, }, + { + // nullif with DECLARED_OUTPUT: return type is any1? (nullable). + // Even when both arguments are non-nullable, the result must be + // nullable because NULLIF returns NULL when args are equal. + name: "nullif(any1, any1) -> any1? [DECLARED_OUTPUT]", + nulls: extensions.DeclaredOutputNullability, + ret: any1Nullable, + extArgs: extensions.FuncParameterList{valArg(any1NonNull), valArg(any1NonNull)}, + args: []types.Type{i64TypeReq, i64TypeReq}, + expected: &types.Int64Type{Nullability: types.NullabilityNullable}, + }, { name: "element_at(list, i64) -> any1", nulls: extensions.DeclaredOutputNullability, diff --git a/types/any_type.go b/types/any_type.go index 4f11e29..4b065b7 100644 --- a/types/any_type.go +++ b/types/any_type.go @@ -128,6 +128,16 @@ func (m *AnyType) ReturnType(funcParameters []FuncDefArgType, argumentTypes []Ty return nil, err } if typ != nil { + // Apply the declared nullability from the return type definition. + // For example, nullif declares `return: any1?` — the `?` means + // the result is always nullable regardless of argument nullability. + // + // TODO: It is unclear if a declared non-nullable return type (e.g. `any1`) + // should override a nullable argument type to produce a non-nullable result. + // This is being resolved in https://github.com/substrait-io/substrait/issues/943 + if m.Nullability == NullabilityNullable { + typ = typ.WithNullability(NullabilityNullable) + } return typ, nil } } diff --git a/types/any_type_test.go b/types/any_type_test.go index cce888b..49eee57 100644 --- a/types/any_type_test.go +++ b/types/any_type_test.go @@ -25,7 +25,7 @@ func TestAnyType(t *testing.T) { argName: "any", parameters: []FuncDefArgType{&AnyType{Name: "any"}}, args: []Type{decP30S9}, - concreteReturnType: decP30S9, + concreteReturnType: &DecimalType{Precision: 30, Scale: 9, Nullability: NullabilityNullable}, nullability: NullabilityNullable, expectedString: "any?", }, @@ -325,7 +325,7 @@ func TestAnyType(t *testing.T) { argName: "any1", parameters: []FuncDefArgType{&AnyType{Name: "any1"}, &Int32Type{}}, args: []Type{varchar37, &Int32Type{}}, - concreteReturnType: varchar37, + concreteReturnType: &VarCharType{Length: 37, Nullability: NullabilityNullable}, nullability: NullabilityNullable, expectedString: "any1?", },