diff --git a/delta-lake/common/src/main/delta-33x-40x/scala/com/nvidia/spark/rapids/delta/common/DeltaProviderBase.scala b/delta-lake/common/src/main/delta-33x-40x/scala/com/nvidia/spark/rapids/delta/common/DeltaProviderBase.scala index bbf3459a705..ea108f4fef6 100644 --- a/delta-lake/common/src/main/delta-33x-40x/scala/com/nvidia/spark/rapids/delta/common/DeltaProviderBase.scala +++ b/delta-lake/common/src/main/delta-33x-40x/scala/com/nvidia/spark/rapids/delta/common/DeltaProviderBase.scala @@ -304,11 +304,15 @@ object DVPredicatePushdown extends ShimPredicateHelper { def mergeIdenticalProjects(plan: SparkPlan): SparkPlan = { plan.transformUp { case p @ GpuProjectExec(projList1, - GpuProjectExec(projList2, child, enablePreSplit1), enablePreSplit2) => + GpuProjectExec(projList2, child, enablePreSplit1), + enablePreSplit2) => val projSet1 = projList1.map(_.exprId).toSet val projSet2 = projList2.map(_.exprId).toSet if (projSet1 == projSet2) { - GpuProjectExec(projList1, child, enablePreSplit1 && enablePreSplit2) + GpuProjectExec( + projList1, + child, + enablePreSplit1 && enablePreSplit2) } else { p } diff --git a/docs/supported_ops.md b/docs/supported_ops.md index 888f4fa807c..196f518c6a7 100644 --- a/docs/supported_ops.md +++ b/docs/supported_ops.md @@ -8267,7 +8267,7 @@ are limited. -PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, UDT, DAYTIME, YEARMONTH
+PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types CALENDAR, UDT, DAYTIME, YEARMONTH
@@ -8290,7 +8290,7 @@ are limited. -PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, UDT, DAYTIME, YEARMONTH
+PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types CALENDAR, UDT, DAYTIME, YEARMONTH
diff --git a/integration_tests/run_pyspark_from_build.sh b/integration_tests/run_pyspark_from_build.sh index 0f2e2471b7b..4ede802a22b 100755 --- a/integration_tests/run_pyspark_from_build.sh +++ b/integration_tests/run_pyspark_from_build.sh @@ -46,6 +46,9 @@ # To run all tests, including Avro tests: # INCLUDE_SPARK_AVRO_JAR=true ./run_pyspark_from_build.sh # +# To run tests WITHOUT Protobuf tests (protobuf is included by default): +# INCLUDE_SPARK_PROTOBUF_JAR=false ./run_pyspark_from_build.sh +# # To run a specific test: # TEST=my_test ./run_pyspark_from_build.sh # @@ -141,9 +144,101 @@ else AVRO_JARS="" fi - # ALL_JARS includes dist.jar integration-test.jar avro.jar parquet.jar if they exist + # Protobuf support: Include spark-protobuf jar by default for protobuf_test.py + # Set INCLUDE_SPARK_PROTOBUF_JAR=false to disable + PROTOBUF_JARS="" + if [[ $( echo ${INCLUDE_SPARK_PROTOBUF_JAR} | tr '[:upper:]' '[:lower:]' ) != "false" ]]; + then + export INCLUDE_SPARK_PROTOBUF_JAR=true + mkdir -p "${TARGET_DIR}/dependency" + + # Download spark-protobuf jar if not already in target/dependency + PROTOBUF_JAR_NAME="spark-protobuf_${SCALA_VERSION}-${VERSION_STRING}.jar" + PROTOBUF_JAR_PATH="${TARGET_DIR}/dependency/${PROTOBUF_JAR_NAME}" + + if [[ ! -f "$PROTOBUF_JAR_PATH" ]]; then + echo "Downloading spark-protobuf jar..." + PROTOBUF_MAVEN_URL="https://repo1.maven.org/maven2/org/apache/spark/spark-protobuf_${SCALA_VERSION}/${VERSION_STRING}/${PROTOBUF_JAR_NAME}" + if curl -fsL -o "$PROTOBUF_JAR_PATH" "$PROTOBUF_MAVEN_URL"; then + echo "Downloaded spark-protobuf jar to $PROTOBUF_JAR_PATH" + else + echo "WARNING: Failed to download spark-protobuf jar from $PROTOBUF_MAVEN_URL" + rm -f "$PROTOBUF_JAR_PATH" + fi + fi + + # Also download protobuf-java jar (required dependency). + # Detect version from the jar bundled with Spark, fall back to version mapping. + PROTOBUF_JAVA_VERSION="" + BUNDLED_PB_JAR=$(ls "$SPARK_HOME"/jars/protobuf-java-[0-9]*.jar 2>/dev/null | sort -V | tail -1) + if [[ -n "$BUNDLED_PB_JAR" ]]; then + PROTOBUF_JAVA_VERSION=$(basename "$BUNDLED_PB_JAR" | sed 's/protobuf-java-\(.*\)\.jar/\1/') + echo "Detected protobuf-java version $PROTOBUF_JAVA_VERSION from SPARK_HOME" + fi + if [[ -z "$PROTOBUF_JAVA_VERSION" ]]; then + case "$VERSION_STRING" in + 3.4.*) PROTOBUF_JAVA_VERSION="3.25.1" ;; + 3.5.*) PROTOBUF_JAVA_VERSION="3.25.1" ;; + 4.0.*|4.1.*) PROTOBUF_JAVA_VERSION="4.29.3" ;; + *) PROTOBUF_JAVA_VERSION="3.25.1" ;; + esac + echo "Using protobuf-java version $PROTOBUF_JAVA_VERSION based on Spark $VERSION_STRING" + fi + PROTOBUF_JAVA_JAR_NAME="protobuf-java-${PROTOBUF_JAVA_VERSION}.jar" + PROTOBUF_JAVA_JAR_PATH="${TARGET_DIR}/dependency/${PROTOBUF_JAVA_JAR_NAME}" + + if [[ ! -f "$PROTOBUF_JAVA_JAR_PATH" ]]; then + echo "Downloading protobuf-java jar..." + PROTOBUF_JAVA_MAVEN_URL="https://repo1.maven.org/maven2/com/google/protobuf/protobuf-java/${PROTOBUF_JAVA_VERSION}/${PROTOBUF_JAVA_JAR_NAME}" + if curl -fsL -o "$PROTOBUF_JAVA_JAR_PATH" "$PROTOBUF_JAVA_MAVEN_URL"; then + echo "Downloaded protobuf-java jar to $PROTOBUF_JAVA_JAR_PATH" + else + echo "WARNING: Failed to download protobuf-java jar from $PROTOBUF_JAVA_MAVEN_URL" + rm -f "$PROTOBUF_JAVA_JAR_PATH" + fi + fi + + SPARK_PROTOBUF_JAR_AVAILABLE=false + PROTOBUF_JAVA_AVAILABLE=false + + if [[ -f "$PROTOBUF_JAR_PATH" ]]; then + PROTOBUF_JARS="$PROTOBUF_JAR_PATH" + echo "Including spark-protobuf jar: $PROTOBUF_JAR_PATH" + SPARK_PROTOBUF_JAR_AVAILABLE=true + fi + if [[ -f "$PROTOBUF_JAVA_JAR_PATH" ]]; then + PROTOBUF_JARS="${PROTOBUF_JARS:+$PROTOBUF_JARS }$PROTOBUF_JAVA_JAR_PATH" + echo "Including protobuf-java jar: $PROTOBUF_JAVA_JAR_PATH" + PROTOBUF_JAVA_AVAILABLE=true + elif [[ -n "$BUNDLED_PB_JAR" ]]; then + echo "Using bundled protobuf-java jar from SPARK_HOME: $BUNDLED_PB_JAR" + PROTOBUF_JAVA_AVAILABLE=true + fi + + if [[ "$SPARK_PROTOBUF_JAR_AVAILABLE" == "true" && \ + "$PROTOBUF_JAVA_AVAILABLE" == "true" ]]; then + export PROTOBUF_JARS_AVAILABLE=true + else + echo "WARNING: Protobuf JAR dependencies incomplete; protobuf tests will be skipped" + echo " spark-protobuf available: $SPARK_PROTOBUF_JAR_AVAILABLE" + echo " protobuf-java available: $PROTOBUF_JAVA_AVAILABLE" + export PROTOBUF_JARS_AVAILABLE=false + fi + # Also add protobuf jars to driver classpath for Class.forName() to work + # This is needed because --jars only adds to executor classpath + if [[ -n "$PROTOBUF_JARS" ]]; then + PROTOBUF_DRIVER_CP=$(echo "$PROTOBUF_JARS" | tr ' ' ':') + export PYSP_TEST_spark_driver_extraClassPath="${PYSP_TEST_spark_driver_extraClassPath:+${PYSP_TEST_spark_driver_extraClassPath}:}${PROTOBUF_DRIVER_CP}" + echo "Added protobuf jars to driver classpath" + fi + else + export INCLUDE_SPARK_PROTOBUF_JAR=false + export PROTOBUF_JARS_AVAILABLE=false + fi + + # ALL_JARS includes dist.jar integration-test.jar avro.jar parquet.jar protobuf.jar if they exist # Remove non-existing paths and canonicalize the paths including get rid of links and `..` - ALL_JARS=$(readlink -e $PLUGIN_JAR $TEST_JARS $AVRO_JARS $PARQUET_HADOOP_TESTS || true) + ALL_JARS=$(readlink -e $PLUGIN_JAR $TEST_JARS $AVRO_JARS $PARQUET_HADOOP_TESTS $PROTOBUF_JARS || true) # `:` separated jars ALL_JARS="${ALL_JARS//$'\n'/:}" @@ -411,6 +506,7 @@ else export PYSP_TEST_spark_gluten_loadLibFromJar=true fi + SPARK_SHELL_SMOKE_TEST="${SPARK_SHELL_SMOKE_TEST:-0}" EXPLAIN_ONLY_CPU_SMOKE_TEST="${EXPLAIN_ONLY_CPU_SMOKE_TEST:-0}" SPARK_CONNECT_SMOKE_TEST="${SPARK_CONNECT_SMOKE_TEST:-0}" diff --git a/integration_tests/src/main/python/data_gen.py b/integration_tests/src/main/python/data_gen.py index fa7decac82d..da8f780a1b6 100644 --- a/integration_tests/src/main/python/data_gen.py +++ b/integration_tests/src/main/python/data_gen.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2025, NVIDIA CORPORATION. +# Copyright (c) 2020-2026, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,6 +13,7 @@ # limitations under the License. import copy +from dataclasses import dataclass, replace from datetime import date, datetime, timedelta, timezone from decimal import * from enum import Enum @@ -857,6 +858,659 @@ def gen_bytes(): return bytes([ rand.randint(0, 255) for _ in range(length) ]) self._start(rand, gen_bytes) + +# ----------------------------------------------------------------------------- +# Protobuf schema-first test modeling / generation / encoding +# ----------------------------------------------------------------------------- + +_PROTOBUF_WIRE_VARINT = 0 +_PROTOBUF_WIRE_64BIT = 1 +_PROTOBUF_WIRE_LEN_DELIM = 2 +_PROTOBUF_WIRE_32BIT = 5 +_PB_MISSING = object() + + +def _encode_protobuf_uvarint(value): + """Encode a non-negative integer as protobuf varint.""" + if value is None: + raise ValueError("value must not be None") + if value < 0: + raise ValueError("uvarint only supports non-negative integers") + out = bytearray() + v = int(value) + while True: + b = v & 0x7F + v >>= 7 + if v: + out.append(b | 0x80) + else: + out.append(b) + break + return bytes(out) + + +def _encode_protobuf_key(field_number, wire_type): + return _encode_protobuf_uvarint((int(field_number) << 3) | int(wire_type)) + + +def _encode_protobuf_zigzag32(value): + return (int(value) << 1) ^ (int(value) >> 31) + + +def _encode_protobuf_zigzag64(value): + return (int(value) << 1) ^ (int(value) >> 63) + + +class PbCardinality(Enum): + OPTIONAL = 'optional' + REQUIRED = 'required' + REPEATED = 'repeated' + + +class PbScalarKind(Enum): + BOOL = 'bool' + INT32 = 'int32' + INT64 = 'int64' + UINT32 = 'uint32' + UINT64 = 'uint64' + SINT32 = 'sint32' + SINT64 = 'sint64' + FIXED32 = 'fixed32' + FIXED64 = 'fixed64' + SFIXED32 = 'sfixed32' + SFIXED64 = 'sfixed64' + FLOAT = 'float' + DOUBLE = 'double' + STRING = 'string' + BYTES = 'bytes' + ENUM = 'enum' + + +def _pb_scalar_kind_spark_type(kind): + if kind in {PbScalarKind.BOOL}: + return BooleanType() + if kind in {PbScalarKind.INT32, PbScalarKind.UINT32, PbScalarKind.SINT32, + PbScalarKind.FIXED32, PbScalarKind.SFIXED32, PbScalarKind.ENUM}: + return IntegerType() + if kind in {PbScalarKind.INT64, PbScalarKind.UINT64, PbScalarKind.SINT64, + PbScalarKind.FIXED64, PbScalarKind.SFIXED64}: + return LongType() + if kind == PbScalarKind.FLOAT: + return FloatType() + if kind == PbScalarKind.DOUBLE: + return DoubleType() + if kind == PbScalarKind.STRING: + return StringType() + if kind == PbScalarKind.BYTES: + return BinaryType() + raise ValueError(f'Unsupported protobuf scalar kind: {kind}') + + +@dataclass(frozen=True) +class PbEnumSpec: + name: str + values: tuple + + def __post_init__(self): + values = tuple((str(name), int(number)) for name, number in self.values) + if not values: + raise ValueError('enum spec must contain at least one value') + names = [name for name, _ in values] + numbers = [number for _, number in values] + if len(names) != len(set(names)): + raise ValueError(f'duplicate enum names in {self.name}') + if len(numbers) != len(set(numbers)): + raise ValueError(f'duplicate enum numbers in {self.name}') + object.__setattr__(self, 'values', values) + + def number_for(self, value): + if isinstance(value, str): + for name, number in self.values: + if name == value: + return number + raise ValueError(f'Unknown enum name {value!r} for enum {self.name}') + return int(value) + + +@dataclass(frozen=True) +class PbScalarFieldSpec: + name: str + number: int + kind: PbScalarKind + gen: object = None + cardinality: PbCardinality = PbCardinality.OPTIONAL + default: object = None + packed: bool = False + min_len: int = 0 + max_len: int = 5 + enum: object = None + + def __post_init__(self): + object.__setattr__(self, 'name', str(self.name)) + object.__setattr__(self, 'number', int(self.number)) + if self.number <= 0: + raise ValueError('field numbers must be positive') + if self.cardinality != PbCardinality.REPEATED and self.packed: + raise ValueError(f'packed encoding requires repeated cardinality: {self.name}') + if self.cardinality == PbCardinality.REPEATED and self.default is not None: + raise ValueError(f'repeated fields cannot have defaults: {self.name}') + if self.min_len < 0 or self.max_len < self.min_len: + raise ValueError(f'invalid repeated length bounds for {self.name}') + if self.kind == PbScalarKind.ENUM: + if self.enum is None: + raise ValueError(f'enum field requires enum spec: {self.name}') + elif self.enum is not None: + raise ValueError(f'non-enum field cannot carry enum spec: {self.name}') + if self.packed and self.kind not in { + PbScalarKind.BOOL, PbScalarKind.INT32, PbScalarKind.INT64, + PbScalarKind.UINT32, PbScalarKind.UINT64, PbScalarKind.SINT32, + PbScalarKind.SINT64, PbScalarKind.FIXED32, PbScalarKind.FIXED64, + PbScalarKind.SFIXED32, PbScalarKind.SFIXED64, PbScalarKind.FLOAT, + PbScalarKind.DOUBLE, PbScalarKind.ENUM}: + raise ValueError(f'packed encoding is not supported for {self.kind.value}: {self.name}') + + +@dataclass(frozen=True) +class PbMessageFieldSpec: + name: str + number: int + fields: tuple + cardinality: PbCardinality = PbCardinality.OPTIONAL + min_len: int = 0 + max_len: int = 5 + + def __post_init__(self): + object.__setattr__(self, 'name', str(self.name)) + object.__setattr__(self, 'number', int(self.number)) + object.__setattr__(self, 'fields', tuple(self.fields)) + if self.number <= 0: + raise ValueError('field numbers must be positive') + if self.min_len < 0 or self.max_len < self.min_len: + raise ValueError(f'invalid repeated length bounds for {self.name}') + if self.cardinality == PbCardinality.REQUIRED and self.min_len != 0: + raise ValueError('required message field cannot define repeated bounds') + + +@dataclass(frozen=True) +class PbMessageSpec: + name: str + fields: tuple + + def __post_init__(self): + object.__setattr__(self, 'name', str(self.name)) + object.__setattr__(self, 'fields', tuple(self.fields)) + _validate_pb_fields(self.fields, self.name) + + def as_datagen(self, binary_col_name='bin'): + return ProtobufRowGen(self, binary_col_name=binary_col_name) + + def encode(self, value): + return encode_pb_message(self, value) + + +def _validate_pb_fields(fields, owner_name): + names = [field.name for field in fields] + numbers = [field.number for field in fields] + if len(names) != len(set(names)): + raise ValueError(f'duplicate field names in {owner_name}') + if len(numbers) != len(set(numbers)): + raise ValueError(f'duplicate field numbers in {owner_name}') + + +class _PbBuilder: + def message(self, name, fields): + return PbMessageSpec(name, tuple(fields)) + + def bool(self, name, number, gen=None, default=None): + return PbScalarFieldSpec(name, number, PbScalarKind.BOOL, gen=gen, default=default) + + def int32(self, name, number, gen=None, default=None): + return PbScalarFieldSpec(name, number, PbScalarKind.INT32, gen=gen, default=default) + + def int64(self, name, number, gen=None, default=None): + return PbScalarFieldSpec(name, number, PbScalarKind.INT64, gen=gen, default=default) + + def uint32(self, name, number, gen=None, default=None): + return PbScalarFieldSpec(name, number, PbScalarKind.UINT32, gen=gen, default=default) + + def uint64(self, name, number, gen=None, default=None): + return PbScalarFieldSpec(name, number, PbScalarKind.UINT64, gen=gen, default=default) + + def sint32(self, name, number, gen=None, default=None): + return PbScalarFieldSpec(name, number, PbScalarKind.SINT32, gen=gen, default=default) + + def sint64(self, name, number, gen=None, default=None): + return PbScalarFieldSpec(name, number, PbScalarKind.SINT64, gen=gen, default=default) + + def fixed32(self, name, number, gen=None, default=None): + return PbScalarFieldSpec(name, number, PbScalarKind.FIXED32, gen=gen, default=default) + + def fixed64(self, name, number, gen=None, default=None): + return PbScalarFieldSpec(name, number, PbScalarKind.FIXED64, gen=gen, default=default) + + def sfixed32(self, name, number, gen=None, default=None): + return PbScalarFieldSpec(name, number, PbScalarKind.SFIXED32, gen=gen, default=default) + + def sfixed64(self, name, number, gen=None, default=None): + return PbScalarFieldSpec(name, number, PbScalarKind.SFIXED64, gen=gen, default=default) + + def float(self, name, number, gen=None, default=None): + return PbScalarFieldSpec(name, number, PbScalarKind.FLOAT, gen=gen, default=default) + + def double(self, name, number, gen=None, default=None): + return PbScalarFieldSpec(name, number, PbScalarKind.DOUBLE, gen=gen, default=default) + + def string(self, name, number, gen=None, default=None): + return PbScalarFieldSpec(name, number, PbScalarKind.STRING, gen=gen, default=default) + + def bytes(self, name, number, gen=None, default=None): + return PbScalarFieldSpec(name, number, PbScalarKind.BYTES, gen=gen, default=default) + + def enum_type(self, name, values): + return PbEnumSpec(name, tuple(values)) + + def enum_field(self, name, number, enum_spec, gen=None, default=None): + return PbScalarFieldSpec( + name, number, PbScalarKind.ENUM, gen=gen, default=default, enum=enum_spec) + + def message_field(self, name, number, fields): + return PbMessageFieldSpec(name, number, tuple(fields)) + + def nested(self, name, number, fields): + return self.message_field(name, number, fields) + + def repeated(self, field_spec, min_len=0, max_len=5, packed=False): + if isinstance(field_spec, PbScalarFieldSpec): + return replace( + field_spec, + cardinality=PbCardinality.REPEATED, + min_len=min_len, + max_len=max_len, + packed=packed) + if isinstance(field_spec, PbMessageFieldSpec): + if packed: + raise ValueError('message fields cannot be packed') + return replace( + field_spec, + cardinality=PbCardinality.REPEATED, + min_len=min_len, + max_len=max_len) + raise TypeError(f'Unsupported protobuf field for repeated(): {type(field_spec)}') + + def repeated_message(self, name, number, fields, min_len=0, max_len=5): + return self.repeated(self.message_field(name, number, fields), min_len=min_len, max_len=max_len) + + def required(self, field_spec): + if isinstance(field_spec, PbScalarFieldSpec): + if field_spec.default is not None: + raise ValueError('required fields cannot have defaults') + return replace(field_spec, cardinality=PbCardinality.REQUIRED) + if isinstance(field_spec, PbMessageFieldSpec): + return replace(field_spec, cardinality=PbCardinality.REQUIRED) + raise TypeError(f'Unsupported protobuf field for required(): {type(field_spec)}') + + +pb = _PbBuilder() + + +def _pb_gen_cache_repr(gen): + return 'None' if gen is None else gen._cache_repr() + + +def _pb_enum_cache_repr(enum_spec): + return 'None' if enum_spec is None else str(enum_spec.values) + + +def _pb_field_cache_repr(field_spec): + if isinstance(field_spec, PbScalarFieldSpec): + return ('Scalar(' + field_spec.name + ',' + str(field_spec.number) + ',' + + field_spec.kind.value + ',' + field_spec.cardinality.value + ',' + + str(field_spec.default) + ',' + str(field_spec.packed) + ',' + + str(field_spec.min_len) + ',' + str(field_spec.max_len) + ',' + + _pb_enum_cache_repr(field_spec.enum) + ',' + + _pb_gen_cache_repr(field_spec.gen) + ')') + children = ','.join(_pb_field_cache_repr(child) for child in field_spec.fields) + return ('Message(' + field_spec.name + ',' + str(field_spec.number) + ',' + + field_spec.cardinality.value + ',' + str(field_spec.min_len) + ',' + + str(field_spec.max_len) + ',[' + children + '])') + + +def _pb_message_cache_repr(message_spec): + children = ','.join(_pb_field_cache_repr(field_spec) for field_spec in message_spec.fields) + return 'PbMessage(' + message_spec.name + ',[' + children + '])' + + +class ProtobufEncoder: + def encode_message(self, message_spec, value): + if value is None: + return b'' + if not isinstance(message_spec, PbMessageSpec): + raise TypeError(f'encode_message expects PbMessageSpec, got {type(message_spec)}') + return self._encode_message_fields(message_spec.fields, value) + + def encode_field(self, field_spec, value): + if isinstance(field_spec, PbScalarFieldSpec): + return self._encode_scalar_field(field_spec, value) + return self._encode_message_field(field_spec, value) + + def _encode_message_fields(self, fields, value): + if not isinstance(value, dict): + raise TypeError(f'protobuf message values must be dicts, got {type(value)}') + unknown = set(value.keys()) - {field.name for field in fields} + if unknown: + raise ValueError(f'unknown protobuf field(s): {sorted(unknown)}') + return b''.join( + self.encode_field(field, value.get(field.name, _PB_MISSING)) + for field in fields) + + def _normalize_scalar_input(self, field_spec, value): + if field_spec.kind == PbScalarKind.ENUM: + return field_spec.enum.number_for(value) + if field_spec.kind == PbScalarKind.BOOL: + return bool(value) + if field_spec.kind in { + PbScalarKind.INT32, PbScalarKind.INT64, PbScalarKind.UINT32, + PbScalarKind.UINT64, PbScalarKind.SINT32, PbScalarKind.SINT64, + PbScalarKind.FIXED32, PbScalarKind.FIXED64, PbScalarKind.SFIXED32, + PbScalarKind.SFIXED64}: + return int(value) + if field_spec.kind in {PbScalarKind.FLOAT, PbScalarKind.DOUBLE}: + return float(value) + if field_spec.kind == PbScalarKind.STRING: + return str(value) + if field_spec.kind == PbScalarKind.BYTES: + return value if isinstance(value, bytes) else bytes(value) + raise ValueError(f'Unsupported scalar kind: {field_spec.kind}') + + def _encode_scalar_payload(self, field_spec, value): + scalar_value = self._normalize_scalar_input(field_spec, value) + kind = field_spec.kind + if kind == PbScalarKind.BOOL: + return _PROTOBUF_WIRE_VARINT, _encode_protobuf_uvarint(1 if scalar_value else 0) + if kind in {PbScalarKind.INT32, PbScalarKind.INT64, PbScalarKind.ENUM}: + u64 = int(scalar_value) & 0xFFFFFFFFFFFFFFFF + return _PROTOBUF_WIRE_VARINT, _encode_protobuf_uvarint(u64) + if kind == PbScalarKind.UINT32: + scalar_value = int(scalar_value) + if scalar_value < 0: + raise ValueError(f'uint32 field cannot encode negative value: {field_spec.name}') + return _PROTOBUF_WIRE_VARINT, _encode_protobuf_uvarint(scalar_value) + if kind == PbScalarKind.UINT64: + scalar_value = int(scalar_value) + if scalar_value < 0: + raise ValueError(f'uint64 field cannot encode negative value: {field_spec.name}') + return _PROTOBUF_WIRE_VARINT, _encode_protobuf_uvarint(scalar_value) + if kind == PbScalarKind.SINT32: + return _PROTOBUF_WIRE_VARINT, _encode_protobuf_uvarint(_encode_protobuf_zigzag32(scalar_value)) + if kind == PbScalarKind.SINT64: + return _PROTOBUF_WIRE_VARINT, _encode_protobuf_uvarint(_encode_protobuf_zigzag64(scalar_value)) + if kind in {PbScalarKind.FIXED32, PbScalarKind.SFIXED32}: + return _PROTOBUF_WIRE_32BIT, struct.pack(' bool: + """ + `spark-protobuf` is an optional external module. PySpark may have the Python wrappers + even when the JVM side isn't present on the classpath, which manifests as: + TypeError: 'JavaPackage' object is not callable + when calling into `sc._jvm.org.apache.spark.sql.protobuf.functions.from_protobuf`. + + In the integration harness, Spark jars are often attached dynamically. Using the current + thread's context classloader is more reliable than the default `Class.forName()` lookup. + """ + jvm = spark.sparkContext._jvm + loader = None + try: + loader = jvm.Thread.currentThread().getContextClassLoader() + except Exception: + pass + candidates = [ + # Scala object `functions` compiles to `functions$` + "org.apache.spark.sql.protobuf.functions$", + # Some environments may expose it differently + "org.apache.spark.sql.protobuf.functions", + ] + for cls in candidates: + try: + if loader is not None: + jvm.java.lang.Class.forName(cls, True, loader) + else: + jvm.java.lang.Class.forName(cls) + return True + except Exception: + continue + + # Fallback: try to resolve the JVM member through Py4J. A missing optional module typically + # stays as a JavaPackage placeholder instead of a callable JavaMember/JavaClass. + try: + member = jvm.org.apache.spark.sql.protobuf.functions.from_protobuf + return type(member).__name__ != "JavaPackage" + except Exception: + return False + + +# --------------------------------------------------------------------------- +# Shared fixture and helpers to reduce per-test boilerplate +# --------------------------------------------------------------------------- + +@pytest.fixture(scope="module") +def from_protobuf_fn(): + """Skip the entire module if from_protobuf or the JVM module is unavailable.""" + fn = _try_import_from_protobuf() + if fn is None: + pytest.skip("from_protobuf not available") + if not with_cpu_session(_spark_protobuf_jvm_available): + pytest.skip("spark-protobuf JVM not available") + return fn + + +def _setup_protobuf_desc(spark_tmp_path, desc_name, build_fn): + """Build descriptor bytes via JVM, write to HDFS, return (desc_path, desc_bytes).""" + desc_path = spark_tmp_path + "/" + desc_name + desc_bytes = with_cpu_session(build_fn) + with_cpu_session( + lambda spark: _write_bytes_to_hadoop_path(spark, desc_path, desc_bytes)) + return desc_path, desc_bytes + + +def _call_from_protobuf(from_protobuf_fn, col, message_name, + desc_path, desc_bytes, options=None): + """Call from_protobuf using the right API variant.""" + sig = inspect.signature(from_protobuf_fn) + if "binaryDescriptorSet" in sig.parameters: + kw = dict(binaryDescriptorSet=bytearray(desc_bytes)) + if options is not None: + kw["options"] = options + return from_protobuf_fn(col, message_name, **kw) + if options is not None: + return from_protobuf_fn(col, message_name, desc_path, options) + return from_protobuf_fn(col, message_name, desc_path) + + +def test_call_from_protobuf_preserves_options_for_legacy_signature(): + calls = [] + + def fake_from_protobuf(col, message_name, desc_path, *args): + calls.append((col, message_name, desc_path, args)) + return "ok" + + options = {"enums.as.ints": "true"} + result = _call_from_protobuf( + fake_from_protobuf, "col", "msg", "/tmp/test.desc", b"desc", options=options) + + assert result == "ok" + assert calls == [("col", "msg", "/tmp/test.desc", (options,))] + + +def test_encode_protobuf_packed_repeated_fixed_uses_unsigned_twos_complement(): + i32_encoded = _encode_protobuf_packed_repeated( + 1, IntegerType(), [0xFFFFFFFF], encoding='fixed') + i64_encoded = _encode_protobuf_packed_repeated( + 1, LongType(), [0xFFFFFFFFFFFFFFFF], encoding='fixed') + + assert i32_encoded == b"\x0a\x04" + struct.pack(" 1, 1 -> 2, -2 -> 3, 2 -> 4, etc. + """ + desc_path, desc_bytes = _setup_protobuf_desc( + spark_tmp_path, "signed.desc", _build_signed_int_descriptor_set_bytes) + message_name = "test.WithSignedInts" + + data_gen = _as_datagen([ + _scalar("si32", 1, IntegerGen( + special_cases=[-1, 0, 1, -2147483648, 2147483647]), encoding='zigzag'), + _scalar("si64", 2, LongGen( + special_cases=[-1, 0, 1, -9223372036854775808, 9223372036854775807]), + encoding='zigzag'), + _scalar("sf32", 3, IntegerGen( + special_cases=[0, 1, -1, 2147483647, -2147483648]), encoding='fixed'), + _scalar("sf64", 4, LongGen( + special_cases=[0, 1, -1]), encoding='fixed'), + ]) + + def run_on_spark(spark): + df = gen_df(spark, data_gen) + decoded = _call_from_protobuf( + from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes) + + return df.select( + decoded.getField("si32").alias("si32"), + decoded.getField("si64").alias("si64"), + decoded.getField("sf32").alias("sf32"), + decoded.getField("sf64").alias("sf64"), + ) + + assert_gpu_and_cpu_are_equal_collect(run_on_spark) + + +def _build_fixed_int_descriptor_set_bytes(spark): + """Build a descriptor for fixed-width integer fields.""" + return _build_proto2_descriptor(spark, "fixed_int.proto", [ + _msg("WithFixedInts", [ + _field("fx32", 1, "FIXED32"), + _field("fx64", 2, "FIXED64"), + ]), + ]) + + +@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+") +@ignore_order(local=True) +def test_from_protobuf_fixed_integers(spark_tmp_path, from_protobuf_fn): + """ + Test decoding fixed-width unsigned integer types. + """ + desc_path, desc_bytes = _setup_protobuf_desc( + spark_tmp_path, "fixed.desc", _build_fixed_int_descriptor_set_bytes) + message_name = "test.WithFixedInts" + + data_gen = _as_datagen([ + _scalar("fx32", 1, IntegerGen( + special_cases=[0, 1, -1, 2147483647, -2147483648]), encoding='fixed'), + _scalar("fx64", 2, LongGen( + special_cases=[0, 1, -1]), encoding='fixed'), + ]) + + def run_on_spark(spark): + df = gen_df(spark, data_gen) + decoded = _call_from_protobuf( + from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes) + return df.select( + decoded.getField("fx32").alias("fx32"), + decoded.getField("fx64").alias("fx64"), + ) + + assert_gpu_and_cpu_are_equal_collect(run_on_spark) + + +def _build_schema_projection_descriptor_set_bytes(spark): + """Build a descriptor with nested and repeated struct fields for pruning tests.""" + return _build_proto2_descriptor(spark, "schema_proj.proto", [ + _msg("Detail", [ + _field("a", 1, "INT32"), + _field("b", 2, "INT32"), + _field("c", 3, "STRING"), + ]), + _msg("SchemaProj", [ + _field("id", 1, "INT32"), + _field("name", 2, "STRING"), + _field("detail", 3, "MESSAGE", type_name=".test.Detail"), + _field("items", 4, "MESSAGE", label="repeated", type_name=".test.Detail"), + ]), + ]) + + +# Field descriptors for SchemaProj: {id, name, detail: {a, b, c}, items[]: {a, b, c}} +_detail_children = [ + _scalar("a", 1, IntegerGen()), + _scalar("b", 2, IntegerGen()), + _scalar("c", 3, StringGen()), +] +_schema_proj_schema = _schema("SchemaProjManual", [ + _scalar("id", 1, IntegerGen()), + _scalar("name", 2, StringGen()), + _nested("detail", 3, _detail_children), + _repeated_message("items", 4, _detail_children), +]) + +_schema_proj_test_data = [ + encode_pb_message(_schema_proj_schema, { + "id": 1, + "name": "alice", + "detail": {"a": 10, "b": 20, "c": "d1"}, + "items": [ + {"a": 100, "b": 200, "c": "i1"}, + {"a": 101, "b": 201, "c": "i2"}, + ], + }), + encode_pb_message(_schema_proj_schema, { + "id": 2, + "name": "bob", + "detail": {"a": 30, "b": 40, "c": "d2"}, + "items": [ + {"a": 300, "b": 400, "c": "i3"}, + ], + }), + encode_pb_message(_schema_proj_schema, { + "id": 3, + "name": "carol", + "detail": {"a": 50, "b": 60, "c": "d3"}, + "items": [], + }), +] + + +_schema_proj_cases = [ + ("nested_single_field", [("id", ("id",)), ("detail_a", ("detail", "a"))]), + ("nested_two_fields", [("detail_a", ("detail", "a")), ("detail_c", ("detail", "c"))]), + ("whole_struct_no_pruning", [("id", ("id",)), ("detail", ("detail",))]), + ("whole_and_subfield", [("detail", ("detail",)), ("detail_a", ("detail", "a"))]), + ("scalar_plus_nested", [("id", ("id",)), ("name", ("name",)), ("detail_a", ("detail", "a"))]), + ("repeated_msg_single_subfield", [("id", ("id",)), ("items_a", ("items", "a"))]), + ("repeated_msg_two_subfields", [("items_a", ("items", "a")), ("items_c", ("items", "c"))]), + ("repeated_whole_no_pruning", [("id", ("id",)), ("items", ("items",))]), + ("mix_struct_and_repeated", [("id", ("id",)), ("detail_a", ("detail", "a")), ("items_c", ("items", "c"))]), +] + + +def _get_field_by_path(expr, path): + current = expr + for name in path: + current = current.getField(name) + return current + + +@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+") +@pytest.mark.parametrize("boundary", ["alias", "withcolumn"], ids=idfn) +@ignore_order(local=True) +def test_from_protobuf_projection_across_plan_boundary( + spark_tmp_path, from_protobuf_fn, boundary): + """Schema projection across alias (select→select) and withColumn plan boundaries.""" + desc_path, desc_bytes = _setup_protobuf_desc( + spark_tmp_path, "schema_proj_boundary.desc", + _build_schema_projection_descriptor_set_bytes) + message_name = "test.SchemaProj" + + def run_on_spark(spark): + df = spark.createDataFrame([(row,) for row in _schema_proj_test_data], schema="bin binary") + decoded = _call_from_protobuf( + from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes) + if boundary == "alias": + aliased = df.select(decoded.alias("decoded")) + return aliased.select( + f.col("decoded").getField("detail").getField("a").alias("detail_a"), + f.col("decoded").getField("id").alias("id")) + else: + with_decoded = df.withColumn("decoded", decoded) + return with_decoded.select( + f.col("decoded").getField("items").getField("a").alias("items_a"), + f.col("decoded").getField("id").alias("id")) + + assert_gpu_and_cpu_are_equal_collect(run_on_spark) + + +def _build_dual_message_projection_descriptor_set_bytes(spark): + """Build descriptors for two logical views over the same binary payload column.""" + return _build_proto2_descriptor(spark, "dual_projection.proto", [ + _msg("NestedPayload", [_field("count", 1, "INT32")]), + _msg("BytesView", [ + _field("status", 1, "INT32"), + _field("payload", 2, "BYTES"), + ]), + _msg("NestedView", [ + _field("status", 1, "INT32"), + _field("payload", 2, "MESSAGE", type_name=".test.NestedPayload"), + ]), + ]) + + +@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+") +@ignore_order(local=True) +def test_from_protobuf_different_messages_same_binary_column_do_not_interfere( + spark_tmp_path, from_protobuf_fn): + desc_path, desc_bytes = _setup_protobuf_desc( + spark_tmp_path, "dual_projection.desc", + _build_dual_message_projection_descriptor_set_bytes) + + payload_keep = _encode_tag(1, 0) + _encode_varint(7) + payload_drop = _encode_tag(1, 0) + _encode_varint(9) + row_keep = (_encode_tag(1, 0) + _encode_varint(1) + + _encode_tag(2, 2) + _encode_varint(len(payload_keep)) + payload_keep) + row_drop = (_encode_tag(1, 0) + _encode_varint(0) + + _encode_tag(2, 2) + _encode_varint(len(payload_drop)) + payload_drop) + + def run_on_spark(spark): + df = spark.createDataFrame([(row_keep,), (row_drop,)], schema="bin binary") + bytes_view = _call_from_protobuf( + from_protobuf_fn, f.col("bin"), "test.BytesView", desc_path, desc_bytes) + nested_view = _call_from_protobuf( + from_protobuf_fn, f.col("bin"), "test.NestedView", desc_path, desc_bytes) + return df.filter(bytes_view.getField("status") == 1).select( + nested_view.getField("payload").getField("count").alias("payload_count")) + + assert_gpu_and_cpu_are_equal_collect(run_on_spark) + + +def _build_deep_nested_5_level_descriptor_set_bytes(spark): + """Build a descriptor for a five-level nested message chain.""" + return _build_proto2_descriptor(spark, "deep_nested_5_level.proto", [ + _msg("Level5", [_field("val5", 1, "INT32")]), + _msg("Level4", [ + _field("val4", 1, "INT32"), + _field("level5", 2, "MESSAGE", type_name=".test.Level5"), + ]), + _msg("Level3", [ + _field("val3", 1, "INT32"), + _field("level4", 2, "MESSAGE", type_name=".test.Level4"), + ]), + _msg("Level2", [ + _field("val2", 1, "INT32"), + _field("level3", 2, "MESSAGE", type_name=".test.Level3"), + ]), + _msg("Level1", [ + _field("val1", 1, "INT32"), + _field("level2", 2, "MESSAGE", type_name=".test.Level2"), + ]), + ]) + + +@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+") +@ignore_order(local=True) +def test_from_protobuf_deep_nesting_5_levels(spark_tmp_path, from_protobuf_fn): + desc_path, desc_bytes = _setup_protobuf_desc( + spark_tmp_path, "deep_nested_5_level.desc", + _build_deep_nested_5_level_descriptor_set_bytes) + message_name = "test.Level1" + data_gen = _as_datagen([ + _scalar("val1", 1, IntegerGen()), + _nested("level2", 2, [ + _scalar("val2", 1, IntegerGen()), + _nested("level3", 2, [ + _scalar("val3", 1, IntegerGen()), + _nested("level4", 2, [ + _scalar("val4", 1, IntegerGen()), + _nested("level5", 2, [ + _scalar("val5", 1, IntegerGen()), + ]), + ]), + ]), + ]), + ]) + def run_on_spark(spark): + df = gen_df(spark, data_gen) + decoded = _call_from_protobuf( + from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes) + return df.select( + decoded.getField("val1").alias("val1"), + decoded.getField("level2").alias("level2"), + ) + assert_gpu_and_cpu_are_equal_collect(run_on_spark) + + +@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+") +@pytest.mark.parametrize("case_id,select_specs", _schema_proj_cases, ids=lambda c: c[0] if isinstance(c, tuple) else str(c)) +@ignore_order(local=True) +def test_from_protobuf_schema_projection_cases( + spark_tmp_path, from_protobuf_fn, case_id, select_specs): + """Parametrized nested-schema projection tests.""" + desc_path, desc_bytes = _setup_protobuf_desc( + spark_tmp_path, "schema_proj.desc", _build_schema_projection_descriptor_set_bytes) + message_name = "test.SchemaProj" + + def run_on_spark(spark): + df = spark.createDataFrame( + [(d,) for d in _schema_proj_test_data], schema="bin binary") + decoded = _call_from_protobuf( + from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes) + selected = [_get_field_by_path(decoded, path).alias(alias) + for alias, path in select_specs] + return df.select(*selected) + + assert_gpu_and_cpu_are_equal_collect(run_on_spark) + +def _build_name_collision_descriptor_set_bytes(spark): + """Build a regression schema with same-named fields in unrelated nested messages.""" + return _build_proto2_descriptor(spark, "name_collision.proto", [ + _msg("User", [ + _field("age", 1, "INT32"), + _field("id", 2, "INT32"), + ]), + _msg("Ad", [_field("id", 1, "INT32")]), + _msg("Event", [ + _field("user_info", 1, "MESSAGE", type_name=".test.User"), + _field("ad_info", 2, "MESSAGE", type_name=".test.Ad"), + ]), + ]) + +@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+") +@ignore_order(local=True) +def test_from_protobuf_bug1_name_collision(spark_tmp_path, from_protobuf_fn): + desc_path, desc_bytes = _setup_protobuf_desc( + spark_tmp_path, "name_collision.desc", + _build_name_collision_descriptor_set_bytes) + message_name = "test.Event" + + data_gen = _as_datagen([ + _nested("user_info", 1, [ + _scalar("age", 1, IntegerGen()), + _scalar("id", 2, IntegerGen()), + ]), + _nested("ad_info", 2, [ + _scalar("id", 1, IntegerGen()), + ]), + ]) + + def run_on_spark(spark): + df = gen_df(spark, data_gen) + decoded = _call_from_protobuf( + from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes) + + return df.select( + decoded.getField("user_info").getField("age").alias("age"), + decoded.getField("user_info").getField("id").alias("user_id"), + decoded.getField("ad_info").getField("id").alias("ad_id") + ) + + assert_gpu_and_cpu_are_equal_collect(run_on_spark) + + +def _build_filter_jump_descriptor_set_bytes(spark): + """Build a minimal descriptor used by the filter-jump regression test.""" + return _build_proto2_descriptor(spark, "filter_jump.proto", [ + _msg("Event", [ + _field("status", 1, "INT32"), + _field("ad_info", 2, "STRING"), + ]), + ]) + +@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+") +@ignore_order(local=True) +def test_from_protobuf_bug2_filter_jump(spark_tmp_path, from_protobuf_fn): + desc_path, desc_bytes = _setup_protobuf_desc( + spark_tmp_path, "filter_jump.desc", + _build_filter_jump_descriptor_set_bytes) + message_name = "test.Event" + + data_gen = _as_datagen([ + _scalar("status", 1, IntegerGen(min_val=1, max_val=1)), + _scalar("ad_info", 2, StringGen()), + ]) + + def run_on_spark(spark): + df = gen_df(spark, data_gen) + pb_expr1 = _call_from_protobuf( + from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes) + pb_expr2 = _call_from_protobuf( + from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes) + + return df.filter(pb_expr1.getField("status") == 1).select(pb_expr2.getField("ad_info").alias("ad_info")) + + assert_gpu_and_cpu_are_equal_collect(run_on_spark) + + +def _build_unrelated_struct_name_collision_descriptor_set_bytes(spark): + """Build a regression schema where an unrelated nested struct shares a field name.""" + return _build_proto2_descriptor(spark, "unrelated_struct.proto", [ + _msg("Nested", [ + _field("dummy", 1, "INT32"), + _field("winfoid", 2, "INT32"), + ]), + _msg("Event", [ + _field("ad_info", 1, "MESSAGE", type_name=".test.Nested"), + ]), + ]) + +@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+") +@ignore_order(local=True) +def test_from_protobuf_bug3_unrelated_struct_name_collision(spark_tmp_path, from_protobuf_fn): + desc_path, desc_bytes = _setup_protobuf_desc( + spark_tmp_path, "unrelated_struct.desc", + _build_unrelated_struct_name_collision_descriptor_set_bytes) + message_name = "test.Event" + + data_gen = _as_datagen([ + _nested("ad_info", 1, [ + _scalar("dummy", 1, IntegerGen()), + _scalar("winfoid", 2, IntegerGen()), + ]), + ]) + + def run_on_spark(spark): + df = gen_df(spark, data_gen) + # Write to parquet to prevent Catalyst from optimizing away the GetStructField, + # and to ensure it runs on the GPU. + df_with_other = df.withColumn("other_struct", + f.struct(f.lit("hello").alias("dummy_str"), f.lit(42).alias("winfoid"))) + + path = spark_tmp_path + "/bug3_data.parquet" + df_with_other.write.mode("overwrite").parquet(path) + read_df = spark.read.parquet(path) + + decoded = _call_from_protobuf( + from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes) + + # We only select decoded.ad_info.winfoid, so dummy is pruned. + # winfoid gets ordinal 0 in the pruned schema. + # But for other_struct, winfoid is ordinal 1. + # GpuGetStructFieldMeta will see "winfoid", query the ThreadLocal, get 0, + # and extract ordinal 0 ("hello") for other_winfoid, causing a mismatch! + return read_df.select( + decoded.getField("ad_info").getField("winfoid").alias("pb_winfoid"), + f.col("other_struct").getField("winfoid").alias("other_winfoid") + ) + + assert_gpu_and_cpu_are_equal_collect(run_on_spark) + + +def _build_max_depth_descriptor_set_bytes(spark): + """Build a descriptor for a 12-level nested message chain.""" + messages = [] + for i in range(12, 0, -1): + fields = [_field(f"val{i}", 1, "INT32")] + if i < 12: + fields.append( + _field(f"level{i + 1}", 2, "MESSAGE", type_name=f".test.Level{i + 1}") + ) + messages.append(_msg(f"Level{i}", fields)) + return _build_proto2_descriptor(spark, "max_depth.proto", messages) + +@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+") +@ignore_order(local=True) +def test_from_protobuf_bug4_max_depth(spark_tmp_path, from_protobuf_fn): + desc_path, desc_bytes = _setup_protobuf_desc( + spark_tmp_path, "max_depth.desc", + _build_max_depth_descriptor_set_bytes) + message_name = "test.Level1" + + # Build the deeply nested data gen spec + def build_nested_gen(level): + if level == 12: + return [_scalar(f"val{level}", 1, IntegerGen())] + return [ + _scalar(f"val{level}", 1, IntegerGen()), + _nested(f"level{level+1}", 2, build_nested_gen(level+1)) + ] + + data_gen = _as_datagen(build_nested_gen(1)) + + def run_on_spark(spark): + df = gen_df(spark, data_gen) + decoded = _call_from_protobuf( + from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes) + # Deep access + field = decoded + for i in range(2, 13): + field = field.getField(f"level{i}") + return df.select(field.getField("val12").alias("val12")) + + # Depth 12 exceeds GPU max nesting depth (10), so the query should + # gracefully fall back to CPU. Verify that it still produces correct + # results (CPU path) without crashing. + from spark_session import with_cpu_session + cpu_result = with_cpu_session(lambda spark: run_on_spark(spark).collect()) + assert len(cpu_result) > 0 + + +# =========================================================================== +# Regression tests for known bugs found by code review +# =========================================================================== + +def _encode_varint(value): + """Encode a non-negative integer as a protobuf varint (for hand-crafting test bytes).""" + out = bytearray() + v = int(value) + while True: + b = v & 0x7F + v >>= 7 + if v: + out.append(b | 0x80) + else: + out.append(b) + break + return bytes(out) + + +def _encode_tag(field_number, wire_type): + return _encode_varint((field_number << 3) | wire_type) + + +# --------------------------------------------------------------------------- +# Bug 1: BOOL8 truncation — non-canonical bool varint values >= 256 +# +# Protobuf spec: bool is a varint; any non-zero value means true. +# CPU decoder (protobuf-java): CodedInputStream.readBool() = readRawVarint64() != 0 → true +# GPU decoder: extract_varint_kernel writes static_cast(v). +# For v = 256, static_cast(256) == 0 → false. BUG. +# --------------------------------------------------------------------------- + +@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+") +@ignore_order(local=True) +def test_from_protobuf_bool_noncanonical_varint_scalar(spark_tmp_path, from_protobuf_fn): + """Regression test: scalar bool encoded as non-canonical varint (e.g. 256) must decode as true. + + Protobuf allows any non-zero varint for bool true. The GPU decoder previously + truncated to uint8_t, causing values >= 256 to wrap to 0 (false). + """ + desc_path, desc_bytes = _setup_protobuf_desc( + spark_tmp_path, "simple_bool_bug.desc", _build_simple_descriptor_set_bytes) + message_name = "test.Simple" + + # varint(256) = 0x80 0x02, varint(512) = 0x80 0x04 — valid non-canonical bool true + row_bool_256 = _encode_tag(1, 0) + _encode_varint(256) + \ + _encode_tag(2, 0) + _encode_varint(99) + + # Control row: canonical bool true (varint 1) — should work on both + row_bool_1 = _encode_tag(1, 0) + _encode_varint(1) + \ + _encode_tag(2, 0) + _encode_varint(100) + + # Another non-canonical value: varint(512) + row_bool_512 = _encode_tag(1, 0) + _encode_varint(512) + \ + _encode_tag(2, 0) + _encode_varint(101) + + def run_on_spark(spark): + df = spark.createDataFrame( + [(row_bool_256,), (row_bool_1,), (row_bool_512,)], + schema="bin binary", + ) + decoded = _call_from_protobuf( + from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes) + return df.select( + decoded.getField("b").alias("b"), + decoded.getField("i32").alias("i32"), + ) + + assert_gpu_and_cpu_are_equal_collect(run_on_spark) + + +def _build_repeated_bool_descriptor_set_bytes(spark): + """Build a descriptor for an optional id plus repeated bool flags.""" + return _build_proto2_descriptor(spark, "repeated_bool.proto", [ + _msg("WithRepeatedBool", [ + _field("id", 1, "INT32"), + _field("flags", 2, "BOOL", label="repeated"), + ]), + ]) + + +@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+") +@ignore_order(local=True) +def test_from_protobuf_bool_noncanonical_varint_repeated(spark_tmp_path, from_protobuf_fn): + """Regression test: repeated bool with non-canonical varint values must all decode as true. + + Same uint8_t truncation issue as the scalar case, exercised with repeated fields. + """ + desc_path, desc_bytes = _setup_protobuf_desc( + spark_tmp_path, "repeated_bool_bug.desc", _build_repeated_bool_descriptor_set_bytes) + message_name = "test.WithRepeatedBool" + + # Repeated bool field 2 (wire type 0 = varint), unpacked. + # Three elements: varint(256), varint(1), varint(512) — all should decode as true. + row = (_encode_tag(1, 0) + _encode_varint(42) + + _encode_tag(2, 0) + _encode_varint(256) + + _encode_tag(2, 0) + _encode_varint(1) + + _encode_tag(2, 0) + _encode_varint(512)) + + def run_on_spark(spark): + df = spark.createDataFrame([(row,)], schema="bin binary") + decoded = _call_from_protobuf( + from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes) + return df.select( + decoded.getField("id").alias("id"), + decoded.getField("flags").alias("flags"), + ) + + assert_gpu_and_cpu_are_equal_collect(run_on_spark) + + +# --------------------------------------------------------------------------- +# Regression guard: nested message child field default values +# --------------------------------------------------------------------------- + +def _build_nested_with_defaults_descriptor_set_bytes(spark): + """Build a descriptor with proto2 defaults inside a nested child struct.""" + return _build_proto2_descriptor(spark, "nested_defaults.proto", [ + _msg("Inner", [ + _field("count", 1, "INT32", default=42), + _field("label", 2, "STRING", default="hello"), + _field("flag", 3, "BOOL", default=True), + ]), + _msg("OuterWithNestedDefaults", [ + _field("id", 1, "INT32"), + _field("inner", 2, "MESSAGE", type_name=".test.Inner"), + ]), + ]) + + +@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+") +@ignore_order(local=True) +def test_from_protobuf_nested_child_default_values(spark_tmp_path, from_protobuf_fn): + """Regression test: proto2 default values for fields inside nested messages must be honored. + + When `inner` is present but its child fields are absent, the decoder must + return the proto2 defaults (count=42, label="hello", flag=true), not null. + """ + desc_path, desc_bytes = _setup_protobuf_desc( + spark_tmp_path, "nested_defaults.desc", + _build_nested_with_defaults_descriptor_set_bytes) + message_name = "test.OuterWithNestedDefaults" + + # Row 1: outer.id = 10, inner is present but EMPTY (0-length nested message). + # Wire: field 1 varint(10), field 2 length-delimited with length 0. + # CPU should fill inner.count=42, inner.label="hello", inner.flag=true. + row_empty_inner = (_encode_tag(1, 0) + _encode_varint(10) + + _encode_tag(2, 2) + _encode_varint(0)) + + # Row 2: outer.id = 20, inner has only count=7 (label and flag should get defaults). + inner_partial = _encode_tag(1, 0) + _encode_varint(7) + row_partial_inner = (_encode_tag(1, 0) + _encode_varint(20) + + _encode_tag(2, 2) + _encode_varint(len(inner_partial)) + + inner_partial) + + # Row 3: outer.id = 30, inner is fully absent → inner itself is null. + row_no_inner = _encode_tag(1, 0) + _encode_varint(30) + + def run_on_spark(spark): + df = spark.createDataFrame( + [(row_empty_inner,), (row_partial_inner,), (row_no_inner,)], + schema="bin binary", + ) + decoded = _call_from_protobuf( + from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes) + return df.select( + decoded.getField("id").alias("id"), + decoded.getField("inner").getField("count").alias("inner_count"), + decoded.getField("inner").getField("label").alias("inner_label"), + decoded.getField("inner").getField("flag").alias("inner_flag"), + ) + + assert_gpu_and_cpu_are_equal_collect(run_on_spark) + + +# =========================================================================== +# Deep nested schema pruning tests +# +# These verify that the GPU path correctly prunes nested fields at depth > 2. +# Previously, collectStructFieldReferences only recognized 2-level +# GetStructField chains, so accessing decoded.level2.level3.val3 would +# decode ALL of level3's children instead of only val3. +# =========================================================================== + +def _deep_5_level_data_gen(): + return _as_datagen([ + _scalar("val1", 1, IntegerGen()), + _nested("level2", 2, [ + _scalar("val2", 1, IntegerGen()), + _nested("level3", 2, [ + _scalar("val3", 1, IntegerGen()), + _nested("level4", 2, [ + _scalar("val4", 1, IntegerGen()), + _nested("level5", 2, [ + _scalar("val5", 1, IntegerGen()), + ]), + ]), + ]), + ]), + ]) + + +@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+") +@ignore_order(local=True) +def test_deep_pruning_3_level_leaf(spark_tmp_path, from_protobuf_fn): + """Access decoded.level2.level3.val3 -- triggers 3-level deep pruning.""" + desc_path, desc_bytes = _setup_protobuf_desc( + spark_tmp_path, "dp3.desc", _build_deep_nested_5_level_descriptor_set_bytes) + message_name = "test.Level1" + data_gen = _deep_5_level_data_gen() + + def run_on_spark(spark): + df = gen_df(spark, data_gen) + decoded = _call_from_protobuf( + from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes) + return df.select( + decoded.getField("val1").alias("val1"), + decoded.getField("level2").getField("level3").getField("val3").alias("deep_val3"), + ) + + assert_gpu_and_cpu_are_equal_collect(run_on_spark) + + +@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+") +@ignore_order(local=True) +def test_deep_pruning_5_level_leaf(spark_tmp_path, from_protobuf_fn): + """Access decoded.level2.level3.level4.level5.val5 -- deepest leaf.""" + desc_path, desc_bytes = _setup_protobuf_desc( + spark_tmp_path, "dp5.desc", _build_deep_nested_5_level_descriptor_set_bytes) + message_name = "test.Level1" + data_gen = _deep_5_level_data_gen() + + def run_on_spark(spark): + df = gen_df(spark, data_gen) + decoded = _call_from_protobuf( + from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes) + return df.select( + _get_field_by_path(decoded, ["level2", "level3", "level4", "level5", "val5"]) + .alias("val5"), + ) + + assert_gpu_and_cpu_are_equal_collect(run_on_spark) + + +@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+") +@ignore_order(local=True) +def test_deep_pruning_mixed_depths(spark_tmp_path, from_protobuf_fn): + """Access leaves at different depths in the same query. + + Select val1 (depth 1), val2 (depth 2), val3 (depth 3), and val5 (depth 5) + to exercise pruning at every intermediate level simultaneously. + """ + desc_path, desc_bytes = _setup_protobuf_desc( + spark_tmp_path, "dp_mix.desc", _build_deep_nested_5_level_descriptor_set_bytes) + message_name = "test.Level1" + data_gen = _deep_5_level_data_gen() + + def run_on_spark(spark): + df = gen_df(spark, data_gen) + decoded = _call_from_protobuf( + from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes) + return df.select( + decoded.getField("val1").alias("val1"), + decoded.getField("level2").getField("val2").alias("val2"), + _get_field_by_path(decoded, ["level2", "level3", "val3"]).alias("val3"), + _get_field_by_path(decoded, ["level2", "level3", "level4", "level5", "val5"]) + .alias("val5"), + ) + + assert_gpu_and_cpu_are_equal_collect(run_on_spark) + + +@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+") +@ignore_order(local=True) +def test_deep_pruning_whole_struct_at_depth_3(spark_tmp_path, from_protobuf_fn): + """Select the whole level3 struct -- no deep pruning inside level3.""" + desc_path, desc_bytes = _setup_protobuf_desc( + spark_tmp_path, "dp_whole3.desc", _build_deep_nested_5_level_descriptor_set_bytes) + message_name = "test.Level1" + data_gen = _deep_5_level_data_gen() + + def run_on_spark(spark): + df = gen_df(spark, data_gen) + decoded = _call_from_protobuf( + from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes) + return df.select( + decoded.getField("level2").getField("level3").alias("level3"), + ) + + assert_gpu_and_cpu_are_equal_collect(run_on_spark) + + +# =========================================================================== +# FAILFAST mode tests +# =========================================================================== + +@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+") +def test_from_protobuf_failfast_malformed_data(spark_tmp_path, from_protobuf_fn): + """FAILFAST mode should throw on malformed protobuf data (both CPU and GPU).""" + desc_path, desc_bytes = _setup_protobuf_desc( + spark_tmp_path, "failfast.desc", _build_simple_descriptor_set_bytes) + message_name = "test.Simple" + + # Craft a valid row and a malformed row (truncated varint with continuation bit) + valid_row = _encode_tag(1, 0) + _encode_varint(1) + \ + _encode_tag(2, 0) + _encode_varint(42) + malformed_row = bytes([0x08, 0x80]) # field 1, varint, but only continuation byte -- no end + + def run_on_spark(spark): + df = spark.createDataFrame( + [(valid_row,), (malformed_row,)], + schema="bin binary", + ) + decoded = _call_from_protobuf( + from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes, + options={"mode": "FAILFAST"}) + # Must call .collect() so the exception surfaces inside with_*_session + return df.select(decoded.getField("b").alias("b")).collect() + + assert_gpu_and_cpu_error(run_on_spark, {}, "Malformed") + + +@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+") +@ignore_order(local=True) +def test_from_protobuf_permissive_malformed_returns_null(spark_tmp_path, from_protobuf_fn): + """PERMISSIVE mode should return null for malformed rows, not throw. + + Note: Spark's from_protobuf defaults to FAILFAST (unlike JSON/CSV which + default to PERMISSIVE), so mode must be set explicitly. + """ + desc_path, desc_bytes = _setup_protobuf_desc( + spark_tmp_path, "permissive.desc", _build_simple_descriptor_set_bytes) + message_name = "test.Simple" + + valid_row = _encode_tag(2, 0) + _encode_varint(99) + malformed_row = bytes([0x08, 0x80]) # truncated varint + + def run_on_spark(spark): + df = spark.createDataFrame( + [(valid_row,), (malformed_row,)], + schema="bin binary", + ) + decoded = _call_from_protobuf( + from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes, + options={"mode": "PERMISSIVE"}) + return df.select( + decoded.getField("i32").alias("i32"), + ) + + assert_gpu_and_cpu_are_equal_collect(run_on_spark) + + +@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+") +@ignore_order(local=True) +def test_from_protobuf_all_null_input(spark_tmp_path, from_protobuf_fn): + """All rows in the input binary column are null (not empty bytes, actual nulls). + GPU should produce all-null struct rows matching CPU behavior.""" + desc_path, desc_bytes = _setup_protobuf_desc( + spark_tmp_path, "allnull.desc", _build_simple_descriptor_set_bytes) + message_name = "test.Simple" + + def run_on_spark(spark): + df = spark.createDataFrame( + [(None,), (None,), (None,)], + schema="bin binary", + ) + decoded = _call_from_protobuf( + from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes) + return df.select( + decoded.getField("i32").alias("i32"), + decoded.getField("s").alias("s"), + ) + + assert_gpu_and_cpu_are_equal_collect(run_on_spark) diff --git a/integration_tests/src/main/python/spark_init_internal.py b/integration_tests/src/main/python/spark_init_internal.py index c0f5b6123ad..5150d2dacdc 100644 --- a/integration_tests/src/main/python/spark_init_internal.py +++ b/integration_tests/src/main/python/spark_init_internal.py @@ -45,6 +45,7 @@ def conf_for_env(env_name): return res _DRIVER_ENV = env_for_conf('spark.driver.extraJavaOptions') +_DRIVER_CLASSPATH = env_for_conf('spark.driver.extraClassPath') _SPARK_JARS = env_for_conf("spark.jars") _SPARK_JARS_PACKAGES = env_for_conf("spark.jars.packages") spark_jars_env = { @@ -64,11 +65,66 @@ def findspark_init(): if spark_jars is not None: logging.info(f"Adding to findspark jars: {spark_jars}") findspark.add_jars(spark_jars) + # Also add to driver classpath so classes are available to Class.forName() + # This is needed for optional modules like spark-protobuf + _add_driver_classpath(spark_jars) if spark_jars_packages is not None: logging.info(f"Adding to findspark packages: {spark_jars_packages}") findspark.add_packages(spark_jars_packages) + +def _add_driver_classpath(jars): + """ + Add jars to the driver classpath via PYSPARK_SUBMIT_ARGS. + findspark.add_jars() only adds --jars, which doesn't make classes available + to Class.forName() on the driver. This function adds --driver-class-path. + """ + if not jars: + return + current_args = os.environ.get('PYSPARK_SUBMIT_ARGS', '') + # Remove trailing 'pyspark-shell' if present + if current_args.endswith('pyspark-shell'): + current_args = current_args[:-len('pyspark-shell')].strip() + jar_list = [j.strip() for j in jars.split(',') if j.strip()] + existing_driver_cp = { + p for p in os.environ.get(_DRIVER_CLASSPATH, '').split(os.pathsep) if p + } + args_driver_cp_match = re.search(r'--driver-class-path\s+(\S+)', current_args) + existing_args_driver_cp = { + p for p in (args_driver_cp_match.group(1).split(os.pathsep) + if args_driver_cp_match else []) if p + } + jar_list = [ + j for j in jar_list + if j not in existing_driver_cp and j not in existing_args_driver_cp + ] + if not jar_list: + logging.info("Skipping PYSPARK_SUBMIT_ARGS driver-class-path update; jars already present") + return + new_cp = os.pathsep.join(jar_list) + if '--driver-class-path' in current_args: + if args_driver_cp_match: + existing_cp = args_driver_cp_match.group(1) + merged_cp = existing_cp + os.pathsep + new_cp + current_args = re.sub( + r'--driver-class-path\s+\S+', + lambda m: f'--driver-class-path {merged_cp}', + current_args, + count=1) + else: + current_args = re.sub( + r'--driver-class-path(?=\s|$)', + '', + current_args, + count=1).strip() + current_args += f' --driver-class-path {new_cp}' + else: + current_args += f' --driver-class-path {new_cp}' + new_args = f"{current_args} pyspark-shell".strip() + os.environ['PYSPARK_SUBMIT_ARGS'] = new_args + logging.info(f"Updated PYSPARK_SUBMIT_ARGS with driver-class-path") + def running_with_xdist(session, is_worker): try: import xdist diff --git a/integration_tests/src/test/resources/protobuf_test/gen_nested_proto_data.sh b/integration_tests/src/test/resources/protobuf_test/gen_nested_proto_data.sh new file mode 100755 index 00000000000..2283ac8a751 --- /dev/null +++ b/integration_tests/src/test/resources/protobuf_test/gen_nested_proto_data.sh @@ -0,0 +1,34 @@ +#!/bin/bash +# Convenience script: compile nested_proto .proto files into a descriptor set. +# +# Usage: +# ./gen_nested_proto_data.sh +# +# The generated .desc file is checked into the repository and used by +# integration tests in protobuf_test.py. Re-run this script whenever +# the .proto definitions under nested_proto/ change. + +set -e + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +PROTO_DIR="${SCRIPT_DIR}/nested_proto" +OUTPUT_DIR="${SCRIPT_DIR}/nested_proto/generated" + +echo "=== Protobuf Descriptor Compiler ===" +echo "Proto dir: ${PROTO_DIR}" +echo "" + +# Create output directory +mkdir -p "${OUTPUT_DIR}" + +# Compile proto files into a descriptor set (includes all imports) +DESC_FILE="${OUTPUT_DIR}/main_log.desc" +echo "Compiling proto files..." +protoc \ + --descriptor_set_out="${DESC_FILE}" \ + --include_imports \ + -I"${PROTO_DIR}" \ + "${PROTO_DIR}/main_log.proto" + +echo "Generated: ${DESC_FILE}" +echo "=== Done ===" diff --git a/integration_tests/src/test/resources/protobuf_test/nested_proto/device_req.proto b/integration_tests/src/test/resources/protobuf_test/nested_proto/device_req.proto new file mode 100644 index 00000000000..c4d86951d96 --- /dev/null +++ b/integration_tests/src/test/resources/protobuf_test/nested_proto/device_req.proto @@ -0,0 +1,11 @@ +syntax = "proto2"; + +package com.test.proto.sample; + +option java_outer_classname = "DeviceReqBean"; + +// Device request field +message DeviceReqField { + optional int32 os_type = 1; // int32 + optional bytes device_id = 2; // bytes +} diff --git a/integration_tests/src/test/resources/protobuf_test/nested_proto/generated/main_log.desc b/integration_tests/src/test/resources/protobuf_test/nested_proto/generated/main_log.desc new file mode 100644 index 00000000000..6e8155238b3 Binary files /dev/null and b/integration_tests/src/test/resources/protobuf_test/nested_proto/generated/main_log.desc differ diff --git a/integration_tests/src/test/resources/protobuf_test/nested_proto/main_log.proto b/integration_tests/src/test/resources/protobuf_test/nested_proto/main_log.proto new file mode 100644 index 00000000000..f9a325a9f2e --- /dev/null +++ b/integration_tests/src/test/resources/protobuf_test/nested_proto/main_log.proto @@ -0,0 +1,103 @@ +syntax = "proto2"; + +package com.test.proto.sample; + +import "module_a_res.proto"; +import "module_b_res.proto"; +import "device_req.proto"; + +// ========== Enum type tests ========== +enum SourceType { + WEB = 0; // web + APP = 1; // application + MOBILE = 4; // mobile +} + +enum ChannelType { + CHANNEL_A = 0; + CHANNEL_B = 1; +} + +// ========== Main log record ========== +message MainLogRecord { + // required fields + required SourceType source = 1; // required enum + required uint64 timestamp = 2; // required uint64 + + // optional scalar types - one of each + optional string user_id = 3; // string + optional int64 account_id = 4; // int64 + optional fixed32 client_ip = 5; // fixed32 + + // nested message + optional LogContent log_content = 6; +} + +// ========== Log content (multi-level nesting) ========== +message LogContent { + optional BasicInfo basic_info = 1; + repeated ChannelInfo channel_list = 2; + repeated DataSourceField source_list = 3; +} + +// ========== Basic info (three-level nesting) ========== +message BasicInfo { + optional RequestInfo request_info = 1; + optional ExtendedReqInfo extended_req_info = 2; + optional ServerAddedField server_added_field = 3; +} + +// ========== Request info ========== +message RequestInfo { + optional uint32 page_num = 1; // uint32 + optional string channel_code = 2; // string + repeated uint32 experiment_ids = 3; // repeated uint32 + optional bool is_filtered = 4; // bool +} + +// ========== Extended request info (cross-file import) ========== +message ExtendedReqInfo { + optional DeviceReqField device_req_field = 1; // reference to external proto +} + +// ========== Server-added fields ========== +message ServerAddedField { + optional uint32 region_code = 1; // uint32 + optional string flow_type = 2; // string + optional int32 filter_result = 3; // int32 + repeated int32 hit_rule_list = 4; // repeated int32 + optional uint64 request_time = 5; // uint64 + optional bool skip_flag = 6; // bool +} + +// ========== Channel info ========== +message ChannelInfo { + optional int32 channel_id = 1; + optional ModuleAResField module_a_res = 2; // reference to external proto +} + +// ========== Source channel info ========== +message SrcChannelInfo { + optional int32 channel_id = 1; + optional ModuleASrcResField module_a_src_res = 2; // reference to external proto +} + +// ========== Data source field ========== +message DataSourceField { + optional uint32 source_id = 1; + repeated SrcChannelInfo src_channel_list = 2; + optional string billing_name = 3; + repeated ItemDetailField item_list = 4; + optional bool is_free = 5; +} + +// ========== Item detail field ========== +message ItemDetailField { + optional uint32 rank = 1; // uint32 + optional uint64 record_id = 2; // uint64 + optional string keyword = 3; // string + + // cross-file message references + optional ModuleADetailField module_a_detail = 4; + optional ModuleBDetailField module_b_detail = 5; +} diff --git a/integration_tests/src/test/resources/protobuf_test/nested_proto/module_a_res.proto b/integration_tests/src/test/resources/protobuf_test/nested_proto/module_a_res.proto new file mode 100644 index 00000000000..9a2674e5dd5 --- /dev/null +++ b/integration_tests/src/test/resources/protobuf_test/nested_proto/module_a_res.proto @@ -0,0 +1,92 @@ +syntax = "proto2"; + +package com.test.proto.sample; + +option java_outer_classname = "ModuleARes"; + +import "predictor_schema.proto"; + +// ========== Test default values ========== +message PartnerInfo { + optional string token = 1 [default = ""]; // default empty string + optional uint64 partner_id = 2 [default = 0]; // default 0 +} + +// ========== Test coordinate structure (simple nesting) ========== +message Coordinate { + optional double x = 1; // x coordinate - double + optional double y = 2; // y coordinate - double +} + +// ========== Test multi-level nesting ========== +message LocationPoint { + optional uint32 frequency = 1; // frequency - uint32 + optional Coordinate coord = 2; // coordinate - nested message + optional uint64 timestamp = 3; // timestamp - uint64 +} + +// ========== Test change log ========== +message ChangeLog { + optional uint32 value_before = 1; // value before change + optional string parameters = 2; // parameters +} + +// ========== Test price change log ========== +message PriceLog { + optional uint32 price_before = 1; +} + +// ========== Module A response-level fields ========== +message ModuleAResField { + optional string route_tag = 1; // route tag - string + optional int32 status_tag = 2; // status tag - int32 + optional uint32 region_id = 3; // region id - uint32 + repeated string experiment_ids = 4; // experiment id list - repeated string + optional double quality_score = 5; // quality score - double + repeated LocationPoint location_points = 6; // location points - repeated nested message + repeated uint64 interest_ids = 7; // interest id list - repeated uint64 +} + +// ========== Module A source response field ========== +message ModuleASrcResField { + optional uint32 match_type = 1; // match type +} + +// ========== Key-value pair ========== +message KVPair { + optional bytes key = 1; // key - bytes + optional bytes value = 2; // value - bytes +} + +// ========== Style configuration ========== +message StyleConfig { + optional uint32 style_id = 1; // style id + repeated KVPair kv_pairs = 2; // kv pair list - repeated nested message +} + +// ========== Module A detail field (core complex structure) ========== +message ModuleADetailField { + // scalar types - one or two of each + optional uint32 type_code = 1; // uint32 + optional uint64 item_id = 2; // uint64 + optional int32 strategy_type = 3; // int32 + optional int64 min_value = 4; // int64 + optional bytes target_url = 5; // bytes + optional string title = 6; // string + optional bool is_valid = 7; // bool + optional float score_ratio = 8; // float + + // repeated scalar types + repeated uint32 template_ids = 9; // repeated uint32 + repeated uint64 material_ids = 10; // repeated uint64 + + // repeated nested messages + repeated StyleConfig styles = 11; // repeated message + repeated ChangeLog change_logs = 12; // repeated message + + // nested message + optional PartnerInfo partner_info = 13; // nested message + + // cross-file import + optional PredictorSchema predictor_schema = 14; // reference to external proto +} diff --git a/integration_tests/src/test/resources/protobuf_test/nested_proto/module_b_res.proto b/integration_tests/src/test/resources/protobuf_test/nested_proto/module_b_res.proto new file mode 100644 index 00000000000..0c742739835 --- /dev/null +++ b/integration_tests/src/test/resources/protobuf_test/nested_proto/module_b_res.proto @@ -0,0 +1,29 @@ +syntax = "proto2"; + +package com.test.proto.sample; + +// Module B response field +message ModuleBResField { + optional uint32 type_code = 1; // uint32 + optional string extra_info = 2; // string +} + +// Block element - tests repeated scalar types +message BlockElement { + optional uint64 element_id = 1; // uint64 + repeated uint64 ref_ids = 2; // repeated uint64 +} + +// Block info - tests repeated nested messages +message BlockInfo { + optional uint64 block_id = 1; // uint64 + repeated BlockElement elements = 2; // repeated message +} + +// Module B detail field +message ModuleBDetailField { + repeated uint32 tags = 1; // repeated uint32 + optional uint64 item_id = 2; // uint64 + optional string name = 3; // string + repeated BlockInfo blocks = 4; // repeated message +} diff --git a/integration_tests/src/test/resources/protobuf_test/nested_proto/predictor_schema.proto b/integration_tests/src/test/resources/protobuf_test/nested_proto/predictor_schema.proto new file mode 100644 index 00000000000..bc7d86c67e2 --- /dev/null +++ b/integration_tests/src/test/resources/protobuf_test/nested_proto/predictor_schema.proto @@ -0,0 +1,82 @@ +syntax = "proto2"; + +package com.test.proto.sample; + +// Predictor schema - tests multi-level schema nesting + +// ========== Main schema structure ========== +message PredictorSchema { + optional SchemaTypeA type_a_schema = 1; // nested + optional SchemaTypeB type_b_schema = 2; // nested + optional SchemaTypeC type_c_schema = 3; // nested (with repeated) +} + +// ========== TypeA Query Schema ========== +message TypeAQuerySchema { + optional string keyword = 1; // keyword + optional string session_id = 2; // session id +} + +// ========== TypeA Pair Schema ========== +message TypeAPairSchema { + optional string record_id = 1; // record id + optional string item_id = 2; // item id +} + +// ========== TypeA Schema ========== +message SchemaTypeA { + optional TypeAQuerySchema query_schema = 1; // query-level schema + repeated TypeAPairSchema pair_schema = 2; // pair list - repeated nested +} + +// ========== TypeB Query Schema ========== +message TypeBQuerySchema { + optional string profile_tag_id = 1; // profile tag id + optional string entity_id = 2; // entity id +} + +// ========== TypeB Style Element ========== +message TypeBStyleElem { + optional string template_id = 1; // template id + optional string material_id = 2; // material id +} + +// ========== TypeB Style Schema ========== +message TypeBStyleSchema { + repeated TypeBStyleElem values = 1; // element list - repeated nested +} + +// ========== TypeB Schema ========== +message SchemaTypeB { + optional TypeBQuerySchema query_schema = 1; // query-level schema + repeated TypeBStyleSchema style_schema = 2; // style list - repeated nested +} + +// ========== TypeC Query Schema ========== +message TypeCQuerySchema { + optional string keyword = 1; // keyword + optional string category = 2; // category +} + +// ========== TypeC Pair Schema ========== +message TypeCPairSchema { + optional string item_id = 1; // item id + optional string target_url = 2; // target url +} + +// ========== TypeC Style Element (empty structure test) ========== +message TypeCStyleElem { + // empty message - tests empty structure handling +} + +// ========== TypeC Style Schema ========== +message TypeCStyleSchema { + repeated TypeCStyleElem values = 1; // empty element list +} + +// ========== TypeC Schema (full three-level structure) ========== +message SchemaTypeC { + optional TypeCQuerySchema query_schema = 1; // query schema + repeated TypeCPairSchema pair_schema = 2; // pair list + repeated TypeCStyleSchema style_schema = 3; // style list +} diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuBoundAttribute.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuBoundAttribute.scala index 3c5b9f26872..3deeb88c737 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuBoundAttribute.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuBoundAttribute.scala @@ -71,7 +71,8 @@ object GpuBindReferences extends Logging { if (ordinal == -1) { sys.error(s"Couldn't find $a in ${input.attrs.mkString("[", ",", "]")}") } else { - GpuBoundReference(ordinal, a.dataType, input(ordinal).nullable)(a.exprId, a.name) + GpuBoundReference(ordinal, input(ordinal).dataType, input(ordinal).nullable)( + a.exprId, a.name) } } val matchFunc = regularMatch.orElse(partial) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 565e97a40ca..edcee71dc0f 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -955,6 +955,7 @@ object GpuOverrides extends Logging { + GpuTypeShims.additionalCommonOperatorSupportedTypes).nested(), TypeSig.all), (a, conf, p, r) => new UnaryAstExprMeta[Alias](a, conf, p, r) { + override def typeMeta: DataTypeMeta = childExprs.head.typeMeta override def convertToGpu(child: Expression): GpuExpression = GpuAlias(child, a.name)(a.exprId, a.qualifier, a.explicitMetadata) }), @@ -981,7 +982,32 @@ object GpuOverrides extends Logging { TypeSig.all), (att, conf, p, r) => new BaseExprMeta[AttributeReference](att, conf, p, r) { // This is the only NOOP operator. It goes away when things are bound - override def convertToGpuImpl(): Expression = att + override def convertToGpuImpl(): Expression = { + def findParentPlan(meta: Option[RapidsMeta[_, _, _]]): Option[SparkPlanMeta[_]] = { + meta match { + case Some(planMeta: SparkPlanMeta[_]) => Some(planMeta) + case Some(other) => findParentPlan(other.parent) + case None => None + } + } + + val maybeResolvedAttr = for { + planMeta <- findParentPlan(parent) + matched <- planMeta.childPlans.iterator + .flatMap(_.outputAttributes.iterator) + .find(_.exprId == att.exprId) + } yield AttributeReference( + att.name, + matched.dataType, + matched.nullable, + att.metadata)(att.exprId, att.qualifier) + + // NOTE: matched.dataType may still reflect the wrapped Spark plan's original + // output type (for example, an un-pruned protobuf struct). Correct runtime type + // propagation is guaranteed by GpuBoundAttribute, which uses input(ordinal).dataType + // at bind time. + maybeResolvedAttr.getOrElse(att) + } // There are so many of these that we don't need to print them out, unless it // will not work on the GPU @@ -2628,10 +2654,7 @@ object GpuOverrides extends Logging { TypeSig.STRUCT.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.BINARY), TypeSig.STRUCT.nested(TypeSig.all)), - (expr, conf, p, r) => new UnaryExprMeta[GetStructField](expr, conf, p, r) { - override def convertToGpu(arr: Expression): GpuExpression = - GpuGetStructField(arr, expr.ordinal, expr.name) - }), + (expr, conf, p, r) => new GpuGetStructFieldMeta(expr, conf, p, r)), expr[GetArrayItem]( "Gets the field at `ordinal` in the Array", ExprChecks.binaryProject( diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala index a97f830fe3e..015814e94b0 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala @@ -53,6 +53,9 @@ class GpuProjectExecMeta( p: Option[RapidsMeta[_, _, _]], r: DataFromReplacementRule) extends SparkPlanMeta[ProjectExec](proj, conf, p, r) with Logging { + override protected lazy val outputTypeMetas: Option[Seq[DataTypeMeta]] = + Some(childExprs.map(_.typeMeta)) + override def convertToGpu(): GpuExec = { // Force list to avoid recursive Java serialization of lazy list Seq implementation val gpuExprs = childExprs.map(_.convertToGpu().asInstanceOf[NamedExpression]).toList @@ -299,7 +302,6 @@ trait GpuProjectExecLike extends GpuPartitioningPreservingUnaryExecNode with Gpu override def doExecute(): RDD[InternalRow] = throw new IllegalStateException(s"Row-based execution should not occur for $this") - // The same as what feeds us override def outputBatching: CoalesceGoal = GpuExec.outputBatching(child) } @@ -768,8 +770,6 @@ case class GpuProjectExec( override def output: Seq[Attribute] = projectList.map(_.toAttribute) override def outputBatching: CoalesceGoal = if (enablePreSplit) { - // Pre-split will make sure the size of each output batch will not be larger - // than the splitUntilSize. TargetSize(PreProjectSplitIterator.getSplitUntilSize) } else { super.outputBatching diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFromProtobuf.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFromProtobuf.scala new file mode 100644 index 00000000000..35743bbf336 --- /dev/null +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFromProtobuf.scala @@ -0,0 +1,195 @@ +/* + * Copyright (c) 2025-2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.rapids + +import java.util.Arrays + +import ai.rapids.cudf +import ai.rapids.cudf.{CudfException, DType} +import com.nvidia.spark.rapids.{GpuColumnVector, GpuUnaryExpression} +import com.nvidia.spark.rapids.jni.{Protobuf, ProtobufSchemaDescriptor} +import com.nvidia.spark.rapids.shims.NullIntolerantShim + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression} +import org.apache.spark.sql.types._ + +/** + * GPU implementation for Spark's `from_protobuf` decode path. + * + * This is designed to replace `org.apache.spark.sql.protobuf.ProtobufDataToCatalyst` when + * supported. + * + * The implementation uses a flattened schema representation where nested fields have parent + * indices pointing to their containing message field. For pure scalar schemas, all fields + * are top-level (parentIndices == -1, depthLevels == 0, isRepeated == false). + * + * Schema projection is supported: `decodedSchema` contains only the top-level fields and + * nested children that are actually referenced by downstream operators. Downstream + * `GetStructField` and `GetArrayStructFields` nodes have their ordinals rewritten via + * `PRUNED_ORDINAL_TAG` to index into the pruned schema. Unreferenced fields are never + * accessed, so no null-column filling is needed. + * + * @param decodedSchema The pruned schema containing only the fields decoded by the GPU. + * Only fields referenced by downstream operators are included; + * ordinal remapping ensures correct field access into the pruned output. + * @param fieldNumbers Protobuf field numbers for all fields in flattened schema + * @param parentIndices Parent indices for all fields (-1 for top-level) + * @param depthLevels Nesting depth for all fields (0 for top-level) + * @param wireTypes Wire types for all fields + * @param outputTypeIds cuDF type IDs for all fields + * @param encodings Encodings for all fields + * @param isRepeated Whether each field is repeated + * @param isRequired Whether each field is required + * @param hasDefaultValue Whether each field has a default value + * @param defaultInts Default int/long values + * @param defaultFloats Default float/double values + * @param defaultBools Default bool values + * @param defaultStrings Default string/bytes values + * @param enumValidValues Valid enum values for each field + * @param enumNames Enum value names for enum-as-string fields. Parallel to enumValidValues. + * @param failOnErrors If true, throw exception on malformed data + */ +case class GpuFromProtobuf( + decodedSchema: StructType, + fieldNumbers: Array[Int], + parentIndices: Array[Int], + depthLevels: Array[Int], + wireTypes: Array[Int], + outputTypeIds: Array[Int], + encodings: Array[Int], + isRepeated: Array[Boolean], + isRequired: Array[Boolean], + hasDefaultValue: Array[Boolean], + defaultInts: Array[Long], + defaultFloats: Array[Double], + defaultBools: Array[Boolean], + defaultStrings: Array[Array[Byte]], + enumValidValues: Array[Array[Int]], + enumNames: Array[Array[Array[Byte]]], + failOnErrors: Boolean, + child: Expression) + extends GpuUnaryExpression with ExpectsInputTypes with NullIntolerantShim with Logging { + + override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType) + + override def dataType: DataType = decodedSchema + + override def nullable: Boolean = true + + override def equals(other: Any): Boolean = other match { + case that: GpuFromProtobuf => + decodedSchema == that.decodedSchema && + Arrays.equals(fieldNumbers, that.fieldNumbers) && + Arrays.equals(parentIndices, that.parentIndices) && + Arrays.equals(depthLevels, that.depthLevels) && + Arrays.equals(wireTypes, that.wireTypes) && + Arrays.equals(outputTypeIds, that.outputTypeIds) && + Arrays.equals(encodings, that.encodings) && + Arrays.equals(isRepeated, that.isRepeated) && + Arrays.equals(isRequired, that.isRequired) && + Arrays.equals(hasDefaultValue, that.hasDefaultValue) && + Arrays.equals(defaultInts, that.defaultInts) && + Arrays.equals(defaultFloats, that.defaultFloats) && + Arrays.equals(defaultBools, that.defaultBools) && + GpuFromProtobuf.deepEquals(defaultStrings, that.defaultStrings) && + GpuFromProtobuf.deepEquals(enumValidValues, that.enumValidValues) && + GpuFromProtobuf.deepEquals(enumNames, that.enumNames) && + failOnErrors == that.failOnErrors && + child == that.child + case _ => false + } + + override def hashCode(): Int = { + var result = decodedSchema.hashCode() + result = 31 * result + Arrays.hashCode(fieldNumbers) + result = 31 * result + Arrays.hashCode(parentIndices) + result = 31 * result + Arrays.hashCode(depthLevels) + result = 31 * result + Arrays.hashCode(wireTypes) + result = 31 * result + Arrays.hashCode(outputTypeIds) + result = 31 * result + Arrays.hashCode(encodings) + result = 31 * result + Arrays.hashCode(isRepeated) + result = 31 * result + Arrays.hashCode(isRequired) + result = 31 * result + Arrays.hashCode(hasDefaultValue) + result = 31 * result + Arrays.hashCode(defaultInts) + result = 31 * result + Arrays.hashCode(defaultFloats) + result = 31 * result + Arrays.hashCode(defaultBools) + result = 31 * result + GpuFromProtobuf.deepHashCode(defaultStrings) + result = 31 * result + GpuFromProtobuf.deepHashCode(enumValidValues) + result = 31 * result + GpuFromProtobuf.deepHashCode(enumNames) + result = 31 * result + failOnErrors.hashCode() + result = 31 * result + child.hashCode() + result + } + + // ProtobufSchemaDescriptor is a pure-Java immutable holder for validated schema arrays. + // It does not own native resources, so task-scoped close hooks are not required here. + @transient private lazy val protobufSchema = new ProtobufSchemaDescriptor( + fieldNumbers, parentIndices, depthLevels, wireTypes, outputTypeIds, encodings, + isRepeated, isRequired, hasDefaultValue, defaultInts, defaultFloats, defaultBools, + defaultStrings, enumValidValues, enumNames) + + override protected def doColumnar(input: GpuColumnVector): cudf.ColumnVector = { + // Input null mask is propagated to the output struct by the C++ decoder, + // so no mergeAndSetValidity call is needed here. + try { + Protobuf.decodeToStruct(input.getBase, protobufSchema, failOnErrors) + } catch { + case e: CudfException if failOnErrors => + throw new org.apache.spark.SparkException("Malformed protobuf message", e) + case e: CudfException => + logWarning(s"Unexpected CudfException in PERMISSIVE mode: ${e.getMessage}", e) + throw e + } + } +} + +object GpuFromProtobuf { + val ENC_DEFAULT = 0 + val ENC_FIXED = 1 + val ENC_ZIGZAG = 2 + val ENC_ENUM_STRING = 3 + + /** + * Maps a Spark DataType to the corresponding cuDF native type ID. + * Note: The encoding (varint/zigzag/fixed) is determined by the protobuf field type, + * not the Spark data type, so it must be set separately based on the protobuf schema. + * + * @return Some(typeId) for supported types, None for unsupported types + */ + def sparkTypeToCudfIdOpt(dt: DataType): Option[Int] = dt match { + case BooleanType => Some(DType.BOOL8.getTypeId.getNativeId) + case IntegerType => Some(DType.INT32.getTypeId.getNativeId) + case LongType => Some(DType.INT64.getTypeId.getNativeId) + case FloatType => Some(DType.FLOAT32.getTypeId.getNativeId) + case DoubleType => Some(DType.FLOAT64.getTypeId.getNativeId) + case StringType => Some(DType.STRING.getTypeId.getNativeId) + case BinaryType => Some(DType.LIST.getTypeId.getNativeId) + case _ => None + } + + /** + * Check if a Spark DataType is supported by the GPU protobuf decoder. + */ + def isTypeSupported(dt: DataType): Boolean = sparkTypeToCudfIdOpt(dt).isDefined + + private def deepEquals[T](left: Array[T], right: Array[T]): Boolean = + Arrays.deepEquals(left.asInstanceOf[Array[Object]], right.asInstanceOf[Array[Object]]) + + private def deepHashCode[T](arr: Array[T]): Int = + Arrays.deepHashCode(arr.asInstanceOf[Array[Object]]) +} diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala index 9afb3b854d6..c88ae6a5322 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala @@ -32,7 +32,10 @@ import org.apache.spark.sql.rapids.shims.RapidsErrorUtils import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, ArrayType, BooleanType, DataType, IntegralType, LongType, MapType, StructField, StructType} import org.apache.spark.sql.vectorized.ColumnarBatch -case class GpuGetStructField(child: Expression, ordinal: Int, name: Option[String] = None) +case class GpuGetStructField( + child: Expression, + ordinal: Int, + name: Option[String] = None) extends ShimUnaryExpression with GpuExpression with ShimGetStructField @@ -41,15 +44,23 @@ case class GpuGetStructField(child: Expression, ordinal: Int, name: Option[Strin lazy val childSchema: StructType = child.dataType.asInstanceOf[StructType] override def dataType: DataType = childSchema(ordinal).dataType - override def nullable: Boolean = child.nullable || childSchema(ordinal).nullable + + override def nullable: Boolean = + child.nullable || childSchema(ordinal).nullable override def toString: String = { - val fieldName = if (resolved) childSchema(ordinal).name else s"_$ordinal" + val fieldName = if (resolved) { + childSchema(ordinal).name + } else { + s"_$ordinal" + } s"$child.${name.getOrElse(fieldName)}" } - override def sql: String = - child.sql + s".${quoteIdentifier(name.getOrElse(childSchema(ordinal).name))}" + override def sql: String = { + val fieldName = childSchema(ordinal).name + child.sql + s".${quoteIdentifier(name.getOrElse(fieldName))}" + } override def columnarEvalAny(batch: ColumnarBatch): Any = { val dt = dataType @@ -59,7 +70,6 @@ case class GpuGetStructField(child: Expression, ordinal: Int, name: Option[Strin GpuColumnVector.from(view.copyToColumnVector(), dt) } case s: GpuScalar => - // For a scalar in we want a scalar out. if (!s.isValid) { GpuScalar(null, dt) } else { @@ -402,6 +412,36 @@ case class GpuArrayPosition(left: Expression, right: Expression) } } +object GpuStructFieldOrdinalTag { + val PRUNED_ORDINAL_TAG = + new org.apache.spark.sql.catalyst.trees.TreeNodeTag[Int]("GPU_PRUNED_ORDINAL") +} + +class GpuGetStructFieldMeta( + expr: GetStructField, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _, _]], + rule: DataFromReplacementRule) + extends UnaryExprMeta[GetStructField](expr, conf, parent, rule) { + + def convertToGpu(child: Expression): GpuExpression = { + val effectiveOrd = GpuGetStructFieldMeta.effectiveOrdinal(expr) + GpuGetStructField(child, effectiveOrd, expr.name) + } +} + +object GpuGetStructFieldMeta { + def effectiveOrdinal(expr: GetStructField): Int = { + val runtimeOrd = expr.getTagValue( + GpuStructFieldOrdinalTag.PRUNED_ORDINAL_TAG).getOrElse(-1) + if (runtimeOrd >= 0) { + runtimeOrd + } else { + expr.ordinal + } + } +} + class GpuGetArrayStructFieldsMeta( expr: GetArrayStructFields, conf: RapidsConf, @@ -409,8 +449,41 @@ class GpuGetArrayStructFieldsMeta( rule: DataFromReplacementRule) extends UnaryExprMeta[GetArrayStructFields](expr, conf, parent, rule) { - def convertToGpu(child: Expression): GpuExpression = - GpuGetArrayStructFields(child, expr.field, expr.ordinal, expr.numFields, expr.containsNull) + def convertToGpu(child: Expression): GpuExpression = { + val effectiveOrd = GpuGetArrayStructFieldsMeta.effectiveOrdinal(expr) + val runtimeOrd = expr.getTagValue( + GpuStructFieldOrdinalTag.PRUNED_ORDINAL_TAG).getOrElse(-1) + val effectiveNumFields = + GpuGetArrayStructFieldsMeta.effectiveNumFields(child, expr, runtimeOrd) + GpuGetArrayStructFields(child, expr.field, + effectiveOrd, effectiveNumFields, expr.containsNull) + } +} + +object GpuGetArrayStructFieldsMeta { + def effectiveOrdinal(expr: GetArrayStructFields): Int = { + val runtimeOrd = expr.getTagValue( + GpuStructFieldOrdinalTag.PRUNED_ORDINAL_TAG).getOrElse(-1) + if (runtimeOrd >= 0) { + runtimeOrd + } else { + expr.ordinal + } + } + + def effectiveNumFields( + child: Expression, + expr: GetArrayStructFields, + runtimeOrd: Int): Int = { + if (runtimeOrd >= 0) { + child.dataType match { + case ArrayType(st: StructType, _) => st.fields.length + case _ => expr.numFields + } + } else { + expr.numFields + } + } } /** diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/protobuf/ProtobufSchemaExtractor.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/protobuf/ProtobufSchemaExtractor.scala new file mode 100644 index 00000000000..eedbc9cd01c --- /dev/null +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/protobuf/ProtobufSchemaExtractor.scala @@ -0,0 +1,236 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.rapids.protobuf + +import scala.collection.mutable + +import com.nvidia.spark.rapids.jni.Protobuf.{WT_32BIT, WT_64BIT, WT_LEN, WT_VARINT} + +import org.apache.spark.sql.rapids.GpuFromProtobuf +import org.apache.spark.sql.types._ + +object ProtobufSchemaExtractor { + def analyzeAllFields( + schema: StructType, + msgDesc: ProtobufMessageDescriptor, + enumsAsInts: Boolean, + messageName: String): Either[String, Map[String, ProtobufFieldInfo]] = { + val result = mutable.Map[String, ProtobufFieldInfo]() + + schema.fields.foreach { sf => + val fieldInfo = msgDesc.findField(sf.name) match { + case None => + unsupportedFieldInfo( + sf, + None, + s"Protobuf field '${sf.name}' not found in message '$messageName'") + case Some(fd) => + extractFieldInfo(sf, fd, enumsAsInts) match { + case Right(info) => + info + case Left(reason) => + unsupportedFieldInfo(sf, Some(fd), reason) + } + } + result(sf.name) = fieldInfo + } + + Right(result.toMap) + } + + def extractFieldInfo( + sparkField: StructField, + fieldDescriptor: ProtobufFieldDescriptor, + enumsAsInts: Boolean): Either[String, ProtobufFieldInfo] = { + val (isSupported, unsupportedReason, encoding) = + checkFieldSupport( + sparkField.dataType, + fieldDescriptor.protoTypeName, + fieldDescriptor.isRepeated, + enumsAsInts) + + val defaultValue = fieldDescriptor.defaultValueResult match { + case Right(value) => + value + case Left(_) if !isSupported => + // Preserve the primary unsupported reason from checkFieldSupport for fields that are + // already known to be unsupported. Reflection/default extraction errors on those fields + // should not mask the more actionable type-support message. + None + case Left(reason) => + return Left(reason) + } + + Right( + ProtobufFieldInfo( + fieldNumber = fieldDescriptor.fieldNumber, + protoTypeName = fieldDescriptor.protoTypeName, + sparkType = sparkField.dataType, + encoding = encoding, + isSupported = isSupported, + unsupportedReason = unsupportedReason, + isRequired = fieldDescriptor.isRequired, + defaultValue = defaultValue, + enumMetadata = fieldDescriptor.enumMetadata, + isRepeated = fieldDescriptor.isRepeated + )) + } + + private def unsupportedFieldInfo( + sparkField: StructField, + fieldDescriptor: Option[ProtobufFieldDescriptor], + reason: String): ProtobufFieldInfo = { + ProtobufFieldInfo( + fieldNumber = fieldDescriptor.map(_.fieldNumber).getOrElse(-1), + protoTypeName = fieldDescriptor.map(_.protoTypeName).getOrElse("UNKNOWN"), + sparkType = sparkField.dataType, + encoding = GpuFromProtobuf.ENC_DEFAULT, + isSupported = false, + unsupportedReason = Some(reason), + isRequired = fieldDescriptor.exists(_.isRequired), + defaultValue = None, + enumMetadata = fieldDescriptor.flatMap(_.enumMetadata), + isRepeated = fieldDescriptor.exists(_.isRepeated) + ) + } + + def checkFieldSupport( + sparkType: DataType, + protoTypeName: String, + isRepeated: Boolean, + enumsAsInts: Boolean): (Boolean, Option[String], Int) = { + + if (isRepeated) { + sparkType match { + case ArrayType(elementType, _) => + elementType match { + case BooleanType | IntegerType | LongType | FloatType | DoubleType | + StringType | BinaryType => + return checkScalarEncoding(elementType, protoTypeName, enumsAsInts) + case _: StructType => + return (true, None, GpuFromProtobuf.ENC_DEFAULT) + case _ => + return ( + false, + Some(s"unsupported repeated element type: $elementType"), + GpuFromProtobuf.ENC_DEFAULT) + } + case _ => + return ( + false, + Some(s"repeated field should map to ArrayType, got: $sparkType"), + GpuFromProtobuf.ENC_DEFAULT) + } + } + + if (protoTypeName == "MESSAGE") { + sparkType match { + case _: StructType => + return (true, None, GpuFromProtobuf.ENC_DEFAULT) + case _ => + return ( + false, + Some(s"nested message should map to StructType, got: $sparkType"), + GpuFromProtobuf.ENC_DEFAULT) + } + } + + sparkType match { + case BooleanType | IntegerType | LongType | FloatType | DoubleType | + StringType | BinaryType => + case other => + return ( + false, + Some(s"unsupported Spark type: $other"), + GpuFromProtobuf.ENC_DEFAULT) + } + + checkScalarEncoding(sparkType, protoTypeName, enumsAsInts) + } + + def checkScalarEncoding( + sparkType: DataType, + protoTypeName: String, + enumsAsInts: Boolean): (Boolean, Option[String], Int) = { + val encoding = (sparkType, protoTypeName) match { + case (BooleanType, "BOOL") => Some(GpuFromProtobuf.ENC_DEFAULT) + case (IntegerType, "INT32" | "UINT32") => Some(GpuFromProtobuf.ENC_DEFAULT) + case (IntegerType, "SINT32") => Some(GpuFromProtobuf.ENC_ZIGZAG) + case (IntegerType, "FIXED32" | "SFIXED32") => Some(GpuFromProtobuf.ENC_FIXED) + case (LongType, "INT64" | "UINT64") => Some(GpuFromProtobuf.ENC_DEFAULT) + case (LongType, "SINT64") => Some(GpuFromProtobuf.ENC_ZIGZAG) + case (LongType, "FIXED64" | "SFIXED64") => Some(GpuFromProtobuf.ENC_FIXED) + case (LongType, "INT32" | "UINT32" | "SINT32" | "FIXED32" | "SFIXED32") => + val enc = protoTypeName match { + case "SINT32" => GpuFromProtobuf.ENC_ZIGZAG + case "FIXED32" | "SFIXED32" => GpuFromProtobuf.ENC_FIXED + case _ => GpuFromProtobuf.ENC_DEFAULT + } + Some(enc) + case (FloatType, "FLOAT") => Some(GpuFromProtobuf.ENC_DEFAULT) + case (DoubleType, "DOUBLE") => Some(GpuFromProtobuf.ENC_DEFAULT) + case (StringType, "STRING") => Some(GpuFromProtobuf.ENC_DEFAULT) + case (BinaryType, "BYTES") => Some(GpuFromProtobuf.ENC_DEFAULT) + case (IntegerType, "ENUM") if enumsAsInts => Some(GpuFromProtobuf.ENC_DEFAULT) + case (StringType, "ENUM") if !enumsAsInts => Some(GpuFromProtobuf.ENC_ENUM_STRING) + case _ => None + } + + encoding match { + case Some(enc) => (true, None, enc) + case None => + val reason = (sparkType, protoTypeName) match { + case (DoubleType, "FLOAT") => + "Spark DoubleType mapped to Protobuf FLOAT is not yet supported on GPU; " + + "use FloatType or fall back to CPU" + case (FloatType, "DOUBLE") => + "Spark FloatType mapped to Protobuf DOUBLE is not yet supported on GPU; " + + "use DoubleType or fall back to CPU" + case _ => + s"type mismatch: Spark $sparkType vs Protobuf $protoTypeName" + } + (false, + Some(reason), + GpuFromProtobuf.ENC_DEFAULT) + } + } + + def getWireType(protoTypeName: String, encoding: Int): Either[String, Int] = { + val wireType = protoTypeName match { + case "BOOL" | "INT32" | "UINT32" | "SINT32" | "INT64" | "UINT64" | "SINT64" | "ENUM" => + if (encoding == GpuFromProtobuf.ENC_FIXED) { + if (protoTypeName.contains("64")) { + WT_64BIT + } else { + WT_32BIT + } + } else { + WT_VARINT + } + case "FIXED32" | "SFIXED32" | "FLOAT" => + WT_32BIT + case "FIXED64" | "SFIXED64" | "DOUBLE" => + WT_64BIT + case "STRING" | "BYTES" | "MESSAGE" => + WT_LEN + case other => + return Left( + s"Unknown protobuf type name '$other' - cannot determine wire type; falling back to CPU") + } + Right(wireType) + } +} diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/protobuf/ProtobufSchemaModel.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/protobuf/ProtobufSchemaModel.scala new file mode 100644 index 00000000000..cb8c7374baf --- /dev/null +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/protobuf/ProtobufSchemaModel.scala @@ -0,0 +1,226 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.rapids.protobuf + +import java.util.Arrays + +import org.apache.spark.sql.types.DataType + +sealed trait ProtobufDescriptorSource + +object ProtobufDescriptorSource { + final case class DescriptorPath(path: String) extends ProtobufDescriptorSource + final case class DescriptorBytes(bytes: Array[Byte]) extends ProtobufDescriptorSource { + override def equals(other: Any): Boolean = other match { + case DescriptorBytes(otherBytes) => Arrays.equals(bytes, otherBytes) + case _ => false + } + + override def hashCode(): Int = Arrays.hashCode(bytes) + } +} + +final case class ProtobufExprInfo( + messageName: String, + descriptorSource: ProtobufDescriptorSource, + options: Map[String, String]) + +final case class ProtobufPlannerOptions( + enumsAsInts: Boolean, + failOnErrors: Boolean) + +sealed trait ProtobufDefaultValue + +object ProtobufDefaultValue { + final case class BoolValue(value: Boolean) extends ProtobufDefaultValue + final case class IntValue(value: Long) extends ProtobufDefaultValue + final case class FloatValue(value: Float) extends ProtobufDefaultValue + final case class DoubleValue(value: Double) extends ProtobufDefaultValue + final case class StringValue(value: String) extends ProtobufDefaultValue + final case class BinaryValue(value: Array[Byte]) extends ProtobufDefaultValue { + override def equals(other: Any): Boolean = other match { + case BinaryValue(otherBytes) => Arrays.equals(value, otherBytes) + case _ => false + } + + override def hashCode(): Int = Arrays.hashCode(value) + } + final case class EnumValue(number: Int, name: String) extends ProtobufDefaultValue +} + +final case class ProtobufEnumValue(number: Int, name: String) + +final case class ProtobufEnumMetadata(values: Seq[ProtobufEnumValue]) { + lazy val validValues: Array[Int] = values.map(_.number).toArray + lazy val orderedNames: Array[Array[Byte]] = values.map(_.name.getBytes("UTF-8")).toArray + lazy val namesByNumber: Map[Int, String] = values.map(v => v.number -> v.name).toMap + + def enumDefault(number: Int): ProtobufDefaultValue.EnumValue = { + val name = namesByNumber.getOrElse(number, s"$number") + ProtobufDefaultValue.EnumValue(number, name) + } +} + +trait ProtobufMessageDescriptor { + def syntax: String + def findField(name: String): Option[ProtobufFieldDescriptor] +} + +trait ProtobufFieldDescriptor { + def name: String + def fieldNumber: Int + def protoTypeName: String + def isRepeated: Boolean + def isRequired: Boolean + def defaultValueResult: Either[String, Option[ProtobufDefaultValue]] + def enumMetadata: Option[ProtobufEnumMetadata] + def messageDescriptor: Option[ProtobufMessageDescriptor] +} + +final case class ProtobufFieldInfo( + fieldNumber: Int, + protoTypeName: String, + sparkType: DataType, + encoding: Int, + isSupported: Boolean, + unsupportedReason: Option[String], + isRequired: Boolean, + defaultValue: Option[ProtobufDefaultValue], + enumMetadata: Option[ProtobufEnumMetadata], + isRepeated: Boolean = false) { + def hasDefaultValue: Boolean = defaultValue.isDefined +} + +final case class FlattenedFieldDescriptor( + fieldNumber: Int, + parentIdx: Int, + depth: Int, + wireType: Int, + outputTypeId: Int, + encoding: Int, + isRepeated: Boolean, + isRequired: Boolean, + hasDefaultValue: Boolean, + defaultInt: Long, + defaultFloat: Double, + defaultBool: Boolean, + defaultString: Array[Byte], + enumValidValues: Array[Int], + enumNames: Array[Array[Byte]]) { + override def equals(other: Any): Boolean = other match { + case that: FlattenedFieldDescriptor => + fieldNumber == that.fieldNumber && + parentIdx == that.parentIdx && + depth == that.depth && + wireType == that.wireType && + outputTypeId == that.outputTypeId && + encoding == that.encoding && + isRepeated == that.isRepeated && + isRequired == that.isRequired && + hasDefaultValue == that.hasDefaultValue && + defaultInt == that.defaultInt && + java.lang.Double.compare(defaultFloat, that.defaultFloat) == 0 && + defaultBool == that.defaultBool && + Arrays.equals(defaultString, that.defaultString) && + Arrays.equals(enumValidValues, that.enumValidValues) && + Arrays.deepEquals( + enumNames.asInstanceOf[Array[Object]], + that.enumNames.asInstanceOf[Array[Object]]) + case _ => false + } + + override def hashCode(): Int = { + var result = fieldNumber + result = 31 * result + parentIdx + result = 31 * result + depth + result = 31 * result + wireType + result = 31 * result + outputTypeId + result = 31 * result + encoding + result = 31 * result + isRepeated.hashCode() + result = 31 * result + isRequired.hashCode() + result = 31 * result + hasDefaultValue.hashCode() + result = 31 * result + defaultInt.hashCode() + result = 31 * result + defaultFloat.hashCode() + result = 31 * result + defaultBool.hashCode() + result = 31 * result + Arrays.hashCode(defaultString) + result = 31 * result + Arrays.hashCode(enumValidValues) + result = 31 * result + Arrays.deepHashCode(enumNames.asInstanceOf[Array[Object]]) + result + } +} + +final case class FlattenedSchemaArrays( + fieldNumbers: Array[Int], + parentIndices: Array[Int], + depthLevels: Array[Int], + wireTypes: Array[Int], + outputTypeIds: Array[Int], + encodings: Array[Int], + isRepeated: Array[Boolean], + isRequired: Array[Boolean], + hasDefaultValue: Array[Boolean], + defaultInts: Array[Long], + defaultFloats: Array[Double], + defaultBools: Array[Boolean], + defaultStrings: Array[Array[Byte]], + enumValidValues: Array[Array[Int]], + enumNames: Array[Array[Array[Byte]]]) { + override def equals(other: Any): Boolean = other match { + case that: FlattenedSchemaArrays => + Arrays.equals(fieldNumbers, that.fieldNumbers) && + Arrays.equals(parentIndices, that.parentIndices) && + Arrays.equals(depthLevels, that.depthLevels) && + Arrays.equals(wireTypes, that.wireTypes) && + Arrays.equals(outputTypeIds, that.outputTypeIds) && + Arrays.equals(encodings, that.encodings) && + Arrays.equals(isRepeated, that.isRepeated) && + Arrays.equals(isRequired, that.isRequired) && + Arrays.equals(hasDefaultValue, that.hasDefaultValue) && + Arrays.equals(defaultInts, that.defaultInts) && + Arrays.equals(defaultFloats, that.defaultFloats) && + Arrays.equals(defaultBools, that.defaultBools) && + Arrays.deepEquals( + defaultStrings.asInstanceOf[Array[Object]], + that.defaultStrings.asInstanceOf[Array[Object]]) && + Arrays.deepEquals( + enumValidValues.asInstanceOf[Array[Object]], + that.enumValidValues.asInstanceOf[Array[Object]]) && + Arrays.deepEquals( + enumNames.asInstanceOf[Array[Object]], + that.enumNames.asInstanceOf[Array[Object]]) + case _ => false + } + + override def hashCode(): Int = { + var result = Arrays.hashCode(fieldNumbers) + result = 31 * result + Arrays.hashCode(parentIndices) + result = 31 * result + Arrays.hashCode(depthLevels) + result = 31 * result + Arrays.hashCode(wireTypes) + result = 31 * result + Arrays.hashCode(outputTypeIds) + result = 31 * result + Arrays.hashCode(encodings) + result = 31 * result + Arrays.hashCode(isRepeated) + result = 31 * result + Arrays.hashCode(isRequired) + result = 31 * result + Arrays.hashCode(hasDefaultValue) + result = 31 * result + Arrays.hashCode(defaultInts) + result = 31 * result + Arrays.hashCode(defaultFloats) + result = 31 * result + Arrays.hashCode(defaultBools) + result = 31 * result + Arrays.deepHashCode(defaultStrings.asInstanceOf[Array[Object]]) + result = 31 * result + Arrays.deepHashCode(enumValidValues.asInstanceOf[Array[Object]]) + result = 31 * result + Arrays.deepHashCode(enumNames.asInstanceOf[Array[Object]]) + result + } +} diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/protobuf/ProtobufSchemaValidator.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/protobuf/ProtobufSchemaValidator.scala new file mode 100644 index 00000000000..31758c97236 --- /dev/null +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/protobuf/ProtobufSchemaValidator.scala @@ -0,0 +1,184 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.rapids.protobuf + +import ai.rapids.cudf.DType + +import org.apache.spark.sql.rapids.GpuFromProtobuf +import org.apache.spark.sql.types._ + +object ProtobufSchemaValidator { + private final case class JniDefaultValues( + defaultInt: Long, + defaultFloat: Double, + defaultBool: Boolean, + defaultString: Array[Byte]) + + def toFlattenedFieldDescriptor( + path: String, + field: StructField, + fieldInfo: ProtobufFieldInfo, + parentIdx: Int, + depth: Int, + outputTypeId: Int): Either[String, FlattenedFieldDescriptor] = { + validateFieldInfo(path, field, fieldInfo).flatMap { _ => + ProtobufSchemaExtractor + .getWireType(fieldInfo.protoTypeName, fieldInfo.encoding) + .flatMap { wireType => + encodeDefaultValue(path, field.dataType, fieldInfo).map { defaults => + val enumValidValues = fieldInfo.enumMetadata.map(_.validValues).orNull + val enumNames = + if (fieldInfo.encoding == GpuFromProtobuf.ENC_ENUM_STRING) { + fieldInfo.enumMetadata.map(_.orderedNames).orNull + } else { + null + } + + FlattenedFieldDescriptor( + fieldNumber = fieldInfo.fieldNumber, + parentIdx = parentIdx, + depth = depth, + wireType = wireType, + outputTypeId = outputTypeId, + encoding = fieldInfo.encoding, + isRepeated = fieldInfo.isRepeated, + isRequired = fieldInfo.isRequired, + hasDefaultValue = fieldInfo.hasDefaultValue, + defaultInt = defaults.defaultInt, + defaultFloat = defaults.defaultFloat, + defaultBool = defaults.defaultBool, + defaultString = defaults.defaultString, + enumValidValues = enumValidValues, + enumNames = enumNames + ) + } + } + } + } + + def validateFlattenedSchema(flatFields: Seq[FlattenedFieldDescriptor]): Either[String, Unit] = { + val structTypeId = DType.STRUCT.getTypeId.getNativeId + flatFields.zipWithIndex.foreach { case (field, idx) => + if (field.parentIdx >= idx) { + return Left(s"Flattened protobuf schema has invalid parent index at position $idx") + } + if (field.parentIdx == -1 && field.depth != 0) { + return Left(s"Top-level protobuf field at position $idx must have depth 0") + } + if (field.parentIdx >= 0 && field.depth <= 0) { + return Left(s"Nested protobuf field at position $idx must have positive depth") + } + if (field.parentIdx >= 0 && flatFields(field.parentIdx).outputTypeId != structTypeId) { + return Left( + s"Protobuf field at position $idx has a non-STRUCT parent at ${field.parentIdx}") + } + if (field.encoding == GpuFromProtobuf.ENC_ENUM_STRING) { + if (field.enumValidValues == null || field.enumNames == null) { + return Left(s"Enum-string field at position $idx is missing enum metadata") + } + if (field.enumValidValues.length != field.enumNames.length) { + return Left(s"Enum-string field at position $idx has mismatched enum metadata") + } + } + } + Right(()) + } + + def toFlattenedSchemaArrays( + flatFields: Array[FlattenedFieldDescriptor]): FlattenedSchemaArrays = { + FlattenedSchemaArrays( + fieldNumbers = flatFields.map(_.fieldNumber), + parentIndices = flatFields.map(_.parentIdx), + depthLevels = flatFields.map(_.depth), + wireTypes = flatFields.map(_.wireType), + outputTypeIds = flatFields.map(_.outputTypeId), + encodings = flatFields.map(_.encoding), + isRepeated = flatFields.map(_.isRepeated), + isRequired = flatFields.map(_.isRequired), + hasDefaultValue = flatFields.map(_.hasDefaultValue), + defaultInts = flatFields.map(_.defaultInt), + defaultFloats = flatFields.map(_.defaultFloat), + defaultBools = flatFields.map(_.defaultBool), + defaultStrings = flatFields.map(_.defaultString), + enumValidValues = flatFields.map(_.enumValidValues), + enumNames = flatFields.map(_.enumNames) + ) + } + + private def validateFieldInfo( + path: String, + field: StructField, + fieldInfo: ProtobufFieldInfo): Either[String, Unit] = { + if (fieldInfo.isRepeated && fieldInfo.hasDefaultValue) { + return Left(s"Repeated protobuf field '$path' cannot carry a default value") + } + + fieldInfo.enumMetadata match { + case Some(enumMeta) if enumMeta.values.isEmpty => + return Left(s"Enum field '$path' is missing enum values") + case Some(_) if fieldInfo.protoTypeName != "ENUM" => + return Left(s"Non-enum field '$path' should not carry enum metadata") + case None if fieldInfo.protoTypeName == "ENUM" => + return Left(s"Enum field '$path' is missing enum metadata") + case _ => + } + + if (fieldInfo.encoding == GpuFromProtobuf.ENC_ENUM_STRING && + fieldInfo.enumMetadata.isEmpty) { + return Left(s"Enum-string field '$path' is missing enum metadata") + } + + Right(()) + } + + private def encodeDefaultValue( + path: String, + dataType: DataType, + fieldInfo: ProtobufFieldInfo): Either[String, JniDefaultValues] = { + val empty = JniDefaultValues(0L, 0.0, defaultBool = false, null) + fieldInfo.defaultValue match { + case None => Right(empty) + case Some(defaultValue) => + val targetType = dataType match { + case ArrayType(elementType, _) => elementType + case other => other + } + (targetType, defaultValue) match { + case (BooleanType, ProtobufDefaultValue.BoolValue(value)) => + Right(empty.copy(defaultBool = value)) + case (IntegerType | LongType, ProtobufDefaultValue.IntValue(value)) => + Right(empty.copy(defaultInt = value)) + case (IntegerType | LongType, ProtobufDefaultValue.EnumValue(number, _)) => + Right(empty.copy(defaultInt = number.toLong)) + case (FloatType, ProtobufDefaultValue.FloatValue(value)) => + Right(empty.copy(defaultFloat = value.toDouble)) + case (DoubleType, ProtobufDefaultValue.DoubleValue(value)) => + Right(empty.copy(defaultFloat = value)) + case (StringType, ProtobufDefaultValue.StringValue(value)) => + Right(empty.copy(defaultString = value.getBytes("UTF-8"))) + case (StringType, ProtobufDefaultValue.EnumValue(number, name)) => + Right(empty.copy( + defaultInt = number.toLong, + defaultString = name.getBytes("UTF-8"))) + case (BinaryType, ProtobufDefaultValue.BinaryValue(value)) => + Right(empty.copy(defaultString = value)) + case _ => + Left(s"Incompatible default value for protobuf field '$path': $defaultValue") + } + } + } +} diff --git a/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/ProtobufExprShims.scala b/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/ProtobufExprShims.scala new file mode 100644 index 00000000000..13a5f7f6a0c --- /dev/null +++ b/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/ProtobufExprShims.scala @@ -0,0 +1,910 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "340"} +{"spark": "341"} +{"spark": "342"} +{"spark": "343"} +{"spark": "344"} +{"spark": "350"} +{"spark": "351"} +{"spark": "352"} +{"spark": "353"} +{"spark": "354"} +{"spark": "355"} +{"spark": "356"} +{"spark": "357"} +{"spark": "400"} +{"spark": "401"} +{"spark": "402"} +{"spark": "411"} +spark-rapids-shim-json-lines ***/ + +package com.nvidia.spark.rapids.shims + +import scala.collection.mutable + +import ai.rapids.cudf.DType +import com.nvidia.spark.rapids._ + +import org.apache.spark.sql.catalyst.expressions.{ + AttributeReference, Expression, GetArrayStructFields, GetStructField, UnaryExpression +} +import org.apache.spark.sql.execution.ProjectExec +import org.apache.spark.sql.rapids.GpuFromProtobuf +import org.apache.spark.sql.rapids.protobuf.{ + FlattenedFieldDescriptor, + ProtobufFieldInfo, + ProtobufMessageDescriptor, + ProtobufSchemaExtractor, + ProtobufSchemaValidator +} +import org.apache.spark.sql.types._ + +/** + * Spark 3.4+ optional integration for spark-protobuf expressions. + * + * spark-protobuf is an external module, so these rules must be registered by reflection. + */ +object ProtobufExprShims extends org.apache.spark.internal.Logging { + private[this] val protobufDataToCatalystClassName = + "org.apache.spark.sql.protobuf.ProtobufDataToCatalyst" + + val PRUNED_ORDINAL_TAG = + org.apache.spark.sql.rapids.GpuStructFieldOrdinalTag.PRUNED_ORDINAL_TAG + + def exprs: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = { + try { + val clazz = ShimReflectionUtils.loadClass(protobufDataToCatalystClassName) + .asInstanceOf[Class[_ <: UnaryExpression]] + Map(clazz.asInstanceOf[Class[_ <: Expression]] -> fromProtobufRule) + } catch { + case _: ClassNotFoundException => Map.empty + case e: Exception => + logWarning(s"Failed to load $protobufDataToCatalystClassName: ${e.getMessage}") + Map.empty + case e: Error => + logWarning( + s"JVM error while loading $protobufDataToCatalystClassName: ${e.getMessage}") + Map.empty + } + } + + private def fromProtobufRule: ExprRule[_ <: Expression] = { + GpuOverrides.expr[UnaryExpression]( + "Decode a BinaryType column (protobuf) into a Spark SQL struct", + ExprChecks.unaryProject( + // Use TypeSig.all here because schema projection determines which fields + // actually need GPU support. Detailed type checking is done in tagExprForGpu. + TypeSig.all, + TypeSig.all, + TypeSig.BINARY, + TypeSig.BINARY), + (e, conf, p, r) => new UnaryExprMeta[UnaryExpression](e, conf, p, r) { + + private var fullSchema: StructType = _ + private var failOnErrors: Boolean = _ + + // Flattened schema variables for GPU decoding + private var flatFieldNumbers: Array[Int] = _ + private var flatParentIndices: Array[Int] = _ + private var flatDepthLevels: Array[Int] = _ + private var flatWireTypes: Array[Int] = _ + private var flatOutputTypeIds: Array[Int] = _ + private var flatEncodings: Array[Int] = _ + private var flatIsRepeated: Array[Boolean] = _ + private var flatIsRequired: Array[Boolean] = _ + private var flatHasDefaultValue: Array[Boolean] = _ + private var flatDefaultInts: Array[Long] = _ + private var flatDefaultFloats: Array[Double] = _ + private var flatDefaultBools: Array[Boolean] = _ + private var flatDefaultStrings: Array[Array[Byte]] = _ + private var flatEnumValidValues: Array[Array[Int]] = _ + private var flatEnumNames: Array[Array[Array[Byte]]] = _ + // Indices in fullSchema for top-level fields that were decoded (for schema projection) + private var decodedTopLevelIndices: Array[Int] = _ + + override def tagExprForGpu(): Unit = { + fullSchema = e.dataType match { + case st: StructType => st + case other => + willNotWorkOnGpu( + s"Only StructType output is supported for from_protobuf, got $other") + return + } + + val exprInfo = SparkProtobufCompat.extractExprInfo(e) match { + case Right(info) => info + case Left(reason) => + willNotWorkOnGpu(reason) + return + } + val unsupportedOptions = SparkProtobufCompat.unsupportedOptions(exprInfo.options) + if (unsupportedOptions.nonEmpty) { + val keys = unsupportedOptions.mkString(",") + willNotWorkOnGpu( + s"from_protobuf options are not supported yet on GPU: $keys") + return + } + + val plannerOptions = SparkProtobufCompat.parsePlannerOptions(exprInfo.options) match { + case Right(opts) => opts + case Left(reason) => + willNotWorkOnGpu(reason) + return + } + val enumsAsInts = plannerOptions.enumsAsInts + failOnErrors = plannerOptions.failOnErrors + val messageName = exprInfo.messageName + + val msgDesc = SparkProtobufCompat.resolveMessageDescriptor(exprInfo) match { + case Right(desc) => desc + case Left(reason) => + willNotWorkOnGpu(reason) + return + } + + // Reject proto3 descriptors — GPU decoder only supports proto2 semantics. + // proto3 has different null/default-value behavior that the GPU path doesn't handle. + val protoSyntax = msgDesc.syntax + if (!SparkProtobufCompat.isGpuSupportedProtoSyntax(protoSyntax)) { + willNotWorkOnGpu( + "proto3/editions syntax is not supported by the GPU protobuf decoder; " + + "only proto2 is supported. The query will fall back to CPU.") + return + } + + // Step 1: Analyze all fields and build field info map + val fieldsInfoMap = + ProtobufSchemaExtractor.analyzeAllFields(fullSchema, msgDesc, enumsAsInts, messageName) + .fold({ reason => + willNotWorkOnGpu(reason) + return + }, identity) + + // Step 2: Determine which fields are actually required by downstream operations + val requiredFieldNames = analyzeRequiredFields(fieldsInfoMap.keySet) + + // Step 2b: Proto2 required fields must always be decoded so the GPU can + // detect missing-required and null the struct row (PERMISSIVE) or throw + // (FAILFAST), matching CPU behavior. Without this, schema projection + // can prune a required field and the GPU silently produces a non-null + // struct where CPU would have returned null. + val protoRequiredFieldNames = fieldsInfoMap.collect { + case (name, info) if info.isRequired => name + }.toSet + val allFieldsToDecode = requiredFieldNames ++ protoRequiredFieldNames + + // Step 2c: When nested pruning selects only some children of a struct, + // proto2 required children must still be included so their presence + // can be checked by the GPU decoder. + augmentNestedRequirementsWithRequired(msgDesc) + + // Step 3: Check if all fields to be decoded are supported + val unsupportedRequired = allFieldsToDecode.filter { name => + fieldsInfoMap.get(name).exists(!_.isSupported) + } + + if (unsupportedRequired.nonEmpty) { + val reasons = unsupportedRequired.map { name => + val info = fieldsInfoMap(name) + s"${name}: ${info.unsupportedReason.getOrElse("unknown reason")}" + } + willNotWorkOnGpu( + s"Required fields not supported for from_protobuf: ${reasons.mkString(", ")}") + return + } + + // Step 4: Identify which fields in fullSchema need to be decoded + val indicesToDecode = fullSchema.fields.zipWithIndex.collect { + case (sf, idx) if allFieldsToDecode.contains(sf.name) => idx + } + + // Verify all fields to be decoded are actually supported + // (This catches edge cases where field analysis might have issues) + val unsupportedInDecode = indicesToDecode.filter { idx => + val sf = fullSchema.fields(idx) + fieldsInfoMap.get(sf.name).exists(!_.isSupported) + } + if (unsupportedInDecode.nonEmpty) { + val reasons = unsupportedInDecode.map { idx => + val sf = fullSchema.fields(idx) + val info = fieldsInfoMap(sf.name) + s"${sf.name}: ${info.unsupportedReason.getOrElse("unknown reason")}" + } + willNotWorkOnGpu( + s"Fields not supported for from_protobuf: ${reasons.mkString(", ")}") + return + } + + // Step 5: Build flattened schema for GPU decoding. + // The flattened schema represents nested fields with parent indices. + // For pure scalar schemas, all fields are top-level (parentIdx == -1, depth == 0). + { + val flatFields = mutable.ArrayBuffer[FlattenedFieldDescriptor]() + var step5Failed = false + + def failStep5(reason: String): Unit = { + step5Failed = true + willNotWorkOnGpu(reason) + } + + // Helper to add a field and its children recursively. + // pathPrefix is the dot-path of ancestor fields (empty for top-level). + def addFieldWithChildren( + sf: StructField, + info: ProtobufFieldInfo, + parentIdx: Int, + depth: Int, + containingMsgDesc: ProtobufMessageDescriptor, + pathPrefix: String = ""): Unit = { + + val currentIdx = flatFields.size + + if (depth >= 10) { + failStep5("Protobuf nesting depth exceeds maximum supported depth of 10") + return + } + + val outputTypeOpt = sf.dataType match { + case ArrayType(elemType, _) => + elemType match { + case _: StructType => + // Repeated message field: ArrayType(StructType) - element type is STRUCT + Some(DType.STRUCT.getTypeId.getNativeId) + case other => + GpuFromProtobuf.sparkTypeToCudfIdOpt(other) + } + case _: StructType => + Some(DType.STRUCT.getTypeId.getNativeId) + case other => + GpuFromProtobuf.sparkTypeToCudfIdOpt(other) + } + val outputType = outputTypeOpt.getOrElse { + failStep5( + s"Unsupported Spark type for protobuf field '${sf.name}': ${sf.dataType}") + return + } + + val path = if (pathPrefix.isEmpty) sf.name else s"$pathPrefix.${sf.name}" + ProtobufSchemaValidator.toFlattenedFieldDescriptor( + path, + sf, + info, + parentIdx, + depth, + outputType).fold({ reason => + failStep5(reason) + return + }, flatFields += _) + + // For nested struct types (including repeated message = ArrayType(StructType)), + // add child fields + sf.dataType match { + case st: StructType if containingMsgDesc != null => + // Repeated message parents and plain struct parents share the same child + // expansion path; the flat parent entry's isRepeated flag distinguishes them. + addChildFieldsFromStruct( + st, containingMsgDesc, sf.name, currentIdx, depth, pathPrefix) + + case ArrayType(st: StructType, _) if containingMsgDesc != null => + addChildFieldsFromStruct( + st, containingMsgDesc, sf.name, currentIdx, depth, pathPrefix) + + case _ => // Not a struct, no children to add + } + } + + // Helper to add child fields from a struct type. + // Applies nested schema pruning at arbitrary depth using path-based + // lookup into nestedFieldRequirements. + def addChildFieldsFromStruct( + st: StructType, + containingMsgDesc: ProtobufMessageDescriptor, + fieldName: String, + parentIdx: Int, + parentDepth: Int, + pathPrefix: String): Unit = { + val path = if (pathPrefix.isEmpty) fieldName else s"$pathPrefix.$fieldName" + // containingMsgDesc is the descriptor of the message that directly contains + // fieldName. + val parentField = containingMsgDesc.findField(fieldName) + if (parentField.isEmpty) { + failStep5( + s"Nested field '$fieldName' not found in protobuf descriptor at '$path'") + return + } else { + parentField.get.messageDescriptor match { + case Some(childMsgDesc) => + val requiredChildren = nestedFieldRequirements.get(path) + val filteredFields = requiredChildren match { + case Some(Some(childNames)) => + st.fields.filter(f => childNames.contains(f.name)) + case _ => + st.fields + } + filteredFields.foreach { childSf => + childMsgDesc.findField(childSf.name) match { + case None => + failStep5( + s"Nested field '${childSf.name}' not found in protobuf " + + s"descriptor for message at '$path'") + return + case Some(childFd) => + ProtobufSchemaExtractor + .extractFieldInfo(childSf, childFd, enumsAsInts) match { + case Left(reason) => + failStep5(reason) + return + case Right(childInfo) => + if (!childInfo.isSupported) { + failStep5( + s"Nested field '${childSf.name}' at '$path': " + + childInfo.unsupportedReason.getOrElse("unsupported type")) + return + } else { + addFieldWithChildren( + childSf, childInfo, parentIdx, parentDepth + 1, childMsgDesc, + path) + } + } + } + } + case None => + failStep5( + s"Nested field '$fieldName' at '$path' did not resolve to a message type") + return + } + } + } + + // Only add top-level fields that are actually required (schema projection). + // This significantly reduces GPU memory and computation for schemas with many + // fields when only a few are needed. Downstream GetStructField ordinals are + // remapped via PRUNED_ORDINAL_TAG to index into the pruned output. + decodedTopLevelIndices = indicesToDecode + indicesToDecode.foreach { schemaIdx => + if (!step5Failed) { + val sf = fullSchema.fields(schemaIdx) + val info = fieldsInfoMap(sf.name) + addFieldWithChildren(sf, info, -1, 0, msgDesc) + } + } + + if (step5Failed) { + return + } + + // Populate flattened schema variables + val flat = flatFields.toArray + ProtobufSchemaValidator.validateFlattenedSchema(flat).fold({ reason => + failStep5(reason) + return + }, identity) + val arrays = ProtobufSchemaValidator.toFlattenedSchemaArrays(flat) + flatFieldNumbers = arrays.fieldNumbers + flatParentIndices = arrays.parentIndices + flatDepthLevels = arrays.depthLevels + flatWireTypes = arrays.wireTypes + flatOutputTypeIds = arrays.outputTypeIds + flatEncodings = arrays.encodings + flatIsRepeated = arrays.isRepeated + flatIsRequired = arrays.isRequired + flatHasDefaultValue = arrays.hasDefaultValue + flatDefaultInts = arrays.defaultInts + flatDefaultFloats = arrays.defaultFloats + flatDefaultBools = arrays.defaultBools + flatDefaultStrings = arrays.defaultStrings + flatEnumValidValues = arrays.enumValidValues + flatEnumNames = arrays.enumNames + + val prunedFieldsMap = buildPrunedFieldsMap() + // PRUNED_ORDINAL_TAG is set here, after all willNotWorkOnGpu guards succeed. + // This is safe because the CPU path never reads the tag, and if a parent later + // forces this subtree back to CPU, the decode and its field extractors fall back + // together, so no partial-GPU path can misread stale ordinals. + targetExprsToRemap.foreach( + registerPrunedOrdinals(_, prunedFieldsMap, decodedTopLevelIndices.toSeq)) + overrideDataType(buildDecodedSchema(prunedFieldsMap)) + } + } + + /** + * Analyze which fields are actually required by downstream operations. + * Traverses parent plan nodes upward, collecting struct field references from + * ProjectExec, FilterExec, and transparent pass-through nodes (AggregateExec, + * SortExec, WindowExec, etc.), then returns the set of required top-level + * field names. + * + * @param allFieldNames All field names in the full schema + * @return Set of field names that are actually required + */ + private var targetExprsToRemap: Seq[Expression] = Seq.empty + + private def analyzeRequiredFields(allFieldNames: Set[String]): Set[String] = { + val fieldReqs = mutable.Map[String, Option[Set[String]]]() + protobufOutputExprIds = Set.empty + var hasDirectStructRef = false + val holder = () => { hasDirectStructRef = true } + + var currentMeta: Option[SparkPlanMeta[_]] = findParentPlanMeta() + var safeToPrune = true + val collectedExprs = mutable.ArrayBuffer[Expression]() + val startingPlanMeta = currentMeta + + def advanceToParent(): Unit = { + currentMeta = currentMeta.get.parent match { + case Some(pm: SparkPlanMeta[_]) => Some(pm) + case _ => None + } + } + + while (currentMeta.isDefined && safeToPrune) { + val allowSemanticReferenceMatch = currentMeta == startingPlanMeta + currentMeta.get.wrapped match { + case p: ProjectExec => + collectedExprs ++= p.projectList + p.projectList.foreach { + case alias: org.apache.spark.sql.catalyst.expressions.Alias + if isProtobufStructReference( + alias.child, allowSemanticReferenceMatch) => + protobufOutputExprIds += alias.exprId + case _ => + } + p.projectList.foreach( + collectStructFieldReferences( + _, fieldReqs, holder, allowSemanticReferenceMatch)) + advanceToParent() + case f: org.apache.spark.sql.execution.FilterExec => + collectedExprs += f.condition + collectStructFieldReferences( + f.condition, fieldReqs, holder, allowSemanticReferenceMatch) + advanceToParent() + case a: org.apache.spark.sql.execution.aggregate.BaseAggregateExec => + val exprs = a.aggregateExpressions ++ a.groupingExpressions + collectedExprs ++= exprs + exprs.foreach( + collectStructFieldReferences( + _, fieldReqs, holder, allowSemanticReferenceMatch)) + advanceToParent() + case s: org.apache.spark.sql.execution.SortExec => + val exprs = s.sortOrder + collectedExprs ++= exprs + exprs.foreach( + collectStructFieldReferences( + _, fieldReqs, holder, allowSemanticReferenceMatch)) + advanceToParent() + case w: org.apache.spark.sql.execution.window.WindowExec => + val exprs = w.windowExpression + collectedExprs ++= exprs + exprs.foreach( + collectStructFieldReferences( + _, fieldReqs, holder, allowSemanticReferenceMatch)) + advanceToParent() + case other => + // Keep schema projection conservative for less common plan shapes above + // from_protobuf. Those plans currently fall back to full-schema decode + // until we add dedicated coverage for pruning through them. + logDebug(s"Schema pruning disabled: unrecognized plan node " + + s"${other.getClass.getSimpleName} above from_protobuf") + safeToPrune = false + } + } + + // An empty fieldReqs also subsumes the "no relevant expressions collected" case. + if (!safeToPrune || hasDirectStructRef || fieldReqs.isEmpty) { + targetExprsToRemap = Seq.empty + allFieldNames + } else { + nestedFieldRequirements = fieldReqs.toMap + targetExprsToRemap = collectedExprs.toSeq + fieldReqs.keySet.toSet + } + } + + /** + * Find the parent SparkPlanMeta by traversing up the parent chain. + */ + private def findParentPlanMeta(): Option[SparkPlanMeta[_]] = { + def traverse(meta: Option[RapidsMeta[_, _, _]]): Option[SparkPlanMeta[_]] = { + meta match { + case Some(p: SparkPlanMeta[_]) => Some(p) + case Some(p: RapidsMeta[_, _, _]) => traverse(p.parent) + case _ => None + } + } + traverse(parent) + } + + /** + * Nested field requirements: maps a field path to child requirements. + * Keys are dot-separated paths from the protobuf root: + * - "level1" -> Some(Set("level2")) (top-level struct pruning) + * - "level1.level2" -> Some(Set("level3")) (deep nested pruning) + * - "field" -> None (whole field needed) + * + * Top-level names (keys without dots) also determine which fields are decoded. + */ + private var nestedFieldRequirements: Map[String, Option[Set[String]]] = Map.empty + private var protobufOutputExprIds: Set[ + org.apache.spark.sql.catalyst.expressions.ExprId] = Set.empty + private lazy val protobufOutputExprId + : Option[org.apache.spark.sql.catalyst.expressions.ExprId] = + parent.flatMap { meta => + meta.wrapped match { + case alias: org.apache.spark.sql.catalyst.expressions + .Alias if alias.child.semanticEquals(e) => + Some(alias.exprId) + case _ => None + } + } + + private def getFieldName(ordinal: Int, nameOpt: Option[String], + schema: StructType): String = { + nameOpt.getOrElse { + if (ordinal < schema.fields.length) schema.fields(ordinal).name + else s"_$ordinal" + } + } + + /** + * Navigate the protobuf descriptor tree by following a path of field names. + * Returns the message descriptor at the end of the path, or None. + */ + private def resolveProtoMsgDesc( + rootDesc: ProtobufMessageDescriptor, + pathParts: Seq[String]): Option[ProtobufMessageDescriptor] = { + pathParts.headOption match { + case None => Some(rootDesc) + case Some(head) => + rootDesc.findField(head).flatMap(_.messageDescriptor).flatMap { childDesc => + if (pathParts.tail.isEmpty) Some(childDesc) + else resolveProtoMsgDesc(childDesc, pathParts.tail) + } + } + } + + /** + * Augment nestedFieldRequirements so that proto2 required children are + * always included when nested pruning is active. Without this, a required + * child that is not referenced downstream would be pruned, and the GPU + * decoder would not check its presence. + */ + private def augmentNestedRequirementsWithRequired( + rootMsgDesc: ProtobufMessageDescriptor): Unit = { + if (nestedFieldRequirements.isEmpty) return + nestedFieldRequirements = nestedFieldRequirements.map { + case entry @ (pathKey, Some(childNames)) => + val pathParts = pathKey.split("\\.").toSeq + resolveProtoMsgDesc(rootMsgDesc, pathParts) match { + case Some(childMsgDesc) => + val schemaType = resolveSchemaAtPath(fullSchema, pathParts) + if (schemaType != null) { + val requiredChildNames = schemaType.fields.flatMap { sf => + childMsgDesc.findField(sf.name).filter(_.isRequired).map(_ => sf.name) + }.toSet + if (requiredChildNames.subsetOf(childNames)) entry + else pathKey -> Some(childNames ++ requiredChildNames) + } else { + entry + } + case None => entry + } + case other => other + } + } + + /** + * Navigate the Spark schema tree by following a dot-separated path of + * field names. Returns the StructType at the end of the path, unwrapping + * ArrayType(StructType) along the way, or null if the path is invalid. + */ + private def resolveSchemaAtPath(root: StructType, path: Seq[String]): StructType = { + var current: StructType = root + for (name <- path) { + val field = current.fields.find(_.name == name).orNull + if (field == null) return null + field.dataType match { + case st: StructType => current = st + case ArrayType(st: StructType, _) => current = st + case _ => return null + } + } + current + } + + /** + * Walk a GetStructField chain upward until it reaches the protobuf + * reference expression, returning the sequence of field names forming + * the access path. Returns None if the chain does not terminate at a + * protobuf reference. + * + * Example: for `GetStructField(GetStructField(decoded, a_ord), b_ord)` + * → Some(Seq("a", "b")) + */ + private def resolveFieldAccessChain( + expr: Expression, + allowSemanticReferenceMatch: Boolean): Option[Seq[String]] = { + expr match { + case GetStructField(child, ordinal, nameOpt) => + if (isProtobufStructReference(child, allowSemanticReferenceMatch)) { + Some(Seq(getFieldName(ordinal, nameOpt, fullSchema))) + } else { + resolveFieldAccessChain(child, allowSemanticReferenceMatch).flatMap { parentPath => + val parentSchema = if (parentPath.isEmpty) fullSchema + else resolveSchemaAtPath(fullSchema, parentPath) + if (parentSchema != null) { + Some(parentPath :+ getFieldName(ordinal, nameOpt, parentSchema)) + } else { + None + } + } + } + case _ if isProtobufStructReference(expr, allowSemanticReferenceMatch) => + Some(Seq.empty) + case _ => + None + } + } + + private def addNestedFieldReq( + fieldReqs: mutable.Map[String, Option[Set[String]]], + parentKey: String, + childName: String): Unit = { + fieldReqs.get(parentKey) match { + case Some(None) => // Already need whole field, keep it + case Some(Some(existing)) => + fieldReqs(parentKey) = Some(existing + childName) + case None => + fieldReqs(parentKey) = Some(Set(childName)) + } + } + + /** + * Register pruning requirements at every level of a field access path. + * For path = ["a", "b"] with leafName = "c": + * "a" -> needs child "b" + * "a.b" -> needs child "c" + */ + private def registerPathRequirements( + fieldReqs: mutable.Map[String, Option[Set[String]]], + path: Seq[String], + leafName: String): Unit = { + for (i <- path.indices) { + val pathKey = path.take(i + 1).mkString(".") + val childName = if (i < path.length - 1) path(i + 1) else leafName + addNestedFieldReq(fieldReqs, pathKey, childName) + } + } + + private def collectStructFieldReferences( + expr: Expression, + fieldReqs: mutable.Map[String, Option[Set[String]]], + hasDirectStructRefHolder: () => Unit, + allowSemanticReferenceMatch: Boolean): Unit = { + expr match { + case GetStructField(child, ordinal, nameOpt) => + resolveFieldAccessChain(child, allowSemanticReferenceMatch) match { + case Some(parentPath) => + val parentSchema = if (parentPath.isEmpty) fullSchema + else resolveSchemaAtPath(fullSchema, parentPath) + if (parentSchema != null) { + val fieldName = getFieldName(ordinal, nameOpt, parentSchema) + if (parentPath.isEmpty) { + // Direct top-level access: decoded.field_name (whole field) + fieldReqs(fieldName) = None + } else { + registerPathRequirements(fieldReqs, parentPath, fieldName) + } + } else { + collectStructFieldReferences( + child, fieldReqs, hasDirectStructRefHolder, allowSemanticReferenceMatch) + } + case None => + collectStructFieldReferences( + child, fieldReqs, hasDirectStructRefHolder, allowSemanticReferenceMatch) + } + + case gasf: GetArrayStructFields => + resolveFieldAccessChain(gasf.child, allowSemanticReferenceMatch) match { + case Some(parentPath) if parentPath.nonEmpty => + registerPathRequirements(fieldReqs, parentPath, gasf.field.name) + case Some(parentPath) if parentPath.isEmpty => + logDebug("Schema pruning disabled: unexpected direct protobuf reference in " + + "GetArrayStructFields") + hasDirectStructRefHolder() + case _ => + gasf.children.foreach { child => + collectStructFieldReferences( + child, fieldReqs, hasDirectStructRefHolder, allowSemanticReferenceMatch) + } + } + + case alias: org.apache.spark.sql.catalyst.expressions.Alias => + if (!isProtobufStructReference(alias.child, allowSemanticReferenceMatch)) { + collectStructFieldReferences( + alias.child, fieldReqs, hasDirectStructRefHolder, allowSemanticReferenceMatch) + } + + case _ => + if (isProtobufStructReference(expr, allowSemanticReferenceMatch)) { + hasDirectStructRefHolder() + } + expr.children.foreach { child => + collectStructFieldReferences( + child, fieldReqs, hasDirectStructRefHolder, allowSemanticReferenceMatch) + } + } + } + + private def buildPrunedFieldsMap(): Map[String, Seq[String]] = { + nestedFieldRequirements.collect { + case (pathKey, Some(childNames)) => + val pathParts = pathKey.split("\\.").toSeq + val childSchema = resolveSchemaAtPath(fullSchema, pathParts) + if (childSchema != null) { + val orderedNames = childSchema.fields + .map(_.name) + .filter(childNames.contains) + .toSeq + pathKey -> orderedNames + } else { + // This path should be unreachable for valid schema paths, but keep the fallback + // deterministic rather than relying on Set iteration order. + pathKey -> childNames.toSeq.sorted + } + } + } + + private def buildDecodedSchema(prunedFieldsMap: Map[String, Seq[String]]): StructType = { + def applyPruning(field: StructField, prefix: String): StructField = { + val path = if (prefix.isEmpty) field.name else s"$prefix.${field.name}" + prunedFieldsMap.get(path) match { + case Some(childNames) => + field.dataType match { + case ArrayType(st: StructType, cn) => + val pruned = StructType( + st.fields.filter(f => childNames.contains(f.name)) + .map(f => applyPruning(f, path))) + field.copy(dataType = ArrayType(pruned, cn)) + case st: StructType => + val pruned = StructType( + st.fields.filter(f => childNames.contains(f.name)) + .map(f => applyPruning(f, path))) + field.copy(dataType = pruned) + case _ => field + } + case None => + field.dataType match { + case ArrayType(st: StructType, cn) => + val recursed = StructType(st.fields.map(f => applyPruning(f, path))) + if (recursed != st) field.copy(dataType = ArrayType(recursed, cn)) + else field + case st: StructType => + val recursed = StructType(st.fields.map(f => applyPruning(f, path))) + if (recursed != st) field.copy(dataType = recursed) + else field + case _ => field + } + } + } + + val decodedFields = decodedTopLevelIndices.map { idx => + applyPruning(fullSchema.fields(idx), "") + } + StructType(decodedFields.map(f => f.copy(nullable = true))) + } + + private def registerPrunedOrdinals( + expr: Expression, + prunedFieldsMap: Map[String, Seq[String]], + topLevelIndices: Seq[Int]): Unit = { + expr match { + case gsf @ GetStructField(childExpr, ordinal, nameOpt) => + resolveFieldAccessChain(childExpr, allowSemanticReferenceMatch = true) match { + case Some(parentPath) if parentPath.nonEmpty => + val parentSchema = resolveSchemaAtPath(fullSchema, parentPath) + if (parentSchema != null) { + val pathKey = parentPath.mkString(".") + val childName = getFieldName(ordinal, nameOpt, parentSchema) + prunedFieldsMap.get(pathKey).foreach { orderedChildren => + val runtimeOrd = orderedChildren.indexOf(childName) + if (runtimeOrd >= 0) { + gsf.setTagValue(ProtobufExprShims.PRUNED_ORDINAL_TAG, runtimeOrd) + } + } + } + case Some(parentPath) if parentPath.isEmpty => + val runtimeOrd = topLevelIndices.indexOf(ordinal) + if (runtimeOrd >= 0) { + gsf.setTagValue(ProtobufExprShims.PRUNED_ORDINAL_TAG, runtimeOrd) + } + case _ => + } + case gasf @ GetArrayStructFields(childExpr, field, _, _, _) => + resolveFieldAccessChain(childExpr, allowSemanticReferenceMatch = true) match { + case Some(parentPath) if parentPath.nonEmpty => + val pathKey = parentPath.mkString(".") + prunedFieldsMap.get(pathKey).foreach { orderedChildren => + val runtimeOrd = orderedChildren.indexOf(field.name) + if (runtimeOrd >= 0) { + gasf.setTagValue(ProtobufExprShims.PRUNED_ORDINAL_TAG, runtimeOrd) + } + } + case _ => + } + case _ => + } + expr.children.foreach(registerPrunedOrdinals(_, prunedFieldsMap, topLevelIndices)) + } + + /** + * Check if an expression references the output of a protobuf decode expression. + * This can be either: + * 1. The ProtobufDataToCatalyst expression itself + * 2. An AttributeReference that references the output of ProtobufDataToCatalyst + * (when accessing from a downstream ProjectExec) + */ + private def isProtobufStructReference( + expr: Expression, + allowSemanticReferenceMatch: Boolean): Boolean = { + if ((expr eq e) || expr.semanticEquals(e)) { + return true + } + + // Catalyst may create duplicate ProtobufDataToCatalyst + // instances for each GetStructField access. Match copies + // by class + identical input child + identical decode + // semantics so that + // analyzeRequiredFields detects all field accesses in one + // pass, keeping schema projection correct. + if (allowSemanticReferenceMatch && + expr.getClass == e.getClass && + expr.children.nonEmpty && + e.children.nonEmpty && + ((expr.children.head eq e.children.head) || + expr.children.head.semanticEquals( + e.children.head)) && + SparkProtobufCompat.sameDecodeSemantics(expr, e)) { + return true + } + + expr match { + case attr: AttributeReference => + protobufOutputExprIds.contains(attr.exprId) || + protobufOutputExprId.exists(_ == attr.exprId) + case _ => false + } + } + + override def convertToGpu(child: Expression): GpuExpression = { + val prunedFieldsMap = buildPrunedFieldsMap() + val decodedSchema = buildDecodedSchema(prunedFieldsMap) + + GpuFromProtobuf( + decodedSchema, + flatFieldNumbers, flatParentIndices, + flatDepthLevels, flatWireTypes, flatOutputTypeIds, flatEncodings, + flatIsRepeated, flatIsRequired, flatHasDefaultValue, flatDefaultInts, + flatDefaultFloats, flatDefaultBools, flatDefaultStrings, flatEnumValidValues, + flatEnumNames, failOnErrors, child) + } + } + ) + } + +} diff --git a/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/Spark340PlusNonDBShims.scala b/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/Spark340PlusNonDBShims.scala index 88c62eea41b..afb9c059ccf 100644 --- a/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/Spark340PlusNonDBShims.scala +++ b/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/Spark340PlusNonDBShims.scala @@ -160,7 +160,7 @@ trait Spark340PlusNonDBShims extends Spark331PlusNonDBShims { ), GpuElementAtMeta.elementAtRule(true) ).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap - super.getExprs ++ shimExprs + super.getExprs ++ shimExprs ++ ProtobufExprShims.exprs } override def getDataWriteCmds: Map[Class[_ <: DataWritingCommand], diff --git a/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/SparkProtobufCompat.scala b/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/SparkProtobufCompat.scala new file mode 100644 index 00000000000..85f1c4ac12b --- /dev/null +++ b/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/SparkProtobufCompat.scala @@ -0,0 +1,381 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "340"} +{"spark": "341"} +{"spark": "342"} +{"spark": "343"} +{"spark": "344"} +{"spark": "350"} +{"spark": "351"} +{"spark": "352"} +{"spark": "353"} +{"spark": "354"} +{"spark": "355"} +{"spark": "356"} +{"spark": "357"} +{"spark": "400"} +{"spark": "401"} +{"spark": "402"} +{"spark": "411"} +spark-rapids-shim-json-lines ***/ + +package com.nvidia.spark.rapids.shims + +import java.lang.reflect.Method +import java.nio.file.{Files, Paths} + +import scala.util.Try + +import com.nvidia.spark.rapids.ShimReflectionUtils + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.rapids.protobuf._ + +private[shims] object SparkProtobufCompat extends Logging { + private[this] val sparkProtobufUtilsObjectClassName = + "org.apache.spark.sql.protobuf.utils.ProtobufUtils$" + + val SupportedOptions: Set[String] = Set("enums.as.ints", "mode") + + def extractExprInfo(e: Expression): Either[String, ProtobufExprInfo] = { + for { + messageName <- reflectMessageName(e) + options <- reflectOptions(e) + descriptorSource <- reflectDescriptorSource(e) + } yield ProtobufExprInfo(messageName, descriptorSource, options) + } + + def sameDecodeSemantics(left: Expression, right: Expression): Boolean = { + (extractExprInfo(left), extractExprInfo(right)) match { + case (Right(leftInfo), Right(rightInfo)) => leftInfo == rightInfo + case _ => false + } + } + + def parsePlannerOptions( + options: Map[String, String]): Either[String, ProtobufPlannerOptions] = { + val enumsAsInts = Try(options.getOrElse("enums.as.ints", "false").toBoolean) + .toEither + .left + .map { _ => + "Invalid value for from_protobuf option 'enums.as.ints': " + + s"'${options.getOrElse("enums.as.ints", "")}' (expected true/false)" + } + enumsAsInts.map(v => + ProtobufPlannerOptions( + enumsAsInts = v, + failOnErrors = options.getOrElse("mode", "FAILFAST").equalsIgnoreCase("FAILFAST"))) + } + + def unsupportedOptions(options: Map[String, String]): Seq[String] = + options.keys.filterNot(SupportedOptions.contains).toSeq.sorted + + def isGpuSupportedProtoSyntax(syntax: String): Boolean = + syntax.nonEmpty && syntax != "null" && syntax != "PROTO3" && syntax != "EDITIONS" + + def resolveMessageDescriptor( + exprInfo: ProtobufExprInfo): Either[String, ProtobufMessageDescriptor] = { + Try(buildMessageDescriptor(exprInfo.messageName, exprInfo.descriptorSource)) + .toEither + .left + .map { t => + s"Failed to resolve protobuf descriptor for message '${exprInfo.messageName}': " + + s"${t.getMessage}" + } + .map(new ReflectiveMessageDescriptor(_)) + } + + private def reflectMessageName(e: Expression): Either[String, String] = + Try(PbReflect.invoke0[String](e, "messageName")).toEither.left.map { t => + s"Cannot read from_protobuf messageName via reflection: ${t.getMessage}" + } + + private def reflectOptions(e: Expression): Either[String, Map[String, String]] = { + Try(PbReflect.invoke0[scala.collection.Map[String, String]](e, "options")) + .map(_.toMap) + .toEither.left.map { _ => + "Cannot read from_protobuf options via reflection; falling back to CPU" + } + } + + private def reflectDescriptorSource(e: Expression): Either[String, ProtobufDescriptorSource] = { + reflectDescFilePath(e).map(ProtobufDescriptorSource.DescriptorPath).orElse( + reflectDescriptorBytes(e).map(ProtobufDescriptorSource.DescriptorBytes)).toRight( + "from_protobuf requires a descriptor set (descFilePath or binaryFileDescriptorSet)") + } + + private def reflectDescFilePath(e: Expression): Option[String] = + Try(PbReflect.invoke0[Option[String]](e, "descFilePath")).toOption.flatten + + private def reflectDescriptorBytes(e: Expression): Option[Array[Byte]] = { + val spark35Result = Try(PbReflect.invoke0[Option[Array[Byte]]](e, "binaryFileDescriptorSet")) + .toOption.flatten + spark35Result.orElse { + val direct = Try(PbReflect.invoke0[Array[Byte]](e, "binaryDescriptorSet")).toOption + direct.orElse { + Try(PbReflect.invoke0[Option[Array[Byte]]](e, "binaryDescriptorSet")).toOption.flatten + } + } + } + + private def buildMessageDescriptor( + messageName: String, + descriptorSource: ProtobufDescriptorSource): AnyRef = { + val cls = ShimReflectionUtils.loadClass(sparkProtobufUtilsObjectClassName) + val module = cls.getField("MODULE$").get(null) + val buildMethod = cls.getMethod("buildDescriptor", classOf[String], classOf[scala.Option[_]]) + + invokeBuildDescriptor( + buildMethod, + module, + messageName, + descriptorSource, + filePath => Files.readAllBytes(Paths.get(filePath))) + } + + private[shims] def invokeBuildDescriptor( + buildMethod: Method, + module: AnyRef, + messageName: String, + descriptorSource: ProtobufDescriptorSource, + readDescriptorFile: String => Array[Byte]): AnyRef = { + descriptorSource match { + case ProtobufDescriptorSource.DescriptorBytes(bytes) => + buildMethod.invoke(module, messageName, Some(bytes)).asInstanceOf[AnyRef] + case ProtobufDescriptorSource.DescriptorPath(filePath) => + try { + buildMethod.invoke(module, messageName, Some(filePath)).asInstanceOf[AnyRef] + } catch { + case ex: java.lang.reflect.InvocationTargetException => + val cause = ex.getCause + // Spark 3.5+ changed the descriptor payload from Option[String] to + // Option[Array[Byte]] while keeping the same erased JVM signature. + // Retry with file contents when the path-based invocation clearly hit that + // binary-descriptor variant. + if (cause != null && (cause.isInstanceOf[ClassCastException] || + cause.isInstanceOf[MatchError])) { + Try { + buildMethod.invoke( + module, messageName, Some(readDescriptorFile(filePath))).asInstanceOf[AnyRef] + }.recoverWith { case retryEx => + val wrapped = buildDescriptorRetryFailure(cause, retryEx) + wrapped.addSuppressed(ex) + scala.util.Failure(wrapped) + }.get + } else { + throw ex + } + } + } + } + + private def buildDescriptorRetryFailure( + originalCause: Throwable, + retryFailure: Throwable): RuntimeException = { + val retryCause = unwrapInvocationFailure(retryFailure) + new RuntimeException( + s"Spark 3.5+ descriptor bytes retry failed after initial path invocation error " + + s"(${describeThrowable(originalCause)}); retry error (${describeThrowable(retryCause)})", + retryCause) + } + + private def unwrapInvocationFailure(t: Throwable): Throwable = t match { + case ex: java.lang.reflect.InvocationTargetException if ex.getCause != null => ex.getCause + case other => other + } + + private def describeThrowable(t: Throwable): String = { + val suffix = Option(t.getMessage).filter(_.nonEmpty).map(msg => s": $msg").getOrElse("") + s"${t.getClass.getSimpleName}$suffix" + } + + private def typeName(t: AnyRef): String = + if (t == null) "" else Try(PbReflect.invoke0[String](t, "name")).getOrElse(t.toString) + + private final class ReflectiveMessageDescriptor(raw: AnyRef) extends ProtobufMessageDescriptor { + override lazy val syntax: String = PbReflect.getFileSyntax(raw, typeName) + + override def findField(name: String): Option[ProtobufFieldDescriptor] = + Option(PbReflect.findFieldByName(raw, name)).map(new ReflectiveFieldDescriptor(_)) + } + + private final class ReflectiveFieldDescriptor(raw: AnyRef) extends ProtobufFieldDescriptor { + override lazy val name: String = PbReflect.invoke0[String](raw, "getName") + override lazy val fieldNumber: Int = PbReflect.getFieldNumber(raw) + override lazy val protoTypeName: String = typeName(PbReflect.getFieldType(raw)) + override lazy val isRepeated: Boolean = PbReflect.isRepeated(raw) + override lazy val isRequired: Boolean = PbReflect.isRequired(raw) + override lazy val enumMetadata: Option[ProtobufEnumMetadata] = + if (protoTypeName == "ENUM") { + Some(ProtobufEnumMetadata(PbReflect.getEnumValues(PbReflect.getEnumType(raw)))) + } else { + None + } + override lazy val defaultValueResult: Either[String, Option[ProtobufDefaultValue]] = + Try { + if (PbReflect.hasDefaultValue(raw)) { + PbReflect.getDefaultValue(raw) match { + case Some(default) => + toDefaultValue(default, protoTypeName, enumMetadata).map(Some(_)) + case None => + Right(None) + } + } else { + Right(None) + } + }.toEither.left.map { t => + s"Failed to read protobuf default value for field '$name': ${t.getMessage}" + }.flatMap(identity) + override lazy val messageDescriptor: Option[ProtobufMessageDescriptor] = + if (protoTypeName == "MESSAGE") { + Some(new ReflectiveMessageDescriptor(PbReflect.getMessageType(raw))) + } else { + None + } + } + + private def toDefaultValue( + rawDefault: AnyRef, + protoTypeName: String, + enumMetadata: Option[ProtobufEnumMetadata]): Either[String, ProtobufDefaultValue] = + protoTypeName match { + case "BOOL" => + Right(ProtobufDefaultValue.BoolValue( + rawDefault.asInstanceOf[java.lang.Boolean].booleanValue())) + case "FLOAT" => + Right(ProtobufDefaultValue.FloatValue( + rawDefault.asInstanceOf[java.lang.Float].floatValue())) + case "DOUBLE" => + Right(ProtobufDefaultValue.DoubleValue( + rawDefault.asInstanceOf[java.lang.Double].doubleValue())) + case "STRING" => + Right(ProtobufDefaultValue.StringValue( + if (rawDefault == null) null else rawDefault.toString)) + case "BYTES" => + Right(ProtobufDefaultValue.BinaryValue(extractBytes(rawDefault))) + case "ENUM" => + val number = extractNumber(rawDefault).intValue() + Right(enumMetadata.map(_.enumDefault(number)) + .getOrElse(ProtobufDefaultValue.EnumValue(number, rawDefault.toString))) + case "INT32" | "UINT32" | "SINT32" | "FIXED32" | "SFIXED32" | + "INT64" | "UINT64" | "SINT64" | "FIXED64" | "SFIXED64" => + Right(ProtobufDefaultValue.IntValue(extractNumber(rawDefault).longValue())) + case other => + Left( + s"Unsupported protobuf default value type '$other' for value ${rawDefault.toString}") + } + + private def extractNumber(rawDefault: AnyRef): java.lang.Number = rawDefault match { + case n: java.lang.Number => n + case ref: AnyRef => + Try { + ref.getClass.getMethod("getNumber").invoke(ref).asInstanceOf[java.lang.Number] + }.getOrElse { + throw new IllegalStateException( + s"Unsupported protobuf numeric default value class: ${ref.getClass.getName}") + } + } + + private def extractBytes(rawDefault: AnyRef): Array[Byte] = rawDefault match { + case bytes: Array[Byte] => bytes + case ref: AnyRef => + Try { + ref.getClass.getMethod("toByteArray").invoke(ref).asInstanceOf[Array[Byte]] + }.getOrElse { + throw new IllegalStateException( + s"Unsupported protobuf bytes default value class: ${ref.getClass.getName}") + } + } + + private object PbReflect { + private val cache = new java.util.concurrent.ConcurrentHashMap[String, Method]() + + private def protobufJavaVersion: String = Try { + val rtCls = Class.forName("com.google.protobuf.RuntimeVersion") + val domain = rtCls.getField("DOMAIN").get(null) + val major = rtCls.getField("MAJOR").get(null) + val minor = rtCls.getField("MINOR").get(null) + val patch = rtCls.getField("PATCH").get(null) + s"$domain-$major.$minor.$patch" + }.getOrElse("unknown") + + private def cached(cls: Class[_], name: String, paramTypes: Class[_]*): Method = { + val key = s"${cls.getName}#$name(${paramTypes.map(_.getName).mkString(",")})" + cache.computeIfAbsent(key, _ => { + try { + cls.getMethod(name, paramTypes: _*) + } catch { + case ex: NoSuchMethodException => + throw new UnsupportedOperationException( + s"protobuf-java method not found: ${cls.getSimpleName}.$name " + + s"(protobuf-java version: $protobufJavaVersion). " + + s"This may indicate an incompatible protobuf-java library version.", + ex) + } + }) + } + + def invoke0[T](obj: AnyRef, method: String): T = + cached(obj.getClass, method).invoke(obj).asInstanceOf[T] + + def invoke1[T](obj: AnyRef, method: String, arg0Cls: Class[_], arg0: AnyRef): T = + cached(obj.getClass, method, arg0Cls).invoke(obj, arg0).asInstanceOf[T] + + def findFieldByName(msgDesc: AnyRef, name: String): AnyRef = + invoke1[AnyRef](msgDesc, "findFieldByName", classOf[String], name) + + def getFieldNumber(fd: AnyRef): Int = + invoke0[java.lang.Integer](fd, "getNumber").intValue() + + def getFieldType(fd: AnyRef): AnyRef = invoke0[AnyRef](fd, "getType") + + def isRepeated(fd: AnyRef): Boolean = + invoke0[java.lang.Boolean](fd, "isRepeated").booleanValue() + + def isRequired(fd: AnyRef): Boolean = + invoke0[java.lang.Boolean](fd, "isRequired").booleanValue() + + def hasDefaultValue(fd: AnyRef): Boolean = + invoke0[java.lang.Boolean](fd, "hasDefaultValue").booleanValue() + + def getDefaultValue(fd: AnyRef): Option[AnyRef] = + Option(invoke0[AnyRef](fd, "getDefaultValue")) + + def getMessageType(fd: AnyRef): AnyRef = invoke0[AnyRef](fd, "getMessageType") + + def getEnumType(fd: AnyRef): AnyRef = invoke0[AnyRef](fd, "getEnumType") + + def getEnumValues(enumType: AnyRef): Seq[ProtobufEnumValue] = { + import scala.collection.JavaConverters._ + val values = invoke0[java.util.List[_]](enumType, "getValues") + values.asScala.map { v => + val ev = v.asInstanceOf[AnyRef] + val num = invoke0[java.lang.Integer](ev, "getNumber").intValue() + val enumName = invoke0[String](ev, "getName") + ProtobufEnumValue(num, enumName) + }.toSeq + } + + def getFileSyntax(msgDesc: AnyRef, typeNameFn: AnyRef => String): String = Try { + val fileDesc = invoke0[AnyRef](msgDesc, "getFile") + val syntaxObj = invoke0[AnyRef](fileDesc, "getSyntax") + typeNameFn(syntaxObj) + }.getOrElse("") + } +} diff --git a/sql-plugin/src/test/scala/com/nvidia/spark/rapids/shims/ProtobufExprShimsSuite.scala b/sql-plugin/src/test/scala/com/nvidia/spark/rapids/shims/ProtobufExprShimsSuite.scala new file mode 100644 index 00000000000..4bd85dbc458 --- /dev/null +++ b/sql-plugin/src/test/scala/com/nvidia/spark/rapids/shims/ProtobufExprShimsSuite.scala @@ -0,0 +1,633 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.shims + +import ai.rapids.cudf.DType +import org.scalatest.funsuite.AnyFunSuite + +import org.apache.spark.sql.catalyst.expressions.{ + Expression, + GetArrayStructFields, + UnaryExpression +} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.rapids.{ + GpuFromProtobuf, + GpuGetArrayStructFieldsMeta, + GpuStructFieldOrdinalTag +} +import org.apache.spark.sql.rapids.protobuf._ +import org.apache.spark.sql.types._ + +class ProtobufExprShimsSuite extends AnyFunSuite { + private val outputSchema = StructType(Seq( + StructField("id", IntegerType, nullable = true), + StructField("name", StringType, nullable = true))) + + private case class FakeExprChild() extends Expression { + override def children: Seq[Expression] = Nil + override def nullable: Boolean = true + override def dataType: DataType = BinaryType + override def eval(input: org.apache.spark.sql.catalyst.InternalRow): Any = null + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + throw new UnsupportedOperationException("not needed") + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): Expression = { + assert(newChildren.isEmpty) + this + } + } + + private abstract class FakeBaseProtobufExpr(childExpr: Expression) extends UnaryExpression { + override def child: Expression = childExpr + override def nullable: Boolean = true + override def dataType: DataType = outputSchema + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + throw new UnsupportedOperationException("not needed") + override protected def withNewChildInternal(newChild: Expression): Expression = this + } + + private case class FakePathProtobufExpr(override val child: Expression) + extends FakeBaseProtobufExpr(child) { + def messageName: String = "test.Message" + def descFilePath: Option[String] = Some("/tmp/test.desc") + def options: scala.collection.Map[String, String] = Map("mode" -> "FAILFAST") + } + + private case class FakeBytesProtobufExpr(override val child: Expression) + extends FakeBaseProtobufExpr(child) { + def messageName: String = "test.Message" + def binaryDescriptorSet: Array[Byte] = Array[Byte](1, 2, 3) + def options: scala.collection.Map[String, String] = + Map("mode" -> "PERMISSIVE", "enums.as.ints" -> "true") + } + + private case class FakeMissingOptionsExpr(override val child: Expression) + extends FakeBaseProtobufExpr(child) { + def messageName: String = "test.Message" + def descFilePath: Option[String] = Some("/tmp/test.desc") + } + + private case class FakeDifferentMessageExpr(override val child: Expression) + extends FakeBaseProtobufExpr(child) { + def messageName: String = "test.OtherMessage" + def descFilePath: Option[String] = Some("/tmp/test.desc") + def options: scala.collection.Map[String, String] = Map("mode" -> "FAILFAST") + } + + private case class FakeDifferentDescriptorExpr(override val child: Expression) + extends FakeBaseProtobufExpr(child) { + def messageName: String = "test.Message" + def descFilePath: Option[String] = Some("/tmp/other.desc") + def options: scala.collection.Map[String, String] = Map("mode" -> "FAILFAST") + } + + private case class FakeDifferentOptionsExpr(override val child: Expression) + extends FakeBaseProtobufExpr(child) { + def messageName: String = "test.Message" + def descFilePath: Option[String] = Some("/tmp/test.desc") + def options: scala.collection.Map[String, String] = Map("mode" -> "PERMISSIVE") + } + + private case class FakeTypedUnaryExpr( + dt: DataType, + override val child: Expression = FakeExprChild()) extends UnaryExpression { + override def nullable: Boolean = true + override def dataType: DataType = dt + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + throw new UnsupportedOperationException("not needed") + override protected def withNewChildInternal(newChild: Expression): Expression = copy(child = + newChild) + } + + private object FakeSpark34ProtobufUtils { + def buildDescriptor(messageName: String, descFilePath: Option[String]): String = + s"$messageName:${descFilePath.getOrElse("none")}" + } + + private object FakeSpark35ProtobufUtils { + def buildDescriptor(messageName: String, binaryFileDescriptorSet: Option[Array[Byte]]): String = + s"$messageName:${binaryFileDescriptorSet.map(_.mkString(",")).getOrElse("none")}" + } + + private object FakeSpark35RetryFailureProtobufUtils { + def buildDescriptor( + messageName: String, + binaryFileDescriptorSet: Option[Array[Byte]]): String = { + val bytes = binaryFileDescriptorSet.getOrElse(Array.emptyByteArray) + if (bytes.sameElements(Array[Byte](1, 2, 3))) { + throw new IllegalArgumentException(s"Unknown message $messageName") + } + s"$messageName:${bytes.mkString(",")}" + } + } + + private case class FakeMessageDescriptor( + syntax: String, + fields: Map[String, ProtobufFieldDescriptor]) extends ProtobufMessageDescriptor { + override def findField(name: String): Option[ProtobufFieldDescriptor] = fields.get(name) + } + + private case class FakeFieldDescriptor( + name: String, + fieldNumber: Int, + protoTypeName: String, + isRepeated: Boolean = false, + isRequired: Boolean = false, + defaultValue: Option[ProtobufDefaultValue] = None, + defaultValueError: Option[String] = None, + enumMetadata: Option[ProtobufEnumMetadata] = None, + messageDescriptor: Option[ProtobufMessageDescriptor] = None) extends ProtobufFieldDescriptor { + override lazy val defaultValueResult: Either[String, Option[ProtobufDefaultValue]] = + defaultValueError match { + case Some(reason) => Left(reason) + case None => Right(defaultValue) + } + } + + test("compat extracts descriptor path and options from legacy expression") { + val exprInfo = SparkProtobufCompat.extractExprInfo(FakePathProtobufExpr(FakeExprChild())) + assert(exprInfo.isRight) + val info = exprInfo.toOption.get + assert(info.messageName == "test.Message") + assert(info.options == Map("mode" -> "FAILFAST")) + assert(info.descriptorSource == + ProtobufDescriptorSource.DescriptorPath("/tmp/test.desc")) + } + + test("compat extracts binary descriptor source and planner options") { + val exprInfo = SparkProtobufCompat.extractExprInfo(FakeBytesProtobufExpr(FakeExprChild())) + assert(exprInfo.isRight) + val info = exprInfo.toOption.get + info.descriptorSource match { + case ProtobufDescriptorSource.DescriptorBytes(bytes) => + assert(bytes.sameElements(Array[Byte](1, 2, 3))) + case other => + fail(s"Unexpected descriptor source: $other") + } + val plannerOptions = SparkProtobufCompat.parsePlannerOptions(info.options) + assert(plannerOptions == + Right(ProtobufPlannerOptions(enumsAsInts = true, failOnErrors = false))) + } + + test("compat invokes Spark 3.4 descriptor builder with descriptor path") { + val buildMethod = FakeSpark34ProtobufUtils.getClass.getMethod( + "buildDescriptor", classOf[String], classOf[scala.Option[_]]) + + val result = SparkProtobufCompat.invokeBuildDescriptor( + buildMethod, + FakeSpark34ProtobufUtils, + "test.Message", + ProtobufDescriptorSource.DescriptorPath("/tmp/test.desc"), + _ => fail("path-to-bytes fallback should not be needed for Spark 3.4")) + + assert(result == "test.Message:/tmp/test.desc") + } + + test("compat retries descriptor path as bytes for Spark 3.5 descriptor builder") { + val buildMethod = FakeSpark35ProtobufUtils.getClass.getMethod( + "buildDescriptor", classOf[String], classOf[scala.Option[_]]) + var readCalls = 0 + + val result = SparkProtobufCompat.invokeBuildDescriptor( + buildMethod, + FakeSpark35ProtobufUtils, + "test.Message", + ProtobufDescriptorSource.DescriptorPath("/tmp/test.desc"), + _ => { + readCalls += 1 + Array[Byte](1, 2, 3) + }) + + assert(readCalls == 1) + assert(result == "test.Message:1,2,3") + } + + test("compat passes bytes directly to Spark 3.5 descriptor builder") { + val buildMethod = FakeSpark35ProtobufUtils.getClass.getMethod( + "buildDescriptor", classOf[String], classOf[scala.Option[_]]) + + val result = SparkProtobufCompat.invokeBuildDescriptor( + buildMethod, + FakeSpark35ProtobufUtils, + "test.Message", + ProtobufDescriptorSource.DescriptorBytes(Array[Byte](4, 5, 6)), + _ => fail("binary descriptor source should not read a file")) + + assert(result == "test.Message:4,5,6") + } + + test("compat preserves retry context when descriptor bytes fallback also fails") { + val buildMethod = FakeSpark35RetryFailureProtobufUtils.getClass.getMethod( + "buildDescriptor", classOf[String], classOf[scala.Option[_]]) + + val ex = intercept[RuntimeException] { + SparkProtobufCompat.invokeBuildDescriptor( + buildMethod, + FakeSpark35RetryFailureProtobufUtils, + "test.Message", + ProtobufDescriptorSource.DescriptorPath("/tmp/test.desc"), + _ => Array[Byte](1, 2, 3)) + } + + assert(ex.getMessage.contains("descriptor bytes retry failed")) + assert(ex.getMessage.contains("ClassCastException")) + assert(ex.getMessage.contains("Unknown message test.Message")) + assert(ex.getCause.isInstanceOf[IllegalArgumentException]) + assert(ex.getSuppressed.exists(_.isInstanceOf[java.lang.reflect.InvocationTargetException])) + } + + test("compat distinguishes decode semantics across message descriptor and options") { + val child = FakeExprChild() + + assert(SparkProtobufCompat.sameDecodeSemantics( + FakePathProtobufExpr(child), FakePathProtobufExpr(child))) + assert(SparkProtobufCompat.sameDecodeSemantics( + FakeBytesProtobufExpr(child), FakeBytesProtobufExpr(child))) + assert(!SparkProtobufCompat.sameDecodeSemantics( + FakePathProtobufExpr(child), FakeDifferentMessageExpr(child))) + assert(!SparkProtobufCompat.sameDecodeSemantics( + FakePathProtobufExpr(child), FakeDifferentDescriptorExpr(child))) + assert(!SparkProtobufCompat.sameDecodeSemantics( + FakePathProtobufExpr(child), FakeDifferentOptionsExpr(child))) + } + + test("compat reports missing options accessor as cpu fallback reason") { + val exprInfo = SparkProtobufCompat.extractExprInfo(FakeMissingOptionsExpr(FakeExprChild())) + assert(exprInfo.left.toOption.exists( + _.contains("Cannot read from_protobuf options via reflection"))) + } + + test("compat detects unsupported options and proto3 syntax") { + assert(SparkProtobufCompat.unsupportedOptions(Map("mode" -> "FAILFAST", "foo" -> "bar")) == + Seq("foo")) + assert(!SparkProtobufCompat.isGpuSupportedProtoSyntax("PROTO3")) + assert(!SparkProtobufCompat.isGpuSupportedProtoSyntax("EDITIONS")) + assert(!SparkProtobufCompat.isGpuSupportedProtoSyntax("")) + assert(!SparkProtobufCompat.isGpuSupportedProtoSyntax("null")) + assert(SparkProtobufCompat.isGpuSupportedProtoSyntax("PROTO2")) + } + + test("compat returns Left for unsupported default value types") { + val method = SparkProtobufCompat.getClass.getDeclaredMethods + .find(_.getName.endsWith("toDefaultValue")) + .getOrElse(fail("toDefaultValue method not found")) + method.setAccessible(true) + + val result = method.invoke( + SparkProtobufCompat, + "opaque-default", + "MESSAGE", + scala.None).asInstanceOf[Either[String, ProtobufDefaultValue]] + + assert(result.left.toOption.exists(_.contains("Unsupported protobuf default value type"))) + } + + test("extractor preserves typed enum defaults") { + val enumMeta = ProtobufEnumMetadata(Seq( + ProtobufEnumValue(0, "UNKNOWN"), + ProtobufEnumValue(1, "EN"), + ProtobufEnumValue(2, "ZH"))) + val msgDesc = FakeMessageDescriptor( + syntax = "PROTO2", + fields = Map( + "language" -> FakeFieldDescriptor( + name = "language", + fieldNumber = 1, + protoTypeName = "ENUM", + defaultValue = Some(ProtobufDefaultValue.EnumValue(1, "EN")), + enumMetadata = Some(enumMeta)))) + val schema = StructType(Seq(StructField("language", StringType, nullable = true))) + + val infos = ProtobufSchemaExtractor.analyzeAllFields( + schema, msgDesc, enumsAsInts = false, "test.Message") + + assert(infos.isRight) + assert(infos.toOption.get("language").defaultValue.contains( + ProtobufDefaultValue.EnumValue(1, "EN"))) + } + + test("extractor records reflection failures as unsupported field info") { + val msgDesc = FakeMessageDescriptor( + syntax = "PROTO2", + fields = Map( + "ok" -> FakeFieldDescriptor( + name = "ok", + fieldNumber = 1, + protoTypeName = "INT32"), + "id" -> FakeFieldDescriptor( + name = "id", + fieldNumber = 2, + protoTypeName = "INT32", + defaultValueError = + Some("Failed to read protobuf default value for field 'id': unsupported type")))) + val schema = StructType(Seq( + StructField("ok", IntegerType, nullable = true), + StructField("id", IntegerType, nullable = true))) + + val infos = ProtobufSchemaExtractor.analyzeAllFields( + schema, msgDesc, enumsAsInts = true, "test.Message") + + assert(infos.isRight) + assert(infos.toOption.get("ok").isSupported) + assert(!infos.toOption.get("id").isSupported) + assert(infos.toOption.get("id").unsupportedReason.exists( + _.contains("Failed to read protobuf default value for field 'id'"))) + } + + test("extractor preserves type mismatch reason over default reflection failure") { + val fieldInfo = ProtobufSchemaExtractor.extractFieldInfo( + StructField("id", StringType, nullable = true), + FakeFieldDescriptor( + name = "id", + fieldNumber = 1, + protoTypeName = "INT32", + defaultValueError = + Some("Failed to read protobuf default value for field 'id': unsupported type")), + enumsAsInts = true) + + assert(fieldInfo.isRight) + assert(!fieldInfo.toOption.get.isSupported) + assert(fieldInfo.toOption.get.unsupportedReason.contains( + "type mismatch: Spark StringType vs Protobuf INT32")) + } + + test("extractor gives explicit reason for unsupported FLOAT/DOUBLE widening mismatches") { + val doubleFromFloat = ProtobufSchemaExtractor.extractFieldInfo( + StructField("score", DoubleType, nullable = true), + FakeFieldDescriptor( + name = "score", + fieldNumber = 1, + protoTypeName = "FLOAT"), + enumsAsInts = true) + val floatFromDouble = ProtobufSchemaExtractor.extractFieldInfo( + StructField("score", FloatType, nullable = true), + FakeFieldDescriptor( + name = "score", + fieldNumber = 1, + protoTypeName = "DOUBLE"), + enumsAsInts = true) + + assert(doubleFromFloat.isRight) + assert(!doubleFromFloat.toOption.get.isSupported) + assert(doubleFromFloat.toOption.get.unsupportedReason.contains( + "Spark DoubleType mapped to Protobuf FLOAT is not yet supported on GPU; " + + "use FloatType or fall back to CPU")) + assert(floatFromDouble.isRight) + assert(!floatFromDouble.toOption.get.isSupported) + assert(floatFromDouble.toOption.get.unsupportedReason.contains( + "Spark FloatType mapped to Protobuf DOUBLE is not yet supported on GPU; " + + "use DoubleType or fall back to CPU")) + } + + test("validator encodes enum-string defaults into both numeric and string payloads") { + val enumMeta = ProtobufEnumMetadata(Seq( + ProtobufEnumValue(0, "UNKNOWN"), + ProtobufEnumValue(1, "EN"))) + val info = ProtobufFieldInfo( + fieldNumber = 2, + protoTypeName = "ENUM", + sparkType = StringType, + encoding = GpuFromProtobuf.ENC_ENUM_STRING, + isSupported = true, + unsupportedReason = None, + isRequired = false, + defaultValue = Some(ProtobufDefaultValue.EnumValue(1, "EN")), + enumMetadata = Some(enumMeta), + isRepeated = false) + + val flat = ProtobufSchemaValidator.toFlattenedFieldDescriptor( + path = "common.language", + field = StructField("language", StringType, nullable = true), + fieldInfo = info, + parentIdx = 0, + depth = 1, + outputTypeId = 6) + + assert(flat.isRight) + assert(flat.toOption.get.defaultInt == 1L) + assert(new String(flat.toOption.get.defaultString, "UTF-8") == "EN") + assert(flat.toOption.get.enumValidValues.sameElements(Array(0, 1))) + assert(flat.toOption.get.enumNames + .map(new String(_, "UTF-8")) + .sameElements(Array("UNKNOWN", "EN"))) + } + + test("validator rejects enum-string field without enum metadata") { + val info = ProtobufFieldInfo( + fieldNumber = 2, + protoTypeName = "ENUM", + sparkType = StringType, + encoding = GpuFromProtobuf.ENC_ENUM_STRING, + isSupported = true, + unsupportedReason = None, + isRequired = false, + defaultValue = Some(ProtobufDefaultValue.EnumValue(1, "EN")), + enumMetadata = None, + isRepeated = false) + + val flat = ProtobufSchemaValidator.toFlattenedFieldDescriptor( + path = "common.language", + field = StructField("language", StringType, nullable = true), + fieldInfo = info, + parentIdx = 0, + depth = 1, + outputTypeId = 6) + + assert(flat.left.toOption.exists(_.contains("missing enum metadata"))) + } + + test("validator returns Left for incompatible default type instead of throwing") { + val info = ProtobufFieldInfo( + fieldNumber = 3, + protoTypeName = "FLOAT", + sparkType = DoubleType, + encoding = GpuFromProtobuf.ENC_DEFAULT, + isSupported = true, + unsupportedReason = None, + isRequired = false, + defaultValue = Some(ProtobufDefaultValue.FloatValue(1.5f)), + enumMetadata = None, + isRepeated = false) + + val flat = ProtobufSchemaValidator.toFlattenedFieldDescriptor( + path = "common.score", + field = StructField("score", DoubleType, nullable = true), + fieldInfo = info, + parentIdx = 0, + depth = 1, + outputTypeId = 6) + + assert(flat.left.toOption.exists( + _.contains("Incompatible default value for protobuf field 'common.score'"))) + } + + test("validator rejects flattened schema with non-STRUCT parent") { + val flatFields = Seq( + FlattenedFieldDescriptor( + fieldNumber = 1, + parentIdx = -1, + depth = 0, + wireType = 0, + outputTypeId = DType.INT32.getTypeId.getNativeId, + encoding = GpuFromProtobuf.ENC_DEFAULT, + isRepeated = false, + isRequired = false, + hasDefaultValue = false, + defaultInt = 0L, + defaultFloat = 0.0, + defaultBool = false, + defaultString = Array.emptyByteArray, + enumValidValues = null, + enumNames = null), + FlattenedFieldDescriptor( + fieldNumber = 2, + parentIdx = 0, + depth = 1, + wireType = 0, + outputTypeId = DType.INT32.getTypeId.getNativeId, + encoding = GpuFromProtobuf.ENC_DEFAULT, + isRepeated = false, + isRequired = false, + hasDefaultValue = false, + defaultInt = 0L, + defaultFloat = 0.0, + defaultBool = false, + defaultString = Array.emptyByteArray, + enumValidValues = null, + enumNames = null)) + + val validation = ProtobufSchemaValidator.validateFlattenedSchema(flatFields) + assert(validation.left.toOption.exists(_.contains("non-STRUCT parent"))) + } + + test("array struct field meta uses pruned child field count after ordinal remap") { + val originalStruct = StructType(Seq( + StructField("a", IntegerType, nullable = true), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = true))) + val prunedStruct = StructType(Seq(StructField("b", IntegerType, nullable = true))) + val originalChild = FakeTypedUnaryExpr(ArrayType(originalStruct, containsNull = true)) + val sparkExpr = GetArrayStructFields( + child = originalChild, + field = originalStruct.fields(1), + ordinal = 1, + numFields = originalStruct.fields.length, + containsNull = true) + sparkExpr.setTagValue(GpuStructFieldOrdinalTag.PRUNED_ORDINAL_TAG, 0) + + val prunedChild = FakeTypedUnaryExpr(ArrayType(prunedStruct, containsNull = true)) + val runtimeOrd = sparkExpr.getTagValue(GpuStructFieldOrdinalTag.PRUNED_ORDINAL_TAG).get + + assert(runtimeOrd == 0) + assert( + GpuGetArrayStructFieldsMeta.effectiveNumFields(prunedChild, sparkExpr, runtimeOrd) == 1) + } + + test("GpuFromProtobuf semantic equality is content-based for schema arrays") { + def emptyEnumNames: Array[Array[Byte]] = Array.empty[Array[Byte]] + + val expr1 = GpuFromProtobuf( + decodedSchema = outputSchema, + fieldNumbers = Array(1, 2), + parentIndices = Array(-1, -1), + depthLevels = Array(0, 0), + wireTypes = Array(0, 2), + outputTypeIds = Array(3, 6), + encodings = Array(0, 0), + isRepeated = Array(false, false), + isRequired = Array(false, false), + hasDefaultValue = Array(false, false), + defaultInts = Array(0L, 0L), + defaultFloats = Array(0.0, 0.0), + defaultBools = Array(false, false), + defaultStrings = Array(Array.emptyByteArray, Array.emptyByteArray), + enumValidValues = Array(Array.emptyIntArray, Array.emptyIntArray), + enumNames = Array(emptyEnumNames, emptyEnumNames), + failOnErrors = true, + child = FakeExprChild()) + + val expr2 = GpuFromProtobuf( + decodedSchema = outputSchema, + fieldNumbers = Array(1, 2), + parentIndices = Array(-1, -1), + depthLevels = Array(0, 0), + wireTypes = Array(0, 2), + outputTypeIds = Array(3, 6), + encodings = Array(0, 0), + isRepeated = Array(false, false), + isRequired = Array(false, false), + hasDefaultValue = Array(false, false), + defaultInts = Array(0L, 0L), + defaultFloats = Array(0.0, 0.0), + defaultBools = Array(false, false), + defaultStrings = Array(Array.emptyByteArray, Array.emptyByteArray), + enumValidValues = Array(Array.emptyIntArray, Array.emptyIntArray), + enumNames = Array(emptyEnumNames.map(identity), emptyEnumNames.map(identity)), + failOnErrors = true, + child = FakeExprChild()) + + assert(expr1.semanticEquals(expr2)) + assert(expr1.semanticHash() == expr2.semanticHash()) + } + + test("protobuf binary defaults use content-based equality") { + val left = ProtobufDefaultValue.BinaryValue(Array[Byte](1, 2, 3)) + val right = ProtobufDefaultValue.BinaryValue(Array[Byte](1, 2, 3)) + + assert(left == right) + assert(left.hashCode() == right.hashCode()) + } + + test("flattened field descriptor uses content-based equality for array fields") { + val left = FlattenedFieldDescriptor( + fieldNumber = 1, + parentIdx = -1, + depth = 0, + wireType = 2, + outputTypeId = 6, + encoding = 0, + isRepeated = false, + isRequired = false, + hasDefaultValue = true, + defaultInt = 0L, + defaultFloat = 0.0, + defaultBool = false, + defaultString = Array[Byte](1, 2), + enumValidValues = Array(0, 1), + enumNames = Array("A".getBytes("UTF-8"), "B".getBytes("UTF-8"))) + val right = FlattenedFieldDescriptor( + fieldNumber = 1, + parentIdx = -1, + depth = 0, + wireType = 2, + outputTypeId = 6, + encoding = 0, + isRepeated = false, + isRequired = false, + hasDefaultValue = true, + defaultInt = 0L, + defaultFloat = 0.0, + defaultBool = false, + defaultString = Array[Byte](1, 2), + enumValidValues = Array(0, 1), + enumNames = Array("A".getBytes("UTF-8"), "B".getBytes("UTF-8"))) + + assert(left == right) + assert(left.hashCode() == right.hashCode()) + } +}