-
Notifications
You must be signed in to change notification settings - Fork 657
feat(llm): add LangChain adapter and framework registry #1759
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Pouyanpi
merged 5 commits into
develop
from
feat/langchain-decouple/stack-2-adapter-layer
Apr 13, 2026
Merged
Changes from 3 commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
474c3f4
feat(llm): add LangChain adapter and framework registry
Pouyanpi 9913a39
fix(llm): address review feedback on adapter and framework registry
Pouyanpi a1814b3
refactor(llm): extract shared helpers in LangChain adapter
Pouyanpi aa679c2
apply review suggestions
Pouyanpi ee8d73b
apply review suggestions
Pouyanpi File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,394 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import logging | ||
| import uuid | ||
| from typing import Any, AsyncIterator, Dict, List, Optional, Union | ||
|
|
||
| from nemoguardrails.types import ( | ||
| ChatMessage, | ||
| FinishReason, | ||
| LLMModel, | ||
| LLMResponse, | ||
| LLMResponseChunk, | ||
| ToolCall, | ||
| ToolCallFunction, | ||
| UsageInfo, | ||
| ) | ||
|
|
||
| log = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| def _infer_model_name(llm: Any): | ||
| """Helper to infer the model name based from an LLM instance. | ||
|
|
||
| Because not all models implement correctly _identifying_params from LangChain, we have to | ||
| try to do this manually. | ||
| """ | ||
| for attr in ["model", "model_name"]: | ||
| if hasattr(llm, attr): | ||
| val = getattr(llm, attr) | ||
| if isinstance(val, str): | ||
| return val | ||
|
|
||
| model_kwargs = getattr(llm, "model_kwargs", None) | ||
| if model_kwargs and isinstance(model_kwargs, dict): | ||
| for attr in ["model", "model_name", "name"]: | ||
| val = model_kwargs.get(attr) | ||
| if isinstance(val, str): | ||
| return val | ||
|
|
||
| # If we still can't figure out, return "unknown". | ||
| return "unknown" | ||
|
|
||
|
|
||
| def _infer_provider_from_module(llm: Any) -> Optional[str]: | ||
| """Infer provider name from the LLM's module path. | ||
|
|
||
| This function extracts the provider name from LangChain package naming conventions: | ||
| - langchain_openai -> openai | ||
| - langchain_anthropic -> anthropic | ||
| - langchain_google_genai -> google_genai | ||
| - langchain_nvidia_ai_endpoints -> nvidia_ai_endpoints | ||
| - langchain_community.chat_models.ollama -> ollama | ||
|
|
||
| For patched/wrapped classes, checks base classes as well. | ||
|
|
||
| Args: | ||
| llm: The LLM instance | ||
|
|
||
| Returns: | ||
| The inferred provider name, or None if it cannot be determined | ||
| """ | ||
| module = type(llm).__module__ | ||
|
|
||
| if module.startswith("langchain_"): | ||
| package = module.split(".")[0] | ||
| provider = package.replace("langchain_", "") | ||
|
|
||
| if provider == "community": | ||
| parts = module.split(".") | ||
| return parts[-1] if len(parts) >= 3 else "community" | ||
| else: | ||
| return provider | ||
|
|
||
| for base_class in type(llm).__mro__[1:]: | ||
| base_module = base_class.__module__ | ||
| if base_module.startswith("langchain_"): | ||
|
Pouyanpi marked this conversation as resolved.
|
||
| package = base_module.split(".")[0] | ||
| provider = package.replace("langchain_", "") | ||
|
|
||
| if provider == "community": | ||
| parts = base_module.split(".") | ||
| return parts[-1] if len(parts) >= 3 else "community" | ||
| else: | ||
| return provider | ||
|
|
||
| return None | ||
|
|
||
|
|
||
| _BASE_URL_ATTRIBUTES = [ | ||
| "base_url", | ||
| "endpoint_url", | ||
| "server_url", | ||
| "azure_endpoint", | ||
| "openai_api_base", | ||
| "api_base", | ||
| "api_host", | ||
| "endpoint", | ||
| ] | ||
|
|
||
|
|
||
| class LangChainLLMAdapter: | ||
| def __init__(self, llm): | ||
| self._llm = llm | ||
|
|
||
| @property | ||
| def raw_llm(self) -> Any: | ||
| return self._llm | ||
|
|
||
| @property | ||
| def model_name(self) -> str: | ||
| return _infer_model_name(self._llm) | ||
|
|
||
| @property | ||
| def provider_name(self) -> Optional[str]: | ||
| return _infer_provider_from_module(self._llm) | ||
|
|
||
| @property | ||
| def provider_url(self) -> Optional[str]: | ||
| # temp: uses _BASE_URL_ATTRIBUTES which duplicates utils.py BASE_URL_ATTRIBUTES. | ||
| # utils.py copy will be removed in stack-3 when it switches to model.provider_url. | ||
| for attr in _BASE_URL_ATTRIBUTES: | ||
| value = getattr(self._llm, attr, None) | ||
| if value: | ||
| return str(value) | ||
| client = getattr(self._llm, "client", None) | ||
| if client and hasattr(client, "base_url"): | ||
| return str(client.base_url) | ||
| return None | ||
|
|
||
| def _filter_reasoning_model_params(self, params: Optional[dict]) -> Optional[dict]: | ||
| if not params or "temperature" not in params: | ||
| return params | ||
|
|
||
| model_name = _infer_model_name(self._llm).lower() | ||
|
|
||
| is_openai_reasoning_model = ( | ||
| model_name.startswith("o1") | ||
| or model_name.startswith("o3") | ||
| or (model_name.startswith("gpt-5") and "chat" not in model_name) | ||
|
Pouyanpi marked this conversation as resolved.
|
||
| ) | ||
|
|
||
| if is_openai_reasoning_model: | ||
| filtered = params.copy() | ||
| filtered.pop("temperature", None) | ||
| log.debug("Stripped 'temperature' for reasoning model '%s'", model_name) | ||
| return filtered | ||
|
|
||
| return params | ||
|
Pouyanpi marked this conversation as resolved.
|
||
|
|
||
| def _prepare_llm(self, kwargs: dict): | ||
| kwargs = self._filter_reasoning_model_params(kwargs) or {} | ||
| llm = self._llm | ||
| if kwargs: | ||
| llm = llm.bind(**kwargs) | ||
| return llm | ||
|
|
||
| def _to_langchain_input(self, prompt): | ||
| if isinstance(prompt, list): | ||
| from nemoguardrails.integrations.langchain.message_utils import ( | ||
| chatmessages_to_langchain_messages, | ||
| ) | ||
|
|
||
| return chatmessages_to_langchain_messages(prompt) | ||
| return prompt | ||
|
|
||
| async def generate( | ||
| self, | ||
| prompt: Union[str, List[ChatMessage]], | ||
| *, | ||
| stop: Optional[List[str]] = None, | ||
| **kwargs, | ||
| ) -> LLMResponse: | ||
| llm = self._prepare_llm(kwargs) | ||
| messages = self._to_langchain_input(prompt) | ||
| response = await llm.ainvoke(messages, stop=stop) | ||
|
Pouyanpi marked this conversation as resolved.
|
||
| return _langchain_response_to_llm_response(response) | ||
|
|
||
| async def stream( | ||
| self, | ||
| prompt: Union[str, List[ChatMessage]], | ||
| *, | ||
| stop: Optional[List[str]] = None, | ||
| **kwargs, | ||
| ) -> AsyncIterator[LLMResponseChunk]: | ||
| llm = self._prepare_llm(kwargs) | ||
| messages = self._to_langchain_input(prompt) | ||
| async for chunk in llm.astream(messages, stop=stop): | ||
|
Pouyanpi marked this conversation as resolved.
|
||
| yield _langchain_chunk_to_llm_response_chunk(chunk) | ||
|
|
||
|
|
||
| class LangChainFramework: | ||
| def create_model( | ||
| self, | ||
| model_name: str, | ||
| provider_name: str, | ||
| model_kwargs: Optional[Dict[str, Any]] = None, | ||
| ) -> LLMModel: | ||
| from nemoguardrails.llm.models.langchain_initializer import ( | ||
| init_langchain_model, | ||
| ) | ||
|
Pouyanpi marked this conversation as resolved.
|
||
|
|
||
| kwargs = dict(model_kwargs) if model_kwargs else {} | ||
| mode = kwargs.pop("mode", "chat") | ||
|
|
||
| raw_llm = init_langchain_model( | ||
| model_name=model_name, | ||
| provider_name=provider_name, | ||
| mode=mode, | ||
| kwargs=kwargs, | ||
| ) | ||
| return LangChainLLMAdapter(raw_llm) | ||
|
|
||
|
|
||
| _FINISH_REASON_MAP: Dict[str, FinishReason] = { | ||
| "stop": "stop", | ||
| "end_turn": "stop", | ||
| "length": "length", | ||
| "max_tokens": "length", | ||
| "tool_calls": "tool_calls", | ||
| "tool_use": "tool_calls", | ||
| "content_filter": "content_filter", | ||
| } | ||
|
|
||
|
|
||
| def _map_finish_reason(raw: Optional[str]) -> Optional[FinishReason]: | ||
| if raw is None: | ||
| return None | ||
| return _FINISH_REASON_MAP.get(raw, "other") | ||
|
|
||
|
|
||
| def _build_usage_info(raw: Any) -> Optional[UsageInfo]: | ||
| if raw is None: | ||
| return None | ||
| if not isinstance(raw, dict): | ||
| try: | ||
| raw = dict(raw) | ||
| except (TypeError, ValueError): | ||
| return None | ||
| if not raw: | ||
| return None | ||
| input_tokens = raw.get("input_tokens", raw.get("prompt_tokens", 0)) | ||
| output_tokens = raw.get("output_tokens", raw.get("completion_tokens", 0)) | ||
| total_tokens = raw.get("total_tokens") or (input_tokens + output_tokens) | ||
| return UsageInfo( | ||
| input_tokens=input_tokens, | ||
| output_tokens=output_tokens, | ||
| total_tokens=total_tokens, | ||
| reasoning_tokens=raw.get("reasoning_tokens"), | ||
| cached_tokens=raw.get("cached_tokens", raw.get("cache_read_input_tokens")), | ||
| ) | ||
|
Pouyanpi marked this conversation as resolved.
|
||
|
|
||
|
|
||
| _EXTRACTED_METADATA_KEYS = frozenset( | ||
| { | ||
| "model_name", | ||
| "model", | ||
| "finish_reason", | ||
| "stop_reason", | ||
| "stop_sequence", | ||
| "id", | ||
| "request_id", | ||
| "token_usage", | ||
| "usage", | ||
| } | ||
| ) | ||
|
|
||
| _REASONING_KEYS = frozenset({"reasoning_content"}) | ||
|
|
||
|
|
||
| def _extract_reasoning(response: Any) -> Optional[str]: | ||
| content_blocks = getattr(response, "content_blocks", None) | ||
| if content_blocks: | ||
| for block in content_blocks: | ||
| if isinstance(block, dict) and block.get("type") == "reasoning": | ||
| val = block.get("reasoning") | ||
| if val: | ||
| return val | ||
|
|
||
| additional_kwargs = getattr(response, "additional_kwargs", None) | ||
| if additional_kwargs and isinstance(additional_kwargs, dict): | ||
| val = additional_kwargs.get("reasoning_content") | ||
| if val: | ||
| return val | ||
|
|
||
| return None | ||
|
|
||
|
|
||
| def _extract_tool_calls(response: Any) -> Optional[List[ToolCall]]: | ||
| raw = getattr(response, "tool_calls", None) | ||
| if not raw: | ||
| return None | ||
| return [ | ||
| ToolCall( | ||
| id=tc.get("id") or str(uuid.uuid4()), | ||
| type="function", | ||
| function=ToolCallFunction( | ||
| name=tc.get("name", ""), | ||
| arguments=tc.get("args", {}), | ||
| ), | ||
| ) | ||
| for tc in raw | ||
| ] | ||
|
|
||
|
|
||
| def _extract_usage(response: Any) -> Optional[UsageInfo]: | ||
| usage = _build_usage_info(getattr(response, "usage_metadata", None)) | ||
| if usage is not None: | ||
| return usage | ||
|
|
||
| for source in ( | ||
| getattr(response, "response_metadata", None) or {}, | ||
| getattr(response, "generation_info", None) or {}, | ||
| ): | ||
| token_usage = source.get("token_usage") or source.get("usage") | ||
| if token_usage: | ||
| usage = _build_usage_info(token_usage) | ||
| if usage is not None: | ||
| return usage | ||
|
|
||
| return None | ||
|
|
||
|
|
||
| def _extract_model_info(response_metadata: Dict[str, Any]) -> tuple: | ||
| model = response_metadata.get("model_name") or response_metadata.get("model") | ||
| raw_finish = response_metadata.get("finish_reason") or response_metadata.get("stop_reason") | ||
| finish_reason = _map_finish_reason(raw_finish) | ||
| stop_sequence = response_metadata.get("stop_sequence") | ||
| request_id = response_metadata.get("id") or response_metadata.get("request_id") | ||
| return model, finish_reason, stop_sequence, request_id | ||
|
Pouyanpi marked this conversation as resolved.
Outdated
|
||
|
|
||
|
|
||
| def _build_provider_metadata( | ||
| response_metadata: Dict[str, Any], | ||
| additional_kwargs: Optional[Dict[str, Any]] = None, | ||
| ) -> Optional[Dict[str, Any]]: | ||
| result: Dict[str, Any] = {k: v for k, v in response_metadata.items() if k not in _EXTRACTED_METADATA_KEYS} | ||
| if additional_kwargs: | ||
| for k, v in additional_kwargs.items(): | ||
| if k not in _REASONING_KEYS and k not in result: | ||
| result[k] = v | ||
| return result or None | ||
|
|
||
|
|
||
| def _langchain_response_to_llm_response(response: Any) -> LLMResponse: | ||
| content = getattr(response, "content", None) | ||
| if content is None: | ||
| content = str(response) | ||
|
|
||
| response_metadata = getattr(response, "response_metadata", None) or {} | ||
| additional_kwargs = getattr(response, "additional_kwargs", None) or {} | ||
| model, finish_reason, stop_sequence, request_id = _extract_model_info(response_metadata) | ||
|
|
||
| return LLMResponse( | ||
| content=content, | ||
| reasoning=_extract_reasoning(response), | ||
| tool_calls=_extract_tool_calls(response), | ||
| model=model, | ||
| finish_reason=finish_reason, | ||
| stop_sequence=stop_sequence, | ||
| request_id=request_id, | ||
| usage=_extract_usage(response), | ||
| provider_metadata=_build_provider_metadata(response_metadata, additional_kwargs), | ||
| ) | ||
|
|
||
|
|
||
| def _langchain_chunk_to_llm_response_chunk(chunk: Any) -> LLMResponseChunk: | ||
| content = getattr(chunk, "content", None) | ||
| if content is None: | ||
| content = getattr(chunk, "text", None) | ||
| if content is None: | ||
| content = str(chunk) | ||
|
|
||
| response_metadata = getattr(chunk, "response_metadata", None) or {} | ||
| generation_info = getattr(chunk, "generation_info", None) or {} | ||
| merged_metadata = {**response_metadata, **generation_info} | ||
|
|
||
| return LLMResponseChunk( | ||
| delta_content=content, | ||
| usage=_extract_usage(chunk), | ||
| provider_metadata=merged_metadata or None, | ||
| ) | ||
|
Pouyanpi marked this conversation as resolved.
Pouyanpi marked this conversation as resolved.
Pouyanpi marked this conversation as resolved.
|
||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.