From c3aa6d33bb376db66e806be5d023866ed37272d9 Mon Sep 17 00:00:00 2001 From: Shahmir Varqha Date: Wed, 20 May 2026 15:52:24 +0800 Subject: [PATCH 01/10] support tool approvals --- examples/ai/chat/pydantic-ai-chat.py | 315 +++++++++++++++--- .../chat/__tests__/chat-utils.test.ts | 269 +++++++++++++++ frontend/src/components/chat/chat-utils.ts | 72 +--- frontend/src/plugins/impl/chat/chat-ui.tsx | 8 +- marimo/_ai/_pydantic_ai_utils.py | 4 +- marimo/_ai/_types.py | 76 ++++- marimo/_ai/llm/_impl.py | 27 +- tests/_ai/llm/test_impl.py | 138 +++++++- tests/_ai/test_ai_types.py | 84 +++++ tests/_ai/test_pydantic_utils.py | 14 +- 10 files changed, 872 insertions(+), 135 deletions(-) create mode 100644 frontend/src/components/chat/__tests__/chat-utils.test.ts diff --git a/examples/ai/chat/pydantic-ai-chat.py b/examples/ai/chat/pydantic-ai-chat.py index b57b615faad..b3d740e7066 100644 --- a/examples/ai/chat/pydantic-ai-chat.py +++ b/examples/ai/chat/pydantic-ai-chat.py @@ -6,9 +6,10 @@ # "pydantic==2.12.5", # ] # /// + import marimo -__generated_with = "0.21.1" +__generated_with = "0.23.6" app = marimo.App(width="medium") with app.setup(hide_code=True): @@ -16,7 +17,12 @@ import os import httpx - from pydantic_ai import Agent, RunContext, BinaryImage + from pydantic_ai import ( + Agent, + BinaryImage, + DeferredToolRequests, + RunContext, + ) from pydantic_ai.models.google import GoogleModel, GoogleModelSettings from pydantic_ai.providers.google import GoogleProvider from pydantic_ai.models.anthropic import AnthropicModel, AnthropicModelSettings @@ -52,6 +58,7 @@ def _(): structured = mo.ui.checkbox(label="Structured outputs") thinking = mo.ui.checkbox(label="Reasoning") fetch_dog_tool = mo.ui.checkbox(label="Fetch dog pics tool") + delete_file_tool = mo.ui.checkbox(label="Delete file tool (requires approval)") models = mo.ui.dropdown( options={ @@ -64,8 +71,8 @@ def _(): label="Choose a model", ) - mo.vstack([models, structured, thinking, fetch_dog_tool]) - return fetch_dog_tool, models, structured, thinking + mo.vstack([models, structured, thinking, fetch_dog_tool, delete_file_tool]) + return delete_file_tool, fetch_dog_tool, models, structured, thinking @app.cell(hide_code=True) @@ -129,6 +136,7 @@ def get_model( @app.cell(hide_code=True) def _( + delete_file_tool, fetch_dog_tool, input_key, model_name, @@ -153,6 +161,15 @@ class CodeOutput(BaseModel): elif structured.value: output_type = [CodeOutput, str] + # Tools that pause for human approval require `DeferredToolRequests` + # in the output type; pydantic-ai returns it whenever a tool flagged + # `requires_approval=True` is called. + if delete_file_tool.value: + if isinstance(output_type, list): + output_type = [*output_type, DeferredToolRequests] + else: + output_type = [output_type, DeferredToolRequests] + agent = Agent( model, output_type=output_type, @@ -172,6 +189,14 @@ def fetch_dog_picture_url(ctx: RunContext[str]) -> str: return response_json["message"] else: return "Error fetching dog URL" + + + if delete_file_tool.value: + + @agent.tool_plain(requires_approval=True) + def delete_file(path: str) -> str: + """Pretend to delete the file at `path`.""" + return f"File {path!r} deleted" return (agent,) @@ -184,6 +209,7 @@ def _(agent): "Who is Ada Lovelace?", "What is marimo?", "I need dogs (render as markdown)", + "Delete the file at path 'secrets.env'", ], allow_attachments=True, show_configuration_controls=True, @@ -198,52 +224,115 @@ def _(chatbot): return -@app.cell +@app.cell(hide_code=True) def _(): - mo.md(""" - ## Custom Model Sample + mo.md(r""" + ## Custom model sample + + `mo.ui.chat` accepts any async generator that yields Vercel AI SDK chunks. + The model below is a hand-rolled showcase of every part the SDK knows + about — reasoning, streamed tool input, file/source/data attachments, + a deliberately failed tool, and a final tool that pauses for human + approval. """) return -@app.cell +@app.cell(hide_code=True) def _(): + import asyncio import uuid + import pydantic_ai.ui.vercel_ai.response_types as vercel + def _new_id(prefix: str) -> str: + return f"{prefix}_{uuid.uuid4().hex[:8]}" + + + def _pending_approval(messages) -> dict | None: + """Find a tool part the user just approved or denied, if any. + + After Approve/Deny, the SDK transitions the tool part on the last + assistant message to `approval-responded` and auto-resumes. We + look for that state on the most recent assistant turn so we know + whether to start a fresh showcase or finish the deletion. + """ + for message in reversed(messages): + if message.role != "assistant": + continue + for part in message.raw_or_dumped_parts(): + if not isinstance(part, dict): + continue + if not str(part.get("type", "")).startswith("tool-"): + continue + if part.get("state") == "approval-responded": + return part + return None + return None + + async def custom_model(messages, config): - # Generate unique IDs for message parts - reasoning_id = f"reasoning_{uuid.uuid4().hex}" - text_id = f"text_{uuid.uuid4().hex}" - tool_id = f"tool_{uuid.uuid4().hex}" + del config + + pending = _pending_approval(messages) + if pending is not None: + async for chunk in _resume_after_approval(pending): + yield chunk + return + + async for chunk in _showcase_turn(): + yield chunk + + + async def _showcase_turn(): + reasoning_id = _new_id("reasoning") + search_id = _new_id("tc") + translate_id = _new_id("tc") + delete_id = _new_id("tc") + approval_id = _new_id("ap") + intro_id = _new_id("text") + followup_id = _new_id("text") + error_text_id = _new_id("text") + ask_id = _new_id("text") + data_id = _new_id("data") + + # Message-level metadata round-trips on `message.metadata` in the UI. + yield vercel.MessageMetadataChunk( + message_metadata={"demo": "vercel-ai-sdk-showcase", "turn": 1} + ) - # --- Stream reasoning/thinking --- + # ── Step 1: think + run a tool that succeeds ────────────────── yield vercel.StartStepChunk() + yield vercel.ReasoningStartChunk(id=reasoning_id) - yield vercel.ReasoningDeltaChunk( - id=reasoning_id, - delta="The user is asking about Van Gogh. I should fetch information about his famous works.", - ) + for chunk in [ + "The user wants the full tour. ", + "I'll search for a famous painting, ", + "compose an answer with citations and an image, ", + "demonstrate an erroring tool, ", + "and finally offer to clean up a temp file ", + "behind a human-approval gate.", + ]: + yield vercel.ReasoningDeltaChunk(id=reasoning_id, delta=chunk) + await asyncio.sleep(0.04) yield vercel.ReasoningEndChunk(id=reasoning_id) - # --- Stream tool call to fetch artwork information --- + yield vercel.ToolInputStartChunk( + tool_call_id=search_id, tool_name="search_artwork" + ) + for delta in ['{"artist":', ' "Vincent van Gogh",', ' "limit": 1}']: + yield vercel.ToolInputDeltaChunk( + tool_call_id=search_id, input_text_delta=delta + ) + await asyncio.sleep(0.04) yield vercel.ToolInputAvailableChunk( - tool_call_id=tool_id, + tool_call_id=search_id, tool_name="search_artwork", input={"artist": "Vincent van Gogh", "limit": 1}, ) - yield vercel.ToolInputStartChunk( - tool_call_id=tool_id, tool_name="search_artwork" - ) - yield vercel.ToolInputDeltaChunk( - tool_call_id=tool_id, - input_text_delta='{"artist": "Vincent van Gogh", "limit": 1}', - ) - - # --- Tool output (simulated artwork search result) --- yield vercel.ToolOutputAvailableChunk( - tool_call_id=tool_id, + tool_call_id=search_id, output={ "title": "The Starry Night", "year": 1889, @@ -251,28 +340,170 @@ async def custom_model(messages, config): }, ) - # --- Stream text response --- - yield vercel.TextStartChunk(id=text_id) - yield vercel.TextDeltaChunk( - id=text_id, - delta="One of Vincent van Gogh's most iconic works is 'The Starry Night', painted in 1889. Here's the painting:\n\n", - ) + yield vercel.FinishStepChunk() + + # ── Step 2: compose the answer with rich media ──────────────── + yield vercel.StartStepChunk() + + yield vercel.TextStartChunk(id=intro_id) + for delta in [ + "One of Vincent van Gogh's most iconic works is ", + "**The Starry Night**, painted in 1889. ", + "Here is the painting:", + ]: + yield vercel.TextDeltaChunk(id=intro_id, delta=delta) + await asyncio.sleep(0.04) + yield vercel.TextEndChunk(id=intro_id) - # --- Embed the artwork image --- yield vercel.FileChunk( - url="https://upload.wikimedia.org/wikipedia/commons/thumb/e/ea/Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg/1280px-Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg", + url=( + "https://upload.wikimedia.org/wikipedia/commons/thumb/e/ea/" + "Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg/" + "1280px-Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg" + ), media_type="image/jpeg", ) + + yield vercel.SourceUrlChunk( + source_id=_new_id("src"), + url="https://www.moma.org/collection/works/79802", + title="The Starry Night | MoMA", + ) + yield vercel.SourceDocumentChunk( + source_id=_new_id("src"), + media_type="application/pdf", + title="Faille catalogue raisonné, vol. III", + filename="van-gogh-catalogue.pdf", + ) + + # Custom data-* parts let backends ship arbitrary structured + # payloads to bespoke UI widgets without bending the text channel. + yield vercel.DataChunk( + id=data_id, + type="data-artwork-card", + data={ + "title": "The Starry Night", + "year": 1889, + "movement": "Post-Impressionism", + }, + ) + + yield vercel.TextStartChunk(id=followup_id) + yield vercel.TextDeltaChunk( + id=followup_id, + delta=( + "\n\nNext I'll try a translation tool that's expected to" + " fail — handy for seeing how errors render." + ), + ) + yield vercel.TextEndChunk(id=followup_id) + + yield vercel.FinishStepChunk() + + # ── Step 3: a tool whose execution fails ────────────────────── + yield vercel.StartStepChunk() + + yield vercel.ToolInputStartChunk( + tool_call_id=translate_id, tool_name="translate" + ) + yield vercel.ToolInputAvailableChunk( + tool_call_id=translate_id, + tool_name="translate", + input={"text": "Sterrennacht", "from": "nl", "to": "klingon"}, + ) + yield vercel.ToolOutputErrorChunk( + tool_call_id=translate_id, + error_text="UnsupportedLanguage: 'klingon' is not a supported target.", + ) + + yield vercel.TextStartChunk(id=error_text_id) yield vercel.TextDeltaChunk( - id=text_id, - delta="\nThis masterpiece is now housed at the Museum of Modern Art in New York and remains one of the most recognizable paintings in the world.", + id=error_text_id, + delta="That call failed, as expected — moving on.", ) - yield vercel.TextEndChunk(id=text_id) + yield vercel.TextEndChunk(id=error_text_id) + yield vercel.FinishStepChunk() - yield vercel.FinishChunk() + # ── Step 4: ask for approval, then stop ─────────────────────── + yield vercel.StartStepChunk() - custom_chat = mo.ui.chat(custom_model) + yield vercel.TextStartChunk(id=ask_id) + yield vercel.TextDeltaChunk( + id=ask_id, + delta=( + "I'd like to delete the search cache file. " + "Approve below to proceed, or deny to keep it." + ), + ) + yield vercel.TextEndChunk(id=ask_id) + + yield vercel.ToolInputStartChunk( + tool_call_id=delete_id, tool_name="delete_file" + ) + yield vercel.ToolInputAvailableChunk( + tool_call_id=delete_id, + tool_name="delete_file", + input={"path": "/tmp/van-gogh-search.cache"}, + ) + yield vercel.ToolApprovalRequestChunk( + approval_id=approval_id, tool_call_id=delete_id + ) + + yield vercel.FinishStepChunk() + yield vercel.FinishChunk(finish_reason="tool-calls") + + + async def _resume_after_approval(pending: dict): + tool_call_id = pending["toolCallId"] + approval = pending.get("approval") or {} + approved = bool(approval.get("approved")) + path = (pending.get("input") or {}).get("path", "") + + text_id = _new_id("text") + + yield vercel.MessageMetadataChunk( + message_metadata={ + "demo": "vercel-ai-sdk-showcase", + "turn": 2, + "approval": approval, + } + ) + yield vercel.StartStepChunk() + + if approved: + yield vercel.ToolOutputAvailableChunk( + tool_call_id=tool_call_id, + output={"deleted": True, "path": path}, + ) + yield vercel.TextStartChunk(id=text_id) + yield vercel.TextDeltaChunk( + id=text_id, delta=f"Done — `{path}` has been removed." + ) + yield vercel.TextEndChunk(id=text_id) + else: + yield vercel.ToolOutputDeniedChunk(tool_call_id=tool_call_id) + yield vercel.TextStartChunk(id=text_id) + yield vercel.TextDeltaChunk( + id=text_id, + delta=( + f"No problem — I'll leave `{path}` alone. " + f"Reason: {approval.get('reason') or 'no reason given'}." + ), + ) + yield vercel.TextEndChunk(id=text_id) + + yield vercel.FinishStepChunk() + yield vercel.FinishChunk(finish_reason="stop") + + + custom_chat = mo.ui.chat( + custom_model, + prompts=[ + "Run the full Vercel AI SDK part showcase", + "Show me reasoning, citations, and an approval-gated tool", + ], + ) custom_chat return (custom_chat,) diff --git a/frontend/src/components/chat/__tests__/chat-utils.test.ts b/frontend/src/components/chat/__tests__/chat-utils.test.ts new file mode 100644 index 00000000000..6b4ad171582 --- /dev/null +++ b/frontend/src/components/chat/__tests__/chat-utils.test.ts @@ -0,0 +1,269 @@ +/* Copyright 2026 Marimo. All rights reserved. */ + +import type { UIMessage } from "ai"; +import { describe, expect, it } from "vitest"; +import { hasPendingToolCalls } from "../chat-utils"; + +/** + * `hasPendingToolCalls` powers `sendAutomaticallyWhen` in `mo.ui.chat`: + * returns true only when the last assistant message *ends* with a tool + * call in a ready-to-round-trip state. Any trailing non-tool part (text, + * file, source-*, reasoning, data-*, new step-start) means the assistant + * has already answered and we leave the next turn to the user. The + * approval flow relies on this firing for `approval-responded`. + */ + +const userMessage = (text: string): UIMessage => ({ + id: `user-${text}`, + role: "user", + parts: [{ type: "text", text }], +}); + +const assistantToolMessage = ( + parts: UIMessage["parts"], + id = "assistant-1", +): UIMessage => ({ + id, + role: "assistant", + parts, +}); + +describe("hasPendingToolCalls", () => { + it("returns false when there are no messages", () => { + expect(hasPendingToolCalls([])).toBe(false); + }); + + it("returns false when the last message is a user message", () => { + expect(hasPendingToolCalls([userMessage("hi")])).toBe(false); + }); + + it("returns false when the last assistant message has no tool parts", () => { + expect( + hasPendingToolCalls([ + userMessage("hi"), + assistantToolMessage([{ type: "text", text: "hello!" }]), + ]), + ).toBe(false); + }); + + it("returns false while a tool is still streaming or awaiting approval", () => { + expect( + hasPendingToolCalls([ + userMessage("delete it"), + assistantToolMessage([ + { + type: "tool-delete_file", + toolCallId: "call-1", + state: "approval-requested", + input: { path: "secrets.env" }, + approval: { id: "approval-1" }, + } as unknown as UIMessage["parts"][number], + ]), + ]), + ).toBe(false); + }); + + it("returns true when the user has responded to an approval request", () => { + // The chat must auto-resume as soon as Approve/Deny is clicked. + expect( + hasPendingToolCalls([ + userMessage("delete it"), + assistantToolMessage([ + { + type: "tool-delete_file", + toolCallId: "call-1", + state: "approval-responded", + input: { path: "secrets.env" }, + approval: { id: "approval-1", approved: true }, + } as unknown as UIMessage["parts"][number], + ]), + ]), + ).toBe(true); + }); + + it("returns true when a tool reached a terminal output state", () => { + expect( + hasPendingToolCalls([ + userMessage("run it"), + assistantToolMessage([ + { + type: "tool-run_query", + toolCallId: "call-1", + state: "output-available", + input: { sql: "select 1" }, + output: 1, + } as unknown as UIMessage["parts"][number], + ]), + ]), + ).toBe(true); + }); + + it("returns false when only some tool calls are ready", () => { + expect( + hasPendingToolCalls([ + userMessage("two things"), + assistantToolMessage([ + { + type: "tool-first", + toolCallId: "call-1", + state: "output-available", + input: {}, + output: 1, + } as unknown as UIMessage["parts"][number], + { + type: "tool-second", + toolCallId: "call-2", + state: "input-available", + input: {}, + } as unknown as UIMessage["parts"][number], + ]), + ]), + ).toBe(false); + }); + + it("returns false once the assistant has appended text after the tool result", () => { + expect( + hasPendingToolCalls([ + userMessage("run it"), + assistantToolMessage([ + { + type: "tool-run_query", + toolCallId: "call-1", + state: "output-available", + input: {}, + output: 1, + } as unknown as UIMessage["parts"][number], + { type: "text", text: "The query returned 1." }, + ]), + ]), + ).toBe(false); + }); + + it("returns false when a file part trails the completed tool call", () => { + // Regression: tool → text → file used to loop because only trailing + // text counted as "the assistant has answered". + expect( + hasPendingToolCalls([ + userMessage("show me Starry Night"), + assistantToolMessage([ + { type: "step-start" }, + { + type: "tool-search_artwork", + toolCallId: "call-1", + state: "output-available", + input: { artist: "Van Gogh" }, + output: { title: "The Starry Night" }, + } as unknown as UIMessage["parts"][number], + { type: "text", text: "Here is the painting:" }, + { + type: "file", + mediaType: "image/jpeg", + url: "https://example.com/starry-night.jpg", + } as unknown as UIMessage["parts"][number], + ]), + ]), + ).toBe(false); + }); + + it("returns false when a source-url part trails the completed tool call", () => { + expect( + hasPendingToolCalls([ + userMessage("cite your sources"), + assistantToolMessage([ + { + type: "tool-web_search", + toolCallId: "call-1", + state: "output-available", + input: { q: "marimo notebook" }, + output: "found", + } as unknown as UIMessage["parts"][number], + { type: "text", text: "marimo is a reactive notebook." }, + { + type: "source-url", + sourceId: "src-1", + url: "https://marimo.io", + } as unknown as UIMessage["parts"][number], + ]), + ]), + ).toBe(false); + }); + + it("returns false when a reasoning part trails the completed tool call", () => { + expect( + hasPendingToolCalls([ + userMessage("explain"), + assistantToolMessage([ + { + type: "tool-lookup", + toolCallId: "call-1", + state: "output-available", + input: {}, + output: 1, + } as unknown as UIMessage["parts"][number], + { + type: "reasoning", + text: "Now I'll summarize.", + } as unknown as UIMessage["parts"][number], + ]), + ]), + ).toBe(false); + }); + + it("returns false when a new step-start follows the completed tool call", () => { + expect( + hasPendingToolCalls([ + userMessage("multi-step"), + assistantToolMessage([ + { type: "step-start" }, + { + type: "tool-run_query", + toolCallId: "call-1", + state: "output-available", + input: {}, + output: 1, + } as unknown as UIMessage["parts"][number], + { type: "step-start" }, + ]), + ]), + ).toBe(false); + }); + + it("ignores providerExecuted tools", () => { + // Provider-side tools are resolved by the model, not the runtime, so + // they must not drive an auto-resume. + expect( + hasPendingToolCalls([ + userMessage("hi"), + assistantToolMessage([ + { + type: "tool-web_search", + toolCallId: "call-1", + state: "output-available", + input: {}, + output: 1, + providerExecuted: true, + } as unknown as UIMessage["parts"][number], + ]), + ]), + ).toBe(false); + }); + + it("returns true for dynamic-tool parts in a terminal state", () => { + // `dynamic-tool` parts must drive auto-resume alongside `tool-*`. + expect( + hasPendingToolCalls([ + userMessage("run it"), + assistantToolMessage([ + { + type: "dynamic-tool", + toolName: "run_query", + toolCallId: "call-1", + state: "output-available", + input: {}, + output: 1, + } as unknown as UIMessage["parts"][number], + ]), + ]), + ).toBe(true); + }); +}); diff --git a/frontend/src/components/chat/chat-utils.ts b/frontend/src/components/chat/chat-utils.ts index fbb3552f30d..c09195345af 100644 --- a/frontend/src/components/chat/chat-utils.ts +++ b/frontend/src/components/chat/chat-utils.ts @@ -5,7 +5,8 @@ import { type ChatAddToolOutputFunction, type FileUIPart, isToolUIPart, - type ToolUIPart, + lastAssistantMessageIsCompleteWithApprovalResponses, + lastAssistantMessageIsCompleteWithToolCalls, type UIMessage, } from "ai"; import { useState } from "react"; @@ -17,7 +18,6 @@ import type { InvokeAiToolRequest, InvokeAiToolResponse, } from "@/core/network/types"; -import { logNever } from "@/utils/assertNever"; import { blobToString } from "@/utils/fileToBase64"; import { Logger } from "@/utils/Logger"; import { getAICompletionBodyWithAttachments } from "../editor/ai/completion-utils"; @@ -169,69 +169,25 @@ export async function handleToolCall({ } /** - * Returns true if a tool call is "ready to be sent back to the server" — i.e. - * either it has reached a terminal output state, or the user has just supplied - * an approval response that the server hasn't seen yet. - */ -function isToolCallReadyToSend(state: ToolUIPart["state"]): boolean { - switch (state) { - case "output-available": - case "output-error": - case "output-denied": - case "approval-responded": - return true; - case "input-streaming": - case "input-available": - case "approval-requested": - return false; - default: - logNever(state); - return false; - } -} - -/** - * Checks if we should send a message automatically based on the messages. - * We auto-send when every tool call on the last assistant message has either - * finished (output-available/error/denied) or has just received a user - * approval response, and the assistant hasn't replied yet. + * Auto-send the next turn when the last assistant message ends with a + * tool call ready to round-trip. Any non-tool trailing part (text, file, + * source-*, reasoning, data-*, new step-start) means the assistant has + * already answered, so we leave the next turn to the user. State checks + * are delegated to the SDK to stay in sync with upstream. */ export function hasPendingToolCalls(messages: UIMessage[]): boolean { - if (messages.length === 0) { - return false; - } - - const lastMessage = messages[messages.length - 1]; - const parts = lastMessage.parts; - - if (parts.length === 0) { - return false; - } - - // Only auto-send if the last message is an assistant message - // Because assistant messages are the ones that can have tool calls - if (lastMessage.role !== "assistant") { + const lastMessage = messages.at(-1); + if (!lastMessage || lastMessage.role !== "assistant") { return false; } - - const toolParts = parts.filter(isToolUIPart); - - if (toolParts.length === 0) { + const lastPart = lastMessage.parts.at(-1); + if (!lastPart || !isToolUIPart(lastPart)) { return false; } - - const allToolCallsReady = toolParts.every((part) => - isToolCallReadyToSend(part.state), + return ( + lastAssistantMessageIsCompleteWithToolCalls({ messages }) || + lastAssistantMessageIsCompleteWithApprovalResponses({ messages }) ); - - // Check if the last part has any text content - const lastPart = parts[parts.length - 1]; - const hasTextContent = - lastPart.type === "text" && lastPart.text?.trim().length > 0; - - Logger.debug("All tool calls ready to send: %s", allToolCallsReady); - - return allToolCallsReady && !hasTextContent; } export function useFileState() { diff --git a/frontend/src/plugins/impl/chat/chat-ui.tsx b/frontend/src/plugins/impl/chat/chat-ui.tsx index fd719e568dd..a10d9bba1fd 100644 --- a/frontend/src/plugins/impl/chat/chat-ui.tsx +++ b/frontend/src/plugins/impl/chat/chat-ui.tsx @@ -27,7 +27,10 @@ import { import React, { useEffect, useRef, useState } from "react"; import { z } from "zod"; import { renderUIMessage } from "@/components/chat/chat-display"; -import { convertToFileUIPart } from "@/components/chat/chat-utils"; +import { + convertToFileUIPart, + hasPendingToolCalls, +} from "@/components/chat/chat-utils"; import { type AdditionalCompletions, PromptInput, @@ -186,7 +189,9 @@ export const Chatbot: React.FC = (props) => { error, regenerate, clearError, + addToolApprovalResponse, } = useChat({ + sendAutomaticallyWhen: ({ messages }) => hasPendingToolCalls(messages), transport: new DefaultChatTransport({ fetch: async ( request: RequestInfo | URL, @@ -440,6 +445,7 @@ export const Chatbot: React.FC = (props) => { message, isStreamingReasoning: status === "streaming", isLast, + addToolApprovalResponse, })}
diff --git a/marimo/_ai/_pydantic_ai_utils.py b/marimo/_ai/_pydantic_ai_utils.py index c1653ad21d4..2b37f7e0e2c 100644 --- a/marimo/_ai/_pydantic_ai_utils.py +++ b/marimo/_ai/_pydantic_ai_utils.py @@ -98,7 +98,7 @@ def safe_part_processor( or generate_id("message") ) role = message.get("role", "assistant") - parts = [_sanitize_part(part) for part in message.get("parts", [])] + parts = [sanitize_part(part) for part in message.get("parts", [])] metadata = message.get("metadata") ui_message = UIMessage( @@ -147,7 +147,7 @@ def _tool_part_allowed_fields() -> dict[tuple[bool, str], frozenset[str]]: return result -def _sanitize_part(part: Any) -> Any: +def sanitize_part(part: Any) -> Any: """Drop fields the AI SDK spread onto a tool part during a state transition. The AI SDK transitions tool parts via `{ ...part, state, approval }`, which diff --git a/marimo/_ai/_types.py b/marimo/_ai/_types.py index a179e07b547..8042718415d 100644 --- a/marimo/_ai/_types.py +++ b/marimo/_ai/_types.py @@ -5,13 +5,20 @@ import mimetypes from collections.abc import Iterator from dataclasses import asdict, dataclass, is_dataclass -from typing import TYPE_CHECKING, Any, Literal, TypedDict, cast +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Literal, + TypedDict, + cast, + override, +) import msgspec from marimo import _loggers from marimo._dependencies.dependencies import DependencyManager -from marimo._utils.dicts import remove_none_values from marimo._utils.parse_dataclass import parse_raw LOGGER = _loggers.marimo_logger() @@ -190,7 +197,7 @@ class StepStartPart: ] -class ChatMessage(msgspec.Struct): +class ChatMessage(msgspec.Struct, eq=False): """ A message in a chat. """ @@ -214,16 +221,42 @@ class ChatMessage(msgspec.Struct): metadata: Any | None = None + # High-fidelity snapshot of `parts` as they arrive over the wire. + _raw_parts: list[dict[str, Any]] | None = None + def __post_init__(self) -> None: - # Hack: msgspec only supports discriminated unions. This is a hack to just - # iterate through possible part variants and decode until one works. if self.parts: + # Snapshot the wire-format dicts before lossy conversion. We only + # snapshot when every part is a dict so we don't store a partial + # view that mixes typed and raw entries. + if self._raw_parts is None and all( + isinstance(p, dict) for p in self.parts + ): + self._raw_parts = [cast(dict[str, Any], p) for p in self.parts] + # Hack: msgspec only supports discriminated unions. This is a hack to just + # iterate through possible part variants and decode until one works. parts = [] for part in self.parts: if converted := self._convert_part(part): parts.append(converted) self.parts = parts + # Fields excluded from equality. Add anything here that is a cache / + # representation detail rather than part of the message's identity, so + # `__eq__` keeps comparing every "real" field automatically as the + # struct grows. + _EQ_EXCLUDE: ClassVar[frozenset[str]] = frozenset({"_raw_parts"}) + + @override + def __eq__(self, other: object) -> bool: + if not isinstance(other, ChatMessage): + return NotImplemented + return all( + getattr(self, name) == getattr(other, name) + for name in self.__struct_fields__ + if name not in self._EQ_EXCLUDE + ) + def _convert_part(self, part: Any) -> ChatPart | None: # If we receive a Vercel AI SDK part (through pydantic-ai), return it as is. if DependencyManager.pydantic_ai.imported(): @@ -246,13 +279,38 @@ def _convert_part(self, part: Any) -> ChatPart | None: ) return None + def raw_or_dumped_parts(self) -> list[dict[str, Any]]: + """Return parts in dict form, preferring the original wire payload.""" + if self._raw_parts is not None: + return self._raw_parts + + result: list[dict[str, Any]] = [] + for part in self.parts: + if is_dataclass(part): + result.append(asdict(part)) + elif DependencyManager.pydantic_ai.imported(): + from pydantic_ai.ui.vercel_ai.request_types import ( + UIMessagePart, + ) + + if isinstance(part, UIMessagePart): + result.append( + part.model_dump(by_alias=True, exclude_none=True) + ) + elif isinstance(part, dict): + # Defensive: at runtime `parts` may carry raw dicts because + # `ChatPart` is `dict[str, Any]` at runtime even though the + # type-checking alias is a union of dataclasses. + result.append(part) + return result + def __iter__(self) -> Iterator[tuple[str, Any]]: """Allow dict(message) to build the serialized dict.""" out: ChatMessageDict = { "role": self.role, "id": self.id, "content": self.content, - "parts": [cast(ChatPartDict, asdict(part)) for part in self.parts], + "parts": cast(list[ChatPartDict], self.raw_or_dumped_parts()), "attachments": [ cast(ChatAttachmentDict, asdict(a)) for a in self.attachments ] @@ -278,12 +336,16 @@ def create( """ if part_validator_class: + # Lazy import: `_pydantic_ai_utils` pulls in `marimo._server.*`, + # which we don't want to load just to define the types module. + from marimo._ai._pydantic_ai_utils import sanitize_part + validated_parts = [] for part in parts: if isinstance(part, part_validator_class): validated_parts.append(part) elif isinstance(part, dict): - sanitized_part = remove_none_values(part) + sanitized_part = sanitize_part(part) # Try pydantic validation for dict -> class conversion try: from pydantic import TypeAdapter diff --git a/marimo/_ai/llm/_impl.py b/marimo/_ai/llm/_impl.py index c3980af68f9..37ea89a5105 100644 --- a/marimo/_ai/llm/_impl.py +++ b/marimo/_ai/llm/_impl.py @@ -1,16 +1,14 @@ # Copyright 2026 Marimo. All rights reserved. from __future__ import annotations -import dataclasses import json import os import re from typing import TYPE_CHECKING, Any, cast from marimo import _loggers -from marimo._ai._pydantic_ai_utils import generate_id +from marimo._ai._pydantic_ai_utils import generate_id, sanitize_part from marimo._plugins.ui._impl.chat.chat import AI_SDK_VERSION, DONE_CHUNK -from marimo._utils.dicts import remove_none_values if TYPE_CHECKING: from collections.abc import AsyncGenerator, Callable, Generator @@ -747,18 +745,14 @@ def _build_ui_messages( if not message.id: LOGGER.warning("Message %s has no id", message) + # Prefer the raw wire payload when we have it so fields outside + # marimo's lossy dataclasses (`approval`, `providerExecuted`, + # `preliminary`, ...) survive the round-trip into pydantic-ai. parts: list[UIMessagePart] = [] if message.parts: parts = cast( list[UIMessagePart], - [ - self._remove_none_values( - dataclasses.asdict(part) - if dataclasses.is_dataclass(part) - else part - ) - for part in message.parts - ], + [sanitize_part(p) for p in message.raw_or_dumped_parts()], ) if not parts: if message.content is not None: @@ -784,11 +778,6 @@ def _build_ui_messages( ) return ui_messages - def _remove_none_values(self, obj: dict[str, Any]) -> dict[str, Any]: - if isinstance(obj, dict) and hasattr(obj, "items"): - return remove_none_values(obj) - return obj - def _serialize_vercel_ai_chunk( self, chunk: BaseChunk ) -> dict[str, Any] | None: @@ -849,7 +838,11 @@ async def _stream_response( messages=ui_messages, ) - adapter = VercelAIAdapter(agent=self.agent, run_input=run_input) + adapter = VercelAIAdapter( + agent=self.agent, + run_input=run_input, + sdk_version=AI_SDK_VERSION, + ) event_stream = adapter.run_stream(model_settings=model_settings) async for event in event_stream: if serialized := self._serialize_vercel_ai_chunk(event): diff --git a/tests/_ai/llm/test_impl.py b/tests/_ai/llm/test_impl.py index 6b23b74c55a..7db95cfb305 100644 --- a/tests/_ai/llm/test_impl.py +++ b/tests/_ai/llm/test_impl.py @@ -3,7 +3,7 @@ from __future__ import annotations import os -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from unittest.mock import MagicMock, patch import pytest @@ -20,8 +20,11 @@ simple, ) from marimo._dependencies.dependencies import DependencyManager +from marimo._plugins.ui._impl.chat.chat import AI_SDK_VERSION if TYPE_CHECKING: + from collections.abc import AsyncIterator + from pydantic_ai.settings import ModelSettings @@ -1410,6 +1413,9 @@ async def mock_run_stream(**_kwargs: Any): {"id": "2", "type": "text-delta", "delta": " World"}, ] + _, kwargs = mock_adapter.call_args + assert kwargs.get("sdk_version") == AI_SDK_VERSION + async def test_stream_text(self): """Test _stream_text streams text from the model.""" mock_agent = MagicMock() @@ -1597,6 +1603,136 @@ def test_pydantic_ai_serialize_vercel_ai_chunk_error_handling( result = model._serialize_vercel_ai_chunk(cast(Any, error_chunk)) assert result is None + def test_build_ui_messages_preserves_tool_approval_field(self): + """When the frontend posts back an `approval-responded` tool part, + the `approval` payload must reach pydantic-ai intact. The lossy + `ToolInvocationPart` dataclass doesn't model `approval`, so without + the `_raw_parts` snapshot the field would be silently dropped and + the agent would loop on the same tool call forever. + + Regression for the second half of the tool-approval fix: the first + half (`sdk_version=AI_SDK_VERSION`) made the request chunk visible + to the frontend; this half makes the user's response visible to the + agent. + """ + from pydantic_ai.ui.vercel_ai.request_types import ( + ToolApprovalResponded, + ToolApprovalRespondedPart, + ) + + model = pydantic_ai(MagicMock()) + messages = [ + ChatMessage( + role="user", + content="Delete secrets.env", + id="msg-user-1", + parts=cast( # pyright: ignore[reportAny] + Any, + [{"type": "text", "text": "Delete secrets.env"}], + ), + ), + ChatMessage( + role="assistant", + content=None, + id="msg-assistant-1", + parts=cast( # pyright: ignore[reportAny] + Any, + [ + { + "type": "tool-delete_file", + "toolCallId": "call-1", + "state": "approval-responded", + "input": {"path": "secrets.env"}, + "approval": { + "id": "call-1", + "approved": True, + }, + } + ], + ), + ), + ] + + ui_messages = model._build_ui_messages(messages) + assert len(ui_messages) == 2 + + # The tool part must be reified as the *responded* variant — the + # one that carries the approval — and the approval must survive. + tool_part = ui_messages[1].parts[0] + assert isinstance(tool_part, ToolApprovalRespondedPart), ( + f"Expected ToolApprovalRespondedPart, got {type(tool_part).__name__}: " + f"{tool_part!r}" + ) + approval = tool_part.approval + assert isinstance(approval, ToolApprovalResponded), ( + f"Expected ToolApprovalResponded, got {approval!r}" + ) + assert approval.id == "call-1" + assert approval.approved is True + + async def test_stream_response_emits_tool_approval_request(self): + """Tools with `requires_approval=True` should surface an + approval-request chunk so the frontend can render an Approve/Deny + card. This is the v6-only behavior unlocked by passing + `sdk_version=AI_SDK_VERSION` to the adapter. + """ + from pydantic_ai import Agent, DeferredToolRequests + from pydantic_ai.models.function import ( + AgentInfo, + DeltaToolCall, + DeltaToolCalls, + FunctionModel, + ) + + async def respond( + messages: list[Any], _info: AgentInfo + ) -> AsyncIterator[DeltaToolCalls]: + del messages, _info + yield { + 0: DeltaToolCall( + name="delete_file", + json_args='{"path": "secrets.env"}', + tool_call_id="call-1", + ) + } + + agent = Agent( + FunctionModel(stream_function=respond), + output_type=[str, DeferredToolRequests], + ) + + @agent.tool_plain(requires_approval=True) + def delete_file(path: str) -> str: + return f"File {path!r} deleted" + + model = pydantic_ai(agent) + messages = [ + ChatMessage( + role="user", + content="Delete secrets.env", + id="msg-user-1", + parts=[TextPart(type="text", text="Delete secrets.env")], + ), + ] + config = ChatModelConfig(max_tokens=100) + + chunks = [ + chunk async for chunk in model._stream_response(messages, config) + ] + + approval_chunks = [ + chunk + for chunk in chunks + if chunk.get("type") == "tool-approval-request" + ] + assert approval_chunks == [ + { + "type": "tool-approval-request", + "approvalId": "call-1", + "toolCallId": "call-1", + } + ] + class MockBaseChunkWithError: """Mock BaseChunk that raises on serialization.""" diff --git a/tests/_ai/test_ai_types.py b/tests/_ai/test_ai_types.py index 2f262bd5b2c..4228ef7da2b 100644 --- a/tests/_ai/test_ai_types.py +++ b/tests/_ai/test_ai_types.py @@ -546,3 +546,87 @@ def test_parts_and_attachments_serialized_to_dict(self): ], "metadata": None, } + + +class TestChatMessageRawPartsRoundTrip: + """`ChatMessage` must snapshot raw wire payloads so we don't drop AI + SDK fields the typed dataclasses don't model (e.g. `approval`, + `callProviderMetadata`). Without this the pydantic-ai bridge would + lose context on every deferred tool run. + """ + + def test_typed_input_does_not_snapshot(self): + """Typed parts are themselves the source of truth — no snapshot.""" + message = ChatMessage( + role="user", + content="hi", + id="msg", + parts=[TextPart(type="text", text="hi")], + ) + assert message._raw_parts is None + assert dict[str, Any](message)["parts"] == [ + {"type": "text", "text": "hi"} + ] + + def test_equality_ignores_raw_parts(self): + """Equality compares value, not cache state.""" + from_dict = ChatMessage( + role="user", + content="hi", + id="msg", + parts=[cast(ChatPart, {"type": "text", "text": "hi"})], + ) + from_typed = ChatMessage( + role="user", + content="hi", + id="msg", + parts=[TextPart(type="text", text="hi")], + ) + assert from_dict._raw_parts is not None + assert from_typed._raw_parts is None + assert from_dict == from_typed + + @pytest.mark.parametrize( + "raw", + [ + # Tool part carrying fields marimo's dataclass doesn't model — + # this is the case that motivated `_raw_parts` in the first place. + { + "type": "tool-delete_file", + "toolCallId": "call-1", + "state": "approval-responded", + "input": {"path": "secrets.env"}, + "approval": {"id": "call-1", "approved": True}, + "callProviderMetadata": {"openai": {"foo": "bar"}}, + }, + # Alternate tool shape. + { + "type": "dynamic-tool", + "toolName": "delete_file", + "toolCallId": "call-1", + "state": "input-streaming", + }, + # Non-tool part with its own bag of optional fields. + { + "type": "file", + "mediaType": "image/png", + "url": "data:image/png;base64,abc", + }, + # `data-*` parts are user-defined; the type string itself + # carries information and must survive the round-trip. + { + "type": "data-reasoning-signature", + "data": {"signature": "x"}, + }, + ], + ) + def test_raw_part_round_trips_verbatim(self, raw: dict[str, Any]): + """`dict(ChatMessage(parts=[raw]))["parts"]` must equal `[raw]`.""" + message = ChatMessage( + role="assistant", + content=None, + id="msg", + parts=[cast(ChatPart, raw)], + ) + assert dict[str, Any](message)["parts"] == [raw] + assert message._raw_parts == [raw] diff --git a/tests/_ai/test_pydantic_utils.py b/tests/_ai/test_pydantic_utils.py index 91a36600f9d..e3b79441611 100644 --- a/tests/_ai/test_pydantic_utils.py +++ b/tests/_ai/test_pydantic_utils.py @@ -9,11 +9,11 @@ pytest.importorskip("pydantic_ai", reason="pydantic_ai not installed") from marimo._ai._pydantic_ai_utils import ( - _sanitize_part, convert_to_pydantic_messages, create_simple_prompt, form_toolsets, generate_id, + sanitize_part, ) from marimo._server.ai.tools.types import ToolDefinition @@ -521,7 +521,7 @@ def test_strips_stale_fields_on_approval_responded(self): "output": "Invalid arguments for tool request_user_blessing", "errorText": "boom", } - clean = _sanitize_part(stale) + clean = sanitize_part(stale) assert "output" not in clean assert "errorText" not in clean assert clean["approval"] == {"id": "call_abc", "approved": True} @@ -538,7 +538,7 @@ def test_preserves_real_fields_on_output_available(self): "output": {"ok": True}, "preliminary": False, } - clean = _sanitize_part(part) + clean = sanitize_part(part) assert clean["output"] == {"ok": True} assert clean["preliminary"] is False @@ -551,7 +551,7 @@ def test_preserves_real_fields_on_output_error(self): "rawInput": {"x": 1}, "errorText": "boom", } - clean = _sanitize_part(part) + clean = sanitize_part(part) assert clean["errorText"] == "boom" assert clean["rawInput"] == {"x": 1} @@ -565,7 +565,7 @@ def test_dynamic_tool_preserves_tool_name(self): "approval": {"id": "c1", "approved": True}, "output": "stale", } - clean = _sanitize_part(stale) + clean = sanitize_part(stale) assert "output" not in clean assert clean["toolName"] == "my_tool" @@ -584,7 +584,7 @@ def test_dynamic_tool_preserves_tool_name(self): ], ) def test_non_tool_parts_pass_through_unchanged(self, part): - assert _sanitize_part(part) is part + assert sanitize_part(part) is part @pytest.mark.parametrize( "part", @@ -606,7 +606,7 @@ def test_non_tool_parts_pass_through_unchanged(self, part): ], ) def test_malformed_parts_pass_through_unchanged(self, part): - assert _sanitize_part(part) == part + assert sanitize_part(part) == part class TestConvertToPydanticMessagesSanitizes: From 92417be1ef752d3848bb826576f09dd08e5a2a18 Mon Sep 17 00:00:00 2001 From: Shahmir Varqha Date: Wed, 20 May 2026 15:59:35 +0800 Subject: [PATCH 02/10] remove override --- marimo/_ai/_types.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/marimo/_ai/_types.py b/marimo/_ai/_types.py index 8042718415d..fe78d5e35d6 100644 --- a/marimo/_ai/_types.py +++ b/marimo/_ai/_types.py @@ -5,15 +5,7 @@ import mimetypes from collections.abc import Iterator from dataclasses import asdict, dataclass, is_dataclass -from typing import ( - TYPE_CHECKING, - Any, - ClassVar, - Literal, - TypedDict, - cast, - override, -) +from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypedDict, cast import msgspec @@ -247,7 +239,6 @@ def __post_init__(self) -> None: # struct grows. _EQ_EXCLUDE: ClassVar[frozenset[str]] = frozenset({"_raw_parts"}) - @override def __eq__(self, other: object) -> bool: if not isinstance(other, ChatMessage): return NotImplemented From 319f270318540c10495ca701ca4fc632f17add4e Mon Sep 17 00:00:00 2001 From: Shahmir Varqha Date: Wed, 20 May 2026 16:02:27 +0800 Subject: [PATCH 03/10] fix from comment --- marimo/_ai/_types.py | 22 +++++++++++----------- tests/_ai/test_ai_types.py | 16 ++++++++++++++++ 2 files changed, 27 insertions(+), 11 deletions(-) diff --git a/marimo/_ai/_types.py b/marimo/_ai/_types.py index fe78d5e35d6..ceb3b81ae82 100644 --- a/marimo/_ai/_types.py +++ b/marimo/_ai/_types.py @@ -275,23 +275,23 @@ def raw_or_dumped_parts(self) -> list[dict[str, Any]]: if self._raw_parts is not None: return self._raw_parts + ui_message_part_cls: type | None = None + if DependencyManager.pydantic_ai.imported(): + from pydantic_ai.ui.vercel_ai.request_types import UIMessagePart + + ui_message_part_cls = UIMessagePart + result: list[dict[str, Any]] = [] for part in self.parts: if is_dataclass(part): result.append(asdict(part)) - elif DependencyManager.pydantic_ai.imported(): - from pydantic_ai.ui.vercel_ai.request_types import ( - UIMessagePart, + elif ui_message_part_cls is not None and isinstance( + part, ui_message_part_cls + ): + result.append( + part.model_dump(by_alias=True, exclude_none=True) ) - - if isinstance(part, UIMessagePart): - result.append( - part.model_dump(by_alias=True, exclude_none=True) - ) elif isinstance(part, dict): - # Defensive: at runtime `parts` may carry raw dicts because - # `ChatPart` is `dict[str, Any]` at runtime even though the - # type-checking alias is a union of dataclasses. result.append(part) return result diff --git a/tests/_ai/test_ai_types.py b/tests/_ai/test_ai_types.py index 4228ef7da2b..ccd28b4be7f 100644 --- a/tests/_ai/test_ai_types.py +++ b/tests/_ai/test_ai_types.py @@ -630,3 +630,19 @@ def test_raw_part_round_trips_verbatim(self, raw: dict[str, Any]): ) assert dict[str, Any](message)["parts"] == [raw] assert message._raw_parts == [raw] + + def test_dict_part_appended_after_init_is_preserved(self): + message = ChatMessage( + role="assistant", + content=None, + id="msg", + parts=[TextPart(type="text", text="hi")], + ) + assert message._raw_parts is None + + extra = {"type": "data-custom", "data": {"k": "v"}} + message.parts.append(cast(ChatPart, extra)) + + dumped = message.raw_or_dumped_parts() + assert dumped == [{"type": "text", "text": "hi"}, extra] + assert dict[str, Any](message)["parts"] == dumped From 53f63abcf7dbb8c970564b2f8e93725bf6af4288 Mon Sep 17 00:00:00 2001 From: Shahmir Varqha Date: Wed, 20 May 2026 16:24:54 +0800 Subject: [PATCH 04/10] fix mypy type for ui_message_part_cls Co-Authored-By: Claude Opus 4.7 --- marimo/_ai/_types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/marimo/_ai/_types.py b/marimo/_ai/_types.py index ceb3b81ae82..07db4389943 100644 --- a/marimo/_ai/_types.py +++ b/marimo/_ai/_types.py @@ -275,7 +275,7 @@ def raw_or_dumped_parts(self) -> list[dict[str, Any]]: if self._raw_parts is not None: return self._raw_parts - ui_message_part_cls: type | None = None + ui_message_part_cls: Any = None if DependencyManager.pydantic_ai.imported(): from pydantic_ai.ui.vercel_ai.request_types import UIMessagePart From 12efb1a582f7096a089f6d65a1a866853d810354 Mon Sep 17 00:00:00 2001 From: Shahmir Varqha Date: Wed, 20 May 2026 16:51:13 +0800 Subject: [PATCH 05/10] address PR review feedback - VercelAIAdapter: fallback when sdk_version kwarg unsupported - chat-ui: only pass addToolApprovalResponse to the last message - ChatMessage: move _raw_parts off the msgspec struct so it isn't serialized Co-Authored-By: Claude Opus 4.7 --- frontend/src/plugins/impl/chat/chat-ui.tsx | 4 +++- marimo/_ai/_types.py | 23 +++++----------------- marimo/_ai/llm/_impl.py | 16 ++++++++++----- 3 files changed, 19 insertions(+), 24 deletions(-) diff --git a/frontend/src/plugins/impl/chat/chat-ui.tsx b/frontend/src/plugins/impl/chat/chat-ui.tsx index a10d9bba1fd..f4fabb3b8db 100644 --- a/frontend/src/plugins/impl/chat/chat-ui.tsx +++ b/frontend/src/plugins/impl/chat/chat-ui.tsx @@ -445,7 +445,9 @@ export const Chatbot: React.FC = (props) => { message, isStreamingReasoning: status === "streaming", isLast, - addToolApprovalResponse, + addToolApprovalResponse: isLast + ? addToolApprovalResponse + : undefined, })}
diff --git a/marimo/_ai/_types.py b/marimo/_ai/_types.py index 07db4389943..cf955b1e24b 100644 --- a/marimo/_ai/_types.py +++ b/marimo/_ai/_types.py @@ -5,7 +5,7 @@ import mimetypes from collections.abc import Iterator from dataclasses import asdict, dataclass, is_dataclass -from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypedDict, cast +from typing import TYPE_CHECKING, Any, Literal, TypedDict, cast import msgspec @@ -189,7 +189,7 @@ class StepStartPart: ] -class ChatMessage(msgspec.Struct, eq=False): +class ChatMessage(msgspec.Struct, eq=False, dict=True): """ A message in a chat. """ @@ -213,17 +213,11 @@ class ChatMessage(msgspec.Struct, eq=False): metadata: Any | None = None - # High-fidelity snapshot of `parts` as they arrive over the wire. - _raw_parts: list[dict[str, Any]] | None = None - def __post_init__(self) -> None: + # Non-struct attribute (via dict=True) so it never reaches the wire. + self._raw_parts: list[dict[str, Any]] | None = None if self.parts: - # Snapshot the wire-format dicts before lossy conversion. We only - # snapshot when every part is a dict so we don't store a partial - # view that mixes typed and raw entries. - if self._raw_parts is None and all( - isinstance(p, dict) for p in self.parts - ): + if all(isinstance(p, dict) for p in self.parts): self._raw_parts = [cast(dict[str, Any], p) for p in self.parts] # Hack: msgspec only supports discriminated unions. This is a hack to just # iterate through possible part variants and decode until one works. @@ -233,19 +227,12 @@ def __post_init__(self) -> None: parts.append(converted) self.parts = parts - # Fields excluded from equality. Add anything here that is a cache / - # representation detail rather than part of the message's identity, so - # `__eq__` keeps comparing every "real" field automatically as the - # struct grows. - _EQ_EXCLUDE: ClassVar[frozenset[str]] = frozenset({"_raw_parts"}) - def __eq__(self, other: object) -> bool: if not isinstance(other, ChatMessage): return NotImplemented return all( getattr(self, name) == getattr(other, name) for name in self.__struct_fields__ - if name not in self._EQ_EXCLUDE ) def _convert_part(self, part: Any) -> ChatPart | None: diff --git a/marimo/_ai/llm/_impl.py b/marimo/_ai/llm/_impl.py index 37ea89a5105..6d838c94cfc 100644 --- a/marimo/_ai/llm/_impl.py +++ b/marimo/_ai/llm/_impl.py @@ -838,11 +838,17 @@ async def _stream_response( messages=ui_messages, ) - adapter = VercelAIAdapter( - agent=self.agent, - run_input=run_input, - sdk_version=AI_SDK_VERSION, - ) + try: + adapter = VercelAIAdapter( + agent=self.agent, + run_input=run_input, + sdk_version=AI_SDK_VERSION, + ) + except TypeError: + adapter = VercelAIAdapter( + agent=self.agent, + run_input=run_input, + ) event_stream = adapter.run_stream(model_settings=model_settings) async for event in event_stream: if serialized := self._serialize_vercel_ai_chunk(event): From 7b3da27f7225ef5ce903082412fd7f8ff0f4804d Mon Sep 17 00:00:00 2001 From: Shahmir Varqha Date: Wed, 20 May 2026 16:52:54 +0800 Subject: [PATCH 06/10] restore _EQ_EXCLUDE infrastructure on ChatMessage Co-Authored-By: Claude Opus 4.7 --- marimo/_ai/_types.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/marimo/_ai/_types.py b/marimo/_ai/_types.py index cf955b1e24b..48e2eff3ae9 100644 --- a/marimo/_ai/_types.py +++ b/marimo/_ai/_types.py @@ -5,7 +5,7 @@ import mimetypes from collections.abc import Iterator from dataclasses import asdict, dataclass, is_dataclass -from typing import TYPE_CHECKING, Any, Literal, TypedDict, cast +from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypedDict, cast import msgspec @@ -227,12 +227,19 @@ def __post_init__(self) -> None: parts.append(converted) self.parts = parts + # Fields excluded from equality. Add anything here that is a cache / + # representation detail rather than part of the message's identity, so + # `__eq__` keeps comparing every "real" field automatically as the + # struct grows. + _EQ_EXCLUDE: ClassVar[frozenset[str]] = frozenset() + def __eq__(self, other: object) -> bool: if not isinstance(other, ChatMessage): return NotImplemented return all( getattr(self, name) == getattr(other, name) for name in self.__struct_fields__ + if name not in self._EQ_EXCLUDE ) def _convert_part(self, part: Any) -> ChatPart | None: From 3f3a6390d23f9640bcb33c63275fbec533beb6da Mon Sep 17 00:00:00 2001 From: Shahmir Varqha Date: Wed, 20 May 2026 22:43:23 +0800 Subject: [PATCH 07/10] preserve raw dict snapshots in mixed-input ChatMessage - ChatMessage: snapshot raw dicts per-element so unmodeled SDK fields (e.g. `approval`, `callProviderMetadata`) survive even when `parts` mixes typed and dict entries. - test_impl: don't pin approvalId to a specific value; older pydantic-ai emits a UUID while newer reuses toolCallId. Co-Authored-By: Claude Opus 4.7 --- marimo/_ai/_types.py | 66 +++++++++++++++++++++++++++----------- tests/_ai/llm/test_impl.py | 15 +++++---- tests/_ai/test_ai_types.py | 26 +++++++++++++++ 3 files changed, 82 insertions(+), 25 deletions(-) diff --git a/marimo/_ai/_types.py b/marimo/_ai/_types.py index 48e2eff3ae9..f7d02149293 100644 --- a/marimo/_ai/_types.py +++ b/marimo/_ai/_types.py @@ -214,17 +214,30 @@ class ChatMessage(msgspec.Struct, eq=False, dict=True): metadata: Any | None = None def __post_init__(self) -> None: - # Non-struct attribute (via dict=True) so it never reaches the wire. - self._raw_parts: list[dict[str, Any]] | None = None + # Non-struct attributes (via dict=True) so they never reach the wire. + # Snapshots are aligned with the original input order; SDK fields the + # typed dataclasses don't model are preserved even when the input is + # a mix of dicts and typed parts. `_parts_idx[i]` is the index of the + # converted part in `self.parts` (or None when conversion dropped it). + self._raw_parts: list[dict[str, Any] | None] | None = None + self._parts_idx: list[int | None] | None = None if self.parts: - if all(isinstance(p, dict) for p in self.parts): - self._raw_parts = [cast(dict[str, Any], p) for p in self.parts] - # Hack: msgspec only supports discriminated unions. This is a hack to just - # iterate through possible part variants and decode until one works. - parts = [] + snapshots: list[dict[str, Any] | None] = [] + idx_map: list[int | None] = [] + parts: list[ChatPart] = [] for part in self.parts: + is_dict = isinstance(part, dict) + snapshots.append( + cast(dict[str, Any], part) if is_dict else None + ) if converted := self._convert_part(part): + idx_map.append(len(parts)) parts.append(converted) + else: + idx_map.append(None) + if any(s is not None for s in snapshots): + self._raw_parts = snapshots + self._parts_idx = idx_map self.parts = parts # Fields excluded from equality. Add anything here that is a cache / @@ -266,8 +279,8 @@ def _convert_part(self, part: Any) -> ChatPart | None: def raw_or_dumped_parts(self) -> list[dict[str, Any]]: """Return parts in dict form, preferring the original wire payload.""" - if self._raw_parts is not None: - return self._raw_parts + raw_parts = self._raw_parts + idx_map = self._parts_idx ui_message_part_cls: Any = None if DependencyManager.pydantic_ai.imported(): @@ -275,18 +288,35 @@ def raw_or_dumped_parts(self) -> list[dict[str, Any]]: ui_message_part_cls = UIMessagePart - result: list[dict[str, Any]] = [] - for part in self.parts: + def dump(part: Any) -> dict[str, Any] | None: if is_dataclass(part): - result.append(asdict(part)) - elif ui_message_part_cls is not None and isinstance( + return asdict(part) + if ui_message_part_cls is not None and isinstance( part, ui_message_part_cls ): - result.append( - part.model_dump(by_alias=True, exclude_none=True) - ) - elif isinstance(part, dict): - result.append(part) + return part.model_dump(by_alias=True, exclude_none=True) + if isinstance(part, dict): + return cast(dict[str, Any], part) + return None + + if raw_parts is None or idx_map is None: + return [d for p in self.parts if (d := dump(p)) is not None] + + result: list[dict[str, Any]] = [] + last_used = -1 + for snap, parts_idx in zip(raw_parts, idx_map, strict=True): + if snap is not None: + result.append(snap) + if parts_idx is not None: + last_used = max(last_used, parts_idx) + elif parts_idx is not None and parts_idx < len(self.parts): + if (d := dump(self.parts[parts_idx])) is not None: + result.append(d) + last_used = max(last_used, parts_idx) + # Parts appended after __post_init__ aren't in idx_map; dump them live. + for part in self.parts[last_used + 1 :]: + if (d := dump(part)) is not None: + result.append(d) return result def __iter__(self) -> Iterator[tuple[str, Any]]: diff --git a/tests/_ai/llm/test_impl.py b/tests/_ai/llm/test_impl.py index 7db95cfb305..9a84c7233d7 100644 --- a/tests/_ai/llm/test_impl.py +++ b/tests/_ai/llm/test_impl.py @@ -1725,13 +1725,14 @@ def delete_file(path: str) -> str: for chunk in chunks if chunk.get("type") == "tool-approval-request" ] - assert approval_chunks == [ - { - "type": "tool-approval-request", - "approvalId": "call-1", - "toolCallId": "call-1", - } - ] + # Older pydantic-ai generates a UUID approvalId; newer versions reuse + # toolCallId. Either is fine — assert the shape, not the exact value. + assert len(approval_chunks) == 1 + chunk = approval_chunks[0] + assert chunk["type"] == "tool-approval-request" + assert chunk["toolCallId"] == "call-1" + assert isinstance(chunk["approvalId"], str) + assert chunk["approvalId"] class MockBaseChunkWithError: diff --git a/tests/_ai/test_ai_types.py b/tests/_ai/test_ai_types.py index ccd28b4be7f..e731a058d9e 100644 --- a/tests/_ai/test_ai_types.py +++ b/tests/_ai/test_ai_types.py @@ -631,6 +631,32 @@ def test_raw_part_round_trips_verbatim(self, raw: dict[str, Any]): assert dict[str, Any](message)["parts"] == [raw] assert message._raw_parts == [raw] + def test_mixed_dict_and_typed_parts_preserve_raw_dicts(self): + """When `parts` mixes raw dicts and typed inputs, the raw dicts' + unmodeled fields (e.g. `approval`) must survive — not just the + all-dicts case. + """ + raw = { + "type": "tool-delete_file", + "toolCallId": "call-1", + "state": "approval-responded", + "input": {"path": "secrets.env"}, + "approval": {"id": "call-1", "approved": True}, + } + message = ChatMessage( + role="assistant", + content=None, + id="msg", + parts=[ + cast(ChatPart, raw), + TextPart(type="text", text="ok"), + ], + ) + assert message.raw_or_dumped_parts() == [ + raw, + {"type": "text", "text": "ok"}, + ] + def test_dict_part_appended_after_init_is_preserved(self): message = ChatMessage( role="assistant", From 39b5a16a8fcb07b8894728bd91c2f91e1cfdcf04 Mon Sep 17 00:00:00 2001 From: Shahmir Varqha Date: Thu, 21 May 2026 11:02:20 +0800 Subject: [PATCH 08/10] refactor and fix types --- marimo/_ai/_convert.py | 2 +- marimo/_ai/_types.py | 116 ++++++++++++++----------------------- tests/_ai/test_ai_types.py | 22 ++++--- 3 files changed, 59 insertions(+), 81 deletions(-) diff --git a/marimo/_ai/_convert.py b/marimo/_ai/_convert.py index 3fe7b6507be..360efc338ce 100644 --- a/marimo/_ai/_convert.py +++ b/marimo/_ai/_convert.py @@ -164,7 +164,7 @@ def convert_to_openai_messages( # Reset parts since we've added the messages current_parts = [] - else: + elif dataclasses.is_dataclass(part) and not isinstance(part, type): current_parts.append(dataclasses.asdict(part)) # type: ignore if current_parts: diff --git a/marimo/_ai/_types.py b/marimo/_ai/_types.py index f7d02149293..2b8fd969a14 100644 --- a/marimo/_ai/_types.py +++ b/marimo/_ai/_types.py @@ -5,7 +5,7 @@ import mimetypes from collections.abc import Iterator from dataclasses import asdict, dataclass, is_dataclass -from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypedDict, cast +from typing import TYPE_CHECKING, Any, Literal, TypedDict, cast import msgspec @@ -189,7 +189,7 @@ class StepStartPart: ] -class ChatMessage(msgspec.Struct, eq=False, dict=True): +class ChatMessage(msgspec.Struct, dict=True): """ A message in a chat. """ @@ -214,74 +214,53 @@ class ChatMessage(msgspec.Struct, eq=False, dict=True): metadata: Any | None = None def __post_init__(self) -> None: - # Non-struct attributes (via dict=True) so they never reach the wire. - # Snapshots are aligned with the original input order; SDK fields the - # typed dataclasses don't model are preserved even when the input is - # a mix of dicts and typed parts. `_parts_idx[i]` is the index of the - # converted part in `self.parts` (or None when conversion dropped it). + # Non-struct attribute (via `dict=True`) so it isn't serialized. + # Snapshots raw dict inputs 1:1 with `self.parts` so SDK fields the + # typed dataclasses don't model survive the round-trip. self._raw_parts: list[dict[str, Any] | None] | None = None - self._parts_idx: list[int | None] | None = None - if self.parts: - snapshots: list[dict[str, Any] | None] = [] - idx_map: list[int | None] = [] - parts: list[ChatPart] = [] - for part in self.parts: - is_dict = isinstance(part, dict) - snapshots.append( - cast(dict[str, Any], part) if is_dict else None - ) - if converted := self._convert_part(part): - idx_map.append(len(parts)) - parts.append(converted) - else: - idx_map.append(None) - if any(s is not None for s in snapshots): - self._raw_parts = snapshots - self._parts_idx = idx_map - self.parts = parts - - # Fields excluded from equality. Add anything here that is a cache / - # representation detail rather than part of the message's identity, so - # `__eq__` keeps comparing every "real" field automatically as the - # struct grows. - _EQ_EXCLUDE: ClassVar[frozenset[str]] = frozenset() - - def __eq__(self, other: object) -> bool: - if not isinstance(other, ChatMessage): - return NotImplemented - return all( - getattr(self, name) == getattr(other, name) - for name in self.__struct_fields__ - if name not in self._EQ_EXCLUDE - ) + if not self.parts: + return + snapshots: list[dict[str, Any] | None] = [] + typed: list[ChatPart] = [] + for part in self.parts: + converted = self._convert_part(part) + if converted is None: + continue + snapshots.append( + cast(dict[str, Any], part) if isinstance(part, dict) else None + ) + typed.append(converted) + if any(s is not None for s in snapshots): + self._raw_parts = snapshots + self.parts = typed def _convert_part(self, part: Any) -> ChatPart | None: - # If we receive a Vercel AI SDK part (through pydantic-ai), return it as is. if DependencyManager.pydantic_ai.imported(): from pydantic_ai.ui.vercel_ai.request_types import UIMessagePart if isinstance(part, UIMessagePart): return cast(ChatPart, part) - PartType = None - for PartType in PART_TYPES: - try: - if is_dataclass(part): - return cast(ChatPart, part) - return parse_raw(part, cls=PartType, allow_unknown_keys=True) - except Exception: - continue + if is_dataclass(part) and not isinstance(part, type): + return cast(ChatPart, part) - LOGGER.debug( - f"Could not decode part {part}. Ignore if it's a Vercel UI message part." - ) + # Unknown dicts pass through verbatim so future SDK part types still + # round-trip. + if isinstance(part, dict): + for PartType in PART_TYPES: + try: + return parse_raw( + part, cls=PartType, allow_unknown_keys=True + ) + except Exception: + continue + return cast(ChatPart, part) + + LOGGER.debug("Dropping unrecognized part %r", part) return None def raw_or_dumped_parts(self) -> list[dict[str, Any]]: """Return parts in dict form, preferring the original wire payload.""" - raw_parts = self._raw_parts - idx_map = self._parts_idx - ui_message_part_cls: Any = None if DependencyManager.pydantic_ai.imported(): from pydantic_ai.ui.vercel_ai.request_types import UIMessagePart @@ -289,33 +268,26 @@ def raw_or_dumped_parts(self) -> list[dict[str, Any]]: ui_message_part_cls = UIMessagePart def dump(part: Any) -> dict[str, Any] | None: - if is_dataclass(part): + if is_dataclass(part) and not isinstance(part, type): return asdict(part) if ui_message_part_cls is not None and isinstance( part, ui_message_part_cls ): - return part.model_dump(by_alias=True, exclude_none=True) + return cast( + dict[str, Any], + part.model_dump(by_alias=True, exclude_none=True), + ) if isinstance(part, dict): return cast(dict[str, Any], part) return None - if raw_parts is None or idx_map is None: - return [d for p in self.parts if (d := dump(p)) is not None] - + raws = self._raw_parts result: list[dict[str, Any]] = [] - last_used = -1 - for snap, parts_idx in zip(raw_parts, idx_map, strict=True): + for i, part in enumerate(self.parts): + snap = raws[i] if raws is not None and i < len(raws) else None if snap is not None: result.append(snap) - if parts_idx is not None: - last_used = max(last_used, parts_idx) - elif parts_idx is not None and parts_idx < len(self.parts): - if (d := dump(self.parts[parts_idx])) is not None: - result.append(d) - last_used = max(last_used, parts_idx) - # Parts appended after __post_init__ aren't in idx_map; dump them live. - for part in self.parts[last_used + 1 :]: - if (d := dump(part)) is not None: + elif (d := dump(part)) is not None: result.append(d) return result diff --git a/tests/_ai/test_ai_types.py b/tests/_ai/test_ai_types.py index e731a058d9e..4724750610e 100644 --- a/tests/_ai/test_ai_types.py +++ b/tests/_ai/test_ai_types.py @@ -473,7 +473,11 @@ def test_keeps_already_typed_parts(self): assert message.parts[0] is text_part def test_handles_invalid_parts_gracefully(self): - """Test that invalid parts are dropped gracefully.""" + """Unknown dict parts don't crash construction; they're preserved + verbatim so future SDK part types still round-trip on serialization. + See `test_raw_part_round_trips_verbatim` for the contract. + """ + raw_unknown = {"type": "unknown_type", "data": "invalid"} message = ChatMessage( role="user", content="Hello", @@ -481,17 +485,19 @@ def test_handles_invalid_parts_gracefully(self): Any, [ {"type": "text", "text": "Valid"}, - {"type": "unknown_type", "data": "invalid"}, + raw_unknown, ], ), ) - # Valid part should be kept, invalid should be dropped - assert message == ChatMessage( - role="user", - content="Hello", - parts=[TextPart(type="text", text="Valid")], - ) + assert len(message.parts) == 2 + assert isinstance(message.parts[0], TextPart) + assert message.parts[0].text == "Valid" + assert message.parts[1] == raw_unknown + assert message.raw_or_dumped_parts() == [ + {"type": "text", "text": "Valid"}, + raw_unknown, + ] def test_with_none_parts(self): """Test that None parts is handled.""" From 81587532139bf3e712cad0faafc7b06956633602 Mon Sep 17 00:00:00 2001 From: Shahmir Varqha Date: Thu, 21 May 2026 11:11:15 +0800 Subject: [PATCH 09/10] stronger typing --- marimo/_ai/_types.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/marimo/_ai/_types.py b/marimo/_ai/_types.py index 2b8fd969a14..c41a4df39cf 100644 --- a/marimo/_ai/_types.py +++ b/marimo/_ai/_types.py @@ -168,6 +168,8 @@ class StepStartPart: if TYPE_CHECKING: from collections.abc import Iterator + from pydantic_ai.ui.vercel_ai.request_types import UIMessagePart + ChatPart = ( TextPart | ReasoningPart @@ -261,11 +263,13 @@ def _convert_part(self, part: Any) -> ChatPart | None: def raw_or_dumped_parts(self) -> list[dict[str, Any]]: """Return parts in dict form, preferring the original wire payload.""" - ui_message_part_cls: Any = None + ui_message_part_cls: type[UIMessagePart] | None = None if DependencyManager.pydantic_ai.imported(): from pydantic_ai.ui.vercel_ai.request_types import UIMessagePart - ui_message_part_cls = UIMessagePart + # `UIMessagePart` is a union type alias; the runtime value works + # with `isinstance` but doesn't match `type[...]` statically. + ui_message_part_cls = UIMessagePart # type: ignore[assignment] # pyright: ignore[reportAssignmentType] def dump(part: Any) -> dict[str, Any] | None: if is_dataclass(part) and not isinstance(part, type): @@ -273,10 +277,7 @@ def dump(part: Any) -> dict[str, Any] | None: if ui_message_part_cls is not None and isinstance( part, ui_message_part_cls ): - return cast( - dict[str, Any], - part.model_dump(by_alias=True, exclude_none=True), - ) + return part.model_dump(by_alias=True, exclude_none=True) # type: ignore[no-any-return] if isinstance(part, dict): return cast(dict[str, Any], part) return None From 1474735fd14273e1304fe5750e2d3086c2150e2a Mon Sep 17 00:00:00 2001 From: Shahmir Varqha Date: Thu, 21 May 2026 11:28:53 +0800 Subject: [PATCH 10/10] fix type --- marimo/_ai/_convert.py | 5 +++++ marimo/_ai/_types.py | 1 - 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/marimo/_ai/_convert.py b/marimo/_ai/_convert.py index 360efc338ce..198195a62d2 100644 --- a/marimo/_ai/_convert.py +++ b/marimo/_ai/_convert.py @@ -166,6 +166,11 @@ def convert_to_openai_messages( current_parts = [] elif dataclasses.is_dataclass(part) and not isinstance(part, type): current_parts.append(dataclasses.asdict(part)) # type: ignore + else: + LOGGER.debug( + "Dropping unsupported part %s during OpenAI conversion", + type(part).__name__, + ) if current_parts: openai_messages.append( diff --git a/marimo/_ai/_types.py b/marimo/_ai/_types.py index c41a4df39cf..56fa50c5605 100644 --- a/marimo/_ai/_types.py +++ b/marimo/_ai/_types.py @@ -3,7 +3,6 @@ import abc import mimetypes -from collections.abc import Iterator from dataclasses import asdict, dataclass, is_dataclass from typing import TYPE_CHECKING, Any, Literal, TypedDict, cast