Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions astrbot/core/core_lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(self, log_broker: LogBroker, db: BaseDatabase) -> None:
self.subagent_orchestrator: SubAgentOrchestrator | None = None
self.cron_manager: CronJobManager | None = None
self.temp_dir_cleaner: TempDirCleaner | None = None
self._default_chat_provider_warning_emitted = False

# 设置代理
proxy_config = self.astrbot_config.get("http_proxy", "")
Expand Down Expand Up @@ -97,6 +98,65 @@ async def _init_or_reload_subagent_orchestrator(self) -> None:
except Exception as e:
logger.error(f"Subagent orchestrator init failed: {e}", exc_info=True)

@staticmethod
def _is_chat_provider_config(
provider_config: dict, provider_sources: dict[str, dict]
) -> bool:
if provider_config.get("provider_type") == "chat_completion":
return True

provider_source_id = provider_config.get("provider_source_id")
if not provider_source_id:
return False

provider_source = provider_sources.get(provider_source_id, {})
if provider_source.get("provider_type") == "chat_completion":
return True

provider_source_type = provider_source.get("type", "")
return isinstance(provider_source_type, str) and (
"chat_completion" in provider_source_type
)

def _warn_about_unset_default_chat_provider(self, config: dict) -> None:
if self._default_chat_provider_warning_emitted:
return

provider_settings = config.get("provider_settings", {})
default_provider_id = provider_settings.get("default_provider_id", "")
if default_provider_id:
return

provider_sources = {
source.get("id"): source
for source in config.get("provider_sources", [])
if isinstance(source, dict) and source.get("id")
}
enabled_chat_provider_ids: list[str] = []
for provider in config.get("provider", []):
if not isinstance(provider, dict):
continue
if not provider.get("enable", True):
continue
if not self._is_chat_provider_config(provider, provider_sources):
continue

provider_id = provider.get("id")
if isinstance(provider_id, str) and provider_id:
enabled_chat_provider_ids.append(provider_id)

if len(enabled_chat_provider_ids) <= 1:
return

