Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -255,12 +255,19 @@ paperbanana generate \
--input paper.pdf \
--caption "Overview of our method" \
--pdf-pages "3-8"

# Guide generation with a reference/sketch image (repeatable)
paperbanana generate \
--input method.txt \
--caption "Overview of our framework" \
--image sketch.png --image prior_figure.png
```

| Flag | Short | Description |
|------|-------|-------------|
| `--input` | `-i` | Path to methodology text file or PDF (required for new runs) |
| `--caption` | `-c` | Figure caption / communicative intent (required for new runs) |
| `--image` | | Reference/sketch image (hand-drawn sketch, whiteboard photo, prior figure) that guides the Planner. Repeatable for multiple images |
| `--output` | `-o` | Output image path (default: auto-generated in `outputs/`) |
| `--iterations` | `-n` | Number of Visualizer-Critic refinement rounds (default: 3) |
| `--num-candidates` | `-k` | Generate N candidate images in parallel, 1-8 (default: 1). Planning runs once; refinement fans out per candidate with seed offsets. Outputs land in `candidates/cand_<i>/`; the run-root `final_output` is candidate 1. Cost estimates and `--budget` account for the fan-out |
Expand Down
29 changes: 29 additions & 0 deletions mcp_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,28 @@ def _embed_caption(image_path: str, caption: str) -> None:
mcp = FastMCP("PaperBanana")


def _validate_input_images(input_images: list[str] | None) -> list[str]:
"""Validate user-provided reference/sketch image paths before the pipeline starts.

Each path must exist and be a PIL-openable raster image.

Raises:
ValueError: If a path is missing or not a valid raster image.
"""
validated: list[str] = []
for image_path in input_images or []:
path = Path(image_path)
if not path.is_file():
raise ValueError(f"Image file not found: {image_path}")
try:
with PILImage.open(path) as im:
im.verify()
except Exception:
raise ValueError(f"Not a valid raster image (e.g. PNG, JPEG, WebP): {image_path}")
validated.append(str(path))
return validated


@mcp.tool
async def generate_diagram(
source_context: str,
Expand All @@ -181,6 +203,7 @@ async def generate_diagram(
optimize: bool = False,
auto_refine: bool = False,
generate_caption: bool = False,
input_images: list[str] | None = None,
) -> Image:
"""Generate a publication-quality methodology diagram from text.

Expand All @@ -197,10 +220,15 @@ async def generate_diagram(
generate_caption: Auto-generate a publication-ready figure caption
after generation. When True, the caption is embedded in the
image metadata (PNG tEXt chunk, key "Caption") and logged.
input_images: Optional file paths to user-provided reference/sketch
images (hand-drawn sketch, whiteboard photo, prior figure) that
guide the layout and content of the generated diagram.

Returns:
The generated diagram as a PNG image.
"""
validated_images = _validate_input_images(input_images)

settings = Settings(
refinement_iterations=iterations,
optimize_inputs=optimize,
Expand All @@ -223,6 +251,7 @@ def _on_progress(event: str, payload: dict) -> None:
communicative_intent=caption,
diagram_type=DiagramType.METHODOLOGY,
aspect_ratio=aspect_ratio,
input_images=validated_images,
)

result = await pipeline.generate(gen_input)
Expand Down
47 changes: 46 additions & 1 deletion paperbanana/agents/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ async def run(
examples: list[ReferenceExample],
diagram_type: DiagramType = DiagramType.METHODOLOGY,
supported_ratios: list[str] | None = None,
input_images: list[str] | None = None,
) -> tuple[str, str | None]:
"""Generate a detailed textual description of the target diagram.

Expand All @@ -55,6 +56,8 @@ async def run(
examples: Retrieved reference examples for in-context learning.
diagram_type: Type of diagram being generated.
supported_ratios: Aspect ratios the image provider supports.
input_images: Paths to user-provided reference/sketch images that
guide the plan alongside the retrieved exemplars.

Returns:
Tuple of (description, recommended_ratio).
Expand All @@ -66,8 +69,18 @@ async def run(
# Load reference images for visual in-context learning
example_images = await asyncio.to_thread(self._load_example_images, examples)

# Load user-provided reference/sketch images (attached after the
# exemplar images so "reference image N" indexing stays valid).
user_images: list = []
if input_images:
user_images = await asyncio.to_thread(self._load_input_images, input_images)

prompt_type = "diagram" if diagram_type == DiagramType.METHODOLOGY else "plot"
template = self.load_prompt(prompt_type)
if user_images:
# Appended pre-format so the prompt recorder captures it; the note
# is brace-free, keeping str.format() on the template intact.
template += "\n\n" + self._format_user_image_note(len(user_images))
# Inject supported ratios into the prompt template
ratios_str = ", ".join(supported_ratios) if supported_ratios else "1:1, 16:9"
prompt = self.format_prompt(
Expand All @@ -83,12 +96,14 @@ async def run(
"Running planner agent",
num_examples=len(examples),
num_images=len(example_images),
num_user_images=len(user_images),
context_length=len(source_context),
)

all_images = example_images + user_images
raw_output = await self.vlm.generate(
prompt=prompt,
images=example_images if example_images else None,
images=all_images if all_images else None,
temperature=0.7,
max_tokens=4096,
)
Expand Down Expand Up @@ -242,6 +257,36 @@ def _load_example_images(self, examples: list[ReferenceExample]) -> list:
)
return images

@staticmethod
def _format_user_image_note(count: int) -> str:
"""Label for user-provided reference/sketch images attached to the prompt."""
return (
"## User-Provided Reference/Sketch\n"
f"The final {count} attached image(s), after the reference example images, "
"are user-provided reference/sketch images (e.g. a hand-drawn sketch, "
"whiteboard photo, or a prior version of the figure). Use them as guidance "
"for the layout and content of the target diagram while staying faithful "
"to the source text."
)

def _load_input_images(self, paths: list[str]) -> list:
"""Load user-provided reference/sketch images from local paths.

Returns PIL Image objects; unreadable files are skipped with a warning
(the CLI/MCP entry points validate them before the pipeline starts).
"""
images = []
for path in paths:
try:
images.append(load_image(path))
except Exception as e:
logger.warning(
"Failed to load user-provided reference image",
image_path=path,
error=str(e),
)
return images

_VALID_RATIOS = {"1:1", "2:3", "3:2", "3:4", "4:3", "9:16", "16:9", "21:9"}

@classmethod
Expand Down
12 changes: 12 additions & 0 deletions paperbanana/agents/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ async def run(
seed: Optional[int] = None,
aspect_ratio: Optional[str] = None,
vector_formats: Optional[list[str]] = None,
sketch_guided: bool = False,
) -> str:
"""Generate an image from a description.

Expand All @@ -74,6 +75,8 @@ async def run(
aspect_ratio: Target aspect ratio (e.g., '16:9', '1:1').
vector_formats: Vector formats to export alongside raster (e.g., ['svg', 'pdf']).
Only applies to statistical plots; ignored for methodology diagrams.
sketch_guided: When True, the diagram prompt notes that a
user-provided reference sketch guided the plan.

Returns:
Path to the generated raster image.
Expand All @@ -90,18 +93,27 @@ async def run(
iteration,
seed,
aspect_ratio,
sketch_guided=sketch_guided,
)

_SKETCH_GUIDED_NOTE = (
"Note: this plan was guided by a user-provided reference sketch; "
"follow the description above faithfully."
)

async def _generate_diagram(
self,
description: str,
output_path: Optional[str],
iteration: int,
seed: Optional[int],
aspect_ratio: Optional[str] = None,
sketch_guided: bool = False,
) -> str:
"""Generate a methodology diagram using the image generation model."""
template = self.load_prompt("diagram")
if sketch_guided:
template += "\n\n" + self._SKETCH_GUIDED_NOTE
prompt = self.format_prompt(
template,
prompt_label=f"visualizer_diagram_iter_{iteration}",
Expand Down
36 changes: 36 additions & 0 deletions paperbanana/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,14 @@ def generate(
caption: Optional[str] = typer.Option(
None, "--caption", "-c", help="Figure caption / communicative intent"
),
image: Optional[list[str]] = typer.Option(
None,
"--image",
help=(
"Path to a reference/sketch image (hand-drawn sketch, whiteboard photo, "
"prior figure) that guides generation. Repeatable for multiple images."
),
),
output: Optional[str] = typer.Option(None, "--output", "-o", help="Output image path"),
output_dir: Optional[str] = typer.Option(
None,
Expand Down Expand Up @@ -455,6 +463,31 @@ def generate(
"[red]Error: --pdf-pages cannot be used with --continue or --continue-run[/red]"
)
raise typer.Exit(1)
if image and (continue_last or continue_run):
console.print("[red]Error: --image cannot be used with --continue or --continue-run[/red]")
raise typer.Exit(1)

# Validate reference/sketch images before any pipeline work starts.
input_images: list[str] = []
if image:
from PIL import Image as PILImage
from PIL import UnidentifiedImageError

for image_path in image:
img_file = Path(image_path)
if not img_file.is_file():
console.print(f"[red]Error: Image file not found: {image_path}[/red]")
raise typer.Exit(1)
try:
with PILImage.open(img_file) as im:
im.verify()
except (UnidentifiedImageError, OSError, ValueError):
console.print(
f"[red]Error: Not a valid raster image (e.g. PNG, JPEG, WebP): "
f"{image_path}[/red]"
)
raise typer.Exit(1)
input_images.append(str(img_file))

_valid_categories = {
"agent_reasoning",
Expand Down Expand Up @@ -721,6 +754,7 @@ async def _run_continue():
diagram_type=DiagramType.METHODOLOGY,
aspect_ratio=aspect_ratio,
reference_ids=ref_id_list,
input_images=input_images,
)

# Determine expected output file extension based on settings.output_format
Expand Down Expand Up @@ -764,6 +798,8 @@ async def _run_continue():
pdf_note = ""
if input_path.suffix.lower() == ".pdf":
pdf_note = f"\nPDF pages: {pdf_pages.strip() if pdf_pages else 'all'}"
if input_images:
pdf_note += f"\nReference images: {', '.join(input_images)}"
console.print(
Panel.fit(
"[bold]PaperBanana[/bold] - Dry Run\n\n"
Expand Down
4 changes: 4 additions & 0 deletions paperbanana/core/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,7 @@ def _extra(**payload: Any) -> Dict[str, Any]:
seed=seed,
aspect_ratio=effective_ratio,
vector_formats=vector_formats,
sketch_guided=bool(input.input_images),
)
visualizer_seconds = time.perf_counter() - visualizer_start
if image_path is None:
Expand Down Expand Up @@ -1267,6 +1268,7 @@ async def generate(
"raw_data": input.raw_data,
"aspect_ratio": input.aspect_ratio,
"vector_export": self._effective_vector_export(input),
"input_images": input.input_images,
},
self._run_dir / "run_input.json",
)
Expand Down Expand Up @@ -1342,6 +1344,7 @@ async def generate(
diagram_type=input.diagram_type,
raw_data=input.raw_data,
aspect_ratio=input.aspect_ratio,
input_images=input.input_images,
)
except Exception:
optimize_seconds = time.perf_counter() - optimize_start
Expand Down Expand Up @@ -1471,6 +1474,7 @@ async def generate(
examples=examples,
diagram_type=input.diagram_type,
supported_ratios=getattr(self.visualizer.image_gen, "supported_ratios", None),
input_images=input.input_images,
)
planning_seconds = time.perf_counter() - planning_start
_emit_progress(
Expand Down
8 changes: 8 additions & 0 deletions paperbanana/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,14 @@ class GenerationInput(BaseModel):
default=None,
description="Optional vector export (svg/pdf/both); None uses Settings.vector_export",
)
input_images: list[str] = Field(
default_factory=list,
description=(
"Paths to user-provided reference/sketch images (e.g. a hand-drawn "
"sketch, whiteboard photo, or prior figure version) that guide the "
"Planner alongside retrieved exemplars."
),
)

@field_validator("aspect_ratio")
@classmethod
Expand Down
Loading
Loading