Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

import logging
import re
import uuid
from collections.abc import AsyncGenerator

Expand All @@ -40,6 +41,21 @@

logger = logging.getLogger(__name__)

_MODEL_NAME_RE = re.compile(r"^[a-zA-Z0-9][a-zA-Z0-9._/:@\- ]{0,254}$")


def _build_lc_config(max_tool_calls: int, model: str | None, *, supports_override: bool = True) -> dict:
lc_config: dict = {"recursion_limit": (max_tool_calls + 1) * 2}
if model is not None:
if not _MODEL_NAME_RE.match(model):
raise ValueError(f"Invalid model name: {model!r}")
if not supports_override:
raise ValueError(f"The configured inference backend does not support per-request model selection. "
f"Remove the 'model' field from the request or switch to a supported backend "
f"(e.g. NIM, OpenAI, Bedrock). Requested: {model!r}")
lc_config["configurable"] = {"model_name": model}
return lc_config

Comment thread
coderabbitai[bot] marked this conversation as resolved.

class ReActAgentWorkflowConfig(AgentBaseConfig, OptimizableMixin, name="react_agent"):
"""
Expand Down Expand Up @@ -99,6 +115,7 @@ async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builde
from langchain_core.messages import AIMessageChunk
from langchain_core.messages import BaseMessage
from langchain_core.messages import trim_messages
from langchain_core.runnables.configurable import RunnableConfigurableFields
from langgraph.errors import GraphRecursionError
from langgraph.graph.state import CompiledStateGraph

Expand All @@ -112,6 +129,7 @@ async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builde

# we can choose an LLM for the ReAct agent in the config file
llm = await builder.get_llm(config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
_supports_model_override = isinstance(llm, RunnableConfigurableFields)
# the agent can run any installed tool, simply install the tool and add it to the config file
# the sample tool provided can easily be copied or changed
tools = await builder.get_tools(tool_names=config.tool_names, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
Expand Down Expand Up @@ -160,7 +178,10 @@ async def _response_fn(chat_request_or_message: ChatRequestOrMessage) -> ChatRes
state = ReActGraphState(messages=messages)

# run the ReAct Agent Graph
state = await graph.ainvoke(state, config={'recursion_limit': (config.max_tool_calls + 1) * 2})
state = await graph.ainvoke(state,
config=_build_lc_config(config.max_tool_calls,
message.model,
supports_override=_supports_model_override))
# setting recursion_limit: 4 allows 1 tool call
# - allows the ReAct Agent to perform 1 cycle / call 1 single tool,
# - but stops the agent when it tries to call a tool a second time
Expand Down Expand Up @@ -211,10 +232,12 @@ async def _stream_fn(chat_request_or_message: ChatRequestOrMessage) -> AsyncGene
buffer = ""
found_final_answer = False

async for msg, metadata in graph.astream(
state,
config={'recursion_limit': (config.max_tool_calls + 1) * 2},
stream_mode="messages"):
async for msg, metadata in graph.astream(state,
config=_build_lc_config(
config.max_tool_calls,
message.model,
supports_override=_supports_model_override),
stream_mode="messages"):
if not isinstance(msg, AIMessageChunk):
continue
if not isinstance(metadata, dict) or metadata.get("langgraph_node") != "agent":
Expand All @@ -240,8 +263,7 @@ async def _stream_fn(chat_request_or_message: ChatRequestOrMessage) -> AsyncGene
except GraphRecursionError:
logger.warning(
"%s ReAct Agent reached its maximum iteration limit (%d) without producing a final answer. "
"This typically means the LLM kept calling tools instead of returning a response.",
AGENT_LOG_PREFIX,
"This typically means the LLM kept calling tools instead of returning a response.", AGENT_LOG_PREFIX,
config.max_tool_calls)
yield ChatResponseChunk.create_streaming_chunk(
f"The react agent could not produce a final answer within {config.max_tool_calls} "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from nat.data_models.api_server import ChatRequestOrMessage
from nat.data_models.api_server import ChatResponse
from nat.plugins.langchain.agent.react_agent.register import ReActAgentWorkflowConfig
from nat.plugins.langchain.agent.react_agent.register import _build_lc_config

logger = logging.getLogger(__name__)

Expand All @@ -43,6 +44,7 @@ async def per_user_react_agent_workflow(config: PerUserReActAgentWorkflowConfig,
"""Per-user ReAct Agent - each user gets their own isolated instance."""
from langchain_core.messages import BaseMessage
from langchain_core.messages import trim_messages
from langchain_core.runnables.configurable import RunnableConfigurableFields
from langgraph.graph.state import CompiledStateGraph

from nat.data_models.api_server import Usage
Expand All @@ -54,6 +56,7 @@ async def per_user_react_agent_workflow(config: PerUserReActAgentWorkflowConfig,

prompt = create_react_agent_prompt(config)
llm = await builder.get_llm(config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
_supports_model_override = isinstance(llm, RunnableConfigurableFields)
tools = await builder.get_tools(tool_names=config.tool_names, wrapper_type=LLMFrameworkEnum.LANGCHAIN)

if not tools:
Expand Down Expand Up @@ -82,7 +85,10 @@ async def _response_fn(chat_request_or_message: ChatRequestOrMessage) -> ChatRes
start_on="human",
include_system=True)
state = ReActGraphState(messages=messages)
state = await graph.ainvoke(state, config={'recursion_limit': (config.max_tool_calls + 1) * 2})
state = await graph.ainvoke(
state,
config=_build_lc_config(config.max_tool_calls, message.model,
supports_override=_supports_model_override))
state = ReActGraphState(**state)
output_message = state.messages[-1]
content = str(output_message.content)
Expand Down
21 changes: 21 additions & 0 deletions packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from typing import Any
from typing import TypeVar

from langchain_core.runnables import ConfigurableField

from nat.builder.builder import Builder
from nat.builder.framework_enum import LLMFrameworkEnum
from nat.cli.register_workflow import register_llm_client
Expand Down Expand Up @@ -54,6 +56,21 @@

ModelType = TypeVar("ModelType")

# Azure: model_name is tracing metadata only; the deployment controls which model runs.
# HuggingFace: local pipeline has the model loaded into RAM; cannot hot-swap at inference time.
_NO_RUNTIME_MODEL_OVERRIDE: frozenset[str] = frozenset({"AzureChatOpenAI", "AsyncChatHuggingFace"})


def _get_model_configurable_field(client: Any) -> str | None:
"""Return the Pydantic field name for model selection, or None if not supported."""
if type(client).__name__ in _NO_RUNTIME_MODEL_OVERRIDE:
return None
fields = getattr(type(client), "model_fields", {})
for candidate in ("model_name", "model", "model_id"):
if candidate in fields:
return candidate
return None


def _get_langchain_oci_chat_model():
from langchain_oci import ChatOCIGenAI
Expand Down Expand Up @@ -108,6 +125,10 @@ def inject(self, messages: LanguageModelInput, *args, **kwargs) -> FunctionArgum
return FunctionArgumentWrapper(messages, *args, **kwargs)
raise ValueError(f"Unsupported message type: {type(messages)}")

model_field = _get_model_configurable_field(client)
if model_field is not None:
client = client.configurable_fields(**{model_field: ConfigurableField(id="model_name")})

if isinstance(llm_config, RetryMixin):
client = patch_with_retry(client,
retries=llm_config.num_retries,
Expand Down
Loading