Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 61 additions & 47 deletions pkg/api/handlers/mcpgateway/oauth/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -460,76 +460,90 @@ func (h *handler) doTokenExchange(req api.Context, oauthClient v1.OAuthClient, r
})
}

if mcpServer.Spec.Manifest.Runtime == types.RuntimeComposite {
_, componentMCPID, ok := strings.Cut(resource, "/mcp-connect/")
audienceID := componentMCPID

var (
token string
expiresAt time.Time
)

if ok {
if system.IsMCPServerInstanceID(componentMCPID) {
_, resourceMCPServerID, isConnectURL := strings.Cut(resource, "/mcp-connect/")

var (
token string
expiresAt time.Time
)

// If the resource is an internal MCP connect URL, then process it here and don't try to get OAuth credentials.
if isConnectURL || mcpServer.Spec.Manifest.Runtime == types.RuntimeComposite {
Comment thread
thedadams marked this conversation as resolved.
Outdated
audienceID := resourceMCPServerID
if system.IsSystemMCPServerID(audienceID) {
var systemMCPServer v1.SystemMCPServer
if err := req.Get(&systemMCPServer, audienceID); err != nil {
return types.NewErrBadRequest("%v", Error{
Code: ErrInvalidRequest,
Description: "failed to retrieve system MCP server " + audienceID,
})
}
} else if mcpServer.Spec.Manifest.Runtime == types.RuntimeComposite {
if !isConnectURL {
// No component MCP ID in resource, return the original token
token = subjectToken
if tokenCtx != nil {
expiresAt = tokenCtx.ExpiresAt
} else if apiKeyExpiresAt != nil {
expiresAt = *apiKeyExpiresAt
} else {
expiresAt = time.Now().Add(tokenExpiration)
}
} else if system.IsMCPServerInstanceID(resourceMCPServerID) {
// Ensure this MCP server instance belongs to this composite MCP server.
var component v1.MCPServerInstance
if err := req.Get(&component, componentMCPID); err != nil || component.Spec.CompositeName != mcpServer.Name {
if err := req.Get(&component, resourceMCPServerID); err != nil || component.Spec.CompositeName != mcpServer.Name {
return types.NewErrBadRequest("%v", Error{
Code: ErrInvalidRequest,
Description: "failed to retrieve composite MCP server " + componentMCPID,
Description: "failed to retrieve composite MCP server " + resourceMCPServerID,
})
}

audienceID = component.Spec.MCPServerName
} else {
// Ensure this MCP server belongs to this composite MCP server.
var component v1.MCPServer
if err := req.Get(&component, componentMCPID); err != nil || component.Spec.CompositeName != mcpServer.Name {
if err := req.Get(&component, resourceMCPServerID); err != nil || component.Spec.CompositeName != mcpServer.Name {
return types.NewErrBadRequest("%v", Error{
Code: ErrInvalidRequest,
Description: "failed to retrieve composite MCP server " + componentMCPID,
})
}
}

if subjectTokenType == tokenTypeAPIKey {
// Pass the API key through to component servers for their own token exchange
token = subjectToken
expiresAt = time.Now().Add(tokenExpiration)
if apiKeyExpiresAt != nil {
expiresAt = *apiKeyExpiresAt
}
} else {
// For JWTs, update the existing token context and create a new token
tokenCtx.MCPID = componentMCPID
tokenCtx.Audience = fmt.Sprintf("%s/mcp-connect/%s", h.baseURL, audienceID)
expiresAt = tokenCtx.ExpiresAt

var err error
token, err = h.tokenService.NewToken(req.Context(), *tokenCtx)
if err != nil {
log.Errorf("failed to create token for component MCP server %s: %v", componentMCPID, err)
return types.NewErrBadRequest("%v", Error{
Code: ErrServerError,
Description: "failed to create token",
Description: "failed to retrieve composite MCP server " + resourceMCPServerID,
})
}
}
} else {
// No component MCP ID in resource, return the original token
return types.NewErrBadRequest("%v", Error{
Code: ErrInvalidRequest,
Description: "invalid resource for token exchange",
})
}

if subjectTokenType == tokenTypeAPIKey {
// Pass the API key through to component servers for their own token exchange
token = subjectToken
if tokenCtx != nil {
expiresAt = tokenCtx.ExpiresAt
} else if apiKeyExpiresAt != nil {
expiresAt = time.Now().Add(tokenExpiration)
if apiKeyExpiresAt != nil {
expiresAt = *apiKeyExpiresAt
} else {
expiresAt = time.Now().Add(tokenExpiration)
}
} else if isConnectURL {
// For JWTs, update the existing token context and create a new token
tokenCtx.MCPID = resourceMCPServerID
tokenCtx.Audience = fmt.Sprintf("%s/mcp-connect/%s", h.baseURL, audienceID)
expiresAt = tokenCtx.ExpiresAt

var err error
token, err = h.tokenService.NewToken(req.Context(), *tokenCtx)
if err != nil {
log.Errorf("failed to create token for component MCP server %s: %v", resourceMCPServerID, err)
return types.NewErrBadRequest("%v", Error{
Code: ErrServerError,
Description: "failed to create token",
})
}
}

// For composite MCP servers, return the token.
// This ensures it gets passed to the component MCP servers so they can do token exchange.
log.Infof("Issued token-exchange response for composite MCP server: client=%s mcpID=%s audienceResource=%s subjectTokenType=%s", oauthClient.Name, mcpID, resource, subjectTokenType)
log.Infof("Issued token-exchange response for MCP server for internal connect URL: client=%s mcpID=%s audienceResource=%s subjectTokenType=%s", oauthClient.Name, mcpID, resource, subjectTokenType)
return req.Write(TokenExchangeResponse{
AccessToken: token,
IssuedTokenType: tokenTypeAccessToken,
Expand All @@ -539,7 +553,7 @@ func (h *handler) doTokenExchange(req api.Context, oauthClient v1.OAuthClient, r
}
} else if system.IsMCPServerInstanceID(mcpID) {
return types.NewErrNotFound("no token exchange for %s", resource)
} else if system.IsSystemMCPServerID(mcpID) {
} else if mcpID == system.ObotMCPServerName {
// Return a new token that represents the user, so that the SystemMCPServer can make API calls to Obot on behalf of the user.
// Preserve the user's existing groups/roles when available from the subject token,
// otherwise look up the user to determine their role.
Expand Down
46 changes: 26 additions & 20 deletions pkg/controller/handlers/systemmcpserver/systemmcpserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"crypto/rand"
"errors"
"fmt"
"maps"
"strings"
"time"

Expand Down Expand Up @@ -66,7 +67,7 @@ func (h *Handler) EnsureSecretInfo(req router.Request, _ router.Response) error
}
}

secretCredToolName := secretInfoToolName(systemServer.Name)
secretCredToolName := SecretInfoToolName(systemServer.Name)

if systemServer.Status.AuditLogTokenHash != "" {
cred, err := h.gptClient.RevealCredential(req.Ctx, []string{systemServer.Name}, secretCredToolName)
Expand Down Expand Up @@ -168,7 +169,7 @@ func (h *Handler) EnsureDeployment(req router.Request, _ router.Response) error
return fmt.Errorf("failed to list credentials: %w", err)
}

secretToolName := secretInfoToolName(systemServer.Name)
secretToolName := SecretInfoToolName(systemServer.Name)
credEnv := make(map[string]string)
for _, cred := range creds {
// Skip the secret info credential — those vars go to the shim only, not the MCP server.
Expand All @@ -180,9 +181,8 @@ func (h *Handler) EnsureDeployment(req router.Request, _ router.Response) error
if err != nil {
continue
}
for k, v := range credDetail.Env {
credEnv[k] = v
}

maps.Copy(credEnv, credDetail.Env)
}

// Retrieve the token exchange credential
Expand Down Expand Up @@ -239,12 +239,23 @@ func (h *Handler) EnsureDeployment(req router.Request, _ router.Response) error
// CleanupDeployment handles cleanup when SystemMCPServer is deleted
func (h *Handler) CleanupDeployment(req router.Request, _ router.Response) error {
systemServer := req.Object.(*v1.SystemMCPServer)
creds, err := h.gptClient.ListCredentials(req.Ctx, gptscript.ListCredentialsOptions{
CredentialContexts: []string{systemServer.Name},
})
if err != nil {
return fmt.Errorf("failed to list credentials for %s system server cleanup: %w", systemServer.Name, err)
}

for _, cred := range creds {
if err := h.gptClient.DeleteCredential(req.Ctx, cred.Context, cred.ToolName); err != nil && !errors.As(err, &gptscript.ErrNotFound{}) {
return fmt.Errorf("failed to delete credential %s: %w", cred.ToolName, err)
}
}

// Shutdown deployment via backend
// The backend's shutdownServer will remove the deployment (Docker container or K8s deployment)
err := h.mcpSessionManager.ShutdownServer(req.Ctx, systemServer.Name)
if err != nil {
return fmt.Errorf("failed to shutdown system MCP server: %w", err)
if err = h.mcpSessionManager.ShutdownServer(req.Ctx, systemServer.Name); err != nil {
return fmt.Errorf("failed to shutdown system MCP server %s: %w", systemServer.Name, err)
}

return nil
Expand All @@ -253,29 +264,28 @@ func (h *Handler) CleanupDeployment(req router.Request, _ router.Response) error
// isSystemServerConfigured checks if all required configuration is present
func isSystemServerConfigured(ctx context.Context, gptClient *gptscript.GPTScript, server v1.SystemMCPServer) bool {
// Check if all required env vars are configured
credCtx := server.Name
credCtx := []string{server.Name}
creds, err := gptClient.ListCredentials(ctx, gptscript.ListCredentialsOptions{
CredentialContexts: []string{credCtx},
CredentialContexts: credCtx,
})
if err != nil {
log.Infof("Failed to list credentials for system MCP server %s configuration check: %v",
server.Name, err)
return false
}

secretToolName := secretInfoToolName(server.Name)
secretToolName := SecretInfoToolName(server.Name)
credEnv := make(map[string]string)
for _, cred := range creds {
if cred.ToolName == secretToolName {
continue
}
credDetail, err := gptClient.RevealCredential(ctx, []string{credCtx}, cred.ToolName)
credDetail, err := gptClient.RevealCredential(ctx, credCtx, cred.ToolName)
if err != nil {
continue
}
for k, v := range credDetail.Env {
credEnv[k] = v
}

maps.Copy(credEnv, credDetail.Env)
}

for _, env := range server.Spec.Manifest.Env {
Expand All @@ -289,12 +299,8 @@ func isSystemServerConfigured(ctx context.Context, gptClient *gptscript.GPTScrip
return true
}

func secretInfoToolName(serverName string) string {
return serverName + "-secret-info"
}

// SecretInfoToolName returns the credential toolName used to store token exchange secrets
// for the given system MCP server. Exported for use by API handlers.
func SecretInfoToolName(serverName string) string {
return secretInfoToolName(serverName)
return serverName + "-secret-info"
}
59 changes: 0 additions & 59 deletions pkg/controller/mcpwebhookvalidation/cleanupresources.go

This file was deleted.

Loading
Loading