diff --git a/src/strands/event_loop/streaming.py b/src/strands/event_loop/streaming.py index 804f90a1d..a3e0d1bc8 100644 --- a/src/strands/event_loop/streaming.py +++ b/src/strands/event_loop/streaming.py @@ -185,6 +185,9 @@ def handle_content_block_start(event: ContentBlockStartEvent) -> dict[str, Any]: current_tool_use["toolUseId"] = tool_use_data["toolUseId"] current_tool_use["name"] = tool_use_data["name"] current_tool_use["input"] = "" + # Preserve type field for server-side tools (e.g., "server_tool_use" for nova_grounding) + if "type" in tool_use_data: + current_tool_use["type"] = tool_use_data["type"] return current_tool_use @@ -280,11 +283,15 @@ def handle_content_block_stop(state: dict[str, Any]) -> dict[str, Any]: tool_use_id = current_tool_use["toolUseId"] tool_use_name = current_tool_use["name"] - tool_use = ToolUse( - toolUseId=tool_use_id, - name=tool_use_name, - input=current_tool_use["input"], - ) + tool_use: ToolUse = { + "toolUseId": tool_use_id, + "name": tool_use_name, + "input": current_tool_use["input"], + } + # Preserve type field for server-side tools (e.g., "server_tool_use" for nova_grounding) + if "type" in current_tool_use: + tool_use["type"] = current_tool_use["type"] + content.append({"toolUse": tool_use}) state["current_tool_use"] = {} diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 08d8f400c..679f289f7 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -52,6 +52,18 @@ DEFAULT_READ_TIMEOUT = 120 +# Keys in additional_args that may conflict with built-in request construction +_CONFLICTING_ADDITIONAL_ARGS_KEYS = frozenset( + { + "toolConfig", + "inferenceConfig", + "guardrailConfig", + "system", + "messages", + "modelId", + } +) + class BedrockModel(Model): """AWS Bedrock model provider implementation. @@ -88,6 +100,9 @@ class BedrockConfig(TypedDict, total=False): True includes status, False removes status, "auto" determines based on model_id. Defaults to "auto". stop_sequences: List of sequences that will stop generation when encountered streaming: Flag to enable/disable streaming. Defaults to True. + system_tools: List of Bedrock system tool definitions (e.g., nova_grounding). + These are server-side tools merged with agent tools in toolConfig. + Example: [{"systemTool": {"name": "nova_grounding"}}] temperature: Controls randomness in generation (higher = more random) top_p: Controls diversity via nucleus sampling (alternative to temperature) """ @@ -110,6 +125,7 @@ class BedrockConfig(TypedDict, total=False): include_tool_result_status: Optional[Literal["auto"] | bool] stop_sequences: Optional[list[str]] streaming: Optional[bool] + system_tools: Optional[list[dict[str, Any]]] temperature: Optional[float] top_p: Optional[float] @@ -187,6 +203,21 @@ def get_config(self) -> BedrockConfig: """ return self.config + def _warn_on_conflicting_additional_args(self) -> None: + """Warn if additional_args contains keys that conflict with built-in parameters.""" + additional_args = self.config.get("additional_args") + if not additional_args: + return + + for key in _CONFLICTING_ADDITIONAL_ARGS_KEYS: + if key in additional_args: + warnings.warn( + f"additional_args contains '{key}' which may conflict with built-in request parameters. " + f"Values in additional_args are merged last and may overwrite built-in values.", + UserWarning, + stacklevel=3, + ) + def _format_request( self, messages: Messages, @@ -206,6 +237,8 @@ def _format_request( Returns: A Bedrock converse stream request. """ + self._warn_on_conflicting_additional_args() + if not tool_specs: has_tool_content = any( any("toolUse" in block or "toolResult" in block for block in msg.get("content", [])) for msg in messages @@ -238,18 +271,19 @@ def _format_request( "inputSchema": tool_spec["inputSchema"], } } - for tool_spec in tool_specs + for tool_spec in (tool_specs or []) ], *( [{"cachePoint": {"type": self.config["cache_tools"]}}] if self.config.get("cache_tools") else [] ), + *(self.config.get("system_tools") or []), ], **({"toolChoice": tool_choice if tool_choice else {"auto": {}}}), } } - if tool_specs + if tool_specs or self.config.get("system_tools") else {} ), **( @@ -672,8 +706,12 @@ def _stream( logger.debug("got response from model") if streaming: response = self.client.converse_stream(**request) - # Track tool use events to fix stopReason for streaming responses - has_tool_use = False + # Track tool use/result events to fix stopReason for streaming responses + # We need to distinguish server-side tools (already executed) from client-side tools + tool_use_info: dict[str, str] = {} # toolUseId -> type (e.g., "server_tool_use") + tool_result_ids: set[str] = set() # IDs of tools with results + has_client_tools = False + for chunk in response["stream"]: if ( "metadata" in chunk @@ -685,22 +723,40 @@ def _stream( for event in self._generate_redaction_events(): callback(event) - # Track if we see tool use events - if "contentBlockStart" in chunk and chunk["contentBlockStart"].get("start", {}).get("toolUse"): - has_tool_use = True + # Track tool use events with their types + if "contentBlockStart" in chunk: + tool_use_start = chunk["contentBlockStart"].get("start", {}).get("toolUse") + if tool_use_start: + tool_use_id = tool_use_start.get("toolUseId", "") + tool_type = tool_use_start.get("type", "") + tool_use_info[tool_use_id] = tool_type + # Check if it's a client-side tool (not server_tool_use) + if tool_type != "server_tool_use": + has_client_tools = True + + # Track tool result events (for server-side tools that were already executed) + tool_result_start = chunk["contentBlockStart"].get("start", {}).get("toolResult") + if tool_result_start: + tool_result_ids.add(tool_result_start.get("toolUseId", "")) # Fix stopReason for streaming responses that contain tool use + # BUT: Only override if there are client-side tools without results if ( - has_tool_use - and "messageStop" in chunk + "messageStop" in chunk and (message_stop := chunk["messageStop"]).get("stopReason") == "end_turn" ): - # Create corrected chunk with tool_use stopReason - modified_chunk = chunk.copy() - modified_chunk["messageStop"] = message_stop.copy() - modified_chunk["messageStop"]["stopReason"] = "tool_use" - logger.warning("Override stop reason from end_turn to tool_use") - callback(modified_chunk) + # Check if we have client-side tools that need execution + needs_execution = has_client_tools and not set(tool_use_info.keys()).issubset(tool_result_ids) + + if needs_execution: + # Create corrected chunk with tool_use stopReason + modified_chunk = chunk.copy() + modified_chunk["messageStop"] = message_stop.copy() + modified_chunk["messageStop"]["stopReason"] = "tool_use" + logger.warning("Override stop reason from end_turn to tool_use") + callback(modified_chunk) + else: + callback(chunk) else: callback(chunk) @@ -762,6 +818,43 @@ def _stream( callback() logger.debug("finished streaming response from model") + def _has_client_side_tools_to_execute(self, message_content: list[dict[str, Any]]) -> bool: + """Check if message contains client-side tools that need execution. + + Server-side tools (like nova_grounding) are executed by Bedrock and include + toolResult blocks in the response. We should NOT override stopReason to + "tool_use" for these tools. + + Args: + message_content: The content array from Bedrock response. + + Returns: + True if there are client-side tools without results, False otherwise. + """ + tool_use_ids = set() + tool_result_ids = set() + has_client_tools = False + + for content in message_content: + if "toolUse" in content: + tool_use = content["toolUse"] + tool_use_ids.add(tool_use["toolUseId"]) + + # Check if it's a server-side tool (Bedrock executes these) + if tool_use.get("type") != "server_tool_use": + has_client_tools = True + + elif "toolResult" in content: + # Track which tools already have results + tool_result_ids.add(content["toolResult"]["toolUseId"]) + + # Only return True if there are client-side tools without results + if not has_client_tools: + return False + + # Check if all tool uses have corresponding results + return not tool_use_ids.issubset(tool_result_ids) + def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Iterable[StreamEvent]: """Convert a non-streaming response to the streaming format. @@ -842,10 +935,12 @@ def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Itera # Yield messageStop event # Fix stopReason for models that return end_turn when they should return tool_use on non-streaming side + # BUT: Don't override for server-side tools (like nova_grounding) that are already executed current_stop_reason = response["stopReason"] if current_stop_reason == "end_turn": message_content = response["output"]["message"]["content"] - if any("toolUse" in content for content in message_content): + # Only override if there are client-side tools that need execution + if self._has_client_side_tools_to_execute(message_content): current_stop_reason = "tool_use" logger.warning("Override stop reason from end_turn to tool_use") diff --git a/src/strands/tools/executors/_executor.py b/src/strands/tools/executors/_executor.py index 5d01c5d48..769b5fac5 100644 --- a/src/strands/tools/executors/_executor.py +++ b/src/strands/tools/executors/_executor.py @@ -180,24 +180,38 @@ async def _stream( invocation_state = before_event.invocation_state if not selected_tool: - if tool_func == selected_tool: + # Server-side tools (e.g., nova_grounding) are executed by the model provider + # and their results are already included in the response. We provide a + # placeholder result to satisfy the tool result requirement. + if tool_use.get("type") == "server_tool_use": + logger.debug("tool_name=<%s> | server-side tool executed by model provider", tool_name) + result: ToolResult = { + "toolUseId": str(tool_use.get("toolUseId")), + "status": "success", + "content": [{"text": f"Server-side tool '{tool_name}' executed by model provider"}], + } + elif tool_func == selected_tool: logger.error( "tool_name=<%s>, available_tools=<%s> | tool not found in registry", tool_name, list(agent.tool_registry.registry.keys()), ) + result = { + "toolUseId": str(tool_use.get("toolUseId")), + "status": "error", + "content": [{"text": f"Unknown tool: {tool_name}"}], + } else: logger.debug( "tool_name=<%s>, tool_use_id=<%s> | a hook resulted in a non-existing tool call", tool_name, str(tool_use.get("toolUseId")), ) - - result: ToolResult = { - "toolUseId": str(tool_use.get("toolUseId")), - "status": "error", - "content": [{"text": f"Unknown tool: {tool_name}"}], - } + result = { + "toolUseId": str(tool_use.get("toolUseId")), + "status": "error", + "content": [{"text": f"Unknown tool: {tool_name}"}], + } after_event, _ = await ToolExecutor._invoke_after_tool_call_hook( agent, selected_tool, tool_use, invocation_state, result diff --git a/src/strands/types/content.py b/src/strands/types/content.py index 4d0bbe412..3026e4dd2 100644 --- a/src/strands/types/content.py +++ b/src/strands/types/content.py @@ -8,7 +8,7 @@ from typing import Dict, List, Literal, Optional -from typing_extensions import TypedDict +from typing_extensions import NotRequired, TypedDict from .citations import CitationsContentBlock from .media import DocumentContent, ImageContent, VideoContent @@ -129,10 +129,12 @@ class ContentBlockStartToolUse(TypedDict): Attributes: name: The name of the tool that the model is requesting to use. toolUseId: The ID for the tool request. + type: Optional type identifier (e.g., "server_tool_use" for server-side tools). """ name: str toolUseId: str + type: NotRequired[str] class ContentBlockStart(TypedDict, total=False): diff --git a/src/strands/types/tools.py b/src/strands/types/tools.py index 8f4dba6b1..9aa6313f2 100644 --- a/src/strands/types/tools.py +++ b/src/strands/types/tools.py @@ -57,11 +57,13 @@ class ToolUse(TypedDict): Can be any JSON-serializable type. name: The name of the tool to invoke. toolUseId: A unique identifier for this specific tool use request. + type: Optional type identifier for the tool use (e.g., "server_tool_use" for server-side tools). """ input: Any name: str toolUseId: str + type: NotRequired[str] class ToolResultContent(TypedDict, total=False): diff --git a/tests/strands/event_loop/test_streaming.py b/tests/strands/event_loop/test_streaming.py index c6e44b78a..ba3a7172e 100644 --- a/tests/strands/event_loop/test_streaming.py +++ b/tests/strands/event_loop/test_streaming.py @@ -124,6 +124,11 @@ def test_handle_message_start(): {"start": {"toolUse": {"toolUseId": "test", "name": "test"}}}, {"toolUseId": "test", "name": "test", "input": ""}, ), + # Server-side tool with type field (e.g., nova_grounding) + ( + {"start": {"toolUse": {"toolUseId": "server-1", "name": "nova_grounding", "type": "server_tool_use"}}}, + {"toolUseId": "server-1", "name": "nova_grounding", "input": "", "type": "server_tool_use"}, + ), ], ) def test_handle_content_block_start(chunk: ContentBlockStartEvent, exp_tool_use): @@ -328,6 +333,39 @@ def test_handle_content_block_delta(event: ContentBlockDeltaEvent, event_type, s "redactedContent": b"", }, ), + # Server-side tool with type field (e.g., nova_grounding) + ( + { + "content": [], + "current_tool_use": { + "toolUseId": "server-1", + "name": "nova_grounding", + "input": "{}", + "type": "server_tool_use", + }, + "text": "", + "reasoningText": "", + "citationsContent": [], + "redactedContent": b"", + }, + { + "content": [ + { + "toolUse": { + "toolUseId": "server-1", + "name": "nova_grounding", + "input": {}, + "type": "server_tool_use", + } + } + ], + "current_tool_use": {}, + "text": "", + "reasoningText": "", + "citationsContent": [], + "redactedContent": b"", + }, + ), # Text ( { diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 5ec5a7072..9c44ce266 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -543,6 +543,143 @@ def test_format_request_cache(model, messages, model_id, tool_spec, cache_type): assert tru_request == exp_request +def test_format_request_system_tools_only(model, messages, model_id): + """System tools work without agent tools.""" + system_tools = [{"systemTool": {"name": "nova_grounding"}}] + model.update_config(system_tools=system_tools) + + tru_request = model._format_request(messages, tool_specs=None) + exp_request = { + "inferenceConfig": {}, + "modelId": model_id, + "messages": messages, + "system": [], + "toolConfig": { + "tools": [{"systemTool": {"name": "nova_grounding"}}], + "toolChoice": {"auto": {}}, + }, + } + + assert tru_request == exp_request + + +def test_format_request_system_tools_with_agent_tools(model, messages, model_id, tool_spec): + """System tools are merged after agent tools.""" + system_tools = [{"systemTool": {"name": "nova_grounding"}}] + model.update_config(system_tools=system_tools) + + tru_request = model._format_request(messages, tool_specs=[tool_spec]) + exp_request = { + "inferenceConfig": {}, + "modelId": model_id, + "messages": messages, + "system": [], + "toolConfig": { + "tools": [ + {"toolSpec": tool_spec}, + {"systemTool": {"name": "nova_grounding"}}, + ], + "toolChoice": {"auto": {}}, + }, + } + + assert tru_request == exp_request + + +def test_format_request_system_tools_with_cache(model, messages, model_id, tool_spec, cache_type): + """System tools appear after cache_tools.""" + system_tools = [{"systemTool": {"name": "nova_grounding"}}] + model.update_config(system_tools=system_tools, cache_tools=cache_type) + + tru_request = model._format_request(messages, tool_specs=[tool_spec]) + exp_request = { + "inferenceConfig": {}, + "modelId": model_id, + "messages": messages, + "system": [], + "toolConfig": { + "tools": [ + {"toolSpec": tool_spec}, + {"cachePoint": {"type": cache_type}}, + {"systemTool": {"name": "nova_grounding"}}, + ], + "toolChoice": {"auto": {}}, + }, + } + + assert tru_request == exp_request + + +def test_format_request_empty_system_tools(model, messages, model_id): + """Empty system_tools list doesn't add toolConfig.""" + model.update_config(system_tools=[]) + + tru_request = model._format_request(messages, tool_specs=None) + exp_request = { + "inferenceConfig": {}, + "modelId": model_id, + "messages": messages, + "system": [], + } + + assert tru_request == exp_request + + +def test_format_request_multiple_system_tools(model, messages, model_id): + """Multiple system tools are all included.""" + system_tools = [ + {"systemTool": {"name": "nova_grounding"}}, + {"systemTool": {"name": "another_tool"}}, + ] + model.update_config(system_tools=system_tools) + + tru_request = model._format_request(messages, tool_specs=None) + exp_request = { + "inferenceConfig": {}, + "modelId": model_id, + "messages": messages, + "system": [], + "toolConfig": { + "tools": [ + {"systemTool": {"name": "nova_grounding"}}, + {"systemTool": {"name": "another_tool"}}, + ], + "toolChoice": {"auto": {}}, + }, + } + + assert tru_request == exp_request + + +def test_format_request_warns_on_conflicting_toolconfig(model, messages, model_id): + """Warning is emitted when additional_args contains toolConfig.""" + model.update_config(additional_args={"toolConfig": {"tools": []}}) + + with pytest.warns(UserWarning, match="additional_args contains 'toolConfig'"): + model._format_request(messages, tool_specs=None) + + +def test_format_request_warns_on_conflicting_inferenceconfig(model, messages, model_id): + """Warning is emitted when additional_args contains inferenceConfig.""" + model.update_config(additional_args={"inferenceConfig": {"maxTokens": 100}}) + + with pytest.warns(UserWarning, match="additional_args contains 'inferenceConfig'"): + model._format_request(messages, tool_specs=None) + + +def test_format_request_no_warning_for_safe_additional_args(model, messages, model_id): + """No warning for non-conflicting additional_args keys.""" + model.update_config(additional_args={"customField": "value"}) + + # Should not emit any warnings + import warnings + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + model._format_request(messages, tool_specs=None) + assert len(w) == 0 + + @pytest.mark.asyncio async def test_stream_throttling_exception_from_event_stream_error(bedrock_client, model, messages, alist): error_message = "Rate exceeded" @@ -2143,3 +2280,171 @@ async def test_citations_content_preserves_tagged_union_structure(bedrock_client "(documentChar, documentPage, documentChunk, searchResultLocation, or web) " "with the location fields nested inside." ) + + + +def test_has_client_side_tools_to_execute_with_client_tools(model): + """Test that client-side tools are correctly identified as needing execution.""" + message_content = [ + { + "toolUse": { + "toolUseId": "tool-123", + "name": "my_tool", + "input": {"param": "value"}, + } + } + ] + + assert model._has_client_side_tools_to_execute(message_content) is True + + +def test_has_client_side_tools_to_execute_with_server_tools(model): + """Test that server-side tools (like nova_grounding) are NOT identified as needing execution.""" + message_content = [ + { + "toolUse": { + "toolUseId": "tool-123", + "name": "nova_grounding", + "type": "server_tool_use", + "input": {}, + } + }, + { + "toolResult": { + "toolUseId": "tool-123", + "content": [{"text": "Grounding result"}], + } + }, + ] + + assert model._has_client_side_tools_to_execute(message_content) is False + + +def test_has_client_side_tools_to_execute_with_mixed_tools(model): + """Test mixed server and client tools - should return True if client tools need execution.""" + message_content = [ + # Server-side tool with result + { + "toolUse": { + "toolUseId": "server-tool-123", + "name": "nova_grounding", + "type": "server_tool_use", + "input": {}, + } + }, + { + "toolResult": { + "toolUseId": "server-tool-123", + "content": [{"text": "Grounding result"}], + } + }, + # Client-side tool without result + { + "toolUse": { + "toolUseId": "client-tool-456", + "name": "my_tool", + "input": {"param": "value"}, + } + }, + ] + + assert model._has_client_side_tools_to_execute(message_content) is True + + +def test_has_client_side_tools_to_execute_with_no_tools(model): + """Test that no tools returns False.""" + message_content = [{"text": "Just some text"}] + + assert model._has_client_side_tools_to_execute(message_content) is False + + +@pytest.mark.asyncio +async def test_stream_server_tool_use_does_not_override_stop_reason(bedrock_client, alist, messages): + """Test that stopReason is NOT overridden for server-side tools like nova_grounding.""" + model = BedrockModel(model_id="amazon.nova-premier-v1:0") + model.client = bedrock_client + + # Simulate streaming response with server-side tool use and result + bedrock_client.converse_stream.return_value = { + "stream": [ + {"messageStart": {"role": "assistant"}}, + { + "contentBlockStart": { + "start": { + "toolUse": { + "toolUseId": "tool-123", + "name": "nova_grounding", + "type": "server_tool_use", + } + } + } + }, + {"contentBlockDelta": {"delta": {"toolUse": {"input": "{}"}}}}, + {"contentBlockStop": {}}, + { + "contentBlockStart": { + "start": { + "toolResult": { + "toolUseId": "tool-123", + } + } + } + }, + {"contentBlockDelta": {"delta": {"text": "Grounding result"}}}, + {"contentBlockStop": {}}, + {"contentBlockStart": {"start": {}}}, + {"contentBlockDelta": {"delta": {"text": "Final response"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn"}}, + ] + } + + events = await alist(model.stream(messages)) + + # Find the messageStop event + message_stop_event = next(e for e in events if "messageStop" in e) + + # Verify stopReason was NOT overridden (should remain end_turn for server-side tools) + assert message_stop_event["messageStop"]["stopReason"] == "end_turn" + + +@pytest.mark.asyncio +async def test_stream_non_streaming_server_tool_use_does_not_override_stop_reason(bedrock_client, alist, messages): + """Test that stopReason is NOT overridden for server-side tools in non-streaming mode.""" + model = BedrockModel(model_id="amazon.nova-premier-v1:0", streaming=False) + model.client = bedrock_client + + bedrock_client.converse.return_value = { + "output": { + "message": { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "tool-123", + "name": "nova_grounding", + "type": "server_tool_use", + "input": {}, + } + }, + { + "toolResult": { + "toolUseId": "tool-123", + "content": [{"text": "Grounding result"}], + } + }, + {"text": "Final response based on grounding"}, + ], + } + }, + "stopReason": "end_turn", + "usage": {"inputTokens": 10, "outputTokens": 20}, + } + + events = await alist(model.stream(messages)) + + # Find the messageStop event + message_stop_event = next(e for e in events if "messageStop" in e) + + # Verify stopReason was NOT overridden (should remain end_turn for server-side tools) + assert message_stop_event["messageStop"]["stopReason"] == "end_turn" diff --git a/tests/strands/tools/executors/test_executor.py b/tests/strands/tools/executors/test_executor.py index 8139fbf66..0d0b3ba8a 100644 --- a/tests/strands/tools/executors/test_executor.py +++ b/tests/strands/tools/executors/test_executor.py @@ -192,6 +192,44 @@ async def test_executor_stream_yields_unknown_tool(executor, agent, tool_results assert tru_hook_after_event == exp_hook_after_event +@pytest.mark.asyncio +async def test_executor_stream_handles_server_side_tools( + executor, agent, tool_results, invocation_state, hook_events, alist +): + """Test that server-side tools (type: server_tool_use) return success with placeholder result. + + Server-side tools like nova_grounding are executed by the model provider (e.g., Bedrock). + The executor should return a success result (not error) to satisfy the tool result requirement. + """ + tool_use = {"name": "nova_grounding", "toolUseId": "server-1", "type": "server_tool_use", "input": {}} + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + + tru_events = await alist(stream) + exp_events = [ + ToolResultEvent({ + "toolUseId": "server-1", + "status": "success", + "content": [{"text": "Server-side tool 'nova_grounding' executed by model provider"}], + }) + ] + assert tru_events == exp_events + + tru_results = tool_results + exp_results = [exp_events[-1].tool_result] + assert tru_results == exp_results + + # Hooks should still be invoked for server-side tools + tru_hook_after_event = hook_events[-1] + exp_hook_after_event = AfterToolCallEvent( + agent=agent, + selected_tool=None, + tool_use=tool_use, + invocation_state=invocation_state, + result=exp_results[0], + ) + assert tru_hook_after_event == exp_hook_after_event + + @pytest.mark.asyncio async def test_executor_stream_with_trace( executor, tracer, agent, tool_results, cycle_trace, cycle_span, invocation_state, alist