From 6d8eb496297c2a0743420d3f3470b40e53dc5371 Mon Sep 17 00:00:00 2001 From: IBatae01 Date: Tue, 27 Jan 2026 16:08:51 +0300 Subject: [PATCH] add context dict to invoke, stream, stream_log, astream_events --- langserve/api_handler.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/langserve/api_handler.py b/langserve/api_handler.py index 51cea2c1..fc6d6cf4 100644 --- a/langserve/api_handler.py +++ b/langserve/api_handler.py @@ -816,6 +816,7 @@ async def _get_config_and_input( except json.JSONDecodeError: raise RequestValidationError(errors=["Invalid JSON body"]) try: + context = body.get("context", {}) body = InvokeRequestShallowValidator.model_validate(body) # Merge the config from the path with the config from the body. @@ -839,7 +840,7 @@ async def _get_config_and_input( # using configuration. schema = self._runnable.with_config(config).input_schema input_ = schema.model_validate(body.input) - return config, _unpack_input(input_) + return config, _unpack_input(input_), context except ValidationError as e: raise RequestValidationError(e.errors(), body=body) @@ -862,7 +863,7 @@ async def invoke( """ # We do not use the InvokeRequest model here since configurable runnables # have dynamic schema -- so the validation below is a bit more involved. - config, input_ = await self._get_config_and_input( + config, input_, context = await self._get_config_and_input( request, config_hash, endpoint="invoke", @@ -876,6 +877,7 @@ async def invoke( invoke_coro = self._runnable.ainvoke( input_, config=config, + context=context, ) feedback_key: Optional[str] @@ -1146,7 +1148,7 @@ async def stream( """ run_id = None try: - config, input_ = await self._get_config_and_input( + config, input_, context = await self._get_config_and_input( request, config_hash, endpoint="stream", @@ -1184,6 +1186,7 @@ async def _stream() -> AsyncIterator[dict]: async for chunk in self._runnable.astream( input_, config=config_w_callbacks, + context=context, ): # Send a metadata event as soon as possible if not has_sent_metadata: @@ -1234,7 +1237,7 @@ async def stream_log( It's attached to _stream_log_docs endpoint. """ try: - config, input_ = await self._get_config_and_input( + config, input_, context = await self._get_config_and_input( request, config_hash, endpoint="stream_log", @@ -1277,6 +1280,7 @@ async def _stream_log() -> AsyncIterator[dict]: async for chunk in self._runnable.astream_log( input_, config=config, + context=context, diff=True, include_names=stream_log_request.include_names, include_types=stream_log_request.include_types, @@ -1343,7 +1347,7 @@ async def astream_events( """Stream events from the runnable.""" run_id = None try: - config, input_ = await self._get_config_and_input( + config, input_, context = await self._get_config_and_input( request, config_hash, endpoint="stream_events", @@ -1393,6 +1397,7 @@ async def _stream_events() -> AsyncIterator[dict]: async for event in self._runnable.astream_events( input_, config=config, + context=context, include_names=stream_events_request.include_names, include_types=stream_events_request.include_types, include_tags=stream_events_request.include_tags,