diff --git a/CHANGELOG.md b/CHANGELOG.md index 9e2531ba..18e3e080 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,9 @@ ### Bug Fixes - A `datetime` bound to a server-side `{name:DateTime64(...)}` placeholder now keeps its sub-second precision instead of being truncated to seconds. The declared parameter type drives this, so no `_64` name suffix or manual `DT64Param` wrapper is needed, and it applies through `Array` and `Tuple` hints. Plain `DateTime` binds are unchanged. Closes [#739](https://github.com/ClickHouse/clickhouse-connect/issues/739). - Strip `--` line comments that have no following space when classifying queries, so a DDL with a leading `--sql`-style comment is routed as a command instead of raising `StreamFailureError`. Closes [#499](https://github.com/ClickHouse/clickhouse-connect/issues/499). +- SQLAlchemy: implement reflection on the dialect itself so `MetaData.reflect()` and `Inspector.get_multi_columns()` work. +- SQLAlchemy: UDT-based types (`UUID`, `IPv4`/`IPv6`, `JSON`, `Nested`, geometry types, `AggregateFunction`, etc.) now return concrete `python_type` classes instead of `None`, matching SQLAlchemy's `TypeEngine.python_type` contract. +- SQLAlchemy: `Array` now subclasses `sqlalchemy.types.ARRAY` and exposes `item_type`. - `bytes`/`bytearray` query parameters now render as ClickHouse string literals (each byte as `\xHH`) instead of the Python repr, fixing inserts into `FixedString`/`String` columns through the SQLAlchemy dialect. Closes [#777](https://github.com/ClickHouse/clickhouse-connect/issues/777). - The `dsn` passed to `create_client`/`create_async_client` now percent-decodes the username, password, and database, so credentials containing reserved characters can be supplied URL-encoded (`pass%20word` becomes `pass word`). A literal `%` in a DSN must now be written as `%25`. A DSN with a username and no password now sends an empty password rather than the literal string `None`. Closes [#713](https://github.com/ClickHouse/clickhouse-connect/issues/713). diff --git a/clickhouse_connect/cc_sqlalchemy/datatypes/sqltypes.py b/clickhouse_connect/cc_sqlalchemy/datatypes/sqltypes.py index 94d8e67f..53561505 100644 --- a/clickhouse_connect/cc_sqlalchemy/datatypes/sqltypes.py +++ b/clickhouse_connect/cc_sqlalchemy/datatypes/sqltypes.py @@ -1,7 +1,17 @@ +import ipaddress +import uuid from collections.abc import Sequence from enum import Enum as PyEnum from sqlalchemy.exc import ArgumentError +from sqlalchemy.types import ( + ARRAY, + Float, + Integer, + Interval, + Numeric, + UserDefinedType, +) from sqlalchemy.types import ( Boolean as SqlaBoolean, ) @@ -11,18 +21,11 @@ from sqlalchemy.types import ( DateTime as SqlaDateTime, ) -from sqlalchemy.types import ( - Float, - Integer, - Interval, - Numeric, - UserDefinedType, -) from sqlalchemy.types import ( String as SqlaString, ) -from clickhouse_connect.cc_sqlalchemy.datatypes.base import ChSqlaType, schema_types +from clickhouse_connect.cc_sqlalchemy.datatypes.base import ChSqlaType, schema_types, sqla_type_from_name from clickhouse_connect.datatypes.base import EMPTY_TYPE_DEF, LC_TYPE_DEF, NULLABLE_TYPE_DEF, TypeDef from clickhouse_connect.datatypes.numeric import Enum8 as ChEnum8 from clickhouse_connect.datatypes.numeric import Enum16 as ChEnum16 @@ -209,43 +212,43 @@ def __init__(self, size: int = -1, type_def: TypeDef = None): class IPv4(ChSqlaType, UserDefinedType): - python_type = None + python_type = ipaddress.IPv4Address class IPv6(ChSqlaType, UserDefinedType): - python_type = None + python_type = ipaddress.IPv6Address class UUID(ChSqlaType, UserDefinedType): - python_type = None + python_type = uuid.UUID class Nothing(ChSqlaType, UserDefinedType): - python_type = None + python_type = type(None) class Point(ChSqlaType, UserDefinedType): - python_type = None + python_type = tuple class Ring(ChSqlaType, UserDefinedType): - python_type = None + python_type = list class Polygon(ChSqlaType, UserDefinedType): - python_type = None + python_type = list class MultiPolygon(ChSqlaType, UserDefinedType): - python_type = None + python_type = list class LineString(ChSqlaType, UserDefinedType): - python_type = None + python_type = list class MultiLineString(ChSqlaType, UserDefinedType): - python_type = None + python_type = list class Date(ChSqlaType, SqlaDate): @@ -412,8 +415,9 @@ def __new__(cls, element: ChSqlaType | type[ChSqlaType]): return element.__class__(type_def=TypeDef(wrappers, orig.keys, orig.values)) -class Array(ChSqlaType, UserDefinedType): +class Array(ChSqlaType, ARRAY): python_type = list + dimensions = 1 def __init__(self, element: ChSqlaType | type[ChSqlaType] = None, type_def: TypeDef = None): """ @@ -425,7 +429,12 @@ def __init__(self, element: ChSqlaType | type[ChSqlaType] = None, type_def: Type if callable(element): element = element() type_def = TypeDef(values=(element.name,)) - super().__init__(type_def) + ChSqlaType.__init__(self, type_def) + # Set item_type directly; calling ARRAY.__init__ would reject nested Array(Array(T)), + # which CH supports natively (CH expresses dimensions via nesting, not a dim count). + # as_tuple has no class-level default, so set it here to satisfy ARRAY result processing. + self.item_type = sqla_type_from_name(type_def.values[0]) + self.as_tuple = False class Map(ChSqlaType, UserDefinedType): @@ -489,7 +498,7 @@ class JSON(ChSqlaType, UserDefinedType): Note this isn't currently supported for insert/select, only table definitions """ - python_type = None + python_type = dict class Nested(ChSqlaType, UserDefinedType): @@ -497,11 +506,11 @@ class Nested(ChSqlaType, UserDefinedType): Note this isn't currently supported for insert/select, only table definitions """ - python_type = None + python_type = list class SimpleAggregateFunction(ChSqlaType, UserDefinedType): - python_type = None + python_type = str def __init__( self, @@ -532,7 +541,7 @@ class AggregateFunction(ChSqlaType, UserDefinedType): Note this isn't currently supported for insert/select, only table definitions """ - python_type = None + python_type = str def __init__(self, *params, type_def: TypeDef = None): """ diff --git a/clickhouse_connect/cc_sqlalchemy/dialect.py b/clickhouse_connect/cc_sqlalchemy/dialect.py index 4e47b96d..368e8042 100644 --- a/clickhouse_connect/cc_sqlalchemy/dialect.py +++ b/clickhouse_connect/cc_sqlalchemy/dialect.py @@ -5,7 +5,7 @@ from clickhouse_connect import dbapi from clickhouse_connect.cc_sqlalchemy import dialect_name, ischema_names -from clickhouse_connect.cc_sqlalchemy.inspector import ChInspector, get_table_metadata +from clickhouse_connect.cc_sqlalchemy.inspector import ChInspector, get_columns, get_table_metadata from clickhouse_connect.cc_sqlalchemy.sql import full_table from clickhouse_connect.cc_sqlalchemy.sql.compiler import ChStatementCompiler from clickhouse_connect.cc_sqlalchemy.sql.ddlcompiler import ChDDLCompiler @@ -94,6 +94,9 @@ def get_table_names(self, connection, schema=None, **kw): cmd += " FROM " + quote_identifier(schema) return [row.name for row in connection.execute(text(cmd))] + def get_columns(self, connection, table_name, schema=None, **kw): + return get_columns(connection, table_name, schema) + def get_primary_keys(self, connection, table_name, schema=None, **kw): return [] diff --git a/clickhouse_connect/cc_sqlalchemy/inspector.py b/clickhouse_connect/cc_sqlalchemy/inspector.py index 68875208..e566fec9 100644 --- a/clickhouse_connect/cc_sqlalchemy/inspector.py +++ b/clickhouse_connect/cc_sqlalchemy/inspector.py @@ -121,6 +121,36 @@ def get_dictionary_metadata(connection, table_name: str, schema: str | None = No return metadata +def get_columns(connection, table_name: str, schema: str | None = None) -> list[dict[str, Any]]: + table_metadata = get_table_metadata(connection, table_name, schema) + if table_metadata.engine == "Dictionary": + return get_dictionary_columns(connection, table_name, schema) + table_id = full_table(table_name, schema) + result_set = connection.execute(text(f"DESCRIBE TABLE {table_id}")) + if not result_set: + raise NoResultFound(f"Table {table_id} does not exist") + columns = [] + for row in result_set: + sqla_type = sqla_type_from_name(row.type.replace("\n", "")) + col = { + "name": row.name, + "type": sqla_type, + "nullable": sqla_type.nullable, + "autoincrement": False, + "comment": row.comment or None, + "clickhouse_codec": row.codec_expression or None, + "clickhouse_ttl": text(row.ttl_expression) if row.ttl_expression else None, + } + if row.default_type == "DEFAULT" and row.default_expression: + col["server_default"] = text(row.default_expression) + elif row.default_type == "MATERIALIZED" and row.default_expression: + col["clickhouse_materialized"] = text(row.default_expression) + elif row.default_type == "ALIAS" and row.default_expression: + col["clickhouse_alias"] = text(row.default_expression) + columns.append(col) + return columns + + class ChInspector(Inspector): def reflect_table( self, @@ -157,30 +187,4 @@ def reflect_table( table.kwargs["clickhouse_engine"] = table.engine def get_columns(self, table_name, schema=None, **_kwargs): - table_metadata = get_table_metadata(self.bind, table_name, schema) - if table_metadata.engine == "Dictionary": - return get_dictionary_columns(self.bind, table_name, schema) - table_id = full_table(table_name, schema) - result_set = self.bind.execute(text(f"DESCRIBE TABLE {table_id}")) - if not result_set: - raise NoResultFound(f"Table {table_id} does not exist") - columns = [] - for row in result_set: - sqla_type = sqla_type_from_name(row.type.replace("\n", "")) - col = { - "name": row.name, - "type": sqla_type, - "nullable": sqla_type.nullable, - "autoincrement": False, - "comment": row.comment or None, - "clickhouse_codec": row.codec_expression or None, - "clickhouse_ttl": text(row.ttl_expression) if row.ttl_expression else None, - } - if row.default_type == "DEFAULT" and row.default_expression: - col["server_default"] = text(row.default_expression) - elif row.default_type == "MATERIALIZED" and row.default_expression: - col["clickhouse_materialized"] = text(row.default_expression) - elif row.default_type == "ALIAS" and row.default_expression: - col["clickhouse_alias"] = text(row.default_expression) - columns.append(col) - return columns + return get_columns(self.bind, table_name, schema) diff --git a/tests/integration_tests/test_sqlalchemy/test_reflect.py b/tests/integration_tests/test_sqlalchemy/test_reflect.py index 311a7675..6da2669f 100644 --- a/tests/integration_tests/test_sqlalchemy/test_reflect.py +++ b/tests/integration_tests/test_sqlalchemy/test_reflect.py @@ -70,3 +70,55 @@ def test_get_table_names(test_engine: Engine, test_db: str): assert isinstance(system_tables, list) assert "columns" in system_tables assert "fake_table" not in system_tables + + +def test_metadata_reflect(test_engine: Engine, test_db: str): + """Dialect-level reflection. MetaData.reflect() exercises the + Dialect.get_multi_columns -> Dialect.get_columns path (not + Inspector.get_columns), which previously raised NotImplementedError. + The dialect does not reflect a primary key: ClickHouse PRIMARY KEY / + ORDER BY is not a uniqueness guarantee, so the identity key is left for + application code to declare explicitly.""" + common.set_setting("invalid_setting_action", "drop") + with test_engine.begin() as conn: + conn.execute(text(f"DROP TABLE IF EXISTS {test_db}.reflect_pk_test")) + conn.execute( + text( + f"CREATE TABLE {test_db}.reflect_pk_test (org_id UInt32, id UInt64, payload String) ENGINE MergeTree ORDER BY (org_id, id)" + ) + ) + + metadata = db.MetaData(schema=test_db) + metadata.reflect(bind=test_engine, only=["reflect_pk_test"]) + table = metadata.tables[f"{test_db}.reflect_pk_test"] + + assert {c.name for c in table.columns} == {"org_id", "id", "payload"} + assert list(table.primary_key.columns) == [] + + # Direct autoload should also populate columns without a reflected PK. + table2 = db.Table("reflect_pk_test", db.MetaData(schema=test_db), autoload_with=test_engine) + assert {c.name for c in table2.columns} == {"org_id", "id", "payload"} + assert list(table2.primary_key.columns) == [] + + +def test_user_declared_primary_key(test_engine: Engine, test_db: str): + """A user-declared primary key on a pre-declared column survives reflection.""" + common.set_setting("invalid_setting_action", "drop") + with test_engine.begin() as conn: + conn.execute(text(f"DROP TABLE IF EXISTS {test_db}.reflect_user_pk_test")) + conn.execute( + text( + f"CREATE TABLE {test_db}.reflect_user_pk_test (org_id UInt32, id UInt64, payload String) " + "ENGINE MergeTree ORDER BY (org_id, id)" + ) + ) + + table = db.Table( + "reflect_user_pk_test", + db.MetaData(schema=test_db), + db.Column("org_id", UInt32, primary_key=True), + db.Column("id", db.BigInteger, primary_key=True), + autoload_with=test_engine, + ) + assert [c.name for c in table.primary_key.columns] == ["org_id", "id"] + assert {c.name for c in table.columns} == {"org_id", "id", "payload"}