diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 6787370460..ad6500058a 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -2131,6 +2131,7 @@ class ChatProviderTemplate(TypedDict): "embedding_api_key": { "description": "API Key", "type": "string", + "hint": "使用 vLLM 作为提供商时,请在 API Key 中填写 'vllm' 以启用兼容模式(自动禁用 dimensions 参数)", }, "embedding_api_base": { "description": "API Base URL", @@ -2671,12 +2672,12 @@ class ChatProviderTemplate(TypedDict): "deerflow_assistant_id": { "description": "Assistant ID", "type": "string", - "hint": "DeerFlow 2.0 LangGraph assistant_id,默认为 lead_agent。", + "hint": "LangGraph assistant_id,默认为 lead_agent。", }, "deerflow_model_name": { "description": "模型名称覆盖", "type": "string", - "hint": "可选。覆盖 DeerFlow 默认模型(对应运行时 configurable 的 model_name)。", + "hint": "可选。覆盖 DeerFlow 默认模型(对应 runtime context 的 model_name)。", }, "deerflow_thinking_enabled": { "description": "启用思考模式", @@ -2685,17 +2686,17 @@ class ChatProviderTemplate(TypedDict): "deerflow_plan_mode": { "description": "启用计划模式", "type": "bool", - "hint": "对应 DeerFlow 2.0 运行时 configurable 的 is_plan_mode。", + "hint": "对应 DeerFlow 的 is_plan_mode。", }, "deerflow_subagent_enabled": { "description": "启用子智能体", "type": "bool", - "hint": "对应 DeerFlow 2.0 运行时 configurable 的 subagent_enabled。", + "hint": "对应 DeerFlow 的 subagent_enabled。", }, "deerflow_max_concurrent_subagents": { "description": "子智能体最大并发数", "type": "int", - "hint": "对应 DeerFlow 2.0 运行时 configurable 的 max_concurrent_subagents。仅在启用子智能体时生效,默认 3。", + "hint": "对应 DeerFlow 的 max_concurrent_subagents。仅在启用子智能体时生效,默认 3。", }, "deerflow_recursion_limit": { "description": "递归深度上限", diff --git a/astrbot/core/provider/sources/openai_embedding_source.py b/astrbot/core/provider/sources/openai_embedding_source.py index ae531996ae..dab622db08 100644 --- a/astrbot/core/provider/sources/openai_embedding_source.py +++ b/astrbot/core/provider/sources/openai_embedding_source.py @@ -1,8 +1,6 @@ import httpx from openai import AsyncOpenAI - from astrbot import logger - from ..entities import ProviderType from ..provider import EmbeddingProvider from ..register import register_provider_adapter @@ -18,12 +16,14 @@ def __init__(self, provider_config: dict, provider_settings: dict) -> None: super().__init__(provider_config, provider_settings) self.provider_config = provider_config self.provider_settings = provider_settings + proxy = provider_config.get("proxy", "") provider_id = provider_config.get("id", "unknown_id") http_client = None if proxy: logger.info(f"[OpenAI Embedding] {provider_id} Using proxy: {proxy}") http_client = httpx.AsyncClient(proxy=proxy) + api_base = ( provider_config.get("embedding_api_base", "https://api.openai.com/v1") .strip() @@ -33,7 +33,12 @@ def __init__(self, provider_config: dict, provider_settings: dict) -> None: if api_base and not api_base.endswith("/v1") and not api_base.endswith("/v4"): # /v4 see #5699 api_base = api_base + "/v1" + + # [新增] 保存处理后的 api_base 并转换为小写,用于后续特征比对 + self.api_base_normalized = api_base.lower() + logger.info(f"[OpenAI Embedding] {provider_id} Using API Base: {api_base}") + self.client = AsyncOpenAI( api_key=provider_config.get("embedding_api_key"), base_url=api_base, @@ -41,48 +46,183 @@ def __init__(self, provider_config: dict, provider_settings: dict) -> None: http_client=http_client, ) self.model = provider_config.get("embedding_model", "text-embedding-3-small") + + # [新增] 运行时状态标记:一旦触发 400 错误将此设为 True + self._is_vllm_detected = False + + def _is_vllm(self) -> bool: + """检测是否是 vLLM(vLLM 不支持 dimensions 参数)""" + # 1. 优先检查运行时已证实的标记 + if self._is_vllm_detected: + return True + + # 2. [核心修改] 检查 API Key 是否为 "vllm" + api_key = self.provider_config.get("embedding_api_key", "") + if api_key and api_key.lower() == "vllm": + logger.info("[OpenAI Embedding] vLLM mode enabled by API Key 'vllm'.") + return True + + # 3. 辅助检查:ID 或 URL 中是否显式包含 "vllm" + provider_id = self.provider_config.get("id", "").lower() + api_base = self.api_base_normalized.lower() + if "vllm" in provider_id or "vllm" in api_base: + logger.info(f"[OpenAI Embedding] Detected vLLM by id/api_base: {provider_id}") + return True + + # 4. 移除对端口 (8000, 8001) 的静态判定,避免误伤其他兼容服务 + return False + + def _mark_as_vllm(self) -> None: + """标记此实例为vLLM(通过运行时错误检测出来的)""" + self._is_vllm_detected = True + logger.info("[OpenAI Embedding] Marked as vLLM (runtime detection via error)") async def get_embedding(self, text: str) -> list[float]: """获取文本的嵌入""" kwargs = self._embedding_kwargs() - embedding = await self.client.embeddings.create( - input=text, - model=self.model, - **kwargs, - ) - return embedding.data[0].embedding + try: + embedding = await self.client.embeddings.create( + input=text, + model=self.model, + **kwargs, + ) + return embedding.data[0].embedding + except Exception as e: + # 如果包含"matryoshka"或"dimensions"相关的错误,说明vLLM不支持该参数 + # 尝试不带dimensions重试 + error_msg = str(e).lower() + if ("matryoshka" in error_msg or "dimensions" in error_msg) and kwargs.get("dimensions"): + logger.warning( + f"[OpenAI Embedding] Detected vLLM dimensions error, retrying without dimensions parameter: {e}" + ) + kwargs_retry = {k: v for k, v in kwargs.items() if k != "dimensions"} + try: + embedding = await self.client.embeddings.create( + input=text, + model=self.model, + **kwargs_retry, + ) + logger.info( + "[OpenAI Embedding] Successfully retrieved embedding without dimensions parameter, marking as vLLM" + ) + # 标记为vLLM以便后续调用也跳过dimensions + self._mark_as_vllm() + return embedding.data[0].embedding + except Exception as retry_error: + logger.error( + f"[OpenAI Embedding] Retry without dimensions also failed: {retry_error}" + ) + raise retry_error + else: + raise async def get_embeddings(self, text: list[str]) -> list[list[float]]: """批量获取文本的嵌入""" kwargs = self._embedding_kwargs() - embeddings = await self.client.embeddings.create( - input=text, - model=self.model, - **kwargs, - ) - return [item.embedding for item in embeddings.data] + try: + embeddings = await self.client.embeddings.create( + input=text, + model=self.model, + **kwargs, + ) + return [item.embedding for item in embeddings.data] + except Exception as e: + # 如果包含"matryoshka"或"dimensions"相关的错误,说明vLLM不支持该参数 + # 尝试不带dimensions重试 + error_msg = str(e).lower() + if ("matryoshka" in error_msg or "dimensions" in error_msg) and kwargs.get("dimensions"): + logger.warning( + f"[OpenAI Embedding] Detected vLLM dimensions error in batch mode, retrying without dimensions: {e}" + ) + kwargs_retry = {k: v for k, v in kwargs.items() if k != "dimensions"} + try: + embeddings = await self.client.embeddings.create( + input=text, + model=self.model, + **kwargs_retry, + ) + logger.info( + "[OpenAI Embedding] Successfully retrieved batch embeddings without dimensions parameter" + ) + # 标记为vLLM以便后续调用也跳过dimensions + self._mark_as_vllm() + return [item.embedding for item in embeddings.data] + except Exception as retry_error: + logger.error( + f"[OpenAI Embedding] Batch retry without dimensions also failed: {retry_error}" + ) + raise retry_error + else: + raise def _embedding_kwargs(self) -> dict: """构建嵌入请求的可选参数""" kwargs = {} - if "embedding_dimensions" in self.provider_config: + provider_id = self.provider_config.get("id", "unknown") + embedding_dim_config = self.provider_config.get("embedding_dimensions", "") + # 检查是否是vLLM + is_vllm = self._is_vllm() + if is_vllm: + logger.info( + f"[OpenAI Embedding] {provider_id}: Detected vLLM, skipping dimensions parameter (config value: '{embedding_dim_config}')" + ) + return kwargs + # 非vLLM服务(OpenAI等)支持dimensions,读取配置 + if embedding_dim_config and embedding_dim_config != "": try: - kwargs["dimensions"] = int(self.provider_config["embedding_dimensions"]) + dim_value = int(embedding_dim_config) + kwargs["dimensions"] = dim_value + logger.info( + f"[OpenAI Embedding] {provider_id}: Added dimensions parameter: {dim_value}" + ) except (ValueError, TypeError): logger.warning( - f"embedding_dimensions in embedding configs is not a valid integer: '{self.provider_config['embedding_dimensions']}', ignored." + f"[OpenAI Embedding] {provider_id}: embedding_dimensions is not a valid integer: '{embedding_dim_config}', ignored." ) + else: + logger.info( + f"[OpenAI Embedding] {provider_id}: No embedding_dimensions configured, API will use default" + ) return kwargs def get_dim(self) -> int: """获取向量的维度""" - if "embedding_dimensions" in self.provider_config: + provider_id = self.provider_config.get("id", "unknown") + # 首先尝试从config读取 + embedding_dim_config = self.provider_config.get("embedding_dimensions", "") + if embedding_dim_config and embedding_dim_config != "": try: - return int(self.provider_config["embedding_dimensions"]) + dim = int(embedding_dim_config) + if dim > 0: + logger.info( + f"[OpenAI Embedding] {provider_id}: Dimension from config: {dim}" + ) + return dim except (ValueError, TypeError): logger.warning( - f"embedding_dimensions in embedding configs is not a valid integer: '{self.provider_config['embedding_dimensions']}', ignored." + f"[OpenAI Embedding] {provider_id}: embedding_dimensions is not a valid integer: '{embedding_dim_config}', trying model inference" + ) + # config为空或无效时根据模型名推断维度 + # 这样Living Memory可以在自动检测后匹配正确的维度 + model = self.provider_config.get("embedding_model", "").lower() + model_dims = { + "bge-m3": 1024, + "bge-large-en-v1.5": 1024, + "bge-large-zh-v1.5": 1024, + "text-embedding-3-small": 1536, + "text-embedding-3-large": 3072, + "text-embedding-ada-002": 1536, + } + for model_key, dim in model_dims.items(): + if model_key in model: + logger.info( + f"[OpenAI Embedding] {provider_id}: Inferred dimension {dim} from model: {model}" ) + return dim + # 无法推断时返回0(Living Memory会检测实际维度) + logger.warning( + f"[OpenAI Embedding] {provider_id}: Could not determine dimension (model: {model}, config: '{embedding_dim_config}')" + ) return 0 async def terminate(self): diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index bcd7e075c7..8087b7bd83 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -906,6 +906,12 @@ async def get_embedding_dim(self): return Response().ok({"embedding_dimensions": dim}).__dict__ except Exception as e: logger.error(traceback.format_exc()) + err_msg = str(e).lower() + # [新增] 识别 vLLM 的特定报错关键字 + if "matryoshka" in err_msg or "dimensions" in err_msg: + logger.info("Detected vLLM specific error, bypassing...") + # 伪造一个成功的响应,告知前端进入"兼容模式" + return Response().ok({"embedding_dimensions": "vLLM-Adaptive"}).__dict__ return Response().error(f"获取嵌入维度失败: {e!s}").__dict__ async def get_provider_source_models(self): diff --git a/dashboard/src/components/shared/AstrBotConfig.vue b/dashboard/src/components/shared/AstrBotConfig.vue index 33273a36c9..9a332fb0e5 100644 --- a/dashboard/src/components/shared/AstrBotConfig.vue +++ b/dashboard/src/components/shared/AstrBotConfig.vue @@ -104,17 +104,17 @@ function saveEditedContent() { async function getEmbeddingDimensions(providerConfig) { if (loadingEmbeddingDim.value) return - loadingEmbeddingDim.value = true try { const response = await axios.post('/api/config/provider/get_embedding_dim', { provider_config: providerConfig }) - if (response.data.status != "error" && response.data.data?.embedding_dimensions) { console.log(response.data.data.embedding_dimensions) - providerConfig.embedding_dimensions = response.data.data.embedding_dimensions + //[已禁用] 不再自动写入配置文件,仅显示提示 + // providerConfig.embedding_dimensions = response.data.data.embedding_dimensions useToast().success("获取成功: " + response.data.data.embedding_dimensions) + useToast().info(`检测到维度: ${response.data.data.embedding_dimensions}。如需保存,请手动填入后点保存。`) } else { useToast().error(response.data.message) } diff --git a/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json b/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json index 935abb358e..c5bf0889b7 100644 --- a/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json +++ b/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json @@ -1276,7 +1276,8 @@ "hint": "嵌入模型名称。" }, "embedding_api_key": { - "description": "API Key" + "description": "API Key", + "hint": "使用 vLLM 作为提供商时,请在 API Key 中填写 'vllm' 以启用兼容模式(自动禁用 dimensions 参数)" }, "embedding_api_base": { "description": "API Base URL" @@ -1604,26 +1605,26 @@ }, "deerflow_assistant_id": { "description": "Assistant ID", - "hint": "DeerFlow 2.0 LangGraph assistant_id,默认为 lead_agent。" + "hint": "LangGraph assistant_id,默认为 lead_agent。" }, "deerflow_model_name": { "description": "模型名称覆盖", - "hint": "可选。覆盖 DeerFlow 默认模型(对应运行时 configurable 的 model_name)。" + "hint": "可选。覆盖 DeerFlow 默认模型(对应 runtime context 的 model_name)。" }, "deerflow_thinking_enabled": { "description": "启用思考模式" }, "deerflow_plan_mode": { "description": "启用计划模式", - "hint": "对应 DeerFlow 2.0 运行时 configurable 的 is_plan_mode。" + "hint": "对应 DeerFlow 的 is_plan_mode。" }, "deerflow_subagent_enabled": { "description": "启用子智能体", - "hint": "对应 DeerFlow 2.0 运行时 configurable 的 subagent_enabled。" + "hint": "对应 DeerFlow 的 subagent_enabled。" }, "deerflow_max_concurrent_subagents": { "description": "子智能体最大并发数", - "hint": "对应 DeerFlow 2.0 运行时 configurable 的 max_concurrent_subagents。仅在启用子智能体时生效,默认 3。" + "hint": "对应 DeerFlow 的 max_concurrent_subagents。仅在启用子智能体时生效,默认 3。" }, "deerflow_recursion_limit": { "description": "递归深度上限",