Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 62 additions & 33 deletions datafusion/core/benches/aggregate_query_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

mod data_utils;

use criterion::{Criterion, criterion_group, criterion_main};
use data_utils::create_table_provider;
use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main};
use data_utils::{Utf8PayloadProfile, create_table_provider_with_payload};
use datafusion::error::Result;
use datafusion::execution::context::SessionContext;
use parking_lot::Mutex;
Expand All @@ -36,13 +36,38 @@ fn create_context(
partitions_len: usize,
array_len: usize,
batch_size: usize,
) -> Result<Arc<Mutex<SessionContext>>> {
create_context_with_payload(
partitions_len,
array_len,
batch_size,
Utf8PayloadProfile::Small,
)
}

fn create_context_with_payload(
partitions_len: usize,
array_len: usize,
batch_size: usize,
utf8_payload_profile: Utf8PayloadProfile,
) -> Result<Arc<Mutex<SessionContext>>> {
let ctx = SessionContext::new();
let provider = create_table_provider(partitions_len, array_len, batch_size)?;
let provider = create_table_provider_with_payload(
partitions_len,
array_len,
batch_size,
utf8_payload_profile,
)?;
ctx.register_table("t", provider)?;
Ok(Arc::new(Mutex::new(ctx)))
}

fn string_agg_sql(group_by_column: &str) -> String {
format!(
"SELECT {group_by_column}, string_agg(utf8, ',') FROM t GROUP BY {group_by_column}"
)
}

fn criterion_benchmark(c: &mut Criterion) {
let partitions_len = 8;
let array_len = 32768 * 2; // 2^16
Expand Down Expand Up @@ -296,38 +321,42 @@ fn criterion_benchmark(c: &mut Criterion) {
})
});

c.bench_function("string_agg_query_group_by_few_groups", |b| {
b.iter(|| {
query(
ctx.clone(),
&rt,
"SELECT u64_narrow, string_agg(utf8, ',') \
FROM t GROUP BY u64_narrow",
)
})
});
// These payload sizes keep the original 4-value cardinality while changing
// only the bytes copied into grouped `string_agg` state:
// - small_3b preserves the existing `hi0`..`hi3` baseline
// - medium_64b makes copy costs measurable without overwhelming the query
// - large_1024b stresses both CPU and memory behavior
let string_agg_profiles = [
(Utf8PayloadProfile::Small, "small_3b"),
(Utf8PayloadProfile::Medium, "medium_64b"),
(Utf8PayloadProfile::Large, "large_1024b"),
]
.into_iter()
.map(|(profile, label)| {
(
label,
create_context_with_payload(partitions_len, array_len, batch_size, profile)
.unwrap(),
)
})
.collect::<Vec<_>>();

c.bench_function("string_agg_query_group_by_mid_groups", |b| {
b.iter(|| {
query(
ctx.clone(),
&rt,
"SELECT u64_mid, string_agg(utf8, ',') \
FROM t GROUP BY u64_mid",
)
})
});
let string_agg_queries = [
("few_groups", string_agg_sql("u64_narrow")),
("mid_groups", string_agg_sql("u64_mid")),
("many_groups", string_agg_sql("u64_wide")),
];

c.bench_function("string_agg_query_group_by_many_groups", |b| {
b.iter(|| {
query(
ctx.clone(),
&rt,
"SELECT u64_wide, string_agg(utf8, ',') \
FROM t GROUP BY u64_wide",
)
})
});
let mut string_agg_group = c.benchmark_group("string_agg_payloads");
for (query_name, sql) in &string_agg_queries {
for (payload_name, payload_ctx) in &string_agg_profiles {
string_agg_group
.bench_function(BenchmarkId::new(*query_name, *payload_name), |b| {
b.iter(|| query(payload_ctx.clone(), &rt, sql))
});
}
}
string_agg_group.finish();
}

criterion_group!(benches, criterion_benchmark);
Expand Down
78 changes: 73 additions & 5 deletions datafusion/core/benches/data_utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,23 @@ use rand_distr::{Normal, Pareto};
use std::fmt::Write;
use std::sync::Arc;

