Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 31 additions & 47 deletions packages/kilo-vscode/src/KiloProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,13 @@ import {
} from "./kilo-provider/handlers/question"
import { fetchAndSendPendingSuggestions, routeSuggestionWebviewMessage } from "./kilo-provider/handlers/suggestion"
import { nativeTitle } from "./kilo-provider/native-tab-title"
import * as Preferences from "./kilo-provider/model-preferences"
import { handleEnhancePrompt } from "./kilo-provider/enhance-prompt"

import {
buildActionContext,
computeDefaultSelection,
fetchProviderData,
validateRecents,
validateFavorites,
connectProvider as connectProviderAction,
authorizeProviderOAuth as authorizeOAuthAction,
completeProviderOAuth as completeOAuthAction,
Expand Down Expand Up @@ -319,6 +319,25 @@ export class KiloProvider implements vscode.WebviewViewProvider, TelemetryProper
}
}

private get preferenceCtx(): Parameters<typeof Preferences.persistVariant>[0] {
return {
extensionContext: this.extensionContext,
postMessage: (msg) => this.postMessage(msg),
notifyFavoritesChanged: (favorites) => this.connectionService.notifyFavoritesChanged(favorites),
}
}

private get enhancePromptCtx(): Parameters<typeof handleEnhancePrompt>[0] {
return {
client: this.client,
postMessage: (msg) => this.postMessage(msg),
getErrorMessage,
showErrorMessage: (msg) => {
void vscode.window.showErrorMessage(msg)
},
}
}

