diff --git a/agents/src/inference/index.ts b/agents/src/inference/index.ts index e6704fb14..b6847f59c 100644 --- a/agents/src/inference/index.ts +++ b/agents/src/inference/index.ts @@ -10,6 +10,7 @@ export { LLMStream, type ChatCompletionOptions, type GatewayOptions, + type InferenceClass, type InferenceLLMOptions, type LLMModels, type XAIModels, diff --git a/agents/src/inference/llm.test.ts b/agents/src/inference/llm.test.ts new file mode 100644 index 000000000..0162f2ea4 --- /dev/null +++ b/agents/src/inference/llm.test.ts @@ -0,0 +1,107 @@ +// SPDX-FileCopyrightText: 2026 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import { beforeAll, describe, expect, it } from 'vitest'; +import { ChatContext } from '../llm/index.js'; +import { initializeLogger } from '../log.js'; +import { type InferenceClass, LLM } from './llm.js'; + +beforeAll(() => { + initializeLogger({ level: 'silent', pretty: false }); +}); + +type CapturedHeaders = Record; + +/** + * Build an LLM, stub its OpenAI client's chat.completions.create, start a chat + * stream with the given per-call value, drain the stream, and return the headers + * that were passed to the create call. + */ +async function captureHeaders(opts: { + ctor?: InferenceClass; + perCall?: InferenceClass; +}): Promise { + const llm = new LLM({ + model: 'openai/gpt-4o-mini', + apiKey: 'test-key', + apiSecret: 'test-secret', + baseURL: 'https://example.livekit.cloud', + inferenceClass: opts.ctor, + }); + + let capturedHeaders: CapturedHeaders = {}; + + const stub = async (_body: unknown, options?: unknown) => { + capturedHeaders = (options as { headers?: CapturedHeaders } | undefined)?.headers ?? {}; + return { + [Symbol.asyncIterator]() { + return { + next: async () => ({ done: true as const, value: undefined }), + }; + }, + }; + }; + + const internal = llm as unknown as { + client: { chat: { completions: { create: typeof stub } } }; + }; + internal.client.chat.completions.create = stub; + + const stream = llm.chat({ + chatCtx: new ChatContext(), + inferenceClass: opts.perCall, + }); + + // Drain the stream so run() completes and headers get captured. + for await (const _chunk of stream) { + // no-op — stub yields zero chunks + void _chunk; + } + + return capturedHeaders; +} + +describe('inference.LLM X-LiveKit-Inference-Priority header', () => { + // --- no value anywhere --- + + it('omits the header when neither constructor nor chat() sets inferenceClass', async () => { + const headers = await captureHeaders({}); + expect(headers['X-LiveKit-Inference-Priority']).toBeUndefined(); + }); + + // --- constructor-only --- + + it("uses constructor 'priority' when chat() does not override", async () => { + const headers = await captureHeaders({ ctor: 'priority' }); + expect(headers['X-LiveKit-Inference-Priority']).toBe('priority'); + }); + + it("uses constructor 'standard' when chat() does not override", async () => { + const headers = await captureHeaders({ ctor: 'standard' }); + expect(headers['X-LiveKit-Inference-Priority']).toBe('standard'); + }); + + // --- per-call-only --- + + it("uses per-call 'priority' when no constructor default is set", async () => { + const headers = await captureHeaders({ perCall: 'priority' }); + expect(headers['X-LiveKit-Inference-Priority']).toBe('priority'); + }); + + it("uses per-call 'standard' when no constructor default is set", async () => { + const headers = await captureHeaders({ perCall: 'standard' }); + expect(headers['X-LiveKit-Inference-Priority']).toBe('standard'); + }); + + // --- per-call overrides constructor --- + + it("per-call 'standard' overrides constructor 'priority'", async () => { + const headers = await captureHeaders({ ctor: 'priority', perCall: 'standard' }); + expect(headers['X-LiveKit-Inference-Priority']).toBe('standard'); + }); + + it("per-call 'priority' overrides constructor 'standard'", async () => { + const headers = await captureHeaders({ ctor: 'standard', perCall: 'priority' }); + expect(headers['X-LiveKit-Inference-Priority']).toBe('priority'); + }); +}); diff --git a/agents/src/inference/llm.ts b/agents/src/inference/llm.ts index 104400a53..1cc55149e 100644 --- a/agents/src/inference/llm.ts +++ b/agents/src/inference/llm.ts @@ -9,6 +9,8 @@ import { DEFAULT_API_CONNECT_OPTIONS } from '../types.js'; import { type Expand, toError } from '../utils.js'; import { type AnyString, + INFERENCE_PRIORITY_HEADER, + INFERENCE_PROVIDER_HEADER, buildMetadataHeaders, createAccessToken, getDefaultInferenceUrl, @@ -150,6 +152,8 @@ function dropUnsupportedParams( return result; } +export type InferenceClass = 'priority' | 'standard'; + export interface InferenceLLMOptions { model: LLMModels; provider?: string; @@ -158,6 +162,7 @@ export interface InferenceLLMOptions { apiSecret: string; modelOptions: ChatCompletionOptions; strictToolSchema?: boolean; + inferenceClass?: InferenceClass; } export interface GatewayOptions { @@ -180,6 +185,7 @@ export class LLM extends llm.LLM { apiSecret?: string; modelOptions?: InferenceLLMOptions['modelOptions']; strictToolSchema?: boolean; + inferenceClass?: InferenceClass; }) { super(); @@ -191,6 +197,7 @@ export class LLM extends llm.LLM { apiSecret, modelOptions, strictToolSchema = false, + inferenceClass, } = opts; const lkBaseURL = baseURL || getDefaultInferenceUrl(); @@ -213,6 +220,7 @@ export class LLM extends llm.LLM { apiSecret: lkApiSecret, modelOptions: modelOptions || {}, strictToolSchema, + inferenceClass, }; this.client = new OpenAI({ @@ -243,6 +251,7 @@ export class LLM extends llm.LLM { connOptions = DEFAULT_API_CONNECT_OPTIONS, parallelToolCalls, toolChoice, + inferenceClass, // TODO(AJS-270): Add response_format parameter support extraKwargs, }: { @@ -251,6 +260,7 @@ export class LLM extends llm.LLM { connOptions?: APIConnectOptions; parallelToolCalls?: boolean; toolChoice?: llm.ToolChoice; + inferenceClass?: InferenceClass; // TODO(AJS-270): Add responseFormat parameter extraKwargs?: Record; }): LLMStream { @@ -274,6 +284,9 @@ export class LLM extends llm.LLM { modelOptions.tool_choice = toolChoice as ToolChoice; } + const resolvedInferenceClass = + inferenceClass !== undefined ? inferenceClass : this.opts.inferenceClass; + // TODO(AJS-270): Add response_format support here modelOptions = { ...modelOptions, ...this.opts.modelOptions }; @@ -291,6 +304,7 @@ export class LLM extends llm.LLM { apiKey: this.opts.apiKey, apiSecret: this.opts.apiSecret, }, + inferenceClass: resolvedInferenceClass, }); } } @@ -302,6 +316,7 @@ export class LLMStream extends llm.LLMStream { private client: OpenAI; private modelOptions: Record; private strictToolSchema: boolean; + private inferenceClass?: InferenceClass; private gatewayOptions?: GatewayOptions; private toolCallId?: string; @@ -323,6 +338,7 @@ export class LLMStream extends llm.LLMStream { modelOptions, providerFmt, strictToolSchema, + inferenceClass, }: { model: LLMModels; provider?: string; @@ -334,6 +350,7 @@ export class LLMStream extends llm.LLMStream { modelOptions: Record; providerFmt?: llm.ProviderFormat; strictToolSchema: boolean; + inferenceClass?: InferenceClass; }, ) { super(llm, { chatCtx, toolCtx, connOptions }); @@ -344,6 +361,7 @@ export class LLMStream extends llm.LLMStream { this.modelOptions = modelOptions; this.model = model; this.strictToolSchema = strictToolSchema; + this.inferenceClass = inferenceClass; } protected async run(): Promise { @@ -403,7 +421,10 @@ export class LLMStream extends llm.LLMStream { ...((requestOptions.extra_headers as Record | undefined) ?? {}), }; if (this.provider) { - extraHeaders['X-LiveKit-Inference-Provider'] = this.provider; + extraHeaders[INFERENCE_PROVIDER_HEADER] = this.provider; + } + if (this.inferenceClass !== undefined) { + extraHeaders[INFERENCE_PRIORITY_HEADER] = this.inferenceClass; } delete requestOptions.extra_headers; diff --git a/agents/src/inference/utils.ts b/agents/src/inference/utils.ts index 782d7e8c9..57c3b2ce1 100644 --- a/agents/src/inference/utils.ts +++ b/agents/src/inference/utils.ts @@ -16,6 +16,10 @@ export const DEFAULT_INFERENCE_URL = 'https://agent-gateway.livekit.cloud/v1'; /** Staging inference URL */ export const STAGING_INFERENCE_URL = 'https://agent-gateway.staging.livekit.cloud/v1'; +/** LiveKit Agent Gateway routing header names. */ +export const INFERENCE_PROVIDER_HEADER = 'X-LiveKit-Inference-Provider'; +export const INFERENCE_PRIORITY_HEADER = 'X-LiveKit-Inference-Priority'; + /** * Get the default inference URL based on the environment. *