diff --git a/pyproject.toml b/pyproject.toml index b7839300..f826b132 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "uipath-langchain" -version = "0.3.2" +version = "0.3.3" description = "Python SDK that enables developers to build and deploy LangGraph agents to the UiPath Cloud Platform" readme = { file = "README.md", content-type = "text/markdown" } requires-python = ">=3.11" diff --git a/src/uipath_langchain/agent/react/router.py b/src/uipath_langchain/agent/react/router.py index dfcbf57a..6b27828e 100644 --- a/src/uipath_langchain/agent/react/router.py +++ b/src/uipath_langchain/agent/react/router.py @@ -3,6 +3,7 @@ from typing import Literal from langchain_core.messages import AIMessage, AnyMessage, ToolCall +from langgraph.types import Send from uipath.agent.react import END_EXECUTION_TOOL, RAISE_ERROR_TOOL from ..exceptions import AgentNodeRoutingException @@ -59,7 +60,7 @@ def create_route_agent(thinking_messages_limit: int = 0): def route_agent( state: AgentGraphState, - ) -> list[str] | Literal[AgentGraphNode.AGENT, AgentGraphNode.TERMINATE]: + ) -> list[str | Send] | Literal[AgentGraphNode.AGENT, AgentGraphNode.TERMINATE]: """Route after agent: handles all routing logic including control flow detection. Routing logic: @@ -86,7 +87,15 @@ def route_agent( return AgentGraphNode.TERMINATE if tool_calls: - return [tc["name"] for tc in tool_calls] + return [ + Send( + tc["name"], + AgentGraphState( + messages=messages, inner_state=state.inner_state, tool_call=tc + ), + ) + for tc in tool_calls + ] consecutive_thinking_messages = count_consecutive_thinking_messages(messages) diff --git a/src/uipath_langchain/agent/react/types.py b/src/uipath_langchain/agent/react/types.py index ebd84258..1d7cb2fb 100644 --- a/src/uipath_langchain/agent/react/types.py +++ b/src/uipath_langchain/agent/react/types.py @@ -2,6 +2,7 @@ from typing import Annotated, Any, Optional from langchain_core.messages import AnyMessage +from langchain_core.messages.tool import ToolCall from langgraph.graph.message import add_messages from pydantic import BaseModel, Field from uipath.platform.attachments import Attachment @@ -33,6 +34,9 @@ class AgentGraphState(BaseModel): inner_state: Annotated[InnerAgentGraphState, merge_objects] = Field( default_factory=InnerAgentGraphState ) + tool_call: Optional[ToolCall] = ( + None # This field is used to pass tool inputs to tool nodes. + ) class AgentGuardrailsGraphState(AgentGraphState): diff --git a/src/uipath_langchain/agent/tools/tool_node.py b/src/uipath_langchain/agent/tools/tool_node.py index aec540b4..e45e9969 100644 --- a/src/uipath_langchain/agent/tools/tool_node.py +++ b/src/uipath_langchain/agent/tools/tool_node.py @@ -4,7 +4,6 @@ from inspect import signature from typing import Any, Awaitable, Callable, Literal -from langchain_core.messages.ai import AIMessage from langchain_core.messages.tool import ToolCall, ToolMessage from langchain_core.runnables.config import RunnableConfig from langchain_core.tools import BaseTool @@ -12,6 +11,8 @@ from langgraph.types import Command from pydantic import BaseModel +from ..react.types import AgentGraphState + # the type safety can be improved with generics ToolWrapperType = Callable[ [BaseTool, ToolCall, Any], dict[str, Any] | Command[Any] | None @@ -49,7 +50,9 @@ def __init__( self.wrapper = wrapper self.awrapper = awrapper - def _func(self, state: Any, config: RunnableConfig | None = None) -> OutputType: + def _func( + self, state: AgentGraphState, config: RunnableConfig | None = None + ) -> OutputType: call = self._extract_tool_call(state) if call is None: return None @@ -61,7 +64,7 @@ def _func(self, state: Any, config: RunnableConfig | None = None) -> OutputType: return self._process_result(call, result) async def _afunc( - self, state: Any, config: RunnableConfig | None = None + self, state: AgentGraphState, config: RunnableConfig | None = None ) -> OutputType: call = self._extract_tool_call(state) if call is None: @@ -73,20 +76,10 @@ async def _afunc( result = await self.tool.ainvoke(call["args"]) return self._process_result(call, result) - def _extract_tool_call(self, state: Any) -> ToolCall | None: + def _extract_tool_call(self, state: AgentGraphState) -> ToolCall | None: """Extract the tool call from the state messages.""" - if not hasattr(state, "messages"): - raise ValueError("State does not have messages key") - - last_message = state.messages[-1] - if not isinstance(last_message, AIMessage): - raise ValueError("Last message in message stack is not an AIMessage.") - - for tool_call in last_message.tool_calls: - if tool_call["name"] == self.tool.name: - return tool_call - return None + return state.tool_call def _process_result( self, call: ToolCall, result: dict[str, Any] | Command[Any] | None @@ -101,7 +94,7 @@ def _process_result( return {"messages": [message]} def _filter_state( - self, state: Any, wrapper: ToolWrapperType | AsyncToolWrapperType + self, state: AgentGraphState, wrapper: ToolWrapperType | AsyncToolWrapperType ) -> BaseModel: """Filter the state to the expected model type.""" model_type = list(signature(wrapper).parameters.values())[2].annotation diff --git a/tests/agent/tools/test_tool_node.py b/tests/agent/tools/test_tool_node.py index 3d5633ed..362505ef 100644 --- a/tests/agent/tools/test_tool_node.py +++ b/tests/agent/tools/test_tool_node.py @@ -1,6 +1,6 @@ """Tests for tool_node.py module.""" -from typing import Any, Dict +from typing import Any, Dict, Optional import pytest from langchain_core.messages import AIMessage, HumanMessage @@ -55,6 +55,7 @@ class MockState(BaseModel): messages: list[Any] = [] user_id: str = "test_user" session_id: str = "test_session" + tool_call: Optional[ToolCall] = None def mock_wrapper( @@ -101,7 +102,7 @@ def mock_state(self): "id": "test_call_id", } ai_message = AIMessage(content="Using tool", tool_calls=[tool_call]) - return MockState(messages=[ai_message]) + return MockState(messages=[ai_message], tool_call=tool_call) @pytest.fixture def empty_state(self): @@ -200,26 +201,6 @@ def test_no_tool_calls_returns_none(self, mock_tool, empty_state): assert result is None - def test_non_ai_message_raises_error(self, mock_tool, non_ai_state): - """Test that non-AI messages raise ValueError.""" - node = UiPathToolNode(mock_tool) - - with pytest.raises( - ValueError, match="Last message in message stack is not an AIMessage" - ): - node._func(non_ai_state) - - def test_mismatched_tool_name_returns_none(self, mock_tool, mock_state): - """Test that mismatched tool names return None.""" - # Change the tool call name to something different - mock_state.messages[-1].tool_calls[0]["name"] = "different_tool" - - node = UiPathToolNode(mock_tool) - - result = node._func(mock_state) - - assert result is None - def test_state_filtering(self, mock_tool, mock_state): """Test that state is properly filtered for wrapper functions.""" node = UiPathToolNode(mock_tool, wrapper=mock_wrapper) diff --git a/uv.lock b/uv.lock index eda76354..812eed82 100644 --- a/uv.lock +++ b/uv.lock @@ -3282,7 +3282,7 @@ wheels = [ [[package]] name = "uipath-langchain" -version = "0.3.2" +version = "0.3.3" source = { editable = "." } dependencies = [ { name = "aiosqlite" },