Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 73 additions & 33 deletions pkg/api/handlers/mcpgateway/oauth/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package oauth

import (
"bytes"
"context"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
Expand All @@ -21,6 +22,7 @@ import (
"github.com/obot-platform/obot/pkg/storage/selectors"
"github.com/obot-platform/obot/pkg/system"
"golang.org/x/crypto/bcrypt"
apierrors "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/fields"
kclient "sigs.k8s.io/controller-runtime/pkg/client"
Expand Down Expand Up @@ -460,60 +462,42 @@ func (h *handler) doTokenExchange(req api.Context, oauthClient v1.OAuthClient, r
})
}

_, resourceMCPID, isConnectURL := strings.Cut(resource, "/mcp-connect/")
if mcpServer.Spec.Manifest.Runtime == types.RuntimeComposite {
_, componentMCPID, ok := strings.Cut(resource, "/mcp-connect/")
audienceID := componentMCPID
audienceID := resourceMCPID

var (
token string
expiresAt time.Time
err error
)

if ok {
if system.IsMCPServerInstanceID(componentMCPID) {
if isConnectURL {
if system.IsMCPServerInstanceID(resourceMCPID) {
// Ensure this MCP server instance belongs to this composite MCP server.
var component v1.MCPServerInstance
if err := req.Get(&component, componentMCPID); err != nil || component.Spec.CompositeName != mcpServer.Name {
if err := req.Get(&component, resourceMCPID); err != nil || component.Spec.CompositeName != mcpServer.Name {
return types.NewErrBadRequest("%v", Error{
Code: ErrInvalidRequest,
Description: "failed to retrieve composite MCP server " + componentMCPID,
Description: "failed to retrieve composite MCP server " + resourceMCPID,
})
}

audienceID = component.Spec.MCPServerName
} else {
// Ensure this MCP server belongs to this composite MCP server.
var component v1.MCPServer
if err := req.Get(&component, componentMCPID); err != nil || component.Spec.CompositeName != mcpServer.Name {
if err := req.Get(&component, resourceMCPID); err != nil || component.Spec.CompositeName != mcpServer.Name {
return types.NewErrBadRequest("%v", Error{
Code: ErrInvalidRequest,
Description: "failed to retrieve composite MCP server " + componentMCPID,
Description: "failed to retrieve composite MCP server " + resourceMCPID,
})
}
}

if subjectTokenType == tokenTypeAPIKey {
// Pass the API key through to component servers for their own token exchange
token = subjectToken
expiresAt = time.Now().Add(tokenExpiration)
if apiKeyExpiresAt != nil {
expiresAt = *apiKeyExpiresAt
}
} else {
// For JWTs, update the existing token context and create a new token
tokenCtx.MCPID = componentMCPID
tokenCtx.Audience = fmt.Sprintf("%s/mcp-connect/%s", h.baseURL, audienceID)
expiresAt = tokenCtx.ExpiresAt

var err error
token, err = h.tokenService.NewToken(req.Context(), *tokenCtx)
if err != nil {
log.Errorf("failed to create token for component MCP server %s: %v", componentMCPID, err)
return types.NewErrBadRequest("%v", Error{
Code: ErrServerError,
Description: "failed to create token",
})
}
token, expiresAt, err = h.getTokenForConnectResource(req.Context(), subjectTokenType, subjectToken, apiKeyExpiresAt, tokenCtx, resourceMCPID, audienceID)
if err != nil {
return err
}
} else {
// No component MCP ID in resource, return the original token
Expand All @@ -537,10 +521,37 @@ func (h *handler) doTokenExchange(req api.Context, oauthClient v1.OAuthClient, r
ExpiresIn: max(int(time.Until(expiresAt).Seconds()), 0),
})
}

// If this is an MCP server (validated by the IsMCPServerID check above), and it is trying to call
// a webhook system MCP server, then return a valid token for that system MCP server after validating
// that the system MCP server exists.
// If the server doesn't exist, then let the logic fall-through to the normal token exchange logic.
if isConnectURL && system.IsWebhookSystemMCPServerID(resourceMCPID) {
var systemMCPServer v1.SystemMCPServer
if err := req.Get(&systemMCPServer, resourceMCPID); err == nil {
token, expiresAt, err := h.getTokenForConnectResource(req.Context(), subjectTokenType, subjectToken, apiKeyExpiresAt, tokenCtx, resourceMCPID, resourceMCPID)
if err != nil {
return err
}

log.Infof("Issued token-exchange response for webhook system MCP server: client=%s mcpID=%s audienceResource=%s subjectTokenType=%s", oauthClient.Name, mcpID, resource, subjectTokenType)
return req.Write(TokenExchangeResponse{
AccessToken: token,
IssuedTokenType: tokenTypeAccessToken,
TokenType: "Bearer",
ExpiresIn: max(int(time.Until(expiresAt).Seconds()), 0),
})
} else if !apierrors.IsNotFound(err) {
return Error{
Code: ErrInvalidRequest,
Description: fmt.Sprintf("failed to retrieve system MCP server %s: %v", resourceMCPID, err),
}
}
}
} else if system.IsMCPServerInstanceID(mcpID) {
return types.NewErrNotFound("no token exchange for %s", resource)
} else if system.IsSystemMCPServerID(mcpID) {
// Return a new token that represents the user, so that the SystemMCPServer can make API calls to Obot on behalf of the user.
} else if mcpID == system.ObotMCPServerName {
// Return a new token that represents the user, so that the Obot MCP server can make API calls to Obot on behalf of the user.
// Preserve the user's existing groups/roles when available from the subject token,
// otherwise look up the user to determine their role.
var userGroups []string
Expand Down Expand Up @@ -617,11 +628,40 @@ func (h *handler) doTokenExchange(req api.Context, oauthClient v1.OAuthClient, r
})
}

// getTokenForConnectResource handles the special case of token exchange for /mcp-connect/{resourceMCPID} resources.
// It returns the same API key if the subject token is an API key, or creates a new token with the appropriate audience if the subject token is a JWT.
func (h *handler) getTokenForConnectResource(ctx context.Context, subjectTokenType, subjectToken string, apiKeyExpiresAt *time.Time, tokenCtx *persistent.TokenContext, resourceMCPID, audienceID string) (string, time.Time, error) {
if subjectTokenType == tokenTypeAPIKey {
// Pass the API key through to component servers for their own token exchange.
expiresAt := time.Now().Add(tokenExpiration)
if apiKeyExpiresAt != nil {
expiresAt = *apiKeyExpiresAt
}

return subjectToken, expiresAt, nil
}

// For JWTs, update the existing token context and create a new token.
tokenCtx.MCPID = resourceMCPID
tokenCtx.Audience = fmt.Sprintf("%s/mcp-connect/%s", h.baseURL, audienceID)

token, err := h.tokenService.NewToken(ctx, *tokenCtx)
if err != nil {
log.Errorf("failed to create token for component MCP server %s: %v", resourceMCPID, err)
return "", time.Time{}, types.NewErrBadRequest("%v", Error{
Code: ErrServerError,
Description: "failed to create token",
})
}

return token, tokenCtx.ExpiresAt, nil
}

// validateAPIKeyAccess checks if the API key has access to the specified MCP server.
// For component servers (servers that belong to a composite), it instead checks whether
// the corresponding composite server is in the allowed list.
func validateAPIKeyAccess(ctx api.Context, apiKey *gwtypes.APIKey, mcpID string) error {
if slices.Contains(apiKey.MCPServerIDs, "*") || slices.Contains(apiKey.MCPServerIDs, mcpID) {
if system.IsWebhookSystemMCPServerID(mcpID) || slices.Contains(apiKey.MCPServerIDs, "*") || slices.Contains(apiKey.MCPServerIDs, mcpID) {
return nil
}

Expand Down
46 changes: 26 additions & 20 deletions pkg/controller/handlers/systemmcpserver/systemmcpserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"crypto/rand"
"errors"
"fmt"
"maps"
"strings"
"time"

Expand Down Expand Up @@ -66,7 +67,7 @@ func (h *Handler) EnsureSecretInfo(req router.Request, _ router.Response) error
}
}

secretCredToolName := secretInfoToolName(systemServer.Name)
secretCredToolName := SecretInfoToolName(systemServer.Name)

if systemServer.Status.AuditLogTokenHash != "" {
cred, err := h.gptClient.RevealCredential(req.Ctx, []string{systemServer.Name}, secretCredToolName)
Expand Down Expand Up @@ -168,7 +169,7 @@ func (h *Handler) EnsureDeployment(req router.Request, _ router.Response) error
return fmt.Errorf("failed to list credentials: %w", err)
}

secretToolName := secretInfoToolName(systemServer.Name)
secretToolName := SecretInfoToolName(systemServer.Name)
credEnv := make(map[string]string)
for _, cred := range creds {
// Skip the secret info credential — those vars go to the shim only, not the MCP server.
Expand All @@ -180,9 +181,8 @@ func (h *Handler) EnsureDeployment(req router.Request, _ router.Response) error
if err != nil {
continue
}
for k, v := range credDetail.Env {
credEnv[k] = v
}

maps.Copy(credEnv, credDetail.Env)
}

// Retrieve the token exchange credential
Expand Down Expand Up @@ -239,12 +239,23 @@ func (h *Handler) EnsureDeployment(req router.Request, _ router.Response) error
// CleanupDeployment handles cleanup when SystemMCPServer is deleted
func (h *Handler) CleanupDeployment(req router.Request, _ router.Response) error {
systemServer := req.Object.(*v1.SystemMCPServer)
creds, err := h.gptClient.ListCredentials(req.Ctx, gptscript.ListCredentialsOptions{
CredentialContexts: []string{systemServer.Name},
})
if err != nil {
return fmt.Errorf("failed to list credentials for %s system server cleanup: %w", systemServer.Name, err)
}

for _, cred := range creds {
if err := h.gptClient.DeleteCredential(req.Ctx, cred.Context, cred.ToolName); err != nil && !errors.As(err, &gptscript.ErrNotFound{}) {
return fmt.Errorf("failed to delete credential %s: %w", cred.ToolName, err)
}
}

// Shutdown deployment via backend
// The backend's shutdownServer will remove the deployment (Docker container or K8s deployment)
err := h.mcpSessionManager.ShutdownServer(req.Ctx, systemServer.Name)
if err != nil {
return fmt.Errorf("failed to shutdown system MCP server: %w", err)
if err = h.mcpSessionManager.ShutdownServer(req.Ctx, systemServer.Name); err != nil {
return fmt.Errorf("failed to shutdown system MCP server %s: %w", systemServer.Name, err)
}

return nil
Expand All @@ -253,29 +264,28 @@ func (h *Handler) CleanupDeployment(req router.Request, _ router.Response) error
// isSystemServerConfigured checks if all required configuration is present
func isSystemServerConfigured(ctx context.Context, gptClient *gptscript.GPTScript, server v1.SystemMCPServer) bool {
// Check if all required env vars are configured
credCtx := server.Name
credCtx := []string{server.Name}
creds, err := gptClient.ListCredentials(ctx, gptscript.ListCredentialsOptions{
CredentialContexts: []string{credCtx},
CredentialContexts: credCtx,
})
if err != nil {
log.Infof("Failed to list credentials for system MCP server %s configuration check: %v",
server.Name, err)
return false
}

secretToolName := secretInfoToolName(server.Name)
secretToolName := SecretInfoToolName(server.Name)
credEnv := make(map[string]string)
for _, cred := range creds {
if cred.ToolName == secretToolName {
continue
}
credDetail, err := gptClient.RevealCredential(ctx, []string{credCtx}, cred.ToolName)
credDetail, err := gptClient.RevealCredential(ctx, credCtx, cred.ToolName)
if err != nil {
continue
}
for k, v := range credDetail.Env {
credEnv[k] = v
}

maps.Copy(credEnv, credDetail.Env)
}

for _, env := range server.Spec.Manifest.Env {
Expand All @@ -289,12 +299,8 @@ func isSystemServerConfigured(ctx context.Context, gptClient *gptscript.GPTScrip
return true
}

func secretInfoToolName(serverName string) string {
return serverName + "-secret-info"
}

// SecretInfoToolName returns the credential toolName used to store token exchange secrets
// for the given system MCP server. Exported for use by API handlers.
func SecretInfoToolName(serverName string) string {
return secretInfoToolName(serverName)
return serverName + "-secret-info"
}
59 changes: 0 additions & 59 deletions pkg/controller/mcpwebhookvalidation/cleanupresources.go

This file was deleted.

Loading
Loading