Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ Or use the xAI Responses API with built-in web search:
```go
import "github.com/katasec/forge/provider/xai"

provider := xai.New(os.Getenv("XAI_API_KEY"), "grok-4-1-fast-non-reasoning", xai.WithWebSearch())
provider := xai.New(os.Getenv("XAI_API_KEY"), xai.ModelGrok4FastNonReasoning, xai.WithWebSearch())

// After running the agent, access citations:
citations := provider.LastCitations()
Expand Down
2 changes: 1 addition & 1 deletion _examples/chat-console/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func buildProvider(name string) (forge.Provider, func()) {
return openai.New("https://api.x.ai/v1", key, "grok-3-mini"), func() {}
case "xai-search":
key := requireEnv("XAI_API_KEY")
provider := xai.New(key, "grok-4-1-fast-non-reasoning", xai.WithWebSearch())
provider := xai.New(key, xai.ModelGrok4FastNonReasoning, xai.WithWebSearch())
return provider, func() { printCitations(provider.LastCitations()) }
default:
log.Fatalf("unknown provider %q; use anthropic, xai, or xai-search", name)
Expand Down
2 changes: 1 addition & 1 deletion _examples/hello-world/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func main() {
if key == "" {
log.Fatal("Set XAI_API_KEY environment variable")
}
xaiProvider = xai.New(key, "grok-4-1-fast-non-reasoning", xai.WithWebSearch())
xaiProvider = xai.New(key, xai.ModelGrok4FastNonReasoning, xai.WithWebSearch())
provider = xaiProvider
default:
log.Fatalf("Unknown provider: %s (use 'anthropic', 'xai', or 'xai-search')", *providerFlag)
Expand Down
12 changes: 12 additions & 0 deletions provider/xai/model.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package xai

// Model is an xAI model identifier.
type Model string

const (
// ModelGrok3Mini is xAI's Grok 3 mini model.
ModelGrok3Mini Model = "grok-3-mini"

// ModelGrok4FastNonReasoning is xAI's fast non-reasoning Grok 4.1 model.
ModelGrok4FastNonReasoning Model = "grok-4-1-fast-non-reasoning"
)
6 changes: 3 additions & 3 deletions provider/xai/xai.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
//
// Usage:
//
// provider := xai.New(apiKey, "grok-3-mini", xai.WithWebSearch())
// provider := xai.New(apiKey, xai.ModelGrok3Mini, xai.WithWebSearch())
package xai

import (
Expand Down Expand Up @@ -62,11 +62,11 @@ type xSearchConfig struct {
}

// New creates an xAI provider using the Responses API.
func New(apiKey, model string, opts ...Option) *Provider {
func New(apiKey string, model Model, opts ...Option) *Provider {
p := &Provider{
baseURL: "https://api.x.ai/v1",
apiKey: apiKey,
model: model,
model: string(model),
client: &http.Client{},
}
for _, opt := range opts {
Expand Down
20 changes: 10 additions & 10 deletions provider/xai/xai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ import (
var _ forge.Provider = (*Provider)(nil)

func TestNew(t *testing.T) {
p := New("test-key", "grok-3-mini")
p := New("test-key", ModelGrok3Mini)
if p.apiKey != "test-key" {
t.Errorf("apiKey = %q, want %q", p.apiKey, "test-key")
}
if p.model != "grok-3-mini" {
t.Errorf("model = %q, want %q", p.model, "grok-3-mini")
if p.model != string(ModelGrok3Mini) {
t.Errorf("model = %q, want %q", p.model, ModelGrok3Mini)
}
if p.baseURL != "https://api.x.ai/v1" {
t.Errorf("baseURL = %q, want default", p.baseURL)
Expand All @@ -30,7 +30,7 @@ func TestNew(t *testing.T) {
}

func TestNewWithOptions(t *testing.T) {
p := New("key", "model",
p := New("key", Model("custom-model"),
WithBaseURL("http://localhost"),
WithWebSearch(AllowedDomains("wikipedia.org", "github.com")),
WithXSearch(ExcludedHandles("@spam")),
Expand Down Expand Up @@ -72,7 +72,7 @@ func TestGenerate(t *testing.T) {
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
t.Fatalf("decode request: %v", err)
}
if req.Model != "grok-3-mini" {
if req.Model != string(ModelGrok3Mini) {
t.Errorf("model = %q", req.Model)
}
// Should have system + user message.
Expand Down Expand Up @@ -103,7 +103,7 @@ func TestGenerate(t *testing.T) {
}))
defer srv.Close()

p := New("test-key", "grok-3-mini", WithBaseURL(srv.URL))
p := New("test-key", ModelGrok3Mini, WithBaseURL(srv.URL))
resp, err := p.Generate(context.Background(), forge.ProviderRequest{
SystemPrompt: "Be helpful.",
Messages: []forge.Message{
Expand Down Expand Up @@ -164,7 +164,7 @@ func TestGenerateWithFunctionCalls(t *testing.T) {
}))
defer srv.Close()

p := New("key", "grok-3-mini", WithBaseURL(srv.URL))
p := New("key", ModelGrok3Mini, WithBaseURL(srv.URL))
resp, err := p.Generate(context.Background(), forge.ProviderRequest{
Messages: []forge.Message{{Role: forge.RoleUser, Content: "Weather in SF?"}},
Tools: []forge.ToolDefinition{{
Expand Down Expand Up @@ -235,7 +235,7 @@ func TestGenerateWithToolResults(t *testing.T) {
}))
defer srv.Close()

p := New("key", "grok-3-mini", WithBaseURL(srv.URL))
p := New("key", ModelGrok3Mini, WithBaseURL(srv.URL))
resp, err := p.Generate(context.Background(), forge.ProviderRequest{
Messages: []forge.Message{
{Role: forge.RoleUser, Content: "Weather in SF?"},
Expand Down Expand Up @@ -306,7 +306,7 @@ func TestGenerateWithWebSearch(t *testing.T) {
}))
defer srv.Close()

p := New("key", "grok-3-mini",
p := New("key", ModelGrok3Mini,
WithBaseURL(srv.URL),
WithWebSearch(AllowedDomains("reuters.com")),
)
Expand Down Expand Up @@ -348,7 +348,7 @@ func TestGenerateAPIError(t *testing.T) {
}))
defer srv.Close()

p := New("key", "grok-3-mini", WithBaseURL(srv.URL))
p := New("key", ModelGrok3Mini, WithBaseURL(srv.URL))
_, err := p.Generate(context.Background(), forge.ProviderRequest{
Messages: []forge.Message{{Role: forge.RoleUser, Content: "Hi"}},
})
Expand Down
Loading