diff --git a/embabel-agent-api/src/main/kotlin/com/embabel/agent/api/common/support/OperationContextDelegate.kt b/embabel-agent-api/src/main/kotlin/com/embabel/agent/api/common/support/OperationContextDelegate.kt index 33c813e64..7ff7ed55c 100644 --- a/embabel-agent-api/src/main/kotlin/com/embabel/agent/api/common/support/OperationContextDelegate.kt +++ b/embabel-agent-api/src/main/kotlin/com/embabel/agent/api/common/support/OperationContextDelegate.kt @@ -16,7 +16,6 @@ package com.embabel.agent.api.common.support import com.embabel.agent.api.common.* -import com.embabel.agent.api.common.support.streaming.StreamingCapabilityDetector import com.embabel.agent.api.tool.ArtifactSinkingTool import com.embabel.agent.api.tool.Tool import com.embabel.agent.api.tool.ToolCallContext @@ -38,8 +37,7 @@ import com.embabel.agent.experimental.primitive.Determination import com.embabel.agent.spi.loop.ToolChainingInjectionStrategy import com.embabel.agent.spi.loop.ToolInjectionStrategy import com.embabel.agent.spi.loop.ToolNotFoundPolicy -import com.embabel.agent.spi.support.springai.ChatClientLlmOperations -import com.embabel.agent.spi.support.springai.streaming.StreamingChatClientOperations +import com.embabel.agent.spi.streaming.StreamingLlmOperations import com.embabel.chat.AssistantMessage import com.embabel.chat.ImagePart import com.embabel.chat.Message @@ -321,12 +319,14 @@ internal data class OperationContextDelegate( override fun supportsStreaming(): Boolean { val llmOperations = context.agentPlatform().platformServices.llmOperations - return StreamingCapabilityDetector.supportsStreaming(llmOperations, this.llm) + // Level 1 sanity check + if (llmOperations !is StreamingLlmOperations) return false + + return llmOperations.supportsStreaming(this.llm) } override fun generateStream(): Flux { - val llmOperations = context.agentPlatform().platformServices.llmOperations as ChatClientLlmOperations - val streamingLlmOperations = StreamingChatClientOperations(llmOperations) + val streamingLlmOperations = context.agentPlatform().platformServices.llmOperations as StreamingLlmOperations return streamingLlmOperations.generateStream( messages = messages, @@ -337,8 +337,7 @@ internal data class OperationContextDelegate( } override fun createObjectStream(itemClass: Class): Flux { - val llmOperations = context.agentPlatform().platformServices.llmOperations as ChatClientLlmOperations - val streamingLlmOperations = StreamingChatClientOperations(llmOperations) + val streamingLlmOperations = context.agentPlatform().platformServices.llmOperations as StreamingLlmOperations return streamingLlmOperations.createObjectStream( messages = messages, @@ -350,8 +349,8 @@ internal data class OperationContextDelegate( } override fun createObjectStreamWithThinking(itemClass: Class): Flux> { - val llmOperations = context.agentPlatform().platformServices.llmOperations as ChatClientLlmOperations - val streamingLlmOperations = StreamingChatClientOperations(llmOperations) + val streamingLlmOperations = context.agentPlatform().platformServices.llmOperations as StreamingLlmOperations + return streamingLlmOperations.createObjectStreamWithThinking( messages = messages, interaction = streamingInteraction(), diff --git a/embabel-agent-api/src/main/kotlin/com/embabel/agent/api/common/support/streaming/StreamingCapabilityDetector.kt b/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/streaming/StreamingCapabilityDetector.kt similarity index 74% rename from embabel-agent-api/src/main/kotlin/com/embabel/agent/api/common/support/streaming/StreamingCapabilityDetector.kt rename to embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/streaming/StreamingCapabilityDetector.kt index b732ced9e..f95b20d7a 100644 --- a/embabel-agent-api/src/main/kotlin/com/embabel/agent/api/common/support/streaming/StreamingCapabilityDetector.kt +++ b/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/streaming/StreamingCapabilityDetector.kt @@ -13,15 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.embabel.agent.api.common.support.streaming +package com.embabel.agent.spi.streaming -import com.embabel.agent.api.common.InteractionId -import com.embabel.agent.api.common.support.streaming.StreamingCapabilityDetector.supportsStreaming -import com.embabel.agent.core.internal.LlmOperations -import com.embabel.agent.core.support.LlmInteraction -import com.embabel.agent.spi.support.springai.ChatClientLlmOperations -import com.embabel.agent.spi.support.springai.SpringAiLlmService -import com.embabel.common.ai.model.LlmOptions import com.embabel.common.util.loggerFor import org.springframework.ai.chat.messages.UserMessage import org.springframework.ai.chat.model.ChatModel @@ -63,27 +56,6 @@ internal object StreamingCapabilityDetector { } } - /** - * Delegates to [supportsStreaming] - */ - fun supportsStreaming(llmOperations: LlmOperations, llmOptions: LlmOptions): Boolean { - - // Level 1 sanity check - if (llmOperations !is ChatClientLlmOperations) return false - - // Level 2: Must have actual streaming capability - val llm = llmOperations.getLlm( - LlmInteraction( // check for circular dependency - id = InteractionId("capability-check"), - llm = llmOptions - ) - ) - - val springAiLlm = llm as? SpringAiLlmService ?: return false - return supportsStreaming(springAiLlm.chatModel) - - } - private fun testStreamingCapability(model: ChatModel): Boolean { return try { // Use a prompt that should generate a response diff --git a/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/streaming/StreamingLlmOperations.kt b/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/streaming/StreamingLlmOperations.kt index 0e13a035b..c3809b0c2 100644 --- a/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/streaming/StreamingLlmOperations.kt +++ b/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/streaming/StreamingLlmOperations.kt @@ -20,7 +20,7 @@ import com.embabel.agent.core.Action import com.embabel.agent.core.AgentProcess import com.embabel.agent.core.support.LlmInteraction import com.embabel.chat.Message -import com.embabel.chat.UserMessage +import com.embabel.common.ai.model.LlmOptions import com.embabel.common.core.streaming.StreamingEvent import reactor.core.publisher.Flux @@ -40,26 +40,12 @@ import reactor.core.publisher.Flux interface StreamingLlmOperations { /** - * Generate streaming text in the context of an AgentProcess. - * Returns a Flux that emits text chunks as they arrive from the LLM. + * Tests whether the given chat model actually supports streaming operations. * - * @param prompt Prompt to generate text from - * @param interaction Llm options and tool callbacks to use, plus unique identifier - * @param agentProcess Agent process we are running within - * @param action Action we are running within if we are running within an action - * @return Flux of text chunks as they arrive from the LLM + * @param llmOptions LlmOptions to use, including the name of the model to test + * @return true if the model supports streaming, false otherwise */ - fun generateStream( - prompt: String, - interaction: LlmInteraction, - agentProcess: AgentProcess, - action: Action?, - ): Flux = generateStream( - messages = listOf(UserMessage(prompt)), - interaction = interaction, - agentProcess = agentProcess, - action = action, - ) + fun supportsStreaming(llmOptions: LlmOptions): Boolean /** * Generate streaming text from messages in the context of an AgentProcess. @@ -76,7 +62,9 @@ interface StreamingLlmOperations { interaction: LlmInteraction, agentProcess: AgentProcess, action: Action?, - ): Flux + ): Flux { + return doTransformStream(messages, interaction, null, agentProcess, action) + } /** * Create a streaming list of objects from JSONL response in the context of an AgentProcess. @@ -98,7 +86,9 @@ interface StreamingLlmOperations { outputClass: Class, agentProcess: AgentProcess, action: Action?, - ): Flux + ): Flux { + return doTransformObjectStream(messages, interaction, outputClass, null, agentProcess, action) + } /** * Try to create a streaming list of objects in the context of an AgentProcess. @@ -118,7 +108,13 @@ interface StreamingLlmOperations { outputClass: Class, agentProcess: AgentProcess, action: Action?, - ): Flux> + ): Flux> { + return createObjectStream(messages, interaction, outputClass, agentProcess, action) + .map { Result.success(it) } + .onErrorResume { throwable -> + Flux.just(Result.failure(throwable)) + } + } /** * Create a streaming list of objects with LLM thinking content from mixed JSONL response. @@ -141,7 +137,9 @@ interface StreamingLlmOperations { outputClass: Class, agentProcess: AgentProcess, action: Action?, - ): Flux> + ): Flux> { + return doTransformObjectStreamWithThinking(messages, interaction, outputClass, null, agentProcess, action) + } /** * Low level streaming transform with optional platform context. diff --git a/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/springai/ChatClientLlmOperations.kt b/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/springai/ChatClientLlmOperations.kt index b8ecf955f..e3c1ee4fc 100644 --- a/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/springai/ChatClientLlmOperations.kt +++ b/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/springai/ChatClientLlmOperations.kt @@ -16,6 +16,8 @@ package com.embabel.agent.spi.support.springai import com.embabel.agent.api.common.Asyncer +import com.embabel.agent.api.common.InteractionId +import com.embabel.agent.spi.streaming.StreamingCapabilityDetector import com.embabel.agent.api.event.LlmRequestEvent import com.embabel.agent.api.tool.Tool import com.embabel.agent.api.tool.config.ToolLoopConfiguration @@ -29,20 +31,25 @@ import com.embabel.agent.spi.ToolDecorator import com.embabel.agent.spi.loop.AutoCorrectionPolicy import com.embabel.agent.spi.loop.LlmMessageSender import com.embabel.agent.spi.loop.ToolLoopFactory +import com.embabel.agent.spi.loop.streaming.LlmMessageStreamer import com.embabel.agent.spi.support.LlmDataBindingProperties import com.embabel.agent.spi.support.LlmOperationsPromptsProperties import com.embabel.agent.spi.support.MaybeReturn import com.embabel.agent.spi.support.OutputConverter +import com.embabel.agent.spi.support.PROMPT_ELEMENT_SEPARATOR import com.embabel.agent.spi.support.ToolLoopLlmOperations import com.embabel.agent.spi.support.ToolResolutionHelper import com.embabel.agent.spi.support.guardrails.validateAssistantResponse import com.embabel.agent.spi.support.guardrails.validateUserInput +import com.embabel.agent.spi.support.springai.streaming.SpringAiLlmMessageStreamer import com.embabel.agent.spi.validation.DefaultValidationPromptGenerator import com.embabel.agent.spi.validation.ValidationPromptGenerator import com.embabel.chat.Message import com.embabel.common.ai.converters.FilteringJacksonOutputConverter +import com.embabel.common.ai.converters.streaming.StreamingJacksonOutputConverter import com.embabel.common.ai.model.LlmOptions import com.embabel.common.ai.model.ModelProvider +import com.embabel.common.core.streaming.StreamingEvent import com.embabel.common.core.thinking.ThinkingException import com.embabel.common.core.thinking.ThinkingResponse import com.embabel.common.core.thinking.spi.InternalThinkingApi @@ -75,6 +82,8 @@ import org.springframework.context.ApplicationContext import org.springframework.core.ParameterizedTypeReference import org.springframework.retry.support.RetrySynchronizationManager import org.springframework.stereotype.Service +import reactor.core.publisher.Flux +import reactor.core.publisher.Mono // Log message constants to avoid duplication private const val LLM_TIMEOUT_MESSAGE = "LLM {}: attempt {} timed out after {}ms" @@ -93,6 +102,7 @@ private const val LLM_INTERRUPTED_MESSAGE = "LLM {}: attempt {} was interrupted" * @param toolDecorator ToolDecorator to decorate tools to make them aware of platform * @param templateRenderer TemplateRenderer to render templates * @param dataBindingProperties properties + * @param useMessageStreamer When true, delegates raw streaming to [LlmMessageStreamer] instead of calling Spring AI ChatClient directly. This decouples streaming from Spring AI, enabling vendor-neutral implementations. */ @ThreadSafe @Service @@ -111,6 +121,7 @@ internal class ChatClientLlmOperations( private val customizers: List = emptyList(), asyncer: Asyncer, toolLoopFactory: ToolLoopFactory = ToolLoopFactory.create(ToolLoopConfiguration(), asyncer, AutoCorrectionPolicy()), + private val useMessageStreamer: Boolean = false, ) : ToolLoopLlmOperations( toolDecorator = toolDecorator, modelProvider = modelProvider, @@ -124,7 +135,7 @@ internal class ChatClientLlmOperations( toolLoopFactory = toolLoopFactory, asyncer = asyncer, templateRenderer = templateRenderer, -) { +), LlmOperationsIncludingStreaming { @PostConstruct private fun logPropertyConfiguration() { @@ -578,10 +589,7 @@ internal class ChatClientLlmOperations( } } - /** - * Expose LLM selection for streaming operations - */ - internal fun getLlm(interaction: LlmInteraction): LlmService<*> = chooseLlm(interaction.llm) + private fun getLlm(interaction: LlmInteraction): LlmService<*> = chooseLlm(interaction.llm) /** * Require the LLM to be a SpringAiLlm for Spring AI specific operations. @@ -604,7 +612,7 @@ internal class ChatClientLlmOperations( * @param llm the LLM service to create a client for * @param llmRequestEvent optional domain context; when present, enables instrumentation */ - internal fun createChatClient( + private fun createChatClient( llm: LlmService<*>, llmRequestEvent: LlmRequestEvent<*>? = null, ): ChatClient { @@ -896,6 +904,359 @@ internal class ChatClientLlmOperations( action = action, toolDecorator = toolDecorator, ) + + override fun supportsStreaming(llmOptions: LlmOptions): Boolean { + val llm = getLlm( + LlmInteraction( + id = InteractionId("capability-check"), + llm = llmOptions + ) + ) + val springAiLlm = llm as? SpringAiLlmService ?: return false + return StreamingCapabilityDetector.supportsStreaming(springAiLlm.chatModel) + } + + override fun doTransformStream( + messages: List, + interaction: LlmInteraction, + llmRequestEvent: LlmRequestEvent?, + agentProcess: AgentProcess?, + action: Action?, + ): Flux { + // Use ChatClientLlmOperations to get LLM and create ChatClient + val llm = getLlm(interaction) + val chatClient = createChatClient(llm) + + // Build prompt using helper methods + val promptContributions = buildPromptContributions(interaction, llm) + val springAiPrompt = buildBasicPrompt(promptContributions, messages) + + // Guardrails: Pre-validation of user input + val userMessages = messages.filterIsInstance() + validateUserInput(userMessages, interaction, llmRequestEvent?.agentProcess?.blackboard) + + val chatOptions = requireSpringAiLlm(llm).optionsConverter.convertOptions(interaction.llm) + + // Resolve tool groups and decorate tools + val tools = resolveAndDecorateTools(interaction, agentProcess, action) + + return createStreamInternal( + chatClient = chatClient, + messages = messages, + promptContributions = promptContributions, + tools = tools, + chatOptions = chatOptions, + springAiPrompt = springAiPrompt, + ) + } + + /** + * Creates a stream of typed objects from LLM JSONL responses, with thinking content suppressed. + * + * This method provides a clean object-only stream by filtering the internal unified stream + * to exclude thinking content and extract only typed objects. + * + * **Stream Characteristics:** + * - **Input**: Raw LLM chunks containing JSONL + thinking content + * - **Processing**: Chunks → Lines → Events → Objects (thinking filtered out) + * - **Output**: `Flux` containing only parsed typed objects + * - **Error Handling**: Malformed JSON is skipped; stream continues + * - **Backpressure**: Supports standard Flux operators and subscription patterns + * + * **Example Usage:** + * ```kotlin + * val objectStream: Flux = doTransformObjectStream(messages, interaction, User::class.java, null) + * + * objectStream + * .doOnNext { user -> println("Received user: ${user.name}") } + * .doOnError { error -> logger.error("Stream error", error) } + * .doOnComplete { println("Stream completed") } + * .subscribe() + * ``` + * + * **Difference from doTransformObjectStreamWithThinking:** + * - This method: Returns `Flux` with only objects (thinking suppressed) + * - WithThinking: Returns `Flux>` with both thinking and objects + * + * @param messages The conversation messages to send to LLM + * @param interaction LLM configuration and context + * @param outputClass The target class for object deserialization + * @param llmRequestEvent Optional event for tracking/observability + * @return Flux of typed objects, thinking content filtered out + */ + override fun doTransformObjectStream( + messages: List, + interaction: LlmInteraction, + outputClass: Class, + llmRequestEvent: LlmRequestEvent?, + agentProcess: AgentProcess?, + action: Action?, + ): Flux { + return doTransformObjectStreamInternal( + messages = messages, + interaction = interaction, + outputClass = outputClass, + llmRequestEvent = llmRequestEvent, + agentProcess = agentProcess, + action = action, + ) + .filter { it.isObject() } + .map { (it as StreamingEvent.Object).item } + } + + /** + * Creates a mixed stream containing both LLM thinking content and typed objects. + * + * This method returns the full unified stream without filtering, allowing users to receive + * both thinking events (LLM reasoning) and object events (parsed JSON data) in the order + * they appear in the LLM response. + * + * **Stream Characteristics:** + * - **Input**: Raw LLM chunks containing JSONL + thinking content + * - **Processing**: Chunks → Lines → Events (both thinking and objects preserved) + * - **Output**: `Flux>` with mixed content + * - **Event Types**: `StreamingEvent.Thinking(content)` and `StreamingEvent.Object(data)` + * - **Error Handling**: Malformed JSON treated as thinking content; stream continues + * + * **Example Usage:** + * ```kotlin + * val mixedStream: Flux> = doTransformObjectStreamWithThinking(...) + * + * mixedStream.subscribe { event -> + * when { + * event.isThinking() -> println("LLM thinking: ${event.getThinking()}") + * event.isObject() -> println("User object: ${event.getObject()}") + * } + * } + * ``` + * + * **User Filtering Options:** + * ```kotlin + * // Get only thinking content: + * val thinkingOnly = mixedStream.filter { it.isThinking() }.map { it.getThinking()!! } + * + * // Get only objects (equivalent to doTransformObjectStream): + * val objectsOnly = mixedStream.filter { it.isObject() }.map { it.getObject()!! } + * ``` + * + * @param messages The conversation messages to send to LLM + * @param interaction LLM configuration and context + * @param outputClass The target class for object deserialization + * @param llmRequestEvent Optional event for tracking/observability + * @return Flux of StreamingEvent containing both thinking and object events + */ + override fun doTransformObjectStreamWithThinking( + messages: List, + interaction: LlmInteraction, + outputClass: Class, + llmRequestEvent: LlmRequestEvent?, + agentProcess: AgentProcess?, + action: Action?, + ): Flux> { + return doTransformObjectStreamInternal( + messages = messages, + interaction = interaction, + outputClass = outputClass, + llmRequestEvent = llmRequestEvent, + agentProcess = agentProcess, + action = action, + ) + } + + /** + * Internal unified streaming implementation - workhorse -that handles the complete transformation pipeline. + * + * This method implements a robust 3-step transformation pipeline: + * 1. **Raw LLM Chunks**: Receives arbitrary-sized chunks from LLM via Spring AI ChatClient + * 2. **Line Buffering**: Accumulates chunks into complete logical lines using stateful LineBuffer + * 3. **Event Generation**: Classifies lines as thinking vs objects, converts to StreamingEvent + * + * **Design Principles:** + * - **Single Source of Truth**: All streaming logic centralized here + * - **Error Isolation**: Malformed lines don't break the entire stream + * - **Order Preservation**: Events maintain LLM response order via concatMap + * - **Backpressure Support**: Full Flux lifecycle support with reactive operators + * + * **Event Types Generated:** + * - `StreamingEvent.Thinking(content)`: LLM reasoning text (from `` blocks or prefix thinking) + * - `StreamingEvent.Object(data)`: Parsed typed objects from JSONL content + * + * **Error Handling Strategy:** + * - Chunk processing errors: Skip chunk, continue stream + * - Line classification errors: Treat as thinking content + * - JSON parsing errors: Skip line, continue processing + * - Stream continues on individual failures to maximize data recovery + * + * **Performance Characteristics:** + * - Streaming-friendly: no blocking operations + * + * @return Unified Flux> that public methods can filter as needed + */ + private fun doTransformObjectStreamInternal( + messages: List, + interaction: LlmInteraction, + outputClass: Class, + @Suppress("UNUSED_PARAMETER") + llmRequestEvent: LlmRequestEvent?, + agentProcess: AgentProcess?, + action: Action?, + ): Flux> { + // Common setup - delegate to ChatClientLlmOperations for LLM setup + val llm = getLlm(interaction) + // Chat Client + val chatClient = createChatClient(llm) + // Chat Options, additional potential option "streaming" + val chatOptions = requireSpringAiLlm(llm).optionsConverter.convertOptions(interaction.llm) + + val streamingConverter = StreamingJacksonOutputConverter( + clazz = outputClass, + objectMapper = objectMapper, + fieldFilter = interaction.fieldFilter + ) + + // Build prompt using helper methods, including streaming format instructions + val promptContributions = buildPromptContributions(interaction, llm) + val streamingFormatInstructions = streamingConverter.getFormat() + logger.debug("STREAMING FORMAT INSTRUCTIONS: $streamingFormatInstructions") + val fullPromptContributions = if (promptContributions.isNotEmpty()) { + "$promptContributions$PROMPT_ELEMENT_SEPARATOR$streamingFormatInstructions" + } else { + streamingFormatInstructions + } + val springAiPrompt = buildBasicPrompt(fullPromptContributions, messages) + + // Guardrails: Pre-validation of user input + val userMessages = messages.filterIsInstance() + validateUserInput(userMessages, interaction, llmRequestEvent?.agentProcess?.blackboard) + + // Resolve tool groups and decorate tools + val tools = resolveAndDecorateTools(interaction, agentProcess, action) + + // Step 1: Original raw chunk stream from LLM + val rawChunkFlux: Flux = createStreamInternal( + chatClient = chatClient, + messages = messages, + promptContributions = fullPromptContributions, + tools = tools, + chatOptions = chatOptions, + springAiPrompt = springAiPrompt, + ).filter { it.isNotEmpty() } + .doOnNext { chunk -> logger.trace("RAW CHUNK: '${chunk.replace("\n", "\\n")}'") } + + // Step 2: Transform raw chunks to complete newline-delimited lines + val lineFlux: Flux = rawChunkFlux + .transform { chunkFlux -> rawChunksToLines(chunkFlux) } + .doOnNext { line -> logger.trace("COMPLETE LINE: '$line'") } + + // Step 3: Final flux of StreamingEvent (thinking + objects) + val event = lineFlux + .concatMap { line -> streamingConverter.convertStreamWithThinking(line) } + + return event + } + + /** + * Convert raw streaming chunks → NDJSON lines + * Handles all general cases: + * - multiple \n in one chunk + * - no \n in chunk + * - line spanning many chunks + */ + fun rawChunksToLines(raw: Flux): Flux { + val buffer = StringBuilder() + return raw.concatMap { chunk -> // ONLY CHANGE: handle → concatMap + buffer.append(chunk) + val lines = mutableListOf() + while (true) { + val idx = buffer.indexOf('\n') + if (idx < 0) break + val line = buffer.substring(0, idx).trim() + if (line.isNotEmpty()) lines.add(line) + buffer.delete(0, idx + 1) + } + + Flux.fromIterable(lines) // emit multiple lines + + }.doOnComplete { + // Log any remaining buffer content when stream ends + if (buffer.isNotEmpty()) { + val finalLine = buffer.toString().trim() + if (finalLine.isNotEmpty()) { + logger.trace("FINAL LINE: '$finalLine'") + } + } + }.concatWith( + // final emit + Mono.fromSupplier { buffer.toString().trim() } + .filter { it.isNotEmpty() } + ) + } + + /* ------------------------------------------------------------------------- + * Streaming Abstraction Layer + * + * Supports decoupling streaming from Spring AI via LlmMessageStreamer interface. + * Controlled by useMessageStreamer flag: + * - false (default): uses Spring AI ChatClient directly + * - true: delegates to vendor-neutral LlmMessageStreamer + * + * Enables future support for non-Spring AI providers (e.g., LangChain4j). + * ------------------------------------------------------------------------ */ + + /** + * Build message list with prompt contributions prepended as system message. + * + * Mirrors [buildBasicPrompt] but returns Embabel messages instead of Spring AI Prompt. + * Used by the decoupled streaming path (when useMessageStreamer=true). + * + * @param messages Conversation messages + * @param promptContributions Prompt contributions to prepend + * @return Message list with contributions as first system message (if non-empty) + */ + private fun buildMessagesWithContributions( + messages: List, + promptContributions: String, + ): List = buildList { + if (promptContributions.isNotEmpty()) { + add(com.embabel.chat.SystemMessage(promptContributions)) + } + addAll(messages) + } + + /** + * Create raw content stream from LLM. + * + * Switches between decoupled path (LlmMessageStreamer) and current path (Spring AI direct) + * based on [useMessageStreamer] flag. + * + * @param chatClient Spring AI ChatClient instance + * @param messages Embabel conversation messages + * @param promptContributions Prompt contributions string + * @param tools Embabel tools available for LLM + * @param chatOptions Spring AI chat options + * @param springAiPrompt Pre-built Spring AI prompt (used when useMessageStreamer=false) + * @return Flux of raw content chunks + */ + private fun createStreamInternal( + chatClient: org.springframework.ai.chat.client.ChatClient, + messages: List, + promptContributions: String, + tools: List, + chatOptions: org.springframework.ai.chat.prompt.ChatOptions, + springAiPrompt: Prompt, + ): Flux { + return if (useMessageStreamer) { + val streamerMessages = buildMessagesWithContributions(messages, promptContributions) + SpringAiLlmMessageStreamer(chatClient, chatOptions).stream(streamerMessages, tools) + } else { + chatClient + .prompt(springAiPrompt) + .toolCallbacks(tools.toSpringToolCallbacks()) + .options(chatOptions) + .stream() + .content() + } + } } /** diff --git a/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/springai/LlmOperationsIncludingStreaming.kt b/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/springai/LlmOperationsIncludingStreaming.kt new file mode 100644 index 000000000..830199d88 --- /dev/null +++ b/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/springai/LlmOperationsIncludingStreaming.kt @@ -0,0 +1,21 @@ +/* + * Copyright 2024-2026 Embabel Pty Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.embabel.agent.spi.support.springai + +import com.embabel.agent.core.internal.LlmOperations +import com.embabel.agent.spi.streaming.StreamingLlmOperations + +interface LlmOperationsIncludingStreaming : LlmOperations, StreamingLlmOperations diff --git a/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/springai/streaming/StreamingChatClientOperations.kt b/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/springai/streaming/StreamingChatClientOperations.kt deleted file mode 100644 index 63c015e7c..000000000 --- a/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/springai/streaming/StreamingChatClientOperations.kt +++ /dev/null @@ -1,487 +0,0 @@ -/* - * Copyright 2024-2026 Embabel Pty Ltd. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.embabel.agent.spi.support.springai.streaming - -import com.embabel.agent.api.event.LlmRequestEvent -import com.embabel.agent.core.Action -import com.embabel.agent.core.AgentProcess -import com.embabel.agent.core.support.LlmInteraction -import com.embabel.agent.spi.LlmService -import com.embabel.agent.spi.loop.streaming.LlmMessageStreamer -import com.embabel.agent.spi.streaming.StreamingLlmOperations -import com.embabel.agent.spi.support.PROMPT_ELEMENT_SEPARATOR -import com.embabel.agent.spi.support.guardrails.validateUserInput -import com.embabel.agent.spi.support.springai.ChatClientLlmOperations -import com.embabel.agent.spi.support.springai.SpringAiLlmService -import com.embabel.agent.spi.support.springai.toSpringAiMessage -import com.embabel.agent.spi.support.springai.toSpringToolCallbacks -import com.embabel.chat.Message -import com.embabel.common.ai.converters.streaming.StreamingJacksonOutputConverter -import com.embabel.common.core.streaming.StreamingEvent -import org.slf4j.LoggerFactory -import org.springframework.ai.chat.messages.SystemMessage -import org.springframework.ai.chat.prompt.Prompt -import reactor.core.publisher.Flux -import reactor.core.publisher.Mono - -/** - * Streaming implementation that provides real-time LLM response processing with unified event streams. - * - * Delegates to ChatClientLlmOperations for core LLM functionality while adding sophisticated - * streaming capabilities that handle chunk-to-line buffering, thinking content classification, - * and typed object parsing. - * - * **Core Capabilities:** - * - **Raw Text Streaming**: Direct access to LLM chunks as they arrive - * - **Typed Object Streaming**: Real-time JSONL parsing to typed objects - * - **Mixed Content Streaming**: Combined thinking + object events in unified stream - * - **Error Resilience**: Individual line failures don't break the stream - * - **Backpressure Support**: Full reactive streaming with lifecycle management - * - * **Unified Architecture:** - * All streaming methods are built on a single internal pipeline that emits `StreamingEvent`, - * allowing consistent behavior and the flexibility to filter events as needed by different use cases. - */ -internal class StreamingChatClientOperations( - private val chatClientLlmOperations: ChatClientLlmOperations, - /** - * When true, delegates raw streaming to [LlmMessageStreamer] instead of - * calling Spring AI ChatClient directly. This decouples streaming from - * Spring AI, enabling vendor-neutral implementations. - */ - private val useMessageStreamer: Boolean = false, -) : StreamingLlmOperations { - - // once streaming feature gets stable set log level to TRACE - private val logger = LoggerFactory.getLogger(StreamingChatClientOperations::class.java) - - /** - * Build prompt contributions string from interaction and LLM contributors. - * Consider helper - */ - private fun buildPromptContributions(interaction: LlmInteraction, llm: LlmService<*>): String { - return (interaction.promptContributors + llm.promptContributors) - .joinToString(PROMPT_ELEMENT_SEPARATOR) { it.contribution() } - } - - /** - * Build Spring AI Prompt from messages and contributions. - * Consider helper - */ - private fun buildSpringAiPrompt(messages: List, promptContributions: String): Prompt { - return Prompt( - buildList { - if (promptContributions.isNotEmpty()) { - add(SystemMessage(promptContributions)) - } - addAll(messages.map { it.toSpringAiMessage() }) - } - ) - } - - override fun generateStream( - messages: List, - interaction: LlmInteraction, - agentProcess: AgentProcess, - action: Action?, - ): Flux { - return doTransformStream(messages, interaction, null, agentProcess, action) - } - - override fun createObjectStream( - messages: List, - interaction: LlmInteraction, - outputClass: Class, - agentProcess: AgentProcess, - action: Action?, - ): Flux { - return doTransformObjectStream(messages, interaction, outputClass, null, agentProcess, action) - } - - override fun createObjectStreamWithThinking( - messages: List, - interaction: LlmInteraction, - outputClass: Class, - agentProcess: AgentProcess, - action: Action?, - ): Flux> { - return doTransformObjectStreamWithThinking(messages, interaction, outputClass, null, agentProcess, action) - } - - override fun createObjectStreamIfPossible( - messages: List, - interaction: LlmInteraction, - outputClass: Class, - agentProcess: AgentProcess, - action: Action?, - ): Flux> { - return createObjectStream(messages, interaction, outputClass, agentProcess, action) - .map { Result.success(it) } - .onErrorResume { throwable -> - Flux.just(Result.failure(throwable)) - } - } - - /** - * Require the LLM to be a SpringAiLlm for Spring AI specific operations. - */ - private fun requireSpringAiLlm(llm: LlmService<*>): SpringAiLlmService { - return llm as? SpringAiLlmService - ?: throw IllegalStateException("StreamingChatClientOperations requires SpringAiLlm, got ${llm::class.simpleName}") - } - - override fun doTransformStream( - messages: List, - interaction: LlmInteraction, - llmRequestEvent: LlmRequestEvent?, - agentProcess: AgentProcess?, - action: Action?, - ): Flux { - // Use ChatClientLlmOperations to get LLM and create ChatClient - val llm = chatClientLlmOperations.getLlm(interaction) - val chatClient = chatClientLlmOperations.createChatClient(llm) - - // Build prompt using helper methods - val promptContributions = buildPromptContributions(interaction, llm) - val springAiPrompt = buildSpringAiPrompt(messages, promptContributions) - - // Guardrails: Pre-validation of user input - val userMessages = messages.filterIsInstance() - validateUserInput(userMessages, interaction, llmRequestEvent?.agentProcess?.blackboard) - - val chatOptions = requireSpringAiLlm(llm).optionsConverter.convertOptions(interaction.llm) - - // Resolve tool groups and decorate tools - val tools = chatClientLlmOperations.resolveAndDecorateTools(interaction, agentProcess, action) - - return createStreamInternal( - chatClient = chatClient, - messages = messages, - promptContributions = promptContributions, - tools = tools, - chatOptions = chatOptions, - springAiPrompt = springAiPrompt, - ) - } - - /** - * Creates a stream of typed objects from LLM JSONL responses, with thinking content suppressed. - * - * This method provides a clean object-only stream by filtering the internal unified stream - * to exclude thinking content and extract only typed objects. - * - * **Stream Characteristics:** - * - **Input**: Raw LLM chunks containing JSONL + thinking content - * - **Processing**: Chunks → Lines → Events → Objects (thinking filtered out) - * - **Output**: `Flux` containing only parsed typed objects - * - **Error Handling**: Malformed JSON is skipped; stream continues - * - **Backpressure**: Supports standard Flux operators and subscription patterns - * - * **Example Usage:** - * ```kotlin - * val objectStream: Flux = doTransformObjectStream(messages, interaction, User::class.java, null) - * - * objectStream - * .doOnNext { user -> println("Received user: ${user.name}") } - * .doOnError { error -> logger.error("Stream error", error) } - * .doOnComplete { println("Stream completed") } - * .subscribe() - * ``` - * - * **Difference from doTransformObjectStreamWithThinking:** - * - This method: Returns `Flux` with only objects (thinking suppressed) - * - WithThinking: Returns `Flux>` with both thinking and objects - * - * @param messages The conversation messages to send to LLM - * @param interaction LLM configuration and context - * @param outputClass The target class for object deserialization - * @param llmRequestEvent Optional event for tracking/observability - * @return Flux of typed objects, thinking content filtered out - */ - override fun doTransformObjectStream( - messages: List, - interaction: LlmInteraction, - outputClass: Class, - llmRequestEvent: LlmRequestEvent?, - agentProcess: AgentProcess?, - action: Action?, - ): Flux { - return doTransformObjectStreamInternal( - messages = messages, - interaction = interaction, - outputClass = outputClass, - llmRequestEvent = llmRequestEvent, - agentProcess = agentProcess, - action = action, - ) - .filter { it.isObject() } - .map { (it as StreamingEvent.Object).item } - } - - /** - * Creates a mixed stream containing both LLM thinking content and typed objects. - * - * This method returns the full unified stream without filtering, allowing users to receive - * both thinking events (LLM reasoning) and object events (parsed JSON data) in the order - * they appear in the LLM response. - * - * **Stream Characteristics:** - * - **Input**: Raw LLM chunks containing JSONL + thinking content - * - **Processing**: Chunks → Lines → Events (both thinking and objects preserved) - * - **Output**: `Flux>` with mixed content - * - **Event Types**: `StreamingEvent.Thinking(content)` and `StreamingEvent.Object(data)` - * - **Error Handling**: Malformed JSON treated as thinking content; stream continues - * - * **Example Usage:** - * ```kotlin - * val mixedStream: Flux> = doTransformObjectStreamWithThinking(...) - * - * mixedStream.subscribe { event -> - * when { - * event.isThinking() -> println("LLM thinking: ${event.getThinking()}") - * event.isObject() -> println("User object: ${event.getObject()}") - * } - * } - * ``` - * - * **User Filtering Options:** - * ```kotlin - * // Get only thinking content: - * val thinkingOnly = mixedStream.filter { it.isThinking() }.map { it.getThinking()!! } - * - * // Get only objects (equivalent to doTransformObjectStream): - * val objectsOnly = mixedStream.filter { it.isObject() }.map { it.getObject()!! } - * ``` - * - * @param messages The conversation messages to send to LLM - * @param interaction LLM configuration and context - * @param outputClass The target class for object deserialization - * @param llmRequestEvent Optional event for tracking/observability - * @return Flux of StreamingEvent containing both thinking and object events - */ - override fun doTransformObjectStreamWithThinking( - messages: List, - interaction: LlmInteraction, - outputClass: Class, - llmRequestEvent: LlmRequestEvent?, - agentProcess: AgentProcess?, - action: Action?, - ): Flux> { - return doTransformObjectStreamInternal( - messages = messages, - interaction = interaction, - outputClass = outputClass, - llmRequestEvent = llmRequestEvent, - agentProcess = agentProcess, - action = action, - ) - } - - /** - * Internal unified streaming implementation - workhorse -that handles the complete transformation pipeline. - * - * This method implements a robust 3-step transformation pipeline: - * 1. **Raw LLM Chunks**: Receives arbitrary-sized chunks from LLM via Spring AI ChatClient - * 2. **Line Buffering**: Accumulates chunks into complete logical lines using stateful LineBuffer - * 3. **Event Generation**: Classifies lines as thinking vs objects, converts to StreamingEvent - * - * **Design Principles:** - * - **Single Source of Truth**: All streaming logic centralized here - * - **Error Isolation**: Malformed lines don't break the entire stream - * - **Order Preservation**: Events maintain LLM response order via concatMap - * - **Backpressure Support**: Full Flux lifecycle support with reactive operators - * - * **Event Types Generated:** - * - `StreamingEvent.Thinking(content)`: LLM reasoning text (from `` blocks or prefix thinking) - * - `StreamingEvent.Object(data)`: Parsed typed objects from JSONL content - * - * **Error Handling Strategy:** - * - Chunk processing errors: Skip chunk, continue stream - * - Line classification errors: Treat as thinking content - * - JSON parsing errors: Skip line, continue processing - * - Stream continues on individual failures to maximize data recovery - * - * **Performance Characteristics:** - * - Streaming-friendly: no blocking operations - * - * @return Unified Flux> that public methods can filter as needed - */ - private fun doTransformObjectStreamInternal( - messages: List, - interaction: LlmInteraction, - outputClass: Class, - @Suppress("UNUSED_PARAMETER") - llmRequestEvent: LlmRequestEvent?, - agentProcess: AgentProcess?, - action: Action?, - ): Flux> { - // Common setup - delegate to ChatClientLlmOperations for LLM setup - val llm = chatClientLlmOperations.getLlm(interaction) - // Chat Client - val chatClient = chatClientLlmOperations.createChatClient(llm) - // Chat Options, additional potential option "streaming" - val chatOptions = requireSpringAiLlm(llm).optionsConverter.convertOptions(interaction.llm) - - val streamingConverter = StreamingJacksonOutputConverter( - clazz = outputClass, - objectMapper = chatClientLlmOperations.objectMapper, - fieldFilter = interaction.fieldFilter - ) - - // Build prompt using helper methods, including streaming format instructions - val promptContributions = buildPromptContributions(interaction, llm) - val streamingFormatInstructions = streamingConverter.getFormat() - logger.debug("STREAMING FORMAT INSTRUCTIONS: $streamingFormatInstructions") - val fullPromptContributions = if (promptContributions.isNotEmpty()) { - "$promptContributions$PROMPT_ELEMENT_SEPARATOR$streamingFormatInstructions" - } else { - streamingFormatInstructions - } - val springAiPrompt = buildSpringAiPrompt(messages, fullPromptContributions) - - // Guardrails: Pre-validation of user input - val userMessages = messages.filterIsInstance() - validateUserInput(userMessages, interaction, llmRequestEvent?.agentProcess?.blackboard) - - // Resolve tool groups and decorate tools - val tools = chatClientLlmOperations.resolveAndDecorateTools(interaction, agentProcess, action) - - // Step 1: Original raw chunk stream from LLM - val rawChunkFlux: Flux = createStreamInternal( - chatClient = chatClient, - messages = messages, - promptContributions = fullPromptContributions, - tools = tools, - chatOptions = chatOptions, - springAiPrompt = springAiPrompt, - ).filter { it.isNotEmpty() } - .doOnNext { chunk -> logger.trace("RAW CHUNK: '${chunk.replace("\n", "\\n")}'") } - - // Step 2: Transform raw chunks to complete newline-delimited lines - val lineFlux: Flux = rawChunkFlux - .transform { chunkFlux -> rawChunksToLines(chunkFlux) } - .doOnNext { line -> logger.trace("COMPLETE LINE: '$line'") } - - // Step 3: Final flux of StreamingEvent (thinking + objects) - val event = lineFlux - .concatMap { line -> streamingConverter.convertStreamWithThinking(line) } - - return event - } - - /** - * Convert raw streaming chunks → NDJSON lines - * Handles all general cases: - * - multiple \n in one chunk - * - no \n in chunk - * - line spanning many chunks - */ - fun rawChunksToLines(raw: Flux): Flux { - val buffer = StringBuilder() - return raw.concatMap { chunk -> // ONLY CHANGE: handle → concatMap - buffer.append(chunk) - val lines = mutableListOf() - while (true) { - val idx = buffer.indexOf('\n') - if (idx < 0) break - val line = buffer.substring(0, idx).trim() - if (line.isNotEmpty()) lines.add(line) - buffer.delete(0, idx + 1) - } - - Flux.fromIterable(lines) // emit multiple lines - - }.doOnComplete { - // Log any remaining buffer content when stream ends - if (buffer.isNotEmpty()) { - val finalLine = buffer.toString().trim() - if (finalLine.isNotEmpty()) { - logger.trace("FINAL LINE: '$finalLine'") - } - } - }.concatWith( - // final emit - Mono.fromSupplier { buffer.toString().trim() } - .filter { it.isNotEmpty() } - ) - } - - /* ------------------------------------------------------------------------- - * Streaming Abstraction Layer - * - * Supports decoupling streaming from Spring AI via LlmMessageStreamer interface. - * Controlled by useMessageStreamer flag: - * - false (default): uses Spring AI ChatClient directly - * - true: delegates to vendor-neutral LlmMessageStreamer - * - * Enables future support for non-Spring AI providers (e.g., LangChain4j). - * ------------------------------------------------------------------------ */ - - /** - * Build message list with prompt contributions prepended as system message. - * - * Mirrors [buildSpringAiPrompt] but returns Embabel messages instead of Spring AI Prompt. - * Used by the decoupled streaming path (when useMessageStreamer=true). - * - * @param messages Conversation messages - * @param promptContributions Prompt contributions to prepend - * @return Message list with contributions as first system message (if non-empty) - */ - private fun buildMessagesWithContributions( - messages: List, - promptContributions: String, - ): List = buildList { - if (promptContributions.isNotEmpty()) { - add(com.embabel.chat.SystemMessage(promptContributions)) - } - addAll(messages) - } - - /** - * Create raw content stream from LLM. - * - * Switches between decoupled path (LlmMessageStreamer) and current path (Spring AI direct) - * based on [useMessageStreamer] flag. - * - * @param chatClient Spring AI ChatClient instance - * @param messages Embabel conversation messages - * @param promptContributions Prompt contributions string - * @param tools Embabel tools available for LLM - * @param chatOptions Spring AI chat options - * @param springAiPrompt Pre-built Spring AI prompt (used when useMessageStreamer=false) - * @return Flux of raw content chunks - */ - private fun createStreamInternal( - chatClient: org.springframework.ai.chat.client.ChatClient, - messages: List, - promptContributions: String, - tools: List, - chatOptions: org.springframework.ai.chat.prompt.ChatOptions, - springAiPrompt: Prompt, - ): Flux { - return if (useMessageStreamer) { - val streamerMessages = buildMessagesWithContributions(messages, promptContributions) - SpringAiLlmMessageStreamer(chatClient, chatOptions).stream(streamerMessages, tools) - } else { - chatClient - .prompt(springAiPrompt) - .toolCallbacks(tools.toSpringToolCallbacks()) - .options(chatOptions) - .stream() - .content() - } - } -} diff --git a/embabel-agent-api/src/test/kotlin/com/embabel/agent/api/common/OperationContextPromptRunnerStreamingTest.kt b/embabel-agent-api/src/test/kotlin/com/embabel/agent/api/common/OperationContextPromptRunnerStreamingTest.kt index ccd76ad32..68ad560ae 100644 --- a/embabel-agent-api/src/test/kotlin/com/embabel/agent/api/common/OperationContextPromptRunnerStreamingTest.kt +++ b/embabel-agent-api/src/test/kotlin/com/embabel/agent/api/common/OperationContextPromptRunnerStreamingTest.kt @@ -20,10 +20,8 @@ import com.embabel.agent.core.AgentPlatform import com.embabel.agent.core.AgentProcess import com.embabel.agent.core.Operation import com.embabel.agent.core.internal.LlmOperations -import com.embabel.agent.core.support.LlmInteraction import com.embabel.agent.spi.streaming.StreamingLlmOperations -import com.embabel.agent.spi.support.springai.ChatClientLlmOperations -import com.embabel.agent.spi.support.springai.SpringAiLlmService +import com.embabel.agent.spi.support.springai.LlmOperationsIncludingStreaming import com.embabel.agent.test.integration.DummyObjectCreatingLlmOperations import com.embabel.chat.UserMessage import com.embabel.common.ai.model.LlmOptions @@ -31,13 +29,6 @@ import io.mockk.every import io.mockk.mockk import org.junit.jupiter.api.Test import org.junit.jupiter.api.assertThrows -import org.springframework.ai.chat.messages.AssistantMessage -import org.springframework.ai.chat.model.ChatModel -import org.springframework.ai.chat.model.ChatResponse -import org.springframework.ai.chat.model.Generation -import org.springframework.ai.chat.model.StreamingChatModel -import org.springframework.ai.chat.prompt.Prompt -import reactor.core.publisher.Flux import kotlin.test.assertFalse import kotlin.test.assertNotNull @@ -56,22 +47,12 @@ class OperationContextPromptRunnerStreamingTest { @Test fun `should create streaming operations when llmOperations is ChatClientLlmOperations`() { - // Given: Mock with ChatClientLlmOperations that implements StreamingLlmOperations AND has StreamingChatModel - val mockStreamingChatModel = mockk(moreInterfaces = arrayOf(StreamingChatModel::class)) { - every { stream(any()) } returns Flux.just( - ChatResponse( - listOf(Generation(AssistantMessage("streaming response"))) - ) - ) - } - - val mockLlm = mockk { - every { chatModel } returns mockStreamingChatModel - } + // Given: Mock with ChatClientLlmOperations that implements StreamingLlmOperations AND has model supporting streaming + val llmOptions = LlmOptions.withModel("test-model-supporting-streaming") - val mockChatClientLlmOperations = - mockk(moreInterfaces = arrayOf(StreamingLlmOperations::class), relaxed = true) { - every { getLlm(any()) } returns mockLlm + val mockLlmOperationsIncludingStreaming = + mockk(moreInterfaces = arrayOf(StreamingLlmOperations::class), relaxed = true) { + every { supportsStreaming(llmOptions) } returns true } val mockAgentPlatform = mockk() @@ -85,7 +66,7 @@ class OperationContextPromptRunnerStreamingTest { } val mockPlatformServices = mockk { - every { llmOperations } returns mockChatClientLlmOperations // This enables all capability levels + every { llmOperations } returns mockLlmOperationsIncludingStreaming // This enables all capability levels } every { mockAgentPlatform.platformServices } returns mockPlatformServices @@ -94,7 +75,7 @@ class OperationContextPromptRunnerStreamingTest { // Create with real value objects where possible val promptRunner = OperationContextPromptRunner( context = mockOperationContext, - llm = LlmOptions.withModel("test-model"), + llm = llmOptions, messages = listOf(UserMessage("Test message")), toolGroups = emptySet(), toolObjects = emptyList(), diff --git a/embabel-agent-api/src/test/kotlin/com/embabel/agent/api/common/support/OperationContextPromptRunner.kt b/embabel-agent-api/src/test/kotlin/com/embabel/agent/api/common/support/OperationContextPromptRunner.kt index 78127602e..c5fe61788 100644 --- a/embabel-agent-api/src/test/kotlin/com/embabel/agent/api/common/support/OperationContextPromptRunner.kt +++ b/embabel-agent-api/src/test/kotlin/com/embabel/agent/api/common/support/OperationContextPromptRunner.kt @@ -19,7 +19,6 @@ import com.embabel.agent.api.common.* import com.embabel.agent.api.common.nested.support.PromptRunnerCreating import com.embabel.agent.api.common.nested.support.PromptRunnerRendering import com.embabel.agent.api.common.streaming.StreamingPromptRunner -import com.embabel.agent.api.common.support.streaming.StreamingCapabilityDetector import com.embabel.agent.api.common.support.streaming.StreamingImpl import com.embabel.agent.api.common.thinking.support.ThinkingPromptRunnerOperationsImpl import com.embabel.agent.api.tool.Tool @@ -35,8 +34,8 @@ import com.embabel.agent.core.support.LlmInteraction import com.embabel.agent.core.support.safelyGetTools import com.embabel.agent.experimental.primitive.Determination import com.embabel.agent.spi.loop.ToolNotFoundPolicy +import com.embabel.agent.spi.streaming.StreamingLlmOperations import com.embabel.agent.spi.support.springai.ChatClientLlmOperations -import com.embabel.agent.spi.support.springai.streaming.StreamingChatClientOperations import com.embabel.chat.ImagePart import com.embabel.chat.Message import com.embabel.chat.UserMessage @@ -275,14 +274,16 @@ internal data class OperationContextPromptRunner( /** * Check if streaming is supported by the underlying LLM model. * Performs three-level capability detection: - * 1. Must be ChatClientLlmOperations for Spring AI integration + * 1. Must implement StreamingLlmOperations * 2. Must have StreamingChatModel */ override fun supportsStreaming(): Boolean { val llmOperations = context.agentPlatform().platformServices.llmOperations + // Level 1 sanity check + if (llmOperations !is StreamingLlmOperations) return false - return StreamingCapabilityDetector.supportsStreaming(llmOperations, this.llm) + return llmOperations.supportsStreaming(this.llm) } override fun streaming(): StreamingPromptRunner.Streaming { @@ -297,9 +298,7 @@ internal data class OperationContextPromptRunner( } return StreamingImpl( - streamingLlmOperations = StreamingChatClientOperations( - context.agentPlatform().platformServices.llmOperations as ChatClientLlmOperations - ), + streamingLlmOperations = context.agentPlatform().platformServices.llmOperations as StreamingLlmOperations, interaction = LlmInteraction( llm = llm, toolGroups = toolGroups, diff --git a/embabel-agent-api/src/test/kotlin/com/embabel/agent/e2e/LLMStreamingIntegrationTest.kt b/embabel-agent-api/src/test/kotlin/com/embabel/agent/e2e/LLMStreamingIntegrationTest.kt index b700128f7..6bfb20365 100644 --- a/embabel-agent-api/src/test/kotlin/com/embabel/agent/e2e/LLMStreamingIntegrationTest.kt +++ b/embabel-agent-api/src/test/kotlin/com/embabel/agent/e2e/LLMStreamingIntegrationTest.kt @@ -21,7 +21,7 @@ import com.embabel.agent.api.common.Ai import com.embabel.agent.api.common.autonomy.Autonomy import com.embabel.agent.api.common.streaming.StreamingPromptRunner import com.embabel.agent.api.common.streaming.asStreaming -import com.embabel.agent.api.common.support.streaming.StreamingCapabilityDetector +import com.embabel.agent.spi.streaming.StreamingCapabilityDetector import com.embabel.agent.core.AgentPlatform import com.embabel.agent.core.internal.LlmOperations import com.embabel.agent.spi.LlmService diff --git a/embabel-agent-api/src/test/kotlin/com/embabel/agent/spi/support/ChatClientLlmOperationsTest.kt b/embabel-agent-api/src/test/kotlin/com/embabel/agent/spi/support/ChatClientLlmOperationsTest.kt index c1f9a9094..fdb860b2c 100644 --- a/embabel-agent-api/src/test/kotlin/com/embabel/agent/spi/support/ChatClientLlmOperationsTest.kt +++ b/embabel-agent-api/src/test/kotlin/com/embabel/agent/spi/support/ChatClientLlmOperationsTest.kt @@ -27,8 +27,8 @@ import com.embabel.agent.core.support.InvalidLlmReturnFormatException import com.embabel.agent.core.support.InvalidLlmReturnTypeException import com.embabel.agent.core.support.LlmInteraction import com.embabel.agent.core.support.safelyGetToolsFrom +import com.embabel.agent.spi.streaming.StreamingLlmOperations import com.embabel.agent.spi.support.springai.ChatClientLlmOperations -import com.embabel.agent.spi.support.MaybeReturn import com.embabel.agent.spi.support.springai.SpringAiLlmService import com.embabel.agent.spi.validation.DefaultValidationPromptGenerator import com.embabel.agent.support.SimpleTestAgent @@ -48,8 +48,10 @@ import io.mockk.slot import jakarta.validation.Validation import jakarta.validation.constraints.Pattern import org.junit.jupiter.api.Assertions.* +import org.junit.jupiter.api.Disabled import org.junit.jupiter.api.Nested import org.junit.jupiter.api.Test +import org.slf4j.LoggerFactory import org.springframework.ai.chat.messages.AssistantMessage import org.springframework.ai.chat.model.ChatModel import org.springframework.ai.chat.model.ChatResponse @@ -58,6 +60,9 @@ import org.springframework.ai.chat.prompt.ChatOptions import org.springframework.ai.chat.prompt.DefaultChatOptions import org.springframework.ai.chat.prompt.Prompt import org.springframework.ai.model.tool.ToolCallingChatOptions +import reactor.core.publisher.Flux +import reactor.test.StepVerifier +import java.time.Duration import java.time.LocalDate import java.util.concurrent.Executors import java.util.function.Predicate @@ -104,13 +109,31 @@ class FakeChatModel( ) ) } + + override fun stream(prompt: Prompt): Flux { + promptsPassed.add(prompt) + val options = prompt.options as? ToolCallingChatOptions + ?: throw IllegalArgumentException("Expected ToolCallingChatOptions") + optionsPassed.add(options) + return Flux.fromIterable(responses) + .map { response -> + ChatResponse( + listOf( + Generation(AssistantMessage(response)) + ) + ) + } + } } class ChatClientLlmOperationsTest { + private val logger = LoggerFactory.getLogger(ChatClientLlmOperationsTest::class.java) + data class Setup( val llmOperations: LlmOperations, + val streamingLlmOperations: StreamingLlmOperations, val mockAgentProcess: AgentProcess, val mutableLlmInvocationHistory: MutableLlmInvocationHistory, ) @@ -118,6 +141,7 @@ class ChatClientLlmOperationsTest { private fun createChatClientLlmOperations( fakeChatModel: FakeChatModel, dataBindingProperties: LlmDataBindingProperties = LlmDataBindingProperties(), + useMessageStreamer: Boolean = false, ): Setup { val ese = EventSavingAgenticEventListener() val mutableLlmInvocationHistory = MutableLlmInvocationHistory() @@ -157,8 +181,9 @@ class ChatClientLlmOperationsTest { objectMapper = jacksonObjectMapper().registerModule(JavaTimeModule()), dataBindingProperties = dataBindingProperties, asyncer = ExecutorAsyncer(Executors.newCachedThreadPool()), + useMessageStreamer = useMessageStreamer ) - return Setup(cco, mockAgentProcess, mutableLlmInvocationHistory) + return Setup(cco, cco, mockAgentProcess, mutableLlmInvocationHistory) } data class Dog(val name: String) @@ -408,6 +433,398 @@ class ChatClientLlmOperationsTest { } } + @Nested + inner class CreateObjectStream { + + @Test + fun `passes correct prompt`() { + val duke = Dog("Duke") + + val fakeChatModel = FakeChatModel(jacksonObjectMapper().writeValueAsString(duke)) + + val prompt = + "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua." + val setup = createChatClientLlmOperations(fakeChatModel) + setup.streamingLlmOperations.createObjectStream( + messages = listOf(UserMessage(prompt)), + interaction = LlmInteraction( + id = InteractionId("id"), llm = LlmOptions() + ), + outputClass = Dog::class.java, + action = SimpleTestAgent.actions.first(), + agentProcess = setup.mockAgentProcess, + ).blockLast() + + val promptText = fakeChatModel.promptsPassed[0].toString() + assertTrue(promptText.contains("\$schema"), "Prompt contains JSON schema") + assertTrue(promptText.contains(promptText), "Prompt contains user prompt:\n$promptText") + } + + @Test + fun `handles ill formed JSON when returning data class`() { + val fakeChatModel = FakeChatModel("This ain't no JSON") + + val setup = createChatClientLlmOperations(fakeChatModel) + val result = setup.streamingLlmOperations.createObjectStream( + messages = listOf(UserMessage("prompt")), + interaction = LlmInteraction( + id = InteractionId("id"), llm = LlmOptions() + ), + outputClass = Dog::class.java, + action = SimpleTestAgent.actions.first(), + agentProcess = setup.mockAgentProcess, + ) + + StepVerifier.create(result) + .verifyComplete() // No data returned + } + + @Test + fun `returns data class`() { + val duke = Dog("Duke") + + val fakeChatModel = FakeChatModel(jacksonObjectMapper().writeValueAsString(duke)) + + val setup = createChatClientLlmOperations(fakeChatModel) + val result = setup.streamingLlmOperations.createObjectStream( + messages = listOf(UserMessage("prompt")), + interaction = LlmInteraction( + id = InteractionId("id"), llm = LlmOptions() + ), + outputClass = Dog::class.java, + action = SimpleTestAgent.actions.first(), + agentProcess = setup.mockAgentProcess, + ).blockLast() + assertEquals(duke, result) + } + + @Test + fun `passes JSON few shot example`() { + val duke = Dog("Duke") + + val fakeChatModel = FakeChatModel(jacksonObjectMapper().writeValueAsString(duke)) + + val setup = createChatClientLlmOperations(fakeChatModel) + val result = setup.streamingLlmOperations.createObjectStream( + messages = listOf( + UserMessage( + """ + Return a dog. Dogs look like this: + { + "name": "Duke", + "type": "Dog" + } + """.trimIndent() + ) + ), + interaction = LlmInteraction( + id = InteractionId("id"), llm = LlmOptions() + ), + outputClass = Dog::class.java, + action = SimpleTestAgent.actions.first(), + agentProcess = setup.mockAgentProcess, + ).blockLast() + assertEquals(duke, result) + } + + @Test + fun `presents no tools to ChatModel`() { + val duke = Dog("Duke") + + val fakeChatModel = FakeChatModel(jacksonObjectMapper().writeValueAsString(duke)) + + val setup = createChatClientLlmOperations(fakeChatModel) + val result = setup.streamingLlmOperations.createObjectStream( + messages = listOf(UserMessage("prompt")), + interaction = LlmInteraction( + id = InteractionId("id"), llm = LlmOptions() + ), + outputClass = Dog::class.java, + action = SimpleTestAgent.actions.first(), + agentProcess = setup.mockAgentProcess, + ).blockLast() + assertEquals(duke, result) + assertEquals(1, fakeChatModel.promptsPassed.size) + val tools = fakeChatModel.optionsPassed[0].toolCallbacks + assertEquals(0, tools.size) + } + + @Test + fun `presents tools to ChatModel via doTransform`() { + val duke = Dog("Duke") + + val fakeChatModel = FakeChatModel(jacksonObjectMapper().writeValueAsString(duke)) + + // Wumpus's have tools + val tools = safelyGetToolsFrom(ToolObject(Wumpus("wumpy"))) + val setup = createChatClientLlmOperations(fakeChatModel) + val result = setup.streamingLlmOperations.doTransformObjectStream( + messages = listOf( + SystemMessage("do whatever"), + UserMessage("prompt"), + ), + interaction = LlmInteraction( + id = InteractionId("id"), + llm = LlmOptions(), + tools = tools, + ), + outputClass = Dog::class.java, + llmRequestEvent = null, + ).blockLast() + assertEquals(duke, result) + assertEquals(1, fakeChatModel.promptsPassed.size) + val passedTools = fakeChatModel.optionsPassed[0].toolCallbacks + assertEquals(tools.size, passedTools.size, "Must have passed same number of tools") + assertEquals( + tools.map { it.definition.name }.toSet(), + passedTools.map { it.toolDefinition.name() }.toSet(), + ) + } + + @Test + fun `presents tools to ChatModel when given multiple messages`() { + val duke = Dog("Duke") + + val fakeChatModel = FakeChatModel(jacksonObjectMapper().writeValueAsString(duke)) + + // Wumpus's have tools - use native Tool interface + val tools = safelyGetToolsFrom(ToolObject(Wumpus("wumpy"))) + val setup = createChatClientLlmOperations(fakeChatModel) + val result = setup.streamingLlmOperations.createObjectStream( + messages = listOf(UserMessage("prompt")), + interaction = LlmInteraction( + id = InteractionId("id"), + llm = LlmOptions(), + tools = tools, + ), + outputClass = Dog::class.java, + action = SimpleTestAgent.actions.first(), + agentProcess = setup.mockAgentProcess, + ).blockLast() + assertEquals(duke, result) + assertEquals(1, fakeChatModel.promptsPassed.size) + val passedTools = fakeChatModel.optionsPassed[0].toolCallbacks + assertEquals(tools.size, passedTools.size, "Must have passed same number of tools") + assertEquals( + tools.map { it.definition.name }.sorted(), + passedTools.map { it.toolDefinition.name() }) + } + + @Test + fun `handles reasoning model return`() { + val duke = Dog("Duke") + + val fakeChatModel = FakeChatModel( + "Deep thoughts\n" + jacksonObjectMapper().writeValueAsString(duke) + ) + + val setup = createChatClientLlmOperations(fakeChatModel) + val result = setup.streamingLlmOperations.createObjectStream( + messages = listOf(UserMessage("prompt")), + interaction = LlmInteraction( + id = InteractionId("id"), llm = LlmOptions() + ), + outputClass = Dog::class.java, + action = SimpleTestAgent.actions.first(), + agentProcess = setup.mockAgentProcess, + ).blockLast() + assertEquals(duke, result) + } + + @Test + fun `handles LocalDate return`() { + val duke = TemporalDog("Duke", birthDate = LocalDate.of(2021, 2, 26)) + + val fakeChatModel = FakeChatModel( + jacksonObjectMapper().registerModule(JavaTimeModule()).writeValueAsString(duke) + ) + + val setup = createChatClientLlmOperations(fakeChatModel) + val result = setup.streamingLlmOperations.createObjectStream( + messages = listOf(UserMessage("prompt")), + interaction = LlmInteraction( + id = InteractionId("id"), llm = LlmOptions() + ), + outputClass = TemporalDog::class.java, + action = SimpleTestAgent.actions.first(), + agentProcess = setup.mockAgentProcess, + ).blockLast() + assertEquals(duke, result) + } + } + + data class TestItem(val name: String, val value: Int) + + @Nested + inner class CreateObjectStreamWithThinking { + + @Test + fun `should handle single complete chunk`() { + val fakeChatModel = FakeChatModel("This is thinking content\n") + val setup = createChatClientLlmOperations(fakeChatModel) + + // When + val result = setup.streamingLlmOperations.createObjectStreamWithThinking( + messages = listOf(UserMessage("test")), + interaction = LlmInteraction( + id = InteractionId("id"), llm = LlmOptions() + ), + outputClass = TestItem::class.java, + action = SimpleTestAgent.actions.first(), + agentProcess = setup.mockAgentProcess + ) + + + // Then: Should emit one thinking event for the complete line + StepVerifier.create(result) + .expectNextMatches { + it.isThinking() && it.getThinking() == "This is thinking content" + } + .expectComplete() + .verify(Duration.ofSeconds(1)) + } + + @Test + fun `should handle multi-chunk JSONL object stream`() { + // Given: Multiple chunks forming JSONL objects + val chunks = listOf( + "{\"name\":\"Item1\",\"value\":", + "42}\n{\"name\":\"Item2\",", + "\"value\":84}\n" + ) + val fakeChatModel = FakeChatModel(chunks) + val setup = createChatClientLlmOperations(fakeChatModel) + + // When + val result = setup.streamingLlmOperations.createObjectStreamWithThinking( + messages = listOf(UserMessage("test")), + interaction = LlmInteraction( + id = InteractionId("id"), llm = LlmOptions() + ), + outputClass = TestItem::class.java, + action = SimpleTestAgent.actions.first(), + agentProcess = setup.mockAgentProcess + ) + + // Then: Should emit two object events + StepVerifier.create(result) + .expectNextMatches { + it.isObject() && it.getObject()?.name == "Item1" && it.getObject()?.value == 42 + } + .expectNextMatches { + it.isObject() && it.getObject()?.name == "Item2" && it.getObject()?.value == 84 + } + .expectComplete() + .verify(Duration.ofSeconds(1)) + } + + @Test + fun `should handle mixed thinking and object content in chunks`() { + // Given: Realistic chunking that splits thinking and JSON across chunk boundaries + val chunks = listOf( + "Ana", // Partial thinking start + "lyzing req", // Partial thinking middle + "uirement\n{\"name\":", // Thinking end + partial JSON + "\"TestItem\",\"va", // Partial JSON middle + "lue\":123}\nDone", // JSON end + partial thinking + "\n" // Thinking end + ) + val fakeChatModel = FakeChatModel(chunks) + val setup = createChatClientLlmOperations(fakeChatModel) + + // When + val result = setup.streamingLlmOperations.createObjectStreamWithThinking( + messages = listOf(UserMessage("test")), + interaction = LlmInteraction( + id = InteractionId("id"), llm = LlmOptions() + ), + outputClass = TestItem::class.java, + action = SimpleTestAgent.actions.first(), + agentProcess = setup.mockAgentProcess + ) + + // Then: Should emit thinking, object, thinking in correct order + StepVerifier.create(result) + .expectNextMatches { + it.isThinking() && it.getThinking() == "Analyzing requirement" + } + .expectNextMatches { + it.isObject() && it.getObject()?.name == "TestItem" && it.getObject()?.value == 123 + } + .expectNextMatches { + it.isThinking() && it.getThinking() == "Done" + } + .expectComplete() + .verify(Duration.ofSeconds(1)) + } + + @Test + fun `should handle real streaming with reactive callbacks`() { + // Given: Mixed content with multiple events + val chunks = listOf( + "Processing request\n", + "{\"name\":\"Item1\",\"value\":100}\n", + "{\"name\":\"Item2\",\"value\":200}\n", + "Request completed\n" + ) + val fakeChatModel = FakeChatModel(chunks) + val setup = createChatClientLlmOperations(fakeChatModel) + + // When: Subscribe with real reactive callbacks + val receivedEvents = mutableListOf() + var errorOccurred: Throwable? = null + var completionCalled = false + + val result = setup.streamingLlmOperations.createObjectStreamWithThinking( + messages = listOf(UserMessage("test")), + interaction = LlmInteraction( + id = InteractionId("id"), llm = LlmOptions() + ), + outputClass = TestItem::class.java, + action = SimpleTestAgent.actions.first(), + agentProcess = setup.mockAgentProcess + ) + + result + .doOnNext { event -> + when { + event.isThinking() -> { + val content = event.getThinking()!! + receivedEvents.add("THINKING: $content") + logger.info("Received thinking: {}", content) + } + + event.isObject() -> { + val obj = event.getObject()!! + receivedEvents.add("OBJECT: ${obj.name}=${obj.value}") + logger.info("Received object: {}={}", obj.name, obj.value) + } + } + } + .doOnError { error -> + errorOccurred = error + logger.error("Stream error: {}", error.message) + } + .doOnComplete { + completionCalled = true + logger.info("Stream completed successfully") + } + .subscribe() + + // Give stream time to complete + Thread.sleep(500) + + // Then: Verify real reactive behavior + assertNull(errorOccurred, "No errors should occur") + assertTrue(completionCalled, "Stream should complete successfully") + assertEquals(4, receivedEvents.size, "Should receive all events") + assertEquals("THINKING: Processing request", receivedEvents[0]) + assertEquals("OBJECT: Item1=100", receivedEvents[1]) + assertEquals("OBJECT: Item2=200", receivedEvents[2]) + assertEquals("THINKING: Request completed", receivedEvents[3]) + } + } + @Nested inner class CreateObjectIfPossible { @@ -595,6 +1012,202 @@ class ChatClientLlmOperationsTest { } } + @Nested + inner class CreateObjectStreamIfPossible { + + @Test + @Disabled("createObjectStreamIfPossible does not have an implemenation with a specific prompt yet") + fun `should have correct prompt with success and failure`() { + val fakeChatModel = + FakeChatModel( + jacksonObjectMapper().writeValueAsString( + MaybeReturn( + failure = "didn't work" + ) + ) + ) + + val prompt = "The quick brown fox jumped over the lazy dog" + val setup = createChatClientLlmOperations(fakeChatModel) + val result = setup.streamingLlmOperations.createObjectStreamIfPossible( + messages = listOf(UserMessage(prompt)), + interaction = LlmInteraction( + id = InteractionId("id"), llm = LlmOptions() + ), + outputClass = Dog::class.java, + action = SimpleTestAgent.actions.first(), + agentProcess = setup.mockAgentProcess, + ).blockLast() + assertTrue(result!!.isFailure) + val promptText = fakeChatModel.promptsPassed[0].toString() + assertTrue(promptText.contains("\$schema"), "Prompt contains JSON schema") + assertTrue(promptText.contains(promptText), "Prompt contains user prompt:\n$promptText") + + assertTrue(promptText.contains("possible"), "Prompt mentions possible") + assertTrue(promptText.contains("success"), "Prompt mentions success") + assertTrue(promptText.contains("failure"), "Prompt mentions failure") + } + + @Test + fun `returns data class - success`() { + val duke = Dog("Duke") + + val fakeChatModel = FakeChatModel( + jacksonObjectMapper().writeValueAsString(duke) + ) + + val setup = createChatClientLlmOperations(fakeChatModel) + val result = setup.streamingLlmOperations.createObjectStreamIfPossible( + messages = listOf(UserMessage("prompt")), + interaction = LlmInteraction( + id = InteractionId("id"), llm = LlmOptions() + ), + outputClass = Dog::class.java, + action = SimpleTestAgent.actions.first(), + agentProcess = setup.mockAgentProcess, + ).blockLast() + assertEquals(duke, result!!.getOrThrow()) + } + + @Test + fun `handles reasoning model success return`() { + val duke = Dog("Duke") + + val fakeChatModel = FakeChatModel( + "More deep thoughts\n" + jacksonObjectMapper().writeValueAsString(duke) + ) + + val setup = createChatClientLlmOperations(fakeChatModel) + val result = setup.streamingLlmOperations.createObjectStreamIfPossible( + messages = listOf(UserMessage("prompt")), + interaction = LlmInteraction( + id = InteractionId("id"), llm = LlmOptions() + ), + outputClass = Dog::class.java, + action = SimpleTestAgent.actions.first(), + agentProcess = setup.mockAgentProcess, + ).blockLast() + assertEquals(duke, result!!.getOrThrow()) + } + + @Test + fun `handles LocalDate return`() { + val duke = TemporalDog("Duke", birthDate = LocalDate.of(2021, 2, 26)) + + val fakeChatModel = FakeChatModel( + jacksonObjectMapper().registerModule(JavaTimeModule()).writeValueAsString(duke) + ) + + val setup = createChatClientLlmOperations(fakeChatModel) + val result = setup.streamingLlmOperations.createObjectStreamIfPossible( + messages = listOf(UserMessage("prompt")), + interaction = LlmInteraction( + id = InteractionId("id"), llm = LlmOptions() + ), + outputClass = TemporalDog::class.java, + action = SimpleTestAgent.actions.first(), + agentProcess = setup.mockAgentProcess, + ).blockLast() + assertEquals(duke, result!!.getOrThrow()) + } + + @Test + fun `handles ill formed JSON when returning data class`() { + val fakeChatModel = FakeChatModel("This ain't no JSON") + + val setup = createChatClientLlmOperations(fakeChatModel) + val result = setup.streamingLlmOperations.createObjectStreamIfPossible( + messages = listOf(UserMessage("prompt")), + interaction = LlmInteraction( + id = InteractionId("id"), llm = LlmOptions() + ), + outputClass = Dog::class.java, + action = SimpleTestAgent.actions.first(), + agentProcess = setup.mockAgentProcess, + ) + + StepVerifier.create(result) + .verifyComplete() // No data returned + } + + @Test + fun `returns data class - failure`() { + val fakeChatModel = + FakeChatModel( + jacksonObjectMapper().writeValueAsString( + MaybeReturn( + failure = "didn't work" + ) + ) + ) + + val setup = createChatClientLlmOperations(fakeChatModel) + val result = setup.streamingLlmOperations.createObjectStreamIfPossible( + messages = listOf(UserMessage("prompt")), + interaction = LlmInteraction( + id = InteractionId("id"), llm = LlmOptions() + ), + outputClass = Dog::class.java, + action = SimpleTestAgent.actions.first(), + agentProcess = setup.mockAgentProcess, + ).blockLast() + assertTrue(result!!.isFailure) + } + + @Test + fun `presents tools to ChatModel`() { + val duke = Dog("Duke") + + val fakeChatModel = FakeChatModel( + jacksonObjectMapper().writeValueAsString( + MaybeReturn(duke) + ) + ) + + // Wumpus's have tools - use native Tool interface + val tools = safelyGetToolsFrom(ToolObject(Wumpus("wumpy"))) + val setup = createChatClientLlmOperations(fakeChatModel) + setup.streamingLlmOperations.createObjectStreamIfPossible( + messages = listOf(UserMessage("prompt")), + interaction = LlmInteraction( + id = InteractionId("id"), + llm = LlmOptions(), + tools = tools, + ), + outputClass = Dog::class.java, + action = SimpleTestAgent.actions.first(), + agentProcess = setup.mockAgentProcess, + ).blockLast() + assertEquals(1, fakeChatModel.promptsPassed.size) + val passedTools = fakeChatModel.optionsPassed[0].toolCallbacks + assertEquals(tools.size, passedTools.size, "Must have passed same number of tools") + assertEquals( + tools.map { it.definition.name }.sorted(), + passedTools.map { it.toolDefinition.name() }) + } + } + + @Nested + inner class GenerateStream { + + @Test + fun `returns string`() { + val fakeChatModel = FakeChatModel("fake response") + + val setup = createChatClientLlmOperations(fakeChatModel) + val result = setup.streamingLlmOperations.generateStream( + messages = listOf(UserMessage("prompt")), + interaction = LlmInteraction( + id = InteractionId("id"), llm = LlmOptions() + ), + action = SimpleTestAgent.actions.first(), + agentProcess = setup.mockAgentProcess, + ).blockLast() + + assertEquals(fakeChatModel.response, result) + } + } + @Nested inner class TimeoutBehavior { @@ -697,7 +1310,7 @@ class ChatClientLlmOperationsTest { llmOperationsPromptsProperties = promptsProperties, asyncer = ExecutorAsyncer(Executors.newCachedThreadPool()), ) - return Setup(cco, mockAgentProcess, mutableLlmInvocationHistory) + return Setup(cco, cco, mockAgentProcess, mutableLlmInvocationHistory) } } @@ -1326,4 +1939,144 @@ class ChatClientLlmOperationsTest { } } + /** + * Test of lower level internal implementation + */ + @Nested + inner class StreamedChunks { + + @Test + fun `rawChunksToLines should handle single line chunks`() { + val chunks = Flux.just("line1\n", "line2\n") + + val setup = createChatClientLlmOperations(FakeChatModel("fake")) + val result = (setup.streamingLlmOperations as ChatClientLlmOperations).rawChunksToLines(chunks) + + StepVerifier.create(result) + .expectNext("line1") + .expectNext("line2") + .verifyComplete() + } + + @Test + fun `rawChunksToLines should handle multi-line chunks from Anthropic`() { + val chunks = Flux.just(".\n\n\n{\"") + + val setup = createChatClientLlmOperations(FakeChatModel("fake")) + val result = (setup.streamingLlmOperations as ChatClientLlmOperations).rawChunksToLines(chunks) + + StepVerifier.create(result) + .expectNext(".") + .expectNext("") + .expectNext("{\"") + .verifyComplete() + } + + @Test + fun `rawChunksToLines should handle incomplete lines across chunks`() { + val chunks = Flux.just("partial", " line\n", "complete\n") + + val setup = createChatClientLlmOperations(FakeChatModel("fake")) + val result = (setup.streamingLlmOperations as ChatClientLlmOperations).rawChunksToLines(chunks) + + StepVerifier.create(result) + .expectNext("partial line") + .expectNext("complete") + .verifyComplete() + } + + @Test + fun `rawChunksToLines should emit final incomplete line`() { + val chunks = Flux.just("line1\n", "incomplete") + + val setup = createChatClientLlmOperations(FakeChatModel("fake")) + val result = (setup.streamingLlmOperations as ChatClientLlmOperations).rawChunksToLines(chunks) + + StepVerifier.create(result) + .expectNext("line1") + .expectNext("incomplete") + .verifyComplete() + } + } + + /** + * Tests for useMessageStreamer=true (decoupled streaming path via LlmMessageStreamer). + */ + @Nested + inner class MessageStreamerTests { + + @Test + fun `should use LlmMessageStreamer when useMessageStreamer is true`() { + // Given + val fakeChatModel = FakeChatModel(listOf("streamed ", "content")) + + val setup = createChatClientLlmOperations(fakeChatModel, useMessageStreamer = true) + + // When + val result = setup.streamingLlmOperations.generateStream( + listOf(UserMessage("test")), + interaction = LlmInteraction( + id = InteractionId("id"), llm = LlmOptions() + ), + action = SimpleTestAgent.actions.first(), + agentProcess = setup.mockAgentProcess, + ) + + // Then + StepVerifier.create(result) + .expectNext("streamed ") + .expectNext("content") + .verifyComplete() + } + + @Test + fun `should prepend prompt contributions as system message`() { + // Given + val fakeChatModel = FakeChatModel(listOf("response")) + + val setup = createChatClientLlmOperations(fakeChatModel, useMessageStreamer = true) + + // When + val result = setup.streamingLlmOperations.generateStream( + listOf(UserMessage("user message")), + interaction = LlmInteraction( + id = InteractionId("id"), llm = LlmOptions() + ), + action = SimpleTestAgent.actions.first(), + agentProcess = setup.mockAgentProcess, + ) + + // Then + StepVerifier.create(result) + .expectNext("response") + .verifyComplete() + } + + @Test + fun `should handle object streaming`() { + // Given + val fakeChatModel = FakeChatModel(listOf("{\"name\":\"Test\",\"value\":42}\n")) + + val setup = createChatClientLlmOperations(fakeChatModel, useMessageStreamer = true) + + // When + val result = setup.streamingLlmOperations.createObjectStreamWithThinking( + messages = listOf(UserMessage("test")), + outputClass = TestItem::class.java, + interaction = LlmInteraction( + id = InteractionId("id"), llm = LlmOptions() + ), + action = SimpleTestAgent.actions.first(), + agentProcess = setup.mockAgentProcess, + ) + + // Then + StepVerifier.create(result) + .expectNextMatches { + it.isObject() && it.getObject()?.name == "Test" && it.getObject()?.value == 42 + } + .expectComplete() + .verify(Duration.ofSeconds(1)) + } + } } diff --git a/embabel-agent-api/src/test/kotlin/com/embabel/agent/spi/support/springai/streaming/StreamingChatClientOperationsGuardRailTest.kt b/embabel-agent-api/src/test/kotlin/com/embabel/agent/spi/support/springai/streaming/ChatClientStreamingOperationsGuardRailTest.kt similarity index 97% rename from embabel-agent-api/src/test/kotlin/com/embabel/agent/spi/support/springai/streaming/StreamingChatClientOperationsGuardRailTest.kt rename to embabel-agent-api/src/test/kotlin/com/embabel/agent/spi/support/springai/streaming/ChatClientStreamingOperationsGuardRailTest.kt index b2811b723..0ce86f0f5 100644 --- a/embabel-agent-api/src/test/kotlin/com/embabel/agent/spi/support/springai/streaming/StreamingChatClientOperationsGuardRailTest.kt +++ b/embabel-agent-api/src/test/kotlin/com/embabel/agent/spi/support/springai/streaming/ChatClientStreamingOperationsGuardRailTest.kt @@ -26,6 +26,7 @@ import com.embabel.agent.core.LlmInvocation import com.embabel.agent.core.LlmInvocationHistory import com.embabel.agent.core.ProcessContext import com.embabel.agent.core.support.LlmInteraction +import com.embabel.agent.spi.streaming.StreamingLlmOperations import com.embabel.agent.spi.support.DefaultToolDecorator import com.embabel.agent.spi.support.LlmDataBindingProperties import com.embabel.agent.spi.support.RegistryToolGroupResolver @@ -120,12 +121,12 @@ class StreamingGuardRailTestFakeChatModel( } /** - * Tests for guardrail validation in StreamingChatClientOperations + * Tests for guardrail validation in streaming operations in ChatClientLlmOperations */ -class StreamingChatClientOperationsGuardRailTest { +class ChatClientStreamingOperationsGuardRailTest { internal data class Setup( - val streamingOperations: StreamingChatClientOperations, + val streamingOperations: StreamingLlmOperations, val mockAgentProcess: AgentProcess, val mutableLlmInvocationHistory: StreamingGuardRailTestMutableLlmInvocationHistory, ) @@ -172,8 +173,7 @@ class StreamingChatClientOperationsGuardRailTest { dataBindingProperties = dataBindingProperties, asyncer = com.embabel.agent.spi.support.ExecutorAsyncer(java.util.concurrent.Executors.newCachedThreadPool()), ) - val streamingOperations = StreamingChatClientOperations(cco) - return Setup(streamingOperations, mockAgentProcess, mutableLlmInvocationHistory) + return Setup(cco, mockAgentProcess, mutableLlmInvocationHistory) } @Test diff --git a/embabel-agent-api/src/test/kotlin/com/embabel/agent/spi/support/springai/streaming/StreamingChatClientOperationsTest.kt b/embabel-agent-api/src/test/kotlin/com/embabel/agent/spi/support/springai/streaming/StreamingChatClientOperationsTest.kt deleted file mode 100644 index 2d1c58652..000000000 --- a/embabel-agent-api/src/test/kotlin/com/embabel/agent/spi/support/springai/streaming/StreamingChatClientOperationsTest.kt +++ /dev/null @@ -1,568 +0,0 @@ -/* - * Copyright 2024-2026 Embabel Pty Ltd. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.embabel.agent.spi.support.springai.streaming - -import com.embabel.agent.core.Action -import com.embabel.agent.core.AgentProcess -import com.embabel.agent.core.support.LlmInteraction -import com.embabel.agent.spi.streaming.StreamingLlmOperations -import com.embabel.agent.spi.support.springai.ChatClientLlmOperations -import com.embabel.agent.spi.support.springai.SpringAiLlmService -import com.embabel.chat.UserMessage -import com.fasterxml.jackson.module.kotlin.jacksonObjectMapper -import io.mockk.every -import io.mockk.mockk -import io.mockk.verify -import org.junit.jupiter.api.Assertions.* -import org.junit.jupiter.api.BeforeEach -import org.junit.jupiter.api.Nested -import org.junit.jupiter.api.Test -import org.slf4j.LoggerFactory -import org.springframework.ai.chat.client.ChatClient -import org.springframework.ai.chat.prompt.Prompt -import org.springframework.ai.tool.ToolCallback -import reactor.core.publisher.Flux -import reactor.test.StepVerifier -import java.time.Duration - -/** - * Unit tests for StreamingChatClientOperations. - * - * Tests the SPI layer delegation behavior and interface implementation. - * StreamingChatClientOperations sits at the SPI layer, below the API layer's - * streaming capability detection. This class assumes streaming has already been - * validated as available by the API layer (OperationContextPromptRunner). - * - * Key responsibilities tested: - * - Proper delegation to ChatClientLlmOperations - * - Correct implementation of StreamingLlmOperations interface - * - Bridge between embabel streaming interfaces and Spring AI ChatClient - * - * Note: Streaming capability detection is tested at the API layer. - * Complex Spring AI interactions are tested in integration tests. - */ -class StreamingChatClientOperationsTest { - - private val logger = LoggerFactory.getLogger(StreamingChatClientOperationsTest::class.java) - private lateinit var mockChatClientLlmOperations: ChatClientLlmOperations - private lateinit var mockLlm: SpringAiLlmService - private lateinit var mockChatClient: ChatClient - private lateinit var mockInteraction: LlmInteraction - private lateinit var mockAgentProcess: AgentProcess - private lateinit var mockAction: Action - private lateinit var streamingOperations: StreamingChatClientOperations - - @BeforeEach - fun setUp() { - mockChatClientLlmOperations = mockk(relaxed = true) - mockLlm = mockk(relaxed = true) - mockChatClient = mockk(relaxed = true) - mockInteraction = mockk(relaxed = true) - mockAgentProcess = mockk(relaxed = true) - mockAction = mockk(relaxed = true) - - // Setup basic delegation - every { mockChatClientLlmOperations.getLlm(any()) } returns mockLlm - every { mockChatClientLlmOperations.createChatClient(mockLlm) } returns mockChatClient - every { mockInteraction.promptContributors } returns emptyList() - every { mockLlm.promptContributors } returns emptyList() - val mockOptionsConverter = mockk>(relaxed = true) - every { mockLlm.optionsConverter } returns mockOptionsConverter - every { mockOptionsConverter.convertOptions(any()) } returns mockk(relaxed = true) - every { mockInteraction.llm } returns mockk(relaxed = true) - every { mockInteraction.tools } returns emptyList() - every { mockChatClientLlmOperations.objectMapper } returns jacksonObjectMapper() - every { mockInteraction.fieldFilter } returns { true } - - streamingOperations = StreamingChatClientOperations(mockChatClientLlmOperations) - } - - @Test - fun `should implement StreamingLlmOperations interface`() { - // Given & When & Then - assertTrue(streamingOperations is StreamingLlmOperations) - } - - @Test - fun `should delegate getLlm to ChatClientLlmOperations on generateStream`() { - // Given - val messages = listOf(UserMessage("Test prompt")) - - // When - streamingOperations.generateStream(messages, mockInteraction, mockAgentProcess, mockAction) - - // Then - verify { mockChatClientLlmOperations.getLlm(mockInteraction) } - } - - @Test - fun `should delegate createChatClient to ChatClientLlmOperations on generateStream`() { - // Given - val messages = listOf(UserMessage("Test prompt")) - - // When - streamingOperations.generateStream(messages, mockInteraction, mockAgentProcess, mockAction) - - // Then - verify { mockChatClientLlmOperations.createChatClient(mockLlm) } - } - - @Test - fun `should delegate getLlm to ChatClientLlmOperations on createObjectStream`() { - // Given - val messages = listOf(UserMessage("Create objects")) - val outputClass = TestItem::class.java - - // When - streamingOperations.createObjectStream(messages, mockInteraction, outputClass, mockAgentProcess, mockAction) - - // Then - verify { mockChatClientLlmOperations.getLlm(mockInteraction) } - } - - @Test - fun `should delegate createChatClient to ChatClientLlmOperations on createObjectStream`() { - // Given - val messages = listOf(UserMessage("Create objects")) - val outputClass = TestItem::class.java - - // When - streamingOperations.createObjectStream(messages, mockInteraction, outputClass, mockAgentProcess, mockAction) - - // Then - verify { mockChatClientLlmOperations.createChatClient(mockLlm) } - } - - @Test - fun `should delegate getLlm to ChatClientLlmOperations on createObjectStreamWithThinking`() { - // Given - val messages = listOf(UserMessage("Create objects with thinking")) - val outputClass = TestItem::class.java - - // When - streamingOperations.createObjectStreamWithThinking( - messages, - mockInteraction, - outputClass, - mockAgentProcess, - mockAction - ) - - // Then - verify { mockChatClientLlmOperations.getLlm(mockInteraction) } - } - - @Test - fun `should delegate createChatClient to ChatClientLlmOperations on createObjectStreamWithThinking`() { - // Given - val messages = listOf(UserMessage("Create objects with thinking")) - val outputClass = TestItem::class.java - - // When - streamingOperations.createObjectStreamWithThinking( - messages, - mockInteraction, - outputClass, - mockAgentProcess, - mockAction - ) - - // Then - verify { mockChatClientLlmOperations.createChatClient(mockLlm) } - } - - @Test - fun `should return Flux from generateStream`() { - // Given - val messages = listOf(UserMessage("Test prompt")) - - // When - val result = streamingOperations.generateStream(messages, mockInteraction, mockAgentProcess, mockAction) - - // Then - assertNotNull(result) - } - - @Test - fun `should return Flux from createObjectStream`() { - // Given - val messages = listOf(UserMessage("Create objects")) - val outputClass = TestItem::class.java - - // When - val result = - streamingOperations.createObjectStream(messages, mockInteraction, outputClass, mockAgentProcess, mockAction) - - // Then - assertNotNull(result) - } - - @Test - fun `should return Flux from createObjectStreamWithThinking`() { - // Given - val messages = listOf(UserMessage("Create objects with thinking")) - val outputClass = TestItem::class.java - - // When - val result = streamingOperations.createObjectStreamWithThinking( - messages, - mockInteraction, - outputClass, - mockAgentProcess, - mockAction - ) - - // Then - assertNotNull(result) - assertTrue(result is Flux<*>) - } - - @Test - fun `should return Flux from createObjectStreamIfPossible`() { - // Given - val messages = listOf(UserMessage("Create objects safely")) - val outputClass = TestItem::class.java - - // When - val result = streamingOperations.createObjectStreamIfPossible( - messages, - mockInteraction, - outputClass, - mockAgentProcess, - mockAction - ) - - // Then - assertNotNull(result) - assertTrue(result is Flux<*>) - } - - @Test - fun `should accept null action parameter`() { - // Given - val messages = listOf(UserMessage("Test prompt")) - - // When & Then - should not throw exception - val result = streamingOperations.generateStream(messages, mockInteraction, mockAgentProcess, null) - assertNotNull(result) - } - - data class TestItem(val name: String, val value: Int) - - @Test - fun `should handle single complete chunk`() { - // Given: Single chunk with thinking content - val chunkFlux = Flux.just("This is thinking content\n") - mockChatClientForStreaming(chunkFlux) - - // When - val result = streamingOperations.createObjectStreamWithThinking( - messages = listOf(UserMessage("test")), - interaction = mockInteraction, - outputClass = TestItem::class.java, - agentProcess = mockAgentProcess, - action = mockAction - ) - - - // Then: Should emit one thinking event for the complete line - StepVerifier.create(result) - .expectNextMatches { - it.isThinking() && it.getThinking() == "This is thinking content" - } - .expectComplete() - .verify(Duration.ofSeconds(1)) - } - - @Test - fun `should handle multi-chunk JSONL object stream`() { - // Given: Multiple chunks forming JSONL objects - val chunkFlux = Flux.just( - "{\"name\":\"Item1\",\"value\":", - "42}\n{\"name\":\"Item2\",", - "\"value\":84}\n" - ) - mockChatClientForStreaming(chunkFlux) - - // When - val result = streamingOperations.createObjectStreamWithThinking( - messages = listOf(UserMessage("test")), - interaction = mockInteraction, - outputClass = TestItem::class.java, - agentProcess = mockAgentProcess, - action = mockAction - ) - - // Then: Should emit two object events - StepVerifier.create(result) - .expectNextMatches { - it.isObject() && it.getObject()?.name == "Item1" && it.getObject()?.value == 42 - } - .expectNextMatches { - it.isObject() && it.getObject()?.name == "Item2" && it.getObject()?.value == 84 - } - .expectComplete() - .verify(Duration.ofSeconds(1)) - } - - @Test - fun `should handle mixed thinking and object content in chunks`() { - // Given: Realistic chunking that splits thinking and JSON across chunk boundaries - val chunkFlux = Flux.just( - "Ana", // Partial thinking start - "lyzing req", // Partial thinking middle - "uirement\n{\"name\":", // Thinking end + partial JSON - "\"TestItem\",\"va", // Partial JSON middle - "lue\":123}\nDone", // JSON end + partial thinking - "\n" // Thinking end - ) - mockChatClientForStreaming(chunkFlux) - - // When - val result = streamingOperations.createObjectStreamWithThinking( - messages = listOf(UserMessage("test")), - interaction = mockInteraction, - outputClass = TestItem::class.java, - agentProcess = mockAgentProcess, - action = mockAction - ) - - // Then: Should emit thinking, object, thinking in correct order - StepVerifier.create(result) - .expectNextMatches { - it.isThinking() && it.getThinking() == "Analyzing requirement" - } - .expectNextMatches { - it.isObject() && it.getObject()?.name == "TestItem" && it.getObject()?.value == 123 - } - .expectNextMatches { - it.isThinking() && it.getThinking() == "Done" - } - .expectComplete() - .verify(Duration.ofSeconds(1)) - } - - @Test - fun `should handle real streaming with reactive callbacks`() { - // Given: Mixed content with multiple events - val chunkFlux = Flux.just( - "Processing request\n", - "{\"name\":\"Item1\",\"value\":100}\n", - "{\"name\":\"Item2\",\"value\":200}\n", - "Request completed\n" - ) - mockChatClientForStreaming(chunkFlux) - - // When: Subscribe with real reactive callbacks - val receivedEvents = mutableListOf() - var errorOccurred: Throwable? = null - var completionCalled = false - - val result = streamingOperations.createObjectStreamWithThinking( - messages = listOf(UserMessage("test")), - interaction = mockInteraction, - outputClass = TestItem::class.java, - agentProcess = mockAgentProcess, - action = mockAction - ) - - result - .doOnNext { event -> - when { - event.isThinking() -> { - val content = event.getThinking()!! - receivedEvents.add("THINKING: $content") - logger.info("Received thinking: {}", content) - } - - event.isObject() -> { - val obj = event.getObject()!! - receivedEvents.add("OBJECT: ${obj.name}=${obj.value}") - logger.info("Received object: {}={}", obj.name, obj.value) - } - } - } - .doOnError { error -> - errorOccurred = error - logger.error("Stream error: {}", error.message) - } - .doOnComplete { - completionCalled = true - logger.info("Stream completed successfully") - } - .subscribe() - - // Give stream time to complete - Thread.sleep(500) - - // Then: Verify real reactive behavior - assertNull(errorOccurred, "No errors should occur") - assertTrue(completionCalled, "Stream should complete successfully") - assertEquals(4, receivedEvents.size, "Should receive all events") - assertEquals("THINKING: Processing request", receivedEvents[0]) - assertEquals("OBJECT: Item1=100", receivedEvents[1]) - assertEquals("OBJECT: Item2=200", receivedEvents[2]) - assertEquals("THINKING: Request completed", receivedEvents[3]) - } - - @Test - fun `rawChunksToLines should handle single line chunks`() { - val chunks = Flux.just("line1\n", "line2\n") - - val result = streamingOperations.rawChunksToLines(chunks) - - StepVerifier.create(result) - .expectNext("line1") - .expectNext("line2") - .verifyComplete() - } - - @Test - fun `rawChunksToLines should handle multi-line chunks from Anthropic`() { - val chunks = Flux.just(".\n\n\n{\"") - - val result = streamingOperations.rawChunksToLines(chunks) - - StepVerifier.create(result) - .expectNext(".") - .expectNext("") - .expectNext("{\"") - .verifyComplete() - } - - @Test - fun `rawChunksToLines should handle incomplete lines across chunks`() { - val chunks = Flux.just("partial", " line\n", "complete\n") - - val result = streamingOperations.rawChunksToLines(chunks) - - StepVerifier.create(result) - .expectNext("partial line") - .expectNext("complete") - .verifyComplete() - } - - @Test - fun `rawChunksToLines should emit final incomplete line`() { - val chunks = Flux.just("line1\n", "incomplete") - - val result = streamingOperations.rawChunksToLines(chunks) - - StepVerifier.create(result) - .expectNext("line1") - .expectNext("incomplete") - .verifyComplete() - } - - - private fun mockChatClientForStreaming(chunkFlux: Flux) { - val mockRequestSpec = mockk(relaxed = true) - val mockContentStreamSpec = mockk(relaxed = true) - - every { mockChatClient.prompt(any()) } returns mockRequestSpec - every { mockRequestSpec.toolCallbacks(any>()) } returns mockRequestSpec - every { mockRequestSpec.options(any()) } returns mockRequestSpec - every { mockRequestSpec.stream() } returns mockContentStreamSpec - every { mockContentStreamSpec.content() } returns chunkFlux - - } - - /** - * Tests for useMessageStreamer=true (decoupled streaming path via LlmMessageStreamer). - */ - @Nested - inner class MessageStreamerTests { - - private lateinit var streamingOpsWithStreamer: StreamingChatClientOperations - - @BeforeEach - fun setUpStreamer() { - streamingOpsWithStreamer = StreamingChatClientOperations( - mockChatClientLlmOperations, - useMessageStreamer = true - ) - } - - @Test - fun `should use LlmMessageStreamer when useMessageStreamer is true`() { - // Given - val chunkFlux = Flux.just("streamed ", "content") - mockChatClientForStreaming(chunkFlux) - - // When - val result = streamingOpsWithStreamer.generateStream( - listOf(UserMessage("test")), - mockInteraction, - mockAgentProcess, - mockAction - ) - - // Then - StepVerifier.create(result) - .expectNext("streamed ") - .expectNext("content") - .verifyComplete() - } - - @Test - fun `should prepend prompt contributions as system message`() { - // Given - val mockContributor = mockk() - every { mockContributor.contribution() } returns "System contribution" - every { mockInteraction.promptContributors } returns listOf(mockContributor) - - val chunkFlux = Flux.just("response") - mockChatClientForStreaming(chunkFlux) - - // When - val result = streamingOpsWithStreamer.generateStream( - listOf(UserMessage("user message")), - mockInteraction, - mockAgentProcess, - mockAction - ) - - // Then - StepVerifier.create(result) - .expectNext("response") - .verifyComplete() - } - - @Test - fun `should handle object streaming`() { - // Given - val chunkFlux = Flux.just("{\"name\":\"Test\",\"value\":42}\n") - mockChatClientForStreaming(chunkFlux) - - // When - val result = streamingOpsWithStreamer.createObjectStreamWithThinking( - messages = listOf(UserMessage("test")), - interaction = mockInteraction, - outputClass = TestItem::class.java, - agentProcess = mockAgentProcess, - action = mockAction - ) - - // Then - StepVerifier.create(result) - .expectNextMatches { - it.isObject() && it.getObject()?.name == "Test" && it.getObject()?.value == 42 - } - .expectComplete() - .verify(Duration.ofSeconds(1)) - } - } -} diff --git a/embabel-agent-docs/src/main/asciidoc/reference/testing/page.adoc b/embabel-agent-docs/src/main/asciidoc/reference/testing/page.adoc index 8ceb62bfe..dec202a83 100644 --- a/embabel-agent-docs/src/main/asciidoc/reference/testing/page.adoc +++ b/embabel-agent-docs/src/main/asciidoc/reference/testing/page.adoc @@ -795,15 +795,26 @@ class StoryWriterIntegrationTest : EmbabelMockitoIntegrationTest() { ===== Key Integration Testing Features **Base Class Benefits:** -- `EmbabelMockitoIntegrationTest` handles Spring Boot setup and LLM mocking automatically - Provides `agentPlatform` and `llmOperations` pre-configured - Includes helper methods for common testing patterns + +- `EmbabelMockitoIntegrationTest` handles Spring Boot setup and LLM mocking automatically +- Provides `agentPlatform` and `llmOperations` pre-configured +- Includes helper methods for common testing patterns **Convenient Stubbing Methods:** -- `whenCreateObject(prompt, outputClass)`: Mock object creation calls - `whenGenerateText(prompt)`: Mock text generation calls - Support for both exact prompts and `contains()` matching + +- `whenCreateObject(prompt, outputClass)`: Mock object creation calls +- `whenGenerateText(prompt)`: Mock text generation calls +- Support for both exact prompts and `contains()` matching +- Supports streaming calls by calling `supportsStreaming(true)` in test setup. **Advanced Verification:** -- `verifyCreateObjectMatching()`: Verify prompts with custom matchers - `verifyGenerateTextMatching()`: Verify text generation calls - `verifyNoMoreInteractions()`: Ensure no unexpected LLM calls + +- `verifyCreateObjectMatching()`: Verify prompts with custom matchers +- `verifyGenerateTextMatching()`: Verify text generation calls +- `verifyNoMoreInteractions()`: Ensure no unexpected LLM calls **LLM Configuration Testing:** + - Verify temperature settings: `llm.getLlm().getTemperature() == 0.9` - Check tool groups: `llm.getToolGroups().isEmpty()` - Validate persona and other LLM options \ No newline at end of file diff --git a/embabel-agent-test-support/embabel-agent-test/src/main/java/com/embabel/agent/test/integration/EmbabelMockitoIntegrationTest.java b/embabel-agent-test-support/embabel-agent-test/src/main/java/com/embabel/agent/test/integration/EmbabelMockitoIntegrationTest.java index 6ed996bd8..69f53e479 100644 --- a/embabel-agent-test-support/embabel-agent-test/src/main/java/com/embabel/agent/test/integration/EmbabelMockitoIntegrationTest.java +++ b/embabel-agent-test-support/embabel-agent-test/src/main/java/com/embabel/agent/test/integration/EmbabelMockitoIntegrationTest.java @@ -16,9 +16,10 @@ package com.embabel.agent.test.integration; import com.embabel.agent.core.AgentPlatform; -import com.embabel.agent.core.internal.LlmOperations; import com.embabel.agent.core.support.LlmInteraction; +import com.embabel.agent.spi.support.springai.LlmOperationsIncludingStreaming; import com.embabel.chat.Message; +import com.embabel.common.ai.model.LlmOptions; import com.embabel.common.ai.model.ModelProvider; import org.mockito.ArgumentCaptor; import org.mockito.ArgumentMatcher; @@ -28,11 +29,14 @@ import org.springframework.boot.test.context.SpringBootTest; import org.springframework.test.context.TestPropertySource; import org.springframework.test.context.bean.override.mockito.MockitoBean; +import reactor.core.publisher.Flux; import java.util.List; import java.util.function.Predicate; -import static org.mockito.ArgumentMatchers.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -58,19 +62,33 @@ public class EmbabelMockitoIntegrationTest { private ModelProvider modelProvider; @MockitoBean - protected LlmOperations llmOperations; + protected LlmOperationsIncludingStreaming llmOperations; + + protected void supportsStreaming(boolean supportsStreaming) { + when(llmOperations.supportsStreaming(any(LlmOptions.class))).thenReturn(supportsStreaming); + } // Stubbing methods protected OngoingStubbing whenCreateObject(Predicate promptMatcher, Class outputClass, Predicate llmInteractionPredicate) { // Mock the lower level LLM operation to create an object // that will ultimately be called - return when(llmOperations.createObject(argThat(m -> firstMessageContentSatisfiesMatcher(m, promptMatcher)), argThat(llmInteractionPredicate::test), eq(outputClass), any(), any())); + return when( + llmOperations.createObject(argThat(m -> firstMessageContentSatisfiesMatcher(m, promptMatcher)), argThat(llmInteractionPredicate::test), eq(outputClass), any(), any())); } protected OngoingStubbing whenCreateObject(Predicate promptMatcher, Class outputClass) { return whenCreateObject(promptMatcher, outputClass, llmi -> true); } + protected OngoingStubbing> whenCreateObjectStream(Predicate promptMatcher, Class outputClass, Predicate llmInteractionPredicate) { + return when( + llmOperations.createObjectStream(argThat(m -> firstMessageContentSatisfiesMatcher(m, promptMatcher)), argThat(llmInteractionPredicate::test), eq(outputClass), any(), any())); + } + + protected OngoingStubbing> whenCreateObjectStream(Predicate promptMatcher, Class outputClass) { + return whenCreateObjectStream(promptMatcher, outputClass, llmi -> true); + } + protected OngoingStubbing whenGenerateText(Predicate promptMatcher, Predicate llmInteractionMatcher) { return when(llmOperations.createObject(argThat(m -> firstMessageContentSatisfiesMatcher(m, promptMatcher)), argThat(llmInteractionMatcher::test), eq(String.class), any(), any())); @@ -80,6 +98,15 @@ protected OngoingStubbing whenGenerateText(Predicate promptMatch return whenGenerateText(promptMatcher, llmi -> true); } + protected OngoingStubbing> whenGenerateStream(Predicate promptMatcher, Predicate llmInteractionMatcher) { + return when(llmOperations.generateStream(argThat(m -> firstMessageContentSatisfiesMatcher(m, promptMatcher)), + argThat(llmInteractionMatcher::test), any(), any())); + } + + protected OngoingStubbing> whenGenerateStream(Predicate promptMatcher) { + return whenGenerateStream(promptMatcher, llmi -> true); + } + // Verification methods protected void verifyCreateObject(Predicate promptMatcher, Class outputClass, Predicate llmInteractionMatcher) { verify(llmOperations).createObject(argThat(m -> firstMessageContentSatisfiesMatcher(m, promptMatcher)), @@ -90,6 +117,15 @@ protected void verifyCreateObject(Predicate prompt, Class outputC verifyCreateObject(prompt, outputClass, llmi -> true); } + protected void verifyCreateObjectStream(Predicate promptMatcher, Class outputClass, Predicate llmInteractionMatcher) { + verify(llmOperations).createObjectStream(argThat(m -> firstMessageContentSatisfiesMatcher(m, promptMatcher)), + argThat(llmInteractionMatcher::test), eq(outputClass), any(), any()); + } + + protected void verifyCreateObjectStream(Predicate prompt, Class outputClass) { + verifyCreateObjectStream(prompt, outputClass, llmi -> true); + } + protected void verifyGenerateText(Predicate promptMatcher, Predicate llmInteractionMatcher) { verify(llmOperations).createObject(argThat(m -> firstMessageContentSatisfiesMatcher(m, promptMatcher)), argThat(llmInteractionMatcher::test), eq(String.class), any(), any()); } @@ -98,6 +134,14 @@ protected void verifyGenerateText(Predicate promptMatcher) { verifyGenerateText(promptMatcher, llmi -> true); } + protected void verifyGenerateStream(Predicate promptMatcher, Predicate llmInteractionMatcher) { + verify(llmOperations).generateStream(argThat(m -> firstMessageContentSatisfiesMatcher(m, promptMatcher)), argThat(llmInteractionMatcher::test), any(), any()); + } + + protected void verifyGenerateStream(Predicate promptMatcher) { + verifyGenerateStream(promptMatcher, llmi -> true); + } + // Verification methods with argument matchers protected void verifyCreateObjectMatching(Predicate promptMatcher, Class outputClass, ArgumentMatcher llmInteractionMatcher) { verify(llmOperations).createObject(argThat(m -> firstMessageContentSatisfiesMatcher(m, promptMatcher)), argThat(llmInteractionMatcher), eq(outputClass), any(), any()); @@ -109,13 +153,32 @@ protected void verifyCreateObjectMatchingMessages(ArgumentMatcher void verifyCreateObjectStreamMatching(Predicate promptMatcher, Class outputClass, ArgumentMatcher llmInteractionMatcher) { + verify(llmOperations).createObjectStream(argThat(m -> firstMessageContentSatisfiesMatcher(m, promptMatcher)), argThat(llmInteractionMatcher), eq(outputClass), any(), any()); + } + + protected void verifyCreateObjectStreamMatchingMessages(ArgumentMatcher> promptMatcher, Class outputClass, ArgumentMatcher llmInteractionMatcher) { + verify(llmOperations).createObjectStream(argThat(promptMatcher), + argThat(llmInteractionMatcher), + eq(outputClass), any(), any()); + } + protected void verifyGenerateTextMatching(Predicate promptMatcher) { verify(llmOperations).createObject(argThat(messages -> firstMessageContentSatisfiesMatcher(messages, promptMatcher)), any(), eq(String.class), any(), any()); } - protected void verifyGenerateTextMatching(Predicate promptMatcher, - LlmInteraction llmInteraction) { - Mockito.verify(llmOperations).createObject(argThat(messages -> firstMessageContentSatisfiesMatcher(messages, promptMatcher)), eq(llmInteraction), eq(String.class), any(), any()); + protected void verifyGenerateTextMatching(Predicate promptMatcher, LlmInteraction llmInteraction) { + Mockito.verify(llmOperations) + .createObject(argThat(messages -> firstMessageContentSatisfiesMatcher(messages, promptMatcher)), eq(llmInteraction), eq(String.class), any(), any()); + } + + protected void verifyGenerateStreamMatching(Predicate promptMatcher) { + verify(llmOperations).generateStream(argThat(messages -> firstMessageContentSatisfiesMatcher(messages, promptMatcher)), any(), any(), any()); + } + + protected void verifyGenerateStreamMatching(Predicate promptMatcher, LlmInteraction llmInteraction) { + Mockito.verify(llmOperations) + .generateStream(argThat(messages -> firstMessageContentSatisfiesMatcher(messages, promptMatcher)), eq(llmInteraction), any(), any()); } // Convenience verification methods