Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""add chat_messages.langfuse_trace_id

Revision ID: 092e2aa153ce
Revises: a15a753f44c8
Create Date: 2026-05-21 09:00:00.000000

"""
from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa


revision: str = '092e2aa153ce'
down_revision: Union[str, Sequence[str], None] = 'a15a753f44c8'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
"""Upgrade schema."""
with op.batch_alter_table('chat_messages') as batch_op:
batch_op.add_column(
sa.Column('langfuse_trace_id', sa.String(), nullable=True)
)
batch_op.create_index(
'ix_chat_messages_langfuse_trace_id',
['langfuse_trace_id'],
)


def downgrade() -> None:
"""Downgrade schema."""
with op.batch_alter_table('chat_messages') as batch_op:
batch_op.drop_index('ix_chat_messages_langfuse_trace_id')
batch_op.drop_column('langfuse_trace_id')
3 changes: 3 additions & 0 deletions packages/api/src/cell_explorer_api/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ class ChatMessageRow(SQLModel, table=True):
role: str # "user" | "assistant"
content: str
created_at: datetime = Field(default_factory=_utcnow)
# Set on assistant rows when Langfuse tracing is enabled. Used by the
# feedback PUT route to forward thumbs ratings to Langfuse Scores.
langfuse_trace_id: str | None = Field(default=None, index=True)


class ChatFeedback(SQLModel, table=True):
Expand Down
50 changes: 50 additions & 0 deletions packages/api/src/cell_explorer_api/routes/chat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""HTTP routes for the chat agent (Plan 2a)."""

import asyncio
import logging
import uuid as _uuid
from collections.abc import AsyncIterator
from datetime import datetime
Expand All @@ -13,8 +14,10 @@
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession

from cell_explorer_agent.config import AgentConfig
from cell_explorer_agent.events import Error, ThreadOpen
from cell_explorer_agent.messages import AssistantMessage, UserMessage
from cell_explorer_agent.telemetry import langfuse_client
from cell_explorer_api.auth.dependencies import require_auth
from cell_explorer_api.auth.models import User
from cell_explorer_api.auth.optional import optional_auth
Expand Down Expand Up @@ -48,13 +51,48 @@
load_thread,
)

logger = logging.getLogger(__name__)

router = APIRouter(tags=["chat"])

# Note: this router is included by routes/__init__.py whose own router has
# prefix="/api". Routes inside use absolute paths ("/chat/{slug}/context"),
# matching the convention in datasets.py.


def _forward_feedback_to_langfuse(
*,
trace_id: str,
user_sub: str,
message_id: _uuid.UUID,
rating: str,
comment: str | None,
) -> None:
"""Best-effort forward of a thumbs rating to Langfuse Scores.

Idempotent on (user_sub, message_id) via deterministic score_id so
re-clicking updates the score in place (assuming Langfuse upserts;
duplicate rows are at worst dedupable later via score_id prefix).

Any failure is logged and swallowed β€” feedback writes to our DB are
the source of truth; Langfuse is an observability sink.
"""
try:
client = langfuse_client.get(AgentConfig())
if client is None:
return
client.create_score(
name="user_feedback",
value=1.0 if rating == "up" else 0.0,
data_type="NUMERIC",
trace_id=trace_id,
score_id=f"feedback-{user_sub}-{message_id}",
comment=comment,
)
except Exception:
logger.exception("Failed to forward feedback to Langfuse")


