diff --git a/pkg/vmcp/headerforward/transport.go b/pkg/vmcp/headerforward/transport.go index 8bac2cf500..e235fcb17a 100644 --- a/pkg/vmcp/headerforward/transport.go +++ b/pkg/vmcp/headerforward/transport.go @@ -35,9 +35,13 @@ type headerForwardRoundTripper struct { } // RoundTrip injects the pre-resolved headers onto a clone of the request and -// delegates to the wrapped transport. Headers already present on the request -// are left untouched so inner transports (auth, identity, trace) can always -// override user-supplied values for the same name. +// delegates to the wrapped transport. Names already present on the inbound +// request are skipped — this preserves any header the *caller* set on the +// request before invoking the round-tripper chain. Inner (closer to the wire) +// transports run AFTER this one and use Set() unconditionally, so on any +// overlapping name those stages still win on the wire. Restricted names are +// blocked at resolve time, so user-supplied config cannot reach this point +// for Host, hop-by-hop, or X-Forwarded-* anyway. func (h *headerForwardRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { if len(h.headers) == 0 { return h.base.RoundTrip(req) @@ -64,6 +68,11 @@ func (h *headerForwardRoundTripper) RoundTrip(req *http.Request) (*http.Response // backend's pre-resolved HeaderForwardConfig. Returns base unchanged when no // header injection is configured or the effective header set is empty. // +// Used by both the vMCP backend client (startup capability discovery) and the +// per-session backend connector (long-lived MCP traffic). Exported so the +// session backend in pkg/vmcp/session/internal/backend can share the same +// transport-chain wiring. +// // Fails loudly (constructor validation, per go-style.md) when a secret identifier // cannot be resolved through the provider, so a misconfigured backend surfaces // at pod startup — not as a silent missing-header on every request. diff --git a/pkg/vmcp/session/internal/backend/mcp_session.go b/pkg/vmcp/session/internal/backend/mcp_session.go index 4ad386ae39..a90a545a47 100644 --- a/pkg/vmcp/session/internal/backend/mcp_session.go +++ b/pkg/vmcp/session/internal/backend/mcp_session.go @@ -17,11 +17,13 @@ import ( "github.com/mark3labs/mcp-go/mcp" "github.com/stacklok/toolhive/pkg/auth" + "github.com/stacklok/toolhive/pkg/secrets" "github.com/stacklok/toolhive/pkg/versions" "github.com/stacklok/toolhive/pkg/vmcp" vmcpauth "github.com/stacklok/toolhive/pkg/vmcp/auth" authtypes "github.com/stacklok/toolhive/pkg/vmcp/auth/types" "github.com/stacklok/toolhive/pkg/vmcp/conversion" + "github.com/stacklok/toolhive/pkg/vmcp/headerforward" ) const ( @@ -196,19 +198,25 @@ func (c *mcpSession) GetPrompt( // // registry provides the authentication strategy for outgoing backend requests. // Pass a registry configured with the "unauthenticated" strategy to disable auth. +// +// A single secrets.EnvironmentProvider is constructed once per connector and +// shared across every session it creates; its lifetime matches the connector's. +// It is consumed by BuildHeaderForwardTripper to resolve secret-backed entries +// in target.HeaderForward. func NewHTTPConnector(registry vmcpauth.OutgoingAuthRegistry) func( ctx context.Context, target *vmcp.BackendTarget, identity *auth.Identity, sessionHint string, ) (Session, *vmcp.CapabilityList, error) { + provider := secrets.NewEnvironmentProvider() return func( ctx context.Context, target *vmcp.BackendTarget, identity *auth.Identity, sessionHint string, ) (Session, *vmcp.CapabilityList, error) { - c, err := createMCPClient(target, identity, registry, sessionHint) + c, err := createMCPClient(ctx, target, identity, registry, sessionHint, provider) if err != nil { return nil, nil, fmt.Errorf("failed to create MCP client for backend %s: %w", target.WorkloadID, err) } @@ -238,11 +246,17 @@ func NewHTTPConnector(registry vmcpauth.OutgoingAuthRegistry) func( // to client.Close(), not to any caller-supplied init context. // sessionHint, when non-empty, is passed as the initial Mcp-Session-Id for // streamable-HTTP transports so the backend can resume an existing session. +// +// ctx is used only to resolve secret-backed entries in target.HeaderForward at +// client-creation time; the transport itself is started with context.Background() +// as described above. provider supplies values for those secret-backed headers. func createMCPClient( + ctx context.Context, target *vmcp.BackendTarget, identity *auth.Identity, registry vmcpauth.OutgoingAuthRegistry, sessionHint string, + provider secrets.Provider, ) (*mcpclient.Client, error) { // Resolve and validate the auth strategy once at client creation time. strategyName := authtypes.StrategyTypeUnauthenticated @@ -259,7 +273,15 @@ func createMCPClient( slog.Debug("Applied authentication strategy", "strategy", strategy.Name(), "backendID", target.WorkloadID) - // Build shared transport chain: auth → identity propagation. + // Build shared transport chain (innermost first → outermost): + // http.DefaultTransport → authRoundTripper → identityRoundTripper → headerForwardRoundTripper + // On an outbound request, the outermost stage runs first: header-forward + // injects its headers onto a request that does not yet carry auth/identity + // headers, then inner stages run and call Set() unconditionally so any + // overlapping name they care about (Authorization, identity headers) wins on + // the wire. Restricted header names (Host, hop-by-hop, X-Forwarded-*) are + // rejected at resolve time by resolveHeaderForward, so user-supplied + // HeaderForward cannot inject them in the first place. // The per-transport sections below may add a size-limiting wrapper on top. base := http.RoundTripper(http.DefaultTransport) base = &authRoundTripper{ @@ -269,6 +291,10 @@ func createMCPClient( target: target, } base = &identityRoundTripper{base: base, identity: identity} + base, err = headerforward.BuildHeaderForwardTripper(ctx, base, target.HeaderForward, provider, target.WorkloadID) + if err != nil { + return nil, fmt.Errorf("failed to build header-forward transport for backend %s: %w", target.WorkloadID, err) + } var c *mcpclient.Client switch target.TransportType { diff --git a/pkg/vmcp/session/internal/backend/mcp_session_capabilities_test.go b/pkg/vmcp/session/internal/backend/mcp_session_capabilities_test.go index bddcd77d81..ba877664ca 100644 --- a/pkg/vmcp/session/internal/backend/mcp_session_capabilities_test.go +++ b/pkg/vmcp/session/internal/backend/mcp_session_capabilities_test.go @@ -54,6 +54,11 @@ type fakeBackend struct { // resources/list failure. mu sync.Mutex methodCalls map[string]int + + // headersByMethod records the inbound request headers keyed by JSON-RPC + // method name. Tests asserting transport-chain behavior (e.g. HeaderForward) + // use headersFor(method) to inspect the headers a backend actually saw. + headersByMethod map[string]http.Header } type jsonRPCError struct { @@ -68,6 +73,7 @@ func newFakeBackend(t *testing.T, fb *fakeBackend) string { t.Helper() fb.t = t fb.methodCalls = make(map[string]int) + fb.headersByMethod = make(map[string]http.Header) mux := http.NewServeMux() mux.HandleFunc("/mcp", fb.handle) @@ -83,6 +89,19 @@ func (f *fakeBackend) callCount(method string) int { return f.methodCalls[method] } +// headersFor returns a clone of the inbound HTTP headers recorded for the most +// recent JSON-RPC request with the given method, or nil if no such request was +// seen. Cloning under the mutex keeps the caller safe from concurrent writes. +func (f *fakeBackend) headersFor(method string) http.Header { + f.mu.Lock() + defer f.mu.Unlock() + h := f.headersByMethod[method] + if h == nil { + return nil + } + return h.Clone() +} + // handle implements the JSON-RPC subset needed for backend init. The // streamable-HTTP transport sends POST requests with Accept: // application/json, text/event-stream — we always reply with @@ -117,6 +136,7 @@ func (f *fakeBackend) handle(w http.ResponseWriter, r *http.Request) { f.mu.Lock() f.methodCalls[msg.Method]++ + f.headersByMethod[msg.Method] = r.Header.Clone() f.mu.Unlock() // Notifications (no id, e.g. notifications/initialized) get an empty 202. @@ -148,6 +168,16 @@ func (f *fakeBackend) handle(w http.ResponseWriter, r *http.Request) { return } f.writeResult(w, msg.ID, map[string]any{"prompts": f.prompts}) + case string(mcp.MethodToolsCall): + // Minimal CallToolResult with a single text content. Tests that exercise + // the post-initialize transport chain (e.g. HeaderForward) need a method + // they can invoke after Initialize completes; tools/call is the cheapest. + f.writeResult(w, msg.ID, map[string]any{ + "content": []map[string]any{ + {"type": "text", "text": "ok"}, + }, + "isError": false, + }) default: f.writeError(w, msg.ID, &jsonRPCError{code: mcp.METHOD_NOT_FOUND, message: "Method not found"}) } diff --git a/pkg/vmcp/session/internal/backend/mcp_session_header_forward_test.go b/pkg/vmcp/session/internal/backend/mcp_session_header_forward_test.go new file mode 100644 index 0000000000..9d58de9b76 --- /dev/null +++ b/pkg/vmcp/session/internal/backend/mcp_session_header_forward_test.go @@ -0,0 +1,247 @@ +// SPDX-FileCopyrightText: Copyright 2026 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package backend + +import ( + "context" + "net/http" + "testing" + "time" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/pkg/secrets" + "github.com/stacklok/toolhive/pkg/vmcp" + vmcpauth "github.com/stacklok/toolhive/pkg/vmcp/auth" + authtypes "github.com/stacklok/toolhive/pkg/vmcp/auth/types" +) + +// TestHTTPSession_AppliesHeaderForwardToPostInitializeRequests asserts the +// invariants of the HeaderForward transport stage on the session-side +// connector (pkg/vmcp/session/internal/backend). Fixes #5289. +// +// The three subtests cover, in order: +// - plaintext: a user-configured AddPlaintextHeaders entry reaches the +// backend on post-initialize traffic (tools/call). +// - overlap precedence: when an inner auth stage writes the same header +// name that HeaderForward configures, the inner stage's value lands on +// the wire — proving the chain ordering documented on +// headerForwardRoundTripper.RoundTrip. +// - restricted-name rejection: NewHTTPConnector fails fast when +// HeaderForward names a restricted header, proving the +// resolveHeaderForward guard is wired into the session-side connector. +// +// Secret resolution requires t.Setenv and lives in its own non-parallel test +// below (TestHTTPSession_HeaderForward_ResolvesSecretsFromEnv). +func TestHTTPSession_AppliesHeaderForwardToPostInitializeRequests(t *testing.T) { + t.Parallel() + + t.Run("plaintext header reaches backend on tools/call", func(t *testing.T) { + t.Parallel() + + const ( + headerName = "X-MCP-Toolsets" + headerValue = "projects,issues,pull_requests,users,repos" + ) + + fb := &fakeBackend{advertiseTools: true, tools: []mcp.Tool{{Name: "echo"}}} + url := newFakeBackend(t, fb) + + target := &vmcp.BackendTarget{ + WorkloadID: "header-forward-backend", + WorkloadName: "header-forward-backend", + BaseURL: url, + TransportType: "streamable-http", + HeaderForward: &vmcp.HeaderForwardConfig{ + AddPlaintextHeaders: map[string]string{ + headerName: headerValue, + }, + }, + } + + sess := connectAndCallEcho(t, target) + t.Cleanup(func() { _ = sess.Close() }) + + got := fb.headersFor(string(mcp.MethodToolsCall)) + require.NotNil(t, got, "backend never received a tools/call request") + assert.Equal(t, headerValue, got.Get(headerName), + "HeaderForward.AddPlaintextHeaders must reach the backend on post-initialize requests") + }) + + t.Run("inner auth stage wins on overlapping header name", func(t *testing.T) { + t.Parallel() + + // The chain runs outer→inner on the outbound request: + // headerForwardRoundTripper → identityRoundTripper → authRoundTripper → http.DefaultTransport + // headerForwardRoundTripper sets X-Test-Identity first (request didn't + // have it). The inner authRoundTripper's Authenticate() then runs and + // calls Set() unconditionally, overwriting. We assert the auth value + // is the one on the wire. + const ( + headerName = "X-Test-Identity" + headerForwardVal = "from-header-forward" + authStrategyValue = "from-auth" + ) + + fb := &fakeBackend{advertiseTools: true, tools: []mcp.Tool{{Name: "echo"}}} + url := newFakeBackend(t, fb) + + registry := vmcpauth.NewDefaultOutgoingAuthRegistry() + require.NoError(t, registry.RegisterStrategy( + "test-header-setter", + &testHeaderSettingStrategy{name: "test-header-setter", header: headerName, value: authStrategyValue}, + )) + + target := &vmcp.BackendTarget{ + WorkloadID: "overlap-backend", + WorkloadName: "overlap-backend", + BaseURL: url, + TransportType: "streamable-http", + AuthConfig: &authtypes.BackendAuthStrategy{Type: "test-header-setter"}, + HeaderForward: &vmcp.HeaderForwardConfig{ + AddPlaintextHeaders: map[string]string{ + headerName: headerForwardVal, + }, + }, + } + + connector := NewHTTPConnector(registry) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + sess, _, err := connector(ctx, target, nil, "") + require.NoError(t, err) + t.Cleanup(func() { _ = sess.Close() }) + + _, err = sess.CallTool(ctx, "echo", map[string]any{}, nil) + require.NoError(t, err) + + got := fb.headersFor(string(mcp.MethodToolsCall)) + require.NotNil(t, got, "backend never received a tools/call request") + assert.Equal(t, authStrategyValue, got.Get(headerName), + "inner auth stage must overwrite HeaderForward on overlapping header names") + }) + + t.Run("restricted header in HeaderForward fails connector at startup", func(t *testing.T) { + t.Parallel() + + // resolveHeaderForward rejects names in middleware.RestrictedHeaders. + // Asserts the guard is reachable from the session-side connector — a + // misconfigured backend surfaces at session creation, not silently as + // a missing header on every request. + fb := &fakeBackend{advertiseTools: true, tools: []mcp.Tool{{Name: "echo"}}} + url := newFakeBackend(t, fb) + + target := &vmcp.BackendTarget{ + WorkloadID: "restricted-header-backend", + WorkloadName: "restricted-header-backend", + BaseURL: url, + TransportType: "streamable-http", + HeaderForward: &vmcp.HeaderForwardConfig{ + AddPlaintextHeaders: map[string]string{ + // X-Forwarded-For is in middleware.RestrictedHeaders — letting + // user config inject it would enable identity-spoofing of the + // caller's IP to the backend. + "X-Forwarded-For": "1.2.3.4", + }, + }, + } + + registry := newTestRegistry(t) + connector := NewHTTPConnector(registry) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + _, _, err := connector(ctx, target, nil, "") + require.Error(t, err) + assert.Contains(t, err.Error(), "restricted", + "connector must reject restricted header names in HeaderForward config") + }) +} + +// TestHTTPSession_HeaderForward_ResolvesSecretsFromEnv asserts that an +// AddHeadersFromSecret entry is resolved via the EnvironmentProvider that +// NewHTTPConnector constructs internally and reaches the backend on +// post-initialize traffic. +// +// Lives in its own non-parallel top-level test because t.Setenv (the only +// way to inject a value into the connector's env-backed secrets.Provider +// without changing production APIs) cannot be used inside a parallel test +// tree. +func TestHTTPSession_HeaderForward_ResolvesSecretsFromEnv(t *testing.T) { + const ( + headerName = "X-GitHub-Auth" + secretID = "GITHUB_PAT" + secretEnvName = secrets.EnvVarPrefix + secretID + secretValue = "ghp_test_value_12345" + ) + t.Setenv(secretEnvName, secretValue) + + fb := &fakeBackend{advertiseTools: true, tools: []mcp.Tool{{Name: "echo"}}} + url := newFakeBackend(t, fb) + + target := &vmcp.BackendTarget{ + WorkloadID: "secret-header-backend", + WorkloadName: "secret-header-backend", + BaseURL: url, + TransportType: "streamable-http", + HeaderForward: &vmcp.HeaderForwardConfig{ + AddHeadersFromSecret: map[string]string{ + headerName: secretID, + }, + }, + } + + sess := connectAndCallEcho(t, target) + t.Cleanup(func() { _ = sess.Close() }) + + got := fb.headersFor(string(mcp.MethodToolsCall)) + require.NotNil(t, got, "backend never received a tools/call request") + assert.Equal(t, secretValue, got.Get(headerName), + "HeaderForward.AddHeadersFromSecret must be resolved via the env provider and reach the backend") +} + +// connectAndCallEcho builds an HTTP session for target via the default test +// registry, makes a single tools/call("echo"), and returns the open session. +// The caller is responsible for closing the session (typically via t.Cleanup). +func connectAndCallEcho(t *testing.T, target *vmcp.BackendTarget) Session { + t.Helper() + + registry := newTestRegistry(t) + connector := NewHTTPConnector(registry) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + t.Cleanup(cancel) + + sess, caps, err := connector(ctx, target, nil, "") + require.NoError(t, err, "connector must initialise the backend successfully") + require.NotNil(t, sess, "connector returned nil session") + require.NotNil(t, caps, "connector returned nil capability list") + + _, err = sess.CallTool(ctx, "echo", map[string]any{}, nil) + require.NoError(t, err, "post-initialize CallTool must succeed") + + return sess +} + +// testHeaderSettingStrategy is a vmcpauth.Strategy stand-in that unconditionally +// writes a single header onto every outbound request. Used to drive the +// overlap-precedence assertion: when HeaderForward configures the same name, +// this strategy (called from the inner authRoundTripper) must win on the wire. +type testHeaderSettingStrategy struct { + name string + header string + value string +} + +func (s *testHeaderSettingStrategy) Name() string { return s.name } + +func (s *testHeaderSettingStrategy) Authenticate(_ context.Context, req *http.Request, _ *authtypes.BackendAuthStrategy) error { + req.Header.Set(s.header, s.value) + return nil +} + +func (*testHeaderSettingStrategy) Validate(_ *authtypes.BackendAuthStrategy) error { return nil } diff --git a/pkg/vmcp/session/internal/backend/mcp_session_test.go b/pkg/vmcp/session/internal/backend/mcp_session_test.go index 7daf5decce..842d0368fd 100644 --- a/pkg/vmcp/session/internal/backend/mcp_session_test.go +++ b/pkg/vmcp/session/internal/backend/mcp_session_test.go @@ -4,11 +4,13 @@ package backend import ( + "context" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/stacklok/toolhive/pkg/secrets" "github.com/stacklok/toolhive/pkg/vmcp" vmcpauth "github.com/stacklok/toolhive/pkg/vmcp/auth" "github.com/stacklok/toolhive/pkg/vmcp/auth/strategies" @@ -40,7 +42,7 @@ func TestCreateMCPClient_UnsupportedTransport(t *testing.T) { TransportType: transport, } - _, err := createMCPClient(target, nil, newTestRegistry(t), "") + _, err := createMCPClient(context.Background(), target, nil, newTestRegistry(t), "", secrets.NewEnvironmentProvider()) require.Error(t, err) assert.ErrorIs(t, err, vmcp.ErrUnsupportedTransport, "transport %q should return ErrUnsupportedTransport", transport)