diff --git a/.changeset/forty-results-speak.md b/.changeset/forty-results-speak.md new file mode 100644 index 000000000..6c97e22da --- /dev/null +++ b/.changeset/forty-results-speak.md @@ -0,0 +1,5 @@ +--- +'@hono/node-ws': minor +--- + +Fix WebSocket connections failing when the endpoint is registered under `app.route()`. diff --git a/packages/node-ws/src/index.test.ts b/packages/node-ws/src/index.test.ts index 0ced85ce1..a746fe636 100644 --- a/packages/node-ws/src/index.test.ts +++ b/packages/node-ws/src/index.test.ts @@ -312,4 +312,61 @@ describe('WebSocket helper', () => { expect(clientWs).toBeTruthy() expect(wss.clients.size).toBe(1) }) + + it('Should work with app.route()', async () => { + const subApp = new Hono() + subApp.get( + '/ws', + upgradeWebSocket(() => ({ + onOpen(_, ws) { + ws.send('Hello from sub app') + }, + })) + ) + + app.route('/sub', subApp) + + const ws = new WebSocket('ws://localhost:3030/sub/ws') + const mainPromise = new Promise((resolve, reject) => { + ws.onmessage = (event) => { + resolve(event.data as string) + } + ws.onerror = () => { + reject(new Error('WebSocket error')) + } + }) + + expect(await mainPromise).toBe('Hello from sub app') + ws.close() + }) + + it('Should work with nested app.route()', async () => { + const subSubApp = new Hono() + subSubApp.get( + '/ws', + upgradeWebSocket(() => ({ + onOpen(_, ws) { + ws.send('Hello from nested') + }, + })) + ) + + const subApp = new Hono() + subApp.route('/nested', subSubApp) + + app.route('/sub', subApp) + + const ws = new WebSocket('ws://localhost:3030/sub/nested/ws') + const mainPromise = new Promise((resolve, reject) => { + ws.onmessage = (event) => { + resolve(event.data as string) + } + ws.onerror = () => { + reject(new Error('WebSocket error')) + } + }) + + expect(await mainPromise).toBe('Hello from nested') + ws.close() + }) }) diff --git a/packages/node-ws/src/index.ts b/packages/node-ws/src/index.ts index 39b8044af..2351996d5 100644 --- a/packages/node-ws/src/index.ts +++ b/packages/node-ws/src/index.ts @@ -36,22 +36,20 @@ const CONNECTION_SYMBOL_KEY: unique symbol = Symbol('CONNECTION_SYMBOL_KEY') */ export const createNodeWebSocket = (init: NodeWebSocketInit): NodeWebSocket => { const wss = new WebSocketServer({ noServer: true }) - const waiterMap = new Map< - IncomingMessage, - { resolve: (ws: WebSocket) => void; connectionSymbol: symbol } - >() + const upgradeAllowed = new WeakSet() + const waiterMap = new Map void>() wss.on('connection', (ws, request) => { - const waiter = waiterMap.get(request) - if (waiter) { - waiter.resolve(ws) + const resolve = waiterMap.get(request) + if (resolve) { + resolve(ws) waiterMap.delete(request) } }) - const nodeUpgradeWebSocket = (request: IncomingMessage, connectionSymbol: symbol) => { + const nodeUpgradeWebSocket = (request: IncomingMessage) => { return new Promise((resolve) => { - waiterMap.set(request, { resolve, connectionSymbol }) + waiterMap.set(request, resolve) }) } @@ -72,15 +70,13 @@ export const createNodeWebSocket = (init: NodeWebSocketInit): NodeWebSocket => { const env: { incoming: IncomingMessage outgoing: undefined - [CONNECTION_SYMBOL_KEY]?: symbol } = { incoming: request, outgoing: undefined, } await init.app.request(url, { headers: headers }, env) - const waiter = waiterMap.get(request) - if (!waiter || waiter.connectionSymbol !== env[CONNECTION_SYMBOL_KEY]) { + if (!upgradeAllowed.has(request)) { socket.end( 'HTTP/1.1 400 Bad Request\r\n' + 'Connection: close\r\n' + @@ -91,6 +87,9 @@ export const createNodeWebSocket = (init: NodeWebSocketInit): NodeWebSocket => { return } + // Remove the mark after checking to prevent memory leak + upgradeAllowed.delete(request) + wss.handleUpgrade(request, socket, head, (ws) => { wss.emit('connection', ws, request) }) @@ -102,10 +101,12 @@ export const createNodeWebSocket = (init: NodeWebSocketInit): NodeWebSocket => { return } - const connectionSymbol = generateConnectionSymbol() - c.env[CONNECTION_SYMBOL_KEY] = connectionSymbol + const request = c.env.incoming as IncomingMessage + + // Instead of writing to c.env, use a WeakSet to track the request object directly + upgradeAllowed.add(request) ;(async () => { - const ws = await nodeUpgradeWebSocket(c.env.incoming, connectionSymbol) + const ws = await nodeUpgradeWebSocket(request) // buffer messages to handle messages received before the events are set up const messagesReceivedInStarting: [data: WebSocket.RawData, isBinary: boolean][] = []