diff --git a/README.md b/README.md index d7f7723..bd84e8a 100644 --- a/README.md +++ b/README.md @@ -255,12 +255,19 @@ paperbanana generate \ --input paper.pdf \ --caption "Overview of our method" \ --pdf-pages "3-8" + +# Guide generation with a reference/sketch image (repeatable) +paperbanana generate \ + --input method.txt \ + --caption "Overview of our framework" \ + --image sketch.png --image prior_figure.png ``` | Flag | Short | Description | |------|-------|-------------| | `--input` | `-i` | Path to methodology text file or PDF (required for new runs) | | `--caption` | `-c` | Figure caption / communicative intent (required for new runs) | +| `--image` | | Reference/sketch image (hand-drawn sketch, whiteboard photo, prior figure) that guides the Planner. Repeatable for multiple images | | `--output` | `-o` | Output image path (default: auto-generated in `outputs/`) | | `--iterations` | `-n` | Number of Visualizer-Critic refinement rounds (default: 3) | | `--num-candidates` | `-k` | Generate N candidate images in parallel, 1-8 (default: 1). Planning runs once; refinement fans out per candidate with seed offsets. Outputs land in `candidates/cand_/`; the run-root `final_output` is candidate 1. Cost estimates and `--budget` account for the fan-out | diff --git a/mcp_server/server.py b/mcp_server/server.py index 10b4f16..401cded 100644 --- a/mcp_server/server.py +++ b/mcp_server/server.py @@ -172,6 +172,28 @@ def _embed_caption(image_path: str, caption: str) -> None: mcp = FastMCP("PaperBanana") +def _validate_input_images(input_images: list[str] | None) -> list[str]: + """Validate user-provided reference/sketch image paths before the pipeline starts. + + Each path must exist and be a PIL-openable raster image. + + Raises: + ValueError: If a path is missing or not a valid raster image. + """ + validated: list[str] = [] + for image_path in input_images or []: + path = Path(image_path) + if not path.is_file(): + raise ValueError(f"Image file not found: {image_path}") + try: + with PILImage.open(path) as im: + im.verify() + except Exception: + raise ValueError(f"Not a valid raster image (e.g. PNG, JPEG, WebP): {image_path}") + validated.append(str(path)) + return validated + + @mcp.tool async def generate_diagram( source_context: str, @@ -181,6 +203,7 @@ async def generate_diagram( optimize: bool = False, auto_refine: bool = False, generate_caption: bool = False, + input_images: list[str] | None = None, ) -> Image: """Generate a publication-quality methodology diagram from text. @@ -197,10 +220,15 @@ async def generate_diagram( generate_caption: Auto-generate a publication-ready figure caption after generation. When True, the caption is embedded in the image metadata (PNG tEXt chunk, key "Caption") and logged. + input_images: Optional file paths to user-provided reference/sketch + images (hand-drawn sketch, whiteboard photo, prior figure) that + guide the layout and content of the generated diagram. Returns: The generated diagram as a PNG image. """ + validated_images = _validate_input_images(input_images) + settings = Settings( refinement_iterations=iterations, optimize_inputs=optimize, @@ -223,6 +251,7 @@ def _on_progress(event: str, payload: dict) -> None: communicative_intent=caption, diagram_type=DiagramType.METHODOLOGY, aspect_ratio=aspect_ratio, + input_images=validated_images, ) result = await pipeline.generate(gen_input) diff --git a/paperbanana/agents/planner.py b/paperbanana/agents/planner.py index fe22184..f38dd01 100644 --- a/paperbanana/agents/planner.py +++ b/paperbanana/agents/planner.py @@ -46,6 +46,7 @@ async def run( examples: list[ReferenceExample], diagram_type: DiagramType = DiagramType.METHODOLOGY, supported_ratios: list[str] | None = None, + input_images: list[str] | None = None, ) -> tuple[str, str | None]: """Generate a detailed textual description of the target diagram. @@ -55,6 +56,8 @@ async def run( examples: Retrieved reference examples for in-context learning. diagram_type: Type of diagram being generated. supported_ratios: Aspect ratios the image provider supports. + input_images: Paths to user-provided reference/sketch images that + guide the plan alongside the retrieved exemplars. Returns: Tuple of (description, recommended_ratio). @@ -66,8 +69,18 @@ async def run( # Load reference images for visual in-context learning example_images = await asyncio.to_thread(self._load_example_images, examples) + # Load user-provided reference/sketch images (attached after the + # exemplar images so "reference image N" indexing stays valid). + user_images: list = [] + if input_images: + user_images = await asyncio.to_thread(self._load_input_images, input_images) + prompt_type = "diagram" if diagram_type == DiagramType.METHODOLOGY else "plot" template = self.load_prompt(prompt_type) + if user_images: + # Appended pre-format so the prompt recorder captures it; the note + # is brace-free, keeping str.format() on the template intact. + template += "\n\n" + self._format_user_image_note(len(user_images)) # Inject supported ratios into the prompt template ratios_str = ", ".join(supported_ratios) if supported_ratios else "1:1, 16:9" prompt = self.format_prompt( @@ -83,12 +96,14 @@ async def run( "Running planner agent", num_examples=len(examples), num_images=len(example_images), + num_user_images=len(user_images), context_length=len(source_context), ) + all_images = example_images + user_images raw_output = await self.vlm.generate( prompt=prompt, - images=example_images if example_images else None, + images=all_images if all_images else None, temperature=0.7, max_tokens=4096, ) @@ -242,6 +257,36 @@ def _load_example_images(self, examples: list[ReferenceExample]) -> list: ) return images + @staticmethod + def _format_user_image_note(count: int) -> str: + """Label for user-provided reference/sketch images attached to the prompt.""" + return ( + "## User-Provided Reference/Sketch\n" + f"The final {count} attached image(s), after the reference example images, " + "are user-provided reference/sketch images (e.g. a hand-drawn sketch, " + "whiteboard photo, or a prior version of the figure). Use them as guidance " + "for the layout and content of the target diagram while staying faithful " + "to the source text." + ) + + def _load_input_images(self, paths: list[str]) -> list: + """Load user-provided reference/sketch images from local paths. + + Returns PIL Image objects; unreadable files are skipped with a warning + (the CLI/MCP entry points validate them before the pipeline starts). + """ + images = [] + for path in paths: + try: + images.append(load_image(path)) + except Exception as e: + logger.warning( + "Failed to load user-provided reference image", + image_path=path, + error=str(e), + ) + return images + _VALID_RATIOS = {"1:1", "2:3", "3:2", "3:4", "4:3", "9:16", "16:9", "21:9"} @classmethod diff --git a/paperbanana/agents/visualizer.py b/paperbanana/agents/visualizer.py index 808acac..96b3a85 100644 --- a/paperbanana/agents/visualizer.py +++ b/paperbanana/agents/visualizer.py @@ -61,6 +61,7 @@ async def run( seed: Optional[int] = None, aspect_ratio: Optional[str] = None, vector_formats: Optional[list[str]] = None, + sketch_guided: bool = False, ) -> str: """Generate an image from a description. @@ -74,6 +75,8 @@ async def run( aspect_ratio: Target aspect ratio (e.g., '16:9', '1:1'). vector_formats: Vector formats to export alongside raster (e.g., ['svg', 'pdf']). Only applies to statistical plots; ignored for methodology diagrams. + sketch_guided: When True, the diagram prompt notes that a + user-provided reference sketch guided the plan. Returns: Path to the generated raster image. @@ -90,8 +93,14 @@ async def run( iteration, seed, aspect_ratio, + sketch_guided=sketch_guided, ) + _SKETCH_GUIDED_NOTE = ( + "Note: this plan was guided by a user-provided reference sketch; " + "follow the description above faithfully." + ) + async def _generate_diagram( self, description: str, @@ -99,9 +108,12 @@ async def _generate_diagram( iteration: int, seed: Optional[int], aspect_ratio: Optional[str] = None, + sketch_guided: bool = False, ) -> str: """Generate a methodology diagram using the image generation model.""" template = self.load_prompt("diagram") + if sketch_guided: + template += "\n\n" + self._SKETCH_GUIDED_NOTE prompt = self.format_prompt( template, prompt_label=f"visualizer_diagram_iter_{iteration}", diff --git a/paperbanana/cli.py b/paperbanana/cli.py index f824ada..7d3b076 100644 --- a/paperbanana/cli.py +++ b/paperbanana/cli.py @@ -245,6 +245,14 @@ def generate( caption: Optional[str] = typer.Option( None, "--caption", "-c", help="Figure caption / communicative intent" ), + image: Optional[list[str]] = typer.Option( + None, + "--image", + help=( + "Path to a reference/sketch image (hand-drawn sketch, whiteboard photo, " + "prior figure) that guides generation. Repeatable for multiple images." + ), + ), output: Optional[str] = typer.Option(None, "--output", "-o", help="Output image path"), output_dir: Optional[str] = typer.Option( None, @@ -455,6 +463,31 @@ def generate( "[red]Error: --pdf-pages cannot be used with --continue or --continue-run[/red]" ) raise typer.Exit(1) + if image and (continue_last or continue_run): + console.print("[red]Error: --image cannot be used with --continue or --continue-run[/red]") + raise typer.Exit(1) + + # Validate reference/sketch images before any pipeline work starts. + input_images: list[str] = [] + if image: + from PIL import Image as PILImage + from PIL import UnidentifiedImageError + + for image_path in image: + img_file = Path(image_path) + if not img_file.is_file(): + console.print(f"[red]Error: Image file not found: {image_path}[/red]") + raise typer.Exit(1) + try: + with PILImage.open(img_file) as im: + im.verify() + except (UnidentifiedImageError, OSError, ValueError): + console.print( + f"[red]Error: Not a valid raster image (e.g. PNG, JPEG, WebP): " + f"{image_path}[/red]" + ) + raise typer.Exit(1) + input_images.append(str(img_file)) _valid_categories = { "agent_reasoning", @@ -721,6 +754,7 @@ async def _run_continue(): diagram_type=DiagramType.METHODOLOGY, aspect_ratio=aspect_ratio, reference_ids=ref_id_list, + input_images=input_images, ) # Determine expected output file extension based on settings.output_format @@ -764,6 +798,8 @@ async def _run_continue(): pdf_note = "" if input_path.suffix.lower() == ".pdf": pdf_note = f"\nPDF pages: {pdf_pages.strip() if pdf_pages else 'all'}" + if input_images: + pdf_note += f"\nReference images: {', '.join(input_images)}" console.print( Panel.fit( "[bold]PaperBanana[/bold] - Dry Run\n\n" diff --git a/paperbanana/core/pipeline.py b/paperbanana/core/pipeline.py index c2997a8..17f6739 100644 --- a/paperbanana/core/pipeline.py +++ b/paperbanana/core/pipeline.py @@ -560,6 +560,7 @@ def _extra(**payload: Any) -> Dict[str, Any]: seed=seed, aspect_ratio=effective_ratio, vector_formats=vector_formats, + sketch_guided=bool(input.input_images), ) visualizer_seconds = time.perf_counter() - visualizer_start if image_path is None: @@ -1267,6 +1268,7 @@ async def generate( "raw_data": input.raw_data, "aspect_ratio": input.aspect_ratio, "vector_export": self._effective_vector_export(input), + "input_images": input.input_images, }, self._run_dir / "run_input.json", ) @@ -1342,6 +1344,7 @@ async def generate( diagram_type=input.diagram_type, raw_data=input.raw_data, aspect_ratio=input.aspect_ratio, + input_images=input.input_images, ) except Exception: optimize_seconds = time.perf_counter() - optimize_start @@ -1471,6 +1474,7 @@ async def generate( examples=examples, diagram_type=input.diagram_type, supported_ratios=getattr(self.visualizer.image_gen, "supported_ratios", None), + input_images=input.input_images, ) planning_seconds = time.perf_counter() - planning_start _emit_progress( diff --git a/paperbanana/core/types.py b/paperbanana/core/types.py index 5b37227..3249212 100644 --- a/paperbanana/core/types.py +++ b/paperbanana/core/types.py @@ -91,6 +91,14 @@ class GenerationInput(BaseModel): default=None, description="Optional vector export (svg/pdf/both); None uses Settings.vector_export", ) + input_images: list[str] = Field( + default_factory=list, + description=( + "Paths to user-provided reference/sketch images (e.g. a hand-drawn " + "sketch, whiteboard photo, or prior figure version) that guide the " + "Planner alongside retrieved exemplars." + ), + ) @field_validator("aspect_ratio") @classmethod diff --git a/tests/test_agents/test_planner.py b/tests/test_agents/test_planner.py index 4630697..16e78e9 100644 --- a/tests/test_agents/test_planner.py +++ b/tests/test_agents/test_planner.py @@ -8,6 +8,7 @@ from paperbanana.agents.planner import PlannerAgent from paperbanana.core.types import ReferenceExample +from paperbanana.core.utils import find_prompt_dir class _MockVLM: @@ -18,6 +19,21 @@ async def generate(self, *args, **kwargs): return "ok" +class _CapturingVLM: + """Mock VLM that records the prompt and image parts it receives.""" + + name = "mock-vlm" + model_name = "mock-model" + + def __init__(self): + self.captured: dict = {} + + async def generate(self, prompt, images=None, **kwargs): + self.captured["prompt"] = prompt + self.captured["images"] = images + return "a detailed description\nRECOMMENDED_RATIO: 16:9" + + def test_format_examples_includes_structure_hints(): agent = PlannerAgent(_MockVLM()) text = agent._format_examples( @@ -74,6 +90,66 @@ def test_has_valid_image_rejects_insecure_or_local_urls(): assert agent._has_valid_image(private_ip) is False +async def test_planner_attaches_user_sketch_images_after_exemplars(tmp_path): + """User sketch images are attached after exemplar images and labeled in the prompt.""" + ref_img = tmp_path / "ref.png" + Image.new("RGB", (2, 2), color=(255, 0, 0)).save(ref_img) + sketch = tmp_path / "sketch.png" + Image.new("RGB", (4, 4), color=(0, 0, 255)).save(sketch) + + vlm = _CapturingVLM() + agent = PlannerAgent(vlm, prompt_dir=find_prompt_dir()) + + description, ratio = await agent.run( + source_context="methodology text", + caption="figure caption", + examples=[ + ReferenceExample( + id="ref_001", + source_context="ctx", + caption="cap", + image_path=str(ref_img), + ) + ], + input_images=[str(sketch)], + ) + + assert description == "a detailed description" + assert ratio == "16:9" + # Exemplar image first, user sketch attached last. + images = vlm.captured["images"] + assert len(images) == 2 + assert images[0].size == (2, 2) + assert images[-1].size == (4, 4) + # The prompt labels the trailing image parts as user-provided. + assert "User-Provided Reference/Sketch" in vlm.captured["prompt"] + + +async def test_planner_without_user_images_keeps_prompt_unchanged(tmp_path): + """No sketch label or extra image parts appear when input_images is absent.""" + ref_img = tmp_path / "ref.png" + Image.new("RGB", (2, 2), color=(255, 0, 0)).save(ref_img) + + vlm = _CapturingVLM() + agent = PlannerAgent(vlm, prompt_dir=find_prompt_dir()) + + await agent.run( + source_context="methodology text", + caption="figure caption", + examples=[ + ReferenceExample( + id="ref_001", + source_context="ctx", + caption="cap", + image_path=str(ref_img), + ) + ], + ) + + assert len(vlm.captured["images"]) == 1 + assert "User-Provided Reference/Sketch" not in vlm.captured["prompt"] + + def test_load_example_images_loads_from_url(monkeypatch): """_load_example_images fetches and loads images from http(s) URLs.""" agent = PlannerAgent(_MockVLM()) diff --git a/tests/test_agents/test_visualizer.py b/tests/test_agents/test_visualizer.py index a8eb5de..73eb275 100644 --- a/tests/test_agents/test_visualizer.py +++ b/tests/test_agents/test_visualizer.py @@ -135,3 +135,36 @@ def test_ratio_to_dimensions_supports_high_resolution_landscape(): def test_ratio_to_dimensions_supports_2k_landscape(): assert VisualizerAgent._ratio_to_dimensions("16:9", output_resolution="2k") == (2048, 1152) + + +# ── Sketch-guided prompt note ───────────────────────────────────────────────── + + +class _CapturingImageGen: + def __init__(self): + self.captured = {} + + async def generate(self, prompt, **kwargs): + from PIL import Image + + self.captured["prompt"] = prompt + return Image.new("RGB", (8, 8), color=(255, 255, 255)) + + +async def test_generate_diagram_notes_user_sketch_when_guided(tmp_path): + """sketch_guided=True adds a one-line mention to the diagram prompt.""" + from paperbanana.core.utils import find_prompt_dir + + gen = _CapturingImageGen() + agent = VisualizerAgent( + image_gen=gen, + vlm_provider=_DummyVLM(), + prompt_dir=find_prompt_dir(), + output_dir=str(tmp_path), + ) + + await agent.run(description="a diagram description", sketch_guided=True) + assert "user-provided reference sketch" in gen.captured["prompt"] + + await agent.run(description="a diagram description") + assert "user-provided reference sketch" not in gen.captured["prompt"] diff --git a/tests/test_cli.py b/tests/test_cli.py index 2a04d93..6e06245 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1405,3 +1405,124 @@ def test_continue_run_missing_path_reports_resolved_path(tmp_path): assert result.exit_code == 1 assert "Run directory not found" in flat assert "run_x" in flat + + +# ── --image (reference/sketch) flag tests ───────────────────────────────────── + + +def _strip_ansi(output: str) -> str: + """Strip ANSI color escapes so assertions are stable under rich/CI.""" + import re + + return re.sub(r"\x1b\[[0-9;]*m", "", output) + + +def _write_png(path: Path, size=(4, 4)) -> Path: + from PIL import Image + + Image.new("RGB", size, color=(0, 0, 255)).save(path) + return path + + +def test_generate_image_flag_round_trip(tmp_path, monkeypatch): + """Repeatable --image paths propagate into GenerationInput.input_images.""" + import paperbanana.core.types as types_mod + + captured: dict[str, object] = {} + real_init = types_mod.GenerationInput.__init__ + + def _spy_init(self, **kwargs): + captured.update(kwargs) + real_init(self, **kwargs) + + monkeypatch.setattr(types_mod.GenerationInput, "__init__", _spy_init) + + input_path = tmp_path / "method.txt" + input_path.write_text("Sample methodology text for testing.", encoding="utf-8") + sketch1 = _write_png(tmp_path / "sketch1.png") + sketch2 = _write_png(tmp_path / "sketch2.png") + + result = runner.invoke( + app, + [ + "generate", + "--input", + str(input_path), + "--caption", + "test", + "--image", + str(sketch1), + "--image", + str(sketch2), + "--dry-run", + ], + ) + + output = _strip_ansi(result.output) + assert result.exit_code == 0 + assert captured["input_images"] == [str(sketch1), str(sketch2)] + assert "Reference images:" in output + + +def test_generate_image_flag_missing_file_errors(tmp_path): + """--image with a nonexistent path fails before the pipeline starts.""" + input_path = tmp_path / "method.txt" + input_path.write_text("Sample methodology text for testing.", encoding="utf-8") + + result = runner.invoke( + app, + [ + "generate", + "--input", + str(input_path), + "--caption", + "test", + "--image", + str(tmp_path / "missing_sketch.png"), + "--dry-run", + ], + ) + + output = _strip_ansi(result.output) + assert result.exit_code == 1 + assert "Image file not found" in output + + +def test_generate_image_flag_rejects_non_raster_file(tmp_path): + """--image with a non-image file (e.g. text with .png extension) errors clearly.""" + input_path = tmp_path / "method.txt" + input_path.write_text("Sample methodology text for testing.", encoding="utf-8") + fake_image = tmp_path / "fake.png" + fake_image.write_text("this is not an image", encoding="utf-8") + + result = runner.invoke( + app, + [ + "generate", + "--input", + str(input_path), + "--caption", + "test", + "--image", + str(fake_image), + "--dry-run", + ], + ) + + output = _strip_ansi(result.output) + assert result.exit_code == 1 + assert "Not a valid raster image" in output + + +def test_generate_image_flag_rejected_with_continue(tmp_path): + """--image cannot be combined with --continue / --continue-run.""" + sketch = _write_png(tmp_path / "sketch.png") + + result = runner.invoke( + app, + ["generate", "--continue", "--image", str(sketch)], + ) + + output = _strip_ansi(result.output) + assert result.exit_code == 1 + assert "--image cannot be used with --continue" in output diff --git a/tests/test_mcp/test_server_tools.py b/tests/test_mcp/test_server_tools.py index 75c3b0b..3315926 100644 --- a/tests/test_mcp/test_server_tools.py +++ b/tests/test_mcp/test_server_tools.py @@ -53,3 +53,31 @@ def test_no_unexpected_tools(): def test_every_tool_has_description(): undocumented = [t.name for t in _list_tools() if not (t.description or "").strip()] assert not undocumented, f"MCP tools without descriptions: {undocumented}" + + +def test_generate_diagram_exposes_input_images_param(): + """generate_diagram accepts optional input_images (reference/sketch paths).""" + tool = next(t for t in _list_tools() if t.name == "generate_diagram") + assert "input_images" in tool.parameters.get("properties", {}) + + +def test_validate_input_images_rejects_missing_and_non_raster(tmp_path): + from mcp_server.server import _validate_input_images + + # Missing file + with pytest.raises(ValueError, match="not found"): + _validate_input_images([str(tmp_path / "missing.png")]) + + # Non-raster file with image extension + fake = tmp_path / "fake.png" + fake.write_text("not an image", encoding="utf-8") + with pytest.raises(ValueError, match="raster image"): + _validate_input_images([str(fake)]) + + # Valid tiny PNG passes + from PIL import Image + + real = tmp_path / "real.png" + Image.new("RGB", (2, 2), color=(255, 0, 0)).save(real) + assert _validate_input_images([str(real)]) == [str(real)] + assert _validate_input_images(None) == [] diff --git a/tests/test_pipeline/test_input_images.py b/tests/test_pipeline/test_input_images.py new file mode 100644 index 0000000..99c744e --- /dev/null +++ b/tests/test_pipeline/test_input_images.py @@ -0,0 +1,115 @@ +"""Pipeline wiring tests for user-provided reference/sketch images.""" + +from __future__ import annotations + +import json +from unittest.mock import AsyncMock + +import pytest +from PIL import Image + +from paperbanana.core.config import Settings +from paperbanana.core.pipeline import PaperBananaPipeline +from paperbanana.core.types import CritiqueResult, DiagramType, GenerationInput + + +class _MockVLM: + name = "mock-vlm" + model_name = "mock-model" + + def __init__(self, responses: list[str]): + self._responses = responses + self._idx = 0 + + async def generate(self, *args, **kwargs): + idx = min(self._idx, len(self._responses) - 1) + self._idx += 1 + return self._responses[idx] + + +class _MockImageGen: + name = "mock-image-gen" + model_name = "mock-image-model" + + async def generate(self, *args, **kwargs): + return Image.new("RGB", (128, 128), color=(255, 255, 255)) + + +def _make_sketch(tmp_path) -> str: + sketch = tmp_path / "sketch.png" + Image.new("RGB", (4, 4), color=(0, 0, 255)).save(sketch) + return str(sketch) + + +def _make_pipeline(tmp_path) -> PaperBananaPipeline: + settings = Settings( + output_dir=str(tmp_path / "outputs"), + reference_set_path=str(tmp_path / "empty_refs"), + refinement_iterations=1, + ) + vlm = _MockVLM( + responses=[ + "planner description", + "styled description", + json.dumps({"critic_suggestions": [], "revised_description": None}), + ] + ) + return PaperBananaPipeline(settings=settings, vlm_client=vlm, image_gen_fn=_MockImageGen()) + + +@pytest.mark.asyncio +async def test_planner_receives_input_images_and_critic_does_not(tmp_path): + """The sketch reaches the Planner; the Critic judges against source text only.""" + sketch_path = _make_sketch(tmp_path) + pipeline = _make_pipeline(tmp_path) + + pipeline.retriever.run = AsyncMock(return_value=[]) + pipeline.planner.run = AsyncMock(return_value=("planner description", None)) + pipeline.critic.run = AsyncMock(return_value=CritiqueResult()) + + await pipeline.generate( + GenerationInput( + source_context="source context", + communicative_intent="caption", + diagram_type=DiagramType.METHODOLOGY, + input_images=[sketch_path], + ) + ) + + # Planner got the sketch paths. + planner_kwargs = pipeline.planner.run.await_args.kwargs + assert planner_kwargs["input_images"] == [sketch_path] + + # Critic was called, but never with the sketch (no image/path argument + # other than the generated image; no mention of the sketch path at all). + pipeline.critic.run.assert_awaited() + critic_kwargs = pipeline.critic.run.await_args.kwargs + assert "input_images" not in critic_kwargs + assert "images" not in critic_kwargs + for value in critic_kwargs.values(): + if isinstance(value, str): + assert sketch_path not in value + elif isinstance(value, (list, tuple)): + assert sketch_path not in value + + +@pytest.mark.asyncio +async def test_run_input_json_records_input_images(tmp_path): + """run_input.json persists input_images for reproducibility.""" + sketch_path = _make_sketch(tmp_path) + pipeline = _make_pipeline(tmp_path) + pipeline.retriever.run = AsyncMock(return_value=[]) + pipeline.planner.run = AsyncMock(return_value=("planner description", None)) + pipeline.critic.run = AsyncMock(return_value=CritiqueResult()) + + await pipeline.generate( + GenerationInput( + source_context="source context", + communicative_intent="caption", + diagram_type=DiagramType.METHODOLOGY, + input_images=[sketch_path], + ) + ) + + run_input = json.loads((pipeline._run_dir / "run_input.json").read_text(encoding="utf-8")) + assert run_input["input_images"] == [sketch_path] diff --git a/tests/test_pipeline/test_types.py b/tests/test_pipeline/test_types.py index 9be74c2..ce0fcf6 100644 --- a/tests/test_pipeline/test_types.py +++ b/tests/test_pipeline/test_types.py @@ -44,6 +44,35 @@ def test_generation_input_with_invalid_aspect_ratio_raises(): ) +def test_generation_input_input_images_default_empty(): + """input_images defaults to an empty list.""" + gi = GenerationInput( + source_context="Test methodology", + communicative_intent="Test caption", + ) + assert gi.input_images == [] + + +def test_generation_input_accepts_input_images(): + """input_images accepts a list of path strings.""" + gi = GenerationInput( + source_context="Test methodology", + communicative_intent="Test caption", + input_images=["sketch.png", "prior_figure.jpg"], + ) + assert gi.input_images == ["sketch.png", "prior_figure.jpg"] + + +def test_generation_input_rejects_non_list_input_images(): + """input_images must be a list of strings.""" + with pytest.raises(ValueError): + GenerationInput( + source_context="Test methodology", + communicative_intent="Test caption", + input_images={"not": "a list"}, + ) + + def test_generation_input_plot(): """Test GenerationInput for statistical plots.""" gi = GenerationInput(