diff --git a/examples/ai/chat/pydantic-ai-chat.py b/examples/ai/chat/pydantic-ai-chat.py index b57b615faad..b3d740e7066 100644 --- a/examples/ai/chat/pydantic-ai-chat.py +++ b/examples/ai/chat/pydantic-ai-chat.py @@ -6,9 +6,10 @@ # "pydantic==2.12.5", # ] # /// + import marimo -__generated_with = "0.21.1" +__generated_with = "0.23.6" app = marimo.App(width="medium") with app.setup(hide_code=True): @@ -16,7 +17,12 @@ import os import httpx - from pydantic_ai import Agent, RunContext, BinaryImage + from pydantic_ai import ( + Agent, + BinaryImage, + DeferredToolRequests, + RunContext, + ) from pydantic_ai.models.google import GoogleModel, GoogleModelSettings from pydantic_ai.providers.google import GoogleProvider from pydantic_ai.models.anthropic import AnthropicModel, AnthropicModelSettings @@ -52,6 +58,7 @@ def _(): structured = mo.ui.checkbox(label="Structured outputs") thinking = mo.ui.checkbox(label="Reasoning") fetch_dog_tool = mo.ui.checkbox(label="Fetch dog pics tool") + delete_file_tool = mo.ui.checkbox(label="Delete file tool (requires approval)") models = mo.ui.dropdown( options={ @@ -64,8 +71,8 @@ def _(): label="Choose a model", ) - mo.vstack([models, structured, thinking, fetch_dog_tool]) - return fetch_dog_tool, models, structured, thinking + mo.vstack([models, structured, thinking, fetch_dog_tool, delete_file_tool]) + return delete_file_tool, fetch_dog_tool, models, structured, thinking @app.cell(hide_code=True) @@ -129,6 +136,7 @@ def get_model( @app.cell(hide_code=True) def _( + delete_file_tool, fetch_dog_tool, input_key, model_name, @@ -153,6 +161,15 @@ class CodeOutput(BaseModel): elif structured.value: output_type = [CodeOutput, str] + # Tools that pause for human approval require `DeferredToolRequests` + # in the output type; pydantic-ai returns it whenever a tool flagged + # `requires_approval=True` is called. + if delete_file_tool.value: + if isinstance(output_type, list): + output_type = [*output_type, DeferredToolRequests] + else: + output_type = [output_type, DeferredToolRequests] + agent = Agent( model, output_type=output_type, @@ -172,6 +189,14 @@ def fetch_dog_picture_url(ctx: RunContext[str]) -> str: return response_json["message"] else: return "Error fetching dog URL" + + + if delete_file_tool.value: + + @agent.tool_plain(requires_approval=True) + def delete_file(path: str) -> str: + """Pretend to delete the file at `path`.""" + return f"File {path!r} deleted" return (agent,) @@ -184,6 +209,7 @@ def _(agent): "Who is Ada Lovelace?", "What is marimo?", "I need dogs (render as markdown)", + "Delete the file at path 'secrets.env'", ], allow_attachments=True, show_configuration_controls=True, @@ -198,52 +224,115 @@ def _(chatbot): return -@app.cell +@app.cell(hide_code=True) def _(): - mo.md(""" - ## Custom Model Sample + mo.md(r""" + ## Custom model sample + + `mo.ui.chat` accepts any async generator that yields Vercel AI SDK chunks. + The model below is a hand-rolled showcase of every part the SDK knows + about — reasoning, streamed tool input, file/source/data attachments, + a deliberately failed tool, and a final tool that pauses for human + approval. """) return -@app.cell +@app.cell(hide_code=True) def _(): + import asyncio import uuid + import pydantic_ai.ui.vercel_ai.response_types as vercel + def _new_id(prefix: str) -> str: + return f"{prefix}_{uuid.uuid4().hex[:8]}" + + + def _pending_approval(messages) -> dict | None: + """Find a tool part the user just approved or denied, if any. + + After Approve/Deny, the SDK transitions the tool part on the last + assistant message to `approval-responded` and auto-resumes. We + look for that state on the most recent assistant turn so we know + whether to start a fresh showcase or finish the deletion. + """ + for message in reversed(messages): + if message.role != "assistant": + continue + for part in message.raw_or_dumped_parts(): + if not isinstance(part, dict): + continue + if not str(part.get("type", "")).startswith("tool-"): + continue + if part.get("state") == "approval-responded": + return part + return None + return None + + async def custom_model(messages, config): - # Generate unique IDs for message parts - reasoning_id = f"reasoning_{uuid.uuid4().hex}" - text_id = f"text_{uuid.uuid4().hex}" - tool_id = f"tool_{uuid.uuid4().hex}" + del config + + pending = _pending_approval(messages) + if pending is not None: + async for chunk in _resume_after_approval(pending): + yield chunk + return + + async for chunk in _showcase_turn(): + yield chunk + + + async def _showcase_turn(): + reasoning_id = _new_id("reasoning") + search_id = _new_id("tc") + translate_id = _new_id("tc") + delete_id = _new_id("tc") + approval_id = _new_id("ap") + intro_id = _new_id("text") + followup_id = _new_id("text") + error_text_id = _new_id("text") + ask_id = _new_id("text") + data_id = _new_id("data") + + # Message-level metadata round-trips on `message.metadata` in the UI. + yield vercel.MessageMetadataChunk( + message_metadata={"demo": "vercel-ai-sdk-showcase", "turn": 1} + ) - # --- Stream reasoning/thinking --- + # ── Step 1: think + run a tool that succeeds ────────────────── yield vercel.StartStepChunk() + yield vercel.ReasoningStartChunk(id=reasoning_id) - yield vercel.ReasoningDeltaChunk( - id=reasoning_id, - delta="The user is asking about Van Gogh. I should fetch information about his famous works.", - ) + for chunk in [ + "The user wants the full tour. ", + "I'll search for a famous painting, ", + "compose an answer with citations and an image, ", + "demonstrate an erroring tool, ", + "and finally offer to clean up a temp file ", + "behind a human-approval gate.", + ]: + yield vercel.ReasoningDeltaChunk(id=reasoning_id, delta=chunk) + await asyncio.sleep(0.04) yield vercel.ReasoningEndChunk(id=reasoning_id) - # --- Stream tool call to fetch artwork information --- + yield vercel.ToolInputStartChunk( + tool_call_id=search_id, tool_name="search_artwork" + ) + for delta in ['{"artist":', ' "Vincent van Gogh",', ' "limit": 1}']: + yield vercel.ToolInputDeltaChunk( + tool_call_id=search_id, input_text_delta=delta + ) + await asyncio.sleep(0.04) yield vercel.ToolInputAvailableChunk( - tool_call_id=tool_id, + tool_call_id=search_id, tool_name="search_artwork", input={"artist": "Vincent van Gogh", "limit": 1}, ) - yield vercel.ToolInputStartChunk( - tool_call_id=tool_id, tool_name="search_artwork" - ) - yield vercel.ToolInputDeltaChunk( - tool_call_id=tool_id, - input_text_delta='{"artist": "Vincent van Gogh", "limit": 1}', - ) - - # --- Tool output (simulated artwork search result) --- yield vercel.ToolOutputAvailableChunk( - tool_call_id=tool_id, + tool_call_id=search_id, output={ "title": "The Starry Night", "year": 1889, @@ -251,28 +340,170 @@ async def custom_model(messages, config): }, ) - # --- Stream text response --- - yield vercel.TextStartChunk(id=text_id) - yield vercel.TextDeltaChunk( - id=text_id, - delta="One of Vincent van Gogh's most iconic works is 'The Starry Night', painted in 1889. Here's the painting:\n\n", - ) + yield vercel.FinishStepChunk() + + # ── Step 2: compose the answer with rich media ──────────────── + yield vercel.StartStepChunk() + + yield vercel.TextStartChunk(id=intro_id) + for delta in [ + "One of Vincent van Gogh's most iconic works is ", + "**The Starry Night**, painted in 1889. ", + "Here is the painting:", + ]: + yield vercel.TextDeltaChunk(id=intro_id, delta=delta) + await asyncio.sleep(0.04) + yield vercel.TextEndChunk(id=intro_id) - # --- Embed the artwork image --- yield vercel.FileChunk( - url="https://upload.wikimedia.org/wikipedia/commons/thumb/e/ea/Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg/1280px-Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg", + url=( + "https://upload.wikimedia.org/wikipedia/commons/thumb/e/ea/" + "Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg/" + "1280px-Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg" + ), media_type="image/jpeg", ) + + yield vercel.SourceUrlChunk( + source_id=_new_id("src"), + url="https://www.moma.org/collection/works/79802", + title="The Starry Night | MoMA", + ) + yield vercel.SourceDocumentChunk( + source_id=_new_id("src"), + media_type="application/pdf", + title="Faille catalogue raisonné, vol. III", + filename="van-gogh-catalogue.pdf", + ) + + # Custom data-* parts let backends ship arbitrary structured + # payloads to bespoke UI widgets without bending the text channel. + yield vercel.DataChunk( + id=data_id, + type="data-artwork-card", + data={ + "title": "The Starry Night", + "year": 1889, + "movement": "Post-Impressionism", + }, + ) + + yield vercel.TextStartChunk(id=followup_id) + yield vercel.TextDeltaChunk( + id=followup_id, + delta=( + "\n\nNext I'll try a translation tool that's expected to" + " fail — handy for seeing how errors render." + ), + ) + yield vercel.TextEndChunk(id=followup_id) + + yield vercel.FinishStepChunk() + + # ── Step 3: a tool whose execution fails ────────────────────── + yield vercel.StartStepChunk() + + yield vercel.ToolInputStartChunk( + tool_call_id=translate_id, tool_name="translate" + ) + yield vercel.ToolInputAvailableChunk( + tool_call_id=translate_id, + tool_name="translate", + input={"text": "Sterrennacht", "from": "nl", "to": "klingon"}, + ) + yield vercel.ToolOutputErrorChunk( + tool_call_id=translate_id, + error_text="UnsupportedLanguage: 'klingon' is not a supported target.", + ) + + yield vercel.TextStartChunk(id=error_text_id) yield vercel.TextDeltaChunk( - id=text_id, - delta="\nThis masterpiece is now housed at the Museum of Modern Art in New York and remains one of the most recognizable paintings in the world.", + id=error_text_id, + delta="That call failed, as expected — moving on.", ) - yield vercel.TextEndChunk(id=text_id) + yield vercel.TextEndChunk(id=error_text_id) + yield vercel.FinishStepChunk() - yield vercel.FinishChunk() + # ── Step 4: ask for approval, then stop ─────────────────────── + yield vercel.StartStepChunk() - custom_chat = mo.ui.chat(custom_model) + yield vercel.TextStartChunk(id=ask_id) + yield vercel.TextDeltaChunk( + id=ask_id, + delta=( + "I'd like to delete the search cache file. " + "Approve below to proceed, or deny to keep it." + ), + ) + yield vercel.TextEndChunk(id=ask_id) + + yield vercel.ToolInputStartChunk( + tool_call_id=delete_id, tool_name="delete_file" + ) + yield vercel.ToolInputAvailableChunk( + tool_call_id=delete_id, + tool_name="delete_file", + input={"path": "/tmp/van-gogh-search.cache"}, + ) + yield vercel.ToolApprovalRequestChunk( + approval_id=approval_id, tool_call_id=delete_id + ) + + yield vercel.FinishStepChunk() + yield vercel.FinishChunk(finish_reason="tool-calls") + + + async def _resume_after_approval(pending: dict): + tool_call_id = pending["toolCallId"] + approval = pending.get("approval") or {} + approved = bool(approval.get("approved")) + path = (pending.get("input") or {}).get("path", "") + + text_id = _new_id("text") + + yield vercel.MessageMetadataChunk( + message_metadata={ + "demo": "vercel-ai-sdk-showcase", + "turn": 2, + "approval": approval, + } + ) + yield vercel.StartStepChunk() + + if approved: + yield vercel.ToolOutputAvailableChunk( + tool_call_id=tool_call_id, + output={"deleted": True, "path": path}, + ) + yield vercel.TextStartChunk(id=text_id) + yield vercel.TextDeltaChunk( + id=text_id, delta=f"Done — `{path}` has been removed." + ) + yield vercel.TextEndChunk(id=text_id) + else: + yield vercel.ToolOutputDeniedChunk(tool_call_id=tool_call_id) + yield vercel.TextStartChunk(id=text_id) + yield vercel.TextDeltaChunk( + id=text_id, + delta=( + f"No problem — I'll leave `{path}` alone. " + f"Reason: {approval.get('reason') or 'no reason given'}." + ), + ) + yield vercel.TextEndChunk(id=text_id) + + yield vercel.FinishStepChunk() + yield vercel.FinishChunk(finish_reason="stop") + + + custom_chat = mo.ui.chat( + custom_model, + prompts=[ + "Run the full Vercel AI SDK part showcase", + "Show me reasoning, citations, and an approval-gated tool", + ], + ) custom_chat return (custom_chat,) diff --git a/frontend/src/components/chat/__tests__/chat-utils.test.ts b/frontend/src/components/chat/__tests__/chat-utils.test.ts new file mode 100644 index 00000000000..6b4ad171582 --- /dev/null +++ b/frontend/src/components/chat/__tests__/chat-utils.test.ts @@ -0,0 +1,269 @@ +/* Copyright 2026 Marimo. All rights reserved. */ + +import type { UIMessage } from "ai"; +import { describe, expect, it } from "vitest"; +import { hasPendingToolCalls } from "../chat-utils"; + +/** + * `hasPendingToolCalls` powers `sendAutomaticallyWhen` in `mo.ui.chat`: + * returns true only when the last assistant message *ends* with a tool + * call in a ready-to-round-trip state. Any trailing non-tool part (text, + * file, source-*, reasoning, data-*, new step-start) means the assistant + * has already answered and we leave the next turn to the user. The + * approval flow relies on this firing for `approval-responded`. + */ + +const userMessage = (text: string): UIMessage => ({ + id: `user-${text}`, + role: "user", + parts: [{ type: "text", text }], +}); + +const assistantToolMessage = ( + parts: UIMessage["parts"], + id = "assistant-1", +): UIMessage => ({ + id, + role: "assistant", + parts, +}); + +describe("hasPendingToolCalls", () => { + it("returns false when there are no messages", () => { + expect(hasPendingToolCalls([])).toBe(false); + }); + + it("returns false when the last message is a user message", () => { + expect(hasPendingToolCalls([userMessage("hi")])).toBe(false); + }); + + it("returns false when the last assistant message has no tool parts", () => { + expect( + hasPendingToolCalls([ + userMessage("hi"), + assistantToolMessage([{ type: "text", text: "hello!" }]), + ]), + ).toBe(false); + }); + + it("returns false while a tool is still streaming or awaiting approval", () => { + expect( + hasPendingToolCalls([ + userMessage("delete it"), + assistantToolMessage([ + { + type: "tool-delete_file", + toolCallId: "call-1", + state: "approval-requested", + input: { path: "secrets.env" }, + approval: { id: "approval-1" }, + } as unknown as UIMessage["parts"][number], + ]), + ]), + ).toBe(false); + }); + + it("returns true when the user has responded to an approval request", () => { + // The chat must auto-resume as soon as Approve/Deny is clicked. + expect( + hasPendingToolCalls([ + userMessage("delete it"), + assistantToolMessage([ + { + type: "tool-delete_file", + toolCallId: "call-1", + state: "approval-responded", + input: { path: "secrets.env" }, + approval: { id: "approval-1", approved: true }, + } as unknown as UIMessage["parts"][number], + ]), + ]), + ).toBe(true); + }); + + it("returns true when a tool reached a terminal output state", () => { + expect( + hasPendingToolCalls([ + userMessage("run it"), + assistantToolMessage([ + { + type: "tool-run_query", + toolCallId: "call-1", + state: "output-available", + input: { sql: "select 1" }, + output: 1, + } as unknown as UIMessage["parts"][number], + ]), + ]), + ).toBe(true); + }); + + it("returns false when only some tool calls are ready", () => { + expect( + hasPendingToolCalls([ + userMessage("two things"), + assistantToolMessage([ + { + type: "tool-first", + toolCallId: "call-1", + state: "output-available", + input: {}, + output: 1, + } as unknown as UIMessage["parts"][number], + { + type: "tool-second", + toolCallId: "call-2", + state: "input-available", + input: {}, + } as unknown as UIMessage["parts"][number], + ]), + ]), + ).toBe(false); + }); + + it("returns false once the assistant has appended text after the tool result", () => { + expect( + hasPendingToolCalls([ + userMessage("run it"), + assistantToolMessage([ + { + type: "tool-run_query", + toolCallId: "call-1", + state: "output-available", + input: {}, + output: 1, + } as unknown as UIMessage["parts"][number], + { type: "text", text: "The query returned 1." }, + ]), + ]), + ).toBe(false); + }); + + it("returns false when a file part trails the completed tool call", () => { + // Regression: tool → text → file used to loop because only trailing + // text counted as "the assistant has answered". + expect( + hasPendingToolCalls([ + userMessage("show me Starry Night"), + assistantToolMessage([ + { type: "step-start" }, + { + type: "tool-search_artwork", + toolCallId: "call-1", + state: "output-available", + input: { artist: "Van Gogh" }, + output: { title: "The Starry Night" }, + } as unknown as UIMessage["parts"][number], + { type: "text", text: "Here is the painting:" }, + { + type: "file", + mediaType: "image/jpeg", + url: "https://example.com/starry-night.jpg", + } as unknown as UIMessage["parts"][number], + ]), + ]), + ).toBe(false); + }); + + it("returns false when a source-url part trails the completed tool call", () => { + expect( + hasPendingToolCalls([ + userMessage("cite your sources"), + assistantToolMessage([ + { + type: "tool-web_search", + toolCallId: "call-1", + state: "output-available", + input: { q: "marimo notebook" }, + output: "found", + } as unknown as UIMessage["parts"][number], + { type: "text", text: "marimo is a reactive notebook." }, + { + type: "source-url", + sourceId: "src-1", + url: "https://marimo.io", + } as unknown as UIMessage["parts"][number], + ]), + ]), + ).toBe(false); + }); + + it("returns false when a reasoning part trails the completed tool call", () => { + expect( + hasPendingToolCalls([ + userMessage("explain"), + assistantToolMessage([ + { + type: "tool-lookup", + toolCallId: "call-1", + state: "output-available", + input: {}, + output: 1, + } as unknown as UIMessage["parts"][number], + { + type: "reasoning", + text: "Now I'll summarize.", + } as unknown as UIMessage["parts"][number], + ]), + ]), + ).toBe(false); + }); + + it("returns false when a new step-start follows the completed tool call", () => { + expect( + hasPendingToolCalls([ + userMessage("multi-step"), + assistantToolMessage([ + { type: "step-start" }, + { + type: "tool-run_query", + toolCallId: "call-1", + state: "output-available", + input: {}, + output: 1, + } as unknown as UIMessage["parts"][number], + { type: "step-start" }, + ]), + ]), + ).toBe(false); + }); + + it("ignores providerExecuted tools", () => { + // Provider-side tools are resolved by the model, not the runtime, so + // they must not drive an auto-resume. + expect( + hasPendingToolCalls([ + userMessage("hi"), + assistantToolMessage([ + { + type: "tool-web_search", + toolCallId: "call-1", + state: "output-available", + input: {}, + output: 1, + providerExecuted: true, + } as unknown as UIMessage["parts"][number], + ]), + ]), + ).toBe(false); + }); + + it("returns true for dynamic-tool parts in a terminal state", () => { + // `dynamic-tool` parts must drive auto-resume alongside `tool-*`. + expect( + hasPendingToolCalls([ + userMessage("run it"), + assistantToolMessage([ + { + type: "dynamic-tool", + toolName: "run_query", + toolCallId: "call-1", + state: "output-available", + input: {}, + output: 1, + } as unknown as UIMessage["parts"][number], + ]), + ]), + ).toBe(true); + }); +}); diff --git a/frontend/src/components/chat/chat-utils.ts b/frontend/src/components/chat/chat-utils.ts index fbb3552f30d..c09195345af 100644 --- a/frontend/src/components/chat/chat-utils.ts +++ b/frontend/src/components/chat/chat-utils.ts @@ -5,7 +5,8 @@ import { type ChatAddToolOutputFunction, type FileUIPart, isToolUIPart, - type ToolUIPart, + lastAssistantMessageIsCompleteWithApprovalResponses, + lastAssistantMessageIsCompleteWithToolCalls, type UIMessage, } from "ai"; import { useState } from "react"; @@ -17,7 +18,6 @@ import type { InvokeAiToolRequest, InvokeAiToolResponse, } from "@/core/network/types"; -import { logNever } from "@/utils/assertNever"; import { blobToString } from "@/utils/fileToBase64"; import { Logger } from "@/utils/Logger"; import { getAICompletionBodyWithAttachments } from "../editor/ai/completion-utils"; @@ -169,69 +169,25 @@ export async function handleToolCall({ } /** - * Returns true if a tool call is "ready to be sent back to the server" — i.e. - * either it has reached a terminal output state, or the user has just supplied - * an approval response that the server hasn't seen yet. - */ -function isToolCallReadyToSend(state: ToolUIPart["state"]): boolean { - switch (state) { - case "output-available": - case "output-error": - case "output-denied": - case "approval-responded": - return true; - case "input-streaming": - case "input-available": - case "approval-requested": - return false; - default: - logNever(state); - return false; - } -} - -/** - * Checks if we should send a message automatically based on the messages. - * We auto-send when every tool call on the last assistant message has either - * finished (output-available/error/denied) or has just received a user - * approval response, and the assistant hasn't replied yet. + * Auto-send the next turn when the last assistant message ends with a + * tool call ready to round-trip. Any non-tool trailing part (text, file, + * source-*, reasoning, data-*, new step-start) means the assistant has + * already answered, so we leave the next turn to the user. State checks + * are delegated to the SDK to stay in sync with upstream. */ export function hasPendingToolCalls(messages: UIMessage[]): boolean { - if (messages.length === 0) { - return false; - } - - const lastMessage = messages[messages.length - 1]; - const parts = lastMessage.parts; - - if (parts.length === 0) { - return false; - } - - // Only auto-send if the last message is an assistant message - // Because assistant messages are the ones that can have tool calls - if (lastMessage.role !== "assistant") { + const lastMessage = messages.at(-1); + if (!lastMessage || lastMessage.role !== "assistant") { return false; } - - const toolParts = parts.filter(isToolUIPart); - - if (toolParts.length === 0) { + const lastPart = lastMessage.parts.at(-1); + if (!lastPart || !isToolUIPart(lastPart)) { return false; } - - const allToolCallsReady = toolParts.every((part) => - isToolCallReadyToSend(part.state), + return ( + lastAssistantMessageIsCompleteWithToolCalls({ messages }) || + lastAssistantMessageIsCompleteWithApprovalResponses({ messages }) ); - - // Check if the last part has any text content - const lastPart = parts[parts.length - 1]; - const hasTextContent = - lastPart.type === "text" && lastPart.text?.trim().length > 0; - - Logger.debug("All tool calls ready to send: %s", allToolCallsReady); - - return allToolCallsReady && !hasTextContent; } export function useFileState() { diff --git a/frontend/src/plugins/impl/chat/chat-ui.tsx b/frontend/src/plugins/impl/chat/chat-ui.tsx index fd719e568dd..f4fabb3b8db 100644 --- a/frontend/src/plugins/impl/chat/chat-ui.tsx +++ b/frontend/src/plugins/impl/chat/chat-ui.tsx @@ -27,7 +27,10 @@ import { import React, { useEffect, useRef, useState } from "react"; import { z } from "zod"; import { renderUIMessage } from "@/components/chat/chat-display"; -import { convertToFileUIPart } from "@/components/chat/chat-utils"; +import { + convertToFileUIPart, + hasPendingToolCalls, +} from "@/components/chat/chat-utils"; import { type AdditionalCompletions, PromptInput, @@ -186,7 +189,9 @@ export const Chatbot: React.FC = (props) => { error, regenerate, clearError, + addToolApprovalResponse, } = useChat({ + sendAutomaticallyWhen: ({ messages }) => hasPendingToolCalls(messages), transport: new DefaultChatTransport({ fetch: async ( request: RequestInfo | URL, @@ -440,6 +445,9 @@ export const Chatbot: React.FC = (props) => { message, isStreamingReasoning: status === "streaming", isLast, + addToolApprovalResponse: isLast + ? addToolApprovalResponse + : undefined, })}
diff --git a/marimo/_ai/_convert.py b/marimo/_ai/_convert.py index 3fe7b6507be..198195a62d2 100644 --- a/marimo/_ai/_convert.py +++ b/marimo/_ai/_convert.py @@ -164,8 +164,13 @@ def convert_to_openai_messages( # Reset parts since we've added the messages current_parts = [] - else: + elif dataclasses.is_dataclass(part) and not isinstance(part, type): current_parts.append(dataclasses.asdict(part)) # type: ignore + else: + LOGGER.debug( + "Dropping unsupported part %s during OpenAI conversion", + type(part).__name__, + ) if current_parts: openai_messages.append( diff --git a/marimo/_ai/_pydantic_ai_utils.py b/marimo/_ai/_pydantic_ai_utils.py index c1653ad21d4..2b37f7e0e2c 100644 --- a/marimo/_ai/_pydantic_ai_utils.py +++ b/marimo/_ai/_pydantic_ai_utils.py @@ -98,7 +98,7 @@ def safe_part_processor( or generate_id("message") ) role = message.get("role", "assistant") - parts = [_sanitize_part(part) for part in message.get("parts", [])] + parts = [sanitize_part(part) for part in message.get("parts", [])] metadata = message.get("metadata") ui_message = UIMessage( @@ -147,7 +147,7 @@ def _tool_part_allowed_fields() -> dict[tuple[bool, str], frozenset[str]]: return result -def _sanitize_part(part: Any) -> Any: +def sanitize_part(part: Any) -> Any: """Drop fields the AI SDK spread onto a tool part during a state transition. The AI SDK transitions tool parts via `{ ...part, state, approval }`, which diff --git a/marimo/_ai/_types.py b/marimo/_ai/_types.py index a179e07b547..56fa50c5605 100644 --- a/marimo/_ai/_types.py +++ b/marimo/_ai/_types.py @@ -3,7 +3,6 @@ import abc import mimetypes -from collections.abc import Iterator from dataclasses import asdict, dataclass, is_dataclass from typing import TYPE_CHECKING, Any, Literal, TypedDict, cast @@ -11,7 +10,6 @@ from marimo import _loggers from marimo._dependencies.dependencies import DependencyManager -from marimo._utils.dicts import remove_none_values from marimo._utils.parse_dataclass import parse_raw LOGGER = _loggers.marimo_logger() @@ -169,6 +167,8 @@ class StepStartPart: if TYPE_CHECKING: from collections.abc import Iterator + from pydantic_ai.ui.vercel_ai.request_types import UIMessagePart + ChatPart = ( TextPart | ReasoningPart @@ -190,7 +190,7 @@ class StepStartPart: ] -class ChatMessage(msgspec.Struct): +class ChatMessage(msgspec.Struct, dict=True): """ A message in a chat. """ @@ -215,44 +215,89 @@ class ChatMessage(msgspec.Struct): metadata: Any | None = None def __post_init__(self) -> None: - # Hack: msgspec only supports discriminated unions. This is a hack to just - # iterate through possible part variants and decode until one works. - if self.parts: - parts = [] - for part in self.parts: - if converted := self._convert_part(part): - parts.append(converted) - self.parts = parts + # Non-struct attribute (via `dict=True`) so it isn't serialized. + # Snapshots raw dict inputs 1:1 with `self.parts` so SDK fields the + # typed dataclasses don't model survive the round-trip. + self._raw_parts: list[dict[str, Any] | None] | None = None + if not self.parts: + return + snapshots: list[dict[str, Any] | None] = [] + typed: list[ChatPart] = [] + for part in self.parts: + converted = self._convert_part(part) + if converted is None: + continue + snapshots.append( + cast(dict[str, Any], part) if isinstance(part, dict) else None + ) + typed.append(converted) + if any(s is not None for s in snapshots): + self._raw_parts = snapshots + self.parts = typed def _convert_part(self, part: Any) -> ChatPart | None: - # If we receive a Vercel AI SDK part (through pydantic-ai), return it as is. if DependencyManager.pydantic_ai.imported(): from pydantic_ai.ui.vercel_ai.request_types import UIMessagePart if isinstance(part, UIMessagePart): return cast(ChatPart, part) - PartType = None - for PartType in PART_TYPES: - try: - if is_dataclass(part): - return cast(ChatPart, part) - return parse_raw(part, cls=PartType, allow_unknown_keys=True) - except Exception: - continue + if is_dataclass(part) and not isinstance(part, type): + return cast(ChatPart, part) + + # Unknown dicts pass through verbatim so future SDK part types still + # round-trip. + if isinstance(part, dict): + for PartType in PART_TYPES: + try: + return parse_raw( + part, cls=PartType, allow_unknown_keys=True + ) + except Exception: + continue + return cast(ChatPart, part) - LOGGER.debug( - f"Could not decode part {part}. Ignore if it's a Vercel UI message part." - ) + LOGGER.debug("Dropping unrecognized part %r", part) return None + def raw_or_dumped_parts(self) -> list[dict[str, Any]]: + """Return parts in dict form, preferring the original wire payload.""" + ui_message_part_cls: type[UIMessagePart] | None = None + if DependencyManager.pydantic_ai.imported(): + from pydantic_ai.ui.vercel_ai.request_types import UIMessagePart + + # `UIMessagePart` is a union type alias; the runtime value works + # with `isinstance` but doesn't match `type[...]` statically. + ui_message_part_cls = UIMessagePart # type: ignore[assignment] # pyright: ignore[reportAssignmentType] + + def dump(part: Any) -> dict[str, Any] | None: + if is_dataclass(part) and not isinstance(part, type): + return asdict(part) + if ui_message_part_cls is not None and isinstance( + part, ui_message_part_cls + ): + return part.model_dump(by_alias=True, exclude_none=True) # type: ignore[no-any-return] + if isinstance(part, dict): + return cast(dict[str, Any], part) + return None + + raws = self._raw_parts + result: list[dict[str, Any]] = [] + for i, part in enumerate(self.parts): + snap = raws[i] if raws is not None and i < len(raws) else None + if snap is not None: + result.append(snap) + elif (d := dump(part)) is not None: + result.append(d) + return result + def __iter__(self) -> Iterator[tuple[str, Any]]: """Allow dict(message) to build the serialized dict.""" out: ChatMessageDict = { "role": self.role, "id": self.id, "content": self.content, - "parts": [cast(ChatPartDict, asdict(part)) for part in self.parts], + "parts": cast(list[ChatPartDict], self.raw_or_dumped_parts()), "attachments": [ cast(ChatAttachmentDict, asdict(a)) for a in self.attachments ] @@ -278,12 +323,16 @@ def create( """ if part_validator_class: + # Lazy import: `_pydantic_ai_utils` pulls in `marimo._server.*`, + # which we don't want to load just to define the types module. + from marimo._ai._pydantic_ai_utils import sanitize_part + validated_parts = [] for part in parts: if isinstance(part, part_validator_class): validated_parts.append(part) elif isinstance(part, dict): - sanitized_part = remove_none_values(part) + sanitized_part = sanitize_part(part) # Try pydantic validation for dict -> class conversion try: from pydantic import TypeAdapter diff --git a/marimo/_ai/llm/_impl.py b/marimo/_ai/llm/_impl.py index c3980af68f9..6d838c94cfc 100644 --- a/marimo/_ai/llm/_impl.py +++ b/marimo/_ai/llm/_impl.py @@ -1,16 +1,14 @@ # Copyright 2026 Marimo. All rights reserved. from __future__ import annotations -import dataclasses import json import os import re from typing import TYPE_CHECKING, Any, cast from marimo import _loggers -from marimo._ai._pydantic_ai_utils import generate_id +from marimo._ai._pydantic_ai_utils import generate_id, sanitize_part from marimo._plugins.ui._impl.chat.chat import AI_SDK_VERSION, DONE_CHUNK -from marimo._utils.dicts import remove_none_values if TYPE_CHECKING: from collections.abc import AsyncGenerator, Callable, Generator @@ -747,18 +745,14 @@ def _build_ui_messages( if not message.id: LOGGER.warning("Message %s has no id", message) + # Prefer the raw wire payload when we have it so fields outside + # marimo's lossy dataclasses (`approval`, `providerExecuted`, + # `preliminary`, ...) survive the round-trip into pydantic-ai. parts: list[UIMessagePart] = [] if message.parts: parts = cast( list[UIMessagePart], - [ - self._remove_none_values( - dataclasses.asdict(part) - if dataclasses.is_dataclass(part) - else part - ) - for part in message.parts - ], + [sanitize_part(p) for p in message.raw_or_dumped_parts()], ) if not parts: if message.content is not None: @@ -784,11 +778,6 @@ def _build_ui_messages( ) return ui_messages - def _remove_none_values(self, obj: dict[str, Any]) -> dict[str, Any]: - if isinstance(obj, dict) and hasattr(obj, "items"): - return remove_none_values(obj) - return obj - def _serialize_vercel_ai_chunk( self, chunk: BaseChunk ) -> dict[str, Any] | None: @@ -849,7 +838,17 @@ async def _stream_response( messages=ui_messages, ) - adapter = VercelAIAdapter(agent=self.agent, run_input=run_input) + try: + adapter = VercelAIAdapter( + agent=self.agent, + run_input=run_input, + sdk_version=AI_SDK_VERSION, + ) + except TypeError: + adapter = VercelAIAdapter( + agent=self.agent, + run_input=run_input, + ) event_stream = adapter.run_stream(model_settings=model_settings) async for event in event_stream: if serialized := self._serialize_vercel_ai_chunk(event): diff --git a/tests/_ai/llm/test_impl.py b/tests/_ai/llm/test_impl.py index 6b23b74c55a..9a84c7233d7 100644 --- a/tests/_ai/llm/test_impl.py +++ b/tests/_ai/llm/test_impl.py @@ -3,7 +3,7 @@ from __future__ import annotations import os -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from unittest.mock import MagicMock, patch import pytest @@ -20,8 +20,11 @@ simple, ) from marimo._dependencies.dependencies import DependencyManager +from marimo._plugins.ui._impl.chat.chat import AI_SDK_VERSION if TYPE_CHECKING: + from collections.abc import AsyncIterator + from pydantic_ai.settings import ModelSettings @@ -1410,6 +1413,9 @@ async def mock_run_stream(**_kwargs: Any): {"id": "2", "type": "text-delta", "delta": " World"}, ] + _, kwargs = mock_adapter.call_args + assert kwargs.get("sdk_version") == AI_SDK_VERSION + async def test_stream_text(self): """Test _stream_text streams text from the model.""" mock_agent = MagicMock() @@ -1597,6 +1603,137 @@ def test_pydantic_ai_serialize_vercel_ai_chunk_error_handling( result = model._serialize_vercel_ai_chunk(cast(Any, error_chunk)) assert result is None + def test_build_ui_messages_preserves_tool_approval_field(self): + """When the frontend posts back an `approval-responded` tool part, + the `approval` payload must reach pydantic-ai intact. The lossy + `ToolInvocationPart` dataclass doesn't model `approval`, so without + the `_raw_parts` snapshot the field would be silently dropped and + the agent would loop on the same tool call forever. + + Regression for the second half of the tool-approval fix: the first + half (`sdk_version=AI_SDK_VERSION`) made the request chunk visible + to the frontend; this half makes the user's response visible to the + agent. + """ + from pydantic_ai.ui.vercel_ai.request_types import ( + ToolApprovalResponded, + ToolApprovalRespondedPart, + ) + + model = pydantic_ai(MagicMock()) + messages = [ + ChatMessage( + role="user", + content="Delete secrets.env", + id="msg-user-1", + parts=cast( # pyright: ignore[reportAny] + Any, + [{"type": "text", "text": "Delete secrets.env"}], + ), + ), + ChatMessage( + role="assistant", + content=None, + id="msg-assistant-1", + parts=cast( # pyright: ignore[reportAny] + Any, + [ + { + "type": "tool-delete_file", + "toolCallId": "call-1", + "state": "approval-responded", + "input": {"path": "secrets.env"}, + "approval": { + "id": "call-1", + "approved": True, + }, + } + ], + ), + ), + ] + + ui_messages = model._build_ui_messages(messages) + assert len(ui_messages) == 2 + + # The tool part must be reified as the *responded* variant — the + # one that carries the approval — and the approval must survive. + tool_part = ui_messages[1].parts[0] + assert isinstance(tool_part, ToolApprovalRespondedPart), ( + f"Expected ToolApprovalRespondedPart, got {type(tool_part).__name__}: " + f"{tool_part!r}" + ) + approval = tool_part.approval + assert isinstance(approval, ToolApprovalResponded), ( + f"Expected ToolApprovalResponded, got {approval!r}" + ) + assert approval.id == "call-1" + assert approval.approved is True + + async def test_stream_response_emits_tool_approval_request(self): + """Tools with `requires_approval=True` should surface an + approval-request chunk so the frontend can render an Approve/Deny + card. This is the v6-only behavior unlocked by passing + `sdk_version=AI_SDK_VERSION` to the adapter. + """ + from pydantic_ai import Agent, DeferredToolRequests + from pydantic_ai.models.function import ( + AgentInfo, + DeltaToolCall, + DeltaToolCalls, + FunctionModel, + ) + + async def respond( + messages: list[Any], _info: AgentInfo + ) -> AsyncIterator[DeltaToolCalls]: + del messages, _info + yield { + 0: DeltaToolCall( + name="delete_file", + json_args='{"path": "secrets.env"}', + tool_call_id="call-1", + ) + } + + agent = Agent( + FunctionModel(stream_function=respond), + output_type=[str, DeferredToolRequests], + ) + + @agent.tool_plain(requires_approval=True) + def delete_file(path: str) -> str: + return f"File {path!r} deleted" + + model = pydantic_ai(agent) + messages = [ + ChatMessage( + role="user", + content="Delete secrets.env", + id="msg-user-1", + parts=[TextPart(type="text", text="Delete secrets.env")], + ), + ] + config = ChatModelConfig(max_tokens=100) + + chunks = [ + chunk async for chunk in model._stream_response(messages, config) + ] + + approval_chunks = [ + chunk + for chunk in chunks + if chunk.get("type") == "tool-approval-request" + ] + # Older pydantic-ai generates a UUID approvalId; newer versions reuse + # toolCallId. Either is fine — assert the shape, not the exact value. + assert len(approval_chunks) == 1 + chunk = approval_chunks[0] + assert chunk["type"] == "tool-approval-request" + assert chunk["toolCallId"] == "call-1" + assert isinstance(chunk["approvalId"], str) + assert chunk["approvalId"] + class MockBaseChunkWithError: """Mock BaseChunk that raises on serialization.""" diff --git a/tests/_ai/test_ai_types.py b/tests/_ai/test_ai_types.py index 2f262bd5b2c..4724750610e 100644 --- a/tests/_ai/test_ai_types.py +++ b/tests/_ai/test_ai_types.py @@ -473,7 +473,11 @@ def test_keeps_already_typed_parts(self): assert message.parts[0] is text_part def test_handles_invalid_parts_gracefully(self): - """Test that invalid parts are dropped gracefully.""" + """Unknown dict parts don't crash construction; they're preserved + verbatim so future SDK part types still round-trip on serialization. + See `test_raw_part_round_trips_verbatim` for the contract. + """ + raw_unknown = {"type": "unknown_type", "data": "invalid"} message = ChatMessage( role="user", content="Hello", @@ -481,17 +485,19 @@ def test_handles_invalid_parts_gracefully(self): Any, [ {"type": "text", "text": "Valid"}, - {"type": "unknown_type", "data": "invalid"}, + raw_unknown, ], ), ) - # Valid part should be kept, invalid should be dropped - assert message == ChatMessage( - role="user", - content="Hello", - parts=[TextPart(type="text", text="Valid")], - ) + assert len(message.parts) == 2 + assert isinstance(message.parts[0], TextPart) + assert message.parts[0].text == "Valid" + assert message.parts[1] == raw_unknown + assert message.raw_or_dumped_parts() == [ + {"type": "text", "text": "Valid"}, + raw_unknown, + ] def test_with_none_parts(self): """Test that None parts is handled.""" @@ -546,3 +552,129 @@ def test_parts_and_attachments_serialized_to_dict(self): ], "metadata": None, } + + +class TestChatMessageRawPartsRoundTrip: + """`ChatMessage` must snapshot raw wire payloads so we don't drop AI + SDK fields the typed dataclasses don't model (e.g. `approval`, + `callProviderMetadata`). Without this the pydantic-ai bridge would + lose context on every deferred tool run. + """ + + def test_typed_input_does_not_snapshot(self): + """Typed parts are themselves the source of truth — no snapshot.""" + message = ChatMessage( + role="user", + content="hi", + id="msg", + parts=[TextPart(type="text", text="hi")], + ) + assert message._raw_parts is None + assert dict[str, Any](message)["parts"] == [ + {"type": "text", "text": "hi"} + ] + + def test_equality_ignores_raw_parts(self): + """Equality compares value, not cache state.""" + from_dict = ChatMessage( + role="user", + content="hi", + id="msg", + parts=[cast(ChatPart, {"type": "text", "text": "hi"})], + ) + from_typed = ChatMessage( + role="user", + content="hi", + id="msg", + parts=[TextPart(type="text", text="hi")], + ) + assert from_dict._raw_parts is not None + assert from_typed._raw_parts is None + assert from_dict == from_typed + + @pytest.mark.parametrize( + "raw", + [ + # Tool part carrying fields marimo's dataclass doesn't model — + # this is the case that motivated `_raw_parts` in the first place. + { + "type": "tool-delete_file", + "toolCallId": "call-1", + "state": "approval-responded", + "input": {"path": "secrets.env"}, + "approval": {"id": "call-1", "approved": True}, + "callProviderMetadata": {"openai": {"foo": "bar"}}, + }, + # Alternate tool shape. + { + "type": "dynamic-tool", + "toolName": "delete_file", + "toolCallId": "call-1", + "state": "input-streaming", + }, + # Non-tool part with its own bag of optional fields. + { + "type": "file", + "mediaType": "image/png", + "url": "data:image/png;base64,abc", + }, + # `data-*` parts are user-defined; the type string itself + # carries information and must survive the round-trip. + { + "type": "data-reasoning-signature", + "data": {"signature": "x"}, + }, + ], + ) + def test_raw_part_round_trips_verbatim(self, raw: dict[str, Any]): + """`dict(ChatMessage(parts=[raw]))["parts"]` must equal `[raw]`.""" + message = ChatMessage( + role="assistant", + content=None, + id="msg", + parts=[cast(ChatPart, raw)], + ) + assert dict[str, Any](message)["parts"] == [raw] + assert message._raw_parts == [raw] + + def test_mixed_dict_and_typed_parts_preserve_raw_dicts(self): + """When `parts` mixes raw dicts and typed inputs, the raw dicts' + unmodeled fields (e.g. `approval`) must survive — not just the + all-dicts case. + """ + raw = { + "type": "tool-delete_file", + "toolCallId": "call-1", + "state": "approval-responded", + "input": {"path": "secrets.env"}, + "approval": {"id": "call-1", "approved": True}, + } + message = ChatMessage( + role="assistant", + content=None, + id="msg", + parts=[ + cast(ChatPart, raw), + TextPart(type="text", text="ok"), + ], + ) + assert message.raw_or_dumped_parts() == [ + raw, + {"type": "text", "text": "ok"}, + ] + + def test_dict_part_appended_after_init_is_preserved(self): + message = ChatMessage( + role="assistant", + content=None, + id="msg", + parts=[TextPart(type="text", text="hi")], + ) + assert message._raw_parts is None + + extra = {"type": "data-custom", "data": {"k": "v"}} + message.parts.append(cast(ChatPart, extra)) + + dumped = message.raw_or_dumped_parts() + assert dumped == [{"type": "text", "text": "hi"}, extra] + assert dict[str, Any](message)["parts"] == dumped diff --git a/tests/_ai/test_pydantic_utils.py b/tests/_ai/test_pydantic_utils.py index 91a36600f9d..e3b79441611 100644 --- a/tests/_ai/test_pydantic_utils.py +++ b/tests/_ai/test_pydantic_utils.py @@ -9,11 +9,11 @@ pytest.importorskip("pydantic_ai", reason="pydantic_ai not installed") from marimo._ai._pydantic_ai_utils import ( - _sanitize_part, convert_to_pydantic_messages, create_simple_prompt, form_toolsets, generate_id, + sanitize_part, ) from marimo._server.ai.tools.types import ToolDefinition @@ -521,7 +521,7 @@ def test_strips_stale_fields_on_approval_responded(self): "output": "Invalid arguments for tool request_user_blessing", "errorText": "boom", } - clean = _sanitize_part(stale) + clean = sanitize_part(stale) assert "output" not in clean assert "errorText" not in clean assert clean["approval"] == {"id": "call_abc", "approved": True} @@ -538,7 +538,7 @@ def test_preserves_real_fields_on_output_available(self): "output": {"ok": True}, "preliminary": False, } - clean = _sanitize_part(part) + clean = sanitize_part(part) assert clean["output"] == {"ok": True} assert clean["preliminary"] is False @@ -551,7 +551,7 @@ def test_preserves_real_fields_on_output_error(self): "rawInput": {"x": 1}, "errorText": "boom", } - clean = _sanitize_part(part) + clean = sanitize_part(part) assert clean["errorText"] == "boom" assert clean["rawInput"] == {"x": 1} @@ -565,7 +565,7 @@ def test_dynamic_tool_preserves_tool_name(self): "approval": {"id": "c1", "approved": True}, "output": "stale", } - clean = _sanitize_part(stale) + clean = sanitize_part(stale) assert "output" not in clean assert clean["toolName"] == "my_tool" @@ -584,7 +584,7 @@ def test_dynamic_tool_preserves_tool_name(self): ], ) def test_non_tool_parts_pass_through_unchanged(self, part): - assert _sanitize_part(part) is part + assert sanitize_part(part) is part @pytest.mark.parametrize( "part", @@ -606,7 +606,7 @@ def test_non_tool_parts_pass_through_unchanged(self, part): ], ) def test_malformed_parts_pass_through_unchanged(self, part): - assert _sanitize_part(part) == part + assert sanitize_part(part) == part class TestConvertToPydanticMessagesSanitizes: