Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ build-backend = "setuptools.build_meta"

[tool.setuptools.packages.find]
where = ["."]
include = ["vectordb_bench", "vectordb_bench.cli"]
include = ["vectordb_bench", "vectordb_bench.*"]

[project]
name = "vectordb-bench"
Expand All @@ -26,6 +26,7 @@ classifiers = [
]
dependencies = [
"click",
"pyyaml",
"pytz",
"streamlit>=1.47,<2", # 1.47 fixes streamlit#11660
"tqdm",
Expand Down
38 changes: 37 additions & 1 deletion vectordb_bench/backend/clients/oceanbase/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,40 @@ class OceanBaseTypedDict(CommonTypedDict):
]
database: Annotated[str, click.option("--database", type=str, help="DataBase name", required=True)]
port: Annotated[int, click.option("--port", type=int, help="OceanBase port", required=True)]
create_index_parallel: Annotated[
int,
click.option(
"--create-index-parallel",
type=int,
default=16,
show_default=True,
help="PARALLEL hint degree for CREATE VECTOR INDEX",
),
]
partitions: Annotated[
int,
click.option(
"--partitions",
type=int,
default=0,
show_default=True,
help="Number of KEY partitions for the table. 0 or 1 means no partitioning.",
),
]


class OceanBaseHNSWTypedDict(CommonTypedDict, OceanBaseTypedDict, HNSWFlavor4): ...
class OceanBaseHNSWTypedDict(CommonTypedDict, OceanBaseTypedDict, HNSWFlavor4):
extra_info_max_size: Annotated[
int | None,
click.option(
"--extra-info-max-size",
type=int,
default=32,
show_default=True,
help="extra_info_max_size for HNSW index. Set to 0 to omit.",
required=False,
),
]


@cli.command()
Expand All @@ -55,6 +86,9 @@ def OceanBaseHNSW(**parameters: Unpack[OceanBaseHNSWTypedDict]):
m=parameters["m"],
efConstruction=parameters["ef_construction"],
ef_search=parameters["ef_search"],
extra_info_max_size=parameters["extra_info_max_size"] or None,
create_index_parallel=parameters["create_index_parallel"],
partitions=parameters["partitions"],
index=parameters["index_type"],
),
**parameters,
Expand Down Expand Up @@ -94,6 +128,8 @@ def OceanBaseIVF(**parameters: Unpack[OceanBaseIVFTypedDict]):
nlist=parameters["nlist"],
sample_per_nlist=parameters["sample_per_nlist"],
nbits=parameters["nbits"],
create_index_parallel=parameters["create_index_parallel"],
partitions=parameters["partitions"],
index=input_index_type,
ivf_nprobes=parameters["ivf_nprobes"],
),
Expand Down
11 changes: 5 additions & 6 deletions vectordb_bench/backend/clients/oceanbase/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,20 +36,18 @@ class OceanBaseIndexConfig(BaseModel):
index: IndexType
metric_type: MetricType | None = None
lib: str = "vsag"
create_index_parallel: int = 16
partitions: int = 0

def parse_metric(self) -> str:
if self.metric_type == MetricType.L2 or (
self.index == IndexType.HNSW_BQ and self.metric_type == MetricType.COSINE
):
if self.metric_type == MetricType.L2:
return "l2"
if self.metric_type == MetricType.IP:
return "inner_product"
return "cosine"

def parse_metric_func_str(self) -> str:
if self.metric_type == MetricType.L2 or (
self.index == IndexType.HNSW_BQ and self.metric_type == MetricType.COSINE
):
if self.metric_type == MetricType.L2:
return "l2_distance"
if self.metric_type == MetricType.IP:
return "negative_inner_product"
Expand All @@ -60,6 +58,7 @@ class OceanBaseHNSWConfig(OceanBaseIndexConfig, DBCaseConfig):
m: int
efConstruction: int
ef_search: int | None = None
extra_info_max_size: int | None = 32
index: IndexType

def index_param(self) -> dict:
Expand Down
40 changes: 23 additions & 17 deletions vectordb_bench/backend/clients/oceanbase/oceanbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ class OceanBase(VectorDB):
FilterOp.NumGE,
FilterOp.StrEqual,
]
# insert path is GIL-bound; multi-threading cannot improve load throughput
thread_safe: bool = False

