From cadaf9d6a5ec1ffd3998765709abd8182bd8307a Mon Sep 17 00:00:00 2001 From: Ameer Deen Date: Tue, 26 May 2026 21:26:37 +0400 Subject: [PATCH] add xai model constants --- README.md | 2 +- _examples/chat-console/main.go | 2 +- _examples/hello-world/main.go | 2 +- provider/xai/model.go | 12 ++++++++++++ provider/xai/xai.go | 6 +++--- provider/xai/xai_test.go | 20 ++++++++++---------- 6 files changed, 28 insertions(+), 16 deletions(-) create mode 100644 provider/xai/model.go diff --git a/README.md b/README.md index 7184105..4265c1e 100644 --- a/README.md +++ b/README.md @@ -67,7 +67,7 @@ Or use the xAI Responses API with built-in web search: ```go import "github.com/katasec/forge/provider/xai" -provider := xai.New(os.Getenv("XAI_API_KEY"), "grok-4-1-fast-non-reasoning", xai.WithWebSearch()) +provider := xai.New(os.Getenv("XAI_API_KEY"), xai.ModelGrok4FastNonReasoning, xai.WithWebSearch()) // After running the agent, access citations: citations := provider.LastCitations() diff --git a/_examples/chat-console/main.go b/_examples/chat-console/main.go index df54f54..1a272a7 100644 --- a/_examples/chat-console/main.go +++ b/_examples/chat-console/main.go @@ -45,7 +45,7 @@ func buildProvider(name string) (forge.Provider, func()) { return openai.New("https://api.x.ai/v1", key, "grok-3-mini"), func() {} case "xai-search": key := requireEnv("XAI_API_KEY") - provider := xai.New(key, "grok-4-1-fast-non-reasoning", xai.WithWebSearch()) + provider := xai.New(key, xai.ModelGrok4FastNonReasoning, xai.WithWebSearch()) return provider, func() { printCitations(provider.LastCitations()) } default: log.Fatalf("unknown provider %q; use anthropic, xai, or xai-search", name) diff --git a/_examples/hello-world/main.go b/_examples/hello-world/main.go index 79cfe0f..1bce5c2 100644 --- a/_examples/hello-world/main.go +++ b/_examples/hello-world/main.go @@ -55,7 +55,7 @@ func main() { if key == "" { log.Fatal("Set XAI_API_KEY environment variable") } - xaiProvider = xai.New(key, "grok-4-1-fast-non-reasoning", xai.WithWebSearch()) + xaiProvider = xai.New(key, xai.ModelGrok4FastNonReasoning, xai.WithWebSearch()) provider = xaiProvider default: log.Fatalf("Unknown provider: %s (use 'anthropic', 'xai', or 'xai-search')", *providerFlag) diff --git a/provider/xai/model.go b/provider/xai/model.go new file mode 100644 index 0000000..a74d00d --- /dev/null +++ b/provider/xai/model.go @@ -0,0 +1,12 @@ +package xai + +// Model is an xAI model identifier. +type Model string + +const ( + // ModelGrok3Mini is xAI's Grok 3 mini model. + ModelGrok3Mini Model = "grok-3-mini" + + // ModelGrok4FastNonReasoning is xAI's fast non-reasoning Grok 4.1 model. + ModelGrok4FastNonReasoning Model = "grok-4-1-fast-non-reasoning" +) diff --git a/provider/xai/xai.go b/provider/xai/xai.go index 01a9083..ed1607f 100644 --- a/provider/xai/xai.go +++ b/provider/xai/xai.go @@ -5,7 +5,7 @@ // // Usage: // -// provider := xai.New(apiKey, "grok-3-mini", xai.WithWebSearch()) +// provider := xai.New(apiKey, xai.ModelGrok3Mini, xai.WithWebSearch()) package xai import ( @@ -62,11 +62,11 @@ type xSearchConfig struct { } // New creates an xAI provider using the Responses API. -func New(apiKey, model string, opts ...Option) *Provider { +func New(apiKey string, model Model, opts ...Option) *Provider { p := &Provider{ baseURL: "https://api.x.ai/v1", apiKey: apiKey, - model: model, + model: string(model), client: &http.Client{}, } for _, opt := range opts { diff --git a/provider/xai/xai_test.go b/provider/xai/xai_test.go index 3018104..f655039 100644 --- a/provider/xai/xai_test.go +++ b/provider/xai/xai_test.go @@ -14,12 +14,12 @@ import ( var _ forge.Provider = (*Provider)(nil) func TestNew(t *testing.T) { - p := New("test-key", "grok-3-mini") + p := New("test-key", ModelGrok3Mini) if p.apiKey != "test-key" { t.Errorf("apiKey = %q, want %q", p.apiKey, "test-key") } - if p.model != "grok-3-mini" { - t.Errorf("model = %q, want %q", p.model, "grok-3-mini") + if p.model != string(ModelGrok3Mini) { + t.Errorf("model = %q, want %q", p.model, ModelGrok3Mini) } if p.baseURL != "https://api.x.ai/v1" { t.Errorf("baseURL = %q, want default", p.baseURL) @@ -30,7 +30,7 @@ func TestNew(t *testing.T) { } func TestNewWithOptions(t *testing.T) { - p := New("key", "model", + p := New("key", Model("custom-model"), WithBaseURL("http://localhost"), WithWebSearch(AllowedDomains("wikipedia.org", "github.com")), WithXSearch(ExcludedHandles("@spam")), @@ -72,7 +72,7 @@ func TestGenerate(t *testing.T) { if err := json.NewDecoder(r.Body).Decode(&req); err != nil { t.Fatalf("decode request: %v", err) } - if req.Model != "grok-3-mini" { + if req.Model != string(ModelGrok3Mini) { t.Errorf("model = %q", req.Model) } // Should have system + user message. @@ -103,7 +103,7 @@ func TestGenerate(t *testing.T) { })) defer srv.Close() - p := New("test-key", "grok-3-mini", WithBaseURL(srv.URL)) + p := New("test-key", ModelGrok3Mini, WithBaseURL(srv.URL)) resp, err := p.Generate(context.Background(), forge.ProviderRequest{ SystemPrompt: "Be helpful.", Messages: []forge.Message{ @@ -164,7 +164,7 @@ func TestGenerateWithFunctionCalls(t *testing.T) { })) defer srv.Close() - p := New("key", "grok-3-mini", WithBaseURL(srv.URL)) + p := New("key", ModelGrok3Mini, WithBaseURL(srv.URL)) resp, err := p.Generate(context.Background(), forge.ProviderRequest{ Messages: []forge.Message{{Role: forge.RoleUser, Content: "Weather in SF?"}}, Tools: []forge.ToolDefinition{{ @@ -235,7 +235,7 @@ func TestGenerateWithToolResults(t *testing.T) { })) defer srv.Close() - p := New("key", "grok-3-mini", WithBaseURL(srv.URL)) + p := New("key", ModelGrok3Mini, WithBaseURL(srv.URL)) resp, err := p.Generate(context.Background(), forge.ProviderRequest{ Messages: []forge.Message{ {Role: forge.RoleUser, Content: "Weather in SF?"}, @@ -306,7 +306,7 @@ func TestGenerateWithWebSearch(t *testing.T) { })) defer srv.Close() - p := New("key", "grok-3-mini", + p := New("key", ModelGrok3Mini, WithBaseURL(srv.URL), WithWebSearch(AllowedDomains("reuters.com")), ) @@ -348,7 +348,7 @@ func TestGenerateAPIError(t *testing.T) { })) defer srv.Close() - p := New("key", "grok-3-mini", WithBaseURL(srv.URL)) + p := New("key", ModelGrok3Mini, WithBaseURL(srv.URL)) _, err := p.Generate(context.Background(), forge.ProviderRequest{ Messages: []forge.Message{{Role: forge.RoleUser, Content: "Hi"}}, })