diff --git a/AGENT.md b/AGENT.md index 9482997..f6b1a28 100644 --- a/AGENT.md +++ b/AGENT.md @@ -75,7 +75,7 @@ Columns: `workflow_id`, `workflow_name`, `stages_state` (JSON), `status`, `creat * **Datetime:** All timestamps use `datetime.now(timezone.utc)` (timezone-aware). `datetime.utcnow()` is deprecated in Python 3.12+ and must not be re-introduced. * **Log path propagation pattern:** When an executor knows the log path at submit time, it writes it to `task_spec.metadata["log_path"]`. The runner reads this after `submit_with_retry()` and passes it to `db.update_status(..., log_path=...)`. The runner then retrieves `record.log_path` and passes it as `executor.logs(remote_id, log_path=record.log_path)`. All `logs()` implementations accept the optional `log_path` kwarg. * **set_e passthrough:** Tasks that need `set -x` without `set -e` (e.g., retry loops) set `metadata["set_e"] = False`. `SlurmExecutor.submit()` reads this and passes it to `generate_sbatch_script(set_e=...)`. The default is `True` (preserving `set -ex` for all existing tasks). -* **Workflow OmegaConf overrides:** `workflow run` resolves `${params.X}` interpolations via OmegaConf. Merge order: YAML base → `--from-job` extracted params → CLI trailing overrides. The `_PARAM_MAPPING` dict in `WorkflowRunner.extract_workflow_params()` maps task-level param keys to workflow-level dotlist keys. +* **Workflow OmegaConf overrides:** `workflow run` resolves `${params.X}` interpolations via OmegaConf. Merge order: YAML base → `--from-job` extracted params → CLI trailing overrides. From-job params use `OmegaConf.merge()` (dict-only keys), while CLI overrides use `OmegaConf.update()` per-key to correctly handle list-indexed paths like `stages.0.params.X`. The `_PARAM_MAPPING` dict in `WorkflowRunner.extract_workflow_params()` maps task-level param keys to workflow-level dotlist keys. * **Workflow placeholder validation:** Required params use `` markers. `WorkflowRunner._validate_no_placeholders()` matches `^$` and raises `ValueError` listing all unfilled fields before any submission. * **Workflow detach pattern:** `run_detached()` validates and creates the DB record synchronously, then forks via `subprocess.Popen([sys.executable, "-m", "devrun.workflow", "--state-file", ...])` with `start_new_session=True`. The child process drives the heartbeat loop on the pre-created record. @@ -125,7 +125,7 @@ python -m pytest tests/ -v ### Test Coverage -- **713 tests passing**, **10 skipped** (infrastructure-dependent: require real SSH/Slurm connectivity) +- **728 tests passing**, **10 skipped** (infrastructure-dependent: require real SSH/Slurm connectivity) - Unit tests for all major components (models, registry, database, router, runner, tasks, executors, workflow engine) - Integration tests between modules - End-to-end workflow tests diff --git a/devrun/cli.py b/devrun/cli.py index 67f23e1..7462d54 100644 --- a/devrun/cli.py +++ b/devrun/cli.py @@ -4,7 +4,6 @@ import json import logging -import sys from pathlib import Path from typing import Optional @@ -446,15 +445,88 @@ def fetch( app.add_typer(workflow_app, name="workflow") +def _show_workflow_help(target: str) -> None: + """Show help for a specific workflow based on its configuration.""" + from devrun.runner import load_merged_config + from rich.panel import Panel + from rich.text import Text + + try: + raw = load_merged_config(target) + except FileNotFoundError: + console.print(f"[red]Error:[/red] No config found for workflow '{target}'.") + console.print("Ensure the workflow config exists in one of the config search directories.") + raise typer.Exit(code=1) + except Exception as e: + console.print(f"[red]Failed to load configuration for '{target}':[/red] {e}") + raise typer.Exit(code=1) + + workflow_name = raw.get("workflow", target) + + console.print(Panel(f"Workflow: [bold cyan]{workflow_name}[/bold cyan] (config: {target})", expand=False)) + console.print() + + # Workflow-level params + params = raw.get("params", {}) + if params: + param_table = Table(title="Workflow Parameters", show_edge=False, title_justify="left", header_style="bold cyan") + param_table.add_column("Override") + param_table.add_column("Default Value") + + for k, v in params.items(): + val_str = str(v) + if val_str.startswith("<") and val_str.endswith(">"): + val_str = f"[yellow]{val_str}[/yellow]" + param_table.add_row(f"params.[bold]{k}[/bold]", val_str) + + console.print(param_table) + console.print() + + # Stages + stages = raw.get("stages", []) + if stages: + stage_table = Table(title="Stages", show_edge=False, title_justify="left", header_style="bold cyan") + stage_table.add_column("Name") + stage_table.add_column("Task", style="cyan") + stage_table.add_column("Executor", style="green") + stage_table.add_column("Depends On", style="dim") + + for s in stages: + deps = s.get("depends_on", None) + if isinstance(deps, list): + deps_str = ", ".join(deps) + elif deps: + deps_str = str(deps) + else: + deps_str = "—" + stage_table.add_row(s.get("name", "?"), s.get("task", "?"), s.get("executor", "?"), deps_str) + + console.print(stage_table) + console.print() + + # Usage example + console.print("[dim]Usage Example:[/dim]") + example_cmd = Text("devrun workflow run ", style="bold") + example_cmd.append(target, style="bold cyan") + if params: + first_param = next(iter(params.keys())) + example_cmd.append(f" params.{first_param}=value", style="green") + console.print(example_cmd) + + @workflow_app.command( "run", - context_settings={"allow_extra_args": True, "ignore_unknown_options": True}, + context_settings={"allow_extra_args": True, "ignore_unknown_options": True, "help_option_names": []}, ) def workflow_run( ctx: typer.Context, - target: str = typer.Argument(..., help="Workflow config path, name, or name/variation"), + target: Optional[str] = typer.Argument(None, help="Workflow config path, name, or name/variation"), dry_run: bool = typer.Option(False, "--dry-run", help="Show execution plan without submitting"), + start_after: Optional[str] = typer.Option(None, "--start-after", help="Skip this stage and its dependencies, start from the next"), + from_job: Optional[str] = typer.Option(None, "--from-job", help="Extract workflow params from an existing job"), + detach: bool = typer.Option(False, "--detach", "-d", help="Run workflow in background, return immediately"), verbose: bool = typer.Option(False, "--verbose", "-v"), + help: bool = typer.Option(False, "--help", "-h", help="Show this message and exit."), ) -> None: """Run a multi-stage workflow from a YAML config. @@ -462,38 +534,110 @@ def workflow_run( or name/variation. Configs are resolved through the same hierarchical search path as task configs. Trailing arguments are OmegaConf overrides. """ + if help: + if not target: + console.print(ctx.get_help()) + raise typer.Exit() + else: + _show_workflow_help(target) + raise typer.Exit() + + if not target: + console.print("[red]Missing argument 'TARGET'.[/red]\n") + console.print(ctx.get_help()) + raise typer.Exit(code=2) + _setup_logging(verbose) - from devrun.runner import load_merged_config + from omegaconf import OmegaConf + from devrun.runner import find_configs + import devrun.keystore # noqa: F401 — registers ${key:…} resolver + import devrun.presets # noqa: F401 — registers ${preset:…} resolver from devrun.models import WorkflowConfig + from devrun.workflow import WorkflowRunner - overrides = ctx.args - if overrides: - console.print(f"[dim]Using overrides: {overrides}[/dim]") + runner = WorkflowRunner() + task_name: Optional[str] = None + # Merge order: YAML base (hierarchical) → from-job params → CLI overrides (highest priority) try: - raw = load_merged_config(target, overrides=overrides) + config_paths = find_configs(target) except FileNotFoundError: console.print(f"[red]Error:[/red] Config not found for '{target}'.") console.print("Ensure the workflow config exists in one of the config search directories.") raise typer.Exit(code=1) try: - cfg = WorkflowConfig(**raw) + raw_cfg = OmegaConf.load(config_paths[0]) + for extra_path in config_paths[1:]: + raw_cfg = OmegaConf.merge(raw_cfg, OmegaConf.load(extra_path)) + + if from_job: + try: + job_params, task_name = runner.extract_workflow_params(from_job) + except ValueError as exc: + console.print(f"[red]Error:[/red] {exc}") + raise typer.Exit(code=1) + if job_params: + console.print(f"[dim]From job {from_job}: {list(job_params.keys())}[/dim]") + job_overrides = [f"{k}={v}" for k, v in job_params.items()] + raw_cfg = OmegaConf.merge(raw_cfg, OmegaConf.from_dotlist(job_overrides)) + + if ctx.args: + console.print(f"[dim]Using overrides: {ctx.args}[/dim]") + for arg in ctx.args: + key, _, value = arg.partition("=") + if key and _ == "=": + # Parse value type (e.g. "30" → int, "true" → bool) + # so numeric/boolean overrides aren't stored as strings. + parsed = yaml.safe_load(value) + OmegaConf.update(raw_cfg, key, parsed) + else: + console.print(f"[yellow]Warning:[/yellow] ignoring malformed override: {arg}") + + resolved = OmegaConf.to_container(raw_cfg, resolve=True) + except typer.Exit: + raise except Exception as exc: - console.print(f"[red]Error parsing workflow config:[/red] {exc}") + console.print(f"[red]Error loading/resolving workflow config:[/red] {exc}") raise typer.Exit(code=1) - from devrun.workflow import WorkflowRunner + try: + cfg = WorkflowConfig(**resolved) + except Exception as exc: + console.print(f"[red]Error parsing workflow config:[/red] {exc}") + raise typer.Exit(code=1) - runner = WorkflowRunner() - result = runner.run(cfg, dry_run=dry_run) + # Auto-detect stage to skip when --from-job is used without --start-after + if from_job and not start_after and task_name is not None: + detected_stage = runner.detect_stage_for_task(task_name, cfg) + if detected_stage: + start_after = detected_stage + console.print( + f"[dim]Auto-detected: skipping stage '{detected_stage}' " + f"based on job task type '{task_name}'[/dim]" + ) - if dry_run: - console.print(result) - console.print("[yellow]Dry-run complete. No jobs were submitted.[/yellow]") - else: - console.print(f"[green]Workflow completed:[/green] {result}") + try: + if detach: + if dry_run: + console.print("[red]Error:[/red] --detach and --dry-run cannot be used together.") + raise typer.Exit(code=1) + wf_id = runner.run_detached(cfg, start_after=start_after) + console.print( + f"[green]Workflow {wf_id} started in background.[/green]\n" + f"Use [bold]devrun workflow status {wf_id}[/bold] to monitor." + ) + else: + result = runner.run(cfg, dry_run=dry_run, start_after=start_after) + if dry_run: + console.print(result) + console.print("[yellow]Dry-run complete. No jobs were submitted.[/yellow]") + else: + console.print(f"[green]Workflow completed:[/green] {result}") + except ValueError as exc: + console.print(f"[red]Error:[/red] {exc}") + raise typer.Exit(code=1) @workflow_app.command("status") diff --git a/devrun/configs/swe_bench_workflow/default.yaml b/devrun/configs/swe_bench_workflow/default.yaml index 57d113b..48d2b8b 100644 --- a/devrun/configs/swe_bench_workflow/default.yaml +++ b/devrun/configs/swe_bench_workflow/default.yaml @@ -1,14 +1,26 @@ # configs/swe_bench_workflow/default.yaml # Full SWE-bench pipeline: inference → collect → evaluate +# +# Usage: +# devrun workflow run devrun/configs/swe_bench_workflow/default.yaml \ +# params.model_name=openai/gpt-4 \ +# params.dataset=/data/swebench/SWE-bench_Lite \ +# params.working_dir=/home/user/project +# +# Run collect+eval from existing inference: +# devrun workflow run devrun/configs/swe_bench_workflow/default.yaml \ +# --from-job +# +# See docs/swe-bench-workflow-guide.md for full documentation. workflow: swe_bench params: - model_name: "" # REQUIRED: model identifier - dataset: "" # REQUIRED: absolute path to dataset + model_name: "" + dataset: "" split: test run_name: "run1" output_dir: "logs/run1" - working_dir: "" # REQUIRED: remote project root + working_dir: "" stages: - name: inference diff --git a/devrun/workflow.py b/devrun/workflow.py index 77e1f63..36fcebc 100644 --- a/devrun/workflow.py +++ b/devrun/workflow.py @@ -3,6 +3,11 @@ import json import logging +import os +import re +import subprocess +import sys +import tempfile import time from datetime import datetime, timezone from pathlib import Path @@ -27,25 +32,140 @@ def __init__( self._db = JobStore(db_path) self._executors_path = executors_path - def run(self, config: WorkflowConfig, dry_run: bool = False) -> str: + # Statuses that satisfy downstream dependency checks. + _SATISFIED_STATUSES = frozenset({"completed", "skipped_by_user"}) + + def run( + self, + config: WorkflowConfig, + dry_run: bool = False, + start_after: str | None = None, + ) -> str: """Execute (or dry-run) a workflow. Returns the workflow_id for real runs, or a plan string for dry_run. """ stages_by_name = {s.name: s for s in config.stages} + # Validate start_after early (applies to both dry-run and real runs) + skip_set: set[str] = set() + if start_after: + skip_set = self._compute_skip_set(start_after, stages_by_name) + if dry_run: - return self._dry_run(config, stages_by_name) + return self._dry_run(config, stages_by_name, skip_set=skip_set) + + # Fail fast on unfilled placeholders + self._validate_no_placeholders(config) # Initialise per-stage state stages_state: dict[str, dict[str, Any]] = {} for stage in config.stages: - stages_state[stage.name] = {"status": "pending", "job_id": None} + stages_state[stage.name] = {"status": "pending", "remote_job_id": None, "db_job_id": None} + + # Pre-mark skipped stages + if skip_set: + for name in skip_set: + stages_state[name] = {"status": "skipped_by_user", "remote_job_id": None, "db_job_id": None} + logger.info("Stage %s skipped (--start-after %s)", name, start_after) wf_id = self._db.insert_workflow(config.workflow, stages_state) self._db.update_workflow(wf_id, status="running") logger.info("Workflow %s started: %s", wf_id, config.workflow) + return self._heartbeat_loop(wf_id, config, stages_state) + + def run_detached( + self, + config: WorkflowConfig, + start_after: str | None = None, + ) -> str: + """Start a workflow in a background process and return immediately. + + Returns the workflow_id. The background process runs the heartbeat + loop and updates the DB. Use ``status()`` / ``logs()`` to monitor. + """ + stages_by_name = {s.name: s for s in config.stages} + + # Validate early so the user gets errors before we fork + self._validate_no_placeholders(config) + skip_set: set[str] = set() + if start_after: + skip_set = self._compute_skip_set(start_after, stages_by_name) + + # Create workflow record so the caller can query it immediately + stages_state: dict[str, dict[str, Any]] = {} + for stage in config.stages: + stages_state[stage.name] = {"status": "pending", "remote_job_id": None, "db_job_id": None} + if skip_set: + for name in skip_set: + stages_state[name] = {"status": "skipped_by_user", "remote_job_id": None, "db_job_id": None} + + wf_id = self._db.insert_workflow(config.workflow, stages_state) + self._db.update_workflow(wf_id, status="pending") + + # Serialise state for the child process. + # skip_set is already applied to stages_state above, so the child + # process does not need start_after — it just drives the heartbeat. + state: dict[str, Any] = { + "config": config.model_dump(mode="json"), + "workflow_id": wf_id, + "db_path": str(self._db._db_path), + } + if self._executors_path is not None: + state["executors_path"] = str(self._executors_path) + fd, state_path = tempfile.mkstemp(prefix="devrun_wf_", suffix=".json") + with os.fdopen(fd, "w") as fh: + json.dump(state, fh) + + # Ensure log directory exists + log_dir = Path.home() / ".devrun" / "logs" + log_dir.mkdir(parents=True, exist_ok=True) + log_file = log_dir / f"workflow_{wf_id}.log" + + with open(log_file, "w") as log_fh: + subprocess.Popen( + [sys.executable, "-m", "devrun.workflow", "--state-file", state_path], + start_new_session=True, + stdout=log_fh, + stderr=subprocess.STDOUT, + close_fds=True, + ) + + logger.info( + "Workflow %s detached (PID forked). Log: %s", wf_id, log_file, + ) + return wf_id + + def _run_existing( + self, + wf_id: str, + config: WorkflowConfig, + ) -> str: + """Drive the heartbeat loop on an already-created workflow record. + + Used by the detached background process. The workflow record and + initial stages_state already exist in the DB (created by + ``run_detached``). The skip_set has already been applied to + stages_state before serialisation, so no ``start_after`` is needed here. + """ + record = self._db.get_workflow(wf_id) + if not record: + raise ValueError(f"Workflow {wf_id} not found in DB") + stages_state: dict[str, dict[str, Any]] = json.loads(record["stages_state"]) + + self._db.update_workflow(wf_id, status="running") + logger.info("Workflow %s resumed (detached): %s", wf_id, config.workflow) + + return self._heartbeat_loop(wf_id, config, stages_state) + + def _heartbeat_loop( + self, + wf_id: str, + config: WorkflowConfig, + stages_state: dict[str, dict[str, Any]], + ) -> str: + """Core heartbeat polling loop shared by ``run()`` and ``_run_existing()``.""" start_time = time.monotonic() try: @@ -67,9 +187,9 @@ def run(self, config: WorkflowConfig, dry_run: bool = False) -> str: for stage in config.stages: state = stages_state[stage.name] - if state["status"] in ("completed", "skipped"): + if state["status"] in ("completed", "skipped", "skipped_by_user"): continue - if state["status"] == "failed": + if state["status"] in ("failed", "cancelled"): any_failed = True continue @@ -80,7 +200,8 @@ def run(self, config: WorkflowConfig, dry_run: bool = False) -> str: if isinstance(deps, str): deps = [deps] deps_met = all( - stages_state[d]["status"] == "completed" for d in deps + stages_state[d]["status"] in self._SATISFIED_STATUSES + for d in deps ) if not deps_met: if any( @@ -93,12 +214,14 @@ def run(self, config: WorkflowConfig, dry_run: bool = False) -> str: continue try: - job_id = self._submit_stage(stage.name, stage) + db_job_id, remote_id = self._submit_stage(stage.name, stage) state["status"] = "submitted" - state["job_id"] = job_id + state["remote_job_id"] = remote_id + state["db_job_id"] = db_job_id state["executor"] = stage.executor logger.info( - "Stage %s submitted: job_id=%s", stage.name, job_id + "Stage %s submitted: remote_job_id=%s, db_job_id=%s", + stage.name, remote_id, db_job_id, ) except Exception: logger.exception("Stage %s failed to submit", stage.name) @@ -106,39 +229,51 @@ def run(self, config: WorkflowConfig, dry_run: bool = False) -> str: any_failed = True elif state["status"] in ("submitted", "running"): - job_id = state["job_id"] - poll_status = self._poll_job_status(job_id, stage.executor) + remote_id = state["remote_job_id"] + poll_status = self._poll_job_status(remote_id, stage.executor) if poll_status == "completed": state["status"] = "completed" logger.info("Stage %s completed", stage.name) + db_jid = state.get("db_job_id") + if db_jid: + self._db.update_status( + db_jid, JobStatus.COMPLETED, + completed_at=datetime.now(timezone.utc), + ) elif poll_status == "failed": state["status"] = "failed" any_failed = True logger.error("Stage %s failed", stage.name) + db_jid = state.get("db_job_id") + if db_jid: + self._db.update_status( + db_jid, JobStatus.FAILED, + completed_at=datetime.now(timezone.utc), + ) elif poll_status == "running": state["status"] = "running" self._db.update_workflow(wf_id, stages_state=stages_state) - if all_done: + if any_failed: self._db.update_workflow( wf_id, - status="completed", + status="failed", stages_state=stages_state, completed_at=datetime.now(timezone.utc), ) - logger.info("Workflow %s completed successfully", wf_id) + logger.error("Workflow %s failed", wf_id) return wf_id - if any_failed: + if all_done: self._db.update_workflow( wf_id, - status="failed", + status="completed", stages_state=stages_state, completed_at=datetime.now(timezone.utc), ) - logger.error("Workflow %s failed", wf_id) + logger.info("Workflow %s completed successfully", wf_id) return wf_id time.sleep(config.heartbeat_interval) @@ -155,8 +290,61 @@ def run(self, config: WorkflowConfig, dry_run: bool = False) -> str: # -- internal helpers --------------------------------------------------- - def _submit_stage(self, stage_name: str, stage: WorkflowStage) -> str: - """Submit a single stage: resolve task + executor, prepare, submit.""" + @staticmethod + def _compute_skip_set( + start_after: str, stages_by_name: dict[str, WorkflowStage] + ) -> set[str]: + """Return the set of stage names to skip: *start_after* plus its transitive deps.""" + if start_after not in stages_by_name: + raise ValueError( + f"--start-after stage '{start_after}' does not exist. " + f"Available stages: {sorted(stages_by_name)}" + ) + skip: set[str] = set() + queue = [start_after] + while queue: + name = queue.pop() + if name in skip: + continue + skip.add(name) + deps = stages_by_name[name].depends_on or [] + if isinstance(deps, str): + deps = [deps] + queue.extend(deps) + return skip + + @staticmethod + def _validate_no_placeholders(config: WorkflowConfig) -> None: + """Raise ``ValueError`` if any param still contains a ```` marker.""" + pattern = re.compile(r"^$") + unfilled: list[str] = [] + + def _check(prefix: str, mapping: dict[str, Any]) -> None: + for key, val in mapping.items(): + if isinstance(val, str) and pattern.match(val): + unfilled.append(f" {prefix}.{key}: {val}") + elif isinstance(val, dict): + _check(f"{prefix}.{key}", val) + + _check("params", config.params) + for stage in config.stages: + _check(f"stages.{stage.name}.params", stage.params) + + if unfilled: + lines = ["Workflow config has unfilled required parameters:"] + lines.extend(unfilled) + lines.append("") + lines.append("Set them via CLI overrides:") + lines.append( + " devrun workflow run config.yaml params.model_name=mymodel params.dataset=/path/to/data" + ) + raise ValueError("\n".join(lines)) + + def _submit_stage(self, stage_name: str, stage: WorkflowStage) -> tuple[str, str]: + """Submit a single stage: resolve task + executor, prepare, submit. + + Returns (db_job_id, remote_job_id). + """ task_cls = get_task_class(stage.task) task = task_cls() task_spec = task.prepare(stage.params) @@ -165,18 +353,18 @@ def _submit_stage(self, stage_name: str, stage: WorkflowStage) -> str: remote_id = executor.submit(task_spec) # Record in jobs table - job_id = self._db.insert( + db_job_id = self._db.insert( task_name=stage.task, executor=stage.executor, parameters=stage.params, ) self._db.update_status( - job_id, + db_job_id, JobStatus.SUBMITTED, remote_job_id=remote_id, log_path=task_spec.metadata.get("log_path"), ) - return remote_id + return db_job_id, remote_id def _poll_job_status(self, job_id: str, executor_name: str) -> str: """Check the live status of a submitted job.""" @@ -193,31 +381,118 @@ def _poll_job_status(self, job_id: str, executor_name: str) -> str: return "running" # treat unknown as still running def _dry_run( - self, config: WorkflowConfig, stages_by_name: dict[str, WorkflowStage] + self, + config: WorkflowConfig, + stages_by_name: dict[str, WorkflowStage], + skip_set: set[str] | None = None, ) -> str: """Print the full execution plan without submitting anything.""" - lines = [f"Workflow: {config.workflow}", f"Timeout: {config.timeout}s", ""] + skip_set = skip_set or set() + timeout_h = config.timeout / 3600 + lines = [ + f"Workflow: {config.workflow}", + f"Timeout: {config.timeout:.0f}s ({timeout_h:.0f}h)", + "", + ] + will_run_count = 0 for i, stage in enumerate(config.stages, 1): + skipped = stage.name in skip_set + tag = " [SKIPPED — start-after]" if skipped else " [WILL RUN]" + lines.append(f"Stage {i}: {stage.name}{tag}") + lines.append(f" Task: {stage.task}") + lines.append(f" Executor: {stage.executor}") deps = stage.depends_on or [] if isinstance(deps, str): deps = [deps] + lines.append(f" Depends on: {', '.join(deps) if deps else '(none)'}") + if skipped: + lines.append("") + continue + will_run_count += 1 task_cls = get_task_class(stage.task) task = task_cls() task_spec = task.prepare(stage.params) - lines.append(f"Stage {i}: {stage.name}") - lines.append(f" Task: {stage.task}") - lines.append(f" Executor: {stage.executor}") - lines.append(f" Depends on: {deps or '(none)'}") lines.append(f" Working dir: {task_spec.working_dir or '(default)'}") - lines.append(f" Command preview (first 200 chars):") - lines.append(f" {task_spec.command[:200]}...") + # Show key params (up to 5) + if stage.params: + param_items = list(stage.params.items())[:5] + param_str = ", ".join(f"{k}={v}" for k, v in param_items) + if len(stage.params) > 5: + param_str += f", ... (+{len(stage.params) - 5} more)" + lines.append(f" Params: {param_str}") + lines.append(f" Command preview (first 500 chars):") + lines.append(f" {task_spec.command[:500]}") lines.append("") + if skip_set: + lines.append( + f"Summary: {len(skip_set)} stage(s) skipped, " + f"{will_run_count} stage(s) will run" + ) plan = "\n".join(lines) logger.info("Dry-run plan:\n%s", plan) return plan # -- public query methods ------------------------------------------------ + def extract_workflow_params(self, job_id: str) -> tuple[dict[str, str], str]: + """Extract workflow-level params from an existing job record. + + Returns (dotlist_dict, task_name) where dotlist_dict has keys like + ``"params.model_name"`` suitable for OmegaConf merging. + + Mapping priority: explicit _PARAM_MAPPING entries take precedence + (allowing key renaming), then any remaining job params are mapped + generically as ``params.{key}``. + """ + record = self._db.get(job_id) + if record is None: + raise ValueError( + f"Job '{job_id}' not found. Use `devrun history` to find job IDs." + ) + job_params = record.params_dict + + # Map task-specific param names → workflow-level param names + _PARAM_MAPPING: dict[str, str] = { + "model_name": "params.model_name", + "dataset": "params.dataset", + "split": "params.split", + "output_dir": "params.output_dir", + "working_dir": "params.working_dir", + "run_name": "params.run_name", + } + + dotlist: dict[str, str] = {} + mapped_keys: set[str] = set() + for job_key, workflow_key in _PARAM_MAPPING.items(): + if job_key in job_params and job_params[job_key]: + dotlist[workflow_key] = str(job_params[job_key]) + mapped_keys.add(job_key) + + # Generic fallback for unmapped params (skip known-sensitive keys) + _SENSITIVE_KEYS = frozenset({"api_key", "token", "secret", "password", "credentials"}) + for key, val in job_params.items(): + wf_key = f"params.{key}" + if key not in mapped_keys and wf_key not in dotlist and val: + if key in _SENSITIVE_KEYS: + logger.debug("Skipping sensitive param %s from generic fallback", key) + continue + dotlist[wf_key] = str(val) + + logger.info( + "Extracted %d params from job %s (%s): %s", + len(dotlist), job_id, record.task_name, list(dotlist.keys()), + ) + return dotlist, record.task_name + + def detect_stage_for_task( + self, task_name: str, config: WorkflowConfig + ) -> str | None: + """Find the stage name whose task type matches *task_name*.""" + for stage in config.stages: + if stage.task == task_name: + return stage.name + return None + def status(self, workflow_id: str) -> dict[str, Any] | None: """Return workflow record or None.""" return self._db.get_workflow(workflow_id) @@ -232,20 +507,20 @@ def cancel(self, workflow_id: str) -> None: raise ValueError(f"Workflow {workflow_id} not found") stages_state = json.loads(record["stages_state"]) for name, state in stages_state.items(): - if state["status"] in ("submitted", "running") and state.get("job_id"): - logger.info("Cancelling stage %s (job %s)", name, state["job_id"]) - # Cancel the actual remote job if executor info is available + remote_id = state.get("remote_job_id") + if state["status"] in ("submitted", "running") and remote_id: + logger.info("Cancelling stage %s (remote_job_id %s)", name, remote_id) executor_name = state.get("executor") if executor_name: try: executor = resolve_executor( executor_name, executors_path=self._executors_path ) - executor.cancel(state["job_id"]) + executor.cancel(remote_id) except Exception: logger.warning( "Failed to cancel remote job %s for stage %s", - state["job_id"], + remote_id, name, exc_info=True, ) @@ -258,19 +533,86 @@ def cancel(self, workflow_id: str) -> None: ) def logs(self, workflow_id: str, stage: str | None = None) -> str: - """Retrieve logs summary for a workflow or specific stage.""" + """Retrieve logs for a workflow or specific stage. + + For specific stages with executor info, delegates to the executor's + ``logs()`` method. Falls back to a status summary otherwise. + For detached workflows, appends the background process log. + """ record = self._db.get_workflow(workflow_id) if not record: raise ValueError(f"Workflow {workflow_id} not found") stages_state = json.loads(record["stages_state"]) + if stage: state = stages_state.get(stage) - if not state or not state.get("job_id"): + if not state or not state.get("remote_job_id"): return f"No logs available for stage '{stage}'" - return f"Stage {stage}: job_id={state['job_id']}, status={state['status']}" + remote_id = state["remote_job_id"] + # Try to delegate to executor for real logs + executor_name = state.get("executor") + if executor_name: + try: + executor = resolve_executor( + executor_name, executors_path=self._executors_path + ) + return executor.logs(remote_id) + except Exception: + logger.debug( + "Could not fetch executor logs for stage %s, falling back", + stage, exc_info=True, + ) + return f"Stage {stage}: remote_job_id={remote_id}, status={state['status']}" + lines = [] for name, state in stages_state.items(): lines.append( - f"{name}: status={state['status']}, job_id={state.get('job_id', 'N/A')}" + f"{name}: status={state['status']}, remote_job_id={state.get('remote_job_id', 'N/A')}" ) + + # Append background process log for detached workflows + bg_log = Path.home() / ".devrun" / "logs" / f"workflow_{workflow_id}.log" + if bg_log.exists(): + log_text = bg_log.read_text().strip() + if log_text: + lines.append("") + lines.append(f"--- Background process log ({bg_log}) ---") + lines.append(log_text) + return "\n".join(lines) + + +# --------------------------------------------------------------------------- +# Background process entry point (used by run_detached) +# --------------------------------------------------------------------------- + + +def _run_from_state_file(state_path: str) -> None: + """Entry point for ``python -m devrun.workflow --state-file ``.""" + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + path = Path(state_path) + state = json.loads(path.read_text()) + # Clean up state file now that we've read it + path.unlink(missing_ok=True) + + config = WorkflowConfig(**state["config"]) + wf_id: str = state["workflow_id"] + + runner = WorkflowRunner( + db_path=state.get("db_path"), + executors_path=state.get("executors_path"), + ) + runner._run_existing(wf_id, config) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Detached workflow runner") + parser.add_argument("--state-file", required=True, help="Path to serialised state JSON") + args = parser.parse_args() + _run_from_state_file(args.state_file) diff --git a/docs/swe-bench-workflow-guide.md b/docs/swe-bench-workflow-guide.md new file mode 100644 index 0000000..5ddd67a --- /dev/null +++ b/docs/swe-bench-workflow-guide.md @@ -0,0 +1,379 @@ +# SWE-bench Workflow Guide + +Run the full SWE-bench evaluation pipeline -- inference, result collection, and patch evaluation -- as a single orchestrated workflow. + +## Overview + +The SWE-bench workflow chains three stages: + +``` +inference (swe_bench_agentic) + ↓ Slurm array job: runs an LLM agent against benchmark instances +collect (swe_bench_collect) + ↓ SSH command: aggregates per-instance results into predictions.jsonl +evaluate (swe_bench_eval) + Slurm job: validates produced patches with swebench.harness +``` + +Each stage depends on the previous one. The workflow engine handles dependency resolution, status polling, and failure propagation automatically via a heartbeat loop. + +## Quick Start + +### 1. Configure your executor + +Ensure `executors.yaml` (or your project/user config layer) defines both a `slurm` and `ssh` executor pointing to your compute cluster: + +```yaml +# executors.yaml +slurm: + type: slurm + host: cluster.example.com + user: myuser + python_env: + type: conda + name: openhands + +ssh: + type: ssh + host: cluster.example.com + user: myuser +``` + +### 2. Run the full pipeline + +```bash +devrun workflow run devrun/configs/swe_bench_workflow/default.yaml \ + params.model_name=openai/gpt-4 \ + params.dataset=/data/SWE-bench_Verified \ + params.working_dir=/remote/project/root +``` + +The three required parameters (`model_name`, `dataset`, `working_dir`) must be set -- either in a config overlay or as CLI overrides. All other parameters have sensible defaults. + +The workflow blocks in the foreground, printing stage transitions as they happen. Use `--detach` (see [Background Execution](#background-execution)) to return immediately. + +### 3. Monitor progress + +```bash +# Check stage-by-stage status +devrun workflow status + +# View logs for a specific stage +devrun workflow logs --stage inference + +# List recent workflows +devrun workflow list +``` + +## Running from Existing Inference + +If you already have a completed `swe_bench_agentic` job (inference is done and outputs exist on the remote host), you can launch only the collect and evaluate stages. + +### Using `--from-job` + +The simplest approach: pass the existing job's ID. The workflow engine extracts all relevant parameters (model name, dataset, output directory, etc.) from the job record and auto-detects which stage to skip. + +```bash +# Find the job ID of your completed inference run +devrun history + +# Launch collect + eval using that job's parameters +devrun workflow run devrun/configs/swe_bench_workflow/default.yaml \ + --from-job abc123 +``` + +What happens: +1. Parameters are extracted from job `abc123` (model_name, dataset, split, output_dir, working_dir, run_name). +2. The job's task type (`swe_bench_agentic`) is matched to the `inference` stage. +3. `--start-after inference` is automatically applied -- the inference stage is skipped, and the workflow begins at `collect`. + +You can still apply CLI overrides on top of the extracted parameters. CLI overrides take highest priority: + +```bash +devrun workflow run devrun/configs/swe_bench_workflow/default.yaml \ + --from-job abc123 \ + params.working_dir=/different/remote/path +``` + +### Using `--start-after` manually + +If you prefer explicit control, combine `--start-after` with manual parameter overrides: + +```bash +devrun workflow run devrun/configs/swe_bench_workflow/default.yaml \ + --start-after inference \ + params.model_name=openai/gpt-4 \ + params.dataset=/data/SWE-bench_Verified \ + params.output_dir=logs/run1 \ + params.working_dir=/remote/project/root +``` + +`--start-after ` skips the named stage **and all its transitive dependencies**. For the SWE-bench workflow, `--start-after inference` skips only the inference stage (since it has no dependencies of its own), leaving collect and evaluate to run. + +### Finding existing job IDs + +Use `devrun history` to list recent jobs: + +```bash +# Show the 20 most recent jobs (default) +devrun history + +# Show all jobs +devrun history --all + +# Limit to a specific count +devrun history -n 50 +``` + +The output table shows Job ID, Task, Executor, Status, and Created timestamp. Look for `swe_bench_agentic` jobs with status `completed`. + +## Background Execution + +For long-running workflows, use `--detach` (or `-d`) to start the workflow in the background and return immediately: + +```bash +devrun workflow run devrun/configs/swe_bench_workflow/default.yaml \ + --detach \ + params.model_name=openai/gpt-4 \ + params.dataset=/data/SWE-bench_Verified \ + params.working_dir=/remote/project/root +``` + +Output: +``` +Workflow abc12345 started in background. +Use devrun workflow status abc12345 to monitor. +``` + +The background process logs to `~/.devrun/logs/workflow_.log`. + +Detached mode validates all parameters and creates the workflow DB record **before** forking, so configuration errors appear immediately in your terminal rather than silently failing in the background. + +`--detach` and `--dry-run` cannot be combined. + +### Monitoring a detached workflow + +```bash +# Stage-by-stage status with timing +devrun workflow status + +# Tail logs for a specific stage +devrun workflow logs --stage inference + +# Cancel all active stages +devrun workflow cancel +``` + +## Parameter Overrides + +Trailing arguments on `devrun workflow run` are OmegaConf dotlist overrides. They are merged in this order (last wins): + +1. **YAML config file** -- base configuration +2. **`--from-job` params** -- extracted from an existing job record +3. **CLI overrides** -- trailing arguments + +### Syntax + +```bash +# Simple value +params.model_name=openai/gpt-4 + +# Path with spaces (quote the whole arg) +"params.dataset=/data/my dataset/SWE-bench_Verified" + +# Nested stage params +stages.0.params.max_iterations=200 + +# Multiple overrides +devrun workflow run config.yaml \ + params.model_name=mymodel \ + params.split=dev \ + params.run_name=experiment-2 +``` + +### Common overrides + +| Override | Purpose | +|----------|---------| +| `params.model_name=X` | Model identifier | +| `params.dataset=/path` | Dataset location on remote host | +| `params.working_dir=/path` | Remote project root | +| `params.split=dev` | Dataset split (default: `test`) | +| `params.run_name=X` | Run identifier (default: `run1`) | +| `params.output_dir=logs/X` | Output directory (default: `logs/run1`) | + +## Dry-Run Verification + +Preview the full execution plan without submitting any jobs: + +```bash +devrun workflow run devrun/configs/swe_bench_workflow/default.yaml \ + --dry-run \ + params.model_name=openai/gpt-4 \ + params.dataset=/data/SWE-bench_Verified \ + params.working_dir=/remote/project/root +``` + +The dry-run output shows each stage with: +- Task type and executor +- Dependencies +- Working directory +- Key parameters (up to 5) +- First 500 characters of the rendered command + +When combined with `--start-after`, skipped stages are clearly tagged: + +``` +Workflow: swe_bench +Timeout: 172800s (48h) + +Stage 1: inference [SKIPPED — start-after] + Task: swe_bench_agentic + Executor: slurm + Depends on: (none) + +Stage 2: collect [WILL RUN] + Task: swe_bench_collect + Executor: ssh + Depends on: inference + Working dir: /remote/project/root + Params: output_dir=logs/run1, dataset=/data/SWE-bench_Verified, ... + Command preview (first 500 chars): + ... + +Stage 3: evaluate [WILL RUN] + Task: swe_bench_eval + Executor: slurm + Depends on: collect + ... + +Summary: 1 stage(s) skipped, 2 stage(s) will run +``` + +Always dry-run before launching a real workflow to verify parameter resolution and command generation. + +## Configuration Reference + +### Workflow-level parameters + +Defined under `params:` in the workflow config. All stages reference these via `${params.X}` interpolation. + +| Parameter | Required | Default | Description | +|-----------|----------|---------|-------------| +| `model_name` | Yes | -- | Model identifier (e.g., `openai/gpt-4`) | +| `dataset` | Yes | -- | Absolute path to SWE-bench dataset on remote host | +| `split` | No | `test` | Dataset split | +| `run_name` | No | `run1` | Run identifier, used in output paths | +| `output_dir` | No | `logs/run1` | Output directory (relative to `working_dir`) | +| `working_dir` | Yes | -- | Remote project root directory | + +Required parameters use `` placeholders in the default config. The workflow engine validates that all placeholders are filled before submission and reports which ones are missing. + +### Stage: inference (`swe_bench_agentic`) + +Submits a Slurm array job that runs the OpenHands agent against each benchmark instance. + +| Parameter | Default | Description | +|-----------|---------|-------------| +| `max_iterations` | `100` | Max agent iterations per instance | +| `max_attempts` | `5` | Retry count per array element | +| `array` | `000-499` | Slurm array range | +| `concurrency_limit` | `10` | Max concurrent array elements | +| `cpus_per_task` | `4` | CPUs per array element | +| `mem` | `32G` | Memory per array element | +| `walltime` | `24:00:00` | Slurm time limit | +| `job_name` | `swe-inference` | Slurm job name | + +### Stage: collect (`swe_bench_collect`) + +Runs via SSH. Scans inference output directories and produces `predictions.jsonl` using `jq`. + +| Parameter | Default | Description | +|-----------|---------|-------------| +| `predictions_path` | `${params.output_dir}/predictions.jsonl` | Output file path | +| `model_name_or_path` | `${params.model_name}` | Model identifier for prediction records | + +Instances with missing `git_patch` values are excluded with a warning. The command prints a summary of collected vs. skipped instances. + +### Stage: evaluate (`swe_bench_eval`) + +Submits a Slurm job running `swebench.harness.run_evaluation`. + +| Parameter | Default | Description | +|-----------|---------|-------------| +| `cpus_per_task` | `32` | CPUs for evaluation | +| `mem` | `64G` | Memory for evaluation | +| `max_workers` | `32` | Parallel evaluation workers | +| `walltime` | `24:00:00` | Slurm time limit | + +### Workflow engine settings + +Set at the top level of the workflow config: + +| Parameter | Default | Description | +|-----------|---------|-------------| +| `timeout` | `172800` (48h) | Max wall-clock time in seconds for the entire workflow | +| `heartbeat_interval` | `30.0` | Poll interval in seconds between status checks | + +## Troubleshooting + +### "Workflow config has unfilled required parameters" + +You forgot to set one or more required parameters. The error message lists exactly which ones: + +``` +Workflow config has unfilled required parameters: + params.model_name: + params.dataset: + +Set them via CLI overrides: + devrun workflow run config.yaml params.model_name=mymodel params.dataset=/path/to/data +``` + +Fix: add the missing overrides to your command. + +### "Job 'xyz' not found" + +The `--from-job` flag references a job ID that doesn't exist in the database. + +```bash +# List all jobs to find the correct ID +devrun history --all +``` + +### "Stage 'X' depends_on 'Y' which does not exist" + +A `depends_on` reference in the workflow config points to a stage name that doesn't exist. Check for typos in your config file. Available stage names are listed in the error message. + +### Stage stuck in "submitted" state + +The workflow engine polls executor status at the heartbeat interval (default 30s). If a stage remains in `submitted` for longer than expected: + +- Check the Slurm queue: `squeue -u $USER` +- Verify the executor configuration (host, user, SSH key) +- Check workflow logs: `devrun workflow logs --stage ` + +### Inference succeeded but collect finds no outputs + +Verify that the `output_dir` and `working_dir` parameters match between the inference job and the workflow config. The collect stage looks for files at: + +``` +{working_dir}/{output_dir}/*/{DS_DIR}/*/*/output.jsonl +``` + +where `DS_DIR` is derived from the dataset path (e.g., `/data/SWE-bench_Verified` with split `test` becomes `__data__SWE-bench_Verified-test`). + +### OmegaConf interpolation errors + +If you see errors about unresolved interpolations (`${params.X}`), ensure: +- The referenced key exists in the top-level `params:` section +- The key name matches exactly (case-sensitive) +- You haven't introduced a typo in `${params.…}` syntax + +### Workflow timed out + +The default timeout is 48 hours. Override it for longer runs: + +```bash +devrun workflow run config.yaml timeout=259200 ... # 72 hours +``` diff --git a/tests/test_cli.py b/tests/test_cli.py index 8722785..b01a108 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -490,6 +490,139 @@ def test_workflow_run_dry_run(self, tmp_path): assert "dry-run" in result.stdout.lower() +class TestWorkflowCLINewFeatures: + """Tests for new workflow CLI features: overrides, start-after, from-job, detach.""" + + def test_workflow_run_with_overrides(self, tmp_path): + """Trailing args should be passed as OmegaConf overrides to the workflow config.""" + config = { + "workflow": "override_test", + "stages": [ + { + "name": "s1", + "task": "eval", + "executor": "local", + "params": {"model": "default-model"}, + }, + ], + "heartbeat_interval": 0.001, + } + cfg_path = tmp_path / "wf.yaml" + cfg_path.write_text(yaml.dump(config)) + + # Use --dry-run to verify overrides are applied without needing to mock execution + runner = get_cli_runner() + result = runner.invoke(app, [ + "workflow", "run", str(cfg_path), "--dry-run", + "stages.0.params.model=overridden-model", + ]) + assert result.exit_code == 0 + # The override should appear in the dry-run output + assert "overridden-model" in result.stdout + + def test_workflow_run_start_after_flag(self, tmp_path): + """--start-after flag should be parsed and forwarded to WorkflowRunner.""" + config = { + "workflow": "start_after_test", + "stages": [ + {"name": "inference", "task": "eval", "executor": "local", "params": {"model": "x"}}, + {"name": "collect", "task": "eval", "executor": "local", "depends_on": "inference", "params": {"model": "x"}}, + {"name": "evaluate", "task": "eval", "executor": "local", "depends_on": "collect", "params": {"model": "x"}}, + ], + "heartbeat_interval": 0.001, + } + cfg_path = tmp_path / "wf.yaml" + cfg_path.write_text(yaml.dump(config)) + + # Use --dry-run with --start-after to verify the flag is parsed and produces skip markers + runner = get_cli_runner() + result = runner.invoke(app, [ + "workflow", "run", str(cfg_path), + "--start-after", "inference", "--dry-run", + ]) + assert result.exit_code == 0 + # Should show inference as skipped + assert "SKIPPED" in result.stdout or "skipped" in result.stdout.lower() + + def test_workflow_run_from_job_flag(self, tmp_path): + """--from-job flag should be parsed and extract_workflow_params called.""" + config = { + "workflow": "from_job_test", + "stages": [ + {"name": "inference", "task": "eval", "executor": "local", "params": {"model": "x"}}, + {"name": "collect", "task": "eval", "executor": "local", "depends_on": "inference", "params": {"model": "x"}}, + ], + "heartbeat_interval": 0.001, + } + cfg_path = tmp_path / "wf.yaml" + cfg_path.write_text(yaml.dump(config)) + + with patch("devrun.workflow.WorkflowRunner.extract_workflow_params") as mock_extract: + mock_extract.return_value = ( + {"params.model_name": "from-job-model"}, + "swe_bench_agentic", + ) + with patch("devrun.workflow.WorkflowRunner.detect_stage_for_task", return_value=None): + with patch("devrun.workflow.WorkflowRunner.run", return_value="wf_789"): + runner = get_cli_runner() + result = runner.invoke(app, [ + "workflow", "run", str(cfg_path), + "--from-job", "job_abc123", + ]) + # The flag should be parsed and extract_workflow_params called + assert result.exit_code in [0, 1] + mock_extract.assert_called_once_with("job_abc123") + + def test_workflow_run_detach_flag(self, tmp_path): + """--detach flag should call run_detached instead of run.""" + config = { + "workflow": "detach_test", + "stages": [ + {"name": "s1", "task": "eval", "executor": "local", "params": {"model": "x"}}, + ], + "heartbeat_interval": 0.001, + } + cfg_path = tmp_path / "wf.yaml" + cfg_path.write_text(yaml.dump(config)) + + with patch("devrun.workflow.WorkflowRunner.run_detached", return_value="wf_detach_001") as mock_detach: + runner = get_cli_runner() + result = runner.invoke(app, [ + "workflow", "run", str(cfg_path), + "--detach", + ]) + assert result.exit_code == 0 + mock_detach.assert_called_once() + assert "background" in result.stdout.lower() or "wf_detach_001" in result.stdout + + def test_workflow_run_placeholder_error(self, tmp_path): + """Workflow with unfilled placeholders should show helpful error.""" + config = { + "workflow": "placeholder_test", + "stages": [ + { + "name": "s1", + "task": "eval", + "executor": "local", + "params": { + "model": "", + "dataset": "/data/test", + }, + }, + ], + "heartbeat_interval": 0.001, + } + cfg_path = tmp_path / "wf.yaml" + cfg_path.write_text(yaml.dump(config)) + + runner = get_cli_runner() + result = runner.invoke(app, ["workflow", "run", str(cfg_path)]) + # Should fail with a helpful error about unfilled placeholders + assert result.exit_code == 1 + output_lower = result.stdout.lower() + assert "required" in output_lower or "placeholder" in output_lower or "unfilled" in output_lower + + class TestNoArgsIsHelp: """Tests that all Typer apps show help when invoked with no arguments.""" @@ -523,66 +656,65 @@ def test_presets_no_args_shows_help(self): class TestWorkflowRunResolution: - """Tests for workflow run using hierarchical config resolution.""" + """Tests for workflow run using hierarchical config resolution via find_configs.""" - def test_workflow_run_by_name(self): - """Verify workflow run accepts a name target and passes it to load_merged_config.""" - mock_config = { + def test_workflow_run_by_name(self, tmp_path): + """Verify workflow run accepts a name target and resolves it via find_configs.""" + config = { "workflow": "test_wf", "stages": [ {"name": "s1", "task": "eval", "executor": "local", "params": {"model": "x"}}, ], "heartbeat_interval": 0.001, } + cfg_path = tmp_path / "wf.yaml" + cfg_path.write_text(yaml.dump(config)) - with patch("devrun.runner.load_merged_config", return_value=mock_config) as mock_load: - with patch("devrun.workflow.WorkflowRunner") as mock_wf_cls: - mock_runner = MagicMock() - mock_runner.run.return_value = "wf_123" - mock_wf_cls.return_value = mock_runner - + with patch("devrun.runner.find_configs", return_value=[cfg_path]) as mock_find: + with patch("devrun.workflow.WorkflowRunner.run", return_value="wf_123"): runner = get_cli_runner() result = runner.invoke(app, ["workflow", "run", "my_workflow"]) assert result.exit_code == 0 - mock_load.assert_called_once_with("my_workflow", overrides=[]) - - def test_workflow_run_with_overrides(self): - """Verify trailing args are passed as overrides to load_merged_config.""" - mock_config = { - "workflow": "test_wf", - "stages": [ - {"name": "s1", "task": "eval", "executor": "local", "params": {"model": "x"}}, - ], - "heartbeat_interval": 0.001, - } - - with patch("devrun.runner.load_merged_config", return_value=mock_config) as mock_load: - with patch("devrun.workflow.WorkflowRunner") as mock_wf_cls: - mock_runner = MagicMock() - mock_runner.run.return_value = "wf_123" - mock_wf_cls.return_value = mock_runner - - runner = get_cli_runner() - result = runner.invoke( - app, - ["workflow", "run", "my_workflow", "params.model=new", "params.lr=0.01"], - ) - - assert result.exit_code == 0 - mock_load.assert_called_once_with( - "my_workflow", - overrides=["params.model=new", "params.lr=0.01"], - ) + mock_find.assert_called_once_with("my_workflow") def test_workflow_run_not_found(self): - """When load_merged_config raises FileNotFoundError, exit code 1 with error.""" + """When find_configs raises FileNotFoundError, exit code 1 with error.""" with patch( - "devrun.runner.load_merged_config", + "devrun.runner.find_configs", side_effect=FileNotFoundError("Config for 'bogus' not found."), ): runner = get_cli_runner() result = runner.invoke(app, ["workflow", "run", "bogus"]) assert result.exit_code == 1 - assert "not found" in result.stdout.lower() \ No newline at end of file + assert "not found" in result.stdout.lower() + + def test_workflow_run_help_with_target(self, tmp_path): + """'devrun workflow run --help' shows workflow-specific help.""" + mock_config = { + "workflow": "test_wf", + "params": {"model": "base", "lr": 0.01}, + "stages": [ + {"name": "train", "task": "eval", "executor": "local", "params": {}}, + {"name": "eval", "task": "eval", "executor": "local", "depends_on": "train", "params": {}}, + ], + } + + with patch("devrun.runner.load_merged_config", return_value=mock_config): + runner = get_cli_runner() + result = runner.invoke(app, ["workflow", "run", "my_workflow", "--help"]) + + assert result.exit_code == 0 + assert "test_wf" in result.stdout + assert "params.model" in result.stdout + assert "params.lr" in result.stdout + assert "train" in result.stdout + assert "eval" in result.stdout + + def test_workflow_run_help_without_target(self): + """'devrun workflow run --help' shows generic command help.""" + runner = get_cli_runner() + result = runner.invoke(app, ["workflow", "run", "--help"]) + assert result.exit_code == 0 + assert "usage" in result.stdout.lower() diff --git a/tests/test_workflow.py b/tests/test_workflow.py index 2a0f87c..263742c 100644 --- a/tests/test_workflow.py +++ b/tests/test_workflow.py @@ -38,7 +38,7 @@ def test_dry_run_does_not_submit(self, workflow_runner, simple_config): def test_dependency_ordering(self, workflow_runner, simple_config): """step2 depends on step1, so step1 must run first.""" with patch.object(workflow_runner, "_submit_stage") as mock_submit: - mock_submit.return_value = "mock_job_id" + mock_submit.return_value = ("mock_db_id", "mock_remote_id") with patch.object(workflow_runner, "_poll_job_status", return_value="completed"): workflow_runner.run(simple_config) calls = [c.args[0] for c in mock_submit.call_args_list] @@ -48,7 +48,7 @@ def test_dependency_ordering(self, workflow_runner, simple_config): def test_failure_stops_workflow(self, workflow_runner, simple_config): """If step1 fails, step2 should never be submitted.""" with patch.object(workflow_runner, "_submit_stage") as mock_submit: - mock_submit.return_value = "mock_job_id" + mock_submit.return_value = ("mock_db_id", "mock_remote_id") with patch.object(workflow_runner, "_poll_job_status", return_value="failed"): wf_id = workflow_runner.run(simple_config) record = workflow_runner._db.get_workflow(wf_id) @@ -57,7 +57,7 @@ def test_failure_stops_workflow(self, workflow_runner, simple_config): def test_all_stages_complete(self, workflow_runner, simple_config): with patch.object(workflow_runner, "_submit_stage") as mock_submit: - mock_submit.return_value = "mock_job_id" + mock_submit.return_value = ("mock_db_id", "mock_remote_id") with patch.object(workflow_runner, "_poll_job_status", return_value="completed"): wf_id = workflow_runner.run(simple_config) record = workflow_runner._db.get_workflow(wf_id) @@ -71,7 +71,7 @@ def test_timeout_cancels_workflow(self, workflow_runner): heartbeat_interval=0.001, ) with patch.object(workflow_runner, "_submit_stage") as mock_submit: - mock_submit.return_value = "mock_job_id" + mock_submit.return_value = ("mock_db_id", "mock_remote_id") with patch.object(workflow_runner, "_poll_job_status", return_value="running"): wf_id = workflow_runner.run(cfg) record = workflow_runner._db.get_workflow(wf_id) @@ -93,7 +93,7 @@ def test_submit_failure_marks_stage_failed(self, workflow_runner): def test_skipped_stage_when_dep_fails(self, workflow_runner, simple_config): """step2 should be skipped when step1 fails.""" with patch.object(workflow_runner, "_submit_stage") as mock_submit: - mock_submit.return_value = "mock_job_id" + mock_submit.return_value = ("mock_db_id", "mock_remote_id") with patch.object(workflow_runner, "_poll_job_status", return_value="failed"): wf_id = workflow_runner.run(simple_config) record = workflow_runner._db.get_workflow(wf_id) @@ -104,8 +104,8 @@ def test_skipped_stage_when_dep_fails(self, workflow_runner, simple_config): def test_cancel_workflow(self, workflow_runner): wf_id = workflow_runner._db.insert_workflow("test", { - "s1": {"status": "running", "job_id": "j1"}, - "s2": {"status": "pending", "job_id": None}, + "s1": {"status": "running", "remote_job_id": "j1"}, + "s2": {"status": "pending", "remote_job_id": None}, }) workflow_runner._db.update_workflow(wf_id, status="running") workflow_runner.cancel(wf_id) @@ -115,14 +115,47 @@ def test_cancel_workflow(self, workflow_runner): assert stages["s1"]["status"] == "cancelled" assert stages["s2"]["status"] == "pending" # was not running + def test_cancelled_stage_exits_heartbeat_loop(self, workflow_runner): + """A cancelled stage should be treated as terminal so the heartbeat loop exits promptly.""" + import time + + cfg = WorkflowConfig( + workflow="cancel_loop_test", + stages=[ + WorkflowStage(name="s1", task="eval", executor="local"), + ], + timeout=600, + heartbeat_interval=0.001, + ) + # Pre-set stage as cancelled (simulates external cancel() call) + stages_state = { + "s1": { + "status": "cancelled", + "remote_job_id": "j1", + "db_job_id": "db1", + "executor": "local", + }, + } + wf_id = workflow_runner._db.insert_workflow("cancel_loop_test", stages_state) + workflow_runner._db.update_workflow(wf_id, status="running") + + start = time.monotonic() + result = workflow_runner._heartbeat_loop(wf_id, cfg, stages_state) + elapsed = time.monotonic() - start + + record = workflow_runner._db.get_workflow(result) + assert record["status"] == "failed" + # Must exit promptly — not spin until the 600s timeout + assert elapsed < 5.0 + def test_cancel_nonexistent_raises(self, workflow_runner): with pytest.raises(ValueError, match="not found"): workflow_runner.cancel("nonexistent") def test_logs_all_stages(self, workflow_runner): wf_id = workflow_runner._db.insert_workflow("test", { - "s1": {"status": "completed", "job_id": "j1"}, - "s2": {"status": "running", "job_id": "j2"}, + "s1": {"status": "completed", "remote_job_id": "j1"}, + "s2": {"status": "running", "remote_job_id": "j2"}, }) result = workflow_runner.logs(wf_id) assert "s1" in result @@ -132,7 +165,7 @@ def test_logs_all_stages(self, workflow_runner): def test_logs_specific_stage(self, workflow_runner): wf_id = workflow_runner._db.insert_workflow("test", { - "s1": {"status": "completed", "job_id": "j1"}, + "s1": {"status": "completed", "remote_job_id": "j1"}, }) result = workflow_runner.logs(wf_id, stage="s1") assert "j1" in result @@ -171,7 +204,7 @@ def test_independent_stages_both_submitted(self, workflow_runner): heartbeat_interval=0.001, ) with patch.object(workflow_runner, "_submit_stage") as mock_submit: - mock_submit.return_value = "mock_job_id" + mock_submit.return_value = ("mock_db_id", "mock_remote_id") with patch.object(workflow_runner, "_poll_job_status", return_value="completed"): wf_id = workflow_runner.run(cfg) record = workflow_runner._db.get_workflow(wf_id) @@ -191,7 +224,7 @@ def test_diamond_dependency(self, workflow_runner): heartbeat_interval=0.001, ) with patch.object(workflow_runner, "_submit_stage") as mock_submit: - mock_submit.return_value = "mock_job_id" + mock_submit.return_value = ("mock_db_id", "mock_remote_id") with patch.object(workflow_runner, "_poll_job_status", return_value="completed"): wf_id = workflow_runner.run(cfg) record = workflow_runner._db.get_workflow(wf_id) @@ -203,3 +236,419 @@ def test_diamond_dependency(self, workflow_runner): assert names.index("A") < names.index("C") assert names.index("B") < names.index("D") assert names.index("C") < names.index("D") + + +# ============================================================================ +# OmegaConf interpolation tests +# ============================================================================ + + +class TestOmegaConfInterpolation: + """Tests for OmegaConf ${params.X} resolution in workflow configs.""" + + def test_omegaconf_interpolation_resolved(self, tmp_path): + """Workflow config with ${params.X} references should resolve correctly.""" + import yaml + + config_data = { + "workflow": "interp_test", + "params": { + "model_name": "test-model", + "output_dir": "/data/output", + }, + "stages": [ + { + "name": "inference", + "task": "eval", + "executor": "local", + "params": { + "model": "${params.model_name}", + "output": "${params.output_dir}/results", + }, + }, + { + "name": "evaluate", + "task": "eval", + "executor": "local", + "depends_on": "inference", + "params": { + "model": "${params.model_name}", + "predictions_path": "${params.output_dir}/predictions.jsonl", + }, + }, + ], + "heartbeat_interval": 0.001, + } + cfg_path = tmp_path / "workflow.yaml" + cfg_path.write_text(yaml.dump(config_data)) + + from omegaconf import OmegaConf + + raw = OmegaConf.load(str(cfg_path)) + resolved = OmegaConf.to_container(raw, resolve=True) + cfg = WorkflowConfig(**resolved) + + # Verify interpolation resolved correctly + assert cfg.stages[0].params["model"] == "test-model" + assert cfg.stages[0].params["output"] == "/data/output/results" + assert cfg.stages[1].params["model"] == "test-model" + assert cfg.stages[1].params["predictions_path"] == "/data/output/predictions.jsonl" + + +# ============================================================================ +# start_after tests +# ============================================================================ + + +class TestWorkflowStartAfter: + """Tests for the start_after feature that skips completed stages.""" + + @pytest.fixture + def three_stage_config(self): + return WorkflowConfig( + workflow="three_stage", + stages=[ + WorkflowStage( + name="inference", task="eval", executor="local", + params={"model": "x"}, + ), + WorkflowStage( + name="collect", task="eval", executor="local", + depends_on="inference", + params={"output_dir": "/data"}, + ), + WorkflowStage( + name="evaluate", task="eval", executor="local", + depends_on="collect", + params={"predictions": "/data/pred.jsonl"}, + ), + ], + heartbeat_interval=0.001, + ) + + def test_start_after_skips_named_stage(self, workflow_runner, three_stage_config): + """When start_after='inference', the inference stage is pre-marked as skipped_by_user.""" + with patch.object(workflow_runner, "_submit_stage") as mock_submit: + mock_submit.return_value = ("mock_db_id", "mock_remote_id") + with patch.object(workflow_runner, "_poll_job_status", return_value="completed"): + wf_id = workflow_runner.run(three_stage_config, start_after="inference") + record = workflow_runner._db.get_workflow(wf_id) + stages = json.loads(record["stages_state"]) + assert stages["inference"]["status"] == "skipped_by_user" + # collect and evaluate should have been submitted + submitted_names = [c.args[0] for c in mock_submit.call_args_list] + assert "inference" not in submitted_names + assert "collect" in submitted_names + assert "evaluate" in submitted_names + + def test_start_after_skips_transitive_deps(self, workflow_runner): + """Stages that the start_after stage depends on are also skipped.""" + cfg = WorkflowConfig( + workflow="transitive_test", + stages=[ + WorkflowStage(name="prep", task="eval", executor="local"), + WorkflowStage(name="inference", task="eval", executor="local", depends_on="prep"), + WorkflowStage(name="collect", task="eval", executor="local", depends_on="inference"), + WorkflowStage(name="evaluate", task="eval", executor="local", depends_on="collect"), + ], + heartbeat_interval=0.001, + ) + with patch.object(workflow_runner, "_submit_stage") as mock_submit: + mock_submit.return_value = ("mock_db_id", "mock_remote_id") + with patch.object(workflow_runner, "_poll_job_status", return_value="completed"): + wf_id = workflow_runner.run(cfg, start_after="inference") + record = workflow_runner._db.get_workflow(wf_id) + stages = json.loads(record["stages_state"]) + # Both prep and inference should be skipped_by_user + assert stages["prep"]["status"] == "skipped_by_user" + assert stages["inference"]["status"] == "skipped_by_user" + # collect and evaluate should run + submitted_names = [c.args[0] for c in mock_submit.call_args_list] + assert "collect" in submitted_names + assert "evaluate" in submitted_names + + def test_start_after_invalid_stage_raises(self, workflow_runner, three_stage_config): + """start_after with a nonexistent stage name should raise ValueError.""" + with pytest.raises(ValueError, match="not found|does not exist|unknown stage"): + workflow_runner.run(three_stage_config, start_after="nonexistent_stage") + + def test_start_after_last_stage_completes_immediately(self, workflow_runner, three_stage_config): + """start_after the last stage means everything is skipped — workflow completes with all skipped.""" + wf_id = workflow_runner.run(three_stage_config, start_after="evaluate") + record = workflow_runner._db.get_workflow(wf_id) + stages = json.loads(record["stages_state"]) + # All stages should be skipped_by_user + for name in ("inference", "collect", "evaluate"): + assert stages[name]["status"] == "skipped_by_user" + assert record["status"] == "completed" + + def test_start_after_dry_run_shows_skip_info(self, workflow_runner, three_stage_config): + """Dry-run with start_after should indicate which stages are skipped.""" + result = workflow_runner.run(three_stage_config, dry_run=True, start_after="inference") + assert isinstance(result, str) + # The output should indicate inference is skipped + result_lower = result.lower() + assert "skip" in result_lower + assert "inference" in result_lower + # collect and evaluate should still be shown + assert "collect" in result_lower + assert "evaluate" in result_lower + + +# ============================================================================ +# from_job tests +# ============================================================================ + + +class TestWorkflowFromJob: + """Tests for extract_workflow_params and detect_stage_for_task.""" + + @pytest.fixture + def swe_bench_workflow_config(self): + return WorkflowConfig( + workflow="swe_bench", + stages=[ + WorkflowStage( + name="inference", task="swe_bench_agentic", executor="slurm", + params={ + "model_name": "", + "dataset": "/data/SWE-bench_Verified", + "split": "test", + "output_dir": "", + "run_name": "", + }, + ), + WorkflowStage( + name="collect", task="swe_bench_collect", executor="local", + depends_on="inference", + params={ + "output_dir": "", + "dataset": "/data/SWE-bench_Verified", + "split": "test", + }, + ), + WorkflowStage( + name="evaluate", task="swe_bench_eval", executor="local", + depends_on="collect", + params={ + "dataset_name": "/data/SWE-bench_Verified", + "predictions_path": "", + }, + ), + ], + heartbeat_interval=0.001, + ) + + def test_extract_workflow_params_extracts_params(self, workflow_runner): + """extract_workflow_params should return a dotlist dict from a stored job record.""" + job_params = { + "model_name": "test-model", + "dataset": "/data/SWE-bench_Verified", + "split": "test", + "output_dir": "logs/test_run", + "run_name": "test_run", + } + job_id = workflow_runner._db.insert( + task_name="swe_bench_agentic", + executor="slurm", + parameters=job_params, + ) + + dotlist, task_name = workflow_runner.extract_workflow_params(job_id) + assert task_name == "swe_bench_agentic" + assert dotlist["params.model_name"] == "test-model" + assert dotlist["params.output_dir"] == "logs/test_run" + assert dotlist["params.run_name"] == "test_run" + assert dotlist["params.dataset"] == "/data/SWE-bench_Verified" + + def test_extract_workflow_params_nonexistent_raises(self, workflow_runner): + """extract_workflow_params with an invalid job_id should raise ValueError.""" + with pytest.raises(ValueError, match="not found"): + workflow_runner.extract_workflow_params("nonexistent_job_id") + + def test_extract_workflow_params_omits_empty_values(self, workflow_runner): + """extract_workflow_params should skip params with empty/falsy values.""" + job_params = { + "model_name": "test-model", + "dataset": "", + "output_dir": "logs/run", + } + job_id = workflow_runner._db.insert( + task_name="swe_bench_agentic", + executor="slurm", + parameters=job_params, + ) + dotlist, _ = workflow_runner.extract_workflow_params(job_id) + assert "params.model_name" in dotlist + assert "params.output_dir" in dotlist + assert "params.dataset" not in dotlist # empty string skipped + + def test_detect_stage_for_task(self, workflow_runner, swe_bench_workflow_config): + """detect_stage_for_task maps task type to stage name.""" + assert workflow_runner.detect_stage_for_task("swe_bench_agentic", swe_bench_workflow_config) == "inference" + assert workflow_runner.detect_stage_for_task("swe_bench_collect", swe_bench_workflow_config) == "collect" + assert workflow_runner.detect_stage_for_task("swe_bench_eval", swe_bench_workflow_config) == "evaluate" + assert workflow_runner.detect_stage_for_task("unknown_task", swe_bench_workflow_config) is None + + +# ============================================================================ +# Detached mode tests +# ============================================================================ + + +class TestWorkflowDetached: + """Tests for detached (background) workflow execution.""" + + def test_detached_creates_db_record(self, workflow_runner): + """run_detached should write a workflow record to the DB before returning.""" + cfg = WorkflowConfig( + workflow="detach_test", + stages=[ + WorkflowStage(name="slow_stage", task="eval", executor="local", + params={"model": "x"}), + ], + heartbeat_interval=1.0, + ) + with patch("devrun.workflow.subprocess.Popen") as mock_popen: + mock_popen.return_value = MagicMock() + wf_id = workflow_runner.run_detached(cfg) + # Should return immediately with a workflow_id + assert isinstance(wf_id, str) + assert len(wf_id) > 0 + # DB record should exist + record = workflow_runner._db.get_workflow(wf_id) + assert record is not None + assert record["workflow_name"] == "detach_test" + + def test_detached_returns_workflow_id_immediately(self, workflow_runner): + """run_detached should return the workflow ID immediately without blocking.""" + cfg = WorkflowConfig( + workflow="detach_return_test", + stages=[ + WorkflowStage(name="s1", task="eval", executor="local", + params={"model": "x"}), + WorkflowStage(name="s2", task="eval", executor="local", + depends_on="s1", params={"model": "y"}), + ], + heartbeat_interval=10.0, + ) + import time + + with patch("devrun.workflow.subprocess.Popen") as mock_popen: + mock_popen.return_value = MagicMock() + start = time.monotonic() + wf_id = workflow_runner.run_detached(cfg) + elapsed = time.monotonic() - start + # Should return much faster than the heartbeat interval + assert elapsed < 5.0 + assert isinstance(wf_id, str) + # Popen should have been called to spawn background process + mock_popen.assert_called_once() + + +# ============================================================================ +# Placeholder validation tests +# ============================================================================ + + +class TestPlaceholderValidation: + """Tests for placeholder validation.""" + + def test_placeholder_validation_catches_required(self, workflow_runner): + """Config with placeholders should be detected and rejected.""" + cfg = WorkflowConfig( + workflow="placeholder_test", + stages=[ + WorkflowStage( + name="s1", task="eval", executor="local", + params={"model": "", "dataset": "/data/test"}, + ), + ], + heartbeat_interval=0.001, + ) + with pytest.raises((ValueError, RuntimeError), match="REQUIRED|placeholder|unfilled"): + workflow_runner.run(cfg) + + def test_placeholder_validation_passes_clean_config(self, workflow_runner): + """Config without placeholders should pass validation.""" + cfg = WorkflowConfig( + workflow="clean_test", + stages=[ + WorkflowStage( + name="s1", task="eval", executor="local", + params={"model": "test-model", "dataset": "/data/test"}, + ), + ], + heartbeat_interval=0.001, + ) + with patch.object(workflow_runner, "_submit_stage") as mock_submit: + mock_submit.return_value = ("mock_db_id", "mock_remote_id") + with patch.object(workflow_runner, "_poll_job_status", return_value="completed"): + # Should not raise + wf_id = workflow_runner.run(cfg) + assert isinstance(wf_id, str) + + +# ============================================================================ +# Enhanced dry-run output tests +# ============================================================================ + + +class TestImprovedDryRun: + """Tests for enhanced dry-run output format.""" + + def test_improved_dry_run_output(self, workflow_runner): + """Enhanced dry-run should include full params and detailed formatting.""" + cfg = WorkflowConfig( + workflow="dryrun_test", + stages=[ + WorkflowStage( + name="inference", task="eval", executor="local", + params={"model": "test-model", "batch_size": 8}, + ), + WorkflowStage( + name="evaluate", task="eval", executor="local", + depends_on="inference", + params={"dataset": "/data/test"}, + ), + ], + heartbeat_interval=0.001, + ) + result = workflow_runner.run(cfg, dry_run=True) + assert isinstance(result, str) + # Should include stage info + assert "inference" in result + assert "evaluate" in result + # Should include executor info + assert "local" in result + # Should include task info + assert "eval" in result + # Should include some form of parameter display + result_lower = result.lower() + assert "param" in result_lower or "model" in result_lower or "batch_size" in result_lower + + +# ============================================================================ +# Enhanced logs tests +# ============================================================================ + + +class TestEnhancedLogs: + """Tests for enhanced workflow logs that delegate to executor.""" + + def test_enhanced_logs_delegates_to_executor(self, workflow_runner): + """Logs method should attempt to delegate to executor.logs() for actual content.""" + wf_id = workflow_runner._db.insert_workflow("test", { + "s1": {"status": "completed", "remote_job_id": "j1", "executor": "local"}, + }) + # Insert a corresponding job record so executor lookup works + workflow_runner._db.insert(task_name="eval", executor="local") + + mock_executor = MagicMock() + mock_executor.logs.return_value = "real log content from executor" + + with patch("devrun.workflow.resolve_executor", return_value=mock_executor): + result = workflow_runner.logs(wf_id, stage="s1") + # The result should include actual log content or at least delegate + assert isinstance(result, str) + assert len(result) > 0 diff --git a/tests/test_workflow_simulation.py b/tests/test_workflow_simulation.py index e208d58..4657a72 100644 --- a/tests/test_workflow_simulation.py +++ b/tests/test_workflow_simulation.py @@ -173,3 +173,219 @@ def test_all_stages_use_same_dataset(self, swe_bench_config): col_dataset = swe_bench_config.stages[1].params["dataset"] eval_dataset = swe_bench_config.stages[2].params["dataset_name"] assert inf_dataset == col_dataset == eval_dataset + + +class TestWorkflowSimulationStartAfter: + """Simulation tests for start_after and from_job workflows.""" + + @pytest.fixture + def swe_bench_three_stage_config(self): + """A three-stage SWE-bench workflow matching the production pattern.""" + return WorkflowConfig( + workflow="swe_bench", + params={ + "model_name": "test-model", + "dataset": "/mnt/data/SWE-bench_Verified", + "split": "test", + "output_dir": "logs/sim_run", + "run_name": "sim_run", + "working_dir": "/remote/project", + }, + stages=[ + WorkflowStage( + name="inference", + task="swe_bench_agentic", + executor="slurm", + params={ + "model_name": "test-model", + "dataset": "/mnt/data/SWE-bench_Verified", + "split": "test", + "run_name": "sim_run", + "output_dir": "logs/sim_run", + "llm_config": "/fake/config.json", + "max_iterations": 100, + "max_attempts": 5, + "array": "000-004", + "working_dir": "/remote/project", + "base_url": "http://localhost:8000", + "api_key": "sk-test", + "temperature": "0.7", + "top_p": "0.95", + "env_commands": ["source /opt/conda/bin/activate"], + "env": {}, + }, + ), + WorkflowStage( + name="collect", + task="swe_bench_collect", + executor="local", + depends_on="inference", + params={ + "output_dir": "logs/sim_run", + "dataset": "/mnt/data/SWE-bench_Verified", + "split": "test", + "model_name_or_path": "test-model", + "predictions_path": "logs/sim_run/predictions.jsonl", + "working_dir": "/remote/project", + }, + ), + WorkflowStage( + name="evaluate", + task="swe_bench_eval", + executor="local", + depends_on="collect", + params={ + "dataset_name": "/mnt/data/SWE-bench_Verified", + "predictions_path": "logs/sim_run/predictions.jsonl", + "working_dir": "/remote/project", + "run_id": "sim_test", + }, + ), + ], + heartbeat_interval=0.001, + ) + + def test_start_after_inference_runs_collect_eval( + self, swe_bench_three_stage_config, tmp_path + ): + """Core use case: skip inference, run collect and evaluate. + + When start_after='inference', only collect and evaluate should execute. + This simulates the primary scenario where inference was run via an + existing swe_bench_agentic task and we want to continue with + collect + eval. + """ + from unittest.mock import patch + from devrun.workflow import WorkflowRunner + + runner = WorkflowRunner(db_path=tmp_path / "sim_test.db") + + with patch.object(runner, "_submit_stage") as mock_submit: + mock_submit.return_value = ("mock_db_id", "mock_remote_id") + with patch.object(runner, "_poll_job_status", return_value="completed"): + wf_id = runner.run( + swe_bench_three_stage_config, start_after="inference" + ) + record = runner._db.get_workflow(wf_id) + import json + stages = json.loads(record["stages_state"]) + + # Inference skipped, collect + evaluate completed + assert stages["inference"]["status"] == "skipped_by_user" + assert stages["collect"]["status"] == "completed" + assert stages["evaluate"]["status"] == "completed" + + # Only collect and evaluate were submitted (in order) + submitted = [c.args[0] for c in mock_submit.call_args_list] + assert submitted == ["collect", "evaluate"] + + def test_from_job_populates_downstream_stages(self, tmp_path): + """Params extracted from a job should propagate correctly via OmegaConf merge. + + This simulates the full CLI flow: extract_workflow_params returns a + dotlist dict, which is merged into the raw OmegaConf config before + resolution. The merged config should have REQUIRED placeholders + replaced with actual values from the source job, maintaining DS_DIR + consistency across stages. + """ + import yaml as _yaml + from omegaconf import OmegaConf + from devrun.workflow import WorkflowRunner + + runner = WorkflowRunner(db_path=tmp_path / "from_job_sim.db") + + # Insert a source job record simulating a completed swe_bench_agentic run + job_params = { + "model_name": "gpt-4o", + "dataset": "/mnt/data/SWE-bench_Verified", + "split": "test", + "output_dir": "logs/gpt4o_run", + "run_name": "gpt4o_run", + "working_dir": "/remote/project", + "llm_config": "/fake/config.json", + "max_iterations": 100, + } + job_id = runner._db.insert( + task_name="swe_bench_agentic", + executor="slurm", + parameters=job_params, + ) + + # Template YAML config with OmegaConf interpolation and REQUIRED placeholders + template_yaml = { + "workflow": "swe_bench", + "params": { + "model_name": "", + "dataset": "/mnt/data/SWE-bench_Verified", + "split": "test", + "output_dir": "", + "run_name": "", + }, + "stages": [ + { + "name": "inference", + "task": "swe_bench_agentic", + "executor": "slurm", + "params": { + "model_name": "${params.model_name}", + "dataset": "${params.dataset}", + "split": "${params.split}", + "output_dir": "${params.output_dir}", + "run_name": "${params.run_name}", + }, + }, + { + "name": "collect", + "task": "swe_bench_collect", + "executor": "local", + "depends_on": "inference", + "params": { + "output_dir": "${params.output_dir}", + "dataset": "${params.dataset}", + "split": "${params.split}", + }, + }, + { + "name": "evaluate", + "task": "swe_bench_eval", + "executor": "local", + "depends_on": "collect", + "params": { + "dataset_name": "${params.dataset}", + }, + }, + ], + "heartbeat_interval": 0.001, + } + cfg_path = tmp_path / "workflow.yaml" + cfg_path.write_text(_yaml.dump(template_yaml)) + + # Simulate CLI flow: extract params → merge → resolve + dotlist, task_name = runner.extract_workflow_params(job_id) + assert task_name == "swe_bench_agentic" + assert dotlist["params.model_name"] == "gpt-4o" + + raw_cfg = OmegaConf.load(str(cfg_path)) + job_overrides = [f"{k}={v}" for k, v in dotlist.items()] + raw_cfg = OmegaConf.merge(raw_cfg, OmegaConf.from_dotlist(job_overrides)) + resolved = OmegaConf.to_container(raw_cfg, resolve=True) + cfg = WorkflowConfig(**resolved) + + # Auto-detect stage + detected = runner.detect_stage_for_task(task_name, cfg) + assert detected == "inference" + + # Verify params were populated correctly in all stages + inf_params = cfg.stages[0].params + assert inf_params["model_name"] == "gpt-4o" + assert inf_params["output_dir"] == "logs/gpt4o_run" + assert inf_params["run_name"] == "gpt4o_run" + + col_params = cfg.stages[1].params + assert col_params["output_dir"] == "logs/gpt4o_run" + assert col_params["dataset"] == "/mnt/data/SWE-bench_Verified" + + # DS_DIR should be consistent between inference and collect + inf_ds_dir = derive_ds_dir(inf_params["dataset"], inf_params["split"]) + col_ds_dir = derive_ds_dir(col_params["dataset"], col_params["split"]) + assert inf_ds_dir == col_ds_dir