diff --git a/apps/web/src/app/api/internal/bot-session-callback/[botRequestId]/route.ts b/apps/web/src/app/api/internal/bot-session-callback/[botRequestId]/route.ts index ef6eb5dda7..dda710d0b5 100644 --- a/apps/web/src/app/api/internal/bot-session-callback/[botRequestId]/route.ts +++ b/apps/web/src/app/api/internal/bot-session-callback/[botRequestId]/route.ts @@ -24,10 +24,11 @@ import { } from '@/lib/bot/request-logging'; import { parseBotCallbackStep } from '@/lib/bot/step-budget'; import { runBotAgent, type BotAgentMessageLike } from '@/lib/bot/agent-runner'; +import { getRehydratedBotRequestMessageState } from '@/lib/bot/message-state'; import { botPlatforms } from '@/lib/bot/platforms'; import { getPlatformIntegrationById } from '@/lib/bot/platform-helpers'; import { findUserById } from '@/lib/user'; -import type { Thread } from 'chat'; +import type { Message, Thread } from 'chat'; type ExecutionCallbackPayload = { sessionId: string; @@ -183,6 +184,7 @@ async function continueBotAgentAfterCallback(params: { requestRow: Awaited>; platformIntegration: PlatformIntegration; thread: Thread; + message: Message; continuationPrompt: string; completedStepCount: number; }) { @@ -195,23 +197,8 @@ async function continueBotAgentAfterCallback(params: { return await botPlatforms.require(params.platformIntegration.platform).withAuthContext({ platformIntegration: params.platformIntegration, fn: async () => { - const originalMessage = await Promise.resolve( - params.thread.adapter.fetchMessage?.( - params.thread.id, - params.requestRow.platform_message_id - ) ?? null - ).catch(error => { - console.warn('[BotSessionCallback] Failed to fetch original platform message:', { - error, - platform: params.platformIntegration.platform, - threadId: params.thread.id, - messageId: params.requestRow.platform_message_id, - }); - return null; - }); - const callbackMessage: BotAgentMessageLike = { - author: originalMessage?.author ?? { + author: params.message.author ?? { fullName: 'Cloud Agent Callback', isBot: false, isMe: false, @@ -402,6 +389,7 @@ async function handleCompletedCallback( requestRow: NonNullable>>, platformIntegration: PlatformIntegration, thread: Thread, + message: Message, completedStepCount: number, trackedCallbackSession: BotRequestCloudAgentSession | undefined ) { @@ -642,6 +630,7 @@ ${cloudAgentResultsForPrompt}`; requestRow, platformIntegration, thread, + message, continuationPrompt, completedStepCount, }); @@ -882,7 +871,43 @@ export async function POST( requestRow.platform_integration_id ); await bot.initialize(); - const thread = bot.thread(requestRow.platform_thread_id); + let rehydrated: Awaited>; + try { + rehydrated = await getRehydratedBotRequestMessageState(botRequestId); + } catch (error) { + captureException(error, { + tags: { + source: 'bot-session-callback-api', + op: 'rehydrate-message-state', + }, + extra: { + botRequestId, + callbackSessionId, + status: payload.status, + }, + }); + const updated = await failBotRequest({ + botRequestId, + errorMessage: + 'Cloud Agent callback could not recover message state for this bot request.', + responseTimeMs: Date.now() - startedAt, + }); + logCallback( + 'Failed to rehydrate message state for callback; marked bot request as error', + { + botRequestId, + updated: Boolean(updated), + } + ); + return; + } + const { thread, message } = rehydrated; + + logCallback('Resolved callback chat context', { + botRequestId, + threadId: thread.id, + messageId: message?.id, + }); if (childSessionStatus && trackedCallbackSession) { try { @@ -934,6 +959,7 @@ export async function POST( requestRow, platformIntegration, thread, + message, completedStepCount, trackedCallbackSession ); diff --git a/apps/web/src/lib/bot-identity.ts b/apps/web/src/lib/bot-identity.ts index 25f7497237..8efb723708 100644 --- a/apps/web/src/lib/bot-identity.ts +++ b/apps/web/src/lib/bot-identity.ts @@ -5,6 +5,7 @@ import * as z from 'zod'; import { NEXTAUTH_SECRET } from '@/lib/config.server'; import { botIdentityRedisKey } from '@/lib/redis-keys'; import { PLATFORM } from '@/lib/integrations/core/constants'; +import { serializedMessageSchema, serializedThreadSchema } from '@/lib/bot/message-state'; const CHAT_SDK_CACHE_KEY_PREFIX = 'chat-sdk:cache:'; const LINK_ACCOUNT_CONTEXT_KEY_PREFIX = 'link-account-context:'; @@ -152,44 +153,6 @@ const platformIdentitySchema = z.object({ userId: z.string(), }); -const serializedThreadShape = z.looseObject({ - _type: z.literal('chat:Thread'), - adapterName: z.string(), - channelId: z.string(), - id: z.string(), - isDM: z.boolean(), -}); - -const serializedThreadSchema = z.custom( - value => serializedThreadShape.safeParse(value).success -); - -const serializedMessageShape = z.looseObject({ - _type: z.literal('chat:Message'), - attachments: z.array(z.unknown()), - author: z.object({ - userId: z.string(), - userName: z.string(), - fullName: z.string(), - isBot: z.union([z.boolean(), z.literal('unknown')]), - isMe: z.boolean(), - }), - formatted: z.unknown(), - id: z.string(), - metadata: z.object({ - dateSent: z.iso.datetime(), - edited: z.boolean(), - editedAt: z.iso.datetime().optional(), - }), - raw: z.unknown(), - text: z.string(), - threadId: z.string(), -}); - -const serializedMessageSchema = z.custom( - value => serializedMessageShape.safeParse(value).success -); - const linkTokenPayloadSchema = z.object({ identity: platformIdentitySchema, contextKey: z.string(), diff --git a/apps/web/src/lib/bot.test.ts b/apps/web/src/lib/bot.test.ts index 46285aeaba..74f5f945e8 100644 --- a/apps/web/src/lib/bot.test.ts +++ b/apps/web/src/lib/bot.test.ts @@ -4,8 +4,11 @@ declare global { globalThis.__botTestMentionHandler ??= null; +let mockState: { kind: string } | undefined; + function getMockState() { - return { kind: 'state' }; + mockState ??= { kind: 'state' }; + return mockState; } function getMockSlackAdapter() { @@ -188,8 +191,6 @@ const mockedCanKiloUserAccessPlatformIntegration = jest.mocked( const mockedGetPlatformIntegration = jest.mocked(getPlatformIntegration); const mockedFindUserById = jest.mocked(findUserById); const mockedProcessLinkedMessage = jest.mocked(processLinkedMessage); -const mockState = getMockState(); - function makeThread() { return { id: 'thread-1', adapter: { name: 'slack' } }; } @@ -259,6 +260,7 @@ describe('bot mention authorization', () => { message, platformIntegration: integration, user, + state: mockState, }); }); }); diff --git a/apps/web/src/lib/bot.ts b/apps/web/src/lib/bot.ts index 1eae29fa6f..6bd5550c18 100644 --- a/apps/web/src/lib/bot.ts +++ b/apps/web/src/lib/bot.ts @@ -101,7 +101,12 @@ function createKiloBot( } try { - await processLinkedMessage({ thread, message, platformIntegration, user }); + await processLinkedMessage({ + thread, + message, + platformIntegration, + user, + }); } catch (error) { console.error('[Bot] Unhandled error in message handler:', error); await thread.post({ markdown: 'Sorry, something went wrong while processing your message.' }); diff --git a/apps/web/src/lib/bot/message-state.ts b/apps/web/src/lib/bot/message-state.ts new file mode 100644 index 0000000000..bdf0d7f424 --- /dev/null +++ b/apps/web/src/lib/bot/message-state.ts @@ -0,0 +1,137 @@ +import 'server-only'; +import * as z from 'zod'; +import { ThreadImpl, Message, type StateAdapter, type Thread } from 'chat'; +import type { Root } from 'mdast'; +import { bot } from '@/lib/bot'; + +const BOT_REQUEST_MESSAGE_STATE_KEY_PREFIX = 'bot-request-message-state:'; +const BOT_REQUEST_MESSAGE_STATE_TTL_MS = 24 * 60 * 60 * 1000; + +type SerializedThread = ReturnType; +type SerializedMessage = ReturnType; + +type BotRequestMessageState = { + thread: SerializedThread; + message: SerializedMessage; +}; + +const serializedThreadShape = z.looseObject({ + _type: z.literal('chat:Thread'), + adapterName: z.string(), + channelId: z.string(), + channelVisibility: z.enum(['private', 'workspace', 'external', 'unknown']).optional(), + currentMessage: z.lazy(() => serializedMessageShape).optional(), + id: z.string(), + isDM: z.boolean(), +}) satisfies z.ZodType; + +const serializedMessageAttachmentShape = z.object({ + type: z.enum(['image', 'file', 'video', 'audio']), + url: z.string().optional(), + name: z.string().optional(), + mimeType: z.string().optional(), + size: z.number().optional(), + width: z.number().optional(), + height: z.number().optional(), + fetchMetadata: z.record(z.string(), z.string()).optional(), +}); + +const serializedMessageLinkShape = z.object({ + url: z.string(), + title: z.string().optional(), + description: z.string().optional(), + imageUrl: z.string().optional(), + siteName: z.string().optional(), +}); + +const formattedContentShape = z.custom( + value => + z.object({ type: z.literal('root'), children: z.array(z.unknown()) }).safeParse(value).success +); + +export const serializedThreadSchema = z.custom( + value => serializedThreadShape.safeParse(value).success +); + +const serializedMessageShape = z.looseObject({ + _type: z.literal('chat:Message'), + attachments: z.array(serializedMessageAttachmentShape), + author: z.object({ + userId: z.string(), + userName: z.string(), + fullName: z.string(), + isBot: z.union([z.boolean(), z.literal('unknown')]), + isMe: z.boolean(), + }), + formatted: formattedContentShape, + id: z.string(), + isMention: z.boolean().optional(), + links: z.array(serializedMessageLinkShape).optional(), + metadata: z.object({ + dateSent: z.iso.datetime(), + edited: z.boolean(), + editedAt: z.iso.datetime().optional(), + }), + raw: z.unknown(), + text: z.string(), + threadId: z.string(), +}) satisfies z.ZodType; + +export const serializedMessageSchema = z.custom( + value => serializedMessageShape.safeParse(value).success +); + +const botRequestMessageStateSchema = z.object({ + thread: serializedThreadSchema, + message: serializedMessageSchema, +}); + +function botRequestMessageStateKey(botRequestId: string): string { + return `${BOT_REQUEST_MESSAGE_STATE_KEY_PREFIX}${botRequestId}`; +} + +export async function storeBotRequestMessageState({ + state, + botRequestId, + thread, + message, +}: { + state: StateAdapter; + botRequestId: string; + thread: Thread; + message: Message; +}): Promise { + await state.set( + botRequestMessageStateKey(botRequestId), + { + thread: thread.toJSON(), + message: message.toJSON(), + }, + BOT_REQUEST_MESSAGE_STATE_TTL_MS + ); +} + +export async function getBotRequestMessageState( + state: StateAdapter, + botRequestId: string +): Promise { + const value = await state.get(botRequestMessageStateKey(botRequestId)); + if (!value) { + return null; + } + + return botRequestMessageStateSchema.parse(value); +} + +export async function getRehydratedBotRequestMessageState(botRequestId: string) { + const stored = await getBotRequestMessageState(bot.getState(), botRequestId); + + if (!stored) { + throw new Error('Could not find message state for botRequest ' + botRequestId); + } + + return { + thread: ThreadImpl.fromJSON(stored.thread), + message: Message.fromJSON(stored.message), + }; +} diff --git a/apps/web/src/lib/bot/run.ts b/apps/web/src/lib/bot/run.ts index 04691178fa..745d993b2e 100644 --- a/apps/web/src/lib/bot/run.ts +++ b/apps/web/src/lib/bot/run.ts @@ -1,9 +1,11 @@ import { createBotRequest, updateBotRequest } from '@/lib/bot/request-logging'; import { runBotAgent } from '@/lib/bot/agent-runner'; import { extractAndUploadImages } from '@/lib/bot/images'; +import { storeBotRequestMessageState } from '@/lib/bot/message-state'; import type { PlatformIntegration, User } from '@kilocode/db'; import type { Message, Thread } from 'chat'; import { captureException } from '@sentry/nextjs'; +import { bot } from '@/lib/bot'; export async function processLinkedMessage({ thread, @@ -30,6 +32,12 @@ export async function processLinkedMessage({ userMessage: message.text, modelUsed: undefined, }); + await storeBotRequestMessageState({ + state: bot.getState(), + botRequestId, + thread, + message, + }); } catch (error) { captureException(error, { tags: { component: 'kilo-bot', op: 'create-bot-request' },