Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@ import {
} from '@/lib/bot/request-logging';
import { parseBotCallbackStep } from '@/lib/bot/step-budget';
import { runBotAgent, type BotAgentMessageLike } from '@/lib/bot/agent-runner';
import { getRehydratedBotRequestMessageState } from '@/lib/bot/message-state';
import { botPlatforms } from '@/lib/bot/platforms';
import { getPlatformIntegrationById } from '@/lib/bot/platform-helpers';
import { findUserById } from '@/lib/user';
import type { Thread } from 'chat';
import type { Message, Thread } from 'chat';

type ExecutionCallbackPayload = {
sessionId: string;
Expand Down Expand Up @@ -183,6 +184,7 @@ async function continueBotAgentAfterCallback(params: {
requestRow: Awaited<ReturnType<typeof getBotRequest>>;
platformIntegration: PlatformIntegration;
thread: Thread;
message: Message;
continuationPrompt: string;
completedStepCount: number;
}) {
Expand All @@ -195,23 +197,8 @@ async function continueBotAgentAfterCallback(params: {
return await botPlatforms.require(params.platformIntegration.platform).withAuthContext({
platformIntegration: params.platformIntegration,
fn: async () => {
const originalMessage = await Promise.resolve(
params.thread.adapter.fetchMessage?.(
params.thread.id,
params.requestRow.platform_message_id
) ?? null
).catch(error => {
console.warn('[BotSessionCallback] Failed to fetch original platform message:', {
error,
platform: params.platformIntegration.platform,
threadId: params.thread.id,
messageId: params.requestRow.platform_message_id,
});
return null;
});

const callbackMessage: BotAgentMessageLike = {
author: originalMessage?.author ?? {
author: params.message.author ?? {
fullName: 'Cloud Agent Callback',
isBot: false,
isMe: false,
Expand Down Expand Up @@ -402,6 +389,7 @@ async function handleCompletedCallback(
requestRow: NonNullable<Awaited<ReturnType<typeof getBotRequest>>>,
platformIntegration: PlatformIntegration,
thread: Thread,
message: Message,
completedStepCount: number,
trackedCallbackSession: BotRequestCloudAgentSession | undefined
) {
Expand Down Expand Up @@ -642,6 +630,7 @@ ${cloudAgentResultsForPrompt}`;
requestRow,
platformIntegration,
thread,
message,
continuationPrompt,
completedStepCount,
});
Expand Down Expand Up @@ -882,7 +871,43 @@ export async function POST(
requestRow.platform_integration_id
);
await bot.initialize();
const thread = bot.thread(requestRow.platform_thread_id);
let rehydrated: Awaited<ReturnType<typeof getRehydratedBotRequestMessageState>>;
try {
rehydrated = await getRehydratedBotRequestMessageState(botRequestId);
} catch (error) {
captureException(error, {
tags: {
source: 'bot-session-callback-api',
op: 'rehydrate-message-state',
},
extra: {
botRequestId,
callbackSessionId,
status: payload.status,
},
});
const updated = await failBotRequest({
botRequestId,
errorMessage:
'Cloud Agent callback could not recover message state for this bot request.',
responseTimeMs: Date.now() - startedAt,
});
logCallback(
'Failed to rehydrate message state for callback; marked bot request as error',
{
botRequestId,
updated: Boolean(updated),
}
);
return;
}
const { thread, message } = rehydrated;

logCallback('Resolved callback chat context', {
botRequestId,
threadId: thread.id,
messageId: message?.id,
});

if (childSessionStatus && trackedCallbackSession) {
try {
Expand Down Expand Up @@ -934,6 +959,7 @@ export async function POST(
requestRow,
platformIntegration,
thread,
message,
completedStepCount,
trackedCallbackSession
);
Expand Down
39 changes: 1 addition & 38 deletions apps/web/src/lib/bot-identity.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import * as z from 'zod';
import { NEXTAUTH_SECRET } from '@/lib/config.server';
import { botIdentityRedisKey } from '@/lib/redis-keys';
import { PLATFORM } from '@/lib/integrations/core/constants';
import { serializedMessageSchema, serializedThreadSchema } from '@/lib/bot/message-state';

const CHAT_SDK_CACHE_KEY_PREFIX = 'chat-sdk:cache:';
const LINK_ACCOUNT_CONTEXT_KEY_PREFIX = 'link-account-context:';
Expand Down Expand Up @@ -152,44 +153,6 @@ const platformIdentitySchema = z.object({
userId: z.string(),
});

const serializedThreadShape = z.looseObject({
_type: z.literal('chat:Thread'),
adapterName: z.string(),
channelId: z.string(),
id: z.string(),
isDM: z.boolean(),
});

const serializedThreadSchema = z.custom<SerializedThread>(
value => serializedThreadShape.safeParse(value).success
);

const serializedMessageShape = z.looseObject({
_type: z.literal('chat:Message'),
attachments: z.array(z.unknown()),
author: z.object({
userId: z.string(),
userName: z.string(),
fullName: z.string(),
isBot: z.union([z.boolean(), z.literal('unknown')]),
isMe: z.boolean(),
}),
formatted: z.unknown(),
id: z.string(),
metadata: z.object({
dateSent: z.iso.datetime(),
edited: z.boolean(),
editedAt: z.iso.datetime().optional(),
}),
raw: z.unknown(),
text: z.string(),
threadId: z.string(),
});

const serializedMessageSchema = z.custom<SerializedMessage>(
value => serializedMessageShape.safeParse(value).success
);

const linkTokenPayloadSchema = z.object({
identity: platformIdentitySchema,
contextKey: z.string(),
Expand Down
8 changes: 5 additions & 3 deletions apps/web/src/lib/bot.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@ declare global {

globalThis.__botTestMentionHandler ??= null;

let mockState: { kind: string } | undefined;

function getMockState() {
return { kind: 'state' };
mockState ??= { kind: 'state' };
return mockState;
}

function getMockSlackAdapter() {
Expand Down Expand Up @@ -188,8 +191,6 @@ const mockedCanKiloUserAccessPlatformIntegration = jest.mocked(
const mockedGetPlatformIntegration = jest.mocked(getPlatformIntegration);
const mockedFindUserById = jest.mocked(findUserById);
const mockedProcessLinkedMessage = jest.mocked(processLinkedMessage);
const mockState = getMockState();

function makeThread() {
return { id: 'thread-1', adapter: { name: 'slack' } };
}
Expand Down Expand Up @@ -259,6 +260,7 @@ describe('bot mention authorization', () => {
message,
platformIntegration: integration,
user,
state: mockState,
});
});
});
Expand Down
7 changes: 6 additions & 1 deletion apps/web/src/lib/bot.ts
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,12 @@ function createKiloBot(
}

try {
await processLinkedMessage({ thread, message, platformIntegration, user });
await processLinkedMessage({
thread,
message,
platformIntegration,
user,
});
} catch (error) {
console.error('[Bot] Unhandled error in message handler:', error);
await thread.post({ markdown: 'Sorry, something went wrong while processing your message.' });
Expand Down
137 changes: 137 additions & 0 deletions apps/web/src/lib/bot/message-state.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import 'server-only';
import * as z from 'zod';
import { ThreadImpl, Message, type StateAdapter, type Thread } from 'chat';
import type { Root } from 'mdast';
import { bot } from '@/lib/bot';

const BOT_REQUEST_MESSAGE_STATE_KEY_PREFIX = 'bot-request-message-state:';
const BOT_REQUEST_MESSAGE_STATE_TTL_MS = 24 * 60 * 60 * 1000;

type SerializedThread = ReturnType<Thread['toJSON']>;
type SerializedMessage = ReturnType<Message['toJSON']>;

type BotRequestMessageState = {
thread: SerializedThread;
message: SerializedMessage;
};

const serializedThreadShape = z.looseObject({
_type: z.literal('chat:Thread'),
adapterName: z.string(),
channelId: z.string(),
channelVisibility: z.enum(['private', 'workspace', 'external', 'unknown']).optional(),
currentMessage: z.lazy(() => serializedMessageShape).optional(),
id: z.string(),
isDM: z.boolean(),
}) satisfies z.ZodType<SerializedThread>;

const serializedMessageAttachmentShape = z.object({
type: z.enum(['image', 'file', 'video', 'audio']),
url: z.string().optional(),
name: z.string().optional(),
mimeType: z.string().optional(),
size: z.number().optional(),
width: z.number().optional(),
height: z.number().optional(),
fetchMetadata: z.record(z.string(), z.string()).optional(),
});

const serializedMessageLinkShape = z.object({
url: z.string(),
title: z.string().optional(),
description: z.string().optional(),
imageUrl: z.string().optional(),
siteName: z.string().optional(),
});

const formattedContentShape = z.custom<Root>(
value =>
z.object({ type: z.literal('root'), children: z.array(z.unknown()) }).safeParse(value).success
);

export const serializedThreadSchema = z.custom<SerializedThread>(
value => serializedThreadShape.safeParse(value).success
);

const serializedMessageShape = z.looseObject({
_type: z.literal('chat:Message'),
attachments: z.array(serializedMessageAttachmentShape),
author: z.object({
userId: z.string(),
userName: z.string(),
fullName: z.string(),
isBot: z.union([z.boolean(), z.literal('unknown')]),
isMe: z.boolean(),
}),
formatted: formattedContentShape,
id: z.string(),
isMention: z.boolean().optional(),
links: z.array(serializedMessageLinkShape).optional(),
metadata: z.object({
dateSent: z.iso.datetime(),
edited: z.boolean(),
editedAt: z.iso.datetime().optional(),
}),
raw: z.unknown(),
text: z.string(),
threadId: z.string(),
}) satisfies z.ZodType<SerializedMessage>;

export const serializedMessageSchema = z.custom<SerializedMessage>(
value => serializedMessageShape.safeParse(value).success
);

const botRequestMessageStateSchema = z.object({
thread: serializedThreadSchema,
message: serializedMessageSchema,
});

function botRequestMessageStateKey(botRequestId: string): string {
return `${BOT_REQUEST_MESSAGE_STATE_KEY_PREFIX}${botRequestId}`;
}

export async function storeBotRequestMessageState({
state,
botRequestId,
thread,
message,
}: {
state: StateAdapter;
botRequestId: string;
thread: Thread;
message: Message;
}): Promise<void> {
await state.set<BotRequestMessageState>(
botRequestMessageStateKey(botRequestId),
{
thread: thread.toJSON(),
message: message.toJSON(),
},
BOT_REQUEST_MESSAGE_STATE_TTL_MS
);
}

export async function getBotRequestMessageState(
state: StateAdapter,
botRequestId: string
): Promise<BotRequestMessageState | null> {
const value = await state.get<unknown>(botRequestMessageStateKey(botRequestId));
if (!value) {
return null;
}

return botRequestMessageStateSchema.parse(value);
}

export async function getRehydratedBotRequestMessageState(botRequestId: string) {
const stored = await getBotRequestMessageState(bot.getState(), botRequestId);

if (!stored) {
throw new Error('Could not find message state for botRequest ' + botRequestId);
}

return {
thread: ThreadImpl.fromJSON(stored.thread),
message: Message.fromJSON(stored.message),
};
}
8 changes: 8 additions & 0 deletions apps/web/src/lib/bot/run.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import { createBotRequest, updateBotRequest } from '@/lib/bot/request-logging';
import { runBotAgent } from '@/lib/bot/agent-runner';
import { extractAndUploadImages } from '@/lib/bot/images';
import { storeBotRequestMessageState } from '@/lib/bot/message-state';
import type { PlatformIntegration, User } from '@kilocode/db';
import type { Message, Thread } from 'chat';
import { captureException } from '@sentry/nextjs';
import { bot } from '@/lib/bot';

export async function processLinkedMessage({
thread,
Expand All @@ -30,6 +32,12 @@ export async function processLinkedMessage({
userMessage: message.text,
modelUsed: undefined,
});
await storeBotRequestMessageState({
Comment thread
RSO marked this conversation as resolved.
state: bot.getState(),
botRequestId,
thread,
message,
});
} catch (error) {
captureException(error, {
tags: { component: 'kilo-bot', op: 'create-bot-request' },
Expand Down
Loading