diff --git a/pyproject.toml b/pyproject.toml index 8bd5de2ad..038070c83 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ build-backend = "setuptools.build_meta" [tool.setuptools.packages.find] where = ["."] -include = ["vectordb_bench", "vectordb_bench.cli"] +include = ["vectordb_bench", "vectordb_bench.*"] [project] name = "vectordb-bench" @@ -26,6 +26,7 @@ classifiers = [ ] dependencies = [ "click", + "pyyaml", "pytz", "streamlit>=1.47,<2", # 1.47 fixes streamlit#11660 "tqdm", diff --git a/vectordb_bench/backend/clients/oceanbase/cli.py b/vectordb_bench/backend/clients/oceanbase/cli.py index 61583cc82..81ccacc79 100644 --- a/vectordb_bench/backend/clients/oceanbase/cli.py +++ b/vectordb_bench/backend/clients/oceanbase/cli.py @@ -31,9 +31,40 @@ class OceanBaseTypedDict(CommonTypedDict): ] database: Annotated[str, click.option("--database", type=str, help="DataBase name", required=True)] port: Annotated[int, click.option("--port", type=int, help="OceanBase port", required=True)] + create_index_parallel: Annotated[ + int, + click.option( + "--create-index-parallel", + type=int, + default=16, + show_default=True, + help="PARALLEL hint degree for CREATE VECTOR INDEX", + ), + ] + partitions: Annotated[ + int, + click.option( + "--partitions", + type=int, + default=0, + show_default=True, + help="Number of KEY partitions for the table. 0 or 1 means no partitioning.", + ), + ] -class OceanBaseHNSWTypedDict(CommonTypedDict, OceanBaseTypedDict, HNSWFlavor4): ... +class OceanBaseHNSWTypedDict(CommonTypedDict, OceanBaseTypedDict, HNSWFlavor4): + extra_info_max_size: Annotated[ + int | None, + click.option( + "--extra-info-max-size", + type=int, + default=32, + show_default=True, + help="extra_info_max_size for HNSW index. Set to 0 to omit.", + required=False, + ), + ] @cli.command() @@ -55,6 +86,9 @@ def OceanBaseHNSW(**parameters: Unpack[OceanBaseHNSWTypedDict]): m=parameters["m"], efConstruction=parameters["ef_construction"], ef_search=parameters["ef_search"], + extra_info_max_size=parameters["extra_info_max_size"] or None, + create_index_parallel=parameters["create_index_parallel"], + partitions=parameters["partitions"], index=parameters["index_type"], ), **parameters, @@ -94,6 +128,8 @@ def OceanBaseIVF(**parameters: Unpack[OceanBaseIVFTypedDict]): nlist=parameters["nlist"], sample_per_nlist=parameters["sample_per_nlist"], nbits=parameters["nbits"], + create_index_parallel=parameters["create_index_parallel"], + partitions=parameters["partitions"], index=input_index_type, ivf_nprobes=parameters["ivf_nprobes"], ), diff --git a/vectordb_bench/backend/clients/oceanbase/config.py b/vectordb_bench/backend/clients/oceanbase/config.py index 1f37cfc75..384a7c263 100644 --- a/vectordb_bench/backend/clients/oceanbase/config.py +++ b/vectordb_bench/backend/clients/oceanbase/config.py @@ -36,20 +36,18 @@ class OceanBaseIndexConfig(BaseModel): index: IndexType metric_type: MetricType | None = None lib: str = "vsag" + create_index_parallel: int = 16 + partitions: int = 0 def parse_metric(self) -> str: - if self.metric_type == MetricType.L2 or ( - self.index == IndexType.HNSW_BQ and self.metric_type == MetricType.COSINE - ): + if self.metric_type == MetricType.L2: return "l2" if self.metric_type == MetricType.IP: return "inner_product" return "cosine" def parse_metric_func_str(self) -> str: - if self.metric_type == MetricType.L2 or ( - self.index == IndexType.HNSW_BQ and self.metric_type == MetricType.COSINE - ): + if self.metric_type == MetricType.L2: return "l2_distance" if self.metric_type == MetricType.IP: return "negative_inner_product" @@ -60,6 +58,7 @@ class OceanBaseHNSWConfig(OceanBaseIndexConfig, DBCaseConfig): m: int efConstruction: int ef_search: int | None = None + extra_info_max_size: int | None = 32 index: IndexType def index_param(self) -> dict: diff --git a/vectordb_bench/backend/clients/oceanbase/oceanbase.py b/vectordb_bench/backend/clients/oceanbase/oceanbase.py index bf615e4d0..3281b1e3d 100644 --- a/vectordb_bench/backend/clients/oceanbase/oceanbase.py +++ b/vectordb_bench/backend/clients/oceanbase/oceanbase.py @@ -23,6 +23,8 @@ class OceanBase(VectorDB): FilterOp.NumGE, FilterOp.StrEqual, ] + # insert path is GIL-bound; multi-threading cannot improve load throughput + thread_safe: bool = False def __init__( self, @@ -58,6 +60,15 @@ def __init__( finally: self._disconnect() + def __getstate__(self): + state = self.__dict__.copy() + state["_conn"] = None + state["_cursor"] = None + return state + + def __setstate__(self, state: dict) -> None: + self.__dict__.update(state) + def _connect(self): try: self._conn = mysql.connect( @@ -109,27 +120,28 @@ def _create_table(self): if not self._cursor: raise ValueError("Cursor is not initialized") - log.info(f"Creating table {self.table_name}") - create_table_query = f""" - CREATE TABLE {self.table_name} ( - id INT PRIMARY KEY, - embedding VECTOR({self.dim}) - ); - """ + partitions = getattr(self.db_case_config, "partitions", 0) + log.info(f"Creating table {self.table_name} (partitions={partitions})") + + create_table_query = f"CREATE TABLE {self.table_name} (id INT PRIMARY KEY, embedding VECTOR({self.dim}))" + if partitions > 1: + create_table_query += f" PARTITION BY KEY(id) PARTITIONS {partitions}" + create_table_query += ";" self._cursor.execute(create_table_query) def optimize(self, data_size: int): index_params = self.db_case_config.index_param() index_args = ", ".join(f"{k}={v}" for k, v in index_params["params"].items()) index_query = ( - f"CREATE /*+ PARALLEL(18) */ VECTOR INDEX idx1 " + f"CREATE /*+ PARALLEL({self.db_case_config.create_index_parallel}) */ VECTOR INDEX idx1 " f"ON {self.table_name}(embedding) " f"WITH (distance={self.db_case_config.parse_metric()}, " f"type={index_params['index_type']}, lib={index_params['lib']}, {index_args}" ) - if self.db_case_config.index in {IndexType.HNSW, IndexType.HNSW_SQ, IndexType.HNSW_BQ}: - index_query += ", extra_info_max_size=32" + extra_info = getattr(self.db_case_config, "extra_info_max_size", None) + if extra_info is not None: + index_query += f", extra_info_max_size={extra_info}" index_query += ")" @@ -153,10 +165,6 @@ def optimize(self, data_size: int): raise def need_normalize_cosine(self) -> bool: - if self.db_case_config.index == IndexType.HNSW_BQ: - log.info("current HNSW_BQ only supports L2, cosine dataset need normalize.") - return True - return False def _wait_for_major_compaction(self): @@ -185,9 +193,7 @@ def insert_embeddings( batch_end = min(batch_start + self.load_batch_size, len(embeddings)) batch = [(metadata[i], embeddings[i]) for i in range(batch_start, batch_end)] values = ", ".join(f"({item_id}, '[{','.join(map(str, embedding))}]')" for item_id, embedding in batch) - self._cursor.execute( - f"INSERT /*+ ENABLE_PARALLEL_DML PARALLEL(32) */ INTO {self.table_name} VALUES {values}" - ) + self._cursor.execute(f"INSERT INTO {self.table_name} VALUES {values}") insert_count += len(batch) except mysql.Error: log.exception("Failed to insert embeddings") diff --git a/vectordb_bench/backend/runner/__init__.py b/vectordb_bench/backend/runner/__init__.py index d56fe0ff8..4a7ccb0c5 100644 --- a/vectordb_bench/backend/runner/__init__.py +++ b/vectordb_bench/backend/runner/__init__.py @@ -1,11 +1,13 @@ from .concurrent_runner import ConcurrentInsertRunner from .mp_runner import MultiProcessingSearchRunner +from .multiprocess_load_runner import MultiprocessInsertRunner from .read_write_runner import ReadWriteRunner from .serial_runner import SerialInsertRunner, SerialSearchRunner __all__ = [ "ConcurrentInsertRunner", "MultiProcessingSearchRunner", + "MultiprocessInsertRunner", "ReadWriteRunner", "SerialInsertRunner", "SerialSearchRunner", diff --git a/vectordb_bench/backend/runner/concurrent_runner.py b/vectordb_bench/backend/runner/concurrent_runner.py index 7c8aeb24f..c989a69da 100644 --- a/vectordb_bench/backend/runner/concurrent_runner.py +++ b/vectordb_bench/backend/runner/concurrent_runner.py @@ -74,7 +74,10 @@ def __init__( effective_workers = max_workers or min(mp.cpu_count(), 4) if not db.thread_safe: - log.info(f"DB {db.name} is not thread-safe, falling back to max_workers=1") + log.info( + f"DB {db.name} declared thread_safe=False, falling back to max_workers=1" + " (use --load-processes for parallel loading)" + ) effective_workers = 1 self.max_workers = effective_workers assert db.thread_safe or self.max_workers == 1, ( diff --git a/vectordb_bench/backend/runner/multiprocess_load_runner.py b/vectordb_bench/backend/runner/multiprocess_load_runner.py new file mode 100644 index 000000000..bc1e7c84c --- /dev/null +++ b/vectordb_bench/backend/runner/multiprocess_load_runner.py @@ -0,0 +1,261 @@ +"""Multi-process loader for the performance load stage.""" + +from __future__ import annotations + +import contextlib +import logging +import multiprocessing as mp +import time +from typing import TYPE_CHECKING, Any + +import numpy as np + +from ... import config +from ...models import PerformanceTimeoutError +from ..filter import Filter, FilterOp, non_filter +from ..utils import time_it + +if TYPE_CHECKING: + from multiprocessing.queues import Queue as MPQueue + from multiprocessing.sharedctypes import Synchronized + + from ..clients import api + from ..dataset import DatasetManager + + +log = logging.getLogger(__name__) + + +# Sent from the producer to signal "no more work". +_SENTINEL = None + +# Poll interval while the producer waits for a stalled queue slot. +_PROGRESS_INTERVAL_SEC = 5 + + +def _insert_worker( + db: api.VectorDB, + task_queue: MPQueue, + error_queue: MPQueue, + counter: Synchronized, + normalize: bool, + worker_id: int, +) -> None: + """Drain `task_queue` in a dedicated process. + + Each item is a `(ids, emb_np, labels)` tuple. Workers share no state + beyond the queues + counter; a failure from any worker is surfaced via + `error_queue` and causes the runner to abort. + """ + try: + with db.init(): + while True: + item = task_queue.get() + if item is _SENTINEL: + return + + ids, emb_np, labels = item + if normalize: + emb_np = emb_np / np.linalg.norm(emb_np, axis=1)[:, np.newaxis] + embeddings = emb_np.tolist() + + kwargs: dict[str, Any] = {} + if labels is not None: + kwargs["labels_data"] = labels + inserted, err = db.insert_embeddings( + embeddings=embeddings, + metadata=ids, + **kwargs, + ) + if err is not None: + raise err # noqa: TRY301 + with counter.get_lock(): + counter.value += inserted + except Exception as exc: + log.exception(f"[worker-{worker_id}] insert failed") + with contextlib.suppress(Exception): + error_queue.put(f"worker-{worker_id}: {exc!r}") + + +class MultiprocessInsertRunner: + """Run insert_embeddings across a pool of worker processes.""" + + def __init__( + self, + db: api.VectorDB, + dataset: DatasetManager, + normalize: bool, + workers: int, + filters: Filter = non_filter, + timeout: float | None = None, + queue_size: int | None = None, + ): + self.db = db + self.dataset = dataset + self.normalize = normalize + self.filters = filters + self.timeout = timeout if isinstance(timeout, int | float) else None + + if workers <= 0: + workers = mp.cpu_count() + self.workers = workers + # A couple slots of slack per worker lets the producer run a step ahead + # without unbounded memory growth. + self.queue_size = queue_size if queue_size and queue_size > 0 else max(workers * 2, 4) + + def _pull_labels(self, data_df: Any, ids: list[int]) -> list | None: + """Extract per-row scalar labels for StrEqual filter, if enabled.""" + if self.filters.type != FilterOp.StrEqual: + return None + if self.dataset.data.scalar_labels_file_separated: + return self.dataset.scalar_labels[self.filters.label_field][ids].to_list() + return data_df[self.filters.label_field].tolist() + + @time_it + def _run(self) -> int: + ctx = mp.get_context("spawn") + task_queue: MPQueue = ctx.Queue(maxsize=self.queue_size) + error_queue: MPQueue = ctx.Queue() + counter: Synchronized = ctx.Value("q", 0) + + log.info( + f"Multiprocess load start: workers={self.workers}, " + f"queue_size={self.queue_size}, batch_size={config.NUM_PER_BATCH}", + ) + + procs: list[mp.Process] = [] + for i in range(self.workers): + p = ctx.Process( + target=_insert_worker, + args=( + self.db, + task_queue, + error_queue, + counter, + self.normalize, + i, + ), + name=f"vdb-load-{i}", + # daemon so workers die with the parent subprocess on SIGTERM/SIGKILL. + daemon=True, + ) + p.start() + procs.append(p) + + id_field = self.dataset.data.train_id_field + vec_field = self.dataset.data.train_vector_field + + produced = 0 + start = time.perf_counter() + last_log = start + interrupted = False + + try: + for data_df in self.dataset: + if self._abort_if_worker_died(error_queue, procs): + break + + ids = data_df[id_field].tolist() + emb_np = np.stack(data_df[vec_field]) + labels = self._pull_labels(data_df, ids) + # Short timeout on put so a stuck queue doesn't block SIGTERM/Ctrl+C. + self._put_with_interruptible_wait(task_queue, (ids, emb_np, labels), procs) + produced += len(ids) + + now = time.perf_counter() + if now - last_log >= _PROGRESS_INTERVAL_SEC: + done = counter.value + elapsed = now - start + rate = done / elapsed if elapsed > 0 else 0 + log.info( + f"Load progress: produced={produced} inserted={done} " + f"rate={rate:.0f}/s elapsed={elapsed:.0f}s", + ) + last_log = now + + if self.timeout is not None and (now - start) > self.timeout: + msg = f"Multiprocess load exceeded timeout {self.timeout}s" + log.warning(msg) + raise PerformanceTimeoutError(msg) + except KeyboardInterrupt: + log.warning("Multiprocess load interrupted by user") + interrupted = True + raise + finally: + self._shutdown(procs, task_queue, graceful=not interrupted) + + if not error_queue.empty(): + err_detail = error_queue.get_nowait() + msg = f"Multiprocess load failed: {err_detail}" + raise RuntimeError(msg) + + inserted = counter.value + elapsed = time.perf_counter() - start + rate = inserted / elapsed if elapsed > 0 else 0 + log.info( + f"Multiprocess load done: produced={produced} inserted={inserted} " + f"elapsed={elapsed:.2f}s rate={rate:.0f}/s", + ) + return inserted + + @staticmethod + def _put_with_interruptible_wait( + queue: MPQueue, + item: Any, + procs: list[mp.Process], + chunk_timeout: float = 1.0, + ) -> None: + """`queue.put` that polls so a Ctrl+C or worker crash doesn't block it forever.""" + while True: + try: + queue.put(item, timeout=chunk_timeout) + except Exception: + if all(not p.is_alive() for p in procs): + msg = "All workers exited while producer was waiting to enqueue" + raise RuntimeError(msg) from None + else: + return + + @staticmethod + def _abort_if_worker_died(error_queue: MPQueue, procs: list[mp.Process]) -> bool: + """Surface worker failures or unexpected exits early.""" + if not error_queue.empty(): + return True + return any(not p.is_alive() and p.exitcode not in (None, 0) for p in procs) + + @staticmethod + def _shutdown( + procs: list[mp.Process], + task_queue: MPQueue, + graceful: bool, + graceful_join: float = 10.0, + hard_join: float = 3.0, + ) -> None: + """Wind down workers. On interrupt, skip draining and terminate fast.""" + if graceful: + # Let workers finish what's already queued, then exit on sentinel. + for _ in procs: + try: + task_queue.put(_SENTINEL, timeout=5) + except Exception: + break + for p in procs: + p.join(timeout=graceful_join) + + # Anyone still alive gets terminated — includes the interrupted path. + for p in procs: + if p.is_alive(): + if graceful: + log.warning(f"worker {p.name} did not exit cleanly; terminating") + with contextlib.suppress(Exception): + p.terminate() + for p in procs: + p.join(timeout=hard_join) + if p.is_alive(): + with contextlib.suppress(Exception): + p.kill() + p.join(timeout=1) + + def run(self) -> int: + count, _ = self._run() + return count diff --git a/vectordb_bench/backend/task_runner.py b/vectordb_bench/backend/task_runner.py index 6b51d1277..4723415e8 100644 --- a/vectordb_bench/backend/task_runner.py +++ b/vectordb_bench/backend/task_runner.py @@ -17,6 +17,7 @@ from .runner import ( ConcurrentInsertRunner, MultiProcessingSearchRunner, + MultiprocessInsertRunner, ReadWriteRunner, SerialInsertRunner, SerialSearchRunner, @@ -249,20 +250,47 @@ def _run_streaming_case(self) -> Metric: def _load_train_data(self): """Insert train data concurrently and get the insert_duration""" try: - runner = ConcurrentInsertRunner( - self.db, - self.ca.dataset, - self.normalize, - self.ca.filters, - self.ca.load_timeout, - max_workers=self.config.load_concurrency or None, - ) + workers = self._pick_load_processes() + if workers > 0: + log.info(f"Using MultiprocessInsertRunner (workers={workers})") + runner = MultiprocessInsertRunner( + db=self.db, + dataset=self.ca.dataset, + normalize=self.normalize, + workers=workers, + filters=self.ca.filters, + timeout=self.ca.load_timeout, + ) + else: + runner = ConcurrentInsertRunner( + self.db, + self.ca.dataset, + self.normalize, + self.ca.filters, + self.ca.load_timeout, + max_workers=self.config.load_concurrency or None, + ) runner.run() except Exception as e: raise e from None finally: runner = None + def _pick_load_processes(self) -> int: + """Return worker count for multi-process loader, or 0 to use threaded loader.""" + if self.config.load_processes > 0: + return self.config.load_processes + + if not getattr(self.db, "thread_safe", True) and self.config.load_concurrency > 1: + log.info( + f"{self.db.name} declares thread_safe=False; auto-switching to " + f"multi-process loader with {self.config.load_concurrency} workers " + f"(set --load-processes to override)", + ) + return self.config.load_concurrency + + return 0 + def _serial_search(self) -> tuple[float, float, float, float]: """Performance serial tests, search the entire test data once, calculate the recall, serial_latency_p99, serial_latency_p95 diff --git a/vectordb_bench/cli/cli.py b/vectordb_bench/cli/cli.py index 94b13762a..16cbf8dd7 100644 --- a/vectordb_bench/cli/cli.py +++ b/vectordb_bench/cli/cli.py @@ -240,6 +240,23 @@ class CommonTypedDict(TypedDict): help="Number of concurrent workers for data loading in performance cases (0 = cpu_count)", ), ] + load_processes: Annotated[ + int, + click.option( + "--load-processes", + type=int, + default=0, + show_default=True, + help=( + "Use an N-worker multi-process loader instead of the threaded " + "one (--load-concurrency). 0 disables it. Recommended for " + "SQL-based vector DBs (OceanBase / TiDB / pgvector / Doris / " + "VectorChord) where the insert path is GIL-bound. When a client " + "declares thread_safe=False, this is auto-enabled with " + "--load-concurrency workers even if this flag is 0" + ), + ), + ] search_serial: Annotated[ bool, click.option( @@ -653,6 +670,7 @@ def run( parameters["search_concurrent"], ), load_concurrency=parameters["load_concurrency"], + load_processes=parameters["load_processes"], ) task_label = parameters["task_label"] diff --git a/vectordb_bench/models.py b/vectordb_bench/models.py index a7e7c09f1..380ca5026 100644 --- a/vectordb_bench/models.py +++ b/vectordb_bench/models.py @@ -244,6 +244,7 @@ class TaskConfig(BaseModel): case_config: CaseConfig stages: list[TaskStage] = ALL_TASK_STAGES load_concurrency: int = config.LOAD_CONCURRENCY + load_processes: int = 0 @property def db_name(self):