diff --git a/embabel-agent-api/src/main/kotlin/com/embabel/agent/api/common/support/OperationContextDelegate.kt b/embabel-agent-api/src/main/kotlin/com/embabel/agent/api/common/support/OperationContextDelegate.kt index 33c813e64..71f7da5be 100644 --- a/embabel-agent-api/src/main/kotlin/com/embabel/agent/api/common/support/OperationContextDelegate.kt +++ b/embabel-agent-api/src/main/kotlin/com/embabel/agent/api/common/support/OperationContextDelegate.kt @@ -16,7 +16,7 @@ package com.embabel.agent.api.common.support import com.embabel.agent.api.common.* -import com.embabel.agent.api.common.support.streaming.StreamingCapabilityDetector +import com.embabel.agent.spi.support.streaming.StreamingCapabilityDetector import com.embabel.agent.api.tool.ArtifactSinkingTool import com.embabel.agent.api.tool.Tool import com.embabel.agent.api.tool.ToolCallContext @@ -35,11 +35,10 @@ import com.embabel.agent.core.internal.LlmOperations import com.embabel.agent.core.support.LlmInteraction import com.embabel.agent.core.support.safelyGetTools import com.embabel.agent.experimental.primitive.Determination +import com.embabel.agent.core.internal.streaming.StreamingLlmOperationsFactory import com.embabel.agent.spi.loop.ToolChainingInjectionStrategy import com.embabel.agent.spi.loop.ToolInjectionStrategy import com.embabel.agent.spi.loop.ToolNotFoundPolicy -import com.embabel.agent.spi.support.springai.ChatClientLlmOperations -import com.embabel.agent.spi.support.springai.streaming.StreamingChatClientOperations import com.embabel.chat.AssistantMessage import com.embabel.chat.ImagePart import com.embabel.chat.Message @@ -325,8 +324,7 @@ internal data class OperationContextDelegate( } override fun generateStream(): Flux { - val llmOperations = context.agentPlatform().platformServices.llmOperations as ChatClientLlmOperations - val streamingLlmOperations = StreamingChatClientOperations(llmOperations) + val streamingLlmOperations = streamingFactory().createStreamingOperations(llm) return streamingLlmOperations.generateStream( messages = messages, @@ -337,8 +335,7 @@ internal data class OperationContextDelegate( } override fun createObjectStream(itemClass: Class): Flux { - val llmOperations = context.agentPlatform().platformServices.llmOperations as ChatClientLlmOperations - val streamingLlmOperations = StreamingChatClientOperations(llmOperations) + val streamingLlmOperations = streamingFactory().createStreamingOperations(llm) return streamingLlmOperations.createObjectStream( messages = messages, @@ -350,8 +347,7 @@ internal data class OperationContextDelegate( } override fun createObjectStreamWithThinking(itemClass: Class): Flux> { - val llmOperations = context.agentPlatform().platformServices.llmOperations as ChatClientLlmOperations - val streamingLlmOperations = StreamingChatClientOperations(llmOperations) + val streamingLlmOperations = streamingFactory().createStreamingOperations(llm) return streamingLlmOperations.createObjectStreamWithThinking( messages = messages, interaction = streamingInteraction(), @@ -381,6 +377,15 @@ internal data class OperationContextDelegate( ) } + private fun streamingFactory(): StreamingLlmOperationsFactory { + val llmOperations = context.agentPlatform().platformServices.llmOperations + return llmOperations as? StreamingLlmOperationsFactory + ?: throw UnsupportedOperationException( + "Streaming not supported: LlmOperations (${llmOperations::class.simpleName}) " + + "does not implement StreamingLlmOperationsFactory" + ) + } + override fun supportsThinking(): Boolean = true // Patterned after createObject() - uses ProcessContext flow diff --git a/embabel-agent-api/src/main/kotlin/com/embabel/agent/api/common/support/streaming/StreamingCapabilityDetector.kt b/embabel-agent-api/src/main/kotlin/com/embabel/agent/api/common/support/streaming/StreamingCapabilityDetector.kt deleted file mode 100644 index b732ced9e..000000000 --- a/embabel-agent-api/src/main/kotlin/com/embabel/agent/api/common/support/streaming/StreamingCapabilityDetector.kt +++ /dev/null @@ -1,119 +0,0 @@ -/* - * Copyright 2024-2026 Embabel Pty Ltd. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.embabel.agent.api.common.support.streaming - -import com.embabel.agent.api.common.InteractionId -import com.embabel.agent.api.common.support.streaming.StreamingCapabilityDetector.supportsStreaming -import com.embabel.agent.core.internal.LlmOperations -import com.embabel.agent.core.support.LlmInteraction -import com.embabel.agent.spi.support.springai.ChatClientLlmOperations -import com.embabel.agent.spi.support.springai.SpringAiLlmService -import com.embabel.common.ai.model.LlmOptions -import com.embabel.common.util.loggerFor -import org.springframework.ai.chat.messages.UserMessage -import org.springframework.ai.chat.model.ChatModel -import org.springframework.ai.chat.model.ChatResponse -import org.springframework.ai.chat.prompt.Prompt -import reactor.core.publisher.Flux -import java.time.Duration -import java.util.concurrent.ConcurrentHashMap - -/** - * Utility for detecting actual streaming capability in ChatModel implementations. - * - * Spring AI's ChatModel interface extends StreamingChatModel, but not all implementations - * provide meaningful streaming support. Some may throw UnsupportedOperationException, - * return empty Flux, or provide stub implementations. - * - * This detector performs lightweight behavioral testing to determine if a model - * actually supports streaming operations. - */ -internal object StreamingCapabilityDetector { - private val logger = loggerFor() - private val capabilityCache = ConcurrentHashMap, Boolean>() - - private const val CACHE_MISS_LOG_MESSAGE = "Cache miss for {}, testing streaming capability..." - private const val TEST_PROMPT_MESSAGE = "Say 'test' to confirm streaming works" - - /** - * Tests whether the given ChatModel actually supports streaming operations. - * - * @param model The ChatModel to test - * @return true if the model supports streaming, false otherwise - */ - fun supportsStreaming(model: ChatModel): Boolean { - // Cache by model class to avoid repeated tests - //return true - return capabilityCache.computeIfAbsent(model.javaClass) { - logger.debug(CACHE_MISS_LOG_MESSAGE, model.javaClass.simpleName) - testStreamingCapability(model) - } - } - - /** - * Delegates to [supportsStreaming] - */ - fun supportsStreaming(llmOperations: LlmOperations, llmOptions: LlmOptions): Boolean { - - // Level 1 sanity check - if (llmOperations !is ChatClientLlmOperations) return false - - // Level 2: Must have actual streaming capability - val llm = llmOperations.getLlm( - LlmInteraction( // check for circular dependency - id = InteractionId("capability-check"), - llm = llmOptions - ) - ) - - val springAiLlm = llm as? SpringAiLlmService ?: return false - return supportsStreaming(springAiLlm.chatModel) - - } - - private fun testStreamingCapability(model: ChatModel): Boolean { - return try { - // Use a prompt that should generate a response - val testRequest = Prompt(listOf(UserMessage(TEST_PROMPT_MESSAGE))) - val stream = model.stream(testRequest) - - // Test if stream can be consumed without errors - canConsumeStream(stream) - true - - } catch (e: UnsupportedOperationException) { - false - } catch (e: Exception) { - false - } - } - - private fun canConsumeStream(stream: Flux): Boolean { - return try { - // Test if we can check stream elements without consuming - stream.hasElements() - .timeout(Duration.ofMillis(100)) // configure - .block() - - // If we get here without exceptions, streaming capability exists - - true - - } catch (e: Exception) { - false // Any exception means streaming doesn't work - } - } -} diff --git a/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/streaming/StreamingLlmOperations.kt b/embabel-agent-api/src/main/kotlin/com/embabel/agent/core/internal/streaming/StreamingLlmOperations.kt similarity index 97% rename from embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/streaming/StreamingLlmOperations.kt rename to embabel-agent-api/src/main/kotlin/com/embabel/agent/core/internal/streaming/StreamingLlmOperations.kt index 0e13a035b..93f0da684 100644 --- a/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/streaming/StreamingLlmOperations.kt +++ b/embabel-agent-api/src/main/kotlin/com/embabel/agent/core/internal/streaming/StreamingLlmOperations.kt @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.embabel.agent.spi.streaming +package com.embabel.agent.core.internal.streaming import com.embabel.agent.api.event.LlmRequestEvent import com.embabel.agent.core.Action @@ -22,12 +22,13 @@ import com.embabel.agent.core.support.LlmInteraction import com.embabel.chat.Message import com.embabel.chat.UserMessage import com.embabel.common.core.streaming.StreamingEvent +import org.jetbrains.annotations.ApiStatus import reactor.core.publisher.Flux /** * Streaming extension of LlmOperations for real-time LLM response processing. * - * This SPI interface provides reactive streaming capabilities that support + * This internal interface provides reactive streaming capabilities that support * the API layer StreamingPromptRunner interfaces, enabling: * - Real-time processing of LLM responses as they arrive * - Streaming lists of objects from JSONL responses @@ -37,6 +38,7 @@ import reactor.core.publisher.Flux * All streaming methods return Project Reactor Flux streams for integration * with Spring WebFlux and other reactive frameworks. */ +@ApiStatus.Internal interface StreamingLlmOperations { /** diff --git a/embabel-agent-api/src/main/kotlin/com/embabel/agent/core/internal/streaming/StreamingLlmOperationsFactory.kt b/embabel-agent-api/src/main/kotlin/com/embabel/agent/core/internal/streaming/StreamingLlmOperationsFactory.kt new file mode 100644 index 000000000..be37d3738 --- /dev/null +++ b/embabel-agent-api/src/main/kotlin/com/embabel/agent/core/internal/streaming/StreamingLlmOperationsFactory.kt @@ -0,0 +1,52 @@ +/* + * Copyright 2024-2026 Embabel Pty Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.embabel.agent.core.internal.streaming + +import com.embabel.common.ai.model.LlmOptions +import org.jetbrains.annotations.ApiStatus + +/** + * Factory interface for creating [StreamingLlmOperations] instances. + * + * This interface is separate from [com.embabel.agent.core.internal.LlmOperations] + * to maintain interface segregation. Implementations can choose to implement + * both interfaces or just one, depending on their capabilities. + * + * @see StreamingLlmOperations + * @see com.embabel.agent.core.internal.LlmOperations + */ +@ApiStatus.Internal +interface StreamingLlmOperationsFactory { + + /** + * Check if streaming is supported for the given LLM options. + * + * This method allows capability detection without resolving specific + * implementation details, enabling proper mocking in tests. + * + * @param options LLM options including model selection criteria + * @return true if streaming is supported for the resolved LLM + */ + fun supportsStreaming(options: LlmOptions): Boolean + + /** + * Create a [StreamingLlmOperations] instance configured with the given options. + * + * @param options LLM options including model selection criteria + * @return A streaming operations instance for the selected LLM + */ + fun createStreamingOperations(options: LlmOptions): StreamingLlmOperations +} diff --git a/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/LlmService.kt b/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/LlmService.kt index 944474675..da160ce7c 100644 --- a/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/LlmService.kt +++ b/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/LlmService.kt @@ -16,6 +16,7 @@ package com.embabel.agent.spi import com.embabel.agent.spi.loop.LlmMessageSender +import com.embabel.agent.spi.loop.streaming.LlmMessageStreamer import com.embabel.common.ai.model.LlmMetadata import com.embabel.common.ai.model.LlmOptions import com.embabel.common.ai.prompt.PromptContributor @@ -48,6 +49,27 @@ interface LlmService> : LlmMetadata, PromptContributorCo */ fun createMessageSender(options: LlmOptions): LlmMessageSender + /** + * Create a message streamer for this LLM configured with the given options. + * + * The message streamer handles streaming LLM API calls. Tool execution is + * handled internally by the underlying LLM framework during streaming. + * + * @param options Configuration options for the LLM call (temperature, max tokens, etc.) + * @return A message streamer configured for this LLM + */ + fun createMessageStreamer(options: LlmOptions): LlmMessageStreamer + + /** + * Check if this LLM service supports streaming operations. + * + * Each LlmService instance is bound to a specific model, so this checks + * whether that particular model supports streaming. + * + * @return true if the underlying model supports streaming, false otherwise + */ + fun supportsStreaming(): Boolean + /** * Returns a copy of this LLM service with the specified knowledge cutoff date. * diff --git a/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/AbstractLlmOperations.kt b/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/AbstractLlmOperations.kt index 9eed83e34..02d3f52ec 100644 --- a/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/AbstractLlmOperations.kt +++ b/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/AbstractLlmOperations.kt @@ -21,6 +21,9 @@ import com.embabel.agent.api.tool.Tool import com.embabel.agent.core.Action import com.embabel.agent.core.AgentProcess import com.embabel.agent.core.internal.LlmOperations +import com.embabel.agent.core.internal.streaming.StreamingLlmOperations +import com.embabel.agent.core.internal.streaming.StreamingLlmOperationsFactory +import com.embabel.agent.spi.support.streaming.StreamingLlmOperationsImpl import com.embabel.agent.core.support.InvalidLlmReturnTypeException import com.embabel.agent.core.support.LlmInteraction import com.embabel.agent.spi.AutoLlmSelectionCriteriaResolver @@ -37,6 +40,7 @@ import com.embabel.common.ai.model.ModelSelectionCriteria import com.embabel.common.ai.model.PreResolvedModelSelectionCriteria import com.embabel.common.core.thinking.ThinkingResponse import com.embabel.common.util.time +import com.fasterxml.jackson.databind.ObjectMapper import jakarta.validation.ConstraintViolation import jakarta.validation.Validator import java.lang.reflect.Field @@ -67,7 +71,8 @@ abstract class AbstractLlmOperations( protected val dataBindingProperties: LlmDataBindingProperties, protected val promptsProperties: LlmOperationsPromptsProperties = LlmOperationsPromptsProperties(), protected val asyncer: Asyncer, -) : LlmOperations { + internal open val objectMapper: ObjectMapper, +) : LlmOperations, StreamingLlmOperationsFactory { protected val logger: Logger = LoggerFactory.getLogger(javaClass) @@ -419,6 +424,22 @@ abstract class AbstractLlmOperations( return modelProvider.getLlm(crit) } + override fun supportsStreaming(options: LlmOptions): Boolean { + val llmService = chooseLlm(options) + return llmService.supportsStreaming() + } + + override fun createStreamingOperations(options: LlmOptions): StreamingLlmOperations { + val llmService = chooseLlm(options) + val messageStreamer = llmService.createMessageStreamer(options) + return StreamingLlmOperationsImpl( + messageStreamer = messageStreamer, + objectMapper = objectMapper, + llmService = llmService, + toolDecorator = toolDecorator, + ) + } + protected abstract fun doTransformIfPossible( messages: List, interaction: LlmInteraction, diff --git a/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/ToolLoopLlmOperations.kt b/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/ToolLoopLlmOperations.kt index edc30a069..568b7036e 100644 --- a/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/ToolLoopLlmOperations.kt +++ b/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/ToolLoopLlmOperations.kt @@ -58,8 +58,6 @@ import java.time.Duration import java.time.Instant import javax.annotation.concurrent.ThreadSafe -const val PROMPT_ELEMENT_SEPARATOR = "\n----\n" - /** * Output converter abstraction for parsing LLM output. * Framework-agnostic interface that can be implemented by Spring AI converters or others. @@ -106,7 +104,7 @@ open class ToolLoopLlmOperations( dataBindingProperties: LlmDataBindingProperties = LlmDataBindingProperties(), autoLlmSelectionCriteriaResolver: AutoLlmSelectionCriteriaResolver = AutoLlmSelectionCriteriaResolver.DEFAULT, promptsProperties: LlmOperationsPromptsProperties = LlmOperationsPromptsProperties(), - internal open val objectMapper: ObjectMapper = jacksonObjectMapper().registerModule(JavaTimeModule()), + objectMapper: ObjectMapper = jacksonObjectMapper().registerModule(JavaTimeModule()), protected val observationRegistry: ObservationRegistry = ObservationRegistry.NOOP, asyncer: Asyncer = ExecutorAsyncer(java.util.concurrent.Executors.newCachedThreadPool()), protected val toolLoopFactory: ToolLoopFactory = ToolLoopFactory.create(ToolLoopConfiguration(), asyncer, AutoCorrectionPolicy()), @@ -120,6 +118,7 @@ open class ToolLoopLlmOperations( autoLlmSelectionCriteriaResolver = autoLlmSelectionCriteriaResolver, promptsProperties = promptsProperties, asyncer = asyncer, + objectMapper = objectMapper, ) { override fun doTransform( @@ -631,10 +630,7 @@ open class ToolLoopLlmOperations( protected fun buildPromptContributions( interaction: LlmInteraction, llm: LlmService<*>, - ): String { - return (interaction.promptContributors + llm.promptContributors) - .joinToString(PROMPT_ELEMENT_SEPARATOR) { it.contribution() } - } + ): String = buildPromptContributionsString(interaction.promptContributors, llm.promptContributors) /** * Build initial messages for the tool loop, including system prompt contributions and schema. diff --git a/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/messagePromptBuilders.kt b/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/messagePromptBuilders.kt new file mode 100644 index 000000000..61f842592 --- /dev/null +++ b/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/messagePromptBuilders.kt @@ -0,0 +1,94 @@ +/* + * Copyright 2024-2026 Embabel Pty Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.embabel.agent.spi.support + +import com.embabel.chat.Message +import com.embabel.chat.SystemMessage +import com.embabel.common.ai.prompt.PromptContributor + +/** + * Separator used between prompt elements when building consolidated prompts. + */ +internal const val PROMPT_ELEMENT_SEPARATOR = "\n----\n" + +/** + * Builds a prompt contributions string from prompt contributors. + * + * @param interactionContributors Contributors from the LlmInteraction + * @param llmContributors Contributors from the LlmService + * @return Consolidated prompt contributions string + */ +internal fun buildPromptContributionsString( + interactionContributors: List, + llmContributors: List, +): String = (interactionContributors + llmContributors) + .joinToString(PROMPT_ELEMENT_SEPARATOR) { it.contribution() } + +/** + * Partitions messages into system message content and non-system messages. + * This enables consolidating all system content at the beginning of the prompt. + * + * @param messages The messages to partition + * @return Pair of (system content strings, non-system messages) + */ +internal fun partitionMessages(messages: List): Pair, List> { + val systemContent = mutableListOf() + val nonSystemMessages = mutableListOf() + for (message in messages) { + if (message is SystemMessage) { + systemContent.add(message.content) + } else { + nonSystemMessages.add(message) + } + } + return systemContent to nonSystemMessages +} + +/** + * Consolidates multiple system content strings into a single string. + * Follows OpenAI best practices and ensures compatibility with models like DeepSeek + * that have strict message ordering requirements. + * + * @param contents The content strings to consolidate + * @return Single consolidated string with contents joined by double newlines + */ +internal fun buildConsolidatedSystemMessage(vararg contents: String): String = + contents.filter { it.isNotEmpty() }.joinToString("\n\n") + +/** + * Builds a message list with all system content consolidated into a single + * system message at the beginning. + * + * Partitions input messages, extracts system content, merges with prompt contributions, + * and returns a new list with consolidated system message followed by non-system messages. + * + * @param messages The input messages (may contain system messages to extract) + * @param promptContributions The prompt contributions string to include + * @return Message list with consolidated system message first + */ +internal fun buildConsolidatedPromptMessages( + messages: List, + promptContributions: String, +): List { + val (systemContent, nonSystemMessages) = partitionMessages(messages) + val allSystemContent = buildConsolidatedSystemMessage(promptContributions, *systemContent.toTypedArray()) + return buildList { + if (allSystemContent.isNotEmpty()) { + add(SystemMessage(allSystemContent)) + } + addAll(nonSystemMessages) + } +} diff --git a/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/springai/ChatClientLlmOperations.kt b/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/springai/ChatClientLlmOperations.kt index b8ecf955f..9f40220bb 100644 --- a/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/springai/ChatClientLlmOperations.kt +++ b/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/springai/ChatClientLlmOperations.kt @@ -20,6 +20,8 @@ import com.embabel.agent.api.event.LlmRequestEvent import com.embabel.agent.api.tool.Tool import com.embabel.agent.api.tool.config.ToolLoopConfiguration import com.embabel.agent.core.Action +import com.embabel.agent.core.internal.streaming.StreamingLlmOperations +import com.embabel.agent.spi.support.springai.streaming.StreamingChatClientOperations import com.embabel.agent.core.AgentProcess import com.embabel.agent.core.support.LlmInteraction import com.embabel.agent.core.support.toEmbabelUsage @@ -35,6 +37,8 @@ import com.embabel.agent.spi.support.MaybeReturn import com.embabel.agent.spi.support.OutputConverter import com.embabel.agent.spi.support.ToolLoopLlmOperations import com.embabel.agent.spi.support.ToolResolutionHelper +import com.embabel.agent.spi.support.buildConsolidatedSystemMessage +import com.embabel.agent.spi.support.partitionMessages import com.embabel.agent.spi.support.guardrails.validateAssistantResponse import com.embabel.agent.spi.support.guardrails.validateUserInput import com.embabel.agent.spi.validation.DefaultValidationPromptGenerator @@ -55,6 +59,7 @@ import com.fasterxml.jackson.module.kotlin.jacksonObjectMapper import io.micrometer.observation.ObservationRegistry import jakarta.annotation.PostConstruct import jakarta.validation.Validator +import org.springframework.beans.factory.annotation.Value import java.lang.reflect.ParameterizedType import java.util.concurrent.CompletableFuture import java.util.concurrent.ExecutionException @@ -111,6 +116,8 @@ internal class ChatClientLlmOperations( private val customizers: List = emptyList(), asyncer: Asyncer, toolLoopFactory: ToolLoopFactory = ToolLoopFactory.create(ToolLoopConfiguration(), asyncer, AutoCorrectionPolicy()), + @Value("\${embabel.agent.platform.streaming.use-legacy-streaming:false}") + private val useLegacyStreaming: Boolean = false, ) : ToolLoopLlmOperations( toolDecorator = toolDecorator, modelProvider = modelProvider, @@ -737,32 +744,6 @@ internal class ChatClientLlmOperations( ) } - /** - * Partitions messages into system messages (content only) and non-system messages. - * This enables consolidating all system content at the beginning of the prompt. - */ - private fun partitionMessages(messages: List): Pair, List> { - val systemContent = mutableListOf() - val nonSystemMessages = mutableListOf() - for (message in messages) { - if (message is com.embabel.chat.SystemMessage) { - systemContent.add(message.content) - } else { - nonSystemMessages.add(message) - } - } - return systemContent to nonSystemMessages - } - - /** - * Consolidates multiple system content strings into a single system message. - * This ensures a single system message at the beginning of the conversation, - * following OpenAI best practices and ensuring compatibility with models like DeepSeek - * that have strict message ordering requirements. - */ - private fun buildConsolidatedSystemMessage(vararg contents: String): String = - contents.filter { it.isNotEmpty() }.joinToString("\n\n") - // ==================================== // EXCEPTION HANDLING // ==================================== @@ -896,6 +877,25 @@ internal class ChatClientLlmOperations( action = action, toolDecorator = toolDecorator, ) + + // ==================================== + // STREAMING OPERATIONS FACTORY + // ==================================== + + /** + * Creates streaming operations with fallback to legacy Spring AI implementation. + * + * TODO: Remove this override, the [useLegacyStreaming] flag, and [StreamingChatClientOperations] + * once the new StreamingLlmOperationsImpl is fully validated. After removal, + * casts to ChatClientLlmOperations will no longer be needed anywhere in the codebase. + */ + override fun createStreamingOperations(options: LlmOptions): StreamingLlmOperations { + return if (useLegacyStreaming) { + StreamingChatClientOperations(this) + } else { + super.createStreamingOperations(options) + } + } } /** diff --git a/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/springai/SpringAiLlmService.kt b/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/springai/SpringAiLlmService.kt index 89d05d152..400bb73b2 100644 --- a/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/springai/SpringAiLlmService.kt +++ b/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/springai/SpringAiLlmService.kt @@ -17,13 +17,60 @@ package com.embabel.agent.spi.support.springai import com.embabel.agent.spi.LlmService import com.embabel.agent.spi.loop.LlmMessageSender +import com.embabel.agent.spi.loop.streaming.LlmMessageStreamer +import com.embabel.agent.spi.support.springai.streaming.SpringAiLlmMessageStreamer import com.embabel.common.ai.model.* import com.embabel.common.ai.prompt.KnowledgeCutoffDate import com.embabel.common.ai.prompt.PromptContributor import com.fasterxml.jackson.databind.annotation.JsonSerialize +import org.springframework.ai.chat.client.ChatClient +import org.springframework.ai.chat.messages.UserMessage import org.springframework.ai.chat.model.ChatModel +import org.springframework.ai.chat.model.ChatResponse +import org.springframework.ai.chat.prompt.Prompt +import reactor.core.publisher.Flux +import java.time.Duration import java.time.LocalDate +/** + * Helper object for verifying streaming capability of ChatModel instances. + * + * Spring AI's ChatModel interface extends StreamingChatModel, but not all implementations + * provide meaningful streaming support. Some may throw UnsupportedOperationException, + * return empty Flux, or provide stub implementations. + * + * This helper performs lightweight behavioral testing to determine if a model + * actually supports streaming operations. + */ +private object StreamingCapabilityVerifier { + private const val TEST_PROMPT_MESSAGE = "Say 'test' to confirm streaming works" + private const val STREAMING_TEST_TIMEOUT_MS = 100L + + fun supportsStreaming(chatModel: ChatModel): Boolean { + return try { + val testRequest = Prompt(listOf(UserMessage(TEST_PROMPT_MESSAGE))) + val stream = chatModel.stream(testRequest) + canConsumeStream(stream) + true + } catch (e: UnsupportedOperationException) { + false + } catch (e: Exception) { + false + } + } + + private fun canConsumeStream(stream: Flux): Boolean { + return try { + stream.hasElements() + .timeout(Duration.ofMillis(STREAMING_TEST_TIMEOUT_MS)) + .block() + true + } catch (e: Exception) { + false + } + } +} + /** * Spring AI implementation that provides decoupled LLM operations. * @@ -71,6 +118,14 @@ data class SpringAiLlmService @JvmOverloads constructor( return SpringAiLlmMessageSender(chatModel, chatOptions, toolResponseContentAdapter) } + override fun createMessageStreamer(options: LlmOptions): LlmMessageStreamer { + val chatOptions = optionsConverter.convertOptions(options) + val chatClient = ChatClient.create(chatModel) + return SpringAiLlmMessageStreamer(chatClient, chatOptions) + } + + override fun supportsStreaming(): Boolean = StreamingCapabilityVerifier.supportsStreaming(chatModel) + override fun withKnowledgeCutoffDate(date: LocalDate): SpringAiLlmService = copy( knowledgeCutoffDate = date, diff --git a/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/springai/streaming/StreamingChatClientOperations.kt b/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/springai/streaming/StreamingChatClientOperations.kt index 63c015e7c..d2ce5a37f 100644 --- a/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/springai/streaming/StreamingChatClientOperations.kt +++ b/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/springai/streaming/StreamingChatClientOperations.kt @@ -21,8 +21,10 @@ import com.embabel.agent.core.AgentProcess import com.embabel.agent.core.support.LlmInteraction import com.embabel.agent.spi.LlmService import com.embabel.agent.spi.loop.streaming.LlmMessageStreamer -import com.embabel.agent.spi.streaming.StreamingLlmOperations +import com.embabel.agent.core.internal.streaming.StreamingLlmOperations import com.embabel.agent.spi.support.PROMPT_ELEMENT_SEPARATOR +import com.embabel.agent.spi.support.buildConsolidatedPromptMessages +import com.embabel.agent.spi.support.buildPromptContributionsString import com.embabel.agent.spi.support.guardrails.validateUserInput import com.embabel.agent.spi.support.springai.ChatClientLlmOperations import com.embabel.agent.spi.support.springai.SpringAiLlmService @@ -54,7 +56,19 @@ import reactor.core.publisher.Mono * **Unified Architecture:** * All streaming methods are built on a single internal pipeline that emits `StreamingEvent`, * allowing consistent behavior and the flexibility to filter events as needed by different use cases. + * + * @deprecated Use [com.embabel.agent.spi.support.streaming.StreamingLlmOperationsImpl] via + * [com.embabel.agent.core.internal.streaming.StreamingLlmOperationsFactory.createStreamingOperations]. + * This class will be removed once the new vendor-agnostic streaming implementation is fully validated. */ +@Deprecated( + message = "Use StreamingLlmOperationsImpl via StreamingLlmOperationsFactory.createStreamingOperations(). " + + "Will be removed once new streaming implementation is fully validated.", + replaceWith = ReplaceWith( + "StreamingLlmOperationsFactory.createStreamingOperations(options)", + "com.embabel.agent.core.internal.streaming.StreamingLlmOperationsFactory" + ) +) internal class StreamingChatClientOperations( private val chatClientLlmOperations: ChatClientLlmOperations, /** @@ -70,12 +84,9 @@ internal class StreamingChatClientOperations( /** * Build prompt contributions string from interaction and LLM contributors. - * Consider helper */ - private fun buildPromptContributions(interaction: LlmInteraction, llm: LlmService<*>): String { - return (interaction.promptContributors + llm.promptContributors) - .joinToString(PROMPT_ELEMENT_SEPARATOR) { it.contribution() } - } + private fun buildPromptContributions(interaction: LlmInteraction, llm: LlmService<*>): String = + buildPromptContributionsString(interaction.promptContributors, llm.promptContributors) /** * Build Spring AI Prompt from messages and contributions. @@ -443,12 +454,7 @@ internal class StreamingChatClientOperations( private fun buildMessagesWithContributions( messages: List, promptContributions: String, - ): List = buildList { - if (promptContributions.isNotEmpty()) { - add(com.embabel.chat.SystemMessage(promptContributions)) - } - addAll(messages) - } + ): List = buildConsolidatedPromptMessages(messages, promptContributions) /** * Create raw content stream from LLM. diff --git a/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/streaming/StreamingCapabilityDetector.kt b/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/streaming/StreamingCapabilityDetector.kt new file mode 100644 index 000000000..5408d4d74 --- /dev/null +++ b/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/streaming/StreamingCapabilityDetector.kt @@ -0,0 +1,59 @@ +/* + * Copyright 2024-2026 Embabel Pty Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.embabel.agent.spi.support.streaming + +import com.embabel.agent.core.internal.LlmOperations +import com.embabel.agent.core.internal.streaming.StreamingLlmOperationsFactory +import com.embabel.common.ai.model.LlmOptions +import com.embabel.common.util.loggerFor +import java.util.concurrent.ConcurrentHashMap + +/** + * Utility for detecting and caching streaming capability of LLM services. + * + * Not all LLM implementations provide meaningful streaming support. Some may throw + * UnsupportedOperationException, return empty Flux, or provide stub implementations. + * + * This detector caches the results of streaming capability tests (delegated to + * [StreamingLlmOperationsFactory.supportsStreaming]) using the model as cache key. + */ +internal object StreamingCapabilityDetector { + private val logger = loggerFor() + private val capabilityCache = ConcurrentHashMap() + + private const val CACHE_MISS_LOG_MESSAGE = "Cache miss for {}, testing streaming capability..." + + /** + * Tests whether the LLM resolved from the given operations and options supports streaming. + * + * Results are cached by model to avoid repeated tests. + * + * @param llmOperations The LLM operations instance + * @param llmOptions Options used to resolve the LLM + * @return true if streaming is supported, false otherwise + */ + fun supportsStreaming(llmOperations: LlmOperations, llmOptions: LlmOptions): Boolean { + // Must be a StreamingLlmOperationsFactory to support streaming + if (llmOperations !is StreamingLlmOperationsFactory) return false + + // Cache by model (or criteria string) + val cacheKey = llmOptions.model ?: llmOptions.criteria?.toString() ?: "default" + return capabilityCache.computeIfAbsent(cacheKey) { + logger.debug(CACHE_MISS_LOG_MESSAGE, cacheKey) + llmOperations.supportsStreaming(llmOptions) + } + } +} diff --git a/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/streaming/StreamingLlmOperationsImpl.kt b/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/streaming/StreamingLlmOperationsImpl.kt new file mode 100644 index 000000000..6b96be260 --- /dev/null +++ b/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/streaming/StreamingLlmOperationsImpl.kt @@ -0,0 +1,300 @@ +/* + * Copyright 2024-2026 Embabel Pty Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.embabel.agent.spi.support.streaming + +import com.embabel.agent.api.event.LlmRequestEvent +import com.embabel.agent.api.tool.Tool +import com.embabel.agent.core.Action +import com.embabel.agent.core.AgentProcess +import com.embabel.agent.core.support.LlmInteraction +import com.embabel.agent.spi.LlmService +import com.embabel.agent.spi.ToolDecorator +import com.embabel.agent.spi.loop.streaming.LlmMessageStreamer +import com.embabel.agent.core.internal.streaming.StreamingLlmOperations +import com.embabel.agent.spi.support.PROMPT_ELEMENT_SEPARATOR +import com.embabel.agent.spi.support.ToolResolutionHelper +import com.embabel.agent.spi.support.buildConsolidatedPromptMessages +import com.embabel.agent.spi.support.buildPromptContributionsString +import com.embabel.agent.spi.support.guardrails.validateUserInput +import com.embabel.chat.Message +import com.embabel.chat.SystemMessage +import com.embabel.chat.UserMessage +import com.embabel.common.ai.converters.streaming.StreamingJacksonOutputConverter +import com.embabel.common.core.streaming.StreamingEvent +import com.fasterxml.jackson.databind.ObjectMapper +import org.slf4j.LoggerFactory +import reactor.core.publisher.Flux +import reactor.core.publisher.Mono + +/** + * Vendor-neutral implementation of [StreamingLlmOperations]. + * + * This class provides streaming LLM operations without depending on any specific + * LLM framework (Spring AI, LangChain4j, etc.). It delegates raw streaming to + * [LlmMessageStreamer] and handles: + * - Line buffering from raw chunks + * - JSONL parsing to typed objects + * - Thinking content extraction + * + * @param messageStreamer The streamer for raw LLM content + * @param objectMapper ObjectMapper for JSON parsing + * @param llmService The LLM service for prompt contributions + * @param toolDecorator Decorator to make tools platform-aware + */ +internal class StreamingLlmOperationsImpl( + private val messageStreamer: LlmMessageStreamer, + private val objectMapper: ObjectMapper, + private val llmService: LlmService<*>, + private val toolDecorator: ToolDecorator, +) : StreamingLlmOperations { + + private val logger = LoggerFactory.getLogger(StreamingLlmOperationsImpl::class.java) + + // ======================================== + // Public API methods + // ======================================== + + override fun generateStream( + messages: List, + interaction: LlmInteraction, + agentProcess: AgentProcess, + action: Action?, + ): Flux { + return doTransformStream(messages, interaction, null, agentProcess, action) + } + + override fun createObjectStream( + messages: List, + interaction: LlmInteraction, + outputClass: Class, + agentProcess: AgentProcess, + action: Action?, + ): Flux { + return doTransformObjectStream(messages, interaction, outputClass, null, agentProcess, action) + } + + override fun createObjectStreamIfPossible( + messages: List, + interaction: LlmInteraction, + outputClass: Class, + agentProcess: AgentProcess, + action: Action?, + ): Flux> { + return createObjectStream(messages, interaction, outputClass, agentProcess, action) + .map { Result.success(it) } + .onErrorResume { throwable -> + Flux.just(Result.failure(throwable)) + } + } + + override fun createObjectStreamWithThinking( + messages: List, + interaction: LlmInteraction, + outputClass: Class, + agentProcess: AgentProcess, + action: Action?, + ): Flux> { + return doTransformObjectStreamWithThinking(messages, interaction, outputClass, null, agentProcess, action) + } + + // ======================================== + // Low-level transform methods + // ======================================== + + override fun doTransformStream( + messages: List, + interaction: LlmInteraction, + llmRequestEvent: LlmRequestEvent?, + agentProcess: AgentProcess?, + action: Action?, + ): Flux { + // Build prompt contributions + val promptContributions = buildPromptContributions(interaction) + + // Guardrails: Pre-validation of user input + val userMessages = messages.filterIsInstance() + validateUserInput(userMessages, interaction, llmRequestEvent?.agentProcess?.blackboard) + + // Resolve and decorate tools + val tools = resolveTools(interaction, agentProcess, action) + + // Build messages with contributions + val messagesWithContributions = buildMessagesWithContributions(messages, promptContributions) + + // Stream raw chunks from LLM + return messageStreamer.stream(messagesWithContributions, tools) + } + + override fun doTransformObjectStream( + messages: List, + interaction: LlmInteraction, + outputClass: Class, + llmRequestEvent: LlmRequestEvent?, + agentProcess: AgentProcess?, + action: Action?, + ): Flux { + return doTransformObjectStreamInternal( + messages = messages, + interaction = interaction, + outputClass = outputClass, + llmRequestEvent = llmRequestEvent, + agentProcess = agentProcess, + action = action, + ) + .filter { it.isObject() } + .map { (it as StreamingEvent.Object).item } + } + + override fun doTransformObjectStreamWithThinking( + messages: List, + interaction: LlmInteraction, + outputClass: Class, + llmRequestEvent: LlmRequestEvent?, + agentProcess: AgentProcess?, + action: Action?, + ): Flux> { + return doTransformObjectStreamInternal( + messages = messages, + interaction = interaction, + outputClass = outputClass, + llmRequestEvent = llmRequestEvent, + agentProcess = agentProcess, + action = action, + ) + } + + // ======================================== + // Internal implementation + // ======================================== + + /** + * Internal unified streaming implementation that handles the complete transformation pipeline. + * + * Pipeline: + * 1. Raw LLM chunks from [LlmMessageStreamer] + * 2. Line buffering via [rawChunksToLines] + * 3. Event generation via [StreamingJacksonOutputConverter] + */ + private fun doTransformObjectStreamInternal( + messages: List, + interaction: LlmInteraction, + outputClass: Class, + @Suppress("UNUSED_PARAMETER") + llmRequestEvent: LlmRequestEvent?, + agentProcess: AgentProcess?, + action: Action?, + ): Flux> { + // Create converter for JSONL parsing + val streamingConverter = StreamingJacksonOutputConverter( + clazz = outputClass, + objectMapper = objectMapper, + fieldFilter = interaction.fieldFilter + ) + + // Build prompt contributions with streaming format instructions + val promptContributions = buildPromptContributions(interaction) + val streamingFormatInstructions = streamingConverter.getFormat() + logger.debug("STREAMING FORMAT INSTRUCTIONS: $streamingFormatInstructions") + val fullPromptContributions = if (promptContributions.isNotEmpty()) { + "$promptContributions$PROMPT_ELEMENT_SEPARATOR$streamingFormatInstructions" + } else { + streamingFormatInstructions + } + + // Guardrails: Pre-validation of user input + val userMessages = messages.filterIsInstance() + validateUserInput(userMessages, interaction, llmRequestEvent?.agentProcess?.blackboard) + + // Resolve and decorate tools + val tools = resolveTools(interaction, agentProcess, action) + + // Build messages with contributions + val messagesWithContributions = buildMessagesWithContributions(messages, fullPromptContributions) + + // Step 1: Raw chunk stream from LLM + val rawChunkFlux: Flux = messageStreamer.stream(messagesWithContributions, tools) + .filter { it.isNotEmpty() } + .doOnNext { chunk -> logger.trace("RAW CHUNK: '${chunk.replace("\n", "\\n")}'") } + + // Step 2: Transform raw chunks to complete newline-delimited lines + val lineFlux: Flux = rawChunkFlux + .transform { chunkFlux -> rawChunksToLines(chunkFlux) } + .doOnNext { line -> logger.trace("COMPLETE LINE: '$line'") } + + // Step 3: Convert lines to StreamingEvent (thinking + objects) + return lineFlux.concatMap { line -> streamingConverter.convertStreamWithThinking(line) } + } + + // ======================================== + // Helper methods + // ======================================== + + /** + * Build prompt contributions string from interaction and LLM contributors. + */ + private fun buildPromptContributions(interaction: LlmInteraction): String = + buildPromptContributionsString(interaction.promptContributors, llmService.promptContributors) + + /** + * Build message list with prompt contributions and any existing system messages + * consolidated into a single system message at the beginning. + */ + private fun buildMessagesWithContributions( + messages: List, + promptContributions: String, + ): List = buildConsolidatedPromptMessages(messages, promptContributions) + + /** + * Resolve and decorate tools using [ToolResolutionHelper]. + */ + private fun resolveTools( + interaction: LlmInteraction, + agentProcess: AgentProcess?, + action: Action?, + ): List { + return ToolResolutionHelper.resolveAndDecorate(interaction, agentProcess, action, toolDecorator) + } + + /** + * Convert raw streaming chunks to NDJSON lines. + * Handles all cases: multiple \n in one chunk, no \n in chunk, line spanning many chunks. + */ + private fun rawChunksToLines(raw: Flux): Flux { + val buffer = StringBuilder() + return raw.concatMap { chunk -> + buffer.append(chunk) + val lines = mutableListOf() + while (true) { + val idx = buffer.indexOf('\n') + if (idx < 0) break + val line = buffer.substring(0, idx).trim() + if (line.isNotEmpty()) lines.add(line) + buffer.delete(0, idx + 1) + } + Flux.fromIterable(lines) + }.doOnComplete { + if (buffer.isNotEmpty()) { + val finalLine = buffer.toString().trim() + if (finalLine.isNotEmpty()) { + logger.trace("FINAL LINE: '$finalLine'") + } + } + }.concatWith( + Mono.fromSupplier { buffer.toString().trim() } + .filter { it.isNotEmpty() } + ) + } +} diff --git a/embabel-agent-api/src/test/kotlin/com/embabel/agent/api/common/OperationContextPromptRunnerStreamingTest.kt b/embabel-agent-api/src/test/kotlin/com/embabel/agent/api/common/OperationContextPromptRunnerStreamingTest.kt index ccd76ad32..bac593086 100644 --- a/embabel-agent-api/src/test/kotlin/com/embabel/agent/api/common/OperationContextPromptRunnerStreamingTest.kt +++ b/embabel-agent-api/src/test/kotlin/com/embabel/agent/api/common/OperationContextPromptRunnerStreamingTest.kt @@ -21,7 +21,7 @@ import com.embabel.agent.core.AgentProcess import com.embabel.agent.core.Operation import com.embabel.agent.core.internal.LlmOperations import com.embabel.agent.core.support.LlmInteraction -import com.embabel.agent.spi.streaming.StreamingLlmOperations +import com.embabel.agent.core.internal.streaming.StreamingLlmOperations import com.embabel.agent.spi.support.springai.ChatClientLlmOperations import com.embabel.agent.spi.support.springai.SpringAiLlmService import com.embabel.agent.test.integration.DummyObjectCreatingLlmOperations @@ -67,11 +67,15 @@ class OperationContextPromptRunnerStreamingTest { val mockLlm = mockk { every { chatModel } returns mockStreamingChatModel + every { provider } returns "test-provider" + every { name } returns "test-model" + every { supportsStreaming() } returns true } val mockChatClientLlmOperations = mockk(moreInterfaces = arrayOf(StreamingLlmOperations::class), relaxed = true) { every { getLlm(any()) } returns mockLlm + every { supportsStreaming(any()) } returns true } val mockAgentPlatform = mockk() diff --git a/embabel-agent-api/src/test/kotlin/com/embabel/agent/api/common/support/OperationContextPromptRunner.kt b/embabel-agent-api/src/test/kotlin/com/embabel/agent/api/common/support/OperationContextPromptRunner.kt index 78127602e..0d7780ce3 100644 --- a/embabel-agent-api/src/test/kotlin/com/embabel/agent/api/common/support/OperationContextPromptRunner.kt +++ b/embabel-agent-api/src/test/kotlin/com/embabel/agent/api/common/support/OperationContextPromptRunner.kt @@ -19,7 +19,7 @@ import com.embabel.agent.api.common.* import com.embabel.agent.api.common.nested.support.PromptRunnerCreating import com.embabel.agent.api.common.nested.support.PromptRunnerRendering import com.embabel.agent.api.common.streaming.StreamingPromptRunner -import com.embabel.agent.api.common.support.streaming.StreamingCapabilityDetector +import com.embabel.agent.spi.support.streaming.StreamingCapabilityDetector import com.embabel.agent.api.common.support.streaming.StreamingImpl import com.embabel.agent.api.common.thinking.support.ThinkingPromptRunnerOperationsImpl import com.embabel.agent.api.tool.Tool diff --git a/embabel-agent-api/src/test/kotlin/com/embabel/agent/api/common/support/streaming/StreamingImpl.kt b/embabel-agent-api/src/test/kotlin/com/embabel/agent/api/common/support/streaming/StreamingImpl.kt index 14679ee00..9ba00b65a 100644 --- a/embabel-agent-api/src/test/kotlin/com/embabel/agent/api/common/support/streaming/StreamingImpl.kt +++ b/embabel-agent-api/src/test/kotlin/com/embabel/agent/api/common/support/streaming/StreamingImpl.kt @@ -19,7 +19,7 @@ import com.embabel.agent.api.common.streaming.StreamingPromptRunner import com.embabel.agent.core.Action import com.embabel.agent.core.AgentProcess import com.embabel.agent.core.support.LlmInteraction -import com.embabel.agent.spi.streaming.StreamingLlmOperations +import com.embabel.agent.core.internal.streaming.StreamingLlmOperations import com.embabel.chat.Message import com.embabel.chat.UserMessage import com.embabel.common.core.streaming.StreamingEvent diff --git a/embabel-agent-api/src/test/kotlin/com/embabel/agent/api/common/support/streaming/StreamingImplTest.kt b/embabel-agent-api/src/test/kotlin/com/embabel/agent/api/common/support/streaming/StreamingImplTest.kt index 5fbb32ee3..c779da071 100644 --- a/embabel-agent-api/src/test/kotlin/com/embabel/agent/api/common/support/streaming/StreamingImplTest.kt +++ b/embabel-agent-api/src/test/kotlin/com/embabel/agent/api/common/support/streaming/StreamingImplTest.kt @@ -18,7 +18,7 @@ package com.embabel.agent.api.common.support.streaming import com.embabel.agent.core.Action import com.embabel.agent.core.AgentProcess import com.embabel.agent.core.support.LlmInteraction -import com.embabel.agent.spi.streaming.StreamingLlmOperations +import com.embabel.agent.core.internal.streaming.StreamingLlmOperations import com.embabel.chat.Message import com.embabel.chat.UserMessage import com.embabel.common.core.streaming.StreamingEvent diff --git a/embabel-agent-api/src/test/kotlin/com/embabel/agent/e2e/LLMStreamingIntegrationTest.kt b/embabel-agent-api/src/test/kotlin/com/embabel/agent/e2e/LLMStreamingIntegrationTest.kt index b700128f7..e05e44fce 100644 --- a/embabel-agent-api/src/test/kotlin/com/embabel/agent/e2e/LLMStreamingIntegrationTest.kt +++ b/embabel-agent-api/src/test/kotlin/com/embabel/agent/e2e/LLMStreamingIntegrationTest.kt @@ -21,7 +21,6 @@ import com.embabel.agent.api.common.Ai import com.embabel.agent.api.common.autonomy.Autonomy import com.embabel.agent.api.common.streaming.StreamingPromptRunner import com.embabel.agent.api.common.streaming.asStreaming -import com.embabel.agent.api.common.support.streaming.StreamingCapabilityDetector import com.embabel.agent.core.AgentPlatform import com.embabel.agent.core.internal.LlmOperations import com.embabel.agent.spi.LlmService @@ -176,10 +175,14 @@ class LLMStreamingIntegrationTest( @Test fun `test streaming capability detection`() { - // Direct test of StreamingCapabilityDetector + // Direct test of LlmService.supportsStreaming() val fakeStreamingModel = FakeStreamingChatModel("test response") - val detector = StreamingCapabilityDetector - val supportsStreaming = detector.supportsStreaming(fakeStreamingModel) + val llmService = SpringAiLlmService( + name = "test-streaming", + chatModel = fakeStreamingModel, + provider = "test", + ) + val supportsStreaming = llmService.supportsStreaming() assertTrue(supportsStreaming, "FakeStreamingChatModel should be detected as supporting streaming") diff --git a/embabel-agent-api/src/test/kotlin/com/embabel/agent/spi/support/AbstractLlmOperationsStreamingTest.kt b/embabel-agent-api/src/test/kotlin/com/embabel/agent/spi/support/AbstractLlmOperationsStreamingTest.kt new file mode 100644 index 000000000..b84f0205f --- /dev/null +++ b/embabel-agent-api/src/test/kotlin/com/embabel/agent/spi/support/AbstractLlmOperationsStreamingTest.kt @@ -0,0 +1,160 @@ +/* + * Copyright 2024-2026 Embabel Pty Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.embabel.agent.spi.support + +import com.embabel.agent.api.common.Asyncer +import com.embabel.agent.api.event.LlmRequestEvent +import com.embabel.agent.core.support.LlmInteraction +import com.embabel.agent.spi.AutoLlmSelectionCriteriaResolver +import com.embabel.agent.spi.LlmService +import com.embabel.agent.spi.ToolDecorator +import com.embabel.agent.spi.loop.streaming.LlmMessageStreamer +import com.embabel.agent.spi.support.streaming.StreamingLlmOperationsImpl +import com.embabel.agent.spi.validation.DefaultValidationPromptGenerator +import com.embabel.chat.Message +import com.embabel.common.ai.model.LlmOptions +import com.embabel.common.ai.model.ModelProvider +import com.embabel.common.ai.model.ModelSelectionCriteria +import com.fasterxml.jackson.databind.ObjectMapper +import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule +import com.fasterxml.jackson.module.kotlin.jacksonObjectMapper +import io.mockk.every +import io.mockk.mockk +import io.mockk.verify +import jakarta.validation.Validation +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import kotlin.test.assertIs + +/** + * Tests for [AbstractLlmOperations.createStreamingOperations]. + */ +class AbstractLlmOperationsStreamingTest { + + private lateinit var mockModelProvider: ModelProvider + private lateinit var mockToolDecorator: ToolDecorator + private lateinit var mockLlmService: LlmService<*> + private lateinit var mockMessageStreamer: LlmMessageStreamer + private lateinit var mockAsyncer: Asyncer + private val objectMapper: ObjectMapper = jacksonObjectMapper().registerModule(JavaTimeModule()) + private val validator = Validation.buildDefaultValidatorFactory().validator + + private lateinit var llmOperations: TestableAbstractLlmOperations + + @BeforeEach + fun setup() { + mockModelProvider = mockk() + mockToolDecorator = mockk() + mockLlmService = mockk() + mockMessageStreamer = mockk() + mockAsyncer = mockk() + + every { mockModelProvider.getLlm(any()) } returns mockLlmService + every { mockLlmService.createMessageStreamer(any()) } returns mockMessageStreamer + + llmOperations = TestableAbstractLlmOperations( + toolDecorator = mockToolDecorator, + modelProvider = mockModelProvider, + validator = validator, + objectMapper = objectMapper, + asyncer = mockAsyncer, + ) + } + + @Test + fun `createStreamingOperations returns StreamingLlmOperationsImpl`() { + val options = LlmOptions() + + val result = llmOperations.createStreamingOperations(options) + + assertIs(result) + } + + @Test + fun `createStreamingOperations calls chooseLlm with correct options`() { + val options = LlmOptions() + + llmOperations.createStreamingOperations(options) + + verify { mockModelProvider.getLlm(any()) } + } + + @Test + fun `createStreamingOperations calls createMessageStreamer on LlmService`() { + val options = LlmOptions() + + llmOperations.createStreamingOperations(options) + + verify { mockLlmService.createMessageStreamer(options) } + } + + /** + * Testable concrete implementation of AbstractLlmOperations. + */ + private class TestableAbstractLlmOperations( + toolDecorator: ToolDecorator, + modelProvider: ModelProvider, + validator: jakarta.validation.Validator, + objectMapper: ObjectMapper, + asyncer: Asyncer, + ) : AbstractLlmOperations( + toolDecorator = toolDecorator, + modelProvider = modelProvider, + validator = validator, + validationPromptGenerator = DefaultValidationPromptGenerator(), + autoLlmSelectionCriteriaResolver = AutoLlmSelectionCriteriaResolver.DEFAULT, + dataBindingProperties = LlmDataBindingProperties(), + promptsProperties = LlmOperationsPromptsProperties(), + asyncer = asyncer, + objectMapper = objectMapper, + ) { + override fun doTransformIfPossible( + messages: List, + interaction: LlmInteraction, + outputClass: Class, + llmRequestEvent: LlmRequestEvent, + ): Result { + throw UnsupportedOperationException("Not needed for streaming tests") + } + + override fun doTransform( + messages: List, + interaction: LlmInteraction, + outputClass: Class, + llmRequestEvent: LlmRequestEvent?, + ): O { + throw UnsupportedOperationException("Not needed for streaming tests") + } + + override fun doTransformWithThinking( + messages: List, + interaction: LlmInteraction, + outputClass: Class, + llmRequestEvent: LlmRequestEvent?, + ): com.embabel.common.core.thinking.ThinkingResponse { + throw UnsupportedOperationException("Not needed for streaming tests") + } + + override fun doTransformWithThinkingIfPossible( + messages: List, + interaction: LlmInteraction, + outputClass: Class, + llmRequestEvent: LlmRequestEvent?, + ): Result> { + throw UnsupportedOperationException("Not needed for streaming tests") + } + } +} diff --git a/embabel-agent-api/src/test/kotlin/com/embabel/agent/spi/support/MessagePromptBuildersTest.kt b/embabel-agent-api/src/test/kotlin/com/embabel/agent/spi/support/MessagePromptBuildersTest.kt new file mode 100644 index 000000000..e26ac81d8 --- /dev/null +++ b/embabel-agent-api/src/test/kotlin/com/embabel/agent/spi/support/MessagePromptBuildersTest.kt @@ -0,0 +1,244 @@ +/* + * Copyright 2024-2026 Embabel Pty Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.embabel.agent.spi.support + +import com.embabel.chat.AssistantMessage +import com.embabel.chat.SystemMessage +import com.embabel.chat.UserMessage +import com.embabel.common.ai.prompt.PromptContributor +import org.junit.jupiter.api.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue + +/** + * Tests for message prompt builder helper functions. + */ +class MessagePromptBuildersTest { + + // ======================================== + // buildPromptContributionsString tests + // ======================================== + + @Test + fun `buildPromptContributionsString with empty lists returns empty string`() { + val result = buildPromptContributionsString(emptyList(), emptyList()) + assertEquals("", result) + } + + @Test + fun `buildPromptContributionsString joins contributors with separator`() { + val interactionContributors = listOf( + testPromptContributor("interaction1"), + testPromptContributor("interaction2") + ) + val llmContributors = listOf( + testPromptContributor("llm1") + ) + + val result = buildPromptContributionsString(interactionContributors, llmContributors) + + assertEquals("interaction1\n----\ninteraction2\n----\nllm1", result) + } + + @Test + fun `buildPromptContributionsString with only interaction contributors`() { + val interactionContributors = listOf(testPromptContributor("only-interaction")) + + val result = buildPromptContributionsString(interactionContributors, emptyList()) + + assertEquals("only-interaction", result) + } + + @Test + fun `buildPromptContributionsString with only llm contributors`() { + val llmContributors = listOf(testPromptContributor("only-llm")) + + val result = buildPromptContributionsString(emptyList(), llmContributors) + + assertEquals("only-llm", result) + } + + private fun testPromptContributor(content: String): PromptContributor = object : PromptContributor { + override fun contribution(): String = content + } + + // ======================================== + // partitionMessages tests + // ======================================== + + @Test + fun `partitionMessages with empty list returns empty results`() { + val (systemContent, nonSystemMessages) = partitionMessages(emptyList()) + + assertTrue(systemContent.isEmpty()) + assertTrue(nonSystemMessages.isEmpty()) + } + + @Test + fun `partitionMessages separates system messages from others`() { + val messages = listOf( + SystemMessage("system1"), + UserMessage("user1"), + SystemMessage("system2"), + AssistantMessage("assistant1"), + UserMessage("user2") + ) + + val (systemContent, nonSystemMessages) = partitionMessages(messages) + + assertEquals(listOf("system1", "system2"), systemContent) + assertEquals(3, nonSystemMessages.size) + assertTrue(nonSystemMessages[0] is UserMessage) + assertTrue(nonSystemMessages[1] is AssistantMessage) + assertTrue(nonSystemMessages[2] is UserMessage) + } + + @Test + fun `partitionMessages with only system messages`() { + val messages = listOf( + SystemMessage("system1"), + SystemMessage("system2") + ) + + val (systemContent, nonSystemMessages) = partitionMessages(messages) + + assertEquals(listOf("system1", "system2"), systemContent) + assertTrue(nonSystemMessages.isEmpty()) + } + + @Test + fun `partitionMessages with no system messages`() { + val messages = listOf( + UserMessage("user1"), + AssistantMessage("assistant1") + ) + + val (systemContent, nonSystemMessages) = partitionMessages(messages) + + assertTrue(systemContent.isEmpty()) + assertEquals(2, nonSystemMessages.size) + } + + // ======================================== + // buildConsolidatedSystemMessage tests + // ======================================== + + @Test + fun `buildConsolidatedSystemMessage with empty contents returns empty string`() { + val result = buildConsolidatedSystemMessage() + assertEquals("", result) + } + + @Test + fun `buildConsolidatedSystemMessage filters empty strings`() { + val result = buildConsolidatedSystemMessage("content1", "", "content2", "") + assertEquals("content1\n\ncontent2", result) + } + + @Test + fun `buildConsolidatedSystemMessage joins with double newlines`() { + val result = buildConsolidatedSystemMessage("first", "second", "third") + assertEquals("first\n\nsecond\n\nthird", result) + } + + @Test + fun `buildConsolidatedSystemMessage with single content`() { + val result = buildConsolidatedSystemMessage("only-content") + assertEquals("only-content", result) + } + + @Test + fun `buildConsolidatedSystemMessage with all empty strings returns empty`() { + val result = buildConsolidatedSystemMessage("", "", "") + assertEquals("", result) + } + + // ======================================== + // buildConsolidatedPromptMessages tests + // ======================================== + + @Test + fun `buildConsolidatedPromptMessages with empty messages and contributions`() { + val result = buildConsolidatedPromptMessages(emptyList(), "") + + assertTrue(result.isEmpty()) + } + + @Test + fun `buildConsolidatedPromptMessages consolidates system messages at beginning`() { + val messages = listOf( + UserMessage("user1"), + SystemMessage("system1"), + AssistantMessage("assistant1"), + SystemMessage("system2") + ) + + val result = buildConsolidatedPromptMessages(messages, "contributions") + + // 1 consolidated system + 2 non-system messages = 3 + assertEquals(3, result.size) + assertTrue(result[0] is SystemMessage) + assertEquals("contributions\n\nsystem1\n\nsystem2", (result[0] as SystemMessage).content) + assertTrue(result[1] is UserMessage) + assertTrue(result[2] is AssistantMessage) + } + + @Test + fun `buildConsolidatedPromptMessages with only prompt contributions`() { + val messages = listOf( + UserMessage("user1"), + AssistantMessage("assistant1") + ) + + val result = buildConsolidatedPromptMessages(messages, "prompt-contributions") + + assertEquals(3, result.size) + assertTrue(result[0] is SystemMessage) + assertEquals("prompt-contributions", (result[0] as SystemMessage).content) + assertTrue(result[1] is UserMessage) + assertTrue(result[2] is AssistantMessage) + } + + @Test + fun `buildConsolidatedPromptMessages preserves message order for non-system messages`() { + val messages = listOf( + UserMessage("first"), + AssistantMessage("second"), + UserMessage("third") + ) + + val result = buildConsolidatedPromptMessages(messages, "sys") + + assertEquals(4, result.size) + assertEquals("first", (result[1] as UserMessage).content) + assertEquals("second", (result[2] as AssistantMessage).content) + assertEquals("third", (result[3] as UserMessage).content) + } + + @Test + fun `buildConsolidatedPromptMessages with empty contributions and system messages`() { + val messages = listOf( + SystemMessage("system-only"), + UserMessage("user1") + ) + + val result = buildConsolidatedPromptMessages(messages, "") + + assertEquals(2, result.size) + assertTrue(result[0] is SystemMessage) + assertEquals("system-only", (result[0] as SystemMessage).content) + } +} diff --git a/embabel-agent-api/src/test/kotlin/com/embabel/agent/spi/support/springai/SpringAiLlmServiceTest.kt b/embabel-agent-api/src/test/kotlin/com/embabel/agent/spi/support/springai/SpringAiLlmServiceTest.kt index d6310ca8a..efcfd924d 100644 --- a/embabel-agent-api/src/test/kotlin/com/embabel/agent/spi/support/springai/SpringAiLlmServiceTest.kt +++ b/embabel-agent-api/src/test/kotlin/com/embabel/agent/spi/support/springai/SpringAiLlmServiceTest.kt @@ -15,6 +15,7 @@ */ package com.embabel.agent.spi.support.springai +import com.embabel.agent.spi.support.springai.streaming.SpringAiLlmMessageStreamer import com.embabel.common.ai.model.DefaultOptionsConverter import com.embabel.common.ai.model.LlmOptions import com.embabel.common.ai.model.OptionsConverter @@ -254,6 +255,48 @@ class SpringAiLlmServiceTest { } } + @Nested + inner class CreateMessageStreamerTests { + + // Relaxed mock needed because ChatClient.create() calls chatModel.getDefaultOptions() + private val relaxedChatModel: ChatModel = mockk(relaxed = true) + + @Test + fun `createMessageStreamer returns SpringAiLlmMessageStreamer`() { + val service = SpringAiLlmService( + name = "test-model", + provider = "Provider", + chatModel = relaxedChatModel, + ) + val options = LlmOptions() + + val streamer = service.createMessageStreamer(options) + + assertThat(streamer).isInstanceOf(SpringAiLlmMessageStreamer::class.java) + } + + @Test + fun `createMessageStreamer uses optionsConverter`() { + var converterCalled = false + val customConverter = object : OptionsConverter { + override fun convertOptions(options: LlmOptions): ChatOptions { + converterCalled = true + return mockk() + } + } + val service = SpringAiLlmService( + name = "test-model", + provider = "Provider", + chatModel = relaxedChatModel, + optionsConverter = customConverter, + ) + + service.createMessageStreamer(LlmOptions()) + + assertThat(converterCalled).isTrue() + } + } + @Nested inner class ToolResponseContentAdapterTests { diff --git a/embabel-agent-api/src/test/kotlin/com/embabel/agent/spi/support/springai/streaming/StreamingChatClientOperationsTest.kt b/embabel-agent-api/src/test/kotlin/com/embabel/agent/spi/support/springai/streaming/StreamingChatClientOperationsTest.kt index 2d1c58652..f4d98cd67 100644 --- a/embabel-agent-api/src/test/kotlin/com/embabel/agent/spi/support/springai/streaming/StreamingChatClientOperationsTest.kt +++ b/embabel-agent-api/src/test/kotlin/com/embabel/agent/spi/support/springai/streaming/StreamingChatClientOperationsTest.kt @@ -18,7 +18,7 @@ package com.embabel.agent.spi.support.springai.streaming import com.embabel.agent.core.Action import com.embabel.agent.core.AgentProcess import com.embabel.agent.core.support.LlmInteraction -import com.embabel.agent.spi.streaming.StreamingLlmOperations +import com.embabel.agent.core.internal.streaming.StreamingLlmOperations import com.embabel.agent.spi.support.springai.ChatClientLlmOperations import com.embabel.agent.spi.support.springai.SpringAiLlmService import com.embabel.chat.UserMessage diff --git a/embabel-agent-api/src/test/kotlin/com/embabel/agent/spi/support/streaming/StreamingCapabilityDetectorTest.kt b/embabel-agent-api/src/test/kotlin/com/embabel/agent/spi/support/streaming/StreamingCapabilityDetectorTest.kt new file mode 100644 index 000000000..fd7459ebd --- /dev/null +++ b/embabel-agent-api/src/test/kotlin/com/embabel/agent/spi/support/streaming/StreamingCapabilityDetectorTest.kt @@ -0,0 +1,89 @@ +/* + * Copyright 2024-2026 Embabel Pty Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.embabel.agent.spi.support.streaming + +import com.embabel.agent.core.internal.LlmOperations +import com.embabel.agent.core.internal.streaming.StreamingLlmOperations +import com.embabel.agent.core.internal.streaming.StreamingLlmOperationsFactory +import com.embabel.common.ai.model.LlmOptions +import io.mockk.every +import io.mockk.mockk +import io.mockk.verify +import org.junit.jupiter.api.Test +import kotlin.test.assertFalse +import kotlin.test.assertTrue + +/** + * Tests for [StreamingCapabilityDetector]. + */ +class StreamingCapabilityDetectorTest { + + @Test + fun `supportsStreaming returns false when llmOperations is not StreamingLlmOperationsFactory`() { + val mockLlmOperations = mockk() + val options = LlmOptions.withModel("test-model") + + val result = StreamingCapabilityDetector.supportsStreaming(mockLlmOperations, options) + + assertFalse(result) + } + + @Test + fun `supportsStreaming delegates to factory when llmOperations is StreamingLlmOperationsFactory`() { + val mockFactory = mockk() + val options = LlmOptions.withModel("streaming-model") + + every { mockFactory.supportsStreaming(options) } returns true + + val result = StreamingCapabilityDetector.supportsStreaming(mockFactory, options) + + assertTrue(result) + verify { mockFactory.supportsStreaming(options) } + } + + @Test + fun `supportsStreaming returns false when factory reports no streaming support`() { + val mockFactory = mockk() + val options = LlmOptions.withModel("non-streaming-model") + + every { mockFactory.supportsStreaming(options) } returns false + + val result = StreamingCapabilityDetector.supportsStreaming(mockFactory, options) + + assertFalse(result) + } + + @Test + fun `supportsStreaming caches result for same model`() { + val mockFactory = mockk() + val options = LlmOptions.withModel("cached-model") + + every { mockFactory.supportsStreaming(options) } returns true + + // Call twice + StreamingCapabilityDetector.supportsStreaming(mockFactory, options) + StreamingCapabilityDetector.supportsStreaming(mockFactory, options) + + // Factory should only be called once due to caching + verify(exactly = 1) { mockFactory.supportsStreaming(options) } + } + + /** + * Test interface that combines LlmOperations and StreamingLlmOperationsFactory + * for mocking purposes. + */ + private interface TestStreamingLlmOperationsFactory : LlmOperations, StreamingLlmOperationsFactory +}