Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 23 additions & 10 deletions paperbanana/agents/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ async def run(
if user_images:
# Appended pre-format so the prompt recorder captures it; the note
# is brace-free, keeping str.format() on the template intact.
template += "\n\n" + self._format_user_image_note(len(user_images))
template += "\n\n" + self._format_user_image_note(
len(user_images), offset=len(example_images)
)
# Inject supported ratios into the prompt template
ratios_str = ", ".join(supported_ratios) if supported_ratios else "1:1, 16:9"
prompt = self.format_prompt(
Expand Down Expand Up @@ -258,16 +260,27 @@ def _load_example_images(self, examples: list[ReferenceExample]) -> list:
return images

@staticmethod
def _format_user_image_note(count: int) -> str:
def _format_user_image_note(count: int, offset: int = 0) -> str:
"""Label for user-provided reference/sketch images attached to the prompt."""
return (
"## User-Provided Reference/Sketch\n"
f"The final {count} attached image(s), after the reference example images, "
"are user-provided reference/sketch images (e.g. a hand-drawn sketch, "
"whiteboard photo, or a prior version of the figure). Use them as guidance "
"for the layout and content of the target diagram while staying faithful "
"to the source text."
)
if count <= 0:
return ""
if count == 1:
positions = f"attached image {offset + 1}"
else:
positions = f"attached images {offset + 1}-{offset + count}"
lines = [
"## User-Provided Reference/Sketch Images",
(
f"The final {count} attached image(s) ({positions}) are user-provided "
"reference/sketch images (e.g. a hand-drawn sketch, whiteboard photo, "
"or a prior version of the figure). Any earlier attached images are "
"retrieved reference examples. Use the user-provided images as guidance "
"for layout and content while staying faithful to the source text."
),
]
for i in range(count):
lines.append(f"- User reference/sketch image {i + 1}: attached image {offset + 1 + i}")
return "\n".join(lines)

def _load_input_images(self, paths: list[str]) -> list:
"""Load user-provided reference/sketch images from local paths.
Expand Down
2 changes: 1 addition & 1 deletion paperbanana/agents/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ async def run(
)

_SKETCH_GUIDED_NOTE = (
"Note: this plan was guided by a user-provided reference sketch; "
"Note: this plan was guided by a user-provided reference sketch/image; "
"follow the description above faithfully."
)

Expand Down
50 changes: 30 additions & 20 deletions paperbanana/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,35 @@ def _check_pdf_dep(path: Path) -> None:
_require_pdf_dep()


def _validate_input_image_paths(image_paths: Optional[list[str]]) -> list[str]:
"""Validate repeatable --image paths and return absolute path strings."""
if not image_paths:
return []

from PIL import Image, UnidentifiedImageError

validated: list[str] = []
for raw in image_paths:
path = Path(raw).expanduser()
if not path.exists():
console.print(f"[red]Error: Reference image not found: {raw}[/red]")
raise typer.Exit(1)
if not path.is_file():
console.print(f"[red]Error: Reference image is not a file: {raw}[/red]")
raise typer.Exit(1)
try:
with Image.open(path) as img:
img.verify()
except (UnidentifiedImageError, OSError, ValueError) as e:
console.print(
f"[red]Error: Reference image is not a readable raster image: {raw}[/red]"
)
console.print(f"[dim]{e}[/dim]")
raise typer.Exit(1)
validated.append(str(path.resolve()))
return validated


def _require_studio_dep() -> None:
"""Raise a clean error if Gradio is not installed."""
try:
Expand Down Expand Up @@ -494,26 +523,7 @@ def generate(
raise typer.Exit(1)

# Validate reference/sketch images before any pipeline work starts.
input_images: list[str] = []
if image:
from PIL import Image as PILImage
from PIL import UnidentifiedImageError

for image_path in image:
img_file = Path(image_path)
if not img_file.is_file():
console.print(f"[red]Error: Image file not found: {image_path}[/red]")
raise typer.Exit(1)
try:
with PILImage.open(img_file) as im:
im.verify()
except (UnidentifiedImageError, OSError, ValueError):
console.print(
f"[red]Error: Not a valid raster image (e.g. PNG, JPEG, WebP): "
f"{image_path}[/red]"
)
raise typer.Exit(1)
input_images.append(str(img_file))
input_images = _validate_input_image_paths(image)

_valid_categories = {
"agent_reasoning",
Expand Down
3 changes: 3 additions & 0 deletions paperbanana/core/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1275,6 +1275,7 @@ async def generate(
run_id=self.run_id,
diagram_type=input.diagram_type.value,
context_length=len(input.source_context),
input_image_count=len(input.input_images),
)

# Save input for resume/continue support
Expand Down Expand Up @@ -1559,6 +1560,7 @@ async def generate(
save_json(
{
"retrieved_examples": [e.id for e in examples],
"input_images": input.input_images,
"initial_description": description,
"optimized_description": optimized_description,
"planner_recommended_ratio": planner_ratio,
Expand Down Expand Up @@ -1806,6 +1808,7 @@ async def generate(
"external_enabled": self.settings.exemplar_retrieval_enabled,
"external_candidate_ids": external_candidate_ids,
}
metadata_dict["input_images"] = {"count": len(input.input_images)}
if rollback_info is not None:
metadata_dict["rollback"] = rollback_info
if candidates_meta is not None:
Expand Down
76 changes: 76 additions & 0 deletions tests/test_agents/test_planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ async def test_planner_attaches_user_sketch_images_after_exemplars(tmp_path):
assert images[-1].size == (4, 4)
# The prompt labels the trailing image parts as user-provided.
assert "User-Provided Reference/Sketch" in vlm.captured["prompt"]
assert "attached image 2" in vlm.captured["prompt"]
assert "User reference/sketch image 1: attached image 2" in vlm.captured["prompt"]


async def test_planner_without_user_images_keeps_prompt_unchanged(tmp_path):
Expand Down Expand Up @@ -169,3 +171,77 @@ def test_load_example_images_loads_from_url(monkeypatch):
images = agent._load_example_images(examples)
assert len(images) == 1
assert images[0].size == (1, 1)


# ── user-provided reference/sketch images (issue #223) ──────────────


def test_format_input_image_guidance_identifies_last_images():
"""Guidance text tells the model which attached images are user-provided."""
agent = PlannerAgent(_MockVLM())
text = agent._format_user_image_note(2, offset=3)

assert "User-Provided Reference/Sketch Images" in text
assert "final 2 attached image(s)" in text
assert "attached images 4-5" in text
assert "User reference/sketch image 1: attached image 4" in text
assert "User reference/sketch image 2: attached image 5" in text


def test_format_input_image_guidance_single_image_no_offset():
"""With one user image and no exemplar images, it is attached image 1."""
agent = PlannerAgent(_MockVLM())
text = agent._format_user_image_note(1, offset=0)

assert "attached image 1" in text
assert "User reference/sketch image 1: attached image 1" in text


def test_format_input_image_guidance_empty():
"""No user images means no guidance section."""
agent = PlannerAgent(_MockVLM())
assert agent._format_user_image_note(0) == ""


def test_load_input_images_skips_unreadable_paths(tmp_path):
"""Local images load; missing paths are skipped with a warning, not an error."""
agent = PlannerAgent(_MockVLM())
img_path = tmp_path / "sketch.png"
Image.new("RGB", (2, 2), color=(0, 255, 0)).save(img_path)

images = agent._load_input_images([str(img_path), str(tmp_path / "missing.png")])

assert len(images) == 1
assert images[0].size == (2, 2)


def test_run_passes_user_images_after_examples(tmp_path):
"""run() attaches user images after exemplar images and adds prompt guidance."""
import asyncio

captured = {}

class _CapturingVLM(_MockVLM):
async def generate(self, prompt, images=None, **kwargs):
captured["prompt"] = prompt
captured["images"] = images
return "a diagram description"

agent = PlannerAgent(_CapturingVLM())
img_path = tmp_path / "sketch.png"
Image.new("RGB", (3, 3), color=(0, 0, 255)).save(img_path)

description, _ratio = asyncio.run(
agent.run(
source_context="Our method has two stages.",
caption="Overview of our framework",
examples=[],
input_images=[str(img_path)],
)
)

assert description == "a diagram description"
assert len(captured["images"]) == 1
assert captured["images"][0].size == (3, 3)
assert "User-Provided Reference/Sketch Images" in captured["prompt"]
assert "attached image 1" in captured["prompt"]
65 changes: 53 additions & 12 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
import tempfile
from pathlib import Path

import pytest
import typer
from typer.testing import CliRunner

from paperbanana.cli import app
from paperbanana.cli import _validate_input_image_paths, app

runner = CliRunner()
HELP_TERMINAL_WIDTH = 200
Expand Down Expand Up @@ -1420,6 +1422,7 @@ def _strip_ansi(output: str) -> str:
def _write_png(path: Path, size=(4, 4)) -> Path:
from PIL import Image

path.parent.mkdir(parents=True, exist_ok=True)
Image.new("RGB", size, color=(0, 0, 255)).save(path)
return path

Expand Down Expand Up @@ -1460,10 +1463,47 @@ def _spy_init(self, **kwargs):

output = _strip_ansi(result.output)
assert result.exit_code == 0
assert captured["input_images"] == [str(sketch1), str(sketch2)]
assert captured["input_images"] == [str(sketch1.resolve()), str(sketch2.resolve())]
assert "Reference images:" in output


def test_validate_input_image_paths_returns_absolute_paths(tmp_path):
"""The reusable --image validator preserves order and normalizes paths."""
sketch1 = _write_png(tmp_path / "sketch1.png")
sketch2 = _write_png(tmp_path / "sketch2.png")

assert _validate_input_image_paths([str(sketch1), str(sketch2)]) == [
str(sketch1.resolve()),
str(sketch2.resolve()),
]


def test_validate_input_image_paths_empty_input():
"""None or empty list validates to an empty list."""
assert _validate_input_image_paths(None) == []
assert _validate_input_image_paths([]) == []


def test_validate_input_image_paths_rejects_missing_file(tmp_path):
"""A nonexistent path exits with an error."""
with pytest.raises(typer.Exit):
_validate_input_image_paths([str(tmp_path / "missing.png")])


def test_validate_input_image_paths_rejects_directory(tmp_path):
"""A directory path exits with an error."""
with pytest.raises(typer.Exit):
_validate_input_image_paths([str(tmp_path)])


def test_validate_input_image_paths_rejects_non_image(tmp_path):
"""A file that is not a readable raster image exits with an error."""
fake = tmp_path / "fake.png"
fake.write_text("this is not an image", encoding="utf-8")
with pytest.raises(typer.Exit):
_validate_input_image_paths([str(fake)])


def test_generate_image_flag_missing_file_errors(tmp_path):
"""--image with a nonexistent path fails before the pipeline starts."""
input_path = tmp_path / "method.txt"
Expand All @@ -1485,7 +1525,7 @@ def test_generate_image_flag_missing_file_errors(tmp_path):

output = _strip_ansi(result.output)
assert result.exit_code == 1
assert "Image file not found" in output
assert "Reference image not found" in output


def test_generate_image_flag_rejects_non_raster_file(tmp_path):
Expand All @@ -1511,18 +1551,19 @@ def test_generate_image_flag_rejects_non_raster_file(tmp_path):

output = _strip_ansi(result.output)
assert result.exit_code == 1
assert "Not a valid raster image" in output
assert "not a readable raster image" in output


def test_generate_image_flag_rejected_with_continue(tmp_path):
"""--image cannot be combined with --continue / --continue-run."""
sketch = _write_png(tmp_path / "sketch.png")

result = runner.invoke(
app,
["generate", "--continue", "--image", str(sketch)],
)

output = _strip_ansi(result.output)
assert result.exit_code == 1
assert "--image cannot be used with --continue" in output
for continue_args in (["--continue"], ["--continue-run", "run_x"]):
result = runner.invoke(
app,
["generate", *continue_args, "--image", str(sketch)],
terminal_width=HELP_TERMINAL_WIDTH,
)
flat = result.output.replace("\n", "")
assert result.exit_code == 1
assert "--image cannot be used with --continue" in flat
Loading