diff --git a/packages/api/alembic/versions/092e2aa153ce_add_chat_messages_langfuse_trace_id.py b/packages/api/alembic/versions/092e2aa153ce_add_chat_messages_langfuse_trace_id.py new file mode 100644 index 0000000..d14b8d0 --- /dev/null +++ b/packages/api/alembic/versions/092e2aa153ce_add_chat_messages_langfuse_trace_id.py @@ -0,0 +1,36 @@ +"""add chat_messages.langfuse_trace_id + +Revision ID: 092e2aa153ce +Revises: a15a753f44c8 +Create Date: 2026-05-21 09:00:00.000000 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +revision: str = '092e2aa153ce' +down_revision: Union[str, Sequence[str], None] = 'a15a753f44c8' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + with op.batch_alter_table('chat_messages') as batch_op: + batch_op.add_column( + sa.Column('langfuse_trace_id', sa.String(), nullable=True) + ) + batch_op.create_index( + 'ix_chat_messages_langfuse_trace_id', + ['langfuse_trace_id'], + ) + + +def downgrade() -> None: + """Downgrade schema.""" + with op.batch_alter_table('chat_messages') as batch_op: + batch_op.drop_index('ix_chat_messages_langfuse_trace_id') + batch_op.drop_column('langfuse_trace_id') diff --git a/packages/api/src/cell_explorer_api/db/models.py b/packages/api/src/cell_explorer_api/db/models.py index f3c09b2..2e31a51 100644 --- a/packages/api/src/cell_explorer_api/db/models.py +++ b/packages/api/src/cell_explorer_api/db/models.py @@ -101,6 +101,9 @@ class ChatMessageRow(SQLModel, table=True): role: str # "user" | "assistant" content: str created_at: datetime = Field(default_factory=_utcnow) + # Set on assistant rows when Langfuse tracing is enabled. Used by the + # feedback PUT route to forward thumbs ratings to Langfuse Scores. + langfuse_trace_id: str | None = Field(default=None, index=True) class ChatFeedback(SQLModel, table=True): diff --git a/packages/api/src/cell_explorer_api/routes/chat.py b/packages/api/src/cell_explorer_api/routes/chat.py index e5c3c4c..28452a4 100644 --- a/packages/api/src/cell_explorer_api/routes/chat.py +++ b/packages/api/src/cell_explorer_api/routes/chat.py @@ -1,6 +1,7 @@ """HTTP routes for the chat agent (Plan 2a).""" import asyncio +import logging import uuid as _uuid from collections.abc import AsyncIterator from datetime import datetime @@ -13,8 +14,10 @@ from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession +from cell_explorer_agent.config import AgentConfig from cell_explorer_agent.events import Error, ThreadOpen from cell_explorer_agent.messages import AssistantMessage, UserMessage +from cell_explorer_agent.telemetry import langfuse_client from cell_explorer_api.auth.dependencies import require_auth from cell_explorer_api.auth.models import User from cell_explorer_api.auth.optional import optional_auth @@ -48,6 +51,8 @@ load_thread, ) +logger = logging.getLogger(__name__) + router = APIRouter(tags=["chat"]) # Note: this router is included by routes/__init__.py whose own router has @@ -55,6 +60,39 @@ # matching the convention in datasets.py. +def _forward_feedback_to_langfuse( + *, + trace_id: str, + user_sub: str, + message_id: _uuid.UUID, + rating: str, + comment: str | None, +) -> None: + """Best-effort forward of a thumbs rating to Langfuse Scores. + + Idempotent on (user_sub, message_id) via deterministic score_id so + re-clicking updates the score in place (assuming Langfuse upserts; + duplicate rows are at worst dedupable later via score_id prefix). + + Any failure is logged and swallowed — feedback writes to our DB are + the source of truth; Langfuse is an observability sink. + """ + try: + client = langfuse_client.get(AgentConfig()) + if client is None: + return + client.create_score( + name="user_feedback", + value=1.0 if rating == "up" else 0.0, + data_type="NUMERIC", + trace_id=trace_id, + score_id=f"feedback-{user_sub}-{message_id}", + comment=comment, + ) + except Exception: + logger.exception("Failed to forward feedback to Langfuse") + + class ChatMessage(BaseModel): role: Literal["user", "assistant"] content: str = Field(min_length=1) @@ -183,6 +221,7 @@ async def _ndjson_event_stream( msg = await append_message( stream_db, thread, role="assistant", content="".join(assistant_buffer), + langfuse_trace_id=event.trace_id, ) # Capture id before commit() expires ORM attributes. message_id = str(msg.id) @@ -534,6 +573,8 @@ async def put_message_feedback( status_code=422, detail="Only assistant messages can be rated" ) + # Capture trace_id before commit() expires ORM attributes. + langfuse_trace_id = msg.langfuse_trace_id # Upsert by (message_id, user_sub). existing = ( await db.exec( @@ -559,6 +600,15 @@ async def put_message_feedback( fb = existing await db.commit() await db.refresh(fb) + # Forward to Langfuse Scores when we know the trace id. Best-effort. + if langfuse_trace_id: + _forward_feedback_to_langfuse( + trace_id=langfuse_trace_id, + user_sub=user.sub, + message_id=message_id, + rating=fb.rating, + comment=fb.comment, + ) return FeedbackResponse( rating=fb.rating, # type: ignore[arg-type] comment=fb.comment, diff --git a/packages/api/src/cell_explorer_api/services/threads.py b/packages/api/src/cell_explorer_api/services/threads.py index 57890fc..8a51997 100644 --- a/packages/api/src/cell_explorer_api/services/threads.py +++ b/packages/api/src/cell_explorer_api/services/threads.py @@ -144,9 +144,15 @@ async def append_message( *, role: str, content: str, + langfuse_trace_id: str | None = None, ) -> ChatMessageRow: """Append a message to the thread and bump its updated_at.""" - msg = ChatMessageRow(thread_id=thread.id, role=role, content=content) + msg = ChatMessageRow( + thread_id=thread.id, + role=role, + content=content, + langfuse_trace_id=langfuse_trace_id, + ) session.add(msg) thread.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) session.add(thread) diff --git a/packages/api/tests/routes/test_chat_feedback.py b/packages/api/tests/routes/test_chat_feedback.py index d3cb9f2..aead3e5 100644 --- a/packages/api/tests/routes/test_chat_feedback.py +++ b/packages/api/tests/routes/test_chat_feedback.py @@ -284,3 +284,152 @@ def test_thread_detail_includes_message_id(seeded_app): res = client.get(f"/api/chat/public-atlas/threads/{thread_id}") assert res.status_code == 200 assert res.json()["messages"][0]["id"] == msg_id + + +def _create_assistant_msg_with_trace_id(seeded_app, *, user_sub: str, trace_id: str) -> str: + """Seed one assistant ChatMessageRow with langfuse_trace_id set.""" + from cell_explorer_api.db.models import ChatMessageRow, ChatThread, Dataset + from sqlmodel import select + + async def _seed() -> str: + engine = seeded_app.state.db_engine + async with AsyncSession(engine) as session: + dataset = (await session.exec(select(Dataset))).first() + t = ChatThread(user_sub=user_sub, dataset_id=dataset.id, title="t") + session.add(t) + await session.flush() + m = ChatMessageRow( + thread_id=t.id, + role="assistant", + content="hello", + langfuse_trace_id=trace_id, + ) + session.add(m) + await session.commit() + await session.refresh(m) + return str(m.id) + + return asyncio.run(_seed()) + + +def test_put_feedback_forwards_to_langfuse_when_trace_id_present( + seeded_app, monkeypatch +): + """PUT 👍 on a message with langfuse_trace_id calls Langfuse create_score.""" + from cell_explorer_agent.telemetry import langfuse_client + from cell_explorer_agent.telemetry.fake import FakeLangfuseClient + + fake = FakeLangfuseClient() + monkeypatch.setattr(langfuse_client, "get", lambda _config: fake) + + client = TestClient(seeded_app) + _set_auth_cookie(client, seeded_app, sub="user-1") + msg_id = _create_assistant_msg_with_trace_id( + seeded_app, user_sub="user-1", trace_id="trace-xyz", + ) + + res = client.put( + f"/api/chat/public-atlas/messages/{msg_id}/feedback", + json={"rating": "up", "comment": "great answer"}, + ) + assert res.status_code == 200, res.text + + assert len(fake.scores) == 1 + s = fake.scores[0] + assert s["name"] == "user_feedback" + assert s["value"] == 1.0 + assert s["data_type"] == "NUMERIC" + assert s["trace_id"] == "trace-xyz" + assert s["score_id"] == f"feedback-user-1-{msg_id}" + assert s["comment"] == "great answer" + + +def test_put_feedback_forwards_down_as_zero(seeded_app, monkeypatch): + """👎 maps to value=0.0.""" + from cell_explorer_agent.telemetry import langfuse_client + from cell_explorer_agent.telemetry.fake import FakeLangfuseClient + + fake = FakeLangfuseClient() + monkeypatch.setattr(langfuse_client, "get", lambda _config: fake) + + client = TestClient(seeded_app) + _set_auth_cookie(client, seeded_app, sub="user-1") + msg_id = _create_assistant_msg_with_trace_id( + seeded_app, user_sub="user-1", trace_id="trace-xyz", + ) + + res = client.put( + f"/api/chat/public-atlas/messages/{msg_id}/feedback", + json={"rating": "down"}, + ) + assert res.status_code == 200 + assert fake.scores[0]["value"] == 0.0 + + +def test_put_feedback_skips_langfuse_when_no_trace_id(seeded_app, monkeypatch): + """When the message has no langfuse_trace_id, no Langfuse call is made.""" + from cell_explorer_agent.telemetry import langfuse_client + from cell_explorer_agent.telemetry.fake import FakeLangfuseClient + + fake = FakeLangfuseClient() + monkeypatch.setattr(langfuse_client, "get", lambda _config: fake) + + client = TestClient(seeded_app) + _set_auth_cookie(client, seeded_app, sub="user-1") + # Message created via _create_thread_with_assistant_msg has no trace_id. + _thread_id, msg_id = _create_thread_with_assistant_msg(seeded_app, user_sub="user-1") + + res = client.put( + f"/api/chat/public-atlas/messages/{msg_id}/feedback", + json={"rating": "up"}, + ) + assert res.status_code == 200 + assert fake.scores == [] + + +def test_put_feedback_swallows_langfuse_errors(seeded_app, monkeypatch): + """A Langfuse failure must not break the user-facing feedback PUT.""" + from cell_explorer_agent.telemetry import langfuse_client + + class _BrokenClient: + def create_score(self, **_kwargs): + raise RuntimeError("langfuse down") + + monkeypatch.setattr(langfuse_client, "get", lambda _config: _BrokenClient()) + + client = TestClient(seeded_app) + _set_auth_cookie(client, seeded_app, sub="user-1") + msg_id = _create_assistant_msg_with_trace_id( + seeded_app, user_sub="user-1", trace_id="trace-xyz", + ) + + res = client.put( + f"/api/chat/public-atlas/messages/{msg_id}/feedback", + json={"rating": "up"}, + ) + assert res.status_code == 200, res.text + + +def test_delete_feedback_does_not_call_langfuse(seeded_app, monkeypatch): + """DELETE clears local feedback only; Langfuse history is preserved.""" + from cell_explorer_agent.telemetry import langfuse_client + from cell_explorer_agent.telemetry.fake import FakeLangfuseClient + + fake = FakeLangfuseClient() + monkeypatch.setattr(langfuse_client, "get", lambda _config: fake) + + client = TestClient(seeded_app) + _set_auth_cookie(client, seeded_app, sub="user-1") + msg_id = _create_assistant_msg_with_trace_id( + seeded_app, user_sub="user-1", trace_id="trace-xyz", + ) + # First PUT creates feedback + one score; then DELETE clears the row. + client.put( + f"/api/chat/public-atlas/messages/{msg_id}/feedback", + json={"rating": "up"}, + ) + pre_delete_scores = list(fake.scores) + res = client.delete(f"/api/chat/public-atlas/messages/{msg_id}/feedback") + assert res.status_code == 204 + # DELETE did not produce any additional Langfuse calls. + assert fake.scores == pre_delete_scores diff --git a/packages/cell-explorer-agent/src/cell_explorer_agent/agent.py b/packages/cell-explorer-agent/src/cell_explorer_agent/agent.py index e9f34d9..9b9446b 100644 --- a/packages/cell-explorer-agent/src/cell_explorer_agent/agent.py +++ b/packages/cell-explorer-agent/src/cell_explorer_agent/agent.py @@ -176,7 +176,7 @@ async def run( if stop_reason == "end_turn" or not pending_calls: trace.set_output("".join(final_assistant_text_buffer)) - yield Done(usage=total_usage) + yield Done(usage=total_usage, trace_id=trace.trace_id) return if tool_calls_this_turn + len(pending_calls) > max_calls: @@ -187,7 +187,7 @@ async def run( retryable=False, ) trace.set_output("".join(final_assistant_text_buffer)) - yield Done(usage=total_usage) + yield Done(usage=total_usage, trace_id=trace.trace_id) return results: list[ToolResult] = [] diff --git a/packages/cell-explorer-agent/src/cell_explorer_agent/events.py b/packages/cell-explorer-agent/src/cell_explorer_agent/events.py index 907dd1a..3f0a119 100644 --- a/packages/cell-explorer-agent/src/cell_explorer_agent/events.py +++ b/packages/cell-explorer-agent/src/cell_explorer_agent/events.py @@ -58,6 +58,11 @@ class Done(BaseModel): # does not persist, so this is None at agent emit time; the API layer fills # it in just before yielding the event to the client. message_id: str | None = None + # Langfuse trace id for this turn. Set by the agent when telemetry is + # enabled; the API layer persists it on the assistant row so later + # feedback (PUT /feedback) can forward to Langfuse Scores. Wire-level + # field — the frontend ignores it. + trace_id: str | None = None class ThreadOpen(BaseModel): diff --git a/packages/cell-explorer-agent/src/cell_explorer_agent/telemetry/fake.py b/packages/cell-explorer-agent/src/cell_explorer_agent/telemetry/fake.py index 9905861..9e5efb1 100644 --- a/packages/cell-explorer-agent/src/cell_explorer_agent/telemetry/fake.py +++ b/packages/cell-explorer-agent/src/cell_explorer_agent/telemetry/fake.py @@ -93,6 +93,11 @@ def __init__(self) -> None: self.trace_updates: list[dict[str, Any]] = [] self.flushed_count: int = 0 self._current_stack: list[FakeObservation] = [] + # Tests can override; defaults to a synthetic value returned by + # get_current_trace_id() while a span context is active. + self.fake_trace_id: str = "fake-trace-id" + # Records every create_score(**kwargs) call for assertion in tests. + self.scores: list[dict[str, Any]] = [] # ---- observation creation ---- @@ -135,6 +140,19 @@ def update_current_span(self, **kwargs: Any) -> None: obs.update(**kwargs) return + # ---- trace id resolution ---- + + def get_current_trace_id(self) -> str | None: + """Mirrors v3 SDK: returns the active trace id while a span context + is open; None otherwise.""" + return self.fake_trace_id if self._current_stack else None + + # ---- scores ---- + + def create_score(self, **kwargs: Any) -> None: + """Record a create_score call. Tests inspect self.scores.""" + self.scores.append(dict(kwargs)) + # ---- lifecycle ---- def flush(self) -> None: diff --git a/packages/cell-explorer-agent/src/cell_explorer_agent/telemetry/trace_context.py b/packages/cell-explorer-agent/src/cell_explorer_agent/telemetry/trace_context.py index c687e79..d3b7346 100644 --- a/packages/cell-explorer-agent/src/cell_explorer_agent/telemetry/trace_context.py +++ b/packages/cell-explorer-agent/src/cell_explorer_agent/telemetry/trace_context.py @@ -63,6 +63,10 @@ def __init__( # v3 internals self._root_cm: Any | None = None self._root_obs: Any | None = None + # Captured inside __aenter__ once the root span is active. Used by + # callers (route layer) to persist a message → trace association so + # later feedback can be forwarded to Langfuse Scores. + self.trace_id: str | None = None # Per-turn counter so a turn with multiple LLM rounds shows # llm-call-1, llm-call-2, ... in the Langfuse trace tree # instead of three identical 'llm-call' rows. @@ -94,6 +98,12 @@ async def __aenter__(self) -> "TurnTrace": }, ) self._root_obs = self._root_cm.__enter__() + # Capture trace id from the now-active span context. None if the + # SDK can't resolve one (e.g. fake clients without OTel context). + try: + self.trace_id = self._client.get_current_trace_id() + except Exception: + self.trace_id = None # Set trace-level fields (user/session/tags) once the root span # is active — these belong on the implicit trace, not the span. self._client.update_current_trace( diff --git a/packages/cell-explorer-agent/tests/test_telemetry_trace_context.py b/packages/cell-explorer-agent/tests/test_telemetry_trace_context.py index 4fa880d..e70c301 100644 --- a/packages/cell-explorer-agent/tests/test_telemetry_trace_context.py +++ b/packages/cell-explorer-agent/tests/test_telemetry_trace_context.py @@ -110,6 +110,40 @@ async def test_private_dataset_trace_redacts_content(fake): assert fake.tool_spans[0].input == {"_redacted": "tool_args", "tool": "filter_by_ids"} +async def test_turn_trace_captures_trace_id_from_active_span(fake): + """Once the root span is open, get_current_trace_id() returns a value; + TurnTrace caches it on self.trace_id so the route layer can persist it.""" + fake.fake_trace_id = "abc-trace-123" + async with TurnTrace( + client=fake, + user_id="u", + thread_id="t", + dataset_slug="ds", + is_public=True, + model="m", + environment="test", + user_input="hi", + view_state=None, + ) as trace: + assert trace.trace_id == "abc-trace-123" + + +async def test_turn_trace_trace_id_none_when_no_client(): + """No client → no span context → no trace_id.""" + async with TurnTrace( + client=None, + user_id="u", + thread_id="t", + dataset_slug="ds", + is_public=True, + model="m", + environment="test", + user_input="hi", + view_state=None, + ) as trace: + assert trace.trace_id is None + + async def test_no_client_is_noop(): """When client=None, the context manager records nothing and does not raise.""" async with TurnTrace(