Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
package com.embabel.agent.api.common.support

import com.embabel.agent.api.common.*
import com.embabel.agent.api.common.support.streaming.StreamingCapabilityDetector
import com.embabel.agent.api.tool.ArtifactSinkingTool
import com.embabel.agent.api.tool.Tool
import com.embabel.agent.api.tool.ToolCallContext
Expand All @@ -38,8 +37,7 @@ import com.embabel.agent.experimental.primitive.Determination
import com.embabel.agent.spi.loop.ToolChainingInjectionStrategy
import com.embabel.agent.spi.loop.ToolInjectionStrategy
import com.embabel.agent.spi.loop.ToolNotFoundPolicy
import com.embabel.agent.spi.support.springai.ChatClientLlmOperations
import com.embabel.agent.spi.support.springai.streaming.StreamingChatClientOperations
import com.embabel.agent.spi.streaming.StreamingLlmOperations
import com.embabel.chat.AssistantMessage
import com.embabel.chat.ImagePart
import com.embabel.chat.Message
Expand Down Expand Up @@ -321,12 +319,14 @@ internal data class OperationContextDelegate(

override fun supportsStreaming(): Boolean {
val llmOperations = context.agentPlatform().platformServices.llmOperations
return StreamingCapabilityDetector.supportsStreaming(llmOperations, this.llm)
// Level 1 sanity check
if (llmOperations !is StreamingLlmOperations) return false

return llmOperations.supportsStreaming(this.llm)
}

override fun generateStream(): Flux<String> {
val llmOperations = context.agentPlatform().platformServices.llmOperations as ChatClientLlmOperations
val streamingLlmOperations = StreamingChatClientOperations(llmOperations)
val streamingLlmOperations = context.agentPlatform().platformServices.llmOperations as StreamingLlmOperations

return streamingLlmOperations.generateStream(
messages = messages,
Expand All @@ -337,8 +337,7 @@ internal data class OperationContextDelegate(
}

override fun <T> createObjectStream(itemClass: Class<T>): Flux<T> {
val llmOperations = context.agentPlatform().platformServices.llmOperations as ChatClientLlmOperations
val streamingLlmOperations = StreamingChatClientOperations(llmOperations)
val streamingLlmOperations = context.agentPlatform().platformServices.llmOperations as StreamingLlmOperations

return streamingLlmOperations.createObjectStream(
messages = messages,
Expand All @@ -350,8 +349,8 @@ internal data class OperationContextDelegate(
}

override fun <T> createObjectStreamWithThinking(itemClass: Class<T>): Flux<StreamingEvent<T>> {
val llmOperations = context.agentPlatform().platformServices.llmOperations as ChatClientLlmOperations
val streamingLlmOperations = StreamingChatClientOperations(llmOperations)
val streamingLlmOperations = context.agentPlatform().platformServices.llmOperations as StreamingLlmOperations

return streamingLlmOperations.createObjectStreamWithThinking(
messages = messages,
interaction = streamingInteraction(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.embabel.agent.api.common.support.streaming
package com.embabel.agent.spi.streaming

import com.embabel.agent.api.common.InteractionId
import com.embabel.agent.api.common.support.streaming.StreamingCapabilityDetector.supportsStreaming
import com.embabel.agent.core.internal.LlmOperations
import com.embabel.agent.core.support.LlmInteraction
import com.embabel.agent.spi.support.springai.ChatClientLlmOperations
import com.embabel.agent.spi.support.springai.SpringAiLlmService
import com.embabel.common.ai.model.LlmOptions
import com.embabel.common.util.loggerFor
import org.springframework.ai.chat.messages.UserMessage
import org.springframework.ai.chat.model.ChatModel
Expand Down Expand Up @@ -63,27 +56,6 @@ internal object StreamingCapabilityDetector {
}
}

/**
* Delegates to [supportsStreaming]
*/
fun supportsStreaming(llmOperations: LlmOperations, llmOptions: LlmOptions): Boolean {

// Level 1 sanity check
if (llmOperations !is ChatClientLlmOperations) return false

// Level 2: Must have actual streaming capability
val llm = llmOperations.getLlm(
LlmInteraction( // check for circular dependency
id = InteractionId("capability-check"),
llm = llmOptions
)
)

val springAiLlm = llm as? SpringAiLlmService ?: return false
return supportsStreaming(springAiLlm.chatModel)

}

private fun testStreamingCapability(model: ChatModel): Boolean {
return try {
// Use a prompt that should generate a response
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import com.embabel.agent.core.Action
import com.embabel.agent.core.AgentProcess
import com.embabel.agent.core.support.LlmInteraction
import com.embabel.chat.Message
import com.embabel.chat.UserMessage
import com.embabel.common.ai.model.LlmOptions
import com.embabel.common.core.streaming.StreamingEvent
import reactor.core.publisher.Flux

Expand All @@ -40,26 +40,12 @@ import reactor.core.publisher.Flux
interface StreamingLlmOperations {

/**
* Generate streaming text in the context of an AgentProcess.
* Returns a Flux that emits text chunks as they arrive from the LLM.
* Tests whether the given chat model actually supports streaming operations.
*
* @param prompt Prompt to generate text from
* @param interaction Llm options and tool callbacks to use, plus unique identifier
* @param agentProcess Agent process we are running within
* @param action Action we are running within if we are running within an action
* @return Flux of text chunks as they arrive from the LLM
* @param llmOptions LlmOptions to use, including the name of the model to test
* @return true if the model supports streaming, false otherwise
*/
fun generateStream(
prompt: String,
interaction: LlmInteraction,
agentProcess: AgentProcess,
action: Action?,
): Flux<String> = generateStream(
messages = listOf(UserMessage(prompt)),
interaction = interaction,
agentProcess = agentProcess,
action = action,
)
fun supportsStreaming(llmOptions: LlmOptions): Boolean

/**
* Generate streaming text from messages in the context of an AgentProcess.
Expand All @@ -76,7 +62,9 @@ interface StreamingLlmOperations {
interaction: LlmInteraction,
agentProcess: AgentProcess,
action: Action?,
): Flux<String>
): Flux<String> {
return doTransformStream(messages, interaction, null, agentProcess, action)
}

/**
* Create a streaming list of objects from JSONL response in the context of an AgentProcess.
Expand All @@ -98,7 +86,9 @@ interface StreamingLlmOperations {
outputClass: Class<O>,
agentProcess: AgentProcess,
action: Action?,
): Flux<O>
): Flux<O> {
return doTransformObjectStream(messages, interaction, outputClass, null, agentProcess, action)
}

/**
* Try to create a streaming list of objects in the context of an AgentProcess.
Expand All @@ -118,7 +108,13 @@ interface StreamingLlmOperations {
outputClass: Class<O>,
agentProcess: AgentProcess,
action: Action?,
): Flux<Result<O>>
): Flux<Result<O>> {
return createObjectStream(messages, interaction, outputClass, agentProcess, action)
.map { Result.success(it) }
.onErrorResume { throwable ->
Flux.just(Result.failure(throwable))
}
}

/**
* Create a streaming list of objects with LLM thinking content from mixed JSONL response.
Expand All @@ -141,7 +137,9 @@ interface StreamingLlmOperations {
outputClass: Class<O>,
agentProcess: AgentProcess,
action: Action?,
): Flux<StreamingEvent<O>>
): Flux<StreamingEvent<O>> {
return doTransformObjectStreamWithThinking(messages, interaction, outputClass, null, agentProcess, action)
}

/**
* Low level streaming transform with optional platform context.
Expand Down
Loading
Loading