def __init__(
self,
Expand Down Expand Up @@ -58,6 +60,15 @@ def __init__(
finally:
self._disconnect()

def __getstate__(self):
state = self.__dict__.copy()
state["_conn"] = None
state["_cursor"] = None
return state

def __setstate__(self, state: dict) -> None:
self.__dict__.update(state)

def _connect(self):
try:
self._conn = mysql.connect(
Expand Down Expand Up @@ -109,27 +120,28 @@ def _create_table(self):
if not self._cursor:
raise ValueError("Cursor is not initialized")

log.info(f"Creating table {self.table_name}")
create_table_query = f"""
CREATE TABLE {self.table_name} (
id INT PRIMARY KEY,
embedding VECTOR({self.dim})
);
"""
partitions = getattr(self.db_case_config, "partitions", 0)
log.info(f"Creating table {self.table_name} (partitions={partitions})")

create_table_query = f"CREATE TABLE {self.table_name} (id INT PRIMARY KEY, embedding VECTOR({self.dim}))"
if partitions > 1:
create_table_query += f" PARTITION BY KEY(id) PARTITIONS {partitions}"
create_table_query += ";"
self._cursor.execute(create_table_query)

def optimize(self, data_size: int):
index_params = self.db_case_config.index_param()
index_args = ", ".join(f"{k}={v}" for k, v in index_params["params"].items())
index_query = (
f"CREATE /*+ PARALLEL(18) */ VECTOR INDEX idx1 "
f"CREATE /*+ PARALLEL({self.db_case_config.create_index_parallel}) */ VECTOR INDEX idx1 "
f"ON {self.table_name}(embedding) "
f"WITH (distance={self.db_case_config.parse_metric()}, "
f"type={index_params['index_type']}, lib={index_params['lib']}, {index_args}"
)

if self.db_case_config.index in {IndexType.HNSW, IndexType.HNSW_SQ, IndexType.HNSW_BQ}:
index_query += ", extra_info_max_size=32"
extra_info = getattr(self.db_case_config, "extra_info_max_size", None)
if extra_info is not None:
index_query += f", extra_info_max_size={extra_info}"

index_query += ")"

Expand All @@ -153,10 +165,6 @@ def optimize(self, data_size: int):
raise

def need_normalize_cosine(self) -> bool:
if self.db_case_config.index == IndexType.HNSW_BQ:
log.info("current HNSW_BQ only supports L2, cosine dataset need normalize.")
return True

return False

def _wait_for_major_compaction(self):
Expand Down Expand Up @@ -185,9 +193,7 @@ def insert_embeddings(
batch_end = min(batch_start + self.load_batch_size, len(embeddings))
batch = [(metadata[i], embeddings[i]) for i in range(batch_start, batch_end)]
values = ", ".join(f"({item_id}, '[{','.join(map(str, embedding))}]')" for item_id, embedding in batch)
self._cursor.execute(
f"INSERT /*+ ENABLE_PARALLEL_DML PARALLEL(32) */ INTO {self.table_name} VALUES {values}"
)
self._cursor.execute(f"INSERT INTO {self.table_name} VALUES {values}")
insert_count += len(batch)
except mysql.Error:
log.exception("Failed to insert embeddings")
Expand Down
2 changes: 2 additions & 0 deletions vectordb_bench/backend/runner/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from .concurrent_runner import ConcurrentInsertRunner
from .mp_runner import MultiProcessingSearchRunner
from .multiprocess_load_runner import MultiprocessInsertRunner
from .read_write_runner import ReadWriteRunner
from .serial_runner import SerialInsertRunner, SerialSearchRunner

__all__ = [
"ConcurrentInsertRunner",
"MultiProcessingSearchRunner",
"MultiprocessInsertRunner",
"ReadWriteRunner",
"SerialInsertRunner",
"SerialSearchRunner",
Expand Down
5 changes: 4 additions & 1 deletion vectordb_bench/backend/runner/concurrent_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,10 @@ def __init__(

effective_workers = max_workers or min(mp.cpu_count(), 4)
if not db.thread_safe:
log.info(f"DB {db.name} is not thread-safe, falling back to max_workers=1")
log.info(
f"DB {db.name} declared thread_safe=False, falling back to max_workers=1"
" (use --load-processes for parallel loading)"
)
effective_workers = 1
self.max_workers = effective_workers
assert db.thread_safe or self.max_workers == 1, (
Expand Down
Loading
Loading