diff --git a/packages/nvidia_nat_langchain/src/nat/plugins/langchain/agent/react_agent/register.py b/packages/nvidia_nat_langchain/src/nat/plugins/langchain/agent/react_agent/register.py index 0c64b426ea..95d1ed4962 100644 --- a/packages/nvidia_nat_langchain/src/nat/plugins/langchain/agent/react_agent/register.py +++ b/packages/nvidia_nat_langchain/src/nat/plugins/langchain/agent/react_agent/register.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +import re import uuid from collections.abc import AsyncGenerator @@ -40,6 +41,21 @@ logger = logging.getLogger(__name__) +_MODEL_NAME_RE = re.compile(r"^[a-zA-Z0-9][a-zA-Z0-9._/:@\- ]{0,254}$") + + +def _build_lc_config(max_tool_calls: int, model: str | None, *, supports_override: bool = True) -> dict: + lc_config: dict = {"recursion_limit": (max_tool_calls + 1) * 2} + if model is not None: + if not _MODEL_NAME_RE.match(model): + raise ValueError(f"Invalid model name: {model!r}") + if not supports_override: + raise ValueError(f"The configured inference backend does not support per-request model selection. " + f"Remove the 'model' field from the request or switch to a supported backend " + f"(e.g. NIM, OpenAI, Bedrock). Requested: {model!r}") + lc_config["configurable"] = {"model_name": model} + return lc_config + class ReActAgentWorkflowConfig(AgentBaseConfig, OptimizableMixin, name="react_agent"): """ @@ -99,6 +115,7 @@ async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builde from langchain_core.messages import AIMessageChunk from langchain_core.messages import BaseMessage from langchain_core.messages import trim_messages + from langchain_core.runnables.configurable import RunnableConfigurableFields from langgraph.errors import GraphRecursionError from langgraph.graph.state import CompiledStateGraph @@ -112,6 +129,7 @@ async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builde # we can choose an LLM for the ReAct agent in the config file llm = await builder.get_llm(config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN) + _supports_model_override = isinstance(llm, RunnableConfigurableFields) # the agent can run any installed tool, simply install the tool and add it to the config file # the sample tool provided can easily be copied or changed tools = await builder.get_tools(tool_names=config.tool_names, wrapper_type=LLMFrameworkEnum.LANGCHAIN) @@ -160,7 +178,10 @@ async def _response_fn(chat_request_or_message: ChatRequestOrMessage) -> ChatRes state = ReActGraphState(messages=messages) # run the ReAct Agent Graph - state = await graph.ainvoke(state, config={'recursion_limit': (config.max_tool_calls + 1) * 2}) + state = await graph.ainvoke(state, + config=_build_lc_config(config.max_tool_calls, + message.model, + supports_override=_supports_model_override)) # setting recursion_limit: 4 allows 1 tool call # - allows the ReAct Agent to perform 1 cycle / call 1 single tool, # - but stops the agent when it tries to call a tool a second time @@ -211,10 +232,12 @@ async def _stream_fn(chat_request_or_message: ChatRequestOrMessage) -> AsyncGene buffer = "" found_final_answer = False - async for msg, metadata in graph.astream( - state, - config={'recursion_limit': (config.max_tool_calls + 1) * 2}, - stream_mode="messages"): + async for msg, metadata in graph.astream(state, + config=_build_lc_config( + config.max_tool_calls, + message.model, + supports_override=_supports_model_override), + stream_mode="messages"): if not isinstance(msg, AIMessageChunk): continue if not isinstance(metadata, dict) or metadata.get("langgraph_node") != "agent": @@ -240,8 +263,7 @@ async def _stream_fn(chat_request_or_message: ChatRequestOrMessage) -> AsyncGene except GraphRecursionError: logger.warning( "%s ReAct Agent reached its maximum iteration limit (%d) without producing a final answer. " - "This typically means the LLM kept calling tools instead of returning a response.", - AGENT_LOG_PREFIX, + "This typically means the LLM kept calling tools instead of returning a response.", AGENT_LOG_PREFIX, config.max_tool_calls) yield ChatResponseChunk.create_streaming_chunk( f"The react agent could not produce a final answer within {config.max_tool_calls} " diff --git a/packages/nvidia_nat_langchain/src/nat/plugins/langchain/agent/react_agent/register_per_user_agent.py b/packages/nvidia_nat_langchain/src/nat/plugins/langchain/agent/react_agent/register_per_user_agent.py index 9bcaa571d7..9e3646ea4a 100644 --- a/packages/nvidia_nat_langchain/src/nat/plugins/langchain/agent/react_agent/register_per_user_agent.py +++ b/packages/nvidia_nat_langchain/src/nat/plugins/langchain/agent/react_agent/register_per_user_agent.py @@ -23,6 +23,7 @@ from nat.data_models.api_server import ChatRequestOrMessage from nat.data_models.api_server import ChatResponse from nat.plugins.langchain.agent.react_agent.register import ReActAgentWorkflowConfig +from nat.plugins.langchain.agent.react_agent.register import _build_lc_config logger = logging.getLogger(__name__) @@ -43,6 +44,7 @@ async def per_user_react_agent_workflow(config: PerUserReActAgentWorkflowConfig, """Per-user ReAct Agent - each user gets their own isolated instance.""" from langchain_core.messages import BaseMessage from langchain_core.messages import trim_messages + from langchain_core.runnables.configurable import RunnableConfigurableFields from langgraph.graph.state import CompiledStateGraph from nat.data_models.api_server import Usage @@ -54,6 +56,7 @@ async def per_user_react_agent_workflow(config: PerUserReActAgentWorkflowConfig, prompt = create_react_agent_prompt(config) llm = await builder.get_llm(config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN) + _supports_model_override = isinstance(llm, RunnableConfigurableFields) tools = await builder.get_tools(tool_names=config.tool_names, wrapper_type=LLMFrameworkEnum.LANGCHAIN) if not tools: @@ -82,7 +85,10 @@ async def _response_fn(chat_request_or_message: ChatRequestOrMessage) -> ChatRes start_on="human", include_system=True) state = ReActGraphState(messages=messages) - state = await graph.ainvoke(state, config={'recursion_limit': (config.max_tool_calls + 1) * 2}) + state = await graph.ainvoke( + state, + config=_build_lc_config(config.max_tool_calls, message.model, + supports_override=_supports_model_override)) state = ReActGraphState(**state) output_message = state.messages[-1] content = str(output_message.content) diff --git a/packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py b/packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py index 8d5a640d51..ff76464956 100644 --- a/packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py +++ b/packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py @@ -22,6 +22,8 @@ from typing import Any from typing import TypeVar +from langchain_core.runnables import ConfigurableField + from nat.builder.builder import Builder from nat.builder.framework_enum import LLMFrameworkEnum from nat.cli.register_workflow import register_llm_client @@ -54,6 +56,21 @@ ModelType = TypeVar("ModelType") +# Azure: model_name is tracing metadata only; the deployment controls which model runs. +# HuggingFace: local pipeline has the model loaded into RAM; cannot hot-swap at inference time. +_NO_RUNTIME_MODEL_OVERRIDE: frozenset[str] = frozenset({"AzureChatOpenAI", "AsyncChatHuggingFace"}) + + +def _get_model_configurable_field(client: Any) -> str | None: + """Return the Pydantic field name for model selection, or None if not supported.""" + if type(client).__name__ in _NO_RUNTIME_MODEL_OVERRIDE: + return None + fields = getattr(type(client), "model_fields", {}) + for candidate in ("model_name", "model", "model_id"): + if candidate in fields: + return candidate + return None + def _get_langchain_oci_chat_model(): from langchain_oci import ChatOCIGenAI @@ -108,6 +125,10 @@ def inject(self, messages: LanguageModelInput, *args, **kwargs) -> FunctionArgum return FunctionArgumentWrapper(messages, *args, **kwargs) raise ValueError(f"Unsupported message type: {type(messages)}") + model_field = _get_model_configurable_field(client) + if model_field is not None: + client = client.configurable_fields(**{model_field: ConfigurableField(id="model_name")}) + if isinstance(llm_config, RetryMixin): client = patch_with_retry(client, retries=llm_config.num_retries,