Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions astrbot/core/config/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -2131,6 +2131,7 @@ class ChatProviderTemplate(TypedDict):
"embedding_api_key": {
"description": "API Key",
"type": "string",
"hint": "使用 vLLM 作为提供商时,请在 API Key 中填写 'vllm' 以启用兼容模式(自动禁用 dimensions 参数)",
},
"embedding_api_base": {
"description": "API Base URL",
Expand Down Expand Up @@ -2671,12 +2672,12 @@ class ChatProviderTemplate(TypedDict):
"deerflow_assistant_id": {
"description": "Assistant ID",
"type": "string",
"hint": "DeerFlow 2.0 LangGraph assistant_id,默认为 lead_agent。",
"hint": "LangGraph assistant_id,默认为 lead_agent。",
},
"deerflow_model_name": {
"description": "模型名称覆盖",
"type": "string",
"hint": "可选。覆盖 DeerFlow 默认模型(对应运行时 configurable 的 model_name)。",
"hint": "可选。覆盖 DeerFlow 默认模型(对应 runtime context 的 model_name)。",
},
"deerflow_thinking_enabled": {
"description": "启用思考模式",
Expand All @@ -2685,17 +2686,17 @@ class ChatProviderTemplate(TypedDict):
"deerflow_plan_mode": {
"description": "启用计划模式",
"type": "bool",
"hint": "对应 DeerFlow 2.0 运行时 configurable 的 is_plan_mode。",
"hint": "对应 DeerFlow 的 is_plan_mode。",
},
"deerflow_subagent_enabled": {
"description": "启用子智能体",
"type": "bool",
"hint": "对应 DeerFlow 2.0 运行时 configurable 的 subagent_enabled。",
"hint": "对应 DeerFlow 的 subagent_enabled。",
},
"deerflow_max_concurrent_subagents": {
"description": "子智能体最大并发数",
"type": "int",
"hint": "对应 DeerFlow 2.0 运行时 configurable 的 max_concurrent_subagents。仅在启用子智能体时生效,默认 3。",
"hint": "对应 DeerFlow 的 max_concurrent_subagents。仅在启用子智能体时生效,默认 3。",
},
"deerflow_recursion_limit": {
"description": "递归深度上限",
Expand Down
180 changes: 160 additions & 20 deletions astrbot/core/provider/sources/openai_embedding_source.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import httpx
from openai import AsyncOpenAI

from astrbot import logger

from ..entities import ProviderType
from ..provider import EmbeddingProvider
from ..register import register_provider_adapter
Expand All @@ -18,12 +16,14 @@ def __init__(self, provider_config: dict, provider_settings: dict) -> None:
super().__init__(provider_config, provider_settings)
self.provider_config = provider_config
self.provider_settings = provider_settings

proxy = provider_config.get("proxy", "")
provider_id = provider_config.get("id", "unknown_id")
http_client = None
if proxy:
logger.info(f"[OpenAI Embedding] {provider_id} Using proxy: {proxy}")
http_client = httpx.AsyncClient(proxy=proxy)

api_base = (
provider_config.get("embedding_api_base", "https://api.openai.com/v1")
.strip()
Expand All @@ -33,56 +33,196 @@ def __init__(self, provider_config: dict, provider_settings: dict) -> None:
if api_base and not api_base.endswith("/v1") and not api_base.endswith("/v4"):
# /v4 see #5699
api_base = api_base + "/v1"

# [新增] 保存处理后的 api_base 并转换为小写,用于后续特征比对
self.api_base_normalized = api_base.lower()

logger.info(f"[OpenAI Embedding] {provider_id} Using API Base: {api_base}")

self.client = AsyncOpenAI(
api_key=provider_config.get("embedding_api_key"),
base_url=api_base,
timeout=int(provider_config.get("timeout", 20)),
http_client=http_client,
)
self.model = provider_config.get("embedding_model", "text-embedding-3-small")

# [新增] 运行时状态标记:一旦触发 400 错误将此设为 True
self._is_vllm_detected = False

def _is_vllm(self) -> bool:
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改了下_is_vllm的逻辑,现在在OpenAI Embedding提供商的API key这一栏填入“vllm“才会使得dimensions参数不会发给vllm防止报错。优化了之前靠检测端口号的静态判定逻辑

Image

"""检测是否是 vLLM(vLLM 不支持 dimensions 参数)"""
# 1. 优先检查运行时已证实的标记
if self._is_vllm_detected:
return True

# 2. [核心修改] 检查 API Key 是否为 "vllm"
api_key = self.provider_config.get("embedding_api_key", "")
if api_key and api_key.lower() == "vllm":
logger.info("[OpenAI Embedding] vLLM mode enabled by API Key 'vllm'.")
return True

# 3. 辅助检查:ID 或 URL 中是否显式包含 "vllm"
provider_id = self.provider_config.get("id", "").lower()
api_base = self.api_base_normalized.lower()
if "vllm" in provider_id or "vllm" in api_base:
logger.info(f"[OpenAI Embedding] Detected vLLM by id/api_base: {provider_id}")
return True

# 4. 移除对端口 (8000, 8001) 的静态判定,避免误伤其他兼容服务
return False

def _mark_as_vllm(self) -> None:
"""标记此实例为vLLM(通过运行时错误检测出来的)"""
self._is_vllm_detected = True
logger.info("[OpenAI Embedding] Marked as vLLM (runtime detection via error)")

async def get_embedding(self, text: str) -> list[float]:
Comment thread
Creeper3222 marked this conversation as resolved.
"""获取文本的嵌入"""
kwargs = self._embedding_kwargs()
embedding = await self.client.embeddings.create(
input=text,
model=self.model,
**kwargs,
)
return embedding.data[0].embedding
try:
embedding = await self.client.embeddings.create(
input=text,
model=self.model,
**kwargs,
)
return embedding.data[0].embedding
except Exception as e:
# 如果包含"matryoshka"或"dimensions"相关的错误,说明vLLM不支持该参数
# 尝试不带dimensions重试
error_msg = str(e).lower()
if ("matryoshka" in error_msg or "dimensions" in error_msg) and kwargs.get("dimensions"):
logger.warning(
f"[OpenAI Embedding] Detected vLLM dimensions error, retrying without dimensions parameter: {e}"
)
kwargs_retry = {k: v for k, v in kwargs.items() if k != "dimensions"}
Comment thread
Creeper3222 marked this conversation as resolved.
try:
embedding = await self.client.embeddings.create(
input=text,
model=self.model,
**kwargs_retry,
)
logger.info(
"[OpenAI Embedding] Successfully retrieved embedding without dimensions parameter, marking as vLLM"
)
# 标记为vLLM以便后续调用也跳过dimensions
self._mark_as_vllm()
return embedding.data[0].embedding
except Exception as retry_error:
logger.error(
f"[OpenAI Embedding] Retry without dimensions also failed: {retry_error}"
)
raise retry_error
else:
raise

async def get_embeddings(self, text: list[str]) -> list[list[float]]:
"""批量获取文本的嵌入"""
kwargs = self._embedding_kwargs()
embeddings = await self.client.embeddings.create(
input=text,
model=self.model,
**kwargs,
)
return [item.embedding for item in embeddings.data]
try:
embeddings = await self.client.embeddings.create(
input=text,
model=self.model,
**kwargs,
)
return [item.embedding for item in embeddings.data]
except Exception as e:
# 如果包含"matryoshka"或"dimensions"相关的错误,说明vLLM不支持该参数
# 尝试不带dimensions重试
error_msg = str(e).lower()
if ("matryoshka" in error_msg or "dimensions" in error_msg) and kwargs.get("dimensions"):
logger.warning(
f"[OpenAI Embedding] Detected vLLM dimensions error in batch mode, retrying without dimensions: {e}"
)
kwargs_retry = {k: v for k, v in kwargs.items() if k != "dimensions"}
try:
embeddings = await self.client.embeddings.create(
input=text,
model=self.model,
**kwargs_retry,
)
logger.info(
"[OpenAI Embedding] Successfully retrieved batch embeddings without dimensions parameter"
)
# 标记为vLLM以便后续调用也跳过dimensions
self._mark_as_vllm()
return [item.embedding for item in embeddings.data]
except Exception as retry_error:
logger.error(
f"[OpenAI Embedding] Batch retry without dimensions also failed: {retry_error}"
)
raise retry_error
else:
raise

def _embedding_kwargs(self) -> dict:
"""构建嵌入请求的可选参数"""
kwargs = {}
if "embedding_dimensions" in self.provider_config:
provider_id = self.provider_config.get("id", "unknown")
embedding_dim_config = self.provider_config.get("embedding_dimensions", "")
# 检查是否是vLLM
is_vllm = self._is_vllm()
if is_vllm:
logger.info(
f"[OpenAI Embedding] {provider_id}: Detected vLLM, skipping dimensions parameter (config value: '{embedding_dim_config}')"
)
return kwargs
# 非vLLM服务(OpenAI等)支持dimensions,读取配置
if embedding_dim_config and embedding_dim_config != "":
try:
kwargs["dimensions"] = int(self.provider_config["embedding_dimensions"])
dim_value = int(embedding_dim_config)
kwargs["dimensions"] = dim_value
logger.info(
f"[OpenAI Embedding] {provider_id}: Added dimensions parameter: {dim_value}"
)
except (ValueError, TypeError):
logger.warning(
f"embedding_dimensions in embedding configs is not a valid integer: '{self.provider_config['embedding_dimensions']}', ignored."
f"[OpenAI Embedding] {provider_id}: embedding_dimensions is not a valid integer: '{embedding_dim_config}', ignored."
)
else:
logger.info(
f"[OpenAI Embedding] {provider_id}: No embedding_dimensions configured, API will use default"
)
return kwargs

def get_dim(self) -> int:
"""获取向量的维度"""
if "embedding_dimensions" in self.provider_config:
provider_id = self.provider_config.get("id", "unknown")
# 首先尝试从config读取
embedding_dim_config = self.provider_config.get("embedding_dimensions", "")
if embedding_dim_config and embedding_dim_config != "":
try:
return int(self.provider_config["embedding_dimensions"])
dim = int(embedding_dim_config)
if dim > 0:
logger.info(
f"[OpenAI Embedding] {provider_id}: Dimension from config: {dim}"
)
return dim
except (ValueError, TypeError):
logger.warning(
f"embedding_dimensions in embedding configs is not a valid integer: '{self.provider_config['embedding_dimensions']}', ignored."
f"[OpenAI Embedding] {provider_id}: embedding_dimensions is not a valid integer: '{embedding_dim_config}', trying model inference"
)
# config为空或无效时根据模型名推断维度
# 这样Living Memory可以在自动检测后匹配正确的维度
model = self.provider_config.get("embedding_model", "").lower()
model_dims = {
"bge-m3": 1024,
"bge-large-en-v1.5": 1024,
"bge-large-zh-v1.5": 1024,
"text-embedding-3-small": 1536,
"text-embedding-3-large": 3072,
"text-embedding-ada-002": 1536,
}
for model_key, dim in model_dims.items():
if model_key in model:
logger.info(
f"[OpenAI Embedding] {provider_id}: Inferred dimension {dim} from model: {model}"
)
return dim
# 无法推断时返回0(Living Memory会检测实际维度)
logger.warning(
f"[OpenAI Embedding] {provider_id}: Could not determine dimension (model: {model}, config: '{embedding_dim_config}')"
)
return 0

async def terminate(self):
Expand Down
6 changes: 6 additions & 0 deletions astrbot/dashboard/routes/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,6 +906,12 @@ async def get_embedding_dim(self):
return Response().ok({"embedding_dimensions": dim}).__dict__
except Exception as e:
logger.error(traceback.format_exc())
err_msg = str(e).lower()
# [新增] 识别 vLLM 的特定报错关键字
if "matryoshka" in err_msg or "dimensions" in err_msg:
logger.info("Detected vLLM specific error, bypassing...")
# 伪造一个成功的响应,告知前端进入"兼容模式"
return Response().ok({"embedding_dimensions": "vLLM-Adaptive"}).__dict__
return Response().error(f"获取嵌入维度失败: {e!s}").__dict__

async def get_provider_source_models(self):
Expand Down
6 changes: 3 additions & 3 deletions dashboard/src/components/shared/AstrBotConfig.vue
Original file line number Diff line number Diff line change
Expand Up @@ -104,17 +104,17 @@ function saveEditedContent() {

async function getEmbeddingDimensions(providerConfig) {
if (loadingEmbeddingDim.value) return

loadingEmbeddingDim.value = true
try {
const response = await axios.post('/api/config/provider/get_embedding_dim', {
provider_config: providerConfig
})

if (response.data.status != "error" && response.data.data?.embedding_dimensions) {
console.log(response.data.data.embedding_dimensions)
providerConfig.embedding_dimensions = response.data.data.embedding_dimensions
//[已禁用] 不再自动写入配置文件,仅显示提示
// providerConfig.embedding_dimensions = response.data.data.embedding_dimensions
useToast().success("获取成功: " + response.data.data.embedding_dimensions)
useToast().info(`检测到维度: ${response.data.data.embedding_dimensions}。如需保存,请手动填入后点保存。`)
} else {
useToast().error(response.data.message)
}
Expand Down
13 changes: 7 additions & 6 deletions dashboard/src/i18n/locales/zh-CN/features/config-metadata.json
Original file line number Diff line number Diff line change
Expand Up @@ -1276,7 +1276,8 @@
"hint": "嵌入模型名称。"
},
"embedding_api_key": {
"description": "API Key"
"description": "API Key",
"hint": "使用 vLLM 作为提供商时,请在 API Key 中填写 'vllm' 以启用兼容模式(自动禁用 dimensions 参数)"
},
"embedding_api_base": {
"description": "API Base URL"
Expand Down Expand Up @@ -1604,26 +1605,26 @@
},
"deerflow_assistant_id": {
"description": "Assistant ID",
"hint": "DeerFlow 2.0 LangGraph assistant_id,默认为 lead_agent。"
"hint": "LangGraph assistant_id,默认为 lead_agent。"
},
"deerflow_model_name": {
"description": "模型名称覆盖",
"hint": "可选。覆盖 DeerFlow 默认模型(对应运行时 configurable 的 model_name)。"
"hint": "可选。覆盖 DeerFlow 默认模型(对应 runtime context 的 model_name)。"
},
"deerflow_thinking_enabled": {
"description": "启用思考模式"
},
"deerflow_plan_mode": {
"description": "启用计划模式",
"hint": "对应 DeerFlow 2.0 运行时 configurable 的 is_plan_mode。"
"hint": "对应 DeerFlow 的 is_plan_mode。"
},
"deerflow_subagent_enabled": {
"description": "启用子智能体",
"hint": "对应 DeerFlow 2.0 运行时 configurable 的 subagent_enabled。"
"hint": "对应 DeerFlow 的 subagent_enabled。"
},
"deerflow_max_concurrent_subagents": {
"description": "子智能体最大并发数",
"hint": "对应 DeerFlow 2.0 运行时 configurable 的 max_concurrent_subagents。仅在启用子智能体时生效,默认 3。"
"hint": "对应 DeerFlow 的 max_concurrent_subagents。仅在启用子智能体时生效,默认 3。"
},
"deerflow_recursion_limit": {
"description": "递归深度上限",
Expand Down