Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 40 additions & 16 deletions testcases/parser/nodes.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,28 @@ const (
WindowFuncType TestFuncType = "window"
)

// TestCaseResult is the expected result of a test case. It is either a
// *CaseLiteral (a concrete value or NULL) or a NonValueOutcome (<!ERROR> or <!UNDEFINED>).
type TestCaseResult interface {
String() string
// isTestCaseResult restricts TestCaseResult to types in this package.
isTestCaseResult()
}

// CaseLiteral is a typed literal value used as a function argument or expected
// result in a Substrait test case (e.g. "120::i8" or "NULL").
type CaseLiteral struct {
Type types.Type
ValueText string
Value types.FuncArg
SubstraitError *SubstraitError
// Type is the Substrait type of the literal (e.g. i8, fp64, string).
Type types.Type
// ValueText is the original text representation of the value as it appeared in the test case.
ValueText string
// Value is the parsed literal value or enum argument. Nil for NULL literals.
Value types.FuncArg
}

func (*CaseLiteral) isTestCaseResult() {}

func (c *CaseLiteral) String() string {
if c.SubstraitError != nil {
return c.SubstraitError.String()
}
if c.Value == nil {
return "NULL"
}
Expand Down Expand Up @@ -70,9 +81,6 @@ func literalToString(literal expr.Literal) string {
}

func (c *CaseLiteral) AsAggregateArgumentString() string {
if c.SubstraitError != nil {
return c.SubstraitError.String()
}
if list, ok := c.Value.(*expr.ListLiteral); ok {
var elements []string
for _, element := range list.Value {
Expand Down Expand Up @@ -118,7 +126,7 @@ type TestCase struct {
FuncName string
Args []*CaseLiteral
AggregateArgs []*AggregateArgument
Result *CaseLiteral
Result TestCaseResult
Options FuncOptions
Columns [][]expr.Literal
TableName string
Expand Down Expand Up @@ -486,10 +494,26 @@ type CompactAggregateFuncCall struct {
AggregateArgs []*AggregateArgument
}

type SubstraitError struct {
Error string
}
// NonValueOutcome represents a test case result that is not a concrete value.
// Per the Substrait test case spec, a result may be <!ERROR> (the operation
// must fail) or <!UNDEFINED> (the operation may return any value).
type NonValueOutcome int

func (e SubstraitError) String() string {
return "<!" + e.Error + ">"
func (NonValueOutcome) isTestCaseResult() {}

const (
// NonValueError indicates the function is expected to fail with an error.
NonValueError NonValueOutcome = iota + 1
// NonValueUndefined indicates the result is implementation-defined; any value is acceptable.
NonValueUndefined
)

func (e NonValueOutcome) String() string {
switch e {
case NonValueError:
return "<!ERROR>"
case NonValueUndefined:
return "<!UNDEFINED>"
}
return "<!UNKNOWN>"
}
54 changes: 33 additions & 21 deletions testcases/parser/parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,18 @@ import (
"github.com/substrait-io/substrait-go/v8/types"
)

func resultLiteral(t *testing.T, r TestCaseResult) *CaseLiteral {
t.Helper()
lit, ok := r.(*CaseLiteral)
require.True(t, ok, "expected *CaseLiteral, got %T", r)
return lit
}

func TestNonValueOutcomeString(t *testing.T) {
assert.Equal(t, "<!ERROR>", NonValueError.String())
assert.Equal(t, "<!UNDEFINED>", NonValueUndefined.String())
}

func makeHeader(version, include string) string {
return fmt.Sprintf("### SUBSTRAIT_SCALAR_TEST: %s\n### SUBSTRAIT_INCLUDE: '%s'\n\n", version, include)
}
Expand Down Expand Up @@ -93,8 +105,8 @@ lt('2016-12-31T13:30:15'::ts, '2017-12-31T13:30:15'::ts) = true::bool
require.NoError(t, err)
assert.Equal(t, tsLiteral, testFile.TestCases[0].Args[1].Value)
boolLiteral := literal.NewBool(true, false)
assert.Equal(t, boolLiteral, testFile.TestCases[0].Result.Value)
assert.Equal(t, &types.BooleanType{Nullability: types.NullabilityRequired}, testFile.TestCases[0].Result.Type)
assert.Equal(t, boolLiteral, resultLiteral(t, testFile.TestCases[0].Result).Value)
assert.Equal(t, &types.BooleanType{Nullability: types.NullabilityRequired}, resultLiteral(t, testFile.TestCases[0].Result).Type)
timestampType := &types.TimestampType{Nullability: types.NullabilityRequired}
assert.Equal(t, timestampType, testFile.TestCases[0].Args[0].Type)
assert.Equal(t, timestampType, testFile.TestCases[0].Args[1].Type)
Expand Down Expand Up @@ -134,13 +146,13 @@ add(0.5::dec<1, 1>, 0.25::dec<2, 2>) = 0.75::dec<5, 2>
f641 := literal.NewFloat64(1, false)
assert.Equal(t, dec8, testFile.TestCases[0].Args[0].Value)
assert.Equal(t, dec2, testFile.TestCases[0].Args[1].Value)
assert.Equal(t, f6464, testFile.TestCases[0].Result.Value)
assert.Equal(t, f6464, resultLiteral(t, testFile.TestCases[0].Result).Value)
assert.Equal(t, dec1, testFile.TestCases[1].Args[0].Value)
assert.Equal(t, decMinus1Point0, testFile.TestCases[1].Args[1].Value)
assert.Equal(t, f641, testFile.TestCases[1].Result.Value)
assert.Equal(t, f641, resultLiteral(t, testFile.TestCases[1].Result).Value)
assert.Equal(t, decMinus1, testFile.TestCases[2].Args[0].Value)
assert.Equal(t, decPoint5, testFile.TestCases[2].Args[1].Value)
assert.Equal(t, "fp64(NaN)", testFile.TestCases[2].Result.Value.String())
assert.Equal(t, "fp64(NaN)", resultLiteral(t, testFile.TestCases[2].Result).Value.String())

decPoint25Value, _ := literal.NewDecimalFromString("0.25", false)
decPoint75Value, _ := literal.NewDecimalFromString("0.75", false)
Expand All @@ -149,7 +161,7 @@ add(0.5::dec<1, 1>, 0.25::dec<2, 2>) = 0.75::dec<5, 2>
decPoint5, _ = decPoint5Value.(expr.WithTypeLiteral).WithType(&types.DecimalType{Precision: 1, Scale: 1, Nullability: types.NullabilityRequired})
assert.Equal(t, decPoint5, testFile.TestCases[3].Args[0].Value)
assert.Equal(t, decPoint25, testFile.TestCases[3].Args[1].Value)
assert.Equal(t, decPoint75, testFile.TestCases[3].Result.Value)
assert.Equal(t, decPoint75, resultLiteral(t, testFile.TestCases[3].Result).Value)
}

func TestParseTestWithVariousTypes(t *testing.T) {
Expand Down Expand Up @@ -218,9 +230,9 @@ func TestParseTestWithVariousTypes(t *testing.T) {
t.Errorf("unexpected arg value type %T", v)
}
}
resultLit, ok := testFile.TestCases[0].Result.Value.(expr.Literal)
require.True(t, ok, "result should be a literal, got %T", testFile.TestCases[0].Result.Value)
checkNullability(t, resultLit, testFile.TestCases[0].Result.Type)
resultLit, ok := resultLiteral(t, testFile.TestCases[0].Result).Value.(expr.Literal)
require.True(t, ok, "result should be a literal, got %T", resultLiteral(t, testFile.TestCases[0].Result).Value)
checkNullability(t, resultLit, resultLiteral(t, testFile.TestCases[0].Result).Type)
})
}
}
Expand Down Expand Up @@ -257,7 +269,7 @@ starts_with('abcd'::str, 'AB'::str) [case_sensitivity:CASE_INSENSITIVE] = true::
strRes := literal.NewString("abcdef", false)
assert.Equal(t, strAbc, testFile.TestCases[0].Args[0].Value)
assert.Equal(t, strDef, testFile.TestCases[0].Args[1].Value)
assert.Equal(t, strRes, testFile.TestCases[0].Result.Value)
assert.Equal(t, strRes, resultLiteral(t, testFile.TestCases[0].Result).Value)

strArg1 := literal.NewString("HHHelloooo", false)
strArg2 := literal.NewString("Hel+", false)
Expand All @@ -267,17 +279,17 @@ starts_with('abcd'::str, 'AB'::str) [case_sensitivity:CASE_INSENSITIVE] = true::
strRes1 := literal.NewString("HH", false)
strRes2 := literal.NewString("oooo", false)
result, _ := literal.NewList([]expr.Literal{strRes1, strRes2}, false)
assert.Equal(t, result, testFile.TestCases[1].Result.Value)
assert.Equal(t, result, resultLiteral(t, testFile.TestCases[1].Result).Value)

str1 := literal.NewString("à", false)
i642 := literal.NewInt64(2, false)
assert.Equal(t, str1, testFile.TestCases[2].Args[0].Value)
assert.Equal(t, i642, testFile.TestCases[2].Result.Value)
assert.Equal(t, i642, resultLiteral(t, testFile.TestCases[2].Result).Value)

str2 := literal.NewString("😄", false)
i644 := literal.NewInt64(4, false)
assert.Equal(t, str2, testFile.TestCases[3].Args[0].Value)
assert.Equal(t, i644, testFile.TestCases[3].Result.Value)
assert.Equal(t, i644, resultLiteral(t, testFile.TestCases[3].Result).Value)

}

