diff --git a/CLAUDE.md b/CLAUDE.md index 110ce49..2134791 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -122,15 +122,8 @@ loci hook install # Register hooks to ~/.claude/setti ## Past Memory Search (codeatrium) -IMPORTANT: Command usage is injected automatically at session start via `loci prime` (SessionStart hook). +IMPORTANT: Full usage instructions are injected automatically at session start via `loci prime` (SessionStart hook). If not in context, run `loci prime`. - -### Rules - -1. **Search before implementing** — always check if something was discussed or built before starting work. -2. **Check symbols when you lack context** — run `loci context --symbol` before changing a function you don't have enough background on. -3. **Use technical terms** — queries with exact symbol names, error messages, or parameter names yield better results. -4. **Follow up with `loci show`** — when `exchange_core` is ambiguous, fetch the full verbatim conversation. --- diff --git a/README.md b/README.md index 4b78cb7..667edb5 100644 --- a/README.md +++ b/README.md @@ -87,11 +87,12 @@ Agent instructions are injected automatically — no manual setup required: | `loci index` | Index new session logs | | `loci distill [--limit N]` | Distill undistilled exchanges via LLM | | `loci search "query" --json` | Semantic search (agent-facing) | -| `loci context --symbol "name" --json` | Code symbol → past conversations | +| `loci context --symbol "name" --json` | Code symbol → past conversations (lightweight; add `--full` for verbatim text) | | `loci show "" --json` | Retrieve verbatim conversation | | `loci status` | Show index state | | `loci server start/stop/status` | Embedding server management | | `loci hook install` | Re-register hooks (normally already done by `loci init`) | +| `loci hook uninstall` | Remove codeatrium hooks from `settings.json` | ## Automation (Claude Code Hooks) diff --git a/src/codeatrium/cli/__init__.py b/src/codeatrium/cli/__init__.py index 85a31b0..bece7ac 100644 --- a/src/codeatrium/cli/__init__.py +++ b/src/codeatrium/cli/__init__.py @@ -167,7 +167,7 @@ def init( try: con.execute( f""" - UPDATE exchanges SET distilled_at = 'skipped' + UPDATE exchanges SET distilled_at = 'skipped', distill_status = 'skipped' WHERE distilled_at IS NULL AND id IN ( SELECT id FROM exchanges @@ -237,7 +237,7 @@ def _on_progress( else: typer.echo(f" [{cur}/{tot}] distilled", err=True) - count = distill_all( + count, err_count = distill_all( db, model=cfg.distill_model, on_progress=_on_progress, @@ -245,6 +245,8 @@ def _on_progress( distill_min_chars=cfg.distill_min_chars, ) typer.echo(f"Distilled {count} exchange(s).") + if err_count > 0: + typer.echo(f"{err_count} exchange(s) failed — see errors above.", err=True) except KeyboardInterrupt: typer.echo( "\n⚠ Distillation interrupted. " diff --git a/src/codeatrium/cli/distill_cmd.py b/src/codeatrium/cli/distill_cmd.py index b94566c..72ce819 100644 --- a/src/codeatrium/cli/distill_cmd.py +++ b/src/codeatrium/cli/distill_cmd.py @@ -14,6 +14,7 @@ def distill( ] = None, ) -> None: """未蒸留の exchange を claude -p で蒸留して palace_objects を生成する""" + import fcntl import os from codeatrium.config import load_config @@ -30,31 +31,14 @@ def distill( lock_path = db.parent / "distill.lock" - # ロック取得: O_CREAT | O_EXCL で原子的に作成(TOCTOU 防止) + # ロック取得: fcntl.flock で排他ロック(LOCK_NB: 非ブロッキング) + fd = os.open(str(lock_path), os.O_CREAT | os.O_RDWR, 0o600) try: - fd = os.open(lock_path, os.O_CREAT | os.O_EXCL | os.O_WRONLY) - os.write(fd, str(os.getpid()).encode()) + fcntl.flock(fd, fcntl.LOCK_EX | fcntl.LOCK_NB) + except BlockingIOError: os.close(fd) - except FileExistsError: - # 既存ロックのプロセスが生きているか確認 - try: - existing_pid = int(lock_path.read_text().strip()) - os.kill(existing_pid, 0) - typer.echo( - f"loci distill is already running (PID {existing_pid}). Exiting.", - err=True, - ) - raise typer.Exit(0) - except (ValueError, ProcessLookupError, PermissionError): - # stale lock — 再取得 - lock_path.unlink(missing_ok=True) - try: - fd = os.open(lock_path, os.O_CREAT | os.O_EXCL | os.O_WRONLY) - os.write(fd, str(os.getpid()).encode()) - os.close(fd) - except FileExistsError: - typer.echo("loci distill: lost lock race after stale cleanup. Exiting.", err=True) - raise typer.Exit(0) + typer.echo("loci distill is already running. Exiting.", err=True) + raise typer.Exit(0) def _on_progress(cur: int, tot: int, error: str | None = None) -> None: if error: @@ -63,7 +47,16 @@ def _on_progress(cur: int, tot: int, error: str | None = None) -> None: typer.echo(f" [{cur}/{tot}] distilled", err=True) try: - count = distill_all( + from codeatrium.db import check_drift + + drifts = check_drift(db) + for key, recorded, current in drifts: + typer.echo( + f"[warn] {key} changed ({recorded} -> {current}). Re-index recommended.", + err=True, + ) + + count, err_count = distill_all( db, limit=limit, model=cfg.distill_model, @@ -72,5 +65,8 @@ def _on_progress(cur: int, tot: int, error: str | None = None) -> None: distill_min_chars=cfg.distill_min_chars, ) typer.echo(f"Distilled {count} exchange(s).") + if err_count > 0: + typer.echo(f"{err_count} exchange(s) failed — see errors above.", err=True) finally: - lock_path.unlink(missing_ok=True) + fcntl.flock(fd, fcntl.LOCK_UN) + os.close(fd) diff --git a/src/codeatrium/cli/hook_cmd.py b/src/codeatrium/cli/hook_cmd.py index 26045a8..8f74dd9 100644 --- a/src/codeatrium/cli/hook_cmd.py +++ b/src/codeatrium/cli/hook_cmd.py @@ -22,3 +22,12 @@ def hook_install() -> None: cfg = load_config(find_project_root()) _changed, message = install_hooks(batch_limit=cfg.distill_batch_limit) typer.echo(message) + + +@hook_app.command("uninstall") +def hook_uninstall() -> None: + """Claude Code の settings.json から codeatrium 関連フックをすべて除去する。""" + from codeatrium.hooks import uninstall_hooks + + _changed, message = uninstall_hooks() + typer.echo(message) diff --git a/src/codeatrium/cli/index_cmd.py b/src/codeatrium/cli/index_cmd.py index dd2b777..a8dec2b 100644 --- a/src/codeatrium/cli/index_cmd.py +++ b/src/codeatrium/cli/index_cmd.py @@ -32,6 +32,10 @@ def index( raise typer.Exit(1) init_db(db) + from codeatrium.db import check_drift + drifts = check_drift(db) + for key, recorded, current in drifts: + typer.echo(f"[warn] {key} changed ({recorded} -> {current}). Re-index recommended.", err=True) cfg = load_config(root) target_dir = path or resolve_claude_projects_path(root) diff --git a/src/codeatrium/cli/prime_cmd.py b/src/codeatrium/cli/prime_cmd.py index afa8372..c21bfde 100644 --- a/src/codeatrium/cli/prime_cmd.py +++ b/src/codeatrium/cli/prime_cmd.py @@ -12,28 +12,31 @@ PRIME_TEXT = """\ ## Past Memory Search (codeatrium) -Use `loci search` to find past implementations, decisions, and code locations. +codeatrium records every past conversation turn, decision, and code location. Retrieve that memory **before** acting — not after. -### When to use +### When to act (agent-initiated triggers) -- When asked "where did we implement X?" or "where is X?" -- When checking if a similar bug was fixed before -- When verifying if a feature already exists -- When looking up the reasoning behind a past design decision -- Before editing code you lack context about — use `loci context --symbol` to review past discussions -- Before refactoring or changing the behavior of a function — use `loci context --symbol` to check past design decisions +- **Before editing or refactoring a function** — recall past design decisions and known constraints for that symbol. +- **Before starting a new implementation** — check if similar work was done before; reuse decisions and avoid re-debating settled choices. +- **When you encounter a known or recurring error** — search for past fixes; the solution may already be documented. -### Commands +### Search — semantic query over past conversations ```bash -# Semantic search -loci search "query" --json --limit 5 +# Find past discussions, decisions, or implementations +loci search "BM25 RRF fusion ranking" --json --limit 5 -# Reverse lookup: code symbol -> past conversations -loci context --symbol "symbol_name" --json - -# Retrieve verbatim conversation (use verbatim_ref from search results) +# Retrieve verbatim exchange (use verbatim_ref from search results) loci show "" --json +``` + +### Context — reverse lookup from code symbol to past conversations + +Touching a symbol = recalling memory about that symbol. Before changing any function or class, look up what was decided about it. + +```bash +# Retrieve all past conversations that involved this symbol +loci context --symbol "SymbolResolver.extract" --json ```\ """ @@ -41,15 +44,8 @@ {BEGIN_MARKER} ## Past Memory Search (codeatrium) -IMPORTANT: Command usage is injected automatically at session start via `loci prime` (SessionStart hook). +IMPORTANT: Full usage instructions are injected automatically at session start via `loci prime` (SessionStart hook). If not in context, run `loci prime`. - -### Rules - -1. **Search before implementing** — always check if something was discussed or built before starting work. -2. **Check symbols when you lack context** — run `loci context --symbol` before changing a function you don't have enough background on. -3. **Use technical terms** — queries with exact symbol names, error messages, or parameter names yield better results. -4. **Follow up with `loci show`** — when `exchange_core` is ambiguous, fetch the full verbatim conversation. {END_MARKER}\ """ diff --git a/src/codeatrium/cli/search_cmd.py b/src/codeatrium/cli/search_cmd.py index ab134ff..c4b3bbf 100644 --- a/src/codeatrium/cli/search_cmd.py +++ b/src/codeatrium/cli/search_cmd.py @@ -25,6 +25,11 @@ def search( typer.echo("Not initialized. Run `loci init` first.", err=True) raise typer.Exit(1) + from codeatrium.db import check_drift + drifts = check_drift(db) + for key, recorded, current in drifts: + typer.echo(f"[warn] {key} changed ({recorded} -> {current}). Re-index recommended.", err=True) + embedder = Embedder() query_vec = embedder.embed(query) results = search_combined(db, query, query_vec, limit=limit) @@ -62,6 +67,7 @@ def context( ], limit: Annotated[int, typer.Option("--limit", "-n", help="返す件数")] = 5, json_output: Annotated[bool, typer.Option("--json", help="JSON で出力")] = False, + full: Annotated[bool, typer.Option("--full", help="全文(user_content / agent_content)を含める")] = False, ) -> None: """シンボル名から関連する過去会話を逆引きする""" from codeatrium.db import get_connection @@ -87,10 +93,13 @@ def context( e.user_content, e.agent_content, p.exchange_core, - p.specific_context + p.specific_context, + c.source_path, + e.ply_start FROM symbols s JOIN palace_objects p ON p.id = s.palace_object_id JOIN exchanges e ON e.id = p.exchange_id + JOIN conversations c ON c.id = e.conversation_id WHERE s.symbol_name LIKE ? LIMIT ? """, @@ -103,8 +112,9 @@ def context( return if json_output: - output = [ - { + output = [] + for r in rows: + base = { "symbol_name": r["symbol_name"], "symbol_kind": r["symbol_kind"], "file_path": r["file_path"], @@ -113,11 +123,12 @@ def context( "exchange_id": r["exchange_id"], "exchange_core": r["exchange_core"], "specific_context": r["specific_context"], - "user_content": r["user_content"], - "agent_content": r["agent_content"], + "verbatim_ref": f"{r['source_path']}:ply={r['ply_start']}", } - for r in rows - ] + if full: + base["user_content"] = r["user_content"] + base["agent_content"] = r["agent_content"] + output.append(base) typer.echo(json.dumps(output, ensure_ascii=False, indent=2)) else: for i, r in enumerate(rows, 1): @@ -126,3 +137,4 @@ def context( typer.echo(f" {r['signature']}") if r["exchange_core"]: typer.echo(f" Core: {r['exchange_core']}") + typer.echo(f" {r['source_path']}:ply={r['ply_start']}") diff --git a/src/codeatrium/cli/server_cmd.py b/src/codeatrium/cli/server_cmd.py index e8072f9..6e2946c 100644 --- a/src/codeatrium/cli/server_cmd.py +++ b/src/codeatrium/cli/server_cmd.py @@ -6,11 +6,14 @@ server_app = typer.Typer(help="embedding サーバー管理") +_SERVER_STARTUP_POLL_ATTEMPTS: int = 150 # サーバー起動確認のポーリング回数(0.2秒 × 150 = 最大30秒待機) + @server_app.command("start") def server_start() -> None: """embedding サーバーをバックグラウンドで起動する""" import json as _json + import os import socket as _socket import subprocess @@ -36,8 +39,20 @@ def server_start() -> None: return except Exception: sock.unlink(missing_ok=True) + server_pid_path(root).unlink(missing_ok=True) pid_path = server_pid_path(root) + if pid_path.exists(): + try: + _pid = int(pid_path.read_text().strip()) + os.kill(_pid, 0) + except (ProcessLookupError, ValueError): + # プロセス不在 or 不正な PID → stale とみなして除去 + pid_path.unlink(missing_ok=True) + except PermissionError: + # 別ユーザーのプロセスが生存 → 触らない + pass + proc = subprocess.Popen( [_loci_python(), "-m", "codeatrium.embedder_server", str(sock)], stdout=subprocess.DEVNULL, @@ -48,7 +63,7 @@ def server_start() -> None: import time - for i in range(150): + for i in range(_SERVER_STARTUP_POLL_ATTEMPTS): if sock.exists(): typer.echo(f"Server started (PID {proc.pid})") return diff --git a/src/codeatrium/cli/status_cmd.py b/src/codeatrium/cli/status_cmd.py index a60662e..30d3d76 100644 --- a/src/codeatrium/cli/status_cmd.py +++ b/src/codeatrium/cli/status_cmd.py @@ -25,12 +25,26 @@ def status( con = get_connection(db) total = con.execute("SELECT COUNT(*) FROM exchanges").fetchone()[0] distilled = con.execute( - "SELECT COUNT(*) FROM exchanges WHERE distilled_at IS NOT NULL" + "SELECT COUNT(*) FROM exchanges WHERE distill_status = 'distilled'" + ).fetchone()[0] + skipped = con.execute( + "SELECT COUNT(*) FROM exchanges WHERE distill_status = 'skipped'" + ).fetchone()[0] + pending = con.execute( + "SELECT COUNT(*) FROM exchanges WHERE distill_status = 'pending'" ).fetchone()[0] palace_count = con.execute("SELECT COUNT(*) FROM palace_objects").fetchone()[0] symbol_count = con.execute("SELECT COUNT(*) FROM symbols").fetchone()[0] con.close() + from codeatrium.db import check_drift + drifts = check_drift(db) + for key, recorded, current in drifts: + typer.echo( + f"[drift] {key}: recorded={recorded}, current={current} — re-index recommended", + err=True, + ) + db_size_bytes = db.stat().st_size db_size_kb = db_size_bytes / 1024 @@ -41,7 +55,8 @@ def status( "db_path": str(db), "exchanges": total, "distilled": distilled, - "undistilled": total - distilled, + "skipped": skipped, + "pending": pending, "palace_objects": palace_count, "symbols": symbol_count, "db_size_kb": round(db_size_kb, 1), @@ -53,7 +68,7 @@ def status( else: typer.echo(f"DB: {db} ({db_size_kb:.1f} KB)") typer.echo( - f"Exchanges : {total} total, {distilled} distilled, {total - distilled} pending" + f"Exchanges : {total} total | {distilled} distilled, {skipped} skipped, {pending} pending" ) typer.echo(f"Palace : {palace_count}") typer.echo(f"Symbols : {symbol_count}") diff --git a/src/codeatrium/config.py b/src/codeatrium/config.py index 686f4af..ae7f09a 100644 --- a/src/codeatrium/config.py +++ b/src/codeatrium/config.py @@ -2,6 +2,7 @@ from __future__ import annotations +import sys import tomllib from dataclasses import dataclass from pathlib import Path @@ -38,9 +39,7 @@ def load_config(project_root: Path) -> Config: try: with config_path.open("rb") as f: data = tomllib.load(f) - except Exception as e: - import sys - + except (FileNotFoundError, tomllib.TOMLDecodeError, OSError) as e: print(f"Warning: failed to parse {config_path}: {e}", file=sys.stderr) return Config() @@ -48,8 +47,6 @@ def load_config(project_root: Path) -> Config: model = distill.get("model", DEFAULT_DISTILL_MODEL) if not isinstance(model, str) or not model.strip(): - import sys - print( "Warning: distill.model must be a non-empty string, using default.", file=sys.stderr, @@ -58,8 +55,6 @@ def load_config(project_root: Path) -> Config: batch_limit = distill.get("batch_limit", DEFAULT_DISTILL_BATCH_LIMIT) if not isinstance(batch_limit, int) or batch_limit < 1: - import sys - print( "Warning: distill.batch_limit must be a positive integer, using default.", file=sys.stderr, @@ -70,8 +65,6 @@ def load_config(project_root: Path) -> Config: min_chars = index.get("min_chars", DEFAULT_INDEX_MIN_CHARS) if not isinstance(min_chars, int) or min_chars < 1: - import sys - print( "Warning: index.min_chars must be a positive integer, using default.", file=sys.stderr, @@ -80,8 +73,6 @@ def load_config(project_root: Path) -> Config: distill_min_chars = distill.get("min_chars", DEFAULT_DISTILL_MIN_CHARS) if not isinstance(distill_min_chars, int) or distill_min_chars < 1: - import sys - print( "Warning: distill.min_chars must be a positive integer, using default.", file=sys.stderr, diff --git a/src/codeatrium/db.py b/src/codeatrium/db.py index d075118..a9ca8b6 100644 --- a/src/codeatrium/db.py +++ b/src/codeatrium/db.py @@ -10,20 +10,192 @@ rooms - palace object の room_assignments vec_palace - sqlite-vec HNSW インデックス(Phase2 distilled ベクトル検索用) symbols - tree-sitter 解決済みシンボル(Phase3 コード逆引き用) + _MIGRATIONS - 逐次マイグレーション関数リスト(user_version ベース) """ +import hashlib +import os import sqlite3 +from collections.abc import Callable from pathlib import Path import sqlite_vec +def _migrate_v1_add_last_ply_end(con: sqlite3.Connection) -> None: + """Migration v1: Add last_ply_end column to conversations table if absent.""" + columns = con.execute("PRAGMA table_info(conversations)").fetchall() + column_names = [col[1] for col in columns] + + if "last_ply_end" not in column_names: + con.execute( + "ALTER TABLE conversations ADD COLUMN last_ply_end INT NOT NULL DEFAULT -1" + ) + + +def _migrate_v2_add_distill_status(con: sqlite3.Connection) -> None: + """Migration v2: exchanges に distill_status カラムを追加し既存データを変換する""" + table_exists = con.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='exchanges'").fetchone() + if table_exists is None: + return + columns = con.execute("PRAGMA table_info(exchanges)").fetchall() + column_names = [col[1] for col in columns] + + if "distill_status" not in column_names: + con.execute( + "ALTER TABLE exchanges ADD COLUMN distill_status TEXT NOT NULL DEFAULT 'pending'" + ) + + con.execute( + "UPDATE exchanges SET distill_status='skipped', distilled_at=NULL WHERE distilled_at='skipped'" + ) + con.execute( + "UPDATE exchanges SET distill_status='distilled' WHERE distilled_at IS NOT NULL AND distilled_at != 'skipped'" + ) + + +def _migrate_v3_add_meta(con: sqlite3.Connection) -> None: + """Migration v3: meta テーブルを新設し embedding_model と prompt_version を初期化する""" + con.execute( + "CREATE TABLE IF NOT EXISTS meta (key TEXT PRIMARY KEY, value TEXT)" + ) + + from codeatrium.embedder import MODEL_NAME + from codeatrium.llm import DISTILL_PROMPT_VERSION + + con.execute( + "INSERT OR IGNORE INTO meta(key,value) VALUES (?,?)", + ("embedding_model", MODEL_NAME), + ) + con.execute( + "INSERT OR IGNORE INTO meta(key,value) VALUES (?,?)", + ("prompt_version", DISTILL_PROMPT_VERSION), + ) + + +def _migrate_v4_add_indexes(con: sqlite3.Connection) -> None: + """Migration v4: rooms/symbols/palace_objects に検索用インデックスを追加する""" + rooms_exists = con.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='rooms'").fetchone() + if rooms_exists is not None: + con.execute( + "CREATE INDEX IF NOT EXISTS idx_rooms_palace_object_id ON rooms(palace_object_id)" + ) + + symbols_exists = con.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='symbols'").fetchone() + if symbols_exists is not None: + con.execute( + "CREATE INDEX IF NOT EXISTS idx_symbols_palace_object_id ON symbols(palace_object_id)" + ) + + palace_objects_exists = con.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='palace_objects'").fetchone() + if palace_objects_exists is not None: + con.execute( + "CREATE INDEX IF NOT EXISTS idx_palace_objects_exchange_id ON palace_objects(exchange_id)" + ) + + +def _migrate_v5_add_exchange_files(con: sqlite3.Connection) -> None: + """Migration v5: exchange_files テーブルを新設する""" + table_exists = con.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='exchange_files'").fetchone() + if table_exists is None: + con.execute( + "CREATE TABLE exchange_files (exchange_id TEXT, file_path TEXT, PRIMARY KEY(exchange_id, file_path))" + ) + + +def _migrate_v6_recompute_symbol_ids(con: sqlite3.Connection) -> None: + """Migration v6: symbols テーブルの id カラムを hash 再計算する""" + symbols_exists = con.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='symbols'").fetchone() + if symbols_exists is None: + return + + rows = con.execute("SELECT rowid, symbol_name, file_path, palace_object_id FROM symbols").fetchall() + for rowid, symbol_name, file_path, palace_object_id in rows: + new_id = hashlib.sha256((symbol_name + ":" + file_path + ":" + palace_object_id).encode()).hexdigest() + con.execute("UPDATE symbols SET id=? WHERE rowid=?", (new_id, rowid)) + + +def _migrate_v7_repair_distill(con: sqlite3.Connection) -> None: + """Migration v7: palace_objects テーブルから bm25_text を削除・distill ステータス修復・orphan クリーンアップ""" + # STEP1: bm25_text カラム削除 + palace_objects_exists = con.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='palace_objects'").fetchone() + if palace_objects_exists is not None: + columns = con.execute("PRAGMA table_info(palace_objects)").fetchall() + column_names = [col[1] for col in columns] + if "bm25_text" in column_names: + con.execute( + "CREATE TABLE palace_objects_new (id TEXT PRIMARY KEY, exchange_id TEXT NOT NULL, exchange_core TEXT NOT NULL, specific_context TEXT NOT NULL, distill_text TEXT NOT NULL)" + ) + con.execute( + "INSERT INTO palace_objects_new (id, exchange_id, exchange_core, specific_context, distill_text) SELECT id, exchange_id, exchange_core, specific_context, distill_text FROM palace_objects" + ) + con.execute("DROP TABLE palace_objects") + con.execute("ALTER TABLE palace_objects_new RENAME TO palace_objects") + con.execute("CREATE INDEX IF NOT EXISTS idx_palace_objects_exchange_id ON palace_objects(exchange_id)") + + # STEP2: re-distill reset + exchanges_exists = con.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='exchanges'").fetchone() + palace_objects_exists = con.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='palace_objects'").fetchone() + if exchanges_exists is not None and palace_objects_exists is not None: + con.execute( + "UPDATE exchanges SET distill_status='pending', distilled_at=NULL WHERE distill_status='distilled' AND id NOT IN (SELECT exchange_id FROM palace_objects)" + ) + + # STEP3: orphan cleanup + rooms_exists = con.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='rooms'").fetchone() + if rooms_exists is not None: + con.execute( + "DELETE FROM rooms WHERE palace_object_id NOT IN (SELECT id FROM palace_objects)" + ) + + symbols_exists = con.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='symbols'").fetchone() + if symbols_exists is not None: + con.execute( + "DELETE FROM symbols WHERE palace_object_id NOT IN (SELECT id FROM palace_objects)" + ) + + vec_palace_exists = con.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='vec_palace'").fetchone() + if vec_palace_exists is not None: + con.execute( + "DELETE FROM vec_palace WHERE palace_id NOT IN (SELECT id FROM palace_objects)" + ) + + +_MIGRATIONS: list[Callable[[sqlite3.Connection], None]] = [ + _migrate_v1_add_last_ply_end, + _migrate_v2_add_distill_status, + _migrate_v3_add_meta, + _migrate_v4_add_indexes, + _migrate_v5_add_exchange_files, + _migrate_v6_recompute_symbol_ids, + _migrate_v7_repair_distill, +] + + +def _run_migrations(con: sqlite3.Connection) -> None: + """Run pending migrations based on PRAGMA user_version.""" + current_version: int = con.execute("PRAGMA user_version").fetchone()[0] + + for target_version, fn in enumerate(_MIGRATIONS, start=1): + if target_version > current_version: + con.execute("BEGIN") + try: + fn(con) + con.execute(f"PRAGMA user_version = {target_version}") + con.execute("COMMIT") + except Exception: + con.rollback() + raise + + def get_connection(db_path: Path) -> sqlite3.Connection: - """sqlite-vec 拡張をロードした接続を返す""" - con = sqlite3.connect(db_path) + """sqlite-vec 拡張をロードし WAL モード・busy_timeout を設定した接続を返す""" + con = sqlite3.connect(db_path, timeout=10.0) con.enable_load_extension(True) sqlite_vec.load(con) con.enable_load_extension(False) + con.execute("PRAGMA journal_mode=WAL") + con.execute("PRAGMA busy_timeout=10000") con.row_factory = sqlite3.Row return con @@ -32,88 +204,125 @@ def init_db(db_path: Path) -> None: """DB を初期化してスキーマを作成する(冪等)""" db_path.parent.mkdir(parents=True, exist_ok=True) con = get_connection(db_path) + # memory.db 本体と WAL サイドカー(-wal / -shm)は会話・コードの逐語データを含むため + # 所有者のみ読み書き可(0o600)にする。WAL モードでサイドカーが生成される。 + os.chmod(db_path, 0o600) + for suffix in ("-wal", "-shm"): + sidecar = db_path.parent / (db_path.name + suffix) + if sidecar.exists(): + os.chmod(sidecar, 0o600) - con.executescript(""" - CREATE TABLE IF NOT EXISTS conversations ( - id TEXT PRIMARY KEY, -- sha256(source_path) - source_path TEXT NOT NULL UNIQUE, - started_at TIMESTAMP, - last_ply_end INT NOT NULL DEFAULT -1 -- 最後にインデックスした ply_end(差分用) - ); - - CREATE TABLE IF NOT EXISTS exchanges ( - id TEXT PRIMARY KEY, -- sha256(conversation_id + ":" + user_uuid) - conversation_id TEXT NOT NULL, - ply_start INT NOT NULL, - ply_end INT NOT NULL, - user_content TEXT NOT NULL, - agent_content TEXT NOT NULL, - distilled_at TIMESTAMP -- NULL = 未蒸留 - ); - - CREATE VIRTUAL TABLE IF NOT EXISTS exchanges_fts USING fts5( - user_content, - agent_content, - content=exchanges, - content_rowid=rowid - ); - - CREATE TRIGGER IF NOT EXISTS exchanges_ai - AFTER INSERT ON exchanges BEGIN - INSERT INTO exchanges_fts(rowid, user_content, agent_content) - VALUES (new.rowid, new.user_content, new.agent_content); - END; - - CREATE TRIGGER IF NOT EXISTS exchanges_ad - AFTER DELETE ON exchanges BEGIN - INSERT INTO exchanges_fts(exchanges_fts, rowid, user_content, agent_content) - VALUES ('delete', old.rowid, old.user_content, old.agent_content); - END; - - CREATE TRIGGER IF NOT EXISTS exchanges_au - AFTER UPDATE ON exchanges BEGIN - INSERT INTO exchanges_fts(exchanges_fts, rowid, user_content, agent_content) - VALUES ('delete', old.rowid, old.user_content, old.agent_content); - INSERT INTO exchanges_fts(rowid, user_content, agent_content) - VALUES (new.rowid, new.user_content, new.agent_content); - END; - - CREATE TABLE IF NOT EXISTS palace_objects ( - id TEXT PRIMARY KEY, - exchange_id TEXT NOT NULL, - exchange_core TEXT NOT NULL, - specific_context TEXT NOT NULL, - distill_text TEXT NOT NULL -- exchange_core + newline + specific_context - ); - - CREATE TABLE IF NOT EXISTS rooms ( - id TEXT PRIMARY KEY, - palace_object_id TEXT NOT NULL, - room_type TEXT NOT NULL, -- "file" / "concept" / "workflow" - room_key TEXT NOT NULL, - room_label TEXT NOT NULL, - relevance REAL NOT NULL, - dedup_hash TEXT NOT NULL -- hash(room_type, room_key) - ); - - CREATE TABLE IF NOT EXISTS symbols ( - id TEXT PRIMARY KEY, -- sha256(symbol_name + file_path) - palace_object_id TEXT NOT NULL, - symbol_name TEXT NOT NULL, -- "AuthMiddleware.validate" - symbol_kind TEXT NOT NULL, -- "function" / "class" / "method" - file_path TEXT NOT NULL, - signature TEXT NOT NULL, - line INT NOT NULL, - dedup_hash TEXT NOT NULL -- sha256(symbol_name + file_path) - ); - """) + # Check if conversations table exists (indicates existing DB) + table_exists = con.execute( + 'SELECT name FROM sqlite_master WHERE type="table" AND name="conversations"' + ).fetchone() - # マイグレーション: last_ply_end カラムが無い既存 DB に追加 - try: - con.execute("ALTER TABLE conversations ADD COLUMN last_ply_end INT NOT NULL DEFAULT -1") - con.commit() - except Exception: - pass # カラムが既に存在する場合は無視 + if table_exists is None: + # New DB: run core schema + con.executescript(""" + CREATE TABLE IF NOT EXISTS conversations ( + id TEXT PRIMARY KEY, -- sha256(source_path) + source_path TEXT NOT NULL UNIQUE, + started_at TIMESTAMP, + last_ply_end INT NOT NULL DEFAULT -1 -- 最後にインデックスした ply_end(差分用) + ); + + CREATE TABLE IF NOT EXISTS exchanges ( + id TEXT PRIMARY KEY, -- sha256(conversation_id + ":" + user_uuid) + conversation_id TEXT NOT NULL, + ply_start INT NOT NULL, + ply_end INT NOT NULL, + user_content TEXT NOT NULL, + agent_content TEXT NOT NULL, + distilled_at TIMESTAMP, -- NULL = 未蒸留 + distill_status TEXT NOT NULL DEFAULT 'pending' + ); + + CREATE VIRTUAL TABLE IF NOT EXISTS exchanges_fts USING fts5( + user_content, + agent_content, + content=exchanges, + content_rowid=rowid + ); + + CREATE TRIGGER IF NOT EXISTS exchanges_ai + AFTER INSERT ON exchanges BEGIN + INSERT INTO exchanges_fts(rowid, user_content, agent_content) + VALUES (new.rowid, new.user_content, new.agent_content); + END; + + CREATE TRIGGER IF NOT EXISTS exchanges_ad + AFTER DELETE ON exchanges BEGIN + INSERT INTO exchanges_fts(exchanges_fts, rowid, user_content, agent_content) + VALUES ('delete', old.rowid, old.user_content, old.agent_content); + END; + + CREATE TRIGGER IF NOT EXISTS exchanges_au + AFTER UPDATE ON exchanges BEGIN + INSERT INTO exchanges_fts(exchanges_fts, rowid, user_content, agent_content) + VALUES ('delete', old.rowid, old.user_content, old.agent_content); + INSERT INTO exchanges_fts(rowid, user_content, agent_content) + VALUES (new.rowid, new.user_content, new.agent_content); + END; + + CREATE TABLE IF NOT EXISTS palace_objects ( + id TEXT PRIMARY KEY, + exchange_id TEXT NOT NULL, + exchange_core TEXT NOT NULL, + specific_context TEXT NOT NULL, + distill_text TEXT NOT NULL -- exchange_core + newline + specific_context + ); + + CREATE TABLE IF NOT EXISTS rooms ( + id TEXT PRIMARY KEY, + palace_object_id TEXT NOT NULL, + room_type TEXT NOT NULL, -- "file" / "concept" / "workflow" + room_key TEXT NOT NULL, + room_label TEXT NOT NULL, + relevance REAL NOT NULL, + dedup_hash TEXT NOT NULL -- hash(room_type, room_key) + ); + + CREATE TABLE IF NOT EXISTS symbols ( + id TEXT PRIMARY KEY, -- sha256(symbol_name + file_path) + palace_object_id TEXT NOT NULL, + symbol_name TEXT NOT NULL, -- "AuthMiddleware.validate" + symbol_kind TEXT NOT NULL, -- "function" / "class" / "method" + file_path TEXT NOT NULL, + signature TEXT NOT NULL, + line INT NOT NULL, + dedup_hash TEXT NOT NULL -- sha256(symbol_name + file_path) + ); + + CREATE TABLE IF NOT EXISTS exchange_files ( + exchange_id TEXT, + file_path TEXT, + PRIMARY KEY (exchange_id, file_path) + ); + + CREATE TABLE IF NOT EXISTS meta (key TEXT PRIMARY KEY, value TEXT); + + CREATE INDEX IF NOT EXISTS idx_rooms_palace_object_id ON rooms(palace_object_id); + CREATE INDEX IF NOT EXISTS idx_symbols_palace_object_id ON symbols(palace_object_id); + CREATE INDEX IF NOT EXISTS idx_palace_objects_exchange_id ON palace_objects(exchange_id); + """) + + from codeatrium.embedder import MODEL_NAME + from codeatrium.llm import DISTILL_PROMPT_VERSION + + con.execute( + "INSERT OR IGNORE INTO meta(key,value) VALUES (?,?)", + ("embedding_model", MODEL_NAME), + ) + con.execute( + "INSERT OR IGNORE INTO meta(key,value) VALUES (?,?)", + ("prompt_version", DISTILL_PROMPT_VERSION), + ) + + con.execute(f"PRAGMA user_version = {len(_MIGRATIONS)}") + else: + # Existing DB: run migrations + _run_migrations(con) # sqlite-vec の仮想テーブル(HNSW, Phase1 verbatim embedding 用) con.execute(""" @@ -133,3 +342,38 @@ def init_db(db_path: Path) -> None: con.commit() con.close() + + +def check_drift(db_path: Path) -> list[tuple[str, str, str]]: + """meta テーブルの記録値と現行値を比較し不一致の (key, recorded, current) タプルリストを返す""" + con = get_connection(db_path) + try: + # Check if meta table exists + meta_exists = con.execute( + 'SELECT name FROM sqlite_master WHERE type="table" AND name="meta"' + ).fetchone() + + if meta_exists is None: + return [] + + from codeatrium.embedder import MODEL_NAME + from codeatrium.llm import DISTILL_PROMPT_VERSION + + # Get recorded values from meta table + meta_rows = con.execute( + "SELECT key, value FROM meta WHERE key IN ('embedding_model', 'prompt_version')" + ).fetchall() + recorded = {row[0]: row[1] for row in meta_rows} + + # Compare with current values + drifts: list[tuple[str, str, str]] = [] + + if "embedding_model" in recorded and recorded["embedding_model"] != MODEL_NAME: + drifts.append(("embedding_model", recorded["embedding_model"], MODEL_NAME)) + + if "prompt_version" in recorded and recorded["prompt_version"] != DISTILL_PROMPT_VERSION: + drifts.append(("prompt_version", recorded["prompt_version"], DISTILL_PROMPT_VERSION)) + + return drifts + finally: + con.close() diff --git a/src/codeatrium/distiller.py b/src/codeatrium/distiller.py index df3c2b1..928fc79 100644 --- a/src/codeatrium/distiller.py +++ b/src/codeatrium/distiller.py @@ -10,19 +10,21 @@ from __future__ import annotations import datetime -import hashlib import os import re -import struct from collections.abc import Callable from pathlib import Path -from typing import Any +from typing import TYPE_CHECKING, Any os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") from codeatrium.embedder import Embedder, EmbedderSetupError from codeatrium.llm import DISTILL_PROMPT_TEMPLATE, call_claude from codeatrium.models import PalaceObject +from codeatrium.utils import sha256 + +if TYPE_CHECKING: + from codeatrium.resolver import SymbolResolver # ---- ファイルパス抽出 ---- @@ -32,13 +34,6 @@ ) -# ---- 内部ヘルパー ---- - - -def _sha256(text: str) -> str: - return hashlib.sha256(text.encode()).hexdigest() - - # ---- 公開 API ---- @@ -92,6 +87,7 @@ def extract_files_touched( def distill_exchange( exchange_id: str, + db_path: Path, user_content: str, agent_content: str, ply_start: int, @@ -99,7 +95,20 @@ def distill_exchange( model: str | None = None, project_root: str | None = None, ) -> PalaceObject: - """1つの exchange を蒸留して PalaceObject を返す""" + """1つの exchange を蒸留して PalaceObject を返す + + Parameters: + exchange_id: exchange の ID + db_path: データベースファイルパス + user_content: ユーザーコンテンツ + agent_content: エージェントコンテンツ + ply_start: 開始ply + ply_end: 終了ply + model: 蒸留に使うモデル(デフォルトはconfig.toml から) + project_root: プロジェクトルート(ファイルパスフィルタ用) + """ + from codeatrium.db import get_connection + messages_text = (user_content + "\n" + agent_content)[:4000] prompt = DISTILL_PROMPT_TEMPLATE.format( ply_start=ply_start, @@ -107,9 +116,35 @@ def distill_exchange( messages_text=messages_text, ) raw = call_claude(prompt, model=model) - files_touched = extract_files_touched( + + # PRIMARY: exchange_files から読み込み + con = get_connection(db_path) + rows = con.execute( + "SELECT file_path FROM exchange_files WHERE exchange_id=?", (exchange_id,) + ).fetchall() + con.close() + primary_paths = [row["file_path"] for row in rows] + + # FALLBACK: regex 抽出 + fallback_paths = extract_files_touched( user_content, agent_content, project_root=project_root ) + + # Merge: primary パスをフィルタして、fallback との重複排除 + root_prefix = (project_root.rstrip("/") + "/") if project_root else None + seen: set[str] = set() + files_touched: list[str] = [] + + for path in primary_paths: + if path not in seen and not _is_external_path(path, root_prefix): + seen.add(path) + files_touched.append(path) + + for path in fallback_paths: + if path not in seen: + seen.add(path) + files_touched.append(path) + return PalaceObject( exchange_core=raw["exchange_core"], specific_context=raw["specific_context"], @@ -123,96 +158,145 @@ def save_palace_object( exchange_id: str, palace: PalaceObject, embedding: Any, # np.ndarray + resolver: SymbolResolver | None = None, + symbol_cache: dict[str, list[Any]] | None = None, ) -> None: """PalaceObject を DB に保存し exchange の distilled_at を更新する""" import numpy as np from codeatrium.db import get_connection + from codeatrium.resolver import SymbolResolver - palace_id = _sha256(f"palace:{exchange_id}") + palace_id = sha256(f"palace:{exchange_id}") distill_text = palace.exchange_core + "\n" + palace.specific_context con = get_connection(db_path) - - con.execute( - """ - INSERT OR IGNORE INTO palace_objects - (id, exchange_id, exchange_core, specific_context, distill_text) - VALUES (?, ?, ?, ?, ?) - """, - ( - palace_id, - exchange_id, - palace.exchange_core, - palace.specific_context, - distill_text, - ), - ) - - for room in palace.room_assignments: - dedup = _sha256(f"{room['room_type']}:{room['room_key']}") - room_id = _sha256(f"{palace_id}:{dedup}") - con.execute( - """ - INSERT OR IGNORE INTO rooms - (id, palace_object_id, room_type, room_key, room_label, relevance, dedup_hash) - VALUES (?, ?, ?, ?, ?, ?, ?) - """, - ( - room_id, - palace_id, - room["room_type"], - room["room_key"], - room["room_label"], - room["relevance"], - dedup, - ), - ) - - arr = embedding.astype(np.float32) - blob = struct.pack(f"{len(arr)}f", *arr.tolist()) - exists = con.execute( - "SELECT 1 FROM vec_palace WHERE palace_id = ?", (palace_id,) - ).fetchone() - if not exists: - con.execute( - "INSERT INTO vec_palace (palace_id, embedding) VALUES (?, ?)", - (palace_id, blob), - ) - - # ⑤ tree-sitter シンボル解決 - from codeatrium.resolver import SymbolResolver - - resolver = SymbolResolver() - for file_str in palace.files_touched: - for sym in resolver.extract(Path(file_str)): - sym_id = _sha256(f"{sym.symbol_name}:{sym.file_path}") + con.execute("BEGIN") + try: + existing = con.execute( + "SELECT 1 FROM palace_objects WHERE id = ?", (palace_id,) + ).fetchone() + if existing is None: con.execute( """ - INSERT OR IGNORE INTO symbols - (id, palace_object_id, symbol_name, symbol_kind, - file_path, signature, line, dedup_hash) - VALUES (?, ?, ?, ?, ?, ?, ?, ?) + INSERT INTO palace_objects + (id, exchange_id, exchange_core, specific_context, distill_text) + VALUES (?, ?, ?, ?, ?) """, ( - sym_id, palace_id, - sym.symbol_name, - sym.symbol_kind, - sym.file_path, - sym.signature, - sym.line, - sym_id, + exchange_id, + palace.exchange_core, + palace.specific_context, + distill_text, ), ) + verify = con.execute( + "SELECT 1 FROM palace_objects WHERE id = ?", (palace_id,) + ).fetchone() + if verify is None: + raise RuntimeError(f"palace_objects INSERT failed for id={palace_id}") + + for room in palace.room_assignments: + dedup = sha256(f"{room['room_type']}:{room['room_key']}") + room_id = sha256(f"{palace_id}:{dedup}") + room_exists = con.execute( + "SELECT 1 FROM rooms WHERE id = ?", (room_id,) + ).fetchone() + if room_exists is None: + con.execute( + """ + INSERT INTO rooms + (id, palace_object_id, room_type, room_key, room_label, relevance, dedup_hash) + VALUES (?, ?, ?, ?, ?, ?, ?) + """, + ( + room_id, + palace_id, + room["room_type"], + room["room_key"], + room["room_label"], + room["relevance"], + dedup, + ), + ) + + blob = embedding.astype(np.float32).tobytes() + exists = con.execute( + "SELECT 1 FROM vec_palace WHERE palace_id = ?", (palace_id,) + ).fetchone() + if not exists: + con.execute( + "INSERT INTO vec_palace (palace_id, embedding) VALUES (?, ?)", + (palace_id, blob), + ) - con.execute( - "UPDATE exchanges SET distilled_at = ? WHERE id = ?", - (datetime.datetime.utcnow().isoformat(), exchange_id), - ) + # ⑤ tree-sitter シンボル解決 + if resolver is None: + from codeatrium.resolver import SymbolResolver + resolver = SymbolResolver() + + # Fetch exchange body text for symbol body-mention filter + ex_row = con.execute( + "SELECT user_content, agent_content FROM exchanges WHERE id = ?", + (exchange_id,), + ).fetchone() + body_text = ( + (ex_row["user_content"] + ex_row["agent_content"]) + if ex_row is not None + else "" + ) - con.commit() - con.close() + for file_str in palace.files_touched: + if symbol_cache is not None and file_str in symbol_cache: + syms = symbol_cache[file_str] + else: + syms = resolver.extract(Path(file_str)) + if symbol_cache is not None: + symbol_cache[file_str] = syms + for sym in syms: + # Body-mention filter: skip symbol if not mentioned in exchange + if sym.symbol_name not in body_text: + continue + + # Compute sym_id with palace_id, dedup_hash separate + sym_id = sha256(f"{sym.symbol_name}:{sym.file_path}:{palace_id}") + dedup_hash = sha256(f"{sym.symbol_name}:{sym.file_path}") + + sym_exists = con.execute( + "SELECT 1 FROM symbols WHERE id = ?", (sym_id,) + ).fetchone() + if sym_exists is None: + con.execute( + """ + INSERT INTO symbols + (id, palace_object_id, symbol_name, symbol_kind, + file_path, signature, line, dedup_hash) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + sym_id, + palace_id, + sym.symbol_name, + sym.symbol_kind, + sym.file_path, + sym.signature, + sym.line, + dedup_hash, + ), + ) + + con.execute( + "UPDATE exchanges SET distilled_at = ?, distill_status = 'distilled' WHERE id = ?", + (datetime.datetime.now(datetime.UTC).isoformat(), exchange_id), + ) + + con.execute("COMMIT") + except Exception: + con.execute("ROLLBACK") + raise + finally: + con.close() def distill_all( @@ -222,14 +306,15 @@ def distill_all( on_progress: Callable[..., None] | None = None, project_root: str | None = None, distill_min_chars: int = 100, -) -> int: +) -> tuple[int, int]: """未蒸留の exchange を処理する。 distill_min_chars: この文字数未満の exchange は蒸留スキップ(デフォルト100) on_progress: (current, total, error=None) を受け取るコールバック - Returns: 処理した exchange 数 + Returns: (処理した exchange 数, エラー数) """ from codeatrium.db import get_connection + from codeatrium.resolver import SymbolResolver con = get_connection(db_path) @@ -237,7 +322,7 @@ def distill_all( # - 1-exchange セッション # - distill_min_chars 未満(ワンフレーズ指示・システムメッセージ等) con.execute(""" - UPDATE exchanges SET distilled_at = 'skipped' + UPDATE exchanges SET distilled_at = 'skipped', distill_status = 'skipped' WHERE distilled_at IS NULL AND ((SELECT COUNT(*) FROM exchanges e2 WHERE e2.conversation_id = exchanges.conversation_id) < 2 @@ -248,7 +333,7 @@ def distill_all( query = """ SELECT e.id, e.user_content, e.agent_content, e.ply_start, e.ply_end FROM exchanges e - WHERE e.distilled_at IS NULL + WHERE e.distill_status = 'pending' """ params: list[int] = [] if limit is not None: @@ -258,16 +343,19 @@ def distill_all( con.close() if not rows: - return 0 + return 0, 0 total = len(rows) embedder = Embedder() + resolver = SymbolResolver() + symbol_cache: dict[str, list[Any]] = {} count = 0 errors = 0 for row in rows: try: palace = distill_exchange( row["id"], + db_path, row["user_content"], row["agent_content"], row["ply_start"], @@ -277,7 +365,7 @@ def distill_all( ) distill_text = palace.exchange_core + "\n" + palace.specific_context vec = embedder.embed_passage(distill_text) - save_palace_object(db_path, row["id"], palace, vec) + save_palace_object(db_path, row["id"], palace, vec, resolver=resolver, symbol_cache=symbol_cache) count += 1 except EmbedderSetupError: # 環境レベルの失敗: per-row でなくループ全体を中断する @@ -290,4 +378,4 @@ def distill_all( if on_progress is not None: on_progress(count, total) - return count + return count, errors diff --git a/src/codeatrium/embedder.py b/src/codeatrium/embedder.py index 16488bc..af78388 100644 --- a/src/codeatrium/embedder.py +++ b/src/codeatrium/embedder.py @@ -157,7 +157,10 @@ def _embed_via_socket_or_direct( # ② 直接ロード self._ensure_model() - assert self._model is not None + if self._model is None: + raise RuntimeError( + "モデルがロードされていません: _ensure_model() の呼び出しに失敗した可能性があります" + ) result = self._model.encode( [f"{prefix}{text}"], normalize_embeddings=True, diff --git a/src/codeatrium/hooks.py b/src/codeatrium/hooks.py index 87d201e..a10fd6b 100644 --- a/src/codeatrium/hooks.py +++ b/src/codeatrium/hooks.py @@ -2,8 +2,12 @@ from __future__ import annotations +import copy import json +import os import shlex +import shutil +import tempfile from pathlib import Path from typing import Any, cast @@ -11,6 +15,33 @@ from codeatrium.paths import loci_bin +def _write_settings(settings_path: Path, settings: dict[str, Any]) -> None: + """settings.json をアトミックに書き込む。書き込み前に .bak を作成し os.replace で差し替える。""" + settings_path.parent.mkdir(parents=True, exist_ok=True) + if settings_path.exists(): + shutil.copy2(settings_path, settings_path.with_suffix(".json.bak")) + + tmp_file = tempfile.NamedTemporaryFile( + mode="w", + dir=settings_path.parent, + delete=False, + suffix=".tmp", + encoding="utf-8", + ) + tmp_path = tmp_file.name + try: + json.dump(settings, tmp_file, ensure_ascii=False, indent=2) + tmp_file.flush() + os.fsync(tmp_file.fileno()) + tmp_file.close() + os.replace(tmp_path, settings_path) + except Exception: + tmp_file.close() + if os.path.exists(tmp_path): + os.unlink(tmp_path) + raise + + def install_hooks(batch_limit: int = DEFAULT_DISTILL_BATCH_LIMIT) -> tuple[bool, str]: """Claude Code の Stop / SessionStart フックに loci を登録する。 @@ -19,7 +50,7 @@ def install_hooks(batch_limit: int = DEFAULT_DISTILL_BATCH_LIMIT) -> tuple[bool, settings_path = Path.home() / ".claude" / "settings.json" if settings_path.exists(): - with settings_path.open() as f: + with settings_path.open(encoding="utf-8") as f: settings: dict[str, Any] = json.load(f) else: settings = {} @@ -141,9 +172,7 @@ def install_hooks(batch_limit: int = DEFAULT_DISTILL_BATCH_LIMIT) -> tuple[bool, if not changed: return False, "Hooks already up to date." - settings_path.parent.mkdir(parents=True, exist_ok=True) - with settings_path.open("w") as f: - json.dump(settings, f, ensure_ascii=False, indent=2) + _write_settings(settings_path, settings) lines = [ f"Hooks installed: {settings_path}", @@ -154,3 +183,50 @@ def install_hooks(batch_limit: int = DEFAULT_DISTILL_BATCH_LIMIT) -> tuple[bool, " (matcher: startup|clear|resume|compact)", ] return True, "\n".join(lines) + + +def uninstall_hooks() -> tuple[bool, str]: + """codeatrium に関連する Claude Code フックを削除する。 + + Returns: (changed, message) — 変更の有無と結果メッセージ + """ + settings_path = Path.home() / ".claude" / "settings.json" + + if not settings_path.exists(): + return False, "No settings.json found. Nothing to uninstall." + + with settings_path.open(encoding="utf-8") as f: + settings: dict[str, Any] = json.load(f) + + settings_before = copy.deepcopy(settings) + + hooks = settings.get("hooks", {}) + if not hooks: + return False, "No hooks section found. Nothing to uninstall." + + def _is_loci(h: dict[str, Any]) -> bool: + cmd = h.get("command", "") + return "loci" in cmd and any( + kw in cmd for kw in ("index", "server", "distill", "prime") + ) + + for section in ("Stop", "SessionStart", "SessionEnd"): + if section not in hooks: + continue + entries: list[dict[str, Any]] = hooks[section] + for entry in entries[:]: + hooks_list = entry.get("hooks", []) + entry["hooks"] = [h for h in hooks_list if not _is_loci(h)] + hooks[section] = [e for e in entries if e.get("hooks")] + if not hooks[section]: + del hooks[section] + + if not hooks: + if "hooks" in settings: + del settings["hooks"] + + if settings == settings_before: + return False, "No codeatrium hooks found. Nothing to uninstall." + + _write_settings(settings_path, settings) + return True, f"Hooks uninstalled: {settings_path}" diff --git a/src/codeatrium/indexer.py b/src/codeatrium/indexer.py index 0bab94f..75f3d30 100644 --- a/src/codeatrium/indexer.py +++ b/src/codeatrium/indexer.py @@ -12,13 +12,14 @@ from __future__ import annotations -import hashlib import json -from dataclasses import dataclass +from dataclasses import dataclass, field from datetime import UTC, datetime from pathlib import Path from typing import Any +from codeatrium.utils import sha256 + @dataclass class Exchange: @@ -30,13 +31,86 @@ class Exchange: ply_end: int user_content: str agent_content: str + files: list[str] = field(default_factory=list) # ---- 内部ヘルパー ---- +# 外部パス(サイトパッケージ等)の判定用マーカー +_EXTERNAL_PATH_MARKERS = ( + 'site-packages/', + 'dist-packages/', + '/lib/python', + '/opt/', + '/usr/lib/', + '/usr/local/lib/', + '.venv/', + '/venv/', + 'node_modules/', +) + + +def _is_external_path_indexer(path: str) -> bool: + """パスが外部ライブラリ(site-packages など)を指しているか判定する""" + return any(marker in path for marker in _EXTERNAL_PATH_MARKERS) + -def _sha256(text: str) -> str: - return hashlib.sha256(text.encode()).hexdigest() +def _extract_tool_use_files(entries: list[dict | None]) -> list[str]: + """ + assistant エントリから tool_use ブロックの file_path を抽出する。 + 外部パスは除外し、重複をトリムしたリストを返す(順序保持)。 + """ + seen: set[str] = set() + result: list[str] = [] + + for entry in entries: + if entry is None: + continue + if entry.get("type") != "assistant": + continue + + msg = entry.get("message") + if not isinstance(msg, dict): + continue + + content = msg.get("content") + if not isinstance(content, list): + continue + + for block in content: + if not isinstance(block, dict): + continue + if block.get("type") != "tool_use": + continue + + name = block.get("name") + if name not in {'Edit', 'Write', 'Read', 'MultiEdit', 'NotebookEdit'}: + continue + + input_dict = block.get("input") + if not isinstance(input_dict, dict): + continue + + # NotebookEdit の場合は notebook_path、その他は file_path + path = None + if name == 'NotebookEdit': + path = input_dict.get("notebook_path") + else: + path = input_dict.get("file_path") + + if not path or not isinstance(path, str): + continue + + # 外部パスはスキップ + if _is_external_path_indexer(path): + continue + + # 重複排除(順序保持) + if path not in seen: + seen.add(path) + result.append(path) + + return result def _extract_text(content: Any) -> str: @@ -104,44 +178,62 @@ def _is_real_user_entry(entry: dict) -> bool: # ---- 公開API ---- -def parse_exchanges(jsonl_path: Path, min_chars: int = 50) -> list[Exchange]: +def parse_exchanges(jsonl_path: Path, min_chars: int = 50, last_ply_end: int = -1) -> list[Exchange]: """ .jsonl ファイルを読んで exchange リストを返す。 trivial(min_chars 文字未満)は除外する。 + last_ply_end: ply インデックス以前の行はスキップ(デフォルト -1 = 全行パース、既にインデックスされた部分の再処理を避けるため)。 """ - entries: list[dict] = [] + raw_entries: list[dict | None] = [] if not jsonl_path.exists(): return [] + ply = 0 with jsonl_path.open(encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue - try: - entries.append(json.loads(line)) - except json.JSONDecodeError: - continue - - conversation_id = _sha256(str(jsonl_path)) + if ply <= last_ply_end: + # 既インデックス領域: 古い exchange は再構築しない(None プレースホルダ)。 + # ただし malformed 行は位置に数えない — last_ply_end は成功パース行のみを + # 数えた座標系なので、検証パースして同じ座標系を維持する。 + try: + json.loads(line) + except json.JSONDecodeError: + continue + raw_entries.append(None) + ply += 1 + else: + try: + raw_entries.append(json.loads(line)) + ply += 1 + except json.JSONDecodeError: + continue + + conversation_id = sha256(str(jsonl_path)) # exchange の境界インデックスを収集 - boundaries: list[int] = [i for i, e in enumerate(entries) if _is_real_user_entry(e)] + boundaries: list[int] = [i for i, e in enumerate(raw_entries) if e is not None and _is_real_user_entry(e)] exchanges: list[Exchange] = [] for b_idx, start in enumerate(boundaries): end = ( boundaries[b_idx + 1] - 1 if b_idx + 1 < len(boundaries) - else len(entries) - 1 + else len(raw_entries) - 1 ) - user_entry = entries[start] + user_entry = raw_entries[start] + if user_entry is None: + continue user_text = _extract_text(user_entry["message"]["content"]) # assistant の発話を連結(コンパクション要約ゾーンは除外) agent_parts: list[str] = [] in_compaction_zone = False - for e in entries[start + 1 : end + 1]: + for e in raw_entries[start + 1 : end + 1]: + if e is None: + continue if e.get("type") == "user": msg = e.get("message", {}) if isinstance(msg, dict): @@ -163,7 +255,10 @@ def parse_exchanges(jsonl_path: Path, min_chars: int = 50) -> list[Exchange]: continue user_uuid = user_entry.get("uuid", f"{start}") - exchange_id = _sha256(f"{conversation_id}:{user_uuid}") + exchange_id = sha256(f"{conversation_id}:{user_uuid}") + + # tool_use から file パスを抽出 + tool_files = _extract_tool_use_files(raw_entries[start : end + 1]) exchanges.append( Exchange( @@ -173,6 +268,7 @@ def parse_exchanges(jsonl_path: Path, min_chars: int = 50) -> list[Exchange]: ply_end=end, user_content=user_text, agent_content=agent_text, + files=tool_files, ) ) @@ -187,7 +283,7 @@ def index_file(jsonl_path: Path, db_path: Path, min_chars: int = 50) -> int: """ from codeatrium.db import get_connection - conversation_id = _sha256(str(jsonl_path)) + conversation_id = sha256(str(jsonl_path)) con = get_connection(db_path) # 既存 conversation の last_ply_end を取得 @@ -196,7 +292,7 @@ def index_file(jsonl_path: Path, db_path: Path, min_chars: int = 50) -> int: ).fetchone() last_ply_end = row["last_ply_end"] if row is not None else -1 - exchanges = parse_exchanges(jsonl_path, min_chars=min_chars) + exchanges = parse_exchanges(jsonl_path, min_chars=min_chars, last_ply_end=last_ply_end) new_exchanges = [ex for ex in exchanges if ex.ply_start > last_ply_end] if not new_exchanges: @@ -234,6 +330,14 @@ def index_file(jsonl_path: Path, db_path: Path, min_chars: int = 50) -> int: ), ) + # exchange_files を登録 + for ex in new_exchanges: + for file_path in ex.files: + con.execute( + "INSERT OR IGNORE INTO exchange_files (exchange_id, file_path) VALUES (?, ?)", + (ex.id, file_path), + ) + con.commit() con.close() return len(new_exchanges) diff --git a/src/codeatrium/llm.py b/src/codeatrium/llm.py index d27a8d0..468252a 100644 --- a/src/codeatrium/llm.py +++ b/src/codeatrium/llm.py @@ -2,6 +2,7 @@ from __future__ import annotations +import hashlib import json import subprocess from pathlib import Path @@ -31,6 +32,9 @@ JSONのみで回答してください。""" +# プロンプト sha256 先頭8桁(B5: drift 検出用) +DISTILL_PROMPT_VERSION = hashlib.sha256(DISTILL_PROMPT_TEMPLATE.encode()).hexdigest()[:8] + JSON_SCHEMA = json.dumps( { "type": "object", diff --git a/src/codeatrium/search.py b/src/codeatrium/search.py index 29cf3f1..51d23ff 100644 --- a/src/codeatrium/search.py +++ b/src/codeatrium/search.py @@ -16,6 +16,7 @@ import sqlite3 import struct +from contextlib import closing from pathlib import Path from typing import Any @@ -116,29 +117,28 @@ def search_bm25( db_path: Path, query_text: str, limit: int = 10, min_exchanges: int = 2 ) -> list[BM25Result]: """FTS5 BM25 で exchanges_fts を検索する""" - con = get_connection(db_path) fts_query = _fts5_query(query_text) - try: - rows = con.execute( - """ - SELECT - e.id AS exchange_id, - e.user_content, - e.agent_content, - -bm25(exchanges_fts) AS score - FROM exchanges_fts - JOIN exchanges e ON e.rowid = exchanges_fts.rowid - WHERE exchanges_fts MATCH ? - AND (SELECT COUNT(*) FROM exchanges e2 - WHERE e2.conversation_id = e.conversation_id) >= ? - ORDER BY score DESC - LIMIT ? - """, - (fts_query, min_exchanges, limit), - ).fetchall() - except sqlite3.OperationalError: - rows = [] - con.close() + with closing(get_connection(db_path)) as con: + try: + rows = con.execute( + """ + SELECT + e.id AS exchange_id, + e.user_content, + e.agent_content, + -bm25(exchanges_fts) AS score + FROM exchanges_fts + JOIN exchanges e ON e.rowid = exchanges_fts.rowid + WHERE exchanges_fts MATCH ? + AND (SELECT COUNT(*) FROM exchanges e2 + WHERE e2.conversation_id = e.conversation_id) >= ? + ORDER BY score DESC + LIMIT ? + """, + (fts_query, min_exchanges, limit), + ).fetchall() + except sqlite3.OperationalError: + rows = [] return [ BM25Result( exchange_id=row["exchange_id"], @@ -157,37 +157,36 @@ def search_hnsw_palace( db_path: Path, query_vec: np.ndarray, limit: int = 10, min_exchanges: int = 2 ) -> list[HNSWPalaceResult]: """sqlite-vec HNSW で vec_palace を検索する(distilled embedding)""" - con = get_connection(db_path) - blob = _serialize(query_vec) - - try: - rows = con.execute( - """ - SELECT - p.exchange_id, - e.user_content, - e.agent_content, - p.exchange_core, - p.specific_context, - v.distance - FROM ( - SELECT palace_id, distance - FROM vec_palace - WHERE embedding MATCH ? - AND k = ? - ) v - JOIN palace_objects p ON p.id = v.palace_id - JOIN exchanges e ON e.id = p.exchange_id - WHERE (SELECT COUNT(*) FROM exchanges e2 - WHERE e2.conversation_id = e.conversation_id) >= ? - ORDER BY v.distance - """, - (blob, limit, min_exchanges), - ).fetchall() - except sqlite3.OperationalError: - rows = [] - - con.close() + + with closing(get_connection(db_path)) as con: + blob = _serialize(query_vec) + try: + rows = con.execute( + """ + SELECT + p.exchange_id, + e.user_content, + e.agent_content, + p.exchange_core, + p.specific_context, + v.distance + FROM ( + SELECT palace_id, distance + FROM vec_palace + WHERE embedding MATCH ? + AND k = ? + ) v + JOIN palace_objects p ON p.id = v.palace_id + JOIN exchanges e ON e.id = p.exchange_id + WHERE (SELECT COUNT(*) FROM exchanges e2 + WHERE e2.conversation_id = e.conversation_id) >= ? + ORDER BY v.distance + """, + (blob, limit, min_exchanges), + ).fetchall() + except sqlite3.OperationalError: + rows = [] + return [ HNSWPalaceResult( exchange_id=row["exchange_id"], @@ -266,8 +265,7 @@ def search_combined( fused = rrf(bm25_results, hnsw_results, limit=limit) if fused: - con = get_connection(db_path) - _enrich_results(con, fused) - con.close() + with closing(get_connection(db_path)) as con: + _enrich_results(con, fused) return fused diff --git a/src/codeatrium/utils.py b/src/codeatrium/utils.py new file mode 100644 index 0000000..4df08c5 --- /dev/null +++ b/src/codeatrium/utils.py @@ -0,0 +1,8 @@ +"""共通ユーティリティ: プロジェクト横断で再利用される小さなヘルパー関数""" + +import hashlib + + +def sha256(text: str) -> str: + """テキストの SHA-256 ハッシュ(hex 文字列)を返す""" + return hashlib.sha256(text.encode()).hexdigest() diff --git a/tests/test_config.py b/tests/test_config.py index ac68da3..9550bfe 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -3,6 +3,7 @@ from __future__ import annotations from pathlib import Path +from unittest.mock import patch from codeatrium.config import ( DEFAULT_DISTILL_BATCH_LIMIT, @@ -115,3 +116,26 @@ def test_load_config_distill_min_chars_invalid_fallback(tmp_path: Path) -> None: (codeatrium_dir / "config.toml").write_text("[distill]\nmin_chars = -5\n") cfg = load_config(tmp_path) assert cfg.distill_min_chars == DEFAULT_DISTILL_MIN_CHARS + + +def test_load_config_toml_decode_error_fallback(tmp_path: Path, capsys) -> None: + """不正な TOML 内容(TOMLDecodeError)はデフォルトにフォールバック、警告出力""" + codeatrium_dir = tmp_path / ".codeatrium" + codeatrium_dir.mkdir() + (codeatrium_dir / "config.toml").write_text("not valid toml [") + cfg = load_config(tmp_path) + assert cfg == Config() + captured = capsys.readouterr() + assert "Warning" in captured.err + + +def test_load_config_oserror_fallback(tmp_path: Path) -> None: + """OSError(ファイルアクセスエラー)はデフォルトにフォールバック""" + codeatrium_dir = tmp_path / ".codeatrium" + codeatrium_dir.mkdir() + config_file = codeatrium_dir / "config.toml" + config_file.write_text("[distill]\nmodel = 'test'\n") + # exists() は True のまま、open() だけ OSError を引き起こす + with patch("pathlib.Path.open", side_effect=OSError("disk error")): + cfg = load_config(tmp_path) + assert cfg == Config() diff --git a/tests/test_db.py b/tests/test_db.py index 53a5cd3..f153b68 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -2,10 +2,11 @@ DB 初期化・スキーマのテスト """ +import hashlib import sqlite3 from pathlib import Path -from codeatrium.db import get_connection, init_db +from codeatrium.db import _MIGRATIONS, check_drift, get_connection, init_db def test_init_db_creates_conversations_table(tmp_path: Path) -> None: @@ -67,3 +68,872 @@ def test_init_db_creates_vec_table(tmp_path: Path) -> None: cur = con.execute("SELECT name FROM sqlite_master WHERE name='vec_exchanges'") assert cur.fetchone() is not None con.close() + + +def test_init_db_stamps_user_version_on_new_db(tmp_path: Path) -> None: + db_path = tmp_path / "memory.db" + init_db(db_path) + + con = sqlite3.connect(db_path) + user_version = con.execute("PRAGMA user_version").fetchone()[0] + assert user_version == len(_MIGRATIONS) + con.close() + + +def test_init_db_migration_adds_last_ply_end_to_legacy_db(tmp_path: Path) -> None: + db_path = tmp_path / "memory.db" + + # Create a raw sqlite3 DB with legacy schema (no last_ply_end) + raw_con = sqlite3.connect(db_path) + raw_con.execute( + """CREATE TABLE conversations ( + id TEXT PRIMARY KEY, + source_path TEXT NOT NULL UNIQUE, + started_at TIMESTAMP + )""" + ) + raw_con.execute("INSERT INTO conversations(id, source_path) VALUES ('row1', '/a')") + raw_con.execute("PRAGMA user_version = 0") + raw_con.commit() + raw_con.close() + + # Now call init_db which should run migrations + init_db(db_path) + + # Verify migration was applied + con = sqlite3.connect(db_path) + table_info = con.execute("PRAGMA table_info(conversations)").fetchall() + column_names = [col[1] for col in table_info] + assert "last_ply_end" in column_names + + # Verify inserted row still exists + row = con.execute( + "SELECT id, source_path FROM conversations WHERE id='row1'" + ).fetchone() + assert row is not None + assert row[0] == "row1" + assert row[1] == "/a" + + # Verify user_version was stamped + user_version = con.execute("PRAGMA user_version").fetchone()[0] + assert user_version == len(_MIGRATIONS) + con.close() + + +def test_init_db_migration_idempotent(tmp_path: Path) -> None: + db_path = tmp_path / "memory.db" + init_db(db_path) + init_db(db_path) + + con = sqlite3.connect(db_path) + user_version = con.execute("PRAGMA user_version").fetchone()[0] + assert user_version == len(_MIGRATIONS) + con.close() + + +def test_get_connection_journal_mode_is_wal(tmp_path: Path) -> None: + db_path = tmp_path / "memory.db" + init_db(db_path) + + con = get_connection(db_path) + journal_mode = con.execute("PRAGMA journal_mode").fetchone()[0] + assert journal_mode.lower() == "wal" + con.close() + + +def test_concurrent_writes_do_not_raise_locked(tmp_path: Path) -> None: + db_path = tmp_path / "memory.db" + init_db(db_path) + + con1 = get_connection(db_path) + con2 = get_connection(db_path) + + con1.execute("INSERT INTO conversations(id, source_path) VALUES ('id1', '/path1')") + con1.commit() + + con2.execute("INSERT INTO conversations(id, source_path) VALUES ('id2', '/path2')") + con2.commit() + + con1.close() + con2.close() + + +def test_init_db_new_db_has_distill_status_column(tmp_path: Path) -> None: + db_path = tmp_path / "memory.db" + init_db(db_path) + + con = sqlite3.connect(db_path) + table_info = con.execute("PRAGMA table_info(exchanges)").fetchall() + column_names = [col[1] for col in table_info] + assert "distill_status" in column_names + con.close() + + +def test_init_db_new_db_has_meta_table(tmp_path: Path) -> None: + db_path = tmp_path / "memory.db" + init_db(db_path) + + con = sqlite3.connect(db_path) + cur = con.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name='meta'" + ) + assert cur.fetchone() is not None + con.close() + + +def test_init_db_new_db_has_indexes(tmp_path: Path) -> None: + db_path = tmp_path / "memory.db" + init_db(db_path) + + con = sqlite3.connect(db_path) + indexes = con.execute( + "SELECT name FROM sqlite_master WHERE type='index' AND name IN ('idx_rooms_palace_object_id', 'idx_symbols_palace_object_id', 'idx_palace_objects_exchange_id')" + ).fetchall() + index_names = [idx[0] for idx in indexes] + assert "idx_rooms_palace_object_id" in index_names + assert "idx_symbols_palace_object_id" in index_names + assert "idx_palace_objects_exchange_id" in index_names + con.close() + + +def test_init_db_new_db_meta_has_embedding_model(tmp_path: Path) -> None: + db_path = tmp_path / "memory.db" + init_db(db_path) + + con = get_connection(db_path) + row = con.execute("SELECT value FROM meta WHERE key='embedding_model'").fetchone() + assert row is not None + assert row[0] and len(row[0]) > 0 + con.close() + + +def test_init_db_new_db_meta_has_prompt_version(tmp_path: Path) -> None: + db_path = tmp_path / "memory.db" + init_db(db_path) + + con = get_connection(db_path) + row = con.execute("SELECT value FROM meta WHERE key='prompt_version'").fetchone() + assert row is not None + assert len(row[0]) == 8 + con.close() + + +def test_migration_v2_converts_skipped(tmp_path: Path) -> None: + db_path = tmp_path / "memory.db" + + # Create a raw sqlite3 DB with legacy schema (user_version=1, pre-v2) + raw_con = sqlite3.connect(db_path) + raw_con.execute( + """CREATE TABLE conversations ( + id TEXT PRIMARY KEY, + source_path TEXT NOT NULL UNIQUE, + started_at TIMESTAMP, + last_ply_end INT NOT NULL DEFAULT -1 + )""" + ) + raw_con.execute("INSERT INTO conversations(id, source_path) VALUES ('conv1', '/src')") + raw_con.execute( + """CREATE TABLE exchanges ( + id TEXT PRIMARY KEY, + conversation_id TEXT NOT NULL, + ply_start INT NOT NULL, + ply_end INT NOT NULL, + user_content TEXT NOT NULL, + agent_content TEXT NOT NULL, + distilled_at TIMESTAMP + )""" + ) + raw_con.execute( + "INSERT INTO exchanges VALUES ('ex1', 'conv1', 0, 1, 'user1', 'agent1', 'skipped')" + ) + raw_con.execute( + "INSERT INTO exchanges VALUES ('ex2', 'conv1', 1, 2, 'user2', 'agent2', '2026-01-01T00:00:00')" + ) + raw_con.execute("PRAGMA user_version = 1") + raw_con.commit() + raw_con.close() + + # Run init_db which should run v2 migration + init_db(db_path) + + # Verify v2 migration: 'skipped' becomes distill_status='skipped' with NULL distilled_at + con = sqlite3.connect(db_path) + row1 = con.execute( + "SELECT distill_status, distilled_at FROM exchanges WHERE id='ex1'" + ).fetchone() + assert row1[0] == "skipped" + assert row1[1] is None + + # Verify timestamp row: distill_status='distilled' with distilled_at preserved + row2 = con.execute( + "SELECT distill_status, distilled_at FROM exchanges WHERE id='ex2'" + ).fetchone() + assert row2[0] == "distilled" + assert row2[1] == "2026-01-01T00:00:00" + con.close() + + +def test_migration_v3_creates_meta(tmp_path: Path) -> None: + db_path = tmp_path / "memory.db" + + # Create a raw sqlite3 DB with user_version=2 (post-v2, pre-v3) + raw_con = sqlite3.connect(db_path) + raw_con.execute( + """CREATE TABLE conversations ( + id TEXT PRIMARY KEY, + source_path TEXT NOT NULL UNIQUE, + started_at TIMESTAMP, + last_ply_end INT NOT NULL DEFAULT -1 + )""" + ) + raw_con.execute("INSERT INTO conversations(id, source_path) VALUES ('conv1', '/src')") + raw_con.execute( + """CREATE TABLE exchanges ( + id TEXT PRIMARY KEY, + conversation_id TEXT NOT NULL, + ply_start INT NOT NULL, + ply_end INT NOT NULL, + user_content TEXT NOT NULL, + agent_content TEXT NOT NULL, + distilled_at TIMESTAMP, + distill_status TEXT NOT NULL DEFAULT 'pending' + )""" + ) + raw_con.execute("PRAGMA user_version = 2") + raw_con.commit() + raw_con.close() + + # Run init_db which should run v3 migration + init_db(db_path) + + # Verify meta table exists with embedding_model and prompt_version + con = sqlite3.connect(db_path) + meta_exists = con.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name='meta'" + ).fetchone() + assert meta_exists is not None + + embedding_model = con.execute( + "SELECT value FROM meta WHERE key='embedding_model'" + ).fetchone() + assert embedding_model is not None + + prompt_version = con.execute( + "SELECT value FROM meta WHERE key='prompt_version'" + ).fetchone() + assert prompt_version is not None + con.close() + + +def test_migration_v4_creates_indexes(tmp_path: Path) -> None: + db_path = tmp_path / "memory.db" + + # Create a raw sqlite3 DB with user_version=3 (post-v3, pre-v4) + raw_con = sqlite3.connect(db_path) + raw_con.execute( + """CREATE TABLE conversations ( + id TEXT PRIMARY KEY, + source_path TEXT NOT NULL UNIQUE, + started_at TIMESTAMP, + last_ply_end INT NOT NULL DEFAULT -1 + )""" + ) + raw_con.execute("INSERT INTO conversations(id, source_path) VALUES ('conv1', '/src')") + raw_con.execute( + """CREATE TABLE exchanges ( + id TEXT PRIMARY KEY, + conversation_id TEXT NOT NULL, + ply_start INT NOT NULL, + ply_end INT NOT NULL, + user_content TEXT NOT NULL, + agent_content TEXT NOT NULL, + distilled_at TIMESTAMP, + distill_status TEXT NOT NULL DEFAULT 'pending' + )""" + ) + raw_con.execute("CREATE TABLE meta (key TEXT PRIMARY KEY, value TEXT)") + raw_con.execute( + "INSERT INTO meta VALUES ('embedding_model', 'test-model')" + ) + raw_con.execute( + "INSERT INTO meta VALUES ('prompt_version', 'v0000001')" + ) + raw_con.execute( + """CREATE TABLE palace_objects ( + id TEXT PRIMARY KEY, + exchange_id TEXT NOT NULL, + exchange_core TEXT NOT NULL, + specific_context TEXT NOT NULL, + distill_text TEXT NOT NULL + )""" + ) + raw_con.execute( + """CREATE TABLE rooms ( + id TEXT PRIMARY KEY, + palace_object_id TEXT NOT NULL, + room_type TEXT NOT NULL, + room_key TEXT NOT NULL, + room_label TEXT NOT NULL, + relevance REAL NOT NULL, + dedup_hash TEXT NOT NULL + )""" + ) + raw_con.execute( + """CREATE TABLE symbols ( + id TEXT PRIMARY KEY, + palace_object_id TEXT NOT NULL, + symbol_name TEXT NOT NULL, + symbol_kind TEXT NOT NULL, + file_path TEXT NOT NULL, + signature TEXT NOT NULL, + line INT NOT NULL, + dedup_hash TEXT NOT NULL + )""" + ) + raw_con.execute("PRAGMA user_version = 3") + raw_con.commit() + raw_con.close() + + # Run init_db which should run v4 migration + init_db(db_path) + + # Verify indexes exist + con = sqlite3.connect(db_path) + indexes = con.execute( + "SELECT name FROM sqlite_master WHERE type='index' AND name IN ('idx_rooms_palace_object_id', 'idx_symbols_palace_object_id', 'idx_palace_objects_exchange_id')" + ).fetchall() + index_names = [idx[0] for idx in indexes] + assert "idx_rooms_palace_object_id" in index_names + assert "idx_symbols_palace_object_id" in index_names + assert "idx_palace_objects_exchange_id" in index_names + con.close() + + +def test_init_db_idempotent_user_version_4(tmp_path: Path) -> None: + db_path = tmp_path / "memory.db" + init_db(db_path) + init_db(db_path) + + con = sqlite3.connect(db_path) + user_version = con.execute("PRAGMA user_version").fetchone()[0] + assert user_version == len(_MIGRATIONS) + con.close() + + +def test_check_drift_no_drift(tmp_path: Path) -> None: + db_path = tmp_path / "memory.db" + init_db(db_path) + + drifts = check_drift(db_path) + assert drifts == [] + + +def test_check_drift_detects_mismatch(tmp_path: Path) -> None: + db_path = tmp_path / "memory.db" + init_db(db_path) + + # Modify meta to introduce drift + con = get_connection(db_path) + con.execute("UPDATE meta SET value='old_value' WHERE key='prompt_version'") + con.commit() + con.close() + + # Check drift should detect the mismatch + drifts = check_drift(db_path) + assert len(drifts) > 0 + drift_keys = [d[0] for d in drifts] + assert "prompt_version" in drift_keys + + +def test_check_drift_absent_meta_returns_empty(tmp_path: Path) -> None: + db_path = tmp_path / "memory.db" + + # Create a legacy DB without meta table (pre-v3) + raw_con = sqlite3.connect(db_path) + raw_con.execute( + """CREATE TABLE conversations ( + id TEXT PRIMARY KEY, + source_path TEXT NOT NULL UNIQUE, + started_at TIMESTAMP, + last_ply_end INT NOT NULL DEFAULT -1 + )""" + ) + raw_con.execute("INSERT INTO conversations(id, source_path) VALUES ('conv1', '/src')") + raw_con.execute( + """CREATE TABLE exchanges ( + id TEXT PRIMARY KEY, + conversation_id TEXT NOT NULL, + ply_start INT NOT NULL, + ply_end INT NOT NULL, + user_content TEXT NOT NULL, + agent_content TEXT NOT NULL, + distilled_at TIMESTAMP + )""" + ) + raw_con.execute("PRAGMA user_version = 1") + raw_con.commit() + raw_con.close() + + # check_drift should return [] without raising an exception + drifts = check_drift(db_path) + assert drifts == [] + + +def test_init_db_chmod_600(tmp_path: Path) -> None: + db_path = tmp_path / "memory.db" + init_db(db_path) + + mode_str = oct(db_path.stat().st_mode)[-3:] + assert mode_str == "600" + + +def test_migration_v5_creates_exchange_files(tmp_path: Path) -> None: + db_path = tmp_path / "memory.db" + + # Create a raw sqlite3 DB with user_version=4 (post-v4, pre-v5) + raw_con = sqlite3.connect(db_path) + raw_con.execute( + """CREATE TABLE conversations ( + id TEXT PRIMARY KEY, + source_path TEXT NOT NULL UNIQUE, + started_at TIMESTAMP, + last_ply_end INT NOT NULL DEFAULT -1 + )""" + ) + raw_con.execute("INSERT INTO conversations(id, source_path) VALUES ('conv1', '/src')") + raw_con.execute( + """CREATE TABLE exchanges ( + id TEXT PRIMARY KEY, + conversation_id TEXT NOT NULL, + ply_start INT NOT NULL, + ply_end INT NOT NULL, + user_content TEXT NOT NULL, + agent_content TEXT NOT NULL, + distilled_at TIMESTAMP, + distill_status TEXT NOT NULL DEFAULT 'pending' + )""" + ) + raw_con.execute("CREATE TABLE meta (key TEXT PRIMARY KEY, value TEXT)") + raw_con.execute( + "INSERT INTO meta VALUES ('embedding_model', 'test-model')" + ) + raw_con.execute( + "INSERT INTO meta VALUES ('prompt_version', 'v0000001')" + ) + raw_con.execute( + """CREATE TABLE palace_objects ( + id TEXT PRIMARY KEY, + exchange_id TEXT NOT NULL, + exchange_core TEXT NOT NULL, + specific_context TEXT NOT NULL, + distill_text TEXT NOT NULL + )""" + ) + raw_con.execute( + """CREATE TABLE rooms ( + id TEXT PRIMARY KEY, + palace_object_id TEXT NOT NULL, + room_type TEXT NOT NULL, + room_key TEXT NOT NULL, + room_label TEXT NOT NULL, + relevance REAL NOT NULL, + dedup_hash TEXT NOT NULL + )""" + ) + raw_con.execute( + """CREATE TABLE symbols ( + id TEXT PRIMARY KEY, + palace_object_id TEXT NOT NULL, + symbol_name TEXT NOT NULL, + symbol_kind TEXT NOT NULL, + file_path TEXT NOT NULL, + signature TEXT NOT NULL, + line INT NOT NULL, + dedup_hash TEXT NOT NULL + )""" + ) + raw_con.execute("PRAGMA user_version = 4") + raw_con.commit() + raw_con.close() + + # Run init_db which should run v5 migration + init_db(db_path) + + # Verify exchange_files table exists + con = sqlite3.connect(db_path) + exchange_files_exists = con.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name='exchange_files'" + ).fetchone() + assert exchange_files_exists is not None + + # Verify exchange_files columns + table_info = con.execute("PRAGMA table_info(exchange_files)").fetchall() + column_names = [col[1] for col in table_info] + assert "exchange_id" in column_names + assert "file_path" in column_names + con.close() + + +def test_migration_v6_recomputes_symbol_ids(tmp_path: Path) -> None: + db_path = tmp_path / "memory.db" + + # Create a raw sqlite3 DB with user_version=5 (so only v6 and v7 run) + raw_con = sqlite3.connect(db_path) + raw_con.execute( + """CREATE TABLE conversations ( + id TEXT PRIMARY KEY, + source_path TEXT NOT NULL UNIQUE, + started_at TIMESTAMP, + last_ply_end INT NOT NULL DEFAULT -1 + )""" + ) + raw_con.execute("INSERT INTO conversations(id, source_path) VALUES ('conv1', '/src')") + raw_con.execute( + """CREATE TABLE exchanges ( + id TEXT PRIMARY KEY, + conversation_id TEXT NOT NULL, + ply_start INT NOT NULL, + ply_end INT NOT NULL, + user_content TEXT NOT NULL, + agent_content TEXT NOT NULL, + distilled_at TIMESTAMP, + distill_status TEXT NOT NULL DEFAULT 'pending' + )""" + ) + raw_con.execute("CREATE TABLE meta (key TEXT PRIMARY KEY, value TEXT)") + raw_con.execute( + "INSERT INTO meta VALUES ('embedding_model', 'test-model')" + ) + raw_con.execute( + "INSERT INTO meta VALUES ('prompt_version', 'v0000001')" + ) + raw_con.execute( + """CREATE TABLE palace_objects ( + id TEXT PRIMARY KEY, + exchange_id TEXT NOT NULL, + exchange_core TEXT NOT NULL, + specific_context TEXT NOT NULL, + distill_text TEXT NOT NULL + )""" + ) + raw_con.execute( + """CREATE TABLE rooms ( + id TEXT PRIMARY KEY, + palace_object_id TEXT NOT NULL, + room_type TEXT NOT NULL, + room_key TEXT NOT NULL, + room_label TEXT NOT NULL, + relevance REAL NOT NULL, + dedup_hash TEXT NOT NULL + )""" + ) + raw_con.execute( + """CREATE TABLE symbols ( + id TEXT PRIMARY KEY, + palace_object_id TEXT NOT NULL, + symbol_name TEXT NOT NULL, + symbol_kind TEXT NOT NULL, + file_path TEXT NOT NULL, + signature TEXT NOT NULL, + line INT NOT NULL, + dedup_hash TEXT NOT NULL + )""" + ) + + # Insert palace_objects row so v7 doesn't delete the symbol as orphan + raw_con.execute( + "INSERT INTO palace_objects VALUES ('po1', 'ex1', 'c', 's', 'c\ns')" + ) + + # Insert symbol with OLD id formula: sha256("Sym:file.py") + old_id = hashlib.sha256(b"Sym:file.py").hexdigest() + raw_con.execute( + "INSERT INTO symbols VALUES (?, 'po1', 'Sym', 'function', 'file.py', 'def Sym', 1, ?)", + (old_id, old_id), + ) + + raw_con.execute("PRAGMA user_version = 5") + raw_con.commit() + raw_con.close() + + # Run init_db which should run v6 and v7 migrations + init_db(db_path) + + # Verify symbol id was recomputed with NEW formula: sha256("Sym:file.py:po1") + con = sqlite3.connect(db_path) + expected_new_id = hashlib.sha256(b"Sym:file.py:po1").hexdigest() + row = con.execute( + "SELECT id FROM symbols WHERE symbol_name='Sym'" + ).fetchone() + assert row is not None + assert row[0] == expected_new_id + con.close() + + +def test_migration_v7_resets_orphan_distilled(tmp_path: Path) -> None: + db_path = tmp_path / "memory.db" + + # Create a raw sqlite3 DB with user_version=6 (so only v7 runs) + raw_con = sqlite3.connect(db_path) + raw_con.execute( + """CREATE TABLE conversations ( + id TEXT PRIMARY KEY, + source_path TEXT NOT NULL UNIQUE, + started_at TIMESTAMP, + last_ply_end INT NOT NULL DEFAULT -1 + )""" + ) + raw_con.execute("INSERT INTO conversations(id, source_path) VALUES ('conv1', '/src')") + raw_con.execute( + """CREATE TABLE exchanges ( + id TEXT PRIMARY KEY, + conversation_id TEXT NOT NULL, + ply_start INT NOT NULL, + ply_end INT NOT NULL, + user_content TEXT NOT NULL, + agent_content TEXT NOT NULL, + distilled_at TIMESTAMP, + distill_status TEXT NOT NULL DEFAULT 'pending' + )""" + ) + raw_con.execute("CREATE TABLE meta (key TEXT PRIMARY KEY, value TEXT)") + raw_con.execute( + "INSERT INTO meta VALUES ('embedding_model', 'test-model')" + ) + raw_con.execute( + "INSERT INTO meta VALUES ('prompt_version', 'v0000001')" + ) + raw_con.execute( + """CREATE TABLE palace_objects ( + id TEXT PRIMARY KEY, + exchange_id TEXT NOT NULL, + exchange_core TEXT NOT NULL, + specific_context TEXT NOT NULL, + distill_text TEXT NOT NULL + )""" + ) + raw_con.execute( + """CREATE TABLE rooms ( + id TEXT PRIMARY KEY, + palace_object_id TEXT NOT NULL, + room_type TEXT NOT NULL, + room_key TEXT NOT NULL, + room_label TEXT NOT NULL, + relevance REAL NOT NULL, + dedup_hash TEXT NOT NULL + )""" + ) + raw_con.execute( + """CREATE TABLE symbols ( + id TEXT PRIMARY KEY, + palace_object_id TEXT NOT NULL, + symbol_name TEXT NOT NULL, + symbol_kind TEXT NOT NULL, + file_path TEXT NOT NULL, + signature TEXT NOT NULL, + line INT NOT NULL, + dedup_hash TEXT NOT NULL + )""" + ) + + # Insert exchanges row with no palace_objects referencing it + raw_con.execute( + "INSERT INTO exchanges VALUES ('exX', 'conv1', 0, 1, 'user', 'agent', '2026-01-01', 'distilled')" + ) + + raw_con.execute("PRAGMA user_version = 6") + raw_con.commit() + raw_con.close() + + # Run init_db which should run v7 migration + init_db(db_path) + + # Verify exchanges row distill_status and distilled_at were reset + con = sqlite3.connect(db_path) + row = con.execute( + "SELECT distill_status, distilled_at FROM exchanges WHERE id='exX'" + ).fetchone() + assert row is not None + assert row[0] == "pending" + assert row[1] is None + con.close() + + +def test_migration_v7_removes_orphan_symbols(tmp_path: Path) -> None: + db_path = tmp_path / "memory.db" + + # Create a raw sqlite3 DB with user_version=6 (so only v7 runs) + raw_con = sqlite3.connect(db_path) + raw_con.execute( + """CREATE TABLE conversations ( + id TEXT PRIMARY KEY, + source_path TEXT NOT NULL UNIQUE, + started_at TIMESTAMP, + last_ply_end INT NOT NULL DEFAULT -1 + )""" + ) + raw_con.execute("INSERT INTO conversations(id, source_path) VALUES ('conv1', '/src')") + raw_con.execute( + """CREATE TABLE exchanges ( + id TEXT PRIMARY KEY, + conversation_id TEXT NOT NULL, + ply_start INT NOT NULL, + ply_end INT NOT NULL, + user_content TEXT NOT NULL, + agent_content TEXT NOT NULL, + distilled_at TIMESTAMP, + distill_status TEXT NOT NULL DEFAULT 'pending' + )""" + ) + raw_con.execute("CREATE TABLE meta (key TEXT PRIMARY KEY, value TEXT)") + raw_con.execute( + "INSERT INTO meta VALUES ('embedding_model', 'test-model')" + ) + raw_con.execute( + "INSERT INTO meta VALUES ('prompt_version', 'v0000001')" + ) + raw_con.execute( + """CREATE TABLE palace_objects ( + id TEXT PRIMARY KEY, + exchange_id TEXT NOT NULL, + exchange_core TEXT NOT NULL, + specific_context TEXT NOT NULL, + distill_text TEXT NOT NULL + )""" + ) + raw_con.execute( + """CREATE TABLE rooms ( + id TEXT PRIMARY KEY, + palace_object_id TEXT NOT NULL, + room_type TEXT NOT NULL, + room_key TEXT NOT NULL, + room_label TEXT NOT NULL, + relevance REAL NOT NULL, + dedup_hash TEXT NOT NULL + )""" + ) + raw_con.execute( + """CREATE TABLE symbols ( + id TEXT PRIMARY KEY, + palace_object_id TEXT NOT NULL, + symbol_name TEXT NOT NULL, + symbol_kind TEXT NOT NULL, + file_path TEXT NOT NULL, + signature TEXT NOT NULL, + line INT NOT NULL, + dedup_hash TEXT NOT NULL + )""" + ) + + # Insert symbol with palace_object_id="ghost" but NO palace_objects row with that id + ghost_id = hashlib.sha256(b"Ghost:test.py").hexdigest() + raw_con.execute( + "INSERT INTO symbols VALUES (?, 'ghost', 'Ghost', 'function', 'test.py', 'def Ghost', 1, ?)", + (ghost_id, ghost_id), + ) + + raw_con.execute("PRAGMA user_version = 6") + raw_con.commit() + raw_con.close() + + # Run init_db which should run v7 migration + init_db(db_path) + + # Verify orphan symbol was deleted + con = sqlite3.connect(db_path) + count = con.execute("SELECT COUNT(*) FROM symbols").fetchone()[0] + assert count == 0 + con.close() + + +def test_migration_v7_removes_bm25_text_column(tmp_path: Path) -> None: + db_path = tmp_path / "memory.db" + + # Create a raw sqlite3 DB with user_version=6 (so only v7 runs) + raw_con = sqlite3.connect(db_path) + raw_con.execute( + """CREATE TABLE conversations ( + id TEXT PRIMARY KEY, + source_path TEXT NOT NULL UNIQUE, + started_at TIMESTAMP, + last_ply_end INT NOT NULL DEFAULT -1 + )""" + ) + raw_con.execute("INSERT INTO conversations(id, source_path) VALUES ('conv1', '/src')") + raw_con.execute( + """CREATE TABLE exchanges ( + id TEXT PRIMARY KEY, + conversation_id TEXT NOT NULL, + ply_start INT NOT NULL, + ply_end INT NOT NULL, + user_content TEXT NOT NULL, + agent_content TEXT NOT NULL, + distilled_at TIMESTAMP, + distill_status TEXT NOT NULL DEFAULT 'pending' + )""" + ) + raw_con.execute("CREATE TABLE meta (key TEXT PRIMARY KEY, value TEXT)") + raw_con.execute( + "INSERT INTO meta VALUES ('embedding_model', 'test-model')" + ) + raw_con.execute( + "INSERT INTO meta VALUES ('prompt_version', 'v0000001')" + ) + raw_con.execute( + """CREATE TABLE palace_objects ( + id TEXT PRIMARY KEY, + exchange_id TEXT NOT NULL, + exchange_core TEXT NOT NULL, + specific_context TEXT NOT NULL, + distill_text TEXT NOT NULL, + bm25_text TEXT NOT NULL + )""" + ) + raw_con.execute( + """CREATE TABLE rooms ( + id TEXT PRIMARY KEY, + palace_object_id TEXT NOT NULL, + room_type TEXT NOT NULL, + room_key TEXT NOT NULL, + room_label TEXT NOT NULL, + relevance REAL NOT NULL, + dedup_hash TEXT NOT NULL + )""" + ) + raw_con.execute( + """CREATE TABLE symbols ( + id TEXT PRIMARY KEY, + palace_object_id TEXT NOT NULL, + symbol_name TEXT NOT NULL, + symbol_kind TEXT NOT NULL, + file_path TEXT NOT NULL, + signature TEXT NOT NULL, + line INT NOT NULL, + dedup_hash TEXT NOT NULL + )""" + ) + + # Insert palace_objects row with bm25_text + raw_con.execute( + "INSERT INTO palace_objects VALUES ('po1', 'ex1', 'c', 's', 'c\ns', 'legacy')" + ) + + raw_con.execute("PRAGMA user_version = 6") + raw_con.commit() + raw_con.close() + + # Run init_db which should run v7 migration + init_db(db_path) + + # Verify bm25_text column was removed + con = sqlite3.connect(db_path) + table_info = con.execute("PRAGMA table_info(palace_objects)").fetchall() + column_names = [col[1] for col in table_info] + assert "bm25_text" not in column_names + + # Verify palace_objects row still exists + row = con.execute( + "SELECT id FROM palace_objects WHERE id='po1'" + ).fetchone() + assert row is not None + con.close() diff --git a/tests/test_distiller.py b/tests/test_distiller.py index 80be21b..a55c927 100644 --- a/tests/test_distiller.py +++ b/tests/test_distiller.py @@ -4,9 +4,11 @@ call_claude・Embedder はモックしてモデルロードを避ける """ +import hashlib from unittest.mock import MagicMock, patch import numpy as np +import pytest from codeatrium.db import get_connection, init_db from codeatrium.distiller import ( @@ -45,10 +47,10 @@ def _make_exchange(db_path, ex_id, user_text=LONG_TEXT, agent_text=LONG_TEXT): con.execute( """ INSERT OR IGNORE INTO exchanges - (id, conversation_id, ply_start, ply_end, user_content, agent_content, distilled_at) - VALUES (?,?,?,?,?,?,?) + (id, conversation_id, ply_start, ply_end, user_content, agent_content, distilled_at, distill_status) + VALUES (?,?,?,?,?,?,?,?) """, - ("_pad_conv1", "conv1", 0, 1, "padding", "padding", "2026-01-01"), + ("_pad_conv1", "conv1", 0, 1, "padding", "padding", "2026-01-01", "distilled"), ) con.execute( """ @@ -162,28 +164,84 @@ def test_extract_files_relative_paths_unaffected_by_root() -> None: @patch("codeatrium.distiller.call_claude", return_value=MOCK_PALACE_RESPONSE) -def test_distill_exchange_returns_palace(mock_call) -> None: - palace = distill_exchange("ex1", "pool の設定", "pool_size=5 を追加した", 0, 3) +def test_distill_exchange_returns_palace(mock_call, tmp_path) -> None: + db_path = tmp_path / "memory.db" + init_db(db_path) + palace = distill_exchange("ex1", db_path, "pool の設定", "pool_size=5 を追加した", 0, 3) assert palace.exchange_core == "pool_size を 5 に設定した" assert palace.specific_context == "pool_size=5" assert len(palace.room_assignments) == 1 @patch("codeatrium.distiller.call_claude", return_value=MOCK_PALACE_RESPONSE) -def test_distill_exchange_calls_claude_once(mock_call) -> None: - distill_exchange("ex1", "pool の設定", "pool_size=5", 0, 3) +def test_distill_exchange_calls_claude_once(mock_call, tmp_path) -> None: + db_path = tmp_path / "memory.db" + init_db(db_path) + distill_exchange("ex1", db_path, "pool の設定", "pool_size=5", 0, 3) mock_call.assert_called_once() @patch("codeatrium.distiller.call_claude", return_value=MOCK_PALACE_RESPONSE) -def test_distill_exchange_extracts_files(mock_call) -> None: - palace = distill_exchange("ex1", "src/db/pool.py を修正", "pool_size=5", 0, 3) +def test_distill_exchange_extracts_files(mock_call, tmp_path) -> None: + db_path = tmp_path / "memory.db" + init_db(db_path) + palace = distill_exchange("ex1", db_path, "src/db/pool.py を修正", "pool_size=5", 0, 3) assert "src/db/pool.py" in palace.files_touched +@patch("codeatrium.distiller.call_claude", return_value=MOCK_PALACE_RESPONSE) +def test_distill_exchange_merges_exchange_files(mock_call, tmp_path) -> None: + db_path = tmp_path / "memory.db" + init_db(db_path) + _make_exchange(db_path, "ex1") + con = get_connection(db_path) + con.execute( + "INSERT INTO exchange_files (exchange_id, file_path) VALUES (?,?)", + ("ex1", "src/tool.py"), + ) + con.commit() + con.close() + palace = distill_exchange("ex1", db_path, "no paths here", "none", 0, 3) + assert "src/tool.py" in palace.files_touched + + # --- save_palace_object --- +def test_save_palace_object_symbol_id_uses_3part_hash(tmp_path) -> None: + db_path = tmp_path / "memory.db" + init_db(db_path) + _make_exchange(db_path, "ex1", user_text="Foo.bar method", agent_text="Foo.bar implementation") + + resolver = MagicMock() + sym = MagicMock() + sym.symbol_name = "Foo.bar" + sym.file_path = "src/foo.py" + sym.symbol_kind = "method" + sym.signature = "def bar" + sym.line = 1 + resolver.extract.return_value = [sym] + + palace = PalaceObject( + exchange_core="c", + specific_context="s", + room_assignments=[], + files_touched=["src/foo.py"], + ) + save_palace_object(db_path, "ex1", palace, np.zeros(384, dtype=np.float32), resolver=resolver) + + con = get_connection(db_path) + row = con.execute("SELECT id, dedup_hash FROM symbols").fetchone() + con.close() + + palace_id = hashlib.sha256(b"palace:ex1").hexdigest() + expected_id = hashlib.sha256(f"Foo.bar:src/foo.py:{palace_id}".encode()).hexdigest() + expected_dedup = hashlib.sha256(b"Foo.bar:src/foo.py").hexdigest() + + assert row["id"] == expected_id + assert row["dedup_hash"] == expected_dedup + + def test_save_palace_object_stores_in_db(tmp_path) -> None: db_path = tmp_path / "memory.db" init_db(db_path) @@ -213,6 +271,61 @@ def test_save_palace_object_stores_in_db(tmp_path) -> None: con.close() +def test_save_palace_object_skips_symbol_not_in_body(tmp_path) -> None: + db_path = tmp_path / "memory.db" + init_db(db_path) + _make_exchange(db_path, "ex1", user_text="some unrelated text " * 5, agent_text="more text " * 5) + + resolver = MagicMock() + sym = MagicMock() + sym.symbol_name = "NotMentioned" + sym.file_path = "src/foo.py" + resolver.extract.return_value = [sym] + + palace = PalaceObject( + exchange_core="c", + specific_context="s", + room_assignments=[], + files_touched=["src/foo.py"], + ) + save_palace_object(db_path, "ex1", palace, np.zeros(384, dtype=np.float32), resolver=resolver) + + con = get_connection(db_path) + count = con.execute("SELECT COUNT(*) FROM symbols").fetchone()[0] + con.close() + + assert count == 0 + + +def test_save_palace_object_includes_symbol_in_body(tmp_path) -> None: + db_path = tmp_path / "memory.db" + init_db(db_path) + _make_exchange(db_path, "ex1", user_text="Foo.bar method " * 5, agent_text="more text " * 5) + + resolver = MagicMock() + sym = MagicMock() + sym.symbol_name = "Foo.bar" + sym.file_path = "src/foo.py" + sym.symbol_kind = "method" + sym.signature = "def bar" + sym.line = 1 + resolver.extract.return_value = [sym] + + palace = PalaceObject( + exchange_core="c", + specific_context="s", + room_assignments=[], + files_touched=["src/foo.py"], + ) + save_palace_object(db_path, "ex1", palace, np.zeros(384, dtype=np.float32), resolver=resolver) + + con = get_connection(db_path) + count = con.execute("SELECT COUNT(*) FROM symbols").fetchone()[0] + con.close() + + assert count == 1 + + def test_save_palace_object_sets_distilled_at(tmp_path) -> None: db_path = tmp_path / "memory.db" init_db(db_path) @@ -259,6 +372,42 @@ def test_save_palace_object_saves_rooms(tmp_path) -> None: con.close() +def test_save_palace_object_two_palace_objects_same_symbol(tmp_path) -> None: + db_path = tmp_path / "memory.db" + init_db(db_path) + _make_exchange(db_path, "ex1", user_text="Foo.bar method " * 5, agent_text="implementation " * 5) + _make_exchange(db_path, "ex2", user_text="Foo.bar method " * 5, agent_text="implementation " * 5) + + resolver = MagicMock() + sym = MagicMock() + sym.symbol_name = "Foo.bar" + sym.file_path = "src/foo.py" + sym.symbol_kind = "method" + sym.signature = "def bar" + sym.line = 1 + resolver.extract.return_value = [sym] + + palace = PalaceObject( + exchange_core="c", + specific_context="s", + room_assignments=[], + files_touched=["src/foo.py"], + ) + + save_palace_object(db_path, "ex1", palace, np.zeros(384, dtype=np.float32), resolver=resolver) + save_palace_object(db_path, "ex2", palace, np.zeros(384, dtype=np.float32), resolver=resolver) + + con = get_connection(db_path) + all_rows = con.execute("SELECT id, dedup_hash FROM symbols").fetchall() + con.close() + + assert len(all_rows) == 2 + ids = {row["id"] for row in all_rows} + dedup_hashes = {row["dedup_hash"] for row in all_rows} + assert len(ids) == 2 # Two different ids + assert len(dedup_hashes) == 1 # Same dedup_hash + + def test_save_palace_object_saves_vec(tmp_path) -> None: db_path = tmp_path / "memory.db" init_db(db_path) @@ -290,7 +439,7 @@ def test_distill_all_processes_undistilled(mock_call, tmp_path) -> None: mock_embedder.embed_passage.return_value = np.zeros(384, dtype=np.float32) with patch("codeatrium.distiller.Embedder", return_value=mock_embedder): - count = distill_all(db_path) + count, _ = distill_all(db_path) assert count == 1 @@ -302,13 +451,13 @@ def test_distill_all_skips_distilled(mock_call, tmp_path) -> None: _make_exchange(db_path, "ex1") con = get_connection(db_path) - con.execute("UPDATE exchanges SET distilled_at = '2026-01-01' WHERE id = 'ex1'") + con.execute("UPDATE exchanges SET distilled_at = '2026-01-01', distill_status = 'distilled' WHERE id = 'ex1'") con.commit() con.close() mock_embedder = MagicMock() with patch("codeatrium.distiller.Embedder", return_value=mock_embedder): - count = distill_all(db_path) + count, _ = distill_all(db_path) assert count == 0 @@ -324,6 +473,143 @@ def test_distill_all_returns_count(mock_call, tmp_path) -> None: mock_embedder.embed_passage.return_value = np.zeros(384, dtype=np.float32) with patch("codeatrium.distiller.Embedder", return_value=mock_embedder): - count = distill_all(db_path) + count, _ = distill_all(db_path) assert count == 2 + + +@patch("codeatrium.distiller.call_claude", return_value=MOCK_PALACE_RESPONSE) +def test_distill_all_returns_tuple(mock_call, tmp_path) -> None: + """distill_all は tuple を返す""" + db_path = tmp_path / "memory.db" + init_db(db_path) + _make_exchange(db_path, "ex1") + + mock_embedder = MagicMock() + mock_embedder.embed_passage.return_value = np.zeros(384, dtype=np.float32) + + with patch("codeatrium.distiller.Embedder", return_value=mock_embedder): + result = distill_all(db_path) + + assert isinstance(result, tuple) + assert len(result) == 2 + + +@patch("codeatrium.distiller.call_claude") +def test_distill_all_error_count(mock_call, tmp_path) -> None: + """distill_all はエラー数をカウントして返す""" + db_path = tmp_path / "memory.db" + init_db(db_path) + _make_exchange(db_path, "ex1") + _make_exchange(db_path, "ex2") + + # 1回目は失敗、2回目は成功 + mock_call.side_effect = [RuntimeError("Test error"), MOCK_PALACE_RESPONSE] + + mock_embedder = MagicMock() + mock_embedder.embed_passage.return_value = np.zeros(384, dtype=np.float32) + + with patch("codeatrium.distiller.Embedder", return_value=mock_embedder): + count, errors = distill_all(db_path) + + assert count == 1 + assert errors == 1 + + +def test_save_palace_object_sets_distill_status_distilled(tmp_path) -> None: + """save_palace_object は distill_status を 'distilled' にセットする""" + db_path = tmp_path / "memory.db" + init_db(db_path) + _make_exchange(db_path, "ex1") + + palace = PalaceObject( + exchange_core="蒸留済み", + specific_context="detail", + room_assignments=[], + ) + save_palace_object(db_path, "ex1", palace, np.zeros(384, dtype=np.float32)) + + con = get_connection(db_path) + row = con.execute( + "SELECT distill_status FROM exchanges WHERE id=?", ("ex1",) + ).fetchone() + assert row["distill_status"] == "distilled" + con.close() + + +def test_save_palace_object_raises_on_palace_insert_failure(tmp_path) -> None: + """save_palace_object が palace_objects INSERT 失敗時に例外を raise する""" + db_path = tmp_path / "memory.db" + init_db(db_path) + _make_exchange(db_path, "ex1") + + # Recreate palace_objects table with NOT NULL bm25_text to simulate legacy schema + con = get_connection(db_path) + con.execute("DROP TABLE palace_objects") + con.execute(""" + CREATE TABLE palace_objects ( + id TEXT PRIMARY KEY, + exchange_id TEXT NOT NULL, + exchange_core TEXT NOT NULL, + specific_context TEXT NOT NULL, + distill_text TEXT NOT NULL, + bm25_text TEXT NOT NULL + ) + """) + con.commit() + con.close() + + palace = PalaceObject( + exchange_core="テスト", + specific_context="detail", + room_assignments=[], + files_touched=[], + ) + + with pytest.raises(Exception): + save_palace_object(db_path, "ex1", palace, np.zeros(384, dtype=np.float32)) + + con = get_connection(db_path) + # distill_status should still be 'pending' due to rollback + row = con.execute( + "SELECT distill_status FROM exchanges WHERE id=?", ("ex1",) + ).fetchone() + assert row["distill_status"] == "pending" + + # palace_objects should have 0 rows + palace_rows = con.execute("SELECT * FROM palace_objects").fetchall() + assert len(palace_rows) == 0 + con.close() + + +def test_save_palace_object_rollback_on_error(tmp_path) -> None: + """save_palace_object がエラーで失敗した場合、distill_status は 'pending' で palace_objects は空""" + db_path = tmp_path / "memory.db" + init_db(db_path) + _make_exchange(db_path, "ex1") + + palace = PalaceObject( + exchange_core="テスト", + specific_context="detail", + room_assignments=[], + ) + + # 不正な次元の embedding を渡して struct.pack を失敗させる + bad_embedding = np.zeros(100, dtype=np.float32) + + try: + save_palace_object(db_path, "ex1", palace, bad_embedding) + except Exception: + pass # エラーは予期されている + + con = get_connection(db_path) + # distill_status は 'pending' のまま + row = con.execute( + "SELECT distill_status FROM exchanges WHERE id=?", ("ex1",) + ).fetchone() + assert row["distill_status"] == "pending" + + # palace_objects テーブルは空 + palace_rows = con.execute("SELECT * FROM palace_objects").fetchall() + assert len(palace_rows) == 0 + con.close() diff --git a/tests/test_embedder.py b/tests/test_embedder.py index c94b6da..ede11df 100644 --- a/tests/test_embedder.py +++ b/tests/test_embedder.py @@ -3,9 +3,11 @@ モデルロードを避けるため embed() は mock する """ -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch -from codeatrium.embedder import Embedder +import pytest + +from codeatrium.embedder import Embedder, _try_socket_embed def test_embedder_returns_384_dim() -> None: @@ -45,3 +47,34 @@ def test_embedder_returns_float32() -> None: vec = embedder.embed("テスト") assert vec.dtype == np.float32 + + +def test_embed_via_socket_or_direct_raises_when_model_none() -> None: + """_ensure_model 後も _model が None なら RuntimeError(assert 撤去・Q4)""" + embedder = Embedder.__new__(Embedder) + embedder._sock_path = None + embedder._model = None + # _ensure_model を no-op にして _model を None のままにする + embedder._ensure_model = MagicMock() # type: ignore[method-assign] + + with pytest.raises(RuntimeError): + embedder._embed_via_socket_or_direct("テスト", "query", "query: ") + + +def test_try_socket_embed_chunked_response() -> None: + """改行終端レスポンスが複数チャンクで届いても再構成される(H5 client)""" + import numpy as np + + fake_sock = MagicMock() + fake_sock.__enter__.return_value = fake_sock + fake_sock.__exit__.return_value = False + fake_sock.recv.side_effect = [b'{"embedding":[0.1,0.2]}', b"\n", b""] + + mock_path = MagicMock() + mock_path.exists.return_value = True + + with patch("codeatrium.embedder.socket.socket", return_value=fake_sock): + vec = _try_socket_embed(mock_path, "query", "hello") + + assert vec is not None + np.testing.assert_allclose(vec, np.array([0.1, 0.2], dtype=np.float32)) diff --git a/tests/test_embedder_server.py b/tests/test_embedder_server.py new file mode 100644 index 0000000..ca06a1a --- /dev/null +++ b/tests/test_embedder_server.py @@ -0,0 +1,116 @@ +""" +embedder_server._handle_client のテスト(H5 ソケット堅牢性 / Q7 テスト空白埋め) + +実モデルはロードせず embedder は MagicMock。接続は socket.socketpair() を使う。 +""" + +from __future__ import annotations + +import json +import socket +import threading +import time + +import numpy as np +import pytest + +from codeatrium.embedder_server import _handle_client + + +def _read_line(sock: socket.socket) -> dict: + """改行が現れるまで recv して JSON を 1 行パースする""" + buf = b"" + sock.settimeout(2.0) + while b"\n" not in buf: + chunk = sock.recv(4096) + if not chunk: + break + buf += chunk + return json.loads(buf.split(b"\n")[0]) + + +def _start_handler( + server_conn: socket.socket, + embedder: object, + stop_event: threading.Event, +) -> threading.Thread: + """_handle_client を daemon スレッドで起動する""" + t = threading.Thread( + target=_handle_client, + args=(server_conn, embedder, [0.0], stop_event), + daemon=True, + ) + t.start() + return t + + +def test_handle_client_chunked_request() -> None: + """リクエスト JSON が複数 recv に分割到着しても 1 リクエストとして処理される""" + server_conn, client_conn = socket.socketpair() + embedder = type("E", (), {})() + embedder.embed = lambda text: np.array([0.1, 0.2, 0.3], dtype=np.float32) # type: ignore[attr-defined] + stop_event = threading.Event() + t = _start_handler(server_conn, embedder, stop_event) + try: + client_conn.sendall(b'{"type":"query","te') + time.sleep(0.05) + client_conn.sendall(b'xt":"hello"}\n') + resp = _read_line(client_conn) + assert resp["embedding"] == pytest.approx([0.1, 0.2, 0.3]) + finally: + client_conn.close() + server_conn.close() + t.join(timeout=2.0) + + +def test_handle_client_invalid_json_then_valid() -> None: + """不正 JSON は error 応答を返し、接続を切らず後続を処理する""" + server_conn, client_conn = socket.socketpair() + embedder = type("E", (), {})() + embedder.embed = lambda text: np.array([1.0], dtype=np.float32) # type: ignore[attr-defined] + stop_event = threading.Event() + t = _start_handler(server_conn, embedder, stop_event) + try: + client_conn.sendall(b"not valid json\n") + first = _read_line(client_conn) + assert first["error"] == "invalid json" + client_conn.sendall(b'{"type":"query","text":"hi"}\n') + second = _read_line(client_conn) + assert second["embedding"] == [1.0] + finally: + client_conn.close() + server_conn.close() + t.join(timeout=2.0) + + +def test_handle_client_ping() -> None: + """ping に status ok を返す""" + server_conn, client_conn = socket.socketpair() + embedder = type("E", (), {})() + stop_event = threading.Event() + t = _start_handler(server_conn, embedder, stop_event) + try: + client_conn.sendall(b'{"type":"ping"}\n') + resp = _read_line(client_conn) + assert resp == {"status": "ok"} + finally: + client_conn.close() + server_conn.close() + t.join(timeout=2.0) + + +def test_handle_client_stop() -> None: + """stop で stopping を返し stop_event がセットされる""" + server_conn, client_conn = socket.socketpair() + embedder = type("E", (), {})() + stop_event = threading.Event() + t = _start_handler(server_conn, embedder, stop_event) + try: + client_conn.sendall(b'{"type":"stop"}\n') + resp = _read_line(client_conn) + assert resp == {"status": "stopping"} + t.join(timeout=2.0) + assert stop_event.is_set() + finally: + client_conn.close() + server_conn.close() diff --git a/tests/test_indexer.py b/tests/test_indexer.py index 9e4bb5c..35d21e0 100644 --- a/tests/test_indexer.py +++ b/tests/test_indexer.py @@ -37,6 +37,23 @@ def make_assistant_entry(uuid: str, text: str, parent_uuid: str) -> dict: } +def make_assistant_entry_with_tool_use( + uuid: str, file_path: str, parent_uuid: str | None = None, tool_name: str = "Edit" +) -> dict: + """Assistant entry with a tool_use block (no text content)""" + key = "notebook_path" if tool_name == "NotebookEdit" else "file_path" + return { + "type": "assistant", + "uuid": uuid, + "parentUuid": parent_uuid, + "timestamp": "2026-03-26T00:00:01.000Z", + "message": { + "role": "assistant", + "content": [{"type": "tool_use", "name": tool_name, "input": {key: file_path}}], + }, + } + + def write_jsonl(path: Path, entries: list[dict]) -> None: with path.open("w") as f: for e in entries: @@ -126,6 +143,57 @@ def test_parse_exchanges_ply_range(tmp_path: Path) -> None: assert exchanges[0].ply_end == 1 +def test_parse_exchanges_skips_old_exchanges(tmp_path: Path) -> None: + """last_ply_end 以前の ply は exchange として再構築しない(None プレースホルダ)""" + f = tmp_path / "session.jsonl" + write_jsonl( + f, + [ + make_user_entry("u1", "最初の質問です。よろしくお願いします。" * 5), + make_assistant_entry("a1", "了解しました。詳しく説明します。" * 5, "u1"), + make_user_entry("u2", "次の質問です。詳しく教えてください。" * 5, "a1"), + make_assistant_entry( + "a2", "詳しく説明します。ご参考になれば幸いです。" * 5, "u2" + ), + ], + ) + exchanges = parse_exchanges(f, last_ply_end=1) + # ply 0, 1 は再構築されず、ply 2, 3 の1件の exchange のみ返る + assert len(exchanges) == 1 + assert exchanges[0].ply_start == 2 + assert exchanges[0].ply_end == 3 + + +def test_parse_exchanges_malformed_line_in_indexed_region_no_drift( + tmp_path: Path, +) -> None: + """既インデックス領域に壊れた JSON 行があってもスキップ境界がズレない(回帰)。 + + 壊れた行は成功パース座標系に位置を持たないため、skip 領域でも数えてはならない。 + """ + f = tmp_path / "session.jsonl" + # ply 0(user), 1(assistant) を全行パース時の last_ply_end とする。 + # 行頭に壊れた JSON を 1 行混ぜると、座標系を誤ると境界が 1 つ早くズレる。 + lines = [ + json.dumps(make_user_entry("u1", "最初の質問です。" * 6)), + "{ this is not valid json", + json.dumps(make_assistant_entry("a1", "了解しました。" * 6, "u1")), + json.dumps(make_user_entry("u2", "次の質問です。" * 6, "a1")), + json.dumps(make_assistant_entry("a2", "説明します。" * 6, "u2")), + ] + f.write_text("\n".join(lines) + "\n", encoding="utf-8") + + # 全行パース(last_ply_end=-1)での境界を基準にする + full = parse_exchanges(f, last_ply_end=-1) + # 壊れた行は無視され、u1/a1 と u2/a2 の 2 exchange になる + assert [(e.ply_start, e.ply_end) for e in full] == [(0, 1), (2, 3)] + + # 最初の exchange(ply 0,1)までインデックス済みとして再パース + incremental = parse_exchanges(f, last_ply_end=1) + # 2 番目の exchange(ply 2,3)だけが返り、座標がドリフトしない + assert [(e.ply_start, e.ply_end) for e in incremental] == [(2, 3)] + + def test_parse_exchanges_deterministic_id(tmp_path: Path) -> None: """同じファイルを2回パースすると同じ exchange_id になる""" f = tmp_path / "session.jsonl" @@ -301,3 +369,101 @@ def test_index_file_fts_populated(tmp_path: Path) -> None: ).fetchall() assert len(rows) == 1 con.close() + + +# ---- tool_use ファイル抽出テスト ---- + + +def test_parse_exchanges_captures_tool_use_files(tmp_path: Path) -> None: + """tool_use Edit ブロックのファイルパスが exchange.files に含まれる""" + f = tmp_path / "session.jsonl" + write_jsonl( + f, + [ + make_user_entry("u1", "Edit the source file please. " * 10), + make_assistant_entry_with_tool_use("a1", "src/foo.py", "u1"), + ], + ) + exchanges = parse_exchanges(f) + assert len(exchanges) == 1 + assert "src/foo.py" in exchanges[0].files + + +def test_parse_exchanges_tool_use_dedup(tmp_path: Path) -> None: + """複数の assistant エントリが同じファイルを参照する場合、重複を除外する""" + f = tmp_path / "session.jsonl" + write_jsonl( + f, + [ + make_user_entry("u1", "Edit the source file please. " * 10), + make_assistant_entry_with_tool_use("a1", "src/foo.py", "u1"), + make_assistant_entry_with_tool_use("a2", "src/foo.py", "a1"), + ], + ) + exchanges = parse_exchanges(f) + assert len(exchanges) == 1 + assert exchanges[0].files.count("src/foo.py") == 1 + assert len(exchanges[0].files) == 1 + + +def test_parse_exchanges_tool_use_excludes_external(tmp_path: Path) -> None: + """外部パス(.venv など)は除外される""" + f = tmp_path / "session.jsonl" + write_jsonl( + f, + [ + make_user_entry("u1", "Edit the source file please. " * 10), + make_assistant_entry_with_tool_use( + "a1", "/Users/u/.venv/lib/python3.12/site-packages/foo.py", "u1" + ), + ], + ) + exchanges = parse_exchanges(f) + assert len(exchanges) == 1 + assert "/Users/u/.venv/lib/python3.12/site-packages/foo.py" not in exchanges[0].files + assert len(exchanges[0].files) == 0 + + +def test_index_file_writes_exchange_files(tmp_path: Path) -> None: + """index_file が exchange_files テーブルに書き込む""" + db_path = tmp_path / ".codeatrium" / "memory.db" + init_db(db_path) + + jsonl = tmp_path / "session.jsonl" + write_jsonl( + jsonl, + [ + make_user_entry("u1", "Edit the source file please. " * 10), + make_assistant_entry_with_tool_use("a1", "src/bar.py", "u1"), + ], + ) + + index_file(jsonl, db_path) + + con = get_connection(db_path) + rows = con.execute("SELECT file_path FROM exchange_files").fetchall() + assert "src/bar.py" in [r[0] for r in rows] + con.close() + + +def test_index_file_exchange_files_dedup(tmp_path: Path) -> None: + """同じ exchange 内の複数の tool_use ブロックでも exchange_files は重複しない""" + db_path = tmp_path / ".codeatrium" / "memory.db" + init_db(db_path) + + jsonl = tmp_path / "session.jsonl" + write_jsonl( + jsonl, + [ + make_user_entry("u1", "Edit the source file please. " * 10), + make_assistant_entry_with_tool_use("a1", "src/baz.py", "u1"), + make_assistant_entry_with_tool_use("a2", "src/baz.py", "a1"), + ], + ) + + index_file(jsonl, db_path) + + con = get_connection(db_path) + count = con.execute("SELECT COUNT(*) FROM exchange_files").fetchone()[0] + assert count == 1 + con.close() diff --git a/tests/test_llm.py b/tests/test_llm.py new file mode 100644 index 0000000..1be016f --- /dev/null +++ b/tests/test_llm.py @@ -0,0 +1,148 @@ +"""call_claude のユニットテスト: subprocess.run をモックして振る舞いを検証する""" + +from __future__ import annotations + +import json +import subprocess +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from codeatrium.llm import call_claude + +# ---- テストデータ ---- + + +MOCK_JSON_RESPONSE = { + "structured_output": { + "exchange_core": "テスト交換: パラメータを設定した", + "specific_context": "timeout=300", + "room_assignments": [ + { + "room_type": "concept", + "room_key": "test-param", + "room_label": "Test Parameter", + "relevance": 0.85, + } + ], + } +} + + +# ---- テスト ---- + + +def test_call_claude_command_args() -> None: + """ + subprocess.run をモックし、call_claude 実行時のコマンドリストに + --no-session-persistence, --setting-sources, --output-format (json), + --model が含まれることを assert する + """ + mock_result = MagicMock() + mock_result.returncode = 0 + mock_result.stdout = json.dumps(MOCK_JSON_RESPONSE) + + with patch("codeatrium.llm.subprocess.run", return_value=mock_result) as mock_run: + with patch("shutil.which", return_value="/usr/bin/claude"): + call_claude("test prompt") + + # subprocess.run が呼ばれたことを確認 + assert mock_run.called + + # 呼び出し時のコマンドリスト取得 + call_args = mock_run.call_args + assert call_args is not None + cmd_list = call_args[0][0] # 第一引数のコマンドリスト + + # 必須フラグが含まれていることを確認 + assert "--no-session-persistence" in cmd_list + assert "--setting-sources" in cmd_list + assert "--output-format" in cmd_list + assert "--model" in cmd_list + assert "json" in cmd_list + + # claude コマンドパスと --print フラグ + assert cmd_list[0] == "/usr/bin/claude" + assert "--print" in cmd_list + + +def test_call_claude_returns_dict() -> None: + """ + モックして call_claude の戻り値が期待する dict + (structured_output を含む) であることを assert する + """ + mock_result = MagicMock() + mock_result.returncode = 0 + mock_result.stdout = json.dumps(MOCK_JSON_RESPONSE) + + with patch("codeatrium.llm.subprocess.run", return_value=mock_result): + with patch("shutil.which", return_value="/usr/bin/claude"): + result = call_claude("test prompt") + + # 戻り値が dict で、期待するキーを含むことを確認 + assert isinstance(result, dict) + assert "exchange_core" in result + assert "specific_context" in result + assert "room_assignments" in result + + # 値の確認 + assert result["exchange_core"] == "テスト交換: パラメータを設定した" + assert result["specific_context"] == "timeout=300" + assert isinstance(result["room_assignments"], list) + assert len(result["room_assignments"]) == 1 + + +def test_call_claude_cleanup_on_success(tmp_path: Path) -> None: + """ + _session_dir を tmp_path に向け、副作用 .jsonl が + 正常終了時にクリーンアップされることを確認する + """ + side_jsonl = tmp_path / "side.jsonl" + + def fake_run(*args, **kwargs): + # subprocess.run 呼び出し時(before スナップショット取得後)に副作用ファイルを作成 + side_jsonl.write_text('{"key": "value"}\n') + mock_result = MagicMock() + mock_result.returncode = 0 + mock_result.stdout = json.dumps(MOCK_JSON_RESPONSE) + return mock_result + + with patch("codeatrium.llm.subprocess.run", side_effect=fake_run): + with patch("shutil.which", return_value="/usr/bin/claude"): + with patch( + "codeatrium.llm._session_dir", return_value=tmp_path + ): + call_claude("test prompt") + + # 正常終了後は副作用 .jsonl がクリーンアップされたことを確認 + assert not side_jsonl.exists() + + +def test_call_claude_cleanup_on_timeout(tmp_path: Path) -> None: + """ + subprocess.run が subprocess.TimeoutExpired を投げるようモックし: + (1) call_claude が例外を送出する (pytest.raises) + (2) それでも副作用 .jsonl のクリーンアップが走る (finally 経路) ことを確認する + """ + side_jsonl = tmp_path / "side.jsonl" + + def fake_run(*args, **kwargs): + # subprocess.run 呼び出し時(before スナップショット取得後)に副作用ファイルを作成 + side_jsonl.write_text('{"key": "value"}\n') + raise subprocess.TimeoutExpired("claude", 300) + + with patch( + "codeatrium.llm.subprocess.run", + side_effect=fake_run, + ): + with patch("shutil.which", return_value="/usr/bin/claude"): + with patch( + "codeatrium.llm._session_dir", return_value=tmp_path + ): + # TimeoutExpired が発生することを確認 + with pytest.raises(subprocess.TimeoutExpired): + call_claude("test prompt") + + # タイムアウト時にも副作用 .jsonl がクリーンアップされたことを確認 + assert not side_jsonl.exists() diff --git a/tests/test_prime_cmd.py b/tests/test_prime_cmd.py new file mode 100644 index 0000000..3722bb9 --- /dev/null +++ b/tests/test_prime_cmd.py @@ -0,0 +1,75 @@ +"""tests for prime_cmd — PRIME_TEXT contract + inject_claude_md idempotency""" + +from __future__ import annotations + +from pathlib import Path + +from codeatrium.cli.prime_cmd import ( + BEGIN_MARKER, + END_MARKER, + PRIME_TEXT, + inject_claude_md, +) + +# ---- PRIME_TEXT contract ---- + + +def test_prime_text_has_loci_context_section_heading(): + """PRIME_TEXT must contain an independent section heading for loci context""" + assert "### Context" in PRIME_TEXT + + +def test_prime_text_has_agent_action_triggers(): + """PRIME_TEXT must list agent-initiated action triggers for edit/refactor, new impl, and error""" + assert "Before editing or refactoring" in PRIME_TEXT + assert "Before starting a new implementation" in PRIME_TEXT + assert "encounter" in PRIME_TEXT + + +def test_prime_text_has_concrete_search_example(): + """PRIME_TEXT must contain a concrete loci search example (not a bare placeholder)""" + assert 'loci search "BM25 RRF fusion ranking"' in PRIME_TEXT + + +def test_prime_text_has_concrete_context_example(): + """PRIME_TEXT must contain a concrete loci context example with a real symbol""" + assert 'loci context --symbol "SymbolResolver.extract"' in PRIME_TEXT + + +def test_prime_text_context_section_explains_bidirectional_recall(): + """PRIME_TEXT context section must convey the symbol-to-memory design intent""" + text_lower = PRIME_TEXT.lower() + assert any( + phrase in text_lower for phrase in ["recall", "reverse lookup", "memory about"] + ) + + +# ---- inject_claude_md idempotency ---- + + +def test_inject_claude_md_creates_file_when_absent(tmp_path: Path): + """inject_claude_md creates CLAUDE.md when it does not exist""" + result = inject_claude_md(tmp_path) + assert result is True + claude_md = tmp_path / "CLAUDE.md" + assert claude_md.exists() + content = claude_md.read_text() + assert BEGIN_MARKER in content + assert END_MARKER in content + + +def test_inject_claude_md_idempotent_on_second_call(tmp_path: Path): + """inject_claude_md returns False on second call when content is already up-to-date""" + inject_claude_md(tmp_path) + result2 = inject_claude_md(tmp_path) + assert result2 is False + + +def test_inject_claude_md_second_call_does_not_modify_content(tmp_path: Path): + """inject_claude_md leaves file content unchanged on second call""" + inject_claude_md(tmp_path) + claude_md = tmp_path / "CLAUDE.md" + content_after_first = claude_md.read_text() + inject_claude_md(tmp_path) + content_after_second = claude_md.read_text() + assert content_after_first == content_after_second diff --git a/tests/test_search_cmd.py b/tests/test_search_cmd.py new file mode 100644 index 0000000..06b1769 --- /dev/null +++ b/tests/test_search_cmd.py @@ -0,0 +1,125 @@ +"""loci context コマンドの出力契約テスト + +C5: 既定出力から会話全文を外し verbatim_ref を返す。--full で全文復元。 +""" + +from __future__ import annotations + +import json +import sqlite3 +from pathlib import Path + +from typer.testing import CliRunner + +from codeatrium.cli import app +from codeatrium.db import get_connection, init_db + +runner = CliRunner() + +LONG = "x" * 200 + + +def _setup(tmp_path: Path) -> tuple[Path, sqlite3.Connection]: + codeatrium_dir = tmp_path / ".codeatrium" + codeatrium_dir.mkdir() + db = codeatrium_dir / "memory.db" + init_db(db) + con = get_connection(db) + return db, con + + +def _insert_fixture( + con, + ex_id="ex1", + conv_id="conv1", + source_path="/fake/session.jsonl", + ply_start=0, + symbol_name="MyFunc", +) -> None: + con.execute( + "INSERT OR IGNORE INTO conversations (id, source_path) VALUES (?,?)", + (conv_id, source_path), + ) + con.execute( + """INSERT OR IGNORE INTO exchanges + (id, conversation_id, ply_start, ply_end, user_content, agent_content) + VALUES (?,?,?,?,?,?)""", + (ex_id, conv_id, ply_start, ply_start + 3, "user " + LONG, "agent " + LONG), + ) + con.execute( + """INSERT OR IGNORE INTO palace_objects + (id, exchange_id, exchange_core, specific_context, distill_text) + VALUES (?,?,?,?,?)""", + ("p1", ex_id, "core summary", "specific detail", "core summary"), + ) + con.execute( + """INSERT OR IGNORE INTO symbols + (id, palace_object_id, symbol_name, symbol_kind, file_path, + signature, line, dedup_hash) + VALUES (?,?,?,?,?,?,?,?)""", + ("s1", "p1", symbol_name, "function", "src/foo.py", "def MyFunc()", 42, "hash1"), + ) + con.commit() + + +def test_context_default_no_full_content(tmp_path, monkeypatch): + monkeypatch.chdir(tmp_path) + db, con = _setup(tmp_path) + _insert_fixture(con) + con.close() + + result = runner.invoke(app, ["context", "--symbol", "MyFunc", "--json"]) + assert result.exit_code == 0 + data = json.loads(result.output) + assert "user_content" not in data[0] + assert "agent_content" not in data[0] + assert "verbatim_ref" in data[0] + + +def test_context_full_flag_includes_content(tmp_path, monkeypatch): + monkeypatch.chdir(tmp_path) + db, con = _setup(tmp_path) + _insert_fixture(con) + con.close() + + result = runner.invoke( + app, ["context", "--symbol", "MyFunc", "--json", "--full"] + ) + assert result.exit_code == 0 + data = json.loads(result.output) + assert "user_content" in data[0] + assert "agent_content" in data[0] + + +def test_context_default_nine_fields(tmp_path, monkeypatch): + monkeypatch.chdir(tmp_path) + db, con = _setup(tmp_path) + _insert_fixture(con) + con.close() + + result = runner.invoke(app, ["context", "--symbol", "MyFunc", "--json"]) + assert result.exit_code == 0 + data = json.loads(result.output) + assert set(data[0].keys()) == { + "symbol_name", + "symbol_kind", + "file_path", + "signature", + "line", + "exchange_id", + "exchange_core", + "specific_context", + "verbatim_ref", + } + + +def test_context_verbatim_ref_format(tmp_path, monkeypatch): + monkeypatch.chdir(tmp_path) + db, con = _setup(tmp_path) + _insert_fixture(con, source_path="/fake/session.jsonl", ply_start=10) + con.close() + + result = runner.invoke(app, ["context", "--symbol", "MyFunc", "--json"]) + assert result.exit_code == 0 + data = json.loads(result.output) + assert data[0]["verbatim_ref"] == "/fake/session.jsonl:ply=10" diff --git a/tests/test_search_phase2.py b/tests/test_search_phase2.py index d831591..9c87e29 100644 --- a/tests/test_search_phase2.py +++ b/tests/test_search_phase2.py @@ -5,8 +5,10 @@ import struct from pathlib import Path +from unittest.mock import MagicMock, patch import numpy as np +import pytest from codeatrium.db import get_connection, init_db from codeatrium.search import ( @@ -16,6 +18,7 @@ rrf, search_bm25, search_combined, + search_hnsw_palace, ) LONG_TEXT = "connection pool " * 10 @@ -288,3 +291,71 @@ def test_search_combined_with_palace(tmp_path: Path) -> None: vec = np.ones(384, dtype=np.float32) results = search_combined(db_path, "connection pool", vec, limit=5) assert any(r.exchange_id == "ex1" for r in results) + + +# --- connection leak tests --- + + +def test_search_bm25_connection_leak(tmp_path: Path) -> None: + """search_bm25 は例外時も connection を close する""" + db_path = tmp_path / "memory.db" + init_db(db_path) + + fake_con = MagicMock() + fake_con.execute.side_effect = RuntimeError("test error") + + with patch("codeatrium.search.get_connection", return_value=fake_con): + with pytest.raises(RuntimeError): + search_bm25(db_path, "query") + + assert fake_con.close.called + + +def test_search_hnsw_connection_leak(tmp_path: Path) -> None: + """search_hnsw_palace は例外時も connection を close する""" + db_path = tmp_path / "memory.db" + init_db(db_path) + + fake_con = MagicMock() + fake_con.execute.side_effect = RuntimeError("test error") + + with patch("codeatrium.search.get_connection", return_value=fake_con): + with pytest.raises(RuntimeError): + search_hnsw_palace(db_path, np.ones(384, dtype=np.float32)) + + assert fake_con.close.called + + +def test_search_combined_enrich_connection_leak(tmp_path: Path) -> None: + """search_combined は enrich 時に exception が出ても enrich con を close する""" + db_path = tmp_path / "memory.db" + init_db(db_path) + con = get_connection(db_path) + _insert_exchange(con, "ex1", LONG_TEXT, "pool response") + con.close() + + fake_enrich_con = MagicMock() + stored_con = None + + def mock_get_connection(path): + """最初の呼び出しは real con (search_bm25/search_hnsw 用)、 + 次の呼び出しで fake con (enrich 用) を返す""" + nonlocal stored_con + if stored_con is None: + stored_con = get_connection(path) + return stored_con + return fake_enrich_con + + with patch( + "codeatrium.search.get_connection", side_effect=mock_get_connection + ), patch( + "codeatrium.search._enrich_results", side_effect=RuntimeError("enrich failed") + ): + with pytest.raises(RuntimeError): + search_combined( + db_path, "connection pool", np.ones(384, dtype=np.float32), limit=5 + ) + + assert fake_enrich_con.close.called + if stored_con is not None: + stored_con.close() diff --git a/tests/test_security.py b/tests/test_security.py index b58fd50..144e789 100644 --- a/tests/test_security.py +++ b/tests/test_security.py @@ -9,13 +9,20 @@ from __future__ import annotations +import fcntl import os import stat from pathlib import Path from unittest.mock import patch +from typer.testing import CliRunner + +from codeatrium.cli import app +from codeatrium.db import init_db from codeatrium.hooks import install_hooks +runner = CliRunner() + # --- #1: hooks.py — shlex.quote でパスをクオート --- @@ -24,9 +31,10 @@ def test_hooks_quotes_loci_path_with_spaces() -> None: fake_path = "/Users/test user/venvs/my env/bin/loci" with patch("codeatrium.hooks.loci_bin", return_value=fake_path): with patch("codeatrium.hooks.Path") as mock_path_cls: - mock_settings = mock_path_cls.home.return_value / ".claude" / "settings.json" - mock_settings.exists.return_value = False - _, msg = install_hooks() + with patch("codeatrium.hooks._write_settings"): + mock_settings = mock_path_cls.home.return_value / ".claude" / "settings.json" + mock_settings.exists.return_value = False + _, msg = install_hooks() # shlex.quote はシングルクオートでラップする assert "'" in msg or "\\" in msg @@ -35,9 +43,10 @@ def test_hooks_batch_limit_cast_to_int() -> None: """batch_limit が int にキャストされることを確認""" with patch("codeatrium.hooks.loci_bin", return_value="/usr/bin/loci"): with patch("codeatrium.hooks.Path") as mock_path_cls: - mock_settings = mock_path_cls.home.return_value / ".claude" / "settings.json" - mock_settings.exists.return_value = False - _, msg = install_hooks(batch_limit=20) + with patch("codeatrium.hooks._write_settings"): + mock_settings = mock_path_cls.home.return_value / ".claude" / "settings.json" + mock_settings.exists.return_value = False + _, msg = install_hooks(batch_limit=20) assert "--limit 20" in msg @@ -82,7 +91,7 @@ def test_distill_all_limit_parameterized(tmp_path: Path) -> None: patch("codeatrium.distiller.call_claude", return_value=mock_response), patch("codeatrium.distiller.Embedder", return_value=mock_embedder), ): - count = distill_all(db_path, limit=1) + count, _ = distill_all(db_path, limit=1) assert count == 1 @@ -144,41 +153,41 @@ def _start_and_stop() -> None: def test_distill_lock_atomic_creation(tmp_path: Path) -> None: - """ロックファイルが O_CREAT | O_EXCL で原子的に作成される""" + """2つ目の flock(LOCK_NB) は BlockingIOError になる""" lock_path = tmp_path / "distill.lock" - # O_CREAT | O_EXCL で作成 - fd = os.open(lock_path, os.O_CREAT | os.O_EXCL | os.O_WRONLY) - os.write(fd, b"12345") - os.close(fd) + fd1 = os.open(str(lock_path), os.O_CREAT | os.O_RDWR, 0o600) + fcntl.flock(fd1, fcntl.LOCK_EX | fcntl.LOCK_NB) - assert lock_path.exists() - assert lock_path.read_text() == "12345" - - # 2回目は FileExistsError + fd2 = os.open(str(lock_path), os.O_CREAT | os.O_RDWR, 0o600) try: - fd = os.open(lock_path, os.O_CREAT | os.O_EXCL | os.O_WRONLY) - os.close(fd) - raise AssertionError("Expected FileExistsError") - except FileExistsError: + fcntl.flock(fd2, fcntl.LOCK_EX | fcntl.LOCK_NB) + raise AssertionError("Expected BlockingIOError") + except BlockingIOError: pass + finally: + fcntl.flock(fd1, fcntl.LOCK_UN) + os.close(fd1) + os.close(fd2) -def test_distill_lock_stale_cleanup(tmp_path: Path) -> None: - """死んだプロセスの stale lock を検出してクリーンアップできる""" - lock_path = tmp_path / "distill.lock" - # 存在しない PID を書き込む - lock_path.write_text("999999999") +def test_distill_lock_already_running(tmp_path: Path, monkeypatch) -> None: + """ロック保持中の distill は already running で exit 0""" + codeatrium_dir = tmp_path / ".codeatrium" + codeatrium_dir.mkdir(parents=True) + init_db(codeatrium_dir / "memory.db") + + lock_path = codeatrium_dir / "distill.lock" + fd = os.open(str(lock_path), os.O_CREAT | os.O_RDWR, 0o600) + fcntl.flock(fd, fcntl.LOCK_EX | fcntl.LOCK_NB) try: - existing_pid = int(lock_path.read_text().strip()) - os.kill(existing_pid, 0) - raise AssertionError("PID should not exist") - except ProcessLookupError: - # stale lock — 削除して再取得 - lock_path.unlink() - fd = os.open(lock_path, os.O_CREAT | os.O_EXCL | os.O_WRONLY) - os.write(fd, str(os.getpid()).encode()) + monkeypatch.chdir(tmp_path) + result = runner.invoke(app, ["distill"]) + + assert result.exit_code == 0 + output = result.output + (getattr(result, "stderr", "") or "") + assert "already running" in output + finally: + fcntl.flock(fd, fcntl.LOCK_UN) os.close(fd) - - assert lock_path.read_text() == str(os.getpid()) diff --git a/tests/test_server_cmd.py b/tests/test_server_cmd.py index 254031f..1c10b85 100644 --- a/tests/test_server_cmd.py +++ b/tests/test_server_cmd.py @@ -3,6 +3,7 @@ from __future__ import annotations from pathlib import Path +from unittest.mock import MagicMock, patch from typer.testing import CliRunner @@ -19,3 +20,61 @@ def test_server_start_rejects_uninitialized_repo(tmp_path: Path, monkeypatch) -> assert "loci init" in result.output # .codeatrium ディレクトリが作成されていないこと assert not (tmp_path / ".codeatrium").exists() + + +def _make_initialized_repo(tmp_path: Path) -> Path: + """db_path(root).exists() が True になる最小リポジトリを作る""" + cdir = tmp_path / ".codeatrium" + cdir.mkdir(parents=True, exist_ok=True) + (cdir / "memory.db").touch() + return tmp_path + + +def _fake_ok_socket() -> MagicMock: + """ping に {"status":"ok"} を返す context-manager 対応の偽ソケット""" + s = MagicMock() + s.__enter__.return_value = s + s.__exit__.return_value = False + s.recv.return_value = b'{"status":"ok"}\n' + return s + + +def test_server_start_already_running(tmp_path: Path, monkeypatch) -> None: + """稼働中サーバーがいる状態で start を再実行しても二重起動しない(H3)""" + _make_initialized_repo(tmp_path) + monkeypatch.chdir(tmp_path) + sock = tmp_path / ".codeatrium" / "embedder.sock" + sock.touch() # exists() を True にする + + popen = MagicMock() + with patch("codeatrium.paths.git_root", return_value=None), \ + patch("socket.socket", return_value=_fake_ok_socket()), \ + patch("subprocess.Popen", popen): + result = runner.invoke(app, ["server", "start"]) + + assert "already running" in result.output + popen.assert_not_called() + + +def test_server_start_stale_cleanup(tmp_path: Path, monkeypatch) -> None: + """死亡 PID の pid ファイルが残っていても os.kill 生存確認で掃除して起動する(H3)""" + _make_initialized_repo(tmp_path) + monkeypatch.chdir(tmp_path) + cdir = tmp_path / ".codeatrium" + pid_file = cdir / "embedder.pid" + pid_file.write_text("999999999") # 存在しない PID(socket ファイルは作らない) + + popen = MagicMock() + popen.return_value.pid = 12345 + + # socket が無いので ping はスキップされ pid 生存確認パスを通る。 + # Popen 後の wait ループを抜けるため time.sleep を無効化する。 + with patch("codeatrium.paths.git_root", return_value=None), \ + patch("subprocess.Popen", popen), \ + patch("time.sleep", lambda *a, **k: None): + runner.invoke(app, ["server", "start"]) + + # 死亡 PID は os.kill(pid, 0) の ProcessLookupError で除去され、 + # 新サーバーが起動して新しい PID が pid ファイルに書かれる + assert popen.called + assert pid_file.read_text().strip() == "12345" diff --git a/tests/test_status_hook.py b/tests/test_status_hook.py index 6a9ae37..9f205ac 100644 --- a/tests/test_status_hook.py +++ b/tests/test_status_hook.py @@ -9,7 +9,9 @@ import json from pathlib import Path +from unittest.mock import patch +import pytest from typer.testing import CliRunner from codeatrium.cli import app @@ -56,7 +58,8 @@ def test_status_json_output(tmp_path, monkeypatch): data = json.loads(result.output) assert "exchanges" in data assert "distilled" in data - assert "undistilled" in data + assert "skipped" in data + assert "pending" in data assert "palace_objects" in data assert "symbols" in data assert "db_size_kb" in data @@ -83,8 +86,8 @@ def test_status_counts_exchanges(tmp_path, monkeypatch): (ex_id1, conv_id, 0, 1, "hello world", "hi there"), ) con.execute( - "INSERT INTO exchanges (id, conversation_id, ply_start, ply_end, user_content, agent_content, distilled_at) VALUES (?, ?, ?, ?, ?, ?, ?)", - (ex_id2, conv_id, 2, 3, "foo bar", "baz qux", "2026-01-01T00:00:00"), + "INSERT INTO exchanges (id, conversation_id, ply_start, ply_end, user_content, agent_content, distilled_at, distill_status) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", + (ex_id2, conv_id, 2, 3, "foo bar", "baz qux", "2026-01-01T00:00:00", "distilled"), ) con.commit() con.close() @@ -93,7 +96,7 @@ def test_status_counts_exchanges(tmp_path, monkeypatch): data = json.loads(result.output) assert data["exchanges"] == 2 assert data["distilled"] == 1 - assert data["undistilled"] == 1 + assert data["pending"] == 1 # ---- hook install ---- @@ -169,7 +172,9 @@ def test_hook_install_prime_idempotent(tmp_path, monkeypatch): assert len(prime_hooks) == 1 -def test_prime_outputs_instructions(): +def test_prime_outputs_instructions(tmp_path, monkeypatch): + _setup_db(tmp_path) + monkeypatch.chdir(tmp_path) result = runner.invoke(app, ["prime"]) assert result.exit_code == 0 assert "loci search" in result.output @@ -196,3 +201,133 @@ def test_hook_install_merges_existing_settings(tmp_path, monkeypatch): # 既存設定が保持されている assert data.get("model") == "opus" assert "hooks" in data + + +# ---- hook install atomic + backup ---- + + +def test_write_settings_atomic_bak(tmp_path, monkeypatch): + """install 時に既存 settings.json を .bak にバックアップする""" + monkeypatch.setattr("codeatrium.hooks.Path.home", lambda: tmp_path) + settings_path = tmp_path / ".claude" / "settings.json" + settings_path.parent.mkdir(parents=True) + settings_path.write_text(json.dumps({"model": "opus"})) + + result = runner.invoke(app, ["hook", "install"]) + assert result.exit_code == 0 + + bak_path = settings_path.with_suffix(".json.bak") + assert bak_path.exists() + bak_data = json.loads(bak_path.read_text()) + assert bak_data.get("model") == "opus" + + +def test_write_settings_failure_keeps_original_intact(tmp_path, monkeypatch): + """書き込み失敗(例外注入)時に元 settings.json が無傷であることを確認""" + monkeypatch.setattr("codeatrium.hooks.Path.home", lambda: tmp_path) + settings_path = tmp_path / ".claude" / "settings.json" + settings_path.parent.mkdir(parents=True) + initial_content = {"model": "opus", "existing": True} + settings_path.write_text(json.dumps(initial_content)) + + # os.replace を例外を投げる mock に patch する + with patch("codeatrium.hooks.os.replace", side_effect=OSError("disk full")): + from codeatrium.hooks import install_hooks + # install_hooks() が OSError を送出することを確認 + with pytest.raises(OSError): + install_hooks() + + # 元の settings.json が無傷であることを assert + assert settings_path.exists() + original_data = json.loads(settings_path.read_text()) + assert original_data == initial_content + + +def test_write_settings_atomic_no_bak_when_missing(tmp_path, monkeypatch): + """settings.json が存在しない場合は .bak は作成されない""" + monkeypatch.setattr("codeatrium.hooks.Path.home", lambda: tmp_path) + settings_path = tmp_path / ".claude" / "settings.json" + + result = runner.invoke(app, ["hook", "install"]) + assert result.exit_code == 0 + + bak_path = settings_path.with_suffix(".json.bak") + assert not bak_path.exists() + + +# ---- hook uninstall ---- + + +def test_hook_uninstall_removes_codeatrium_hooks(tmp_path, monkeypatch): + """uninstall は codeatrium フックを削除する""" + monkeypatch.setattr("codeatrium.hooks.Path.home", lambda: tmp_path) + settings_path = tmp_path / ".claude" / "settings.json" + + runner.invoke(app, ["hook", "install"]) + result = runner.invoke(app, ["hook", "uninstall"]) + assert result.exit_code == 0 + + data = json.loads(settings_path.read_text()) + # hooks がないか、Stop/SessionStart/SessionEnd に loci コマンドを含むエントリが無いこと + if "hooks" in data: + for hook_type in ["Stop", "SessionStart", "SessionEnd"]: + if hook_type in data["hooks"]: + entries = data["hooks"][hook_type] + for entry in entries: + for h in entry.get("hooks", []): + assert "loci" not in h.get("command", "") + + +def test_hook_uninstall_preserves_user_hooks(tmp_path, monkeypatch): + """uninstall はユーザーフックを保持する""" + monkeypatch.setattr("codeatrium.hooks.Path.home", lambda: tmp_path) + settings_path = tmp_path / ".claude" / "settings.json" + settings_path.parent.mkdir(parents=True) + settings_path.write_text( + json.dumps({ + "hooks": { + "Stop": [{"hooks": [{"type": "command", "command": "my-tool run"}]}] + } + }) + ) + + runner.invoke(app, ["hook", "install"]) + runner.invoke(app, ["hook", "uninstall"]) + + data = json.loads(settings_path.read_text()) + assert "Stop" in data["hooks"] + stop_entries = data["hooks"]["Stop"] + all_commands = [h for entry in stop_entries for h in entry.get("hooks", [])] + assert any("my-tool run" in h.get("command", "") for h in all_commands) + + +def test_hook_uninstall_idempotent(tmp_path, monkeypatch): + """uninstall は複数回実行しても安全(べき等)""" + monkeypatch.setattr("codeatrium.hooks.Path.home", lambda: tmp_path) + + # install なしで直接 uninstall + result1 = runner.invoke(app, ["hook", "uninstall"]) + assert result1.exit_code == 0 + assert "Nothing to uninstall" in result1.output or "No" in result1.output + + # 2回目も同じ + result2 = runner.invoke(app, ["hook", "uninstall"]) + assert result2.exit_code == 0 + assert "Nothing to uninstall" in result2.output or "No" in result2.output + + +def test_hook_uninstall_empty_matcher_removed(tmp_path, monkeypatch): + """uninstall 後、空の matcher を持つエントリは削除される""" + monkeypatch.setattr("codeatrium.hooks.Path.home", lambda: tmp_path) + settings_path = tmp_path / ".claude" / "settings.json" + + runner.invoke(app, ["hook", "install"]) + runner.invoke(app, ["hook", "uninstall"]) + + data = json.loads(settings_path.read_text()) + if "hooks" in data and "SessionStart" in data["hooks"]: + entries = data["hooks"]["SessionStart"] + for entry in entries: + # 各エントリは空でない hooks を持つこと + hooks = entry.get("hooks", []) + assert len(hooks) > 0