diff --git a/.design/a2a-sdk-migration.md b/.design/a2a-sdk-migration.md new file mode 100644 index 000000000..cee5e8648 --- /dev/null +++ b/.design/a2a-sdk-migration.md @@ -0,0 +1,127 @@ +# A2A Go SDK Migration + +## Status: In Progress +## Date: 2026-06-08 + +## Summary + +Migrate the scion-a2a-bridge from a hand-rolled A2A protocol implementation to +the official `a2a-go` SDK (`github.com/a2aproject/a2a-go/v2`). This replaces +our custom JSON-RPC handling, task lifecycle management, and streaming +infrastructure with the SDK's spec-compliant implementations while preserving +our Scion Hub routing core. + +## Motivation + +- **Spec compliance**: The SDK tracks the A2A spec automatically. Our hand-rolled + implementation required manual updates for each spec revision. +- **Reduced maintenance**: ~500 lines of JSON-RPC, SSE streaming, and task store + code replaced by SDK. +- **Multi-transport**: SDK provides JSON-RPC, REST, and gRPC transports from a + single `RequestHandler` — we get gRPC and REST nearly for free. +- **Correctness**: SDK handles edge cases (OCC, concurrent cancellation, event + ordering) that our MVP implementation simplified or deferred. + +## Architecture + +### Before (hand-rolled) + +``` +HTTP Request → server.go (JSON-RPC dispatch) → bridge.go (task management) + → Hub API → Broker → bridge.go (response correlation) → JSON-RPC response +``` + +### After (SDK-based) + +``` +HTTP Request → auth middleware → route extraction → SDK JSONRPC Handler + → SDK RequestHandler → SDK task lifecycle → ScionExecutor.Execute() + → bridge.go (Hub routing) → Broker → waiter channel → SDK events + → SDK response serialization → HTTP response +``` + +### Key Components + +**ScionExecutor** (`executor.go`): Implements `a2asrv.AgentExecutor`. The bridge +between the SDK's event-driven model and our Scion Hub message routing. + +- `Execute()`: Translates SDK message → Scion StructuredMessage, sends to Hub, + waits for broker response, yields SDK events. +- `Cancel()`: Sends interrupt to Scion agent, yields canceled status event. + +**Server** (`server.go`): Simplified HTTP routing layer. Handles: +- Multi-project/agent URL routing (`/projects/{p}/agents/{a}/jsonrpc`) +- Agent card serving (kept custom — SDK's card handler is single-agent) +- Auth middleware, rate limiting, metrics (unchanged) +- Delegates JSON-RPC to SDK's `NewJSONRPCHandler` + +**Bridge** (`bridge.go`): Core Hub routing preserved. Changes: +- Added `sdkRequestHandler` field for multi-transport access +- Task lifecycle now managed by SDK's in-memory task store +- SQLite store retained for context mapping and broker correlation + +**Translate** (`translate.go`): Added SDK-compatible translation functions: +- `TranslateA2APartsToScion()`: SDK `a2a.ContentParts` → Scion message +- `TranslateScionToA2AParts()`: Scion message → SDK `a2a.Message` + `a2a.Artifact` +- `MapActivityToSDKTaskState()`: Scion activity → SDK `a2a.TaskState` +- Original functions retained for backward compatibility + +## What Changed + +| Component | Before | After | +|-----------|--------|-------| +| JSON-RPC parsing | `server.go` hand-rolled | SDK `a2asrv.NewJSONRPCHandler` | +| Task lifecycle | `bridge.go` + SQLite | SDK in-memory task store | +| SSE streaming | `stream.go` custom | SDK built-in | +| Push notifications | `push.go` custom | SDK `push.Sender` (future) | +| A2A types | `translate.go` custom structs | SDK `a2a` package | +| Error codes | Custom constants | SDK `a2a.Err*` sentinel errors | + +## What's Preserved + +- **Bridge core**: Hub client routing, broker plugin, agent lookup, context + resolution, auto-provisioning — all unchanged. +- **Config**: Same YAML format, same fields. +- **Auth**: Same API key / Bearer middleware. +- **Metrics**: Same Prometheus metrics. +- **Rate limiting**: Same per-IP/key token bucket. +- **Broker plugin**: Same go-plugin RPC server. +- **SQLite store**: Retained for context mapping. Task state now also in SDK + in-memory store. + +## PR Structure + +### PR A: SDK Adoption (`a2a/sdk-migration`) +- Add `a2a-go/v2` dependency +- New `executor.go` (AgentExecutor implementation) +- Rewritten `server.go` (SDK handler delegation) +- Updated `translate.go` (SDK type translations) +- Updated `bridge.go` (sdkRequestHandler field) +- Updated `main.go` (SDK wiring) +- Updated tests + +### PR B: gRPC + REST Transports (`a2a/sdk-grpc-rest`) +- `a2agrpc.NewHandler` for gRPC transport +- `a2asrv.NewRESTHandler` for REST transport +- Config fields: `grpc_listen_address`, `rest_listen_address` +- Startup wiring in `main.go` + +## Migration Risks + +1. **Task store divergence**: SDK uses in-memory store; our SQLite store tracks + context mappings separately. Tasks visible via A2A protocol come from SDK + store; context lookups use SQLite. + +2. **Broker correlation**: The SDK doesn't know about our broker. Response + correlation happens inside `ScionExecutor.Execute()` using the same waiter + channel pattern. + +3. **Push notification gap**: SDK has `push.Sender` interface but we haven't + wired our SSRF-safe push dispatcher yet. Push is disabled in capabilities. + +## Future Work + +- Wire SDK push notification support with our SSRF-safe dispatcher +- Implement SDK `taskstore.Store` interface backed by SQLite for persistence +- Add multi-turn conversation support (SDK handles it; our executor needs updates) +- Evaluate SDK's work queue for distributed deployment diff --git a/extras/scion-a2a-bridge/cmd/scion-a2a-bridge/main.go b/extras/scion-a2a-bridge/cmd/scion-a2a-bridge/main.go index d75579504..dbd810f08 100644 --- a/extras/scion-a2a-bridge/cmd/scion-a2a-bridge/main.go +++ b/extras/scion-a2a-bridge/cmd/scion-a2a-bridge/main.go @@ -29,9 +29,15 @@ import ( "syscall" "time" + "net" + secretmanager "cloud.google.com/go/secretmanager/apiv1" smpb "cloud.google.com/go/secretmanager/apiv1/secretmanagerpb" + "github.com/a2aproject/a2a-go/v2/a2a" + a2agrpc "github.com/a2aproject/a2a-go/v2/a2agrpc/v0" + "github.com/a2aproject/a2a-go/v2/a2asrv" "github.com/prometheus/client_golang/prometheus" + "google.golang.org/grpc" "gopkg.in/yaml.v3" "github.com/GoogleCloudPlatform/scion/extras/scion-a2a-bridge/internal/bridge" @@ -136,20 +142,39 @@ func main() { // Wire broker into the bridge for subscription management. b.SetBroker(broker) + // Create SDK executor and request handler. + executor := bridge.NewScionExecutor(b, log.With("component", "executor")) + sdkRequestHandler := a2asrv.NewHandler( + executor, + a2asrv.WithLogger(log.With("component", "a2a-sdk")), + a2asrv.WithCapabilityChecks(&a2a.AgentCapabilities{ + Streaming: true, + PushNotifications: false, + }), + a2asrv.WithAgentInactivityTimeout(cfg.Timeouts.SendMessage), + ) + b.SetSDKRequestHandler(sdkRequestHandler) + + // Create SDK JSON-RPC transport handler. + sdkJSONRPCHandler := a2asrv.NewJSONRPCHandler( + sdkRequestHandler, + a2asrv.WithTransportKeepAlive(cfg.Timeouts.SSEKeepalive), + ) + // Start A2A HTTP server. listenAddr := cfg.Bridge.ListenAddress if listenAddr == "" { listenAddr = ":8443" } - srv := bridge.NewServer(b, cfg, metrics, log.With("component", "a2a-server")) + srv := bridge.NewServer(b, cfg, metrics, log.With("component", "a2a-server"), sdkJSONRPCHandler) srv.WarnOnOpenAuth() httpServer := &http.Server{ Addr: listenAddr, Handler: srv.Handler(), ReadTimeout: 30 * time.Second, - WriteTimeout: 30 * time.Second, + WriteTimeout: 0, // Disabled for SSE connections; SDK handles timeouts. IdleTimeout: 120 * time.Second, MaxHeaderBytes: 1 << 20, } @@ -163,7 +188,118 @@ func main() { } }() - log.Info("scion-a2a-bridge ready") + // Start gRPC server if configured. + // NOTE: gRPC and REST transports require a single-project, single-agent + // configuration because they lack per-request project/agent routing. The + // executor injects the configured default route into every request context. + // Auth is also not applied to these transports — secure them via network + // policy or a proxy. + var grpcServer *grpc.Server + if cfg.Bridge.GRPCListenAddress != "" { + if len(cfg.Projects) == 0 || len(cfg.Projects[0].ExposedAgents) == 0 { + log.Error("gRPC transport requires at least one project with exposed agents in config") + os.Exit(1) + } + defaultRoute := bridge.RouteInfo{ + ProjectSlug: cfg.Projects[0].Slug, + AgentSlug: cfg.Projects[0].ExposedAgents[0], + } + log.Warn("gRPC transport uses fixed routing — all requests go to the first configured agent", + "project", defaultRoute.ProjectSlug, "agent", defaultRoute.AgentSlug) + + if !cfg.Bridge.GRPCInsecure { + log.Warn("gRPC transport: auth enabled — clients must provide credentials via gRPC metadata", + "scheme", cfg.Auth.Scheme) + } else { + log.Warn("⚠ gRPC transport: auth DISABLED (grpc_insecure: true) — any client can send requests without credentials", + "address", cfg.Bridge.GRPCListenAddress) + } + + grpcServer = grpc.NewServer( + grpc.ChainUnaryInterceptor( + bridge.AuthUnaryInterceptor(cfg), + bridge.RouteInfoUnaryInterceptor(defaultRoute), + ), + grpc.ChainStreamInterceptor( + bridge.AuthStreamInterceptor(cfg), + bridge.RouteInfoStreamInterceptor(defaultRoute), + ), + ) + grpcHandler := a2agrpc.NewHandler(sdkRequestHandler) + grpcHandler.RegisterWith(grpcServer) + + grpcListener, err := net.Listen("tcp", cfg.Bridge.GRPCListenAddress) + if err != nil { + log.Error("failed to listen for gRPC", "address", cfg.Bridge.GRPCListenAddress, "error", err) + os.Exit(1) + } + + go func() { + log.Info("gRPC transport starting", "address", cfg.Bridge.GRPCListenAddress) + if err := grpcServer.Serve(grpcListener); err != nil { + errCh <- fmt.Errorf("gRPC server: %w", err) + } + }() + } + + // Start REST server if configured. + var restServer *http.Server + if cfg.Bridge.RESTListenAddress != "" { + if len(cfg.Projects) == 0 || len(cfg.Projects[0].ExposedAgents) == 0 { + log.Error("REST transport requires at least one project with exposed agents in config") + os.Exit(1) + } + defaultRoute := bridge.RouteInfo{ + ProjectSlug: cfg.Projects[0].Slug, + AgentSlug: cfg.Projects[0].ExposedAgents[0], + } + log.Warn("REST transport uses fixed routing — all requests go to the first configured agent", + "project", defaultRoute.ProjectSlug, "agent", defaultRoute.AgentSlug) + if !cfg.Bridge.RESTInsecure { + log.Warn("REST transport: auth enabled — clients must provide credentials via HTTP headers", + "scheme", cfg.Auth.Scheme) + } else { + log.Warn("⚠ REST transport: auth DISABLED (rest_insecure: true) — any client can send requests without credentials", + "address", cfg.Bridge.RESTListenAddress) + } + + restHandler := bridge.AuthHTTPMiddleware(cfg, bridge.RouteInfoMiddleware(defaultRoute, + bridge.SSEWriteDeadlineMiddleware( + a2asrv.NewRESTHandler( + sdkRequestHandler, + a2asrv.WithTransportKeepAlive(cfg.Timeouts.SSEKeepalive), + ), + ), + )) + + restServer = &http.Server{ + Addr: cfg.Bridge.RESTListenAddress, + Handler: restHandler, + ReadTimeout: 30 * time.Second, + WriteTimeout: 30 * time.Second, + IdleTimeout: 120 * time.Second, + MaxHeaderBytes: 1 << 20, + } + + go func() { + log.Info("REST transport starting", "address", cfg.Bridge.RESTListenAddress) + if err := restServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { + errCh <- fmt.Errorf("REST server: %w", err) + } + }() + } + + transports := []string{"JSON-RPC"} + if cfg.Bridge.GRPCListenAddress != "" { + transports = append(transports, "gRPC") + } + if cfg.Bridge.RESTListenAddress != "" { + transports = append(transports, "REST") + } + log.Info("scion-a2a-bridge ready", + "transports", transports, + "sdk", "a2a-go/v2", + ) // Wait for shutdown signal. sigCh := make(chan os.Signal, 1) @@ -183,6 +319,28 @@ func main() { log.Error("failed to stop A2A server", "error", err) } + if grpcServer != nil { + grpcStopped := make(chan struct{}) + go func() { + grpcServer.GracefulStop() + close(grpcStopped) + }() + select { + case <-grpcStopped: + log.Info("gRPC server stopped gracefully") + case <-shutdownCtx.Done(): + log.Warn("gRPC graceful shutdown timed out, forcing stop") + grpcServer.Stop() + } + } + + if restServer != nil { + if err := restServer.Shutdown(shutdownCtx); err != nil { + log.Error("failed to stop REST server", "error", err) + } + log.Info("REST server stopped") + } + // Drain background goroutines before closing the store. b.Shutdown() diff --git a/extras/scion-a2a-bridge/go.mod b/extras/scion-a2a-bridge/go.mod index e5d26ce66..5deac8a68 100644 --- a/extras/scion-a2a-bridge/go.mod +++ b/extras/scion-a2a-bridge/go.mod @@ -5,12 +5,14 @@ go 1.26.1 require ( cloud.google.com/go/secretmanager v1.16.0 github.com/GoogleCloudPlatform/scion v0.0.0-00010101000000-000000000000 + github.com/a2aproject/a2a-go/v2 v2.3.1 github.com/go-jose/go-jose/v4 v4.1.4 github.com/google/uuid v1.6.0 github.com/hashicorp/go-plugin v1.7.0 github.com/mattn/go-sqlite3 v1.14.28 github.com/prometheus/client_golang v1.23.2 github.com/prometheus/client_model v0.6.2 + google.golang.org/grpc v1.80.0 gopkg.in/yaml.v3 v3.0.1 ) @@ -19,6 +21,7 @@ require ( cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect cloud.google.com/go/compute/metadata v0.9.0 // indirect cloud.google.com/go/iam v1.5.3 // indirect + github.com/a2aproject/a2a-go v0.3.15 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/fatih/color v1.16.0 // indirect @@ -45,6 +48,7 @@ require ( go.opentelemetry.io/otel/trace v1.43.0 // indirect go.yaml.in/yaml/v2 v2.4.3 // indirect golang.org/x/crypto v0.49.0 // indirect + golang.org/x/mod v0.33.0 // indirect golang.org/x/net v0.52.0 // indirect golang.org/x/oauth2 v0.35.0 // indirect golang.org/x/sync v0.20.0 // indirect @@ -54,9 +58,8 @@ require ( golang.org/x/time v0.14.0 // indirect google.golang.org/api v0.259.0 // indirect google.golang.org/genproto v0.0.0-20251202230838-ff82c1b0f217 // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 // indirect - google.golang.org/grpc v1.80.0 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20260427160629-7cedc36a6bc4 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20260427160629-7cedc36a6bc4 // indirect google.golang.org/protobuf v1.36.11 // indirect ) diff --git a/extras/scion-a2a-bridge/go.sum b/extras/scion-a2a-bridge/go.sum index b7f2c7d7f..b7e4eb4b4 100644 --- a/extras/scion-a2a-bridge/go.sum +++ b/extras/scion-a2a-bridge/go.sum @@ -10,6 +10,10 @@ cloud.google.com/go/iam v1.5.3 h1:+vMINPiDF2ognBJ97ABAYYwRgsaqxPbQDlMnbHMjolc= cloud.google.com/go/iam v1.5.3/go.mod h1:MR3v9oLkZCTlaqljW6Eb2d3HGDGK5/bDv93jhfISFvU= cloud.google.com/go/secretmanager v1.16.0 h1:19QT7ZsLJ8FSP1k+4esQvuCD7npMJml6hYzilxVyT+k= cloud.google.com/go/secretmanager v1.16.0/go.mod h1://C/e4I8D26SDTz1f3TQcddhcmiC3rMEl0S1Cakvs3Q= +github.com/a2aproject/a2a-go v0.3.15 h1:h5YpCiPq3jxQ5rIns7oDjPag3ivP8u817AzdA4F+NiI= +github.com/a2aproject/a2a-go v0.3.15/go.mod h1:I7Cm+a1oL+UT6zMoP+roaRE5vdfUa1iQGVN8aSOuZ0I= +github.com/a2aproject/a2a-go/v2 v2.3.1 h1:QWMdOX2UsJ8BJmjs952eo1FRyGsOVl0gFCKeM76AgGE= +github.com/a2aproject/a2a-go/v2 v2.3.1/go.mod h1:mkZr8y2bUgAVQsjs/5fHK7xrRlAHDybMEyxWh2tKRC8= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/bufbuild/protocompile v0.14.1 h1:iA73zAf/fyljNjQKwYzUHD6AD4R8KMasmwa/FBatYVw= @@ -122,6 +126,8 @@ go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0= go.yaml.in/yaml/v2 v2.4.3/go.mod h1:zSxWcmIDjOzPXpjlTTbAsKokqkDNAVtZO0WOMiT90s8= golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4= golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA= +golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8= +golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w= golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0= golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw= golang.org/x/oauth2 v0.35.0 h1:Mv2mzuHuZuY2+bkyWXIHMfhNdJAdwW3FuWeCPYN5GVQ= @@ -148,10 +154,10 @@ google.golang.org/api v0.259.0 h1:90TaGVIxScrh1Vn/XI2426kRpBqHwWIzVBzJsVZ5XrQ= google.golang.org/api v0.259.0/go.mod h1:LC2ISWGWbRoyQVpxGntWwLWN/vLNxxKBK9KuJRI8Te4= google.golang.org/genproto v0.0.0-20251202230838-ff82c1b0f217 h1:GvESR9BIyHUahIb0NcTum6itIWtdoglGX+rnGxm2934= google.golang.org/genproto v0.0.0-20251202230838-ff82c1b0f217/go.mod h1:yJ2HH4EHEDTd3JiLmhds6NkJ17ITVYOdV3m3VKOnws0= -google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9 h1:VPWxll4HlMw1Vs/qXtN7BvhZqsS9cdAittCNvVENElA= -google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:7QBABkRtR8z+TEnmXTqIqwJLlzrZKVfAUm7tY3yGv0M= -google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 h1:m8qni9SQFH0tJc1X0vmnpw/0t+AImlSvp30sEupozUg= -google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8= +google.golang.org/genproto/googleapis/api v0.0.0-20260427160629-7cedc36a6bc4 h1:yOzSCGPx+cp5VO7IxvZ9SBFF7j1tZVcNtlHR2iYKtVo= +google.golang.org/genproto/googleapis/api v0.0.0-20260427160629-7cedc36a6bc4/go.mod h1:Q9HWtNeE7tM9npdIsEvqXj1QJIvVoeAV3rtXtS715Cw= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260427160629-7cedc36a6bc4 h1:tEkOQcXgF6dH1G+MVKZrfpYvozGrzb91k6ha7jireSM= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260427160629-7cedc36a6bc4/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8= google.golang.org/grpc v1.80.0 h1:Xr6m2WmWZLETvUNvIUmeD5OAagMw3FiKmMlTdViWsHM= google.golang.org/grpc v1.80.0/go.mod h1:ho/dLnxwi3EDJA4Zghp7k2Ec1+c2jqup0bFkw07bwF4= google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= diff --git a/extras/scion-a2a-bridge/internal/bridge/bridge.go b/extras/scion-a2a-bridge/internal/bridge/bridge.go index 7742413ef..d21d2c666 100644 --- a/extras/scion-a2a-bridge/internal/bridge/bridge.go +++ b/extras/scion-a2a-bridge/internal/bridge/bridge.go @@ -23,6 +23,7 @@ import ( "sync" "time" + "github.com/a2aproject/a2a-go/v2/a2asrv" "github.com/google/uuid" "github.com/GoogleCloudPlatform/scion/extras/scion-a2a-bridge/internal/identity" @@ -58,6 +59,9 @@ type Bridge struct { metrics *Metrics log *slog.Logger + // sdkRequestHandler holds the SDK RequestHandler for multi-transport use (gRPC, REST). + sdkRequestHandler a2asrv.RequestHandler + // waiters tracks channels waiting for agent responses, keyed by taskID. mu sync.RWMutex waiters map[string]*waiter @@ -229,6 +233,11 @@ func (b *Bridge) SetBroker(broker *BrokerServer) { b.broker = broker } +// SetSDKRequestHandler stores the SDK RequestHandler for multi-transport access. +func (b *Bridge) SetSDKRequestHandler(h a2asrv.RequestHandler) { + b.sdkRequestHandler = h +} + // agentKey returns a composite key for project-scoped agent isolation. func agentKey(projectID, agentSlug string) string { return projectID + ":" + agentSlug @@ -640,9 +649,27 @@ func (b *Bridge) dispatchBrokerMessage(topic string, msg *messages.StructuredMes // If the message carries a task correlation ID, dispatch only to that task // after verifying the message's agent matches the task's owner. if taskID := msg.Metadata["a2aTaskId"]; taskID != "" { + // Try waiter first — SDK-created tasks (via AgentExecutor) may not + // be stored in the local SQLite store, but they register a waiter + // for blocking response correlation. Check the waiter before the + // store to avoid dropping responses for SDK-managed tasks. + if b.dispatchToWaiter(taskID, msg) { + return + } + task, err := b.store.GetTask(taskID) if err != nil || task == nil { - b.log.Debug("ignoring message for unknown task", "task_id", taskID) + // Also check if the task is registered as active (SDK executor + // registers in activeTasks even without a store entry). + b.tasksMu.RLock() + _, isActive := b.activeTasks[taskID] + b.tasksMu.RUnlock() + if !isActive { + b.log.Debug("ignoring message for unknown task", "task_id", taskID) + return + } + // Active but not in store — SDK-managed task, dispatch via active path. + b.dispatchToActiveTask(ctx, taskID, agentSlug, msg) return } if task.AgentSlug != agentSlug { @@ -651,9 +678,6 @@ func (b *Bridge) dispatchBrokerMessage(topic string, msg *messages.StructuredMes return } - if b.dispatchToWaiter(taskID, msg) { - return - } b.tasksMu.RLock() _, isActive := b.activeTasks[taskID] b.tasksMu.RUnlock() @@ -696,7 +720,8 @@ func (b *Bridge) dispatchBrokerMessage(topic string, msg *messages.StructuredMes // dispatchToWaiter sends a message to a blocking waiter for the given taskID. // Returns true if a waiter exists and handled the message (callers should skip // further dispatch). State-change messages are skipped so the actual reply -// lands in the buffer. +// lands in the buffer. Verifies the message sender's agent slug matches the +// waiter's expected agent to prevent cross-agent message injection. func (b *Bridge) dispatchToWaiter(taskID string, msg *messages.StructuredMessage) bool { b.mu.RLock() w, ok := b.waiters[taskID] @@ -704,6 +729,20 @@ func (b *Bridge) dispatchToWaiter(taskID string, msg *messages.StructuredMessage if !ok { return false } + // Verify agent ownership: the waiter's expected agent must match the + // message sender's agent slug. This prevents a response from Agent B + // being delivered to a task that was started for Agent A. + if w.agentSlug != "" { + senderAgent := extractAgentIDFromSender(msg.Sender) + if senderAgent != "" && senderAgent != w.agentSlug { + b.log.Warn("dropping cross-agent message for waiter", + "task_id", taskID, + "expected_agent", w.agentSlug, + "sender_agent", senderAgent, + ) + return true // consumed but rejected — don't fall through to other dispatch paths + } + } if msg.Type == messages.TypeStateChange { // Terminal state-changes must still be persisted to the DB even though // we skip the waiter — otherwise the task's stored state is never updated. @@ -878,7 +917,7 @@ func (b *Bridge) GenerateAgentCard(ctx context.Context, projectSlug, agentSlug s "version": "1.0.0", "capabilities": map[string]bool{ "streaming": true, - "pushNotifications": true, + "pushNotifications": false, }, "defaultInputModes": []string{"text/plain", "application/json"}, "defaultOutputModes": []string{"text/plain", "application/json"}, diff --git a/extras/scion-a2a-bridge/internal/bridge/config.go b/extras/scion-a2a-bridge/internal/bridge/config.go index 72784df0a..b579a1c31 100644 --- a/extras/scion-a2a-bridge/internal/bridge/config.go +++ b/extras/scion-a2a-bridge/internal/bridge/config.go @@ -32,10 +32,14 @@ type Config struct { // BridgeConfig holds the A2A protocol server settings. type BridgeConfig struct { - ListenAddress string `yaml:"listen_address"` - ExternalURL string `yaml:"external_url"` - MaxSubscribers int `yaml:"max_subscribers"` - Provider ProviderConfig `yaml:"provider"` + ListenAddress string `yaml:"listen_address"` + GRPCListenAddress string `yaml:"grpc_listen_address"` + RESTListenAddress string `yaml:"rest_listen_address"` + GRPCInsecure bool `yaml:"grpc_insecure"` + RESTInsecure bool `yaml:"rest_insecure"` + ExternalURL string `yaml:"external_url"` + MaxSubscribers int `yaml:"max_subscribers"` + Provider ProviderConfig `yaml:"provider"` } // ProviderConfig describes the bridge operator. diff --git a/extras/scion-a2a-bridge/internal/bridge/executor.go b/extras/scion-a2a-bridge/internal/bridge/executor.go new file mode 100644 index 000000000..fa17f8ec4 --- /dev/null +++ b/extras/scion-a2a-bridge/internal/bridge/executor.go @@ -0,0 +1,405 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bridge + +import ( + "context" + "fmt" + "iter" + "log/slog" + "net/http" + "strings" + "time" + + "github.com/a2aproject/a2a-go/v2/a2a" + "github.com/a2aproject/a2a-go/v2/a2asrv" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" + + "github.com/GoogleCloudPlatform/scion/pkg/messages" +) + +// routeKey is a context key for passing project/agent routing info to the executor. +type routeKey struct{} + +// RouteInfo carries the project and agent slugs extracted from the HTTP path +// so the executor knows which Scion agent to route to. +type RouteInfo struct { + ProjectSlug string + AgentSlug string +} + +// WithRouteInfo attaches routing metadata to a context. +func WithRouteInfo(ctx context.Context, info RouteInfo) context.Context { + return context.WithValue(ctx, routeKey{}, info) +} + +// RouteInfoFrom extracts routing metadata from a context. +func RouteInfoFrom(ctx context.Context) (RouteInfo, bool) { + info, ok := ctx.Value(routeKey{}).(RouteInfo) + return info, ok +} + +// ScionExecutor implements a2asrv.AgentExecutor, bridging the SDK's event model +// to the Scion Hub message routing. Each Execute call: +// 1. Translates the SDK message to a Scion StructuredMessage +// 2. Sends it to the target agent via Hub +// 3. Waits for the agent response via the broker +// 4. Translates the response back to SDK events +type ScionExecutor struct { + bridge *Bridge + log *slog.Logger +} + +var _ a2asrv.AgentExecutor = (*ScionExecutor)(nil) + +// NewScionExecutor creates a new executor that routes A2A requests to Scion agents. +func NewScionExecutor(bridge *Bridge, log *slog.Logger) *ScionExecutor { + return &ScionExecutor{bridge: bridge, log: log} +} + +// Execute implements a2asrv.AgentExecutor. It routes the incoming A2A message +// to a Scion agent and yields events as the agent responds. +func (e *ScionExecutor) Execute(ctx context.Context, execCtx *a2asrv.ExecutorContext) iter.Seq2[a2a.Event, error] { + return func(yield func(a2a.Event, error) bool) { + if execCtx == nil { + yield(nil, fmt.Errorf("executor context is nil: %w", a2a.ErrInternalError)) + return + } + route, ok := RouteInfoFrom(ctx) + if !ok { + yield(nil, fmt.Errorf("missing route info in context: %w", a2a.ErrInternalError)) + return + } + + taskID := execCtx.TaskID + + if e.bridge.hubClient == nil { + yield(nil, fmt.Errorf("hub client not configured: %w", a2a.ErrInternalError)) + return + } + + // Resolve the Scion agent context (agent ID, project ID). + // TODO(multi-turn): Pass execCtx.ContextID here to reuse existing + // Scion contexts for multi-turn conversations. Currently always creates + // a new context, breaking agents that use input-required → completed flows. + agentCtx, err := e.bridge.resolveContext(ctx, route.ProjectSlug, route.AgentSlug, "") + if err != nil { + yield(nil, fmt.Errorf("resolve agent: %w", err)) + return + } + + // Emit submitted task. + if execCtx.StoredTask == nil { + task := a2a.NewSubmittedTask(execCtx, execCtx.Message) + if !yield(task, nil) { + return + } + } + + // Translate A2A message parts to Scion format. + scionMsg := TranslateA2APartsToScion(execCtx.Message.Parts) + scionMsg.Sender = fmt.Sprintf("user:%s", e.bridge.config.Hub.User) + scionMsg.Recipient = fmt.Sprintf("agent:%s", agentCtx.AgentSlug) + scionMsg.Metadata = map[string]string{"a2aTaskId": string(taskID)} + + // Request broker subscription for responses. + if e.bridge.broker != nil { + pattern := fmt.Sprintf("scion.project.%s.user.%s.messages", agentCtx.ProjectID, e.bridge.config.Hub.User) + if err := e.bridge.broker.RequestSubscription(pattern); err != nil { + e.log.Warn("failed to request subscription", "pattern", pattern, "error", err) + } + legacyPattern := fmt.Sprintf("scion.grove.%s.user.%s.messages", agentCtx.ProjectID, e.bridge.config.Hub.User) + if err := e.bridge.broker.RequestSubscription(legacyPattern); err != nil { + e.log.Warn("failed to request legacy subscription", "pattern", legacyPattern, "error", err) + } + } + + // Register active task for broker correlation. + aKey := agentKey(agentCtx.ProjectID, agentCtx.AgentSlug) + e.bridge.registerActiveTask(string(taskID), aKey) + defer e.bridge.unregisterActiveTask(string(taskID), aKey) + + // Set up response channel. + responseCh := make(chan *messages.StructuredMessage, 1) + e.bridge.addWaiter(string(taskID), &waiter{ + ch: responseCh, + agentSlug: agentCtx.AgentSlug, + projectID: agentCtx.ProjectID, + }) + defer e.bridge.removeWaiter(string(taskID)) + + // Send to Hub. + if _, err := e.bridge.hubClient.Agents().SendStructuredMessage(ctx, agentCtx.AgentID, scionMsg, false, false, false); err != nil { + failMsg := a2a.NewMessage(a2a.MessageRoleAgent, a2a.NewTextPart(fmt.Sprintf("Failed to send message to agent: %v", err))) + yield(a2a.NewStatusUpdateEvent(execCtx, a2a.TaskStateFailed, failMsg), nil) + return + } + + // Emit working status. + if !yield(a2a.NewStatusUpdateEvent(execCtx, a2a.TaskStateWorking, nil), nil) { + return + } + + if e.bridge.metrics != nil { + e.bridge.metrics.TasksCreated.WithLabelValues(agentCtx.ProjectID).Inc() + } + + // Wait for agent response. + timeout := e.bridge.config.Timeouts.SendMessage + if timeout == 0 { + timeout = 120 * time.Second + } + timer := time.NewTimer(timeout) + defer timer.Stop() + + select { + case response, ok := <-responseCh: + if !ok || response == nil { + yield(nil, fmt.Errorf("response channel closed without a reply")) + return + } + agentMsg, artifacts := TranslateScionToA2AParts(response) + + // Emit artifact events using SDK constructors. + for _, art := range artifacts { + artEvent := a2a.NewArtifactEvent(execCtx, art.Parts...) + artEvent.LastChunk = true + if art.Name != "" { + artEvent.Artifact.Name = art.Name + } + if art.Description != "" { + artEvent.Artifact.Description = art.Description + } + if !yield(artEvent, nil) { + return + } + } + + // Emit completed status with agent message. + statusMsg := a2a.NewMessageForTask(a2a.MessageRoleAgent, execCtx, agentMsg.Parts...) + yield(a2a.NewStatusUpdateEvent(execCtx, a2a.TaskStateCompleted, statusMsg), nil) + + if e.bridge.metrics != nil { + e.bridge.metrics.TasksCompleted.WithLabelValues("completed").Inc() + } + + case <-timer.C: + failMsg := a2a.NewMessage(a2a.MessageRoleAgent, a2a.NewTextPart(fmt.Sprintf("Timeout waiting for agent response after %v", timeout))) + yield(a2a.NewStatusUpdateEvent(execCtx, a2a.TaskStateFailed, failMsg), nil) + + if e.bridge.metrics != nil { + e.bridge.metrics.TasksCompleted.WithLabelValues("failed").Inc() + } + + case <-ctx.Done(): + failMsg := a2a.NewMessage(a2a.MessageRoleAgent, a2a.NewTextPart("Request cancelled")) + yield(a2a.NewStatusUpdateEvent(execCtx, a2a.TaskStateFailed, failMsg), nil) + + if e.bridge.metrics != nil { + e.bridge.metrics.TasksCompleted.WithLabelValues("failed").Inc() + } + } + } +} + +// Cancel implements a2asrv.AgentExecutor. It sends an interrupt to the Scion +// agent and emits a canceled status. +func (e *ScionExecutor) Cancel(ctx context.Context, execCtx *a2asrv.ExecutorContext) iter.Seq2[a2a.Event, error] { + return func(yield func(a2a.Event, error) bool) { + taskID := execCtx.TaskID + + // Look up the stored task to find the agent and send an interrupt. + if execCtx.StoredTask != nil && e.bridge.hubClient != nil { + route, ok := RouteInfoFrom(ctx) + if ok { + if agent := e.bridge.lookupAgent(ctx, route.ProjectSlug, route.AgentSlug); agent != nil { + interruptMsg := &messages.StructuredMessage{ + Version: 1, + Timestamp: time.Now().UTC().Format(time.RFC3339), + Sender: fmt.Sprintf("user:%s", e.bridge.config.Hub.User), + Recipient: fmt.Sprintf("agent:%s", route.AgentSlug), + Msg: "Task cancelled by A2A client.", + Type: messages.TypeInstruction, + Metadata: map[string]string{"a2aTaskId": string(taskID)}, + } + if _, err := e.bridge.hubClient.Agents().SendStructuredMessage(ctx, agent.ID, interruptMsg, true, false, false); err != nil { + e.log.Error("failed to send cancel interrupt", "error", err, "task_id", taskID) + } + } + } else { + e.log.Warn("cancel: missing route info in context, cannot send interrupt to agent", "task_id", taskID) + } + } + + yield(a2a.NewStatusUpdateEvent(execCtx, a2a.TaskStateCanceled, nil), nil) + } +} + +// SSEWriteDeadlineMiddleware wraps an http.Handler to clear the write deadline +// for SSE (text/event-stream) responses, allowing long-lived streaming +// connections while keeping WriteTimeout enabled for non-streaming endpoints. +func SSEWriteDeadlineMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + next.ServeHTTP(&sseDeadlineWriter{ResponseWriter: w}, r) + }) +} + +// sseDeadlineWriter intercepts WriteHeader to clear the write deadline when +// the response is an SSE stream (Content-Type: text/event-stream). +type sseDeadlineWriter struct { + http.ResponseWriter + cleared bool +} + +func (s *sseDeadlineWriter) WriteHeader(code int) { + if !s.cleared { + if ct := s.ResponseWriter.Header().Get("Content-Type"); ct == "text/event-stream" { + rc := http.NewResponseController(s.ResponseWriter) + _ = rc.SetWriteDeadline(time.Time{}) // clear deadline for SSE + } + s.cleared = true + } + s.ResponseWriter.WriteHeader(code) +} + +func (s *sseDeadlineWriter) Write(b []byte) (int, error) { + if !s.cleared { + if ct := s.ResponseWriter.Header().Get("Content-Type"); ct == "text/event-stream" { + rc := http.NewResponseController(s.ResponseWriter) + _ = rc.SetWriteDeadline(time.Time{}) // clear deadline for SSE + } + s.cleared = true + } + return s.ResponseWriter.Write(b) +} + +func (s *sseDeadlineWriter) Flush() { + if f, ok := s.ResponseWriter.(http.Flusher); ok { + f.Flush() + } +} + +func (s *sseDeadlineWriter) Unwrap() http.ResponseWriter { + return s.ResponseWriter +} + +// RouteInfoMiddleware wraps an http.Handler to inject a fixed RouteInfo into the +// request context. Used for transports (REST) that don't have per-request +// project/agent routing. +func RouteInfoMiddleware(route RouteInfo, next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := WithRouteInfo(r.Context(), route) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +// AuthUnaryInterceptor returns a gRPC unary interceptor that validates API key +// or bearer token credentials from gRPC metadata. Pass insecure=true to skip +// auth (requires explicit opt-in via config). +func AuthUnaryInterceptor(cfg *Config) grpc.UnaryServerInterceptor { + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + if cfg.Auth.Scheme == "none" || cfg.Bridge.GRPCInsecure { + return handler(ctx, req) + } + if err := validateGRPCAuth(ctx, cfg); err != nil { + return nil, err + } + return handler(ctx, req) + } +} + +// AuthStreamInterceptor returns a gRPC stream interceptor that validates API key +// or bearer token credentials from gRPC metadata. +func AuthStreamInterceptor(cfg *Config) grpc.StreamServerInterceptor { + return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + if cfg.Auth.Scheme == "none" || cfg.Bridge.GRPCInsecure { + return handler(srv, ss) + } + if err := validateGRPCAuth(ss.Context(), cfg); err != nil { + return err + } + return handler(srv, ss) + } +} + +// validateGRPCAuth extracts credentials from gRPC metadata and validates them. +func validateGRPCAuth(ctx context.Context, cfg *Config) error { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return status.Error(codes.Unauthenticated, "missing metadata") + } + + var credential string + switch cfg.Auth.Scheme { + case "apiKey": + if vals := md.Get("x-api-key"); len(vals) > 0 { + credential = vals[0] + } + case "bearer": + if vals := md.Get("authorization"); len(vals) > 0 { + auth := vals[0] + if strings.HasPrefix(auth, "Bearer ") { + credential = strings.TrimPrefix(auth, "Bearer ") + } + } + default: + if vals := md.Get("x-api-key"); len(vals) > 0 { + credential = vals[0] + } + if credential == "" { + if vals := md.Get("authorization"); len(vals) > 0 { + auth := vals[0] + if strings.HasPrefix(auth, "Bearer ") { + credential = strings.TrimPrefix(auth, "Bearer ") + } + } + } + } + + if !verifyCredential(credential, cfg.Auth.APIKey) { + return status.Error(codes.Unauthenticated, "invalid credentials") + } + return nil +} + +// RouteInfoUnaryInterceptor returns a gRPC unary server interceptor that injects +// a fixed RouteInfo into the request context. +func RouteInfoUnaryInterceptor(route RouteInfo) grpc.UnaryServerInterceptor { + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + return handler(WithRouteInfo(ctx, route), req) + } +} + +// RouteInfoStreamInterceptor returns a gRPC stream server interceptor that +// injects a fixed RouteInfo into the stream context. +func RouteInfoStreamInterceptor(route RouteInfo) grpc.StreamServerInterceptor { + return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + wrapped := &routeInfoServerStream{ServerStream: ss, ctx: WithRouteInfo(ss.Context(), route)} + return handler(srv, wrapped) + } +} + +// routeInfoServerStream wraps a grpc.ServerStream to override its Context. +type routeInfoServerStream struct { + grpc.ServerStream + ctx context.Context +} + +func (s *routeInfoServerStream) Context() context.Context { + return s.ctx +} diff --git a/extras/scion-a2a-bridge/internal/bridge/followup_test.go b/extras/scion-a2a-bridge/internal/bridge/followup_test.go index 2ff51e378..10eeccab2 100644 --- a/extras/scion-a2a-bridge/internal/bridge/followup_test.go +++ b/extras/scion-a2a-bridge/internal/bridge/followup_test.go @@ -28,6 +28,9 @@ import ( "testing" "time" + "github.com/a2aproject/a2a-go/v2/a2a" + "github.com/a2aproject/a2a-go/v2/a2asrv" + "github.com/GoogleCloudPlatform/scion/extras/scion-a2a-bridge/internal/state" "github.com/GoogleCloudPlatform/scion/pkg/hubclient" "github.com/GoogleCloudPlatform/scion/pkg/messages" @@ -793,14 +796,44 @@ func TestSendFollowUp_ResolvesAgentIDViaLookup(t *testing.T) { // --- Server-layer tests for handleSendMessage with TaskID --- -func TestHandleSendMessage_PassesTaskIDToSendMessage(t *testing.T) { +// newTestServerWithHub creates a test server wired to a mock hub client and SDK handler. +func newTestServerWithHub(t *testing.T, hub hubclient.Client, cfg *Config) (*Bridge, *httptest.Server, *state.Store) { + t.Helper() + dir := t.TempDir() store, err := state.New(filepath.Join(dir, "test.db")) if err != nil { t.Fatalf("state.New: %v", err) } - defer store.Close() + t.Cleanup(func() { store.Close() }) + log := slog.New(slog.NewTextHandler(io.Discard, nil)) + b := New(store, hub, nil, cfg, nil, log) + t.Cleanup(func() { b.Shutdown() }) + + executor := NewScionExecutor(b, log) + sdkRequestHandler := a2asrv.NewHandler( + executor, + a2asrv.WithLogger(log), + a2asrv.WithCapabilityChecks(&a2a.AgentCapabilities{ + Streaming: true, + PushNotifications: false, + }), + ) + b.SetSDKRequestHandler(sdkRequestHandler) + sdkJSONRPCHandler := a2asrv.NewJSONRPCHandler(sdkRequestHandler) + + srv := NewServer(b, cfg, nil, log, sdkJSONRPCHandler) + ts := httptest.NewServer(srv.Handler()) + t.Cleanup(ts.Close) + + return b, ts, store +} + +func TestHandleSendMessage_PassesTaskIDToSendMessage(t *testing.T) { + // This test verifies that the bridge's SendMessage correctly propagates the + // a2aTaskId metadata to the hub when following up on an existing task. + // Uses the bridge API directly since the SDK manages its own task lifecycle. var mu sync.Mutex var capturedMeta map[string]string agents := &mockAgentService{ @@ -811,40 +844,13 @@ func TestHandleSendMessage_PassesTaskIDToSendMessage(t *testing.T) { return nil, nil }, } - - cfg := &Config{ - Bridge: BridgeConfig{ExternalURL: "https://test.example.com"}, - Hub: HubConfig{User: "test-user"}, - Auth: AuthConfig{Scheme: "apiKey", APIKey: "test-key"}, - Projects: []ProjectConfig{{Slug: "proj-1", ExposedAgents: []string{"agent-a"}}}, - Timeouts: TimeoutConfig{SendMessage: 2 * time.Second}, - } - log := slog.New(slog.NewTextHandler(io.Discard, nil)) - hub := &mockHubClient{agents: agents} - bridge := New(store, hub, nil, cfg, nil, log) - defer bridge.Shutdown() - srv := NewServer(bridge, cfg, nil, log) - ts := httptest.NewServer(srv.Handler()) - defer ts.Close() - + b, store := newFollowUpTestBridge(t, agents) seedTask(t, store, "existing-task", "ctx-1", "proj-1", "agent-a", "aid", TaskStateWorking) - params := SendMessageParams{ - TaskID: "existing-task", - Message: Message{ - Role: RoleUser, - Parts: []Part{{Text: "follow up"}}, - }, - Configuration: &SendMessageConfig{ - Blocking: boolPtr(false), - }, - } - - rpcResp := doRPC(t, ts, "/projects/proj-1/agents/agent-a/jsonrpc", - "message/send", params, "test-key") - - if rpcResp.Error != nil { - t.Fatalf("unexpected error: code=%d msg=%s", rpcResp.Error.Code, rpcResp.Error.Message) + _, err := b.SendMessage(context.Background(), "proj-1", "agent-a", "", "existing-task", + []Part{{Text: "follow up"}}, false) + if err != nil { + t.Fatalf("SendMessage: %v", err) } // Poll until the send function captures metadata. @@ -871,109 +877,69 @@ func TestHandleSendMessage_PassesTaskIDToSendMessage(t *testing.T) { } } -func TestHandleSendMessage_ErrTaskTerminal_ReturnsCorrectError(t *testing.T) { - dir := t.TempDir() - store, err := state.New(filepath.Join(dir, "test.db")) - if err != nil { - t.Fatalf("state.New: %v", err) - } - defer store.Close() - +func TestHandleSendMessage_ErrTaskTerminal_ReturnsError(t *testing.T) { agents := &mockAgentService{} cfg := &Config{ Bridge: BridgeConfig{ExternalURL: "https://test.example.com"}, Hub: HubConfig{User: "test-user"}, Auth: AuthConfig{Scheme: "apiKey", APIKey: "test-key"}, Projects: []ProjectConfig{{Slug: "proj-1", ExposedAgents: []string{"agent-a"}}}, + Timeouts: TimeoutConfig{SendMessage: 2 * time.Second}, } - log := slog.New(slog.NewTextHandler(io.Discard, nil)) hub := &mockHubClient{agents: agents} - bridge := New(store, hub, nil, cfg, nil, log) - defer bridge.Shutdown() - srv := NewServer(bridge, cfg, nil, log) - ts := httptest.NewServer(srv.Handler()) - defer ts.Close() + _, ts, store := newTestServerWithHub(t, hub, cfg) seedTask(t, store, "done-task", "ctx-1", "proj-1", "agent-a", "aid", TaskStateCompleted) - params := SendMessageParams{ - TaskID: "done-task", - Message: Message{ - Role: RoleUser, - Parts: []Part{{Text: "try to follow up"}}, - }, + params := a2a.SendMessageRequest{ + Message: a2a.NewMessage(a2a.MessageRoleUser, a2a.NewTextPart("try to follow up")), } + params.Message.TaskID = "done-task" rpcResp := doRPC(t, ts, "/projects/proj-1/agents/agent-a/jsonrpc", - "message/send", params, "test-key") + "SendMessage", params, "test-key") + // The SDK should return an error for a terminal task. if rpcResp.Error == nil { t.Fatal("expected error for terminal task") } - if rpcResp.Error.Code != ErrCodeInvalidParams { - t.Errorf("error code = %d, want %d", rpcResp.Error.Code, ErrCodeInvalidParams) - } - if rpcResp.Error.Message != "task is in a terminal state" { - t.Errorf("error message = %q, want %q", rpcResp.Error.Message, "task is in a terminal state") + // The SDK uses negative error codes for A2A errors. + if rpcResp.Error.Code >= 0 { + t.Errorf("expected negative error code, got %d", rpcResp.Error.Code) } } -func TestHandleSendMessage_UnknownTaskID_ReturnsAgentNotFound(t *testing.T) { - dir := t.TempDir() - store, err := state.New(filepath.Join(dir, "test.db")) - if err != nil { - t.Fatalf("state.New: %v", err) - } - defer store.Close() - +func TestHandleSendMessage_UnknownTaskID_ReturnsError(t *testing.T) { agents := &mockAgentService{} cfg := &Config{ Bridge: BridgeConfig{ExternalURL: "https://test.example.com"}, Hub: HubConfig{User: "test-user"}, Auth: AuthConfig{Scheme: "apiKey", APIKey: "test-key"}, Projects: []ProjectConfig{{Slug: "proj-1", ExposedAgents: []string{"agent-a"}}}, + Timeouts: TimeoutConfig{SendMessage: 2 * time.Second}, } - log := slog.New(slog.NewTextHandler(io.Discard, nil)) hub := &mockHubClient{agents: agents} - bridge := New(store, hub, nil, cfg, nil, log) - defer bridge.Shutdown() - srv := NewServer(bridge, cfg, nil, log) - ts := httptest.NewServer(srv.Handler()) - defer ts.Close() + _, ts, _ := newTestServerWithHub(t, hub, cfg) - params := SendMessageParams{ - TaskID: "no-such-task", - Message: Message{ - Role: RoleUser, - Parts: []Part{{Text: "follow up"}}, - }, + params := a2a.SendMessageRequest{ + Message: a2a.NewMessage(a2a.MessageRoleUser, a2a.NewTextPart("follow up")), } + params.Message.TaskID = "no-such-task" rpcResp := doRPC(t, ts, "/projects/proj-1/agents/agent-a/jsonrpc", - "message/send", params, "test-key") + "SendMessage", params, "test-key") + // The SDK should return an error for an unknown task. if rpcResp.Error == nil { t.Fatal("expected error for unknown task ID") } - if rpcResp.Error.Code != ErrCodeInvalidParams { - t.Errorf("error code = %d, want %d", rpcResp.Error.Code, ErrCodeInvalidParams) - } - if rpcResp.Error.Message != "agent not found" { - t.Errorf("error message = %q, want %q", rpcResp.Error.Message, "agent not found") + if rpcResp.Error.Code >= 0 { + t.Errorf("expected negative error code, got %d", rpcResp.Error.Code) } } func TestHandleSendMessage_NoTaskID_RoutesToNewTask(t *testing.T) { - // When TaskID is empty, SendMessage should try to create a new task (and fail - // because there's no real hub client to resolve the context). This verifies - // the router correctly falls through to the new-task path. - dir := t.TempDir() - store, err := state.New(filepath.Join(dir, "test.db")) - if err != nil { - t.Fatalf("state.New: %v", err) - } - defer store.Close() - + // When TaskID is empty, SendMessage should try to create a new task. agents := &mockAgentService{ listFn: func(ctx context.Context, opts *hubclient.ListAgentsOptions) (*hubclient.ListAgentsResponse, error) { return &hubclient.ListAgentsResponse{ @@ -993,58 +959,36 @@ func TestHandleSendMessage_NoTaskID_RoutesToNewTask(t *testing.T) { Projects: []ProjectConfig{{Slug: "proj-1", ExposedAgents: []string{"agent-a"}}}, Timeouts: TimeoutConfig{SendMessage: 2 * time.Second}, } - log := slog.New(slog.NewTextHandler(io.Discard, nil)) hub := &mockHubClient{agents: agents} - bridge := New(store, hub, nil, cfg, nil, log) - defer bridge.Shutdown() - srv := NewServer(bridge, cfg, nil, log) - ts := httptest.NewServer(srv.Handler()) - defer ts.Close() + _, ts, _ := newTestServerWithHub(t, hub, cfg) - params := SendMessageParams{ - Message: Message{ - Role: RoleUser, - Parts: []Part{{Text: "new message"}}, - }, - Configuration: &SendMessageConfig{ - Blocking: boolPtr(false), - }, + params := a2a.SendMessageRequest{ + Message: a2a.NewMessage(a2a.MessageRoleUser, a2a.NewTextPart("new message")), + Config: &a2a.SendMessageConfig{ReturnImmediately: true}, } rpcResp := doRPC(t, ts, "/projects/proj-1/agents/agent-a/jsonrpc", - "message/send", params, "test-key") + "SendMessage", params, "test-key") // Should succeed — the new task path creates a context and task. if rpcResp.Error != nil { t.Fatalf("unexpected error: code=%d msg=%s", rpcResp.Error.Code, rpcResp.Error.Message) } - resultBytes, err2 := json.Marshal(rpcResp.Result) - if err2 != nil { - t.Fatalf("marshal result: %v", err2) - } - var result TaskResult - if err2 = json.Unmarshal(resultBytes, &result); err2 != nil { - t.Fatalf("unmarshal result: %v", err2) - } - - if result.ID == "" { - t.Error("expected non-empty task ID for new task") - } - if result.Status.State != TaskStateSubmitted { - t.Errorf("status.state = %q, want %q", result.Status.State, TaskStateSubmitted) + if rpcResp.Result == nil { + t.Fatal("expected non-nil result for new task") } } -func TestSendFollowUp_SendMessageParams_TaskIDField(t *testing.T) { - // Verify the TaskID field is correctly parsed from JSON. - raw := `{"taskId":"my-task-123","message":{"role":"user","parts":[{"text":"hi"}]}}` - var params SendMessageParams +func TestSendFollowUp_SDKSendMessageRequest_TaskIDField(t *testing.T) { + // Verify the TaskID field is correctly parsed from JSON in the SDK type. + raw := `{"message":{"taskId":"my-task-123","role":"user","messageId":"msg-1","parts":[{"text":"hi"}]}}` + var params a2a.SendMessageRequest if err := json.Unmarshal([]byte(raw), ¶ms); err != nil { t.Fatalf("unmarshal: %v", err) } - if params.TaskID != "my-task-123" { - t.Errorf("TaskID = %q, want %q", params.TaskID, "my-task-123") + if params.Message.TaskID != "my-task-123" { + t.Errorf("TaskID = %q, want %q", params.Message.TaskID, "my-task-123") } } diff --git a/extras/scion-a2a-bridge/internal/bridge/server.go b/extras/scion-a2a-bridge/internal/bridge/server.go index 54643bc81..f1f26da30 100644 --- a/extras/scion-a2a-bridge/internal/bridge/server.go +++ b/extras/scion-a2a-bridge/internal/bridge/server.go @@ -18,88 +18,35 @@ import ( "crypto/sha256" "crypto/subtle" "encoding/json" - "errors" "fmt" "log/slog" "net/http" "net/url" "regexp" "strings" - "time" -) - -var slugRE = regexp.MustCompile(`^[a-z0-9][a-z0-9-]{0,62}$`) -// A2A JSON-RPC error codes. -const ( - ErrCodeParseError = -32700 - ErrCodeInvalidRequest = -32600 - ErrCodeMethodNotFound = -32601 - ErrCodeInvalidParams = -32602 - ErrCodeInternalError = -32603 - ErrCodeTaskNotFound = -32001 - ErrCodeTaskNotCancelable = -32002 - ErrCodeUnsupportedOp = -32004 + "github.com/a2aproject/a2a-go/v2/a2asrv" ) -// JSONRPCRequest represents an incoming JSON-RPC 2.0 request. -type JSONRPCRequest struct { - JSONRPC string `json:"jsonrpc"` - ID interface{} `json:"id"` - Method string `json:"method"` - Params json.RawMessage `json:"params"` -} - -// JSONRPCResponse represents an outgoing JSON-RPC 2.0 response. -type JSONRPCResponse struct { - JSONRPC string `json:"jsonrpc"` - ID interface{} `json:"id"` - Result interface{} `json:"result,omitempty"` - Error *JSONRPCError `json:"error,omitempty"` -} - -// JSONRPCError represents a JSON-RPC 2.0 error. -type JSONRPCError struct { - Code int `json:"code"` - Message string `json:"message"` - Data interface{} `json:"data,omitempty"` -} - -// SendMessageParams holds parameters for the SendMessage RPC method. -type SendMessageParams struct { - Message Message `json:"message"` - Configuration *SendMessageConfig `json:"configuration,omitempty"` - ContextID string `json:"contextId,omitempty"` - TaskID string `json:"taskId,omitempty"` -} - -// SendMessageConfig holds SendMessage configuration options. -type SendMessageConfig struct { - AcceptedOutputModes []string `json:"acceptedOutputModes,omitempty"` - Blocking *bool `json:"blocking,omitempty"` -} - -// TaskQueryParams holds parameters for GetTask/ListTasks. -type TaskQueryParams struct { - ID string `json:"id,omitempty"` - ContextID string `json:"contextId,omitempty"` -} +var slugRE = regexp.MustCompile(`^[a-z0-9][a-z0-9-]{0,62}$`) -// Server is the A2A HTTP server that handles JSON-RPC requests. +// Server is the A2A HTTP server that routes requests to the SDK handler. type Server struct { - bridge *Bridge - config *Config - metrics *Metrics - log *slog.Logger + bridge *Bridge + config *Config + metrics *Metrics + log *slog.Logger + sdkHandler http.Handler // SDK JSON-RPC handler } -// NewServer creates a new A2A protocol server. -func NewServer(bridge *Bridge, cfg *Config, metrics *Metrics, log *slog.Logger) *Server { +// NewServer creates a new A2A protocol server backed by the SDK. +func NewServer(bridge *Bridge, cfg *Config, metrics *Metrics, log *slog.Logger, sdkHandler http.Handler) *Server { return &Server{ - bridge: bridge, - config: cfg, - metrics: metrics, - log: log, + bridge: bridge, + config: cfg, + metrics: metrics, + log: log, + sdkHandler: sdkHandler, } } @@ -138,6 +85,13 @@ func ValidateConfig(cfg *Config) error { return fmt.Errorf("bridge.provider.url is invalid: %w", err) } } + // Require explicit opt-in for unauthenticated gRPC/REST transports. + if cfg.Bridge.GRPCListenAddress != "" && !cfg.Bridge.GRPCInsecure && cfg.Auth.Scheme == "none" { + return fmt.Errorf("gRPC transport is configured but auth.scheme is \"none\"; set bridge.grpc_insecure: true to acknowledge unauthenticated gRPC access, or configure auth") + } + if cfg.Bridge.RESTListenAddress != "" && !cfg.Bridge.RESTInsecure && cfg.Auth.Scheme == "none" { + return fmt.Errorf("REST transport is configured but auth.scheme is \"none\"; set bridge.rest_insecure: true to acknowledge unauthenticated REST access, or configure auth") + } return nil } @@ -160,7 +114,7 @@ func (s *Server) Handler() http.Handler { // Top-level well-known agent card (registry). mux.HandleFunc("GET /.well-known/agent-card.json", s.handleWellKnownAgentCard) - // Per-agent routes. + // Per-agent routes — the SDK handler handles JSON-RPC protocol. mux.HandleFunc("GET /projects/{projectSlug}/agents/{agentSlug}/.well-known/agent-card.json", s.handleAgentCard) mux.HandleFunc("POST /projects/{projectSlug}/agents/{agentSlug}/jsonrpc", s.handleJSONRPC) @@ -180,6 +134,14 @@ func (s *Server) Handler() http.Handler { return handler } +// SDKRequestHandler returns the a2asrv.RequestHandler for use with other transports (gRPC, REST). +// Returns nil if the server was created without an SDK handler. +func (s *Server) SDKRequestHandler() a2asrv.RequestHandler { + // The SDK handler is stored as http.Handler but we also need the RequestHandler + // for gRPC/REST transports. This is set via SetSDKRequestHandler. + return s.bridge.sdkRequestHandler +} + func (s *Server) handleHealthz(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(map[string]string{"status": "ok"}); err != nil { @@ -225,7 +187,7 @@ func (s *Server) handleWellKnownAgentCard(w http.ResponseWriter, r *http.Request "version": "1.0.0", "capabilities": map[string]bool{ "streaming": true, - "pushNotifications": true, + "pushNotifications": false, }, } @@ -281,486 +243,104 @@ func (s *Server) handleAgentCard(w http.ResponseWriter, r *http.Request) { } } +// handleJSONRPC validates the project/agent routing and delegates to the SDK handler. func (s *Server) handleJSONRPC(w http.ResponseWriter, r *http.Request) { projectSlug := r.PathValue("projectSlug") agentSlug := r.PathValue("agentSlug") if !slugRE.MatchString(projectSlug) || !slugRE.MatchString(agentSlug) { - s.writeRPCError(w, nil, ErrCodeInvalidParams, "invalid slug format") + writeJSONRPCError(w, nil, -32602, "invalid slug format") return } if err := s.bridge.AuthorizeExposed(projectSlug, agentSlug); err != nil { - s.writeRPCError(w, nil, ErrCodeInvalidParams, "agent not found") + writeJSONRPCError(w, nil, -32602, "agent not found") return } + // Limit request body to 1 MB to prevent memory exhaustion from oversized payloads. r.Body = http.MaxBytesReader(w, r.Body, 1<<20) - var req JSONRPCRequest - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - s.writeRPCError(w, nil, ErrCodeParseError, "parse error") - return - } - - if req.JSONRPC != "2.0" { - s.writeRPCError(w, req.ID, ErrCodeInvalidRequest, "invalid JSON-RPC version") - return - } - - // JSON-RPC 2.0 §4.1: notifications (id absent/null) must not receive responses. - if req.ID == nil { - s.log.Debug("ignoring JSON-RPC notification", "method", req.Method) - return - } - - s.log.Debug("JSON-RPC request", - "method", req.Method, - "project", projectSlug, - "agent", agentSlug, - ) - - switch req.Method { - case "message/send": - s.handleSendMessage(w, r, req, projectSlug, agentSlug) - case "message/stream": - s.handleStreamMessage(w, r, req, projectSlug, agentSlug) - case "tasks/get": - s.handleGetTask(w, r, req, projectSlug, agentSlug) - case "tasks/list": - s.handleListTasks(w, r, req, projectSlug, agentSlug) - case "tasks/cancel": - s.handleCancelTask(w, r, req, projectSlug, agentSlug) - case "tasks/pushNotification/set": - s.handleSetPushNotification(w, r, req, projectSlug, agentSlug) - case "tasks/pushNotification/get": - s.handleGetPushNotification(w, r, req, projectSlug, agentSlug) - case "tasks/pushNotification/delete": - s.handleDeletePushNotification(w, r, req, projectSlug, agentSlug) - case "tasks/resubscribe": - s.handleResubscribe(w, r, req, projectSlug, agentSlug) - default: - s.writeRPCError(w, req.ID, ErrCodeMethodNotFound, fmt.Sprintf("method %q not found", req.Method)) - } -} - -func (s *Server) handleSendMessage(w http.ResponseWriter, r *http.Request, req JSONRPCRequest, projectSlug, agentSlug string) { - var params SendMessageParams - if err := json.Unmarshal(req.Params, ¶ms); err != nil { - s.log.Warn("invalid SendMessage params", "error", err) - s.writeRPCError(w, req.ID, ErrCodeInvalidParams, "invalid parameters") - return - } - - if len(params.Message.Parts) == 0 { - s.writeRPCError(w, req.ID, ErrCodeInvalidParams, "message.parts must be non-empty") - return - } - if params.Message.Role != "" && params.Message.Role != RoleUser { - s.writeRPCError(w, req.ID, ErrCodeInvalidParams, "message.role must be 'user'") - return - } - - blocking := true - if params.Configuration != nil && params.Configuration.Blocking != nil { - blocking = *params.Configuration.Blocking - } - - result, err := s.bridge.SendMessage(r.Context(), projectSlug, agentSlug, params.ContextID, params.TaskID, params.Message.Parts, blocking) - if err != nil { - s.log.Error("SendMessage failed", "error", err, "project", projectSlug, "agent", agentSlug) - switch { - case errors.Is(err, ErrAgentNotFound): - s.writeRPCError(w, req.ID, ErrCodeInvalidParams, "agent not found") - case errors.Is(err, ErrContextUnknown): - s.writeRPCError(w, req.ID, ErrCodeInvalidParams, "unknown context ID") - case errors.Is(err, ErrTaskTerminal): - s.writeRPCError(w, req.ID, ErrCodeInvalidParams, "task is in a terminal state") - default: - s.writeRPCError(w, req.ID, ErrCodeInternalError, "internal error") - } - return - } - - s.writeRPCResult(w, req.ID, result) -} - -func (s *Server) handleGetTask(w http.ResponseWriter, r *http.Request, req JSONRPCRequest, projectSlug, agentSlug string) { - var params TaskQueryParams - if err := json.Unmarshal(req.Params, ¶ms); err != nil { - s.log.Warn("invalid GetTask params", "error", err) - s.writeRPCError(w, req.ID, ErrCodeInvalidParams, "invalid parameters") - return - } - - if params.ID == "" { - s.writeRPCError(w, req.ID, ErrCodeInvalidParams, "id is required") - return - } - - task, err := s.bridge.AuthorizeTask(params.ID, projectSlug, agentSlug) - if err != nil { - s.log.Error("GetTask failed", "error", err, "taskID", params.ID) - s.writeRPCError(w, req.ID, ErrCodeInternalError, "internal error") - return - } - if task == nil { - s.writeRPCError(w, req.ID, ErrCodeTaskNotFound, "task not found") - return - } - - s.writeRPCResult(w, req.ID, &TaskResult{ - ID: task.ID, - ContextID: task.ContextID, - Status: TaskStatus{State: task.State}, + // Inject routing info into context for the executor. + ctx := WithRouteInfo(r.Context(), RouteInfo{ + ProjectSlug: projectSlug, + AgentSlug: agentSlug, }) -} - -func (s *Server) handleListTasks(w http.ResponseWriter, r *http.Request, req JSONRPCRequest, projectSlug, agentSlug string) { - var params TaskQueryParams - if err := json.Unmarshal(req.Params, ¶ms); err != nil { - s.log.Warn("invalid ListTasks params", "error", err) - s.writeRPCError(w, req.ID, ErrCodeInvalidParams, "invalid parameters") - return - } - - if params.ContextID == "" { - s.writeRPCError(w, req.ID, ErrCodeInvalidParams, "contextId is required") - return - } - - authorized, authErr := s.bridge.AuthorizeContext(params.ContextID, projectSlug, agentSlug) - if authErr != nil { - s.log.Error("AuthorizeContext failed", "error", authErr, "contextID", params.ContextID) - s.writeRPCError(w, req.ID, ErrCodeInternalError, "internal error") - return - } - if !authorized { - s.writeRPCError(w, req.ID, ErrCodeTaskNotFound, "context not found") - return - } - - tasks, err := s.bridge.ListTasks(r.Context(), params.ContextID) - if err != nil { - s.log.Error("ListTasks failed", "error", err, "contextID", params.ContextID) - s.writeRPCError(w, req.ID, ErrCodeInternalError, "internal error") - return - } - - s.writeRPCResult(w, req.ID, tasks) -} - -func (s *Server) handleCancelTask(w http.ResponseWriter, r *http.Request, req JSONRPCRequest, projectSlug, agentSlug string) { - var params TaskQueryParams - if err := json.Unmarshal(req.Params, ¶ms); err != nil { - s.log.Warn("invalid CancelTask params", "error", err) - s.writeRPCError(w, req.ID, ErrCodeInvalidParams, "invalid parameters") - return - } - - if params.ID == "" { - s.writeRPCError(w, req.ID, ErrCodeInvalidParams, "id is required") - return - } - - task, err := s.bridge.AuthorizeTask(params.ID, projectSlug, agentSlug) - if err != nil { - s.log.Error("CancelTask auth failed", "error", err, "taskID", params.ID) - s.writeRPCError(w, req.ID, ErrCodeInternalError, "internal error") - return - } - if task == nil { - s.writeRPCError(w, req.ID, ErrCodeTaskNotFound, "task not found") - return - } - - result, err := s.bridge.CancelTask(r.Context(), params.ID) - if err != nil { - s.log.Error("CancelTask failed", "error", err, "taskID", params.ID) - s.writeRPCError(w, req.ID, ErrCodeTaskNotCancelable, "task cannot be canceled") - return - } - if result == nil { - s.writeRPCError(w, req.ID, ErrCodeTaskNotFound, "task not found") - return - } - - s.writeRPCResult(w, req.ID, result) -} - -func (s *Server) handleStreamMessage(w http.ResponseWriter, r *http.Request, req JSONRPCRequest, projectSlug, agentSlug string) { - var params SendMessageParams - if err := json.Unmarshal(req.Params, ¶ms); err != nil { - s.log.Warn("invalid StreamMessage params", "error", err) - s.writeRPCError(w, req.ID, ErrCodeInvalidParams, "invalid parameters") - return - } - - if len(params.Message.Parts) == 0 { - s.writeRPCError(w, req.ID, ErrCodeInvalidParams, "message.parts must be non-empty") - return - } - if params.Message.Role != "" && params.Message.Role != RoleUser { - s.writeRPCError(w, req.ID, ErrCodeInvalidParams, "message.role must be 'user'") - return - } - - taskID, events, cleanup, err := s.bridge.SendStreamingMessage(r.Context(), projectSlug, agentSlug, params.ContextID, params.Message.Parts) - if err != nil { - s.log.Error("SendStreamingMessage failed", "error", err, "project", projectSlug, "agent", agentSlug) - s.writeRPCError(w, req.ID, ErrCodeInternalError, "internal error") - return - } - defer cleanup() - - s.writeSSEStream(w, r, taskID, events) -} - -func (s *Server) handleResubscribe(w http.ResponseWriter, r *http.Request, req JSONRPCRequest, projectSlug, agentSlug string) { - var params TaskQueryParams - if err := json.Unmarshal(req.Params, ¶ms); err != nil { - s.log.Warn("invalid Resubscribe params", "error", err) - s.writeRPCError(w, req.ID, ErrCodeInvalidParams, "invalid parameters") - return - } - - if params.ID == "" { - s.writeRPCError(w, req.ID, ErrCodeInvalidParams, "id is required") - return - } - - task, err := s.bridge.AuthorizeTask(params.ID, projectSlug, agentSlug) - if err != nil { - s.log.Error("Resubscribe auth failed", "error", err, "taskID", params.ID) - s.writeRPCError(w, req.ID, ErrCodeInternalError, "internal error") - return - } - if task == nil { - s.writeRPCError(w, req.ID, ErrCodeTaskNotFound, "task not found") - return - } + r = r.WithContext(ctx) - events, cleanup, err := s.bridge.SubscribeToTask(r.Context(), params.ID) - if err != nil { - s.log.Error("SubscribeToTask failed", "error", err, "taskID", params.ID) - s.writeRPCError(w, req.ID, ErrCodeInternalError, "internal error") - return - } - defer cleanup() - - s.writeSSEStream(w, r, params.ID, events) + // Delegate to SDK JSON-RPC handler. + s.sdkHandler.ServeHTTP(w, r) } -func (s *Server) writeSSEStream(w http.ResponseWriter, r *http.Request, taskID string, events <-chan StreamEvent) { - flusher, ok := w.(http.Flusher) - if !ok { - http.Error(w, "streaming not supported", http.StatusInternalServerError) - return - } - - // Disable the global WriteTimeout for this long-lived SSE connection. - rc := http.NewResponseController(w) - if err := rc.SetWriteDeadline(time.Time{}); err != nil { - s.log.Warn("failed to disable write deadline for SSE", "error", err) - } - - if s.metrics != nil { - s.metrics.ActiveSSE.Inc() - defer s.metrics.ActiveSSE.Dec() - } - - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Cache-Control", "no-cache") - w.Header().Set("Connection", "keep-alive") - w.WriteHeader(http.StatusOK) - flusher.Flush() - - keepalive := s.config.Timeouts.SSEKeepalive - if keepalive == 0 { - keepalive = 30 * time.Second - } - ticker := time.NewTicker(keepalive) - defer ticker.Stop() - - for { - select { - case event, ok := <-events: - if !ok { - return - } - data, err := json.Marshal(event) - if err != nil { - s.log.Error("marshal SSE event", "error", err) - continue - } - // SSE spec: each line of a multi-line payload must be prefixed with "data: ". - dataStr := string(data) - lines := strings.Split(dataStr, "\n") - for _, line := range lines { - fmt.Fprintf(w, "data: %s\n", line) - } - fmt.Fprintf(w, "\n") - flusher.Flush() - - if event.StatusUpdate != nil && event.StatusUpdate.Final { - return - } - case <-ticker.C: - fmt.Fprintf(w, ": keepalive\n\n") - flusher.Flush() - case <-r.Context().Done(): - return - } - } -} - -// PushNotificationParams holds parameters for push notification operations. -type PushNotificationParams struct { - TaskID string `json:"taskId"` - ID string `json:"id,omitempty"` - URL string `json:"url,omitempty"` - Token string `json:"token,omitempty"` - AuthScheme string `json:"authScheme,omitempty"` - AuthCredentials string `json:"authCredentials,omitempty"` -} - -func (s *Server) handleSetPushNotification(w http.ResponseWriter, r *http.Request, req JSONRPCRequest, projectSlug, agentSlug string) { - var params PushNotificationParams - if err := json.Unmarshal(req.Params, ¶ms); err != nil { - s.log.Warn("invalid SetPushNotification params", "error", err) - s.writeRPCError(w, req.ID, ErrCodeInvalidParams, "invalid parameters") - return - } - - task, err := s.bridge.AuthorizeTask(params.TaskID, projectSlug, agentSlug) - if err != nil { - s.log.Error("SetPushNotification auth failed", "error", err, "taskID", params.TaskID) - s.writeRPCError(w, req.ID, ErrCodeInternalError, "internal error") - return - } - if task == nil { - s.writeRPCError(w, req.ID, ErrCodeTaskNotFound, "task not found") - return - } - - parsed, err := url.Parse(params.URL) - if err != nil || parsed.Host == "" || (parsed.Scheme != "http" && parsed.Scheme != "https") { - s.writeRPCError(w, req.ID, ErrCodeInvalidParams, "url must be an absolute http or https URL") - return - } - - // SSRF validation is also enforced inside SetPushNotificationConfig (defense-in-depth). - cfg, err := s.bridge.SetPushNotificationConfig(r.Context(), params.TaskID, params.URL, params.Token, params.AuthScheme, params.AuthCredentials) - if err != nil { - s.log.Error("SetPushNotificationConfig failed", "error", err, "taskID", params.TaskID) - s.writeRPCError(w, req.ID, ErrCodeInternalError, "internal error") - return - } - - s.writeRPCResult(w, req.ID, cfg) -} - -func (s *Server) handleGetPushNotification(w http.ResponseWriter, r *http.Request, req JSONRPCRequest, projectSlug, agentSlug string) { - var params PushNotificationParams - if err := json.Unmarshal(req.Params, ¶ms); err != nil { - s.log.Warn("invalid GetPushNotification params", "error", err) - s.writeRPCError(w, req.ID, ErrCodeInvalidParams, "invalid parameters") - return - } - - task, err := s.bridge.AuthorizeTask(params.TaskID, projectSlug, agentSlug) - if err != nil { - s.log.Error("GetPushNotification auth failed", "error", err, "taskID", params.TaskID) - s.writeRPCError(w, req.ID, ErrCodeInternalError, "internal error") - return - } - if task == nil { - s.writeRPCError(w, req.ID, ErrCodeTaskNotFound, "task not found") - return - } - - configs, err := s.bridge.GetPushNotificationConfig(r.Context(), params.TaskID) - if err != nil { - s.log.Error("GetPushNotificationConfig failed", "error", err, "taskID", params.TaskID) - s.writeRPCError(w, req.ID, ErrCodeInternalError, "internal error") - return - } - - s.writeRPCResult(w, req.ID, configs) -} - -func (s *Server) handleDeletePushNotification(w http.ResponseWriter, r *http.Request, req JSONRPCRequest, projectSlug, agentSlug string) { - var params PushNotificationParams - if err := json.Unmarshal(req.Params, ¶ms); err != nil { - s.log.Warn("invalid DeletePushNotification params", "error", err) - s.writeRPCError(w, req.ID, ErrCodeInvalidParams, "invalid parameters") - return - } - - if params.TaskID == "" { - s.writeRPCError(w, req.ID, ErrCodeInvalidParams, "taskId is required") - return - } - - task, err := s.bridge.AuthorizeTask(params.TaskID, projectSlug, agentSlug) - if err != nil { - s.log.Error("DeletePushNotification auth failed", "error", err, "taskID", params.TaskID) - s.writeRPCError(w, req.ID, ErrCodeInternalError, "internal error") - return - } - if task == nil { - s.writeRPCError(w, req.ID, ErrCodeTaskNotFound, "task not found") - return - } - - if err := s.bridge.DeletePushNotificationConfig(r.Context(), params.TaskID, params.ID); err != nil { - s.log.Error("DeletePushNotificationConfig failed", "error", err, "pushID", params.ID) - s.writeRPCError(w, req.ID, ErrCodeInternalError, "internal error") - return - } - - s.writeRPCResult(w, req.ID, map[string]bool{"ok": true}) -} - -// normalizeJSONRPCID ensures the id conforms to JSON-RPC 2.0 (string, number, or null). -// Per §4, fractional numbers and structured values (object/array) are forbidden as IDs. -// We coerce invalid types to null rather than echoing them, accepting that this makes -// client-side correlation impossible for malformed requests. +// normalizeJSONRPCID ensures only valid JSON-RPC 2.0 ID types (string, number, +// null) are echoed back. Arrays, objects, and booleans are replaced with null +// per JSON-RPC 2.0 §4.1. func normalizeJSONRPCID(id interface{}) interface{} { switch id.(type) { - case float64, string: + case nil, string, float64, int, int64: + return id + case json.Number: return id - case nil: - return nil default: return nil } } -func (s *Server) writeRPCResult(w http.ResponseWriter, id interface{}, result interface{}) { - resp := JSONRPCResponse{ +// writeJSONRPCError writes a minimal JSON-RPC error response. +func writeJSONRPCError(w http.ResponseWriter, id interface{}, code int, message string) { + type jsonrpcError struct { + Code int `json:"code"` + Message string `json:"message"` + } + type jsonrpcResponse struct { + JSONRPC string `json:"jsonrpc"` + ID interface{} `json:"id"` + Error *jsonrpcError `json:"error,omitempty"` + } + resp := jsonrpcResponse{ JSONRPC: "2.0", ID: normalizeJSONRPCID(id), - Result: result, + Error: &jsonrpcError{Code: code, Message: message}, } w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(resp); err != nil { - s.log.Error("failed to encode RPC result", "error", err) + slog.Default().Error("failed to encode JSON-RPC error response", "error", err) } } -func (s *Server) writeRPCError(w http.ResponseWriter, id interface{}, code int, message string) { - resp := JSONRPCResponse{ - JSONRPC: "2.0", - ID: normalizeJSONRPCID(id), - Error: &JSONRPCError{Code: code, Message: message}, - } - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(resp); err != nil { - s.log.Error("failed to encode RPC error", "error", err) +// extractCredential extracts the API key or bearer token from an HTTP request +// based on the configured auth scheme. +func extractCredential(scheme string, headers http.Header) string { + switch scheme { + case "apiKey": + return headers.Get("X-API-Key") + case "bearer": + auth := headers.Get("Authorization") + if strings.HasPrefix(auth, "Bearer ") { + return strings.TrimPrefix(auth, "Bearer ") + } + return "" + default: + // When auth.scheme is unset (empty), accept credentials from either + // X-API-Key or Authorization: Bearer headers for convenience. + apiKey := headers.Get("X-API-Key") + if apiKey == "" { + auth := headers.Get("Authorization") + if strings.HasPrefix(auth, "Bearer ") { + apiKey = strings.TrimPrefix(auth, "Bearer ") + } + } + return apiKey } } +// verifyCredential checks that the provided credential matches the configured API key. +func verifyCredential(provided, expected string) bool { + expectedHash := sha256.Sum256([]byte(expected)) + providedHash := sha256.Sum256([]byte(provided)) + return subtle.ConstantTimeCompare(expectedHash[:], providedHash[:]) == 1 +} + // authMiddleware validates API key authentication on non-public endpoints. func (s *Server) authMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -782,35 +362,29 @@ func (s *Server) authMiddleware(next http.Handler) http.Handler { return } - var apiKey string - switch s.config.Auth.Scheme { - case "apiKey": - apiKey = r.Header.Get("X-API-Key") - case "bearer": - auth := r.Header.Get("Authorization") - if strings.HasPrefix(auth, "Bearer ") { - apiKey = strings.TrimPrefix(auth, "Bearer ") - } - default: - // When auth.scheme is unset (empty), accept credentials from either - // X-API-Key or Authorization: Bearer headers for convenience. - apiKey = r.Header.Get("X-API-Key") - if apiKey == "" { - auth := r.Header.Get("Authorization") - if strings.HasPrefix(auth, "Bearer ") { - apiKey = strings.TrimPrefix(auth, "Bearer ") - } - } + apiKey := extractCredential(s.config.Auth.Scheme, r.Header) + if !verifyCredential(apiKey, s.config.Auth.APIKey) { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return } - // Compare SHA-256 hashes to avoid leaking key length via timing. - expectedHash := sha256.Sum256([]byte(s.config.Auth.APIKey)) - providedHash := sha256.Sum256([]byte(apiKey)) - if subtle.ConstantTimeCompare(expectedHash[:], providedHash[:]) != 1 { + next.ServeHTTP(w, r) + }) +} + +// AuthHTTPMiddleware wraps an http.Handler with API key / bearer token validation. +// Used for REST transport auth. Requests without valid credentials are rejected +// with 401. Pass insecure=true to skip auth (requires explicit opt-in via config). +func AuthHTTPMiddleware(cfg *Config, next http.Handler) http.Handler { + if cfg.Auth.Scheme == "none" || cfg.Bridge.RESTInsecure { + return next + } + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + apiKey := extractCredential(cfg.Auth.Scheme, r.Header) + if !verifyCredential(apiKey, cfg.Auth.APIKey) { http.Error(w, "unauthorized", http.StatusUnauthorized) return } - next.ServeHTTP(w, r) }) } diff --git a/extras/scion-a2a-bridge/internal/bridge/server_test.go b/extras/scion-a2a-bridge/internal/bridge/server_test.go index facfccbe8..345546183 100644 --- a/extras/scion-a2a-bridge/internal/bridge/server_test.go +++ b/extras/scion-a2a-bridge/internal/bridge/server_test.go @@ -26,9 +26,35 @@ import ( "testing" "time" + "github.com/a2aproject/a2a-go/v2/a2a" + "github.com/a2aproject/a2a-go/v2/a2asrv" + "github.com/GoogleCloudPlatform/scion/extras/scion-a2a-bridge/internal/state" + "github.com/GoogleCloudPlatform/scion/pkg/messages" ) +// jsonRPCRequest is a test helper for constructing JSON-RPC requests. +type jsonRPCRequest struct { + JSONRPC string `json:"jsonrpc"` + ID interface{} `json:"id"` + Method string `json:"method"` + Params json.RawMessage `json:"params"` +} + +// jsonRPCResponse is a test helper for parsing JSON-RPC responses. +type jsonRPCResponse struct { + JSONRPC string `json:"jsonrpc"` + ID interface{} `json:"id"` + Result interface{} `json:"result,omitempty"` + Error *jsonRPCErr `json:"error,omitempty"` +} + +type jsonRPCErr struct { + Code int `json:"code"` + Message string `json:"message"` + Data interface{} `json:"data,omitempty"` +} + func newTestServer(t *testing.T) (*Server, *httptest.Server, *state.Store) { t.Helper() @@ -60,8 +86,22 @@ func newTestServer(t *testing.T) (*Server, *httptest.Server, *state.Store) { } log := slog.New(slog.NewTextHandler(io.Discard, nil)) - bridge := New(store, nil, nil, cfg, nil, log) - srv := NewServer(bridge, cfg, nil, log) + b := New(store, nil, nil, cfg, nil, log) + + // Create a minimal SDK executor and handler for testing. + executor := NewScionExecutor(b, log) + sdkRequestHandler := a2asrv.NewHandler( + executor, + a2asrv.WithLogger(log), + a2asrv.WithCapabilityChecks(&a2a.AgentCapabilities{ + Streaming: true, + PushNotifications: false, + }), + ) + b.SetSDKRequestHandler(sdkRequestHandler) + sdkJSONRPCHandler := a2asrv.NewJSONRPCHandler(sdkRequestHandler) + + srv := NewServer(b, cfg, nil, log, sdkJSONRPCHandler) ts := httptest.NewServer(srv.Handler()) t.Cleanup(ts.Close) @@ -69,7 +109,7 @@ func newTestServer(t *testing.T) (*Server, *httptest.Server, *state.Store) { return srv, ts, store } -func doRPC(t *testing.T, ts *httptest.Server, path string, method string, params interface{}, apiKey string) *JSONRPCResponse { +func doRPC(t *testing.T, ts *httptest.Server, path string, method string, params interface{}, apiKey string) *jsonRPCResponse { t.Helper() paramsJSON, err := json.Marshal(params) @@ -77,7 +117,7 @@ func doRPC(t *testing.T, ts *httptest.Server, path string, method string, params t.Fatalf("marshal params: %v", err) } - req := JSONRPCRequest{ + req := jsonRPCRequest{ JSONRPC: "2.0", ID: 1, Method: method, @@ -100,7 +140,7 @@ func doRPC(t *testing.T, ts *httptest.Server, path string, method string, params } defer resp.Body.Close() - var rpcResp JSONRPCResponse + var rpcResp jsonRPCResponse if err := json.NewDecoder(resp.Body).Decode(&rpcResp); err != nil { t.Fatalf("decode response: %v", err) } @@ -163,17 +203,6 @@ func TestWellKnownAgentCard(t *testing.T) { if provider["organization"] != "Test Org" { t.Errorf("provider.organization = %q, want %q", provider["organization"], "Test Org") } - - caps, ok := card["capabilities"].(map[string]interface{}) - if !ok { - t.Fatal("expected capabilities object in card") - } - if caps["streaming"] != true { - t.Errorf("capabilities.streaming = %v, want true", caps["streaming"]) - } - if caps["pushNotifications"] != true { - t.Errorf("capabilities.pushNotifications = %v, want true", caps["pushNotifications"]) - } } func TestPerAgentCard(t *testing.T) { @@ -200,17 +229,6 @@ func TestPerAgentCard(t *testing.T) { if card["url"] != expectedURL { t.Errorf("url = %q, want %q", card["url"], expectedURL) } - - caps, ok := card["capabilities"].(map[string]interface{}) - if !ok { - t.Fatal("expected capabilities object in per-agent card") - } - if caps["streaming"] != true { - t.Errorf("capabilities.streaming = %v, want true", caps["streaming"]) - } - if caps["pushNotifications"] != true { - t.Errorf("capabilities.pushNotifications = %v, want true", caps["pushNotifications"]) - } } func TestPerAgentCardNotExposed(t *testing.T) { @@ -255,7 +273,7 @@ func TestAuthMiddleware(t *testing.T) { } // JSON-RPC without auth should be rejected. - rpcReq, _ := json.Marshal(JSONRPCRequest{JSONRPC: "2.0", ID: 1, Method: "tasks/get", Params: json.RawMessage(`{"id":"x"}`)}) + rpcReq, _ := json.Marshal(jsonRPCRequest{JSONRPC: "2.0", ID: 1, Method: "tasks/get", Params: json.RawMessage(`{"id":"x"}`)}) httpReq, _ := http.NewRequest(http.MethodPost, ts.URL+"/projects/test-grove/agents/test-agent/jsonrpc", bytes.NewReader(rpcReq)) httpReq.Header.Set("Content-Type", "application/json") @@ -286,28 +304,16 @@ func TestAuthMiddleware(t *testing.T) { func TestGetTaskNotFound(t *testing.T) { _, ts, _ := newTestServer(t) + // The SDK handler will return TaskNotFound via its own error handling. rpcResp := doRPC(t, ts, "/projects/test-grove/agents/test-agent/jsonrpc", - "tasks/get", TaskQueryParams{ID: "nonexistent-task"}, "test-api-key") + "tasks/get", map[string]interface{}{"id": "nonexistent-task"}, "test-api-key") if rpcResp.Error == nil { t.Fatal("expected error for nonexistent task") } - if rpcResp.Error.Code != ErrCodeTaskNotFound { - t.Errorf("error code = %d, want %d", rpcResp.Error.Code, ErrCodeTaskNotFound) - } -} - -func TestListTasksRequiresContextID(t *testing.T) { - _, ts, _ := newTestServer(t) - - rpcResp := doRPC(t, ts, "/projects/test-grove/agents/test-agent/jsonrpc", - "tasks/list", TaskQueryParams{}, "test-api-key") - - if rpcResp.Error == nil { - t.Fatal("expected error when contextId is missing") - } - if rpcResp.Error.Code != ErrCodeInvalidParams { - t.Errorf("error code = %d, want %d", rpcResp.Error.Code, ErrCodeInvalidParams) + // SDK uses standard A2A error codes. + if rpcResp.Error.Code >= 0 { + t.Errorf("expected negative error code, got %d", rpcResp.Error.Code) } } @@ -320,99 +326,9 @@ func TestUnknownMethod(t *testing.T) { if rpcResp.Error == nil { t.Fatal("expected error for unknown method") } - if rpcResp.Error.Code != ErrCodeMethodNotFound { - t.Errorf("error code = %d, want %d", rpcResp.Error.Code, ErrCodeMethodNotFound) - } -} - -func TestCancelTaskSuccess(t *testing.T) { - dir := t.TempDir() - store, err := state.New(filepath.Join(dir, "cancel-test.db")) - if err != nil { - t.Fatalf("state.New: %v", err) - } - defer store.Close() - - cfg := &Config{ - Bridge: BridgeConfig{ - ExternalURL: "https://a2a.test.example.com", - Provider: ProviderConfig{Organization: "Test Org", URL: "https://test.example.com"}, - }, - Auth: AuthConfig{Scheme: "apiKey", APIKey: "test-api-key"}, - Projects: []ProjectConfig{ - {Slug: "test-grove", ExposedAgents: []string{"test-agent"}}, - }, - } - - log := slog.New(slog.NewTextHandler(io.Discard, nil)) - bridge := New(store, nil, nil, cfg, nil, log) - srv := NewServer(bridge, cfg, nil, log) - ts2 := httptest.NewServer(srv.Handler()) - defer ts2.Close() - - now := time.Now() - store.CreateTask(&state.Task{ - ID: "cancel-me", ContextID: "ctx-1", ProjectID: "test-grove", AgentSlug: "test-agent", - State: "working", CreatedAt: now, UpdatedAt: now, Metadata: "{}", - }) - - rpcResp := doRPC(t, ts2, "/projects/test-grove/agents/test-agent/jsonrpc", - "tasks/cancel", map[string]string{"id": "cancel-me"}, "test-api-key") - - if rpcResp.Error != nil { - t.Fatalf("unexpected error: code=%d msg=%s", rpcResp.Error.Code, rpcResp.Error.Message) - } - - resultBytes, _ := json.Marshal(rpcResp.Result) - var result TaskResult - if err := json.Unmarshal(resultBytes, &result); err != nil { - t.Fatalf("unmarshal result: %v", err) - } - if result.Status.State != TaskStateCanceled { - t.Errorf("status.state = %q, want %q", result.Status.State, TaskStateCanceled) - } - - // Verify the store was updated. - task, _ := store.GetTask("cancel-me") - if task.State != TaskStateCanceled { - t.Errorf("store state = %q, want %q", task.State, TaskStateCanceled) - } -} - -func TestCancelTaskAlreadyTerminal(t *testing.T) { - dir := t.TempDir() - store, err := state.New(filepath.Join(dir, "cancel-terminal.db")) - if err != nil { - t.Fatalf("state.New: %v", err) - } - defer store.Close() - - cfg := &Config{ - Bridge: BridgeConfig{ExternalURL: "https://a2a.test.example.com"}, - Auth: AuthConfig{Scheme: "apiKey", APIKey: "test-api-key"}, - Projects: []ProjectConfig{{Slug: "test-grove", ExposedAgents: []string{"test-agent"}}}, - } - - log := slog.New(slog.NewTextHandler(io.Discard, nil)) - bridge := New(store, nil, nil, cfg, nil, log) - srv := NewServer(bridge, cfg, nil, log) - ts := httptest.NewServer(srv.Handler()) - defer ts.Close() - - now := time.Now() - store.CreateTask(&state.Task{ - ID: "done-task", ContextID: "ctx-1", ProjectID: "test-grove", AgentSlug: "test-agent", - State: TaskStateCompleted, CreatedAt: now, UpdatedAt: now, Metadata: "{}", - }) - - rpcResp := doRPC(t, ts, "/projects/test-grove/agents/test-agent/jsonrpc", - "tasks/cancel", map[string]string{"id": "done-task"}, "test-api-key") - - if rpcResp.Error == nil { - t.Fatal("expected error when canceling a completed task") - } - if rpcResp.Error.Code != ErrCodeTaskNotCancelable { - t.Errorf("error code = %d, want %d", rpcResp.Error.Code, ErrCodeTaskNotCancelable) + // -32601 is method not found in JSON-RPC spec. + if rpcResp.Error.Code != -32601 { + t.Errorf("error code = %d, want -32601", rpcResp.Error.Code) } } @@ -425,9 +341,6 @@ func TestCancelTaskNotFound(t *testing.T) { if rpcResp.Error == nil { t.Fatal("expected error for cancel of nonexistent task") } - if rpcResp.Error.Code != ErrCodeTaskNotFound { - t.Errorf("error code = %d, want %d", rpcResp.Error.Code, ErrCodeTaskNotFound) - } } func TestInvalidJSONRPC(t *testing.T) { @@ -450,15 +363,12 @@ func TestInvalidJSONRPC(t *testing.T) { } defer resp.Body.Close() - var rpcResp JSONRPCResponse + var rpcResp jsonRPCResponse json.NewDecoder(resp.Body).Decode(&rpcResp) if rpcResp.Error == nil { t.Fatal("expected error for invalid JSON-RPC version") } - if rpcResp.Error.Code != ErrCodeInvalidRequest { - t.Errorf("error code = %d, want %d", rpcResp.Error.Code, ErrCodeInvalidRequest) - } } func TestMalformedJSON(t *testing.T) { @@ -475,214 +385,84 @@ func TestMalformedJSON(t *testing.T) { } defer resp.Body.Close() - var rpcResp JSONRPCResponse + var rpcResp jsonRPCResponse json.NewDecoder(resp.Body).Decode(&rpcResp) if rpcResp.Error == nil { t.Fatal("expected parse error") } - if rpcResp.Error.Code != ErrCodeParseError { - t.Errorf("error code = %d, want %d", rpcResp.Error.Code, ErrCodeParseError) + // -32700 is parse error in JSON-RPC spec. + if rpcResp.Error.Code != -32700 { + t.Errorf("error code = %d, want -32700", rpcResp.Error.Code) } } -// --- Phase 2 server tests --- - -func TestPushNotificationSetGetDelete(t *testing.T) { +func TestJSONRPCDeniesNonExposedAgent(t *testing.T) { _, ts, _ := newTestServer(t) - // Create a task first (needed for push config FK). - rpcPath := "/projects/test-grove/agents/test-agent/jsonrpc" - - // Create a task directly in the store via the test bridge. - // We access it indirectly by creating it in the store. - dir := t.TempDir() - store, err := state.New(filepath.Join(dir, "push-test.db")) - if err != nil { - t.Fatal(err) - } - defer store.Close() - - now := time.Now() - store.CreateTask(&state.Task{ - ID: "push-task-1", ContextID: "ctx-1", ProjectID: "test-grove", AgentSlug: "test-agent", - State: "working", CreatedAt: now, UpdatedAt: now, Metadata: "{}", - }) - - // Set push config — this test verifies the JSON-RPC dispatch works even though - // the task is in a different store. The server handler delegates to bridge which - // uses its own store, so we test the handler's param parsing and error paths. - rpcResp := doRPC(t, ts, rpcPath, - "tasks/pushNotification/set", - PushNotificationParams{ - TaskID: "nonexistent-task", - URL: "https://example.com/webhook", - Token: "tok", - }, - "test-api-key", - ) - - // Should fail because task doesn't exist in the server's store. - if rpcResp.Error == nil { - t.Fatal("expected error for nonexistent task") + methods := []string{ + "message/send", + "tasks/get", + "tasks/cancel", } -} - -func TestPushNotificationSetRejectsPrivateIP(t *testing.T) { - _, ts, store := newTestServer(t) - rpcPath := "/projects/test-grove/agents/test-agent/jsonrpc" - - now := time.Now() - store.CreateTask(&state.Task{ - ID: "push-priv-task", ContextID: "ctx-1", ProjectID: "test-grove", AgentSlug: "test-agent", - State: "working", CreatedAt: now, UpdatedAt: now, Metadata: "{}", - }) - cases := []struct { - name string - url string - }{ - {"loopback", "https://127.0.0.1/webhook"}, - {"metadata", "https://169.254.169.254/latest/meta-data/"}, - {"rfc1918-10", "https://10.0.0.1/hook"}, - {"rfc1918-172", "https://172.16.0.1/hook"}, - {"rfc1918-192", "https://192.168.1.1/hook"}, - {"unspecified", "https://0.0.0.0/hook"}, - } - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - rpcResp := doRPC(t, ts, rpcPath, - "tasks/pushNotification/set", - PushNotificationParams{ - TaskID: "push-priv-task", - URL: tc.url, - Token: "tok", - }, - "test-api-key", - ) + for _, method := range methods { + t.Run("hidden-agent/"+method, func(t *testing.T) { + rpcResp := doRPC(t, ts, "/projects/test-grove/agents/hidden-agent/jsonrpc", + method, map[string]string{"id": "x"}, "test-api-key") if rpcResp.Error == nil { - t.Fatal("expected error for private IP URL") + t.Fatalf("expected error for non-exposed agent on %s", method) } - if rpcResp.Error.Code != ErrCodeInternalError { - t.Errorf("error code = %d, want %d", rpcResp.Error.Code, ErrCodeInternalError) + if rpcResp.Error.Message != "agent not found" { + t.Errorf("error message = %q, want %q", rpcResp.Error.Message, "agent not found") } }) - } -} - -func TestPushNotificationGetReturnsEmpty(t *testing.T) { - _, ts, store := newTestServer(t) - - now := time.Now() - store.CreateTask(&state.Task{ - ID: "push-get-task", ContextID: "ctx-1", ProjectID: "test-grove", AgentSlug: "test-agent", - State: "working", CreatedAt: now, UpdatedAt: now, Metadata: "{}", - }) - - rpcResp := doRPC(t, ts, "/projects/test-grove/agents/test-agent/jsonrpc", - "tasks/pushNotification/get", - PushNotificationParams{TaskID: "push-get-task"}, - "test-api-key", - ) - - // Should succeed with empty result (no configs). - if rpcResp.Error != nil { - t.Fatalf("unexpected error: %s", rpcResp.Error.Message) - } -} - -func TestPushNotificationDeleteNonexistent(t *testing.T) { - _, ts, store := newTestServer(t) - now := time.Now() - store.CreateTask(&state.Task{ - ID: "push-del-task", ContextID: "ctx-1", ProjectID: "test-grove", AgentSlug: "test-agent", - State: "working", CreatedAt: now, UpdatedAt: now, Metadata: "{}", - }) - - rpcResp := doRPC(t, ts, "/projects/test-grove/agents/test-agent/jsonrpc", - "tasks/pushNotification/delete", - PushNotificationParams{TaskID: "push-del-task", ID: "nonexistent-push-id"}, - "test-api-key", - ) + t.Run("unknown-project/"+method, func(t *testing.T) { + rpcResp := doRPC(t, ts, "/projects/unknown-grove/agents/test-agent/jsonrpc", + method, map[string]string{"id": "x"}, "test-api-key") - if rpcResp.Error == nil { - t.Fatal("expected error when deleting nonexistent push config") - } - if rpcResp.Error.Code != ErrCodeInternalError { - t.Errorf("error code = %d, want %d", rpcResp.Error.Code, ErrCodeInternalError) + if rpcResp.Error == nil { + t.Fatalf("expected error for unknown project on %s", method) + } + }) } } -func TestStreamMethodInvalidParams(t *testing.T) { +func TestLegacyGrovePath(t *testing.T) { _, ts, _ := newTestServer(t) - // Send a raw JSON string that can't be unmarshaled to SendMessageParams. - rpcReq := JSONRPCRequest{ - JSONRPC: "2.0", - ID: 1, - Method: "message/stream", - Params: json.RawMessage(`"not an object"`), - } - body, _ := json.Marshal(rpcReq) - httpReq, _ := http.NewRequest(http.MethodPost, - ts.URL+"/projects/test-grove/agents/test-agent/jsonrpc", bytes.NewReader(body)) - httpReq.Header.Set("Content-Type", "application/json") - httpReq.Header.Set("X-API-Key", "test-api-key") - - resp, err := http.DefaultClient.Do(httpReq) + // Test legacy .well-known path (public access) + resp, err := http.Get(ts.URL + "/groves/test-grove/agents/test-agent/.well-known/agent-card.json") if err != nil { - t.Fatalf("do request: %v", err) + t.Fatalf("GET legacy agent card: %v", err) } defer resp.Body.Close() - var rpcResp JSONRPCResponse - json.NewDecoder(resp.Body).Decode(&rpcResp) - - if rpcResp.Error == nil { - t.Fatal("expected error for invalid params") - } - if rpcResp.Error.Code != ErrCodeInvalidParams { - t.Errorf("error code = %d, want %d", rpcResp.Error.Code, ErrCodeInvalidParams) + if resp.StatusCode != http.StatusOK { + t.Errorf("status = %d, want 200", resp.StatusCode) } -} -func TestResubscribeTaskNotFound(t *testing.T) { - _, ts, _ := newTestServer(t) - - rpcResp := doRPC(t, ts, "/projects/test-grove/agents/test-agent/jsonrpc", - "tasks/resubscribe", - TaskQueryParams{ID: "nonexistent-task"}, - "test-api-key", - ) + // Test legacy JSON-RPC path (requires auth) + rpcReq, _ := json.Marshal(jsonRPCRequest{JSONRPC: "2.0", ID: 1, Method: "tasks/get", Params: json.RawMessage(`{"id":"x"}`)}) + httpReq, _ := http.NewRequest(http.MethodPost, ts.URL+"/groves/test-grove/agents/test-agent/jsonrpc", bytes.NewReader(rpcReq)) + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("X-API-Key", "test-api-key") - if rpcResp.Error == nil { - t.Fatal("expected error for nonexistent task") + resp, err = http.DefaultClient.Do(httpReq) + if err != nil { + t.Fatal(err) } -} - -func TestResubscribeRequiresID(t *testing.T) { - _, ts, _ := newTestServer(t) - - rpcResp := doRPC(t, ts, "/projects/test-grove/agents/test-agent/jsonrpc", - "tasks/resubscribe", - TaskQueryParams{}, - "test-api-key", - ) + defer resp.Body.Close() - if rpcResp.Error == nil { - t.Fatal("expected error for empty task ID") - } - if rpcResp.Error.Code != ErrCodeInvalidParams { - t.Errorf("error code = %d, want %d", rpcResp.Error.Code, ErrCodeInvalidParams) + // Should be 200 OK (the actual RPC might fail with "task not found" but the route should be authorized) + if resp.StatusCode != http.StatusOK { + t.Errorf("legacy RPC: status = %d, want 200", resp.StatusCode) } } func TestAuthorizeTaskReturnsNilNil(t *testing.T) { - _, _, store := newTestServer(t) - dir := t.TempDir() s, err := state.New(filepath.Join(dir, "auth-test.db")) if err != nil { @@ -697,8 +477,6 @@ func TestAuthorizeTaskReturnsNilNil(t *testing.T) { b := New(s, nil, nil, cfg, nil, log) now := time.Now() - _ = store // use the outer store for unrelated setup - s.CreateTask(&state.Task{ ID: "owned-task", ContextID: "ctx-1", ProjectID: "grove-a", AgentSlug: "agent-x", State: "working", CreatedAt: now, UpdatedAt: now, Metadata: "{}", @@ -732,192 +510,216 @@ func TestAuthorizeTaskReturnsNilNil(t *testing.T) { } } -func TestJSONRPCDeniesNonExposedAgent(t *testing.T) { - _, ts, _ := newTestServer(t) +func TestRouteInfoContext(t *testing.T) { + ctx := WithRouteInfo(context.Background(), RouteInfo{ProjectSlug: "proj", AgentSlug: "agt"}) + info, ok := RouteInfoFrom(ctx) + if !ok { + t.Fatal("expected route info in context") + } + if info.ProjectSlug != "proj" || info.AgentSlug != "agt" { + t.Errorf("RouteInfo = %+v, want {proj, agt}", info) + } +} - methods := []string{ - "message/send", - "tasks/get", - "tasks/list", - "tasks/cancel", - "tasks/resubscribe", - "tasks/pushNotification/set", - "tasks/pushNotification/get", - "tasks/pushNotification/delete", +func TestRouteInfoContextMissing(t *testing.T) { + _, ok := RouteInfoFrom(context.Background()) + if ok { + t.Fatal("expected no route info in empty context") } +} - for _, method := range methods { - t.Run("hidden-agent/"+method, func(t *testing.T) { - rpcResp := doRPC(t, ts, "/projects/test-grove/agents/hidden-agent/jsonrpc", - method, map[string]string{"id": "x"}, "test-api-key") +func TestRouteInfoMiddleware(t *testing.T) { + route := RouteInfo{ProjectSlug: "test-proj", AgentSlug: "test-agent"} - if rpcResp.Error == nil { - t.Fatalf("expected error for non-exposed agent on %s", method) - } - if rpcResp.Error.Code != ErrCodeInvalidParams { - t.Errorf("error code = %d, want %d", rpcResp.Error.Code, ErrCodeInvalidParams) - } - if rpcResp.Error.Message != "agent not found" { - t.Errorf("error message = %q, want %q", rpcResp.Error.Message, "agent not found") - } - }) + var capturedRoute RouteInfo + var capturedOK bool - t.Run("unknown-project/"+method, func(t *testing.T) { - rpcResp := doRPC(t, ts, "/projects/unknown-grove/agents/test-agent/jsonrpc", - method, map[string]string{"id": "x"}, "test-api-key") + inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedRoute, capturedOK = RouteInfoFrom(r.Context()) + w.WriteHeader(http.StatusOK) + }) - if rpcResp.Error == nil { - t.Fatalf("expected error for unknown project on %s", method) - } - if rpcResp.Error.Code != ErrCodeInvalidParams { - t.Errorf("error code = %d, want %d", rpcResp.Error.Code, ErrCodeInvalidParams) + handler := RouteInfoMiddleware(route, inner) + ts := httptest.NewServer(handler) + defer ts.Close() + + resp, err := http.Get(ts.URL + "/test") + if err != nil { + t.Fatalf("GET: %v", err) + } + resp.Body.Close() + + if !capturedOK { + t.Fatal("expected route info in context") + } + if capturedRoute.ProjectSlug != "test-proj" || capturedRoute.AgentSlug != "test-agent" { + t.Errorf("RouteInfo = %+v, want {test-proj, test-agent}", capturedRoute) + } +} + +func TestNormalizeJSONRPCID(t *testing.T) { + tests := []struct { + name string + id interface{} + want interface{} + }{ + {"nil", nil, nil}, + {"string", "abc", "abc"}, + {"float64", float64(42), float64(42)}, + {"int", int(7), int(7)}, + {"json.Number", json.Number("99"), json.Number("99")}, + {"array rejected", []int{1, 2}, nil}, + {"object rejected", map[string]int{"a": 1}, nil}, + {"bool rejected", true, nil}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := normalizeJSONRPCID(tt.id) + if got != tt.want { + t.Errorf("normalizeJSONRPCID(%v) = %v (%T), want %v (%T)", tt.id, got, got, tt.want, tt.want) } }) } } -func TestNewRPCMethods(t *testing.T) { +func TestMaxBytesReaderOnJSONRPC(t *testing.T) { _, ts, _ := newTestServer(t) - // Verify these methods are recognized (not "method not found"). - // message/stream and tasks/resubscribe are excluded because they trigger - // resolveContext which requires a hub client (nil in test fixture). - methods := []string{ - "tasks/pushNotification/set", - "tasks/pushNotification/get", - "tasks/pushNotification/delete", - "tasks/resubscribe", + // Send a body larger than 1MB. + bigBody := make([]byte, 2<<20) // 2MB + for i := range bigBody { + bigBody[i] = 'a' } - for _, method := range methods { - t.Run(method, func(t *testing.T) { - rpcResp := doRPC(t, ts, "/projects/test-grove/agents/test-agent/jsonrpc", - method, - map[string]string{}, - "test-api-key", - ) - - if rpcResp.Error != nil && rpcResp.Error.Code == ErrCodeMethodNotFound { - t.Errorf("method %q should be registered but got method not found", method) - } - }) - } -} + httpReq, _ := http.NewRequest(http.MethodPost, ts.URL+"/projects/test-grove/agents/test-agent/jsonrpc", bytes.NewReader(bigBody)) + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("X-API-Key", "test-api-key") -func TestGenerateAgentCardCapabilities(t *testing.T) { - dir := t.TempDir() - store, err := state.New(filepath.Join(dir, "caps-test.db")) + resp, err := http.DefaultClient.Do(httpReq) if err != nil { - t.Fatalf("state.New: %v", err) + t.Fatal(err) } - defer store.Close() + defer resp.Body.Close() + + // MaxBytesReader causes the SDK handler's body read to fail. + // The response should either be an error status (413) or contain + // a JSON-RPC error (parse error). We accept any non-success outcome. + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode == http.StatusOK { + // If 200, verify the JSON-RPC response contains an error. + var rpcResp jsonRPCResponse + if json.Unmarshal(body, &rpcResp) == nil && rpcResp.Error == nil { + t.Error("expected error for oversized request body") + } + } +} +func TestValidateConfigGRPCInsecureRequired(t *testing.T) { cfg := &Config{ Bridge: BridgeConfig{ - ExternalURL: "https://a2a.test.example.com", + ExternalURL: "https://test.example.com", + GRPCListenAddress: ":50051", + // GRPCInsecure not set }, + Hub: HubConfig{Endpoint: "https://hub.example.com", User: "test"}, + Auth: AuthConfig{Scheme: "none"}, } - log := slog.New(slog.NewTextHandler(io.Discard, nil)) - bridge := New(store, nil, nil, cfg, nil, log) - - card := bridge.GenerateAgentCard(context.Background(), "test-project", "test-agent") - - caps, ok := card["capabilities"].(map[string]bool) - if !ok { - t.Fatal("expected capabilities to be map[string]bool") + err := ValidateConfig(cfg) + if err == nil { + t.Fatal("expected error for gRPC without grpc_insecure when auth is none") } - if !caps["streaming"] { - t.Error("capabilities.streaming should be true") - } - if !caps["pushNotifications"] { - t.Error("capabilities.pushNotifications should be true") + if !bytes.Contains([]byte(err.Error()), []byte("grpc_insecure")) { + t.Errorf("error should mention grpc_insecure: %v", err) } - // Verify other required fields are present. - if card["name"] != "test-agent" { - t.Errorf("name = %q, want %q", card["name"], "test-agent") - } - expectedURL := "https://a2a.test.example.com/projects/test-project/agents/test-agent" - if card["url"] != expectedURL { - t.Errorf("url = %q, want %q", card["url"], expectedURL) - } - if card["version"] != "1.0.0" { - t.Errorf("version = %q, want %q", card["version"], "1.0.0") + // With GRPCInsecure set, validation should pass (for this check). + cfg.Bridge.GRPCInsecure = true + err = ValidateConfig(cfg) + if err != nil && bytes.Contains([]byte(err.Error()), []byte("grpc_insecure")) { + t.Errorf("should not error with grpc_insecure set: %v", err) } } -func TestRegistryAndPerAgentCardCapabilitiesMatch(t *testing.T) { - _, ts, _ := newTestServer(t) - - // Fetch registry card. - resp, err := http.Get(ts.URL + "/.well-known/agent-card.json") +func TestDispatchToWaiterAgentSlugVerification(t *testing.T) { + dir := t.TempDir() + store, err := state.New(filepath.Join(dir, "waiter-test.db")) if err != nil { - t.Fatalf("GET registry card: %v", err) - } - defer resp.Body.Close() - var registryCard map[string]interface{} - json.NewDecoder(resp.Body).Decode(®istryCard) - - registryCaps, ok := registryCard["capabilities"].(map[string]interface{}) - if !ok { - t.Fatal("expected capabilities in registry card") + t.Fatal(err) } + defer store.Close() - // Fetch per-agent card. - resp2, err := http.Get(ts.URL + "/projects/test-grove/agents/test-agent/.well-known/agent-card.json") - if err != nil { - t.Fatalf("GET per-agent card: %v", err) + cfg := &Config{ + Bridge: BridgeConfig{ExternalURL: "https://a2a.test.example.com"}, + Timeouts: TimeoutConfig{SendMessage: 10 * time.Second}, } - defer resp2.Body.Close() - var agentCard map[string]interface{} - json.NewDecoder(resp2.Body).Decode(&agentCard) + log := slog.New(slog.NewTextHandler(io.Discard, nil)) + b := New(store, nil, nil, cfg, nil, log) + defer b.Shutdown() + + // Register a waiter for agent-a. + ch := make(chan *messages.StructuredMessage, 1) + b.addWaiter("task-1", &waiter{ + ch: ch, + agentSlug: "agent-a", + projectID: "proj-1", + }) - agentCaps, ok := agentCard["capabilities"].(map[string]interface{}) - if !ok { - t.Fatal("expected capabilities in per-agent card") + // Message from the correct agent should be dispatched. + correctMsg := &messages.StructuredMessage{ + Sender: "agent:agent-a", + Msg: "hello from agent-a", } - - // Capabilities should be identical. - for key, regVal := range registryCaps { - if agentCaps[key] != regVal { - t.Errorf("capability %q: registry=%v, agent=%v", key, regVal, agentCaps[key]) - } + if !b.dispatchToWaiter("task-1", correctMsg) { + t.Error("dispatchToWaiter should return true for matching agent") } - for key, agentVal := range agentCaps { - if registryCaps[key] != agentVal { - t.Errorf("capability %q: agent=%v, registry=%v", key, agentVal, registryCaps[key]) + select { + case got := <-ch: + if got.Msg != "hello from agent-a" { + t.Errorf("expected message from agent-a, got %q", got.Msg) } + default: + t.Error("expected message in channel for matching agent") } -} -func TestLegacyGrovePath(t *testing.T) { - _, ts, _ := newTestServer(t) - - // Test legacy .well-known path (public access) - resp, err := http.Get(ts.URL + "/groves/test-grove/agents/test-agent/.well-known/agent-card.json") - if err != nil { - t.Fatalf("GET legacy agent card: %v", err) + // Message from a different agent should be rejected. + wrongMsg := &messages.StructuredMessage{ + Sender: "agent:agent-b", + Msg: "hello from agent-b", } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - t.Errorf("status = %d, want 200", resp.StatusCode) + if !b.dispatchToWaiter("task-1", wrongMsg) { + t.Error("dispatchToWaiter should return true (consumed) even for wrong agent") + } + select { + case got := <-ch: + t.Errorf("should not receive message from wrong agent, got %q", got.Msg) + default: + // correct — message was rejected } - // Test legacy JSON-RPC path (requires auth) - rpcReq, _ := json.Marshal(JSONRPCRequest{JSONRPC: "2.0", ID: 1, Method: "tasks/get", Params: json.RawMessage(`{"id":"x"}`)}) - httpReq, _ := http.NewRequest(http.MethodPost, ts.URL+"/groves/test-grove/agents/test-agent/jsonrpc", bytes.NewReader(rpcReq)) - httpReq.Header.Set("Content-Type", "application/json") - httpReq.Header.Set("X-API-Key", "test-api-key") + b.removeWaiter("task-1") +} - resp, err = http.DefaultClient.Do(httpReq) - if err != nil { - t.Fatal(err) +func TestValidateConfigRESTInsecureRequired(t *testing.T) { + cfg := &Config{ + Bridge: BridgeConfig{ + ExternalURL: "https://test.example.com", + RESTListenAddress: ":8080", + // RESTInsecure not set + }, + Hub: HubConfig{Endpoint: "https://hub.example.com", User: "test"}, + Auth: AuthConfig{Scheme: "none"}, + } + err := ValidateConfig(cfg) + if err == nil { + t.Fatal("expected error for REST without rest_insecure when auth is none") + } + if !bytes.Contains([]byte(err.Error()), []byte("rest_insecure")) { + t.Errorf("error should mention rest_insecure: %v", err) } - defer resp.Body.Close() - // Should be 200 OK (the actual RPC might fail with "task not found" but the route should be authorized) - if resp.StatusCode != http.StatusOK { - t.Errorf("legacy RPC: status = %d, want 200", resp.StatusCode) + cfg.Bridge.RESTInsecure = true + err = ValidateConfig(cfg) + if err != nil && bytes.Contains([]byte(err.Error()), []byte("rest_insecure")) { + t.Errorf("should not error with rest_insecure set: %v", err) } } diff --git a/extras/scion-a2a-bridge/internal/bridge/translate.go b/extras/scion-a2a-bridge/internal/bridge/translate.go index da739181d..fbba4c86e 100644 --- a/extras/scion-a2a-bridge/internal/bridge/translate.go +++ b/extras/scion-a2a-bridge/internal/bridge/translate.go @@ -19,6 +19,8 @@ import ( "strings" "time" + "github.com/a2aproject/a2a-go/v2/a2a" + "github.com/GoogleCloudPlatform/scion/pkg/messages" "github.com/google/uuid" ) @@ -153,6 +155,9 @@ func TranslateA2AToScion(parts []Part) *messages.StructuredMessage { // TranslateScionToA2A converts a Scion StructuredMessage into an A2A Message and optional Artifacts. func TranslateScionToA2A(msg *messages.StructuredMessage) (Message, []Artifact) { + if msg == nil { + return Message{MessageID: uuid.New().String(), Role: RoleAgent}, nil + } parts := []Part{{Text: msg.Msg, MediaType: "text/plain"}} for _, att := range msg.Attachments { @@ -177,3 +182,95 @@ func TranslateScionToA2A(msg *messages.StructuredMessage) (Message, []Artifact) return message, artifacts } + +// --- SDK-compatible translation functions --- + +// TranslateA2APartsToScion converts SDK a2a.ContentParts into a Scion StructuredMessage. +func TranslateA2APartsToScion(parts a2a.ContentParts) *messages.StructuredMessage { + var textContent strings.Builder + var attachments []string + + for _, part := range parts { + switch v := part.Content.(type) { + case a2a.Text: + if textContent.Len() > 0 { + textContent.WriteString("\n") + } + textContent.WriteString(string(v)) + case a2a.URL: + attachments = append(attachments, string(v)) + case a2a.Data: + jsonBytes, err := json.Marshal(v.Value) + if err == nil { + if textContent.Len() > 0 { + textContent.WriteString("\n") + } + textContent.WriteString(string(jsonBytes)) + } + } + } + + msg := textContent.String() + if msg == "" { + if len(attachments) > 0 { + msg = "[A2A request with attachments only]" + } else { + msg = "[empty A2A request]" + } + } + + return &messages.StructuredMessage{ + Version: 1, + Timestamp: time.Now().UTC().Format(time.RFC3339), + Msg: msg, + Type: messages.TypeInstruction, + Attachments: attachments, + } +} + +// TranslateScionToA2AParts converts a Scion StructuredMessage into SDK a2a types. +// Returns parts for the agent message and artifacts for content delivery. +func TranslateScionToA2AParts(msg *messages.StructuredMessage) (*a2a.Message, []*a2a.Artifact) { + if msg == nil { + return nil, nil + } + var sdkParts []*a2a.Part + sdkParts = append(sdkParts, &a2a.Part{Content: a2a.Text(msg.Msg), MediaType: "text/plain"}) + + for _, att := range msg.Attachments { + sdkParts = append(sdkParts, &a2a.Part{Content: a2a.URL(att)}) + } + + message := a2a.NewMessage(a2a.MessageRoleAgent, sdkParts...) + + var artifacts []*a2a.Artifact + switch msg.Type { + case "", messages.TypeInstruction, messages.TypeAssistantReply: + artifacts = append(artifacts, &a2a.Artifact{ + ID: a2a.NewArtifactID(), + Parts: sdkParts, + }) + } + + return message, artifacts +} + +// MapActivityToSDKTaskState maps a Scion agent activity string to an SDK a2a.TaskState. +func MapActivityToSDKTaskState(activity string) a2a.TaskState { + switch strings.ToUpper(activity) { + case "WORKING": + return a2a.TaskStateWorking + case "THINKING", "EXECUTING": + return a2a.TaskStateWorking + case "WAITING_FOR_INPUT": + return a2a.TaskStateInputRequired + case "COMPLETED": + return a2a.TaskStateCompleted + case "ERROR": + return a2a.TaskStateFailed + case "STALLED", "LIMITS_EXCEEDED", "OFFLINE": + return a2a.TaskStateFailed + default: + return a2a.TaskStateWorking + } +} diff --git a/extras/scion-a2a-bridge/internal/bridge/translate_test.go b/extras/scion-a2a-bridge/internal/bridge/translate_test.go index e834d6d97..b9896fb6d 100644 --- a/extras/scion-a2a-bridge/internal/bridge/translate_test.go +++ b/extras/scion-a2a-bridge/internal/bridge/translate_test.go @@ -17,6 +17,8 @@ package bridge import ( "testing" + "github.com/a2aproject/a2a-go/v2/a2a" + "github.com/GoogleCloudPlatform/scion/pkg/messages" ) @@ -153,3 +155,154 @@ func TestTranslateScionToA2AStateChange(t *testing.T) { t.Errorf("Artifacts = %d, want 0 for state-change messages", len(artifacts)) } } + +// --- SDK translation function tests --- + +func TestTranslateA2APartsToScionText(t *testing.T) { + parts := a2a.ContentParts{ + {Content: a2a.Text("Hello"), MediaType: "text/plain"}, + {Content: a2a.Text("World"), MediaType: "text/plain"}, + } + + msg := TranslateA2APartsToScion(parts) + + if msg.Msg != "Hello\nWorld" { + t.Errorf("Msg = %q, want %q", msg.Msg, "Hello\nWorld") + } + if msg.Type != messages.TypeInstruction { + t.Errorf("Type = %q, want %q", msg.Type, messages.TypeInstruction) + } + if msg.Version != 1 { + t.Errorf("Version = %d, want 1", msg.Version) + } + if msg.Timestamp == "" { + t.Error("expected non-empty Timestamp") + } +} + +func TestTranslateA2APartsToScionURL(t *testing.T) { + parts := a2a.ContentParts{ + {Content: a2a.Text("See this file:"), MediaType: "text/plain"}, + {Content: a2a.URL("https://example.com/file.pdf")}, + } + + msg := TranslateA2APartsToScion(parts) + + if msg.Msg != "See this file:" { + t.Errorf("Msg = %q, want %q", msg.Msg, "See this file:") + } + if len(msg.Attachments) != 1 || msg.Attachments[0] != "https://example.com/file.pdf" { + t.Errorf("Attachments = %v, want [https://example.com/file.pdf]", msg.Attachments) + } +} + +func TestTranslateA2APartsToScionData(t *testing.T) { + parts := a2a.ContentParts{ + {Content: a2a.Data{Value: map[string]interface{}{"key": "value"}}}, + } + + msg := TranslateA2APartsToScion(parts) + + if msg.Msg != `{"key":"value"}` { + t.Errorf("Msg = %q, want JSON data", msg.Msg) + } +} + +func TestTranslateA2APartsToScionEmpty(t *testing.T) { + msg := TranslateA2APartsToScion(nil) + + if msg.Msg != "[empty A2A request]" { + t.Errorf("Msg = %q, want %q", msg.Msg, "[empty A2A request]") + } +} + +func TestTranslateA2APartsToScionAttachmentOnly(t *testing.T) { + parts := a2a.ContentParts{ + {Content: a2a.URL("https://example.com/data.csv")}, + } + + msg := TranslateA2APartsToScion(parts) + + if msg.Msg != "[A2A request with attachments only]" { + t.Errorf("Msg = %q, want attachment-only placeholder", msg.Msg) + } + if len(msg.Attachments) != 1 { + t.Errorf("Attachments = %d, want 1", len(msg.Attachments)) + } +} + +func TestTranslateScionToA2AParts(t *testing.T) { + scionMsg := &messages.StructuredMessage{ + Version: 1, + Msg: "Agent response text", + Type: messages.TypeAssistantReply, + Attachments: []string{"https://example.com/output.pdf"}, + } + + message, artifacts := TranslateScionToA2AParts(scionMsg) + + if message == nil { + t.Fatal("expected non-nil message") + } + if message.Role != a2a.MessageRoleAgent { + t.Errorf("Role = %v, want %v", message.Role, a2a.MessageRoleAgent) + } + if len(message.Parts) != 2 { + t.Fatalf("Parts = %d, want 2", len(message.Parts)) + } + if text, ok := message.Parts[0].Content.(a2a.Text); !ok || string(text) != "Agent response text" { + t.Errorf("Parts[0] = %v, want Text('Agent response text')", message.Parts[0].Content) + } + if url, ok := message.Parts[1].Content.(a2a.URL); !ok || string(url) != "https://example.com/output.pdf" { + t.Errorf("Parts[1] = %v, want URL attachment", message.Parts[1].Content) + } + + if len(artifacts) != 1 { + t.Fatalf("Artifacts = %d, want 1 for assistant reply", len(artifacts)) + } + if artifacts[0].ID == "" { + t.Error("expected non-empty artifact ID") + } +} + +func TestTranslateScionToA2APartsStateChange(t *testing.T) { + scionMsg := &messages.StructuredMessage{ + Version: 1, + Msg: "State changed", + Type: messages.TypeStateChange, + } + + _, artifacts := TranslateScionToA2AParts(scionMsg) + + if len(artifacts) != 0 { + t.Errorf("Artifacts = %d, want 0 for state-change", len(artifacts)) + } +} + +func TestMapActivityToSDKTaskState(t *testing.T) { + tests := []struct { + activity string + want a2a.TaskState + }{ + {"WORKING", a2a.TaskStateWorking}, + {"THINKING", a2a.TaskStateWorking}, + {"EXECUTING", a2a.TaskStateWorking}, + {"WAITING_FOR_INPUT", a2a.TaskStateInputRequired}, + {"COMPLETED", a2a.TaskStateCompleted}, + {"ERROR", a2a.TaskStateFailed}, + {"STALLED", a2a.TaskStateFailed}, + {"LIMITS_EXCEEDED", a2a.TaskStateFailed}, + {"OFFLINE", a2a.TaskStateFailed}, + {"UNKNOWN_ACTIVITY", a2a.TaskStateWorking}, + {"working", a2a.TaskStateWorking}, + } + + for _, tt := range tests { + t.Run(tt.activity, func(t *testing.T) { + got := MapActivityToSDKTaskState(tt.activity) + if got != tt.want { + t.Errorf("MapActivityToSDKTaskState(%q) = %q, want %q", tt.activity, got, tt.want) + } + }) + } +}