diff --git a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/protocol/server/McpNettyServer.java b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/protocol/server/McpNettyServer.java index 678c9ac36bc..1f441009551 100644 --- a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/protocol/server/McpNettyServer.java +++ b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/protocol/server/McpNettyServer.java @@ -8,6 +8,8 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.taobao.arthas.mcp.server.CommandExecutor; import com.taobao.arthas.mcp.server.protocol.spec.*; +import com.taobao.arthas.mcp.server.session.ArthasCommandSessionManager; +import com.taobao.arthas.mcp.server.task.*; import com.taobao.arthas.mcp.server.util.Assert; import com.taobao.arthas.mcp.server.util.Utils; import org.slf4j.Logger; @@ -15,14 +17,12 @@ import java.time.Duration; import java.util.*; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.CompletionException; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.*; import java.util.function.BiFunction; /** - * A Netty-based MCP server implementation that provides access to tools, resources, and prompts. + * A Netty-based MCP server implementation that provides access to tools, + * resources, and prompts. * * @author Yeaury */ @@ -42,40 +42,68 @@ public class McpNettyServer { private final CopyOnWriteArrayList tools = new CopyOnWriteArrayList<>(); + private final ConcurrentHashMap toolsByName = new ConcurrentHashMap<>(); + private final CopyOnWriteArrayList resourceTemplates = new CopyOnWriteArrayList<>(); private final ConcurrentHashMap resources = new ConcurrentHashMap<>(); private final ConcurrentHashMap prompts = new ConcurrentHashMap<>(); + private final ServerTaskToolHandler serverTaskToolHandler; + + private final ArthasCommandSessionManager sessionManager; + private McpSchema.LoggingLevel minLoggingLevel = McpSchema.LoggingLevel.DEBUG; private List protocolVersions; McpNettyServer(McpStreamableServerTransportProvider mcpTransportProvider, - ObjectMapper objectMapper, Duration requestTimeout, - McpServerFeatures.McpServerConfig features, - CommandExecutor commandExecutor) { + ObjectMapper objectMapper, Duration requestTimeout, + McpServerFeatures.McpServerConfig features, + CommandExecutor commandExecutor, + ArthasCommandSessionManager sessionManager) { this.mcpTransportProvider = mcpTransportProvider; this.objectMapper = objectMapper; this.serverInfo = features.getServerInfo(); this.serverCapabilities = features.getServerCapabilities(); this.instructions = features.getInstructions(); this.tools.addAll(features.getTools()); + + for (McpServerFeatures.ToolSpecification tool : features.getTools()) { + this.toolsByName.put(tool.getTool().getName(), tool); + } + this.resources.putAll(features.getResources()); this.resourceTemplates.addAll(features.getResourceTemplates()); this.prompts.putAll(features.getPrompts()); + this.sessionManager = sessionManager; + + this.serverTaskToolHandler = new ServerTaskToolHandler( + features.getTaskTools(), + features.getTaskOptions(), + objectMapper, + this::notifyAllClients, + Duration.ofSeconds(30), + sessionManager); + Map> requestHandlers = prepareRequestHandlers(); Map notificationHandlers = prepareNotificationHandlers(features); this.protocolVersions = mcpTransportProvider.protocolVersions(); + TaskStore taskStore = this.serverTaskToolHandler + .getTaskStore(); + TaskMessageQueue taskMessageQueue = this.serverTaskToolHandler.getTaskMessageQueue(); + mcpTransportProvider.setSessionFactory(new DefaultMcpStreamableServerSessionFactory(requestTimeout, - this::initializeRequestHandler, requestHandlers, notificationHandlers, commandExecutor)); + this::initializeRequestHandler, requestHandlers, notificationHandlers, commandExecutor, + taskStore, taskMessageQueue)); } - private Map prepareNotificationHandlers(McpServerFeatures.McpServerConfig features) { + private Map prepareNotificationHandlers( + McpServerFeatures.McpServerConfig features) { Map notificationHandlers = new HashMap<>(); notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_INITIALIZED, @@ -86,9 +114,9 @@ private Map prepareNotificationHandlers(McpServe if (Utils.isEmpty(rootsChangeConsumers)) { rootsChangeConsumers = Collections.singletonList( - (exchange, roots) -> CompletableFuture.runAsync(() -> - logger.warn("Roots list changed notification, but no consumers provided. Roots list changed: {}", roots)) - ); + (exchange, roots) -> CompletableFuture.runAsync(() -> logger.warn( + "Roots list changed notification, but no consumers provided. Roots list changed: {}", + roots))); } notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED, @@ -129,10 +157,13 @@ private Map> prepareRequestHandlers() { requestHandlers.put(McpSchema.METHOD_LOGGING_SET_LEVEL, setLoggerRequestHandler()); } + // Add tasks API handlers via ServerTaskToolHandler + this.serverTaskToolHandler.logCapabilityMismatches(this.serverCapabilities.getTasks()); + requestHandlers.putAll(this.serverTaskToolHandler.getRequestHandlers(this.serverCapabilities.getTasks())); + return requestHandlers; } - // --------------------------------------- // Lifecycle Management // --------------------------------------- @@ -150,8 +181,7 @@ private CompletableFuture initializeRequestHandler( if (protocolVersions.contains(initializeRequest.getProtocolVersion())) { serverProtocolVersion = initializeRequest.getProtocolVersion(); - } - else { + } else { logger.warn( "Client requested unsupported protocol version: {}, " + "so the server will suggest {} instead", initializeRequest.getProtocolVersion(), serverProtocolVersion); @@ -170,10 +200,19 @@ public McpSchema.Implementation getServerInfo() { } public CompletableFuture closeGracefully() { - return this.mcpTransportProvider.closeGracefully(); + if (this.serverTaskToolHandler != null) { + return this.serverTaskToolHandler.closeGracefully() + .thenCompose(v -> mcpTransportProvider.closeGracefully()); + } + return mcpTransportProvider.closeGracefully(); } public void close() { + try { + closeGracefully().get(5, TimeUnit.SECONDS); + } catch (Exception e) { + logger.warn("Error during graceful close", e); + } this.mcpTransportProvider.close(); } @@ -225,11 +264,13 @@ public CompletableFuture addTool(McpServerFeatures.ToolSpecification toolS return CompletableFuture.supplyAsync(() -> { // Check for duplicate tool names - if (this.tools.stream().anyMatch(th -> th.getTool().getName().equals(toolSpecification.getTool().getName()))) { + if (this.tools.stream() + .anyMatch(th -> th.getTool().getName().equals(toolSpecification.getTool().getName()))) { throw new CompletionException( new McpError("Tool with name '" + toolSpecification.getTool().getName() + "' already exists")); } this.tools.add(toolSpecification); + this.toolsByName.put(toolSpecification.getTool().getName(), toolSpecification); logger.debug("Added tool handler: {}", toolSpecification.getTool().getName()); return null; }).thenCompose(ignored -> { @@ -261,6 +302,7 @@ public CompletableFuture removeTool(String toolName) { if (!removed) { throw new CompletionException(new McpError("Tool with name '" + toolName + "' not found")); } + this.toolsByName.remove(toolName); logger.debug("Removed tool handler: {}", toolName); return null; }).thenCompose(ignored -> { @@ -280,33 +322,94 @@ public CompletableFuture notifyToolsListChanged() { return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_TOOLS_LIST_CHANGED, null); } + + private CompletableFuture notifyAllClients(String method, Object notification) { + return this.mcpTransportProvider.notifyClients(method, notification); + } + + public CompletableFuture addTaskTool(TaskAwareToolSpecification taskToolSpecification) { + if (this.serverCapabilities.getTools() == null) { + CompletableFuture future = new CompletableFuture<>(); + future.completeExceptionally(McpError.builder(McpSchema.ErrorCodes.METHOD_NOT_FOUND) + .message("Server must be configured with tool capabilities") + .build()); + return future; + } + return this.serverTaskToolHandler.addTaskTool(taskToolSpecification, this.serverCapabilities.getTools()); + } + + /** + * Remove a task-aware tool at runtime. + * @param toolName The name of the task-aware tool to remove + * @return Mono that completes when clients have been notified of the change + */ + public CompletableFuture removeTaskTool(String toolName) { + if (this.serverCapabilities.getTools() == null) { + CompletableFuture future = new CompletableFuture<>(); + future.completeExceptionally(McpError.builder(McpSchema.ErrorCodes.METHOD_NOT_FOUND) + .message("Server must be configured with tool capabilities") + .build()); + return future; + } + return this.serverTaskToolHandler.removeTaskTool(toolName, this.serverCapabilities.getTools()); + } + private McpRequestHandler toolsListRequestHandler() { return (exchange, commandContext, params) -> { - List tools = new ArrayList<>(); + List toolList = new ArrayList<>(); for (McpServerFeatures.ToolSpecification toolSpec : this.tools) { - tools.add(toolSpec.getTool()); + toolList.add(toolSpec.getTool()); } + toolList.addAll(this.serverTaskToolHandler.getToolDefinitions()); - return CompletableFuture.completedFuture(new McpSchema.ListToolsResult(tools, null)); + return CompletableFuture.completedFuture(new McpSchema.ListToolsResult(toolList, null)); }; } - private McpRequestHandler toolsCallRequestHandler() { + private McpRequestHandler toolsCallRequestHandler() { return (exchange, commandContext, params) -> { McpSchema.CallToolRequest callToolRequest = objectMapper.convertValue(params, new TypeReference() { }); - Optional toolSpecification = this.tools.stream() - .filter(tr -> callToolRequest.getName().equals(tr.getTool().getName())) - .findAny(); - if (!toolSpecification.isPresent()) { - CompletableFuture future = new CompletableFuture<>(); - future.completeExceptionally(new McpError("no tool found: " + callToolRequest.getName())); - return future; + String toolName = callToolRequest.getName(); + + McpServerFeatures.ToolSpecification normalTool = this.toolsByName.get(toolName); + if (normalTool != null) { + // Normal tools don't support task enhancement requests + if (callToolRequest.getTask() != null) { + CompletableFuture future = new CompletableFuture<>(); + future.completeExceptionally(McpError.builder(McpSchema.ErrorCodes.METHOD_NOT_FOUND) + .message("Tool '" + toolName + "' does not support task-augmented requests") + .data("Remove the 'task' parameter or use a task-aware tool") + .build()); + return future; + } + return normalTool.getCall().apply(exchange, commandContext, callToolRequest) + .thenApply(result -> (Object) result); + } + + // task aware tools are delegated to ServerTaskToolHandler + CompletableFuture taskToolResult = this.serverTaskToolHandler.handleToolCall( + exchange, commandContext, callToolRequest); + if (taskToolResult != null) { + return taskToolResult.thenApply(result -> { + if (result == null) { + throw McpError.builder(McpSchema.ErrorCodes.INVALID_PARAMS) + .message("Unknown tool: " + callToolRequest.getName()) + .data("Tool not found: " + callToolRequest.getName()) + .build(); + } + return result; + }); } - return toolSpecification.get().getCall().apply(exchange, commandContext, callToolRequest); + CompletableFuture future = new CompletableFuture<>(); + future.completeExceptionally(McpError.builder(McpSchema.ErrorCodes.INVALID_PARAMS) + .message("Unknown tool: " + callToolRequest.getName()) + .data("Tool not found: " + callToolRequest.getName()) + .build()); + return future; }; } @@ -327,7 +430,8 @@ public CompletableFuture addResource(McpServerFeatures.ResourceSpecificati } return CompletableFuture.supplyAsync(() -> { - if (this.resources.putIfAbsent(resourceSpecification.getResource().getUri(), resourceSpecification) != null) { + if (this.resources.putIfAbsent(resourceSpecification.getResource().getUri(), + resourceSpecification) != null) { throw new CompletionException(new McpError( "Resource with URI '" + resourceSpecification.getResource().getUri() + "' already exists")); } @@ -393,7 +497,7 @@ private McpRequestHandler resourcesListRequestHan private McpRequestHandler resourceTemplateListRequestHandler() { return (exchange, commandContext, params) -> CompletableFuture - .completedFuture(new McpSchema.ListResourceTemplatesResult(this.resourceTemplates, null)); + .completedFuture(new McpSchema.ListResourceTemplatesResult(this.resourceTemplates, null)); } private McpRequestHandler resourcesReadRequestHandler() { @@ -430,10 +534,11 @@ public CompletableFuture addPrompt(McpServerFeatures.PromptSpecification p return CompletableFuture.supplyAsync(() -> { McpServerFeatures.PromptSpecification existing = this.prompts - .putIfAbsent(promptSpecification.getPrompt().getName(), promptSpecification); + .putIfAbsent(promptSpecification.getPrompt().getName(), promptSpecification); if (existing != null) { throw new CompletionException( - new McpError("Prompt with name '" + promptSpecification.getPrompt().getName() + "' already exists")); + new McpError( + "Prompt with name '" + promptSpecification.getPrompt().getName() + "' already exists")); } logger.debug("Added prompt handler: {}", promptSpecification.getPrompt().getName()); @@ -533,10 +638,10 @@ private McpRequestHandler> setLoggerRequestHandler() { McpSchema.SetLevelRequest.class); this.minLoggingLevel = request.getLevel(); return CompletableFuture.completedFuture(Collections.emptyMap()); - } - catch (Exception e) { + } catch (Exception e) { CompletableFuture> future = new CompletableFuture<>(); - future.completeExceptionally(new McpError("An error occurred while processing a request to set the log level: " + e.getMessage())); + future.completeExceptionally(new McpError( + "An error occurred while processing a request to set the log level: " + e.getMessage())); return future; } }; diff --git a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/protocol/server/McpNettyServerExchange.java b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/protocol/server/McpNettyServerExchange.java index 8a26808a743..45c18fdf1d4 100644 --- a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/protocol/server/McpNettyServerExchange.java +++ b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/protocol/server/McpNettyServerExchange.java @@ -10,11 +10,17 @@ import com.taobao.arthas.mcp.server.protocol.spec.McpSchema.LoggingLevel; import com.taobao.arthas.mcp.server.protocol.spec.McpSchema.LoggingMessageNotification; import com.taobao.arthas.mcp.server.protocol.spec.McpSession; +import com.taobao.arthas.mcp.server.task.QueuedMessage; +import com.taobao.arthas.mcp.server.task.TaskDefaults; +import com.taobao.arthas.mcp.server.task.TaskMessageQueue; +import com.taobao.arthas.mcp.server.task.TaskStore; import com.taobao.arthas.mcp.server.util.Assert; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.time.Duration; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicLong; /** * Represents the interaction between MCP server and client. Provides methods for communication, logging, and context management. @@ -47,8 +53,16 @@ public class McpNettyServerExchange { private final McpTransportContext transportContext; + private final TaskMessageQueue taskMessageQueue; + + private final TaskStore taskStore; + private volatile LoggingLevel minLoggingLevel = LoggingLevel.INFO; + private final AtomicLong sideChannelRequestCounter = new AtomicLong(0); + + private static final Duration SIDE_CHANNEL_TIMEOUT = Duration.ofMinutes(TaskDefaults.DEFAULT_SIDE_CHANNEL_TIMEOUT_MINUTES); + private static final TypeReference CREATE_MESSAGE_RESULT_TYPE_REF = new TypeReference() { }; @@ -60,19 +74,53 @@ public class McpNettyServerExchange { private static final TypeReference ELICIT_USER_INPUT_RESULT_TYPE_REF = new TypeReference() { }; + + private static final TypeReference GET_TASK_RESULT_TYPE_REF = + new TypeReference() { + }; + + private static final TypeReference CREATE_TASK_RESULT_TYPE_REF = + new TypeReference() { + }; + + private static final TypeReference LIST_TASKS_RESULT_TYPE_REF = + new TypeReference() { + }; + + private static final TypeReference CANCEL_TASK_RESULT_TYPE_REF = + new TypeReference() { + }; public static final TypeReference OBJECT_TYPE_REF = new TypeReference() { }; public McpNettyServerExchange(String sessionId, McpSession session, - McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo, - McpTransportContext transportContext) { + McpSchema.ClientCapabilities clientCapabilities, + McpSchema.Implementation clientInfo, + McpTransportContext transportContext, + TaskMessageQueue taskMessageQueue) { + this(sessionId, session, clientCapabilities, clientInfo, transportContext, taskMessageQueue, null); + } + + + public McpNettyServerExchange(String sessionId, McpSession session, + McpSchema.ClientCapabilities clientCapabilities, + McpSchema.Implementation clientInfo, + McpTransportContext transportContext, + TaskMessageQueue taskMessageQueue, + TaskStore taskStore) { this.sessionId = sessionId; this.session = session; this.clientCapabilities = clientCapabilities; this.clientInfo = clientInfo; this.transportContext = transportContext; + this.taskMessageQueue = taskMessageQueue; + this.taskStore = taskStore; } + + public CompletableFuture sendNotification(String method, Object params) { + return session.sendNotification(method, params); + } /** * Get client capabilities. * @return Client capabilities @@ -105,6 +153,9 @@ public McpTransportContext getTransportContext() { return this.transportContext; } + public String sessionId() { + return this.sessionId; + } /** * Create a new message using client sampling capability. MCP provides a standardized way for servers to request * LLM sampling ("completion" or "generation") through the client. This flow allows clients to maintain control @@ -116,19 +167,33 @@ public McpTransportContext getTransportContext() { */ public CompletableFuture createMessage( McpSchema.CreateMessageRequest createMessageRequest) { + return createMessage(createMessageRequest, null); + } + + public CompletableFuture createMessage( + McpSchema.CreateMessageRequest createMessageRequest, + String taskId) { if (this.clientCapabilities == null) { logger.error("Client not initialized, cannot create message"); CompletableFuture future = new CompletableFuture<>(); - future.completeExceptionally(new McpError("Client must be initialized first. Please call initialize method!")); + future.completeExceptionally(new McpError("Client must be initialized. Call the initialize method first!")); return future; } if (this.clientCapabilities.getSampling() == null) { logger.error("Client not configured with sampling capability, cannot create message"); CompletableFuture future = new CompletableFuture<>(); - future.completeExceptionally(new McpError("Client must be configured with sampling capability")); + future.completeExceptionally(new McpError("Client must be configured with sampling capabilities")); return future; } + // Side-channel flow: enqueue request and wait for response via tasks/result + if (taskId != null && this.taskMessageQueue != null && this.taskStore != null) { + return sideChannelRequest(taskId, McpSchema.METHOD_SAMPLING_CREATE_MESSAGE, + createMessageRequest, McpSchema.CreateMessageResult.class, + "Waiting for sampling response"); + } + + // No task context: send immediately logger.debug("Creating client message, session ID: {}", this.sessionId); return this.session .sendRequest(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE, createMessageRequest, CREATE_MESSAGE_RESULT_TYPE_REF) @@ -171,18 +236,28 @@ public CompletableFuture listRoots(String cursor) { } public CompletableFuture loggingNotification(LoggingMessageNotification loggingMessageNotification) { + return loggingNotification(loggingMessageNotification, null); + } + + public CompletableFuture loggingNotification(LoggingMessageNotification loggingMessageNotification, String taskId) { if (loggingMessageNotification == null) { CompletableFuture future = new CompletableFuture<>(); - future.completeExceptionally(new McpError("log messages cannot be empty")); + future.completeExceptionally(new McpError("Logging message must not be null")); return future; } if (this.isNotificationForLevelAllowed(loggingMessageNotification.getLevel())) { + // Side-channel flow: enqueue notification for delivery via tasks/result + if (taskId != null && this.taskMessageQueue != null) { + return sideChannelNotification(taskId, McpSchema.METHOD_NOTIFICATION_MESSAGE, + loggingMessageNotification); + } + return this.session.sendNotification(McpSchema.METHOD_NOTIFICATION_MESSAGE, loggingMessageNotification) .whenComplete((result, error) -> { if (error != null) { - logger.error("Failed to send logging notification, level: {}, session ID: {}, error: {}", loggingMessageNotification.getLevel(), - this.sessionId, error.getMessage()); + logger.error("Failed to send logging notification, level: {}, session ID: {}, error: {}", + loggingMessageNotification.getLevel(), this.sessionId, error.getMessage()); } }); } @@ -193,10 +268,15 @@ public CompletableFuture ping() { return this.session.sendRequest(McpSchema.METHOD_PING, null, OBJECT_TYPE_REF); } - public CompletableFuture createElicitation(McpSchema.ElicitRequest request) { - if (request == null) { + + public CompletableFuture createElicitation(McpSchema.ElicitRequest elicitRequest) { + return createElicitation(elicitRequest, null); + } + + public CompletableFuture createElicitation(McpSchema.ElicitRequest elicitRequest, String taskId) { + if (elicitRequest == null) { CompletableFuture future = new CompletableFuture<>(); - future.completeExceptionally(new McpError("elicit request cannot be null")); + future.completeExceptionally(new McpError("Elicit request must not be null")); return future; } if (this.clientCapabilities == null) { @@ -209,8 +289,17 @@ public CompletableFuture createElicitation(McpSchema.Eli future.completeExceptionally(new McpError("Client must be configured with elicitation capabilities")); return future; } + + // Side-channel flow: enqueue request and wait for response via tasks/result + if (taskId != null && this.taskMessageQueue != null && this.taskStore != null) { + return sideChannelRequest(taskId, McpSchema.METHOD_ELICITATION_CREATE, + elicitRequest, McpSchema.ElicitResult.class, + "Waiting for user input"); + } + + // No task context: send immediately return this.session - .sendRequest(McpSchema.METHOD_ELICITATION_CREATE, request, ELICIT_USER_INPUT_RESULT_TYPE_REF) + .sendRequest(McpSchema.METHOD_ELICITATION_CREATE, elicitRequest, ELICIT_USER_INPUT_RESULT_TYPE_REF) .whenComplete((result, error) -> { if (error != null) { logger.error("Failed to elicit user input, session ID: {}, error: {}", this.sessionId, error.getMessage()); @@ -231,12 +320,23 @@ private boolean isNotificationForLevelAllowed(LoggingLevel loggingLevel) { } public CompletableFuture progressNotification(McpSchema.ProgressNotification progressNotification) { + return progressNotification(progressNotification, null); + } + + public CompletableFuture progressNotification(McpSchema.ProgressNotification progressNotification, String taskId) { if (progressNotification == null) { CompletableFuture future = new CompletableFuture<>(); - future.completeExceptionally(new McpError("progress notifications cannot be empty")); + future.completeExceptionally(new McpError("Progress notification must not be null")); return future; } + // Side-channel flow: enqueue notification for delivery via tasks/result + if (taskId != null && this.taskMessageQueue != null) { + return sideChannelNotification(taskId, McpSchema.METHOD_NOTIFICATION_PROGRESS, + progressNotification); + } + + // Send immediately return this.session .sendNotification(McpSchema.METHOD_NOTIFICATION_PROGRESS, progressNotification) .whenComplete((result, error) -> { @@ -247,4 +347,87 @@ public CompletableFuture progressNotification(McpSchema.ProgressNotificati } }); } + + + public CompletableFuture getTask(McpSchema.GetTaskRequest getTaskRequest) { + return this.session.sendRequest(McpSchema.METHOD_TASKS_GET, getTaskRequest, GET_TASK_RESULT_TYPE_REF); + } + + public CompletableFuture getTask(String taskId) { + return this.getTask(new McpSchema.GetTaskRequest(taskId, null)); + } + + public CompletableFuture getTaskResult( + McpSchema.GetTaskPayloadRequest getTaskPayloadRequest, + TypeReference resultTypeRef) { + return this.session.sendRequest(McpSchema.METHOD_TASKS_RESULT, getTaskPayloadRequest, resultTypeRef); + } + + public CompletableFuture getTaskResult( + String taskId, + TypeReference resultTypeRef) { + return this.getTaskResult(new McpSchema.GetTaskPayloadRequest(taskId, null), resultTypeRef); + } + + public CompletableFuture listTasks() { + return this.listTasks(null); + } + + public CompletableFuture listTasks(String cursor) { + return this.session.sendRequest(McpSchema.METHOD_TASKS_LIST, + new McpSchema.PaginatedRequest(cursor), + LIST_TASKS_RESULT_TYPE_REF); + } + + public CompletableFuture cancelTask(McpSchema.CancelTaskRequest cancelTaskRequest) { + return this.session.sendRequest(McpSchema.METHOD_TASKS_CANCEL, cancelTaskRequest, CANCEL_TASK_RESULT_TYPE_REF); + } + + public CompletableFuture cancelTask(String taskId) { + Assert.notNull(taskId, "Task ID must not be null"); + if (taskId.trim().isEmpty()) { + CompletableFuture future = new CompletableFuture<>(); + future.completeExceptionally(new IllegalArgumentException("Task ID must not be empty")); + return future; + } + return cancelTask(new McpSchema.CancelTaskRequest(taskId, null)); + } + + // === Side-Channel Helpers === + + @SuppressWarnings("unchecked") + private CompletableFuture sideChannelRequest( + String taskId, String method, McpSchema.Request request, + Class resultType, String inputMessage) { + + String requestId = "sc-" + this.sessionId + "-" + this.sideChannelRequestCounter.getAndIncrement(); + + logger.debug("Side-channel request: taskId={}, method={}, requestId={}", taskId, method, requestId); + + // 1. Enqueue the request for the side-channel handler to pick up + QueuedMessage.Request queuedRequest = new QueuedMessage.Request(requestId, method, request); + + return this.taskMessageQueue.enqueue(taskId, queuedRequest) + .thenCompose(v -> { + // 2. Set task to INPUT_REQUIRED so client polls tasks/result + return this.taskStore.updateTaskStatus(taskId, this.sessionId, + McpSchema.TaskStatus.INPUT_REQUIRED, inputMessage); + }) + .thenCompose(v -> { + // 3. Wait for the response to arrive via the queue + return this.taskMessageQueue.waitForResponse(taskId, requestId, SIDE_CHANNEL_TIMEOUT); + }) + .thenCompose(response -> { + // 4. Restore task to WORKING status + return this.taskStore.updateTaskStatus(taskId, this.sessionId, + McpSchema.TaskStatus.WORKING, null) + .thenApply(v -> (T) response.result()); + }); + } + + private CompletableFuture sideChannelNotification(String taskId, String method, Object notification) { + logger.debug("Side-channel notification: taskId={}, method={}", taskId, method); + QueuedMessage.Notification queuedNotification = new QueuedMessage.Notification(method, notification); + return this.taskMessageQueue.enqueue(taskId, queuedNotification); + } } diff --git a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/protocol/server/McpServer.java b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/protocol/server/McpServer.java index bc01aa32dc7..a5205746dc3 100644 --- a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/protocol/server/McpServer.java +++ b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/protocol/server/McpServer.java @@ -10,6 +10,10 @@ import com.taobao.arthas.mcp.server.protocol.spec.McpSchema; import com.taobao.arthas.mcp.server.protocol.spec.McpStatelessServerTransport; import com.taobao.arthas.mcp.server.protocol.spec.McpStreamableServerTransportProvider; +import com.taobao.arthas.mcp.server.session.ArthasCommandSessionManager; +import com.taobao.arthas.mcp.server.task.TaskAwareToolSpecification; +import com.taobao.arthas.mcp.server.task.TaskMessageQueue; +import com.taobao.arthas.mcp.server.task.TaskStore; import com.taobao.arthas.mcp.server.util.Assert; import java.time.Duration; @@ -50,6 +54,8 @@ class StreamableServerNettySpecification { final List tools = new ArrayList<>(); + final List taskTools = new ArrayList<>(); + final Map resources = new HashMap<>(); final List resourceTemplates = new ArrayList<>(); @@ -60,6 +66,12 @@ class StreamableServerNettySpecification { Duration requestTimeout = Duration.ofSeconds(10); // Default timeout + TaskStore taskStore; + + TaskMessageQueue taskMessageQueue; + + ArthasCommandSessionManager sessionManager; + public StreamableServerNettySpecification(McpStreamableServerTransportProvider transportProvider) { this.transportProvider = transportProvider; } @@ -202,15 +214,87 @@ public StreamableServerNettySpecification commandExecutor(CommandExecutor comman this.commandExecutor = commandExecutor; return this; } + + public StreamableServerNettySpecification sessionManager(ArthasCommandSessionManager sessionManager) { + this.sessionManager = sessionManager; + return this; + } + + public StreamableServerNettySpecification taskTool(TaskAwareToolSpecification taskToolSpecification) { + Assert.notNull(taskToolSpecification, "Task tool specification must not be null"); + assertNoDuplicateTool(taskToolSpecification.tool().getName()); + this.taskTools.add(taskToolSpecification); + return this; + } + + public StreamableServerNettySpecification taskTools(List taskToolSpecifications) { + Assert.notNull(taskToolSpecifications, "Task tool specifications list must not be null"); + for (TaskAwareToolSpecification taskTool : taskToolSpecifications) { + assertNoDuplicateTool(taskTool.tool().getName()); + this.taskTools.add(taskTool); + } + return this; + } + + public StreamableServerNettySpecification taskTools(TaskAwareToolSpecification... taskToolSpecifications) { + Assert.notNull(taskToolSpecifications, "Task tool specifications must not be null"); + for (TaskAwareToolSpecification taskTool : taskToolSpecifications) { + assertNoDuplicateTool(taskTool.tool().getName()); + this.taskTools.add(taskTool); + } + return this; + } + + public StreamableServerNettySpecification taskStore(TaskStore taskStore) { + Assert.notNull(taskStore, "Task store must not be null"); + this.taskStore = taskStore; + return this; + } + + public StreamableServerNettySpecification taskMessageQueue(TaskMessageQueue taskMessageQueue) { + Assert.notNull(taskMessageQueue, "Task message queue must not be null"); + this.taskMessageQueue = taskMessageQueue; + return this; + } + + protected void validateTaskConfiguration() { + boolean hasTaskTools = !this.taskTools.isEmpty(); + boolean hasTaskStore = this.taskStore != null; + + if (hasTaskTools && !hasTaskStore) { + throw new IllegalStateException("Task-aware tools registered but no TaskStore configured. " + + "Add a TaskStore via .taskStore(store) or remove task tools."); + } + // Note: Having taskStore without taskTools is allowed (for future dynamic registration) + } + + private void assertNoDuplicateTool(String toolName) { + for (McpServerFeatures.ToolSpecification tool : this.tools) { + if (tool.getTool().getName().equals(toolName)) { + throw new IllegalArgumentException("Duplicate tool name: " + toolName); + } + } + for (TaskAwareToolSpecification taskTool : this.taskTools) { + if (taskTool.tool().getName().equals(toolName)) { + throw new IllegalArgumentException("Duplicate tool name: " + toolName); + } + } + } public McpNettyServer build() { + validateTaskConfiguration(); + ObjectMapper mapper = this.objectMapper != null ? this.objectMapper : JsonParser.getObjectMapper(); Assert.notNull(this.commandExecutor, "CommandExecutor must be set before building"); return new McpNettyServer( this.transportProvider, mapper, this.requestTimeout, new McpServerFeatures.McpServerConfig(this.serverInfo, this.serverCapabilities, this.tools, - this.resources, this.resourceTemplates, this.prompts, this.rootsChangeHandlers, this.instructions - ), this.commandExecutor + this.taskTools, + this.resources, this.resourceTemplates, this.prompts, this.rootsChangeHandlers, this.instructions, + this.taskStore, + this.taskMessageQueue + ), this.commandExecutor, + this.sessionManager ); } } diff --git a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/protocol/server/McpServerFeatures.java b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/protocol/server/McpServerFeatures.java index 9ab7f073cee..e6da213aeea 100644 --- a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/protocol/server/McpServerFeatures.java +++ b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/protocol/server/McpServerFeatures.java @@ -6,6 +6,10 @@ import com.taobao.arthas.mcp.server.protocol.spec.McpSchema; import com.taobao.arthas.mcp.server.session.ArthasCommandContext; +import com.taobao.arthas.mcp.server.task.TaskAwareToolSpecification; +import com.taobao.arthas.mcp.server.task.TaskManagerOptions; +import com.taobao.arthas.mcp.server.task.TaskMessageQueue; +import com.taobao.arthas.mcp.server.task.TaskStore; import com.taobao.arthas.mcp.server.util.Assert; import com.taobao.arthas.mcp.server.util.Utils; @@ -25,43 +29,59 @@ public static class McpServerConfig { private final McpSchema.Implementation serverInfo; private final McpSchema.ServerCapabilities serverCapabilities; private final List tools; + private final List taskTools; private final Map resources; private final List resourceTemplates; private final Map prompts; private final List, CompletableFuture>> rootsChangeConsumers; private final String instructions; + private final TaskStore taskStore; + + private final TaskMessageQueue taskMessageQueue; + public McpServerConfig( McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, List tools, + List taskTools, Map resources, List resourceTemplates, Map prompts, List, CompletableFuture>> rootsChangeConsumers, - String instructions) { + String instructions, + TaskStore taskStore, + TaskMessageQueue taskMessageQueue) { - Assert.notNull(serverInfo, "The server information cannot be empty"); + Assert.notNull(serverInfo, "Server info must not be null"); - // If serverCapabilities is empty, the appropriate capability configuration - // is automatically built based on the provided capabilities + // 如果 serverCapabilities 为空,根据提供的功能自动构建合适的能力配置 if (serverCapabilities == null) { serverCapabilities = new McpSchema.ServerCapabilities( null, // experimental new McpSchema.ServerCapabilities.LoggingCapabilities(), !Utils.isEmpty(prompts) ? new McpSchema.ServerCapabilities.PromptCapabilities(false) : null, !Utils.isEmpty(resources) ? new McpSchema.ServerCapabilities.ResourceCapabilities(false, false) : null, - !Utils.isEmpty(tools) ? new McpSchema.ServerCapabilities.ToolCapabilities(false) : null); + (!Utils.isEmpty(tools) || !Utils.isEmpty(taskTools)) + ? new McpSchema.ServerCapabilities.ToolCapabilities(false) : null, + !Utils.isEmpty(taskTools) ? McpSchema.ServerCapabilities.TaskCapabilities.builder() + .list() + .cancel() + .toolsCall() + .build() : null); } + this.serverInfo = serverInfo; + this.serverCapabilities = serverCapabilities; this.tools = (tools != null) ? tools : Collections.emptyList(); + this.taskTools = (taskTools != null) ? taskTools : Collections.emptyList(); this.resources = (resources != null) ? resources : Collections.emptyMap(); this.resourceTemplates = (resourceTemplates != null) ? resourceTemplates : Collections.emptyList(); this.prompts = (prompts != null) ? prompts : Collections.emptyMap(); this.rootsChangeConsumers = (rootsChangeConsumers != null) ? rootsChangeConsumers : Collections.emptyList(); - this.serverInfo = serverInfo; - this.serverCapabilities = serverCapabilities; this.instructions = instructions; + this.taskStore = taskStore; + this.taskMessageQueue = taskMessageQueue; } public McpSchema.Implementation getServerInfo() { @@ -75,6 +95,10 @@ public McpSchema.ServerCapabilities getServerCapabilities() { public List getTools() { return tools; } + + public List getTaskTools() { + return taskTools; + } public Map getResources() { return resources; @@ -96,6 +120,30 @@ public String getInstructions() { return instructions; } + public TaskStore getTaskStore() { + return taskStore; + } + + public TaskMessageQueue getTaskMessageQueue() { + return taskMessageQueue; + } + + public TaskManagerOptions getTaskOptions() { + return buildTaskOptions(this.taskStore, this.taskMessageQueue); + } + + private static TaskManagerOptions buildTaskOptions( + TaskStore taskStore, + TaskMessageQueue taskMessageQueue) { + if (taskStore == null && taskMessageQueue == null) { + return null; + } + return TaskManagerOptions.builder() + .store(taskStore) + .messageQueue(taskMessageQueue) + .build(); + } + public static Builder builder() { return new Builder(); } @@ -104,11 +152,14 @@ public static class Builder { private McpSchema.Implementation serverInfo; private McpSchema.ServerCapabilities serverCapabilities; private final List tools = new ArrayList<>(); + private final List taskTools = new ArrayList<>(); private final Map resources = new HashMap<>(); private final List resourceTemplates = new ArrayList<>(); private final Map prompts = new HashMap<>(); private final List, CompletableFuture>> rootsChangeConsumers = new ArrayList<>(); private String instructions; + private TaskStore taskStore; + private TaskMessageQueue taskMessageQueue; public Builder serverInfo(McpSchema.Implementation serverInfo) { this.serverInfo = serverInfo; @@ -125,6 +176,11 @@ public Builder addTool(ToolSpecification tool) { return this; } + public Builder addTaskTool(com.taobao.arthas.mcp.server.task.TaskAwareToolSpecification taskTool) { + this.taskTools.add(taskTool); + return this; + } + public Builder addResource(String key, ResourceSpecification resource) { this.resources.put(key, resource); return this; @@ -151,9 +207,19 @@ public Builder instructions(String instructions) { return this; } + public Builder taskStore(TaskStore taskStore) { + this.taskStore = taskStore; + return this; + } + + public Builder taskMessageQueue(TaskMessageQueue taskMessageQueue) { + this.taskMessageQueue = taskMessageQueue; + return this; + } + public McpServerConfig build() { - return new McpServerConfig(serverInfo, serverCapabilities, tools, resources, resourceTemplates, prompts, - rootsChangeConsumers, instructions); + return new McpServerConfig(serverInfo, serverCapabilities, tools, taskTools, resources, + resourceTemplates, prompts, rootsChangeConsumers, instructions, taskStore, taskMessageQueue); } } } diff --git a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/protocol/server/transport/NettyStreamableServerTransportProvider.java b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/protocol/server/transport/NettyStreamableServerTransportProvider.java index 549294a71af..9f6f4be68c2 100644 --- a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/protocol/server/transport/NettyStreamableServerTransportProvider.java +++ b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/protocol/server/transport/NettyStreamableServerTransportProvider.java @@ -56,7 +56,7 @@ private NettyStreamableServerTransportProvider(ObjectMapper objectMapper, String @Override public List protocolVersions() { return Arrays.asList(ProtocolVersions.MCP_2024_11_05, ProtocolVersions.MCP_2025_03_26, - ProtocolVersions.MCP_2025_06_18); + ProtocolVersions.MCP_2025_06_18, ProtocolVersions.MCP_2025_11_25); } @Override diff --git a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/protocol/spec/DefaultMcpStreamableServerSessionFactory.java b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/protocol/spec/DefaultMcpStreamableServerSessionFactory.java index 8becafec962..96d9f26f624 100644 --- a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/protocol/spec/DefaultMcpStreamableServerSessionFactory.java +++ b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/protocol/spec/DefaultMcpStreamableServerSessionFactory.java @@ -9,6 +9,8 @@ import com.taobao.arthas.mcp.server.protocol.server.McpNotificationHandler; import com.taobao.arthas.mcp.server.protocol.server.McpRequestHandler; import com.taobao.arthas.mcp.server.protocol.server.store.InMemoryEventStore; +import com.taobao.arthas.mcp.server.task.TaskMessageQueue; +import com.taobao.arthas.mcp.server.task.TaskStore; import java.time.Duration; import java.util.Map; @@ -27,17 +29,23 @@ public class DefaultMcpStreamableServerSessionFactory implements McpStreamableSe private final Map> requestHandlers; private final Map notificationHandlers; private final CommandExecutor commandExecutor; + private final TaskStore taskStore; + private final TaskMessageQueue taskMessageQueue; public DefaultMcpStreamableServerSessionFactory(Duration requestTimeout, McpInitRequestHandler mcpInitRequestHandler, Map> requestHandlers, Map notificationHandlers, - CommandExecutor commandExecutor) { + CommandExecutor commandExecutor, + TaskStore taskStore, + TaskMessageQueue taskMessageQueue) { this.requestTimeout = requestTimeout; this.mcpInitRequestHandler = mcpInitRequestHandler; this.requestHandlers = requestHandlers; this.notificationHandlers = notificationHandlers; this.commandExecutor = commandExecutor; + this.taskStore = taskStore; + this.taskMessageQueue = taskMessageQueue; } @Override @@ -53,7 +61,9 @@ public McpStreamableServerSession.McpStreamableServerSessionInit startSession( requestHandlers, notificationHandlers, commandExecutor, - new InMemoryEventStore()); + new InMemoryEventStore(), + taskStore, + taskMessageQueue); // Handle the initialization request CompletableFuture initResult = diff --git a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/protocol/spec/McpError.java b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/protocol/spec/McpError.java index b128477bb93..84e1dd1af97 100644 --- a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/protocol/spec/McpError.java +++ b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/protocol/spec/McpError.java @@ -5,6 +5,7 @@ package com.taobao.arthas.mcp.server.protocol.spec; import com.taobao.arthas.mcp.server.protocol.spec.McpSchema.JSONRPCResponse.JSONRPCError; +import com.taobao.arthas.mcp.server.util.Assert; /** * Exception class for representing JSON-RPC errors in MCP protocol. @@ -13,19 +14,97 @@ */ public class McpError extends RuntimeException { - private JSONRPCError jsonRpcError; + private JSONRPCError jsonRpcError; - public McpError(JSONRPCError jsonRpcError) { - super(jsonRpcError.getMessage()); - this.jsonRpcError = jsonRpcError; - } + public McpError(JSONRPCError jsonRpcError) { + super(jsonRpcError.getMessage()); + this.jsonRpcError = jsonRpcError; + } - public McpError(Object error) { - super(error.toString()); - } + @Deprecated + public McpError(Object error) { + super(error.toString()); + } - public JSONRPCError getJsonRpcError() { - return jsonRpcError; - } + public JSONRPCError getJsonRpcError() { + return jsonRpcError; + } + + @Override + public String toString() { + StringBuilder builder = new StringBuilder(super.toString()); + if (jsonRpcError != null) { + builder.append("\n"); + builder.append(jsonRpcError.toString()); + } + return builder.toString(); + } + + public static Builder builder(int errorCode) { + return new Builder(errorCode); + } + + public static class Builder { + + private final int code; + + private String message; + + private Object data; + + private Builder(int code) { + this.code = code; + } + + public Builder message(String message) { + this.message = message; + return this; + } + + public Builder data(Object data) { + this.data = data; + return this; + } + + public McpError build() { + Assert.hasText(message, "message must not be empty"); + return new McpError(new JSONRPCError(code, message, data)); + } + + } + + public static Throwable findRootCause(Throwable throwable) { + Assert.notNull(throwable, "throwable must not be null"); + Throwable rootCause = throwable; + while (rootCause.getCause() != null && rootCause.getCause() != rootCause) { + rootCause = rootCause.getCause(); + } + return rootCause; + } + + public static String aggregateExceptionMessages(Throwable throwable) { + Assert.notNull(throwable, "throwable must not be null"); + + StringBuilder messages = new StringBuilder(); + Throwable current = throwable; + + while (current != null) { + if (messages.length() > 0) { + messages.append("\n Caused by: "); + } + + messages.append(current.getClass().getSimpleName()); + if (current.getMessage() != null) { + messages.append(": ").append(current.getMessage()); + } + + if (current.getCause() == current) { + break; + } + current = current.getCause(); + } + + return messages.toString(); + } } diff --git a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/protocol/spec/McpSchema.java b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/protocol/spec/McpSchema.java index 081a865d7b3..48b7d3e2305 100644 --- a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/protocol/spec/McpSchema.java +++ b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/protocol/spec/McpSchema.java @@ -16,9 +16,11 @@ import java.util.*; /** - * Based on the JSON-RPC 2.0 specification - * and the - * Model Context Protocol Schema. + * Based on the JSON-RPC 2.0 + * specification + * and the + * Model Context Protocol Schema. * * @author Yeaury */ @@ -29,7 +31,7 @@ public final class McpSchema { private McpSchema() { } - public static final String LATEST_PROTOCOL_VERSION = ProtocolVersions.MCP_2025_06_18; + public static final String LATEST_PROTOCOL_VERSION = ProtocolVersions.MCP_2025_11_25; public static final String JSONRPC_VERSION = "2.0"; @@ -83,7 +85,28 @@ private McpSchema() { public static final String METHOD_SAMPLING_CREATE_MESSAGE = "sampling/createMessage"; // Elicitation Methods - public static final String METHOD_ELICITATION_CREATE = "elicitation/create"; + public static final String METHOD_ELICITATION_CREATE = "elicitation/create"; + + // Tasks Methods + public static final String METHOD_TASKS_LIST = "tasks/list"; + public static final String METHOD_TASKS_GET = "tasks/get"; + public static final String METHOD_TASKS_RESULT = "tasks/result"; // Blocking result retrieval + public static final String METHOD_TASKS_CANCEL = "tasks/cancel"; + public static final String METHOD_NOTIFICATION_TASKS_STATUS = "notifications/tasks/status"; + public static final String METHOD_NOTIFICATION_TASKS_LIST_CHANGED = "notifications/tasks/list_changed"; + + // --------------------------- + // Metadata Keys + // --------------------------- + + /** + * 标准的关联任务元数据键。 + * + *

+ * 所有与任务相关的请求、响应和通知都应在 _meta 字段中包含此键, + * 其值为 RelatedTaskMetadata 对象,用于关联消息与其对应的任务。 + */ + public static final String RELATED_TASK_META_KEY = "io.modelcontextprotocol/related-task"; private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); @@ -122,40 +145,49 @@ public static final class ErrorCodes { } - public interface Meta { + public interface Meta { - default Map meta() { - return null; - } + default Map meta() { + return null; + } - } + } public interface Request extends Meta { - default Object progressToken() { - Map metadata = meta(); - if (metadata != null && metadata.containsKey("progressToken")) { - return metadata.get("progressToken"); - } - return null; - } + default Object progressToken() { + Map metadata = meta(); + if (metadata != null && metadata.containsKey("progressToken")) { + return metadata.get("progressToken"); + } + return null; + } } public interface Result extends Meta { } + + public interface ServerTaskPayloadResult extends Result { + } + + public interface ClientTaskPayloadResult extends Result { + } + private static final TypeReference> MAP_TYPE_REF = new TypeReference>() { }; /** * Deserializes a JSON string into a JSONRPCMessage object. + * * @param objectMapper The ObjectMapper instance to use for deserialization - * @param jsonText The JSON string to deserialize + * @param jsonText The JSON string to deserialize * @return A JSONRPCMessage instance using either the {@link JSONRPCRequest}, - * {@link JSONRPCNotification}, or {@link JSONRPCResponse} classes. - * @throws IOException If there's an error during deserialization - * @throws IllegalArgumentException If the JSON structure doesn't match any known - * message type + * {@link JSONRPCNotification}, or {@link JSONRPCResponse} classes. + * @throws IOException If there's an error during deserialization + * @throws IllegalArgumentException If the JSON structure doesn't match any + * known + * message type */ public static JSONRPCMessage deserializeJsonRpcMessage(ObjectMapper objectMapper, String jsonText) throws IOException { @@ -167,11 +199,9 @@ public static JSONRPCMessage deserializeJsonRpcMessage(ObjectMapper objectMapper // Determine message type based on specific JSON structure if (map.containsKey("method") && map.containsKey("id")) { return objectMapper.convertValue(map, JSONRPCRequest.class); - } - else if (map.containsKey("method") && !map.containsKey("id")) { + } else if (map.containsKey("method") && !map.containsKey("id")) { return objectMapper.convertValue(map, JSONRPCNotification.class); - } - else if (map.containsKey("result") || map.containsKey("error")) { + } else if (map.containsKey("result") || map.containsKey("error")) { return objectMapper.convertValue(map, JSONRPCResponse.class); } @@ -388,8 +418,10 @@ public String getInstructions() { } /** - * Clients can implement additional features to enrich connected MCP servers with - * additional capabilities. These capabilities can be used to extend the functionality + * Clients can implement additional features to enrich connected MCP servers + * with + * additional capabilities. These capabilities can be used to extend the + * functionality * of the server, or to provide additional information to the server about the * client's capabilities. */ @@ -400,27 +432,28 @@ public static class ClientCapabilities { private final Map experimental; private final RootCapabilities roots; private final Sampling sampling; - private final Elicitation elicitation; + private final Elicitation elicitation; public ClientCapabilities( @JsonProperty("experimental") Map experimental, @JsonProperty("roots") RootCapabilities roots, @JsonProperty("sampling") Sampling sampling, - @JsonProperty("elicitation") Elicitation elicitation) { + @JsonProperty("elicitation") Elicitation elicitation) { this.experimental = experimental; this.roots = roots; this.sampling = sampling; - this.elicitation = elicitation; + this.elicitation = elicitation; } /** - * Roots define the boundaries of where servers can operate within the filesystem, + * Roots define the boundaries of where servers can operate within the + * filesystem, * allowing them to understand which directories and files they have access to. * Servers can request the list of roots from supporting clients and * receive notifications when that list changes. - */ + */ @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) + @JsonIgnoreProperties(ignoreUnknown = true) public static class RootCapabilities { private final Boolean listChanged; @@ -436,7 +469,7 @@ public Boolean getListChanged() { /** * Provides a standardized way for servers to request LLM - * sampling ("completions" or "generations") from language + * sampling ("completions" or "generations") from language * models via clients. This flow allows clients to maintain * control over model access, selection, and permissions * while enabling servers to leverage AI capabilities—with @@ -444,13 +477,13 @@ public Boolean getListChanged() { * image-based interactions and optionally include context * from MCP servers in their prompts. */ - @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonInclude(JsonInclude.Include.NON_ABSENT) public static class Sampling { } - @JsonInclude(JsonInclude.Include.NON_ABSENT) - public static class Elicitation { - } + @JsonInclude(JsonInclude.Include.NON_ABSENT) + public static class Elicitation { + } public Map getExperimental() { return experimental; @@ -464,9 +497,9 @@ public Sampling getSampling() { return sampling; } - public Elicitation getElicitation() { - return elicitation; - } + public Elicitation getElicitation() { + return elicitation; + } public static Builder builder() { return new Builder(); @@ -476,7 +509,7 @@ public static class Builder { private Map experimental; private RootCapabilities roots; private Sampling sampling; - private Elicitation elicitation; + private Elicitation elicitation; public Builder experimental(Map experimental) { this.experimental = experimental; @@ -493,10 +526,10 @@ public Builder sampling() { return this; } - public Builder elicitation() { - this.elicitation = new Elicitation(); - return this; - } + public Builder elicitation() { + this.elicitation = new Elicitation(); + return this; + } public ClientCapabilities build() { return new ClientCapabilities(experimental, roots, sampling, elicitation); @@ -512,18 +545,31 @@ public static class ServerCapabilities { private final PromptCapabilities prompts; private final ResourceCapabilities resources; private final ToolCapabilities tools; + private final TaskCapabilities tasks; public ServerCapabilities( @JsonProperty("experimental") Map experimental, @JsonProperty("logging") LoggingCapabilities logging, @JsonProperty("prompts") PromptCapabilities prompts, @JsonProperty("resources") ResourceCapabilities resources, - @JsonProperty("tools") ToolCapabilities tools) { + @JsonProperty("tools") ToolCapabilities tools, + @JsonProperty("tasks") TaskCapabilities tasks) { this.experimental = experimental; this.logging = logging; this.prompts = prompts; this.resources = resources; this.tools = tools; + this.tasks = tasks; + } + + // Backward compatibility constructor + public ServerCapabilities( + Map experimental, + LoggingCapabilities logging, + PromptCapabilities prompts, + ResourceCapabilities resources, + ToolCapabilities tools) { + this(experimental, logging, prompts, resources, tools, null); } public static Builder builder() { @@ -536,6 +582,7 @@ public static class Builder { private PromptCapabilities prompts; private ResourceCapabilities resources; private ToolCapabilities tools; + private TaskCapabilities tasks; public Builder experimental(Map experimental) { this.experimental = experimental; @@ -562,8 +609,13 @@ public Builder tools(ToolCapabilities tools) { return this; } + public Builder tasks(TaskCapabilities tasks) { + this.tasks = tasks; + return this; + } + public ServerCapabilities build() { - return new ServerCapabilities(experimental, logging, prompts, resources, tools); + return new ServerCapabilities(experimental, logging, prompts, resources, tools, tasks); } } @@ -618,6 +670,106 @@ public Boolean getListChanged() { } } + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public static class TaskCapabilities { + private final ListTaskCapability list; + private final CancelTaskCapability cancel; + private final TaskRequestCapabilities requests; + + public TaskCapabilities( + @JsonProperty("list") ListTaskCapability list, + @JsonProperty("cancel") CancelTaskCapability cancel, + @JsonProperty("requests") TaskRequestCapabilities requests) { + this.list = list; + this.cancel = cancel; + this.requests = requests; + } + + public ListTaskCapability getList() { + return list; + } + + public CancelTaskCapability getCancel() { + return cancel; + } + + public TaskRequestCapabilities getRequests() { + return requests; + } + + @JsonIgnoreProperties(ignoreUnknown = true) + public static class ListTaskCapability { + } + + @JsonIgnoreProperties(ignoreUnknown = true) + public static class CancelTaskCapability { + } + + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public static class TaskRequestCapabilities { + private final ToolsTaskCapabilities tools; + + public TaskRequestCapabilities(@JsonProperty("tools") ToolsTaskCapabilities tools) { + this.tools = tools; + } + + public ToolsTaskCapabilities getTools() { + return tools; + } + + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public static class ToolsTaskCapabilities { + private final CallTaskCapability call; + + public ToolsTaskCapabilities(@JsonProperty("call") CallTaskCapability call) { + this.call = call; + } + + public CallTaskCapability getCall() { + return call; + } + + @JsonIgnoreProperties(ignoreUnknown = true) + public static class CallTaskCapability { + } + } + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private ListTaskCapability list; + private CancelTaskCapability cancel; + private TaskRequestCapabilities requests; + + public Builder list() { + this.list = new ListTaskCapability(); + return this; + } + + public Builder cancel() { + this.cancel = new CancelTaskCapability(); + return this; + } + + public Builder toolsCall() { + this.requests = new TaskRequestCapabilities( + new TaskRequestCapabilities.ToolsTaskCapabilities( + new TaskRequestCapabilities.ToolsTaskCapabilities.CallTaskCapability())); + return this; + } + + public TaskCapabilities build() { + return new TaskCapabilities(list, cancel, requests); + } + } + } + public Map getExperimental() { return experimental; } @@ -637,8 +789,11 @@ public ResourceCapabilities getResources() { public ToolCapabilities getTools() { return tools; } - } + public TaskCapabilities getTasks() { + return tasks; + } + } @JsonInclude(JsonInclude.Include.NON_ABSENT) @JsonIgnoreProperties(ignoreUnknown = true) @@ -664,27 +819,36 @@ public String getVersion() { // Existing Enums and Base Types public enum Role { - @JsonProperty("user") USER, - @JsonProperty("assistant") ASSISTANT + @JsonProperty("user") + USER, + @JsonProperty("assistant") + ASSISTANT } public enum StopReason { - @JsonProperty("stop") STOP, - @JsonProperty("length") LENGTH, - @JsonProperty("content_filter") CONTENT_FILTER + @JsonProperty("stop") + STOP, + @JsonProperty("length") + LENGTH, + @JsonProperty("content_filter") + CONTENT_FILTER } public enum ContextInclusionStrategy { - @JsonProperty("none") NONE, - @JsonProperty("all") ALL, - @JsonProperty("relevant") RELEVANT + @JsonProperty("none") + NONE, + @JsonProperty("all") + ALL, + @JsonProperty("relevant") + RELEVANT } // --------------------------- // Resource Interfaces // --------------------------- /** - * Base for objects that include optional annotations for the client. The client can + * Base for objects that include optional annotations for the client. The client + * can * use annotations to inform how objects are used or displayed */ public interface Annotated { @@ -694,7 +858,8 @@ public interface Annotated { } /** - * Optional annotations for the client. The client can use annotations to inform how + * Optional annotations for the client. The client can use annotations to inform + * how * objects are used or displayed. */ @JsonInclude(JsonInclude.Include.NON_ABSENT) @@ -943,7 +1108,8 @@ public Map getMeta() { } /** - * Sent from the client to request resources/updated notifications from the server + * Sent from the client to request resources/updated notifications from the + * server * whenever a particular resource changes. */ @JsonInclude(JsonInclude.Include.NON_ABSENT) @@ -986,12 +1152,14 @@ public interface ResourceContents { /** * The URI of this resource. + * * @return the URI of this resource. */ String uri(); /** * The MIME type of this resource. + * * @return the MIME type of this resource. */ String mimeType(); @@ -1035,7 +1203,8 @@ public String getText() { /** * Binary contents of a resource. * - * This must only be set if the resource can actually be represented as binary data + * This must only be set if the resource can actually be represented as binary + * data * (not text). */ @JsonInclude(JsonInclude.Include.NON_ABSENT) @@ -1138,7 +1307,8 @@ public Boolean getRequired() { /** * Describes a message returned as part of a prompt. - * This is similar to `SamplingMessage`, but also supports the embedding of resources + * This is similar to `SamplingMessage`, but also supports the embedding of + * resources * from the MCP server. */ @JsonInclude(JsonInclude.Include.NON_ABSENT) @@ -1352,7 +1522,8 @@ public Boolean getAdditionalProperties() { /** * Represents a tool that the server provides. Tools enable servers to expose - * executable functionality to the system. Through these tools, you can interact with + * executable functionality to the system. Through these tools, you can interact + * with * external systems, perform computations, and take actions in the real world. */ @JsonInclude(JsonInclude.Include.NON_ABSENT) @@ -1361,14 +1532,21 @@ public static class Tool { private final String name; private final String description; private final JsonSchema inputSchema; + private final ToolExecution execution; public Tool( @JsonProperty("name") String name, @JsonProperty("description") String description, - @JsonProperty("inputSchema") JsonSchema inputSchema) { + @JsonProperty("inputSchema") JsonSchema inputSchema, + @JsonProperty("execution") ToolExecution execution) { this.name = name; this.description = description; this.inputSchema = inputSchema; + this.execution = execution; + } + + public Tool(String name, String description, JsonSchema inputSchema) { + this(name, description, inputSchema, null); } public String getName() { @@ -1382,13 +1560,100 @@ public String getDescription() { public JsonSchema getInputSchema() { return inputSchema; } + + public ToolExecution getExecution() { + return execution; + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private String name; + private String description; + private JsonSchema inputSchema; + private ToolExecution execution; + + public Builder name(String name) { + this.name = name; + return this; + } + + public Builder description(String description) { + this.description = description; + return this; + } + + public Builder inputSchema(JsonSchema inputSchema) { + this.inputSchema = inputSchema; + return this; + } + + public Builder execution(ToolExecution execution) { + this.execution = execution; + return this; + } + + public Builder taskSupport(String taskSupport) { + TaskSupportMode mode = null; + if (taskSupport != null) { + mode = TaskSupportMode.valueOf(taskSupport.toUpperCase()); + } + this.execution = new ToolExecution(mode); + return this; + } + + public Tool build() { + if (name == null || name.trim().isEmpty()) { + throw new IllegalArgumentException("Tool name must not be null or empty"); + } + if (inputSchema == null) { + throw new IllegalArgumentException("Tool inputSchema must not be null"); + } + return new Tool(name, description, inputSchema, execution); + } + } + } + + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public enum TaskSupportMode { + + @JsonProperty("forbidden") + FORBIDDEN, + + @JsonProperty("optional") + OPTIONAL, + + @JsonProperty("required") + REQUIRED + } + + public static class ToolExecution { + private final TaskSupportMode taskSupport; + + public ToolExecution(@JsonProperty("taskSupport") TaskSupportMode taskSupport) { + this.taskSupport = taskSupport; + } + + public TaskSupportMode getTaskSupport() { + return taskSupport; + } + + public boolean supportsTask() { + return taskSupport == TaskSupportMode.OPTIONAL || taskSupport == TaskSupportMode.REQUIRED; + } + + public boolean requiresTask() { + return taskSupport == TaskSupportMode.REQUIRED; + } } private static JsonSchema parseSchema(String schema) { try { return OBJECT_MAPPER.readValue(schema, JsonSchema.class); - } - catch (IOException e) { + } catch (IOException e) { throw new IllegalArgumentException("Invalid schema: " + schema, e); } } @@ -1402,25 +1667,31 @@ public static class CallToolRequest implements Request { private final String name; private final Map arguments; private final Map meta; + private final TaskMetadata task; public CallToolRequest( @JsonProperty("name") String name, @JsonProperty("arguments") Map arguments, - @JsonProperty("_meta") Map meta) { + @JsonProperty("_meta") Map meta, + @JsonProperty("task") TaskMetadata task) { this.name = name; this.arguments = arguments; this.meta = meta; + this.task = task; } private static Map parseJsonArguments(String jsonArguments) { try { return OBJECT_MAPPER.readValue(jsonArguments, MAP_TYPE_REF); - } - catch (IOException e) { + } catch (IOException e) { throw new IllegalArgumentException("Invalid arguments: " + jsonArguments, e); } } + public CallToolRequest(String name, Map arguments, Map meta) { + this(name, arguments, meta, null); + } + public String getName() { return name; } @@ -1434,6 +1705,14 @@ public Map meta() { return meta; } + public Map getMeta() { + return meta; + } + + public TaskMetadata getTask() { + return task; + } + public static Builder builder() { return new Builder(); } @@ -1446,6 +1725,8 @@ public static class Builder { private Map meta; + private TaskMetadata task; + public Builder name(String name) { this.name = name; return this; @@ -1466,6 +1747,16 @@ public Builder meta(Map meta) { return this; } + public Builder task(TaskMetadata task) { + this.task = task; + return this; + } + + public Builder taskWithTtl(Long ttl) { + this.task = new TaskMetadata(ttl); + return this; + } + public Builder progressToken(String progressToken) { if (this.meta == null) { this.meta = new HashMap<>(); @@ -1476,107 +1767,109 @@ public Builder progressToken(String progressToken) { public CallToolRequest build() { Assert.hasText(name, "name must not be empty"); - return new CallToolRequest(name, arguments, meta); + return new CallToolRequest(name, arguments, meta, task); } } } /** * The server's response to a tools/call request from the client. + * + *

+ * 实现 {@link ServerTaskPayloadResult},可作为服务端任务的结果类型。 */ - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public static class CallToolResult implements Result { - private final List content; - private final Boolean isError; - private final Map meta; - - public CallToolResult( - @JsonProperty("content") List content, - @JsonProperty("isError") Boolean isError, - @JsonProperty("_meta") Map meta) { - this.content = content; - this.isError = isError; - this.meta = meta; - } - - public CallToolResult(String content, Boolean isError, Map meta) { - this(Collections.singletonList(new TextContent(content)), isError, meta); - } - - public List getContent() { - return content; - } - - public Boolean getIsError() { - return isError; - } - - @Override - public Map meta() { - return meta; - } - - public Map getMeta() { - return meta(); - } - - public static Builder builder() { - return new Builder(); - } - - public static class Builder { - private List content = new ArrayList<>(); - private Boolean isError; - private Map meta; - - public Builder content(List content) { - Assert.notNull(content, "content must not be null"); - this.content = content; - return this; - } - - public Builder textContent(List textContent) { - Assert.notNull(textContent, "textContent must not be null"); - textContent.stream() - .map(TextContent::new) - .forEach(this.content::add); - return this; - } - - public Builder addContent(Content contentItem) { - Assert.notNull(contentItem, "contentItem must not be null"); - if (this.content == null) { - this.content = new ArrayList<>(); - } - this.content.add(contentItem); - return this; - } - - public Builder addTextContent(String text) { - Assert.notNull(text, "text must not be null"); - return addContent(new TextContent(text)); - } - - public Builder isError(Boolean isError) { - Assert.notNull(isError, "isError must not be null"); - this.isError = isError; - return this; - } - - public Builder meta(Map meta) { - this.meta = meta; - return this; - } - - public CallToolResult build() { - return new CallToolResult(content, isError, meta); - } - } - } - - - // --------------------------- + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public static class CallToolResult implements ServerTaskPayloadResult { + private final List content; + private final Boolean isError; + private final Map meta; + + public CallToolResult( + @JsonProperty("content") List content, + @JsonProperty("isError") Boolean isError, + @JsonProperty("_meta") Map meta) { + this.content = content; + this.isError = isError; + this.meta = meta; + } + + public CallToolResult(String content, Boolean isError, Map meta) { + this(Collections.singletonList(new TextContent(content)), isError, meta); + } + + public List getContent() { + return content; + } + + public Boolean getIsError() { + return isError; + } + + @Override + public Map meta() { + return meta; + } + + public Map getMeta() { + return meta(); + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private List content = new ArrayList<>(); + private Boolean isError; + private Map meta; + + public Builder content(List content) { + Assert.notNull(content, "content must not be null"); + this.content = content; + return this; + } + + public Builder textContent(List textContent) { + Assert.notNull(textContent, "textContent must not be null"); + textContent.stream() + .map(TextContent::new) + .forEach(this.content::add); + return this; + } + + public Builder addContent(Content contentItem) { + Assert.notNull(contentItem, "contentItem must not be null"); + if (this.content == null) { + this.content = new ArrayList<>(); + } + this.content.add(contentItem); + return this; + } + + public Builder addTextContent(String text) { + Assert.notNull(text, "text must not be null"); + return addContent(new TextContent(text)); + } + + public Builder isError(Boolean isError) { + Assert.notNull(isError, "isError must not be null"); + this.isError = isError; + return this; + } + + public Builder meta(Map meta) { + this.meta = meta; + return this; + } + + public CallToolResult build() { + return new CallToolResult(content, isError, meta); + } + } + } + + // --------------------------- // Sampling Interfaces // --------------------------- @JsonInclude(JsonInclude.Include.NON_ABSENT) @@ -1673,7 +1966,7 @@ public CreateMessageRequest( @JsonProperty("temperature") Double temperature, @JsonProperty("maxTokens") int maxTokens, @JsonProperty("stopSequences") List stopSequences, - @JsonProperty("_meta") Map meta) { + @JsonProperty("_meta") Map meta) { this.messages = messages; this.modelPreferences = modelPreferences; this.systemPrompt = systemPrompt; @@ -1712,10 +2005,10 @@ public List getStopSequences() { return stopSequences; } - @Override - public Map meta() { - return meta; - } + @Override + public Map meta() { + return meta; + } } @JsonInclude(JsonInclude.Include.NON_ABSENT) @@ -1754,221 +2047,223 @@ public StopReason getStopReason() { } } - // Elicitation - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public static class ElicitRequest implements Request { - - private final String message; - private final Map requestedSchema; - private final Map meta; - - // Constructor - public ElicitRequest( - @JsonProperty("message") String message, - @JsonProperty("requestedSchema") Map requestedSchema, - @JsonProperty("_meta") Map meta) { - this.message = message; - this.requestedSchema = requestedSchema; - this.meta = meta; - } - - public String getMessage() { - return message; - } - - public Map getRequestedSchema() { - return requestedSchema; - } - - @Override - public Map meta() { - return meta; - } - - public Map getMeta() { - return meta(); - } - - // Backwards compatibility constructor - public ElicitRequest(String message, Map requestedSchema) { - this(message, requestedSchema, null); - } - - public static Builder builder() { - return new Builder(); - } - - public static class Builder { - - private String message; - private Map requestedSchema; - private Map meta; - - public Builder message(String message) { - this.message = message; - return this; - } - - public Builder requestedSchema(Map requestedSchema) { - this.requestedSchema = requestedSchema; - return this; - } - - public Builder meta(Map meta) { - this.meta = meta; - return this; - } - - public Builder progressToken(Object progressToken) { - if (this.meta == null) { - this.meta = new HashMap<>(); - } - this.meta.put("progressToken", progressToken); - return this; - } - - public ElicitRequest build() { - return new ElicitRequest(message, requestedSchema, meta); - } - } - } - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public static class ElicitResult implements Result { - - private final Action action; - private final Map content; - private final Map meta; - - public enum Action { - @JsonProperty("accept") ACCEPT, - @JsonProperty("decline") DECLINE, - @JsonProperty("cancel") CANCEL - } - - // Constructor - public ElicitResult( - @JsonProperty("action") Action action, - @JsonProperty("content") Map content, - @JsonProperty("_meta") Map meta) { - this.action = action; - this.content = content; - this.meta = meta; - } - - public Action getAction() { - return action; - } - - public Map getContent() { - return content; - } - - @Override - public Map meta() { - return meta; - } - - public Map getMeta() { - return meta(); - } - - // Backwards compatibility constructor - public ElicitResult(Action action, Map content) { - this(action, content, null); - } - - public static Builder builder() { - return new Builder(); - } - - public static class Builder { - - private Action action; - private Map content; - private Map meta; - - public Builder action(Action action) { - this.action = action; - return this; - } - - public Builder content(Map content) { - this.content = content; - return this; - } - - public Builder meta(Map meta) { - this.meta = meta; - return this; - } - - public ElicitResult build() { - return new ElicitResult(action, content, meta); - } - } - } - - - // --------------------------- - // Pagination Interfaces - // --------------------------- + // Elicitation @JsonInclude(JsonInclude.Include.NON_ABSENT) @JsonIgnoreProperties(ignoreUnknown = true) - public static class PaginatedRequest { - private final String cursor; + public static class ElicitRequest implements Request { - public PaginatedRequest( - @JsonProperty("cursor") String cursor) { - this.cursor = cursor; - } + private final String message; + private final Map requestedSchema; + private final Map meta; - public String getCursor() { - return cursor; + // Constructor + public ElicitRequest( + @JsonProperty("message") String message, + @JsonProperty("requestedSchema") Map requestedSchema, + @JsonProperty("_meta") Map meta) { + this.message = message; + this.requestedSchema = requestedSchema; + this.meta = meta; } - } - - @JsonInclude(JsonInclude.Include.NON_ABSENT) - @JsonIgnoreProperties(ignoreUnknown = true) - public static class PaginatedResult { - private final String nextCursor; - public PaginatedResult( - @JsonProperty("nextCursor") String nextCursor) { - this.nextCursor = nextCursor; + public String getMessage() { + return message; } - public String getNextCursor() { - return nextCursor; + public Map getRequestedSchema() { + return requestedSchema; } - } - // --------------------------- - // Progress and Logging - // --------------------------- - @JsonIgnoreProperties(ignoreUnknown = true) - public static class ProgressNotification { - private final String progressToken; - private final double progress; - private final Double total; + @Override + public Map meta() { + return meta; + } - public ProgressNotification( - @JsonProperty("progressToken") String progressToken, - @JsonProperty("progress") double progress, - @JsonProperty("total") Double total) { - this.progressToken = progressToken; - this.progress = progress; - this.total = total; + public Map getMeta() { + return meta(); } - public String getProgressToken() { - return progressToken; + // Backwards compatibility constructor + public ElicitRequest(String message, Map requestedSchema) { + this(message, requestedSchema, null); } - public double getProgress() { - return progress; + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private String message; + private Map requestedSchema; + private Map meta; + + public Builder message(String message) { + this.message = message; + return this; + } + + public Builder requestedSchema(Map requestedSchema) { + this.requestedSchema = requestedSchema; + return this; + } + + public Builder meta(Map meta) { + this.meta = meta; + return this; + } + + public Builder progressToken(Object progressToken) { + if (this.meta == null) { + this.meta = new HashMap<>(); + } + this.meta.put("progressToken", progressToken); + return this; + } + + public ElicitRequest build() { + return new ElicitRequest(message, requestedSchema, meta); + } + } + } + + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public static class ElicitResult implements Result { + + private final Action action; + private final Map content; + private final Map meta; + + public enum Action { + @JsonProperty("accept") + ACCEPT, + @JsonProperty("decline") + DECLINE, + @JsonProperty("cancel") + CANCEL + } + + // Constructor + public ElicitResult( + @JsonProperty("action") Action action, + @JsonProperty("content") Map content, + @JsonProperty("_meta") Map meta) { + this.action = action; + this.content = content; + this.meta = meta; + } + + public Action getAction() { + return action; + } + + public Map getContent() { + return content; + } + + @Override + public Map meta() { + return meta; + } + + public Map getMeta() { + return meta(); + } + + // Backwards compatibility constructor + public ElicitResult(Action action, Map content) { + this(action, content, null); + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private Action action; + private Map content; + private Map meta; + + public Builder action(Action action) { + this.action = action; + return this; + } + + public Builder content(Map content) { + this.content = content; + return this; + } + + public Builder meta(Map meta) { + this.meta = meta; + return this; + } + + public ElicitResult build() { + return new ElicitResult(action, content, meta); + } + } + } + + // --------------------------- + // Pagination Interfaces + // --------------------------- + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public static class PaginatedRequest { + private final String cursor; + + public PaginatedRequest( + @JsonProperty("cursor") String cursor) { + this.cursor = cursor; + } + + public String getCursor() { + return cursor; + } + } + + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public static class PaginatedResult { + private final String nextCursor; + + public PaginatedResult( + @JsonProperty("nextCursor") String nextCursor) { + this.nextCursor = nextCursor; + } + + public String getNextCursor() { + return nextCursor; + } + } + + // --------------------------- + // Progress and Logging + // --------------------------- + @JsonIgnoreProperties(ignoreUnknown = true) + public static class ProgressNotification { + private final String progressToken; + private final double progress; + private final Double total; + + public ProgressNotification( + @JsonProperty("progressToken") String progressToken, + @JsonProperty("progress") double progress, + @JsonProperty("total") Double total) { + this.progressToken = progressToken; + this.progress = progress; + this.total = total; + } + + public String getProgressToken() { + return progressToken; + } + + public double getProgress() { + return progress; } public Double getTotal() { @@ -1977,9 +2272,11 @@ public Double getTotal() { } /** - * The Model Context Protocol (MCP) provides a standardized way for servers to send + * The Model Context Protocol (MCP) provides a standardized way for servers to + * send * structured log messages to clients. Clients can control logging verbosity by - * setting minimum log levels, with servers sending notifications containing severity + * setting minimum log levels, with servers sending notifications containing + * severity * levels, optional logger names, and arbitrary JSON-serializable data. */ @JsonIgnoreProperties(ignoreUnknown = true) @@ -2040,14 +2337,22 @@ public LoggingMessageNotification build() { } public enum LoggingLevel { - @JsonProperty("debug") DEBUG(0), - @JsonProperty("info") INFO(1), - @JsonProperty("notice") NOTICE(2), - @JsonProperty("warning") WARNING(3), - @JsonProperty("error") ERROR(4), - @JsonProperty("critical") CRITICAL(5), - @JsonProperty("alert") ALERT(6), - @JsonProperty("emergency") EMERGENCY(7); + @JsonProperty("debug") + DEBUG(0), + @JsonProperty("info") + INFO(1), + @JsonProperty("notice") + NOTICE(2), + @JsonProperty("warning") + WARNING(3), + @JsonProperty("error") + ERROR(4), + @JsonProperty("critical") + CRITICAL(5), + @JsonProperty("alert") + ALERT(6), + @JsonProperty("emergency") + EMERGENCY(7); private final int level; @@ -2076,7 +2381,8 @@ public LoggingLevel getLevel() { } /** - * Notification for sending intermediate results during streaming tool execution. + * Notification for sending intermediate results during streaming tool + * execution. * This allows tools to send partial results to clients in real-time. */ @JsonIgnoreProperties(ignoreUnknown = true) @@ -2105,10 +2411,14 @@ public Object getData() { // Autocomplete // --------------------------- public enum CompleteArgument { - @JsonProperty("name") NAME, - @JsonProperty("description") DESCRIPTION, - @JsonProperty("uri") URI, - @JsonProperty("mimeType") MIME_TYPE + @JsonProperty("name") + NAME, + @JsonProperty("description") + DESCRIPTION, + @JsonProperty("uri") + URI, + @JsonProperty("mimeType") + MIME_TYPE } public static class CompleteRequest implements Request { @@ -2226,11 +2536,9 @@ public interface Content { default String type() { if (this instanceof TextContent) { return "text"; - } - else if (this instanceof ImageContent) { + } else if (this instanceof ImageContent) { return "image"; - } - else if (this instanceof EmbeddedResource) { + } else if (this instanceof EmbeddedResource) { return "resource"; } throw new IllegalArgumentException("Unknown content type: " + this); @@ -2387,4 +2695,829 @@ public List getRoots() { } } + // --------------------------- + // Tasks + // --------------------------- + + public enum TaskStatus { + @JsonProperty("working") + WORKING, + @JsonProperty("input_required") + INPUT_REQUIRED, + @JsonProperty("completed") + COMPLETED, + @JsonProperty("failed") + FAILED, + @JsonProperty("cancelled") + CANCELLED; + + public boolean isTerminal() { + return this == COMPLETED || this == FAILED || this == CANCELLED; + } + } + + /** + * Task 表示一个长时间运行的操作的执行状态。 + * + * @see MCP + * Tasks Specification + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public static class Task { + private final String taskId; + private final TaskStatus status; + private final String statusMessage; + private final String createdAt; + private final String lastUpdatedAt; + private final Long ttl; + private final Long pollInterval; + + public Task( + @JsonProperty("taskId") String taskId, + @JsonProperty("status") TaskStatus status, + @JsonProperty("statusMessage") String statusMessage, + @JsonProperty("createdAt") String createdAt, + @JsonProperty("lastUpdatedAt") String lastUpdatedAt, + @JsonProperty("ttl") Long ttl, + @JsonProperty("pollInterval") Long pollInterval) { + this.taskId = taskId; + this.status = status; + this.statusMessage = statusMessage; + this.createdAt = createdAt; + this.lastUpdatedAt = lastUpdatedAt; + this.ttl = ttl; + this.pollInterval = pollInterval; + } + + public String getTaskId() { + return taskId; + } + + public TaskStatus getStatus() { + return status; + } + + public String getStatusMessage() { + return statusMessage; + } + + public String getCreatedAt() { + return createdAt; + } + + public String getLastUpdatedAt() { + return lastUpdatedAt; + } + + public Long getTtl() { + return ttl; + } + + public Long getPollInterval() { + return pollInterval; + } + + public boolean isTerminal() { + return status.isTerminal(); + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private String taskId; + private TaskStatus status; + private String statusMessage; + private String createdAt; + private String lastUpdatedAt; + private Long ttl; + private Long pollInterval; + + public Builder taskId(String taskId) { + this.taskId = taskId; + return this; + } + + public Builder status(TaskStatus status) { + this.status = status; + return this; + } + + public Builder statusMessage(String statusMessage) { + this.statusMessage = statusMessage; + return this; + } + + public Builder createdAt(String createdAt) { + this.createdAt = createdAt; + return this; + } + + public Builder lastUpdatedAt(String lastUpdatedAt) { + this.lastUpdatedAt = lastUpdatedAt; + return this; + } + + public Builder ttl(Long ttl) { + this.ttl = ttl; + return this; + } + + public Builder pollInterval(Long pollInterval) { + this.pollInterval = pollInterval; + return this; + } + + public Task build() { + if (taskId == null || taskId.trim().isEmpty()) { + throw new IllegalArgumentException("Task taskId must not be null or empty"); + } + if (status == null) { + throw new IllegalArgumentException("Task status must not be null"); + } + if (createdAt == null || createdAt.trim().isEmpty()) { + throw new IllegalArgumentException("Task createdAt must not be null or empty"); + } + if (lastUpdatedAt == null || lastUpdatedAt.trim().isEmpty()) { + throw new IllegalArgumentException("Task lastUpdatedAt must not be null or empty"); + } + return new Task(taskId, status, statusMessage, createdAt, lastUpdatedAt, ttl, pollInterval); + } + } + } + + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public static class TaskMetadata { + private final Long ttl; + + public TaskMetadata(@JsonProperty("ttl") Long ttl) { + this.ttl = ttl; + } + + public Long getTtl() { + return ttl; + } + + @JsonIgnore + public java.time.Duration ttlAsDuration() { + return ttl != null ? java.time.Duration.ofMillis(ttl) : null; + } + } + + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public static class RelatedTaskMetadata { + private final String taskId; + + public RelatedTaskMetadata(@JsonProperty("taskId") String taskId) { + this.taskId = taskId; + } + + public String getTaskId() { + return taskId; + } + } + + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public static class CreateTaskResult implements Result { + private final Task task; + private final Map meta; + + public CreateTaskResult( + @JsonProperty("task") Task task, + @JsonProperty("_meta") Map meta) { + this.task = task; + this.meta = meta; + } + + public Task getTask() { + return task; + } + + @Override + public Map meta() { + return meta; + } + + public Map getMeta() { + return meta(); + } + } + + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public static class ListTasksResult implements Result { + private final List tasks; + private final String nextCursor; + private final Map meta; + + public ListTasksResult( + @JsonProperty("tasks") List tasks, + @JsonProperty("nextCursor") String nextCursor, + @JsonProperty("_meta") Map meta) { + this.tasks = tasks; + this.nextCursor = nextCursor; + this.meta = meta; + } + + public ListTasksResult(List tasks, String nextCursor) { + this(tasks, nextCursor, null); + } + + public List getTasks() { + return tasks; + } + + public String getNextCursor() { + return nextCursor; + } + + @Override + public Map meta() { + return meta; + } + + public Map getMeta() { + return meta(); + } + } + + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public static class ListTasksRequest implements Request { + private final Map meta; + + public ListTasksRequest(@JsonProperty("_meta") Map meta) { + this.meta = meta; + } + + @Override + public Map meta() { + return meta; + } + + public Map getMeta() { + return meta(); + } + } + + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public static class GetTaskRequest implements Request { + private final String taskId; + private final Map meta; + + public GetTaskRequest( + @JsonProperty("taskId") String taskId, + @JsonProperty("_meta") Map meta) { + this.taskId = taskId; + this.meta = meta; + } + + public String getTaskId() { + return taskId; + } + + @Override + public Map meta() { + return meta; + } + + public Map getMeta() { + return meta(); + } + } + + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public static class GetTaskResult implements Result { + private final String taskId; + private final TaskStatus status; + private final String statusMessage; + private final String createdAt; + private final String lastUpdatedAt; + private final Long ttl; + private final Long pollInterval; + private final Map meta; + + public GetTaskResult( + @JsonProperty("taskId") String taskId, + @JsonProperty("status") TaskStatus status, + @JsonProperty("statusMessage") String statusMessage, + @JsonProperty("createdAt") String createdAt, + @JsonProperty("lastUpdatedAt") String lastUpdatedAt, + @JsonProperty("ttl") Long ttl, + @JsonProperty("pollInterval") Long pollInterval, + @JsonProperty("_meta") Map meta) { + this.taskId = taskId; + this.status = status; + this.statusMessage = statusMessage; + this.createdAt = createdAt; + this.lastUpdatedAt = lastUpdatedAt; + this.ttl = ttl; + this.pollInterval = pollInterval; + this.meta = meta; + } + + public String getTaskId() { + return taskId; + } + + public TaskStatus getStatus() { + return status; + } + + public String getStatusMessage() { + return statusMessage; + } + + public String getCreatedAt() { + return createdAt; + } + + public String getLastUpdatedAt() { + return lastUpdatedAt; + } + + public Long getTtl() { + return ttl; + } + + public Long getPollInterval() { + return pollInterval; + } + + @Override + public Map meta() { + return meta; + } + + public Map getMeta() { + return meta(); + } + + public static GetTaskResult fromTask(Task task) { + return new GetTaskResult( + task.getTaskId(), + task.getStatus(), + task.getStatusMessage(), + task.getCreatedAt(), + task.getLastUpdatedAt(), + task.getTtl(), + task.getPollInterval(), + null); + } + + public Task toTask() { + return Task.builder() + .taskId(taskId) + .status(status) + .statusMessage(statusMessage) + .createdAt(createdAt) + .lastUpdatedAt(lastUpdatedAt) + .ttl(ttl) + .pollInterval(pollInterval) + .build(); + } + } + + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public static class GetTaskPayloadRequest implements Request { + private final String taskId; + private final Map meta; + + public GetTaskPayloadRequest( + @JsonProperty("taskId") String taskId, + @JsonProperty("_meta") Map meta) { + this.taskId = taskId; + this.meta = meta; + } + + public String getTaskId() { + return taskId; + } + + @Override + public Map meta() { + return meta; + } + + public Map getMeta() { + return meta(); + } + } + + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public static class GetTaskPayloadResult implements Result { + private final Object payload; + private final Map meta; + + public GetTaskPayloadResult( + @JsonProperty("payload") Object payload, + @JsonProperty("_meta") Map meta) { + this.payload = payload; + this.meta = meta; + } + + public Object getPayload() { + return payload; + } + + @Override + public Map meta() { + return meta; + } + + public Map getMeta() { + return meta(); + } + } + + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public static class CancelTaskRequest implements Request { + private final String taskId; + private final Map meta; + + public CancelTaskRequest( + @JsonProperty("taskId") String taskId, + @JsonProperty("_meta") Map meta) { + this.taskId = taskId; + this.meta = meta; + } + + public String getTaskId() { + return taskId; + } + + @Override + public Map meta() { + return meta; + } + } + + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public static class CancelTaskResult implements Result { + private final String taskId; + private final TaskStatus status; + private final String statusMessage; + private final String createdAt; + private final String lastUpdatedAt; + private final Long ttl; + private final Long pollInterval; + private final Map meta; + + public CancelTaskResult( + @JsonProperty("taskId") String taskId, + @JsonProperty("status") TaskStatus status, + @JsonProperty("statusMessage") String statusMessage, + @JsonProperty("createdAt") String createdAt, + @JsonProperty("lastUpdatedAt") String lastUpdatedAt, + @JsonProperty("ttl") Long ttl, + @JsonProperty("pollInterval") Long pollInterval, + @JsonProperty("_meta") Map meta) { + this.taskId = taskId; + this.status = status; + this.statusMessage = statusMessage; + this.createdAt = createdAt; + this.lastUpdatedAt = lastUpdatedAt; + this.ttl = ttl; + this.pollInterval = pollInterval; + this.meta = meta; + } + + public String getTaskId() { + return taskId; + } + + public TaskStatus getStatus() { + return status; + } + + public String getStatusMessage() { + return statusMessage; + } + + public String getCreatedAt() { + return createdAt; + } + + public String getLastUpdatedAt() { + return lastUpdatedAt; + } + + public Long getTtl() { + return ttl; + } + + public Long getPollInterval() { + return pollInterval; + } + + @Override + public Map meta() { + return meta; + } + + public Map getMeta() { + return meta(); + } + + public static CancelTaskResult fromTask(Task task) { + return new CancelTaskResult( + task.getTaskId(), + task.getStatus(), + task.getStatusMessage(), + task.getCreatedAt(), + task.getLastUpdatedAt(), + task.getTtl(), + task.getPollInterval(), + null // meta + ); + } + } + + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public static class TaskStatusNotification { + private final String taskId; + private final TaskStatus status; + private final String statusMessage; + private final String createdAt; + private final String lastUpdatedAt; + private final Long ttl; + private final Long pollInterval; + private final Map meta; + + @JsonCreator + public TaskStatusNotification( + @JsonProperty("taskId") String taskId, + @JsonProperty("status") TaskStatus status, + @JsonProperty("statusMessage") String statusMessage, + @JsonProperty("createdAt") String createdAt, + @JsonProperty("lastUpdatedAt") String lastUpdatedAt, + @JsonProperty("ttl") Long ttl, + @JsonProperty("pollInterval") Long pollInterval, + @JsonProperty("_meta") Map meta) { + this.taskId = taskId; + this.status = status; + this.statusMessage = statusMessage; + this.createdAt = createdAt; + this.lastUpdatedAt = lastUpdatedAt; + this.ttl = ttl; + this.pollInterval = pollInterval; + this.meta = meta; + } + + public String getTaskId() { + return taskId; + } + + public TaskStatus getStatus() { + return status; + } + + public String getStatusMessage() { + return statusMessage; + } + + public String getCreatedAt() { + return createdAt; + } + + public String getLastUpdatedAt() { + return lastUpdatedAt; + } + + public Long getTtl() { + return ttl; + } + + public Long getPollInterval() { + return pollInterval; + } + + @JsonProperty("_meta") + public Map getMeta() { + return meta; + } + + public boolean isTerminal() { + return status != null && status.isTerminal(); + } + + public static TaskStatusNotification fromTask(Task task) { + return new TaskStatusNotification( + task.getTaskId(), + task.getStatus(), + task.getStatusMessage(), + task.getCreatedAt(), + task.getLastUpdatedAt(), + task.getTtl(), + task.getPollInterval(), + null); + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private String taskId; + private TaskStatus status; + private String statusMessage; + private String createdAt; + private String lastUpdatedAt; + private Long ttl; + private Long pollInterval; + private Map meta; + + public Builder taskId(String taskId) { + this.taskId = taskId; + return this; + } + + public Builder status(TaskStatus status) { + this.status = status; + return this; + } + + public Builder statusMessage(String statusMessage) { + this.statusMessage = statusMessage; + return this; + } + + public Builder createdAt(String createdAt) { + this.createdAt = createdAt; + return this; + } + + public Builder lastUpdatedAt(String lastUpdatedAt) { + this.lastUpdatedAt = lastUpdatedAt; + return this; + } + + public Builder ttl(Long ttl) { + this.ttl = ttl; + return this; + } + + public Builder pollInterval(Long pollInterval) { + this.pollInterval = pollInterval; + return this; + } + + public Builder meta(Map meta) { + this.meta = meta; + return this; + } + + public TaskStatusNotification build() { + return new TaskStatusNotification(taskId, status, statusMessage, + createdAt, lastUpdatedAt, ttl, pollInterval, meta); + } + } + } + + // ======================== + // ResponseMessage 类型 + // ======================== + + /** + * 流式响应消息接口,用于任务增强请求的 SSE 流式推送。 + * + *

消息类型: + *

    + *
  • {@link TaskCreatedMessage} — 任务创建后的第一条消息
  • + *
  • {@link TaskStatusMessage} — 轮询期间的状态更新
  • + *
  • {@link ResultMessage} — 最终成功结果
  • + *
  • {@link ErrorMessage} — 错误消息(终态)
  • + *
+ * + * @param 预期的结果类型 + */ + public interface ResponseMessage { + + /** + * 返回消息类型标识符。 + * @return 类型字符串("taskCreated"、"taskStatus"、"result" 或 "error") + */ + String type(); + } + + /** + * 表示任务已创建的消息。这是任务增强请求的第一条消息。 + * + * @param 预期的结果类型 + */ + public static class TaskCreatedMessage implements ResponseMessage { + + private final Task task; + + private TaskCreatedMessage(Task task) { + this.task = task; + } + + public Task getTask() { + return this.task; + } + + @Override + public String type() { + return "taskCreated"; + } + + public static TaskCreatedMessage of(Task task) { + return new TaskCreatedMessage(task); + } + } + + /** + * 表示任务状态更新的消息。在轮询等待终态期间周期性产生。 + * + * @param 预期的结果类型 + */ + public static class TaskStatusMessage implements ResponseMessage { + + private final Task task; + + private TaskStatusMessage(Task task) { + this.task = task; + } + + public Task getTask() { + return this.task; + } + + @Override + public String type() { + return "taskStatus"; + } + + public static TaskStatusMessage of(Task task) { + return new TaskStatusMessage(task); + } + } + + /** + * 表示最终成功结果的消息。这是终态消息,之后不会再有消息。 + * + * @param 结果类型 + */ + public static class ResultMessage implements ResponseMessage { + + private final T result; + + private ResultMessage(T result) { + this.result = result; + } + + public T getResult() { + return this.result; + } + + @Override + public String type() { + return "result"; + } + + public static ResultMessage of(T result) { + return new ResultMessage(result); + } + } + + /** + * 表示发生错误的消息。这是终态消息,之后不会再有消息。 + * + * @param 预期的结果类型 + */ + public static class ErrorMessage implements ResponseMessage { + + private final McpError error; + + private ErrorMessage(McpError error) { + this.error = error; + } + + public McpError getError() { + return this.error; + } + + @Override + public String type() { + return "error"; + } + + public static ErrorMessage of(McpError error) { + return new ErrorMessage(error); + } + } } diff --git a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/protocol/spec/McpStatelessServerTransport.java b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/protocol/spec/McpStatelessServerTransport.java index 11660a4945b..0eb4e7c1cd6 100644 --- a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/protocol/spec/McpStatelessServerTransport.java +++ b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/protocol/spec/McpStatelessServerTransport.java @@ -21,7 +21,8 @@ default void close() { CompletableFuture closeGracefully(); default List protocolVersions() { - return Arrays.asList(ProtocolVersions.MCP_2025_03_26, ProtocolVersions.MCP_2025_06_18); + return Arrays.asList(ProtocolVersions.MCP_2025_03_26, ProtocolVersions.MCP_2025_06_18, + ProtocolVersions.MCP_2025_11_25); } } diff --git a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/protocol/spec/McpStreamableServerSession.java b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/protocol/spec/McpStreamableServerSession.java index e40ded4dc90..ac41485d1d4 100644 --- a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/protocol/spec/McpStreamableServerSession.java +++ b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/protocol/spec/McpStreamableServerSession.java @@ -13,6 +13,8 @@ import com.taobao.arthas.mcp.server.protocol.server.McpTransportContext; import com.taobao.arthas.mcp.server.session.ArthasCommandContext; import com.taobao.arthas.mcp.server.session.ArthasCommandSessionManager; +import com.taobao.arthas.mcp.server.task.TaskMessageQueue; +import com.taobao.arthas.mcp.server.task.TaskStore; import com.taobao.arthas.mcp.server.util.Assert; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -61,11 +63,16 @@ public class McpStreamableServerSession implements McpSession { private final EventStore eventStore; + private final TaskStore taskStore; + + private final TaskMessageQueue taskMessageQueue; + public McpStreamableServerSession(String id, McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo, Duration requestTimeout, Map> requestHandlers, Map notificationHandlers, - CommandExecutor commandExecutor, EventStore eventStore) { + CommandExecutor commandExecutor, EventStore eventStore, + TaskStore taskStore, TaskMessageQueue taskMessageQueue) { this.id = id; this.missingMcpTransportSession = new MissingMcpTransportSession(id); this.listeningStreamRef = new AtomicReference<>(this.missingMcpTransportSession); @@ -77,6 +84,8 @@ public McpStreamableServerSession(String id, McpSchema.ClientCapabilities client this.commandExecutor = commandExecutor; this.commandSessionManager = new ArthasCommandSessionManager(commandExecutor); this.eventStore = eventStore; + this.taskStore = taskStore; + this.taskMessageQueue = taskMessageQueue; } /** @@ -171,12 +180,33 @@ public CompletableFuture responseStream(McpSchema.JSONRPCRequest jsonrpcRe .thenCompose(v -> transport.closeGracefully()); } ArthasCommandContext commandContext = createCommandContext(transportContext.get(MCP_AUTH_SUBJECT_KEY)); - + return requestHandler .handle(new McpNettyServerExchange(this.id, stream, clientCapabilities.get(), - clientInfo.get(), transportContext), commandContext, jsonrpcRequest.getParams()) - .thenApply(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, - jsonrpcRequest.getId(), result, null)) + clientInfo.get(), transportContext, taskMessageQueue, taskStore), + commandContext, jsonrpcRequest.getParams()) + .handle((result, ex) -> { + if (ex != null) { + Throwable cause = ex; + if (cause instanceof java.util.concurrent.CompletionException) { + cause = cause.getCause(); + } + + McpSchema.JSONRPCResponse.JSONRPCError jsonRpcError; + if (cause instanceof McpError && ((McpError) cause).getJsonRpcError() != null) { + jsonRpcError = ((McpError) cause).getJsonRpcError(); + } else { + jsonRpcError = new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, + cause.getMessage(), McpError.aggregateExceptionMessages(cause)); + } + + return new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, jsonrpcRequest.getId(), + null, jsonRpcError); + } else { + return new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, + jsonrpcRequest.getId(), result, null); + } + }) .thenCompose(response -> transport.sendMessage(response, null)) .thenCompose(v -> transport.closeGracefully()); } @@ -193,7 +223,8 @@ public CompletableFuture accept(McpSchema.JSONRPCNotification notification ArthasCommandContext commandContext = createCommandContext(transportContext.get(MCP_AUTH_SUBJECT_KEY)); McpSession listeningStream = this.listeningStreamRef.get(); return notificationHandler.handle(new McpNettyServerExchange(this.id, listeningStream, - this.clientCapabilities.get(), this.clientInfo.get(), transportContext), commandContext, notification.getParams()); + this.clientCapabilities.get(), this.clientInfo.get(), transportContext, taskMessageQueue, taskStore), + commandContext, notification.getParams()); } public CompletableFuture accept(McpSchema.JSONRPCResponse response) { diff --git a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/protocol/spec/ProtocolVersions.java b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/protocol/spec/ProtocolVersions.java index de06695e601..2dedbd2b7f6 100644 --- a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/protocol/spec/ProtocolVersions.java +++ b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/protocol/spec/ProtocolVersions.java @@ -20,4 +20,10 @@ public interface ProtocolVersions { */ String MCP_2025_06_18 = "2025-06-18"; + /** + * MCP protocol version for 2025-11-25. + * https://modelcontextprotocol.io/specification/2025-11-25 + */ + String MCP_2025_11_25 = "2025-11-25"; + } diff --git a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/session/ArthasCommandContext.java b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/session/ArthasCommandContext.java index dc3207247cd..503ef8f008b 100644 --- a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/session/ArthasCommandContext.java +++ b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/session/ArthasCommandContext.java @@ -146,15 +146,12 @@ public Map pullResults() { * Interrupt the current job */ public Map interruptJob() { - requireSessionSupport(); - return commandExecutor.interruptJob(binding.getArthasSessionId()); + if (binding != null) { + return commandExecutor.interruptJob(binding.getArthasSessionId()); + } + return null; } - /** - * Set session userId for statistics reporting - * - * @param userId 用户 ID - */ public void setSessionUserId(String userId) { if (binding != null && userId != null) { commandExecutor.setSessionUserId(binding.getArthasSessionId(), userId); diff --git a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/session/ArthasCommandSessionManager.java b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/session/ArthasCommandSessionManager.java index 924d074649f..400d48d77d6 100644 --- a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/session/ArthasCommandSessionManager.java +++ b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/session/ArthasCommandSessionManager.java @@ -21,6 +21,9 @@ public class ArthasCommandSessionManager { private final CommandExecutor commandExecutor; private final ConcurrentHashMap sessionBindings = new ConcurrentHashMap<>(); + + // 独立管理 task session + private final ConcurrentHashMap taskSessionBindings = new ConcurrentHashMap<>(); public ArthasCommandSessionManager(CommandExecutor commandExecutor) { this.commandExecutor = commandExecutor; @@ -153,4 +156,42 @@ public void closeCommandSession(String mcpSessionId) { public void closeAllSessions() { sessionBindings.keySet().forEach(this::closeCommandSession); } + + /** + * 为 task 创建独立的 Arthas Session。 + */ + public CommandSessionBinding createIsolatedTaskSession(String taskId) { + Map result = commandExecutor.createSession(); + + CommandSessionBinding binding = new CommandSessionBinding( + "task-" + taskId, // 使用 task ID 作为 MCP session ID + (String) result.get("sessionId"), + (String) result.get("consumerId") + ); + + // 注册到独立的 map,方便追踪和清理 + taskSessionBindings.put(taskId, binding); + + logger.info("Created isolated task session: taskId={}, arthasSessionId={}", + taskId, binding.getArthasSessionId()); + return binding; + } + + public void closeTaskSession(String taskId) { + CommandSessionBinding binding = taskSessionBindings.remove(taskId); + if (binding != null) { + try { + commandExecutor.closeSession(binding.getArthasSessionId()); + logger.info("Closed task session: taskId={}, arthasSessionId={}", + taskId, binding.getArthasSessionId()); + } catch (Exception e) { + logger.warn("Failed to close task session: taskId={}, error={}", + taskId, e.getMessage()); + } + } + } + + public int getActiveTaskSessionCount() { + return taskSessionBindings.size(); + } } diff --git a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/AbstractTaskAwareToolSpecificationBuilder.java b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/AbstractTaskAwareToolSpecificationBuilder.java new file mode 100644 index 00000000000..c31ff249a2d --- /dev/null +++ b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/AbstractTaskAwareToolSpecificationBuilder.java @@ -0,0 +1,77 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package com.taobao.arthas.mcp.server.task; + +import com.taobao.arthas.mcp.server.protocol.spec.McpSchema; + +/** + * Abstract base for task-aware tool specification builders with self-referencing generics. + * + * @param the concrete builder type + * @author Yeaury + */ +public abstract class AbstractTaskAwareToolSpecificationBuilder> { + + protected String name; + protected String description; + protected McpSchema.JsonSchema inputSchema; + protected McpSchema.TaskSupportMode taskSupport = McpSchema.TaskSupportMode.OPTIONAL; + + @SuppressWarnings("unchecked") + protected T self() { + return (T) this; + } + + public T name(String name) { + this.name = name; + return self(); + } + + public T description(String description) { + this.description = description; + return self(); + } + + public T inputSchema(McpSchema.JsonSchema schema) { + this.inputSchema = schema; + return self(); + } + + public T taskSupport(McpSchema.TaskSupportMode mode) { + this.taskSupport = mode; + return self(); + } + + public T taskSupport(String mode) { + if ("optional".equalsIgnoreCase(mode)) { + this.taskSupport = McpSchema.TaskSupportMode.OPTIONAL; + } else if ("required".equalsIgnoreCase(mode)) { + this.taskSupport = McpSchema.TaskSupportMode.REQUIRED; + } else if ("forbidden".equalsIgnoreCase(mode)) { + this.taskSupport = McpSchema.TaskSupportMode.FORBIDDEN; + } else { + throw new IllegalArgumentException("Invalid taskSupport mode: " + mode); + } + return self(); + } + + protected void validateCommonFields() { + if (name == null || name.trim().isEmpty()) { + throw new IllegalArgumentException("Tool name must not be null or empty"); + } + if (inputSchema == null) { + throw new IllegalArgumentException("Input schema must not be null"); + } + } + + protected McpSchema.Tool buildTool() { + return McpSchema.Tool.builder() + .name(name) + .description(description) + .inputSchema(inputSchema) + .execution(new McpSchema.ToolExecution(taskSupport)) + .build(); + } +} diff --git a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/AbstractTaskHandler.java b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/AbstractTaskHandler.java new file mode 100644 index 00000000000..44ecd51c2e7 --- /dev/null +++ b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/AbstractTaskHandler.java @@ -0,0 +1,137 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package com.taobao.arthas.mcp.server.task; + +import com.taobao.arthas.mcp.server.protocol.spec.McpSchema; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; + +/** + * Abstract base for task handlers, managing TaskStore and TaskManager lifecycle. + * + * @param result type stored in TaskStore + * @author Yeaury + */ +public abstract class AbstractTaskHandler implements TaskManagerHost { + + private static final Logger logger = LoggerFactory.getLogger(AbstractTaskHandler.class); + + protected final TaskStore taskStore; + protected final TaskManager taskManager; + protected final TaskHandlerRegistry taskHandlerRegistry = new TaskHandlerRegistry(); + + protected AbstractTaskHandler(TaskStore taskStore, TaskManagerOptions taskOptions) { + this.taskStore = taskStore; + if (taskOptions != null && taskStore != null) { + this.taskManager = taskOptions.createTaskManager(); + this.taskManager.bind(this); + logger.info("TaskManager created: {}", this.taskManager.getClass().getSimpleName()); + } else { + this.taskManager = NullTaskManager.getInstance(); + logger.info("Using NullTaskManager (tasks not configured)"); + } + } + + @Override + public void registerHandler(String method, TaskRequestHandler handler) { + this.taskHandlerRegistry.registerHandler(method, handler); + logger.debug("Registered task handler for method: {}", method); + } + + @Override + public CompletableFuture invokeCustomTaskHandler( + String taskId, String method, McpSchema.Request request, + TaskHandlerContext context, Class resultType) { + + if (this.taskStore == null) { + return CompletableFuture.completedFuture(null); + } + return this.taskStore.getTask(taskId, context.sessionId()) + .thenCompose(storeResult -> { + if (storeResult == null) { + logger.debug("invokeCustomTaskHandler: task not found for taskId={}", taskId); + return CompletableFuture.completedFuture(null); + } + return findAndInvokeCustomHandler(storeResult, method, request, context, resultType); + }) + .exceptionally(ex -> { + logger.debug("invokeCustomTaskHandler: task lookup failed for taskId={}, returning null", + taskId, ex); + return null; + }); + } + + /** Hook for subclasses to find and invoke tool-specific custom handlers. Returns null by default. */ + protected CompletableFuture findAndInvokeCustomHandler( + GetTaskFromStoreResult storeResult, String method, McpSchema.Request request, + TaskHandlerContext context, Class resultType) { + return CompletableFuture.completedFuture(null); + } + + public TaskStore getTaskStore() { + return this.taskStore; + } + + public TaskManager taskManager() { + return this.taskManager; + } + + public void close() { + if (this.taskManager != null) { + try { + this.taskManager.onClose(); + logger.info("TaskManager closed"); + } catch (Exception e) { + logger.error("Error closing TaskManager", e); + } + } + if (this.taskStore != null) { + try { + this.taskStore.shutdown().get(TaskDefaults.TASK_STORE_SHUTDOWN_TIMEOUT_SECONDS, TimeUnit.SECONDS); + logger.info("TaskStore shutdown completed"); + } catch (Exception e) { + logger.error("Error shutting down TaskStore", e); + } + } + } + + public CompletableFuture closeGracefully() { + if (this.taskManager != null) { + this.taskManager.onClose(); + } + return this.taskStore != null ? this.taskStore.shutdown() : CompletableFuture.completedFuture(null); + } + + // --------------------------------------- + // Handler Context Factory + // --------------------------------------- + + protected static TaskManagerHost.TaskHandlerContext createTaskHandlerContext( + String sessionId, + TriFunction, CompletableFuture> requestSender, + java.util.function.BiFunction> notificationSender) { + return new TaskManagerHost.TaskHandlerContext() { + @Override + public String sessionId() { + return sessionId; + } + + @Override + @SuppressWarnings("unchecked") + public CompletableFuture sendRequest( + String reqMethod, Object reqParams, Class resultType) { + return (CompletableFuture) requestSender.apply(reqMethod, reqParams, resultType); + } + + @Override + public CompletableFuture sendNotification(String notifMethod, Object notification) { + return notificationSender.apply(notifMethod, notification); + } + }; + } +} diff --git a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/CreateTaskContext.java b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/CreateTaskContext.java new file mode 100644 index 00000000000..22f31a63eb9 --- /dev/null +++ b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/CreateTaskContext.java @@ -0,0 +1,49 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package com.taobao.arthas.mcp.server.task; + +import com.taobao.arthas.mcp.server.protocol.server.McpNettyServerExchange; +import com.taobao.arthas.mcp.server.protocol.spec.McpSchema; +import com.taobao.arthas.mcp.server.session.ArthasCommandContext; +import com.taobao.arthas.mcp.server.session.ArthasCommandSessionManager; + +import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; + +/** + * Task lifecycle context provided to {@link CreateTaskHandler} implementations. + * + * @author Yeaury + */ +public interface CreateTaskContext { + + McpNettyServerExchange exchange(); + + String sessionId(); + + Long requestTtl(); + + McpSchema.Request originatingRequest(); + + ArthasCommandContext commandContext(); + + CompletableFuture createTask(); + + CompletableFuture createTask(Consumer customizer); + + CompletableFuture completeTask(String taskId, McpSchema.CallToolResult result); + + CompletableFuture failTask(String taskId, McpSchema.CallToolResult errorResult); + + CompletableFuture setInputRequired(String taskId, String message); + + CompletableFuture isCancellationRequested(String taskId); + + ArthasCommandSessionManager sessionManager(); + + ArthasCommandContext createIsolatedTaskSession(String taskId); + + void cleanupTaskSession(String taskId); +} diff --git a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/CreateTaskHandler.java b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/CreateTaskHandler.java new file mode 100644 index 00000000000..d0af0de4822 --- /dev/null +++ b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/CreateTaskHandler.java @@ -0,0 +1,24 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package com.taobao.arthas.mcp.server.task; + +import com.taobao.arthas.mcp.server.protocol.spec.McpSchema; + +import java.util.Map; +import java.util.concurrent.CompletableFuture; + +/** + * Handler for task creation. Implementations start async work and return immediately with a task. + * + * @author Yeaury + */ +@FunctionalInterface +public interface CreateTaskHandler { + + CompletableFuture createTask( + Map args, + CreateTaskContext context + ); +} diff --git a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/CreateTaskOptions.java b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/CreateTaskOptions.java new file mode 100644 index 00000000000..14d13f262e9 --- /dev/null +++ b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/CreateTaskOptions.java @@ -0,0 +1,102 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package com.taobao.arthas.mcp.server.task; + +import com.taobao.arthas.mcp.server.protocol.spec.McpSchema; + +/** + * Options for creating a task. + * + * @author Yeaury + */ +public class CreateTaskOptions { + + private final String sessionId; + private final String taskId; + private final Long requestedTtl; + private final Long pollInterval; + private final McpSchema.Request originatingRequest; + private final Object context; + + private CreateTaskOptions(Builder builder) { + this.sessionId = builder.sessionId; + this.taskId = builder.taskId; + this.requestedTtl = builder.requestedTtl; + this.pollInterval = builder.pollInterval; + this.originatingRequest = builder.originatingRequest; + this.context = builder.context; + } + + public String sessionId() { + return sessionId; + } + + public String taskId() { + return taskId; + } + + public Long requestedTtl() { + return requestedTtl; + } + + public Long pollInterval() { + return pollInterval; + } + + public McpSchema.Request originatingRequest() { + return originatingRequest; + } + + public Object context() { + return context; + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private String sessionId; + private String taskId; + private Long requestedTtl; + private Long pollInterval; + private McpSchema.Request originatingRequest; + private Object context; + + public Builder sessionId(String sessionId) { + this.sessionId = sessionId; + return this; + } + + public Builder taskId(String taskId) { + this.taskId = taskId; + return this; + } + + public Builder ttl(Long ttl) { + this.requestedTtl = ttl; + return this; + } + + public Builder pollInterval(Long pollInterval) { + this.pollInterval = pollInterval; + return this; + } + + public Builder originatingRequest(McpSchema.Request request) { + this.originatingRequest = request; + return this; + } + + public Builder context(Object context) { + this.context = context; + return this; + } + + public CreateTaskOptions build() { + return new CreateTaskOptions(this); + } + } +} diff --git a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/DefaultCreateTaskContext.java b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/DefaultCreateTaskContext.java new file mode 100644 index 00000000000..075e0669ce5 --- /dev/null +++ b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/DefaultCreateTaskContext.java @@ -0,0 +1,146 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package com.taobao.arthas.mcp.server.task; + +import com.taobao.arthas.mcp.server.protocol.server.McpNettyServerExchange; +import com.taobao.arthas.mcp.server.protocol.spec.McpSchema; +import com.taobao.arthas.mcp.server.session.ArthasCommandContext; +import com.taobao.arthas.mcp.server.session.ArthasCommandSessionManager; +import com.taobao.arthas.mcp.server.session.ArthasCommandSessionManager.CommandSessionBinding; + +import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; + +/** + * Default implementation of {@link CreateTaskContext}. + * + * @author Yeaury + */ +public class DefaultCreateTaskContext implements CreateTaskContext { + + private final TaskStore taskStore; + + private final TaskMessageQueue messageQueue; + + private final McpNettyServerExchange exchange; + + private final String sessionId; + + private final Long requestTtl; + + private final McpSchema.Request originatingRequest; + + private final ArthasCommandContext commandContext; + + private final ArthasCommandSessionManager sessionManager; + + public DefaultCreateTaskContext( + TaskStore taskStore, + TaskMessageQueue messageQueue, + McpNettyServerExchange exchange, + String sessionId, + Long requestTtl, + McpSchema.Request originatingRequest, + ArthasCommandContext commandContext, + ArthasCommandSessionManager sessionManager) { + this.taskStore = taskStore; + this.messageQueue = messageQueue; + this.exchange = exchange; + this.sessionId = sessionId; + this.requestTtl = requestTtl; + this.originatingRequest = originatingRequest; + this.commandContext = commandContext; + this.sessionManager = sessionManager; + } + + @Override + public McpNettyServerExchange exchange() { + return exchange; + } + + @Override + public String sessionId() { + return sessionId; + } + + @Override + public Long requestTtl() { + return requestTtl; + } + + @Override + public McpSchema.Request originatingRequest() { + return originatingRequest; + } + + @Override + public ArthasCommandContext commandContext() { + return commandContext; + } + + @Override + public CompletableFuture createTask() { + return createTask(builder -> {}); + } + + @Override + public CompletableFuture createTask(Consumer customizer) { + CreateTaskOptions.Builder builder = CreateTaskOptions.builder() + .sessionId(sessionId) + .ttl(requestTtl) + .originatingRequest(originatingRequest); + customizer.accept(builder); + return taskStore.createTask(builder.build()); + } + + @Override + public CompletableFuture completeTask(String taskId, McpSchema.CallToolResult result) { + return taskStore.storeTaskResult(taskId, sessionId, McpSchema.TaskStatus.COMPLETED, result); + } + + @Override + public CompletableFuture failTask(String taskId, McpSchema.CallToolResult errorResult) { + return taskStore.storeTaskResult(taskId, sessionId, McpSchema.TaskStatus.FAILED, errorResult); + } + + @Override + public CompletableFuture setInputRequired(String taskId, String message) { + return taskStore.updateTaskStatus(taskId, sessionId, McpSchema.TaskStatus.INPUT_REQUIRED, message); + } + + TaskStore taskStore() { + return taskStore; + } + + TaskMessageQueue taskMessageQueue() { + return messageQueue; + } + + @Override + public CompletableFuture isCancellationRequested(String taskId) { + return taskStore.isCancellationRequested(taskId, sessionId); + } + + @Override + public ArthasCommandSessionManager sessionManager() { + return sessionManager; + } + + @Override + public ArthasCommandContext createIsolatedTaskSession(String taskId) { + if (sessionManager == null) { + throw new IllegalStateException("SessionManager is not available"); + } + CommandSessionBinding binding = sessionManager.createIsolatedTaskSession(taskId); + return new ArthasCommandContext(commandContext.getCommandExecutor(), binding); + } + + @Override + public void cleanupTaskSession(String taskId) { + if (sessionManager != null) { + sessionManager.closeTaskSession(taskId); + } + } +} diff --git a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/DefaultTaskManager.java b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/DefaultTaskManager.java new file mode 100644 index 00000000000..45a8aafef3b --- /dev/null +++ b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/DefaultTaskManager.java @@ -0,0 +1,694 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package com.taobao.arthas.mcp.server.task; + +import com.taobao.arthas.mcp.server.protocol.spec.McpError; +import com.taobao.arthas.mcp.server.protocol.spec.McpSchema; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.time.Duration; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeoutException; +import java.util.function.Consumer; +import java.util.function.Function; + +/** + * Default {@link TaskManager} implementation handling task orchestration and side-channel messaging. + * + * @see TaskManager + * @see NullTaskManager + */ +class DefaultTaskManager implements TaskManager { + + private static final Logger logger = LoggerFactory.getLogger(DefaultTaskManager.class); + + private static final String RELATED_TASK_META_KEY = "relatedTask"; + + private final TaskStore taskStore; + private final TaskMessageQueue messageQueue; + private final Duration defaultPollInterval; + private final Duration pollTimeout; + + /** Pending request resolvers awaiting side-channel responses. */ + private final Map requestResolvers = new ConcurrentHashMap<>(); + + /** Stores response and error handlers for a pending request. */ + private static class RequestResolver { + final Consumer responseHandler; + final Consumer errorHandler; + + RequestResolver(Consumer responseHandler, Consumer errorHandler) { + this.responseHandler = responseHandler; + this.errorHandler = errorHandler; + } + } + + private TaskManagerHost host; + + DefaultTaskManager(TaskManagerOptions options) { + this.taskStore = options.taskStore(); + this.messageQueue = options.messageQueue(); + this.defaultPollInterval = options.defaultPollInterval(); + this.pollTimeout = options.pollTimeout() != null + ? options.pollTimeout() + : Duration.ofMillis(TaskDefaults.DEFAULT_AUTOMATIC_POLLING_TIMEOUT_MS); + } + + @Override + public void bind(TaskManagerHost host) { + this.host = host; + + if (this.taskStore != null) { + host.registerHandler(McpSchema.METHOD_TASKS_GET, this::handleGetTask); + host.registerHandler(McpSchema.METHOD_TASKS_RESULT, this::handleGetTaskResult); + host.registerHandler(McpSchema.METHOD_TASKS_LIST, this::handleListTasks); + host.registerHandler(McpSchema.METHOD_TASKS_CANCEL, this::handleCancelTask); + } + } + + @Override + public InboundRequestResult processInboundRequest(String requestMethod, Object requestParams, + InboundRequestContext ctx) { + String relatedTaskId = extractRelatedTaskId(requestParams); + TaskCreationParams taskCreationParams = extractTaskCreationParams(requestParams); + + Consumer wrappedSendNotification; + if (relatedTaskId != null) { + wrappedSendNotification = notification -> ctx.sendNotification() + .send(notification, NotificationOptions.withRelatedTask(new RelatedTaskInfo(relatedTaskId))) + .exceptionally(ex -> { + logger.warn("Failed to send notification", ex); + return null; + }); + } else { + wrappedSendNotification = notification -> ctx.sendNotification() + .send(notification, NotificationOptions.empty()) + .exceptionally(ex -> { + logger.warn("Failed to send notification", ex); + return null; + }); + } + + RequestSender wrappedSendRequest = getSendRequest(ctx, relatedTaskId); + + Function> routeResponse = response -> { + if (relatedTaskId == null) { + return CompletableFuture.completedFuture(false); + } + return CompletableFuture.completedFuture(false); + }; + + return new InboundRequestResult( + wrappedSendNotification, + wrappedSendRequest, + routeResponse, + taskCreationParams != null + ); + } + + private static RequestSender getSendRequest(InboundRequestContext ctx, String relatedTaskId) { + if (relatedTaskId != null) { + return new RequestSender() { + @Override + public CompletableFuture send(Object request, Class resultType, RequestOptions options) { + RequestOptions augmented = new RequestOptions( + options != null ? options.task() : null, + new RelatedTaskInfo(relatedTaskId)); + return ctx.sendRequest().send(request, resultType, augmented); + } + }; + } else { + return ctx.sendRequest(); + } + } + + @Override + public OutboundRequestResult processOutboundRequest(String requestMethod, Object requestParams, + RequestOptions options, Object messageId, + Consumer responseHandler, + Consumer errorHandler) { + String relatedTaskId = options != null && options.relatedTask() != null + ? options.relatedTask().taskId() + : null; + + if (relatedTaskId != null && this.messageQueue != null) { + this.requestResolvers.put(messageId, new RequestResolver(responseHandler, errorHandler)); + + McpSchema.Request typedRequest = requestParams instanceof McpSchema.Request + ? (McpSchema.Request) requestParams : null; + QueuedMessage.Request queuedRequest = new QueuedMessage.Request(messageId, requestMethod, typedRequest); + this.messageQueue.enqueue(relatedTaskId, queuedRequest) + .exceptionally(ex -> { + errorHandler.accept(ex); + return null; + }); + + return new OutboundRequestResult(true); + } + + return new OutboundRequestResult(false); + } + + @Override + public InboundResponseResult processInboundResponse(Object responseResult, Object messageId) { + RequestResolver resolver = this.requestResolvers.remove(messageId); + if (resolver != null) { + if (responseResult instanceof McpSchema.JSONRPCResponse) { + McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) responseResult; + if (response.getError() != null) { + if (resolver.errorHandler != null) { + resolver.errorHandler.accept(new McpError(response.getError())); + } else { + resolver.responseHandler.accept(new McpError(response.getError())); + } + } else { + resolver.responseHandler.accept(response.getResult()); + } + } else if (responseResult instanceof Throwable) { + if (resolver.errorHandler != null) { + resolver.errorHandler.accept((Throwable) responseResult); + } else { + resolver.responseHandler.accept(responseResult); + } + } else { + resolver.responseHandler.accept(responseResult); + } + return new InboundResponseResult(true); + } + + return new InboundResponseResult(false); + } + + @Override + public CompletableFuture processOutboundNotification( + String notificationMethod, Object notification, NotificationOptions options) { + String relatedTaskId = options != null && options.relatedTask() != null + ? options.relatedTask().taskId() + : null; + + if (relatedTaskId != null && this.messageQueue != null) { + QueuedMessage.Notification queuedNotification = new QueuedMessage.Notification( + notificationMethod, notification); + + return this.messageQueue.enqueue(relatedTaskId, queuedNotification) + .thenApply(v -> new OutboundNotificationResult(true, null)); + } + + return CompletableFuture.completedFuture( + new OutboundNotificationResult(false, notification)); + } + + @Override + public void onClose() { + this.requestResolvers.clear(); + } + + @Override + public Optional> taskStore() { + return Optional.ofNullable(this.taskStore); + } + + @Override + public Optional messageQueue() { + return Optional.ofNullable(this.messageQueue); + } + + @Override + public Duration defaultPollInterval() { + return this.defaultPollInterval; + } + + // Private helpers + + @SuppressWarnings("unchecked") + private String extractRelatedTaskId(Object requestParams) { + if (requestParams == null) { + return null; + } + try { + if (requestParams instanceof Map) { + Map params = (Map) requestParams; + Object meta = params.get("_meta"); + if (meta instanceof Map) { + Map metaMap = (Map) meta; + Object relatedTask = metaMap.get(RELATED_TASK_META_KEY); + if (relatedTask instanceof Map) { + Map relatedTaskMap = (Map) relatedTask; + return (String) relatedTaskMap.get("taskId"); + } + } + } + } catch (ClassCastException e) { + logger.debug("Failed to extract related task ID: {}", e.getMessage()); + } + return null; + } + + @SuppressWarnings("unchecked") + private TaskCreationParams extractTaskCreationParams(Object requestParams) { + if (requestParams == null) { + return null; + } + try { + if (requestParams instanceof Map) { + Map params = (Map) requestParams; + Object task = params.get("task"); + if (task instanceof Map) { + Map taskMap = (Map) task; + Long ttl = taskMap.containsKey("ttl") + ? ((Number) taskMap.get("ttl")).longValue() + : null; + return new TaskCreationParams(ttl); + } + } + } catch (ClassCastException e) { + logger.debug("Failed to extract task creation params: {}", e.getMessage()); + } + return null; + } + + // Handler implementations + + private CompletableFuture handleGetTask(String requestMethod, Object requestParams, + TaskManagerHost.TaskHandlerContext ctx) { + if (this.taskStore == null) { + CompletableFuture failed = new CompletableFuture<>(); + failed.completeExceptionally(McpError.builder(McpSchema.ErrorCodes.INTERNAL_ERROR) + .message("TaskStore not configured") + .build()); + return failed; + } + + String taskId = extractTaskIdFromParams(requestParams); + if (taskId == null) { + CompletableFuture failed = new CompletableFuture<>(); + failed.completeExceptionally(McpError.builder(McpSchema.ErrorCodes.INVALID_PARAMS) + .message("Missing required parameter: taskId") + .build()); + return failed; + } + + McpSchema.GetTaskRequest typedRequest = new McpSchema.GetTaskRequest(taskId, null); + + return host.invokeCustomTaskHandler(taskId, McpSchema.METHOD_TASKS_GET, typedRequest, ctx, + McpSchema.GetTaskResult.class) + .thenCompose(result -> { + if (result != null) { + return CompletableFuture.completedFuture(result); + } + return this.taskStore.getTask(taskId, ctx.sessionId()) + .thenCompose(storeResult -> { + if (storeResult == null) { + CompletableFuture failed = new CompletableFuture<>(); + failed.completeExceptionally(McpError.builder(McpSchema.ErrorCodes.INVALID_PARAMS) + .message("Task not found (may have expired after TTL)") + .data("Task ID: " + taskId) + .build()); + return failed; + } + return CompletableFuture.completedFuture( + (McpSchema.Result) McpSchema.GetTaskResult.fromTask(storeResult.task())); + }); + }); + } + + private CompletableFuture handleGetTaskResult(String requestMethod, Object requestParams, + TaskManagerHost.TaskHandlerContext ctx) { + if (this.taskStore == null) { + CompletableFuture failed = new CompletableFuture<>(); + failed.completeExceptionally(McpError.builder(McpSchema.ErrorCodes.INTERNAL_ERROR) + .message("TaskStore not configured") + .build()); + return failed; + } + + String taskId = extractTaskIdFromParams(requestParams); + if (taskId == null) { + CompletableFuture failed = new CompletableFuture<>(); + failed.completeExceptionally(McpError.builder(McpSchema.ErrorCodes.INVALID_PARAMS) + .message("Missing required parameter: taskId") + .build()); + return failed; + } + + String sessionId = ctx.sessionId(); + + McpSchema.GetTaskPayloadRequest typedRequest = new McpSchema.GetTaskPayloadRequest(taskId, null); + + return this.taskStore.getTask(taskId, sessionId) + .thenCompose(storeResult -> { + if (storeResult == null) { + CompletableFuture failed = new CompletableFuture<>(); + failed.completeExceptionally(McpError.builder(McpSchema.ErrorCodes.INVALID_PARAMS) + .message("Task not found (may have expired after TTL)") + .data("Task ID: " + taskId) + .build()); + return failed; + } + + McpSchema.Task task = storeResult.task(); + + logger.debug("handleGetTaskResult: Task {} status={}, messageQueue={}", + taskId, task.getStatus(), + this.messageQueue != null ? "present" : "null"); + + // Handle INPUT_REQUIRED: process queued side-channel messages first + if (task.getStatus() == McpSchema.TaskStatus.INPUT_REQUIRED && this.messageQueue != null) { + logger.debug("handleGetTaskResult: Task {} is INPUT_REQUIRED, starting side-channel processing", + taskId); + return processQueuedMessagesAndWaitForTerminal(ctx, taskId, sessionId) + .thenCompose(sideChannelResult -> { + return tryCustomHandlerOrDefault(taskId, typedRequest, ctx, sessionId); + }); + } + + return tryCustomHandlerOrDefault(taskId, typedRequest, ctx, sessionId); + }); + } + + /** Tries the custom tasks/result handler, falling back to default store lookup. */ + private CompletableFuture tryCustomHandlerOrDefault( + String taskId, McpSchema.GetTaskPayloadRequest typedRequest, + TaskManagerHost.TaskHandlerContext ctx, String sessionId) { + + return host.invokeCustomTaskHandler(taskId, McpSchema.METHOD_TASKS_RESULT, typedRequest, ctx, + McpSchema.ServerTaskPayloadResult.class) + .thenCompose(result -> { + if (result != null) { + return CompletableFuture.completedFuture(result); + } + return defaultGetTaskResult(taskId, sessionId); + }); + } + + /** Default tasks/result implementation using the TaskStore. */ + private CompletableFuture defaultGetTaskResult(String taskId, String sessionId) { + return this.taskStore.getTask(taskId, sessionId) + .thenCompose(storeResult -> { + if (storeResult == null) { + CompletableFuture failed = new CompletableFuture<>(); + failed.completeExceptionally(McpError.builder(McpSchema.ErrorCodes.INVALID_PARAMS) + .message("Task not found") + .build()); + return failed; + } + + McpSchema.Task task = storeResult.task(); + + if (task.isTerminal()) { + logger.debug("defaultGetTaskResult: Task {} is terminal, fetching result", taskId); + return fetchTaskResult(taskId, sessionId); + } + + return watchAndFetchResult(taskId, sessionId); + }); + } + + /** Fetches the result of a terminal task. */ + @SuppressWarnings("unchecked") + private CompletableFuture fetchTaskResult(String taskId, String sessionId) { + // Re-fetch the task to get its current status for fallback construction. + return this.taskStore.getTask(taskId, sessionId).thenCompose(storeResult -> { + final McpSchema.Task task = storeResult != null ? storeResult.task() : null; + TaskStore store = (TaskStore) this.taskStore; + return store.getTaskResult(taskId, sessionId) + .thenApply(result -> { + if (result != null) { + return result; + } + // CANCELLED tasks never store a payload — construct a semantic response. + if (task != null && task.getStatus() == McpSchema.TaskStatus.CANCELLED) { + String msg = "Task was cancelled" + + (task.getStatusMessage() != null ? ": " + task.getStatusMessage() : ""); + return (McpSchema.Result) new McpSchema.CallToolResult(msg, true, null); + } + // Should not reach here for FAILED tasks (payload stored by failTask). + throw new RuntimeException("Task result not found"); + }); + }); + } + + /** Watches a task until terminal, then fetches its result. */ + private CompletableFuture watchAndFetchResult(String taskId, String sessionId) { + long timeoutMs = this.pollTimeout.toMillis(); + return this.taskStore.watchTaskUntilTerminal(taskId, sessionId, timeoutMs) + .thenCompose(updates -> { + if (updates == null || updates.isEmpty()) { + CompletableFuture failed = new CompletableFuture<>(); + failed.completeExceptionally(McpError.builder(McpSchema.ErrorCodes.INTERNAL_ERROR) + .message("Task did not complete within timeout") + .data("Task ID: " + taskId) + .build()); + return failed; + } + McpSchema.Task terminalTask = updates.get(updates.size() - 1); + if (!terminalTask.isTerminal()) { + CompletableFuture failed = new CompletableFuture<>(); + failed.completeExceptionally(McpError.builder(McpSchema.ErrorCodes.INTERNAL_ERROR) + .message("Task did not complete within timeout") + .data("Task ID: " + taskId) + .build()); + return failed; + } + return fetchTaskResult(taskId, sessionId); + }) + .exceptionally(ex -> { + if (ex instanceof TimeoutException || ex.getCause() instanceof TimeoutException) { + throw new RuntimeException( + McpError.builder(McpSchema.ErrorCodes.INTERNAL_ERROR) + .message("Task did not complete within timeout") + .data("Task ID: " + taskId) + .build()); + } + throw new RuntimeException(ex); + }); + } + + /** + * Processes all queued side-channel messages for an INPUT_REQUIRED task, then waits for terminal state. + */ + private CompletableFuture processQueuedMessagesAndWaitForTerminal( + TaskManagerHost.TaskHandlerContext ctx, String taskId, String sessionId) { + logger.debug("processQueuedMessagesAndWaitForTerminal: Starting side-channel processing for task {}", taskId); + + return processAllQueuedMessages(ctx, taskId) + .thenCompose(v -> { + logger.debug("processQueuedMessagesAndWaitForTerminal: Finished processing queue for task {}", + taskId); + return pollAndProcessUntilTerminal(ctx, taskId, sessionId); + }); + } + + /** Dequeues and processes all actionable messages for a task. */ + private CompletableFuture processAllQueuedMessages(TaskManagerHost.TaskHandlerContext ctx, String taskId) { + return this.messageQueue.dequeueAll(taskId) + .thenCompose(messages -> { + CompletableFuture allProcessed = CompletableFuture.completedFuture(null); + for (QueuedMessage msg : messages) { + allProcessed = allProcessed.thenCompose(v -> processMessage(ctx, msg, taskId)); + } + return allProcessed; + }); + } + + /** Dispatches a single queued message to the client. */ + private CompletableFuture processMessage(TaskManagerHost.TaskHandlerContext ctx, QueuedMessage msg, + String taskId) { + if (msg instanceof QueuedMessage.Request) { + QueuedMessage.Request req = (QueuedMessage.Request) msg; + return sendRequestAndEnqueueResponse(ctx, req, taskId); + } + + if (msg instanceof QueuedMessage.Notification) { + QueuedMessage.Notification notif = (QueuedMessage.Notification) msg; + return sendNotificationToClient(ctx, notif, taskId); + } + + return CompletableFuture.completedFuture(null); + } + + /** Sends a request to the client and enqueues the response for waitForResponse() to retrieve. */ + private CompletableFuture sendRequestAndEnqueueResponse(TaskManagerHost.TaskHandlerContext ctx, + QueuedMessage.Request req, String taskId) { + String requestId = String.valueOf(req.requestId()); + + logger.debug("sendRequestAndEnqueueResponse: Sending {} request {} to client for task {}", + req.method(), requestId, taskId); + + Class resultClass = getResultClass(req.method()); + + return ctx.sendRequest(req.method(), req.request(), resultClass) + .thenCompose(result -> { + logger.debug("sendRequestAndEnqueueResponse: Got response for request {}, enqueueing for task {}", + requestId, taskId); + QueuedMessage.Response response = new QueuedMessage.Response(requestId, result); + return this.messageQueue.enqueue(taskId, response); + }); + } + + /** Returns the result class for a known side-channel method. */ + private Class getResultClass(String method) { + if (McpSchema.METHOD_ELICITATION_CREATE.equals(method)) { + return McpSchema.ElicitResult.class; + } else if (McpSchema.METHOD_SAMPLING_CREATE_MESSAGE.equals(method)) { + return McpSchema.CreateMessageResult.class; + } else { + throw new IllegalArgumentException("Unsupported side-channel method: " + method); + } + } + + /** Sends a notification to the client without waiting for a response. */ + private CompletableFuture sendNotificationToClient(TaskManagerHost.TaskHandlerContext ctx, + QueuedMessage.Notification notif, String taskId) { + Object notification = TaskMetadataUtils.addRelatedTaskMetadata(taskId, notif.notification()); + return ctx.sendNotification(notif.method(), notification); + } + + /** Polls and processes messages until the task reaches a terminal state. */ + private CompletableFuture pollAndProcessUntilTerminal( + TaskManagerHost.TaskHandlerContext ctx, String taskId, String sessionId) { + + CompletableFuture pollingFuture = doPollAndProcess(ctx, taskId, sessionId); + + CompletableFuture timeoutFuture = delay(this.pollTimeout.toMillis()) + .thenApply(v -> { + throw new java.util.concurrent.CompletionException( + McpError.builder(McpSchema.ErrorCodes.INTERNAL_ERROR) + .message("Task did not complete within timeout") + .data("Task ID: " + taskId) + .build()); + }); + + return CompletableFuture.anyOf(pollingFuture, timeoutFuture) + .thenApply(obj -> (McpSchema.Result) obj); + } + + /** Recursive poll-and-process loop. */ + private CompletableFuture doPollAndProcess( + TaskManagerHost.TaskHandlerContext ctx, String taskId, String sessionId) { + return this.taskStore.getTask(taskId, sessionId) + .thenCompose(storeResult -> { + if (storeResult == null) { + CompletableFuture failed = new CompletableFuture<>(); + failed.completeExceptionally(McpError.builder(McpSchema.ErrorCodes.INVALID_PARAMS) + .message("Task not found during polling") + .data("Task ID: " + taskId) + .build()); + return failed; + } + + McpSchema.Task task = storeResult.task(); + + if (task.isTerminal()) { + return fetchTaskResult(taskId, sessionId); + } + + long interval = task.getPollInterval() != null + ? task.getPollInterval() + : this.defaultPollInterval.toMillis(); + + if (task.getStatus() == McpSchema.TaskStatus.INPUT_REQUIRED) { + return processAllQueuedMessages(ctx, taskId) + .thenCompose(v -> delay(interval)) + .thenCompose(ignored -> doPollAndProcess(ctx, taskId, sessionId)); + } + + return delay(interval) + .thenCompose(ignored -> doPollAndProcess(ctx, taskId, sessionId)); + }); + } + + private CompletableFuture handleListTasks(String requestMethod, Object requestParams, + TaskManagerHost.TaskHandlerContext ctx) { + if (this.taskStore == null) { + CompletableFuture failed = new CompletableFuture<>(); + failed.completeExceptionally(McpError.builder(McpSchema.ErrorCodes.INTERNAL_ERROR) + .message("TaskStore not configured") + .build()); + return failed; + } + + String cursor = extractCursorFromParams(requestParams); + + return this.taskStore.listTasks(cursor, ctx.sessionId()) + .thenApply(result -> (McpSchema.Result) result); + } + + private CompletableFuture handleCancelTask(String requestMethod, Object requestParams, + TaskManagerHost.TaskHandlerContext ctx) { + if (this.taskStore == null) { + CompletableFuture failed = new CompletableFuture<>(); + failed.completeExceptionally(McpError.builder(McpSchema.ErrorCodes.INTERNAL_ERROR) + .message("TaskStore not configured") + .build()); + return failed; + } + + String taskId = extractTaskIdFromParams(requestParams); + if (taskId == null) { + CompletableFuture failed = new CompletableFuture<>(); + failed.completeExceptionally(McpError.builder(McpSchema.ErrorCodes.INVALID_PARAMS) + .message("Missing required parameter: taskId") + .build()); + return failed; + } + + return this.taskStore.requestCancellation(taskId, ctx.sessionId()) + .thenApply(task -> { + if (task == null) { + throw new CompletionException( + McpError.builder(McpSchema.ErrorCodes.INVALID_PARAMS) + .message("Task not found or not accessible") + .data("Task ID: " + taskId) + .build()); + } + return (McpSchema.Result) McpSchema.CancelTaskResult.fromTask(task); + }); + } + + private String extractTaskIdFromParams(Object params) { + return extractStringFromParams(params, "taskId"); + } + + private String extractCursorFromParams(Object params) { + return extractStringFromParams(params, "cursor"); + } + + @SuppressWarnings("unchecked") + private String extractStringFromParams(Object params, String key) { + if (params == null) { + return null; + } + try { + if (params instanceof Map) { + Map paramsMap = (Map) params; + return (String) paramsMap.get(key); + } + } catch (ClassCastException e) { + logger.debug("Failed to extract {} from params: {}", key, e.getMessage()); + } + return null; + } + + /** Shared scheduler to avoid creating a new thread pool per delay() call. */ + private static final java.util.concurrent.ScheduledExecutorService DELAY_SCHEDULER = + java.util.concurrent.Executors.newSingleThreadScheduledExecutor(r -> { + Thread t = new Thread(r, "task-manager-delay"); + t.setDaemon(true); + return t; + }); + + /** Delays for the given number of milliseconds. */ + private CompletableFuture delay(long milliseconds) { + CompletableFuture future = new CompletableFuture<>(); + DELAY_SCHEDULER.schedule(() -> future.complete(null), milliseconds, java.util.concurrent.TimeUnit.MILLISECONDS); + return future; + } +} diff --git a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/GetTaskFromStoreResult.java b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/GetTaskFromStoreResult.java new file mode 100644 index 00000000000..10671d25256 --- /dev/null +++ b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/GetTaskFromStoreResult.java @@ -0,0 +1,31 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package com.taobao.arthas.mcp.server.task; + +import com.taobao.arthas.mcp.server.protocol.spec.McpSchema; + +/** + * Result of a task lookup from {@link TaskStore}, including the originating request. + * + * @author Yeaury + */ +public class GetTaskFromStoreResult { + + private final McpSchema.Task task; + private final McpSchema.Request originatingRequest; + + public GetTaskFromStoreResult(McpSchema.Task task, McpSchema.Request originatingRequest) { + this.task = task; + this.originatingRequest = originatingRequest; + } + + public McpSchema.Task task() { + return task; + } + + public McpSchema.Request originatingRequest() { + return originatingRequest; + } +} diff --git a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/GetTaskHandler.java b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/GetTaskHandler.java new file mode 100644 index 00000000000..bd236c10050 --- /dev/null +++ b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/GetTaskHandler.java @@ -0,0 +1,24 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package com.taobao.arthas.mcp.server.task; + +import com.taobao.arthas.mcp.server.protocol.server.McpNettyServerExchange; +import com.taobao.arthas.mcp.server.protocol.spec.McpSchema; + +import java.util.concurrent.CompletableFuture; + +/** + * Optional custom handler for {@code tasks/get} requests. + * + * @author Yeaury + */ +@FunctionalInterface +public interface GetTaskHandler { + + CompletableFuture handle( + McpNettyServerExchange exchange, + McpSchema.GetTaskRequest request + ); +} diff --git a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/GetTaskResultHandler.java b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/GetTaskResultHandler.java new file mode 100644 index 00000000000..d35f827b595 --- /dev/null +++ b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/GetTaskResultHandler.java @@ -0,0 +1,24 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package com.taobao.arthas.mcp.server.task; + +import com.taobao.arthas.mcp.server.protocol.server.McpNettyServerExchange; +import com.taobao.arthas.mcp.server.protocol.spec.McpSchema; + +import java.util.concurrent.CompletableFuture; + +/** + * Optional custom handler for {@code tasks/result} requests. + * + * @author Yeaury + */ +@FunctionalInterface +public interface GetTaskResultHandler { + + CompletableFuture handle( + McpNettyServerExchange exchange, + McpSchema.GetTaskPayloadRequest request + ); +} diff --git a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/InMemoryTaskMessageQueue.java b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/InMemoryTaskMessageQueue.java new file mode 100644 index 00000000000..cf16f64be67 --- /dev/null +++ b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/InMemoryTaskMessageQueue.java @@ -0,0 +1,176 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package com.taobao.arthas.mcp.server.task; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.*; + +/** + * In-memory {@link TaskMessageQueue} implementation. + * + *

Uses two separate queues per task: actionable (Request/Notification, returned by + * dequeue/dequeueAll) and response (Response, retrieved via waitForResponse only). + * + * @author Yeaury + */ +public class InMemoryTaskMessageQueue implements TaskMessageQueue { + + private static final Logger logger = LoggerFactory.getLogger(InMemoryTaskMessageQueue.class); + + private final ConcurrentHashMap> actionableQueues + = new ConcurrentHashMap<>(); + + private final ConcurrentHashMap> responseQueues + = new ConcurrentHashMap<>(); + + @Override + public CompletableFuture enqueue(String taskId, QueuedMessage message) { + return CompletableFuture.runAsync(() -> { + if (message instanceof QueuedMessage.Response) { + QueuedMessage.Response response = (QueuedMessage.Response) message; + responseQueues.computeIfAbsent(taskId, k -> new ConcurrentLinkedQueue<>()) + .offer(response); + logger.debug("Enqueued response for task {} (requestId: {})", taskId, response.requestId()); + } else { + actionableQueues.computeIfAbsent(taskId, k -> new ConcurrentLinkedQueue<>()) + .offer(message); + logger.debug("Enqueued {} for task {}", message.getClass().getSimpleName(), taskId); + } + }); + } + + @Override + public CompletableFuture dequeue(String taskId) { + return CompletableFuture.supplyAsync(() -> { + ConcurrentLinkedQueue queue = actionableQueues.get(taskId); + if (queue == null || queue.isEmpty()) { + return null; + } + QueuedMessage msg = queue.poll(); + if (msg != null) { + logger.debug("Dequeued {} for task {}", msg.getClass().getSimpleName(), taskId); + } + return msg; + }); + } + + @Override + public CompletableFuture> dequeueAll(String taskId) { + return CompletableFuture.supplyAsync(() -> { + ConcurrentLinkedQueue queue = actionableQueues.get(taskId); + if (queue == null || queue.isEmpty()) { + return Collections.emptyList(); + } + List messages = new ArrayList<>(); + QueuedMessage msg; + while ((msg = queue.poll()) != null) { + messages.add(msg); + } + if (!messages.isEmpty()) { + logger.debug("Dequeued {} messages for task {}", messages.size(), taskId); + } + return messages; + }); + } + + @Override + public CompletableFuture waitForResponse(String taskId, Object requestId, Duration timeout) { + return CompletableFuture.supplyAsync(() -> { + long startTime = System.currentTimeMillis(); + long timeoutMs = timeout.toMillis(); + long pollInterval = TaskDefaults.RESPONSE_POLL_INTERVAL_MS; + + logger.debug("waitForResponse: Waiting for response to request {} for task {} (timeout: {}ms)", + requestId, taskId, timeoutMs); + + while (System.currentTimeMillis() - startTime < timeoutMs) { + ConcurrentLinkedQueue queue = responseQueues.get(taskId); + if (queue != null) { + for (QueuedMessage.Response response : queue) { + if (requestId.equals(response.requestId())) { + queue.remove(response); + logger.debug("waitForResponse: Found response for request {} in task {}", + requestId, taskId); + return response; + } + } + } + try { + Thread.sleep(pollInterval); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new CompletionException("Interrupted while waiting for response", e); + } + } + + logger.warn("waitForResponse: Timeout waiting for response to request {} for task {}", + requestId, taskId); + throw new CompletionException(new TimeoutException( + "Timeout waiting for response to request " + requestId + " for task " + taskId)); + }); + } + + @Override + public CompletableFuture clearTask(String taskId) { + return CompletableFuture.runAsync(() -> { + ConcurrentLinkedQueue actionableQueue = actionableQueues.remove(taskId); + ConcurrentLinkedQueue responseQueue = responseQueues.remove(taskId); + int totalCleared = 0; + if (actionableQueue != null) totalCleared += actionableQueue.size(); + if (responseQueue != null) totalCleared += responseQueue.size(); + if (totalCleared > 0) { + logger.debug("Cleared {} messages for task {}", totalCleared, taskId); + } + }); + } + + @Override + public CompletableFuture getQueueSize(String taskId) { + return CompletableFuture.supplyAsync(() -> { + int size = 0; + ConcurrentLinkedQueue actionableQueue = actionableQueues.get(taskId); + if (actionableQueue != null) size += actionableQueue.size(); + ConcurrentLinkedQueue responseQueue = responseQueues.get(taskId); + if (responseQueue != null) size += responseQueue.size(); + return size; + }); + } + + @Override + public CompletableFuture shutdown() { + return CompletableFuture.runAsync(() -> { + int totalMessages = actionableQueues.values().stream().mapToInt(ConcurrentLinkedQueue::size).sum(); + totalMessages += responseQueues.values().stream().mapToInt(ConcurrentLinkedQueue::size).sum(); + actionableQueues.clear(); + responseQueues.clear(); + logger.info("TaskMessageQueue shut down (cleared {} messages)", totalMessages); + }); + } + + /** Returns the actionable message count for a task (for testing/monitoring). */ + public int getActionableMessageCount(String taskId) { + ConcurrentLinkedQueue queue = actionableQueues.get(taskId); + return queue != null ? queue.size() : 0; + } + + /** Returns the response message count for a task (for testing/monitoring). */ + public int getResponseMessageCount(String taskId) { + ConcurrentLinkedQueue queue = responseQueues.get(taskId); + return queue != null ? queue.size() : 0; + } + + /** Returns the total message count across all tasks (for testing/monitoring). */ + public int getTotalMessageCount() { + int total = actionableQueues.values().stream().mapToInt(ConcurrentLinkedQueue::size).sum(); + total += responseQueues.values().stream().mapToInt(ConcurrentLinkedQueue::size).sum(); + return total; + } +} diff --git a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/InMemoryTaskStore.java b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/InMemoryTaskStore.java new file mode 100644 index 00000000000..5b40462064f --- /dev/null +++ b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/InMemoryTaskStore.java @@ -0,0 +1,518 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package com.taobao.arthas.mcp.server.task; + +import com.taobao.arthas.mcp.server.protocol.spec.McpError; +import com.taobao.arthas.mcp.server.protocol.spec.McpSchema; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.time.Duration; +import java.time.Instant; +import java.util.*; +import java.util.concurrent.*; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; + +/** + * In-memory {@link TaskStore} implementation with TTL-based cleanup. + * + *

Uses {@link ConcurrentSkipListMap} for O(log n) sorted access and efficient + * cursor-based pagination via {@code tailMap()}. + * + * @param result type + * @author Yeaury + */ +public class InMemoryTaskStore implements TaskStore { + + private static final Logger logger = LoggerFactory.getLogger(InMemoryTaskStore.class); + + private static final long DEFAULT_TTL_MS = TaskDefaults.DEFAULT_TTL_MS; + + private static final long DEFAULT_POLL_INTERVAL_MS = TaskDefaults.DEFAULT_POLL_INTERVAL_MS; + + private static final int DEFAULT_PAGE_SIZE = TaskDefaults.DEFAULT_PAGE_SIZE; + + private static final int DEFAULT_MAX_TASKS = TaskDefaults.DEFAULT_MAX_TASKS; + + private final NavigableMap tasks = new ConcurrentSkipListMap<>(); + + private final Map results = new ConcurrentHashMap<>(); + + private final Set cancellationRequests = ConcurrentHashMap.newKeySet(); + + private final ScheduledExecutorService cleanupExecutor; + + private final long defaultTtl; + + private final long defaultPollInterval; + + private static final AtomicLong INSTANCE_COUNTER = new AtomicLong(0); + + private final long instanceId; + + private final TaskMessageQueue messageQueue; + + private final int maxTasks; + + public InMemoryTaskStore() { + this(DEFAULT_TTL_MS, DEFAULT_POLL_INTERVAL_MS, null, DEFAULT_MAX_TASKS); + } + + public InMemoryTaskStore(long defaultTtl, long defaultPollInterval) { + this(defaultTtl, defaultPollInterval, null, DEFAULT_MAX_TASKS); + } + + public InMemoryTaskStore(long defaultTtl, long defaultPollInterval, TaskMessageQueue messageQueue) { + this(defaultTtl, defaultPollInterval, messageQueue, DEFAULT_MAX_TASKS); + } + + public InMemoryTaskStore(long defaultTtl, long defaultPollInterval, + TaskMessageQueue messageQueue, int maxTasks) { + if (defaultTtl <= 0) throw new IllegalArgumentException("defaultTtl must be positive"); + if (defaultPollInterval <= 0) throw new IllegalArgumentException("defaultPollInterval must be positive"); + if (maxTasks <= 0) throw new IllegalArgumentException("maxTasks must be positive"); + + this.instanceId = INSTANCE_COUNTER.incrementAndGet(); + this.defaultTtl = defaultTtl; + this.defaultPollInterval = defaultPollInterval; + this.messageQueue = messageQueue; + this.maxTasks = maxTasks; + + this.cleanupExecutor = Executors.newSingleThreadScheduledExecutor(r -> { + Thread t = new Thread(r, "mcp-task-cleanup-" + instanceId); + t.setDaemon(true); + return t; + }); + this.cleanupExecutor.scheduleAtFixedRate( + this::cleanupExpiredTasks, 1, 1, TimeUnit.MINUTES); + } + + public static Builder builder() { + return new Builder<>(); + } + + public static class Builder { + + private long defaultTtl = DEFAULT_TTL_MS; + private long defaultPollInterval = DEFAULT_POLL_INTERVAL_MS; + private TaskMessageQueue messageQueue = null; + private int maxTasks = DEFAULT_MAX_TASKS; + + public Builder defaultTtl(Duration ttl) { + this.defaultTtl = ttl.toMillis(); + return this; + } + + public Builder defaultTtlMs(long ttlMs) { + this.defaultTtl = ttlMs; + return this; + } + + public Builder defaultPollInterval(Duration interval) { + this.defaultPollInterval = interval.toMillis(); + return this; + } + + public Builder defaultPollIntervalMs(long intervalMs) { + this.defaultPollInterval = intervalMs; + return this; + } + + public Builder messageQueue(TaskMessageQueue queue) { + this.messageQueue = queue; + return this; + } + + public Builder maxTasks(int max) { + this.maxTasks = max; + return this; + } + + public InMemoryTaskStore build() { + return new InMemoryTaskStore<>(defaultTtl, defaultPollInterval, messageQueue, maxTasks); + } + } + + private final Object createTaskLock = new Object(); + + private boolean isSessionValid(TaskEntry entry, String requestSessionId) { + if (requestSessionId == null) return true; + String taskSessionId = entry.sessionId(); + if (taskSessionId == null || taskSessionId.isEmpty()) return true; + return requestSessionId.equals(taskSessionId); + } + + @Override + public CompletableFuture createTask(CreateTaskOptions options) { + return CompletableFuture.supplyAsync(() -> { + synchronized (createTaskLock) { + if (tasks.size() >= maxTasks) { + throw new CompletionException( + McpError.builder(McpSchema.ErrorCodes.INTERNAL_ERROR) + .message("Maximum task limit reached (" + maxTasks + ")") + .build() + ); + } + + String taskId = options.taskId() != null ? options.taskId() : UUID.randomUUID().toString(); + String now = Instant.now().toString(); + Long ttl = options.requestedTtl() != null ? options.requestedTtl() : defaultTtl; + Long pollInterval = options.pollInterval() != null ? options.pollInterval() : defaultPollInterval; + String sessionId = options.sessionId(); + + McpSchema.Task task = McpSchema.Task.builder() + .taskId(taskId) + .status(McpSchema.TaskStatus.WORKING) + .createdAt(now) + .lastUpdatedAt(now) + .ttl(ttl) + .pollInterval(pollInterval) + .build(); + + tasks.put(taskId, new TaskEntry(task, options.originatingRequest(), + options.context(), sessionId)); + logger.info("createTask: Created task - taskId: {}, sessionId: {}", taskId, sessionId); + return task; + } + }); + } + + @Override + public CompletableFuture getTask(String taskId, String sessionId) { + return CompletableFuture.supplyAsync(() -> { + TaskEntry entry = tasks.get(taskId); + if (entry == null) { + logger.debug("getTask: Task not found - taskId: {}", taskId); + return null; + } + if (!isSessionValid(entry, sessionId)) { + logger.warn("getTask: Session validation failed - taskId: {}", taskId); + return null; + } + return new GetTaskFromStoreResult(entry.task(), entry.originatingRequest()); + }); + } + + @Override + public CompletableFuture updateTaskStatus(String taskId, String sessionId, + McpSchema.TaskStatus status, String statusMessage) { + return CompletableFuture.runAsync(() -> { + tasks.computeIfPresent(taskId, (id, entry) -> { + if (!isSessionValid(entry, sessionId)) return entry; + McpSchema.Task oldTask = entry.task(); + if (TaskHelper.isTerminal(oldTask.getStatus())) return entry; + String now = Instant.now().toString(); + McpSchema.Task newTask = McpSchema.Task.builder() + .taskId(oldTask.getTaskId()) + .status(status) + .statusMessage(statusMessage) + .createdAt(oldTask.getCreatedAt()) + .lastUpdatedAt(now) + .ttl(oldTask.getTtl()) + .pollInterval(oldTask.getPollInterval()) + .build(); + logger.debug("Updated task {} status: {} -> {}", taskId, oldTask.getStatus(), status); + return new TaskEntry(newTask, entry.originatingRequest(), entry.context(), entry.sessionId()); + }); + }); + } + + @Override + public CompletableFuture storeTaskResult(String taskId, String sessionId, + McpSchema.TaskStatus status, R result) { + return CompletableFuture.runAsync(() -> { + AtomicBoolean taskFound = new AtomicBoolean(false); + AtomicBoolean sessionValid = new AtomicBoolean(true); + AtomicBoolean wasTerminal = new AtomicBoolean(false); + + tasks.computeIfPresent(taskId, (id, entry) -> { + taskFound.set(true); + if (!isSessionValid(entry, sessionId)) { + sessionValid.set(false); + return entry; + } + McpSchema.Task oldTask = entry.task(); + if (TaskHelper.isTerminal(oldTask.getStatus())) { + wasTerminal.set(true); + return entry; + } + results.put(taskId, result); + String now = Instant.now().toString(); + McpSchema.Task newTask = McpSchema.Task.builder() + .taskId(oldTask.getTaskId()) + .status(status) + .createdAt(oldTask.getCreatedAt()) + .lastUpdatedAt(now) + .ttl(oldTask.getTtl()) + .pollInterval(oldTask.getPollInterval()) + .build(); + logger.debug("Stored result for task: {}", taskId); + return new TaskEntry(newTask, entry.originatingRequest(), entry.context(), entry.sessionId()); + }); + + if (!taskFound.get()) { + throw new CompletionException( + McpError.builder(McpSchema.ErrorCodes.INVALID_PARAMS) + .message("Task not found (may have expired after TTL): " + taskId) + .build() + ); + } + if (!sessionValid.get()) { + throw new CompletionException( + McpError.builder(McpSchema.ErrorCodes.INVALID_PARAMS) + .message("Task not found (may have expired after TTL): " + taskId) + .build() + ); + } + if (wasTerminal.get()) { + logger.debug("Skipped storing result for task {} - already in terminal state", taskId); + } + }); + } + + @Override + public CompletableFuture getTaskResult(String taskId, String sessionId) { + return CompletableFuture.supplyAsync(() -> { + TaskEntry entry = tasks.get(taskId); + if (entry == null || !isSessionValid(entry, sessionId)) return null; + return results.get(taskId); + }); + } + + @Override + public CompletableFuture listTasks(String cursor, String sessionId) { + return CompletableFuture.supplyAsync(() -> { + List taskList = new ArrayList<>(); + String nextCursor = null; + + // Use tailMap for O(log n) cursor lookup; handles missing cursors gracefully + NavigableMap view = cursor != null + ? tasks.tailMap(cursor, false) + : tasks; + + Iterator> iterator = view.entrySet().iterator(); + int count = 0; + String lastKey = null; + + while (iterator.hasNext() && count < DEFAULT_PAGE_SIZE) { + Map.Entry entry = iterator.next(); + TaskEntry taskEntry = entry.getValue(); + if (sessionId != null && !sessionId.equals(taskEntry.sessionId())) continue; + taskList.add(taskEntry.task()); + lastKey = entry.getKey(); + count++; + } + + while (iterator.hasNext()) { + Map.Entry entry = iterator.next(); + if (sessionId == null || sessionId.equals(entry.getValue().sessionId())) { + nextCursor = lastKey; + break; + } + } + + return new McpSchema.ListTasksResult(taskList, nextCursor); + }); + } + + @Override + public CompletableFuture requestCancellation(String taskId, String sessionId) { + return CompletableFuture.supplyAsync(() -> { + AtomicReference resultRef = new AtomicReference<>(); + AtomicReference terminalStatusRef = new AtomicReference<>(); + AtomicBoolean sessionValid = new AtomicBoolean(true); + + tasks.computeIfPresent(taskId, (id, entry) -> { + if (!isSessionValid(entry, sessionId)) { + sessionValid.set(false); + return entry; + } + McpSchema.Task oldTask = entry.task(); + if (TaskHelper.isTerminal(oldTask.getStatus())) { + terminalStatusRef.set(oldTask.getStatus()); + resultRef.set(oldTask); + return entry; + } + cancellationRequests.add(taskId); + String now = Instant.now().toString(); + McpSchema.Task newTask = McpSchema.Task.builder() + .taskId(oldTask.getTaskId()) + .status(McpSchema.TaskStatus.CANCELLED) + .statusMessage("Cancellation requested") + .createdAt(oldTask.getCreatedAt()) + .lastUpdatedAt(now) + .ttl(oldTask.getTtl()) + .pollInterval(oldTask.getPollInterval()) + .build(); + resultRef.set(newTask); + logger.info("Cancelled task: {}", taskId); + return new TaskEntry(newTask, entry.originatingRequest(), entry.context(), entry.sessionId()); + }); + + if (!sessionValid.get()) return null; + + McpSchema.TaskStatus terminalStatus = terminalStatusRef.get(); + if (terminalStatus != null) { + throw new CompletionException( + McpError.builder(McpSchema.ErrorCodes.INVALID_PARAMS) + .message("Cannot cancel task: already in terminal status '" + terminalStatus + "'") + .data("taskId: " + taskId) + .build() + ); + } + + return resultRef.get(); + }); + } + + @Override + public CompletableFuture isCancellationRequested(String taskId, String sessionId) { + return CompletableFuture.supplyAsync(() -> { + TaskEntry entry = tasks.get(taskId); + if (entry == null || !isSessionValid(entry, sessionId)) return false; + return cancellationRequests.contains(taskId); + }); + } + + @Override + public CompletableFuture> watchTaskUntilTerminal( + String taskId, String sessionId, long timeoutMs) { + + CompletableFuture> future = new CompletableFuture<>(); + List updates = new CopyOnWriteArrayList<>(); + + return getTask(taskId, sessionId).thenCompose(initialResult -> { + if (initialResult == null) { + CompletableFuture> failed = new CompletableFuture<>(); + failed.completeExceptionally(McpError.builder(McpSchema.ErrorCodes.INVALID_PARAMS) + .message("Task not found (may have expired after TTL): " + taskId) + .build()); + return failed; + } + + McpSchema.Task initialTask = initialResult.task(); + long pollInterval = initialTask.getPollInterval() != null + ? initialTask.getPollInterval() + : DEFAULT_POLL_INTERVAL_MS; + + ScheduledExecutorService scheduler = Executors.newSingleThreadScheduledExecutor(r -> { + Thread t = new Thread(r, "task-watch-" + taskId); + t.setDaemon(true); + return t; + }); + + long startTime = System.currentTimeMillis(); + + ScheduledFuture pollingTask = scheduler.scheduleAtFixedRate(() -> { + try { + if (System.currentTimeMillis() - startTime > timeoutMs) { + future.completeExceptionally(new TimeoutException("Task watch timeout")); + scheduler.shutdown(); + return; + } + getTask(taskId, sessionId).whenComplete((result, ex) -> { + if (ex != null) { + future.completeExceptionally(ex); + scheduler.shutdown(); + return; + } + if (result != null) { + McpSchema.Task task = result.task(); + if (updates.size() < TaskDefaults.MAX_WATCH_UPDATES) { + updates.add(task); + } else { + updates.remove(0); + updates.add(task); + } + if (TaskHelper.isTerminal(task.getStatus())) { + future.complete(updates); + scheduler.shutdown(); + } + } + }); + } catch (Exception e) { + future.completeExceptionally(e); + scheduler.shutdown(); + } + }, 0, pollInterval, TimeUnit.MILLISECONDS); + + future.whenComplete((result, ex) -> { + if (!pollingTask.isDone()) pollingTask.cancel(false); + scheduler.shutdown(); + }); + + return future; + }); + } + + /** Package-visible for testing. */ + void cleanupExpiredTasks() { + Instant now = Instant.now(); + List expiredTaskIds = new ArrayList<>(); + + tasks.entrySet().removeIf(entry -> { + McpSchema.Task task = entry.getValue().task(); + if (task.getTtl() == null) return false; + Instant expiresAt = Instant.parse(task.getCreatedAt()).plusMillis(task.getTtl()); + if (now.isAfter(expiresAt)) { + String taskId = entry.getKey(); + results.remove(taskId); + cancellationRequests.remove(taskId); + expiredTaskIds.add(taskId); + logger.debug("Removed expired task: {}", taskId); + return true; + } + return false; + }); + + if (messageQueue != null && !expiredTaskIds.isEmpty()) { + for (String taskId : expiredTaskIds) { + messageQueue.clearTask(taskId).exceptionally(ex -> { + logger.warn("Failed to clear task queue for {}", taskId, ex); + return null; + }); + } + logger.debug("Completed cleanup of {} expired tasks", expiredTaskIds.size()); + } + } + + @Override + public CompletableFuture shutdown() { + return CompletableFuture.runAsync(() -> { + cleanupExecutor.shutdown(); + try { + if (!cleanupExecutor.awaitTermination(5, TimeUnit.SECONDS)) { + cleanupExecutor.shutdownNow(); + } + logger.info("TaskStore shut down"); + } catch (InterruptedException e) { + cleanupExecutor.shutdownNow(); + Thread.currentThread().interrupt(); + } + }); + } + + private static class TaskEntry { + private final McpSchema.Task task; + private final McpSchema.Request originatingRequest; + private final Object context; + private final String sessionId; + + TaskEntry(McpSchema.Task task, McpSchema.Request originatingRequest, + Object context, String sessionId) { + this.task = task; + this.originatingRequest = originatingRequest; + this.context = context; + this.sessionId = sessionId; + } + + McpSchema.Task task() { return task; } + McpSchema.Request originatingRequest() { return originatingRequest; } + Object context() { return context; } + String sessionId() { return sessionId; } + } +} diff --git a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/NullTaskManager.java b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/NullTaskManager.java new file mode 100644 index 00000000000..690dcfe9f05 --- /dev/null +++ b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/NullTaskManager.java @@ -0,0 +1,79 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package com.taobao.arthas.mcp.server.task; + +import java.time.Duration; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; + +/** + * No-op {@link TaskManager} used when task support is not configured. + * + * @author Yeaury + */ +final class NullTaskManager implements TaskManager { + + private static final NullTaskManager INSTANCE = new NullTaskManager(); + + private NullTaskManager() {} + + static TaskManager getInstance() { + return INSTANCE; + } + + @Override + public void bind(TaskManagerHost host) { + } + + @Override + public InboundRequestResult processInboundRequest(String requestMethod, Object requestParams, + InboundRequestContext ctx) { + return new InboundRequestResult( + notification -> {}, + ctx.sendRequest(), + response -> CompletableFuture.completedFuture(false), + false + ); + } + + @Override + public OutboundRequestResult processOutboundRequest(String requestMethod, Object requestParams, + RequestOptions options, Object messageId, + Consumer responseHandler, + Consumer errorHandler) { + return new OutboundRequestResult(false); + } + + @Override + public InboundResponseResult processInboundResponse(Object responseResult, Object messageId) { + return new InboundResponseResult(false); + } + + @Override + public CompletableFuture processOutboundNotification( + String notificationMethod, Object notification, NotificationOptions options) { + return CompletableFuture.completedFuture(new OutboundNotificationResult(false)); + } + + @Override + public void onClose() { + } + + @Override + public Optional> taskStore() { + return Optional.empty(); + } + + @Override + public Optional messageQueue() { + return Optional.empty(); + } + + @Override + public Duration defaultPollInterval() { + return Duration.ofMillis(TaskDefaults.DEFAULT_POLL_INTERVAL_MS); + } +} diff --git a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/QueuedMessage.java b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/QueuedMessage.java new file mode 100644 index 00000000000..e54a35bf91e --- /dev/null +++ b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/QueuedMessage.java @@ -0,0 +1,96 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package com.taobao.arthas.mcp.server.task; + +import com.taobao.arthas.mcp.server.protocol.spec.McpSchema; + +/** + * Message types for side-channel communication during task execution. + * + *

Request and Notification are dequeued for delivery to the client. + * Response messages are retrieved exclusively via {@code waitForResponse}. + * + * @author Yeaury + */ +public abstract class QueuedMessage { + + /** Server-to-client request (e.g. elicitation, sampling) requiring a response. */ + public static class Request extends QueuedMessage { + private final Object requestId; + private final String method; + private final McpSchema.Request request; + + public Request(Object requestId, String method, McpSchema.Request request) { + this.requestId = requestId; + this.method = method; + this.request = request; + } + + public Object requestId() { + return requestId; + } + + public String method() { + return method; + } + + public McpSchema.Request request() { + return request; + } + + @Override + public String toString() { + return "QueuedMessage.Request{requestId=" + requestId + ", method='" + method + "'}"; + } + } + + /** Client response to a prior Request. */ + public static class Response extends QueuedMessage { + private final Object requestId; + private final McpSchema.Result result; + + public Response(Object requestId, McpSchema.Result result) { + this.requestId = requestId; + this.result = result; + } + + public Object requestId() { + return requestId; + } + + public McpSchema.Result result() { + return result; + } + + @Override + public String toString() { + return "QueuedMessage.Response{requestId=" + requestId + "}"; + } + } + + /** Async notification (e.g. progress update) that requires no response. */ + public static class Notification extends QueuedMessage { + private final String method; + private final Object notification; + + public Notification(String method, Object notification) { + this.method = method; + this.notification = notification; + } + + public String method() { + return method; + } + + public Object notification() { + return notification; + } + + @Override + public String toString() { + return "QueuedMessage.Notification{method='" + method + "'}"; + } + } +} diff --git a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/ServerTaskToolHandler.java b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/ServerTaskToolHandler.java new file mode 100644 index 00000000000..fd8ac7a9f5c --- /dev/null +++ b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/ServerTaskToolHandler.java @@ -0,0 +1,562 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package com.taobao.arthas.mcp.server.task; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.taobao.arthas.mcp.server.protocol.server.McpNettyServerExchange; +import com.taobao.arthas.mcp.server.protocol.server.McpRequestHandler; +import com.taobao.arthas.mcp.server.protocol.server.McpTransportContext; +import com.taobao.arthas.mcp.server.protocol.spec.McpError; +import com.taobao.arthas.mcp.server.protocol.spec.McpSchema; +import com.taobao.arthas.mcp.server.session.ArthasCommandContext; +import com.taobao.arthas.mcp.server.session.ArthasCommandSessionManager; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.time.Duration; +import java.util.*; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.function.BiFunction; +import java.util.stream.Collectors; + +/** + * Manages task-aware tool registration and task lifecycle on the server side. + * + * @see TaskManager + * @see TaskManagerHost + */ +public class ServerTaskToolHandler extends AbstractTaskHandler { + + private static final Logger logger = LoggerFactory.getLogger(ServerTaskToolHandler.class); + + private final ObjectMapper objectMapper; + private final TaskManagerOptions taskOptions; + private final Duration automaticPollingTimeout; + private final ArthasCommandSessionManager sessionManager; + + private final CopyOnWriteArrayList taskTools = new CopyOnWriteArrayList<>(); + + private final ConcurrentHashMap taskToolsByName = new ConcurrentHashMap<>(); + + private final Object toolRegistrationLock = new Object(); + + /** Notifies all connected clients of a method/params pair. */ + private final BiFunction> clientNotifier; + + @SuppressWarnings("unchecked") + public ServerTaskToolHandler( + List taskTools, + TaskManagerOptions taskOptions, + ObjectMapper objectMapper, + BiFunction> clientNotifier, + Duration automaticPollingTimeout, + ArthasCommandSessionManager sessionManager) { + + super( + taskOptions != null ? (TaskStore) taskOptions.taskStore() : null, + taskOptions + ); + + this.objectMapper = objectMapper; + this.clientNotifier = clientNotifier; + this.automaticPollingTimeout = automaticPollingTimeout; + this.taskOptions = taskOptions; + this.sessionManager = sessionManager; + + this.taskTools.addAll(taskTools); + for (TaskAwareToolSpecification taskTool : taskTools) { + this.taskToolsByName.put(taskTool.tool().getName(), taskTool); + } + } + + // --------------------------------------- + // Task Tool Registration + // --------------------------------------- + + public CompletableFuture addTaskTool( + TaskAwareToolSpecification taskToolSpecification, + McpSchema.ServerCapabilities.ToolCapabilities toolCapabilities) { + + if (taskToolSpecification == null) { + CompletableFuture f = new CompletableFuture<>(); + f.completeExceptionally(new IllegalArgumentException("Task tool specification must not be null")); + return f; + } + if (taskToolSpecification.tool() == null) { + CompletableFuture f = new CompletableFuture<>(); + f.completeExceptionally(new IllegalArgumentException("Tool must not be null")); + return f; + } + if (taskToolSpecification.createTaskHandler() == null) { + CompletableFuture f = new CompletableFuture<>(); + f.completeExceptionally(new IllegalArgumentException("createTask handler must not be null")); + return f; + } + + return CompletableFuture.supplyAsync(() -> { + String toolName = taskToolSpecification.tool().getName(); + synchronized (this.toolRegistrationLock) { + if (this.taskTools.removeIf(th -> th.tool().getName().equals(toolName))) { + logger.warn("Replace existing TaskTool with name '{}'", toolName); + } + + this.taskTools.add(taskToolSpecification); + this.taskToolsByName.put(toolName, taskToolSpecification); + } + logger.debug("Added task tool handler: {}", toolName); + + if (toolCapabilities != null && toolCapabilities.getListChanged() != null + && toolCapabilities.getListChanged()) { + return notifyToolsListChanged(); + } + return CompletableFuture.completedFuture(null); + }).thenCompose(f -> f); + } + + public CompletableFuture removeTaskTool( + String toolName, + McpSchema.ServerCapabilities.ToolCapabilities toolCapabilities) { + + if (toolName == null) { + CompletableFuture f = new CompletableFuture<>(); + f.completeExceptionally(new IllegalArgumentException("Tool name must not be null")); + return f; + } + + return CompletableFuture.supplyAsync(() -> { + if (this.taskTools.removeIf(toolSpec -> toolSpec.tool().getName().equals(toolName))) { + this.taskToolsByName.remove(toolName); + logger.debug("Removed task tool handler: {}", toolName); + if (toolCapabilities != null && toolCapabilities.getListChanged() != null + && toolCapabilities.getListChanged()) { + return notifyToolsListChanged(); + } + } + else { + logger.warn("Ignore as a TaskTool with name '{}' not found", toolName); + } + return CompletableFuture.completedFuture(null); + }).thenCompose(f -> f); + } + + public List listTaskTools() { + return this.taskTools.stream() + .map(TaskAwareToolSpecification::tool) + .collect(Collectors.toList()); + } + + public List getToolDefinitions() { + return this.taskTools.stream() + .map(TaskAwareToolSpecification::tool) + .collect(Collectors.toList()); + } + + public boolean hasToolNamed(String name) { + return this.taskToolsByName.containsKey(name); + } + + public Object getToolRegistrationLock() { + return this.toolRegistrationLock; + } + + // --------------------------------------- + // Task Tool Call Handling + // --------------------------------------- + + public CompletableFuture handleToolCall( + McpNettyServerExchange exchange, + ArthasCommandContext commandContext, + McpSchema.CallToolRequest callToolRequest) { + + TaskAwareToolSpecification taskTool = this.taskToolsByName.get(callToolRequest.getName()); + if (taskTool == null) { + return null; + } + + return doHandleTaskToolCall(exchange, commandContext, callToolRequest, taskTool) + .thenApply(r -> (Object) r); + } + + /** Dispatches a task-aware tool call; handles task creation or automatic polling. */ + private CompletableFuture doHandleTaskToolCall( + McpNettyServerExchange exchange, + ArthasCommandContext commandContext, + McpSchema.CallToolRequest request, + TaskAwareToolSpecification taskTool) { + + McpSchema.ToolExecution execution = taskTool.tool().getExecution(); + McpSchema.TaskSupportMode taskSupportMode = execution != null ? execution.getTaskSupport() : null; + + if (request.getTask() != null) { + if (getTaskStore() == null) { + CompletableFuture f = new CompletableFuture<>(); + f.completeExceptionally(McpError.builder(McpSchema.ErrorCodes.INVALID_REQUEST) + .message("Server does not support tasks") + .data("Task store not configured") + .build()); + return f; + } + return handleTaskToolCreateTask(exchange, commandContext, request, taskTool); + } + + if (taskSupportMode == McpSchema.TaskSupportMode.REQUIRED) { + CompletableFuture f = new CompletableFuture<>(); + f.completeExceptionally(McpError.builder(McpSchema.ErrorCodes.INVALID_PARAMS) + .message("This tool requires task-augmented execution") + .data("Tool '" + request.getName() + "' requires task metadata in the request") + .build()); + return f; + } + + if (getTaskStore() != null) { + return handleAutomaticTaskPolling(exchange, commandContext, request, taskTool); + } + + if (taskTool.callHandler() != null) { + return taskTool.callHandler().apply(exchange, commandContext, request); + } + + CompletableFuture f = new CompletableFuture<>(); + f.completeExceptionally(McpError.builder(McpSchema.ErrorCodes.INTERNAL_ERROR) + .message("Tool requires task store or callHandler for execution") + .build()); + return f; + } + + /** Handles task creation for a task-aware tool call. */ + private CompletableFuture handleTaskToolCreateTask( + McpNettyServerExchange exchange, + ArthasCommandContext commandContext, + McpSchema.CallToolRequest request, + TaskAwareToolSpecification taskTool) { + + Long requestTtl = request.getTask() != null ? request.getTask().getTtl() : null; + + String sessionId = extractSessionId(exchange); + logger.info("handleTaskToolCreateTask: Creating task for tool '{}' with sessionId: {}", + request.getName(), sessionId); + + CreateTaskContext extra = new DefaultCreateTaskContext( + this.taskStore, + getTaskMessageQueue(), + exchange, + sessionId, + requestTtl, + request, + commandContext, + this.sessionManager + ); + + Map args = request.getArguments() != null ? request.getArguments() : Collections.emptyMap(); + + return taskTool.createTaskHandler().createTask(args, extra) + .exceptionally(ex -> { + Throwable cause = ex instanceof CompletionException ? ex.getCause() : ex; + if (!(cause instanceof McpError)) { + throw new CompletionException(new McpError( + new McpSchema.JSONRPCResponse.JSONRPCError( + McpSchema.ErrorCodes.INTERNAL_ERROR, + "Task creation failed: " + cause.getMessage(), + null + ) + )); + } + throw new CompletionException(cause); + }); + } + + /** Handles automatic task polling for a task-aware tool call without task metadata. */ + private CompletableFuture handleAutomaticTaskPolling( + McpNettyServerExchange exchange, + ArthasCommandContext commandContext, + McpSchema.CallToolRequest request, + TaskAwareToolSpecification taskTool) { + + CreateTaskContext extra = new DefaultCreateTaskContext( + this.taskStore, + getTaskMessageQueue(), + exchange, + extractSessionId(exchange), + null, + request, + commandContext, + this.sessionManager + ); + + Map args = request.getArguments() != null ? request.getArguments() : Collections.emptyMap(); + + return taskTool.createTaskHandler().createTask(args, extra) + .thenCompose(createResult -> { + McpSchema.Task task = createResult.getTask(); + if (task == null) { + CompletableFuture f = new CompletableFuture<>(); + f.completeExceptionally(McpError.builder(McpSchema.ErrorCodes.INTERNAL_ERROR) + .message("createTaskHandler did not return a task") + .build()); + return f; + } + + String taskId = task.getTaskId(); + String sessionId = extractSessionId(exchange); + + return pollTaskUntilTerminal(taskId, sessionId, task, taskTool); + }); + } + + /** Polls a task until it reaches a terminal state, then returns the result. */ + private CompletableFuture pollTaskUntilTerminal( + String taskId, + String sessionId, + McpSchema.Task initialTask, + TaskAwareToolSpecification taskTool) { + + long pollInterval = initialTask.getPollInterval() != null + ? initialTask.getPollInterval() + : TaskDefaults.DEFAULT_POLL_INTERVAL_MS; + + Duration timeout = this.automaticPollingTimeout != null + ? this.automaticPollingTimeout + : Duration.ofMillis(TaskDefaults.DEFAULT_AUTOMATIC_POLLING_TIMEOUT_MS); + + CompletableFuture> watchFuture = taskStore.watchTaskUntilTerminal( + taskId, + sessionId, + timeout.toMillis() + ); + + return watchFuture.thenCompose(tasks -> { + if (tasks.isEmpty()) { + CompletableFuture f = new CompletableFuture<>(); + f.completeExceptionally(McpError.builder(McpSchema.ErrorCodes.INTERNAL_ERROR) + .message("Task watch returned empty list") + .data("Task ID: " + taskId) + .build()); + return f; + } + + McpSchema.Task finalTask = tasks.get(tasks.size() - 1); + + if (finalTask.getStatus() == McpSchema.TaskStatus.INPUT_REQUIRED) { + CompletableFuture f = new CompletableFuture<>(); + f.completeExceptionally(new McpError( + new McpSchema.JSONRPCResponse.JSONRPCError( + McpSchema.ErrorCodes.INTERNAL_ERROR, + "Task requires interactive input which is not supported in automatic polling mode. " + + "Use task-augmented requests (with TaskMetadata) to enable interactive input. " + + "Task ID: " + taskId, + null + ) + )); + return f; + } + + // For FAILED/CANCELLED: fetch the stored payload (FAILED has one via failTask; + // CANCELLED has none, so fall back to a synthetic error result). + if (finalTask.getStatus() == McpSchema.TaskStatus.FAILED + || finalTask.getStatus() == McpSchema.TaskStatus.CANCELLED) { + return taskStore.getTaskResult(taskId, sessionId) + .thenApply(result -> { + if (result != null) { + return (McpSchema.CallToolResult) result; + } + // CANCELLED (or FAILED without payload as a safety net) + String msg = finalTask.getStatus() == McpSchema.TaskStatus.CANCELLED + ? "Task was cancelled" + + (finalTask.getStatusMessage() != null ? ": " + finalTask.getStatusMessage() : "") + : "Task failed" + + (finalTask.getStatusMessage() != null ? ": " + finalTask.getStatusMessage() : ""); + return new McpSchema.CallToolResult(msg, true, null); + }); + } + + return taskStore.getTaskResult(taskId, sessionId) + .thenApply(result -> (McpSchema.CallToolResult) result); + }).exceptionally(ex -> { + Throwable cause = ex instanceof java.util.concurrent.CompletionException ? ex.getCause() : ex; + if (cause instanceof java.util.concurrent.TimeoutException) { + throw new java.util.concurrent.CompletionException(new McpError( + new McpSchema.JSONRPCResponse.JSONRPCError( + McpSchema.ErrorCodes.INTERNAL_ERROR, + "Task timed out waiting for completion: " + taskId, + null + ) + )); + } + throw new java.util.concurrent.CompletionException(cause); + }); + } + + private String extractSessionId(McpNettyServerExchange exchange) { + return exchange.sessionId(); + } + + // --------------------------------------- + // Task Request Handler Wiring + // --------------------------------------- + + public Map> getRequestHandlers( + McpSchema.ServerCapabilities.TaskCapabilities taskCapabilities) { + Map> handlers = new HashMap<>(); + if (taskCapabilities != null && getTaskStore() != null) { + this.taskHandlerRegistry.wireHandlers( + taskCapabilities.getList() != null, + taskCapabilities.getCancel() != null, + this::adaptTaskHandler, + handlers::put + ); + } + return handlers; + } + + public void logCapabilityMismatches(McpSchema.ServerCapabilities.TaskCapabilities taskCapabilities) { + if (taskCapabilities != null && getTaskStore() == null) { + logger.warn("Server has tasks capability enabled but no TaskStore configured. " + + "Task operations will be unavailable. Provide a TaskStore or remove the tasks capability."); + } + if (getTaskStore() != null && taskCapabilities == null) { + logger.warn("Server has a TaskStore configured but tasks capability is not enabled. " + + "Task operations will be unavailable. Enable the tasks capability or remove the TaskStore."); + } + } + + @SuppressWarnings("unchecked") + private McpRequestHandler adaptTaskHandler(String method, TaskManagerHost.TaskRequestHandler taskHandler) { + return (exchange, commandContext, params) -> { + String sessionId = extractSessionId(exchange); + + TaskManagerHost.TaskHandlerContext ctx = createTaskHandlerContext( + sessionId, + (reqMethod, reqParams, resultType) -> { + TypeReference typeRef = new TypeReference() {}; + return exchange.getSession().sendRequest(reqMethod, reqParams, typeRef); + }, + (notifMethod, notification) -> { + return exchange.sendNotification(notifMethod, notification); + } + ); + + return this.taskHandlerRegistry.invokeHandler(method, params, ctx).thenApply(result -> { + if (McpSchema.METHOD_TASKS_RESULT.equals(method) && result instanceof McpSchema.Result) { + try { + McpSchema.GetTaskPayloadRequest payloadReq = + objectMapper.convertValue(params, McpSchema.GetTaskPayloadRequest.class); + if (payloadReq != null && payloadReq.getTaskId() != null) { + return (T) addRelatedTaskMetadata(payloadReq.getTaskId(), (McpSchema.Result) result); + } + } catch (Exception e) { + logger.warn("Failed to add related-task metadata", e); + } + } + return result; + }); + }; + } + + // --------------------------------------- + // Metadata Helpers + // --------------------------------------- + + private McpSchema.Result addRelatedTaskMetadata(String taskId, McpSchema.Result result) { + if (result instanceof McpSchema.CallToolResult) { + McpSchema.CallToolResult ctr = (McpSchema.CallToolResult) result; + Map newMeta = TaskMetadataUtils.mergeRelatedTaskMetadata(taskId, ctr.getMeta()); + return new McpSchema.CallToolResult(ctr.getContent(), ctr.getIsError(), newMeta); + } + return result; + } + + public TaskMessageQueue getTaskMessageQueue() { + return this.taskOptions != null ? this.taskOptions.messageQueue() : null; + } + + // --------------------------------------- + // Lifecycle + // --------------------------------------- + + public CompletableFuture notifyTaskStatus(McpSchema.TaskStatusNotification taskStatusNotification) { + if (taskStatusNotification == null) { + CompletableFuture f = new CompletableFuture<>(); + f.completeExceptionally(McpError.builder(McpSchema.ErrorCodes.INVALID_REQUEST) + .message("Task status notification must not be null") + .build()); + return f; + } + return this.clientNotifier.apply(McpSchema.METHOD_NOTIFICATION_TASKS_STATUS, taskStatusNotification); + } + + private CompletableFuture notifyToolsListChanged() { + return this.clientNotifier.apply(McpSchema.METHOD_NOTIFICATION_TOOLS_LIST_CHANGED, + Collections.emptyMap()); + } + + @Override + public void close() { + super.close(); + logger.info("ServerTaskToolHandler closed"); + } + + // --------------------------------------- + // TaskManagerHost Implementation + // --------------------------------------- + + @Override + public CompletableFuture request(McpSchema.Request request, Class resultType) { + logger.debug("TaskManagerHost.request called on server. For session-specific requests, use exchange methods."); + CompletableFuture f = new CompletableFuture<>(); + f.completeExceptionally(new UnsupportedOperationException( + "Broadcast requests not supported for tasks. Use session-specific exchange methods.")); + return f; + } + + @Override + public CompletableFuture notification(String notificationMethod, Object notification) { + return this.clientNotifier.apply(notificationMethod, notification); + } + + @Override + @SuppressWarnings("unchecked") + protected CompletableFuture findAndInvokeCustomHandler( + GetTaskFromStoreResult storeResult, String method, McpSchema.Request request, + TaskManagerHost.TaskHandlerContext context, Class resultType) { + + String toolName = null; + if (storeResult.originatingRequest() instanceof McpSchema.CallToolRequest) { + McpSchema.CallToolRequest ctr = (McpSchema.CallToolRequest) storeResult.originatingRequest(); + toolName = ctr.getName(); + } + + TaskAwareToolSpecification taskTool = toolName != null ? this.taskToolsByName.get(toolName) : null; + + if (taskTool == null) { + return CompletableFuture.completedFuture(null); + } + + McpNettyServerExchange exchange = new McpNettyServerExchange(context.sessionId(), null, null, + null, McpTransportContext.EMPTY, null); + + if (McpSchema.METHOD_TASKS_GET.equals(method)) { + GetTaskHandler handler = taskTool.getTaskHandler(); + if (handler != null && request instanceof McpSchema.GetTaskRequest) { + McpSchema.GetTaskRequest getRequest = (McpSchema.GetTaskRequest) request; + return handler.handle(exchange, getRequest) + .thenApply(result -> resultType.cast(result)); + } + } + else if (McpSchema.METHOD_TASKS_RESULT.equals(method)) { + GetTaskResultHandler handler = taskTool.getTaskResultHandler(); + if (handler != null && request instanceof McpSchema.GetTaskPayloadRequest) { + McpSchema.GetTaskPayloadRequest payloadRequest = (McpSchema.GetTaskPayloadRequest) request; + return handler.handle(exchange, payloadRequest) + .thenApply(result -> resultType.cast(result)); + } + } + + return CompletableFuture.completedFuture(null); + } +} diff --git a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/TaskAwareToolSpecification.java b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/TaskAwareToolSpecification.java new file mode 100644 index 00000000000..4bd594f6f3c --- /dev/null +++ b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/TaskAwareToolSpecification.java @@ -0,0 +1,101 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package com.taobao.arthas.mcp.server.task; + +import com.taobao.arthas.mcp.server.protocol.server.McpNettyServerExchange; +import com.taobao.arthas.mcp.server.protocol.spec.McpSchema; +import com.taobao.arthas.mcp.server.session.ArthasCommandContext; + +import java.util.concurrent.CompletableFuture; + +/** + * Task-aware tool specification combining tool definition with task handlers. + * + *

Task support modes: OPTIONAL (default, backward-compatible), REQUIRED, FORBIDDEN. + * + * @author Yeaury + */ +public final class TaskAwareToolSpecification { + + private final McpSchema.Tool tool; + private final TriFunction> callHandler; + private final CreateTaskHandler createTaskHandler; + private final GetTaskHandler getTaskHandler; + private final GetTaskResultHandler getTaskResultHandler; + + private TaskAwareToolSpecification( + McpSchema.Tool tool, + TriFunction> callHandler, + CreateTaskHandler createTaskHandler, + GetTaskHandler getTaskHandler, + GetTaskResultHandler getTaskResultHandler) { + this.tool = tool; + this.callHandler = callHandler; + this.createTaskHandler = createTaskHandler; + this.getTaskHandler = getTaskHandler; + this.getTaskResultHandler = getTaskResultHandler; + } + + public McpSchema.Tool tool() { + return tool; + } + + public TriFunction> callHandler() { + return callHandler; + } + + public CreateTaskHandler createTaskHandler() { + return createTaskHandler; + } + + public GetTaskHandler getTaskHandler() { + return getTaskHandler; + } + + public GetTaskResultHandler getTaskResultHandler() { + return getTaskResultHandler; + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder extends AbstractTaskAwareToolSpecificationBuilder { + + private TriFunction> callHandler; + private CreateTaskHandler createTaskHandler; + private GetTaskHandler getTaskHandler; + private GetTaskResultHandler getTaskResultHandler; + + public Builder callHandler(TriFunction> callHandler) { + this.callHandler = callHandler; + return this; + } + + public Builder createTaskHandler(CreateTaskHandler createTaskHandler) { + this.createTaskHandler = createTaskHandler; + return this; + } + + public Builder getTaskHandler(GetTaskHandler getTaskHandler) { + this.getTaskHandler = getTaskHandler; + return this; + } + + public Builder getTaskResultHandler(GetTaskResultHandler getTaskResultHandler) { + this.getTaskResultHandler = getTaskResultHandler; + return this; + } + + public TaskAwareToolSpecification build() { + validateCommonFields(); + if (createTaskHandler == null) { + throw new IllegalArgumentException("createTaskHandler must not be null"); + } + McpSchema.Tool tool = buildTool(); + return new TaskAwareToolSpecification(tool, callHandler, createTaskHandler, getTaskHandler, getTaskResultHandler); + } + } +} diff --git a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/TaskDefaults.java b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/TaskDefaults.java new file mode 100644 index 00000000000..0541658046e --- /dev/null +++ b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/TaskDefaults.java @@ -0,0 +1,76 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package com.taobao.arthas.mcp.server.task; + +import com.taobao.arthas.mcp.server.protocol.spec.McpSchema; + +import java.time.Duration; + +/** + * Default configuration constants for the task system. + * + * @author Yeaury + */ +public final class TaskDefaults { + + public static final long DEFAULT_TTL_MS = 60_000L; + + public static final long DEFAULT_POLL_INTERVAL_MS = 1000L; + + public static final int DEFAULT_PAGE_SIZE = 100; + + public static final int DEFAULT_MAX_TASKS = 10_000; + + public static final long DEFAULT_AUTOMATIC_POLLING_TIMEOUT_MS = 600000L; + + public static final int DEFAULT_SIDE_CHANNEL_TIMEOUT_MINUTES = 5; + + public static final long MAX_TTL_MS = 24 * 60 * 60 * 1000L; + + public static final long MIN_POLL_INTERVAL_MS = 100L; + + public static final long MAX_POLL_INTERVAL_MS = 60 * 60 * 1000L; + + public static final long CLEANUP_INTERVAL_MINUTES = 1L; + + public static final long MESSAGE_QUEUE_CLEANUP_TIMEOUT_MS = 1_000L; + + public static final long RESPONSE_POLL_INTERVAL_MS = 50L; + + public static final long TASK_STORE_SHUTDOWN_TIMEOUT_SECONDS = 5L; + + public static final int DEFAULT_MAX_POLL_ATTEMPTS = 60; + + public static final long MAX_TIMEOUT_MS = 3_600_000L; + + public static final int MAX_WATCH_UPDATES = 100; + + public static final McpSchema.JsonSchema EMPTY_INPUT_SCHEMA = + new McpSchema.JsonSchema("object", null, null, null); + + /** + * Calculates a polling timeout scaled to the given poll interval, capped at {@link #MAX_TIMEOUT_MS}. + */ + public static Duration calculateTimeout(Long pollInterval) { + long interval = pollInterval != null ? pollInterval : DEFAULT_POLL_INTERVAL_MS; + long calculatedMs = interval * DEFAULT_MAX_POLL_ATTEMPTS; + return Duration.ofMillis(Math.min(calculatedMs, MAX_TIMEOUT_MS)); + } + + /** + * Validates that task-aware tools are not registered without a TaskStore. + */ + public static void validateTaskConfiguration(boolean hasTaskTools, boolean hasTaskStore) { + if (hasTaskTools && !hasTaskStore) { + throw new IllegalStateException( + "Task-aware tools registered but no TaskStore configured. " + + "Add a TaskStore via .taskStore(store) or remove task tools."); + } + } + + private TaskDefaults() { + throw new UnsupportedOperationException("Utility class"); + } +} diff --git a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/TaskHandlerRegistry.java b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/TaskHandlerRegistry.java new file mode 100644 index 00000000000..1b6de0be9fe --- /dev/null +++ b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/TaskHandlerRegistry.java @@ -0,0 +1,72 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package com.taobao.arthas.mcp.server.task; + +import com.taobao.arthas.mcp.server.protocol.spec.McpSchema; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.BiConsumer; +import java.util.function.BiFunction; + +/** + * Registry for task request handlers, shared by client and server. + * + * @author Yeaury + */ +class TaskHandlerRegistry { + + private static final Logger logger = LoggerFactory.getLogger(TaskHandlerRegistry.class); + + private final ConcurrentHashMap handlers = new ConcurrentHashMap<>(); + + public void registerHandler(String method, TaskManagerHost.TaskRequestHandler handler) { + logger.debug("Registered task handler: {}", method); + this.handlers.put(method, handler); + } + + @SuppressWarnings("unchecked") + public CompletableFuture invokeHandler(String method, Object params, + TaskManagerHost.TaskHandlerContext context) { + TaskManagerHost.TaskRequestHandler handler = this.handlers.get(method); + if (handler == null) { + CompletableFuture future = new CompletableFuture<>(); + future.completeExceptionally(new IllegalStateException("No handler registered: " + method)); + return future; + } + return handler.handle(method, params, context) + .thenApply(result -> (T) result); + } + + /** + * Wires registered handlers via the adapter/registrar pair. + * tasks/get and tasks/result are always wired; tasks/list and tasks/cancel are conditional. + */ + public void wireHandlers(boolean supportsList, boolean supportsCancel, + BiFunction adapter, + BiConsumer registrar) { + wireIfPresent(McpSchema.METHOD_TASKS_GET, adapter, registrar); + wireIfPresent(McpSchema.METHOD_TASKS_RESULT, adapter, registrar); + if (supportsList) { + wireIfPresent(McpSchema.METHOD_TASKS_LIST, adapter, registrar); + } + if (supportsCancel) { + wireIfPresent(McpSchema.METHOD_TASKS_CANCEL, adapter, registrar); + } + } + + private void wireIfPresent(String method, + BiFunction adapter, + BiConsumer registrar) { + TaskManagerHost.TaskRequestHandler handler = this.handlers.get(method); + if (handler != null) { + registrar.accept(method, adapter.apply(method, handler)); + } else { + logger.warn("No handler registered for {}", method); + } + } +} diff --git a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/TaskHelper.java b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/TaskHelper.java new file mode 100644 index 00000000000..ca6a5f4e3a3 --- /dev/null +++ b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/TaskHelper.java @@ -0,0 +1,49 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package com.taobao.arthas.mcp.server.task; + +import com.taobao.arthas.mcp.server.protocol.spec.McpSchema.TaskStatus; + +/** + * Utility methods for task status checks. + * + * @author Yeaury + */ +public final class TaskHelper { + + private TaskHelper() {} + + /** Returns true if the status is COMPLETED, FAILED, or CANCELLED. */ + public static boolean isTerminal(TaskStatus status) { + if (status == null) { + return false; + } + return status == TaskStatus.COMPLETED + || status == TaskStatus.FAILED + || status == TaskStatus.CANCELLED; + } + + /** + * Returns true if the transition from {@code from} to {@code to} is valid. + * + *

Terminal states cannot transition further. WORKING can transition to any state. + * INPUT_REQUIRED can transition to WORKING or any terminal state. + */ + public static boolean isValidTransition(TaskStatus from, TaskStatus to) { + if (from == null || to == null) { + return false; + } + if (isTerminal(from)) { + return false; + } + if (from == TaskStatus.WORKING) { + return true; + } + if (from == TaskStatus.INPUT_REQUIRED) { + return to == TaskStatus.WORKING || isTerminal(to); + } + return false; + } +} diff --git a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/TaskManager.java b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/TaskManager.java new file mode 100644 index 00000000000..42992c6a24c --- /dev/null +++ b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/TaskManager.java @@ -0,0 +1,230 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package com.taobao.arthas.mcp.server.task; + +import java.time.Duration; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; + +/** + * Orchestrates task state, message queuing, polling, and handler registration. + * + *

Interacts with the protocol layer via five lifecycle methods: + * {@link #processInboundRequest}, {@link #processOutboundRequest}, + * {@link #processInboundResponse}, {@link #processOutboundNotification}, {@link #onClose}. + * Must be bound to a {@link TaskManagerHost} via {@link #bind} before use. + * + * @author Yeaury + * @see DefaultTaskManager + * @see NullTaskManager + */ +public interface TaskManager { + + void bind(TaskManagerHost host); + + InboundRequestResult processInboundRequest(String requestMethod, Object requestParams, + InboundRequestContext ctx); + + OutboundRequestResult processOutboundRequest(String requestMethod, Object requestParams, + RequestOptions options, Object messageId, + Consumer responseHandler, + Consumer errorHandler); + + InboundResponseResult processInboundResponse(Object responseResult, Object messageId); + + CompletableFuture processOutboundNotification( + String notificationMethod, Object notification, NotificationOptions options); + + void onClose(); + + Optional> taskStore(); + + Optional messageQueue(); + + Duration defaultPollInterval(); + + // === Supporting types === + + class InboundRequestContext { + private final String sessionId; + private final NotificationSender sendNotification; + private final RequestSender sendRequest; + + public InboundRequestContext(String sessionId, + NotificationSender sendNotification, + RequestSender sendRequest) { + this.sessionId = sessionId; + this.sendNotification = sendNotification; + this.sendRequest = sendRequest; + } + + public String sessionId() { + return sessionId; + } + + public NotificationSender sendNotification() { + return sendNotification; + } + + public RequestSender sendRequest() { + return sendRequest; + } + } + + @FunctionalInterface + interface NotificationSender { + CompletableFuture send(Object notification, NotificationOptions options); + } + + @FunctionalInterface + interface RequestSender { + CompletableFuture send(Object request, Class resultType, RequestOptions options); + } + + class InboundRequestResult { + private final Consumer sendNotification; + private final RequestSender sendRequest; + private final java.util.function.Function> routeResponse; + private final boolean hasTaskCreationParams; + + public InboundRequestResult(Consumer sendNotification, + RequestSender sendRequest, + java.util.function.Function> routeResponse, + boolean hasTaskCreationParams) { + this.sendNotification = sendNotification; + this.sendRequest = sendRequest; + this.routeResponse = routeResponse; + this.hasTaskCreationParams = hasTaskCreationParams; + } + + public Consumer sendNotification() { + return sendNotification; + } + + public RequestSender sendRequest() { + return sendRequest; + } + + public java.util.function.Function> routeResponse() { + return routeResponse; + } + + public boolean hasTaskCreationParams() { + return hasTaskCreationParams; + } + } + + class OutboundRequestResult { + private final boolean queued; + + public OutboundRequestResult(boolean queued) { + this.queued = queued; + } + + public boolean queued() { + return queued; + } + } + + class InboundResponseResult { + private final boolean consumed; + + public InboundResponseResult(boolean consumed) { + this.consumed = consumed; + } + + public boolean consumed() { + return consumed; + } + } + + class OutboundNotificationResult { + private final boolean queued; + private final Object jsonrpcNotification; + + public OutboundNotificationResult(boolean queued, Object jsonrpcNotification) { + this.queued = queued; + this.jsonrpcNotification = jsonrpcNotification; + } + + public OutboundNotificationResult(boolean queued) { + this(queued, null); + } + + public boolean queued() { + return queued; + } + + public Object jsonrpcNotification() { + return jsonrpcNotification; + } + } + + class RequestOptions { + private final TaskCreationParams task; + private final RelatedTaskInfo relatedTask; + + public RequestOptions(TaskCreationParams task, RelatedTaskInfo relatedTask) { + this.task = task; + this.relatedTask = relatedTask; + } + + public static RequestOptions empty() { + return new RequestOptions(null, null); + } + + public TaskCreationParams task() { + return task; + } + + public RelatedTaskInfo relatedTask() { + return relatedTask; + } + } + + class NotificationOptions { + private final RelatedTaskInfo relatedTask; + + public NotificationOptions(RelatedTaskInfo relatedTask) { + this.relatedTask = relatedTask; + } + + public static NotificationOptions empty() { + return new NotificationOptions(null); + } + public static NotificationOptions withRelatedTask(RelatedTaskInfo relatedTask) { + return new NotificationOptions(relatedTask); + } + + public RelatedTaskInfo relatedTask() { + return relatedTask; + } + } + + class TaskCreationParams { + private final Long ttl; + + public TaskCreationParams(Long ttl) { + this.ttl = ttl; + } + + public Long ttl() { + return ttl; + } + } + + class RelatedTaskInfo { + private final String taskId; + + public RelatedTaskInfo(String taskId) { + this.taskId = taskId; + } + + public String taskId() { + return taskId; + } + } +} diff --git a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/TaskManagerHost.java b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/TaskManagerHost.java new file mode 100644 index 00000000000..1d1e9b64aa0 --- /dev/null +++ b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/TaskManagerHost.java @@ -0,0 +1,47 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package com.taobao.arthas.mcp.server.task; + +import com.taobao.arthas.mcp.server.protocol.spec.McpSchema; + +import java.util.concurrent.CompletableFuture; + +/** + * Communication interface between {@link TaskManager} and the protocol layer. + * + * @author Yeaury + */ +public interface TaskManagerHost { + + CompletableFuture request(McpSchema.Request request, Class resultType); + + CompletableFuture notification(String notificationMethod, Object notification); + + /** Register a handler for a task-related method (e.g. tasks/get). */ + void registerHandler(String method, TaskRequestHandler handler); + + /** + * Invoke a custom task handler if one is registered for the given task and method. + * Returns null if no custom handler exists. + */ + CompletableFuture invokeCustomTaskHandler( + String taskId, String method, McpSchema.Request request, + TaskHandlerContext context, Class resultType); + + @FunctionalInterface + interface TaskRequestHandler { + CompletableFuture handle(String requestMethod, Object requestParams, + TaskHandlerContext context); + } + + interface TaskHandlerContext { + String sessionId(); + + CompletableFuture sendRequest(String method, Object params, + Class resultType); + + CompletableFuture sendNotification(String method, Object notification); + } +} diff --git a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/TaskManagerOptions.java b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/TaskManagerOptions.java new file mode 100644 index 00000000000..835f0b8caa5 --- /dev/null +++ b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/TaskManagerOptions.java @@ -0,0 +1,92 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package com.taobao.arthas.mcp.server.task; + +import java.time.Duration; + +/** + * Configuration options for {@link TaskManager}. + * + * @author Yeaury + */ +public class TaskManagerOptions { + + private final TaskStore taskStore; + private final TaskMessageQueue messageQueue; + private final Duration defaultPollInterval; + private final Duration pollTimeout; + + private TaskManagerOptions(Builder builder) { + this.taskStore = builder.taskStore; + this.messageQueue = builder.messageQueue; + this.defaultPollInterval = builder.defaultPollInterval != null + ? builder.defaultPollInterval + : Duration.ofMillis(TaskDefaults.DEFAULT_POLL_INTERVAL_MS); + this.pollTimeout = builder.pollTimeout; + } + + public TaskStore taskStore() { + return this.taskStore; + } + + public TaskMessageQueue messageQueue() { + return this.messageQueue; + } + + public Duration defaultPollInterval() { + return this.defaultPollInterval; + } + + public Duration pollTimeout() { + return this.pollTimeout; + } + + /** + * Creates a {@link DefaultTaskManager} if store/queue are configured, otherwise {@link NullTaskManager}. + */ + public TaskManager createTaskManager() { + if (this.taskStore == null && this.messageQueue == null) { + return NullTaskManager.getInstance(); + } + return new DefaultTaskManager(this); + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private TaskStore taskStore; + private TaskMessageQueue messageQueue; + private Duration defaultPollInterval; + private Duration pollTimeout; + + private Builder() {} + + public Builder store(TaskStore taskStore) { + this.taskStore = taskStore; + return this; + } + + public Builder messageQueue(TaskMessageQueue messageQueue) { + this.messageQueue = messageQueue; + return this; + } + + public Builder defaultPollInterval(Duration interval) { + this.defaultPollInterval = interval; + return this; + } + + public Builder pollTimeout(Duration timeout) { + this.pollTimeout = timeout; + return this; + } + + public TaskManagerOptions build() { + return new TaskManagerOptions(this); + } + } +} diff --git a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/TaskMessageQueue.java b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/TaskMessageQueue.java new file mode 100644 index 00000000000..7dc3ccea21f --- /dev/null +++ b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/TaskMessageQueue.java @@ -0,0 +1,42 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package com.taobao.arthas.mcp.server.task; + +import java.time.Duration; +import java.util.List; +import java.util.concurrent.CompletableFuture; + +/** + * Queue for side-channel messages during task execution. + * + *

Supports three message types: Request (server→client, needs response), + * Notification (async, no response), and Response (client reply, retrieved via + * {@link #waitForResponse} only). + * + * @author Yeaury + */ +public interface TaskMessageQueue { + + /** Enqueue a message (Request, Response, or Notification). */ + CompletableFuture enqueue(String taskId, QueuedMessage message); + + /** Dequeue the next actionable message (Request or Notification); returns null if empty. */ + CompletableFuture dequeue(String taskId); + + /** Dequeue all actionable messages (Request and Notification). */ + CompletableFuture> dequeueAll(String taskId); + + /** Block until a Response matching {@code requestId} is enqueued, or timeout. */ + CompletableFuture waitForResponse(String taskId, Object requestId, Duration timeout); + + /** Remove all messages for a task (called on task expiry/cleanup). */ + CompletableFuture clearTask(String taskId); + + default CompletableFuture getQueueSize(String taskId) { + return CompletableFuture.completedFuture(0); + } + + CompletableFuture shutdown(); +} diff --git a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/TaskMetadataUtils.java b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/TaskMetadataUtils.java new file mode 100644 index 00000000000..bbb148ee373 --- /dev/null +++ b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/TaskMetadataUtils.java @@ -0,0 +1,64 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package com.taobao.arthas.mcp.server.task; + +import com.taobao.arthas.mcp.server.protocol.spec.McpSchema; + +import java.util.HashMap; +import java.util.Map; + +/** + * Utilities for adding {@code relatedTask} metadata to notifications and results. + * + * @author Yeaury + */ +public final class TaskMetadataUtils { + + private TaskMetadataUtils() {} + + /** + * Injects {@code _meta.relatedTask.taskId} into a notification. + * {@link McpSchema.TaskStatusNotification} is returned unchanged (already contains taskId). + */ + @SuppressWarnings("unchecked") + public static Object addRelatedTaskMetadata(String taskId, Object notification) { + if (notification == null || taskId == null) { + return notification; + } + if (notification instanceof McpSchema.TaskStatusNotification) { + return notification; + } + if (notification instanceof Map) { + Map notifMap = new HashMap<>((Map) notification); + Map meta = notifMap.containsKey("_meta") && notifMap.get("_meta") instanceof Map + ? new HashMap<>((Map) notifMap.get("_meta")) + : new HashMap<>(); + Map relatedTask = new HashMap<>(); + relatedTask.put("taskId", taskId); + meta.put(McpSchema.RELATED_TASK_META_KEY, relatedTask); + notifMap.put("_meta", meta); + return notifMap; + } + return notification; + } + + /** Merges {@code relatedTask: {taskId}} into a new metadata map, overlaying existing entries. */ + public static Map mergeRelatedTaskMetadata(String taskId, Map existingMeta) { + Map taskIdMap = new HashMap<>(); + taskIdMap.put("taskId", taskId); + return mergeRelatedTaskMetadata((Object) taskIdMap, existingMeta); + } + + /** Merges {@code relatedTask: relatedTaskValue} into a new metadata map, overlaying existing entries. */ + public static Map mergeRelatedTaskMetadata(Object relatedTaskValue, + Map existingMeta) { + Map newMeta = new HashMap<>(); + newMeta.put(McpSchema.RELATED_TASK_META_KEY, relatedTaskValue); + if (existingMeta != null) { + newMeta.putAll(existingMeta); + } + return newMeta; + } +} diff --git a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/TaskStore.java b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/TaskStore.java new file mode 100644 index 00000000000..374a495389d --- /dev/null +++ b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/TaskStore.java @@ -0,0 +1,51 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package com.taobao.arthas.mcp.server.task; + +import com.taobao.arthas.mcp.server.protocol.spec.McpSchema; + +import java.util.List; +import java.util.concurrent.CompletableFuture; + +/** + * Persistent store for task state and results with session isolation. + * + *

Session validation rules: null {@code sessionId} allows all access (single-tenant); + * tasks without a session are accessible from any session; otherwise sessionIds must match. + * + *

Error conventions: {@link #getTask} and {@link #getTaskResult} return null on miss; + * {@link #storeTaskResult} throws on miss; {@link #updateTaskStatus} silently ignores misses; + * {@link #requestCancellation} throws (-32602) for terminal tasks. + * + * @param result type stored by this store + * @author Yeaury + */ +public interface TaskStore { + + CompletableFuture createTask(CreateTaskOptions options); + + CompletableFuture getTask(String taskId, String sessionId); + + CompletableFuture updateTaskStatus(String taskId, String sessionId, + McpSchema.TaskStatus status, String statusMessage); + + CompletableFuture storeTaskResult(String taskId, String sessionId, + McpSchema.TaskStatus status, R result); + + CompletableFuture getTaskResult(String taskId, String sessionId); + + CompletableFuture listTasks(String cursor, String sessionId); + + CompletableFuture requestCancellation(String taskId, String sessionId); + + CompletableFuture isCancellationRequested(String taskId, String sessionId); + + CompletableFuture> watchTaskUntilTerminal( + String taskId, String sessionId, long timeoutMs); + + default CompletableFuture shutdown() { + return CompletableFuture.completedFuture(null); + } +} diff --git a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/TriFunction.java b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/TriFunction.java new file mode 100644 index 00000000000..9eeb2f65fb8 --- /dev/null +++ b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/task/TriFunction.java @@ -0,0 +1,20 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package com.taobao.arthas.mcp.server.task; + +/** + * Three-argument function interface. + * + * @param first argument type + * @param second argument type + * @param third argument type + * @param result type + * @author Yeaury + */ +@FunctionalInterface +public interface TriFunction { + + R apply(T t, U u, V v); +} diff --git a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/tool/ToolCallbackCreateTaskHandler.java b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/tool/ToolCallbackCreateTaskHandler.java new file mode 100644 index 00000000000..18c85ecda76 --- /dev/null +++ b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/tool/ToolCallbackCreateTaskHandler.java @@ -0,0 +1,239 @@ +package com.taobao.arthas.mcp.server.tool; + +import com.taobao.arthas.mcp.server.protocol.server.McpTransportContext; +import com.taobao.arthas.mcp.server.protocol.spec.McpSchema; +import com.taobao.arthas.mcp.server.session.ArthasCommandContext; +import com.taobao.arthas.mcp.server.task.CreateTaskContext; +import com.taobao.arthas.mcp.server.task.CreateTaskHandler; +import com.taobao.arthas.mcp.server.util.JsonParser; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.CompletableFuture; + +import static com.taobao.arthas.mcp.server.tool.ToolContextKeys.*; + +/** + * 将 ToolCallback 适配为 CreateTaskHandler 的通用适配器。 + * + * @see MCP Tasks Specification + */ +public class ToolCallbackCreateTaskHandler implements CreateTaskHandler { + + private static final Logger logger = LoggerFactory.getLogger(ToolCallbackCreateTaskHandler.class); + + private final ToolCallback toolCallback; + + public ToolCallbackCreateTaskHandler(ToolCallback toolCallback) { + this.toolCallback = toolCallback; + } + + @Override + public CompletableFuture createTask( + Map args, + CreateTaskContext context) { + + logger.debug("Creating task for tool: {}", toolCallback.getToolDefinition().getName()); + + return context.createTask(opts -> { + // 使用默认配置,工具可以通过注解自定义 pollInterval 等 + }).thenCompose(task -> { + String taskId = task.getTaskId(); + + logger.info("Task created: {}, starting async tool execution", taskId); + + CompletableFuture.runAsync(() -> { + executeToolAndUpdateTaskStatus(taskId, args, context); + }); + + return CompletableFuture.completedFuture( + new McpSchema.CreateTaskResult(task, null) + ); + }); + } + + /** + * 在后台执行工具并更新任务状态。 + * + * @param taskId 任务 ID + * @param args 工具参数 + * @param context 任务上下文 + */ + private void executeToolAndUpdateTaskStatus(String taskId, Map args, CreateTaskContext context) { + ArthasCommandContext isolatedContext = null; + try { + // 执行前检查任务是否已被取消 + Boolean alreadyCancelled = context.isCancellationRequested(taskId).join(); + if (Boolean.TRUE.equals(alreadyCancelled)) { + logger.info("Task {} was cancelled before execution started, skipping", taskId); + return; + } + + logger.debug("Executing tool: {} for task: {}", + toolCallback.getToolDefinition().getName(), taskId); + + // 为 task 创建独立的 session + isolatedContext = context.createIsolatedTaskSession(taskId); + logger.debug("Created isolated session for task: {}, arthasSessionId: {}", + taskId, isolatedContext.getArthasSessionId()); + + // 使用独立的 context 构建工具上下文 + ToolContext enhancedContext = buildEnhancedToolContext(taskId, context, isolatedContext); + + // 调用工具方法(工具内部的轮询循环会检查取消状态) + String toolInput = JsonParser.toJson(args); + String resultJson = toolCallback.call(toolInput, enhancedContext); + + // 执行完成后再次检查取消状态 + Boolean cancelledAfter = context.isCancellationRequested(taskId).join(); + if (Boolean.TRUE.equals(cancelledAfter)) { + logger.info("Task {} was cancelled during execution, interrupting job", taskId); + interruptJob(isolatedContext); + return; + } + + // 解析结果为 CallToolResult + McpSchema.CallToolResult result = parseToolResult(resultJson); + + // 检查工具返回的结果是否标记为 cancelled(由 StreamableToolUtils 设置) + if (isResultCancelled(resultJson)) { + logger.info("Task {} execution detected cancellation, interrupting job", taskId); + interruptJob(isolatedContext); + return; + } + + // 根据结果类型更新任务状态 + if (Boolean.TRUE.equals(result.getIsError())) { + // 工具返回错误结果,标记任务为失败 + String errorMessage = extractErrorMessage(result); + context.failTask(taskId, new McpSchema.CallToolResult(errorMessage, true, null)) + .exceptionally(ex -> { + logger.error("Failed to mark task as failed: {}", taskId, ex); + return null; + }); + logger.warn("Tool execution returned error for task: {}", taskId); + } else { + // 工具执行成功,完成任务 + context.completeTask(taskId, result) + .thenRun(() -> { + logger.info("Task completed successfully: {}", taskId); + }) + .exceptionally(ex -> { + logger.error("Failed to update task completion: {}", taskId, ex); + return null; + }); + } + + } catch (Exception e) { + logger.error("Tool execution failed for task: {}", taskId, e); + + // 标记任务失败(如果任务已被取消,updateTaskStatus 会静默忽略终态任务) + context.failTask(taskId, new McpSchema.CallToolResult("Tool execution failed: " + e.getMessage(), true, null)) + .exceptionally(ex -> { + logger.error("Failed to update task failure: {}", taskId, ex); + return null; + }); + } finally { + // 清理独立的 session + cleanupTaskSession(taskId, context); + } + } + + private void interruptJob(ArthasCommandContext commandContext) { + try { + if (commandContext != null) { + commandContext.interruptJob(); + } + } catch (Exception e) { + logger.warn("Failed to interrupt job: {}", e.getMessage()); + } + } + + @SuppressWarnings("unchecked") + private boolean isResultCancelled(String resultJson) { + try { + Map resultMap = JsonParser.fromJson(resultJson, Map.class); + return Boolean.TRUE.equals(resultMap.get("cancelled")); + } catch (Exception e) { + return false; + } + } + + /** + * 清理 task 的独立 session。 + */ + private void cleanupTaskSession(String taskId, CreateTaskContext context) { + try { + context.cleanupTaskSession(taskId); + logger.debug("Cleaned up task session: {}", taskId); + } catch (Exception e) { + logger.warn("Failed to cleanup task session: {}, error={}", taskId, e.getMessage()); + } + } + + /** + * 构建增强的 ToolContext,用于 task 执行。 + */ + private ToolContext buildEnhancedToolContext( + String taskId, + CreateTaskContext context, + ArthasCommandContext isolatedContext) { + + Map contextMap = new HashMap<>(); + + contextMap.put(CREATE_TASK_CONTEXT, context); + contextMap.put(TASK_ID, taskId); + + contextMap.put(COMMAND_CONTEXT, isolatedContext); + + if (context.exchange() != null) { + contextMap.put(EXCHANGE, context.exchange()); + + McpTransportContext transportContext = context.exchange().getTransportContext(); + if (transportContext != null) { + contextMap.put(MCP_TRANSPORT_CONTEXT, transportContext); + } + } + + return new ToolContext(contextMap); + } + + private McpSchema.CallToolResult parseToolResult(String resultJson) { + try { + Map resultMap = JsonParser.fromJson(resultJson, Map.class); + + if (resultMap.containsKey("content")) { + return JsonParser.fromJson(resultJson, McpSchema.CallToolResult.class); + } + + McpSchema.TextContent textContent = new McpSchema.TextContent(resultJson); + return new McpSchema.CallToolResult( + java.util.Collections.singletonList(textContent), + false, + null + ); + + } catch (Exception e) { + logger.debug("Failed to parse tool result as JSON, treating as plain text", e); + + McpSchema.TextContent textContent = new McpSchema.TextContent(resultJson); + return new McpSchema.CallToolResult( + java.util.Collections.singletonList(textContent), + false, + null + ); + } + } + + private String extractErrorMessage(McpSchema.CallToolResult result) { + if (result.getContent() != null && !result.getContent().isEmpty()) { + McpSchema.Content firstContent = result.getContent().get(0); + if (firstContent instanceof McpSchema.TextContent) { + return ((McpSchema.TextContent) firstContent).getText(); + } + } + return "Tool execution failed"; + } +} diff --git a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/tool/ToolContextKeys.java b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/tool/ToolContextKeys.java new file mode 100644 index 00000000000..0d84e3865e0 --- /dev/null +++ b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/tool/ToolContextKeys.java @@ -0,0 +1,23 @@ +package com.taobao.arthas.mcp.server.tool; + +/** + * ToolContext 键名常量定义。 + * + */ +public final class ToolContextKeys { + + public static final String EXCHANGE = "exchange"; + + public static final String COMMAND_CONTEXT = "commandContext"; + + public static final String MCP_TRANSPORT_CONTEXT = "mcpTransportContext"; + + public static final String PROGRESS_TOKEN = "progressToken"; + + public static final String TASK_ID = "taskId"; + + public static final String CREATE_TASK_CONTEXT = "createTaskContext"; + + private ToolContextKeys() { + } +} diff --git a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/tool/annotation/Tool.java b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/tool/annotation/Tool.java index 4f1eb1bfe10..058f610b741 100644 --- a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/tool/annotation/Tool.java +++ b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/tool/annotation/Tool.java @@ -1,5 +1,7 @@ package com.taobao.arthas.mcp.server.tool.annotation; +import com.taobao.arthas.mcp.server.protocol.spec.McpSchema.TaskSupportMode; + import java.lang.annotation.*; @Target({ ElementType.METHOD, ElementType.ANNOTATION_TYPE }) @@ -12,5 +14,19 @@ String description() default ""; boolean streamable() default false; + + /** + * 任务支持模式。 + * + *

    + *
  • {@link TaskSupportMode#FORBIDDEN FORBIDDEN} - 不支持任务(默认)
  • + *
  • {@link TaskSupportMode#OPTIONAL OPTIONAL} - 可选支持任务
  • + *
  • {@link TaskSupportMode#REQUIRED REQUIRED} - 必须以任务模式调用
  • + *
+ * + * @return 任务支持模式 + * @see MCP Tasks Specification + */ + TaskSupportMode taskSupport() default TaskSupportMode.FORBIDDEN; } diff --git a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/tool/definition/ToolDefinition.java b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/tool/definition/ToolDefinition.java index 1ebaad3109f..d5477d06d72 100644 --- a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/tool/definition/ToolDefinition.java +++ b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/tool/definition/ToolDefinition.java @@ -1,6 +1,7 @@ package com.taobao.arthas.mcp.server.tool.definition; import com.taobao.arthas.mcp.server.protocol.spec.McpSchema; +import com.taobao.arthas.mcp.server.protocol.spec.McpSchema.TaskSupportMode; public class ToolDefinition { private String name; @@ -10,13 +11,21 @@ public class ToolDefinition { private McpSchema.JsonSchema inputSchema; private boolean streamable; + + private TaskSupportMode taskSupport; public ToolDefinition(String name, String description, - McpSchema.JsonSchema inputSchema, boolean streamable) { + McpSchema.JsonSchema inputSchema, boolean streamable, TaskSupportMode taskSupport) { this.name = name; this.description = description; this.inputSchema = inputSchema; this.streamable = streamable; + this.taskSupport = taskSupport; + } + + public ToolDefinition(String name, String description, + McpSchema.JsonSchema inputSchema, boolean streamable) { + this(name, description, inputSchema, streamable, TaskSupportMode.FORBIDDEN); } public String getName() { @@ -35,6 +44,10 @@ public boolean isStreamable() { return streamable; } + public TaskSupportMode taskSupport() { + return taskSupport; + } + public static Builder builder() { return new Builder(); } @@ -48,6 +61,8 @@ public static final class Builder { private McpSchema.JsonSchema inputSchema; private boolean streamable; + + private TaskSupportMode taskSupport = TaskSupportMode.FORBIDDEN; private Builder() { } @@ -71,10 +86,14 @@ public Builder streamable(boolean streamable) { this.streamable = streamable; return this; } + + public Builder taskSupport(TaskSupportMode taskSupport) { + this.taskSupport = taskSupport; + return this; + } public ToolDefinition build() { - return new ToolDefinition(this.name, this.description, this.inputSchema, this.streamable); + return new ToolDefinition(name, description, inputSchema, streamable, taskSupport); } - } } diff --git a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/tool/definition/ToolDefinitions.java b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/tool/definition/ToolDefinitions.java index fd4b421a281..0c30b62b93b 100644 --- a/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/tool/definition/ToolDefinitions.java +++ b/arthas-mcp-server/src/main/java/com/taobao/arthas/mcp/server/tool/definition/ToolDefinitions.java @@ -1,5 +1,6 @@ package com.taobao.arthas.mcp.server.tool.definition; +import com.taobao.arthas.mcp.server.protocol.spec.McpSchema; import com.taobao.arthas.mcp.server.tool.annotation.Tool; import com.taobao.arthas.mcp.server.tool.util.JsonSchemaGenerator; import com.taobao.arthas.mcp.server.util.Assert; @@ -14,7 +15,8 @@ public static ToolDefinition.Builder builder(Method method) { .name(getToolName(method)) .description(getToolDescription(method)) .inputSchema(JsonSchemaGenerator.generateForMethodInput(method)) - .streamable(isStreamable(method)); + .streamable(isStreamable(method)) + .taskSupport(getTaskSupport(method)); } public static ToolDefinition from(Method method) { @@ -48,4 +50,13 @@ public static boolean isStreamable(Method method) { return tool.streamable(); } + public static McpSchema.TaskSupportMode getTaskSupport(Method method) { + Assert.notNull(method, "method cannot be null"); + Tool tool = method.getAnnotation(Tool.class); + if (tool == null) { + return McpSchema.TaskSupportMode.FORBIDDEN; + } + return tool.taskSupport(); + } + } diff --git a/core/src/main/java/com/taobao/arthas/core/mcp/ArthasMcpServer.java b/core/src/main/java/com/taobao/arthas/core/mcp/ArthasMcpServer.java index 8e1b9a00dfd..a5b377f969d 100644 --- a/core/src/main/java/com/taobao/arthas/core/mcp/ArthasMcpServer.java +++ b/core/src/main/java/com/taobao/arthas/core/mcp/ArthasMcpServer.java @@ -16,9 +16,18 @@ import com.taobao.arthas.mcp.server.protocol.spec.McpSchema.Implementation; import com.taobao.arthas.mcp.server.protocol.spec.McpSchema.ServerCapabilities; import com.taobao.arthas.mcp.server.protocol.spec.McpStreamableServerTransportProvider; +import com.taobao.arthas.mcp.server.protocol.spec.McpSchema; +import com.taobao.arthas.mcp.server.session.ArthasCommandSessionManager; +import com.taobao.arthas.mcp.server.task.InMemoryTaskMessageQueue; +import com.taobao.arthas.mcp.server.task.InMemoryTaskStore; +import com.taobao.arthas.mcp.server.task.TaskAwareToolSpecification; +import com.taobao.arthas.mcp.server.task.TaskMessageQueue; +import com.taobao.arthas.mcp.server.task.TaskStore; import com.taobao.arthas.mcp.server.tool.DefaultToolCallbackProvider; import com.taobao.arthas.mcp.server.tool.ToolCallback; +import com.taobao.arthas.mcp.server.tool.ToolCallbackCreateTaskHandler; import com.taobao.arthas.mcp.server.tool.ToolCallbackProvider; +import com.taobao.arthas.mcp.server.tool.definition.ToolDefinition; import com.taobao.arthas.mcp.server.util.JsonParser; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -46,6 +55,7 @@ public class ArthasMcpServer { private final ServerProtocol protocol; private final CommandExecutor commandExecutor; + private ArthasCommandSessionManager sessionManager; private McpHttpRequestHandler unifiedMcpHandler; @@ -75,13 +85,13 @@ public McpHttpRequestHandler getMcpRequestHandler() { } /** - * Start MCP server + * 启动 MCP 服务器 */ public void start() { try { - // Register Arthas-specific JSON filter + // 注册 Arthas 特定的 JSON 过滤器 com.taobao.arthas.core.mcp.util.McpObjectVOFilter.register(); - + McpServerProperties properties = new McpServerProperties.Builder() .name("arthas-mcp-server") .version("4.1.5") @@ -93,14 +103,7 @@ public void start() { .protocol(this.protocol) .build(); - // Use Arthas tool base package from core module - DefaultToolCallbackProvider toolCallbackProvider = new DefaultToolCallbackProvider(); - toolCallbackProvider.setToolBasePackage(ARTHAS_TOOL_BASE_PACKAGE); - - ToolCallback[] callbacks = toolCallbackProvider.getToolCallbacks(); - List providerToolCallbacks = Arrays.stream(callbacks) - .filter(Objects::nonNull) - .collect(Collectors.toList()); + ToolClassification toolClassification = scanAndClassifyTools(); unifiedMcpHandler = McpHttpRequestHandler.builder() .mcpEndpoint(properties.getMcpEndpoint()) @@ -109,51 +112,188 @@ public void start() { .build(); if (properties.getProtocol() == ServerProtocol.STREAMABLE) { - McpStreamableServerTransportProvider transportProvider = createStreamableHttpTransportProvider(properties); - streamableHandler = transportProvider.getMcpRequestHandler(); - unifiedMcpHandler.setStreamableHandler(streamableHandler); - - McpServer.StreamableServerNettySpecification streamableServerNettySpecification = McpServer.netty(transportProvider) - .serverInfo(new Implementation(properties.getName(), properties.getVersion())) - .capabilities(buildServerCapabilities(properties)) - .instructions(properties.getInstructions()) - .requestTimeout(properties.getRequestTimeout()) - .commandExecutor(commandExecutor) - .objectMapper(properties.getObjectMapper() != null ? properties.getObjectMapper() : JsonParser.getObjectMapper()); - - streamableServerNettySpecification.tools( - McpToolUtils.toStreamableToolSpecifications(providerToolCallbacks)); - - streamableServer = streamableServerNettySpecification.build(); + startStreamableServer(properties, toolClassification); } else { - NettyStatelessServerTransport statelessTransport = createStatelessHttpTransport(properties); - statelessHandler = statelessTransport.getMcpRequestHandler(); - unifiedMcpHandler.setStatelessHandler(statelessHandler); - - McpServer.StatelessServerNettySpecification statelessServerNettySpecification = McpServer.netty(statelessTransport) - .serverInfo(new Implementation(properties.getName(), properties.getVersion())) - .capabilities(buildServerCapabilities(properties)) - .instructions(properties.getInstructions()) - .requestTimeout(properties.getRequestTimeout()) - .commandExecutor(commandExecutor) - .objectMapper(properties.getObjectMapper() != null ? properties.getObjectMapper() : JsonParser.getObjectMapper()); - - statelessServerNettySpecification.tools( - McpToolUtils.toStatelessToolSpecifications(providerToolCallbacks)); - - statelessServer = statelessServerNettySpecification.build(); + startStatelessServer(properties, toolClassification); } logger.info("Arthas MCP server started successfully"); logger.info("- MCP Endpoint: {}", properties.getMcpEndpoint()); logger.info("- Transport mode: {}", properties.getProtocol()); - logger.info("- Available tools: {}", providerToolCallbacks.size()); - logger.info("- Server ready to accept connections"); } catch (Exception e) { logger.error("Failed to start Arthas MCP server", e); throw new RuntimeException("Failed to start Arthas MCP server", e); } } + + /** + * 扫描并分类工具 + */ + private ToolClassification scanAndClassifyTools() { + DefaultToolCallbackProvider toolCallbackProvider = new DefaultToolCallbackProvider(); + toolCallbackProvider.setToolBasePackage(ARTHAS_TOOL_BASE_PACKAGE); + + ToolCallback[] allCallbacks = toolCallbackProvider.getToolCallbacks(); + + // 根据 taskSupport 属性分类工具 + List requiredTaskTools = new ArrayList<>(); // taskSupport=required + List optionalTaskTools = new ArrayList<>(); // taskSupport=optional + List normalTools = new ArrayList<>(); // taskSupport=forbidden + + for (ToolCallback callback : allCallbacks) { + if (callback == null) { + continue; + } + + ToolDefinition def = callback.getToolDefinition(); + McpSchema.TaskSupportMode taskSupport = def.taskSupport(); + + // 根据 taskSupport 分类 + switch (taskSupport) { + case REQUIRED: + requiredTaskTools.add(callback); + break; + case OPTIONAL: + optionalTaskTools.add(callback); + break; + case FORBIDDEN: + default: + normalTools.add(callback); + break; + } + } + + logger.info("Scanned {} tools: {} normal, {} optional-task, {} required-task", + allCallbacks.length, normalTools.size(), optionalTaskTools.size(), requiredTaskTools.size()); + + return new ToolClassification(Arrays.asList(allCallbacks), normalTools, optionalTaskTools, requiredTaskTools); + } + + /** + * 启动 Streamable 模式服务器 + */ + private void startStreamableServer(McpServerProperties properties, ToolClassification classification) { + // 初始化 SessionManager + this.sessionManager = new ArthasCommandSessionManager(commandExecutor); + logger.info("Initialized ArthasCommandSessionManager for MCP server"); + + McpStreamableServerTransportProvider transportProvider = createStreamableHttpTransportProvider(properties); + streamableHandler = transportProvider.getMcpRequestHandler(); + unifiedMcpHandler.setStreamableHandler(streamableHandler); + + // 准备任务感知工具列表(taskSupport = OPTIONAL 或 REQUIRED) + List taskAwareTools = new ArrayList<>(); + taskAwareTools.addAll(classification.optionalTaskTools); + taskAwareTools.addAll(classification.requiredTaskTools); + + boolean hasTaskTools = !taskAwareTools.isEmpty(); + + McpServer.StreamableServerNettySpecification serverSpec = McpServer.netty(transportProvider) + .serverInfo(new Implementation(properties.getName(), properties.getVersion())) + .capabilities(buildServerCapabilities(properties, hasTaskTools)) + .instructions(properties.getInstructions()) + .requestTimeout(properties.getRequestTimeout()) + .commandExecutor(commandExecutor) + .sessionManager(this.sessionManager) + .objectMapper(properties.getObjectMapper() != null ? properties.getObjectMapper() : JsonParser.getObjectMapper()); + + // 只注册普通工具(taskSupport = FORBIDDEN) + serverSpec.tools(McpToolUtils.toStreamableToolSpecifications(classification.normalTools)); + logger.debug("Registered {} normal tools", classification.normalTools.size()); + + if (hasTaskTools) { + configureTaskSupport(serverSpec, taskAwareTools); + } + + streamableServer = serverSpec.build(); + } + + /** + * 配置任务支持 + */ + private void configureTaskSupport(McpServer.StreamableServerNettySpecification serverSpec, + List taskAwareTools) { + logger.info("Configuring tasks support for {} task-aware tools", taskAwareTools.size()); + + // 创建 TaskStore 和 TaskMessageQueue + TaskStore taskStore = InMemoryTaskStore.builder() + .defaultTtl(Duration.ofMinutes(30)) // 任务 TTL 30 分钟 + .build(); + + TaskMessageQueue messageQueue = new InMemoryTaskMessageQueue(); + + // 配置 TaskStore 和 TaskMessageQueue + serverSpec.taskStore(taskStore).taskMessageQueue(messageQueue); + + // 为每个任务感知工具创建 TaskAwareToolSpecification + for (ToolCallback callback : taskAwareTools) { + ToolDefinition def = callback.getToolDefinition(); + + ToolCallbackCreateTaskHandler createTaskHandler = new ToolCallbackCreateTaskHandler(callback); + + TaskAwareToolSpecification spec = TaskAwareToolSpecification.builder() + .name(def.getName()) + .description(def.getDescription()) + .inputSchema(def.getInputSchema()) + .taskSupport(def.taskSupport()) + .createTaskHandler(createTaskHandler) + .build(); + + serverSpec.taskTool(spec); + logger.debug("Registered task-aware tool: {} (taskSupport: {})", def.getName(), def.taskSupport()); + } + + logger.info("Registered {} task-aware tools successfully", taskAwareTools.size()); + } + + /** + * 启动 Stateless 模式服务器 + */ + private void startStatelessServer(McpServerProperties properties, ToolClassification classification) { + // 创建传输层 + NettyStatelessServerTransport statelessTransport = createStatelessHttpTransport(properties); + statelessHandler = statelessTransport.getMcpRequestHandler(); + unifiedMcpHandler.setStatelessHandler(statelessHandler); + + // Stateless 模式不支持任务 + boolean enableTasks = false; + + // 构建服务器规格 + McpServer.StatelessServerNettySpecification serverSpec = McpServer.netty(statelessTransport) + .serverInfo(new Implementation(properties.getName(), properties.getVersion())) + .capabilities(buildServerCapabilities(properties, enableTasks)) + .instructions(properties.getInstructions()) + .requestTimeout(properties.getRequestTimeout()) + .commandExecutor(commandExecutor) + .objectMapper(properties.getObjectMapper() != null ? properties.getObjectMapper() : JsonParser.getObjectMapper()); + + // 在 stateless 模式下,所有工具都作为普通工具注册(不支持任务) + serverSpec.tools(McpToolUtils.toStatelessToolSpecifications(classification.allCallbacks)); + logger.info("Registered {} tools in stateless mode (tasks not supported)", classification.allCallbacks.size()); + + // 构建并启动服务器 + statelessServer = serverSpec.build(); + } + + /** + * 工具分类结果 + */ + private static class ToolClassification { + final List allCallbacks; + final List normalTools; + final List optionalTaskTools; + final List requiredTaskTools; + + ToolClassification(List allCallbacks, + List normalTools, + List optionalTaskTools, + List requiredTaskTools) { + this.allCallbacks = allCallbacks; + this.normalTools = normalTools; + this.optionalTaskTools = optionalTaskTools; + this.requiredTaskTools = requiredTaskTools; + } + } /** * Default keep-alive interval for MCP server (15 seconds) @@ -178,12 +318,35 @@ private NettyStatelessServerTransport createStatelessHttpTransport(McpServerProp .build(); } - private ServerCapabilities buildServerCapabilities(McpServerProperties properties) { - return ServerCapabilities.builder() + /** + * 构建服务器能力声明。 + * + * @param properties 服务器属性 + * @param enableTasks 是否启用任务支持(只有在有任务工具时才启用) + * @return ServerCapabilities + */ + private ServerCapabilities buildServerCapabilities(McpServerProperties properties, boolean enableTasks) { + ServerCapabilities.Builder builder = ServerCapabilities.builder() .prompts(new ServerCapabilities.PromptCapabilities(properties.isPromptChangeNotification())) .resources(new ServerCapabilities.ResourceCapabilities(properties.isResourceSubscribe(), properties.isResourceChangeNotification())) - .tools(new ServerCapabilities.ToolCapabilities(properties.isToolChangeNotification())) - .build(); + .tools(new ServerCapabilities.ToolCapabilities(properties.isToolChangeNotification())); + + // 只有在有任务工具时才声明 tasks capability + if (enableTasks) { + // 声明服务器支持的任务能力 + ServerCapabilities.TaskCapabilities taskCapabilities = ServerCapabilities.TaskCapabilities.builder() + .list() // 支持 tasks/list(列出所有任务) + .cancel() // 支持 tasks/cancel(取消任务) + .toolsCall() // 支持 tools/call 的任务增强执行(包括 tasks/get 和 tasks/result) + .build(); + + builder.tasks(taskCapabilities); + logger.info("Tasks capability enabled (supports list, cancel, tools/call with tasks)"); + } else { + logger.info("Tasks capability disabled (no task-aware tools)"); + } + + return builder.build(); } public void stop() { diff --git a/core/src/main/java/com/taobao/arthas/core/mcp/tool/function/AbstractArthasTool.java b/core/src/main/java/com/taobao/arthas/core/mcp/tool/function/AbstractArthasTool.java index 09ea3cbce9f..2464767a3ba 100644 --- a/core/src/main/java/com/taobao/arthas/core/mcp/tool/function/AbstractArthasTool.java +++ b/core/src/main/java/com/taobao/arthas/core/mcp/tool/function/AbstractArthasTool.java @@ -4,14 +4,15 @@ import com.taobao.arthas.core.mcp.util.McpAuthExtractor; import com.taobao.arthas.mcp.server.protocol.server.McpNettyServerExchange; import com.taobao.arthas.mcp.server.protocol.server.McpTransportContext; +import com.taobao.arthas.mcp.server.task.CreateTaskContext; import com.taobao.arthas.mcp.server.tool.ToolContext; +import com.taobao.arthas.mcp.server.tool.ToolContextKeys; import com.taobao.arthas.mcp.server.util.JsonParser; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.Map; -import static com.taobao.arthas.core.mcp.tool.util.McpToolUtils.*; import static com.taobao.arthas.core.mcp.tool.function.StreamableToolUtils.*; /** @@ -39,18 +40,18 @@ protected static class ToolExecutionContext { private final boolean isStreamable; public ToolExecutionContext(ToolContext toolContext, boolean isStreamable) { - this.commandContext = (ArthasCommandContext) toolContext.getContext().get(TOOL_CONTEXT_COMMAND_CONTEXT_KEY); + this.commandContext = (ArthasCommandContext) toolContext.getContext().get(ToolContextKeys.COMMAND_CONTEXT); this.isStreamable = isStreamable; // 尝试获取 Exchange (在 Stateless 模式下为 null) - this.exchange = (McpNettyServerExchange) toolContext.getContext().get(TOOL_CONTEXT_MCP_EXCHANGE_KEY); + this.exchange = (McpNettyServerExchange) toolContext.getContext().get(ToolContextKeys.EXCHANGE); // 尝试获取 Progress Token - Object progressTokenObj = toolContext.getContext().get(PROGRESS_TOKEN); + Object progressTokenObj = toolContext.getContext().get(ToolContextKeys.PROGRESS_TOKEN); this.progressToken = progressTokenObj != null ? String.valueOf(progressTokenObj) : null; // 尝试获取 Transport Context (在 Stateless 模式下可能为 null) - this.mcpTransportContext = (McpTransportContext) toolContext.getContext().get(MCP_TRANSPORT_CONTEXT); + this.mcpTransportContext = (McpTransportContext) toolContext.getContext().get(ToolContextKeys.MCP_TRANSPORT_CONTEXT); // 从 Transport Context 中提取认证信息 if (this.mcpTransportContext != null) { @@ -111,14 +112,14 @@ protected String executeSync(ToolContext toolContext, String commandStr) { } } - protected String executeStreamable(ToolContext toolContext, String commandStr, - Integer expectedResultCount, Integer pollIntervalMs, + protected String executeStreamable(ToolContext toolContext, String commandStr, + Integer expectedResultCount, Integer pollIntervalMs, Integer timeoutMs, String successMessage) { ToolExecutionContext execContext = null; try { execContext = new ToolExecutionContext(toolContext, true); - + logger.info("Starting streamable execution: {}", commandStr); // Set userId to session before async execution for stat reporting @@ -133,16 +134,25 @@ protected String executeStreamable(ToolContext toolContext, String commandStr, } logger.debug("Async execution started: {}", asyncResult); + // 构建取消检查器(如果在 task 模式下运行) + StreamableToolUtils.CancellationChecker cancellationChecker = buildCancellationChecker(toolContext); + Map results = executeAndCollectResults( - execContext.getExchange(), - execContext.getCommandContext(), - expectedResultCount, - pollIntervalMs, + execContext.getExchange(), + execContext.getCommandContext(), + expectedResultCount, + pollIntervalMs, timeoutMs, - execContext.getProgressToken() + execContext.getProgressToken(), + cancellationChecker ); - + if (results != null) { + // 检查是否被取消 + if (Boolean.TRUE.equals(results.get("cancelled"))) { + return JsonParser.toJson(results); + } + String message = successMessage != null ? successMessage : "Command execution completed successfully"; if (Boolean.TRUE.equals(results.get("timedOut"))) { @@ -153,12 +163,12 @@ protected String executeStreamable(ToolContext toolContext, String commandStr, message = "Command execution ended (Timed out). No results captured within the time limit."; } } - + return JsonParser.toJson(createCompletedResponse(message, results)); } else { return JsonParser.toJson(createErrorResponse("Command execution failed due to timeout or error limits exceeded")); } - + } catch (Exception e) { logger.error("Error executing streamable command: {}", commandStr, e); return JsonParser.toJson(createErrorResponse("Error executing command: " + e.getMessage())); @@ -173,6 +183,37 @@ protected String executeStreamable(ToolContext toolContext, String commandStr, } } + /** + * 从 ToolContext 中构建取消检查器。 + * + *

当工具在 task 模式下运行时,ToolContext 中包含 CREATE_TASK_EXTRA 和 TASK_ID, + * 可以用来定期检查任务是否已被取消。非 task 模式下返回 null。 + */ + private static StreamableToolUtils.CancellationChecker buildCancellationChecker(ToolContext toolContext) { + if (toolContext == null) { + return null; + } + Object extraObj = toolContext.getContext().get(ToolContextKeys.CREATE_TASK_CONTEXT); + Object taskIdObj = toolContext.getContext().get(ToolContextKeys.TASK_ID); + if (!(extraObj instanceof CreateTaskContext) || !(taskIdObj instanceof String)) { + return null; + } + final CreateTaskContext taskContext = (CreateTaskContext) extraObj; + final String taskId = (String) taskIdObj; + return new StreamableToolUtils.CancellationChecker() { + @Override + public boolean isCancelled() { + try { + Boolean result = taskContext.isCancellationRequested(taskId).join(); + return Boolean.TRUE.equals(result); + } catch (Exception e) { + // 检查失败不应阻止工具继续执行 + return false; + } + } + }; + } + private static boolean isAsyncExecutionStarted(Map asyncResult) { if (asyncResult == null) { return false; diff --git a/core/src/main/java/com/taobao/arthas/core/mcp/tool/function/StreamableToolUtils.java b/core/src/main/java/com/taobao/arthas/core/mcp/tool/function/StreamableToolUtils.java index b14cff52b13..d0f1adba045 100644 --- a/core/src/main/java/com/taobao/arthas/core/mcp/tool/function/StreamableToolUtils.java +++ b/core/src/main/java/com/taobao/arthas/core/mcp/tool/function/StreamableToolUtils.java @@ -33,12 +33,23 @@ public final class StreamableToolUtils { public static final int MIN_ALLOW_INPUT_COUNT_TO_COMPLETE = 2; + /** + * 取消检查器函数式接口。 + * + *

用于在轮询循环中定期检查任务是否已被取消。 + */ + @FunctionalInterface + public interface CancellationChecker { + + boolean isCancelled(); + } + private StreamableToolUtils() { } /** * 同步执行命令并收集所有结果,支持进度通知 - * + * * @param exchange MCP交换器,用于发送进度通知 * @param commandContext 命令上下文 * @param expectedResultCount 预期结果数量 @@ -47,11 +58,32 @@ private StreamableToolUtils() { * @param progressToken 进度令牌 * @return 包含所有结果的Map,如果执行失败返回null */ - public static Map executeAndCollectResults(McpNettyServerExchange exchange, - ArthasCommandContext commandContext, - Integer expectedResultCount, Integer intervalMs, + public static Map executeAndCollectResults(McpNettyServerExchange exchange, + ArthasCommandContext commandContext, + Integer expectedResultCount, Integer intervalMs, Integer timeoutMs, String progressToken) { + return executeAndCollectResults(exchange, commandContext, expectedResultCount, intervalMs, timeoutMs, progressToken, null); + } + + /** + * 同步执行命令并收集所有结果,支持进度通知和取消检查 + * + * @param exchange MCP交换器,用于发送进度通知 + * @param commandContext 命令上下文 + * @param expectedResultCount 预期结果数量 + * @param intervalMs 轮询间隔 + * @param timeoutMs 超时时间(毫秒) + * @param progressToken 进度令牌 + * @param cancellationChecker 取消检查器,为 null 时不检查取消 + * @return 包含所有结果的Map,取消时返回带 cancelled 标记的结果,执行失败返回null + */ + public static Map executeAndCollectResults(McpNettyServerExchange exchange, + ArthasCommandContext commandContext, + Integer expectedResultCount, Integer intervalMs, + Integer timeoutMs, + String progressToken, + CancellationChecker cancellationChecker) { List allResults = new ArrayList<>(); int errorRetries = 0; int allowInputCount = 0; @@ -69,6 +101,12 @@ public static Map executeAndCollectResults(McpNettyServerExchang try { while (System.currentTimeMillis() < deadline) { + // 检查任务是否已被取消 + if (cancellationChecker != null && cancellationChecker.isCancelled()) { + logger.info("Task cancellation detected, stopping command execution"); + return createCancelledResult(allResults, totalResultCount); + } + try { Map results = commandContext.pullResults(); if (results == null) { @@ -304,6 +342,17 @@ public static Map createErrorResponseWithResults(String message, return response; } + private static Map createCancelledResult(List allResults, int totalResultCount) { + Map result = new HashMap<>(); + result.put("results", allResults); + result.put("resultCount", totalResultCount); + result.put("status", "cancelled"); + result.put("stage", "final"); + result.put("cancelled", true); + result.put("message", "Task was cancelled by user"); + return result; + } + private static Map createFinalResult(List allResults, int totalResultCount, boolean timedOut, long timeoutMs) { Map finalResult = new HashMap<>(); finalResult.put("results", allResults); diff --git a/core/src/main/java/com/taobao/arthas/core/mcp/tool/function/monitor200/MonitorTool.java b/core/src/main/java/com/taobao/arthas/core/mcp/tool/function/monitor200/MonitorTool.java index 28479ef75b8..5de3b114239 100644 --- a/core/src/main/java/com/taobao/arthas/core/mcp/tool/function/monitor200/MonitorTool.java +++ b/core/src/main/java/com/taobao/arthas/core/mcp/tool/function/monitor200/MonitorTool.java @@ -1,6 +1,7 @@ package com.taobao.arthas.core.mcp.tool.function.monitor200; import com.taobao.arthas.core.mcp.tool.function.AbstractArthasTool; +import com.taobao.arthas.mcp.server.protocol.spec.McpSchema; import com.taobao.arthas.mcp.server.tool.ToolContext; import com.taobao.arthas.mcp.server.tool.annotation.Tool; import com.taobao.arthas.mcp.server.tool.annotation.ToolParam; @@ -27,7 +28,8 @@ public class MonitorTool extends AbstractArthasTool { @Tool( name = "monitor", description = "Monitor 方法调用监控工具: 实时监控指定类的指定方法的调用情况,包括调用次数、成功次数、失败次数、平均RT、失败率等统计信息。对应 Arthas 的 monitor 命令。", - streamable = true + streamable = true, + taskSupport = McpSchema.TaskSupportMode.OPTIONAL ) public String monitor( @ToolParam(description = "类名表达式匹配,支持通配符,如demo.MathGame") diff --git a/core/src/main/java/com/taobao/arthas/core/mcp/tool/function/monitor200/StackTool.java b/core/src/main/java/com/taobao/arthas/core/mcp/tool/function/monitor200/StackTool.java index f5ae1984a74..3c62450ca85 100644 --- a/core/src/main/java/com/taobao/arthas/core/mcp/tool/function/monitor200/StackTool.java +++ b/core/src/main/java/com/taobao/arthas/core/mcp/tool/function/monitor200/StackTool.java @@ -1,6 +1,7 @@ package com.taobao.arthas.core.mcp.tool.function.monitor200; import com.taobao.arthas.core.mcp.tool.function.AbstractArthasTool; +import com.taobao.arthas.mcp.server.protocol.spec.McpSchema; import com.taobao.arthas.mcp.server.tool.ToolContext; import com.taobao.arthas.mcp.server.tool.annotation.Tool; import com.taobao.arthas.mcp.server.tool.annotation.ToolParam; @@ -24,7 +25,8 @@ public class StackTool extends AbstractArthasTool { @Tool( name = "stack", description = "Stack 调用堆栈跟踪工具: 输出当前方法被调用的调用路径,帮助分析方法的调用链路。对应 Arthas 的 stack 命令。", - streamable = true + streamable = true, + taskSupport = McpSchema.TaskSupportMode.OPTIONAL ) public String stack( @ToolParam(description = "类名表达式匹配,支持通配符,如demo.MathGame") diff --git a/core/src/main/java/com/taobao/arthas/core/mcp/tool/function/monitor200/TimeTunnelTool.java b/core/src/main/java/com/taobao/arthas/core/mcp/tool/function/monitor200/TimeTunnelTool.java index 1f3e040c7f6..b225d60488c 100644 --- a/core/src/main/java/com/taobao/arthas/core/mcp/tool/function/monitor200/TimeTunnelTool.java +++ b/core/src/main/java/com/taobao/arthas/core/mcp/tool/function/monitor200/TimeTunnelTool.java @@ -1,6 +1,7 @@ package com.taobao.arthas.core.mcp.tool.function.monitor200; import com.taobao.arthas.core.mcp.tool.function.AbstractArthasTool; +import com.taobao.arthas.mcp.server.protocol.spec.McpSchema; import com.taobao.arthas.mcp.server.tool.ToolContext; import com.taobao.arthas.mcp.server.tool.annotation.Tool; import com.taobao.arthas.mcp.server.tool.annotation.ToolParam; @@ -18,7 +19,8 @@ public class TimeTunnelTool extends AbstractArthasTool { @Tool( name = "tt", description = "TimeTunnel 时空隧道工具: 方法执行数据的时空隧道,记录下指定方法每次调用的入参和返回信息,对应 Arthas 的 tt 命令。支持记录、列表、搜索、查看详情、重放、删除等操作。", - streamable = true + streamable = true, + taskSupport = McpSchema.TaskSupportMode.OPTIONAL ) public String timeTunnel( @ToolParam(description = "操作类型: record/t(记录), list/l(列表), search/s(搜索), info/i(查看详情), replay/p(重放), delete/d(删除), deleteAll/da(删除所有),默认record") diff --git a/core/src/main/java/com/taobao/arthas/core/mcp/tool/function/monitor200/TraceTool.java b/core/src/main/java/com/taobao/arthas/core/mcp/tool/function/monitor200/TraceTool.java index ecf864b33aa..2f2a1891e67 100644 --- a/core/src/main/java/com/taobao/arthas/core/mcp/tool/function/monitor200/TraceTool.java +++ b/core/src/main/java/com/taobao/arthas/core/mcp/tool/function/monitor200/TraceTool.java @@ -1,6 +1,7 @@ package com.taobao.arthas.core.mcp.tool.function.monitor200; import com.taobao.arthas.core.mcp.tool.function.AbstractArthasTool; +import com.taobao.arthas.mcp.server.protocol.spec.McpSchema; import com.taobao.arthas.mcp.server.tool.ToolContext; import com.taobao.arthas.mcp.server.tool.annotation.Tool; import com.taobao.arthas.mcp.server.tool.annotation.ToolParam; @@ -18,7 +19,8 @@ public class TraceTool extends AbstractArthasTool { @Tool( name = "trace", description = "Trace 方法内部调用路径跟踪工具: 追踪方法内部调用路径,输出每个节点的耗时信息,对应 Arthas 的 trace 命令。", - streamable = true + streamable = true, + taskSupport = McpSchema.TaskSupportMode.OPTIONAL ) public String trace( @ToolParam(description = "类名表达式匹配,支持通配符,如demo.MathGame") diff --git a/core/src/main/java/com/taobao/arthas/core/mcp/tool/function/monitor200/WatchTool.java b/core/src/main/java/com/taobao/arthas/core/mcp/tool/function/monitor200/WatchTool.java index 9a000314d70..955fa27d2a6 100644 --- a/core/src/main/java/com/taobao/arthas/core/mcp/tool/function/monitor200/WatchTool.java +++ b/core/src/main/java/com/taobao/arthas/core/mcp/tool/function/monitor200/WatchTool.java @@ -1,6 +1,7 @@ package com.taobao.arthas.core.mcp.tool.function.monitor200; import com.taobao.arthas.core.mcp.tool.function.AbstractArthasTool; +import com.taobao.arthas.mcp.server.protocol.spec.McpSchema; import com.taobao.arthas.mcp.server.tool.ToolContext; import com.taobao.arthas.mcp.server.tool.annotation.Tool; import com.taobao.arthas.mcp.server.tool.annotation.ToolParam; @@ -21,7 +22,8 @@ public class WatchTool extends AbstractArthasTool { @Tool( name = "watch", description = "Watch 方法执行观察工具: 观察指定方法的调用情况,包括入参、返回值和抛出异常等信息,支持实时流式输出。对应 Arthas 的 watch 命令。", - streamable = true + streamable = true, + taskSupport = McpSchema.TaskSupportMode.OPTIONAL ) public String watch( @ToolParam(description = "类名表达式匹配,支持通配符,如demo.MathGame") diff --git a/core/src/main/java/com/taobao/arthas/core/mcp/tool/util/McpToolUtils.java b/core/src/main/java/com/taobao/arthas/core/mcp/tool/util/McpToolUtils.java index eb2f5e409be..aad64393f5b 100644 --- a/core/src/main/java/com/taobao/arthas/core/mcp/tool/util/McpToolUtils.java +++ b/core/src/main/java/com/taobao/arthas/core/mcp/tool/util/McpToolUtils.java @@ -1,12 +1,12 @@ package com.taobao.arthas.core.mcp.tool.util; import com.fasterxml.jackson.databind.ObjectMapper; -import com.taobao.arthas.mcp.server.session.ArthasCommandContext; import com.taobao.arthas.mcp.server.protocol.server.McpServerFeatures; import com.taobao.arthas.mcp.server.protocol.server.McpStatelessServerFeatures; import com.taobao.arthas.mcp.server.protocol.spec.McpSchema; import com.taobao.arthas.mcp.server.tool.ToolCallback; import com.taobao.arthas.mcp.server.tool.ToolContext; +import com.taobao.arthas.mcp.server.tool.ToolContextKeys; import java.util.*; import java.util.concurrent.CompletableFuture; @@ -14,14 +14,6 @@ public final class McpToolUtils { - public static final String TOOL_CONTEXT_MCP_EXCHANGE_KEY = "exchange"; - - public static final String TOOL_CONTEXT_COMMAND_CONTEXT_KEY = "commandContext"; - - public static final String MCP_TRANSPORT_CONTEXT = "mcpTransportContext"; - - public static final String PROGRESS_TOKEN = "progressToken"; - private McpToolUtils() { } @@ -50,18 +42,19 @@ public static McpServerFeatures.ToolSpecification toToolSpecification(ToolCallba McpSchema.Tool tool = new McpSchema.Tool( toolCallback.getToolDefinition().getName(), toolCallback.getToolDefinition().getDescription(), - toolCallback.getToolDefinition().getInputSchema() + toolCallback.getToolDefinition().getInputSchema(), + new McpSchema.ToolExecution(toolCallback.getToolDefinition().taskSupport()) ); McpServerFeatures.ToolCallFunction callFunction = (exchange, commandContext, request) -> { try { Map contextMap = new HashMap<>(); - contextMap.put(TOOL_CONTEXT_MCP_EXCHANGE_KEY, exchange); - contextMap.put(TOOL_CONTEXT_COMMAND_CONTEXT_KEY, commandContext); - contextMap.put(PROGRESS_TOKEN, request.progressToken()); + contextMap.put(ToolContextKeys.EXCHANGE, exchange); + contextMap.put(ToolContextKeys.COMMAND_CONTEXT, commandContext); + contextMap.put(ToolContextKeys.PROGRESS_TOKEN, request.progressToken()); // Add MCP_TRANSPORT_CONTEXT from exchange for streamable tools to access auth info if (exchange != null && exchange.getTransportContext() != null) { - contextMap.put(MCP_TRANSPORT_CONTEXT, exchange.getTransportContext()); + contextMap.put(ToolContextKeys.MCP_TRANSPORT_CONTEXT, exchange.getTransportContext()); } ToolContext toolContext = new ToolContext(contextMap); @@ -100,14 +93,15 @@ public static McpStatelessServerFeatures.ToolSpecification toStatelessToolSpecif McpSchema.Tool tool = new McpSchema.Tool( toolCallback.getToolDefinition().getName(), toolCallback.getToolDefinition().getDescription(), - toolCallback.getToolDefinition().getInputSchema() + toolCallback.getToolDefinition().getInputSchema(), + new McpSchema.ToolExecution(toolCallback.getToolDefinition().taskSupport()) ); McpStatelessServerFeatures.ToolCallFunction callFunction = (context, commandContext, arguments) -> { try { Map contextMap = new HashMap<>(); - contextMap.put(MCP_TRANSPORT_CONTEXT, context); - contextMap.put(TOOL_CONTEXT_COMMAND_CONTEXT_KEY, commandContext); + contextMap.put(ToolContextKeys.MCP_TRANSPORT_CONTEXT, context); + contextMap.put(ToolContextKeys.COMMAND_CONTEXT, commandContext); ToolContext toolContext = new ToolContext(contextMap); String argumentsJson = convertParametersToString(arguments);