diff --git a/README.md b/README.md index 4265c1e..61e911e 100644 --- a/README.md +++ b/README.md @@ -52,15 +52,15 @@ func main() { } ``` -Swap to xAI Grok by changing one import: +Swap to OpenAI by changing one import: ```go import "github.com/katasec/forge/provider/openai" -provider := openai.New("https://api.x.ai/v1", os.Getenv("XAI_API_KEY"), "grok-3-mini") +provider := openai.New(os.Getenv("OPENAI_API_KEY"), openai.ModelGPT54Nano) ``` -The `openai` package works with any OpenAI-compatible API (xAI, OpenAI, Together, Groq, etc.). +The `openai` package uses the OpenAI Responses API, including text and image content. Or use the xAI Responses API with built-in web search: @@ -135,6 +135,16 @@ resp, err := agent.Ask(ctx, "Hello") fmt.Println(resp.LastText()) ``` +For multimodal input, use `AskContent`: + +```go +resp, err := agent.AskContent(ctx, + forge.Text("Describe this image."), + forge.ImageURL("https://example.com/cat.png"), +) +fmt.Println(resp.LastText()) +``` + Use `AskIn` when you want to manage multiple named conversations: ```go diff --git a/_examples/calculator/main.go b/_examples/calculator/main.go index c757b0c..98c8eb4 100644 --- a/_examples/calculator/main.go +++ b/_examples/calculator/main.go @@ -43,16 +43,16 @@ func (p *MockProvider) Generate(_ context.Context, req forge.ProviderRequest) (* // First call: "LLM" decides to use the add tool. if p.calls == 1 { return &forge.ProviderResponse{ - Message: forge.Message{ + Messages: []forge.Message{{ Role: forge.RoleAssistant, - ToolCalls: []forge.ToolCall{ - { + Content: []forge.ContentBlock{ + forge.ToolCallBlock(forge.ToolCall{ ID: "call-1", Name: "add", Arguments: json.RawMessage(`{"a": 12, "b": 30}`), - }, + }), }, - }, + }}, FinishReason: forge.FinishReasonToolUse, Usage: forge.TokenUsage{InputTokens: 25, OutputTokens: 15}, }, nil @@ -62,16 +62,13 @@ func (p *MockProvider) Generate(_ context.Context, req forge.ProviderRequest) (* // Look at the last message to find the tool result. var toolResult string for _, msg := range req.Messages { - if msg.Role == forge.RoleTool && len(msg.ToolResults) > 0 { - toolResult = msg.ToolResults[0].Content + if msg.Role == forge.RoleTool && len(msg.ToolResults()) > 0 { + toolResult = msg.ToolResults()[0].Content } } return &forge.ProviderResponse{ - Message: forge.Message{ - Role: forge.RoleAssistant, - Content: fmt.Sprintf("The answer is %s!", toolResult), - }, + Messages: []forge.Message{forge.AssistantText(fmt.Sprintf("The answer is %s!", toolResult))}, FinishReason: forge.FinishReasonStop, Usage: forge.TokenUsage{InputTokens: 40, OutputTokens: 10}, }, nil diff --git a/_examples/chat-console/README.md b/_examples/chat-console/README.md index 9bb3432..d12ebe8 100644 --- a/_examples/chat-console/README.md +++ b/_examples/chat-console/README.md @@ -15,7 +15,14 @@ export ANTHROPIC_API_KEY=sk-ant-... go run . ``` -Use xAI's OpenAI-compatible endpoint: +Use OpenAI's Responses API: + +```bash +export OPENAI_API_KEY=sk-... +go run . -provider openai +``` + +Use xAI's Responses API: ```bash export XAI_API_KEY=xai-... diff --git a/_examples/chat-console/main.go b/_examples/chat-console/main.go index 1a272a7..70c4a21 100644 --- a/_examples/chat-console/main.go +++ b/_examples/chat-console/main.go @@ -42,13 +42,16 @@ func buildProvider(name string) (forge.Provider, func()) { return anthropic.New(key, "claude-sonnet-4-20250514"), func() {} case "xai": key := requireEnv("XAI_API_KEY") - return openai.New("https://api.x.ai/v1", key, "grok-3-mini"), func() {} + return xai.New(key, xai.ModelGrok4FastNonReasoning), func() {} + case "openai": + key := requireEnv("OPENAI_API_KEY") + return openai.New(key, openai.ModelGPT54Nano), func() {} case "xai-search": key := requireEnv("XAI_API_KEY") provider := xai.New(key, xai.ModelGrok4FastNonReasoning, xai.WithWebSearch()) return provider, func() { printCitations(provider.LastCitations()) } default: - log.Fatalf("unknown provider %q; use anthropic, xai, or xai-search", name) + log.Fatalf("unknown provider %q; use anthropic, openai, xai, or xai-search", name) return nil, nil } } diff --git a/_examples/hello-world/main.go b/_examples/hello-world/main.go index 1bce5c2..f44a823 100644 --- a/_examples/hello-world/main.go +++ b/_examples/hello-world/main.go @@ -1,16 +1,16 @@ // Hello World is the simplest possible forge example. // // Shows how to call Claude with your Anthropic API key, and how to -// swap to xAI's Grok by changing one line. +// swap to xAI's Grok or OpenAI by changing one flag. // // Usage: // // export ANTHROPIC_API_KEY=sk-ant-... // go run . // -// # Or use xAI's OpenAI-compatible endpoint instead: -// export XAI_API_KEY=xai-... -// go run . -provider xai +// # Or use OpenAI's Responses API instead: +// export OPENAI_API_KEY=sk-... +// go run . -provider openai // // # Or use xAI Responses API with web search: // export XAI_API_KEY=xai-... @@ -31,7 +31,7 @@ import ( ) func main() { - providerFlag := flag.String("provider", "anthropic", "Provider to use: anthropic, xai, or xai-search") + providerFlag := flag.String("provider", "anthropic", "Provider to use: anthropic, openai, xai, or xai-search") flag.Parse() // Pick your provider. The agent setup below stays the same. @@ -44,12 +44,19 @@ func main() { log.Fatal("Set ANTHROPIC_API_KEY environment variable") } provider = anthropic.New(key, "claude-sonnet-4-20250514") + case "openai": + key := os.Getenv("OPENAI_API_KEY") + if key == "" { + log.Fatal("Set OPENAI_API_KEY environment variable") + } + provider = openai.New(key, openai.ModelGPT54Nano) case "xai": key := os.Getenv("XAI_API_KEY") if key == "" { log.Fatal("Set XAI_API_KEY environment variable") } - provider = openai.New("https://api.x.ai/v1", key, "grok-3-mini") + xaiProvider = xai.New(key, xai.ModelGrok4FastNonReasoning) + provider = xaiProvider case "xai-search": key := os.Getenv("XAI_API_KEY") if key == "" { @@ -58,7 +65,7 @@ func main() { xaiProvider = xai.New(key, xai.ModelGrok4FastNonReasoning, xai.WithWebSearch()) provider = xaiProvider default: - log.Fatalf("Unknown provider: %s (use 'anthropic', 'xai', or 'xai-search')", *providerFlag) + log.Fatalf("Unknown provider: %s (use 'anthropic', 'openai', 'xai', or 'xai-search')", *providerFlag) } // Build the agent: provider, prompt, and runtime behavior live in Config. diff --git a/agent.go b/agent.go index 32fb5f4..8ef0386 100644 --- a/agent.go +++ b/agent.go @@ -5,6 +5,7 @@ import ( "errors" "github.com/google/uuid" + "github.com/katasec/forge/message" ) // AgentRequest is the input to Agent.Run. @@ -26,8 +27,8 @@ type AgentResponse struct { func (r *AgentResponse) LastText() string { for i := len(r.Messages) - 1; i >= 0; i-- { msg := r.Messages[i] - if msg.Role == RoleAssistant && msg.Content != "" { - return msg.Content + if msg.Role == RoleAssistant && msg.Text() != "" { + return msg.Text() } } return "" @@ -97,7 +98,15 @@ func (a *Agent) Ask(ctx context.Context, prompt string) (*AgentResponse, error) func (a *Agent) AskIn(ctx context.Context, conversationID, prompt string) (*AgentResponse, error) { return a.Run(ctx, AgentRequest{ ConversationID: conversationID, - Messages: []Message{UserMessage(prompt)}, + Messages: []Message{UserText(prompt)}, + }) +} + +// AskContent sends a rich user message in the agent's default conversation. +func (a *Agent) AskContent(ctx context.Context, blocks ...ContentBlock) (*AgentResponse, error) { + return a.Run(ctx, AgentRequest{ + ConversationID: a.defaultConversationID, + Messages: []Message{UserMessage(blocks...)}, }) } @@ -147,8 +156,11 @@ func (a *Agent) Run(ctx context.Context, req AgentRequest) (*AgentResponse, erro } usage.InputTokens += resp.Usage.InputTokens + usage.CachedInputTokens += resp.Usage.CachedInputTokens usage.OutputTokens += resp.Usage.OutputTokens - messages = append(messages, resp.Message) + usage.ReasoningOutputTokens += resp.Usage.ReasoningOutputTokens + usage.TotalTokens += resp.Usage.TotalTokens + messages = append(messages, resp.Messages...) iteration++ if resp.FinishReason == FinishReasonStop { @@ -157,7 +169,18 @@ func (a *Agent) Run(ctx context.Context, req AgentRequest) (*AgentResponse, erro } // FinishReason is tool_use - execute the tool calls. - toolResults := a.executor.Execute(ctx, resp.Message.ToolCalls) + if len(resp.Messages) == 0 { + finishReason = FinishReasonError + toolErrors = append(toolErrors, ToolError{Message: "provider requested tool use without a message"}) + break + } + toolCalls := resp.Messages[len(resp.Messages)-1].ToolCalls() + if len(toolCalls) == 0 { + finishReason = FinishReasonError + toolErrors = append(toolErrors, ToolError{Message: "provider requested tool use without tool calls"}) + break + } + toolResults := a.executor.Execute(ctx, toolCalls) // Check for tool errors. hasError := false @@ -176,11 +199,7 @@ func (a *Agent) Run(ctx context.Context, req AgentRequest) (*AgentResponse, erro } // Append tool results message (even on error, for coherent history). - toolMsg := Message{ - Role: RoleTool, - ToolResults: toolResults, - } - messages = append(messages, toolMsg) + messages = append(messages, message.ToolMessage(toolResults...)) if hasError { break diff --git a/agent_test.go b/agent_test.go index 513c195..e7bade0 100644 --- a/agent_test.go +++ b/agent_test.go @@ -27,7 +27,7 @@ func (m *mockProvider) Generate(_ context.Context, _ ProviderRequest) (*Provider } // Default: stop with empty message. return &ProviderResponse{ - Message: Message{Role: RoleAssistant, Content: "default"}, + Messages: []Message{AssistantText("default")}, FinishReason: FinishReasonStop, }, nil } @@ -47,7 +47,7 @@ func (r *recordingProvider) Generate(_ context.Context, req ProviderRequest) (*P return r.responses[i], nil } return &ProviderResponse{ - Message: Message{Role: RoleAssistant, Content: "default"}, + Messages: []Message{AssistantText("default")}, FinishReason: FinishReasonStop, }, nil } @@ -117,11 +117,11 @@ func TestAgentAskPreservesDefaultConversation(t *testing.T) { provider := &recordingProvider{ responses: []*ProviderResponse{ { - Message: Message{Role: RoleAssistant, Content: "hello"}, + Messages: []Message{AssistantText("hello")}, FinishReason: FinishReasonStop, }, { - Message: Message{Role: RoleAssistant, Content: "I remember"}, + Messages: []Message{AssistantText("I remember")}, FinishReason: FinishReasonStop, }, }, @@ -153,17 +153,17 @@ func TestAgentAskPreservesDefaultConversation(t *testing.T) { if len(provider.requests[1].Messages) != 3 { t.Fatalf("second request messages = %d, want 3", len(provider.requests[1].Messages)) } - if provider.requests[1].Messages[0].Content != "My name is Ameer." { - t.Errorf("first remembered message = %q", provider.requests[1].Messages[0].Content) + if provider.requests[1].Messages[0].Text() != "My name is Ameer." { + t.Errorf("first remembered message = %q", provider.requests[1].Messages[0].Text()) } } func TestAgentAskInUsesNamedConversations(t *testing.T) { provider := &recordingProvider{ responses: []*ProviderResponse{ - {Message: Message{Role: RoleAssistant, Content: "forge noted"}, FinishReason: FinishReasonStop}, - {Message: Message{Role: RoleAssistant, Content: "other noted"}, FinishReason: FinishReasonStop}, - {Message: Message{Role: RoleAssistant, Content: "forge remembered"}, FinishReason: FinishReasonStop}, + {Messages: []Message{AssistantText("forge noted")}, FinishReason: FinishReasonStop}, + {Messages: []Message{AssistantText("other noted")}, FinishReason: FinishReasonStop}, + {Messages: []Message{AssistantText("forge remembered")}, FinishReason: FinishReasonStop}, }, } @@ -189,18 +189,18 @@ func TestAgentAskInUsesNamedConversations(t *testing.T) { if len(provider.requests[2].Messages) != 3 { t.Fatalf("forge follow-up messages = %d, want 3", len(provider.requests[2].Messages)) } - if provider.requests[2].Messages[0].Content != "Remember forge." { - t.Errorf("first forge message = %q", provider.requests[2].Messages[0].Content) + if provider.requests[2].Messages[0].Text() != "Remember forge." { + t.Errorf("first forge message = %q", provider.requests[2].Messages[0].Text()) } } func TestAgentResponseLastText(t *testing.T) { resp := &AgentResponse{ Messages: []Message{ - UserMessage("hello"), - {Role: RoleAssistant, Content: "first"}, - {Role: RoleTool, ToolResults: []ToolResult{{Content: "tool result"}}}, - {Role: RoleAssistant, Content: "latest"}, + UserText("hello"), + AssistantText("first"), + {Role: RoleTool, Content: []ContentBlock{ToolResultBlock(ToolResult{Content: "tool result"})}}, + AssistantText("latest"), }, } @@ -213,7 +213,7 @@ func TestAgentRunStop(t *testing.T) { provider := &mockProvider{ responses: []*ProviderResponse{ { - Message: Message{Role: RoleAssistant, Content: "hello back"}, + Messages: []Message{AssistantText("hello back")}, FinishReason: FinishReasonStop, Usage: TokenUsage{InputTokens: 10, OutputTokens: 5}, }, @@ -222,7 +222,7 @@ func TestAgentRunStop(t *testing.T) { agent, _ := NewAgent(Config{Provider: provider}) resp, err := agent.Run(context.Background(), AgentRequest{ - Messages: []Message{{Role: RoleUser, Content: "hello"}}, + Messages: []Message{UserText("hello")}, }) if err != nil { t.Fatalf("Run error: %v", err) @@ -246,18 +246,12 @@ func TestAgentRunIterLimit(t *testing.T) { provider := &mockProvider{ responses: []*ProviderResponse{ { - Message: Message{ - Role: RoleAssistant, - ToolCalls: []ToolCall{{ID: "c1", Name: "echo", Arguments: json.RawMessage(`{"text":"hi"}`)}}, - }, + Messages: []Message{{Role: RoleAssistant, Content: []ContentBlock{ToolCallBlock(ToolCall{ID: "c1", Name: "echo", Arguments: json.RawMessage(`{"text":"hi"}`)})}}}, FinishReason: FinishReasonToolUse, }, // Would loop forever, but iter limit stops it. { - Message: Message{ - Role: RoleAssistant, - ToolCalls: []ToolCall{{ID: "c2", Name: "echo", Arguments: json.RawMessage(`{"text":"hi"}`)}}, - }, + Messages: []Message{{Role: RoleAssistant, Content: []ContentBlock{ToolCallBlock(ToolCall{ID: "c2", Name: "echo", Arguments: json.RawMessage(`{"text":"hi"}`)})}}}, FinishReason: FinishReasonToolUse, }, }, @@ -278,7 +272,7 @@ func TestAgentRunIterLimit(t *testing.T) { }) resp, err := agent.Run(context.Background(), AgentRequest{ - Messages: []Message{{Role: RoleUser, Content: "go"}}, + Messages: []Message{UserText("go")}, }) if err != nil { t.Fatalf("Run error: %v", err) @@ -295,7 +289,7 @@ func TestAgentRunProviderError(t *testing.T) { agent, _ := NewAgent(Config{Provider: provider}) _, err := agent.Run(context.Background(), AgentRequest{ - Messages: []Message{{Role: RoleUser, Content: "hello"}}, + Messages: []Message{UserText("hello")}, }) if err == nil { t.Fatal("expected error from provider") @@ -309,10 +303,7 @@ func TestAgentRunToolErrorStop(t *testing.T) { provider := &mockProvider{ responses: []*ProviderResponse{ { - Message: Message{ - Role: RoleAssistant, - ToolCalls: []ToolCall{{ID: "c1", Name: "broken", Arguments: json.RawMessage(`{}`)}}, - }, + Messages: []Message{{Role: RoleAssistant, Content: []ContentBlock{ToolCallBlock(ToolCall{ID: "c1", Name: "broken", Arguments: json.RawMessage(`{}`)})}}}, FinishReason: FinishReasonToolUse, }, }, @@ -331,7 +322,7 @@ func TestAgentRunToolErrorStop(t *testing.T) { }) resp, err := agent.Run(context.Background(), AgentRequest{ - Messages: []Message{{Role: RoleUser, Content: "go"}}, + Messages: []Message{UserText("go")}, }) if err != nil { t.Fatalf("Run error: %v (tool errors should not be fatal)", err) @@ -358,15 +349,12 @@ func TestAgentRunToolErrorContinue(t *testing.T) { provider := &mockProvider{ responses: []*ProviderResponse{ { - Message: Message{ - Role: RoleAssistant, - ToolCalls: []ToolCall{{ID: "c1", Name: "broken", Arguments: json.RawMessage(`{}`)}}, - }, + Messages: []Message{{Role: RoleAssistant, Content: []ContentBlock{ToolCallBlock(ToolCall{ID: "c1", Name: "broken", Arguments: json.RawMessage(`{}`)})}}}, FinishReason: FinishReasonToolUse, }, // After seeing the error, LLM stops. { - Message: Message{Role: RoleAssistant, Content: "I see the tool failed"}, + Messages: []Message{AssistantText("I see the tool failed")}, FinishReason: FinishReasonStop, }, }, @@ -383,7 +371,7 @@ func TestAgentRunToolErrorContinue(t *testing.T) { }) resp, err := agent.Run(context.Background(), AgentRequest{ - Messages: []Message{{Role: RoleUser, Content: "go"}}, + Messages: []Message{UserText("go")}, }) if err != nil { t.Fatalf("Run error: %v", err) @@ -405,13 +393,13 @@ func TestAgentRunWithMemory(t *testing.T) { // Pre-populate memory. store.Save(ctx, "conv-1", []Message{ - {ID: "prev-1", Role: RoleUser, Content: "earlier message"}, + UserText("earlier message"), }) provider := &mockProvider{ responses: []*ProviderResponse{ { - Message: Message{Role: RoleAssistant, Content: "I remember"}, + Messages: []Message{AssistantText("I remember")}, FinishReason: FinishReasonStop, Usage: TokenUsage{InputTokens: 20, OutputTokens: 10}, }, @@ -425,7 +413,7 @@ func TestAgentRunWithMemory(t *testing.T) { resp, err := agent.Run(ctx, AgentRequest{ ConversationID: "conv-1", - Messages: []Message{{Role: RoleUser, Content: "new message"}}, + Messages: []Message{UserText("new message")}, }) if err != nil { t.Fatalf("Run error: %v", err) @@ -435,8 +423,8 @@ func TestAgentRunWithMemory(t *testing.T) { if len(resp.Messages) != 3 { t.Fatalf("got %d messages, want 3", len(resp.Messages)) } - if resp.Messages[0].Content != "earlier message" { - t.Errorf("Messages[0] = %q, want %q", resp.Messages[0].Content, "earlier message") + if resp.Messages[0].Text() != "earlier message" { + t.Errorf("Messages[0] = %q, want %q", resp.Messages[0].Text(), "earlier message") } // Memory should be updated with all 3 messages. @@ -452,15 +440,12 @@ func TestAgentRunUsageAccumulation(t *testing.T) { provider := &mockProvider{ responses: []*ProviderResponse{ { - Message: Message{ - Role: RoleAssistant, - ToolCalls: []ToolCall{{ID: "c1", Name: "noop", Arguments: json.RawMessage(`{}`)}}, - }, + Messages: []Message{{Role: RoleAssistant, Content: []ContentBlock{ToolCallBlock(ToolCall{ID: "c1", Name: "noop", Arguments: json.RawMessage(`{}`)})}}}, FinishReason: FinishReasonToolUse, Usage: TokenUsage{InputTokens: 10, OutputTokens: 5}, }, { - Message: Message{Role: RoleAssistant, Content: "done"}, + Messages: []Message{AssistantText("done")}, FinishReason: FinishReasonStop, Usage: TokenUsage{InputTokens: 20, OutputTokens: 8}, }, @@ -477,7 +462,7 @@ func TestAgentRunUsageAccumulation(t *testing.T) { }) resp, err := agent.Run(context.Background(), AgentRequest{ - Messages: []Message{{Role: RoleUser, Content: "go"}}, + Messages: []Message{UserText("go")}, }) if err != nil { t.Fatalf("Run error: %v", err) diff --git a/docs/design/design.md b/docs/design/design.md index 137cd76..ee9a414 100644 --- a/docs/design/design.md +++ b/docs/design/design.md @@ -21,7 +21,7 @@ forge/tool/ tool interface, typed Func helper, calls, result forge/tool/registry/ tool registry implementation forge/provider/ provider interface, requests, responses, usage, finish reasons forge/provider/anthropic/ Anthropic Messages API provider -forge/provider/openai/ OpenAI-compatible provider (OpenAI, xAI, Together, Groq) +forge/provider/openai/ OpenAI Responses API provider forge/provider/xai/ xAI Responses API provider (web search, X search, citations) forge/memory/ memory store interface forge/memory/inmem/ in-memory memory store @@ -456,7 +456,7 @@ Pseudocode: 24. 25. usage.InputTokens += providerResp.Usage.InputTokens 26. usage.OutputTokens += providerResp.Usage.OutputTokens -27. append providerResp.Message to messages +27. append providerResp.Messages to messages 28. iteration++ 29. 30. if providerResp.FinishReason == FinishReasonStop: @@ -464,7 +464,7 @@ Pseudocode: 32. break LOOP 33. 34. // FinishReason is tool_use - execute the tool calls -35. toolResults = executor.Execute(ctx, providerResp.Message.ToolCalls) +35. toolResults = executor.Execute(ctx, last(providerResp.Messages).ToolCalls()) 36. 37. // Check for tool errors 38. for each result in toolResults where result.IsError: @@ -472,12 +472,12 @@ Pseudocode: 40. if errorPolicy == ErrorPolicyStop: 41. finishReason = FinishReasonError 42. // still append the tool message so the conversation is coherent -43. toolMsg = Message{Role: RoleTool, ToolResults: toolResults} +43. toolMsg = ToolMessage(toolResults...) 44. append toolMsg to messages 45. break LOOP 46. 47. // Feed results back to the LLM -48. toolMsg = Message{Role: RoleTool, ToolResults: toolResults} +48. toolMsg = ToolMessage(toolResults...) 49. append toolMsg to messages 50. 51. END LOOP diff --git a/memory/inmem/inmem.go b/memory/inmem/inmem.go index 63adfe5..8bddd59 100644 --- a/memory/inmem/inmem.go +++ b/memory/inmem/inmem.go @@ -30,9 +30,7 @@ func (s *Store) Load(_ context.Context, conversationID string) ([]message.Messag return nil, nil } - cp := make([]message.Message, len(msgs)) - copy(cp, msgs) - return cp, nil + return cloneMessages(msgs), nil } // Save replaces the entire message history for the given conversation. @@ -40,9 +38,7 @@ func (s *Store) Save(_ context.Context, conversationID string, messages []messag s.mu.Lock() defer s.mu.Unlock() - cp := make([]message.Message, len(messages)) - copy(cp, messages) - s.data[conversationID] = cp + s.data[conversationID] = cloneMessages(messages) return nil } @@ -54,3 +50,43 @@ func (s *Store) Clear(_ context.Context, conversationID string) error { delete(s.data, conversationID) return nil } + +func cloneMessages(messages []message.Message) []message.Message { + cp := make([]message.Message, len(messages)) + for i, msg := range messages { + cp[i] = msg + cp[i].Content = cloneContentBlocks(msg.Content) + } + return cp +} + +func cloneContentBlocks(blocks []message.ContentBlock) []message.ContentBlock { + cp := make([]message.ContentBlock, len(blocks)) + for i, block := range blocks { + cp[i] = block + if block.Image != nil { + image := *block.Image + if image.Data != nil { + image.Data = append([]byte(nil), image.Data...) + } + cp[i].Image = &image + } + if block.ToolCall != nil { + call := *block.ToolCall + call.Arguments = append([]byte(nil), call.Arguments...) + cp[i].ToolCall = &call + } + if block.ToolResult != nil { + result := *block.ToolResult + cp[i].ToolResult = &result + } + if block.Metadata != nil { + metadata := make(map[string]any, len(block.Metadata)) + for k, v := range block.Metadata { + metadata[k] = v + } + cp[i].Metadata = metadata + } + } + return cp +} diff --git a/memory_test.go b/memory_test.go index 10934f2..1b4fc31 100644 --- a/memory_test.go +++ b/memory_test.go @@ -21,8 +21,8 @@ func TestInMemoryStoreSaveAndLoad(t *testing.T) { ctx := context.Background() messages := []Message{ - {ID: "1", Role: RoleUser, Content: "hello"}, - {ID: "2", Role: RoleAssistant, Content: "hi"}, + {ID: "1", Role: RoleUser, Content: []ContentBlock{Text("hello")}}, + {ID: "2", Role: RoleAssistant, Content: []ContentBlock{Text("hi")}}, } if err := s.Save(ctx, "conv-1", messages); err != nil { @@ -36,8 +36,8 @@ func TestInMemoryStoreSaveAndLoad(t *testing.T) { if len(loaded) != 2 { t.Fatalf("got %d messages, want 2", len(loaded)) } - if loaded[0].Content != "hello" { - t.Errorf("loaded[0].Content = %q, want %q", loaded[0].Content, "hello") + if loaded[0].Text() != "hello" { + t.Errorf("loaded[0].Text() = %q, want %q", loaded[0].Text(), "hello") } } @@ -45,15 +45,15 @@ func TestInMemoryStoreSaveReplaces(t *testing.T) { s := NewInMemoryStore() ctx := context.Background() - s.Save(ctx, "conv-1", []Message{{ID: "1", Content: "first"}}) - s.Save(ctx, "conv-1", []Message{{ID: "2", Content: "second"}}) + s.Save(ctx, "conv-1", []Message{{ID: "1", Content: []ContentBlock{Text("first")}}}) + s.Save(ctx, "conv-1", []Message{{ID: "2", Content: []ContentBlock{Text("second")}}}) loaded, _ := s.Load(ctx, "conv-1") if len(loaded) != 1 { t.Fatalf("got %d messages, want 1", len(loaded)) } - if loaded[0].Content != "second" { - t.Errorf("Content = %q, want %q", loaded[0].Content, "second") + if loaded[0].Text() != "second" { + t.Errorf("Content = %q, want %q", loaded[0].Text(), "second") } } @@ -61,7 +61,7 @@ func TestInMemoryStoreClear(t *testing.T) { s := NewInMemoryStore() ctx := context.Background() - s.Save(ctx, "conv-1", []Message{{ID: "1", Content: "hello"}}) + s.Save(ctx, "conv-1", []Message{{ID: "1", Content: []ContentBlock{Text("hello")}}}) if err := s.Clear(ctx, "conv-1"); err != nil { t.Fatalf("Clear error: %v", err) @@ -77,13 +77,13 @@ func TestInMemoryStoreReturnsCopy(t *testing.T) { s := NewInMemoryStore() ctx := context.Background() - s.Save(ctx, "conv-1", []Message{{ID: "1", Content: "original"}}) + s.Save(ctx, "conv-1", []Message{{ID: "1", Content: []ContentBlock{Text("original")}}}) loaded, _ := s.Load(ctx, "conv-1") - loaded[0].Content = "mutated" + loaded[0].Content[0].Text = "mutated" reloaded, _ := s.Load(ctx, "conv-1") - if reloaded[0].Content != "original" { - t.Errorf("Content = %q, want %q (store should return copies)", reloaded[0].Content, "original") + if reloaded[0].Text() != "original" { + t.Errorf("Content = %q, want %q (store should return copies)", reloaded[0].Text(), "original") } } diff --git a/message/message.go b/message/message.go index a83e39c..5298f69 100644 --- a/message/message.go +++ b/message/message.go @@ -1,6 +1,10 @@ package message -import "github.com/katasec/forge/tool" +import ( + "strings" + + "github.com/katasec/forge/tool" +) // Role identifies the sender of a message in a conversation. type Role string @@ -12,16 +16,118 @@ const ( RoleSystem Role = "system" ) +// ContentType identifies the kind of content inside a message. +type ContentType string + +const ( + ContentTypeText ContentType = "text" + ContentTypeImage ContentType = "image" + ContentTypeToolCall ContentType = "tool_call" + ContentTypeToolResult ContentType = "tool_result" +) + +// ImageContent represents image input for multimodal providers. +type ImageContent struct { + URL string `json:"url,omitempty"` + MediaType string `json:"media_type,omitempty"` + Data []byte `json:"data,omitempty"` +} + +// ContentBlock is one typed unit of message content. +type ContentBlock struct { + Type ContentType `json:"type"` + Text string `json:"text,omitempty"` + Image *ImageContent `json:"image,omitempty"` + ToolCall *tool.Call `json:"tool_call,omitempty"` + ToolResult *tool.Result `json:"tool_result,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +// Text creates a text content block. +func Text(content string) ContentBlock { + return ContentBlock{Type: ContentTypeText, Text: content} +} + +// ImageURL creates an image content block backed by a URL. +func ImageURL(url string) ContentBlock { + return ContentBlock{Type: ContentTypeImage, Image: &ImageContent{URL: url}} +} + +// ImageBytes creates an image content block backed by bytes. +func ImageBytes(data []byte, mediaType string) ContentBlock { + return ContentBlock{Type: ContentTypeImage, Image: &ImageContent{Data: data, MediaType: mediaType}} +} + +// ToolCall creates a tool-call content block. +func ToolCall(call tool.Call) ContentBlock { + return ContentBlock{Type: ContentTypeToolCall, ToolCall: &call} +} + +// ToolResult creates a tool-result content block. +func ToolResult(result tool.Result) ContentBlock { + return ContentBlock{Type: ContentTypeToolResult, ToolResult: &result} +} + // Message represents a single message in a conversation. type Message struct { - ID string `json:"id"` - Role Role `json:"role"` - Content string `json:"content"` - ToolCalls []tool.Call `json:"tool_calls,omitempty"` - ToolResults []tool.Result `json:"tool_results,omitempty"` + ID string `json:"id,omitempty"` + Role Role `json:"role"` + Content []ContentBlock `json:"content,omitempty"` +} + +// UserMessage creates a user-role message with the given content blocks. +func UserMessage(blocks ...ContentBlock) Message { + return Message{Role: RoleUser, Content: blocks} +} + +// UserText creates a user-role text message. +func UserText(content string) Message { + return UserMessage(Text(content)) +} + +// AssistantText creates an assistant-role text message. +func AssistantText(content string) Message { + return Message{Role: RoleAssistant, Content: []ContentBlock{Text(content)}} +} + +// ToolMessage creates a tool-role message with tool results. +func ToolMessage(results ...tool.Result) Message { + blocks := make([]ContentBlock, 0, len(results)) + for _, result := range results { + blocks = append(blocks, ToolResult(result)) + } + return Message{Role: RoleTool, Content: blocks} +} + +// Text returns all text blocks joined together. +func (m Message) Text() string { + var parts []string + for _, block := range m.Content { + if block.Type == ContentTypeText && block.Text != "" { + parts = append(parts, block.Text) + } + } + return strings.Join(parts, "") +} + +// ToolCalls returns all tool-call blocks in the message. +func (m Message) ToolCalls() []tool.Call { + var calls []tool.Call + for _, block := range m.Content { + if block.Type == ContentTypeToolCall && block.ToolCall != nil { + calls = append(calls, *block.ToolCall) + } + } + return calls } -// UserMessage creates a user-role message with the given content. -func UserMessage(content string) Message { - return Message{Role: RoleUser, Content: content} +// ToolResults returns all tool-result blocks in the message. +func (m Message) ToolResults() []tool.Result { + var results []tool.Result + for _, block := range m.Content { + if block.Type == ContentTypeToolResult && block.ToolResult != nil { + results = append(results, *block.ToolResult) + } + } + return results } diff --git a/middleware_test.go b/middleware_test.go index 963f90d..3d4f8bd 100644 --- a/middleware_test.go +++ b/middleware_test.go @@ -16,7 +16,7 @@ func TestSingleMiddleware(t *testing.T) { inner := RunFunc(func(_ context.Context, _ ProviderRequest) (*ProviderResponse, error) { return &ProviderResponse{ - Message: Message{Role: RoleAssistant, Content: "ok"}, + Messages: []Message{AssistantText("ok")}, FinishReason: FinishReasonStop, }, nil }) @@ -29,8 +29,8 @@ func TestSingleMiddleware(t *testing.T) { if !called { t.Error("middleware was not called") } - if resp.Message.Content != "ok" { - t.Errorf("Content = %q, want %q", resp.Message.Content, "ok") + if resp.Messages[0].Text() != "ok" { + t.Errorf("Content = %q, want %q", resp.Messages[0].Text(), "ok") } } @@ -53,7 +53,7 @@ func TestMiddlewareCompositionOrder(t *testing.T) { inner := RunFunc(func(_ context.Context, _ ProviderRequest) (*ProviderResponse, error) { order = append(order, "provider") return &ProviderResponse{ - Message: Message{Role: RoleAssistant, Content: "ok"}, + Messages: []Message{AssistantText("ok")}, FinishReason: FinishReasonStop, }, nil }) diff --git a/provider/anthropic/anthropic.go b/provider/anthropic/anthropic.go index 735a183..ed073b8 100644 --- a/provider/anthropic/anthropic.go +++ b/provider/anthropic/anthropic.go @@ -28,6 +28,14 @@ func New(apiKey, model string) *Provider { } } +// Capabilities describes the Anthropic provider features Forge currently supports. +func (p *Provider) Capabilities() forge.Capabilities { + return forge.Capabilities{ + Usage: true, + Production: true, + } +} + // --- Anthropic API request/response types --- type request struct { @@ -68,7 +76,7 @@ func (p *Provider) Generate(ctx context.Context, req forge.ProviderRequest) (*fo } msgs = append(msgs, message{ Role: string(m.Role), - Content: m.Content, + Content: m.Text(), }) } @@ -127,10 +135,7 @@ func (p *Provider) Generate(ctx context.Context, req forge.ProviderRequest) (*fo } return &forge.ProviderResponse{ - Message: forge.Message{ - Role: forge.RoleAssistant, - Content: textContent, - }, + Messages: []forge.Message{forge.AssistantText(textContent)}, FinishReason: finishReason, Usage: forge.TokenUsage{ InputTokens: apiResp.Usage.InputTokens, diff --git a/provider/anthropic/anthropic_test.go b/provider/anthropic/anthropic_test.go index 8609e6e..7bcf988 100644 --- a/provider/anthropic/anthropic_test.go +++ b/provider/anthropic/anthropic_test.go @@ -74,18 +74,18 @@ func TestGenerate(t *testing.T) { resp, err := p.Generate(context.Background(), forge.ProviderRequest{ SystemPrompt: "You are helpful.", Messages: []forge.Message{ - {Role: forge.RoleUser, Content: "Hi"}, + forge.UserText("Hi"), }, }) if err != nil { t.Fatalf("Generate: %v", err) } - if resp.Message.Content != "Hello!" { - t.Errorf("content = %q, want %q", resp.Message.Content, "Hello!") + if resp.Messages[0].Text() != "Hello!" { + t.Errorf("content = %q, want %q", resp.Messages[0].Text(), "Hello!") } - if resp.Message.Role != forge.RoleAssistant { - t.Errorf("role = %q, want %q", resp.Message.Role, forge.RoleAssistant) + if resp.Messages[0].Role != forge.RoleAssistant { + t.Errorf("role = %q, want %q", resp.Messages[0].Role, forge.RoleAssistant) } if resp.FinishReason != forge.FinishReasonStop { t.Errorf("finishReason = %q, want %q", resp.FinishReason, forge.FinishReasonStop) @@ -112,7 +112,7 @@ func TestGenerateAPIError(t *testing.T) { } _, err := p.Generate(context.Background(), forge.ProviderRequest{ - Messages: []forge.Message{{Role: forge.RoleUser, Content: "Hi"}}, + Messages: []forge.Message{forge.UserText("Hi")}, }) if err == nil { t.Fatal("expected error for 401 response") diff --git a/provider/openai/model.go b/provider/openai/model.go new file mode 100644 index 0000000..bdbc173 --- /dev/null +++ b/provider/openai/model.go @@ -0,0 +1,9 @@ +package openai + +// Model is an OpenAI model identifier. +type Model string + +const ( + // ModelGPT54Nano is OpenAI's GPT-5.4 nano model. + ModelGPT54Nano Model = "gpt-5.4-nano" +) diff --git a/provider/openai/openai.go b/provider/openai/openai.go index 237d5bf..7095962 100644 --- a/provider/openai/openai.go +++ b/provider/openai/openai.go @@ -1,11 +1,10 @@ -// Package openai implements forge.Provider using the OpenAI-compatible chat -// completions API. Works with OpenAI, xAI (Grok), Together, Groq, and any -// other provider that speaks the OpenAI format. +// Package openai implements forge.Provider using the OpenAI Responses API. package openai import ( "bytes" "context" + "encoding/base64" "encoding/json" "fmt" "io" @@ -14,7 +13,7 @@ import ( "github.com/katasec/forge" ) -// Provider implements forge.Provider using the OpenAI-compatible API. +// Provider implements forge.Provider using the OpenAI Responses API. type Provider struct { baseURL string apiKey string @@ -22,63 +21,91 @@ type Provider struct { client *http.Client } -// New creates an OpenAI-compatible provider for the given base URL, API key, and model. -func New(baseURL, apiKey, model string) *Provider { - return &Provider{ - baseURL: baseURL, +// New creates an OpenAI provider using the Responses API. +func New(apiKey string, model Model, opts ...Option) *Provider { + p := &Provider{ + baseURL: "https://api.openai.com/v1", apiKey: apiKey, - model: model, + model: string(model), client: &http.Client{}, } + for _, opt := range opts { + opt(p) + } + return p +} + +// Capabilities describes the OpenAI provider features Forge currently supports. +func (p *Provider) Capabilities() forge.Capabilities { + return forge.Capabilities{ + Images: true, + Usage: true, + Production: true, + } } -// --- OpenAI API request/response types --- +// --- OpenAI Responses API request/response types --- type request struct { - Model string `json:"model"` - Messages []message `json:"messages"` + Model string `json:"model"` + Input []inputItem `json:"input"` + Instructions string `json:"instructions,omitempty"` } -type message struct { - Role string `json:"role"` - Content string `json:"content"` +type inputItem struct { + Role string `json:"role"` + Content []contentInput `json:"content"` +} + +type contentInput struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + ImageURL string `json:"image_url,omitempty"` } type response struct { - Choices []choice `json:"choices"` - Usage usage `json:"usage"` + Output []outputItem `json:"output"` + Usage usage `json:"usage"` } -type choice struct { - Message message `json:"message"` - FinishReason string `json:"finish_reason"` +type outputItem struct { + Type string `json:"type"` + Role string `json:"role,omitempty"` + Content []contentOutput `json:"content,omitempty"` +} + +type contentOutput struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` } type usage struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` + InputTokens int `json:"input_tokens"` + InputTokensDetails inputTokensDetails `json:"input_tokens_details"` + OutputTokens int `json:"output_tokens"` + OutputTokensDetails outputTokensDetails `json:"output_tokens_details"` + TotalTokens int `json:"total_tokens"` +} + +type inputTokensDetails struct { + CachedTokens int `json:"cached_tokens"` } -// Generate sends a request to the OpenAI-compatible chat completions endpoint. +type outputTokensDetails struct { + ReasoningTokens int `json:"reasoning_tokens"` +} + +// Generate sends a request to the OpenAI Responses API. func (p *Provider) Generate(ctx context.Context, req forge.ProviderRequest) (*forge.ProviderResponse, error) { - // Convert forge messages to OpenAI format. - var msgs []message - if req.SystemPrompt != "" { - msgs = append(msgs, message{Role: "system", Content: req.SystemPrompt}) - } - for _, m := range req.Messages { - if m.Role == forge.RoleSystem { - continue - } - msgs = append(msgs, message{ - Role: string(m.Role), - Content: m.Content, - }) + input, err := convertMessages(req.Messages) + if err != nil { + return nil, err } body := request{ - Model: p.model, - Messages: msgs, + Model: p.model, + Input: input, + Instructions: req.SystemPrompt, } jsonBody, err := json.Marshal(body) @@ -86,8 +113,7 @@ func (p *Provider) Generate(ctx context.Context, req forge.ProviderRequest) (*fo return nil, fmt.Errorf("marshal request: %w", err) } - url := fmt.Sprintf("%s/chat/completions", p.baseURL) - httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(jsonBody)) + httpReq, err := http.NewRequestWithContext(ctx, "POST", p.baseURL+"/responses", bytes.NewReader(jsonBody)) if err != nil { return nil, fmt.Errorf("create request: %w", err) } @@ -106,7 +132,7 @@ func (p *Provider) Generate(ctx context.Context, req forge.ProviderRequest) (*fo } if httpResp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("API error (%d): %s", httpResp.StatusCode, string(respBody)) + return nil, fmt.Errorf("openai API error (%d): %s", httpResp.StatusCode, string(respBody)) } var apiResp response @@ -114,26 +140,114 @@ func (p *Provider) Generate(ctx context.Context, req forge.ProviderRequest) (*fo return nil, fmt.Errorf("unmarshal response: %w", err) } - if len(apiResp.Choices) == 0 { - return nil, fmt.Errorf("no choices in response") - } - - ch := apiResp.Choices[0] - - finishReason := forge.FinishReasonStop - if ch.FinishReason == "tool_calls" { - finishReason = forge.FinishReasonToolUse + messages := convertResponse(apiResp) + if len(messages) == 0 { + return nil, fmt.Errorf("no assistant messages in response") } return &forge.ProviderResponse{ - Message: forge.Message{ - Role: forge.RoleAssistant, - Content: ch.Message.Content, - }, - FinishReason: finishReason, + Messages: messages, + FinishReason: forge.FinishReasonStop, Usage: forge.TokenUsage{ - InputTokens: apiResp.Usage.PromptTokens, - OutputTokens: apiResp.Usage.CompletionTokens, + InputTokens: apiResp.Usage.InputTokens, + CachedInputTokens: apiResp.Usage.InputTokensDetails.CachedTokens, + OutputTokens: apiResp.Usage.OutputTokens, + ReasoningOutputTokens: apiResp.Usage.OutputTokensDetails.ReasoningTokens, + TotalTokens: apiResp.Usage.TotalTokens, }, }, nil } + +func convertMessages(messages []forge.Message) ([]inputItem, error) { + items := make([]inputItem, 0, len(messages)) + for _, msg := range messages { + if msg.Role == forge.RoleSystem { + continue + } + + content, err := convertContent(msg.Role, msg.Content) + if err != nil { + return nil, err + } + if len(content) == 0 { + continue + } + + items = append(items, inputItem{ + Role: string(msg.Role), + Content: content, + }) + } + return items, nil +} + +func convertContent(role forge.Role, blocks []forge.ContentBlock) ([]contentInput, error) { + content := make([]contentInput, 0, len(blocks)) + for _, block := range blocks { + switch block.Type { + case forge.ContentTypeText: + contentType := "input_text" + if role == forge.RoleAssistant { + contentType = "output_text" + } + content = append(content, contentInput{Type: contentType, Text: block.Text}) + case forge.ContentTypeImage: + if role != forge.RoleUser { + return nil, fmt.Errorf("openai image content is only supported for user messages") + } + if block.Image == nil { + return nil, fmt.Errorf("image content block missing image data") + } + imageURL, err := openAIImageURL(*block.Image) + if err != nil { + return nil, err + } + content = append(content, contentInput{Type: "input_image", ImageURL: imageURL}) + case forge.ContentTypeToolCall, forge.ContentTypeToolResult: + return nil, fmt.Errorf("openai provider does not support tool content yet") + default: + return nil, fmt.Errorf("unsupported content block type: %s", block.Type) + } + } + return content, nil +} + +func openAIImageURL(image forge.ImageContent) (string, error) { + if image.URL != "" { + return image.URL, nil + } + if len(image.Data) == 0 { + return "", fmt.Errorf("image content requires URL or data") + } + if image.MediaType == "" { + return "", fmt.Errorf("image bytes require media type") + } + encoded := base64.StdEncoding.EncodeToString(image.Data) + return fmt.Sprintf("data:%s;base64,%s", image.MediaType, encoded), nil +} + +func convertResponse(apiResp response) []forge.Message { + var messages []forge.Message + for _, item := range apiResp.Output { + if item.Type != "message" { + continue + } + + var blocks []forge.ContentBlock + for _, content := range item.Content { + if content.Type == "output_text" && content.Text != "" { + blocks = append(blocks, forge.Text(content.Text)) + } + } + if len(blocks) == 0 { + continue + } + + role := forge.RoleAssistant + if item.Role != "" { + role = forge.Role(item.Role) + } + messages = append(messages, forge.Message{Role: role, Content: blocks}) + } + return messages +} diff --git a/provider/openai/openai_test.go b/provider/openai/openai_test.go index 2ea565d..9dfd0a0 100644 --- a/provider/openai/openai_test.go +++ b/provider/openai/openai_test.go @@ -14,18 +14,18 @@ import ( var _ forge.Provider = (*Provider)(nil) func TestNew(t *testing.T) { - p := New("https://api.openai.com/v1", "test-key", "gpt-4") + p := New("test-key", ModelGPT54Nano) if p == nil { t.Fatal("New returned nil") } if p.baseURL != "https://api.openai.com/v1" { - t.Errorf("baseURL = %q, want %q", p.baseURL, "https://api.openai.com/v1") + t.Errorf("baseURL = %q, want trimmed base URL", p.baseURL) } if p.apiKey != "test-key" { - t.Errorf("apiKey = %q, want %q", p.apiKey, "test-key") + t.Errorf("apiKey = %q, want test-key", p.apiKey) } - if p.model != "gpt-4" { - t.Errorf("model = %q, want %q", p.model, "gpt-4") + if p.model != "gpt-5.4-nano" { + t.Errorf("model = %q, want gpt-5.4-nano", p.model) } } @@ -34,6 +34,9 @@ func TestGenerate(t *testing.T) { if r.Method != "POST" { t.Errorf("method = %s, want POST", r.Method) } + if r.URL.Path != "/responses" { + t.Errorf("path = %q, want /responses", r.URL.Path) + } if got := r.Header.Get("Authorization"); got != "Bearer test-key" { t.Errorf("Authorization = %q, want %q", got, "Bearer test-key") } @@ -42,67 +45,130 @@ func TestGenerate(t *testing.T) { if err := json.NewDecoder(r.Body).Decode(&req); err != nil { t.Fatalf("decode request: %v", err) } - if req.Model != "gpt-4" { - t.Errorf("model = %q, want %q", req.Model, "gpt-4") + if req.Model != "gpt-5.4-nano" { + t.Errorf("model = %q, want gpt-5.4-nano", req.Model) + } + if req.Instructions != "You are helpful." { + t.Errorf("instructions = %q", req.Instructions) } - // System prompt should be the first message. - if len(req.Messages) == 0 || req.Messages[0].Role != "system" { - t.Error("expected system message as first message") + if len(req.Input) != 1 || req.Input[0].Role != "user" { + t.Fatalf("input = %+v, want one user item", req.Input) + } + if got := req.Input[0].Content[0]; got.Type != "input_text" || got.Text != "Hi" { + t.Fatalf("content = %+v, want input_text Hi", got) } resp := response{ - Choices: []choice{{ - Message: message{Role: "assistant", Content: "Hello!"}, - FinishReason: "stop", + Output: []outputItem{{ + Type: "message", + Role: "assistant", + Content: []contentOutput{{ + Type: "output_text", + Text: "Hello!", + }}, }}, - Usage: usage{PromptTokens: 8, CompletionTokens: 3}, + Usage: usage{ + InputTokens: 8, + OutputTokens: 3, + TotalTokens: 11, + InputTokensDetails: inputTokensDetails{ + CachedTokens: 2, + }, + OutputTokensDetails: outputTokensDetails{ + ReasoningTokens: 1, + }, + }, } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(resp) })) defer srv.Close() - p := New(srv.URL, "test-key", "gpt-4") + p := New("test-key", ModelGPT54Nano, WithBaseURL(srv.URL)) resp, err := p.Generate(context.Background(), forge.ProviderRequest{ SystemPrompt: "You are helpful.", Messages: []forge.Message{ - {Role: forge.RoleUser, Content: "Hi"}, + forge.UserText("Hi"), }, }) if err != nil { t.Fatalf("Generate: %v", err) } - if resp.Message.Content != "Hello!" { - t.Errorf("content = %q, want %q", resp.Message.Content, "Hello!") + if resp.Messages[0].Text() != "Hello!" { + t.Errorf("content = %q, want Hello!", resp.Messages[0].Text()) } - if resp.Message.Role != forge.RoleAssistant { - t.Errorf("role = %q, want %q", resp.Message.Role, forge.RoleAssistant) + if resp.Messages[0].Role != forge.RoleAssistant { + t.Errorf("role = %q, want %q", resp.Messages[0].Role, forge.RoleAssistant) } if resp.FinishReason != forge.FinishReasonStop { t.Errorf("finishReason = %q, want %q", resp.FinishReason, forge.FinishReasonStop) } - if resp.Usage.InputTokens != 8 || resp.Usage.OutputTokens != 3 { - t.Errorf("usage = %+v, want {8, 3}", resp.Usage) + if resp.Usage.InputTokens != 8 || resp.Usage.OutputTokens != 3 || resp.Usage.TotalTokens != 11 { + t.Errorf("usage = %+v, want input 8 output 3 total 11", resp.Usage) + } + if resp.Usage.CachedInputTokens != 2 || resp.Usage.ReasoningOutputTokens != 1 { + t.Errorf("usage details = %+v, want cached 2 reasoning 1", resp.Usage) + } +} + +func TestGenerateWithImageURL(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var req request + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + t.Fatalf("decode request: %v", err) + } + if len(req.Input) != 1 || len(req.Input[0].Content) != 2 { + t.Fatalf("input content = %+v, want text and image", req.Input) + } + image := req.Input[0].Content[1] + if image.Type != "input_image" || image.ImageURL != "https://example.com/cat.png" { + t.Fatalf("image content = %+v", image) + } + + json.NewEncoder(w).Encode(response{ + Output: []outputItem{{ + Type: "message", + Role: "assistant", + Content: []contentOutput{{Type: "output_text", Text: "A cat."}}, + }}, + }) + })) + defer srv.Close() + + p := New("test-key", ModelGPT54Nano, WithBaseURL(srv.URL)) + resp, err := p.Generate(context.Background(), forge.ProviderRequest{ + Messages: []forge.Message{ + forge.UserMessage( + forge.Text("Describe this image."), + forge.ImageURL("https://example.com/cat.png"), + ), + }, + }) + if err != nil { + t.Fatalf("Generate: %v", err) + } + if resp.Messages[0].Text() != "A cat." { + t.Errorf("content = %q, want A cat.", resp.Messages[0].Text()) } } -func TestGenerateNoChoices(t *testing.T) { +func TestGenerateNoMessages(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - resp := response{Choices: []choice{}, Usage: usage{}} + resp := response{Output: []outputItem{}, Usage: usage{}} w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(resp) })) defer srv.Close() - p := New(srv.URL, "test-key", "gpt-4") + p := New("test-key", ModelGPT54Nano, WithBaseURL(srv.URL)) _, err := p.Generate(context.Background(), forge.ProviderRequest{ - Messages: []forge.Message{{Role: forge.RoleUser, Content: "Hi"}}, + Messages: []forge.Message{forge.UserText("Hi")}, }) if err == nil { - t.Fatal("expected error for empty choices") + t.Fatal("expected error for empty output") } } @@ -113,10 +179,10 @@ func TestGenerateAPIError(t *testing.T) { })) defer srv.Close() - p := New(srv.URL, "test-key", "gpt-4") + p := New("test-key", ModelGPT54Nano, WithBaseURL(srv.URL)) _, err := p.Generate(context.Background(), forge.ProviderRequest{ - Messages: []forge.Message{{Role: forge.RoleUser, Content: "Hi"}}, + Messages: []forge.Message{forge.UserText("Hi")}, }) if err == nil { t.Fatal("expected error for 429 response") diff --git a/provider/openai/options.go b/provider/openai/options.go new file mode 100644 index 0000000..9940048 --- /dev/null +++ b/provider/openai/options.go @@ -0,0 +1,13 @@ +package openai + +import "strings" + +// Option configures a Provider. +type Option func(*Provider) + +// WithBaseURL overrides the OpenAI API base URL. +func WithBaseURL(baseURL string) Option { + return func(p *Provider) { + p.baseURL = strings.TrimRight(baseURL, "/") + } +} diff --git a/provider/provider.go b/provider/provider.go index 0410edc..4ef2b34 100644 --- a/provider/provider.go +++ b/provider/provider.go @@ -19,8 +19,26 @@ const ( // TokenUsage tracks token consumption across provider calls. type TokenUsage struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` + InputTokens int `json:"input_tokens"` + CachedInputTokens int `json:"cached_input_tokens,omitempty"` + OutputTokens int `json:"output_tokens"` + ReasoningOutputTokens int `json:"reasoning_output_tokens,omitempty"` + TotalTokens int `json:"total_tokens,omitempty"` +} + +// Capabilities describes optional provider features. +type Capabilities struct { + Tools bool `json:"tools,omitempty"` + Images bool `json:"images,omitempty"` + Streaming bool `json:"streaming,omitempty"` + Usage bool `json:"usage,omitempty"` + Local bool `json:"local,omitempty"` + Production bool `json:"production,omitempty"` +} + +// CapabilityProvider is implemented by providers that can describe their features. +type CapabilityProvider interface { + Capabilities() Capabilities } // Request is the input to a single LLM call. @@ -32,9 +50,10 @@ type Request struct { // Response is the output of a single LLM call. type Response struct { - Message message.Message `json:"message"` - FinishReason FinishReason `json:"finish_reason"` - Usage TokenUsage `json:"usage"` + Messages []message.Message `json:"messages"` + FinishReason FinishReason `json:"finish_reason"` + Usage TokenUsage `json:"usage"` + Metadata map[string]any `json:"metadata,omitempty"` } // Provider makes a single LLM call. It does not loop. diff --git a/provider/xai/xai.go b/provider/xai/xai.go index ed1607f..36ab8a2 100644 --- a/provider/xai/xai.go +++ b/provider/xai/xai.go @@ -18,6 +18,7 @@ import ( "sync" "github.com/katasec/forge" + "github.com/katasec/forge/message" ) // Provider implements forge.Provider using the xAI Responses API. @@ -75,6 +76,15 @@ func New(apiKey string, model Model, opts ...Option) *Provider { return p } +// Capabilities describes the xAI provider features Forge currently supports. +func (p *Provider) Capabilities() forge.Capabilities { + return forge.Capabilities{ + Tools: true, + Usage: true, + Production: true, + } +} + // WithBaseURL overrides the API base URL (useful for testing). func WithBaseURL(url string) Option { return func(p *Provider) { p.baseURL = url } @@ -226,8 +236,8 @@ func convertMessages(msgs []forge.Message, systemPrompt string) []inputItem { } // Tool result messages expand into one input item per result. - if m.Role == forge.RoleTool && len(m.ToolResults) > 0 { - for _, tr := range m.ToolResults { + if m.Role == forge.RoleTool && len(m.ToolResults()) > 0 { + for _, tr := range m.ToolResults() { items = append(items, inputItem{ Type: "function_call_output", CallID: tr.CallID, @@ -239,7 +249,7 @@ func convertMessages(msgs []forge.Message, systemPrompt string) []inputItem { items = append(items, inputItem{ Role: string(m.Role), - Content: m.Content, + Content: m.Text(), }) } @@ -302,12 +312,16 @@ func parseResponse(resp *response) (*forge.ProviderResponse, []Citation) { finishReason = forge.FinishReasonToolUse } + blocks := []forge.ContentBlock{} + if content != "" { + blocks = append(blocks, forge.Text(content)) + } + for _, call := range toolCalls { + blocks = append(blocks, message.ToolCall(call)) + } + return &forge.ProviderResponse{ - Message: forge.Message{ - Role: forge.RoleAssistant, - Content: content, - ToolCalls: toolCalls, - }, + Messages: []forge.Message{{Role: forge.RoleAssistant, Content: blocks}}, FinishReason: finishReason, Usage: forge.TokenUsage{ InputTokens: resp.Usage.InputTokens, diff --git a/provider/xai/xai_test.go b/provider/xai/xai_test.go index f655039..c3eda84 100644 --- a/provider/xai/xai_test.go +++ b/provider/xai/xai_test.go @@ -107,18 +107,18 @@ func TestGenerate(t *testing.T) { resp, err := p.Generate(context.Background(), forge.ProviderRequest{ SystemPrompt: "Be helpful.", Messages: []forge.Message{ - {Role: forge.RoleUser, Content: "Hi"}, + forge.UserText("Hi"), }, }) if err != nil { t.Fatalf("Generate: %v", err) } - if resp.Message.Content != "Hello from Grok!" { - t.Errorf("content = %q", resp.Message.Content) + if resp.Messages[0].Text() != "Hello from Grok!" { + t.Errorf("content = %q", resp.Messages[0].Text()) } - if resp.Message.Role != forge.RoleAssistant { - t.Errorf("role = %q", resp.Message.Role) + if resp.Messages[0].Role != forge.RoleAssistant { + t.Errorf("role = %q", resp.Messages[0].Role) } if resp.FinishReason != forge.FinishReasonStop { t.Errorf("finishReason = %q", resp.FinishReason) @@ -166,7 +166,7 @@ func TestGenerateWithFunctionCalls(t *testing.T) { p := New("key", ModelGrok3Mini, WithBaseURL(srv.URL)) resp, err := p.Generate(context.Background(), forge.ProviderRequest{ - Messages: []forge.Message{{Role: forge.RoleUser, Content: "Weather in SF?"}}, + Messages: []forge.Message{forge.UserText("Weather in SF?")}, Tools: []forge.ToolDefinition{{ Name: "get_weather", Description: "Get weather", @@ -180,10 +180,10 @@ func TestGenerateWithFunctionCalls(t *testing.T) { if resp.FinishReason != forge.FinishReasonToolUse { t.Errorf("finishReason = %q, want tool_use", resp.FinishReason) } - if len(resp.Message.ToolCalls) != 1 { - t.Fatalf("toolCalls = %d, want 1", len(resp.Message.ToolCalls)) + if len(resp.Messages[0].ToolCalls()) != 1 { + t.Fatalf("toolCalls = %d, want 1", len(resp.Messages[0].ToolCalls())) } - tc := resp.Message.ToolCalls[0] + tc := resp.Messages[0].ToolCalls()[0] if tc.ID != "call-1" { t.Errorf("toolCall.ID = %q", tc.ID) } @@ -238,12 +238,12 @@ func TestGenerateWithToolResults(t *testing.T) { p := New("key", ModelGrok3Mini, WithBaseURL(srv.URL)) resp, err := p.Generate(context.Background(), forge.ProviderRequest{ Messages: []forge.Message{ - {Role: forge.RoleUser, Content: "Weather in SF?"}, - {Role: forge.RoleAssistant, Content: "", ToolCalls: []forge.ToolCall{ - {ID: "call-1", Name: "get_weather", Arguments: json.RawMessage(`{"city":"SF"}`)}, + forge.UserText("Weather in SF?"), + {Role: forge.RoleAssistant, Content: []forge.ContentBlock{ + forge.ToolCallBlock(forge.ToolCall{ID: "call-1", Name: "get_weather", Arguments: json.RawMessage(`{"city":"SF"}`)}), }}, - {Role: forge.RoleTool, ToolResults: []forge.ToolResult{ - {CallID: "call-1", Content: "72°F"}, + {Role: forge.RoleTool, Content: []forge.ContentBlock{ + forge.ToolResultBlock(forge.ToolResult{CallID: "call-1", Content: "72°F"}), }}, }, }) @@ -251,8 +251,8 @@ func TestGenerateWithToolResults(t *testing.T) { t.Fatalf("Generate: %v", err) } - if resp.Message.Content != "It's 72°F in SF." { - t.Errorf("content = %q", resp.Message.Content) + if resp.Messages[0].Text() != "It's 72°F in SF." { + t.Errorf("content = %q", resp.Messages[0].Text()) } if resp.FinishReason != forge.FinishReasonStop { t.Errorf("finishReason = %q", resp.FinishReason) @@ -311,14 +311,14 @@ func TestGenerateWithWebSearch(t *testing.T) { WithWebSearch(AllowedDomains("reuters.com")), ) resp, err := p.Generate(context.Background(), forge.ProviderRequest{ - Messages: []forge.Message{{Role: forge.RoleUser, Content: "Latest xAI news?"}}, + Messages: []forge.Message{forge.UserText("Latest xAI news?")}, }) if err != nil { t.Fatalf("Generate: %v", err) } - if resp.Message.Content != "According to Reuters, xAI launched..." { - t.Errorf("content = %q", resp.Message.Content) + if resp.Messages[0].Text() != "According to Reuters, xAI launched..." { + t.Errorf("content = %q", resp.Messages[0].Text()) } if resp.FinishReason != forge.FinishReasonStop { t.Errorf("finishReason = %q", resp.FinishReason) @@ -350,7 +350,7 @@ func TestGenerateAPIError(t *testing.T) { p := New("key", ModelGrok3Mini, WithBaseURL(srv.URL)) _, err := p.Generate(context.Background(), forge.ProviderRequest{ - Messages: []forge.Message{{Role: forge.RoleUser, Content: "Hi"}}, + Messages: []forge.Message{forge.UserText("Hi")}, }) if err == nil { t.Fatal("expected error for 429 response") diff --git a/tool_test.go b/tool_test.go index f4d4c54..0e34474 100644 --- a/tool_test.go +++ b/tool_test.go @@ -84,12 +84,12 @@ func TestFuncInvokeInvalidArgs(t *testing.T) { } func TestUserMessage(t *testing.T) { - msg := UserMessage("hello") + msg := UserText("hello") if msg.Role != RoleUser { t.Fatalf("Role = %q, want %q", msg.Role, RoleUser) } - if msg.Content != "hello" { - t.Fatalf("Content = %q, want hello", msg.Content) + if msg.Text() != "hello" { + t.Fatalf("Content = %q, want hello", msg.Text()) } } diff --git a/types.go b/types.go index 6bb671b..b3d27a3 100644 --- a/types.go +++ b/types.go @@ -16,9 +16,47 @@ const ( ) type Message = message.Message +type ContentBlock = message.ContentBlock +type ContentType = message.ContentType +type ImageContent = message.ImageContent -func UserMessage(content string) Message { - return message.UserMessage(content) +const ( + ContentTypeText = message.ContentTypeText + ContentTypeImage = message.ContentTypeImage + ContentTypeToolCall = message.ContentTypeToolCall + ContentTypeToolResult = message.ContentTypeToolResult +) + +func Text(content string) ContentBlock { + return message.Text(content) +} + +func ImageURL(url string) ContentBlock { + return message.ImageURL(url) +} + +func ImageBytes(data []byte, mediaType string) ContentBlock { + return message.ImageBytes(data, mediaType) +} + +func ToolCallBlock(call ToolCall) ContentBlock { + return message.ToolCall(call) +} + +func ToolResultBlock(result ToolResult) ContentBlock { + return message.ToolResult(result) +} + +func UserMessage(blocks ...ContentBlock) Message { + return message.UserMessage(blocks...) +} + +func UserText(content string) Message { + return message.UserText(content) +} + +func AssistantText(content string) Message { + return message.AssistantText(content) } type ToolCall = tool.Call @@ -35,6 +73,8 @@ const ( ) type TokenUsage = provider.TokenUsage +type Capabilities = provider.Capabilities +type CapabilityProvider = provider.CapabilityProvider type ErrorPolicy string