Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions agents/src/inference/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ export {
LLMStream,
type ChatCompletionOptions,
type GatewayOptions,
type InferenceClass,
type InferenceLLMOptions,
type LLMModels,
type XAIModels,
Expand Down
107 changes: 107 additions & 0 deletions agents/src/inference/llm.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
// SPDX-FileCopyrightText: 2026 LiveKit, Inc.
//
// SPDX-License-Identifier: Apache-2.0
import { beforeAll, describe, expect, it } from 'vitest';
import { ChatContext } from '../llm/index.js';
import { initializeLogger } from '../log.js';
import { type InferenceClass, LLM } from './llm.js';

beforeAll(() => {
initializeLogger({ level: 'silent', pretty: false });
});

type CapturedHeaders = Record<string, string>;

/**
* Build an LLM, stub its OpenAI client's chat.completions.create, start a chat
* stream with the given per-call value, drain the stream, and return the headers
* that were passed to the create call.
*/
async function captureHeaders(opts: {
ctor?: InferenceClass;
perCall?: InferenceClass;
}): Promise<CapturedHeaders> {
const llm = new LLM({
model: 'openai/gpt-4o-mini',
apiKey: 'test-key',
apiSecret: 'test-secret',
baseURL: 'https://example.livekit.cloud',
inferenceClass: opts.ctor,
});

let capturedHeaders: CapturedHeaders = {};

const stub = async (_body: unknown, options?: unknown) => {
capturedHeaders = (options as { headers?: CapturedHeaders } | undefined)?.headers ?? {};
return {
[Symbol.asyncIterator]() {
return {
next: async () => ({ done: true as const, value: undefined }),
};
},
};
};

const internal = llm as unknown as {
client: { chat: { completions: { create: typeof stub } } };
};
internal.client.chat.completions.create = stub;

const stream = llm.chat({
chatCtx: new ChatContext(),
inferenceClass: opts.perCall,
});

// Drain the stream so run() completes and headers get captured.
for await (const _chunk of stream) {
// no-op — stub yields zero chunks
void _chunk;
}

return capturedHeaders;
}

describe('inference.LLM X-LiveKit-Inference-Priority header', () => {
// --- no value anywhere ---

it('omits the header when neither constructor nor chat() sets inferenceClass', async () => {
const headers = await captureHeaders({});
expect(headers['X-LiveKit-Inference-Priority']).toBeUndefined();
});

// --- constructor-only ---

it("uses constructor 'priority' when chat() does not override", async () => {
const headers = await captureHeaders({ ctor: 'priority' });
expect(headers['X-LiveKit-Inference-Priority']).toBe('priority');
});

it("uses constructor 'standard' when chat() does not override", async () => {
const headers = await captureHeaders({ ctor: 'standard' });
expect(headers['X-LiveKit-Inference-Priority']).toBe('standard');
});

// --- per-call-only ---

it("uses per-call 'priority' when no constructor default is set", async () => {
const headers = await captureHeaders({ perCall: 'priority' });
expect(headers['X-LiveKit-Inference-Priority']).toBe('priority');
});

it("uses per-call 'standard' when no constructor default is set", async () => {
const headers = await captureHeaders({ perCall: 'standard' });
expect(headers['X-LiveKit-Inference-Priority']).toBe('standard');
});

// --- per-call overrides constructor ---

it("per-call 'standard' overrides constructor 'priority'", async () => {
const headers = await captureHeaders({ ctor: 'priority', perCall: 'standard' });
expect(headers['X-LiveKit-Inference-Priority']).toBe('standard');
});

it("per-call 'priority' overrides constructor 'standard'", async () => {
const headers = await captureHeaders({ ctor: 'standard', perCall: 'priority' });
expect(headers['X-LiveKit-Inference-Priority']).toBe('priority');
});
});
23 changes: 22 additions & 1 deletion agents/src/inference/llm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import { DEFAULT_API_CONNECT_OPTIONS } from '../types.js';
import { type Expand, toError } from '../utils.js';
import {
type AnyString,
INFERENCE_PRIORITY_HEADER,
INFERENCE_PROVIDER_HEADER,
buildMetadataHeaders,
createAccessToken,
getDefaultInferenceUrl,
Expand Down Expand Up @@ -150,6 +152,8 @@ function dropUnsupportedParams(
return result;
}

export type InferenceClass = 'priority' | 'standard';

export interface InferenceLLMOptions {
model: LLMModels;
provider?: string;
Expand All @@ -158,6 +162,7 @@ export interface InferenceLLMOptions {
apiSecret: string;
modelOptions: ChatCompletionOptions;
strictToolSchema?: boolean;
inferenceClass?: InferenceClass;
}

export interface GatewayOptions {
Expand All @@ -180,6 +185,7 @@ export class LLM extends llm.LLM {
apiSecret?: string;
modelOptions?: InferenceLLMOptions['modelOptions'];
strictToolSchema?: boolean;
inferenceClass?: InferenceClass;
}) {
super();

Expand All @@ -191,6 +197,7 @@ export class LLM extends llm.LLM {
apiSecret,
modelOptions,
strictToolSchema = false,
inferenceClass,
} = opts;

const lkBaseURL = baseURL || getDefaultInferenceUrl();
Expand All @@ -213,6 +220,7 @@ export class LLM extends llm.LLM {
apiSecret: lkApiSecret,
modelOptions: modelOptions || {},
strictToolSchema,
inferenceClass,
};

this.client = new OpenAI({
Expand Down Expand Up @@ -243,6 +251,7 @@ export class LLM extends llm.LLM {
connOptions = DEFAULT_API_CONNECT_OPTIONS,
parallelToolCalls,
toolChoice,
inferenceClass,
// TODO(AJS-270): Add response_format parameter support
extraKwargs,
}: {
Expand All @@ -251,6 +260,7 @@ export class LLM extends llm.LLM {
connOptions?: APIConnectOptions;
parallelToolCalls?: boolean;
toolChoice?: llm.ToolChoice;
inferenceClass?: InferenceClass;
// TODO(AJS-270): Add responseFormat parameter
extraKwargs?: Record<string, unknown>;
}): LLMStream {
Expand All @@ -274,6 +284,9 @@ export class LLM extends llm.LLM {
modelOptions.tool_choice = toolChoice as ToolChoice;
}

const resolvedInferenceClass =
inferenceClass !== undefined ? inferenceClass : this.opts.inferenceClass;

// TODO(AJS-270): Add response_format support here

modelOptions = { ...modelOptions, ...this.opts.modelOptions };
Expand All @@ -291,6 +304,7 @@ export class LLM extends llm.LLM {
apiKey: this.opts.apiKey,
apiSecret: this.opts.apiSecret,
},
inferenceClass: resolvedInferenceClass,
});
}
}
Expand All @@ -302,6 +316,7 @@ export class LLMStream extends llm.LLMStream {
private client: OpenAI;
private modelOptions: Record<string, unknown>;
private strictToolSchema: boolean;
private inferenceClass?: InferenceClass;

private gatewayOptions?: GatewayOptions;
private toolCallId?: string;
Expand All @@ -323,6 +338,7 @@ export class LLMStream extends llm.LLMStream {
modelOptions,
providerFmt,
strictToolSchema,
inferenceClass,
}: {
model: LLMModels;
provider?: string;
Expand All @@ -334,6 +350,7 @@ export class LLMStream extends llm.LLMStream {
modelOptions: Record<string, unknown>;
providerFmt?: llm.ProviderFormat;
strictToolSchema: boolean;
inferenceClass?: InferenceClass;
},
) {
super(llm, { chatCtx, toolCtx, connOptions });
Expand All @@ -344,6 +361,7 @@ export class LLMStream extends llm.LLMStream {
this.modelOptions = modelOptions;
this.model = model;
this.strictToolSchema = strictToolSchema;
this.inferenceClass = inferenceClass;
}

protected async run(): Promise<void> {
Expand Down Expand Up @@ -403,7 +421,10 @@ export class LLMStream extends llm.LLMStream {
...((requestOptions.extra_headers as Record<string, string> | undefined) ?? {}),
};
if (this.provider) {
extraHeaders['X-LiveKit-Inference-Provider'] = this.provider;
extraHeaders[INFERENCE_PROVIDER_HEADER] = this.provider;
}
if (this.inferenceClass !== undefined) {
extraHeaders[INFERENCE_PRIORITY_HEADER] = this.inferenceClass;
}
delete requestOptions.extra_headers;

Expand Down
4 changes: 4 additions & 0 deletions agents/src/inference/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ export const DEFAULT_INFERENCE_URL = 'https://agent-gateway.livekit.cloud/v1';
/** Staging inference URL */
export const STAGING_INFERENCE_URL = 'https://agent-gateway.staging.livekit.cloud/v1';

/** LiveKit Agent Gateway routing header names. */
export const INFERENCE_PROVIDER_HEADER = 'X-LiveKit-Inference-Provider';
export const INFERENCE_PRIORITY_HEADER = 'X-LiveKit-Inference-Priority';

/**
* Get the default inference URL based on the environment.
*
Expand Down
Loading