From 014da40846bd78c0cbac472c5b45767d293ab29b Mon Sep 17 00:00:00 2001 From: Ameer Deen Date: Tue, 7 Apr 2026 20:53:36 +0530 Subject: [PATCH] Add xAI Responses API provider with web search and citations New provider/xai package implementing forge.Provider via the modern xAI Responses API. Supports built-in server-side tools (web search, X search) with domain/handle filtering, native function calling with flat tool definitions, and citation extraction from inline annotations. Co-Authored-By: Claude Opus 4.6 (1M context) --- README.md | 21 +- _examples/hello-world/main.go | 29 ++- docs/design/design.md | 1 + provider/xai/xai.go | 378 ++++++++++++++++++++++++++++++++++ provider/xai/xai_test.go | 358 ++++++++++++++++++++++++++++++++ 5 files changed, 782 insertions(+), 5 deletions(-) create mode 100644 provider/xai/xai.go create mode 100644 provider/xai/xai_test.go diff --git a/README.md b/README.md index 74b6159..0341cda 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,7 @@ Forge handles the **LLM call → tool execution → response** cycle. You supply go get github.com/katasec/forge go get github.com/katasec/forge/provider/anthropic # optional go get github.com/katasec/forge/provider/openai # optional +go get github.com/katasec/forge/provider/xai # optional — xAI Responses API with web search ``` ## Quick Start @@ -63,13 +64,29 @@ import "github.com/katasec/forge/provider/openai" provider := openai.New("https://api.x.ai/v1", os.Getenv("XAI_API_KEY"), "grok-3-mini") ``` -The `openai` package works with any OpenAI-compatible API (xAI, OpenAI, Together, Groq, etc.). See [`_examples/hello-world`](./_examples/hello-world) for the full runnable code. +The `openai` package works with any OpenAI-compatible API (xAI, OpenAI, Together, Groq, etc.). + +Or use the xAI Responses API with built-in web search: + +```go +import "github.com/katasec/forge/provider/xai" + +provider := xai.New(os.Getenv("XAI_API_KEY"), "grok-4-1-fast-non-reasoning", xai.WithWebSearch()) + +// After running the agent, access citations: +citations := provider.LastCitations() +for _, c := range citations { + fmt.Printf("[%s] %s\n", c.Title, c.URL) +} +``` + +See [`_examples/hello-world`](./_examples/hello-world) for the full runnable code. ## Core Concepts ### Provider -The `Provider` interface makes a single LLM call. Forge ships with two built-in providers, or you can implement your own: +The `Provider` interface makes a single LLM call. Forge ships with three built-in providers, or you can implement your own: ```go type Provider interface { diff --git a/_examples/hello-world/main.go b/_examples/hello-world/main.go index cd2f2e5..212fe88 100644 --- a/_examples/hello-world/main.go +++ b/_examples/hello-world/main.go @@ -8,9 +8,13 @@ // export ANTHROPIC_API_KEY=sk-ant-... // go run . // -// # Or use xAI instead: +// # Or use xAI (OpenAI-compatible) instead: // export XAI_API_KEY=xai-... // go run . -provider xai +// +// # Or use xAI Responses API with web search: +// export XAI_API_KEY=xai-... +// go run . -provider xai-search package main import ( @@ -23,14 +27,16 @@ import ( "github.com/katasec/forge" "github.com/katasec/forge/provider/anthropic" "github.com/katasec/forge/provider/openai" + "github.com/katasec/forge/provider/xai" ) func main() { - providerFlag := flag.String("provider", "anthropic", "Provider to use: anthropic or xai") + providerFlag := flag.String("provider", "anthropic", "Provider to use: anthropic, xai, or xai-search") flag.Parse() // Pick your provider — this is the only thing that changes. var provider forge.Provider + var xaiProvider *xai.Provider // for citation access switch *providerFlag { case "anthropic": key := os.Getenv("ANTHROPIC_API_KEY") @@ -44,8 +50,15 @@ func main() { log.Fatal("Set XAI_API_KEY environment variable") } provider = openai.New("https://api.x.ai/v1", key, "grok-3-mini") + case "xai-search": + key := os.Getenv("XAI_API_KEY") + if key == "" { + log.Fatal("Set XAI_API_KEY environment variable") + } + xaiProvider = xai.New(key, "grok-4-1-fast-non-reasoning", xai.WithWebSearch()) + provider = xaiProvider default: - log.Fatalf("Unknown provider: %s (use 'anthropic' or 'xai')", *providerFlag) + log.Fatalf("Unknown provider: %s (use 'anthropic', 'xai', or 'xai-search')", *providerFlag) } // Build the agent — same code regardless of provider. @@ -69,4 +82,14 @@ func main() { fmt.Println(resp.Messages[len(resp.Messages)-1].Content) fmt.Printf("\n[%s | tokens: %d in, %d out]\n", *providerFlag, resp.Usage.InputTokens, resp.Usage.OutputTokens) + + // Show citations if using xai-search. + if xaiProvider != nil { + if citations := xaiProvider.LastCitations(); len(citations) > 0 { + fmt.Println("\nSources:") + for i, c := range citations { + fmt.Printf(" [%d] %s — %s\n", i+1, c.Title, c.URL) + } + } + } } diff --git a/docs/design/design.md b/docs/design/design.md index 369c488..7680cce 100644 --- a/docs/design/design.md +++ b/docs/design/design.md @@ -18,6 +18,7 @@ Sub-packages hold swappable backend implementations (following the `database/sql forge/ root — interfaces, types, Agent, Config, defaults forge/provider/anthropic/ Anthropic Messages API provider forge/provider/openai/ OpenAI-compatible provider (OpenAI, xAI, Together, Groq) +forge/provider/xai/ xAI Responses API provider (web search, X search, citations) ``` Future sub-packages (created when implementations exist, not preemptively): diff --git a/provider/xai/xai.go b/provider/xai/xai.go new file mode 100644 index 0000000..4571599 --- /dev/null +++ b/provider/xai/xai.go @@ -0,0 +1,378 @@ +// Package xai implements forge.Provider using the xAI Responses API. +// +// This provider supports the modern xAI Responses API with built-in +// server-side tools (web search, X search) and native function calling. +// +// Usage: +// +// provider := xai.New(apiKey, "grok-3-mini", xai.WithWebSearch()) +package xai + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "sync" + + "github.com/katasec/forge" +) + +// Provider implements forge.Provider using the xAI Responses API. +type Provider struct { + baseURL string + apiKey string + model string + client *http.Client + tools []requestTool // persistent server-side tools (web_search, x_search) + + mu sync.Mutex + lastCitations []Citation +} + +// Citation represents a source reference returned by xAI search tools. +type Citation struct { + URL string `json:"url"` + Title string `json:"title"` + Snippet string `json:"snippet"` + Source string `json:"source"` + StartIndex int `json:"start_index"` + EndIndex int `json:"end_index"` +} + +// Option configures a Provider. +type Option func(*Provider) + +// WebSearchOption configures the web_search tool. +type WebSearchOption func(*webSearchConfig) + +// XSearchOption configures the x_search tool. +type XSearchOption func(*xSearchConfig) + +type webSearchConfig struct { + AllowedDomains []string `json:"allowed_domains,omitempty"` + ExcludedDomains []string `json:"excluded_domains,omitempty"` +} + +type xSearchConfig struct { + AllowedHandles []string `json:"allowed_x_handles,omitempty"` + ExcludedHandles []string `json:"excluded_x_handles,omitempty"` +} + +// New creates an xAI provider using the Responses API. +func New(apiKey, model string, opts ...Option) *Provider { + p := &Provider{ + baseURL: "https://api.x.ai/v1", + apiKey: apiKey, + model: model, + client: &http.Client{}, + } + for _, opt := range opts { + opt(p) + } + return p +} + +// WithBaseURL overrides the API base URL (useful for testing). +func WithBaseURL(url string) Option { + return func(p *Provider) { p.baseURL = url } +} + +// WithWebSearch enables the built-in web search tool. +func WithWebSearch(opts ...WebSearchOption) Option { + return func(p *Provider) { + cfg := &webSearchConfig{} + for _, o := range opts { + o(cfg) + } + t := requestTool{Type: "web_search"} + if len(cfg.AllowedDomains) > 0 { + t.AllowedDomains = cfg.AllowedDomains + } + if len(cfg.ExcludedDomains) > 0 { + t.ExcludedDomains = cfg.ExcludedDomains + } + p.tools = append(p.tools, t) + } +} + +// WithXSearch enables the built-in X/Twitter search tool. +func WithXSearch(opts ...XSearchOption) Option { + return func(p *Provider) { + cfg := &xSearchConfig{} + for _, o := range opts { + o(cfg) + } + t := requestTool{Type: "x_search"} + if len(cfg.AllowedHandles) > 0 { + t.AllowedHandles = cfg.AllowedHandles + } + if len(cfg.ExcludedHandles) > 0 { + t.ExcludedHandles = cfg.ExcludedHandles + } + p.tools = append(p.tools, t) + } +} + +// AllowedDomains restricts web search to the specified domains. +func AllowedDomains(domains ...string) WebSearchOption { + return func(c *webSearchConfig) { c.AllowedDomains = domains } +} + +// ExcludedDomains excludes the specified domains from web search. +func ExcludedDomains(domains ...string) WebSearchOption { + return func(c *webSearchConfig) { c.ExcludedDomains = domains } +} + +// AllowedHandles restricts X search to the specified handles. +func AllowedHandles(handles ...string) XSearchOption { + return func(c *xSearchConfig) { c.AllowedHandles = handles } +} + +// ExcludedHandles excludes the specified handles from X search. +func ExcludedHandles(handles ...string) XSearchOption { + return func(c *xSearchConfig) { c.ExcludedHandles = handles } +} + +// LastCitations returns the citations from the most recent Generate call. +func (p *Provider) LastCitations() []Citation { + p.mu.Lock() + defer p.mu.Unlock() + return p.lastCitations +} + +// --- xAI Responses API wire types --- + +type request struct { + Model string `json:"model"` + Input []inputItem `json:"input"` + Tools []requestTool `json:"tools,omitempty"` +} + +type inputItem struct { + // Message fields + Role string `json:"role,omitempty"` + Content string `json:"content,omitempty"` + // Tool result fields + Type string `json:"type,omitempty"` // "function_call_output" + CallID string `json:"call_id,omitempty"` + Output string `json:"output,omitempty"` +} + +type requestTool struct { + Type string `json:"type"` + Name string `json:"name,omitempty"` + Description string `json:"description,omitempty"` + Parameters json.RawMessage `json:"parameters,omitempty"` + // web_search options + AllowedDomains []string `json:"allowed_domains,omitempty"` + ExcludedDomains []string `json:"excluded_domains,omitempty"` + // x_search options + AllowedHandles []string `json:"allowed_x_handles,omitempty"` + ExcludedHandles []string `json:"excluded_x_handles,omitempty"` +} + +type response struct { + ID string `json:"id"` + Output []outputItem `json:"output"` + Usage responseUsage `json:"usage"` +} + +type outputItem struct { + Type string `json:"type"` + // function_call fields + Name string `json:"name,omitempty"` + Arguments string `json:"arguments,omitempty"` + CallID string `json:"call_id,omitempty"` + // message fields + Role string `json:"role,omitempty"` + Content []contentItem `json:"content,omitempty"` +} + +type contentItem struct { + Type string `json:"type"` + Text string `json:"text"` + Annotations []annotation `json:"annotations,omitempty"` +} + +type annotation struct { + Type string `json:"type"` + URL string `json:"url"` + Title string `json:"title"` + StartIndex int `json:"start_index"` + EndIndex int `json:"end_index"` +} + +type responseUsage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` +} + +// --- Conversion helpers --- + +// convertMessages converts forge messages to xAI input items. +func convertMessages(msgs []forge.Message, systemPrompt string) []inputItem { + var items []inputItem + + if systemPrompt != "" { + items = append(items, inputItem{Role: "system", Content: systemPrompt}) + } + + for _, m := range msgs { + if m.Role == forge.RoleSystem { + continue // handled above + } + + // Tool result messages expand into one input item per result. + if m.Role == forge.RoleTool && len(m.ToolResults) > 0 { + for _, tr := range m.ToolResults { + items = append(items, inputItem{ + Type: "function_call_output", + CallID: tr.CallID, + Output: tr.Content, + }) + } + continue + } + + items = append(items, inputItem{ + Role: string(m.Role), + Content: m.Content, + }) + } + + return items +} + +// convertTools converts forge tool definitions to xAI flat tool format. +func convertTools(defs []forge.ToolDefinition) []requestTool { + tools := make([]requestTool, 0, len(defs)) + for _, d := range defs { + tools = append(tools, requestTool{ + Type: "function", + Name: d.Name, + Description: d.Description, + Parameters: d.Schema.Parameters, + }) + } + return tools +} + +// parseResponse converts an xAI response to a forge ProviderResponse. +func parseResponse(resp *response) (*forge.ProviderResponse, []Citation) { + var content string + var toolCalls []forge.ToolCall + var citations []Citation + + for _, item := range resp.Output { + switch item.Type { + case "function_call": + toolCalls = append(toolCalls, forge.ToolCall{ + ID: item.CallID, + Name: item.Name, + Arguments: json.RawMessage(item.Arguments), + }) + case "message": + for _, c := range item.Content { + if c.Type == "output_text" { + content += c.Text + // Extract citations from inline annotations. + for _, a := range c.Annotations { + if a.Type == "url_citation" { + citations = append(citations, Citation{ + URL: a.URL, + Title: a.Title, + Source: "web", + StartIndex: a.StartIndex, + EndIndex: a.EndIndex, + }) + } + } + } + } + // Server-side tool calls (web_search_call, x_search_call, etc.) + // are auto-executed by xAI — we don't surface them. + } + } + + finishReason := forge.FinishReasonStop + if len(toolCalls) > 0 { + finishReason = forge.FinishReasonToolUse + } + + return &forge.ProviderResponse{ + Message: forge.Message{ + Role: forge.RoleAssistant, + Content: content, + ToolCalls: toolCalls, + }, + FinishReason: finishReason, + Usage: forge.TokenUsage{ + InputTokens: resp.Usage.InputTokens, + OutputTokens: resp.Usage.OutputTokens, + }, + }, citations +} + +// Generate sends a request to the xAI Responses API. +func (p *Provider) Generate(ctx context.Context, req forge.ProviderRequest) (*forge.ProviderResponse, error) { + // Build the input items from forge messages. + input := convertMessages(req.Messages, req.SystemPrompt) + + // Build tools: merge function tools from the request with persistent server-side tools. + var tools []requestTool + tools = append(tools, p.tools...) // server-side tools (web_search, x_search) + if len(req.Tools) > 0 { + tools = append(tools, convertTools(req.Tools)...) + } + + body := request{ + Model: p.model, + Input: input, + Tools: tools, + } + + jsonBody, err := json.Marshal(body) + if err != nil { + return nil, fmt.Errorf("marshal request: %w", err) + } + + httpReq, err := http.NewRequestWithContext(ctx, "POST", + fmt.Sprintf("%s/responses", p.baseURL), bytes.NewReader(jsonBody)) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Authorization", "Bearer "+p.apiKey) + + httpResp, err := p.client.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("send request: %w", err) + } + defer httpResp.Body.Close() + + respBody, err := io.ReadAll(httpResp.Body) + if err != nil { + return nil, fmt.Errorf("read response: %w", err) + } + + if httpResp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("xAI API error (%d): %s", httpResp.StatusCode, string(respBody)) + } + + var apiResp response + if err := json.Unmarshal(respBody, &apiResp); err != nil { + return nil, fmt.Errorf("unmarshal response: %w", err) + } + + providerResp, citations := parseResponse(&apiResp) + + // Store citations for provider-specific access. + p.mu.Lock() + p.lastCitations = citations + p.mu.Unlock() + + return providerResp, nil +} diff --git a/provider/xai/xai_test.go b/provider/xai/xai_test.go new file mode 100644 index 0000000..3018104 --- /dev/null +++ b/provider/xai/xai_test.go @@ -0,0 +1,358 @@ +package xai + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/katasec/forge" +) + +// Compile-time check that *Provider satisfies forge.Provider. +var _ forge.Provider = (*Provider)(nil) + +func TestNew(t *testing.T) { + p := New("test-key", "grok-3-mini") + if p.apiKey != "test-key" { + t.Errorf("apiKey = %q, want %q", p.apiKey, "test-key") + } + if p.model != "grok-3-mini" { + t.Errorf("model = %q, want %q", p.model, "grok-3-mini") + } + if p.baseURL != "https://api.x.ai/v1" { + t.Errorf("baseURL = %q, want default", p.baseURL) + } + if len(p.tools) != 0 { + t.Errorf("tools = %d, want 0", len(p.tools)) + } +} + +func TestNewWithOptions(t *testing.T) { + p := New("key", "model", + WithBaseURL("http://localhost"), + WithWebSearch(AllowedDomains("wikipedia.org", "github.com")), + WithXSearch(ExcludedHandles("@spam")), + ) + + if p.baseURL != "http://localhost" { + t.Errorf("baseURL = %q", p.baseURL) + } + if len(p.tools) != 2 { + t.Fatalf("tools = %d, want 2", len(p.tools)) + } + if p.tools[0].Type != "web_search" { + t.Errorf("tools[0].Type = %q", p.tools[0].Type) + } + if len(p.tools[0].AllowedDomains) != 2 { + t.Errorf("allowed_domains = %v", p.tools[0].AllowedDomains) + } + if p.tools[1].Type != "x_search" { + t.Errorf("tools[1].Type = %q", p.tools[1].Type) + } + if len(p.tools[1].ExcludedHandles) != 1 { + t.Errorf("excluded_handles = %v", p.tools[1].ExcludedHandles) + } +} + +func TestGenerate(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + t.Errorf("method = %s, want POST", r.Method) + } + if got := r.Header.Get("Authorization"); got != "Bearer test-key" { + t.Errorf("Authorization = %q", got) + } + if r.URL.Path != "/responses" { + t.Errorf("path = %q, want /responses", r.URL.Path) + } + + var req request + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + t.Fatalf("decode request: %v", err) + } + if req.Model != "grok-3-mini" { + t.Errorf("model = %q", req.Model) + } + // Should have system + user message. + if len(req.Input) != 2 { + t.Fatalf("input items = %d, want 2", len(req.Input)) + } + if req.Input[0].Role != "system" { + t.Errorf("input[0].role = %q, want system", req.Input[0].Role) + } + if req.Input[1].Role != "user" { + t.Errorf("input[1].role = %q, want user", req.Input[1].Role) + } + + resp := response{ + ID: "resp-123", + Output: []outputItem{{ + Type: "message", + Role: "assistant", + Content: []contentItem{{ + Type: "output_text", + Text: "Hello from Grok!", + }}, + }}, + Usage: responseUsage{InputTokens: 10, OutputTokens: 5}, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer srv.Close() + + p := New("test-key", "grok-3-mini", WithBaseURL(srv.URL)) + resp, err := p.Generate(context.Background(), forge.ProviderRequest{ + SystemPrompt: "Be helpful.", + Messages: []forge.Message{ + {Role: forge.RoleUser, Content: "Hi"}, + }, + }) + if err != nil { + t.Fatalf("Generate: %v", err) + } + + if resp.Message.Content != "Hello from Grok!" { + t.Errorf("content = %q", resp.Message.Content) + } + if resp.Message.Role != forge.RoleAssistant { + t.Errorf("role = %q", resp.Message.Role) + } + if resp.FinishReason != forge.FinishReasonStop { + t.Errorf("finishReason = %q", resp.FinishReason) + } + if resp.Usage.InputTokens != 10 || resp.Usage.OutputTokens != 5 { + t.Errorf("usage = %+v", resp.Usage) + } +} + +func TestGenerateWithFunctionCalls(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var req request + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + t.Fatalf("decode: %v", err) + } + + // Verify function tools are flat (not nested in "function" wrapper). + if len(req.Tools) != 1 { + t.Fatalf("tools = %d, want 1", len(req.Tools)) + } + tool := req.Tools[0] + if tool.Type != "function" { + t.Errorf("tool.type = %q, want function", tool.Type) + } + if tool.Name != "get_weather" { + t.Errorf("tool.name = %q", tool.Name) + } + if tool.Description != "Get weather" { + t.Errorf("tool.description = %q", tool.Description) + } + + resp := response{ + ID: "resp-456", + Output: []outputItem{{ + Type: "function_call", + Name: "get_weather", + Arguments: `{"city":"SF"}`, + CallID: "call-1", + }}, + Usage: responseUsage{InputTokens: 15, OutputTokens: 8}, + } + json.NewEncoder(w).Encode(resp) + })) + defer srv.Close() + + p := New("key", "grok-3-mini", WithBaseURL(srv.URL)) + resp, err := p.Generate(context.Background(), forge.ProviderRequest{ + Messages: []forge.Message{{Role: forge.RoleUser, Content: "Weather in SF?"}}, + Tools: []forge.ToolDefinition{{ + Name: "get_weather", + Description: "Get weather", + Schema: forge.ToolSchema{Parameters: json.RawMessage(`{"type":"object","properties":{"city":{"type":"string"}}}`)}, + }}, + }) + if err != nil { + t.Fatalf("Generate: %v", err) + } + + if resp.FinishReason != forge.FinishReasonToolUse { + t.Errorf("finishReason = %q, want tool_use", resp.FinishReason) + } + if len(resp.Message.ToolCalls) != 1 { + t.Fatalf("toolCalls = %d, want 1", len(resp.Message.ToolCalls)) + } + tc := resp.Message.ToolCalls[0] + if tc.ID != "call-1" { + t.Errorf("toolCall.ID = %q", tc.ID) + } + if tc.Name != "get_weather" { + t.Errorf("toolCall.Name = %q", tc.Name) + } + if string(tc.Arguments) != `{"city":"SF"}` { + t.Errorf("toolCall.Arguments = %s", tc.Arguments) + } +} + +func TestGenerateWithToolResults(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var req request + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + t.Fatalf("decode: %v", err) + } + + // Find the function_call_output input item. + var found bool + for _, item := range req.Input { + if item.Type == "function_call_output" { + found = true + if item.CallID != "call-1" { + t.Errorf("call_id = %q", item.CallID) + } + if item.Output != "72°F" { + t.Errorf("output = %q", item.Output) + } + } + } + if !found { + t.Error("no function_call_output found in input") + } + + resp := response{ + ID: "resp-789", + Output: []outputItem{{ + Type: "message", + Role: "assistant", + Content: []contentItem{{ + Type: "output_text", + Text: "It's 72°F in SF.", + }}, + }}, + Usage: responseUsage{InputTokens: 20, OutputTokens: 10}, + } + json.NewEncoder(w).Encode(resp) + })) + defer srv.Close() + + p := New("key", "grok-3-mini", WithBaseURL(srv.URL)) + resp, err := p.Generate(context.Background(), forge.ProviderRequest{ + Messages: []forge.Message{ + {Role: forge.RoleUser, Content: "Weather in SF?"}, + {Role: forge.RoleAssistant, Content: "", ToolCalls: []forge.ToolCall{ + {ID: "call-1", Name: "get_weather", Arguments: json.RawMessage(`{"city":"SF"}`)}, + }}, + {Role: forge.RoleTool, ToolResults: []forge.ToolResult{ + {CallID: "call-1", Content: "72°F"}, + }}, + }, + }) + if err != nil { + t.Fatalf("Generate: %v", err) + } + + if resp.Message.Content != "It's 72°F in SF." { + t.Errorf("content = %q", resp.Message.Content) + } + if resp.FinishReason != forge.FinishReasonStop { + t.Errorf("finishReason = %q", resp.FinishReason) + } +} + +func TestGenerateWithWebSearch(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var req request + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + t.Fatalf("decode: %v", err) + } + + // Verify web_search tool is included. + var hasWebSearch bool + for _, tool := range req.Tools { + if tool.Type == "web_search" { + hasWebSearch = true + if len(tool.AllowedDomains) != 1 || tool.AllowedDomains[0] != "reuters.com" { + t.Errorf("allowed_domains = %v", tool.AllowedDomains) + } + } + } + if !hasWebSearch { + t.Error("web_search tool not found in request") + } + + resp := response{ + ID: "resp-search", + Output: []outputItem{ + {Type: "web_search_call"}, // server-side, should be ignored + { + Type: "message", + Role: "assistant", + Content: []contentItem{{ + Type: "output_text", + Text: "According to Reuters, xAI launched...", + Annotations: []annotation{{ + Type: "url_citation", + URL: "https://reuters.com/tech/xai", + Title: "1", + StartIndex: 22, + EndIndex: 35, + }}, + }}, + }, + }, + Usage: responseUsage{InputTokens: 50, OutputTokens: 30}, + } + json.NewEncoder(w).Encode(resp) + })) + defer srv.Close() + + p := New("key", "grok-3-mini", + WithBaseURL(srv.URL), + WithWebSearch(AllowedDomains("reuters.com")), + ) + resp, err := p.Generate(context.Background(), forge.ProviderRequest{ + Messages: []forge.Message{{Role: forge.RoleUser, Content: "Latest xAI news?"}}, + }) + if err != nil { + t.Fatalf("Generate: %v", err) + } + + if resp.Message.Content != "According to Reuters, xAI launched..." { + t.Errorf("content = %q", resp.Message.Content) + } + if resp.FinishReason != forge.FinishReasonStop { + t.Errorf("finishReason = %q", resp.FinishReason) + } + + // Check citations via provider-specific accessor. + citations := p.LastCitations() + if len(citations) != 1 { + t.Fatalf("citations = %d, want 1", len(citations)) + } + c := citations[0] + if c.URL != "https://reuters.com/tech/xai" { + t.Errorf("citation.URL = %q", c.URL) + } + if c.Source != "web" { + t.Errorf("citation.Source = %q, want web", c.Source) + } + if c.StartIndex != 22 || c.EndIndex != 35 { + t.Errorf("citation indices = [%d, %d]", c.StartIndex, c.EndIndex) + } +} + +func TestGenerateAPIError(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusTooManyRequests) + w.Write([]byte(`{"error":"rate limited"}`)) + })) + defer srv.Close() + + p := New("key", "grok-3-mini", WithBaseURL(srv.URL)) + _, err := p.Generate(context.Background(), forge.ProviderRequest{ + Messages: []forge.Message{{Role: forge.RoleUser, Content: "Hi"}}, + }) + if err == nil { + t.Fatal("expected error for 429 response") + } +}