Expand All @@ -301,8 +313,8 @@ some_func('abc'::str, 'def'::str) = [1, 2, 3, 4, 5, 6]::List<i8>`
literal.NewInt8(1, false), literal.NewInt8(2, false), literal.NewInt8(3, false),
literal.NewInt8(4, false), literal.NewInt8(5, false), literal.NewInt8(6, false),
}, false)
assert.Equal(t, list, testFile.TestCases[0].Result.Value)
assert.Equal(t, i8List, testFile.TestCases[0].Result.Type)
assert.Equal(t, list, resultLiteral(t, testFile.TestCases[0].Result).Value)
assert.Equal(t, i8List, resultLiteral(t, testFile.TestCases[0].Result).Type)
}

func TestParseNestedListLiteral(t *testing.T) {
Expand Down Expand Up @@ -439,8 +451,8 @@ sum((9223372036854775806, 1, 1, 1, 1, 10000000000)::i64) [overflow:ERROR] = <!ER
lit, ok := tc.AggregateArgs[0].Argument.Value.(expr.Literal)
require.True(t, ok, "aggregate arg should be a literal, got %T", tc.AggregateArgs[0].Argument.Value)
assert.Equal(t, listType, lit.GetType())
assert.Equal(t, "fp64", tc.Result.Type.String())
assert.Equal(t, literal.NewFloat64(2, false), tc.Result.Value)
assert.Equal(t, "fp64", resultLiteral(t, tc.Result).Type.String())
assert.Equal(t, literal.NewFloat64(2, false), resultLiteral(t, tc.Result).Value)
assert.Equal(t, AggregateFuncType, tc.FuncType)
_, err = tc.GetScalarFunctionInvocation(nil, nil)
require.Error(t, err)
Expand All @@ -466,7 +478,7 @@ sum((9223372036854775806, 1, 1, 1, 1, 10000000000)::i64) [overflow:ERROR] = <!ER
require.Equal(t, 1, aggregateFunc.NArgs())
assert.Equal(t, &types.Float32Type{Nullability: types.NullabilityNullable}, aggArg.GetType())
assert.Equal(t, []types.Type{&types.Float32Type{Nullability: types.NullabilityNullable}}, tc.GetArgTypes())
assert.Equal(t, tc.Result.Type, &types.Float64Type{Nullability: types.NullabilityNullable})
assert.Equal(t, resultLiteral(t, tc.Result).Type, &types.Float64Type{Nullability: types.NullabilityNullable})
argValues := newFloat32Values(true, 1, 2, 3)
argValues = append(argValues, &expr.NullLiteral{Type: &types.Float32Type{Nullability: types.NullabilityNullable}})
argList, _ := literal.NewList(argValues, false)
Expand Down Expand Up @@ -572,14 +584,14 @@ func TestParseAggregateFuncCompact(t *testing.T) {
createAggregateArg(t, "", "col1", f32Type),
}
assert.Equal(t, args, tc.AggregateArgs)
assert.Equal(t, "fp64", tc.Result.Type.String())
assert.Equal(t, literal.NewFloat64(1, false), tc.Result.Value)
assert.Equal(t, "fp64", resultLiteral(t, tc.Result).Type.String())
assert.Equal(t, literal.NewFloat64(1, false), resultLiteral(t, tc.Result).Value)
assert.Equal(t, testString, tc.String())

tc = testFile.TestCases[1]
args[1] = createAggregateArg(t, "", "col1", f32Type.WithNullability(types.NullabilityNullable))
assert.Equal(t, args, tc.AggregateArgs)
assert.Equal(t, "fp64?", tc.Result.Type.String())
assert.Equal(t, "fp64?", resultLiteral(t, tc.Result).Type.String())
}

func createAggregateArg(t *testing.T, tableName, columnName string, columnType types.Type) *AggregateArgument {
Expand Down
19 changes: 6 additions & 13 deletions testcases/parser/visitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ func (v *TestCaseVisitor) VisitAggregateFuncTestGroup(ctx *baseparser.AggregateF

func (v *TestCaseVisitor) VisitAggFuncTestCase(ctx *baseparser.AggFuncTestCaseContext) interface{} {
testcase := v.Visit(ctx.AggFuncCall()).(*TestCase)
testcase.Result = v.Visit(ctx.Result()).(*CaseLiteral)
testcase.Result = v.Visit(ctx.Result()).(TestCaseResult)
if ctx.FuncOptions() != nil {
testcase.Options = v.Visit(ctx.FuncOptions()).(FuncOptions)
}
Expand All @@ -136,7 +136,7 @@ func (v *TestCaseVisitor) VisitSingleArgAggregateFuncCall(ctx *baseparser.Single
return &TestCase{
FuncName: ctx.Identifier().GetText(),
AggregateArgs: []*AggregateArgument{{Argument: arg, ColumnType: arg.Type}},
Result: &CaseLiteral{SubstraitError: &SubstraitError{Error: "uninitialized"}},
Result: &CaseLiteral{},
}
}

Expand Down Expand Up @@ -342,7 +342,7 @@ func (v *TestCaseVisitor) VisitTestCase(ctx *baseparser.TestCaseContext) interfa
return &TestCase{
FuncName: ctx.Identifier().GetText(),
Args: v.Visit(ctx.Arguments()).([]*CaseLiteral),
Result: v.Visit(ctx.Result()).(*CaseLiteral),
Result: v.Visit(ctx.Result()).(TestCaseResult),
Options: options,
}
}
Expand Down Expand Up @@ -842,17 +842,10 @@ func (v *TestCaseVisitor) VisitResult(ctx *baseparser.ResultContext) interface{}
}

func (v *TestCaseVisitor) VisitSubstraitError(ctx *baseparser.SubstraitErrorContext) interface{} {
err := &SubstraitError{Error: "UNKNOWN"}
if ctx.ErrorResult() != nil {
err.Error = "ERROR"
} else if ctx.UndefineResult() != nil {
err.Error = "UNDEFINED"
}
return &CaseLiteral{
Type: nil,
ValueText: ctx.GetText(),
SubstraitError: err,
if ctx.UndefineResult() != nil {
return NonValueUndefined
}
return NonValueError
}

func (v *TestCaseVisitor) VisitBoolean(ctx *baseparser.BooleanContext) interface{} {
Expand Down
Loading