Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions datafusion/core/tests/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6852,3 +6852,38 @@ async fn test_duplicate_state_fields_for_dfschema_construct() -> Result<()> {

Ok(())
}

/// Regression test for https://github.com/apache/datafusion/issues/21411
/// grouping() should work when wrapped in an alias via the DataFrame API.
///
/// This bug only manifests through the DataFrame API because `.alias()` wraps
/// the `grouping()` call in an `Expr::Alias` node at the aggregate expression
/// level. The SQL planner handles aliasing separately (via projection), so the
/// `ResolveGroupingFunction` analyzer rule never sees an `Expr::Alias` wrapper
/// around the aggregate function in SQL queries — making SQL-based tests
/// insufficient to cover this case.
#[tokio::test]
async fn test_grouping_with_alias() -> Result<()> {
use datafusion_functions_aggregate::expr_fn::grouping;

let df = create_test_table("test")
.await?
.aggregate(vec![col("a")], vec![grouping(col("a")).alias("g")])?
.sort(vec![Sort::new(col("a"), true, false)])?;

let results = df.collect().await?;

let expected = [
"+-----------+---+",
"| a | g |",
"+-----------+---+",
"| 123AbcDef | 0 |",
"| CBAdef | 0 |",
"| abc123 | 0 |",
"| abcDEF | 0 |",
"+-----------+---+",
];
assert_batches_eq!(expected, &results);

Ok(())
}
32 changes: 29 additions & 3 deletions datafusion/optimizer/src/analyzer/resolve_grouping_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ fn replace_grouping_exprs(
.into_iter()
.zip(columns.into_iter().skip(group_expr_len + grouping_id_len))
{
match expr {
Expr::AggregateFunction(ref function) if is_grouping_function(&expr) => {
match &expr {
Expr::AggregateFunction(function) if is_grouping_function(&expr) => {
let grouping_expr = grouping_function_on_id(
function,
&group_expr_to_bitmap_index,
Expand All @@ -110,6 +110,26 @@ fn replace_grouping_exprs(
column.name,
)));
}
Expr::Alias(Alias {
expr: inner_expr,
relation,
name,
..
}) if is_grouping_function(&expr) => {
if let Expr::AggregateFunction(function) = inner_expr.as_ref() {
let grouping_expr = grouping_function_on_id(
function,
&group_expr_to_bitmap_index,
is_grouping_set,
)?;
// Preserve the user-provided alias
projection_exprs.push(Expr::Alias(Alias::new(
grouping_expr,
relation.clone(),
name.clone(),
)));
}
}
_ => {
projection_exprs.push(Expr::Column(column));
new_agg_expr.push(expr);
Expand Down Expand Up @@ -151,7 +171,13 @@ fn analyze_internal(plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
fn is_grouping_function(expr: &Expr) -> bool {
// TODO: Do something better than name here should grouping be a built
// in expression?
matches!(expr, Expr::AggregateFunction(AggregateFunction { func, .. }) if func.name() == "grouping")
match expr {
Expr::AggregateFunction(AggregateFunction { func, .. }) => {
func.name() == "grouping"
}
Expr::Alias(Alias { expr, .. }) => is_grouping_function(expr),
_ => false,
}
}

fn contains_grouping_function(exprs: &[Expr]) -> bool {
Expand Down
Loading