diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9c7f61f..71e10f0 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -2,7 +2,6 @@ name: CI on: push: - branches: [main] pull_request: branches: [main] @@ -22,9 +21,23 @@ jobs: - name: Vet run: go vet ./... + - name: Vet examples + run: | + cd _examples/hello-world + go vet ./... + cd ../calculator + go vet ./... + - name: Test run: gotestsum --junitfile test-results.xml -- ./... -v + - name: Test examples + run: | + cd _examples/hello-world + go test ./... + cd ../calculator + go test ./... + - name: Test Summary uses: test-summary/action@v2 if: always() diff --git a/README.md b/README.md index 0341cda..6b9d599 100644 --- a/README.md +++ b/README.md @@ -43,16 +43,12 @@ func main() { log.Fatal(err) } - resp, err := agent.Run(context.Background(), forge.AgentRequest{ - Messages: []forge.Message{ - {Role: forge.RoleUser, Content: "Hello! What are you?"}, - }, - }) + resp, err := agent.Ask(context.Background(), "Hello! What are you?") if err != nil { log.Fatal(err) } - fmt.Println(resp.Messages[len(resp.Messages)-1].Content) + fmt.Println(resp.LastText()) } ``` @@ -125,13 +121,28 @@ type Tool interface { `Agent.Run` executes this loop: -1. Load conversation history from memory (if configured) +1. Load conversation history from memory 2. Call the provider with messages + tool definitions 3. If the provider says **stop** → return the response 4. If the provider requests **tool use** → execute tools, feed results back, go to 2 5. If **iteration limit** hit → return with `FinishReasonIterLimit` 6. Save conversation to memory +For the common case, use `Ask`: + +```go +resp, err := agent.Ask(ctx, "Hello") +fmt.Println(resp.LastText()) +``` + +Use `AskIn` when you want to manage multiple named conversations: + +```go +resp, err := agent.AskIn(ctx, "support-ticket-123", "What happened last?") +``` + +Use `Run` when you need full control over message roles, multiple messages, or advanced conversation wiring. + ### Error Policy Controls what happens when a tool returns an error: @@ -165,26 +176,30 @@ Middleware composes as decorators: given `[A, B, C]`, request flows `A → B → ### Memory -Persist conversations across `Agent.Run` calls: +Forge uses in-memory conversation history by default. Repeated `Ask` calls on the same agent continue the same default conversation: ```go -store := forge.NewInMemoryStore() - agent, _ := forge.NewAgent(forge.Config{ Provider: myProvider, - Memory: store, }) -// First call — starts a conversation. -resp, _ := agent.Run(ctx, forge.AgentRequest{ - ConversationID: "conv-1", - Messages: []forge.Message{{Role: forge.RoleUser, Content: "Hi"}}, -}) +resp, _ := agent.Ask(ctx, "My name is Ameer.") +resp, _ = agent.Ask(ctx, "What is my name?") +``` + +For named conversations: + +```go +resp, _ := agent.AskIn(ctx, "conv-1", "Hi") +resp, _ = agent.AskIn(ctx, "conv-1", "What did I just say?") +``` + +Disable memory explicitly for stateless agents: -// Second call — continues the same conversation. -resp, _ = agent.Run(ctx, forge.AgentRequest{ - ConversationID: "conv-1", - Messages: []forge.Message{{Role: forge.RoleUser, Content: "What did I just say?"}}, +```go +agent, _ := forge.NewAgent(forge.Config{ + Provider: myProvider, + DisableMemory: true, }) ``` diff --git a/_examples/calculator/main.go b/_examples/calculator/main.go index 71f310f..11899ca 100644 --- a/_examples/calculator/main.go +++ b/_examples/calculator/main.go @@ -107,7 +107,7 @@ func main() { agent, err := forge.NewAgent(forge.Config{ Provider: &MockProvider{}, Tools: []forge.Tool{addTool, mulTool}, - Middleware: []forge.Middleware{logging}, + Middleware: []forge.Middleware{logging}, SystemPrompt: "You are a helpful calculator assistant.", MaxIterations: 5, ErrorPolicy: forge.ErrorPolicyContinue, @@ -120,17 +120,13 @@ func main() { fmt.Println("User: What is 12 + 30?") fmt.Println(strings.Repeat("-", 40)) - resp, err := agent.Run(context.Background(), forge.AgentRequest{ - Messages: []forge.Message{ - {Role: forge.RoleUser, Content: "What is 12 + 30?"}, - }, - }) + resp, err := agent.Ask(context.Background(), "What is 12 + 30?") if err != nil { log.Fatal(err) } fmt.Println(strings.Repeat("-", 40)) - fmt.Printf("Assistant: %s\n", resp.Messages[len(resp.Messages)-1].Content) + fmt.Printf("Assistant: %s\n", resp.LastText()) fmt.Printf("Finish reason: %s\n", resp.FinishReason) fmt.Printf("Tokens: %d in, %d out\n", resp.Usage.InputTokens, resp.Usage.OutputTokens) fmt.Printf("Conversation: %d messages\n", len(resp.Messages)) diff --git a/_examples/hello-world/main.go b/_examples/hello-world/main.go index 212fe88..a85f490 100644 --- a/_examples/hello-world/main.go +++ b/_examples/hello-world/main.go @@ -70,17 +70,13 @@ func main() { log.Fatal(err) } - // Run it. - resp, err := agent.Run(context.Background(), forge.AgentRequest{ - Messages: []forge.Message{ - {Role: forge.RoleUser, Content: "Hello! What are you?"}, - }, - }) + // Ask preserves conversation history on this agent by default. + resp, err := agent.Ask(context.Background(), "Hello! What are you?") if err != nil { log.Fatal(err) } - fmt.Println(resp.Messages[len(resp.Messages)-1].Content) + fmt.Println(resp.LastText()) fmt.Printf("\n[%s | tokens: %d in, %d out]\n", *providerFlag, resp.Usage.InputTokens, resp.Usage.OutputTokens) // Show citations if using xai-search. diff --git a/agent.go b/agent.go index 9de6e32..d83dfde 100644 --- a/agent.go +++ b/agent.go @@ -22,16 +22,28 @@ type AgentResponse struct { Errors []ToolError `json:"errors,omitempty"` } +// LastText returns the latest assistant text in the response. +func (r *AgentResponse) LastText() string { + for i := len(r.Messages) - 1; i >= 0; i-- { + msg := r.Messages[i] + if msg.Role == RoleAssistant && msg.Content != "" { + return msg.Content + } + } + return "" +} + // Agent orchestrates the LLM call → tool execution → response loop. type Agent struct { - provider Provider - registry *ToolRegistry - executor ToolExecutor - run RunFunc - memory MemoryStore - systemPrompt string - maxIterations int - errorPolicy ErrorPolicy + provider Provider + registry *ToolRegistry + executor ToolExecutor + run RunFunc + memory MemoryStore + defaultConversationID string + systemPrompt string + maxIterations int + errorPolicy ErrorPolicy } // NewAgent creates an Agent from the given Config. @@ -58,18 +70,37 @@ func NewAgent(cfg Config) (*Agent, error) { errorPolicy = ErrorPolicyStop } + memory := cfg.Memory + if memory == nil && !cfg.DisableMemory { + memory = NewInMemoryStore() + } + return &Agent{ - provider: cfg.Provider, - registry: registry, - executor: executor, - run: run, - memory: cfg.Memory, - systemPrompt: cfg.SystemPrompt, - maxIterations: cfg.MaxIterations, - errorPolicy: errorPolicy, + provider: cfg.Provider, + registry: registry, + executor: executor, + run: run, + memory: memory, + defaultConversationID: uuid.New().String(), + systemPrompt: cfg.SystemPrompt, + maxIterations: cfg.MaxIterations, + errorPolicy: errorPolicy, }, nil } +// Ask sends a user prompt in the agent's default conversation. +func (a *Agent) Ask(ctx context.Context, prompt string) (*AgentResponse, error) { + return a.AskIn(ctx, a.defaultConversationID, prompt) +} + +// AskIn sends a user prompt in the named conversation. +func (a *Agent) AskIn(ctx context.Context, conversationID, prompt string) (*AgentResponse, error) { + return a.Run(ctx, AgentRequest{ + ConversationID: conversationID, + Messages: []Message{UserMessage(prompt)}, + }) +} + // Run executes the agent loop per the design spec pseudocode. func (a *Agent) Run(ctx context.Context, req AgentRequest) (*AgentResponse, error) { conversationID := req.ConversationID diff --git a/agent_test.go b/agent_test.go index 34f48a6..9250af6 100644 --- a/agent_test.go +++ b/agent_test.go @@ -30,6 +30,26 @@ func (m *mockProvider) Generate(_ context.Context, _ ProviderRequest) (*Provider }, nil } +// recordingProvider stores provider requests so tests can inspect conversation history. +type recordingProvider struct { + responses []*ProviderResponse + requests []ProviderRequest + calls int +} + +func (r *recordingProvider) Generate(_ context.Context, req ProviderRequest) (*ProviderResponse, error) { + r.requests = append(r.requests, req) + i := r.calls + r.calls++ + if i < len(r.responses) { + return r.responses[i], nil + } + return &ProviderResponse{ + Message: Message{Role: RoleAssistant, Content: "default"}, + FinishReason: FinishReasonStop, + }, nil +} + func TestNewAgentNilProvider(t *testing.T) { _, err := NewAgent(Config{}) if err == nil { @@ -49,6 +69,130 @@ func TestNewAgentDefaultErrorPolicy(t *testing.T) { } } +func TestNewAgentDefaultsToInMemoryStore(t *testing.T) { + agent, err := NewAgent(Config{ + Provider: &mockProvider{}, + }) + if err != nil { + t.Fatalf("NewAgent error: %v", err) + } + if agent.memory == nil { + t.Fatal("expected default memory store") + } + if _, ok := agent.memory.(*InMemoryStore); !ok { + t.Fatalf("memory = %T, want *InMemoryStore", agent.memory) + } +} + +func TestNewAgentDisableMemory(t *testing.T) { + agent, err := NewAgent(Config{ + Provider: &mockProvider{}, + DisableMemory: true, + }) + if err != nil { + t.Fatalf("NewAgent error: %v", err) + } + if agent.memory != nil { + t.Fatalf("memory = %T, want nil", agent.memory) + } +} + +func TestAgentAskPreservesDefaultConversation(t *testing.T) { + provider := &recordingProvider{ + responses: []*ProviderResponse{ + { + Message: Message{Role: RoleAssistant, Content: "hello"}, + FinishReason: FinishReasonStop, + }, + { + Message: Message{Role: RoleAssistant, Content: "I remember"}, + FinishReason: FinishReasonStop, + }, + }, + } + + agent, err := NewAgent(Config{Provider: provider}) + if err != nil { + t.Fatalf("NewAgent error: %v", err) + } + + first, err := agent.Ask(context.Background(), "My name is Ameer.") + if err != nil { + t.Fatalf("Ask first error: %v", err) + } + second, err := agent.Ask(context.Background(), "What is my name?") + if err != nil { + t.Fatalf("Ask second error: %v", err) + } + + if first.ConversationID == "" { + t.Fatal("expected first response to include conversation ID") + } + if second.ConversationID != first.ConversationID { + t.Fatalf("conversation ID = %q, want %q", second.ConversationID, first.ConversationID) + } + if len(provider.requests) != 2 { + t.Fatalf("provider requests = %d, want 2", len(provider.requests)) + } + if len(provider.requests[1].Messages) != 3 { + t.Fatalf("second request messages = %d, want 3", len(provider.requests[1].Messages)) + } + if provider.requests[1].Messages[0].Content != "My name is Ameer." { + t.Errorf("first remembered message = %q", provider.requests[1].Messages[0].Content) + } +} + +func TestAgentAskInUsesNamedConversations(t *testing.T) { + provider := &recordingProvider{ + responses: []*ProviderResponse{ + {Message: Message{Role: RoleAssistant, Content: "forge noted"}, FinishReason: FinishReasonStop}, + {Message: Message{Role: RoleAssistant, Content: "other noted"}, FinishReason: FinishReasonStop}, + {Message: Message{Role: RoleAssistant, Content: "forge remembered"}, FinishReason: FinishReasonStop}, + }, + } + + agent, err := NewAgent(Config{Provider: provider}) + if err != nil { + t.Fatalf("NewAgent error: %v", err) + } + + if _, err := agent.AskIn(context.Background(), "forge", "Remember forge."); err != nil { + t.Fatalf("AskIn forge error: %v", err) + } + if _, err := agent.AskIn(context.Background(), "other", "Remember other."); err != nil { + t.Fatalf("AskIn other error: %v", err) + } + resp, err := agent.AskIn(context.Background(), "forge", "What did I ask you to remember?") + if err != nil { + t.Fatalf("AskIn forge follow-up error: %v", err) + } + + if resp.ConversationID != "forge" { + t.Fatalf("conversation ID = %q, want forge", resp.ConversationID) + } + if len(provider.requests[2].Messages) != 3 { + t.Fatalf("forge follow-up messages = %d, want 3", len(provider.requests[2].Messages)) + } + if provider.requests[2].Messages[0].Content != "Remember forge." { + t.Errorf("first forge message = %q", provider.requests[2].Messages[0].Content) + } +} + +func TestAgentResponseLastText(t *testing.T) { + resp := &AgentResponse{ + Messages: []Message{ + UserMessage("hello"), + {Role: RoleAssistant, Content: "first"}, + {Role: RoleTool, ToolResults: []ToolResult{{Content: "tool result"}}}, + {Role: RoleAssistant, Content: "latest"}, + }, + } + + if got := resp.LastText(); got != "latest" { + t.Fatalf("LastText() = %q, want latest", got) + } +} + func TestAgentRunStop(t *testing.T) { provider := &mockProvider{ responses: []*ProviderResponse{ diff --git a/config.go b/config.go index 91f15e4..78f0860 100644 --- a/config.go +++ b/config.go @@ -5,7 +5,8 @@ type Config struct { Provider Provider Tools []Tool Middleware []Middleware - Memory MemoryStore // optional, nil means no persistence + Memory MemoryStore // optional, defaults to in-memory unless DisableMemory is true + DisableMemory bool // optional, true means no conversation persistence SystemPrompt string // optional MaxIterations int // 0 means no limit ErrorPolicy ErrorPolicy // defaults to ErrorPolicyStop diff --git a/docs/design/design.md b/docs/design/design.md index 7680cce..41ba9eb 100644 --- a/docs/design/design.md +++ b/docs/design/design.md @@ -29,6 +29,16 @@ forge/memory/redis/ Redis-backed MemoryStore forge/executor/concurrent/ Parallel tool executor ``` +## Code Style & API Philosophy + +Forge uses progressive disclosure in both code and public API design. + +Top-level functions should read like intent. Complex behavior should be composed from small, named, single-responsibility helpers so a reader can understand what happens first, then drill into how each step works. + +The public API should make the common path obvious before exposing the lower-level machinery. Developers should be able to start with `Ask(ctx, prompt)`, use `AskIn(ctx, conversationID, prompt)` when they need named conversations, and drop to `Run(ctx, AgentRequest{...})` only when they need full control over roles, message history, or advanced orchestration. + +Package layout should follow the same rule as the code. The root package may remain a friendly facade, but implementation-heavy defaults should move toward focused packages as the library grows: agent orchestration, messages, tools, providers, memory, executors, and metadata. + --- ## 1. Core Types (`types.go`) @@ -56,11 +66,14 @@ type Message struct { ToolCalls []ToolCall `json:"tool_calls,omitempty"` ToolResults []ToolResult `json:"tool_results,omitempty"` } + +func UserMessage(content string) Message ``` - `ID` is assigned by the caller or generated (UUID) when empty. - `ToolCalls` is populated only on assistant messages when the LLM requests tool use. - `ToolResults` is populated only on tool-role messages returning results. +- `UserMessage` is the preferred helper for creating a single user prompt when using the lower-level `Run` API. ### ToolCall & ToolResult @@ -334,7 +347,8 @@ type Config struct { Provider Provider Tools []Tool Middleware []Middleware - Memory MemoryStore // optional, nil means no persistence + Memory MemoryStore // optional, defaults to in-memory unless DisableMemory is true + DisableMemory bool // optional, true means no conversation persistence SystemPrompt string // optional MaxIterations int // 0 means no limit ErrorPolicy ErrorPolicy // defaults to ErrorPolicyStop @@ -354,7 +368,9 @@ func NewAgent(cfg Config) (*Agent, error) - Builds a `ToolRegistry` from `cfg.Tools`. - Creates a `SequentialExecutor` with the registry. - Applies `cfg.Middleware` to build the composed `RunFunc`. +- Defaults `Memory` to `NewInMemoryStore()` when `cfg.Memory` is nil and `cfg.DisableMemory` is false. - Defaults `ErrorPolicy` to `ErrorPolicyStop` if empty. +- Creates a default conversation ID used by `Ask`. ### AgentRequest & AgentResponse @@ -371,8 +387,22 @@ type AgentResponse struct { Usage TokenUsage `json:"usage"` Errors []ToolError `json:"errors,omitempty"` } + +func (r *AgentResponse) LastText() string +``` + +### Convenience API + +```go +func (a *Agent) Ask(ctx context.Context, prompt string) (*AgentResponse, error) +func (a *Agent) AskIn(ctx context.Context, conversationID, prompt string) (*AgentResponse, error) ``` +- `Ask` sends a user prompt to the agent's default conversation. +- `AskIn` sends a user prompt to a named conversation. +- Both return the full `AgentResponse`, preserving access to conversation ID, token usage, finish reason, errors, and message history. +- `LastText` returns the latest assistant text content in the response. + ### Agent Loop — `Agent.Run(ctx, req)` ``` @@ -451,7 +481,8 @@ Pseudocode: - Tool errors with `ErrorPolicyContinue` are collected in `errors` but the loop continues, letting the LLM see the error and adapt. - Tool errors with `ErrorPolicyStop` break the loop immediately but still include the tool results in the message history. - `FinishReasonToolUse` never appears in the final `AgentResponse` — the loop always processes tool calls. -- Memory is saved once at the end, with the complete conversation. +- Memory is enabled by default with an in-memory store. It is saved once at the end, with the complete conversation. +- File-backed or database-backed memory must be explicitly configured because persistence changes privacy and lifecycle expectations. - Context cancellation is respected: `composedRunFunc` and `tool.Invoke` should check `ctx`. --- diff --git a/tool_test.go b/tool_test.go index 4a03199..f4d4c54 100644 --- a/tool_test.go +++ b/tool_test.go @@ -83,6 +83,16 @@ func TestFuncInvokeInvalidArgs(t *testing.T) { } } +func TestUserMessage(t *testing.T) { + msg := UserMessage("hello") + if msg.Role != RoleUser { + t.Fatalf("Role = %q, want %q", msg.Role, RoleUser) + } + if msg.Content != "hello" { + t.Fatalf("Content = %q, want hello", msg.Content) + } +} + func TestRegistryRegisterAndGet(t *testing.T) { r := NewToolRegistry() tool := Func[addInput]("add", "adds", func(_ context.Context, in addInput) (string, error) { diff --git a/types.go b/types.go index 736f001..080ad97 100644 --- a/types.go +++ b/types.go @@ -21,6 +21,11 @@ type Message struct { ToolResults []ToolResult `json:"tool_results,omitempty"` } +// UserMessage creates a user-role message with the given content. +func UserMessage(content string) Message { + return Message{Role: RoleUser, Content: content} +} + // ToolCall represents a request from the LLM to invoke a tool. type ToolCall struct { ID string `json:"id"`