Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions pkg/vmcp/headerforward/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,13 @@ type headerForwardRoundTripper struct {
}

// RoundTrip injects the pre-resolved headers onto a clone of the request and
// delegates to the wrapped transport. Headers already present on the request
// are left untouched so inner transports (auth, identity, trace) can always
// override user-supplied values for the same name.
// delegates to the wrapped transport. Names already present on the inbound
// request are skipped — this preserves any header the *caller* set on the
// request before invoking the round-tripper chain. Inner (closer to the wire)
// transports run AFTER this one and use Set() unconditionally, so on any
// overlapping name those stages still win on the wire. Restricted names are
// blocked at resolve time, so user-supplied config cannot reach this point
// for Host, hop-by-hop, or X-Forwarded-* anyway.
func (h *headerForwardRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
if len(h.headers) == 0 {
return h.base.RoundTrip(req)
Expand All @@ -64,6 +68,11 @@ func (h *headerForwardRoundTripper) RoundTrip(req *http.Request) (*http.Response
// backend's pre-resolved HeaderForwardConfig. Returns base unchanged when no
// header injection is configured or the effective header set is empty.
//
// Used by both the vMCP backend client (startup capability discovery) and the
// per-session backend connector (long-lived MCP traffic). Exported so the
// session backend in pkg/vmcp/session/internal/backend can share the same
// transport-chain wiring.
//
// Fails loudly (constructor validation, per go-style.md) when a secret identifier
// cannot be resolved through the provider, so a misconfigured backend surfaces
// at pod startup — not as a silent missing-header on every request.
Expand Down
30 changes: 28 additions & 2 deletions pkg/vmcp/session/internal/backend/mcp_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@ import (
"github.com/mark3labs/mcp-go/mcp"

"github.com/stacklok/toolhive/pkg/auth"
"github.com/stacklok/toolhive/pkg/secrets"
"github.com/stacklok/toolhive/pkg/versions"
"github.com/stacklok/toolhive/pkg/vmcp"
vmcpauth "github.com/stacklok/toolhive/pkg/vmcp/auth"
authtypes "github.com/stacklok/toolhive/pkg/vmcp/auth/types"
"github.com/stacklok/toolhive/pkg/vmcp/conversion"
"github.com/stacklok/toolhive/pkg/vmcp/headerforward"
)

const (
Expand Down Expand Up @@ -196,19 +198,25 @@ func (c *mcpSession) GetPrompt(
//
// registry provides the authentication strategy for outgoing backend requests.
// Pass a registry configured with the "unauthenticated" strategy to disable auth.
//
// A single secrets.EnvironmentProvider is constructed once per connector and
// shared across every session it creates; its lifetime matches the connector's.
// It is consumed by BuildHeaderForwardTripper to resolve secret-backed entries
// in target.HeaderForward.
func NewHTTPConnector(registry vmcpauth.OutgoingAuthRegistry) func(
ctx context.Context,
target *vmcp.BackendTarget,
identity *auth.Identity,
sessionHint string,
) (Session, *vmcp.CapabilityList, error) {
provider := secrets.NewEnvironmentProvider()
return func(
ctx context.Context,
target *vmcp.BackendTarget,
identity *auth.Identity,
sessionHint string,
) (Session, *vmcp.CapabilityList, error) {
c, err := createMCPClient(target, identity, registry, sessionHint)
c, err := createMCPClient(ctx, target, identity, registry, sessionHint, provider)
if err != nil {
return nil, nil, fmt.Errorf("failed to create MCP client for backend %s: %w", target.WorkloadID, err)
}
Expand Down Expand Up @@ -238,11 +246,17 @@ func NewHTTPConnector(registry vmcpauth.OutgoingAuthRegistry) func(
// to client.Close(), not to any caller-supplied init context.
// sessionHint, when non-empty, is passed as the initial Mcp-Session-Id for
// streamable-HTTP transports so the backend can resume an existing session.
//
// ctx is used only to resolve secret-backed entries in target.HeaderForward at
// client-creation time; the transport itself is started with context.Background()
// as described above. provider supplies values for those secret-backed headers.
func createMCPClient(
ctx context.Context,
target *vmcp.BackendTarget,
identity *auth.Identity,
registry vmcpauth.OutgoingAuthRegistry,
sessionHint string,
provider secrets.Provider,
) (*mcpclient.Client, error) {
// Resolve and validate the auth strategy once at client creation time.
strategyName := authtypes.StrategyTypeUnauthenticated
Expand All @@ -259,7 +273,15 @@ func createMCPClient(

slog.Debug("Applied authentication strategy", "strategy", strategy.Name(), "backendID", target.WorkloadID)

// Build shared transport chain: auth → identity propagation.
// Build shared transport chain (innermost first → outermost):
// http.DefaultTransport → authRoundTripper → identityRoundTripper → headerForwardRoundTripper
// On an outbound request, the outermost stage runs first: header-forward
// injects its headers onto a request that does not yet carry auth/identity
// headers, then inner stages run and call Set() unconditionally so any
// overlapping name they care about (Authorization, identity headers) wins on
// the wire. Restricted header names (Host, hop-by-hop, X-Forwarded-*) are
// rejected at resolve time by resolveHeaderForward, so user-supplied
// HeaderForward cannot inject them in the first place.
// The per-transport sections below may add a size-limiting wrapper on top.
base := http.RoundTripper(http.DefaultTransport)
base = &authRoundTripper{
Expand All @@ -269,6 +291,10 @@ func createMCPClient(
target: target,
}
base = &identityRoundTripper{base: base, identity: identity}
base, err = headerforward.BuildHeaderForwardTripper(ctx, base, target.HeaderForward, provider, target.WorkloadID)
if err != nil {
return nil, fmt.Errorf("failed to build header-forward transport for backend %s: %w", target.WorkloadID, err)
}

var c *mcpclient.Client
switch target.TransportType {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ type fakeBackend struct {
// resources/list failure.
mu sync.Mutex
methodCalls map[string]int

// headersByMethod records the inbound request headers keyed by JSON-RPC
// method name. Tests asserting transport-chain behavior (e.g. HeaderForward)
// use headersFor(method) to inspect the headers a backend actually saw.
headersByMethod map[string]http.Header
}

type jsonRPCError struct {
Expand All @@ -68,6 +73,7 @@ func newFakeBackend(t *testing.T, fb *fakeBackend) string {
t.Helper()
fb.t = t
fb.methodCalls = make(map[string]int)
fb.headersByMethod = make(map[string]http.Header)

mux := http.NewServeMux()
mux.HandleFunc("/mcp", fb.handle)
Expand All @@ -83,6 +89,19 @@ func (f *fakeBackend) callCount(method string) int {
return f.methodCalls[method]
}

// headersFor returns a clone of the inbound HTTP headers recorded for the most
// recent JSON-RPC request with the given method, or nil if no such request was
// seen. Cloning under the mutex keeps the caller safe from concurrent writes.
func (f *fakeBackend) headersFor(method string) http.Header {
f.mu.Lock()
defer f.mu.Unlock()
h := f.headersByMethod[method]
if h == nil {
return nil
}
return h.Clone()
}

// handle implements the JSON-RPC subset needed for backend init. The
// streamable-HTTP transport sends POST requests with Accept:
// application/json, text/event-stream — we always reply with
Expand Down Expand Up @@ -117,6 +136,7 @@ func (f *fakeBackend) handle(w http.ResponseWriter, r *http.Request) {

f.mu.Lock()
f.methodCalls[msg.Method]++
f.headersByMethod[msg.Method] = r.Header.Clone()
f.mu.Unlock()

// Notifications (no id, e.g. notifications/initialized) get an empty 202.
Expand Down Expand Up @@ -148,6 +168,16 @@ func (f *fakeBackend) handle(w http.ResponseWriter, r *http.Request) {
return
}
f.writeResult(w, msg.ID, map[string]any{"prompts": f.prompts})
case string(mcp.MethodToolsCall):
// Minimal CallToolResult with a single text content. Tests that exercise
// the post-initialize transport chain (e.g. HeaderForward) need a method
// they can invoke after Initialize completes; tools/call is the cheapest.
f.writeResult(w, msg.ID, map[string]any{
"content": []map[string]any{
{"type": "text", "text": "ok"},
},
"isError": false,
})
default:
f.writeError(w, msg.ID, &jsonRPCError{code: mcp.METHOD_NOT_FOUND, message: "Method not found"})
}
Expand Down
Loading
Loading