class ChatMessage(BaseModel):
role: Literal["user", "assistant"]
content: str = Field(min_length=1)
Expand Down Expand Up @@ -183,6 +221,7 @@ async def _ndjson_event_stream(
msg = await append_message(
stream_db, thread, role="assistant",
content="".join(assistant_buffer),
langfuse_trace_id=event.trace_id,
)
# Capture id before commit() expires ORM attributes.
message_id = str(msg.id)
Expand Down Expand Up @@ -534,6 +573,8 @@ async def put_message_feedback(
status_code=422, detail="Only assistant messages can be rated"
)

# Capture trace_id before commit() expires ORM attributes.
langfuse_trace_id = msg.langfuse_trace_id
# Upsert by (message_id, user_sub).
existing = (
await db.exec(
Expand All @@ -559,6 +600,15 @@ async def put_message_feedback(
fb = existing
await db.commit()
await db.refresh(fb)
# Forward to Langfuse Scores when we know the trace id. Best-effort.
if langfuse_trace_id:
_forward_feedback_to_langfuse(
trace_id=langfuse_trace_id,
user_sub=user.sub,
message_id=message_id,
rating=fb.rating,
comment=fb.comment,
)
return FeedbackResponse(
rating=fb.rating, # type: ignore[arg-type]
comment=fb.comment,
Expand Down
8 changes: 7 additions & 1 deletion packages/api/src/cell_explorer_api/services/threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,15 @@ async def append_message(
*,
role: str,
content: str,
langfuse_trace_id: str | None = None,
) -> ChatMessageRow:
"""Append a message to the thread and bump its updated_at."""
msg = ChatMessageRow(thread_id=thread.id, role=role, content=content)
msg = ChatMessageRow(
thread_id=thread.id,
role=role,
content=content,
langfuse_trace_id=langfuse_trace_id,
)
session.add(msg)
thread.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
session.add(thread)
Expand Down
149 changes: 149 additions & 0 deletions packages/api/tests/routes/test_chat_feedback.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,3 +284,152 @@ def test_thread_detail_includes_message_id(seeded_app):
res = client.get(f"/api/chat/public-atlas/threads/{thread_id}")
assert res.status_code == 200
assert res.json()["messages"][0]["id"] == msg_id


def _create_assistant_msg_with_trace_id(seeded_app, *, user_sub: str, trace_id: str) -> str:
"""Seed one assistant ChatMessageRow with langfuse_trace_id set."""
from cell_explorer_api.db.models import ChatMessageRow, ChatThread, Dataset
from sqlmodel import select

async def _seed() -> str:
engine = seeded_app.state.db_engine
async with AsyncSession(engine) as session:
dataset = (await session.exec(select(Dataset))).first()
t = ChatThread(user_sub=user_sub, dataset_id=dataset.id, title="t")
session.add(t)
await session.flush()
m = ChatMessageRow(
thread_id=t.id,
role="assistant",
content="hello",
langfuse_trace_id=trace_id,
)
session.add(m)
await session.commit()
await session.refresh(m)
return str(m.id)

return asyncio.run(_seed())


def test_put_feedback_forwards_to_langfuse_when_trace_id_present(
seeded_app, monkeypatch
):
"""PUT πŸ‘ on a message with langfuse_trace_id calls Langfuse create_score."""
from cell_explorer_agent.telemetry import langfuse_client
from cell_explorer_agent.telemetry.fake import FakeLangfuseClient

fake = FakeLangfuseClient()
monkeypatch.setattr(langfuse_client, "get", lambda _config: fake)

client = TestClient(seeded_app)
_set_auth_cookie(client, seeded_app, sub="user-1")
msg_id = _create_assistant_msg_with_trace_id(
seeded_app, user_sub="user-1", trace_id="trace-xyz",
)

res = client.put(
f"/api/chat/public-atlas/messages/{msg_id}/feedback",
json={"rating": "up", "comment": "great answer"},
)
assert res.status_code == 200, res.text

assert len(fake.scores) == 1
s = fake.scores[0]
assert s["name"] == "user_feedback"
assert s["value"] == 1.0
assert s["data_type"] == "NUMERIC"
assert s["trace_id"] == "trace-xyz"
assert s["score_id"] == f"feedback-user-1-{msg_id}"
assert s["comment"] == "great answer"


def test_put_feedback_forwards_down_as_zero(seeded_app, monkeypatch):
"""πŸ‘Ž maps to value=0.0."""
from cell_explorer_agent.telemetry import langfuse_client
from cell_explorer_agent.telemetry.fake import FakeLangfuseClient

fake = FakeLangfuseClient()
monkeypatch.setattr(langfuse_client, "get", lambda _config: fake)

client = TestClient(seeded_app)
_set_auth_cookie(client, seeded_app, sub="user-1")
msg_id = _create_assistant_msg_with_trace_id(
seeded_app, user_sub="user-1", trace_id="trace-xyz",
)

res = client.put(
f"/api/chat/public-atlas/messages/{msg_id}/feedback",
json={"rating": "down"},
)
assert res.status_code == 200
assert fake.scores[0]["value"] == 0.0


def test_put_feedback_skips_langfuse_when_no_trace_id(seeded_app, monkeypatch):
"""When the message has no langfuse_trace_id, no Langfuse call is made."""
from cell_explorer_agent.telemetry import langfuse_client
from cell_explorer_agent.telemetry.fake import FakeLangfuseClient

fake = FakeLangfuseClient()
monkeypatch.setattr(langfuse_client, "get", lambda _config: fake)

client = TestClient(seeded_app)
_set_auth_cookie(client, seeded_app, sub="user-1")
# Message created via _create_thread_with_assistant_msg has no trace_id.
_thread_id, msg_id = _create_thread_with_assistant_msg(seeded_app, user_sub="user-1")

res = client.put(
f"/api/chat/public-atlas/messages/{msg_id}/feedback",
json={"rating": "up"},
)
assert res.status_code == 200
assert fake.scores == []


def test_put_feedback_swallows_langfuse_errors(seeded_app, monkeypatch):
"""A Langfuse failure must not break the user-facing feedback PUT."""
from cell_explorer_agent.telemetry import langfuse_client

class _BrokenClient:
def create_score(self, **_kwargs):
raise RuntimeError("langfuse down")

monkeypatch.setattr(langfuse_client, "get", lambda _config: _BrokenClient())

client = TestClient(seeded_app)
_set_auth_cookie(client, seeded_app, sub="user-1")
msg_id = _create_assistant_msg_with_trace_id(
seeded_app, user_sub="user-1", trace_id="trace-xyz",
)

res = client.put(
f"/api/chat/public-atlas/messages/{msg_id}/feedback",
json={"rating": "up"},
)
assert res.status_code == 200, res.text


def test_delete_feedback_does_not_call_langfuse(seeded_app, monkeypatch):
"""DELETE clears local feedback only; Langfuse history is preserved."""
from cell_explorer_agent.telemetry import langfuse_client
from cell_explorer_agent.telemetry.fake import FakeLangfuseClient

fake = FakeLangfuseClient()
monkeypatch.setattr(langfuse_client, "get", lambda _config: fake)

client = TestClient(seeded_app)
_set_auth_cookie(client, seeded_app, sub="user-1")
msg_id = _create_assistant_msg_with_trace_id(
seeded_app, user_sub="user-1", trace_id="trace-xyz",
)
# First PUT creates feedback + one score; then DELETE clears the row.
client.put(
f"/api/chat/public-atlas/messages/{msg_id}/feedback",
json={"rating": "up"},
)
pre_delete_scores = list(fake.scores)
res = client.delete(f"/api/chat/public-atlas/messages/{msg_id}/feedback")
assert res.status_code == 204
# DELETE did not produce any additional Langfuse calls.
assert fake.scores == pre_delete_scores
4 changes: 2 additions & 2 deletions packages/cell-explorer-agent/src/cell_explorer_agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ async def run(

if stop_reason == "end_turn" or not pending_calls:
trace.set_output("".join(final_assistant_text_buffer))
yield Done(usage=total_usage)
yield Done(usage=total_usage, trace_id=trace.trace_id)
return

if tool_calls_this_turn + len(pending_calls) > max_calls:
Expand All @@ -187,7 +187,7 @@ async def run(
retryable=False,
)
trace.set_output("".join(final_assistant_text_buffer))
yield Done(usage=total_usage)
yield Done(usage=total_usage, trace_id=trace.trace_id)
return

results: list[ToolResult] = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ class Done(BaseModel):
# does not persist, so this is None at agent emit time; the API layer fills
# it in just before yielding the event to the client.
message_id: str | None = None
# Langfuse trace id for this turn. Set by the agent when telemetry is
# enabled; the API layer persists it on the assistant row so later
# feedback (PUT /feedback) can forward to Langfuse Scores. Wire-level
# field β€” the frontend ignores it.
trace_id: str | None = None


class ThreadOpen(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ def __init__(self) -> None:
self.trace_updates: list[dict[str, Any]] = []
self.flushed_count: int = 0
self._current_stack: list[FakeObservation] = []
# Tests can override; defaults to a synthetic value returned by
# get_current_trace_id() while a span context is active.
self.fake_trace_id: str = "fake-trace-id"
# Records every create_score(**kwargs) call for assertion in tests.
self.scores: list[dict[str, Any]] = []

# ---- observation creation ----

Expand Down Expand Up @@ -135,6 +140,19 @@ def update_current_span(self, **kwargs: Any) -> None:
obs.update(**kwargs)
return

# ---- trace id resolution ----

def get_current_trace_id(self) -> str | None:
"""Mirrors v3 SDK: returns the active trace id while a span context
is open; None otherwise."""
return self.fake_trace_id if self._current_stack else None

# ---- scores ----

def create_score(self, **kwargs: Any) -> None:
"""Record a create_score call. Tests inspect self.scores."""
self.scores.append(dict(kwargs))

# ---- lifecycle ----

def flush(self) -> None:
Expand Down
Loading
Loading