diff --git a/core/src/main/java/io/substrait/plan/Plan.java b/core/src/main/java/io/substrait/plan/Plan.java index 6436993f4..2f4deb8f0 100644 --- a/core/src/main/java/io/substrait/plan/Plan.java +++ b/core/src/main/java/io/substrait/plan/Plan.java @@ -3,9 +3,12 @@ import io.substrait.SubstraitVersion; import io.substrait.extension.AdvancedExtension; import io.substrait.relation.Rel; +import io.substrait.type.NamedFieldCountingTypeVisitor; import java.util.List; import java.util.Optional; import org.immutables.value.Value; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; @Value.Immutable public abstract class Plan { @@ -61,10 +64,31 @@ private static Version loadVersion() { @Value.Immutable public abstract static class Root { + private static final Logger LOGGER = LoggerFactory.getLogger(Root.class); + public abstract Rel getInput(); public abstract List getNames(); + @Value.Check + protected void check() { + final int actualNameCount = getNames().size(); + if (actualNameCount == 0) { + LOGGER.warn( + "Plan.Root built without output names; this will be an error in the next release"); + return; + } + + final int expectedFieldCount = + NamedFieldCountingTypeVisitor.countNames(getInput().getRecordType()); + if (actualNameCount != expectedFieldCount) { + throw new IllegalArgumentException( + String.format( + "Plan.Root names count (%d) must match input record type depth-first named-field count (%d)", + actualNameCount, expectedFieldCount)); + } + } + public static ImmutableRoot.Builder builder() { return ImmutableRoot.builder(); } diff --git a/core/src/main/java/io/substrait/relation/VirtualTableScan.java b/core/src/main/java/io/substrait/relation/VirtualTableScan.java index 36e5b6dcc..28646a07f 100644 --- a/core/src/main/java/io/substrait/relation/VirtualTableScan.java +++ b/core/src/main/java/io/substrait/relation/VirtualTableScan.java @@ -1,8 +1,8 @@ package io.substrait.relation; import io.substrait.expression.Expression; +import io.substrait.type.NamedFieldCountingTypeVisitor; import io.substrait.type.Type; -import io.substrait.type.TypeVisitor; import io.substrait.util.VisitationContext; import java.util.List; import java.util.Objects; @@ -90,162 +90,4 @@ public O accept( public static ImmutableVirtualTableScan.Builder builder() { return ImmutableVirtualTableScan.builder(); } - - private static class NamedFieldCountingTypeVisitor - implements TypeVisitor { - - private static final NamedFieldCountingTypeVisitor VISITOR = - new NamedFieldCountingTypeVisitor(); - - private static Integer countNames(Type type) { - return type.accept(VISITOR); - } - - @Override - public Integer visit(Type.Bool type) throws RuntimeException { - return 0; - } - - @Override - public Integer visit(Type.I8 type) throws RuntimeException { - return 0; - } - - @Override - public Integer visit(Type.I16 type) throws RuntimeException { - return 0; - } - - @Override - public Integer visit(Type.I32 type) throws RuntimeException { - return 0; - } - - @Override - public Integer visit(Type.I64 type) throws RuntimeException { - return 0; - } - - @Override - public Integer visit(Type.FP32 type) throws RuntimeException { - return 0; - } - - @Override - public Integer visit(Type.FP64 type) throws RuntimeException { - return 0; - } - - @Override - public Integer visit(Type.Str type) throws RuntimeException { - return 0; - } - - @Override - public Integer visit(Type.Binary type) throws RuntimeException { - return 0; - } - - @Override - public Integer visit(Type.Date type) throws RuntimeException { - return 0; - } - - @Override - public Integer visit(Type.Time type) throws RuntimeException { - return 0; - } - - @Override - public Integer visit(Type.TimestampTZ type) throws RuntimeException { - return 0; - } - - @Override - public Integer visit(Type.Timestamp type) throws RuntimeException { - return 0; - } - - @Override - public Integer visit(Type.PrecisionTimestamp type) throws RuntimeException { - return 0; - } - - @Override - public Integer visit(Type.PrecisionTime type) throws RuntimeException { - return 0; - } - - @Override - public Integer visit(Type.PrecisionTimestampTZ type) throws RuntimeException { - return 0; - } - - @Override - public Integer visit(Type.IntervalYear type) throws RuntimeException { - return 0; - } - - @Override - public Integer visit(Type.IntervalDay type) throws RuntimeException { - return 0; - } - - @Override - public Integer visit(Type.IntervalCompound type) throws RuntimeException { - return 0; - } - - @Override - public Integer visit(Type.UUID type) throws RuntimeException { - return 0; - } - - @Override - public Integer visit(Type.FixedChar type) throws RuntimeException { - return 0; - } - - @Override - public Integer visit(Type.VarChar type) throws RuntimeException { - return 0; - } - - @Override - public Integer visit(Type.FixedBinary type) throws RuntimeException { - return 0; - } - - @Override - public Integer visit(Type.Decimal type) throws RuntimeException { - return 0; - } - - @Override - public Integer visit(Type.Struct type) throws RuntimeException { - // Only struct fields have names - the top level column names are also - // captured by this since the whole schema is wrapped in a Struct type - return type.fields().stream().mapToInt(field -> 1 + field.accept(this)).sum(); - } - - @Override - public Integer visit(Type.ListType type) throws RuntimeException { - return type.elementType().accept(this); - } - - @Override - public Integer visit(Type.Map type) throws RuntimeException { - return type.key().accept(this) + type.value().accept(this); - } - - @Override - public Integer visit(Type.UserDefined type) throws RuntimeException { - return 0; - } - - @Override - public Integer visit(Type.Func type) throws RuntimeException { - return 0; - } - } } diff --git a/core/src/main/java/io/substrait/type/NamedFieldCountingTypeVisitor.java b/core/src/main/java/io/substrait/type/NamedFieldCountingTypeVisitor.java new file mode 100644 index 000000000..4aa689869 --- /dev/null +++ b/core/src/main/java/io/substrait/type/NamedFieldCountingTypeVisitor.java @@ -0,0 +1,192 @@ +package io.substrait.type; + +/** + * Counts the number of field names required for a {@link Type} using Substrait's depth-first naming + * rules. + * + *

This is the same counting scheme used by {@link NamedStruct#names()}: top-level struct fields + * contribute one name each, and nested struct fields inside structs, lists, and maps also + * contribute names in depth-first order. Scalar types and other non-structural types do not + * contribute additional names. + * + *

Examples: + * + *

    + *
  • {@code struct} requires 2 names + *
  • {@code list>} requires 2 names + *
  • {@code map, struct>} requires 5 names + *
+ * + *

This utility is used anywhere the library needs to validate or reason about name counts + * without carrying the names themselves, such as {@code Plan.Root} and {@code VirtualTableScan} + * validation. + */ +public final class NamedFieldCountingTypeVisitor implements TypeVisitor { + + private static final NamedFieldCountingTypeVisitor VISITOR = new NamedFieldCountingTypeVisitor(); + + private NamedFieldCountingTypeVisitor() {} + + /** + * Returns the number of names required to describe {@code type} in Substrait's depth-first naming + * order. + * + *

For a top-level struct, this includes both the top-level field names and any nested struct + * field names required by compound child types. + * + * @param type the type to inspect + * @return the number of required names + */ + public static int countNames(Type type) { + return type.accept(VISITOR); + } + + @Override + public Integer visit(Type.Bool type) { + return 0; + } + + @Override + public Integer visit(Type.I8 type) { + return 0; + } + + @Override + public Integer visit(Type.I16 type) { + return 0; + } + + @Override + public Integer visit(Type.I32 type) { + return 0; + } + + @Override + public Integer visit(Type.I64 type) { + return 0; + } + + @Override + public Integer visit(Type.FP32 type) { + return 0; + } + + @Override + public Integer visit(Type.FP64 type) { + return 0; + } + + @Override + public Integer visit(Type.Str type) { + return 0; + } + + @Override + public Integer visit(Type.Binary type) { + return 0; + } + + @Override + public Integer visit(Type.Date type) { + return 0; + } + + @Override + public Integer visit(Type.Time type) { + return 0; + } + + @Override + public Integer visit(Type.TimestampTZ type) { + return 0; + } + + @Override + public Integer visit(Type.Timestamp type) { + return 0; + } + + @Override + public Integer visit(Type.PrecisionTime type) { + return 0; + } + + @Override + public Integer visit(Type.PrecisionTimestamp type) { + return 0; + } + + @Override + public Integer visit(Type.PrecisionTimestampTZ type) { + return 0; + } + + @Override + public Integer visit(Type.IntervalYear type) { + return 0; + } + + @Override + public Integer visit(Type.IntervalDay type) { + return 0; + } + + @Override + public Integer visit(Type.IntervalCompound type) { + return 0; + } + + @Override + public Integer visit(Type.UUID type) { + return 0; + } + + @Override + public Integer visit(Type.FixedChar type) { + return 0; + } + + @Override + public Integer visit(Type.VarChar type) { + return 0; + } + + @Override + public Integer visit(Type.FixedBinary type) { + return 0; + } + + @Override + public Integer visit(Type.Decimal type) { + return 0; + } + + @Override + public Integer visit(Type.Func type) { + return 0; + } + + @Override + public Integer visit(Type.Struct type) { + // Each struct field contributes its own name, plus any nested names required by that field's + // type. + return type.fields().stream().mapToInt(field -> 1 + countNames(field)).sum(); + } + + @Override + public Integer visit(Type.ListType type) { + // Lists do not add a name themselves, but list elements may contain nested structs. + return countNames(type.elementType()); + } + + @Override + public Integer visit(Type.Map type) { + // Maps do not add names themselves; any required names come from struct keys and/or values. + return countNames(type.key()) + countNames(type.value()); + } + + @Override + public Integer visit(Type.UserDefined type) { + return 0; + } +} diff --git a/core/src/test/java/io/substrait/plan/PlanConverterTest.java b/core/src/test/java/io/substrait/plan/PlanConverterTest.java index ab0a321b4..62ab1d8ee 100644 --- a/core/src/test/java/io/substrait/plan/PlanConverterTest.java +++ b/core/src/test/java/io/substrait/plan/PlanConverterTest.java @@ -9,6 +9,7 @@ import io.substrait.extension.SimpleExtension; import io.substrait.plan.Plan.Root; import io.substrait.relation.ImmutableVirtualTableScan; +import io.substrait.relation.NamedScan; import io.substrait.relation.VirtualTableScan; import io.substrait.type.NamedStruct; import io.substrait.type.Type; @@ -18,9 +19,27 @@ import io.substrait.utils.StringHolderHandlingProtoExtensionConverter; import java.util.Arrays; import java.util.Collections; +import java.util.List; import org.junit.jupiter.api.Test; class PlanConverterTest { + @Test + void rootNamesMustMatchInputFieldCount() { + final NamedScan scan = + NamedScan.builder() + .addNames("test_table") + .initialSchema( + NamedStruct.builder() + .addNames("only_column") + .struct(TypeCreator.REQUIRED.struct(TypeCreator.REQUIRED.I32)) + .build()) + .build(); + + assertThrows( + IllegalArgumentException.class, + () -> Root.builder().input(scan).names(List.of("col1", "col2")).build()); + } + @Test void emptyAdvancedExtensionTest() { final Plan plan = Plan.builder().advancedExtension(AdvancedExtension.builder().build()).build(); @@ -304,7 +323,14 @@ void nestedUserDefinedTypesShareExtensionCollector() { false, nullablePointLiteral, pointLiteral, vectorOfPointLiteral)) .build(); - Plan plan = Plan.builder().addRoots(Root.builder().input(virtualTable).build()).build(); + Plan plan = + Plan.builder() + .addRoots( + Root.builder() + .input(virtualTable) + .names(List.of("nullable_point_col", "point_col", "vector_col")) + .build()) + .build(); PlanProtoConverter toProtoConverter = new PlanProtoConverter(); io.substrait.proto.Plan protoPlan = toProtoConverter.toProto(plan); diff --git a/core/src/test/java/io/substrait/type/NamedFieldCountingTypeVisitorTest.java b/core/src/test/java/io/substrait/type/NamedFieldCountingTypeVisitorTest.java new file mode 100644 index 000000000..c6a40bdaf --- /dev/null +++ b/core/src/test/java/io/substrait/type/NamedFieldCountingTypeVisitorTest.java @@ -0,0 +1,41 @@ +package io.substrait.type; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.jupiter.api.Test; + +class NamedFieldCountingTypeVisitorTest { + private static final TypeCreator R = TypeCreator.REQUIRED; + + @Test + void primitiveTypeHasNoNames() { + assertEquals(0, NamedFieldCountingTypeVisitor.countNames(R.I32)); + } + + @Test + void simpleNamedStructDocExample() { + // Doc example from https://substrait.io/types/named_structs/ "Simple Named Struct". + Type.Struct type = R.struct(R.I64, R.list(R.I64), R.map(R.I64, R.I64), R.I64); + + assertEquals(4, NamedFieldCountingTypeVisitor.countNames(type)); + } + + @Test + void structsInCompoundTypesDocExample() { + // Doc example from https://substrait.io/types/named_structs/ "Structs in Compound Types". + Type.Struct type = + R.struct( + R.I64, R.list(R.struct(R.I64, R.I64)), R.map(R.I64, R.struct(R.I64, R.I64)), R.I64); + + assertEquals(8, NamedFieldCountingTypeVisitor.countNames(type)); + } + + @Test + void structsInStructsDocExample() { + // Doc example from https://substrait.io/types/named_structs/ "Structs in Structs". + Type.Struct type = + R.struct(R.I64, R.struct(R.I64, R.struct(R.I64, R.I64), R.I64, R.struct(R.I64, R.I64))); + + assertEquals(10, NamedFieldCountingTypeVisitor.countNames(type)); + } +} diff --git a/core/src/test/resources/plan-roundtrip/complex-expected-plan.json b/core/src/test/resources/plan-roundtrip/complex-expected-plan.json index 9f4c47a18..feae19f72 100644 --- a/core/src/test/resources/plan-roundtrip/complex-expected-plan.json +++ b/core/src/test/resources/plan-roundtrip/complex-expected-plan.json @@ -141,6 +141,7 @@ } }, "names": [ + "dummy", "result", "product" ] diff --git a/core/src/test/resources/plan-roundtrip/complex-input-plan.json b/core/src/test/resources/plan-roundtrip/complex-input-plan.json index 17a1984bd..b7393a572 100644 --- a/core/src/test/resources/plan-roundtrip/complex-input-plan.json +++ b/core/src/test/resources/plan-roundtrip/complex-input-plan.json @@ -126,6 +126,7 @@ } }, "names": [ + "dummy", "result", "product" ] diff --git a/core/src/test/resources/plan-roundtrip/simple-expected-plan.json b/core/src/test/resources/plan-roundtrip/simple-expected-plan.json index 5ae8daa3f..6ad7d4d07 100644 --- a/core/src/test/resources/plan-roundtrip/simple-expected-plan.json +++ b/core/src/test/resources/plan-roundtrip/simple-expected-plan.json @@ -114,6 +114,8 @@ } }, "names": [ + "value1", + "value2", "sum_result" ] } diff --git a/core/src/test/resources/plan-roundtrip/simple-input-plan.json b/core/src/test/resources/plan-roundtrip/simple-input-plan.json index b5f5af06e..441995bee 100644 --- a/core/src/test/resources/plan-roundtrip/simple-input-plan.json +++ b/core/src/test/resources/plan-roundtrip/simple-input-plan.json @@ -99,6 +99,8 @@ } }, "names": [ + "value1", + "value2", "sum_result" ] } diff --git a/core/src/test/resources/plan-roundtrip/zero-anchor-expected-plan.json b/core/src/test/resources/plan-roundtrip/zero-anchor-expected-plan.json index ff6b8ab0e..ce8445ef6 100644 --- a/core/src/test/resources/plan-roundtrip/zero-anchor-expected-plan.json +++ b/core/src/test/resources/plan-roundtrip/zero-anchor-expected-plan.json @@ -86,6 +86,8 @@ } }, "names": [ + "a", + "b", "sum_result" ] } diff --git a/core/src/test/resources/plan-roundtrip/zero-anchor-input-plan.json b/core/src/test/resources/plan-roundtrip/zero-anchor-input-plan.json index 202814110..6c57c4305 100644 --- a/core/src/test/resources/plan-roundtrip/zero-anchor-input-plan.json +++ b/core/src/test/resources/plan-roundtrip/zero-anchor-input-plan.json @@ -73,6 +73,8 @@ } }, "names": [ + "a", + "b", "sum_result" ] }