From 61b46358a06bd7284bcfe93958c50a8bd61073d2 Mon Sep 17 00:00:00 2001 From: Akshaya Shanbhogue Date: Tue, 6 Jan 2026 21:24:53 -0800 Subject: [PATCH] fix(ToolNode): pass correct inputs to tool nodes Scenario: 1. LLM invokes multiple instances of the same tool, such as ["Web Search", "Web Search"] 1. This line schedules multiple tool nodes to run in parallel. 1. Each tool node picks up the first matching tool call from the state. This means that both of the web search tools will execute the first tool call. --- pyproject.toml | 2 +- src/uipath_langchain/agent/react/router.py | 13 ++++++++-- src/uipath_langchain/agent/react/types.py | 4 +++ src/uipath_langchain/agent/tools/tool_node.py | 25 +++++++------------ tests/agent/tools/test_tool_node.py | 25 +++---------------- uv.lock | 2 +- 6 files changed, 29 insertions(+), 42 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b7839300..f826b132 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "uipath-langchain" -version = "0.3.2" +version = "0.3.3" description = "Python SDK that enables developers to build and deploy LangGraph agents to the UiPath Cloud Platform" readme = { file = "README.md", content-type = "text/markdown" } requires-python = ">=3.11" diff --git a/src/uipath_langchain/agent/react/router.py b/src/uipath_langchain/agent/react/router.py index dfcbf57a..6b27828e 100644 --- a/src/uipath_langchain/agent/react/router.py +++ b/src/uipath_langchain/agent/react/router.py @@ -3,6 +3,7 @@ from typing import Literal from langchain_core.messages import AIMessage, AnyMessage, ToolCall +from langgraph.types import Send from uipath.agent.react import END_EXECUTION_TOOL, RAISE_ERROR_TOOL from ..exceptions import AgentNodeRoutingException @@ -59,7 +60,7 @@ def create_route_agent(thinking_messages_limit: int = 0): def route_agent( state: AgentGraphState, - ) -> list[str] | Literal[AgentGraphNode.AGENT, AgentGraphNode.TERMINATE]: + ) -> list[str | Send] | Literal[AgentGraphNode.AGENT, AgentGraphNode.TERMINATE]: """Route after agent: handles all routing logic including control flow detection. Routing logic: @@ -86,7 +87,15 @@ def route_agent( return AgentGraphNode.TERMINATE if tool_calls: - return [tc["name"] for tc in tool_calls] + return [ + Send( + tc["name"], + AgentGraphState( + messages=messages, inner_state=state.inner_state, tool_call=tc + ), + ) + for tc in tool_calls + ] consecutive_thinking_messages = count_consecutive_thinking_messages(messages) diff --git a/src/uipath_langchain/agent/react/types.py b/src/uipath_langchain/agent/react/types.py index ebd84258..1d7cb2fb 100644 --- a/src/uipath_langchain/agent/react/types.py +++ b/src/uipath_langchain/agent/react/types.py @@ -2,6 +2,7 @@ from typing import Annotated, Any, Optional from langchain_core.messages import AnyMessage +from langchain_core.messages.tool import ToolCall from langgraph.graph.message import add_messages from pydantic import BaseModel, Field from uipath.platform.attachments import Attachment @@ -33,6 +34,9 @@ class AgentGraphState(BaseModel): inner_state: Annotated[InnerAgentGraphState, merge_objects] = Field( default_factory=InnerAgentGraphState ) + tool_call: Optional[ToolCall] = ( + None # This field is used to pass tool inputs to tool nodes. + ) class AgentGuardrailsGraphState(AgentGraphState): diff --git a/src/uipath_langchain/agent/tools/tool_node.py b/src/uipath_langchain/agent/tools/tool_node.py index aec540b4..e45e9969 100644 --- a/src/uipath_langchain/agent/tools/tool_node.py +++ b/src/uipath_langchain/agent/tools/tool_node.py @@ -4,7 +4,6 @@ from inspect import signature from typing import Any, Awaitable, Callable, Literal -from langchain_core.messages.ai import AIMessage from langchain_core.messages.tool import ToolCall, ToolMessage from langchain_core.runnables.config import RunnableConfig from langchain_core.tools import BaseTool @@ -12,6 +11,8 @@ from langgraph.types import Command from pydantic import BaseModel +from ..react.types import AgentGraphState + # the type safety can be improved with generics ToolWrapperType = Callable[ [BaseTool, ToolCall, Any], dict[str, Any] | Command[Any] | None @@ -49,7 +50,9 @@ def __init__( self.wrapper = wrapper self.awrapper = awrapper - def _func(self, state: Any, config: RunnableConfig | None = None) -> OutputType: + def _func( + self, state: AgentGraphState, config: RunnableConfig | None = None + ) -> OutputType: call = self._extract_tool_call(state) if call is None: return None @@ -61,7 +64,7 @@ def _func(self, state: Any, config: RunnableConfig | None = None) -> OutputType: return self._process_result(call, result) async def _afunc( - self, state: Any, config: RunnableConfig | None = None + self, state: AgentGraphState, config: RunnableConfig | None = None ) -> OutputType: call = self._extract_tool_call(state) if call is None: @@ -73,20 +76,10 @@ async def _afunc( result = await self.tool.ainvoke(call["args"]) return self._process_result(call, result) - def _extract_tool_call(self, state: Any) -> ToolCall | None: + def _extract_tool_call(self, state: AgentGraphState) -> ToolCall | None: """Extract the tool call from the state messages.""" - if not hasattr(state, "messages"): - raise ValueError("State does not have messages key") - - last_message = state.messages[-1] - if not isinstance(last_message, AIMessage): - raise ValueError("Last message in message stack is not an AIMessage.") - - for tool_call in last_message.tool_calls: - if tool_call["name"] == self.tool.name: - return tool_call - return None + return state.tool_call def _process_result( self, call: ToolCall, result: dict[str, Any] | Command[Any] | None @@ -101,7 +94,7 @@ def _process_result( return {"messages": [message]} def _filter_state( - self, state: Any, wrapper: ToolWrapperType | AsyncToolWrapperType + self, state: AgentGraphState, wrapper: ToolWrapperType | AsyncToolWrapperType ) -> BaseModel: """Filter the state to the expected model type.""" model_type = list(signature(wrapper).parameters.values())[2].annotation diff --git a/tests/agent/tools/test_tool_node.py b/tests/agent/tools/test_tool_node.py index 3d5633ed..362505ef 100644 --- a/tests/agent/tools/test_tool_node.py +++ b/tests/agent/tools/test_tool_node.py @@ -1,6 +1,6 @@ """Tests for tool_node.py module.""" -from typing import Any, Dict +from typing import Any, Dict, Optional import pytest from langchain_core.messages import AIMessage, HumanMessage @@ -55,6 +55,7 @@ class MockState(BaseModel): messages: list[Any] = [] user_id: str = "test_user" session_id: str = "test_session" + tool_call: Optional[ToolCall] = None def mock_wrapper( @@ -101,7 +102,7 @@ def mock_state(self): "id": "test_call_id", } ai_message = AIMessage(content="Using tool", tool_calls=[tool_call]) - return MockState(messages=[ai_message]) + return MockState(messages=[ai_message], tool_call=tool_call) @pytest.fixture def empty_state(self): @@ -200,26 +201,6 @@ def test_no_tool_calls_returns_none(self, mock_tool, empty_state): assert result is None - def test_non_ai_message_raises_error(self, mock_tool, non_ai_state): - """Test that non-AI messages raise ValueError.""" - node = UiPathToolNode(mock_tool) - - with pytest.raises( - ValueError, match="Last message in message stack is not an AIMessage" - ): - node._func(non_ai_state) - - def test_mismatched_tool_name_returns_none(self, mock_tool, mock_state): - """Test that mismatched tool names return None.""" - # Change the tool call name to something different - mock_state.messages[-1].tool_calls[0]["name"] = "different_tool" - - node = UiPathToolNode(mock_tool) - - result = node._func(mock_state) - - assert result is None - def test_state_filtering(self, mock_tool, mock_state): """Test that state is properly filtered for wrapper functions.""" node = UiPathToolNode(mock_tool, wrapper=mock_wrapper) diff --git a/uv.lock b/uv.lock index eda76354..812eed82 100644 --- a/uv.lock +++ b/uv.lock @@ -3282,7 +3282,7 @@ wheels = [ [[package]] name = "uipath-langchain" -version = "0.3.2" +version = "0.3.3" source = { editable = "." } dependencies = [ { name = "aiosqlite" },