Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,16 @@ import org.apache.gluten.execution.{DeltaScanTransformer, ProjectExecTransformer
import org.apache.gluten.extension.columnar.transition.RemoveTransitions

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Expression, InputFileBlockLength, InputFileBlockStart, InputFileName}
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, CreateNamedStruct, Expression, GetStructField, If, InputFileBlockLength, InputFileBlockStart, InputFileName, IsNull, LambdaFunction, Literal, NamedLambdaVariable}
import org.apache.spark.sql.catalyst.expressions.{ArrayTransform, TransformKeys, TransformValues}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.delta.{DeltaColumnMapping, DeltaParquetFileFormat, NoMapping}
import org.apache.spark.sql.execution.{ProjectExec, SparkPlan}
import org.apache.spark.sql.execution.datasources.FileFormat
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType}

import scala.collection.mutable
import scala.collection.mutable.ListBuffer

object DeltaPostTransformRules {
Expand Down Expand Up @@ -93,6 +96,73 @@ object DeltaPostTransformRules {
}
}

/**
* Checks whether two structurally compatible DataTypes have different struct field names at any
* nesting level.
*/
private def nestedFieldNamesDiffer(logical: DataType, physical: DataType): Boolean = {
(logical, physical) match {
case (l: StructType, p: StructType) if l.length == p.length =>
l.zip(p).exists {
case (lf, pf) =>
lf.name != pf.name || nestedFieldNamesDiffer(lf.dataType, pf.dataType)
}
case (l: ArrayType, p: ArrayType) =>
nestedFieldNamesDiffer(l.elementType, p.elementType)
case (l: MapType, p: MapType) =>
nestedFieldNamesDiffer(l.keyType, p.keyType) ||
nestedFieldNamesDiffer(l.valueType, p.valueType)
case _ => false
}
}

/**
* Rebuilds an expression tree so that nested struct field names match the logical schema. Uses
* positional extraction (GetStructField) and reconstruction (CreateNamedStruct) instead of Cast,
* so correctness does not depend on Velox's cast_match_struct_by_name config.
*/
private def reconcileFieldNames(
expr: Expression,
logical: DataType,
physical: DataType): Expression = {
(logical, physical) match {
case (l: StructType, p: StructType) if l.length == p.length =>
val rebuiltFields = l.zip(p).zipWithIndex.flatMap {
case ((lf, pf), i) =>
val extracted = GetStructField(expr, i, None)
val reconciled = reconcileFieldNames(extracted, lf.dataType, pf.dataType)
Seq(Literal(lf.name), reconciled)
}
val rebuilt = CreateNamedStruct(rebuiltFields)
If(IsNull(expr), Literal.create(null, l), rebuilt)
case (l: ArrayType, p: ArrayType) if nestedFieldNamesDiffer(l.elementType, p.elementType) =>
val lambdaVar = NamedLambdaVariable("element", p.elementType, p.containsNull)
val body = reconcileFieldNames(lambdaVar, l.elementType, p.elementType)
ArrayTransform(expr, LambdaFunction(body, Seq(lambdaVar)))
case (l: MapType, p: MapType) =>
val needKeys = nestedFieldNamesDiffer(l.keyType, p.keyType)
val needValues = nestedFieldNamesDiffer(l.valueType, p.valueType)
var result = expr
if (needValues) {
val keyVar = NamedLambdaVariable("key", p.keyType, false)
val valueVar = NamedLambdaVariable("value", p.valueType, p.valueContainsNull)
val body = reconcileFieldNames(valueVar, l.valueType, p.valueType)
result = TransformValues(result, LambdaFunction(body, Seq(keyVar, valueVar)))
}
if (needKeys) {
val keyVar = NamedLambdaVariable("key", p.keyType, false)
val valueVar = NamedLambdaVariable(
"value",
if (needValues) l.valueType else p.valueType,
p.valueContainsNull)
val body = reconcileFieldNames(keyVar, l.keyType, p.keyType)
result = TransformKeys(result, LambdaFunction(body, Seq(keyVar, valueVar)))
}
result
case _ => expr
}
}

