diff --git a/delta-lake/common/src/main/delta-33x-40x/scala/com/nvidia/spark/rapids/delta/common/DeltaProviderBase.scala b/delta-lake/common/src/main/delta-33x-40x/scala/com/nvidia/spark/rapids/delta/common/DeltaProviderBase.scala
index bbf3459a705..ea108f4fef6 100644
--- a/delta-lake/common/src/main/delta-33x-40x/scala/com/nvidia/spark/rapids/delta/common/DeltaProviderBase.scala
+++ b/delta-lake/common/src/main/delta-33x-40x/scala/com/nvidia/spark/rapids/delta/common/DeltaProviderBase.scala
@@ -304,11 +304,15 @@ object DVPredicatePushdown extends ShimPredicateHelper {
def mergeIdenticalProjects(plan: SparkPlan): SparkPlan = {
plan.transformUp {
case p @ GpuProjectExec(projList1,
- GpuProjectExec(projList2, child, enablePreSplit1), enablePreSplit2) =>
+ GpuProjectExec(projList2, child, enablePreSplit1),
+ enablePreSplit2) =>
val projSet1 = projList1.map(_.exprId).toSet
val projSet2 = projList2.map(_.exprId).toSet
if (projSet1 == projSet2) {
- GpuProjectExec(projList1, child, enablePreSplit1 && enablePreSplit2)
+ GpuProjectExec(
+ projList1,
+ child,
+ enablePreSplit1 && enablePreSplit2)
} else {
p
}
diff --git a/docs/supported_ops.md b/docs/supported_ops.md
index 888f4fa807c..196f518c6a7 100644
--- a/docs/supported_ops.md
+++ b/docs/supported_ops.md
@@ -8267,7 +8267,7 @@ are limited.
|
|
|
-PS UTC is only supported TZ for child TIMESTAMP; unsupported child types BINARY, CALENDAR, UDT, DAYTIME, YEARMONTH |
+PS UTC is only supported TZ for child TIMESTAMP; unsupported child types CALENDAR, UDT, DAYTIME, YEARMONTH |
|
|
|
@@ -8290,7 +8290,7 @@ are limited.
|
|
|
-PS UTC is only supported TZ for child TIMESTAMP; unsupported child types BINARY, CALENDAR, UDT, DAYTIME, YEARMONTH |
+PS UTC is only supported TZ for child TIMESTAMP; unsupported child types CALENDAR, UDT, DAYTIME, YEARMONTH |
|
|
|
diff --git a/integration_tests/run_pyspark_from_build.sh b/integration_tests/run_pyspark_from_build.sh
index 0f2e2471b7b..4ede802a22b 100755
--- a/integration_tests/run_pyspark_from_build.sh
+++ b/integration_tests/run_pyspark_from_build.sh
@@ -46,6 +46,9 @@
# To run all tests, including Avro tests:
# INCLUDE_SPARK_AVRO_JAR=true ./run_pyspark_from_build.sh
#
+# To run tests WITHOUT Protobuf tests (protobuf is included by default):
+# INCLUDE_SPARK_PROTOBUF_JAR=false ./run_pyspark_from_build.sh
+#
# To run a specific test:
# TEST=my_test ./run_pyspark_from_build.sh
#
@@ -141,9 +144,101 @@ else
AVRO_JARS=""
fi
- # ALL_JARS includes dist.jar integration-test.jar avro.jar parquet.jar if they exist
+ # Protobuf support: Include spark-protobuf jar by default for protobuf_test.py
+ # Set INCLUDE_SPARK_PROTOBUF_JAR=false to disable
+ PROTOBUF_JARS=""
+ if [[ $( echo ${INCLUDE_SPARK_PROTOBUF_JAR} | tr '[:upper:]' '[:lower:]' ) != "false" ]];
+ then
+ export INCLUDE_SPARK_PROTOBUF_JAR=true
+ mkdir -p "${TARGET_DIR}/dependency"
+
+ # Download spark-protobuf jar if not already in target/dependency
+ PROTOBUF_JAR_NAME="spark-protobuf_${SCALA_VERSION}-${VERSION_STRING}.jar"
+ PROTOBUF_JAR_PATH="${TARGET_DIR}/dependency/${PROTOBUF_JAR_NAME}"
+
+ if [[ ! -f "$PROTOBUF_JAR_PATH" ]]; then
+ echo "Downloading spark-protobuf jar..."
+ PROTOBUF_MAVEN_URL="https://repo1.maven.org/maven2/org/apache/spark/spark-protobuf_${SCALA_VERSION}/${VERSION_STRING}/${PROTOBUF_JAR_NAME}"
+ if curl -fsL -o "$PROTOBUF_JAR_PATH" "$PROTOBUF_MAVEN_URL"; then
+ echo "Downloaded spark-protobuf jar to $PROTOBUF_JAR_PATH"
+ else
+ echo "WARNING: Failed to download spark-protobuf jar from $PROTOBUF_MAVEN_URL"
+ rm -f "$PROTOBUF_JAR_PATH"
+ fi
+ fi
+
+ # Also download protobuf-java jar (required dependency).
+ # Detect version from the jar bundled with Spark, fall back to version mapping.
+ PROTOBUF_JAVA_VERSION=""
+ BUNDLED_PB_JAR=$(ls "$SPARK_HOME"/jars/protobuf-java-[0-9]*.jar 2>/dev/null | sort -V | tail -1)
+ if [[ -n "$BUNDLED_PB_JAR" ]]; then
+ PROTOBUF_JAVA_VERSION=$(basename "$BUNDLED_PB_JAR" | sed 's/protobuf-java-\(.*\)\.jar/\1/')
+ echo "Detected protobuf-java version $PROTOBUF_JAVA_VERSION from SPARK_HOME"
+ fi
+ if [[ -z "$PROTOBUF_JAVA_VERSION" ]]; then
+ case "$VERSION_STRING" in
+ 3.4.*) PROTOBUF_JAVA_VERSION="3.25.1" ;;
+ 3.5.*) PROTOBUF_JAVA_VERSION="3.25.1" ;;
+ 4.0.*|4.1.*) PROTOBUF_JAVA_VERSION="4.29.3" ;;
+ *) PROTOBUF_JAVA_VERSION="3.25.1" ;;
+ esac
+ echo "Using protobuf-java version $PROTOBUF_JAVA_VERSION based on Spark $VERSION_STRING"
+ fi
+ PROTOBUF_JAVA_JAR_NAME="protobuf-java-${PROTOBUF_JAVA_VERSION}.jar"
+ PROTOBUF_JAVA_JAR_PATH="${TARGET_DIR}/dependency/${PROTOBUF_JAVA_JAR_NAME}"
+
+ if [[ ! -f "$PROTOBUF_JAVA_JAR_PATH" ]]; then
+ echo "Downloading protobuf-java jar..."
+ PROTOBUF_JAVA_MAVEN_URL="https://repo1.maven.org/maven2/com/google/protobuf/protobuf-java/${PROTOBUF_JAVA_VERSION}/${PROTOBUF_JAVA_JAR_NAME}"
+ if curl -fsL -o "$PROTOBUF_JAVA_JAR_PATH" "$PROTOBUF_JAVA_MAVEN_URL"; then
+ echo "Downloaded protobuf-java jar to $PROTOBUF_JAVA_JAR_PATH"
+ else
+ echo "WARNING: Failed to download protobuf-java jar from $PROTOBUF_JAVA_MAVEN_URL"
+ rm -f "$PROTOBUF_JAVA_JAR_PATH"
+ fi
+ fi
+
+ SPARK_PROTOBUF_JAR_AVAILABLE=false
+ PROTOBUF_JAVA_AVAILABLE=false
+
+ if [[ -f "$PROTOBUF_JAR_PATH" ]]; then
+ PROTOBUF_JARS="$PROTOBUF_JAR_PATH"
+ echo "Including spark-protobuf jar: $PROTOBUF_JAR_PATH"
+ SPARK_PROTOBUF_JAR_AVAILABLE=true
+ fi
+ if [[ -f "$PROTOBUF_JAVA_JAR_PATH" ]]; then
+ PROTOBUF_JARS="${PROTOBUF_JARS:+$PROTOBUF_JARS }$PROTOBUF_JAVA_JAR_PATH"
+ echo "Including protobuf-java jar: $PROTOBUF_JAVA_JAR_PATH"
+ PROTOBUF_JAVA_AVAILABLE=true
+ elif [[ -n "$BUNDLED_PB_JAR" ]]; then
+ echo "Using bundled protobuf-java jar from SPARK_HOME: $BUNDLED_PB_JAR"
+ PROTOBUF_JAVA_AVAILABLE=true
+ fi
+
+ if [[ "$SPARK_PROTOBUF_JAR_AVAILABLE" == "true" && \
+ "$PROTOBUF_JAVA_AVAILABLE" == "true" ]]; then
+ export PROTOBUF_JARS_AVAILABLE=true
+ else
+ echo "WARNING: Protobuf JAR dependencies incomplete; protobuf tests will be skipped"
+ echo " spark-protobuf available: $SPARK_PROTOBUF_JAR_AVAILABLE"
+ echo " protobuf-java available: $PROTOBUF_JAVA_AVAILABLE"
+ export PROTOBUF_JARS_AVAILABLE=false
+ fi
+ # Also add protobuf jars to driver classpath for Class.forName() to work
+ # This is needed because --jars only adds to executor classpath
+ if [[ -n "$PROTOBUF_JARS" ]]; then
+ PROTOBUF_DRIVER_CP=$(echo "$PROTOBUF_JARS" | tr ' ' ':')
+ export PYSP_TEST_spark_driver_extraClassPath="${PYSP_TEST_spark_driver_extraClassPath:+${PYSP_TEST_spark_driver_extraClassPath}:}${PROTOBUF_DRIVER_CP}"
+ echo "Added protobuf jars to driver classpath"
+ fi
+ else
+ export INCLUDE_SPARK_PROTOBUF_JAR=false
+ export PROTOBUF_JARS_AVAILABLE=false
+ fi
+
+ # ALL_JARS includes dist.jar integration-test.jar avro.jar parquet.jar protobuf.jar if they exist
# Remove non-existing paths and canonicalize the paths including get rid of links and `..`
- ALL_JARS=$(readlink -e $PLUGIN_JAR $TEST_JARS $AVRO_JARS $PARQUET_HADOOP_TESTS || true)
+ ALL_JARS=$(readlink -e $PLUGIN_JAR $TEST_JARS $AVRO_JARS $PARQUET_HADOOP_TESTS $PROTOBUF_JARS || true)
# `:` separated jars
ALL_JARS="${ALL_JARS//$'\n'/:}"
@@ -411,6 +506,7 @@ else
export PYSP_TEST_spark_gluten_loadLibFromJar=true
fi
+
SPARK_SHELL_SMOKE_TEST="${SPARK_SHELL_SMOKE_TEST:-0}"
EXPLAIN_ONLY_CPU_SMOKE_TEST="${EXPLAIN_ONLY_CPU_SMOKE_TEST:-0}"
SPARK_CONNECT_SMOKE_TEST="${SPARK_CONNECT_SMOKE_TEST:-0}"
diff --git a/integration_tests/src/main/python/data_gen.py b/integration_tests/src/main/python/data_gen.py
index fa7decac82d..da8f780a1b6 100644
--- a/integration_tests/src/main/python/data_gen.py
+++ b/integration_tests/src/main/python/data_gen.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2020-2025, NVIDIA CORPORATION.
+# Copyright (c) 2020-2026, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,6 +13,7 @@
# limitations under the License.
import copy
+from dataclasses import dataclass, replace
from datetime import date, datetime, timedelta, timezone
from decimal import *
from enum import Enum
@@ -857,6 +858,659 @@ def gen_bytes():
return bytes([ rand.randint(0, 255) for _ in range(length) ])
self._start(rand, gen_bytes)
+
+# -----------------------------------------------------------------------------
+# Protobuf schema-first test modeling / generation / encoding
+# -----------------------------------------------------------------------------
+
+_PROTOBUF_WIRE_VARINT = 0
+_PROTOBUF_WIRE_64BIT = 1
+_PROTOBUF_WIRE_LEN_DELIM = 2
+_PROTOBUF_WIRE_32BIT = 5
+_PB_MISSING = object()
+
+
+def _encode_protobuf_uvarint(value):
+ """Encode a non-negative integer as protobuf varint."""
+ if value is None:
+ raise ValueError("value must not be None")
+ if value < 0:
+ raise ValueError("uvarint only supports non-negative integers")
+ out = bytearray()
+ v = int(value)
+ while True:
+ b = v & 0x7F
+ v >>= 7
+ if v:
+ out.append(b | 0x80)
+ else:
+ out.append(b)
+ break
+ return bytes(out)
+
+
+def _encode_protobuf_key(field_number, wire_type):
+ return _encode_protobuf_uvarint((int(field_number) << 3) | int(wire_type))
+
+
+def _encode_protobuf_zigzag32(value):
+ return (int(value) << 1) ^ (int(value) >> 31)
+
+
+def _encode_protobuf_zigzag64(value):
+ return (int(value) << 1) ^ (int(value) >> 63)
+
+
+class PbCardinality(Enum):
+ OPTIONAL = 'optional'
+ REQUIRED = 'required'
+ REPEATED = 'repeated'
+
+
+class PbScalarKind(Enum):
+ BOOL = 'bool'
+ INT32 = 'int32'
+ INT64 = 'int64'
+ UINT32 = 'uint32'
+ UINT64 = 'uint64'
+ SINT32 = 'sint32'
+ SINT64 = 'sint64'
+ FIXED32 = 'fixed32'
+ FIXED64 = 'fixed64'
+ SFIXED32 = 'sfixed32'
+ SFIXED64 = 'sfixed64'
+ FLOAT = 'float'
+ DOUBLE = 'double'
+ STRING = 'string'
+ BYTES = 'bytes'
+ ENUM = 'enum'
+
+
+def _pb_scalar_kind_spark_type(kind):
+ if kind in {PbScalarKind.BOOL}:
+ return BooleanType()
+ if kind in {PbScalarKind.INT32, PbScalarKind.UINT32, PbScalarKind.SINT32,
+ PbScalarKind.FIXED32, PbScalarKind.SFIXED32, PbScalarKind.ENUM}:
+ return IntegerType()
+ if kind in {PbScalarKind.INT64, PbScalarKind.UINT64, PbScalarKind.SINT64,
+ PbScalarKind.FIXED64, PbScalarKind.SFIXED64}:
+ return LongType()
+ if kind == PbScalarKind.FLOAT:
+ return FloatType()
+ if kind == PbScalarKind.DOUBLE:
+ return DoubleType()
+ if kind == PbScalarKind.STRING:
+ return StringType()
+ if kind == PbScalarKind.BYTES:
+ return BinaryType()
+ raise ValueError(f'Unsupported protobuf scalar kind: {kind}')
+
+
+@dataclass(frozen=True)
+class PbEnumSpec:
+ name: str
+ values: tuple
+
+ def __post_init__(self):
+ values = tuple((str(name), int(number)) for name, number in self.values)
+ if not values:
+ raise ValueError('enum spec must contain at least one value')
+ names = [name for name, _ in values]
+ numbers = [number for _, number in values]
+ if len(names) != len(set(names)):
+ raise ValueError(f'duplicate enum names in {self.name}')
+ if len(numbers) != len(set(numbers)):
+ raise ValueError(f'duplicate enum numbers in {self.name}')
+ object.__setattr__(self, 'values', values)
+
+ def number_for(self, value):
+ if isinstance(value, str):
+ for name, number in self.values:
+ if name == value:
+ return number
+ raise ValueError(f'Unknown enum name {value!r} for enum {self.name}')
+ return int(value)
+
+
+@dataclass(frozen=True)
+class PbScalarFieldSpec:
+ name: str
+ number: int
+ kind: PbScalarKind
+ gen: object = None
+ cardinality: PbCardinality = PbCardinality.OPTIONAL
+ default: object = None
+ packed: bool = False
+ min_len: int = 0
+ max_len: int = 5
+ enum: object = None
+
+ def __post_init__(self):
+ object.__setattr__(self, 'name', str(self.name))
+ object.__setattr__(self, 'number', int(self.number))
+ if self.number <= 0:
+ raise ValueError('field numbers must be positive')
+ if self.cardinality != PbCardinality.REPEATED and self.packed:
+ raise ValueError(f'packed encoding requires repeated cardinality: {self.name}')
+ if self.cardinality == PbCardinality.REPEATED and self.default is not None:
+ raise ValueError(f'repeated fields cannot have defaults: {self.name}')
+ if self.min_len < 0 or self.max_len < self.min_len:
+ raise ValueError(f'invalid repeated length bounds for {self.name}')
+ if self.kind == PbScalarKind.ENUM:
+ if self.enum is None:
+ raise ValueError(f'enum field requires enum spec: {self.name}')
+ elif self.enum is not None:
+ raise ValueError(f'non-enum field cannot carry enum spec: {self.name}')
+ if self.packed and self.kind not in {
+ PbScalarKind.BOOL, PbScalarKind.INT32, PbScalarKind.INT64,
+ PbScalarKind.UINT32, PbScalarKind.UINT64, PbScalarKind.SINT32,
+ PbScalarKind.SINT64, PbScalarKind.FIXED32, PbScalarKind.FIXED64,
+ PbScalarKind.SFIXED32, PbScalarKind.SFIXED64, PbScalarKind.FLOAT,
+ PbScalarKind.DOUBLE, PbScalarKind.ENUM}:
+ raise ValueError(f'packed encoding is not supported for {self.kind.value}: {self.name}')
+
+
+@dataclass(frozen=True)
+class PbMessageFieldSpec:
+ name: str
+ number: int
+ fields: tuple
+ cardinality: PbCardinality = PbCardinality.OPTIONAL
+ min_len: int = 0
+ max_len: int = 5
+
+ def __post_init__(self):
+ object.__setattr__(self, 'name', str(self.name))
+ object.__setattr__(self, 'number', int(self.number))
+ object.__setattr__(self, 'fields', tuple(self.fields))
+ if self.number <= 0:
+ raise ValueError('field numbers must be positive')
+ if self.min_len < 0 or self.max_len < self.min_len:
+ raise ValueError(f'invalid repeated length bounds for {self.name}')
+ if self.cardinality == PbCardinality.REQUIRED and self.min_len != 0:
+ raise ValueError('required message field cannot define repeated bounds')
+
+
+@dataclass(frozen=True)
+class PbMessageSpec:
+ name: str
+ fields: tuple
+
+ def __post_init__(self):
+ object.__setattr__(self, 'name', str(self.name))
+ object.__setattr__(self, 'fields', tuple(self.fields))
+ _validate_pb_fields(self.fields, self.name)
+
+ def as_datagen(self, binary_col_name='bin'):
+ return ProtobufRowGen(self, binary_col_name=binary_col_name)
+
+ def encode(self, value):
+ return encode_pb_message(self, value)
+
+
+def _validate_pb_fields(fields, owner_name):
+ names = [field.name for field in fields]
+ numbers = [field.number for field in fields]
+ if len(names) != len(set(names)):
+ raise ValueError(f'duplicate field names in {owner_name}')
+ if len(numbers) != len(set(numbers)):
+ raise ValueError(f'duplicate field numbers in {owner_name}')
+
+
+class _PbBuilder:
+ def message(self, name, fields):
+ return PbMessageSpec(name, tuple(fields))
+
+ def bool(self, name, number, gen=None, default=None):
+ return PbScalarFieldSpec(name, number, PbScalarKind.BOOL, gen=gen, default=default)
+
+ def int32(self, name, number, gen=None, default=None):
+ return PbScalarFieldSpec(name, number, PbScalarKind.INT32, gen=gen, default=default)
+
+ def int64(self, name, number, gen=None, default=None):
+ return PbScalarFieldSpec(name, number, PbScalarKind.INT64, gen=gen, default=default)
+
+ def uint32(self, name, number, gen=None, default=None):
+ return PbScalarFieldSpec(name, number, PbScalarKind.UINT32, gen=gen, default=default)
+
+ def uint64(self, name, number, gen=None, default=None):
+ return PbScalarFieldSpec(name, number, PbScalarKind.UINT64, gen=gen, default=default)
+
+ def sint32(self, name, number, gen=None, default=None):
+ return PbScalarFieldSpec(name, number, PbScalarKind.SINT32, gen=gen, default=default)
+
+ def sint64(self, name, number, gen=None, default=None):
+ return PbScalarFieldSpec(name, number, PbScalarKind.SINT64, gen=gen, default=default)
+
+ def fixed32(self, name, number, gen=None, default=None):
+ return PbScalarFieldSpec(name, number, PbScalarKind.FIXED32, gen=gen, default=default)
+
+ def fixed64(self, name, number, gen=None, default=None):
+ return PbScalarFieldSpec(name, number, PbScalarKind.FIXED64, gen=gen, default=default)
+
+ def sfixed32(self, name, number, gen=None, default=None):
+ return PbScalarFieldSpec(name, number, PbScalarKind.SFIXED32, gen=gen, default=default)
+
+ def sfixed64(self, name, number, gen=None, default=None):
+ return PbScalarFieldSpec(name, number, PbScalarKind.SFIXED64, gen=gen, default=default)
+
+ def float(self, name, number, gen=None, default=None):
+ return PbScalarFieldSpec(name, number, PbScalarKind.FLOAT, gen=gen, default=default)
+
+ def double(self, name, number, gen=None, default=None):
+ return PbScalarFieldSpec(name, number, PbScalarKind.DOUBLE, gen=gen, default=default)
+
+ def string(self, name, number, gen=None, default=None):
+ return PbScalarFieldSpec(name, number, PbScalarKind.STRING, gen=gen, default=default)
+
+ def bytes(self, name, number, gen=None, default=None):
+ return PbScalarFieldSpec(name, number, PbScalarKind.BYTES, gen=gen, default=default)
+
+ def enum_type(self, name, values):
+ return PbEnumSpec(name, tuple(values))
+
+ def enum_field(self, name, number, enum_spec, gen=None, default=None):
+ return PbScalarFieldSpec(
+ name, number, PbScalarKind.ENUM, gen=gen, default=default, enum=enum_spec)
+
+ def message_field(self, name, number, fields):
+ return PbMessageFieldSpec(name, number, tuple(fields))
+
+ def nested(self, name, number, fields):
+ return self.message_field(name, number, fields)
+
+ def repeated(self, field_spec, min_len=0, max_len=5, packed=False):
+ if isinstance(field_spec, PbScalarFieldSpec):
+ return replace(
+ field_spec,
+ cardinality=PbCardinality.REPEATED,
+ min_len=min_len,
+ max_len=max_len,
+ packed=packed)
+ if isinstance(field_spec, PbMessageFieldSpec):
+ if packed:
+ raise ValueError('message fields cannot be packed')
+ return replace(
+ field_spec,
+ cardinality=PbCardinality.REPEATED,
+ min_len=min_len,
+ max_len=max_len)
+ raise TypeError(f'Unsupported protobuf field for repeated(): {type(field_spec)}')
+
+ def repeated_message(self, name, number, fields, min_len=0, max_len=5):
+ return self.repeated(self.message_field(name, number, fields), min_len=min_len, max_len=max_len)
+
+ def required(self, field_spec):
+ if isinstance(field_spec, PbScalarFieldSpec):
+ if field_spec.default is not None:
+ raise ValueError('required fields cannot have defaults')
+ return replace(field_spec, cardinality=PbCardinality.REQUIRED)
+ if isinstance(field_spec, PbMessageFieldSpec):
+ return replace(field_spec, cardinality=PbCardinality.REQUIRED)
+ raise TypeError(f'Unsupported protobuf field for required(): {type(field_spec)}')
+
+
+pb = _PbBuilder()
+
+
+def _pb_gen_cache_repr(gen):
+ return 'None' if gen is None else gen._cache_repr()
+
+
+def _pb_enum_cache_repr(enum_spec):
+ return 'None' if enum_spec is None else str(enum_spec.values)
+
+
+def _pb_field_cache_repr(field_spec):
+ if isinstance(field_spec, PbScalarFieldSpec):
+ return ('Scalar(' + field_spec.name + ',' + str(field_spec.number) + ',' +
+ field_spec.kind.value + ',' + field_spec.cardinality.value + ',' +
+ str(field_spec.default) + ',' + str(field_spec.packed) + ',' +
+ str(field_spec.min_len) + ',' + str(field_spec.max_len) + ',' +
+ _pb_enum_cache_repr(field_spec.enum) + ',' +
+ _pb_gen_cache_repr(field_spec.gen) + ')')
+ children = ','.join(_pb_field_cache_repr(child) for child in field_spec.fields)
+ return ('Message(' + field_spec.name + ',' + str(field_spec.number) + ',' +
+ field_spec.cardinality.value + ',' + str(field_spec.min_len) + ',' +
+ str(field_spec.max_len) + ',[' + children + '])')
+
+
+def _pb_message_cache_repr(message_spec):
+ children = ','.join(_pb_field_cache_repr(field_spec) for field_spec in message_spec.fields)
+ return 'PbMessage(' + message_spec.name + ',[' + children + '])'
+
+
+class ProtobufEncoder:
+ def encode_message(self, message_spec, value):
+ if value is None:
+ return b''
+ if not isinstance(message_spec, PbMessageSpec):
+ raise TypeError(f'encode_message expects PbMessageSpec, got {type(message_spec)}')
+ return self._encode_message_fields(message_spec.fields, value)
+
+ def encode_field(self, field_spec, value):
+ if isinstance(field_spec, PbScalarFieldSpec):
+ return self._encode_scalar_field(field_spec, value)
+ return self._encode_message_field(field_spec, value)
+
+ def _encode_message_fields(self, fields, value):
+ if not isinstance(value, dict):
+ raise TypeError(f'protobuf message values must be dicts, got {type(value)}')
+ unknown = set(value.keys()) - {field.name for field in fields}
+ if unknown:
+ raise ValueError(f'unknown protobuf field(s): {sorted(unknown)}')
+ return b''.join(
+ self.encode_field(field, value.get(field.name, _PB_MISSING))
+ for field in fields)
+
+ def _normalize_scalar_input(self, field_spec, value):
+ if field_spec.kind == PbScalarKind.ENUM:
+ return field_spec.enum.number_for(value)
+ if field_spec.kind == PbScalarKind.BOOL:
+ return bool(value)
+ if field_spec.kind in {
+ PbScalarKind.INT32, PbScalarKind.INT64, PbScalarKind.UINT32,
+ PbScalarKind.UINT64, PbScalarKind.SINT32, PbScalarKind.SINT64,
+ PbScalarKind.FIXED32, PbScalarKind.FIXED64, PbScalarKind.SFIXED32,
+ PbScalarKind.SFIXED64}:
+ return int(value)
+ if field_spec.kind in {PbScalarKind.FLOAT, PbScalarKind.DOUBLE}:
+ return float(value)
+ if field_spec.kind == PbScalarKind.STRING:
+ return str(value)
+ if field_spec.kind == PbScalarKind.BYTES:
+ return value if isinstance(value, bytes) else bytes(value)
+ raise ValueError(f'Unsupported scalar kind: {field_spec.kind}')
+
+ def _encode_scalar_payload(self, field_spec, value):
+ scalar_value = self._normalize_scalar_input(field_spec, value)
+ kind = field_spec.kind
+ if kind == PbScalarKind.BOOL:
+ return _PROTOBUF_WIRE_VARINT, _encode_protobuf_uvarint(1 if scalar_value else 0)
+ if kind in {PbScalarKind.INT32, PbScalarKind.INT64, PbScalarKind.ENUM}:
+ u64 = int(scalar_value) & 0xFFFFFFFFFFFFFFFF
+ return _PROTOBUF_WIRE_VARINT, _encode_protobuf_uvarint(u64)
+ if kind == PbScalarKind.UINT32:
+ scalar_value = int(scalar_value)
+ if scalar_value < 0:
+ raise ValueError(f'uint32 field cannot encode negative value: {field_spec.name}')
+ return _PROTOBUF_WIRE_VARINT, _encode_protobuf_uvarint(scalar_value)
+ if kind == PbScalarKind.UINT64:
+ scalar_value = int(scalar_value)
+ if scalar_value < 0:
+ raise ValueError(f'uint64 field cannot encode negative value: {field_spec.name}')
+ return _PROTOBUF_WIRE_VARINT, _encode_protobuf_uvarint(scalar_value)
+ if kind == PbScalarKind.SINT32:
+ return _PROTOBUF_WIRE_VARINT, _encode_protobuf_uvarint(_encode_protobuf_zigzag32(scalar_value))
+ if kind == PbScalarKind.SINT64:
+ return _PROTOBUF_WIRE_VARINT, _encode_protobuf_uvarint(_encode_protobuf_zigzag64(scalar_value))
+ if kind in {PbScalarKind.FIXED32, PbScalarKind.SFIXED32}:
+ return _PROTOBUF_WIRE_32BIT, struct.pack(' bool:
+ """
+ `spark-protobuf` is an optional external module. PySpark may have the Python wrappers
+ even when the JVM side isn't present on the classpath, which manifests as:
+ TypeError: 'JavaPackage' object is not callable
+ when calling into `sc._jvm.org.apache.spark.sql.protobuf.functions.from_protobuf`.
+
+ In the integration harness, Spark jars are often attached dynamically. Using the current
+ thread's context classloader is more reliable than the default `Class.forName()` lookup.
+ """
+ jvm = spark.sparkContext._jvm
+ loader = None
+ try:
+ loader = jvm.Thread.currentThread().getContextClassLoader()
+ except Exception:
+ pass
+ candidates = [
+ # Scala object `functions` compiles to `functions$`
+ "org.apache.spark.sql.protobuf.functions$",
+ # Some environments may expose it differently
+ "org.apache.spark.sql.protobuf.functions",
+ ]
+ for cls in candidates:
+ try:
+ if loader is not None:
+ jvm.java.lang.Class.forName(cls, True, loader)
+ else:
+ jvm.java.lang.Class.forName(cls)
+ return True
+ except Exception:
+ continue
+
+ # Fallback: try to resolve the JVM member through Py4J. A missing optional module typically
+ # stays as a JavaPackage placeholder instead of a callable JavaMember/JavaClass.
+ try:
+ member = jvm.org.apache.spark.sql.protobuf.functions.from_protobuf
+ return type(member).__name__ != "JavaPackage"
+ except Exception:
+ return False
+
+
+# ---------------------------------------------------------------------------
+# Shared fixture and helpers to reduce per-test boilerplate
+# ---------------------------------------------------------------------------
+
+@pytest.fixture(scope="module")
+def from_protobuf_fn():
+ """Skip the entire module if from_protobuf or the JVM module is unavailable."""
+ fn = _try_import_from_protobuf()
+ if fn is None:
+ pytest.skip("from_protobuf not available")
+ if not with_cpu_session(_spark_protobuf_jvm_available):
+ pytest.skip("spark-protobuf JVM not available")
+ return fn
+
+
+def _setup_protobuf_desc(spark_tmp_path, desc_name, build_fn):
+ """Build descriptor bytes via JVM, write to HDFS, return (desc_path, desc_bytes)."""
+ desc_path = spark_tmp_path + "/" + desc_name
+ desc_bytes = with_cpu_session(build_fn)
+ with_cpu_session(
+ lambda spark: _write_bytes_to_hadoop_path(spark, desc_path, desc_bytes))
+ return desc_path, desc_bytes
+
+
+def _call_from_protobuf(from_protobuf_fn, col, message_name,
+ desc_path, desc_bytes, options=None):
+ """Call from_protobuf using the right API variant."""
+ sig = inspect.signature(from_protobuf_fn)
+ if "binaryDescriptorSet" in sig.parameters:
+ kw = dict(binaryDescriptorSet=bytearray(desc_bytes))
+ if options is not None:
+ kw["options"] = options
+ return from_protobuf_fn(col, message_name, **kw)
+ if options is not None:
+ return from_protobuf_fn(col, message_name, desc_path, options)
+ return from_protobuf_fn(col, message_name, desc_path)
+
+
+def test_call_from_protobuf_preserves_options_for_legacy_signature():
+ calls = []
+
+ def fake_from_protobuf(col, message_name, desc_path, *args):
+ calls.append((col, message_name, desc_path, args))
+ return "ok"
+
+ options = {"enums.as.ints": "true"}
+ result = _call_from_protobuf(
+ fake_from_protobuf, "col", "msg", "/tmp/test.desc", b"desc", options=options)
+
+ assert result == "ok"
+ assert calls == [("col", "msg", "/tmp/test.desc", (options,))]
+
+
+def test_encode_protobuf_packed_repeated_fixed_uses_unsigned_twos_complement():
+ i32_encoded = _encode_protobuf_packed_repeated(
+ 1, IntegerType(), [0xFFFFFFFF], encoding='fixed')
+ i64_encoded = _encode_protobuf_packed_repeated(
+ 1, LongType(), [0xFFFFFFFFFFFFFFFF], encoding='fixed')
+
+ assert i32_encoded == b"\x0a\x04" + struct.pack(" 1, 1 -> 2, -2 -> 3, 2 -> 4, etc.
+ """
+ desc_path, desc_bytes = _setup_protobuf_desc(
+ spark_tmp_path, "signed.desc", _build_signed_int_descriptor_set_bytes)
+ message_name = "test.WithSignedInts"
+
+ data_gen = _as_datagen([
+ _scalar("si32", 1, IntegerGen(
+ special_cases=[-1, 0, 1, -2147483648, 2147483647]), encoding='zigzag'),
+ _scalar("si64", 2, LongGen(
+ special_cases=[-1, 0, 1, -9223372036854775808, 9223372036854775807]),
+ encoding='zigzag'),
+ _scalar("sf32", 3, IntegerGen(
+ special_cases=[0, 1, -1, 2147483647, -2147483648]), encoding='fixed'),
+ _scalar("sf64", 4, LongGen(
+ special_cases=[0, 1, -1]), encoding='fixed'),
+ ])
+
+ def run_on_spark(spark):
+ df = gen_df(spark, data_gen)
+ decoded = _call_from_protobuf(
+ from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes)
+
+ return df.select(
+ decoded.getField("si32").alias("si32"),
+ decoded.getField("si64").alias("si64"),
+ decoded.getField("sf32").alias("sf32"),
+ decoded.getField("sf64").alias("sf64"),
+ )
+
+ assert_gpu_and_cpu_are_equal_collect(run_on_spark)
+
+
+def _build_fixed_int_descriptor_set_bytes(spark):
+ """Build a descriptor for fixed-width integer fields."""
+ return _build_proto2_descriptor(spark, "fixed_int.proto", [
+ _msg("WithFixedInts", [
+ _field("fx32", 1, "FIXED32"),
+ _field("fx64", 2, "FIXED64"),
+ ]),
+ ])
+
+
+@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+")
+@ignore_order(local=True)
+def test_from_protobuf_fixed_integers(spark_tmp_path, from_protobuf_fn):
+ """
+ Test decoding fixed-width unsigned integer types.
+ """
+ desc_path, desc_bytes = _setup_protobuf_desc(
+ spark_tmp_path, "fixed.desc", _build_fixed_int_descriptor_set_bytes)
+ message_name = "test.WithFixedInts"
+
+ data_gen = _as_datagen([
+ _scalar("fx32", 1, IntegerGen(
+ special_cases=[0, 1, -1, 2147483647, -2147483648]), encoding='fixed'),
+ _scalar("fx64", 2, LongGen(
+ special_cases=[0, 1, -1]), encoding='fixed'),
+ ])
+
+ def run_on_spark(spark):
+ df = gen_df(spark, data_gen)
+ decoded = _call_from_protobuf(
+ from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes)
+ return df.select(
+ decoded.getField("fx32").alias("fx32"),
+ decoded.getField("fx64").alias("fx64"),
+ )
+
+ assert_gpu_and_cpu_are_equal_collect(run_on_spark)
+
+
+def _build_schema_projection_descriptor_set_bytes(spark):
+ """Build a descriptor with nested and repeated struct fields for pruning tests."""
+ return _build_proto2_descriptor(spark, "schema_proj.proto", [
+ _msg("Detail", [
+ _field("a", 1, "INT32"),
+ _field("b", 2, "INT32"),
+ _field("c", 3, "STRING"),
+ ]),
+ _msg("SchemaProj", [
+ _field("id", 1, "INT32"),
+ _field("name", 2, "STRING"),
+ _field("detail", 3, "MESSAGE", type_name=".test.Detail"),
+ _field("items", 4, "MESSAGE", label="repeated", type_name=".test.Detail"),
+ ]),
+ ])
+
+
+# Field descriptors for SchemaProj: {id, name, detail: {a, b, c}, items[]: {a, b, c}}
+_detail_children = [
+ _scalar("a", 1, IntegerGen()),
+ _scalar("b", 2, IntegerGen()),
+ _scalar("c", 3, StringGen()),
+]
+_schema_proj_schema = _schema("SchemaProjManual", [
+ _scalar("id", 1, IntegerGen()),
+ _scalar("name", 2, StringGen()),
+ _nested("detail", 3, _detail_children),
+ _repeated_message("items", 4, _detail_children),
+])
+
+_schema_proj_test_data = [
+ encode_pb_message(_schema_proj_schema, {
+ "id": 1,
+ "name": "alice",
+ "detail": {"a": 10, "b": 20, "c": "d1"},
+ "items": [
+ {"a": 100, "b": 200, "c": "i1"},
+ {"a": 101, "b": 201, "c": "i2"},
+ ],
+ }),
+ encode_pb_message(_schema_proj_schema, {
+ "id": 2,
+ "name": "bob",
+ "detail": {"a": 30, "b": 40, "c": "d2"},
+ "items": [
+ {"a": 300, "b": 400, "c": "i3"},
+ ],
+ }),
+ encode_pb_message(_schema_proj_schema, {
+ "id": 3,
+ "name": "carol",
+ "detail": {"a": 50, "b": 60, "c": "d3"},
+ "items": [],
+ }),
+]
+
+
+_schema_proj_cases = [
+ ("nested_single_field", [("id", ("id",)), ("detail_a", ("detail", "a"))]),
+ ("nested_two_fields", [("detail_a", ("detail", "a")), ("detail_c", ("detail", "c"))]),
+ ("whole_struct_no_pruning", [("id", ("id",)), ("detail", ("detail",))]),
+ ("whole_and_subfield", [("detail", ("detail",)), ("detail_a", ("detail", "a"))]),
+ ("scalar_plus_nested", [("id", ("id",)), ("name", ("name",)), ("detail_a", ("detail", "a"))]),
+ ("repeated_msg_single_subfield", [("id", ("id",)), ("items_a", ("items", "a"))]),
+ ("repeated_msg_two_subfields", [("items_a", ("items", "a")), ("items_c", ("items", "c"))]),
+ ("repeated_whole_no_pruning", [("id", ("id",)), ("items", ("items",))]),
+ ("mix_struct_and_repeated", [("id", ("id",)), ("detail_a", ("detail", "a")), ("items_c", ("items", "c"))]),
+]
+
+
+def _get_field_by_path(expr, path):
+ current = expr
+ for name in path:
+ current = current.getField(name)
+ return current
+
+
+@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+")
+@pytest.mark.parametrize("boundary", ["alias", "withcolumn"], ids=idfn)
+@ignore_order(local=True)
+def test_from_protobuf_projection_across_plan_boundary(
+ spark_tmp_path, from_protobuf_fn, boundary):
+ """Schema projection across alias (select→select) and withColumn plan boundaries."""
+ desc_path, desc_bytes = _setup_protobuf_desc(
+ spark_tmp_path, "schema_proj_boundary.desc",
+ _build_schema_projection_descriptor_set_bytes)
+ message_name = "test.SchemaProj"
+
+ def run_on_spark(spark):
+ df = spark.createDataFrame([(row,) for row in _schema_proj_test_data], schema="bin binary")
+ decoded = _call_from_protobuf(
+ from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes)
+ if boundary == "alias":
+ aliased = df.select(decoded.alias("decoded"))
+ return aliased.select(
+ f.col("decoded").getField("detail").getField("a").alias("detail_a"),
+ f.col("decoded").getField("id").alias("id"))
+ else:
+ with_decoded = df.withColumn("decoded", decoded)
+ return with_decoded.select(
+ f.col("decoded").getField("items").getField("a").alias("items_a"),
+ f.col("decoded").getField("id").alias("id"))
+
+ assert_gpu_and_cpu_are_equal_collect(run_on_spark)
+
+
+def _build_dual_message_projection_descriptor_set_bytes(spark):
+ """Build descriptors for two logical views over the same binary payload column."""
+ return _build_proto2_descriptor(spark, "dual_projection.proto", [
+ _msg("NestedPayload", [_field("count", 1, "INT32")]),
+ _msg("BytesView", [
+ _field("status", 1, "INT32"),
+ _field("payload", 2, "BYTES"),
+ ]),
+ _msg("NestedView", [
+ _field("status", 1, "INT32"),
+ _field("payload", 2, "MESSAGE", type_name=".test.NestedPayload"),
+ ]),
+ ])
+
+
+@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+")
+@ignore_order(local=True)
+def test_from_protobuf_different_messages_same_binary_column_do_not_interfere(
+ spark_tmp_path, from_protobuf_fn):
+ desc_path, desc_bytes = _setup_protobuf_desc(
+ spark_tmp_path, "dual_projection.desc",
+ _build_dual_message_projection_descriptor_set_bytes)
+
+ payload_keep = _encode_tag(1, 0) + _encode_varint(7)
+ payload_drop = _encode_tag(1, 0) + _encode_varint(9)
+ row_keep = (_encode_tag(1, 0) + _encode_varint(1) +
+ _encode_tag(2, 2) + _encode_varint(len(payload_keep)) + payload_keep)
+ row_drop = (_encode_tag(1, 0) + _encode_varint(0) +
+ _encode_tag(2, 2) + _encode_varint(len(payload_drop)) + payload_drop)
+
+ def run_on_spark(spark):
+ df = spark.createDataFrame([(row_keep,), (row_drop,)], schema="bin binary")
+ bytes_view = _call_from_protobuf(
+ from_protobuf_fn, f.col("bin"), "test.BytesView", desc_path, desc_bytes)
+ nested_view = _call_from_protobuf(
+ from_protobuf_fn, f.col("bin"), "test.NestedView", desc_path, desc_bytes)
+ return df.filter(bytes_view.getField("status") == 1).select(
+ nested_view.getField("payload").getField("count").alias("payload_count"))
+
+ assert_gpu_and_cpu_are_equal_collect(run_on_spark)
+
+
+def _build_deep_nested_5_level_descriptor_set_bytes(spark):
+ """Build a descriptor for a five-level nested message chain."""
+ return _build_proto2_descriptor(spark, "deep_nested_5_level.proto", [
+ _msg("Level5", [_field("val5", 1, "INT32")]),
+ _msg("Level4", [
+ _field("val4", 1, "INT32"),
+ _field("level5", 2, "MESSAGE", type_name=".test.Level5"),
+ ]),
+ _msg("Level3", [
+ _field("val3", 1, "INT32"),
+ _field("level4", 2, "MESSAGE", type_name=".test.Level4"),
+ ]),
+ _msg("Level2", [
+ _field("val2", 1, "INT32"),
+ _field("level3", 2, "MESSAGE", type_name=".test.Level3"),
+ ]),
+ _msg("Level1", [
+ _field("val1", 1, "INT32"),
+ _field("level2", 2, "MESSAGE", type_name=".test.Level2"),
+ ]),
+ ])
+
+
+@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+")
+@ignore_order(local=True)
+def test_from_protobuf_deep_nesting_5_levels(spark_tmp_path, from_protobuf_fn):
+ desc_path, desc_bytes = _setup_protobuf_desc(
+ spark_tmp_path, "deep_nested_5_level.desc",
+ _build_deep_nested_5_level_descriptor_set_bytes)
+ message_name = "test.Level1"
+ data_gen = _as_datagen([
+ _scalar("val1", 1, IntegerGen()),
+ _nested("level2", 2, [
+ _scalar("val2", 1, IntegerGen()),
+ _nested("level3", 2, [
+ _scalar("val3", 1, IntegerGen()),
+ _nested("level4", 2, [
+ _scalar("val4", 1, IntegerGen()),
+ _nested("level5", 2, [
+ _scalar("val5", 1, IntegerGen()),
+ ]),
+ ]),
+ ]),
+ ]),
+ ])
+ def run_on_spark(spark):
+ df = gen_df(spark, data_gen)
+ decoded = _call_from_protobuf(
+ from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes)
+ return df.select(
+ decoded.getField("val1").alias("val1"),
+ decoded.getField("level2").alias("level2"),
+ )
+ assert_gpu_and_cpu_are_equal_collect(run_on_spark)
+
+
+@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+")
+@pytest.mark.parametrize("case_id,select_specs", _schema_proj_cases, ids=lambda c: c[0] if isinstance(c, tuple) else str(c))
+@ignore_order(local=True)
+def test_from_protobuf_schema_projection_cases(
+ spark_tmp_path, from_protobuf_fn, case_id, select_specs):
+ """Parametrized nested-schema projection tests."""
+ desc_path, desc_bytes = _setup_protobuf_desc(
+ spark_tmp_path, "schema_proj.desc", _build_schema_projection_descriptor_set_bytes)
+ message_name = "test.SchemaProj"
+
+ def run_on_spark(spark):
+ df = spark.createDataFrame(
+ [(d,) for d in _schema_proj_test_data], schema="bin binary")
+ decoded = _call_from_protobuf(
+ from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes)
+ selected = [_get_field_by_path(decoded, path).alias(alias)
+ for alias, path in select_specs]
+ return df.select(*selected)
+
+ assert_gpu_and_cpu_are_equal_collect(run_on_spark)
+
+def _build_name_collision_descriptor_set_bytes(spark):
+ """Build a regression schema with same-named fields in unrelated nested messages."""
+ return _build_proto2_descriptor(spark, "name_collision.proto", [
+ _msg("User", [
+ _field("age", 1, "INT32"),
+ _field("id", 2, "INT32"),
+ ]),
+ _msg("Ad", [_field("id", 1, "INT32")]),
+ _msg("Event", [
+ _field("user_info", 1, "MESSAGE", type_name=".test.User"),
+ _field("ad_info", 2, "MESSAGE", type_name=".test.Ad"),
+ ]),
+ ])
+
+@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+")
+@ignore_order(local=True)
+def test_from_protobuf_bug1_name_collision(spark_tmp_path, from_protobuf_fn):
+ desc_path, desc_bytes = _setup_protobuf_desc(
+ spark_tmp_path, "name_collision.desc",
+ _build_name_collision_descriptor_set_bytes)
+ message_name = "test.Event"
+
+ data_gen = _as_datagen([
+ _nested("user_info", 1, [
+ _scalar("age", 1, IntegerGen()),
+ _scalar("id", 2, IntegerGen()),
+ ]),
+ _nested("ad_info", 2, [
+ _scalar("id", 1, IntegerGen()),
+ ]),
+ ])
+
+ def run_on_spark(spark):
+ df = gen_df(spark, data_gen)
+ decoded = _call_from_protobuf(
+ from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes)
+
+ return df.select(
+ decoded.getField("user_info").getField("age").alias("age"),
+ decoded.getField("user_info").getField("id").alias("user_id"),
+ decoded.getField("ad_info").getField("id").alias("ad_id")
+ )
+
+ assert_gpu_and_cpu_are_equal_collect(run_on_spark)
+
+
+def _build_filter_jump_descriptor_set_bytes(spark):
+ """Build a minimal descriptor used by the filter-jump regression test."""
+ return _build_proto2_descriptor(spark, "filter_jump.proto", [
+ _msg("Event", [
+ _field("status", 1, "INT32"),
+ _field("ad_info", 2, "STRING"),
+ ]),
+ ])
+
+@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+")
+@ignore_order(local=True)
+def test_from_protobuf_bug2_filter_jump(spark_tmp_path, from_protobuf_fn):
+ desc_path, desc_bytes = _setup_protobuf_desc(
+ spark_tmp_path, "filter_jump.desc",
+ _build_filter_jump_descriptor_set_bytes)
+ message_name = "test.Event"
+
+ data_gen = _as_datagen([
+ _scalar("status", 1, IntegerGen(min_val=1, max_val=1)),
+ _scalar("ad_info", 2, StringGen()),
+ ])
+
+ def run_on_spark(spark):
+ df = gen_df(spark, data_gen)
+ pb_expr1 = _call_from_protobuf(
+ from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes)
+ pb_expr2 = _call_from_protobuf(
+ from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes)
+
+ return df.filter(pb_expr1.getField("status") == 1).select(pb_expr2.getField("ad_info").alias("ad_info"))
+
+ assert_gpu_and_cpu_are_equal_collect(run_on_spark)
+
+
+def _build_unrelated_struct_name_collision_descriptor_set_bytes(spark):
+ """Build a regression schema where an unrelated nested struct shares a field name."""
+ return _build_proto2_descriptor(spark, "unrelated_struct.proto", [
+ _msg("Nested", [
+ _field("dummy", 1, "INT32"),
+ _field("winfoid", 2, "INT32"),
+ ]),
+ _msg("Event", [
+ _field("ad_info", 1, "MESSAGE", type_name=".test.Nested"),
+ ]),
+ ])
+
+@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+")
+@ignore_order(local=True)
+def test_from_protobuf_bug3_unrelated_struct_name_collision(spark_tmp_path, from_protobuf_fn):
+ desc_path, desc_bytes = _setup_protobuf_desc(
+ spark_tmp_path, "unrelated_struct.desc",
+ _build_unrelated_struct_name_collision_descriptor_set_bytes)
+ message_name = "test.Event"
+
+ data_gen = _as_datagen([
+ _nested("ad_info", 1, [
+ _scalar("dummy", 1, IntegerGen()),
+ _scalar("winfoid", 2, IntegerGen()),
+ ]),
+ ])
+
+ def run_on_spark(spark):
+ df = gen_df(spark, data_gen)
+ # Write to parquet to prevent Catalyst from optimizing away the GetStructField,
+ # and to ensure it runs on the GPU.
+ df_with_other = df.withColumn("other_struct",
+ f.struct(f.lit("hello").alias("dummy_str"), f.lit(42).alias("winfoid")))
+
+ path = spark_tmp_path + "/bug3_data.parquet"
+ df_with_other.write.mode("overwrite").parquet(path)
+ read_df = spark.read.parquet(path)
+
+ decoded = _call_from_protobuf(
+ from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes)
+
+ # We only select decoded.ad_info.winfoid, so dummy is pruned.
+ # winfoid gets ordinal 0 in the pruned schema.
+ # But for other_struct, winfoid is ordinal 1.
+ # GpuGetStructFieldMeta will see "winfoid", query the ThreadLocal, get 0,
+ # and extract ordinal 0 ("hello") for other_winfoid, causing a mismatch!
+ return read_df.select(
+ decoded.getField("ad_info").getField("winfoid").alias("pb_winfoid"),
+ f.col("other_struct").getField("winfoid").alias("other_winfoid")
+ )
+
+ assert_gpu_and_cpu_are_equal_collect(run_on_spark)
+
+
+def _build_max_depth_descriptor_set_bytes(spark):
+ """Build a descriptor for a 12-level nested message chain."""
+ messages = []
+ for i in range(12, 0, -1):
+ fields = [_field(f"val{i}", 1, "INT32")]
+ if i < 12:
+ fields.append(
+ _field(f"level{i + 1}", 2, "MESSAGE", type_name=f".test.Level{i + 1}")
+ )
+ messages.append(_msg(f"Level{i}", fields))
+ return _build_proto2_descriptor(spark, "max_depth.proto", messages)
+
+@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+")
+@ignore_order(local=True)
+def test_from_protobuf_bug4_max_depth(spark_tmp_path, from_protobuf_fn):
+ desc_path, desc_bytes = _setup_protobuf_desc(
+ spark_tmp_path, "max_depth.desc",
+ _build_max_depth_descriptor_set_bytes)
+ message_name = "test.Level1"
+
+ # Build the deeply nested data gen spec
+ def build_nested_gen(level):
+ if level == 12:
+ return [_scalar(f"val{level}", 1, IntegerGen())]
+ return [
+ _scalar(f"val{level}", 1, IntegerGen()),
+ _nested(f"level{level+1}", 2, build_nested_gen(level+1))
+ ]
+
+ data_gen = _as_datagen(build_nested_gen(1))
+
+ def run_on_spark(spark):
+ df = gen_df(spark, data_gen)
+ decoded = _call_from_protobuf(
+ from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes)
+ # Deep access
+ field = decoded
+ for i in range(2, 13):
+ field = field.getField(f"level{i}")
+ return df.select(field.getField("val12").alias("val12"))
+
+ # Depth 12 exceeds GPU max nesting depth (10), so the query should
+ # gracefully fall back to CPU. Verify that it still produces correct
+ # results (CPU path) without crashing.
+ from spark_session import with_cpu_session
+ cpu_result = with_cpu_session(lambda spark: run_on_spark(spark).collect())
+ assert len(cpu_result) > 0
+
+
+# ===========================================================================
+# Regression tests for known bugs found by code review
+# ===========================================================================
+
+def _encode_varint(value):
+ """Encode a non-negative integer as a protobuf varint (for hand-crafting test bytes)."""
+ out = bytearray()
+ v = int(value)
+ while True:
+ b = v & 0x7F
+ v >>= 7
+ if v:
+ out.append(b | 0x80)
+ else:
+ out.append(b)
+ break
+ return bytes(out)
+
+
+def _encode_tag(field_number, wire_type):
+ return _encode_varint((field_number << 3) | wire_type)
+
+
+# ---------------------------------------------------------------------------
+# Bug 1: BOOL8 truncation — non-canonical bool varint values >= 256
+#
+# Protobuf spec: bool is a varint; any non-zero value means true.
+# CPU decoder (protobuf-java): CodedInputStream.readBool() = readRawVarint64() != 0 → true
+# GPU decoder: extract_varint_kernel writes static_cast(v).
+# For v = 256, static_cast(256) == 0 → false. BUG.
+# ---------------------------------------------------------------------------
+
+@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+")
+@ignore_order(local=True)
+def test_from_protobuf_bool_noncanonical_varint_scalar(spark_tmp_path, from_protobuf_fn):
+ """Regression test: scalar bool encoded as non-canonical varint (e.g. 256) must decode as true.
+
+ Protobuf allows any non-zero varint for bool true. The GPU decoder previously
+ truncated to uint8_t, causing values >= 256 to wrap to 0 (false).
+ """
+ desc_path, desc_bytes = _setup_protobuf_desc(
+ spark_tmp_path, "simple_bool_bug.desc", _build_simple_descriptor_set_bytes)
+ message_name = "test.Simple"
+
+ # varint(256) = 0x80 0x02, varint(512) = 0x80 0x04 — valid non-canonical bool true
+ row_bool_256 = _encode_tag(1, 0) + _encode_varint(256) + \
+ _encode_tag(2, 0) + _encode_varint(99)
+
+ # Control row: canonical bool true (varint 1) — should work on both
+ row_bool_1 = _encode_tag(1, 0) + _encode_varint(1) + \
+ _encode_tag(2, 0) + _encode_varint(100)
+
+ # Another non-canonical value: varint(512)
+ row_bool_512 = _encode_tag(1, 0) + _encode_varint(512) + \
+ _encode_tag(2, 0) + _encode_varint(101)
+
+ def run_on_spark(spark):
+ df = spark.createDataFrame(
+ [(row_bool_256,), (row_bool_1,), (row_bool_512,)],
+ schema="bin binary",
+ )
+ decoded = _call_from_protobuf(
+ from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes)
+ return df.select(
+ decoded.getField("b").alias("b"),
+ decoded.getField("i32").alias("i32"),
+ )
+
+ assert_gpu_and_cpu_are_equal_collect(run_on_spark)
+
+
+def _build_repeated_bool_descriptor_set_bytes(spark):
+ """Build a descriptor for an optional id plus repeated bool flags."""
+ return _build_proto2_descriptor(spark, "repeated_bool.proto", [
+ _msg("WithRepeatedBool", [
+ _field("id", 1, "INT32"),
+ _field("flags", 2, "BOOL", label="repeated"),
+ ]),
+ ])
+
+
+@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+")
+@ignore_order(local=True)
+def test_from_protobuf_bool_noncanonical_varint_repeated(spark_tmp_path, from_protobuf_fn):
+ """Regression test: repeated bool with non-canonical varint values must all decode as true.
+
+ Same uint8_t truncation issue as the scalar case, exercised with repeated fields.
+ """
+ desc_path, desc_bytes = _setup_protobuf_desc(
+ spark_tmp_path, "repeated_bool_bug.desc", _build_repeated_bool_descriptor_set_bytes)
+ message_name = "test.WithRepeatedBool"
+
+ # Repeated bool field 2 (wire type 0 = varint), unpacked.
+ # Three elements: varint(256), varint(1), varint(512) — all should decode as true.
+ row = (_encode_tag(1, 0) + _encode_varint(42) +
+ _encode_tag(2, 0) + _encode_varint(256) +
+ _encode_tag(2, 0) + _encode_varint(1) +
+ _encode_tag(2, 0) + _encode_varint(512))
+
+ def run_on_spark(spark):
+ df = spark.createDataFrame([(row,)], schema="bin binary")
+ decoded = _call_from_protobuf(
+ from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes)
+ return df.select(
+ decoded.getField("id").alias("id"),
+ decoded.getField("flags").alias("flags"),
+ )
+
+ assert_gpu_and_cpu_are_equal_collect(run_on_spark)
+
+
+# ---------------------------------------------------------------------------
+# Regression guard: nested message child field default values
+# ---------------------------------------------------------------------------
+
+def _build_nested_with_defaults_descriptor_set_bytes(spark):
+ """Build a descriptor with proto2 defaults inside a nested child struct."""
+ return _build_proto2_descriptor(spark, "nested_defaults.proto", [
+ _msg("Inner", [
+ _field("count", 1, "INT32", default=42),
+ _field("label", 2, "STRING", default="hello"),
+ _field("flag", 3, "BOOL", default=True),
+ ]),
+ _msg("OuterWithNestedDefaults", [
+ _field("id", 1, "INT32"),
+ _field("inner", 2, "MESSAGE", type_name=".test.Inner"),
+ ]),
+ ])
+
+
+@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+")
+@ignore_order(local=True)
+def test_from_protobuf_nested_child_default_values(spark_tmp_path, from_protobuf_fn):
+ """Regression test: proto2 default values for fields inside nested messages must be honored.
+
+ When `inner` is present but its child fields are absent, the decoder must
+ return the proto2 defaults (count=42, label="hello", flag=true), not null.
+ """
+ desc_path, desc_bytes = _setup_protobuf_desc(
+ spark_tmp_path, "nested_defaults.desc",
+ _build_nested_with_defaults_descriptor_set_bytes)
+ message_name = "test.OuterWithNestedDefaults"
+
+ # Row 1: outer.id = 10, inner is present but EMPTY (0-length nested message).
+ # Wire: field 1 varint(10), field 2 length-delimited with length 0.
+ # CPU should fill inner.count=42, inner.label="hello", inner.flag=true.
+ row_empty_inner = (_encode_tag(1, 0) + _encode_varint(10) +
+ _encode_tag(2, 2) + _encode_varint(0))
+
+ # Row 2: outer.id = 20, inner has only count=7 (label and flag should get defaults).
+ inner_partial = _encode_tag(1, 0) + _encode_varint(7)
+ row_partial_inner = (_encode_tag(1, 0) + _encode_varint(20) +
+ _encode_tag(2, 2) + _encode_varint(len(inner_partial)) +
+ inner_partial)
+
+ # Row 3: outer.id = 30, inner is fully absent → inner itself is null.
+ row_no_inner = _encode_tag(1, 0) + _encode_varint(30)
+
+ def run_on_spark(spark):
+ df = spark.createDataFrame(
+ [(row_empty_inner,), (row_partial_inner,), (row_no_inner,)],
+ schema="bin binary",
+ )
+ decoded = _call_from_protobuf(
+ from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes)
+ return df.select(
+ decoded.getField("id").alias("id"),
+ decoded.getField("inner").getField("count").alias("inner_count"),
+ decoded.getField("inner").getField("label").alias("inner_label"),
+ decoded.getField("inner").getField("flag").alias("inner_flag"),
+ )
+
+ assert_gpu_and_cpu_are_equal_collect(run_on_spark)
+
+
+# ===========================================================================
+# Deep nested schema pruning tests
+#
+# These verify that the GPU path correctly prunes nested fields at depth > 2.
+# Previously, collectStructFieldReferences only recognized 2-level
+# GetStructField chains, so accessing decoded.level2.level3.val3 would
+# decode ALL of level3's children instead of only val3.
+# ===========================================================================
+
+def _deep_5_level_data_gen():
+ return _as_datagen([
+ _scalar("val1", 1, IntegerGen()),
+ _nested("level2", 2, [
+ _scalar("val2", 1, IntegerGen()),
+ _nested("level3", 2, [
+ _scalar("val3", 1, IntegerGen()),
+ _nested("level4", 2, [
+ _scalar("val4", 1, IntegerGen()),
+ _nested("level5", 2, [
+ _scalar("val5", 1, IntegerGen()),
+ ]),
+ ]),
+ ]),
+ ]),
+ ])
+
+
+@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+")
+@ignore_order(local=True)
+def test_deep_pruning_3_level_leaf(spark_tmp_path, from_protobuf_fn):
+ """Access decoded.level2.level3.val3 -- triggers 3-level deep pruning."""
+ desc_path, desc_bytes = _setup_protobuf_desc(
+ spark_tmp_path, "dp3.desc", _build_deep_nested_5_level_descriptor_set_bytes)
+ message_name = "test.Level1"
+ data_gen = _deep_5_level_data_gen()
+
+ def run_on_spark(spark):
+ df = gen_df(spark, data_gen)
+ decoded = _call_from_protobuf(
+ from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes)
+ return df.select(
+ decoded.getField("val1").alias("val1"),
+ decoded.getField("level2").getField("level3").getField("val3").alias("deep_val3"),
+ )
+
+ assert_gpu_and_cpu_are_equal_collect(run_on_spark)
+
+
+@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+")
+@ignore_order(local=True)
+def test_deep_pruning_5_level_leaf(spark_tmp_path, from_protobuf_fn):
+ """Access decoded.level2.level3.level4.level5.val5 -- deepest leaf."""
+ desc_path, desc_bytes = _setup_protobuf_desc(
+ spark_tmp_path, "dp5.desc", _build_deep_nested_5_level_descriptor_set_bytes)
+ message_name = "test.Level1"
+ data_gen = _deep_5_level_data_gen()
+
+ def run_on_spark(spark):
+ df = gen_df(spark, data_gen)
+ decoded = _call_from_protobuf(
+ from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes)
+ return df.select(
+ _get_field_by_path(decoded, ["level2", "level3", "level4", "level5", "val5"])
+ .alias("val5"),
+ )
+
+ assert_gpu_and_cpu_are_equal_collect(run_on_spark)
+
+
+@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+")
+@ignore_order(local=True)
+def test_deep_pruning_mixed_depths(spark_tmp_path, from_protobuf_fn):
+ """Access leaves at different depths in the same query.
+
+ Select val1 (depth 1), val2 (depth 2), val3 (depth 3), and val5 (depth 5)
+ to exercise pruning at every intermediate level simultaneously.
+ """
+ desc_path, desc_bytes = _setup_protobuf_desc(
+ spark_tmp_path, "dp_mix.desc", _build_deep_nested_5_level_descriptor_set_bytes)
+ message_name = "test.Level1"
+ data_gen = _deep_5_level_data_gen()
+
+ def run_on_spark(spark):
+ df = gen_df(spark, data_gen)
+ decoded = _call_from_protobuf(
+ from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes)
+ return df.select(
+ decoded.getField("val1").alias("val1"),
+ decoded.getField("level2").getField("val2").alias("val2"),
+ _get_field_by_path(decoded, ["level2", "level3", "val3"]).alias("val3"),
+ _get_field_by_path(decoded, ["level2", "level3", "level4", "level5", "val5"])
+ .alias("val5"),
+ )
+
+ assert_gpu_and_cpu_are_equal_collect(run_on_spark)
+
+
+@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+")
+@ignore_order(local=True)
+def test_deep_pruning_whole_struct_at_depth_3(spark_tmp_path, from_protobuf_fn):
+ """Select the whole level3 struct -- no deep pruning inside level3."""
+ desc_path, desc_bytes = _setup_protobuf_desc(
+ spark_tmp_path, "dp_whole3.desc", _build_deep_nested_5_level_descriptor_set_bytes)
+ message_name = "test.Level1"
+ data_gen = _deep_5_level_data_gen()
+
+ def run_on_spark(spark):
+ df = gen_df(spark, data_gen)
+ decoded = _call_from_protobuf(
+ from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes)
+ return df.select(
+ decoded.getField("level2").getField("level3").alias("level3"),
+ )
+
+ assert_gpu_and_cpu_are_equal_collect(run_on_spark)
+
+
+# ===========================================================================
+# FAILFAST mode tests
+# ===========================================================================
+
+@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+")
+def test_from_protobuf_failfast_malformed_data(spark_tmp_path, from_protobuf_fn):
+ """FAILFAST mode should throw on malformed protobuf data (both CPU and GPU)."""
+ desc_path, desc_bytes = _setup_protobuf_desc(
+ spark_tmp_path, "failfast.desc", _build_simple_descriptor_set_bytes)
+ message_name = "test.Simple"
+
+ # Craft a valid row and a malformed row (truncated varint with continuation bit)
+ valid_row = _encode_tag(1, 0) + _encode_varint(1) + \
+ _encode_tag(2, 0) + _encode_varint(42)
+ malformed_row = bytes([0x08, 0x80]) # field 1, varint, but only continuation byte -- no end
+
+ def run_on_spark(spark):
+ df = spark.createDataFrame(
+ [(valid_row,), (malformed_row,)],
+ schema="bin binary",
+ )
+ decoded = _call_from_protobuf(
+ from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes,
+ options={"mode": "FAILFAST"})
+ # Must call .collect() so the exception surfaces inside with_*_session
+ return df.select(decoded.getField("b").alias("b")).collect()
+
+ assert_gpu_and_cpu_error(run_on_spark, {}, "Malformed")
+
+
+@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+")
+@ignore_order(local=True)
+def test_from_protobuf_permissive_malformed_returns_null(spark_tmp_path, from_protobuf_fn):
+ """PERMISSIVE mode should return null for malformed rows, not throw.
+
+ Note: Spark's from_protobuf defaults to FAILFAST (unlike JSON/CSV which
+ default to PERMISSIVE), so mode must be set explicitly.
+ """
+ desc_path, desc_bytes = _setup_protobuf_desc(
+ spark_tmp_path, "permissive.desc", _build_simple_descriptor_set_bytes)
+ message_name = "test.Simple"
+
+ valid_row = _encode_tag(2, 0) + _encode_varint(99)
+ malformed_row = bytes([0x08, 0x80]) # truncated varint
+
+ def run_on_spark(spark):
+ df = spark.createDataFrame(
+ [(valid_row,), (malformed_row,)],
+ schema="bin binary",
+ )
+ decoded = _call_from_protobuf(
+ from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes,
+ options={"mode": "PERMISSIVE"})
+ return df.select(
+ decoded.getField("i32").alias("i32"),
+ )
+
+ assert_gpu_and_cpu_are_equal_collect(run_on_spark)
+
+
+@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+")
+@ignore_order(local=True)
+def test_from_protobuf_all_null_input(spark_tmp_path, from_protobuf_fn):
+ """All rows in the input binary column are null (not empty bytes, actual nulls).
+ GPU should produce all-null struct rows matching CPU behavior."""
+ desc_path, desc_bytes = _setup_protobuf_desc(
+ spark_tmp_path, "allnull.desc", _build_simple_descriptor_set_bytes)
+ message_name = "test.Simple"
+
+ def run_on_spark(spark):
+ df = spark.createDataFrame(
+ [(None,), (None,), (None,)],
+ schema="bin binary",
+ )
+ decoded = _call_from_protobuf(
+ from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes)
+ return df.select(
+ decoded.getField("i32").alias("i32"),
+ decoded.getField("s").alias("s"),
+ )
+
+ assert_gpu_and_cpu_are_equal_collect(run_on_spark)
diff --git a/integration_tests/src/main/python/spark_init_internal.py b/integration_tests/src/main/python/spark_init_internal.py
index c0f5b6123ad..5150d2dacdc 100644
--- a/integration_tests/src/main/python/spark_init_internal.py
+++ b/integration_tests/src/main/python/spark_init_internal.py
@@ -45,6 +45,7 @@ def conf_for_env(env_name):
return res
_DRIVER_ENV = env_for_conf('spark.driver.extraJavaOptions')
+_DRIVER_CLASSPATH = env_for_conf('spark.driver.extraClassPath')
_SPARK_JARS = env_for_conf("spark.jars")
_SPARK_JARS_PACKAGES = env_for_conf("spark.jars.packages")
spark_jars_env = {
@@ -64,11 +65,66 @@ def findspark_init():
if spark_jars is not None:
logging.info(f"Adding to findspark jars: {spark_jars}")
findspark.add_jars(spark_jars)
+ # Also add to driver classpath so classes are available to Class.forName()
+ # This is needed for optional modules like spark-protobuf
+ _add_driver_classpath(spark_jars)
if spark_jars_packages is not None:
logging.info(f"Adding to findspark packages: {spark_jars_packages}")
findspark.add_packages(spark_jars_packages)
+
+def _add_driver_classpath(jars):
+ """
+ Add jars to the driver classpath via PYSPARK_SUBMIT_ARGS.
+ findspark.add_jars() only adds --jars, which doesn't make classes available
+ to Class.forName() on the driver. This function adds --driver-class-path.
+ """
+ if not jars:
+ return
+ current_args = os.environ.get('PYSPARK_SUBMIT_ARGS', '')
+ # Remove trailing 'pyspark-shell' if present
+ if current_args.endswith('pyspark-shell'):
+ current_args = current_args[:-len('pyspark-shell')].strip()
+ jar_list = [j.strip() for j in jars.split(',') if j.strip()]
+ existing_driver_cp = {
+ p for p in os.environ.get(_DRIVER_CLASSPATH, '').split(os.pathsep) if p
+ }
+ args_driver_cp_match = re.search(r'--driver-class-path\s+(\S+)', current_args)
+ existing_args_driver_cp = {
+ p for p in (args_driver_cp_match.group(1).split(os.pathsep)
+ if args_driver_cp_match else []) if p
+ }
+ jar_list = [
+ j for j in jar_list
+ if j not in existing_driver_cp and j not in existing_args_driver_cp
+ ]
+ if not jar_list:
+ logging.info("Skipping PYSPARK_SUBMIT_ARGS driver-class-path update; jars already present")
+ return
+ new_cp = os.pathsep.join(jar_list)
+ if '--driver-class-path' in current_args:
+ if args_driver_cp_match:
+ existing_cp = args_driver_cp_match.group(1)
+ merged_cp = existing_cp + os.pathsep + new_cp
+ current_args = re.sub(
+ r'--driver-class-path\s+\S+',
+ lambda m: f'--driver-class-path {merged_cp}',
+ current_args,
+ count=1)
+ else:
+ current_args = re.sub(
+ r'--driver-class-path(?=\s|$)',
+ '',
+ current_args,
+ count=1).strip()
+ current_args += f' --driver-class-path {new_cp}'
+ else:
+ current_args += f' --driver-class-path {new_cp}'
+ new_args = f"{current_args} pyspark-shell".strip()
+ os.environ['PYSPARK_SUBMIT_ARGS'] = new_args
+ logging.info(f"Updated PYSPARK_SUBMIT_ARGS with driver-class-path")
+
def running_with_xdist(session, is_worker):
try:
import xdist
diff --git a/integration_tests/src/test/resources/protobuf_test/gen_nested_proto_data.sh b/integration_tests/src/test/resources/protobuf_test/gen_nested_proto_data.sh
new file mode 100755
index 00000000000..2283ac8a751
--- /dev/null
+++ b/integration_tests/src/test/resources/protobuf_test/gen_nested_proto_data.sh
@@ -0,0 +1,34 @@
+#!/bin/bash
+# Convenience script: compile nested_proto .proto files into a descriptor set.
+#
+# Usage:
+# ./gen_nested_proto_data.sh
+#
+# The generated .desc file is checked into the repository and used by
+# integration tests in protobuf_test.py. Re-run this script whenever
+# the .proto definitions under nested_proto/ change.
+
+set -e
+
+SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
+PROTO_DIR="${SCRIPT_DIR}/nested_proto"
+OUTPUT_DIR="${SCRIPT_DIR}/nested_proto/generated"
+
+echo "=== Protobuf Descriptor Compiler ==="
+echo "Proto dir: ${PROTO_DIR}"
+echo ""
+
+# Create output directory
+mkdir -p "${OUTPUT_DIR}"
+
+# Compile proto files into a descriptor set (includes all imports)
+DESC_FILE="${OUTPUT_DIR}/main_log.desc"
+echo "Compiling proto files..."
+protoc \
+ --descriptor_set_out="${DESC_FILE}" \
+ --include_imports \
+ -I"${PROTO_DIR}" \
+ "${PROTO_DIR}/main_log.proto"
+
+echo "Generated: ${DESC_FILE}"
+echo "=== Done ==="
diff --git a/integration_tests/src/test/resources/protobuf_test/nested_proto/device_req.proto b/integration_tests/src/test/resources/protobuf_test/nested_proto/device_req.proto
new file mode 100644
index 00000000000..c4d86951d96
--- /dev/null
+++ b/integration_tests/src/test/resources/protobuf_test/nested_proto/device_req.proto
@@ -0,0 +1,11 @@
+syntax = "proto2";
+
+package com.test.proto.sample;
+
+option java_outer_classname = "DeviceReqBean";
+
+// Device request field
+message DeviceReqField {
+ optional int32 os_type = 1; // int32
+ optional bytes device_id = 2; // bytes
+}
diff --git a/integration_tests/src/test/resources/protobuf_test/nested_proto/generated/main_log.desc b/integration_tests/src/test/resources/protobuf_test/nested_proto/generated/main_log.desc
new file mode 100644
index 00000000000..6e8155238b3
Binary files /dev/null and b/integration_tests/src/test/resources/protobuf_test/nested_proto/generated/main_log.desc differ
diff --git a/integration_tests/src/test/resources/protobuf_test/nested_proto/main_log.proto b/integration_tests/src/test/resources/protobuf_test/nested_proto/main_log.proto
new file mode 100644
index 00000000000..f9a325a9f2e
--- /dev/null
+++ b/integration_tests/src/test/resources/protobuf_test/nested_proto/main_log.proto
@@ -0,0 +1,103 @@
+syntax = "proto2";
+
+package com.test.proto.sample;
+
+import "module_a_res.proto";
+import "module_b_res.proto";
+import "device_req.proto";
+
+// ========== Enum type tests ==========
+enum SourceType {
+ WEB = 0; // web
+ APP = 1; // application
+ MOBILE = 4; // mobile
+}
+
+enum ChannelType {
+ CHANNEL_A = 0;
+ CHANNEL_B = 1;
+}
+
+// ========== Main log record ==========
+message MainLogRecord {
+ // required fields
+ required SourceType source = 1; // required enum
+ required uint64 timestamp = 2; // required uint64
+
+ // optional scalar types - one of each
+ optional string user_id = 3; // string
+ optional int64 account_id = 4; // int64
+ optional fixed32 client_ip = 5; // fixed32
+
+ // nested message
+ optional LogContent log_content = 6;
+}
+
+// ========== Log content (multi-level nesting) ==========
+message LogContent {
+ optional BasicInfo basic_info = 1;
+ repeated ChannelInfo channel_list = 2;
+ repeated DataSourceField source_list = 3;
+}
+
+// ========== Basic info (three-level nesting) ==========
+message BasicInfo {
+ optional RequestInfo request_info = 1;
+ optional ExtendedReqInfo extended_req_info = 2;
+ optional ServerAddedField server_added_field = 3;
+}
+
+// ========== Request info ==========
+message RequestInfo {
+ optional uint32 page_num = 1; // uint32
+ optional string channel_code = 2; // string
+ repeated uint32 experiment_ids = 3; // repeated uint32
+ optional bool is_filtered = 4; // bool
+}
+
+// ========== Extended request info (cross-file import) ==========
+message ExtendedReqInfo {
+ optional DeviceReqField device_req_field = 1; // reference to external proto
+}
+
+// ========== Server-added fields ==========
+message ServerAddedField {
+ optional uint32 region_code = 1; // uint32
+ optional string flow_type = 2; // string
+ optional int32 filter_result = 3; // int32
+ repeated int32 hit_rule_list = 4; // repeated int32
+ optional uint64 request_time = 5; // uint64
+ optional bool skip_flag = 6; // bool
+}
+
+// ========== Channel info ==========
+message ChannelInfo {
+ optional int32 channel_id = 1;
+ optional ModuleAResField module_a_res = 2; // reference to external proto
+}
+
+// ========== Source channel info ==========
+message SrcChannelInfo {
+ optional int32 channel_id = 1;
+ optional ModuleASrcResField module_a_src_res = 2; // reference to external proto
+}
+
+// ========== Data source field ==========
+message DataSourceField {
+ optional uint32 source_id = 1;
+ repeated SrcChannelInfo src_channel_list = 2;
+ optional string billing_name = 3;
+ repeated ItemDetailField item_list = 4;
+ optional bool is_free = 5;
+}
+
+// ========== Item detail field ==========
+message ItemDetailField {
+ optional uint32 rank = 1; // uint32
+ optional uint64 record_id = 2; // uint64
+ optional string keyword = 3; // string
+
+ // cross-file message references
+ optional ModuleADetailField module_a_detail = 4;
+ optional ModuleBDetailField module_b_detail = 5;
+}
diff --git a/integration_tests/src/test/resources/protobuf_test/nested_proto/module_a_res.proto b/integration_tests/src/test/resources/protobuf_test/nested_proto/module_a_res.proto
new file mode 100644
index 00000000000..9a2674e5dd5
--- /dev/null
+++ b/integration_tests/src/test/resources/protobuf_test/nested_proto/module_a_res.proto
@@ -0,0 +1,92 @@
+syntax = "proto2";
+
+package com.test.proto.sample;
+
+option java_outer_classname = "ModuleARes";
+
+import "predictor_schema.proto";
+
+// ========== Test default values ==========
+message PartnerInfo {
+ optional string token = 1 [default = ""]; // default empty string
+ optional uint64 partner_id = 2 [default = 0]; // default 0
+}
+
+// ========== Test coordinate structure (simple nesting) ==========
+message Coordinate {
+ optional double x = 1; // x coordinate - double
+ optional double y = 2; // y coordinate - double
+}
+
+// ========== Test multi-level nesting ==========
+message LocationPoint {
+ optional uint32 frequency = 1; // frequency - uint32
+ optional Coordinate coord = 2; // coordinate - nested message
+ optional uint64 timestamp = 3; // timestamp - uint64
+}
+
+// ========== Test change log ==========
+message ChangeLog {
+ optional uint32 value_before = 1; // value before change
+ optional string parameters = 2; // parameters
+}
+
+// ========== Test price change log ==========
+message PriceLog {
+ optional uint32 price_before = 1;
+}
+
+// ========== Module A response-level fields ==========
+message ModuleAResField {
+ optional string route_tag = 1; // route tag - string
+ optional int32 status_tag = 2; // status tag - int32
+ optional uint32 region_id = 3; // region id - uint32
+ repeated string experiment_ids = 4; // experiment id list - repeated string
+ optional double quality_score = 5; // quality score - double
+ repeated LocationPoint location_points = 6; // location points - repeated nested message
+ repeated uint64 interest_ids = 7; // interest id list - repeated uint64
+}
+
+// ========== Module A source response field ==========
+message ModuleASrcResField {
+ optional uint32 match_type = 1; // match type
+}
+
+// ========== Key-value pair ==========
+message KVPair {
+ optional bytes key = 1; // key - bytes
+ optional bytes value = 2; // value - bytes
+}
+
+// ========== Style configuration ==========
+message StyleConfig {
+ optional uint32 style_id = 1; // style id
+ repeated KVPair kv_pairs = 2; // kv pair list - repeated nested message
+}
+
+// ========== Module A detail field (core complex structure) ==========
+message ModuleADetailField {
+ // scalar types - one or two of each
+ optional uint32 type_code = 1; // uint32
+ optional uint64 item_id = 2; // uint64
+ optional int32 strategy_type = 3; // int32
+ optional int64 min_value = 4; // int64
+ optional bytes target_url = 5; // bytes
+ optional string title = 6; // string
+ optional bool is_valid = 7; // bool
+ optional float score_ratio = 8; // float
+
+ // repeated scalar types
+ repeated uint32 template_ids = 9; // repeated uint32
+ repeated uint64 material_ids = 10; // repeated uint64
+
+ // repeated nested messages
+ repeated StyleConfig styles = 11; // repeated message
+ repeated ChangeLog change_logs = 12; // repeated message
+
+ // nested message
+ optional PartnerInfo partner_info = 13; // nested message
+
+ // cross-file import
+ optional PredictorSchema predictor_schema = 14; // reference to external proto
+}
diff --git a/integration_tests/src/test/resources/protobuf_test/nested_proto/module_b_res.proto b/integration_tests/src/test/resources/protobuf_test/nested_proto/module_b_res.proto
new file mode 100644
index 00000000000..0c742739835
--- /dev/null
+++ b/integration_tests/src/test/resources/protobuf_test/nested_proto/module_b_res.proto
@@ -0,0 +1,29 @@
+syntax = "proto2";
+
+package com.test.proto.sample;
+
+// Module B response field
+message ModuleBResField {
+ optional uint32 type_code = 1; // uint32
+ optional string extra_info = 2; // string
+}
+
+// Block element - tests repeated scalar types
+message BlockElement {
+ optional uint64 element_id = 1; // uint64
+ repeated uint64 ref_ids = 2; // repeated uint64
+}
+
+// Block info - tests repeated nested messages
+message BlockInfo {
+ optional uint64 block_id = 1; // uint64
+ repeated BlockElement elements = 2; // repeated message
+}
+
+// Module B detail field
+message ModuleBDetailField {
+ repeated uint32 tags = 1; // repeated uint32
+ optional uint64 item_id = 2; // uint64
+ optional string name = 3; // string
+ repeated BlockInfo blocks = 4; // repeated message
+}
diff --git a/integration_tests/src/test/resources/protobuf_test/nested_proto/predictor_schema.proto b/integration_tests/src/test/resources/protobuf_test/nested_proto/predictor_schema.proto
new file mode 100644
index 00000000000..bc7d86c67e2
--- /dev/null
+++ b/integration_tests/src/test/resources/protobuf_test/nested_proto/predictor_schema.proto
@@ -0,0 +1,82 @@
+syntax = "proto2";
+
+package com.test.proto.sample;
+
+// Predictor schema - tests multi-level schema nesting
+
+// ========== Main schema structure ==========
+message PredictorSchema {
+ optional SchemaTypeA type_a_schema = 1; // nested
+ optional SchemaTypeB type_b_schema = 2; // nested
+ optional SchemaTypeC type_c_schema = 3; // nested (with repeated)
+}
+
+// ========== TypeA Query Schema ==========
+message TypeAQuerySchema {
+ optional string keyword = 1; // keyword
+ optional string session_id = 2; // session id
+}
+
+// ========== TypeA Pair Schema ==========
+message TypeAPairSchema {
+ optional string record_id = 1; // record id
+ optional string item_id = 2; // item id
+}
+
+// ========== TypeA Schema ==========
+message SchemaTypeA {
+ optional TypeAQuerySchema query_schema = 1; // query-level schema
+ repeated TypeAPairSchema pair_schema = 2; // pair list - repeated nested
+}
+
+// ========== TypeB Query Schema ==========
+message TypeBQuerySchema {
+ optional string profile_tag_id = 1; // profile tag id
+ optional string entity_id = 2; // entity id
+}
+
+// ========== TypeB Style Element ==========
+message TypeBStyleElem {
+ optional string template_id = 1; // template id
+ optional string material_id = 2; // material id
+}
+
+// ========== TypeB Style Schema ==========
+message TypeBStyleSchema {
+ repeated TypeBStyleElem values = 1; // element list - repeated nested
+}
+
+// ========== TypeB Schema ==========
+message SchemaTypeB {
+ optional TypeBQuerySchema query_schema = 1; // query-level schema
+ repeated TypeBStyleSchema style_schema = 2; // style list - repeated nested
+}
+
+// ========== TypeC Query Schema ==========
+message TypeCQuerySchema {
+ optional string keyword = 1; // keyword
+ optional string category = 2; // category
+}
+
+// ========== TypeC Pair Schema ==========
+message TypeCPairSchema {
+ optional string item_id = 1; // item id
+ optional string target_url = 2; // target url
+}
+
+// ========== TypeC Style Element (empty structure test) ==========
+message TypeCStyleElem {
+ // empty message - tests empty structure handling
+}
+
+// ========== TypeC Style Schema ==========
+message TypeCStyleSchema {
+ repeated TypeCStyleElem values = 1; // empty element list
+}
+
+// ========== TypeC Schema (full three-level structure) ==========
+message SchemaTypeC {
+ optional TypeCQuerySchema query_schema = 1; // query schema
+ repeated TypeCPairSchema pair_schema = 2; // pair list
+ repeated TypeCStyleSchema style_schema = 3; // style list
+}
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuBoundAttribute.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuBoundAttribute.scala
index 3c5b9f26872..3deeb88c737 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuBoundAttribute.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuBoundAttribute.scala
@@ -71,7 +71,8 @@ object GpuBindReferences extends Logging {
if (ordinal == -1) {
sys.error(s"Couldn't find $a in ${input.attrs.mkString("[", ",", "]")}")
} else {
- GpuBoundReference(ordinal, a.dataType, input(ordinal).nullable)(a.exprId, a.name)
+ GpuBoundReference(ordinal, input(ordinal).dataType, input(ordinal).nullable)(
+ a.exprId, a.name)
}
}
val matchFunc = regularMatch.orElse(partial)
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
index 565e97a40ca..edcee71dc0f 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
@@ -955,6 +955,7 @@ object GpuOverrides extends Logging {
+ GpuTypeShims.additionalCommonOperatorSupportedTypes).nested(),
TypeSig.all),
(a, conf, p, r) => new UnaryAstExprMeta[Alias](a, conf, p, r) {
+ override def typeMeta: DataTypeMeta = childExprs.head.typeMeta
override def convertToGpu(child: Expression): GpuExpression =
GpuAlias(child, a.name)(a.exprId, a.qualifier, a.explicitMetadata)
}),
@@ -981,7 +982,32 @@ object GpuOverrides extends Logging {
TypeSig.all),
(att, conf, p, r) => new BaseExprMeta[AttributeReference](att, conf, p, r) {
// This is the only NOOP operator. It goes away when things are bound
- override def convertToGpuImpl(): Expression = att
+ override def convertToGpuImpl(): Expression = {
+ def findParentPlan(meta: Option[RapidsMeta[_, _, _]]): Option[SparkPlanMeta[_]] = {
+ meta match {
+ case Some(planMeta: SparkPlanMeta[_]) => Some(planMeta)
+ case Some(other) => findParentPlan(other.parent)
+ case None => None
+ }
+ }
+
+ val maybeResolvedAttr = for {
+ planMeta <- findParentPlan(parent)
+ matched <- planMeta.childPlans.iterator
+ .flatMap(_.outputAttributes.iterator)
+ .find(_.exprId == att.exprId)
+ } yield AttributeReference(
+ att.name,
+ matched.dataType,
+ matched.nullable,
+ att.metadata)(att.exprId, att.qualifier)
+
+ // NOTE: matched.dataType may still reflect the wrapped Spark plan's original
+ // output type (for example, an un-pruned protobuf struct). Correct runtime type
+ // propagation is guaranteed by GpuBoundAttribute, which uses input(ordinal).dataType
+ // at bind time.
+ maybeResolvedAttr.getOrElse(att)
+ }
// There are so many of these that we don't need to print them out, unless it
// will not work on the GPU
@@ -2628,10 +2654,7 @@ object GpuOverrides extends Logging {
TypeSig.STRUCT.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY +
TypeSig.STRUCT + TypeSig.MAP + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.BINARY),
TypeSig.STRUCT.nested(TypeSig.all)),
- (expr, conf, p, r) => new UnaryExprMeta[GetStructField](expr, conf, p, r) {
- override def convertToGpu(arr: Expression): GpuExpression =
- GpuGetStructField(arr, expr.ordinal, expr.name)
- }),
+ (expr, conf, p, r) => new GpuGetStructFieldMeta(expr, conf, p, r)),
expr[GetArrayItem](
"Gets the field at `ordinal` in the Array",
ExprChecks.binaryProject(
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala
index a97f830fe3e..015814e94b0 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala
@@ -53,6 +53,9 @@ class GpuProjectExecMeta(
p: Option[RapidsMeta[_, _, _]],
r: DataFromReplacementRule) extends SparkPlanMeta[ProjectExec](proj, conf, p, r)
with Logging {
+ override protected lazy val outputTypeMetas: Option[Seq[DataTypeMeta]] =
+ Some(childExprs.map(_.typeMeta))
+
override def convertToGpu(): GpuExec = {
// Force list to avoid recursive Java serialization of lazy list Seq implementation
val gpuExprs = childExprs.map(_.convertToGpu().asInstanceOf[NamedExpression]).toList
@@ -299,7 +302,6 @@ trait GpuProjectExecLike extends GpuPartitioningPreservingUnaryExecNode with Gpu
override def doExecute(): RDD[InternalRow] =
throw new IllegalStateException(s"Row-based execution should not occur for $this")
- // The same as what feeds us
override def outputBatching: CoalesceGoal = GpuExec.outputBatching(child)
}
@@ -768,8 +770,6 @@ case class GpuProjectExec(
override def output: Seq[Attribute] = projectList.map(_.toAttribute)
override def outputBatching: CoalesceGoal = if (enablePreSplit) {
- // Pre-split will make sure the size of each output batch will not be larger
- // than the splitUntilSize.
TargetSize(PreProjectSplitIterator.getSplitUntilSize)
} else {
super.outputBatching
diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFromProtobuf.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFromProtobuf.scala
new file mode 100644
index 00000000000..35743bbf336
--- /dev/null
+++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFromProtobuf.scala
@@ -0,0 +1,195 @@
+/*
+ * Copyright (c) 2025-2026, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.rapids
+
+import java.util.Arrays
+
+import ai.rapids.cudf
+import ai.rapids.cudf.{CudfException, DType}
+import com.nvidia.spark.rapids.{GpuColumnVector, GpuUnaryExpression}
+import com.nvidia.spark.rapids.jni.{Protobuf, ProtobufSchemaDescriptor}
+import com.nvidia.spark.rapids.shims.NullIntolerantShim
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression}
+import org.apache.spark.sql.types._
+
+/**
+ * GPU implementation for Spark's `from_protobuf` decode path.
+ *
+ * This is designed to replace `org.apache.spark.sql.protobuf.ProtobufDataToCatalyst` when
+ * supported.
+ *
+ * The implementation uses a flattened schema representation where nested fields have parent
+ * indices pointing to their containing message field. For pure scalar schemas, all fields
+ * are top-level (parentIndices == -1, depthLevels == 0, isRepeated == false).
+ *
+ * Schema projection is supported: `decodedSchema` contains only the top-level fields and
+ * nested children that are actually referenced by downstream operators. Downstream
+ * `GetStructField` and `GetArrayStructFields` nodes have their ordinals rewritten via
+ * `PRUNED_ORDINAL_TAG` to index into the pruned schema. Unreferenced fields are never
+ * accessed, so no null-column filling is needed.
+ *
+ * @param decodedSchema The pruned schema containing only the fields decoded by the GPU.
+ * Only fields referenced by downstream operators are included;
+ * ordinal remapping ensures correct field access into the pruned output.
+ * @param fieldNumbers Protobuf field numbers for all fields in flattened schema
+ * @param parentIndices Parent indices for all fields (-1 for top-level)
+ * @param depthLevels Nesting depth for all fields (0 for top-level)
+ * @param wireTypes Wire types for all fields
+ * @param outputTypeIds cuDF type IDs for all fields
+ * @param encodings Encodings for all fields
+ * @param isRepeated Whether each field is repeated
+ * @param isRequired Whether each field is required
+ * @param hasDefaultValue Whether each field has a default value
+ * @param defaultInts Default int/long values
+ * @param defaultFloats Default float/double values
+ * @param defaultBools Default bool values
+ * @param defaultStrings Default string/bytes values
+ * @param enumValidValues Valid enum values for each field
+ * @param enumNames Enum value names for enum-as-string fields. Parallel to enumValidValues.
+ * @param failOnErrors If true, throw exception on malformed data
+ */
+case class GpuFromProtobuf(
+ decodedSchema: StructType,
+ fieldNumbers: Array[Int],
+ parentIndices: Array[Int],
+ depthLevels: Array[Int],
+ wireTypes: Array[Int],
+ outputTypeIds: Array[Int],
+ encodings: Array[Int],
+ isRepeated: Array[Boolean],
+ isRequired: Array[Boolean],
+ hasDefaultValue: Array[Boolean],
+ defaultInts: Array[Long],
+ defaultFloats: Array[Double],
+ defaultBools: Array[Boolean],
+ defaultStrings: Array[Array[Byte]],
+ enumValidValues: Array[Array[Int]],
+ enumNames: Array[Array[Array[Byte]]],
+ failOnErrors: Boolean,
+ child: Expression)
+ extends GpuUnaryExpression with ExpectsInputTypes with NullIntolerantShim with Logging {
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType)
+
+ override def dataType: DataType = decodedSchema
+
+ override def nullable: Boolean = true
+
+ override def equals(other: Any): Boolean = other match {
+ case that: GpuFromProtobuf =>
+ decodedSchema == that.decodedSchema &&
+ Arrays.equals(fieldNumbers, that.fieldNumbers) &&
+ Arrays.equals(parentIndices, that.parentIndices) &&
+ Arrays.equals(depthLevels, that.depthLevels) &&
+ Arrays.equals(wireTypes, that.wireTypes) &&
+ Arrays.equals(outputTypeIds, that.outputTypeIds) &&
+ Arrays.equals(encodings, that.encodings) &&
+ Arrays.equals(isRepeated, that.isRepeated) &&
+ Arrays.equals(isRequired, that.isRequired) &&
+ Arrays.equals(hasDefaultValue, that.hasDefaultValue) &&
+ Arrays.equals(defaultInts, that.defaultInts) &&
+ Arrays.equals(defaultFloats, that.defaultFloats) &&
+ Arrays.equals(defaultBools, that.defaultBools) &&
+ GpuFromProtobuf.deepEquals(defaultStrings, that.defaultStrings) &&
+ GpuFromProtobuf.deepEquals(enumValidValues, that.enumValidValues) &&
+ GpuFromProtobuf.deepEquals(enumNames, that.enumNames) &&
+ failOnErrors == that.failOnErrors &&
+ child == that.child
+ case _ => false
+ }
+
+ override def hashCode(): Int = {
+ var result = decodedSchema.hashCode()
+ result = 31 * result + Arrays.hashCode(fieldNumbers)
+ result = 31 * result + Arrays.hashCode(parentIndices)
+ result = 31 * result + Arrays.hashCode(depthLevels)
+ result = 31 * result + Arrays.hashCode(wireTypes)
+ result = 31 * result + Arrays.hashCode(outputTypeIds)
+ result = 31 * result + Arrays.hashCode(encodings)
+ result = 31 * result + Arrays.hashCode(isRepeated)
+ result = 31 * result + Arrays.hashCode(isRequired)
+ result = 31 * result + Arrays.hashCode(hasDefaultValue)
+ result = 31 * result + Arrays.hashCode(defaultInts)
+ result = 31 * result + Arrays.hashCode(defaultFloats)
+ result = 31 * result + Arrays.hashCode(defaultBools)
+ result = 31 * result + GpuFromProtobuf.deepHashCode(defaultStrings)
+ result = 31 * result + GpuFromProtobuf.deepHashCode(enumValidValues)
+ result = 31 * result + GpuFromProtobuf.deepHashCode(enumNames)
+ result = 31 * result + failOnErrors.hashCode()
+ result = 31 * result + child.hashCode()
+ result
+ }
+
+ // ProtobufSchemaDescriptor is a pure-Java immutable holder for validated schema arrays.
+ // It does not own native resources, so task-scoped close hooks are not required here.
+ @transient private lazy val protobufSchema = new ProtobufSchemaDescriptor(
+ fieldNumbers, parentIndices, depthLevels, wireTypes, outputTypeIds, encodings,
+ isRepeated, isRequired, hasDefaultValue, defaultInts, defaultFloats, defaultBools,
+ defaultStrings, enumValidValues, enumNames)
+
+ override protected def doColumnar(input: GpuColumnVector): cudf.ColumnVector = {
+ // Input null mask is propagated to the output struct by the C++ decoder,
+ // so no mergeAndSetValidity call is needed here.
+ try {
+ Protobuf.decodeToStruct(input.getBase, protobufSchema, failOnErrors)
+ } catch {
+ case e: CudfException if failOnErrors =>
+ throw new org.apache.spark.SparkException("Malformed protobuf message", e)
+ case e: CudfException =>
+ logWarning(s"Unexpected CudfException in PERMISSIVE mode: ${e.getMessage}", e)
+ throw e
+ }
+ }
+}
+
+object GpuFromProtobuf {
+ val ENC_DEFAULT = 0
+ val ENC_FIXED = 1
+ val ENC_ZIGZAG = 2
+ val ENC_ENUM_STRING = 3
+
+ /**
+ * Maps a Spark DataType to the corresponding cuDF native type ID.
+ * Note: The encoding (varint/zigzag/fixed) is determined by the protobuf field type,
+ * not the Spark data type, so it must be set separately based on the protobuf schema.
+ *
+ * @return Some(typeId) for supported types, None for unsupported types
+ */
+ def sparkTypeToCudfIdOpt(dt: DataType): Option[Int] = dt match {
+ case BooleanType => Some(DType.BOOL8.getTypeId.getNativeId)
+ case IntegerType => Some(DType.INT32.getTypeId.getNativeId)
+ case LongType => Some(DType.INT64.getTypeId.getNativeId)
+ case FloatType => Some(DType.FLOAT32.getTypeId.getNativeId)
+ case DoubleType => Some(DType.FLOAT64.getTypeId.getNativeId)
+ case StringType => Some(DType.STRING.getTypeId.getNativeId)
+ case BinaryType => Some(DType.LIST.getTypeId.getNativeId)
+ case _ => None
+ }
+
+ /**
+ * Check if a Spark DataType is supported by the GPU protobuf decoder.
+ */
+ def isTypeSupported(dt: DataType): Boolean = sparkTypeToCudfIdOpt(dt).isDefined
+
+ private def deepEquals[T](left: Array[T], right: Array[T]): Boolean =
+ Arrays.deepEquals(left.asInstanceOf[Array[Object]], right.asInstanceOf[Array[Object]])
+
+ private def deepHashCode[T](arr: Array[T]): Int =
+ Arrays.deepHashCode(arr.asInstanceOf[Array[Object]])
+}
diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala
index 9afb3b854d6..c88ae6a5322 100644
--- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala
+++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala
@@ -32,7 +32,10 @@ import org.apache.spark.sql.rapids.shims.RapidsErrorUtils
import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, ArrayType, BooleanType, DataType, IntegralType, LongType, MapType, StructField, StructType}
import org.apache.spark.sql.vectorized.ColumnarBatch
-case class GpuGetStructField(child: Expression, ordinal: Int, name: Option[String] = None)
+case class GpuGetStructField(
+ child: Expression,
+ ordinal: Int,
+ name: Option[String] = None)
extends ShimUnaryExpression
with GpuExpression
with ShimGetStructField
@@ -41,15 +44,23 @@ case class GpuGetStructField(child: Expression, ordinal: Int, name: Option[Strin
lazy val childSchema: StructType = child.dataType.asInstanceOf[StructType]
override def dataType: DataType = childSchema(ordinal).dataType
- override def nullable: Boolean = child.nullable || childSchema(ordinal).nullable
+
+ override def nullable: Boolean =
+ child.nullable || childSchema(ordinal).nullable
override def toString: String = {
- val fieldName = if (resolved) childSchema(ordinal).name else s"_$ordinal"
+ val fieldName = if (resolved) {
+ childSchema(ordinal).name
+ } else {
+ s"_$ordinal"
+ }
s"$child.${name.getOrElse(fieldName)}"
}
- override def sql: String =
- child.sql + s".${quoteIdentifier(name.getOrElse(childSchema(ordinal).name))}"
+ override def sql: String = {
+ val fieldName = childSchema(ordinal).name
+ child.sql + s".${quoteIdentifier(name.getOrElse(fieldName))}"
+ }
override def columnarEvalAny(batch: ColumnarBatch): Any = {
val dt = dataType
@@ -59,7 +70,6 @@ case class GpuGetStructField(child: Expression, ordinal: Int, name: Option[Strin
GpuColumnVector.from(view.copyToColumnVector(), dt)
}
case s: GpuScalar =>
- // For a scalar in we want a scalar out.
if (!s.isValid) {
GpuScalar(null, dt)
} else {
@@ -402,6 +412,36 @@ case class GpuArrayPosition(left: Expression, right: Expression)
}
}
+object GpuStructFieldOrdinalTag {
+ val PRUNED_ORDINAL_TAG =
+ new org.apache.spark.sql.catalyst.trees.TreeNodeTag[Int]("GPU_PRUNED_ORDINAL")
+}
+
+class GpuGetStructFieldMeta(
+ expr: GetStructField,
+ conf: RapidsConf,
+ parent: Option[RapidsMeta[_, _, _]],
+ rule: DataFromReplacementRule)
+ extends UnaryExprMeta[GetStructField](expr, conf, parent, rule) {
+
+ def convertToGpu(child: Expression): GpuExpression = {
+ val effectiveOrd = GpuGetStructFieldMeta.effectiveOrdinal(expr)
+ GpuGetStructField(child, effectiveOrd, expr.name)
+ }
+}
+
+object GpuGetStructFieldMeta {
+ def effectiveOrdinal(expr: GetStructField): Int = {
+ val runtimeOrd = expr.getTagValue(
+ GpuStructFieldOrdinalTag.PRUNED_ORDINAL_TAG).getOrElse(-1)
+ if (runtimeOrd >= 0) {
+ runtimeOrd
+ } else {
+ expr.ordinal
+ }
+ }
+}
+
class GpuGetArrayStructFieldsMeta(
expr: GetArrayStructFields,
conf: RapidsConf,
@@ -409,8 +449,41 @@ class GpuGetArrayStructFieldsMeta(
rule: DataFromReplacementRule)
extends UnaryExprMeta[GetArrayStructFields](expr, conf, parent, rule) {
- def convertToGpu(child: Expression): GpuExpression =
- GpuGetArrayStructFields(child, expr.field, expr.ordinal, expr.numFields, expr.containsNull)
+ def convertToGpu(child: Expression): GpuExpression = {
+ val effectiveOrd = GpuGetArrayStructFieldsMeta.effectiveOrdinal(expr)
+ val runtimeOrd = expr.getTagValue(
+ GpuStructFieldOrdinalTag.PRUNED_ORDINAL_TAG).getOrElse(-1)
+ val effectiveNumFields =
+ GpuGetArrayStructFieldsMeta.effectiveNumFields(child, expr, runtimeOrd)
+ GpuGetArrayStructFields(child, expr.field,
+ effectiveOrd, effectiveNumFields, expr.containsNull)
+ }
+}
+
+object GpuGetArrayStructFieldsMeta {
+ def effectiveOrdinal(expr: GetArrayStructFields): Int = {
+ val runtimeOrd = expr.getTagValue(
+ GpuStructFieldOrdinalTag.PRUNED_ORDINAL_TAG).getOrElse(-1)
+ if (runtimeOrd >= 0) {
+ runtimeOrd
+ } else {
+ expr.ordinal
+ }
+ }
+
+ def effectiveNumFields(
+ child: Expression,
+ expr: GetArrayStructFields,
+ runtimeOrd: Int): Int = {
+ if (runtimeOrd >= 0) {
+ child.dataType match {
+ case ArrayType(st: StructType, _) => st.fields.length
+ case _ => expr.numFields
+ }
+ } else {
+ expr.numFields
+ }
+ }
}
/**
diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/protobuf/ProtobufSchemaExtractor.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/protobuf/ProtobufSchemaExtractor.scala
new file mode 100644
index 00000000000..eedbc9cd01c
--- /dev/null
+++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/protobuf/ProtobufSchemaExtractor.scala
@@ -0,0 +1,236 @@
+/*
+ * Copyright (c) 2026, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.rapids.protobuf
+
+import scala.collection.mutable
+
+import com.nvidia.spark.rapids.jni.Protobuf.{WT_32BIT, WT_64BIT, WT_LEN, WT_VARINT}
+
+import org.apache.spark.sql.rapids.GpuFromProtobuf
+import org.apache.spark.sql.types._
+
+object ProtobufSchemaExtractor {
+ def analyzeAllFields(
+ schema: StructType,
+ msgDesc: ProtobufMessageDescriptor,
+ enumsAsInts: Boolean,
+ messageName: String): Either[String, Map[String, ProtobufFieldInfo]] = {
+ val result = mutable.Map[String, ProtobufFieldInfo]()
+
+ schema.fields.foreach { sf =>
+ val fieldInfo = msgDesc.findField(sf.name) match {
+ case None =>
+ unsupportedFieldInfo(
+ sf,
+ None,
+ s"Protobuf field '${sf.name}' not found in message '$messageName'")
+ case Some(fd) =>
+ extractFieldInfo(sf, fd, enumsAsInts) match {
+ case Right(info) =>
+ info
+ case Left(reason) =>
+ unsupportedFieldInfo(sf, Some(fd), reason)
+ }
+ }
+ result(sf.name) = fieldInfo
+ }
+
+ Right(result.toMap)
+ }
+
+ def extractFieldInfo(
+ sparkField: StructField,
+ fieldDescriptor: ProtobufFieldDescriptor,
+ enumsAsInts: Boolean): Either[String, ProtobufFieldInfo] = {
+ val (isSupported, unsupportedReason, encoding) =
+ checkFieldSupport(
+ sparkField.dataType,
+ fieldDescriptor.protoTypeName,
+ fieldDescriptor.isRepeated,
+ enumsAsInts)
+
+ val defaultValue = fieldDescriptor.defaultValueResult match {
+ case Right(value) =>
+ value
+ case Left(_) if !isSupported =>
+ // Preserve the primary unsupported reason from checkFieldSupport for fields that are
+ // already known to be unsupported. Reflection/default extraction errors on those fields
+ // should not mask the more actionable type-support message.
+ None
+ case Left(reason) =>
+ return Left(reason)
+ }
+
+ Right(
+ ProtobufFieldInfo(
+ fieldNumber = fieldDescriptor.fieldNumber,
+ protoTypeName = fieldDescriptor.protoTypeName,
+ sparkType = sparkField.dataType,
+ encoding = encoding,
+ isSupported = isSupported,
+ unsupportedReason = unsupportedReason,
+ isRequired = fieldDescriptor.isRequired,
+ defaultValue = defaultValue,
+ enumMetadata = fieldDescriptor.enumMetadata,
+ isRepeated = fieldDescriptor.isRepeated
+ ))
+ }
+
+ private def unsupportedFieldInfo(
+ sparkField: StructField,
+ fieldDescriptor: Option[ProtobufFieldDescriptor],
+ reason: String): ProtobufFieldInfo = {
+ ProtobufFieldInfo(
+ fieldNumber = fieldDescriptor.map(_.fieldNumber).getOrElse(-1),
+ protoTypeName = fieldDescriptor.map(_.protoTypeName).getOrElse("UNKNOWN"),
+ sparkType = sparkField.dataType,
+ encoding = GpuFromProtobuf.ENC_DEFAULT,
+ isSupported = false,
+ unsupportedReason = Some(reason),
+ isRequired = fieldDescriptor.exists(_.isRequired),
+ defaultValue = None,
+ enumMetadata = fieldDescriptor.flatMap(_.enumMetadata),
+ isRepeated = fieldDescriptor.exists(_.isRepeated)
+ )
+ }
+
+ def checkFieldSupport(
+ sparkType: DataType,
+ protoTypeName: String,
+ isRepeated: Boolean,
+ enumsAsInts: Boolean): (Boolean, Option[String], Int) = {
+
+ if (isRepeated) {
+ sparkType match {
+ case ArrayType(elementType, _) =>
+ elementType match {
+ case BooleanType | IntegerType | LongType | FloatType | DoubleType |
+ StringType | BinaryType =>
+ return checkScalarEncoding(elementType, protoTypeName, enumsAsInts)
+ case _: StructType =>
+ return (true, None, GpuFromProtobuf.ENC_DEFAULT)
+ case _ =>
+ return (
+ false,
+ Some(s"unsupported repeated element type: $elementType"),
+ GpuFromProtobuf.ENC_DEFAULT)
+ }
+ case _ =>
+ return (
+ false,
+ Some(s"repeated field should map to ArrayType, got: $sparkType"),
+ GpuFromProtobuf.ENC_DEFAULT)
+ }
+ }
+
+ if (protoTypeName == "MESSAGE") {
+ sparkType match {
+ case _: StructType =>
+ return (true, None, GpuFromProtobuf.ENC_DEFAULT)
+ case _ =>
+ return (
+ false,
+ Some(s"nested message should map to StructType, got: $sparkType"),
+ GpuFromProtobuf.ENC_DEFAULT)
+ }
+ }
+
+ sparkType match {
+ case BooleanType | IntegerType | LongType | FloatType | DoubleType |
+ StringType | BinaryType =>
+ case other =>
+ return (
+ false,
+ Some(s"unsupported Spark type: $other"),
+ GpuFromProtobuf.ENC_DEFAULT)
+ }
+
+ checkScalarEncoding(sparkType, protoTypeName, enumsAsInts)
+ }
+
+ def checkScalarEncoding(
+ sparkType: DataType,
+ protoTypeName: String,
+ enumsAsInts: Boolean): (Boolean, Option[String], Int) = {
+ val encoding = (sparkType, protoTypeName) match {
+ case (BooleanType, "BOOL") => Some(GpuFromProtobuf.ENC_DEFAULT)
+ case (IntegerType, "INT32" | "UINT32") => Some(GpuFromProtobuf.ENC_DEFAULT)
+ case (IntegerType, "SINT32") => Some(GpuFromProtobuf.ENC_ZIGZAG)
+ case (IntegerType, "FIXED32" | "SFIXED32") => Some(GpuFromProtobuf.ENC_FIXED)
+ case (LongType, "INT64" | "UINT64") => Some(GpuFromProtobuf.ENC_DEFAULT)
+ case (LongType, "SINT64") => Some(GpuFromProtobuf.ENC_ZIGZAG)
+ case (LongType, "FIXED64" | "SFIXED64") => Some(GpuFromProtobuf.ENC_FIXED)
+ case (LongType, "INT32" | "UINT32" | "SINT32" | "FIXED32" | "SFIXED32") =>
+ val enc = protoTypeName match {
+ case "SINT32" => GpuFromProtobuf.ENC_ZIGZAG
+ case "FIXED32" | "SFIXED32" => GpuFromProtobuf.ENC_FIXED
+ case _ => GpuFromProtobuf.ENC_DEFAULT
+ }
+ Some(enc)
+ case (FloatType, "FLOAT") => Some(GpuFromProtobuf.ENC_DEFAULT)
+ case (DoubleType, "DOUBLE") => Some(GpuFromProtobuf.ENC_DEFAULT)
+ case (StringType, "STRING") => Some(GpuFromProtobuf.ENC_DEFAULT)
+ case (BinaryType, "BYTES") => Some(GpuFromProtobuf.ENC_DEFAULT)
+ case (IntegerType, "ENUM") if enumsAsInts => Some(GpuFromProtobuf.ENC_DEFAULT)
+ case (StringType, "ENUM") if !enumsAsInts => Some(GpuFromProtobuf.ENC_ENUM_STRING)
+ case _ => None
+ }
+
+ encoding match {
+ case Some(enc) => (true, None, enc)
+ case None =>
+ val reason = (sparkType, protoTypeName) match {
+ case (DoubleType, "FLOAT") =>
+ "Spark DoubleType mapped to Protobuf FLOAT is not yet supported on GPU; " +
+ "use FloatType or fall back to CPU"
+ case (FloatType, "DOUBLE") =>
+ "Spark FloatType mapped to Protobuf DOUBLE is not yet supported on GPU; " +
+ "use DoubleType or fall back to CPU"
+ case _ =>
+ s"type mismatch: Spark $sparkType vs Protobuf $protoTypeName"
+ }
+ (false,
+ Some(reason),
+ GpuFromProtobuf.ENC_DEFAULT)
+ }
+ }
+
+ def getWireType(protoTypeName: String, encoding: Int): Either[String, Int] = {
+ val wireType = protoTypeName match {
+ case "BOOL" | "INT32" | "UINT32" | "SINT32" | "INT64" | "UINT64" | "SINT64" | "ENUM" =>
+ if (encoding == GpuFromProtobuf.ENC_FIXED) {
+ if (protoTypeName.contains("64")) {
+ WT_64BIT
+ } else {
+ WT_32BIT
+ }
+ } else {
+ WT_VARINT
+ }
+ case "FIXED32" | "SFIXED32" | "FLOAT" =>
+ WT_32BIT
+ case "FIXED64" | "SFIXED64" | "DOUBLE" =>
+ WT_64BIT
+ case "STRING" | "BYTES" | "MESSAGE" =>
+ WT_LEN
+ case other =>
+ return Left(
+ s"Unknown protobuf type name '$other' - cannot determine wire type; falling back to CPU")
+ }
+ Right(wireType)
+ }
+}
diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/protobuf/ProtobufSchemaModel.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/protobuf/ProtobufSchemaModel.scala
new file mode 100644
index 00000000000..cb8c7374baf
--- /dev/null
+++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/protobuf/ProtobufSchemaModel.scala
@@ -0,0 +1,226 @@
+/*
+ * Copyright (c) 2026, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.rapids.protobuf
+
+import java.util.Arrays
+
+import org.apache.spark.sql.types.DataType
+
+sealed trait ProtobufDescriptorSource
+
+object ProtobufDescriptorSource {
+ final case class DescriptorPath(path: String) extends ProtobufDescriptorSource
+ final case class DescriptorBytes(bytes: Array[Byte]) extends ProtobufDescriptorSource {
+ override def equals(other: Any): Boolean = other match {
+ case DescriptorBytes(otherBytes) => Arrays.equals(bytes, otherBytes)
+ case _ => false
+ }
+
+ override def hashCode(): Int = Arrays.hashCode(bytes)
+ }
+}
+
+final case class ProtobufExprInfo(
+ messageName: String,
+ descriptorSource: ProtobufDescriptorSource,
+ options: Map[String, String])
+
+final case class ProtobufPlannerOptions(
+ enumsAsInts: Boolean,
+ failOnErrors: Boolean)
+
+sealed trait ProtobufDefaultValue
+
+object ProtobufDefaultValue {
+ final case class BoolValue(value: Boolean) extends ProtobufDefaultValue
+ final case class IntValue(value: Long) extends ProtobufDefaultValue
+ final case class FloatValue(value: Float) extends ProtobufDefaultValue
+ final case class DoubleValue(value: Double) extends ProtobufDefaultValue
+ final case class StringValue(value: String) extends ProtobufDefaultValue
+ final case class BinaryValue(value: Array[Byte]) extends ProtobufDefaultValue {
+ override def equals(other: Any): Boolean = other match {
+ case BinaryValue(otherBytes) => Arrays.equals(value, otherBytes)
+ case _ => false
+ }
+
+ override def hashCode(): Int = Arrays.hashCode(value)
+ }
+ final case class EnumValue(number: Int, name: String) extends ProtobufDefaultValue
+}
+
+final case class ProtobufEnumValue(number: Int, name: String)
+
+final case class ProtobufEnumMetadata(values: Seq[ProtobufEnumValue]) {
+ lazy val validValues: Array[Int] = values.map(_.number).toArray
+ lazy val orderedNames: Array[Array[Byte]] = values.map(_.name.getBytes("UTF-8")).toArray
+ lazy val namesByNumber: Map[Int, String] = values.map(v => v.number -> v.name).toMap
+
+ def enumDefault(number: Int): ProtobufDefaultValue.EnumValue = {
+ val name = namesByNumber.getOrElse(number, s"$number")
+ ProtobufDefaultValue.EnumValue(number, name)
+ }
+}
+
+trait ProtobufMessageDescriptor {
+ def syntax: String
+ def findField(name: String): Option[ProtobufFieldDescriptor]
+}
+
+trait ProtobufFieldDescriptor {
+ def name: String
+ def fieldNumber: Int
+ def protoTypeName: String
+ def isRepeated: Boolean
+ def isRequired: Boolean
+ def defaultValueResult: Either[String, Option[ProtobufDefaultValue]]
+ def enumMetadata: Option[ProtobufEnumMetadata]
+ def messageDescriptor: Option[ProtobufMessageDescriptor]
+}
+
+final case class ProtobufFieldInfo(
+ fieldNumber: Int,
+ protoTypeName: String,
+ sparkType: DataType,
+ encoding: Int,
+ isSupported: Boolean,
+ unsupportedReason: Option[String],
+ isRequired: Boolean,
+ defaultValue: Option[ProtobufDefaultValue],
+ enumMetadata: Option[ProtobufEnumMetadata],
+ isRepeated: Boolean = false) {
+ def hasDefaultValue: Boolean = defaultValue.isDefined
+}
+
+final case class FlattenedFieldDescriptor(
+ fieldNumber: Int,
+ parentIdx: Int,
+ depth: Int,
+ wireType: Int,
+ outputTypeId: Int,
+ encoding: Int,
+ isRepeated: Boolean,
+ isRequired: Boolean,
+ hasDefaultValue: Boolean,
+ defaultInt: Long,
+ defaultFloat: Double,
+ defaultBool: Boolean,
+ defaultString: Array[Byte],
+ enumValidValues: Array[Int],
+ enumNames: Array[Array[Byte]]) {
+ override def equals(other: Any): Boolean = other match {
+ case that: FlattenedFieldDescriptor =>
+ fieldNumber == that.fieldNumber &&
+ parentIdx == that.parentIdx &&
+ depth == that.depth &&
+ wireType == that.wireType &&
+ outputTypeId == that.outputTypeId &&
+ encoding == that.encoding &&
+ isRepeated == that.isRepeated &&
+ isRequired == that.isRequired &&
+ hasDefaultValue == that.hasDefaultValue &&
+ defaultInt == that.defaultInt &&
+ java.lang.Double.compare(defaultFloat, that.defaultFloat) == 0 &&
+ defaultBool == that.defaultBool &&
+ Arrays.equals(defaultString, that.defaultString) &&
+ Arrays.equals(enumValidValues, that.enumValidValues) &&
+ Arrays.deepEquals(
+ enumNames.asInstanceOf[Array[Object]],
+ that.enumNames.asInstanceOf[Array[Object]])
+ case _ => false
+ }
+
+ override def hashCode(): Int = {
+ var result = fieldNumber
+ result = 31 * result + parentIdx
+ result = 31 * result + depth
+ result = 31 * result + wireType
+ result = 31 * result + outputTypeId
+ result = 31 * result + encoding
+ result = 31 * result + isRepeated.hashCode()
+ result = 31 * result + isRequired.hashCode()
+ result = 31 * result + hasDefaultValue.hashCode()
+ result = 31 * result + defaultInt.hashCode()
+ result = 31 * result + defaultFloat.hashCode()
+ result = 31 * result + defaultBool.hashCode()
+ result = 31 * result + Arrays.hashCode(defaultString)
+ result = 31 * result + Arrays.hashCode(enumValidValues)
+ result = 31 * result + Arrays.deepHashCode(enumNames.asInstanceOf[Array[Object]])
+ result
+ }
+}
+
+final case class FlattenedSchemaArrays(
+ fieldNumbers: Array[Int],
+ parentIndices: Array[Int],
+ depthLevels: Array[Int],
+ wireTypes: Array[Int],
+ outputTypeIds: Array[Int],
+ encodings: Array[Int],
+ isRepeated: Array[Boolean],
+ isRequired: Array[Boolean],
+ hasDefaultValue: Array[Boolean],
+ defaultInts: Array[Long],
+ defaultFloats: Array[Double],
+ defaultBools: Array[Boolean],
+ defaultStrings: Array[Array[Byte]],
+ enumValidValues: Array[Array[Int]],
+ enumNames: Array[Array[Array[Byte]]]) {
+ override def equals(other: Any): Boolean = other match {
+ case that: FlattenedSchemaArrays =>
+ Arrays.equals(fieldNumbers, that.fieldNumbers) &&
+ Arrays.equals(parentIndices, that.parentIndices) &&
+ Arrays.equals(depthLevels, that.depthLevels) &&
+ Arrays.equals(wireTypes, that.wireTypes) &&
+ Arrays.equals(outputTypeIds, that.outputTypeIds) &&
+ Arrays.equals(encodings, that.encodings) &&
+ Arrays.equals(isRepeated, that.isRepeated) &&
+ Arrays.equals(isRequired, that.isRequired) &&
+ Arrays.equals(hasDefaultValue, that.hasDefaultValue) &&
+ Arrays.equals(defaultInts, that.defaultInts) &&
+ Arrays.equals(defaultFloats, that.defaultFloats) &&
+ Arrays.equals(defaultBools, that.defaultBools) &&
+ Arrays.deepEquals(
+ defaultStrings.asInstanceOf[Array[Object]],
+ that.defaultStrings.asInstanceOf[Array[Object]]) &&
+ Arrays.deepEquals(
+ enumValidValues.asInstanceOf[Array[Object]],
+ that.enumValidValues.asInstanceOf[Array[Object]]) &&
+ Arrays.deepEquals(
+ enumNames.asInstanceOf[Array[Object]],
+ that.enumNames.asInstanceOf[Array[Object]])
+ case _ => false
+ }
+
+ override def hashCode(): Int = {
+ var result = Arrays.hashCode(fieldNumbers)
+ result = 31 * result + Arrays.hashCode(parentIndices)
+ result = 31 * result + Arrays.hashCode(depthLevels)
+ result = 31 * result + Arrays.hashCode(wireTypes)
+ result = 31 * result + Arrays.hashCode(outputTypeIds)
+ result = 31 * result + Arrays.hashCode(encodings)
+ result = 31 * result + Arrays.hashCode(isRepeated)
+ result = 31 * result + Arrays.hashCode(isRequired)
+ result = 31 * result + Arrays.hashCode(hasDefaultValue)
+ result = 31 * result + Arrays.hashCode(defaultInts)
+ result = 31 * result + Arrays.hashCode(defaultFloats)
+ result = 31 * result + Arrays.hashCode(defaultBools)
+ result = 31 * result + Arrays.deepHashCode(defaultStrings.asInstanceOf[Array[Object]])
+ result = 31 * result + Arrays.deepHashCode(enumValidValues.asInstanceOf[Array[Object]])
+ result = 31 * result + Arrays.deepHashCode(enumNames.asInstanceOf[Array[Object]])
+ result
+ }
+}
diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/protobuf/ProtobufSchemaValidator.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/protobuf/ProtobufSchemaValidator.scala
new file mode 100644
index 00000000000..31758c97236
--- /dev/null
+++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/protobuf/ProtobufSchemaValidator.scala
@@ -0,0 +1,184 @@
+/*
+ * Copyright (c) 2026, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.rapids.protobuf
+
+import ai.rapids.cudf.DType
+
+import org.apache.spark.sql.rapids.GpuFromProtobuf
+import org.apache.spark.sql.types._
+
+object ProtobufSchemaValidator {
+ private final case class JniDefaultValues(
+ defaultInt: Long,
+ defaultFloat: Double,
+ defaultBool: Boolean,
+ defaultString: Array[Byte])
+
+ def toFlattenedFieldDescriptor(
+ path: String,
+ field: StructField,
+ fieldInfo: ProtobufFieldInfo,
+ parentIdx: Int,
+ depth: Int,
+ outputTypeId: Int): Either[String, FlattenedFieldDescriptor] = {
+ validateFieldInfo(path, field, fieldInfo).flatMap { _ =>
+ ProtobufSchemaExtractor
+ .getWireType(fieldInfo.protoTypeName, fieldInfo.encoding)
+ .flatMap { wireType =>
+ encodeDefaultValue(path, field.dataType, fieldInfo).map { defaults =>
+ val enumValidValues = fieldInfo.enumMetadata.map(_.validValues).orNull
+ val enumNames =
+ if (fieldInfo.encoding == GpuFromProtobuf.ENC_ENUM_STRING) {
+ fieldInfo.enumMetadata.map(_.orderedNames).orNull
+ } else {
+ null
+ }
+
+ FlattenedFieldDescriptor(
+ fieldNumber = fieldInfo.fieldNumber,
+ parentIdx = parentIdx,
+ depth = depth,
+ wireType = wireType,
+ outputTypeId = outputTypeId,
+ encoding = fieldInfo.encoding,
+ isRepeated = fieldInfo.isRepeated,
+ isRequired = fieldInfo.isRequired,
+ hasDefaultValue = fieldInfo.hasDefaultValue,
+ defaultInt = defaults.defaultInt,
+ defaultFloat = defaults.defaultFloat,
+ defaultBool = defaults.defaultBool,
+ defaultString = defaults.defaultString,
+ enumValidValues = enumValidValues,
+ enumNames = enumNames
+ )
+ }
+ }
+ }
+ }
+
+ def validateFlattenedSchema(flatFields: Seq[FlattenedFieldDescriptor]): Either[String, Unit] = {
+ val structTypeId = DType.STRUCT.getTypeId.getNativeId
+ flatFields.zipWithIndex.foreach { case (field, idx) =>
+ if (field.parentIdx >= idx) {
+ return Left(s"Flattened protobuf schema has invalid parent index at position $idx")
+ }
+ if (field.parentIdx == -1 && field.depth != 0) {
+ return Left(s"Top-level protobuf field at position $idx must have depth 0")
+ }
+ if (field.parentIdx >= 0 && field.depth <= 0) {
+ return Left(s"Nested protobuf field at position $idx must have positive depth")
+ }
+ if (field.parentIdx >= 0 && flatFields(field.parentIdx).outputTypeId != structTypeId) {
+ return Left(
+ s"Protobuf field at position $idx has a non-STRUCT parent at ${field.parentIdx}")
+ }
+ if (field.encoding == GpuFromProtobuf.ENC_ENUM_STRING) {
+ if (field.enumValidValues == null || field.enumNames == null) {
+ return Left(s"Enum-string field at position $idx is missing enum metadata")
+ }
+ if (field.enumValidValues.length != field.enumNames.length) {
+ return Left(s"Enum-string field at position $idx has mismatched enum metadata")
+ }
+ }
+ }
+ Right(())
+ }
+
+ def toFlattenedSchemaArrays(
+ flatFields: Array[FlattenedFieldDescriptor]): FlattenedSchemaArrays = {
+ FlattenedSchemaArrays(
+ fieldNumbers = flatFields.map(_.fieldNumber),
+ parentIndices = flatFields.map(_.parentIdx),
+ depthLevels = flatFields.map(_.depth),
+ wireTypes = flatFields.map(_.wireType),
+ outputTypeIds = flatFields.map(_.outputTypeId),
+ encodings = flatFields.map(_.encoding),
+ isRepeated = flatFields.map(_.isRepeated),
+ isRequired = flatFields.map(_.isRequired),
+ hasDefaultValue = flatFields.map(_.hasDefaultValue),
+ defaultInts = flatFields.map(_.defaultInt),
+ defaultFloats = flatFields.map(_.defaultFloat),
+ defaultBools = flatFields.map(_.defaultBool),
+ defaultStrings = flatFields.map(_.defaultString),
+ enumValidValues = flatFields.map(_.enumValidValues),
+ enumNames = flatFields.map(_.enumNames)
+ )
+ }
+
+ private def validateFieldInfo(
+ path: String,
+ field: StructField,
+ fieldInfo: ProtobufFieldInfo): Either[String, Unit] = {
+ if (fieldInfo.isRepeated && fieldInfo.hasDefaultValue) {
+ return Left(s"Repeated protobuf field '$path' cannot carry a default value")
+ }
+
+ fieldInfo.enumMetadata match {
+ case Some(enumMeta) if enumMeta.values.isEmpty =>
+ return Left(s"Enum field '$path' is missing enum values")
+ case Some(_) if fieldInfo.protoTypeName != "ENUM" =>
+ return Left(s"Non-enum field '$path' should not carry enum metadata")
+ case None if fieldInfo.protoTypeName == "ENUM" =>
+ return Left(s"Enum field '$path' is missing enum metadata")
+ case _ =>
+ }
+
+ if (fieldInfo.encoding == GpuFromProtobuf.ENC_ENUM_STRING &&
+ fieldInfo.enumMetadata.isEmpty) {
+ return Left(s"Enum-string field '$path' is missing enum metadata")
+ }
+
+ Right(())
+ }
+
+ private def encodeDefaultValue(
+ path: String,
+ dataType: DataType,
+ fieldInfo: ProtobufFieldInfo): Either[String, JniDefaultValues] = {
+ val empty = JniDefaultValues(0L, 0.0, defaultBool = false, null)
+ fieldInfo.defaultValue match {
+ case None => Right(empty)
+ case Some(defaultValue) =>
+ val targetType = dataType match {
+ case ArrayType(elementType, _) => elementType
+ case other => other
+ }
+ (targetType, defaultValue) match {
+ case (BooleanType, ProtobufDefaultValue.BoolValue(value)) =>
+ Right(empty.copy(defaultBool = value))
+ case (IntegerType | LongType, ProtobufDefaultValue.IntValue(value)) =>
+ Right(empty.copy(defaultInt = value))
+ case (IntegerType | LongType, ProtobufDefaultValue.EnumValue(number, _)) =>
+ Right(empty.copy(defaultInt = number.toLong))
+ case (FloatType, ProtobufDefaultValue.FloatValue(value)) =>
+ Right(empty.copy(defaultFloat = value.toDouble))
+ case (DoubleType, ProtobufDefaultValue.DoubleValue(value)) =>
+ Right(empty.copy(defaultFloat = value))
+ case (StringType, ProtobufDefaultValue.StringValue(value)) =>
+ Right(empty.copy(defaultString = value.getBytes("UTF-8")))
+ case (StringType, ProtobufDefaultValue.EnumValue(number, name)) =>
+ Right(empty.copy(
+ defaultInt = number.toLong,
+ defaultString = name.getBytes("UTF-8")))
+ case (BinaryType, ProtobufDefaultValue.BinaryValue(value)) =>
+ Right(empty.copy(defaultString = value))
+ case _ =>
+ Left(s"Incompatible default value for protobuf field '$path': $defaultValue")
+ }
+ }
+ }
+}
diff --git a/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/ProtobufExprShims.scala b/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/ProtobufExprShims.scala
new file mode 100644
index 00000000000..13a5f7f6a0c
--- /dev/null
+++ b/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/ProtobufExprShims.scala
@@ -0,0 +1,910 @@
+/*
+ * Copyright (c) 2026, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*** spark-rapids-shim-json-lines
+{"spark": "340"}
+{"spark": "341"}
+{"spark": "342"}
+{"spark": "343"}
+{"spark": "344"}
+{"spark": "350"}
+{"spark": "351"}
+{"spark": "352"}
+{"spark": "353"}
+{"spark": "354"}
+{"spark": "355"}
+{"spark": "356"}
+{"spark": "357"}
+{"spark": "400"}
+{"spark": "401"}
+{"spark": "402"}
+{"spark": "411"}
+spark-rapids-shim-json-lines ***/
+
+package com.nvidia.spark.rapids.shims
+
+import scala.collection.mutable
+
+import ai.rapids.cudf.DType
+import com.nvidia.spark.rapids._
+
+import org.apache.spark.sql.catalyst.expressions.{
+ AttributeReference, Expression, GetArrayStructFields, GetStructField, UnaryExpression
+}
+import org.apache.spark.sql.execution.ProjectExec
+import org.apache.spark.sql.rapids.GpuFromProtobuf
+import org.apache.spark.sql.rapids.protobuf.{
+ FlattenedFieldDescriptor,
+ ProtobufFieldInfo,
+ ProtobufMessageDescriptor,
+ ProtobufSchemaExtractor,
+ ProtobufSchemaValidator
+}
+import org.apache.spark.sql.types._
+
+/**
+ * Spark 3.4+ optional integration for spark-protobuf expressions.
+ *
+ * spark-protobuf is an external module, so these rules must be registered by reflection.
+ */
+object ProtobufExprShims extends org.apache.spark.internal.Logging {
+ private[this] val protobufDataToCatalystClassName =
+ "org.apache.spark.sql.protobuf.ProtobufDataToCatalyst"
+
+ val PRUNED_ORDINAL_TAG =
+ org.apache.spark.sql.rapids.GpuStructFieldOrdinalTag.PRUNED_ORDINAL_TAG
+
+ def exprs: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = {
+ try {
+ val clazz = ShimReflectionUtils.loadClass(protobufDataToCatalystClassName)
+ .asInstanceOf[Class[_ <: UnaryExpression]]
+ Map(clazz.asInstanceOf[Class[_ <: Expression]] -> fromProtobufRule)
+ } catch {
+ case _: ClassNotFoundException => Map.empty
+ case e: Exception =>
+ logWarning(s"Failed to load $protobufDataToCatalystClassName: ${e.getMessage}")
+ Map.empty
+ case e: Error =>
+ logWarning(
+ s"JVM error while loading $protobufDataToCatalystClassName: ${e.getMessage}")
+ Map.empty
+ }
+ }
+
+ private def fromProtobufRule: ExprRule[_ <: Expression] = {
+ GpuOverrides.expr[UnaryExpression](
+ "Decode a BinaryType column (protobuf) into a Spark SQL struct",
+ ExprChecks.unaryProject(
+ // Use TypeSig.all here because schema projection determines which fields
+ // actually need GPU support. Detailed type checking is done in tagExprForGpu.
+ TypeSig.all,
+ TypeSig.all,
+ TypeSig.BINARY,
+ TypeSig.BINARY),
+ (e, conf, p, r) => new UnaryExprMeta[UnaryExpression](e, conf, p, r) {
+
+ private var fullSchema: StructType = _
+ private var failOnErrors: Boolean = _
+
+ // Flattened schema variables for GPU decoding
+ private var flatFieldNumbers: Array[Int] = _
+ private var flatParentIndices: Array[Int] = _
+ private var flatDepthLevels: Array[Int] = _
+ private var flatWireTypes: Array[Int] = _
+ private var flatOutputTypeIds: Array[Int] = _
+ private var flatEncodings: Array[Int] = _
+ private var flatIsRepeated: Array[Boolean] = _
+ private var flatIsRequired: Array[Boolean] = _
+ private var flatHasDefaultValue: Array[Boolean] = _
+ private var flatDefaultInts: Array[Long] = _
+ private var flatDefaultFloats: Array[Double] = _
+ private var flatDefaultBools: Array[Boolean] = _
+ private var flatDefaultStrings: Array[Array[Byte]] = _
+ private var flatEnumValidValues: Array[Array[Int]] = _
+ private var flatEnumNames: Array[Array[Array[Byte]]] = _
+ // Indices in fullSchema for top-level fields that were decoded (for schema projection)
+ private var decodedTopLevelIndices: Array[Int] = _
+
+ override def tagExprForGpu(): Unit = {
+ fullSchema = e.dataType match {
+ case st: StructType => st
+ case other =>
+ willNotWorkOnGpu(
+ s"Only StructType output is supported for from_protobuf, got $other")
+ return
+ }
+
+ val exprInfo = SparkProtobufCompat.extractExprInfo(e) match {
+ case Right(info) => info
+ case Left(reason) =>
+ willNotWorkOnGpu(reason)
+ return
+ }
+ val unsupportedOptions = SparkProtobufCompat.unsupportedOptions(exprInfo.options)
+ if (unsupportedOptions.nonEmpty) {
+ val keys = unsupportedOptions.mkString(",")
+ willNotWorkOnGpu(
+ s"from_protobuf options are not supported yet on GPU: $keys")
+ return
+ }
+
+ val plannerOptions = SparkProtobufCompat.parsePlannerOptions(exprInfo.options) match {
+ case Right(opts) => opts
+ case Left(reason) =>
+ willNotWorkOnGpu(reason)
+ return
+ }
+ val enumsAsInts = plannerOptions.enumsAsInts
+ failOnErrors = plannerOptions.failOnErrors
+ val messageName = exprInfo.messageName
+
+ val msgDesc = SparkProtobufCompat.resolveMessageDescriptor(exprInfo) match {
+ case Right(desc) => desc
+ case Left(reason) =>
+ willNotWorkOnGpu(reason)
+ return
+ }
+
+ // Reject proto3 descriptors — GPU decoder only supports proto2 semantics.
+ // proto3 has different null/default-value behavior that the GPU path doesn't handle.
+ val protoSyntax = msgDesc.syntax
+ if (!SparkProtobufCompat.isGpuSupportedProtoSyntax(protoSyntax)) {
+ willNotWorkOnGpu(
+ "proto3/editions syntax is not supported by the GPU protobuf decoder; " +
+ "only proto2 is supported. The query will fall back to CPU.")
+ return
+ }
+
+ // Step 1: Analyze all fields and build field info map
+ val fieldsInfoMap =
+ ProtobufSchemaExtractor.analyzeAllFields(fullSchema, msgDesc, enumsAsInts, messageName)
+ .fold({ reason =>
+ willNotWorkOnGpu(reason)
+ return
+ }, identity)
+
+ // Step 2: Determine which fields are actually required by downstream operations
+ val requiredFieldNames = analyzeRequiredFields(fieldsInfoMap.keySet)
+
+ // Step 2b: Proto2 required fields must always be decoded so the GPU can
+ // detect missing-required and null the struct row (PERMISSIVE) or throw
+ // (FAILFAST), matching CPU behavior. Without this, schema projection
+ // can prune a required field and the GPU silently produces a non-null
+ // struct where CPU would have returned null.
+ val protoRequiredFieldNames = fieldsInfoMap.collect {
+ case (name, info) if info.isRequired => name
+ }.toSet
+ val allFieldsToDecode = requiredFieldNames ++ protoRequiredFieldNames
+
+ // Step 2c: When nested pruning selects only some children of a struct,
+ // proto2 required children must still be included so their presence
+ // can be checked by the GPU decoder.
+ augmentNestedRequirementsWithRequired(msgDesc)
+
+ // Step 3: Check if all fields to be decoded are supported
+ val unsupportedRequired = allFieldsToDecode.filter { name =>
+ fieldsInfoMap.get(name).exists(!_.isSupported)
+ }
+
+ if (unsupportedRequired.nonEmpty) {
+ val reasons = unsupportedRequired.map { name =>
+ val info = fieldsInfoMap(name)
+ s"${name}: ${info.unsupportedReason.getOrElse("unknown reason")}"
+ }
+ willNotWorkOnGpu(
+ s"Required fields not supported for from_protobuf: ${reasons.mkString(", ")}")
+ return
+ }
+
+ // Step 4: Identify which fields in fullSchema need to be decoded
+ val indicesToDecode = fullSchema.fields.zipWithIndex.collect {
+ case (sf, idx) if allFieldsToDecode.contains(sf.name) => idx
+ }
+
+ // Verify all fields to be decoded are actually supported
+ // (This catches edge cases where field analysis might have issues)
+ val unsupportedInDecode = indicesToDecode.filter { idx =>
+ val sf = fullSchema.fields(idx)
+ fieldsInfoMap.get(sf.name).exists(!_.isSupported)
+ }
+ if (unsupportedInDecode.nonEmpty) {
+ val reasons = unsupportedInDecode.map { idx =>
+ val sf = fullSchema.fields(idx)
+ val info = fieldsInfoMap(sf.name)
+ s"${sf.name}: ${info.unsupportedReason.getOrElse("unknown reason")}"
+ }
+ willNotWorkOnGpu(
+ s"Fields not supported for from_protobuf: ${reasons.mkString(", ")}")
+ return
+ }
+
+ // Step 5: Build flattened schema for GPU decoding.
+ // The flattened schema represents nested fields with parent indices.
+ // For pure scalar schemas, all fields are top-level (parentIdx == -1, depth == 0).
+ {
+ val flatFields = mutable.ArrayBuffer[FlattenedFieldDescriptor]()
+ var step5Failed = false
+
+ def failStep5(reason: String): Unit = {
+ step5Failed = true
+ willNotWorkOnGpu(reason)
+ }
+
+ // Helper to add a field and its children recursively.
+ // pathPrefix is the dot-path of ancestor fields (empty for top-level).
+ def addFieldWithChildren(
+ sf: StructField,
+ info: ProtobufFieldInfo,
+ parentIdx: Int,
+ depth: Int,
+ containingMsgDesc: ProtobufMessageDescriptor,
+ pathPrefix: String = ""): Unit = {
+
+ val currentIdx = flatFields.size
+
+ if (depth >= 10) {
+ failStep5("Protobuf nesting depth exceeds maximum supported depth of 10")
+ return
+ }
+
+ val outputTypeOpt = sf.dataType match {
+ case ArrayType(elemType, _) =>
+ elemType match {
+ case _: StructType =>
+ // Repeated message field: ArrayType(StructType) - element type is STRUCT
+ Some(DType.STRUCT.getTypeId.getNativeId)
+ case other =>
+ GpuFromProtobuf.sparkTypeToCudfIdOpt(other)
+ }
+ case _: StructType =>
+ Some(DType.STRUCT.getTypeId.getNativeId)
+ case other =>
+ GpuFromProtobuf.sparkTypeToCudfIdOpt(other)
+ }
+ val outputType = outputTypeOpt.getOrElse {
+ failStep5(
+ s"Unsupported Spark type for protobuf field '${sf.name}': ${sf.dataType}")
+ return
+ }
+
+ val path = if (pathPrefix.isEmpty) sf.name else s"$pathPrefix.${sf.name}"
+ ProtobufSchemaValidator.toFlattenedFieldDescriptor(
+ path,
+ sf,
+ info,
+ parentIdx,
+ depth,
+ outputType).fold({ reason =>
+ failStep5(reason)
+ return
+ }, flatFields += _)
+
+ // For nested struct types (including repeated message = ArrayType(StructType)),
+ // add child fields
+ sf.dataType match {
+ case st: StructType if containingMsgDesc != null =>
+ // Repeated message parents and plain struct parents share the same child
+ // expansion path; the flat parent entry's isRepeated flag distinguishes them.
+ addChildFieldsFromStruct(
+ st, containingMsgDesc, sf.name, currentIdx, depth, pathPrefix)
+
+ case ArrayType(st: StructType, _) if containingMsgDesc != null =>
+ addChildFieldsFromStruct(
+ st, containingMsgDesc, sf.name, currentIdx, depth, pathPrefix)
+
+ case _ => // Not a struct, no children to add
+ }
+ }
+
+ // Helper to add child fields from a struct type.
+ // Applies nested schema pruning at arbitrary depth using path-based
+ // lookup into nestedFieldRequirements.
+ def addChildFieldsFromStruct(
+ st: StructType,
+ containingMsgDesc: ProtobufMessageDescriptor,
+ fieldName: String,
+ parentIdx: Int,
+ parentDepth: Int,
+ pathPrefix: String): Unit = {
+ val path = if (pathPrefix.isEmpty) fieldName else s"$pathPrefix.$fieldName"
+ // containingMsgDesc is the descriptor of the message that directly contains
+ // fieldName.
+ val parentField = containingMsgDesc.findField(fieldName)
+ if (parentField.isEmpty) {
+ failStep5(
+ s"Nested field '$fieldName' not found in protobuf descriptor at '$path'")
+ return
+ } else {
+ parentField.get.messageDescriptor match {
+ case Some(childMsgDesc) =>
+ val requiredChildren = nestedFieldRequirements.get(path)
+ val filteredFields = requiredChildren match {
+ case Some(Some(childNames)) =>
+ st.fields.filter(f => childNames.contains(f.name))
+ case _ =>
+ st.fields
+ }
+ filteredFields.foreach { childSf =>
+ childMsgDesc.findField(childSf.name) match {
+ case None =>
+ failStep5(
+ s"Nested field '${childSf.name}' not found in protobuf " +
+ s"descriptor for message at '$path'")
+ return
+ case Some(childFd) =>
+ ProtobufSchemaExtractor
+ .extractFieldInfo(childSf, childFd, enumsAsInts) match {
+ case Left(reason) =>
+ failStep5(reason)
+ return
+ case Right(childInfo) =>
+ if (!childInfo.isSupported) {
+ failStep5(
+ s"Nested field '${childSf.name}' at '$path': " +
+ childInfo.unsupportedReason.getOrElse("unsupported type"))
+ return
+ } else {
+ addFieldWithChildren(
+ childSf, childInfo, parentIdx, parentDepth + 1, childMsgDesc,
+ path)
+ }
+ }
+ }
+ }
+ case None =>
+ failStep5(
+ s"Nested field '$fieldName' at '$path' did not resolve to a message type")
+ return
+ }
+ }
+ }
+
+ // Only add top-level fields that are actually required (schema projection).
+ // This significantly reduces GPU memory and computation for schemas with many
+ // fields when only a few are needed. Downstream GetStructField ordinals are
+ // remapped via PRUNED_ORDINAL_TAG to index into the pruned output.
+ decodedTopLevelIndices = indicesToDecode
+ indicesToDecode.foreach { schemaIdx =>
+ if (!step5Failed) {
+ val sf = fullSchema.fields(schemaIdx)
+ val info = fieldsInfoMap(sf.name)
+ addFieldWithChildren(sf, info, -1, 0, msgDesc)
+ }
+ }
+
+ if (step5Failed) {
+ return
+ }
+
+ // Populate flattened schema variables
+ val flat = flatFields.toArray
+ ProtobufSchemaValidator.validateFlattenedSchema(flat).fold({ reason =>
+ failStep5(reason)
+ return
+ }, identity)
+ val arrays = ProtobufSchemaValidator.toFlattenedSchemaArrays(flat)
+ flatFieldNumbers = arrays.fieldNumbers
+ flatParentIndices = arrays.parentIndices
+ flatDepthLevels = arrays.depthLevels
+ flatWireTypes = arrays.wireTypes
+ flatOutputTypeIds = arrays.outputTypeIds
+ flatEncodings = arrays.encodings
+ flatIsRepeated = arrays.isRepeated
+ flatIsRequired = arrays.isRequired
+ flatHasDefaultValue = arrays.hasDefaultValue
+ flatDefaultInts = arrays.defaultInts
+ flatDefaultFloats = arrays.defaultFloats
+ flatDefaultBools = arrays.defaultBools
+ flatDefaultStrings = arrays.defaultStrings
+ flatEnumValidValues = arrays.enumValidValues
+ flatEnumNames = arrays.enumNames
+
+ val prunedFieldsMap = buildPrunedFieldsMap()
+ // PRUNED_ORDINAL_TAG is set here, after all willNotWorkOnGpu guards succeed.
+ // This is safe because the CPU path never reads the tag, and if a parent later
+ // forces this subtree back to CPU, the decode and its field extractors fall back
+ // together, so no partial-GPU path can misread stale ordinals.
+ targetExprsToRemap.foreach(
+ registerPrunedOrdinals(_, prunedFieldsMap, decodedTopLevelIndices.toSeq))
+ overrideDataType(buildDecodedSchema(prunedFieldsMap))
+ }
+ }
+
+ /**
+ * Analyze which fields are actually required by downstream operations.
+ * Traverses parent plan nodes upward, collecting struct field references from
+ * ProjectExec, FilterExec, and transparent pass-through nodes (AggregateExec,
+ * SortExec, WindowExec, etc.), then returns the set of required top-level
+ * field names.
+ *
+ * @param allFieldNames All field names in the full schema
+ * @return Set of field names that are actually required
+ */
+ private var targetExprsToRemap: Seq[Expression] = Seq.empty
+
+ private def analyzeRequiredFields(allFieldNames: Set[String]): Set[String] = {
+ val fieldReqs = mutable.Map[String, Option[Set[String]]]()
+ protobufOutputExprIds = Set.empty
+ var hasDirectStructRef = false
+ val holder = () => { hasDirectStructRef = true }
+
+ var currentMeta: Option[SparkPlanMeta[_]] = findParentPlanMeta()
+ var safeToPrune = true
+ val collectedExprs = mutable.ArrayBuffer[Expression]()
+ val startingPlanMeta = currentMeta
+
+ def advanceToParent(): Unit = {
+ currentMeta = currentMeta.get.parent match {
+ case Some(pm: SparkPlanMeta[_]) => Some(pm)
+ case _ => None
+ }
+ }
+
+ while (currentMeta.isDefined && safeToPrune) {
+ val allowSemanticReferenceMatch = currentMeta == startingPlanMeta
+ currentMeta.get.wrapped match {
+ case p: ProjectExec =>
+ collectedExprs ++= p.projectList
+ p.projectList.foreach {
+ case alias: org.apache.spark.sql.catalyst.expressions.Alias
+ if isProtobufStructReference(
+ alias.child, allowSemanticReferenceMatch) =>
+ protobufOutputExprIds += alias.exprId
+ case _ =>
+ }
+ p.projectList.foreach(
+ collectStructFieldReferences(
+ _, fieldReqs, holder, allowSemanticReferenceMatch))
+ advanceToParent()
+ case f: org.apache.spark.sql.execution.FilterExec =>
+ collectedExprs += f.condition
+ collectStructFieldReferences(
+ f.condition, fieldReqs, holder, allowSemanticReferenceMatch)
+ advanceToParent()
+ case a: org.apache.spark.sql.execution.aggregate.BaseAggregateExec =>
+ val exprs = a.aggregateExpressions ++ a.groupingExpressions
+ collectedExprs ++= exprs
+ exprs.foreach(
+ collectStructFieldReferences(
+ _, fieldReqs, holder, allowSemanticReferenceMatch))
+ advanceToParent()
+ case s: org.apache.spark.sql.execution.SortExec =>
+ val exprs = s.sortOrder
+ collectedExprs ++= exprs
+ exprs.foreach(
+ collectStructFieldReferences(
+ _, fieldReqs, holder, allowSemanticReferenceMatch))
+ advanceToParent()
+ case w: org.apache.spark.sql.execution.window.WindowExec =>
+ val exprs = w.windowExpression
+ collectedExprs ++= exprs
+ exprs.foreach(
+ collectStructFieldReferences(
+ _, fieldReqs, holder, allowSemanticReferenceMatch))
+ advanceToParent()
+ case other =>
+ // Keep schema projection conservative for less common plan shapes above
+ // from_protobuf. Those plans currently fall back to full-schema decode
+ // until we add dedicated coverage for pruning through them.
+ logDebug(s"Schema pruning disabled: unrecognized plan node " +
+ s"${other.getClass.getSimpleName} above from_protobuf")
+ safeToPrune = false
+ }
+ }
+
+ // An empty fieldReqs also subsumes the "no relevant expressions collected" case.
+ if (!safeToPrune || hasDirectStructRef || fieldReqs.isEmpty) {
+ targetExprsToRemap = Seq.empty
+ allFieldNames
+ } else {
+ nestedFieldRequirements = fieldReqs.toMap
+ targetExprsToRemap = collectedExprs.toSeq
+ fieldReqs.keySet.toSet
+ }
+ }
+
+ /**
+ * Find the parent SparkPlanMeta by traversing up the parent chain.
+ */
+ private def findParentPlanMeta(): Option[SparkPlanMeta[_]] = {
+ def traverse(meta: Option[RapidsMeta[_, _, _]]): Option[SparkPlanMeta[_]] = {
+ meta match {
+ case Some(p: SparkPlanMeta[_]) => Some(p)
+ case Some(p: RapidsMeta[_, _, _]) => traverse(p.parent)
+ case _ => None
+ }
+ }
+ traverse(parent)
+ }
+
+ /**
+ * Nested field requirements: maps a field path to child requirements.
+ * Keys are dot-separated paths from the protobuf root:
+ * - "level1" -> Some(Set("level2")) (top-level struct pruning)
+ * - "level1.level2" -> Some(Set("level3")) (deep nested pruning)
+ * - "field" -> None (whole field needed)
+ *
+ * Top-level names (keys without dots) also determine which fields are decoded.
+ */
+ private var nestedFieldRequirements: Map[String, Option[Set[String]]] = Map.empty
+ private var protobufOutputExprIds: Set[
+ org.apache.spark.sql.catalyst.expressions.ExprId] = Set.empty
+ private lazy val protobufOutputExprId
+ : Option[org.apache.spark.sql.catalyst.expressions.ExprId] =
+ parent.flatMap { meta =>
+ meta.wrapped match {
+ case alias: org.apache.spark.sql.catalyst.expressions
+ .Alias if alias.child.semanticEquals(e) =>
+ Some(alias.exprId)
+ case _ => None
+ }
+ }
+
+ private def getFieldName(ordinal: Int, nameOpt: Option[String],
+ schema: StructType): String = {
+ nameOpt.getOrElse {
+ if (ordinal < schema.fields.length) schema.fields(ordinal).name
+ else s"_$ordinal"
+ }
+ }
+
+ /**
+ * Navigate the protobuf descriptor tree by following a path of field names.
+ * Returns the message descriptor at the end of the path, or None.
+ */
+ private def resolveProtoMsgDesc(
+ rootDesc: ProtobufMessageDescriptor,
+ pathParts: Seq[String]): Option[ProtobufMessageDescriptor] = {
+ pathParts.headOption match {
+ case None => Some(rootDesc)
+ case Some(head) =>
+ rootDesc.findField(head).flatMap(_.messageDescriptor).flatMap { childDesc =>
+ if (pathParts.tail.isEmpty) Some(childDesc)
+ else resolveProtoMsgDesc(childDesc, pathParts.tail)
+ }
+ }
+ }
+
+ /**
+ * Augment nestedFieldRequirements so that proto2 required children are
+ * always included when nested pruning is active. Without this, a required
+ * child that is not referenced downstream would be pruned, and the GPU
+ * decoder would not check its presence.
+ */
+ private def augmentNestedRequirementsWithRequired(
+ rootMsgDesc: ProtobufMessageDescriptor): Unit = {
+ if (nestedFieldRequirements.isEmpty) return
+ nestedFieldRequirements = nestedFieldRequirements.map {
+ case entry @ (pathKey, Some(childNames)) =>
+ val pathParts = pathKey.split("\\.").toSeq
+ resolveProtoMsgDesc(rootMsgDesc, pathParts) match {
+ case Some(childMsgDesc) =>
+ val schemaType = resolveSchemaAtPath(fullSchema, pathParts)
+ if (schemaType != null) {
+ val requiredChildNames = schemaType.fields.flatMap { sf =>
+ childMsgDesc.findField(sf.name).filter(_.isRequired).map(_ => sf.name)
+ }.toSet
+ if (requiredChildNames.subsetOf(childNames)) entry
+ else pathKey -> Some(childNames ++ requiredChildNames)
+ } else {
+ entry
+ }
+ case None => entry
+ }
+ case other => other
+ }
+ }
+
+ /**
+ * Navigate the Spark schema tree by following a dot-separated path of
+ * field names. Returns the StructType at the end of the path, unwrapping
+ * ArrayType(StructType) along the way, or null if the path is invalid.
+ */
+ private def resolveSchemaAtPath(root: StructType, path: Seq[String]): StructType = {
+ var current: StructType = root
+ for (name <- path) {
+ val field = current.fields.find(_.name == name).orNull
+ if (field == null) return null
+ field.dataType match {
+ case st: StructType => current = st
+ case ArrayType(st: StructType, _) => current = st
+ case _ => return null
+ }
+ }
+ current
+ }
+
+ /**
+ * Walk a GetStructField chain upward until it reaches the protobuf
+ * reference expression, returning the sequence of field names forming
+ * the access path. Returns None if the chain does not terminate at a
+ * protobuf reference.
+ *
+ * Example: for `GetStructField(GetStructField(decoded, a_ord), b_ord)`
+ * → Some(Seq("a", "b"))
+ */
+ private def resolveFieldAccessChain(
+ expr: Expression,
+ allowSemanticReferenceMatch: Boolean): Option[Seq[String]] = {
+ expr match {
+ case GetStructField(child, ordinal, nameOpt) =>
+ if (isProtobufStructReference(child, allowSemanticReferenceMatch)) {
+ Some(Seq(getFieldName(ordinal, nameOpt, fullSchema)))
+ } else {
+ resolveFieldAccessChain(child, allowSemanticReferenceMatch).flatMap { parentPath =>
+ val parentSchema = if (parentPath.isEmpty) fullSchema
+ else resolveSchemaAtPath(fullSchema, parentPath)
+ if (parentSchema != null) {
+ Some(parentPath :+ getFieldName(ordinal, nameOpt, parentSchema))
+ } else {
+ None
+ }
+ }
+ }
+ case _ if isProtobufStructReference(expr, allowSemanticReferenceMatch) =>
+ Some(Seq.empty)
+ case _ =>
+ None
+ }
+ }
+
+ private def addNestedFieldReq(
+ fieldReqs: mutable.Map[String, Option[Set[String]]],
+ parentKey: String,
+ childName: String): Unit = {
+ fieldReqs.get(parentKey) match {
+ case Some(None) => // Already need whole field, keep it
+ case Some(Some(existing)) =>
+ fieldReqs(parentKey) = Some(existing + childName)
+ case None =>
+ fieldReqs(parentKey) = Some(Set(childName))
+ }
+ }
+
+ /**
+ * Register pruning requirements at every level of a field access path.
+ * For path = ["a", "b"] with leafName = "c":
+ * "a" -> needs child "b"
+ * "a.b" -> needs child "c"
+ */
+ private def registerPathRequirements(
+ fieldReqs: mutable.Map[String, Option[Set[String]]],
+ path: Seq[String],
+ leafName: String): Unit = {
+ for (i <- path.indices) {
+ val pathKey = path.take(i + 1).mkString(".")
+ val childName = if (i < path.length - 1) path(i + 1) else leafName
+ addNestedFieldReq(fieldReqs, pathKey, childName)
+ }
+ }
+
+ private def collectStructFieldReferences(
+ expr: Expression,
+ fieldReqs: mutable.Map[String, Option[Set[String]]],
+ hasDirectStructRefHolder: () => Unit,
+ allowSemanticReferenceMatch: Boolean): Unit = {
+ expr match {
+ case GetStructField(child, ordinal, nameOpt) =>
+ resolveFieldAccessChain(child, allowSemanticReferenceMatch) match {
+ case Some(parentPath) =>
+ val parentSchema = if (parentPath.isEmpty) fullSchema
+ else resolveSchemaAtPath(fullSchema, parentPath)
+ if (parentSchema != null) {
+ val fieldName = getFieldName(ordinal, nameOpt, parentSchema)
+ if (parentPath.isEmpty) {
+ // Direct top-level access: decoded.field_name (whole field)
+ fieldReqs(fieldName) = None
+ } else {
+ registerPathRequirements(fieldReqs, parentPath, fieldName)
+ }
+ } else {
+ collectStructFieldReferences(
+ child, fieldReqs, hasDirectStructRefHolder, allowSemanticReferenceMatch)
+ }
+ case None =>
+ collectStructFieldReferences(
+ child, fieldReqs, hasDirectStructRefHolder, allowSemanticReferenceMatch)
+ }
+
+ case gasf: GetArrayStructFields =>
+ resolveFieldAccessChain(gasf.child, allowSemanticReferenceMatch) match {
+ case Some(parentPath) if parentPath.nonEmpty =>
+ registerPathRequirements(fieldReqs, parentPath, gasf.field.name)
+ case Some(parentPath) if parentPath.isEmpty =>
+ logDebug("Schema pruning disabled: unexpected direct protobuf reference in " +
+ "GetArrayStructFields")
+ hasDirectStructRefHolder()
+ case _ =>
+ gasf.children.foreach { child =>
+ collectStructFieldReferences(
+ child, fieldReqs, hasDirectStructRefHolder, allowSemanticReferenceMatch)
+ }
+ }
+
+ case alias: org.apache.spark.sql.catalyst.expressions.Alias =>
+ if (!isProtobufStructReference(alias.child, allowSemanticReferenceMatch)) {
+ collectStructFieldReferences(
+ alias.child, fieldReqs, hasDirectStructRefHolder, allowSemanticReferenceMatch)
+ }
+
+ case _ =>
+ if (isProtobufStructReference(expr, allowSemanticReferenceMatch)) {
+ hasDirectStructRefHolder()
+ }
+ expr.children.foreach { child =>
+ collectStructFieldReferences(
+ child, fieldReqs, hasDirectStructRefHolder, allowSemanticReferenceMatch)
+ }
+ }
+ }
+
+ private def buildPrunedFieldsMap(): Map[String, Seq[String]] = {
+ nestedFieldRequirements.collect {
+ case (pathKey, Some(childNames)) =>
+ val pathParts = pathKey.split("\\.").toSeq
+ val childSchema = resolveSchemaAtPath(fullSchema, pathParts)
+ if (childSchema != null) {
+ val orderedNames = childSchema.fields
+ .map(_.name)
+ .filter(childNames.contains)
+ .toSeq
+ pathKey -> orderedNames
+ } else {
+ // This path should be unreachable for valid schema paths, but keep the fallback
+ // deterministic rather than relying on Set iteration order.
+ pathKey -> childNames.toSeq.sorted
+ }
+ }
+ }
+
+ private def buildDecodedSchema(prunedFieldsMap: Map[String, Seq[String]]): StructType = {
+ def applyPruning(field: StructField, prefix: String): StructField = {
+ val path = if (prefix.isEmpty) field.name else s"$prefix.${field.name}"
+ prunedFieldsMap.get(path) match {
+ case Some(childNames) =>
+ field.dataType match {
+ case ArrayType(st: StructType, cn) =>
+ val pruned = StructType(
+ st.fields.filter(f => childNames.contains(f.name))
+ .map(f => applyPruning(f, path)))
+ field.copy(dataType = ArrayType(pruned, cn))
+ case st: StructType =>
+ val pruned = StructType(
+ st.fields.filter(f => childNames.contains(f.name))
+ .map(f => applyPruning(f, path)))
+ field.copy(dataType = pruned)
+ case _ => field
+ }
+ case None =>
+ field.dataType match {
+ case ArrayType(st: StructType, cn) =>
+ val recursed = StructType(st.fields.map(f => applyPruning(f, path)))
+ if (recursed != st) field.copy(dataType = ArrayType(recursed, cn))
+ else field
+ case st: StructType =>
+ val recursed = StructType(st.fields.map(f => applyPruning(f, path)))
+ if (recursed != st) field.copy(dataType = recursed)
+ else field
+ case _ => field
+ }
+ }
+ }
+
+ val decodedFields = decodedTopLevelIndices.map { idx =>
+ applyPruning(fullSchema.fields(idx), "")
+ }
+ StructType(decodedFields.map(f => f.copy(nullable = true)))
+ }
+
+ private def registerPrunedOrdinals(
+ expr: Expression,
+ prunedFieldsMap: Map[String, Seq[String]],
+ topLevelIndices: Seq[Int]): Unit = {
+ expr match {
+ case gsf @ GetStructField(childExpr, ordinal, nameOpt) =>
+ resolveFieldAccessChain(childExpr, allowSemanticReferenceMatch = true) match {
+ case Some(parentPath) if parentPath.nonEmpty =>
+ val parentSchema = resolveSchemaAtPath(fullSchema, parentPath)
+ if (parentSchema != null) {
+ val pathKey = parentPath.mkString(".")
+ val childName = getFieldName(ordinal, nameOpt, parentSchema)
+ prunedFieldsMap.get(pathKey).foreach { orderedChildren =>
+ val runtimeOrd = orderedChildren.indexOf(childName)
+ if (runtimeOrd >= 0) {
+ gsf.setTagValue(ProtobufExprShims.PRUNED_ORDINAL_TAG, runtimeOrd)
+ }
+ }
+ }
+ case Some(parentPath) if parentPath.isEmpty =>
+ val runtimeOrd = topLevelIndices.indexOf(ordinal)
+ if (runtimeOrd >= 0) {
+ gsf.setTagValue(ProtobufExprShims.PRUNED_ORDINAL_TAG, runtimeOrd)
+ }
+ case _ =>
+ }
+ case gasf @ GetArrayStructFields(childExpr, field, _, _, _) =>
+ resolveFieldAccessChain(childExpr, allowSemanticReferenceMatch = true) match {
+ case Some(parentPath) if parentPath.nonEmpty =>
+ val pathKey = parentPath.mkString(".")
+ prunedFieldsMap.get(pathKey).foreach { orderedChildren =>
+ val runtimeOrd = orderedChildren.indexOf(field.name)
+ if (runtimeOrd >= 0) {
+ gasf.setTagValue(ProtobufExprShims.PRUNED_ORDINAL_TAG, runtimeOrd)
+ }
+ }
+ case _ =>
+ }
+ case _ =>
+ }
+ expr.children.foreach(registerPrunedOrdinals(_, prunedFieldsMap, topLevelIndices))
+ }
+
+ /**
+ * Check if an expression references the output of a protobuf decode expression.
+ * This can be either:
+ * 1. The ProtobufDataToCatalyst expression itself
+ * 2. An AttributeReference that references the output of ProtobufDataToCatalyst
+ * (when accessing from a downstream ProjectExec)
+ */
+ private def isProtobufStructReference(
+ expr: Expression,
+ allowSemanticReferenceMatch: Boolean): Boolean = {
+ if ((expr eq e) || expr.semanticEquals(e)) {
+ return true
+ }
+
+ // Catalyst may create duplicate ProtobufDataToCatalyst
+ // instances for each GetStructField access. Match copies
+ // by class + identical input child + identical decode
+ // semantics so that
+ // analyzeRequiredFields detects all field accesses in one
+ // pass, keeping schema projection correct.
+ if (allowSemanticReferenceMatch &&
+ expr.getClass == e.getClass &&
+ expr.children.nonEmpty &&
+ e.children.nonEmpty &&
+ ((expr.children.head eq e.children.head) ||
+ expr.children.head.semanticEquals(
+ e.children.head)) &&
+ SparkProtobufCompat.sameDecodeSemantics(expr, e)) {
+ return true
+ }
+
+ expr match {
+ case attr: AttributeReference =>
+ protobufOutputExprIds.contains(attr.exprId) ||
+ protobufOutputExprId.exists(_ == attr.exprId)
+ case _ => false
+ }
+ }
+
+ override def convertToGpu(child: Expression): GpuExpression = {
+ val prunedFieldsMap = buildPrunedFieldsMap()
+ val decodedSchema = buildDecodedSchema(prunedFieldsMap)
+
+ GpuFromProtobuf(
+ decodedSchema,
+ flatFieldNumbers, flatParentIndices,
+ flatDepthLevels, flatWireTypes, flatOutputTypeIds, flatEncodings,
+ flatIsRepeated, flatIsRequired, flatHasDefaultValue, flatDefaultInts,
+ flatDefaultFloats, flatDefaultBools, flatDefaultStrings, flatEnumValidValues,
+ flatEnumNames, failOnErrors, child)
+ }
+ }
+ )
+ }
+
+}
diff --git a/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/Spark340PlusNonDBShims.scala b/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/Spark340PlusNonDBShims.scala
index 88c62eea41b..afb9c059ccf 100644
--- a/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/Spark340PlusNonDBShims.scala
+++ b/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/Spark340PlusNonDBShims.scala
@@ -160,7 +160,7 @@ trait Spark340PlusNonDBShims extends Spark331PlusNonDBShims {
),
GpuElementAtMeta.elementAtRule(true)
).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap
- super.getExprs ++ shimExprs
+ super.getExprs ++ shimExprs ++ ProtobufExprShims.exprs
}
override def getDataWriteCmds: Map[Class[_ <: DataWritingCommand],
diff --git a/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/SparkProtobufCompat.scala b/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/SparkProtobufCompat.scala
new file mode 100644
index 00000000000..85f1c4ac12b
--- /dev/null
+++ b/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/SparkProtobufCompat.scala
@@ -0,0 +1,381 @@
+/*
+ * Copyright (c) 2026, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*** spark-rapids-shim-json-lines
+{"spark": "340"}
+{"spark": "341"}
+{"spark": "342"}
+{"spark": "343"}
+{"spark": "344"}
+{"spark": "350"}
+{"spark": "351"}
+{"spark": "352"}
+{"spark": "353"}
+{"spark": "354"}
+{"spark": "355"}
+{"spark": "356"}
+{"spark": "357"}
+{"spark": "400"}
+{"spark": "401"}
+{"spark": "402"}
+{"spark": "411"}
+spark-rapids-shim-json-lines ***/
+
+package com.nvidia.spark.rapids.shims
+
+import java.lang.reflect.Method
+import java.nio.file.{Files, Paths}
+
+import scala.util.Try
+
+import com.nvidia.spark.rapids.ShimReflectionUtils
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.rapids.protobuf._
+
+private[shims] object SparkProtobufCompat extends Logging {
+ private[this] val sparkProtobufUtilsObjectClassName =
+ "org.apache.spark.sql.protobuf.utils.ProtobufUtils$"
+
+ val SupportedOptions: Set[String] = Set("enums.as.ints", "mode")
+
+ def extractExprInfo(e: Expression): Either[String, ProtobufExprInfo] = {
+ for {
+ messageName <- reflectMessageName(e)
+ options <- reflectOptions(e)
+ descriptorSource <- reflectDescriptorSource(e)
+ } yield ProtobufExprInfo(messageName, descriptorSource, options)
+ }
+
+ def sameDecodeSemantics(left: Expression, right: Expression): Boolean = {
+ (extractExprInfo(left), extractExprInfo(right)) match {
+ case (Right(leftInfo), Right(rightInfo)) => leftInfo == rightInfo
+ case _ => false
+ }
+ }
+
+ def parsePlannerOptions(
+ options: Map[String, String]): Either[String, ProtobufPlannerOptions] = {
+ val enumsAsInts = Try(options.getOrElse("enums.as.ints", "false").toBoolean)
+ .toEither
+ .left
+ .map { _ =>
+ "Invalid value for from_protobuf option 'enums.as.ints': " +
+ s"'${options.getOrElse("enums.as.ints", "")}' (expected true/false)"
+ }
+ enumsAsInts.map(v =>
+ ProtobufPlannerOptions(
+ enumsAsInts = v,
+ failOnErrors = options.getOrElse("mode", "FAILFAST").equalsIgnoreCase("FAILFAST")))
+ }
+
+ def unsupportedOptions(options: Map[String, String]): Seq[String] =
+ options.keys.filterNot(SupportedOptions.contains).toSeq.sorted
+
+ def isGpuSupportedProtoSyntax(syntax: String): Boolean =
+ syntax.nonEmpty && syntax != "null" && syntax != "PROTO3" && syntax != "EDITIONS"
+
+ def resolveMessageDescriptor(
+ exprInfo: ProtobufExprInfo): Either[String, ProtobufMessageDescriptor] = {
+ Try(buildMessageDescriptor(exprInfo.messageName, exprInfo.descriptorSource))
+ .toEither
+ .left
+ .map { t =>
+ s"Failed to resolve protobuf descriptor for message '${exprInfo.messageName}': " +
+ s"${t.getMessage}"
+ }
+ .map(new ReflectiveMessageDescriptor(_))
+ }
+
+ private def reflectMessageName(e: Expression): Either[String, String] =
+ Try(PbReflect.invoke0[String](e, "messageName")).toEither.left.map { t =>
+ s"Cannot read from_protobuf messageName via reflection: ${t.getMessage}"
+ }
+
+ private def reflectOptions(e: Expression): Either[String, Map[String, String]] = {
+ Try(PbReflect.invoke0[scala.collection.Map[String, String]](e, "options"))
+ .map(_.toMap)
+ .toEither.left.map { _ =>
+ "Cannot read from_protobuf options via reflection; falling back to CPU"
+ }
+ }
+
+ private def reflectDescriptorSource(e: Expression): Either[String, ProtobufDescriptorSource] = {
+ reflectDescFilePath(e).map(ProtobufDescriptorSource.DescriptorPath).orElse(
+ reflectDescriptorBytes(e).map(ProtobufDescriptorSource.DescriptorBytes)).toRight(
+ "from_protobuf requires a descriptor set (descFilePath or binaryFileDescriptorSet)")
+ }
+
+ private def reflectDescFilePath(e: Expression): Option[String] =
+ Try(PbReflect.invoke0[Option[String]](e, "descFilePath")).toOption.flatten
+
+ private def reflectDescriptorBytes(e: Expression): Option[Array[Byte]] = {
+ val spark35Result = Try(PbReflect.invoke0[Option[Array[Byte]]](e, "binaryFileDescriptorSet"))
+ .toOption.flatten
+ spark35Result.orElse {
+ val direct = Try(PbReflect.invoke0[Array[Byte]](e, "binaryDescriptorSet")).toOption
+ direct.orElse {
+ Try(PbReflect.invoke0[Option[Array[Byte]]](e, "binaryDescriptorSet")).toOption.flatten
+ }
+ }
+ }
+
+ private def buildMessageDescriptor(
+ messageName: String,
+ descriptorSource: ProtobufDescriptorSource): AnyRef = {
+ val cls = ShimReflectionUtils.loadClass(sparkProtobufUtilsObjectClassName)
+ val module = cls.getField("MODULE$").get(null)
+ val buildMethod = cls.getMethod("buildDescriptor", classOf[String], classOf[scala.Option[_]])
+
+ invokeBuildDescriptor(
+ buildMethod,
+ module,
+ messageName,
+ descriptorSource,
+ filePath => Files.readAllBytes(Paths.get(filePath)))
+ }
+
+ private[shims] def invokeBuildDescriptor(
+ buildMethod: Method,
+ module: AnyRef,
+ messageName: String,
+ descriptorSource: ProtobufDescriptorSource,
+ readDescriptorFile: String => Array[Byte]): AnyRef = {
+ descriptorSource match {
+ case ProtobufDescriptorSource.DescriptorBytes(bytes) =>
+ buildMethod.invoke(module, messageName, Some(bytes)).asInstanceOf[AnyRef]
+ case ProtobufDescriptorSource.DescriptorPath(filePath) =>
+ try {
+ buildMethod.invoke(module, messageName, Some(filePath)).asInstanceOf[AnyRef]
+ } catch {
+ case ex: java.lang.reflect.InvocationTargetException =>
+ val cause = ex.getCause
+ // Spark 3.5+ changed the descriptor payload from Option[String] to
+ // Option[Array[Byte]] while keeping the same erased JVM signature.
+ // Retry with file contents when the path-based invocation clearly hit that
+ // binary-descriptor variant.
+ if (cause != null && (cause.isInstanceOf[ClassCastException] ||
+ cause.isInstanceOf[MatchError])) {
+ Try {
+ buildMethod.invoke(
+ module, messageName, Some(readDescriptorFile(filePath))).asInstanceOf[AnyRef]
+ }.recoverWith { case retryEx =>
+ val wrapped = buildDescriptorRetryFailure(cause, retryEx)
+ wrapped.addSuppressed(ex)
+ scala.util.Failure(wrapped)
+ }.get
+ } else {
+ throw ex
+ }
+ }
+ }
+ }
+
+ private def buildDescriptorRetryFailure(
+ originalCause: Throwable,
+ retryFailure: Throwable): RuntimeException = {
+ val retryCause = unwrapInvocationFailure(retryFailure)
+ new RuntimeException(
+ s"Spark 3.5+ descriptor bytes retry failed after initial path invocation error " +
+ s"(${describeThrowable(originalCause)}); retry error (${describeThrowable(retryCause)})",
+ retryCause)
+ }
+
+ private def unwrapInvocationFailure(t: Throwable): Throwable = t match {
+ case ex: java.lang.reflect.InvocationTargetException if ex.getCause != null => ex.getCause
+ case other => other
+ }
+
+ private def describeThrowable(t: Throwable): String = {
+ val suffix = Option(t.getMessage).filter(_.nonEmpty).map(msg => s": $msg").getOrElse("")
+ s"${t.getClass.getSimpleName}$suffix"
+ }
+
+ private def typeName(t: AnyRef): String =
+ if (t == null) "" else Try(PbReflect.invoke0[String](t, "name")).getOrElse(t.toString)
+
+ private final class ReflectiveMessageDescriptor(raw: AnyRef) extends ProtobufMessageDescriptor {
+ override lazy val syntax: String = PbReflect.getFileSyntax(raw, typeName)
+
+ override def findField(name: String): Option[ProtobufFieldDescriptor] =
+ Option(PbReflect.findFieldByName(raw, name)).map(new ReflectiveFieldDescriptor(_))
+ }
+
+ private final class ReflectiveFieldDescriptor(raw: AnyRef) extends ProtobufFieldDescriptor {
+ override lazy val name: String = PbReflect.invoke0[String](raw, "getName")
+ override lazy val fieldNumber: Int = PbReflect.getFieldNumber(raw)
+ override lazy val protoTypeName: String = typeName(PbReflect.getFieldType(raw))
+ override lazy val isRepeated: Boolean = PbReflect.isRepeated(raw)
+ override lazy val isRequired: Boolean = PbReflect.isRequired(raw)
+ override lazy val enumMetadata: Option[ProtobufEnumMetadata] =
+ if (protoTypeName == "ENUM") {
+ Some(ProtobufEnumMetadata(PbReflect.getEnumValues(PbReflect.getEnumType(raw))))
+ } else {
+ None
+ }
+ override lazy val defaultValueResult: Either[String, Option[ProtobufDefaultValue]] =
+ Try {
+ if (PbReflect.hasDefaultValue(raw)) {
+ PbReflect.getDefaultValue(raw) match {
+ case Some(default) =>
+ toDefaultValue(default, protoTypeName, enumMetadata).map(Some(_))
+ case None =>
+ Right(None)
+ }
+ } else {
+ Right(None)
+ }
+ }.toEither.left.map { t =>
+ s"Failed to read protobuf default value for field '$name': ${t.getMessage}"
+ }.flatMap(identity)
+ override lazy val messageDescriptor: Option[ProtobufMessageDescriptor] =
+ if (protoTypeName == "MESSAGE") {
+ Some(new ReflectiveMessageDescriptor(PbReflect.getMessageType(raw)))
+ } else {
+ None
+ }
+ }
+
+ private def toDefaultValue(
+ rawDefault: AnyRef,
+ protoTypeName: String,
+ enumMetadata: Option[ProtobufEnumMetadata]): Either[String, ProtobufDefaultValue] =
+ protoTypeName match {
+ case "BOOL" =>
+ Right(ProtobufDefaultValue.BoolValue(
+ rawDefault.asInstanceOf[java.lang.Boolean].booleanValue()))
+ case "FLOAT" =>
+ Right(ProtobufDefaultValue.FloatValue(
+ rawDefault.asInstanceOf[java.lang.Float].floatValue()))
+ case "DOUBLE" =>
+ Right(ProtobufDefaultValue.DoubleValue(
+ rawDefault.asInstanceOf[java.lang.Double].doubleValue()))
+ case "STRING" =>
+ Right(ProtobufDefaultValue.StringValue(
+ if (rawDefault == null) null else rawDefault.toString))
+ case "BYTES" =>
+ Right(ProtobufDefaultValue.BinaryValue(extractBytes(rawDefault)))
+ case "ENUM" =>
+ val number = extractNumber(rawDefault).intValue()
+ Right(enumMetadata.map(_.enumDefault(number))
+ .getOrElse(ProtobufDefaultValue.EnumValue(number, rawDefault.toString)))
+ case "INT32" | "UINT32" | "SINT32" | "FIXED32" | "SFIXED32" |
+ "INT64" | "UINT64" | "SINT64" | "FIXED64" | "SFIXED64" =>
+ Right(ProtobufDefaultValue.IntValue(extractNumber(rawDefault).longValue()))
+ case other =>
+ Left(
+ s"Unsupported protobuf default value type '$other' for value ${rawDefault.toString}")
+ }
+
+ private def extractNumber(rawDefault: AnyRef): java.lang.Number = rawDefault match {
+ case n: java.lang.Number => n
+ case ref: AnyRef =>
+ Try {
+ ref.getClass.getMethod("getNumber").invoke(ref).asInstanceOf[java.lang.Number]
+ }.getOrElse {
+ throw new IllegalStateException(
+ s"Unsupported protobuf numeric default value class: ${ref.getClass.getName}")
+ }
+ }
+
+ private def extractBytes(rawDefault: AnyRef): Array[Byte] = rawDefault match {
+ case bytes: Array[Byte] => bytes
+ case ref: AnyRef =>
+ Try {
+ ref.getClass.getMethod("toByteArray").invoke(ref).asInstanceOf[Array[Byte]]
+ }.getOrElse {
+ throw new IllegalStateException(
+ s"Unsupported protobuf bytes default value class: ${ref.getClass.getName}")
+ }
+ }
+
+ private object PbReflect {
+ private val cache = new java.util.concurrent.ConcurrentHashMap[String, Method]()
+
+ private def protobufJavaVersion: String = Try {
+ val rtCls = Class.forName("com.google.protobuf.RuntimeVersion")
+ val domain = rtCls.getField("DOMAIN").get(null)
+ val major = rtCls.getField("MAJOR").get(null)
+ val minor = rtCls.getField("MINOR").get(null)
+ val patch = rtCls.getField("PATCH").get(null)
+ s"$domain-$major.$minor.$patch"
+ }.getOrElse("unknown")
+
+ private def cached(cls: Class[_], name: String, paramTypes: Class[_]*): Method = {
+ val key = s"${cls.getName}#$name(${paramTypes.map(_.getName).mkString(",")})"
+ cache.computeIfAbsent(key, _ => {
+ try {
+ cls.getMethod(name, paramTypes: _*)
+ } catch {
+ case ex: NoSuchMethodException =>
+ throw new UnsupportedOperationException(
+ s"protobuf-java method not found: ${cls.getSimpleName}.$name " +
+ s"(protobuf-java version: $protobufJavaVersion). " +
+ s"This may indicate an incompatible protobuf-java library version.",
+ ex)
+ }
+ })
+ }
+
+ def invoke0[T](obj: AnyRef, method: String): T =
+ cached(obj.getClass, method).invoke(obj).asInstanceOf[T]
+
+ def invoke1[T](obj: AnyRef, method: String, arg0Cls: Class[_], arg0: AnyRef): T =
+ cached(obj.getClass, method, arg0Cls).invoke(obj, arg0).asInstanceOf[T]
+
+ def findFieldByName(msgDesc: AnyRef, name: String): AnyRef =
+ invoke1[AnyRef](msgDesc, "findFieldByName", classOf[String], name)
+
+ def getFieldNumber(fd: AnyRef): Int =
+ invoke0[java.lang.Integer](fd, "getNumber").intValue()
+
+ def getFieldType(fd: AnyRef): AnyRef = invoke0[AnyRef](fd, "getType")
+
+ def isRepeated(fd: AnyRef): Boolean =
+ invoke0[java.lang.Boolean](fd, "isRepeated").booleanValue()
+
+ def isRequired(fd: AnyRef): Boolean =
+ invoke0[java.lang.Boolean](fd, "isRequired").booleanValue()
+
+ def hasDefaultValue(fd: AnyRef): Boolean =
+ invoke0[java.lang.Boolean](fd, "hasDefaultValue").booleanValue()
+
+ def getDefaultValue(fd: AnyRef): Option[AnyRef] =
+ Option(invoke0[AnyRef](fd, "getDefaultValue"))
+
+ def getMessageType(fd: AnyRef): AnyRef = invoke0[AnyRef](fd, "getMessageType")
+
+ def getEnumType(fd: AnyRef): AnyRef = invoke0[AnyRef](fd, "getEnumType")
+
+ def getEnumValues(enumType: AnyRef): Seq[ProtobufEnumValue] = {
+ import scala.collection.JavaConverters._
+ val values = invoke0[java.util.List[_]](enumType, "getValues")
+ values.asScala.map { v =>
+ val ev = v.asInstanceOf[AnyRef]
+ val num = invoke0[java.lang.Integer](ev, "getNumber").intValue()
+ val enumName = invoke0[String](ev, "getName")
+ ProtobufEnumValue(num, enumName)
+ }.toSeq
+ }
+
+ def getFileSyntax(msgDesc: AnyRef, typeNameFn: AnyRef => String): String = Try {
+ val fileDesc = invoke0[AnyRef](msgDesc, "getFile")
+ val syntaxObj = invoke0[AnyRef](fileDesc, "getSyntax")
+ typeNameFn(syntaxObj)
+ }.getOrElse("")
+ }
+}
diff --git a/sql-plugin/src/test/scala/com/nvidia/spark/rapids/shims/ProtobufExprShimsSuite.scala b/sql-plugin/src/test/scala/com/nvidia/spark/rapids/shims/ProtobufExprShimsSuite.scala
new file mode 100644
index 00000000000..4bd85dbc458
--- /dev/null
+++ b/sql-plugin/src/test/scala/com/nvidia/spark/rapids/shims/ProtobufExprShimsSuite.scala
@@ -0,0 +1,633 @@
+/*
+ * Copyright (c) 2026, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.nvidia.spark.rapids.shims
+
+import ai.rapids.cudf.DType
+import org.scalatest.funsuite.AnyFunSuite
+
+import org.apache.spark.sql.catalyst.expressions.{
+ Expression,
+ GetArrayStructFields,
+ UnaryExpression
+}
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
+import org.apache.spark.sql.rapids.{
+ GpuFromProtobuf,
+ GpuGetArrayStructFieldsMeta,
+ GpuStructFieldOrdinalTag
+}
+import org.apache.spark.sql.rapids.protobuf._
+import org.apache.spark.sql.types._
+
+class ProtobufExprShimsSuite extends AnyFunSuite {
+ private val outputSchema = StructType(Seq(
+ StructField("id", IntegerType, nullable = true),
+ StructField("name", StringType, nullable = true)))
+
+ private case class FakeExprChild() extends Expression {
+ override def children: Seq[Expression] = Nil
+ override def nullable: Boolean = true
+ override def dataType: DataType = BinaryType
+ override def eval(input: org.apache.spark.sql.catalyst.InternalRow): Any = null
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
+ throw new UnsupportedOperationException("not needed")
+ override protected def withNewChildrenInternal(
+ newChildren: IndexedSeq[Expression]): Expression = {
+ assert(newChildren.isEmpty)
+ this
+ }
+ }
+
+ private abstract class FakeBaseProtobufExpr(childExpr: Expression) extends UnaryExpression {
+ override def child: Expression = childExpr
+ override def nullable: Boolean = true
+ override def dataType: DataType = outputSchema
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
+ throw new UnsupportedOperationException("not needed")
+ override protected def withNewChildInternal(newChild: Expression): Expression = this
+ }
+
+ private case class FakePathProtobufExpr(override val child: Expression)
+ extends FakeBaseProtobufExpr(child) {
+ def messageName: String = "test.Message"
+ def descFilePath: Option[String] = Some("/tmp/test.desc")
+ def options: scala.collection.Map[String, String] = Map("mode" -> "FAILFAST")
+ }
+
+ private case class FakeBytesProtobufExpr(override val child: Expression)
+ extends FakeBaseProtobufExpr(child) {
+ def messageName: String = "test.Message"
+ def binaryDescriptorSet: Array[Byte] = Array[Byte](1, 2, 3)
+ def options: scala.collection.Map[String, String] =
+ Map("mode" -> "PERMISSIVE", "enums.as.ints" -> "true")
+ }
+
+ private case class FakeMissingOptionsExpr(override val child: Expression)
+ extends FakeBaseProtobufExpr(child) {
+ def messageName: String = "test.Message"
+ def descFilePath: Option[String] = Some("/tmp/test.desc")
+ }
+
+ private case class FakeDifferentMessageExpr(override val child: Expression)
+ extends FakeBaseProtobufExpr(child) {
+ def messageName: String = "test.OtherMessage"
+ def descFilePath: Option[String] = Some("/tmp/test.desc")
+ def options: scala.collection.Map[String, String] = Map("mode" -> "FAILFAST")
+ }
+
+ private case class FakeDifferentDescriptorExpr(override val child: Expression)
+ extends FakeBaseProtobufExpr(child) {
+ def messageName: String = "test.Message"
+ def descFilePath: Option[String] = Some("/tmp/other.desc")
+ def options: scala.collection.Map[String, String] = Map("mode" -> "FAILFAST")
+ }
+
+ private case class FakeDifferentOptionsExpr(override val child: Expression)
+ extends FakeBaseProtobufExpr(child) {
+ def messageName: String = "test.Message"
+ def descFilePath: Option[String] = Some("/tmp/test.desc")
+ def options: scala.collection.Map[String, String] = Map("mode" -> "PERMISSIVE")
+ }
+
+ private case class FakeTypedUnaryExpr(
+ dt: DataType,
+ override val child: Expression = FakeExprChild()) extends UnaryExpression {
+ override def nullable: Boolean = true
+ override def dataType: DataType = dt
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
+ throw new UnsupportedOperationException("not needed")
+ override protected def withNewChildInternal(newChild: Expression): Expression = copy(child =
+ newChild)
+ }
+
+ private object FakeSpark34ProtobufUtils {
+ def buildDescriptor(messageName: String, descFilePath: Option[String]): String =
+ s"$messageName:${descFilePath.getOrElse("none")}"
+ }
+
+ private object FakeSpark35ProtobufUtils {
+ def buildDescriptor(messageName: String, binaryFileDescriptorSet: Option[Array[Byte]]): String =
+ s"$messageName:${binaryFileDescriptorSet.map(_.mkString(",")).getOrElse("none")}"
+ }
+
+ private object FakeSpark35RetryFailureProtobufUtils {
+ def buildDescriptor(
+ messageName: String,
+ binaryFileDescriptorSet: Option[Array[Byte]]): String = {
+ val bytes = binaryFileDescriptorSet.getOrElse(Array.emptyByteArray)
+ if (bytes.sameElements(Array[Byte](1, 2, 3))) {
+ throw new IllegalArgumentException(s"Unknown message $messageName")
+ }
+ s"$messageName:${bytes.mkString(",")}"
+ }
+ }
+
+ private case class FakeMessageDescriptor(
+ syntax: String,
+ fields: Map[String, ProtobufFieldDescriptor]) extends ProtobufMessageDescriptor {
+ override def findField(name: String): Option[ProtobufFieldDescriptor] = fields.get(name)
+ }
+
+ private case class FakeFieldDescriptor(
+ name: String,
+ fieldNumber: Int,
+ protoTypeName: String,
+ isRepeated: Boolean = false,
+ isRequired: Boolean = false,
+ defaultValue: Option[ProtobufDefaultValue] = None,
+ defaultValueError: Option[String] = None,
+ enumMetadata: Option[ProtobufEnumMetadata] = None,
+ messageDescriptor: Option[ProtobufMessageDescriptor] = None) extends ProtobufFieldDescriptor {
+ override lazy val defaultValueResult: Either[String, Option[ProtobufDefaultValue]] =
+ defaultValueError match {
+ case Some(reason) => Left(reason)
+ case None => Right(defaultValue)
+ }
+ }
+
+ test("compat extracts descriptor path and options from legacy expression") {
+ val exprInfo = SparkProtobufCompat.extractExprInfo(FakePathProtobufExpr(FakeExprChild()))
+ assert(exprInfo.isRight)
+ val info = exprInfo.toOption.get
+ assert(info.messageName == "test.Message")
+ assert(info.options == Map("mode" -> "FAILFAST"))
+ assert(info.descriptorSource ==
+ ProtobufDescriptorSource.DescriptorPath("/tmp/test.desc"))
+ }
+
+ test("compat extracts binary descriptor source and planner options") {
+ val exprInfo = SparkProtobufCompat.extractExprInfo(FakeBytesProtobufExpr(FakeExprChild()))
+ assert(exprInfo.isRight)
+ val info = exprInfo.toOption.get
+ info.descriptorSource match {
+ case ProtobufDescriptorSource.DescriptorBytes(bytes) =>
+ assert(bytes.sameElements(Array[Byte](1, 2, 3)))
+ case other =>
+ fail(s"Unexpected descriptor source: $other")
+ }
+ val plannerOptions = SparkProtobufCompat.parsePlannerOptions(info.options)
+ assert(plannerOptions ==
+ Right(ProtobufPlannerOptions(enumsAsInts = true, failOnErrors = false)))
+ }
+
+ test("compat invokes Spark 3.4 descriptor builder with descriptor path") {
+ val buildMethod = FakeSpark34ProtobufUtils.getClass.getMethod(
+ "buildDescriptor", classOf[String], classOf[scala.Option[_]])
+
+ val result = SparkProtobufCompat.invokeBuildDescriptor(
+ buildMethod,
+ FakeSpark34ProtobufUtils,
+ "test.Message",
+ ProtobufDescriptorSource.DescriptorPath("/tmp/test.desc"),
+ _ => fail("path-to-bytes fallback should not be needed for Spark 3.4"))
+
+ assert(result == "test.Message:/tmp/test.desc")
+ }
+
+ test("compat retries descriptor path as bytes for Spark 3.5 descriptor builder") {
+ val buildMethod = FakeSpark35ProtobufUtils.getClass.getMethod(
+ "buildDescriptor", classOf[String], classOf[scala.Option[_]])
+ var readCalls = 0
+
+ val result = SparkProtobufCompat.invokeBuildDescriptor(
+ buildMethod,
+ FakeSpark35ProtobufUtils,
+ "test.Message",
+ ProtobufDescriptorSource.DescriptorPath("/tmp/test.desc"),
+ _ => {
+ readCalls += 1
+ Array[Byte](1, 2, 3)
+ })
+
+ assert(readCalls == 1)
+ assert(result == "test.Message:1,2,3")
+ }
+
+ test("compat passes bytes directly to Spark 3.5 descriptor builder") {
+ val buildMethod = FakeSpark35ProtobufUtils.getClass.getMethod(
+ "buildDescriptor", classOf[String], classOf[scala.Option[_]])
+
+ val result = SparkProtobufCompat.invokeBuildDescriptor(
+ buildMethod,
+ FakeSpark35ProtobufUtils,
+ "test.Message",
+ ProtobufDescriptorSource.DescriptorBytes(Array[Byte](4, 5, 6)),
+ _ => fail("binary descriptor source should not read a file"))
+
+ assert(result == "test.Message:4,5,6")
+ }
+
+ test("compat preserves retry context when descriptor bytes fallback also fails") {
+ val buildMethod = FakeSpark35RetryFailureProtobufUtils.getClass.getMethod(
+ "buildDescriptor", classOf[String], classOf[scala.Option[_]])
+
+ val ex = intercept[RuntimeException] {
+ SparkProtobufCompat.invokeBuildDescriptor(
+ buildMethod,
+ FakeSpark35RetryFailureProtobufUtils,
+ "test.Message",
+ ProtobufDescriptorSource.DescriptorPath("/tmp/test.desc"),
+ _ => Array[Byte](1, 2, 3))
+ }
+
+ assert(ex.getMessage.contains("descriptor bytes retry failed"))
+ assert(ex.getMessage.contains("ClassCastException"))
+ assert(ex.getMessage.contains("Unknown message test.Message"))
+ assert(ex.getCause.isInstanceOf[IllegalArgumentException])
+ assert(ex.getSuppressed.exists(_.isInstanceOf[java.lang.reflect.InvocationTargetException]))
+ }
+
+ test("compat distinguishes decode semantics across message descriptor and options") {
+ val child = FakeExprChild()
+
+ assert(SparkProtobufCompat.sameDecodeSemantics(
+ FakePathProtobufExpr(child), FakePathProtobufExpr(child)))
+ assert(SparkProtobufCompat.sameDecodeSemantics(
+ FakeBytesProtobufExpr(child), FakeBytesProtobufExpr(child)))
+ assert(!SparkProtobufCompat.sameDecodeSemantics(
+ FakePathProtobufExpr(child), FakeDifferentMessageExpr(child)))
+ assert(!SparkProtobufCompat.sameDecodeSemantics(
+ FakePathProtobufExpr(child), FakeDifferentDescriptorExpr(child)))
+ assert(!SparkProtobufCompat.sameDecodeSemantics(
+ FakePathProtobufExpr(child), FakeDifferentOptionsExpr(child)))
+ }
+
+ test("compat reports missing options accessor as cpu fallback reason") {
+ val exprInfo = SparkProtobufCompat.extractExprInfo(FakeMissingOptionsExpr(FakeExprChild()))
+ assert(exprInfo.left.toOption.exists(
+ _.contains("Cannot read from_protobuf options via reflection")))
+ }
+
+ test("compat detects unsupported options and proto3 syntax") {
+ assert(SparkProtobufCompat.unsupportedOptions(Map("mode" -> "FAILFAST", "foo" -> "bar")) ==
+ Seq("foo"))
+ assert(!SparkProtobufCompat.isGpuSupportedProtoSyntax("PROTO3"))
+ assert(!SparkProtobufCompat.isGpuSupportedProtoSyntax("EDITIONS"))
+ assert(!SparkProtobufCompat.isGpuSupportedProtoSyntax(""))
+ assert(!SparkProtobufCompat.isGpuSupportedProtoSyntax("null"))
+ assert(SparkProtobufCompat.isGpuSupportedProtoSyntax("PROTO2"))
+ }
+
+ test("compat returns Left for unsupported default value types") {
+ val method = SparkProtobufCompat.getClass.getDeclaredMethods
+ .find(_.getName.endsWith("toDefaultValue"))
+ .getOrElse(fail("toDefaultValue method not found"))
+ method.setAccessible(true)
+
+ val result = method.invoke(
+ SparkProtobufCompat,
+ "opaque-default",
+ "MESSAGE",
+ scala.None).asInstanceOf[Either[String, ProtobufDefaultValue]]
+
+ assert(result.left.toOption.exists(_.contains("Unsupported protobuf default value type")))
+ }
+
+ test("extractor preserves typed enum defaults") {
+ val enumMeta = ProtobufEnumMetadata(Seq(
+ ProtobufEnumValue(0, "UNKNOWN"),
+ ProtobufEnumValue(1, "EN"),
+ ProtobufEnumValue(2, "ZH")))
+ val msgDesc = FakeMessageDescriptor(
+ syntax = "PROTO2",
+ fields = Map(
+ "language" -> FakeFieldDescriptor(
+ name = "language",
+ fieldNumber = 1,
+ protoTypeName = "ENUM",
+ defaultValue = Some(ProtobufDefaultValue.EnumValue(1, "EN")),
+ enumMetadata = Some(enumMeta))))
+ val schema = StructType(Seq(StructField("language", StringType, nullable = true)))
+
+ val infos = ProtobufSchemaExtractor.analyzeAllFields(
+ schema, msgDesc, enumsAsInts = false, "test.Message")
+
+ assert(infos.isRight)
+ assert(infos.toOption.get("language").defaultValue.contains(
+ ProtobufDefaultValue.EnumValue(1, "EN")))
+ }
+
+ test("extractor records reflection failures as unsupported field info") {
+ val msgDesc = FakeMessageDescriptor(
+ syntax = "PROTO2",
+ fields = Map(
+ "ok" -> FakeFieldDescriptor(
+ name = "ok",
+ fieldNumber = 1,
+ protoTypeName = "INT32"),
+ "id" -> FakeFieldDescriptor(
+ name = "id",
+ fieldNumber = 2,
+ protoTypeName = "INT32",
+ defaultValueError =
+ Some("Failed to read protobuf default value for field 'id': unsupported type"))))
+ val schema = StructType(Seq(
+ StructField("ok", IntegerType, nullable = true),
+ StructField("id", IntegerType, nullable = true)))
+
+ val infos = ProtobufSchemaExtractor.analyzeAllFields(
+ schema, msgDesc, enumsAsInts = true, "test.Message")
+
+ assert(infos.isRight)
+ assert(infos.toOption.get("ok").isSupported)
+ assert(!infos.toOption.get("id").isSupported)
+ assert(infos.toOption.get("id").unsupportedReason.exists(
+ _.contains("Failed to read protobuf default value for field 'id'")))
+ }
+
+ test("extractor preserves type mismatch reason over default reflection failure") {
+ val fieldInfo = ProtobufSchemaExtractor.extractFieldInfo(
+ StructField("id", StringType, nullable = true),
+ FakeFieldDescriptor(
+ name = "id",
+ fieldNumber = 1,
+ protoTypeName = "INT32",
+ defaultValueError =
+ Some("Failed to read protobuf default value for field 'id': unsupported type")),
+ enumsAsInts = true)
+
+ assert(fieldInfo.isRight)
+ assert(!fieldInfo.toOption.get.isSupported)
+ assert(fieldInfo.toOption.get.unsupportedReason.contains(
+ "type mismatch: Spark StringType vs Protobuf INT32"))
+ }
+
+ test("extractor gives explicit reason for unsupported FLOAT/DOUBLE widening mismatches") {
+ val doubleFromFloat = ProtobufSchemaExtractor.extractFieldInfo(
+ StructField("score", DoubleType, nullable = true),
+ FakeFieldDescriptor(
+ name = "score",
+ fieldNumber = 1,
+ protoTypeName = "FLOAT"),
+ enumsAsInts = true)
+ val floatFromDouble = ProtobufSchemaExtractor.extractFieldInfo(
+ StructField("score", FloatType, nullable = true),
+ FakeFieldDescriptor(
+ name = "score",
+ fieldNumber = 1,
+ protoTypeName = "DOUBLE"),
+ enumsAsInts = true)
+
+ assert(doubleFromFloat.isRight)
+ assert(!doubleFromFloat.toOption.get.isSupported)
+ assert(doubleFromFloat.toOption.get.unsupportedReason.contains(
+ "Spark DoubleType mapped to Protobuf FLOAT is not yet supported on GPU; " +
+ "use FloatType or fall back to CPU"))
+ assert(floatFromDouble.isRight)
+ assert(!floatFromDouble.toOption.get.isSupported)
+ assert(floatFromDouble.toOption.get.unsupportedReason.contains(
+ "Spark FloatType mapped to Protobuf DOUBLE is not yet supported on GPU; " +
+ "use DoubleType or fall back to CPU"))
+ }
+
+ test("validator encodes enum-string defaults into both numeric and string payloads") {
+ val enumMeta = ProtobufEnumMetadata(Seq(
+ ProtobufEnumValue(0, "UNKNOWN"),
+ ProtobufEnumValue(1, "EN")))
+ val info = ProtobufFieldInfo(
+ fieldNumber = 2,
+ protoTypeName = "ENUM",
+ sparkType = StringType,
+ encoding = GpuFromProtobuf.ENC_ENUM_STRING,
+ isSupported = true,
+ unsupportedReason = None,
+ isRequired = false,
+ defaultValue = Some(ProtobufDefaultValue.EnumValue(1, "EN")),
+ enumMetadata = Some(enumMeta),
+ isRepeated = false)
+
+ val flat = ProtobufSchemaValidator.toFlattenedFieldDescriptor(
+ path = "common.language",
+ field = StructField("language", StringType, nullable = true),
+ fieldInfo = info,
+ parentIdx = 0,
+ depth = 1,
+ outputTypeId = 6)
+
+ assert(flat.isRight)
+ assert(flat.toOption.get.defaultInt == 1L)
+ assert(new String(flat.toOption.get.defaultString, "UTF-8") == "EN")
+ assert(flat.toOption.get.enumValidValues.sameElements(Array(0, 1)))
+ assert(flat.toOption.get.enumNames
+ .map(new String(_, "UTF-8"))
+ .sameElements(Array("UNKNOWN", "EN")))
+ }
+
+ test("validator rejects enum-string field without enum metadata") {
+ val info = ProtobufFieldInfo(
+ fieldNumber = 2,
+ protoTypeName = "ENUM",
+ sparkType = StringType,
+ encoding = GpuFromProtobuf.ENC_ENUM_STRING,
+ isSupported = true,
+ unsupportedReason = None,
+ isRequired = false,
+ defaultValue = Some(ProtobufDefaultValue.EnumValue(1, "EN")),
+ enumMetadata = None,
+ isRepeated = false)
+
+ val flat = ProtobufSchemaValidator.toFlattenedFieldDescriptor(
+ path = "common.language",
+ field = StructField("language", StringType, nullable = true),
+ fieldInfo = info,
+ parentIdx = 0,
+ depth = 1,
+ outputTypeId = 6)
+
+ assert(flat.left.toOption.exists(_.contains("missing enum metadata")))
+ }
+
+ test("validator returns Left for incompatible default type instead of throwing") {
+ val info = ProtobufFieldInfo(
+ fieldNumber = 3,
+ protoTypeName = "FLOAT",
+ sparkType = DoubleType,
+ encoding = GpuFromProtobuf.ENC_DEFAULT,
+ isSupported = true,
+ unsupportedReason = None,
+ isRequired = false,
+ defaultValue = Some(ProtobufDefaultValue.FloatValue(1.5f)),
+ enumMetadata = None,
+ isRepeated = false)
+
+ val flat = ProtobufSchemaValidator.toFlattenedFieldDescriptor(
+ path = "common.score",
+ field = StructField("score", DoubleType, nullable = true),
+ fieldInfo = info,
+ parentIdx = 0,
+ depth = 1,
+ outputTypeId = 6)
+
+ assert(flat.left.toOption.exists(
+ _.contains("Incompatible default value for protobuf field 'common.score'")))
+ }
+
+ test("validator rejects flattened schema with non-STRUCT parent") {
+ val flatFields = Seq(
+ FlattenedFieldDescriptor(
+ fieldNumber = 1,
+ parentIdx = -1,
+ depth = 0,
+ wireType = 0,
+ outputTypeId = DType.INT32.getTypeId.getNativeId,
+ encoding = GpuFromProtobuf.ENC_DEFAULT,
+ isRepeated = false,
+ isRequired = false,
+ hasDefaultValue = false,
+ defaultInt = 0L,
+ defaultFloat = 0.0,
+ defaultBool = false,
+ defaultString = Array.emptyByteArray,
+ enumValidValues = null,
+ enumNames = null),
+ FlattenedFieldDescriptor(
+ fieldNumber = 2,
+ parentIdx = 0,
+ depth = 1,
+ wireType = 0,
+ outputTypeId = DType.INT32.getTypeId.getNativeId,
+ encoding = GpuFromProtobuf.ENC_DEFAULT,
+ isRepeated = false,
+ isRequired = false,
+ hasDefaultValue = false,
+ defaultInt = 0L,
+ defaultFloat = 0.0,
+ defaultBool = false,
+ defaultString = Array.emptyByteArray,
+ enumValidValues = null,
+ enumNames = null))
+
+ val validation = ProtobufSchemaValidator.validateFlattenedSchema(flatFields)
+ assert(validation.left.toOption.exists(_.contains("non-STRUCT parent")))
+ }
+
+ test("array struct field meta uses pruned child field count after ordinal remap") {
+ val originalStruct = StructType(Seq(
+ StructField("a", IntegerType, nullable = true),
+ StructField("b", IntegerType, nullable = true),
+ StructField("c", IntegerType, nullable = true)))
+ val prunedStruct = StructType(Seq(StructField("b", IntegerType, nullable = true)))
+ val originalChild = FakeTypedUnaryExpr(ArrayType(originalStruct, containsNull = true))
+ val sparkExpr = GetArrayStructFields(
+ child = originalChild,
+ field = originalStruct.fields(1),
+ ordinal = 1,
+ numFields = originalStruct.fields.length,
+ containsNull = true)
+ sparkExpr.setTagValue(GpuStructFieldOrdinalTag.PRUNED_ORDINAL_TAG, 0)
+
+ val prunedChild = FakeTypedUnaryExpr(ArrayType(prunedStruct, containsNull = true))
+ val runtimeOrd = sparkExpr.getTagValue(GpuStructFieldOrdinalTag.PRUNED_ORDINAL_TAG).get
+
+ assert(runtimeOrd == 0)
+ assert(
+ GpuGetArrayStructFieldsMeta.effectiveNumFields(prunedChild, sparkExpr, runtimeOrd) == 1)
+ }
+
+ test("GpuFromProtobuf semantic equality is content-based for schema arrays") {
+ def emptyEnumNames: Array[Array[Byte]] = Array.empty[Array[Byte]]
+
+ val expr1 = GpuFromProtobuf(
+ decodedSchema = outputSchema,
+ fieldNumbers = Array(1, 2),
+ parentIndices = Array(-1, -1),
+ depthLevels = Array(0, 0),
+ wireTypes = Array(0, 2),
+ outputTypeIds = Array(3, 6),
+ encodings = Array(0, 0),
+ isRepeated = Array(false, false),
+ isRequired = Array(false, false),
+ hasDefaultValue = Array(false, false),
+ defaultInts = Array(0L, 0L),
+ defaultFloats = Array(0.0, 0.0),
+ defaultBools = Array(false, false),
+ defaultStrings = Array(Array.emptyByteArray, Array.emptyByteArray),
+ enumValidValues = Array(Array.emptyIntArray, Array.emptyIntArray),
+ enumNames = Array(emptyEnumNames, emptyEnumNames),
+ failOnErrors = true,
+ child = FakeExprChild())
+
+ val expr2 = GpuFromProtobuf(
+ decodedSchema = outputSchema,
+ fieldNumbers = Array(1, 2),
+ parentIndices = Array(-1, -1),
+ depthLevels = Array(0, 0),
+ wireTypes = Array(0, 2),
+ outputTypeIds = Array(3, 6),
+ encodings = Array(0, 0),
+ isRepeated = Array(false, false),
+ isRequired = Array(false, false),
+ hasDefaultValue = Array(false, false),
+ defaultInts = Array(0L, 0L),
+ defaultFloats = Array(0.0, 0.0),
+ defaultBools = Array(false, false),
+ defaultStrings = Array(Array.emptyByteArray, Array.emptyByteArray),
+ enumValidValues = Array(Array.emptyIntArray, Array.emptyIntArray),
+ enumNames = Array(emptyEnumNames.map(identity), emptyEnumNames.map(identity)),
+ failOnErrors = true,
+ child = FakeExprChild())
+
+ assert(expr1.semanticEquals(expr2))
+ assert(expr1.semanticHash() == expr2.semanticHash())
+ }
+
+ test("protobuf binary defaults use content-based equality") {
+ val left = ProtobufDefaultValue.BinaryValue(Array[Byte](1, 2, 3))
+ val right = ProtobufDefaultValue.BinaryValue(Array[Byte](1, 2, 3))
+
+ assert(left == right)
+ assert(left.hashCode() == right.hashCode())
+ }
+
+ test("flattened field descriptor uses content-based equality for array fields") {
+ val left = FlattenedFieldDescriptor(
+ fieldNumber = 1,
+ parentIdx = -1,
+ depth = 0,
+ wireType = 2,
+ outputTypeId = 6,
+ encoding = 0,
+ isRepeated = false,
+ isRequired = false,
+ hasDefaultValue = true,
+ defaultInt = 0L,
+ defaultFloat = 0.0,
+ defaultBool = false,
+ defaultString = Array[Byte](1, 2),
+ enumValidValues = Array(0, 1),
+ enumNames = Array("A".getBytes("UTF-8"), "B".getBytes("UTF-8")))
+ val right = FlattenedFieldDescriptor(
+ fieldNumber = 1,
+ parentIdx = -1,
+ depth = 0,
+ wireType = 2,
+ outputTypeId = 6,
+ encoding = 0,
+ isRepeated = false,
+ isRequired = false,
+ hasDefaultValue = true,
+ defaultInt = 0L,
+ defaultFloat = 0.0,
+ defaultBool = false,
+ defaultString = Array[Byte](1, 2),
+ enumValidValues = Array(0, 1),
+ enumNames = Array("A".getBytes("UTF-8"), "B".getBytes("UTF-8")))
+
+ assert(left == right)
+ assert(left.hashCode() == right.hashCode())
+ }
+}