diff --git a/quickwit/quickwit-datafusion/src/sources/metrics/factory.rs b/quickwit/quickwit-datafusion/src/sources/metrics/factory.rs index b95c9cbe8e0..ae0c37a61ed 100644 --- a/quickwit/quickwit-datafusion/src/sources/metrics/factory.rs +++ b/quickwit/quickwit-datafusion/src/sources/metrics/factory.rs @@ -93,7 +93,8 @@ impl TableProviderFactory for MetricsTableProviderFactory { .resolve(&index_name, self.split_kind) .await?; - let arrow_schema: SchemaRef = Arc::new(cmd.schema.as_arrow().clone()); + let arrow_schema: SchemaRef = + super::metrics_table_provider_schema(Arc::new(cmd.schema.as_arrow().clone())); if arrow_schema.fields().is_empty() { return Err(DataFusionError::Plan(format!( diff --git a/quickwit/quickwit-datafusion/src/sources/metrics/mod.rs b/quickwit/quickwit-datafusion/src/sources/metrics/mod.rs index b8f41ad0747..c3e5756b6a0 100644 --- a/quickwit/quickwit-datafusion/src/sources/metrics/mod.rs +++ b/quickwit/quickwit-datafusion/src/sources/metrics/mod.rs @@ -115,6 +115,7 @@ async fn resolve_metrics_table_provider( match index_resolver.resolve(index_name, split_kind).await { Ok((split_provider, index_uri)) => { + let schema = metrics_table_provider_schema(schema); let provider = MetricsTableProvider::new(schema, split_provider, index_uri)?; Ok(Some(Arc::new(provider))) } @@ -128,6 +129,10 @@ async fn resolve_metrics_table_provider( } } +fn dict_encoded_string_type() -> DataType { + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)) +} + fn split_kind_from_index_name(index_name: &str) -> Option { if is_metrics_index(index_name) { Some(ParquetSplitKind::Metrics) @@ -140,7 +145,7 @@ fn split_kind_from_index_name(index_name: &str) -> Option { /// Minimal 4-column schema — always present in every OSS metrics parquet file. fn minimal_base_schema() -> SchemaRef { - let dict = DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)); + let dict = dict_encoded_string_type(); Arc::new(ArrowSchema::new(vec![ Field::new("metric_name", dict, false), Field::new("metric_type", DataType::UInt8, false), @@ -166,6 +171,43 @@ fn minimal_schema_for_kind(split_kind: ParquetSplitKind) -> SchemaRef { } } +fn is_arrow_string_type(data_type: &DataType) -> bool { + matches!( + data_type, + DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View + ) +} + +fn metrics_table_provider_schema(schema: SchemaRef) -> SchemaRef { + let mut has_string_fields = false; + let fields = schema + .fields() + .iter() + .map(|field| { + if is_arrow_string_type(field.data_type()) { + has_string_fields = true; + Arc::new( + field + .as_ref() + .clone() + .with_data_type(dict_encoded_string_type()), + ) + } else { + Arc::clone(field) + } + }) + .collect::>(); + + if has_string_fields { + Arc::new(ArrowSchema::new_with_metadata( + fields, + schema.metadata().clone(), + )) + } else { + schema + } +} + /// Native OSS `SchemaProvider` for metrics indexes. pub struct MetricsSchemaProvider { index_resolver: Arc, @@ -336,3 +378,62 @@ impl QuickwitSubstraitConsumerExt for MetricsDataSource { Ok(provider.map(|provider| (index_name.to_string(), provider))) } } + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use super::*; + + #[test] + fn table_provider_schema_dictionary_encodes_string_fields() { + let field_metadata = HashMap::from([("field_key".to_string(), "field_value".to_string())]); + let schema_metadata = + HashMap::from([("schema_key".to_string(), "schema_value".to_string())]); + let schema = Arc::new(ArrowSchema::new_with_metadata( + vec![ + Field::new("metric_name", DataType::Utf8, false), + Field::new("service", DataType::LargeUtf8, true) + .with_metadata(field_metadata.clone()), + Field::new("env", DataType::Utf8View, true), + Field::new("value", DataType::Float64, false), + Field::new("already_dict", dict_encoded_string_type(), true), + ], + schema_metadata.clone(), + )); + + let normalized = metrics_table_provider_schema(schema); + + assert_eq!( + normalized + .field_with_name("metric_name") + .unwrap() + .data_type(), + &dict_encoded_string_type() + ); + assert_eq!( + normalized.field_with_name("service").unwrap().data_type(), + &dict_encoded_string_type() + ); + assert_eq!( + normalized.field_with_name("env").unwrap().data_type(), + &dict_encoded_string_type() + ); + assert_eq!( + normalized.field_with_name("value").unwrap().data_type(), + &DataType::Float64 + ); + assert_eq!( + normalized + .field_with_name("already_dict") + .unwrap() + .data_type(), + &dict_encoded_string_type() + ); + assert_eq!(normalized.metadata(), &schema_metadata); + assert_eq!( + normalized.field_with_name("service").unwrap().metadata(), + &field_metadata + ); + } +} diff --git a/quickwit/quickwit-datafusion/tests/metrics.rs b/quickwit/quickwit-datafusion/tests/metrics.rs index 910f4d79db1..bb819ddf158 100644 --- a/quickwit/quickwit-datafusion/tests/metrics.rs +++ b/quickwit/quickwit-datafusion/tests/metrics.rs @@ -73,6 +73,41 @@ fn total_rows(batches: &[RecordBatch]) -> usize { batches.iter().map(|b| b.num_rows()).sum() } +fn string_value(array: &dyn Array, row: usize) -> String { + assert!(!array.is_null(row), "unexpected NULL string at row {row}"); + if let Some(strings) = array + .as_any() + .downcast_ref::() + { + strings.value(row).to_string() + } else if let Some(strings) = array.as_any().downcast_ref::() { + strings.value(row).to_string() + } else if let Some(dict) = array + .as_any() + .downcast_ref::>() + { + let key = usize::try_from(dict.keys().value(row)).unwrap(); + string_value(dict.values().as_ref(), key) + } else { + panic!("unexpected string column type {:?}", array.data_type()); + } +} + +fn assert_dict_encoded_string_column(batches: &[RecordBatch], column_name: &str) { + let expected = arrow::datatypes::DataType::Dictionary( + Box::new(arrow::datatypes::DataType::Int32), + Box::new(arrow::datatypes::DataType::Utf8), + ); + for batch in batches { + let column = batch.column_by_name(column_name).unwrap(); + assert_eq!( + column.data_type(), + &expected, + "expected `{column_name}` to stay dictionary encoded" + ); + } +} + // ═══════════════════════════════════════════════════════════════════ // Tests // ═══════════════════════════════════════════════════════════════════ @@ -692,33 +727,28 @@ async fn test_rollup_nested_aggregation() { 6, "expected 6 rows (3 bins × 2 services); staging rows must be excluded" ); + assert_dict_encoded_string_column(&batches, "service"); // Collect (service, avg_val) pairs in ORDER BY time_bin, service order. - // After GROUP BY, DataFusion casts dict-encoded strings to plain Utf8. - let results: Vec<(String, f64)> = batches.iter().flat_map(|batch| { - let svc_raw = batch.column_by_name("service").unwrap(); - let avg_col = batch.column_by_name("avg_val").unwrap() - .as_any().downcast_ref::().unwrap(); - (0..batch.num_rows()).map(|i| { - // After GROUP BY, DataFusion 52 may return Utf8View, Utf8, or Dict. - let svc = if let Some(sa) = svc_raw.as_any() - .downcast_ref::() { - sa.value(i).to_string() - } else if let Some(sa) = svc_raw.as_any() - .downcast_ref::() { - sa.value(i).to_string() - } else { - let dict = svc_raw.as_any() - .downcast_ref::>() - .unwrap_or_else(|| panic!("service column: unexpected type {:?}", svc_raw.data_type())); - let keys = dict.keys().as_any().downcast_ref::().unwrap(); - let vals = dict.values().as_any().downcast_ref::().unwrap(); - vals.value(keys.value(i) as usize).to_string() - }; - let avg = avg_col.value(i); - (svc, avg) - }).collect::>() - }).collect(); + let results: Vec<(String, f64)> = batches + .iter() + .flat_map(|batch| { + let svc_raw = batch.column_by_name("service").unwrap(); + let avg_col = batch + .column_by_name("avg_val") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + (0..batch.num_rows()) + .map(|i| { + let svc = string_value(svc_raw.as_ref(), i); + let avg = avg_col.value(i); + (svc, avg) + }) + .collect::>() + }) + .collect(); // Expected: [(api,200), (web,11), (api,400), (web,22), (api,600), (web,33)] let expected = [ @@ -828,11 +858,12 @@ async fn test_substrait_named_table_query() { .await .unwrap(); - let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum(); + let num_rows: usize = batches.iter().map(|b| b.num_rows()).sum(); assert_eq!( - total_rows, 2, + num_rows, 2, "expected 2 metric names (cpu.usage, memory.used)" ); + assert_dict_encoded_string_column(&batches, "metric_name"); // Verify SUM values: cpu.usage = 1+2+3 = 6, memory.used = 10+20+30 = 60 let metric_col = batches[0].column_by_name("metric_name").unwrap(); @@ -843,23 +874,9 @@ async fn test_substrait_named_table_query() { .downcast_ref::() .unwrap(); - // metric_name may come back as StringViewArray or StringArray after aggregation + // metric_name may come back as Utf8View, Utf8, or Dictionary after aggregation. let names: Vec = (0..batches[0].num_rows()) - .map(|i| { - if let Some(sv) = metric_col - .as_any() - .downcast_ref::() - { - sv.value(i).to_string() - } else { - metric_col - .as_any() - .downcast_ref::() - .unwrap() - .value(i) - .to_string() - } - }) + .map(|i| string_value(metric_col.as_ref(), i)) .collect(); assert_eq!(names, vec!["cpu.usage", "memory.used"]); @@ -873,6 +890,39 @@ async fn test_substrait_named_table_query() { "memory.used SUM expected 60.0, got {}", total_col.value(1) ); + + let like_df = ctx + .sql( + r#"SELECT metric_name, SUM(value) as total + FROM "metrics-substrait-test" + WHERE metric_name LIKE 'cpu.%' + GROUP BY metric_name + ORDER BY metric_name"#, + ) + .await + .unwrap(); + let like_plan = like_df.into_optimized_plan().unwrap(); + let like_substrait_plan = to_substrait_plan(&like_plan, &ctx.state()).unwrap(); + let like_stream = builder + .execute_substrait(&like_substrait_plan.encode_to_vec()) + .await + .unwrap(); + let like_batches = datafusion::physical_plan::common::collect(like_stream) + .await + .unwrap(); + + assert_eq!(total_rows(&like_batches), 1); + assert_dict_encoded_string_column(&like_batches, "metric_name"); + assert_eq!( + string_value( + like_batches[0] + .column_by_name("metric_name") + .unwrap() + .as_ref(), + 0 + ), + "cpu.usage" + ); } /// Executes the user-provided Substrait rollup plan directly against real @@ -1022,6 +1072,7 @@ async fn test_rollup_substrait_from_file() { // 3 bins × 2 services (api, web) = 6 rows. let total: usize = batches.iter().map(|b| b.num_rows()).sum(); assert_eq!(total, 6, "expected 6 rows (3 bins × 2 services)"); + assert_dict_encoded_string_column(&batches, "service"); // Expected order: (api,bin0,200), (web,bin0,11), (api,bin30,400), // (web,bin30,22), (api,bin60,600), (web,bin60,33) diff --git a/quickwit/quickwit-datafusion/tests/sketches.rs b/quickwit/quickwit-datafusion/tests/sketches.rs index ea14fbb986b..2af72e16d52 100644 --- a/quickwit/quickwit-datafusion/tests/sketches.rs +++ b/quickwit/quickwit-datafusion/tests/sketches.rs @@ -247,7 +247,8 @@ async fn test_sketch_merge_and_quantile_substrait() { SELECT dd_quantile(dd_sketch(keys, counts, "count", "min", "max", flags), 0.50) AS p50 FROM "datadog-sketches" - WHERE metric_name = 'req.latency' AND timestamp_secs = 600 + -- DataFusion's Substrait producer cannot serialize dictionary literals yet. + WHERE CAST(metric_name AS VARCHAR) = 'req.latency' AND timestamp_secs = 600 "#, ) .await