-
-
Notifications
You must be signed in to change notification settings - Fork 2k
fix(provider): resolve vLLM embedding compatibility and dimension inf… #7509
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Creeper3222
wants to merge
3
commits into
AstrBotDevs:master
Choose a base branch
from
Creeper3222:fix/vllm-embedding-compatibility
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 1 commit
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,8 +1,6 @@ | ||
| import httpx | ||
| from openai import AsyncOpenAI | ||
|
|
||
| from astrbot import logger | ||
|
|
||
| from ..entities import ProviderType | ||
| from ..provider import EmbeddingProvider | ||
| from ..register import register_provider_adapter | ||
|
|
@@ -18,12 +16,14 @@ def __init__(self, provider_config: dict, provider_settings: dict) -> None: | |
| super().__init__(provider_config, provider_settings) | ||
| self.provider_config = provider_config | ||
| self.provider_settings = provider_settings | ||
|
|
||
| proxy = provider_config.get("proxy", "") | ||
| provider_id = provider_config.get("id", "unknown_id") | ||
| http_client = None | ||
| if proxy: | ||
| logger.info(f"[OpenAI Embedding] {provider_id} Using proxy: {proxy}") | ||
| http_client = httpx.AsyncClient(proxy=proxy) | ||
|
|
||
| api_base = ( | ||
| provider_config.get("embedding_api_base", "https://api.openai.com/v1") | ||
| .strip() | ||
|
|
@@ -33,56 +33,195 @@ def __init__(self, provider_config: dict, provider_settings: dict) -> None: | |
| if api_base and not api_base.endswith("/v1") and not api_base.endswith("/v4"): | ||
| # /v4 see #5699 | ||
| api_base = api_base + "/v1" | ||
|
|
||
| # [新增] 保存处理后的 api_base 并转换为小写,用于后续特征比对 | ||
| self.api_base_normalized = api_base.lower() | ||
|
|
||
| logger.info(f"[OpenAI Embedding] {provider_id} Using API Base: {api_base}") | ||
|
|
||
| self.client = AsyncOpenAI( | ||
| api_key=provider_config.get("embedding_api_key"), | ||
| base_url=api_base, | ||
| timeout=int(provider_config.get("timeout", 20)), | ||
| http_client=http_client, | ||
| ) | ||
| self.model = provider_config.get("embedding_model", "text-embedding-3-small") | ||
|
|
||
| # [新增] 运行时状态标记:一旦触发 400 错误将此设为 True | ||
| self._is_vllm_detected = False | ||
|
|
||
| def _is_vllm(self) -> bool: | ||
| """检测是否是vLLM(vLLM不支持dimensions参数)""" | ||
| # 先检查运行时检测标志 | ||
| if self._is_vllm_detected: | ||
| return True | ||
| # 方法1:检查provider_id是否包含vllm | ||
| provider_id = self.provider_config.get("id", "").lower() | ||
| if "vllm" in provider_id: | ||
| logger.info(f"[OpenAI Embedding] Detected vLLM by provider id: {provider_id}") | ||
| return True | ||
| # 方法2:检查api_base中的特征端口或主机名 | ||
| api_base = self.api_base_normalized.lower() | ||
| if "vllm" in api_base: | ||
| logger.info(f"[OpenAI Embedding] Detected vLLM by api_base keyword") | ||
| return True | ||
| # 方法3:检查常见的vLLM端口(8000, 8001等) | ||
| if ":8000" in api_base or ":8001" in api_base or ":8002" in api_base: | ||
|
sourcery-ai[bot] marked this conversation as resolved.
Outdated
|
||
| logger.info(f"[OpenAI Embedding] Detected vLLM by common port in api_base: {api_base}") | ||
| return True | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| return False | ||
|
|
||
| def _mark_as_vllm(self) -> None: | ||
| """标记此实例为vLLM(通过运行时错误检测出来的)""" | ||
| self._is_vllm_detected = True | ||
| logger.info("[OpenAI Embedding] Marked as vLLM (runtime detection via error)") | ||
|
|
||
| async def get_embedding(self, text: str) -> list[float]: | ||
|
Creeper3222 marked this conversation as resolved.
|
||
| """获取文本的嵌入""" | ||
| kwargs = self._embedding_kwargs() | ||
| embedding = await self.client.embeddings.create( | ||
| input=text, | ||
| model=self.model, | ||
| **kwargs, | ||
| ) | ||
| return embedding.data[0].embedding | ||
| try: | ||
| embedding = await self.client.embeddings.create( | ||
| input=text, | ||
| model=self.model, | ||
| **kwargs, | ||
| ) | ||
| return embedding.data[0].embedding | ||
| except Exception as e: | ||
| # 如果包含"matryoshka"或"dimensions"相关的错误,说明vLLM不支持该参数 | ||
| # 尝试不带dimensions重试 | ||
| error_msg = str(e).lower() | ||
| if ("matryoshka" in error_msg or "dimensions" in error_msg) and kwargs.get("dimensions"): | ||
| logger.warning( | ||
| f"[OpenAI Embedding] Detected vLLM dimensions error, retrying without dimensions parameter: {e}" | ||
| ) | ||
| kwargs_retry = {k: v for k, v in kwargs.items() if k != "dimensions"} | ||
|
Creeper3222 marked this conversation as resolved.
|
||
| try: | ||
| embedding = await self.client.embeddings.create( | ||
| input=text, | ||
| model=self.model, | ||
| **kwargs_retry, | ||
| ) | ||
| logger.info( | ||
| "[OpenAI Embedding] Successfully retrieved embedding without dimensions parameter, marking as vLLM" | ||
| ) | ||
| # 标记为vLLM以便后续调用也跳过dimensions | ||
| self._mark_as_vllm() | ||
| return embedding.data[0].embedding | ||
| except Exception as retry_error: | ||
| logger.error( | ||
| f"[OpenAI Embedding] Retry without dimensions also failed: {retry_error}" | ||
| ) | ||
| raise retry_error | ||
| else: | ||
| raise | ||
|
|
||
| async def get_embeddings(self, text: list[str]) -> list[list[float]]: | ||
| """批量获取文本的嵌入""" | ||
| kwargs = self._embedding_kwargs() | ||
| embeddings = await self.client.embeddings.create( | ||
| input=text, | ||
| model=self.model, | ||
| **kwargs, | ||
| ) | ||
| return [item.embedding for item in embeddings.data] | ||
| try: | ||
| embeddings = await self.client.embeddings.create( | ||
| input=text, | ||
| model=self.model, | ||
| **kwargs, | ||
| ) | ||
| return [item.embedding for item in embeddings.data] | ||
| except Exception as e: | ||
| # 如果包含"matryoshka"或"dimensions"相关的错误,说明vLLM不支持该参数 | ||
| # 尝试不带dimensions重试 | ||
| error_msg = str(e).lower() | ||
| if ("matryoshka" in error_msg or "dimensions" in error_msg) and kwargs.get("dimensions"): | ||
| logger.warning( | ||
| f"[OpenAI Embedding] Detected vLLM dimensions error in batch mode, retrying without dimensions: {e}" | ||
| ) | ||
| kwargs_retry = {k: v for k, v in kwargs.items() if k != "dimensions"} | ||
| try: | ||
| embeddings = await self.client.embeddings.create( | ||
| input=text, | ||
| model=self.model, | ||
| **kwargs_retry, | ||
| ) | ||
| logger.info( | ||
| "[OpenAI Embedding] Successfully retrieved batch embeddings without dimensions parameter" | ||
| ) | ||
| # 标记为vLLM以便后续调用也跳过dimensions | ||
| self._mark_as_vllm() | ||
| return [item.embedding for item in embeddings.data] | ||
| except Exception as retry_error: | ||
| logger.error( | ||
| f"[OpenAI Embedding] Batch retry without dimensions also failed: {retry_error}" | ||
| ) | ||
| raise retry_error | ||
| else: | ||
| raise | ||
|
|
||
| def _embedding_kwargs(self) -> dict: | ||
| """构建嵌入请求的可选参数""" | ||
| kwargs = {} | ||
| if "embedding_dimensions" in self.provider_config: | ||
| provider_id = self.provider_config.get("id", "unknown") | ||
| embedding_dim_config = self.provider_config.get("embedding_dimensions", "") | ||
| # 检查是否是vLLM | ||
| is_vllm = self._is_vllm() | ||
| if is_vllm: | ||
| logger.info( | ||
| f"[OpenAI Embedding] {provider_id}: Detected vLLM, skipping dimensions parameter (config value: '{embedding_dim_config}')" | ||
| ) | ||
| return kwargs | ||
| # 非vLLM服务(OpenAI等)支持dimensions,读取配置 | ||
| if embedding_dim_config and embedding_dim_config != "": | ||
| try: | ||
| kwargs["dimensions"] = int(self.provider_config["embedding_dimensions"]) | ||
| dim_value = int(embedding_dim_config) | ||
| kwargs["dimensions"] = dim_value | ||
| logger.info( | ||
| f"[OpenAI Embedding] {provider_id}: Added dimensions parameter: {dim_value}" | ||
| ) | ||
| except (ValueError, TypeError): | ||
| logger.warning( | ||
| f"embedding_dimensions in embedding configs is not a valid integer: '{self.provider_config['embedding_dimensions']}', ignored." | ||
| f"[OpenAI Embedding] {provider_id}: embedding_dimensions is not a valid integer: '{embedding_dim_config}', ignored." | ||
| ) | ||
| else: | ||
| logger.info( | ||
| f"[OpenAI Embedding] {provider_id}: No embedding_dimensions configured, API will use default" | ||
| ) | ||
| return kwargs | ||
|
|
||
| def get_dim(self) -> int: | ||
| """获取向量的维度""" | ||
| if "embedding_dimensions" in self.provider_config: | ||
| provider_id = self.provider_config.get("id", "unknown") | ||
| # 首先尝试从config读取 | ||
| embedding_dim_config = self.provider_config.get("embedding_dimensions", "") | ||
| if embedding_dim_config and embedding_dim_config != "": | ||
| try: | ||
| return int(self.provider_config["embedding_dimensions"]) | ||
| dim = int(embedding_dim_config) | ||
| if dim > 0: | ||
| logger.info( | ||
| f"[OpenAI Embedding] {provider_id}: Dimension from config: {dim}" | ||
| ) | ||
| return dim | ||
| except (ValueError, TypeError): | ||
| logger.warning( | ||
| f"embedding_dimensions in embedding configs is not a valid integer: '{self.provider_config['embedding_dimensions']}', ignored." | ||
| f"[OpenAI Embedding] {provider_id}: embedding_dimensions is not a valid integer: '{embedding_dim_config}', trying model inference" | ||
| ) | ||
| # config为空或无效时根据模型名推断维度 | ||
| # 这样Living Memory可以在自动检测后匹配正确的维度 | ||
| model = self.provider_config.get("embedding_model", "").lower() | ||
| model_dims = { | ||
| "bge-m3": 1024, | ||
| "bge-large-en-v1.5": 1024, | ||
| "bge-large-zh-v1.5": 1024, | ||
| "text-embedding-3-small": 1536, | ||
| "text-embedding-3-large": 3072, | ||
| "text-embedding-ada-002": 1536, | ||
| } | ||
| for model_key, dim in model_dims.items(): | ||
| if model_key in model: | ||
| logger.info( | ||
| f"[OpenAI Embedding] {provider_id}: Inferred dimension {dim} from model: {model}" | ||
| ) | ||
| return dim | ||
| # 无法推断时返回0(Living Memory会检测实际维度) | ||
| logger.warning( | ||
| f"[OpenAI Embedding] {provider_id}: Could not determine dimension (model: {model}, config: '{embedding_dim_config}')" | ||
| ) | ||
| return 0 | ||
|
|
||
| async def terminate(self): | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -104,17 +104,17 @@ function saveEditedContent() { | |
|
|
||
| async function getEmbeddingDimensions(providerConfig) { | ||
| if (loadingEmbeddingDim.value) return | ||
|
|
||
| loadingEmbeddingDim.value = true | ||
| try { | ||
| const response = await axios.post('/api/config/provider/get_embedding_dim', { | ||
| provider_config: providerConfig | ||
| }) | ||
|
|
||
| if (response.data.status != "error" && response.data.data?.embedding_dimensions) { | ||
| console.log(response.data.data.embedding_dimensions) | ||
| providerConfig.embedding_dimensions = response.data.data.embedding_dimensions | ||
| [已禁用] 不再自动写入配置文件,仅显示提示 | ||
|
sourcery-ai[bot] marked this conversation as resolved.
Outdated
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| // providerConfig.embedding_dimensions = response.data.data.embedding_dimensions | ||
| useToast().success("获取成功: " + response.data.data.embedding_dimensions) | ||
| useToast().info(`检测到维度: ${response.data.data.embedding_dimensions}。如需保存,请手动填入后点保存。`) | ||
| } else { | ||
| useToast().error(response.data.message) | ||
| } | ||
|
|
||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
改了下_is_vllm的逻辑,现在在OpenAI Embedding提供商的API key这一栏填入“vllm“才会使得dimensions参数不会发给vllm防止报错。优化了之前靠检测端口号的静态判定逻辑