From 5b90ba978eae063187e1c20b09eb4f3833f52786 Mon Sep 17 00:00:00 2001 From: Sanskarzz Date: Thu, 14 May 2026 10:13:45 +0530 Subject: [PATCH 1/2] vMCP rate-limit middleware wiring Signed-off-by: Sanskarzz --- pkg/ratelimit/middleware.go | 27 +- pkg/vmcp/cli/serve.go | 10 + pkg/vmcp/server/ratelimit.go | 81 ++++++ pkg/vmcp/server/ratelimit_test.go | 237 +++++++++++++++ pkg/vmcp/server/server.go | 181 +++++++----- .../virtualmcp_rate_limiting_test.go | 271 ++++++++++++++++++ 6 files changed, 728 insertions(+), 79 deletions(-) create mode 100644 pkg/vmcp/server/ratelimit.go create mode 100644 pkg/vmcp/server/ratelimit_test.go create mode 100644 test/e2e/thv-operator/virtualmcp/virtualmcp_rate_limiting_test.go diff --git a/pkg/ratelimit/middleware.go b/pkg/ratelimit/middleware.go index ee108fdf87..ce437db9e6 100644 --- a/pkg/ratelimit/middleware.go +++ b/pkg/ratelimit/middleware.go @@ -52,6 +52,17 @@ type rateLimitMiddleware struct { client redis.UniversalClient } +// ToolNameResolver resolves the rate-limit tool name from a parsed MCP request. +type ToolNameResolver func(*mcp.ParsedMCPRequest) string + +// DefaultToolNameResolver uses the parsed MCP resource ID as the rate-limit tool name. +func DefaultToolNameResolver(parsed *mcp.ParsedMCPRequest) string { + if parsed == nil { + return "" + } + return parsed.ResourceID +} + // Handler returns the middleware function used by the proxy. func (m *rateLimitMiddleware) Handler() types.MiddlewareFunction { return m.handler @@ -99,16 +110,19 @@ func CreateMiddleware(config *types.MiddlewareConfig, runner types.MiddlewareRun } mw := &rateLimitMiddleware{ - handler: rateLimitHandler(limiter), + handler: NewMiddleware(limiter, nil), client: client, } runner.AddMiddleware(MiddlewareType, mw) return nil } -// rateLimitHandler returns a middleware function that enforces rate limits +// NewMiddleware returns a middleware function that enforces rate limits // on tools/call requests. -func rateLimitHandler(limiter Limiter) types.MiddlewareFunction { +func NewMiddleware(limiter Limiter, resolveToolName ToolNameResolver) types.MiddlewareFunction { + if resolveToolName == nil { + resolveToolName = DefaultToolNameResolver + } return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Rate limits only apply to parsed tools/call requests. @@ -127,7 +141,7 @@ func rateLimitHandler(limiter Limiter) types.MiddlewareFunction { if identity, ok := auth.IdentityFromContext(r.Context()); ok { userID = identity.Subject } - decision, err := limiter.Allow(r.Context(), parsed.ResourceID, userID) + decision, err := limiter.Allow(r.Context(), resolveToolName(parsed), userID) if err != nil { slog.Warn("rate limit check failed, allowing request", "error", err) next.ServeHTTP(w, r) @@ -142,6 +156,11 @@ func rateLimitHandler(limiter Limiter) types.MiddlewareFunction { } } +// rateLimitHandler returns the default rate-limit middleware used by tests and legacy callers. +func rateLimitHandler(limiter Limiter) types.MiddlewareFunction { + return NewMiddleware(limiter, nil) +} + // writeRateLimited writes an HTTP 429 response with a JSON-RPC error body. func writeRateLimited(w http.ResponseWriter, requestID any, retryAfter time.Duration) { retrySeconds := int(math.Ceil(retryAfter.Seconds())) diff --git a/pkg/vmcp/cli/serve.go b/pkg/vmcp/cli/serve.go index a962f52f27..6bf88a34e8 100644 --- a/pkg/vmcp/cli/serve.go +++ b/pkg/vmcp/cli/serve.go @@ -376,6 +376,7 @@ func Serve(ctx context.Context, cfg ServeConfig) error { serverCfg := &vmcpserver.Config{ Name: vmcpCfg.Name, + Namespace: vmcpNamespace(), Version: versions.Version, GroupRef: vmcpCfg.Group, Host: cfg.Host, @@ -394,6 +395,7 @@ func Serve(ctx context.Context, cfg ServeConfig) error { OptimizerConfig: optCfg, SessionFactory: sessionFactory, SessionStorage: vmcpCfg.SessionStorage, + RateLimiting: vmcpCfg.RateLimiting, } // Assign Watcher only when backendWatcher is non-nil. A typed nil @@ -529,6 +531,14 @@ func generateQuickModeConfig(groupRef string) (*config.Config, error) { return cfg, nil } +func vmcpNamespace() string { + namespace := os.Getenv("VMCP_NAMESPACE") + if namespace == "" { + return "local" + } + return namespace +} + // loadAuthServerConfig loads the auth server RunConfig from a sibling file // alongside the main config. The operator serializes authserver.RunConfig as a // separate ConfigMap key (authserver-config.yaml). diff --git a/pkg/vmcp/server/ratelimit.go b/pkg/vmcp/server/ratelimit.go new file mode 100644 index 0000000000..8dd0996ab5 --- /dev/null +++ b/pkg/vmcp/server/ratelimit.go @@ -0,0 +1,81 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package server + +import ( + "context" + "fmt" + "net/http" + "os" + "time" + + "github.com/redis/go-redis/v9" + + mcpparser "github.com/stacklok/toolhive/pkg/mcp" + "github.com/stacklok/toolhive/pkg/ratelimit" + vmcpconfig "github.com/stacklok/toolhive/pkg/vmcp/config" +) + +const rateLimitRedisPingTimeout = 5 * time.Second + +func (s *Server) buildRateLimitMiddleware( + ctx context.Context, +) (func(http.Handler) http.Handler, func(context.Context) error, error) { + if s.config.RateLimiting == nil { + return nil, nil, nil + } + if s.config.SessionStorage == nil || s.config.SessionStorage.Provider != "redis" { + return nil, nil, fmt.Errorf("rate limiting requires Redis session storage") + } + if s.config.SessionStorage.Address == "" { + return nil, nil, fmt.Errorf("rate limiting requires Redis session storage address") + } + + client := redis.NewClient(&redis.Options{ + Addr: s.config.SessionStorage.Address, + DB: int(s.config.SessionStorage.DB), + Password: os.Getenv(vmcpconfig.RedisPasswordEnvVar), + }) + + pingCtx, cancel := context.WithTimeout(ctx, rateLimitRedisPingTimeout) + defer cancel() + if err := client.Ping(pingCtx).Err(); err != nil { + _ = client.Close() + return nil, nil, fmt.Errorf("rate limit middleware: failed to connect to Redis at %s: %w", + s.config.SessionStorage.Address, err) + } + + limiter, err := ratelimit.NewLimiter(client, s.config.Namespace, s.config.Name, s.config.RateLimiting) + if err != nil { + _ = client.Close() + return nil, nil, fmt.Errorf("failed to create rate limiter: %w", err) + } + + cleanup := func(context.Context) error { + return client.Close() + } + return ratelimit.NewMiddleware(limiter, s.rateLimitToolName), cleanup, nil +} + +func (s *Server) rateLimitToolName(parsed *mcpparser.ParsedMCPRequest) string { + if parsed == nil { + return "" + } + toolName := parsed.ResourceID + if !s.optimizerEnabled() || toolName != "call_tool" { + return toolName + } + if parsed.Arguments == nil { + return toolName + } + innerToolName, ok := parsed.Arguments["tool_name"].(string) + if !ok || innerToolName == "" { + return toolName + } + return innerToolName +} + +func (s *Server) optimizerEnabled() bool { + return s.config.OptimizerConfig != nil || s.config.OptimizerFactory != nil +} diff --git a/pkg/vmcp/server/ratelimit_test.go b/pkg/vmcp/server/ratelimit_test.go new file mode 100644 index 0000000000..76b45cfe6d --- /dev/null +++ b/pkg/vmcp/server/ratelimit_test.go @@ -0,0 +1,237 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package server + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/alicebob/miniredis/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/stacklok/toolhive/pkg/auth" + mcpparser "github.com/stacklok/toolhive/pkg/mcp" + ratelimittypes "github.com/stacklok/toolhive/pkg/ratelimit/types" + vmcpconfig "github.com/stacklok/toolhive/pkg/vmcp/config" + "github.com/stacklok/toolhive/pkg/vmcp/optimizer" +) + +func TestBuildRateLimitMiddlewareDisabledWithoutConfig(t *testing.T) { + t.Parallel() + + s := &Server{config: &Config{Name: "vmcp", Namespace: "default"}} + + middleware, cleanup, err := s.buildRateLimitMiddleware(t.Context()) + + require.NoError(t, err) + assert.Nil(t, middleware) + assert.Nil(t, cleanup) +} + +func TestBuildRateLimitMiddlewareRequiresRedisSessionStorage(t *testing.T) { + t.Parallel() + + s := &Server{ + config: &Config{ + Name: "vmcp", + Namespace: "default", + RateLimiting: sharedRateLimitConfig(1), + }, + } + + middleware, cleanup, err := s.buildRateLimitMiddleware(t.Context()) + + require.Error(t, err) + assert.Contains(t, err.Error(), "requires Redis session storage") + assert.Nil(t, middleware) + assert.Nil(t, cleanup) +} + +func TestRateLimitMiddlewarePerUserSharedAcrossTools(t *testing.T) { + t.Parallel() + + handler := newTestRateLimitHandler(t, &Config{ + Name: "vmcp", + Namespace: "default", + RateLimiting: &ratelimittypes.RateLimitConfig{ + PerUser: &ratelimittypes.RateLimitBucket{ + MaxTokens: 1, + RefillPeriod: metav1.Duration{Duration: time.Minute}, + }, + }, + }) + + first := serveToolCall(t, handler, "backend_a_echo", "alice", nil) + assert.Equal(t, http.StatusOK, first.Code) + + second := serveToolCall(t, handler, "backend_b_echo", "alice", nil) + assert.Equal(t, http.StatusTooManyRequests, second.Code) + assertRateLimitedBody(t, second) +} + +func TestRateLimitMiddlewareUsesPostAggregationToolNames(t *testing.T) { + t.Parallel() + + handler := newTestRateLimitHandler(t, &Config{ + Name: "vmcp", + Namespace: "default", + RateLimiting: &ratelimittypes.RateLimitConfig{ + Tools: []ratelimittypes.ToolRateLimitConfig{ + { + Name: "backend_a_echo", + Shared: &ratelimittypes.RateLimitBucket{ + MaxTokens: 1, + RefillPeriod: metav1.Duration{Duration: time.Minute}, + }, + }, + }, + }, + }) + + first := serveToolCall(t, handler, "backend_a_echo", "", nil) + assert.Equal(t, http.StatusOK, first.Code) + + otherTool := serveToolCall(t, handler, "backend_b_echo", "", nil) + assert.Equal(t, http.StatusOK, otherTool.Code) + + secondMatchingTool := serveToolCall(t, handler, "backend_a_echo", "", nil) + assert.Equal(t, http.StatusTooManyRequests, secondMatchingTool.Code) +} + +func TestRateLimitMiddlewareOptimizerExtractsInnerToolName(t *testing.T) { + t.Parallel() + + handler := newTestRateLimitHandler(t, &Config{ + Name: "vmcp", + Namespace: "default", + OptimizerConfig: &optimizer.Config{}, + RateLimiting: &ratelimittypes.RateLimitConfig{ + Tools: []ratelimittypes.ToolRateLimitConfig{ + { + Name: "backend_fetch_fetch", + Shared: &ratelimittypes.RateLimitBucket{ + MaxTokens: 1, + RefillPeriod: metav1.Duration{Duration: time.Minute}, + }, + }, + }, + }, + }) + + args := map[string]any{ + "tool_name": "backend_fetch_fetch", + "parameters": map[string]any{"url": "https://example.com"}, + } + first := serveToolCall(t, handler, "call_tool", "", args) + assert.Equal(t, http.StatusOK, first.Code) + + second := serveToolCall(t, handler, "call_tool", "", args) + assert.Equal(t, http.StatusTooManyRequests, second.Code) +} + +func TestRateLimitToolNameFallsBackToCallTool(t *testing.T) { + t.Parallel() + + s := &Server{config: &Config{OptimizerConfig: &optimizer.Config{}}} + parsed := &mcpparser.ParsedMCPRequest{ + Method: "tools/call", + ResourceID: "call_tool", + Arguments: map[string]any{}, + } + + assert.Equal(t, "call_tool", s.rateLimitToolName(parsed)) +} + +func newTestRateLimitHandler(t *testing.T, cfg *Config) http.Handler { + t.Helper() + + mr := miniredis.RunT(t) + cfg.SessionStorage = &vmcpconfig.SessionStorageConfig{ + Provider: "redis", + Address: mr.Addr(), + } + + s := &Server{config: cfg} + middleware, cleanup, err := s.buildRateLimitMiddleware(t.Context()) + require.NoError(t, err) + require.NotNil(t, middleware) + t.Cleanup(func() { + require.NoError(t, cleanup(context.Background())) + }) + + next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + }) + return withIdentityMiddleware(mcpparser.ParsingMiddleware(middleware(next))) +} + +func withIdentityMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + user := r.Header.Get("X-Test-User") + if user != "" { + identity := &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: user}} + r = r.WithContext(auth.WithIdentity(r.Context(), identity)) + } + next.ServeHTTP(w, r) + }) +} + +func serveToolCall( + t *testing.T, + handler http.Handler, + toolName string, + user string, + arguments map[string]any, +) *httptest.ResponseRecorder { + t.Helper() + + if arguments == nil { + arguments = map[string]any{} + } + body, err := json.Marshal(map[string]any{ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/call", + "params": map[string]any{ + "name": toolName, + "arguments": arguments, + }, + }) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/mcp", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + if user != "" { + req.Header.Set("X-Test-User", user) + } + recorder := httptest.NewRecorder() + handler.ServeHTTP(recorder, req) + return recorder +} + +func assertRateLimitedBody(t *testing.T, recorder *httptest.ResponseRecorder) { + t.Helper() + + var resp map[string]any + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp)) + errObj := resp["error"].(map[string]any) + assert.Equal(t, float64(-32029), errObj["code"]) + assert.Equal(t, "Rate limit exceeded", errObj["message"]) +} + +func sharedRateLimitConfig(maxTokens int32) *ratelimittypes.RateLimitConfig { + return &ratelimittypes.RateLimitConfig{ + Shared: &ratelimittypes.RateLimitBucket{ + MaxTokens: maxTokens, + RefillPeriod: metav1.Duration{Duration: time.Minute}, + }, + } +} diff --git a/pkg/vmcp/server/server.go b/pkg/vmcp/server/server.go index d070166600..4d997d1173 100644 --- a/pkg/vmcp/server/server.go +++ b/pkg/vmcp/server/server.go @@ -29,6 +29,7 @@ import ( "github.com/stacklok/toolhive/pkg/auth" asrunner "github.com/stacklok/toolhive/pkg/authserver/runner" mcpparser "github.com/stacklok/toolhive/pkg/mcp" + ratelimittypes "github.com/stacklok/toolhive/pkg/ratelimit/types" "github.com/stacklok/toolhive/pkg/recovery" "github.com/stacklok/toolhive/pkg/telemetry" transportmiddleware "github.com/stacklok/toolhive/pkg/transport/middleware" @@ -93,6 +94,9 @@ type Config struct { // Name is the server name exposed in MCP protocol Name string + // Namespace is the namespace used to scope distributed runtime state. + Namespace string + // Version is the server version Version string @@ -182,6 +186,10 @@ type Config struct { // session persistence; the Redis password is read from the // THV_SESSION_REDIS_PASSWORD environment variable. SessionStorage *vmcpconfig.SessionStorageConfig + + // RateLimiting configures Redis-backed rate limiting for tools/call requests. + // When nil, rate limiting is disabled. + RateLimiting *ratelimittypes.RateLimitConfig } // Server is the Virtual MCP Server that aggregates multiple backends. @@ -254,6 +262,9 @@ type Server struct { // Populated during Start() initialization before blocking; no mutex needed // since Stop() is only called after Start()'s select returns. shutdownFuncs []func(context.Context) error + + // rateLimitMiddleware is initialized once during New() when configured. + rateLimitMiddleware func(http.Handler) http.Handler } // buildSessionDataStorage constructs the DataStorage backend from cfg. @@ -319,6 +330,9 @@ func New( if cfg.Name == "" { cfg.Name = "toolhive-vmcp" } + if cfg.Namespace == "" { + cfg.Namespace = "local" + } if cfg.Version == "" { cfg.Version = "0.1.0" } @@ -485,6 +499,14 @@ func New( if optimizerCleanup != nil { srv.shutdownFuncs = append(srv.shutdownFuncs, optimizerCleanup) } + rateLimitMiddleware, rateLimitCleanup, err := srv.buildRateLimitMiddleware(ctx) + if err != nil { + return nil, err + } + srv.rateLimitMiddleware = rateLimitMiddleware + if rateLimitCleanup != nil { + srv.shutdownFuncs = append(srv.shutdownFuncs, rateLimitCleanup) + } // Register OnRegisterSession hook to inject capabilities after SDK registers session. // See handleSessionRegistration for implementation details. @@ -571,57 +593,79 @@ func (s *Server) Handler(_ context.Context) (http.Handler, error) { slog.Debug("embedded authorization server routes registered") } - // MCP endpoint - apply middleware chain (wrapping order, execution happens in reverse): - // Code wraps: auth+parser → audit → discovery → annotation-enrichment → - // authz → backend-enrichment → MCP-parsing → telemetry - // Execution order: recovery → header-val → auth+parser → audit → - // discovery → annotation-enrichment → authz → backend-enrichment → - // MCP-parsing → telemetry → handler + mcpHandler, err := s.buildMCPHandler(streamableServer) + if err != nil { + return nil, err + } - var mcpHandler http.Handler = streamableServer + // Apply Accept header validation (rejects GET requests without Accept: text/event-stream) + mcpHandler = headerValidatingMiddleware(mcpHandler) + + // Clear the write deadline for qualifying SSE connections (GET + + // Accept: text/event-stream + MCP endpoint path) so the server-level + // WriteTimeout does not kill long-lived SSE streams (see golang/go#16100). + // Non-qualifying requests are left untouched; http.Server.WriteTimeout + // (defaultWriteTimeout) remains in effect for them. + mcpHandler = transportmiddleware.WriteTimeout(s.config.EndpointPath)(mcpHandler) + + // Apply recovery middleware as outermost (catches panics from all inner middleware) + mcpHandler = recovery.Middleware(mcpHandler) + slog.Info("recovery middleware enabled for MCP endpoints") + + mux.Handle("/", mcpHandler) + + return mux, nil +} + +func (s *Server) buildMCPHandler(streamableServer http.Handler) (http.Handler, error) { + // Wrapping order is inside-out; execution is outside-in: + // auth → MCP parser → rate limit → audit → discovery → annotation enrichment → + // authz → backend enrichment → telemetry → streamable HTTP handler. + mcpHandler := streamableServer if s.config.TelemetryProvider != nil { mcpHandler = s.config.TelemetryProvider.Middleware(s.config.Name, "streamable-http")(mcpHandler) slog.Info("telemetry middleware enabled for MCP endpoints") } + mcpHandler = s.applyBackendEnrichment(mcpHandler) + mcpHandler = s.applyAuthorization(mcpHandler) + mcpHandler = s.applyDiscovery(mcpHandler) - // Apply MCP parsing middleware to extract JSON-RPC method from request body. - // This runs before telemetry so that recordMetrics can label metrics with the - // actual mcp_method (e.g. "tools/call", "initialize") instead of "unknown". - // Note: ParsingMiddleware is also composed inside the auth middleware (for audit/authz). - // The second application here is a no-op because the context already holds a - // ParsedMCPRequest; it exists only so the telemetry layer works correctly even - // when auth middleware is nil. - mcpHandler = mcpparser.ParsingMiddleware(mcpHandler) + var err error + mcpHandler, err = s.applyAudit(mcpHandler) + if err != nil { + return nil, err + } + mcpHandler = s.applyRateLimiting(mcpHandler) - // Apply backend enrichment middleware if audit is configured - // This runs after discovery populates the routing table, so it can extract backend names - if s.config.AuditConfig != nil { - mcpHandler = s.backendEnrichmentMiddleware(mcpHandler) - slog.Info("backend enrichment middleware enabled for audit events") + mcpHandler = mcpparser.ParsingMiddleware(mcpHandler) + if s.config.AuthMiddleware != nil { + mcpHandler = s.config.AuthMiddleware(mcpHandler) + slog.Info("authentication middleware enabled for MCP endpoints") } + return mcpHandler, nil +} - // Apply authorization middleware if configured (runs AFTER discovery in execution). - // Wrapping it here (before discovery wrap) means discovery runs first, then authz. - if s.config.AuthzMiddleware != nil { - mcpHandler = s.config.AuthzMiddleware(mcpHandler) - slog.Info("authorization middleware enabled for MCP endpoints (post-discovery)") +func (s *Server) applyBackendEnrichment(next http.Handler) http.Handler { + if s.config.AuditConfig == nil { + return next } + slog.Info("backend enrichment middleware enabled for audit events") + return s.backendEnrichmentMiddleware(next) +} - // Apply annotation enrichment middleware (runs after discovery, before authz in execution). - // Reads tool annotations from discovered capabilities and injects them into the - // request context so the authz middleware can make annotation-aware decisions. - if s.config.AuthzMiddleware != nil { - mcpHandler = AnnotationEnrichmentMiddleware(mcpHandler) - slog.Info("annotation enrichment middleware enabled for MCP endpoints") +func (s *Server) applyAuthorization(next http.Handler) http.Handler { + if s.config.AuthzMiddleware == nil { + return next } + next = s.config.AuthzMiddleware(next) + slog.Info("authorization middleware enabled for MCP endpoints (post-discovery)") + next = AnnotationEnrichmentMiddleware(next) + slog.Info("annotation enrichment middleware enabled for MCP endpoints") + return next +} - // Apply discovery middleware (runs after audit/auth middleware) - // Discovery middleware performs per-request capability aggregation with user context. - // vmcpSessionMgr (MultiSessionGetter) is used to retrieve the fully-formed MultiSession - // for subsequent requests so the routing table can be injected into context. - // The backend registry provides a dynamic backend list (supports DynamicRegistry for K8s). - // The health monitor enables filtering based on current health status (respects circuit breaker). +func (s *Server) applyDiscovery(next http.Handler) http.Handler { s.healthMonitorMu.RLock() healthMon := s.healthMonitor s.healthMonitorMu.RUnlock() @@ -630,51 +674,38 @@ func (s *Server) Handler(_ context.Context) (http.Handler, error) { if healthMon != nil { healthStatusProvider = healthMon } - mcpHandler = discovery.Middleware( + next = discovery.Middleware( s.discoveryMgr, s.backendRegistry, s.vmcpSessionMgr, healthStatusProvider, discovery.WithSessionScopedRouting(), - )(mcpHandler) + )(next) slog.Info("discovery middleware enabled for lazy per-user capability discovery") + return next +} - // Apply audit middleware if configured (runs after auth, before discovery) - if s.config.AuditConfig != nil { - if err := s.config.AuditConfig.Validate(); err != nil { - return nil, fmt.Errorf("invalid audit configuration: %w", err) - } - auditor, err := audit.NewAuditorWithTransport( - s.config.AuditConfig, - "streamable-http", // vMCP uses streamable HTTP transport - ) - if err != nil { - return nil, fmt.Errorf("failed to create auditor: %w", err) - } - mcpHandler = auditor.Middleware(mcpHandler) - slog.Info("audit middleware enabled for MCP endpoints") +func (s *Server) applyAudit(next http.Handler) (http.Handler, error) { + if s.config.AuditConfig == nil { + return next, nil } - - // Apply authentication middleware if configured (runs first in chain) - if s.config.AuthMiddleware != nil { - mcpHandler = s.config.AuthMiddleware(mcpHandler) - slog.Info("authentication middleware enabled for MCP endpoints") + if err := s.config.AuditConfig.Validate(); err != nil { + return nil, fmt.Errorf("invalid audit configuration: %w", err) } + auditor, err := audit.NewAuditorWithTransport( + s.config.AuditConfig, + "streamable-http", // vMCP uses streamable HTTP transport + ) + if err != nil { + return nil, fmt.Errorf("failed to create auditor: %w", err) + } + slog.Info("audit middleware enabled for MCP endpoints") + return auditor.Middleware(next), nil +} - // Apply Accept header validation (rejects GET requests without Accept: text/event-stream) - mcpHandler = headerValidatingMiddleware(mcpHandler) - - // Clear the write deadline for qualifying SSE connections (GET + - // Accept: text/event-stream + MCP endpoint path) so the server-level - // WriteTimeout does not kill long-lived SSE streams (see golang/go#16100). - // Non-qualifying requests are left untouched; http.Server.WriteTimeout - // (defaultWriteTimeout) remains in effect for them. - mcpHandler = transportmiddleware.WriteTimeout(s.config.EndpointPath)(mcpHandler) - - // Apply recovery middleware as outermost (catches panics from all inner middleware) - mcpHandler = recovery.Middleware(mcpHandler) - slog.Info("recovery middleware enabled for MCP endpoints") - - mux.Handle("/", mcpHandler) - - return mux, nil +func (s *Server) applyRateLimiting(next http.Handler) http.Handler { + if s.rateLimitMiddleware == nil { + return next + } + slog.Info("rate limit middleware enabled for MCP endpoints") + return s.rateLimitMiddleware(next) } // Start starts the Virtual MCP Server and begins serving requests. diff --git a/test/e2e/thv-operator/virtualmcp/virtualmcp_rate_limiting_test.go b/test/e2e/thv-operator/virtualmcp/virtualmcp_rate_limiting_test.go new file mode 100644 index 0000000000..3a4f89f85b --- /dev/null +++ b/test/e2e/thv-operator/virtualmcp/virtualmcp_rate_limiting_test.go @@ -0,0 +1,271 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package virtualmcp + +import ( + "encoding/json" + "fmt" + "net" + "net/http" + "os/exec" + "strings" + "time" + + mcpclient "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/client/transport" + "github.com/mark3labs/mcp-go/mcp" + "github.com/onsi/ginkgo/v2" + "github.com/onsi/gomega" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + + mcpv1beta1 "github.com/stacklok/toolhive/cmd/thv-operator/api/v1beta1" + vmcpconfig "github.com/stacklok/toolhive/pkg/vmcp/config" + "github.com/stacklok/toolhive/test/e2e/images" +) + +var _ = ginkgo.Describe("VirtualMCPServer Rate Limiting", ginkgo.Ordered, func() { + const ( + timeout = 5 * time.Minute + pollInterval = 2 * time.Second + oidcAudience = "vmcp-audience" + ) + + var ( + mcpGroupName string + backendName string + vmcpName string + redisName string + oidcName string + vmcpLocalPort int + oidcLocalPort int + vmcpPortForwardCleanup func() + oidcPortForwardCleanup func() + oidcCleanup func() + ) + + ginkgo.BeforeAll(func() { + ts := time.Now().UnixNano() + mcpGroupName = fmt.Sprintf("e2e-rl-group-%d", ts) + backendName = fmt.Sprintf("e2e-rl-backend-%d", ts) + vmcpName = fmt.Sprintf("e2e-rl-vmcp-%d", ts) + redisName = fmt.Sprintf("e2e-rl-redis-%d", ts) + oidcName = fmt.Sprintf("e2e-rl-oidc-%d", ts) + + ginkgo.By("Deploying Redis") + deployRedis(redisName) + + ginkgo.By("Deploying parameterized OIDC server") + oidcIssuer, _, cleanup := DeployParameterizedOIDCServer( + ctx, k8sClient, oidcName, defaultNamespace, timeout, pollInterval, + ) + oidcCleanup = cleanup + var err error + oidcLocalPort, oidcPortForwardCleanup, err = startRateLimitServicePortForward(oidcName, 80) + gomega.Expect(err).ToNot(gomega.HaveOccurred()) + + ginkgo.By("Creating MCPOIDCConfig") + gomega.Expect(k8sClient.Create(ctx, &mcpv1beta1.MCPOIDCConfig{ + ObjectMeta: metav1.ObjectMeta{Name: oidcName, Namespace: defaultNamespace}, + Spec: mcpv1beta1.MCPOIDCConfigSpec{ + Type: mcpv1beta1.MCPOIDCConfigTypeInline, + Inline: &mcpv1beta1.InlineOIDCSharedConfig{ + Issuer: oidcIssuer, + InsecureAllowHTTP: true, + JWKSAllowPrivateIP: true, + ProtectedResourceAllowPrivateIP: true, + }, + }, + })).To(gomega.Succeed()) + + ginkgo.By("Creating MCPGroup") + CreateMCPGroupAndWait(ctx, k8sClient, mcpGroupName, defaultNamespace, + "E2E vMCP rate limiting group", timeout, pollInterval) + + ginkgo.By("Creating backend MCPServer") + gomega.Expect(k8sClient.Create(ctx, &mcpv1beta1.MCPServer{ + ObjectMeta: metav1.ObjectMeta{Name: backendName, Namespace: defaultNamespace}, + Spec: mcpv1beta1.MCPServerSpec{ + GroupRef: &mcpv1beta1.MCPGroupRef{Name: mcpGroupName}, + Image: images.YardstickServerImage, + Transport: "streamable-http", + ProxyPort: 8080, + MCPPort: 8080, + }, + })).To(gomega.Succeed()) + + ginkgo.By("Waiting for backend MCPServer to be ready") + gomega.Eventually(func() error { + server := &mcpv1beta1.MCPServer{} + if err := k8sClient.Get(ctx, types.NamespacedName{ + Name: backendName, + Namespace: defaultNamespace, + }, server); err != nil { + return err + } + if server.Status.Phase != mcpv1beta1.MCPServerPhaseReady { + return fmt.Errorf("backend not ready yet, phase: %s", server.Status.Phase) + } + return nil + }, timeout, pollInterval).Should(gomega.Succeed()) + + redisAddr := fmt.Sprintf("%s.%s.svc.cluster.local:6379", redisName, defaultNamespace) + ginkgo.By("Creating VirtualMCPServer with per-user rate limiting") + gomega.Expect(k8sClient.Create(ctx, &mcpv1beta1.VirtualMCPServer{ + ObjectMeta: metav1.ObjectMeta{Name: vmcpName, Namespace: defaultNamespace}, + Spec: mcpv1beta1.VirtualMCPServerSpec{ + GroupRef: &mcpv1beta1.MCPGroupRef{Name: mcpGroupName}, + Config: vmcpconfig.Config{ + Group: mcpGroupName, + RateLimiting: &mcpv1beta1.RateLimitConfig{ + PerUser: &mcpv1beta1.RateLimitBucket{ + MaxTokens: 1, + RefillPeriod: metav1.Duration{Duration: time.Minute}, + }, + }, + }, + IncomingAuth: &mcpv1beta1.IncomingAuthConfig{ + Type: "oidc", + OIDCConfigRef: &mcpv1beta1.MCPOIDCConfigReference{ + Name: oidcName, + Audience: oidcAudience, + }, + }, + SessionStorage: &mcpv1beta1.SessionStorageConfig{ + Provider: mcpv1beta1.SessionStorageProviderRedis, + Address: redisAddr, + }, + }, + })).To(gomega.Succeed()) + + ginkgo.By("Waiting for VirtualMCPServer to be ready") + WaitForVirtualMCPServerReady(ctx, k8sClient, vmcpName, defaultNamespace, timeout, pollInterval) + + ginkgo.By("Port-forwarding VirtualMCPServer service") + vmcpLocalPort, vmcpPortForwardCleanup, err = startRateLimitServicePortForward(VMCPServiceName(vmcpName), 4483) + gomega.Expect(err).ToNot(gomega.HaveOccurred()) + }) + + ginkgo.AfterAll(func() { + if vmcpPortForwardCleanup != nil { + vmcpPortForwardCleanup() + } + if oidcPortForwardCleanup != nil { + oidcPortForwardCleanup() + } + if oidcCleanup != nil { + oidcCleanup() + } + _ = k8sClient.Delete(ctx, &mcpv1beta1.VirtualMCPServer{ + ObjectMeta: metav1.ObjectMeta{Name: vmcpName, Namespace: defaultNamespace}, + }) + _ = k8sClient.Delete(ctx, &mcpv1beta1.MCPServer{ + ObjectMeta: metav1.ObjectMeta{Name: backendName, Namespace: defaultNamespace}, + }) + _ = k8sClient.Delete(ctx, &mcpv1beta1.MCPGroup{ + ObjectMeta: metav1.ObjectMeta{Name: mcpGroupName, Namespace: defaultNamespace}, + }) + _ = k8sClient.Delete(ctx, &mcpv1beta1.MCPOIDCConfig{ + ObjectMeta: metav1.ObjectMeta{Name: oidcName, Namespace: defaultNamespace}, + }) + cleanupRedis(redisName) + }) + + ginkgo.It("rejects tools/call after the per-user limit is exceeded", func() { + token := fetchRateLimitOIDCToken(oidcLocalPort, "alice") + mcpClient := newRateLimitMCPClient(vmcpLocalPort, token) + defer mcpClient.Close() + + tools, err := mcpClient.ListTools(ctx, mcp.ListToolsRequest{}) + gomega.Expect(err).ToNot(gomega.HaveOccurred()) + toolName := firstEchoToolName(tools.Tools) + gomega.Expect(toolName).ToNot(gomega.BeEmpty()) + + req := mcp.CallToolRequest{} + req.Params.Name = toolName + req.Params.Arguments = map[string]any{"input": "ratelimittest"} + + _, err = mcpClient.CallTool(ctx, req) + gomega.Expect(err).ToNot(gomega.HaveOccurred()) + + _, err = mcpClient.CallTool(ctx, req) + gomega.Expect(err).To(gomega.HaveOccurred()) + gomega.Expect(err.Error()).To(gomega.Or( + gomega.ContainSubstring("429"), + gomega.ContainSubstring("-32029"), + gomega.ContainSubstring("Rate limit exceeded"), + )) + }) +}) + +func fetchRateLimitOIDCToken(oidcPort int, subject string) string { + url := fmt.Sprintf("http://localhost:%d/token?subject=%s", oidcPort, subject) + resp, err := http.Post(url, "application/x-www-form-urlencoded", nil) //nolint:noctx + gomega.Expect(err).ToNot(gomega.HaveOccurred()) + defer resp.Body.Close() + gomega.Expect(resp.StatusCode).To(gomega.Equal(http.StatusOK)) + + var tokenResp struct { + AccessToken string `json:"access_token"` + } + gomega.Expect(json.NewDecoder(resp.Body).Decode(&tokenResp)).To(gomega.Succeed()) + gomega.Expect(tokenResp.AccessToken).ToNot(gomega.BeEmpty()) + return tokenResp.AccessToken +} + +func newRateLimitMCPClient(vmcpPort int, token string) *mcpclient.Client { + httpClient := &http.Client{ + Transport: &authRoundTripper{token: token, transport: http.DefaultTransport}, + Timeout: 30 * time.Second, + } + serverURL := fmt.Sprintf("http://localhost:%d/mcp", vmcpPort) + return InitializeMCPClientWithRetries(serverURL, 2*time.Minute, transport.WithHTTPBasicClient(httpClient)) +} + +func startRateLimitServicePortForward(serviceName string, servicePort int32) (int, func(), error) { + listener, err := net.Listen("tcp", ":0") + if err != nil { + return 0, nil, fmt.Errorf("failed to find free local port: %w", err) + } + localPort := listener.Addr().(*net.TCPAddr).Port + _ = listener.Close() + + kubeconfigArg := fmt.Sprintf("--kubeconfig=%s", kubeconfig) + //nolint:gosec // kubeconfig, serviceName, and ports are test-controlled values. + cmd := exec.Command("kubectl", kubeconfigArg, + "-n", defaultNamespace, "port-forward", + fmt.Sprintf("svc/%s", serviceName), + fmt.Sprintf("%d:%d", localPort, servicePort)) + if err := cmd.Start(); err != nil { + return 0, nil, fmt.Errorf("failed to start port-forward to service %s: %w", serviceName, err) + } + + cleanup := func() { + if cmd.Process != nil { + _ = cmd.Process.Kill() + _ = cmd.Wait() + } + } + + for range 30 { + conn, dialErr := net.DialTimeout("tcp", fmt.Sprintf("localhost:%d", localPort), 500*time.Millisecond) + if dialErr == nil { + _ = conn.Close() + return localPort, cleanup, nil + } + time.Sleep(500 * time.Millisecond) + } + + cleanup() + return 0, nil, fmt.Errorf("port-forward to service %s never became ready on localhost:%d", serviceName, localPort) +} + +func firstEchoToolName(tools []mcp.Tool) string { + for _, tool := range tools { + if tool.Name == "echo" || strings.HasSuffix(tool.Name, "_echo") { + return tool.Name + } + } + return "" +} From 619264d0338f1dc1d7362a47a2ad9d475c3ec6b5 Mon Sep 17 00:00:00 2001 From: Sanskarzz Date: Fri, 15 May 2026 19:39:39 +0530 Subject: [PATCH 2/2] improve test coverage Signed-off-by: Sanskarzz --- pkg/ratelimit/middleware_test.go | 73 +++++++++++++ pkg/vmcp/cli/serve_test.go | 14 +++ pkg/vmcp/server/ratelimit_test.go | 171 ++++++++++++++++++++++++++++++ 3 files changed, 258 insertions(+) diff --git a/pkg/ratelimit/middleware_test.go b/pkg/ratelimit/middleware_test.go index ed76e72e0c..017dea4b1a 100644 --- a/pkg/ratelimit/middleware_test.go +++ b/pkg/ratelimit/middleware_test.go @@ -13,11 +13,17 @@ import ( "testing" "time" + "github.com/alicebob/miniredis/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + v1beta1 "github.com/stacklok/toolhive/cmd/thv-operator/api/v1beta1" "github.com/stacklok/toolhive/pkg/auth" "github.com/stacklok/toolhive/pkg/mcp" + transporttypes "github.com/stacklok/toolhive/pkg/transport/types" + transportmocks "github.com/stacklok/toolhive/pkg/transport/types/mocks" ) // dummyLimiter is a test double for the Limiter interface. @@ -208,3 +214,70 @@ func TestRateLimitHandler_NoIdentityPassesEmptyUserID(t *testing.T) { assert.Equal(t, "echo", recorder.toolName) assert.Empty(t, recorder.userID, "unauthenticated requests should pass empty userID") } + +func TestDefaultToolNameResolverNilParsedRequest(t *testing.T) { + t.Parallel() + + assert.Empty(t, DefaultToolNameResolver(nil)) +} + +func TestNewMiddlewareUsesCustomToolNameResolver(t *testing.T) { + t.Parallel() + + recorder := &recordingLimiter{} + handler := NewMiddleware(recorder, func(*mcp.ParsedMCPRequest) string { + return "resolved-tool" + })(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodPost, "/mcp", nil) + req = withParsedMCPRequest(req, "tools/call", "raw-tool", 1) + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "resolved-tool", recorder.toolName) +} + +func TestRateLimitMiddlewareHandlerReturnsConfiguredHandler(t *testing.T) { + t.Parallel() + + expected := rateLimitHandler(&dummyLimiter{decision: &Decision{Allowed: true}}) + mw := &rateLimitMiddleware{handler: expected} + + assert.NotNil(t, mw.Handler()) +} + +func TestCreateMiddlewareRegistersUsableMiddleware(t *testing.T) { + t.Parallel() + + mr := miniredis.RunT(t) + cfg, err := transporttypes.NewMiddlewareConfig(MiddlewareType, MiddlewareParams{ + Namespace: "default", + ServerName: "server", + RedisAddr: mr.Addr(), + Config: &v1beta1.RateLimitConfig{ + Shared: &v1beta1.RateLimitBucket{ + MaxTokens: 1, + RefillPeriod: metav1.Duration{Duration: time.Minute}, + }, + }, + }) + require.NoError(t, err) + + ctrl := gomock.NewController(t) + runner := transportmocks.NewMockMiddlewareRunner(ctrl) + var registered transporttypes.Middleware + runner.EXPECT(). + AddMiddleware(MiddlewareType, gomock.AssignableToTypeOf(&rateLimitMiddleware{})). + Do(func(_ string, middleware transporttypes.Middleware) { + registered = middleware + }) + + require.NoError(t, CreateMiddleware(cfg, runner)) + require.NotNil(t, registered) + require.NotNil(t, registered.Handler()) + require.NoError(t, registered.Close()) +} diff --git a/pkg/vmcp/cli/serve_test.go b/pkg/vmcp/cli/serve_test.go index 667b285779..22b0260295 100644 --- a/pkg/vmcp/cli/serve_test.go +++ b/pkg/vmcp/cli/serve_test.go @@ -337,6 +337,20 @@ func TestValidateQuickModeHost(t *testing.T) { } } +func TestVMCPNamespace(t *testing.T) { + t.Run("defaults to local", func(t *testing.T) { + t.Setenv("VMCP_NAMESPACE", "") + + assert.Equal(t, "local", vmcpNamespace()) + }) + + t.Run("uses environment value", func(t *testing.T) { + t.Setenv("VMCP_NAMESPACE", "toolhive-system") + + assert.Equal(t, "toolhive-system", vmcpNamespace()) + }) +} + // TestRunDiscovery_ZeroBackends exercises the branch in runDiscovery where the // discoverer succeeds but returns no backends. The function must return a // non-error, an empty (non-nil) backend slice, and pass through the client and diff --git a/pkg/vmcp/server/ratelimit_test.go b/pkg/vmcp/server/ratelimit_test.go index 76b45cfe6d..8987e93507 100644 --- a/pkg/vmcp/server/ratelimit_test.go +++ b/pkg/vmcp/server/ratelimit_test.go @@ -15,13 +15,18 @@ import ( "github.com/alicebob/miniredis/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "github.com/stacklok/toolhive/pkg/auth" mcpparser "github.com/stacklok/toolhive/pkg/mcp" ratelimittypes "github.com/stacklok/toolhive/pkg/ratelimit/types" + "github.com/stacklok/toolhive/pkg/vmcp" vmcpconfig "github.com/stacklok/toolhive/pkg/vmcp/config" + discoveryMocks "github.com/stacklok/toolhive/pkg/vmcp/discovery/mocks" + "github.com/stacklok/toolhive/pkg/vmcp/mocks" "github.com/stacklok/toolhive/pkg/vmcp/optimizer" + routerMocks "github.com/stacklok/toolhive/pkg/vmcp/router/mocks" ) func TestBuildRateLimitMiddlewareDisabledWithoutConfig(t *testing.T) { @@ -55,6 +60,82 @@ func TestBuildRateLimitMiddlewareRequiresRedisSessionStorage(t *testing.T) { assert.Nil(t, cleanup) } +func TestBuildRateLimitMiddlewareRequiresRedisAddress(t *testing.T) { + t.Parallel() + + s := &Server{ + config: &Config{ + Name: "vmcp", + Namespace: "default", + RateLimiting: sharedRateLimitConfig(1), + SessionStorage: &vmcpconfig.SessionStorageConfig{ + Provider: "redis", + }, + }, + } + + middleware, cleanup, err := s.buildRateLimitMiddleware(t.Context()) + + require.Error(t, err) + assert.Contains(t, err.Error(), "requires Redis session storage address") + assert.Nil(t, middleware) + assert.Nil(t, cleanup) +} + +func TestBuildRateLimitMiddlewareRedisPingFailure(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(t.Context(), 100*time.Millisecond) + defer cancel() + s := &Server{ + config: &Config{ + Name: "vmcp", + Namespace: "default", + RateLimiting: sharedRateLimitConfig(1), + SessionStorage: &vmcpconfig.SessionStorageConfig{ + Provider: "redis", + Address: "127.0.0.1:1", + }, + }, + } + + middleware, cleanup, err := s.buildRateLimitMiddleware(ctx) + + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to connect to Redis") + assert.Nil(t, middleware) + assert.Nil(t, cleanup) +} + +func TestBuildRateLimitMiddlewareInvalidRateLimitConfig(t *testing.T) { + t.Parallel() + + mr := miniredis.RunT(t) + s := &Server{ + config: &Config{ + Name: "vmcp", + Namespace: "default", + RateLimiting: &ratelimittypes.RateLimitConfig{ + Shared: &ratelimittypes.RateLimitBucket{ + MaxTokens: 0, + RefillPeriod: metav1.Duration{Duration: time.Minute}, + }, + }, + SessionStorage: &vmcpconfig.SessionStorageConfig{ + Provider: "redis", + Address: mr.Addr(), + }, + }, + } + + middleware, cleanup, err := s.buildRateLimitMiddleware(t.Context()) + + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to create rate limiter") + assert.Nil(t, middleware) + assert.Nil(t, cleanup) +} + func TestRateLimitMiddlewarePerUserSharedAcrossTools(t *testing.T) { t.Parallel() @@ -150,6 +231,96 @@ func TestRateLimitToolNameFallsBackToCallTool(t *testing.T) { assert.Equal(t, "call_tool", s.rateLimitToolName(parsed)) } +func TestRateLimitToolNameNilParsedRequest(t *testing.T) { + t.Parallel() + + s := &Server{config: &Config{}} + + assert.Empty(t, s.rateLimitToolName(nil)) +} + +func TestRateLimitToolNameOptimizerFallsBackForInvalidInnerToolName(t *testing.T) { + t.Parallel() + + s := &Server{config: &Config{OptimizerConfig: &optimizer.Config{}}} + parsed := &mcpparser.ParsedMCPRequest{ + Method: "tools/call", + ResourceID: "call_tool", + Arguments: map[string]any{"tool_name": 123}, + } + + assert.Equal(t, "call_tool", s.rateLimitToolName(parsed)) +} + +func TestApplyRateLimitingWrapsConfiguredMiddleware(t *testing.T) { + t.Parallel() + + s := &Server{ + rateLimitMiddleware: func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Rate-Limit-Test", "wrapped") + next.ServeHTTP(w, r) + }) + }, + } + handler := s.applyRateLimiting(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusAccepted) + })) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/mcp", nil) + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusAccepted, rec.Code) + assert.Equal(t, "wrapped", rec.Header().Get("X-Rate-Limit-Test")) +} + +func TestApplyAuthorizationWrapsConfiguredMiddleware(t *testing.T) { + t.Parallel() + + s := &Server{config: &Config{ + AuthzMiddleware: func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Authz-Test", "wrapped") + next.ServeHTTP(w, r) + }) + }, + }} + handler := s.applyAuthorization(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusAccepted) + })) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/mcp", nil) + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusAccepted, rec.Code) + assert.Equal(t, "wrapped", rec.Header().Get("X-Authz-Test")) +} + +func TestNewDefaultsNamespace(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + mockRouter := routerMocks.NewMockRouter(ctrl) + mockBackendClient := mocks.NewMockBackendClient(ctrl) + mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl) + + s, err := New( + t.Context(), + &Config{SessionFactory: testMinimalFactory()}, + mockRouter, + mockBackendClient, + mockDiscoveryMgr, + vmcp.NewImmutableRegistry([]vmcp.Backend{}), + nil, + ) + + require.NoError(t, err) + assert.Equal(t, "local", s.config.Namespace) +} + func newTestRateLimitHandler(t *testing.T, cfg *Config) http.Handler { t.Helper()