From 2badbf534c7fb9bb048ca8643b968ad12f72443f Mon Sep 17 00:00:00 2001 From: Andrew Coleman Date: Wed, 15 Apr 2026 10:47:56 +0100 Subject: [PATCH 1/3] =?UTF-8?q?fix(spark):=20correct=20handling=20of=20?= =?UTF-8?q?=E2=80=98remap=E2=80=99=20property?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When running the substrait to spark converter on plans that were generated by another system, it was found that the emit/outputMappings (remap) property was not being processed correctly. This commit fixes that. Extra tests were added, containing substrait plans in JSON proto format that were generated by a third party system. Signed-off-by: Andrew Coleman --- spark/spark-3.4_2.12/build.gradle.kts | 1 + spark/spark-3.5_2.12/build.gradle.kts | 1 + spark/spark-4.0_2.13/build.gradle.kts | 1 + .../spark/logical/ToLogicalPlan.scala | 67 +- spark/src/test/resources/nba_team.csv | 31 + spark/src/test/resources/nba_team_history.csv | 53 ++ .../substrait_plan_nba_california.json | 822 ++++++++++++++++ .../substrait_plan_with_aggregate_op.json | 893 ++++++++++++++++++ .../src/test/resources/tests_subset_2023.csv | 30 + .../test/resources/vehicles_subset_2023.csv | 31 + .../substrait/spark/SubstraitJsonSuite.scala | 123 +++ 11 files changed, 2035 insertions(+), 18 deletions(-) create mode 100644 spark/src/test/resources/nba_team.csv create mode 100644 spark/src/test/resources/nba_team_history.csv create mode 100644 spark/src/test/resources/substrait_plan_nba_california.json create mode 100644 spark/src/test/resources/substrait_plan_with_aggregate_op.json create mode 100644 spark/src/test/resources/tests_subset_2023.csv create mode 100644 spark/src/test/resources/vehicles_subset_2023.csv create mode 100644 spark/src/test/scala/io/substrait/spark/SubstraitJsonSuite.scala diff --git a/spark/spark-3.4_2.12/build.gradle.kts b/spark/spark-3.4_2.12/build.gradle.kts index f6317b8ad..da15de2a0 100644 --- a/spark/spark-3.4_2.12/build.gradle.kts +++ b/spark/spark-3.4_2.12/build.gradle.kts @@ -135,6 +135,7 @@ dependencies { implementation(platform(libs.jackson.bom)) implementation(libs.bundles.jackson) implementation(libs.json.schema.validator) + testImplementation(libs.protobuf.java.util) testImplementation(platform(libs.junit.bom)) testRuntimeOnly(libs.junit.platform.engine) diff --git a/spark/spark-3.5_2.12/build.gradle.kts b/spark/spark-3.5_2.12/build.gradle.kts index 031447f79..7f9e37e87 100644 --- a/spark/spark-3.5_2.12/build.gradle.kts +++ b/spark/spark-3.5_2.12/build.gradle.kts @@ -136,6 +136,7 @@ dependencies { implementation(platform(libs.jackson.bom)) implementation(libs.bundles.jackson) implementation(libs.json.schema.validator) + testImplementation(libs.protobuf.java.util) testImplementation(platform(libs.junit.bom)) testRuntimeOnly(libs.junit.platform.engine) diff --git a/spark/spark-4.0_2.13/build.gradle.kts b/spark/spark-4.0_2.13/build.gradle.kts index cd39d39ca..0f4aca567 100644 --- a/spark/spark-4.0_2.13/build.gradle.kts +++ b/spark/spark-4.0_2.13/build.gradle.kts @@ -136,6 +136,7 @@ dependencies { implementation(platform(libs.jackson.bom)) implementation(libs.bundles.jackson) implementation(libs.json.schema.validator) + testImplementation(libs.protobuf.java.util) testImplementation(platform(libs.junit.bom)) testRuntimeOnly(libs.junit.platform.engine) diff --git a/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala b/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala index 99f43f0f3..0f823cc4f 100644 --- a/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala +++ b/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala @@ -54,6 +54,7 @@ import io.substrait.util.EmptyVisitationContext import org.apache.hadoop.fs.Path import java.net.URI +import java.util.Optional import scala.annotation.nowarn import scala.collection.mutable.ArrayBuffer @@ -141,7 +142,8 @@ class ToLogicalPlan(val spark: AnyRef = SparkCompat.instance.getOrCreateSparkSes val outputs = groupBy.map(toNamedExpression) val aggregateExpressions = aggregate.getMeasures.asScala.map(fromMeasure).map(toNamedExpression).toSeq - Aggregate(groupBy, outputs ++ aggregateExpressions, child) + val plan = Aggregate(groupBy, outputs ++ aggregateExpressions, child) + remap(plan, aggregate.getRemap) } } @@ -199,7 +201,8 @@ class ToLogicalPlan(val spark: AnyRef = SparkCompat.instance.getOrCreateSparkSes }) .map(toNamedExpression(_)) .toSeq - Window(windowExpressions, partitions, sortOrders, child) + val plan = Window(windowExpressions, partitions, sortOrders, child) + remap(plan, window.getRemap) } } @@ -225,7 +228,8 @@ class ToLogicalPlan(val spark: AnyRef = SparkCompat.instance.getOrCreateSparkSes case other => throw new UnsupportedOperationException(s"Unsupported join type $other") } - Join(left, right, joinType, condition, hint = JoinHint.NONE) + val plan = Join(left, right, joinType, condition, hint = JoinHint.NONE) + remap(plan, join.getRemap) } } @@ -235,7 +239,8 @@ class ToLogicalPlan(val spark: AnyRef = SparkCompat.instance.getOrCreateSparkSes withChild(left, right) { // TODO: Support different join types here when join types are added to cross rel for BNLJ // Currently, this will change both cross and inner join types to inner join - Join(left, right, Inner, Option(null), hint = JoinHint.NONE) + val plan = Join(left, right, Inner, Option(null), hint = JoinHint.NONE) + remap(plan, join.getRemap) } } @@ -258,7 +263,7 @@ class ToLogicalPlan(val spark: AnyRef = SparkCompat.instance.getOrCreateSparkSes val limit = fetch.getCount.orElse(-1).intValue() // -1 means unassigned here val offset = fetch.getOffset.intValue() val toLiteral = (i: Int) => Literal(i, IntegerType) - if (limit >= 0) { + val plan = if (limit >= 0) { val limitExpr = toLiteral(limit) if (offset > 0) { GlobalLimit( @@ -270,13 +275,15 @@ class ToLogicalPlan(val spark: AnyRef = SparkCompat.instance.getOrCreateSparkSes } else { Offset(toLiteral(offset), child) } + remap(plan, fetch.getRemap) } override def visit(sort: relation.Sort, context: EmptyVisitationContext): LogicalPlan = { val child = sort.getInput.accept(this, context) withChild(child) { val sortOrders = sort.getSortFields.asScala.map(toSortOrder).toSeq - Sort(sortOrders, global = true, child) + val plan = Sort(sortOrders, global = true, child) + remap(plan, sort.getRemap) } } @@ -310,17 +317,24 @@ class ToLogicalPlan(val spark: AnyRef = SparkCompat.instance.getOrCreateSparkSes val names = fieldNames(project).getOrElse(List.empty) withOutput(output) { - val projectExprs = + val projectExprs = { project.getExpressions.asScala .map(expr => expr.accept(expressionConverter, context)) .toSeq + } val projectList = if (names.size == projectExprs.size) { projectExprs.zip(names).map { case (expr, name) => Alias(expr, name)() } } else { projectExprs.map(toNamedExpression) } if (createProject) { - Project(projectList, child) + val ps = output.map(_.toAttribute) ++ projectList + val remapped = if (project.getRemap.isPresent) { + project.getRemap.get().indices().asScala.map(i => ps(i)).toSeq + } else { + ps + } + Project(remapped, child) } else { val aggregate: Aggregate = child.asInstanceOf[Aggregate] aggregate.copy(aggregateExpressions = projectList) @@ -352,7 +366,8 @@ class ToLogicalPlan(val spark: AnyRef = SparkCompat.instance.getOrCreateSparkSes .map { case (t, name) => StructField(name, t._1, t._2) } .map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()) - Expand(projections.transpose, output, child) + val plan = Expand(projections.transpose, output, child) + remap(plan, expand.getRemap) } } @@ -360,18 +375,20 @@ class ToLogicalPlan(val spark: AnyRef = SparkCompat.instance.getOrCreateSparkSes val child = filter.getInput.accept(this, context) withChild(child) { val condition = filter.getCondition.accept(expressionConverter, context) - Filter(condition, child) + val plan = Filter(condition, child) + remap(plan, filter.getRemap) } } override def visit(set: relation.Set, context: EmptyVisitationContext): LogicalPlan = { val children = set.getInputs.asScala.map(_.accept(this, context)).toSeq withOutput(children.flatMap(_.output)) { - set.getSetOp match { + val plan = set.getSetOp match { case SetOp.UNION_ALL => Union(children, byName = false, allowMissingCol = false) case op => throw new UnsupportedOperationException(s"Operation not currently supported: $op") } + remap(plan, set.getRemap) } } @@ -386,21 +403,23 @@ class ToLogicalPlan(val spark: AnyRef = SparkCompat.instance.getOrCreateSparkSes .toSeq ) }.toSeq - virtualTableScan.getInitialSchema match { + val plan = virtualTableScan.getInitialSchema match { case ns: NamedStruct if ns.names().isEmpty && rows.length == 1 => OneRowRelation() case _ => LocalRelation(ToSparkType.toAttributeSeq(virtualTableScan.getInitialSchema), rows) } + remap(plan, virtualTableScan.getRemap) } override def visit( namedScan: relation.NamedScan, context: EmptyVisitationContext): LogicalPlan = { - resolve(UnresolvedRelation(namedScan.getNames.asScala.toSeq)) match { + val plan = resolve(UnresolvedRelation(namedScan.getNames.asScala.toSeq)) match { case m: MultiInstanceRelation => m.newInstance() case other => other } + remap(plan, namedScan.getRemap) } override def visit(localFiles: LocalFiles, context: EmptyVisitationContext): LogicalPlan = { @@ -427,12 +446,13 @@ class ToLogicalPlan(val spark: AnyRef = SparkCompat.instance.getOrCreateSparkSes format, options ) - SparkCompat.instance.createLogicalRelation( + val plan = SparkCompat.instance.createLogicalRelation( relation = hadoopFsRelation, output = output, catalogTable = None, isStreaming = false ) + remap(plan, localFiles.getRemap) } def convertFileFormat(fileFormat: FileFormat): (SparkFileFormat, Map[String, String]) = { @@ -467,7 +487,7 @@ class ToLogicalPlan(val spark: AnyRef = SparkCompat.instance.getOrCreateSparkSes case "hive" => true case _ => false } - write.getOperation match { + val plan = write.getOperation match { case WriteOp.CTAS => withChild(child) { if (isHive) { @@ -501,7 +521,7 @@ class ToLogicalPlan(val spark: AnyRef = SparkCompat.instance.getOrCreateSparkSes } case op => throw new UnsupportedOperationException(s"Write mode $op not supported") } - + remap(plan, write.getRemap) } override def visit(write: ExtensionWrite, context: EmptyVisitationContext): LogicalPlan = { @@ -529,7 +549,7 @@ class ToLogicalPlan(val spark: AnyRef = SparkCompat.instance.getOrCreateSparkSes val name = file.getPath.get.split('/').reverse.head val table = catalogTable(Seq(name)) - withChild(child) { + val plan = withChild(child) { V1Writes.apply( InsertIntoHadoopFsRelationCommand( outputPath = new Path(file.getPath.get), @@ -546,19 +566,21 @@ class ToLogicalPlan(val spark: AnyRef = SparkCompat.instance.getOrCreateSparkSes outputColumnNames = write.getTableSchema.names.asScala.toSeq )) } + remap(plan, write.getRemap) } override def visit(ddl: NamedDdl, context: EmptyVisitationContext): LogicalPlan = { val table = catalogTable(ddl.getNames.asScala.toSeq, ToSparkType.toStructType(ddl.getTableSchema)) - (ddl.getOperation, ddl.getObject) match { + val plan = (ddl.getOperation, ddl.getObject) match { case (DdlOp.CREATE, DdlObject.TABLE) => CreateTableCommand(table, false) case (DdlOp.DROP, DdlObject.TABLE) => DropTableCommand(table.identifier, false, false, false) case (DdlOp.DROP_IF_EXIST, DdlObject.TABLE) => DropTableCommand(table.identifier, true, false, false) case op => throw new UnsupportedOperationException(s"Ddl operation $op not supported") } + remap(plan, ddl.getRemap) } private def catalogTable( @@ -613,6 +635,15 @@ class ToLogicalPlan(val spark: AnyRef = SparkCompat.instance.getOrCreateSparkSes expressionConverter.popOutput() } } + + private def remap(plan: LogicalPlan, remap: Optional[relation.Rel.Remap]): LogicalPlan = { + if (remap.isEmpty) { + return plan + } + val projectExprs = plan.output.map { case ne: NamedExpression => ne.toAttribute }.map(toNamedExpression) + Project(remap.get().indices().asScala.map(i => projectExprs(i)).toSeq, plan) + } + private def resolve(plan: LogicalPlan): LogicalPlan = { val qe = SparkCompat.instance.createQueryExecution(spark, plan) qe.analyzed match { diff --git a/spark/src/test/resources/nba_team.csv b/spark/src/test/resources/nba_team.csv new file mode 100644 index 000000000..825bc7aa5 --- /dev/null +++ b/spark/src/test/resources/nba_team.csv @@ -0,0 +1,31 @@ +id,full_name,abbreviation,nickname,city,state,year_founded +1610612737,Atlanta Hawks,ATL,Hawks,Atlanta,Atlanta,1949.0 +1610612738,Boston Celtics,BOS,Celtics,Boston,Massachusetts,1946.0 +1610612739,Cleveland Cavaliers,CLE,Cavaliers,Cleveland,Ohio,1970.0 +1610612740,New Orleans Pelicans,NOP,Pelicans,New Orleans,Louisiana,2002.0 +1610612741,Chicago Bulls,CHI,Bulls,Chicago,Illinois,1966.0 +1610612742,Dallas Mavericks,DAL,Mavericks,Dallas,Texas,1980.0 +1610612743,Denver Nuggets,DEN,Nuggets,Denver,Colorado,1976.0 +1610612744,Golden State Warriors,GSW,Warriors,Golden State,California,1946.0 +1610612745,Houston Rockets,HOU,Rockets,Houston,Texas,1967.0 +1610612746,Los Angeles Clippers,LAC,Clippers,Los Angeles,California,1970.0 +1610612747,Los Angeles Lakers,LAL,Lakers,Los Angeles,California,1948.0 +1610612748,Miami Heat,MIA,Heat,Miami,Florida,1988.0 +1610612749,Milwaukee Bucks,MIL,Bucks,Milwaukee,Wisconsin,1968.0 +1610612750,Minnesota Timberwolves,MIN,Timberwolves,Minnesota,Minnesota,1989.0 +1610612751,Brooklyn Nets,BKN,Nets,Brooklyn,New York,1976.0 +1610612752,New York Knicks,NYK,Knicks,New York,New York,1946.0 +1610612753,Orlando Magic,ORL,Magic,Orlando,Florida,1989.0 +1610612754,Indiana Pacers,IND,Pacers,Indiana,Indiana,1976.0 +1610612755,Philadelphia 76ers,PHI,76ers,Philadelphia,Pennsylvania,1949.0 +1610612756,Phoenix Suns,PHX,Suns,Phoenix,Arizona,1968.0 +1610612757,Portland Trail Blazers,POR,Trail Blazers,Portland,Oregon,1970.0 +1610612758,Sacramento Kings,SAC,Kings,Sacramento,California,1948.0 +1610612759,San Antonio Spurs,SAS,Spurs,San Antonio,Texas,1976.0 +1610612760,Oklahoma City Thunder,OKC,Thunder,Oklahoma City,Oklahoma,1967.0 +1610612761,Toronto Raptors,TOR,Raptors,Toronto,Ontario,1995.0 +1610612762,Utah Jazz,UTA,Jazz,Utah,Utah,1974.0 +1610612763,Memphis Grizzlies,MEM,Grizzlies,Memphis,Tennessee,1995.0 +1610612764,Washington Wizards,WAS,Wizards,Washington,District of Columbia,1961.0 +1610612765,Detroit Pistons,DET,Pistons,Detroit,Michigan,1948.0 +1610612766,Charlotte Hornets,CHA,Hornets,Charlotte,North Carolina,1988.0 diff --git a/spark/src/test/resources/nba_team_history.csv b/spark/src/test/resources/nba_team_history.csv new file mode 100644 index 000000000..6fdb7ee52 --- /dev/null +++ b/spark/src/test/resources/nba_team_history.csv @@ -0,0 +1,53 @@ +team_id,city,nickname,year_founded,year_active_till +1610612737,Atlanta,Hawks,1968,2019 +1610612737,St. Louis,Hawks,1955,1967 +1610612737,Milwaukee,Hawks,1951,1954 +1610612737,Tri-Cities,Blackhawks,1949,1950 +1610612741,Chicago,Bulls,1966,2019 +1610612742,Dallas,Mavericks,1980,2019 +1610612743,Denver,Nuggets,1976,2019 +1610612744,Golden State,Warriors,1971,2019 +1610612744,San Francisco,Warriors,1962,1970 +1610612744,Philadelphia,Warriors,1946,1961 +1610612745,Houston,Rockets,1971,2019 +1610612745,San Diego,Rockets,1967,1970 +1610612746,Los Angeles,Clippers,1984,2019 +1610612746,San Diego,Clippers,1978,1983 +1610612746,Buffalo,Braves,1970,1977 +1610612747,Los Angeles,Lakers,1960,2019 +1610612747,Minneapolis,Lakers,1948,1959 +1610612748,Miami,Heat,1988,2019 +1610612749,Milwaukee,Bucks,1968,2019 +1610612750,Minnesota,Timberwolves,1989,2019 +1610612751,Brooklyn,Nets,2012,2019 +1610612751,New Jersey,Nets,1977,2011 +1610612751,New York,Nets,1976,1976 +1610612754,Indiana,Pacers,1976,2019 +1610612755,Philadelphia,76ers,1963,2019 +1610612755,Syracuse,Nationals,1949,1962 +1610612756,Phoenix,Suns,1968,2019 +1610612757,Portland,Trail Blazers,1970,2019 +1610612758,Sacramento,Kings,1985,2019 +1610612758,Kansas City,Kings,1975,1984 +1610612758,Kansas City-Omaha,Kings,1972,1974 +1610612758,Cincinnati,Royals,1957,1971 +1610612758,Rochester,Royals,1948,1956 +1610612759,San Antonio,Spurs,1976,2019 +1610612760,Oklahoma City,Thunder,2008,2019 +1610612760,Seattle,SuperSonics,1967,2007 +1610612761,Toronto,Raptors,1995,2019 +1610612762,Utah,Jazz,1979,2019 +1610612762,New Orleans,Jazz,1974,1978 +1610612763,Memphis,Grizzlies,2001,2019 +1610612763,Vancouver,Grizzlies,1995,2000 +1610612764,Washington,Wizards,1997,2019 +1610612764,Washington,Bullets,1974,1996 +1610612764,Capital,Bullets,1973,1973 +1610612764,Baltimore,Bullets,1963,1972 +1610612764,Chicago,Zephyrs,1962,1962 +1610612764,Chicago,Packers,1961,1961 +1610612765,Detroit,Pistons,1957,2019 +1610612765,Ft. Wayne Zollner,Pistons,1948,1956 +1610612766,Charlotte,Hornets,2014,2019 +1610612766,Charlotte,Bobcats,2004,2013 +1610612766,Charlotte,Hornets,1988,2001 diff --git a/spark/src/test/resources/substrait_plan_nba_california.json b/spark/src/test/resources/substrait_plan_nba_california.json new file mode 100644 index 000000000..353ee5a5b --- /dev/null +++ b/spark/src/test/resources/substrait_plan_nba_california.json @@ -0,0 +1,822 @@ +{ + "extensions": [ + { + "extensionFunction": { + "functionAnchor": 277, + "name": "equal:any_any", + "extensionUrnReference": 7 + } + }, + { + "extensionFunction": { + "functionAnchor": 159, + "name": "concat:str", + "extensionUrnReference": 5 + } + } + ], + "relations": [ + { + "root": { + "input": { + "project": { + "input": { + "filter": { + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [ + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + 27, + 28, + 29, + 30, + 31, + 32, + 33 + ] + } + }, + "input": { + "join": { + "left": { + "project": { + "common": { + "emit": { + "outputMapping": [ + 7, + 8, + 9, + 10, + 11, + 12, + 13 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "id", + "full_name", + "abbreviation", + "nickname", + "city", + "state", + "year_founded" + ], + "struct": { + "types": [ + { + "i32": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 256, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 256, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 256, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 256, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 256, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fp64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "nba", + "team" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 3 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 5 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 6 + } + }, + "rootReference": {} + } + } + ] + } + }, + "right": { + "project": { + "common": { + "emit": { + "outputMapping": [ + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19 + ] + } + }, + "input": { + "project": { + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "team_id", + "city", + "nickname", + "year_founded", + "year_active_till" + ], + "struct": { + "types": [ + { + "i32": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 256, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 256, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i32": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i32": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "nba", + "team_history" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 3 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": {} + } + } + ] + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 3 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 3 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": {} + } + } + ] + } + }, + "expression": { + "scalarFunction": { + "functionReference": 277, + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + }, + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 7 + } + }, + "rootReference": {} + } + } + } + ] + } + }, + "type": "JOIN_TYPE_INNER" + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": { + "field": 7 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 8 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 9 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 10 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 11 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 7 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 13 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 14 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 15 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 16 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 3 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 5 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 6 + } + }, + "rootReference": {} + } + } + ] + } + }, + "condition": { + "scalarFunction": { + "functionReference": 277, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "cast": { + "type": { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "selection": { + "directReference": { + "structField": { + "field": 15 + } + }, + "rootReference": {} + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_RETURN_NULL" + } + } + }, + { + "value": { + "literal": { + "string": "California" + } + } + } + ] + } + } + } + }, + "expressions": [ + { + "scalarFunction": { + "functionReference": 159, + "outputType": { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "cast": { + "type": { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "selection": { + "directReference": { + "structField": { + "field": 6 + } + }, + "rootReference": {} + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_RETURN_NULL" + } + } + }, + { + "value": { + "literal": { + "string": " " + } + } + }, + { + "value": { + "cast": { + "type": { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "selection": { + "directReference": { + "structField": { + "field": 7 + } + }, + "rootReference": {} + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_RETURN_NULL" + } + } + } + ] + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 6 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 8 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 9 + } + }, + "rootReference": {} + } + } + ] + } + }, + "names": [ + "team_id", + "city", + "nickname", + "year_founded", + "year_active_till", + "team_id", + "hist_city", + "hist_nickname", + "hist_year_founded", + "hist_year_active_till", + "id", + "full_name", + "abbreviation", + "nickname", + "city", + "state", + "year_founded", + "team_name", + "city", + "year_founded", + "year_active_till" + ] + } + } + ], + "version": { + "minorNumber": 79 + }, + "extensionUrns": [ + { + "extensionUrnAnchor": 7, + "urn": "extension:io.substrait:functions_comparison" + }, + { + "extensionUrnAnchor": 5, + "urn": "extension:io.substrait:functions_string" + } + ] +} diff --git a/spark/src/test/resources/substrait_plan_with_aggregate_op.json b/spark/src/test/resources/substrait_plan_with_aggregate_op.json new file mode 100644 index 000000000..8fb0cc8c9 --- /dev/null +++ b/spark/src/test/resources/substrait_plan_with_aggregate_op.json @@ -0,0 +1,893 @@ +{ + "extensions": [ + { + "extensionFunction": { + "functionAnchor": 277, + "name": "equal:any_any", + "extensionUrnReference": 7 + } + }, + { + "extensionFunction": { + "functionAnchor": 415, + "name": "sum:i32", + "extensionUrnReference": 9 + } + } + ], + "relations": [ + { + "root": { + "input": { + "project": { + "common": { + "direct": {} + }, + "input": { + "aggregate": { + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [ + 20, + 21, + 22, + 23, + 24, + 25, + 26, + 27, + 28, + 29, + 30, + 31, + 32, + 33, + 34, + 35, + 36, + 37, + 38, + 39 + ] + } + }, + "input": { + "join": { + "left": { + "project": { + "common": { + "emit": { + "outputMapping": [ + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21 + ] + } + }, + "input": { + "project": { + "common": { + "direct": {} + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "test_id", + "vehicle_id", + "test_date", + "test_class", + "test_type", + "test_result", + "test_mileage", + "postcode_area" + ], + "struct": { + "types": [ + { + "i32": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i32": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "date": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i32": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 256, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 256, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i32": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 256, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "test1", + "tests" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 6 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 7 + } + }, + "rootReference": {} + } + } + ] + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 3 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 5 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 6 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 7 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 8 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 6 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 7 + } + }, + "rootReference": {} + } + } + ] + } + }, + "right": { + "project": { + "common": { + "emit": { + "outputMapping": [ + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17 + ] + } + }, + "input": { + "project": { + "common": { + "direct": {} + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "vehicle_id", + "make", + "model", + "colour", + "fuel_type", + "cylinder_capacity", + "first_use_date" + ], + "struct": { + "types": [ + { + "i32": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 256, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 256, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 256, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 256, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i32": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "date": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "test1", + "vehicles" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": {} + } + } + ] + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 3 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 5 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 6 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 7 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": {} + } + } + ] + } + }, + "expression": { + "scalarFunction": { + "functionReference": 277, + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 8 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 18 + } + }, + "rootReference": {} + } + } + } + ] + } + }, + "type": "JOIN_TYPE_INNER" + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 3 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 5 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 6 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 7 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 8 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 6 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 7 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 11 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 12 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 13 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 14 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 15 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 16 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 17 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 18 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 15 + } + }, + "rootReference": {} + } + } + ] + } + }, + "groupings": [ + { + "groupingExpressions": [ + { + "selection": { + "directReference": { + "structField": { + "field": 15 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 7 + } + }, + "rootReference": {} + } + } + ], + "expressionReferences": [ + 0, + 1 + ] + } + ], + "measures": [ + { + "measure": { + "functionReference": 415, + "outputType": { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 6 + } + }, + "rootReference": {} + } + } + } + ], + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT" + } + } + ], + "groupingExpressions": [ + { + "selection": { + "directReference": { + "structField": { + "field": 15 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 7 + } + }, + "rootReference": {} + } + } + ] + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": {} + } + }, + { + "cast": { + "type": { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": {} + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_RETURN_NULL" + } + } + ] + } + }, + "names": [ + "fuel_type", + "postcode_area", + "total_miles_i32", + "fuel_type", + "postcode_area", + "total_miles" + ] + } + } + ], + "version": { + "minorNumber": 79 + }, + "extensionUrns": [ + { + "extensionUrnAnchor": 7, + "urn": "extension:io.substrait:functions_comparison" + }, + { + "extensionUrnAnchor": 9, + "urn": "extension:io.substrait:functions_arithmetic" + } + ] +} diff --git a/spark/src/test/resources/tests_subset_2023.csv b/spark/src/test/resources/tests_subset_2023.csv new file mode 100644 index 000000000..762d53491 --- /dev/null +++ b/spark/src/test/resources/tests_subset_2023.csv @@ -0,0 +1,30 @@ +test_id,vehicle_id,test_date,test_class,test_type,test_result,test_mileage,postcode_area +539514409,17113014,2023-01-09,4,NT,F,69934,PA +1122718877,986649781,2023-01-16,4,NT,F,57376,SG +1104881351,424684356,2023-03-06,4,NT,F,81853,SG +1487493049,1307056703,2023-03-07,4,NT,P,20763,SA +1107861883,130747047,2023-03-27,4,RT,P,125910,SA +472789285,777757523,2023-03-29,4,NT,P,68399,CO +1105082521,840180863,2023-04-15,4,NT,P,54240,NN +1172953135,917255260,2023-04-27,4,NT,P,60918,SM +127807783,888103385,2023-05-08,4,NT,P,112090,EH +1645970709,816803134,2023-06-03,4,NT,P,134858,RG +1355347761,919820431,2023-06-21,4,NT,P,37336,ST +1750209849,544950855,2023-06-23,4,NT,F,120034,NR +1376930435,439876988,2023-07-19,4,NT,P,109927,PO +582729949,1075446447,2023-07-19,4,NT,P,72986,SA +127953451,105663799,2023-07-31,4,NT,F,35824,ME +759291679,931759350,2023-08-07,4,NT,P,65353,DY +1629819891,335780567,2023-08-08,4,NT,PRS,103365,CF +1120026477,1153361746,2023-08-11,4,NT,P,286881,RM +1331300969,644861283,2023-08-15,4,NT,P,52173,LE +990694587,449899992,2023-08-16,4,NT,F,124891,SA +193460599,759696266,2023-08-29,4,NT,P,83554,LU +1337337679,1110416764,2023-10-09,4,NT,PRS,71093,SS +1885237527,137785384,2023-11-04,4,NT,P,88730,BH +1082642803,1291985882,2023-11-15,4,NT,PRS,160717,BA +896066743,615735063,2023-11-15,4,RT,P,107710,NR +1022666841,474362449,2023-11-20,4,NT,P,56296,HP +1010400923,1203222226,2023-12-04,4,NT,F,89255,TW +866705687,605696575,2023-12-06,4,NT,P,14674,YO +621751843,72093448,2023-12-14,4,NT,F,230280,TR diff --git a/spark/src/test/resources/vehicles_subset_2023.csv b/spark/src/test/resources/vehicles_subset_2023.csv new file mode 100644 index 000000000..087b54c84 --- /dev/null +++ b/spark/src/test/resources/vehicles_subset_2023.csv @@ -0,0 +1,31 @@ +vehicle_id,make,model,colour,fuel_type,cylinder_capacity,first_use_date +17113014,VAUXHALL,VIVARO,BLACK,DI,1995,2011-09-29 +986649781,VAUXHALL,INSIGNIA,WHITE,DI,1956,2017-07-19 +424684356,RENAULT,GRAND SCENIC,GREY,PE,1997,2010-07-19 +1307056703,RENAULT,CLIO,BLACK,DI,1461,2014-05-30 +130747047,FORD,FOCUS,SILVER,DI,1560,2013-07-10 +777757523,HYUNDAI,I10,WHITE,PE,998,2016-05-21 +840180863,BMW,1 SERIES,WHITE,PE,2979,2016-03-11 +917255260,VAUXHALL,ASTRA,WHITE,PE,1364,2012-04-21 +888103385,FORD,GALAXY,SILVER,DI,1997,2014-09-12 +816803134,FORD,FIESTA,BLUE,PE,1299,2002-10-24 +697184031,BMW,X1,WHITE,DI,1995,2016-03-31 +919820431,TOYOTA,AURIS,BRONZE,PE,1329,2015-06-29 +544950855,VAUXHALL,ASTRA,RED,DI,1956,2012-09-17 +439876988,MINI,MINI,GREEN,PE,1598,2010-03-31 +1075446447,CITROEN,C4,RED,DI,1560,2015-10-05 +105663799,RENAULT,KADJAR,BLACK,PE,1332,2020-07-23 +931759350,FIAT,DUCATO,WHITE,DI,2199,2008-04-18 +335780567,HYUNDAI,I20,BLUE,PE,1396,2013-08-13 +1153361746,TOYOTA,PRIUS,SILVER,HY,1800,2010-06-23 +644861283,FORD,FIESTA,BLACK,PE,998,2015-09-03 +449899992,BMW,3 SERIES,GREEN,DI,2926,2006-09-30 +759696266,CITROEN,C4,BLUE,DI,1997,2011-12-19 +1110416764,CITROEN,XSARA,SILVER,DI,1997,1999-06-30 +137785384,MINI,MINI,GREY,DI,1598,2011-11-29 +1291985882,LAND ROVER,DEFENDER,BLUE,DI,2495,2002-06-12 +615735063,VOLKSWAGEN,CADDY,WHITE,DI,1598,2013-03-01 +474362449,VAUXHALL,GRANDLAND,GREY,PE,1199,2018-11-12 +1203222226,VAUXHALL,ASTRA,BLUE,PE,1598,2010-06-03 +605696575,SUZUKI,SWIFT SZ-T DUALJET MHEV CVT,RED,HY,1197,2020-12-18 +72093448,AUDI,A4,SILVER,DI,1896,2001-03-19 diff --git a/spark/src/test/scala/io/substrait/spark/SubstraitJsonSuite.scala b/spark/src/test/scala/io/substrait/spark/SubstraitJsonSuite.scala new file mode 100644 index 000000000..363ee904b --- /dev/null +++ b/spark/src/test/scala/io/substrait/spark/SubstraitJsonSuite.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.substrait.spark + +import io.substrait.plan.ProtoPlanConverter +import io.substrait.spark.logical.ToLogicalPlan +import org.apache.hadoop.shaded.com.google.gson.{Gson, JsonObject} +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan + +import java.io.IOException +import java.nio.file.{Files, Path, Paths} + +/** + * Contains test cases for the Substrait to Spark converter where the substrait plan is generated + * by a (trusted) third party tool. This deviates from the 'round-trip' methodology whereby + * all the substrait plans are generated by the Spark to Substrait converter in this repo. + * + * The Substrait plan is loaded from a proto file (JSON format). + */ +class SubstraitJsonSuite extends SparkFunSuite { + var spark: SparkSession = + SparkSession.builder().config("spark.master", "local").enableHiveSupport().getOrCreate() + + test("Aggregate") { + val substraitPlanPath = Path.of("../src/test/resources/substrait_plan_with_aggregate_op.json") + val defaultSchema = "test1" + + // Create table test1.vehicles + val vehiclesFile = "../src/test/resources/vehicles_subset_2023.csv" + val dsVehicles = spark.read.option("header", true).option("inferSchema", true).csv(Paths.get(vehiclesFile).toAbsolutePath.toString) + spark.sql("CREATE NAMESPACE IF NOT EXISTS test1") + spark.sql("DROP TABLE IF EXISTS test1.vehicles") + try dsVehicles.writeTo("test1.vehicles").create() + catch { + case _: TableAlreadyExistsException => + // Table already exists, ignore + } + + // Create table test1.tests + val testsFile = "../src/test/resources/tests_subset_2023.csv" + val dsTests = spark.read.option("header", true).option("inferSchema", true).csv(Paths.get(testsFile).toAbsolutePath.toString) + spark.sql("DROP TABLE IF EXISTS test1.tests") + try dsTests.writeTo("test1.tests").create() + catch { + case _: TableAlreadyExistsException => + // Table already exists, ignore + } + + val sparkPlan = toSpark(substraitPlanPath, defaultSchema) + val output = spark.sessionState.executePlan(sparkPlan).executedPlan.execute() + // there should be 25 rows and 3 columns + assertResult(25)(output.count()) + assertResult(3)(output.take(1)(0).numFields) + } + + test("nba") { + val substraitPlanPath = Path.of("../src/test/resources/substrait_plan_nba_california.json") + val defaultSchema = "nba" + + // Create table nba.team + val vehiclesFile = "../src/test/resources/nba_team.csv" + val dsVehicles = spark.read.option("header", true).option("inferSchema", true).csv(Paths.get(vehiclesFile).toAbsolutePath.toString) + spark.sql("CREATE NAMESPACE IF NOT EXISTS nba") + spark.sql("DROP TABLE IF EXISTS nba.team") + try dsVehicles.writeTo("nba.team").create() + catch { + case _: TableAlreadyExistsException => + // Table already exists, ignore + } + + // Create table nba.team_history + val testsFile = "../src/test/resources/nba_team_history.csv" + val dsTests = spark.read.option("header", true).option("inferSchema", true).csv(Paths.get(testsFile).toAbsolutePath.toString) + spark.sql("DROP TABLE IF EXISTS nba.team_history") + try dsTests.writeTo("nba.team_history").create() + catch { + case _: TableAlreadyExistsException => + // Table already exists, ignore + } + + val sparkPlan = toSpark(substraitPlanPath, defaultSchema) + val output = spark.sessionState.executePlan(sparkPlan).executedPlan.execute() + // there should be 13 rows and 21 columns + assertResult(13)(output.count()) + assertResult(21)(output.take(1)(0).numFields) + } + + @throws[IOException] + def toSpark(substraitPlanPath: Path, defaultSchema: String): LogicalPlan = { + val defaultCatalog = "spark_catalog" + val jsonContent = Files.readString(substraitPlanPath) + val gson = new Gson + val jsonObj = gson.fromJson(jsonContent, classOf[JsonObject]) + val builder = io.substrait.proto.Plan.newBuilder + com.google.protobuf.util.JsonFormat.parser.ignoringUnknownFields.merge(jsonObj.toString, builder) + val proto = builder.build + val protoToPlan = new ProtoPlanConverter + val plan = protoToPlan.from(proto) + spark.catalog.setCurrentCatalog(defaultCatalog) + spark.catalog.setCurrentDatabase(defaultSchema) + val substraitConverter = new ToLogicalPlan(spark) + substraitConverter.convert(plan) + } + +} From b476ac49927c0251bb9d5ad254c864eb92757a38 Mon Sep 17 00:00:00 2001 From: Andrew Coleman Date: Wed, 15 Apr 2026 15:46:54 +0100 Subject: [PATCH 2/3] fix: address review comments Signed-off-by: Andrew Coleman --- .../substrait/spark/logical/ToLogicalPlan.scala | 8 ++++---- .../io/substrait/spark/SubstraitJsonSuite.scala | 17 ----------------- 2 files changed, 4 insertions(+), 21 deletions(-) diff --git a/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala b/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala index 0f823cc4f..d2abf13c5 100644 --- a/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala +++ b/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala @@ -319,7 +319,7 @@ class ToLogicalPlan(val spark: AnyRef = SparkCompat.instance.getOrCreateSparkSes withOutput(output) { val projectExprs = { project.getExpressions.asScala - .map(expr => expr.accept(expressionConverter, context)) + .map(_.accept(expressionConverter, context)) .toSeq } val projectList = if (names.size == projectExprs.size) { @@ -328,11 +328,11 @@ class ToLogicalPlan(val spark: AnyRef = SparkCompat.instance.getOrCreateSparkSes projectExprs.map(toNamedExpression) } if (createProject) { - val ps = output.map(_.toAttribute) ++ projectList + val allExpressions = output.map(_.toAttribute) ++ projectList val remapped = if (project.getRemap.isPresent) { - project.getRemap.get().indices().asScala.map(i => ps(i)).toSeq + project.getRemap.get().indices().asScala.map(allExpressions(_)).toSeq } else { - ps + allExpressions } Project(remapped, child) } else { diff --git a/spark/src/test/scala/io/substrait/spark/SubstraitJsonSuite.scala b/spark/src/test/scala/io/substrait/spark/SubstraitJsonSuite.scala index 363ee904b..ba017c6e2 100644 --- a/spark/src/test/scala/io/substrait/spark/SubstraitJsonSuite.scala +++ b/spark/src/test/scala/io/substrait/spark/SubstraitJsonSuite.scala @@ -1,20 +1,3 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to you under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - package io.substrait.spark import io.substrait.plan.ProtoPlanConverter From a160e10e3602da1d8fe3e3572fe3d4175721076d Mon Sep 17 00:00:00 2001 From: Andrew Coleman Date: Mon, 20 Apr 2026 16:07:42 +0100 Subject: [PATCH 3/3] fix: rewrite remap test Assisted by AI Signed-off-by: Andrew Coleman --- spark/spark-3.4_2.12/build.gradle.kts | 1 - spark/spark-3.5_2.12/build.gradle.kts | 1 - spark/spark-4.0_2.13/build.gradle.kts | 1 - spark/src/test/resources/nba_team.csv | 31 - spark/src/test/resources/nba_team_history.csv | 53 -- .../substrait_plan_nba_california.json | 822 ---------------- .../substrait_plan_with_aggregate_op.json | 893 ------------------ .../src/test/resources/tests_subset_2023.csv | 30 - .../test/resources/vehicles_subset_2023.csv | 31 - .../spark/AggregateWithJoinSuite.scala | 275 ++++++ .../substrait/spark/SubstraitJsonSuite.scala | 106 --- 11 files changed, 275 insertions(+), 1969 deletions(-) delete mode 100644 spark/src/test/resources/nba_team.csv delete mode 100644 spark/src/test/resources/nba_team_history.csv delete mode 100644 spark/src/test/resources/substrait_plan_nba_california.json delete mode 100644 spark/src/test/resources/substrait_plan_with_aggregate_op.json delete mode 100644 spark/src/test/resources/tests_subset_2023.csv delete mode 100644 spark/src/test/resources/vehicles_subset_2023.csv create mode 100644 spark/src/test/scala/io/substrait/spark/AggregateWithJoinSuite.scala delete mode 100644 spark/src/test/scala/io/substrait/spark/SubstraitJsonSuite.scala diff --git a/spark/spark-3.4_2.12/build.gradle.kts b/spark/spark-3.4_2.12/build.gradle.kts index da15de2a0..f6317b8ad 100644 --- a/spark/spark-3.4_2.12/build.gradle.kts +++ b/spark/spark-3.4_2.12/build.gradle.kts @@ -135,7 +135,6 @@ dependencies { implementation(platform(libs.jackson.bom)) implementation(libs.bundles.jackson) implementation(libs.json.schema.validator) - testImplementation(libs.protobuf.java.util) testImplementation(platform(libs.junit.bom)) testRuntimeOnly(libs.junit.platform.engine) diff --git a/spark/spark-3.5_2.12/build.gradle.kts b/spark/spark-3.5_2.12/build.gradle.kts index 7f9e37e87..031447f79 100644 --- a/spark/spark-3.5_2.12/build.gradle.kts +++ b/spark/spark-3.5_2.12/build.gradle.kts @@ -136,7 +136,6 @@ dependencies { implementation(platform(libs.jackson.bom)) implementation(libs.bundles.jackson) implementation(libs.json.schema.validator) - testImplementation(libs.protobuf.java.util) testImplementation(platform(libs.junit.bom)) testRuntimeOnly(libs.junit.platform.engine) diff --git a/spark/spark-4.0_2.13/build.gradle.kts b/spark/spark-4.0_2.13/build.gradle.kts index 0f4aca567..cd39d39ca 100644 --- a/spark/spark-4.0_2.13/build.gradle.kts +++ b/spark/spark-4.0_2.13/build.gradle.kts @@ -136,7 +136,6 @@ dependencies { implementation(platform(libs.jackson.bom)) implementation(libs.bundles.jackson) implementation(libs.json.schema.validator) - testImplementation(libs.protobuf.java.util) testImplementation(platform(libs.junit.bom)) testRuntimeOnly(libs.junit.platform.engine) diff --git a/spark/src/test/resources/nba_team.csv b/spark/src/test/resources/nba_team.csv deleted file mode 100644 index 825bc7aa5..000000000 --- a/spark/src/test/resources/nba_team.csv +++ /dev/null @@ -1,31 +0,0 @@ -id,full_name,abbreviation,nickname,city,state,year_founded -1610612737,Atlanta Hawks,ATL,Hawks,Atlanta,Atlanta,1949.0 -1610612738,Boston Celtics,BOS,Celtics,Boston,Massachusetts,1946.0 -1610612739,Cleveland Cavaliers,CLE,Cavaliers,Cleveland,Ohio,1970.0 -1610612740,New Orleans Pelicans,NOP,Pelicans,New Orleans,Louisiana,2002.0 -1610612741,Chicago Bulls,CHI,Bulls,Chicago,Illinois,1966.0 -1610612742,Dallas Mavericks,DAL,Mavericks,Dallas,Texas,1980.0 -1610612743,Denver Nuggets,DEN,Nuggets,Denver,Colorado,1976.0 -1610612744,Golden State Warriors,GSW,Warriors,Golden State,California,1946.0 -1610612745,Houston Rockets,HOU,Rockets,Houston,Texas,1967.0 -1610612746,Los Angeles Clippers,LAC,Clippers,Los Angeles,California,1970.0 -1610612747,Los Angeles Lakers,LAL,Lakers,Los Angeles,California,1948.0 -1610612748,Miami Heat,MIA,Heat,Miami,Florida,1988.0 -1610612749,Milwaukee Bucks,MIL,Bucks,Milwaukee,Wisconsin,1968.0 -1610612750,Minnesota Timberwolves,MIN,Timberwolves,Minnesota,Minnesota,1989.0 -1610612751,Brooklyn Nets,BKN,Nets,Brooklyn,New York,1976.0 -1610612752,New York Knicks,NYK,Knicks,New York,New York,1946.0 -1610612753,Orlando Magic,ORL,Magic,Orlando,Florida,1989.0 -1610612754,Indiana Pacers,IND,Pacers,Indiana,Indiana,1976.0 -1610612755,Philadelphia 76ers,PHI,76ers,Philadelphia,Pennsylvania,1949.0 -1610612756,Phoenix Suns,PHX,Suns,Phoenix,Arizona,1968.0 -1610612757,Portland Trail Blazers,POR,Trail Blazers,Portland,Oregon,1970.0 -1610612758,Sacramento Kings,SAC,Kings,Sacramento,California,1948.0 -1610612759,San Antonio Spurs,SAS,Spurs,San Antonio,Texas,1976.0 -1610612760,Oklahoma City Thunder,OKC,Thunder,Oklahoma City,Oklahoma,1967.0 -1610612761,Toronto Raptors,TOR,Raptors,Toronto,Ontario,1995.0 -1610612762,Utah Jazz,UTA,Jazz,Utah,Utah,1974.0 -1610612763,Memphis Grizzlies,MEM,Grizzlies,Memphis,Tennessee,1995.0 -1610612764,Washington Wizards,WAS,Wizards,Washington,District of Columbia,1961.0 -1610612765,Detroit Pistons,DET,Pistons,Detroit,Michigan,1948.0 -1610612766,Charlotte Hornets,CHA,Hornets,Charlotte,North Carolina,1988.0 diff --git a/spark/src/test/resources/nba_team_history.csv b/spark/src/test/resources/nba_team_history.csv deleted file mode 100644 index 6fdb7ee52..000000000 --- a/spark/src/test/resources/nba_team_history.csv +++ /dev/null @@ -1,53 +0,0 @@ -team_id,city,nickname,year_founded,year_active_till -1610612737,Atlanta,Hawks,1968,2019 -1610612737,St. Louis,Hawks,1955,1967 -1610612737,Milwaukee,Hawks,1951,1954 -1610612737,Tri-Cities,Blackhawks,1949,1950 -1610612741,Chicago,Bulls,1966,2019 -1610612742,Dallas,Mavericks,1980,2019 -1610612743,Denver,Nuggets,1976,2019 -1610612744,Golden State,Warriors,1971,2019 -1610612744,San Francisco,Warriors,1962,1970 -1610612744,Philadelphia,Warriors,1946,1961 -1610612745,Houston,Rockets,1971,2019 -1610612745,San Diego,Rockets,1967,1970 -1610612746,Los Angeles,Clippers,1984,2019 -1610612746,San Diego,Clippers,1978,1983 -1610612746,Buffalo,Braves,1970,1977 -1610612747,Los Angeles,Lakers,1960,2019 -1610612747,Minneapolis,Lakers,1948,1959 -1610612748,Miami,Heat,1988,2019 -1610612749,Milwaukee,Bucks,1968,2019 -1610612750,Minnesota,Timberwolves,1989,2019 -1610612751,Brooklyn,Nets,2012,2019 -1610612751,New Jersey,Nets,1977,2011 -1610612751,New York,Nets,1976,1976 -1610612754,Indiana,Pacers,1976,2019 -1610612755,Philadelphia,76ers,1963,2019 -1610612755,Syracuse,Nationals,1949,1962 -1610612756,Phoenix,Suns,1968,2019 -1610612757,Portland,Trail Blazers,1970,2019 -1610612758,Sacramento,Kings,1985,2019 -1610612758,Kansas City,Kings,1975,1984 -1610612758,Kansas City-Omaha,Kings,1972,1974 -1610612758,Cincinnati,Royals,1957,1971 -1610612758,Rochester,Royals,1948,1956 -1610612759,San Antonio,Spurs,1976,2019 -1610612760,Oklahoma City,Thunder,2008,2019 -1610612760,Seattle,SuperSonics,1967,2007 -1610612761,Toronto,Raptors,1995,2019 -1610612762,Utah,Jazz,1979,2019 -1610612762,New Orleans,Jazz,1974,1978 -1610612763,Memphis,Grizzlies,2001,2019 -1610612763,Vancouver,Grizzlies,1995,2000 -1610612764,Washington,Wizards,1997,2019 -1610612764,Washington,Bullets,1974,1996 -1610612764,Capital,Bullets,1973,1973 -1610612764,Baltimore,Bullets,1963,1972 -1610612764,Chicago,Zephyrs,1962,1962 -1610612764,Chicago,Packers,1961,1961 -1610612765,Detroit,Pistons,1957,2019 -1610612765,Ft. Wayne Zollner,Pistons,1948,1956 -1610612766,Charlotte,Hornets,2014,2019 -1610612766,Charlotte,Bobcats,2004,2013 -1610612766,Charlotte,Hornets,1988,2001 diff --git a/spark/src/test/resources/substrait_plan_nba_california.json b/spark/src/test/resources/substrait_plan_nba_california.json deleted file mode 100644 index 353ee5a5b..000000000 --- a/spark/src/test/resources/substrait_plan_nba_california.json +++ /dev/null @@ -1,822 +0,0 @@ -{ - "extensions": [ - { - "extensionFunction": { - "functionAnchor": 277, - "name": "equal:any_any", - "extensionUrnReference": 7 - } - }, - { - "extensionFunction": { - "functionAnchor": 159, - "name": "concat:str", - "extensionUrnReference": 5 - } - } - ], - "relations": [ - { - "root": { - "input": { - "project": { - "input": { - "filter": { - "input": { - "project": { - "common": { - "emit": { - "outputMapping": [ - 17, - 18, - 19, - 20, - 21, - 22, - 23, - 24, - 25, - 26, - 27, - 28, - 29, - 30, - 31, - 32, - 33 - ] - } - }, - "input": { - "join": { - "left": { - "project": { - "common": { - "emit": { - "outputMapping": [ - 7, - 8, - 9, - 10, - 11, - 12, - 13 - ] - } - }, - "input": { - "read": { - "common": { - "direct": {} - }, - "baseSchema": { - "names": [ - "id", - "full_name", - "abbreviation", - "nickname", - "city", - "state", - "year_founded" - ], - "struct": { - "types": [ - { - "i32": { - "nullability": "NULLABILITY_NULLABLE" - } - }, - { - "varchar": { - "length": 256, - "nullability": "NULLABILITY_NULLABLE" - } - }, - { - "varchar": { - "length": 256, - "nullability": "NULLABILITY_NULLABLE" - } - }, - { - "varchar": { - "length": 256, - "nullability": "NULLABILITY_NULLABLE" - } - }, - { - "varchar": { - "length": 256, - "nullability": "NULLABILITY_NULLABLE" - } - }, - { - "varchar": { - "length": 256, - "nullability": "NULLABILITY_NULLABLE" - } - }, - { - "fp64": { - "nullability": "NULLABILITY_NULLABLE" - } - } - ], - "nullability": "NULLABILITY_REQUIRED" - } - }, - "namedTable": { - "names": [ - "nba", - "team" - ] - } - } - }, - "expressions": [ - { - "selection": { - "directReference": { - "structField": {} - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 1 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 2 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 3 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 4 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 5 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 6 - } - }, - "rootReference": {} - } - } - ] - } - }, - "right": { - "project": { - "common": { - "emit": { - "outputMapping": [ - 10, - 11, - 12, - 13, - 14, - 15, - 16, - 17, - 18, - 19 - ] - } - }, - "input": { - "project": { - "input": { - "read": { - "common": { - "direct": {} - }, - "baseSchema": { - "names": [ - "team_id", - "city", - "nickname", - "year_founded", - "year_active_till" - ], - "struct": { - "types": [ - { - "i32": { - "nullability": "NULLABILITY_NULLABLE" - } - }, - { - "varchar": { - "length": 256, - "nullability": "NULLABILITY_NULLABLE" - } - }, - { - "varchar": { - "length": 256, - "nullability": "NULLABILITY_NULLABLE" - } - }, - { - "i32": { - "nullability": "NULLABILITY_NULLABLE" - } - }, - { - "i32": { - "nullability": "NULLABILITY_NULLABLE" - } - } - ], - "nullability": "NULLABILITY_REQUIRED" - } - }, - "namedTable": { - "names": [ - "nba", - "team_history" - ] - } - } - }, - "expressions": [ - { - "selection": { - "directReference": { - "structField": {} - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 1 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 2 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 3 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 4 - } - }, - "rootReference": {} - } - } - ] - } - }, - "expressions": [ - { - "selection": { - "directReference": { - "structField": {} - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 1 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 2 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 3 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 4 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": {} - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 1 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 2 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 3 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 4 - } - }, - "rootReference": {} - } - } - ] - } - }, - "expression": { - "scalarFunction": { - "functionReference": 277, - "outputType": { - "bool": { - "nullability": "NULLABILITY_NULLABLE" - } - }, - "arguments": [ - { - "value": { - "selection": { - "directReference": { - "structField": {} - }, - "rootReference": {} - } - } - }, - { - "value": { - "selection": { - "directReference": { - "structField": { - "field": 7 - } - }, - "rootReference": {} - } - } - } - ] - } - }, - "type": "JOIN_TYPE_INNER" - } - }, - "expressions": [ - { - "selection": { - "directReference": { - "structField": { - "field": 7 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 8 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 9 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 10 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 11 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 7 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 13 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 14 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 15 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 16 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": {} - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 1 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 2 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 3 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 4 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 5 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 6 - } - }, - "rootReference": {} - } - } - ] - } - }, - "condition": { - "scalarFunction": { - "functionReference": 277, - "outputType": { - "bool": { - "nullability": "NULLABILITY_REQUIRED" - } - }, - "arguments": [ - { - "value": { - "cast": { - "type": { - "string": { - "nullability": "NULLABILITY_REQUIRED" - } - }, - "input": { - "selection": { - "directReference": { - "structField": { - "field": 15 - } - }, - "rootReference": {} - } - }, - "failureBehavior": "FAILURE_BEHAVIOR_RETURN_NULL" - } - } - }, - { - "value": { - "literal": { - "string": "California" - } - } - } - ] - } - } - } - }, - "expressions": [ - { - "scalarFunction": { - "functionReference": 159, - "outputType": { - "string": { - "nullability": "NULLABILITY_REQUIRED" - } - }, - "arguments": [ - { - "value": { - "cast": { - "type": { - "string": { - "nullability": "NULLABILITY_REQUIRED" - } - }, - "input": { - "selection": { - "directReference": { - "structField": { - "field": 6 - } - }, - "rootReference": {} - } - }, - "failureBehavior": "FAILURE_BEHAVIOR_RETURN_NULL" - } - } - }, - { - "value": { - "literal": { - "string": " " - } - } - }, - { - "value": { - "cast": { - "type": { - "string": { - "nullability": "NULLABILITY_REQUIRED" - } - }, - "input": { - "selection": { - "directReference": { - "structField": { - "field": 7 - } - }, - "rootReference": {} - } - }, - "failureBehavior": "FAILURE_BEHAVIOR_RETURN_NULL" - } - } - } - ] - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 6 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 8 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 9 - } - }, - "rootReference": {} - } - } - ] - } - }, - "names": [ - "team_id", - "city", - "nickname", - "year_founded", - "year_active_till", - "team_id", - "hist_city", - "hist_nickname", - "hist_year_founded", - "hist_year_active_till", - "id", - "full_name", - "abbreviation", - "nickname", - "city", - "state", - "year_founded", - "team_name", - "city", - "year_founded", - "year_active_till" - ] - } - } - ], - "version": { - "minorNumber": 79 - }, - "extensionUrns": [ - { - "extensionUrnAnchor": 7, - "urn": "extension:io.substrait:functions_comparison" - }, - { - "extensionUrnAnchor": 5, - "urn": "extension:io.substrait:functions_string" - } - ] -} diff --git a/spark/src/test/resources/substrait_plan_with_aggregate_op.json b/spark/src/test/resources/substrait_plan_with_aggregate_op.json deleted file mode 100644 index 8fb0cc8c9..000000000 --- a/spark/src/test/resources/substrait_plan_with_aggregate_op.json +++ /dev/null @@ -1,893 +0,0 @@ -{ - "extensions": [ - { - "extensionFunction": { - "functionAnchor": 277, - "name": "equal:any_any", - "extensionUrnReference": 7 - } - }, - { - "extensionFunction": { - "functionAnchor": 415, - "name": "sum:i32", - "extensionUrnReference": 9 - } - } - ], - "relations": [ - { - "root": { - "input": { - "project": { - "common": { - "direct": {} - }, - "input": { - "aggregate": { - "input": { - "project": { - "common": { - "emit": { - "outputMapping": [ - 20, - 21, - 22, - 23, - 24, - 25, - 26, - 27, - 28, - 29, - 30, - 31, - 32, - 33, - 34, - 35, - 36, - 37, - 38, - 39 - ] - } - }, - "input": { - "join": { - "left": { - "project": { - "common": { - "emit": { - "outputMapping": [ - 11, - 12, - 13, - 14, - 15, - 16, - 17, - 18, - 19, - 20, - 21 - ] - } - }, - "input": { - "project": { - "common": { - "direct": {} - }, - "input": { - "read": { - "common": { - "direct": {} - }, - "baseSchema": { - "names": [ - "test_id", - "vehicle_id", - "test_date", - "test_class", - "test_type", - "test_result", - "test_mileage", - "postcode_area" - ], - "struct": { - "types": [ - { - "i32": { - "nullability": "NULLABILITY_NULLABLE" - } - }, - { - "i32": { - "nullability": "NULLABILITY_NULLABLE" - } - }, - { - "date": { - "nullability": "NULLABILITY_NULLABLE" - } - }, - { - "i32": { - "nullability": "NULLABILITY_NULLABLE" - } - }, - { - "varchar": { - "length": 256, - "nullability": "NULLABILITY_NULLABLE" - } - }, - { - "varchar": { - "length": 256, - "nullability": "NULLABILITY_NULLABLE" - } - }, - { - "i32": { - "nullability": "NULLABILITY_NULLABLE" - } - }, - { - "varchar": { - "length": 256, - "nullability": "NULLABILITY_NULLABLE" - } - } - ], - "nullability": "NULLABILITY_REQUIRED" - } - }, - "namedTable": { - "names": [ - "test1", - "tests" - ] - } - } - }, - "expressions": [ - { - "selection": { - "directReference": { - "structField": { - "field": 1 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 6 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 7 - } - }, - "rootReference": {} - } - } - ] - } - }, - "expressions": [ - { - "selection": { - "directReference": { - "structField": {} - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 1 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 2 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 3 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 4 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 5 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 6 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 7 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 8 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 6 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 7 - } - }, - "rootReference": {} - } - } - ] - } - }, - "right": { - "project": { - "common": { - "emit": { - "outputMapping": [ - 9, - 10, - 11, - 12, - 13, - 14, - 15, - 16, - 17 - ] - } - }, - "input": { - "project": { - "common": { - "direct": {} - }, - "input": { - "read": { - "common": { - "direct": {} - }, - "baseSchema": { - "names": [ - "vehicle_id", - "make", - "model", - "colour", - "fuel_type", - "cylinder_capacity", - "first_use_date" - ], - "struct": { - "types": [ - { - "i32": { - "nullability": "NULLABILITY_NULLABLE" - } - }, - { - "varchar": { - "length": 256, - "nullability": "NULLABILITY_NULLABLE" - } - }, - { - "varchar": { - "length": 256, - "nullability": "NULLABILITY_NULLABLE" - } - }, - { - "varchar": { - "length": 256, - "nullability": "NULLABILITY_NULLABLE" - } - }, - { - "varchar": { - "length": 256, - "nullability": "NULLABILITY_NULLABLE" - } - }, - { - "i32": { - "nullability": "NULLABILITY_NULLABLE" - } - }, - { - "date": { - "nullability": "NULLABILITY_NULLABLE" - } - } - ], - "nullability": "NULLABILITY_REQUIRED" - } - }, - "namedTable": { - "names": [ - "test1", - "vehicles" - ] - } - } - }, - "expressions": [ - { - "selection": { - "directReference": { - "structField": {} - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 4 - } - }, - "rootReference": {} - } - } - ] - } - }, - "expressions": [ - { - "selection": { - "directReference": { - "structField": {} - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 1 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 2 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 3 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 4 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 5 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 6 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 7 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 4 - } - }, - "rootReference": {} - } - } - ] - } - }, - "expression": { - "scalarFunction": { - "functionReference": 277, - "outputType": { - "bool": { - "nullability": "NULLABILITY_NULLABLE" - } - }, - "arguments": [ - { - "value": { - "selection": { - "directReference": { - "structField": { - "field": 8 - } - }, - "rootReference": {} - } - } - }, - { - "value": { - "selection": { - "directReference": { - "structField": { - "field": 18 - } - }, - "rootReference": {} - } - } - } - ] - } - }, - "type": "JOIN_TYPE_INNER" - } - }, - "expressions": [ - { - "selection": { - "directReference": { - "structField": {} - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 1 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 2 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 3 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 4 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 5 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 6 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 7 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 8 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 6 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 7 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 11 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 12 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 13 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 14 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 15 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 16 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 17 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 18 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 15 - } - }, - "rootReference": {} - } - } - ] - } - }, - "groupings": [ - { - "groupingExpressions": [ - { - "selection": { - "directReference": { - "structField": { - "field": 15 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 7 - } - }, - "rootReference": {} - } - } - ], - "expressionReferences": [ - 0, - 1 - ] - } - ], - "measures": [ - { - "measure": { - "functionReference": 415, - "outputType": { - "i64": { - "nullability": "NULLABILITY_NULLABLE" - } - }, - "arguments": [ - { - "value": { - "selection": { - "directReference": { - "structField": { - "field": 6 - } - }, - "rootReference": {} - } - } - } - ], - "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT" - } - } - ], - "groupingExpressions": [ - { - "selection": { - "directReference": { - "structField": { - "field": 15 - } - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 7 - } - }, - "rootReference": {} - } - } - ] - } - }, - "expressions": [ - { - "selection": { - "directReference": { - "structField": {} - }, - "rootReference": {} - } - }, - { - "selection": { - "directReference": { - "structField": { - "field": 1 - } - }, - "rootReference": {} - } - }, - { - "cast": { - "type": { - "i64": { - "nullability": "NULLABILITY_REQUIRED" - } - }, - "input": { - "selection": { - "directReference": { - "structField": { - "field": 2 - } - }, - "rootReference": {} - } - }, - "failureBehavior": "FAILURE_BEHAVIOR_RETURN_NULL" - } - } - ] - } - }, - "names": [ - "fuel_type", - "postcode_area", - "total_miles_i32", - "fuel_type", - "postcode_area", - "total_miles" - ] - } - } - ], - "version": { - "minorNumber": 79 - }, - "extensionUrns": [ - { - "extensionUrnAnchor": 7, - "urn": "extension:io.substrait:functions_comparison" - }, - { - "extensionUrnAnchor": 9, - "urn": "extension:io.substrait:functions_arithmetic" - } - ] -} diff --git a/spark/src/test/resources/tests_subset_2023.csv b/spark/src/test/resources/tests_subset_2023.csv deleted file mode 100644 index 762d53491..000000000 --- a/spark/src/test/resources/tests_subset_2023.csv +++ /dev/null @@ -1,30 +0,0 @@ -test_id,vehicle_id,test_date,test_class,test_type,test_result,test_mileage,postcode_area -539514409,17113014,2023-01-09,4,NT,F,69934,PA -1122718877,986649781,2023-01-16,4,NT,F,57376,SG -1104881351,424684356,2023-03-06,4,NT,F,81853,SG -1487493049,1307056703,2023-03-07,4,NT,P,20763,SA -1107861883,130747047,2023-03-27,4,RT,P,125910,SA -472789285,777757523,2023-03-29,4,NT,P,68399,CO -1105082521,840180863,2023-04-15,4,NT,P,54240,NN -1172953135,917255260,2023-04-27,4,NT,P,60918,SM -127807783,888103385,2023-05-08,4,NT,P,112090,EH -1645970709,816803134,2023-06-03,4,NT,P,134858,RG -1355347761,919820431,2023-06-21,4,NT,P,37336,ST -1750209849,544950855,2023-06-23,4,NT,F,120034,NR -1376930435,439876988,2023-07-19,4,NT,P,109927,PO -582729949,1075446447,2023-07-19,4,NT,P,72986,SA -127953451,105663799,2023-07-31,4,NT,F,35824,ME -759291679,931759350,2023-08-07,4,NT,P,65353,DY -1629819891,335780567,2023-08-08,4,NT,PRS,103365,CF -1120026477,1153361746,2023-08-11,4,NT,P,286881,RM -1331300969,644861283,2023-08-15,4,NT,P,52173,LE -990694587,449899992,2023-08-16,4,NT,F,124891,SA -193460599,759696266,2023-08-29,4,NT,P,83554,LU -1337337679,1110416764,2023-10-09,4,NT,PRS,71093,SS -1885237527,137785384,2023-11-04,4,NT,P,88730,BH -1082642803,1291985882,2023-11-15,4,NT,PRS,160717,BA -896066743,615735063,2023-11-15,4,RT,P,107710,NR -1022666841,474362449,2023-11-20,4,NT,P,56296,HP -1010400923,1203222226,2023-12-04,4,NT,F,89255,TW -866705687,605696575,2023-12-06,4,NT,P,14674,YO -621751843,72093448,2023-12-14,4,NT,F,230280,TR diff --git a/spark/src/test/resources/vehicles_subset_2023.csv b/spark/src/test/resources/vehicles_subset_2023.csv deleted file mode 100644 index 087b54c84..000000000 --- a/spark/src/test/resources/vehicles_subset_2023.csv +++ /dev/null @@ -1,31 +0,0 @@ -vehicle_id,make,model,colour,fuel_type,cylinder_capacity,first_use_date -17113014,VAUXHALL,VIVARO,BLACK,DI,1995,2011-09-29 -986649781,VAUXHALL,INSIGNIA,WHITE,DI,1956,2017-07-19 -424684356,RENAULT,GRAND SCENIC,GREY,PE,1997,2010-07-19 -1307056703,RENAULT,CLIO,BLACK,DI,1461,2014-05-30 -130747047,FORD,FOCUS,SILVER,DI,1560,2013-07-10 -777757523,HYUNDAI,I10,WHITE,PE,998,2016-05-21 -840180863,BMW,1 SERIES,WHITE,PE,2979,2016-03-11 -917255260,VAUXHALL,ASTRA,WHITE,PE,1364,2012-04-21 -888103385,FORD,GALAXY,SILVER,DI,1997,2014-09-12 -816803134,FORD,FIESTA,BLUE,PE,1299,2002-10-24 -697184031,BMW,X1,WHITE,DI,1995,2016-03-31 -919820431,TOYOTA,AURIS,BRONZE,PE,1329,2015-06-29 -544950855,VAUXHALL,ASTRA,RED,DI,1956,2012-09-17 -439876988,MINI,MINI,GREEN,PE,1598,2010-03-31 -1075446447,CITROEN,C4,RED,DI,1560,2015-10-05 -105663799,RENAULT,KADJAR,BLACK,PE,1332,2020-07-23 -931759350,FIAT,DUCATO,WHITE,DI,2199,2008-04-18 -335780567,HYUNDAI,I20,BLUE,PE,1396,2013-08-13 -1153361746,TOYOTA,PRIUS,SILVER,HY,1800,2010-06-23 -644861283,FORD,FIESTA,BLACK,PE,998,2015-09-03 -449899992,BMW,3 SERIES,GREEN,DI,2926,2006-09-30 -759696266,CITROEN,C4,BLUE,DI,1997,2011-12-19 -1110416764,CITROEN,XSARA,SILVER,DI,1997,1999-06-30 -137785384,MINI,MINI,GREY,DI,1598,2011-11-29 -1291985882,LAND ROVER,DEFENDER,BLUE,DI,2495,2002-06-12 -615735063,VOLKSWAGEN,CADDY,WHITE,DI,1598,2013-03-01 -474362449,VAUXHALL,GRANDLAND,GREY,PE,1199,2018-11-12 -1203222226,VAUXHALL,ASTRA,BLUE,PE,1598,2010-06-03 -605696575,SUZUKI,SWIFT SZ-T DUALJET MHEV CVT,RED,HY,1197,2020-12-18 -72093448,AUDI,A4,SILVER,DI,1896,2001-03-19 diff --git a/spark/src/test/scala/io/substrait/spark/AggregateWithJoinSuite.scala b/spark/src/test/scala/io/substrait/spark/AggregateWithJoinSuite.scala new file mode 100644 index 000000000..e65a4e769 --- /dev/null +++ b/spark/src/test/scala/io/substrait/spark/AggregateWithJoinSuite.scala @@ -0,0 +1,275 @@ +package io.substrait.spark + +import io.substrait.dsl.SubstraitBuilder +import io.substrait.expression.{Expression, ExpressionCreator} +import io.substrait.extension.DefaultExtensionCatalog +import io.substrait.plan.Plan +import io.substrait.relation.{Aggregate, Join, Project, Rel, VirtualTableScan} +import io.substrait.spark.logical.ToLogicalPlan +import io.substrait.`type`.{NamedStruct, Type, TypeCreator} +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.classic.DatasetUtil +import org.apache.spark.sql.test.SharedSparkSession + +import java.util +import java.util.Arrays +import scala.jdk.CollectionConverters._ + +/** + * Test case that constructs a Substrait plan with aggregate operations using the builder syntax. + * This test creates a plan equivalent to substrait_plan_with_aggregate_op.json but uses + * VirtualTableScan instead of NamedTable for input data. + * + * The plan structure is: + * - Root with output mapping [3, 4, 5] and names ["fuel_type", "postcode_area", "total_test_mileage"] + * - Project with output mapping [3, 4, 5] (selecting fields 0, 1, 2 from aggregate) + * - Aggregate grouping by fields 4 and 14, with sum(field 13) + * - Project with output mapping [15-29] (selecting specific fields from join) + * - Inner Join on tests.vehicle_id = vehicles.vehicle_id + * - Left: Project selecting all 8 fields from tests VirtualTableScan + * - Right: Project selecting all 7 fields from vehicles VirtualTableScan + */ +class AggregateWithJoinSuite extends SparkFunSuite with SharedSparkSession with SubstraitPlanTestBase { + + private val R = TypeCreator.REQUIRED + private val N = TypeCreator.NULLABLE + private val extensions = DefaultExtensionCatalog.DEFAULT_COLLECTION + private val sb = new SubstraitBuilder(extensions) + + private def createTestsRow( + testId: Int, + vehicleId: Int, + testDate: Int, + testClass: Int, + testType: String, + testResult: String, + testMileage: Int, + postcodeArea: String): Expression.NestedStruct = { + Expression.NestedStruct.builder() + .addFields(ExpressionCreator.i32(true, testId)) + .addFields(ExpressionCreator.i32(true, vehicleId)) + .addFields(ExpressionCreator.date(true, testDate)) + .addFields(ExpressionCreator.i32(true, testClass)) + .addFields(ExpressionCreator.string(true, testType)) + .addFields(ExpressionCreator.string(true, testResult)) + .addFields(ExpressionCreator.i32(true, testMileage)) + .addFields(ExpressionCreator.string(true, postcodeArea)) + .build() + } + + /** + * Creates a VirtualTableScan for the tests table with schema: + * test_id (i32), vehicle_id (i32), test_date (date), test_class (i32), + * test_type (varchar), test_result (varchar), test_mileage (i32), postcode_area (varchar) + */ + private def createTestsTable(): VirtualTableScan = { + val columnNames = Arrays.asList( + "test_id", "vehicle_id", "test_date", "test_class", + "test_type", "test_result", "test_mileage", "postcode_area") + + val struct = Type.Struct.builder() + .addFields(N.I32) // test_id + .addFields(N.I32) // vehicle_id + .addFields(N.DATE) // test_date + .addFields(N.I32) // test_class + .addFields(N.STRING) // test_type + .addFields(N.STRING) // test_result + .addFields(N.I32) // test_mileage + .addFields(N.STRING) // postcode_area + .nullable(false) + .build() + + val namedStruct = NamedStruct.of(columnNames, struct) + + VirtualTableScan.builder() + .initialSchema(namedStruct) + .addRows(createTestsRow(151, 100, 19000, 4, "MOT", "PASS", 50000, "GU")) + .addRows(createTestsRow(385, 102, 20000, 4, "MOT", "PASS", 15000, "GU")) + .addRows(createTestsRow(222, 101, 20000, 4, "MOT", "PASS", 20000, "PO")) + .addRows(createTestsRow(164, 101, 20000, 4, "MOT", "FAIL", 35000, "GU")) + .build() + } + + private def createVehiclesRow( + vehicleId: Int, + make: String, + model: String, + colour: String, + fuelType: String, + cylinderCapacity: Int, + firstUseDate: Int): Expression.NestedStruct = { + Expression.NestedStruct.builder() + .addFields(ExpressionCreator.i32(true, vehicleId)) + .addFields(ExpressionCreator.string(true, make)) + .addFields(ExpressionCreator.string(true, model)) + .addFields(ExpressionCreator.string(true, colour)) + .addFields(ExpressionCreator.string(true, fuelType)) + .addFields(ExpressionCreator.i32(true, cylinderCapacity)) + .addFields(ExpressionCreator.date(true, firstUseDate)) + .build() + } + + /** + * Creates a VirtualTableScan for the vehicles table with schema: + * vehicle_id (i32), make (varchar), model (varchar), colour (varchar), + * fuel_type (varchar), cylinder_capacity (i32), first_use_date (date) + */ + private def createVehiclesTable(): VirtualTableScan = { + val columnNames = Arrays.asList( + "vehicle_id", "make", "model", "colour", + "fuel_type", "cylinder_capacity", "first_use_date") + + val struct = Type.Struct.builder() + .addFields(N.I32) // vehicle_id + .addFields(N.STRING) // make + .addFields(N.STRING) // model + .addFields(N.STRING) // colour + .addFields(N.STRING) // fuel_type + .addFields(N.I32) // cylinder_capacity + .addFields(N.DATE) // first_use_date + .nullable(false) + .build() + + val namedStruct = NamedStruct.of(columnNames, struct) + + VirtualTableScan.builder() + .initialSchema(namedStruct) + .addRows(createVehiclesRow(100, "Ford", "Focus", "Blue", "Petrol", 1600, 18000)) + .addRows(createVehiclesRow(101, "VW", "Golf", "Red", "Diesel", 2000, 19000)) + .addRows(createVehiclesRow(100, "Ford", "Fiesta", "White", "Petrol", 1200, 18000)) + .build() + } + + test("aggregate with join plan structure") { + // Create base tables + val testsTable = createTestsTable() + val vehiclesTable = createVehiclesTable() + + // Project all fields from tests table with output mapping [8-15] + val leftProject = Project.builder() + .input(testsTable) + .addExpressions(sb.fieldReference(testsTable, 0)) // test_id + .addExpressions(sb.fieldReference(testsTable, 1)) // vehicle_id + .addExpressions(sb.fieldReference(testsTable, 2)) // test_date + .addExpressions(sb.fieldReference(testsTable, 3)) // test_class + .addExpressions(sb.fieldReference(testsTable, 4)) // test_type + .addExpressions(sb.fieldReference(testsTable, 5)) // test_result + .addExpressions(sb.fieldReference(testsTable, 6)) // test_mileage + .addExpressions(sb.fieldReference(testsTable, 7)) // postcode_area + .remap(Rel.Remap.of(Arrays.asList(8, 9, 10, 11, 12, 13, 14, 15))) + .build() + + // Project all fields from vehicles table with output mapping [7-13] + val rightProject = Project.builder() + .input(vehiclesTable) + .addExpressions(sb.fieldReference(vehiclesTable, 0)) // vehicle_id + .addExpressions(sb.fieldReference(vehiclesTable, 1)) // make + .addExpressions(sb.fieldReference(vehiclesTable, 2)) // model + .addExpressions(sb.fieldReference(vehiclesTable, 3)) // colour + .addExpressions(sb.fieldReference(vehiclesTable, 4)) // fuel_type + .addExpressions(sb.fieldReference(vehiclesTable, 5)) // cylinder_capacity + .addExpressions(sb.fieldReference(vehiclesTable, 6)) // first_use_date + .remap(Rel.Remap.of(Arrays.asList(7, 8, 9, 10, 11, 12, 13))) + .build() + + // Inner join on tests.vehicle_id = vehicles.vehicle_id + val join = sb.innerJoin( + (joinInput: SubstraitBuilder.JoinInput) => { + sb.equal( + sb.fieldReference(joinInput, 1), // tests.vehicle_id (field 1 from left) + sb.fieldReference(joinInput, 8)) // vehicles.vehicle_id (field 8 in combined output) + }, + leftProject, + rightProject + ) + + // Project after join with output mapping [15-29] + // Reorders to: vehicles fields (8-14) + tests fields (0-7) + val postJoinProject = Project.builder() + .input(join) + .addExpressions(sb.fieldReference(join, 8)) // vehicles.vehicle_id + .addExpressions(sb.fieldReference(join, 9)) // vehicles.make + .addExpressions(sb.fieldReference(join, 10)) // vehicles.model + .addExpressions(sb.fieldReference(join, 11)) // vehicles.colour + .addExpressions(sb.fieldReference(join, 12)) // vehicles.fuel_type + .addExpressions(sb.fieldReference(join, 13)) // vehicles.cylinder_capacity + .addExpressions(sb.fieldReference(join, 14)) // vehicles.first_use_date + .addExpressions(sb.fieldReference(join, 0)) // tests.test_id + .addExpressions(sb.fieldReference(join, 1)) // tests.vehicle_id + .addExpressions(sb.fieldReference(join, 2)) // tests.test_date + .addExpressions(sb.fieldReference(join, 3)) // tests.test_class + .addExpressions(sb.fieldReference(join, 4)) // tests.test_type + .addExpressions(sb.fieldReference(join, 5)) // tests.test_result + .addExpressions(sb.fieldReference(join, 6)) // tests.test_mileage + .addExpressions(sb.fieldReference(join, 7)) // tests.postcode_area + .remap(Rel.Remap.of(Arrays.asList(15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29))) + .build() + + // Create aggregate with grouping by fuel_type (field 4) and postcode_area (field 14) + // and sum of test_mileage (field 13) + val aggregate = sb.aggregate( + (input: Rel) => { + // Create grouping by fuel_type and postcode_area + Aggregate.Grouping.builder() + .addExpressions(sb.fieldReference(input, 4)) // fuel_type (field 4 in postJoinProject) + .addExpressions(sb.fieldReference(input, 14)) // postcode_area (field 14 in postJoinProject) + .build() + }, + (input: Rel) => { + // Create measure for sum of test_mileage + val sumMeasure = sb.sum(input, 13) // sum of test_mileage (field 13 in postJoinProject) + util.Arrays.asList(sumMeasure) + }, + postJoinProject + ) + + // Project to select final output fields with output mapping [3, 4, 5] + val finalProject = Project.builder() + .input(aggregate) + .addExpressions(sb.fieldReference(aggregate, 0)) // fuel_type (grouping key 0) + .addExpressions(sb.fieldReference(aggregate, 1)) // postcode_area (grouping key 1) + .addExpressions(sb.fieldReference(aggregate, 2)) // total_test_mileage (measure 0) + .remap(Rel.Remap.of(Arrays.asList(3, 4, 5))) + .build() + + // Wrap in a Plan with Root and output column names + val root = Plan.Root.builder() + .input(finalProject) + .addNames("fuel_type", "postcode_area", "total_test_mileage") + .build() + + val plan = Plan.builder() + .addRoots(root) + .build() + + val converter = new ToLogicalPlan(spark) + val sparkPlan = converter.convert(plan) + val output = DatasetUtil.fromLogicalPlan(spark, sparkPlan) + // val output = spark.sessionState.executePlan(sparkPlan).executedPlan.execute() + // there should be 1 row and 3 columns + assertResult(3)(output.count()) + val rows = output.take(3) + println(sparkPlan) + + // should produce the output: + // +---------+-------------+------------------+ + // |fuel_type|postcode_area|total_test_mileage| + // +---------+-------------+------------------+ + // | Petrol| GU| 100000| + // | Diesel| PO| 20000| + // | Diesel| GU| 35000| + // +---------+-------------+------------------+ + + assertRow(rows(0), "Petrol", "GU", 100000) + assertRow(rows(1), "Diesel", "PO", 20000) + assertRow(rows(2), "Diesel", "GU", 35000) + } + + def assertRow(row: Row, fuelType: String, postcodeArea: String, totalTestMileage: Long): Unit = { + assertResult(fuelType)(row.getString(0)) + assertResult(postcodeArea)(row.getString(1)) + assertResult(totalTestMileage)(row.getLong(2)) + } +} diff --git a/spark/src/test/scala/io/substrait/spark/SubstraitJsonSuite.scala b/spark/src/test/scala/io/substrait/spark/SubstraitJsonSuite.scala deleted file mode 100644 index ba017c6e2..000000000 --- a/spark/src/test/scala/io/substrait/spark/SubstraitJsonSuite.scala +++ /dev/null @@ -1,106 +0,0 @@ -package io.substrait.spark - -import io.substrait.plan.ProtoPlanConverter -import io.substrait.spark.logical.ToLogicalPlan -import org.apache.hadoop.shaded.com.google.gson.{Gson, JsonObject} -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan - -import java.io.IOException -import java.nio.file.{Files, Path, Paths} - -/** - * Contains test cases for the Substrait to Spark converter where the substrait plan is generated - * by a (trusted) third party tool. This deviates from the 'round-trip' methodology whereby - * all the substrait plans are generated by the Spark to Substrait converter in this repo. - * - * The Substrait plan is loaded from a proto file (JSON format). - */ -class SubstraitJsonSuite extends SparkFunSuite { - var spark: SparkSession = - SparkSession.builder().config("spark.master", "local").enableHiveSupport().getOrCreate() - - test("Aggregate") { - val substraitPlanPath = Path.of("../src/test/resources/substrait_plan_with_aggregate_op.json") - val defaultSchema = "test1" - - // Create table test1.vehicles - val vehiclesFile = "../src/test/resources/vehicles_subset_2023.csv" - val dsVehicles = spark.read.option("header", true).option("inferSchema", true).csv(Paths.get(vehiclesFile).toAbsolutePath.toString) - spark.sql("CREATE NAMESPACE IF NOT EXISTS test1") - spark.sql("DROP TABLE IF EXISTS test1.vehicles") - try dsVehicles.writeTo("test1.vehicles").create() - catch { - case _: TableAlreadyExistsException => - // Table already exists, ignore - } - - // Create table test1.tests - val testsFile = "../src/test/resources/tests_subset_2023.csv" - val dsTests = spark.read.option("header", true).option("inferSchema", true).csv(Paths.get(testsFile).toAbsolutePath.toString) - spark.sql("DROP TABLE IF EXISTS test1.tests") - try dsTests.writeTo("test1.tests").create() - catch { - case _: TableAlreadyExistsException => - // Table already exists, ignore - } - - val sparkPlan = toSpark(substraitPlanPath, defaultSchema) - val output = spark.sessionState.executePlan(sparkPlan).executedPlan.execute() - // there should be 25 rows and 3 columns - assertResult(25)(output.count()) - assertResult(3)(output.take(1)(0).numFields) - } - - test("nba") { - val substraitPlanPath = Path.of("../src/test/resources/substrait_plan_nba_california.json") - val defaultSchema = "nba" - - // Create table nba.team - val vehiclesFile = "../src/test/resources/nba_team.csv" - val dsVehicles = spark.read.option("header", true).option("inferSchema", true).csv(Paths.get(vehiclesFile).toAbsolutePath.toString) - spark.sql("CREATE NAMESPACE IF NOT EXISTS nba") - spark.sql("DROP TABLE IF EXISTS nba.team") - try dsVehicles.writeTo("nba.team").create() - catch { - case _: TableAlreadyExistsException => - // Table already exists, ignore - } - - // Create table nba.team_history - val testsFile = "../src/test/resources/nba_team_history.csv" - val dsTests = spark.read.option("header", true).option("inferSchema", true).csv(Paths.get(testsFile).toAbsolutePath.toString) - spark.sql("DROP TABLE IF EXISTS nba.team_history") - try dsTests.writeTo("nba.team_history").create() - catch { - case _: TableAlreadyExistsException => - // Table already exists, ignore - } - - val sparkPlan = toSpark(substraitPlanPath, defaultSchema) - val output = spark.sessionState.executePlan(sparkPlan).executedPlan.execute() - // there should be 13 rows and 21 columns - assertResult(13)(output.count()) - assertResult(21)(output.take(1)(0).numFields) - } - - @throws[IOException] - def toSpark(substraitPlanPath: Path, defaultSchema: String): LogicalPlan = { - val defaultCatalog = "spark_catalog" - val jsonContent = Files.readString(substraitPlanPath) - val gson = new Gson - val jsonObj = gson.fromJson(jsonContent, classOf[JsonObject]) - val builder = io.substrait.proto.Plan.newBuilder - com.google.protobuf.util.JsonFormat.parser.ignoringUnknownFields.merge(jsonObj.toString, builder) - val proto = builder.build - val protoToPlan = new ProtoPlanConverter - val plan = protoToPlan.from(proto) - spark.catalog.setCurrentCatalog(defaultCatalog) - spark.catalog.setCurrentDatabase(defaultSchema) - val substraitConverter = new ToLogicalPlan(spark) - substraitConverter.convert(plan) - } - -}