diff --git a/pyproject.toml b/pyproject.toml index b7839300..f826b132 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "uipath-langchain" -version = "0.3.2" +version = "0.3.3" description = "Python SDK that enables developers to build and deploy LangGraph agents to the UiPath Cloud Platform" readme = { file = "README.md", content-type = "text/markdown" } requires-python = ">=3.11" diff --git a/src/uipath_langchain/agent/guardrails/actions/filter_action.py b/src/uipath_langchain/agent/guardrails/actions/filter_action.py index fa11182d..8852d0d8 100644 --- a/src/uipath_langchain/agent/guardrails/actions/filter_action.py +++ b/src/uipath_langchain/agent/guardrails/actions/filter_action.py @@ -1,7 +1,7 @@ import re from typing import Any -from langchain_core.messages import AIMessage, ToolMessage +from langchain_core.messages import ToolMessage from langgraph.types import Command from uipath.core.guardrails.guardrails import FieldReference, FieldSource from uipath.platform.guardrails import BaseGuardrail, GuardrailScope @@ -11,6 +11,7 @@ from ...exceptions import AgentTerminationException from ...react.types import AgentGuardrailsGraphState +from ...react.utils import extract_tool_call_from_state from .base_action import GuardrailAction, GuardrailActionNode @@ -143,66 +144,36 @@ def _filter_tool_input_fields( if not has_input_fields: return {} - msgs = state.messages.copy() - if not msgs: - return {} + tool_call_id = getattr(state, "tool_call_id", None) + tool_call, message = extract_tool_call_from_state( + state, tool_name, tool_call_id, return_message=True + ) - # Find the AIMessage with tool calls - # At PRE_EXECUTION, this is always the last message - ai_message = None - for i in range(len(msgs) - 1, -1, -1): - msg = msgs[i] - if isinstance(msg, AIMessage) and msg.tool_calls: - ai_message = msg - break + if tool_call is None: + return {} - if ai_message is None: + args = tool_call["args"] + if not args or not isinstance(args, dict): return {} - # Find and filter the tool call with matching name - # Type assertion: we know ai_message is AIMessage from the check above - assert isinstance(ai_message, AIMessage) - tool_calls = list(ai_message.tool_calls) + # Filter out the specified input fields + filtered_args = args.copy() modified = False - for tool_call in tool_calls: - call_name = ( - tool_call.get("name") - if isinstance(tool_call, dict) - else getattr(tool_call, "name", None) - ) - - if call_name == tool_name: - # Get the current args - args = ( - tool_call.get("args") - if isinstance(tool_call, dict) - else getattr(tool_call, "args", None) - ) + for field_ref in fields_to_filter: + # Only filter input fields + if field_ref.source == FieldSource.INPUT and field_ref.path in filtered_args: + del filtered_args[field_ref.path] + modified = True - if args and isinstance(args, dict): - # Filter out the specified input fields - filtered_args = args.copy() - for field_ref in fields_to_filter: - # Only filter input fields - if ( - field_ref.source == FieldSource.INPUT - and field_ref.path in filtered_args - ): - del filtered_args[field_ref.path] - modified = True - - # Update the tool call with filtered args - if isinstance(tool_call, dict): - tool_call["args"] = filtered_args - else: - tool_call.args = filtered_args - - break + if modified: + tool_call["args"] = filtered_args + message.tool_calls = [ + tool_call if tool_call["id"] == tc["id"] else tc + for tc in message.tool_calls + ] - if modified: - ai_message.tool_calls = tool_calls - return Command(update={"messages": msgs}) + return Command(update={"messages": [message]}) return {} diff --git a/src/uipath_langchain/agent/react/agent.py b/src/uipath_langchain/agent/react/agent.py index ad887d16..702fd56a 100644 --- a/src/uipath_langchain/agent/react/agent.py +++ b/src/uipath_langchain/agent/react/agent.py @@ -10,6 +10,9 @@ from uipath.platform.guardrails import BaseGuardrail from ..guardrails.actions import GuardrailAction +from .aggregator_node import ( + create_aggregator_node, +) from .guardrails.guardrails_subgraph import ( create_agent_init_guardrails_subgraph, create_agent_terminate_guardrails_subgraph, @@ -107,6 +110,10 @@ def create_agent( ) builder.add_node(AgentGraphNode.TERMINATE, terminate_with_guardrails_subgraph) + # Add aggregator node + aggregator_node = create_aggregator_node() + builder.add_node(AgentGraphNode.AGGREGATOR, aggregator_node) + builder.add_edge(START, AgentGraphNode.INIT) llm_node = create_llm_node(model, llm_tools, config.thinking_messages_limit) @@ -125,7 +132,10 @@ def create_agent( ) for tool_name in tool_node_names: - builder.add_edge(tool_name, AgentGraphNode.AGENT) + builder.add_edge(tool_name, AgentGraphNode.AGGREGATOR) + + # Aggregator goes back to agent + builder.add_edge(AgentGraphNode.AGGREGATOR, AgentGraphNode.AGENT) builder.add_edge(AgentGraphNode.TERMINATE, END) diff --git a/src/uipath_langchain/agent/react/aggregator_node.py b/src/uipath_langchain/agent/react/aggregator_node.py new file mode 100644 index 00000000..05fe003c --- /dev/null +++ b/src/uipath_langchain/agent/react/aggregator_node.py @@ -0,0 +1,125 @@ +"""Aggregator node for merging substates back into main state.""" + +from typing import Any + +from langchain_core.messages import AIMessage, AnyMessage +from langgraph.types import Overwrite + +from uipath_langchain.agent.react.types import AgentGraphState, InnerAgentGraphState + + +def _aggregate_messages( + original_messages: list[AnyMessage], substate_messages: dict[str, list[AnyMessage]] +) -> list[AnyMessage]: + aggregated_by_id: dict[str, AnyMessage] = {} + original_order: list[str] = [] + + for msg in original_messages: + aggregated_by_id[msg.id] = msg + original_order.append(msg.id) + + new_messages: list[AnyMessage] = [] + + for tool_call_id, substate_msgs in substate_messages.items(): + for msg in substate_msgs: + if msg.id in aggregated_by_id: + # existing message + original_msg = aggregated_by_id[msg.id] + if ( + isinstance(msg, AIMessage) + and msg.tool_calls + and len(msg.tool_calls) > 0 + ): + updated_tool_call = next( + (tc for tc in msg.tool_calls if tc["id"] == tool_call_id), None + ) + if updated_tool_call: + # update the specific tool call in the original message + new_tool_calls = [ + updated_tool_call if tc["id"] == tool_call_id else tc + for tc in original_msg.tool_calls + ] + aggregated_by_id[msg.id].tool_calls = new_tool_calls + else: + # new message, add it + new_messages.append(msg) + + result = [] + for msg_id in original_order: + result.append(aggregated_by_id[msg_id]) + result.extend(new_messages) + + return result + + +def create_aggregator_node() -> callable: + """Create an aggregator node that merges substates back into main state.""" + + def aggregator_node(state: AgentGraphState) -> dict[str, Any] | Overwrite: + """ + Aggregate substates back into main state. + + If substates is empty, no-op and continue. + If substates is non-empty: + - for messages, leave placeholder for message aggregation logic + - for each field in inner state, get its reducer and apply updates + - lastly, overwrite the state and clear substates + """ + if not state.substates: + return {} + + # message aggregation + substate_messages = {} + for tool_call_id, substate in state.substates.items(): + if "messages" in substate: + substate_messages[tool_call_id] = substate["messages"] + + aggregated_messages = _aggregate_messages(state.messages, substate_messages) + + # inner state fields aggregation + aggregated_inner_dict = state.inner_state.model_dump() + + inner_state_fields = InnerAgentGraphState.model_fields + for substate in state.substates.values(): + if "inner_state" in substate: + substate_inner_data = substate["inner_state"] + + if isinstance(substate_inner_data, InnerAgentGraphState): + substate_inner_dict = substate_inner_data.model_dump() + else: + substate_inner_dict = substate_inner_data + + # for each field, apply reducer if defined + for field_name, field_info in inner_state_fields.items(): + if field_name in substate_inner_dict: + substate_field_value = substate_inner_dict[field_name] + current_field_value = aggregated_inner_dict[field_name] + + if field_info.metadata and callable(field_info.metadata[-1]): + reducer_func = field_info.metadata[-1] + merged_value = reducer_func( + current_field_value, substate_field_value + ) + else: + # no reducer, just replace + merged_value = substate_field_value + + aggregated_inner_dict[field_name] = merged_value + + aggregated_inner_state = InnerAgentGraphState.model_validate( + aggregated_inner_dict + ) + + state.messages = aggregated_messages + state.inner_state = aggregated_inner_state + state.substates = {} + + # return overwrite command to replace the state + return { + **state.model_dump(exclude={"messages", "inner_state", "substates"}), + "messages": Overwrite(aggregated_messages), + "inner_state": Overwrite(aggregated_inner_state), + "substates": Overwrite({}), + } + + return aggregator_node diff --git a/src/uipath_langchain/agent/react/guardrails/guardrails_subgraph.py b/src/uipath_langchain/agent/react/guardrails/guardrails_subgraph.py index d7d75a05..ca3661c2 100644 --- a/src/uipath_langchain/agent/react/guardrails/guardrails_subgraph.py +++ b/src/uipath_langchain/agent/react/guardrails/guardrails_subgraph.py @@ -23,8 +23,10 @@ ) from uipath_langchain.agent.guardrails.types import ExecutionStage from uipath_langchain.agent.react.types import ( + AgentGraphNode, AgentGraphState, AgentGuardrailsGraphState, + SubgraphOutputModel, ) _VALIDATOR_ALLOWED_STAGES = { @@ -33,6 +35,21 @@ } +def _tool_call_state_handler(state: AgentGuardrailsGraphState) -> dict[str, Any]: + """Handle tool call state by moving contents to substates if tool_call is present.""" + if state.tool_call_id is not None: + # Move current state contents to substates under tool_call_id + return { + "substates": { + state.tool_call_id: { + "messages": state.messages, + "inner_state": state.inner_state, + } + } + } + return {} + + def _filter_guardrails_by_stage( guardrails: Sequence[tuple[BaseGuardrail, GuardrailAction]] | None, stage: ExecutionStage, @@ -83,7 +100,7 @@ def _create_guardrails_subgraph( """ inner_name, inner_node = main_inner_node - subgraph = StateGraph(AgentGuardrailsGraphState) + subgraph = StateGraph(AgentGuardrailsGraphState, output_schema=SubgraphOutputModel) subgraph.add_node(inner_name, inner_node) @@ -105,6 +122,10 @@ def _create_guardrails_subgraph( else: subgraph.add_edge(START, inner_name) + # Always add the tool call state handler node at the end + tool_call_handler_name = AgentGraphNode.TOOL_CALL_STATE_HANDLER + subgraph.add_node(tool_call_handler_name, _tool_call_state_handler) + # Add post execution guardrail nodes if ExecutionStage.POST_EXECUTION in execution_stages: post_guardrails = _filter_guardrails_by_stage( @@ -116,12 +137,15 @@ def _create_guardrails_subgraph( scope, ExecutionStage.POST_EXECUTION, node_factory, - END, + tool_call_handler_name, inner_name, ) subgraph.add_edge(inner_name, first_post_exec_guardrail_node) else: - subgraph.add_edge(inner_name, END) + subgraph.add_edge(inner_name, tool_call_handler_name) + + # Always connect tool call handler to END + subgraph.add_edge(tool_call_handler_name, END) return subgraph.compile() diff --git a/src/uipath_langchain/agent/react/llm_node.py b/src/uipath_langchain/agent/react/llm_node.py index 57518d95..59a623c8 100644 --- a/src/uipath_langchain/agent/react/llm_node.py +++ b/src/uipath_langchain/agent/react/llm_node.py @@ -1,6 +1,6 @@ """LLM node for ReAct Agent graph.""" -from typing import Literal, Sequence +from typing import Any, Literal, Sequence from langchain_core.language_models import BaseChatModel from langchain_core.messages import AIMessage, AnyMessage @@ -53,8 +53,15 @@ def create_llm_node( base_llm = model.bind_tools(bindable_tools) if bindable_tools else model tool_choice_required_value = _get_required_tool_choice_by_model(model) - async def llm_node(state: AgentGraphState): - messages: list[AnyMessage] = state.messages + async def llm_node(state: Any): + # we need to use Any here because LangGraph has weird edge behavior + # if the type annotation for the state in the edge function is Any/BaseModel/dict/etc aka not a specific model + # then LangGraph will pass the **same** state that was passed to the previous node + # meaning if we want the full state in the edge, we need to pass the full state here as well + # unfortunately, using AgentGraphState in the annotation and relying on extra="allow" does not work + # so we are doing the validation manually here + agent_state = AgentGraphState.model_validate(state, from_attributes=True) + messages: list[AnyMessage] = agent_state.messages consecutive_thinking_messages = count_consecutive_thinking_messages(messages) diff --git a/src/uipath_langchain/agent/react/reducers.py b/src/uipath_langchain/agent/react/reducers.py index cecf323c..2e48c20a 100644 --- a/src/uipath_langchain/agent/react/reducers.py +++ b/src/uipath_langchain/agent/react/reducers.py @@ -31,19 +31,19 @@ def add_job_attachments( def merge_objects(left: Any, right: Any) -> Any: - """Merge a Pydantic model with another model or dict, with right values taking precedence. + """Merge a Pydantic model or dict with another model or dict, with right values taking precedence. Applies field-specific reducers from annotation metadata when merging values. Args: - left: Existing Pydantic BaseModel instance + left: Existing Pydantic BaseModel instance or dict right: New Pydantic BaseModel instance or dict to merge Returns: - New Pydantic model instance with merged values + New Pydantic model instance with merged values (if left is BaseModel) or merged dict (if left is dict) Raises: - TypeError: If left is not a Pydantic BaseModel or right is not a BaseModel or dict + TypeError: If left or right are not Pydantic BaseModel or dict """ if not right: return left @@ -52,12 +52,22 @@ def merge_objects(left: Any, right: Any) -> Any: return right # validate input types - if not isinstance(left, BaseModel): - raise TypeError("Left object must be a Pydantic BaseModel") + if not isinstance(left, (BaseModel, dict)): + raise TypeError("Left object must be a Pydantic BaseModel or dict") if not isinstance(right, (BaseModel, dict)): raise TypeError("Right object must be a Pydantic BaseModel or dict") + # If left is a dict, perform simple dict merging + if isinstance(left, dict): + merged_values = left.copy() + if isinstance(right, BaseModel): + merged_values.update(right.model_dump()) + else: + merged_values.update(right) + return merged_values + + # If left is a BaseModel, use the original logic model_fields = type(left).model_fields merged_values = {} @@ -80,11 +90,19 @@ def merge_objects(left: Any, right: Any) -> Any: left_value = merged_values[field_name] # apply reducer if defined - if field_info.metadata and callable(field_info.metadata[0]): - reducer_func = field_info.metadata[0] + if field_info.metadata and callable(field_info.metadata[-1]): + reducer_func = field_info.metadata[-1] merged_values[field_name] = reducer_func(left_value, right_value) else: merged_values[field_name] = right_value # return new model instance with merged values return type(left)(**merged_values) + + +def replace_once(left: Any | None, right: Any | None) -> Any | None: + """Reducer to replace left value with right value if left is None.""" + if left is None: + return right + + return left diff --git a/src/uipath_langchain/agent/react/router.py b/src/uipath_langchain/agent/react/router.py index dfcbf57a..2d91b77e 100644 --- a/src/uipath_langchain/agent/react/router.py +++ b/src/uipath_langchain/agent/react/router.py @@ -1,12 +1,13 @@ """Routing functions for conditional edges in the agent graph.""" -from typing import Literal +from typing import Any, Literal from langchain_core.messages import AIMessage, AnyMessage, ToolCall +from langgraph.types import Send from uipath.agent.react import END_EXECUTION_TOOL, RAISE_ERROR_TOOL from ..exceptions import AgentNodeRoutingException -from .types import AgentGraphNode, AgentGraphState +from .types import AgentGraphNode, AgentGraphState, UiPathToolNodeInput from .utils import count_consecutive_thinking_messages FLOW_CONTROL_TOOLS = [END_EXECUTION_TOOL.name, RAISE_ERROR_TOOL.name] @@ -58,8 +59,8 @@ def create_route_agent(thinking_messages_limit: int = 0): """ def route_agent( - state: AgentGraphState, - ) -> list[str] | Literal[AgentGraphNode.AGENT, AgentGraphNode.TERMINATE]: + state: Any, + ) -> list[str | Send] | Literal[AgentGraphNode.AGENT, AgentGraphNode.TERMINATE]: """Route after agent: handles all routing logic including control flow detection. Routing logic: @@ -76,7 +77,13 @@ def route_agent( Raises: AgentNodeRoutingException: When encountering unexpected state (empty messages, non-AIMessage, or excessive completions) """ - messages = state.messages + + # we cannot type hint state as CompleteAgentGraphState because it's defined at runtime + # and directly using AgentGraphState in the type hint, even if extra="allow" is set, does not work + agent_state: AgentGraphState = AgentGraphState.model_validate( + state, from_attributes=True + ) + messages: list[AnyMessage] = agent_state.messages last_message = __validate_last_message_is_AI(messages) tool_calls = list(last_message.tool_calls) if last_message.tool_calls else [] @@ -86,7 +93,15 @@ def route_agent( return AgentGraphNode.TERMINATE if tool_calls: - return [tc["name"] for tc in tool_calls] + return [ + Send( + tc["name"], + UiPathToolNodeInput( + **agent_state.model_dump(), tool_call_id=tc["id"] + ), + ) + for tc in tool_calls + ] consecutive_thinking_messages = count_consecutive_thinking_messages(messages) diff --git a/src/uipath_langchain/agent/react/types.py b/src/uipath_langchain/agent/react/types.py index ebd84258..79c9fb3e 100644 --- a/src/uipath_langchain/agent/react/types.py +++ b/src/uipath_langchain/agent/react/types.py @@ -3,10 +3,14 @@ from langchain_core.messages import AnyMessage from langgraph.graph.message import add_messages -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field from uipath.platform.attachments import Attachment -from uipath_langchain.agent.react.reducers import add_job_attachments, merge_objects +from uipath_langchain.agent.react.reducers import ( + add_job_attachments, + merge_objects, + replace_once, +) class AgentTerminationSource(StrEnum): @@ -23,16 +27,24 @@ class AgentTermination(BaseModel): class InnerAgentGraphState(BaseModel): job_attachments: Annotated[dict[str, Attachment], add_job_attachments] = {} - termination: AgentTermination | None = None + termination: Annotated[AgentTermination | None, replace_once] = None class AgentGraphState(BaseModel): """Agent Graph state for standard loop execution.""" + substates: Annotated[dict[str, Any], merge_objects] = {} messages: Annotated[list[AnyMessage], add_messages] = [] inner_state: Annotated[InnerAgentGraphState, merge_objects] = Field( default_factory=InnerAgentGraphState ) + model_config = ConfigDict(extra="allow") + + +class SubgraphOutputModel(BaseModel): + """Subgraph output model.""" + + substates: dict[str, Any] class AgentGuardrailsGraphState(AgentGraphState): @@ -40,6 +52,7 @@ class AgentGuardrailsGraphState(AgentGraphState): guardrail_validation_result: Optional[str] = None agent_result: Optional[dict[str, Any]] = None + tool_call_id: Optional[str] = None class AgentGraphNode(StrEnum): @@ -50,6 +63,8 @@ class AgentGraphNode(StrEnum): TOOLS = "tools" TERMINATE = "terminate" GUARDED_TERMINATE = "guarded-terminate" + TOOL_CALL_STATE_HANDLER = "tool-call-state-handler" + AGGREGATOR = "aggregator" class AgentGraphConfig(BaseModel): @@ -61,3 +76,9 @@ class AgentGraphConfig(BaseModel): ge=0, description="Max consecutive thinking messages before enforcing tool usage. 0 = force tools every time.", ) + + +class UiPathToolNodeInput(AgentGraphState): + """Tool node input model.""" + + tool_call_id: str diff --git a/src/uipath_langchain/agent/react/utils.py b/src/uipath_langchain/agent/react/utils.py index 614f7014..8db3d2a9 100644 --- a/src/uipath_langchain/agent/react/utils.py +++ b/src/uipath_langchain/agent/react/utils.py @@ -3,10 +3,12 @@ from typing import Any, Sequence from langchain_core.messages import AIMessage, BaseMessage +from langchain_core.messages.tool import ToolCall from pydantic import BaseModel from uipath.agent.react import END_EXECUTION_TOOL from uipath_langchain.agent.react.jsonschema_pydantic_converter import create_model +from uipath_langchain.agent.react.types import AgentGraphState def resolve_input_model( @@ -48,3 +50,52 @@ def count_consecutive_thinking_messages(messages: Sequence[BaseMessage]) -> int: count += 1 return count + + +def extract_tool_call_from_state( + state: AgentGraphState, + tool_name: str, + tool_call_id: str | None = None, + return_message: bool = False, +) -> ToolCall | None | tuple[ToolCall | None, AIMessage | None]: + """ + Extract tool call from state using consistent logic. + + Search order: + 1. If tool_call_id is provided, search for tool call with matching id and name + 2. Otherwise, find first tool call with matching name from the last AI message + + Args: + state: The agent graph state + tool_name: Name of the tool to find + tool_call_id: Optional tool call id to search for + return_message: If True, returns tuple of (tool_call, message) instead of just tool_call + + Returns: + The matching ToolCall if found, None otherwise. If return_message is True, + returns tuple of (ToolCall | None, AIMessage | None). + """ + if not state.messages: + return (None, None) if return_message else None + + # 1. If tool_call_id is provided, search for tool call with matching id and name + if tool_call_id is not None: + for message in reversed(state.messages): + if isinstance(message, AIMessage): + for tool_call in message.tool_calls: + if ( + tool_call["id"] == tool_call_id + and tool_call["name"] == tool_name + ): + return (tool_call, message) if return_message else tool_call + return (None, None) if return_message else None + + # 2. Find first tool call with matching name from the last AI message + for message in reversed(state.messages): + if isinstance(message, AIMessage): + for tool_call in message.tool_calls: + if tool_call["name"] == tool_name: + return (tool_call, message) if return_message else tool_call + break + + return (None, None) if return_message else None diff --git a/src/uipath_langchain/agent/tools/tool_node.py b/src/uipath_langchain/agent/tools/tool_node.py index aec540b4..2b494c03 100644 --- a/src/uipath_langchain/agent/tools/tool_node.py +++ b/src/uipath_langchain/agent/tools/tool_node.py @@ -4,14 +4,15 @@ from inspect import signature from typing import Any, Awaitable, Callable, Literal -from langchain_core.messages.ai import AIMessage from langchain_core.messages.tool import ToolCall, ToolMessage -from langchain_core.runnables.config import RunnableConfig from langchain_core.tools import BaseTool from langgraph._internal._runnable import RunnableCallable from langgraph.types import Command from pydantic import BaseModel +from uipath_langchain.agent.react.types import AgentGraphState +from uipath_langchain.agent.react.utils import extract_tool_call_from_state + # the type safety can be improved with generics ToolWrapperType = Callable[ [BaseTool, ToolCall, Any], dict[str, Any] | Command[Any] | None @@ -27,8 +28,8 @@ class UiPathToolNode(RunnableCallable): """ A ToolNode that can be used in a React agent graph. It extracts the tool call from the state messages and invokes the tool. + Alternatively, it can accept a UiPathToolNodeInput directly. It supports optional synchronous and asynchronous wrappers for custom processing. - Generic over the state model. Args: tool: The tool to invoke. wrapper: An optional synchronous wrapper for custom processing. @@ -49,8 +50,8 @@ def __init__( self.wrapper = wrapper self.awrapper = awrapper - def _func(self, state: Any, config: RunnableConfig | None = None) -> OutputType: - call = self._extract_tool_call(state) + def _func(self, input: AgentGraphState) -> OutputType: + call, state = self._extract_tool_call(input) if call is None: return None if self.wrapper: @@ -60,10 +61,8 @@ def _func(self, state: Any, config: RunnableConfig | None = None) -> OutputType: result = self.tool.invoke(call["args"]) return self._process_result(call, result) - async def _afunc( - self, state: Any, config: RunnableConfig | None = None - ) -> OutputType: - call = self._extract_tool_call(state) + async def _afunc(self, input: AgentGraphState) -> OutputType: + call, state = self._extract_tool_call(input) if call is None: return None if self.awrapper: @@ -73,20 +72,16 @@ async def _afunc( result = await self.tool.ainvoke(call["args"]) return self._process_result(call, result) - def _extract_tool_call(self, state: Any) -> ToolCall | None: - """Extract the tool call from the state messages.""" - - if not hasattr(state, "messages"): - raise ValueError("State does not have messages key") - - last_message = state.messages[-1] - if not isinstance(last_message, AIMessage): - raise ValueError("Last message in message stack is not an AIMessage.") - - for tool_call in last_message.tool_calls: - if tool_call["name"] == self.tool.name: - return tool_call - return None + def _extract_tool_call( + self, input: AgentGraphState + ) -> tuple[ToolCall | None, AgentGraphState]: + """ + Extract the tool call and agent state from the input. + Uses the shared utility function for consistent tool call extraction logic. + """ + tool_call_id = getattr(input, "tool_call_id", None) + tool_call = extract_tool_call_from_state(input, self.tool.name, tool_call_id) + return tool_call, input def _process_result( self, call: ToolCall, result: dict[str, Any] | Command[Any] | None @@ -101,7 +96,7 @@ def _process_result( return {"messages": [message]} def _filter_state( - self, state: Any, wrapper: ToolWrapperType | AsyncToolWrapperType + self, state: AgentGraphState, wrapper: ToolWrapperType | AsyncToolWrapperType ) -> BaseModel: """Filter the state to the expected model type.""" model_type = list(signature(wrapper).parameters.values())[2].annotation @@ -109,7 +104,7 @@ def _filter_state( raise ValueError( "Wrapper state parameter must be a pydantic BaseModel subclass." ) - return model_type.model_validate(state, from_attributes=True) + return model_type.model_validate(state, from_attributes=True, extra="allow") class ToolWrapperMixin: diff --git a/tests/agent/tools/test_tool_node.py b/tests/agent/tools/test_tool_node.py index 3d5633ed..e3f68195 100644 --- a/tests/agent/tools/test_tool_node.py +++ b/tests/agent/tools/test_tool_node.py @@ -9,6 +9,7 @@ from langgraph.types import Command from pydantic import BaseModel +from uipath_langchain.agent.react.types import AgentGraphState, UiPathToolNodeInput from uipath_langchain.agent.tools.tool_node import ( ToolWrapperMixin, UiPathToolNode, @@ -317,3 +318,57 @@ def test_create_tool_node_empty_tools(self): result = create_tool_node([]) assert result == {} + + +class TestUiPathToolNodeInput: + """Test cases for UiPathToolNode with UiPathToolNodeInput.""" + + @pytest.fixture + def mock_tool(self): + """Fixture for mock tool.""" + return MockTool() + + @pytest.fixture + def tool_node_input(self): + """Fixture for UiPathToolNodeInput.""" + tool_call = { + "name": "mock_tool", + "args": {"input_text": "test input"}, + "id": "test_call_id", + } + agent_state = AgentGraphState( + messages=[], user_id="test_user", session_id="test_session" + ) + return UiPathToolNodeInput(tool_call=tool_call, agent_state=agent_state) + + def test_accepts_uipath_tool_node_input(self, mock_tool, tool_node_input): + """Test that UiPathToolNode accepts UiPathToolNodeInput.""" + node = UiPathToolNode(mock_tool) + + result = node._func(tool_node_input) + + assert result is not None + assert isinstance(result, dict) + assert "messages" in result + + def test_extracts_call_and_state_for_wrapper(self, mock_tool, tool_node_input): + """Test that call and state are correctly extracted and passed to wrapper.""" + node = UiPathToolNode(mock_tool, wrapper=mock_wrapper) + + result = node._func(tool_node_input) + + assert result is not None + assert isinstance(result, dict) + tool_message = result["messages"][0] + assert "user: test_user" in tool_message.content + + def test_validates_tool_call_name(self, mock_tool, tool_node_input): + """Test that tool call name is validated.""" + tool_node_input.tool_call["name"] = "different_tool" + node = UiPathToolNode(mock_tool) + + with pytest.raises( + ValueError, + match="Tool call name 'different_tool' does not match tool name 'mock_tool'", + ): + node._func(tool_node_input) diff --git a/uv.lock b/uv.lock index eda76354..812eed82 100644 --- a/uv.lock +++ b/uv.lock @@ -3282,7 +3282,7 @@ wheels = [ [[package]] name = "uipath-langchain" -version = "0.3.2" +version = "0.3.3" source = { editable = "." } dependencies = [ { name = "aiosqlite" },