diff --git a/agentrun/conversation_service/__ots_backend_async_template.py b/agentrun/conversation_service/__ots_backend_async_template.py index 5bd9a13..6279f3c 100644 --- a/agentrun/conversation_service/__ots_backend_async_template.py +++ b/agentrun/conversation_service/__ots_backend_async_template.py @@ -35,6 +35,10 @@ ) from agentrun.conversation_service.model import ( + CHECKPOINT_BLOBS_SCHEMA_VERSION, + CHECKPOINT_SCHEMA_VERSION, + CHECKPOINT_WRITES_SCHEMA_VERSION, + CONVERSATION_SCHEMA_VERSION, ConversationEvent, ConversationSession, DEFAULT_APP_STATE_TABLE, @@ -48,6 +52,9 @@ DEFAULT_STATE_SEARCH_INDEX, DEFAULT_STATE_TABLE, DEFAULT_USER_STATE_TABLE, + EVENT_SCHEMA_VERSION, + SCHEMA_VERSION_COLUMN, + STATE_SCHEMA_VERSION, StateData, StateScope, ) @@ -113,7 +120,7 @@ def __init__( ) # ----------------------------------------------------------------------- - # 建表(异步)/ Table creation (async) + # 建表 / Table creation # ----------------------------------------------------------------------- async def init_tables_async(self) -> None: @@ -174,6 +181,13 @@ async def init_search_index_async(self) -> None: await self._create_conversation_search_index_async() await self._create_state_search_index_async() + async def init_conversation_search_index_async(self) -> None: + """仅创建 Conversation 多元索引(异步)。 + + 索引已存在时跳过,可重复调用。 + """ + await self._create_conversation_search_index_async() + async def init_checkpoint_tables_async(self) -> None: """创建 LangGraph checkpoint 相关的 3 张表(异步)。 @@ -595,7 +609,7 @@ async def _create_state_search_index_async(self) -> None: raise # ----------------------------------------------------------------------- - # Session CRUD(异步)/ Session CRUD (async) + # Session CRUD # ----------------------------------------------------------------------- async def put_session_async(self, session: ConversationSession) -> None: @@ -607,6 +621,7 @@ async def put_session_async(self, session: ConversationSession) -> None: ] attribute_columns = [ + (SCHEMA_VERSION_COLUMN, CONVERSATION_SCHEMA_VERSION), ("created_at", session.created_at), ("updated_at", session.updated_at), ("is_pinned", session.is_pinned), @@ -946,7 +961,7 @@ async def search_sessions_async( return sessions, search_response.total_count or 0 # ----------------------------------------------------------------------- - # Event CRUD(异步)/ Event CRUD (async) + # Event CRUD # ----------------------------------------------------------------------- async def put_event_async( @@ -991,6 +1006,7 @@ async def put_event_async( content_json = json.dumps(content, ensure_ascii=False) attribute_columns = [ + (SCHEMA_VERSION_COLUMN, EVENT_SCHEMA_VERSION), ("type", event_type), ("content", content_json), ("created_at", created_at), @@ -1171,7 +1187,7 @@ async def delete_events_by_session_async( return deleted # ----------------------------------------------------------------------- - # State CRUD(JSON 字符串存储 + 列分片)(异步) + # State CRUD(JSON 字符串存储 + 列分片) # ----------------------------------------------------------------------- async def put_state_async( @@ -1204,6 +1220,7 @@ async def put_state_async( state_json = serialize_state(state) put_cols: list[tuple[str, Any]] = [ + (SCHEMA_VERSION_COLUMN, STATE_SCHEMA_VERSION), ("updated_at", now), ("version", version + 1), ] @@ -1328,7 +1345,7 @@ async def delete_state_row_async( await self._async_client.delete_row(table_name, row, condition) # ----------------------------------------------------------------------- - # Checkpoint CRUD(LangGraph)(异步) + # Checkpoint CRUD(LangGraph) # ----------------------------------------------------------------------- async def put_checkpoint_async( @@ -1349,6 +1366,7 @@ async def put_checkpoint_async( ("checkpoint_id", checkpoint_id), ] attribute_columns = [ + (SCHEMA_VERSION_COLUMN, CHECKPOINT_SCHEMA_VERSION), ("checkpoint_type", checkpoint_type), ("checkpoint_data", checkpoint_data), ("metadata", metadata_json), @@ -1502,6 +1520,7 @@ async def put_checkpoint_writes_async( ("task_idx", w["task_idx"]), ] attrs = [ + (SCHEMA_VERSION_COLUMN, CHECKPOINT_WRITES_SCHEMA_VERSION), ("task_id", w["task_id"]), ("task_path", w.get("task_path", "")), ("channel", w["channel"]), @@ -1580,6 +1599,7 @@ async def put_checkpoint_blob_async( ("version", version), ] attribute_columns = [ + (SCHEMA_VERSION_COLUMN, CHECKPOINT_BLOBS_SCHEMA_VERSION), ("blob_type", blob_type), ("blob_data", blob_data), ] @@ -1747,7 +1767,7 @@ async def _scan_and_delete_async( await self._async_client.batch_write_row(request) # ----------------------------------------------------------------------- - # 内部辅助方法(I/O 相关,异步) + # 内部辅助方法(I/O 相关) # ----------------------------------------------------------------------- async def _get_chunk_count_async( diff --git a/agentrun/conversation_service/__session_store_async_template.py b/agentrun/conversation_service/__session_store_async_template.py index e8f91d4..667c38c 100644 --- a/agentrun/conversation_service/__session_store_async_template.py +++ b/agentrun/conversation_service/__session_store_async_template.py @@ -71,7 +71,7 @@ async def init_langchain_tables_async(self) -> None: 表或索引已存在时跳过,可重复调用。 """ await self._backend.init_core_tables_async() - await self._backend.init_search_index_async() + await self._backend.init_conversation_search_index_async() async def init_langgraph_tables_async(self) -> None: """创建 LangGraph 所需的全部表和索引(异步)。 @@ -81,7 +81,7 @@ async def init_langgraph_tables_async(self) -> None: 表或索引已存在时跳过,可重复调用。 """ await self._backend.init_core_tables_async() - await self._backend.init_search_index_async() + await self._backend.init_conversation_search_index_async() await self._backend.init_checkpoint_tables_async() async def init_adk_tables_async(self) -> None: @@ -96,7 +96,7 @@ async def init_adk_tables_async(self) -> None: await self._backend.init_search_index_async() # ------------------------------------------------------------------- - # Checkpoint 管理(LangGraph)(异步) + # Checkpoint 管理(LangGraph) # ------------------------------------------------------------------- async def put_checkpoint_async( @@ -210,7 +210,7 @@ async def delete_thread_checkpoints_async( await self._backend.delete_thread_checkpoints_async(thread_id) # ------------------------------------------------------------------- - # Session 管理(异步)/ Session management (async) + # Session 管理 / Session management # ------------------------------------------------------------------- async def create_session_async( @@ -496,7 +496,7 @@ async def update_session_async( ) # ------------------------------------------------------------------- - # Event 管理(异步)/ Event management (async) + # Event 管理 / Event management # ------------------------------------------------------------------- async def append_event_async( @@ -631,7 +631,7 @@ async def get_recent_events_async( return events # ------------------------------------------------------------------- - # State 管理(异步)/ State management (async) + # State 管理 / State management # ------------------------------------------------------------------- async def get_session_state_async( @@ -746,7 +746,7 @@ async def get_merged_state_async( return merged # ------------------------------------------------------------------- - # 内部辅助方法(异步) + # 内部辅助方法 # ------------------------------------------------------------------- async def _apply_delta_async( @@ -802,7 +802,7 @@ async def _apply_delta_async( ) # ------------------------------------------------------------------- - # 工厂方法(异步)/ Factory methods (async) + # 工厂方法 / Factory methods # ------------------------------------------------------------------- @classmethod diff --git a/agentrun/conversation_service/model.py b/agentrun/conversation_service/model.py index c7b344e..89b43c0 100644 --- a/agentrun/conversation_service/model.py +++ b/agentrun/conversation_service/model.py @@ -28,6 +28,41 @@ DEFAULT_CHECKPOINT_WRITES_TABLE = "checkpoint_writes" DEFAULT_CHECKPOINT_BLOBS_TABLE = "checkpoint_blobs" +# --------------------------------------------------------------------------- +# OTS Schema 版本管理 +# +# 用于 SDK 写入端与 Core 读取端(funagent-core)的兼容性协调。 +# 每次写入行(PutRow / UpdateRow / BatchWriteRow)时在 +# attribute_columns 中携带 _schema_version 字段。 +# Core 端读取时检查该字段,版本不匹配时打 WARN 日志并尽力解析。 +# 历史数据(无此字段)视为 v0。 +# +# 版本计数规则: +# - 大部分表独立计数 +# - state / app_state / user_state 三张表共享 STATE_SCHEMA_VERSION +# +# 升级流程: +# 1. 递增对应表的 *_SCHEMA_VERSION 常量 +# 2. 在 PR 描述中记录变更的列名/类型/语义 +# 3. 通知 funagent-core 侧同步更新解析逻辑和版本常量 +# 4. 如涉及 breaking change,提供数据迁移指引 +# +# 兼容性规则: +# - 只加不删:新增列允许,删除/重命名列视为 breaking change +# - PK 不可变:主键结构永不改变 +# - 索引名不可变:Search Index 名称一旦确定不再修改 +# - 语义不可变:已有列的类型和含义不改变 +# --------------------------------------------------------------------------- + +SCHEMA_VERSION_COLUMN = "_schema_version" + +CONVERSATION_SCHEMA_VERSION = 1 +EVENT_SCHEMA_VERSION = 1 +STATE_SCHEMA_VERSION = 1 # state / app_state / user_state 共享 +CHECKPOINT_SCHEMA_VERSION = 1 +CHECKPOINT_WRITES_SCHEMA_VERSION = 1 +CHECKPOINT_BLOBS_SCHEMA_VERSION = 1 + # --------------------------------------------------------------------------- # 枚举 diff --git a/agentrun/conversation_service/ots_backend.py b/agentrun/conversation_service/ots_backend.py index d79f7c5..e69d7cc 100644 --- a/agentrun/conversation_service/ots_backend.py +++ b/agentrun/conversation_service/ots_backend.py @@ -45,6 +45,10 @@ ) from agentrun.conversation_service.model import ( + CHECKPOINT_BLOBS_SCHEMA_VERSION, + CHECKPOINT_SCHEMA_VERSION, + CHECKPOINT_WRITES_SCHEMA_VERSION, + CONVERSATION_SCHEMA_VERSION, ConversationEvent, ConversationSession, DEFAULT_APP_STATE_TABLE, @@ -58,6 +62,9 @@ DEFAULT_STATE_SEARCH_INDEX, DEFAULT_STATE_TABLE, DEFAULT_USER_STATE_TABLE, + EVENT_SCHEMA_VERSION, + SCHEMA_VERSION_COLUMN, + STATE_SCHEMA_VERSION, StateData, StateScope, ) @@ -123,7 +130,7 @@ def __init__( ) # ----------------------------------------------------------------------- - # 建表(异步)/ Table creation (async) + # 建表 / Table creation # ----------------------------------------------------------------------- async def init_tables_async(self) -> None: @@ -242,6 +249,20 @@ def init_search_index(self) -> None: self._create_conversation_search_index() self._create_state_search_index() + async def init_conversation_search_index_async(self) -> None: + """仅创建 Conversation 多元索引(异步)。 + + 索引已存在时跳过,可重复调用。 + """ + await self._create_conversation_search_index_async() + + def init_conversation_search_index(self) -> None: + """仅创建 Conversation 多元索引(同步)。 + + 索引已存在时跳过,可重复调用。 + """ + self._create_conversation_search_index() + async def init_checkpoint_tables_async(self) -> None: """创建 LangGraph checkpoint 相关的 3 张表(异步)。 @@ -991,26 +1012,18 @@ async def _create_state_search_index_async(self) -> None: self._state_table, ) except OTSServiceError as e: - err_str = str(e).lower() - if "already exist" in err_str or ( + if "already exist" in str(e).lower() or ( hasattr(e, "code") and e.code == "OTSObjectAlreadyExist" ): logger.warning( "Search index %s already exists, skipping.", self._state_search_index, ) - elif "does not exist" in err_str and "table" in err_str: - logger.warning( - "Table %s does not exist, skipping search index creation" - " for %s.", - self._state_table, - self._state_search_index, - ) else: raise # ----------------------------------------------------------------------- - # Session CRUD(异步)/ Session CRUD (async) + # Session CRUD # ----------------------------------------------------------------------- def _create_state_search_index(self) -> None: @@ -1084,26 +1097,18 @@ def _create_state_search_index(self) -> None: self._state_table, ) except OTSServiceError as e: - err_str = str(e).lower() - if "already exist" in err_str or ( + if "already exist" in str(e).lower() or ( hasattr(e, "code") and e.code == "OTSObjectAlreadyExist" ): logger.warning( "Search index %s already exists, skipping.", self._state_search_index, ) - elif "does not exist" in err_str and "table" in err_str: - logger.warning( - "Table %s does not exist, skipping search index creation" - " for %s.", - self._state_table, - self._state_search_index, - ) else: raise # ----------------------------------------------------------------------- - # Session CRUD(同步)/ Session CRUD (async) + # Session CRUD # ----------------------------------------------------------------------- async def put_session_async(self, session: ConversationSession) -> None: @@ -1115,6 +1120,7 @@ async def put_session_async(self, session: ConversationSession) -> None: ] attribute_columns = [ + (SCHEMA_VERSION_COLUMN, CONVERSATION_SCHEMA_VERSION), ("created_at", session.created_at), ("updated_at", session.updated_at), ("is_pinned", session.is_pinned), @@ -1148,6 +1154,7 @@ def put_session(self, session: ConversationSession) -> None: ] attribute_columns = [ + (SCHEMA_VERSION_COLUMN, CONVERSATION_SCHEMA_VERSION), ("created_at", session.created_at), ("updated_at", session.updated_at), ("is_pinned", session.is_pinned), @@ -1690,7 +1697,7 @@ async def search_sessions_async( return sessions, search_response.total_count or 0 # ----------------------------------------------------------------------- - # Event CRUD(异步)/ Event CRUD (async) + # Event CRUD # ----------------------------------------------------------------------- def search_sessions( @@ -1799,7 +1806,7 @@ def search_sessions( return sessions, search_response.total_count or 0 # ----------------------------------------------------------------------- - # Event CRUD(同步)/ Event CRUD (async) + # Event CRUD # ----------------------------------------------------------------------- async def put_event_async( @@ -1844,6 +1851,7 @@ async def put_event_async( content_json = json.dumps(content, ensure_ascii=False) attribute_columns = [ + (SCHEMA_VERSION_COLUMN, EVENT_SCHEMA_VERSION), ("type", event_type), ("content", content_json), ("created_at", created_at), @@ -1918,6 +1926,7 @@ def put_event( content_json = json.dumps(content, ensure_ascii=False) attribute_columns = [ + (SCHEMA_VERSION_COLUMN, EVENT_SCHEMA_VERSION), ("type", event_type), ("content", content_json), ("created_at", created_at), @@ -2174,7 +2183,7 @@ async def delete_events_by_session_async( return deleted # ----------------------------------------------------------------------- - # State CRUD(JSON 字符串存储 + 列分片)(异步) + # State CRUD(JSON 字符串存储 + 列分片) # ----------------------------------------------------------------------- def delete_events_by_session( @@ -2249,7 +2258,7 @@ def delete_events_by_session( return deleted # ----------------------------------------------------------------------- - # State CRUD(JSON 字符串存储 + 列分片)(同步) + # State CRUD(JSON 字符串存储 + 列分片) # ----------------------------------------------------------------------- async def put_state_async( @@ -2282,6 +2291,7 @@ async def put_state_async( state_json = serialize_state(state) put_cols: list[tuple[str, Any]] = [ + (SCHEMA_VERSION_COLUMN, STATE_SCHEMA_VERSION), ("updated_at", now), ("version", version + 1), ] @@ -2374,6 +2384,7 @@ def put_state( state_json = serialize_state(state) put_cols: list[tuple[str, Any]] = [ + (SCHEMA_VERSION_COLUMN, STATE_SCHEMA_VERSION), ("updated_at", now), ("version", version + 1), ] @@ -2542,7 +2553,7 @@ async def delete_state_row_async( await self._async_client.delete_row(table_name, row, condition) # ----------------------------------------------------------------------- - # State CRUD(同步) + # Checkpoint CRUD(LangGraph) # ----------------------------------------------------------------------- def delete_state_row( @@ -2561,7 +2572,7 @@ def delete_state_row( self._client.delete_row(table_name, row, condition) # ----------------------------------------------------------------------- - # Checkpoint CRUD(LangGraph)(异步) + # Checkpoint CRUD(LangGraph) # ----------------------------------------------------------------------- async def put_checkpoint_async( @@ -2582,6 +2593,7 @@ async def put_checkpoint_async( ("checkpoint_id", checkpoint_id), ] attribute_columns = [ + (SCHEMA_VERSION_COLUMN, CHECKPOINT_SCHEMA_VERSION), ("checkpoint_type", checkpoint_type), ("checkpoint_data", checkpoint_data), ("metadata", metadata_json), @@ -2591,10 +2603,6 @@ async def put_checkpoint_async( condition = Condition(RowExistenceExpectation.IGNORE) await self._async_client.put_row(self._checkpoint_table, row, condition) - # ----------------------------------------------------------------------- - # Checkpoint CRUD(LangGraph)(同步) - # ----------------------------------------------------------------------- - def put_checkpoint( self, thread_id: str, @@ -2613,6 +2621,7 @@ def put_checkpoint( ("checkpoint_id", checkpoint_id), ] attribute_columns = [ + (SCHEMA_VERSION_COLUMN, CHECKPOINT_SCHEMA_VERSION), ("checkpoint_type", checkpoint_type), ("checkpoint_data", checkpoint_data), ("metadata", metadata_json), @@ -2881,6 +2890,7 @@ async def put_checkpoint_writes_async( ("task_idx", w["task_idx"]), ] attrs = [ + (SCHEMA_VERSION_COLUMN, CHECKPOINT_WRITES_SCHEMA_VERSION), ("task_id", w["task_id"]), ("task_path", w.get("task_path", "")), ("channel", w["channel"]), @@ -2928,6 +2938,7 @@ def put_checkpoint_writes( ("task_idx", w["task_idx"]), ] attrs = [ + (SCHEMA_VERSION_COLUMN, CHECKPOINT_WRITES_SCHEMA_VERSION), ("task_id", w["task_id"]), ("task_path", w.get("task_path", "")), ("channel", w["channel"]), @@ -3048,6 +3059,7 @@ async def put_checkpoint_blob_async( ("version", version), ] attribute_columns = [ + (SCHEMA_VERSION_COLUMN, CHECKPOINT_BLOBS_SCHEMA_VERSION), ("blob_type", blob_type), ("blob_data", blob_data), ] @@ -3075,6 +3087,7 @@ def put_checkpoint_blob( ("version", version), ] attribute_columns = [ + (SCHEMA_VERSION_COLUMN, CHECKPOINT_BLOBS_SCHEMA_VERSION), ("blob_type", blob_type), ("blob_data", blob_data), ] @@ -3357,7 +3370,7 @@ async def _scan_and_delete_async( await self._async_client.batch_write_row(request) # ----------------------------------------------------------------------- - # 内部辅助方法(I/O 相关,异步) + # 内部辅助方法(I/O 相关) # ----------------------------------------------------------------------- def _scan_and_delete( @@ -3401,7 +3414,7 @@ def _scan_and_delete( self._client.batch_write_row(request) # ----------------------------------------------------------------------- - # 内部辅助方法(I/O 相关,同步) + # 内部辅助方法(I/O 相关) # ----------------------------------------------------------------------- async def _get_chunk_count_async( diff --git a/agentrun/conversation_service/session_store.py b/agentrun/conversation_service/session_store.py index 9d3f567..2360c02 100644 --- a/agentrun/conversation_service/session_store.py +++ b/agentrun/conversation_service/session_store.py @@ -111,7 +111,7 @@ async def init_langchain_tables_async(self) -> None: 表或索引已存在时跳过,可重复调用。 """ await self._backend.init_core_tables_async() - await self._backend.init_search_index_async() + await self._backend.init_conversation_search_index_async() def init_langchain_tables(self) -> None: """创建 LangChain 所需的全部表和索引(同步)。 @@ -120,7 +120,7 @@ def init_langchain_tables(self) -> None: 表或索引已存在时跳过,可重复调用。 """ self._backend.init_core_tables() - self._backend.init_search_index() + self._backend.init_conversation_search_index() async def init_langgraph_tables_async(self) -> None: """创建 LangGraph 所需的全部表和索引(异步)。 @@ -130,7 +130,7 @@ async def init_langgraph_tables_async(self) -> None: 表或索引已存在时跳过,可重复调用。 """ await self._backend.init_core_tables_async() - await self._backend.init_search_index_async() + await self._backend.init_conversation_search_index_async() await self._backend.init_checkpoint_tables_async() def init_langgraph_tables(self) -> None: @@ -141,7 +141,7 @@ def init_langgraph_tables(self) -> None: 表或索引已存在时跳过,可重复调用。 """ self._backend.init_core_tables() - self._backend.init_search_index() + self._backend.init_conversation_search_index() self._backend.init_checkpoint_tables() async def init_adk_tables_async(self) -> None: @@ -156,7 +156,7 @@ async def init_adk_tables_async(self) -> None: await self._backend.init_search_index_async() # ------------------------------------------------------------------- - # Checkpoint 管理(LangGraph)(异步) + # Checkpoint 管理(LangGraph) # ------------------------------------------------------------------- def init_adk_tables(self) -> None: @@ -171,7 +171,7 @@ def init_adk_tables(self) -> None: self._backend.init_search_index() # ------------------------------------------------------------------- - # Checkpoint 管理(LangGraph)(同步) + # Checkpoint 管理(LangGraph) # ------------------------------------------------------------------- async def put_checkpoint_async( @@ -388,7 +388,7 @@ async def delete_thread_checkpoints_async( await self._backend.delete_thread_checkpoints_async(thread_id) # ------------------------------------------------------------------- - # Session 管理(异步)/ Session management (async) + # Session 管理 / Session management # ------------------------------------------------------------------- def delete_thread_checkpoints( @@ -399,7 +399,7 @@ def delete_thread_checkpoints( self._backend.delete_thread_checkpoints(thread_id) # ------------------------------------------------------------------- - # Session 管理(同步)/ Session management (async) + # Session 管理 / Session management # ------------------------------------------------------------------- async def create_session_async( @@ -909,7 +909,7 @@ async def update_session_async( ) # ------------------------------------------------------------------- - # Event 管理(异步)/ Event management (async) + # Event 管理 / Event management # ------------------------------------------------------------------- def update_session( @@ -965,7 +965,7 @@ def update_session( ) # ------------------------------------------------------------------- - # Event 管理(同步)/ Event management (async) + # Event 管理 / Event management # ------------------------------------------------------------------- async def append_event_async( @@ -1199,7 +1199,7 @@ async def get_recent_events_async( return events # ------------------------------------------------------------------- - # State 管理(异步)/ State management (async) + # State 管理 / State management # ------------------------------------------------------------------- def get_recent_events( @@ -1233,7 +1233,7 @@ def get_recent_events( return events # ------------------------------------------------------------------- - # State 管理(同步)/ State management (async) + # State 管理 / State management # ------------------------------------------------------------------- async def get_session_state_async( @@ -1429,7 +1429,7 @@ async def get_merged_state_async( return merged # ------------------------------------------------------------------- - # 内部辅助方法(异步) + # 内部辅助方法 # ------------------------------------------------------------------- def get_merged_state( @@ -1457,7 +1457,7 @@ def get_merged_state( return merged # ------------------------------------------------------------------- - # 内部辅助方法(同步) + # 内部辅助方法 # ------------------------------------------------------------------- async def _apply_delta_async( @@ -1513,7 +1513,7 @@ async def _apply_delta_async( ) # ------------------------------------------------------------------- - # 工厂方法(异步)/ Factory methods (async) + # 工厂方法 / Factory methods # ------------------------------------------------------------------- def _apply_delta( @@ -1567,7 +1567,7 @@ def _apply_delta( ) # ------------------------------------------------------------------- - # 工厂方法(同步)/ Factory methods (async) + # 工厂方法 / Factory methods # ------------------------------------------------------------------- @classmethod diff --git a/tests/unittests/conversation_service/test_ots_backend.py b/tests/unittests/conversation_service/test_ots_backend.py index 3b22526..4701773 100644 --- a/tests/unittests/conversation_service/test_ots_backend.py +++ b/tests/unittests/conversation_service/test_ots_backend.py @@ -18,8 +18,15 @@ from tablestore import OTSServiceError, Row # type: ignore[import-untyped] from agentrun.conversation_service.model import ( + CHECKPOINT_BLOBS_SCHEMA_VERSION, + CHECKPOINT_SCHEMA_VERSION, + CHECKPOINT_WRITES_SCHEMA_VERSION, + CONVERSATION_SCHEMA_VERSION, ConversationEvent, ConversationSession, + EVENT_SCHEMA_VERSION, + SCHEMA_VERSION_COLUMN, + STATE_SCHEMA_VERSION, StateData, StateScope, ) @@ -29,6 +36,20 @@ ) from agentrun.conversation_service.utils import MAX_COLUMN_SIZE + +def _extract_attr_columns_dict( + row_arg: Row, +) -> dict[str, Any]: + """Extract attribute columns from a Row arg into a dict for easy assertion.""" + cols = row_arg.attribute_columns + if isinstance(cols, dict): + # UpdateRow format: {"PUT": [...], ...} + put_list = cols.get("PUT", []) + return {name: val for name, val in put_list} + # PutRow format: [(name, value), ...] + return {col[0]: col[1] for col in cols} + + # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- @@ -1970,3 +1991,179 @@ async def test_search_sessions_with_row_objects(self) -> None: sessions, total = await backend.search_sessions_async("agent1") assert len(sessions) == 1 + + +class TestSchemaVersionAsync: + """验证所有异步 put_* 方法在写入时携带 _schema_version 列。""" + + @pytest.mark.asyncio + async def test_put_session_has_schema_version(self) -> None: + backend = _make_async_backend() + session = ConversationSession("a", "u", "s", 100, 200) + await backend.put_session_async(session) + + call_args = backend._async_client.put_row.call_args + row_arg = call_args[0][1] + attrs = _extract_attr_columns_dict(row_arg) + assert attrs[SCHEMA_VERSION_COLUMN] == CONVERSATION_SCHEMA_VERSION + + @pytest.mark.asyncio + async def test_put_event_has_schema_version(self) -> None: + backend = _make_async_backend() + await backend.put_event_async("a", "u", "s", "msg", {"text": "hi"}) + + call_args = backend._async_client.put_row.call_args + row_arg = call_args[0][1] + attrs = _extract_attr_columns_dict(row_arg) + assert attrs[SCHEMA_VERSION_COLUMN] == EVENT_SCHEMA_VERSION + + @pytest.mark.asyncio + async def test_put_state_has_schema_version(self) -> None: + backend = _make_async_backend() + await backend.put_state_async( + StateScope.SESSION, "a", "u", "s", {"key": "val"}, 0 + ) + + call_args = backend._async_client.update_row.call_args + row_arg = call_args[0][1] + attrs = _extract_attr_columns_dict(row_arg) + assert attrs[SCHEMA_VERSION_COLUMN] == STATE_SCHEMA_VERSION + + @pytest.mark.asyncio + async def test_put_checkpoint_has_schema_version(self) -> None: + backend = _make_async_backend() + await backend.put_checkpoint_async( + "t1", + "ns1", + "c1", + checkpoint_type="json", + checkpoint_data="{}", + metadata_json="{}", + ) + + call_args = backend._async_client.put_row.call_args + row_arg = call_args[0][1] + attrs = _extract_attr_columns_dict(row_arg) + assert attrs[SCHEMA_VERSION_COLUMN] == CHECKPOINT_SCHEMA_VERSION + + @pytest.mark.asyncio + async def test_put_checkpoint_writes_has_schema_version(self) -> None: + backend = _make_async_backend() + backend._async_client.batch_write_row = AsyncMock() + writes = [{ + "task_idx": "0", + "task_id": "t", + "channel": "c", + "value_type": "json", + "value_data": "{}", + }] + await backend.put_checkpoint_writes_async("t1", "ns1", "c1", writes) + + call_args = backend._async_client.batch_write_row.call_args + request = call_args[0][0] + table_item = list(request.items.values())[0] + row_arg = table_item.row_items[0].row + attrs = _extract_attr_columns_dict(row_arg) + assert attrs[SCHEMA_VERSION_COLUMN] == CHECKPOINT_WRITES_SCHEMA_VERSION + + @pytest.mark.asyncio + async def test_put_checkpoint_blob_has_schema_version(self) -> None: + backend = _make_async_backend() + await backend.put_checkpoint_blob_async( + "t1", + "ns1", + "ch1", + "v1", + blob_type="json", + blob_data="{}", + ) + + call_args = backend._async_client.put_row.call_args + row_arg = call_args[0][1] + attrs = _extract_attr_columns_dict(row_arg) + assert attrs[SCHEMA_VERSION_COLUMN] == CHECKPOINT_BLOBS_SCHEMA_VERSION + + +class TestSchemaVersionSync: + """验证所有同步 put_* 方法在写入时携带 _schema_version 列。""" + + def test_put_session_has_schema_version(self) -> None: + backend = _make_backend() + session = ConversationSession("a", "u", "s", 100, 200) + backend.put_session(session) + + call_args = backend._client.put_row.call_args + row_arg = call_args[0][1] + attrs = _extract_attr_columns_dict(row_arg) + assert attrs[SCHEMA_VERSION_COLUMN] == CONVERSATION_SCHEMA_VERSION + + def test_put_event_has_schema_version(self) -> None: + backend = _make_backend() + return_row = MagicMock() + return_row.primary_key = [("seq_id", 1)] + backend._client.put_row.return_value = (None, return_row) + backend.put_event("a", "u", "s", "msg", {"text": "hi"}) + + call_args = backend._client.put_row.call_args + row_arg = call_args[0][1] + attrs = _extract_attr_columns_dict(row_arg) + assert attrs[SCHEMA_VERSION_COLUMN] == EVENT_SCHEMA_VERSION + + def test_put_state_has_schema_version(self) -> None: + backend = _make_backend() + backend.put_state(StateScope.SESSION, "a", "u", "s", {"key": "val"}, 0) + + call_args = backend._client.update_row.call_args + row_arg = call_args[0][1] + attrs = _extract_attr_columns_dict(row_arg) + assert attrs[SCHEMA_VERSION_COLUMN] == STATE_SCHEMA_VERSION + + def test_put_checkpoint_has_schema_version(self) -> None: + backend = _make_backend() + backend.put_checkpoint( + "t1", + "ns1", + "c1", + checkpoint_type="json", + checkpoint_data="{}", + metadata_json="{}", + ) + + call_args = backend._client.put_row.call_args + row_arg = call_args[0][1] + attrs = _extract_attr_columns_dict(row_arg) + assert attrs[SCHEMA_VERSION_COLUMN] == CHECKPOINT_SCHEMA_VERSION + + def test_put_checkpoint_writes_has_schema_version(self) -> None: + backend = _make_backend() + writes = [{ + "task_idx": "0", + "task_id": "t", + "channel": "c", + "value_type": "json", + "value_data": "{}", + }] + backend.put_checkpoint_writes("t1", "ns1", "c1", writes) + + call_args = backend._client.batch_write_row.call_args + request = call_args[0][0] + table_item = list(request.items.values())[0] + row_arg = table_item.row_items[0].row + attrs = _extract_attr_columns_dict(row_arg) + assert attrs[SCHEMA_VERSION_COLUMN] == CHECKPOINT_WRITES_SCHEMA_VERSION + + def test_put_checkpoint_blob_has_schema_version(self) -> None: + backend = _make_backend() + backend.put_checkpoint_blob( + "t1", + "ns1", + "ch1", + "v1", + blob_type="json", + blob_data="{}", + ) + + call_args = backend._client.put_row.call_args + row_arg = call_args[0][1] + attrs = _extract_attr_columns_dict(row_arg) + assert attrs[SCHEMA_VERSION_COLUMN] == CHECKPOINT_BLOBS_SCHEMA_VERSION