diff --git a/src/main/cpp/CMakeLists.txt b/src/main/cpp/CMakeLists.txt index 443b08021d..b8a32aac4a 100644 --- a/src/main/cpp/CMakeLists.txt +++ b/src/main/cpp/CMakeLists.txt @@ -207,6 +207,7 @@ add_library( src/NativeParquetJni.cpp src/NumberConverterJni.cpp src/ParseURIJni.cpp + src/ProtobufJni.cpp src/RegexRewriteUtilsJni.cpp src/RowConversionJni.cpp src/SparkResourceAdaptorJni.cpp @@ -254,6 +255,9 @@ add_library( src/multiply.cu src/number_converter.cu src/parse_uri.cu + src/protobuf/protobuf.cu + src/protobuf/protobuf_builders.cu + src/protobuf/protobuf_kernels.cu src/regex_rewrite_utils.cu src/row_conversion.cu src/round_float.cu diff --git a/src/main/cpp/src/ProtobufJni.cpp b/src/main/cpp/src/ProtobufJni.cpp new file mode 100644 index 0000000000..c5eb07f4b6 --- /dev/null +++ b/src/main/cpp/src/ProtobufJni.cpp @@ -0,0 +1,219 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cudf_jni_apis.hpp" +#include "protobuf/protobuf.hpp" + +#include +#include + +namespace { + +/** + * Convert a Java Object[] of primitive arrays into a vector-of-vectors. + * @tparam CppT Element type in the output vectors (e.g. host_vector, + * host_vector). + * @param convert Per-element callback: (JNIEnv*, jobject) -> std::vector. + * Must return an empty vector on null input. Returns std::nullopt on JNI error. + */ +template +std::vector jni_array_of_arrays_to_vectors(JNIEnv* env, + jobjectArray arr, + int num_elements, + ConvertFn convert) +{ + std::vector result; + result.reserve(num_elements); + for (int i = 0; i < num_elements; ++i) { + jobject elem = env->GetObjectArrayElement(arr, i); + if (env->ExceptionCheck()) { return {}; } + auto vec = convert(env, elem); + if (elem != nullptr) { env->DeleteLocalRef(elem); } + if (env->ExceptionCheck()) { return {}; } + result.push_back(std::move(vec)); + } + return result; +} + +cudf::detail::host_vector jni_byte_array_to_vector(JNIEnv* env, jobject obj) +{ + if (obj == nullptr) { + return cudf::detail::make_host_vector(0, cudf::get_default_stream()); + } + auto byte_arr = static_cast(obj); + jsize len = env->GetArrayLength(byte_arr); + jbyte* bytes = env->GetByteArrayElements(byte_arr, nullptr); + if (bytes == nullptr) { + return cudf::detail::make_host_vector(0, cudf::get_default_stream()); + } + auto vec = cudf::detail::make_host_vector(len, cudf::get_default_stream()); + std::copy( + reinterpret_cast(bytes), reinterpret_cast(bytes) + len, vec.begin()); + env->ReleaseByteArrayElements(byte_arr, bytes, JNI_ABORT); + return vec; +} + +cudf::detail::host_vector jni_int_array_to_vector(JNIEnv* env, jobject obj) +{ + if (obj == nullptr) { + return cudf::detail::make_host_vector(0, cudf::get_default_stream()); + } + auto int_arr = static_cast(obj); + jsize len = env->GetArrayLength(int_arr); + jint* ints = env->GetIntArrayElements(int_arr, nullptr); + if (ints == nullptr) { + return cudf::detail::make_host_vector(0, cudf::get_default_stream()); + } + auto vec = cudf::detail::make_host_vector(len, cudf::get_default_stream()); + std::copy(ints, ints + len, vec.begin()); + env->ReleaseIntArrayElements(int_arr, ints, JNI_ABORT); + return vec; +} + +} // namespace + +extern "C" { + +JNIEXPORT jlong JNICALL +Java_com_nvidia_spark_rapids_jni_Protobuf_decodeToStruct(JNIEnv* env, + jclass, + jlong binary_input_view, + jintArray field_numbers, + jintArray parent_indices, + jintArray depth_levels, + jintArray wire_types, + jintArray output_type_ids, + jintArray encodings, + jbooleanArray is_repeated, + jbooleanArray is_required, + jbooleanArray has_default_value, + jlongArray default_ints, + jdoubleArray default_floats, + jbooleanArray default_bools, + jobjectArray default_strings, + jobjectArray enum_valid_values, + jobjectArray enum_names, + jboolean fail_on_errors) +{ + auto const all_inputs_valid = binary_input_view && field_numbers && parent_indices && + depth_levels && wire_types && output_type_ids && encodings && + is_repeated && is_required && has_default_value && default_ints && + default_floats && default_bools && default_strings && + enum_valid_values && enum_names; + JNI_NULL_CHECK(env, all_inputs_valid, "one or more input arrays are null", 0); + + JNI_TRY + { + cudf::jni::auto_set_device(env); + auto const* input = reinterpret_cast(binary_input_view); + + cudf::jni::native_jintArray n_field_numbers(env, field_numbers); + cudf::jni::native_jintArray n_parent_indices(env, parent_indices); + cudf::jni::native_jintArray n_depth_levels(env, depth_levels); + cudf::jni::native_jintArray n_wire_types(env, wire_types); + cudf::jni::native_jintArray n_output_type_ids(env, output_type_ids); + cudf::jni::native_jintArray n_encodings(env, encodings); + cudf::jni::native_jbooleanArray n_is_repeated(env, is_repeated); + cudf::jni::native_jbooleanArray n_is_required(env, is_required); + cudf::jni::native_jbooleanArray n_has_default(env, has_default_value); + cudf::jni::native_jlongArray n_default_ints(env, default_ints); + cudf::jni::native_jdoubleArray n_default_floats(env, default_floats); + cudf::jni::native_jbooleanArray n_default_bools(env, default_bools); + + int num_fields = n_field_numbers.size(); + + // Validate array sizes + if (n_parent_indices.size() != num_fields || n_depth_levels.size() != num_fields || + n_wire_types.size() != num_fields || n_output_type_ids.size() != num_fields || + n_encodings.size() != num_fields || n_is_repeated.size() != num_fields || + n_is_required.size() != num_fields || n_has_default.size() != num_fields || + n_default_ints.size() != num_fields || n_default_floats.size() != num_fields || + n_default_bools.size() != num_fields || + env->GetArrayLength(default_strings) != num_fields || + env->GetArrayLength(enum_valid_values) != num_fields || + env->GetArrayLength(enum_names) != num_fields) { + JNI_THROW_NEW(env, + cudf::jni::ILLEGAL_ARG_EXCEPTION_CLASS, + "All field arrays must have the same length", + 0); + } + + // Build schema descriptors + std::vector schema; + schema.reserve(num_fields); + for (int i = 0; i < num_fields; ++i) { + schema.push_back({n_field_numbers[i], + n_parent_indices[i], + n_depth_levels[i], + static_cast(n_wire_types[i]), + static_cast(n_output_type_ids[i]), + static_cast(n_encodings[i]), + n_is_repeated[i] != 0, + n_is_required[i] != 0, + n_has_default[i] != 0}); + } + + // Convert boolean arrays + std::vector default_bool_values; + default_bool_values.reserve(num_fields); + for (int i = 0; i < num_fields; ++i) { + default_bool_values.push_back(n_default_bools[i] != 0); + } + + // Convert default values + std::vector default_int_values(n_default_ints.begin(), n_default_ints.end()); + std::vector default_float_values(n_default_floats.begin(), n_default_floats.end()); + + auto default_string_values = jni_array_of_arrays_to_vectors>( + env, default_strings, num_fields, jni_byte_array_to_vector); + if (env->ExceptionCheck()) { return 0; } + + auto enum_values = jni_array_of_arrays_to_vectors>( + env, enum_valid_values, num_fields, jni_int_array_to_vector); + if (env->ExceptionCheck()) { return 0; } + + auto enum_name_values = + jni_array_of_arrays_to_vectors>>( + env, + enum_names, + num_fields, + [](JNIEnv* e, jobject obj) -> std::vector> { + if (obj == nullptr) { return {}; } + auto inner_arr = static_cast(obj); + jsize num = e->GetArrayLength(inner_arr); + return jni_array_of_arrays_to_vectors>( + e, inner_arr, num, jni_byte_array_to_vector); + }); + if (env->ExceptionCheck()) { return 0; } + + spark_rapids_jni::protobuf::protobuf_decode_context context{std::move(schema), + std::move(default_int_values), + std::move(default_float_values), + std::move(default_bool_values), + std::move(default_string_values), + std::move(enum_values), + std::move(enum_name_values), + static_cast(fail_on_errors)}; + + auto result = spark_rapids_jni::protobuf::decode_protobuf_to_struct( + *input, context, cudf::get_default_stream(), cudf::get_current_device_resource_ref()); + + return cudf::jni::release_as_jlong(result); + } + JNI_CATCH(env, 0); +} + +} // extern "C" diff --git a/src/main/cpp/src/protobuf/protobuf.cu b/src/main/cpp/src/protobuf/protobuf.cu new file mode 100644 index 0000000000..1082843cad --- /dev/null +++ b/src/main/cpp/src/protobuf/protobuf.cu @@ -0,0 +1,346 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nvtx_ranges.hpp" +#include "protobuf/protobuf_kernels.cuh" + +#include + +#include +#include + +namespace spark_rapids_jni::protobuf { + +namespace detail { + +namespace { + +std::unique_ptr make_null_column_with_schema( + std::vector const& schema, + int schema_idx, + int num_fields, + cudf::size_type num_rows, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + auto const& field = schema[schema_idx]; + auto const dtype = cudf::data_type{schema[schema_idx].output_type}; + + if (field.is_repeated) { + std::unique_ptr empty_child; + if (dtype.id() == cudf::type_id::STRUCT) { + empty_child = + make_empty_struct_column_with_schema(schema, schema_idx, num_fields, stream, mr); + } else { + empty_child = make_empty_column_safe(dtype, stream, mr); + } + return make_null_list_column_with_child(std::move(empty_child), num_rows, stream, mr); + } + + if (dtype.id() == cudf::type_id::STRUCT) { + auto child_indices = find_child_field_indices(schema, num_fields, schema_idx); + std::vector> children; + for (auto const child_idx : child_indices) { + children.push_back( + make_null_column_with_schema(schema, child_idx, num_fields, num_rows, stream, mr)); + } + auto null_mask = cudf::create_null_mask(num_rows, cudf::mask_state::ALL_NULL, stream, mr); + return cudf::make_structs_column( + num_rows, std::move(children), num_rows, std::move(null_mask), stream, mr); + } + + return make_null_column(dtype, num_rows, stream, mr); +} + +} // namespace + +bool is_encoding_compatible(nested_field_descriptor const& field, cudf::data_type const& type) +{ + switch (field.encoding) { + case proto_encoding::DEFAULT: + switch (type.id()) { + case cudf::type_id::BOOL8: + case cudf::type_id::INT32: + case cudf::type_id::UINT32: + case cudf::type_id::INT64: + case cudf::type_id::UINT64: return field.wire_type == proto_wire_type::VARINT; + case cudf::type_id::FLOAT32: return field.wire_type == proto_wire_type::I32BIT; + case cudf::type_id::FLOAT64: return field.wire_type == proto_wire_type::I64BIT; + case cudf::type_id::STRING: + case cudf::type_id::LIST: + case cudf::type_id::STRUCT: return field.wire_type == proto_wire_type::LEN; + default: return false; + } + case proto_encoding::FIXED: + switch (type.id()) { + case cudf::type_id::INT32: + case cudf::type_id::UINT32: + case cudf::type_id::FLOAT32: return field.wire_type == proto_wire_type::I32BIT; + case cudf::type_id::INT64: + case cudf::type_id::UINT64: + case cudf::type_id::FLOAT64: return field.wire_type == proto_wire_type::I64BIT; + default: return false; + } + case proto_encoding::ZIGZAG: + return field.wire_type == proto_wire_type::VARINT && + (type.id() == cudf::type_id::INT32 || type.id() == cudf::type_id::INT64); + case proto_encoding::ENUM_STRING: + return field.wire_type == proto_wire_type::VARINT && type.id() == cudf::type_id::STRING; + default: return false; + } +} + +void validate_decode_context(protobuf_decode_context const& context) +{ + auto const num_fields = context.schema.size(); + if (context.default_ints.size() != num_fields) { + CUDF_FAIL("protobuf decode context: default_ints size mismatch with schema (" + + std::to_string(context.default_ints.size()) + " vs " + std::to_string(num_fields) + + ")", + std::invalid_argument); + } + if (context.default_floats.size() != num_fields) { + CUDF_FAIL("protobuf decode context: default_floats size mismatch with schema (" + + std::to_string(context.default_floats.size()) + " vs " + + std::to_string(num_fields) + ")", + std::invalid_argument); + } + if (context.default_bools.size() != num_fields) { + CUDF_FAIL("protobuf decode context: default_bools size mismatch with schema (" + + std::to_string(context.default_bools.size()) + " vs " + std::to_string(num_fields) + + ")", + std::invalid_argument); + } + if (context.default_strings.size() != num_fields) { + CUDF_FAIL("protobuf decode context: default_strings size mismatch with schema (" + + std::to_string(context.default_strings.size()) + " vs " + + std::to_string(num_fields) + ")", + std::invalid_argument); + } + if (context.enum_valid_values.size() != num_fields) { + CUDF_FAIL("protobuf decode context: enum_valid_values size mismatch with schema (" + + std::to_string(context.enum_valid_values.size()) + " vs " + + std::to_string(num_fields) + ")", + std::invalid_argument); + } + if (context.enum_names.size() != num_fields) { + CUDF_FAIL("protobuf decode context: enum_names size mismatch with schema (" + + std::to_string(context.enum_names.size()) + " vs " + std::to_string(num_fields) + + ")", + std::invalid_argument); + } + + std::unordered_set seen_field_numbers; + for (size_t i = 0; i < num_fields; ++i) { + auto const& field = context.schema[i]; + auto const type = cudf::data_type{field.output_type}; + if (field.field_number <= 0 || field.field_number > MAX_FIELD_NUMBER) { + CUDF_FAIL("protobuf decode context: invalid field number at field " + std::to_string(i), + std::invalid_argument); + } + if (field.depth < 0 || field.depth >= MAX_NESTING_DEPTH) { + CUDF_FAIL("protobuf decode context: field depth exceeds supported limit at field " + + std::to_string(i), + std::invalid_argument); + } + if (field.parent_idx < -1 || field.parent_idx >= static_cast(i)) { + CUDF_FAIL("protobuf decode context: invalid parent index at field " + std::to_string(i), + std::invalid_argument); + } + auto const key = (static_cast(static_cast(field.parent_idx)) << 32) | + static_cast(field.field_number); + if (!seen_field_numbers.insert(key).second) { + CUDF_FAIL("protobuf decode context: duplicate field number under same parent at field " + + std::to_string(i), + std::invalid_argument); + } + + if (field.parent_idx == -1) { + if (field.depth != 0) { + CUDF_FAIL("protobuf decode context: top-level field must have depth 0 at field " + + std::to_string(i), + std::invalid_argument); + } + } else { + auto const& parent = context.schema[field.parent_idx]; + if (field.depth != parent.depth + 1) { + CUDF_FAIL("protobuf decode context: child depth mismatch at field " + std::to_string(i), + std::invalid_argument); + } + if (cudf::data_type{context.schema[field.parent_idx].output_type}.id() != + cudf::type_id::STRUCT) { + CUDF_FAIL("protobuf decode context: parent must be STRUCT at field " + std::to_string(i), + std::invalid_argument); + } + } + + if (field.wire_type != proto_wire_type::VARINT && field.wire_type != proto_wire_type::I64BIT && + field.wire_type != proto_wire_type::LEN && field.wire_type != proto_wire_type::I32BIT) { + CUDF_FAIL("protobuf decode context: invalid wire type at field " + std::to_string(i), + std::invalid_argument); + } + if (field.encoding < proto_encoding::DEFAULT || field.encoding > proto_encoding::ENUM_STRING) { + CUDF_FAIL("protobuf decode context: invalid encoding at field " + std::to_string(i), + std::invalid_argument); + } + if (field.is_repeated && field.is_required) { + CUDF_FAIL("protobuf decode context: field cannot be both repeated and required at field " + + std::to_string(i), + std::invalid_argument); + } + if (field.is_repeated && field.has_default_value) { + CUDF_FAIL("protobuf decode context: repeated field cannot carry default value at field " + + std::to_string(i), + std::invalid_argument); + } + if (field.has_default_value && + (type.id() == cudf::type_id::STRUCT || type.id() == cudf::type_id::LIST)) { + CUDF_FAIL("protobuf decode context: STRUCT/LIST field cannot carry default value at field " + + std::to_string(i), + std::invalid_argument); + } + if (!is_encoding_compatible(field, type)) { + CUDF_FAIL("protobuf decode context: incompatible wire type/encoding/output type at field " + + std::to_string(i), + std::invalid_argument); + } + + if (field.encoding == proto_encoding::ENUM_STRING) { + if (context.enum_valid_values[i].empty() || context.enum_names[i].empty()) { + CUDF_FAIL( + "protobuf decode context: enum-as-string field requires non-empty metadata at field " + + std::to_string(i), + std::invalid_argument); + } + if (context.enum_valid_values[i].size() != context.enum_names[i].size()) { + CUDF_FAIL( + "protobuf decode context: enum-as-string metadata mismatch at field " + std::to_string(i), + std::invalid_argument); + } + auto const& ev = context.enum_valid_values[i]; + for (size_t j = 1; j < ev.size(); ++j) { + if (ev[j] <= ev[j - 1]) { + CUDF_FAIL("protobuf decode context: enum_valid_values must be strictly sorted at field " + + std::to_string(i), + std::invalid_argument); + } + } + } + } +} + +protobuf_field_meta_view make_field_meta_view(protobuf_decode_context const& context, + int schema_idx) +{ + auto const idx = static_cast(schema_idx); + return protobuf_field_meta_view{context.schema.at(idx), + cudf::data_type{context.schema.at(idx).output_type}, + context.default_ints.at(idx), + context.default_floats.at(idx), + context.default_bools.at(idx), + context.default_strings.at(idx), + context.enum_valid_values.at(idx), + context.enum_names.at(idx)}; +} + +std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& binary_input, + protobuf_decode_context const& context, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + validate_decode_context(context); + auto const& schema = context.schema; + CUDF_EXPECTS(binary_input.type().id() == cudf::type_id::LIST, + "binary_input must be a LIST column"); + cudf::lists_column_view const in_list(binary_input); + auto const child_type = in_list.child().type().id(); + CUDF_EXPECTS(child_type == cudf::type_id::INT8 || child_type == cudf::type_id::UINT8, + "binary_input must be a LIST column"); + + auto const num_rows = binary_input.size(); + auto const num_fields = static_cast(schema.size()); + + if (num_fields == 0) { + auto const input_null_count = binary_input.null_count(); + if (input_null_count > 0) { + auto null_mask = cudf::copy_bitmask(binary_input, stream, mr); + return cudf::make_structs_column( + num_rows, {}, input_null_count, std::move(null_mask), stream, mr); + } + return cudf::make_structs_column(num_rows, {}, 0, rmm::device_buffer{}, stream, mr); + } + + if (num_rows == 0) { + std::vector> empty_children; + for (int i = 0; i < num_fields; i++) { + if (schema[i].parent_idx == -1) { + auto field_type = cudf::data_type{schema[i].output_type}; + if (schema[i].is_repeated && field_type.id() == cudf::type_id::STRUCT) { + auto empty_struct = + make_empty_struct_column_with_schema(schema, i, num_fields, stream, mr); + empty_children.push_back(make_empty_list_column(std::move(empty_struct), stream, mr)); + } else if (schema[i].is_repeated) { + auto empty_child = make_empty_column_safe(field_type, stream, mr); + empty_children.push_back(make_empty_list_column(std::move(empty_child), stream, mr)); + } else if (field_type.id() == cudf::type_id::STRUCT) { + empty_children.push_back( + make_empty_struct_column_with_schema(schema, i, num_fields, stream, mr)); + } else { + empty_children.push_back(make_empty_column_safe(field_type, stream, mr)); + } + } + } + return cudf::make_structs_column( + 0, std::move(empty_children), 0, rmm::device_buffer{}, stream, mr); + } + + std::vector> column_map(num_fields); + + std::vector> top_level_children; + for (int i = 0; i < num_fields; i++) { + if (schema[i].parent_idx == -1) { + if (column_map[i]) { + top_level_children.push_back(std::move(column_map[i])); + } else { + top_level_children.push_back( + make_null_column_with_schema(schema, i, num_fields, num_rows, stream, mr)); + } + } + } + + auto const input_null_count = binary_input.null_count(); + if (input_null_count > 0) { + auto null_mask = cudf::copy_bitmask(binary_input, stream, mr); + return cudf::make_structs_column( + num_rows, std::move(top_level_children), input_null_count, std::move(null_mask), stream, mr); + } + + return cudf::make_structs_column( + num_rows, std::move(top_level_children), 0, rmm::device_buffer{}, stream, mr); +} + +} // namespace detail + +std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& binary_input, + protobuf_decode_context const& context, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + SRJ_FUNC_RANGE(); + return detail::decode_protobuf_to_struct(binary_input, context, stream, mr); +} + +} // namespace spark_rapids_jni::protobuf diff --git a/src/main/cpp/src/protobuf/protobuf.hpp b/src/main/cpp/src/protobuf/protobuf.hpp new file mode 100644 index 0000000000..803f0c2d93 --- /dev/null +++ b/src/main/cpp/src/protobuf/protobuf.hpp @@ -0,0 +1,113 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +#include +#include + +#include +#include +#include + +namespace spark_rapids_jni::protobuf { + +enum class proto_encoding : int { + DEFAULT = 0, + FIXED = 1, + ZIGZAG = 2, + ENUM_STRING = 3, +}; + +CUDF_HOST_DEVICE constexpr int encoding_value(proto_encoding encoding) +{ + return static_cast(encoding); +} + +constexpr int MAX_FIELD_NUMBER = (1 << 29) - 1; + +enum class proto_wire_type : int { + VARINT = 0, + I64BIT = 1, + LEN = 2, + SGROUP = 3, + EGROUP = 4, + I32BIT = 5, +}; + +CUDF_HOST_DEVICE constexpr int wire_type_value(proto_wire_type wire_type) +{ + return static_cast(wire_type); +} + +constexpr int MAX_NESTING_DEPTH = 10; + +struct nested_field_descriptor { + int field_number; // Protobuf field number + int parent_idx; // Index of parent field in schema (-1 for top-level) + int depth; // Nesting depth (0 for top-level) + proto_wire_type wire_type; // Expected wire type + cudf::type_id output_type; // Output cudf type + proto_encoding encoding; // Encoding type + bool is_repeated; // Whether this field is repeated (array) + bool is_required; // Whether this field is required (proto2) + bool has_default_value; // Whether this field has a default value +}; + +struct protobuf_decode_context { + std::vector schema; + std::vector default_ints; + std::vector default_floats; + std::vector default_bools; + std::vector> default_strings; + std::vector> enum_valid_values; + std::vector>> enum_names; + bool fail_on_errors; +}; + +struct protobuf_field_meta_view { + nested_field_descriptor const& schema; + cudf::data_type const output_type; + int64_t default_int; + double default_float; + bool default_bool; + cudf::detail::host_vector const& default_string; + cudf::detail::host_vector const& enum_valid_values; + std::vector> const& enum_names; +}; + +namespace detail { + +bool is_encoding_compatible(nested_field_descriptor const& field, cudf::data_type const& type); + +void validate_decode_context(protobuf_decode_context const& context); + +protobuf_field_meta_view make_field_meta_view(protobuf_decode_context const& context, + int schema_idx); + +} // namespace detail + +std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& binary_input, + protobuf_decode_context const& context, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr); + +} // namespace spark_rapids_jni::protobuf diff --git a/src/main/cpp/src/protobuf/protobuf_builders.cu b/src/main/cpp/src/protobuf/protobuf_builders.cu new file mode 100644 index 0000000000..42420acedf --- /dev/null +++ b/src/main/cpp/src/protobuf/protobuf_builders.cu @@ -0,0 +1,126 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "protobuf/protobuf_kernels.cuh" + +#include +#include + +namespace spark_rapids_jni::protobuf::detail { + +std::unique_ptr make_null_column(cudf::data_type dtype, + cudf::size_type num_rows, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + if (num_rows == 0) { return cudf::make_empty_column(dtype); } + + switch (dtype.id()) { + case cudf::type_id::BOOL8: + case cudf::type_id::INT8: + case cudf::type_id::UINT8: + case cudf::type_id::INT16: + case cudf::type_id::UINT16: + case cudf::type_id::INT32: + case cudf::type_id::UINT32: + case cudf::type_id::INT64: + case cudf::type_id::UINT64: + case cudf::type_id::FLOAT32: + case cudf::type_id::FLOAT64: + return cudf::make_fixed_width_column(dtype, num_rows, cudf::mask_state::ALL_NULL, stream, mr); + case cudf::type_id::STRING: { + rmm::device_uvector pairs(num_rows, stream, mr); + thrust::fill(rmm::exec_policy_nosync(stream), + pairs.data(), + pairs.end(), + cudf::strings::detail::string_index_pair{nullptr, 0}); + return cudf::strings::detail::make_strings_column(pairs.data(), pairs.end(), stream, mr); + } + case cudf::type_id::LIST: + return cudf::lists::detail::make_all_nulls_lists_column( + num_rows, cudf::data_type{cudf::type_id::UINT8}, stream, mr); + case cudf::type_id::STRUCT: { + std::vector> empty_children; + auto null_mask = cudf::create_null_mask(num_rows, cudf::mask_state::ALL_NULL, stream, mr); + return cudf::make_structs_column( + num_rows, std::move(empty_children), num_rows, std::move(null_mask), stream, mr); + } + default: CUDF_FAIL("Unsupported type for null column creation"); + } +} + +std::unique_ptr make_empty_column_safe(cudf::data_type dtype, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + switch (dtype.id()) { + case cudf::type_id::LIST: { + auto offsets_col = + std::make_unique(cudf::data_type{cudf::type_id::INT32}, + 1, + rmm::device_buffer(sizeof(int32_t), stream, mr), + rmm::device_buffer{}, + 0); + CUDF_CUDA_TRY(cudaMemsetAsync( + offsets_col->mutable_view().data(), 0, sizeof(int32_t), stream.value())); + auto child_col = std::make_unique( + cudf::data_type{cudf::type_id::UINT8}, 0, rmm::device_buffer{}, rmm::device_buffer{}, 0); + return cudf::make_lists_column( + 0, std::move(offsets_col), std::move(child_col), 0, rmm::device_buffer{}); + } + case cudf::type_id::STRUCT: { + std::vector> empty_children; + return cudf::make_structs_column( + 0, std::move(empty_children), 0, rmm::device_buffer{}, stream, mr); + } + default: return cudf::make_empty_column(dtype); + } +} + +std::unique_ptr make_null_list_column_with_child( + std::unique_ptr child_col, + cudf::size_type num_rows, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + rmm::device_uvector offsets(num_rows + 1, stream, mr); + thrust::fill(rmm::exec_policy_nosync(stream), offsets.begin(), offsets.end(), 0); + auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + num_rows + 1, + offsets.release(), + rmm::device_buffer{}, + 0); + auto null_mask = cudf::create_null_mask(num_rows, cudf::mask_state::ALL_NULL, stream, mr); + return cudf::make_lists_column( + num_rows, std::move(offsets_col), std::move(child_col), num_rows, std::move(null_mask)); +} + +std::unique_ptr make_empty_list_column(std::unique_ptr element_col, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + 1, + rmm::device_buffer(sizeof(int32_t), stream, mr), + rmm::device_buffer{}, + 0); + CUDF_CUDA_TRY(cudaMemsetAsync( + offsets_col->mutable_view().data(), 0, sizeof(int32_t), stream.value())); + return cudf::make_lists_column( + 0, std::move(offsets_col), std::move(element_col), 0, rmm::device_buffer{}); +} + +} // namespace spark_rapids_jni::protobuf::detail diff --git a/src/main/cpp/src/protobuf/protobuf_device_helpers.cuh b/src/main/cpp/src/protobuf/protobuf_device_helpers.cuh new file mode 100644 index 0000000000..2ddb2b27fa --- /dev/null +++ b/src/main/cpp/src/protobuf/protobuf_device_helpers.cuh @@ -0,0 +1,295 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "protobuf/protobuf_types.cuh" + +#include + +#include + +#include +#include + +namespace spark_rapids_jni::protobuf::detail { + +// ============================================================================ +// Device helper functions +// ============================================================================ + +__device__ inline bool read_varint(uint8_t const* cur, + uint8_t const* end, + uint64_t& out, + int& bytes) +{ + out = 0; + bytes = 0; + int shift = 0; + // Protobuf varint uses 7 bits per byte with MSB as continuation flag. + // A 64-bit value requires at most ceil(64/7) = 10 bytes. + while (cur < end && bytes < MAX_VARINT_BYTES) { + uint8_t b = *cur++; + // For the 10th byte (bytes == 9, shift == 63), only the lowest bit is valid + if (bytes == 9 && (b & 0xFE) != 0) { + return false; // Invalid: 10th byte has more than 1 significant bit + } + out |= (static_cast(b & 0x7Fu) << shift); + bytes++; + if ((b & 0x80u) == 0) { return true; } + shift += 7; + } + return false; +} + +__device__ inline void set_error_once(int* error_flag, int error_code) +{ + int expected = 0; + cuda::atomic_ref ref(*error_flag); + ref.compare_exchange_strong(expected, error_code, cuda::memory_order_relaxed); +} + +void set_error_once_async(int* error_flag, int error_code, rmm::cuda_stream_view stream); + +__device__ inline int get_wire_type_size(int wt, uint8_t const* cur, uint8_t const* end) +{ + switch (wt) { + case wire_type_value(proto_wire_type::VARINT): { + // Need to scan to find the end of varint + int count = 0; + while (cur < end && count < MAX_VARINT_BYTES) { + if ((*cur++ & 0x80u) == 0) { return count + 1; } + count++; + } + return -1; // Invalid varint + } + case wire_type_value(proto_wire_type::I64BIT): + // Check if there's enough data for 8 bytes + if (end - cur < 8) return -1; + return 8; + case wire_type_value(proto_wire_type::I32BIT): + // Check if there's enough data for 4 bytes + if (end - cur < 4) return -1; + return 4; + case wire_type_value(proto_wire_type::LEN): { + uint64_t len; + int n; + if (!read_varint(cur, end, len, n)) return -1; + if (len > static_cast(end - cur - n) || + len > static_cast(cuda::std::numeric_limits::max() - n)) { + return -1; + } + return n + static_cast(len); + } + case wire_type_value(proto_wire_type::SGROUP): { + auto const* start = cur; + int depth = 1; + while (cur < end && depth > 0) { + uint64_t key; + int key_bytes; + if (!read_varint(cur, end, key, key_bytes)) return -1; + cur += key_bytes; + + int inner_wt = static_cast(key & 0x7); + if (inner_wt == wire_type_value(proto_wire_type::EGROUP)) { + --depth; + if (depth == 0) { return static_cast(cur - start); } + } else if (inner_wt == wire_type_value(proto_wire_type::SGROUP)) { + if (++depth > 32) return -1; + } else { + int inner_size = -1; + switch (inner_wt) { + case wire_type_value(proto_wire_type::VARINT): { + uint64_t dummy; + int vbytes; + if (!read_varint(cur, end, dummy, vbytes)) return -1; + inner_size = vbytes; + break; + } + case wire_type_value(proto_wire_type::I64BIT): inner_size = 8; break; + case wire_type_value(proto_wire_type::LEN): { + uint64_t len; + int len_bytes; + if (!read_varint(cur, end, len, len_bytes)) return -1; + if (len > static_cast(cuda::std::numeric_limits::max() - len_bytes)) { + return -1; + } + inner_size = len_bytes + static_cast(len); + break; + } + case wire_type_value(proto_wire_type::I32BIT): inner_size = 4; break; + default: return -1; + } + if (inner_size < 0 || cur + inner_size > end) return -1; + cur += inner_size; + } + } + return -1; + } + case wire_type_value(proto_wire_type::EGROUP): return 0; + default: return -1; + } +} + +__device__ inline bool skip_field(uint8_t const* cur, + uint8_t const* end, + int wt, + uint8_t const*& out_cur) +{ + // A bare end-group is only valid while a start-group payload is being parsed recursively inside + // get_wire_type_size(wire_type_value(proto_wire_type::SGROUP)). + // The scan/count kernels should never accept it as a standalone field because Spark CPU treats + // unmatched end-groups as malformed protobuf. + if (wt == wire_type_value(proto_wire_type::EGROUP)) { return false; } + + int size = get_wire_type_size(wt, cur, end); + if (size < 0) return false; + // Ensure we don't skip past the end of the buffer + if (cur + size > end) return false; + out_cur = cur + size; + return true; +} + +/** + * Get the data offset and length for a field at current position. + * Returns true on success, false on error. + */ +__device__ inline bool get_field_data_location( + uint8_t const* cur, uint8_t const* end, int wt, int32_t& data_offset, int32_t& data_length) +{ + if (wt == wire_type_value(proto_wire_type::LEN)) { + // For length-delimited, read the length prefix + uint64_t len; + int len_bytes; + if (!read_varint(cur, end, len, len_bytes)) return false; + if (len > static_cast(end - cur - len_bytes) || + len > static_cast(cuda::std::numeric_limits::max())) { + return false; + } + data_offset = len_bytes; // offset past the length prefix + data_length = static_cast(len); + } else { + // For fixed-size and varint fields + int field_size = get_wire_type_size(wt, cur, end); + if (field_size < 0) return false; + data_offset = 0; + data_length = field_size; + } + return true; +} + +CUDF_HOST_DEVICE inline size_t flat_index(size_t row, size_t width, size_t col) +{ + return row * width + col; +} + +__device__ inline bool checked_add_int32(int32_t lhs, int32_t rhs, int32_t& out) +{ + auto const sum = static_cast(lhs) + rhs; + if (sum < cuda::std::numeric_limits::min() || + sum > cuda::std::numeric_limits::max()) { + return false; + } + out = static_cast(sum); + return true; +} + +__device__ inline bool check_message_bounds(int32_t start, + int32_t end_pos, + cudf::size_type total_size, + int* error_flag) +{ + if (start < 0 || end_pos < start || end_pos > total_size) { + set_error_once(error_flag, ERR_BOUNDS); + return false; + } + return true; +} + +struct proto_tag { + int field_number; + int wire_type; +}; + +__device__ inline bool decode_tag(uint8_t const*& cur, + uint8_t const* end, + proto_tag& tag, + int* error_flag) +{ + uint64_t key; + int key_bytes; + if (!read_varint(cur, end, key, key_bytes)) { + set_error_once(error_flag, ERR_VARINT); + return false; + } + + cur += key_bytes; + uint64_t fn = key >> 3; + if (fn == 0 || fn > static_cast(MAX_FIELD_NUMBER)) { + set_error_once(error_flag, ERR_FIELD_NUMBER); + return false; + } + tag.field_number = static_cast(fn); + tag.wire_type = static_cast(key & 0x7); + return true; +} + +/** + * Load a little-endian value from unaligned memory. + * Reads bytes individually to avoid unaligned-access issues on GPU. + */ +template +__device__ inline T load_le(uint8_t const* p); + +template <> +__device__ inline uint32_t load_le(uint8_t const* p) +{ + return static_cast(p[0]) | (static_cast(p[1]) << 8) | + (static_cast(p[2]) << 16) | (static_cast(p[3]) << 24); +} + +template <> +__device__ inline uint64_t load_le(uint8_t const* p) +{ + uint64_t v = 0; +#pragma unroll + for (int i = 0; i < 8; ++i) { + v |= (static_cast(p[i]) << (8 * i)); + } + return v; +} + +/** + * O(1) lookup of field_number -> field_index using a direct-mapped table. + * Falls back to linear search when the table is empty (field numbers too large). + */ +// Keep this definition in the header so all CUDA translation units can inline it. +__device__ __forceinline__ int lookup_field(int field_number, + int const* lookup_table, + int lookup_table_size, + field_descriptor const* field_descs, + int num_fields) +{ + if (lookup_table != nullptr && field_number > 0 && field_number < lookup_table_size) { + return lookup_table[field_number]; + } + for (int f = 0; f < num_fields; f++) { + if (field_descs[f].field_number == field_number) return f; + } + return -1; +} + +} // namespace spark_rapids_jni::protobuf::detail diff --git a/src/main/cpp/src/protobuf/protobuf_host_helpers.hpp b/src/main/cpp/src/protobuf/protobuf_host_helpers.hpp new file mode 100644 index 0000000000..d1179dc858 --- /dev/null +++ b/src/main/cpp/src/protobuf/protobuf_host_helpers.hpp @@ -0,0 +1,292 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "protobuf/protobuf_types.cuh" + +#include +#include + +#include +#include + +#include +#include +#include +#include + +namespace spark_rapids_jni::protobuf::detail { + +// ============================================================================ +// Field number lookup table helpers +// ============================================================================ + +/** + * Build a host-side direct-mapped lookup table: field_number -> index. + * @param get_field_number Callable: (int i) -> field_number for the i-th entry. + * @param num_entries Number of entries. + * @return Empty vector if the max field number exceeds the threshold. + */ +template +inline std::vector build_lookup_table(FieldNumberFn get_field_number, int num_entries) +{ + int max_fn = 0; + for (int i = 0; i < num_entries; i++) { + max_fn = std::max(max_fn, get_field_number(i)); + } + if (max_fn > FIELD_LOOKUP_TABLE_MAX) { return {}; } + std::vector table(max_fn + 1, -1); + for (int i = 0; i < num_entries; i++) { + table[get_field_number(i)] = i; + } + return table; +} + +inline std::vector build_index_lookup_table(nested_field_descriptor const* schema, + int const* field_indices, + int num_indices) +{ + return build_lookup_table([&](int i) { return schema[field_indices[i]].field_number; }, + num_indices); +} + +inline std::vector build_field_lookup_table(field_descriptor const* descs, int num_fields) +{ + return build_lookup_table([&](int i) { return descs[i].field_number; }, num_fields); +} + +/** + * Find all child field indices for a given parent index in the schema. + * This is a commonly used pattern throughout the codebase. + * + * @param schema The schema vector (either nested_field_descriptor or + * device_nested_field_descriptor) + * @param num_fields Number of fields in the schema + * @param parent_idx The parent index to search for + * @return Vector of child field indices + */ +template +std::vector find_child_field_indices(SchemaT const& schema, int num_fields, int parent_idx) +{ + std::vector child_indices; + for (int i = 0; i < num_fields; i++) { + if (schema[i].parent_idx == parent_idx) { child_indices.push_back(i); } + } + return child_indices; +} + +// Forward declarations needed by make_empty_struct_column_with_schema +std::unique_ptr make_empty_column_safe(cudf::data_type dtype, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr); + +std::unique_ptr make_empty_list_column(std::unique_ptr element_col, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr); + +template +std::unique_ptr make_empty_struct_column_with_schema( + SchemaT const& schema, + int parent_idx, + int num_fields, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + auto child_indices = find_child_field_indices(schema, num_fields, parent_idx); + + std::vector> children; + for (int child_idx : child_indices) { + auto child_type = cudf::data_type{schema[child_idx].output_type}; + + std::unique_ptr child_col; + if (child_type.id() == cudf::type_id::STRUCT) { + child_col = make_empty_struct_column_with_schema(schema, child_idx, num_fields, stream, mr); + } else { + child_col = make_empty_column_safe(child_type, stream, mr); + } + + if (schema[child_idx].is_repeated) { + child_col = make_empty_list_column(std::move(child_col), stream, mr); + } + + children.push_back(std::move(child_col)); + } + + return cudf::make_structs_column(0, std::move(children), 0, rmm::device_buffer{}, stream, mr); +} + +void maybe_check_required_fields(field_location const* locations, + std::vector const& field_indices, + std::vector const& schema, + int num_rows, + cudf::bitmask_type const* input_null_mask, + cudf::size_type input_offset, + field_location const* parent_locs, + bool* row_force_null, + int32_t const* top_row_indices, + int* error_flag, + rmm::cuda_stream_view stream); + +void propagate_invalid_enum_flags_to_rows(rmm::device_uvector const& item_invalid, + rmm::device_uvector& row_invalid, + int num_items, + int32_t const* top_row_indices, + bool propagate_to_rows, + rmm::cuda_stream_view stream); + +void validate_enum_and_propagate_rows(rmm::device_uvector const& values, + rmm::device_uvector& valid, + cudf::detail::host_vector const& valid_enums, + rmm::device_uvector& row_invalid, + int num_items, + int32_t const* top_row_indices, + bool propagate_to_rows, + rmm::cuda_stream_view stream); + +// ============================================================================ +// Forward declarations of builder/utility functions +// ============================================================================ + +std::unique_ptr make_null_column(cudf::data_type dtype, + cudf::size_type num_rows, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr); + +std::unique_ptr make_null_list_column_with_child( + std::unique_ptr child_col, + cudf::size_type num_rows, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr); + +std::unique_ptr build_enum_string_column( + rmm::device_uvector& enum_values, + rmm::device_uvector& valid, + cudf::detail::host_vector const& valid_enums, + std::vector> const& enum_name_bytes, + rmm::device_uvector& d_row_force_null, + int num_rows, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr, + int32_t const* top_row_indices = nullptr, + bool propagate_invalid_rows = true); + +// Complex builder forward declarations +std::unique_ptr build_repeated_enum_string_column( + cudf::column_view const& binary_input, + uint8_t const* message_data, + cudf::size_type const* list_offsets, + cudf::size_type base_offset, + rmm::device_uvector const& d_field_counts, + rmm::device_uvector& d_occurrences, + int total_count, + int num_rows, + cudf::detail::host_vector const& valid_enums, + std::vector> const& enum_name_bytes, + rmm::device_uvector& d_row_force_null, + rmm::device_uvector& d_error, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr); + +std::unique_ptr build_repeated_string_column( + cudf::column_view const& binary_input, + uint8_t const* message_data, + cudf::size_type const* list_offsets, + cudf::size_type base_offset, + device_nested_field_descriptor const& field_desc, + rmm::device_uvector const& d_field_counts, + rmm::device_uvector& d_occurrences, + int total_count, + int num_rows, + bool is_bytes, + rmm::device_uvector& d_error, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr); + +std::unique_ptr build_nested_struct_column( + uint8_t const* message_data, + cudf::size_type message_data_size, + cudf::size_type const* list_offsets, + cudf::size_type base_offset, + rmm::device_uvector const& d_parent_locs, + std::vector const& child_field_indices, + std::vector const& schema, + int num_fields, + std::vector const& default_ints, + std::vector const& default_floats, + std::vector const& default_bools, + std::vector> const& default_strings, + std::vector> const& enum_valid_values, + std::vector>> const& enum_names, + rmm::device_uvector& d_row_force_null, + rmm::device_uvector& d_error, + int num_rows, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr, + int32_t const* top_row_indices, + int depth, + bool propagate_invalid_rows = true); + +std::unique_ptr build_repeated_child_list_column( + uint8_t const* message_data, + cudf::size_type message_data_size, + cudf::size_type const* row_offsets, + cudf::size_type base_offset, + field_location const* parent_locs, + int num_parent_rows, + int child_schema_idx, + std::vector const& schema, + int num_fields, + std::vector const& default_ints, + std::vector const& default_floats, + std::vector const& default_bools, + std::vector> const& default_strings, + std::vector> const& enum_valid_values, + std::vector>> const& enum_names, + rmm::device_uvector& d_row_force_null, + rmm::device_uvector& d_error, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr, + int32_t const* top_row_indices, + int depth, + bool propagate_invalid_rows = true); + +std::unique_ptr build_repeated_struct_column( + cudf::column_view const& binary_input, + uint8_t const* message_data, + cudf::size_type message_data_size, + cudf::size_type const* list_offsets, + cudf::size_type base_offset, + device_nested_field_descriptor const& field_desc, + rmm::device_uvector const& d_field_counts, + rmm::device_uvector& d_occurrences, + int total_count, + int num_rows, + std::vector const& h_device_schema, + std::vector const& child_field_indices, + std::vector const& default_ints, + std::vector const& default_floats, + std::vector const& default_bools, + std::vector> const& default_strings, + std::vector const& schema, + std::vector> const& enum_valid_values, + std::vector>> const& enum_names, + rmm::device_uvector& d_row_force_null, + rmm::device_uvector& d_error_top, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr); + +} // namespace spark_rapids_jni::protobuf::detail diff --git a/src/main/cpp/src/protobuf/protobuf_kernels.cu b/src/main/cpp/src/protobuf/protobuf_kernels.cu new file mode 100644 index 0000000000..ddd09e881f --- /dev/null +++ b/src/main/cpp/src/protobuf/protobuf_kernels.cu @@ -0,0 +1,168 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "protobuf/protobuf_kernels.cuh" + +#include +#include + +#include +#include + +#include +#include +#include +#include + +namespace spark_rapids_jni::protobuf::detail { + +namespace { + +CUDF_KERNEL void set_error_if_unset_kernel(int* error_flag, int error_code) +{ + if (blockIdx.x == 0 && threadIdx.x == 0) { set_error_once(error_flag, error_code); } +} + +// Stub kernels — replaced with real implementations in follow-up PRs. +CUDF_KERNEL void check_required_fields_kernel(field_location const*, + uint8_t const*, + int, + int, + cudf::bitmask_type const*, + cudf::size_type, + field_location const*, + bool*, + int32_t const*, + int*) +{ +} + +CUDF_KERNEL void validate_enum_values_kernel(int32_t const*, bool*, bool*, int32_t const*, int, int) +{ +} + +} // namespace + +void set_error_once_async(int* error_flag, int error_code, rmm::cuda_stream_view stream) +{ + set_error_if_unset_kernel<<<1, 1, 0, stream.value()>>>(error_flag, error_code); + CUDF_CUDA_TRY(cudaPeekAtLastError()); +} + +void maybe_check_required_fields(field_location const* locations, + std::vector const& field_indices, + std::vector const& schema, + int num_rows, + cudf::bitmask_type const* input_null_mask, + cudf::size_type input_offset, + field_location const* parent_locs, + bool* row_force_null, + int32_t const* top_row_indices, + int* error_flag, + rmm::cuda_stream_view stream) +{ + if (num_rows == 0 || field_indices.empty()) { return; } + + bool has_required = false; + auto h_is_required = + cudf::detail::make_pinned_vector_async(field_indices.size(), stream); + for (size_t i = 0; i < field_indices.size(); ++i) { + h_is_required[i] = schema[field_indices[i]].is_required ? 1 : 0; + has_required |= (h_is_required[i] != 0); + } + if (!has_required) { return; } + + auto d_is_required = cudf::detail::make_device_uvector_async( + h_is_required, stream, cudf::get_current_device_resource_ref()); + + auto const blocks = static_cast((num_rows + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); + check_required_fields_kernel<<>>( + locations, + d_is_required.data(), + static_cast(field_indices.size()), + num_rows, + input_null_mask, + input_offset, + parent_locs, + row_force_null, + top_row_indices, + error_flag); +} + +void propagate_invalid_enum_flags_to_rows(rmm::device_uvector const& item_invalid, + rmm::device_uvector& row_invalid, + int num_items, + int32_t const* top_row_indices, + bool propagate_to_rows, + rmm::cuda_stream_view stream) +{ + if (num_items == 0 || row_invalid.size() == 0 || !propagate_to_rows) { return; } + + if (top_row_indices == nullptr) { + CUDF_EXPECTS(static_cast(num_items) <= row_invalid.size(), + "enum invalid-row propagation exceeded row buffer"); + thrust::transform(rmm::exec_policy_nosync(stream), + row_invalid.begin(), + row_invalid.begin() + num_items, + item_invalid.begin(), + row_invalid.begin(), + [] __device__(bool row_is_invalid, bool item_is_invalid) { + return row_is_invalid || item_is_invalid; + }); + return; + } + + thrust::for_each(rmm::exec_policy_nosync(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(num_items), + [item_invalid = item_invalid.data(), + top_row_indices, + row_invalid = row_invalid.data()] __device__(int idx) { + if (item_invalid[idx]) { row_invalid[top_row_indices[idx]] = true; } + }); +} + +void validate_enum_and_propagate_rows(rmm::device_uvector const& values, + rmm::device_uvector& valid, + cudf::detail::host_vector const& valid_enums, + rmm::device_uvector& row_invalid, + int num_items, + int32_t const* top_row_indices, + bool propagate_to_rows, + rmm::cuda_stream_view stream) +{ + if (num_items == 0 || valid_enums.empty()) { return; } + + auto const blocks = static_cast((num_items + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); + auto d_valid_enums = cudf::detail::make_device_uvector_async( + valid_enums, stream, cudf::get_current_device_resource_ref()); + + rmm::device_uvector item_invalid( + num_items, stream, cudf::get_current_device_resource_ref()); + thrust::fill(rmm::exec_policy_nosync(stream), item_invalid.begin(), item_invalid.end(), false); + validate_enum_values_kernel<<>>( + values.data(), + valid.data(), + item_invalid.data(), + d_valid_enums.data(), + static_cast(valid_enums.size()), + num_items); + + propagate_invalid_enum_flags_to_rows( + item_invalid, row_invalid, num_items, top_row_indices, propagate_to_rows, stream); +} + +} // namespace spark_rapids_jni::protobuf::detail diff --git a/src/main/cpp/src/protobuf/protobuf_kernels.cuh b/src/main/cpp/src/protobuf/protobuf_kernels.cuh new file mode 100644 index 0000000000..38917b6a83 --- /dev/null +++ b/src/main/cpp/src/protobuf/protobuf_kernels.cuh @@ -0,0 +1,937 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "protobuf/protobuf_device_helpers.cuh" +#include "protobuf/protobuf_host_helpers.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace spark_rapids_jni::protobuf::detail { + +// ============================================================================ +// Pass 2: Extract data kernels +// ============================================================================ + +// ============================================================================ +// Data Extraction Location Providers +// ============================================================================ + +struct top_level_location_provider { + cudf::size_type const* offsets; + cudf::size_type base_offset; + field_location const* locations; + int field_idx; + int num_fields; + + __device__ inline field_location get(int thread_idx, int32_t& data_offset) const + { + auto loc = locations[flat_index(static_cast(thread_idx), + static_cast(num_fields), + static_cast(field_idx))]; + if (loc.offset >= 0) { data_offset = offsets[thread_idx] - base_offset + loc.offset; } + return loc; + } +}; + +struct repeated_location_provider { + cudf::size_type const* row_offsets; + cudf::size_type base_offset; + repeated_occurrence const* occurrences; + + __device__ inline field_location get(int thread_idx, int32_t& data_offset) const + { + auto occ = occurrences[thread_idx]; + data_offset = row_offsets[occ.row_idx] - base_offset + occ.offset; + return {occ.offset, occ.length}; + } +}; + +struct nested_location_provider { + cudf::size_type const* row_offsets; + cudf::size_type base_offset; + field_location const* parent_locations; + field_location const* child_locations; + int field_idx; + int num_fields; + + __device__ inline field_location get(int thread_idx, int32_t& data_offset) const + { + auto ploc = parent_locations[thread_idx]; + auto cloc = child_locations[flat_index(static_cast(thread_idx), + static_cast(num_fields), + static_cast(field_idx))]; + if (ploc.offset >= 0 && cloc.offset >= 0) { + data_offset = row_offsets[thread_idx] - base_offset + ploc.offset + cloc.offset; + } else { + cloc.offset = -1; + } + return cloc; + } +}; + +struct nested_repeated_location_provider { + cudf::size_type const* row_offsets; + cudf::size_type base_offset; + field_location const* parent_locations; + repeated_occurrence const* occurrences; + + __device__ inline field_location get(int thread_idx, int32_t& data_offset) const + { + auto occ = occurrences[thread_idx]; + auto ploc = parent_locations[occ.row_idx]; + if (ploc.offset >= 0) { + data_offset = row_offsets[occ.row_idx] - base_offset + ploc.offset + occ.offset; + return {occ.offset, occ.length}; + } + data_offset = 0; + return {-1, 0}; + } +}; + +struct repeated_msg_child_location_provider { + cudf::size_type const* row_offsets; + cudf::size_type base_offset; + field_location const* msg_locations; + field_location const* child_locations; + int field_idx; + int num_fields; + + __device__ inline field_location get(int thread_idx, int32_t& data_offset) const + { + auto mloc = msg_locations[thread_idx]; + auto cloc = child_locations[flat_index(static_cast(thread_idx), + static_cast(num_fields), + static_cast(field_idx))]; + if (mloc.offset >= 0 && cloc.offset >= 0) { + data_offset = row_offsets[thread_idx] - base_offset + mloc.offset + cloc.offset; + } else { + cloc.offset = -1; + } + return cloc; + } +}; + +template +CUDF_KERNEL void extract_varint_kernel(uint8_t const* message_data, + LocationProvider loc_provider, + int total_items, + OutputType* out, + bool* valid, + int* error_flag, + bool has_default = false, + int64_t default_value = 0) +{ + auto idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (idx >= total_items) return; + + int32_t data_offset = 0; + auto loc = loc_provider.get(idx, data_offset); + + // For BOOL8 (uint8_t), protobuf spec says any non-zero varint is true. + // A raw static_cast would silently truncate values >= 256 to 0. + auto const write_value = [](OutputType* dst, uint64_t val) { + if constexpr (cuda::std::is_same_v) { + *dst = static_cast(val != 0 ? 1 : 0); + } else { + *dst = static_cast(val); + } + }; + + if (loc.offset < 0) { + if (has_default) { + write_value(&out[idx], static_cast(default_value)); + if (valid) valid[idx] = true; + } else { + if (valid) valid[idx] = false; + } + return; + } + + uint8_t const* cur = message_data + data_offset; + uint8_t const* cur_end = cur + loc.length; + + uint64_t v; + int n; + if (!read_varint(cur, cur_end, v, n)) { + set_error_once(error_flag, ERR_VARINT); + if (valid) valid[idx] = false; + return; + } + + if constexpr (ZigZag) { v = (v >> 1) ^ (-(v & 1)); } + write_value(&out[idx], v); + if (valid) valid[idx] = true; +} + +template +CUDF_KERNEL void extract_fixed_kernel(uint8_t const* message_data, + LocationProvider loc_provider, + int total_items, + OutputType* out, + bool* valid, + int* error_flag, + bool has_default = false, + OutputType default_value = OutputType{}) +{ + auto idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (idx >= total_items) return; + + int32_t data_offset = 0; + auto loc = loc_provider.get(idx, data_offset); + + if (loc.offset < 0) { + if (has_default) { + out[idx] = default_value; + if (valid) valid[idx] = true; + } else { + if (valid) valid[idx] = false; + } + return; + } + + uint8_t const* cur = message_data + data_offset; + OutputType value; + + if constexpr (WT == wire_type_value(proto_wire_type::I32BIT)) { + if (loc.length < 4) { + set_error_once(error_flag, ERR_FIXED_LEN); + if (valid) valid[idx] = false; + return; + } + uint32_t raw = load_le(cur); + memcpy(&value, &raw, sizeof(value)); + } else { + if (loc.length < 8) { + set_error_once(error_flag, ERR_FIXED_LEN); + if (valid) valid[idx] = false; + return; + } + uint64_t raw = load_le(cur); + memcpy(&value, &raw, sizeof(value)); + } + + out[idx] = value; + if (valid) valid[idx] = true; +} + +// ============================================================================ +// Batched scalar extraction — one 2D kernel for N fields of the same type +// ============================================================================ + +struct batched_scalar_desc { + int loc_field_idx; // index into the locations array (column within d_locations) + void* output; // pre-allocated output buffer (T*) + bool* valid; // pre-allocated validity buffer + bool has_default; + int64_t default_int; + double default_float; +}; + +template +CUDF_KERNEL void extract_varint_batched_kernel(uint8_t const* message_data, + cudf::size_type const* row_offsets, + cudf::size_type base_offset, + field_location const* locations, + int num_loc_fields, + batched_scalar_desc const* descs, + int num_descs, + int num_rows, + int* error_flag) +{ + int row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + int fi = static_cast(blockIdx.y); + if (row >= num_rows || fi >= num_descs) return; + + auto const& desc = descs[fi]; + auto loc = locations[row * num_loc_fields + desc.loc_field_idx]; + auto* out = static_cast(desc.output); + + auto const write_value = [](OutputType* dst, uint64_t val) { + if constexpr (cuda::std::is_same_v) { + *dst = static_cast(val != 0 ? 1 : 0); + } else { + *dst = static_cast(val); + } + }; + + if (loc.offset < 0) { + if (desc.has_default) { + write_value(&out[row], static_cast(desc.default_int)); + desc.valid[row] = true; + } else { + desc.valid[row] = false; + } + return; + } + + int32_t data_offset = row_offsets[row] - base_offset + loc.offset; + uint8_t const* cur = message_data + data_offset; + uint8_t const* end = cur + loc.length; + + uint64_t v; + int n; + if (!read_varint(cur, end, v, n)) { + set_error_once(error_flag, ERR_VARINT); + desc.valid[row] = false; + return; + } + if constexpr (ZigZag) { v = (v >> 1) ^ (-(v & 1)); } + write_value(&out[row], v); + desc.valid[row] = true; +} + +template +CUDF_KERNEL void extract_fixed_batched_kernel(uint8_t const* message_data, + cudf::size_type const* row_offsets, + cudf::size_type base_offset, + field_location const* locations, + int num_loc_fields, + batched_scalar_desc const* descs, + int num_descs, + int num_rows, + int* error_flag) +{ + int row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + int fi = static_cast(blockIdx.y); + if (row >= num_rows || fi >= num_descs) return; + + auto const& desc = descs[fi]; + auto loc = locations[row * num_loc_fields + desc.loc_field_idx]; + auto* out = static_cast(desc.output); + + if (loc.offset < 0) { + if (desc.has_default) { + if constexpr (cuda::std::is_integral_v) { + out[row] = static_cast(desc.default_int); + } else { + out[row] = static_cast(desc.default_float); + } + desc.valid[row] = true; + } else { + desc.valid[row] = false; + } + return; + } + + int32_t data_offset = row_offsets[row] - base_offset + loc.offset; + uint8_t const* cur = message_data + data_offset; + OutputType value; + + if constexpr (WT == wire_type_value(proto_wire_type::I32BIT)) { + if (loc.length < 4) { + set_error_once(error_flag, ERR_FIXED_LEN); + desc.valid[row] = false; + return; + } + uint32_t raw = load_le(cur); + memcpy(&value, &raw, sizeof(value)); + } else { + if (loc.length < 8) { + set_error_once(error_flag, ERR_FIXED_LEN); + desc.valid[row] = false; + return; + } + uint64_t raw = load_le(cur); + memcpy(&value, &raw, sizeof(value)); + } + out[row] = value; + desc.valid[row] = true; +} + +// ============================================================================ + +template +CUDF_KERNEL void extract_lengths_kernel(LocationProvider loc_provider, + int total_items, + int32_t* out_lengths, + bool has_default = false, + int32_t default_length = 0) +{ + auto idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (idx >= total_items) return; + + int32_t data_offset = 0; + auto loc = loc_provider.get(idx, data_offset); + + if (loc.offset >= 0) { + out_lengths[idx] = loc.length; + } else if (has_default) { + out_lengths[idx] = default_length; + } else { + out_lengths[idx] = 0; + } +} + +// ============================================================================ +// Host-side template helpers that launch CUDA kernels +// ============================================================================ + +template +inline std::pair make_null_mask_from_valid( + rmm::device_uvector const& valid, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + auto begin = thrust::make_counting_iterator(0); + auto end = begin + valid.size(); + auto pred = [ptr = valid.data()] __device__(cudf::size_type i) { + return static_cast(ptr[i]); + }; + return cudf::detail::valid_if(begin, end, pred, stream, mr); +} + +template +std::unique_ptr extract_and_build_scalar_column(cudf::data_type dt, + int num_rows, + LaunchFn&& launch_extract, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + rmm::device_uvector out(num_rows, stream, mr); + rmm::device_uvector valid((num_rows > 0 ? num_rows : 1), stream, mr); + if (num_rows == 0) { + return std::make_unique(dt, 0, out.release(), rmm::device_buffer{}, 0); + } + launch_extract(out.data(), valid.data()); + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + return std::make_unique(dt, num_rows, out.release(), std::move(mask), null_count); +} + +template +inline void extract_integer_into_buffers(uint8_t const* message_data, + LocationProvider const& loc_provider, + int num_rows, + int blocks, + int threads, + bool has_default, + int64_t default_value, + int encoding, + bool enable_zigzag, + T* out_ptr, + bool* valid_ptr, + int* error_ptr, + rmm::cuda_stream_view stream) +{ + if (enable_zigzag && encoding == encoding_value(proto_encoding::ZIGZAG)) { + extract_varint_kernel + <<>>(message_data, + loc_provider, + num_rows, + out_ptr, + valid_ptr, + error_ptr, + has_default, + default_value); + } else if (encoding == encoding_value(proto_encoding::FIXED)) { + if constexpr (sizeof(T) == 4) { + extract_fixed_kernel + <<>>(message_data, + loc_provider, + num_rows, + out_ptr, + valid_ptr, + error_ptr, + has_default, + static_cast(default_value)); + } else { + static_assert(sizeof(T) == 8, "extract_integer_into_buffers only supports 32/64-bit"); + extract_fixed_kernel + <<>>(message_data, + loc_provider, + num_rows, + out_ptr, + valid_ptr, + error_ptr, + has_default, + static_cast(default_value)); + } + } else { + extract_varint_kernel + <<>>(message_data, + loc_provider, + num_rows, + out_ptr, + valid_ptr, + error_ptr, + has_default, + default_value); + } +} + +template +std::unique_ptr extract_and_build_integer_column(cudf::data_type dt, + uint8_t const* message_data, + LocationProvider const& loc_provider, + int num_rows, + int blocks, + int threads, + rmm::device_uvector& d_error, + bool has_default, + int64_t default_value, + int encoding, + bool enable_zigzag, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + return extract_and_build_scalar_column( + dt, + num_rows, + [&](T* out_ptr, bool* valid_ptr) { + extract_integer_into_buffers(message_data, + loc_provider, + num_rows, + blocks, + threads, + has_default, + default_value, + encoding, + enable_zigzag, + out_ptr, + valid_ptr, + d_error.data(), + stream); + }, + stream, + mr); +} + +struct extract_strided_count { + repeated_field_info const* info; + int field_idx; + int num_fields; + + __device__ int32_t operator()(int row) const + { + return info[flat_index(static_cast(row), + static_cast(num_fields), + static_cast(field_idx))] + .count; + } +}; + +template +inline std::unique_ptr extract_and_build_string_or_bytes_column( + bool as_bytes, + uint8_t const* message_data, + int num_rows, + LengthProvider const& length_provider, + CopyProvider const& copy_provider, + ValidityFn validity_fn, + bool has_default, + cudf::detail::host_vector const& default_bytes, + rmm::device_uvector& d_error, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + int32_t def_len = has_default ? static_cast(default_bytes.size()) : 0; + rmm::device_uvector d_default(0, stream, mr); + if (has_default && def_len > 0) { + d_default = cudf::detail::make_device_uvector_async( + default_bytes, stream, cudf::get_current_device_resource_ref()); + } + + rmm::device_uvector lengths(num_rows, stream, mr); + auto const threads = THREADS_PER_BLOCK; + auto const blocks = static_cast((num_rows + threads - 1u) / threads); + extract_lengths_kernel<<>>( + length_provider, num_rows, lengths.data(), has_default, def_len); + + auto [offsets_col, total_size] = + cudf::strings::detail::make_offsets_child_column(lengths.begin(), lengths.end(), stream, mr); + + rmm::device_uvector chars(total_size, stream, mr); + if (total_size > 0) { + auto const* offsets_data = offsets_col->view().data(); + auto* chars_ptr = chars.data(); + auto const* default_ptr = d_default.data(); + + auto src_iter = cudf::detail::make_counting_transform_iterator( + 0, + cuda::proclaim_return_type( + [message_data, copy_provider, has_default, default_ptr, def_len] __device__( + int idx) -> void const* { + int32_t data_offset = 0; + auto loc = copy_provider.get(idx, data_offset); + if (loc.offset < 0) { + return (has_default && def_len > 0) ? static_cast(default_ptr) : nullptr; + } + return static_cast(message_data + data_offset); + })); + auto dst_iter = cudf::detail::make_counting_transform_iterator( + 0, cuda::proclaim_return_type([chars_ptr, offsets_data] __device__(int idx) -> void* { + return static_cast(chars_ptr + offsets_data[idx]); + })); + auto size_iter = cudf::detail::make_counting_transform_iterator( + 0, + cuda::proclaim_return_type( + [copy_provider, has_default, def_len] __device__(int idx) -> size_t { + int32_t data_offset = 0; + auto loc = copy_provider.get(idx, data_offset); + if (loc.offset < 0) { + return (has_default && def_len > 0) ? static_cast(def_len) : 0; + } + return static_cast(loc.length); + })); + + size_t temp_storage_bytes = 0; + cub::DeviceMemcpy::Batched( + nullptr, temp_storage_bytes, src_iter, dst_iter, size_iter, num_rows, stream.value()); + rmm::device_buffer temp_storage(temp_storage_bytes, stream, mr); + cub::DeviceMemcpy::Batched(temp_storage.data(), + temp_storage_bytes, + src_iter, + dst_iter, + size_iter, + num_rows, + stream.value()); + } + + if (num_rows == 0) { + if (as_bytes) { + auto bytes_child = std::make_unique( + cudf::data_type{cudf::type_id::UINT8}, 0, rmm::device_buffer{}, rmm::device_buffer{}, 0); + return cudf::make_lists_column( + 0, std::move(offsets_col), std::move(bytes_child), 0, rmm::device_buffer{}); + } + return cudf::make_strings_column( + 0, std::move(offsets_col), chars.release(), 0, rmm::device_buffer{}); + } + + rmm::device_uvector valid(num_rows, stream, mr); + thrust::transform(rmm::exec_policy_nosync(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(num_rows), + valid.data(), + validity_fn); + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + if (as_bytes) { + auto bytes_child = + std::make_unique(cudf::data_type{cudf::type_id::UINT8}, + total_size, + rmm::device_buffer(chars.data(), total_size, stream, mr), + rmm::device_buffer{}, + 0); + return cudf::make_lists_column( + num_rows, std::move(offsets_col), std::move(bytes_child), null_count, std::move(mask)); + } + + return cudf::make_strings_column( + num_rows, std::move(offsets_col), chars.release(), null_count, std::move(mask)); +} + +template +inline std::unique_ptr extract_typed_column( + cudf::data_type dt, + int encoding, + uint8_t const* message_data, + LocationProvider const& loc_provider, + int num_items, + int blocks, + int threads_per_block, + bool has_default, + int64_t default_int, + double default_float, + bool default_bool, + cudf::detail::host_vector const& default_string, + int schema_idx, + std::vector> const& enum_valid_values, + std::vector>> const& enum_names, + rmm::device_uvector& d_row_force_null, + rmm::device_uvector& d_error, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr, + int32_t const* top_row_indices = nullptr, + bool propagate_invalid_rows = true) +{ + switch (dt.id()) { + case cudf::type_id::BOOL8: { + int64_t def_val = has_default ? (default_bool ? 1 : 0) : 0; + return extract_and_build_scalar_column( + dt, + num_items, + [&](uint8_t* out_ptr, bool* valid_ptr) { + extract_varint_kernel + <<>>(message_data, + loc_provider, + num_items, + out_ptr, + valid_ptr, + d_error.data(), + has_default, + def_val); + }, + stream, + mr); + } + case cudf::type_id::INT32: { + if (num_items == 0) { + return std::make_unique(dt, 0, rmm::device_buffer{}, rmm::device_buffer{}, 0); + } + rmm::device_uvector out(num_items, stream, mr); + rmm::device_uvector valid(num_items, stream, mr); + extract_integer_into_buffers(message_data, + loc_provider, + num_items, + blocks, + threads_per_block, + has_default, + default_int, + encoding, + true, + out.data(), + valid.data(), + d_error.data(), + stream); + if (schema_idx < static_cast(enum_valid_values.size())) { + auto const& valid_enums = enum_valid_values[schema_idx]; + if (!valid_enums.empty()) { + validate_enum_and_propagate_rows(out, + valid, + valid_enums, + d_row_force_null, + num_items, + top_row_indices, + propagate_invalid_rows, + stream); + } + } + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + return std::make_unique( + dt, num_items, out.release(), std::move(mask), null_count); + } + case cudf::type_id::UINT32: + return extract_and_build_integer_column(dt, + message_data, + loc_provider, + num_items, + blocks, + threads_per_block, + d_error, + has_default, + default_int, + encoding, + false, + stream, + mr); + case cudf::type_id::INT64: + return extract_and_build_integer_column(dt, + message_data, + loc_provider, + num_items, + blocks, + threads_per_block, + d_error, + has_default, + default_int, + encoding, + true, + stream, + mr); + case cudf::type_id::UINT64: + return extract_and_build_integer_column(dt, + message_data, + loc_provider, + num_items, + blocks, + threads_per_block, + d_error, + has_default, + default_int, + encoding, + false, + stream, + mr); + case cudf::type_id::FLOAT32: { + float def_float_val = has_default ? static_cast(default_float) : 0.0f; + return extract_and_build_scalar_column( + dt, + num_items, + [&](float* out_ptr, bool* valid_ptr) { + extract_fixed_kernel + <<>>(message_data, + loc_provider, + num_items, + out_ptr, + valid_ptr, + d_error.data(), + has_default, + def_float_val); + }, + stream, + mr); + } + case cudf::type_id::FLOAT64: { + double def_double = has_default ? default_float : 0.0; + return extract_and_build_scalar_column( + dt, + num_items, + [&](double* out_ptr, bool* valid_ptr) { + extract_fixed_kernel + <<>>(message_data, + loc_provider, + num_items, + out_ptr, + valid_ptr, + d_error.data(), + has_default, + def_double); + }, + stream, + mr); + } + default: return make_null_column(dt, num_items, stream, mr); + } +} + +template +inline std::unique_ptr build_repeated_scalar_column( + cudf::column_view const& binary_input, + uint8_t const* message_data, + cudf::size_type const* list_offsets, + cudf::size_type base_offset, + device_nested_field_descriptor const& field_desc, + rmm::device_uvector const& d_field_counts, + rmm::device_uvector& d_occurrences, + int total_count, + int num_rows, + rmm::device_uvector& d_error, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + auto const input_null_count = binary_input.null_count(); + + if (total_count == 0) { + rmm::device_uvector offsets(num_rows + 1, stream, mr); + thrust::fill(rmm::exec_policy_nosync(stream), offsets.begin(), offsets.end(), 0); + auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + num_rows + 1, + offsets.release(), + rmm::device_buffer{}, + 0); + auto elem_type = field_desc.output_type_id == static_cast(cudf::type_id::LIST) + ? cudf::type_id::UINT8 + : static_cast(field_desc.output_type_id); + auto child_col = make_empty_column_safe(cudf::data_type{elem_type}, stream, mr); + + if (input_null_count > 0) { + auto null_mask = cudf::copy_bitmask(binary_input, stream, mr); + return cudf::make_lists_column(num_rows, + std::move(offsets_col), + std::move(child_col), + input_null_count, + std::move(null_mask)); + } else { + return cudf::make_lists_column( + num_rows, std::move(offsets_col), std::move(child_col), 0, rmm::device_buffer{}); + } + } + + rmm::device_uvector list_offs(num_rows + 1, stream, mr); + thrust::exclusive_scan(rmm::exec_policy_nosync(stream), + d_field_counts.begin(), + d_field_counts.end(), + list_offs.begin(), + 0); + + int32_t total_count_i32 = static_cast(total_count); + thrust::fill_n(rmm::exec_policy_nosync(stream), list_offs.data() + num_rows, 1, total_count_i32); + + rmm::device_uvector values(total_count, stream, mr); + + auto const threads = THREADS_PER_BLOCK; + auto const blocks = static_cast((total_count + threads - 1u) / threads); + + int encoding = field_desc.encoding; + bool zigzag = (encoding == encoding_value(proto_encoding::ZIGZAG)); + + constexpr bool is_floating_point = std::is_same_v || std::is_same_v; + bool use_fixed_kernel = is_floating_point || (encoding == encoding_value(proto_encoding::FIXED)); + + repeated_location_provider loc_provider{list_offsets, base_offset, d_occurrences.data()}; + if (use_fixed_kernel) { + if constexpr (sizeof(T) == 4) { + extract_fixed_kernel + <<>>( + message_data, loc_provider, total_count, values.data(), nullptr, d_error.data()); + } else { + extract_fixed_kernel + <<>>( + message_data, loc_provider, total_count, values.data(), nullptr, d_error.data()); + } + } else if (zigzag) { + extract_varint_kernel + <<>>( + message_data, loc_provider, total_count, values.data(), nullptr, d_error.data()); + } else { + extract_varint_kernel + <<>>( + message_data, loc_provider, total_count, values.data(), nullptr, d_error.data()); + } + + auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + num_rows + 1, + list_offs.release(), + rmm::device_buffer{}, + 0); + auto child_col = std::make_unique( + cudf::data_type{static_cast(field_desc.output_type_id)}, + total_count, + values.release(), + rmm::device_buffer{}, + 0); + + if (input_null_count > 0) { + auto null_mask = cudf::copy_bitmask(binary_input, stream, mr); + return cudf::make_lists_column(num_rows, + std::move(offsets_col), + std::move(child_col), + input_null_count, + std::move(null_mask)); + } + + return cudf::make_lists_column( + num_rows, std::move(offsets_col), std::move(child_col), 0, rmm::device_buffer{}); +} + +} // namespace spark_rapids_jni::protobuf::detail diff --git a/src/main/cpp/src/protobuf/protobuf_types.cuh b/src/main/cpp/src/protobuf/protobuf_types.cuh new file mode 100644 index 0000000000..c575447ded --- /dev/null +++ b/src/main/cpp/src/protobuf/protobuf_types.cuh @@ -0,0 +1,126 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "protobuf/protobuf.hpp" + +namespace spark_rapids_jni::protobuf::detail { + +// Protobuf varint encoding uses at most 10 bytes to represent a 64-bit value. +constexpr int MAX_VARINT_BYTES = 10; + +// CUDA kernel launch configuration. +constexpr int THREADS_PER_BLOCK = 256; + +// Error codes for kernel error reporting. +constexpr int ERR_BOUNDS = 1; +constexpr int ERR_VARINT = 2; +constexpr int ERR_FIELD_NUMBER = 3; +constexpr int ERR_WIRE_TYPE = 4; +constexpr int ERR_OVERFLOW = 5; +constexpr int ERR_FIELD_SIZE = 6; +constexpr int ERR_SKIP = 7; +constexpr int ERR_FIXED_LEN = 8; +constexpr int ERR_REQUIRED = 9; +constexpr int ERR_SCHEMA_TOO_LARGE = 10; +constexpr int ERR_MISSING_ENUM_META = 11; +constexpr int ERR_REPEATED_COUNT_MISMATCH = 12; + +// Threshold for using a direct-mapped lookup table for field_number -> field_index. +// Field numbers above this threshold fall back to linear search. +constexpr int FIELD_LOOKUP_TABLE_MAX = 4096; + +/** + * Structure to record field location within a message. + * offset < 0 means field was not found. + */ +struct field_location { + int32_t offset; // Offset of field data within the message (-1 if not found) + int32_t length; // Length of field data in bytes +}; + +/** + * Field descriptor passed to the scanning kernel. + */ +struct field_descriptor { + int field_number; // Protobuf field number + int expected_wire_type; // Expected wire type for this field + bool is_repeated; // Repeated children are scanned via count/scan kernels +}; + +/** + * Information about repeated field occurrences in a row. + */ +struct repeated_field_info { + int32_t count; // Number of occurrences in this row + int32_t total_length; // Total bytes for all occurrences (for varlen fields) +}; + +/** + * Location of a single occurrence of a repeated field. + */ +struct repeated_occurrence { + int32_t row_idx; // Which row this occurrence belongs to + int32_t offset; // Offset within the message + int32_t length; // Length of the field data +}; + +/** + * Per-field descriptor passed to the combined occurrence scan kernel. + * Contains device pointers so the kernel can write to each field's output. + */ +struct repeated_field_scan_desc { + int field_number; + int wire_type; + int32_t const* row_offsets; // Pre-computed prefix-sum offsets [num_rows + 1] + repeated_occurrence* occurrences; // Output buffer [total_count] +}; + +/** + * Device-side descriptor for nested schema fields. + */ +struct device_nested_field_descriptor { + int field_number; + int parent_idx; + int depth; + int wire_type; + int output_type_id; + int encoding; + bool is_repeated; + bool is_required; + bool has_default_value; + + device_nested_field_descriptor() = default; + + // Wire type and encoding are stored as int (not typed enums) because CUDA device code + // historically had limited constexpr enum support, and the kernel comparison sites use + // int-typed wire_type_value()/encoding_value() helpers throughout. + explicit device_nested_field_descriptor(nested_field_descriptor const& src) + : field_number(src.field_number), + parent_idx(src.parent_idx), + depth(src.depth), + wire_type(static_cast(src.wire_type)), + output_type_id(static_cast(src.output_type)), + encoding(static_cast(src.encoding)), + is_repeated(src.is_repeated), + is_required(src.is_required), + has_default_value(src.has_default_value) + { + } +}; + +} // namespace spark_rapids_jni::protobuf::detail diff --git a/src/main/java/com/nvidia/spark/rapids/jni/Protobuf.java b/src/main/java/com/nvidia/spark/rapids/jni/Protobuf.java new file mode 100644 index 0000000000..512ef9e3c7 --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/Protobuf.java @@ -0,0 +1,116 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.jni; + +import ai.rapids.cudf.ColumnVector; +import ai.rapids.cudf.ColumnView; +import ai.rapids.cudf.NativeDepsLoader; + +/** + * GPU protobuf decoding utilities. + * + * This API uses a multi-pass approach for efficient decoding: + *
    + *
  • Pass 1: Scan all messages, count nested elements and repeated field occurrences
  • + *
  • Pass 2: Prefix sum to compute output offsets for arrays and nested structs
  • + *
  • Pass 3: Extract data using pre-computed offsets
  • + *
  • Pass 4: Build nested column structure
  • + *
+ * + * The schema is represented as a flattened array of field descriptors with parent-child + * relationships. Top-level fields have parentIndices == -1 and depthLevels == 0. + * For pure scalar schemas, all fields are top-level with isRepeated == false. + * + * Supported protobuf field types include: + *
    + *
  • VARINT: {@code int32}, {@code int64}, {@code uint32}, {@code uint64}, {@code bool}
  • + *
  • ZIGZAG VARINT (encoding=2): {@code sint32}, {@code sint64}
  • + *
  • FIXED32 (encoding=1): {@code fixed32}, {@code sfixed32}, {@code float}
  • + *
  • FIXED64 (encoding=1): {@code fixed64}, {@code sfixed64}, {@code double}
  • + *
  • LENGTH_DELIMITED: {@code string}, {@code bytes}, nested {@code message}
  • + *
  • Nested messages and repeated fields
  • + *
+ * + *

In permissive mode ({@code failOnErrors=false}), if decoding encounters a row-local parse + * error from which it cannot safely recover its cursor position (for example, an unexpected wire + * type or malformed varint), scanning for that row stops at the error position. Fields that appear + * later in the same message are therefore treated as "not found" and follow the normal + * missing-field semantics (nulls or defaults, depending on the schema metadata). + */ +public class Protobuf { + static { + NativeDepsLoader.loadNativeDeps(); + } + + public static final int ENC_DEFAULT = 0; + public static final int ENC_FIXED = 1; + public static final int ENC_ZIGZAG = 2; + public static final int ENC_ENUM_STRING = 3; + + // Wire type constants + public static final int WT_VARINT = 0; + public static final int WT_64BIT = 1; + public static final int WT_LEN = 2; + public static final int WT_32BIT = 5; + + /** + * Decode protobuf messages into a STRUCT column. + * + * @param binaryInput column of type LIST<INT8/UINT8> where each row is one protobuf message. + * @param schema descriptor containing flattened schema arrays (field numbers, types, defaults, etc.) + * @param failOnErrors if true, throw an exception on malformed protobuf messages. If false, + * malformed rows are handled permissively; when a row-local parse error + * prevents safe resynchronization, later fields in that same row are treated + * as absent rather than continuing from an uncertain cursor position. + * @return a cudf STRUCT column with nested structure. + */ + public static ColumnVector decodeToStruct(ColumnView binaryInput, + ProtobufSchemaDescriptor schema, + boolean failOnErrors) { + if (binaryInput == null) { + throw new IllegalArgumentException("binaryInput must not be null"); + } + if (schema == null) { + throw new IllegalArgumentException("schema must not be null"); + } + long handle = decodeToStruct(binaryInput.getNativeView(), + schema.fieldNumbers, schema.parentIndices, schema.depthLevels, + schema.wireTypes, schema.outputTypeIds, schema.encodings, + schema.isRepeated, schema.isRequired, schema.hasDefaultValue, + schema.defaultInts, schema.defaultFloats, schema.defaultBools, + schema.defaultStrings, schema.enumValidValues, schema.enumNames, failOnErrors); + return new ColumnVector(handle); + } + + private static native long decodeToStruct(long binaryInputView, + int[] fieldNumbers, + int[] parentIndices, + int[] depthLevels, + int[] wireTypes, + int[] outputTypeIds, + int[] encodings, + boolean[] isRepeated, + boolean[] isRequired, + boolean[] hasDefaultValue, + long[] defaultInts, + double[] defaultFloats, + boolean[] defaultBools, + byte[][] defaultStrings, + int[][] enumValidValues, + byte[][][] enumNames, + boolean failOnErrors); +} diff --git a/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java b/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java new file mode 100644 index 0000000000..810cfb60a9 --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java @@ -0,0 +1,319 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.jni; + +import java.util.HashSet; +import java.util.Set; + +/** + * Immutable descriptor for a flattened protobuf schema, grouping the parallel arrays + * that describe field structure, types, defaults, and enum metadata. + * + *

Use this class instead of passing 15+ individual arrays through the JNI boundary. + * Validation is performed once in the constructor (and again on deserialization). + * + *

All arrays provided to the constructor are defensively copied to guarantee immutability. + * During deserialization, {@code defaultReadObject()} reconstructs a fresh object graph and + * {@link #readObject(java.io.ObjectInputStream)} re-validates the schema invariants before the + * instance becomes visible. Package-private field access from {@link Protobuf} is therefore safe + * because constructor callers cannot retain mutable aliases into the stored arrays. + */ +public final class ProtobufSchemaDescriptor implements java.io.Serializable { + private static final long serialVersionUID = 1L; + private static final int MAX_FIELD_NUMBER = (1 << 29) - 1; + private static final int MAX_NESTING_DEPTH = 10; + private static final int STRUCT_TYPE_ID = ai.rapids.cudf.DType.STRUCT.getTypeId().getNativeId(); + private static final int STRING_TYPE_ID = ai.rapids.cudf.DType.STRING.getTypeId().getNativeId(); + private static final int LIST_TYPE_ID = ai.rapids.cudf.DType.LIST.getTypeId().getNativeId(); + private static final int BOOL8_TYPE_ID = ai.rapids.cudf.DType.BOOL8.getTypeId().getNativeId(); + private static final int INT32_TYPE_ID = ai.rapids.cudf.DType.INT32.getTypeId().getNativeId(); + private static final int UINT32_TYPE_ID = ai.rapids.cudf.DType.UINT32.getTypeId().getNativeId(); + private static final int INT64_TYPE_ID = ai.rapids.cudf.DType.INT64.getTypeId().getNativeId(); + private static final int UINT64_TYPE_ID = ai.rapids.cudf.DType.UINT64.getTypeId().getNativeId(); + private static final int FLOAT32_TYPE_ID = ai.rapids.cudf.DType.FLOAT32.getTypeId().getNativeId(); + private static final int FLOAT64_TYPE_ID = ai.rapids.cudf.DType.FLOAT64.getTypeId().getNativeId(); + + final int[] fieldNumbers; + final int[] parentIndices; + final int[] depthLevels; + final int[] wireTypes; + final int[] outputTypeIds; + final int[] encodings; + final boolean[] isRepeated; + final boolean[] isRequired; + final boolean[] hasDefaultValue; + final long[] defaultInts; + final double[] defaultFloats; + final boolean[] defaultBools; + final byte[][] defaultStrings; + final int[][] enumValidValues; + final byte[][][] enumNames; + + /** + * @throws IllegalArgumentException if any array is null, arrays have mismatched lengths, + * field numbers are out of range, or encoding values are invalid. + */ + public ProtobufSchemaDescriptor( + int[] fieldNumbers, + int[] parentIndices, + int[] depthLevels, + int[] wireTypes, + int[] outputTypeIds, + int[] encodings, + boolean[] isRepeated, + boolean[] isRequired, + boolean[] hasDefaultValue, + long[] defaultInts, + double[] defaultFloats, + boolean[] defaultBools, + byte[][] defaultStrings, + int[][] enumValidValues, + byte[][][] enumNames) { + + validate(fieldNumbers, parentIndices, depthLevels, wireTypes, outputTypeIds, + encodings, isRepeated, isRequired, hasDefaultValue, defaultInts, + defaultFloats, defaultBools, defaultStrings, enumValidValues, enumNames); + + this.fieldNumbers = fieldNumbers.clone(); + this.parentIndices = parentIndices.clone(); + this.depthLevels = depthLevels.clone(); + this.wireTypes = wireTypes.clone(); + this.outputTypeIds = outputTypeIds.clone(); + this.encodings = encodings.clone(); + this.isRepeated = isRepeated.clone(); + this.isRequired = isRequired.clone(); + this.hasDefaultValue = hasDefaultValue.clone(); + this.defaultInts = defaultInts.clone(); + this.defaultFloats = defaultFloats.clone(); + this.defaultBools = defaultBools.clone(); + this.defaultStrings = deepCopy(defaultStrings); + this.enumValidValues = deepCopy(enumValidValues); + this.enumNames = deepCopy(enumNames); + } + + public int numFields() { return fieldNumbers.length; } + + private void readObject(java.io.ObjectInputStream in) + throws java.io.IOException, ClassNotFoundException { + // defaultReadObject() reconstructs new array objects from the serialized stream; we do not + // receive caller-owned array aliases here. Re-run validate() so deserialization cannot bypass + // the constructor's schema invariants. + in.defaultReadObject(); + try { + validate(fieldNumbers, parentIndices, depthLevels, wireTypes, outputTypeIds, + encodings, isRepeated, isRequired, hasDefaultValue, defaultInts, + defaultFloats, defaultBools, defaultStrings, enumValidValues, enumNames); + } catch (IllegalArgumentException e) { + java.io.InvalidObjectException ioe = new java.io.InvalidObjectException(e.getMessage()); + ioe.initCause(e); + throw ioe; + } + } + + private static byte[][] deepCopy(byte[][] src) { + byte[][] dst = new byte[src.length][]; + for (int i = 0; i < src.length; i++) { + dst[i] = src[i] != null ? src[i].clone() : null; + } + return dst; + } + + private static int[][] deepCopy(int[][] src) { + int[][] dst = new int[src.length][]; + for (int i = 0; i < src.length; i++) { + dst[i] = src[i] != null ? src[i].clone() : null; + } + return dst; + } + + private static byte[][][] deepCopy(byte[][][] src) { + byte[][][] dst = new byte[src.length][][]; + for (int i = 0; i < src.length; i++) { + if (src[i] == null) continue; + dst[i] = new byte[src[i].length][]; + for (int j = 0; j < src[i].length; j++) { + dst[i][j] = src[i][j] != null ? src[i][j].clone() : null; + } + } + return dst; + } + + private static void validate( + int[] fieldNumbers, int[] parentIndices, int[] depthLevels, + int[] wireTypes, int[] outputTypeIds, int[] encodings, + boolean[] isRepeated, boolean[] isRequired, boolean[] hasDefaultValue, + long[] defaultInts, double[] defaultFloats, boolean[] defaultBools, + byte[][] defaultStrings, int[][] enumValidValues, byte[][][] enumNames) { + + if (fieldNumbers == null || parentIndices == null || depthLevels == null || + wireTypes == null || outputTypeIds == null || encodings == null || + isRepeated == null || isRequired == null || hasDefaultValue == null || + defaultInts == null || defaultFloats == null || defaultBools == null || + defaultStrings == null || enumValidValues == null || enumNames == null) { + throw new IllegalArgumentException("All schema arrays must be non-null"); + } + + int n = fieldNumbers.length; + if (parentIndices.length != n || depthLevels.length != n || + wireTypes.length != n || outputTypeIds.length != n || + encodings.length != n || isRepeated.length != n || + isRequired.length != n || hasDefaultValue.length != n || + defaultInts.length != n || defaultFloats.length != n || + defaultBools.length != n || defaultStrings.length != n || + enumValidValues.length != n || enumNames.length != n) { + throw new IllegalArgumentException("All schema arrays must have the same length"); + } + + Set seenFieldNumbers = new HashSet<>(); + for (int i = 0; i < n; i++) { + if (fieldNumbers[i] <= 0 || fieldNumbers[i] > MAX_FIELD_NUMBER) { + throw new IllegalArgumentException( + "Invalid field number at index " + i + ": " + fieldNumbers[i] + + " (must be 1-" + MAX_FIELD_NUMBER + ")"); + } + if (depthLevels[i] < 0 || depthLevels[i] >= MAX_NESTING_DEPTH) { + throw new IllegalArgumentException( + "Invalid depth at index " + i + ": " + depthLevels[i] + + " (must be 0-" + (MAX_NESTING_DEPTH - 1) + ")"); + } + int pi = parentIndices[i]; + if (pi < -1 || pi >= i) { + throw new IllegalArgumentException( + "Invalid parent index at index " + i + ": " + pi + + " (must be -1 or a prior index < " + i + ")"); + } + if (pi == -1) { + if (depthLevels[i] != 0) { + throw new IllegalArgumentException( + "Top-level field at index " + i + " must have depth 0, got " + depthLevels[i]); + } + } else { + if (outputTypeIds[pi] != STRUCT_TYPE_ID) { + throw new IllegalArgumentException( + "Parent at index " + pi + " for field " + i + " must be STRUCT, got type id " + + outputTypeIds[pi]); + } + if (depthLevels[i] != depthLevels[pi] + 1) { + throw new IllegalArgumentException( + "Field at index " + i + " depth (" + depthLevels[i] + + ") must be parent depth (" + depthLevels[pi] + ") + 1"); + } + } + long fieldKey = (((long) pi) << 32) | (fieldNumbers[i] & 0xFFFFFFFFL); + if (!seenFieldNumbers.add(fieldKey)) { + throw new IllegalArgumentException( + "Duplicate field number " + fieldNumbers[i] + + " under parent index " + pi + " at schema index " + i); + } + int wt = wireTypes[i]; + if (wt != 0 && wt != 1 && wt != 2 && wt != 5) { + throw new IllegalArgumentException( + "Invalid wire type at index " + i + ": " + wt + + " (must be one of {0, 1, 2, 5})"); + } + int enc = encodings[i]; + if (enc < Protobuf.ENC_DEFAULT || enc > Protobuf.ENC_ENUM_STRING) { + throw new IllegalArgumentException( + "Invalid encoding at index " + i + ": " + enc); + } + if (!isEncodingCompatible(wt, outputTypeIds[i], enc)) { + throw new IllegalArgumentException( + "Incompatible wire type / output type / encoding at index " + i + + ": wireType=" + wt + ", outputTypeId=" + outputTypeIds[i] + ", encoding=" + enc); + } + if (isRepeated[i] && isRequired[i]) { + throw new IllegalArgumentException( + "Field at index " + i + " cannot be both repeated and required"); + } + if (isRepeated[i] && hasDefaultValue[i]) { + throw new IllegalArgumentException( + "Repeated field at index " + i + " cannot carry a default value"); + } + if (hasDefaultValue[i] && + (outputTypeIds[i] == STRUCT_TYPE_ID || outputTypeIds[i] == LIST_TYPE_ID)) { + throw new IllegalArgumentException( + "STRUCT/LIST field at index " + i + " cannot carry a default value"); + } + if (enc == Protobuf.ENC_ENUM_STRING && + (enumValidValues[i] == null || enumValidValues[i].length == 0 || + enumNames[i] == null || enumNames[i].length == 0)) { + throw new IllegalArgumentException( + "Enum-as-string field at index " + i + + " must provide non-empty enumValidValues and enumNames"); + } + if (enumValidValues[i] != null) { + int[] ev = enumValidValues[i]; + for (int j = 1; j < ev.length; j++) { + if (ev[j] <= ev[j - 1]) { + throw new IllegalArgumentException( + "enumValidValues[" + i + "] must be strictly sorted in ascending order " + + "(binary search requires unique values), but found " + ev[j - 1] + + " followed by " + ev[j]); + } + } + if (enumNames[i] != null && enumNames[i].length != ev.length) { + throw new IllegalArgumentException( + "enumNames[" + i + "].length (" + enumNames[i].length + ") must equal " + + "enumValidValues[" + i + "].length (" + ev.length + ")"); + } + } else if (enumNames[i] != null) { + throw new IllegalArgumentException( + "enumNames[" + i + "] is non-null but enumValidValues[" + i + "] is null; " + + "both must be provided together for enum-as-string fields"); + } + } + } + + private static boolean isEncodingCompatible(int wireType, int outputTypeId, int encoding) { + switch (encoding) { + case Protobuf.ENC_DEFAULT: + if (outputTypeId == BOOL8_TYPE_ID || outputTypeId == INT32_TYPE_ID || + outputTypeId == UINT32_TYPE_ID || outputTypeId == INT64_TYPE_ID || + outputTypeId == UINT64_TYPE_ID) { + return wireType == Protobuf.WT_VARINT; + } + if (outputTypeId == FLOAT32_TYPE_ID) { + return wireType == Protobuf.WT_32BIT; + } + if (outputTypeId == FLOAT64_TYPE_ID) { + return wireType == Protobuf.WT_64BIT; + } + if (outputTypeId == STRING_TYPE_ID || outputTypeId == LIST_TYPE_ID || + outputTypeId == STRUCT_TYPE_ID) { + return wireType == Protobuf.WT_LEN; + } + return false; + case Protobuf.ENC_FIXED: + if (outputTypeId == INT32_TYPE_ID || outputTypeId == UINT32_TYPE_ID || + outputTypeId == FLOAT32_TYPE_ID) { + return wireType == Protobuf.WT_32BIT; + } + if (outputTypeId == INT64_TYPE_ID || outputTypeId == UINT64_TYPE_ID || + outputTypeId == FLOAT64_TYPE_ID) { + return wireType == Protobuf.WT_64BIT; + } + return false; + case Protobuf.ENC_ZIGZAG: + return wireType == Protobuf.WT_VARINT && + (outputTypeId == INT32_TYPE_ID || outputTypeId == INT64_TYPE_ID); + case Protobuf.ENC_ENUM_STRING: + return wireType == Protobuf.WT_VARINT && outputTypeId == STRING_TYPE_ID; + default: + return false; + } + } +} diff --git a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptorTest.java b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptorTest.java new file mode 100644 index 0000000000..eabe8e58d0 --- /dev/null +++ b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptorTest.java @@ -0,0 +1,341 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.jni; + +import org.junit.jupiter.api.Test; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotSame; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class ProtobufSchemaDescriptorTest { + private ProtobufSchemaDescriptor makeDescriptor( + boolean isRepeated, + boolean hasDefaultValue, + int encoding, + int[] enumValidValues, + byte[][] enumNames) { + int outputType = (encoding == Protobuf.ENC_ENUM_STRING) + ? ai.rapids.cudf.DType.STRING.getTypeId().getNativeId() + : ai.rapids.cudf.DType.INT32.getTypeId().getNativeId(); + return new ProtobufSchemaDescriptor( + new int[]{1}, + new int[]{-1}, + new int[]{0}, + new int[]{Protobuf.WT_VARINT}, + new int[]{outputType}, + new int[]{encoding}, + new boolean[]{isRepeated}, + new boolean[]{false}, + new boolean[]{hasDefaultValue}, + new long[]{0}, + new double[]{0.0}, + new boolean[]{false}, + new byte[][]{null}, + new int[][]{enumValidValues}, + new byte[][][]{enumNames}); + } + + @Test + void testRepeatedFieldCannotCarryDefaultValue() { + assertThrows(IllegalArgumentException.class, () -> + makeDescriptor(true, true, Protobuf.ENC_DEFAULT, null, null)); + } + + @Test + void testFieldCannotBeBothRepeatedAndRequired() { + assertThrows(IllegalArgumentException.class, () -> + new ProtobufSchemaDescriptor( + new int[]{1}, + new int[]{-1}, + new int[]{0}, + new int[]{Protobuf.WT_VARINT}, + new int[]{ai.rapids.cudf.DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT}, + new boolean[]{true}, + new boolean[]{true}, + new boolean[]{false}, + new long[]{0}, + new double[]{0.0}, + new boolean[]{false}, + new byte[][]{null}, + new int[][]{null}, + new byte[][][]{null})); + } + + @Test + void testStructFieldCannotCarryDefaultValue() { + assertThrows(IllegalArgumentException.class, () -> + new ProtobufSchemaDescriptor( + new int[]{1}, + new int[]{-1}, + new int[]{0}, + new int[]{Protobuf.WT_LEN}, + new int[]{ai.rapids.cudf.DType.STRUCT.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT}, + new boolean[]{false}, + new boolean[]{false}, + new boolean[]{true}, + new long[]{0}, + new double[]{0.0}, + new boolean[]{false}, + new byte[][]{null}, + new int[][]{null}, + new byte[][][]{null})); + } + + @Test + void testListFieldCannotCarryDefaultValue() { + assertThrows(IllegalArgumentException.class, () -> + new ProtobufSchemaDescriptor( + new int[]{1}, + new int[]{-1}, + new int[]{0}, + new int[]{Protobuf.WT_LEN}, + new int[]{ai.rapids.cudf.DType.LIST.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT}, + new boolean[]{false}, + new boolean[]{false}, + new boolean[]{true}, + new long[]{0}, + new double[]{0.0}, + new boolean[]{false}, + new byte[][]{null}, + new int[][]{null}, + new byte[][][]{null})); + } + + @Test + void testEnumStringRequiresEnumMetadata() { + assertThrows(IllegalArgumentException.class, () -> + makeDescriptor(false, false, Protobuf.ENC_ENUM_STRING, null, null)); + assertThrows(IllegalArgumentException.class, () -> + makeDescriptor(false, false, Protobuf.ENC_ENUM_STRING, new int[]{0, 1}, null)); + assertThrows(IllegalArgumentException.class, () -> + makeDescriptor(false, false, Protobuf.ENC_ENUM_STRING, null, + new byte[][]{"A".getBytes(), "B".getBytes()})); + } + + @Test + void testEnumStringRejectsEmptyEnumArrays() { + assertThrows(IllegalArgumentException.class, () -> + makeDescriptor(false, false, Protobuf.ENC_ENUM_STRING, new int[]{}, new byte[][]{})); + } + + @Test + void testDuplicateFieldNumbersUnderSameParentRejected() { + assertThrows(IllegalArgumentException.class, () -> + new ProtobufSchemaDescriptor( + new int[]{1, 7, 7}, + new int[]{-1, 0, 0}, + new int[]{0, 1, 1}, + new int[]{Protobuf.WT_LEN, Protobuf.WT_VARINT, Protobuf.WT_VARINT}, + new int[]{ + ai.rapids.cudf.DType.STRUCT.getTypeId().getNativeId(), + ai.rapids.cudf.DType.INT32.getTypeId().getNativeId(), + ai.rapids.cudf.DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT}, + new boolean[]{false, false, false}, + new boolean[]{false, false, false}, + new boolean[]{false, false, false}, + new long[]{0, 0, 0}, + new double[]{0.0, 0.0, 0.0}, + new boolean[]{false, false, false}, + new byte[][]{null, null, null}, + new int[][]{null, null, null}, + new byte[][][]{null, null, null})); + } + + @Test + void testDuplicateFieldNumbersUnderDifferentParentsAllowed() { + assertDoesNotThrow(() -> + new ProtobufSchemaDescriptor( + new int[]{1, 2, 7, 7}, + new int[]{-1, -1, 0, 1}, + new int[]{0, 0, 1, 1}, + new int[]{Protobuf.WT_LEN, Protobuf.WT_LEN, Protobuf.WT_VARINT, Protobuf.WT_VARINT}, + new int[]{ + ai.rapids.cudf.DType.STRUCT.getTypeId().getNativeId(), + ai.rapids.cudf.DType.STRUCT.getTypeId().getNativeId(), + ai.rapids.cudf.DType.INT32.getTypeId().getNativeId(), + ai.rapids.cudf.DType.INT32.getTypeId().getNativeId()}, + new int[]{ + Protobuf.ENC_DEFAULT, + Protobuf.ENC_DEFAULT, + Protobuf.ENC_DEFAULT, + Protobuf.ENC_DEFAULT}, + new boolean[]{false, false, false, false}, + new boolean[]{false, false, false, false}, + new boolean[]{false, false, false, false}, + new long[]{0, 0, 0, 0}, + new double[]{0.0, 0.0, 0.0, 0.0}, + new boolean[]{false, false, false, false}, + new byte[][]{null, null, null, null}, + new int[][]{null, null, null, null}, + new byte[][][]{null, null, null, null})); + } + + @Test + void testChildParentMustBeStruct() { + assertThrows(IllegalArgumentException.class, () -> + new ProtobufSchemaDescriptor( + new int[]{1, 2}, + new int[]{-1, 0}, + new int[]{0, 1}, + new int[]{Protobuf.WT_VARINT, Protobuf.WT_VARINT}, + new int[]{ + ai.rapids.cudf.DType.INT32.getTypeId().getNativeId(), + ai.rapids.cudf.DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT}, + new boolean[]{false, false}, + new boolean[]{false, false}, + new boolean[]{false, false}, + new long[]{0, 0}, + new double[]{0.0, 0.0}, + new boolean[]{false, false}, + new byte[][]{null, null}, + new int[][]{null, null}, + new byte[][][]{null, null})); + } + + @Test + void testEncodingCompatibilityValidation() { + assertThrows(IllegalArgumentException.class, () -> + new ProtobufSchemaDescriptor( + new int[]{1}, + new int[]{-1}, + new int[]{0}, + new int[]{Protobuf.WT_32BIT}, + new int[]{ai.rapids.cudf.DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT}, + new boolean[]{false}, + new boolean[]{false}, + new boolean[]{false}, + new long[]{0}, + new double[]{0.0}, + new boolean[]{false}, + new byte[][]{null}, + new int[][]{null}, + new byte[][][]{null})); + + assertThrows(IllegalArgumentException.class, () -> + new ProtobufSchemaDescriptor( + new int[]{1}, + new int[]{-1}, + new int[]{0}, + new int[]{Protobuf.WT_LEN}, + new int[]{ai.rapids.cudf.DType.STRING.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_ENUM_STRING}, + new boolean[]{false}, + new boolean[]{false}, + new boolean[]{false}, + new long[]{0}, + new double[]{0.0}, + new boolean[]{false}, + new byte[][]{null}, + new int[][]{{0, 1}}, + new byte[][][]{new byte[][]{"A".getBytes(), "B".getBytes()}})); + } + + @Test + void testDepthAboveSupportedLimitRejected() { + assertThrows(IllegalArgumentException.class, () -> + new ProtobufSchemaDescriptor( + new int[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, + new int[]{-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, + new int[]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + new int[]{Protobuf.WT_LEN, Protobuf.WT_LEN, Protobuf.WT_LEN, Protobuf.WT_LEN, + Protobuf.WT_LEN, Protobuf.WT_LEN, Protobuf.WT_LEN, Protobuf.WT_LEN, + Protobuf.WT_LEN, Protobuf.WT_LEN, Protobuf.WT_VARINT}, + new int[]{ + ai.rapids.cudf.DType.STRUCT.getTypeId().getNativeId(), + ai.rapids.cudf.DType.STRUCT.getTypeId().getNativeId(), + ai.rapids.cudf.DType.STRUCT.getTypeId().getNativeId(), + ai.rapids.cudf.DType.STRUCT.getTypeId().getNativeId(), + ai.rapids.cudf.DType.STRUCT.getTypeId().getNativeId(), + ai.rapids.cudf.DType.STRUCT.getTypeId().getNativeId(), + ai.rapids.cudf.DType.STRUCT.getTypeId().getNativeId(), + ai.rapids.cudf.DType.STRUCT.getTypeId().getNativeId(), + ai.rapids.cudf.DType.STRUCT.getTypeId().getNativeId(), + ai.rapids.cudf.DType.STRUCT.getTypeId().getNativeId(), + ai.rapids.cudf.DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT, + Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT, + Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT, + Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT}, + new boolean[]{false, false, false, false, false, false, false, false, false, false, false}, + new boolean[]{false, false, false, false, false, false, false, false, false, false, false}, + new boolean[]{false, false, false, false, false, false, false, false, false, false, false}, + new long[]{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + new double[]{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, + new boolean[]{false, false, false, false, false, false, false, false, false, false, false}, + new byte[][]{null, null, null, null, null, null, null, null, null, null, null}, + new int[][]{null, null, null, null, null, null, null, null, null, null, null}, + new byte[][][]{null, null, null, null, null, null, null, null, null, null, null})); + } + + @Test + void testSerializationRoundTripPreservesContentsAndIsolation() throws Exception { + ProtobufSchemaDescriptor original = new ProtobufSchemaDescriptor( + new int[]{1}, + new int[]{-1}, + new int[]{0}, + new int[]{Protobuf.WT_VARINT}, + new int[]{ai.rapids.cudf.DType.STRING.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_ENUM_STRING}, + new boolean[]{false}, + new boolean[]{false}, + new boolean[]{false}, + new long[]{7}, + new double[]{0.0}, + new boolean[]{false}, + new byte[][]{"def".getBytes()}, + new int[][]{{0, 1}}, + new byte[][][]{new byte[][]{"A".getBytes(), "B".getBytes()}}); + + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + try (ObjectOutputStream oos = new ObjectOutputStream(baos)) { + oos.writeObject(original); + } + + ProtobufSchemaDescriptor roundTrip; + try (ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(baos.toByteArray()))) { + roundTrip = (ProtobufSchemaDescriptor) ois.readObject(); + } + + assertEquals(original.numFields(), roundTrip.numFields()); + assertArrayEquals(original.fieldNumbers, roundTrip.fieldNumbers); + assertArrayEquals(original.defaultStrings[0], roundTrip.defaultStrings[0]); + assertArrayEquals(original.enumValidValues[0], roundTrip.enumValidValues[0]); + assertArrayEquals(original.enumNames[0][0], roundTrip.enumNames[0][0]); + assertArrayEquals(original.enumNames[0][1], roundTrip.enumNames[0][1]); + assertNotSame(original.defaultStrings, roundTrip.defaultStrings); + assertNotSame(original.defaultStrings[0], roundTrip.defaultStrings[0]); + assertNotSame(original.enumValidValues, roundTrip.enumValidValues); + assertNotSame(original.enumValidValues[0], roundTrip.enumValidValues[0]); + assertNotSame(original.enumNames, roundTrip.enumNames); + assertNotSame(original.enumNames[0], roundTrip.enumNames[0]); + assertNotSame(original.enumNames[0][0], roundTrip.enumNames[0][0]); + } +} diff --git a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java new file mode 100644 index 0000000000..16e836bb02 --- /dev/null +++ b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java @@ -0,0 +1,399 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.jni; + +import ai.rapids.cudf.ColumnVector; +import ai.rapids.cudf.ColumnView; +import ai.rapids.cudf.DType; +import ai.rapids.cudf.HostColumnVector; +import ai.rapids.cudf.Table; +import org.junit.jupiter.api.Test; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; + +/** + * Tests for the Protobuf GPU decoder — framework PR. + * + * These tests verify the decode stub: schema validation, correct output shape, + * null column construction, and empty-row handling. Actual data extraction tests + * are added in follow-up PRs. + */ +public class ProtobufTest { + + private static ProtobufSchemaDescriptor makeScalarSchema(int[] fieldNumbers, int[] typeIds, + int[] encodings) { + int n = fieldNumbers.length; + int[] parentIndices = new int[n]; + int[] depthLevels = new int[n]; + int[] wireTypes = new int[n]; + boolean[] isRepeated = new boolean[n]; + boolean[] isRequired = new boolean[n]; + boolean[] hasDefault = new boolean[n]; + long[] defaultInts = new long[n]; + double[] defaultFloats = new double[n]; + boolean[] defaultBools = new boolean[n]; + byte[][] defaultStrings = new byte[n][]; + int[][] enumValid = new int[n][]; + byte[][][] enumNames = new byte[n][][]; + + java.util.Arrays.fill(parentIndices, -1); + for (int i = 0; i < n; i++) { + wireTypes[i] = deriveWireType(typeIds[i], encodings[i]); + } + return new ProtobufSchemaDescriptor(fieldNumbers, parentIndices, depthLevels, + wireTypes, typeIds, encodings, isRepeated, isRequired, hasDefault, + defaultInts, defaultFloats, defaultBools, defaultStrings, enumValid, enumNames); + } + + private static int deriveWireType(int typeId, int encoding) { + if (encoding == Protobuf.ENC_ENUM_STRING) return Protobuf.WT_VARINT; + if (typeId == DType.FLOAT32.getTypeId().getNativeId()) return Protobuf.WT_32BIT; + if (typeId == DType.FLOAT64.getTypeId().getNativeId()) return Protobuf.WT_64BIT; + if (typeId == DType.STRING.getTypeId().getNativeId()) return Protobuf.WT_LEN; + if (typeId == DType.LIST.getTypeId().getNativeId()) return Protobuf.WT_LEN; + if (typeId == DType.STRUCT.getTypeId().getNativeId()) return Protobuf.WT_LEN; + if (encoding == Protobuf.ENC_FIXED) { + if (typeId == DType.INT64.getTypeId().getNativeId()) return Protobuf.WT_64BIT; + return Protobuf.WT_32BIT; + } + return Protobuf.WT_VARINT; + } + + // ============================================================================ + // Output shape tests — verify the stub produces correctly typed struct columns + // ============================================================================ + + @Test + void testEmptySchemaProducesEmptyStruct() { + Byte[] row = new Byte[]{0x08, 0x01}; + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector result = Protobuf.decodeToStruct(input.getColumn(0), + makeScalarSchema(new int[]{}, new int[]{}, new int[]{}), true)) { + assertNotNull(result); + assertEquals(DType.STRUCT, result.getType()); + assertEquals(1, result.getRowCount()); + assertEquals(0, result.getNumChildren()); + } + } + + @Test + void testSingleScalarFieldOutputShape() { + Byte[] row = new Byte[]{0x08, 0x01}; + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector result = Protobuf.decodeToStruct(input.getColumn(0), + makeScalarSchema( + new int[]{1}, + new int[]{DType.INT64.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT}), true)) { + assertNotNull(result); + assertEquals(DType.STRUCT, result.getType()); + assertEquals(1, result.getRowCount()); + assertEquals(1, result.getNumChildren()); + assertEquals(DType.INT64, result.getChildColumnView(0).getType()); + } + } + + @Test + void testMultiFieldOutputShape() { + Byte[] row = new Byte[]{0x08, 0x01}; + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector result = Protobuf.decodeToStruct(input.getColumn(0), + makeScalarSchema( + new int[]{1, 2, 3}, + new int[]{DType.INT64.getTypeId().getNativeId(), + DType.STRING.getTypeId().getNativeId(), + DType.FLOAT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT}), + true)) { + assertNotNull(result); + assertEquals(DType.STRUCT, result.getType()); + assertEquals(1, result.getRowCount()); + assertEquals(3, result.getNumChildren()); + assertEquals(DType.INT64, result.getChildColumnView(0).getType()); + assertEquals(DType.STRING, result.getChildColumnView(1).getType()); + assertEquals(DType.FLOAT32, result.getChildColumnView(2).getType()); + } + } + + @Test + void testMultipleRowsOutputShape() { + Byte[] row0 = new Byte[]{0x08, 0x01}; + Byte[] row1 = new Byte[]{0x08, 0x02}; + Byte[] row2 = new Byte[]{0x08, 0x03}; + try (Table input = new Table.TestBuilder().column(new Byte[][]{row0, row1, row2}).build(); + ColumnVector result = Protobuf.decodeToStruct(input.getColumn(0), + makeScalarSchema( + new int[]{1}, + new int[]{DType.INT64.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT}), true)) { + assertEquals(3, result.getRowCount()); + assertEquals(1, result.getNumChildren()); + } + } + + // ============================================================================ + // Null input handling + // ============================================================================ + + @Test + void testNullInputRowProducesNullStructRow() { + Byte[] row0 = new Byte[]{0x08, 0x01}; + try (Table input = new Table.TestBuilder().column(new Byte[][]{row0, null}).build(); + ColumnVector result = Protobuf.decodeToStruct(input.getColumn(0), + makeScalarSchema( + new int[]{1}, + new int[]{DType.INT64.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT}), true)) { + assertEquals(2, result.getRowCount()); + try (HostColumnVector hcv = result.copyToHost()) { + assertFalse(hcv.isNull(0), "Row 0 should not be null"); + assertTrue(hcv.isNull(1), "Row 1 (null input) should be null in output struct"); + } + } + } + + @Test + void testAllNullInputRows() { + try (Table input = new Table.TestBuilder().column(new Byte[][]{null, null, null}).build(); + ColumnVector result = Protobuf.decodeToStruct(input.getColumn(0), + makeScalarSchema( + new int[]{1, 2}, + new int[]{DType.INT64.getTypeId().getNativeId(), + DType.STRING.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT}), true)) { + assertEquals(3, result.getRowCount()); + assertEquals(2, result.getNumChildren()); + try (HostColumnVector hcv = result.copyToHost()) { + for (int row = 0; row < 3; row++) { + assertTrue(hcv.isNull(row), "Row " + row + " should be null"); + } + } + } + } + + // ============================================================================ + // Empty-row (0 rows) handling + // ============================================================================ + + @Test + void testZeroRowInput() { + try (Table input = new Table.TestBuilder().column(new Byte[][]{}).build(); + ColumnVector result = Protobuf.decodeToStruct(input.getColumn(0), + makeScalarSchema( + new int[]{1, 2}, + new int[]{DType.INT64.getTypeId().getNativeId(), + DType.STRING.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT}), true)) { + assertEquals(0, result.getRowCount()); + assertEquals(DType.STRUCT, result.getType()); + assertEquals(2, result.getNumChildren()); + assertEquals(DType.INT64, result.getChildColumnView(0).getType()); + assertEquals(DType.STRING, result.getChildColumnView(1).getType()); + } + } + + // ============================================================================ + // Nested schema shape tests (verifies correct column types without decode) + // ============================================================================ + + @Test + void testNestedMessageOutputShape() { + // Schema: message Outer { int32 a = 1; Inner b = 2; } message Inner { int32 x = 1; } + int intType = DType.INT32.getTypeId().getNativeId(); + int structType = DType.STRUCT.getTypeId().getNativeId(); + ProtobufSchemaDescriptor schema = new ProtobufSchemaDescriptor( + new int[]{1, 2, 1}, // field numbers + new int[]{-1, -1, 1}, // parent indices + new int[]{0, 0, 1}, // depth levels + new int[]{Protobuf.WT_VARINT, Protobuf.WT_LEN, Protobuf.WT_VARINT}, + new int[]{intType, structType, intType}, + new int[]{0, 0, 0}, // encodings + new boolean[]{false, false, false}, + new boolean[]{false, false, false}, + new boolean[]{false, false, false}, + new long[]{0, 0, 0}, + new double[]{0, 0, 0}, + new boolean[]{false, false, false}, + new byte[][]{null, null, null}, + new int[][]{null, null, null}, + new byte[][][]{null, null, null} + ); + + Byte[] row = new Byte[]{0x08, 0x01}; + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector result = Protobuf.decodeToStruct(input.getColumn(0), schema, true)) { + assertEquals(DType.STRUCT, result.getType()); + assertEquals(1, result.getRowCount()); + assertEquals(2, result.getNumChildren()); + assertEquals(DType.INT32, result.getChildColumnView(0).getType()); + assertEquals(DType.STRUCT, result.getChildColumnView(1).getType()); + assertEquals(1, result.getChildColumnView(1).getNumChildren()); + assertEquals(DType.INT32, result.getChildColumnView(1).getChildColumnView(0).getType()); + } + } + + @Test + void testRepeatedFieldOutputShape() { + // Schema: message Msg { repeated int32 values = 1; } + int intType = DType.INT32.getTypeId().getNativeId(); + ProtobufSchemaDescriptor schema = new ProtobufSchemaDescriptor( + new int[]{1}, + new int[]{-1}, + new int[]{0}, + new int[]{Protobuf.WT_VARINT}, + new int[]{intType}, + new int[]{0}, + new boolean[]{true}, // is_repeated = true + new boolean[]{false}, + new boolean[]{false}, + new long[]{0}, + new double[]{0}, + new boolean[]{false}, + new byte[][]{null}, + new int[][]{null}, + new byte[][][]{null} + ); + + Byte[] row = new Byte[]{0x08, 0x01}; + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector result = Protobuf.decodeToStruct(input.getColumn(0), schema, true)) { + assertEquals(DType.STRUCT, result.getType()); + assertEquals(1, result.getRowCount()); + assertEquals(1, result.getNumChildren()); + assertEquals(DType.LIST, result.getChildColumnView(0).getType()); + } + } + + @Test + void testZeroRowNestedSchemaShape() { + // 0 rows with nested schema — verify correct type hierarchy + int intType = DType.INT32.getTypeId().getNativeId(); + int structType = DType.STRUCT.getTypeId().getNativeId(); + ProtobufSchemaDescriptor schema = new ProtobufSchemaDescriptor( + new int[]{1, 2, 1}, + new int[]{-1, -1, 1}, + new int[]{0, 0, 1}, + new int[]{Protobuf.WT_VARINT, Protobuf.WT_LEN, Protobuf.WT_VARINT}, + new int[]{intType, structType, intType}, + new int[]{0, 0, 0}, + new boolean[]{false, false, false}, + new boolean[]{false, false, false}, + new boolean[]{false, false, false}, + new long[]{0, 0, 0}, + new double[]{0, 0, 0}, + new boolean[]{false, false, false}, + new byte[][]{null, null, null}, + new int[][]{null, null, null}, + new byte[][][]{null, null, null} + ); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{}).build(); + ColumnVector result = Protobuf.decodeToStruct(input.getColumn(0), schema, true)) { + assertEquals(0, result.getRowCount()); + assertEquals(2, result.getNumChildren()); + assertEquals(DType.INT32, result.getChildColumnView(0).getType()); + assertEquals(DType.STRUCT, result.getChildColumnView(1).getType()); + assertEquals(1, result.getChildColumnView(1).getNumChildren()); + assertEquals(DType.INT32, result.getChildColumnView(1).getChildColumnView(0).getType()); + } + } + + @Test + void testZeroRowRepeatedMessageShape() { + // 0 rows with repeated message schema: repeated Inner inner = 1; + int structType = DType.STRUCT.getTypeId().getNativeId(); + int intType = DType.INT32.getTypeId().getNativeId(); + ProtobufSchemaDescriptor schema = new ProtobufSchemaDescriptor( + new int[]{1, 1}, // field numbers + new int[]{-1, 0}, // parent indices: inner's child has parent=0 + new int[]{0, 1}, // depth levels + new int[]{Protobuf.WT_LEN, Protobuf.WT_VARINT}, + new int[]{structType, intType}, + new int[]{0, 0}, + new boolean[]{true, false}, // inner is repeated + new boolean[]{false, false}, + new boolean[]{false, false}, + new long[]{0, 0}, + new double[]{0, 0}, + new boolean[]{false, false}, + new byte[][]{null, null}, + new int[][]{null, null}, + new byte[][][]{null, null} + ); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{}).build(); + ColumnVector result = Protobuf.decodeToStruct(input.getColumn(0), schema, true)) { + assertEquals(0, result.getRowCount()); + assertEquals(1, result.getNumChildren()); + assertEquals(DType.LIST, result.getChildColumnView(0).getType()); + } + } + + @Test + void testZeroRowRepeatedScalarShape() { + int intType = DType.INT32.getTypeId().getNativeId(); + ProtobufSchemaDescriptor schema = new ProtobufSchemaDescriptor( + new int[]{1}, + new int[]{-1}, + new int[]{0}, + new int[]{Protobuf.WT_VARINT}, + new int[]{intType}, + new int[]{0}, + new boolean[]{true}, + new boolean[]{false}, + new boolean[]{false}, + new long[]{0}, + new double[]{0}, + new boolean[]{false}, + new byte[][]{null}, + new int[][]{null}, + new byte[][][]{null} + ); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{}).build(); + ColumnVector result = Protobuf.decodeToStruct(input.getColumn(0), schema, true)) { + assertEquals(0, result.getRowCount()); + assertEquals(1, result.getNumChildren()); + assertEquals(DType.LIST, result.getChildColumnView(0).getType()); + } + } + + // ============================================================================ + // Input validation tests + // ============================================================================ + + @Test + void testNullBinaryInputThrows() { + assertThrows(IllegalArgumentException.class, () -> + Protobuf.decodeToStruct(null, + makeScalarSchema(new int[]{1}, + new int[]{DType.INT64.getTypeId().getNativeId()}, + new int[]{0}), true)); + } + + @Test + void testNullSchemaThrows() { + Byte[] row = new Byte[]{0x08, 0x01}; + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build()) { + assertThrows(IllegalArgumentException.class, () -> + Protobuf.decodeToStruct(input.getColumn(0), null, true)); + } + } +}