diff --git a/datafusion/sql/src/cte.rs b/datafusion/sql/src/cte.rs index 18766d7056355..0368cf622bfe6 100644 --- a/datafusion/sql/src/cte.rs +++ b/datafusion/sql/src/cte.rs @@ -15,12 +15,13 @@ // specific language governing permissions and limitations // under the License. +use std::collections::HashMap; use std::sync::Arc; use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use datafusion_common::{ - Result, not_impl_err, plan_err, + Diagnostic, Result, Span, not_impl_err, plan_err, tree_node::{TreeNode, TreeNodeRecursion}, }; use datafusion_expr::{LogicalPlan, LogicalPlanBuilder, TableSource}; @@ -33,15 +34,35 @@ impl SqlToRel<'_, S> { planner_context: &mut PlannerContext, ) -> Result<()> { let is_recursive = with.recursive; + // Track the span of the first definition of each CTE name, + // so we can point to it in diagnostics when a duplicate is found. + let mut cte_name_spans: HashMap> = HashMap::new(); // Process CTEs from top to bottom for cte in with.cte_tables { // A `WITH` block can't use the same name more than once let cte_name = self.ident_normalizer.normalize(cte.alias.name.clone()); if planner_context.contains_cte(&cte_name) { + let dup_span = Span::try_from_sqlparser_span(cte.alias.name.span); + let mut diagnostic = Diagnostic::new_error( + format!("WITH query name {cte_name:?} specified more than once"), + dup_span, + ); + if let Some(Some(first_span)) = cte_name_spans.get(&cte_name) { + diagnostic.add_note( + format!("{cte_name:?} previously defined here"), + Some(*first_span), + ); + } + diagnostic.add_note("WITH query names must be unique", None); return plan_err!( - "WITH query name {cte_name:?} specified more than once" + "WITH query name {cte_name:?} specified more than once"; + diagnostic = diagnostic ); } + cte_name_spans.insert( + cte_name.clone(), + Span::try_from_sqlparser_span(cte.alias.name.span), + ); // Create a logical plan for the CTE let cte_plan = if is_recursive { diff --git a/datafusion/sql/tests/cases/diagnostic.rs b/datafusion/sql/tests/cases/diagnostic.rs index 7a729739469d3..5456015927bb0 100644 --- a/datafusion/sql/tests/cases/diagnostic.rs +++ b/datafusion/sql/tests/cases/diagnostic.rs @@ -369,6 +369,20 @@ fn test_unary_op_plus_with_non_column() -> Result<()> { Ok(()) } +#[test] +fn test_duplicate_cte_name() -> Result<()> { + let query = "WITH /*first*/cte/*first*/ AS (SELECT 1), /*dup*/cte/*dup*/ AS (SELECT 2) SELECT 1"; + let spans = get_spans(query); + let diag = do_query(query); + assert_snapshot!(diag.message, @r#"WITH query name "cte" specified more than once"#); + assert_eq!(diag.span, Some(spans["dup"])); + assert_snapshot!(diag.notes[0].message, @r#""cte" previously defined here"#); + assert_eq!(diag.notes[0].span, Some(spans["first"])); + assert_snapshot!(diag.notes[1].message, @"WITH query names must be unique"); + assert_eq!(diag.notes[1].span, None); + Ok(()) +} + #[test] fn test_syntax_error() -> Result<()> { // create a table with a column of type varchar