diff --git a/pkg/vmcp/client/client.go b/pkg/vmcp/client/client.go index c67b79e4bd..59baf1fd86 100644 --- a/pkg/vmcp/client/client.go +++ b/pkg/vmcp/client/client.go @@ -33,6 +33,7 @@ import ( vmcpauth "github.com/stacklok/toolhive/pkg/vmcp/auth" authtypes "github.com/stacklok/toolhive/pkg/vmcp/auth/types" "github.com/stacklok/toolhive/pkg/vmcp/conversion" + "github.com/stacklok/toolhive/pkg/vmcp/headerforward" healthcontext "github.com/stacklok/toolhive/pkg/vmcp/health/context" ) @@ -306,7 +307,9 @@ func (h *httpBackendClient) defaultClientFactory(ctx context.Context, target *vm // Inject per-backend HTTP headers from MCPServerEntry.spec.headerForward. // Resolves plaintext + secret-backed headers once here; auth (inner) always // wins over user-supplied headers because it runs after this tripper. - baseTransport, err = buildHeaderForwardTripper(ctx, baseTransport, target.HeaderForward, h.secretsProvider, target.WorkloadID) + baseTransport, err = headerforward.BuildHeaderForwardTripper( + ctx, baseTransport, target.HeaderForward, h.secretsProvider, target.WorkloadID, + ) if err != nil { return nil, fmt.Errorf("failed to build header-forward transport: %w", err) } diff --git a/pkg/vmcp/client/header_forward.go b/pkg/vmcp/headerforward/transport.go similarity index 92% rename from pkg/vmcp/client/header_forward.go rename to pkg/vmcp/headerforward/transport.go index 1708742dd8..8bac2cf500 100644 --- a/pkg/vmcp/client/header_forward.go +++ b/pkg/vmcp/headerforward/transport.go @@ -1,7 +1,12 @@ // SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. // SPDX-License-Identifier: Apache-2.0 -package client +// Package headerforward provides the HTTP round-tripper that injects +// per-backend forwarded headers (plaintext + secret-resolved) onto outbound +// requests. It is consumed by both the capability-discovery HTTP client +// (pkg/vmcp/client) and the per-session HTTP client +// (pkg/vmcp/session/internal/backend). +package headerforward import ( "context" @@ -55,7 +60,7 @@ func (h *headerForwardRoundTripper) RoundTrip(req *http.Request) (*http.Response return h.base.RoundTrip(reqCopy) } -// buildHeaderForwardTripper constructs a headerForwardRoundTripper for the +// BuildHeaderForwardTripper constructs a headerForwardRoundTripper for the // backend's pre-resolved HeaderForwardConfig. Returns base unchanged when no // header injection is configured or the effective header set is empty. // @@ -66,7 +71,7 @@ func (h *headerForwardRoundTripper) RoundTrip(req *http.Request) (*http.Response // Restricted header names (matching pkg/transport/middleware.RestrictedHeaders) // are rejected to prevent Host, Content-Length, Authorization, hop-by-hop, and // X-Forwarded-* spoofing via user-supplied config. -func buildHeaderForwardTripper( +func BuildHeaderForwardTripper( ctx context.Context, base http.RoundTripper, cfg *vmcp.HeaderForwardConfig, diff --git a/pkg/vmcp/client/header_forward_test.go b/pkg/vmcp/headerforward/transport_test.go similarity index 98% rename from pkg/vmcp/client/header_forward_test.go rename to pkg/vmcp/headerforward/transport_test.go index eddbfd21c9..d1d19042d8 100644 --- a/pkg/vmcp/client/header_forward_test.go +++ b/pkg/vmcp/headerforward/transport_test.go @@ -1,7 +1,7 @@ // SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. // SPDX-License-Identifier: Apache-2.0 -package client +package headerforward import ( "context" @@ -237,7 +237,7 @@ func TestResolveHeaderForward_NilCfgReturnsNil(t *testing.T) { func TestBuildHeaderForwardTripper_NilCfgReturnsBase(t *testing.T) { t.Parallel() base := &captureTripper{} - got, err := buildHeaderForwardTripper(t.Context(), base, nil, nil, "x") + got, err := BuildHeaderForwardTripper(t.Context(), base, nil, nil, "x") require.NoError(t, err) assert.Same(t, base, got, "nil cfg must pass base through untouched") }