diff --git a/testcases/parser/nodes.go b/testcases/parser/nodes.go index f005f9d..cd02876 100644 --- a/testcases/parser/nodes.go +++ b/testcases/parser/nodes.go @@ -20,17 +20,28 @@ const ( WindowFuncType TestFuncType = "window" ) +// TestCaseResult is the expected result of a test case. It is either a +// *CaseLiteral (a concrete value or NULL) or a NonValueOutcome ( or ). +type TestCaseResult interface { + String() string + // isTestCaseResult restricts TestCaseResult to types in this package. + isTestCaseResult() +} + +// CaseLiteral is a typed literal value used as a function argument or expected +// result in a Substrait test case (e.g. "120::i8" or "NULL"). type CaseLiteral struct { - Type types.Type - ValueText string - Value types.FuncArg - SubstraitError *SubstraitError + // Type is the Substrait type of the literal (e.g. i8, fp64, string). + Type types.Type + // ValueText is the original text representation of the value as it appeared in the test case. + ValueText string + // Value is the parsed literal value or enum argument. Nil for NULL literals. + Value types.FuncArg } +func (*CaseLiteral) isTestCaseResult() {} + func (c *CaseLiteral) String() string { - if c.SubstraitError != nil { - return c.SubstraitError.String() - } if c.Value == nil { return "NULL" } @@ -70,9 +81,6 @@ func literalToString(literal expr.Literal) string { } func (c *CaseLiteral) AsAggregateArgumentString() string { - if c.SubstraitError != nil { - return c.SubstraitError.String() - } if list, ok := c.Value.(*expr.ListLiteral); ok { var elements []string for _, element := range list.Value { @@ -118,7 +126,7 @@ type TestCase struct { FuncName string Args []*CaseLiteral AggregateArgs []*AggregateArgument - Result *CaseLiteral + Result TestCaseResult Options FuncOptions Columns [][]expr.Literal TableName string @@ -486,10 +494,26 @@ type CompactAggregateFuncCall struct { AggregateArgs []*AggregateArgument } -type SubstraitError struct { - Error string -} +// NonValueOutcome represents a test case result that is not a concrete value. +// Per the Substrait test case spec, a result may be (the operation +// must fail) or (the operation may return any value). +type NonValueOutcome int -func (e SubstraitError) String() string { - return "" +func (NonValueOutcome) isTestCaseResult() {} + +const ( + // NonValueError indicates the function is expected to fail with an error. + NonValueError NonValueOutcome = iota + 1 + // NonValueUndefined indicates the result is implementation-defined; any value is acceptable. + NonValueUndefined +) + +func (e NonValueOutcome) String() string { + switch e { + case NonValueError: + return "" + case NonValueUndefined: + return "" + } + return "" } diff --git a/testcases/parser/parse_test.go b/testcases/parser/parse_test.go index 7e77056..a845c83 100644 --- a/testcases/parser/parse_test.go +++ b/testcases/parser/parse_test.go @@ -16,6 +16,18 @@ import ( "github.com/substrait-io/substrait-go/v8/types" ) +func resultLiteral(t *testing.T, r TestCaseResult) *CaseLiteral { + t.Helper() + lit, ok := r.(*CaseLiteral) + require.True(t, ok, "expected *CaseLiteral, got %T", r) + return lit +} + +func TestNonValueOutcomeString(t *testing.T) { + assert.Equal(t, "", NonValueError.String()) + assert.Equal(t, "", NonValueUndefined.String()) +} + func makeHeader(version, include string) string { return fmt.Sprintf("### SUBSTRAIT_SCALAR_TEST: %s\n### SUBSTRAIT_INCLUDE: '%s'\n\n", version, include) } @@ -93,8 +105,8 @@ lt('2016-12-31T13:30:15'::ts, '2017-12-31T13:30:15'::ts) = true::bool require.NoError(t, err) assert.Equal(t, tsLiteral, testFile.TestCases[0].Args[1].Value) boolLiteral := literal.NewBool(true, false) - assert.Equal(t, boolLiteral, testFile.TestCases[0].Result.Value) - assert.Equal(t, &types.BooleanType{Nullability: types.NullabilityRequired}, testFile.TestCases[0].Result.Type) + assert.Equal(t, boolLiteral, resultLiteral(t, testFile.TestCases[0].Result).Value) + assert.Equal(t, &types.BooleanType{Nullability: types.NullabilityRequired}, resultLiteral(t, testFile.TestCases[0].Result).Type) timestampType := &types.TimestampType{Nullability: types.NullabilityRequired} assert.Equal(t, timestampType, testFile.TestCases[0].Args[0].Type) assert.Equal(t, timestampType, testFile.TestCases[0].Args[1].Type) @@ -134,13 +146,13 @@ add(0.5::dec<1, 1>, 0.25::dec<2, 2>) = 0.75::dec<5, 2> f641 := literal.NewFloat64(1, false) assert.Equal(t, dec8, testFile.TestCases[0].Args[0].Value) assert.Equal(t, dec2, testFile.TestCases[0].Args[1].Value) - assert.Equal(t, f6464, testFile.TestCases[0].Result.Value) + assert.Equal(t, f6464, resultLiteral(t, testFile.TestCases[0].Result).Value) assert.Equal(t, dec1, testFile.TestCases[1].Args[0].Value) assert.Equal(t, decMinus1Point0, testFile.TestCases[1].Args[1].Value) - assert.Equal(t, f641, testFile.TestCases[1].Result.Value) + assert.Equal(t, f641, resultLiteral(t, testFile.TestCases[1].Result).Value) assert.Equal(t, decMinus1, testFile.TestCases[2].Args[0].Value) assert.Equal(t, decPoint5, testFile.TestCases[2].Args[1].Value) - assert.Equal(t, "fp64(NaN)", testFile.TestCases[2].Result.Value.String()) + assert.Equal(t, "fp64(NaN)", resultLiteral(t, testFile.TestCases[2].Result).Value.String()) decPoint25Value, _ := literal.NewDecimalFromString("0.25", false) decPoint75Value, _ := literal.NewDecimalFromString("0.75", false) @@ -149,7 +161,7 @@ add(0.5::dec<1, 1>, 0.25::dec<2, 2>) = 0.75::dec<5, 2> decPoint5, _ = decPoint5Value.(expr.WithTypeLiteral).WithType(&types.DecimalType{Precision: 1, Scale: 1, Nullability: types.NullabilityRequired}) assert.Equal(t, decPoint5, testFile.TestCases[3].Args[0].Value) assert.Equal(t, decPoint25, testFile.TestCases[3].Args[1].Value) - assert.Equal(t, decPoint75, testFile.TestCases[3].Result.Value) + assert.Equal(t, decPoint75, resultLiteral(t, testFile.TestCases[3].Result).Value) } func TestParseTestWithVariousTypes(t *testing.T) { @@ -218,9 +230,9 @@ func TestParseTestWithVariousTypes(t *testing.T) { t.Errorf("unexpected arg value type %T", v) } } - resultLit, ok := testFile.TestCases[0].Result.Value.(expr.Literal) - require.True(t, ok, "result should be a literal, got %T", testFile.TestCases[0].Result.Value) - checkNullability(t, resultLit, testFile.TestCases[0].Result.Type) + resultLit, ok := resultLiteral(t, testFile.TestCases[0].Result).Value.(expr.Literal) + require.True(t, ok, "result should be a literal, got %T", resultLiteral(t, testFile.TestCases[0].Result).Value) + checkNullability(t, resultLit, resultLiteral(t, testFile.TestCases[0].Result).Type) }) } } @@ -257,7 +269,7 @@ starts_with('abcd'::str, 'AB'::str) [case_sensitivity:CASE_INSENSITIVE] = true:: strRes := literal.NewString("abcdef", false) assert.Equal(t, strAbc, testFile.TestCases[0].Args[0].Value) assert.Equal(t, strDef, testFile.TestCases[0].Args[1].Value) - assert.Equal(t, strRes, testFile.TestCases[0].Result.Value) + assert.Equal(t, strRes, resultLiteral(t, testFile.TestCases[0].Result).Value) strArg1 := literal.NewString("HHHelloooo", false) strArg2 := literal.NewString("Hel+", false) @@ -267,17 +279,17 @@ starts_with('abcd'::str, 'AB'::str) [case_sensitivity:CASE_INSENSITIVE] = true:: strRes1 := literal.NewString("HH", false) strRes2 := literal.NewString("oooo", false) result, _ := literal.NewList([]expr.Literal{strRes1, strRes2}, false) - assert.Equal(t, result, testFile.TestCases[1].Result.Value) + assert.Equal(t, result, resultLiteral(t, testFile.TestCases[1].Result).Value) str1 := literal.NewString("à", false) i642 := literal.NewInt64(2, false) assert.Equal(t, str1, testFile.TestCases[2].Args[0].Value) - assert.Equal(t, i642, testFile.TestCases[2].Result.Value) + assert.Equal(t, i642, resultLiteral(t, testFile.TestCases[2].Result).Value) str2 := literal.NewString("😄", false) i644 := literal.NewInt64(4, false) assert.Equal(t, str2, testFile.TestCases[3].Args[0].Value) - assert.Equal(t, i644, testFile.TestCases[3].Result.Value) + assert.Equal(t, i644, resultLiteral(t, testFile.TestCases[3].Result).Value) } @@ -301,8 +313,8 @@ some_func('abc'::str, 'def'::str) = [1, 2, 3, 4, 5, 6]::List` literal.NewInt8(1, false), literal.NewInt8(2, false), literal.NewInt8(3, false), literal.NewInt8(4, false), literal.NewInt8(5, false), literal.NewInt8(6, false), }, false) - assert.Equal(t, list, testFile.TestCases[0].Result.Value) - assert.Equal(t, i8List, testFile.TestCases[0].Result.Type) + assert.Equal(t, list, resultLiteral(t, testFile.TestCases[0].Result).Value) + assert.Equal(t, i8List, resultLiteral(t, testFile.TestCases[0].Result).Type) } func TestParseNestedListLiteral(t *testing.T) { @@ -439,8 +451,8 @@ sum((9223372036854775806, 1, 1, 1, 1, 10000000000)::i64) [overflow:ERROR] =