self._default_chat_provider_warning_emitted = True
logger.warning(
"Detected %d enabled chat providers but `provider_settings.default_provider_id` is empty. "
"AstrBot will use `%s` as the startup fallback chat provider. "
"Set a default chat model in the WebUI configuration page to avoid unexpected provider switching.",
len(enabled_chat_provider_ids),
enabled_chat_provider_ids[0],
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation of _warn_about_unset_default_chat_provider and the helper _is_chat_provider_config duplicate the logic for identifying chat completion providers, which is already handled by ProviderManager. Furthermore, the current logic contains a bug: it only checks for provider_source_id, meaning providers defined with a direct type (e.g., openai_chat_completion) in the provider list without an external source will be incorrectly ignored.

Since this check is performed after self.provider_manager.initialize(), it is much more robust and cleaner to use the state already available in self.provider_manager.provider_insts, which contains exactly the successfully loaded chat completion providers.

    def _warn_about_unset_default_chat_provider(self) -> None:
        if self._default_chat_provider_warning_emitted:
            return

        pm = self.provider_manager
        if pm.provider_settings.get("default_provider_id"):
            return

        # ProviderManager.provider_insts contains only successfully loaded chat completion providers.
        chat_providers = pm.provider_insts
        if len(chat_providers) <= 1:
            return

        self._default_chat_provider_warning_emitted = True
        logger.warning(
            "Detected %d enabled chat providers but `provider_settings.default_provider_id` is empty. "
            "AstrBot will use `%s` as the startup fallback chat provider. "
            "Set a default chat model in the WebUI configuration page to avoid unexpected provider switching.",
            len(chat_providers),
            chat_providers[0].meta().id,
        )

async def initialize(self) -> None:
"""初始化 AstrBot 核心生命周期管理类.

Expand Down Expand Up @@ -202,6 +262,9 @@ async def initialize(self) -> None:

# 根据配置实例化各个 Provider
await self.provider_manager.initialize()
self._warn_about_unset_default_chat_provider(
self.astrbot_config_mgr.default_conf
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Update the call to _warn_about_unset_default_chat_provider to match the new signature (no arguments needed).

        self._warn_about_unset_default_chat_provider()


await self.kb_manager.initialize()

Expand Down
94 changes: 94 additions & 0 deletions tests/unit/test_core_lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,100 @@ async def test_subagent_orchestrator_error_is_logged(
)


class TestAstrBotCoreLifecycleDefaultChatProviderWarning:
"""Tests for startup warning when default chat provider is unset."""

def test_warns_for_multiple_enabled_chat_providers_without_default(
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
self, mock_log_broker, mock_db
):
lifecycle = AstrBotCoreLifecycle(mock_log_broker, mock_db)
config = {
"provider_settings": {"default_provider_id": ""},
"provider_sources": [
{"id": "openai_source", "provider_type": "chat_completion"}
],
"provider": [
{
"id": "openai_source/model-a",
"provider_source_id": "openai_source",
"enable": True,
},
{
"id": "agent_runner_provider",
"provider_type": "agent_runner",
"enable": True,
},
{
"id": "openai_source/model-b",
"provider_source_id": "openai_source",
"enable": True,
},
],
}

with patch("astrbot.core.core_lifecycle.logger") as mock_logger:
lifecycle._warn_about_unset_default_chat_provider(config)

mock_logger.warning.assert_called_once()
assert mock_logger.warning.call_args[0][1] == 2
assert mock_logger.warning.call_args[0][2] == "openai_source/model-a"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The unit tests should be updated to mock self.provider_manager.provider_insts and self.provider_manager.provider_settings instead of passing a raw configuration dictionary. This aligns with the refactored logic in core_lifecycle.py and makes the tests more focused on the warning condition rather than the internal configuration parsing logic.

    def test_warns_for_multiple_enabled_chat_providers_without_default(
        self, mock_log_broker, mock_db
    ):
        lifecycle = AstrBotCoreLifecycle(mock_log_broker, mock_db)
        
        # Mock provider manager state
        mock_pm = MagicMock()
        mock_pm.provider_settings = {"default_provider_id": ""}
        
        mock_p1 = MagicMock()
        mock_p1.meta.return_value.id = "openai_source/model-a"
        mock_p2 = MagicMock()
        mock_p2.meta.return_value.id = "openai_source/model-b"
        
        mock_pm.provider_insts = [mock_p1, mock_p2]
        lifecycle.provider_manager = mock_pm

        with patch("astrbot.core.core_lifecycle.logger") as mock_logger:
            lifecycle._warn_about_unset_default_chat_provider()

        mock_logger.warning.assert_called_once()
        assert mock_logger.warning.call_args[0][1] == 2
        assert mock_logger.warning.call_args[0][2] == "openai_source/model-a"


def test_warns_only_once_per_lifecycle(self, mock_log_broker, mock_db):
lifecycle = AstrBotCoreLifecycle(mock_log_broker, mock_db)
config = {
"provider_settings": {"default_provider_id": ""},
"provider_sources": [
{"id": "openai_source", "provider_type": "chat_completion"}
],
"provider": [
{
"id": "openai_source/model-a",
"provider_source_id": "openai_source",
"enable": True,
},
{
"id": "openai_source/model-b",
"provider_source_id": "openai_source",
"enable": True,
},
],
}

with patch("astrbot.core.core_lifecycle.logger") as mock_logger:
lifecycle._warn_about_unset_default_chat_provider(config)
lifecycle._warn_about_unset_default_chat_provider(config)

mock_logger.warning.assert_called_once()

def test_does_not_warn_when_default_chat_provider_is_set(
self, mock_log_broker, mock_db
):
Comment on lines +317 to +324
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): Add a test where curr_provider_inst is None to verify the fallback to providers[0] is exercised

Currently all tests set curr_provider_inst to a non-None value, so the or providers[0] branch is untested. Please add a case with multiple providers, curr_provider_inst=None, and an empty default_provider_id to confirm a warning is emitted and the first provider’s id is used. This will exercise the fallback path and clarify behavior when curr_provider_inst is unset.

Suggested change
with patch("astrbot.core.core_lifecycle.logger") as mock_logger:
lifecycle._warn_about_unset_default_chat_provider()
mock_logger.warning.assert_not_called()
def test_does_not_warn_when_default_chat_provider_is_set(
self, mock_log_broker, mock_db
):
with patch("astrbot.core.core_lifecycle.logger") as mock_logger:
lifecycle._warn_about_unset_default_chat_provider()
mock_logger.warning.assert_not_called()
def test_warns_and_uses_first_provider_when_curr_provider_unset(
self, mock_log_broker, mock_db
):
lifecycle = AstrBotCoreLifecycle(mock_log_broker, mock_db)
lifecycle.provider_manager = MagicMock(
provider_settings={"default_provider_id": ""},
provider_insts=[
self._make_provider("openai_source/model-a"),
self._make_provider("openai_source/model-b"),
],
curr_provider_inst=None,
)
with patch("astrbot.core.core_lifecycle.logger") as mock_logger:
lifecycle._warn_about_unset_default_chat_provider()
mock_logger.warning.assert_called_once()
args, kwargs = mock_logger.warning.call_args
assert "openai_source/model-a" in args[0]
def test_does_not_warn_when_default_chat_provider_is_set(
self, mock_log_broker, mock_db
):

lifecycle = AstrBotCoreLifecycle(mock_log_broker, mock_db)
config = {
"provider_settings": {"default_provider_id": "openai_source/model-a"},
"provider_sources": [
{"id": "openai_source", "provider_type": "chat_completion"}
],
"provider": [
{
"id": "openai_source/model-a",
"provider_source_id": "openai_source",
"enable": True,
},
{
"id": "openai_source/model-b",
"provider_source_id": "openai_source",
"enable": True,
},
],
}

with patch("astrbot.core.core_lifecycle.logger") as mock_logger:
lifecycle._warn_about_unset_default_chat_provider(config)

mock_logger.warning.assert_not_called()


class TestAstrBotCoreLifecycleInitialize:
"""Tests for AstrBotCoreLifecycle.initialize method."""

Expand Down