diff --git a/pkg/mcp/mcp.go b/pkg/mcp/mcp.go index 9cdd190aa..6942085ab 100644 --- a/pkg/mcp/mcp.go +++ b/pkg/mcp/mcp.go @@ -7,7 +7,6 @@ import ( "os" "slices" "sync" - "time" "github.com/modelcontextprotocol/go-sdk/mcp" "k8s.io/klog/v2" @@ -72,26 +71,23 @@ type Server struct { } func NewServer(configuration Configuration, targetProvider internalk8s.Provider) (*Server, error) { - s := &Server{ - configuration: &configuration, - server: mcp.NewServer( - &mcp.Implementation{ - Name: version.BinaryName, - Title: version.BinaryName, - Version: version.Version, - WebsiteURL: version.WebsiteURL, + // Initialize MCP server + mcpServer := mcp.NewServer( + &mcp.Implementation{ + Name: version.BinaryName, + Title: version.BinaryName, + Version: version.Version, + WebsiteURL: version.WebsiteURL, + }, + &mcp.ServerOptions{ + Capabilities: &mcp.ServerCapabilities{ + Resources: nil, + Prompts: &mcp.PromptCapabilities{ListChanged: !configuration.Stateless}, + Tools: &mcp.ToolCapabilities{ListChanged: !configuration.Stateless}, + Logging: &mcp.LoggingCapabilities{}, }, - &mcp.ServerOptions{ - Capabilities: &mcp.ServerCapabilities{ - Resources: nil, - Prompts: &mcp.PromptCapabilities{ListChanged: !configuration.Stateless}, - Tools: &mcp.ToolCapabilities{ListChanged: !configuration.Stateless}, - Logging: &mcp.LoggingCapabilities{}, - }, - Instructions: configuration.ServerInstructions, - }), - p: targetProvider, - } + Instructions: configuration.ServerInstructions, + }) // Initialize metrics system metricsInstance, err := metrics.New(metrics.Config{ @@ -100,20 +96,36 @@ func NewServer(configuration Configuration, targetProvider internalk8s.Provider) ServiceVersion: version.Version, Telemetry: &configuration.Telemetry, }) + + // Add receiving middleware to the MCP server + mcpServer.AddReceivingMiddleware(sessionInjectionMiddleware) + mcpServer.AddReceivingMiddleware(traceContextPropagationMiddleware) + mcpServer.AddReceivingMiddleware(tracingMiddleware(version.BinaryName + "/mcp")) + mcpServer.AddReceivingMiddleware(authHeaderPropagationMiddleware) + mcpServer.AddReceivingMiddleware(userAgentPropagationMiddleware(version.BinaryName, version.Version)) + mcpServer.AddReceivingMiddleware(toolCallLoggingMiddleware) + mcpServer.AddReceivingMiddleware(metricsMiddleware(metricsInstance)) + if err != nil { return nil, fmt.Errorf("failed to initialize metrics: %w", err) } - s.metrics = metricsInstance - s.server.AddReceivingMiddleware(sessionInjectionMiddleware) - s.server.AddReceivingMiddleware(traceContextPropagationMiddleware) - s.server.AddReceivingMiddleware(tracingMiddleware(version.BinaryName + "/mcp")) - s.server.AddReceivingMiddleware(authHeaderPropagationMiddleware) - s.server.AddReceivingMiddleware(userAgentPropagationMiddleware(version.BinaryName, version.Version)) - s.server.AddReceivingMiddleware(toolCallLoggingMiddleware) - s.server.AddReceivingMiddleware(s.metricsMiddleware()) + return NewServerFrom(configuration, mcpServer, targetProvider, metricsInstance) +} - err = s.reloadToolsets() +// NewServerFrom creates a new MCP server from pre-configured components. +// Use this when you need full control over the MCP server, metrics, or middleware. +// For standard usage with defaults, prefer NewServer. +func NewServerFrom(configuration Configuration, mcpServer *mcp.Server, targetProvider internalk8s.Provider, metrics *metrics.Metrics) (*Server, error) { + s := &Server{ + configuration: &configuration, + server: mcpServer, + p: targetProvider, + metrics: metrics, + } + + // reload toolsets + err := s.reloadToolsets() if err != nil { return nil, err } @@ -266,31 +278,6 @@ func (s *Server) registerPrompt(prompt api.ServerPrompt) error { return nil } -// metricsMiddleware returns a metrics middleware with access to the server's metrics system -func (s *Server) metricsMiddleware() func(mcp.MethodHandler) mcp.MethodHandler { - return func(next mcp.MethodHandler) mcp.MethodHandler { - return func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { - start := time.Now() - result, err := next(ctx, method, req) - duration := time.Since(start) - - toolName := method - if method == "tools/call" { - if params, ok := req.GetParams().(*mcp.CallToolParamsRaw); ok { - if toolReq, _ := GoSdkToolCallParamsToToolCallRequest(params); toolReq != nil { - toolName = toolReq.Name - } - } - } - - // Record to all collectors - s.metrics.RecordToolCall(ctx, toolName, duration, err) - - return result, err - } - } -} - // GetMetrics returns the metrics system for use by the HTTP server. func (s *Server) GetMetrics() *metrics.Metrics { return s.metrics diff --git a/pkg/mcp/middleware.go b/pkg/mcp/middleware.go index 726d95653..4a42f441e 100644 --- a/pkg/mcp/middleware.go +++ b/pkg/mcp/middleware.go @@ -6,9 +6,11 @@ import ( "fmt" "runtime" "strings" + "time" internalk8s "github.com/containers/kubernetes-mcp-server/pkg/kubernetes" "github.com/containers/kubernetes-mcp-server/pkg/mcplog" + "github.com/containers/kubernetes-mcp-server/pkg/metrics" "github.com/containers/kubernetes-mcp-server/pkg/telemetry" "github.com/modelcontextprotocol/go-sdk/mcp" "go.opentelemetry.io/otel" @@ -212,6 +214,31 @@ func getMcpReqUserAgent(req mcp.Request) string { return fmt.Sprintf("%s/%s", initParams.ClientInfo.Name, initParams.ClientInfo.Version) } +// metricsMiddleware returns a metrics middleware with access to the server's metrics system +func metricsMiddleware(metrics *metrics.Metrics) func(mcp.MethodHandler) mcp.MethodHandler { + return func(next mcp.MethodHandler) mcp.MethodHandler { + return func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { + start := time.Now() + result, err := next(ctx, method, req) + duration := time.Since(start) + + toolName := method + if method == "tools/call" { + if params, ok := req.GetParams().(*mcp.CallToolParamsRaw); ok { + if toolReq, _ := GoSdkToolCallParamsToToolCallRequest(params); toolReq != nil { + toolName = toolReq.Name + } + } + } + + // Record to all collectors + metrics.RecordToolCall(ctx, toolName, duration, err) + + return result, err + } + } +} + // metaCarrier adapts an MCP Meta map to the OpenTelemetry TextMapCarrier interface type metaCarrier struct { meta map[string]any