/// Payload profile for the benchmark `utf8` column.
///
/// The small profile preserves the existing `hi0`..`hi3` baseline. Medium and
/// large profiles keep the same low cardinality but scale each value's byte
/// width so string aggregation can expose the cost of copying larger payloads.
#[derive(Clone, Copy, Debug)]
pub enum Utf8PayloadProfile {
/// 3-byte baseline values such as `hi0`.
Small,
/// 64-byte payloads that are large enough to make copying noticeable
/// without dominating the benchmark with allocator churn.
Medium,
/// 1024-byte payloads that amplify both CPU and memory pressure in
/// grouped `string_agg` workloads.
Large,
}

/// create an in-memory table given the partition len, array len, and batch size,
/// and the result table will be of array_len in total, and then partitioned, and batched.
#[expect(clippy::allow_attributes)] // some issue where expect(dead_code) doesn't fire properly
Expand All @@ -44,9 +61,32 @@ pub fn create_table_provider(
array_len: usize,
batch_size: usize,
) -> Result<Arc<MemTable>> {
create_table_provider_with_payload(
partitions_len,
array_len,
batch_size,
Utf8PayloadProfile::Small,
)
}

/// Create an in-memory table with a configurable `utf8` payload size.
#[expect(clippy::allow_attributes)] // some issue where expect(dead_code) doesn't fire properly
#[allow(dead_code)]
pub fn create_table_provider_with_payload(
partitions_len: usize,
array_len: usize,
batch_size: usize,
utf8_payload_profile: Utf8PayloadProfile,
) -> Result<Arc<MemTable>> {
let _ = Utf8PayloadProfile::all();
let schema = Arc::new(create_schema());
let partitions =
create_record_batches(&schema, array_len, partitions_len, batch_size);
let partitions = create_record_batches(
&schema,
array_len,
partitions_len,
batch_size,
utf8_payload_profile,
);
// declare a table in memory. In spark API, this corresponds to createDataFrame(...).
MemTable::try_new(schema, partitions).map(Arc::new)
}
Expand Down Expand Up @@ -91,12 +131,14 @@ fn create_record_batch(
rng: &mut StdRng,
batch_size: usize,
batch_index: usize,
payloads: &[String; 4],
) -> RecordBatch {
// Randomly choose from 4 distinct key values; a higher number increases sparseness.
let key_suffixes = [0, 1, 2, 3];
let keys = StringArray::from_iter_values(
(0..batch_size).map(|_| format!("hi{}", key_suffixes.choose(rng).unwrap())),
);
let keys = StringArray::from_iter_values((0..batch_size).map(|_| {
let suffix = *key_suffixes.choose(rng).unwrap();
payloads[suffix].as_str()
}));

let values = create_data(rng, batch_size, 0.5);

Expand Down Expand Up @@ -146,10 +188,12 @@ pub fn create_record_batches(
array_len: usize,
partitions_len: usize,
batch_size: usize,
utf8_payload_profile: Utf8PayloadProfile,
) -> Vec<Vec<RecordBatch>> {
let mut rng = StdRng::seed_from_u64(42);
let mut partitions = Vec::with_capacity(partitions_len);
let batches_per_partition = array_len / batch_size / partitions_len;
let payloads = utf8_payload_profile.payloads();

for _ in 0..partitions_len {
let mut batches = Vec::with_capacity(batches_per_partition);
Expand All @@ -159,13 +203,37 @@ pub fn create_record_batches(
&mut rng,
batch_size,
batch_index,
&payloads,
));
}
partitions.push(batches);
}
partitions
}

impl Utf8PayloadProfile {
fn all() -> [Self; 3] {
[Self::Small, Self::Medium, Self::Large]
}

fn payloads(self) -> [String; 4] {
std::array::from_fn(|idx| match self {
Self::Small => format!("hi{idx}"),
Self::Medium => payload_string("mid", idx, 64),
Self::Large => payload_string("large", idx, 1024),
})
}
}

fn payload_string(prefix: &str, suffix: usize, target_len: usize) -> String {
let mut value = format!("{prefix}{suffix}_");
value.extend(std::iter::repeat_n(
(b'a' + suffix as u8) as char,
target_len - value.len(),
));
value
}

/// An enum that wraps either a regular StringBuilder or a GenericByteViewBuilder
/// so that both can be used interchangeably.
enum TraceIdBuilder {
Expand Down
Loading