diff --git a/gluten-delta/src/main/scala/org/apache/gluten/extension/DeltaPostTransformRules.scala b/gluten-delta/src/main/scala/org/apache/gluten/extension/DeltaPostTransformRules.scala index da81bdf83f32..e16a6d12fdab 100644 --- a/gluten-delta/src/main/scala/org/apache/gluten/extension/DeltaPostTransformRules.scala +++ b/gluten-delta/src/main/scala/org/apache/gluten/extension/DeltaPostTransformRules.scala @@ -20,13 +20,16 @@ import org.apache.gluten.execution.{DeltaScanTransformer, ProjectExecTransformer import org.apache.gluten.extension.columnar.transition.RemoveTransitions import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Expression, InputFileBlockLength, InputFileBlockStart, InputFileName} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, CreateNamedStruct, Expression, GetStructField, If, InputFileBlockLength, InputFileBlockStart, InputFileName, IsNull, LambdaFunction, Literal, NamedLambdaVariable} +import org.apache.spark.sql.catalyst.expressions.{ArrayTransform, TransformKeys, TransformValues} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.delta.{DeltaColumnMapping, DeltaParquetFileFormat, NoMapping} import org.apache.spark.sql.execution.{ProjectExec, SparkPlan} import org.apache.spark.sql.execution.datasources.FileFormat +import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType} +import scala.collection.mutable import scala.collection.mutable.ListBuffer object DeltaPostTransformRules { @@ -93,6 +96,73 @@ object DeltaPostTransformRules { } } + /** + * Checks whether two structurally compatible DataTypes have different struct field names at any + * nesting level. + */ + private def nestedFieldNamesDiffer(logical: DataType, physical: DataType): Boolean = { + (logical, physical) match { + case (l: StructType, p: StructType) if l.length == p.length => + l.zip(p).exists { + case (lf, pf) => + lf.name != pf.name || nestedFieldNamesDiffer(lf.dataType, pf.dataType) + } + case (l: ArrayType, p: ArrayType) => + nestedFieldNamesDiffer(l.elementType, p.elementType) + case (l: MapType, p: MapType) => + nestedFieldNamesDiffer(l.keyType, p.keyType) || + nestedFieldNamesDiffer(l.valueType, p.valueType) + case _ => false + } + } + + /** + * Rebuilds an expression tree so that nested struct field names match the logical schema. Uses + * positional extraction (GetStructField) and reconstruction (CreateNamedStruct) instead of Cast, + * so correctness does not depend on Velox's cast_match_struct_by_name config. + */ + private def reconcileFieldNames( + expr: Expression, + logical: DataType, + physical: DataType): Expression = { + (logical, physical) match { + case (l: StructType, p: StructType) if l.length == p.length => + val rebuiltFields = l.zip(p).zipWithIndex.flatMap { + case ((lf, pf), i) => + val extracted = GetStructField(expr, i, None) + val reconciled = reconcileFieldNames(extracted, lf.dataType, pf.dataType) + Seq(Literal(lf.name), reconciled) + } + val rebuilt = CreateNamedStruct(rebuiltFields) + If(IsNull(expr), Literal.create(null, l), rebuilt) + case (l: ArrayType, p: ArrayType) if nestedFieldNamesDiffer(l.elementType, p.elementType) => + val lambdaVar = NamedLambdaVariable("element", p.elementType, p.containsNull) + val body = reconcileFieldNames(lambdaVar, l.elementType, p.elementType) + ArrayTransform(expr, LambdaFunction(body, Seq(lambdaVar))) + case (l: MapType, p: MapType) => + val needKeys = nestedFieldNamesDiffer(l.keyType, p.keyType) + val needValues = nestedFieldNamesDiffer(l.valueType, p.valueType) + var result = expr + if (needValues) { + val keyVar = NamedLambdaVariable("key", p.keyType, false) + val valueVar = NamedLambdaVariable("value", p.valueType, p.valueContainsNull) + val body = reconcileFieldNames(valueVar, l.valueType, p.valueType) + result = TransformValues(result, LambdaFunction(body, Seq(keyVar, valueVar))) + } + if (needKeys) { + val keyVar = NamedLambdaVariable("key", p.keyType, false) + val valueVar = NamedLambdaVariable( + "value", + if (needValues) l.valueType else p.valueType, + p.valueContainsNull) + val body = reconcileFieldNames(keyVar, l.keyType, p.keyType) + result = TransformKeys(result, LambdaFunction(body, Seq(keyVar, valueVar))) + } + result + case _ => expr + } + } + /** * This method is only used for Delta ColumnMapping FileFormat(e.g. nameMapping and idMapping) * transform the metadata of Delta into Parquet's, each plan should only be transformed once. @@ -115,8 +185,9 @@ object DeltaPostTransformRules { )(SparkSession.active) // transform output's name into physical name so Reader can read data correctly // should keep the columns order the same as the origin output - val originColumnNames = ListBuffer.empty[String] - val transformedAttrs = ListBuffer.empty[Attribute] + case class ColumnMapping(logicalName: String, logicalType: DataType, physicalAttr: Attribute) + val columnMappings = ListBuffer.empty[ColumnMapping] + val seenNames = mutable.Set.empty[String] def mapAttribute(attr: Attribute) = { val newAttr = if (plan.isMetadataColumn(attr)) { attr @@ -127,9 +198,8 @@ object DeltaPostTransformRules { .createPhysicalAttributes(Seq(attr), fmt.referenceSchema, fmt.columnMappingMode) .head } - if (!originColumnNames.contains(attr.name)) { - transformedAttrs += newAttr - originColumnNames += attr.name + if (seenNames.add(attr.name)) { + columnMappings += ColumnMapping(attr.name, attr.dataType, newAttr) } newAttr } @@ -169,9 +239,20 @@ object DeltaPostTransformRules { scanExecTransformer.copyTagsFrom(plan) tagColumnMappingRule(scanExecTransformer) - // alias physicalName into tableName - val expr = (transformedAttrs, originColumnNames).zipped.map { - (attr, columnName) => Alias(attr, columnName)(exprId = attr.exprId) + // Alias physical names back to logical names. For struct-typed columns, Delta column + // mapping renames internal field names to physical UUIDs. A top-level Alias only restores + // the column name, not the struct's internal field names. We rebuild the struct with + // logical field names using positional extraction (GetStructField/CreateNamedStruct) + // instead of Cast, so correctness does not depend on any Velox cast config. + val expr = columnMappings.map { + cm => + val projectedExpr: Expression = + if (nestedFieldNamesDiffer(cm.logicalType, cm.physicalAttr.dataType)) { + reconcileFieldNames(cm.physicalAttr, cm.logicalType, cm.physicalAttr.dataType) + } else { + cm.physicalAttr + } + Alias(projectedExpr, cm.logicalName)(exprId = cm.physicalAttr.exprId) } val projectExecTransformer = ProjectExecTransformer(expr.toSeq, scanExecTransformer) projectExecTransformer diff --git a/gluten-delta/src/test/scala/org/apache/gluten/execution/DeltaSuite.scala b/gluten-delta/src/test/scala/org/apache/gluten/execution/DeltaSuite.scala index 8b4d7b374d5b..031bf460347d 100644 --- a/gluten-delta/src/test/scala/org/apache/gluten/execution/DeltaSuite.scala +++ b/gluten-delta/src/test/scala/org/apache/gluten/execution/DeltaSuite.scala @@ -399,4 +399,180 @@ abstract class DeltaSuite extends WholeStageTransformerSuite { checkAnswer(df, Seq(Row(2), Row(3))) } } + + testWithMinSparkVersion( + "merge with column mapping handles struct field metadata correctly", + "3.4") { + withTable("merge_struct_source", "merge_struct_target") { + spark.sql(""" + |CREATE TABLE merge_struct_target( + | key INT NOT NULL, + | value INT NOT NULL, + | cstruct STRUCT) + |USING DELTA + |TBLPROPERTIES ( + | 'delta.minReaderVersion' = '2', + | 'delta.minWriterVersion' = '5', + | 'delta.columnMapping.mode' = 'name') + """.stripMargin) + spark.sql("INSERT INTO merge_struct_target VALUES (0, 0, null)") + spark.sql("INSERT INTO merge_struct_target VALUES (100, 100, named_struct('foo', 42))") + + spark.sql( + "CREATE TABLE merge_struct_source (key INT NOT NULL, value INT NOT NULL) USING DELTA") + spark.sql("INSERT INTO merge_struct_source VALUES (1, 1)") + + // MERGE with updateNotMatched to test CaseWhen else branch + spark.sql(""" + |MERGE INTO merge_struct_target AS target + |USING merge_struct_source AS source + |ON source.key = target.key + |WHEN MATCHED THEN + | UPDATE SET target.value = source.value + |WHEN NOT MATCHED BY SOURCE AND target.key = 100 THEN + | UPDATE SET target.value = 22 + """.stripMargin) + + val df = runQueryAndCompare( + "SELECT key, value, cstruct FROM merge_struct_target ORDER BY key") { _ => } + checkAnswer(df, Row(0, 0, null) :: Row(100, 22, Row(42)) :: Nil) + } + } + + testWithMinSparkVersion( + "merge with column mapping handles array-of-struct field metadata correctly", + "3.4") { + withTable("merge_arraystruct_source", "merge_arraystruct_target") { + spark.sql(""" + |CREATE TABLE merge_arraystruct_target( + | key INT NOT NULL, + | tags ARRAY>) + |USING DELTA + |TBLPROPERTIES ( + | 'delta.minReaderVersion' = '2', + | 'delta.minWriterVersion' = '5', + | 'delta.columnMapping.mode' = 'name') + """.stripMargin) + spark.sql("INSERT INTO merge_arraystruct_target VALUES (0, null)") + spark.sql( + "INSERT INTO merge_arraystruct_target VALUES " + + "(100, array(named_struct('label', 'a', 'score', 10)))") + spark.sql("CREATE TABLE merge_arraystruct_source (key INT NOT NULL) USING DELTA") + spark.sql("INSERT INTO merge_arraystruct_source VALUES (1)") + // MERGE that leaves the array-of-struct column unchanged via CaseWhen + spark.sql(""" + |MERGE INTO merge_arraystruct_target AS target + |USING merge_arraystruct_source AS source + |ON source.key = target.key + |WHEN NOT MATCHED BY SOURCE AND target.key = 100 THEN + | UPDATE SET target.key = 101 + """.stripMargin) + val df = runQueryAndCompare("SELECT key, tags FROM merge_arraystruct_target ORDER BY key") { + _ => + } + checkAnswer(df, Row(0, null) :: Row(101, Seq(Row("a", 10))) :: Nil) + } + } + + testWithMinSparkVersion( + "merge with column mapping handles map-of-struct field metadata correctly", + "3.4") { + withTable("merge_mapstruct_source", "merge_mapstruct_target") { + spark.sql(""" + |CREATE TABLE merge_mapstruct_target( + | key INT NOT NULL, + | props MAP>) + |USING DELTA + |TBLPROPERTIES ( + | 'delta.minReaderVersion' = '2', + | 'delta.minWriterVersion' = '5', + | 'delta.columnMapping.mode' = 'name') + """.stripMargin) + spark.sql("INSERT INTO merge_mapstruct_target VALUES (0, null)") + spark.sql( + "INSERT INTO merge_mapstruct_target VALUES " + + "(100, map('x', named_struct('val', 99)))") + spark.sql("CREATE TABLE merge_mapstruct_source (key INT NOT NULL) USING DELTA") + spark.sql("INSERT INTO merge_mapstruct_source VALUES (1)") + // MERGE that leaves the map-of-struct column unchanged via CaseWhen + spark.sql(""" + |MERGE INTO merge_mapstruct_target AS target + |USING merge_mapstruct_source AS source + |ON source.key = target.key + |WHEN NOT MATCHED BY SOURCE AND target.key = 100 THEN + | UPDATE SET target.key = 101 + """.stripMargin) + val df = runQueryAndCompare("SELECT key, props FROM merge_mapstruct_target ORDER BY key") { + _ => + } + checkAnswer(df, Row(0, null) :: Row(101, Map("x" -> Row(99))) :: Nil) + } + } + + testWithMinSparkVersion( + "merge with column mapping handles nested struct-within-struct field metadata correctly", + "3.4") { + withTable("merge_nestedstruct_source", "merge_nestedstruct_target") { + spark.sql(""" + |CREATE TABLE merge_nestedstruct_target( + | key INT NOT NULL, + | nested STRUCT>) + |USING DELTA + |TBLPROPERTIES ( + | 'delta.minReaderVersion' = '2', + | 'delta.minWriterVersion' = '5', + | 'delta.columnMapping.mode' = 'name') + """.stripMargin) + spark.sql("INSERT INTO merge_nestedstruct_target VALUES (0, null)") + spark.sql( + "INSERT INTO merge_nestedstruct_target VALUES " + + "(100, named_struct('outer_val', named_struct('inner_val', 42)))") + spark.sql("CREATE TABLE merge_nestedstruct_source (key INT NOT NULL) USING DELTA") + spark.sql("INSERT INTO merge_nestedstruct_source VALUES (1)") + spark.sql(""" + |MERGE INTO merge_nestedstruct_target AS target + |USING merge_nestedstruct_source AS source + |ON source.key = target.key + |WHEN NOT MATCHED BY SOURCE AND target.key = 100 THEN + | UPDATE SET target.key = 101 + """.stripMargin) + val df = runQueryAndCompare( + "SELECT key, nested FROM merge_nestedstruct_target ORDER BY key") { _ => } + checkAnswer(df, Row(0, null) :: Row(101, Row(Row(42))) :: Nil) + } + } + + testWithMinSparkVersion( + "merge with column mapping handles array with null struct elements correctly", + "3.4") { + withTable("merge_arraynull_source", "merge_arraynull_target") { + spark.sql(""" + |CREATE TABLE merge_arraynull_target( + | key INT NOT NULL, + | items ARRAY>) + |USING DELTA + |TBLPROPERTIES ( + | 'delta.minReaderVersion' = '2', + | 'delta.minWriterVersion' = '5', + | 'delta.columnMapping.mode' = 'name') + """.stripMargin) + spark.sql("INSERT INTO merge_arraynull_target VALUES (0, null)") + spark.sql( + "INSERT INTO merge_arraynull_target VALUES " + + "(100, array(named_struct('name', 'a', 'qty', 1), null))") + spark.sql("CREATE TABLE merge_arraynull_source (key INT NOT NULL) USING DELTA") + spark.sql("INSERT INTO merge_arraynull_source VALUES (1)") + spark.sql(""" + |MERGE INTO merge_arraynull_target AS target + |USING merge_arraynull_source AS source + |ON source.key = target.key + |WHEN NOT MATCHED BY SOURCE AND target.key = 100 THEN + | UPDATE SET target.key = 101 + """.stripMargin) + val df = runQueryAndCompare("SELECT key, items FROM merge_arraynull_target ORDER BY key") { + _ => + } + checkAnswer(df, Row(0, null) :: Row(101, Seq(Row("a", 1), null)) :: Nil) + } + } } diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/CollapseProjectExecTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/CollapseProjectExecTransformer.scala index b15d32a6e12a..91e3112c4aed 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/CollapseProjectExecTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/CollapseProjectExecTransformer.scala @@ -19,7 +19,7 @@ package org.apache.gluten.extension.columnar import org.apache.gluten.config.GlutenConfig import org.apache.gluten.execution.ProjectExecTransformer -import org.apache.spark.sql.catalyst.expressions.{Alias, CreateNamedStruct, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.{Alias, CreateNamedStruct, Expression, If, NamedExpression} import org.apache.spark.sql.catalyst.optimizer.CollapseProjectShim import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.SparkPlan @@ -56,11 +56,22 @@ object CollapseProjectExecTransformer extends Rule[SparkPlan] { /** * In Velox, CreateNamedStruct will generate a special output named obj, We cannot collapse such - * project transformer, otherwise it will result in a bind reference failure. + * project transformer, otherwise it will result in a bind reference failure. Checks for + * CreateNamedStruct as direct Alias child or wrapped in If (null-guard pattern from + * reconcileFieldNames), but not deeply nested inside arbitrary expressions. */ private def containsNamedStructAlias(projectList: Seq[NamedExpression]): Boolean = { projectList.exists { - case a: Alias => a.child.exists(_.isInstanceOf[CreateNamedStruct]) + case a: Alias => isOrWrapsCreateNamedStruct(a.child) + case _ => false + } + } + + private def isOrWrapsCreateNamedStruct(expr: Expression): Boolean = { + expr match { + case _: CreateNamedStruct => true + case If(_, trueValue, falseValue) => + isOrWrapsCreateNamedStruct(trueValue) || isOrWrapsCreateNamedStruct(falseValue) case _ => false } }