Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 77 additions & 12 deletions graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from langgraph.types import Command
from langgraph.prebuilt import ToolNode
from langchain_anthropic import ChatAnthropic
from pydantic import BaseModel
from pydantic import BaseModel, ValidationError
from .utils import cached_invoke, acached_invoke
from .summary import SummaryConfig

Expand Down Expand Up @@ -279,6 +279,25 @@ def to_return(state: StateT) -> PureFunctionGenerator:
return {}
return to_return

def _get_scolding_pure(
t: type[StateT]
) -> PureFunction[StateT, dict[str, list[BaseMessage]]]:
def impl(
state: StateT
) -> PureFunctionGenerator[dict[str, list[BaseMessage]]]:
m = state["messages"].copy()
scolding = HumanMessage(
content="Every AI turn must end with at least one tool call. Double check your initial prompt for what tools you should be using. " \
"In particular, if you are done processing, make sure to deliver your output via the result tool."
)
m.append(scolding)
res = yield m
assert isinstance(res, AIMessage)
return {
"messages": [scolding, res]
}
return impl

I = TypeVar("I", bound=FlowInput)
O = TypeVar("O")

Expand Down Expand Up @@ -402,10 +421,37 @@ def __call__(
) -> AnyNodeFunction[InputState, StateT]:
...

class _ScoldingFact(Protocol):
def __call__(
self,
llm: LLM,
ty: type[StateT]
) -> AnyChatNodeFunction[StateT]:
...

def get_async_scolder(
llm: LLM,
ty: type[StateT]
) -> AsyncChatNodeFunction[StateT]:
return _stitch_async_impl(
_get_scolding_pure(ty),
_async_llm(llm)
)

def get_sync_scolder(
llm: LLM,
ty: type[StateT]
) -> ChatNodeFunction[StateT]:
return _stitch_sync_impl(
_get_scolding_pure(ty),
_sync_llm(llm)
)

INITIAL_NODE = "initial"
TOOLS_NODE = "tools"
TOOL_RESULT_NODE = "tool_result"
SUMMARIZE_NODE = "summarize"
SCOLDING_NODE = "scold"

BoundLLM = LLM

Expand Down Expand Up @@ -544,7 +590,7 @@ def with_tools(self, l: Iterable[BaseTool | SplitTool]) -> "Builder[_BStateT, _B
to_ret._tools.extend(l)
return to_ret

def _build_internal(self, r: _ResultFact, i: _InitialFact, s: _SummarizerFact) -> Tuple["StateGraph[_BStateT, _BContextT, _BInputT, Any]", BoundLLM]: #type: ignore
def _build_internal(self, r: _ResultFact, i: _InitialFact, s: _SummarizerFact, scold: _ScoldingFact) -> Tuple["StateGraph[_BStateT, _BContextT, _BInputT, Any]", BoundLLM]: #type: ignore
if self._state_class is None:
raise ValueError("state_class is required")
if self._input_type is None:
Expand All @@ -571,21 +617,24 @@ def _build_internal(self, r: _ResultFact, i: _InitialFact, s: _SummarizerFact) -
init_fact=i,
result_fact=r,
summary_fact=s,
output_schema=None
output_schema=None,
scolder=scold
)

def build(self) -> Tuple["StateGraph[_BStateT, _BContextT, _BInputT, Any]", BoundLLM]: #type: ignore
return self._build_internal(
s=get_summarizer,
i=initial_node,
r=tool_result_generator
r=tool_result_generator,
scold=get_sync_scolder
)

def build_async(self) -> Tuple["StateGraph[_BStateT, _BContextT, _BInputT, Any]", BoundLLM]: #type: ignore
return self._build_internal(
s=get_async_summarizer,
i=async_initial_node,
r=async_tool_result_generator
r=async_tool_result_generator,
scold=get_async_scolder
)

def compile_async(
Expand Down Expand Up @@ -626,7 +675,8 @@ def build_workflow(
summary_config,
tool_result_generator,
initial_node,
get_summarizer
get_summarizer,
get_sync_scolder
)


Expand Down Expand Up @@ -655,7 +705,8 @@ def build_async_workflow(
summary_config,
async_tool_result_generator,
async_initial_node,
get_async_summarizer
get_async_summarizer,
get_async_scolder
)

def _build_workflow(
Expand All @@ -671,7 +722,8 @@ def _build_workflow(
summary_config: SummaryConfig[StateT] | None,
result_fact: _ResultFact,
init_fact: _InitialFact,
summary_fact: _SummarizerFact
summary_fact: _SummarizerFact,
scolder: _ScoldingFact,
) -> Tuple[StateGraph[StateT, ContextT, InputState, OutputT], LLM]:
"""
Build a standard workflow with initial node -> tools -> tool_result pattern.
Expand Down Expand Up @@ -709,6 +761,16 @@ def should_end(state: StateT) -> Literal["__end__", "tool_result"]:
return "__end__"
return "tool_result"

def should_scold_about_tools(state: StateT) -> Literal["tools", "scold"]:
m = state["messages"]
last = m[-1]
if not isinstance(last, AIMessage):
raise ValueError("Routing is broken, have non-AI message at end of message")
if len(last.tool_calls) == 0:
return "scold"
else:
return "tools"

tool_schemas : list[BaseTool | dict] = []
tool_impls : list[BaseTool] = []

Expand All @@ -734,7 +796,7 @@ def should_end(state: StateT) -> Literal["__end__", "tool_result"]:

# Create initial node and tool node with curried LLM
init_node = init_fact(input_type, state_class, sys_prompt=sys_prompt, initial_prompt=initial_prompt, llm=llm)
tool_node = ToolNode(tool_impls, handle_tool_errors=False)
tool_node = ToolNode(tool_impls, handle_tool_errors=(ValidationError,))
tool_result_node = result_fact(state_class, llm)

# Build the graph with fixed input schema, no context
Expand All @@ -749,8 +811,11 @@ def should_end(state: StateT) -> Literal["__end__", "tool_result"]:
builder.add_node(INITIAL_NODE, init_node, input_schema=input_type)
builder.add_node(TOOLS_NODE, tool_node)
builder.add_node(TOOL_RESULT_NODE, tool_result_node)
builder.add_edge(INITIAL_NODE, TOOLS_NODE)
builder.add_edge(TOOL_RESULT_NODE, TOOLS_NODE)
builder.add_conditional_edges(INITIAL_NODE, should_scold_about_tools)
builder.add_conditional_edges(SCOLDING_NODE, should_scold_about_tools)
builder.add_conditional_edges(TOOL_RESULT_NODE, should_scold_about_tools)

builder.add_node(SCOLDING_NODE, scolder(llm, state_class))

if summary_config is not None:
def routing(state: StateT) -> Literal["summarize", "tool_result", "__end__"]:
Expand All @@ -765,7 +830,7 @@ def routing(state: StateT) -> Literal["summarize", "tool_result", "__end__"]:
unbound_llm, sys_prompt, initial_prompt, state_class, summary_config
)
builder.add_node(SUMMARIZE_NODE, summarizer)
builder.add_edge(SUMMARIZE_NODE, TOOL_RESULT_NODE)
builder.add_edge(SUMMARIZE_NODE, TOOL_RESULT_NODE)
builder.add_conditional_edges(TOOLS_NODE, routing)
else:
builder.add_conditional_edges(TOOLS_NODE, should_end)
Expand Down