diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md index 1c18a55..75c1f12 100644 --- a/.github/copilot-instructions.md +++ b/.github/copilot-instructions.md @@ -42,7 +42,7 @@ src/DotNetCampus.ModelContextProtocol/ - 代码主要以最新版本协议进行编写 - 遇到需要兼容旧协议的部分,用 `Legacy` 命名相关代码并尽量减少代码量 - **协议消息类型规范**:详见 [/docs/knowledge/protocol-messages-guide.md](../docs/knowledge/protocol-messages-guide.md) - - 所有 Protocol 消息类型必须添加中英双语注释 + - **仅** `Protocol/` 文件夹下的消息类型必须添加中英双语注释;其他所有代码(接口、实现类、传输层等)一律使用**纯中文注释**(注:当前存在一些遗留非协议代码仍使用双语注释,如果改到了相关代码,请顺手改为纯中文注释) - 英文注释必须使用 MCP 官方 Schema 原文 - 当前使用协议版本:**2025-11-25** - Schema 文件:[schema.ts](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/main/schema/2025-11-25/schema.ts) @@ -75,8 +75,9 @@ src/DotNetCampus.ModelContextProtocol/ ## 参考资源 -- [MCP 官方规范 (2025-06-18)](https://modelcontextprotocol.io/specification/2025-06-18/basic/transports) - **当前使用版本** -- [MCP Schema (2025-06-18)](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/main/schema/2025-06-18/schema.ts) - **官方消息类型定义** +- [MCP 官方规范 (2025-11-25)](https://modelcontextprotocol.io/specification/2025-11-25/basic/transports) - **当前使用版本** +- [MCP Schema (2025-11-25)](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/main/schema/2025-11-25/schema.ts) - **官方消息类型定义** +- [MCP 官方规范 (2025-06-18)](https://modelcontextprotocol.io/specification/2025-06-18/basic/transports) - 旧版本(兼容支持) - [MCP 官方规范 (2024-11-05)](https://modelcontextprotocol.io/specification/2024-11-05/basic/transports) - 旧版本(兼容支持) - [JSON-RPC 2.0 规范](https://www.jsonrpc.org/specification) - [SSE 标准](https://html.spec.whatwg.org/multipage/server-sent-events.html) diff --git a/README.md b/README.md index 3d0e753..88caa37 100644 --- a/README.md +++ b/README.md @@ -44,7 +44,6 @@ internal class Program .WithTool(() => new SampleTools2()) ) // Use Streamable HTTP transport, listening on http://localhost:5943/mcp - // Also compatible with SSE, listening on http://localhost:5943/mcp/sse .WithLocalHostHttp(5943, "mcp") // You can also use stdio (standard input/output) transport, which is recommended by the MCP protocol for all MCP servers // However, it's generally not recommended to enable both http and stdio simultaneously, diff --git a/docs/en/QuickStart.md b/docs/en/QuickStart.md index cdc77fd..66c486b 100644 --- a/docs/en/QuickStart.md +++ b/docs/en/QuickStart.md @@ -19,7 +19,6 @@ internal class Program .WithTool(() => new SampleTools2()) ) // Use Streamable HTTP transport, listening on http://localhost:5943/mcp - // Also compatible with SSE, listening on http://localhost:5943/mcp/sse .WithLocalHostHttp(5943, "mcp") // You can also use stdio (standard input/output) transport, which is recommended by the MCP protocol for all MCP servers // However, it's generally not recommended to enable both http and stdio simultaneously, diff --git a/docs/knowledge/http-server-transport-implementation-guide.md b/docs/knowledge/http-server-transport-implementation-guide.md index 1616fc5..22a9bed 100644 --- a/docs/knowledge/http-server-transport-implementation-guide.md +++ b/docs/knowledge/http-server-transport-implementation-guide.md @@ -59,17 +59,20 @@ * 反序列化 Body 为 `JsonRpcMessage`。 * 将消息通过 `OnMessageReceived` 传递给上层 MCP Server 处理。 6. **响应写入**: - * **情况 1:上层有直接同步返回 (Response)**: + * **`initialize` 请求**: * 设置 `Content-Type: application/json`。 * 写入响应 JSON。 * 返回 `200 OK`。 - * **情况 2:上层无直接返回 (Notification) 或 异步处理**: - * 返回 `202 Accepted`。 - * 无 Body。 - * *高级情况:SSE 升级*(如果 POST 请求 accept SSE 且 Server 决定用 SSE 回复): - * 设置 `Content-Type: text/event-stream`。 - * 保持连接并在稍后推送 SSE Event。 - * *建议:简单起见,POST 尽量使用 application/json 回复,推送信道留给 GET SSE。* + * **`JsonRpcResponse`(客户端回弹采样结果)**: + * 返回 `202 Accepted`,无 Body。 + * **`JsonRpcNotification`(客户端发送的通知)**: + * 返回 `202 Accepted`,无 Body。 + * **所有其他 `JsonRpcRequest`(工具调用等)**: + * 设置 `Content-Type: text/event-stream`,建立本次请求的专属 SSE 流。 + * 先发送一个空注释事件(prime event)保活。 + * 将此 SSE 流绑定到当前 Session(供采样等服务端主动请求使用)。 + * 调用 `HandleRequestAsync` 处理请求(期间采样请求将写入此 SSE 流)。 + * 将最终响应写入 SSE 流,关闭流。 ### C. 处理 GET 请求 (SSE Subscription) @@ -84,17 +87,11 @@ * 设置响应 Header `Content-Type: text/event-stream`。 * 设置 `Cache-Control: no-cache`。 * 返回 `200 OK`(此时不要关闭 Response 流)。 -4. **注册发送通道**: - * 将当前 HTTP Response 流包装为一个 `IAsyncWriter` 或类似接口。 - * 注册到 Session 对象中,作为服务端向客户端推送消息的通道(Server-to-Client Messenger)。 - * **多连接共存策略**:MCP 协议规范 §2.3.1 明确指出 *“The client MAY remain connected to multiple SSE streams simultaneously.”*。因此,服务端**应支持**每个 Session 维护一个活跃连接列表,并将消息广播到所有连接(或仅主连接)。 - * *实现简化建议*:遵循协议精神,服务端应允许新连接加入而不强制断开旧连接。 -5. **发送 Prime Event**: - * 立即发送一个空事件 `event: message\ndata: \n\n` 或仅 `:\n\n` (Comment) 以保活。 - * 根据 SSE 规范,发送 `id` 字段以支持重连。 -6. **保持循环**: - * 进入 `await Task.Delay(-1)` 或等待 Session 关闭信号。 - * 在循环中捕获异常,如果连接断开,从 Session 中注销此通道。 +4. **发送 Prime Event**: + * 立即发送一个空注释 `:\n\n` 以保活连接。 +5. **保持循环**: + * 进入 `await Task.Delay(-1)` 等待,保持 SSE 连接存活(此通路用于未来扩展服务端主动推送,当前暂不发送任何业务消息)。 + * 在循环中捕获异常,如果连接断开则正常退出。 ### D. 处理 DELETE 请求 (Session Termination) @@ -125,22 +122,25 @@ ## 4. 关键数据结构:Session Store -需要一个线程安全的 `ConcurrentDictionary`。 +需要一个线程安全的 `ConcurrentDictionary`。 -**`HttpServerSession` 类职责**: +**`HttpServerTransportSession` 类职责**: * 存储 Session ID。 -* 管理 SSE 发送通道(也就是当前挂着的那个 GET Response 流)。 -* 提供 `SendMessageAsync(JsonRpcMessage)` 方法:将消息序列化为 SSE 格式 (`event: message\ndata: {...}\n\n`) 并写入流。 +* 和待决服务端请求的 TCS 字典(继承自 `ServerTransportSession` 基类)。 +* 管理当前 POST 请求的专属 SSE 输出流(`_currentRequestSseStream`),这是采样等服务端主动请求的通道。 +* 提供 `WriteSseMessageAsync(Stream, JsonRpcMessage)` 方法:将消息序列化为 SSE 格式 (`event: message\ndata: {...}\n\n`) 并写入流。 ## 5. 错误处理 * **JSON 序列化错误**:返回 400。 * **内部异常**:返回 500,并在 Body 中包含(或不包含)JSON-RPC Error。 -## 6. 待办事项 (Checklist) +## 6. 实现状态 (Checklist) -* [ ] 移除旧版兼容代码 (`/mcp/sse`, `/mcp/messages` 路径处理)。 -* [ ] 确保 POST/GET/DELETE 共用同一个 Endpoint URL。 -* [ ] 实现 Session ID 的生成(初始化时)和校验(后续请求)。 -* [ ] 实现 SSE 的心跳或 Keep-Alive(如果底层不自动处理)。 +* [x] POST/GET/DELETE 共用同一个 Endpoint URL `/mcp`。 +* [x] Session ID 的生成(initialize 时)和校验(后续请求)。 +* [x] 非 initialize 的 POST 请求返回 `text/event-stream`,套接 sampling 等服务端主动请求通道。 +* [x] 初始化请求返回 `application/json`。 +* [x] SSE prime event 保活连接。 +* [ ] 旧版协议兼容 (`/mcp/sse`, `/mcp/messages`)(目前未实现)。 diff --git a/docs/knowledge/http-transport-guide.md b/docs/knowledge/http-transport-guide.md index 27766b0..47665bc 100644 --- a/docs/knowledge/http-transport-guide.md +++ b/docs/knowledge/http-transport-guide.md @@ -9,9 +9,9 @@ | **最新** | 2025-11-25 | Streamable HTTP | `/mcp` | `Mcp-Session-Id` header | ✅ 已支持 | | | 2025-06-18 | Streamable HTTP | `/mcp` | `Mcp-Session-Id` header | ✅ 已支持 | | **变更** | 2025-03-26 | Streamable HTTP | `/mcp` | `Mcp-Session-Id` header | ✅ 已支持 | -| **旧协议** | 2024-11-05 | HTTP+SSE | `/mcp/sse`, `/mcp/messages` | query string `sessionId` | ✅ 兼容 | +| **旧协议** | 2024-11-05 | HTTP+SSE | `/mcp/sse`, `/mcp/messages` | query string `sessionId` | ❌ 未实现 | -> **说明**: 2025-11-25、2025-06-18 和 2025-03-26 在传输层上完全兼容,我们的实现同时支持这些版本。 +> **说明**: 2025-11-25、2025-06-18 和 2025-03-26 在传输层上完全兼容,我们的实现同时支持这些版本。旧版 HTTP+SSE 协议(2024-11-05)目前未实现,如有需要请提 issue。 ## 🔑 关键区别 @@ -70,26 +70,30 @@ endpoint.Equals(EndPoint, StringComparison.OrdinalIgnoreCase) ## 📁 代码组织 -```csharp -#region 新协议实现 (Streamable HTTP - 2025-03-26+) -// HandleSseConnectionAsync() -// HandleJsonRpcRequestAsync() -// HandleDeleteSessionAsync() -#endregion - -#region 旧协议兼容 (HTTP+SSE - 2024-11-05) -// HandleLegacySseConnectionAsync() // 带 Legacy 前缀 -// HandleLegacyMessageRequestAsync() -#endregion +POST 处理逻辑被拆分为职责单一的方法(LocalHost 和 TouchSocket 两版结构完全对称): + ``` +HandlePostRequestAsync(入口) + ├── HandleClientResponseAsync // 客户端响应服务端采样请求(JsonRpcResponse) + ├── HandleNotificationAsync // 通知消息,返回 202 Accepted + └── HandleRpcRequestAsync // JSON-RPC 请求 + ├── GetOrCreateSessionAsync // Session 查找/创建 + ├── HandleInitializeAsync // initialize:返回 application/json + └── HandleSseRequestAsync // 其他请求:返回 text/event-stream SSE +``` + +> **POST 响应规则**: +> - `initialize` 请求 → `Content-Type: application/json`,直接返回 +> - 所有其他 JSON-RPC 请求 → `Content-Type: text/event-stream`, +> 采样等服务端发起的消息在此流上推送,最终响应也写入此流后关闭 ## ✅ 测试清单 -- [ ] 新协议:POST `/mcp` 返回 `Mcp-Session-Id` -- [ ] 新协议:GET `/mcp` 建立 Streamable HTTP 连接 -- [ ] 新协议:DELETE `/mcp` 成功终止会话 -- [ ] 旧协议:GET `/mcp/sse` 发送 endpoint 事件 -- [ ] 旧协议:POST `/mcp/messages?sessionId=xxx` 正常工作 +- [x] 新协议:POST `/mcp` initialize 返回 `Mcp-Session-Id`(`application/json`) +- [x] 新协议:POST `/mcp` 工具调用返回 `text/event-stream` SSE 流 +- [x] 新协议:GET `/mcp` 建立 SSE 保活连接 +- [x] 新协议:DELETE `/mcp` 成功终止会话 +- [x] 采样(Sampling):服务端通过 POST 响应 SSE 流发起采样请求,客户端 POST 回采样结果 - [ ] 路径大小写不敏感 - [ ] 会话不存在时 DELETE 返回 200 OK(幂等性) diff --git a/docs/knowledge/logging-style-guide.md b/docs/knowledge/logging-style-guide.md index ae27612..cbf5e5f 100644 --- a/docs/knowledge/logging-style-guide.md +++ b/docs/knowledge/logging-style-guide.md @@ -366,6 +366,78 @@ Log.Warn($"[McpClient][Http] SSE connection error, reconnecting. Error={ex.Messa Log.Error($"[McpServer][StreamableHttp] Failed to start listener: {ex}"); ``` +## 原始消息日志(Raw Message Log) + +原始消息日志专门用于记录 MCP 传输层收发的完整 JSON-RPC 消息内容,是与官方参考实现对比验证的核心手段。 + +### 使用方式 + +不要直接调用 `IMcpLogger`,而是通过传输层管理器上的扩展方法: + +```csharp +// 服务端接收(传输层 _manager 为 IServerTransportManager) +_manager.LogRawIn("[StreamableHttp]", channel, message); + +// 服务端发送 +_manager.LogRawOut("[StreamableHttp]", channel, message); + +// 客户端接收(传输层 _manager 为 IClientTransportManager) +_manager.LogRawIn("[Http]", channel, data); + +// 客户端发送 +_manager.LogRawOut("[Http]", channel, jsonContent); +``` + +- 由 `McpTransportRawMessageLoggingDetailLevel` 控制是否记录及记录多少(`None` / `Trimmed` / `Full`) +- 日志级别固定为 `Debug` +- 只有在 `Debug` 级别启用时才实际输出,方法内部自行判断,调用方无需包裹 `if (IsEnabled(Debug))` + +### 输出格式 + +``` +{role}{tag} {direction} [via {channel}] {rawMessage} +``` + +示例: + +``` +[McpServer][StreamableHttp] ← [via POST, SessionId=abc123] {"jsonrpc":"2.0","method":"initialize",...} +[McpServer][StreamableHttp] → [via POST/json, SessionId=abc123] {"jsonrpc":"2.0","id":"1","result":{...}} +[McpServer][StreamableHttp] → [via POST/sse, SessionId=abc123] {"jsonrpc":"2.0","method":"sampling/createMessage",...} +[McpClient][Http] → [via POST] {"jsonrpc":"2.0","method":"initialize",...} +[McpClient][Http] ← [via POST/json, SessionId=init] {"jsonrpc":"2.0","id":"1","result":{...}} +[McpClient][Http] ← [via GET/sse, SessionId=abc123] {"jsonrpc":"2.0","id":"2","result":{...}} +[McpServer][Stdio] → {"jsonrpc":"2.0","id":"1","result":{...}} +[McpClient][Stdio] → {"jsonrpc":"2.0","method":"initialize",...} +[McpServer][Ipc] ← {"jsonrpc":"2.0","method":"tools/call",...} +``` + +### 渠道标识(channel)规范 + +`channel` 参数描述消息经由的具体 HTTP/传输渠道,格式为 `"类型[, SessionId=xxx]"`。 + +| 场景 | channel 值 | +|------|------------| +| 服务端接收 POST 请求体(初始化前,无会话) | `"POST"` | +| 服务端接收 POST 请求体(有会话) | `$"POST, SessionId={sessionId}"` | +| 服务端发送 POST application/json 响应 | `$"POST/json, SessionId={session.SessionId}"` | +| 服务端发送 POST 内嵌瞬态 SSE | `$"POST/sse, SessionId={SessionId}"` | +| 客户端发送 POST 请求(初始化前) | `"POST"` | +| 客户端发送 POST 请求(有会话) | `$"POST, SessionId={_sessionId}"` | +| 客户端接收 POST application/json 响应 | `$"POST/json, SessionId={_sessionId ?? "init"}"` | +| 客户端接收 POST 内嵌瞬态 SSE | `$"POST/sse, SessionId={_sessionId ?? "init"}"` | +| 客户端接收 GET SSE 后台循环 | `$"GET/sse, SessionId={_sessionId}"` | +| STDIO / IPC | 无需 channel(渠道唯一,省略参数即可) | + +### 覆盖要求 + +每条 JSON-RPC 消息必须在以下两个时刻之一被记录,且只记录一次: + +- **接收侧**:解析完成后(有 `JsonRpcMessage` 对象)、分发给上层逻辑之前 +- **发送侧**:序列化写入流/通道之前 + +--- + ## 代码实现 ### 日志属性命名 diff --git a/docs/knowledge/test-cases.md b/docs/knowledge/test-cases.md index e6d5de9..78881fb 100644 --- a/docs/knowledge/test-cases.md +++ b/docs/knowledge/test-cases.md @@ -89,12 +89,24 @@ | 状态 | 方法名 | 场景描述 | 预期行为 | | :---: | :--- | :--- | :--- | | ✅ | `Delete_TerminateSession` | `LocalHost`, `TouchSocket` | DELETE 请求成功终止会话,IsConnected 为 false | -| ⏳ | `Post_NoSessionId` | 不带 `sessionId` query 发送消息 | 返回 400 Bad Request 或相应错误 | -| ⏳ | `Sse_EndpointEvent` | 建立旧协议 SSE 连接 | 首先收到 `event: endpoint` 消息 | +| ⏳ | `Post_NoSessionId` | 不带 `Mcp-Session-Id` header 发送消息 | 返回 400/404 错误 | +| ⏳ | `Sse_EndpointEvent` | GET SSE 连接 | SSE 流成功建立并保活 | --- -## 3. 官方兼容性测试 (Compliance) +## 3.5 采样功能测试 (Sampling) + +**文件路径**: `tests/DotNetCampus.ModelContextProtocol.Tests/Servers/SamplingTests.cs` +**目标**: 验证服务器向客户端发起 `sampling/createMessage` 请求的完整流程。 + +| 状态 | 方法名 | DataRow / 参数 | 预期行为 | +| :---: | :--- | :--- | :--- | +| ✅ | `ServerToolCanRequestSampling` | `LocalHost`, `TouchSocket` | 工具内调用 Sampling,客户端处理器被执行,返回结果正确 | +| ✅ | `IsSupportedIsFalseWhenClientHasNoCapability` | `LocalHost`, `TouchSocket` | 客户端未声明 Sampling 能力时 IsSupported 为 false | + +--- + +## 4. 官方兼容性测试 (Compliance) **文件路径**: `tests/DotNetCampus.ModelContextProtocol.Tests/Compliance/OfficialServerTests.cs` **目标**: 启动真正的 Node.js MCP Server 验证本库 Client。 @@ -109,9 +121,9 @@ --- -## 4. 已实现的辅助工具 +## 5. 已实现的辅助工具 -### 4.1 测试工具 (Test Tools) +### 5.1 测试工具 (Test Tools) **文件路径**: `tests/DotNetCampus.ModelContextProtocol.Tests/McpTools/` | 文件 | 类名 | 工具方法 | 用途 | @@ -121,15 +133,18 @@ | `ExceptionTool.cs` | `ExceptionTool` | `ThrowError(string? message)`, `ThrowNested()` | 异常处理测试 | | `LongTextTool.cs` | `LongTextTool` | `Generate(int length)` | 大数据量测试 | | `SimpleTool.cs` | `SimpleTool` | `SayHello()` | 最简单的工具 | +| `StatefulCounterTool.cs` | `StatefulCounterTool` | `Increment()`, `GetCount()` | 有状态工具实例语义测试 | +| `InjectedConstructorTool.cs` | `InjectedConstructorTool` | 注入构造函数工具 | 依赖注入测试 | +| `SamplingTool.cs` | `SamplingTool` | `AskLlm(string message, ...)`, `CheckSamplingCapability(...)` | 服务端发起采样请求测试 | -### 4.2 测试资源 (Test Resources) +### 5.2 测试资源 (Test Resources) **文件路径**: `tests/DotNetCampus.ModelContextProtocol.Tests/McpResources/` | 文件 | 类名 | 资源方法 | 用途 | | :--- | :--- | :--- | :--- | | `SimpleResource.cs` | `SimpleResource` | `TextFile()`, `BinaryImage()`, `UserProfile(int userId)` | 基本资源访问测试 | -### 4.3 测试工厂 (Integration Factory) +### 5.3 测试工厂 (Integration Factory) **文件路径**: `tests/DotNetCampus.ModelContextProtocol.Tests/TestMcpFactory.cs` | 方法 | 用途 | @@ -137,16 +152,18 @@ | `CreateSimpleHttpAsync(HttpTransportType)` | 创建仅包含 SimpleTool 的测试包 | | `CreateFullHttpAsync(HttpTransportType)` | 创建包含所有测试工具的测试包 | | `CreateFullHttpWithResourcesAsync(HttpTransportType)` | 创建包含工具和资源的完整测试包 | -| `CreateHttpCoreAsync(HttpTransportType, Action)` | 完全自定义的测试包创建 | +| `CreateTransientCounterHttpAsync(HttpTransportType)` | 创建仅包含 Transient 计数工具的测试包 | +| `CreateHttpAsync(HttpTransportType, Action<...>)` | 自定义工具的测试包 | +| `CreateHttpCoreAsync(HttpTransportType, Action, Action?)` | 完全自定义,支持同时配置服务端和客户端(如配置 Sampling Handler) | -### 4.4 JSON 序列化上下文 +### 5.4 JSON 序列化上下文 **文件路径**: `tests/DotNetCampus.ModelContextProtocol.Tests/McpTools/TestToolJsonContext.cs` 用于 AOT 兼容的复杂对象序列化,包含 `EchoUserInfo` 等类型的注册。 --- -## 5. 待开发的辅助工具 +## 6. 待开发的辅助工具 1. **Mock Transport** * `InProcessServerTransport` / `InProcessClientTransport` @@ -157,11 +174,12 @@ --- -## 6. 测试统计 +## 7. 测试统计 | 类别 | 通过 | 跳过 | 规划 | | :--- | :---: | :---: | :---: | | 核心功能测试 | 28 | 2 | 2 | -| 传输层测试 | 6 | 3 | 4 | +| 传输层测试 | 6 | 2 | 2 | +| 采样功能测试 | 4 | 0 | 0 | | 官方兼容性测试 | 0 | 3 | 0 | -| **总计** | **34** | **8** | **6** | +| **总计** | **38** | **7** | **4** | diff --git a/docs/zh-hans/QuickStart.md b/docs/zh-hans/QuickStart.md index d07a79a..9d00ba8 100644 --- a/docs/zh-hans/QuickStart.md +++ b/docs/zh-hans/QuickStart.md @@ -18,8 +18,7 @@ internal class Program .WithTool(() => new SampleTools()) .WithTool(() => new SampleTools2()) ) - // 传输层使用 Streamable HTTP,监听 http://localhost:5943/mcp, - // 传输层同时兼容 SSE,监听地址为 http://localhost:5943/mcp/sse + // 传输层使用 Streamable HTTP,监听 http://localhost:5943/mcp .WithLocalHostHttp(5943, "mcp") // 传输层也可使用 stdio(标准输入输出),这是 MCP 协议建议所有 MCP 服务器都支持的传输层 // 不过通常不建议同时启用 http 和 stdio,因为前者通常要求单例运行,后者则必须支持多实例运行 diff --git a/samples/DotNetCampus.SampleMcpServer/McpTools/SamplingTool.cs b/samples/DotNetCampus.SampleMcpServer/McpTools/SamplingTool.cs new file mode 100644 index 0000000..b43c306 --- /dev/null +++ b/samples/DotNetCampus.SampleMcpServer/McpTools/SamplingTool.cs @@ -0,0 +1,53 @@ +using DotNetCampus.ModelContextProtocol.CompilerServices; +using DotNetCampus.ModelContextProtocol.Exceptions; +using DotNetCampus.ModelContextProtocol.Protocol.Messages; +using DotNetCampus.ModelContextProtocol.Servers; + +namespace DotNetCampus.SampleMcpServer.McpTools; + +public class SamplingTool +{ + /// + /// 通过客户端的 LLM 进行采样,将 prompt 发送给客户端,获取 LLM 响应并返回。 + /// 用于人工验证 sampling/createMessage 协议流程是否正常。 + /// + /// MCP 工具上下文 + /// 发送给 LLM 的提示词 + /// 最大生成令牌数 + /// 可选的系统提示词 + [McpServerTool] + public async Task AskLlm( + IMcpServerCallToolContext context, + string prompt, + int maxTokens = 1024, + string? systemPrompt = null) + { + if (!context.Sampling.IsSupported) + { + throw new McpToolException("当前客户端未声明 Sampling 能力。请确保客户端支持 sampling/createMessage 请求。"); + } + + try + { + var result = await context.Sampling.CreateMessageAsync(prompt, maxTokens, systemPrompt, context.CancellationToken); + + var responseText = result.Content switch + { + TextContentBlock text => text.Text, + _ => $"[Non-text content: {result.Content?.GetType().Name}]", + }; + + return $""" + Model: {result.Model} + StopReason: {result.StopReason ?? "unknown"} + Role: {result.Role} + --- + {responseText} + """; + } + catch (McpSamplingRejectedException ex) + { + throw new McpToolException($"采样请求被用户拒绝。Sampling request was rejected by the user. Code: {ex.ErrorCode}, Message: {ex.RejectionMessage}"); + } + } +} diff --git a/samples/DotNetCampus.SampleMcpServer/McpTools/SampleTool.cs b/samples/DotNetCampus.SampleMcpServer/McpTools/SimpleTool.cs similarity index 99% rename from samples/DotNetCampus.SampleMcpServer/McpTools/SampleTool.cs rename to samples/DotNetCampus.SampleMcpServer/McpTools/SimpleTool.cs index 4394491..97ddaef 100644 --- a/samples/DotNetCampus.SampleMcpServer/McpTools/SampleTool.cs +++ b/samples/DotNetCampus.SampleMcpServer/McpTools/SimpleTool.cs @@ -3,7 +3,7 @@ namespace DotNetCampus.SampleMcpServer.McpTools; -public class SampleTool +public class SimpleTool { /// /// 用于给 AI 调试使用的工具,原样返回一些信息 diff --git a/samples/DotNetCampus.SampleMcpServer/Program.cs b/samples/DotNetCampus.SampleMcpServer/Program.cs index 23c4551..ceba127 100644 --- a/samples/DotNetCampus.SampleMcpServer/Program.cs +++ b/samples/DotNetCampus.SampleMcpServer/Program.cs @@ -32,11 +32,12 @@ private static async Task Main(string[] args) .WithRequestHandlers(s => new CustomRequestHandlers(s)) .WithJsonSerializer(McpToolJsonContext.Default) .WithTools(t => t - .WithTool(() => new SampleTool()) + .WithTool(() => new SimpleTool()) .WithTool(() => new InputTool()) .WithTool(() => new OutputTool()) .WithTool(() => new PolymorphicTool()) .WithTool(() => new ResourceTool()) + .WithTool(() => new SamplingTool()) ) .WithResources(r => r .WithResource(() => new SampleResource()) diff --git a/src/DotNetCampus.ModelContextProtocol.Ipc/Transports/Ipc/IpcServerTransport.cs b/src/DotNetCampus.ModelContextProtocol.Ipc/Transports/Ipc/IpcServerTransport.cs index 3a61072..31d0e72 100644 --- a/src/DotNetCampus.ModelContextProtocol.Ipc/Transports/Ipc/IpcServerTransport.cs +++ b/src/DotNetCampus.ModelContextProtocol.Ipc/Transports/Ipc/IpcServerTransport.cs @@ -15,13 +15,14 @@ public class IpcServerTransport : IServerTransport { // System.Runtime.InteropServices.MemoryMarshal.Read("Dncp.Mcp"u8).ToString("X") // 小端写入时,可在 IPC 传输序列中看到 Dncp.Mcp = DotNetCampus.ModelContextProtocol 的 ASCII 字符串。 - private const ulong McpIpcHeader = 0x70634D2E70636E44; + private const ulong McpIpcHeader = IpcServerTransportSession.McpIpcHeader; private readonly IServerTransportManager _manager; private readonly TaskCompletionSource _taskCompletionSource = new(); private readonly IpcProvider _server; private readonly bool _isExternalIpcProvider; private readonly ConcurrentDictionary _sessions = []; + private CancellationToken _runningCancellationToken; /// /// 初始化 类的新实例。 @@ -58,6 +59,7 @@ public Task StartAsync(CancellationToken startingCancellationToken, Cancel _server.StartServer(); _server.PeerConnected += OnPeerConnected; + _runningCancellationToken = runningCancellationToken; runningCancellationToken.Register(() => _taskCompletionSource.TrySetResult()); return Task.FromResult(_taskCompletionSource.Task); } @@ -77,7 +79,9 @@ public ValueTask DisposeAsync() private void OnPeerConnected(object? sender, PeerConnectedArgs e) { - _sessions[e.Peer.PeerName] = new IpcServerTransportSession(e.Peer.PeerName); + var session = new IpcServerTransportSession(_manager, e.Peer.PeerName); + session.SetPeer(e.Peer); + _sessions[e.Peer.PeerName] = session; e.Peer.PeerConnectionBroken += OnPeerConnectionBroken; e.Peer.PeerReconnected += OnPeerReconnected; e.Peer.MessageReceived += OnMessageReceived; @@ -92,7 +96,9 @@ private void OnPeerConnectionBroken(object? sender, IPeerConnectionBrokenArgs e) private void OnPeerReconnected(object? sender, IPeerReconnectedArgs e) { var peer = (PeerProxy)sender!; - _sessions[peer.PeerName] = new IpcServerTransportSession(peer.PeerName); + var session = new IpcServerTransportSession(_manager, peer.PeerName); + session.SetPeer(peer); + _sessions[peer.PeerName] = session; } private void OnMessageReceived(object? sender, IPeerMessageArgs e) @@ -120,29 +126,62 @@ private async Task HandleMessageAsync(PeerProxy peer, IpcMessage message) return; } - var request = await _manager.ParseAndCatchRequestAsync(payload.Body.ToMemoryStream()); - if (request is null) + JsonRpcMessage? parsed; + try { - await _manager.RespondJsonRpcAsync(peer, new JsonRpcResponse - { - Error = new JsonRpcError - { - Code = (int)JsonRpcErrorCode.InvalidRequest, - Message = "Invalid request message.", - }, - }, CancellationToken.None); - return; + parsed = await _manager.ReadMessageAsync(payload.Body.ToMemoryStream()); + } + catch + { + parsed = null; } - var response = await _manager.HandleRequestAsync(request, null, CancellationToken.None); - if (response is null) + if (parsed is not null) { - // 按照 MCP 协议规范,本次请求仅需响应而无需回复。 - // 而 IPC 不需要响应。 - return; + _manager.LogRawIn("[Ipc]", parsed); } - await _manager.RespondJsonRpcAsync(peer, response, CancellationToken.None); + switch (parsed) + { + case JsonRpcResponse response: + // 将响应路由到等待的请求(如 sampling/createMessage 回调)。 + if (_sessions.TryGetValue(peer.PeerName, out var responseSession)) + { + responseSession.HandleResponseAsync(response); + } + return; + + case JsonRpcNotification notification: + // 通知,路由到处理器,无需回复。 + await _manager.HandleRequestAsync( + new JsonRpcRequest { Method = notification.Method, Params = notification.Params }, + null, _runningCancellationToken); + return; + + case JsonRpcRequest request: + { + var response2 = await _manager.HandleRequestAsync(request, null, _runningCancellationToken); + if (response2 is null) + { + // 按照 MCP 协议规范,本次请求仅需响应而无需回复。 + // 而 IPC 不需要响应。 + return; + } + await _manager.RespondJsonRpcAsync(peer, response2, _runningCancellationToken); + return; + } + + default: + await _manager.RespondJsonRpcAsync(peer, new JsonRpcResponse + { + Error = new JsonRpcError + { + Code = (int)JsonRpcErrorCode.InvalidRequest, + Message = "Invalid request message.", + }, + }, _runningCancellationToken); + return; + } } } @@ -150,26 +189,14 @@ file static class Extensions { extension(IServerTransportManager manager) { - public async ValueTask ParseAndCatchRequestAsync(Stream data) - { - try - { - return await manager.ReadRequestAsync(data); - } - catch - { - // 请求消息格式不正确,返回 null 后,原样给 MCP 客户端报告错误。 - return null; - } - } - public async ValueTask RespondJsonRpcAsync(PeerProxy peer, JsonRpcResponse response, CancellationToken cancellationToken) { try { + manager.LogRawOut("[Ipc]", response); using var ms = new MemoryStream(); await manager.WriteMessageAsync(ms, response, cancellationToken); - await peer.NotifyAsync(new IpcMessage("", new IpcMessageBody(ms.GetBuffer(), 0, (int)ms.Length))); + await peer.NotifyAsync(new IpcMessage("", new IpcMessageBody(ms.GetBuffer(), 0, (int)ms.Length), IpcServerTransportSession.McpIpcHeader)); } catch { diff --git a/src/DotNetCampus.ModelContextProtocol.Ipc/Transports/Ipc/IpcServerTransportSession.cs b/src/DotNetCampus.ModelContextProtocol.Ipc/Transports/Ipc/IpcServerTransportSession.cs index ba9a4cb..c16b9e4 100644 --- a/src/DotNetCampus.ModelContextProtocol.Ipc/Transports/Ipc/IpcServerTransportSession.cs +++ b/src/DotNetCampus.ModelContextProtocol.Ipc/Transports/Ipc/IpcServerTransportSession.cs @@ -1,4 +1,5 @@ -using dotnetCampus.Ipc.Pipes; +using dotnetCampus.Ipc.Messages; +using dotnetCampus.Ipc.Pipes; using DotNetCampus.ModelContextProtocol.Protocol.Messages.JsonRpc; namespace DotNetCampus.ModelContextProtocol.Transports.Ipc; @@ -6,31 +7,57 @@ namespace DotNetCampus.ModelContextProtocol.Transports.Ipc; /// /// DotNetCampus.Ipc 传输层的一个会话。 /// -public class IpcServerTransportSession : IServerTransportSession +public class IpcServerTransportSession : ServerTransportSession { + // System.Runtime.InteropServices.MemoryMarshal.Read("Dncp.Mcp"u8).ToString("X") + // 小端写入时,可在 IPC 传输序列中看到 Dncp.Mcp = DotNetCampus.ModelContextProtocol 的 ASCII 字符串。 + internal const ulong McpIpcHeader = 0x70634D2E70636E44; + + private readonly IServerTransportManager _manager; + private PeerProxy? _peer; + /// /// 创建 DotNetCampus.Ipc 传输层的一个会话。 /// + /// /// 会话 Id。 - public IpcServerTransportSession(string sessionId) + public IpcServerTransportSession(IServerTransportManager manager, string sessionId) { + _manager = manager; SessionId = sessionId; } /// /// DotNetCampus.Ipc 传输层其实是严格一对一对应一个 的,所以其实不需要设置此属性。不过我们还是设了,调试稍微方便一点点。 /// - public string SessionId { get; } + public override string SessionId { get; } + + /// + /// 设置与此会话关联的 IPC 对端代理,供服务端主动请求发送时使用。 + /// + internal void SetPeer(PeerProxy peer) + { + _peer = peer; + } /// - public Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) + protected override async Task SendRequestMessageAsync(JsonRpcRequest request, CancellationToken cancellationToken) { - throw new NotImplementedException(); + if (_peer is not { } peer) + { + throw new InvalidOperationException("IPC 对端代理尚未设置,无法发送服务端主动请求。请确认 SetPeer 已在连接建立时被调用。"); + } + + _manager.LogRawOut("[Ipc]", request); + using var ms = new MemoryStream(); + await _manager.WriteMessageAsync(ms, request, cancellationToken); + await peer.NotifyAsync(new IpcMessage("McpServer.SendMessage", new IpcMessageBody(ms.GetBuffer(), 0, (int)ms.Length), McpIpcHeader)); } /// - public ValueTask DisposeAsync() + public override ValueTask DisposeAsync() { + CancelAllPendingRequests(); return ValueTask.CompletedTask; } } diff --git a/src/DotNetCampus.ModelContextProtocol.TouchSocket.Http/Transports/TouchSocket/TouchSocketHttpServerTransport.cs b/src/DotNetCampus.ModelContextProtocol.TouchSocket.Http/Transports/TouchSocket/TouchSocketHttpServerTransport.cs index db203de..8d99d14 100644 --- a/src/DotNetCampus.ModelContextProtocol.TouchSocket.Http/Transports/TouchSocket/TouchSocketHttpServerTransport.cs +++ b/src/DotNetCampus.ModelContextProtocol.TouchSocket.Http/Transports/TouchSocket/TouchSocketHttpServerTransport.cs @@ -13,6 +13,17 @@ namespace DotNetCampus.ModelContextProtocol.Transports.TouchSocket; +// 结构说明:本类的 POST 处理逻辑结构与 LocalHostHttpServerTransport 完全对称。 +// 方法对应关系: +// HandleStreamableHttpMessageAsync ↔ HandlePostRequestAsync +// HandleClientResponseAsync ↔ HandleClientResponseAsync +// HandleNotificationAsync ↔ HandleNotificationAsync +// HandleRpcRequestAsync ↔ HandleRpcRequestAsync +// GetOrCreateSessionAsync ↔ GetOrCreateSessionAsync +// HandleInitializeAsync ↔ HandleInitializeAsync +// HandleSseRequestAsync ↔ HandleSseRequestAsync +// 如需修改协议逻辑,请同时更新对应方法。 + /// /// 基于 TouchSocket.Http 的 Streamable HTTP 传输层实现。 /// @@ -24,11 +35,14 @@ public class TouchSocketHttpServerTransport : PluginBase, IHttpPlugin, IServerTr { private const string ProtocolVersionHeader = "MCP-Protocol-Version"; private const string SessionIdHeader = "Mcp-Session-Id"; + private const int SseKeepAliveIntervalMs = 60000; private static readonly ReadOnlyMemory PrimeEventBytes = ": \n\n"u8.ToArray(); + private static readonly ReadOnlyMemory SseKeepAliveBytes = ": keep-alive\n\n"u8.ToArray(); private readonly IServerTransportManager _manager; private readonly ITouchSocketHttpServerTransportOptions _options; - private readonly ConcurrentDictionary _sessions = new(); + private readonly ConcurrentDictionary _sessions = new(); + private CancellationToken _runningCancellationToken; private readonly TouchSocketConfig? _config; private readonly HttpService? _httpService; @@ -83,6 +97,7 @@ public async Task StartAsync(CancellationToken startingCancellationToken, $"[McpServer][TouchSocket] Transport started with external HttpServer, endpoint: {_options.EndPoint}"); } + _runningCancellationToken = runningCancellationToken; return Task.Delay(Timeout.Infinite, runningCancellationToken); } @@ -154,14 +169,14 @@ await context.Response // Streamable HTTP: 客户端建立连接。 if (method == "GET" && endpoint.Equals(_options.EndPoint, StringComparison.OrdinalIgnoreCase)) { - await HandleStreamableHttpConnectionAsync(context, CancellationToken.None); + await HandleStreamableHttpConnectionAsync(context, _runningCancellationToken); return; } // Streamable HTTP: 客户端发送消息。 if (method == "POST" && endpoint.Equals(_options.EndPoint, StringComparison.OrdinalIgnoreCase)) { - await HandleStreamableHttpMessageAsync(context, CancellationToken.None); + await HandleStreamableHttpMessageAsync(context, _runningCancellationToken); return; } @@ -218,6 +233,8 @@ private async ValueTask HandleStreamableHttpConnectionAsync(HttpContext context, context.Response.ContentType = "text/event-stream"; context.Response.Headers.Add("Cache-Control", "no-cache"); + Log.Info($"[McpServer][TouchSocket] SSE connection established. SessionId={sessionId}"); + try { context.Response.IsChunk = true; @@ -225,7 +242,20 @@ private async ValueTask HandleStreamableHttpConnectionAsync(HttpContext context, await output.WriteAsync(PrimeEventBytes, cancellationToken); await output.FlushAsync(cancellationToken); - await session.RunSseConnectionAsync(output, cancellationToken); + // 定期发送 SSE 心跳,以便在客户端断开时通过写入/刷新失败尽快退出, + // 避免仅依赖外部 cancellationToken 导致连接长期悬挂。 + while (!cancellationToken.IsCancellationRequested) + { + await Task.Delay(SseKeepAliveIntervalMs, cancellationToken); + await output.WriteAsync(SseKeepAliveBytes, cancellationToken); + await output.FlushAsync(cancellationToken); + } + Log.Info($"[McpServer][TouchSocket] SSE connection cancelled. SessionId={sessionId}"); + } + catch (OperationCanceledException) + { + // 正常关闭 + Log.Info($"[McpServer][TouchSocket] SSE connection ended. SessionId={sessionId}"); } catch (Exception ex) { @@ -246,22 +276,21 @@ private async ValueTask HandleStreamableHttpMessageAsync(HttpContext context, Ca // 协议版本检查 var protocolVersion = request.Headers.Get(ProtocolVersionHeader).First; - if (!string.IsNullOrEmpty(protocolVersion)) + if (!string.IsNullOrEmpty(protocolVersion) && (ProtocolVersion)protocolVersion < ProtocolVersion.Minimum) { - // 如果比最小版本小则报错 - if (string.CompareOrdinal(protocolVersion, ProtocolVersion.Minimum) < 0) - { - Log.Warn($"[McpServer][TouchSocket] POST request rejected: Unsupported protocol version. Version={protocolVersion}"); - await context.RespondHttpError(HttpStatusCode.BadRequest, $"Unsupported protocol version. Minimum required: {ProtocolVersion.Minimum}"); - return; - } + Log.Warn($"[McpServer][TouchSocket] POST request rejected: Unsupported protocol version. Version={protocolVersion}"); + await context.RespondHttpError(HttpStatusCode.BadRequest, $"Unsupported protocol version. Minimum required: {ProtocolVersion.Minimum}"); + return; } - JsonRpcRequest? jsonRpcRequest; + var sessionIdStr = request.Headers.Get(SessionIdHeader).First; + + // 解析消息体 + JsonRpcMessage? message; try { var bodyBytes = await request.GetContentAsync(); - jsonRpcRequest = await _manager.ReadRequestAsync(bodyBytes); + message = await _manager.ReadMessageAsync(bodyBytes); } catch (JsonException) { @@ -270,74 +299,209 @@ private async ValueTask HandleStreamableHttpMessageAsync(HttpContext context, Ca return; } - if (jsonRpcRequest == null) + if (message is not null) { - Log.Warn($"[McpServer][TouchSocket] POST request rejected: Empty body."); - await context.RespondHttpError(HttpStatusCode.BadRequest, "Empty body"); - return; + _manager.LogRawIn("[TouchSocket]", $"POST, SessionId={sessionIdStr}", message); } - var isInitialize = jsonRpcRequest.Method == RequestMethods.Initialize; - var sessionIdStr = request.Headers.Get(SessionIdHeader).First; - TouchSocketHttpServerTransportSession? session; + switch (message) + { + case JsonRpcResponse jsonRpcResponse: + await HandleClientResponseAsync(context, sessionIdStr, jsonRpcResponse); + return; + case JsonRpcNotification notification: + await HandleNotificationAsync(context, sessionIdStr, notification, request, cancellationToken); + return; + case JsonRpcRequest jsonRpcRequest: + await HandleRpcRequestAsync(context, sessionIdStr, jsonRpcRequest, request, cancellationToken); + return; + default: + Log.Warn($"[McpServer][TouchSocket] POST request rejected: Invalid or unrecognized JSON-RPC message."); + await context.RespondHttpError(HttpStatusCode.BadRequest, "Invalid or unrecognized JSON-RPC message"); + return; + } + } - if (isInitialize) + /// + /// 客户端响应服务器发起的请求(如 sampling/createMessage)。 + /// + /// + /// + /// + private async ValueTask HandleClientResponseAsync(HttpContext context, string? sessionIdStr, JsonRpcResponse response) + { + if (string.IsNullOrEmpty(sessionIdStr) || !_sessions.TryGetValue(sessionIdStr, out var session)) { - // 初始化请求,创建新 Session - var newSessionId = _manager.MakeNewSessionId(); - var newSession = new TouchSocketHttpServerTransportSession(_manager, newSessionId.Id); + Log.Warn($"[McpServer][TouchSocket] Response routing failed: Session not found. SessionId={sessionIdStr}"); + await context.RespondHttpError(HttpStatusCode.NotFound, "Session not found"); + return; + } + session.HandleResponseAsync(response); + await context.RespondHttpSuccess(HttpStatusCode.Accepted); + } - if (_sessions.TryAdd(newSessionId.Id, newSession)) - { - session = newSession; - _manager.Add(session); - context.Response.Headers.Add(SessionIdHeader, newSessionId.Id); - Log.Info($"[McpServer][TouchSocket] Session created. SessionId={newSessionId.Id}"); - } - else + /// + /// 通知消息,无需响应。 + /// + /// + /// + /// + /// + /// + private async ValueTask HandleNotificationAsync(HttpContext context, string? sessionIdStr, JsonRpcNotification notification, HttpRequest request, CancellationToken cancellationToken) + { + if (string.IsNullOrEmpty(sessionIdStr) || !_sessions.TryGetValue(sessionIdStr, out var session)) + { + Log.Warn($"[McpServer][TouchSocket] Notification routing failed: Session not found. SessionId={sessionIdStr}"); + await context.RespondHttpError(HttpStatusCode.NotFound, "Session not found"); + return; + } + await _manager.HandleRequestAsync( + new JsonRpcRequest { Method = notification.Method, Params = notification.Params }, + s => { - Log.Error($"[McpServer][TouchSocket] Session ID collision. SessionId={newSessionId.Id}"); - await context.RespondHttpError(HttpStatusCode.InternalServerError, "Session ID collision"); - return; - } + s.AddHttpTransportServices(session.SessionId, request); + s.AddTransportSession(session, Log); + }, + cancellationToken: cancellationToken); + await context.RespondHttpSuccess(HttpStatusCode.Accepted); + } + + /// + /// JSON-RPC 请求(包含 initialize 和普通请求两种路径)。 + /// + /// + /// + /// + /// + /// + private async ValueTask HandleRpcRequestAsync(HttpContext context, string? sessionIdStr, JsonRpcRequest jsonRpcRequest, HttpRequest request, CancellationToken cancellationToken) + { + var session = await GetOrCreateSessionAsync(context, sessionIdStr, jsonRpcRequest); + if (session is null) return; + + Log.Debug($"[McpServer][TouchSocket] Handling JSON-RPC request. SessionId={session.SessionId}, Method={jsonRpcRequest.Method}, MessageId={jsonRpcRequest.Id}"); + + if (jsonRpcRequest.Method == RequestMethods.Initialize) + { + await HandleInitializeAsync(context, session, jsonRpcRequest, request, cancellationToken); } else { - if (string.IsNullOrEmpty(sessionIdStr)) - { - Log.Warn($"[McpServer][TouchSocket] POST request rejected: Missing Mcp-Session-Id header. Method={jsonRpcRequest.Method}"); - await context.RespondHttpError(HttpStatusCode.BadRequest, "Missing Mcp-Session-Id header"); - return; - } + await HandleSseRequestAsync(context, session, jsonRpcRequest, request, cancellationToken); + } + } - if (!_sessions.TryGetValue(sessionIdStr, out session)) + /// + /// 查找已有 Session 或为 initialize 请求创建新 Session。失败时向客户端写入错误响应并返回 null。 + /// + /// + /// + /// + /// + private async ValueTask GetOrCreateSessionAsync(HttpContext context, string? sessionIdStr, JsonRpcRequest jsonRpcRequest) + { + if (jsonRpcRequest.Method == RequestMethods.Initialize) + { + var newSessionId = _manager.MakeNewSessionId(); + var newSession = new HttpServerTransportSession(_manager, newSessionId.Id, "[McpServer][TouchSocket]"); + if (_sessions.TryAdd(newSessionId.Id, newSession)) { - Log.Warn($"[McpServer][TouchSocket] POST request rejected: Session not found. SessionId={sessionIdStr}, Method={jsonRpcRequest.Method}"); - await context.RespondHttpError(HttpStatusCode.NotFound, "Session not found"); - return; + _manager.Add(newSession); + context.Response.Headers.Add(SessionIdHeader, newSessionId.Id); + Log.Info($"[McpServer][TouchSocket] Session created. SessionId={newSessionId.Id}"); + return newSession; } + Log.Error($"[McpServer][TouchSocket] Session ID collision. SessionId={newSessionId.Id}"); + await context.RespondHttpError(HttpStatusCode.InternalServerError, "Session ID collision"); + return null; } - Log.Debug($"[McpServer][TouchSocket] Handling JSON-RPC request. SessionId={session.SessionId}, Method={jsonRpcRequest.Method}, MessageId={jsonRpcRequest.Id}"); + if (string.IsNullOrEmpty(sessionIdStr)) + { + Log.Warn($"[McpServer][TouchSocket] POST request rejected: Missing Mcp-Session-Id header. Method={jsonRpcRequest.Method}"); + await context.RespondHttpError(HttpStatusCode.BadRequest, "Missing Mcp-Session-Id header"); + return null; + } + if (!_sessions.TryGetValue(sessionIdStr, out var session)) + { + Log.Warn($"[McpServer][TouchSocket] POST request rejected: Session not found. SessionId={sessionIdStr}, Method={jsonRpcRequest.Method}"); + await context.RespondHttpError(HttpStatusCode.NotFound, "Session not found"); + return null; + } + return session; + } - var jsonRpcResponse = await _manager.HandleRequestAsync(jsonRpcRequest, - s => s.AddHttpTransportServices(session.SessionId, request), + /// + /// initialize 请求:同步返回 application/json,无需 SSE 流。 + /// + /// + /// + /// + /// + /// + private async ValueTask HandleInitializeAsync(HttpContext context, HttpServerTransportSession session, JsonRpcRequest jsonRpcRequest, HttpRequest request, CancellationToken cancellationToken) + { + var initResponse = await _manager.HandleRequestAsync(jsonRpcRequest, + s => + { + s.AddHttpTransportServices(session.SessionId, request); + s.AddTransportSession(session, Log); + }, cancellationToken: cancellationToken); - if (jsonRpcResponse != null) + if (initResponse != null) { - // Request: Success or Failed. - Log.Debug($"[McpServer][TouchSocket] Sending JSON-RPC response. SessionId={session.SessionId}, Method={jsonRpcRequest.Method}, MessageId={jsonRpcRequest.Id}"); - await context.RespondJsonRpcAsync(_manager, HttpStatusCode.OK, jsonRpcResponse); + Log.Debug($"[McpServer][TouchSocket] Sending initialize response. SessionId={session.SessionId}, MessageId={jsonRpcRequest.Id}"); + _manager.LogRawOut("[TouchSocket]", $"POST/json, SessionId={session.SessionId}", initResponse); + await context.RespondJsonRpcAsync(_manager, HttpStatusCode.OK, initResponse, cancellationToken); } else { - // Notification: No need to respond. - Log.Debug($"[McpServer][TouchSocket] No response for notification. SessionId={session.SessionId}, Method={jsonRpcRequest.Method}, MessageId={jsonRpcRequest.Id}"); + Log.Debug($"[McpServer][TouchSocket] No response for initialize notification. SessionId={session.SessionId}"); await context.RespondHttpSuccess(HttpStatusCode.Accepted); } } + /// + /// 非 initialize 请求:以 text/event-stream 响应,服务端可在处理期间通过 SSE 流发起采样请求。 + /// 规范 §2.1 规则 6:"The server MAY send JSON-RPC requests and notifications before sending + /// the JSON-RPC response. These messages SHOULD relate to the originating client request." + /// + /// + /// + /// + /// + /// + private async ValueTask HandleSseRequestAsync(HttpContext context, HttpServerTransportSession session, JsonRpcRequest jsonRpcRequest, HttpRequest request, CancellationToken cancellationToken) + { + context.Response.SetStatus(HttpStatusCode.OK, ""); + context.Response.ContentType = "text/event-stream"; + context.Response.Headers.Add("Cache-Control", "no-cache"); + + context.Response.IsChunk = true; + await using var output = context.Response.CreateWriteStream(); + await output.WriteAsync(PrimeEventBytes, cancellationToken); + await output.FlushAsync(cancellationToken); + + using var _ = session.SetRequestSseStream(output); + + var resp = await _manager.HandleRequestAsync(jsonRpcRequest, + s => + { + s.AddHttpTransportServices(session.SessionId, request); + s.AddTransportSession(session, Log); + }, + cancellationToken: cancellationToken); + + if (resp != null) + { + Log.Debug($"[McpServer][TouchSocket] Sending JSON-RPC response via SSE. SessionId={session.SessionId}, Method={jsonRpcRequest.Method}, MessageId={jsonRpcRequest.Id}"); + await session.WriteSseMessageAsync(output, resp, cancellationToken); + } + await context.Response.CompleteChunkAsync(cancellationToken); + } + /// /// Streamable HTTP: 客户端关闭连接 (DELETE /mcp)。 /// @@ -441,7 +605,8 @@ file static class Extensions /// 服务端传输管理器。 /// HTTP 状态码。 /// JSON-RPC 响应对象。 - internal async ValueTask RespondJsonRpcAsync(IServerTransportManager manager, int statusCode, JsonRpcResponse response) + /// 取消令牌。 + internal async ValueTask RespondJsonRpcAsync(IServerTransportManager manager, int statusCode, JsonRpcResponse response, CancellationToken cancellationToken) { context.Response.ContentType = "application/json"; context.Response.SetStatus(statusCode, ""); @@ -449,7 +614,7 @@ internal async ValueTask RespondJsonRpcAsync(IServerTransportManager manager, in context.Response.IsChunk = true; await using (var stream = context.Response.CreateWriteStream()) { - await manager.WriteMessageAsync(stream, response, CancellationToken.None); + await manager.WriteMessageAsync(stream, response, cancellationToken); } await context.Response.CompleteChunkAsync(); } diff --git a/src/DotNetCampus.ModelContextProtocol.TouchSocket.Http/Transports/TouchSocket/TouchSocketHttpServerTransportSession.cs b/src/DotNetCampus.ModelContextProtocol.TouchSocket.Http/Transports/TouchSocket/TouchSocketHttpServerTransportSession.cs deleted file mode 100644 index 1b48b54..0000000 --- a/src/DotNetCampus.ModelContextProtocol.TouchSocket.Http/Transports/TouchSocket/TouchSocketHttpServerTransportSession.cs +++ /dev/null @@ -1,128 +0,0 @@ -using System.Threading.Channels; -using DotNetCampus.ModelContextProtocol.Protocol.Messages.JsonRpc; -using DotNetCampus.ModelContextProtocol.Hosting.Logging; - -namespace DotNetCampus.ModelContextProtocol.Transports.TouchSocket; - -/// -/// Streamable HTTP 传输层的一个会话。 -/// -public class TouchSocketHttpServerTransportSession : IServerTransportSession -{ - private static readonly ReadOnlyMemory EventMessageBytes = "event: message\n"u8.ToArray(); - private static readonly ReadOnlyMemory DataPrefixBytes = "data: "u8.ToArray(); - private static readonly ReadOnlyMemory NewLineBytes = "\n"u8.ToArray(); - - private readonly IServerTransportManager _manager; - private readonly Channel _outgoingMessages; - private readonly CancellationTokenSource _disposeCts = new(); - - private IMcpLogger Log => _manager.Context.Logger; - - /// - public string SessionId { get; } - - /// - /// 初始化 类的新实例。 - /// - /// 辅助管理 MCP 传输层的管理器。 - /// 唯一标识此会话的 ID。 - public TouchSocketHttpServerTransportSession(IServerTransportManager manager, string sessionId) - { - _manager = manager; - SessionId = sessionId; - _outgoingMessages = Channel.CreateUnbounded(new UnboundedChannelOptions - { - SingleReader = true, - SingleWriter = false, - }); - } - - /// - public Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) - { - if (_disposeCts.IsCancellationRequested) - { - return Task.CompletedTask; - } - return _outgoingMessages.Writer.WriteAsync(message, cancellationToken).AsTask(); - } - - /// - /// 运行 SSE 长连接,持续向客户端推送消息,直到连接断开或取消。 - /// - /// 用于向客户端写入 SSE 数据的输出流。 - /// 用于取消操作的令牌。 - public async Task RunSseConnectionAsync(Stream outputStream, CancellationToken cancellationToken) - { - using var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _disposeCts.Token); - var ct = linkedCts.Token; - - try - { - Log.Debug($"[McpServer][TouchSocket] SSE connection started. SessionId={SessionId}"); - - // Wait for messages and write them - await foreach (var message in _outgoingMessages.Reader.ReadAllAsync(ct)) - { - await WriteSseMessageAsync(outputStream, message, ct); - } - } - catch (OperationCanceledException) - { - // Expected on shutdown - } - catch (Exception ex) - { - Log.Warn($"[McpServer][TouchSocket] SSE connection error. SessionId={SessionId}, Error={ex.Message}"); - } - finally - { - Log.Debug($"[McpServer][TouchSocket] SSE connection ended. SessionId={SessionId}"); - } - } - - private async Task WriteSseMessageAsync(Stream stream, JsonRpcMessage message, CancellationToken ct) - { - try - { - // event: message - await stream.WriteAsync(EventMessageBytes, ct); - - // data: ... - await stream.WriteAsync(DataPrefixBytes, ct); - - // Serialize - await _manager.WriteMessageAsync(stream, message, ct); - - // \n\n (End of event) - await stream.WriteAsync(NewLineBytes, ct); - await stream.WriteAsync(NewLineBytes, ct); - - await stream.FlushAsync(ct); - } - catch (Exception ex) - { - Log.Error($"[McpServer][TouchSocket] Failed to write SSE message. SessionId={SessionId}", ex); - throw; // Re-throw to close connection if write fails - } - } - - /// - public async ValueTask DisposeAsync() - { - if (_disposeCts.IsCancellationRequested) - { - return; - } - -#if NET8_0_OR_GREATER - await _disposeCts.CancelAsync(); -#else - await Task.Yield(); - _disposeCts.Cancel(); -#endif - _outgoingMessages.Writer.TryComplete(); - _disposeCts.Dispose(); - } -} diff --git a/src/DotNetCampus.ModelContextProtocol/Clients/McpClient.cs b/src/DotNetCampus.ModelContextProtocol/Clients/McpClient.cs index 76c3b04..5e87a90 100644 --- a/src/DotNetCampus.ModelContextProtocol/Clients/McpClient.cs +++ b/src/DotNetCampus.ModelContextProtocol/Clients/McpClient.cs @@ -1,4 +1,4 @@ -using System.Text.Json; +using System.Text.Json; using System.Text.Json.Serialization.Metadata; using DotNetCampus.ModelContextProtocol.CompilerServices; using DotNetCampus.ModelContextProtocol.Exceptions; @@ -115,12 +115,12 @@ public async Task ListToolsAsync(string? cursor = null, Cancell : JsonSerializer.SerializeToElement(new ListToolsRequestParams { Cursor = cursor, - }, McpServerRequestJsonContext.Default.ListToolsRequestParams), + }, McpInternalJsonContext.Default.ListToolsRequestParams), }; var response = await Transport.SendRequestAsync(request, cancellationToken).ConfigureAwait(false); response.ThrowClientExceptionIfError(); - return DeserializeResult(response, McpServerResponseJsonContext.Default.ListToolsResult); + return DeserializeResult(response, McpInternalJsonContext.Default.ListToolsResult); } /// @@ -142,12 +142,12 @@ public async Task CallToolAsync(string toolName, JsonElement? ar { Name = toolName, Arguments = arguments, - }, McpServerRequestJsonContext.Default.CallToolRequestParams), + }, McpInternalJsonContext.Default.CallToolRequestParams), }; var response = await Transport.SendRequestAsync(request, cancellationToken).ConfigureAwait(false); response.ThrowClientExceptionIfError(); - return DeserializeResult(response, McpServerResponseJsonContext.Default.CallToolResult); + return DeserializeResult(response, McpInternalJsonContext.Default.CallToolResult); } /// @@ -169,12 +169,12 @@ public async Task ListResourcesAsync(string? cursor = null, : JsonSerializer.SerializeToElement(new ListResourcesRequestParams { Cursor = cursor, - }, McpServerRequestJsonContext.Default.ListResourcesRequestParams), + }, McpInternalJsonContext.Default.ListResourcesRequestParams), }; var response = await Transport.SendRequestAsync(request, cancellationToken).ConfigureAwait(false); response.ThrowClientExceptionIfError(); - return DeserializeResult(response, McpServerResponseJsonContext.Default.ListResourcesResult); + return DeserializeResult(response, McpInternalJsonContext.Default.ListResourcesResult); } /// @@ -194,12 +194,12 @@ public async Task ReadResourceAsync(string uri, Cancellation Params = JsonSerializer.SerializeToElement(new ReadResourceRequestParams { Uri = uri, - }, McpServerRequestJsonContext.Default.ReadResourceRequestParams), + }, McpInternalJsonContext.Default.ReadResourceRequestParams), }; var response = await Transport.SendRequestAsync(request, cancellationToken).ConfigureAwait(false); response.ThrowClientExceptionIfError(); - return DeserializeResult(response, McpServerResponseJsonContext.Default.ReadResourceResult); + return DeserializeResult(response, McpInternalJsonContext.Default.ReadResourceResult); } /// @@ -221,12 +221,12 @@ public async Task ListPromptsAsync(string? cursor = null, Can : JsonSerializer.SerializeToElement(new ListPromptsRequestParams { Cursor = cursor, - }, McpServerRequestJsonContext.Default.ListPromptsRequestParams), + }, McpInternalJsonContext.Default.ListPromptsRequestParams), }; var response = await Transport.SendRequestAsync(request, cancellationToken).ConfigureAwait(false); response.ThrowClientExceptionIfError(); - return DeserializeResult(response, McpServerResponseJsonContext.Default.ListPromptsResult); + return DeserializeResult(response, McpInternalJsonContext.Default.ListPromptsResult); } /// @@ -248,12 +248,12 @@ public async Task GetPromptAsync(string name, Dictionary(response, McpServerResponseJsonContext.Default.GetPromptResult); + return DeserializeResult(response, McpInternalJsonContext.Default.GetPromptResult); } /// diff --git a/src/DotNetCampus.ModelContextProtocol/Clients/McpClientBuilder.cs b/src/DotNetCampus.ModelContextProtocol/Clients/McpClientBuilder.cs index e77176b..dec222f 100644 --- a/src/DotNetCampus.ModelContextProtocol/Clients/McpClientBuilder.cs +++ b/src/DotNetCampus.ModelContextProtocol/Clients/McpClientBuilder.cs @@ -15,9 +15,11 @@ public class McpClientBuilder private string _clientName = ""; private string _clientVersion = "0.0.0"; private IMcpLogger? _logger; + private McpTransportRawMessageLoggingDetailLevel _rawMessageLoggingDetailLevel = McpTransportRawMessageLoggingDetailLevel.None; private IServiceProvider? _serviceProvider; private Func? _transportFactory; private ClientCapabilities _capabilities = new(); + private Func>? _samplingHandler; /// /// 设置客户端名称和版本。 @@ -43,6 +45,19 @@ public McpClientBuilder WithLogger(IMcpLogger logger) return this; } + /// + /// 配置 MCP 客户端的日志记录器。 + /// + /// 日志记录器。 + /// 传输层原始消息的日志记录详细级别。 + /// 用于链式调用的 MCP 客户端生成器。 + public McpClientBuilder WithLogger(IMcpLogger logger, McpTransportRawMessageLoggingDetailLevel rawMessageLoggingDetailLevel) + { + _logger = logger; + _rawMessageLoggingDetailLevel = rawMessageLoggingDetailLevel; + return this; + } + /// /// 配置 MCP 客户端的服务提供器。 /// @@ -131,6 +146,43 @@ public McpClientBuilder WithCapabilities(ClientCapabilities capabilities) return this; } + /// + /// 配置 Sampling 处理器,使客户端支持服务器发起的 sampling/createMessage 请求。 + /// 调用此方法会自动在客户端能力中声明 Sampling 支持。 + /// + /// + /// 当服务器请求采样时的处理函数。接收 并返回 。 + /// + /// 用于链式调用的 MCP 客户端生成器。 + public McpClientBuilder WithSamplingHandler( + Func> handler) + { + _samplingHandler = handler; + _capabilities = _capabilities with + { + Sampling = _capabilities.Sampling ?? new SamplingCapability(), + }; + return this; + } + + /// + /// 配置 Sampling 处理器,使客户端支持服务器发起的 sampling/createMessage 请求。 + /// 调用此方法会自动在客户端能力中声明 Sampling 支持。 + /// + /// + /// 处理函数工厂,接收 以便从中获取所需服务。 + /// + /// 用于链式调用的 MCP 客户端生成器。 + public McpClientBuilder WithSamplingHandler( + Func>> handlerFactory) + { + return WithSamplingHandler((p, ct) => + { + var handler = handlerFactory(_serviceProvider); + return handler(p, ct); + }); + } + /// /// 构建 MCP 客户端实例。 /// @@ -148,9 +200,17 @@ public McpClient Build() ServiceProvider = _serviceProvider, }; - var transportManager = new ClientTransportManager(context); + var transportManager = new ClientTransportManager(context) + { + RawMessageLoggingDetailLevel = _rawMessageLoggingDetailLevel, + }; context.Transport = transportManager; + if (_samplingHandler is { } handler) + { + transportManager.SetSamplingHandler(handler); + } + var transport = _transportFactory(transportManager); transportManager.SetTransport(transport); diff --git a/src/DotNetCampus.ModelContextProtocol/CompilerServices/McpJsonContext.cs b/src/DotNetCampus.ModelContextProtocol/CompilerServices/McpJsonContext.cs index 24d0997..e507447 100644 --- a/src/DotNetCampus.ModelContextProtocol/CompilerServices/McpJsonContext.cs +++ b/src/DotNetCampus.ModelContextProtocol/CompilerServices/McpJsonContext.cs @@ -112,58 +112,53 @@ public partial class CompiledSchemaJsonContext : JsonSerializerContext; internal partial class McpServerToolJsonContext : JsonSerializerContext; /// -/// 提供给 MCP 协议中,服务端收到来自客户端的请求数据时使用的 JSON 序列化上下文。 -/// -[JsonSerializable(typeof(CallToolRequestParams))] -[JsonSerializable(typeof(GetPromptRequestParams))] -[JsonSerializable(typeof(InitializeRequestParams))] -[JsonSerializable(typeof(JsonElement))] -[JsonSerializable(typeof(ListPromptsRequestParams))] -[JsonSerializable(typeof(JsonRpcNotification))] -[JsonSerializable(typeof(JsonRpcRequest))] -[JsonSerializable(typeof(ListResourcesRequestParams))] -[JsonSerializable(typeof(ListResourceTemplatesRequestParams))] -[JsonSerializable(typeof(ListToolsRequestParams))] -[JsonSerializable(typeof(LoggingLevel))] -[JsonSerializable(typeof(PingRequestParams))] -[JsonSerializable(typeof(ReadResourceRequestParams))] -[JsonSerializable(typeof(SetLevelRequestParams))] -[JsonSourceGenerationOptions( - PropertyNameCaseInsensitive = true, - PropertyNamingPolicy = JsonKnownNamingPolicy.CamelCase, - UseStringEnumConverter = true, - WriteIndented = false)] -internal partial class McpServerRequestJsonContext : JsonSerializerContext; - -/// -/// 提供给 MCP 协议中,服务端发送给客户端的响应数据时使用的 JSON 序列化上下文。 +/// MCP 协议内部使用的统一 JSON 序列化上下文,涵盖所有请求参数类型和响应结果类型。 /// [JsonSerializable(typeof(Annotations))] [JsonSerializable(typeof(AudioContentBlock))] [JsonSerializable(typeof(BlobResourceContents))] +[JsonSerializable(typeof(CallToolRequestParams))] [JsonSerializable(typeof(CallToolResult))] [JsonSerializable(typeof(CompiledJsonSchema))] [JsonSerializable(typeof(ContentBlock))] +[JsonSerializable(typeof(CreateMessageRequestParams))] +[JsonSerializable(typeof(CreateMessageResult))] [JsonSerializable(typeof(EmbeddedResourceContentBlock))] [JsonSerializable(typeof(EmptyObject))] +[JsonSerializable(typeof(GetPromptRequestParams))] [JsonSerializable(typeof(GetPromptResult))] [JsonSerializable(typeof(ImageContentBlock))] +[JsonSerializable(typeof(InitializeRequestParams))] [JsonSerializable(typeof(InitializeResult))] [JsonSerializable(typeof(JsonElement))] +[JsonSerializable(typeof(JsonRpcNotification))] +[JsonSerializable(typeof(JsonRpcRequest))] [JsonSerializable(typeof(JsonRpcResponse))] +[JsonSerializable(typeof(ListPromptsRequestParams))] [JsonSerializable(typeof(ListPromptsResult))] +[JsonSerializable(typeof(ListResourcesRequestParams))] [JsonSerializable(typeof(ListResourcesResult))] +[JsonSerializable(typeof(ListResourceTemplatesRequestParams))] [JsonSerializable(typeof(ListResourceTemplatesResult))] +[JsonSerializable(typeof(ListToolsRequestParams))] [JsonSerializable(typeof(ListToolsResult))] +[JsonSerializable(typeof(LoggingLevel))] [JsonSerializable(typeof(McpExceptionData))] +[JsonSerializable(typeof(ModelHint))] +[JsonSerializable(typeof(ModelPreferences))] +[JsonSerializable(typeof(PingRequestParams))] +[JsonSerializable(typeof(ReadResourceRequestParams))] [JsonSerializable(typeof(ReadResourceResult))] [JsonSerializable(typeof(ResourceContents))] [JsonSerializable(typeof(ResourceLinkContentBlock))] +[JsonSerializable(typeof(SamplingMessage))] +[JsonSerializable(typeof(SetLevelRequestParams))] [JsonSerializable(typeof(TextContentBlock))] [JsonSerializable(typeof(TextResourceContents))] +[JsonSerializable(typeof(ToolChoice))] [JsonSourceGenerationOptions( PropertyNameCaseInsensitive = true, PropertyNamingPolicy = JsonKnownNamingPolicy.CamelCase, UseStringEnumConverter = true, WriteIndented = false)] -internal partial class McpServerResponseJsonContext : JsonSerializerContext; +internal partial class McpInternalJsonContext : JsonSerializerContext; diff --git a/src/DotNetCampus.ModelContextProtocol/Exceptions/McpExceptionData.cs b/src/DotNetCampus.ModelContextProtocol/Exceptions/McpExceptionData.cs index 9d237fa..5ff2c6c 100644 --- a/src/DotNetCampus.ModelContextProtocol/Exceptions/McpExceptionData.cs +++ b/src/DotNetCampus.ModelContextProtocol/Exceptions/McpExceptionData.cs @@ -1,4 +1,4 @@ -using System.Text.Json; +using System.Text.Json; using System.Text.Json.Serialization; using DotNetCampus.ModelContextProtocol.CompilerServices; @@ -33,7 +33,7 @@ public record McpExceptionData /// 表示当前实例的 public JsonElement ToJsonElement() { - return JsonSerializer.SerializeToElement(this, McpServerResponseJsonContext.Default.McpExceptionData); + return JsonSerializer.SerializeToElement(this, McpInternalJsonContext.Default.McpExceptionData); } /// @@ -42,7 +42,7 @@ public JsonElement ToJsonElement() /// 表示当前实例的 JSON 字符串。 public string ToJsonString() { - return JsonSerializer.Serialize(this, McpServerResponseJsonContext.Default.McpExceptionData); + return JsonSerializer.Serialize(this, McpInternalJsonContext.Default.McpExceptionData); } /// diff --git a/src/DotNetCampus.ModelContextProtocol/Exceptions/McpSamplingNotSupportedException.cs b/src/DotNetCampus.ModelContextProtocol/Exceptions/McpSamplingNotSupportedException.cs new file mode 100644 index 0000000..c2fd2b9 --- /dev/null +++ b/src/DotNetCampus.ModelContextProtocol/Exceptions/McpSamplingNotSupportedException.cs @@ -0,0 +1,16 @@ +namespace DotNetCampus.ModelContextProtocol.Exceptions; + +/// +/// 当连接的 MCP 客户端未声明对 Sampling 能力的支持,导致无法发起 sampling/createMessage 请求时引发的异常。
+/// 此异常表示客户端在能力协商阶段未声明 sampling 能力,而非代码使用错误。 +///
+public class McpSamplingNotSupportedException : McpClientException +{ + /// + /// 初始化 类的新实例。 + /// + public McpSamplingNotSupportedException() + : base("当前连接的客户端未声明对 Sampling 的支持。The connected client has not declared Sampling capability.") + { + } +} diff --git a/src/DotNetCampus.ModelContextProtocol/Exceptions/McpSamplingRejectedException.cs b/src/DotNetCampus.ModelContextProtocol/Exceptions/McpSamplingRejectedException.cs new file mode 100644 index 0000000..02aa33c --- /dev/null +++ b/src/DotNetCampus.ModelContextProtocol/Exceptions/McpSamplingRejectedException.cs @@ -0,0 +1,30 @@ +namespace DotNetCampus.ModelContextProtocol.Exceptions; + +/// +/// 当 MCP 客户端的采样请求被用户(人工审批)拒绝时引发的异常。
+/// 根据 MCP 规范,客户端实现应提供人工审批机制(human-in-the-loop),允许用户在采样请求发送给 LLM 之前拒绝它。 +///
+public class McpSamplingRejectedException : McpClientException +{ + /// + /// 初始化 类的新实例。 + /// + /// 来自客户端的 JSON-RPC 错误码。 + /// 来自客户端的拒绝原因说明。 + public McpSamplingRejectedException(int errorCode, string message) + : base($"Sampling request was rejected: [{errorCode}] {message}") + { + ErrorCode = errorCode; + RejectionMessage = message; + } + + /// + /// 获取来自客户端的 JSON-RPC 错误码。 + /// + public int ErrorCode { get; } + + /// + /// 获取来自客户端的拒绝原因说明。 + /// + public string RejectionMessage { get; } +} diff --git a/src/DotNetCampus.ModelContextProtocol/Hosting/Services/McpServiceCollectionTransportExtensions.cs b/src/DotNetCampus.ModelContextProtocol/Hosting/Services/McpServiceCollectionTransportExtensions.cs new file mode 100644 index 0000000..6061a7b --- /dev/null +++ b/src/DotNetCampus.ModelContextProtocol/Hosting/Services/McpServiceCollectionTransportExtensions.cs @@ -0,0 +1,26 @@ +using DotNetCampus.ModelContextProtocol.Hosting.Logging; +using DotNetCampus.ModelContextProtocol.Servers; +using DotNetCampus.ModelContextProtocol.Transports; + +namespace DotNetCampus.ModelContextProtocol.Hosting.Services; + +/// +/// 提供向 注册传输层会话服务的扩展方法。 +/// +public static class McpServiceCollectionTransportExtensions +{ + /// + /// 向 MCP 服务集合中注册传输层会话相关服务,包括 + /// 。 + /// + /// MCP 服务集合。 + /// 当前传输层会话实例。 + /// 日志记录器,传递给 Sampling 实现。 + /// 提供链式调用的服务集合。 + public static IMcpServiceCollection AddTransportSession(this IMcpServiceCollection services, IServerTransportSession session, IMcpLogger logger) + { + services.AddScoped(session); + services.AddScoped(new McpServerSampling(session, logger)); + return services; + } +} diff --git a/src/DotNetCampus.ModelContextProtocol/Protocol/Messages/ContentBlock.cs b/src/DotNetCampus.ModelContextProtocol/Protocol/Messages/ContentBlock.cs index 68e68db..55903f5 100644 --- a/src/DotNetCampus.ModelContextProtocol/Protocol/Messages/ContentBlock.cs +++ b/src/DotNetCampus.ModelContextProtocol/Protocol/Messages/ContentBlock.cs @@ -13,8 +13,8 @@ namespace DotNetCampus.ModelContextProtocol.Protocol.Messages; [JsonDerivedType(typeof(AudioContentBlock), typeDiscriminator: "audio")] [JsonDerivedType(typeof(ResourceLinkContentBlock), typeDiscriminator: "resource_link")] [JsonDerivedType(typeof(EmbeddedResourceContentBlock), typeDiscriminator: "resource")] -[JsonDerivedType(typeof(ToolUseContent), typeDiscriminator: "toolUse")] -[JsonDerivedType(typeof(ToolResultContent), typeDiscriminator: "toolResult")] +[JsonDerivedType(typeof(ToolUseContent), typeDiscriminator: "tool_use")] +[JsonDerivedType(typeof(ToolResultContent), typeDiscriminator: "tool_result")] public abstract record ContentBlock { /// @@ -47,6 +47,14 @@ public sealed record TextContentBlock : ContentBlock /// [JsonPropertyName("text")] public required string Text { get; init; } + + /// + /// 输出此文本内容块的文本。 + /// + public override string ToString() + { + return Text; + } } /// @@ -68,6 +76,14 @@ public sealed record ImageContentBlock : ContentBlock /// [JsonPropertyName("mimeType")] public required string MimeType { get; init; } + + /// + /// 输出此图像内容块的文本表示形式,格式为 data URI scheme(data:[<mediatype>][;base64],<data>)。 + /// + public override string ToString() + { + return $"data:{MimeType};base64,{Data}"; + } } /// @@ -89,6 +105,14 @@ public sealed record AudioContentBlock : ContentBlock /// [JsonPropertyName("mimeType")] public required string MimeType { get; init; } + + /// + /// 输出此音频内容块的文本表示形式,格式为 data URI scheme(data:[<mediatype>][;base64],<data>)。 + /// + public override string ToString() + { + return $"data:{MimeType};base64,{Data}"; + } } /// @@ -155,6 +179,14 @@ public sealed record ResourceLinkContentBlock : ContentBlock [JsonPropertyName("size")] [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] public long? Size { get; init; } + + /// + /// 输出此资源链接内容块的文本表示形式,格式为 data URI scheme(data:[<mediatype>][;base64],<data>)。 + /// + public override string ToString() + { + return $"data:{MimeType ?? "application/octet-stream"};base64,{Uri}"; + } } /// @@ -204,6 +236,14 @@ public abstract record ResourceContents [JsonPropertyName("_meta")] [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] public JsonElement? Meta { get; init; } + + /// + /// 输出此资源内容的文本表示形式,格式为 data URI scheme(data:[<mediatype>][;base64],<data>)。 + /// + public override string ToString() + { + return $"data:{MimeType ?? "application/octet-stream"};base64,{Uri}"; + } } /// @@ -219,6 +259,14 @@ public sealed record TextResourceContents : ResourceContents /// [JsonPropertyName("text")] public required string Text { get; init; } + + /// + /// 输出此文本资源内容的文本表示形式。 + /// + public override string ToString() + { + return Text; + } } /// diff --git a/src/DotNetCampus.ModelContextProtocol/Protocol/Messages/Role.cs b/src/DotNetCampus.ModelContextProtocol/Protocol/Messages/Role.cs index dd8b6c1..ee420f3 100644 --- a/src/DotNetCampus.ModelContextProtocol/Protocol/Messages/Role.cs +++ b/src/DotNetCampus.ModelContextProtocol/Protocol/Messages/Role.cs @@ -1,4 +1,3 @@ -using System.Runtime.Serialization; using System.Text.Json.Serialization; namespace DotNetCampus.ModelContextProtocol.Protocol.Messages; @@ -14,13 +13,13 @@ public enum Role /// 用户角色
/// User role ///
- [EnumMember(Value = "user")] + [JsonStringEnumMemberName("user")] User, /// /// 助手角色
/// Assistant role ///
- [EnumMember(Value = "assistant")] + [JsonStringEnumMemberName("assistant")] Assistant, } diff --git a/src/DotNetCampus.ModelContextProtocol/Protocol/Messages/Sampling.cs b/src/DotNetCampus.ModelContextProtocol/Protocol/Messages/Sampling.cs index cbfdb8f..6df1ccd 100644 --- a/src/DotNetCampus.ModelContextProtocol/Protocol/Messages/Sampling.cs +++ b/src/DotNetCampus.ModelContextProtocol/Protocol/Messages/Sampling.cs @@ -123,7 +123,7 @@ public sealed record CreateMessageResult : Result /// Message content ///
[JsonPropertyName("content")] - public required SamplingMessageContent Content { get; init; } + public required ContentBlock Content { get; init; } /// /// 生成消息的模型名称。
@@ -161,7 +161,7 @@ public sealed record SamplingMessage /// Message content ///
[JsonPropertyName("content")] - public required SamplingMessageContent Content { get; init; } + public required ContentBlock Content { get; init; } /// /// 元数据字段
@@ -173,20 +173,6 @@ public sealed record SamplingMessage public JsonElement? Meta { get; init; } } -/// -/// 采样消息内容(文本、图像、音频、工具使用或工具结果)
-/// Sampling message content (text, image, audio, tool use or tool result) -///
-[JsonPolymorphic(TypeDiscriminatorPropertyName = "type")] -[JsonDerivedType(typeof(TextContentBlock), typeDiscriminator: "text")] -[JsonDerivedType(typeof(ImageContentBlock), typeDiscriminator: "image")] -[JsonDerivedType(typeof(AudioContentBlock), typeDiscriminator: "audio")] -[JsonDerivedType(typeof(ToolUseContent), typeDiscriminator: "toolUse")] -[JsonDerivedType(typeof(ToolResultContent), typeDiscriminator: "toolResult")] -public abstract record SamplingMessageContent -{ -} - /// /// 服务器在采样期间对模型选择的偏好,在采样期间请求客户端。
/// 由于 LLM 可以在多个维度上变化,选择"最佳"模型很少是直截了当的。
diff --git a/src/DotNetCampus.ModelContextProtocol/Servers/IMcpServerPrimitiveContext.cs b/src/DotNetCampus.ModelContextProtocol/Servers/IMcpServerPrimitiveContext.cs index ed0ff43..2d58385 100644 --- a/src/DotNetCampus.ModelContextProtocol/Servers/IMcpServerPrimitiveContext.cs +++ b/src/DotNetCampus.ModelContextProtocol/Servers/IMcpServerPrimitiveContext.cs @@ -4,6 +4,7 @@ namespace DotNetCampus.ModelContextProtocol.Servers; + /// /// 包含 MCP 服务器收到来自客户端的请求时,服务端处理请求具体实现可能会用到的各种上下文信息。
/// Contains various context information that the server-side implementation of the MCP server @@ -60,6 +61,11 @@ public interface IMcpServerCallToolContext : IMcpServerPrimitiveContext /// Cancellation token used to cancel the tool invocation operation. ///
CancellationToken CancellationToken { get; } + + /// + /// 提供服务器向客户端发起 Sampling 请求的能力。始终非空;当传输层或客户端不支持 Sampling 时,。 + /// + IMcpServerSampling Sampling { get; } } /// @@ -92,6 +98,9 @@ internal sealed class McpServerCallToolContext : IMcpServerCallToolContext public required string Name { get; init; } public required JsonElement InputJsonArguments { get; init; } public required CancellationToken CancellationToken { get; init; } + public IMcpServerSampling Sampling => + (IMcpServerSampling?)Services.GetService(typeof(IMcpServerSampling)) + ?? NotSupportedMcpServerSampling.Instance; } internal sealed class McpServerReadResourceContext : IMcpServerReadResourceContext diff --git a/src/DotNetCampus.ModelContextProtocol/Servers/McpProtocolBridge.cs b/src/DotNetCampus.ModelContextProtocol/Servers/McpProtocolBridge.cs index 19c4a04..24e6eea 100644 --- a/src/DotNetCampus.ModelContextProtocol/Servers/McpProtocolBridge.cs +++ b/src/DotNetCampus.ModelContextProtocol/Servers/McpProtocolBridge.cs @@ -1,4 +1,4 @@ -using System.Diagnostics.CodeAnalysis; +using System.Diagnostics.CodeAnalysis; using System.Text.Json; using System.Text.Json.Serialization.Metadata; using DotNetCampus.ModelContextProtocol.CompilerServices; @@ -122,28 +122,28 @@ private async ValueTask HandleRequestCoreAsync( }, }, Initialize => await HandleRequestAsync(request, services, context.Handlers.HandleInitializeAsync, - McpServerRequestJsonContext.Default.InitializeRequestParams, McpServerResponseJsonContext.Default.InitializeResult, + McpInternalJsonContext.Default.InitializeRequestParams, McpInternalJsonContext.Default.InitializeResult, cancellationToken), Ping => await HandleRequestAsync(request, services, context.Handlers.HandlePingAsync, - McpServerRequestJsonContext.Default.PingRequestParams, McpServerResponseJsonContext.Default.EmptyObject, + McpInternalJsonContext.Default.PingRequestParams, McpInternalJsonContext.Default.EmptyObject, cancellationToken), LoggingSetLevel => await HandleRequestAsync(request, services, context.Handlers.HandleSetLoggingLevelAsync, - McpServerRequestJsonContext.Default.SetLevelRequestParams, McpServerResponseJsonContext.Default.EmptyObject, + McpInternalJsonContext.Default.SetLevelRequestParams, McpInternalJsonContext.Default.EmptyObject, cancellationToken), ToolsList => await HandleRequestAsync(request, services, context.Handlers.HandleListToolsAsync, - McpServerRequestJsonContext.Default.ListToolsRequestParams, McpServerResponseJsonContext.Default.ListToolsResult, + McpInternalJsonContext.Default.ListToolsRequestParams, McpInternalJsonContext.Default.ListToolsResult, cancellationToken), ToolsCall => await HandleRequestAsync(request, services, context.Handlers.HandleCallToolAsync, - McpServerRequestJsonContext.Default.CallToolRequestParams, McpServerResponseJsonContext.Default.CallToolResult, + McpInternalJsonContext.Default.CallToolRequestParams, McpInternalJsonContext.Default.CallToolResult, cancellationToken), ResourcesList => await HandleRequestAsync(request, services, context.Handlers.HandleListResourcesAsync, - McpServerRequestJsonContext.Default.ListResourcesRequestParams, McpServerResponseJsonContext.Default.ListResourcesResult, + McpInternalJsonContext.Default.ListResourcesRequestParams, McpInternalJsonContext.Default.ListResourcesResult, cancellationToken), ResourcesTemplatesList => await HandleRequestAsync(request, services, context.Handlers.HandleListResourceTemplatesAsync, - McpServerRequestJsonContext.Default.ListResourceTemplatesRequestParams, McpServerResponseJsonContext.Default.ListResourceTemplatesResult, + McpInternalJsonContext.Default.ListResourceTemplatesRequestParams, McpInternalJsonContext.Default.ListResourceTemplatesResult, cancellationToken), ResourcesRead => await HandleRequestAsync(request, services, context.Handlers.HandleReadResourceAsync, - McpServerRequestJsonContext.Default.ReadResourceRequestParams, McpServerResponseJsonContext.Default.ReadResourceResult, + McpInternalJsonContext.Default.ReadResourceRequestParams, McpInternalJsonContext.Default.ReadResourceResult, cancellationToken), _ => new JsonRpcResponse { diff --git a/src/DotNetCampus.ModelContextProtocol/Servers/McpServerBuilder.cs b/src/DotNetCampus.ModelContextProtocol/Servers/McpServerBuilder.cs index 7f9ca73..02f1c18 100644 --- a/src/DotNetCampus.ModelContextProtocol/Servers/McpServerBuilder.cs +++ b/src/DotNetCampus.ModelContextProtocol/Servers/McpServerBuilder.cs @@ -21,6 +21,7 @@ public class McpServerBuilder(string serverName, string serverVersion) private readonly McpServerToolsProvider _tools = new(); private readonly McpServerResourcesProvider _resources = new(); private IMcpLogger? _logger; + private McpTransportRawMessageLoggingDetailLevel _rawMessageLoggingDetailLevel = McpTransportRawMessageLoggingDetailLevel.None; private IMcpServerToolJsonSerializer? _jsonSerializer; private string? _jsonSerializerTypeName; private IServiceProvider? _serviceProvider; @@ -42,7 +43,7 @@ public McpServerBuilder WithStdio() /// MCP 服务器将监听 http://localhost:{port} 上的请求。 /// /// MCP 服务器将监听的路由端点,例如指定为 mcp 时,完整的 URL 为 http://localhost:{port}/mcp。
- /// 所有的 MCP 请求都将发送到该端点;除非客户端使用旧版本(2024-11-05)的 SSE 协议传输时,会自动改为使用 /mcp/sse 端点。
+ /// 所有的 MCP 请求都将发送到该端点。
/// 如果不指定,会使用默认的 /mcp 端点;如果希望监听根路径,请指定为空字符串 ""。 /// /// 用于链式调用的 MCP 服务器生成器。 @@ -96,6 +97,19 @@ public McpServerBuilder WithLogger(IMcpLogger logger) return this; } + /// + /// 配置 MCP 服务器的日志记录器。 + /// + /// 日志记录器。 + /// 传输层原始消息的日志记录详细级别。 + /// 用于链式调用的 MCP 服务器生成器。 + public McpServerBuilder WithLogger(IMcpLogger logger, McpTransportRawMessageLoggingDetailLevel rawMessageLoggingDetailLevel) + { + _logger = logger; + _rawMessageLoggingDetailLevel = rawMessageLoggingDetailLevel; + return this; + } + /// /// 配置自定义的 JSON 序列化上下文。 /// @@ -187,7 +201,10 @@ public McpServer Build() context.Handlers = _requestHandlers is { } requestHandlers ? requestHandlers(server) : new McpServerRequestHandlers(server); - var transportManager = new ServerTransportManager(server, context); + var transportManager = new ServerTransportManager(server, context) + { + RawMessageLoggingDetailLevel = _rawMessageLoggingDetailLevel, + }; context.Transport = transportManager; foreach (var factory in _transportFactories) { diff --git a/src/DotNetCampus.ModelContextProtocol/Servers/McpServerRequestHandlers.cs b/src/DotNetCampus.ModelContextProtocol/Servers/McpServerRequestHandlers.cs index e311b93..5ee6d3e 100644 --- a/src/DotNetCampus.ModelContextProtocol/Servers/McpServerRequestHandlers.cs +++ b/src/DotNetCampus.ModelContextProtocol/Servers/McpServerRequestHandlers.cs @@ -62,7 +62,16 @@ public virtual ValueTask InitializeAsync( CancellationToken cancellationToken) { var clientInfo = request.Params?.ClientInfo; - Logger.Info($"[McpServer][Mcp] Client initializing. ClientName={clientInfo?.Name}, ClientVersion={clientInfo?.Version}, ProtocolVersion={request.Params?.ProtocolVersion}"); + Logger.Info( + $"[McpServer][Mcp] Client initializing. ClientName={clientInfo?.Name}, ClientVersion={clientInfo?.Version}, ProtocolVersion={request.Params?.ProtocolVersion}"); + + // 将客户端能力保存到当前传输层会话,以便后续服务器发起请求(如 sampling)时判断能力。 + var session = (DotNetCampus.ModelContextProtocol.Transports.IServerTransportSession?)request.Services.GetService( + typeof(DotNetCampus.ModelContextProtocol.Transports.IServerTransportSession)); + if (session is not null && request.Params?.Capabilities is { } capabilities) + { + session.ConnectedClientCapabilities = capabilities; + } var hasTools = _server.Tools.Count > 0; var hasResources = _server.Resources.Count > 0; @@ -89,7 +98,8 @@ public virtual ValueTask InitializeAsync( }, }; - Logger.Info($"[McpServer][Mcp] Server initialized. ServerName={_server.ServerName}, ServerVersion={_server.ServerVersion}, ToolCount={_server.Tools.Count}, ResourceCount={_server.Resources.Count}"); + Logger.Info( + $"[McpServer][Mcp] Server initialized. ServerName={_server.ServerName}, ServerVersion={_server.ServerVersion}, ToolCount={_server.Tools.Count}, ResourceCount={_server.Resources.Count}"); return ValueTask.FromResult(result); } @@ -334,6 +344,12 @@ public virtual async ValueTask CallToolAsync( Logger.Warn($"[McpServer][Mcp] Tool call failed. ToolName={toolName}, Arguments={rawRequest.Params}, Error={ex.Message}"); return CallToolResult.FromException(ex); } + catch (McpClientException ex) + { + // 此错误来自 MCP 客户端(例如工具调用过程中,服务端反向发起了请求,但客户端未能正确响应请求)。 + Logger.Warn($"[McpServer][Mcp] Tool call failed: Client error. ToolName={toolName}, Arguments={rawRequest.Params}, Error={ex.Message}"); + return CallToolResult.FromException(ex); + } catch (Exception ex) { // 其他未知错误。 diff --git a/src/DotNetCampus.ModelContextProtocol/Servers/McpServerSampling.cs b/src/DotNetCampus.ModelContextProtocol/Servers/McpServerSampling.cs new file mode 100644 index 0000000..90b89f7 --- /dev/null +++ b/src/DotNetCampus.ModelContextProtocol/Servers/McpServerSampling.cs @@ -0,0 +1,160 @@ +using System.Text.Json; +using DotNetCampus.ModelContextProtocol.CompilerServices; +using DotNetCampus.ModelContextProtocol.Exceptions; +using DotNetCampus.ModelContextProtocol.Hosting.Logging; +using DotNetCampus.ModelContextProtocol.Protocol; +using DotNetCampus.ModelContextProtocol.Protocol.Messages; +using DotNetCampus.ModelContextProtocol.Protocol.Messages.JsonRpc; +using DotNetCampus.ModelContextProtocol.Transports; +using DotNetCampus.ModelContextProtocol.Utils; + +namespace DotNetCampus.ModelContextProtocol.Servers; + +/// +/// 提供服务器主动向客户端发起 Sampling(AI 采样)请求的能力。 +/// +public interface IMcpServerSampling +{ + /// + /// 指示连接的客户端是否声明了对 Sampling 的支持。
+ /// 在调用 前应检查此属性;若为 ,调用将抛出 。 + ///
+ bool IsSupported { get; } + + /// + /// 向客户端发送 sampling/createMessage 请求,通过客户端对 LLM 进行采样。 + /// + /// 采样请求参数。 + /// 取消令牌。 + /// LLM 生成的采样结果。 + /// 当客户端未声明 Sampling 能力时抛出。可通过提前判断 来避免此异常。 + /// 当采样请求被用户(人工审批)拒绝时抛出。 + Task CreateMessageAsync(CreateMessageRequestParams requestParams, CancellationToken cancellationToken = default); +} + +/// +/// 的扩展方法,提供便捷的文本采样接口。 +/// +public static class McpServerSamplingExtensions +{ + /// + /// 向客户端发送简单的纯文本采样请求。 + /// + /// 采样服务实例。 + /// 用户消息内容。 + /// 最大生成令牌数。 + /// 可选的系统提示词。 + /// 取消令牌。 + /// LLM 生成的采样结果。 + /// 当客户端未声明 Sampling 能力时抛出。 + /// 当采样请求被用户拒绝时抛出。 + public static Task CreateMessageAsync( + this IMcpServerSampling sampling, + string userMessage, + int maxTokens = 1024, + string? systemPrompt = null, + CancellationToken cancellationToken = default) + { + var requestParams = new CreateMessageRequestParams + { + Messages = + [ + new SamplingMessage + { + Role = Role.User, + Content = new TextContentBlock { Text = userMessage }, + }, + ], + MaxTokens = maxTokens, + SystemPrompt = systemPrompt, + }; + return sampling.CreateMessageAsync(requestParams, cancellationToken); + } +} + +/// +/// 的内部实现,通过关联的传输层会话与客户端通信。 +/// +internal sealed class McpServerSampling(IServerTransportSession session, IMcpLogger logger) : IMcpServerSampling +{ + /// + public bool IsSupported => session.ConnectedClientCapabilities?.Sampling is not null; + + /// + public async Task CreateMessageAsync( + CreateMessageRequestParams requestParams, + CancellationToken cancellationToken = default) + { + if (!IsSupported) + { + throw new McpSamplingNotSupportedException(); + } + + var request = new JsonRpcRequest + { + Id = RequestId.MakeNew().ToJsonElement(), + Method = RequestMethods.SamplingCreateMessage, + Params = JsonSerializer.SerializeToElement(requestParams, McpInternalJsonContext.Default.CreateMessageRequestParams), + }; + + logger.Debug($"[McpServer][Mcp] Sending sampling/createMessage request. Id={request.Id}"); + + var response = await session.SendRequestAsync(request, cancellationToken).ConfigureAwait(false); + + if (response.Error is { } error) + { + // 根据 MCP 规范,用户拒绝审批时客户端应返回错误响应。 + // JSON-RPC 保留错误码范围为 -32768 到 -32000;任何高于 -32000 的错误码(如 -1) + // 表示用户自定义错误,通常意味着用户主动拒绝了采样请求。 + // + // 兼容性说明(MCP Inspector 的不规范行为): + // MCP Inspector 拒绝采样时,调用的是 reject(new Error("Sampling request rejected")), + // 传入的是普通 JavaScript Error,没有 code 属性。 + // TypeScript SDK 在将 handler rejection 转换为 JSON-RPC 错误时, + // 其逻辑为:code = Number.isSafeInteger(error['code']) ? error['code'] : ErrorCode.InternalError + // 即当 error 无 code 时回退到 -32603(InternalError),而非规范要求的 -1。 + // 因此需额外检查 -32603 + 消息关键字来识别这类不规范的拒绝响应, + // 同时避免将真正的服务端内部错误误判为用户拒绝。 + var isRejectedByUser = error.Code > -32000 + || (error.Code == -32603 && error.Message.Contains("reject", StringComparison.OrdinalIgnoreCase)); + if (isRejectedByUser) + { + logger.Warn($"[McpServer][Mcp] Sampling/createMessage rejected by user. Id={request.Id}, Code={error.Code}, Message={error.Message}"); + throw new McpSamplingRejectedException(error.Code, error.Message); + } + + logger.Error($"[McpServer][Mcp] Sampling/createMessage failed. Id={request.Id}, Code={error.Code}, Message={error.Message}"); + throw new McpClientException($"Sampling request failed: [{error.Code}] {error.Message}"); + } + + if (response.Result is not { } resultElement) + { + throw new McpClientException("Sampling response missing result."); + } + + logger.Debug($"[McpServer][Mcp] Sampling/createMessage succeeded. Id={request.Id}"); + + return resultElement.Deserialize(McpInternalJsonContext.Default.CreateMessageResult) + ?? throw new McpClientException("Failed to deserialize sampling result."); + } +} + +/// +/// 当传输层或客户端不支持 Sampling 时,用于占位的空对象实现。 +/// +internal sealed class NotSupportedMcpServerSampling : IMcpServerSampling +{ + /// + /// 获取全局单例实例。 + /// + public static readonly NotSupportedMcpServerSampling Instance = new(); + + private NotSupportedMcpServerSampling() { } + + /// + public bool IsSupported => false; + + /// + public Task CreateMessageAsync(CreateMessageRequestParams requestParams, CancellationToken cancellationToken = default) + => throw new McpSamplingNotSupportedException(); +} diff --git a/src/DotNetCampus.ModelContextProtocol/Transports/ClientTransportManager.cs b/src/DotNetCampus.ModelContextProtocol/Transports/ClientTransportManager.cs index 0b38213..7eaea4a 100644 --- a/src/DotNetCampus.ModelContextProtocol/Transports/ClientTransportManager.cs +++ b/src/DotNetCampus.ModelContextProtocol/Transports/ClientTransportManager.cs @@ -3,6 +3,7 @@ using DotNetCampus.ModelContextProtocol.Clients; using DotNetCampus.ModelContextProtocol.CompilerServices; using DotNetCampus.ModelContextProtocol.Exceptions; +using DotNetCampus.ModelContextProtocol.Hosting.Logging; using DotNetCampus.ModelContextProtocol.Protocol; using DotNetCampus.ModelContextProtocol.Protocol.Messages; using DotNetCampus.ModelContextProtocol.Protocol.Messages.JsonRpc; @@ -13,14 +14,21 @@ namespace DotNetCampus.ModelContextProtocol.Transports; /// /// 用于管理 MCP 客户端传输层的管理器。 /// -internal class ClientTransportManager(IClientTransportContext context) : IClientTransportManager +internal class ClientTransportManager(IClientTransportContext context) : IClientTransportManager, IMcpTransportLogger { private readonly ConcurrentDictionary> _pendingRequests = []; private IClientTransport? _transport; + private Func>? _samplingHandler; /// public IClientTransportContext Context { get; } = context; + /// + public IMcpLogger Logger => Context.Logger; + + /// + public McpTransportRawMessageLoggingDetailLevel RawMessageLoggingDetailLevel { get; init; } + /// /// 设置传输层实例。 /// @@ -29,31 +37,65 @@ internal void SetTransport(IClientTransport transport) _transport = transport; } + /// + /// 设置 Sampling 请求处理器,供服务器主动发起 sampling/createMessage 请求时调用。 + /// + internal void SetSamplingHandler(Func> handler) + { + _samplingHandler = handler; + } + /// public RequestId MakeNewRequestId() { return RequestId.MakeNew(); } + /// + public ValueTask ReadMessageAsync(string messageLine) + { + var message = JsonElement.Parse(messageLine); + return ValueTask.FromResult(ClassifyAndDeserialize(message)); + } + + /// + /// 根据 JSON-RPC 2.0 字段特征将 分类并反序列化为具体消息类型。 + /// + private static JsonRpcMessage? ClassifyAndDeserialize(JsonElement element) + { + if (element.TryGetProperty("method", out _)) + { + return element.Deserialize(McpInternalJsonContext.Default.JsonRpcRequest); + } + + if (element.TryGetProperty("result", out _) || element.TryGetProperty("error", out _)) + { + return element.Deserialize(McpInternalJsonContext.Default.JsonRpcResponse); + } + + return null; + } + /// public ValueTask ReadResponseAsync(string responseLine) { - var message = JsonSerializer.Deserialize(responseLine, McpServerResponseJsonContext.Default.JsonRpcResponse); + var message = JsonSerializer.Deserialize(responseLine, McpInternalJsonContext.Default.JsonRpcResponse); return ValueTask.FromResult(message); } /// public ValueTask ReadResponseAsync(Stream responseStream) { - var message = JsonSerializer.Deserialize(responseStream, McpServerResponseJsonContext.Default.JsonRpcResponse); + var message = JsonSerializer.Deserialize(responseStream, McpInternalJsonContext.Default.JsonRpcResponse); return ValueTask.FromResult(message); } /// public string WriteMessageAsync(JsonRpcMessage message) => message switch { - JsonRpcRequest request => JsonSerializer.Serialize(request, McpServerRequestJsonContext.Default.JsonRpcRequest), - JsonRpcNotification notification => JsonSerializer.Serialize(notification, McpServerRequestJsonContext.Default.JsonRpcNotification), + JsonRpcRequest request => JsonSerializer.Serialize(request, McpInternalJsonContext.Default.JsonRpcRequest), + JsonRpcResponse response => JsonSerializer.Serialize(response, McpInternalJsonContext.Default.JsonRpcResponse), + JsonRpcNotification notification => JsonSerializer.Serialize(notification, McpInternalJsonContext.Default.JsonRpcNotification), _ => throw new ArgumentException($"不支持的消息类型:{message.GetType().FullName}."), }; @@ -63,9 +105,11 @@ public async ValueTask WriteMessageAsync(Stream requestStream, JsonRpcMessage me await (message switch { JsonRpcRequest request => JsonSerializer.SerializeAsync( - requestStream, request, McpServerRequestJsonContext.Default.JsonRpcRequest, cancellationToken), + requestStream, request, McpInternalJsonContext.Default.JsonRpcRequest, cancellationToken), + JsonRpcResponse response => JsonSerializer.SerializeAsync( + requestStream, response, McpInternalJsonContext.Default.JsonRpcResponse, cancellationToken), JsonRpcNotification notification => JsonSerializer.SerializeAsync( - requestStream, notification, McpServerRequestJsonContext.Default.JsonRpcNotification, cancellationToken), + requestStream, notification, McpInternalJsonContext.Default.JsonRpcNotification, cancellationToken), _ => throw new ArgumentException($"不支持的消息类型:{message.GetType().FullName}."), }); } @@ -82,12 +126,79 @@ public ValueTask HandleRespondAsync(JsonRpcResponse response, CancellationToken if (_pendingRequests.TryRemove(id, out var tcs)) { + Context.Logger.Debug($"[McpClient][Mcp] Response matched to pending request. Id={id}"); tcs.SetResult(response); } + else + { + Context.Logger.Warn($"[McpClient][Mcp] Received unmatched response. Id={id}"); + } return ValueTask.CompletedTask; } + /// + public async ValueTask HandleServerRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default) + { + if (request.Id is null) + { + // JSON-RPC 2.0 规定:通知(notification)没有 id,不应发送响应。 + return; + } + + Context.Logger.Info($"[McpClient][Mcp] Received server-initiated request. Method={request.Method}, Id={request.Id}"); + + JsonRpcResponse response; + + if (request.Method == RequestMethods.SamplingCreateMessage && _samplingHandler is { } handler) + { + try + { + CreateMessageRequestParams? requestParams = null; + if (request.Params is { } paramsElement) + { + requestParams = paramsElement.Deserialize(McpInternalJsonContext.Default.CreateMessageRequestParams); + } + requestParams ??= new CreateMessageRequestParams { Messages = [], MaxTokens = 1024 }; + + var result = await handler(requestParams, cancellationToken).ConfigureAwait(false); + Context.Logger.Debug($"[McpClient][Mcp] Sampling request handled successfully. Id={request.Id}"); + response = new JsonRpcResponse + { + Id = request.Id, + Result = JsonSerializer.SerializeToElement(result, McpInternalJsonContext.Default.CreateMessageResult), + }; + } + catch (Exception ex) + { + Context.Logger.Error($"[McpClient][Mcp] Sampling request handler threw exception. Id={request.Id}, Error={ex.Message}"); + response = new JsonRpcResponse + { + Id = request.Id, + Error = new JsonRpcError + { + Code = (int)JsonRpcErrorCode.InternalError, + Message = ex.Message, + }, + }; + } + } + else + { + Context.Logger.Warn($"[McpClient][Mcp] Unsupported server-initiated request method. Method={request.Method}, Id={request.Id}"); + response = new JsonRpcResponse + { + Id = request.Id, + Error = new JsonRpcError + { + Code = (int)JsonRpcErrorCode.MethodNotFound, + Message = $"Method '{request.Method}' not found or no handler registered.", + }, + }; + } + + await SendMessageAsync(response, cancellationToken).ConfigureAwait(false); + } /// /// 发送请求并等待响应。 /// @@ -107,7 +218,18 @@ public async ValueTask SendRequestAsync(JsonRpcRequest request, } var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - _pendingRequests[id] = tcs; + if (!_pendingRequests.TryAdd(id, tcs)) + { + throw new InvalidOperationException($"已存在相同 ID 的挂起请求:{id}。"); + } + + using var registration = cancellationToken.Register(() => + { + if (_pendingRequests.TryRemove(id, out var removed)) + { + removed.TrySetCanceled(cancellationToken); + } + }); try { @@ -159,7 +281,7 @@ public async ValueTask ConnectAndInitializeAsync(McpClient cli Version = client.ClientVersion, }, Capabilities = client.Capabilities, - }, McpServerRequestJsonContext.Default.InitializeRequestParams), + }, McpInternalJsonContext.Default.InitializeRequestParams), }; var response = await SendRequestAsync(request, cancellationToken).ConfigureAwait(false); @@ -174,7 +296,7 @@ public async ValueTask ConnectAndInitializeAsync(McpClient cli throw new McpClientException("初始化响应格式不正确"); } - var result = responseResult.Deserialize(McpServerResponseJsonContext.Default.InitializeResult) + var result = responseResult.Deserialize(McpInternalJsonContext.Default.InitializeResult) ?? throw new McpClientException("无法解析初始化响应"); // 发送 initialized 通知。 diff --git a/src/DotNetCampus.ModelContextProtocol/Transports/Http/HttpClientTransport.cs b/src/DotNetCampus.ModelContextProtocol/Transports/Http/HttpClientTransport.cs index 6f24321..4a29d94 100644 --- a/src/DotNetCampus.ModelContextProtocol/Transports/Http/HttpClientTransport.cs +++ b/src/DotNetCampus.ModelContextProtocol/Transports/Http/HttpClientTransport.cs @@ -1,4 +1,4 @@ -using System.Net.Http.Headers; +using System.Net.Http.Headers; using System.Text; using System.Text.Json; using DotNetCampus.ModelContextProtocol.Hosting.Logging; @@ -151,7 +151,7 @@ private async ValueTask SendRequestCoreAsync(JsonRpcMessage message, Cancellatio content.Headers.ContentType = new MediaTypeHeaderValue("application/json"); request.Content = content; - _logger.Debug($"[McpClient][Http] Sending POST request. Url={requestUrl}, Type={(isInitialize ? "Initialize" : message.GetType().Name)}"); + _manager.LogRawOut("[Http]", $"POST, SessionId={_sessionId}", jsonContent); // 4. 发送请求 (ResponseHeadersRead 以支持流式响应) var response = await _httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, cancellationToken); @@ -181,10 +181,10 @@ private async ValueTask SendRequestCoreAsync(JsonRpcMessage message, Cancellatio _logger.Debug($"[McpClient][Http] Received SSE stream response."); await using var stream = await response.Content.ReadAsStreamAsync(cancellationToken); - await ProcessSseStreamAsync(stream, cancellationToken, isInitialize - ? (json) => + await ProcessSseStreamAsync(stream, cancellationToken, $"POST/sse, SessionId={_sessionId ?? "init"}", isInitialize + ? (msg) => { - if (json.TryGetProperty("result", out var resultElement)) + if (msg is JsonRpcResponse { Result: { ValueKind: JsonValueKind.Object } resultElement }) { _protocolVersion = TryExtractProtocolVersion(resultElement, "SSE"); } @@ -212,6 +212,7 @@ await ProcessSseStreamAsync(stream, cancellationToken, isInitialize _protocolVersion = TryExtractProtocolVersion(resultElement, "POST"); } + _manager.LogRawIn("[Http]", $"POST/json, SessionId={_sessionId ?? "init"}", rpcResponse); await _manager.HandleRespondAsync(rpcResponse, cancellationToken); } } @@ -304,7 +305,7 @@ private async Task ReceiveLoopAsync(CancellationToken token) } await using var stream = await response.Content.ReadAsStreamAsync(token); - await ProcessSseStreamAsync(stream, token); + await ProcessSseStreamAsync(stream, token, $"GET/sse, SessionId={_sessionId}"); } _logger.Info($"[McpClient][Http] SSE stream ended, reconnecting."); } @@ -331,7 +332,7 @@ private async Task ReceiveLoopAsync(CancellationToken token) // --- SSE 解析核心逻辑 --- - private async Task ProcessSseStreamAsync(Stream stream, CancellationToken token, Action? messageInspector = null) + private async Task ProcessSseStreamAsync(Stream stream, CancellationToken token, string channel, Action? messageInspector = null) { using var reader = new StreamReader(stream, Encoding.UTF8, leaveOpen: true); @@ -351,7 +352,7 @@ private async Task ProcessSseStreamAsync(Stream stream, CancellationToken token, { if (dataBuffer.Length > 0) { - await DispatchSseEventAsync(currentEvent, dataBuffer.ToString(), token, messageInspector); + await DispatchSseEventAsync(currentEvent, dataBuffer.ToString(), token, channel, messageInspector); dataBuffer.Clear(); currentEvent = null; } @@ -372,33 +373,45 @@ private async Task ProcessSseStreamAsync(Stream stream, CancellationToken token, } } - private async Task DispatchSseEventAsync(string? eventName, string data, CancellationToken token, Action? messageInspector) + private async Task DispatchSseEventAsync(string? eventName, string data, CancellationToken token, string channel, Action? messageInspector) { if (string.IsNullOrEmpty(data) || data == "[DONE]") return; if (string.IsNullOrEmpty(eventName) || string.Equals(eventName, "message", StringComparison.OrdinalIgnoreCase)) { + _manager.LogRawIn("[Http]", channel, data); + try { - // 先尝试用 JsonDocument 解析来执行检查器,因为 _manager.ReadResponseAsync 会直接反序列化为对象, - // 而我们需要 inspect 具体字段(如 protocolVersion) - if (messageInspector != null) + // 一次解析即可分类:有 method → 服务器主动请求;有 result/error → 对客户端请求的响应。 + JsonRpcMessage? message; + try { - try - { - using var doc = JsonDocument.Parse(data); - messageInspector(doc.RootElement); - } - catch - { - // 忽略解析错误,后续 _manager 会处理 - } + message = await _manager.ReadMessageAsync(data); + } + catch + { + _logger.Warn($"[McpClient][Http] Failed to parse SSE message."); + return; + } + + // 传入 messageInspector(如需检查初始化响应中的协议版本等字段) + if (message is not null) + { + messageInspector?.Invoke(message); } - var response = await _manager.ReadResponseAsync(data); - if (response != null) + switch (message) { - await _manager.HandleRespondAsync(response, token); + case JsonRpcRequest request: + await _manager.HandleServerRequestAsync(request, token); + break; + case JsonRpcResponse response: + await _manager.HandleRespondAsync(response, token); + break; + default: + _logger.Warn($"[McpClient][Http] Unrecognized SSE message received."); + break; } } catch (Exception ex) @@ -407,4 +420,5 @@ private async Task DispatchSseEventAsync(string? eventName, string data, Cancell } } } + } diff --git a/src/DotNetCampus.ModelContextProtocol/Transports/Http/HttpServerTransportSession.cs b/src/DotNetCampus.ModelContextProtocol/Transports/Http/HttpServerTransportSession.cs new file mode 100644 index 0000000..239b48c --- /dev/null +++ b/src/DotNetCampus.ModelContextProtocol/Transports/Http/HttpServerTransportSession.cs @@ -0,0 +1,147 @@ +using DotNetCampus.ModelContextProtocol.Hosting.Logging; +using DotNetCampus.ModelContextProtocol.Protocol.Messages; +using DotNetCampus.ModelContextProtocol.Protocol.Messages.JsonRpc; + +namespace DotNetCampus.ModelContextProtocol.Transports.Http; + +/// +/// Streamable HTTP 传输层的一个会话。 +/// 同时被 和 TouchSocket HTTP 传输层使用。 +/// +public class HttpServerTransportSession : ServerTransportSession +{ + private static readonly ReadOnlyMemory EventMessageBytes = "event: message\n"u8.ToArray(); + private static readonly ReadOnlyMemory DataPrefixBytes = "data: "u8.ToArray(); + private static readonly ReadOnlyMemory NewLineBytes = "\n"u8.ToArray(); + + private readonly IServerTransportManager _manager; + private readonly string _logPrefix; + private readonly CancellationTokenSource _disposeCts = new(); + private readonly SemaphoreSlim _writeLock = new(1, 1); + + /// + /// 当前 POST 请求绑定的 SSE 输出流,使用 AsyncLocal 确保每个异步执行上下文(即每个并发 POST 请求) + /// 都拥有独立的值,避免多个并发请求相互覆盖导致竞态条件。 + /// 非 null 时,SendRequestAsync 直接向此流写入采样请求。 + /// + private static readonly AsyncLocal _currentRequestSseStream = new(); + + private IMcpLogger Log => _manager.Context.Logger; + + /// + public override string SessionId { get; } + + /// + /// 初始化 类的新实例。 + /// + /// 辅助管理 MCP 传输层的管理器。 + /// 唯一标识此会话的 ID。 + /// 日志前缀,用于区分不同传输层实现(如 "[McpServer][StreamableHttp]")。 + public HttpServerTransportSession(IServerTransportManager manager, string sessionId, string logPrefix) + { + _manager = manager; + SessionId = sessionId; + _logPrefix = logPrefix; + } + + /// + /// 将当前 POST 请求的 SSE 输出流绑定到此会话。 + /// 返回的 Dispose 后自动清除绑定(在 POST 请求处理完成后由 Transport 调用)。 + /// + public IDisposable SetRequestSseStream(Stream stream) + { + _currentRequestSseStream.Value = stream; + return new SseStreamScope(this, stream); + } + + private void ClearRequestSseStream(Stream stream) + { + if (_currentRequestSseStream.Value == stream) + { + _currentRequestSseStream.Value = null; + } + } + + /// + protected override async Task SendRequestMessageAsync(JsonRpcRequest request, CancellationToken cancellationToken) + { + var stream = _currentRequestSseStream.Value + ?? throw new InvalidOperationException("当前没有绑定的 SSE 流,无法发送服务端主动请求。"); + + Log.Debug($"{_logPrefix} Sending server-initiated request. Method={request.Method}, Id={request.Id}, SessionId={SessionId}"); + + // 直接写入当前 POST 请求的 SSE 流,不经过 Channel。 + await WriteSseMessageAsync(stream, request, cancellationToken).ConfigureAwait(false); + } + + /// + protected override void OnResponseReceived(string id, JsonRpcResponse response) + { + Log.Debug($"{_logPrefix} Received client response for pending request. Id={id}, SessionId={SessionId}"); + } + + /// + protected override void OnUnmatchedResponse(string id, JsonRpcResponse response) + { + Log.Warn($"{_logPrefix} Received unmatched client response. Id={id}, SessionId={SessionId}"); + } + + /// + /// 将一条 JSON-RPC 消息写入 SSE 流。 + /// + public async Task WriteSseMessageAsync(Stream stream, JsonRpcMessage message, CancellationToken ct) + { + await _writeLock.WaitAsync(ct).ConfigureAwait(false); + try + { + // event: message + await stream.WriteAsync(EventMessageBytes, ct); + + // data: ... + await stream.WriteAsync(DataPrefixBytes, ct); + + // Serialize + await _manager.WriteMessageAsync(stream, message, ct); + _manager.LogRawOut(_logPrefix, $"POST/sse, SessionId={SessionId}", message); + + // \n\n (End of event) + await stream.WriteAsync(NewLineBytes, ct); + await stream.WriteAsync(NewLineBytes, ct); + + await stream.FlushAsync(ct); + } + catch (Exception ex) + { + Log.Error($"{_logPrefix} Failed to write SSE message. SessionId={SessionId}", ex); + throw; + } + finally + { + _writeLock.Release(); + } + } + + /// + public override async ValueTask DisposeAsync() + { + if (_disposeCts.IsCancellationRequested) + { + return; + } + +#if NET8_0_OR_GREATER + await _disposeCts.CancelAsync(); +#else + await Task.Yield(); + _disposeCts.Cancel(); +#endif + CancelAllPendingRequests(); + _disposeCts.Dispose(); + _writeLock.Dispose(); + } + + private sealed class SseStreamScope(HttpServerTransportSession session, Stream stream) : IDisposable + { + public void Dispose() => session.ClearRequestSseStream(stream); + } +} diff --git a/src/DotNetCampus.ModelContextProtocol/Transports/Http/LocalHostHttpServerTransport.cs b/src/DotNetCampus.ModelContextProtocol/Transports/Http/LocalHostHttpServerTransport.cs index 651fbc5..8df41f6 100644 --- a/src/DotNetCampus.ModelContextProtocol/Transports/Http/LocalHostHttpServerTransport.cs +++ b/src/DotNetCampus.ModelContextProtocol/Transports/Http/LocalHostHttpServerTransport.cs @@ -1,8 +1,9 @@ -using System.Collections.Concurrent; +using System.Collections.Concurrent; using System.Net; using System.Text; using System.Text.Json; using DotNetCampus.ModelContextProtocol.Hosting.Logging; +using DotNetCampus.ModelContextProtocol.Hosting.Services; using DotNetCampus.ModelContextProtocol.Protocol; using DotNetCampus.ModelContextProtocol.Protocol.Messages.JsonRpc; @@ -15,12 +16,14 @@ public class LocalHostHttpServerTransport : IServerTransport { private const string ProtocolVersionHeader = "MCP-Protocol-Version"; private const string SessionIdHeader = "Mcp-Session-Id"; + private const int SseKeepAliveIntervalMs = 60000; private static readonly ReadOnlyMemory PrimeEventBytes = ": \n\n"u8.ToArray(); + private static readonly ReadOnlyMemory SseKeepAliveBytes = ": keep-alive\n\n"u8.ToArray(); private readonly IServerTransportManager _manager; private readonly LocalHostHttpServerTransportOptions _options; private readonly HttpListener _listener = new(); - private readonly ConcurrentDictionary _sessions = new(); + private readonly ConcurrentDictionary _sessions = new(); /// /// 初始化 类的新实例。 @@ -174,94 +177,218 @@ private async Task HandlePostRequestAsync(HttpListenerContext context, Cancellat // 协议版本检查 var protocolVersion = request.Headers[ProtocolVersionHeader]; - if (!string.IsNullOrEmpty(protocolVersion)) + if (!string.IsNullOrEmpty(protocolVersion) && protocolVersion < ProtocolVersion.Minimum) { - // 如果比最小版本小则报错 - if (protocolVersion < ProtocolVersion.Minimum) - { - await context.RespondHttpError(HttpStatusCode.BadRequest, $"Unsupported protocol version. Minimum required: {ProtocolVersion.Minimum}"); - return; - } + await context.RespondHttpError(HttpStatusCode.BadRequest, $"Unsupported protocol version. Minimum required: {ProtocolVersion.Minimum}"); + return; } - JsonRpcRequest? jsonRpcRequest; + // 解析消息体 + JsonRpcMessage? message; try { - jsonRpcRequest = await _manager.ReadRequestAsync(request.InputStream); + message = await _manager.ReadMessageAsync(request.InputStream); } catch (JsonException) { await context.RespondHttpError(HttpStatusCode.BadRequest, "Invalid JSON"); return; } - - if (jsonRpcRequest == null) + catch { - await context.RespondHttpError(HttpStatusCode.BadRequest, "Empty body"); + await context.RespondHttpError(HttpStatusCode.BadRequest, "Failed to read request body"); return; } - var isInitialize = jsonRpcRequest.Method == RequestMethods.Initialize; var sessionIdStr = request.Headers[SessionIdHeader]; - LocalHostHttpServerTransportSession? session; - if (isInitialize) + if (message is not null) { - // 初始化请求,创建新 Session - var newSessionId = _manager.MakeNewSessionId(); - var newSession = new LocalHostHttpServerTransportSession(_manager, newSessionId.Id); - - if (_sessions.TryAdd(newSessionId.Id, newSession)) - { - session = newSession; - _manager.Add(session); - context.Response.AppendHeader(SessionIdHeader, newSessionId.Id); - } - else - { - await context.RespondHttpError(HttpStatusCode.InternalServerError, "Session ID collision"); + _manager.LogRawIn("[StreamableHttp]", $"POST, SessionId={sessionIdStr}", message); + } + switch (message) + { + case JsonRpcResponse jsonRpcResponse: + await HandleClientResponseAsync(context, sessionIdStr, jsonRpcResponse); return; - } + case JsonRpcNotification notification: + await HandleNotificationAsync(context, sessionIdStr, notification, cancellationToken); + return; + case JsonRpcRequest jsonRpcRequest: + await HandleRpcRequestAsync(context, sessionIdStr, jsonRpcRequest, cancellationToken); + return; + default: + await context.RespondHttpError(HttpStatusCode.BadRequest, "Invalid or unrecognized JSON-RPC message"); + return; + } + } + + /// + /// 客户端响应服务器发起的请求(如 sampling/createMessage)。 + /// + /// + /// + /// + private async Task HandleClientResponseAsync(HttpListenerContext context, string? sessionIdStr, JsonRpcResponse response) + { + if (string.IsNullOrEmpty(sessionIdStr) || !_sessions.TryGetValue(sessionIdStr, out var session)) + { + await context.RespondHttpError(HttpStatusCode.NotFound, "Session not found"); + return; + } + session.HandleResponseAsync(response); + context.RespondHttpSuccess(HttpStatusCode.Accepted); + } + + /// + /// 通知消息,无需响应。 + /// + /// + /// + /// + /// + private async Task HandleNotificationAsync(HttpListenerContext context, string? sessionIdStr, JsonRpcNotification notification, + CancellationToken cancellationToken) + { + if (string.IsNullOrEmpty(sessionIdStr) || !_sessions.TryGetValue(sessionIdStr, out var session)) + { + await context.RespondHttpError(HttpStatusCode.NotFound, "Session not found"); + return; + } + await _manager.HandleRequestAsync( + new JsonRpcRequest { Method = notification.Method, Params = notification.Params }, + s => s.AddTransportSession(session, Log), + cancellationToken); + context.RespondHttpSuccess(HttpStatusCode.Accepted); + } + + /// + /// JSON-RPC 请求(包含 initialize 和普通请求两种路径)。 + /// + /// + /// + /// + /// + private async Task HandleRpcRequestAsync(HttpListenerContext context, string? sessionIdStr, JsonRpcRequest jsonRpcRequest, + CancellationToken cancellationToken) + { + var session = await GetOrCreateSessionAsync(context, sessionIdStr, jsonRpcRequest); + if (session is null) return; + + if (jsonRpcRequest.Method == RequestMethods.Initialize) + { + await HandleInitializeAsync(context, session, jsonRpcRequest, cancellationToken); } else { - if (string.IsNullOrEmpty(sessionIdStr)) - { - await context.RespondHttpError(HttpStatusCode.BadRequest, "Missing Mcp-Session-Id header"); - return; - } + await HandleSseRequestAsync(context, session, jsonRpcRequest, cancellationToken); + } + } - if (!_sessions.TryGetValue(sessionIdStr, out session)) + /// + /// 查找已有 Session 或为 initialize 请求创建新 Session。失败时向客户端写入错误响应并返回 null。 + /// + /// + /// + /// + /// + private async Task GetOrCreateSessionAsync(HttpListenerContext context, string? sessionIdStr, JsonRpcRequest jsonRpcRequest) + { + if (jsonRpcRequest.Method == RequestMethods.Initialize) + { + var newSessionId = _manager.MakeNewSessionId(); + var newSession = new HttpServerTransportSession(_manager, newSessionId.Id, "[McpServer][StreamableHttp]"); + if (_sessions.TryAdd(newSessionId.Id, newSession)) { - await context.RespondHttpError(HttpStatusCode.NotFound, "Session not found"); - return; + _manager.Add(newSession); + context.Response.AppendHeader(SessionIdHeader, newSessionId.Id); + return newSession; } + await context.RespondHttpError(HttpStatusCode.InternalServerError, "Session ID collision"); + return null; } - var jsonRpcResponse = await _manager.HandleRequestAsync(jsonRpcRequest, cancellationToken: cancellationToken); + if (string.IsNullOrEmpty(sessionIdStr)) + { + await context.RespondHttpError(HttpStatusCode.BadRequest, "Missing Mcp-Session-Id header"); + return null; + } + if (!_sessions.TryGetValue(sessionIdStr, out var session)) + { + await context.RespondHttpError(HttpStatusCode.NotFound, "Session not found"); + return null; + } + return session; + } - if (jsonRpcResponse != null) + /// + /// initialize 请求:同步返回 application/json,无需 SSE 流。 + /// + /// + /// + /// + /// + private async Task HandleInitializeAsync(HttpListenerContext context, HttpServerTransportSession session, JsonRpcRequest jsonRpcRequest, + CancellationToken cancellationToken) + { + var initResponse = await _manager.HandleRequestAsync(jsonRpcRequest, + s => s.AddTransportSession(session, Log), + cancellationToken); + + if (initResponse != null) { - // Request: Success or Failed. context.Response.ContentType = "application/json"; context.Response.StatusCode = (int)HttpStatusCode.OK; try { - await _manager.WriteMessageAsync(context.Response.OutputStream, jsonRpcResponse, cancellationToken); + _manager.LogRawOut("[StreamableHttp]", $"POST/json, SessionId={session.SessionId}", initResponse); + await _manager.WriteMessageAsync(context.Response.OutputStream, initResponse, cancellationToken); context.Response.SafeClose(); } catch { - // Ignore write errors + // 忽略写入错误 } } else { - // Notification: No need to respond. context.RespondHttpSuccess(HttpStatusCode.Accepted); } } + /// + /// 非 initialize 请求:以 text/event-stream 响应,服务端可在处理期间通过 SSE 流发起采样请求。 + /// 规范 §2.1 规则 6:"The server MAY send JSON-RPC requests and notifications before sending + /// the JSON-RPC response. These messages SHOULD relate to the originating client request." + /// + /// + /// + /// + /// + private async Task HandleSseRequestAsync(HttpListenerContext context, HttpServerTransportSession session, JsonRpcRequest jsonRpcRequest, + CancellationToken cancellationToken) + { + context.Response.StatusCode = (int)HttpStatusCode.OK; + context.Response.ContentType = "text/event-stream"; + context.Response.Headers["Cache-Control"] = "no-cache"; + + var output = context.Response.OutputStream; + await output.WriteAsync(PrimeEventBytes, cancellationToken); + await output.FlushAsync(cancellationToken); + + using var _ = session.SetRequestSseStream(output); + + var response = await _manager.HandleRequestAsync(jsonRpcRequest, + s => s.AddTransportSession(session, Log), + cancellationToken); + + if (response != null) + { + await session.WriteSseMessageAsync(output, response, cancellationToken); + } + context.Response.SafeClose(); + } + private async Task HandleGetRequestAsync(HttpListenerContext context, CancellationToken cancellationToken) { var request = context.Request; @@ -292,17 +419,32 @@ private async Task HandleGetRequestAsync(HttpListenerContext context, Cancellati context.Response.ContentType = "text/event-stream"; context.Response.Headers["Cache-Control"] = "no-cache"; + Log.Info($"[McpServer][StreamableHttp] SSE connection established. SessionId={sessionId}"); + try { var output = context.Response.OutputStream; await output.WriteAsync(PrimeEventBytes, cancellationToken); await output.FlushAsync(cancellationToken); - await session.RunSseConnectionAsync(output, cancellationToken); + // 定期发送 SSE 心跳,以便在客户端断开时通过写入/刷新失败尽快退出, + // 避免仅依赖外部 cancellationToken 导致连接长期悬挂。 + while (!cancellationToken.IsCancellationRequested) + { + await Task.Delay(SseKeepAliveIntervalMs, cancellationToken); + await output.WriteAsync(SseKeepAliveBytes, cancellationToken); + await output.FlushAsync(cancellationToken); + } + Log.Info($"[McpServer][StreamableHttp] SSE connection cancelled. SessionId={sessionId}"); + } + catch (OperationCanceledException) + { + // 正常关闭 + Log.Info($"[McpServer][StreamableHttp] SSE connection ended. SessionId={sessionId}"); } catch (Exception ex) { - Log.Debug($"[McpServer][StreamableHttp] SSE connection ended. SessionId={sessionId}, Error={ex.Message}"); + Log.Info($"[McpServer][StreamableHttp] SSE connection ended. SessionId={sessionId}, Error={ex.Message}"); } finally { diff --git a/src/DotNetCampus.ModelContextProtocol/Transports/Http/LocalHostHttpServerTransportSession.cs b/src/DotNetCampus.ModelContextProtocol/Transports/Http/LocalHostHttpServerTransportSession.cs deleted file mode 100644 index b28c11f..0000000 --- a/src/DotNetCampus.ModelContextProtocol/Transports/Http/LocalHostHttpServerTransportSession.cs +++ /dev/null @@ -1,115 +0,0 @@ -using System.Threading.Channels; -using DotNetCampus.ModelContextProtocol.Hosting.Logging; -using DotNetCampus.ModelContextProtocol.Protocol.Messages.JsonRpc; - -namespace DotNetCampus.ModelContextProtocol.Transports.Http; - -/// -/// Streamable HTTP 传输层的一个会话。 -/// -internal class LocalHostHttpServerTransportSession : IServerTransportSession -{ - private static readonly ReadOnlyMemory EventMessageBytes = "event: message\n"u8.ToArray(); - private static readonly ReadOnlyMemory DataPrefixBytes = "data: "u8.ToArray(); - private static readonly ReadOnlyMemory NewLineBytes = "\n"u8.ToArray(); - - private readonly IServerTransportManager _manager; - private readonly Channel _outgoingMessages; - private readonly CancellationTokenSource _disposeCts = new(); - - private IMcpLogger Log => _manager.Context.Logger; - - public string SessionId { get; } - - public LocalHostHttpServerTransportSession(IServerTransportManager manager, string sessionId) - { - _manager = manager; - SessionId = sessionId; - _outgoingMessages = Channel.CreateUnbounded(new UnboundedChannelOptions - { - SingleReader = true, - SingleWriter = false, - }); - } - - public Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) - { - if (_disposeCts.IsCancellationRequested) - { - return Task.CompletedTask; - } - return _outgoingMessages.Writer.WriteAsync(message, cancellationToken).AsTask(); - } - - public async Task RunSseConnectionAsync(Stream outputStream, CancellationToken cancellationToken) - { - using var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _disposeCts.Token); - var ct = linkedCts.Token; - - try - { - Log.Debug($"[McpServer][StreamableHttp] SSE connection started. SessionId={SessionId}"); - - // Wait for messages and write them - await foreach (var message in _outgoingMessages.Reader.ReadAllAsync(ct)) - { - await WriteSseMessageAsync(outputStream, message, ct); - } - } - catch (OperationCanceledException) - { - // Expected on shutdown - } - catch (Exception ex) - { - Log.Warn($"[McpServer][StreamableHttp] SSE connection error. SessionId={SessionId}, Error={ex.Message}"); - } - finally - { - Log.Debug($"[McpServer][StreamableHttp] SSE connection ended. SessionId={SessionId}"); - } - } - - private async Task WriteSseMessageAsync(Stream stream, JsonRpcMessage message, CancellationToken ct) - { - try - { - // event: message - await stream.WriteAsync(EventMessageBytes, ct); - - // data: ... - await stream.WriteAsync(DataPrefixBytes, ct); - - // Serialize - await _manager.WriteMessageAsync(stream, message, ct); - - // \n\n (End of event) - await stream.WriteAsync(NewLineBytes, ct); - await stream.WriteAsync(NewLineBytes, ct); - - await stream.FlushAsync(ct); - } - catch (Exception ex) - { - Log.Error($"[McpServer][StreamableHttp] Failed to write SSE message. SessionId={SessionId}", ex); - throw; // Re-throw to close connection if write fails - } - } - - public async ValueTask DisposeAsync() - { - if (_disposeCts.IsCancellationRequested) - { - return; - } - -#if NET8_0_OR_GREATER - await _disposeCts.CancelAsync(); -#else - await Task.Yield(); - _disposeCts.Cancel(); -#endif - _outgoingMessages.Writer.TryComplete(); - _disposeCts.Dispose(); - } -} diff --git a/src/DotNetCampus.ModelContextProtocol/Transports/IClientTransportManager.cs b/src/DotNetCampus.ModelContextProtocol/Transports/IClientTransportManager.cs index 200174a..5116f96 100644 --- a/src/DotNetCampus.ModelContextProtocol/Transports/IClientTransportManager.cs +++ b/src/DotNetCampus.ModelContextProtocol/Transports/IClientTransportManager.cs @@ -19,6 +19,23 @@ public interface IClientTransportManager /// RequestId MakeNewRequestId(); + /// + /// 提供给传输层调用。当传输层收到消息字符串行后,一次解析即可分类并反序列化为具体消息类型。 + /// + /// 消息字符串行。 + /// + /// 读取出来的 JSON-RPC 消息对象: + /// + /// 服务器主动发起的请求 → + /// 对客户端请求的响应 → + /// 无法识别或解析失败 → + /// + /// + /// + /// 如果读取失败,此方法会暴露底层的任何读取异常,传输层需处理好此异常(说明消息不正确)。 + /// + ValueTask ReadMessageAsync(string messageLine); + /// /// 提供给传输层调用。当传输层收到响应字符串行后,调用此方法可以将字符串读取为 JSON-RPC 响应对象。 /// @@ -77,4 +94,11 @@ public interface IClientTransportManager /// 取消令牌。 /// 此方法绝对不会发生异常。 ValueTask HandleRespondAsync(JsonRpcResponse response, CancellationToken cancellationToken = default); + + /// + /// 提供给传输层调用。当传输层收到来自服务器的 JSON-RPC 请求时(如 sampling/createMessage),调用此方法可以将请求交给 MCP 客户端进行处理并回送响应。 + /// + /// 从传输层解析出来的服务器发起的 JSON-RPC 请求。 + /// 取消令牌。 + ValueTask HandleServerRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default); } diff --git a/src/DotNetCampus.ModelContextProtocol/Transports/IServerTransportManager.cs b/src/DotNetCampus.ModelContextProtocol/Transports/IServerTransportManager.cs index ff77c29..078dd10 100644 --- a/src/DotNetCampus.ModelContextProtocol/Transports/IServerTransportManager.cs +++ b/src/DotNetCampus.ModelContextProtocol/Transports/IServerTransportManager.cs @@ -47,34 +47,34 @@ public interface IServerTransportManager bool TryGetSession(string sessionId, [NotNullWhen(true)] out T? session) where T : class, IServerTransportSession; /// - /// 提供给传输层调用。当传输层收到请求字符串行后,调用此方法可以将字符串读取为 JSON-RPC 请求对象。 + /// 提供给传输层调用。当传输层收到一行文本消息后,调用此方法将其解析为具体的 JSON-RPC 消息类型。 /// - /// 请求字符串行。 - /// 读取出来的 JSON-RPC 请求对象,如果无法读取则返回 - /// - /// 如果读取失败,此方法会暴露底层的任何读取异常,传输层需处理好此异常(说明请求消息不正确)。 - /// - ValueTask ReadRequestAsync(string requestLine); + /// 消息文本行。 + /// + /// 解析出的消息对象,实际类型为 (有 id)、(无 id) + /// 或 (无 method)之一;无法解析时返回 。 + /// + ValueTask ReadMessageAsync(string messageLine); /// - /// 提供给传输层调用。当传输层收到请求流后,调用此方法可以将请求流读取为 JSON-RPC 请求对象。 + /// 提供给传输层调用。当传输层收到字节流消息后,调用此方法将其解析为具体的 JSON-RPC 消息类型。 /// - /// 请求流。 - /// 读取出来的 JSON-RPC 请求对象,如果无法读取则返回 - /// - /// 如果读取失败,此方法会暴露底层的任何读取异常,传输层需处理好此异常(说明请求消息不正确或连接关闭等)。 - /// - ValueTask ReadRequestAsync(Stream requestStream); + /// 消息流。 + /// + /// 解析出的消息对象,实际类型为 (有 id)、(无 id) + /// 或 (无 method)之一;无法解析时返回 。 + /// + ValueTask ReadMessageAsync(Stream messageStream); /// - /// 提供给传输层调用。当传输层收到请求流后,调用此方法可以将请求流读取为 JSON-RPC 请求对象。 + /// 提供给传输层调用。当传输层收到字节内存消息后,调用此方法将其解析为具体的 JSON-RPC 消息类型。 /// - /// 请求流。 - /// 读取出来的 JSON-RPC 请求对象,如果无法读取则返回 - /// - /// 如果读取失败,此方法会暴露底层的任何读取异常,传输层需处理好此异常(说明请求消息不正确或连接关闭等)。 - /// - ValueTask ReadRequestAsync(ReadOnlyMemory requestMemory); + /// 消息字节内存。 + /// + /// 解析出的消息对象,实际类型为 (有 id)、(无 id) + /// 或 (无 method)之一;无法解析时返回 。 + /// + ValueTask ReadMessageAsync(ReadOnlyMemory messageMemory); /// /// 提供给传输层调用,用于发送消息给 MCP 客户端。 diff --git a/src/DotNetCampus.ModelContextProtocol/Transports/IServerTransportSession.cs b/src/DotNetCampus.ModelContextProtocol/Transports/IServerTransportSession.cs index c454f46..ce3cccc 100644 --- a/src/DotNetCampus.ModelContextProtocol/Transports/IServerTransportSession.cs +++ b/src/DotNetCampus.ModelContextProtocol/Transports/IServerTransportSession.cs @@ -1,4 +1,5 @@ -using DotNetCampus.ModelContextProtocol.Protocol.Messages.JsonRpc; +using DotNetCampus.ModelContextProtocol.Protocol.Messages; +using DotNetCampus.ModelContextProtocol.Protocol.Messages.JsonRpc; namespace DotNetCampus.ModelContextProtocol.Transports; @@ -15,7 +16,17 @@ public interface IServerTransportSession : IAsyncDisposable string? SessionId { get; } /// - /// 将消息发送给其他端。 + /// 连接的客户端所声明的客户端能力。在 Initialize 握手完成后设置。 /// - Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default); + ClientCapabilities? ConnectedClientCapabilities { get; set; } + + /// + /// 向客户端发送 JSON-RPC 请求并等待响应。用于服务器主动发起的请求(如 sampling/createMessage)。 + /// + Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default); + + /// + /// 处理从客户端收到的 JSON-RPC 响应(对服务器发起的请求的回复)。 + /// + void HandleResponseAsync(JsonRpcResponse response); } diff --git a/src/DotNetCampus.ModelContextProtocol/Transports/McpTransportLoggerExtensions.cs b/src/DotNetCampus.ModelContextProtocol/Transports/McpTransportLoggerExtensions.cs new file mode 100644 index 0000000..38eea92 --- /dev/null +++ b/src/DotNetCampus.ModelContextProtocol/Transports/McpTransportLoggerExtensions.cs @@ -0,0 +1,129 @@ +using System.Text.Json; +using DotNetCampus.ModelContextProtocol.CompilerServices; +using DotNetCampus.ModelContextProtocol.Hosting.Logging; +using DotNetCampus.ModelContextProtocol.Protocol.Messages; +using DotNetCampus.ModelContextProtocol.Protocol.Messages.JsonRpc; + +namespace DotNetCampus.ModelContextProtocol.Transports; + +/// +/// 专为传输层原始消息进行日志记录的扩展方法。 +/// +public static class McpTransportLoggerExtensions +{ + private const int TrimmedRawMessageMaxLength = 80; + + /// MCP 传输层管理器。 + extension(IServerTransportManager manager) + { + /// 记录传输层接收到的原始 JSON-RPC 消息。 + public void LogRawIn(string tag, string jsonRpcRawMessage) => ((IMcpTransportLogger)manager).LogRaw("[McpServer]", tag, "←", jsonRpcRawMessage); + /// 记录传输层接收到的原始 JSON-RPC 消息。 + public void LogRawIn(string tag, JsonRpcMessage jsonRpcRawMessage) => ((IMcpTransportLogger)manager).LogRaw("[McpServer]", tag, "←", jsonRpcRawMessage); + /// 记录传输层接收到的原始 JSON-RPC 消息。 + public void LogRawIn(string tag, string channel, string jsonRpcRawMessage) => ((IMcpTransportLogger)manager).LogRaw("[McpServer]", tag, "←", jsonRpcRawMessage, channel); + /// 记录传输层接收到的原始 JSON-RPC 消息。 + public void LogRawIn(string tag, string channel, JsonRpcMessage jsonRpcRawMessage) => ((IMcpTransportLogger)manager).LogRaw("[McpServer]", tag, "←", jsonRpcRawMessage, channel); + /// 记录传输层接收到的原始 JSON-RPC 消息。 + public void LogRawOut(string tag, string jsonRpcRawMessage) => ((IMcpTransportLogger)manager).LogRaw("[McpServer]", tag, "→", jsonRpcRawMessage); + /// 记录传输层接收到的原始 JSON-RPC 消息。 + public void LogRawOut(string tag, JsonRpcMessage jsonRpcRawMessage) => ((IMcpTransportLogger)manager).LogRaw("[McpServer]", tag, "→", jsonRpcRawMessage); + /// 记录传输层接收到的原始 JSON-RPC 消息。 + public void LogRawOut(string tag, string channel, string jsonRpcRawMessage) => ((IMcpTransportLogger)manager).LogRaw("[McpServer]", tag, "→", jsonRpcRawMessage, channel); + /// 记录传输层接收到的原始 JSON-RPC 消息。 + public void LogRawOut(string tag, string channel, JsonRpcMessage jsonRpcRawMessage) => ((IMcpTransportLogger)manager).LogRaw("[McpServer]", tag, "→", jsonRpcRawMessage, channel); + } + + /// MCP 传输层管理器。 + extension(IClientTransportManager manager) + { + /// 记录传输层接收到的原始 JSON-RPC 消息。 + public void LogRawIn(string tag, string jsonRpcRawMessage) => ((IMcpTransportLogger)manager).LogRaw("[McpClient]", tag, "←", jsonRpcRawMessage); + /// 记录传输层接收到的原始 JSON-RPC 消息。 + public void LogRawIn(string tag, JsonRpcMessage jsonRpcRawMessage) => ((IMcpTransportLogger)manager).LogRaw("[McpClient]", tag, "←", jsonRpcRawMessage); + /// 记录传输层接收到的原始 JSON-RPC 消息。 + public void LogRawIn(string tag, string channel, string jsonRpcRawMessage) => ((IMcpTransportLogger)manager).LogRaw("[McpClient]", tag, "←", jsonRpcRawMessage, channel); + /// 记录传输层接收到的原始 JSON-RPC 消息。 + public void LogRawIn(string tag, string channel, JsonRpcMessage jsonRpcRawMessage) => ((IMcpTransportLogger)manager).LogRaw("[McpClient]", tag, "←", jsonRpcRawMessage, channel); + /// 记录传输层接收到的原始 JSON-RPC 消息。 + public void LogRawOut(string tag, string jsonRpcRawMessage) => ((IMcpTransportLogger)manager).LogRaw("[McpClient]", tag, "→", jsonRpcRawMessage); + /// 记录传输层接收到的原始 JSON-RPC 消息。 + public void LogRawOut(string tag, JsonRpcMessage jsonRpcRawMessage) => ((IMcpTransportLogger)manager).LogRaw("[McpClient]", tag, "→", jsonRpcRawMessage); + /// 记录传输层接收到的原始 JSON-RPC 消息。 + public void LogRawOut(string tag, string channel, string jsonRpcRawMessage) => ((IMcpTransportLogger)manager).LogRaw("[McpClient]", tag, "→", jsonRpcRawMessage, channel); + /// 记录传输层接收到的原始 JSON-RPC 消息。 + public void LogRawOut(string tag, string channel, JsonRpcMessage jsonRpcRawMessage) => ((IMcpTransportLogger)manager).LogRaw("[McpClient]", tag, "→", jsonRpcRawMessage, channel); + } + + private static void LogRaw(this IMcpTransportLogger transportLogger, string role, string tag, string direction, string jsonRpcRawMessage, string? channel = null) + { + if (transportLogger.RawMessageLoggingDetailLevel is not McpTransportRawMessageLoggingDetailLevel.None + && transportLogger.Logger.IsEnabled(LoggingLevel.Debug)) + { + var trimmedMessage = transportLogger.RawMessageLoggingDetailLevel is McpTransportRawMessageLoggingDetailLevel.Full + || jsonRpcRawMessage.Length <= TrimmedRawMessageMaxLength + ? jsonRpcRawMessage + : jsonRpcRawMessage[..TrimmedRawMessageMaxLength] + "...(trimmed)"; + var channelPart = channel != null ? $" [via {channel}]" : ""; + transportLogger.Logger.Debug($"{role}{tag} {direction}{channelPart} {trimmedMessage}"); + } + } + + private static void LogRaw(this IMcpTransportLogger transportLogger, string role, string tag, string direction, JsonRpcMessage jsonRpcRawMessage, string? channel = null) + { + if (transportLogger.RawMessageLoggingDetailLevel is not McpTransportRawMessageLoggingDetailLevel.None + && transportLogger.Logger.IsEnabled(LoggingLevel.Debug)) + { + var json = JsonSerializer.Serialize(jsonRpcRawMessage, jsonRpcRawMessage switch + { + JsonRpcNotification => McpInternalJsonContext.Default.JsonRpcNotification, + JsonRpcRequest => McpInternalJsonContext.Default.JsonRpcRequest, + JsonRpcResponse => McpInternalJsonContext.Default.JsonRpcResponse, + _ => throw new InvalidOperationException($"Unexpected JsonRpcMessage type: {jsonRpcRawMessage.GetType().FullName}"), + }); + var trimmedMessage = transportLogger.RawMessageLoggingDetailLevel is McpTransportRawMessageLoggingDetailLevel.Full + || json.Length <= TrimmedRawMessageMaxLength + ? json + : json[..TrimmedRawMessageMaxLength] + "...(trimmed)"; + var channelPart = channel != null ? $" [via {channel}]" : ""; + transportLogger.Logger.Debug($"{role}{tag} {direction}{channelPart} {trimmedMessage}"); + } + } +} + +/// +/// 提供 MCP 传输层日志记录功能的接口。实现此接口的类可以为传输层原始消息提供日志记录支持。 +/// +internal interface IMcpTransportLogger +{ + /// + /// 获取用于记录 MCP 传输层日志的 实例。 + /// + IMcpLogger Logger { get; } + + /// + /// 获取 MCP 传输层原始消息日志记录的详细程度。根据此属性的值,传输层可以决定是否记录原始消息日志,以及记录多少细节。 + /// + McpTransportRawMessageLoggingDetailLevel RawMessageLoggingDetailLevel { get; } +} + +/// +/// MCP 传输层原始消息日志记录的详细程度。 +/// +public enum McpTransportRawMessageLoggingDetailLevel +{ + /// + /// 不记录原始消息日志。 + /// + None, + + /// + /// 记录裁剪的原始消息。过长的消息会被裁剪以避免日志过大。 + /// + Trimmed, + + /// + /// 记录完整的原始消息。可能会导致日志过大,一般仅建议在调试时使用。 + /// + Full, +} diff --git a/src/DotNetCampus.ModelContextProtocol/Transports/ServerTransportManager.cs b/src/DotNetCampus.ModelContextProtocol/Transports/ServerTransportManager.cs index c82c387..289713d 100644 --- a/src/DotNetCampus.ModelContextProtocol/Transports/ServerTransportManager.cs +++ b/src/DotNetCampus.ModelContextProtocol/Transports/ServerTransportManager.cs @@ -1,9 +1,8 @@ -using System.Buffers; using System.Collections.Concurrent; using System.Diagnostics.CodeAnalysis; -using System.IO.Pipelines; using System.Text.Json; using DotNetCampus.ModelContextProtocol.CompilerServices; +using DotNetCampus.ModelContextProtocol.Hosting.Logging; using DotNetCampus.ModelContextProtocol.Hosting.Services; using DotNetCampus.ModelContextProtocol.Protocol; using DotNetCampus.ModelContextProtocol.Protocol.Messages.JsonRpc; @@ -12,7 +11,7 @@ namespace DotNetCampus.ModelContextProtocol.Transports; -internal class ServerTransportManager(McpServer server, McpServerContext context) : IServerTransportManager +internal class ServerTransportManager(McpServer server, McpServerContext context) : IServerTransportManager, IMcpTransportLogger { /// /// 表示 MCP 服务正在运行的 。 @@ -52,6 +51,12 @@ internal class ServerTransportManager(McpServer server, McpServerContext context /// public IServerTransportContext Context => context; + /// + public IMcpLogger Logger => Context.Logger; + + /// + public McpTransportRawMessageLoggingDetailLevel RawMessageLoggingDetailLevel { get; init; } + /// /// 获取已注册的传输层列表。 /// @@ -135,45 +140,71 @@ public bool TryGetSession(string sessionId, [NotNullWhen(true)] out T? sessio return false; } - public ValueTask ReadRequestAsync(string requestLine) + public ValueTask ReadMessageAsync(string messageLine) { - var message = JsonSerializer.Deserialize(requestLine, McpServerRequestJsonContext.Default.JsonRpcRequest); - if (message is { Method: RequestMethods.Initialize, Id: null }) - { - return ValueTask.FromResult(message with { Id = MakeNewSessionId().ToJsonElement() }); - } - return ValueTask.FromResult(message); + var message = JsonElement.Parse(messageLine); + return ValueTask.FromResult(ClassifyAndDeserialize(message)); } - public async ValueTask ReadRequestAsync(Stream requestStream) + public async ValueTask ReadMessageAsync(Stream messageStream) { - var message = await JsonSerializer.DeserializeAsync(requestStream, McpServerRequestJsonContext.Default.JsonRpcRequest); - if (message is { Method: RequestMethods.Initialize, Id: null }) - { - return message with { Id = MakeNewSessionId().ToJsonElement() }; - } - return message; + using var doc = await JsonDocument.ParseAsync(messageStream); + return ClassifyAndDeserialize(doc.RootElement); + } + + public ValueTask ReadMessageAsync(ReadOnlyMemory messageMemory) + { + var message = JsonElement.Parse(messageMemory.Span); + return ValueTask.FromResult(ClassifyAndDeserialize(message)); } - public async ValueTask ReadRequestAsync(ReadOnlyMemory requestMemory) + /// + /// 根据 JSON-RPC 2.0 字段特征将 分类并反序列化为具体消息类型。 + /// + private JsonRpcMessage? ClassifyAndDeserialize(JsonElement element) { - var pipeReader = PipeReader.Create(new ReadOnlySequence(requestMemory)); - var message = await JsonSerializer.DeserializeAsync(pipeReader, McpServerRequestJsonContext.Default.JsonRpcRequest); - if (message is { Method: RequestMethods.Initialize, Id: null }) + var hasMethod = element.TryGetProperty("method", out var methodElement); + + if (hasMethod) + { + // 有 id 且非 null → 请求;无 id 或 id 为 null → 通知。 + var hasId = element.TryGetProperty("id", out var idElement) && idElement.ValueKind != JsonValueKind.Null; + + // initialize 请求即使 id 缺失或为 null 也应被视为请求(兼容旧客户端)。 + var isInitialize = methodElement.ValueKind is JsonValueKind.String && methodElement.GetString() == RequestMethods.Initialize; + + if (hasId || isInitialize) + { + var request = element.Deserialize(McpInternalJsonContext.Default.JsonRpcRequest); + if (request is { Method: RequestMethods.Initialize, Id: null }) + { + return request with { Id = MakeNewSessionId().ToJsonElement() }; + } + return request; + } + else + { + return element.Deserialize(McpInternalJsonContext.Default.JsonRpcNotification); + } + } + + var hasResultOrError = element.TryGetProperty("result", out _) || element.TryGetProperty("error", out _); + if (hasResultOrError) { - return message with { Id = MakeNewSessionId().ToJsonElement() }; + return element.Deserialize(McpInternalJsonContext.Default.JsonRpcResponse); } - return message; + + return null; } public Task WriteMessageAsync(Stream stream, JsonRpcMessage message, CancellationToken cancellationToken) => message switch { JsonRpcResponse response => JsonSerializer.SerializeAsync(stream, response, - McpServerResponseJsonContext.Default.JsonRpcResponse, cancellationToken), + McpInternalJsonContext.Default.JsonRpcResponse, cancellationToken), JsonRpcRequest request => JsonSerializer.SerializeAsync(stream, request, - McpServerRequestJsonContext.Default.JsonRpcRequest, cancellationToken), + McpInternalJsonContext.Default.JsonRpcRequest, cancellationToken), JsonRpcNotification notification => JsonSerializer.SerializeAsync(stream, notification, - McpServerRequestJsonContext.Default.JsonRpcNotification, cancellationToken), + McpInternalJsonContext.Default.JsonRpcNotification, cancellationToken), _ => throw new InvalidOperationException($"Unsupported message type: {message.GetType().FullName}"), }; diff --git a/src/DotNetCampus.ModelContextProtocol/Transports/ServerTransportSession.cs b/src/DotNetCampus.ModelContextProtocol/Transports/ServerTransportSession.cs new file mode 100644 index 0000000..fdd13ca --- /dev/null +++ b/src/DotNetCampus.ModelContextProtocol/Transports/ServerTransportSession.cs @@ -0,0 +1,106 @@ +using System.Collections.Concurrent; +using DotNetCampus.ModelContextProtocol.Protocol.Messages; +using DotNetCampus.ModelContextProtocol.Protocol.Messages.JsonRpc; + +namespace DotNetCampus.ModelContextProtocol.Transports; + +/// +/// 的抽象基类,封装了通用的"等待客户端响应"模式(TCS 字典 + CancellationToken 注册)。 +/// +/// 各传输层的 Session 继承本类,并实现 来完成各自的"发送"操作 +/// (如 Stdio 写 stdout、Streamable HTTP 写 SSE 流)。 +/// +/// +public abstract class ServerTransportSession : IServerTransportSession +{ + private readonly ConcurrentDictionary> _pendingRequests = []; + + /// + public abstract string? SessionId { get; } + + /// + public ClientCapabilities? ConnectedClientCapabilities { get; set; } + + /// + public async Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default) + { + if (request.Id?.ToString() is not { } id) + { + throw new InvalidOperationException("请求 ID 不能为 null。Request ID must not be null."); + } + + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + if (!_pendingRequests.TryAdd(id, tcs)) + { + throw new InvalidOperationException($"已存在相同 ID 的挂起请求:{id}。"); + } + + using var registration = cancellationToken.Register(() => + { + if (_pendingRequests.TryRemove(id, out var removed)) + { + removed.TrySetCanceled(cancellationToken); + } + }); + + try + { + await SendRequestMessageAsync(request, cancellationToken).ConfigureAwait(false); + return await tcs.Task.ConfigureAwait(false); + } + finally + { + _pendingRequests.TryRemove(id, out _); + } + } + + /// + /// 执行实际的消息发送操作。由子类实现,负责将 写入各自的传输通道 + /// (如 Stdio 写 stdout、Streamable HTTP 写 SSE 流)。 + /// + protected abstract Task SendRequestMessageAsync(JsonRpcRequest request, CancellationToken cancellationToken); + + /// + public void HandleResponseAsync(JsonRpcResponse response) + { + if (response.Id?.ToString() is not { } id) + { + return; + } + + if (_pendingRequests.TryRemove(id, out var tcs)) + { + OnResponseReceived(id, response); + tcs.TrySetResult(response); + } + else + { + OnUnmatchedResponse(id, response); + } + } + + /// + /// 匹配的客户端响应到达时的回调(可用于日志)。默认为空实现。 + /// + protected virtual void OnResponseReceived(string id, JsonRpcResponse response) { } + + /// + /// 无法匹配的客户端响应到达时的回调(可用于日志)。默认为空实现。 + /// + protected virtual void OnUnmatchedResponse(string id, JsonRpcResponse response) { } + + /// + /// 取消所有待处理的挂起请求,供 调用。 + /// + protected void CancelAllPendingRequests() + { + foreach (var (_, tcs) in _pendingRequests) + { + tcs.TrySetCanceled(); + } + _pendingRequests.Clear(); + } + + /// + public abstract ValueTask DisposeAsync(); +} diff --git a/src/DotNetCampus.ModelContextProtocol/Transports/Stdio/StdioClientTransport.cs b/src/DotNetCampus.ModelContextProtocol/Transports/Stdio/StdioClientTransport.cs index 81b43e7..19127f3 100644 --- a/src/DotNetCampus.ModelContextProtocol/Transports/Stdio/StdioClientTransport.cs +++ b/src/DotNetCampus.ModelContextProtocol/Transports/Stdio/StdioClientTransport.cs @@ -1,4 +1,4 @@ -using System.Diagnostics; +using System.Diagnostics; using System.Diagnostics.Contracts; using System.Text; using DotNetCampus.ModelContextProtocol.Hosting.Logging; @@ -68,6 +68,7 @@ public async ValueTask SendMessageAsync(JsonRpcMessage message, CancellationToke } var line = _manager.WriteMessageAsync(message); + _manager.LogRawOut("[Stdio]", line); await stdio.StandardInput.WriteAsync(line); await stdio.StandardInput.WriteAsync('\n'); await stdio.StandardInput.FlushAsync(); @@ -98,14 +99,37 @@ private async Task RunLoopAsync(StdioProcessInfo stdio, CancellationToken cancel break; } - var response = await _manager.ParseAndCatchResponseAsync(line); - if (response is null) + if (string.IsNullOrWhiteSpace(line)) + { + continue; + } + + _manager.LogRawIn("[Stdio]", line); + + // 一次解析即可分类:有 method → 服务器主动请求;有 result/error → 对客户端请求的响应。 + JsonRpcMessage? message; + try + { + message = await _manager.ReadMessageAsync(line); + } + catch { Log.Warn($"[McpClient][Stdio] Invalid server message received."); continue; } - await _manager.HandleRespondAsync(response, cancellationToken); + switch (message) + { + case JsonRpcRequest request: + await _manager.HandleServerRequestAsync(request, cancellationToken); + break; + case JsonRpcResponse response: + await _manager.HandleRespondAsync(response, cancellationToken); + break; + default: + Log.Warn($"[McpClient][Stdio] Unrecognized server message received."); + break; + } } } @@ -196,22 +220,3 @@ private readonly record struct StdioProcessInfo public required StreamReader StandardError { get; init; } } } - -file static class Extensions -{ - extension(IClientTransportManager manager) - { - public async ValueTask ParseAndCatchResponseAsync(string inputMessageText) - { - try - { - return await manager.ReadResponseAsync(inputMessageText); - } - catch - { - // 响应消息格式不正确,返回 null 后,原样给 MCP 客户端报告错误。 - return null; - } - } - } -} diff --git a/src/DotNetCampus.ModelContextProtocol/Transports/Stdio/StdioServerTransport.cs b/src/DotNetCampus.ModelContextProtocol/Transports/Stdio/StdioServerTransport.cs index 30bc8d0..7c4b34c 100644 --- a/src/DotNetCampus.ModelContextProtocol/Transports/Stdio/StdioServerTransport.cs +++ b/src/DotNetCampus.ModelContextProtocol/Transports/Stdio/StdioServerTransport.cs @@ -1,6 +1,7 @@ -using System.Text; +using System.Text; using DotNetCampus.ModelContextProtocol.Hosting.Logging; using DotNetCampus.ModelContextProtocol.Protocol.Messages.JsonRpc; +using DotNetCampus.ModelContextProtocol.Servers; namespace DotNetCampus.ModelContextProtocol.Transports.Stdio; @@ -28,7 +29,7 @@ public class StdioServerTransport : IServerTransport public StdioServerTransport(IServerTransportManager manager) { _manager = manager; - _session = new StdioServerTransportSession(); + _session = new StdioServerTransportSession(manager); } private IMcpLogger Log => _manager.Context.Logger; @@ -50,6 +51,7 @@ public Task StartAsync(CancellationToken startingCancellationToken, Cancel StandardInput = input, StandardOutput = output, }; + _session.SetOutput(output); _manager.Add(_session); return Task.FromResult(RunLoopAsync(runningCancellationToken)); @@ -81,32 +83,80 @@ private async Task RunLoopAsync(CancellationToken cancellationToken) var line = await input.ReadLineAsync(cancellationToken); if (line is null) { + Log.Info($"[McpServer][Stdio] Client disconnected (end of input stream)."); break; } - var request = await _manager.ParseAndCatchRequestAsync(line); - if (request is null) + if (string.IsNullOrWhiteSpace(line)) { - await _manager.RespondJsonRpcAsync(output, new JsonRpcResponse - { - Error = new JsonRpcError - { - Code = (int)JsonRpcErrorCode.InvalidRequest, - Message = $"Invalid request message: {line}", - }, - }, cancellationToken); continue; } - var response = await _manager.HandleRequestAsync(request, null, cancellationToken); - if (response is null) + _manager.LogRawIn("[Stdio]", line); + + JsonRpcMessage? message; + try { - // 按照 MCP 协议规范,本次请求仅需响应而无需回复。 - await output.WriteLineAsync(); - continue; + message = await _manager.ReadMessageAsync(line); + } + catch + { + message = null; } - await _manager.RespondJsonRpcAsync(output, response, cancellationToken); + switch (message) + { + case JsonRpcResponse response: + // 将响应路由到等待的请求。 + Log.Debug($"[McpServer][Stdio] Routing client response to session."); + _session.HandleResponseAsync(response); + continue; + + case JsonRpcNotification notification: + // 通知,路由到处理器,无需回复。 + await _manager.HandleRequestAsync( + new JsonRpcRequest { Method = notification.Method, Params = notification.Params }, + s => + { + s.AddScoped(_session); + s.AddScoped(new McpServerSampling(_session, Log)); + }, + cancellationToken); + continue; + + case JsonRpcRequest request: + { + var session = _session; + var response2 = await _manager.HandleRequestAsync(request, + s => + { + s.AddScoped(session); + s.AddScoped(new McpServerSampling(session, Log)); + }, + cancellationToken); + if (response2 is null) + { + // 按照 MCP 协议规范,本次请求仅需响应而无需回复。 + await output.WriteLineAsync(); + continue; + } + await _manager.RespondJsonRpcAsync(output, response2, cancellationToken); + continue; + } + + default: + // 无法解析的消息,回复错误。 + Log.Warn($"[McpServer][Stdio] Received unrecognizable message, responding with error."); + await _manager.RespondJsonRpcAsync(output, new JsonRpcResponse + { + Error = new JsonRpcError + { + Code = (int)JsonRpcErrorCode.InvalidRequest, + Message = $"Invalid request message: {line}", + }, + }, cancellationToken); + continue; + } } } @@ -128,23 +178,11 @@ file static class Extensions { extension(IServerTransportManager manager) { - public async ValueTask ParseAndCatchRequestAsync(string inputMessageText) - { - try - { - return await manager.ReadRequestAsync(inputMessageText); - } - catch - { - // 请求消息格式不正确,返回 null 后,原样给 MCP 客户端报告错误。 - return null; - } - } - public async ValueTask RespondJsonRpcAsync(StreamWriter writer, JsonRpcResponse response, CancellationToken cancellationToken) { try { + manager.LogRawOut("[Stdio]", response); await manager.WriteMessageAsync(writer.BaseStream, response, cancellationToken); await writer.WriteLineAsync(); } diff --git a/src/DotNetCampus.ModelContextProtocol/Transports/Stdio/StdioServerTransportSession.cs b/src/DotNetCampus.ModelContextProtocol/Transports/Stdio/StdioServerTransportSession.cs index 7a119de..e6e7383 100644 --- a/src/DotNetCampus.ModelContextProtocol/Transports/Stdio/StdioServerTransportSession.cs +++ b/src/DotNetCampus.ModelContextProtocol/Transports/Stdio/StdioServerTransportSession.cs @@ -1,33 +1,106 @@ -using DotNetCampus.ModelContextProtocol.Protocol.Messages.JsonRpc; +using System.Text; +using System.Text.Json; +using DotNetCampus.ModelContextProtocol.CompilerServices; +using DotNetCampus.ModelContextProtocol.Hosting.Logging; +using DotNetCampus.ModelContextProtocol.Protocol.Messages; +using DotNetCampus.ModelContextProtocol.Protocol.Messages.JsonRpc; namespace DotNetCampus.ModelContextProtocol.Transports.Stdio; /// /// STDIO 传输层的一个会话。 /// -public class StdioServerTransportSession : IServerTransportSession +public class StdioServerTransportSession : ServerTransportSession { + private static readonly ReadOnlyMemory NewLineBytes = "\n"u8.ToArray(); + private readonly SemaphoreSlim _writeLock = new(1, 1); + private readonly IServerTransportManager _manager; + private readonly IMcpLogger _logger; + private StreamWriter? _output; + /// - /// STDIO 传输层的一个会话。 + /// 初始化 类的新实例。 /// - public StdioServerTransportSession() + /// 辅助管理 MCP 传输层的管理器。 + public StdioServerTransportSession(IServerTransportManager manager) { + _manager = manager; + _logger = manager.Context.Logger; } /// /// STDIO 传输层是专用的,不需要会话 ID。 /// - public string? SessionId => null; + public override string? SessionId => null; + + /// + /// 由 在启动后设置输出流。 + /// + internal void SetOutput(StreamWriter output) + { + _output = output; + } /// - public Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) + protected override async Task SendRequestMessageAsync(JsonRpcRequest request, CancellationToken cancellationToken) + { + _logger.Debug($"[McpServer][Stdio] Sending server-initiated request. Method={request.Method}, Id={request.Id}, SessionId={SessionId}"); + await SendMessageAsync(request, cancellationToken).ConfigureAwait(false); + } + + private async Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) { - throw new NotImplementedException(); + if (_output is not { } output) + { + throw new InvalidOperationException("STDIO 传输层尚未初始化输出流,无法发送服务端主动请求。"); + } + + await _writeLock.WaitAsync(cancellationToken).ConfigureAwait(false); + try + { + if (_logger.IsEnabled(LoggingLevel.Debug)) + { + using var ms = new MemoryStream(); + await JsonSerializer.SerializeAsync(ms, message, GetTypeInfo(message), cancellationToken).ConfigureAwait(false); + var bytes = ms.ToArray(); + var json = Encoding.UTF8.GetString(bytes); + _manager.LogRawOut("[Stdio]", json); + await output.BaseStream.WriteAsync(bytes, cancellationToken).ConfigureAwait(false); + } + else + { + await JsonSerializer.SerializeAsync(output.BaseStream, message, GetTypeInfo(message), cancellationToken).ConfigureAwait(false); + } + await output.BaseStream.WriteAsync(NewLineBytes, cancellationToken).ConfigureAwait(false); + await output.BaseStream.FlushAsync(cancellationToken).ConfigureAwait(false); + } + finally + { + _writeLock.Release(); + } } /// - public ValueTask DisposeAsync() + protected override void OnResponseReceived(string id, JsonRpcResponse response) + => _logger.Debug($"[McpServer][Stdio] Received client response for pending request. Id={id}"); + + /// + protected override void OnUnmatchedResponse(string id, JsonRpcResponse response) + => _logger.Warn($"[McpServer][Stdio] Received unmatched client response. Id={id}"); + + /// + public override ValueTask DisposeAsync() { + _writeLock.Dispose(); + CancelAllPendingRequests(); return ValueTask.CompletedTask; } + + private static System.Text.Json.Serialization.Metadata.JsonTypeInfo GetTypeInfo(JsonRpcMessage message) => message switch + { + JsonRpcResponse response => McpInternalJsonContext.Default.JsonRpcResponse, + JsonRpcRequest request => McpInternalJsonContext.Default.JsonRpcRequest, + JsonRpcNotification notification => McpInternalJsonContext.Default.JsonRpcNotification, + _ => throw new ArgumentException($"不支持的消息类型:{message.GetType().FullName}."), + }; } diff --git a/tests/DotNetCampus.ModelContextProtocol.Tests/McpTools/SamplingTool.cs b/tests/DotNetCampus.ModelContextProtocol.Tests/McpTools/SamplingTool.cs new file mode 100644 index 0000000..25bbb52 --- /dev/null +++ b/tests/DotNetCampus.ModelContextProtocol.Tests/McpTools/SamplingTool.cs @@ -0,0 +1,35 @@ +using DotNetCampus.ModelContextProtocol.CompilerServices; +using DotNetCampus.ModelContextProtocol.Protocol.Messages; +using DotNetCampus.ModelContextProtocol.Servers; + +namespace DotNetCampus.ModelContextProtocol.Tests.McpTools; + +/// +/// 用于测试服务器向客户端发起 Sampling 请求的工具。 +/// +public class SamplingTool +{ + /// + /// 通过服务端 Sampling 能力向客户端 LLM 发起采样请求,并返回结果文本。 + /// + [McpServerTool] + public async Task AskLlm(string message, IMcpServerCallToolContext context) + { + if (!context.Sampling.IsSupported) + { + throw new InvalidOperationException("Sampling service not available in this context."); + } + + var result = await context.Sampling.CreateMessageAsync(message); + return result.Content is TextContentBlock textBlock ? textBlock.Text : string.Empty; + } + + /// + /// 检查客户端是否支持 Sampling 能力(IsSupported)。 + /// + [McpServerTool] + public string CheckSamplingCapability(IMcpServerCallToolContext context) + { + return $"has_capability={context.Sampling.IsSupported}"; + } +} diff --git a/tests/DotNetCampus.ModelContextProtocol.Tests/Servers/SamplingTests.cs b/tests/DotNetCampus.ModelContextProtocol.Tests/Servers/SamplingTests.cs new file mode 100644 index 0000000..c9dbe60 --- /dev/null +++ b/tests/DotNetCampus.ModelContextProtocol.Tests/Servers/SamplingTests.cs @@ -0,0 +1,80 @@ +using System.Text.Json; +using DotNetCampus.ModelContextProtocol.Protocol.Messages; +using DotNetCampus.ModelContextProtocol.Tests.McpTools; + +namespace DotNetCampus.ModelContextProtocol.Tests.Servers; + +/// +/// Sampling 功能集成测试:验证服务器向客户端发起 sampling/createMessage 请求的完整流程。 +/// +[TestClass] +public class SamplingTests +{ + #region Sampling 基本功能 + + [TestMethod("Sampling: 服务器工具可通过 context.Sampling 向客户端发起采样请求")] + [DataRow(HttpTransportType.LocalHost, DisplayName = "LocalHost")] + [DataRow(HttpTransportType.TouchSocket, DisplayName = "TouchSocket")] + public async Task ServerToolCanRequestSampling(HttpTransportType transportType) + { + // Arrange + const string expectedResponseText = "Hello from LLM!"; + var samplingHandlerInvoked = false; + + await using var package = await TestMcpFactory.Shared.CreateHttpCoreAsync( + transportType, + configureBuilder: builder => builder.WithTools(t => t.WithTool(() => new SamplingTool())), + configureClient: clientBuilder => clientBuilder.WithSamplingHandler( + (parms, ct) => + { + samplingHandlerInvoked = true; + var result = new CreateMessageResult + { + Role = Role.Assistant, + Content = new TextContentBlock { Text = expectedResponseText }, + Model = "test-model", + StopReason = "endTurn", + }; + return Task.FromResult(result); + })); + + // Act + var toolArgs = JsonSerializer.SerializeToElement(new { message = "What's 2+2?" }); + var callResult = await package.Client.CallToolAsync("ask_llm", toolArgs); + + // Assert + Assert.IsNotNull(callResult, "工具调用结果不应为 null"); + Assert.IsFalse(callResult.IsError, "工具调用不应返回错误"); + Assert.IsTrue(samplingHandlerInvoked, "客户端的 Sampling 处理器应被调用"); + + var textContent = callResult.Content.OfType().FirstOrDefault(); + Assert.IsNotNull(textContent, "工具调用结果应包含文本内容"); + Assert.AreEqual(expectedResponseText, textContent.Text, "工具返回的文本应与 Sampling 响应一致"); + } + + [TestMethod("Sampling: 无 Sampling 能力时 IsSupported 为 false")] + [DataRow(HttpTransportType.LocalHost, DisplayName = "LocalHost")] + [DataRow(HttpTransportType.TouchSocket, DisplayName = "TouchSocket")] + public async Task IsSupportedIsFalseWhenClientHasNoCapability(HttpTransportType transportType) + { + // Arrange - 客户端不配置 WithSamplingHandler,因此不声明采样能力 + await using var package = await TestMcpFactory.Shared.CreateHttpCoreAsync( + transportType, + configureBuilder: builder => builder.WithTools(t => t.WithTool(() => new SamplingTool()))); + + // Act + var callResult = await package.Client.CallToolAsync("check_sampling_capability"); + + // Assert + Assert.IsNotNull(callResult, "工具调用结果不应为 null"); + Assert.IsFalse(callResult.IsError, "工具调用不应返回错误"); + + var textContent = callResult.Content.OfType().FirstOrDefault(); + Assert.IsNotNull(textContent, "工具调用结果应包含文本内容"); + Assert.AreEqual("has_capability=False", textContent.Text, + "when客户端未声明 Sampling 能力时,IsSupported 应为 false"); + } + + #endregion +} + diff --git a/tests/DotNetCampus.ModelContextProtocol.Tests/TestMcpFactory.cs b/tests/DotNetCampus.ModelContextProtocol.Tests/TestMcpFactory.cs index dc12864..47a700d 100644 --- a/tests/DotNetCampus.ModelContextProtocol.Tests/TestMcpFactory.cs +++ b/tests/DotNetCampus.ModelContextProtocol.Tests/TestMcpFactory.cs @@ -144,11 +144,12 @@ public async ValueTask CreateHttpAsync( } /// - /// 核心方法:创建一个完全自定义的 HTTP 传输 MCP 测试包。 + /// 核心方法:创建一个完全自定义的 HTTP 传输 MCP 测试包,支持同时配置服务端和客户端。 /// public async ValueTask CreateHttpCoreAsync( HttpTransportType httpTransportType, - Action configureBuilder) + Action configureBuilder, + Action? configureClient = null) { var port = Interlocked.Increment(ref _port); var mcpServerBuilder = new McpServerBuilder("TestMcpServer", "1.0.0") @@ -174,12 +175,13 @@ public async ValueTask CreateHttpCoreAsync( mcpServer.EnableDebugMode(); await mcpServer.StartAsync(CancellationToken.None); - var mcpClient = new McpClientBuilder() + var mcpClientBuilder = new McpClientBuilder() .WithLogger(DefaultLogger) - .WithHttp($"http://127.0.0.1:{port}/mcp") - .Build(); + .WithHttp($"http://127.0.0.1:{port}/mcp"); + configureClient?.Invoke(mcpClientBuilder); + var builtClient = mcpClientBuilder.Build(); - return new McpTestingPackage(mcpServer, mcpClient); + return new McpTestingPackage(mcpServer, builtClient); } private static IServiceProvider CreateDefaultServices()