Skip to content
This repository was archived by the owner on May 5, 2026. It is now read-only.
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions langserve/api_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,6 +816,7 @@ async def _get_config_and_input(
except json.JSONDecodeError:
raise RequestValidationError(errors=["Invalid JSON body"])
try:
context = body.get("context", {})
body = InvokeRequestShallowValidator.model_validate(body)

# Merge the config from the path with the config from the body.
Expand All @@ -839,7 +840,7 @@ async def _get_config_and_input(
# using configuration.
schema = self._runnable.with_config(config).input_schema
input_ = schema.model_validate(body.input)
return config, _unpack_input(input_)
return config, _unpack_input(input_), context
except ValidationError as e:
raise RequestValidationError(e.errors(), body=body)

Expand All @@ -862,7 +863,7 @@ async def invoke(
"""
# We do not use the InvokeRequest model here since configurable runnables
# have dynamic schema -- so the validation below is a bit more involved.
config, input_ = await self._get_config_and_input(
config, input_, context = await self._get_config_and_input(
request,
config_hash,
endpoint="invoke",
Expand All @@ -876,6 +877,7 @@ async def invoke(
invoke_coro = self._runnable.ainvoke(
input_,
config=config,
context=context,
)

feedback_key: Optional[str]
Expand Down Expand Up @@ -1146,7 +1148,7 @@ async def stream(
"""
run_id = None
try:
config, input_ = await self._get_config_and_input(
config, input_, context = await self._get_config_and_input(
request,
config_hash,
endpoint="stream",
Expand Down Expand Up @@ -1184,6 +1186,7 @@ async def _stream() -> AsyncIterator[dict]:
async for chunk in self._runnable.astream(
input_,
config=config_w_callbacks,
context=context,
):
# Send a metadata event as soon as possible
if not has_sent_metadata:
Expand Down Expand Up @@ -1234,7 +1237,7 @@ async def stream_log(
It's attached to _stream_log_docs endpoint.
"""
try:
config, input_ = await self._get_config_and_input(
config, input_, context = await self._get_config_and_input(
request,
config_hash,
endpoint="stream_log",
Expand Down Expand Up @@ -1277,6 +1280,7 @@ async def _stream_log() -> AsyncIterator[dict]:
async for chunk in self._runnable.astream_log(
input_,
config=config,
context=context,
diff=True,
include_names=stream_log_request.include_names,
include_types=stream_log_request.include_types,
Expand Down Expand Up @@ -1343,7 +1347,7 @@ async def astream_events(
"""Stream events from the runnable."""
run_id = None
try:
config, input_ = await self._get_config_and_input(
config, input_, context = await self._get_config_and_input(
request,
config_hash,
endpoint="stream_events",
Expand Down Expand Up @@ -1393,6 +1397,7 @@ async def _stream_events() -> AsyncIterator[dict]:
async for event in self._runnable.astream_events(
input_,
config=config,
context=context,
include_names=stream_events_request.include_names,
include_types=stream_events_request.include_types,
include_tags=stream_events_request.include_tags,
Expand Down