// Strip edit-tool metadata.filediff.before/after (multi-MB for edit-heavy
// sessions) to keep session switches fast. Logic in kilo-provider/slim-metadata.ts.
private slimPart<T>(part: T): T {
Expand Down Expand Up @@ -971,41 +990,29 @@ export class KiloProvider implements vscode.WebviewViewProvider, TelemetryProper
TelemetryProxy.capture(message.event, message.properties)
break
case "persistVariant": {
const stored = this.extensionContext?.globalState.get<Record<string, string>>("variantSelections") ?? {}
stored[message.key] = message.value
await this.extensionContext?.globalState.update("variantSelections", stored)
await Preferences.persistVariant(this.preferenceCtx, message.key, message.value)
break
}
case "requestVariants": {
const variants = this.extensionContext?.globalState.get<Record<string, string>>("variantSelections") ?? {}
this.postMessage({ type: "variantsLoaded", variants })
Preferences.requestVariants(this.preferenceCtx)
break
}
case "persistRecents":
await this.extensionContext?.globalState.update("recentModels", validateRecents(message.recents))
await Preferences.persistRecents(this.preferenceCtx, message.recents)
break
case "requestRecents": {
const recents = validateRecents(this.extensionContext?.globalState.get("recentModels"))
this.postMessage({ type: "recentsLoaded", recents })
Preferences.requestRecents(this.preferenceCtx)
break
}
case "toggleFavorite": {
const current = validateFavorites(this.extensionContext?.globalState.get("favoriteModels"))
const key = `${message.providerID}/${message.modelID}`
const exists = current.some((f) => `${f.providerID}/${f.modelID}` === key)
const favorites =
message.action === "add" && !exists
? [...current, { providerID: message.providerID, modelID: message.modelID }]
: message.action === "remove" && exists
? current.filter((f) => `${f.providerID}/${f.modelID}` !== key)
: current
await this.extensionContext?.globalState.update("favoriteModels", favorites)
this.connectionService.notifyFavoritesChanged(favorites)
await Preferences.toggleFavorite(this.preferenceCtx, message.action, {
providerID: message.providerID,
modelID: message.modelID,
})
break
}
case "requestFavorites": {
const favorites = validateFavorites(this.extensionContext?.globalState.get("favoriteModels"))
this.postMessage({ type: "favoritesLoaded", favorites })
Preferences.requestFavorites(this.preferenceCtx)
break
}
// legacy-migration start
Expand All @@ -1026,30 +1033,7 @@ export class KiloProvider implements vscode.WebviewViewProvider, TelemetryProper
break
// legacy-migration end
case "enhancePrompt": {
const sdkClient = this.client
if (!sdkClient) {
this.postMessage({
type: "enhancePromptError",
error: "Not connected to CLI backend",
requestId: message.requestId,
})
break
}
void sdkClient.enhancePrompt
.enhance({ text: message.text }, { throwOnError: true })
.then(({ data }) => {
this.postMessage({ type: "enhancePromptResult", text: data.text, requestId: message.requestId })
})
.catch((err: unknown) => {
const msg = getErrorMessage(err) || "Failed to enhance prompt"
console.error("[Kilo New] KiloProvider: Failed to enhance prompt:", err)
vscode.window.showErrorMessage(`Enhance prompt failed: ${msg}`)
this.postMessage({
type: "enhancePromptError",
error: msg,
requestId: message.requestId,
})
})
handleEnhancePrompt(this.enhancePromptCtx, message.text, message.requestId)
break
}
case "fetchMarketplaceData": {
Expand Down
38 changes: 38 additions & 0 deletions packages/kilo-vscode/src/kilo-provider/enhance-prompt.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import type { KiloClient } from "@kilocode/sdk/v2/client"

type Post = (msg: unknown) => void

interface Context {
readonly client: KiloClient | null
postMessage: Post
getErrorMessage(error: unknown): string
showErrorMessage(message: string): void
}

export function handleEnhancePrompt(ctx: Context, text: string, requestId: string): void {
const client = ctx.client
if (!client) {
ctx.postMessage({
type: "enhancePromptError",
error: "Not connected to CLI backend",
requestId,
})
return
}

void client.enhancePrompt
.enhance({ text }, { throwOnError: true })
.then(({ data }) => {
ctx.postMessage({ type: "enhancePromptResult", text: data.text, requestId })
})
.catch((err: unknown) => {
const msg = ctx.getErrorMessage(err) || "Failed to enhance prompt"
console.error("[Kilo New] KiloProvider: Failed to enhance prompt:", err)
ctx.showErrorMessage(`Enhance prompt failed: ${msg}`)
ctx.postMessage({
type: "enhancePromptError",
error: msg,
requestId,
})
})
}
76 changes: 76 additions & 0 deletions packages/kilo-vscode/src/kilo-provider/model-preferences.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import type { ExtensionContext } from "vscode"

type Model = { providerID: string; modelID: string }
type Post = (msg: unknown) => void

interface Context {
readonly extensionContext?: ExtensionContext
postMessage: Post
notifyFavoritesChanged(favorites: Model[]): void
}

function selection(value: unknown): value is Model {
return (
!!value &&
typeof value === "object" &&
typeof (value as Record<string, unknown>).providerID === "string" &&
typeof (value as Record<string, unknown>).modelID === "string"
)
}

export function validateRecents(raw: unknown): Model[] {
if (!Array.isArray(raw)) return []
return raw
.filter(selection)
.slice(0, 5)
.map((item) => ({ providerID: item.providerID, modelID: item.modelID }))
}

export function validateFavorites(raw: unknown): Model[] {
if (!Array.isArray(raw)) return []
return raw.filter(selection).map((item) => ({ providerID: item.providerID, modelID: item.modelID }))
}

export function updateFavorites(current: Model[], action: "add" | "remove", model: Model): Model[] {
const key = `${model.providerID}/${model.modelID}`
const exists = current.some((item) => `${item.providerID}/${item.modelID}` === key)
if (action === "add" && !exists) return [...current, model]
if (action === "remove" && exists) return current.filter((item) => `${item.providerID}/${item.modelID}` !== key)
return current
}

export async function persistVariant(ctx: Context, key: string, value: string): Promise<void> {
const stored = ctx.extensionContext?.globalState.get<Record<string, string>>("variantSelections") ?? {}
stored[key] = value
await ctx.extensionContext?.globalState.update("variantSelections", stored)
}

export function requestVariants(ctx: Context): void {
const variants = ctx.extensionContext?.globalState.get<Record<string, string>>("variantSelections") ?? {}
ctx.postMessage({ type: "variantsLoaded", variants })
}

export async function persistRecents(ctx: Context, recents: unknown): Promise<void> {
const valid = validateRecents(recents)
await ctx.extensionContext?.globalState.update("recentModels", valid)
}

export function requestRecents(ctx: Context): void {
const stored = ctx.extensionContext?.globalState.get("recentModels")
const recents = validateRecents(stored)
ctx.postMessage({ type: "recentsLoaded", recents })
}

export async function toggleFavorite(ctx: Context, action: "add" | "remove", model: Model): Promise<void> {
const stored = ctx.extensionContext?.globalState.get("favoriteModels")
const current = validateFavorites(stored)
const favorites = updateFavorites(current, action, model)
await ctx.extensionContext?.globalState.update("favoriteModels", favorites)
ctx.notifyFavoritesChanged(favorites)
}

export function requestFavorites(ctx: Context): void {
const stored = ctx.extensionContext?.globalState.get("favoriteModels")
const favorites = validateFavorites(stored)
ctx.postMessage({ type: "favoritesLoaded", favorites })
}
15 changes: 0 additions & 15 deletions packages/kilo-vscode/src/provider-actions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -105,21 +105,6 @@ function isModelSelection(r: unknown): r is { providerID: string; modelID: strin
)
}

/** Validate and sanitize recent model selections from untrusted sources. */
export function validateRecents(raw: unknown): Array<{ providerID: string; modelID: string }> {
if (!Array.isArray(raw)) return []
return raw
.filter(isModelSelection)
.slice(0, 5)
.map((r) => ({ providerID: r.providerID, modelID: r.modelID }))
}

/** Validate and sanitize favorite model selections from untrusted sources. */
export function validateFavorites(raw: unknown): Array<{ providerID: string; modelID: string }> {
if (!Array.isArray(raw)) return []
return raw.filter(isModelSelection).map((r) => ({ providerID: r.providerID, modelID: r.modelID }))
}

/** Validate and sanitize per-mode model selections from untrusted sources. */
export function validateModelSelections(raw: unknown): Record<string, { providerID: string; modelID: string }> {
if (!raw || typeof raw !== "object" || Array.isArray(raw)) return {}
Expand Down
83 changes: 83 additions & 0 deletions packages/kilo-vscode/tests/unit/enhance-prompt.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import { describe, expect, it, spyOn } from "bun:test"
import { handleEnhancePrompt } from "../../src/kilo-provider/enhance-prompt"

type Context = Parameters<typeof handleEnhancePrompt>[0]
type Client = NonNullable<Context["client"]>

async function tick() {
await new Promise((resolve) => setTimeout(resolve, 0))
}

function createCtx(client: Client | null, error = "boom") {
const calls = {
posts: [] as unknown[],
errors: [] as string[],
}
const ctx: Context = {
client,
postMessage: (msg) => calls.posts.push(msg),
getErrorMessage: () => error,
showErrorMessage: (msg) => calls.errors.push(msg),
}
return { calls, ctx }
}

describe("handleEnhancePrompt", () => {
it("posts an error when the client is unavailable", () => {
const { calls, ctx } = createCtx(null)

handleEnhancePrompt(ctx, "draft", "req-1")

expect(calls.posts).toEqual([
{
type: "enhancePromptError",
error: "Not connected to CLI backend",
requestId: "req-1",
},
])
expect(calls.errors).toEqual([])
})

it("posts enhanced text on success", async () => {
const client = {
enhancePrompt: {
enhance: async (input: { text?: string }) => ({ data: { text: `better ${input.text}` } }),
},
} as unknown as Client
const { calls, ctx } = createCtx(client)

handleEnhancePrompt(ctx, "draft", "req-2")
await tick()

expect(calls.posts).toEqual([{ type: "enhancePromptResult", text: "better draft", requestId: "req-2" }])
expect(calls.errors).toEqual([])
})

it("posts and shows the error message on failure", async () => {
const log = spyOn(console, "error").mockImplementation(() => {})
try {
const client = {
enhancePrompt: {
enhance: async () => {
throw new Error("server failed")
},
},
} as unknown as Client
const { calls, ctx } = createCtx(client, "server failed")

handleEnhancePrompt(ctx, "draft", "req-3")
await tick()

expect(calls.errors).toEqual(["Enhance prompt failed: server failed"])
expect(calls.posts).toEqual([
{
type: "enhancePromptError",
error: "server failed",
requestId: "req-3",
},
])
} finally {
log.mockRestore()
}
})
})
36 changes: 36 additions & 0 deletions packages/kilo-vscode/tests/unit/model-preferences.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import { describe, expect, it } from "bun:test"
import { updateFavorites } from "../../src/kilo-provider/model-preferences"

describe("updateFavorites", () => {
it("adds a favorite that is not present", () => {
const current = [{ providerID: "anthropic", modelID: "claude-sonnet-4" }]

expect(updateFavorites(current, "add", { providerID: "openai", modelID: "gpt-5" })).toEqual([
{ providerID: "anthropic", modelID: "claude-sonnet-4" },
{ providerID: "openai", modelID: "gpt-5" },
])
})

it("does not add a duplicate favorite", () => {
const current = [{ providerID: "openai", modelID: "gpt-5" }]

expect(updateFavorites(current, "add", { providerID: "openai", modelID: "gpt-5" })).toBe(current)
})

it("removes an existing favorite", () => {
const current = [
{ providerID: "anthropic", modelID: "claude-sonnet-4" },
{ providerID: "openai", modelID: "gpt-5" },
]

expect(updateFavorites(current, "remove", { providerID: "openai", modelID: "gpt-5" })).toEqual([
{ providerID: "anthropic", modelID: "claude-sonnet-4" },
])
})

it("keeps favorites unchanged when removing a missing favorite", () => {
const current = [{ providerID: "anthropic", modelID: "claude-sonnet-4" }]

expect(updateFavorites(current, "remove", { providerID: "openai", modelID: "gpt-5" })).toBe(current)
})
})
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { describe, expect, it } from "bun:test"
import { validateModelSelections, validateRecents, validateFavorites } from "../../src/provider-actions"
import { validateModelSelections } from "../../src/provider-actions"

describe("validateModelSelections", () => {
it("returns empty object for null", () => {
Expand Down
Loading