diff --git a/README.md b/README.md index 7f24efe45..c162e221d 100644 --- a/README.md +++ b/README.md @@ -241,7 +241,7 @@ For comprehensive TOML configuration documentation, including: - Denied resources for restricting access to sensitive resource types - Server instructions for MCP Tool Search - [Custom MCP prompts](docs/prompts.md) -- [OAuth/OIDC authentication](docs/KEYCLOAK_OIDC_SETUP.md) for HTTP mode +- OAuth/OIDC authentication for HTTP mode ([Keycloak](docs/KEYCLOAK_OIDC_SETUP.md), [Microsoft Entra ID](docs/ENTRA_ID_SETUP.md)) See the **[Configuration Reference](docs/configuration.md)**. diff --git a/docs/ENTRA_ID_SETUP.md b/docs/ENTRA_ID_SETUP.md new file mode 100644 index 000000000..5ad008d3b --- /dev/null +++ b/docs/ENTRA_ID_SETUP.md @@ -0,0 +1,503 @@ +# Microsoft Entra ID Setup for Kubernetes MCP Server + +This guide shows you how to configure the Kubernetes MCP Server to use Microsoft Entra ID (formerly Azure AD) as the OIDC provider. + +## Overview + +Entra ID differs from Keycloak in that it only exposes the standard OpenID Connect discovery endpoint (`/.well-known/openid-configuration`) and does not implement the OAuth Authorization Server Metadata endpoints (`/.well-known/oauth-authorization-server`). + +The MCP server automatically handles this by falling back to `openid-configuration` when the OAuth-specific endpoints return 404. + +## Prerequisites + +- Microsoft Entra ID admin access (Azure Portal) +- Kubernetes cluster configured with Entra ID as the OIDC provider +- `kubectl` CLI with cluster access + +## Step 1: Register an App in Entra ID + +### Create the App Registration + +1. Go to **Azure Portal** → **Microsoft Entra ID** → **App registrations** +2. Click **New registration** +3. Fill in: + - **Name:** `MCP Server` (or any name) + - **Supported account types:** "Accounts in this organizational directory only" + - **Redirect URI:** Leave blank for now +4. Click **Register** + +### Note Your IDs + +From the app's **Overview** page, copy: +- **Application (client) ID** → `CLIENT_ID` +- **Directory (tenant) ID** → `TENANT_ID` + +### Configure Client Credentials + +You need **one** of the following — a client secret or a certificate. If you only need MCP server authentication (no other systems sharing this app registration), certificate-based auth is recommended. + +#### Option A: Client Secret + +Use this if you prefer simplicity or if other systems (e.g., cluster console OIDC login) share this app registration and require a client secret. + +1. Go to **Certificates & secrets** (left sidebar) +2. Click **New client secret** +3. Add description and expiration +4. Click **Add** +5. **Copy the Value immediately** (only shown once) → `CLIENT_SECRET` + +#### Option B: Certificate (Recommended for MCP Server) + +Use this for production deployments. No secret to manage — the MCP server authenticates using a signed JWT assertion. + +1. Generate a certificate (or use your PKI): + ```bash + openssl req -x509 -newkey rsa:2048 -keyout client.key -out client.crt -days 365 -nodes -subj "/CN=MCP Server" + ``` +2. Go to **Certificates & secrets** (left sidebar) +3. Click the **Certificates** tab +4. Click **Upload certificate** and select `client.crt` +5. Note the **Thumbprint** shown after upload — you can use this to verify your config later + +> **Tip:** If your cluster already uses a separate app registration with a client secret for console OIDC login, you can create a dedicated app registration for the MCP server using certificate auth only (no secret needed). See [Separate App Registration for MCP Server](#separate-app-registration-for-mcp-server) below. + +### Configure API Permissions + +1. Go to **API permissions** (left sidebar) +2. Click **Add a permission** → **Microsoft Graph** → **Delegated permissions** +3. Add these permissions: + - `openid` + - `profile` + - `email` +4. Click **Add permissions** +5. Click **Grant admin consent for [your org]** + +### Configure Token Claims + +1. Go to **Token configuration** (left sidebar) +2. Click **Add optional claim** +3. Select **ID** token type +4. Check these claims: + - `email` + - `preferred_username` +5. Click **Add** + +### Add Redirect URI (Optional - for Testing) + +If you plan to test with MCP Inspector: + +1. Go to **Authentication** (left sidebar) +2. Under **Platform configurations**, click **Add a platform** → **Web** +3. Add redirect URI: `http://localhost:6274/oauth/callback` +4. Click **Configure** + +## Step 2: Configure MCP Server + +Create a configuration file (`config.toml`): + +### Basic Configuration + +Use this configuration when your Kubernetes cluster accepts Entra ID tokens directly (cluster OIDC is configured with the same Entra ID tenant): + +```toml +require_oauth = true +oauth_audience = "" +oauth_scopes = ["openid", "profile", "email"] + +# Entra ID uses v2.0 endpoints +authorization_url = "https://login.microsoftonline.com//v2.0" +``` + +Replace: +- `` with your Application (client) ID +- `` with your Directory (tenant) ID + +> **Note:** When `cluster_auth_mode` is not set, the server auto-detects: +> - If `require_oauth = true` → uses `passthrough` +> - Otherwise → uses `kubeconfig` +> +> In `passthrough` mode, if token exchange is configured (`token_exchange_strategy` or `sts_audience`), the token is exchanged before being passed to the cluster. + +### With ServiceAccount Credentials + +If your Kubernetes cluster doesn't accept Entra ID tokens on the API server, use this configuration: + +```toml +require_oauth = true +oauth_audience = "" +oauth_scopes = ["openid", "profile", "email"] + +authorization_url = "https://login.microsoftonline.com//v2.0" + +# Use kubeconfig ServiceAccount credentials for cluster access +cluster_auth_mode = "kubeconfig" +kubeconfig = "/path/to/sa-kubeconfig" +``` + +This setup: +- **MCP clients authenticate via Entra ID** (OAuth required for MCP access) +- **Cluster access uses ServiceAccount token** (from kubeconfig) + +#### Creating a ServiceAccount Kubeconfig + +Your regular kubeconfig likely uses interactive login. Create a kubeconfig with a static ServiceAccount token: + +```bash +# Create ServiceAccount +kubectl create sa mcp-server -n default + +# Grant permissions (adjust role as needed) +kubectl create clusterrolebinding mcp-server-reader \ + --clusterrole=view \ + --serviceaccount=default:mcp-server + +# Create a token (adjust duration to your security requirements) +kubectl create token mcp-server -n default --duration=720h > sa-token + +# Create kubeconfig with the token +export SA_TOKEN=$(cat sa-token) +export CLUSTER_URL=$(kubectl config view --minify -o jsonpath='{.clusters[0].cluster.server}') +export CLUSTER_CA=$(kubectl config view --raw --minify -o jsonpath='{.clusters[0].cluster.certificate-authority-data}') + +cat > mcp-kubeconfig << EOF +apiVersion: v1 +kind: Config +clusters: +- cluster: + certificate-authority-data: ${CLUSTER_CA} + server: ${CLUSTER_URL} + name: cluster +contexts: +- context: + cluster: cluster + user: mcp-server + name: mcp-context +current-context: mcp-context +users: +- name: mcp-server + user: + token: ${SA_TOKEN} +EOF +``` + +Then run: +```bash +./kubernetes-mcp-server --config config.toml +``` + +### With Token Exchange (On-Behalf-Of Flow) + +If your cluster accepts Entra ID tokens and you want to exchange the user's token via the On-Behalf-Of (OBO) flow, use one of the credential options below. + +#### With Client Secret (Option A) + +```toml +require_oauth = true +oauth_audience = "" +oauth_scopes = ["openid", "profile", "email"] + +authorization_url = "https://login.microsoftonline.com//v2.0" + +# Token exchange configuration (passthrough will use this automatically) +token_exchange_strategy = "entra-obo" +sts_client_id = "" +sts_client_secret = "" +sts_scopes = ["api:///.default"] +``` + +#### With Certificate (Option B — Recommended) + +```toml +require_oauth = true +oauth_audience = "" +oauth_scopes = ["openid", "profile", "email"] + +authorization_url = "https://login.microsoftonline.com//v2.0" + +# Token exchange with certificate authentication (RFC 7523 JWT Client Assertion) +token_exchange_strategy = "entra-obo" +sts_client_id = "" +sts_auth_style = "assertion" +sts_client_cert_file = "/path/to/client.crt" +sts_client_key_file = "/path/to/client.key" +sts_scopes = ["api:///.default"] +``` + +No client secret is needed when using certificate auth. The MCP server signs a short-lived JWT assertion (5 minutes) using the private key, and Entra ID validates it against the uploaded certificate. + +#### OBO Prerequisites + +For OBO to work, you need to configure API permissions in Azure: +1. Go to your app registration → **API permissions** +2. Click **Add a permission** → **APIs my organization uses** +3. Select the downstream API app registration +4. Add the required delegated permissions + +## Step 3: Run the MCP Server + +```bash +./kubernetes-mcp-server --config config.toml +``` + +## Testing with MCP Inspector (Optional) + +To test authentication with MCP Inspector: + +1. Ensure redirect URI is configured (see Step 1) +2. Start MCP Inspector: + ```bash + npx @modelcontextprotocol/inspector@latest $(pwd)/kubernetes-mcp-server --config config.toml + ``` +3. In **Authentication** section: + - Set **Client ID** to your `` + - Set **Scope** to `openid profile email` +4. Click **Connect** +5. Login with your Entra ID credentials + +## How It Works + +### Client Registration + +Entra ID doesn't support RFC 7591 Dynamic Client Registration - clients must be pre-registered in the Azure portal (as shown in Step 1 above). + +Add redirect URIs in the Azure portal → Authentication for your MCP clients: +- `http://localhost:6274/oauth/callback` (MCP Inspector default) + +### Well-Known Endpoint Fallback + +The MCP server implements automatic fallback for OIDC providers that don't support all OAuth 2.0 well-known endpoints: + +1. When a client requests `/.well-known/oauth-authorization-server`, the server first tries to proxy the request to Entra ID +2. Entra ID returns 404 (this endpoint doesn't exist) +3. The server automatically falls back to fetching `/.well-known/openid-configuration` +4. The openid-configuration response is returned, which contains all required OAuth metadata + +This allows MCP clients to work with Entra ID without any special configuration. + +## Troubleshooting + +### "invalid_client" Error + +Check that: +- You're using the correct client ID +- The redirect URI matches exactly what's configured in Entra ID +- The client secret is correct (if using client secret auth) + +### "AADSTS700027" Certificate Not Registered + +This means the certificate used to sign the JWT assertion doesn't match any certificate uploaded to your app registration. + +1. Check your certificate's thumbprint: + ```bash + openssl x509 -in /path/to/client.crt -fingerprint -sha1 -noout + ``` +2. Go to Azure Portal → App registrations → your app → **Certificates & secrets** → **Certificates** +3. Compare the thumbprint. If it doesn't match, upload the correct certificate +4. Make sure `sts_client_cert_file` and `sts_client_key_file` point to the matching cert/key pair + +### "AADSTS50011" Redirect URI Mismatch + +The redirect URI in your request doesn't match Entra ID configuration: +1. Go to Azure Portal → App registrations → your app → Authentication +2. Add the exact redirect URI shown in the error message + +### Token Validation Fails + +Ensure your Kubernetes cluster is configured to trust Entra ID tokens: +- The OIDC issuer should be `https://login.microsoftonline.com/{tenant}/v2.0` +- The audience should match your client ID or application ID URI + +### Well-Known Endpoint Returns 404 + +This is expected for `oauth-authorization-server` and `oauth-protected-resource` endpoints. The MCP server automatically handles this by falling back to `openid-configuration`. + +## Differences from Keycloak + +| Feature | Keycloak | Entra ID | +|---------|----------|----------| +| oauth-authorization-server endpoint | ✅ Supported | ❌ Not available | +| oauth-protected-resource endpoint | ✅ Supported | ❌ Not available | +| openid-configuration endpoint | ✅ Supported | ✅ Supported | +| Token Exchange (RFC 8693) | ✅ Supported | ❌ Use On-Behalf-Of flow | +| Dynamic Client Registration | ✅ Supported | ❌ Not available | + +The MCP server handles these differences automatically through the well-known endpoint fallback mechanism. + +## Quick Reference + +| Item | Where to Find | +|------|---------------| +| Client ID | Azure Portal → App registrations → Overview → Application (client) ID | +| Tenant ID | Azure Portal → App registrations → Overview → Directory (tenant) ID | +| Client Secret | Azure Portal → App registrations → Certificates & secrets → Value column | +| Authorization URL | `https://login.microsoftonline.com//v2.0` | + +## Configuring Your Cluster to Accept Entra ID Tokens + +For the passthrough flow to work, your Kubernetes cluster's API server must be configured to accept Entra ID tokens via OIDC. This is separate from any console or dashboard login configuration your cluster may have. + +### Architecture + +``` +┌─────────────────────────────────────────────────────────────────────────────┐ +│ │ +│ ┌──────────┐ ┌─────────────┐ ┌──────────┐ ┌─────────────────┐ │ +│ │ User │────▶│ MCP Client │────▶│MCP Server│────▶│ Kubernetes │ │ +│ │ │ │ (Inspector) │ │ │ │ Cluster │ │ +│ └──────────┘ └─────────────┘ └──────────┘ └─────────────────┘ │ +│ │ │ │ │ │ +│ │ 1. OAuth │ 2. Bearer │ 3. OBO + │ │ +│ │ Login │ Token │ Assertion │ │ +│ │ │ │ │ │ +│ ▼ ▼ ▼ ▼ │ +│ ┌──────────────────────────────────────────────────────────────────────┐ │ +│ │ Microsoft Entra ID │ │ +│ │ │ │ +│ │ 1. User authenticates via OAuth 2.0 (authorization code flow) │ │ +│ │ 2. MCP Server validates user token │ │ +│ │ 3. MCP Server exchanges token using OBO + JWT client assertion │ │ +│ │ 4. Cluster validates exchanged token via OIDC │ │ +│ └──────────────────────────────────────────────────────────────────────┘ │ +│ │ +└─────────────────────────────────────────────────────────────────────────────┘ +``` + +### Prerequisites + +- Cluster admin access +- Ability to configure kube-apiserver OIDC flags (managed clusters may expose this differently) + +### Configure kube-apiserver OIDC Flags + +The API server needs the following OIDC flags. How you set these depends on your cluster type: + +- **kubeadm / self-managed**: edit `/etc/kubernetes/manifests/kube-apiserver.yaml` +- **Managed Kubernetes (EKS, AKS, GKE)**: use the provider's OIDC identity configuration +- **Kind / Minikube**: pass flags via cluster config + +The required flags: + +``` +--oidc-issuer-url=https://login.microsoftonline.com//v2.0 +--oidc-client-id= +--oidc-username-claim=preferred_username +--oidc-groups-claim=groups +``` + +### Create RBAC for Entra ID Users + +Once the API server accepts Entra ID tokens, create RBAC bindings for your users: + +```bash +# For a specific user +kubectl create clusterrolebinding entra-user-admin \ + --clusterrole=cluster-admin \ + --user="user@yourdomain.com" + +# For a group (requires groups claim configured in Entra ID) +kubectl create clusterrolebinding entra-group-admin \ + --clusterrole=cluster-admin \ + --group="your-group-object-id" +``` + +### Verify OIDC Is Working + +You can test that the API server accepts Entra ID tokens by using `kubectl` with a token: + +```bash +kubectl --token="" get namespaces +``` + +If this returns namespaces, your cluster is correctly configured. + +### Complete MCP Server Config for Passthrough + +```toml +# config.toml +log_level = 4 +port = "8080" + +# OAuth: Users authenticate via Entra ID +require_oauth = true +authorization_url = "https://login.microsoftonline.com//v2.0" +oauth_audience = "" + +# Pass exchanged token to cluster +cluster_auth_mode = "passthrough" + +# Token Exchange: OBO flow with JWT client assertion +token_exchange_strategy = "entra-obo" +sts_client_id = "" +sts_auth_style = "assertion" +sts_client_cert_file = "/path/to/client.crt" +sts_client_key_file = "/path/to/client.key" +sts_scopes = ["/.default"] +``` + +### Understanding the Two Trust Relationships + +1. **MCP Server → Entra ID (OBO Exchange)** + - MCP Server authenticates using JWT client assertion (certificate) + - No client_secret needed + - This is what `sts_auth_style = "assertion"` configures + +2. **Cluster → Entra ID (Token Validation)** + - The API server validates tokens from Entra ID + - Configured via kube-apiserver OIDC flags + - Uses OIDC discovery to fetch signing keys + +Both relationships use the same app registration but serve different purposes. + +## Separate App Registration for MCP Server + +If your cluster already uses an Entra ID app registration with a client secret for console OIDC login, you may want a **separate app registration** for the MCP server — especially if you prefer certificate-based auth and don't want to add a certificate to the existing app. + +### When to Use This + +- Your cluster's OIDC app registration is shared with other systems (console, CLI) and uses a client secret +- You want the MCP server to use certificate auth without affecting the existing setup +- You want to scope the MCP server's permissions separately + +### Setup + +**App Registration A** (existing) — used by the cluster for console/CLI OIDC login. Has a client secret. No changes needed. + +**App Registration B** (new) — used by the MCP server for OBO token exchange. Uses certificate auth, no secret required. + +1. Create a new app registration in Azure (follow [Step 1](#step-1-register-an-app-in-entra-id) above) +2. Choose **Option B: Certificate** for credentials — skip the client secret +3. On App Registration B, go to **Expose an API** (left sidebar): + - Click **Set** next to "Application ID URI" (accept the default `api://`) + - Click **Add a scope** → name it (e.g., `access_as_user`) → set "Who can consent" to "Admins and users" → enable it +4. On App Registration A (the existing one), go to **API permissions**: + - Click **Add a permission** → **APIs my organization uses** → find App Registration B + - Add the delegated scope you created (e.g., `access_as_user`) + - Click **Grant admin consent** + +### MCP Server Configuration + +```toml +require_oauth = true +# Use App A's client ID — this is what MCP clients authenticate with +oauth_audience = "" +oauth_scopes = ["openid", "profile", "email"] + +authorization_url = "https://login.microsoftonline.com//v2.0" + +# OBO exchange uses App B's credentials (certificate, no secret) +token_exchange_strategy = "entra-obo" +sts_client_id = "" +sts_auth_style = "assertion" +sts_client_cert_file = "/path/to/client.crt" +sts_client_key_file = "/path/to/client.key" +sts_scopes = ["/.default"] +``` + +This way, the cluster's existing OIDC configuration is untouched, and the MCP server has its own credentials with certificate-based auth. + +## See Also + +- [Entra ID OAuth 2.0 Documentation](https://learn.microsoft.com/en-us/entra/identity-platform/v2-oauth2-auth-code-flow) +- [Entra ID On-Behalf-Of Flow](https://learn.microsoft.com/en-us/entra/identity-platform/v2-oauth2-on-behalf-of-flow) +- [Kubernetes OIDC Authentication](https://kubernetes.io/docs/reference/access-authn-authz/authentication/#openid-connect-tokens) +- [Keycloak OIDC Setup](KEYCLOAK_OIDC_SETUP.md) - Alternative OIDC provider setup diff --git a/docs/configuration.md b/docs/configuration.md index 7fc8775d5..2d2c5b8a0 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -468,10 +468,15 @@ Configure OAuth/OIDC authentication for HTTP mode deployments. | `sts_client_secret` | string | `""` | OAuth client secret for backend token exchange. | | `sts_audience` | string | `""` | Audience for STS token exchange. | | `sts_scopes` | string[] | `[]` | Scopes for STS token exchange. | +| `token_exchange_strategy` | string | `""` | Token exchange strategy: `rfc8693`, `keycloak-v1`, or `entra-obo`. | +| `sts_auth_style` | string | `"params"` | How client credentials are sent: `params` (body), `header` (Basic Auth), or `assertion` (JWT). | +| `sts_client_cert_file` | string | `""` | Path to client certificate PEM file (for `assertion` auth style). | +| `sts_client_key_file` | string | `""` | Path to client private key PEM file (for `assertion` auth style). | +| `cluster_auth_mode` | string | `""` | Cluster auth mode: `passthrough` (use OAuth token) or `kubeconfig` (use kubeconfig credentials). | | `certificate_authority` | string | `""` | Path to CA certificate for validating authorization server connections. | | `server_url` | string | `""` | Public URL of the MCP server (used for OAuth metadata). | -**Example:** +**Example (with client secret):** ```toml require_oauth = true authorization_url = "https://keycloak.example.com/realms/mcp" @@ -483,7 +488,21 @@ sts_client_secret = "your-client-secret" sts_audience = "kubernetes-api" ``` -For a complete OIDC setup guide, see [KEYCLOAK_OIDC_SETUP.md](KEYCLOAK_OIDC_SETUP.md). +**Example (with certificate-based auth for Entra ID):** +```toml +require_oauth = true +authorization_url = "https://login.microsoftonline.com//v2.0" +oauth_audience = "" + +token_exchange_strategy = "entra-obo" +sts_client_id = "" +sts_auth_style = "assertion" +sts_client_cert_file = "/path/to/client.crt" +sts_client_key_file = "/path/to/client.key" +sts_scopes = ["api:///.default"] +``` + +For a complete OIDC setup guide, see [KEYCLOAK_OIDC_SETUP.md](KEYCLOAK_OIDC_SETUP.md) or [ENTRA_ID_SETUP.md](ENTRA_ID_SETUP.md). ### Telemetry diff --git a/go.mod b/go.mod index 3f87caa72..6fe5692e4 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/google/gnostic-models v0.7.1 github.com/google/jsonschema-go v0.4.2 github.com/modelcontextprotocol/go-sdk v1.5.0 + github.com/google/uuid v1.6.0 github.com/prometheus/client_golang v1.23.2 github.com/spf13/afero v1.15.0 github.com/spf13/cobra v1.10.2 @@ -85,7 +86,6 @@ require ( github.com/google/btree v1.1.3 // indirect github.com/google/cel-go v0.27.0 // indirect github.com/google/go-cmp v0.7.0 // indirect - github.com/google/uuid v1.6.0 // indirect github.com/gorilla/websocket v1.5.4-0.20250319132907-e064f32e3674 // indirect github.com/gosuri/uitable v0.0.4 // indirect github.com/gregjones/httpcache v0.0.0-20190611155906-901d90724c79 // indirect diff --git a/pkg/api/config.go b/pkg/api/config.go index 7f9de416c..a86da7df5 100644 --- a/pkg/api/config.go +++ b/pkg/api/config.go @@ -7,11 +7,30 @@ const ( ClusterProviderKcp = "kcp" ) +// ClusterAuthMode constants define how the MCP server authenticates to the cluster. +const ( + // ClusterAuthPassthrough passes the OAuth token to the cluster. + // If token exchange is configured (token_exchange_strategy or sts_audience), + // the token is exchanged first before being passed through. + ClusterAuthPassthrough = "passthrough" + + // ClusterAuthKubeconfig uses kubeconfig credentials (e.g., ServiceAccount token). + // Use when cluster auth is separate from MCP client auth. + ClusterAuthKubeconfig = "kubeconfig" +) + type AuthProvider interface { // IsRequireOAuth indicates whether OAuth authentication is required. IsRequireOAuth() bool } +// ClusterAuthProvider provides configuration for how the MCP server authenticates to clusters. +type ClusterAuthProvider interface { + // GetClusterAuthMode returns the cluster authentication mode. + // Returns empty string for auto-detection based on other config. + GetClusterAuthMode() string +} + type ClusterProvider interface { // GetClusterProviderStrategy returns the cluster provider strategy (if configured). GetClusterProviderStrategy() string @@ -51,6 +70,10 @@ type StsConfigProvider interface { GetStsClientSecret() string GetStsAudience() string GetStsScopes() []string + GetStsStrategy() string + GetStsAuthStyle() string + GetStsClientCertFile() string + GetStsClientKeyFile() string } // ValidationEnabledProvider provides access to validation enabled setting. @@ -65,6 +88,7 @@ type RequireTLSProvider interface { type BaseConfig interface { AuthProvider + ClusterAuthProvider ClusterProvider ConfirmationRulesProvider DeniedResourcesProvider diff --git a/pkg/config/config.go b/pkg/config/config.go index ee81c078a..a9e701d33 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -73,9 +73,25 @@ type StaticConfig struct { // StsAudience is the audience for the STS token exchange. StsAudience string `toml:"sts_audience,omitempty"` // StsScopes is the scopes for the STS token exchange. - StsScopes []string `toml:"sts_scopes,omitempty"` - CertificateAuthority string `toml:"certificate_authority,omitempty"` - ServerURL string `toml:"server_url,omitempty"` + StsScopes []string `toml:"sts_scopes,omitempty"` + // TokenExchangeStrategy is the token exchange strategy to use (rfc8693, keycloak-v1, entra-obo). + // When set with passthrough mode, the token is exchanged before being passed to the cluster. + TokenExchangeStrategy string `toml:"token_exchange_strategy,omitempty"` + // StsAuthStyle specifies how client credentials are sent during token exchange. + // "params" (default): client_id/secret in request body + // "header": HTTP Basic Authentication header + // "assertion": JWT client assertion (RFC 7523, for Entra ID certificate auth) + StsAuthStyle string `toml:"sts_auth_style,omitempty"` + // StsClientCertFile is the path to the client certificate PEM file for JWT assertion auth + StsClientCertFile string `toml:"sts_client_cert_file,omitempty"` + // StsClientKeyFile is the path to the client private key PEM file for JWT assertion auth + StsClientKeyFile string `toml:"sts_client_key_file,omitempty"` + // ClusterAuthMode determines how the MCP server authenticates to the cluster. + // Valid values: "passthrough" (use OAuth token, with optional exchange), "kubeconfig" (use kubeconfig credentials). + // If empty, auto-detects: passthrough when require_oauth=true, otherwise kubeconfig. + ClusterAuthMode string `toml:"cluster_auth_mode,omitempty"` + CertificateAuthority string `toml:"certificate_authority,omitempty"` + ServerURL string `toml:"server_url,omitempty"` // TLS configuration for the HTTP server // TLSCert is the path to the TLS certificate file for HTTPS @@ -389,6 +405,22 @@ func (c *StaticConfig) GetStsScopes() []string { return c.StsScopes } +func (c *StaticConfig) GetStsStrategy() string { + return c.TokenExchangeStrategy +} + +func (c *StaticConfig) GetStsAuthStyle() string { + return c.StsAuthStyle +} + +func (c *StaticConfig) GetStsClientCertFile() string { + return c.StsClientCertFile +} + +func (c *StaticConfig) GetStsClientKeyFile() string { + return c.StsClientKeyFile +} + func (c *StaticConfig) IsValidationEnabled() bool { return c.ValidationEnabled } @@ -417,3 +449,15 @@ func (c *StaticConfig) ValidateRequireTLS() error { "sse_base_url": c.SSEBaseURL, }) } + +func (c *StaticConfig) GetClusterAuthMode() string { + return c.ClusterAuthMode +} + +// ValidateClusterAuthMode validates that cluster_auth_mode is a known value. +func (c *StaticConfig) ValidateClusterAuthMode() error { + if c.ClusterAuthMode != "" && c.ClusterAuthMode != api.ClusterAuthPassthrough && c.ClusterAuthMode != api.ClusterAuthKubeconfig { + return fmt.Errorf("invalid cluster_auth_mode %q: must be %q or %q", c.ClusterAuthMode, api.ClusterAuthPassthrough, api.ClusterAuthKubeconfig) + } + return nil +} diff --git a/pkg/http/authorization.go b/pkg/http/authorization.go index 14e1c8b9e..73a39547d 100644 --- a/pkg/http/authorization.go +++ b/pkg/http/authorization.go @@ -4,15 +4,17 @@ import ( "context" "fmt" "net/http" + "slices" "strings" "github.com/coreos/go-oidc/v3/oidc" "github.com/go-jose/go-jose/v4" "github.com/go-jose/go-jose/v4/jwt" "k8s.io/klog/v2" - "k8s.io/utils/strings/slices" "github.com/containers/kubernetes-mcp-server/pkg/config" + internalk8s "github.com/containers/kubernetes-mcp-server/pkg/kubernetes" + "github.com/containers/kubernetes-mcp-server/pkg/oauth" ) // write401 sends a 401/Unauthorized response with WWW-Authenticate header. @@ -48,11 +50,16 @@ func write401(w http.ResponseWriter, wwwAuthenticateHeader, errorType, message s // - The token is then validated against the OIDC Provider. // // see TestAuthorizationOidcToken -func AuthorizationMiddleware(staticConfig *config.StaticConfig, oidcProvider *oidc.Provider) func(http.Handler) http.Handler { +func AuthorizationMiddleware(staticConfig *config.StaticConfig, oauthState *oauth.State) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if slices.Contains(infraPaths, r.URL.Path) || - slices.Contains(WellKnownEndpoints, r.URL.EscapedPath()) { + // Skip auth for infrastructure endpoints (health, metrics) and well-known endpoints + // Use prefix matching per endpoint to handle sub-paths like /.well-known/oauth-protected-resource/sse + requestPath := r.URL.EscapedPath() + isWellKnown := !strings.Contains(requestPath, "..") && slices.ContainsFunc(WellKnownEndpoints, func(ep string) bool { + return requestPath == ep || strings.HasPrefix(requestPath, ep+"/") + }) + if slices.Contains(infraPaths, r.URL.Path) || isWellKnown { next.ServeHTTP(w, r) return } @@ -86,7 +93,18 @@ func AuthorizationMiddleware(staticConfig *config.StaticConfig, oidcProvider *oi } // Online OIDC provider validation if err == nil { - err = claims.ValidateWithProvider(r.Context(), staticConfig.OAuthAudience, oidcProvider) + snapshot := oauthState.Load() + if snapshot == nil || snapshot.OIDCProvider == nil { + // Provider was configured (authorization_url set) but is unavailable — reject + if staticConfig.AuthorizationURL != "" { + klog.V(1).Infof("Authentication rejected - OIDC provider unavailable: %s %s from %s", r.Method, r.URL.Path, r.RemoteAddr) + write401(w, wwwAuthenticateHeader, "temporarily_unavailable", "OIDC provider is not available") + return + } + // No provider configured — offline validation only + } else { + err = claims.ValidateWithProvider(r.Context(), staticConfig.OAuthAudience, snapshot.OIDCProvider) + } } if err != nil { klog.V(1).Infof("Authentication failed - JWT validation error: %s %s from %s, error: %v", r.Method, r.URL.Path, r.RemoteAddr, err) @@ -94,7 +112,10 @@ func AuthorizationMiddleware(staticConfig *config.StaticConfig, oidcProvider *oi return } - next.ServeHTTP(w, r) + // Store the validated Authorization header in context for MCP handlers + // This is necessary because SSE transport doesn't propagate HTTP headers to MCP requests + ctx := context.WithValue(r.Context(), internalk8s.OAuthAuthorizationHeader, authHeader) + next.ServeHTTP(w, r.WithContext(ctx)) }) } } diff --git a/pkg/http/authorization_mcp_test.go b/pkg/http/authorization_mcp_test.go index 2659497d3..ecf939faf 100644 --- a/pkg/http/authorization_mcp_test.go +++ b/pkg/http/authorization_mcp_test.go @@ -294,14 +294,15 @@ func (s *AuthorizationSuite) TestAuthorizationUnauthorizedTokenExchangeFailure() s.Require().NotNil(s.mcpClient.Session, "Expected session for valid authentication") s.Require().NotNil(s.mcpClient.Session.InitializeResult(), "Expected initial request to not be nil") }) - s.Run("Call tool exchanges token VALID OIDC EXCHANGE Authorization header", func() { + s.Run("Call tool returns error when token exchange fails", func() { toolResult, err := s.mcpClient.Session.CallTool(s.T().Context(), &mcp.CallToolParams{ Name: "events_list", Arguments: map[string]any{}, }) - s.Require().NoError(err, "Expected no error calling tool") // TODO: Should error - s.Require().NotNil(toolResult, "Expected tool result to not be nil") // Should be nil - s.Regexp("token exchange failed:[^:]+: status code 401", s.logBuffer.String()) + // When token exchange is explicitly configured and fails, + // the error should propagate rather than silently passing through + s.Require().Error(err, "Expected tool call to fail when token exchange fails") + s.Require().Nil(toolResult, "Expected no tool result when token exchange fails") }) }) s.mcpClient.Close() @@ -424,7 +425,7 @@ func (s *AuthorizationSuite) TestAuthorizationOidcTokenExchange() { }) s.Require().NoError(err, "Expected no error calling tool") s.Require().NotNil(toolResult, "Expected tool result to not be nil") - s.Contains(s.logBuffer.String(), "token exchanged successfully") + // Token exchange is verified by the successful tool call with STS configured }) }) s.mcpClient.Close() diff --git a/pkg/http/http.go b/pkg/http/http.go index 425b8f4fc..cd7d3127b 100644 --- a/pkg/http/http.go +++ b/pkg/http/http.go @@ -14,12 +14,11 @@ import ( "syscall" "time" - "github.com/coreos/go-oidc/v3/oidc" - "k8s.io/klog/v2" "github.com/containers/kubernetes-mcp-server/pkg/config" "github.com/containers/kubernetes-mcp-server/pkg/mcp" + "github.com/containers/kubernetes-mcp-server/pkg/oauth" ) // tlsErrorFilterWriter filters out noisy TLS handshake errors from health checks @@ -104,11 +103,11 @@ func statsHandler(mcpServer *mcp.Server) http.HandlerFunc { } } -func Serve(ctx context.Context, mcpServer *mcp.Server, staticConfig *config.StaticConfig, oidcProvider *oidc.Provider, httpClient *http.Client) error { +func Serve(ctx context.Context, mcpServer *mcp.Server, staticConfig *config.StaticConfig, oauthState *oauth.State) error { mux := http.NewServeMux() wrappedMux := RequestMiddleware( - AuthorizationMiddleware(staticConfig, oidcProvider)( + AuthorizationMiddleware(staticConfig, oauthState)( MaxBodyMiddleware(staticConfig.HTTP.MaxBodyBytes)(mux), ), ) @@ -144,7 +143,7 @@ func Serve(ctx context.Context, mcpServer *mcp.Server, staticConfig *config.Stat }) mux.HandleFunc(statsEndpoint, statsHandler(mcpServer)) mux.Handle(metricsEndpoint, mcpServer.GetMetrics().PrometheusHandler()) - mux.Handle("/.well-known/", WellKnownHandler(staticConfig, httpClient)) + mux.Handle("/.well-known/", WellKnownHandler(staticConfig, oauthState)) ctx, cancel := context.WithCancel(ctx) defer cancel() diff --git a/pkg/http/http_test.go b/pkg/http/http_test.go index ec356b285..b4179f4a3 100644 --- a/pkg/http/http_test.go +++ b/pkg/http/http_test.go @@ -18,16 +18,16 @@ import ( "github.com/containers/kubernetes-mcp-server/internal/test" "github.com/containers/kubernetes-mcp-server/pkg/api" + "github.com/containers/kubernetes-mcp-server/pkg/config" "github.com/containers/kubernetes-mcp-server/pkg/kubernetes" + "github.com/containers/kubernetes-mcp-server/pkg/mcp" + "github.com/containers/kubernetes-mcp-server/pkg/oauth" "github.com/coreos/go-oidc/v3/oidc" "github.com/coreos/go-oidc/v3/oidc/oidctest" "github.com/stretchr/testify/suite" "golang.org/x/sync/errgroup" "k8s.io/klog/v2" "k8s.io/klog/v2/textlogger" - - "github.com/containers/kubernetes-mcp-server/pkg/config" - "github.com/containers/kubernetes-mcp-server/pkg/mcp" ) type BaseHttpSuite struct { @@ -36,6 +36,7 @@ type BaseHttpSuite struct { StaticConfig *config.StaticConfig mcpServer *mcp.Server OidcProvider *oidc.Provider + OAuthState *oauth.State timeoutCancel context.CancelFunc StopServer context.CancelFunc WaitForShutdown func() error @@ -55,7 +56,8 @@ func (s *BaseHttpSuite) StartServer() { s.Require().NoError(err, "Expected no error getting random port address") s.StaticConfig.Port = strconv.Itoa(tcpAddr.Port) - provider, err := kubernetes.NewProvider(s.StaticConfig, kubernetes.WithTokenExchange(s.OidcProvider, nil)) + s.OAuthState = oauth.NewState(oauth.SnapshotFromConfig(s.StaticConfig, s.OidcProvider, nil)) + provider, err := kubernetes.NewProvider(s.StaticConfig, kubernetes.WithTokenExchange(s.OAuthState)) s.Require().NoError(err, "Expected no error creating kubernetes target provider") s.mcpServer, err = mcp.NewServer(mcp.Configuration{StaticConfig: s.StaticConfig}, provider) s.Require().NoError(err, "Expected no error creating MCP server") @@ -64,7 +66,7 @@ func (s *BaseHttpSuite) StartServer() { timeoutCtx, s.timeoutCancel = context.WithTimeout(s.T().Context(), 10*time.Second) group, gc := errgroup.WithContext(timeoutCtx) cancelCtx, s.StopServer = context.WithCancel(gc) - group.Go(func() error { return Serve(cancelCtx, s.mcpServer, s.StaticConfig, s.OidcProvider, nil) }) + group.Go(func() error { return Serve(cancelCtx, s.mcpServer, s.StaticConfig, s.OAuthState) }) s.WaitForShutdown = group.Wait s.Require().NoError(test.WaitForServer(tcpAddr), "HTTP server did not start in time") s.Require().NoError(test.WaitForHealthz(tcpAddr), "HTTP server /healthz endpoint did not respond with non-404 in time") @@ -90,6 +92,7 @@ type httpContext struct { WaitForShutdown func() error StaticConfig *config.StaticConfig OidcProvider *oidc.Provider + OAuthState *oauth.State } func (c *httpContext) beforeEach(t *testing.T) { @@ -117,7 +120,8 @@ func (c *httpContext) beforeEach(t *testing.T) { t.Fatalf("Failed to close random port listener: %v", randomPortErr) } c.StaticConfig.Port = fmt.Sprintf("%d", ln.Addr().(*net.TCPAddr).Port) - provider, err := kubernetes.NewProvider(c.StaticConfig, kubernetes.WithTokenExchange(c.OidcProvider, nil)) + c.OAuthState = oauth.NewState(oauth.SnapshotFromConfig(c.StaticConfig, c.OidcProvider, nil)) + provider, err := kubernetes.NewProvider(c.StaticConfig, kubernetes.WithTokenExchange(c.OAuthState)) if err != nil { t.Fatalf("Failed to create kubernetes target provider: %v", err) } @@ -129,7 +133,7 @@ func (c *httpContext) beforeEach(t *testing.T) { timeoutCtx, c.timeoutCancel = context.WithTimeout(t.Context(), 10*time.Second) group, gc := errgroup.WithContext(timeoutCtx) cancelCtx, c.StopServer = context.WithCancel(gc) - group.Go(func() error { return Serve(cancelCtx, mcpServer, c.StaticConfig, c.OidcProvider, nil) }) + group.Go(func() error { return Serve(cancelCtx, mcpServer, c.StaticConfig, c.OAuthState) }) c.WaitForShutdown = group.Wait // Wait for HTTP server to start (using net) for i := 0; i < 10; i++ { diff --git a/pkg/http/wellknown.go b/pkg/http/wellknown.go index 82da0ef5c..6555950a2 100644 --- a/pkg/http/wellknown.go +++ b/pkg/http/wellknown.go @@ -7,12 +7,16 @@ import ( "net/http" "net/url" "strings" + "sync" + "time" "github.com/containers/kubernetes-mcp-server/pkg/config" + "github.com/containers/kubernetes-mcp-server/pkg/oauth" "k8s.io/klog/v2" ) const maxWellKnownResponseSize = 1 << 20 // 1 MB +const oidcConfigCacheTTL = 5 * time.Minute var allowedResponseHeaders = map[string]bool{ "Cache-Control": true, @@ -35,76 +39,157 @@ var WellKnownEndpoints = []string{ openIDConfigurationEndpoint, } +// WellKnownMetadataGenerator generates well-known metadata when the upstream +// authorization server doesn't provide certain endpoints. +// This allows supporting OIDC providers that only implement openid-configuration. +type WellKnownMetadataGenerator interface { + // GenerateAuthorizationServerMetadata generates oauth-authorization-server metadata + // from the openid-configuration. Returns nil if generation is not possible. + GenerateAuthorizationServerMetadata(oidcConfig map[string]interface{}) map[string]interface{} + + // GenerateProtectedResourceMetadata generates oauth-protected-resource metadata (RFC 9728) + // for the MCP server. authorizationServerURL is where OAuth metadata can be fetched. + GenerateProtectedResourceMetadata(oidcConfig map[string]interface{}, authorizationServerURL string) map[string]interface{} +} + +// DefaultMetadataGenerator provides standard metadata generation for OIDC providers +// that only implement openid-configuration (e.g., Entra ID, Auth0, etc.) +type DefaultMetadataGenerator struct{} + +// GenerateAuthorizationServerMetadata returns the openid-configuration as-is, +// since it contains the required OAuth 2.0 Authorization Server Metadata fields. +func (g *DefaultMetadataGenerator) GenerateAuthorizationServerMetadata(oidcConfig map[string]interface{}) map[string]interface{} { + return oidcConfig +} + +// GenerateProtectedResourceMetadata generates RFC 9728 compliant metadata +// for the MCP server acting as an OAuth 2.0 protected resource. +func (g *DefaultMetadataGenerator) GenerateProtectedResourceMetadata(oidcConfig map[string]interface{}, authorizationServerURL string) map[string]interface{} { + metadata := map[string]interface{}{ + "authorization_servers": []string{authorizationServerURL}, + } + + // Copy relevant fields from openid-configuration + if scopes, ok := oidcConfig["scopes_supported"]; ok { + metadata["scopes_supported"] = scopes + } + metadata["bearer_methods_supported"] = []string{"header"} + + return metadata +} + type WellKnown struct { - authorizationUrl string - scopesSupported []string - disableDynamicClientRegistration bool - httpClient *http.Client + oauthState *oauth.State + staticConfig *config.StaticConfig + metadataGenerator WellKnownMetadataGenerator + // Cache for openid-configuration to avoid repeated fetches (TTL: oidcConfigCacheTTL) + oidcConfigCache map[string]interface{} + oidcConfigCacheTime time.Time + oidcConfigCacheURL string // tracks which authURL the cache was fetched for + oidcConfigCacheLock sync.RWMutex } var _ http.Handler = &WellKnown{} -func WellKnownHandler(staticConfig *config.StaticConfig, httpClient *http.Client) http.Handler { - authorizationUrl := staticConfig.AuthorizationURL - if authorizationUrl != "" && strings.HasSuffix(authorizationUrl, "/") { - authorizationUrl = strings.TrimSuffix(authorizationUrl, "/") - } - if httpClient == nil { - // Create a TLS-enforcing client instead of using http.DefaultClient - httpClient = config.NewTLSEnforcingClient(nil, staticConfig.IsRequireTLS) +func WellKnownHandler(staticConfig *config.StaticConfig, oauthState *oauth.State) http.Handler { + return WellKnownHandlerWithGenerator(staticConfig, oauthState, &DefaultMetadataGenerator{}) +} + +// WellKnownHandlerWithGenerator creates a WellKnown handler with a custom metadata generator. +// This allows customizing how metadata is generated for different OIDC providers. +func WellKnownHandlerWithGenerator(staticConfig *config.StaticConfig, oauthState *oauth.State, generator WellKnownMetadataGenerator) http.Handler { + if generator == nil { + generator = &DefaultMetadataGenerator{} } return &WellKnown{ - authorizationUrl: authorizationUrl, - disableDynamicClientRegistration: staticConfig.DisableDynamicClientRegistration, - scopesSupported: staticConfig.OAuthScopes, - httpClient: httpClient, + oauthState: oauthState, + staticConfig: staticConfig, + metadataGenerator: generator, } } -func (w WellKnown) ServeHTTP(writer http.ResponseWriter, request *http.Request) { - if w.authorizationUrl == "" { - http.Error(writer, "Authorization URL is not configured", http.StatusNotFound) - return +// authorizationURL returns the current authorization URL from the oauth snapshot, trimming trailing slashes. +func (w *WellKnown) authorizationURL() string { + snap := w.oauthState.Load() + if snap == nil { + return "" } - upstreamURL, err := url.JoinPath(w.authorizationUrl, request.URL.EscapedPath()) - if err != nil || !strings.HasPrefix(upstreamURL, w.authorizationUrl+"/") { - http.Error(writer, "Invalid well-known path", http.StatusBadRequest) - return + return strings.TrimSuffix(snap.AuthorizationURL, "/") +} + +// wellKnownHTTPClient returns the current HTTP client from the oauth snapshot, +// falling back to a TLS-enforcing client if none is available. +func (w *WellKnown) wellKnownHTTPClient() *http.Client { + snap := w.oauthState.Load() + if snap != nil && snap.HTTPClient != nil { + return snap.HTTPClient } - req, err := http.NewRequest(request.Method, upstreamURL, nil) - if err != nil { - klog.V(1).Infof("Well-known proxy failed to create request for %s: %v", request.URL.Path, err) - http.Error(writer, "Failed to create upstream request", http.StatusInternalServerError) + return config.NewTLSEnforcingClient(nil, w.staticConfig.IsRequireTLS) +} + +func (w *WellKnown) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + authURL := w.authorizationURL() + if authURL == "" { + http.Error(writer, "Authorization URL is not configured", http.StatusNotFound) return } - resp, err := w.httpClient.Do(req.WithContext(request.Context())) - if err != nil { - klog.V(1).Infof("Well-known proxy request failed for %s: %v", request.URL.Path, err) - http.Error(writer, "Failed to fetch upstream well-known metadata", http.StatusInternalServerError) + + requestPath := request.URL.EscapedPath() + + // Validate the URL path to prevent path traversal + upstreamURL, err := url.JoinPath(authURL, requestPath) + if err != nil || !strings.HasPrefix(upstreamURL, authURL+"/") { + http.Error(writer, "Invalid well-known path", http.StatusBadRequest) return } - defer func() { _ = resp.Body.Close() }() - var resourceMetadata map[string]interface{} - err = json.NewDecoder(io.LimitReader(resp.Body, maxWellKnownResponseSize)).Decode(&resourceMetadata) + + // Try direct proxy first (works for Keycloak and other providers that support all endpoints) + resourceMetadata, respHeaders, err := w.fetchWellKnownEndpoint(request, upstreamURL) if err != nil { - klog.V(1).Infof("Well-known proxy failed to decode response for %s: %v", request.URL.Path, err) - http.Error(writer, "Failed to read upstream response", http.StatusInternalServerError) + klog.V(1).Infof("Well-known proxy failed to fetch %s: %v", requestPath, err) + http.Error(writer, "Failed to fetch well-known metadata", http.StatusInternalServerError) return } - if w.disableDynamicClientRegistration { - delete(resourceMetadata, "registration_endpoint") - resourceMetadata["require_request_uri_registration"] = false - } - if len(w.scopesSupported) > 0 { - resourceMetadata["scopes_supported"] = w.scopesSupported + + // If direct fetch returned nil (404), generate metadata using the configured generator. + // This provides fallback support for OIDC providers that only implement openid-configuration. + // Use prefix matching to handle paths like /.well-known/oauth-protected-resource/sse + if resourceMetadata == nil { + switch { + case strings.HasPrefix(requestPath, oauthAuthorizationServerEndpoint): + resourceMetadata, err = w.generateAuthorizationServerMetadata(request) + if err != nil { + klog.V(1).Infof("Well-known proxy failed to generate authorization server metadata: %v", err) + http.Error(writer, "Failed to generate well-known metadata", http.StatusInternalServerError) + return + } + respHeaders = nil + case strings.HasPrefix(requestPath, oauthProtectedResourceEndpoint): + resourceMetadata, err = w.generateProtectedResourceMetadata(request) + if err != nil { + klog.V(1).Infof("Well-known proxy failed to generate protected resource metadata: %v", err) + http.Error(writer, "Failed to generate well-known metadata", http.StatusInternalServerError) + return + } + respHeaders = nil + } + if resourceMetadata == nil { + http.Error(writer, "Failed to fetch well-known metadata", http.StatusNotFound) + return + } } + + w.applyConfigOverrides(resourceMetadata) + body, err := json.Marshal(resourceMetadata) if err != nil { klog.V(1).Infof("Well-known proxy failed to marshal response for %s: %v", request.URL.Path, err) http.Error(writer, "Internal server error", http.StatusInternalServerError) return } - for key, values := range resp.Header { + + // Copy allowed headers from backend response if available + for key, values := range respHeaders { if !allowedResponseHeaders[http.CanonicalHeaderKey(key)] { continue } @@ -115,10 +200,143 @@ func (w WellKnown) ServeHTTP(writer http.ResponseWriter, request *http.Request) writer.Header().Set("Content-Type", "application/json") writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(body))) withCORSHeaders(writer) - writer.WriteHeader(resp.StatusCode) + writer.WriteHeader(http.StatusOK) _, _ = writer.Write(body) } +// fetchWellKnownEndpoint fetches a well-known endpoint and returns the parsed JSON. +// Returns nil metadata if the endpoint returns 404 (to allow fallback). +func (w *WellKnown) fetchWellKnownEndpoint(request *http.Request, url string) (map[string]interface{}, http.Header, error) { + req, err := http.NewRequest(request.Method, url, nil) + if err != nil { + return nil, nil, fmt.Errorf("failed to create request: %w", err) + } + resp, err := w.wellKnownHTTPClient().Do(req.WithContext(request.Context())) + if err != nil { + return nil, nil, fmt.Errorf("failed to perform request: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + // Return nil for 404 to trigger fallback + if resp.StatusCode == http.StatusNotFound { + return nil, nil, nil + } + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, nil, fmt.Errorf("upstream returned status %d", resp.StatusCode) + } + + var resourceMetadata map[string]interface{} + if err := json.NewDecoder(io.LimitReader(resp.Body, maxWellKnownResponseSize)).Decode(&resourceMetadata); err != nil { + return nil, nil, fmt.Errorf("failed to read response body: %w", err) + } + + return resourceMetadata, resp.Header, nil +} + +// fetchOpenIDConfiguration fetches and caches the openid-configuration from the authorization server. +func (w *WellKnown) fetchOpenIDConfiguration(request *http.Request) (map[string]interface{}, error) { + authURL := w.authorizationURL() + + // Check cache first (with TTL and URL match — invalidate if authorization URL changed) + w.oidcConfigCacheLock.RLock() + if w.oidcConfigCache != nil && w.oidcConfigCacheURL == authURL && time.Since(w.oidcConfigCacheTime) < oidcConfigCacheTTL { + result := copyMap(w.oidcConfigCache) + w.oidcConfigCacheLock.RUnlock() + return result, nil + } + w.oidcConfigCacheLock.RUnlock() + + // Fetch openid-configuration + oidcURL := authURL + openIDConfigurationEndpoint + oidcConfig, _, err := w.fetchWellKnownEndpoint(request, oidcURL) + if err != nil { + return nil, fmt.Errorf("failed to fetch openid-configuration: %w", err) + } + if oidcConfig == nil { + return nil, nil + } + + // Cache the result with timestamp and URL + w.oidcConfigCacheLock.Lock() + w.oidcConfigCache = copyMap(oidcConfig) + w.oidcConfigCacheTime = time.Now() + w.oidcConfigCacheURL = authURL + w.oidcConfigCacheLock.Unlock() + + return oidcConfig, nil +} + +// generateAuthorizationServerMetadata generates oauth-authorization-server metadata +// using the configured metadata generator and the fetched openid-configuration. +func (w *WellKnown) generateAuthorizationServerMetadata(request *http.Request) (map[string]interface{}, error) { + oidcConfig, err := w.fetchOpenIDConfiguration(request) + if err != nil { + return nil, err + } + if oidcConfig == nil { + return nil, nil + } + return w.metadataGenerator.GenerateAuthorizationServerMetadata(oidcConfig), nil +} + +// generateProtectedResourceMetadata generates oauth-protected-resource metadata (RFC 9728) +// using the configured metadata generator. +func (w *WellKnown) generateProtectedResourceMetadata(request *http.Request) (map[string]interface{}, error) { + oidcConfig, err := w.fetchOpenIDConfiguration(request) + if err != nil { + return nil, err + } + if oidcConfig == nil { + return nil, nil + } + + // MCP server URL - where OAuth metadata can be fetched + mcpServerURL := w.buildResourceURL(request) + return w.metadataGenerator.GenerateProtectedResourceMetadata(oidcConfig, mcpServerURL), nil +} + +// buildResourceURL constructs the canonical resource URL for the MCP server. +// Uses server_url from config when set, otherwise infers from request headers. +// SECURITY: When server_url is not configured, this trusts X-Forwarded-Host and +// X-Forwarded-Proto headers. The server MUST be behind a trusted reverse proxy. +func (w *WellKnown) buildResourceURL(request *http.Request) string { + if w.staticConfig != nil && w.staticConfig.ServerURL != "" { + return strings.TrimSuffix(w.staticConfig.ServerURL, "/") + } + scheme := "https" + if request.TLS == nil && !strings.HasPrefix(request.Header.Get("X-Forwarded-Proto"), "https") { + scheme = "http" + } + host := request.Host + if fwdHost := request.Header.Get("X-Forwarded-Host"); fwdHost != "" { + host = fwdHost + } + return fmt.Sprintf("%s://%s", scheme, host) +} + +// applyConfigOverrides applies server configuration overrides to the metadata. +func (w *WellKnown) applyConfigOverrides(resourceMetadata map[string]interface{}) { + snap := w.oauthState.Load() + if snap != nil && snap.DisableDynamicClientRegistration { + delete(resourceMetadata, "registration_endpoint") + resourceMetadata["require_request_uri_registration"] = false + } + if snap != nil && len(snap.OAuthScopes) > 0 { + resourceMetadata["scopes_supported"] = snap.OAuthScopes + } +} + +// copyMap creates a shallow copy so the cached original is not mutated by +// applyConfigOverrides, which only modifies top-level keys. +func copyMap(src map[string]interface{}) map[string]interface{} { + dst := make(map[string]interface{}, len(src)) + for k, v := range src { + dst[k] = v + } + return dst +} + func withCORSHeaders(writer http.ResponseWriter) { writer.Header().Set("Access-Control-Allow-Origin", "*") writer.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS") diff --git a/pkg/http/wellknown_test.go b/pkg/http/wellknown_test.go index 1d2a1adda..a9deea270 100644 --- a/pkg/http/wellknown_test.go +++ b/pkg/http/wellknown_test.go @@ -328,6 +328,124 @@ func (s *WellknownSuite) TestOversizedUpstreamResponse() { }) } +func (s *WellknownSuite) TestMetadataGenerationFallback() { + s.Run("generates oauth-authorization-server from openid-configuration when endpoint returns 404", func() { + // Simulate OIDC providers that only implement openid-configuration (e.g., Entra ID, Auth0) + s.TestServer.Close() + s.TestServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.EscapedPath() { + case "/.well-known/openid-configuration": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{ + "issuer": "https://login.microsoftonline.com/tenant/v2.0", + "authorization_endpoint": "https://login.microsoftonline.com/tenant/oauth2/v2.0/authorize", + "token_endpoint": "https://login.microsoftonline.com/tenant/oauth2/v2.0/token", + "jwks_uri": "https://login.microsoftonline.com/tenant/discovery/v2.0/keys", + "scopes_supported": ["openid", "profile", "email"] + }`)) + default: + http.NotFound(w, r) + } + })) + s.StaticConfig.AuthorizationURL = s.TestServer.URL + s.StaticConfig.RequireOAuth = true + s.StartServer() + + // oauth-authorization-server should work via fallback + resp, err := http.Get(fmt.Sprintf("http://127.0.0.1:%s/.well-known/oauth-authorization-server", s.StaticConfig.Port)) + s.Require().NoError(err) + s.T().Cleanup(func() { _ = resp.Body.Close() }) + + s.Equal(http.StatusOK, resp.StatusCode, "Expected fallback to succeed") + + body, err := io.ReadAll(resp.Body) + s.Require().NoError(err) + s.Contains(string(body), "login.microsoftonline.com", "Expected Entra ID issuer in response") + s.Contains(string(body), "authorization_endpoint", "Expected authorization_endpoint in response") + }) + + s.Run("generates RFC 9728 compliant oauth-protected-resource when endpoint returns 404", func() { + s.TestServer.Close() + s.TestServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.EscapedPath() { + case "/.well-known/openid-configuration": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{ + "issuer": "https://login.microsoftonline.com/tenant/v2.0", + "token_endpoint": "https://login.microsoftonline.com/tenant/oauth2/v2.0/token", + "scopes_supported": ["openid", "profile"] + }`)) + default: + http.NotFound(w, r) + } + })) + s.StaticConfig.AuthorizationURL = s.TestServer.URL + s.StaticConfig.RequireOAuth = true + s.StartServer() + + resp, err := http.Get(fmt.Sprintf("http://127.0.0.1:%s/.well-known/oauth-protected-resource", s.StaticConfig.Port)) + s.Require().NoError(err) + s.T().Cleanup(func() { _ = resp.Body.Close() }) + + s.Equal(http.StatusOK, resp.StatusCode, "Expected fallback to succeed") + + body, err := io.ReadAll(resp.Body) + s.Require().NoError(err) + + // Verify RFC 9728 format - MCP server is the authorization_server from client's perspective + s.Contains(string(body), `"authorization_servers":`, "Expected authorization_servers array per RFC 9728") + s.Contains(string(body), fmt.Sprintf("127.0.0.1:%s", s.StaticConfig.Port), "Expected authorization_servers to contain MCP server URL") + s.Contains(string(body), `"scopes_supported":`, "Expected scopes_supported from openid-configuration") + }) + + s.Run("returns 404 when both oauth-authorization-server and openid-configuration return 404", func() { + s.TestServer.Close() + s.TestServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.NotFound(w, r) + })) + s.StaticConfig.AuthorizationURL = s.TestServer.URL + s.StaticConfig.RequireOAuth = true + s.StartServer() + + resp, err := http.Get(fmt.Sprintf("http://127.0.0.1:%s/.well-known/oauth-authorization-server", s.StaticConfig.Port)) + s.Require().NoError(err) + s.T().Cleanup(func() { _ = resp.Body.Close() }) + + s.Equal(http.StatusNotFound, resp.StatusCode, "Expected 404 when all endpoints fail") + }) + + s.Run("applies config overrides to fallback response", func() { + s.TestServer.Close() + s.TestServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.EscapedPath() { + case "/.well-known/openid-configuration": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{ + "issuer": "https://login.microsoftonline.com/tenant/v2.0", + "scopes_supported": ["openid"], + "registration_endpoint": "https://should-be-removed" + }`)) + default: + http.NotFound(w, r) + } + })) + s.StaticConfig.AuthorizationURL = s.TestServer.URL + s.StaticConfig.RequireOAuth = true + s.StaticConfig.DisableDynamicClientRegistration = true + s.StaticConfig.OAuthScopes = []string{"custom-scope"} + s.StartServer() + + resp, err := http.Get(fmt.Sprintf("http://127.0.0.1:%s/.well-known/oauth-authorization-server", s.StaticConfig.Port)) + s.Require().NoError(err) + s.T().Cleanup(func() { _ = resp.Body.Close() }) + + body, err := io.ReadAll(resp.Body) + s.Require().NoError(err) + s.NotContains(string(body), "registration_endpoint", "registration_endpoint should be removed") + s.Contains(string(body), `"scopes_supported":["custom-scope"]`, "scopes should be overridden") + }) +} + func TestWellknown(t *testing.T) { suite.Run(t, new(WellknownSuite)) } diff --git a/pkg/kubernetes-mcp-server/cmd/root.go b/pkg/kubernetes-mcp-server/cmd/root.go index 8f15ac0bc..9bcc8d727 100644 --- a/pkg/kubernetes-mcp-server/cmd/root.go +++ b/pkg/kubernetes-mcp-server/cmd/root.go @@ -2,12 +2,9 @@ package cmd import ( "context" - "crypto/tls" - "crypto/x509" "errors" "flag" "fmt" - "net/http" "net/url" "os" "os/signal" @@ -17,7 +14,6 @@ import ( "syscall" "time" - "github.com/coreos/go-oidc/v3/oidc" "github.com/spf13/cobra" "k8s.io/cli-runtime/pkg/genericiooptions" @@ -31,6 +27,7 @@ import ( internalhttp "github.com/containers/kubernetes-mcp-server/pkg/http" "github.com/containers/kubernetes-mcp-server/pkg/kubernetes" "github.com/containers/kubernetes-mcp-server/pkg/mcp" + internaloauth "github.com/containers/kubernetes-mcp-server/pkg/oauth" "github.com/containers/kubernetes-mcp-server/pkg/output" "github.com/containers/kubernetes-mcp-server/pkg/telemetry" "github.com/containers/kubernetes-mcp-server/pkg/toolsets" @@ -346,6 +343,9 @@ func (m *MCPServerOptions) Validate() error { if err := m.StaticConfig.ValidateRequireTLS(); err != nil { return err } + if err := m.StaticConfig.ValidateClusterAuthMode(); err != nil { + return err + } return nil } @@ -375,48 +375,13 @@ func (m *MCPServerOptions) Run() error { return nil } - var oidcProvider *oidc.Provider - var httpClient *http.Client - if m.StaticConfig.AuthorizationURL != "" { - ctx := context.Background() - if m.StaticConfig.CertificateAuthority != "" { - httpClient = &http.Client{} - caCert, err := os.ReadFile(m.StaticConfig.CertificateAuthority) - if err != nil { - return fmt.Errorf("failed to read CA certificate from %s: %w", m.StaticConfig.CertificateAuthority, err) - } - caCertPool := x509.NewCertPool() - if !caCertPool.AppendCertsFromPEM(caCert) { - return fmt.Errorf("failed to append CA certificate from %s to pool", m.StaticConfig.CertificateAuthority) - } - - if caCertPool.Equal(x509.NewCertPool()) { - caCertPool = nil - } - - var transport http.RoundTripper = &http.Transport{ - TLSClientConfig: &tls.Config{ - MinVersion: tls.VersionTLS12, - RootCAs: caCertPool, - }, - } - // Wrap transport with TLS enforcement - transport = config.NewTLSEnforcingTransport(transport, m.StaticConfig.IsRequireTLS) - httpClient.Transport = transport - ctx = oidc.ClientContext(ctx, httpClient) - } else { - // No custom CA, but still enforce TLS if required - httpClient = config.NewTLSEnforcingClient(nil, m.StaticConfig.IsRequireTLS) - ctx = oidc.ClientContext(ctx, httpClient) - } - provider, err := oidc.NewProvider(ctx, m.StaticConfig.AuthorizationURL) - if err != nil { - return fmt.Errorf("unable to setup OIDC provider: %w", err) - } - oidcProvider = provider + oidcProvider, httpClient, err := internaloauth.CreateOIDCProviderAndClient(m.StaticConfig) + if err != nil { + return err } + oauthState := internaloauth.NewState(internaloauth.SnapshotFromConfig(m.StaticConfig, oidcProvider, httpClient)) - provider, err := kubernetes.NewProvider(m.StaticConfig, kubernetes.WithTokenExchange(oidcProvider, httpClient)) + provider, err := kubernetes.NewProvider(m.StaticConfig, kubernetes.WithTokenExchange(oauthState)) if err != nil { return fmt.Errorf("unable to create kubernetes target provider: %w", err) } @@ -437,12 +402,12 @@ func (m *MCPServerOptions) Run() error { // Set up SIGHUP handler for configuration reload if m.ConfigPath != "" || m.ConfigDir != "" { - _ = m.setupSIGHUPHandler(mcpServer) + _ = m.setupSIGHUPHandler(mcpServer, oauthState) } if m.StaticConfig.Port != "" { ctx := context.Background() - return internalhttp.Serve(ctx, mcpServer, m.StaticConfig, oidcProvider, httpClient) + return internalhttp.Serve(ctx, mcpServer, m.StaticConfig, oauthState) } ctx := context.Background() @@ -456,7 +421,7 @@ func (m *MCPServerOptions) Run() error { // setupSIGHUPHandler sets up a signal handler to reload configuration on SIGHUP. // Returns a stop function that should be called to clean up the handler. // The stop function waits for the handler goroutine to finish. -func (m *MCPServerOptions) setupSIGHUPHandler(mcpServer *mcp.Server) (stop func()) { +func (m *MCPServerOptions) setupSIGHUPHandler(mcpServer *mcp.Server, oauthState *internaloauth.State) (stop func()) { sigHupCh := make(chan os.Signal, 1) done := make(chan struct{}) signal.Notify(sigHupCh, syscall.SIGHUP) @@ -473,12 +438,35 @@ func (m *MCPServerOptions) setupSIGHUPHandler(mcpServer *mcp.Server) (stop func( continue } - // Apply the new configuration to the MCP server + // Apply the new configuration to the MCP server first — if this fails, + // we skip the OAuth state update to avoid inconsistent state. if err := mcpServer.ReloadConfiguration(newConfig); err != nil { klog.Errorf("Failed to apply reloaded configuration: %v", err) continue } + // Check if OAuth-relevant config changed and update the shared state + currentSnapshot := oauthState.Load() + if currentSnapshot == nil { + currentSnapshot = &internaloauth.Snapshot{} + } + newSnapshot := internaloauth.SnapshotFromConfig(newConfig, currentSnapshot.OIDCProvider, currentSnapshot.HTTPClient) + if currentSnapshot.HasProviderConfigChanged(newSnapshot) { + klog.V(1).Info("OAuth configuration changed, recreating OIDC provider...") + newProvider, newClient, err := internaloauth.CreateOIDCProviderAndClient(newConfig) + if err != nil { + klog.Errorf("Failed to recreate OIDC provider during reload: %v", err) + continue + } + newSnapshot.OIDCProvider = newProvider + newSnapshot.HTTPClient = newClient + oauthState.Store(newSnapshot) + klog.V(1).Info("OIDC provider and HTTP client updated successfully") + } else if currentSnapshot.HasWellKnownConfigChanged(newSnapshot) { + oauthState.Store(newSnapshot) + klog.V(1).Info("OAuth well-known configuration updated") + } + klog.V(1).Info("Configuration reloaded successfully via SIGHUP") } }() diff --git a/pkg/kubernetes-mcp-server/cmd/root_sighup_test.go b/pkg/kubernetes-mcp-server/cmd/root_sighup_test.go index 5d86966ad..6fd6b10a0 100644 --- a/pkg/kubernetes-mcp-server/cmd/root_sighup_test.go +++ b/pkg/kubernetes-mcp-server/cmd/root_sighup_test.go @@ -15,6 +15,7 @@ import ( "github.com/containers/kubernetes-mcp-server/pkg/config" "github.com/containers/kubernetes-mcp-server/pkg/kubernetes" "github.com/containers/kubernetes-mcp-server/pkg/mcp" + "github.com/containers/kubernetes-mcp-server/pkg/oauth" "github.com/stretchr/testify/suite" "k8s.io/klog/v2" "k8s.io/klog/v2/textlogger" @@ -78,7 +79,8 @@ func (s *SIGHUPSuite) InitServer(configPath, configDir string) { ConfigPath: configPath, ConfigDir: configDir, } - s.stopSIGHUP = opts.setupSIGHUPHandler(s.server) + oauthState := oauth.NewState(&oauth.Snapshot{}) + s.stopSIGHUP = opts.setupSIGHUPHandler(s.server, oauthState) } func (s *SIGHUPSuite) TestSIGHUPReloadsConfigFromFile() { diff --git a/pkg/kubernetes/manager.go b/pkg/kubernetes/manager.go index 8275f9085..61a9e64c7 100644 --- a/pkg/kubernetes/manager.go +++ b/pkg/kubernetes/manager.go @@ -116,10 +116,11 @@ func NewManager(config api.BaseConfig, restConfig *rest.Config, clientCmdConfig func (m *Manager) Derived(ctx context.Context) (*Kubernetes, error) { authorization, ok := ctx.Value(OAuthAuthorizationHeader).(string) if !ok || !strings.HasPrefix(authorization, "Bearer ") { - if m.config.IsRequireOAuth() { - return nil, errors.New("oauth token required") + // Use kubeconfig credentials if explicitly configured or if OAuth is not required + if m.config.GetClusterAuthMode() == api.ClusterAuthKubeconfig || !m.config.IsRequireOAuth() { + return m.kubernetes, nil } - return m.kubernetes, nil + return nil, errors.New("oauth token required") } klog.V(5).Infof("%s header found (Bearer), using provided bearer token", OAuthAuthorizationHeader) userAgent := CustomUserAgent diff --git a/pkg/kubernetes/provider.go b/pkg/kubernetes/provider.go index 9312901c1..30de75069 100644 --- a/pkg/kubernetes/provider.go +++ b/pkg/kubernetes/provider.go @@ -2,11 +2,10 @@ package kubernetes import ( "context" - "net/http" "github.com/containers/kubernetes-mcp-server/pkg/api" + "github.com/containers/kubernetes-mcp-server/pkg/oauth" "github.com/containers/kubernetes-mcp-server/pkg/tokenexchange" - "github.com/coreos/go-oidc/v3/oidc" ) // McpReload is a function type that defines a callback for reloading MCP toolsets (including tools, prompts, or other configurations) @@ -54,20 +53,12 @@ type TokenExchangeProvider interface { type ProviderOption func(*providerOptions) type providerOptions struct { - tokenExchangeConfig *providerTokenExchangeConfig + oauthState *oauth.State } -type providerTokenExchangeConfig struct { - oidcProvider *oidc.Provider - httpClient *http.Client -} - -func WithTokenExchange(oidcProvider *oidc.Provider, httpClient *http.Client) ProviderOption { +func WithTokenExchange(oauthState *oauth.State) ProviderOption { return func(opts *providerOptions) { - opts.tokenExchangeConfig = &providerTokenExchangeConfig{ - oidcProvider: oidcProvider, - httpClient: httpClient, - } + opts.oauthState = oauthState } } @@ -89,12 +80,11 @@ func NewProvider(cfg api.BaseConfig, opts ...ProviderOption) (Provider, error) { return nil, err } - if providerOpts.tokenExchangeConfig != nil { + if providerOpts.oauthState != nil { provider = newTokenExchangingProvider( provider, cfg, - providerOpts.tokenExchangeConfig.oidcProvider, - providerOpts.tokenExchangeConfig.httpClient, + providerOpts.oauthState, ) } diff --git a/pkg/kubernetes/provider_token_exchange.go b/pkg/kubernetes/provider_token_exchange.go index 8c1907f34..ee3221b4b 100644 --- a/pkg/kubernetes/provider_token_exchange.go +++ b/pkg/kubernetes/provider_token_exchange.go @@ -2,40 +2,105 @@ package kubernetes import ( "context" - "net/http" + "sync" "github.com/containers/kubernetes-mcp-server/pkg/api" - "github.com/coreos/go-oidc/v3/oidc" + "github.com/containers/kubernetes-mcp-server/pkg/oauth" + "github.com/containers/kubernetes-mcp-server/pkg/tokenexchange" + "k8s.io/klog/v2" ) type tokenExchangingProvider struct { - provider Provider - stsConfigProvider api.StsConfigProvider - oidcProvider *oidc.Provider - httpClient *http.Client + provider Provider + baseConfig api.BaseConfig + oauthState *oauth.State + // stsConfig is cached and reused across calls so that assertion caching + // in TargetTokenExchangeConfig is effective. Rebuilt when the token URL changes + // (e.g., after SIGHUP reloads the OIDC provider). + stsConfig *tokenexchange.TargetTokenExchangeConfig + stsConfigMu sync.Mutex + stsTokenURL string // tracks which token URL the cached config was built for } var _ Provider = &tokenExchangingProvider{} func newTokenExchangingProvider( provider Provider, - stsConfigProvider api.StsConfigProvider, - oidcProvider *oidc.Provider, - httpClient *http.Client, + baseConfig api.BaseConfig, + oauthState *oauth.State, ) Provider { return &tokenExchangingProvider{ - provider: provider, - stsConfigProvider: stsConfigProvider, - oidcProvider: oidcProvider, - httpClient: httpClient, + provider: provider, + baseConfig: baseConfig, + oauthState: oauthState, } } func (p *tokenExchangingProvider) GetDerivedKubernetes(ctx context.Context, target string) (*Kubernetes, error) { - ctx = ExchangeTokenInContext(ctx, p.stsConfigProvider, p.oidcProvider, p.httpClient, p.provider, target) + snap := p.oauthState.Load() + if snap == nil { + return p.provider.GetDerivedKubernetes(ctx, target) + } + stsConfig := p.getOrBuildStsConfig(snap) + ctx, err := ExchangeTokenInContext(ctx, p.baseConfig, snap.OIDCProvider, snap.HTTPClient, p.provider, target, stsConfig) + if err != nil { + return nil, err + } return p.provider.GetDerivedKubernetes(ctx, target) } +// getOrBuildStsConfig returns a cached STS config, rebuilding it when the +// OIDC provider's token URL changes (e.g., after SIGHUP). +func (p *tokenExchangingProvider) getOrBuildStsConfig(snap *oauth.Snapshot) *tokenexchange.TargetTokenExchangeConfig { + strategy := p.baseConfig.GetStsStrategy() + if strategy == "" { + return nil + } + + var tokenURL string + if snap.OIDCProvider != nil { + if endpoint := snap.OIDCProvider.Endpoint(); endpoint.TokenURL != "" { + tokenURL = endpoint.TokenURL + } + } + if tokenURL == "" { + klog.Warningf("token exchange strategy %q configured but OIDC provider returned empty token URL", strategy) + return nil + } + + p.stsConfigMu.Lock() + defer p.stsConfigMu.Unlock() + + // Return cached config if token URL hasn't changed + if p.stsConfig != nil && p.stsTokenURL == tokenURL { + return p.stsConfig + } + + authStyle := p.baseConfig.GetStsAuthStyle() + if authStyle == "" { + authStyle = tokenexchange.AuthStyleParams + } + + cfg := &tokenexchange.TargetTokenExchangeConfig{ + TokenURL: tokenURL, + ClientID: p.baseConfig.GetStsClientId(), + ClientSecret: p.baseConfig.GetStsClientSecret(), + Audience: p.baseConfig.GetStsAudience(), + Scopes: p.baseConfig.GetStsScopes(), + AuthStyle: authStyle, + ClientCertFile: p.baseConfig.GetStsClientCertFile(), + ClientKeyFile: p.baseConfig.GetStsClientKeyFile(), + } + if err := cfg.Validate(); err != nil { + klog.Warningf("STS config validation failed, token exchange will be attempted per-request but will likely fail with the same error: %v", err) + return nil + } + + p.stsConfig = cfg + p.stsTokenURL = tokenURL + return p.stsConfig +} + func (p *tokenExchangingProvider) IsOpenShift(ctx context.Context) bool { return p.provider.IsOpenShift(ctx) } diff --git a/pkg/kubernetes/token_exchange.go b/pkg/kubernetes/token_exchange.go index e6b35d3c4..631dcb5e3 100644 --- a/pkg/kubernetes/token_exchange.go +++ b/pkg/kubernetes/token_exchange.go @@ -2,6 +2,7 @@ package kubernetes import ( "context" + "fmt" "net/http" "strings" @@ -12,57 +13,118 @@ import ( "k8s.io/klog/v2" ) +// ExchangeTokenInContext exchanges the OAuth token in the context for a token +// that can access the target cluster. The optional stsConfig parameter allows +// callers to reuse a TargetTokenExchangeConfig across calls to benefit from +// assertion caching (pass nil to build a fresh config each time). func ExchangeTokenInContext( ctx context.Context, - stsConfigProvider api.StsConfigProvider, + baseConfig api.BaseConfig, oidcProvider *oidc.Provider, httpClient *http.Client, provider Provider, target string, -) context.Context { + stsConfig *tokenexchange.TargetTokenExchangeConfig, +) (context.Context, error) { auth, ok := ctx.Value(OAuthAuthorizationHeader).(string) if !ok || !strings.HasPrefix(auth, "Bearer ") { - return ctx + return ctx, nil } subjectToken := strings.TrimPrefix(auth, "Bearer ") tep, ok := provider.(TokenExchangeProvider) if !ok { - return stsExchangeTokenInContext(ctx, stsConfigProvider, oidcProvider, httpClient, subjectToken) + return stsExchangeTokenInContext(ctx, baseConfig, oidcProvider, httpClient, subjectToken, stsConfig) } exCfg := tep.GetTokenExchangeConfig(target) if exCfg == nil { - return stsExchangeTokenInContext(ctx, stsConfigProvider, oidcProvider, httpClient, subjectToken) + return stsExchangeTokenInContext(ctx, baseConfig, oidcProvider, httpClient, subjectToken, stsConfig) } exchanger, ok := tokenexchange.GetTokenExchanger(tep.GetTokenExchangeStrategy()) if !ok { klog.Warningf("token exchange strategy %q not found in registry", tep.GetTokenExchangeStrategy()) - return stsExchangeTokenInContext(ctx, stsConfigProvider, oidcProvider, httpClient, subjectToken) + return stsExchangeTokenInContext(ctx, baseConfig, oidcProvider, httpClient, subjectToken, stsConfig) } exchanged, err := exchanger.Exchange(ctx, exCfg, subjectToken) if err != nil { - klog.Errorf("token exchange failed for target %q: %v", target, err) - return ctx + return ctx, fmt.Errorf("token exchange failed for target %q: %w", target, err) } - - klog.V(4).Infof("token exchanged successfully for target %q", target) - return context.WithValue(ctx, OAuthAuthorizationHeader, "Bearer "+exchanged.AccessToken) + return context.WithValue(ctx, OAuthAuthorizationHeader, "Bearer "+exchanged.AccessToken), nil } -// TODO(Cali0707): remove this method and move to using the rfc8693 token exchanger for the global token exchange func stsExchangeTokenInContext( ctx context.Context, - stsConfigProvider api.StsConfigProvider, + baseConfig api.BaseConfig, oidcProvider *oidc.Provider, httpClient *http.Client, token string, -) context.Context { - sts := NewFromConfig(stsConfigProvider, oidcProvider) + stsConfig *tokenexchange.TargetTokenExchangeConfig, +) (context.Context, error) { + // Determine cluster auth mode first to avoid unnecessary token exchange + mode := baseConfig.GetClusterAuthMode() + if mode == "" { + mode = detectClusterAuthMode(baseConfig) + } + + switch mode { + case api.ClusterAuthKubeconfig: + // Use kubeconfig credentials, no token exchange needed + return context.WithValue(ctx, OAuthAuthorizationHeader, ""), nil + + case api.ClusterAuthPassthrough: + // Exchange the token if configured, then pass through to cluster + exchangedToken := token + if strategy := baseConfig.GetStsStrategy(); strategy != "" { + exchangedCtx, err := strategyBasedTokenExchange(ctx, baseConfig, oidcProvider, httpClient, token, strategy, stsConfig) + if err != nil { + return ctx, fmt.Errorf("strategy-based token exchange failed: %w", err) + } + if auth, ok := exchangedCtx.Value(OAuthAuthorizationHeader).(string); ok && strings.HasPrefix(auth, "Bearer ") { + exchangedToken = strings.TrimPrefix(auth, "Bearer ") + } + } else { + sts := NewFromConfig(baseConfig, oidcProvider) + if sts.IsEnabled() { + exchangedCtx, err := builtinStsExchange(ctx, baseConfig, oidcProvider, httpClient, token) + if err != nil { + return ctx, fmt.Errorf("built-in STS token exchange failed: %w", err) + } + if auth, ok := exchangedCtx.Value(OAuthAuthorizationHeader).(string); ok && strings.HasPrefix(auth, "Bearer ") { + exchangedToken = strings.TrimPrefix(auth, "Bearer ") + } + } + } + return context.WithValue(ctx, OAuthAuthorizationHeader, "Bearer "+exchangedToken), nil + + default: + return ctx, fmt.Errorf("unknown cluster_auth_mode %q", mode) + } +} + +// detectClusterAuthMode auto-detects the cluster auth mode based on config. +func detectClusterAuthMode(baseConfig api.BaseConfig) string { + // If OAuth is required, default to passthrough + if baseConfig.IsRequireOAuth() { + return api.ClusterAuthPassthrough + } + // No OAuth required, use kubeconfig credentials + return api.ClusterAuthKubeconfig +} + +// builtinStsExchange performs the built-in RFC 8693 STS exchange. +func builtinStsExchange( + ctx context.Context, + baseConfig api.BaseConfig, + oidcProvider *oidc.Provider, + httpClient *http.Client, + token string, +) (context.Context, error) { + sts := NewFromConfig(baseConfig, oidcProvider) if !sts.IsEnabled() { - return ctx + return ctx, fmt.Errorf("token-exchange mode configured but STS is not enabled") } if httpClient != nil { @@ -74,10 +136,65 @@ func stsExchangeTokenInContext( TokenType: "Bearer", }) if err != nil { - klog.Errorf("token exchange failed: %v", err) - return ctx + return ctx, fmt.Errorf("built-in STS exchange: %w", err) } + return context.WithValue(ctx, OAuthAuthorizationHeader, "Bearer "+exchangedToken.AccessToken), nil +} - klog.V(4).Info("token exchanged successfully") - return context.WithValue(ctx, OAuthAuthorizationHeader, "Bearer "+exchangedToken.AccessToken) +func strategyBasedTokenExchange( + ctx context.Context, + baseConfig api.BaseConfig, + oidcProvider *oidc.Provider, + httpClient *http.Client, + token string, + strategy string, + cachedConfig *tokenexchange.TargetTokenExchangeConfig, +) (context.Context, error) { + exchanger, ok := tokenexchange.GetTokenExchanger(strategy) + if !ok { + return ctx, fmt.Errorf("token exchange strategy %q not found", strategy) + } + + cfg := cachedConfig + if cfg == nil { + // Build token URL from OIDC provider + var tokenURL string + if oidcProvider != nil { + if endpoint := oidcProvider.Endpoint(); endpoint.TokenURL != "" { + tokenURL = endpoint.TokenURL + } + } + if tokenURL == "" { + return ctx, fmt.Errorf("token exchange failed: no token URL available from OIDC provider") + } + + authStyle := baseConfig.GetStsAuthStyle() + if authStyle == "" { + authStyle = tokenexchange.AuthStyleParams + } + + cfg = &tokenexchange.TargetTokenExchangeConfig{ + TokenURL: tokenURL, + ClientID: baseConfig.GetStsClientId(), + ClientSecret: baseConfig.GetStsClientSecret(), + Audience: baseConfig.GetStsAudience(), + Scopes: baseConfig.GetStsScopes(), + AuthStyle: authStyle, + ClientCertFile: baseConfig.GetStsClientCertFile(), + ClientKeyFile: baseConfig.GetStsClientKeyFile(), + } + if err := cfg.Validate(); err != nil { + return ctx, fmt.Errorf("token exchange config validation: %w", err) + } + } + + if httpClient != nil { + ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient) + } + + exchanged, err := exchanger.Exchange(ctx, cfg, token) + if err != nil { + return ctx, fmt.Errorf("token exchange with strategy %q: %w", strategy, err) + } + return context.WithValue(ctx, OAuthAuthorizationHeader, "Bearer "+exchanged.AccessToken), nil } diff --git a/pkg/kubernetes/token_exchange_test.go b/pkg/kubernetes/token_exchange_test.go new file mode 100644 index 000000000..0fa16ac12 --- /dev/null +++ b/pkg/kubernetes/token_exchange_test.go @@ -0,0 +1,86 @@ +package kubernetes + +import ( + "context" + "testing" + + "github.com/containers/kubernetes-mcp-server/pkg/api" + "github.com/containers/kubernetes-mcp-server/pkg/config" + "github.com/stretchr/testify/suite" +) + +type TokenExchangeRoutingSuite struct { + suite.Suite +} + +func (s *TokenExchangeRoutingSuite) TestDetectClusterAuthMode() { + s.Run("returns passthrough when OAuth is required", func() { + cfg := config.Default() + cfg.RequireOAuth = true + mode := detectClusterAuthMode(cfg) + s.Equal(api.ClusterAuthPassthrough, mode) + }) + + s.Run("returns kubeconfig when OAuth is not required", func() { + cfg := config.Default() + cfg.RequireOAuth = false + mode := detectClusterAuthMode(cfg) + s.Equal(api.ClusterAuthKubeconfig, mode) + }) +} + +func (s *TokenExchangeRoutingSuite) TestStsExchangeTokenInContextRouting() { + s.Run("kubeconfig mode clears OAuth token", func() { + cfg := config.Default() + cfg.ClusterAuthMode = api.ClusterAuthKubeconfig + + ctx := context.WithValue(context.Background(), OAuthAuthorizationHeader, "Bearer original-token") + result, err := stsExchangeTokenInContext(ctx, cfg, nil, nil, "original-token", nil) + s.Require().NoError(err) + + auth, _ := result.Value(OAuthAuthorizationHeader).(string) + s.Equal("", auth) + }) + + s.Run("passthrough mode preserves token", func() { + cfg := config.Default() + cfg.ClusterAuthMode = api.ClusterAuthPassthrough + + ctx := context.Background() + result, err := stsExchangeTokenInContext(ctx, cfg, nil, nil, "original-token", nil) + s.Require().NoError(err) + + auth, _ := result.Value(OAuthAuthorizationHeader).(string) + s.Equal("Bearer original-token", auth) + }) + + s.Run("auto-detect defaults to kubeconfig when OAuth not required", func() { + cfg := config.Default() + cfg.RequireOAuth = false + cfg.ClusterAuthMode = "" // auto-detect + + ctx := context.Background() + result, err := stsExchangeTokenInContext(ctx, cfg, nil, nil, "original-token", nil) + s.Require().NoError(err) + + auth, _ := result.Value(OAuthAuthorizationHeader).(string) + s.Equal("", auth) + }) + + s.Run("auto-detect defaults to passthrough when OAuth required", func() { + cfg := config.Default() + cfg.RequireOAuth = true + cfg.ClusterAuthMode = "" // auto-detect + + ctx := context.Background() + result, err := stsExchangeTokenInContext(ctx, cfg, nil, nil, "original-token", nil) + s.Require().NoError(err) + + auth, _ := result.Value(OAuthAuthorizationHeader).(string) + s.Equal("Bearer original-token", auth) + }) +} + +func TestTokenExchangeRouting(t *testing.T) { + suite.Run(t, new(TokenExchangeRoutingSuite)) +} diff --git a/pkg/mcp/middleware.go b/pkg/mcp/middleware.go index 726d95653..0bc1de07e 100644 --- a/pkg/mcp/middleware.go +++ b/pkg/mcp/middleware.go @@ -46,6 +46,9 @@ func authHeaderPropagationMiddleware(next mcp.MethodHandler) mcp.MethodHandler { return next(context.WithValue(ctx, internalk8s.OAuthAuthorizationHeader, customAuthHeader), method, req) } } + + // If no auth header in RequestExtra, context may already have it from HTTP middleware + // (used by SSE transport where HTTP headers aren't propagated to RequestExtra) return next(ctx, method, req) } } diff --git a/pkg/oauth/state.go b/pkg/oauth/state.go new file mode 100644 index 000000000..3a3e59d7e --- /dev/null +++ b/pkg/oauth/state.go @@ -0,0 +1,131 @@ +package oauth + +import ( + "context" + "crypto/tls" + "crypto/x509" + "fmt" + "net/http" + "os" + "slices" + "sync/atomic" + + "github.com/containers/kubernetes-mcp-server/pkg/config" + "github.com/coreos/go-oidc/v3/oidc" +) + +// Snapshot is an immutable point-in-time capture of OAuth-related state. +// It is swapped atomically via State so all consumers see a consistent view. +type Snapshot struct { + OIDCProvider *oidc.Provider + HTTPClient *http.Client + AuthorizationURL string + CertificateAuthority string + OAuthScopes []string + DisableDynamicClientRegistration bool +} + +// HasProviderConfigChanged reports whether the fields that require OIDC provider +// and HTTP client recreation have changed between two snapshots. +func (s *Snapshot) HasProviderConfigChanged(other *Snapshot) bool { + if s == nil || other == nil { + return s != other + } + return s.AuthorizationURL != other.AuthorizationURL || + s.CertificateAuthority != other.CertificateAuthority +} + +// HasWellKnownConfigChanged reports whether any WellKnown-serving fields changed. +func (s *Snapshot) HasWellKnownConfigChanged(other *Snapshot) bool { + if s == nil || other == nil { + return s != other + } + if s.HasProviderConfigChanged(other) { + return true + } + if s.DisableDynamicClientRegistration != other.DisableDynamicClientRegistration { + return true + } + if !slices.Equal(s.OAuthScopes, other.OAuthScopes) { + return true + } + return false +} + +// State holds the current OAuth snapshot and allows atomic, lock-free reads. +type State struct { + ref atomic.Pointer[Snapshot] +} + +// NewState creates a new State initialized with the given snapshot. +func NewState(snap *Snapshot) *State { + s := &State{} + s.ref.Store(snap) + return s +} + +// Load returns the current snapshot. Safe for concurrent use. +func (s *State) Load() *Snapshot { + return s.ref.Load() +} + +// Store atomically replaces the current snapshot. +func (s *State) Store(snap *Snapshot) { + s.ref.Store(snap) +} + +// SnapshotFromConfig extracts OAuth-relevant fields from a StaticConfig and +// pairs them with the corresponding OIDC provider and HTTP client. +func SnapshotFromConfig(cfg *config.StaticConfig, provider *oidc.Provider, httpClient *http.Client) *Snapshot { + return &Snapshot{ + OIDCProvider: provider, + HTTPClient: httpClient, + AuthorizationURL: cfg.AuthorizationURL, + CertificateAuthority: cfg.CertificateAuthority, + OAuthScopes: cfg.OAuthScopes, + DisableDynamicClientRegistration: cfg.DisableDynamicClientRegistration, + } +} + +// CreateOIDCProviderAndClient builds an OIDC provider and HTTP client from config. +// Returns (nil, nil, nil) when AuthorizationURL is empty (OAuth not configured). +func CreateOIDCProviderAndClient(cfg *config.StaticConfig) (*oidc.Provider, *http.Client, error) { + if cfg.AuthorizationURL == "" { + return nil, nil, nil + } + + ctx := context.Background() + var httpClient *http.Client + + if cfg.CertificateAuthority != "" { + caCert, err := os.ReadFile(cfg.CertificateAuthority) + if err != nil { + return nil, nil, fmt.Errorf("failed to read CA certificate from %s: %w", cfg.CertificateAuthority, err) + } + caCertPool := x509.NewCertPool() + if !caCertPool.AppendCertsFromPEM(caCert) { + return nil, nil, fmt.Errorf("failed to append CA certificate from %s to pool", cfg.CertificateAuthority) + } + if caCertPool.Equal(x509.NewCertPool()) { + caCertPool = nil + } + var transport http.RoundTripper = &http.Transport{ + TLSClientConfig: &tls.Config{ + MinVersion: tls.VersionTLS12, + RootCAs: caCertPool, + }, + } + transport = config.NewTLSEnforcingTransport(transport, cfg.IsRequireTLS) + httpClient = &http.Client{Transport: transport} + } else { + httpClient = config.NewTLSEnforcingClient(nil, cfg.IsRequireTLS) + } + + ctx = oidc.ClientContext(ctx, httpClient) + provider, err := oidc.NewProvider(ctx, cfg.AuthorizationURL) + if err != nil { + return nil, nil, fmt.Errorf("unable to setup OIDC provider: %w", err) + } + + return provider, httpClient, nil +} diff --git a/pkg/oauth/state_test.go b/pkg/oauth/state_test.go new file mode 100644 index 000000000..703a0d73f --- /dev/null +++ b/pkg/oauth/state_test.go @@ -0,0 +1,133 @@ +package oauth + +import ( + "sync" + "testing" + + "github.com/containers/kubernetes-mcp-server/pkg/config" + "github.com/stretchr/testify/suite" +) + +type OAuthStateSuite struct { + suite.Suite +} + +func TestOAuthState(t *testing.T) { + suite.Run(t, new(OAuthStateSuite)) +} + +func (s *OAuthStateSuite) TestHasProviderConfigChanged() { + s.Run("nil snapshots", func() { + var snap *Snapshot + s.True(snap.HasProviderConfigChanged(&Snapshot{})) + s.True((&Snapshot{}).HasProviderConfigChanged(nil)) + s.True((*Snapshot)(nil).HasProviderConfigChanged(&Snapshot{AuthorizationURL: "https://example.com"})) + }) + + s.Run("both nil", func() { + var a, b *Snapshot + s.False(a.HasProviderConfigChanged(b)) + }) + + s.Run("same values", func() { + a := &Snapshot{AuthorizationURL: "https://auth.example.com", CertificateAuthority: "/ca.pem"} + b := &Snapshot{AuthorizationURL: "https://auth.example.com", CertificateAuthority: "/ca.pem"} + s.False(a.HasProviderConfigChanged(b)) + }) + + s.Run("authorization URL changed", func() { + a := &Snapshot{AuthorizationURL: "https://old.example.com"} + b := &Snapshot{AuthorizationURL: "https://new.example.com"} + s.True(a.HasProviderConfigChanged(b)) + }) + + s.Run("certificate authority changed", func() { + a := &Snapshot{CertificateAuthority: "/old-ca.pem"} + b := &Snapshot{CertificateAuthority: "/new-ca.pem"} + s.True(a.HasProviderConfigChanged(b)) + }) + + s.Run("non-provider fields do not trigger change", func() { + a := &Snapshot{AuthorizationURL: "https://auth.example.com", OAuthScopes: []string{"a"}} + b := &Snapshot{AuthorizationURL: "https://auth.example.com", OAuthScopes: []string{"b"}} + s.False(a.HasProviderConfigChanged(b)) + }) +} + +func (s *OAuthStateSuite) TestHasWellKnownConfigChanged() { + s.Run("scopes changed", func() { + a := &Snapshot{OAuthScopes: []string{"openid"}} + b := &Snapshot{OAuthScopes: []string{"openid", "profile"}} + s.True(a.HasWellKnownConfigChanged(b)) + }) + + s.Run("disable dynamic client registration changed", func() { + a := &Snapshot{DisableDynamicClientRegistration: false} + b := &Snapshot{DisableDynamicClientRegistration: true} + s.True(a.HasWellKnownConfigChanged(b)) + }) + + s.Run("provider config change implies wellknown change", func() { + a := &Snapshot{AuthorizationURL: "https://old.example.com"} + b := &Snapshot{AuthorizationURL: "https://new.example.com"} + s.True(a.HasWellKnownConfigChanged(b)) + }) + + s.Run("no changes", func() { + a := &Snapshot{ + AuthorizationURL: "https://auth.example.com", + OAuthScopes: []string{"openid"}, + DisableDynamicClientRegistration: true, + } + b := &Snapshot{ + AuthorizationURL: "https://auth.example.com", + OAuthScopes: []string{"openid"}, + DisableDynamicClientRegistration: true, + } + s.False(a.HasWellKnownConfigChanged(b)) + }) +} + +func (s *OAuthStateSuite) TestStateLoadStore() { + s.Run("load returns stored snapshot", func() { + snap := &Snapshot{AuthorizationURL: "https://auth.example.com"} + state := NewState(snap) + s.Equal(snap, state.Load()) + }) + + s.Run("store replaces snapshot", func() { + snap1 := &Snapshot{AuthorizationURL: "https://old.example.com"} + snap2 := &Snapshot{AuthorizationURL: "https://new.example.com"} + state := NewState(snap1) + state.Store(snap2) + s.Equal(snap2, state.Load()) + }) + + s.Run("concurrent access is safe", func() { + state := NewState(&Snapshot{AuthorizationURL: "https://initial.example.com"}) + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(2) + go func() { + defer wg.Done() + state.Store(&Snapshot{AuthorizationURL: "https://new.example.com"}) + }() + go func() { + defer wg.Done() + snap := state.Load() + s.NotNil(snap) + }() + } + wg.Wait() + }) +} + +func (s *OAuthStateSuite) TestCreateOIDCProviderAndClient() { + s.Run("empty authorization URL returns nil", func() { + cfg := config.Default() + provider, client, err := CreateOIDCProviderAndClient(cfg) + s.NoError(err) + s.Nil(provider) + s.Nil(client) + }) +} diff --git a/pkg/tokenexchange/assertion.go b/pkg/tokenexchange/assertion.go new file mode 100644 index 000000000..774337844 --- /dev/null +++ b/pkg/tokenexchange/assertion.go @@ -0,0 +1,184 @@ +package tokenexchange + +import ( + "crypto" + "crypto/rsa" + "crypto/sha1" + "crypto/x509" + "encoding/base64" + "encoding/pem" + "fmt" + "os" + "time" + + "github.com/go-jose/go-jose/v4" + "github.com/go-jose/go-jose/v4/jwt" + "github.com/google/uuid" + "k8s.io/klog/v2" +) + +const ( + // ClientAssertionType is the OAuth client assertion type for JWT bearer (RFC 7523) + ClientAssertionType = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" + + // FormKeyClientAssertion is the form parameter name for the JWT assertion + FormKeyClientAssertion = "client_assertion" + + // FormKeyClientAssertionType is the form parameter name for the assertion type + FormKeyClientAssertionType = "client_assertion_type" + + // DefaultAssertionLifetime is the default validity period for assertions + DefaultAssertionLifetime = 5 * time.Minute + + // AssertionRefreshMargin is how early to refresh before expiry + AssertionRefreshMargin = 30 * time.Second +) + +// loadCertificateAndKey reads the certificate and private key from PEM files +func loadCertificateAndKey(certFile, keyFile string) (*x509.Certificate, crypto.Signer, error) { + certPEM, err := os.ReadFile(certFile) + if err != nil { + return nil, nil, fmt.Errorf("failed to read certificate file %q: %w", certFile, err) + } + + certBlock, _ := pem.Decode(certPEM) + if certBlock == nil { + return nil, nil, fmt.Errorf("failed to decode PEM block from certificate file %q", certFile) + } + + cert, err := x509.ParseCertificate(certBlock.Bytes) + if err != nil { + return nil, nil, fmt.Errorf("failed to parse certificate from %q: %w", certFile, err) + } + + keyPEM, err := os.ReadFile(keyFile) + if err != nil { + return nil, nil, fmt.Errorf("failed to read private key file %q: %w", keyFile, err) + } + + keyBlock, _ := pem.Decode(keyPEM) + if keyBlock == nil { + return nil, nil, fmt.Errorf("failed to decode PEM block from private key file %q", keyFile) + } + + var privateKey crypto.Signer + + // Try PKCS8 first (modern format) + if key, err := x509.ParsePKCS8PrivateKey(keyBlock.Bytes); err == nil { + if signer, ok := key.(crypto.Signer); ok { + privateKey = signer + } else { + return nil, nil, fmt.Errorf("private key from %q does not implement crypto.Signer", keyFile) + } + } else if key, err := x509.ParsePKCS1PrivateKey(keyBlock.Bytes); err == nil { + // RSA PKCS1 format + privateKey = key + } else { + return nil, nil, fmt.Errorf("failed to parse private key from %q: unsupported format (only RSA keys are supported)", keyFile) + } + + // Only RSA keys are currently supported for JWT client assertions + if _, ok := privateKey.(*rsa.PrivateKey); !ok { + return nil, nil, fmt.Errorf("unsupported key type %T from %q: only RSA keys are currently supported for JWT client assertions", privateKey, keyFile) + } + + return cert, privateKey, nil +} + +// computeX5T computes the x5t (X.509 certificate SHA-1 thumbprint) header value +func computeX5T(cert *x509.Certificate) string { + thumbprint := sha1.Sum(cert.Raw) + return base64.RawURLEncoding.EncodeToString(thumbprint[:]) +} + +// getSignatureAlgorithm determines the jose.SignatureAlgorithm based on key type +func getSignatureAlgorithm(key crypto.Signer) (jose.SignatureAlgorithm, error) { + switch key.(type) { + case *rsa.PrivateKey: + return jose.RS256, nil + default: + return "", fmt.Errorf("unsupported key type: %T (only RSA keys are currently supported)", key) + } +} + +// BuildClientAssertion creates a signed JWT assertion for client authentication +func BuildClientAssertion(clientID, tokenURL, certFile, keyFile string, lifetime time.Duration) (string, time.Time, error) { + if lifetime == 0 { + lifetime = DefaultAssertionLifetime + } + + cert, privateKey, err := loadCertificateAndKey(certFile, keyFile) + if err != nil { + return "", time.Time{}, err + } + + algorithm, err := getSignatureAlgorithm(privateKey) + if err != nil { + return "", time.Time{}, err + } + + now := time.Now() + expiry := now.Add(lifetime) + + claims := jwt.Claims{ + Issuer: clientID, + Subject: clientID, + Audience: jwt.Audience{tokenURL}, + ID: uuid.New().String(), + NotBefore: jwt.NewNumericDate(now), + Expiry: jwt.NewNumericDate(expiry), + } + + signerOpts := jose.SignerOptions{} + signerOpts.WithHeader("x5t", computeX5T(cert)) + signerOpts.WithType("JWT") + + signer, err := jose.NewSigner( + jose.SigningKey{Algorithm: algorithm, Key: privateKey}, + &signerOpts, + ) + if err != nil { + return "", time.Time{}, fmt.Errorf("failed to create JWT signer: %w", err) + } + + signedJWT, err := jwt.Signed(signer).Claims(claims).Serialize() + if err != nil { + return "", time.Time{}, fmt.Errorf("failed to sign JWT assertion: %w", err) + } + + klog.V(4).Infof("Built JWT client assertion: issuer=%s, audience=%s, jti=%s, x5t=%s, expires=%s", + clientID, tokenURL, claims.ID, computeX5T(cert), expiry.Format(time.RFC3339)) + + return signedJWT, expiry, nil +} + +// GetOrBuildAssertion returns a cached assertion or builds a new one +func (c *TargetTokenExchangeConfig) GetOrBuildAssertion() (string, error) { + c.assertionMutex.Lock() + defer c.assertionMutex.Unlock() + + // Check if cached assertion is still valid (with margin) + if c.cachedAssertion != "" && time.Now().Add(AssertionRefreshMargin).Before(c.cachedAssertionExpiry) { + klog.V(4).Infof("Using cached JWT client assertion, expires=%s", c.cachedAssertionExpiry.Format(time.RFC3339)) + return c.cachedAssertion, nil + } + + klog.V(4).Infof("Building new JWT client assertion: client_id=%s, token_url=%s, cert_file=%s", + c.ClientID, c.TokenURL, c.ClientCertFile) + + assertion, expiry, err := BuildClientAssertion( + c.ClientID, + c.TokenURL, + c.ClientCertFile, + c.ClientKeyFile, + c.AssertionLifetime, + ) + if err != nil { + return "", err + } + + c.cachedAssertion = assertion + c.cachedAssertionExpiry = expiry + + return assertion, nil +} diff --git a/pkg/tokenexchange/assertion_test.go b/pkg/tokenexchange/assertion_test.go new file mode 100644 index 000000000..9c9047ca4 --- /dev/null +++ b/pkg/tokenexchange/assertion_test.go @@ -0,0 +1,348 @@ +package tokenexchange + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "net/http" + "net/url" + "os" + "path/filepath" + "testing" + "time" + + "github.com/go-jose/go-jose/v4" + "github.com/go-jose/go-jose/v4/jwt" + "github.com/stretchr/testify/suite" +) + +type AssertionTestSuite struct { + suite.Suite + tempDir string + certFile string + keyFile string +} + +func (s *AssertionTestSuite) SetupTest() { + var err error + s.tempDir, err = os.MkdirTemp("", "assertion-test-*") + s.Require().NoError(err) + + s.certFile, s.keyFile = s.generateTestCertAndKey() +} + +func (s *AssertionTestSuite) TearDownTest() { + _ = os.RemoveAll(s.tempDir) +} + +func (s *AssertionTestSuite) generateTestCertAndKey() (string, string) { + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + s.Require().NoError(err) + + template := x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{CommonName: "test"}, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour), + } + + certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey) + s.Require().NoError(err) + + certFile := filepath.Join(s.tempDir, "cert.pem") + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) + s.Require().NoError(os.WriteFile(certFile, certPEM, 0644)) + + keyFile := filepath.Join(s.tempDir, "key.pem") + keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)}) + s.Require().NoError(os.WriteFile(keyFile, keyPEM, 0600)) + + return certFile, keyFile +} + +func (s *AssertionTestSuite) TestBuildClientAssertion() { + s.Run("builds valid JWT with correct claims", func() { + clientID := "test-client-id" + tokenURL := "https://login.microsoftonline.com/tenant/oauth2/v2.0/token" + + assertion, expiry, err := BuildClientAssertion(clientID, tokenURL, s.certFile, s.keyFile, 5*time.Minute) + s.Require().NoError(err) + s.NotEmpty(assertion) + s.True(expiry.After(time.Now())) + + token, err := jwt.ParseSigned(assertion, []jose.SignatureAlgorithm{jose.RS256}) + s.Require().NoError(err) + + var claims jwt.Claims + err = token.UnsafeClaimsWithoutVerification(&claims) + s.Require().NoError(err) + + s.Equal(clientID, claims.Issuer) + s.Equal(clientID, claims.Subject) + s.True(claims.Audience.Contains(tokenURL)) + s.NotEmpty(claims.ID) + }) + + s.Run("includes x5t header", func() { + assertion, _, err := BuildClientAssertion("client", "https://token.url", s.certFile, s.keyFile, 0) + s.Require().NoError(err) + + token, err := jwt.ParseSigned(assertion, []jose.SignatureAlgorithm{jose.RS256}) + s.Require().NoError(err) + + s.Len(token.Headers, 1) + x5t, ok := token.Headers[0].ExtraHeaders["x5t"] + s.True(ok, "x5t header should be present") + s.NotEmpty(x5t) + }) + + s.Run("uses default lifetime when zero", func() { + _, expiry, err := BuildClientAssertion("client", "https://token.url", s.certFile, s.keyFile, 0) + s.Require().NoError(err) + + expectedExpiry := time.Now().Add(DefaultAssertionLifetime) + s.InDelta(expectedExpiry.Unix(), expiry.Unix(), 5) + }) + + s.Run("returns error for missing cert file", func() { + _, _, err := BuildClientAssertion("client", "https://token.url", "/nonexistent/cert.pem", s.keyFile, 0) + s.Error(err) + s.Contains(err.Error(), "failed to read certificate file") + }) + + s.Run("returns error for missing key file", func() { + _, _, err := BuildClientAssertion("client", "https://token.url", s.certFile, "/nonexistent/key.pem", 0) + s.Error(err) + s.Contains(err.Error(), "failed to read private key file") + }) + + s.Run("returns error for invalid cert PEM", func() { + invalidCert := filepath.Join(s.tempDir, "invalid-cert.pem") + s.Require().NoError(os.WriteFile(invalidCert, []byte("not a valid PEM"), 0644)) + + _, _, err := BuildClientAssertion("client", "https://token.url", invalidCert, s.keyFile, 0) + s.Error(err) + s.Contains(err.Error(), "failed to decode PEM block") + }) + + s.Run("returns error for invalid key PEM", func() { + invalidKey := filepath.Join(s.tempDir, "invalid-key.pem") + s.Require().NoError(os.WriteFile(invalidKey, []byte("not a valid PEM"), 0600)) + + _, _, err := BuildClientAssertion("client", "https://token.url", s.certFile, invalidKey, 0) + s.Error(err) + s.Contains(err.Error(), "failed to decode PEM block") + }) +} + +func (s *AssertionTestSuite) TestGetOrBuildAssertion() { + s.Run("caches assertion", func() { + cfg := &TargetTokenExchangeConfig{ + TokenURL: "https://token.url", + ClientID: "test-client", + ClientCertFile: s.certFile, + ClientKeyFile: s.keyFile, + } + + assertion1, err := cfg.GetOrBuildAssertion() + s.Require().NoError(err) + + assertion2, err := cfg.GetOrBuildAssertion() + s.Require().NoError(err) + + s.Equal(assertion1, assertion2, "should return cached assertion") + }) + + s.Run("returns error for missing files", func() { + cfg := &TargetTokenExchangeConfig{ + TokenURL: "https://token.url", + ClientID: "test-client", + ClientCertFile: "/nonexistent/cert.pem", + ClientKeyFile: "/nonexistent/key.pem", + } + + _, err := cfg.GetOrBuildAssertion() + s.Error(err) + }) +} + +func (s *AssertionTestSuite) TestLoadCertificateAndKey() { + s.Run("loads PKCS1 RSA key", func() { + cert, key, err := loadCertificateAndKey(s.certFile, s.keyFile) + s.Require().NoError(err) + s.NotNil(cert) + s.NotNil(key) + }) + + s.Run("loads PKCS8 RSA key", func() { + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + s.Require().NoError(err) + + pkcs8Key, err := x509.MarshalPKCS8PrivateKey(privateKey) + s.Require().NoError(err) + + pkcs8KeyFile := filepath.Join(s.tempDir, "pkcs8-key.pem") + keyPEM := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: pkcs8Key}) + s.Require().NoError(os.WriteFile(pkcs8KeyFile, keyPEM, 0600)) + + cert, key, err := loadCertificateAndKey(s.certFile, pkcs8KeyFile) + s.Require().NoError(err) + s.NotNil(cert) + s.NotNil(key) + }) +} + +func (s *AssertionTestSuite) TestComputeX5T() { + certPEM, err := os.ReadFile(s.certFile) + s.Require().NoError(err) + + block, _ := pem.Decode(certPEM) + s.Require().NotNil(block) + + cert, err := x509.ParseCertificate(block.Bytes) + s.Require().NoError(err) + + x5t := computeX5T(cert) + s.NotEmpty(x5t) + s.Len(x5t, 27) // Base64url encoded SHA-1 (20 bytes) = 27 chars without padding +} + +func (s *AssertionTestSuite) TestValidate() { + s.Run("valid assertion config", func() { + cfg := &TargetTokenExchangeConfig{ + AuthStyle: AuthStyleAssertion, + ClientCertFile: "/path/to/cert.pem", + ClientKeyFile: "/path/to/key.pem", + } + err := cfg.Validate() + s.NoError(err) + }) + + s.Run("assertion requires cert file", func() { + cfg := &TargetTokenExchangeConfig{ + AuthStyle: AuthStyleAssertion, + ClientKeyFile: "/path/to/key.pem", + } + err := cfg.Validate() + s.Error(err) + s.Contains(err.Error(), "client_cert_file is required") + }) + + s.Run("assertion requires key file", func() { + cfg := &TargetTokenExchangeConfig{ + AuthStyle: AuthStyleAssertion, + ClientCertFile: "/path/to/cert.pem", + } + err := cfg.Validate() + s.Error(err) + s.Contains(err.Error(), "client_key_file is required") + }) + + s.Run("params style is valid", func() { + cfg := &TargetTokenExchangeConfig{ + AuthStyle: AuthStyleParams, + } + err := cfg.Validate() + s.NoError(err) + }) + + s.Run("header style is valid", func() { + cfg := &TargetTokenExchangeConfig{ + AuthStyle: AuthStyleHeader, + } + err := cfg.Validate() + s.NoError(err) + }) + + s.Run("invalid auth style", func() { + cfg := &TargetTokenExchangeConfig{ + AuthStyle: "invalid", + } + err := cfg.Validate() + s.Error(err) + s.Contains(err.Error(), "invalid auth_style") + }) +} + +func (s *AssertionTestSuite) TestLoadCertificateAndKeyRejectsECKeys() { + s.Run("rejects EC private key in SEC1 format", func() { + ecKeyFile := s.generateECKeyFile("EC PRIVATE KEY", false) + _, _, err := loadCertificateAndKey(s.certFile, ecKeyFile) + s.Error(err) + s.Contains(err.Error(), "unsupported format") + }) + + s.Run("rejects EC private key in PKCS8 format", func() { + ecKeyFile := s.generateECKeyFile("PRIVATE KEY", true) + _, _, err := loadCertificateAndKey(s.certFile, ecKeyFile) + s.Error(err) + s.Contains(err.Error(), "unsupported key type") + }) +} + +func (s *AssertionTestSuite) generateECKeyFile(pemType string, pkcs8 bool) string { + ecKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + s.Require().NoError(err) + + var keyBytes []byte + if pkcs8 { + keyBytes, err = x509.MarshalPKCS8PrivateKey(ecKey) + s.Require().NoError(err) + } else { + keyBytes, err = x509.MarshalECPrivateKey(ecKey) + s.Require().NoError(err) + } + + ecKeyFile := filepath.Join(s.tempDir, "ec-key.pem") + ecPEM := pem.EncodeToMemory(&pem.Block{Type: pemType, Bytes: keyBytes}) + s.Require().NoError(os.WriteFile(ecKeyFile, ecPEM, 0600)) + return ecKeyFile +} + +func (s *AssertionTestSuite) TestInjectClientAuthWithAssertion() { + s.Run("sets client assertion form fields", func() { + cfg := &TargetTokenExchangeConfig{ + TokenURL: "https://login.microsoftonline.com/tenant/oauth2/v2.0/token", + ClientID: "test-client", + AuthStyle: AuthStyleAssertion, + ClientCertFile: s.certFile, + ClientKeyFile: s.keyFile, + } + + data := url.Values{} + header := http.Header{} + err := injectClientAuth(cfg, data, header) + s.Require().NoError(err) + + s.Equal("test-client", data.Get(FormKeyClientID)) + s.Equal(ClientAssertionType, data.Get(FormKeyClientAssertionType)) + s.NotEmpty(data.Get(FormKeyClientAssertion), "client_assertion should be set") + s.Empty(data.Get(FormKeyClientSecret), "client_secret should not be set for assertion auth") + s.Empty(header.Get(HeaderAuthorization), "Authorization header should not be set for assertion auth") + }) + + s.Run("returns error for invalid cert files", func() { + cfg := &TargetTokenExchangeConfig{ + TokenURL: "https://token.url", + ClientID: "test-client", + AuthStyle: AuthStyleAssertion, + ClientCertFile: "/nonexistent/cert.pem", + ClientKeyFile: "/nonexistent/key.pem", + } + + data := url.Values{} + header := http.Header{} + err := injectClientAuth(cfg, data, header) + s.Error(err) + s.Contains(err.Error(), "failed to build client assertion") + }) +} + +func TestAssertion(t *testing.T) { + suite.Run(t, new(AssertionTestSuite)) +} diff --git a/pkg/tokenexchange/config.go b/pkg/tokenexchange/config.go index 8382d5388..7ab6f8d2e 100644 --- a/pkg/tokenexchange/config.go +++ b/pkg/tokenexchange/config.go @@ -6,6 +6,7 @@ import ( "fmt" "net/http" "os" + "sync" "time" ) @@ -14,6 +15,8 @@ const ( AuthStyleParams = "params" // AuthStyleHeader sends client credentials as HTTP Basic Authentication header AuthStyleHeader = "header" + // AuthStyleAssertion sends a signed JWT client assertion (RFC 7523) + AuthStyleAssertion = "assertion" ) // TargetTokenExchangeConfig holds per-target token exchange configuration @@ -43,21 +46,47 @@ type TargetTokenExchangeConfig struct { // AuthStyle specifies how client credentials are sent to the token endpoint // "params" (default): client_id/secret in request body // "header": HTTP Basic Authentication header + // "assertion": JWT client assertion (RFC 7523) AuthStyle string `toml:"auth_style,omitempty"` + // ClientCertFile is the path to the client certificate PEM file + // Used with AuthStyleAssertion for JWT client assertion authentication + ClientCertFile string `toml:"client_cert_file,omitempty"` + // ClientKeyFile is the path to the client private key PEM file + // Used with AuthStyleAssertion for JWT client assertion authentication + ClientKeyFile string `toml:"client_key_file,omitempty"` + // AssertionLifetime is the validity duration for generated JWT assertions + // Defaults to 5 minutes if not specified + AssertionLifetime time.Duration `toml:"assertion_lifetime,omitempty"` // client is a http client configured to work with the IdP for this target client *http.Client `toml:"-"` + // cachedAssertion stores the most recently generated JWT assertion + cachedAssertion string `toml:"-"` + // cachedAssertionExpiry is when the cached assertion expires + cachedAssertionExpiry time.Time `toml:"-"` + // assertionMutex protects assertion caching from race conditions + assertionMutex sync.Mutex `toml:"-"` } // Validate checks that the configuration values are valid func (c *TargetTokenExchangeConfig) Validate() error { - if c.AuthStyle != "" && c.AuthStyle != AuthStyleParams && c.AuthStyle != AuthStyleHeader { - return fmt.Errorf("invalid auth_style %q: must be %q or %q", c.AuthStyle, AuthStyleParams, AuthStyleHeader) + switch c.AuthStyle { + case "", AuthStyleParams, AuthStyleHeader: + // valid + case AuthStyleAssertion: + if c.ClientCertFile == "" { + return fmt.Errorf("client_cert_file is required when auth_style is %q", AuthStyleAssertion) + } + if c.ClientKeyFile == "" { + return fmt.Errorf("client_key_file is required when auth_style is %q", AuthStyleAssertion) + } + default: + return fmt.Errorf("invalid auth_style %q: must be %q, %q, or %q", c.AuthStyle, AuthStyleParams, AuthStyleHeader, AuthStyleAssertion) } return nil } -func (c *TargetTokenExchangeConfig) HTTPCLient() (*http.Client, error) { +func (c *TargetTokenExchangeConfig) HTTPClient() (*http.Client, error) { if c.client != nil { return c.client, nil } diff --git a/pkg/tokenexchange/entra_obo_exchanger.go b/pkg/tokenexchange/entra_obo_exchanger.go new file mode 100644 index 000000000..9c3134ba7 --- /dev/null +++ b/pkg/tokenexchange/entra_obo_exchanger.go @@ -0,0 +1,54 @@ +package tokenexchange + +import ( + "context" + "net/http" + "net/url" + "strings" + + "golang.org/x/oauth2" +) + +const ( + StrategyEntraOBO = "entra-obo" + + // Entra ID OBO-specific constants + GrantTypeJWTBearer = "urn:ietf:params:oauth:grant-type:jwt-bearer" + FormKeyAssertion = "assertion" + FormKeyRequestedUse = "requested_token_use" + RequestedTokenUseOBO = "on_behalf_of" +) + +// entraOBOExchanger implements the Entra ID On-Behalf-Of flow. +// This is used when the MCP server needs to exchange a user's token for a token +// that can access downstream APIs (like Kubernetes) on behalf of that user. +// +// See: https://learn.microsoft.com/en-us/entra/identity-platform/v2-oauth2-on-behalf-of-flow +type entraOBOExchanger struct{} + +var _ TokenExchanger = &entraOBOExchanger{} + +func (e *entraOBOExchanger) Exchange(ctx context.Context, cfg *TargetTokenExchangeConfig, subjectToken string) (*oauth2.Token, error) { + httpClient, err := cfg.HTTPClient() + if err != nil { + return nil, err + } + + data := url.Values{} + data.Set(FormKeyGrantType, GrantTypeJWTBearer) + data.Set(FormKeyAssertion, subjectToken) + data.Set(FormKeyRequestedUse, RequestedTokenUseOBO) + + if len(cfg.Scopes) > 0 { + data.Set(FormKeyScope, strings.Join(cfg.Scopes, " ")) + } else if cfg.Audience != "" { + data.Set(FormKeyScope, cfg.Audience) + } + + headers := make(http.Header) + if err := injectClientAuth(cfg, data, headers); err != nil { + return nil, err + } + + return doTokenExchange(ctx, httpClient, cfg.TokenURL, data, headers) +} diff --git a/pkg/tokenexchange/entra_obo_exchanger_test.go b/pkg/tokenexchange/entra_obo_exchanger_test.go new file mode 100644 index 000000000..ed3fcf553 --- /dev/null +++ b/pkg/tokenexchange/entra_obo_exchanger_test.go @@ -0,0 +1,101 @@ +package tokenexchange + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/suite" +) + +type EntraOBOExchangerTestSuite struct { + suite.Suite +} + +func (s *EntraOBOExchangerTestSuite) TestExchange() { + s.Run("successful token exchange", func() { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + s.Equal(http.MethodPost, r.Method) + s.Equal("application/x-www-form-urlencoded", r.Header.Get("Content-Type")) + + err := r.ParseForm() + s.Require().NoError(err) + + s.Equal(GrantTypeJWTBearer, r.Form.Get(FormKeyGrantType)) + s.Equal("incoming-token", r.Form.Get(FormKeyAssertion)) + s.Equal(RequestedTokenUseOBO, r.Form.Get(FormKeyRequestedUse)) + s.Equal("test-client", r.Form.Get(FormKeyClientID)) + s.Equal("test-secret", r.Form.Get(FormKeyClientSecret)) + s.Equal("api://target/.default", r.Form.Get(FormKeyScope)) + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "access_token": "exchanged-token", + "token_type": "Bearer", + "expires_in": 3600, + }) + })) + defer server.Close() + + exchanger := &entraOBOExchanger{} + cfg := &TargetTokenExchangeConfig{ + TokenURL: server.URL, + ClientID: "test-client", + ClientSecret: "test-secret", + Scopes: []string{"api://target/.default"}, + } + + token, err := exchanger.Exchange(context.Background(), cfg, "incoming-token") + s.Require().NoError(err) + s.Equal("exchanged-token", token.AccessToken) + s.Equal("Bearer", token.TokenType) + }) + + s.Run("uses audience as scope when scopes empty", func() { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = r.ParseForm() + s.Equal("api://kubernetes/.default", r.Form.Get(FormKeyScope)) + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "access_token": "token", + "token_type": "Bearer", + }) + })) + defer server.Close() + + exchanger := &entraOBOExchanger{} + cfg := &TargetTokenExchangeConfig{ + TokenURL: server.URL, + Audience: "api://kubernetes/.default", + } + + token, err := exchanger.Exchange(context.Background(), cfg, "incoming-token") + s.Require().NoError(err) + s.NotEmpty(token.AccessToken) + }) + + s.Run("returns error on failed exchange", func() { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"error": "invalid_grant"}`)) + })) + defer server.Close() + + exchanger := &entraOBOExchanger{} + cfg := &TargetTokenExchangeConfig{ + TokenURL: server.URL, + } + + token, err := exchanger.Exchange(context.Background(), cfg, "bad-token") + s.Error(err) + s.Nil(token) + s.Contains(err.Error(), "400") + }) +} + +func TestEntraOBOExchanger(t *testing.T) { + suite.Run(t, new(EntraOBOExchangerTestSuite)) +} diff --git a/pkg/tokenexchange/exchanger.go b/pkg/tokenexchange/exchanger.go index f72c46532..6d02fc6e3 100644 --- a/pkg/tokenexchange/exchanger.go +++ b/pkg/tokenexchange/exchanger.go @@ -48,21 +48,30 @@ type TokenExchanger interface { } // injectClientAuth adds client credentials to the request based on auth style -func injectClientAuth(cfg *TargetTokenExchangeConfig, data url.Values, header http.Header) { +func injectClientAuth(cfg *TargetTokenExchangeConfig, data url.Values, header http.Header) error { if cfg.ClientID == "" { - return + return nil } switch cfg.AuthStyle { case AuthStyleHeader: credentials := cfg.ClientID + ":" + cfg.ClientSecret header.Set(HeaderAuthorization, "Basic "+base64.StdEncoding.EncodeToString([]byte(credentials))) + case AuthStyleAssertion: + assertion, err := cfg.GetOrBuildAssertion() + if err != nil { + return fmt.Errorf("failed to build client assertion: %w", err) + } + data.Set(FormKeyClientID, cfg.ClientID) + data.Set(FormKeyClientAssertionType, ClientAssertionType) + data.Set(FormKeyClientAssertion, assertion) default: // AuthStyleParams or empty (default) data.Set(FormKeyClientID, cfg.ClientID) if cfg.ClientSecret != "" { data.Set(FormKeyClientSecret, cfg.ClientSecret) } } + return nil } // tokenExchangeResponse represents the OAuth token exchange response diff --git a/pkg/tokenexchange/keycloak_v1_exchanger.go b/pkg/tokenexchange/keycloak_v1_exchanger.go index 53fd98db7..1a1e7d433 100644 --- a/pkg/tokenexchange/keycloak_v1_exchanger.go +++ b/pkg/tokenexchange/keycloak_v1_exchanger.go @@ -16,7 +16,7 @@ type keycloakV1Exchanger struct{} var _ TokenExchanger = &keycloakV1Exchanger{} func (e *keycloakV1Exchanger) Exchange(ctx context.Context, cfg *TargetTokenExchangeConfig, subjectToken string) (*oauth2.Token, error) { - httpClient, err := cfg.HTTPCLient() + httpClient, err := cfg.HTTPClient() if err != nil { return nil, fmt.Errorf("failed to acquire http client to talk to IdP for target: %w", err) } @@ -36,7 +36,9 @@ func (e *keycloakV1Exchanger) Exchange(ctx context.Context, cfg *TargetTokenExch } headers := http.Header{} - injectClientAuth(cfg, data, headers) + if err := injectClientAuth(cfg, data, headers); err != nil { + return nil, err + } return doTokenExchange(ctx, httpClient, cfg.TokenURL, data, headers) } diff --git a/pkg/tokenexchange/registry.go b/pkg/tokenexchange/registry.go index 4bff301c1..903a0bb6b 100644 --- a/pkg/tokenexchange/registry.go +++ b/pkg/tokenexchange/registry.go @@ -7,6 +7,7 @@ var ( func init() { RegisterTokenExchanger(StrategyKeycloakV1, &keycloakV1Exchanger{}) RegisterTokenExchanger(StrategyRFC8693, &rfc8693Exchanger{}) + RegisterTokenExchanger(StrategyEntraOBO, &entraOBOExchanger{}) } func RegisterTokenExchanger(strategy string, exchanger TokenExchanger) { diff --git a/pkg/tokenexchange/registry_test.go b/pkg/tokenexchange/registry_test.go index 06ec0d972..edb7de7e4 100644 --- a/pkg/tokenexchange/registry_test.go +++ b/pkg/tokenexchange/registry_test.go @@ -21,6 +21,11 @@ func (s *TokenExchangerRegistryTestSuite) TestGetTokenExchanger() { s.True(ok, "Expected rfc8693 exchanger to be registered") s.NotNil(exchanger, "Expected rfc8693 exchanger to be non-nil") }) + s.Run("returns entra-obo exchanger", func() { + exchanger, ok := GetTokenExchanger(StrategyEntraOBO) + s.True(ok, "Expected entra-obo exchanger to be registered") + s.NotNil(exchanger, "Expected entra-obo exchanger to be non-nil") + }) s.Run("returns false for unregistered strategy", func() { exchanger, ok := GetTokenExchanger("non-existent") s.False(ok, "Expected false for non-existent strategy") diff --git a/pkg/tokenexchange/rfc8693_exchanger.go b/pkg/tokenexchange/rfc8693_exchanger.go index f7ef2e6b1..bc78fad1f 100644 --- a/pkg/tokenexchange/rfc8693_exchanger.go +++ b/pkg/tokenexchange/rfc8693_exchanger.go @@ -15,7 +15,7 @@ type rfc8693Exchanger struct{} var _ TokenExchanger = &rfc8693Exchanger{} func (e *rfc8693Exchanger) Exchange(ctx context.Context, cfg *TargetTokenExchangeConfig, subjectToken string) (*oauth2.Token, error) { - httpClient, err := cfg.HTTPCLient() + httpClient, err := cfg.HTTPClient() if err != nil { return nil, fmt.Errorf("failed to acquire http client to talk to IdP for target: %w", err) } @@ -32,7 +32,9 @@ func (e *rfc8693Exchanger) Exchange(ctx context.Context, cfg *TargetTokenExchang } headers := http.Header{} - injectClientAuth(cfg, data, headers) + if err := injectClientAuth(cfg, data, headers); err != nil { + return nil, err + } return doTokenExchange(ctx, httpClient, cfg.TokenURL, data, headers) }