-
Notifications
You must be signed in to change notification settings - Fork 10
Model Config: DB-driven validator, seed sarvamai/elevenlabs/google #859
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
9b40427
d7e7169
85d31ec
06f94ef
3bc791a
4be5b8c
45e6b6a
5fa40d2
48d56ef
944efdd
e00df2e
e0f63d4
51e6e01
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,83 @@ | ||
| """seed stt/tts model_config rows for google, sarvamai, elevenlabs | ||
|
|
||
| Revision ID: 063 | ||
| Revises: 062 | ||
| Create Date: 2026-05-19 00:00:00.000000 | ||
|
|
||
| """ | ||
|
|
||
| from alembic import op | ||
|
|
||
| # revision identifiers, used by Alembic. | ||
| revision = "063" | ||
| down_revision = "062" | ||
| branch_labels = None | ||
| depends_on = None | ||
|
|
||
|
|
||
| SEEDED_MODELS = [ | ||
| ("google", "gemini-2.5-pro"), | ||
| ("google", "gemini-3.1-pro-preview"), | ||
| ("google", "gemini-3-flash-preview"), | ||
| ("google", "gemini-2.5-flash"), | ||
| ("google", "gemini-2.5-flash-preview-tts"), | ||
| ("google", "gemini-2.5-pro-preview-tts"), | ||
| ("sarvamai", "saaras:v3"), | ||
| ("sarvamai", "bulbul:v3"), | ||
| ("elevenlabs", "scribe_v2"), | ||
| ("elevenlabs", "eleven_v3"), | ||
| ] | ||
|
|
||
|
|
||
| def upgrade(): | ||
| # Re-align identity sequence to MAX(id) so new rows get contiguous ids | ||
| # even if dev/test DBs drifted from manual inserts/deletes. | ||
| op.execute( | ||
| "SELECT setval(pg_get_serial_sequence('global.model_config', 'id'), " | ||
| "(SELECT COALESCE(MAX(id), 1) FROM global.model_config))" | ||
| ) | ||
|
|
||
| op.execute( | ||
| """ | ||
| INSERT INTO global.model_config | ||
| (provider, model_name, config, input_modalities, output_modalities, pricing, is_active, inserted_at, updated_at) | ||
| VALUES | ||
| ('google', 'gemini-2.5-pro', '{"temperature": {"type": "float", "default": 1.0, "min": 0.0, "max": 2.0, "description": "Controls randomness. Lower = more deterministic."}}', '{AUDIO}', '{TEXT}', NULL, true, NOW(), NOW()), | ||
| ('google', 'gemini-3.1-pro-preview', '{"thinking_level": {"type": "enum", "default": "high", "options": ["low", "medium", "high"], "description": "Max reasoning depth before output. high = best quality, low = faster/cheaper."}}', '{AUDIO}', '{TEXT}', NULL, true, NOW(), NOW()), | ||
| ('google', 'gemini-3-flash-preview', '{"thinking_level": {"type": "enum", "default": "high", "options": ["minimal", "low", "medium", "high"], "description": "Max reasoning depth before output."}}', '{AUDIO}', '{TEXT}', NULL, true, NOW(), NOW()), | ||
| ('google', 'gemini-2.5-flash', '{"temperature": {"type": "float", "default": 1.0, "min": 0.0, "max": 2.0, "description": "Controls randomness. Lower = more deterministic."}}', '{AUDIO}', '{TEXT}', NULL, true, NOW(), NOW()), | ||
| ('google', 'gemini-2.5-flash-preview-tts', '{"voice": {"type": "enum", "default": "Kore", "options": ["Kore", "Orus", "Leda", "Charon"], "description": "TTS voice."}}', '{TEXT}', '{AUDIO}', NULL, true, NOW(), NOW()), | ||
| ('google', 'gemini-2.5-pro-preview-tts', '{"voice": {"type": "enum", "default": "Kore", "options": ["Kore", "Orus", "Leda", "Charon"], "description": "TTS voice."}}', '{TEXT}', '{AUDIO}', NULL, true, NOW(), NOW()), | ||
| ('sarvamai', 'saaras:v3', '{}', '{AUDIO}', '{TEXT}', NULL, true, NOW(), NOW()), | ||
| ('sarvamai', 'bulbul:v3', '{"voice": {"type": "enum", "default": "simran", "options": ["simran", "shubh", "roopa"], "description": "TTS voice."}}', '{TEXT}', '{AUDIO}', NULL, true, NOW(), NOW()), | ||
| ('elevenlabs', 'scribe_v2', '{}', '{AUDIO}', '{TEXT}', NULL, true, NOW(), NOW()), | ||
| ('elevenlabs', 'eleven_v3', '{"voice": {"type": "enum", "default": "Sarah", "options": ["Sarah", "George", "Callum", "Liam"], "description": "TTS voice."}}', '{TEXT}', '{AUDIO}', NULL, true, NOW(), NOW()) | ||
| ON CONFLICT (provider, model_name) DO NOTHING | ||
| """ | ||
| ) | ||
|
|
||
| # Keep sequence in sync after insert | ||
| op.execute( | ||
| "SELECT setval(pg_get_serial_sequence('global.model_config', 'id'), " | ||
| "(SELECT MAX(id) FROM global.model_config))" | ||
| ) | ||
|
|
||
|
|
||
| def downgrade(): | ||
| op.execute( | ||
| """ | ||
| DELETE FROM global.model_config | ||
| WHERE (provider, model_name) IN ( | ||
| ('google', 'gemini-2.5-pro'), | ||
| ('google', 'gemini-3.1-pro-preview'), | ||
| ('google', 'gemini-3-flash-preview'), | ||
| ('google', 'gemini-2.5-flash'), | ||
| ('google', 'gemini-2.5-flash-preview-tts'), | ||
| ('google', 'gemini-2.5-pro-preview-tts'), | ||
| ('sarvamai', 'saaras:v3'), | ||
| ('sarvamai', 'bulbul:v3'), | ||
| ('elevenlabs', 'scribe_v2'), | ||
| ('elevenlabs', 'eleven_v3') | ||
| ) | ||
| """ | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,13 +1,24 @@ | ||
| from typing import Any, Literal | ||
|
|
||
| from fastapi import HTTPException | ||
| from sqlalchemy.dialects.postgresql import ARRAY | ||
| from sqlalchemy.sql import sqltypes | ||
| from sqlmodel import Session, select | ||
|
|
||
| from app.models import ModelConfig | ||
|
|
||
| Provider = Literal["openai", "google", "sarvamai", "elevenlabs"] | ||
| CompletionType = Literal["text", "stt", "tts"] | ||
|
|
||
|
|
||
| def _normalize_provider(raw: str) -> str: | ||
| """Map NativeCompletionConfig providers (e.g. 'openai-native') to model_config provider names.""" | ||
| return raw[: -len("-native")] if raw.endswith("-native") else raw | ||
|
|
||
|
|
||
| def list_active_model_configs( | ||
| session: Session, | ||
| provider: Literal["openai", "google"] | None = None, | ||
| provider: Provider | None = None, | ||
| skip: int = 0, | ||
| limit: int = 100, | ||
| ) -> tuple[list[ModelConfig], bool]: | ||
|
|
@@ -30,7 +41,7 @@ def list_active_model_configs( | |
|
|
||
| def list_all_active_model_configs( | ||
| session: Session, | ||
| provider: Literal["openai", "google"] | None = None, | ||
| provider: Provider | None = None, | ||
| ) -> list[ModelConfig]: | ||
| statement = select(ModelConfig).where(ModelConfig.is_active) | ||
|
|
||
|
|
@@ -42,7 +53,7 @@ def list_all_active_model_configs( | |
|
|
||
|
|
||
| def get_model_config( | ||
| session: Session, provider: Literal["openai", "google"], model_name: str | ||
| session: Session, provider: Provider, model_name: str | ||
| ) -> ModelConfig | None: | ||
| statement = select(ModelConfig).where( | ||
| ModelConfig.provider == provider, | ||
|
|
@@ -52,9 +63,127 @@ def get_model_config( | |
| return session.exec(statement).first() | ||
|
|
||
|
|
||
| def is_reasoning_model( | ||
| session: Session, provider: Literal["openai", "google"], model_name: str | ||
| def _modality_filter(stmt, completion_type: CompletionType): | ||
| """Restrict query to models matching the completion type via modalities.""" | ||
| str_array = ARRAY(sqltypes.String) | ||
| input_col = ModelConfig.input_modalities | ||
| output_col = ModelConfig.output_modalities | ||
|
|
||
| if completion_type == "stt": | ||
| return stmt.where( | ||
| input_col.cast(str_array).contains(["AUDIO"]), | ||
| output_col.cast(str_array).contains(["TEXT"]), | ||
| ) | ||
| if completion_type == "tts": | ||
| return stmt.where( | ||
| input_col.cast(str_array).contains(["TEXT"]), | ||
| output_col.cast(str_array).contains(["AUDIO"]), | ||
| ) | ||
| # text: must produce TEXT and not consume/produce AUDIO | ||
| return stmt.where( | ||
| output_col.cast(str_array).contains(["TEXT"]), | ||
| ~input_col.cast(str_array).contains(["AUDIO"]), | ||
| ~output_col.cast(str_array).contains(["AUDIO"]), | ||
| ) | ||
|
|
||
|
|
||
| def list_supported_models( | ||
| session: Session, provider: Provider, completion_type: CompletionType | ||
| ) -> list[str]: | ||
| """Return active model names for a provider+completion type.""" | ||
| stmt = select(ModelConfig.model_name).where( | ||
| ModelConfig.provider == provider, | ||
| ModelConfig.is_active, | ||
| ) | ||
| stmt = _modality_filter(stmt, completion_type) | ||
| return list(session.exec(stmt).all()) | ||
|
Comment on lines
+69
to
+74
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Order the supported-model list before surfacing it. This list is returned directly in the 400 detail on Lines 141-144. Without an Proposed fix stmt = select(ModelConfig.model_name).where(
ModelConfig.provider == provider,
ModelConfig.completion_type == completion_type,
ModelConfig.is_active,
)
+ stmt = stmt.order_by(ModelConfig.model_name)
return list(session.exec(stmt).all())🤖 Prompt for AI Agents |
||
|
|
||
|
|
||
| def is_model_supported( | ||
| session: Session, | ||
| provider: Provider, | ||
| completion_type: CompletionType, | ||
| model_name: str, | ||
| ) -> bool: | ||
| """Check whether (provider, model_name) is active and matches the completion type.""" | ||
| stmt = select(ModelConfig.id).where( | ||
| ModelConfig.provider == provider, | ||
| ModelConfig.model_name == model_name, | ||
| ModelConfig.is_active, | ||
| ) | ||
| stmt = _modality_filter(stmt, completion_type) | ||
| return session.exec(stmt).first() is not None | ||
|
|
||
|
|
||
| def validate_blob_model_or_raise(session: Session, blob: Any) -> None: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. inline-blob path on /llm/call and /llm/chain no longer validates model/voice. Before this PR, model/voice checks lived in KaapiCompletionConfig.validate_params, so FastAPI ran them on every request body that contained a ConfigBlob — including ad-hoc inline blobs sent to /llm/call. This PR moves the check into validate_blob_model_or_raise (which needs a DB session) and only wires it into ConfigCrud.create_or_raise and ConfigVersionCrud.create_or_raise. The inline-blob branch in services/llm/jobs.py:525 (config_blob = config.blob) never calls it, so a client can now POST {"provider": "google", "type": "tts", "params": {"model": "gemini-99-ultra", "voice": "Nonexistent"}} and the request will be accepted, a job row created, and the failure surfaces deep in the worker instead of as a 4xx at request time |
||
| """Reject ConfigBlob whose completion.params.model is not in model_config. | ||
|
|
||
| Native configs forward raw provider params; we still expect a `model` key | ||
| in params for text/stt/tts. Missing model is treated as a validation error. | ||
| """ | ||
| completion = blob.completion | ||
| raw_provider = completion.provider | ||
| completion_type = completion.type | ||
| if raw_provider is None: | ||
| return | ||
|
|
||
| if raw_provider.endswith("-native"): | ||
| return | ||
|
|
||
|
coderabbitai[bot] marked this conversation as resolved.
|
||
| provider = _normalize_provider(raw_provider) | ||
| model_name = (completion.params or {}).get("model") | ||
| if not model_name: | ||
| raise HTTPException( | ||
| status_code=400, | ||
| detail=f"completion.params.model is required for provider='{raw_provider}'", | ||
| ) | ||
|
|
||
| model_row = get_model_config( | ||
| session=session, | ||
| provider=provider, # type: ignore[arg-type] | ||
| model_name=model_name, | ||
| ) | ||
| if model_row is None or not is_model_supported( | ||
| session=session, | ||
| provider=provider, # type: ignore[arg-type] | ||
| completion_type=completion_type, | ||
| model_name=model_name, | ||
| ): | ||
| allowed = list_supported_models( | ||
| session=session, | ||
| provider=provider, # type: ignore[arg-type] | ||
| completion_type=completion_type, | ||
| ) | ||
| raise HTTPException( | ||
| status_code=400, | ||
| detail=( | ||
| f"Model '{model_name}' is not supported for provider='{provider}' " | ||
| f"type='{completion_type}'. Allowed: {allowed}" | ||
| ), | ||
| ) | ||
|
|
||
| # TTS voice check: voice must match options declared in model_config.config.voice | ||
| if completion_type == "tts": | ||
| voice = (completion.params or {}).get("voice") | ||
| voice_spec = ( | ||
| model_row.config.get("voice") | ||
| if isinstance(model_row.config, dict) | ||
| else None | ||
| ) | ||
| allowed_voices = ( | ||
| voice_spec.get("options") if isinstance(voice_spec, dict) else None | ||
| ) | ||
| if voice and allowed_voices and voice not in allowed_voices: | ||
| raise HTTPException( | ||
| status_code=400, | ||
| detail=( | ||
| f"Voice '{voice}' is not supported for provider='{provider}' " | ||
| f"model='{model_name}'. Allowed: {allowed_voices}" | ||
| ), | ||
| ) | ||
|
|
||
|
|
||
| def is_reasoning_model(session: Session, provider: Provider, model_name: str) -> bool: | ||
| """Return True if the model is configured with a reasoning `effort` control. | ||
|
|
||
| A model is considered reasoning-capable if its `config` JSON contains an | ||
|
|
@@ -69,7 +198,7 @@ def is_reasoning_model( | |
|
|
||
| def estimate_model_cost( | ||
| session: Session, | ||
| provider: Literal["openai", "google"], | ||
| provider: Provider, | ||
| model_name: str, | ||
| input_tokens: int, | ||
| output_tokens: int, | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.