/**
* This method is only used for Delta ColumnMapping FileFormat(e.g. nameMapping and idMapping)
* transform the metadata of Delta into Parquet's, each plan should only be transformed once.
Expand All @@ -115,8 +185,9 @@ object DeltaPostTransformRules {
)(SparkSession.active)
// transform output's name into physical name so Reader can read data correctly
// should keep the columns order the same as the origin output
val originColumnNames = ListBuffer.empty[String]
val transformedAttrs = ListBuffer.empty[Attribute]
case class ColumnMapping(logicalName: String, logicalType: DataType, physicalAttr: Attribute)
val columnMappings = ListBuffer.empty[ColumnMapping]
val seenNames = mutable.Set.empty[String]
def mapAttribute(attr: Attribute) = {
val newAttr = if (plan.isMetadataColumn(attr)) {
attr
Expand All @@ -127,9 +198,8 @@ object DeltaPostTransformRules {
.createPhysicalAttributes(Seq(attr), fmt.referenceSchema, fmt.columnMappingMode)
.head
}
if (!originColumnNames.contains(attr.name)) {
transformedAttrs += newAttr
originColumnNames += attr.name
if (seenNames.add(attr.name)) {
columnMappings += ColumnMapping(attr.name, attr.dataType, newAttr)
}
newAttr
}
Expand Down Expand Up @@ -169,9 +239,20 @@ object DeltaPostTransformRules {
scanExecTransformer.copyTagsFrom(plan)
tagColumnMappingRule(scanExecTransformer)

// alias physicalName into tableName
val expr = (transformedAttrs, originColumnNames).zipped.map {
(attr, columnName) => Alias(attr, columnName)(exprId = attr.exprId)
// Alias physical names back to logical names. For struct-typed columns, Delta column
// mapping renames internal field names to physical UUIDs. A top-level Alias only restores
// the column name, not the struct's internal field names. We rebuild the struct with
// logical field names using positional extraction (GetStructField/CreateNamedStruct)
// instead of Cast, so correctness does not depend on any Velox cast config.
val expr = columnMappings.map {
cm =>
val projectedExpr: Expression =
if (nestedFieldNamesDiffer(cm.logicalType, cm.physicalAttr.dataType)) {
reconcileFieldNames(cm.physicalAttr, cm.logicalType, cm.physicalAttr.dataType)
} else {
cm.physicalAttr
}
Alias(projectedExpr, cm.logicalName)(exprId = cm.physicalAttr.exprId)
}
val projectExecTransformer = ProjectExecTransformer(expr.toSeq, scanExecTransformer)
projectExecTransformer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -399,4 +399,180 @@ abstract class DeltaSuite extends WholeStageTransformerSuite {
checkAnswer(df, Seq(Row(2), Row(3)))
}
}

testWithMinSparkVersion(
"merge with column mapping handles struct field metadata correctly",
"3.4") {
withTable("merge_struct_source", "merge_struct_target") {
spark.sql("""
|CREATE TABLE merge_struct_target(
| key INT NOT NULL,
| value INT NOT NULL,
| cstruct STRUCT<foo: INT>)
|USING DELTA
|TBLPROPERTIES (
| 'delta.minReaderVersion' = '2',
| 'delta.minWriterVersion' = '5',
| 'delta.columnMapping.mode' = 'name')
""".stripMargin)
spark.sql("INSERT INTO merge_struct_target VALUES (0, 0, null)")
spark.sql("INSERT INTO merge_struct_target VALUES (100, 100, named_struct('foo', 42))")

spark.sql(
"CREATE TABLE merge_struct_source (key INT NOT NULL, value INT NOT NULL) USING DELTA")
spark.sql("INSERT INTO merge_struct_source VALUES (1, 1)")

// MERGE with updateNotMatched to test CaseWhen else branch
spark.sql("""
|MERGE INTO merge_struct_target AS target
|USING merge_struct_source AS source
|ON source.key = target.key
|WHEN MATCHED THEN
| UPDATE SET target.value = source.value
|WHEN NOT MATCHED BY SOURCE AND target.key = 100 THEN
| UPDATE SET target.value = 22
""".stripMargin)

val df = runQueryAndCompare(
"SELECT key, value, cstruct FROM merge_struct_target ORDER BY key") { _ => }
checkAnswer(df, Row(0, 0, null) :: Row(100, 22, Row(42)) :: Nil)
}
}

testWithMinSparkVersion(
"merge with column mapping handles array-of-struct field metadata correctly",
"3.4") {
withTable("merge_arraystruct_source", "merge_arraystruct_target") {
spark.sql("""
|CREATE TABLE merge_arraystruct_target(
| key INT NOT NULL,
| tags ARRAY<STRUCT<label: STRING, score: INT>>)
|USING DELTA
|TBLPROPERTIES (
| 'delta.minReaderVersion' = '2',
| 'delta.minWriterVersion' = '5',
| 'delta.columnMapping.mode' = 'name')
""".stripMargin)
spark.sql("INSERT INTO merge_arraystruct_target VALUES (0, null)")
spark.sql(
"INSERT INTO merge_arraystruct_target VALUES " +
"(100, array(named_struct('label', 'a', 'score', 10)))")
spark.sql("CREATE TABLE merge_arraystruct_source (key INT NOT NULL) USING DELTA")
spark.sql("INSERT INTO merge_arraystruct_source VALUES (1)")
// MERGE that leaves the array-of-struct column unchanged via CaseWhen
spark.sql("""
|MERGE INTO merge_arraystruct_target AS target
|USING merge_arraystruct_source AS source
|ON source.key = target.key
|WHEN NOT MATCHED BY SOURCE AND target.key = 100 THEN
| UPDATE SET target.key = 101
""".stripMargin)
val df = runQueryAndCompare("SELECT key, tags FROM merge_arraystruct_target ORDER BY key") {
_ =>
}
checkAnswer(df, Row(0, null) :: Row(101, Seq(Row("a", 10))) :: Nil)
}
}

testWithMinSparkVersion(
"merge with column mapping handles map-of-struct field metadata correctly",
"3.4") {
withTable("merge_mapstruct_source", "merge_mapstruct_target") {
spark.sql("""
|CREATE TABLE merge_mapstruct_target(
| key INT NOT NULL,
| props MAP<STRING, STRUCT<val: INT>>)
|USING DELTA
|TBLPROPERTIES (
| 'delta.minReaderVersion' = '2',
| 'delta.minWriterVersion' = '5',
| 'delta.columnMapping.mode' = 'name')
""".stripMargin)
spark.sql("INSERT INTO merge_mapstruct_target VALUES (0, null)")
spark.sql(
"INSERT INTO merge_mapstruct_target VALUES " +
"(100, map('x', named_struct('val', 99)))")
spark.sql("CREATE TABLE merge_mapstruct_source (key INT NOT NULL) USING DELTA")
spark.sql("INSERT INTO merge_mapstruct_source VALUES (1)")
// MERGE that leaves the map-of-struct column unchanged via CaseWhen
spark.sql("""
|MERGE INTO merge_mapstruct_target AS target
|USING merge_mapstruct_source AS source
|ON source.key = target.key
|WHEN NOT MATCHED BY SOURCE AND target.key = 100 THEN
| UPDATE SET target.key = 101
""".stripMargin)
val df = runQueryAndCompare("SELECT key, props FROM merge_mapstruct_target ORDER BY key") {
_ =>
}
checkAnswer(df, Row(0, null) :: Row(101, Map("x" -> Row(99))) :: Nil)
}
}

testWithMinSparkVersion(
"merge with column mapping handles nested struct-within-struct field metadata correctly",
"3.4") {
withTable("merge_nestedstruct_source", "merge_nestedstruct_target") {
spark.sql("""
|CREATE TABLE merge_nestedstruct_target(
| key INT NOT NULL,
| nested STRUCT<outer_val: STRUCT<inner_val: INT>>)
|USING DELTA
|TBLPROPERTIES (
| 'delta.minReaderVersion' = '2',
| 'delta.minWriterVersion' = '5',
| 'delta.columnMapping.mode' = 'name')
""".stripMargin)
spark.sql("INSERT INTO merge_nestedstruct_target VALUES (0, null)")
spark.sql(
"INSERT INTO merge_nestedstruct_target VALUES " +
"(100, named_struct('outer_val', named_struct('inner_val', 42)))")
spark.sql("CREATE TABLE merge_nestedstruct_source (key INT NOT NULL) USING DELTA")
spark.sql("INSERT INTO merge_nestedstruct_source VALUES (1)")
spark.sql("""
|MERGE INTO merge_nestedstruct_target AS target
|USING merge_nestedstruct_source AS source
|ON source.key = target.key
|WHEN NOT MATCHED BY SOURCE AND target.key = 100 THEN
| UPDATE SET target.key = 101
""".stripMargin)
val df = runQueryAndCompare(
"SELECT key, nested FROM merge_nestedstruct_target ORDER BY key") { _ => }
checkAnswer(df, Row(0, null) :: Row(101, Row(Row(42))) :: Nil)
}
}

testWithMinSparkVersion(
"merge with column mapping handles array with null struct elements correctly",
"3.4") {
withTable("merge_arraynull_source", "merge_arraynull_target") {
spark.sql("""
|CREATE TABLE merge_arraynull_target(
| key INT NOT NULL,
| items ARRAY<STRUCT<name: STRING, qty: INT>>)
|USING DELTA
|TBLPROPERTIES (
| 'delta.minReaderVersion' = '2',
| 'delta.minWriterVersion' = '5',
| 'delta.columnMapping.mode' = 'name')
""".stripMargin)
spark.sql("INSERT INTO merge_arraynull_target VALUES (0, null)")
spark.sql(
"INSERT INTO merge_arraynull_target VALUES " +
"(100, array(named_struct('name', 'a', 'qty', 1), null))")
spark.sql("CREATE TABLE merge_arraynull_source (key INT NOT NULL) USING DELTA")
spark.sql("INSERT INTO merge_arraynull_source VALUES (1)")
spark.sql("""
|MERGE INTO merge_arraynull_target AS target
|USING merge_arraynull_source AS source
|ON source.key = target.key
|WHEN NOT MATCHED BY SOURCE AND target.key = 100 THEN
| UPDATE SET target.key = 101
""".stripMargin)
val df = runQueryAndCompare("SELECT key, items FROM merge_arraynull_target ORDER BY key") {
_ =>
}
checkAnswer(df, Row(0, null) :: Row(101, Seq(Row("a", 1), null)) :: Nil)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.gluten.extension.columnar
import org.apache.gluten.config.GlutenConfig
import org.apache.gluten.execution.ProjectExecTransformer

import org.apache.spark.sql.catalyst.expressions.{Alias, CreateNamedStruct, NamedExpression}
import org.apache.spark.sql.catalyst.expressions.{Alias, CreateNamedStruct, Expression, If, NamedExpression}
import org.apache.spark.sql.catalyst.optimizer.CollapseProjectShim
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.SparkPlan
Expand Down Expand Up @@ -56,11 +56,22 @@ object CollapseProjectExecTransformer extends Rule[SparkPlan] {

/**
* In Velox, CreateNamedStruct will generate a special output named obj, We cannot collapse such
* project transformer, otherwise it will result in a bind reference failure.
* project transformer, otherwise it will result in a bind reference failure. Checks for
* CreateNamedStruct as direct Alias child or wrapped in If (null-guard pattern from
* reconcileFieldNames), but not deeply nested inside arbitrary expressions.
*/
private def containsNamedStructAlias(projectList: Seq[NamedExpression]): Boolean = {
projectList.exists {
case a: Alias => a.child.exists(_.isInstanceOf[CreateNamedStruct])
case a: Alias => isOrWrapsCreateNamedStruct(a.child)
case _ => false
}
}

private def isOrWrapsCreateNamedStruct(expr: Expression): Boolean = {
expr match {
case _: CreateNamedStruct => true
case If(_, trueValue, falseValue) =>
isOrWrapsCreateNamedStruct(trueValue) || isOrWrapsCreateNamedStruct(falseValue)
case _ => false
}
}
Expand Down
Loading