diff --git a/CHANGELOG.md b/CHANGELOG.md index e72e349e..95f6c0ae 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,9 @@ ## UNRELEASED +### New Features +- Add an in-process `chdb` backend. `clickhouse_connect.get_client(interface="chdb")` (and `get_async_client`) returns a client backed by the embedded ClickHouse engine via the `chdb` Python package, with no server required. + ### Bug Fixes - Async client: `ca_cert="certifi"` shorthand now resolves to `certifi.where()`, matching the sync client. Previously the async path passed the literal string to `ssl_context.load_verify_locations`, producing `FileNotFoundError`. Closes [#742](https://github.com/ClickHouse/clickhouse-connect/issues/742) - Fix SQLAlchemy dialect rendering for `ILIKE` and `NOT ILIKE` expressions to use native ClickHouse syntax instead of the generic SQLAlchemy `lower(...) LIKE lower(...)` fallback. diff --git a/clickhouse_connect/driver/__init__.py b/clickhouse_connect/driver/__init__.py index d0fe9f1d..4e616911 100644 --- a/clickhouse_connect/driver/__init__.py +++ b/clickhouse_connect/driver/__init__.py @@ -159,6 +159,14 @@ def create_client( limits. Only available for query operations (not inserts). Default: False :return: ClickHouse Connect Client instance """ + if interface == "chdb": + return _create_chdb_client( + database=database, + settings=settings, + generic_args=generic_args, + kwargs=kwargs, + ) + host, username, password, port, database, interface = _parse_connection_params( host, username, password, port, database, interface, secure, dsn, kwargs ) @@ -264,6 +272,14 @@ async def create_async_client( limits. Only available for query operations (not inserts). Default: False :return: ClickHouse Connect AsyncClient instance """ + if interface == "chdb": + return _create_chdb_async_client( + database=database, + settings=settings, + generic_args=generic_args, + kwargs=kwargs, + ) + try: from clickhouse_connect.driver.asyncclient import AsyncClient as _AsyncClient except ModuleNotFoundError as ex: @@ -315,3 +331,38 @@ async def create_async_client( ) await client._initialize() return client + + +def _create_chdb_client( + *, + database: str, + settings: dict[str, Any] | None, + generic_args: dict[str, Any] | None, + kwargs: dict[str, Any], +) -> Client: + from clickhouse_connect.driver.chdbclient import ChdbClient + + settings = dict(settings or {}) + if generic_args: + for name, value in generic_args.items(): + if name.startswith("ch_"): + name = name[3:] + settings[name] = value + return ChdbClient( + database=database, + settings=settings, + **kwargs, + ) + + +def _create_chdb_async_client( + *, + database: str, + settings: dict[str, Any] | None, + generic_args: dict[str, Any] | None, + kwargs: dict[str, Any], +): + from clickhouse_connect.driver.chdbasync import AsyncChdbClient + + sync_client = _create_chdb_client(database=database, settings=settings, generic_args=generic_args, kwargs=kwargs) + return AsyncChdbClient(sync_client) # type: ignore[arg-type] diff --git a/clickhouse_connect/driver/chdbasync.py b/clickhouse_connect/driver/chdbasync.py new file mode 100644 index 00000000..ee945adb --- /dev/null +++ b/clickhouse_connect/driver/chdbasync.py @@ -0,0 +1,314 @@ +""" +Async wrapper around ChdbClient. + +chdb has no native async API, so this client delegates each call to the wrapped +sync ChdbClient via `asyncio.get_running_loop().run_in_executor(...)`. Because +ChdbClient serializes concurrent calls on a per-client `threading.Lock`, +gather()-style concurrency on a single AsyncChdbClient does not actually run in +parallel — for true parallelism, create multiple clients. +""" + +from __future__ import annotations + +import asyncio +import io +from collections.abc import Generator, Iterable, Sequence +from datetime import tzinfo +from typing import TYPE_CHECKING, Any, BinaryIO + +from clickhouse_connect.datatypes.base import ClickHouseType +from clickhouse_connect.driver.chdbclient import ChdbClient +from clickhouse_connect.driver.client import Client +from clickhouse_connect.driver.common import StreamContext +from clickhouse_connect.driver.external import ExternalData +from clickhouse_connect.driver.insert import InsertContext +from clickhouse_connect.driver.query import QueryContext, QueryResult, TzMode +from clickhouse_connect.driver.summary import QuerySummary + +if TYPE_CHECKING: + import numpy + import pandas + import polars + import pyarrow + + +class AsyncChdbClient(Client): + """ + Async-facing client for the in-process chdb backend. Each public coroutine + schedules the corresponding sync ChdbClient call on the default thread + executor. Sync-only methods (settings, min_version) are passed through + directly. + """ + + def __init__(self, sync: ChdbClient): + self._sync = sync + # Mirror attributes commonly read off the client object so user code that + # touches them (server_version, server_tz, database, etc.) keeps working. + self.server_tz = sync.server_tz + self.server_version = sync.server_version + self.server_settings = sync.server_settings + self.database = sync.database + self.uri = sync.uri + self.query_limit = sync.query_limit + self.query_retries = sync.query_retries + self.tz_mode = sync.tz_mode + self._tz_source = sync._tz_source + self._apply_server_tz = sync._apply_server_tz + self._dst_safe = sync._dst_safe + self.show_clickhouse_errors = sync.show_clickhouse_errors + self.protocol_version = sync.protocol_version + self.write_compression = sync.write_compression + self.compression = sync.compression + self._read_format = sync._read_format + self._write_format = sync._write_format + self._transform = sync._transform + + @property + def chdb_connection(self): + return self._sync.chdb_connection + + async def _run(self, func, *args, **kwargs): + loop = asyncio.get_running_loop() + return await loop.run_in_executor(None, lambda: func(*args, **kwargs)) + + # ---- sync passthroughs (no I/O) ---- + + def set_client_setting(self, key: str, value: Any) -> None: + self._sync.set_client_setting(key, value) + + def get_client_setting(self, key: str) -> str | None: + return self._sync.get_client_setting(key) + + def set_access_token(self, access_token: str) -> None: + self._sync.set_access_token(access_token) + + def min_version(self, version_str: str) -> bool: + return self._sync.min_version(version_str) + + # ---- async overrides ---- + + async def _query_with_context(self, context: QueryContext) -> QueryResult: # type: ignore[override] + return await self._run(self._sync._query_with_context, context) + + async def query( # type: ignore[override] + self, + query: str | None = None, + parameters: Sequence | dict[str, Any] | None = None, + settings: dict[str, Any] | None = None, + query_formats: dict[str, str] | None = None, + column_formats: dict[str, str | dict[str, str]] | None = None, + encoding: str | None = None, + use_none: bool | None = None, + column_oriented: bool | None = None, + use_numpy: bool | None = None, + max_str_len: int | None = None, + context: QueryContext | None = None, + query_tz: str | tzinfo | None = None, + column_tzs: dict[str, str | tzinfo] | None = None, + external_data: ExternalData | None = None, + transport_settings: dict[str, str] | None = None, + tz_mode: TzMode | None = None, + ) -> QueryResult: + return await self._run( + lambda: self._sync.query( + query=query, + parameters=parameters, + settings=settings, + query_formats=query_formats, + column_formats=column_formats, + encoding=encoding, + use_none=use_none, + column_oriented=column_oriented, + use_numpy=use_numpy, + max_str_len=max_str_len, + context=context, + query_tz=query_tz, + column_tzs=column_tzs, + external_data=external_data, + transport_settings=transport_settings, + tz_mode=tz_mode, + ) + ) + + async def query_column_block_stream(self, *args, **kwargs) -> StreamContext: # type: ignore[override] + return await self._run(lambda: self._sync.query_column_block_stream(*args, **kwargs)) + + async def query_row_block_stream(self, *args, **kwargs) -> StreamContext: # type: ignore[override] + return await self._run(lambda: self._sync.query_row_block_stream(*args, **kwargs)) + + async def query_rows_stream(self, *args, **kwargs) -> StreamContext: # type: ignore[override] + return await self._run(lambda: self._sync.query_rows_stream(*args, **kwargs)) + + async def query_np(self, *args, **kwargs) -> numpy.ndarray: + return await self._run(lambda: self._sync.query_np(*args, **kwargs)) + + async def query_np_stream(self, *args, **kwargs) -> StreamContext: # type: ignore[override] + return await self._run(lambda: self._sync.query_np_stream(*args, **kwargs)) + + async def query_df(self, *args, **kwargs) -> pandas.DataFrame: + return await self._run(lambda: self._sync.query_df(*args, **kwargs)) + + async def query_df_stream(self, *args, **kwargs) -> StreamContext: # type: ignore[override] + return await self._run(lambda: self._sync.query_df_stream(*args, **kwargs)) + + async def query_arrow(self, *args, **kwargs) -> pyarrow.Table: + return await self._run(lambda: self._sync.query_arrow(*args, **kwargs)) + + async def query_arrow_stream(self, *args, **kwargs) -> StreamContext: # type: ignore[override] + return await self._run(lambda: self._sync.query_arrow_stream(*args, **kwargs)) + + async def query_df_arrow(self, *args, **kwargs) -> pandas.DataFrame | polars.DataFrame: + return await self._run(lambda: self._sync.query_df_arrow(*args, **kwargs)) + + async def query_df_arrow_stream(self, *args, **kwargs) -> StreamContext: # type: ignore[override] + return await self._run(lambda: self._sync.query_df_arrow_stream(*args, **kwargs)) + + async def command( # type: ignore[override] + self, + cmd: str, + parameters: Sequence | dict[str, Any] | None = None, + data: str | bytes | None = None, + settings: dict[str, Any] | None = None, + use_database: bool = True, + external_data: ExternalData | None = None, + transport_settings: dict[str, str] | None = None, + ) -> str | int | Sequence[str] | QuerySummary: + return await self._run( + lambda: self._sync.command( + cmd, + parameters=parameters, + data=data, + settings=settings, + use_database=use_database, + external_data=external_data, + transport_settings=transport_settings, + ) + ) + + async def ping(self) -> bool: # type: ignore[override] + return await self._run(self._sync.ping) + + async def raw_query( # type: ignore[override] + self, + query: str, + parameters: Sequence | dict[str, Any] | None = None, + settings: dict[str, Any] | None = None, + fmt: str | None = None, + use_database: bool = True, + external_data: ExternalData | None = None, + transport_settings: dict[str, str] | None = None, + ) -> bytes: + return await self._run( + lambda: self._sync.raw_query( + query, + parameters=parameters, + settings=settings, + fmt=fmt, + use_database=use_database, + external_data=external_data, + transport_settings=transport_settings, + ) + ) + + async def raw_stream( # type: ignore[override] + self, + query: str, + parameters: Sequence | dict[str, Any] | None = None, + settings: dict[str, Any] | None = None, + fmt: str | None = None, + use_database: bool = True, + external_data: ExternalData | None = None, + transport_settings: dict[str, str] | None = None, + ) -> io.IOBase: + return await self._run( + lambda: self._sync.raw_stream( + query, + parameters=parameters, + settings=settings, + fmt=fmt, + use_database=use_database, + external_data=external_data, + transport_settings=transport_settings, + ) + ) + + async def insert( # type: ignore[override] + self, + table: str | None = None, + data=None, + column_names: str | Iterable[str] = "*", + database: str | None = None, + column_types: Sequence[ClickHouseType] | None = None, + column_type_names: Sequence[str] | None = None, + column_oriented: bool = False, + settings: dict[str, Any] | None = None, + context: InsertContext | None = None, + transport_settings: dict[str, str] | None = None, + ) -> QuerySummary: + return await self._run( + lambda: self._sync.insert( + table=table, + data=data, + column_names=column_names, + database=database, + column_types=column_types, + column_type_names=column_type_names, + column_oriented=column_oriented, + settings=settings, + context=context, + transport_settings=transport_settings, + ) + ) + + async def insert_df(self, *args, **kwargs) -> QuerySummary: # type: ignore[override] + return await self._run(lambda: self._sync.insert_df(*args, **kwargs)) + + async def insert_arrow(self, *args, **kwargs) -> QuerySummary: # type: ignore[override] + return await self._run(lambda: self._sync.insert_arrow(*args, **kwargs)) + + async def insert_df_arrow(self, *args, **kwargs) -> QuerySummary: # type: ignore[override] + return await self._run(lambda: self._sync.insert_df_arrow(*args, **kwargs)) + + async def data_insert(self, context: InsertContext) -> QuerySummary: # type: ignore[override] + return await self._run(self._sync.data_insert, context) + + async def raw_insert( # type: ignore[override] + self, + table: str | None = None, + column_names: Sequence[str] | None = None, + insert_block: str | bytes | Generator[bytes, None, None] | BinaryIO | None = None, + settings: dict[str, Any] | None = None, + fmt: str | None = None, + compression: str | None = None, + transport_settings: dict[str, str] | None = None, + ) -> QuerySummary: + return await self._run( + lambda: self._sync.raw_insert( + table=table, + column_names=column_names, + insert_block=insert_block, + settings=settings, + fmt=fmt, + compression=compression, + transport_settings=transport_settings, + ) + ) + + async def close(self) -> None: # type: ignore[override] + await self._run(self._sync.close) + + async def close_connections(self) -> None: # type: ignore[override] + await self._run(self._sync.close_connections) + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + await self.close() + return False + + async def create_insert_context(self, *args, **kwargs) -> InsertContext: # type: ignore[override] + return await self._run(lambda: self._sync.create_insert_context(*args, **kwargs)) + + def create_query_context(self, *args, **kwargs) -> QueryContext: + return self._sync.create_query_context(*args, **kwargs) diff --git a/clickhouse_connect/driver/chdbclient.py b/clickhouse_connect/driver/chdbclient.py new file mode 100644 index 00000000..40cbd92d --- /dev/null +++ b/clickhouse_connect/driver/chdbclient.py @@ -0,0 +1,820 @@ +""" +In-process chdb backend for clickhouse-connect. + +ChdbClient implements the Client contract on top of the embedded ClickHouse engine +exposed by the `chdb` Python package. The same Native byte format that the HTTP +server emits is consumed verbatim, so all of clickhouse-connect's existing type, +dtype, and result conversion machinery is reused. +""" + +from __future__ import annotations + +import io +import json +import logging +import os +import re +import sys +import tempfile +import threading +from collections.abc import Generator, Sequence +from typing import TYPE_CHECKING, Any, BinaryIO + +from clickhouse_connect import common +from clickhouse_connect.datatypes.registry import get_from_name +from clickhouse_connect.driver.binding import bind_query, quote_identifier +from clickhouse_connect.driver.client import Client +from clickhouse_connect.driver.common import coerce_int +from clickhouse_connect.driver.ctypes import RespBuffCls +from clickhouse_connect.driver.exceptions import ( + DatabaseError, + NotSupportedError, + ProgrammingError, + StreamFailureError, +) +from clickhouse_connect.driver.external import ExternalData +from clickhouse_connect.driver.insert import InsertContext +from clickhouse_connect.driver.query import QueryContext, QueryResult, TzMode, TzSource +from clickhouse_connect.driver.summary import QuerySummary +from clickhouse_connect.driver.transform import NativeTransform + +if TYPE_CHECKING: + pass + +logger = logging.getLogger(__name__) + +_columns_only_re = re.compile(r"LIMIT 0\s*$", re.IGNORECASE) + +# chdb's `send_query` emits each ClickHouse block as a self-contained encoding in the +# requested format. For formats that have row-level (or block-level) self-description +# and no global header/footer/file structure, concatenating chunks yields a valid +# stream the caller's parser can consume directly. Other formats (Arrow, Parquet, +# JSON, *WithNames variants, ...) would emit duplicated headers / multiple file +# markers per chunk, which is not a valid larger stream. For those we fall back to a +# single non-streaming query so the result is one well-formed payload. +_STREAM_SAFE_FORMATS = frozenset( + { + "Native", + "TabSeparated", + "TSV", + "CSV", + "RowBinary", + "JSONEachRow", + } +) + + +class _BytesSource: + """ + Minimal stand-in for the HTTP `ResponseSource` that the response buffer + expects. Yields a single chunk of bytes and exposes the attributes the + transform layer reads. + """ + + __slots__ = ("data", "last_message", "exception_tag") + + def __init__(self, data: bytes): + self.data = data + self.last_message = None + self.exception_tag = None + + @property + def gen(self): + def _gen(): + yield self.data + + return _gen() + + def close(self): + return None + + +class _ChdbStreamSource: + """ + Source for `ResponseBuffer` backed by a chdb `StreamingResult`. Yields each + block's bytes and translates chdb's mid-stream RuntimeError into the + `StreamFailureError` clickhouse-connect callers expect. + """ + + __slots__ = ("_sr", "_lock", "_released", "last_message", "exception_tag") + + def __init__(self, streaming_result, lock: threading.Lock): + self._sr = streaming_result + self._lock = lock + self._released = False + self.last_message = None + self.exception_tag = None + + @property + def gen(self): + def _gen(): + try: + while True: + try: + chunk = next(self._sr) + except StopIteration: + return + except Exception as ex: # noqa: BLE001 + raise StreamFailureError(_format_error_message(str(ex))) from ex + payload = chunk.bytes() if hasattr(chunk, "bytes") else bytes(chunk) + if payload: + yield payload + finally: + self.close() + + return _gen() + + def close(self): + if self._released: + return + self._released = True + try: + close = getattr(self._sr, "close", None) + if close: + close() + except Exception: # noqa: BLE001 + logger.debug("Error closing chdb StreamingResult", exc_info=True) + finally: + try: + self._lock.release() + except RuntimeError: + pass + + +def _format_error_message(message: str) -> str: + """Extract a clean ClickHouse exception message from a chdb error string.""" + if not message: + return "" + idx = message.find("Code: ") + if idx > 0: + return message[idx:].strip() + return message.strip() + + +def _drain_to_bytes(block) -> bytes: + """Collect any supported insert_block shape into a single bytes value.""" + if isinstance(block, (bytes, bytearray, memoryview)): + return bytes(block) + if isinstance(block, str): + return block.encode() + if hasattr(block, "to_pybytes"): + return block.to_pybytes() + if hasattr(block, "read"): + return block.read() + parts = [] + for chunk in block: + parts.append(chunk if isinstance(chunk, (bytes, bytearray)) else chunk.encode()) + return b"".join(parts) + + +def _decompress(data: bytes, encoding: str) -> bytes: + if encoding == "lz4": + import lz4.frame + + return lz4.frame.decompress(data) + if encoding == "zstd": + import zstandard + + return zstandard.ZstdDecompressor().decompress(data) + if encoding == "gzip": + import gzip + + return gzip.decompress(data) + if encoding == "br": + try: + import brotli + except ImportError as ex: + raise NotSupportedError("brotli is required to decompress 'br' for chdb raw_insert") from ex + return brotli.decompress(data) + if encoding == "deflate": + import zlib + + return zlib.decompress(data) + raise NotSupportedError(f"Unsupported compression {encoding!r} for chdb raw_insert") + + +def _build_conn_string(chdb_path: str, chdb_options: dict[str, Any] | None) -> str: + path = chdb_path or ":memory:" + if not chdb_options: + return path + from urllib.parse import urlencode + + query = urlencode({k: str(v) for k, v in chdb_options.items()}) + sep = "&" if "?" in path else "?" + return f"{path}{sep}{query}" + + +class ChdbClient(Client): + """ClickHouse Connect client backed by the in-process chdb engine.""" + + # HTTP-style transport settings: accepted by setting validation but stripped + # before being forwarded to chdb (they have no in-process equivalent). + valid_transport_settings: set[str] = { + "database", + "client_protocol_version", + "session_id", + "session_timeout", + "session_check", + "query_id", + "quota_key", + "compress", + "decompress", + "wait_end_of_query", + "buffer_size", + "role", + "send_progress_in_http_headers", + "http_headers_progress_interval_ms", + "enable_http_compression", + } + + def __init__( + self, + chdb_path: str = ":memory:", + chdb_options: dict[str, Any] | None = None, + database: str | None = None, + settings: dict[str, Any] | None = None, + query_limit: int = 0, + tz_source: TzSource | None = None, + tz_mode: TzMode | None = None, + show_clickhouse_errors: bool | None = None, + **ignored, + ): + if sys.platform.startswith("win"): + raise NotSupportedError("chdb backend is not supported on Windows") + + import chdb + + self._chdb_path = chdb_path or ":memory:" + self._chdb_options = dict(chdb_options) if chdb_options else {} + self._connection_string = _build_conn_string(self._chdb_path, self._chdb_options) + self._chdb_module = chdb + self._conn = chdb.connect(self._connection_string) + self._lock = threading.Lock() + self._closed = False + self._client_settings: dict[str, str] = {} + self._initial_settings = dict(settings or {}) + self._read_format = "Native" + self._write_format = "Native" + self._transform = NativeTransform() + self._integration_libs: set[str] = set() + self.uri = f"chdb://{self._chdb_path}" + self.write_compression = None + self.compression = None + + # coerce_int handles None-or-string flexibility + super().__init__( + database=database, + uri=self.uri, + query_limit=coerce_int(query_limit), + query_retries=0, + server_host_name=None, + tz_source=tz_source, + tz_mode=tz_mode, + show_clickhouse_errors=show_clickhouse_errors, + autoconnect=True, + ) + + for k, v in self._initial_settings.items(): + self.set_client_setting(k, v) + + if self.database: + self._exec_raw_query(f"USE {quote_identifier(self.database)}") + + logger.info( + "ChdbClient connected: chdb=%s, server_version=%s, path=%s", + getattr(chdb, "__version__", "?"), + self.server_version, + self._chdb_path, + ) + + # ---- helpers ------------------------------------------------------- + + @property + def chdb_connection(self): + """Underlying chdb connection. Escape hatch for advanced users.""" + return self._conn + + def _ensure_open(self) -> None: + if self._closed: + raise ProgrammingError("ChdbClient is closed") from None + + def _filter_per_call_settings(self, settings: dict[str, Any] | None) -> dict[str, str]: + """Validate per-call settings and drop transport-only ones.""" + out: dict[str, str] = {} + if not settings: + return out + invalid_action = common.get_setting("invalid_setting_action") + for k, v in settings.items(): + str_v = self._validate_setting(k, v, invalid_action) + if str_v is None: + continue + if k in self.valid_transport_settings: + continue + out[k] = str_v + return out + + @staticmethod + def _quote_setting_value(value: str) -> str: + """SQL-quote a setting value so chdb sees the expected literal type. + + Without quotes chdb parses bare numeric-looking strings as UInt64; if the + setting is actually String-typed (e.g. `insert_deduplication_token`) this + triggers `Bad get: has UInt64, requested String`. ClickHouse coerces + single-quoted literals back to numeric types where needed, so quoting + unconditionally is safe. + """ + escaped = value.replace("\\", "\\\\").replace("'", "\\'") + return f"'{escaped}'" + + def _append_settings_clause(self, sql, settings): + if not settings: + return sql + extras = ", ".join(f"{k} = {self._quote_setting_value(v)}" for k, v in settings.items()) + if isinstance(sql, bytes): + # raw_query can receive a bytes SQL when binary parameter substitution + # produced non-UTF-8 byte sequences. chdb accepts bytes natively, so + # keep the bytes path and append the settings clause as bytes too. + sep = b", " if b" SETTINGS " in sql.upper() else b" SETTINGS " + return sql + sep + extras.encode() + if " SETTINGS " in sql.upper(): + return f"{sql}, {extras}" + return f"{sql} SETTINGS {extras}" + + def _persist_setting(self, key: str, value: str) -> None: + """Apply a setting to the underlying chdb session via SET.""" + try: + with self._lock: + self._conn.query(f"SET {key} = {self._quote_setting_value(value)}", "TabSeparated") + except Exception as ex: # noqa: BLE001 + logger.debug("Failed to apply SET %s=%s to chdb session: %s", key, value, ex) + + def _snapshot_settings(self, keys: Sequence[str]) -> dict[str, tuple[str, bool]]: + """Read current value and 'changed' flag for each key from system.settings. + + Returns a dict: {name -> (value, was_explicitly_set)}. + """ + if not keys: + return {} + quoted = ", ".join(f"'{k}'" for k in keys) + body = self._exec_raw_query( + f"SELECT name, value, changed FROM system.settings WHERE name IN ({quoted})", + "TabSeparated", + ) + result: dict[str, tuple[str, bool]] = {} + if body: + for line in body.decode().rstrip("\n").split("\n"): + parts = line.split("\t") + if len(parts) == 3: + name, value, changed = parts + result[name] = (value, changed == "1") + return result + + def _restore_settings(self, snapshot: dict[str, tuple[str, bool]]) -> None: + """Restore settings to the state captured by `_snapshot_settings`.""" + for name, (value, was_changed) in snapshot.items(): + try: + if was_changed: + self._persist_setting(name, value) + else: + with self._lock: + self._conn.query(f"SET {name} = DEFAULT", "TabSeparated") + except Exception: # noqa: BLE001 + logger.debug("Failed to restore setting %s after command()", name, exc_info=True) + + @staticmethod + def _strip_param_prefix(bind_params: dict[str, Any]) -> dict[str, Any]: + """chdb's `params` kwarg expects bare names (`x`); bind_query produces `param_x`.""" + return {(k[6:] if k.startswith("param_") else k): v for k, v in bind_params.items()} if bind_params else {} + + def _exec_raw_query(self, sql: str, fmt: str = "Native", params: dict[str, Any] | None = None) -> bytes: + """Run a query against chdb under the per-client lock and return raw bytes.""" + self._ensure_open() + with self._lock: + try: + result = self._conn.query(sql, fmt, params=params or {}) + except Exception as ex: # noqa: BLE001 + raise self._wrap_exception(ex) from ex + return result.bytes() if hasattr(result, "bytes") else bytes(result) + + def _wrap_exception(self, ex: Exception) -> Exception: + message = _format_error_message(str(ex)) + if not self.show_clickhouse_errors: + message = "ClickHouse error" + return DatabaseError(message) + + def _format_for_command(self) -> str: + return "TabSeparated" + + # ---- abstract method implementations ------------------------------- + + def set_client_setting(self, key: str, value: Any) -> None: + str_value = self._validate_setting(key, value, common.get_setting("invalid_setting_action")) + if str_value is None: + return + self._client_settings[key] = str_value + if key in self.valid_transport_settings: + return + self._persist_setting(key, str_value) + + def get_client_setting(self, key: str) -> str | None: + return self._client_settings.get(key) + + def set_access_token(self, access_token: str) -> None: + # chdb has no auth concept; accept silently for HTTP-mode drop-in compatibility. + return None + + def _query_with_context(self, context: QueryContext) -> QueryResult: + self._ensure_open() + if context.external_data is not None: + raise NotSupportedError("external_data is not supported by the chdb backend") + # chdb's Native output does not include the 8-byte block_info prefix that the + # HTTP server emits when client_protocol_version is set. + context.block_info = False + final_query = self._prep_query(context) + if isinstance(final_query, bytes): + final_query = final_query.decode() + params = self._strip_param_prefix(context.bind_params) + if not context.is_insert and _columns_only_re.search(context.uncommented_query): + # chdb emits zero Native bytes for a LIMIT 0 query, so the Native parser + # would return an empty result with no column metadata. Fetch the schema + # via JSON instead, matching the HTTP client's columns-only fast path. + return self._fetch_columns_only(context, final_query, params) + if context.is_insert: + # INSERT ... VALUES carries its data inline and has no result block to parse; + # appending `FORMAT Native` to a VALUES statement is a syntax error. + sql = self._append_settings_clause(final_query, self._filter_per_call_settings(context.settings)) + self._exec_raw_query(sql, "TabSeparated", params=params) + return QueryResult([]) + sql = f"{final_query}\n FORMAT Native" + sql = self._append_settings_clause(sql, self._filter_per_call_settings(context.settings)) + if context.streaming: + # Use chdb's streaming `send_query` so mid-execution engine errors + # (e.g. throwIf, division by zero on row N) surface during result + # iteration as `StreamFailureError`, matching HTTP's contract. The + # non-streaming `conn.query` would raise eagerly and lose lazy-error + # semantics — we only opt into that for true streaming results, since + # holding the per-client lock for the lifetime of a non-iterated + # QueryResult would deadlock subsequent calls. + self._ensure_open() + self._lock.acquire() + try: + streaming = self._conn.send_query(sql, "Native", params=params or {}) + except Exception as ex: # noqa: BLE001 + self._lock.release() + raise self._wrap_exception(ex) from ex + byte_source = RespBuffCls(_ChdbStreamSource(streaming, self._lock)) + else: + data = self._exec_raw_query(sql, "Native", params=params) + byte_source = RespBuffCls(_BytesSource(data)) + query_result = self._transform.parse_response(byte_source, context) + query_result.summary = {} + return query_result + + def _fetch_columns_only(self, context: QueryContext, final_query: str, params: dict[str, Any]) -> QueryResult: + sql = self._append_settings_clause(f"{final_query}\n FORMAT JSON", self._filter_per_call_settings(context.settings)) + body = self._exec_raw_query(sql, "JSON", params=params) + meta = json.loads(body)["meta"] + renamer = context.column_renamer + names: list[str] = [] + types = [] + for col in meta: + name = col["name"] + if renamer is not None: + try: + name = renamer(name) + except Exception as ex: # noqa: BLE001 + logger.debug("Failed to rename column %s: %s", name, ex) + names.append(name) + types.append(get_from_name(col["type"])) + return QueryResult([], None, tuple(names), tuple(types)) + + def raw_query( + self, + query: str, + parameters: Sequence | dict[str, Any] | None = None, + settings: dict[str, Any] | None = None, + fmt: str | None = None, + use_database: bool = True, + external_data: ExternalData | None = None, + transport_settings: dict[str, str] | None = None, + ) -> bytes: + if external_data is not None: + raise NotSupportedError("external_data is not supported by the chdb backend") + final_query, bound = bind_query(query, parameters, self.server_tz) + # chdb's conn.query accepts both str and bytes; preserve bytes when binary + # parameter substitution (e.g. `$xx$` placeholders) yields non-UTF-8 SQL. + final_query = self._append_settings_clause(final_query, self._filter_per_call_settings(settings)) + # HTTP path defaults to server's TabSeparated when no fmt is provided. + return self._exec_raw_query(final_query, fmt or "TabSeparated", params=self._strip_param_prefix(bound)) + + def raw_stream( + self, + query: str, + parameters: Sequence | dict[str, Any] | None = None, + settings: dict[str, Any] | None = None, + fmt: str | None = None, + use_database: bool = True, + external_data: ExternalData | None = None, + transport_settings: dict[str, str] | None = None, + ) -> io.IOBase: + if external_data is not None: + raise NotSupportedError("external_data is not supported by the chdb backend") + final_query, bound = bind_query(query, parameters, self.server_tz) + if isinstance(final_query, bytes): + final_query = final_query.decode() + final_query = self._append_settings_clause(final_query, self._filter_per_call_settings(settings)) + params = self._strip_param_prefix(bound) + output_fmt = fmt or "TabSeparated" + if output_fmt not in _STREAM_SAFE_FORMATS: + # Formats with global structure (Arrow IPC, Parquet, JSON, *WithNames, ...) + # can't be assembled from chdb's per-block chunks. Fetch as a single + # well-formed payload and wrap as an in-memory stream. + data = self._exec_raw_query(final_query, output_fmt, params=params) + return io.BytesIO(data) + self._ensure_open() + # Acquire the lock for the lifetime of the streaming read so concurrent + # callers don't interleave queries on the same chdb connection. + self._lock.acquire() + try: + streaming = self._conn.send_query(final_query, output_fmt, params=params or {}) + except Exception as ex: # noqa: BLE001 + self._lock.release() + raise self._wrap_exception(ex) from ex + return _ChdbStreamFile(streaming, self._lock) + + def command( + self, + cmd: str, + parameters: Sequence | dict[str, Any] | None = None, + data: str | bytes | None = None, + settings: dict[str, Any] | None = None, + use_database: bool = True, + external_data: ExternalData | None = None, + transport_settings: dict[str, str] | None = None, + ) -> str | int | Sequence[str] | QuerySummary: + if external_data is not None: + raise NotSupportedError("external_data is not supported by the chdb backend") + cmd, bound = bind_query(cmd, parameters, self.server_tz) + if isinstance(cmd, bytes): + cmd = cmd.decode() + params = self._strip_param_prefix(bound) + if data is not None: + if isinstance(data, bytes): + data_str = data.decode() + else: + data_str = data + cmd = f"{cmd}\n{data_str}" + per_call = self._filter_per_call_settings(settings) + # ClickHouse DDL doesn't accept a SETTINGS clause; apply per-call settings to + # the chdb session via SET before running the command, then restore them + # afterwards so they don't leak into the session. + snapshot: dict[str, tuple[str, bool]] = {} + if per_call: + snapshot = self._snapshot_settings(list(per_call.keys())) + for k, v in per_call.items(): + self._persist_setting(k, v) + try: + body = self._exec_raw_query(cmd, self._format_for_command(), params=params) + finally: + if snapshot: + self._restore_settings(snapshot) + if not body: + return QuerySummary({}) + try: + text = body.decode() + except UnicodeDecodeError: + return str(body) + # Match HTTP client semantics: strip trailing newline, split by tab, single + # token tries to coerce to int. + if text.endswith("\n"): + text = text[:-1] + result = text.split("\t") + if len(result) == 1: + try: + return int(result[0]) + except ValueError: + return result[0] + return result + + def ping(self) -> bool: + try: + self._exec_raw_query("SELECT 1", "TabSeparated") + return True + except Exception: # noqa: BLE001 + logger.debug("chdb ping failed", exc_info=True) + return False + + def data_insert(self, context: InsertContext) -> QuerySummary: + if context.empty: + return QuerySummary() + return self._insert_via_infile(context) + + def raw_insert( + self, + table: str | None = None, + column_names: Sequence[str] | None = None, + insert_block: str | bytes | Generator[bytes, None, None] | BinaryIO | None = None, + settings: dict[str, Any] | None = None, + fmt: str | None = None, + compression: str | None = None, + transport_settings: dict[str, str] | None = None, + ) -> QuerySummary: + if insert_block is None or not table: + raise ProgrammingError("raw_insert requires a table and insert_block") + if compression and compression != "identity": + # HTTP carries this via Content-Encoding so the server decompresses. + # chdb has no equivalent input stage, so the caller's pre-compressed + # bytes must be drained and decompressed in the client before being + # written to the INFILE temp file. + insert_block = _drain_to_bytes(insert_block) + insert_block = _decompress(insert_block, compression) + compression = None + + fmt = fmt or self._write_format + cols = "" + if column_names: + cols = f" ({', '.join(quote_identifier(c) for c in column_names)})" + + # Drain insert_block to a temp file, then INSERT FROM INFILE. + tmp = tempfile.NamedTemporaryFile(suffix=f".{fmt.lower()}", delete=False) + try: + try: + if isinstance(insert_block, (bytes, bytearray, memoryview)): + tmp.write(bytes(insert_block)) + elif isinstance(insert_block, str): + tmp.write(insert_block.encode()) + elif hasattr(insert_block, "to_pybytes"): + # pyarrow.Buffer and friends — buffer protocol holder + tmp.write(insert_block.to_pybytes()) + elif hasattr(insert_block, "read"): + while True: + chunk = insert_block.read(1 << 20) + if not chunk: + break + tmp.write(chunk if isinstance(chunk, (bytes, bytearray)) else chunk.encode()) + else: + for chunk in insert_block: + tmp.write(chunk if isinstance(chunk, (bytes, bytearray)) else chunk.encode()) + finally: + tmp.close() + + per_call = self._filter_per_call_settings(settings) + settings_clause = ( + f" SETTINGS {', '.join(f'{k} = {self._quote_setting_value(v)}' for k, v in per_call.items())}" if per_call else "" + ) + sql = f"INSERT INTO {table}{cols} FROM INFILE '{tmp.name}'{settings_clause} FORMAT {fmt}" + self._exec_raw_query(sql, "TabSeparated") + return QuerySummary({}) + finally: + try: + os.unlink(tmp.name) + except OSError: + pass + + def close(self) -> None: + if self._closed: + return + try: + with self._lock: + self._conn.close() + except Exception: # noqa: BLE001 + logger.debug("Error closing chdb connection", exc_info=True) + self._closed = True + + def close_connections(self) -> None: + # chdb only has a single embedded connection per client. + self.close() + + # ---- insert implementations ---------------------------------------- + + def _insert_via_infile(self, context: InsertContext) -> QuerySummary: + tmp = tempfile.NamedTemporaryFile(suffix=".native", delete=False) + try: + try: + first_chunk = True + # NativeTransform.build_insert prepends an `INSERT INTO ... FORMAT Native\n` + # statement to the first chunk for the HTTP request body. We're going to + # write only the Native bytes to a file and INSERT FROM INFILE, so the + # prefix must be skipped. + for chunk in self._transform.build_insert(context): + if context.insert_exception is not None: + ex = context.insert_exception + context.insert_exception = None + raise ex + if first_chunk: + nl = chunk.find(b"\n") + if nl >= 0: + chunk = chunk[nl + 1 :] + first_chunk = False + tmp.write(chunk) + finally: + tmp.close() + + cols = ", ".join(quote_identifier(c) for c in context.column_names) + per_call = self._filter_per_call_settings(context.settings) + settings_clause = ( + f" SETTINGS {', '.join(f'{k} = {self._quote_setting_value(v)}' for k, v in per_call.items())}" if per_call else "" + ) + sql = f"INSERT INTO {context.table} ({cols}) FROM INFILE '{tmp.name}'{settings_clause} FORMAT Native" + self._exec_raw_query(sql, "TabSeparated") + return QuerySummary({}) + finally: + try: + os.unlink(tmp.name) + except OSError: + pass + context.data = None + + # ---- integration tagging ------------------------------------------ + + def _add_integration_tag(self, name: str) -> None: + # No User-Agent header to update for in-process chdb; just record for + # potential future use. + self._integration_libs.add(name) + + +class _ChdbStreamFile(io.RawIOBase): + """ + File-like adapter wrapping chdb's StreamingResult iterator so callers in + clickhouse-connect (which expect an io.IOBase / aiohttp-style stream) can + iterate bytes block-by-block. + + Holds a per-client lock for its lifetime so the chdb connection is not used + concurrently by another caller while a stream is in flight. + """ + + def __init__(self, streaming_result, lock: threading.Lock): + super().__init__() + self._sr = streaming_result + self._lock = lock + self._buf = b"" + self._eof = False + self._closed_flag = False + + def readable(self) -> bool: + return True + + def _pull(self) -> bytes: + while True: + try: + chunk = next(self._sr) + except StopIteration: + self._eof = True + return b"" + except Exception as ex: # noqa: BLE001 + # chdb wraps mid-stream engine errors as RuntimeError. Surface them + # as StreamFailureError so callers can catch them with the same + # exception type used by the HTTP backend's mid-stream failures. + msg = _format_error_message(str(ex)) + self._eof = True + raise StreamFailureError(msg) from ex + payload = chunk.bytes() if hasattr(chunk, "bytes") else bytes(chunk) + if payload: + return payload + + def read(self, size: int | None = -1) -> bytes: + if self._closed_flag: + return b"" + if size is None or size < 0: + parts = [self._buf] + self._buf = b"" + while not self._eof: + chunk = self._pull() + if not chunk: + break + parts.append(chunk) + return b"".join(parts) + while len(self._buf) < size and not self._eof: + chunk = self._pull() + if not chunk: + break + self._buf += chunk + if not self._buf: + return b"" + out = self._buf[:size] + self._buf = self._buf[size:] + return out + + def readinto(self, buf) -> int: + data = self.read(len(buf)) + n = len(data) + if n: + buf[:n] = data + return n + + def close(self) -> None: + if self._closed_flag: + return + self._closed_flag = True + try: + close = getattr(self._sr, "close", None) + if close: + close() + except Exception: # noqa: BLE001 + logger.debug("Error closing chdb StreamingResult", exc_info=True) + finally: + try: + self._lock.release() + except RuntimeError: + pass + super().close() diff --git a/setup.py b/setup.py index 856e0a19..d111a410 100644 --- a/setup.py +++ b/setup.py @@ -59,6 +59,7 @@ def run_setup(try_c: bool = True): install_requires=[ "certifi", "urllib3>=1.26", + 'chdb>=4.1.7; sys_platform != "win32"', 'tzdata; sys_platform == "win32"', 'zstandard; python_version<"3.14"', 'zstandard>=0.25.0; python_version>="3.14"', diff --git a/tests/unit_tests/test_driver/test_chdbclient.py b/tests/unit_tests/test_driver/test_chdbclient.py new file mode 100644 index 00000000..b6dfd3cf --- /dev/null +++ b/tests/unit_tests/test_driver/test_chdbclient.py @@ -0,0 +1,1031 @@ +""" +Unit tests for the in-process chdb client backend. + +These tests do not require a ClickHouse server — chdb is the embedded engine. +Skipped automatically if `chdb` is not installable (e.g. Windows or bare +install). +""" + +from __future__ import annotations + +import asyncio +import io +import os +from datetime import date, datetime +from decimal import Decimal +from uuid import UUID + +import pytest + +chdb = pytest.importorskip("chdb") + +import clickhouse_connect # noqa: E402 +from clickhouse_connect.driver.chdbclient import _build_conn_string, _format_error_message # noqa: E402 +from clickhouse_connect.driver.exceptions import ( # noqa: E402 + DatabaseError, + NotSupportedError, + ProgrammingError, +) + + +@pytest.fixture +def client(): + c = clickhouse_connect.get_client(interface="chdb") + yield c + c.close() + + +# ---- basic protocol ---- + + +def test_ping(client): + assert client.ping() is True + + +def test_server_version_populated(client): + assert client.server_version + assert client.server_version.split(".")[0].isdigit() + + +def test_uri_shape(): + c = clickhouse_connect.get_client(interface="chdb", chdb_path=":memory:") + try: + assert c.uri.startswith("chdb://") + finally: + c.close() + + +def test_chdb_connection_escape_hatch_exposed(client): + assert client.chdb_connection is not None + + +# ---- query / command ---- + + +def test_command_returns_scalar(client): + assert client.command("SELECT 13") == 13 + assert client.command("SELECT 'user_1'") == "user_1" + + +def test_command_returns_tuple_for_multiple_columns(client): + result = client.command("SELECT 79, 'user_2'") + assert result == ["79", "user_2"] + + +def test_query_primitives(client): + r = client.query( + "SELECT toInt32(13) AS i, toString('user_1') AS s, toFloat64(3.14) AS f", + ) + assert r.column_names == ("i", "s", "f") + assert r.result_rows == [(13, "user_1", 3.14)] + + +def test_query_nullable_and_low_cardinality(client): + r = client.query("SELECT CAST(NULL AS Nullable(Int64)) AS n, CAST('user_2' AS LowCardinality(String)) AS lc") + row = r.result_rows[0] + assert row[0] is None + assert row[1] == "user_2" + + +def test_query_dates_decimals(client): + r = client.query("SELECT toDate('2026-05-19') AS d, toDateTime('2026-05-19 10:30:00', 'UTC') AS dt, toDecimal64(123.456, 3) AS dec") + d, dt, dec = r.result_rows[0] + assert d == date(2026, 5, 19) + assert dt == datetime(2026, 5, 19, 10, 30, 0) + assert dec == Decimal("123.456") + + +def test_query_array_and_map(client): + r = client.query("SELECT [1, 2, 3]::Array(UInt32) AS arr, map('user_1', 13, 'user_2', 79) AS m") + arr, m = r.result_rows[0] + assert list(arr) == [1, 2, 3] + assert m == {"user_1": 13, "user_2": 79} + + +def test_query_multi_row(client): + r = client.query("SELECT number FROM numbers(5)") + assert [row[0] for row in r.result_rows] == [0, 1, 2, 3, 4] + + +def test_query_empty(client): + r = client.query("SELECT 1 WHERE 0") + assert r.result_rows == [] + + +def test_raw_query_pass_through(client): + body = client.raw_query("SELECT 13 AS x", fmt="TabSeparated") + assert body == b"13\n" + + +# ---- insert paths ---- + + +def test_insert_row_data(client): + client.command("CREATE TABLE row_insert_test (id UInt32, name String) ENGINE = Memory") + client.insert( + "row_insert_test", + [[13, "user_1"], [79, "user_2"]], + column_names=["id", "name"], + ) + r = client.query("SELECT id, name FROM row_insert_test ORDER BY id") + assert r.result_rows == [(13, "user_1"), (79, "user_2")] + + +def test_insert_dataframe(client): + pd = pytest.importorskip("pandas") + client.command("CREATE TABLE df_insert_test (id UInt32, v Float64) ENGINE = Memory") + df = pd.DataFrame({"id": [13, 79, 103], "v": [1.5, 2.5, 3.5]}) + client.insert_df("df_insert_test", df) + r = client.query("SELECT id, v FROM df_insert_test ORDER BY id") + assert r.result_rows == [(13, 1.5), (79, 2.5), (103, 3.5)] + + +def test_insert_dataframe_reordered_columns(client): + pd = pytest.importorskip("pandas") + client.command("CREATE TABLE df_reorder (id UInt32, v Float64) ENGINE = Memory") + df = pd.DataFrame({"v": [9.5, 10.5], "id": [13, 79]}) # reversed + client.insert_df("df_reorder", df) + r = client.query("SELECT id, v FROM df_reorder ORDER BY id") + assert r.result_rows == [(13, 9.5), (79, 10.5)] + + +def test_raw_insert_bytes_round_trip(client): + client.command("CREATE TABLE raw_insert_test (id UInt32, v String) ENGINE = Memory") + csv = b"13,user_1\n79,user_2\n" + client.raw_insert("raw_insert_test", insert_block=csv, fmt="CSV") + r = client.query("SELECT id, v FROM raw_insert_test ORDER BY id") + assert r.result_rows == [(13, "user_1"), (79, "user_2")] + + +# ---- session semantics ---- + + +def test_session_persistence_within_client(client): + client.command("CREATE TEMPORARY TABLE temp_persist (id Int32)") + client.command("INSERT INTO temp_persist VALUES (13), (79)") + r = client.query("SELECT count() FROM temp_persist") + assert r.result_rows[0][0] == 2 + + +def test_set_client_setting_persists(client): + client.set_client_setting("max_block_size", 1000) + assert client.get_client_setting("max_block_size") == "1000" + + +def _read_session_setting(client, name: str) -> str: + body = client.raw_query(f"SELECT value FROM system.settings WHERE name = '{name}'", fmt="TabSeparated") + return body.decode().strip() + + +def test_command_per_call_setting_does_not_leak(client): + before = _read_session_setting(client, "max_block_size") + client.command("SELECT 1", settings={"max_block_size": 13}) + after = _read_session_setting(client, "max_block_size") + assert after == before, f"max_block_size leaked: before={before!r} after={after!r}" + + +def test_command_per_call_setting_restored_on_error(client): + before = _read_session_setting(client, "max_block_size") + with pytest.raises(DatabaseError): + client.command("SELECT bad_function()", settings={"max_block_size": 13}) + after = _read_session_setting(client, "max_block_size") + assert after == before, f"max_block_size leaked after error: before={before!r} after={after!r}" + + +def test_command_restores_previously_set_value(client): + client.set_client_setting("max_block_size", 7) + client.command("SELECT 1", settings={"max_block_size": 13}) + assert _read_session_setting(client, "max_block_size") == "7" + + +# ---- streaming ---- + + +def test_query_row_block_stream(client): + with client.query_row_block_stream("SELECT number FROM numbers(50) SETTINGS max_block_size = 10") as stream: + blocks = list(stream) + assert sum(len(b) for b in blocks) == 50 + + +def test_raw_stream_iterates(client): + stream = client.raw_stream("SELECT number FROM numbers(5)", fmt="CSV") + try: + data = stream.read() + finally: + stream.close() + assert data.startswith(b"0\n") + + +# ---- raw_stream format dispatch ---- +# +# chdb's send_query emits each ClickHouse block as a self-contained payload, so only +# formats with no global header / footer / file marker can be concatenated chunk-by- +# chunk. For everything else raw_stream falls back to a non-streaming query that +# returns one well-formed payload. These tests pin both branches. + + +def _stream_full_bytes(client, sql, fmt): + stream = client.raw_stream(sql, fmt=fmt) + try: + return stream.read() + finally: + stream.close() + + +def _row_count(client, sql, fmt): + """Run as raw_query (single payload) and return total bytes for comparison.""" + return client.raw_query(sql, fmt=fmt) + + +# All values verified end-to-end: 200k rows is enough to force chdb to emit multiple +# blocks (max_block_size default is ~65k). +_LARGE_QUERY = "SELECT number AS id FROM numbers(200000)" + + +@pytest.mark.parametrize("fmt", ["Native", "TabSeparated", "CSV", "RowBinary", "JSONEachRow"]) +def test_raw_stream_safe_format_full_data(client, fmt): + """Stream-safe formats: concatenated chunks must equal the single-query payload.""" + streamed = _stream_full_bytes(client, _LARGE_QUERY, fmt) + full = _row_count(client, _LARGE_QUERY, fmt) + assert len(streamed) == len(full), f"{fmt}: streamed {len(streamed)} != full {len(full)}" + + +@pytest.mark.parametrize( + "fmt", + [ + "Arrow", + "ArrowStream", + "Parquet", + "TabSeparatedWithNames", + "CSVWithNames", + "RowBinaryWithNamesAndTypes", + ], +) +def test_raw_stream_unsafe_format_falls_back_to_single_payload(client, fmt): + """Unsafe formats fall back to non-streaming: result must equal single-query bytes.""" + streamed = _stream_full_bytes(client, _LARGE_QUERY, fmt) + full = _row_count(client, _LARGE_QUERY, fmt) + assert streamed == full, f"{fmt}: bytes differ — streamed={len(streamed)} vs full={len(full)}" + + +def test_raw_stream_unsafe_format_json_yields_one_object(client): + """JSON includes per-run statistics, so check structural equality rather than bytes.""" + import json as _json + + streamed = _json.loads(_stream_full_bytes(client, _LARGE_QUERY, "JSON")) + full = _json.loads(_row_count(client, _LARGE_QUERY, "JSON")) + assert streamed["meta"] == full["meta"] + assert streamed["data"] == full["data"] + assert "statistics" in streamed and "statistics" in full + + +def test_arrow_stream_yields_all_record_batches(client): + """Regression: large Arrow stream must surface every RecordBatch, not just the first.""" + pa = pytest.importorskip("pyarrow") + stream = client.raw_stream(_LARGE_QUERY, fmt="ArrowStream") + try: + reader = pa.ipc.open_stream(stream) + batches = list(reader) + finally: + stream.close() + total_rows = sum(b.num_rows for b in batches) + assert total_rows == 200000, f"Lost rows in arrow stream: got {total_rows}" + + +def test_parquet_stream_is_single_file(client): + """Regression: Parquet output must be one valid file, not multiple concatenated.""" + pa = pytest.importorskip("pyarrow") + import pyarrow.parquet as pq + + stream = client.raw_stream(_LARGE_QUERY, fmt="Parquet") + try: + data = stream.read() + finally: + stream.close() + table = pq.read_table(pa.BufferReader(data)) + assert table.num_rows == 200000 + + +def test_jsoneachrow_stream_iterates_chunks(client): + """JSONEachRow stays on the streaming path (per-line format), verify chunked read.""" + stream = client.raw_stream(_LARGE_QUERY, fmt="JSONEachRow") + try: + first = stream.read(1024) + rest = stream.read() + finally: + stream.close() + # First chunk should start with valid JSON object + assert first.startswith(b'{"id":'), f"unexpected start: {first[:40]!r}" + # Total bytes equal the non-streaming version + assert len(first) + len(rest) == len(_row_count(client, _LARGE_QUERY, "JSONEachRow")) + + +# ---- error mapping ---- + + +def test_unknown_function_maps_to_database_error(client): + with pytest.raises(DatabaseError) as ex_info: + client.query("SELECT bad_function()") + assert "UNKNOWN_FUNCTION" in str(ex_info.value) or "bad_function" in str(ex_info.value) + + +def test_external_data_not_supported(client): + from clickhouse_connect.driver.external import ExternalData + + ext = ExternalData(file_name="x.csv", data=b"1\n2\n", fmt="CSV", structure="id UInt32") + with pytest.raises(NotSupportedError): + client.query("SELECT * FROM x", external_data=ext) + + +def test_mid_stream_exception_surfaces_as_stream_failure(client): + """Mid-stream chdb errors must be raised as StreamFailureError to match HTTP semantics.""" + from clickhouse_connect.driver.exceptions import StreamFailureError + + query = "SELECT throwIf(number = 100) FROM numbers(1000) SETTINGS max_block_size = 10" + with pytest.raises(StreamFailureError) as ex_info: + with client.query_row_block_stream(query) as stream: + for _ in stream: + pass + assert "throwIf" in str(ex_info.value) or "Code: 395" in str(ex_info.value) + + +# ---- HTTP-only kwargs accepted silently ---- + + +def test_http_only_kwargs_silently_ignored(): + c = clickhouse_connect.get_client( + interface="chdb", + username="default", + password="ignored", + compress=True, + connect_timeout=10, + verify=True, + http_proxy="http://localhost:3128", + ) + try: + assert c.ping() is True + finally: + c.close() + + +def test_set_access_token_silent_noop(client): + client.set_access_token("not-a-real-token") # must not raise + + +# ---- pyarrow / numpy round-trips ---- + + +def test_query_arrow(client): + pytest.importorskip("pyarrow") + client.command("CREATE TABLE arrow_q (id UInt32, name String) ENGINE = Memory") + client.insert("arrow_q", [[13, "user_1"], [79, "user_2"]], column_names=["id", "name"]) + table = client.query_arrow("SELECT id, name FROM arrow_q ORDER BY id") + assert table.column_names == ["id", "name"] + assert table.column("id").to_pylist() == [13, 79] + assert table.column("name").to_pylist() == ["user_1", "user_2"] + + +def test_query_arrow_stream(client): + pytest.importorskip("pyarrow") + client.command("CREATE TABLE arrow_qs (id UInt32) ENGINE = Memory") + client.insert("arrow_qs", [[i] for i in range(20)], column_names=["id"]) + with client.query_arrow_stream("SELECT id FROM arrow_qs SETTINGS max_block_size = 5") as stream: + batches = list(stream) + assert sum(b.num_rows for b in batches) == 20 + + +def test_insert_arrow_round_trip(client): + pa = pytest.importorskip("pyarrow") + client.command("CREATE TABLE arrow_ins (id UInt32, name String) ENGINE = Memory") + table = pa.table({"id": pa.array([13, 79], type=pa.uint32()), "name": pa.array(["user_1", "user_2"])}) + client.insert_arrow("arrow_ins", table) + r = client.query("SELECT id, name FROM arrow_ins ORDER BY id") + assert r.result_rows == [(13, "user_1"), (79, "user_2")] + + +def test_query_np(client): + pytest.importorskip("numpy") + client.command("CREATE TABLE np_q (id UInt32, v Float64) ENGINE = Memory") + client.insert("np_q", [[13, 1.5], [79, 2.5]], column_names=["id", "v"]) + arr = client.query_np("SELECT id, v FROM np_q ORDER BY id") + assert list(arr["id"]) == [13, 79] + assert list(arr["v"]) == [1.5, 2.5] + + +def test_query_np_stream(client): + pytest.importorskip("numpy") + client.command("CREATE TABLE np_qs (id UInt32) ENGINE = Memory") + client.insert("np_qs", [[i] for i in range(20)], column_names=["id"]) + with client.query_np_stream("SELECT id FROM np_qs SETTINGS max_block_size = 7") as stream: + chunks = list(stream) + assert sum(len(c) for c in chunks) == 20 + + +# ---- additional streaming flavors ---- + + +def test_query_column_block_stream(client): + client.command("CREATE TABLE col_stream (id UInt32, v String) ENGINE = Memory") + client.insert("col_stream", [[i, f"row_{i}"] for i in range(15)], column_names=["id", "v"]) + with client.query_column_block_stream("SELECT id, v FROM col_stream SETTINGS max_block_size = 5") as stream: + blocks = list(stream) + # Each block is column-oriented: a tuple of columns + total_rows = sum(len(block[0]) for block in blocks) + assert total_rows == 15 + + +def test_query_rows_stream(client): + client.command("CREATE TABLE rows_stream (id UInt32) ENGINE = Memory") + client.insert("rows_stream", [[i] for i in range(10)], column_names=["id"]) + with client.query_rows_stream("SELECT id FROM rows_stream ORDER BY id") as stream: + rows = list(stream) + assert [r[0] for r in rows] == list(range(10)) + + +# ---- insert variations ---- + + +def test_insert_column_oriented(client): + client.command("CREATE TABLE col_oriented (id UInt32, v Float64) ENGINE = Memory") + columns = [[13, 79, 103], [1.5, 2.5, 3.5]] + client.insert("col_oriented", columns, column_names=["id", "v"], column_oriented=True) + r = client.query("SELECT id, v FROM col_oriented ORDER BY id") + assert r.result_rows == [(13, 1.5), (79, 2.5), (103, 3.5)] + + +def test_reusable_insert_context(client): + client.command("CREATE TABLE reuse_ctx (id UInt32, name String) ENGINE = Memory") + ctx = client.create_insert_context("reuse_ctx", column_names=["id", "name"]) + client.insert(data=[[13, "first"]], context=ctx) + client.insert(data=[[79, "second"]], context=ctx) + r = client.query("SELECT id, name FROM reuse_ctx ORDER BY id") + assert r.result_rows == [(13, "first"), (79, "second")] + + +# ---- database parameter ---- + + +def test_database_parameter_switches_default(): + c = clickhouse_connect.get_client(interface="chdb") + try: + c.command("CREATE DATABASE other_db") + c.command("CREATE TABLE other_db.scoped (id UInt32) ENGINE = Memory") + c.command("INSERT INTO other_db.scoped VALUES (13)") + finally: + c.close() + # Note: chdb :memory: is per-connection, so this test only checks the USE + # mechanism — can't cross sessions on :memory:. Instead verify USE works inline: + c2 = clickhouse_connect.get_client(interface="chdb") + try: + c2.command("CREATE DATABASE scoped_test") + c2.command("USE scoped_test") + c2.command("CREATE TABLE local_t (id UInt32) ENGINE = Memory") + # unqualified reference should resolve into scoped_test + c2.command("INSERT INTO local_t VALUES (13)") + assert c2.query("SELECT count() FROM local_t").result_rows[0][0] == 1 + assert c2.query("SELECT count() FROM scoped_test.local_t").result_rows[0][0] == 1 + finally: + c2.close() + + +def test_database_param_forwarded_to_use(tmp_path): + db = str(tmp_path / "dbparam.db") + # First connection creates DB + table + a = clickhouse_connect.get_client(interface="chdb", chdb_path=db) + try: + a.command("CREATE DATABASE analytics") + a.command("CREATE TABLE analytics.events (id UInt32) ENGINE = MergeTree ORDER BY id") + a.command("INSERT INTO analytics.events VALUES (13)") + finally: + a.close() + # Second connection uses the database= kwarg; unqualified table reference must work + b = clickhouse_connect.get_client(interface="chdb", chdb_path=db, database="analytics") + try: + assert b.query("SELECT count() FROM events").result_rows[0][0] == 1 + finally: + b.close() + + +# ---- DBAPI on top of chdb ---- + + +def test_dbapi_cursor_round_trip(): + import clickhouse_connect.dbapi as dbapi + + conn = dbapi.connect(interface="chdb") + try: + cur = conn.cursor() + try: + cur.execute("CREATE TABLE dba_round_trip (id UInt32, name String) ENGINE = Memory") + cur.execute("INSERT INTO dba_round_trip VALUES (13, 'user_1'), (79, 'user_2')") + cur.execute("SELECT id, name FROM dba_round_trip ORDER BY id") + rows = cur.fetchall() + assert rows == [(13, "user_1"), (79, "user_2")] + assert [c[0] for c in cur.description] == ["id", "name"] + finally: + cur.close() + finally: + conn.close() + + +def test_dbapi_executemany(): + import clickhouse_connect.dbapi as dbapi + + conn = dbapi.connect(interface="chdb") + try: + cur = conn.cursor() + try: + cur.execute("CREATE TABLE dba_many (id UInt32, name String) ENGINE = Memory") + cur.executemany( + "INSERT INTO dba_many (id, name) VALUES", + [{"id": 13, "name": "user_1"}, {"id": 79, "name": "user_2"}, {"id": 103, "name": "user_3"}], + ) + cur.execute("SELECT id, name FROM dba_many ORDER BY id") + assert cur.fetchall() == [(13, "user_1"), (79, "user_2"), (103, "user_3")] + finally: + cur.close() + finally: + conn.close() + + +# ---- async client ---- + + +def test_async_client_basic_flow(): + async def run(): + c = await clickhouse_connect.get_async_client(interface="chdb") + try: + assert await c.ping() is True + r = await c.query("SELECT 13 AS x") + assert r.result_rows == [(13,)] + await c.command("CREATE TABLE async_smoke (id UInt32) ENGINE = Memory") + await c.insert("async_smoke", [[13], [79]], column_names=["id"]) + r = await c.query("SELECT count() FROM async_smoke") + assert r.result_rows[0][0] == 2 + finally: + await c.close() + + asyncio.run(run()) + + +def test_async_client_gather_serializes_without_error(): + async def run(): + c = await clickhouse_connect.get_async_client(interface="chdb") + try: + results = await asyncio.gather( + c.query("SELECT 13"), + c.query("SELECT 79"), + c.query("SELECT 103"), + ) + values = [r.result_rows[0][0] for r in results] + assert sorted(values) == [13, 79, 103] + finally: + await c.close() + + asyncio.run(run()) + + +def test_async_dataframe_insert(): + pd = pytest.importorskip("pandas") + + async def run(): + c = await clickhouse_connect.get_async_client(interface="chdb") + try: + await c.command("CREATE TABLE async_df (id UInt32, v Float64) ENGINE = Memory") + df = pd.DataFrame({"id": [13, 79], "v": [1.5, 2.5]}) + await c.insert_df("async_df", df) + out = await c.query_df("SELECT id, v FROM async_df ORDER BY id") + assert list(out["id"]) == [13, 79] + assert list(out["v"]) == [1.5, 2.5] + finally: + await c.close() + + asyncio.run(run()) + + +# ---- factory / dispatch ---- + + +def test_factory_dispatches_on_interface(): + c = clickhouse_connect.get_client(interface="chdb") + try: + from clickhouse_connect.driver.chdbclient import ChdbClient + + assert isinstance(c, ChdbClient) + finally: + c.close() + + +# ---- pure helper unit tests (no chdb instance needed) ---- + + +def test_build_conn_string_default_memory(): + assert _build_conn_string("", None) == ":memory:" + assert _build_conn_string(None, None) == ":memory:" # type: ignore[arg-type] + + +def test_build_conn_string_path_unchanged_without_options(): + assert _build_conn_string("/data/db", None) == "/data/db" + assert _build_conn_string("file:/data/db?mode=ro", None) == "file:/data/db?mode=ro" + + +def test_build_conn_string_appends_options(): + assert _build_conn_string("/data/db", {"mode": "ro"}) == "/data/db?mode=ro" + + +def test_build_conn_string_merges_with_existing_query(): + result = _build_conn_string("file:/data/db?already=set", {"max_threads": 4}) + assert "already=set" in result and "max_threads=4" in result and "&" in result + + +def test_format_error_message_extracts_code_prefix(): + raw = "Some prefix\nCode: 46. DB::Exception: Function with name `bad` does not exist." + assert _format_error_message(raw).startswith("Code: 46.") + + +def test_format_error_message_passes_through_plain_text(): + assert _format_error_message("plain error") == "plain error" + assert _format_error_message("") == "" + + +# ---- closed client and lifecycle ---- + + +def test_query_after_close_raises(): + c = clickhouse_connect.get_client(interface="chdb") + c.close() + with pytest.raises(ProgrammingError): + c.query("SELECT 1") + + +def test_close_is_idempotent(): + c = clickhouse_connect.get_client(interface="chdb") + c.close() + c.close() # must not raise + + +def test_close_connections_closes_client(): + c = clickhouse_connect.get_client(interface="chdb") + c.close_connections() + with pytest.raises(ProgrammingError): + c.query("SELECT 1") + + +def test_context_manager_closes_client(): + with clickhouse_connect.get_client(interface="chdb") as c: + assert c.ping() is True + with pytest.raises(ProgrammingError): + c.query("SELECT 1") + + +# ---- chdb_path persistence ---- + + +def test_chdb_path_persists_across_clients(tmp_path): + db_path = str(tmp_path / "persisted.db") + + a = clickhouse_connect.get_client(interface="chdb", chdb_path=db_path) + try: + a.command("CREATE TABLE persisted (id UInt32) ENGINE = MergeTree ORDER BY id") + a.insert("persisted", [[13], [79]], column_names=["id"]) + finally: + a.close() + + b = clickhouse_connect.get_client(interface="chdb", chdb_path=db_path) + try: + rows = b.query("SELECT id FROM persisted ORDER BY id").result_rows + assert rows == [(13,), (79,)] + finally: + b.close() + + +# ---- per-call settings on query / insert ---- + + +def test_per_call_settings_appended_to_select(client): + # Setting that affects output rather than just performance, so we can verify it + # actually reached chdb. `output_format_decimal_trailing_zeros` controls Decimal + # text formatting, but for verification we use a behavior we can observe. + r = client.query("SELECT number FROM numbers(10)", settings={"max_block_size": 3}) + assert [row[0] for row in r.result_rows] == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + + +def test_per_call_settings_do_not_leak_via_query(client): + before = _read_session_setting(client, "max_block_size") + client.query("SELECT 1", settings={"max_block_size": 17}) + after = _read_session_setting(client, "max_block_size") + # query path uses inline SETTINGS clause (not SET), so it should never modify + # the session value at all. + assert after == before + + +# ---- show_clickhouse_errors ---- + + +def test_show_clickhouse_errors_false_sanitizes_message(): + c = clickhouse_connect.get_client(interface="chdb", show_clickhouse_errors=False) + try: + with pytest.raises(DatabaseError) as ex_info: + c.query("SELECT bad_function()") + assert "UNKNOWN_FUNCTION" not in str(ex_info.value) + assert "bad_function" not in str(ex_info.value) + finally: + c.close() + + +# ---- query_limit ---- + + +def test_query_limit_auto_appends_limit(): + c = clickhouse_connect.get_client(interface="chdb", query_limit=3) + try: + rows = c.query("SELECT number FROM numbers(100)").result_rows + assert len(rows) == 3 + finally: + c.close() + + +def test_explicit_limit_not_overridden_by_query_limit(): + c = clickhouse_connect.get_client(interface="chdb", query_limit=3) + try: + rows = c.query("SELECT number FROM numbers(100) LIMIT 7").result_rows + assert len(rows) == 7 + finally: + c.close() + + +# ---- streaming variations ---- + + +def test_raw_stream_via_context_manager(client): + with client.raw_stream("SELECT number FROM numbers(5)", fmt="CSV") as stream: + data = stream.read() + assert data == b"0\n1\n2\n3\n4\n" + + +def test_raw_stream_chunked_read(client): + stream = client.raw_stream("SELECT number FROM numbers(50)", fmt="CSV") + try: + out = b"" + while chunk := stream.read(8): + out += chunk + finally: + stream.close() + assert out == b"".join(f"{n}\n".encode() for n in range(50)) + + +def test_raw_stream_readinto(client): + stream = client.raw_stream("SELECT number FROM numbers(3)", fmt="CSV") + try: + buf = bytearray(64) + n = stream.readinto(buf) + assert buf[:n] == b"0\n1\n2\n" + finally: + stream.close() + + +def test_stream_release_lock_on_close(client): + # If close() doesn't release the lock, the next query would deadlock. + stream = client.raw_stream("SELECT 1", fmt="CSV") + stream.close() + # Should return immediately, no deadlock: + assert client.query("SELECT 1").result_rows == [(1,)] + + +# ---- raw_insert input shapes ---- + + +def test_raw_insert_accepts_str(client): + client.command("CREATE TABLE raw_str (id UInt32, v String) ENGINE = Memory") + client.raw_insert("raw_str", insert_block="13,user_1\n79,user_2\n", fmt="CSV") + r = client.query("SELECT id, v FROM raw_str ORDER BY id") + assert r.result_rows == [(13, "user_1"), (79, "user_2")] + + +def test_raw_insert_accepts_file_like(client): + client.command("CREATE TABLE raw_file (id UInt32, v String) ENGINE = Memory") + buf = io.BytesIO(b"13,user_1\n79,user_2\n") + client.raw_insert("raw_file", insert_block=buf, fmt="CSV") + r = client.query("SELECT id, v FROM raw_file ORDER BY id") + assert r.result_rows == [(13, "user_1"), (79, "user_2")] + + +def test_raw_insert_accepts_generator(client): + client.command("CREATE TABLE raw_gen (id UInt32, v String) ENGINE = Memory") + + def chunks(): + yield b"13,user_1\n" + yield b"79,user_2\n" + + client.raw_insert("raw_gen", insert_block=chunks(), fmt="CSV") + r = client.query("SELECT id, v FROM raw_gen ORDER BY id") + assert r.result_rows == [(13, "user_1"), (79, "user_2")] + + +@pytest.mark.parametrize("compression", ["lz4", "zstd", "gzip"]) +def test_raw_insert_decompresses_pre_compressed_payload(client, compression): + """raw_insert with `compression=` accepts compressed bytes and decompresses client-side.""" + import gzip + + import lz4.frame + import zstandard + + csv = b"13,user_1\n79,user_2\n" + encoded = { + "lz4": lz4.frame.compress(csv), + "zstd": zstandard.ZstdCompressor().compress(csv), + "gzip": gzip.compress(csv), + }[compression] + client.command(f"CREATE TABLE raw_compress_{compression} (id UInt32, v String) ENGINE = Memory") + client.raw_insert( + f"raw_compress_{compression}", + insert_block=encoded, + fmt="CSV", + compression=compression, + ) + r = client.query(f"SELECT id, v FROM raw_compress_{compression} ORDER BY id") + assert r.result_rows == [(13, "user_1"), (79, "user_2")] + + +def test_raw_insert_unsupported_compression_raises(client): + with pytest.raises(NotSupportedError): + client.raw_insert("t", insert_block=b"1\n", fmt="CSV", compression="snappy") + + +def test_raw_insert_missing_args(client): + with pytest.raises(ProgrammingError): + client.raw_insert(None, insert_block=b"x") # type: ignore[arg-type] + with pytest.raises(ProgrammingError): + client.raw_insert("t", insert_block=None) + + +def test_raw_insert_cleans_up_temp_file(client, monkeypatch): + """Verify the temp file is deleted even when chdb errors.""" + client.command("CREATE TABLE raw_cleanup (id UInt32) ENGINE = Memory") + seen_paths = [] + + import tempfile as _tempfile + + original = _tempfile.NamedTemporaryFile + + def tracking(*args, **kwargs): + f = original(*args, **kwargs) + seen_paths.append(f.name) + return f + + monkeypatch.setattr(_tempfile, "NamedTemporaryFile", tracking) + + # Bad CSV content for an UInt32 column will cause chdb to error. + with pytest.raises(DatabaseError): + client.raw_insert("raw_cleanup", insert_block=b"not_a_number\n", fmt="CSV") + + assert seen_paths, "temp file path not captured" + for p in seen_paths: + assert not os.path.exists(p), f"temp file leaked: {p}" + + +# ---- additional types ---- + + +def test_query_tuple_and_fixed_string(client): + r = client.query("SELECT tuple(1, 'a', 3.14) AS t, toFixedString('xyz', 4) AS fs") + t, fs = r.result_rows[0] + assert t == (1, "a", 3.14) + assert fs == b"xyz\x00" + + +def test_query_uuid(client): + val = "550e8400-e29b-41d4-a716-446655440000" + r = client.query(f"SELECT toUUID('{val}') AS u") + assert r.result_rows == [(UUID(val),)] + + +def test_query_ipv4_ipv6(client): + r = client.query("SELECT toIPv4('127.0.0.1') AS v4, toIPv6('::1') AS v6") + v4, v6 = r.result_rows[0] + import ipaddress + + assert v4 == ipaddress.IPv4Address("127.0.0.1") + assert v6 == ipaddress.IPv6Address("::1") + + +def test_query_enum(client): + r = client.query("SELECT CAST('a' AS Enum8('a' = 1, 'b' = 2)) AS e") + assert r.result_rows == [("a",)] + + +def test_query_datetime64_with_tz(client): + r = client.query("SELECT toDateTime64('2026-05-19 10:30:00.123456', 6, 'America/New_York') AS dt") + (dt,) = r.result_rows[0] + assert dt.year == 2026 and dt.microsecond == 123456 + + +def test_query_nan_handling(client): + r = client.query("SELECT CAST('nan' AS Float64) AS x, CAST('-inf' AS Float64) AS y") + x, y = r.result_rows[0] + assert x != x # NaN + assert y == float("-inf") + + +# ---- parameter binding ---- + + +def test_query_with_parameters(client): + r = client.query("SELECT {x:Int32} AS x, {name:String} AS name", parameters={"x": 13, "name": "user_1"}) + assert r.result_rows == [(13, "user_1")] + + +def test_raw_query_with_embedded_binary_parameter(client): + """`$name$` placeholders inline raw bytes — chdb accepts bytes SQL, no decode.""" + binary_params = {"$xx$": b"col1,col2\n100,700"} + result = client.raw_query("SELECT col2, col1 FROM format(CSVWithNames, $xx$)", parameters=binary_params) + assert result == b"700\t100\n" + + +def test_raw_query_embedded_binary_with_non_utf8_bytes(client): + """Non-UTF-8 bytes (e.g. binary file content) embedded in SQL must round-trip.""" + payload = b"col1,col2\n100,\xff\x92" + result = client.raw_query("SELECT col2 FROM format(CSVWithNames, $xx$)", parameters={"$xx$": payload}) + # The non-UTF-8 byte sequence must come back intact in the output. + assert b"\xff" in result or b"\xc3\xbf" in result + + +# ---- transport-only settings don't get persisted ---- + + +def test_transport_only_setting_not_persisted_to_session(client): + # session_id is a transport-only key; ChdbClient should accept it but NOT emit + # SET session_id=... to chdb (which would either error or apply a meaningless setting). + before = _read_session_setting(client, "session_id") + client.set_client_setting("session_id", "abc-123") + after = _read_session_setting(client, "session_id") + assert after == before + # But the recorded client-side value is kept for inspection + assert client.get_client_setting("session_id") == "abc-123" + + +# ---- DataFrame stream ---- + + +def test_query_df_stream(client): + pytest.importorskip("pandas") + client.command("CREATE TABLE df_stream (id UInt32) ENGINE = Memory") + client.insert("df_stream", [[i] for i in range(20)], column_names=["id"]) + with client.query_df_stream("SELECT id FROM df_stream SETTINGS max_block_size = 5") as stream: + frames = list(stream) + total = sum(len(f) for f in frames) + assert total == 20 + + +# ---- async additional coverage ---- + + +def test_async_external_data_rejected(): + async def run(): + c = await clickhouse_connect.get_async_client(interface="chdb") + try: + from clickhouse_connect.driver.external import ExternalData + + ext = ExternalData(file_name="x.csv", data=b"1\n", fmt="CSV", structure="id UInt32") + with pytest.raises(NotSupportedError): + await c.query("SELECT * FROM x", external_data=ext) + finally: + await c.close() + + asyncio.run(run()) + + +def test_async_query_error_propagates_as_database_error(): + async def run(): + c = await clickhouse_connect.get_async_client(interface="chdb") + try: + with pytest.raises(DatabaseError): + await c.query("SELECT bad_function()") + finally: + await c.close() + + asyncio.run(run()) + + +def test_async_closed_client_query_raises(): + async def run(): + c = await clickhouse_connect.get_async_client(interface="chdb") + await c.close() + with pytest.raises(ProgrammingError): + await c.query("SELECT 1") + + asyncio.run(run()) + + +def test_async_set_client_setting_is_sync(client): + # Async client's set_client_setting is intentionally sync (no I/O wrap) for + # symmetry with HTTP AsyncClient. + async def run(): + c = await clickhouse_connect.get_async_client(interface="chdb") + try: + c.set_client_setting("max_block_size", 99) # NOT awaited + assert c.get_client_setting("max_block_size") == "99" + finally: + await c.close() + + asyncio.run(run())