From 390d40a5f640bb14144dd527c9da3cdb1cb89b2a Mon Sep 17 00:00:00 2001 From: Anton Burnashev Date: Wed, 15 Apr 2026 14:50:36 +0200 Subject: [PATCH 01/30] explicit `on` and cross-dataset joins in Relation.join() --- dlt/common/libs/sqlglot.py | 8 +- dlt/dataset/_join.py | 118 ++++++++- dlt/dataset/dataset.py | 16 +- dlt/dataset/relation.py | 246 ++++++++++++++---- .../impl/clickhouse/sql_client.py | 15 +- .../impl/sqlalchemy/db_api_client.py | 14 +- dlt/destinations/queries.py | 13 +- dlt/destinations/sql_client.py | 17 +- tests/dataset/conftest.py | 44 ++++ tests/dataset/test_relation_join.py | 211 ++++++++++++++- tests/dataset/utils.py | 118 +++++++++ tests/destinations/test_queries.py | 13 +- 12 files changed, 729 insertions(+), 104 deletions(-) diff --git a/dlt/common/libs/sqlglot.py b/dlt/common/libs/sqlglot.py index 58599344b7..af2d9435c5 100644 --- a/dlt/common/libs/sqlglot.py +++ b/dlt/common/libs/sqlglot.py @@ -948,7 +948,7 @@ def bind_query( qualified_query: sge.Query, sqlglot_schema: Any, # SQLGlotSchema *, - expand_table_name: Callable[[str], List[str]], + expand_table_name: Callable[[str, Optional[str]], List[str]], casefold_identifier: Callable[[str], str], ) -> sge.Query: """Binds a logical query (compliant with dlt schema) to physical tables in the destination dataset. @@ -971,7 +971,9 @@ def bind_query( Args: qualified_query: SQLGlot query expression with qualified table/column references sqlglot_schema: Schema mapping for name validation and column resolution - expand_table_name: Function that expands table name to fully qualified path [catalog, schema, table] + expand_table_name: Function ``(table_name, dataset_name | None) -> [catalog, schema, table]`` + that expands a table name to a fully qualified path. The second argument is the + dataset qualifier from the query (``node.db``), or `None` for the default dataset. casefold_identifier: Case transformation function (`str`, `str.upper`, or `str.lower`) Returns: @@ -993,7 +995,7 @@ def bind_query( # expand named of known tables. this is currently clickhouse things where # we use dataset.table in queries but render those as dataset___table if sqlglot_schema.column_names(node): - expanded_path = expand_table_name(node.name) + expanded_path = expand_table_name(node.name, node.db or None) # set the table name if node.name != expanded_path[-1]: node.this.set("this", expanded_path[-1]) diff --git a/dlt/dataset/_join.py b/dlt/dataset/_join.py index 240116fec9..39be0b12b0 100644 --- a/dlt/dataset/_join.py +++ b/dlt/dataset/_join.py @@ -1,15 +1,18 @@ from __future__ import annotations from functools import reduce -from typing import TYPE_CHECKING, Optional, Sequence +from typing import TYPE_CHECKING, Optional, Sequence, Union +import sqlglot import sqlglot.expressions as sge + from dlt.common.typing import TypedDict from dlt.common.schema import Schema, utils as schema_utils -from dlt.common.schema.typing import TTableReference +from dlt.common.schema.typing import TTableReference, TTableSchemaColumns +from dlt.common.libs.sqlglot import TSqlGlotDialect if TYPE_CHECKING: - from dlt.dataset.relation import TJoinType + from dlt.dataset.relation import Relation, TJoinType _INTERMEDIATE_JOIN_ALIAS_PREFIX = "_dlt_int_t" @@ -268,17 +271,16 @@ def _discover_join_params( def _apply_join_projection( query: sge.Select, *, - schema: Schema, left_table: str, - target_table: str, + target_columns: TTableSchemaColumns, target_qualifier: str, projection_prefix: str, allow_existing_target_projection: bool, ) -> None: """Apply join projection contract onto ``query``. - Preserves the left-side projection and appends only columns from the explicitly - joined ``target_table`` as ``{projection_prefix}__{column}`` aliases. + Preserves the left-side projection and appends only columns from the + joined target as ``{projection_prefix}__{column}`` aliases. ``allow_existing_target_projection`` is used for idempotent re-joins: when a join call contributes no new join edges, all target-prefixed columns may already @@ -304,7 +306,6 @@ def _apply_join_projection( if expr.output_name not in {"", "*"} } - target_columns = schema.tables[target_table]["columns"] target_output_names = { f"{projection_prefix}__{column_name}" for column_name in target_columns.keys() } @@ -376,11 +377,108 @@ def _apply_join( _apply_join_projection( query, - schema=schema, left_table=left_table, - target_table=right_table, + target_columns=schema.tables[right_table]["columns"], target_qualifier=target_qualifier, projection_prefix=projection_prefix, allow_existing_target_projection=not join_params, ) return query + + +def _rewrite_on_qualifiers( + on_expr: sge.Expression, + target_table: str, + internal_alias: str, +) -> sge.Expression: + """Rewrite column qualifiers in the ON expression that reference the target table. + + The user writes ``on="users.id = orders.user_id"`` using logical table names. + Once the target is aliased internally, those references must point to the alias + so the SQL engine can resolve them. + """ + on_expr = on_expr.copy() + for col in on_expr.find_all(sge.Column): + table_node = col.args.get("table") + if isinstance(table_node, sge.Identifier) and table_node.name == target_table: + table_node.set("this", internal_alias) + return on_expr + + +def _apply_explicit_join( + expression: sge.Query, + *, + target: Optional["Relation"] = None, + target_table: str, + target_dataset_name: Optional[str], + target_columns: TTableSchemaColumns, + on: Union[str, sge.Expression], + projection_prefix: str, + kind: "TJoinType", + destination_dialect: TSqlGlotDialect, +) -> sge.Select: + """Apply an explicit-ON join to ``expression`` and return the new query. + + Args: + expression: Left-side query to join onto. + target: Right-hand Relation object (if transformed/subquery), or None for + string / base-table targets. + target_table: Bare table name for schema lookups and projection. + target_dataset_name: Foreign dataset qualifier, or None for local. + target_columns: Columns from the right-hand side for projection. + on: Join condition as a SQL string or sqlglot expression. + projection_prefix: Prefix for appended column aliases. + kind: SQL join type. + destination_dialect: Dialect for parsing string ON expressions. + """ + query = expression.copy() + if not isinstance(query, sge.Select): + raise ValueError(f"Join query `{query}` must be an SQL SELECT statement.") + + internal_alias = f"_dlt_jt_{projection_prefix}" + + # build target expression + target_expr: sge.Expression + if target is not None: + # transformed Relation -> subquery (preserves WHERE, SELECT, etc.) + target_expr = sge.Subquery( + this=target.sqlglot_expression, + alias=sge.TableAlias(this=sge.to_identifier(internal_alias, quoted=False)), + ) + else: + # base-table target (Relation with _table_name, or str) + table_node_args: dict[str, sge.Expression] = { + "this": sge.to_identifier(target_table, quoted=True), + "alias": sge.TableAlias(this=sge.to_identifier(internal_alias, quoted=False)), + } + if target_dataset_name: + table_node_args["db"] = sge.to_identifier(target_dataset_name, quoted=False) + target_expr = sge.Table(**table_node_args) + + if isinstance(on, str): + on_expr = sqlglot.parse_one(on, dialect=destination_dialect) + else: + on_expr = on + + on_expr = _rewrite_on_qualifiers(on_expr, target_table, internal_alias) + + join_expr = sge.Join(this=target_expr, kind=kind.upper()).on(on_expr) + query = query.join(join_expr) + + from_expr = query.args.get("from_") or query.args.get("from") + if not isinstance(from_expr, sge.From) or not isinstance(from_expr.this, sge.Table): + raise ValueError( + "Cannot apply explicit join: left-side query must have a base table " + "in its FROM clause (not a subquery or derived table)." + ) + left_table = from_expr.this.this.name + + _apply_join_projection( + query, + left_table=left_table, + target_columns=target_columns, + target_qualifier=internal_alias, + projection_prefix=projection_prefix, + allow_existing_target_projection=False, + ) + return query diff --git a/dlt/dataset/dataset.py b/dlt/dataset/dataset.py index f336a4e128..451a6e45d6 100644 --- a/dlt/dataset/dataset.py +++ b/dlt/dataset/dataset.py @@ -69,6 +69,8 @@ def __init__( self._default_schema_name: Optional[str] = None self._resolved: bool = False + self._foreign_schemas: Dict[str, List[dlt.Schema]] = {} + self._sql_client: SqlClientBase[Any] = None self._opened_sql_client: SqlClientBase[Any] = None self._table_client: SupportsOpenTables = None @@ -158,14 +160,22 @@ def _ipython_key_completions_(self) -> list[str]: """Provide table names as completion suggestion in interactive environments.""" return self.tables + def _add_foreign_schemas(self, dataset_name: str, schemas: Sequence[dlt.Schema]) -> None: + """Register schemas from a foreign dataset for cross-dataset joins.""" + if dataset_name == self.dataset_name: + return + self._foreign_schemas[dataset_name] = list(schemas) + @property def sqlglot_schema(self) -> SQLGlotSchema: """SQLGlot schema of the dataset derived from all dlt schemas.""" # NOTE: no cache for now, it is probably more expensive to compute the current schema hash # to see wether this is stale than to compute a new sqlglot schema - return lineage.create_sqlglot_schema( - {self.dataset_name: list(self.schemas)}, dialect=self.destination_dialect - ) + schema_map: Dict[str, Sequence[dlt.Schema]] = { + self.dataset_name: list(self.schemas), + **self._foreign_schemas, + } + return lineage.create_sqlglot_schema(schema_map, dialect=self.destination_dialect) @property def destination_dialect(self) -> TSqlGlotDialect: diff --git a/dlt/dataset/relation.py b/dlt/dataset/relation.py index 102bf42307..1162278cb3 100644 --- a/dlt/dataset/relation.py +++ b/dlt/dataset/relation.py @@ -1,6 +1,5 @@ from __future__ import annotations from collections.abc import Collection, Sequence -from functools import partial from typing import ( overload, Union, @@ -35,10 +34,10 @@ from dlt.common.typing import Self, TSortOrder, TypedDict from dlt.common.exceptions import ValueErrorWithKnownValues from dlt.dataset import lineage -from dlt.destinations.sql_client import SqlClientBase, WithSqlClient +from dlt.destinations.sql_client import SqlClientBase, WithSchemas, WithSqlClient from dlt.destinations.queries import bind_query, build_select_expr from dlt.common.destination.dataset import SupportsDataAccess -from dlt.dataset._join import _apply_join +from dlt.dataset._join import _apply_join, _apply_explicit_join if TYPE_CHECKING: @@ -254,14 +253,16 @@ def to_sql(self, pretty: bool = False, *, _raw_query: bool = False) -> str: query = self.sqlglot_expression else: _, _qualified_query = _get_relation_output_columns_schema(self) + + def _expand(table_name: str, db: Optional[str] = None) -> list[str]: + return self.sql_client.make_qualified_table_name_path( + table_name, quote=False, casefold=False, dataset_name=db + ) + query = bind_query( qualified_query=_qualified_query, sqlglot_schema=self._dataset.sqlglot_schema, - expand_table_name=partial( - self.sql_client.make_qualified_table_name_path, - quote=False, - casefold=False, - ), + expand_table_name=_expand, casefold_identifier=self.sql_client.capabilities.casefold_identifier, ) @@ -358,33 +359,46 @@ def order_by(self, column_name: str, direction: TSortOrder = "asc") -> Self: rel._sqlglot_expression = rel.sqlglot_expression.order_by(order_expr) return rel + @overload def join( self, other: str | Self, *, kind: TJoinType = "inner", alias: Optional[str] = None, - ) -> Self: - """Join this relation to another table using dlt schema references. + ) -> Self: ... - Join conditions are discovered automatically from the schema's reference - chain (parent/child/root relationships created by dlt during loading). - Both the current relation and ``other`` must be base-table relations - (i.e., created via ``dataset[table_name]``, not transformed with - ``.select()``/``.where()`` etc.). + @overload + def join( + self, + other: str | Self, + on: str | sge.Expression, + *, + kind: TJoinType = "inner", + alias: Optional[str] = None, + ) -> Self: ... - This method is designed for the common case of navigating dlt's - built-in table hierarchy. For more complex join scenarios — such as - custom join predicates, joining on non-reference columns, self-joins, - or multi-way joins with mixed conditions — use ``Relation.to_ibis()`` - to obtain an ibis table expression and construct the join manually:: + def join( + self, + other: str | Self, + on: str | sge.Expression | None = None, + *, + kind: TJoinType = "inner", + alias: Optional[str] = None, + ) -> Self: + """Join this relation to another table. - t1 = dataset["orders"].to_ibis() - t2 = dataset["products"].to_ibis() - joined = t1.join(t2, t1.product_id == t2.id, how="left") + Without ``on``, join conditions are discovered automatically from the + schema's reference chain (parent/child/root relationships created by + dlt during loading). With ``on``, an explicit join predicate is used + instead — this also enables cross-dataset joins. Args: - other: Table name or base-table relation to join. + other: Table name or Relation to join. For cross-dataset joins, + pass a Relation from a different ``dlt.Dataset``. + on: Explicit join condition as an SQL string or sqlglot expression. + Required for cross-dataset joins and joins between tables + without dlt schema references. kind: Type of SQL join: ``"inner"``, ``"left"``, ``"right"``, or ``"full"``. alias: Projection prefix for the joined table's columns. Columns @@ -392,53 +406,152 @@ def join( the target table name. Returns: - A new relation with the join(s) applied and the target table's + A new relation with the join applied and the target table's columns appended to the projection. Raises: - ValueError: If schema references between the two tables cannot be - resolved, or if either relation is not join-eligible. + ValueError: If the join cannot be resolved. + + Example:: + + # auto join (schema references) + dataset["orders"].join("users") + + # explicit ON + dataset["orders"].join("users", on="orders._dlt_parent_id = users._dlt_id") + + # cross-dataset join + local["orders"].join( + foreign["products"], + on="orders.product_id = products.id", + ) """ if alias == "": raise ValueError("`alias` must be a non-empty string when provided.") - if not self._table_name: - raise ValueError("This relation has no base table to resolve references.") - - if isinstance(other, dlt.Relation): - # TODO: remove once we allow cross-dataset joins - if not ( - self._dataset.is_same_physical_destination(other._dataset) - and self._dataset.dataset_name == other._dataset.dataset_name - ): - raise ValueError( - "Cannot join relations from different datasets: " - f"'{other._dataset.dataset_name}' vs '{self._dataset.dataset_name}'" - ) - target_table = other._table_name - if not target_table: - raise ValueError(f"Relation `{other}` has no base table to resolve references.") - else: - target_table = other + target_dataset, target_table, target_columns = self._resolve_join_target(other, on=on) - if not target_table or not isinstance(target_table, str): - raise ValueError("`other` must be a table name or a base table relation.") - if target_table not in self._dataset.schema.tables: - raise ValueError(f"Table `{target_table}` not found in dataset schema") + # self-join detection + if target_table == self._table_name and target_dataset is self._dataset: + raise ValueError("Self-joins are not supported.") projection_prefix = alias or target_table - query = _apply_join( - self.sqlglot_expression, - schema=self._dataset.schema, - left_table=self._table_name, - right_table=target_table, - projection_prefix=projection_prefix, - kind=kind, - ) + + if on is None: + if not self._table_name: + raise ValueError("This relation has no base table to resolve references.") + if target_dataset is not self._dataset: + raise ValueError("`on` is required when joining relations from different datasets.") + if target_table not in self._dataset.schema.tables: + raise ValueError(f"Table `{target_table}` not found in dataset schema") + query = _apply_join( + self.sqlglot_expression, + schema=self._dataset.schema, + left_table=self._table_name, + right_table=target_table, + projection_prefix=projection_prefix, + kind=kind, + ) + else: + if target_dataset is not self._dataset: + self._dataset._add_foreign_schemas( + target_dataset.dataset_name, + list(target_dataset.schemas), + ) + # pass Relation as target when it's been transformed so it + # becomes a subquery (preserving WHERE, SELECT, LIMIT, etc.) + subquery_rhs: Optional[Relation] = ( + other if isinstance(other, dlt.Relation) and other._query is not None else None + ) + query = _apply_explicit_join( + self.sqlglot_expression, + target=subquery_rhs, + target_table=target_table, + target_dataset_name=( + target_dataset.dataset_name if target_dataset is not self._dataset else None + ), + target_columns=target_columns, + on=on, + projection_prefix=projection_prefix, + kind=kind, + destination_dialect=self.destination_dialect, + ) + rel = self.__copy__() rel._sqlglot_expression = query return rel + def _resolve_join_target( + self, + other: Union[str, Self], + *, + on: Union[str, sge.Expression, None] = None, + ) -> tuple[dlt.Dataset, str, TTableSchemaColumns]: + """Resolve the target dataset, table name, and columns for a join. + + Returns: + Tuple of (target_dataset, target_table_name, target_columns). + """ + if isinstance(other, dlt.Relation): + target_dataset = other._dataset + + # physical destination check + if target_dataset is not self._dataset: + if not self._dataset.is_same_physical_destination(target_dataset): + raise ValueError( + "Cannot join relations from different physical destinations: " + f"'{target_dataset.dataset_name}' vs '{self._dataset.dataset_name}'" + ) + # cross-dataset filesystem not supported + if isinstance(self.sql_client, WithSchemas): + raise ValueError( + "Cross-dataset joins are not supported on filesystem destinations." + ) + + target_table = other._table_name + is_transformed = other._query is not None + if target_table and not is_transformed: + # pristine base-table Relation: look up columns from schema + target_columns = _find_table_columns(target_dataset.schemas, target_table) + elif target_table and is_transformed: + # transformed Relation that still tracks its origin table + # (e.g., .where(), .select()); use its actual output columns + target_columns = other.columns_schema + else: + # no base table at all (e.g., from .query()) + if on is None: + raise ValueError(f"Relation `{other}` has no base table to resolve references.") + target_table = _extract_subquery_alias(other) + target_columns = other.columns_schema + elif isinstance(other, str): + if "." in other: + ds_name, tbl_name = other.split(".", 1) + if ds_name == self._dataset.dataset_name: + target_dataset = self._dataset + elif ds_name in self._dataset._foreign_schemas: + target_dataset = self._dataset + # columns come from the foreign schemas already registered + target_table = tbl_name + target_columns = _find_table_columns( + self._dataset._foreign_schemas[ds_name], tbl_name + ) + return target_dataset, target_table, target_columns + else: + raise ValueError( + f"Dataset `{ds_name}` is not registered. Pass a Relation from the " + "foreign dataset to automatically register its schema." + ) + target_table = tbl_name + target_columns = _find_table_columns(target_dataset.schemas, target_table) + else: + target_dataset = self._dataset + target_table = other + target_columns = _find_table_columns(target_dataset.schemas, target_table) + else: + raise ValueError("`other` must be a table name or a base table relation.") + + return target_dataset, target_table, target_columns + # NOTE we currently force to have one column selected; we could be more flexible # and rewrite the query to compute the AGG of all selected columns # `SELECT AGG(col1), AGG(col2), ... FROM table`` @@ -877,3 +990,22 @@ def _add_load_id_via_parent_key(relation: dlt.Relation) -> dlt.Relation: rel = relation.__copy__() rel._sqlglot_expression = query return rel + + +def _find_table_columns(schemas: Sequence[dlt.Schema], table_name: str) -> TTableSchemaColumns: + """Find the columns schema for a table across a sequence of schemas.""" + for schema in schemas: + if table_name in schema.tables: + return schema.tables[table_name]["columns"] + raise ValueError(f"Table `{table_name}` not found in dataset schema") + + +def _extract_subquery_alias(relation: dlt.Relation) -> str: + """Extract a stable alias for a transformed Relation without a base table.""" + expr = relation.sqlglot_expression + from_expr = expr.args.get("from_") or expr.args.get("from") + if isinstance(from_expr, sge.From) and isinstance(from_expr.this, sge.Table): + table_id = from_expr.this.this + if isinstance(table_id, sge.Identifier): + return table_id.name + return "subquery" diff --git a/dlt/destinations/impl/clickhouse/sql_client.py b/dlt/destinations/impl/clickhouse/sql_client.py index d94611aa8f..9e264e11e8 100644 --- a/dlt/destinations/impl/clickhouse/sql_client.py +++ b/dlt/destinations/impl/clickhouse/sql_client.py @@ -287,14 +287,21 @@ def catalog_name(self, quote: bool = True, casefold: bool = True) -> Optional[st return database_name def make_qualified_table_name_path( - self, table_name: Optional[str], quote: bool = True, casefold: bool = True + self, + table_name: Optional[str], + quote: bool = True, + casefold: bool = True, + dataset_name: Optional[str] = None, ) -> List[str]: # get catalog and dataset - path = super().make_qualified_table_name_path(None, quote=quote, casefold=casefold) + path = super().make_qualified_table_name_path( + None, quote=quote, casefold=casefold, dataset_name=dataset_name + ) + effective_dataset = dataset_name or self.dataset_name if table_name: # table name combines dataset name and table name - if self.dataset_name: - table_name = f"{self.dataset_name}{self.config.dataset_table_separator}{table_name}" + if effective_dataset: + table_name = f"{effective_dataset}{self.config.dataset_table_separator}{table_name}" else: # without dataset just use the table name pass diff --git a/dlt/destinations/impl/sqlalchemy/db_api_client.py b/dlt/destinations/impl/sqlalchemy/db_api_client.py index 1e924d4e65..5119bc51ee 100644 --- a/dlt/destinations/impl/sqlalchemy/db_api_client.py +++ b/dlt/destinations/impl/sqlalchemy/db_api_client.py @@ -337,19 +337,23 @@ def create_table(self, table_obj: sa.Table) -> None: table_obj.create(self._current_connection) def make_qualified_table_name_path( - self, table_name: Optional[str], quote: bool = True, casefold: bool = True + self, + table_name: Optional[str], + quote: bool = True, + casefold: bool = True, + dataset_name: Optional[str] = None, ) -> List[str]: path: List[str] = [] # no catalog for sqlalchemy if catalog_name := self.catalog_name(quote=quote, casefold=casefold): path.append(catalog_name) - dataset_name = self.dataset_name + effective_dataset = dataset_name or self.dataset_name if self.dialect.requires_name_normalize and casefold: # type: ignore[attr-defined] - dataset_name = str(self.dialect.normalize_name(dataset_name)) # type: ignore[func-returns-value] + effective_dataset = str(self.dialect.normalize_name(effective_dataset)) # type: ignore[func-returns-value] if quote: - dataset_name = self.dialect.identifier_preparer.quote_identifier(dataset_name) # type: ignore[attr-defined] - path.append(dataset_name) + effective_dataset = self.dialect.identifier_preparer.quote_identifier(effective_dataset) # type: ignore[attr-defined] + path.append(effective_dataset) if table_name: if self.dialect.requires_name_normalize and casefold: # type: ignore[attr-defined] table_name = str(self.dialect.normalize_name(table_name)) # type: ignore[func-returns-value] diff --git a/dlt/destinations/queries.py b/dlt/destinations/queries.py index e171e2a5e8..c3ef72afb5 100644 --- a/dlt/destinations/queries.py +++ b/dlt/destinations/queries.py @@ -1,5 +1,4 @@ -from functools import partial -from typing import Any, List +from typing import Any, List, Optional import sqlglot.expressions as sge from sqlglot.schema import Schema as SQLGlotSchema @@ -20,12 +19,16 @@ def _normalize_query( TODO: remove after next dlthub release """ + + def _expand(table_name: str, db: Optional[str] = None) -> List[str]: + return sql_client.make_qualified_table_name_path( + table_name, quote=False, casefold=False, dataset_name=db + ) + return bind_query( qualified_query, sqlglot_schema, - expand_table_name=partial( - sql_client.make_qualified_table_name_path, quote=False, casefold=False - ), + expand_table_name=_expand, casefold_identifier=casefold_identifier, ) diff --git a/dlt/destinations/sql_client.py b/dlt/destinations/sql_client.py index c5198724b6..c77bb5f324 100644 --- a/dlt/destinations/sql_client.py +++ b/dlt/destinations/sql_client.py @@ -225,20 +225,27 @@ def make_qualified_table_name( # TODO make it a staticmethod to avoid passing SQLClient instances all around def make_qualified_table_name_path( - self, table_name: Optional[str], quote: bool = True, casefold: bool = True + self, + table_name: Optional[str], + quote: bool = True, + casefold: bool = True, + dataset_name: Optional[str] = None, ) -> List[str]: """Returns a list with path components leading from catalog to table_name. Used to construct fully qualified names. `table_name` is optional. + + Args: + dataset_name: Override the default dataset name for cross-dataset references. """ path: List[str] = [] if catalog_name := self.catalog_name(quote=quote, casefold=casefold): path.append(catalog_name) - dataset_name = self.dataset_name + effective_dataset = dataset_name or self.dataset_name if casefold: - dataset_name = self.capabilities.casefold_identifier(self.dataset_name) + effective_dataset = self.capabilities.casefold_identifier(effective_dataset) if quote: - dataset_name = self.capabilities.escape_identifier(dataset_name) - path.append(dataset_name) + effective_dataset = self.capabilities.escape_identifier(effective_dataset) + path.append(effective_dataset) if table_name: if casefold: table_name = self.capabilities.casefold_identifier(table_name) diff --git a/tests/dataset/conftest.py b/tests/dataset/conftest.py index e2c34178f0..ec9ad4cf66 100644 --- a/tests/dataset/conftest.py +++ b/tests/dataset/conftest.py @@ -8,9 +8,12 @@ from tests.dataset.utils import ( LOAD_0_STATS, LOAD_1_STATS, + TCrossDsFixture, TLoadsFixture, annotated_references, crm, + inventory, + relational_tables, ) from tests.utils import ( auto_test_run_context, @@ -86,6 +89,47 @@ def dataset_with_loads( raise ValueError(f"Unknown dataset fixture: {request.param}") +@pytest.fixture(scope="module") +def dataset_with_relational_tables(module_tmp_path: pathlib.Path) -> dlt.Dataset: + pipeline = dlt.pipeline( + pipeline_name="relational_tables", + pipelines_dir=str(module_tmp_path / "pipelines_dir"), + destination=dlt.destinations.duckdb(str(module_tmp_path / "relational.db")), + dev_mode=True, + ) + pipeline.run(relational_tables()) + return pipeline.dataset() + + +@pytest.fixture(scope="module") +def cross_dataset_duckdb(module_tmp_path: pathlib.Path) -> TCrossDsFixture: + db_path = str(module_tmp_path / "cross_dataset.db") + + # dataset A: CRM data (users + orders) + pipeline_a = dlt.pipeline( + pipeline_name="cross_ds_a", + pipelines_dir=str(module_tmp_path / "pipelines_dir"), + destination=dlt.destinations.duckdb(db_path), + dataset_name="crm_data", + dev_mode=True, + ) + source_a = crm(0) + source_a.root_key = True + pipeline_a.run(source_a) + + # dataset B: inventory data (products + warehouses) + pipeline_b = dlt.pipeline( + pipeline_name="cross_ds_b", + pipelines_dir=str(module_tmp_path / "pipelines_dir"), + destination=dlt.destinations.duckdb(db_path), + dataset_name="inv_data", + dev_mode=True, + ) + pipeline_b.run(inventory()) + + return pipeline_a.dataset(), pipeline_b.dataset() + + @pytest.fixture(scope="module") def dataset_with_annotated_references(module_tmp_path: pathlib.Path) -> dlt.Dataset: pipeline = dlt.pipeline( diff --git a/tests/dataset/test_relation_join.py b/tests/dataset/test_relation_join.py index d83c0ef39c..4747895e0f 100644 --- a/tests/dataset/test_relation_join.py +++ b/tests/dataset/test_relation_join.py @@ -13,7 +13,7 @@ _to_join_ref, ) from dlt.dataset.relation import TJoinType -from tests.dataset.utils import TLoadsFixture +from tests.dataset.utils import TCrossDsFixture, TLoadsFixture class _ColumnRef(TypedDict): @@ -208,8 +208,7 @@ def test_resolve_reference_chain_rejects_self_join(dataset_with_loads: TLoadsFix @pytest.mark.parametrize("dataset_with_loads", ["with_root_key"], indirect=True) -def test_join_rejects_cross_dataset(dataset_with_loads: TLoadsFixture) -> None: - """Test that joining relations from different datasets raises an error.""" +def test_join_rejects_different_physical_destination(dataset_with_loads: TLoadsFixture) -> None: dataset, _, _ = dataset_with_loads with tempfile.TemporaryDirectory() as tmp: @@ -227,12 +226,11 @@ def other_data(): pipeline.run([other_data]) other_dataset = pipeline.dataset() - # Try to join with a relation from the other dataset rel = dataset.table("users") other_rel = other_dataset.table("other_data") - with pytest.raises(ValueError, match="different datasets"): - rel.join(other_rel) + with pytest.raises(ValueError, match="different physical destinations"): + rel.join(other_rel, on="users._dlt_id = other_data._dlt_id") @pytest.mark.parametrize( @@ -300,7 +298,7 @@ def test_resolve_reference_chain_rejects_unrelated_tables( pytest.param( lambda ds: ds.table("users"), "users", - "Cannot join a table to itself", + "Self-joins are not supported", id="self-join", ), pytest.param( @@ -920,3 +918,202 @@ def test_join_columns_schema_resolves_with_name_mutating_normalizer( for column_name in normalized_dataset.schema.tables[normalized_right]["columns"].keys() } assert expected_right_aliases.issubset(schema_cols) + + +def test_explicit_on_joins_relational_tables( + dataset_with_relational_tables: dlt.Dataset, +) -> None: + ds = dataset_with_relational_tables + joined = ds.table("customers").join("orders", on="customers.customer_id = orders.customer_id") + df = joined.df() + assert len(df) == 4 + assert "orders__amount" in df.columns + assert list(df["orders__amount"]) == [50.0, 75.0, 200.0, 30.0] + + # auto join should fail: no dlt reference between customers and orders + with pytest.raises(ValueError, match="Unable to resolve reference chain"): + ds.table("customers").join("orders") + + +def test_explicit_on_accepts_sqlglot_expression( + dataset_with_relational_tables: dlt.Dataset, +) -> None: + ds = dataset_with_relational_tables + on_expr = sge.EQ( + this=sge.Column( + table=sge.to_identifier("customers"), + this=sge.to_identifier("country_code"), + ), + expression=sge.Column( + table=sge.to_identifier("countries"), + this=sge.to_identifier("code"), + ), + ) + joined = ds.table("customers").join("countries", on=on_expr) + df = joined.df() + assert len(df) == 3 + assert list(df["countries__name"]) == ["Germany", "France", "Germany"] + + +def test_explicit_on_non_eq_predicate( + dataset_with_relational_tables: dlt.Dataset, +) -> None: + ds = dataset_with_relational_tables + joined = ds.table("customers").join( + "orders", + on="customers.customer_id = orders.customer_id AND orders.amount > 50", + ) + df = joined.df() + assert len(df) == 2 + assert list(df["orders__amount"]) == [75.0, 200.0] + + +def test_explicit_on_projection_prefix( + dataset_with_relational_tables: dlt.Dataset, +) -> None: + ds = dataset_with_relational_tables + joined = ds.table("customers").join( + "orders", on="customers.customer_id = orders.customer_id", alias="o" + ) + selects = joined.sqlglot_expression.selects + right_aliases = {expr.output_name for expr in selects if expr.output_name.startswith("o__")} + assert right_aliases + expected = {f"o__{col}" for col in ds.schema.tables["orders"]["columns"].keys()} + assert right_aliases == expected + + +def test_explicit_on_rejects_empty_alias( + dataset_with_relational_tables: dlt.Dataset, +) -> None: + ds = dataset_with_relational_tables + with pytest.raises(ValueError, match="must be a non-empty string"): + ds.table("customers").join( + "orders", on="customers.customer_id = orders.customer_id", alias="" + ) + + +def test_explicit_on_rejects_self_join( + dataset_with_relational_tables: dlt.Dataset, +) -> None: + ds = dataset_with_relational_tables + with pytest.raises(ValueError, match="Self-joins are not supported"): + ds.table("customers").join( + "customers", + on="customers.customer_id = customers.customer_id", + alias="c2", + ) + + +def test_explicit_on_with_filtered_rhs( + dataset_with_relational_tables: dlt.Dataset, +) -> None: + ds = dataset_with_relational_tables + expensive_orders = ds.table("orders").where("amount", "gt", 50.0) + joined = ds.table("customers").join( + expensive_orders, on="customers.customer_id = orders.customer_id" + ) + df = joined.df() + assert len(df) == 2 + assert list(df["name"]) == ["Alice", "Bob"] + assert list(df["orders__amount"]) == [75.0, 200.0] + + +def test_explicit_on_with_projected_rhs( + dataset_with_relational_tables: dlt.Dataset, +) -> None: + ds = dataset_with_relational_tables + narrow_orders = ds.table("orders").select("order_id", "customer_id") + joined = ds.table("customers").join( + narrow_orders, on="customers.customer_id = orders.customer_id" + ) + df = joined.df() + assert len(df) == 4 + rhs_cols = {c for c in df.columns if c.startswith("orders__")} + assert rhs_cols == {"orders__order_id", "orders__customer_id"} + assert "orders__amount" not in df.columns + + +def test_cross_dataset_join_registers_foreign_schemas( + cross_dataset_duckdb: TCrossDsFixture, +) -> None: + """Cross-dataset join registers the foreign dataset's schemas.""" + ds_a, ds_b = cross_dataset_duckdb + users = ds_a.table("users") + purchases = ds_b.table("purchases") + + assert ds_b.dataset_name not in ds_a._foreign_schemas + + users.join(purchases, on="users.id = purchases.user_id") + + assert ds_b.dataset_name in ds_a._foreign_schemas + foreign_schemas = ds_a._foreign_schemas[ds_b.dataset_name] + assert len(foreign_schemas) >= 1 + + +def test_cross_dataset_join_requires_on( + cross_dataset_duckdb: TCrossDsFixture, +) -> None: + ds_a, ds_b = cross_dataset_duckdb + users = ds_a.table("users") + purchases = ds_b.table("purchases") + + with pytest.raises(ValueError, match="`on` is required"): + users.join(purchases) + + +def test_cross_dataset_join_e2e( + cross_dataset_duckdb: TCrossDsFixture, +) -> None: + ds_a, ds_b = cross_dataset_duckdb + users = ds_a.table("users") + purchases = ds_b.table("purchases") + + joined = users.join(purchases, on="users.id = purchases.user_id") + df = joined.df() + assert len(df) == 3 + assert "purchases__sku" in df.columns + assert "purchases__quantity" in df.columns + assert sorted(df["purchases__sku"]) == ["G-001", "W-001", "W-001"] + + +_MATCHED = { + "purchases__purchase_id": [1, 2, 3], + "purchases__user_id": [1, 1, 2], + "purchases__sku": ["W-001", "G-001", "W-001"], + "purchases__quantity": [2, 1, 1], + "name": ["Alice", "Alice", "Bob"], +} +_MATCHED_PLUS_ORPHAN = { + "purchases__purchase_id": [1, 2, 3, 4], + "purchases__user_id": [1, 1, 2, 99], + "purchases__sku": ["W-001", "G-001", "W-001", "D-001"], + "purchases__quantity": [2, 1, 1, 5], + "name": ["Alice", "Alice", "Bob", None], # orphan's matched user name is NULL +} + + +@pytest.mark.parametrize( + "kind,expected", + [ + # inner + left: both users match, so LEFT adds no extra rows + pytest.param("inner", _MATCHED, id="inner"), + pytest.param("left", _MATCHED, id="left"), + # right + full: orphan purchase appears with NULL on the user side + pytest.param("right", _MATCHED_PLUS_ORPHAN, id="right"), + pytest.param("full", _MATCHED_PLUS_ORPHAN, id="full"), + ], +) +def test_cross_dataset_join_kind_parameter( + cross_dataset_duckdb: TCrossDsFixture, + kind: TJoinType, + expected: dict[str, list[Any]], +) -> None: + ds_a, ds_b = cross_dataset_duckdb + users = ds_a.table("users") + purchases = ds_b.table("purchases") + + joined = users.join(purchases, on="users.id = purchases.user_id", kind=kind) + df = joined.df() + + for col, expected_values in expected.items(): + assert list(df[col]) == expected_values, f"column `{col}` mismatch" diff --git a/tests/dataset/utils.py b/tests/dataset/utils.py index f866289a3d..46fdf242bb 100644 --- a/tests/dataset/utils.py +++ b/tests/dataset/utils.py @@ -51,8 +51,44 @@ class AccountMembershipRow(TypedDict): user_name: str +class WarehouseRow(TypedDict): + warehouse_id: int + city: str + + +class InventoryItemRow(TypedDict): + sku: str + warehouse_id: int + quantity: int + + +class PurchaseRow(TypedDict): + purchase_id: int + user_id: int + sku: str + quantity: int + + +class CustomerRow(TypedDict): + customer_id: int + name: str + country_code: str + + +class CustomerOrderRow(TypedDict): + order_id: int + customer_id: int + amount: float + + +class CountryRow(TypedDict): + code: str + name: str + + TLoadStats = dict[str, int] TLoadsFixture = tuple[dlt.Dataset, tuple[str, str], tuple[TLoadStats, TLoadStats]] +TCrossDsFixture = tuple[dlt.Dataset, dlt.Dataset] USERS_DATA_0: list[UserRow] = [ @@ -144,6 +180,88 @@ def products(batch_idx: int): return [users(i), products(i)] +WAREHOUSES: list[WarehouseRow] = [ + {"warehouse_id": 1, "city": "Berlin"}, + {"warehouse_id": 2, "city": "Paris"}, +] + +INVENTORY_ITEMS: list[InventoryItemRow] = [ + {"sku": "W-001", "warehouse_id": 1, "quantity": 50}, + {"sku": "G-001", "warehouse_id": 2, "quantity": 30}, + {"sku": "D-001", "warehouse_id": 1, "quantity": 10}, +] + +PURCHASES: list[PurchaseRow] = [ + {"purchase_id": 1, "user_id": 1, "sku": "W-001", "quantity": 2}, + {"purchase_id": 2, "user_id": 1, "sku": "G-001", "quantity": 1}, + {"purchase_id": 3, "user_id": 2, "sku": "W-001", "quantity": 1}, + {"purchase_id": 4, "user_id": 99, "sku": "D-001", "quantity": 5}, +] + + +@dlt.source +def inventory(): + @dlt.resource(name="warehouses") + def warehouses(): + yield WAREHOUSES + + @dlt.resource( + name="inventory_items", + references=[ + { + "referenced_table": "warehouses", + "columns": ["warehouse_id"], + "referenced_columns": ["warehouse_id"], + } + ], + ) + def inventory_items(): + yield INVENTORY_ITEMS + + @dlt.resource(name="purchases") + def purchases(): + yield PURCHASES + + return [warehouses(), inventory_items(), purchases()] + + +CUSTOMERS: list[CustomerRow] = [ + {"customer_id": 1, "name": "Alice", "country_code": "DE"}, + {"customer_id": 2, "name": "Bob", "country_code": "FR"}, + {"customer_id": 3, "name": "Charlie", "country_code": "DE"}, +] + +CUSTOMER_ORDERS: list[CustomerOrderRow] = [ + {"order_id": 100, "customer_id": 1, "amount": 50.0}, + {"order_id": 101, "customer_id": 1, "amount": 75.0}, + {"order_id": 102, "customer_id": 2, "amount": 200.0}, + {"order_id": 103, "customer_id": 3, "amount": 30.0}, +] + +COUNTRIES: list[CountryRow] = [ + {"code": "DE", "name": "Germany"}, + {"code": "FR", "name": "France"}, + {"code": "ES", "name": "Spain"}, +] + + +@dlt.source +def relational_tables(): + @dlt.resource(name="customers") + def customers(): + yield CUSTOMERS + + @dlt.resource(name="orders") + def orders(): + yield CUSTOMER_ORDERS + + @dlt.resource(name="countries") + def countries(): + yield COUNTRIES + + return [customers(), orders(), countries()] + + @dlt.source def annotated_references(): @dlt.resource(name="users") diff --git a/tests/destinations/test_queries.py b/tests/destinations/test_queries.py index a41c65e880..757e614701 100644 --- a/tests/destinations/test_queries.py +++ b/tests/destinations/test_queries.py @@ -1,5 +1,4 @@ -from functools import partial -from typing import cast +from typing import List, Optional, cast import duckdb import pytest @@ -130,12 +129,16 @@ def test_normalize_query(): ) with duckdb_destination_client.sql_client as sql_client: + + def _expand(table_name: str, db: Optional[str] = None) -> List[str]: + return sql_client.make_qualified_table_name_path( + table_name, quote=False, casefold=False, dataset_name=db + ) + normalized_query_expr = bind_query( qualified_query=cast(sge.Query, qualified_query_expr), sqlglot_schema=sqlglot_schema, - expand_table_name=partial( - sql_client.make_qualified_table_name_path, quote=False, casefold=False - ), + expand_table_name=_expand, casefold_identifier=sql_client.capabilities.casefold_identifier, ) normalized_query = normalized_query_expr.sql() From 5b2b25922e2a63a5e1c11a854651f41e84d2760a Mon Sep 17 00:00:00 2001 From: Anton Burnashev Date: Thu, 23 Apr 2026 17:57:40 +0200 Subject: [PATCH 02/30] change the dataset equality check --- dlt/dataset/dataset.py | 6 ++++++ dlt/dataset/relation.py | 14 +++++++------- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/dlt/dataset/dataset.py b/dlt/dataset/dataset.py index 451a6e45d6..71494efeb0 100644 --- a/dlt/dataset/dataset.py +++ b/dlt/dataset/dataset.py @@ -160,6 +160,12 @@ def _ipython_key_completions_(self) -> list[str]: """Provide table names as completion suggestion in interactive environments.""" return self.tables + def _is_same_dataset(self, other: dlt.Dataset) -> bool: + """Whether `other` represents the same logical dataset.""" + # TODO currently only compares dataset name, + # once harderned, conside implementing __eq__ based on this method + return self.dataset_name == other.dataset_name + def _add_foreign_schemas(self, dataset_name: str, schemas: Sequence[dlt.Schema]) -> None: """Register schemas from a foreign dataset for cross-dataset joins.""" if dataset_name == self.dataset_name: diff --git a/dlt/dataset/relation.py b/dlt/dataset/relation.py index 1162278cb3..9d6b4e1037 100644 --- a/dlt/dataset/relation.py +++ b/dlt/dataset/relation.py @@ -431,8 +431,10 @@ def join( target_dataset, target_table, target_columns = self._resolve_join_target(other, on=on) + is_same_dataset = self._dataset._is_same_dataset(target_dataset) + # self-join detection - if target_table == self._table_name and target_dataset is self._dataset: + if target_table == self._table_name and is_same_dataset: raise ValueError("Self-joins are not supported.") projection_prefix = alias or target_table @@ -440,7 +442,7 @@ def join( if on is None: if not self._table_name: raise ValueError("This relation has no base table to resolve references.") - if target_dataset is not self._dataset: + if not is_same_dataset: raise ValueError("`on` is required when joining relations from different datasets.") if target_table not in self._dataset.schema.tables: raise ValueError(f"Table `{target_table}` not found in dataset schema") @@ -453,7 +455,7 @@ def join( kind=kind, ) else: - if target_dataset is not self._dataset: + if not is_same_dataset: self._dataset._add_foreign_schemas( target_dataset.dataset_name, list(target_dataset.schemas), @@ -467,9 +469,7 @@ def join( self.sqlglot_expression, target=subquery_rhs, target_table=target_table, - target_dataset_name=( - target_dataset.dataset_name if target_dataset is not self._dataset else None - ), + target_dataset_name=(None if is_same_dataset else target_dataset.dataset_name), target_columns=target_columns, on=on, projection_prefix=projection_prefix, @@ -496,7 +496,7 @@ def _resolve_join_target( target_dataset = other._dataset # physical destination check - if target_dataset is not self._dataset: + if not self._dataset._is_same_dataset(target_dataset): if not self._dataset.is_same_physical_destination(target_dataset): raise ValueError( "Cannot join relations from different physical destinations: " From 4f43ef74dc51eeeb75e0e862f6631b480929b1af Mon Sep 17 00:00:00 2001 From: Anton Burnashev Date: Tue, 26 May 2026 18:17:47 +0200 Subject: [PATCH 03/30] change foreign schema handling --- dlt/dataset/_join.py | 188 ++++++++++++++++++++-------- dlt/dataset/dataset.py | 16 +-- dlt/dataset/relation.py | 41 +++--- tests/dataset/conftest.py | 29 +++++ tests/dataset/test_relation_join.py | 66 ++++++---- tests/dataset/utils.py | 21 ++++ 6 files changed, 260 insertions(+), 101 deletions(-) diff --git a/dlt/dataset/_join.py b/dlt/dataset/_join.py index 860177b3cd..b3d9b1e1eb 100644 --- a/dlt/dataset/_join.py +++ b/dlt/dataset/_join.py @@ -1,7 +1,7 @@ from __future__ import annotations from functools import reduce -from typing import TYPE_CHECKING, Optional, Sequence, Union +from typing import TYPE_CHECKING, Any, Optional, Sequence, Set, TypeVar, Union import sqlglot import sqlglot.expressions as sge @@ -16,6 +16,8 @@ _INTERMEDIATE_JOIN_ALIAS_PREFIX = "_dlt_int_t" +_TExpr = TypeVar("_TExpr", bound=sge.Expression) + class _JoinRef(TypedDict): """A resolved join step from currently attached table to a target table.""" @@ -154,25 +156,36 @@ def _build_join_condition_from_pairs( return reduce(lambda x, y: sge.And(this=x, expression=y), conditions) +def _identifier_name(node: Any) -> Optional[str]: + """Return the string name of an sqlglot identifier-or-string node.""" + if isinstance(node, sge.Identifier): + return node.name + if isinstance(node, str): + return node + return None + + +def _subquery_alias_name(subquery: sge.Subquery) -> Optional[str]: + """Return the alias name of a subquery, or `None`.""" + alias_expr = subquery.args.get("alias") + if not isinstance(alias_expr, sge.TableAlias): + return None + return _identifier_name(alias_expr.this) + + def _extract_table_qualifier(table_expr: sge.Expression) -> Optional[tuple[str, str]]: if not isinstance(table_expr, sge.Table): return None - table_identifier = table_expr.args.get("this") - if isinstance(table_identifier, sge.Identifier): - table_name = table_identifier.name - elif isinstance(table_identifier, str): - table_name = table_identifier - else: + table_name = _identifier_name(table_expr.args.get("this")) + if table_name is None: return None alias_expr = table_expr.args.get("alias") if isinstance(alias_expr, sge.TableAlias): - alias_identifier = alias_expr.this - if isinstance(alias_identifier, sge.Identifier): - return table_name, alias_identifier.name - if isinstance(alias_identifier, str): - return table_name, alias_identifier + alias_name = _identifier_name(alias_expr.this) + if alias_name is not None: + return table_name, alias_name return table_name, table_name @@ -268,20 +281,18 @@ def _discover_join_params( return joins, target_qualifier -def _normalize_left_projection(query: sge.Select, left_table: str) -> list[sge.Expression]: - """Qualify the left-side projection so an added JOIN can't leak right-side columns. - - Bare `Star` becomes `.*`; unqualified `Column`s get their - `table` set to ``. - """ - origin_identifier = sge.to_identifier(left_table, quoted=False) +def _normalize_left_projection( + query: sge.Select, left_source_qualifier: str +) -> list[sge.Expression]: + """Qualify the left-side projection so an added JOIN cannot leak right-side columns.""" + origin_identifier = sge.to_identifier(left_source_qualifier, quoted=False) normalized: list[sge.Expression] = [] for expr in query.selects: if isinstance(expr, sge.Star): - normalized.append(sge.Column(table=origin_identifier, this=sge.Star())) + normalized.append(sge.Column(table=origin_identifier.copy(), this=sge.Star())) elif isinstance(expr, sge.Column) and expr.args.get("table") is None: expr_copy = expr.copy() - expr_copy.set("table", origin_identifier) + expr_copy.set("table", origin_identifier.copy()) normalized.append(expr_copy) else: normalized.append(expr) @@ -291,7 +302,7 @@ def _normalize_left_projection(query: sge.Select, left_table: str) -> list[sge.E def _apply_join_projection( query: sge.Select, *, - left_table: str, + left_source_qualifier: str, target_columns: TTableSchemaColumns, target_qualifier: str, projection_prefix: str, @@ -307,7 +318,7 @@ def _apply_join_projection( exist in the left projection and should be accepted as a no-op instead of raising a collision error. """ - normalized_left_expressions = _normalize_left_projection(query, left_table) + normalized_left_expressions = _normalize_left_projection(query, left_source_qualifier) existing_projection_column_names = { expr.output_name @@ -387,10 +398,12 @@ def _apply_join( ) query = query.join(join_expr) + left_source_qualifier = _left_source_qualifier(query) or left_table + if project: _apply_join_projection( query, - left_table=left_table, + left_source_qualifier=left_source_qualifier, target_columns=schema.get_table_columns(right_table), target_qualifier=target_qualifier, projection_prefix=projection_prefix, @@ -399,26 +412,82 @@ def _apply_join( else: # filter-only join: qualify the left projection so a bare `*` does not # expand across the joined table and leak right-side columns at runtime. - query.set("expressions", _normalize_left_projection(query, left_table)) + query.set("expressions", _normalize_left_projection(query, left_source_qualifier)) return query -def _rewrite_on_qualifiers( +def _qualify_physical_tables_with_dataset(expression: _TExpr, dataset_name: str) -> _TExpr: + """Bind every physical table reference in ``expression`` to ``dataset_name``.""" + expression = expression.copy() + cte_names = {cte.alias_or_name for cte in expression.find_all(sge.CTE)} + db_identifier = sge.to_identifier(dataset_name, quoted=False) + for table in expression.find_all(sge.Table): + if table.name in cte_names: + continue + if table.args.get("db"): + continue + table.set("db", db_identifier.copy()) + return expression + + +def _left_source_qualifier(query: sge.Query) -> Optional[str]: + """Return the qualifier used to reference the FROM source (alias or table name).""" + from_expr = query.args.get("from_") or query.args.get("from") + if not isinstance(from_expr, sge.From): + return None + from_this = from_expr.this + if isinstance(from_this, sge.Table): + result = _extract_table_qualifier(from_this) + return result[1] if result else None + if isinstance(from_this, sge.Subquery): + return _subquery_alias_name(from_this) + return None + + +def _collect_left_qualifiers(query: sge.Query) -> Set[str]: + """Collect qualifiers (table names or aliases) the LHS exposes to ON binding.""" + qualifiers: Set[str] = set() + sources: list[sge.Expression] = [] + + from_expr = query.args.get("from_") or query.args.get("from") + if isinstance(from_expr, sge.From) and from_expr.this is not None: + sources.append(from_expr.this) + + for join in query.args.get("joins") or []: + if join.this is not None: + sources.append(join.this) + + for source in sources: + if isinstance(source, sge.Table): + result = _extract_table_qualifier(source) + if result: + qualifiers.add(result[1]) + elif isinstance(source, sge.Subquery): + alias_name = _subquery_alias_name(source) + if alias_name is not None: + qualifiers.add(alias_name) + + return qualifiers + + +def _bind_on_predicate( on_expr: sge.Expression, - target_table: str, - internal_alias: str, + *, + left_qualifiers: Set[str], + right_qualifiers: Set[str], + right_internal_alias: str, ) -> sge.Expression: - """Rewrite column qualifiers in the ON expression that reference the target table. - - The user writes ``on="users.id = orders.user_id"`` using logical table names. - Once the target is aliased internally, those references must point to the alias - so the SQL engine can resolve them. - """ + """Rewrite RHS-side column qualifiers in ``on_expr`` to the internal RHS alias.""" on_expr = on_expr.copy() for col in on_expr.find_all(sge.Column): table_node = col.args.get("table") - if isinstance(table_node, sge.Identifier) and table_node.name == target_table: - table_node.set("this", internal_alias) + if not isinstance(table_node, sge.Identifier): + continue + qualifier = table_node.name + if qualifier in left_qualifiers: + continue + if qualifier in right_qualifiers: + col.set("table", sge.to_identifier(right_internal_alias, quoted=False)) return on_expr @@ -433,6 +502,7 @@ def _apply_explicit_join( projection_prefix: str, kind: "TJoinType", destination_dialect: TSqlGlotDialect, + left_dataset_name: str, ) -> sge.Select: """Apply an explicit-ON join to ``expression`` and return the new query. @@ -441,29 +511,44 @@ def _apply_explicit_join( target: Right-hand Relation object (if transformed/subquery), or None for string / base-table targets. target_table: Bare table name for schema lookups and projection. - target_dataset_name: Foreign dataset qualifier, or None for local. + target_dataset_name: Dataset name for the right-hand side. target_columns: Columns from the right-hand side for projection. on: Join condition as a SQL string or sqlglot expression. projection_prefix: Prefix for appended column aliases. kind: SQL join type. destination_dialect: Dialect for parsing string ON expressions. + left_dataset_name: Dataset name for the left-hand side. """ query = expression.copy() if not isinstance(query, sge.Select): raise ValueError(f"Join query `{query}` must be an SQL SELECT statement.") + # bind LHS physical tables to the LHS dataset before composing the join. + # otherwise, adding the RHS dataset to the resolver makes bare LHS tables + # ambiguous + query = _qualify_physical_tables_with_dataset(query, left_dataset_name) + + from_expr = query.args.get("from_") or query.args.get("from") + if not isinstance(from_expr, sge.From) or not isinstance(from_expr.this, sge.Table): + raise ValueError( + "Cannot apply explicit join: left-side query must have a base table " + "in its FROM clause (not a subquery or derived table)." + ) + left_source_qualifier = _left_source_qualifier(query) or from_expr.this.name + internal_alias = f"_dlt_jt_{projection_prefix}" - # build target expression target_expr: sge.Expression - if target is not None: - # transformed Relation -> subquery (preserves WHERE, SELECT, etc.) + if target is not None and target._query is not None: + # transformed Relation: embed as subquery + rhs_inner = target.sqlglot_expression + if target_dataset_name: + rhs_inner = _qualify_physical_tables_with_dataset(rhs_inner, target_dataset_name) target_expr = sge.Subquery( - this=target.sqlglot_expression, + this=rhs_inner, alias=sge.TableAlias(this=sge.to_identifier(internal_alias, quoted=False)), ) else: - # base-table target (Relation with _table_name, or str) table_node_args: dict[str, sge.Expression] = { "this": sge.to_identifier(target_table, quoted=True), "alias": sge.TableAlias(this=sge.to_identifier(internal_alias, quoted=False)), @@ -477,22 +562,21 @@ def _apply_explicit_join( else: on_expr = on - on_expr = _rewrite_on_qualifiers(on_expr, target_table, internal_alias) + left_qualifiers = _collect_left_qualifiers(query) + right_qualifiers = {target_table, projection_prefix} + on_expr = _bind_on_predicate( + on_expr, + left_qualifiers=left_qualifiers, + right_qualifiers=right_qualifiers, + right_internal_alias=internal_alias, + ) join_expr = sge.Join(this=target_expr, kind=kind.upper()).on(on_expr) query = query.join(join_expr) - from_expr = query.args.get("from_") or query.args.get("from") - if not isinstance(from_expr, sge.From) or not isinstance(from_expr.this, sge.Table): - raise ValueError( - "Cannot apply explicit join: left-side query must have a base table " - "in its FROM clause (not a subquery or derived table)." - ) - left_table = from_expr.this.this.name - _apply_join_projection( query, - left_table=left_table, + left_source_qualifier=left_source_qualifier, target_columns=target_columns, target_qualifier=internal_alias, projection_prefix=projection_prefix, diff --git a/dlt/dataset/dataset.py b/dlt/dataset/dataset.py index 6338727872..02f3df0bda 100644 --- a/dlt/dataset/dataset.py +++ b/dlt/dataset/dataset.py @@ -70,8 +70,6 @@ def __init__( self._default_schema_name: Optional[str] = None self._resolved: bool = False - self._foreign_schemas: Dict[str, List[dlt.Schema]] = {} - self._sql_client: SqlClientBase[Any] = None self._opened_sql_client: SqlClientBase[Any] = None self._table_client: SupportsOpenTables = None @@ -167,22 +165,14 @@ def _is_same_dataset(self, other: dlt.Dataset) -> bool: # once harderned, conside implementing __eq__ based on this method return self.dataset_name == other.dataset_name - def _add_foreign_schemas(self, dataset_name: str, schemas: Sequence[dlt.Schema]) -> None: - """Register schemas from a foreign dataset for cross-dataset joins.""" - if dataset_name == self.dataset_name: - return - self._foreign_schemas[dataset_name] = list(schemas) - @property def sqlglot_schema(self) -> SQLGlotSchema: """SQLGlot schema of the dataset derived from all dlt schemas.""" # NOTE: no cache for now, it is probably more expensive to compute the current schema hash # to see wether this is stale than to compute a new sqlglot schema - schema_map: Dict[str, Sequence[dlt.Schema]] = { - self.dataset_name: list(self.schemas), - **self._foreign_schemas, - } - return lineage.create_sqlglot_schema(schema_map, dialect=self.destination_dialect) + return lineage.create_sqlglot_schema( + {self.dataset_name: list(self.schemas)}, dialect=self.destination_dialect + ) @property def destination_dialect(self) -> TSqlGlotDialect: diff --git a/dlt/dataset/relation.py b/dlt/dataset/relation.py index 83709aee87..87a4e1f46d 100644 --- a/dlt/dataset/relation.py +++ b/dlt/dataset/relation.py @@ -17,6 +17,7 @@ from sqlglot import maybe_parse from sqlglot.optimizer.merge_subqueries import merge_subqueries from sqlglot.expressions import ExpOrStr as SqlglotExprOrStr +from sqlglot.schema import Schema as SQLGlotSchema import sqlglot.expressions as sge @@ -116,6 +117,7 @@ def __init__( self._sqlglot_expression: sge.Query = None self._schema: Optional[TTableSchemaColumns] = None self._incremental_ctx: Optional[_RelationIncrementalContext] = None + self._foreign_schemas: dict[str, list[dlt.Schema]] = {} def df(self, *args: Any, **kwargs: Any) -> pd.DataFrame | None: with self._cursor() as cursor: @@ -271,7 +273,7 @@ def _expand(table_name: str, db: Optional[str] = None) -> list[str]: query = bind_query( qualified_query=_qualified_query, - sqlglot_schema=self._dataset.sqlglot_schema, + sqlglot_schema=self._relation_sqlglot_schema(), expand_table_name=_expand, casefold_identifier=self.sql_client.capabilities.casefold_identifier, ) @@ -465,11 +467,6 @@ def join( kind=kind, ) else: - if not is_same_dataset: - self._dataset._add_foreign_schemas( - target_dataset.dataset_name, - list(target_dataset.schemas), - ) # pass Relation as target when it's been transformed so it # becomes a subquery (preserving WHERE, SELECT, LIMIT, etc.) subquery_rhs: Optional[Relation] = ( @@ -485,10 +482,21 @@ def join( projection_prefix=projection_prefix, kind=kind, destination_dialect=self.destination_dialect, + left_dataset_name=self._dataset.dataset_name, ) rel = self.__copy__() rel._sqlglot_expression = query + + # carry the RHS relation's foreign datasets + if isinstance(other, dlt.Relation): + for ds_name, schemas in other._foreign_schemas.items(): + if ds_name == self._dataset.dataset_name: + continue + rel._foreign_schemas[ds_name] = list(schemas) + if not is_same_dataset: + rel._foreign_schemas[target_dataset.dataset_name] = list(target_dataset.schemas) + return rel def _resolve_join_target( @@ -538,21 +546,18 @@ def _resolve_join_target( ds_name, tbl_name = other.split(".", 1) if ds_name == self._dataset.dataset_name: target_dataset = self._dataset - elif ds_name in self._dataset._foreign_schemas: + target_table = tbl_name + target_columns = _find_table_columns(target_dataset.schemas, target_table) + elif ds_name in self._foreign_schemas: target_dataset = self._dataset - # columns come from the foreign schemas already registered target_table = tbl_name - target_columns = _find_table_columns( - self._dataset._foreign_schemas[ds_name], tbl_name - ) + target_columns = _find_table_columns(self._foreign_schemas[ds_name], tbl_name) return target_dataset, target_table, target_columns else: raise ValueError( f"Dataset `{ds_name}` is not registered. Pass a Relation from the " "foreign dataset to automatically register its schema." ) - target_table = tbl_name - target_columns = _find_table_columns(target_dataset.schemas, target_table) else: target_dataset = self._dataset target_table = other @@ -981,8 +986,16 @@ def __copy__(self) -> Self: rel = self.__class__(dataset=self._dataset, query=self.sqlglot_expression) rel._table_name = self._table_name rel._incremental_ctx = self._incremental_ctx + rel._foreign_schemas = {k: list(v) for k, v in self._foreign_schemas.items()} return rel + def _relation_sqlglot_schema(self) -> SQLGlotSchema: + schema_map: dict[str, Sequence[dlt.Schema]] = { + self._dataset.dataset_name: list(self._dataset.schemas), + **self._foreign_schemas, + } + return lineage.create_sqlglot_schema(schema_map, dialect=self.destination_dialect) + def _get_relation_output_columns_schema( relation: dlt.Relation, @@ -994,7 +1007,7 @@ def _get_relation_output_columns_schema( columns_schema, normalized_query = lineage.compute_columns_schema( # use dlt schema compliant query so lineage will work correctly on non case folded identifiers relation.sqlglot_expression, - relation._dataset.sqlglot_schema, + relation._relation_sqlglot_schema(), dialect=relation.destination_dialect, infer_sqlglot_schema=infer_sqlglot_schema, allow_anonymous_columns=allow_anonymous_columns, diff --git a/tests/dataset/conftest.py b/tests/dataset/conftest.py index 0b0bfd88d8..42c0e34426 100644 --- a/tests/dataset/conftest.py +++ b/tests/dataset/conftest.py @@ -13,6 +13,7 @@ annotated_references, crm, inventory, + marketing_users, relational_tables, ) from tests.utils import ( @@ -130,6 +131,34 @@ def cross_dataset_duckdb(module_tmp_path: pathlib.Path) -> TCrossDsFixture: return pipeline_a.dataset(), pipeline_b.dataset() +@pytest.fixture(scope="module") +def same_named_cross_dataset_duckdb(module_tmp_path: pathlib.Path) -> TCrossDsFixture: + # Below both datasets have a `users` table, but with different schema and data + db_path = str(module_tmp_path / "same_named_cross_dataset.db") + + pipeline_a = dlt.pipeline( + pipeline_name="same_name_cross_ds_a", + pipelines_dir=str(module_tmp_path / "pipelines_dir"), + destination=dlt.destinations.duckdb(db_path), + dataset_name="crm_data", + dev_mode=True, + ) + source_a = crm(0) + source_a.root_key = True + pipeline_a.run(source_a) + + pipeline_b = dlt.pipeline( + pipeline_name="same_name_cross_ds_b", + pipelines_dir=str(module_tmp_path / "pipelines_dir"), + destination=dlt.destinations.duckdb(db_path), + dataset_name="marketing_data", + dev_mode=True, + ) + pipeline_b.run(marketing_users()) + + return pipeline_a.dataset(), pipeline_b.dataset() + + @pytest.fixture(scope="module") def dataset_with_annotated_references(module_tmp_path: pathlib.Path) -> dlt.Dataset: pipeline = dlt.pipeline( diff --git a/tests/dataset/test_relation_join.py b/tests/dataset/test_relation_join.py index a52271ea57..d3b7be5eb9 100644 --- a/tests/dataset/test_relation_join.py +++ b/tests/dataset/test_relation_join.py @@ -1033,22 +1033,26 @@ def test_explicit_on_with_projected_rhs( assert "orders__amount" not in df.columns -def test_cross_dataset_join_registers_foreign_schemas( +def test_cross_dataset_join( cross_dataset_duckdb: TCrossDsFixture, ) -> None: - """Cross-dataset join registers the foreign dataset's schemas.""" ds_a, ds_b = cross_dataset_duckdb users = ds_a.table("users") purchases = ds_b.table("purchases") - assert ds_b.dataset_name not in ds_a._foreign_schemas - - users.join(purchases, on="users.id = purchases.user_id") + joined = users.join(purchases, on="users.id = purchases.user_id") - assert ds_b.dataset_name in ds_a._foreign_schemas - foreign_schemas = ds_a._foreign_schemas[ds_b.dataset_name] + assert ds_b.dataset_name in joined._foreign_schemas + assert ds_b.dataset_name not in users._foreign_schemas + foreign_schemas = joined._foreign_schemas[ds_b.dataset_name] assert len(foreign_schemas) >= 1 + df = joined.df() + assert len(df) == 3 + assert "purchases__sku" in df.columns + assert "purchases__quantity" in df.columns + assert sorted(df["purchases__sku"]) == ["G-001", "W-001", "W-001"] + def test_cross_dataset_join_requires_on( cross_dataset_duckdb: TCrossDsFixture, @@ -1061,21 +1065,6 @@ def test_cross_dataset_join_requires_on( users.join(purchases) -def test_cross_dataset_join_e2e( - cross_dataset_duckdb: TCrossDsFixture, -) -> None: - ds_a, ds_b = cross_dataset_duckdb - users = ds_a.table("users") - purchases = ds_b.table("purchases") - - joined = users.join(purchases, on="users.id = purchases.user_id") - df = joined.df() - assert len(df) == 3 - assert "purchases__sku" in df.columns - assert "purchases__quantity" in df.columns - assert sorted(df["purchases__sku"]) == ["G-001", "W-001", "W-001"] - - _MATCHED = { "purchases__purchase_id": [1, 2, 3], "purchases__user_id": [1, 1, 2], @@ -1127,3 +1116,36 @@ def test_join_does_not_project_incomplete_target_columns( assert rows is not None # 3 products inner-joined to 2 categories on category_id → 3 rows assert len(rows) == 3 + + +def test_cross_dataset_join_with_transformed_rhs_preserves_foreign_dataset_binding( + cross_dataset_duckdb: TCrossDsFixture, +) -> None: + ds_a, ds_b = cross_dataset_duckdb + users = ds_a.table("users") + filtered_purchases = ds_b.table("purchases").where("quantity", "gt", 1) + + joined = users.join(filtered_purchases, on="users.id = purchases.user_id").order_by("id") + df = joined.df() + + assert len(df) == 1 + assert list(df["name"]) == ["Alice"] + assert list(df["purchases__purchase_id"]) == [1] + assert list(df["purchases__sku"]) == ["W-001"] + assert list(df["purchases__quantity"]) == [2] + + +def test_cross_dataset_join_with_same_table_names_keeps_sources_unambiguous( + same_named_cross_dataset_duckdb: TCrossDsFixture, +) -> None: + ds_a, ds_b = same_named_cross_dataset_duckdb + crm_users = ds_a.query("SELECT * FROM users AS crm_users") + marketing_users = ds_b.table("users") + + joined = crm_users.join(marketing_users, on="crm_users.id = users.id", alias="marketing") + df = joined.order_by("id").df() + + assert len(df) == 2 + assert list(df["id"]) == [1, 2] + assert list(df["name"]) == ["Alice", "Bob"] + assert list(df["marketing__segment"]) == ["pro", "free"] diff --git a/tests/dataset/utils.py b/tests/dataset/utils.py index 46fdf242bb..3f44e41ebc 100644 --- a/tests/dataset/utils.py +++ b/tests/dataset/utils.py @@ -69,6 +69,11 @@ class PurchaseRow(TypedDict): quantity: int +class MarketingUserRow(TypedDict): + id: int + segment: str + + class CustomerRow(TypedDict): customer_id: int name: str @@ -199,6 +204,13 @@ def products(batch_idx: int): ] +MARKETING_USERS: list[MarketingUserRow] = [ + {"id": 1, "segment": "pro"}, + {"id": 2, "segment": "free"}, + {"id": 4, "segment": "trial"}, +] + + @dlt.source def inventory(): @dlt.resource(name="warehouses") @@ -225,6 +237,15 @@ def purchases(): return [warehouses(), inventory_items(), purchases()] +@dlt.source +def marketing_users(): + @dlt.resource(name="users") + def users(): + yield MARKETING_USERS + + return [users()] + + CUSTOMERS: list[CustomerRow] = [ {"customer_id": 1, "name": "Alice", "country_code": "DE"}, {"customer_id": 2, "name": "Bob", "country_code": "FR"}, From 07baafa4b13e5f88ca9aac15169923d93c075810 Mon Sep 17 00:00:00 2001 From: Anton Burnashev Date: Tue, 26 May 2026 19:25:56 +0200 Subject: [PATCH 04/30] draft docs --- .../general-usage/dataset-access/dataset.md | 51 ++++++++++--- .../dataset_snippets/dataset_snippets.py | 73 ++++++++++++++++++- tests/dataset/utils.py | 2 +- 3 files changed, 111 insertions(+), 15 deletions(-) diff --git a/docs/website/docs/general-usage/dataset-access/dataset.md b/docs/website/docs/general-usage/dataset-access/dataset.md index 64d66bc49d..a2223bb6f2 100644 --- a/docs/website/docs/general-usage/dataset-access/dataset.md +++ b/docs/website/docs/general-usage/dataset-access/dataset.md @@ -175,35 +175,64 @@ See [Incremental transformations](../../hub/features/transformations/index.md#in ### Join related tables -The `join()` method follows relationships already defined in the dlt schema. It can resolve direct schema references between tables as well as multi-hop parent/child paths when one table is an ancestor or descendant of the other. This makes `join()` well suited for navigating nested tables created by dlt and tables connected by explicit references. Joined columns are appended from the target table only and are prefixed with the target table name, or with the alias you provide. +The `join()` method appends a related table to the current relation. It works in two modes: + +- [Auto-join via schema references](#auto-join-via-schema-references): dlt builds the join condition from parent/child relationships dlt creates during loading, plus any `references` you declared on a resource. +- [Explicit `on` predicate](#explicit-join-condition): when you pass `on=`, you write the join condition yourself. Use it for any join the auto mode cannot do, including [joins across two datasets](#cross-dataset-joins) on the same physical destination. By default, `join()` creates an `inner` join. Use `kind="left"`, `"right"`, or `"full"` to choose another SQL join type. -When you do not specify an alias, joined columns use the joined table name as their prefix. For example, `dataset["users"].join("users__orders")` adds columns such as `users__orders__order_id`. When you pass `alias="orders"`, the same column is projected as `orders__order_id` instead. Use `alias` to make result columns easier to read or to avoid output name conflicts. +When you do not specify an `alias`, joined columns use the target table name as their prefix. For example, `dataset["users"].join("users__orders")` adds columns such as `users__orders__order_id`. When you pass `alias="orders"`, the same column is projected as `orders__order_id` instead. Use `alias` to make result columns easier to read or to avoid output name conflicts. + +#### Auto-join via schema references + +With no `on` argument, `join()` follows relationships already defined in the dlt schema. It can resolve direct schema references between tables as well as multi-hop parent/child paths when one table is an ancestor or descendant of the other. This makes the auto mode well suited for navigating nested tables created by dlt and tables connected by explicit references. Joined columns are appended from the target table only and are prefixed with the target table name, or with the alias you provide. -**Limits:** `join()` only works when dlt can resolve a supported schema-defined path between the current relation's base table and the target table. Both sides must be base-table relations, for example `dataset["users"].join("users__orders")`. You cannot call `join()` after transforming a relation with methods such as `select()` or `where()`. +The auto mode works on relations from `dataset[name]` or `dataset.table(name)`, and on relations chained from them with `where()`, `select()`, `order_by()`, and similar methods. It does not work on relations from `dataset.query("...")`, use the explicit form below for those cases. -`join()` does not support: +The auto mode does not support: - arbitrary join conditions -- joins on columns that are not defined as schema references +- joins on columns that you pick yourself - self-joins - joins across different datasets - joins between tables that are only related indirectly through a shared ancestor or another non-linear schema path -In practice, this means `join()` supports ancestor/descendant navigation, but not general graph traversal across the schema. - -For example: +In practice, this means the auto mode supports ancestor/descendant navigation, but not general graph traversal across the schema: - `dataset["users__orders__items"].join("users")` works because `users` is an ancestor in the nested table hierarchy - joining two sibling tables just because both descend from `users` does not work -- joining two tables on a custom predicate such as `orders.customer_email = customers.email` does not work unless that relationship is defined in the schema +- joining two tables on a custom predicate such as `orders.customer_email = customers.email` does not work: use the explicit form below instead + +When the auto mode needs intermediate tables to reach the target, those tables are used only to build the join path. Their columns are not added to the result automatically. Only columns from the explicitly joined target table are appended. + +#### Explicit join condition + +Pass `on=` to write the join condition yourself, as a SQL string or a `sqlglot` expression. Use this form whenever the auto mode does not work for your tables: for example, when joining two top-level tables that dlt did not create from a parent/child relationship. + + + +The right-hand side can be a table name, a table relation, or a relation you already transformed with `select()`, `where()`, etc. When you pass a transformed relation, its filters and column selection carry over to the joined result. + +The left-hand side can be a table relation, a relation chained from one with `where()`, `select()`, `order_by()`, and similar methods, or a `dataset.query("...")` that reads from a single table. + +Self-joins are not supported, even with explicit `on`. For self-joins, multi-way joins with mixed conditions, or fully programmatic join construction, use [Ibis](#modifying-queries-with-ibis-expressions). + +#### Cross-dataset joins + +When you pass `on`, the right-hand side may be a `Relation` from a different `dlt.Dataset`, as long as both datasets share the same physical destination — for example, two pipelines that write to the same DuckDB file, or to the same database server with different dataset/schema names. + + + +Cross-dataset joins: -When `join()` needs intermediate tables to reach the target, those tables are used only to build the join path. Their columns are not added to the result automatically. Only columns from the explicitly joined target table are appended. +- require an explicit `on` condition: the auto mode does not span datasets +- are rejected when the two relations live on different physical destinations +- are not supported on filesystem destinations -For arbitrary join logic, use Ibis. +When two datasets share table names that would otherwise clash in the join (for example, both have a `users` table), give one side a stable alias in your SQL e.g. with `dataset.query("SELECT ... AS alias_name FROM users")` — and refer to that alias in `on`. Without an alias, `join()` cannot tell the two tables apart and will raise. ### Chain operations diff --git a/docs/website/docs/general-usage/dataset-access/dataset_snippets/dataset_snippets.py b/docs/website/docs/general-usage/dataset-access/dataset_snippets/dataset_snippets.py index f0adcdbaaa..e9cebb5393 100644 --- a/docs/website/docs/general-usage/dataset-access/dataset_snippets/dataset_snippets.py +++ b/docs/website/docs/general-usage/dataset-access/dataset_snippets/dataset_snippets.py @@ -235,6 +235,73 @@ def users() -> Generator[list[dict[str, Any]], None, None]: # @@@DLT_SNIPPET_END join_related_tables +def join_explicit_on_snippet(dataset: dlt.Dataset) -> None: + # @@@DLT_SNIPPET_START join_explicit_on + # `customers` and `purchases` are two top-level tables connected + # by `purchases.customer_id` and `customers.id`. There is no schema + # reference between them, so we provide the join condition ourselves. + customers_with_purchases = dataset["customers"].join( + "purchases", + on="customers.id = purchases.customer_id", + kind="left", + ) + + # the right-hand side can also be a transformed relation; its filters + # are preserved when it is embedded as a subquery. + big_purchases = dataset["purchases"].where("quantity", "gt", 3) + customers_with_big_purchases = dataset["customers"].join( + big_purchases, + on="customers.id = purchases.customer_id", + alias="big", + ) + + df = customers_with_big_purchases.select("name", "big__id", "big__quantity").df() + # @@@DLT_SNIPPET_END join_explicit_on + + +def join_cross_dataset_snippet(tmp_path: Path) -> None: + # @@@DLT_SNIPPET_START join_cross_dataset + # two pipelines that write to the same DuckDB file under different + # dataset names — both datasets share one physical destination. + db_path = str(tmp_path / "shop.duckdb") + + crm_pipeline = dlt.pipeline( + pipeline_name="crm", + destination=dlt.destinations.duckdb(db_path), + dataset_name="crm_data", + ) + crm_pipeline.run( + [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}], + table_name="users", + ) + + sales_pipeline = dlt.pipeline( + pipeline_name="sales", + destination=dlt.destinations.duckdb(db_path), + dataset_name="sales_data", + ) + sales_pipeline.run( + [ + {"id": 10, "user_id": 1, "sku": "W-001", "quantity": 2}, + {"id": 11, "user_id": 1, "sku": "G-001", "quantity": 1}, + {"id": 12, "user_id": 2, "sku": "W-001", "quantity": 1}, + ], + table_name="purchases", + ) + + crm = crm_pipeline.dataset() + sales = sales_pipeline.dataset() + + # pass the right-hand side as a Relation from the other dataset; + # `on` is required for cross-dataset joins. + users_with_purchases = crm["users"].join( + sales["purchases"], + on="users.id = purchases.user_id", + ) + df = users_with_purchases.df() + # @@@DLT_SNIPPET_END join_cross_dataset + + def chain_operations_snippet(dataset: dlt.Dataset) -> None: customers_relation = dataset.table("customers") @@ -350,9 +417,9 @@ def custom_sql_snippet(dataset: dlt.Dataset) -> None: # @@@DLT_SNIPPET_START custom_sql # Join 'customers' and 'purchases' tables and filter by quantity query = """ - SELECT * - FROM customers - JOIN purchases + SELECT * + FROM customers + JOIN purchases ON customers.id = purchases.customer_id WHERE purchases.quantity > 1 """ diff --git a/tests/dataset/utils.py b/tests/dataset/utils.py index 3f44e41ebc..2bf1918417 100644 --- a/tests/dataset/utils.py +++ b/tests/dataset/utils.py @@ -70,7 +70,7 @@ class PurchaseRow(TypedDict): class MarketingUserRow(TypedDict): - id: int + id: int # noqa: A003 segment: str From 7ca43159ba232fae9e30e06b7962c0570d40fa39 Mon Sep 17 00:00:00 2001 From: Anton Burnashev Date: Thu, 28 May 2026 08:38:10 +0200 Subject: [PATCH 05/30] new failing tests + new fixtures --- tests/dataset/conftest.py | 48 +++++++- tests/dataset/test_relation_join.py | 181 +++++++++++++++++++++++++++- tests/dataset/utils.py | 23 ++++ 3 files changed, 246 insertions(+), 6 deletions(-) diff --git a/tests/dataset/conftest.py b/tests/dataset/conftest.py index 42c0e34426..3ca7be7aa5 100644 --- a/tests/dataset/conftest.py +++ b/tests/dataset/conftest.py @@ -8,9 +8,11 @@ from tests.dataset.utils import ( LOAD_0_STATS, LOAD_1_STATS, + TCrossDs3Fixture, TCrossDsFixture, TLoadsFixture, annotated_references, + billing, crm, inventory, marketing_users, @@ -107,7 +109,7 @@ def cross_dataset_duckdb(module_tmp_path: pathlib.Path) -> TCrossDsFixture: db_path = str(module_tmp_path / "cross_dataset.db") # dataset A: CRM data (users + orders) - pipeline_a = dlt.pipeline( + pipeline_crm = dlt.pipeline( pipeline_name="cross_ds_a", pipelines_dir=str(module_tmp_path / "pipelines_dir"), destination=dlt.destinations.duckdb(db_path), @@ -116,19 +118,55 @@ def cross_dataset_duckdb(module_tmp_path: pathlib.Path) -> TCrossDsFixture: ) source_a = crm(0) source_a.root_key = True - pipeline_a.run(source_a) + pipeline_crm.run(source_a) # dataset B: inventory data (products + warehouses) - pipeline_b = dlt.pipeline( + pipeline_inv = dlt.pipeline( pipeline_name="cross_ds_b", pipelines_dir=str(module_tmp_path / "pipelines_dir"), destination=dlt.destinations.duckdb(db_path), dataset_name="inv_data", dev_mode=True, ) - pipeline_b.run(inventory()) + pipeline_inv.run(inventory()) - return pipeline_a.dataset(), pipeline_b.dataset() + return pipeline_crm.dataset(), pipeline_inv.dataset() + + +@pytest.fixture(scope="module") +def three_way_cross_dataset_duckdb(module_tmp_path: pathlib.Path) -> TCrossDs3Fixture: + db_path = str(module_tmp_path / "three_way_cross_dataset.db") + + pipeline_crm = dlt.pipeline( + pipeline_name="three_way_ds_a", + pipelines_dir=str(module_tmp_path / "pipelines_dir"), + destination=dlt.destinations.duckdb(db_path), + dataset_name="crm_data", + dev_mode=True, + ) + source_a = crm(0) + source_a.root_key = True + pipeline_crm.run(source_a) + + pipeline_inv = dlt.pipeline( + pipeline_name="three_way_ds_b", + pipelines_dir=str(module_tmp_path / "pipelines_dir"), + destination=dlt.destinations.duckdb(db_path), + dataset_name="inv_data", + dev_mode=True, + ) + pipeline_inv.run(inventory()) + + pipeline_billing = dlt.pipeline( + pipeline_name="three_way_ds_c", + pipelines_dir=str(module_tmp_path / "pipelines_dir"), + destination=dlt.destinations.duckdb(db_path), + dataset_name="billing_data", + dev_mode=True, + ) + pipeline_billing.run(billing()) + + return pipeline_crm.dataset(), pipeline_inv.dataset(), pipeline_billing.dataset() @pytest.fixture(scope="module") diff --git a/tests/dataset/test_relation_join.py b/tests/dataset/test_relation_join.py index d3b7be5eb9..a4df030b39 100644 --- a/tests/dataset/test_relation_join.py +++ b/tests/dataset/test_relation_join.py @@ -13,7 +13,7 @@ _to_join_ref, ) from dlt.dataset.relation import TJoinType -from tests.dataset.utils import TCrossDsFixture, TLoadsFixture +from tests.dataset.utils import TCrossDs3Fixture, TCrossDsFixture, TLoadsFixture class _ColumnRef(TypedDict): @@ -1149,3 +1149,182 @@ def test_cross_dataset_join_with_same_table_names_keeps_sources_unambiguous( assert list(df["id"]) == [1, 2] assert list(df["name"]) == ["Alice", "Bob"] assert list(df["marketing__segment"]) == ["pro", "free"] + + +@pytest.mark.xfail(reason="Ambiguous qualifier should be rejected") +def test_cross_dataset_same_named_join_rejects_ambiguous_on_qualifier( + same_named_cross_dataset_duckdb: TCrossDsFixture, +) -> None: + ds_crm, ds_marketing = same_named_cross_dataset_duckdb + + with pytest.raises(ValueError): + ds_crm.table("users").join( + ds_marketing.table("users"), + on="users.id = users.id", + alias="marketing", + ) + + +def test_cross_dataset_join_chain_three_tables( + cross_dataset_duckdb: TCrossDsFixture, +) -> None: + ds_crm, ds_inv = cross_dataset_duckdb + + joined = ( + ds_inv.table("purchases") + .join(ds_crm.table("users"), on="purchases.user_id = users.id") + .join("inventory_items", on="purchases.sku = inventory_items.sku") + ) + df = joined.order_by("purchase_id").df() + + # orphan purchase (user_id=99) is dropped by the inner join to users + assert len(df) == 3 + assert "purchase_id" in df.columns + assert "users__name" in df.columns + assert "inventory_items__quantity" in df.columns + assert list(df["users__name"]) == ["Alice", "Alice", "Bob"] + assert list(df["inventory_items__quantity"]) == [50, 30, 50] + + +def test_cross_dataset_join_chain_magic_then_cross( + cross_dataset_duckdb: TCrossDsFixture, +) -> None: + ds_crm, ds_inv = cross_dataset_duckdb + + joined = ( + ds_crm.table("users__orders") + .join("users") + .join(ds_inv.table("purchases"), on="users.id = purchases.user_id") + ) + df = joined.df() + + assert len(df) == 5 + assert "order_id" in df.columns # base, unprefixed + assert "users__name" in df.columns + assert "purchases__sku" in df.columns + assert sorted(df["users__name"]) == ["Alice", "Alice", "Alice", "Alice", "Bob"] + assert list(df["users__id"]) == list(df["purchases__user_id"]) + + +@pytest.mark.xfail( + reason=( + "Column 'inventory_items.warehouse_id' could not be resolved for table: 'inventory_items'" + ) +) +def test_cross_dataset_join_chain_four_tables( + cross_dataset_duckdb: TCrossDsFixture, +) -> None: + """Star-schema joined to three dimensions across two datasets""" + ds_crm, ds_inv = cross_dataset_duckdb + + joined = ( + ds_inv.table("purchases") + .join(ds_crm.table("users"), on="purchases.user_id = users.id") + .join("inventory_items", on="purchases.sku = inventory_items.sku") + .join("warehouses", on="inventory_items.warehouse_id = warehouses.warehouse_id") + ) + df = joined.order_by("purchase_id").df() + + assert len(df) == 3 + assert "warehouses__city" in df.columns + assert list(df["warehouses__city"]) == ["Berlin", "Paris", "Berlin"] + + +@pytest.mark.xfail(reason="Column 'users.id' could not be resolved for table: 'users'") +def test_cross_dataset_join_chain_three_datasets( + three_way_cross_dataset_duckdb: TCrossDs3Fixture, +) -> None: + ds_crm, ds_inv, ds_billing = three_way_cross_dataset_duckdb + + joined = ( + ds_inv.table("purchases") + .join(ds_crm.table("users"), on="purchases.user_id = users.id") + .join(ds_billing.table("subscriptions"), on="users.id = subscriptions.user_id") + ) + df = joined.order_by("purchase_id").df() + + assert len(df) == 3 + assert "users__name" in df.columns + assert "subscriptions__plan" in df.columns + assert list(df["users__name"]) == ["Alice", "Alice", "Bob"] + assert list(df["subscriptions__plan"]) == ["enterprise", "enterprise", "free"] + + +def test_cross_dataset_join_chain_does_not_mutate_sources( + cross_dataset_duckdb: TCrossDsFixture, +) -> None: + ds_crm, ds_inv = cross_dataset_duckdb + + purchases = ds_inv.table("purchases") + users = ds_crm.table("users") + inventory_items = ds_inv.table("inventory_items") + + purchases_sql = purchases.to_sql() + users_sql = users.to_sql() + inventory_items_sql = inventory_items.to_sql() + + step1 = purchases.join(users, on="purchases.user_id = users.id") + step1_sql = step1.to_sql() + + assert purchases.to_sql() == purchases_sql + assert users.to_sql() == users_sql + assert inventory_items.to_sql() == inventory_items_sql + assert step1.to_sql() == step1_sql + # check if rebuild of the first step is identical + assert purchases.join(users, on="purchases.user_id = users.id").to_sql() == step1_sql + + +def test_cross_dataset_join_chain_with_filtered_step( + cross_dataset_duckdb: TCrossDsFixture, +) -> None: + ds_crm, ds_inv = cross_dataset_duckdb + + alice_purchases = ds_inv.table("purchases").where("user_id", "eq", 1) + joined = alice_purchases.join(ds_crm.table("users"), on="purchases.user_id = users.id").join( + "inventory_items", on="purchases.sku = inventory_items.sku" + ) + df = joined.order_by("purchase_id").df() + + assert len(df) == 2 + assert list(df["purchase_id"]) == [1, 2] + assert list(df["users__name"]) == ["Alice", "Alice"] + assert list(df["inventory_items__quantity"]) == [50, 30] + + +@pytest.mark.xfail(reason="unqualified where column `quantity` can't be resolved") +def test_cross_dataset_join_chain_filter_on_later_colliding_column( + cross_dataset_duckdb: TCrossDsFixture, +) -> None: + ds_crm, ds_inv = cross_dataset_duckdb + + high_value = ds_inv.table("purchases").where("quantity", "gt", 1) + joined = high_value.join(ds_crm.table("users"), on="purchases.user_id = users.id").join( + "inventory_items", on="purchases.sku = inventory_items.sku" + ) + + df = joined.order_by("purchase_id").df() + assert len(df) == 1 + assert list(df["users__name"]) == ["Alice"] + assert list(df["inventory_items__quantity"]) == [50] + + +@pytest.mark.xfail(reason="Column 'mkt_users.id' could not be resolved for table: 'mkt_users'") +def test_cross_dataset_chain_same_named_tables_disambiguated( + same_named_cross_dataset_duckdb: TCrossDsFixture, +) -> None: + """CRM and marketing both expose a `users` table.""" + ds_crm, ds_mkt = same_named_cross_dataset_duckdb + + marketing = ds_mkt.query("SELECT * FROM users AS mkt_users") + joined = ( + ds_crm.table("users__orders") + .join("users") + .join(marketing, on="users.id = mkt_users.id", alias="marketing") + ) + df = joined.order_by("order_id").df() + + assert len(df) == 3 + assert "users__name" in df.columns + assert "marketing__segment" in df.columns + assert list(df["users__name"]) == ["Alice", "Alice", "Bob"] + assert list(df["marketing__segment"]) == ["pro", "pro", "free"] diff --git a/tests/dataset/utils.py b/tests/dataset/utils.py index 2bf1918417..dc5c40be27 100644 --- a/tests/dataset/utils.py +++ b/tests/dataset/utils.py @@ -74,6 +74,12 @@ class MarketingUserRow(TypedDict): segment: str +class SubscriptionRow(TypedDict): + subscription_id: int + user_id: int + plan: str + + class CustomerRow(TypedDict): customer_id: int name: str @@ -94,6 +100,7 @@ class CountryRow(TypedDict): TLoadStats = dict[str, int] TLoadsFixture = tuple[dlt.Dataset, tuple[str, str], tuple[TLoadStats, TLoadStats]] TCrossDsFixture = tuple[dlt.Dataset, dlt.Dataset] +TCrossDs3Fixture = tuple[dlt.Dataset, dlt.Dataset, dlt.Dataset] USERS_DATA_0: list[UserRow] = [ @@ -211,6 +218,13 @@ def products(batch_idx: int): ] +SUBSCRIPTIONS: list[SubscriptionRow] = [ + {"subscription_id": 1, "user_id": 1, "plan": "enterprise"}, + {"subscription_id": 2, "user_id": 2, "plan": "free"}, + {"subscription_id": 3, "user_id": 3, "plan": "pro"}, +] + + @dlt.source def inventory(): @dlt.resource(name="warehouses") @@ -246,6 +260,15 @@ def users(): return [users()] +@dlt.source +def billing(): + @dlt.resource(name="subscriptions") + def subscriptions(): + yield SUBSCRIPTIONS + + return [subscriptions()] + + CUSTOMERS: list[CustomerRow] = [ {"customer_id": 1, "name": "Alice", "country_code": "DE"}, {"customer_id": 2, "name": "Bob", "country_code": "FR"}, From 1055ab34d0a4d962511600c98674017ec8da6f6e Mon Sep 17 00:00:00 2001 From: Anton Burnashev Date: Thu, 28 May 2026 14:59:45 +0200 Subject: [PATCH 06/30] Add e2e tests for Relation.join() --- tests/load/test_relation_join.py | 320 +++++++++++++++++++++++++++++++ 1 file changed, 320 insertions(+) create mode 100644 tests/load/test_relation_join.py diff --git a/tests/load/test_relation_join.py b/tests/load/test_relation_join.py new file mode 100644 index 0000000000..2a7735a640 --- /dev/null +++ b/tests/load/test_relation_join.py @@ -0,0 +1,320 @@ +"""End-to-end tests for ``Relation.join()`` across destinations.""" + +import os +from typing import Any, cast, Tuple + +import pytest + +import dlt +from dlt import Pipeline +from dlt.common.destination import Destination +from dlt.dataset.relation import TJoinType + +from tests.dataset.utils import ( + annotated_references, + crm, + inventory, + relational_tables, +) +from tests.load.lance_utils import module_lance_rest_server +from tests.load.utils import ( + DestinationTestConfiguration, + MEMORY_BUCKET, + SFTP_BUCKET, + destinations_configs, + drop_pipeline_data, +) +from tests.utils import ( + _preserve_environ, + auto_module_test_run_context, + auto_module_test_storage, + get_test_storage_root, +) + + +# TODO: same as in test_read_interfaces.py: factor out into a shared helper +@pytest.fixture( + scope="module", + params=destinations_configs( + default_sql_configs=True, + read_only_sqlclient_configs=True, + bucket_exclude=[SFTP_BUCKET, MEMORY_BUCKET], + ), + ids=lambda x: x.name, +) +def destination_config( + request: pytest.FixtureRequest, +) -> DestinationTestConfiguration: + return cast(DestinationTestConfiguration, request.param) + + +# TODO: same code in test_read_interfaces.py: factor out into a shared helper +@pytest.fixture(scope="module") +def preserve_module_environ_per_destination_config( + destination_config: DestinationTestConfiguration, +) -> Any: + yield from _preserve_environ() + + +# TODO: same code in test_read_interfaces.py: factor out into a shared helper +def _skip_unsupported_filesystem(destination_config: DestinationTestConfiguration) -> None: + if ( + destination_config.file_format not in ["parquet", "jsonl"] + and destination_config.destination_type == "filesystem" + ): + pytest.skip( + "filesystem read-only sqlclient requires jsonl or parquet; got" + f" {destination_config.file_format}" + ) + + +@pytest.fixture(scope="module") +def relational_pipeline( + destination_config: DestinationTestConfiguration, + module_lance_rest_server: None, + auto_module_test_storage: Any, + preserve_module_environ_per_destination_config: Any, + auto_module_test_run_context: Any, +) -> Any: + _skip_unsupported_filesystem(destination_config) + pipeline = destination_config.setup_pipeline( + "join_relational_pipeline", dataset_name="join_relational", dev_mode=True + ) + pipeline.run(relational_tables(), **destination_config.run_kwargs) + pipeline.run(annotated_references(), **destination_config.run_kwargs) + try: + yield pipeline + finally: + drop_pipeline_data(pipeline) + + +@pytest.fixture(scope="module") +def crm_pipeline( + destination_config: DestinationTestConfiguration, + module_lance_rest_server: None, + auto_module_test_storage: Any, + preserve_module_environ_per_destination_config: Any, + auto_module_test_run_context: Any, +) -> Any: + _skip_unsupported_filesystem(destination_config) + pipeline = destination_config.setup_pipeline( + "join_crm_pipeline", dataset_name="join_crm", dev_mode=True + ) + source = crm(0) + source.root_key = True + pipeline.run(source, **destination_config.run_kwargs) + pipeline.run(inventory(), **destination_config.run_kwargs) + try: + yield pipeline + finally: + drop_pipeline_data(pipeline) + + +@pytest.fixture(scope="module") +def cross_dataset_pipelines( + destination_config: DestinationTestConfiguration, + module_lance_rest_server: None, + auto_module_test_storage: Any, + preserve_module_environ_per_destination_config: Any, + auto_module_test_run_context: Any, +) -> Any: + """Two pipelines on the same physical destination, distinct dataset names.""" + _skip_unsupported_filesystem(destination_config) + if destination_config.destination_type == "filesystem": + pytest.skip( + "cross-dataset joins are not supported on filesystem destinations" + " (see dlt/dataset/relation.py:_resolve_join_target)" + ) + destination: Destination[Any, Any] + if destination_config.destination_type == "duckdb": + destination_config.setup() + # explicitly shared path to ensure the two pipelines see each other's datasets + shared_db = os.path.join(get_test_storage_root(), "cross_ds.duckdb") + destination = dlt.destinations.duckdb(shared_db) + else: + # assume that shared credentials + dataset_name differentiation are enough + destination = destination_config.destination_factory() + + pipeline_crm = destination_config.setup_pipeline( + "cross_crm_pipeline", + dataset_name="cross_crm", + dev_mode=True, + destination=destination, + ) + source = crm(0) + source.root_key = True + pipeline_crm.run(source, **destination_config.run_kwargs) + + pipeline_inv = destination_config.setup_pipeline( + "cross_inv_pipeline", + dataset_name="cross_inv", + dev_mode=True, + destination=destination, + ) + pipeline_inv.run(inventory(), **destination_config.run_kwargs) + + try: + yield pipeline_crm, pipeline_inv + finally: + drop_pipeline_data(pipeline_crm) + drop_pipeline_data(pipeline_inv) + + +@pytest.mark.essential +def test_magic_join_child_to_parent(crm_pipeline: Pipeline) -> None: + dataset = crm_pipeline.dataset() + df = dataset.table("users__orders").join("users").df() + + assert df is not None + assert len(df) == 3 + assert "users__name" in df.columns + assert sorted(df["users__name"].tolist()) == ["Alice", "Alice", "Bob"] + + +@pytest.mark.essential +def test_magic_join_multi_hop_to_root_via_root_key(crm_pipeline: Pipeline) -> None: + dataset = crm_pipeline.dataset() + df = dataset.table("users__orders__items").join("users").df() + + assert df is not None + assert len(df) == 4 + assert "users__name" in df.columns + assert sorted(df["users__name"].tolist()) == ["Alice", "Alice", "Alice", "Bob"] + + +@pytest.mark.essential +def test_explicit_on_basic(relational_pipeline: Pipeline) -> None: + dataset = relational_pipeline.dataset() + df = ( + dataset.table("customers") + .join("orders", on="customers.customer_id = orders.customer_id") + .order_by("orders__order_id") + .df() + ) + + assert df is not None + assert len(df) == 4 + assert "orders__amount" in df.columns + assert [float(x) for x in df["orders__amount"]] == [50.0, 75.0, 200.0, 30.0] + + +def test_explicit_on_composite_key(relational_pipeline: Pipeline) -> None: + dataset = relational_pipeline.dataset() + df = ( + dataset.table("account_memberships") + .join( + "accounts", + on=( + "account_memberships.account_id = accounts.account_id " + "AND account_memberships.tenant_id = accounts.tenant_id" + ), + ) + .order_by("accounts__name") + .df() + ) + + assert df is not None + assert len(df) == 3 + assert list(df["accounts__name"]) == ["Acme", "Globex", "Initech"] + + +@pytest.mark.parametrize( + "kind,expected_rows", + [ + pytest.param("inner", 3, id="inner"), + pytest.param("left", 3, id="left"), + pytest.param("right", 4, id="right"), + pytest.param("full", 4, id="full"), + ], +) +def test_join_kind_matrix( + relational_pipeline: Pipeline, kind: TJoinType, expected_rows: int +) -> None: + dataset = relational_pipeline.dataset() + df = ( + dataset.table("customers") + .join( + "countries", + on="customers.country_code = countries.code", + kind=kind, + ) + .df() + ) + + assert df is not None + assert len(df) == expected_rows + + +def test_chained_three_table_join(relational_pipeline: Pipeline) -> None: + dataset = relational_pipeline.dataset() + df = ( + dataset.table("customers") + .join("orders", on="customers.customer_id = orders.customer_id") + .join("countries", on="customers.country_code = countries.code") + .order_by("orders__order_id") + .df() + ) + + assert df is not None + assert len(df) == 4 + assert "name" in df.columns # left base column + assert "orders__amount" in df.columns # first join + assert "countries__name" in df.columns # second join + assert [float(x) for x in df["orders__amount"]] == [50.0, 75.0, 200.0, 30.0] + assert list(df["countries__name"]) == ["Germany", "Germany", "France", "Germany"] + + +def test_join_with_filtered_lhs(relational_pipeline: Pipeline) -> None: + dataset = relational_pipeline.dataset() + df = ( + dataset.table("customers") + .where("country_code", "eq", "DE") + .join("orders", on="customers.customer_id = orders.customer_id") + .order_by("orders__order_id") + .df() + ) + + assert df is not None + assert len(df) == 3 + assert list(df["name"]) == ["Alice", "Alice", "Charlie"] + assert [float(x) for x in df["orders__amount"]] == [50.0, 75.0, 30.0] + + +def test_join_alias_prefix_in_output_columns(relational_pipeline: Pipeline) -> None: + dataset = relational_pipeline.dataset() + joined = dataset.table("customers").join( + "orders", on="customers.customer_id = orders.customer_id", alias="o" + ) + df = joined.df() + + assert df is not None + o_cols = {c for c in df.columns if c.startswith("o__")} + assert o_cols, f"no `o__`-prefixed columns in {list(df.columns)}" + + expected = {f"o__{col}" for col in dataset.schema.tables["orders"]["columns"].keys()} + assert o_cols == expected + # the default prefix (table name) must not appear when `alias=` overrides it + assert not any(c.startswith("orders__") for c in df.columns) + + +def test_cross_dataset_explicit_join( + cross_dataset_pipelines: Tuple[Pipeline, Pipeline], +) -> None: + pipeline_a, pipeline_b = cross_dataset_pipelines + ds_a = pipeline_a.dataset() + ds_b = pipeline_b.dataset() + + joined = ds_a.table("users").join(ds_b.table("purchases"), on="users.id = purchases.user_id") + + sql = joined.to_sql() + assert ds_a.dataset_name in sql, sql + assert ds_b.dataset_name in sql, sql + + df = joined.order_by("purchases__purchase_id").df() + assert df is not None + # orphan user_id=99 dropped by INNER + assert len(df) == 3 + assert "purchases__sku" in df.columns + assert "name" in df.columns + assert list(df["name"]) == ["Alice", "Alice", "Bob"] + assert list(df["purchases__sku"]) == ["W-001", "G-001", "W-001"] From 92eaa8ba660883abaf030b1f5b970508ac77b460 Mon Sep 17 00:00:00 2001 From: Anton Burnashev Date: Thu, 28 May 2026 16:55:21 +0200 Subject: [PATCH 07/30] Skip tests for unsupported cross-dataset joins in SQLite and MySQL --- tests/load/ducklake/test_ducklake_client.py | 2 +- tests/load/test_relation_join.py | 10 +++++++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/load/ducklake/test_ducklake_client.py b/tests/load/ducklake/test_ducklake_client.py index 8c942755ae..319ae51461 100644 --- a/tests/load/ducklake/test_ducklake_client.py +++ b/tests/load/ducklake/test_ducklake_client.py @@ -438,4 +438,4 @@ def test_ducklake_factory_instantiation() -> None: credentials = DuckLakeCredentials( "lake_catalog", catalog=catalog_credentials, - ) \ No newline at end of file + ) diff --git a/tests/load/test_relation_join.py b/tests/load/test_relation_join.py index 2a7735a640..4bd23c709f 100644 --- a/tests/load/test_relation_join.py +++ b/tests/load/test_relation_join.py @@ -125,6 +125,9 @@ def cross_dataset_pipelines( "cross-dataset joins are not supported on filesystem destinations" " (see dlt/dataset/relation.py:_resolve_join_target)" ) + if destination_config.destination_name == "sqlalchemy_sqlite": + # TODO: remove when we attach foreign datasets in sqlite + pytest.skip("sqlite cross-dataset joins require ATTACH DATABASE for both datasets") destination: Destination[Any, Any] if destination_config.destination_type == "duckdb": destination_config.setup() @@ -228,8 +231,13 @@ def test_explicit_on_composite_key(relational_pipeline: Pipeline) -> None: ], ) def test_join_kind_matrix( - relational_pipeline: Pipeline, kind: TJoinType, expected_rows: int + relational_pipeline: Pipeline, + destination_config: DestinationTestConfiguration, + kind: TJoinType, + expected_rows: int, ) -> None: + if kind == "full" and destination_config.destination_name == "sqlalchemy_mysql": + pytest.skip("MySQL does not support FULL JOIN") dataset = relational_pipeline.dataset() df = ( dataset.table("customers") From bb5d7b9f82590d1047f42b4e3ea15d80e94e4ce2 Mon Sep 17 00:00:00 2001 From: Anton Burnashev Date: Thu, 28 May 2026 17:48:49 +0200 Subject: [PATCH 08/30] more tests for explicit join --- tests/dataset/test_relation_join.py | 174 ++++++++++++++++++++++++++++ 1 file changed, 174 insertions(+) diff --git a/tests/dataset/test_relation_join.py b/tests/dataset/test_relation_join.py index a4df030b39..6df2def5fd 100644 --- a/tests/dataset/test_relation_join.py +++ b/tests/dataset/test_relation_join.py @@ -1328,3 +1328,177 @@ def test_cross_dataset_chain_same_named_tables_disambiguated( assert "marketing__segment" in df.columns assert list(df["users__name"]) == ["Alice", "Alice", "Bob"] assert list(df["marketing__segment"]) == ["pro", "pro", "free"] + + +def test_explicit_on_left_join_keeps_unmatched_left_rows( + dataset_with_relational_tables: dlt.Dataset, +) -> None: + ds = dataset_with_relational_tables + joined = ds.table("countries").join( + "customers", kind="left", on="countries.code = customers.country_code" + ) + df = joined.order_by("code").df() + assert len(df) == 4 + assert list(df["code"]) == ["DE", "DE", "ES", "FR"] + assert list(df["customers__name"]) == ["Alice", "Charlie", None, "Bob"] + es_row = df[df["code"] == "ES"].iloc[0] + assert es_row["name"] == "Spain" + customers_cols = [c for c in df.columns if c.startswith("customers__")] + assert es_row[customers_cols].isna().all() + + +def test_explicit_on_composite_key( + dataset_with_annotated_references: dlt.Dataset, +) -> None: + ds = dataset_with_annotated_references + joined = ds.table("account_memberships").join( + "accounts", + on=( + "account_memberships.account_id = accounts.account_id " + "AND account_memberships.tenant_id = accounts.tenant_id" + ), + ) + df = joined.order_by("accounts__name").df() + + assert len(df) == 3 + assert list(df["accounts__name"]) == ["Acme", "Globex", "Initech"] + + +def test_explicit_on_with_filtered_lhs( + dataset_with_relational_tables: dlt.Dataset, +) -> None: + ds = dataset_with_relational_tables + german_customers = ds.table("customers").where("country_code", "eq", "DE") + joined = german_customers.join("orders", on="customers.customer_id = orders.customer_id") + df = joined.df() + assert len(df) == 3 + assert list(df["name"]) == ["Alice", "Alice", "Charlie"] + assert list(df["orders__amount"]) == [50.0, 75.0, 30.0] + + +@pytest.mark.xfail(reason="ON expression must be non-empty") +def test_explicit_on_rejects_invalid_on_expression( + dataset_with_relational_tables: dlt.Dataset, +) -> None: + ds = dataset_with_relational_tables + with pytest.raises(ValueError, match="non-empty SQL expression"): + ds.table("customers").join("orders", on="") + + +@pytest.mark.xfail(reason="Unsupported join kind should be rejected") +def test_explicit_on_rejects_unknown_kind( + dataset_with_relational_tables: dlt.Dataset, +) -> None: + ds = dataset_with_relational_tables + + with pytest.raises(ValueError, match="kind=outer"): + ds.table("customers").join( + "orders", + kind="outer", # type: ignore[arg-type] + on="customers.customer_id = orders.customer_id", + ) + + +def test_explicit_on_with_projected_lhs_preserves_left_projection( + dataset_with_relational_tables: dlt.Dataset, +) -> None: + ds = dataset_with_relational_tables + narrow_customers = ds.table("customers").select("customer_id", "name") + joined = narrow_customers.join("orders", on="customers.customer_id = orders.customer_id") + df = joined.df() + assert len(df) == 4 + lhs_cols = {c for c in df.columns if not c.startswith("orders__")} + assert lhs_cols == {"customer_id", "name"} + assert "country_code" not in df.columns + assert "orders__amount" in df.columns + assert list(df["orders__amount"]) == [50.0, 75.0, 200.0, 30.0] + + +@pytest.mark.xfail(reason="Column 'o.customer_id' could not be resolved for table: 'o'") +def test_explicit_on_with_aliased_query_relations( + dataset_with_relational_tables: dlt.Dataset, +) -> None: + ds = dataset_with_relational_tables + customers = ds.query("SELECT * FROM customers AS c") + orders = ds.query("SELECT * FROM orders AS o") + + joined = customers.join(orders, on="c.customer_id = o.customer_id") + df = joined.order_by("o__order_id").df() + + assert len(df) == 4 + assert list(df["customer_id"]) == [1, 1, 2, 3] + assert list(df["name"]) == ["Alice", "Alice", "Bob", "Charlie"] + assert list(df["o__amount"]) == [50.0, 75.0, 200.0, 30.0] + + +def test_explicit_on_with_aggregated_rhs( + dataset_with_relational_tables: dlt.Dataset, +) -> None: + ds = dataset_with_relational_tables + order_totals = ds.query( + "SELECT customer_id, SUM(amount) AS total_amount FROM orders GROUP BY customer_id" + ) + + joined = ds.table("customers").join( + order_totals, + on="customers.customer_id = orders.customer_id", + alias="order_totals", + ) + df = joined.order_by("customer_id").df() + + assert len(df) == 3 + assert list(df["customer_id"]) == [1, 2, 3] + assert list(df["name"]) == ["Alice", "Bob", "Charlie"] + assert "order_totals__total_amount" in df.columns + assert list(df["order_totals__total_amount"]) == [125.0, 200.0, 30.0] + assert "order_totals__amount" not in df.columns + + +def test_explicit_on_projection_alias_collision_rejected( + dataset_with_relational_tables: dlt.Dataset, +) -> None: + ds = dataset_with_relational_tables + left = ds.query("SELECT customer_id, 1 AS orders__amount FROM customers") + + with pytest.raises(ValueError, match="conflict with existing columns"): + left.join("orders", on="customers.customer_id = orders.customer_id") + + +def test_cross_dataset_join_to_sql_uses_each_dataset_name( + cross_dataset_duckdb: TCrossDsFixture, +) -> None: + ds_a, ds_b = cross_dataset_duckdb + + joined = ds_a.table("users").join( + ds_b.table("purchases"), + on="users.id = purchases.user_id", + ) + sql = joined.to_sql() + + assert f'"{ds_a.dataset_name}"."users"' in sql + assert f'"{ds_b.dataset_name}"."purchases"' in sql + assert f'"{ds_b.dataset_name}"."users"' not in sql + assert f'"{ds_a.dataset_name}"."purchases"' not in sql + + +def test_cross_dataset_join_with_aggregated_rhs( + cross_dataset_duckdb: TCrossDsFixture, +) -> None: + ds_a, ds_b = cross_dataset_duckdb + + purchase_totals = ds_b.query( + "SELECT user_id, SUM(quantity) AS total_quantity FROM purchases GROUP BY user_id" + ) + joined = ds_a.table("users").join( + purchase_totals, + on="users.id = purchases.user_id", + alias="purchase_totals", + ) + df = joined.order_by("id").df() + + assert len(df) == 2 + assert list(df["id"]) == [1, 2] + assert list(df["name"]) == ["Alice", "Bob"] + assert "purchase_totals__total_quantity" in df.columns + assert [int(x) for x in df["purchase_totals__total_quantity"]] == [3, 1] + assert "purchase_totals__quantity" not in df.columns From 5328c020490608a25ac5c0820822548da26f35f2 Mon Sep 17 00:00:00 2001 From: Anton Burnashev Date: Thu, 28 May 2026 19:10:20 +0200 Subject: [PATCH 09/30] preserve dlt-namespace case in order by, group by, etc after bind_query, fixes snowflake --- dlt/common/libs/sqlglot.py | 20 +++++++ tests/destinations/test_queries.py | 93 ++++++++++++++++++++++++++++++ 2 files changed, 113 insertions(+) diff --git a/dlt/common/libs/sqlglot.py b/dlt/common/libs/sqlglot.py index 5550bd933b..b00cfc7f0f 100644 --- a/dlt/common/libs/sqlglot.py +++ b/dlt/common/libs/sqlglot.py @@ -1044,6 +1044,21 @@ def normalize_query_identifiers( return query +def _restore_alias_case_in_clauses(query: sge.Query, alias_rename_map: Dict[str, str]) -> None: + """Rewrite bare-column references in ORDER BY / GROUP BY / HAVING back to the original + (un-casefolded) alias case so they match SELECT aliases preserved by `bind_query`.""" + for clause_key in ("order", "group", "having"): + clause = query.args.get(clause_key) + if clause is None: + continue + for col in clause.find_all(sge.Column): + if col.args.get("table") is not None: + continue + name_node = col.this + if name_node.name in alias_rename_map: + name_node.set("this", alias_rename_map[name_node.name]) + + def bind_query( qualified_query: sge.Query, sqlglot_schema: Any, # SQLGlotSchema @@ -1113,12 +1128,17 @@ def bind_query( node.set("quoted", True) # add aliases to output selects to stay compatible with dlt schema after the query + alias_rename_map: Dict[str, str] = {} if orig_selects: for i, orig in orig_selects.items(): case_folded_orig = casefold_identifier(orig) if case_folded_orig != orig: + alias_rename_map[case_folded_orig] = orig # somehow we need to alias just top select in UNION (tested on Snowflake) sel_expr = qualified_query.selects[i] qualified_query.selects[i] = sge.alias_(sel_expr, orig, quoted=True) + if alias_rename_map: + _restore_alias_case_in_clauses(qualified_query, alias_rename_map) + return qualified_query diff --git a/tests/destinations/test_queries.py b/tests/destinations/test_queries.py index 757e614701..bc635c4447 100644 --- a/tests/destinations/test_queries.py +++ b/tests/destinations/test_queries.py @@ -13,6 +13,16 @@ from dlt.destinations.impl.duckdb.configuration import DuckDbClientConfiguration +_BIND_QUERY_SCHEMA = SQLGlotSchema( + { + "my_dataset": { + "customers": {"customer_id": str, "country_code": str}, + "orders": {"order_id": str, "customer_id": str}, + } + } +) + + def test_basic() -> None: stmt = build_row_counts_expr("my_table", quoted_identifiers=True) expected = ( @@ -144,3 +154,86 @@ def _expand(table_name: str, db: Optional[str] = None) -> List[str]: normalized_query = normalized_query_expr.sql() assert normalized_query == expected_normalized_query + + +def _bind_query_expand(table_name: str, db: Optional[str] = None) -> List[str]: + return [db, table_name] + + +@pytest.mark.parametrize( + "clause_sql", + [ + pytest.param('ORDER BY "orders__order_id" ASC', id="order_by"), + pytest.param('GROUP BY "orders__order_id"', id="group_by"), + pytest.param('HAVING "orders__order_id" > 0', id="having"), + ], +) +def test_bind_query_preserves_alias_case_for_clause_references(clause_sql: str) -> None: + query = cast( + sge.Query, + sqlglot.parse_one(f""" + SELECT + customers.customer_id AS customer_id, + orders.order_id AS "orders__order_id" + FROM my_dataset.customers AS customers + INNER JOIN my_dataset.orders AS orders + ON customers.customer_id = orders.customer_id + {clause_sql} + """), + ) + + bound = bind_query( + qualified_query=query, + sqlglot_schema=_BIND_QUERY_SCHEMA, + expand_table_name=_bind_query_expand, + casefold_identifier=str.upper, + ) + sql = bound.sql() + + # SELECT alias is preserved in original + assert 'AS "orders__order_id"' in sql + # clause reference matches the preserved alias case, not the casefolded form + assert clause_sql in sql + + +def test_bind_query_casefolds_qualified_columns_in_order_by() -> None: + """Table-qualified column references in ORDER BY must still be casefolded.""" + query = cast( + sge.Query, + sqlglot.parse_one(""" + SELECT customers.customer_id AS customer_id + FROM my_dataset.customers AS customers + ORDER BY customers.customer_id ASC + """), + ) + + bound = bind_query( + qualified_query=query, + sqlglot_schema=_BIND_QUERY_SCHEMA, + expand_table_name=_bind_query_expand, + casefold_identifier=str.upper, + ) + sql = bound.sql() + + assert 'ORDER BY "CUSTOMERS"."CUSTOMER_ID"' in sql + + +def test_bind_query_casefolds_unrelated_bare_order_by_identifiers() -> None: + query = cast( + sge.Query, + sqlglot.parse_one(""" + SELECT customers.customer_id AS customer_id + FROM my_dataset.customers AS customers + ORDER BY country_code ASC + """), + ) + + bound = bind_query( + qualified_query=query, + sqlglot_schema=_BIND_QUERY_SCHEMA, + expand_table_name=_bind_query_expand, + casefold_identifier=str.upper, + ) + sql = bound.sql() + + assert 'ORDER BY "COUNTRY_CODE"' in sql From e0ed02d615fefdee395aa36797ff6243c651f3d0 Mon Sep 17 00:00:00 2001 From: Anton Burnashev Date: Fri, 29 May 2026 11:51:49 +0200 Subject: [PATCH 10/30] fix mypy --- tests/dataset/test_relation_join.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/dataset/test_relation_join.py b/tests/dataset/test_relation_join.py index 6df2def5fd..dc2a8c8998 100644 --- a/tests/dataset/test_relation_join.py +++ b/tests/dataset/test_relation_join.py @@ -1392,9 +1392,9 @@ def test_explicit_on_rejects_unknown_kind( ds = dataset_with_relational_tables with pytest.raises(ValueError, match="kind=outer"): - ds.table("customers").join( + ds.table("customers").join( # type: ignore[call-overload] "orders", - kind="outer", # type: ignore[arg-type] + kind="outer", on="customers.customer_id = orders.customer_id", ) From d19bf3579e6fae2675bf0265f7430a36dbe47319 Mon Sep 17 00:00:00 2001 From: Anton Burnashev Date: Fri, 29 May 2026 12:58:52 +0200 Subject: [PATCH 11/30] normalize var names --- tests/dataset/test_relation_join.py | 634 +++++++++++++++++----------- 1 file changed, 381 insertions(+), 253 deletions(-) diff --git a/tests/dataset/test_relation_join.py b/tests/dataset/test_relation_join.py index dc2a8c8998..ad6d623333 100644 --- a/tests/dataset/test_relation_join.py +++ b/tests/dataset/test_relation_join.py @@ -438,6 +438,15 @@ def test_join_projection_prefix_rejects_colliding_alias( joined.join("users__orders__items", alias="shared") +def test_join_does_not_project_incomplete_target_columns( + dataset_with_incomplete_join_target: dlt.Dataset, +) -> None: + relation = dataset_with_incomplete_join_target.table("products").join("categories") + rows = relation.fetchall() + assert rows is not None + assert len(rows) == 3 + + def test_join_rejects_empty_alias(dataset_with_loads: TLoadsFixture) -> None: dataset, _, _ = dataset_with_loads with pytest.raises(ValueError, match="must be a non-empty string"): @@ -930,7 +939,6 @@ def test_explicit_on_joins_relational_tables( assert "orders__amount" in df.columns assert list(df["orders__amount"]) == [50.0, 75.0, 200.0, 30.0] - # auto join should fail: no dlt reference between customers and orders with pytest.raises(ValueError, match="Unable to resolve reference chain"): ds.table("customers").join("orders") @@ -968,6 +976,40 @@ def test_explicit_on_non_eq_predicate( assert list(df["orders__amount"]) == [75.0, 200.0] +def test_explicit_on_composite_key( + dataset_with_annotated_references: dlt.Dataset, +) -> None: + ds = dataset_with_annotated_references + joined = ds.table("account_memberships").join( + "accounts", + on=( + "account_memberships.account_id = accounts.account_id " + "AND account_memberships.tenant_id = accounts.tenant_id" + ), + ) + df = joined.order_by("accounts__name").df() + + assert len(df) == 3 + assert list(df["accounts__name"]) == ["Acme", "Globex", "Initech"] + + +def test_explicit_on_left_join_keeps_unmatched_left_rows( + dataset_with_relational_tables: dlt.Dataset, +) -> None: + ds = dataset_with_relational_tables + joined = ds.table("countries").join( + "customers", kind="left", on="countries.code = customers.country_code" + ) + df = joined.order_by("code").df() + assert len(df) == 4 + assert list(df["code"]) == ["DE", "DE", "ES", "FR"] + assert list(df["customers__name"]) == ["Alice", "Charlie", None, "Bob"] + es_row = df[df["code"] == "ES"].iloc[0] + assert es_row["name"] == "Spain" + customers_cols = [c for c in df.columns if c.startswith("customers__")] + assert es_row[customers_cols].isna().all() + + def test_explicit_on_projection_prefix( dataset_with_relational_tables: dlt.Dataset, ) -> None: @@ -982,26 +1024,26 @@ def test_explicit_on_projection_prefix( assert right_aliases == expected -def test_explicit_on_rejects_empty_alias( +def test_explicit_on_projection_alias_collision_rejected( dataset_with_relational_tables: dlt.Dataset, ) -> None: ds = dataset_with_relational_tables - with pytest.raises(ValueError, match="must be a non-empty string"): - ds.table("customers").join( - "orders", on="customers.customer_id = orders.customer_id", alias="" - ) + left = ds.query("SELECT customer_id, 1 AS orders__amount FROM customers") + + with pytest.raises(ValueError, match="conflict with existing columns"): + left.join("orders", on="customers.customer_id = orders.customer_id") -def test_explicit_on_rejects_self_join( +def test_explicit_on_with_filtered_lhs( dataset_with_relational_tables: dlt.Dataset, ) -> None: ds = dataset_with_relational_tables - with pytest.raises(ValueError, match="Self-joins are not supported"): - ds.table("customers").join( - "customers", - on="customers.customer_id = customers.customer_id", - alias="c2", - ) + german_customers = ds.table("customers").where("country_code", "eq", "DE") + joined = german_customers.join("orders", on="customers.customer_id = orders.customer_id") + df = joined.df() + assert len(df) == 3 + assert list(df["name"]) == ["Alice", "Alice", "Charlie"] + assert list(df["orders__amount"]) == [50.0, 75.0, 30.0] def test_explicit_on_with_filtered_rhs( @@ -1018,6 +1060,21 @@ def test_explicit_on_with_filtered_rhs( assert list(df["orders__amount"]) == [75.0, 200.0] +def test_explicit_on_with_projected_lhs_preserves_left_projection( + dataset_with_relational_tables: dlt.Dataset, +) -> None: + ds = dataset_with_relational_tables + narrow_customers = ds.table("customers").select("customer_id", "name") + joined = narrow_customers.join("orders", on="customers.customer_id = orders.customer_id") + df = joined.df() + assert len(df) == 4 + lhs_cols = {c for c in df.columns if not c.startswith("orders__")} + assert lhs_cols == {"customer_id", "name"} + assert "country_code" not in df.columns + assert "orders__amount" in df.columns + assert list(df["orders__amount"]) == [50.0, 75.0, 200.0, 30.0] + + def test_explicit_on_with_projected_rhs( dataset_with_relational_tables: dlt.Dataset, ) -> None: @@ -1033,18 +1090,135 @@ def test_explicit_on_with_projected_rhs( assert "orders__amount" not in df.columns +def test_explicit_on_with_aggregated_rhs( + dataset_with_relational_tables: dlt.Dataset, +) -> None: + ds = dataset_with_relational_tables + order_totals = ds.query( + "SELECT customer_id, SUM(amount) AS total_amount FROM orders GROUP BY customer_id" + ) + + joined = ds.table("customers").join( + order_totals, + on="customers.customer_id = orders.customer_id", + alias="order_totals", + ) + df = joined.order_by("customer_id").df() + + assert len(df) == 3 + assert list(df["customer_id"]) == [1, 2, 3] + assert list(df["name"]) == ["Alice", "Bob", "Charlie"] + assert "order_totals__total_amount" in df.columns + assert list(df["order_totals__total_amount"]) == [125.0, 200.0, 30.0] + assert "order_totals__amount" not in df.columns + + +@pytest.mark.xfail(reason="Column 'o.customer_id' could not be resolved for table: 'o'") +def test_explicit_on_with_aliased_query_relations( + dataset_with_relational_tables: dlt.Dataset, +) -> None: + ds = dataset_with_relational_tables + customers = ds.query("SELECT * FROM customers AS c") + orders = ds.query("SELECT * FROM orders AS o") + + joined = customers.join(orders, on="c.customer_id = o.customer_id") + df = joined.order_by("o__order_id").df() + + assert len(df) == 4 + assert list(df["customer_id"]) == [1, 1, 2, 3] + assert list(df["name"]) == ["Alice", "Alice", "Bob", "Charlie"] + assert list(df["o__amount"]) == [50.0, 75.0, 200.0, 30.0] + + +def test_explicit_on_rejects_empty_alias( + dataset_with_relational_tables: dlt.Dataset, +) -> None: + ds = dataset_with_relational_tables + with pytest.raises(ValueError, match="must be a non-empty string"): + ds.table("customers").join( + "orders", on="customers.customer_id = orders.customer_id", alias="" + ) + + +def test_explicit_on_rejects_self_join( + dataset_with_relational_tables: dlt.Dataset, +) -> None: + ds = dataset_with_relational_tables + with pytest.raises(ValueError, match="Self-joins are not supported"): + ds.table("customers").join( + "customers", + on="customers.customer_id = customers.customer_id", + alias="c2", + ) + + +@pytest.mark.xfail(reason="ON expression must be non-empty") +def test_explicit_on_rejects_invalid_on_expression( + dataset_with_relational_tables: dlt.Dataset, +) -> None: + ds = dataset_with_relational_tables + with pytest.raises(ValueError, match="non-empty SQL expression"): + ds.table("customers").join("orders", on="") + + +@pytest.mark.xfail(reason="Unsupported join kind should be rejected") +def test_explicit_on_rejects_unknown_kind( + dataset_with_relational_tables: dlt.Dataset, +) -> None: + ds = dataset_with_relational_tables + + with pytest.raises(ValueError, match="kind=outer"): + ds.table("customers").join( # type: ignore[call-overload] + "orders", + kind="outer", + on="customers.customer_id = orders.customer_id", + ) + + +@pytest.mark.parametrize( + "name_normalizer_ref", + ( + "tests.common.cases.normalizers.title_case", + "tests.common.cases.normalizers.sql_upper", + "tests.common.cases.normalizers.snake_no_x", + ), +) +def test_explicit_on_columns_schema_resolves_with_name_mutating_normalizer( + dataset_with_relational_tables: dlt.Dataset, + name_normalizer_ref: str, +) -> None: + normalized_dataset = _dataset_with_name_normalizer( + dataset_with_relational_tables, name_normalizer_ref + ) + naming = normalized_dataset.schema.naming + customers = naming.normalize_tables_path("customers") + orders = naming.normalize_tables_path("orders") + customer_id = naming.normalize_identifier("customer_id") + + on_predicate = f'"{customers}"."{customer_id}" = "{orders}"."{customer_id}"' + joined = normalized_dataset.table(customers).join(orders, on=on_predicate) + + schema_cols = set(joined.columns_schema.keys()) + assert schema_cols + expected_right_aliases = { + f"{orders}__{column_name}" + for column_name in normalized_dataset.schema.tables[orders]["columns"].keys() + } + assert expected_right_aliases.issubset(schema_cols) + + def test_cross_dataset_join( cross_dataset_duckdb: TCrossDsFixture, ) -> None: - ds_a, ds_b = cross_dataset_duckdb - users = ds_a.table("users") - purchases = ds_b.table("purchases") + ds_crm, ds_inv = cross_dataset_duckdb + users = ds_crm.table("users") + purchases = ds_inv.table("purchases") joined = users.join(purchases, on="users.id = purchases.user_id") - assert ds_b.dataset_name in joined._foreign_schemas - assert ds_b.dataset_name not in users._foreign_schemas - foreign_schemas = joined._foreign_schemas[ds_b.dataset_name] + assert ds_inv.dataset_name in joined._foreign_schemas + assert ds_inv.dataset_name not in users._foreign_schemas + foreign_schemas = joined._foreign_schemas[ds_inv.dataset_name] assert len(foreign_schemas) >= 1 df = joined.df() @@ -1057,9 +1231,9 @@ def test_cross_dataset_join( def test_cross_dataset_join_requires_on( cross_dataset_duckdb: TCrossDsFixture, ) -> None: - ds_a, ds_b = cross_dataset_duckdb - users = ds_a.table("users") - purchases = ds_b.table("purchases") + ds_crm, ds_inv = cross_dataset_duckdb + users = ds_crm.table("users") + purchases = ds_inv.table("purchases") with pytest.raises(ValueError, match="`on` is required"): users.join(purchases) @@ -1084,10 +1258,8 @@ def test_cross_dataset_join_requires_on( @pytest.mark.parametrize( "kind,expected", [ - # inner + left: both users match, so LEFT adds no extra rows pytest.param("inner", _MATCHED, id="inner"), pytest.param("left", _MATCHED, id="left"), - # right + full: orphan purchase appears with NULL on the user side pytest.param("right", _MATCHED_PLUS_ORPHAN, id="right"), pytest.param("full", _MATCHED_PLUS_ORPHAN, id="full"), ], @@ -1097,9 +1269,9 @@ def test_cross_dataset_join_kind_parameter( kind: TJoinType, expected: dict[str, list[Any]], ) -> None: - ds_a, ds_b = cross_dataset_duckdb - users = ds_a.table("users") - purchases = ds_b.table("purchases") + ds_crm, ds_inv = cross_dataset_duckdb + users = ds_crm.table("users") + purchases = ds_inv.table("purchases") joined = users.join(purchases, on="users.id = purchases.user_id", kind=kind) df = joined.df() @@ -1108,22 +1280,29 @@ def test_cross_dataset_join_kind_parameter( assert list(df[col]) == expected_values, f"column `{col}` mismatch" -def test_join_does_not_project_incomplete_target_columns( - dataset_with_incomplete_join_target: dlt.Dataset, +def test_cross_dataset_join_to_sql_uses_each_dataset_name( + cross_dataset_duckdb: TCrossDsFixture, ) -> None: - relation = dataset_with_incomplete_join_target.table("products").join("categories") - rows = relation.fetchall() - assert rows is not None - # 3 products inner-joined to 2 categories on category_id → 3 rows - assert len(rows) == 3 + ds_crm, ds_inv = cross_dataset_duckdb + + joined = ds_crm.table("users").join( + ds_inv.table("purchases"), + on="users.id = purchases.user_id", + ) + sql = joined.to_sql() + + assert f'"{ds_crm.dataset_name}"."users"' in sql + assert f'"{ds_inv.dataset_name}"."purchases"' in sql + assert f'"{ds_inv.dataset_name}"."users"' not in sql + assert f'"{ds_crm.dataset_name}"."purchases"' not in sql def test_cross_dataset_join_with_transformed_rhs_preserves_foreign_dataset_binding( cross_dataset_duckdb: TCrossDsFixture, ) -> None: - ds_a, ds_b = cross_dataset_duckdb - users = ds_a.table("users") - filtered_purchases = ds_b.table("purchases").where("quantity", "gt", 1) + ds_crm, ds_inv = cross_dataset_duckdb + users = ds_crm.table("users") + filtered_purchases = ds_inv.table("purchases").where("quantity", "gt", 1) joined = users.join(filtered_purchases, on="users.id = purchases.user_id").order_by("id") df = joined.df() @@ -1135,12 +1314,60 @@ def test_cross_dataset_join_with_transformed_rhs_preserves_foreign_dataset_bindi assert list(df["purchases__quantity"]) == [2] +def test_cross_dataset_join_with_aggregated_rhs( + cross_dataset_duckdb: TCrossDsFixture, +) -> None: + ds_crm, ds_inv = cross_dataset_duckdb + + purchase_totals = ds_inv.query( + "SELECT user_id, SUM(quantity) AS total_quantity FROM purchases GROUP BY user_id" + ) + joined = ds_crm.table("users").join( + purchase_totals, + on="users.id = purchases.user_id", + alias="purchase_totals", + ) + df = joined.order_by("id").df() + + assert len(df) == 2 + assert list(df["id"]) == [1, 2] + assert list(df["name"]) == ["Alice", "Bob"] + assert "purchase_totals__total_quantity" in df.columns + assert [int(x) for x in df["purchase_totals__total_quantity"]] == [3, 1] + assert "purchase_totals__quantity" not in df.columns + + +def test_cross_dataset_join_with_cte_qualifies_body_but_not_alias( + cross_dataset_duckdb: TCrossDsFixture, +) -> None: + ds_crm, ds_inv = cross_dataset_duckdb + + recent_purchases = ds_inv.query( + "WITH recent AS (SELECT * FROM purchases WHERE quantity > 1) SELECT * FROM recent" + ) + joined = ds_crm.table("users").join(recent_purchases, on="users.id = recent.user_id") + + table_qualifiers = { + (node.name, node.db or None) for node in joined.sqlglot_expression.find_all(sge.Table) + } + assert ("users", ds_crm.dataset_name) in table_qualifiers + assert ("purchases", ds_inv.dataset_name) in table_qualifiers + assert ("recent", None) in table_qualifiers + + df = joined.order_by("id").df() + assert len(df) == 1 + assert list(df["name"]) == ["Alice"] + assert list(df["recent__purchase_id"]) == [1] + assert list(df["recent__sku"]) == ["W-001"] + assert list(df["recent__quantity"]) == [2] + + def test_cross_dataset_join_with_same_table_names_keeps_sources_unambiguous( same_named_cross_dataset_duckdb: TCrossDsFixture, ) -> None: - ds_a, ds_b = same_named_cross_dataset_duckdb - crm_users = ds_a.query("SELECT * FROM users AS crm_users") - marketing_users = ds_b.table("users") + ds_crm, ds_marketing = same_named_cross_dataset_duckdb + crm_users = ds_crm.query("SELECT * FROM users AS crm_users") + marketing_users = ds_marketing.table("users") joined = crm_users.join(marketing_users, on="crm_users.id = users.id", alias="marketing") df = joined.order_by("id").df() @@ -1177,7 +1404,6 @@ def test_cross_dataset_join_chain_three_tables( ) df = joined.order_by("purchase_id").df() - # orphan purchase (user_id=99) is dropped by the inner join to users assert len(df) == 3 assert "purchase_id" in df.columns assert "users__name" in df.columns @@ -1206,48 +1432,77 @@ def test_cross_dataset_join_chain_magic_then_cross( assert list(df["users__id"]) == list(df["purchases__user_id"]) -@pytest.mark.xfail( - reason=( - "Column 'inventory_items.warehouse_id' could not be resolved for table: 'inventory_items'" - ) -) -def test_cross_dataset_join_chain_four_tables( - cross_dataset_duckdb: TCrossDsFixture, +def test_cross_dataset_join_chain_magic_then_two_crossings( + three_way_cross_dataset_duckdb: TCrossDs3Fixture, +) -> None: + """A magic (schema-resolved) join followed by two cross-dataset joins across three + systems: nested CRM orders are enriched with their customer (magic, via the parent + reference), then bridged to that customer's purchases (inventory) and subscription + (billing). + + The magic join emits unqualified CRM tables; the first crossing must retroactively + db-qualify them, and the final query interleaves all three datasets.""" + ds_crm, ds_inv, ds_billing = three_way_cross_dataset_duckdb + + joined = ( + ds_crm.table("users__orders") + .join("users") + .join(ds_inv.table("purchases"), on="users.id = purchases.user_id") + .join(ds_billing.table("subscriptions"), on="users.id = subscriptions.user_id") + ) + df = joined.df() + + assert len(df) == 5 + assert "order_id" in df.columns + assert "users__name" in df.columns + assert "purchases__sku" in df.columns + assert "subscriptions__plan" in df.columns + assert sorted(df["users__name"]) == ["Alice", "Alice", "Alice", "Alice", "Bob"] + assert sorted(df["subscriptions__plan"]) == [ + "enterprise", + "enterprise", + "enterprise", + "enterprise", + "free", + ] + assert list(df["users__id"]) == list(df["purchases__user_id"]) + assert list(df["users__id"]) == list(df["subscriptions__user_id"]) + + +def test_cross_dataset_join_then_foreign_dataset_local_hop_with_relation( + cross_dataset_duckdb: TCrossDsFixture, ) -> None: - """Star-schema joined to three dimensions across two datasets""" ds_crm, ds_inv = cross_dataset_duckdb joined = ( - ds_inv.table("purchases") - .join(ds_crm.table("users"), on="purchases.user_id = users.id") - .join("inventory_items", on="purchases.sku = inventory_items.sku") - .join("warehouses", on="inventory_items.warehouse_id = warehouses.warehouse_id") + ds_crm.table("users") + .join(ds_inv.table("purchases"), on="users.id = purchases.user_id") + .join(ds_inv.table("inventory_items"), on="purchases.sku = inventory_items.sku") ) - df = joined.order_by("purchase_id").df() + df = joined.order_by("purchases__purchase_id").df() assert len(df) == 3 - assert "warehouses__city" in df.columns - assert list(df["warehouses__city"]) == ["Berlin", "Paris", "Berlin"] + assert list(df["purchases__purchase_id"]) == [1, 2, 3] + assert list(df["name"]) == ["Alice", "Alice", "Bob"] + assert list(df["purchases__sku"]) == ["W-001", "G-001", "W-001"] + assert list(df["inventory_items__quantity"]) == [50, 30, 50] -@pytest.mark.xfail(reason="Column 'users.id' could not be resolved for table: 'users'") -def test_cross_dataset_join_chain_three_datasets( - three_way_cross_dataset_duckdb: TCrossDs3Fixture, +def test_cross_dataset_join_chain_with_filtered_step( + cross_dataset_duckdb: TCrossDsFixture, ) -> None: - ds_crm, ds_inv, ds_billing = three_way_cross_dataset_duckdb + ds_crm, ds_inv = cross_dataset_duckdb - joined = ( - ds_inv.table("purchases") - .join(ds_crm.table("users"), on="purchases.user_id = users.id") - .join(ds_billing.table("subscriptions"), on="users.id = subscriptions.user_id") + alice_purchases = ds_inv.table("purchases").where("user_id", "eq", 1) + joined = alice_purchases.join(ds_crm.table("users"), on="purchases.user_id = users.id").join( + "inventory_items", on="purchases.sku = inventory_items.sku" ) df = joined.order_by("purchase_id").df() - assert len(df) == 3 - assert "users__name" in df.columns - assert "subscriptions__plan" in df.columns - assert list(df["users__name"]) == ["Alice", "Alice", "Bob"] - assert list(df["subscriptions__plan"]) == ["enterprise", "enterprise", "free"] + assert len(df) == 2 + assert list(df["purchase_id"]) == [1, 2] + assert list(df["users__name"]) == ["Alice", "Alice"] + assert list(df["inventory_items__quantity"]) == [50, 30] def test_cross_dataset_join_chain_does_not_mutate_sources( @@ -1274,21 +1529,68 @@ def test_cross_dataset_join_chain_does_not_mutate_sources( assert purchases.join(users, on="purchases.user_id = users.id").to_sql() == step1_sql -def test_cross_dataset_join_chain_with_filtered_step( +def test_cross_dataset_join_chain_columns_schema_matches_df( + three_way_cross_dataset_duckdb: TCrossDs3Fixture, +) -> None: + ds_crm, ds_inv, ds_billing = three_way_cross_dataset_duckdb + + joined = ( + ds_inv.table("purchases") + .join(ds_crm.table("users"), on="purchases.user_id = users.id") + .join(ds_billing.table("subscriptions"), on="users.id = subscriptions.user_id") + ) + + schema_cols = set(joined.columns_schema.keys()) + assert schema_cols, "columns_schema must not be empty" + + df = joined.df() + df_cols = set(df.columns) + + assert schema_cols == df_cols + + +@pytest.mark.xfail( + reason=( + "Column 'inventory_items.warehouse_id' could not be resolved for table: 'inventory_items'" + ) +) +def test_cross_dataset_join_chain_four_tables( cross_dataset_duckdb: TCrossDsFixture, ) -> None: + """Star-schema joined to three dimensions across two datasets""" ds_crm, ds_inv = cross_dataset_duckdb - alice_purchases = ds_inv.table("purchases").where("user_id", "eq", 1) - joined = alice_purchases.join(ds_crm.table("users"), on="purchases.user_id = users.id").join( - "inventory_items", on="purchases.sku = inventory_items.sku" + joined = ( + ds_inv.table("purchases") + .join(ds_crm.table("users"), on="purchases.user_id = users.id") + .join("inventory_items", on="purchases.sku = inventory_items.sku") + .join("warehouses", on="inventory_items.warehouse_id = warehouses.warehouse_id") ) df = joined.order_by("purchase_id").df() - assert len(df) == 2 - assert list(df["purchase_id"]) == [1, 2] - assert list(df["users__name"]) == ["Alice", "Alice"] - assert list(df["inventory_items__quantity"]) == [50, 30] + assert len(df) == 3 + assert "warehouses__city" in df.columns + assert list(df["warehouses__city"]) == ["Berlin", "Paris", "Berlin"] + + +@pytest.mark.xfail(reason="Column 'users.id' could not be resolved for table: 'users'") +def test_cross_dataset_join_chain_three_datasets( + three_way_cross_dataset_duckdb: TCrossDs3Fixture, +) -> None: + ds_crm, ds_inv, ds_billing = three_way_cross_dataset_duckdb + + joined = ( + ds_inv.table("purchases") + .join(ds_crm.table("users"), on="purchases.user_id = users.id") + .join(ds_billing.table("subscriptions"), on="users.id = subscriptions.user_id") + ) + df = joined.order_by("purchase_id").df() + + assert len(df) == 3 + assert "users__name" in df.columns + assert "subscriptions__plan" in df.columns + assert list(df["users__name"]) == ["Alice", "Alice", "Bob"] + assert list(df["subscriptions__plan"]) == ["enterprise", "enterprise", "free"] @pytest.mark.xfail(reason="unqualified where column `quantity` can't be resolved") @@ -1313,9 +1615,9 @@ def test_cross_dataset_chain_same_named_tables_disambiguated( same_named_cross_dataset_duckdb: TCrossDsFixture, ) -> None: """CRM and marketing both expose a `users` table.""" - ds_crm, ds_mkt = same_named_cross_dataset_duckdb + ds_crm, ds_marketing = same_named_cross_dataset_duckdb - marketing = ds_mkt.query("SELECT * FROM users AS mkt_users") + marketing = ds_marketing.query("SELECT * FROM users AS mkt_users") joined = ( ds_crm.table("users__orders") .join("users") @@ -1328,177 +1630,3 @@ def test_cross_dataset_chain_same_named_tables_disambiguated( assert "marketing__segment" in df.columns assert list(df["users__name"]) == ["Alice", "Alice", "Bob"] assert list(df["marketing__segment"]) == ["pro", "pro", "free"] - - -def test_explicit_on_left_join_keeps_unmatched_left_rows( - dataset_with_relational_tables: dlt.Dataset, -) -> None: - ds = dataset_with_relational_tables - joined = ds.table("countries").join( - "customers", kind="left", on="countries.code = customers.country_code" - ) - df = joined.order_by("code").df() - assert len(df) == 4 - assert list(df["code"]) == ["DE", "DE", "ES", "FR"] - assert list(df["customers__name"]) == ["Alice", "Charlie", None, "Bob"] - es_row = df[df["code"] == "ES"].iloc[0] - assert es_row["name"] == "Spain" - customers_cols = [c for c in df.columns if c.startswith("customers__")] - assert es_row[customers_cols].isna().all() - - -def test_explicit_on_composite_key( - dataset_with_annotated_references: dlt.Dataset, -) -> None: - ds = dataset_with_annotated_references - joined = ds.table("account_memberships").join( - "accounts", - on=( - "account_memberships.account_id = accounts.account_id " - "AND account_memberships.tenant_id = accounts.tenant_id" - ), - ) - df = joined.order_by("accounts__name").df() - - assert len(df) == 3 - assert list(df["accounts__name"]) == ["Acme", "Globex", "Initech"] - - -def test_explicit_on_with_filtered_lhs( - dataset_with_relational_tables: dlt.Dataset, -) -> None: - ds = dataset_with_relational_tables - german_customers = ds.table("customers").where("country_code", "eq", "DE") - joined = german_customers.join("orders", on="customers.customer_id = orders.customer_id") - df = joined.df() - assert len(df) == 3 - assert list(df["name"]) == ["Alice", "Alice", "Charlie"] - assert list(df["orders__amount"]) == [50.0, 75.0, 30.0] - - -@pytest.mark.xfail(reason="ON expression must be non-empty") -def test_explicit_on_rejects_invalid_on_expression( - dataset_with_relational_tables: dlt.Dataset, -) -> None: - ds = dataset_with_relational_tables - with pytest.raises(ValueError, match="non-empty SQL expression"): - ds.table("customers").join("orders", on="") - - -@pytest.mark.xfail(reason="Unsupported join kind should be rejected") -def test_explicit_on_rejects_unknown_kind( - dataset_with_relational_tables: dlt.Dataset, -) -> None: - ds = dataset_with_relational_tables - - with pytest.raises(ValueError, match="kind=outer"): - ds.table("customers").join( # type: ignore[call-overload] - "orders", - kind="outer", - on="customers.customer_id = orders.customer_id", - ) - - -def test_explicit_on_with_projected_lhs_preserves_left_projection( - dataset_with_relational_tables: dlt.Dataset, -) -> None: - ds = dataset_with_relational_tables - narrow_customers = ds.table("customers").select("customer_id", "name") - joined = narrow_customers.join("orders", on="customers.customer_id = orders.customer_id") - df = joined.df() - assert len(df) == 4 - lhs_cols = {c for c in df.columns if not c.startswith("orders__")} - assert lhs_cols == {"customer_id", "name"} - assert "country_code" not in df.columns - assert "orders__amount" in df.columns - assert list(df["orders__amount"]) == [50.0, 75.0, 200.0, 30.0] - - -@pytest.mark.xfail(reason="Column 'o.customer_id' could not be resolved for table: 'o'") -def test_explicit_on_with_aliased_query_relations( - dataset_with_relational_tables: dlt.Dataset, -) -> None: - ds = dataset_with_relational_tables - customers = ds.query("SELECT * FROM customers AS c") - orders = ds.query("SELECT * FROM orders AS o") - - joined = customers.join(orders, on="c.customer_id = o.customer_id") - df = joined.order_by("o__order_id").df() - - assert len(df) == 4 - assert list(df["customer_id"]) == [1, 1, 2, 3] - assert list(df["name"]) == ["Alice", "Alice", "Bob", "Charlie"] - assert list(df["o__amount"]) == [50.0, 75.0, 200.0, 30.0] - - -def test_explicit_on_with_aggregated_rhs( - dataset_with_relational_tables: dlt.Dataset, -) -> None: - ds = dataset_with_relational_tables - order_totals = ds.query( - "SELECT customer_id, SUM(amount) AS total_amount FROM orders GROUP BY customer_id" - ) - - joined = ds.table("customers").join( - order_totals, - on="customers.customer_id = orders.customer_id", - alias="order_totals", - ) - df = joined.order_by("customer_id").df() - - assert len(df) == 3 - assert list(df["customer_id"]) == [1, 2, 3] - assert list(df["name"]) == ["Alice", "Bob", "Charlie"] - assert "order_totals__total_amount" in df.columns - assert list(df["order_totals__total_amount"]) == [125.0, 200.0, 30.0] - assert "order_totals__amount" not in df.columns - - -def test_explicit_on_projection_alias_collision_rejected( - dataset_with_relational_tables: dlt.Dataset, -) -> None: - ds = dataset_with_relational_tables - left = ds.query("SELECT customer_id, 1 AS orders__amount FROM customers") - - with pytest.raises(ValueError, match="conflict with existing columns"): - left.join("orders", on="customers.customer_id = orders.customer_id") - - -def test_cross_dataset_join_to_sql_uses_each_dataset_name( - cross_dataset_duckdb: TCrossDsFixture, -) -> None: - ds_a, ds_b = cross_dataset_duckdb - - joined = ds_a.table("users").join( - ds_b.table("purchases"), - on="users.id = purchases.user_id", - ) - sql = joined.to_sql() - - assert f'"{ds_a.dataset_name}"."users"' in sql - assert f'"{ds_b.dataset_name}"."purchases"' in sql - assert f'"{ds_b.dataset_name}"."users"' not in sql - assert f'"{ds_a.dataset_name}"."purchases"' not in sql - - -def test_cross_dataset_join_with_aggregated_rhs( - cross_dataset_duckdb: TCrossDsFixture, -) -> None: - ds_a, ds_b = cross_dataset_duckdb - - purchase_totals = ds_b.query( - "SELECT user_id, SUM(quantity) AS total_quantity FROM purchases GROUP BY user_id" - ) - joined = ds_a.table("users").join( - purchase_totals, - on="users.id = purchases.user_id", - alias="purchase_totals", - ) - df = joined.order_by("id").df() - - assert len(df) == 2 - assert list(df["id"]) == [1, 2] - assert list(df["name"]) == ["Alice", "Bob"] - assert "purchase_totals__total_quantity" in df.columns - assert [int(x) for x in df["purchase_totals__total_quantity"]] == [3, 1] - assert "purchase_totals__quantity" not in df.columns From 2576389045cdb05ded575045acf1b20057d58953 Mon Sep 17 00:00:00 2001 From: Anton Burnashev Date: Fri, 29 May 2026 15:19:14 +0200 Subject: [PATCH 12/30] validate join kind, ambiguous `on` qualifier, and empty `on` --- dlt/dataset/_join.py | 6 ++++++ dlt/dataset/relation.py | 9 +++++++++ tests/dataset/test_relation_join.py | 14 +++----------- 3 files changed, 18 insertions(+), 11 deletions(-) diff --git a/dlt/dataset/_join.py b/dlt/dataset/_join.py index b3d9b1e1eb..9b4ffa8c83 100644 --- a/dlt/dataset/_join.py +++ b/dlt/dataset/_join.py @@ -484,6 +484,12 @@ def _bind_on_predicate( if not isinstance(table_node, sge.Identifier): continue qualifier = table_node.name + if qualifier in left_qualifiers and qualifier in right_qualifiers: + raise ValueError( + f"Ambiguous qualifier `{qualifier}` in join `on` expression: it matches both " + "the left and right side of the join. Alias one side (e.g. via `query(...)` " + "or the join `alias`) so each `on` qualifier is unambiguous." + ) if qualifier in left_qualifiers: continue if qualifier in right_qualifiers: diff --git a/dlt/dataset/relation.py b/dlt/dataset/relation.py index 87a4e1f46d..c996f026f0 100644 --- a/dlt/dataset/relation.py +++ b/dlt/dataset/relation.py @@ -9,6 +9,7 @@ Type, TYPE_CHECKING, Literal, + get_args, ) from textwrap import indent from contextlib import contextmanager @@ -441,6 +442,14 @@ def join( if alias == "": raise ValueError("`alias` must be a non-empty string when provided.") + if kind not in get_args(TJoinType): + raise ValueErrorWithKnownValues( + key="kind", value_received=kind, valid_values=list(get_args(TJoinType)) + ) + + if isinstance(on, str) and not on.strip(): + raise ValueError("`on` must be a non-empty SQL expression.") + target_dataset, target_table, target_columns = self._resolve_join_target(other, on=on) is_same_dataset = self._dataset._is_same_dataset(target_dataset) diff --git a/tests/dataset/test_relation_join.py b/tests/dataset/test_relation_join.py index ad6d623333..364d53342b 100644 --- a/tests/dataset/test_relation_join.py +++ b/tests/dataset/test_relation_join.py @@ -1152,16 +1152,16 @@ def test_explicit_on_rejects_self_join( ) -@pytest.mark.xfail(reason="ON expression must be non-empty") +@pytest.mark.parametrize("on", ["", " "], ids=["empty", "whitespace"]) def test_explicit_on_rejects_invalid_on_expression( dataset_with_relational_tables: dlt.Dataset, + on: str, ) -> None: ds = dataset_with_relational_tables with pytest.raises(ValueError, match="non-empty SQL expression"): - ds.table("customers").join("orders", on="") + ds.table("customers").join("orders", on=on) -@pytest.mark.xfail(reason="Unsupported join kind should be rejected") def test_explicit_on_rejects_unknown_kind( dataset_with_relational_tables: dlt.Dataset, ) -> None: @@ -1378,7 +1378,6 @@ def test_cross_dataset_join_with_same_table_names_keeps_sources_unambiguous( assert list(df["marketing__segment"]) == ["pro", "free"] -@pytest.mark.xfail(reason="Ambiguous qualifier should be rejected") def test_cross_dataset_same_named_join_rejects_ambiguous_on_qualifier( same_named_cross_dataset_duckdb: TCrossDsFixture, ) -> None: @@ -1435,13 +1434,6 @@ def test_cross_dataset_join_chain_magic_then_cross( def test_cross_dataset_join_chain_magic_then_two_crossings( three_way_cross_dataset_duckdb: TCrossDs3Fixture, ) -> None: - """A magic (schema-resolved) join followed by two cross-dataset joins across three - systems: nested CRM orders are enriched with their customer (magic, via the parent - reference), then bridged to that customer's purchases (inventory) and subscription - (billing). - - The magic join emits unqualified CRM tables; the first crossing must retroactively - db-qualify them, and the final query interleaves all three datasets.""" ds_crm, ds_inv, ds_billing = three_way_cross_dataset_duckdb joined = ( From 491bd4bf0f5297e68d5325a51aeea045d11443eb Mon Sep 17 00:00:00 2001 From: Anton Burnashev Date: Fri, 29 May 2026 16:24:57 +0200 Subject: [PATCH 13/30] fix t-sql errors and qualifier matching issues --- dlt/dataset/_join.py | 82 +++++++++++++++++-------- dlt/dataset/relation.py | 31 ++++++---- tests/dataset/test_relation_join.py | 92 ++++++++++++++++++++++++++--- 3 files changed, 160 insertions(+), 45 deletions(-) diff --git a/dlt/dataset/_join.py b/dlt/dataset/_join.py index 9b4ffa8c83..3bcff12f46 100644 --- a/dlt/dataset/_join.py +++ b/dlt/dataset/_join.py @@ -15,10 +15,15 @@ from dlt.dataset.relation import Relation, TJoinType _INTERMEDIATE_JOIN_ALIAS_PREFIX = "_dlt_int_t" +_EXPLICIT_JOIN_ALIAS_PREFIX = "_dlt_jt_" _TExpr = TypeVar("_TExpr", bound=sge.Expression) +def _is_internal_join_alias(qualifier: str) -> bool: + return qualifier.startswith((_INTERMEDIATE_JOIN_ALIAS_PREFIX, _EXPLICIT_JOIN_ALIAS_PREFIX)) + + class _JoinRef(TypedDict): """A resolved join step from currently attached table to a target table.""" @@ -378,6 +383,11 @@ def _apply_join( if not isinstance(query, sge.Select): raise ValueError(f"Join query `{query}` must be an SQL SELECT statement.") + # qualify its bare WHERE/ORDER BY columns so they survive a later join + # that introduces a same-named column + if not query.args.get("joins"): + _qualify_unscoped_predicate_columns(query, _left_source_qualifier(query) or left_table) + join_params, target_qualifier = _discover_join_params( query, schema=schema, @@ -444,9 +454,9 @@ def _left_source_qualifier(query: sge.Query) -> Optional[str]: return None -def _collect_left_qualifiers(query: sge.Query) -> Set[str]: - """Collect qualifiers (table names or aliases) the LHS exposes to ON binding.""" - qualifiers: Set[str] = set() +def _existing_source_qualifier_map(query: sge.Query) -> dict[str, str]: + """Map each existing join input's user-facing qualifier to its SQL qualifier.""" + qualifier_map: dict[str, str] = {} sources: list[sge.Expression] = [] from_expr = query.args.get("from_") or query.args.get("from") @@ -460,43 +470,66 @@ def _collect_left_qualifiers(query: sge.Query) -> Set[str]: for source in sources: if isinstance(source, sge.Table): result = _extract_table_qualifier(source) - if result: - qualifiers.add(result[1]) + if not result: + continue + table_name, sql_qualifier = result + if sql_qualifier == table_name or _is_internal_join_alias(sql_qualifier): + # unaliased or internally aliased: the user references it by table name + qualifier_map[table_name] = sql_qualifier + else: + # an explicit alias replaces the table name as the usable qualifier + qualifier_map[sql_qualifier] = sql_qualifier elif isinstance(source, sge.Subquery): - alias_name = _subquery_alias_name(source) - if alias_name is not None: - qualifiers.add(alias_name) + sql_qualifier = _subquery_alias_name(source) + if sql_qualifier is None: + continue + source_qualifier = _left_source_qualifier(source.this) or sql_qualifier + qualifier_map[source_qualifier] = sql_qualifier - return qualifiers + return qualifier_map def _bind_on_predicate( on_expr: sge.Expression, *, - left_qualifiers: Set[str], - right_qualifiers: Set[str], - right_internal_alias: str, + existing_qualifier_map: dict[str, str], + new_right_qualifiers: Set[str], + new_right_alias: str, ) -> sge.Expression: - """Rewrite RHS-side column qualifiers in ``on_expr`` to the internal RHS alias.""" + """Rewrite column qualifiers in ``on_expr`` to the SQL qualifiers of the join inputs.""" on_expr = on_expr.copy() for col in on_expr.find_all(sge.Column): table_node = col.args.get("table") if not isinstance(table_node, sge.Identifier): continue qualifier = table_node.name - if qualifier in left_qualifiers and qualifier in right_qualifiers: + in_existing = qualifier in existing_qualifier_map + in_new = qualifier in new_right_qualifiers + if in_existing and in_new: raise ValueError( f"Ambiguous qualifier `{qualifier}` in join `on` expression: it matches both " "the left and right side of the join. Alias one side (e.g. via `query(...)` " "or the join `alias`) so each `on` qualifier is unambiguous." ) - if qualifier in left_qualifiers: - continue - if qualifier in right_qualifiers: - col.set("table", sge.to_identifier(right_internal_alias, quoted=False)) + if in_new: + col.set("table", sge.to_identifier(new_right_alias, quoted=False)) + elif in_existing: + col.set("table", sge.to_identifier(existing_qualifier_map[qualifier], quoted=False)) return on_expr +def _qualify_unscoped_predicate_columns(query: sge.Select, source_qualifier: str) -> None: + """Bind unqualified columns in pre-join WHERE/ORDER BY clauses to the single source.""" + qualifier_identifier = sge.to_identifier(source_qualifier, quoted=False) + for clause_key in ("where", "order"): + clause = query.args.get(clause_key) + if clause is None: + continue + for col in clause.find_all(sge.Column): + if col.args.get("table") is None and col.parent_select is query: + col.set("table", qualifier_identifier.copy()) + + def _apply_explicit_join( expression: sge.Query, *, @@ -542,7 +575,10 @@ def _apply_explicit_join( ) left_source_qualifier = _left_source_qualifier(query) or from_expr.this.name - internal_alias = f"_dlt_jt_{projection_prefix}" + if not query.args.get("joins"): + _qualify_unscoped_predicate_columns(query, left_source_qualifier) + + internal_alias = f"{_EXPLICIT_JOIN_ALIAS_PREFIX}{projection_prefix}" target_expr: sge.Expression if target is not None and target._query is not None: @@ -568,13 +604,11 @@ def _apply_explicit_join( else: on_expr = on - left_qualifiers = _collect_left_qualifiers(query) - right_qualifiers = {target_table, projection_prefix} on_expr = _bind_on_predicate( on_expr, - left_qualifiers=left_qualifiers, - right_qualifiers=right_qualifiers, - right_internal_alias=internal_alias, + existing_qualifier_map=_existing_source_qualifier_map(query), + new_right_qualifiers={target_table, projection_prefix}, + new_right_alias=internal_alias, ) join_expr = sge.Join(this=target_expr, kind=kind.upper()).on(on_expr) diff --git a/dlt/dataset/relation.py b/dlt/dataset/relation.py index c996f026f0..ce08f676c8 100644 --- a/dlt/dataset/relation.py +++ b/dlt/dataset/relation.py @@ -47,7 +47,12 @@ _RelationIncrementalContext, _sqlglot_type_for_column, ) -from dlt.dataset._join import _apply_join, _apply_explicit_join, _extract_joined_table_aliases +from dlt.dataset._join import ( + _apply_join, + _apply_explicit_join, + _extract_joined_table_aliases, + _left_source_qualifier, +) if TYPE_CHECKING: @@ -363,15 +368,23 @@ def order_by(self, column_name: str, direction: TSortOrder = "asc") -> Self: f"`{direction}` is an invalid sort order, allowed values are: `asc` and `desc`" ) order_expr = sge.Ordered( - this=sge.Column( - this=sge.to_identifier(column_name, quoted=True), - ), + this=self._resolve_output_column(column_name), desc=(direction == "desc"), ) rel = self.__copy__() rel._sqlglot_expression = rel.sqlglot_expression.order_by(order_expr) return rel + def _resolve_output_column(self, column_name: str) -> sge.Expression: + """Resolve an output column name to its projected source expression.""" + for proj in self.sqlglot_expression.selects: + if isinstance(proj, sge.Star) or proj.output_name != column_name: + continue + source = proj.this if isinstance(proj, sge.Alias) else proj + if not isinstance(source, sge.Star): + return source.copy() + return sge.Column(this=sge.to_identifier(column_name, quoted=True)) + @overload def join( self, @@ -1034,11 +1047,5 @@ def _find_table_columns(schemas: Sequence[dlt.Schema], table_name: str) -> TTabl def _extract_subquery_alias(relation: dlt.Relation) -> str: - """Extract a stable alias for a transformed Relation without a base table.""" - expr = relation.sqlglot_expression - from_expr = expr.args.get("from_") or expr.args.get("from") - if isinstance(from_expr, sge.From) and isinstance(from_expr.this, sge.Table): - table_id = from_expr.this.this - if isinstance(table_id, sge.Identifier): - return table_id.name - return "subquery" + """Extract the source qualifier of a transformed Relation without a base table.""" + return _left_source_qualifier(relation.sqlglot_expression) or "subquery" diff --git a/tests/dataset/test_relation_join.py b/tests/dataset/test_relation_join.py index 364d53342b..f95923425b 100644 --- a/tests/dataset/test_relation_join.py +++ b/tests/dataset/test_relation_join.py @@ -3,6 +3,7 @@ from typing import Any, Sequence, Callable, TypedDict, Optional import pytest +import sqlglot import sqlglot.expressions as sge import dlt @@ -1024,6 +1025,66 @@ def test_explicit_on_projection_prefix( assert right_aliases == expected +def _order_by_sort_key(rel: dlt.Relation) -> sge.Column: + """Return the single ORDER BY sort-key column of a relation.""" + order = rel.sqlglot_expression.args.get("order") + assert order is not None and len(order.expressions) == 1 + sort_key = order.expressions[0].this + assert isinstance(sort_key, sge.Column) + return sort_key + + +@pytest.mark.parametrize( + "build_join,order_column,expected_qualifier,expected_column", + [ + pytest.param( + lambda ds: ds.table("customers").join( + "orders", on="customers.customer_id = orders.customer_id" + ), + "orders__order_id", + "_dlt_jt_orders", + "order_id", + id="default-prefix", + ), + pytest.param( + lambda ds: ds.table("customers").join( + "orders", on="customers.customer_id = orders.customer_id", alias="o" + ), + "o__order_id", + "_dlt_jt_o", + "order_id", + id="custom-alias", + ), + ], +) +def test_order_by_join_output_resolves_to_source_column( + dataset_with_relational_tables: dlt.Dataset, + build_join: Callable[[dlt.Dataset], dlt.Relation], + order_column: str, + expected_qualifier: str, + expected_column: str, +) -> None: + rel = build_join(dataset_with_relational_tables).order_by(order_column) + sort_key = _order_by_sort_key(rel) + assert sort_key.table == expected_qualifier, f"bare alias leaked: {sort_key.sql()}" + assert sort_key.name == expected_column + + +def test_order_by_join_output_renders_resolvable_tsql( + dataset_with_relational_tables: dlt.Dataset, +) -> None: + rel = ( + dataset_with_relational_tables.table("customers") + .join("orders", on="customers.customer_id = orders.customer_id") + .order_by("orders__order_id") + ) + order_by = sqlglot.transpile(rel.to_sql(), read="duckdb", write="tsql")[0].split("ORDER BY", 1)[ + 1 + ] + assert "[orders__order_id]" not in order_by + assert "[_dlt_jt_orders].[order_id]" in order_by + + def test_explicit_on_projection_alias_collision_rejected( dataset_with_relational_tables: dlt.Dataset, ) -> None: @@ -1113,7 +1174,6 @@ def test_explicit_on_with_aggregated_rhs( assert "order_totals__amount" not in df.columns -@pytest.mark.xfail(reason="Column 'o.customer_id' could not be resolved for table: 'o'") def test_explicit_on_with_aliased_query_relations( dataset_with_relational_tables: dlt.Dataset, ) -> None: @@ -1541,11 +1601,6 @@ def test_cross_dataset_join_chain_columns_schema_matches_df( assert schema_cols == df_cols -@pytest.mark.xfail( - reason=( - "Column 'inventory_items.warehouse_id' could not be resolved for table: 'inventory_items'" - ) -) def test_cross_dataset_join_chain_four_tables( cross_dataset_duckdb: TCrossDsFixture, ) -> None: @@ -1565,7 +1620,6 @@ def test_cross_dataset_join_chain_four_tables( assert list(df["warehouses__city"]) == ["Berlin", "Paris", "Berlin"] -@pytest.mark.xfail(reason="Column 'users.id' could not be resolved for table: 'users'") def test_cross_dataset_join_chain_three_datasets( three_way_cross_dataset_duckdb: TCrossDs3Fixture, ) -> None: @@ -1585,7 +1639,28 @@ def test_cross_dataset_join_chain_three_datasets( assert list(df["subscriptions__plan"]) == ["enterprise", "enterprise", "free"] -@pytest.mark.xfail(reason="unqualified where column `quantity` can't be resolved") +def test_cross_dataset_join_chain_references_aliased_join_by_table_name( + three_way_cross_dataset_duckdb: TCrossDs3Fixture, +) -> None: + """A later `on` references an earlier join target by its table name even when that join + used a custom `alias`""" + ds_crm, ds_inv, ds_billing = three_way_cross_dataset_duckdb + + joined = ( + ds_inv.table("purchases") + .join(ds_crm.table("users"), on="purchases.user_id = users.id", alias="u") + .join(ds_billing.table("subscriptions"), on="users.id = subscriptions.user_id") + ) + df = joined.order_by("purchase_id").df() + + assert len(df) == 3 + # custom alias drives the projection prefix + assert "u__name" in df.columns + assert "users__name" not in df.columns + assert list(df["u__name"]) == ["Alice", "Alice", "Bob"] + assert list(df["subscriptions__plan"]) == ["enterprise", "enterprise", "free"] + + def test_cross_dataset_join_chain_filter_on_later_colliding_column( cross_dataset_duckdb: TCrossDsFixture, ) -> None: @@ -1602,7 +1677,6 @@ def test_cross_dataset_join_chain_filter_on_later_colliding_column( assert list(df["inventory_items__quantity"]) == [50] -@pytest.mark.xfail(reason="Column 'mkt_users.id' could not be resolved for table: 'mkt_users'") def test_cross_dataset_chain_same_named_tables_disambiguated( same_named_cross_dataset_duckdb: TCrossDsFixture, ) -> None: From 5bb6bf570b45905f1bd2a91fd10064a3428a25a9 Mon Sep 17 00:00:00 2001 From: Anton Burnashev Date: Mon, 1 Jun 2026 18:16:19 +0200 Subject: [PATCH 14/30] resolve explicit-join columns via natural qualifiers --- dlt/dataset/_join.py | 81 +++++++---------------------- tests/dataset/test_relation_join.py | 61 ++++++++-------------- 2 files changed, 41 insertions(+), 101 deletions(-) diff --git a/dlt/dataset/_join.py b/dlt/dataset/_join.py index 3bcff12f46..9357bcf672 100644 --- a/dlt/dataset/_join.py +++ b/dlt/dataset/_join.py @@ -15,15 +15,10 @@ from dlt.dataset.relation import Relation, TJoinType _INTERMEDIATE_JOIN_ALIAS_PREFIX = "_dlt_int_t" -_EXPLICIT_JOIN_ALIAS_PREFIX = "_dlt_jt_" _TExpr = TypeVar("_TExpr", bound=sge.Expression) -def _is_internal_join_alias(qualifier: str) -> bool: - return qualifier.startswith((_INTERMEDIATE_JOIN_ALIAS_PREFIX, _EXPLICIT_JOIN_ALIAS_PREFIX)) - - class _JoinRef(TypedDict): """A resolved join step from currently attached table to a target table.""" @@ -454,9 +449,9 @@ def _left_source_qualifier(query: sge.Query) -> Optional[str]: return None -def _existing_source_qualifier_map(query: sge.Query) -> dict[str, str]: - """Map each existing join input's user-facing qualifier to its SQL qualifier.""" - qualifier_map: dict[str, str] = {} +def _collect_source_qualifiers(query: sge.Query) -> Set[str]: + """Collect the SQL qualifiers (aliases or table names) of every FROM/JOIN source.""" + qualifiers: Set[str] = set() sources: list[sge.Expression] = [] from_expr = query.args.get("from_") or query.args.get("from") @@ -470,52 +465,14 @@ def _existing_source_qualifier_map(query: sge.Query) -> dict[str, str]: for source in sources: if isinstance(source, sge.Table): result = _extract_table_qualifier(source) - if not result: - continue - table_name, sql_qualifier = result - if sql_qualifier == table_name or _is_internal_join_alias(sql_qualifier): - # unaliased or internally aliased: the user references it by table name - qualifier_map[table_name] = sql_qualifier - else: - # an explicit alias replaces the table name as the usable qualifier - qualifier_map[sql_qualifier] = sql_qualifier + if result: + qualifiers.add(result[1]) elif isinstance(source, sge.Subquery): - sql_qualifier = _subquery_alias_name(source) - if sql_qualifier is None: - continue - source_qualifier = _left_source_qualifier(source.this) or sql_qualifier - qualifier_map[source_qualifier] = sql_qualifier - - return qualifier_map + alias_name = _subquery_alias_name(source) + if alias_name is not None: + qualifiers.add(alias_name) - -def _bind_on_predicate( - on_expr: sge.Expression, - *, - existing_qualifier_map: dict[str, str], - new_right_qualifiers: Set[str], - new_right_alias: str, -) -> sge.Expression: - """Rewrite column qualifiers in ``on_expr`` to the SQL qualifiers of the join inputs.""" - on_expr = on_expr.copy() - for col in on_expr.find_all(sge.Column): - table_node = col.args.get("table") - if not isinstance(table_node, sge.Identifier): - continue - qualifier = table_node.name - in_existing = qualifier in existing_qualifier_map - in_new = qualifier in new_right_qualifiers - if in_existing and in_new: - raise ValueError( - f"Ambiguous qualifier `{qualifier}` in join `on` expression: it matches both " - "the left and right side of the join. Alias one side (e.g. via `query(...)` " - "or the join `alias`) so each `on` qualifier is unambiguous." - ) - if in_new: - col.set("table", sge.to_identifier(new_right_alias, quoted=False)) - elif in_existing: - col.set("table", sge.to_identifier(existing_qualifier_map[qualifier], quoted=False)) - return on_expr + return qualifiers def _qualify_unscoped_predicate_columns(query: sge.Select, source_qualifier: str) -> None: @@ -578,7 +535,13 @@ def _apply_explicit_join( if not query.args.get("joins"): _qualify_unscoped_predicate_columns(query, left_source_qualifier) - internal_alias = f"{_EXPLICIT_JOIN_ALIAS_PREFIX}{projection_prefix}" + target_qualifier = target_table + if target_qualifier in _collect_source_qualifiers(query): + raise ValueError( + f"Join target qualifier `{target_qualifier}` already names a source in the query. " + "Alias one side (e.g. via `query('SELECT * FROM ... AS alias')`) so each `on` " + "qualifier is unambiguous." + ) target_expr: sge.Expression if target is not None and target._query is not None: @@ -588,12 +551,11 @@ def _apply_explicit_join( rhs_inner = _qualify_physical_tables_with_dataset(rhs_inner, target_dataset_name) target_expr = sge.Subquery( this=rhs_inner, - alias=sge.TableAlias(this=sge.to_identifier(internal_alias, quoted=False)), + alias=sge.TableAlias(this=sge.to_identifier(target_qualifier, quoted=False)), ) else: table_node_args: dict[str, sge.Expression] = { "this": sge.to_identifier(target_table, quoted=True), - "alias": sge.TableAlias(this=sge.to_identifier(internal_alias, quoted=False)), } if target_dataset_name: table_node_args["db"] = sge.to_identifier(target_dataset_name, quoted=False) @@ -604,13 +566,6 @@ def _apply_explicit_join( else: on_expr = on - on_expr = _bind_on_predicate( - on_expr, - existing_qualifier_map=_existing_source_qualifier_map(query), - new_right_qualifiers={target_table, projection_prefix}, - new_right_alias=internal_alias, - ) - join_expr = sge.Join(this=target_expr, kind=kind.upper()).on(on_expr) query = query.join(join_expr) @@ -618,7 +573,7 @@ def _apply_explicit_join( query, left_source_qualifier=left_source_qualifier, target_columns=target_columns, - target_qualifier=internal_alias, + target_qualifier=target_qualifier, projection_prefix=projection_prefix, allow_existing_target_projection=False, ) diff --git a/tests/dataset/test_relation_join.py b/tests/dataset/test_relation_join.py index f95923425b..b99d03e465 100644 --- a/tests/dataset/test_relation_join.py +++ b/tests/dataset/test_relation_join.py @@ -1035,15 +1035,13 @@ def _order_by_sort_key(rel: dlt.Relation) -> sge.Column: @pytest.mark.parametrize( - "build_join,order_column,expected_qualifier,expected_column", + "build_join,order_column", [ pytest.param( lambda ds: ds.table("customers").join( "orders", on="customers.customer_id = orders.customer_id" ), "orders__order_id", - "_dlt_jt_orders", - "order_id", id="default-prefix", ), pytest.param( @@ -1051,8 +1049,6 @@ def _order_by_sort_key(rel: dlt.Relation) -> sge.Column: "orders", on="customers.customer_id = orders.customer_id", alias="o" ), "o__order_id", - "_dlt_jt_o", - "order_id", id="custom-alias", ), ], @@ -1061,13 +1057,11 @@ def test_order_by_join_output_resolves_to_source_column( dataset_with_relational_tables: dlt.Dataset, build_join: Callable[[dlt.Dataset], dlt.Relation], order_column: str, - expected_qualifier: str, - expected_column: str, ) -> None: rel = build_join(dataset_with_relational_tables).order_by(order_column) sort_key = _order_by_sort_key(rel) - assert sort_key.table == expected_qualifier, f"bare alias leaked: {sort_key.sql()}" - assert sort_key.name == expected_column + assert sort_key.table == "orders" + assert sort_key.name == "order_id" def test_order_by_join_output_renders_resolvable_tsql( @@ -1082,7 +1076,7 @@ def test_order_by_join_output_renders_resolvable_tsql( 1 ] assert "[orders__order_id]" not in order_by - assert "[_dlt_jt_orders].[order_id]" in order_by + assert "[orders].[order_id]" in order_by def test_explicit_on_projection_alias_collision_rejected( @@ -1438,12 +1432,11 @@ def test_cross_dataset_join_with_same_table_names_keeps_sources_unambiguous( assert list(df["marketing__segment"]) == ["pro", "free"] -def test_cross_dataset_same_named_join_rejects_ambiguous_on_qualifier( +def test_cross_dataset_same_named_join_rejects_colliding_target( same_named_cross_dataset_duckdb: TCrossDsFixture, ) -> None: ds_crm, ds_marketing = same_named_cross_dataset_duckdb - - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="already names a source"): ds_crm.table("users").join( ds_marketing.table("users"), on="users.id = users.id", @@ -1620,44 +1613,36 @@ def test_cross_dataset_join_chain_four_tables( assert list(df["warehouses__city"]) == ["Berlin", "Paris", "Berlin"] +@pytest.mark.parametrize( + "alias,name_col,absent_name_col", + [ + pytest.param(None, "users__name", "u__name", id="default-prefix"), + pytest.param("u", "u__name", "users__name", id="custom-alias"), + ], +) def test_cross_dataset_join_chain_three_datasets( three_way_cross_dataset_duckdb: TCrossDs3Fixture, + alias: Optional[str], + name_col: str, + absent_name_col: str, ) -> None: + """A chain whose later `on` resolves an earlier join target by table name. An explicit + `alias` only prefixes the projected columns; it is not itself a join qualifier, so + `on="users.id = ..."` binds regardless of the alias chosen for the `users` join.""" ds_crm, ds_inv, ds_billing = three_way_cross_dataset_duckdb joined = ( ds_inv.table("purchases") - .join(ds_crm.table("users"), on="purchases.user_id = users.id") + .join(ds_crm.table("users"), on="purchases.user_id = users.id", alias=alias) .join(ds_billing.table("subscriptions"), on="users.id = subscriptions.user_id") ) df = joined.order_by("purchase_id").df() assert len(df) == 3 - assert "users__name" in df.columns + assert name_col in df.columns + assert absent_name_col not in df.columns assert "subscriptions__plan" in df.columns - assert list(df["users__name"]) == ["Alice", "Alice", "Bob"] - assert list(df["subscriptions__plan"]) == ["enterprise", "enterprise", "free"] - - -def test_cross_dataset_join_chain_references_aliased_join_by_table_name( - three_way_cross_dataset_duckdb: TCrossDs3Fixture, -) -> None: - """A later `on` references an earlier join target by its table name even when that join - used a custom `alias`""" - ds_crm, ds_inv, ds_billing = three_way_cross_dataset_duckdb - - joined = ( - ds_inv.table("purchases") - .join(ds_crm.table("users"), on="purchases.user_id = users.id", alias="u") - .join(ds_billing.table("subscriptions"), on="users.id = subscriptions.user_id") - ) - df = joined.order_by("purchase_id").df() - - assert len(df) == 3 - # custom alias drives the projection prefix - assert "u__name" in df.columns - assert "users__name" not in df.columns - assert list(df["u__name"]) == ["Alice", "Alice", "Bob"] + assert list(df[name_col]) == ["Alice", "Alice", "Bob"] assert list(df["subscriptions__plan"]) == ["enterprise", "enterprise", "free"] From 5cc3c8ca4df3aea03a2f48ea281f7782044b9743 Mon Sep 17 00:00:00 2001 From: Anton Burnashev Date: Mon, 1 Jun 2026 18:34:07 +0200 Subject: [PATCH 15/30] skip lance and lancedb for now --- tests/load/test_relation_join.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/load/test_relation_join.py b/tests/load/test_relation_join.py index 4bd23c709f..4c7d4813a2 100644 --- a/tests/load/test_relation_join.py +++ b/tests/load/test_relation_join.py @@ -120,7 +120,7 @@ def cross_dataset_pipelines( ) -> Any: """Two pipelines on the same physical destination, distinct dataset names.""" _skip_unsupported_filesystem(destination_config) - if destination_config.destination_type == "filesystem": + if destination_config.destination_type in ("filesystem", "lance", "lancedb"): pytest.skip( "cross-dataset joins are not supported on filesystem destinations" " (see dlt/dataset/relation.py:_resolve_join_target)" From 3cb307fca9016cb837eba4cb9f89f34ef7296e43 Mon Sep 17 00:00:00 2001 From: Anton Burnashev Date: Mon, 1 Jun 2026 18:40:16 +0200 Subject: [PATCH 16/30] skip databricks until this is merged https://github.com/dlt-hub/dlt/pull/4011 --- tests/load/test_relation_join.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/load/test_relation_join.py b/tests/load/test_relation_join.py index 4c7d4813a2..95a2bae335 100644 --- a/tests/load/test_relation_join.py +++ b/tests/load/test_relation_join.py @@ -45,7 +45,11 @@ def destination_config( request: pytest.FixtureRequest, ) -> DestinationTestConfiguration: - return cast(DestinationTestConfiguration, request.param) + config = cast(DestinationTestConfiguration, request.param) + # TODO: remove once https://github.com/dlt-hub/dlt/pull/4011 is merged + if config.destination_type == "databricks": + pytest.skip("databricks foreign-key emission breaks this fixture. see dlt-hub/dlt#4011") + return config # TODO: same code in test_read_interfaces.py: factor out into a shared helper From 885b13b223c1b609c31c625b9e827d1fbb00e6b9 Mon Sep 17 00:00:00 2001 From: Anton Burnashev Date: Tue, 2 Jun 2026 13:42:09 +0200 Subject: [PATCH 17/30] fix dotted-string foreign join qualifier + test --- dlt/common/libs/sqlglot.py | 4 +- dlt/dataset/_join.py | 168 +++++++++--------- dlt/dataset/dataset.py | 2 +- dlt/dataset/relation.py | 132 +++++++------- .../general-usage/dataset-access/dataset.md | 2 + tests/dataset/test_relation_join.py | 123 +++++++++++++ tests/load/test_relation_join.py | 2 +- 7 files changed, 272 insertions(+), 161 deletions(-) diff --git a/dlt/common/libs/sqlglot.py b/dlt/common/libs/sqlglot.py index b00cfc7f0f..179c72557a 100644 --- a/dlt/common/libs/sqlglot.py +++ b/dlt/common/libs/sqlglot.py @@ -1086,9 +1086,9 @@ def bind_query( Args: qualified_query: SQLGlot query expression with qualified table/column references sqlglot_schema: Schema mapping for name validation and column resolution - expand_table_name: Function ``(table_name, dataset_name | None) -> [catalog, schema, table]`` + expand_table_name: Function `(table_name, dataset_name | None) -> [catalog, schema, table]` that expands a table name to a fully qualified path. The second argument is the - dataset qualifier from the query (``node.db``), or `None` for the default dataset. + dataset qualifier from the query (`node.db`), or `None` for the default dataset. casefold_identifier: Case transformation function (`str`, `str.upper`, or `str.lower`) Returns: diff --git a/dlt/dataset/_join.py b/dlt/dataset/_join.py index 9357bcf672..34f5bfc78f 100644 --- a/dlt/dataset/_join.py +++ b/dlt/dataset/_join.py @@ -1,7 +1,7 @@ from __future__ import annotations from functools import reduce -from typing import TYPE_CHECKING, Any, Optional, Sequence, Set, TypeVar, Union +from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Sequence, Set, TypeVar, Union import sqlglot import sqlglot.expressions as sge @@ -12,13 +12,26 @@ from dlt.common.libs.sqlglot import TSqlGlotDialect if TYPE_CHECKING: - from dlt.dataset.relation import Relation, TJoinType + from dlt.dataset.relation import TJoinType _INTERMEDIATE_JOIN_ALIAS_PREFIX = "_dlt_int_t" _TExpr = TypeVar("_TExpr", bound=sge.Expression) +class _JoinTarget(NamedTuple): + """Resolved right-hand side of a `Relation.join()`.""" + + dataset_name: str + is_foreign: bool + """`True` when the target lives in a different dataset than the left-hand side.""" + table_name: str + columns: TTableSchemaColumns + schemas: Sequence[Schema] + subquery: Optional[sge.Query] = None + """RHS query embedded as a derived table for transformed relations; `None` for base tables.""" + + class _JoinRef(TypedDict): """A resolved join step from currently attached table to a target table.""" @@ -190,14 +203,33 @@ def _extract_table_qualifier(table_expr: sge.Expression) -> Optional[tuple[str, return table_name, table_name -def _extract_joined_table_aliases(query: sge.Query) -> dict[str, str]: - alias_map: dict[str, str] = {} +def _from_source(query: sge.Query) -> Optional[sge.Expression]: + """Return the FROM source expression (table or subquery), or `None`.""" # sqlglot >= 28 renamed `from` to `from_` internally from_expr = query.args.get("from_") or query.args.get("from") - if not isinstance(from_expr, sge.From) or not isinstance(from_expr.this, sge.Table): + if not isinstance(from_expr, sge.From): + return None + source: Optional[sge.Expression] = from_expr.this + return source + + +def _source_qualifier(source: Optional[sge.Expression]) -> Optional[str]: + """Return the SQL qualifier (alias or table name) of a FROM/JOIN source.""" + if isinstance(source, sge.Table): + result = _extract_table_qualifier(source) + return result[1] if result else None + if isinstance(source, sge.Subquery): + return _subquery_alias_name(source) + return None + + +def _extract_joined_table_aliases(query: sge.Query) -> dict[str, str]: + alias_map: dict[str, str] = {} + from_this = _from_source(query) + if not isinstance(from_this, sge.Table): return alias_map - tables: list[sge.Table] = [from_expr.this] + tables: list[sge.Table] = [from_this] for join in query.args.get("joins") or []: if isinstance(join.this, sge.Table): tables.append(join.this) @@ -356,6 +388,14 @@ def _apply_join_projection( query.set("expressions", [*normalized_left_expressions, *appended_target_columns]) +def _copy_as_select(expression: sge.Query) -> sge.Select: + """Copy `expression` and assert it is a SELECT so a join can be applied.""" + query = expression.copy() + if not isinstance(query, sge.Select): + raise ValueError(f"Join query `{query}` must be an SQL SELECT statement.") + return query + + def _apply_join( expression: sge.Query, *, @@ -367,21 +407,15 @@ def _apply_join( project: bool = True, ) -> sge.Select: """Apply schema-driven join(s) to `expression` and return the new query.""" - # `project=False` adds the JOIN without touching the SELECT list — for join targets whose - # columns are referenced in WHERE/ON predicates but should not appear in the output if left_table not in schema.tables: raise ValueError(f"Table `{left_table}` not found in dataset schema") if right_table not in schema.tables: raise ValueError(f"Table `{right_table}` not found in dataset schema") - query = expression.copy() - if not isinstance(query, sge.Select): - raise ValueError(f"Join query `{query}` must be an SQL SELECT statement.") + query = _copy_as_select(expression) - # qualify its bare WHERE/ORDER BY columns so they survive a later join - # that introduces a same-named column - if not query.args.get("joins"): - _qualify_unscoped_predicate_columns(query, _left_source_qualifier(query) or left_table) + left_source_qualifier = _left_source_qualifier(query) or left_table + _qualify_unscoped_predicate_columns(query, left_source_qualifier) join_params, target_qualifier = _discover_join_params( query, @@ -403,8 +437,6 @@ def _apply_join( ) query = query.join(join_expr) - left_source_qualifier = _left_source_qualifier(query) or left_table - if project: _apply_join_projection( query, @@ -415,14 +447,12 @@ def _apply_join( allow_existing_target_projection=not join_params, ) else: - # filter-only join: qualify the left projection so a bare `*` does not - # expand across the joined table and leak right-side columns at runtime. query.set("expressions", _normalize_left_projection(query, left_source_qualifier)) return query def _qualify_physical_tables_with_dataset(expression: _TExpr, dataset_name: str) -> _TExpr: - """Bind every physical table reference in ``expression`` to ``dataset_name``.""" + """Bind every physical table reference in `expression` to `dataset_name`.""" expression = expression.copy() cte_names = {cte.alias_or_name for cte in expression.find_all(sge.CTE)} db_identifier = sge.to_identifier(dataset_name, quoted=False) @@ -437,46 +467,19 @@ def _qualify_physical_tables_with_dataset(expression: _TExpr, dataset_name: str) def _left_source_qualifier(query: sge.Query) -> Optional[str]: """Return the qualifier used to reference the FROM source (alias or table name).""" - from_expr = query.args.get("from_") or query.args.get("from") - if not isinstance(from_expr, sge.From): - return None - from_this = from_expr.this - if isinstance(from_this, sge.Table): - result = _extract_table_qualifier(from_this) - return result[1] if result else None - if isinstance(from_this, sge.Subquery): - return _subquery_alias_name(from_this) - return None + return _source_qualifier(_from_source(query)) def _collect_source_qualifiers(query: sge.Query) -> Set[str]: """Collect the SQL qualifiers (aliases or table names) of every FROM/JOIN source.""" - qualifiers: Set[str] = set() - sources: list[sge.Expression] = [] - - from_expr = query.args.get("from_") or query.args.get("from") - if isinstance(from_expr, sge.From) and from_expr.this is not None: - sources.append(from_expr.this) - - for join in query.args.get("joins") or []: - if join.this is not None: - sources.append(join.this) - - for source in sources: - if isinstance(source, sge.Table): - result = _extract_table_qualifier(source) - if result: - qualifiers.add(result[1]) - elif isinstance(source, sge.Subquery): - alias_name = _subquery_alias_name(source) - if alias_name is not None: - qualifiers.add(alias_name) - - return qualifiers + sources = [_from_source(query), *(join.this for join in query.args.get("joins") or [])] + return {qualifier for source in sources if (qualifier := _source_qualifier(source)) is not None} def _qualify_unscoped_predicate_columns(query: sge.Select, source_qualifier: str) -> None: - """Bind unqualified columns in pre-join WHERE/ORDER BY clauses to the single source.""" + """Bind unqualified WHERE/ORDER BY columns to the single source.""" + if query.args.get("joins"): + return qualifier_identifier = sge.to_identifier(source_qualifier, quoted=False) for clause_key in ("where", "order"): clause = query.args.get(clause_key) @@ -489,53 +492,38 @@ def _qualify_unscoped_predicate_columns(query: sge.Select, source_qualifier: str def _apply_explicit_join( expression: sge.Query, + target: _JoinTarget, *, - target: Optional["Relation"] = None, - target_table: str, - target_dataset_name: Optional[str], - target_columns: TTableSchemaColumns, on: Union[str, sge.Expression], projection_prefix: str, - kind: "TJoinType", + kind: TJoinType, destination_dialect: TSqlGlotDialect, left_dataset_name: str, ) -> sge.Select: - """Apply an explicit-ON join to ``expression`` and return the new query. + """Apply an explicit-ON join to `expression` and return the new query. Args: expression: Left-side query to join onto. - target: Right-hand Relation object (if transformed/subquery), or None for - string / base-table targets. - target_table: Bare table name for schema lookups and projection. - target_dataset_name: Dataset name for the right-hand side. - target_columns: Columns from the right-hand side for projection. + target: Resolved right-hand side of the join. on: Join condition as a SQL string or sqlglot expression. projection_prefix: Prefix for appended column aliases. kind: SQL join type. destination_dialect: Dialect for parsing string ON expressions. left_dataset_name: Dataset name for the left-hand side. """ - query = expression.copy() - if not isinstance(query, sge.Select): - raise ValueError(f"Join query `{query}` must be an SQL SELECT statement.") + query = _qualify_physical_tables_with_dataset(_copy_as_select(expression), left_dataset_name) - # bind LHS physical tables to the LHS dataset before composing the join. - # otherwise, adding the RHS dataset to the resolver makes bare LHS tables - # ambiguous - query = _qualify_physical_tables_with_dataset(query, left_dataset_name) - - from_expr = query.args.get("from_") or query.args.get("from") - if not isinstance(from_expr, sge.From) or not isinstance(from_expr.this, sge.Table): + from_this = _from_source(query) + if not isinstance(from_this, sge.Table): raise ValueError( "Cannot apply explicit join: left-side query must have a base table " "in its FROM clause (not a subquery or derived table)." ) - left_source_qualifier = _left_source_qualifier(query) or from_expr.this.name + left_source_qualifier = _source_qualifier(from_this) or from_this.name - if not query.args.get("joins"): - _qualify_unscoped_predicate_columns(query, left_source_qualifier) + _qualify_unscoped_predicate_columns(query, left_source_qualifier) - target_qualifier = target_table + target_qualifier = target.table_name if target_qualifier in _collect_source_qualifiers(query): raise ValueError( f"Join target qualifier `{target_qualifier}` already names a source in the query. " @@ -543,10 +531,12 @@ def _apply_explicit_join( "qualifier is unambiguous." ) + target_dataset_name = target.dataset_name if target.is_foreign else None + target_expr: sge.Expression - if target is not None and target._query is not None: - # transformed Relation: embed as subquery - rhs_inner = target.sqlglot_expression + if target.subquery is not None: + # transformed relation: embed its query as a subquery + rhs_inner = target.subquery if target_dataset_name: rhs_inner = _qualify_physical_tables_with_dataset(rhs_inner, target_dataset_name) target_expr = sge.Subquery( @@ -554,12 +544,14 @@ def _apply_explicit_join( alias=sge.TableAlias(this=sge.to_identifier(target_qualifier, quoted=False)), ) else: - table_node_args: dict[str, sge.Expression] = { - "this": sge.to_identifier(target_table, quoted=True), - } - if target_dataset_name: - table_node_args["db"] = sge.to_identifier(target_dataset_name, quoted=False) - target_expr = sge.Table(**table_node_args) + target_expr = sge.Table( + this=sge.to_identifier(target.table_name, quoted=True), + db=( + sge.to_identifier(target_dataset_name, quoted=False) + if target_dataset_name + else None + ), + ) if isinstance(on, str): on_expr = sqlglot.parse_one(on, dialect=destination_dialect) @@ -572,7 +564,7 @@ def _apply_explicit_join( _apply_join_projection( query, left_source_qualifier=left_source_qualifier, - target_columns=target_columns, + target_columns=target.columns, target_qualifier=target_qualifier, projection_prefix=projection_prefix, allow_existing_target_projection=False, diff --git a/dlt/dataset/dataset.py b/dlt/dataset/dataset.py index fe478f4a44..36d2d2cfab 100644 --- a/dlt/dataset/dataset.py +++ b/dlt/dataset/dataset.py @@ -162,7 +162,7 @@ def _ipython_key_completions_(self) -> list[str]: def _is_same_dataset(self, other: dlt.Dataset) -> bool: """Whether `other` represents the same logical dataset.""" # TODO currently only compares dataset name, - # once harderned, conside implementing __eq__ based on this method + # once hardened, consider implementing __eq__ based on this method return self.dataset_name == other.dataset_name @property diff --git a/dlt/dataset/relation.py b/dlt/dataset/relation.py index ce08f676c8..dca8bc571c 100644 --- a/dlt/dataset/relation.py +++ b/dlt/dataset/relation.py @@ -51,6 +51,7 @@ _apply_join, _apply_explicit_join, _extract_joined_table_aliases, + _JoinTarget, _left_source_qualifier, ) @@ -414,14 +415,14 @@ def join( ) -> Self: """Join this relation to another table. - Without ``on``, join conditions are discovered automatically from the + Without `on`, join conditions are discovered automatically from the schema's reference chain (parent/child/root relationships created by - dlt during loading). With ``on``, an explicit join predicate is used + dlt during loading). With `on`, an explicit join predicate is used instead — this also enables cross-dataset joins. Args: other: Table name or Relation to join. For cross-dataset joins, - pass a Relation from a different ``dlt.Dataset``. + pass a Relation from a different `dlt.Dataset`. on: Explicit join condition as an SQL string or sqlglot expression. Required for cross-dataset joins and joins between tables without dlt schema references. @@ -438,7 +439,7 @@ def join( Raises: ValueError: If the join cannot be resolved. - Example:: + Example: # auto join (schema references) dataset["orders"].join("users") @@ -463,43 +464,33 @@ def join( if isinstance(on, str) and not on.strip(): raise ValueError("`on` must be a non-empty SQL expression.") - target_dataset, target_table, target_columns = self._resolve_join_target(other, on=on) - - is_same_dataset = self._dataset._is_same_dataset(target_dataset) + target = self._resolve_join_target(other, on=on) # self-join detection - if target_table == self._table_name and is_same_dataset: + if target.table_name == self._table_name and not target.is_foreign: raise ValueError("Self-joins are not supported.") - projection_prefix = alias or target_table + projection_prefix = alias or target.table_name if on is None: if not self._table_name: raise ValueError("This relation has no base table to resolve references.") - if not is_same_dataset: + if target.is_foreign: raise ValueError("`on` is required when joining relations from different datasets.") - if target_table not in self._dataset.schema.tables: - raise ValueError(f"Table `{target_table}` not found in dataset schema") + if target.table_name not in self._dataset.schema.tables: + raise ValueError(f"Table `{target.table_name}` not found in dataset schema") query = _apply_join( self.sqlglot_expression, schema=self._dataset.schema, left_table=self._table_name, - right_table=target_table, + right_table=target.table_name, projection_prefix=projection_prefix, kind=kind, ) else: - # pass Relation as target when it's been transformed so it - # becomes a subquery (preserving WHERE, SELECT, LIMIT, etc.) - subquery_rhs: Optional[Relation] = ( - other if isinstance(other, dlt.Relation) and other._query is not None else None - ) query = _apply_explicit_join( self.sqlglot_expression, - target=subquery_rhs, - target_table=target_table, - target_dataset_name=(None if is_same_dataset else target_dataset.dataset_name), - target_columns=target_columns, + target, on=on, projection_prefix=projection_prefix, kind=kind, @@ -516,8 +507,8 @@ def join( if ds_name == self._dataset.dataset_name: continue rel._foreign_schemas[ds_name] = list(schemas) - if not is_same_dataset: - rel._foreign_schemas[target_dataset.dataset_name] = list(target_dataset.schemas) + if target.is_foreign: + rel._foreign_schemas[target.dataset_name] = list(target.schemas) return rel @@ -526,27 +517,22 @@ def _resolve_join_target( other: Union[str, Self], *, on: Union[str, sge.Expression, None] = None, - ) -> tuple[dlt.Dataset, str, TTableSchemaColumns]: - """Resolve the target dataset, table name, and columns for a join. - - Returns: - Tuple of (target_dataset, target_table_name, target_columns). - """ + ) -> _JoinTarget: + """Resolve the right-hand side of a join into a `_JoinTarget`.""" if isinstance(other, dlt.Relation): target_dataset = other._dataset - # physical destination check - if not self._dataset._is_same_dataset(target_dataset): - if not self._dataset.is_same_physical_destination(target_dataset): - raise ValueError( - "Cannot join relations from different physical destinations: " - f"'{target_dataset.dataset_name}' vs '{self._dataset.dataset_name}'" - ) - # cross-dataset filesystem not supported - if isinstance(self.sql_client, WithSchemas): - raise ValueError( - "Cross-dataset joins are not supported on filesystem destinations." - ) + if not self._dataset.is_same_physical_destination(target_dataset): + raise ValueError( + "Cannot join relations from different physical destinations: " + f"'{target_dataset.dataset_name}' vs '{self._dataset.dataset_name}'" + ) + + is_foreign = not self._dataset._is_same_dataset(target_dataset) + if is_foreign and isinstance(self.sql_client, WithSchemas): + raise ValueError( + "Cross-dataset joins are not supported on filesystem destinations." + ) target_table = other._table_name is_transformed = other._query is not None @@ -561,33 +547,46 @@ def _resolve_join_target( # no base table at all (e.g., from .query()) if on is None: raise ValueError(f"Relation `{other}` has no base table to resolve references.") - target_table = _extract_subquery_alias(other) + target_table = _left_source_qualifier(other.sqlglot_expression) or "subquery" target_columns = other.columns_schema - elif isinstance(other, str): + return _JoinTarget( + target_dataset.dataset_name, + is_foreign, + target_table, + target_columns, + target_dataset.schemas, + subquery=other.sqlglot_expression if is_transformed else None, + ) + + if isinstance(other, str): if "." in other: ds_name, tbl_name = other.split(".", 1) - if ds_name == self._dataset.dataset_name: - target_dataset = self._dataset - target_table = tbl_name - target_columns = _find_table_columns(target_dataset.schemas, target_table) - elif ds_name in self._foreign_schemas: - target_dataset = self._dataset - target_table = tbl_name - target_columns = _find_table_columns(self._foreign_schemas[ds_name], tbl_name) - return target_dataset, target_table, target_columns - else: - raise ValueError( - f"Dataset `{ds_name}` is not registered. Pass a Relation from the " - "foreign dataset to automatically register its schema." - ) else: - target_dataset = self._dataset - target_table = other - target_columns = _find_table_columns(target_dataset.schemas, target_table) - else: - raise ValueError("`other` must be a table name or a base table relation.") + ds_name, tbl_name = self._dataset.dataset_name, other + + if ds_name == self._dataset.dataset_name: + return _JoinTarget( + ds_name, + False, + tbl_name, + _find_table_columns(self._dataset.schemas, tbl_name), + self._dataset.schemas, + ) + if ds_name in self._foreign_schemas: + foreign_schemas = self._foreign_schemas[ds_name] + return _JoinTarget( + ds_name, + True, + tbl_name, + _find_table_columns(foreign_schemas, tbl_name), + foreign_schemas, + ) + raise ValueError( + f"Dataset `{ds_name}` is not registered. Pass a Relation from the " + "foreign dataset to automatically register its schema." + ) - return target_dataset, target_table, target_columns + raise ValueError("`other` must be a table name or a base table relation.") def incremental(self, incremental: Incremental[Any]) -> Self: """Filter this relation to a cursor range using an Incremental. @@ -1044,8 +1043,3 @@ def _find_table_columns(schemas: Sequence[dlt.Schema], table_name: str) -> TTabl if table_name in schema.tables: return schema.tables[table_name]["columns"] raise ValueError(f"Table `{table_name}` not found in dataset schema") - - -def _extract_subquery_alias(relation: dlt.Relation) -> str: - """Extract the source qualifier of a transformed Relation without a base table.""" - return _left_source_qualifier(relation.sqlglot_expression) or "subquery" diff --git a/docs/website/docs/general-usage/dataset-access/dataset.md b/docs/website/docs/general-usage/dataset-access/dataset.md index 03273ae725..06353b7e30 100644 --- a/docs/website/docs/general-usage/dataset-access/dataset.md +++ b/docs/website/docs/general-usage/dataset-access/dataset.md @@ -216,6 +216,8 @@ Pass `on=` to write the join condition yourself, as a SQL string or a `sqlglot` The right-hand side can be a table name, a table relation, or a relation you already transformed with `select()`, `where()`, etc. When you pass a transformed relation, its filters and column selection carry over to the joined result. +Refer to the right-hand side in `on` by its source qualifier: the joined table's name, or the alias you gave it in a `dataset.query(...)`. A relation with no identifiable source, for example a constant `dataset.query("SELECT 1 AS id")` that has no `FROM` is exposed under the qualifier `subquery`, so write `subquery.` in `on`. + The left-hand side can be a table relation, a relation chained from one with `where()`, `select()`, `order_by()`, and similar methods, or a `dataset.query("...")` that reads from a single table. Self-joins are not supported, even with explicit `on`. For self-joins, multi-way joins with mixed conditions, or fully programmatic join construction, use [Ibis](#modifying-queries-with-ibis-expressions). diff --git a/tests/dataset/test_relation_join.py b/tests/dataset/test_relation_join.py index b99d03e465..2ac6877b3a 100644 --- a/tests/dataset/test_relation_join.py +++ b/tests/dataset/test_relation_join.py @@ -234,6 +234,36 @@ def other_data(): rel.join(other_rel, on="users._dlt_id = other_data._dlt_id") +def test_join_rejects_same_name_on_different_physical_destinations() -> None: + with tempfile.TemporaryDirectory() as tmp: + tmp_path = pathlib.Path(tmp) + shared_dataset_name = "same_name_diff_dest" + + pipeline_a = dlt.pipeline( + pipeline_name="same_name_diff_dest_a", + pipelines_dir=str(tmp_path / "pipelines_dir"), + destination=dlt.destinations.duckdb(str(tmp_path / "a.duckdb")), + dataset_name=shared_dataset_name, + ) + pipeline_a.run([{"id": 1}], table_name="users") + + pipeline_b = dlt.pipeline( + pipeline_name="same_name_diff_dest_b", + pipelines_dir=str(tmp_path / "pipelines_dir"), + destination=dlt.destinations.duckdb(str(tmp_path / "b.duckdb")), + dataset_name=shared_dataset_name, + ) + pipeline_b.run([{"oid": 10}], table_name="orders") + + ds_a = pipeline_a.dataset() + ds_b = pipeline_b.dataset() + assert ds_a.dataset_name == ds_b.dataset_name + assert not ds_a.is_same_physical_destination(ds_b) + + with pytest.raises(ValueError, match="different physical destinations"): + ds_a.table("users").join(ds_b.table("orders"), on="users.id = orders.user_id") + + @pytest.mark.parametrize( "dataset_with_loads,left,right,expected_targets", [ @@ -944,6 +974,20 @@ def test_explicit_on_joins_relational_tables( ds.table("customers").join("orders") +def test_explicit_on_join_via_local_dotted_string( + dataset_with_relational_tables: dlt.Dataset, +) -> None: + ds = dataset_with_relational_tables + joined = ds.table("customers").join( + f"{ds.dataset_name}.orders", on="customers.customer_id = orders.customer_id" + ) + assert not joined._foreign_schemas + df = joined.df() + assert len(df) == 4 + assert "orders__amount" in df.columns + assert list(df["orders__amount"]) == [50.0, 75.0, 200.0, 30.0] + + def test_explicit_on_accepts_sqlglot_expression( dataset_with_relational_tables: dlt.Dataset, ) -> None: @@ -1184,6 +1228,22 @@ def test_explicit_on_with_aliased_query_relations( assert list(df["o__amount"]) == [50.0, 75.0, 200.0, 30.0] +def test_explicit_on_with_constant_rhs_uses_subquery_fallback_qualifier( + dataset_with_relational_tables: dlt.Dataset, +) -> None: + ds = dataset_with_relational_tables + const = ds.query( + "SELECT 1 AS customer_id, 'x' AS tag" + ) # no FROM clause, no qualifier, falls back to subquery + joined = ds.table("customers").join(const, on="customers.customer_id = subquery.customer_id") + + assert "subquery__tag" in joined.columns_schema + df = joined.df() + assert len(df) == 1 + assert list(df["name"]) == ["Alice"] + assert list(df["subquery__tag"]) == ["x"] + + def test_explicit_on_rejects_empty_alias( dataset_with_relational_tables: dlt.Dataset, ) -> None: @@ -1229,6 +1289,25 @@ def test_explicit_on_rejects_unknown_kind( ) +def test_explicit_on_rejects_unknown_dotted_string_dataset( + dataset_with_relational_tables: dlt.Dataset, +) -> None: + ds = dataset_with_relational_tables + with pytest.raises(ValueError, match="is not registered"): + ds.table("customers").join( + "unknown_ds.orders", on="customers.customer_id = orders.customer_id" + ) + + +def test_explicit_on_rejects_subquery_from_lhs( + dataset_with_relational_tables: dlt.Dataset, +) -> None: + ds = dataset_with_relational_tables + derived = ds.query("SELECT * FROM (SELECT * FROM customers) AS sub") + with pytest.raises(ValueError, match="must have a base table"): + derived.join("orders", on="sub.customer_id = orders.customer_id") + + @pytest.mark.parametrize( "name_normalizer_ref", ( @@ -1282,6 +1361,24 @@ def test_cross_dataset_join( assert sorted(df["purchases__sku"]) == ["G-001", "W-001", "W-001"] +def test_cross_dataset_join_accepts_sqlglot_expression( + cross_dataset_duckdb: TCrossDsFixture, +) -> None: + ds_crm, ds_inv = cross_dataset_duckdb + on_expr = sge.EQ( + this=sge.Column(table=sge.to_identifier("users"), this=sge.to_identifier("id")), + expression=sge.Column( + table=sge.to_identifier("purchases"), this=sge.to_identifier("user_id") + ), + ) + joined = ds_crm.table("users").join(ds_inv.table("purchases"), on=on_expr) + + assert ds_inv.dataset_name in joined._foreign_schemas + df = joined.df() + assert len(df) == 3 + assert sorted(df["purchases__sku"]) == ["G-001", "W-001", "W-001"] + + def test_cross_dataset_join_requires_on( cross_dataset_duckdb: TCrossDsFixture, ) -> None: @@ -1533,6 +1630,32 @@ def test_cross_dataset_join_then_foreign_dataset_local_hop_with_relation( assert list(df["inventory_items__quantity"]) == [50, 30, 50] +def test_cross_dataset_join_via_dotted_string_qualifies_foreign_dataset( + cross_dataset_duckdb: TCrossDsFixture, +) -> None: + ds_crm, ds_inv = cross_dataset_duckdb + + joined = ds_crm.table("users").join( + ds_inv.table("purchases"), on="users.id = purchases.user_id" + ) + assert ds_inv.dataset_name in joined._foreign_schemas + + chained = joined.join( + f"{ds_inv.dataset_name}.inventory_items", + on="purchases.sku = inventory_items.sku", + ) + sql = chained.to_sql() + + assert f'"{ds_inv.dataset_name}"."inventory_items"' in sql, sql + assert f'"{ds_crm.dataset_name}"."inventory_items"' not in sql, sql + + df = chained.order_by("purchases__purchase_id").df() + assert list(df["purchases__purchase_id"]) == [1, 2, 3] + assert list(df["name"]) == ["Alice", "Alice", "Bob"] + assert list(df["purchases__sku"]) == ["W-001", "G-001", "W-001"] + assert list(df["inventory_items__quantity"]) == [50, 30, 50] + + def test_cross_dataset_join_chain_with_filtered_step( cross_dataset_duckdb: TCrossDsFixture, ) -> None: diff --git a/tests/load/test_relation_join.py b/tests/load/test_relation_join.py index 95a2bae335..e9413e5c2b 100644 --- a/tests/load/test_relation_join.py +++ b/tests/load/test_relation_join.py @@ -1,4 +1,4 @@ -"""End-to-end tests for ``Relation.join()`` across destinations.""" +"""End-to-end tests for `Relation.join()` across destinations.""" import os from typing import Any, cast, Tuple From b992b8daf2421da2811668f49d86188ebd2c175e Mon Sep 17 00:00:00 2001 From: Anton Burnashev Date: Tue, 2 Jun 2026 15:56:20 +0200 Subject: [PATCH 18/30] support self-joins via distinct qualifiers --- dlt/dataset/_join.py | 6 +- dlt/dataset/relation.py | 4 -- .../general-usage/dataset-access/dataset.md | 12 +++- tests/dataset/test_relation_join.py | 58 ++++++++++++++++--- tests/dataset/utils.py | 20 ++++++- 5 files changed, 83 insertions(+), 17 deletions(-) diff --git a/dlt/dataset/_join.py b/dlt/dataset/_join.py index 34f5bfc78f..adf9dc794a 100644 --- a/dlt/dataset/_join.py +++ b/dlt/dataset/_join.py @@ -127,7 +127,11 @@ def _resolve_parent_reference_chain(schema: Schema, left: str, right: str) -> li def _resolve_reference_chain(schema: Schema, left: str, right: str) -> list[_JoinRef]: """Resolve ordered join steps between two tables.""" if left == right: - raise ValueError(f"Cannot join a table to itself: {left}") + raise ValueError( + f"Cannot join table `{left}` to itself via schema references. Use an explicit " + "`on=` predicate and alias one side (e.g. via `query('SELECT * FROM ... AS alias')`) " + "to self-join." + ) # Check direct references first for ref in schema.references: diff --git a/dlt/dataset/relation.py b/dlt/dataset/relation.py index dca8bc571c..687a9c80d1 100644 --- a/dlt/dataset/relation.py +++ b/dlt/dataset/relation.py @@ -466,10 +466,6 @@ def join( target = self._resolve_join_target(other, on=on) - # self-join detection - if target.table_name == self._table_name and not target.is_foreign: - raise ValueError("Self-joins are not supported.") - projection_prefix = alias or target.table_name if on is None: diff --git a/docs/website/docs/general-usage/dataset-access/dataset.md b/docs/website/docs/general-usage/dataset-access/dataset.md index 06353b7e30..f3850fb6a4 100644 --- a/docs/website/docs/general-usage/dataset-access/dataset.md +++ b/docs/website/docs/general-usage/dataset-access/dataset.md @@ -220,7 +220,17 @@ Refer to the right-hand side in `on` by its source qualifier: the joined table's The left-hand side can be a table relation, a relation chained from one with `where()`, `select()`, `order_by()`, and similar methods, or a `dataset.query("...")` that reads from a single table. -Self-joins are not supported, even with explicit `on`. For self-joins, multi-way joins with mixed conditions, or fully programmatic join construction, use [Ibis](#modifying-queries-with-ibis-expressions). +Self-joins work with explicit `on`, but the two instances of the table need distinct SQL qualifiers so the predicate can tell them apart. Alias one side with a `dataset.query(...)` and refer to that alias in `on`: + +```py +# attach each employee's manager from the same table +managers = dataset.query("SELECT * FROM employees AS managers") +with_managers = dataset["employees"].join( + managers, on="employees.manager_id = managers.id", kind="left" +) +``` + +Joining a base table directly to itself (as in `dataset["employees"].join("employees", ...)`) is rejected, because both sides would share the `employees` qualifier. #### Cross-dataset joins diff --git a/tests/dataset/test_relation_join.py b/tests/dataset/test_relation_join.py index 2ac6877b3a..db4c504888 100644 --- a/tests/dataset/test_relation_join.py +++ b/tests/dataset/test_relation_join.py @@ -204,7 +204,7 @@ def test_build_join_condition_rejects_empty_pairs() -> None: def test_resolve_reference_chain_rejects_self_join(dataset_with_loads: TLoadsFixture) -> None: dataset, _, _ = dataset_with_loads - with pytest.raises(ValueError, match="Cannot join a table to itself"): + with pytest.raises(ValueError, match="to itself"): _resolve_reference_chain(dataset.schema, "users", "users") @@ -329,7 +329,7 @@ def test_resolve_reference_chain_rejects_unrelated_tables( pytest.param( lambda ds: ds.table("users"), "users", - "Self-joins are not supported", + "to itself", id="self-join", ), pytest.param( @@ -1254,16 +1254,56 @@ def test_explicit_on_rejects_empty_alias( ) -def test_explicit_on_rejects_self_join( +@pytest.mark.parametrize( + "build_join,expected_error", + [ + pytest.param( + lambda ds: ds.table("employees").join( + "employees", on="employees.manager_id = employees.id", alias="mgr" + ), + "already names a source", + id="direct-base-self-join-rejected", + ), + pytest.param( + lambda ds: ds.query("SELECT * FROM employees AS e1").join( + "employees", on="e1.manager_id = employees.id", alias="mgr" + ), + None, + id="aliased-query-lhs-base-rhs", + ), + pytest.param( + lambda ds: ds.query("SELECT * FROM employees AS e1").join( + ds.query("SELECT * FROM employees AS e2"), + on="e1.manager_id = e2.id", + alias="mgr", + ), + None, + id="both-aliased-query", + ), + pytest.param( + lambda ds: ds.table("employees").join( + ds.query("SELECT * FROM employees AS mgr"), + on="employees.manager_id = mgr.id", + ), + None, + id="base-lhs-aliased-query-rhs", + ), + ], +) +def test_self_join_requires_distinct_qualifiers( dataset_with_relational_tables: dlt.Dataset, + build_join: Callable[[dlt.Dataset], dlt.Relation], + expected_error: Optional[str], ) -> None: ds = dataset_with_relational_tables - with pytest.raises(ValueError, match="Self-joins are not supported"): - ds.table("customers").join( - "customers", - on="customers.customer_id = customers.customer_id", - alias="c2", - ) + if expected_error is not None: + with pytest.raises(ValueError, match=expected_error): + build_join(ds) + return + + df = build_join(ds).df() + assert sorted(df["name"]) == ["Bob", "Carol"] + assert sorted(df["mgr__name"]) == ["Alice", "Alice"] @pytest.mark.parametrize("on", ["", " "], ids=["empty", "whitespace"]) diff --git a/tests/dataset/utils.py b/tests/dataset/utils.py index dc5c40be27..2626708885 100644 --- a/tests/dataset/utils.py +++ b/tests/dataset/utils.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Sequence -from typing import TypedDict +from typing import Optional, TypedDict import dlt @@ -97,6 +97,12 @@ class CountryRow(TypedDict): name: str +class EmployeeRow(TypedDict): + id: int # noqa: A003 + name: str + manager_id: Optional[int] + + TLoadStats = dict[str, int] TLoadsFixture = tuple[dlt.Dataset, tuple[str, str], tuple[TLoadStats, TLoadStats]] TCrossDsFixture = tuple[dlt.Dataset, dlt.Dataset] @@ -288,6 +294,12 @@ def subscriptions(): {"code": "ES", "name": "Spain"}, ] +EMPLOYEES: list[EmployeeRow] = [ + {"id": 1, "name": "Alice", "manager_id": None}, + {"id": 2, "name": "Bob", "manager_id": 1}, + {"id": 3, "name": "Carol", "manager_id": 1}, +] + @dlt.source def relational_tables(): @@ -303,7 +315,11 @@ def orders(): def countries(): yield COUNTRIES - return [customers(), orders(), countries()] + @dlt.resource(name="employees") + def employees(): + yield EMPLOYEES + + return [customers(), orders(), countries(), employees()] @dlt.source From 6b2a2c669ef540836b37010cb6f7ceeb18024808 Mon Sep 17 00:00:00 2001 From: Anton Burnashev Date: Tue, 2 Jun 2026 16:00:42 +0200 Subject: [PATCH 19/30] Revert "preserve dlt-namespace case in order by, group by, etc after bind_query, fixes snowflake" This reverts commit 5328c020490608a25ac5c0820822548da26f35f2. --- dlt/common/libs/sqlglot.py | 20 ------- tests/destinations/test_queries.py | 93 ------------------------------ 2 files changed, 113 deletions(-) diff --git a/dlt/common/libs/sqlglot.py b/dlt/common/libs/sqlglot.py index 179c72557a..434610b596 100644 --- a/dlt/common/libs/sqlglot.py +++ b/dlt/common/libs/sqlglot.py @@ -1044,21 +1044,6 @@ def normalize_query_identifiers( return query -def _restore_alias_case_in_clauses(query: sge.Query, alias_rename_map: Dict[str, str]) -> None: - """Rewrite bare-column references in ORDER BY / GROUP BY / HAVING back to the original - (un-casefolded) alias case so they match SELECT aliases preserved by `bind_query`.""" - for clause_key in ("order", "group", "having"): - clause = query.args.get(clause_key) - if clause is None: - continue - for col in clause.find_all(sge.Column): - if col.args.get("table") is not None: - continue - name_node = col.this - if name_node.name in alias_rename_map: - name_node.set("this", alias_rename_map[name_node.name]) - - def bind_query( qualified_query: sge.Query, sqlglot_schema: Any, # SQLGlotSchema @@ -1128,17 +1113,12 @@ def bind_query( node.set("quoted", True) # add aliases to output selects to stay compatible with dlt schema after the query - alias_rename_map: Dict[str, str] = {} if orig_selects: for i, orig in orig_selects.items(): case_folded_orig = casefold_identifier(orig) if case_folded_orig != orig: - alias_rename_map[case_folded_orig] = orig # somehow we need to alias just top select in UNION (tested on Snowflake) sel_expr = qualified_query.selects[i] qualified_query.selects[i] = sge.alias_(sel_expr, orig, quoted=True) - if alias_rename_map: - _restore_alias_case_in_clauses(qualified_query, alias_rename_map) - return qualified_query diff --git a/tests/destinations/test_queries.py b/tests/destinations/test_queries.py index bc635c4447..757e614701 100644 --- a/tests/destinations/test_queries.py +++ b/tests/destinations/test_queries.py @@ -13,16 +13,6 @@ from dlt.destinations.impl.duckdb.configuration import DuckDbClientConfiguration -_BIND_QUERY_SCHEMA = SQLGlotSchema( - { - "my_dataset": { - "customers": {"customer_id": str, "country_code": str}, - "orders": {"order_id": str, "customer_id": str}, - } - } -) - - def test_basic() -> None: stmt = build_row_counts_expr("my_table", quoted_identifiers=True) expected = ( @@ -154,86 +144,3 @@ def _expand(table_name: str, db: Optional[str] = None) -> List[str]: normalized_query = normalized_query_expr.sql() assert normalized_query == expected_normalized_query - - -def _bind_query_expand(table_name: str, db: Optional[str] = None) -> List[str]: - return [db, table_name] - - -@pytest.mark.parametrize( - "clause_sql", - [ - pytest.param('ORDER BY "orders__order_id" ASC', id="order_by"), - pytest.param('GROUP BY "orders__order_id"', id="group_by"), - pytest.param('HAVING "orders__order_id" > 0', id="having"), - ], -) -def test_bind_query_preserves_alias_case_for_clause_references(clause_sql: str) -> None: - query = cast( - sge.Query, - sqlglot.parse_one(f""" - SELECT - customers.customer_id AS customer_id, - orders.order_id AS "orders__order_id" - FROM my_dataset.customers AS customers - INNER JOIN my_dataset.orders AS orders - ON customers.customer_id = orders.customer_id - {clause_sql} - """), - ) - - bound = bind_query( - qualified_query=query, - sqlglot_schema=_BIND_QUERY_SCHEMA, - expand_table_name=_bind_query_expand, - casefold_identifier=str.upper, - ) - sql = bound.sql() - - # SELECT alias is preserved in original - assert 'AS "orders__order_id"' in sql - # clause reference matches the preserved alias case, not the casefolded form - assert clause_sql in sql - - -def test_bind_query_casefolds_qualified_columns_in_order_by() -> None: - """Table-qualified column references in ORDER BY must still be casefolded.""" - query = cast( - sge.Query, - sqlglot.parse_one(""" - SELECT customers.customer_id AS customer_id - FROM my_dataset.customers AS customers - ORDER BY customers.customer_id ASC - """), - ) - - bound = bind_query( - qualified_query=query, - sqlglot_schema=_BIND_QUERY_SCHEMA, - expand_table_name=_bind_query_expand, - casefold_identifier=str.upper, - ) - sql = bound.sql() - - assert 'ORDER BY "CUSTOMERS"."CUSTOMER_ID"' in sql - - -def test_bind_query_casefolds_unrelated_bare_order_by_identifiers() -> None: - query = cast( - sge.Query, - sqlglot.parse_one(""" - SELECT customers.customer_id AS customer_id - FROM my_dataset.customers AS customers - ORDER BY country_code ASC - """), - ) - - bound = bind_query( - qualified_query=query, - sqlglot_schema=_BIND_QUERY_SCHEMA, - expand_table_name=_bind_query_expand, - casefold_identifier=str.upper, - ) - sql = bound.sql() - - assert 'ORDER BY "COUNTRY_CODE"' in sql From 63b7dfcdc18b955702a9436ef0a20974b40b5c8e Mon Sep 17 00:00:00 2001 From: Anton Burnashev Date: Tue, 2 Jun 2026 16:13:34 +0200 Subject: [PATCH 20/30] cleanup --- dlt/dataset/relation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dlt/dataset/relation.py b/dlt/dataset/relation.py index 687a9c80d1..2edec7682a 100644 --- a/dlt/dataset/relation.py +++ b/dlt/dataset/relation.py @@ -497,7 +497,7 @@ def join( rel = self.__copy__() rel._sqlglot_expression = query - # carry the RHS relation's foreign datasets + # carry the RHS relation's foreign schemas if isinstance(other, dlt.Relation): for ds_name, schemas in other._foreign_schemas.items(): if ds_name == self._dataset.dataset_name: From 4113d0c107ec39e0656cb0a8fc2c6956e09d4900 Mon Sep 17 00:00:00 2001 From: Anton Burnashev Date: Tue, 2 Jun 2026 16:21:10 +0200 Subject: [PATCH 21/30] add a note on names in `on` --- dlt/dataset/relation.py | 3 ++- docs/website/docs/general-usage/dataset-access/dataset.md | 4 ++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/dlt/dataset/relation.py b/dlt/dataset/relation.py index 2edec7682a..e2f40ecec4 100644 --- a/dlt/dataset/relation.py +++ b/dlt/dataset/relation.py @@ -425,7 +425,8 @@ def join( pass a Relation from a different `dlt.Dataset`. on: Explicit join condition as an SQL string or sqlglot expression. Required for cross-dataset joins and joins between tables - without dlt schema references. + without dlt schema references. Column and table names in the + predicate must use their dlt schema (normalized) names. kind: Type of SQL join: ``"inner"``, ``"left"``, ``"right"``, or ``"full"``. alias: Projection prefix for the joined table's columns. Columns diff --git a/docs/website/docs/general-usage/dataset-access/dataset.md b/docs/website/docs/general-usage/dataset-access/dataset.md index f3850fb6a4..928949e87c 100644 --- a/docs/website/docs/general-usage/dataset-access/dataset.md +++ b/docs/website/docs/general-usage/dataset-access/dataset.md @@ -220,6 +220,10 @@ Refer to the right-hand side in `on` by its source qualifier: the joined table's The left-hand side can be a table relation, a relation chained from one with `where()`, `select()`, `order_by()`, and similar methods, or a `dataset.query("...")` that reads from a single table. +:::note +Write the column and table names in `on` using their dlt schema names: the normalized identifiers you pass to `dataset.table(...)` and see in the dataset's schema, not the original field names from your source. With the default snake_case naming the two usually match, but under a name-mutating [naming convention](../naming-convention.md) you must use the normalized form. +::: + Self-joins work with explicit `on`, but the two instances of the table need distinct SQL qualifiers so the predicate can tell them apart. Alias one side with a `dataset.query(...)` and refer to that alias in `on`: ```py From 7ae2e189c08fd6c56686f4a5e4ad3ddd1ed42e98 Mon Sep 17 00:00:00 2001 From: Anton Burnashev Date: Tue, 2 Jun 2026 16:36:46 +0200 Subject: [PATCH 22/30] extract shared read-dataset test fixtures into a module --- tests/libs/test_ibis.py | 6 ++-- tests/load/read_dataset_fixtures.py | 45 +++++++++++++++++++++++ tests/load/test_read_interfaces.py | 48 ++++--------------------- tests/load/test_relation_join.py | 56 +++++++---------------------- 4 files changed, 66 insertions(+), 89 deletions(-) create mode 100644 tests/load/read_dataset_fixtures.py diff --git a/tests/libs/test_ibis.py b/tests/libs/test_ibis.py index 0cb1eaa9be..fd7c8ce51e 100644 --- a/tests/libs/test_ibis.py +++ b/tests/libs/test_ibis.py @@ -19,10 +19,8 @@ from tests.load.lance_utils import ( module_lance_rest_server, # consumed by `populated_pipeline` fixture ) -from tests.load.test_read_interfaces import ( - populated_pipeline, - preserve_module_environ_per_destination_config, -) +from tests.load.read_dataset_fixtures import preserve_module_environ_per_destination_config +from tests.load.test_read_interfaces import populated_pipeline from tests.load.utils import DestinationTestConfiguration, destinations_configs from tests.utils import ( auto_module_test_run_context, diff --git a/tests/load/read_dataset_fixtures.py b/tests/load/read_dataset_fixtures.py new file mode 100644 index 0000000000..310edd6d4f --- /dev/null +++ b/tests/load/read_dataset_fixtures.py @@ -0,0 +1,45 @@ +"""Shared fixtures for tests that read datasets and relations across destinations.""" +from typing import Any, cast + +import pytest + +from tests.load.utils import ( + DestinationTestConfiguration, + MEMORY_BUCKET, + SFTP_BUCKET, + destinations_configs, +) +from tests.utils import _preserve_environ + + +@pytest.fixture( + scope="module", + params=destinations_configs( + default_sql_configs=True, + read_only_sqlclient_configs=True, + bucket_exclude=[SFTP_BUCKET, MEMORY_BUCKET], + ), + ids=lambda x: x.name, +) +def destination_config(request: pytest.FixtureRequest) -> DestinationTestConfiguration: + return cast(DestinationTestConfiguration, request.param) + + +@pytest.fixture(scope="module") +def preserve_module_environ_per_destination_config( + destination_config: DestinationTestConfiguration, +) -> Any: + yield from _preserve_environ() + + +def skip_if_unsupported_filesystem_format( + destination_config: DestinationTestConfiguration, +) -> None: + if ( + destination_config.file_format not in ["parquet", "jsonl"] + and destination_config.destination_type == "filesystem" + ): + pytest.skip( + "filesystem read-only sql_client requires jsonl or parquet; got" + f" {destination_config.file_format}" + ) diff --git a/tests/load/test_read_interfaces.py b/tests/load/test_read_interfaces.py index 7defcfee14..8753aa2c55 100644 --- a/tests/load/test_read_interfaces.py +++ b/tests/load/test_read_interfaces.py @@ -25,15 +25,17 @@ from dlt.dataset.exceptions import LineageFailedException from tests.load.lance_utils import module_lance_rest_server +from tests.load.read_dataset_fixtures import ( + destination_config, + preserve_module_environ_per_destination_config, + skip_if_unsupported_filesystem_format, +) from tests.load.utils import ( DestinationTestConfiguration, - MEMORY_BUCKET, - SFTP_BUCKET, destinations_configs, drop_pipeline_data, ) from tests.utils import ( - _preserve_environ, auto_module_test_run_context, auto_module_test_storage, ) @@ -146,28 +148,6 @@ def orderable_in_chain(): return source() -@pytest.fixture( - scope="module", - params=destinations_configs( - default_sql_configs=True, - read_only_sqlclient_configs=True, - bucket_exclude=[SFTP_BUCKET, MEMORY_BUCKET], - ), - ids=lambda x: x.name, -) -def destination_config( - request: pytest.FixtureRequest, -) -> DestinationTestConfiguration: - return cast(DestinationTestConfiguration, request.param) - - -@pytest.fixture(scope="module") -def preserve_module_environ_per_destination_config( - destination_config: DestinationTestConfiguration, -) -> Any: - yield from _preserve_environ() - - @pytest.fixture(scope="module") def populated_pipeline( destination_config: DestinationTestConfiguration, @@ -178,14 +158,7 @@ def populated_pipeline( ) -> Any: """fixture that returns a pipeline object populated with the example data""" - if ( - destination_config.file_format not in ["parquet", "jsonl"] - and destination_config.destination_type == "filesystem" - ): - pytest.skip( - "Test only works for jsonl and parquet on filesystem destination, given:" - f" {destination_config.file_format}" - ) + skip_if_unsupported_filesystem_format(destination_config) pipeline = destination_config.setup_pipeline( "read_pipeline", dataset_name="read_test", dev_mode=True @@ -1682,14 +1655,7 @@ def overlap_pipeline( 'events' table with shared columns (id, name) and unique columns. Tests select schema subsets via ``pipeline.dataset(schema=[...])``. """ - if ( - destination_config.file_format not in ["parquet", "jsonl"] - and destination_config.destination_type == "filesystem" - ): - pytest.skip( - "Test only works for jsonl and parquet on filesystem destination, given:" - f" {destination_config.file_format}" - ) + skip_if_unsupported_filesystem_format(destination_config) pipeline = destination_config.setup_pipeline( "overlap_pipeline", dataset_name="overlap_test", dev_mode=True diff --git a/tests/load/test_relation_join.py b/tests/load/test_relation_join.py index e9413e5c2b..7d2676c1e8 100644 --- a/tests/load/test_relation_join.py +++ b/tests/load/test_relation_join.py @@ -1,7 +1,7 @@ """End-to-end tests for `Relation.join()` across destinations.""" import os -from typing import Any, cast, Tuple +from typing import Any, Tuple import pytest @@ -17,59 +17,27 @@ relational_tables, ) from tests.load.lance_utils import module_lance_rest_server +from tests.load.read_dataset_fixtures import ( + destination_config, + preserve_module_environ_per_destination_config, + skip_if_unsupported_filesystem_format, +) from tests.load.utils import ( DestinationTestConfiguration, - MEMORY_BUCKET, - SFTP_BUCKET, - destinations_configs, drop_pipeline_data, ) from tests.utils import ( - _preserve_environ, auto_module_test_run_context, auto_module_test_storage, get_test_storage_root, ) -# TODO: same as in test_read_interfaces.py: factor out into a shared helper -@pytest.fixture( - scope="module", - params=destinations_configs( - default_sql_configs=True, - read_only_sqlclient_configs=True, - bucket_exclude=[SFTP_BUCKET, MEMORY_BUCKET], - ), - ids=lambda x: x.name, -) -def destination_config( - request: pytest.FixtureRequest, -) -> DestinationTestConfiguration: - config = cast(DestinationTestConfiguration, request.param) +def _skip_unsupported(destination_config: DestinationTestConfiguration) -> None: + skip_if_unsupported_filesystem_format(destination_config) # TODO: remove once https://github.com/dlt-hub/dlt/pull/4011 is merged - if config.destination_type == "databricks": + if destination_config.destination_type == "databricks": pytest.skip("databricks foreign-key emission breaks this fixture. see dlt-hub/dlt#4011") - return config - - -# TODO: same code in test_read_interfaces.py: factor out into a shared helper -@pytest.fixture(scope="module") -def preserve_module_environ_per_destination_config( - destination_config: DestinationTestConfiguration, -) -> Any: - yield from _preserve_environ() - - -# TODO: same code in test_read_interfaces.py: factor out into a shared helper -def _skip_unsupported_filesystem(destination_config: DestinationTestConfiguration) -> None: - if ( - destination_config.file_format not in ["parquet", "jsonl"] - and destination_config.destination_type == "filesystem" - ): - pytest.skip( - "filesystem read-only sqlclient requires jsonl or parquet; got" - f" {destination_config.file_format}" - ) @pytest.fixture(scope="module") @@ -80,7 +48,7 @@ def relational_pipeline( preserve_module_environ_per_destination_config: Any, auto_module_test_run_context: Any, ) -> Any: - _skip_unsupported_filesystem(destination_config) + _skip_unsupported(destination_config) pipeline = destination_config.setup_pipeline( "join_relational_pipeline", dataset_name="join_relational", dev_mode=True ) @@ -100,7 +68,7 @@ def crm_pipeline( preserve_module_environ_per_destination_config: Any, auto_module_test_run_context: Any, ) -> Any: - _skip_unsupported_filesystem(destination_config) + _skip_unsupported(destination_config) pipeline = destination_config.setup_pipeline( "join_crm_pipeline", dataset_name="join_crm", dev_mode=True ) @@ -123,7 +91,7 @@ def cross_dataset_pipelines( auto_module_test_run_context: Any, ) -> Any: """Two pipelines on the same physical destination, distinct dataset names.""" - _skip_unsupported_filesystem(destination_config) + _skip_unsupported(destination_config) if destination_config.destination_type in ("filesystem", "lance", "lancedb"): pytest.skip( "cross-dataset joins are not supported on filesystem destinations" From f7a1647fb36aab80c24d959312604fe4fb57f53d Mon Sep 17 00:00:00 2001 From: Anton Burnashev Date: Tue, 2 Jun 2026 17:39:55 +0200 Subject: [PATCH 23/30] support aggregated/distinct left side in explicit joins + tests --- dlt/dataset/_join.py | 31 +++++++++++++-- dlt/dataset/relation.py | 49 ++++++++++++----------- tests/dataset/test_relation_join.py | 60 +++++++++++++++++++++++++++-- 3 files changed, 108 insertions(+), 32 deletions(-) diff --git a/dlt/dataset/_join.py b/dlt/dataset/_join.py index adf9dc794a..2bef4ac1ff 100644 --- a/dlt/dataset/_join.py +++ b/dlt/dataset/_join.py @@ -480,6 +480,12 @@ def _collect_source_qualifiers(query: sge.Query) -> Set[str]: return {qualifier for source in sources if (qualifier := _source_qualifier(source)) is not None} +def _is_flat_select(query: sge.Select) -> bool: + if any(query.args.get(key) for key in ("group", "having", "qualify", "distinct")): + return False + return not any(sel.find(sge.AggFunc) for sel in query.selects) + + def _qualify_unscoped_predicate_columns(query: sge.Select, source_qualifier: str) -> None: """Bind unqualified WHERE/ORDER BY columns to the single source.""" if query.args.get("joins"): @@ -494,6 +500,23 @@ def _qualify_unscoped_predicate_columns(query: sge.Select, source_qualifier: str col.set("table", qualifier_identifier.copy()) +def _aliased_subquery(query: sge.Query, qualifier: str) -> sge.Subquery: + """Wrap `query` as a derived table exposed under `qualifier`.""" + return sge.Subquery( + this=query, + alias=sge.TableAlias(this=sge.to_identifier(qualifier, quoted=False)), + ) + + +def _wrap_as_derived_table(query: sge.Select, qualifier: str) -> sge.Select: + """Re-select all of `query`'s columns from it embedded as a derived table.""" + return ( + sge.Select() + .select(sge.Column(table=sge.to_identifier(qualifier), this=sge.Star())) + .from_(_aliased_subquery(query, qualifier)) + ) + + def _apply_explicit_join( expression: sge.Query, target: _JoinTarget, @@ -525,6 +548,9 @@ def _apply_explicit_join( ) left_source_qualifier = _source_qualifier(from_this) or from_this.name + if not _is_flat_select(query): + query = _wrap_as_derived_table(query, left_source_qualifier) + _qualify_unscoped_predicate_columns(query, left_source_qualifier) target_qualifier = target.table_name @@ -543,10 +569,7 @@ def _apply_explicit_join( rhs_inner = target.subquery if target_dataset_name: rhs_inner = _qualify_physical_tables_with_dataset(rhs_inner, target_dataset_name) - target_expr = sge.Subquery( - this=rhs_inner, - alias=sge.TableAlias(this=sge.to_identifier(target_qualifier, quoted=False)), - ) + target_expr = _aliased_subquery(rhs_inner, target_qualifier) else: target_expr = sge.Table( this=sge.to_identifier(target.table_name, quoted=True), diff --git a/dlt/dataset/relation.py b/dlt/dataset/relation.py index e2f40ecec4..7edbe3b34f 100644 --- a/dlt/dataset/relation.py +++ b/dlt/dataset/relation.py @@ -441,18 +441,17 @@ def join( ValueError: If the join cannot be resolved. Example: + >>> # auto join (schema references) + >>> dataset["orders"].join("users") - # auto join (schema references) - dataset["orders"].join("users") + >>> # explicit ON + >>> dataset["orders"].join("users", on="orders._dlt_parent_id = users._dlt_id") - # explicit ON - dataset["orders"].join("users", on="orders._dlt_parent_id = users._dlt_id") - - # cross-dataset join - local["orders"].join( - foreign["products"], - on="orders.product_id = products.id", - ) + >>> # cross-dataset join + >>> local["orders"].join( + ... foreign["products"], + ... on="orders.product_id = products.id", + ... ) """ if alias == "": raise ValueError("`alias` must be a non-empty string when provided.") @@ -547,11 +546,11 @@ def _resolve_join_target( target_table = _left_source_qualifier(other.sqlglot_expression) or "subquery" target_columns = other.columns_schema return _JoinTarget( - target_dataset.dataset_name, - is_foreign, - target_table, - target_columns, - target_dataset.schemas, + dataset_name=target_dataset.dataset_name, + is_foreign=is_foreign, + table_name=target_table, + columns=target_columns, + schemas=target_dataset.schemas, subquery=other.sqlglot_expression if is_transformed else None, ) @@ -563,20 +562,20 @@ def _resolve_join_target( if ds_name == self._dataset.dataset_name: return _JoinTarget( - ds_name, - False, - tbl_name, - _find_table_columns(self._dataset.schemas, tbl_name), - self._dataset.schemas, + dataset_name=ds_name, + is_foreign=False, + table_name=tbl_name, + columns=_find_table_columns(self._dataset.schemas, tbl_name), + schemas=self._dataset.schemas, ) if ds_name in self._foreign_schemas: foreign_schemas = self._foreign_schemas[ds_name] return _JoinTarget( - ds_name, - True, - tbl_name, - _find_table_columns(foreign_schemas, tbl_name), - foreign_schemas, + dataset_name=ds_name, + is_foreign=True, + table_name=tbl_name, + columns=_find_table_columns(foreign_schemas, tbl_name), + schemas=foreign_schemas, ) raise ValueError( f"Dataset `{ds_name}` is not registered. Pass a Relation from the " diff --git a/tests/dataset/test_relation_join.py b/tests/dataset/test_relation_join.py index db4c504888..0896a28e47 100644 --- a/tests/dataset/test_relation_join.py +++ b/tests/dataset/test_relation_join.py @@ -1212,6 +1212,63 @@ def test_explicit_on_with_aggregated_rhs( assert "order_totals__amount" not in df.columns +@pytest.mark.parametrize( + "lhs_query,expected_ids,expected_totals,expected_names", + [ + pytest.param( + "SELECT customer_id, SUM(amount) AS total FROM orders GROUP BY customer_id", + [1, 2, 3], + [125.0, 200.0, 30.0], + ["Alice", "Bob", "Charlie"], + id="group-by", + ), + pytest.param( + "SELECT customer_id, SUM(amount) AS total FROM orders " + "GROUP BY customer_id HAVING SUM(amount) > 40", + [1, 2], + [125.0, 200.0], + ["Alice", "Bob"], + id="group-by-having", + ), + ], +) +def test_explicit_on_with_aggregated_lhs( + dataset_with_relational_tables: dlt.Dataset, + lhs_query: str, + expected_ids: list[int], + expected_totals: list[float], + expected_names: list[str], +) -> None: + ds = dataset_with_relational_tables + agg_lhs = ds.query(lhs_query) + joined = agg_lhs.join("customers", on="orders.customer_id = customers.customer_id").order_by( + "customer_id" + ) + df = joined.df() + + assert list(df["customer_id"]) == expected_ids + assert [float(x) for x in df["total"]] == expected_totals + assert list(df["customers__name"]) == expected_names + assert "amount" not in df.columns + + +def test_explicit_on_with_distinct_lhs( + dataset_with_relational_tables: dlt.Dataset, +) -> None: + ds = dataset_with_relational_tables + distinct_codes = ds.query("SELECT DISTINCT country_code FROM customers") + joined = distinct_codes.join("countries", on="customers.country_code = countries.code") + + outer = joined.sqlglot_expression + assert outer.args.get("distinct") is None + derived = outer.find(sge.Subquery) + assert derived is not None and derived.this.args.get("distinct") is not None + + df = joined.order_by("country_code").df() + assert list(df["country_code"]) == ["DE", "FR"] + assert list(df["countries__name"]) == ["Germany", "France"] + + def test_explicit_on_with_aliased_query_relations( dataset_with_relational_tables: dlt.Dataset, ) -> None: @@ -1789,9 +1846,6 @@ def test_cross_dataset_join_chain_three_datasets( name_col: str, absent_name_col: str, ) -> None: - """A chain whose later `on` resolves an earlier join target by table name. An explicit - `alias` only prefixes the projected columns; it is not itself a join qualifier, so - `on="users.id = ..."` binds regardless of the alias chosen for the `users` join.""" ds_crm, ds_inv, ds_billing = three_way_cross_dataset_duckdb joined = ( From a5a0cbeab4794fd298d90d593c4ba178c5cb053d Mon Sep 17 00:00:00 2001 From: Anton Burnashev Date: Tue, 2 Jun 2026 22:54:18 +0200 Subject: [PATCH 24/30] exclude incomplete columns from explicit join projection --- dlt/dataset/relation.py | 2 +- tests/dataset/test_relation_join.py | 16 +++++++++++++++- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/dlt/dataset/relation.py b/dlt/dataset/relation.py index 7edbe3b34f..bcdc9ba35c 100644 --- a/dlt/dataset/relation.py +++ b/dlt/dataset/relation.py @@ -1037,5 +1037,5 @@ def _find_table_columns(schemas: Sequence[dlt.Schema], table_name: str) -> TTabl """Find the columns schema for a table across a sequence of schemas.""" for schema in schemas: if table_name in schema.tables: - return schema.tables[table_name]["columns"] + return schema.get_table_columns(table_name) raise ValueError(f"Table `{table_name}` not found in dataset schema") diff --git a/tests/dataset/test_relation_join.py b/tests/dataset/test_relation_join.py index 0896a28e47..8f64658fe6 100644 --- a/tests/dataset/test_relation_join.py +++ b/tests/dataset/test_relation_join.py @@ -469,10 +469,24 @@ def test_join_projection_prefix_rejects_colliding_alias( joined.join("users__orders__items", alias="shared") +@pytest.mark.parametrize( + "build_join", + [ + pytest.param(lambda ds: ds.table("products").join("categories"), id="magic"), + pytest.param( + lambda ds: ds.table("products").join( + "categories", on="products.category_id = categories.id" + ), + id="explicit-on", + ), + ], +) def test_join_does_not_project_incomplete_target_columns( dataset_with_incomplete_join_target: dlt.Dataset, + build_join: Callable[[dlt.Dataset], dlt.Relation], ) -> None: - relation = dataset_with_incomplete_join_target.table("products").join("categories") + relation = build_join(dataset_with_incomplete_join_target) + assert "categories__phantom_field" not in relation.columns_schema rows = relation.fetchall() assert rows is not None assert len(rows) == 3 From 88b6dca2d4cb579256f15ed21b5d50086246b90f Mon Sep 17 00:00:00 2001 From: Anton Burnashev Date: Wed, 3 Jun 2026 09:51:49 +0200 Subject: [PATCH 25/30] copy transformed rhs subquery so explicit join doesn't mutate the source relation --- dlt/dataset/_join.py | 15 ++++++--------- tests/dataset/test_relation_join.py | 23 +++++++++++++++++++++++ 2 files changed, 29 insertions(+), 9 deletions(-) diff --git a/dlt/dataset/_join.py b/dlt/dataset/_join.py index 2bef4ac1ff..c8f2c9449c 100644 --- a/dlt/dataset/_join.py +++ b/dlt/dataset/_join.py @@ -1,7 +1,7 @@ from __future__ import annotations from functools import reduce -from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Sequence, Set, TypeVar, Union +from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Sequence, Set, Union import sqlglot import sqlglot.expressions as sge @@ -16,8 +16,6 @@ _INTERMEDIATE_JOIN_ALIAS_PREFIX = "_dlt_int_t" -_TExpr = TypeVar("_TExpr", bound=sge.Expression) - class _JoinTarget(NamedTuple): """Resolved right-hand side of a `Relation.join()`.""" @@ -455,9 +453,8 @@ def _apply_join( return query -def _qualify_physical_tables_with_dataset(expression: _TExpr, dataset_name: str) -> _TExpr: +def _qualify_physical_tables_with_dataset(expression: sge.Expression, dataset_name: str) -> None: """Bind every physical table reference in `expression` to `dataset_name`.""" - expression = expression.copy() cte_names = {cte.alias_or_name for cte in expression.find_all(sge.CTE)} db_identifier = sge.to_identifier(dataset_name, quoted=False) for table in expression.find_all(sge.Table): @@ -466,7 +463,6 @@ def _qualify_physical_tables_with_dataset(expression: _TExpr, dataset_name: str) if table.args.get("db"): continue table.set("db", db_identifier.copy()) - return expression def _left_source_qualifier(query: sge.Query) -> Optional[str]: @@ -538,7 +534,8 @@ def _apply_explicit_join( destination_dialect: Dialect for parsing string ON expressions. left_dataset_name: Dataset name for the left-hand side. """ - query = _qualify_physical_tables_with_dataset(_copy_as_select(expression), left_dataset_name) + query = _copy_as_select(expression) + _qualify_physical_tables_with_dataset(query, left_dataset_name) from_this = _from_source(query) if not isinstance(from_this, sge.Table): @@ -566,9 +563,9 @@ def _apply_explicit_join( target_expr: sge.Expression if target.subquery is not None: # transformed relation: embed its query as a subquery - rhs_inner = target.subquery + rhs_inner = target.subquery.copy() if target_dataset_name: - rhs_inner = _qualify_physical_tables_with_dataset(rhs_inner, target_dataset_name) + _qualify_physical_tables_with_dataset(rhs_inner, target_dataset_name) target_expr = _aliased_subquery(rhs_inner, target_qualifier) else: target_expr = sge.Table( diff --git a/tests/dataset/test_relation_join.py b/tests/dataset/test_relation_join.py index 8f64658fe6..e4951277f1 100644 --- a/tests/dataset/test_relation_join.py +++ b/tests/dataset/test_relation_join.py @@ -1173,6 +1173,29 @@ def test_explicit_on_with_filtered_rhs( assert list(df["orders__amount"]) == [75.0, 200.0] +def test_explicit_on_does_not_mutate_transformed_rhs( + dataset_with_relational_tables: dlt.Dataset, +) -> None: + ds = dataset_with_relational_tables + expensive_orders = ds.table("orders").where("amount", "gt", 50.0) + rhs_sql_before = expensive_orders.to_sql() + assert expensive_orders.sqlglot_expression.parent is None + + joined = ds.table("customers").join( + expensive_orders, on="customers.customer_id = orders.customer_id" + ) + + # the join leaves the RHS relation untouched + assert expensive_orders.sqlglot_expression.parent is None + assert expensive_orders.to_sql() == rhs_sql_before + + joined_again = ds.table("customers").join( + expensive_orders, on="customers.customer_id = orders.customer_id", alias="o2" + ) + assert list(joined.df()["orders__amount"]) == [75.0, 200.0] + assert list(joined_again.df()["o2__amount"]) == [75.0, 200.0] + + def test_explicit_on_with_projected_lhs_preserves_left_projection( dataset_with_relational_tables: dlt.Dataset, ) -> None: From d2fea134f75f0fcc50d28c2ca37544aae095044d Mon Sep 17 00:00:00 2001 From: Anton Burnashev Date: Wed, 3 Jun 2026 10:34:49 +0200 Subject: [PATCH 26/30] wrap limited left side as derived table so limit doesn't leak past explicit joins --- dlt/dataset/_join.py | 4 ++- tests/dataset/test_relation_join.py | 48 +++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 1 deletion(-) diff --git a/dlt/dataset/_join.py b/dlt/dataset/_join.py index c8f2c9449c..6ca0ea2a8a 100644 --- a/dlt/dataset/_join.py +++ b/dlt/dataset/_join.py @@ -477,7 +477,9 @@ def _collect_source_qualifiers(query: sge.Query) -> Set[str]: def _is_flat_select(query: sge.Select) -> bool: - if any(query.args.get(key) for key in ("group", "having", "qualify", "distinct")): + if any( + query.args.get(key) for key in ("group", "having", "qualify", "distinct", "limit", "offset") + ): return False return not any(sel.find(sge.AggFunc) for sel in query.selects) diff --git a/tests/dataset/test_relation_join.py b/tests/dataset/test_relation_join.py index e4951277f1..e568aed6e7 100644 --- a/tests/dataset/test_relation_join.py +++ b/tests/dataset/test_relation_join.py @@ -492,6 +492,54 @@ def test_join_does_not_project_incomplete_target_columns( assert len(rows) == 3 +_MAGIC_LIMIT_LEAKS_PAST_JOIN = pytest.mark.xfail( + reason=( + "magic join does not wrap a limited left relation as a derived table, so LIMIT is " + "rendered on the joined query instead of the limited left relation" + ), + strict=True, +) + + +@pytest.mark.parametrize( + "build_join,expected_product_ids", + [ + pytest.param( + lambda ds: ds.table("categories").order_by("id").limit(1).join("products"), + [10, 12], + id="magic", + marks=_MAGIC_LIMIT_LEAKS_PAST_JOIN, + ), + pytest.param( + lambda ds: ds.table("categories") + .order_by("id") + .limit(1) + .join("products", on="categories.id = products.category_id"), + [10, 12], + id="explicit-on-limit", + ), + pytest.param( + lambda ds: ds.query("SELECT * FROM categories ORDER BY id LIMIT 1 OFFSET 1").join( + "products", on="categories.id = products.category_id" + ), + [11], + id="explicit-on-limit-offset", + ), + ], +) +def test_limit_then_join_applies_limit_before_join( + dataset_with_incomplete_join_target: dlt.Dataset, + build_join: Callable[[dlt.Dataset], dlt.Relation], + expected_product_ids: list[int], +) -> None: + """`.limit(n)` must bound the left relation before joining, not cap the joined result.""" + relation = build_join(dataset_with_incomplete_join_target) + df = relation.df() + + assert len(df) == len(expected_product_ids) + assert sorted(df["products__id"]) == expected_product_ids + + def test_join_rejects_empty_alias(dataset_with_loads: TLoadsFixture) -> None: dataset, _, _ = dataset_with_loads with pytest.raises(ValueError, match="must be a non-empty string"): From 7cf2ca75aaa65043a6bf275631abd17de18af829 Mon Sep 17 00:00:00 2001 From: Anton Burnashev Date: Wed, 3 Jun 2026 10:56:10 +0200 Subject: [PATCH 27/30] name the destination in the cross-dataset join error --- dlt/dataset/relation.py | 3 ++- tests/dataset/test_relation_join.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/dlt/dataset/relation.py b/dlt/dataset/relation.py index bcdc9ba35c..d260f540bf 100644 --- a/dlt/dataset/relation.py +++ b/dlt/dataset/relation.py @@ -527,7 +527,8 @@ def _resolve_join_target( is_foreign = not self._dataset._is_same_dataset(target_dataset) if is_foreign and isinstance(self.sql_client, WithSchemas): raise ValueError( - "Cross-dataset joins are not supported on filesystem destinations." + "Cross-dataset joins are not supported on the" + f" `{self._dataset._destination.destination_name}` destination." ) target_table = other._table_name diff --git a/tests/dataset/test_relation_join.py b/tests/dataset/test_relation_join.py index e568aed6e7..1fcc26bae5 100644 --- a/tests/dataset/test_relation_join.py +++ b/tests/dataset/test_relation_join.py @@ -264,6 +264,35 @@ def test_join_rejects_same_name_on_different_physical_destinations() -> None: ds_a.table("users").join(ds_b.table("orders"), on="users.id = orders.user_id") +def test_join_rejects_cross_dataset_on_filesystem() -> None: + with tempfile.TemporaryDirectory() as tmp: + tmp_path = pathlib.Path(tmp) + destination = dlt.destinations.filesystem(str(tmp_path / "data")) + + pipeline_a = dlt.pipeline( + pipeline_name="fs_cross_ds_a", + pipelines_dir=str(tmp_path / "pipelines_dir"), + destination=destination, + dataset_name="fs_crm", + ) + pipeline_a.run([{"id": 1, "name": "Alice"}], table_name="users") + + pipeline_b = dlt.pipeline( + pipeline_name="fs_cross_ds_b", + pipelines_dir=str(tmp_path / "pipelines_dir"), + destination=destination, + dataset_name="fs_inv", + ) + pipeline_b.run([{"order_id": 10, "user_id": 1}], table_name="orders") + + ds_a = pipeline_a.dataset() + ds_b = pipeline_b.dataset() + assert ds_a.is_same_physical_destination(ds_b) + + with pytest.raises(ValueError, match="not supported on the `filesystem` destination"): + ds_a.table("users").join(ds_b.table("orders"), on="users.id = orders.user_id") + + @pytest.mark.parametrize( "dataset_with_loads,left,right,expected_targets", [ From ab0b7f118f7623dc7c736f842c1ca06d4d382563 Mon Sep 17 00:00:00 2001 From: Anton Burnashev Date: Wed, 3 Jun 2026 15:50:45 +0200 Subject: [PATCH 28/30] omit dataset_name in table expansion when unset for legacy sql client compat --- dlt/dataset/relation.py | 10 +---- dlt/destinations/queries.py | 28 +++++++++---- tests/destinations/test_queries.py | 65 +++++++++++++++++++++--------- 3 files changed, 68 insertions(+), 35 deletions(-) diff --git a/dlt/dataset/relation.py b/dlt/dataset/relation.py index d260f540bf..00cda26529 100644 --- a/dlt/dataset/relation.py +++ b/dlt/dataset/relation.py @@ -36,7 +36,7 @@ from dlt.common.exceptions import ValueErrorWithKnownValues from dlt.dataset import lineage from dlt.destinations.sql_client import SqlClientBase, WithSchemas, WithSqlClient -from dlt.destinations.queries import bind_query, build_select_expr +from dlt.destinations.queries import bind_query, build_select_expr, make_expand_table_name from dlt.common.destination.dataset import SupportsDataAccess from dlt.dataset._incremental import ( _build_incremental_aggregate, @@ -272,16 +272,10 @@ def to_sql(self, pretty: bool = False, *, _raw_query: bool = False) -> str: query = self.sqlglot_expression else: _, _qualified_query = _get_relation_output_columns_schema(self) - - def _expand(table_name: str, db: Optional[str] = None) -> list[str]: - return self.sql_client.make_qualified_table_name_path( - table_name, quote=False, casefold=False, dataset_name=db - ) - query = bind_query( qualified_query=_qualified_query, sqlglot_schema=self._relation_sqlglot_schema(), - expand_table_name=_expand, + expand_table_name=make_expand_table_name(self.sql_client), casefold_identifier=self.sql_client.capabilities.casefold_identifier, ) diff --git a/dlt/destinations/queries.py b/dlt/destinations/queries.py index c3ef72afb5..7481190d8d 100644 --- a/dlt/destinations/queries.py +++ b/dlt/destinations/queries.py @@ -1,4 +1,4 @@ -from typing import Any, List, Optional +from typing import Any, Callable, List, Optional import sqlglot.expressions as sge from sqlglot.schema import Schema as SQLGlotSchema @@ -8,6 +8,24 @@ from dlt.destinations.sql_client import SqlClientBase +def make_expand_table_name( + sql_client: SqlClientBase[Any], +) -> Callable[[str, Optional[str]], List[str]]: + """Create a `bind_query` table name expander bound to `sql_client`.""" + + def _expand(table_name: str, db: Optional[str] = None) -> List[str]: + if db is None: + # omit dataset name if not provided for backward compatibility + return sql_client.make_qualified_table_name_path( + table_name, quote=False, casefold=False + ) + return sql_client.make_qualified_table_name_path( + table_name, quote=False, casefold=False, dataset_name=db + ) + + return _expand + + def _normalize_query( qualified_query: sge.Query, sqlglot_schema: SQLGlotSchema, @@ -19,16 +37,10 @@ def _normalize_query( TODO: remove after next dlthub release """ - - def _expand(table_name: str, db: Optional[str] = None) -> List[str]: - return sql_client.make_qualified_table_name_path( - table_name, quote=False, casefold=False, dataset_name=db - ) - return bind_query( qualified_query, sqlglot_schema, - expand_table_name=_expand, + expand_table_name=make_expand_table_name(sql_client), casefold_identifier=casefold_identifier, ) diff --git a/tests/destinations/test_queries.py b/tests/destinations/test_queries.py index 757e614701..b7e7fbfbe5 100644 --- a/tests/destinations/test_queries.py +++ b/tests/destinations/test_queries.py @@ -1,4 +1,4 @@ -from typing import List, Optional, cast +from typing import Any, Iterator, List, Optional, cast import duckdb import pytest @@ -9,10 +9,27 @@ import dlt from dlt.common.schema.typing import C_DLT_LOAD_ID from dlt.dataset.lineage import compute_columns_schema -from dlt.destinations.queries import build_row_counts_expr, build_select_expr, bind_query +from dlt.destinations.queries import ( + build_row_counts_expr, + build_select_expr, + bind_query, + make_expand_table_name, +) +from dlt.destinations.sql_client import SqlClientBase from dlt.destinations.impl.duckdb.configuration import DuckDbClientConfiguration +@pytest.fixture +def duckdb_sql_client() -> Iterator[SqlClientBase[Any]]: + """In-memory duckdb sql client bound to `dataset_name`.""" + con = duckdb.connect(":memory:") + destination_client = dlt.destinations.duckdb(con).client( + dlt.Schema("foobar"), DuckDbClientConfiguration()._bind_dataset_name("dataset_name") + ) + with destination_client.sql_client as sql_client: + yield sql_client + + def test_basic() -> None: stmt = build_row_counts_expr("my_table", quoted_identifiers=True) expected = ( @@ -101,7 +118,7 @@ def test_qualified_query(): assert qualified_query == expected_qualified_query -def test_normalize_query(): +def test_normalize_query(duckdb_sql_client: SqlClientBase[Any]) -> None: sqlglot_schema = SQLGlotSchema( {"dataset_name": {"items": {"id": str}, "double_items": {"double_id": str, "id": str}}} ) @@ -122,25 +139,35 @@ def test_normalize_query(): ' "i"."id" < 20 ORDER BY "i"."id" ASC' ) - con = duckdb.connect(":memory:") - duckdb_dest = dlt.destinations.duckdb(con) - duckdb_destination_client = duckdb_dest.client( - dlt.Schema("foobar"), DuckDbClientConfiguration()._bind_dataset_name("dataset_name") + normalized_query_expr = bind_query( + qualified_query=cast(sge.Query, qualified_query_expr), + sqlglot_schema=sqlglot_schema, + expand_table_name=make_expand_table_name(duckdb_sql_client), + casefold_identifier=duckdb_sql_client.capabilities.casefold_identifier, ) - with duckdb_destination_client.sql_client as sql_client: + assert normalized_query_expr.sql() == expected_normalized_query + - def _expand(table_name: str, db: Optional[str] = None) -> List[str]: - return sql_client.make_qualified_table_name_path( - table_name, quote=False, casefold=False, dataset_name=db +def test_expand_table_name_with_legacy_path_signature( + duckdb_sql_client: SqlClientBase[Any], +) -> None: + """Sql clients overriding `make_qualified_table_name_path` without the `dataset_name` + parameter keep working for tables without a dataset qualifier.""" + + class _LegacyPathClient: + def make_qualified_table_name_path( + self, table_name: Optional[str], quote: bool = True, casefold: bool = True + ) -> List[str]: + return duckdb_sql_client.make_qualified_table_name_path( + table_name, quote=quote, casefold=casefold ) - normalized_query_expr = bind_query( - qualified_query=cast(sge.Query, qualified_query_expr), - sqlglot_schema=sqlglot_schema, - expand_table_name=_expand, - casefold_identifier=sql_client.capabilities.casefold_identifier, - ) - normalized_query = normalized_query_expr.sql() + expand = make_expand_table_name(cast(SqlClientBase[Any], _LegacyPathClient())) - assert normalized_query == expected_normalized_query + assert expand("items", None) == duckdb_sql_client.make_qualified_table_name_path( + "items", quote=False, casefold=False + ) + # a dataset qualifier requires the `dataset_name` parameter on the override + with pytest.raises(TypeError, match="dataset_name"): + expand("items", "other_dataset") From 38b02fdd9a5c3c67c8be920649ef0860378ed07c Mon Sep 17 00:00:00 2001 From: Anton Burnashev Date: Wed, 3 Jun 2026 17:23:26 +0200 Subject: [PATCH 29/30] qualify local join tables with dataset name, wrap windowed left side in explicit joins --- dlt/dataset/_join.py | 13 +-- dlt/dataset/lineage.py | 4 +- dlt/dataset/relation.py | 5 ++ tests/dataset/test_relation_join.py | 123 ++++++++++++++++++++++++++++ 4 files changed, 133 insertions(+), 12 deletions(-) diff --git a/dlt/dataset/_join.py b/dlt/dataset/_join.py index 6ca0ea2a8a..0a4c9c950d 100644 --- a/dlt/dataset/_join.py +++ b/dlt/dataset/_join.py @@ -481,7 +481,7 @@ def _is_flat_select(query: sge.Select) -> bool: query.args.get(key) for key in ("group", "having", "qualify", "distinct", "limit", "offset") ): return False - return not any(sel.find(sge.AggFunc) for sel in query.selects) + return not any(sel.find(sge.AggFunc, sge.Window) for sel in query.selects) def _qualify_unscoped_predicate_columns(query: sge.Select, source_qualifier: str) -> None: @@ -560,23 +560,16 @@ def _apply_explicit_join( "qualifier is unambiguous." ) - target_dataset_name = target.dataset_name if target.is_foreign else None - target_expr: sge.Expression if target.subquery is not None: # transformed relation: embed its query as a subquery rhs_inner = target.subquery.copy() - if target_dataset_name: - _qualify_physical_tables_with_dataset(rhs_inner, target_dataset_name) + _qualify_physical_tables_with_dataset(rhs_inner, target.dataset_name) target_expr = _aliased_subquery(rhs_inner, target_qualifier) else: target_expr = sge.Table( this=sge.to_identifier(target.table_name, quoted=True), - db=( - sge.to_identifier(target_dataset_name, quoted=False) - if target_dataset_name - else None - ), + db=sge.to_identifier(target.dataset_name, quoted=False), ) if isinstance(on, str): diff --git a/dlt/dataset/lineage.py b/dlt/dataset/lineage.py index 3bd07d2078..4e5a6b3003 100644 --- a/dlt/dataset/lineage.py +++ b/dlt/dataset/lineage.py @@ -2,7 +2,7 @@ import sqlglot.expressions as sge -from sqlglot.errors import OptimizeError +from sqlglot.errors import OptimizeError, SchemaError from sqlglot.schema import Schema as SQLGlotSchema, ensure_schema from sqlglot.optimizer.annotate_types import annotate_types from sqlglot.optimizer.qualify import qualify @@ -137,7 +137,7 @@ def compute_columns_schema( expand_stars=True, ), ) - except OptimizeError as e: + except (OptimizeError, SchemaError) as e: raise LineageFailedException( f"Failed to resolve SQL query against the schema received: {e}" ) from e diff --git a/dlt/dataset/relation.py b/dlt/dataset/relation.py index 00cda26529..6aea70cded 100644 --- a/dlt/dataset/relation.py +++ b/dlt/dataset/relation.py @@ -53,6 +53,7 @@ _extract_joined_table_aliases, _JoinTarget, _left_source_qualifier, + _qualify_physical_tables_with_dataset, ) @@ -488,6 +489,10 @@ def join( left_dataset_name=self._dataset.dataset_name, ) + # bind tables left unqualified (e.g. magic join targets) to the local dataset so + # lineage stays unambiguous once foreign schemas are registered + _qualify_physical_tables_with_dataset(query, self._dataset.dataset_name) + rel = self.__copy__() rel._sqlglot_expression = query diff --git a/tests/dataset/test_relation_join.py b/tests/dataset/test_relation_join.py index 1fcc26bae5..b2d6b20ffd 100644 --- a/tests/dataset/test_relation_join.py +++ b/tests/dataset/test_relation_join.py @@ -569,6 +569,20 @@ def test_limit_then_join_applies_limit_before_join( assert sorted(df["products__id"]) == expected_product_ids +def test_windowed_lhs_join_applies_window_before_join( + dataset_with_relational_tables: dlt.Dataset, +) -> None: + ds = dataset_with_relational_tables + numbered = ds.query( + "SELECT customer_id, name, ROW_NUMBER() OVER (ORDER BY name) AS rn FROM customers" + ) + joined = numbered.join("orders", on="customers.customer_id = orders.customer_id") + df = joined.order_by("orders__order_id").df() + + assert len(df) == 4 + assert [int(x) for x in df["rn"]] == [1, 1, 2, 3] + + def test_join_rejects_empty_alias(dataset_with_loads: TLoadsFixture) -> None: dataset, _, _ = dataset_with_loads with pytest.raises(ValueError, match="must be a non-empty string"): @@ -2012,3 +2026,112 @@ def test_cross_dataset_chain_same_named_tables_disambiguated( assert "marketing__segment" in df.columns assert list(df["users__name"]) == ["Alice", "Alice", "Bob"] assert list(df["marketing__segment"]) == ["pro", "pro", "free"] + + +@pytest.mark.parametrize( + "build_local_join,local_table,check_column,expected_values", + [ + pytest.param( + lambda ds, rel: rel.join("users", on="orders.user_id = users.id"), + "users", + "users__name", + ["Alice", "Bob"], + id="bare-table-name", + ), + pytest.param( + lambda ds, rel: rel.join(f"{ds.dataset_name}.users", on="orders.user_id = users.id"), + "users", + "users__name", + ["Alice", "Bob"], + id="dataset-qualified-string", + ), + pytest.param( + lambda ds, rel: rel.join( + ds.query("SELECT * FROM users AS u"), on="orders.user_id = u.id" + ), + "users", + "u__name", + ["Alice", "Bob"], + id="aliased-local-query", + ), + pytest.param( + lambda ds, rel: rel.join("_dlt_loads", on="orders._dlt_load_id = _dlt_loads.load_id"), + "_dlt_loads", + "_dlt_loads__status", + [0, 0], + id="dlt-loads-system-table", + ), + ], +) +def test_cross_dataset_join_then_local_join_to_same_named_table( + build_local_join: Callable[[dlt.Dataset, dlt.Relation], dlt.Relation], + local_table: str, + check_column: str, + expected_values: list[Any], +) -> None: + """A local join target shadowed by a same-named foreign table must bind to the local dataset.""" + with tempfile.TemporaryDirectory() as tmp: + tmp_path = pathlib.Path(tmp) + db_path = str(tmp_path / "shadowed.duckdb") + + pipeline_crm = dlt.pipeline( + pipeline_name="shadowed_local_target_a", + pipelines_dir=str(tmp_path / "pipelines_dir"), + destination=dlt.destinations.duckdb(db_path), + dataset_name="crm_data", + ) + pipeline_crm.run( + [{"order_id": 1, "user_id": 1}, {"order_id": 2, "user_id": 2}], + table_name="orders", + ) + pipeline_crm.run( + [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}], + table_name="users", + ) + + pipeline_mkt = dlt.pipeline( + pipeline_name="shadowed_local_target_b", + pipelines_dir=str(tmp_path / "pipelines_dir"), + destination=dlt.destinations.duckdb(db_path), + dataset_name="mkt_data", + ) + pipeline_mkt.run( + [{"id": 1, "segment": "pro"}, {"id": 2, "segment": "free"}], + table_name="users", + ) + + ds_crm = pipeline_crm.dataset() + ds_mkt = pipeline_mkt.dataset() + + foreign_joined = ds_crm.table("orders").join( + ds_mkt.query("SELECT * FROM users AS mkt_users"), + on="orders.user_id = mkt_users.id", + alias="marketing", + ) + joined = build_local_join(ds_crm, foreign_joined) + + sql = joined.to_sql() + assert f'"{ds_crm.dataset_name}"."{local_table}"' in sql, sql + + df = joined.order_by("order_id").df() + assert list(df[check_column]) == expected_values + assert list(df["marketing__segment"]) == ["pro", "free"] + + +def test_magic_join_after_cross_dataset_resolves_local_target( + same_named_cross_dataset_duckdb: TCrossDsFixture, +) -> None: + """A magic join target shadowed by a same-named foreign table must bind to the local dataset.""" + ds_crm, ds_marketing = same_named_cross_dataset_duckdb + marketing = ds_marketing.query("SELECT * FROM users AS mkt_users") + + joined = ( + ds_crm.table("users__orders") + .join(marketing, on="mkt_users.id = 1", alias="marketing", kind="left") + .join("users") + ) + df = joined.order_by("order_id").df() + + assert len(df) == 3 + assert list(df["users__name"]) == ["Alice", "Alice", "Bob"] + assert list(df["marketing__segment"]) == ["pro", "pro", "pro"] From 3d3e04681d0532cf26c158be5f8acd892a0cd889 Mon Sep 17 00:00:00 2001 From: Anton Burnashev Date: Wed, 3 Jun 2026 17:45:36 +0200 Subject: [PATCH 30/30] validate on= join condition and improve join error messages --- dlt/dataset/_join.py | 11 ++++++++++- dlt/dataset/relation.py | 11 ++++++++--- tests/dataset/test_relation_join.py | 27 +++++++++++++++++++++------ 3 files changed, 39 insertions(+), 10 deletions(-) diff --git a/dlt/dataset/_join.py b/dlt/dataset/_join.py index 0a4c9c950d..ed23d26626 100644 --- a/dlt/dataset/_join.py +++ b/dlt/dataset/_join.py @@ -5,6 +5,7 @@ import sqlglot import sqlglot.expressions as sge +from sqlglot.errors import ParseError, TokenError from dlt.common.typing import TypedDict from dlt.common.schema import Schema, utils as schema_utils @@ -573,9 +574,17 @@ def _apply_explicit_join( ) if isinstance(on, str): - on_expr = sqlglot.parse_one(on, dialect=destination_dialect) + try: + on_expr = sqlglot.parse_one(on, dialect=destination_dialect) + except (ParseError, TokenError) as e: + raise ValueError(f"Cannot parse `on` join condition `{on}`: {e}") from e else: on_expr = on + if not isinstance(on_expr, sge.Condition): + raise ValueError( + f"`on` join condition `{on_expr.sql(destination_dialect)}` must be an SQL boolean" + " expression (e.g. `left.col = right.col`)." + ) join_expr = sge.Join(this=target_expr, kind=kind.upper()).on(on_expr) query = query.join(join_expr) diff --git a/dlt/dataset/relation.py b/dlt/dataset/relation.py index 6aea70cded..651c613455 100644 --- a/dlt/dataset/relation.py +++ b/dlt/dataset/relation.py @@ -519,8 +519,11 @@ def _resolve_join_target( if not self._dataset.is_same_physical_destination(target_dataset): raise ValueError( - "Cannot join relations from different physical destinations: " - f"'{target_dataset.dataset_name}' vs '{self._dataset.dataset_name}'" + "Cannot join relations from different physical destinations: dataset" + f" '{self._dataset.dataset_name}' on" + f" '{self._dataset.destination_client.config}' vs dataset" + f" '{target_dataset.dataset_name}' on" + f" '{target_dataset.destination_client.config}'" ) is_foreign = not self._dataset._is_same_dataset(target_dataset) @@ -582,7 +585,9 @@ def _resolve_join_target( "foreign dataset to automatically register its schema." ) - raise ValueError("`other` must be a table name or a base table relation.") + raise ValueError( + f"`other` must be a table name or a `dlt.Relation`, got `{type(other).__name__}`." + ) def incremental(self, incremental: Incremental[Any]) -> Self: """Filter this relation to a cursor range using an Incremental. diff --git a/tests/dataset/test_relation_join.py b/tests/dataset/test_relation_join.py index b2d6b20ffd..c8a3d9a078 100644 --- a/tests/dataset/test_relation_join.py +++ b/tests/dataset/test_relation_join.py @@ -1,6 +1,6 @@ import tempfile import pathlib -from typing import Any, Sequence, Callable, TypedDict, Optional +from typing import Any, Sequence, Callable, TypedDict, Optional, Union import pytest import sqlglot @@ -260,9 +260,12 @@ def test_join_rejects_same_name_on_different_physical_destinations() -> None: assert ds_a.dataset_name == ds_b.dataset_name assert not ds_a.is_same_physical_destination(ds_b) - with pytest.raises(ValueError, match="different physical destinations"): + with pytest.raises(ValueError, match="different physical destinations") as exc_info: ds_a.table("users").join(ds_b.table("orders"), on="users.id = orders.user_id") + assert "a.duckdb" in str(exc_info.value) + assert "b.duckdb" in str(exc_info.value) + def test_join_rejects_cross_dataset_on_filesystem() -> None: with tempfile.TemporaryDirectory() as tmp: @@ -370,7 +373,7 @@ def test_resolve_reference_chain_rejects_unrelated_tables( pytest.param( lambda ds: ds.table("users"), 123, - "`other` must be a table name or a base table relation", + "`other` must be a table name or a `dlt.Relation`, got `int`", id="invalid-other-type", ), pytest.param( @@ -1491,13 +1494,25 @@ def test_self_join_requires_distinct_qualifiers( assert sorted(df["mgr__name"]) == ["Alice", "Alice"] -@pytest.mark.parametrize("on", ["", " "], ids=["empty", "whitespace"]) +@pytest.mark.parametrize( + "on,match", + [ + pytest.param("", "non-empty SQL expression", id="empty"), + pytest.param(" ", "non-empty SQL expression", id="whitespace"), + pytest.param("customers.id = (((", "Cannot parse `on`", id="unparsable"), + pytest.param("SELECT 1", "must be an SQL boolean expression", id="select-string"), + pytest.param( + sqlglot.select("1"), "must be an SQL boolean expression", id="select-expression" + ), + ], +) def test_explicit_on_rejects_invalid_on_expression( dataset_with_relational_tables: dlt.Dataset, - on: str, + on: Union[str, sge.Expression], + match: str, ) -> None: ds = dataset_with_relational_tables - with pytest.raises(ValueError, match="non-empty SQL expression"): + with pytest.raises(ValueError, match=match): ds.table("customers").join("orders", on=on)