diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index e7aca4c4ae3..2e53f0e936d 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -955,6 +955,7 @@ object GpuOverrides extends Logging { + GpuTypeShims.additionalCommonOperatorSupportedTypes).nested(), TypeSig.all), (a, conf, p, r) => new UnaryAstExprMeta[Alias](a, conf, p, r) { + override def typeMeta: DataTypeMeta = childExprs.head.typeMeta override def convertToGpu(child: Expression): GpuExpression = GpuAlias(child, a.name)(a.exprId, a.qualifier, a.explicitMetadata) }), @@ -2628,10 +2629,7 @@ object GpuOverrides extends Logging { TypeSig.STRUCT.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.BINARY), TypeSig.STRUCT.nested(TypeSig.all)), - (expr, conf, p, r) => new UnaryExprMeta[GetStructField](expr, conf, p, r) { - override def convertToGpu(arr: Expression): GpuExpression = - GpuGetStructField(arr, expr.ordinal, expr.name) - }), + (expr, conf, p, r) => new GpuGetStructFieldMeta(expr, conf, p, r)), expr[GetArrayItem]( "Gets the field at `ordinal` in the Array", ExprChecks.binaryProject( diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala index 9afb3b854d6..fde29b0afe2 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala @@ -402,6 +402,36 @@ case class GpuArrayPosition(left: Expression, right: Expression) } } +object GpuStructFieldOrdinalTag { + val PRUNED_ORDINAL_TAG = + new org.apache.spark.sql.catalyst.trees.TreeNodeTag[Int]("GPU_PRUNED_ORDINAL") +} + +class GpuGetStructFieldMeta( + expr: GetStructField, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _, _]], + rule: DataFromReplacementRule) + extends UnaryExprMeta[GetStructField](expr, conf, parent, rule) { + + override def convertToGpu(child: Expression): GpuExpression = { + val effectiveOrd = GpuGetStructFieldMeta.effectiveOrdinal(expr) + GpuGetStructField(child, effectiveOrd, expr.name) + } +} + +object GpuGetStructFieldMeta { + def effectiveOrdinal(expr: GetStructField): Int = { + val runtimeOrd = expr.getTagValue( + GpuStructFieldOrdinalTag.PRUNED_ORDINAL_TAG).getOrElse(-1) + if (runtimeOrd >= 0) { + runtimeOrd + } else { + expr.ordinal + } + } +} + class GpuGetArrayStructFieldsMeta( expr: GetArrayStructFields, conf: RapidsConf, @@ -409,8 +439,31 @@ class GpuGetArrayStructFieldsMeta( rule: DataFromReplacementRule) extends UnaryExprMeta[GetArrayStructFields](expr, conf, parent, rule) { - def convertToGpu(child: Expression): GpuExpression = - GpuGetArrayStructFields(child, expr.field, expr.ordinal, expr.numFields, expr.containsNull) + override def convertToGpu(child: Expression): GpuExpression = { + val runtimeOrd = expr.getTagValue( + GpuStructFieldOrdinalTag.PRUNED_ORDINAL_TAG).getOrElse(-1) + val effectiveOrd = if (runtimeOrd >= 0) runtimeOrd else expr.ordinal + val effectiveNumFields = + GpuGetArrayStructFieldsMeta.effectiveNumFields(child, expr, runtimeOrd) + GpuGetArrayStructFields(child, expr.field, + effectiveOrd, effectiveNumFields, expr.containsNull) + } +} + +object GpuGetArrayStructFieldsMeta { + def effectiveNumFields( + child: Expression, + expr: GetArrayStructFields, + runtimeOrd: Int): Int = { + if (runtimeOrd >= 0) { + child.dataType match { + case ArrayType(st: StructType, _) => st.fields.length + case _ => expr.numFields + } + } else { + expr.numFields + } + } } /** diff --git a/sql-plugin/src/test/scala/com/nvidia/spark/rapids/StructFieldOrdinalTagSuite.scala b/sql-plugin/src/test/scala/com/nvidia/spark/rapids/StructFieldOrdinalTagSuite.scala new file mode 100644 index 00000000000..0c722b24486 --- /dev/null +++ b/sql-plugin/src/test/scala/com/nvidia/spark/rapids/StructFieldOrdinalTagSuite.scala @@ -0,0 +1,84 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids + +import org.scalatest.funsuite.AnyFunSuite + +import org.apache.spark.sql.catalyst.expressions.{ + AttributeReference, + GetArrayStructFields, + GetStructField +} +import org.apache.spark.sql.rapids.{ + GpuGetArrayStructFieldsMeta, + GpuGetStructFieldMeta, + GpuStructFieldOrdinalTag +} +import org.apache.spark.sql.types._ + +class StructFieldOrdinalTagSuite extends AnyFunSuite { + + private val threeFieldStruct = StructType(Seq( + StructField("a", IntegerType, nullable = true), + StructField("b", StringType, nullable = true), + StructField("c", DoubleType, nullable = true))) + + private val structAttr = AttributeReference("s", threeFieldStruct, nullable = true)() + + private val arrayOfStruct = ArrayType(threeFieldStruct, containsNull = false) + private val arrayAttr = AttributeReference("arr", arrayOfStruct, nullable = true)() + + // ---------- GetStructField ordinal tag ---------- + + test("effectiveOrdinal returns original ordinal when no tag is set") { + val gsf = GetStructField(structAttr, 2, Some("c")) + assert(GpuGetStructFieldMeta.effectiveOrdinal(gsf) === 2) + } + + test("effectiveOrdinal returns tagged ordinal when tag is set") { + val gsf = GetStructField(structAttr, 2, Some("c")) + gsf.setTagValue(GpuStructFieldOrdinalTag.PRUNED_ORDINAL_TAG, 0) + assert(GpuGetStructFieldMeta.effectiveOrdinal(gsf) === 0) + } + + // ---------- effectiveNumFields ---------- + + test("effectiveNumFields returns original numFields when no tag") { + val gasf = GetArrayStructFields(arrayAttr, threeFieldStruct(1), 1, 3, false) + val result = GpuGetArrayStructFieldsMeta.effectiveNumFields(arrayAttr, gasf, -1) + assert(result === 3) + } + + test("effectiveNumFields derives from child type when tag is active") { + val prunedStruct = StructType(Seq( + StructField("b", StringType, nullable = true))) + val prunedArrayType = ArrayType(prunedStruct, containsNull = false) + val prunedChild = AttributeReference("arr", prunedArrayType, nullable = true)() + + val gasf = GetArrayStructFields(arrayAttr, threeFieldStruct(1), 1, 3, false) + val result = GpuGetArrayStructFieldsMeta.effectiveNumFields(prunedChild, gasf, 0) + assert(result === 1) + } + + test("effectiveNumFields falls back to expr.numFields for non-array child type") { + val nonArrayChild = AttributeReference("x", IntegerType, nullable = true)() + val gasf = GetArrayStructFields(arrayAttr, threeFieldStruct(0), 0, 3, false) + val result = GpuGetArrayStructFieldsMeta.effectiveNumFields(nonArrayChild, gasf, 0) + assert(result === 3) + } + +}