Skip to content
95 changes: 62 additions & 33 deletions datafusion/core/benches/aggregate_query_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

mod data_utils;

use criterion::{Criterion, criterion_group, criterion_main};
use data_utils::create_table_provider;
use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main};
use data_utils::{Utf8PayloadProfile, create_table_provider_with_payload};
use datafusion::error::Result;
use datafusion::execution::context::SessionContext;
use parking_lot::Mutex;
Expand All @@ -36,13 +36,38 @@ fn create_context(
partitions_len: usize,
array_len: usize,
batch_size: usize,
) -> Result<Arc<Mutex<SessionContext>>> {
create_context_with_payload(
partitions_len,
array_len,
batch_size,
Utf8PayloadProfile::Small,
)
}

fn create_context_with_payload(
partitions_len: usize,
array_len: usize,
batch_size: usize,
utf8_payload_profile: Utf8PayloadProfile,
) -> Result<Arc<Mutex<SessionContext>>> {
let ctx = SessionContext::new();
let provider = create_table_provider(partitions_len, array_len, batch_size)?;
let provider = create_table_provider_with_payload(
partitions_len,
array_len,
batch_size,
utf8_payload_profile,
)?;
ctx.register_table("t", provider)?;
Ok(Arc::new(Mutex::new(ctx)))
}

fn string_agg_sql(group_by_column: &str) -> String {
format!(
"SELECT {group_by_column}, string_agg(utf8, ',') FROM t GROUP BY {group_by_column}"
)
}

fn criterion_benchmark(c: &mut Criterion) {
let partitions_len = 8;
let array_len = 32768 * 2; // 2^16
Expand Down Expand Up @@ -296,38 +321,42 @@ fn criterion_benchmark(c: &mut Criterion) {
})
});

c.bench_function("string_agg_query_group_by_few_groups", |b| {
b.iter(|| {
query(
ctx.clone(),
&rt,
"SELECT u64_narrow, string_agg(utf8, ',') \
FROM t GROUP BY u64_narrow",
)
})
});
// These payload sizes keep the original 4-value cardinality while changing
// only the bytes copied into grouped `string_agg` state:
// - small_3b preserves the existing `hi0`..`hi3` baseline
// - medium_64b makes copy costs measurable without overwhelming the query
// - large_1024b stresses both CPU and memory behavior
let string_agg_profiles = [
(Utf8PayloadProfile::Small, "small_3b"),
(Utf8PayloadProfile::Medium, "medium_64b"),
(Utf8PayloadProfile::Large, "large_1024b"),
]
.into_iter()
.map(|(profile, label)| {
(
label,
create_context_with_payload(partitions_len, array_len, batch_size, profile)
.unwrap(),
)
})
.collect::<Vec<_>>();

c.bench_function("string_agg_query_group_by_mid_groups", |b| {
b.iter(|| {
query(
ctx.clone(),
&rt,
"SELECT u64_mid, string_agg(utf8, ',') \
FROM t GROUP BY u64_mid",
)
})
});
let string_agg_queries = [
("few_groups", string_agg_sql("u64_narrow")),
("mid_groups", string_agg_sql("u64_mid")),
("many_groups", string_agg_sql("u64_wide")),
];

c.bench_function("string_agg_query_group_by_many_groups", |b| {
b.iter(|| {
query(
ctx.clone(),
&rt,
"SELECT u64_wide, string_agg(utf8, ',') \
FROM t GROUP BY u64_wide",
)
})
});
let mut string_agg_group = c.benchmark_group("string_agg_payloads");
for (query_name, sql) in &string_agg_queries {
for (payload_name, payload_ctx) in &string_agg_profiles {
string_agg_group
.bench_function(BenchmarkId::new(*query_name, *payload_name), |b| {
b.iter(|| query(payload_ctx.clone(), &rt, sql))
});
}
}
string_agg_group.finish();
}

criterion_group!(benches, criterion_benchmark);
Expand Down
78 changes: 73 additions & 5 deletions datafusion/core/benches/data_utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,23 @@ use rand_distr::{Normal, Pareto};
use std::fmt::Write;
use std::sync::Arc;

