From 5ead2ff5dadf0de96a8192b5c50f0e770b032fa0 Mon Sep 17 00:00:00 2001 From: Laurel Orr Date: Sat, 16 May 2026 15:11:50 -0700 Subject: [PATCH 1/2] Wire HeaderForward into vMCP session HTTP client PR #5239 added HeaderForward support to the startup capability-discovery client at pkg/vmcp/client/client.go, but the per-session MCP HTTP client in pkg/vmcp/session/internal/backend builds a parallel transport chain (DefaultTransport -> auth -> identity) that never reads target.HeaderForward. Every post-initialize MCP call (tools/list, tools/call, ...) therefore reaches the upstream without user-configured headers, leaving features like GitHub Copilot's X-MCP-Toolsets filter silently broken in v0.27.2. Construct a single secrets.EnvironmentProvider at connector build time, plumb it into createMCPClient, and wrap the shared chain with BuildHeaderForwardTripper as the outermost stage so vMCP auth/identity headers still win on overlapping names. Export the existing helper from pkg/vmcp/client so the session backend can reuse it without duplication. Fixes #5289 --- pkg/vmcp/headerforward/transport.go | 5 + .../session/internal/backend/mcp_session.go | 25 ++- .../mcp_session_header_forward_test.go | 190 ++++++++++++++++++ .../internal/backend/mcp_session_test.go | 4 +- 4 files changed, 221 insertions(+), 3 deletions(-) create mode 100644 pkg/vmcp/session/internal/backend/mcp_session_header_forward_test.go diff --git a/pkg/vmcp/headerforward/transport.go b/pkg/vmcp/headerforward/transport.go index 8bac2cf500..c6818d267b 100644 --- a/pkg/vmcp/headerforward/transport.go +++ b/pkg/vmcp/headerforward/transport.go @@ -64,6 +64,11 @@ func (h *headerForwardRoundTripper) RoundTrip(req *http.Request) (*http.Response // backend's pre-resolved HeaderForwardConfig. Returns base unchanged when no // header injection is configured or the effective header set is empty. // +// Used by both the vMCP backend client (startup capability discovery) and the +// per-session backend connector (long-lived MCP traffic). Exported so the +// session backend in pkg/vmcp/session/internal/backend can share the same +// transport-chain wiring. +// // Fails loudly (constructor validation, per go-style.md) when a secret identifier // cannot be resolved through the provider, so a misconfigured backend surfaces // at pod startup — not as a silent missing-header on every request. diff --git a/pkg/vmcp/session/internal/backend/mcp_session.go b/pkg/vmcp/session/internal/backend/mcp_session.go index 4ad386ae39..1267ab54b1 100644 --- a/pkg/vmcp/session/internal/backend/mcp_session.go +++ b/pkg/vmcp/session/internal/backend/mcp_session.go @@ -17,11 +17,13 @@ import ( "github.com/mark3labs/mcp-go/mcp" "github.com/stacklok/toolhive/pkg/auth" + "github.com/stacklok/toolhive/pkg/secrets" "github.com/stacklok/toolhive/pkg/versions" "github.com/stacklok/toolhive/pkg/vmcp" vmcpauth "github.com/stacklok/toolhive/pkg/vmcp/auth" authtypes "github.com/stacklok/toolhive/pkg/vmcp/auth/types" "github.com/stacklok/toolhive/pkg/vmcp/conversion" + "github.com/stacklok/toolhive/pkg/vmcp/headerforward" ) const ( @@ -196,19 +198,25 @@ func (c *mcpSession) GetPrompt( // // registry provides the authentication strategy for outgoing backend requests. // Pass a registry configured with the "unauthenticated" strategy to disable auth. +// +// A single secrets.EnvironmentProvider is constructed once per connector and +// shared across every session it creates; its lifetime matches the connector's. +// It is consumed by BuildHeaderForwardTripper to resolve secret-backed entries +// in target.HeaderForward. func NewHTTPConnector(registry vmcpauth.OutgoingAuthRegistry) func( ctx context.Context, target *vmcp.BackendTarget, identity *auth.Identity, sessionHint string, ) (Session, *vmcp.CapabilityList, error) { + provider := secrets.NewEnvironmentProvider() return func( ctx context.Context, target *vmcp.BackendTarget, identity *auth.Identity, sessionHint string, ) (Session, *vmcp.CapabilityList, error) { - c, err := createMCPClient(target, identity, registry, sessionHint) + c, err := createMCPClient(ctx, target, identity, registry, sessionHint, provider) if err != nil { return nil, nil, fmt.Errorf("failed to create MCP client for backend %s: %w", target.WorkloadID, err) } @@ -238,11 +246,17 @@ func NewHTTPConnector(registry vmcpauth.OutgoingAuthRegistry) func( // to client.Close(), not to any caller-supplied init context. // sessionHint, when non-empty, is passed as the initial Mcp-Session-Id for // streamable-HTTP transports so the backend can resume an existing session. +// +// ctx is used only to resolve secret-backed entries in target.HeaderForward at +// client-creation time; the transport itself is started with context.Background() +// as described above. provider supplies values for those secret-backed headers. func createMCPClient( + ctx context.Context, target *vmcp.BackendTarget, identity *auth.Identity, registry vmcpauth.OutgoingAuthRegistry, sessionHint string, + provider secrets.Provider, ) (*mcpclient.Client, error) { // Resolve and validate the auth strategy once at client creation time. strategyName := authtypes.StrategyTypeUnauthenticated @@ -259,7 +273,10 @@ func createMCPClient( slog.Debug("Applied authentication strategy", "strategy", strategy.Name(), "backendID", target.WorkloadID) - // Build shared transport chain: auth → identity propagation. + // Build shared transport chain: auth → identity propagation → header forward. + // HeaderForward is the outermost stage so inner stages (auth, identity) win + // on any overlapping header name — matching the ordering in + // headerForwardRoundTripper.RoundTrip, which skips names already set. // The per-transport sections below may add a size-limiting wrapper on top. base := http.RoundTripper(http.DefaultTransport) base = &authRoundTripper{ @@ -269,6 +286,10 @@ func createMCPClient( target: target, } base = &identityRoundTripper{base: base, identity: identity} + base, err = headerforward.BuildHeaderForwardTripper(ctx, base, target.HeaderForward, provider, target.WorkloadID) + if err != nil { + return nil, fmt.Errorf("failed to build header-forward transport for backend %s: %w", target.WorkloadID, err) + } var c *mcpclient.Client switch target.TransportType { diff --git a/pkg/vmcp/session/internal/backend/mcp_session_header_forward_test.go b/pkg/vmcp/session/internal/backend/mcp_session_header_forward_test.go new file mode 100644 index 0000000000..e49b131eb8 --- /dev/null +++ b/pkg/vmcp/session/internal/backend/mcp_session_header_forward_test.go @@ -0,0 +1,190 @@ +// SPDX-FileCopyrightText: Copyright 2026 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package backend + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/pkg/vmcp" +) + +// headerCapturingBackend is a minimal streamable-HTTP MCP fake that records +// inbound request headers keyed by JSON-RPC method. The test asserts that a +// user-configured HeaderForward header reaches the backend on POST-INITIALIZE +// traffic — see issue #5289. The startup capability-discovery path was fixed +// in PR #5239; per-session HTTP traffic is still missing the wrap. +type headerCapturingBackend struct { + t *testing.T + + mu sync.Mutex + headersByMethod map[string]http.Header +} + +func newHeaderCapturingBackend(t *testing.T) (*headerCapturingBackend, string) { + t.Helper() + fb := &headerCapturingBackend{ + t: t, + headersByMethod: make(map[string]http.Header), + } + mux := http.NewServeMux() + mux.HandleFunc("/mcp", fb.handle) + ts := httptest.NewServer(mux) + t.Cleanup(ts.Close) + return fb, ts.URL + "/mcp" +} + +func (f *headerCapturingBackend) headersFor(method string) http.Header { + f.mu.Lock() + defer f.mu.Unlock() + return f.headersByMethod[method] +} + +func (f *headerCapturingBackend) handle(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + // Streamable-HTTP transports may open a GET for server-pushed + // notifications; rejecting it cleanly is fine for this test. + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + + body, err := io.ReadAll(r.Body) + if err != nil { + f.t.Errorf("headerCapturingBackend: read body: %v", err) + w.WriteHeader(http.StatusBadRequest) + return + } + defer func() { _ = r.Body.Close() }() + + var msg struct { + JSONRPC string `json:"jsonrpc"` + ID json.RawMessage `json:"id"` + Method string `json:"method"` + } + if err := json.Unmarshal(body, &msg); err != nil { + f.t.Errorf("headerCapturingBackend: decode: %v body=%s", err, string(body)) + w.WriteHeader(http.StatusBadRequest) + return + } + + f.mu.Lock() + f.headersByMethod[msg.Method] = r.Header.Clone() + f.mu.Unlock() + + // Notifications (no id, e.g. notifications/initialized) get an empty 202. + if len(msg.ID) == 0 || string(msg.ID) == "null" { + w.WriteHeader(http.StatusAccepted) + return + } + + switch msg.Method { + case string(mcp.MethodInitialize): + w.Header().Set("Mcp-Session-Id", "header-forward-test-session") + f.writeResult(w, msg.ID, map[string]any{ + "protocolVersion": mcp.LATEST_PROTOCOL_VERSION, + "capabilities": map[string]any{ + "tools": map[string]any{}, + }, + "serverInfo": map[string]any{"name": "header-forward-fake", "version": "0.0.0"}, + }) + case string(mcp.MethodToolsList): + f.writeResult(w, msg.ID, map[string]any{ + "tools": []mcp.Tool{{Name: "echo", Description: "echo tool"}}, + }) + case string(mcp.MethodToolsCall): + f.writeResult(w, msg.ID, map[string]any{ + "content": []map[string]any{ + {"type": "text", "text": "ok"}, + }, + "isError": false, + }) + default: + f.writeResult(w, msg.ID, map[string]any{}) + } +} + +func (f *headerCapturingBackend) writeResult(w http.ResponseWriter, id json.RawMessage, result any) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(map[string]any{ + "jsonrpc": "2.0", + "id": json.RawMessage(id), + "result": result, + }); err != nil { + f.t.Errorf("headerCapturingBackend: encode result: %v", err) + } +} + +// TestHTTPSession_AppliesHeaderForwardToPostInitializeRequests is the red-phase +// regression test for issue #5289. PR #5239 fixed HeaderForward for the vMCP +// backend client (used for startup capability discovery) but did not extend the +// fix to the session-side connector at pkg/vmcp/session/internal/backend. +// As a result, user-configured headers (e.g. X-MCP-Toolsets for GitHub MCP) +// never reach the backend on per-session requests like tools/call. +// +// The test asserts that, after the connector completes Initialize, a +// subsequent CallTool carries the configured plaintext header on the wire. +// On main today it fails because the connector's transport chain does not +// include a header-forward round-tripper — see createMCPClient in +// mcp_session.go (the chain is http.DefaultTransport → authRoundTripper → +// identityRoundTripper, with no HeaderForward stage). +func TestHTTPSession_AppliesHeaderForwardToPostInitializeRequests(t *testing.T) { + t.Parallel() + + const ( + headerName = "X-MCP-Toolsets" + headerValue = "projects,issues,pull_requests,users,repos" + ) + + fb, url := newHeaderCapturingBackend(t) + + target := &vmcp.BackendTarget{ + WorkloadID: "header-forward-backend", + WorkloadName: "header-forward-backend", + BaseURL: url, + TransportType: "streamable-http", + HeaderForward: &vmcp.HeaderForwardConfig{ + AddPlaintextHeaders: map[string]string{ + headerName: headerValue, + }, + }, + } + + registry := newTestRegistry(t) + connector := NewHTTPConnector(registry) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + sess, caps, err := connector(ctx, target, nil, "") + require.NoError(t, err, "connector must initialise the backend successfully") + require.NotNil(t, sess, "connector returned nil session") + require.NotNil(t, caps, "connector returned nil capability list") + t.Cleanup(func() { _ = sess.Close() }) + + // Make a single MCP call AFTER initialize completes. tools/call exercises + // the same transport chain as initialize but is unambiguously a + // post-handshake request — which is exactly where the regression lives. + _, err = sess.CallTool(ctx, "echo", map[string]any{}, nil) + require.NoError(t, err, "post-initialize CallTool must succeed") + + // The recorded inbound headers for the tools/call request must include the + // user-configured forward header. This is the single assertion target: + // the test fails for exactly one reason — header missing on the recorded + // post-initialize request. + gotHeaders := fb.headersFor(string(mcp.MethodToolsCall)) + require.NotNil(t, gotHeaders, "backend never received a tools/call request") + assert.Equal(t, headerValue, gotHeaders.Get(headerName), + "HeaderForward.AddPlaintextHeaders must reach the backend on post-initialize requests") +} diff --git a/pkg/vmcp/session/internal/backend/mcp_session_test.go b/pkg/vmcp/session/internal/backend/mcp_session_test.go index 7daf5decce..842d0368fd 100644 --- a/pkg/vmcp/session/internal/backend/mcp_session_test.go +++ b/pkg/vmcp/session/internal/backend/mcp_session_test.go @@ -4,11 +4,13 @@ package backend import ( + "context" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/stacklok/toolhive/pkg/secrets" "github.com/stacklok/toolhive/pkg/vmcp" vmcpauth "github.com/stacklok/toolhive/pkg/vmcp/auth" "github.com/stacklok/toolhive/pkg/vmcp/auth/strategies" @@ -40,7 +42,7 @@ func TestCreateMCPClient_UnsupportedTransport(t *testing.T) { TransportType: transport, } - _, err := createMCPClient(target, nil, newTestRegistry(t), "") + _, err := createMCPClient(context.Background(), target, nil, newTestRegistry(t), "", secrets.NewEnvironmentProvider()) require.Error(t, err) assert.ErrorIs(t, err, vmcp.ErrUnsupportedTransport, "transport %q should return ErrUnsupportedTransport", transport) From c6850a8c95f99066a1bbd9dd20c58072678a017b Mon Sep 17 00:00:00 2001 From: Laurel Orr Date: Sat, 16 May 2026 15:56:42 -0700 Subject: [PATCH 2/2] Tighten HeaderForward test coverage and docs Fix the inverted chain-ordering comments on mcp_session.go's transport chain assembly and on headerForwardRoundTripper.RoundTrip. The outermost wrapper runs FIRST on the outbound request, so header-forward injects onto a request without auth/identity headers and the inner stages then overwrite with Set(). The previous wording claimed the opposite. The production behaviour is unchanged; only the doc text moved. Replace headerCapturingBackend with the shared fakeBackend by adding headersByMethod / headersFor and a tools/call handler. This removes a near-duplicate fake server in the session backend tests. Extend the header-forward test to cover overlap precedence (an inner auth stage's value wins on the wire when both sides set the same name) and AddHeadersFromSecret end-to-end via t.Setenv against the env-backed secrets provider. Add a restricted-name rejection case to prove the resolveHeaderForward guard is wired into the session-side connector. Trim diff-narrating comments from the test file: drop the past-tense narrative about which PR fixed what and describe the assertions instead. The issue reference (#5289) is kept on the test docstring. --- pkg/vmcp/headerforward/transport.go | 10 +- .../session/internal/backend/mcp_session.go | 13 +- .../backend/mcp_session_capabilities_test.go | 30 ++ .../mcp_session_header_forward_test.go | 337 ++++++++++-------- 4 files changed, 243 insertions(+), 147 deletions(-) diff --git a/pkg/vmcp/headerforward/transport.go b/pkg/vmcp/headerforward/transport.go index c6818d267b..e235fcb17a 100644 --- a/pkg/vmcp/headerforward/transport.go +++ b/pkg/vmcp/headerforward/transport.go @@ -35,9 +35,13 @@ type headerForwardRoundTripper struct { } // RoundTrip injects the pre-resolved headers onto a clone of the request and -// delegates to the wrapped transport. Headers already present on the request -// are left untouched so inner transports (auth, identity, trace) can always -// override user-supplied values for the same name. +// delegates to the wrapped transport. Names already present on the inbound +// request are skipped — this preserves any header the *caller* set on the +// request before invoking the round-tripper chain. Inner (closer to the wire) +// transports run AFTER this one and use Set() unconditionally, so on any +// overlapping name those stages still win on the wire. Restricted names are +// blocked at resolve time, so user-supplied config cannot reach this point +// for Host, hop-by-hop, or X-Forwarded-* anyway. func (h *headerForwardRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { if len(h.headers) == 0 { return h.base.RoundTrip(req) diff --git a/pkg/vmcp/session/internal/backend/mcp_session.go b/pkg/vmcp/session/internal/backend/mcp_session.go index 1267ab54b1..a90a545a47 100644 --- a/pkg/vmcp/session/internal/backend/mcp_session.go +++ b/pkg/vmcp/session/internal/backend/mcp_session.go @@ -273,10 +273,15 @@ func createMCPClient( slog.Debug("Applied authentication strategy", "strategy", strategy.Name(), "backendID", target.WorkloadID) - // Build shared transport chain: auth → identity propagation → header forward. - // HeaderForward is the outermost stage so inner stages (auth, identity) win - // on any overlapping header name — matching the ordering in - // headerForwardRoundTripper.RoundTrip, which skips names already set. + // Build shared transport chain (innermost first → outermost): + // http.DefaultTransport → authRoundTripper → identityRoundTripper → headerForwardRoundTripper + // On an outbound request, the outermost stage runs first: header-forward + // injects its headers onto a request that does not yet carry auth/identity + // headers, then inner stages run and call Set() unconditionally so any + // overlapping name they care about (Authorization, identity headers) wins on + // the wire. Restricted header names (Host, hop-by-hop, X-Forwarded-*) are + // rejected at resolve time by resolveHeaderForward, so user-supplied + // HeaderForward cannot inject them in the first place. // The per-transport sections below may add a size-limiting wrapper on top. base := http.RoundTripper(http.DefaultTransport) base = &authRoundTripper{ diff --git a/pkg/vmcp/session/internal/backend/mcp_session_capabilities_test.go b/pkg/vmcp/session/internal/backend/mcp_session_capabilities_test.go index bddcd77d81..ba877664ca 100644 --- a/pkg/vmcp/session/internal/backend/mcp_session_capabilities_test.go +++ b/pkg/vmcp/session/internal/backend/mcp_session_capabilities_test.go @@ -54,6 +54,11 @@ type fakeBackend struct { // resources/list failure. mu sync.Mutex methodCalls map[string]int + + // headersByMethod records the inbound request headers keyed by JSON-RPC + // method name. Tests asserting transport-chain behavior (e.g. HeaderForward) + // use headersFor(method) to inspect the headers a backend actually saw. + headersByMethod map[string]http.Header } type jsonRPCError struct { @@ -68,6 +73,7 @@ func newFakeBackend(t *testing.T, fb *fakeBackend) string { t.Helper() fb.t = t fb.methodCalls = make(map[string]int) + fb.headersByMethod = make(map[string]http.Header) mux := http.NewServeMux() mux.HandleFunc("/mcp", fb.handle) @@ -83,6 +89,19 @@ func (f *fakeBackend) callCount(method string) int { return f.methodCalls[method] } +// headersFor returns a clone of the inbound HTTP headers recorded for the most +// recent JSON-RPC request with the given method, or nil if no such request was +// seen. Cloning under the mutex keeps the caller safe from concurrent writes. +func (f *fakeBackend) headersFor(method string) http.Header { + f.mu.Lock() + defer f.mu.Unlock() + h := f.headersByMethod[method] + if h == nil { + return nil + } + return h.Clone() +} + // handle implements the JSON-RPC subset needed for backend init. The // streamable-HTTP transport sends POST requests with Accept: // application/json, text/event-stream — we always reply with @@ -117,6 +136,7 @@ func (f *fakeBackend) handle(w http.ResponseWriter, r *http.Request) { f.mu.Lock() f.methodCalls[msg.Method]++ + f.headersByMethod[msg.Method] = r.Header.Clone() f.mu.Unlock() // Notifications (no id, e.g. notifications/initialized) get an empty 202. @@ -148,6 +168,16 @@ func (f *fakeBackend) handle(w http.ResponseWriter, r *http.Request) { return } f.writeResult(w, msg.ID, map[string]any{"prompts": f.prompts}) + case string(mcp.MethodToolsCall): + // Minimal CallToolResult with a single text content. Tests that exercise + // the post-initialize transport chain (e.g. HeaderForward) need a method + // they can invoke after Initialize completes; tools/call is the cheapest. + f.writeResult(w, msg.ID, map[string]any{ + "content": []map[string]any{ + {"type": "text", "text": "ok"}, + }, + "isError": false, + }) default: f.writeError(w, msg.ID, &jsonRPCError{code: mcp.METHOD_NOT_FOUND, message: "Method not found"}) } diff --git a/pkg/vmcp/session/internal/backend/mcp_session_header_forward_test.go b/pkg/vmcp/session/internal/backend/mcp_session_header_forward_test.go index e49b131eb8..9d58de9b76 100644 --- a/pkg/vmcp/session/internal/backend/mcp_session_header_forward_test.go +++ b/pkg/vmcp/session/internal/backend/mcp_session_header_forward_test.go @@ -5,11 +5,7 @@ package backend import ( "context" - "encoding/json" - "io" "net/http" - "net/http/httptest" - "sync" "testing" "time" @@ -17,174 +13,235 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/stacklok/toolhive/pkg/secrets" "github.com/stacklok/toolhive/pkg/vmcp" + vmcpauth "github.com/stacklok/toolhive/pkg/vmcp/auth" + authtypes "github.com/stacklok/toolhive/pkg/vmcp/auth/types" ) -// headerCapturingBackend is a minimal streamable-HTTP MCP fake that records -// inbound request headers keyed by JSON-RPC method. The test asserts that a -// user-configured HeaderForward header reaches the backend on POST-INITIALIZE -// traffic — see issue #5289. The startup capability-discovery path was fixed -// in PR #5239; per-session HTTP traffic is still missing the wrap. -type headerCapturingBackend struct { - t *testing.T - - mu sync.Mutex - headersByMethod map[string]http.Header -} - -func newHeaderCapturingBackend(t *testing.T) (*headerCapturingBackend, string) { - t.Helper() - fb := &headerCapturingBackend{ - t: t, - headersByMethod: make(map[string]http.Header), - } - mux := http.NewServeMux() - mux.HandleFunc("/mcp", fb.handle) - ts := httptest.NewServer(mux) - t.Cleanup(ts.Close) - return fb, ts.URL + "/mcp" -} - -func (f *headerCapturingBackend) headersFor(method string) http.Header { - f.mu.Lock() - defer f.mu.Unlock() - return f.headersByMethod[method] -} - -func (f *headerCapturingBackend) handle(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodPost { - // Streamable-HTTP transports may open a GET for server-pushed - // notifications; rejecting it cleanly is fine for this test. - w.WriteHeader(http.StatusMethodNotAllowed) - return - } - - body, err := io.ReadAll(r.Body) - if err != nil { - f.t.Errorf("headerCapturingBackend: read body: %v", err) - w.WriteHeader(http.StatusBadRequest) - return - } - defer func() { _ = r.Body.Close() }() - - var msg struct { - JSONRPC string `json:"jsonrpc"` - ID json.RawMessage `json:"id"` - Method string `json:"method"` - } - if err := json.Unmarshal(body, &msg); err != nil { - f.t.Errorf("headerCapturingBackend: decode: %v body=%s", err, string(body)) - w.WriteHeader(http.StatusBadRequest) - return - } - - f.mu.Lock() - f.headersByMethod[msg.Method] = r.Header.Clone() - f.mu.Unlock() - - // Notifications (no id, e.g. notifications/initialized) get an empty 202. - if len(msg.ID) == 0 || string(msg.ID) == "null" { - w.WriteHeader(http.StatusAccepted) - return - } +// TestHTTPSession_AppliesHeaderForwardToPostInitializeRequests asserts the +// invariants of the HeaderForward transport stage on the session-side +// connector (pkg/vmcp/session/internal/backend). Fixes #5289. +// +// The three subtests cover, in order: +// - plaintext: a user-configured AddPlaintextHeaders entry reaches the +// backend on post-initialize traffic (tools/call). +// - overlap precedence: when an inner auth stage writes the same header +// name that HeaderForward configures, the inner stage's value lands on +// the wire — proving the chain ordering documented on +// headerForwardRoundTripper.RoundTrip. +// - restricted-name rejection: NewHTTPConnector fails fast when +// HeaderForward names a restricted header, proving the +// resolveHeaderForward guard is wired into the session-side connector. +// +// Secret resolution requires t.Setenv and lives in its own non-parallel test +// below (TestHTTPSession_HeaderForward_ResolvesSecretsFromEnv). +func TestHTTPSession_AppliesHeaderForwardToPostInitializeRequests(t *testing.T) { + t.Parallel() - switch msg.Method { - case string(mcp.MethodInitialize): - w.Header().Set("Mcp-Session-Id", "header-forward-test-session") - f.writeResult(w, msg.ID, map[string]any{ - "protocolVersion": mcp.LATEST_PROTOCOL_VERSION, - "capabilities": map[string]any{ - "tools": map[string]any{}, + t.Run("plaintext header reaches backend on tools/call", func(t *testing.T) { + t.Parallel() + + const ( + headerName = "X-MCP-Toolsets" + headerValue = "projects,issues,pull_requests,users,repos" + ) + + fb := &fakeBackend{advertiseTools: true, tools: []mcp.Tool{{Name: "echo"}}} + url := newFakeBackend(t, fb) + + target := &vmcp.BackendTarget{ + WorkloadID: "header-forward-backend", + WorkloadName: "header-forward-backend", + BaseURL: url, + TransportType: "streamable-http", + HeaderForward: &vmcp.HeaderForwardConfig{ + AddPlaintextHeaders: map[string]string{ + headerName: headerValue, + }, }, - "serverInfo": map[string]any{"name": "header-forward-fake", "version": "0.0.0"}, - }) - case string(mcp.MethodToolsList): - f.writeResult(w, msg.ID, map[string]any{ - "tools": []mcp.Tool{{Name: "echo", Description: "echo tool"}}, - }) - case string(mcp.MethodToolsCall): - f.writeResult(w, msg.ID, map[string]any{ - "content": []map[string]any{ - {"type": "text", "text": "ok"}, + } + + sess := connectAndCallEcho(t, target) + t.Cleanup(func() { _ = sess.Close() }) + + got := fb.headersFor(string(mcp.MethodToolsCall)) + require.NotNil(t, got, "backend never received a tools/call request") + assert.Equal(t, headerValue, got.Get(headerName), + "HeaderForward.AddPlaintextHeaders must reach the backend on post-initialize requests") + }) + + t.Run("inner auth stage wins on overlapping header name", func(t *testing.T) { + t.Parallel() + + // The chain runs outer→inner on the outbound request: + // headerForwardRoundTripper → identityRoundTripper → authRoundTripper → http.DefaultTransport + // headerForwardRoundTripper sets X-Test-Identity first (request didn't + // have it). The inner authRoundTripper's Authenticate() then runs and + // calls Set() unconditionally, overwriting. We assert the auth value + // is the one on the wire. + const ( + headerName = "X-Test-Identity" + headerForwardVal = "from-header-forward" + authStrategyValue = "from-auth" + ) + + fb := &fakeBackend{advertiseTools: true, tools: []mcp.Tool{{Name: "echo"}}} + url := newFakeBackend(t, fb) + + registry := vmcpauth.NewDefaultOutgoingAuthRegistry() + require.NoError(t, registry.RegisterStrategy( + "test-header-setter", + &testHeaderSettingStrategy{name: "test-header-setter", header: headerName, value: authStrategyValue}, + )) + + target := &vmcp.BackendTarget{ + WorkloadID: "overlap-backend", + WorkloadName: "overlap-backend", + BaseURL: url, + TransportType: "streamable-http", + AuthConfig: &authtypes.BackendAuthStrategy{Type: "test-header-setter"}, + HeaderForward: &vmcp.HeaderForwardConfig{ + AddPlaintextHeaders: map[string]string{ + headerName: headerForwardVal, + }, }, - "isError": false, - }) - default: - f.writeResult(w, msg.ID, map[string]any{}) - } -} - -func (f *headerCapturingBackend) writeResult(w http.ResponseWriter, id json.RawMessage, result any) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - if err := json.NewEncoder(w).Encode(map[string]any{ - "jsonrpc": "2.0", - "id": json.RawMessage(id), - "result": result, - }); err != nil { - f.t.Errorf("headerCapturingBackend: encode result: %v", err) - } + } + + connector := NewHTTPConnector(registry) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + sess, _, err := connector(ctx, target, nil, "") + require.NoError(t, err) + t.Cleanup(func() { _ = sess.Close() }) + + _, err = sess.CallTool(ctx, "echo", map[string]any{}, nil) + require.NoError(t, err) + + got := fb.headersFor(string(mcp.MethodToolsCall)) + require.NotNil(t, got, "backend never received a tools/call request") + assert.Equal(t, authStrategyValue, got.Get(headerName), + "inner auth stage must overwrite HeaderForward on overlapping header names") + }) + + t.Run("restricted header in HeaderForward fails connector at startup", func(t *testing.T) { + t.Parallel() + + // resolveHeaderForward rejects names in middleware.RestrictedHeaders. + // Asserts the guard is reachable from the session-side connector — a + // misconfigured backend surfaces at session creation, not silently as + // a missing header on every request. + fb := &fakeBackend{advertiseTools: true, tools: []mcp.Tool{{Name: "echo"}}} + url := newFakeBackend(t, fb) + + target := &vmcp.BackendTarget{ + WorkloadID: "restricted-header-backend", + WorkloadName: "restricted-header-backend", + BaseURL: url, + TransportType: "streamable-http", + HeaderForward: &vmcp.HeaderForwardConfig{ + AddPlaintextHeaders: map[string]string{ + // X-Forwarded-For is in middleware.RestrictedHeaders — letting + // user config inject it would enable identity-spoofing of the + // caller's IP to the backend. + "X-Forwarded-For": "1.2.3.4", + }, + }, + } + + registry := newTestRegistry(t) + connector := NewHTTPConnector(registry) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + _, _, err := connector(ctx, target, nil, "") + require.Error(t, err) + assert.Contains(t, err.Error(), "restricted", + "connector must reject restricted header names in HeaderForward config") + }) } -// TestHTTPSession_AppliesHeaderForwardToPostInitializeRequests is the red-phase -// regression test for issue #5289. PR #5239 fixed HeaderForward for the vMCP -// backend client (used for startup capability discovery) but did not extend the -// fix to the session-side connector at pkg/vmcp/session/internal/backend. -// As a result, user-configured headers (e.g. X-MCP-Toolsets for GitHub MCP) -// never reach the backend on per-session requests like tools/call. +// TestHTTPSession_HeaderForward_ResolvesSecretsFromEnv asserts that an +// AddHeadersFromSecret entry is resolved via the EnvironmentProvider that +// NewHTTPConnector constructs internally and reaches the backend on +// post-initialize traffic. // -// The test asserts that, after the connector completes Initialize, a -// subsequent CallTool carries the configured plaintext header on the wire. -// On main today it fails because the connector's transport chain does not -// include a header-forward round-tripper — see createMCPClient in -// mcp_session.go (the chain is http.DefaultTransport → authRoundTripper → -// identityRoundTripper, with no HeaderForward stage). -func TestHTTPSession_AppliesHeaderForwardToPostInitializeRequests(t *testing.T) { - t.Parallel() - +// Lives in its own non-parallel top-level test because t.Setenv (the only +// way to inject a value into the connector's env-backed secrets.Provider +// without changing production APIs) cannot be used inside a parallel test +// tree. +func TestHTTPSession_HeaderForward_ResolvesSecretsFromEnv(t *testing.T) { const ( - headerName = "X-MCP-Toolsets" - headerValue = "projects,issues,pull_requests,users,repos" + headerName = "X-GitHub-Auth" + secretID = "GITHUB_PAT" + secretEnvName = secrets.EnvVarPrefix + secretID + secretValue = "ghp_test_value_12345" ) + t.Setenv(secretEnvName, secretValue) - fb, url := newHeaderCapturingBackend(t) + fb := &fakeBackend{advertiseTools: true, tools: []mcp.Tool{{Name: "echo"}}} + url := newFakeBackend(t, fb) target := &vmcp.BackendTarget{ - WorkloadID: "header-forward-backend", - WorkloadName: "header-forward-backend", + WorkloadID: "secret-header-backend", + WorkloadName: "secret-header-backend", BaseURL: url, TransportType: "streamable-http", HeaderForward: &vmcp.HeaderForwardConfig{ - AddPlaintextHeaders: map[string]string{ - headerName: headerValue, + AddHeadersFromSecret: map[string]string{ + headerName: secretID, }, }, } + sess := connectAndCallEcho(t, target) + t.Cleanup(func() { _ = sess.Close() }) + + got := fb.headersFor(string(mcp.MethodToolsCall)) + require.NotNil(t, got, "backend never received a tools/call request") + assert.Equal(t, secretValue, got.Get(headerName), + "HeaderForward.AddHeadersFromSecret must be resolved via the env provider and reach the backend") +} + +// connectAndCallEcho builds an HTTP session for target via the default test +// registry, makes a single tools/call("echo"), and returns the open session. +// The caller is responsible for closing the session (typically via t.Cleanup). +func connectAndCallEcho(t *testing.T, target *vmcp.BackendTarget) Session { + t.Helper() + registry := newTestRegistry(t) connector := NewHTTPConnector(registry) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() + t.Cleanup(cancel) sess, caps, err := connector(ctx, target, nil, "") require.NoError(t, err, "connector must initialise the backend successfully") require.NotNil(t, sess, "connector returned nil session") require.NotNil(t, caps, "connector returned nil capability list") - t.Cleanup(func() { _ = sess.Close() }) - // Make a single MCP call AFTER initialize completes. tools/call exercises - // the same transport chain as initialize but is unambiguously a - // post-handshake request — which is exactly where the regression lives. _, err = sess.CallTool(ctx, "echo", map[string]any{}, nil) require.NoError(t, err, "post-initialize CallTool must succeed") - // The recorded inbound headers for the tools/call request must include the - // user-configured forward header. This is the single assertion target: - // the test fails for exactly one reason — header missing on the recorded - // post-initialize request. - gotHeaders := fb.headersFor(string(mcp.MethodToolsCall)) - require.NotNil(t, gotHeaders, "backend never received a tools/call request") - assert.Equal(t, headerValue, gotHeaders.Get(headerName), - "HeaderForward.AddPlaintextHeaders must reach the backend on post-initialize requests") + return sess +} + +// testHeaderSettingStrategy is a vmcpauth.Strategy stand-in that unconditionally +// writes a single header onto every outbound request. Used to drive the +// overlap-precedence assertion: when HeaderForward configures the same name, +// this strategy (called from the inner authRoundTripper) must win on the wire. +type testHeaderSettingStrategy struct { + name string + header string + value string } + +func (s *testHeaderSettingStrategy) Name() string { return s.name } + +func (s *testHeaderSettingStrategy) Authenticate(_ context.Context, req *http.Request, _ *authtypes.BackendAuthStrategy) error { + req.Header.Set(s.header, s.value) + return nil +} + +func (*testHeaderSettingStrategy) Validate(_ *authtypes.BackendAuthStrategy) error { return nil }