diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 4d83cbd907c..354560998c5 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -516,6 +516,7 @@ add_library( src/datetime/timezone.cpp src/io/orc/writer_impl.cu src/io/parquet/arrow_schema_writer.cpp + src/io/parquet/bloom_filter_reader.cu src/io/parquet/compact_protocol_reader.cpp src/io/parquet/compact_protocol_writer.cpp src/io/parquet/decode_preprocess.cu diff --git a/cpp/src/io/parquet/bloom_filter_reader.cu b/cpp/src/io/parquet/bloom_filter_reader.cu new file mode 100644 index 00000000000..8c404950efa --- /dev/null +++ b/cpp/src/io/parquet/bloom_filter_reader.cu @@ -0,0 +1,683 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "compact_protocol_reader.hpp" +#include "io/parquet/parquet.hpp" +#include "reader_impl_helpers.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include + +#include +#include +#include + +namespace cudf::io::parquet::detail { +namespace { + +/** + * @brief Converts bloom filter membership results (for each column chunk) to a device column. + * + */ +struct bloom_filter_caster { + cudf::device_span const> bloom_filter_spans; + host_span parquet_types; + size_t total_row_groups; + size_t num_equality_columns; + + enum class is_int96_timestamp : bool { YES, NO }; + + template + std::unique_ptr query_bloom_filter(cudf::size_type equality_col_idx, + cudf::data_type dtype, + ast::literal const* const literal, + rmm::cuda_stream_view stream) const + { + using key_type = T; + using policy_type = cuco::arrow_filter_policy; + using word_type = typename policy_type::word_type; + + // Boolean, List, Struct, Dictionary types are not supported + if constexpr (std::is_same_v or + (cudf::is_compound() and not std::is_same_v)) { + CUDF_FAIL("Bloom filters do not support boolean or compound types"); + } else { + // Check if the literal has the same type as the predicate column + CUDF_EXPECTS( + dtype == literal->get_data_type() and + cudf::have_same_types( + cudf::column_view{dtype, 0, {}, {}, 0, 0, {}}, + cudf::scalar_type_t(T{}, false, stream, cudf::get_current_device_resource_ref())), + "Mismatched predicate column and literal types"); + } + + // Filter properties + auto constexpr bytes_per_block = sizeof(word_type) * policy_type::words_per_block; + + rmm::device_buffer results{total_row_groups, stream, cudf::get_current_device_resource_ref()}; + cudf::device_span results_span{static_cast(results.data()), total_row_groups}; + + // Query literal in bloom filters from each column chunk (row group). + thrust::tabulate( + rmm::exec_policy_nosync(stream), + results_span.begin(), + results_span.end(), + [filter_span = bloom_filter_spans.data(), + d_scalar = literal->get_value(), + col_idx = equality_col_idx, + num_equality_columns = num_equality_columns] __device__(auto row_group_idx) { + // Filter bitset buffer index + auto const filter_idx = col_idx + (num_equality_columns * row_group_idx); + auto const filter_size = filter_span[filter_idx].size(); + + // If no bloom filter, then fill in `true` as membership cannot be determined + if (filter_size == 0) { return true; } + + // Number of filter blocks + auto const num_filter_blocks = filter_size / bytes_per_block; + + // Create a bloom filter view. + cuco::bloom_filter_ref, + cuco::thread_scope_thread, + policy_type> + filter{reinterpret_cast(filter_span[filter_idx].data()), + num_filter_blocks, + {}, // Thread scope as the same literal is being searched across different bitsets + // per thread + {}}; // Arrow policy with cudf::hashing::detail::XXHash_64 seeded with 0 for Arrow + // compatibility + + // If int96_timestamp type, convert literal to string_view and query bloom + // filter + if constexpr (cuda::std::is_same_v and + IS_INT96_TIMESTAMP == is_int96_timestamp::YES) { + auto const int128_key = static_cast<__int128_t>(d_scalar.value()); + cudf::string_view probe_key{reinterpret_cast(&int128_key), 12}; + return filter.contains(probe_key); + } else { + // Query the bloom filter and store results + return filter.contains(d_scalar.value()); + } + }); + + return std::make_unique(cudf::data_type{cudf::type_id::BOOL8}, + static_cast(total_row_groups), + std::move(results), + rmm::device_buffer{}, + 0); + } + + // Creates device columns from bloom filter membership + template + std::unique_ptr operator()(cudf::size_type equality_col_idx, + cudf::data_type dtype, + ast::literal* const literal, + rmm::cuda_stream_view stream) const + { + // For INT96 timestamps, use cudf::string_view type and set is_int96_timestamp to YES + if constexpr (cudf::is_timestamp()) { + if (parquet_types[equality_col_idx] == Type::INT96) { + // For INT96 timestamps, use cudf::string_view type and set is_int96_timestamp to YES + return query_bloom_filter( + equality_col_idx, dtype, literal, stream); + } + } + + // For all other cases + return query_bloom_filter(equality_col_idx, dtype, literal, stream); + } +}; + +/** + * @brief Collects lists of equality predicate literals in the AST expression, one list per input + * table column. This is used in row group filtering based on bloom filters. + */ +class equality_literals_collector : public ast::detail::expression_transformer { + public: + equality_literals_collector() = default; + + equality_literals_collector(ast::expression const& expr, cudf::size_type num_input_columns) + : _num_input_columns{num_input_columns} + { + _equality_literals.resize(_num_input_columns); + expr.accept(*this); + } + + /** + * @copydoc ast::detail::expression_transformer::visit(ast::literal const& ) + */ + std::reference_wrapper visit(ast::literal const& expr) override + { + return expr; + } + + /** + * @copydoc ast::detail::expression_transformer::visit(ast::column_reference const& ) + */ + std::reference_wrapper visit(ast::column_reference const& expr) override + { + CUDF_EXPECTS(expr.get_table_source() == ast::table_reference::LEFT, + "BloomfilterAST supports only left table"); + CUDF_EXPECTS(expr.get_column_index() < _num_input_columns, + "Column index cannot be more than number of columns in the table"); + return expr; + } + + /** + * @copydoc ast::detail::expression_transformer::visit(ast::column_name_reference const& ) + */ + std::reference_wrapper visit( + ast::column_name_reference const& expr) override + { + CUDF_FAIL("Column name reference is not supported in BloomfilterAST"); + } + + /** + * @copydoc ast::detail::expression_transformer::visit(ast::operation const& ) + */ + std::reference_wrapper visit(ast::operation const& expr) override + { + using cudf::ast::ast_operator; + auto const operands = expr.get_operands(); + auto const op = expr.get_operator(); + + if (auto* v = dynamic_cast(&operands[0].get())) { + // First operand should be column reference, second should be literal. + CUDF_EXPECTS(cudf::ast::detail::ast_operator_arity(op) == 2, + "Only binary operations are supported on column reference"); + auto const literal_ptr = dynamic_cast(&operands[1].get()); + CUDF_EXPECTS(literal_ptr != nullptr, + "Second operand of binary operation with column reference must be a literal"); + v->accept(*this); + + // Push to the corresponding column's literals list iff equality predicate is seen + if (op == ast_operator::EQUAL) { + auto const col_idx = v->get_column_index(); + _equality_literals[col_idx].emplace_back(const_cast(literal_ptr)); + } + } else { + // Just visit the operands and ignore any output + std::ignore = visit_operands(operands); + } + + return expr; + } + + /** + * @brief Vectors of equality literals in the AST expression, one per input table column + * + * @return Vectors of equality literals, one per input table column + */ + [[nodiscard]] std::vector> get_equality_literals() && + { + return std::move(_equality_literals); + } + + private: + std::vector> _equality_literals; + + protected: + std::vector> visit_operands( + cudf::host_span const> operands) + { + std::vector> transformed_operands; + for (auto const& operand : operands) { + auto const new_operand = operand.get().accept(*this); + transformed_operands.push_back(new_operand); + } + return transformed_operands; + } + size_type _num_input_columns; +}; + +/** + * @brief Converts AST expression to bloom filter membership (BloomfilterAST) expression. + * This is used in row group filtering based on equality predicate. + */ +class bloom_filter_expression_converter : public equality_literals_collector { + public: + bloom_filter_expression_converter( + ast::expression const& expr, + size_type num_input_columns, + cudf::host_span const> equality_literals) + : _equality_literals{equality_literals} + { + // Set the num columns + _num_input_columns = num_input_columns; + + // Compute and store columns literals offsets + _col_literals_offsets.reserve(_num_input_columns + 1); + _col_literals_offsets.emplace_back(0); + + std::transform(equality_literals.begin(), + equality_literals.end(), + std::back_inserter(_col_literals_offsets), + [&](auto const& col_literal_map) { + return _col_literals_offsets.back() + + static_cast(col_literal_map.size()); + }); + + // Add this visitor + expr.accept(*this); + } + + /** + * @brief Delete equality literals getter as it's not needed in the derived class + */ + [[nodiscard]] std::vector> get_equality_literals() && = delete; + + // Bring all overloads of `visit` from equality_predicate_collector into scope + using equality_literals_collector::visit; + + /** + * @copydoc ast::detail::expression_transformer::visit(ast::operation const& ) + */ + std::reference_wrapper visit(ast::operation const& expr) override + { + using cudf::ast::ast_operator; + auto const operands = expr.get_operands(); + auto const op = expr.get_operator(); + + if (auto* v = dynamic_cast(&operands[0].get())) { + // First operand should be column reference, second should be literal. + CUDF_EXPECTS(cudf::ast::detail::ast_operator_arity(op) == 2, + "Only binary operations are supported on column reference"); + CUDF_EXPECTS(dynamic_cast(&operands[1].get()) != nullptr, + "Second operand of binary operation with column reference must be a literal"); + v->accept(*this); + + if (op == ast_operator::EQUAL) { + // Search the literal in this input column's equality literals list and add to the offset. + auto const col_idx = v->get_column_index(); + auto const& equality_literals = _equality_literals[col_idx]; + auto col_literal_offset = _col_literals_offsets[col_idx]; + auto const literal_iter = std::find(equality_literals.cbegin(), + equality_literals.cend(), + dynamic_cast(&operands[1].get())); + CUDF_EXPECTS(literal_iter != equality_literals.end(), "Could not find the literal ptr"); + col_literal_offset += std::distance(equality_literals.cbegin(), literal_iter); + + // Evaluate boolean is_true(value) expression as NOT(NOT(value)) + auto const& value = _bloom_filter_expr.push(ast::column_reference{col_literal_offset}); + _bloom_filter_expr.push(ast::operation{ + ast_operator::NOT, _bloom_filter_expr.push(ast::operation{ast_operator::NOT, value})}); + } + // For all other expressions, push an always true expression + else { + _bloom_filter_expr.push( + ast::operation{ast_operator::NOT, + _bloom_filter_expr.push(ast::operation{ast_operator::NOT, _always_true})}); + } + } else { + auto new_operands = visit_operands(operands); + if (cudf::ast::detail::ast_operator_arity(op) == 2) { + _bloom_filter_expr.push(ast::operation{op, new_operands.front(), new_operands.back()}); + } else if (cudf::ast::detail::ast_operator_arity(op) == 1) { + _bloom_filter_expr.push(ast::operation{op, new_operands.front()}); + } + } + return _bloom_filter_expr.back(); + } + + /** + * @brief Returns the AST to apply on bloom filter membership. + * + * @return AST operation expression + */ + [[nodiscard]] std::reference_wrapper get_bloom_filter_expr() const + { + return _bloom_filter_expr.back(); + } + + private: + std::vector _col_literals_offsets; + cudf::host_span const> _equality_literals; + ast::tree _bloom_filter_expr; + cudf::numeric_scalar _always_true_scalar{true}; + ast::literal const _always_true{_always_true_scalar}; +}; + +/** + * @brief Reads bloom filter data to device. + * + * @param sources Dataset sources + * @param num_chunks Number of total column chunks to read + * @param bloom_filter_data Device buffers to hold bloom filter bitsets for each chunk + * @param bloom_filter_offsets Bloom filter offsets for all chunks + * @param bloom_filter_sizes Bloom filter sizes for all chunks + * @param chunk_source_map Association between each column chunk and its source + * @param stream CUDA stream used for device memory operations and kernel launches + */ +void read_bloom_filter_data(host_span const> sources, + size_t num_chunks, + cudf::host_span bloom_filter_data, + cudf::host_span> bloom_filter_offsets, + cudf::host_span> bloom_filter_sizes, + std::vector const& chunk_source_map, + rmm::cuda_stream_view stream) +{ + // Read tasks for bloom filter data + std::vector> read_tasks; + + // Read bloom filters for all column chunks + std::for_each( + thrust::counting_iterator(0), + thrust::counting_iterator(num_chunks), + [&](auto const chunk) { + // If bloom filter offset absent, fill in an empty buffer and skip ahead + if (not bloom_filter_offsets[chunk].has_value()) { + bloom_filter_data[chunk] = {}; + return; + } + // Read bloom filter iff present + auto const bloom_filter_offset = bloom_filter_offsets[chunk].value(); + + // If Bloom filter size (header + bitset) is available, just read the entire thing. + // Else just read 256 bytes which will contain the entire header and may contain the + // entire bitset as well. + auto constexpr bloom_filter_size_guess = 256; + auto const initial_read_size = + static_cast(bloom_filter_sizes[chunk].value_or(bloom_filter_size_guess)); + + // Read an initial buffer from source + auto& source = sources[chunk_source_map[chunk]]; + auto buffer = source->host_read(bloom_filter_offset, initial_read_size); + + // Deserialize the Bloom filter header from the buffer. + BloomFilterHeader header; + CompactProtocolReader cp{buffer->data(), buffer->size()}; + cp.read(&header); + + // Get the hardcoded words_per_block value from `cuco::arrow_filter_policy` using a temporary + // `std::byte` key type. + auto constexpr words_per_block = + cuco::arrow_filter_policy::words_per_block; + + // Check if the bloom filter header is valid. + auto const is_header_valid = + (header.num_bytes % words_per_block) == 0 and + header.compression.compression == BloomFilterCompression::Compression::UNCOMPRESSED and + header.algorithm.algorithm == BloomFilterAlgorithm::Algorithm::SPLIT_BLOCK and + header.hash.hash == BloomFilterHash::Hash::XXHASH; + + // Do not read if the bloom filter is invalid + if (not is_header_valid) { + bloom_filter_data[chunk] = {}; + CUDF_LOG_WARN("Encountered an invalid bloom filter header. Skipping"); + return; + } + + // Bloom filter header size + auto const bloom_filter_header_size = static_cast(cp.bytecount()); + auto const bitset_size = static_cast(header.num_bytes); + + // Check if we already read in the filter bitset in the initial read. + if (initial_read_size >= bloom_filter_header_size + bitset_size) { + bloom_filter_data[chunk] = + rmm::device_buffer{buffer->data() + bloom_filter_header_size, bitset_size, stream}; + } + // Read the bitset from datasource. + else { + auto const bitset_offset = bloom_filter_offset + bloom_filter_header_size; + // Directly read to device if preferred + if (source->is_device_read_preferred(bitset_size)) { + bloom_filter_data[chunk] = rmm::device_buffer{bitset_size, stream}; + auto future_read_size = + source->device_read_async(bitset_offset, + bitset_size, + static_cast(bloom_filter_data[chunk].data()), + stream); + + read_tasks.emplace_back(std::move(future_read_size)); + } else { + buffer = source->host_read(bitset_offset, bitset_size); + bloom_filter_data[chunk] = rmm::device_buffer{buffer->data(), buffer->size(), stream}; + } + } + }); + + // Read task sync function + for (auto& task : read_tasks) { + task.wait(); + } +} + +} // namespace + +std::vector aggregate_reader_metadata::read_bloom_filters( + host_span const> sources, + host_span const> row_group_indices, + host_span column_schemas, + size_type total_row_groups, + rmm::cuda_stream_view stream) const +{ + // Descriptors for all the chunks that make up the selected columns + auto const num_input_columns = column_schemas.size(); + auto const num_chunks = total_row_groups * num_input_columns; + + // Association between each column chunk and its source + std::vector chunk_source_map(num_chunks); + + // Keep track of column chunk file offsets + std::vector> bloom_filter_offsets(num_chunks); + std::vector> bloom_filter_sizes(num_chunks); + + // Gather all bloom filter offsets and sizes. + size_type chunk_count = 0; + + // Flag to check if we have at least one valid bloom filter offset + auto have_bloom_filters = false; + + // For all data sources + std::for_each(thrust::counting_iterator(0), + thrust::counting_iterator(row_group_indices.size()), + [&](auto const src_index) { + // Get all row group indices in the data source + auto const& rg_indices = row_group_indices[src_index]; + // For all row groups + std::for_each(rg_indices.cbegin(), rg_indices.cend(), [&](auto const rg_index) { + // For all column chunks + std::for_each( + column_schemas.begin(), column_schemas.end(), [&](auto const schema_idx) { + auto& col_meta = get_column_metadata(rg_index, src_index, schema_idx); + + // Get bloom filter offsets and sizes + bloom_filter_offsets[chunk_count] = col_meta.bloom_filter_offset; + bloom_filter_sizes[chunk_count] = col_meta.bloom_filter_length; + + // Set `have_bloom_filters` if `bloom_filter_offset` is valid + if (col_meta.bloom_filter_offset.has_value()) { have_bloom_filters = true; } + + // Map each column chunk to its source index + chunk_source_map[chunk_count] = src_index; + chunk_count++; + }); + }); + }); + + // Exit early if we don't have any bloom filters + if (not have_bloom_filters) { return {}; } + + // Vector to hold bloom filter data + std::vector bloom_filter_data(num_chunks); + + // Read bloom filter data + read_bloom_filter_data(sources, + num_chunks, + bloom_filter_data, + bloom_filter_offsets, + bloom_filter_sizes, + chunk_source_map, + stream); + + // Return bloom filter data + return bloom_filter_data; +} + +std::vector aggregate_reader_metadata::get_parquet_types( + host_span const> row_group_indices, + host_span column_schemas) const +{ + std::vector parquet_types(column_schemas.size()); + // Find a source with at least one row group + auto const src_iter = std::find_if(row_group_indices.begin(), + row_group_indices.end(), + [](auto const& rg) { return rg.size() > 0; }); + CUDF_EXPECTS(src_iter != row_group_indices.end(), ""); + + // Source index + auto const src_index = std::distance(row_group_indices.begin(), src_iter); + std::transform(column_schemas.begin(), + column_schemas.end(), + parquet_types.begin(), + [&](auto const schema_idx) { + // Use the first row group in this source + auto constexpr row_group_index = 0; + return get_column_metadata(row_group_index, src_index, schema_idx).type; + }); + + return parquet_types; +} + +std::optional>> aggregate_reader_metadata::apply_bloom_filters( + host_span const> sources, + host_span const> input_row_group_indices, + host_span output_dtypes, + host_span output_column_schemas, + std::reference_wrapper filter, + rmm::cuda_stream_view stream) const +{ + // Number of input table columns + auto const num_input_columns = static_cast(output_dtypes.size()); + + // Total number of row groups after StatsAST filtration + auto const total_row_groups = std::accumulate( + input_row_group_indices.begin(), + input_row_group_indices.end(), + size_t{0}, + [](size_t sum, auto const& per_file_row_groups) { return sum + per_file_row_groups.size(); }); + + // Check if we have less than 2B total row groups. + CUDF_EXPECTS(total_row_groups <= std::numeric_limits::max(), + "Total number of row groups exceed the size_type's limit"); + + // Collect equality literals for each input table column + auto const equality_literals = + equality_literals_collector{filter.get(), num_input_columns}.get_equality_literals(); + + // Collect schema indices of columns with equality predicate(s) + std::vector equality_col_schemas; + thrust::copy_if(thrust::host, + output_column_schemas.begin(), + output_column_schemas.end(), + equality_literals.begin(), + std::back_inserter(equality_col_schemas), + [](auto& eq_literals) { return not eq_literals.empty(); }); + + // Return early if no column with equality predicate(s) + if (equality_col_schemas.empty()) { return std::nullopt; } + + // Read a vector of bloom filter bitset device buffers for all columns with equality + // predicate(s) across all row groups + auto bloom_filter_data = read_bloom_filters( + sources, input_row_group_indices, equality_col_schemas, total_row_groups, stream); + + // No bloom filter buffers, return the original row group indices + if (bloom_filter_data.empty()) { return std::nullopt; } + + // Get parquet types for the predicate columns + auto const parquet_types = get_parquet_types(input_row_group_indices, equality_col_schemas); + + // Create spans from bloom filter bitset buffers to use in cuco::bloom_filter_ref. + std::vector> h_bloom_filter_spans; + h_bloom_filter_spans.reserve(bloom_filter_data.size()); + std::transform(bloom_filter_data.begin(), + bloom_filter_data.end(), + std::back_inserter(h_bloom_filter_spans), + [&](auto& buffer) { + return cudf::device_span{ + static_cast(buffer.data()), buffer.size()}; + }); + + // Copy bloom filter bitset spans to device + auto const bloom_filter_spans = cudf::detail::make_device_uvector_async( + h_bloom_filter_spans, stream, cudf::get_current_device_resource_ref()); + + // Create a bloom filter query table caster + bloom_filter_caster const bloom_filter_col{ + bloom_filter_spans, parquet_types, total_row_groups, equality_col_schemas.size()}; + + // Converts bloom filter membership for equality predicate columns to a table + // containing a column for each `col[i] == literal` predicate to be evaluated. + // The table contains #sources * #column_chunks_per_src rows. + std::vector> bloom_filter_membership_columns; + size_t equality_col_idx = 0; + std::for_each( + thrust::counting_iterator(0), + thrust::counting_iterator(output_dtypes.size()), + [&](auto input_col_idx) { + auto const& dtype = output_dtypes[input_col_idx]; + + // Skip if no equality literals for this column + if (equality_literals[input_col_idx].empty()) { return; } + + // Skip if non-comparable (compound) type except string + if (cudf::is_compound(dtype) and dtype.id() != cudf::type_id::STRING) { return; } + + // Add a column for all literals associated with an equality column + for (auto const& literal : equality_literals[input_col_idx]) { + bloom_filter_membership_columns.emplace_back(cudf::type_dispatcher( + dtype, bloom_filter_col, equality_col_idx, dtype, literal, stream)); + } + equality_col_idx++; + }); + + // Create a table from columns + auto bloom_filter_membership_table = cudf::table(std::move(bloom_filter_membership_columns)); + + // Convert AST to BloomfilterAST expression with reference to bloom filter membership + // in above `bloom_filter_membership_table` + bloom_filter_expression_converter bloom_filter_expr{ + filter.get(), num_input_columns, {equality_literals}}; + + // Filter bloom filter membership table with the BloomfilterAST expression and collect + // filtered row group indices + return collect_filtered_row_group_indices(bloom_filter_membership_table, + bloom_filter_expr.get_bloom_filter_expr(), + input_row_group_indices, + stream); +} + +} // namespace cudf::io::parquet::detail diff --git a/cpp/src/io/parquet/compact_protocol_reader.cpp b/cpp/src/io/parquet/compact_protocol_reader.cpp index f1ecf66c29f..b8e72aaac88 100644 --- a/cpp/src/io/parquet/compact_protocol_reader.cpp +++ b/cpp/src/io/parquet/compact_protocol_reader.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2024, NVIDIA CORPORATION. + * Copyright (c) 2018-2025, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -658,6 +658,33 @@ void CompactProtocolReader::read(ColumnChunk* c) function_builder(this, op); } +void CompactProtocolReader::read(BloomFilterAlgorithm* alg) +{ + auto op = std::make_tuple(parquet_field_union_enumerator(1, alg->algorithm)); + function_builder(this, op); +} + +void CompactProtocolReader::read(BloomFilterHash* hash) +{ + auto op = std::make_tuple(parquet_field_union_enumerator(1, hash->hash)); + function_builder(this, op); +} + +void CompactProtocolReader::read(BloomFilterCompression* comp) +{ + auto op = std::make_tuple(parquet_field_union_enumerator(1, comp->compression)); + function_builder(this, op); +} + +void CompactProtocolReader::read(BloomFilterHeader* bf) +{ + auto op = std::make_tuple(parquet_field_int32(1, bf->num_bytes), + parquet_field_struct(2, bf->algorithm), + parquet_field_struct(3, bf->hash), + parquet_field_struct(4, bf->compression)); + function_builder(this, op); +} + void CompactProtocolReader::read(ColumnChunkMetaData* c) { using optional_size_statistics = @@ -665,7 +692,9 @@ void CompactProtocolReader::read(ColumnChunkMetaData* c) using optional_list_enc_stats = parquet_field_optional, parquet_field_struct_list>; - auto op = std::make_tuple(parquet_field_enum(1, c->type), + using optional_i64 = parquet_field_optional; + using optional_i32 = parquet_field_optional; + auto op = std::make_tuple(parquet_field_enum(1, c->type), parquet_field_enum_list(2, c->encodings), parquet_field_string_list(3, c->path_in_schema), parquet_field_enum(4, c->codec), @@ -677,6 +706,8 @@ void CompactProtocolReader::read(ColumnChunkMetaData* c) parquet_field_int64(11, c->dictionary_page_offset), parquet_field_struct(12, c->statistics), optional_list_enc_stats(13, c->encoding_stats), + optional_i64(14, c->bloom_filter_offset), + optional_i32(15, c->bloom_filter_length), optional_size_statistics(16, c->size_statistics)); function_builder(this, op); } diff --git a/cpp/src/io/parquet/compact_protocol_reader.hpp b/cpp/src/io/parquet/compact_protocol_reader.hpp index b87f2e9c692..360197b19ad 100644 --- a/cpp/src/io/parquet/compact_protocol_reader.hpp +++ b/cpp/src/io/parquet/compact_protocol_reader.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2024, NVIDIA CORPORATION. + * Copyright (c) 2018-2025, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -108,6 +108,10 @@ class CompactProtocolReader { void read(IntType* t); void read(RowGroup* r); void read(ColumnChunk* c); + void read(BloomFilterAlgorithm* bf); + void read(BloomFilterHash* bf); + void read(BloomFilterCompression* bf); + void read(BloomFilterHeader* bf); void read(ColumnChunkMetaData* c); void read(PageHeader* p); void read(DataPageHeader* d); diff --git a/cpp/src/io/parquet/parquet.hpp b/cpp/src/io/parquet/parquet.hpp index 2851ef67a65..dc0c4b1540e 100644 --- a/cpp/src/io/parquet/parquet.hpp +++ b/cpp/src/io/parquet/parquet.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2024, NVIDIA CORPORATION. + * Copyright (c) 2018-2025, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -382,12 +382,62 @@ struct ColumnChunkMetaData { // Set of all encodings used for pages in this column chunk. This information can be used to // determine if all data pages are dictionary encoded for example. std::optional> encoding_stats; + // Byte offset from beginning of file to Bloom filter data. + std::optional bloom_filter_offset; + // Size of Bloom filter data including the serialized header, in bytes. Added in 2.10 so readers + // may not read this field from old files and it can be obtained after the BloomFilterHeader has + // been deserialized. Writers should write this field so readers can read the bloom filter in a + // single I/O. + std::optional bloom_filter_length; // Optional statistics to help estimate total memory when converted to in-memory representations. // The histograms contained in these statistics can also be useful in some cases for more // fine-grained nullability/list length filter pushdown. std::optional size_statistics; }; +/** + * @brief The algorithm used in bloom filter + */ +struct BloomFilterAlgorithm { + // Block-based Bloom filter. + enum class Algorithm { UNDEFINED, SPLIT_BLOCK }; + Algorithm algorithm{Algorithm::SPLIT_BLOCK}; +}; + +/** + * @brief The hash function used in Bloom filter + */ +struct BloomFilterHash { + // xxHash_64 + enum class Hash { UNDEFINED, XXHASH }; + Hash hash{Hash::XXHASH}; +}; + +/** + * @brief The compression used in the bloom filter + */ +struct BloomFilterCompression { + enum class Compression { UNDEFINED, UNCOMPRESSED }; + Compression compression{Compression::UNCOMPRESSED}; +}; + +/** + * @brief Bloom filter header struct + * + * The bloom filter data of a column chunk stores this header at the beginning + * following by the filter bitset. + */ +struct BloomFilterHeader { + // The size of bitset in bytes + int32_t num_bytes; + // The algorithm for setting bits + BloomFilterAlgorithm algorithm; + // The hash function used for bloom filter + BloomFilterHash hash; + // The compression used in the bloom filter + BloomFilterCompression compression; +}; + /** * @brief Thrift-derived struct describing a chunk of data for a particular * column diff --git a/cpp/src/io/parquet/predicate_pushdown.cpp b/cpp/src/io/parquet/predicate_pushdown.cpp index 9047ff9169b..0e307bac097 100644 --- a/cpp/src/io/parquet/predicate_pushdown.cpp +++ b/cpp/src/io/parquet/predicate_pushdown.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * Copyright (c) 2023-2025, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -32,6 +32,7 @@ #include #include +#include #include #include #include @@ -388,6 +389,7 @@ class stats_expression_converter : public ast::detail::expression_transformer { } // namespace std::optional>> aggregate_reader_metadata::filter_row_groups( + host_span const> sources, host_span const> row_group_indices, host_span output_dtypes, host_span output_column_schemas, @@ -396,7 +398,6 @@ std::optional>> aggregate_reader_metadata::fi { auto mr = cudf::get_current_device_resource_ref(); // Create row group indices. - std::vector> filtered_row_group_indices; std::vector> all_row_group_indices; host_span const> input_row_group_indices; if (row_group_indices.empty()) { @@ -412,18 +413,22 @@ std::optional>> aggregate_reader_metadata::fi } else { input_row_group_indices = row_group_indices; } - auto const total_row_groups = std::accumulate(input_row_group_indices.begin(), - input_row_group_indices.end(), - 0, - [](size_type sum, auto const& per_file_row_groups) { - return sum + per_file_row_groups.size(); - }); + auto const total_row_groups = std::accumulate( + input_row_group_indices.begin(), + input_row_group_indices.end(), + size_t{0}, + [](size_t sum, auto const& per_file_row_groups) { return sum + per_file_row_groups.size(); }); + + // Check if we have less than 2B total row groups. + CUDF_EXPECTS(total_row_groups <= std::numeric_limits::max(), + "Total number of row groups exceed the size_type's limit"); // Converts Column chunk statistics to a table // where min(col[i]) = columns[i*2], max(col[i])=columns[i*2+1] // For each column, it contains #sources * #column_chunks_per_src rows. std::vector> columns; - stats_caster const stats_col{total_row_groups, per_file_metadata, input_row_group_indices}; + stats_caster const stats_col{ + static_cast(total_row_groups), per_file_metadata, input_row_group_indices}; for (size_t col_idx = 0; col_idx < output_dtypes.size(); col_idx++) { auto const schema_idx = output_column_schemas[col_idx]; auto const& dtype = output_dtypes[col_idx]; @@ -452,44 +457,23 @@ std::optional>> aggregate_reader_metadata::fi CUDF_EXPECTS(predicate.type().id() == cudf::type_id::BOOL8, "Filter expression must return a boolean column"); - auto const host_bitmask = [&] { - auto const num_bitmasks = num_bitmask_words(predicate.size()); - if (predicate.nullable()) { - return cudf::detail::make_host_vector_sync( - device_span(predicate.null_mask(), num_bitmasks), stream); - } else { - auto bitmask = cudf::detail::make_host_vector(num_bitmasks, stream); - std::fill(bitmask.begin(), bitmask.end(), ~bitmask_type{0}); - return bitmask; - } - }(); + // Filter stats table with StatsAST expression and collect filtered row group indices + auto const filtered_row_group_indices = collect_filtered_row_group_indices( + stats_table, stats_expr.get_stats_expr(), input_row_group_indices, stream); - auto validity_it = cudf::detail::make_counting_transform_iterator( - 0, [bitmask = host_bitmask.data()](auto bit_index) { return bit_is_set(bitmask, bit_index); }); + // Span of row groups to apply bloom filtering on. + auto const bloom_filter_input_row_groups = + filtered_row_group_indices.has_value() + ? host_span const>(filtered_row_group_indices.value()) + : input_row_group_indices; - auto const is_row_group_required = cudf::detail::make_host_vector_sync( - device_span(predicate.data(), predicate.size()), stream); + // Apply bloom filtering on the bloom filter input row groups + auto const bloom_filtered_row_groups = apply_bloom_filters( + sources, bloom_filter_input_row_groups, output_dtypes, output_column_schemas, filter, stream); - // Return only filtered row groups based on predicate - // if all are required or all are nulls, return. - if (std::all_of(is_row_group_required.cbegin(), - is_row_group_required.cend(), - [](auto i) { return bool(i); }) or - predicate.null_count() == predicate.size()) { - return std::nullopt; - } - size_type is_required_idx = 0; - for (auto const& input_row_group_index : input_row_group_indices) { - std::vector filtered_row_groups; - for (auto const rg_idx : input_row_group_index) { - if ((!validity_it[is_required_idx]) || is_row_group_required[is_required_idx]) { - filtered_row_groups.push_back(rg_idx); - } - ++is_required_idx; - } - filtered_row_group_indices.push_back(std::move(filtered_row_groups)); - } - return {std::move(filtered_row_group_indices)}; + // Return bloom filtered row group indices iff collected + return bloom_filtered_row_groups.has_value() ? bloom_filtered_row_groups + : filtered_row_group_indices; } // convert column named expression to column index reference expression @@ -510,14 +494,14 @@ named_to_reference_converter::named_to_reference_converter( std::reference_wrapper named_to_reference_converter::visit( ast::literal const& expr) { - _stats_expr = std::reference_wrapper(expr); + _converted_expr = std::reference_wrapper(expr); return expr; } std::reference_wrapper named_to_reference_converter::visit( ast::column_reference const& expr) { - _stats_expr = std::reference_wrapper(expr); + _converted_expr = std::reference_wrapper(expr); return expr; } @@ -531,7 +515,7 @@ std::reference_wrapper named_to_reference_converter::visi } auto col_index = col_index_it->second; _col_ref.emplace_back(col_index); - _stats_expr = std::reference_wrapper(_col_ref.back()); + _converted_expr = std::reference_wrapper(_col_ref.back()); return std::reference_wrapper(_col_ref.back()); } @@ -546,7 +530,7 @@ std::reference_wrapper named_to_reference_converter::visi } else if (cudf::ast::detail::ast_operator_arity(op) == 1) { _operators.emplace_back(op, new_operands.front()); } - _stats_expr = std::reference_wrapper(_operators.back()); + _converted_expr = std::reference_wrapper(_operators.back()); return std::reference_wrapper(_operators.back()); } @@ -640,4 +624,60 @@ class names_from_expression : public ast::detail::expression_transformer { return names_from_expression(expr, skip_names).to_vector(); } +std::optional>> collect_filtered_row_group_indices( + cudf::table_view table, + std::reference_wrapper ast_expr, + host_span const> input_row_group_indices, + rmm::cuda_stream_view stream) +{ + // Filter the input table using AST expression + auto predicate_col = cudf::detail::compute_column( + table, ast_expr.get(), stream, cudf::get_current_device_resource_ref()); + auto predicate = predicate_col->view(); + CUDF_EXPECTS(predicate.type().id() == cudf::type_id::BOOL8, + "Filter expression must return a boolean column"); + + auto const host_bitmask = [&] { + auto const num_bitmasks = num_bitmask_words(predicate.size()); + if (predicate.nullable()) { + return cudf::detail::make_host_vector_sync( + device_span(predicate.null_mask(), num_bitmasks), stream); + } else { + auto bitmask = cudf::detail::make_host_vector(num_bitmasks, stream); + std::fill(bitmask.begin(), bitmask.end(), ~bitmask_type{0}); + return bitmask; + } + }(); + + auto validity_it = cudf::detail::make_counting_transform_iterator( + 0, [bitmask = host_bitmask.data()](auto bit_index) { return bit_is_set(bitmask, bit_index); }); + + // Return only filtered row groups based on predicate + auto const is_row_group_required = cudf::detail::make_host_vector_sync( + device_span(predicate.data(), predicate.size()), stream); + + // Return if all are required, or all are nulls. + if (predicate.null_count() == predicate.size() or std::all_of(is_row_group_required.cbegin(), + is_row_group_required.cend(), + [](auto i) { return bool(i); })) { + return std::nullopt; + } + + // Collect indices of the filtered row groups + size_type is_required_idx = 0; + std::vector> filtered_row_group_indices; + for (auto const& input_row_group_index : input_row_group_indices) { + std::vector filtered_row_groups; + for (auto const rg_idx : input_row_group_index) { + if ((!validity_it[is_required_idx]) || is_row_group_required[is_required_idx]) { + filtered_row_groups.push_back(rg_idx); + } + ++is_required_idx; + } + filtered_row_group_indices.push_back(std::move(filtered_row_groups)); + } + + return {filtered_row_group_indices}; +} + } // namespace cudf::io::parquet::detail diff --git a/cpp/src/io/parquet/reader_impl_helpers.cpp b/cpp/src/io/parquet/reader_impl_helpers.cpp index 0dd1aff41e9..25baa1e0ec8 100644 --- a/cpp/src/io/parquet/reader_impl_helpers.cpp +++ b/cpp/src/io/parquet/reader_impl_helpers.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. + * Copyright (c) 2022-2025, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -1030,6 +1030,7 @@ std::vector aggregate_reader_metadata::get_pandas_index_names() con std::tuple, std::vector> aggregate_reader_metadata::select_row_groups( + host_span const> sources, host_span const> row_group_indices, int64_t skip_rows_opt, std::optional const& num_rows_opt, @@ -1042,7 +1043,7 @@ aggregate_reader_metadata::select_row_groups( // if filter is not empty, then gather row groups to read after predicate pushdown if (filter.has_value()) { filtered_row_group_indices = filter_row_groups( - row_group_indices, output_dtypes, output_column_schemas, filter.value(), stream); + sources, row_group_indices, output_dtypes, output_column_schemas, filter.value(), stream); if (filtered_row_group_indices.has_value()) { row_group_indices = host_span const>(filtered_row_group_indices.value()); diff --git a/cpp/src/io/parquet/reader_impl_helpers.hpp b/cpp/src/io/parquet/reader_impl_helpers.hpp index fd692c0cdd6..a28ce616e2c 100644 --- a/cpp/src/io/parquet/reader_impl_helpers.hpp +++ b/cpp/src/io/parquet/reader_impl_helpers.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. + * Copyright (c) 2022-2025, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -195,6 +195,38 @@ class aggregate_reader_metadata { */ void column_info_for_row_group(row_group_info& rg_info, size_type chunk_start_row) const; + /** + * @brief Reads bloom filter bitsets for the specified columns from the given lists of row + * groups. + * + * @param sources Dataset sources + * @param row_group_indices Lists of row groups to read bloom filters from, one per source + * @param[out] bloom_filter_data List of bloom filter data device buffers + * @param column_schemas Schema indices of columns whose bloom filters will be read + * @param stream CUDA stream used for device memory operations and kernel launches + * + * @return A flattened list of bloom filter bitset device buffers for each predicate column across + * row group + */ + [[nodiscard]] std::vector read_bloom_filters( + host_span const> sources, + host_span const> row_group_indices, + host_span column_schemas, + size_type num_row_groups, + rmm::cuda_stream_view stream) const; + + /** + * @brief Collects Parquet types for the columns with the specified schema indices + * + * @param row_group_indices Lists of row groups, once per source + * @param column_schemas Schema indices of columns whose types will be collected + * + * @return A list of parquet types for the columns matching the provided schema indices + */ + [[nodiscard]] std::vector get_parquet_types( + host_span const> row_group_indices, + host_span column_schemas) const; + public: aggregate_reader_metadata(host_span const> sources, bool use_arrow_schema, @@ -323,26 +355,49 @@ class aggregate_reader_metadata { /** * @brief Filters the row groups based on predicate filter * + * @param sources Lists of input datasources * @param row_group_indices Lists of row groups to read, one per source - * @param output_dtypes Datatypes of of output columns + * @param output_dtypes Datatypes of output columns * @param output_column_schemas schema indices of output columns * @param filter AST expression to filter row groups based on Column chunk statistics * @param stream CUDA stream used for device memory operations and kernel launches - * @return Filtered row group indices, if any is filtered. + * @return Filtered row group indices, if any is filtered */ [[nodiscard]] std::optional>> filter_row_groups( + host_span const> sources, host_span const> row_group_indices, host_span output_dtypes, host_span output_column_schemas, std::reference_wrapper filter, rmm::cuda_stream_view stream) const; + /** + * @brief Filters the row groups using bloom filters + * + * @param sources Dataset sources + * @param row_group_indices Lists of input row groups to read, one per source + * @param output_dtypes Datatypes of output columns + * @param output_column_schemas schema indices of output columns + * @param filter AST expression to filter row groups based on bloom filter membership + * @param stream CUDA stream used for device memory operations and kernel launches + * + * @return Filtered row group indices, if any is filtered + */ + [[nodiscard]] std::optional>> apply_bloom_filters( + host_span const> sources, + host_span const> input_row_group_indices, + host_span output_dtypes, + host_span output_column_schemas, + std::reference_wrapper filter, + rmm::cuda_stream_view stream) const; + /** * @brief Filters and reduces down to a selection of row groups * * The input `row_start` and `row_count` parameters will be recomputed and output as the valid * values based on the input row group list. * + * @param sources Lists of input datasources * @param row_group_indices Lists of row groups to read, one per source * @param row_start Starting row of the selection * @param row_count Total number of rows selected @@ -351,10 +406,11 @@ class aggregate_reader_metadata { * @param filter Optional AST expression to filter row groups based on Column chunk statistics * @param stream CUDA stream used for device memory operations and kernel launches * @return A tuple of corrected row_start, row_count, list of row group indexes and its - * starting row, and list of number of rows per source. + * starting row, and list of number of rows per source */ [[nodiscard]] std::tuple, std::vector> - select_row_groups(host_span const> row_group_indices, + select_row_groups(host_span const> sources, + host_span const> row_group_indices, int64_t row_start, std::optional const& row_count, host_span output_dtypes, @@ -413,14 +469,14 @@ class named_to_reference_converter : public ast::detail::expression_transformer std::reference_wrapper visit(ast::operation const& expr) override; /** - * @brief Returns the AST to apply on Column chunk statistics. + * @brief Returns the converted AST expression * * @return AST operation expression */ [[nodiscard]] std::optional> get_converted_expr() const { - return _stats_expr; + return _converted_expr; } private: @@ -428,7 +484,7 @@ class named_to_reference_converter : public ast::detail::expression_transformer cudf::host_span const> operands); std::unordered_map column_name_to_index; - std::optional> _stats_expr; + std::optional> _converted_expr; // Using std::list or std::deque to avoid reference invalidation std::list _col_ref; std::list _operators; @@ -445,4 +501,22 @@ class named_to_reference_converter : public ast::detail::expression_transformer std::optional> expr, std::vector const& skip_names); +/** + * @brief Filter table using the provided (StatsAST or BloomfilterAST) expression and + * collect filtered row group indices + * + * @param table Table of stats or bloom filter membership columns + * @param ast_expr StatsAST or BloomfilterAST expression to filter with + * @param input_row_group_indices Lists of input row groups to read, one per source + * @param stream CUDA stream used for device memory operations and kernel launches + * + * @return Collected filtered row group indices, one vector per source, if any. A std::nullopt if + * all row groups are required or if the computed predicate is all nulls + */ +[[nodiscard]] std::optional>> collect_filtered_row_group_indices( + cudf::table_view ast_table, + std::reference_wrapper ast_expr, + host_span const> input_row_group_indices, + rmm::cuda_stream_view stream); + } // namespace cudf::io::parquet::detail diff --git a/cpp/src/io/parquet/reader_impl_preprocess.cu b/cpp/src/io/parquet/reader_impl_preprocess.cu index 326232ced60..43666f9e42d 100644 --- a/cpp/src/io/parquet/reader_impl_preprocess.cu +++ b/cpp/src/io/parquet/reader_impl_preprocess.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. + * Copyright (c) 2022-2025, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -1286,7 +1286,8 @@ void reader::impl::preprocess_file(read_mode mode) _file_itm_data.global_num_rows, _file_itm_data.row_groups, _file_itm_data.num_rows_per_source) = - _metadata->select_row_groups(_options.row_group_indices, + _metadata->select_row_groups(_sources, + _options.row_group_indices, _options.skip_rows, _options.num_rows, output_dtypes, diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 35877ac34b9..6a89b1e48d6 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -318,14 +318,15 @@ ConfigureTest( ) ConfigureTest( PARQUET_TEST - io/parquet_test.cpp + io/parquet_bloom_filter_test.cu io/parquet_chunked_reader_test.cu io/parquet_chunked_writer_test.cpp io/parquet_common.cpp io/parquet_misc_test.cpp io/parquet_reader_test.cpp - io/parquet_writer_test.cpp + io/parquet_test.cpp io/parquet_v2_test.cpp + io/parquet_writer_test.cpp GPUS 1 PERCENT 30 ) diff --git a/cpp/tests/io/parquet_bloom_filter_test.cu b/cpp/tests/io/parquet_bloom_filter_test.cu new file mode 100644 index 00000000000..d858f58fa56 --- /dev/null +++ b/cpp/tests/io/parquet_bloom_filter_test.cu @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include + +using StringType = cudf::string_view; + +class ParquetBloomFilterTest : public cudf::test::BaseFixture {}; + +TEST_F(ParquetBloomFilterTest, TestStrings) +{ + using key_type = StringType; + using policy_type = cuco::arrow_filter_policy; + using word_type = policy_type::word_type; + + std::size_t constexpr num_filter_blocks = 4; + auto stream = cudf::get_default_stream(); + + // strings keys to insert + auto keys = cudf::test::strings_column_wrapper( + {"seventh", "fifteenth", "second", "tenth", "fifth", "first", + "seventh", "tenth", "ninth", "ninth", "seventeenth", "eighteenth", + "thirteenth", "fifth", "fourth", "twelfth", "second", "second", + "fourth", "seventh", "seventh", "tenth", "thirteenth", "seventeenth", + "fifth", "seventeenth", "eighth", "fourth", "second", "eighteenth", + "fifteenth", "second", "seventeenth", "thirteenth", "eighteenth", "fifth", + "seventh", "tenth", "fourteenth", "first", "fifth", "fifth", + "tenth", "thirteenth", "fourteenth", "third", "third", "sixth", + "first", "third"}); + + auto d_keys = cudf::column_device_view::create(keys); + + // Spawn a bloom filter + cuco::bloom_filter, + cuda::thread_scope_device, + policy_type, + cudf::detail::cuco_allocator> + filter{num_filter_blocks, + cuco::thread_scope_device, + {{cudf::DEFAULT_HASH_SEED}}, + cudf::detail::cuco_allocator{rmm::mr::polymorphic_allocator{}, stream}, + stream}; + + // Add strings to the bloom filter + filter.add(d_keys->begin(), d_keys->end(), stream); + + // Number of words in the filter + cudf::size_type const num_words = filter.block_extent() * filter.words_per_block; + + // Filter bitset as a column + auto const bitset = cudf::column_view{ + cudf::data_type{cudf::type_id::UINT32}, num_words, filter.data(), nullptr, 0, 0, {}}; + + // Expected filter bitset words computed using Arrow's implementation here: + // https://godbolt.org/z/oKfqcPWbY + auto expected = cudf::test::fixed_width_column_wrapper( + {4194306U, 4194305U, 2359296U, 1073774592U, 524544U, 1024U, 268443648U, + 8519680U, 2147500040U, 8421380U, 269500416U, 4202624U, 8396802U, 100665344U, + 2147747840U, 5243136U, 131146U, 655364U, 285345792U, 134222340U, 545390596U, + 2281717768U, 51201U, 41943553U, 1619656708U, 67441680U, 8462730U, 361220U, + 2216738864U, 587333888U, 4219272U, 873463873U}); + + // Check the bitset for equality + CUDF_TEST_EXPECT_COLUMNS_EQUAL(bitset, expected); +} diff --git a/python/cudf/cudf/tests/data/parquet/mixed_card_ndv_100_bf_fpp0.1_nostats.snappy.parquet b/python/cudf/cudf/tests/data/parquet/mixed_card_ndv_100_bf_fpp0.1_nostats.snappy.parquet new file mode 100644 index 00000000000..4123545a6e0 Binary files /dev/null and b/python/cudf/cudf/tests/data/parquet/mixed_card_ndv_100_bf_fpp0.1_nostats.snappy.parquet differ diff --git a/python/cudf/cudf/tests/data/parquet/mixed_card_ndv_100_chunk_stats.snappy.parquet b/python/cudf/cudf/tests/data/parquet/mixed_card_ndv_100_chunk_stats.snappy.parquet new file mode 100644 index 00000000000..7dc2cee21ae Binary files /dev/null and b/python/cudf/cudf/tests/data/parquet/mixed_card_ndv_100_chunk_stats.snappy.parquet differ diff --git a/python/cudf/cudf/tests/data/parquet/mixed_card_ndv_500_bf_fpp0.1_nostats.snappy.parquet b/python/cudf/cudf/tests/data/parquet/mixed_card_ndv_500_bf_fpp0.1_nostats.snappy.parquet new file mode 100644 index 00000000000..e898f1d7d1b Binary files /dev/null and b/python/cudf/cudf/tests/data/parquet/mixed_card_ndv_500_bf_fpp0.1_nostats.snappy.parquet differ diff --git a/python/cudf/cudf/tests/data/parquet/mixed_card_ndv_500_chunk_stats.snappy.parquet b/python/cudf/cudf/tests/data/parquet/mixed_card_ndv_500_chunk_stats.snappy.parquet new file mode 100644 index 00000000000..3060234d499 Binary files /dev/null and b/python/cudf/cudf/tests/data/parquet/mixed_card_ndv_500_chunk_stats.snappy.parquet differ diff --git a/python/cudf/cudf/tests/test_parquet.py b/python/cudf/cudf/tests/test_parquet.py index 77d1f77d30b..9d5f32c7ab9 100644 --- a/python/cudf/cudf/tests/test_parquet.py +++ b/python/cudf/cudf/tests/test_parquet.py @@ -1,6 +1,7 @@ -# Copyright (c) 2019-2024, NVIDIA CORPORATION. +# Copyright (c) 2019-2025, NVIDIA CORPORATION. import datetime +import decimal import glob import hashlib import math @@ -4370,3 +4371,57 @@ def test_parquet_reader_mismatched_nullability_structs(tmpdir): cudf.read_parquet([buf2, buf1]), cudf.concat([df2, df1]).reset_index(drop=True), ) + + +@pytest.mark.parametrize( + "stats_fname,bloom_filter_fname", + [ + ( + "mixed_card_ndv_100_chunk_stats.snappy.parquet", + "mixed_card_ndv_100_bf_fpp0.1_nostats.snappy.parquet", + ), + ( + "mixed_card_ndv_500_chunk_stats.snappy.parquet", + "mixed_card_ndv_500_bf_fpp0.1_nostats.snappy.parquet", + ), + ], +) +@pytest.mark.parametrize( + "predicate,expected_len", + [ + ([[("str", "==", "FINDME")], [("fp64", "==", float(500))]], 2), + ([("fixed_pt", "==", decimal.Decimal(float(500)))], 2), + ([[("ui32", "==", np.uint32(500)), ("str", "==", "FINDME")]], 2), + ([[("str", "==", "FINDME")], [("ui32", ">=", np.uint32(0))]], 1000), + ( + [ + ("str", "!=", "FINDME"), + ("fixed_pt", "==", decimal.Decimal(float(500))), + ], + 0, + ), + ], +) +def test_parquet_bloom_filters( + datadir, stats_fname, bloom_filter_fname, predicate, expected_len +): + fname_stats = datadir / stats_fname + fname_bf = datadir / bloom_filter_fname + df_stats = cudf.read_parquet(fname_stats, filters=predicate).reset_index( + drop=True + ) + df_bf = cudf.read_parquet(fname_bf, filters=predicate).reset_index( + drop=True + ) + + # Check if tables equal + assert_eq( + df_stats, + df_bf, + ) + + # Check for table length + assert_eq( + len(df_stats), + expected_len, + )