Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""

from ._group_chat._base_group_chat import BaseGroupChat
from ._group_chat._message_store import ListMessageStore, MessageStore
from ._group_chat._graph import (
DiGraph,
DiGraphBuilder,
Expand All @@ -18,6 +19,8 @@

__all__ = [
"BaseGroupChat",
"MessageStore",
"ListMessageStore",
"RoundRobinGroupChat",
"SelectorGroupChat",
"Swarm",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from ...base import TerminationCondition
from ...messages import BaseAgentEvent, BaseChatMessage, MessageFactory, SelectSpeakerEvent, StopMessage
from ._message_store import ListMessageStore, MessageStore
from ._events import (
GroupChatAgentResponse,
GroupChatError,
Expand Down Expand Up @@ -47,6 +48,7 @@ def __init__(
max_turns: int | None,
message_factory: MessageFactory,
emit_team_events: bool = False,
message_store: MessageStore | None = None,
):
super().__init__(
description="Group chat manager",
Expand Down Expand Up @@ -74,6 +76,7 @@ def __init__(
name: topic_type for name, topic_type in zip(participant_names, participant_topic_types, strict=True)
}
self._participant_descriptions = participant_descriptions
self._message_store: MessageStore = message_store or ListMessageStore()
self._message_thread: List[BaseAgentEvent | BaseChatMessage] = []
self._output_message_queue = output_message_queue
self._termination_condition = termination_condition
Expand Down Expand Up @@ -299,8 +302,12 @@ async def update_message_thread(self, messages: Sequence[BaseAgentEvent | BaseCh
"""Update the message thread with the new messages.
This is called when the group chat receives a GroupChatStart or GroupChatAgentResponse event,
before calling the select_speakers method.

Messages are added to both the legacy ``_message_thread`` list (for backward
compatibility with subclasses) and the ``_message_store`` abstraction.
"""
self._message_thread.extend(messages)
await self._message_store.add(messages)

@abstractmethod
async def select_speaker(self, thread: Sequence[BaseAgentEvent | BaseChatMessage]) -> List[str] | str:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""Message store abstraction for group chat message threads."""

from abc import ABC, abstractmethod
from datetime import datetime, timedelta
from typing import List, Sequence

from ...messages import BaseAgentEvent, BaseChatMessage

ChatMessage = BaseAgentEvent | BaseChatMessage


class MessageStore(ABC):
"""Abstract base class for storing group chat message threads.

Implementations can provide different storage backends (in-memory, database, etc.)
and message retention policies (e.g., TTL-based expiration).
"""

@abstractmethod
async def add(self, messages: Sequence[ChatMessage]) -> None:
"""Add messages to the store.

Args:
messages: A sequence of messages to append to the thread.
"""
...

@abstractmethod
async def get(self) -> List[ChatMessage]:
"""Return all messages currently in the store.

Returns:
A list of messages in chronological order.
"""
...

@abstractmethod
async def clear(self) -> None:
"""Remove all messages from the store."""
...


class ListMessageStore(MessageStore):
"""In-memory message store backed by a Python list.

Args:
ttl: Optional time-to-live for messages. When set, messages older than
``ttl`` are automatically excluded from :meth:`get` results.
"""

def __init__(self, *, ttl: timedelta | None = None) -> None:
self._messages: List[ChatMessage] = []
self._timestamps: List[datetime] = []
self._ttl = ttl

async def add(self, messages: Sequence[ChatMessage]) -> None:
now = datetime.now()
self._messages.extend(messages)
self._timestamps.extend(now for _ in messages)

async def get(self) -> List[ChatMessage]:
if self._ttl is not None:
cutoff = datetime.now() - self._ttl
return [
msg
for msg, ts in zip(self._messages, self._timestamps)
if ts >= cutoff
]
return list(self._messages)

async def clear(self) -> None:
self._messages.clear()
self._timestamps.clear()