diff --git a/.gitignore b/.gitignore index ad08cdc..8c1d2ee 100644 --- a/.gitignore +++ b/.gitignore @@ -21,6 +21,4 @@ build/ # Internal docs (local only) docs/internal/ - -# 論文資料(ローカルのみ) -mem/ +.strata \ No newline at end of file diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..5de059f --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,31 @@ +# codeatrium — Agent Usage Guide + +`codeatrium` is a CLI-first memory layer for AI coding agents. The command is `loci`. It lets agents search past conversations, retrieve code locations (file + line + symbol), and link conversation history to code symbols. + +Primary user is **the agent itself**, not a human. The tool is invoked via `loci search "..." --json` from within agent prompts. + +## When to use + +- When asked "where did we implement X?" or "where is X?" +- When checking if a similar bug was fixed before +- When verifying if a feature already exists +- When looking up the reasoning behind a past design decision +- Before editing code you lack context about — use `loci context --symbol` to review past discussions +- Before refactoring or changing the behavior of a function — use `loci context --symbol` to check past design decisions +- When recalling work done on a specific branch — use `loci context --branch` to find past conversations + +## CLI Commands + +```bash +loci init # Initialize .codeatrium/ in project root +loci index # Index new .jsonl files +loci distill [--limit N] # Distill queued exchanges via claude --print +loci search "query" --json --limit 5 # Semantic search (agent-facing) +loci search "query" --branch NAME --json # Branch-filtered semantic search +loci context --symbol "Foo.bar" --json # Reverse lookup: code -> past conversations (lightweight; use loci show for full text) +loci context --branch NAME --json # Branch reverse lookup (undistilled exchanges included) +loci show "~/.claude/.../abc.jsonl:ply=42" # Fetch verbatim exchange +loci status # Show index state +loci server start / stop / status # Embedding server management +loci hook install # Register hooks to ~/.claude/settings.json +``` diff --git a/CLAUDE.md b/CLAUDE.md index 2134791..34f87c0 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -101,7 +101,9 @@ loci init # Initialize .codeatrium/ in projec loci index # Index new .jsonl files loci distill [--limit N] # Distill queued exchanges via claude --print loci search "query" --json --limit 5 # Semantic search (agent-facing) +loci search "query" --branch NAME --json # Branch-filtered semantic search loci context --symbol "Foo.bar" --json # Reverse lookup: code -> past conversations +loci context --branch NAME --json # Branch reverse lookup (undistilled exchanges included) loci show "~/.claude/.../abc.jsonl:ply=42" # Fetch verbatim exchange loci status # Show index state loci server start / stop / status # Embedding server management diff --git a/src/codeatrium/cli/prime_cmd.py b/src/codeatrium/cli/prime_cmd.py index c21bfde..52837c6 100644 --- a/src/codeatrium/cli/prime_cmd.py +++ b/src/codeatrium/cli/prime_cmd.py @@ -19,6 +19,7 @@ - **Before editing or refactoring a function** — recall past design decisions and known constraints for that symbol. - **Before starting a new implementation** — check if similar work was done before; reuse decisions and avoid re-debating settled choices. - **When you encounter a known or recurring error** — search for past fixes; the solution may already be documented. +- **When asked about work on a specific branch** — recall what was done and discussed on that branch. ### Search — semantic query over past conversations @@ -30,13 +31,16 @@ loci show "" --json ``` -### Context — reverse lookup from code symbol to past conversations +### Context — reverse lookup from code symbol or git branch to past conversations Touching a symbol = recalling memory about that symbol. Before changing any function or class, look up what was decided about it. ```bash # Retrieve all past conversations that involved this symbol loci context --symbol "SymbolResolver.extract" --json + +# Retrieve past conversations from work on a specific branch +loci context --branch "feature/foo" --json ```\ """ diff --git a/src/codeatrium/cli/search_cmd.py b/src/codeatrium/cli/search_cmd.py index c4b3bbf..0c27807 100644 --- a/src/codeatrium/cli/search_cmd.py +++ b/src/codeatrium/cli/search_cmd.py @@ -12,6 +12,7 @@ def search( query: Annotated[str, typer.Argument(help="検索クエリ")], limit: Annotated[int, typer.Option("--limit", "-n", help="返す件数")] = 5, json_output: Annotated[bool, typer.Option("--json", help="JSON で出力")] = False, + branch: Annotated[str | None, typer.Option("--branch", "-b", help="ブランチ名で絞り込む(部分一致)")] = None, ) -> None: """BM25(V) + HNSW(D) RRF でクエリに近い過去会話を返す""" from codeatrium.embedder import Embedder @@ -32,7 +33,7 @@ def search( embedder = Embedder() query_vec = embedder.embed(query) - results = search_combined(db, query, query_vec, limit=limit) + results = search_combined(db, query, query_vec, limit=limit, branch=branch) if not results: typer.echo("No results found.") @@ -46,6 +47,7 @@ def search( "rooms": r.rooms, "symbols": r.symbols, "verbatim_ref": r.verbatim_ref, + "git_branch": r.git_branch, } for r in results ] @@ -62,14 +64,17 @@ def search( def context( - symbol: Annotated[ - str, typer.Option("--symbol", "-s", help="シンボル名(部分一致)") - ], + symbol: Annotated[str | None, typer.Option("--symbol", "-s", help="シンボル名(部分一致)")] = None, limit: Annotated[int, typer.Option("--limit", "-n", help="返す件数")] = 5, json_output: Annotated[bool, typer.Option("--json", help="JSON で出力")] = False, full: Annotated[bool, typer.Option("--full", help="全文(user_content / agent_content)を含める")] = False, + branch: Annotated[str | None, typer.Option("--branch", "-b", help="ブランチ名で絞り込む(部分一致)")] = None, ) -> None: """シンボル名から関連する過去会話を逆引きする""" + if symbol is None and branch is None: + typer.echo("Error: --symbol or --branch is required.", err=True) + raise typer.Exit(1) + from codeatrium.db import get_connection from codeatrium.paths import db_path, find_project_root @@ -81,30 +86,83 @@ def context( raise typer.Exit(1) con = get_connection(db) - rows = con.execute( - """ - SELECT - s.symbol_name, - s.symbol_kind, - s.file_path, - s.signature, - s.line, - e.id AS exchange_id, - e.user_content, - e.agent_content, - p.exchange_core, - p.specific_context, - c.source_path, - e.ply_start - FROM symbols s - JOIN palace_objects p ON p.id = s.palace_object_id - JOIN exchanges e ON e.id = p.exchange_id - JOIN conversations c ON c.id = e.conversation_id - WHERE s.symbol_name LIKE ? - LIMIT ? - """, - (f"%{symbol}%", limit), - ).fetchall() + + if symbol is not None and branch is not None: + # Both symbol and branch specified + rows = con.execute( + """ + SELECT + s.symbol_name, + s.symbol_kind, + s.file_path, + s.signature, + s.line, + e.id AS exchange_id, + e.user_content, + e.agent_content, + p.exchange_core, + p.specific_context, + c.source_path, + e.ply_start, + e.git_branch + FROM symbols s + JOIN palace_objects p ON p.id = s.palace_object_id + JOIN exchanges e ON e.id = p.exchange_id + JOIN conversations c ON c.id = e.conversation_id + WHERE s.symbol_name LIKE ? AND e.git_branch LIKE ? + LIMIT ? + """, + (f"%{symbol}%", f"%{branch}%", limit), + ).fetchall() + elif symbol is not None: + # Symbol only (existing behavior with git_branch added) + rows = con.execute( + """ + SELECT + s.symbol_name, + s.symbol_kind, + s.file_path, + s.signature, + s.line, + e.id AS exchange_id, + e.user_content, + e.agent_content, + p.exchange_core, + p.specific_context, + c.source_path, + e.ply_start, + e.git_branch + FROM symbols s + JOIN palace_objects p ON p.id = s.palace_object_id + JOIN exchanges e ON e.id = p.exchange_id + JOIN conversations c ON c.id = e.conversation_id + WHERE s.symbol_name LIKE ? + LIMIT ? + """, + (f"%{symbol}%", limit), + ).fetchall() + else: + # Branch only (LEFT JOIN to include undistilled exchanges) + rows = con.execute( + """ + SELECT + e.id AS exchange_id, + e.git_branch, + e.user_content, + e.agent_content, + p.exchange_core, + p.specific_context, + c.source_path, + e.ply_start + FROM exchanges e + JOIN conversations c ON c.id = e.conversation_id + LEFT JOIN palace_objects p ON p.exchange_id = e.id + WHERE e.git_branch LIKE ? + ORDER BY c.started_at, e.ply_start + LIMIT ? + """, + (f"%{branch}%", limit), + ).fetchall() con.close() if not rows: @@ -114,27 +172,50 @@ def context( if json_output: output = [] for r in rows: - base = { - "symbol_name": r["symbol_name"], - "symbol_kind": r["symbol_kind"], - "file_path": r["file_path"], - "signature": r["signature"], - "line": r["line"], - "exchange_id": r["exchange_id"], - "exchange_core": r["exchange_core"], - "specific_context": r["specific_context"], - "verbatim_ref": f"{r['source_path']}:ply={r['ply_start']}", - } - if full: - base["user_content"] = r["user_content"] - base["agent_content"] = r["agent_content"] + if symbol is not None: + # Symbol mode (symbol only or both) + base = { + "symbol_name": r["symbol_name"], + "symbol_kind": r["symbol_kind"], + "file_path": r["file_path"], + "signature": r["signature"], + "line": r["line"], + "exchange_id": r["exchange_id"], + "exchange_core": r["exchange_core"], + "specific_context": r["specific_context"], + "verbatim_ref": f"{r['source_path']}:ply={r['ply_start']}", + "git_branch": r["git_branch"] if "git_branch" in r.keys() else None, + } + if full: + base["user_content"] = r["user_content"] + base["agent_content"] = r["agent_content"] + else: + # Branch-only mode + base = { + "exchange_id": r["exchange_id"], + "git_branch": r["git_branch"], + "exchange_core": r["exchange_core"], + "specific_context": r["specific_context"], + "verbatim_ref": f"{r['source_path']}:ply={r['ply_start']}", + } + if full: + base["user_content"] = r["user_content"] + base["agent_content"] = r["agent_content"] output.append(base) typer.echo(json.dumps(output, ensure_ascii=False, indent=2)) else: for i, r in enumerate(rows, 1): - typer.echo(f"\n[{i}] {r['symbol_kind']} {r['symbol_name']}") - typer.echo(f" {r['file_path']}:{r['line']}") - typer.echo(f" {r['signature']}") - if r["exchange_core"]: - typer.echo(f" Core: {r['exchange_core']}") - typer.echo(f" {r['source_path']}:ply={r['ply_start']}") + if symbol is not None: + # Symbol mode display + typer.echo(f"\n[{i}] {r['symbol_kind']} {r['symbol_name']}") + typer.echo(f" {r['file_path']}:{r['line']}") + typer.echo(f" {r['signature']}") + if r["exchange_core"]: + typer.echo(f" Core: {r['exchange_core']}") + typer.echo(f" {r['source_path']}:ply={r['ply_start']}") + else: + # Branch-only mode display + typer.echo(f"\n[{i}] exchange_id={r['exchange_id']} git_branch={r['git_branch']}") + if r["exchange_core"]: + typer.echo(f" Core: {r['exchange_core']}") + typer.echo(f" {r['source_path']}:ply={r['ply_start']}") diff --git a/src/codeatrium/db.py b/src/codeatrium/db.py index a9ca8b6..a707a5f 100644 --- a/src/codeatrium/db.py +++ b/src/codeatrium/db.py @@ -161,6 +161,84 @@ def _migrate_v7_repair_distill(con: sqlite3.Connection) -> None: ) +def _migrate_v8_add_git_branch(con: sqlite3.Connection) -> None: + """Migration v8: exchanges に git_branch カラムを追加し既存 exchange を jsonl 再パースでバックフィルする""" + import json + + # Guard: check if exchanges table exists + exchanges_table = con.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='exchanges'").fetchone() + if exchanges_table is None: + return + + # CHECK if git_branch column already exists + columns = con.execute("PRAGMA table_info(exchanges)").fetchall() + column_names = [col[1] for col in columns] + + if "git_branch" not in column_names: + con.execute("ALTER TABLE exchanges ADD COLUMN git_branch TEXT") + + # Nested function to extract git_branch from ply targets in a jsonl file + def _extract_git_branch_from_ply(jsonl_path_str: str, ply_targets: list[int]) -> dict[int, str | None]: + """ + Open jsonl file, iterate lines counting only successful json.loads. + For each ply index in ply_targets, look for entry.get('gitBranch'). + Return mapping ply_index -> gitBranch (None if missing or empty string). + """ + result: dict[int, str | None] = {} + try: + with open(jsonl_path_str, encoding='utf-8') as f: + ply_index = 0 + for line in f: + try: + entry = json.loads(line) + if ply_index in ply_targets: + git_branch_raw = entry.get('gitBranch', '') + git_branch = git_branch_raw if isinstance(git_branch_raw, str) and git_branch_raw.strip() else None + result[ply_index] = git_branch + ply_index += 1 + except json.JSONDecodeError: + # Skip malformed lines without incrementing ply_index + pass + except Exception: + # If file cannot be read, silently return empty mapping + pass + + return result + + # QUERY existing exchanges with conversation info + exchanges_exist = con.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='exchanges'").fetchone() + conversations_exist = con.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='conversations'").fetchone() + + if exchanges_exist is None or conversations_exist is None: + return + + rows = con.execute( + "SELECT e.id, e.conversation_id, e.ply_start, c.source_path FROM exchanges e JOIN conversations c ON c.id = e.conversation_id" + ).fetchall() + + # GROUP by source_path and backfill + by_path: dict[str, list[tuple[str, int]]] = {} + for ex_id, conv_id, ply_start, source_path in rows: + if source_path not in by_path: + by_path[source_path] = [] + by_path[source_path].append((ex_id, ply_start)) + + for source_path, exchanges_for_path in by_path.items(): + ply_targets = [ply_start for _, ply_start in exchanges_for_path] + branch_map = _extract_git_branch_from_ply(source_path, ply_targets) + + for ex_id, ply_start in exchanges_for_path: + try: + branch_value = branch_map.get(ply_start) + con.execute( + "UPDATE exchanges SET git_branch = ? WHERE id = ?", + (branch_value, ex_id), + ) + except Exception: + # Silently continue on any exception so migration never aborts + pass + + _MIGRATIONS: list[Callable[[sqlite3.Connection], None]] = [ _migrate_v1_add_last_ply_end, _migrate_v2_add_distill_status, @@ -169,6 +247,7 @@ def _migrate_v7_repair_distill(con: sqlite3.Connection) -> None: _migrate_v5_add_exchange_files, _migrate_v6_recompute_symbol_ids, _migrate_v7_repair_distill, + _migrate_v8_add_git_branch, ] @@ -235,7 +314,8 @@ def init_db(db_path: Path) -> None: user_content TEXT NOT NULL, agent_content TEXT NOT NULL, distilled_at TIMESTAMP, -- NULL = 未蒸留 - distill_status TEXT NOT NULL DEFAULT 'pending' + distill_status TEXT NOT NULL DEFAULT 'pending', + git_branch TEXT ); CREATE VIRTUAL TABLE IF NOT EXISTS exchanges_fts USING fts5( diff --git a/src/codeatrium/indexer.py b/src/codeatrium/indexer.py index 75f3d30..70fbba2 100644 --- a/src/codeatrium/indexer.py +++ b/src/codeatrium/indexer.py @@ -32,6 +32,7 @@ class Exchange: user_content: str agent_content: str files: list[str] = field(default_factory=list) + git_branch: str | None = None # ---- 内部ヘルパー ---- @@ -255,6 +256,8 @@ def parse_exchanges(jsonl_path: Path, min_chars: int = 50, last_ply_end: int = - continue user_uuid = user_entry.get("uuid", f"{start}") + git_branch_raw = user_entry.get("gitBranch", "") + git_branch = git_branch_raw if isinstance(git_branch_raw, str) and git_branch_raw.strip() else None exchange_id = sha256(f"{conversation_id}:{user_uuid}") # tool_use から file パスを抽出 @@ -269,6 +272,7 @@ def parse_exchanges(jsonl_path: Path, min_chars: int = 50, last_ply_end: int = - user_content=user_text, agent_content=agent_text, files=tool_files, + git_branch=git_branch, ) ) @@ -317,8 +321,8 @@ def index_file(jsonl_path: Path, db_path: Path, min_chars: int = 50) -> int: con.execute( """ INSERT OR IGNORE INTO exchanges - (id, conversation_id, ply_start, ply_end, user_content, agent_content) - VALUES (?, ?, ?, ?, ?, ?) + (id, conversation_id, ply_start, ply_end, user_content, agent_content, git_branch) + VALUES (?, ?, ?, ?, ?, ?, ?) """, ( ex.id, @@ -327,6 +331,7 @@ def index_file(jsonl_path: Path, db_path: Path, min_chars: int = 50) -> int: ex.ply_end, ex.user_content, ex.agent_content, + ex.git_branch, ), ) diff --git a/src/codeatrium/models.py b/src/codeatrium/models.py index f3dc7e0..d75ca0d 100644 --- a/src/codeatrium/models.py +++ b/src/codeatrium/models.py @@ -51,3 +51,4 @@ class FusedResult: verbatim_ref: str | None = None rooms: list[dict[str, Any]] = field(default_factory=list) symbols: list[dict[str, Any]] = field(default_factory=list) + git_branch: str | None = None diff --git a/src/codeatrium/search.py b/src/codeatrium/search.py index 51d23ff..6df0f4c 100644 --- a/src/codeatrium/search.py +++ b/src/codeatrium/search.py @@ -97,10 +97,17 @@ def _enrich_results(con: sqlite3.Connection, results: list[FusedResult]) -> None } ) + branch_rows = con.execute( + f'SELECT e.id, e.git_branch FROM exchanges e WHERE e.id IN ({placeholders})', + exchange_ids, + ).fetchall() + branch_map = {r['id']: r['git_branch'] for r in branch_rows} + for r in results: r.verbatim_ref = ref_map.get(r.exchange_id) r.rooms = rooms_map.get(r.exchange_id, []) r.symbols = symbols_map.get(r.exchange_id, []) + r.git_branch = branch_map.get(r.exchange_id) # ---- BM25 verbatim ---- @@ -114,14 +121,16 @@ def _fts5_query(text: str) -> str: def search_bm25( - db_path: Path, query_text: str, limit: int = 10, min_exchanges: int = 2 + db_path: Path, query_text: str, limit: int = 10, min_exchanges: int = 2, branch: str | None = None ) -> list[BM25Result]: """FTS5 BM25 で exchanges_fts を検索する""" fts_query = _fts5_query(query_text) + branch_clause = 'AND e.git_branch LIKE ?' if branch is not None else '' + branch_params: list = [f'%{branch}%'] if branch is not None else [] with closing(get_connection(db_path)) as con: try: rows = con.execute( - """ + f""" SELECT e.id AS exchange_id, e.user_content, @@ -132,10 +141,11 @@ def search_bm25( WHERE exchanges_fts MATCH ? AND (SELECT COUNT(*) FROM exchanges e2 WHERE e2.conversation_id = e.conversation_id) >= ? + {branch_clause} ORDER BY score DESC LIMIT ? """, - (fts_query, min_exchanges, limit), + (fts_query, min_exchanges, *branch_params, limit), ).fetchall() except sqlite3.OperationalError: rows = [] @@ -154,15 +164,17 @@ def search_bm25( def search_hnsw_palace( - db_path: Path, query_vec: np.ndarray, limit: int = 10, min_exchanges: int = 2 + db_path: Path, query_vec: np.ndarray, limit: int = 10, min_exchanges: int = 2, branch: str | None = None ) -> list[HNSWPalaceResult]: """sqlite-vec HNSW で vec_palace を検索する(distilled embedding)""" + branch_clause = 'AND e.git_branch LIKE ?' if branch is not None else '' + branch_params: list = [f'%{branch}%'] if branch is not None else [] with closing(get_connection(db_path)) as con: blob = _serialize(query_vec) try: rows = con.execute( - """ + f""" SELECT p.exchange_id, e.user_content, @@ -180,9 +192,10 @@ def search_hnsw_palace( JOIN exchanges e ON e.id = p.exchange_id WHERE (SELECT COUNT(*) FROM exchanges e2 WHERE e2.conversation_id = e.conversation_id) >= ? + {branch_clause} ORDER BY v.distance """, - (blob, limit, min_exchanges), + (blob, limit, min_exchanges, *branch_params), ).fetchall() except sqlite3.OperationalError: rows = [] @@ -254,13 +267,14 @@ def search_combined( query_vec: np.ndarray, limit: int = 5, min_exchanges: int = 2, + branch: str | None = None, ) -> list[FusedResult]: """BM25(V) + HNSW(D) の RRF 融合検索。""" bm25_results = search_bm25( - db_path, query_text, limit=limit * 2, min_exchanges=min_exchanges + db_path, query_text, limit=limit * 2, min_exchanges=min_exchanges, branch=branch ) hnsw_results = search_hnsw_palace( - db_path, query_vec, limit=limit * 2, min_exchanges=min_exchanges + db_path, query_vec, limit=limit * 2, min_exchanges=min_exchanges, branch=branch ) fused = rrf(bm25_results, hnsw_results, limit=limit) diff --git a/tests/test_db.py b/tests/test_db.py index f153b68..d5332ea 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -937,3 +937,431 @@ def test_migration_v7_removes_bm25_text_column(tmp_path: Path) -> None: ).fetchone() assert row is not None con.close() + + +def test_migration_v8_adds_git_branch_column(tmp_path: Path) -> None: + """Test that v8 migration adds git_branch column to exchanges table.""" + db_path = tmp_path / "memory.db" + + # Create a raw sqlite3 DB with user_version=7 (post-v7, pre-v8) + raw_con = sqlite3.connect(db_path) + raw_con.execute( + """CREATE TABLE conversations ( + id TEXT PRIMARY KEY, + source_path TEXT NOT NULL UNIQUE, + started_at TIMESTAMP, + last_ply_end INT NOT NULL DEFAULT -1 + )""" + ) + raw_con.execute("INSERT INTO conversations(id, source_path) VALUES ('conv1', '/src')") + raw_con.execute( + """CREATE TABLE exchanges ( + id TEXT PRIMARY KEY, + conversation_id TEXT NOT NULL, + ply_start INT NOT NULL, + ply_end INT NOT NULL, + user_content TEXT NOT NULL, + agent_content TEXT NOT NULL, + distilled_at TIMESTAMP, + distill_status TEXT NOT NULL DEFAULT 'pending' + )""" + ) + raw_con.execute("CREATE TABLE meta (key TEXT PRIMARY KEY, value TEXT)") + raw_con.execute( + "INSERT INTO meta VALUES ('embedding_model', 'test-model')" + ) + raw_con.execute( + "INSERT INTO meta VALUES ('prompt_version', 'v0000001')" + ) + raw_con.execute( + """CREATE TABLE palace_objects ( + id TEXT PRIMARY KEY, + exchange_id TEXT NOT NULL, + exchange_core TEXT NOT NULL, + specific_context TEXT NOT NULL, + distill_text TEXT NOT NULL + )""" + ) + raw_con.execute( + """CREATE TABLE rooms ( + id TEXT PRIMARY KEY, + palace_object_id TEXT NOT NULL, + room_type TEXT NOT NULL, + room_key TEXT NOT NULL, + room_label TEXT NOT NULL, + relevance REAL NOT NULL, + dedup_hash TEXT NOT NULL + )""" + ) + raw_con.execute( + """CREATE TABLE symbols ( + id TEXT PRIMARY KEY, + palace_object_id TEXT NOT NULL, + symbol_name TEXT NOT NULL, + symbol_kind TEXT NOT NULL, + file_path TEXT NOT NULL, + signature TEXT NOT NULL, + line INT NOT NULL, + dedup_hash TEXT NOT NULL + )""" + ) + raw_con.execute( + "CREATE TABLE exchange_files (exchange_id TEXT, file_path TEXT, PRIMARY KEY(exchange_id, file_path))" + ) + + raw_con.execute("PRAGMA user_version = 7") + raw_con.commit() + raw_con.close() + + # Run init_db which should run v8 migration + init_db(db_path) + + # Verify git_branch column exists + con = sqlite3.connect(db_path) + table_info = con.execute("PRAGMA table_info(exchanges)").fetchall() + column_names = [col[1] for col in table_info] + assert "git_branch" in column_names + con.close() + + +def test_migration_v8_backfills_git_branch(tmp_path: Path) -> None: + """Test that v8 migration backfills git_branch from jsonl file.""" + import json + + db_path = tmp_path / "memory.db" + jsonl_path = tmp_path / "session.jsonl" + + # Create a jsonl file with a user entry containing gitBranch + user_entry = { + "uuid": "user1", + "gitBranch": "main", + "type": "user", + "content": "This is a long user content string that should be stored in the database", + } + agent_entry = { + "type": "assistant", + "content": "This is a long agent response string that should be stored in the database", + } + jsonl_path.write_text(json.dumps(user_entry) + "\n" + json.dumps(agent_entry) + "\n") + + # Create a v7 DB with a conversation pointing to the jsonl file + raw_con = sqlite3.connect(db_path) + raw_con.execute( + """CREATE TABLE conversations ( + id TEXT PRIMARY KEY, + source_path TEXT NOT NULL UNIQUE, + started_at TIMESTAMP, + last_ply_end INT NOT NULL DEFAULT -1 + )""" + ) + raw_con.execute( + "INSERT INTO conversations(id, source_path) VALUES ('conv1', ?)", + (str(jsonl_path),), + ) + raw_con.execute( + """CREATE TABLE exchanges ( + id TEXT PRIMARY KEY, + conversation_id TEXT NOT NULL, + ply_start INT NOT NULL, + ply_end INT NOT NULL, + user_content TEXT NOT NULL, + agent_content TEXT NOT NULL, + distilled_at TIMESTAMP, + distill_status TEXT NOT NULL DEFAULT 'pending' + )""" + ) + # Insert an exchange at ply_start=0 (matching the user entry in jsonl) + raw_con.execute( + "INSERT INTO exchanges VALUES ('ex1', 'conv1', 0, 1, 'user', 'agent', NULL, 'pending')" + ) + raw_con.execute("CREATE TABLE meta (key TEXT PRIMARY KEY, value TEXT)") + raw_con.execute( + "INSERT INTO meta VALUES ('embedding_model', 'test-model')" + ) + raw_con.execute( + "INSERT INTO meta VALUES ('prompt_version', 'v0000001')" + ) + raw_con.execute( + """CREATE TABLE palace_objects ( + id TEXT PRIMARY KEY, + exchange_id TEXT NOT NULL, + exchange_core TEXT NOT NULL, + specific_context TEXT NOT NULL, + distill_text TEXT NOT NULL + )""" + ) + raw_con.execute( + """CREATE TABLE rooms ( + id TEXT PRIMARY KEY, + palace_object_id TEXT NOT NULL, + room_type TEXT NOT NULL, + room_key TEXT NOT NULL, + room_label TEXT NOT NULL, + relevance REAL NOT NULL, + dedup_hash TEXT NOT NULL + )""" + ) + raw_con.execute( + """CREATE TABLE symbols ( + id TEXT PRIMARY KEY, + palace_object_id TEXT NOT NULL, + symbol_name TEXT NOT NULL, + symbol_kind TEXT NOT NULL, + file_path TEXT NOT NULL, + signature TEXT NOT NULL, + line INT NOT NULL, + dedup_hash TEXT NOT NULL + )""" + ) + raw_con.execute( + "CREATE TABLE exchange_files (exchange_id TEXT, file_path TEXT, PRIMARY KEY(exchange_id, file_path))" + ) + + raw_con.execute("PRAGMA user_version = 7") + raw_con.commit() + raw_con.close() + + # Run init_db which should run v8 migration and backfill git_branch + init_db(db_path) + + # Verify git_branch was backfilled to 'main' + con = sqlite3.connect(db_path) + row = con.execute( + "SELECT git_branch FROM exchanges WHERE id='ex1'" + ).fetchone() + assert row is not None + assert row[0] == "main" + con.close() + + +def test_migration_v8_missing_jsonl_stays_null(tmp_path: Path) -> None: + """Test that v8 migration handles missing jsonl file gracefully, leaving git_branch NULL.""" + db_path = tmp_path / "memory.db" + + # Create a v7 DB with a conversation pointing to a non-existent jsonl file + raw_con = sqlite3.connect(db_path) + raw_con.execute( + """CREATE TABLE conversations ( + id TEXT PRIMARY KEY, + source_path TEXT NOT NULL UNIQUE, + started_at TIMESTAMP, + last_ply_end INT NOT NULL DEFAULT -1 + )""" + ) + raw_con.execute( + "INSERT INTO conversations(id, source_path) VALUES ('conv1', '/nonexistent/path.jsonl')" + ) + raw_con.execute( + """CREATE TABLE exchanges ( + id TEXT PRIMARY KEY, + conversation_id TEXT NOT NULL, + ply_start INT NOT NULL, + ply_end INT NOT NULL, + user_content TEXT NOT NULL, + agent_content TEXT NOT NULL, + distilled_at TIMESTAMP, + distill_status TEXT NOT NULL DEFAULT 'pending' + )""" + ) + # Insert an exchange at ply_start=0 + raw_con.execute( + "INSERT INTO exchanges VALUES ('ex1', 'conv1', 0, 1, 'user', 'agent', NULL, 'pending')" + ) + raw_con.execute("CREATE TABLE meta (key TEXT PRIMARY KEY, value TEXT)") + raw_con.execute( + "INSERT INTO meta VALUES ('embedding_model', 'test-model')" + ) + raw_con.execute( + "INSERT INTO meta VALUES ('prompt_version', 'v0000001')" + ) + raw_con.execute( + """CREATE TABLE palace_objects ( + id TEXT PRIMARY KEY, + exchange_id TEXT NOT NULL, + exchange_core TEXT NOT NULL, + specific_context TEXT NOT NULL, + distill_text TEXT NOT NULL + )""" + ) + raw_con.execute( + """CREATE TABLE rooms ( + id TEXT PRIMARY KEY, + palace_object_id TEXT NOT NULL, + room_type TEXT NOT NULL, + room_key TEXT NOT NULL, + room_label TEXT NOT NULL, + relevance REAL NOT NULL, + dedup_hash TEXT NOT NULL + )""" + ) + raw_con.execute( + """CREATE TABLE symbols ( + id TEXT PRIMARY KEY, + palace_object_id TEXT NOT NULL, + symbol_name TEXT NOT NULL, + symbol_kind TEXT NOT NULL, + file_path TEXT NOT NULL, + signature TEXT NOT NULL, + line INT NOT NULL, + dedup_hash TEXT NOT NULL + )""" + ) + raw_con.execute( + "CREATE TABLE exchange_files (exchange_id TEXT, file_path TEXT, PRIMARY KEY(exchange_id, file_path))" + ) + + raw_con.execute("PRAGMA user_version = 7") + raw_con.commit() + raw_con.close() + + # Run init_db which should run v8 migration without raising an error + init_db(db_path) + + # Verify git_branch is NULL (not filled in due to missing jsonl) + con = sqlite3.connect(db_path) + row = con.execute( + "SELECT git_branch FROM exchanges WHERE id='ex1'" + ).fetchone() + assert row is not None + assert row[0] is None + con.close() + + +def test_migration_v8_malformed_line_coordinate(tmp_path: Path) -> None: + """Test that malformed jsonl lines do NOT shift ply coordinates during backfill. + + Verifies that when a malformed JSON line is at index 0 and a valid gitBranch + entry is at index 1, the exchange with ply_start=1 gets the correct gitBranch, + proving the malformed line did not consume a ply slot. + """ + import json + + db_path = tmp_path / "memory.db" + jsonl_path = tmp_path / "session.jsonl" + + # Create a jsonl file with: + # Line 0: malformed JSON (not valid JSON) + # Line 1: valid JSON user entry with gitBranch='feature-x' + malformed_line = "not-json\n" + user_entry = { + "uuid": "user1", + "gitBranch": "feature-x", + "type": "human", + "content": "This is a long user content string that is valid for the database", + } + jsonl_path.write_text(malformed_line + json.dumps(user_entry) + "\n") + + # Create a v7 DB with a conversation pointing to the jsonl file + raw_con = sqlite3.connect(db_path) + raw_con.execute( + """CREATE TABLE conversations ( + id TEXT PRIMARY KEY, + source_path TEXT NOT NULL UNIQUE, + started_at TIMESTAMP, + last_ply_end INT NOT NULL DEFAULT -1 + )""" + ) + raw_con.execute( + "INSERT INTO conversations(id, source_path) VALUES ('conv1', ?)", + (str(jsonl_path),), + ) + raw_con.execute( + """CREATE TABLE exchanges ( + id TEXT PRIMARY KEY, + conversation_id TEXT NOT NULL, + ply_start INT NOT NULL, + ply_end INT NOT NULL, + user_content TEXT NOT NULL, + agent_content TEXT NOT NULL, + distilled_at TIMESTAMP, + distill_status TEXT NOT NULL DEFAULT 'pending' + )""" + ) + # Insert an exchange at ply_start=0 (after the malformed line) + raw_con.execute( + "INSERT INTO exchanges VALUES ('ex1', 'conv1', 0, 1, 'user', 'agent', NULL, 'pending')" + ) + raw_con.execute("CREATE TABLE meta (key TEXT PRIMARY KEY, value TEXT)") + raw_con.execute( + "INSERT INTO meta VALUES ('embedding_model', 'test-model')" + ) + raw_con.execute( + "INSERT INTO meta VALUES ('prompt_version', 'v0000001')" + ) + raw_con.execute( + """CREATE TABLE palace_objects ( + id TEXT PRIMARY KEY, + exchange_id TEXT NOT NULL, + exchange_core TEXT NOT NULL, + specific_context TEXT NOT NULL, + distill_text TEXT NOT NULL + )""" + ) + raw_con.execute( + """CREATE TABLE rooms ( + id TEXT PRIMARY KEY, + palace_object_id TEXT NOT NULL, + room_type TEXT NOT NULL, + room_key TEXT NOT NULL, + room_label TEXT NOT NULL, + relevance REAL NOT NULL, + dedup_hash TEXT NOT NULL + )""" + ) + raw_con.execute( + """CREATE TABLE symbols ( + id TEXT PRIMARY KEY, + palace_object_id TEXT NOT NULL, + symbol_name TEXT NOT NULL, + symbol_kind TEXT NOT NULL, + file_path TEXT NOT NULL, + signature TEXT NOT NULL, + line INT NOT NULL, + dedup_hash TEXT NOT NULL + )""" + ) + raw_con.execute( + "CREATE TABLE exchange_files (exchange_id TEXT, file_path TEXT, PRIMARY KEY(exchange_id, file_path))" + ) + + raw_con.execute("PRAGMA user_version = 7") + raw_con.commit() + raw_con.close() + + # Run init_db which should run v8 migration and backfill git_branch + init_db(db_path) + + # Verify git_branch was backfilled to 'feature-x' (not NULL) + # This proves ply_index was 1 when the valid line was processed + con = sqlite3.connect(db_path) + row = con.execute( + "SELECT git_branch FROM exchanges WHERE id='ex1'" + ).fetchone() + assert row is not None + assert row[0] == "feature-x" + con.close() + + +def test_migration_v8_idempotent(tmp_path: Path) -> None: + """Test that v8 migration is idempotent (can be run multiple times).""" + db_path = tmp_path / "memory.db" + + # Initialize a fresh DB (which runs v8) + init_db(db_path) + + # Check user_version after first init + con = sqlite3.connect(db_path) + user_version_1 = con.execute("PRAGMA user_version").fetchone()[0] + con.close() + + # Run init_db again (should be idempotent) + init_db(db_path) + + # Check user_version after second init + con = sqlite3.connect(db_path) + user_version_2 = con.execute("PRAGMA user_version").fetchone()[0] + con.close() + + # Verify user_version equals the number of migrations both times + assert user_version_1 == len(_MIGRATIONS) + assert user_version_2 == len(_MIGRATIONS) diff --git a/tests/test_indexer.py b/tests/test_indexer.py index 35d21e0..063d308 100644 --- a/tests/test_indexer.py +++ b/tests/test_indexer.py @@ -12,9 +12,9 @@ def make_user_entry( - uuid: str, text: str, parent_uuid: str | None = None, is_meta: bool = False + uuid: str, text: str, parent_uuid: str | None = None, is_meta: bool = False, git_branch: str | None = None ) -> dict: - return { + entry = { "type": "user", "uuid": uuid, "parentUuid": parent_uuid, @@ -22,6 +22,9 @@ def make_user_entry( "timestamp": "2026-03-26T00:00:00.000Z", "message": {"role": "user", "content": text}, } + if git_branch is not None: + entry["gitBranch"] = git_branch + return entry def make_assistant_entry(uuid: str, text: str, parent_uuid: str) -> dict: @@ -81,6 +84,102 @@ def test_parse_exchanges_single(tmp_path: Path) -> None: assert "pool_size" in exchanges[0].agent_content +def test_parse_exchanges_git_branch_captured(tmp_path: Path) -> None: + """git_branch が capture される""" + f = tmp_path / "session.jsonl" + write_jsonl( + f, + [ + make_user_entry("u1", "connection pool の修正を教えてください。" * 5, git_branch="main"), + make_assistant_entry( + "a1", "pool_size=5 を DATABASE_URL に追加してください。" * 5, "u1" + ), + ], + ) + exchanges = parse_exchanges(f) + assert len(exchanges) == 1 + assert exchanges[0].git_branch == "main" + + +def test_parse_exchanges_git_branch_missing_is_none(tmp_path: Path) -> None: + """git_branch が missing の場合は None になる""" + f = tmp_path / "session.jsonl" + write_jsonl( + f, + [ + make_user_entry("u1", "connection pool の修正を教えてください。" * 5), + make_assistant_entry( + "a1", "pool_size=5 を DATABASE_URL に追加してください。" * 5, "u1" + ), + ], + ) + exchanges = parse_exchanges(f) + assert len(exchanges) == 1 + assert exchanges[0].git_branch is None + + +def test_parse_exchanges_git_branch_empty_string_is_none(tmp_path: Path) -> None: + """git_branch が empty string の場合は None になる""" + f = tmp_path / "session.jsonl" + write_jsonl( + f, + [ + make_user_entry("u1", "connection pool の修正を教えてください。" * 5, git_branch=""), + make_assistant_entry( + "a1", "pool_size=5 を DATABASE_URL に追加してください。" * 5, "u1" + ), + ], + ) + exchanges = parse_exchanges(f) + assert len(exchanges) == 1 + assert exchanges[0].git_branch is None + + +def test_parse_exchanges_branch_per_exchange(tmp_path: Path) -> None: + """各 exchange が異なる git_branch を持つことができる""" + f = tmp_path / "session.jsonl" + write_jsonl( + f, + [ + make_user_entry("u1", "最初の質問です。よろしくお願いします。" * 5, git_branch="main"), + make_assistant_entry("a1", "了解しました。詳しく説明します。" * 5, "u1"), + make_user_entry("u2", "次の質問です。詳しく教えてください。" * 5, "a1", git_branch="release/1.0-hardening"), + make_assistant_entry( + "a2", "詳しく説明します。ご参考になれば幸いです。" * 5, "u2" + ), + ], + ) + exchanges = parse_exchanges(f) + assert len(exchanges) == 2 + assert exchanges[0].git_branch == "main" + assert exchanges[1].git_branch == "release/1.0-hardening" + + +def test_index_file_persists_git_branch(tmp_path: Path) -> None: + """index_file が git_branch を DB に保存する""" + db_path = tmp_path / ".codeatrium" / "memory.db" + init_db(db_path) + + jsonl = tmp_path / "session.jsonl" + write_jsonl( + jsonl, + [ + make_user_entry("u1", "connection pool の修正を教えてください。" * 5, git_branch="feature-x"), + make_assistant_entry( + "a1", "pool_size=5 を DATABASE_URL に追加してください。" * 5, "u1" + ), + ], + ) + + index_file(jsonl, db_path) + + con = get_connection(db_path) + rows = con.execute("SELECT git_branch FROM exchanges").fetchall() + assert len(rows) == 1 + assert rows[0]["git_branch"] == "feature-x" + con.close() + + def test_parse_exchanges_multiple(tmp_path: Path) -> None: """2 user turn = 2 exchange""" f = tmp_path / "session.jsonl" diff --git a/tests/test_search_cmd.py b/tests/test_search_cmd.py index 06b1769..827ea95 100644 --- a/tests/test_search_cmd.py +++ b/tests/test_search_cmd.py @@ -91,7 +91,7 @@ def test_context_full_flag_includes_content(tmp_path, monkeypatch): assert "agent_content" in data[0] -def test_context_default_nine_fields(tmp_path, monkeypatch): +def test_context_default_symbol_fields(tmp_path, monkeypatch): monkeypatch.chdir(tmp_path) db, con = _setup(tmp_path) _insert_fixture(con) @@ -110,6 +110,7 @@ def test_context_default_nine_fields(tmp_path, monkeypatch): "exchange_core", "specific_context", "verbatim_ref", + "git_branch", } @@ -123,3 +124,79 @@ def test_context_verbatim_ref_format(tmp_path, monkeypatch): assert result.exit_code == 0 data = json.loads(result.output) assert data[0]["verbatim_ref"] == "/fake/session.jsonl:ply=10" + + +def test_context_branch_only_returns_results(tmp_path, monkeypatch): + monkeypatch.chdir(tmp_path) + db, con = _setup(tmp_path) + _insert_fixture(con) + # Update git_branch directly + con.execute("UPDATE exchanges SET git_branch=? WHERE id=?", ("main", "ex1")) + con.commit() + con.close() + + result = runner.invoke(app, ["context", "--branch", "main", "--json"]) + assert result.exit_code == 0 + data = json.loads(result.output) + assert len(data) > 0 + assert data[0]["git_branch"] == "main" + assert "exchange_id" in data[0] + + +def test_context_no_args_exits_1(tmp_path, monkeypatch): + monkeypatch.chdir(tmp_path) + db, con = _setup(tmp_path) + _insert_fixture(con) + con.close() + + result = runner.invoke(app, ["context", "--json"]) + assert result.exit_code == 1 + + +def test_context_symbol_has_git_branch_field(tmp_path, monkeypatch): + monkeypatch.chdir(tmp_path) + db, con = _setup(tmp_path) + _insert_fixture(con) + con.close() + + result = runner.invoke(app, ["context", "--symbol", "MyFunc", "--json"]) + assert result.exit_code == 0 + data = json.loads(result.output) + assert "git_branch" in data[0] + + +def test_search_json_has_git_branch_field(tmp_path, monkeypatch): + from unittest.mock import MagicMock, patch + + from codeatrium.models import FusedResult + + monkeypatch.chdir(tmp_path) + db, con = _setup(tmp_path) + _insert_fixture(con) + con.close() + + # Create a mock FusedResult with git_branch set + mock_result = FusedResult( + exchange_id="ex1", + user_content="user content", + agent_content="agent content", + score=0.95, + exchange_core="core summary", + specific_context="specific detail", + verbatim_ref="/fake/session.jsonl:ply=0", + rooms=[], + symbols=[], + git_branch="feature-branch", + ) + + # Embedder の実体はモデルロードが走り出力を汚すためモックする + with ( + patch("codeatrium.embedder.Embedder", return_value=MagicMock()), + patch("codeatrium.search.search_combined", return_value=[mock_result]), + ): + result = runner.invoke(app, ["search", "test query", "--json"]) + assert result.exit_code == 0 + data = json.loads(result.output) + assert len(data) > 0 + assert "git_branch" in data[0] + assert data[0]["git_branch"] == "feature-branch" diff --git a/tests/test_search_phase2.py b/tests/test_search_phase2.py index 9c87e29..15dc8cd 100644 --- a/tests/test_search_phase2.py +++ b/tests/test_search_phase2.py @@ -24,7 +24,7 @@ LONG_TEXT = "connection pool " * 10 -def _insert_exchange(con, ex_id, user_content, agent_content, conv_id="conv1"): +def _insert_exchange(con, ex_id, user_content, agent_content, conv_id="conv1", git_branch: str | None = None): con.execute( "INSERT OR IGNORE INTO conversations (id, source_path) VALUES (?,?)", (conv_id, f"/path/{conv_id}.jsonl"), @@ -41,10 +41,10 @@ def _insert_exchange(con, ex_id, user_content, agent_content, conv_id="conv1"): con.execute( """ INSERT OR IGNORE INTO exchanges - (id, conversation_id, ply_start, ply_end, user_content, agent_content) - VALUES (?,?,?,?,?,?) + (id, conversation_id, ply_start, ply_end, user_content, agent_content, git_branch) + VALUES (?,?,?,?,?,?,?) """, - (ex_id, conv_id, 2, 5, user_content, agent_content), + (ex_id, conv_id, 2, 5, user_content, agent_content, git_branch), ) con.commit() @@ -326,6 +326,39 @@ def test_search_hnsw_connection_leak(tmp_path: Path) -> None: assert fake_con.close.called +def test_search_hnsw_branch_filter_correct_binding(tmp_path: Path) -> None: + """search_hnsw_palace with branch filter binds parameters correctly. + + Verifies that the branch string is bound to the LIKE ? placeholder, + not to the min_exchanges >= ? integer placeholder. If parameter order + is wrong, SQLite will attempt type coercion and either fail or produce + incorrect results. + """ + db_path = tmp_path / "memory.db" + init_db(db_path) + con = get_connection(db_path) + + # Insert two exchanges with different git_branch values + _insert_exchange(con, "hnsw-branch-a", "connection pool management system " * 3, "response a", conv_id="hnsw-conv-branch", git_branch="branch-a") + _insert_exchange(con, "hnsw-branch-b", "database query optimization " * 3, "response b", conv_id="hnsw-conv-branch", git_branch="branch-b") + + # Insert palace_objects and vec_palace for both exchanges + query_vec = np.zeros(384, dtype=np.float32) + _insert_palace(con, "palace-a", "hnsw-branch-a", "palace core a", query_vec) + _insert_palace(con, "palace-b", "hnsw-branch-b", "palace core b", query_vec) + + con.close() + + # Call search_hnsw_palace with branch filter + # If parameter binding is wrong, this will either raise an exception + # or return results from both branches + results = search_hnsw_palace(db_path, query_vec, limit=10, min_exchanges=2, branch="branch-a") + + # Verify that only branch-a results are returned (or empty if no match) + for result in results: + assert result.exchange_id == "hnsw-branch-a", f"Expected only 'hnsw-branch-a', got {result.exchange_id}" + + def test_search_combined_enrich_connection_leak(tmp_path: Path) -> None: """search_combined は enrich 時に exception が出ても enrich con を close する""" db_path = tmp_path / "memory.db" @@ -359,3 +392,59 @@ def mock_get_connection(path): assert fake_enrich_con.close.called if stored_con is not None: stored_con.close() + + +def test_search_bm25_branch_filter(tmp_path: Path) -> None: + """branch フィルタで指定されたブランチのみ返される""" + db_path = tmp_path / "memory.db" + init_db(db_path) + con = get_connection(db_path) + _insert_exchange(con, "ex1", LONG_TEXT, "pool response", git_branch="branch-a") + _insert_exchange(con, "ex2", LONG_TEXT, "pool response", conv_id="conv2", git_branch="branch-b") + con.close() + + results = search_bm25(db_path, "connection pool", branch="branch-a") + assert len(results) == 1 + assert results[0].exchange_id == "ex1" + + +def test_search_bm25_branch_partial_match(tmp_path: Path) -> None: + """branch フィルタは部分一致で動作する""" + db_path = tmp_path / "memory.db" + init_db(db_path) + con = get_connection(db_path) + _insert_exchange(con, "ex1", LONG_TEXT, "pool response", git_branch="release/1.0-hardening") + con.close() + + results = search_bm25(db_path, "connection pool", branch="1.0-hardening") + assert len(results) == 1 + assert results[0].exchange_id == "ex1" + + +def test_search_combined_branch_filter(tmp_path: Path) -> None: + """search_combined で branch フィルタが機能する""" + db_path = tmp_path / "memory.db" + init_db(db_path) + con = get_connection(db_path) + _insert_exchange(con, "ex1", LONG_TEXT, "pool response", git_branch="target-branch") + _insert_exchange(con, "ex2", LONG_TEXT, "pool response", conv_id="conv2", git_branch="other-branch") + con.close() + + vec = np.ones(384, dtype=np.float32) + results = search_combined(db_path, "connection pool", vec, branch="target-branch") + assert any(r.exchange_id == "ex1" for r in results) + assert not any(r.exchange_id == "ex2" for r in results) + + +def test_enrich_results_populates_git_branch(tmp_path: Path) -> None: + """_enrich_results が git_branch を正しく付加する""" + db_path = tmp_path / "memory.db" + init_db(db_path) + con = get_connection(db_path) + _insert_exchange(con, "ex1", LONG_TEXT, "pool response", git_branch="feature-x") + con.close() + + vec = np.ones(384, dtype=np.float32) + results = search_combined(db_path, "connection pool", vec, limit=5) + assert len(results) >= 1 + assert results[0].git_branch == "feature-x" diff --git a/tests/test_status_hook.py b/tests/test_status_hook.py index 9f205ac..d0dfcfd 100644 --- a/tests/test_status_hook.py +++ b/tests/test_status_hook.py @@ -182,6 +182,15 @@ def test_prime_outputs_instructions(tmp_path, monkeypatch): assert "loci show" in result.output +def test_prime_outputs_branch_usage(tmp_path, monkeypatch): + _setup_db(tmp_path) + monkeypatch.chdir(tmp_path) + result = runner.invoke(app, ["prime"]) + assert result.exit_code == 0 + assert "--branch" in result.output + assert "loci context --branch" in result.output + + def test_prime_silent_when_uninitialized(tmp_path, monkeypatch): """.codeatrium/ がないディレクトリでは何も出力せず exit 0 で抜ける""" monkeypatch.chdir(tmp_path)