diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 1c661744e0867..c09efc88d02ea 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -6852,3 +6852,50 @@ async fn test_duplicate_state_fields_for_dfschema_construct() -> Result<()> { Ok(()) } + +/// Regression test for https://github.com/apache/datafusion/issues/21411 +/// grouping() should work when wrapped in an alias via the DataFrame API. +/// +/// This bug only manifests through the DataFrame API because `.alias()` wraps +/// the `grouping()` call in an `Expr::Alias` node at the aggregate expression +/// level. The SQL planner handles aliasing separately (via projection), so the +/// `ResolveGroupingFunction` analyzer rule never sees an `Expr::Alias` wrapper +/// around the aggregate function in SQL queries — making SQL-based tests +/// insufficient to cover this case. +#[tokio::test] +async fn test_grouping_with_alias() -> Result<()> { + use datafusion_functions_aggregate::expr_fn::grouping; + + let df = create_test_table("test") + .await? + .aggregate(vec![col("a")], vec![grouping(col("a")).alias("g")])? + .sort(vec![Sort::new(col("a"), true, false)])?; + + let results = df.collect().await?; + + let expected = [ + "+-----------+---+", + "| a | g |", + "+-----------+---+", + "| 123AbcDef | 0 |", + "| CBAdef | 0 |", + "| abc123 | 0 |", + "| abcDEF | 0 |", + "+-----------+---+", + ]; + assert_batches_eq!(expected, &results); + + // Also verify that nested aliases (e.g. .alias("x").alias("g")) work correctly + let df = create_test_table("test") + .await? + .aggregate( + vec![col("a")], + vec![grouping(col("a")).alias("x").alias("g")], + )? + .sort(vec![Sort::new(col("a"), true, false)])?; + + let results = df.collect().await?; + assert_batches_eq!(expected, &results); + + Ok(()) +} diff --git a/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs b/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs index 6b8ae3e8531bc..e9c24eda89b38 100644 --- a/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs +++ b/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs @@ -97,8 +97,8 @@ fn replace_grouping_exprs( .into_iter() .zip(columns.into_iter().skip(group_expr_len + grouping_id_len)) { - match expr { - Expr::AggregateFunction(ref function) if is_grouping_function(&expr) => { + match &expr { + Expr::AggregateFunction(function) if is_grouping_function(&expr) => { let grouping_expr = grouping_function_on_id( function, &group_expr_to_bitmap_index, @@ -110,6 +110,20 @@ fn replace_grouping_exprs( column.name, ))); } + Expr::Alias(Alias { relation, name, .. }) if is_grouping_function(&expr) => { + let function = unwrap_alias_to_grouping_function(&expr)?; + let grouping_expr = grouping_function_on_id( + function, + &group_expr_to_bitmap_index, + is_grouping_set, + )?; + // Preserve the outermost user-provided alias + projection_exprs.push(Expr::Alias(Alias::new( + grouping_expr, + relation.clone(), + name.clone(), + ))); + } _ => { projection_exprs.push(Expr::Column(column)); new_agg_expr.push(expr); @@ -148,10 +162,27 @@ fn analyze_internal(plan: LogicalPlan) -> Result> { Ok(transformed_plan) } +/// Recursively unwrap `Expr::Alias` nodes to reach the inner `AggregateFunction`. +/// Returns an error if the innermost expression is not an `AggregateFunction`, +/// which should not happen if `is_grouping_function` returned true. +fn unwrap_alias_to_grouping_function(expr: &Expr) -> Result<&AggregateFunction> { + match expr { + Expr::AggregateFunction(function) => Ok(function), + Expr::Alias(Alias { expr, .. }) => unwrap_alias_to_grouping_function(expr), + _ => plan_err!("Expected grouping aggregate function inside alias, got {expr}"), + } +} + fn is_grouping_function(expr: &Expr) -> bool { // TODO: Do something better than name here should grouping be a built // in expression? - matches!(expr, Expr::AggregateFunction(AggregateFunction { func, .. }) if func.name() == "grouping") + match expr { + Expr::AggregateFunction(AggregateFunction { func, .. }) => { + func.name() == "grouping" + } + Expr::Alias(Alias { expr, .. }) => is_grouping_function(expr), + _ => false, + } } fn contains_grouping_function(exprs: &[Expr]) -> bool {