diff --git a/internal/test/mock_toolset.go b/internal/test/mock_toolset.go new file mode 100644 index 000000000..a293462ec --- /dev/null +++ b/internal/test/mock_toolset.go @@ -0,0 +1,57 @@ +package test + +import ( + "github.com/containers/kubernetes-mcp-server/pkg/api" + "github.com/containers/kubernetes-mcp-server/pkg/toolsets" +) + +// MockToolset is a test helper for testing toolset functionality +type MockToolset struct { + Name string + Description string + Instructions string + Tools []api.ServerTool + Prompts []api.ServerPrompt +} + +var _ api.Toolset = (*MockToolset)(nil) + +func (m *MockToolset) GetName() string { + return m.Name +} + +func (m *MockToolset) GetDescription() string { + return m.Description +} + +func (m *MockToolset) GetToolsetInstructions() string { + return m.Instructions +} + +func (m *MockToolset) GetTools(_ api.Openshift) []api.ServerTool { + if m.Tools == nil { + return []api.ServerTool{} + } + return m.Tools +} + +func (m *MockToolset) GetPrompts() []api.ServerPrompt { + return m.Prompts +} + +// RegisterMockToolset registers a mock toolset for testing +func RegisterMockToolset(mockToolset *MockToolset) { + toolsets.Register(mockToolset) +} + +// UnregisterMockToolset removes a mock toolset from the registry +func UnregisterMockToolset(name string) { + // Get all toolsets and rebuild the list without the mock + allToolsets := toolsets.Toolsets() + toolsets.Clear() + for _, ts := range allToolsets { + if ts.GetName() != name { + toolsets.Register(ts) + } + } +} diff --git a/pkg/api/toolsets.go b/pkg/api/toolsets.go index 59b1f3c70..8bad5c197 100644 --- a/pkg/api/toolsets.go +++ b/pkg/api/toolsets.go @@ -46,6 +46,11 @@ type Toolset interface { // GetPrompts returns the prompts provided by this toolset. // Returns nil if the toolset doesn't provide any prompts. GetPrompts() []ServerPrompt + // GetToolsetInstructions returns instructions for using the tools in this toolset. + // These instructions will be included in the MCP server's initialize response + // to help LLMs understand how to effectively use the toolset. + // Returns an empty string if no specific instructions are needed. + GetToolsetInstructions() string } type ToolCallRequest interface { diff --git a/pkg/config/config.go b/pkg/config/config.go index 668cb6dd2..77cb97d70 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -86,6 +86,10 @@ type StaticConfig struct { // This can be used to provide specific instructions on how the client should use the server ServerInstructions string `toml:"server_instructions,omitempty"` + // DisableToolsetInstructions indicates whether toolset instructions are to be excluded + // from being provided by the MCP server to the MCP client in the initialize response. + DisableToolsetInstructions bool `toml:"disable_toolset_instructions,omitempty"` + // Telemetry contains OpenTelemetry configuration options. // These can also be configured via OTEL_* environment variables. Telemetry TelemetryConfig `toml:"telemetry,omitempty"` diff --git a/pkg/mcp/mcp.go b/pkg/mcp/mcp.go index dab9d501e..e06fa5d40 100644 --- a/pkg/mcp/mcp.go +++ b/pkg/mcp/mcp.go @@ -6,6 +6,7 @@ import ( "net/http" "os" "slices" + "strings" "time" "github.com/modelcontextprotocol/go-sdk/mcp" @@ -70,6 +71,11 @@ type Server struct { } func NewServer(configuration Configuration, targetProvider internalk8s.Provider) (*Server, error) { + instructions := configuration.ServerInstructions + if !configuration.DisableToolsetInstructions { + instructions = buildServerInstructions(configuration.ServerInstructions, configuration.Toolsets()) + } + s := &Server{ configuration: &configuration, server: mcp.NewServer( @@ -86,7 +92,7 @@ func NewServer(configuration Configuration, targetProvider internalk8s.Provider) Tools: &mcp.ToolCapabilities{ListChanged: !configuration.Stateless}, Logging: &mcp.LoggingCapabilities{}, }, - Instructions: configuration.ServerInstructions, + Instructions: instructions, }), p: targetProvider, } @@ -119,6 +125,28 @@ func NewServer(configuration Configuration, targetProvider internalk8s.Provider) return s, nil } +// buildServerInstructions combines server instructions with toolset-specific instructions +func buildServerInstructions(serverInstructions string, toolsets []api.Toolset) string { + var instructions []string + + if serverInstructions != "" { + instructions = append(instructions, serverInstructions) + } + + for _, toolset := range toolsets { + if toolsetInstructions := toolset.GetToolsetInstructions(); toolsetInstructions != "" { + // Add markdown h2 header with toolset name + header := fmt.Sprintf("## %s", toolset.GetName()) + instructions = append(instructions, header, toolsetInstructions) + } + } + + if len(instructions) == 0 { + return "" + } + return strings.Join(instructions, "\n\n") +} + func (s *Server) reloadToolsets() error { ctx := context.Background() diff --git a/pkg/mcp/mcp_test.go b/pkg/mcp/mcp_test.go index d6380ff8f..93fbe4668 100644 --- a/pkg/mcp/mcp_test.go +++ b/pkg/mcp/mcp_test.go @@ -9,6 +9,7 @@ import ( "github.com/BurntSushi/toml" "github.com/containers/kubernetes-mcp-server/internal/test" + "github.com/containers/kubernetes-mcp-server/pkg/api" internalk8s "github.com/containers/kubernetes-mcp-server/pkg/kubernetes" "github.com/mark3labs/mcp-go/client/transport" "github.com/stretchr/testify/suite" @@ -266,3 +267,200 @@ func (s *UserAgentPropagationSuite) TestFallsBackToServerPrefixWhenNoClientInfo( func TestUserAgentPropagation(t *testing.T) { suite.Run(t, new(UserAgentPropagationSuite)) } + +type ToolsetInstructionsSuite struct { + BaseMcpSuite +} + +func (s *ToolsetInstructionsSuite) TestToolsetInstructionsAreIncluded() { + mockToolset := &test.MockToolset{ + Name: "mock", + Description: "Mock toolset for testing", + Instructions: "These are mock toolset instructions.\nAlways use caution with mock tools.", + } + + s.Cfg.Toolsets = []string{"mock", "core"} + + test.RegisterMockToolset(mockToolset) + defer test.UnregisterMockToolset("mock") + + s.InitMcpClient() + s.Run("includes toolset instructions in initialize response", func() { + s.Require().NotNil(s.InitializeResult) + s.Contains(s.InitializeResult.Instructions, "These are mock toolset instructions.\nAlways use caution with mock tools.", + "instructions should include toolset instructions") + }) + s.Run("adds markdown header with toolset name", func() { + s.Require().NotNil(s.InitializeResult) + s.Contains(s.InitializeResult.Instructions, "## mock", + "instructions should include markdown header with toolset name") + }) +} + +func (s *ToolsetInstructionsSuite) TestToolsetInstructionsCombinedWithServerInstructions() { + mockToolset := &test.MockToolset{ + Name: "mock", + Description: "Mock toolset for testing", + Instructions: "Toolset-specific instructions.", + } + + s.Require().NoError(toml.Unmarshal([]byte(` + server_instructions = "Server-level instructions." + toolsets = ["mock"] + `), s.Cfg), "Expected to parse config") + + test.RegisterMockToolset(mockToolset) + defer test.UnregisterMockToolset("mock") + + s.InitMcpClient() + s.Run("combines server and toolset instructions", func() { + s.Require().NotNil(s.InitializeResult) + s.Contains(s.InitializeResult.Instructions, "Server-level instructions.", + "instructions should include server instructions") + s.Contains(s.InitializeResult.Instructions, "Toolset-specific instructions.", + "instructions should include toolset instructions") + }) +} + +func (s *ToolsetInstructionsSuite) TestEmptyToolsetInstructionsNotIncluded() { + s.Cfg.Toolsets = []string{"core"} + s.InitMcpClient() + s.Run("does not include empty toolset instructions", func() { + s.Require().NotNil(s.InitializeResult) + s.Empty(s.InitializeResult.Instructions, + "instructions should be empty when toolset instructions are empty") + }) +} + +func (s *ToolsetInstructionsSuite) TestDisableToolsetInstructions() { + mockToolset := &test.MockToolset{ + Name: "mock", + Description: "Mock toolset for testing", + Instructions: "These instructions should be ignored.", + } + + s.Require().NoError(toml.Unmarshal([]byte(` + server_instructions = "Server-level instructions only." + toolsets = ["mock"] + disable_toolset_instructions = true + `), s.Cfg), "Expected to parse config") + + test.RegisterMockToolset(mockToolset) + defer test.UnregisterMockToolset("mock") + + s.InitMcpClient() + s.Run("excludes toolset instructions when disabled", func() { + s.Require().NotNil(s.InitializeResult) + s.Equal("Server-level instructions only.", s.InitializeResult.Instructions, + "instructions should only contain server instructions when toolset instructions are disabled") + s.NotContains(s.InitializeResult.Instructions, "These instructions should be ignored.", + "instructions should not include toolset instructions when disabled") + }) +} + +func (s *ToolsetInstructionsSuite) TestToolsetInstructionsWithExistingHeaders() { + mockToolset := &test.MockToolset{ + Name: "mock", + Description: "Mock toolset for testing", + Instructions: "### Subheader\nActual instructions here.", + } + + s.Cfg.Toolsets = []string{"mock"} + + test.RegisterMockToolset(mockToolset) + defer test.UnregisterMockToolset("mock") + + s.InitMcpClient() + s.Run("preserves existing headers and adds toolset header", func() { + s.Require().NotNil(s.InitializeResult) + s.Contains(s.InitializeResult.Instructions, "## mock", + "instructions should include markdown header with toolset name") + s.Contains(s.InitializeResult.Instructions, "### Subheader", + "instructions should preserve subheader") + s.Contains(s.InitializeResult.Instructions, "Actual instructions here.", + "instructions should include the actual content") + }) +} + +func TestToolsetInstructions(t *testing.T) { + suite.Run(t, new(ToolsetInstructionsSuite)) +} + +type BuildServerInstructionsSuite struct { + suite.Suite +} + +func (s *BuildServerInstructionsSuite) TestBuildServerInstructions() { + s.Run("returns empty string with no instructions", func() { + result := buildServerInstructions("", []api.Toolset{}) + s.Empty(result) + }) + + s.Run("returns only server instructions when no toolsets", func() { + serverInstructions := "Server instructions here" + result := buildServerInstructions(serverInstructions, []api.Toolset{}) + s.Equal(serverInstructions, result) + }) + + s.Run("adds toolset header for single toolset", func() { + mockToolset := &test.MockToolset{ + Name: "test-toolset", + Instructions: "Toolset instructions", + } + result := buildServerInstructions("", []api.Toolset{mockToolset}) + expected := "## test-toolset\n\nToolset instructions" + s.Equal(expected, result) + }) + + s.Run("combines server instructions with multiple toolsets", func() { + mockToolset1 := &test.MockToolset{ + Name: "toolset1", + Instructions: "Instructions for toolset 1", + } + mockToolset2 := &test.MockToolset{ + Name: "toolset2", + Instructions: "### Header\nInstructions for toolset 2", + } + result := buildServerInstructions("Server instructions", []api.Toolset{mockToolset1, mockToolset2}) + expected := "Server instructions\n\n## toolset1\n\nInstructions for toolset 1\n\n## toolset2\n\n### Header\nInstructions for toolset 2" + s.Equal(expected, result) + }) + + s.Run("skips toolsets with empty instructions", func() { + mockToolset1 := &test.MockToolset{ + Name: "toolset1", + Instructions: "Instructions for toolset 1", + } + mockToolset2 := &test.MockToolset{ + Name: "toolset2", + Instructions: "", + } + result := buildServerInstructions("", []api.Toolset{mockToolset1, mockToolset2}) + expected := "## toolset1\n\nInstructions for toolset 1" + s.Equal(expected, result) + }) + + s.Run("handles multiline instructions", func() { + mockToolset := &test.MockToolset{ + Name: "test-toolset", + Instructions: "Line 1\nLine 2\nLine 3", + } + result := buildServerInstructions("", []api.Toolset{mockToolset}) + expected := "## test-toolset\n\nLine 1\nLine 2\nLine 3" + s.Equal(expected, result) + }) + + s.Run("handles instructions with markdown content", func() { + mockToolset := &test.MockToolset{ + Name: "test-toolset", + Instructions: "**Bold text**\n- List item 1\n- List item 2", + } + result := buildServerInstructions("", []api.Toolset{mockToolset}) + expected := "## test-toolset\n\n**Bold text**\n- List item 1\n- List item 2" + s.Equal(expected, result) + }) +} + +func TestBuildServerInstructions(t *testing.T) { + suite.Run(t, new(BuildServerInstructionsSuite)) +} diff --git a/pkg/mcp/mcp_toolset_prompts_test.go b/pkg/mcp/mcp_toolset_prompts_test.go index 3256f882f..4dfb680a2 100644 --- a/pkg/mcp/mcp_toolset_prompts_test.go +++ b/pkg/mcp/mcp_toolset_prompts_test.go @@ -386,6 +386,10 @@ func (m *mockToolsetWithPrompts) GetPrompts() []api.ServerPrompt { return m.prompts } +func (m *mockToolsetWithPrompts) GetToolsetInstructions() string { + return "" +} + func TestMcpToolsetPromptsSuite(t *testing.T) { suite.Run(t, new(McpToolsetPromptsSuite)) } diff --git a/pkg/toolsets/config/toolset.go b/pkg/toolsets/config/toolset.go index 3d08fb597..6e245f39e 100644 --- a/pkg/toolsets/config/toolset.go +++ b/pkg/toolsets/config/toolset.go @@ -30,6 +30,10 @@ func (t *Toolset) GetPrompts() []api.ServerPrompt { return nil } +func (t *Toolset) GetToolsetInstructions() string { + return "" +} + func init() { toolsets.Register(&Toolset{}) } diff --git a/pkg/toolsets/core/toolset.go b/pkg/toolsets/core/toolset.go index 536b9428c..42945efab 100644 --- a/pkg/toolsets/core/toolset.go +++ b/pkg/toolsets/core/toolset.go @@ -35,6 +35,10 @@ func (t *Toolset) GetPrompts() []api.ServerPrompt { ) } +func (t *Toolset) GetToolsetInstructions() string { + return "" +} + func init() { toolsets.Register(&Toolset{}) } diff --git a/pkg/toolsets/helm/toolset.go b/pkg/toolsets/helm/toolset.go index 6bdbfd419..49af13abd 100644 --- a/pkg/toolsets/helm/toolset.go +++ b/pkg/toolsets/helm/toolset.go @@ -30,6 +30,10 @@ func (t *Toolset) GetPrompts() []api.ServerPrompt { return nil } +func (t *Toolset) GetToolsetInstructions() string { + return "" +} + func init() { toolsets.Register(&Toolset{}) } diff --git a/pkg/toolsets/kcp/toolset.go b/pkg/toolsets/kcp/toolset.go index 1a600d1a3..36432c495 100644 --- a/pkg/toolsets/kcp/toolset.go +++ b/pkg/toolsets/kcp/toolset.go @@ -29,6 +29,10 @@ func (t *Toolset) GetPrompts() []api.ServerPrompt { return nil } +func (t *Toolset) GetToolsetInstructions() string { + return "" +} + func init() { toolsets.Register(&Toolset{}) } diff --git a/pkg/toolsets/kiali/toolset.go b/pkg/toolsets/kiali/toolset.go index da5fded30..8efab6150 100644 --- a/pkg/toolsets/kiali/toolset.go +++ b/pkg/toolsets/kiali/toolset.go @@ -38,6 +38,10 @@ func (t *Toolset) GetPrompts() []api.ServerPrompt { return nil } +func (t *Toolset) GetToolsetInstructions() string { + return "" +} + func init() { toolsets.Register(&Toolset{}) } diff --git a/pkg/toolsets/kubevirt/toolset.go b/pkg/toolsets/kubevirt/toolset.go index 63200eeac..bad242a07 100644 --- a/pkg/toolsets/kubevirt/toolset.go +++ b/pkg/toolsets/kubevirt/toolset.go @@ -32,6 +32,10 @@ func (t *Toolset) GetPrompts() []api.ServerPrompt { return initVMTroubleshoot() } +func (t *Toolset) GetToolsetInstructions() string { + return "" +} + func init() { toolsets.Register(&Toolset{}) } diff --git a/pkg/toolsets/toolsets_test.go b/pkg/toolsets/toolsets_test.go index c2e869814..b25c005f6 100644 --- a/pkg/toolsets/toolsets_test.go +++ b/pkg/toolsets/toolsets_test.go @@ -36,6 +36,8 @@ func (t *TestToolset) GetTools(_ api.Openshift) []api.ServerTool { return nil } func (t *TestToolset) GetPrompts() []api.ServerPrompt { return nil } +func (t *TestToolset) GetToolsetInstructions() string { return "" } + var _ api.Toolset = (*TestToolset)(nil) func (s *ToolsetsSuite) TestToolsetNames() {