From f3504354621141110a62d37c0b420516cb0ff5ae Mon Sep 17 00:00:00 2001 From: John Toman Date: Fri, 20 Feb 2026 14:46:16 -0800 Subject: [PATCH] Robustness in the face of LLM goofs 1. When the LLM fails to call a tool, graphcore will currently loop pointlessly between tool and tool_result nodes. This PR adds new routing for a node which send a reminder about tool calls instead. 2. LLMs failing to follow tool calling schemas is distressingly common: when a pydantic validation error occurs, pass this back to the llm instead of crashing --- graph.py | 89 ++++++++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 77 insertions(+), 12 deletions(-) diff --git a/graph.py b/graph.py index 23537d9..955b45f 100644 --- a/graph.py +++ b/graph.py @@ -26,7 +26,7 @@ from langgraph.types import Command from langgraph.prebuilt import ToolNode from langchain_anthropic import ChatAnthropic -from pydantic import BaseModel +from pydantic import BaseModel, ValidationError from .utils import cached_invoke, acached_invoke from .summary import SummaryConfig @@ -279,6 +279,25 @@ def to_return(state: StateT) -> PureFunctionGenerator: return {} return to_return +def _get_scolding_pure( + t: type[StateT] +) -> PureFunction[StateT, dict[str, list[BaseMessage]]]: + def impl( + state: StateT + ) -> PureFunctionGenerator[dict[str, list[BaseMessage]]]: + m = state["messages"].copy() + scolding = HumanMessage( + content="Every AI turn must end with at least one tool call. Double check your initial prompt for what tools you should be using. " \ + "In particular, if you are done processing, make sure to deliver your output via the result tool." + ) + m.append(scolding) + res = yield m + assert isinstance(res, AIMessage) + return { + "messages": [scolding, res] + } + return impl + I = TypeVar("I", bound=FlowInput) O = TypeVar("O") @@ -402,10 +421,37 @@ def __call__( ) -> AnyNodeFunction[InputState, StateT]: ... +class _ScoldingFact(Protocol): + def __call__( + self, + llm: LLM, + ty: type[StateT] + ) -> AnyChatNodeFunction[StateT]: + ... + +def get_async_scolder( + llm: LLM, + ty: type[StateT] +) -> AsyncChatNodeFunction[StateT]: + return _stitch_async_impl( + _get_scolding_pure(ty), + _async_llm(llm) + ) + +def get_sync_scolder( + llm: LLM, + ty: type[StateT] +) -> ChatNodeFunction[StateT]: + return _stitch_sync_impl( + _get_scolding_pure(ty), + _sync_llm(llm) + ) + INITIAL_NODE = "initial" TOOLS_NODE = "tools" TOOL_RESULT_NODE = "tool_result" SUMMARIZE_NODE = "summarize" +SCOLDING_NODE = "scold" BoundLLM = LLM @@ -544,7 +590,7 @@ def with_tools(self, l: Iterable[BaseTool | SplitTool]) -> "Builder[_BStateT, _B to_ret._tools.extend(l) return to_ret - def _build_internal(self, r: _ResultFact, i: _InitialFact, s: _SummarizerFact) -> Tuple["StateGraph[_BStateT, _BContextT, _BInputT, Any]", BoundLLM]: #type: ignore + def _build_internal(self, r: _ResultFact, i: _InitialFact, s: _SummarizerFact, scold: _ScoldingFact) -> Tuple["StateGraph[_BStateT, _BContextT, _BInputT, Any]", BoundLLM]: #type: ignore if self._state_class is None: raise ValueError("state_class is required") if self._input_type is None: @@ -571,21 +617,24 @@ def _build_internal(self, r: _ResultFact, i: _InitialFact, s: _SummarizerFact) - init_fact=i, result_fact=r, summary_fact=s, - output_schema=None + output_schema=None, + scolder=scold ) def build(self) -> Tuple["StateGraph[_BStateT, _BContextT, _BInputT, Any]", BoundLLM]: #type: ignore return self._build_internal( s=get_summarizer, i=initial_node, - r=tool_result_generator + r=tool_result_generator, + scold=get_sync_scolder ) def build_async(self) -> Tuple["StateGraph[_BStateT, _BContextT, _BInputT, Any]", BoundLLM]: #type: ignore return self._build_internal( s=get_async_summarizer, i=async_initial_node, - r=async_tool_result_generator + r=async_tool_result_generator, + scold=get_async_scolder ) def compile_async( @@ -626,7 +675,8 @@ def build_workflow( summary_config, tool_result_generator, initial_node, - get_summarizer + get_summarizer, + get_sync_scolder ) @@ -655,7 +705,8 @@ def build_async_workflow( summary_config, async_tool_result_generator, async_initial_node, - get_async_summarizer + get_async_summarizer, + get_async_scolder ) def _build_workflow( @@ -671,7 +722,8 @@ def _build_workflow( summary_config: SummaryConfig[StateT] | None, result_fact: _ResultFact, init_fact: _InitialFact, - summary_fact: _SummarizerFact + summary_fact: _SummarizerFact, + scolder: _ScoldingFact, ) -> Tuple[StateGraph[StateT, ContextT, InputState, OutputT], LLM]: """ Build a standard workflow with initial node -> tools -> tool_result pattern. @@ -709,6 +761,16 @@ def should_end(state: StateT) -> Literal["__end__", "tool_result"]: return "__end__" return "tool_result" + def should_scold_about_tools(state: StateT) -> Literal["tools", "scold"]: + m = state["messages"] + last = m[-1] + if not isinstance(last, AIMessage): + raise ValueError("Routing is broken, have non-AI message at end of message") + if len(last.tool_calls) == 0: + return "scold" + else: + return "tools" + tool_schemas : list[BaseTool | dict] = [] tool_impls : list[BaseTool] = [] @@ -734,7 +796,7 @@ def should_end(state: StateT) -> Literal["__end__", "tool_result"]: # Create initial node and tool node with curried LLM init_node = init_fact(input_type, state_class, sys_prompt=sys_prompt, initial_prompt=initial_prompt, llm=llm) - tool_node = ToolNode(tool_impls, handle_tool_errors=False) + tool_node = ToolNode(tool_impls, handle_tool_errors=(ValidationError,)) tool_result_node = result_fact(state_class, llm) # Build the graph with fixed input schema, no context @@ -749,8 +811,11 @@ def should_end(state: StateT) -> Literal["__end__", "tool_result"]: builder.add_node(INITIAL_NODE, init_node, input_schema=input_type) builder.add_node(TOOLS_NODE, tool_node) builder.add_node(TOOL_RESULT_NODE, tool_result_node) - builder.add_edge(INITIAL_NODE, TOOLS_NODE) - builder.add_edge(TOOL_RESULT_NODE, TOOLS_NODE) + builder.add_conditional_edges(INITIAL_NODE, should_scold_about_tools) + builder.add_conditional_edges(SCOLDING_NODE, should_scold_about_tools) + builder.add_conditional_edges(TOOL_RESULT_NODE, should_scold_about_tools) + + builder.add_node(SCOLDING_NODE, scolder(llm, state_class)) if summary_config is not None: def routing(state: StateT) -> Literal["summarize", "tool_result", "__end__"]: @@ -765,7 +830,7 @@ def routing(state: StateT) -> Literal["summarize", "tool_result", "__end__"]: unbound_llm, sys_prompt, initial_prompt, state_class, summary_config ) builder.add_node(SUMMARIZE_NODE, summarizer) - builder.add_edge(SUMMARIZE_NODE, TOOL_RESULT_NODE) + builder.add_edge(SUMMARIZE_NODE, TOOL_RESULT_NODE) builder.add_conditional_edges(TOOLS_NODE, routing) else: builder.add_conditional_edges(TOOLS_NODE, should_end)