/// Payload profile for the benchmark `utf8` column.
///
/// The small profile preserves the existing `hi0`..`hi3` baseline. Medium and
/// large profiles keep the same low cardinality but scale each value's byte
/// width so string aggregation can expose the cost of copying larger payloads.
#[derive(Clone, Copy, Debug)]
pub enum Utf8PayloadProfile {
/// 3-byte baseline values such as `hi0`.
Small,
/// 64-byte payloads that are large enough to make copying noticeable
/// without dominating the benchmark with allocator churn.
Medium,
/// 1024-byte payloads that amplify both CPU and memory pressure in
/// grouped `string_agg` workloads.
Large,
}

/// create an in-memory table given the partition len, array len, and batch size,
/// and the result table will be of array_len in total, and then partitioned, and batched.
#[expect(clippy::allow_attributes)] // some issue where expect(dead_code) doesn't fire properly
Expand All @@ -44,9 +61,32 @@ pub fn create_table_provider(
array_len: usize,
batch_size: usize,
) -> Result<Arc<MemTable>> {
create_table_provider_with_payload(
partitions_len,
array_len,
batch_size,
Utf8PayloadProfile::Small,
)
}

/// Create an in-memory table with a configurable `utf8` payload size.
#[expect(clippy::allow_attributes)] // some issue where expect(dead_code) doesn't fire properly
#[allow(dead_code)]
pub fn create_table_provider_with_payload(
partitions_len: usize,
array_len: usize,
batch_size: usize,
utf8_payload_profile: Utf8PayloadProfile,
) -> Result<Arc<MemTable>> {
let _ = Utf8PayloadProfile::all();
let schema = Arc::new(create_schema());
let partitions =
create_record_batches(&schema, array_len, partitions_len, batch_size);
let partitions = create_record_batches(
&schema,
array_len,
partitions_len,
batch_size,
utf8_payload_profile,
);
// declare a table in memory. In spark API, this corresponds to createDataFrame(...).
MemTable::try_new(schema, partitions).map(Arc::new)
}
Expand Down Expand Up @@ -91,12 +131,14 @@ fn create_record_batch(
rng: &mut StdRng,
batch_size: usize,
batch_index: usize,
payloads: &[String; 4],
) -> RecordBatch {
// Randomly choose from 4 distinct key values; a higher number increases sparseness.
let key_suffixes = [0, 1, 2, 3];
let keys = StringArray::from_iter_values(
(0..batch_size).map(|_| format!("hi{}", key_suffixes.choose(rng).unwrap())),
);
let keys = StringArray::from_iter_values((0..batch_size).map(|_| {
let suffix = *key_suffixes.choose(rng).unwrap();
payloads[suffix].as_str()
}));

let values = create_data(rng, batch_size, 0.5);

Expand Down Expand Up @@ -146,10 +188,12 @@ pub fn create_record_batches(
array_len: usize,
partitions_len: usize,
batch_size: usize,
utf8_payload_profile: Utf8PayloadProfile,
) -> Vec<Vec<RecordBatch>> {
let mut rng = StdRng::seed_from_u64(42);
let mut partitions = Vec::with_capacity(partitions_len);
let batches_per_partition = array_len / batch_size / partitions_len;
let payloads = utf8_payload_profile.payloads();

for _ in 0..partitions_len {
let mut batches = Vec::with_capacity(batches_per_partition);
Expand All @@ -159,13 +203,37 @@ pub fn create_record_batches(
&mut rng,
batch_size,
batch_index,
&payloads,
));
}
partitions.push(batches);
}
partitions
}

impl Utf8PayloadProfile {
fn all() -> [Self; 3] {
[Self::Small, Self::Medium, Self::Large]
}

fn payloads(self) -> [String; 4] {
std::array::from_fn(|idx| match self {
Self::Small => format!("hi{idx}"),
Self::Medium => payload_string("mid", idx, 64),
Self::Large => payload_string("large", idx, 1024),
})
}
}

fn payload_string(prefix: &str, suffix: usize, target_len: usize) -> String {
let mut value = format!("{prefix}{suffix}_");
value.extend(std::iter::repeat_n(
(b'a' + suffix as u8) as char,
target_len - value.len(),
));
value
}

/// An enum that wraps either a regular StringBuilder or a GenericByteViewBuilder
/// so that both can be used interchangeably.
enum TraceIdBuilder {
Expand Down
Loading
Loading