diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 6785eb2e8a4..f59752950a0 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -510,6 +510,9 @@ add_library( src/io/json/process_tokens.cu src/io/json/write_json.cpp src/io/json/write_json.cu + src/io/protobuf/builders.cu + src/io/protobuf/decode.cu + src/io/protobuf/kernels.cu src/io/orc/aggregate_orc_metadata.cpp src/io/orc/dict_enc.cu src/io/orc/orc.cpp diff --git a/cpp/include/cudf/io/detail/protobuf.hpp b/cpp/include/cudf/io/detail/protobuf.hpp new file mode 100644 index 00000000000..e99cfd06104 --- /dev/null +++ b/cpp/include/cudf/io/detail/protobuf.hpp @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#include +#include + +namespace CUDF_EXPORT cudf { +namespace io::protobuf::detail { + +/** + * @brief Check if an encoding is compatible with the given field and data type. + */ +bool is_encoding_compatible(nested_field_descriptor const& field, cudf::data_type const& type); + +/** + * @brief Validate the decode context (schema consistency, encoding compatibility, etc.). + * + * @throws cudf::logic_error if the context is invalid + */ +void validate_decode_options(decode_protobuf_options const& options); + +/** + * @brief Create a view into a single field's metadata from the decode options. + */ +protobuf_field_meta_view make_field_meta_view(decode_protobuf_options const& options, + int schema_idx); + +/** + * @brief Internal implementation of decode_protobuf. + */ +std::unique_ptr decode_protobuf(cudf::column_view const& binary_input, + decode_protobuf_options const& options, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr); + +} // namespace io::protobuf::detail +} // namespace CUDF_EXPORT cudf diff --git a/cpp/include/cudf/io/protobuf.hpp b/cpp/include/cudf/io/protobuf.hpp new file mode 100644 index 00000000000..c3c8bb8d18f --- /dev/null +++ b/cpp/include/cudf/io/protobuf.hpp @@ -0,0 +1,156 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include + +namespace CUDF_EXPORT cudf { +namespace io::protobuf { + +/** + * @brief Protobuf field encoding types. + */ +enum class proto_encoding : int { + DEFAULT = 0, ///< Standard varint encoding + FIXED = 1, ///< Fixed-width encoding (32-bit or 64-bit) + ZIGZAG = 2, ///< ZigZag encoding for signed integers + ENUM_STRING = 3, ///< Enum field decoded as string +}; + +/** + * @brief Get the integer value of a proto_encoding. + */ +CUDF_HOST_DEVICE constexpr int encoding_value(proto_encoding encoding) +{ + return static_cast(encoding); +} + +/// Maximum protobuf field number (29-bit). +constexpr int MAX_FIELD_NUMBER = (1 << 29) - 1; + +/** + * @brief Protobuf wire types. + */ +enum class proto_wire_type : int { + VARINT = 0, ///< Variable-length integer + I64BIT = 1, ///< 64-bit fixed + LEN = 2, ///< Length-delimited + SGROUP = 3, ///< Start group (deprecated) + EGROUP = 4, ///< End group (deprecated) + I32BIT = 5, ///< 32-bit fixed +}; + +/** + * @brief Get the integer value of a proto_wire_type. + */ +CUDF_HOST_DEVICE constexpr int wire_type_value(proto_wire_type wire_type) +{ + return static_cast(wire_type); +} + +/// Maximum supported nesting depth for nested protobuf messages. +constexpr int MAX_NESTING_DEPTH = 10; + +/** + * @brief Descriptor for a single field in a (possibly nested) protobuf schema. + * + * Fields are organized in a flat array where parent-child relationships are + * expressed via @p parent_idx. Top-level fields have `parent_idx == -1`. + */ +struct nested_field_descriptor { + int field_number; ///< Protobuf field number + int parent_idx; ///< Index of parent field in schema (-1 for top-level) + int depth; ///< Nesting depth (0 for top-level) + proto_wire_type wire_type; ///< Expected wire type + cudf::type_id output_type; ///< Output cudf type + proto_encoding encoding; ///< Encoding type + bool is_repeated; ///< Whether this field is repeated (array) + bool is_required; ///< Whether this field is required (proto2) + bool has_default_value; ///< Whether this field has a default value +}; + +/** + * @brief Context for decoding protobuf messages. + * + * Contains the schema (as a flat array of field descriptors), default values, + * and enum metadata needed for decoding. + */ +struct decode_protobuf_options { + std::vector schema; ///< Flat array of field descriptors + std::vector default_ints; ///< Default integer values per field + std::vector default_floats; ///< Default float values per field + std::vector default_bools; ///< Default boolean values per field + std::vector> + default_strings; ///< Default string values per field + std::vector> + enum_valid_values; ///< Valid enum numbers per field + std::vector>> + enum_names; ///< UTF-8 enum names per field + bool fail_on_errors; ///< If true, throw on malformed messages; otherwise return nulls +}; + +/** + * @brief View into a single field's metadata from a decode_protobuf_options. + */ +struct protobuf_field_meta_view { + nested_field_descriptor const& schema; + cudf::data_type const output_type; + int64_t default_int; + double default_float; + bool default_bool; + cudf::detail::host_vector const& default_string; + cudf::detail::host_vector const& enum_valid_values; + std::vector> const& enum_names; +}; + +/** + * @brief Decode serialized protobuf messages into a struct column. + * + * Takes a LIST column where each row contains a serialized protobuf message, + * and decodes it into a STRUCT column according to the provided schema. + * + * Supports nested messages (up to 10 levels), repeated fields (as LIST columns), + * enum-as-string conversion, default values, and required field checking. + * + * @param binary_input LIST or LIST column of serialized protobuf messages + * @param options Decoding options including schema, defaults, and enum metadata + * @param stream CUDA stream used for device memory operations and kernel launches + * @param mr Device memory resource used to allocate the returned column's device memory + * @return A STRUCT column containing the decoded protobuf fields + * + * @throws cudf::logic_error if the schema is invalid + * @throws cudf::logic_error if fail_on_errors is true and a message cannot be decoded + */ +std::unique_ptr decode_protobuf( + cudf::column_view const& binary_input, + decode_protobuf_options const& options, + rmm::cuda_stream_view stream = cudf::get_default_stream(), + rmm::device_async_resource_ref mr = cudf::get_current_device_resource_ref()); + +} // namespace io::protobuf +} // namespace CUDF_EXPORT cudf diff --git a/cpp/src/io/protobuf/builders.cu b/cpp/src/io/protobuf/builders.cu new file mode 100644 index 00000000000..2b7409b575d --- /dev/null +++ b/cpp/src/io/protobuf/builders.cu @@ -0,0 +1,126 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "io/protobuf/kernels.cuh" + +#include +#include + +namespace cudf::io::protobuf::detail { + +std::unique_ptr make_null_column(cudf::data_type dtype, + cudf::size_type num_rows, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + if (num_rows == 0) { return cudf::make_empty_column(dtype); } + + switch (dtype.id()) { + case cudf::type_id::BOOL8: + case cudf::type_id::INT8: + case cudf::type_id::UINT8: + case cudf::type_id::INT16: + case cudf::type_id::UINT16: + case cudf::type_id::INT32: + case cudf::type_id::UINT32: + case cudf::type_id::INT64: + case cudf::type_id::UINT64: + case cudf::type_id::FLOAT32: + case cudf::type_id::FLOAT64: + return cudf::make_fixed_width_column(dtype, num_rows, cudf::mask_state::ALL_NULL, stream, mr); + case cudf::type_id::STRING: { + rmm::device_uvector pairs(num_rows, stream, mr); + thrust::fill(rmm::exec_policy_nosync(stream), + pairs.data(), + pairs.end(), + cudf::strings::detail::string_index_pair{nullptr, 0}); + return cudf::strings::detail::make_strings_column(pairs.data(), pairs.end(), stream, mr); + } + case cudf::type_id::LIST: + return cudf::lists::detail::make_all_nulls_lists_column( + num_rows, cudf::data_type{cudf::type_id::UINT8}, stream, mr); + case cudf::type_id::STRUCT: { + std::vector> empty_children; + auto null_mask = cudf::create_null_mask(num_rows, cudf::mask_state::ALL_NULL, stream, mr); + return cudf::make_structs_column( + num_rows, std::move(empty_children), num_rows, std::move(null_mask), stream, mr); + } + default: CUDF_FAIL("Unsupported type for null column creation"); + } +} + +std::unique_ptr make_empty_column_safe(cudf::data_type dtype, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + switch (dtype.id()) { + case cudf::type_id::LIST: { + auto offsets_col = + std::make_unique(cudf::data_type{cudf::type_id::INT32}, + 1, + rmm::device_buffer(sizeof(int32_t), stream, mr), + rmm::device_buffer{}, + 0); + CUDF_CUDA_TRY(cudaMemsetAsync( + offsets_col->mutable_view().data(), 0, sizeof(int32_t), stream.value())); + auto child_col = std::make_unique( + cudf::data_type{cudf::type_id::UINT8}, 0, rmm::device_buffer{}, rmm::device_buffer{}, 0); + return cudf::make_lists_column( + 0, std::move(offsets_col), std::move(child_col), 0, rmm::device_buffer{}); + } + case cudf::type_id::STRUCT: { + std::vector> empty_children; + return cudf::make_structs_column( + 0, std::move(empty_children), 0, rmm::device_buffer{}, stream, mr); + } + default: return cudf::make_empty_column(dtype); + } +} + +std::unique_ptr make_null_list_column_with_child( + std::unique_ptr child_col, + cudf::size_type num_rows, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + rmm::device_uvector offsets(num_rows + 1, stream, mr); + thrust::fill(rmm::exec_policy_nosync(stream), offsets.begin(), offsets.end(), 0); + auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + num_rows + 1, + offsets.release(), + rmm::device_buffer{}, + 0); + auto null_mask = cudf::create_null_mask(num_rows, cudf::mask_state::ALL_NULL, stream, mr); + return cudf::make_lists_column( + num_rows, std::move(offsets_col), std::move(child_col), num_rows, std::move(null_mask)); +} + +std::unique_ptr make_empty_list_column(std::unique_ptr element_col, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + 1, + rmm::device_buffer(sizeof(int32_t), stream, mr), + rmm::device_buffer{}, + 0); + CUDF_CUDA_TRY(cudaMemsetAsync( + offsets_col->mutable_view().data(), 0, sizeof(int32_t), stream.value())); + return cudf::make_lists_column( + 0, std::move(offsets_col), std::move(element_col), 0, rmm::device_buffer{}); +} + +} // namespace cudf::io::protobuf::detail diff --git a/cpp/src/io/protobuf/decode.cu b/cpp/src/io/protobuf/decode.cu new file mode 100644 index 00000000000..3c755aeb76e --- /dev/null +++ b/cpp/src/io/protobuf/decode.cu @@ -0,0 +1,346 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "io/protobuf/kernels.cuh" + +#include +#include + +#include +#include + +namespace cudf::io::protobuf { + +namespace detail { + +namespace { + +std::unique_ptr make_null_column_with_schema( + std::vector const& schema, + int schema_idx, + int num_fields, + cudf::size_type num_rows, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + auto const& field = schema[schema_idx]; + auto const dtype = cudf::data_type{schema[schema_idx].output_type}; + + if (field.is_repeated) { + std::unique_ptr empty_child; + if (dtype.id() == cudf::type_id::STRUCT) { + empty_child = + make_empty_struct_column_with_schema(schema, schema_idx, num_fields, stream, mr); + } else { + empty_child = make_empty_column_safe(dtype, stream, mr); + } + return make_null_list_column_with_child(std::move(empty_child), num_rows, stream, mr); + } + + if (dtype.id() == cudf::type_id::STRUCT) { + auto child_indices = find_child_field_indices(schema, num_fields, schema_idx); + std::vector> children; + for (auto const child_idx : child_indices) { + children.push_back( + make_null_column_with_schema(schema, child_idx, num_fields, num_rows, stream, mr)); + } + auto null_mask = cudf::create_null_mask(num_rows, cudf::mask_state::ALL_NULL, stream, mr); + return cudf::make_structs_column( + num_rows, std::move(children), num_rows, std::move(null_mask), stream, mr); + } + + return make_null_column(dtype, num_rows, stream, mr); +} + +} // namespace + +bool is_encoding_compatible(nested_field_descriptor const& field, cudf::data_type const& type) +{ + switch (field.encoding) { + case proto_encoding::DEFAULT: + switch (type.id()) { + case cudf::type_id::BOOL8: + case cudf::type_id::INT32: + case cudf::type_id::UINT32: + case cudf::type_id::INT64: + case cudf::type_id::UINT64: return field.wire_type == proto_wire_type::VARINT; + case cudf::type_id::FLOAT32: return field.wire_type == proto_wire_type::I32BIT; + case cudf::type_id::FLOAT64: return field.wire_type == proto_wire_type::I64BIT; + case cudf::type_id::STRING: + case cudf::type_id::LIST: + case cudf::type_id::STRUCT: return field.wire_type == proto_wire_type::LEN; + default: return false; + } + case proto_encoding::FIXED: + switch (type.id()) { + case cudf::type_id::INT32: + case cudf::type_id::UINT32: + case cudf::type_id::FLOAT32: return field.wire_type == proto_wire_type::I32BIT; + case cudf::type_id::INT64: + case cudf::type_id::UINT64: + case cudf::type_id::FLOAT64: return field.wire_type == proto_wire_type::I64BIT; + default: return false; + } + case proto_encoding::ZIGZAG: + return field.wire_type == proto_wire_type::VARINT && + (type.id() == cudf::type_id::INT32 || type.id() == cudf::type_id::INT64); + case proto_encoding::ENUM_STRING: + return field.wire_type == proto_wire_type::VARINT && type.id() == cudf::type_id::STRING; + default: return false; + } +} + +void validate_decode_options(decode_protobuf_options const& context) +{ + auto const num_fields = context.schema.size(); + if (context.default_ints.size() != num_fields) { + CUDF_FAIL("protobuf decode context: default_ints size mismatch with schema (" + + std::to_string(context.default_ints.size()) + " vs " + std::to_string(num_fields) + + ")", + std::invalid_argument); + } + if (context.default_floats.size() != num_fields) { + CUDF_FAIL("protobuf decode context: default_floats size mismatch with schema (" + + std::to_string(context.default_floats.size()) + " vs " + + std::to_string(num_fields) + ")", + std::invalid_argument); + } + if (context.default_bools.size() != num_fields) { + CUDF_FAIL("protobuf decode context: default_bools size mismatch with schema (" + + std::to_string(context.default_bools.size()) + " vs " + std::to_string(num_fields) + + ")", + std::invalid_argument); + } + if (context.default_strings.size() != num_fields) { + CUDF_FAIL("protobuf decode context: default_strings size mismatch with schema (" + + std::to_string(context.default_strings.size()) + " vs " + + std::to_string(num_fields) + ")", + std::invalid_argument); + } + if (context.enum_valid_values.size() != num_fields) { + CUDF_FAIL("protobuf decode context: enum_valid_values size mismatch with schema (" + + std::to_string(context.enum_valid_values.size()) + " vs " + + std::to_string(num_fields) + ")", + std::invalid_argument); + } + if (context.enum_names.size() != num_fields) { + CUDF_FAIL("protobuf decode context: enum_names size mismatch with schema (" + + std::to_string(context.enum_names.size()) + " vs " + std::to_string(num_fields) + + ")", + std::invalid_argument); + } + + std::unordered_set seen_field_numbers; + for (size_t i = 0; i < num_fields; ++i) { + auto const& field = context.schema[i]; + auto const type = cudf::data_type{field.output_type}; + if (field.field_number <= 0 || field.field_number > MAX_FIELD_NUMBER) { + CUDF_FAIL("protobuf decode context: invalid field number at field " + std::to_string(i), + std::invalid_argument); + } + if (field.depth < 0 || field.depth >= MAX_NESTING_DEPTH) { + CUDF_FAIL("protobuf decode context: field depth exceeds supported limit at field " + + std::to_string(i), + std::invalid_argument); + } + if (field.parent_idx < -1 || field.parent_idx >= static_cast(i)) { + CUDF_FAIL("protobuf decode context: invalid parent index at field " + std::to_string(i), + std::invalid_argument); + } + auto const key = (static_cast(static_cast(field.parent_idx)) << 32) | + static_cast(field.field_number); + if (!seen_field_numbers.insert(key).second) { + CUDF_FAIL("protobuf decode context: duplicate field number under same parent at field " + + std::to_string(i), + std::invalid_argument); + } + + if (field.parent_idx == -1) { + if (field.depth != 0) { + CUDF_FAIL("protobuf decode context: top-level field must have depth 0 at field " + + std::to_string(i), + std::invalid_argument); + } + } else { + auto const& parent = context.schema[field.parent_idx]; + if (field.depth != parent.depth + 1) { + CUDF_FAIL("protobuf decode context: child depth mismatch at field " + std::to_string(i), + std::invalid_argument); + } + if (cudf::data_type{context.schema[field.parent_idx].output_type}.id() != + cudf::type_id::STRUCT) { + CUDF_FAIL("protobuf decode context: parent must be STRUCT at field " + std::to_string(i), + std::invalid_argument); + } + } + + if (field.wire_type != proto_wire_type::VARINT && field.wire_type != proto_wire_type::I64BIT && + field.wire_type != proto_wire_type::LEN && field.wire_type != proto_wire_type::I32BIT) { + CUDF_FAIL("protobuf decode context: invalid wire type at field " + std::to_string(i), + std::invalid_argument); + } + if (field.encoding < proto_encoding::DEFAULT || field.encoding > proto_encoding::ENUM_STRING) { + CUDF_FAIL("protobuf decode context: invalid encoding at field " + std::to_string(i), + std::invalid_argument); + } + if (field.is_repeated && field.is_required) { + CUDF_FAIL("protobuf decode context: field cannot be both repeated and required at field " + + std::to_string(i), + std::invalid_argument); + } + if (field.is_repeated && field.has_default_value) { + CUDF_FAIL("protobuf decode context: repeated field cannot carry default value at field " + + std::to_string(i), + std::invalid_argument); + } + if (field.has_default_value && + (type.id() == cudf::type_id::STRUCT || type.id() == cudf::type_id::LIST)) { + CUDF_FAIL("protobuf decode context: STRUCT/LIST field cannot carry default value at field " + + std::to_string(i), + std::invalid_argument); + } + if (!is_encoding_compatible(field, type)) { + CUDF_FAIL("protobuf decode context: incompatible wire type/encoding/output type at field " + + std::to_string(i), + std::invalid_argument); + } + + if (field.encoding == proto_encoding::ENUM_STRING) { + if (context.enum_valid_values[i].empty() || context.enum_names[i].empty()) { + CUDF_FAIL( + "protobuf decode context: enum-as-string field requires non-empty metadata at field " + + std::to_string(i), + std::invalid_argument); + } + if (context.enum_valid_values[i].size() != context.enum_names[i].size()) { + CUDF_FAIL( + "protobuf decode context: enum-as-string metadata mismatch at field " + std::to_string(i), + std::invalid_argument); + } + auto const& ev = context.enum_valid_values[i]; + for (size_t j = 1; j < ev.size(); ++j) { + if (ev[j] <= ev[j - 1]) { + CUDF_FAIL("protobuf decode context: enum_valid_values must be strictly sorted at field " + + std::to_string(i), + std::invalid_argument); + } + } + } + } +} + +protobuf_field_meta_view make_field_meta_view(decode_protobuf_options const& context, + int schema_idx) +{ + auto const idx = static_cast(schema_idx); + return protobuf_field_meta_view{context.schema.at(idx), + cudf::data_type{context.schema.at(idx).output_type}, + context.default_ints.at(idx), + context.default_floats.at(idx), + context.default_bools.at(idx), + context.default_strings.at(idx), + context.enum_valid_values.at(idx), + context.enum_names.at(idx)}; +} + +std::unique_ptr decode_protobuf(cudf::column_view const& binary_input, + decode_protobuf_options const& context, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + validate_decode_options(context); + auto const& schema = context.schema; + CUDF_EXPECTS(binary_input.type().id() == cudf::type_id::LIST, + "binary_input must be a LIST column"); + cudf::lists_column_view const in_list(binary_input); + auto const child_type = in_list.child().type().id(); + CUDF_EXPECTS(child_type == cudf::type_id::INT8 || child_type == cudf::type_id::UINT8, + "binary_input must be a LIST column"); + + auto const num_rows = binary_input.size(); + auto const num_fields = static_cast(schema.size()); + + if (num_fields == 0) { + auto const input_null_count = binary_input.null_count(); + if (input_null_count > 0) { + auto null_mask = cudf::copy_bitmask(binary_input, stream, mr); + return cudf::make_structs_column( + num_rows, {}, input_null_count, std::move(null_mask), stream, mr); + } + return cudf::make_structs_column(num_rows, {}, 0, rmm::device_buffer{}, stream, mr); + } + + if (num_rows == 0) { + std::vector> empty_children; + for (int i = 0; i < num_fields; i++) { + if (schema[i].parent_idx == -1) { + auto field_type = cudf::data_type{schema[i].output_type}; + if (schema[i].is_repeated && field_type.id() == cudf::type_id::STRUCT) { + auto empty_struct = + make_empty_struct_column_with_schema(schema, i, num_fields, stream, mr); + empty_children.push_back(make_empty_list_column(std::move(empty_struct), stream, mr)); + } else if (schema[i].is_repeated) { + auto empty_child = make_empty_column_safe(field_type, stream, mr); + empty_children.push_back(make_empty_list_column(std::move(empty_child), stream, mr)); + } else if (field_type.id() == cudf::type_id::STRUCT) { + empty_children.push_back( + make_empty_struct_column_with_schema(schema, i, num_fields, stream, mr)); + } else { + empty_children.push_back(make_empty_column_safe(field_type, stream, mr)); + } + } + } + return cudf::make_structs_column( + 0, std::move(empty_children), 0, rmm::device_buffer{}, stream, mr); + } + + std::vector> column_map(num_fields); + + std::vector> top_level_children; + for (int i = 0; i < num_fields; i++) { + if (schema[i].parent_idx == -1) { + if (column_map[i]) { + top_level_children.push_back(std::move(column_map[i])); + } else { + top_level_children.push_back( + make_null_column_with_schema(schema, i, num_fields, num_rows, stream, mr)); + } + } + } + + auto const input_null_count = binary_input.null_count(); + if (input_null_count > 0) { + auto null_mask = cudf::copy_bitmask(binary_input, stream, mr); + return cudf::make_structs_column( + num_rows, std::move(top_level_children), input_null_count, std::move(null_mask), stream, mr); + } + + return cudf::make_structs_column( + num_rows, std::move(top_level_children), 0, rmm::device_buffer{}, stream, mr); +} + +} // namespace detail + +std::unique_ptr decode_protobuf(cudf::column_view const& binary_input, + decode_protobuf_options const& context, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + CUDF_FUNC_RANGE(); + return detail::decode_protobuf(binary_input, context, stream, mr); +} + +} // namespace cudf::io::protobuf diff --git a/cpp/src/io/protobuf/device_helpers.cuh b/cpp/src/io/protobuf/device_helpers.cuh new file mode 100644 index 00000000000..ca73ee12c44 --- /dev/null +++ b/cpp/src/io/protobuf/device_helpers.cuh @@ -0,0 +1,295 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "io/protobuf/types.cuh" + +#include + +#include + +#include +#include + +namespace cudf::io::protobuf::detail { + +// ============================================================================ +// Device helper functions +// ============================================================================ + +__device__ inline bool read_varint(uint8_t const* cur, + uint8_t const* end, + uint64_t& out, + int& bytes) +{ + out = 0; + bytes = 0; + int shift = 0; + // Protobuf varint uses 7 bits per byte with MSB as continuation flag. + // A 64-bit value requires at most ceil(64/7) = 10 bytes. + while (cur < end && bytes < MAX_VARINT_BYTES) { + uint8_t b = *cur++; + // For the 10th byte (bytes == 9, shift == 63), only the lowest bit is valid + if (bytes == 9 && (b & 0xFE) != 0) { + return false; // Invalid: 10th byte has more than 1 significant bit + } + out |= (static_cast(b & 0x7Fu) << shift); + bytes++; + if ((b & 0x80u) == 0) { return true; } + shift += 7; + } + return false; +} + +__device__ inline void set_error_once(int* error_flag, int error_code) +{ + int expected = 0; + cuda::atomic_ref ref(*error_flag); + ref.compare_exchange_strong(expected, error_code, cuda::memory_order_relaxed); +} + +void set_error_once_async(int* error_flag, int error_code, rmm::cuda_stream_view stream); + +__device__ inline int get_wire_type_size(int wt, uint8_t const* cur, uint8_t const* end) +{ + switch (wt) { + case wire_type_value(proto_wire_type::VARINT): { + // Need to scan to find the end of varint + int count = 0; + while (cur < end && count < MAX_VARINT_BYTES) { + if ((*cur++ & 0x80u) == 0) { return count + 1; } + count++; + } + return -1; // Invalid varint + } + case wire_type_value(proto_wire_type::I64BIT): + // Check if there's enough data for 8 bytes + if (end - cur < 8) return -1; + return 8; + case wire_type_value(proto_wire_type::I32BIT): + // Check if there's enough data for 4 bytes + if (end - cur < 4) return -1; + return 4; + case wire_type_value(proto_wire_type::LEN): { + uint64_t len; + int n; + if (!read_varint(cur, end, len, n)) return -1; + if (len > static_cast(end - cur - n) || + len > static_cast(cuda::std::numeric_limits::max() - n)) { + return -1; + } + return n + static_cast(len); + } + case wire_type_value(proto_wire_type::SGROUP): { + auto const* start = cur; + int depth = 1; + while (cur < end && depth > 0) { + uint64_t key; + int key_bytes; + if (!read_varint(cur, end, key, key_bytes)) return -1; + cur += key_bytes; + + int inner_wt = static_cast(key & 0x7); + if (inner_wt == wire_type_value(proto_wire_type::EGROUP)) { + --depth; + if (depth == 0) { return static_cast(cur - start); } + } else if (inner_wt == wire_type_value(proto_wire_type::SGROUP)) { + if (++depth > 32) return -1; + } else { + int inner_size = -1; + switch (inner_wt) { + case wire_type_value(proto_wire_type::VARINT): { + uint64_t dummy; + int vbytes; + if (!read_varint(cur, end, dummy, vbytes)) return -1; + inner_size = vbytes; + break; + } + case wire_type_value(proto_wire_type::I64BIT): inner_size = 8; break; + case wire_type_value(proto_wire_type::LEN): { + uint64_t len; + int len_bytes; + if (!read_varint(cur, end, len, len_bytes)) return -1; + if (len > static_cast(cuda::std::numeric_limits::max() - len_bytes)) { + return -1; + } + inner_size = len_bytes + static_cast(len); + break; + } + case wire_type_value(proto_wire_type::I32BIT): inner_size = 4; break; + default: return -1; + } + if (inner_size < 0 || cur + inner_size > end) return -1; + cur += inner_size; + } + } + return -1; + } + case wire_type_value(proto_wire_type::EGROUP): return 0; + default: return -1; + } +} + +__device__ inline bool skip_field(uint8_t const* cur, + uint8_t const* end, + int wt, + uint8_t const*& out_cur) +{ + // A bare end-group is only valid while a start-group payload is being parsed recursively inside + // get_wire_type_size(wire_type_value(proto_wire_type::SGROUP)). + // The scan/count kernels should never accept it as a standalone field because Spark CPU treats + // unmatched end-groups as malformed protobuf. + if (wt == wire_type_value(proto_wire_type::EGROUP)) { return false; } + + int size = get_wire_type_size(wt, cur, end); + if (size < 0) return false; + // Ensure we don't skip past the end of the buffer + if (cur + size > end) return false; + out_cur = cur + size; + return true; +} + +/** + * Get the data offset and length for a field at current position. + * Returns true on success, false on error. + */ +__device__ inline bool get_field_data_location( + uint8_t const* cur, uint8_t const* end, int wt, int32_t& data_offset, int32_t& data_length) +{ + if (wt == wire_type_value(proto_wire_type::LEN)) { + // For length-delimited, read the length prefix + uint64_t len; + int len_bytes; + if (!read_varint(cur, end, len, len_bytes)) return false; + if (len > static_cast(end - cur - len_bytes) || + len > static_cast(cuda::std::numeric_limits::max())) { + return false; + } + data_offset = len_bytes; // offset past the length prefix + data_length = static_cast(len); + } else { + // For fixed-size and varint fields + int field_size = get_wire_type_size(wt, cur, end); + if (field_size < 0) return false; + data_offset = 0; + data_length = field_size; + } + return true; +} + +CUDF_HOST_DEVICE inline size_t flat_index(size_t row, size_t width, size_t col) +{ + return row * width + col; +} + +__device__ inline bool checked_add_int32(int32_t lhs, int32_t rhs, int32_t& out) +{ + auto const sum = static_cast(lhs) + rhs; + if (sum < cuda::std::numeric_limits::min() || + sum > cuda::std::numeric_limits::max()) { + return false; + } + out = static_cast(sum); + return true; +} + +__device__ inline bool check_message_bounds(int32_t start, + int32_t end_pos, + cudf::size_type total_size, + int* error_flag) +{ + if (start < 0 || end_pos < start || end_pos > total_size) { + set_error_once(error_flag, ERR_BOUNDS); + return false; + } + return true; +} + +struct proto_tag { + int field_number; + int wire_type; +}; + +__device__ inline bool decode_tag(uint8_t const*& cur, + uint8_t const* end, + proto_tag& tag, + int* error_flag) +{ + uint64_t key; + int key_bytes; + if (!read_varint(cur, end, key, key_bytes)) { + set_error_once(error_flag, ERR_VARINT); + return false; + } + + cur += key_bytes; + uint64_t fn = key >> 3; + if (fn == 0 || fn > static_cast(MAX_FIELD_NUMBER)) { + set_error_once(error_flag, ERR_FIELD_NUMBER); + return false; + } + tag.field_number = static_cast(fn); + tag.wire_type = static_cast(key & 0x7); + return true; +} + +/** + * Load a little-endian value from unaligned memory. + * Reads bytes individually to avoid unaligned-access issues on GPU. + */ +template +__device__ inline T load_le(uint8_t const* p); + +template <> +__device__ inline uint32_t load_le(uint8_t const* p) +{ + return static_cast(p[0]) | (static_cast(p[1]) << 8) | + (static_cast(p[2]) << 16) | (static_cast(p[3]) << 24); +} + +template <> +__device__ inline uint64_t load_le(uint8_t const* p) +{ + uint64_t v = 0; +#pragma unroll + for (int i = 0; i < 8; ++i) { + v |= (static_cast(p[i]) << (8 * i)); + } + return v; +} + +/** + * O(1) lookup of field_number -> field_index using a direct-mapped table. + * Falls back to linear search when the table is empty (field numbers too large). + */ +// Keep this definition in the header so all CUDA translation units can inline it. +__device__ __forceinline__ int lookup_field(int field_number, + int const* lookup_table, + int lookup_table_size, + field_descriptor const* field_descs, + int num_fields) +{ + if (lookup_table != nullptr && field_number > 0 && field_number < lookup_table_size) { + return lookup_table[field_number]; + } + for (int f = 0; f < num_fields; f++) { + if (field_descs[f].field_number == field_number) return f; + } + return -1; +} + +} // namespace cudf::io::protobuf::detail diff --git a/cpp/src/io/protobuf/host_helpers.hpp b/cpp/src/io/protobuf/host_helpers.hpp new file mode 100644 index 00000000000..ccaa59116db --- /dev/null +++ b/cpp/src/io/protobuf/host_helpers.hpp @@ -0,0 +1,292 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "io/protobuf/types.cuh" + +#include +#include + +#include +#include + +#include +#include +#include +#include + +namespace cudf::io::protobuf::detail { + +// ============================================================================ +// Field number lookup table helpers +// ============================================================================ + +/** + * Build a host-side direct-mapped lookup table: field_number -> index. + * @param get_field_number Callable: (int i) -> field_number for the i-th entry. + * @param num_entries Number of entries. + * @return Empty vector if the max field number exceeds the threshold. + */ +template +inline std::vector build_lookup_table(FieldNumberFn get_field_number, int num_entries) +{ + int max_fn = 0; + for (int i = 0; i < num_entries; i++) { + max_fn = std::max(max_fn, get_field_number(i)); + } + if (max_fn > FIELD_LOOKUP_TABLE_MAX) { return {}; } + std::vector table(max_fn + 1, -1); + for (int i = 0; i < num_entries; i++) { + table[get_field_number(i)] = i; + } + return table; +} + +inline std::vector build_index_lookup_table(nested_field_descriptor const* schema, + int const* field_indices, + int num_indices) +{ + return build_lookup_table([&](int i) { return schema[field_indices[i]].field_number; }, + num_indices); +} + +inline std::vector build_field_lookup_table(field_descriptor const* descs, int num_fields) +{ + return build_lookup_table([&](int i) { return descs[i].field_number; }, num_fields); +} + +/** + * Find all child field indices for a given parent index in the schema. + * This is a commonly used pattern throughout the codebase. + * + * @param schema The schema vector (either nested_field_descriptor or + * device_nested_field_descriptor) + * @param num_fields Number of fields in the schema + * @param parent_idx The parent index to search for + * @return Vector of child field indices + */ +template +std::vector find_child_field_indices(SchemaT const& schema, int num_fields, int parent_idx) +{ + std::vector child_indices; + for (int i = 0; i < num_fields; i++) { + if (schema[i].parent_idx == parent_idx) { child_indices.push_back(i); } + } + return child_indices; +} + +// Forward declarations needed by make_empty_struct_column_with_schema +std::unique_ptr make_empty_column_safe(cudf::data_type dtype, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr); + +std::unique_ptr make_empty_list_column(std::unique_ptr element_col, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr); + +template +std::unique_ptr make_empty_struct_column_with_schema( + SchemaT const& schema, + int parent_idx, + int num_fields, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + auto child_indices = find_child_field_indices(schema, num_fields, parent_idx); + + std::vector> children; + for (int child_idx : child_indices) { + auto child_type = cudf::data_type{schema[child_idx].output_type}; + + std::unique_ptr child_col; + if (child_type.id() == cudf::type_id::STRUCT) { + child_col = make_empty_struct_column_with_schema(schema, child_idx, num_fields, stream, mr); + } else { + child_col = make_empty_column_safe(child_type, stream, mr); + } + + if (schema[child_idx].is_repeated) { + child_col = make_empty_list_column(std::move(child_col), stream, mr); + } + + children.push_back(std::move(child_col)); + } + + return cudf::make_structs_column(0, std::move(children), 0, rmm::device_buffer{}, stream, mr); +} + +void maybe_check_required_fields(field_location const* locations, + std::vector const& field_indices, + std::vector const& schema, + int num_rows, + cudf::bitmask_type const* input_null_mask, + cudf::size_type input_offset, + field_location const* parent_locs, + bool* row_force_null, + int32_t const* top_row_indices, + int* error_flag, + rmm::cuda_stream_view stream); + +void propagate_invalid_enum_flags_to_rows(rmm::device_uvector const& item_invalid, + rmm::device_uvector& row_invalid, + int num_items, + int32_t const* top_row_indices, + bool propagate_to_rows, + rmm::cuda_stream_view stream); + +void validate_enum_and_propagate_rows(rmm::device_uvector const& values, + rmm::device_uvector& valid, + cudf::detail::host_vector const& valid_enums, + rmm::device_uvector& row_invalid, + int num_items, + int32_t const* top_row_indices, + bool propagate_to_rows, + rmm::cuda_stream_view stream); + +// ============================================================================ +// Forward declarations of builder/utility functions +// ============================================================================ + +std::unique_ptr make_null_column(cudf::data_type dtype, + cudf::size_type num_rows, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr); + +std::unique_ptr make_null_list_column_with_child( + std::unique_ptr child_col, + cudf::size_type num_rows, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr); + +std::unique_ptr build_enum_string_column( + rmm::device_uvector& enum_values, + rmm::device_uvector& valid, + cudf::detail::host_vector const& valid_enums, + std::vector> const& enum_name_bytes, + rmm::device_uvector& d_row_force_null, + int num_rows, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr, + int32_t const* top_row_indices = nullptr, + bool propagate_invalid_rows = true); + +// Complex builder forward declarations +std::unique_ptr build_repeated_enum_string_column( + cudf::column_view const& binary_input, + uint8_t const* message_data, + cudf::size_type const* list_offsets, + cudf::size_type base_offset, + rmm::device_uvector const& d_field_counts, + rmm::device_uvector& d_occurrences, + int total_count, + int num_rows, + cudf::detail::host_vector const& valid_enums, + std::vector> const& enum_name_bytes, + rmm::device_uvector& d_row_force_null, + rmm::device_uvector& d_error, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr); + +std::unique_ptr build_repeated_string_column( + cudf::column_view const& binary_input, + uint8_t const* message_data, + cudf::size_type const* list_offsets, + cudf::size_type base_offset, + device_nested_field_descriptor const& field_desc, + rmm::device_uvector const& d_field_counts, + rmm::device_uvector& d_occurrences, + int total_count, + int num_rows, + bool is_bytes, + rmm::device_uvector& d_error, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr); + +std::unique_ptr build_nested_struct_column( + uint8_t const* message_data, + cudf::size_type message_data_size, + cudf::size_type const* list_offsets, + cudf::size_type base_offset, + rmm::device_uvector const& d_parent_locs, + std::vector const& child_field_indices, + std::vector const& schema, + int num_fields, + std::vector const& default_ints, + std::vector const& default_floats, + std::vector const& default_bools, + std::vector> const& default_strings, + std::vector> const& enum_valid_values, + std::vector>> const& enum_names, + rmm::device_uvector& d_row_force_null, + rmm::device_uvector& d_error, + int num_rows, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr, + int32_t const* top_row_indices, + int depth, + bool propagate_invalid_rows = true); + +std::unique_ptr build_repeated_child_list_column( + uint8_t const* message_data, + cudf::size_type message_data_size, + cudf::size_type const* row_offsets, + cudf::size_type base_offset, + field_location const* parent_locs, + int num_parent_rows, + int child_schema_idx, + std::vector const& schema, + int num_fields, + std::vector const& default_ints, + std::vector const& default_floats, + std::vector const& default_bools, + std::vector> const& default_strings, + std::vector> const& enum_valid_values, + std::vector>> const& enum_names, + rmm::device_uvector& d_row_force_null, + rmm::device_uvector& d_error, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr, + int32_t const* top_row_indices, + int depth, + bool propagate_invalid_rows = true); + +std::unique_ptr build_repeated_struct_column( + cudf::column_view const& binary_input, + uint8_t const* message_data, + cudf::size_type message_data_size, + cudf::size_type const* list_offsets, + cudf::size_type base_offset, + device_nested_field_descriptor const& field_desc, + rmm::device_uvector const& d_field_counts, + rmm::device_uvector& d_occurrences, + int total_count, + int num_rows, + std::vector const& h_device_schema, + std::vector const& child_field_indices, + std::vector const& default_ints, + std::vector const& default_floats, + std::vector const& default_bools, + std::vector> const& default_strings, + std::vector const& schema, + std::vector> const& enum_valid_values, + std::vector>> const& enum_names, + rmm::device_uvector& d_row_force_null, + rmm::device_uvector& d_error_top, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr); + +} // namespace cudf::io::protobuf::detail diff --git a/cpp/src/io/protobuf/kernels.cu b/cpp/src/io/protobuf/kernels.cu new file mode 100644 index 00000000000..f96d4aae838 --- /dev/null +++ b/cpp/src/io/protobuf/kernels.cu @@ -0,0 +1,168 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "io/protobuf/kernels.cuh" + +#include +#include + +#include +#include + +#include +#include +#include +#include + +namespace cudf::io::protobuf::detail { + +namespace { + +CUDF_KERNEL void set_error_if_unset_kernel(int* error_flag, int error_code) +{ + if (blockIdx.x == 0 && threadIdx.x == 0) { set_error_once(error_flag, error_code); } +} + +// Stub kernels — replaced with real implementations in follow-up PRs. +CUDF_KERNEL void check_required_fields_kernel(field_location const*, + uint8_t const*, + int, + int, + cudf::bitmask_type const*, + cudf::size_type, + field_location const*, + bool*, + int32_t const*, + int*) +{ +} + +CUDF_KERNEL void validate_enum_values_kernel(int32_t const*, bool*, bool*, int32_t const*, int, int) +{ +} + +} // namespace + +void set_error_once_async(int* error_flag, int error_code, rmm::cuda_stream_view stream) +{ + set_error_if_unset_kernel<<<1, 1, 0, stream.value()>>>(error_flag, error_code); + CUDF_CUDA_TRY(cudaPeekAtLastError()); +} + +void maybe_check_required_fields(field_location const* locations, + std::vector const& field_indices, + std::vector const& schema, + int num_rows, + cudf::bitmask_type const* input_null_mask, + cudf::size_type input_offset, + field_location const* parent_locs, + bool* row_force_null, + int32_t const* top_row_indices, + int* error_flag, + rmm::cuda_stream_view stream) +{ + if (num_rows == 0 || field_indices.empty()) { return; } + + bool has_required = false; + auto h_is_required = + cudf::detail::make_pinned_vector_async(field_indices.size(), stream); + for (size_t i = 0; i < field_indices.size(); ++i) { + h_is_required[i] = schema[field_indices[i]].is_required ? 1 : 0; + has_required |= (h_is_required[i] != 0); + } + if (!has_required) { return; } + + auto d_is_required = cudf::detail::make_device_uvector_async( + h_is_required, stream, rmm::mr::get_current_device_resource_ref()); + + auto const blocks = static_cast((num_rows + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); + check_required_fields_kernel<<>>( + locations, + d_is_required.data(), + static_cast(field_indices.size()), + num_rows, + input_null_mask, + input_offset, + parent_locs, + row_force_null, + top_row_indices, + error_flag); +} + +void propagate_invalid_enum_flags_to_rows(rmm::device_uvector const& item_invalid, + rmm::device_uvector& row_invalid, + int num_items, + int32_t const* top_row_indices, + bool propagate_to_rows, + rmm::cuda_stream_view stream) +{ + if (num_items == 0 || row_invalid.size() == 0 || !propagate_to_rows) { return; } + + if (top_row_indices == nullptr) { + CUDF_EXPECTS(static_cast(num_items) <= row_invalid.size(), + "enum invalid-row propagation exceeded row buffer"); + thrust::transform(rmm::exec_policy_nosync(stream), + row_invalid.begin(), + row_invalid.begin() + num_items, + item_invalid.begin(), + row_invalid.begin(), + [] __device__(bool row_is_invalid, bool item_is_invalid) { + return row_is_invalid || item_is_invalid; + }); + return; + } + + thrust::for_each(rmm::exec_policy_nosync(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(num_items), + [item_invalid = item_invalid.data(), + top_row_indices, + row_invalid = row_invalid.data()] __device__(int idx) { + if (item_invalid[idx]) { row_invalid[top_row_indices[idx]] = true; } + }); +} + +void validate_enum_and_propagate_rows(rmm::device_uvector const& values, + rmm::device_uvector& valid, + cudf::detail::host_vector const& valid_enums, + rmm::device_uvector& row_invalid, + int num_items, + int32_t const* top_row_indices, + bool propagate_to_rows, + rmm::cuda_stream_view stream) +{ + if (num_items == 0 || valid_enums.empty()) { return; } + + auto const blocks = static_cast((num_items + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); + auto d_valid_enums = cudf::detail::make_device_uvector_async( + valid_enums, stream, rmm::mr::get_current_device_resource_ref()); + + rmm::device_uvector item_invalid( + num_items, stream, rmm::mr::get_current_device_resource_ref()); + thrust::fill(rmm::exec_policy_nosync(stream), item_invalid.begin(), item_invalid.end(), false); + validate_enum_values_kernel<<>>( + values.data(), + valid.data(), + item_invalid.data(), + d_valid_enums.data(), + static_cast(valid_enums.size()), + num_items); + + propagate_invalid_enum_flags_to_rows( + item_invalid, row_invalid, num_items, top_row_indices, propagate_to_rows, stream); +} + +} // namespace cudf::io::protobuf::detail diff --git a/cpp/src/io/protobuf/kernels.cuh b/cpp/src/io/protobuf/kernels.cuh new file mode 100644 index 00000000000..e48b1a68746 --- /dev/null +++ b/cpp/src/io/protobuf/kernels.cuh @@ -0,0 +1,936 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "io/protobuf/device_helpers.cuh" +#include "io/protobuf/host_helpers.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace cudf::io::protobuf::detail { + +// ============================================================================ +// Pass 2: Extract data kernels +// ============================================================================ + +// ============================================================================ +// Data Extraction Location Providers +// ============================================================================ + +struct top_level_location_provider { + cudf::size_type const* offsets; + cudf::size_type base_offset; + field_location const* locations; + int field_idx; + int num_fields; + + __device__ inline field_location get(int thread_idx, int32_t& data_offset) const + { + auto loc = locations[flat_index(static_cast(thread_idx), + static_cast(num_fields), + static_cast(field_idx))]; + if (loc.offset >= 0) { data_offset = offsets[thread_idx] - base_offset + loc.offset; } + return loc; + } +}; + +struct repeated_location_provider { + cudf::size_type const* row_offsets; + cudf::size_type base_offset; + repeated_occurrence const* occurrences; + + __device__ inline field_location get(int thread_idx, int32_t& data_offset) const + { + auto occ = occurrences[thread_idx]; + data_offset = row_offsets[occ.row_idx] - base_offset + occ.offset; + return {occ.offset, occ.length}; + } +}; + +struct nested_location_provider { + cudf::size_type const* row_offsets; + cudf::size_type base_offset; + field_location const* parent_locations; + field_location const* child_locations; + int field_idx; + int num_fields; + + __device__ inline field_location get(int thread_idx, int32_t& data_offset) const + { + auto ploc = parent_locations[thread_idx]; + auto cloc = child_locations[flat_index(static_cast(thread_idx), + static_cast(num_fields), + static_cast(field_idx))]; + if (ploc.offset >= 0 && cloc.offset >= 0) { + data_offset = row_offsets[thread_idx] - base_offset + ploc.offset + cloc.offset; + } else { + cloc.offset = -1; + } + return cloc; + } +}; + +struct nested_repeated_location_provider { + cudf::size_type const* row_offsets; + cudf::size_type base_offset; + field_location const* parent_locations; + repeated_occurrence const* occurrences; + + __device__ inline field_location get(int thread_idx, int32_t& data_offset) const + { + auto occ = occurrences[thread_idx]; + auto ploc = parent_locations[occ.row_idx]; + if (ploc.offset >= 0) { + data_offset = row_offsets[occ.row_idx] - base_offset + ploc.offset + occ.offset; + return {occ.offset, occ.length}; + } + data_offset = 0; + return {-1, 0}; + } +}; + +struct repeated_msg_child_location_provider { + cudf::size_type const* row_offsets; + cudf::size_type base_offset; + field_location const* msg_locations; + field_location const* child_locations; + int field_idx; + int num_fields; + + __device__ inline field_location get(int thread_idx, int32_t& data_offset) const + { + auto mloc = msg_locations[thread_idx]; + auto cloc = child_locations[flat_index(static_cast(thread_idx), + static_cast(num_fields), + static_cast(field_idx))]; + if (mloc.offset >= 0 && cloc.offset >= 0) { + data_offset = row_offsets[thread_idx] - base_offset + mloc.offset + cloc.offset; + } else { + cloc.offset = -1; + } + return cloc; + } +}; + +template +CUDF_KERNEL void extract_varint_kernel(uint8_t const* message_data, + LocationProvider loc_provider, + int total_items, + OutputType* out, + bool* valid, + int* error_flag, + bool has_default = false, + int64_t default_value = 0) +{ + auto idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (idx >= total_items) return; + + int32_t data_offset = 0; + auto loc = loc_provider.get(idx, data_offset); + + // For BOOL8 (uint8_t), protobuf spec says any non-zero varint is true. + // A raw static_cast would silently truncate values >= 256 to 0. + auto const write_value = [](OutputType* dst, uint64_t val) { + if constexpr (cuda::std::is_same_v) { + *dst = static_cast(val != 0 ? 1 : 0); + } else { + *dst = static_cast(val); + } + }; + + if (loc.offset < 0) { + if (has_default) { + write_value(&out[idx], static_cast(default_value)); + if (valid) valid[idx] = true; + } else { + if (valid) valid[idx] = false; + } + return; + } + + uint8_t const* cur = message_data + data_offset; + uint8_t const* cur_end = cur + loc.length; + + uint64_t v; + int n; + if (!read_varint(cur, cur_end, v, n)) { + set_error_once(error_flag, ERR_VARINT); + if (valid) valid[idx] = false; + return; + } + + if constexpr (ZigZag) { v = (v >> 1) ^ (-(v & 1)); } + write_value(&out[idx], v); + if (valid) valid[idx] = true; +} + +template +CUDF_KERNEL void extract_fixed_kernel(uint8_t const* message_data, + LocationProvider loc_provider, + int total_items, + OutputType* out, + bool* valid, + int* error_flag, + bool has_default = false, + OutputType default_value = OutputType{}) +{ + auto idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (idx >= total_items) return; + + int32_t data_offset = 0; + auto loc = loc_provider.get(idx, data_offset); + + if (loc.offset < 0) { + if (has_default) { + out[idx] = default_value; + if (valid) valid[idx] = true; + } else { + if (valid) valid[idx] = false; + } + return; + } + + uint8_t const* cur = message_data + data_offset; + OutputType value; + + if constexpr (WT == wire_type_value(proto_wire_type::I32BIT)) { + if (loc.length < 4) { + set_error_once(error_flag, ERR_FIXED_LEN); + if (valid) valid[idx] = false; + return; + } + uint32_t raw = load_le(cur); + memcpy(&value, &raw, sizeof(value)); + } else { + if (loc.length < 8) { + set_error_once(error_flag, ERR_FIXED_LEN); + if (valid) valid[idx] = false; + return; + } + uint64_t raw = load_le(cur); + memcpy(&value, &raw, sizeof(value)); + } + + out[idx] = value; + if (valid) valid[idx] = true; +} + +// ============================================================================ +// Batched scalar extraction — one 2D kernel for N fields of the same type +// ============================================================================ + +struct batched_scalar_desc { + int loc_field_idx; // index into the locations array (column within d_locations) + void* output; // pre-allocated output buffer (T*) + bool* valid; // pre-allocated validity buffer + bool has_default; + int64_t default_int; + double default_float; +}; + +template +CUDF_KERNEL void extract_varint_batched_kernel(uint8_t const* message_data, + cudf::size_type const* row_offsets, + cudf::size_type base_offset, + field_location const* locations, + int num_loc_fields, + batched_scalar_desc const* descs, + int num_descs, + int num_rows, + int* error_flag) +{ + int row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + int fi = static_cast(blockIdx.y); + if (row >= num_rows || fi >= num_descs) return; + + auto const& desc = descs[fi]; + auto loc = locations[row * num_loc_fields + desc.loc_field_idx]; + auto* out = static_cast(desc.output); + + auto const write_value = [](OutputType* dst, uint64_t val) { + if constexpr (cuda::std::is_same_v) { + *dst = static_cast(val != 0 ? 1 : 0); + } else { + *dst = static_cast(val); + } + }; + + if (loc.offset < 0) { + if (desc.has_default) { + write_value(&out[row], static_cast(desc.default_int)); + desc.valid[row] = true; + } else { + desc.valid[row] = false; + } + return; + } + + int32_t data_offset = row_offsets[row] - base_offset + loc.offset; + uint8_t const* cur = message_data + data_offset; + uint8_t const* end = cur + loc.length; + + uint64_t v; + int n; + if (!read_varint(cur, end, v, n)) { + set_error_once(error_flag, ERR_VARINT); + desc.valid[row] = false; + return; + } + if constexpr (ZigZag) { v = (v >> 1) ^ (-(v & 1)); } + write_value(&out[row], v); + desc.valid[row] = true; +} + +template +CUDF_KERNEL void extract_fixed_batched_kernel(uint8_t const* message_data, + cudf::size_type const* row_offsets, + cudf::size_type base_offset, + field_location const* locations, + int num_loc_fields, + batched_scalar_desc const* descs, + int num_descs, + int num_rows, + int* error_flag) +{ + int row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + int fi = static_cast(blockIdx.y); + if (row >= num_rows || fi >= num_descs) return; + + auto const& desc = descs[fi]; + auto loc = locations[row * num_loc_fields + desc.loc_field_idx]; + auto* out = static_cast(desc.output); + + if (loc.offset < 0) { + if (desc.has_default) { + if constexpr (cuda::std::is_integral_v) { + out[row] = static_cast(desc.default_int); + } else { + out[row] = static_cast(desc.default_float); + } + desc.valid[row] = true; + } else { + desc.valid[row] = false; + } + return; + } + + int32_t data_offset = row_offsets[row] - base_offset + loc.offset; + uint8_t const* cur = message_data + data_offset; + OutputType value; + + if constexpr (WT == wire_type_value(proto_wire_type::I32BIT)) { + if (loc.length < 4) { + set_error_once(error_flag, ERR_FIXED_LEN); + desc.valid[row] = false; + return; + } + uint32_t raw = load_le(cur); + memcpy(&value, &raw, sizeof(value)); + } else { + if (loc.length < 8) { + set_error_once(error_flag, ERR_FIXED_LEN); + desc.valid[row] = false; + return; + } + uint64_t raw = load_le(cur); + memcpy(&value, &raw, sizeof(value)); + } + out[row] = value; + desc.valid[row] = true; +} + +// ============================================================================ + +template +CUDF_KERNEL void extract_lengths_kernel(LocationProvider loc_provider, + int total_items, + int32_t* out_lengths, + bool has_default = false, + int32_t default_length = 0) +{ + auto idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (idx >= total_items) return; + + int32_t data_offset = 0; + auto loc = loc_provider.get(idx, data_offset); + + if (loc.offset >= 0) { + out_lengths[idx] = loc.length; + } else if (has_default) { + out_lengths[idx] = default_length; + } else { + out_lengths[idx] = 0; + } +} + +// ============================================================================ +// Host-side template helpers that launch CUDA kernels +// ============================================================================ + +template +inline std::pair make_null_mask_from_valid( + rmm::device_uvector const& valid, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + auto begin = thrust::make_counting_iterator(0); + auto end = begin + valid.size(); + auto pred = [ptr = valid.data()] __device__(cudf::size_type i) { + return static_cast(ptr[i]); + }; + return cudf::detail::valid_if(begin, end, pred, stream, mr); +} + +template +std::unique_ptr extract_and_build_scalar_column(cudf::data_type dt, + int num_rows, + LaunchFn&& launch_extract, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + rmm::device_uvector out(num_rows, stream, mr); + rmm::device_uvector valid((num_rows > 0 ? num_rows : 1), stream, mr); + if (num_rows == 0) { + return std::make_unique(dt, 0, out.release(), rmm::device_buffer{}, 0); + } + launch_extract(out.data(), valid.data()); + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + return std::make_unique(dt, num_rows, out.release(), std::move(mask), null_count); +} + +template +inline void extract_integer_into_buffers(uint8_t const* message_data, + LocationProvider const& loc_provider, + int num_rows, + int blocks, + int threads, + bool has_default, + int64_t default_value, + int encoding, + bool enable_zigzag, + T* out_ptr, + bool* valid_ptr, + int* error_ptr, + rmm::cuda_stream_view stream) +{ + if (enable_zigzag && encoding == encoding_value(proto_encoding::ZIGZAG)) { + extract_varint_kernel + <<>>(message_data, + loc_provider, + num_rows, + out_ptr, + valid_ptr, + error_ptr, + has_default, + default_value); + } else if (encoding == encoding_value(proto_encoding::FIXED)) { + if constexpr (sizeof(T) == 4) { + extract_fixed_kernel + <<>>(message_data, + loc_provider, + num_rows, + out_ptr, + valid_ptr, + error_ptr, + has_default, + static_cast(default_value)); + } else { + static_assert(sizeof(T) == 8, "extract_integer_into_buffers only supports 32/64-bit"); + extract_fixed_kernel + <<>>(message_data, + loc_provider, + num_rows, + out_ptr, + valid_ptr, + error_ptr, + has_default, + static_cast(default_value)); + } + } else { + extract_varint_kernel + <<>>(message_data, + loc_provider, + num_rows, + out_ptr, + valid_ptr, + error_ptr, + has_default, + default_value); + } +} + +template +std::unique_ptr extract_and_build_integer_column(cudf::data_type dt, + uint8_t const* message_data, + LocationProvider const& loc_provider, + int num_rows, + int blocks, + int threads, + rmm::device_uvector& d_error, + bool has_default, + int64_t default_value, + int encoding, + bool enable_zigzag, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + return extract_and_build_scalar_column( + dt, + num_rows, + [&](T* out_ptr, bool* valid_ptr) { + extract_integer_into_buffers(message_data, + loc_provider, + num_rows, + blocks, + threads, + has_default, + default_value, + encoding, + enable_zigzag, + out_ptr, + valid_ptr, + d_error.data(), + stream); + }, + stream, + mr); +} + +struct extract_strided_count { + repeated_field_info const* info; + int field_idx; + int num_fields; + + __device__ int32_t operator()(int row) const + { + return info[flat_index(static_cast(row), + static_cast(num_fields), + static_cast(field_idx))] + .count; + } +}; + +template +inline std::unique_ptr extract_and_build_string_or_bytes_column( + bool as_bytes, + uint8_t const* message_data, + int num_rows, + LengthProvider const& length_provider, + CopyProvider const& copy_provider, + ValidityFn validity_fn, + bool has_default, + cudf::detail::host_vector const& default_bytes, + rmm::device_uvector& d_error, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + int32_t def_len = has_default ? static_cast(default_bytes.size()) : 0; + rmm::device_uvector d_default(0, stream, mr); + if (has_default && def_len > 0) { + d_default = cudf::detail::make_device_uvector_async( + default_bytes, stream, rmm::mr::get_current_device_resource_ref()); + } + + rmm::device_uvector lengths(num_rows, stream, mr); + auto const threads = THREADS_PER_BLOCK; + auto const blocks = static_cast((num_rows + threads - 1u) / threads); + extract_lengths_kernel<<>>( + length_provider, num_rows, lengths.data(), has_default, def_len); + + auto [offsets_col, total_size] = + cudf::strings::detail::make_offsets_child_column(lengths.begin(), lengths.end(), stream, mr); + + rmm::device_uvector chars(total_size, stream, mr); + if (total_size > 0) { + auto const* offsets_data = offsets_col->view().data(); + auto* chars_ptr = chars.data(); + auto const* default_ptr = d_default.data(); + + auto src_iter = cudf::detail::make_counting_transform_iterator( + 0, + cuda::proclaim_return_type( + [message_data, copy_provider, has_default, default_ptr, def_len] __device__( + int idx) -> void const* { + int32_t data_offset = 0; + auto loc = copy_provider.get(idx, data_offset); + if (loc.offset < 0) { + return (has_default && def_len > 0) ? static_cast(default_ptr) : nullptr; + } + return static_cast(message_data + data_offset); + })); + auto dst_iter = cudf::detail::make_counting_transform_iterator( + 0, cuda::proclaim_return_type([chars_ptr, offsets_data] __device__(int idx) -> void* { + return static_cast(chars_ptr + offsets_data[idx]); + })); + auto size_iter = cudf::detail::make_counting_transform_iterator( + 0, + cuda::proclaim_return_type( + [copy_provider, has_default, def_len] __device__(int idx) -> size_t { + int32_t data_offset = 0; + auto loc = copy_provider.get(idx, data_offset); + if (loc.offset < 0) { + return (has_default && def_len > 0) ? static_cast(def_len) : 0; + } + return static_cast(loc.length); + })); + + size_t temp_storage_bytes = 0; + cub::DeviceMemcpy::Batched( + nullptr, temp_storage_bytes, src_iter, dst_iter, size_iter, num_rows, stream.value()); + rmm::device_buffer temp_storage(temp_storage_bytes, stream, mr); + cub::DeviceMemcpy::Batched(temp_storage.data(), + temp_storage_bytes, + src_iter, + dst_iter, + size_iter, + num_rows, + stream.value()); + } + + if (num_rows == 0) { + if (as_bytes) { + auto bytes_child = std::make_unique( + cudf::data_type{cudf::type_id::UINT8}, 0, rmm::device_buffer{}, rmm::device_buffer{}, 0); + return cudf::make_lists_column( + 0, std::move(offsets_col), std::move(bytes_child), 0, rmm::device_buffer{}); + } + return cudf::make_strings_column( + 0, std::move(offsets_col), chars.release(), 0, rmm::device_buffer{}); + } + + rmm::device_uvector valid(num_rows, stream, mr); + thrust::transform(rmm::exec_policy_nosync(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(num_rows), + valid.data(), + validity_fn); + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + if (as_bytes) { + auto bytes_child = + std::make_unique(cudf::data_type{cudf::type_id::UINT8}, + total_size, + rmm::device_buffer(chars.data(), total_size, stream, mr), + rmm::device_buffer{}, + 0); + return cudf::make_lists_column( + num_rows, std::move(offsets_col), std::move(bytes_child), null_count, std::move(mask)); + } + + return cudf::make_strings_column( + num_rows, std::move(offsets_col), chars.release(), null_count, std::move(mask)); +} + +template +inline std::unique_ptr extract_typed_column( + cudf::data_type dt, + int encoding, + uint8_t const* message_data, + LocationProvider const& loc_provider, + int num_items, + int blocks, + int threads_per_block, + bool has_default, + int64_t default_int, + double default_float, + bool default_bool, + cudf::detail::host_vector const& default_string, + int schema_idx, + std::vector> const& enum_valid_values, + std::vector>> const& enum_names, + rmm::device_uvector& d_row_force_null, + rmm::device_uvector& d_error, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr, + int32_t const* top_row_indices = nullptr, + bool propagate_invalid_rows = true) +{ + switch (dt.id()) { + case cudf::type_id::BOOL8: { + int64_t def_val = has_default ? (default_bool ? 1 : 0) : 0; + return extract_and_build_scalar_column( + dt, + num_items, + [&](uint8_t* out_ptr, bool* valid_ptr) { + extract_varint_kernel + <<>>(message_data, + loc_provider, + num_items, + out_ptr, + valid_ptr, + d_error.data(), + has_default, + def_val); + }, + stream, + mr); + } + case cudf::type_id::INT32: { + if (num_items == 0) { + return std::make_unique(dt, 0, rmm::device_buffer{}, rmm::device_buffer{}, 0); + } + rmm::device_uvector out(num_items, stream, mr); + rmm::device_uvector valid(num_items, stream, mr); + extract_integer_into_buffers(message_data, + loc_provider, + num_items, + blocks, + threads_per_block, + has_default, + default_int, + encoding, + true, + out.data(), + valid.data(), + d_error.data(), + stream); + if (schema_idx < static_cast(enum_valid_values.size())) { + auto const& valid_enums = enum_valid_values[schema_idx]; + if (!valid_enums.empty()) { + validate_enum_and_propagate_rows(out, + valid, + valid_enums, + d_row_force_null, + num_items, + top_row_indices, + propagate_invalid_rows, + stream); + } + } + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + return std::make_unique( + dt, num_items, out.release(), std::move(mask), null_count); + } + case cudf::type_id::UINT32: + return extract_and_build_integer_column(dt, + message_data, + loc_provider, + num_items, + blocks, + threads_per_block, + d_error, + has_default, + default_int, + encoding, + false, + stream, + mr); + case cudf::type_id::INT64: + return extract_and_build_integer_column(dt, + message_data, + loc_provider, + num_items, + blocks, + threads_per_block, + d_error, + has_default, + default_int, + encoding, + true, + stream, + mr); + case cudf::type_id::UINT64: + return extract_and_build_integer_column(dt, + message_data, + loc_provider, + num_items, + blocks, + threads_per_block, + d_error, + has_default, + default_int, + encoding, + false, + stream, + mr); + case cudf::type_id::FLOAT32: { + float def_float_val = has_default ? static_cast(default_float) : 0.0f; + return extract_and_build_scalar_column( + dt, + num_items, + [&](float* out_ptr, bool* valid_ptr) { + extract_fixed_kernel + <<>>(message_data, + loc_provider, + num_items, + out_ptr, + valid_ptr, + d_error.data(), + has_default, + def_float_val); + }, + stream, + mr); + } + case cudf::type_id::FLOAT64: { + double def_double = has_default ? default_float : 0.0; + return extract_and_build_scalar_column( + dt, + num_items, + [&](double* out_ptr, bool* valid_ptr) { + extract_fixed_kernel + <<>>(message_data, + loc_provider, + num_items, + out_ptr, + valid_ptr, + d_error.data(), + has_default, + def_double); + }, + stream, + mr); + } + default: return make_null_column(dt, num_items, stream, mr); + } +} + +template +inline std::unique_ptr build_repeated_scalar_column( + cudf::column_view const& binary_input, + uint8_t const* message_data, + cudf::size_type const* list_offsets, + cudf::size_type base_offset, + device_nested_field_descriptor const& field_desc, + rmm::device_uvector const& d_field_counts, + rmm::device_uvector& d_occurrences, + int total_count, + int num_rows, + rmm::device_uvector& d_error, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + auto const input_null_count = binary_input.null_count(); + + if (total_count == 0) { + rmm::device_uvector offsets(num_rows + 1, stream, mr); + thrust::fill(rmm::exec_policy_nosync(stream), offsets.begin(), offsets.end(), 0); + auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + num_rows + 1, + offsets.release(), + rmm::device_buffer{}, + 0); + auto elem_type = field_desc.output_type_id == static_cast(cudf::type_id::LIST) + ? cudf::type_id::UINT8 + : static_cast(field_desc.output_type_id); + auto child_col = make_empty_column_safe(cudf::data_type{elem_type}, stream, mr); + + if (input_null_count > 0) { + auto null_mask = cudf::copy_bitmask(binary_input, stream, mr); + return cudf::make_lists_column(num_rows, + std::move(offsets_col), + std::move(child_col), + input_null_count, + std::move(null_mask)); + } else { + return cudf::make_lists_column( + num_rows, std::move(offsets_col), std::move(child_col), 0, rmm::device_buffer{}); + } + } + + rmm::device_uvector list_offs(num_rows + 1, stream, mr); + thrust::exclusive_scan(rmm::exec_policy_nosync(stream), + d_field_counts.begin(), + d_field_counts.end(), + list_offs.begin(), + 0); + + int32_t total_count_i32 = static_cast(total_count); + thrust::fill_n(rmm::exec_policy_nosync(stream), list_offs.data() + num_rows, 1, total_count_i32); + + rmm::device_uvector values(total_count, stream, mr); + + auto const threads = THREADS_PER_BLOCK; + auto const blocks = static_cast((total_count + threads - 1u) / threads); + + int encoding = field_desc.encoding; + bool zigzag = (encoding == encoding_value(proto_encoding::ZIGZAG)); + + constexpr bool is_floating_point = std::is_same_v || std::is_same_v; + bool use_fixed_kernel = is_floating_point || (encoding == encoding_value(proto_encoding::FIXED)); + + repeated_location_provider loc_provider{list_offsets, base_offset, d_occurrences.data()}; + if (use_fixed_kernel) { + if constexpr (sizeof(T) == 4) { + extract_fixed_kernel + <<>>( + message_data, loc_provider, total_count, values.data(), nullptr, d_error.data()); + } else { + extract_fixed_kernel + <<>>( + message_data, loc_provider, total_count, values.data(), nullptr, d_error.data()); + } + } else if (zigzag) { + extract_varint_kernel + <<>>( + message_data, loc_provider, total_count, values.data(), nullptr, d_error.data()); + } else { + extract_varint_kernel + <<>>( + message_data, loc_provider, total_count, values.data(), nullptr, d_error.data()); + } + + auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + num_rows + 1, + list_offs.release(), + rmm::device_buffer{}, + 0); + auto child_col = std::make_unique( + cudf::data_type{static_cast(field_desc.output_type_id)}, + total_count, + values.release(), + rmm::device_buffer{}, + 0); + + if (input_null_count > 0) { + auto null_mask = cudf::copy_bitmask(binary_input, stream, mr); + return cudf::make_lists_column(num_rows, + std::move(offsets_col), + std::move(child_col), + input_null_count, + std::move(null_mask)); + } + + return cudf::make_lists_column( + num_rows, std::move(offsets_col), std::move(child_col), 0, rmm::device_buffer{}); +} + +} // namespace cudf::io::protobuf::detail diff --git a/cpp/src/io/protobuf/types.cuh b/cpp/src/io/protobuf/types.cuh new file mode 100644 index 00000000000..3d1c9324662 --- /dev/null +++ b/cpp/src/io/protobuf/types.cuh @@ -0,0 +1,126 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace cudf::io::protobuf::detail { + +// Protobuf varint encoding uses at most 10 bytes to represent a 64-bit value. +constexpr int MAX_VARINT_BYTES = 10; + +// CUDA kernel launch configuration. +constexpr int THREADS_PER_BLOCK = 256; + +// Error codes for kernel error reporting. +constexpr int ERR_BOUNDS = 1; +constexpr int ERR_VARINT = 2; +constexpr int ERR_FIELD_NUMBER = 3; +constexpr int ERR_WIRE_TYPE = 4; +constexpr int ERR_OVERFLOW = 5; +constexpr int ERR_FIELD_SIZE = 6; +constexpr int ERR_SKIP = 7; +constexpr int ERR_FIXED_LEN = 8; +constexpr int ERR_REQUIRED = 9; +constexpr int ERR_SCHEMA_TOO_LARGE = 10; +constexpr int ERR_MISSING_ENUM_META = 11; +constexpr int ERR_REPEATED_COUNT_MISMATCH = 12; + +// Threshold for using a direct-mapped lookup table for field_number -> field_index. +// Field numbers above this threshold fall back to linear search. +constexpr int FIELD_LOOKUP_TABLE_MAX = 4096; + +/** + * Structure to record field location within a message. + * offset < 0 means field was not found. + */ +struct field_location { + int32_t offset; // Offset of field data within the message (-1 if not found) + int32_t length; // Length of field data in bytes +}; + +/** + * Field descriptor passed to the scanning kernel. + */ +struct field_descriptor { + int field_number; // Protobuf field number + int expected_wire_type; // Expected wire type for this field + bool is_repeated; // Repeated children are scanned via count/scan kernels +}; + +/** + * Information about repeated field occurrences in a row. + */ +struct repeated_field_info { + int32_t count; // Number of occurrences in this row + int32_t total_length; // Total bytes for all occurrences (for varlen fields) +}; + +/** + * Location of a single occurrence of a repeated field. + */ +struct repeated_occurrence { + int32_t row_idx; // Which row this occurrence belongs to + int32_t offset; // Offset within the message + int32_t length; // Length of the field data +}; + +/** + * Per-field descriptor passed to the combined occurrence scan kernel. + * Contains device pointers so the kernel can write to each field's output. + */ +struct repeated_field_scan_desc { + int field_number; + int wire_type; + int32_t const* row_offsets; // Pre-computed prefix-sum offsets [num_rows + 1] + repeated_occurrence* occurrences; // Output buffer [total_count] +}; + +/** + * Device-side descriptor for nested schema fields. + */ +struct device_nested_field_descriptor { + int field_number; + int parent_idx; + int depth; + int wire_type; + int output_type_id; + int encoding; + bool is_repeated; + bool is_required; + bool has_default_value; + + device_nested_field_descriptor() = default; + + // Wire type and encoding are stored as int (not typed enums) because CUDA device code + // historically had limited constexpr enum support, and the kernel comparison sites use + // int-typed wire_type_value()/encoding_value() helpers throughout. + explicit device_nested_field_descriptor(nested_field_descriptor const& src) + : field_number(src.field_number), + parent_idx(src.parent_idx), + depth(src.depth), + wire_type(static_cast(src.wire_type)), + output_type_id(static_cast(src.output_type)), + encoding(static_cast(src.encoding)), + is_repeated(src.is_repeated), + is_required(src.is_required), + has_default_value(src.has_default_value) + { + } +}; + +} // namespace cudf::io::protobuf::detail diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index f13a3af71b6..b7850c227d6 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -326,6 +326,11 @@ ConfigureTest( GPUS 1 PERCENT 100 ) +ConfigureTest( + PROTOBUF_TEST io/protobuf_test.cpp + GPUS 1 + PERCENT 100 +) ConfigureTest( PARQUET_TEST io/parquet_bloom_filter_test.cu diff --git a/cpp/tests/io/protobuf_test.cpp b/cpp/tests/io/protobuf_test.cpp new file mode 100644 index 00000000000..dadb8789285 --- /dev/null +++ b/cpp/tests/io/protobuf_test.cpp @@ -0,0 +1,368 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace pb = cudf::io::protobuf; + +// ============================================================================ +// Protobuf wire format encoding helpers +// ============================================================================ +namespace { + +constexpr int WT_VARINT = 0; +constexpr int WT_LEN = 2; + +std::vector encode_varint(uint64_t value) +{ + std::vector out; + while (value > 0x7F) { + out.push_back(static_cast((value & 0x7F) | 0x80)); + value >>= 7; + } + out.push_back(static_cast(value)); + return out; +} + +std::vector concat(std::initializer_list> parts) +{ + std::vector out; + for (auto const& p : parts) { + out.insert(out.end(), p.begin(), p.end()); + } + return out; +} + +std::vector tag(int field_number, int wire_type) +{ + return encode_varint((static_cast(field_number) << 3) | + static_cast(wire_type)); +} + +std::vector encode_varint_field(int field_number, uint64_t value) +{ + return concat({tag(field_number, WT_VARINT), encode_varint(value)}); +} + +std::vector encode_string_field(int field_number, std::string const& s) +{ + auto t = tag(field_number, WT_LEN); + auto len = encode_varint(s.size()); + auto out = concat({t, len}); + out.insert(out.end(), s.begin(), s.end()); + return out; +} + +std::unique_ptr make_binary_column(std::vector> const& messages, + std::vector const& validity = {}) +{ + std::vector offsets; + offsets.reserve(messages.size() + 1); + offsets.push_back(0); + for (auto const& m : messages) { + offsets.push_back(offsets.back() + static_cast(m.size())); + } + + std::vector flat_data; + flat_data.reserve(offsets.back()); + for (auto const& m : messages) { + flat_data.insert(flat_data.end(), m.begin(), m.end()); + } + + auto offsets_col = + cudf::test::fixed_width_column_wrapper(offsets.begin(), offsets.end()).release(); + auto data_col = + cudf::test::fixed_width_column_wrapper(flat_data.begin(), flat_data.end()).release(); + + auto num_rows = static_cast(messages.size()); + + if (!validity.empty()) { + auto [null_mask, null_count] = + cudf::test::detail::make_null_mask(validity.begin(), validity.end()); + return cudf::make_lists_column( + num_rows, std::move(offsets_col), std::move(data_col), null_count, std::move(null_mask)); + } + + return cudf::make_lists_column( + num_rows, std::move(offsets_col), std::move(data_col), 0, rmm::device_buffer{}); +} + +pb::decode_protobuf_options make_scalar_options(std::vector const& field_numbers, + std::vector const& types, + std::vector const& encodings, + bool fail_on_errors = true) +{ + int const n = static_cast(field_numbers.size()); + + auto derive_wire_type = [](cudf::type_id type, int enc) -> pb::proto_wire_type { + if (enc == static_cast(pb::proto_encoding::FIXED)) { + if (type == cudf::type_id::INT64 || type == cudf::type_id::UINT64 || + type == cudf::type_id::FLOAT64) { + return pb::proto_wire_type::I64BIT; + } + return pb::proto_wire_type::I32BIT; + } + switch (type) { + case cudf::type_id::FLOAT32: return pb::proto_wire_type::I32BIT; + case cudf::type_id::FLOAT64: return pb::proto_wire_type::I64BIT; + case cudf::type_id::STRING: + case cudf::type_id::LIST: + case cudf::type_id::STRUCT: return pb::proto_wire_type::LEN; + default: return pb::proto_wire_type::VARINT; + } + }; + + std::vector schema; + schema.reserve(n); + for (int i = 0; i < n; ++i) { + schema.push_back({field_numbers[i], + -1, + 0, + derive_wire_type(types[i], encodings[i]), + types[i], + static_cast(encodings[i]), + false, + false, + false}); + } + + std::vector> default_strings; + std::vector> enum_valid; + default_strings.reserve(n); + enum_valid.reserve(n); + for (int i = 0; i < n; ++i) { + default_strings.push_back( + cudf::detail::make_host_vector(0, cudf::get_default_stream())); + enum_valid.push_back(cudf::detail::make_host_vector(0, cudf::get_default_stream())); + } + + return pb::decode_protobuf_options{ + std::move(schema), + std::vector(n, 0), + std::vector(n, 0.0), + std::vector(n, false), + std::move(default_strings), + std::move(enum_valid), + std::vector>>(n), + fail_on_errors, + }; +} + +auto make_empty_host_vectors(int count) +{ + struct result { + std::vector> hv; + std::vector> iv; + }; + result r; + r.hv.reserve(count); + r.iv.reserve(count); + for (int i = 0; i < count; ++i) { + r.hv.push_back(cudf::detail::make_host_vector(0, cudf::get_default_stream())); + r.iv.push_back(cudf::detail::make_host_vector(0, cudf::get_default_stream())); + } + return r; +} + +} // anonymous namespace + +// ============================================================================ +// Test fixture +// ============================================================================ + +struct ProtobufReaderTest : public cudf::test::BaseFixture {}; + +// ============================================================================ +// Part0 tests: output shape, type structure, and null propagation +// (Stub decode returns all-null columns with correct types) +// ============================================================================ + +TEST_F(ProtobufReaderTest, EmptySchema) +{ + auto input = make_binary_column({encode_varint_field(1, 42), encode_varint_field(1, 7)}); + + pb::decode_protobuf_options options{{}, {}, {}, {}, {}, {}, {}, true}; + + auto result = pb::decode_protobuf(*input, options); + + ASSERT_EQ(result->type().id(), cudf::type_id::STRUCT); + ASSERT_EQ(result->size(), 2); + ASSERT_EQ(result->num_children(), 0); +} + +TEST_F(ProtobufReaderTest, ZeroRows) +{ + auto input = make_binary_column({}); + auto options = make_scalar_options({1, 2}, {cudf::type_id::INT64, cudf::type_id::STRING}, {0, 0}); + + auto result = pb::decode_protobuf(*input, options); + + ASSERT_EQ(result->type().id(), cudf::type_id::STRUCT); + ASSERT_EQ(result->size(), 0); + ASSERT_EQ(result->num_children(), 2); + EXPECT_EQ(result->child(0).type().id(), cudf::type_id::INT64); + EXPECT_EQ(result->child(1).type().id(), cudf::type_id::STRING); +} + +TEST_F(ProtobufReaderTest, ZeroRowsNestedSchema) +{ + // [0: id(INT32), 1: inner(STRUCT), 2: name(STRING, parent=1)] + int const n = 3; + std::vector schema = { + {1, + -1, + 0, + pb::proto_wire_type::VARINT, + cudf::type_id::INT32, + pb::proto_encoding::DEFAULT, + false, + false, + false}, + {2, + -1, + 0, + pb::proto_wire_type::LEN, + cudf::type_id::STRUCT, + pb::proto_encoding::DEFAULT, + false, + false, + false}, + {1, + 1, + 1, + pb::proto_wire_type::LEN, + cudf::type_id::STRING, + pb::proto_encoding::DEFAULT, + false, + false, + false}, + }; + + auto [hv, iv] = make_empty_host_vectors(n); + + pb::decode_protobuf_options options{ + std::move(schema), + std::vector(n, 0), + std::vector(n, 0.0), + std::vector(n, false), + std::move(hv), + std::move(iv), + std::vector>>(n), + true}; + + auto result = pb::decode_protobuf(*make_binary_column({}), options); + + ASSERT_EQ(result->size(), 0); + ASSERT_EQ(result->num_children(), 2); + EXPECT_EQ(result->child(0).type().id(), cudf::type_id::INT32); + EXPECT_EQ(result->child(1).type().id(), cudf::type_id::STRUCT); + EXPECT_EQ(result->child(1).num_children(), 1); + EXPECT_EQ(result->child(1).child(0).type().id(), cudf::type_id::STRING); +} + +TEST_F(ProtobufReaderTest, ZeroRowsRepeatedSchema) +{ + int const n = 1; + std::vector schema = { + {1, + -1, + 0, + pb::proto_wire_type::VARINT, + cudf::type_id::INT32, + pb::proto_encoding::DEFAULT, + true, + false, + false}, + }; + + auto [hv, iv] = make_empty_host_vectors(n); + + pb::decode_protobuf_options options{ + std::move(schema), + std::vector(n, 0), + std::vector(n, 0.0), + std::vector(n, false), + std::move(hv), + std::move(iv), + std::vector>>(n), + true}; + + auto result = pb::decode_protobuf(*make_binary_column({}), options); + + ASSERT_EQ(result->size(), 0); + ASSERT_EQ(result->num_children(), 1); + EXPECT_EQ(result->child(0).type().id(), cudf::type_id::LIST); +} + +TEST_F(ProtobufReaderTest, StubReturnsAllNullWithCorrectTypes) +{ + auto input = make_binary_column({encode_varint_field(1, 42), encode_string_field(2, "hello")}); + + auto options = make_scalar_options({1, 2}, {cudf::type_id::INT64, cudf::type_id::STRING}, {0, 0}); + + auto result = pb::decode_protobuf(*input, options); + + ASSERT_EQ(result->type().id(), cudf::type_id::STRUCT); + ASSERT_EQ(result->size(), 2); + ASSERT_EQ(result->num_children(), 2); + EXPECT_EQ(result->child(0).type().id(), cudf::type_id::INT64); + EXPECT_EQ(result->child(1).type().id(), cudf::type_id::STRING); + EXPECT_EQ(result->child(0).null_count(), 2); + EXPECT_EQ(result->child(1).null_count(), 2); +} + +TEST_F(ProtobufReaderTest, NullInputRowsPropagateToStruct) +{ + auto msg = encode_varint_field(1, 42); + auto input = make_binary_column({msg, {}, msg}, {true, false, true}); + + auto options = make_scalar_options({1}, {cudf::type_id::INT64}, {0}); + + auto result = pb::decode_protobuf(*input, options); + + ASSERT_EQ(result->size(), 3); + EXPECT_EQ(result->null_count(), 1); +} + +TEST_F(ProtobufReaderTest, MultipleNumericTypesShape) +{ + auto input = make_binary_column({encode_varint_field(1, 1)}); + + auto options = make_scalar_options({1, 2, 3, 4, 5}, + {cudf::type_id::BOOL8, + cudf::type_id::INT32, + cudf::type_id::INT64, + cudf::type_id::FLOAT32, + cudf::type_id::FLOAT64}, + {0, 0, 0, 0, 0}); + + auto result = pb::decode_protobuf(*input, options); + + ASSERT_EQ(result->num_children(), 5); + EXPECT_EQ(result->child(0).type().id(), cudf::type_id::BOOL8); + EXPECT_EQ(result->child(1).type().id(), cudf::type_id::INT32); + EXPECT_EQ(result->child(2).type().id(), cudf::type_id::INT64); + EXPECT_EQ(result->child(3).type().id(), cudf::type_id::FLOAT32); + EXPECT_EQ(result->child(4).type().id(), cudf::type_id::FLOAT64); +} + +CUDF_TEST_PROGRAM_MAIN() diff --git a/java/src/main/java/ai/rapids/cudf/ColumnView.java b/java/src/main/java/ai/rapids/cudf/ColumnView.java index 4a0ffef1ce5..3ac5bc0f79e 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnView.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnView.java @@ -3359,6 +3359,31 @@ public final ColumnVector getJSONObject(Scalar path) { return getJSONObject(path, GetJsonObjectOptions.DEFAULT); } + /** + * Decode serialized protobuf messages into a STRUCT column. + * + * Takes a LIST column (INT8 or UINT8 elements) where each row contains a serialized + * protobuf message and decodes it into a STRUCT column according to the provided schema. + * + * Supports nested messages (up to 10 levels), repeated fields (as LIST columns), + * enum-as-string conversion, default values, and required field checking. + * + * @param schema descriptor containing the flattened protobuf schema + * @param failOnErrors if true, throw on malformed messages; otherwise return nulls + * @return a STRUCT column with decoded protobuf fields + */ + public final ColumnVector decodeProtobuf(ProtobufSchemaDescriptor schema, + boolean failOnErrors) { + assert type.equals(DType.LIST) : "column type must be a LIST"; + return new ColumnVector(decodeProtobuf(getNativeView(), + schema.fieldNumbers, schema.parentIndices, schema.depthLevels, + schema.wireTypes, schema.outputTypeIds, schema.encodings, + schema.isRepeated, schema.isRequired, schema.hasDefaultValue, + schema.defaultInts, schema.defaultFloats, schema.defaultBools, + schema.defaultStrings, schema.enumValidValues, schema.enumNames, + failOnErrors)); + } + /** * Returns a new strings column where target string within each string is replaced with the specified * replacement string. @@ -4588,6 +4613,14 @@ private static native long repeatStringsWithColumnRepeatTimes(long stringsHandle private static native long getJSONObject(long viewHandle, long scalarHandle, boolean allowSingleQuotes, boolean stripQuotesFromSingleStrings, boolean missingFieldsAsNulls) throws CudfException; + private static native long decodeProtobuf(long viewHandle, + int[] fieldNumbers, int[] parentIndices, int[] depthLevels, + int[] wireTypes, int[] outputTypeIds, int[] encodings, + boolean[] isRepeated, boolean[] isRequired, boolean[] hasDefaultValue, + long[] defaultInts, double[] defaultFloats, boolean[] defaultBools, + byte[][] defaultStrings, int[][] enumValidValues, byte[][][] enumNames, + boolean failOnErrors) throws CudfException; + /** * Native method to parse and convert a timestamp column vector to string column vector. A unix * timestamp is a long value representing how many units since 1970-01-01 00:00:00:000 in either diff --git a/java/src/main/java/ai/rapids/cudf/ProtobufSchemaDescriptor.java b/java/src/main/java/ai/rapids/cudf/ProtobufSchemaDescriptor.java new file mode 100644 index 00000000000..086549b9e2a --- /dev/null +++ b/java/src/main/java/ai/rapids/cudf/ProtobufSchemaDescriptor.java @@ -0,0 +1,213 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +package ai.rapids.cudf; + +import java.util.HashSet; +import java.util.Set; + +/** + * Immutable descriptor for a flattened protobuf schema, grouping the parallel arrays + * that describe field structure, types, defaults, and enum metadata. + * + *

All arrays provided to the constructor are defensively copied to guarantee immutability. + */ +public final class ProtobufSchemaDescriptor implements java.io.Serializable { + private static final long serialVersionUID = 1L; + private static final int MAX_FIELD_NUMBER = (1 << 29) - 1; + private static final int MAX_NESTING_DEPTH = 10; + private static final int STRUCT_TYPE_ID = DType.STRUCT.getTypeId().getNativeId(); + private static final int STRING_TYPE_ID = DType.STRING.getTypeId().getNativeId(); + private static final int LIST_TYPE_ID = DType.LIST.getTypeId().getNativeId(); + private static final int BOOL8_TYPE_ID = DType.BOOL8.getTypeId().getNativeId(); + private static final int INT32_TYPE_ID = DType.INT32.getTypeId().getNativeId(); + private static final int UINT32_TYPE_ID = DType.UINT32.getTypeId().getNativeId(); + private static final int INT64_TYPE_ID = DType.INT64.getTypeId().getNativeId(); + private static final int UINT64_TYPE_ID = DType.UINT64.getTypeId().getNativeId(); + private static final int FLOAT32_TYPE_ID = DType.FLOAT32.getTypeId().getNativeId(); + private static final int FLOAT64_TYPE_ID = DType.FLOAT64.getTypeId().getNativeId(); + + // Encoding constants + public static final int ENC_DEFAULT = 0; + public static final int ENC_FIXED = 1; + public static final int ENC_ZIGZAG = 2; + public static final int ENC_ENUM_STRING = 3; + + // Wire type constants + public static final int WT_VARINT = 0; + public static final int WT_64BIT = 1; + public static final int WT_LEN = 2; + public static final int WT_32BIT = 5; + + final int[] fieldNumbers; + final int[] parentIndices; + final int[] depthLevels; + final int[] wireTypes; + final int[] outputTypeIds; + final int[] encodings; + final boolean[] isRepeated; + final boolean[] isRequired; + final boolean[] hasDefaultValue; + final long[] defaultInts; + final double[] defaultFloats; + final boolean[] defaultBools; + final byte[][] defaultStrings; + final int[][] enumValidValues; + final byte[][][] enumNames; + + /** + * @throws IllegalArgumentException if any array is null, arrays have mismatched lengths, + * field numbers are out of range, or encoding values are invalid. + */ + public ProtobufSchemaDescriptor( + int[] fieldNumbers, + int[] parentIndices, + int[] depthLevels, + int[] wireTypes, + int[] outputTypeIds, + int[] encodings, + boolean[] isRepeated, + boolean[] isRequired, + boolean[] hasDefaultValue, + long[] defaultInts, + double[] defaultFloats, + boolean[] defaultBools, + byte[][] defaultStrings, + int[][] enumValidValues, + byte[][][] enumNames) { + + validate(fieldNumbers, parentIndices, depthLevels, wireTypes, outputTypeIds, + encodings, isRepeated, isRequired, hasDefaultValue, defaultInts, + defaultFloats, defaultBools, defaultStrings, enumValidValues, enumNames); + + this.fieldNumbers = fieldNumbers.clone(); + this.parentIndices = parentIndices.clone(); + this.depthLevels = depthLevels.clone(); + this.wireTypes = wireTypes.clone(); + this.outputTypeIds = outputTypeIds.clone(); + this.encodings = encodings.clone(); + this.isRepeated = isRepeated.clone(); + this.isRequired = isRequired.clone(); + this.hasDefaultValue = hasDefaultValue.clone(); + this.defaultInts = defaultInts.clone(); + this.defaultFloats = defaultFloats.clone(); + this.defaultBools = defaultBools.clone(); + this.defaultStrings = deepCopy(defaultStrings); + this.enumValidValues = deepCopy(enumValidValues); + this.enumNames = deepCopy(enumNames); + } + + public int numFields() { return fieldNumbers.length; } + + private void readObject(java.io.ObjectInputStream in) + throws java.io.IOException, ClassNotFoundException { + in.defaultReadObject(); + try { + validate(fieldNumbers, parentIndices, depthLevels, wireTypes, outputTypeIds, + encodings, isRepeated, isRequired, hasDefaultValue, defaultInts, + defaultFloats, defaultBools, defaultStrings, enumValidValues, enumNames); + } catch (IllegalArgumentException e) { + java.io.InvalidObjectException ioe = new java.io.InvalidObjectException(e.getMessage()); + ioe.initCause(e); + throw ioe; + } + } + + private static byte[][] deepCopy(byte[][] src) { + byte[][] dst = new byte[src.length][]; + for (int i = 0; i < src.length; i++) { + dst[i] = src[i] != null ? src[i].clone() : null; + } + return dst; + } + + private static int[][] deepCopy(int[][] src) { + int[][] dst = new int[src.length][]; + for (int i = 0; i < src.length; i++) { + dst[i] = src[i] != null ? src[i].clone() : null; + } + return dst; + } + + private static byte[][][] deepCopy(byte[][][] src) { + byte[][][] dst = new byte[src.length][][]; + for (int i = 0; i < src.length; i++) { + if (src[i] == null) continue; + dst[i] = new byte[src[i].length][]; + for (int j = 0; j < src[i].length; j++) { + dst[i][j] = src[i][j] != null ? src[i][j].clone() : null; + } + } + return dst; + } + + private static void validate( + int[] fieldNumbers, int[] parentIndices, int[] depthLevels, + int[] wireTypes, int[] outputTypeIds, int[] encodings, + boolean[] isRepeated, boolean[] isRequired, boolean[] hasDefaultValue, + long[] defaultInts, double[] defaultFloats, boolean[] defaultBools, + byte[][] defaultStrings, int[][] enumValidValues, byte[][][] enumNames) { + + if (fieldNumbers == null || parentIndices == null || depthLevels == null || + wireTypes == null || outputTypeIds == null || encodings == null || + isRepeated == null || isRequired == null || hasDefaultValue == null || + defaultInts == null || defaultFloats == null || defaultBools == null || + defaultStrings == null || enumValidValues == null || enumNames == null) { + throw new IllegalArgumentException("All schema arrays must be non-null"); + } + + int n = fieldNumbers.length; + if (parentIndices.length != n || depthLevels.length != n || + wireTypes.length != n || outputTypeIds.length != n || + encodings.length != n || isRepeated.length != n || + isRequired.length != n || hasDefaultValue.length != n || + defaultInts.length != n || defaultFloats.length != n || + defaultBools.length != n || defaultStrings.length != n || + enumValidValues.length != n || enumNames.length != n) { + throw new IllegalArgumentException("All schema arrays must have the same length"); + } + + Set seenFieldNumbers = new HashSet<>(); + for (int i = 0; i < n; i++) { + if (fieldNumbers[i] <= 0 || fieldNumbers[i] > MAX_FIELD_NUMBER) { + throw new IllegalArgumentException( + "Invalid field number at index " + i + ": " + fieldNumbers[i]); + } + if (depthLevels[i] < 0 || depthLevels[i] >= MAX_NESTING_DEPTH) { + throw new IllegalArgumentException( + "Invalid depth at index " + i + ": " + depthLevels[i]); + } + int pi = parentIndices[i]; + if (pi < -1 || pi >= i) { + throw new IllegalArgumentException( + "Invalid parent index at index " + i + ": " + pi); + } + if (pi == -1) { + if (depthLevels[i] != 0) { + throw new IllegalArgumentException( + "Top-level field at index " + i + " must have depth 0"); + } + } else { + if (outputTypeIds[pi] != STRUCT_TYPE_ID) { + throw new IllegalArgumentException( + "Parent at index " + pi + " for field " + i + " must be STRUCT"); + } + } + long fieldKey = (((long) pi) << 32) | (fieldNumbers[i] & 0xFFFFFFFFL); + if (!seenFieldNumbers.add(fieldKey)) { + throw new IllegalArgumentException( + "Duplicate field number " + fieldNumbers[i] + " under parent " + pi); + } + int wt = wireTypes[i]; + if (wt != WT_VARINT && wt != WT_64BIT && wt != WT_LEN && wt != WT_32BIT) { + throw new IllegalArgumentException("Invalid wire type at index " + i + ": " + wt); + } + int enc = encodings[i]; + if (enc < ENC_DEFAULT || enc > ENC_ENUM_STRING) { + throw new IllegalArgumentException("Invalid encoding at index " + i + ": " + enc); + } + } + } +} diff --git a/java/src/main/native/CMakeLists.txt b/java/src/main/native/CMakeLists.txt index 1e7df3802b9..373ef2a1354 100644 --- a/java/src/main/native/CMakeLists.txt +++ b/java/src/main/native/CMakeLists.txt @@ -162,6 +162,7 @@ add_library( src/ColumnViewJni.cu src/CompiledExpression.cpp src/ContiguousTableJni.cpp + src/ProtobufJni.cpp src/DataSourceHelperJni.cpp src/DeletionVectorJni.cpp src/HashJoinJni.cpp diff --git a/java/src/main/native/src/ProtobufJni.cpp b/java/src/main/native/src/ProtobufJni.cpp new file mode 100644 index 00000000000..f6cc7492f72 --- /dev/null +++ b/java/src/main/native/src/ProtobufJni.cpp @@ -0,0 +1,184 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "cudf_jni_apis.hpp" +#include "jni_utils.hpp" + +#include +#include +#include +#include + +namespace { + +cudf::detail::host_vector jni_byte_array_to_host_vector(JNIEnv* env, jobject obj) +{ + if (obj == nullptr) { + return cudf::detail::make_host_vector(0, cudf::get_default_stream()); + } + auto byte_arr = static_cast(obj); + jsize len = env->GetArrayLength(byte_arr); + jbyte* bytes = env->GetByteArrayElements(byte_arr, nullptr); + if (bytes == nullptr) { + return cudf::detail::make_host_vector(0, cudf::get_default_stream()); + } + auto vec = cudf::detail::make_host_vector(len, cudf::get_default_stream()); + std::copy( + reinterpret_cast(bytes), reinterpret_cast(bytes) + len, vec.begin()); + env->ReleaseByteArrayElements(byte_arr, bytes, JNI_ABORT); + return vec; +} + +cudf::detail::host_vector jni_int_array_to_host_vector(JNIEnv* env, jobject obj) +{ + if (obj == nullptr) { + return cudf::detail::make_host_vector(0, cudf::get_default_stream()); + } + auto int_arr = static_cast(obj); + jsize len = env->GetArrayLength(int_arr); + jint* ints = env->GetIntArrayElements(int_arr, nullptr); + if (ints == nullptr) { + return cudf::detail::make_host_vector(0, cudf::get_default_stream()); + } + auto vec = cudf::detail::make_host_vector(len, cudf::get_default_stream()); + std::copy(ints, ints + len, vec.begin()); + env->ReleaseIntArrayElements(int_arr, ints, JNI_ABORT); + return vec; +} + +template +std::vector jni_object_array_to_vectors(JNIEnv* env, + jobjectArray arr, + int num_elements, + ConvertFn convert) +{ + std::vector result; + result.reserve(num_elements); + for (int i = 0; i < num_elements; ++i) { + jobject elem = env->GetObjectArrayElement(arr, i); + if (env->ExceptionCheck()) { return {}; } + auto vec = convert(env, elem); + if (elem != nullptr) { env->DeleteLocalRef(elem); } + if (env->ExceptionCheck()) { return {}; } + result.push_back(std::move(vec)); + } + return result; +} + +} // anonymous namespace + +extern "C" { + +JNIEXPORT jlong JNICALL +Java_ai_rapids_cudf_ColumnView_decodeProtobuf(JNIEnv* env, + jclass, + jlong j_view_handle, + jintArray field_numbers, + jintArray parent_indices, + jintArray depth_levels, + jintArray wire_types, + jintArray output_type_ids, + jintArray encodings, + jbooleanArray is_repeated, + jbooleanArray is_required, + jbooleanArray has_default_value, + jlongArray default_ints, + jdoubleArray default_floats, + jbooleanArray default_bools, + jobjectArray default_strings, + jobjectArray enum_valid_values, + jobjectArray enum_names, + jboolean fail_on_errors) +{ + JNI_NULL_CHECK(env, j_view_handle, "column view cannot be null", 0); + JNI_NULL_CHECK(env, field_numbers, "field_numbers cannot be null", 0); + + try { + cudf::jni::auto_set_device(env); + auto const* input = reinterpret_cast(j_view_handle); + + cudf::jni::native_jintArray n_field_numbers(env, field_numbers); + cudf::jni::native_jintArray n_parent_indices(env, parent_indices); + cudf::jni::native_jintArray n_depth_levels(env, depth_levels); + cudf::jni::native_jintArray n_wire_types(env, wire_types); + cudf::jni::native_jintArray n_output_type_ids(env, output_type_ids); + cudf::jni::native_jintArray n_encodings(env, encodings); + cudf::jni::native_jbooleanArray n_is_repeated(env, is_repeated); + cudf::jni::native_jbooleanArray n_is_required(env, is_required); + cudf::jni::native_jbooleanArray n_has_default(env, has_default_value); + cudf::jni::native_jlongArray n_default_ints(env, default_ints); + cudf::jni::native_jdoubleArray n_default_floats(env, default_floats); + cudf::jni::native_jbooleanArray n_default_bools(env, default_bools); + + int const num_fields = n_field_numbers.size(); + + // Build schema descriptors + std::vector schema; + schema.reserve(num_fields); + for (int i = 0; i < num_fields; ++i) { + schema.push_back({n_field_numbers[i], + n_parent_indices[i], + n_depth_levels[i], + static_cast(n_wire_types[i]), + static_cast(n_output_type_ids[i]), + static_cast(n_encodings[i]), + n_is_repeated[i] != 0, + n_is_required[i] != 0, + n_has_default[i] != 0}); + } + + // Convert default values + std::vector default_int_values(n_default_ints.begin(), n_default_ints.end()); + std::vector default_float_values(n_default_floats.begin(), n_default_floats.end()); + + std::vector default_bool_values; + default_bool_values.reserve(num_fields); + for (int i = 0; i < num_fields; ++i) { + default_bool_values.push_back(n_default_bools[i] != 0); + } + + // Convert default strings (byte[][]) + auto default_string_values = jni_object_array_to_vectors>( + env, default_strings, num_fields, jni_byte_array_to_host_vector); + if (env->ExceptionCheck()) { return 0; } + + // Convert enum valid values (int[][]) + auto enum_values = jni_object_array_to_vectors>( + env, enum_valid_values, num_fields, jni_int_array_to_host_vector); + if (env->ExceptionCheck()) { return 0; } + + // Convert enum names (byte[][][]) + auto enum_name_values = + jni_object_array_to_vectors>>( + env, + enum_names, + num_fields, + [](JNIEnv* e, jobject obj) -> std::vector> { + if (obj == nullptr) { return {}; } + auto inner_arr = static_cast(obj); + jsize num = e->GetArrayLength(inner_arr); + return jni_object_array_to_vectors>( + e, inner_arr, num, jni_byte_array_to_host_vector); + }); + if (env->ExceptionCheck()) { return 0; } + + cudf::io::protobuf::decode_protobuf_options options{std::move(schema), + std::move(default_int_values), + std::move(default_float_values), + std::move(default_bool_values), + std::move(default_string_values), + std::move(enum_values), + std::move(enum_name_values), + static_cast(fail_on_errors)}; + + auto result = cudf::io::protobuf::decode_protobuf( + *input, options, cudf::get_default_stream(), cudf::get_current_device_resource_ref()); + + return cudf::jni::release_as_jlong(result); + } + CATCH_STD(env, 0); +} + +} // extern "C" diff --git a/python/pylibcudf/pylibcudf/io/CMakeLists.txt b/python/pylibcudf/pylibcudf/io/CMakeLists.txt index 089ea8d0e8d..3f0e62d0341 100644 --- a/python/pylibcudf/pylibcudf/io/CMakeLists.txt +++ b/python/pylibcudf/pylibcudf/io/CMakeLists.txt @@ -6,7 +6,7 @@ # ============================================================================= set(cython_sources avro.pyx csv.pyx datasource.pyx json.pyx orc.pyx parquet.pyx - parquet_metadata.pyx text.pyx timezone.pyx types.pyx + parquet_metadata.pyx protobuf.pyx text.pyx timezone.pyx types.pyx ) set(linked_libraries cudf::cudf) diff --git a/python/pylibcudf/pylibcudf/io/__init__.pxd b/python/pylibcudf/pylibcudf/io/__init__.pxd index d8a3c42d4c1..22970861806 100644 --- a/python/pylibcudf/pylibcudf/io/__init__.pxd +++ b/python/pylibcudf/pylibcudf/io/__init__.pxd @@ -10,6 +10,7 @@ from . cimport ( orc, parquet, parquet_metadata, + protobuf, text, timezone, types, diff --git a/python/pylibcudf/pylibcudf/io/__init__.py b/python/pylibcudf/pylibcudf/io/__init__.py index d5410d20482..bee62c65571 100644 --- a/python/pylibcudf/pylibcudf/io/__init__.py +++ b/python/pylibcudf/pylibcudf/io/__init__.py @@ -10,6 +10,7 @@ orc, parquet, parquet_metadata, + protobuf, text, timezone, types, @@ -28,6 +29,7 @@ "orc", "parquet", "parquet_metadata", + "protobuf", "text", "timezone", "types", diff --git a/python/pylibcudf/pylibcudf/io/protobuf.pxd b/python/pylibcudf/pylibcudf/io/protobuf.pxd new file mode 100644 index 00000000000..43e1957591b --- /dev/null +++ b/python/pylibcudf/pylibcudf/io/protobuf.pxd @@ -0,0 +1,21 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. +# SPDX-License-Identifier: Apache-2.0 + +from pylibcudf.column cimport Column +from pylibcudf.gpumemoryresource cimport DeviceMemoryResource +from pylibcudf.stream cimport Stream + + +cpdef Column decode_protobuf( + Column binary_input, + list schema, + list default_ints, + list default_floats, + list default_bools, + list default_strings, + list enum_valid_values, + list enum_names, + bint fail_on_errors, + Stream stream = *, + DeviceMemoryResource mr = *, +) diff --git a/python/pylibcudf/pylibcudf/io/protobuf.pyx b/python/pylibcudf/pylibcudf/io/protobuf.pyx new file mode 100644 index 00000000000..28b437c11e8 --- /dev/null +++ b/python/pylibcudf/pylibcudf/io/protobuf.pyx @@ -0,0 +1,134 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. +# SPDX-License-Identifier: Apache-2.0 + +from libc.stdint cimport int64_t +from libcpp cimport bool +from libcpp.memory cimport unique_ptr +from libcpp.utility cimport move +from libcpp.vector cimport vector + +from rmm.pylibrmm.memory_resource cimport DeviceMemoryResource +from rmm.pylibrmm.stream cimport Stream + +from pylibcudf.column cimport Column +from pylibcudf.utils cimport _get_memory_resource, _get_stream + +from pylibcudf.libcudf.column.column cimport column +from pylibcudf.libcudf.io.protobuf cimport ( + decode_protobuf as cpp_decode_protobuf, + decode_protobuf_options, + nested_field_descriptor, + proto_encoding, + proto_wire_type, +) +from pylibcudf.libcudf.types cimport type_id + +__all__ = [ + "decode_protobuf", +] + + +cpdef Column decode_protobuf( + Column binary_input, + list schema, + list default_ints, + list default_floats, + list default_bools, + list default_strings, + list enum_valid_values, + list enum_names, + bint fail_on_errors, + Stream stream = None, + DeviceMemoryResource mr = None, +): + """ + Decode serialized protobuf messages from a LIST column into a STRUCT column. + + Parameters + ---------- + binary_input : Column + LIST column of serialized protobuf messages. + schema : list of tuples + Each tuple is (field_number, parent_idx, depth, wire_type, output_type_id, + encoding, is_repeated, is_required, has_default_value). + default_ints : list of int + Default integer values per field. + default_floats : list of float + Default float values per field. + default_bools : list of bool + Default boolean values per field. + default_strings : list of bytes + Default string values per field (as raw bytes). + enum_valid_values : list of list of int + Valid enum numbers per field. + enum_names : list of list of bytes + UTF-8 enum names per field. + fail_on_errors : bool + If True, raise on malformed messages. If False, return nulls. + stream : Stream, optional + CUDA stream for device operations. + mr : DeviceMemoryResource, optional + Device memory resource. + + Returns + ------- + Column + A STRUCT column containing decoded protobuf fields. + """ + cdef decode_protobuf_options options + cdef int n = len(schema) + + # Build schema vector + cdef vector[nested_field_descriptor] c_schema + c_schema.reserve(n) + cdef nested_field_descriptor desc + for s in schema: + desc.field_number = s[0] + desc.parent_idx = s[1] + desc.depth = s[2] + desc.wire_type = (s[3]) + desc.output_type = (s[4]) + desc.encoding = (s[5]) + desc.is_repeated = s[6] + desc.is_required = s[7] + desc.has_default_value = s[8] + c_schema.push_back(desc) + + options.schema = move(c_schema) + + # Default values + cdef vector[int64_t] c_default_ints + c_default_ints.reserve(n) + for v in default_ints: + c_default_ints.push_back(v) + options.default_ints = move(c_default_ints) + + cdef vector[double] c_default_floats + c_default_floats.reserve(n) + for v in default_floats: + c_default_floats.push_back(v) + options.default_floats = move(c_default_floats) + + cdef vector[bool] c_default_bools + c_default_bools.reserve(n) + for v in default_bools: + c_default_bools.push_back(v) + options.default_bools = move(c_default_bools) + + options.fail_on_errors = fail_on_errors + + cdef Stream s = _get_stream(stream) + mr = _get_memory_resource(mr) + + cdef unique_ptr[column] c_result + with nogil: + c_result = move( + cpp_decode_protobuf( + binary_input.view(), + options, + s.view(), + mr.get_mr(), + ) + ) + + return Column.from_libcudf(move(c_result), s, mr) diff --git a/python/pylibcudf/pylibcudf/libcudf/io/protobuf.pxd b/python/pylibcudf/pylibcudf/libcudf/io/protobuf.pxd new file mode 100644 index 00000000000..064520a2c9d --- /dev/null +++ b/python/pylibcudf/pylibcudf/libcudf/io/protobuf.pxd @@ -0,0 +1,58 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. +# SPDX-License-Identifier: Apache-2.0 + +from libc.stdint cimport int32_t, int64_t, uint8_t +from libcpp cimport bool +from libcpp.memory cimport unique_ptr +from libcpp.vector cimport vector +from pylibcudf.exception_handler cimport libcudf_exception_handler +from pylibcudf.libcudf.column.column cimport column +from pylibcudf.libcudf.column.column_view cimport column_view +from pylibcudf.libcudf.types cimport type_id +from rmm.librmm.cuda_stream_view cimport cuda_stream_view +from rmm.librmm.device_buffer cimport device_buffer +from rmm.librmm.memory_resource cimport device_memory_resource + + +cdef extern from "cudf/io/protobuf.hpp" namespace "cudf::io::protobuf" nogil: + + cpdef enum class proto_encoding(int): + DEFAULT + FIXED + ZIGZAG + ENUM_STRING + + cpdef enum class proto_wire_type(int): + VARINT + I64BIT + LEN + SGROUP + EGROUP + I32BIT + + cdef struct nested_field_descriptor: + int field_number + int parent_idx + int depth + proto_wire_type wire_type + type_id output_type + proto_encoding encoding + bool is_repeated + bool is_required + bool has_default_value + + cdef struct decode_protobuf_options: + vector[nested_field_descriptor] schema + vector[int64_t] default_ints + vector[double] default_floats + vector[bool] default_bools + # Note: host_vector types are not easily bindable through Cython. + # The Python layer will need to handle conversion. + bool fail_on_errors + + cdef unique_ptr[column] decode_protobuf( + column_view binary_input, + decode_protobuf_options options, + cuda_stream_view stream, + device_memory_resource* mr, + ) except +libcudf_exception_handler