From f161dd39d148cd16213ace188bb8d28100fee2b0 Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Thu, 11 Dec 2025 11:36:10 +0900 Subject: [PATCH 1/6] feat: Add on_stream to agents as tools --- examples/agent_patterns/README.md | 1 + .../agents_as_tools_streaming.py | 57 +++ src/agents/__init__.py | 2 + src/agents/agent.py | 77 +++- tests/test_agent_as_tool.py | 338 ++++++++++++++++++ 5 files changed, 460 insertions(+), 15 deletions(-) create mode 100644 examples/agent_patterns/agents_as_tools_streaming.py diff --git a/examples/agent_patterns/README.md b/examples/agent_patterns/README.md index 96b48920c..2bdadce0d 100644 --- a/examples/agent_patterns/README.md +++ b/examples/agent_patterns/README.md @@ -28,6 +28,7 @@ The mental model for handoffs is that the new agent "takes over". It sees the pr For example, you could model the translation task above as tool calls instead: rather than handing over to the language-specific agent, you could call the agent as a tool, and then use the result in the next step. This enables things like translating multiple languages at once. See the [`agents_as_tools.py`](./agents_as_tools.py) file for an example of this. +See the [`agents_as_tools_streaming.py`](./agents_as_tools_streaming.py) file for a streaming variant that taps into nested agent events via `on_stream`. ## LLM-as-a-judge diff --git a/examples/agent_patterns/agents_as_tools_streaming.py b/examples/agent_patterns/agents_as_tools_streaming.py new file mode 100644 index 000000000..846593c81 --- /dev/null +++ b/examples/agent_patterns/agents_as_tools_streaming.py @@ -0,0 +1,57 @@ +import asyncio + +from agents import Agent, AgentToolStreamEvent, ModelSettings, Runner, function_tool, trace + + +@function_tool( + name_override="billing_status_checker", + description_override="Answer questions about customer billing status.", +) +def billing_status_checker(customer_id: str | None = None, question: str = "") -> str: + """Return a canned billing answer or a fallback when the question is unrelated.""" + normalized = question.lower() + if "bill" in normalized or "billing" in normalized: + return f"This customer (ID: {customer_id})'s bill is $100" + return "I can only answer questions about billing." + + +def handle_stream(event: AgentToolStreamEvent) -> None: + """Print streaming events emitted by the nested billing agent.""" + stream = event["event"] + print(f"[stream] agent={event['agent_name']} type={stream.type} {stream}") + + +async def main() -> None: + with trace("Agents as tools streaming example"): + billing_agent = Agent( + name="Billing Agent", + instructions="You are a billing agent that answers billing questions.", + model_settings=ModelSettings(tool_choice="required"), + tools=[billing_status_checker], + ) + + billing_agent_tool = billing_agent.as_tool( + tool_name="billing_agent", + tool_description="You are a billing agent that answers billing questions.", + on_stream=handle_stream, + ) + + main_agent = Agent( + name="Customer Support Agent", + instructions=( + "You are a customer support agent. Always call the billing agent to answer billing " + "questions and return the billing agent response to the user." + ), + tools=[billing_agent_tool], + ) + + result = await Runner.run( + main_agent, + "Hello, my customer ID is ABC123. How much is my bill for this month?", + ) + + print(f"\nFinal response:\n{result.final_output}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/agents/__init__.py b/src/agents/__init__.py index 6f4d0815d..00a5ca21e 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -8,6 +8,7 @@ from .agent import ( Agent, AgentBase, + AgentToolStreamEvent, StopAtTools, ToolsToFinalOutputFunction, ToolsToFinalOutputResult, @@ -214,6 +215,7 @@ def enable_verbose_stdout_logging(): __all__ = [ "Agent", "AgentBase", + "AgentToolStreamEvent", "StopAtTools", "ToolsToFinalOutputFunction", "ToolsToFinalOutputResult", diff --git a/src/agents/agent.py b/src/agents/agent.py index c479cc697..6d8206ab6 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -32,8 +32,9 @@ from .lifecycle import AgentHooks, RunHooks from .mcp import MCPServer from .memory.session import Session - from .result import RunResult + from .result import RunResult, RunResultStreaming from .run import RunConfig + from .stream_events import StreamEvent @dataclass @@ -58,6 +59,19 @@ class ToolsToFinalOutputResult: """ +class AgentToolStreamEvent(TypedDict): + """Streaming event emitted when an agent is invoked as a tool.""" + + event: StreamEvent + """The streaming event from the nested agent run.""" + + agent_name: str + """The name of the nested agent emitting the event.""" + + tool_call_id: str | None + """The originating tool call ID, if available.""" + + class StopAtTools(TypedDict): stop_at_tool_names: list[str] """A list of tool names, any of which will stop the agent from running further.""" @@ -382,9 +396,12 @@ def as_tool( self, tool_name: str | None, tool_description: str | None, - custom_output_extractor: Callable[[RunResult], Awaitable[str]] | None = None, + custom_output_extractor: ( + Callable[[RunResult | RunResultStreaming], Awaitable[str]] | None + ) = None, is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase[Any]], MaybeAwaitable[bool]] = True, + on_stream: Callable[[AgentToolStreamEvent], MaybeAwaitable[None]] | None = None, run_config: RunConfig | None = None, max_turns: int | None = None, hooks: RunHooks[TContext] | None = None, @@ -409,6 +426,8 @@ def as_tool( is_enabled: Whether the tool is enabled. Can be a bool or a callable that takes the run context and agent and returns whether the tool is enabled. Disabled tools are hidden from the LLM at runtime. + on_stream: Optional callback (sync or async) to receive streaming events from the nested + agent run. When provided, the nested agent is executed in streaming mode. """ @function_tool( @@ -421,21 +440,49 @@ async def run_agent(context: RunContextWrapper, input: str) -> Any: resolved_max_turns = max_turns if max_turns is not None else DEFAULT_MAX_TURNS - output = await Runner.run( - starting_agent=self, - input=input, - context=context.context, - run_config=run_config, - max_turns=resolved_max_turns, - hooks=hooks, - previous_response_id=previous_response_id, - conversation_id=conversation_id, - session=session, - ) + if on_stream is not None: + run_result = Runner.run_streamed( + starting_agent=self, + input=input, + context=context.context, + run_config=run_config, + max_turns=resolved_max_turns, + hooks=hooks, + previous_response_id=previous_response_id, + conversation_id=conversation_id, + session=session, + ) + async for event in run_result.stream_events(): + payload: AgentToolStreamEvent = { + "event": event, + "agent_name": self.name, + "tool_call_id": getattr(context, "tool_call_id", None), + } + try: + maybe_result = on_stream(payload) + if inspect.isawaitable(maybe_result): + await maybe_result + except Exception: + logger.exception( + "Error while handling on_stream event for agent tool %s.", + self.name, + ) + else: + run_result = await Runner.run( + starting_agent=self, + input=input, + context=context.context, + run_config=run_config, + max_turns=resolved_max_turns, + hooks=hooks, + previous_response_id=previous_response_id, + conversation_id=conversation_id, + session=session, + ) if custom_output_extractor: - return await custom_output_extractor(output) + return await custom_output_extractor(run_result) - return output.final_output + return run_result.final_output return run_agent diff --git a/tests/test_agent_as_tool.py b/tests/test_agent_as_tool.py index 51d8edf20..2f6b38a4d 100644 --- a/tests/test_agent_as_tool.py +++ b/tests/test_agent_as_tool.py @@ -18,6 +18,7 @@ Session, TResponseInputItem, ) +from agents.stream_events import RawResponsesStreamEvent from agents.tool_context import ToolContext @@ -373,3 +374,340 @@ async def extractor(result) -> str: output = await tool.on_invoke_tool(tool_context, '{"input": "summarize this"}') assert output == "custom output" + + +@pytest.mark.asyncio +async def test_agent_as_tool_streams_events_with_on_stream( + monkeypatch: pytest.MonkeyPatch, +) -> None: + agent = Agent(name="streamer") + stream_events = [ + RawResponsesStreamEvent(data={"type": "response_started"}), + RawResponsesStreamEvent(data={"type": "output_text_delta", "delta": "hi"}), + ] + + class DummyStreamingResult: + def __init__(self) -> None: + self.final_output = "streamed output" + + async def stream_events(self): + for ev in stream_events: + yield ev + + run_calls: list[dict[str, Any]] = [] + + def fake_run_streamed( + cls, + starting_agent, + input, + *, + context, + max_turns, + hooks, + run_config, + previous_response_id, + auto_previous_response_id=False, + conversation_id, + session, + ): + run_calls.append( + { + "starting_agent": starting_agent, + "input": input, + "context": context, + "max_turns": max_turns, + "hooks": hooks, + "run_config": run_config, + "previous_response_id": previous_response_id, + "conversation_id": conversation_id, + "session": session, + } + ) + return DummyStreamingResult() + + async def unexpected_run(*args: Any, **kwargs: Any) -> None: + raise AssertionError("Runner.run should not be called when on_stream is provided.") + + monkeypatch.setattr(Runner, "run_streamed", classmethod(fake_run_streamed)) + monkeypatch.setattr(Runner, "run", classmethod(unexpected_run)) + + received_events: list[dict[str, Any]] = [] + + async def on_stream(payload: dict[str, Any]) -> None: + received_events.append(payload) + + tool = agent.as_tool( + tool_name="stream_tool", + tool_description="Streams events", + on_stream=on_stream, + ) + + tool_context = ToolContext( + context=None, + tool_name="stream_tool", + tool_call_id="call-123", + tool_arguments='{"input": "run streaming"}', + ) + output = await tool.on_invoke_tool(tool_context, '{"input": "run streaming"}') + + assert output == "streamed output" + assert len(received_events) == len(stream_events) + assert received_events[0]["agent_name"] == "streamer" + assert received_events[0]["tool_call_id"] == "call-123" + assert received_events[0]["event"] == stream_events[0] + assert run_calls[0]["input"] == "run streaming" + + +@pytest.mark.asyncio +async def test_agent_as_tool_streaming_works_with_custom_extractor( + monkeypatch: pytest.MonkeyPatch, +) -> None: + agent = Agent(name="streamer") + stream_events = [RawResponsesStreamEvent(data={"type": "response_started"})] + + class DummyStreamingResult: + def __init__(self) -> None: + self.final_output = "raw output" + + async def stream_events(self): + for ev in stream_events: + yield ev + + streamed_instance = DummyStreamingResult() + + def fake_run_streamed( + cls, + starting_agent, + input, + *, + context, + max_turns, + hooks, + run_config, + previous_response_id, + auto_previous_response_id=False, + conversation_id, + session, + ): + return streamed_instance + + async def unexpected_run(*args: Any, **kwargs: Any) -> None: + raise AssertionError("Runner.run should not be called when on_stream is provided.") + + monkeypatch.setattr(Runner, "run_streamed", classmethod(fake_run_streamed)) + monkeypatch.setattr(Runner, "run", classmethod(unexpected_run)) + + received: list[Any] = [] + + async def extractor(result) -> str: + received.append(result) + return "custom value" + + callbacks: list[Any] = [] + + async def on_stream(payload: dict[str, Any]) -> None: + callbacks.append(payload["event"]) + + tool = agent.as_tool( + tool_name="stream_tool", + tool_description="Streams events", + custom_output_extractor=extractor, + on_stream=on_stream, + ) + + tool_context = ToolContext( + context=None, + tool_name="stream_tool", + tool_call_id="call-abc", + tool_arguments='{"input": "stream please"}', + ) + output = await tool.on_invoke_tool(tool_context, '{"input": "stream please"}') + + assert output == "custom value" + assert received == [streamed_instance] + assert callbacks == stream_events + + +@pytest.mark.asyncio +async def test_agent_as_tool_streaming_accepts_sync_handler( + monkeypatch: pytest.MonkeyPatch, +) -> None: + agent = Agent(name="sync_handler_agent") + + class DummyStreamingResult: + def __init__(self) -> None: + self.final_output = "ok" + + async def stream_events(self): + yield RawResponsesStreamEvent(data={"type": "response_started"}) + + monkeypatch.setattr( + Runner, "run_streamed", classmethod(lambda *args, **kwargs: DummyStreamingResult()) + ) + monkeypatch.setattr( + Runner, + "run", + classmethod(lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError("no run"))), + ) + + calls: list[str] = [] + + def sync_handler(event: dict[str, Any]) -> None: + calls.append(event["event"].type) + + tool = agent.as_tool( + tool_name="sync_tool", + tool_description="Uses sync handler", + on_stream=sync_handler, + ) + tool_context = ToolContext( + context=None, + tool_name="sync_tool", + tool_call_id="call-sync", + tool_arguments='{"input": "go"}', + ) + + output = await tool.on_invoke_tool(tool_context, '{"input": "go"}') + + assert output == "ok" + assert calls == ["raw_response_event"] + + +@pytest.mark.asyncio +async def test_agent_as_tool_streaming_handler_exception_does_not_fail_call( + monkeypatch: pytest.MonkeyPatch, +) -> None: + agent = Agent(name="handler_error_agent") + + class DummyStreamingResult: + def __init__(self) -> None: + self.final_output = "ok" + + async def stream_events(self): + yield RawResponsesStreamEvent(data={"type": "response_started"}) + + monkeypatch.setattr( + Runner, "run_streamed", classmethod(lambda *args, **kwargs: DummyStreamingResult()) + ) + monkeypatch.setattr( + Runner, + "run", + classmethod(lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError("no run"))), + ) + + def bad_handler(event: dict[str, Any]) -> None: + raise RuntimeError("boom") + + tool = agent.as_tool( + tool_name="error_tool", + tool_description="Handler throws", + on_stream=bad_handler, + ) + tool_context = ToolContext( + context=None, + tool_name="error_tool", + tool_call_id="call-bad", + tool_arguments='{"input": "go"}', + ) + + output = await tool.on_invoke_tool(tool_context, '{"input": "go"}') + + assert output == "ok" + + +@pytest.mark.asyncio +async def test_agent_as_tool_without_stream_uses_run( + monkeypatch: pytest.MonkeyPatch, +) -> None: + agent = Agent(name="nostream_agent") + + class DummyResult: + def __init__(self) -> None: + self.final_output = "plain" + + run_calls: list[dict[str, Any]] = [] + + async def fake_run( + cls, + starting_agent, + input, + *, + context, + max_turns, + hooks, + run_config, + previous_response_id, + auto_previous_response_id=False, + conversation_id, + session, + ): + run_calls.append({"input": input}) + return DummyResult() + + monkeypatch.setattr(Runner, "run", classmethod(fake_run)) + monkeypatch.setattr( + Runner, + "run_streamed", + classmethod(lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError("no stream"))), + ) + + tool = agent.as_tool( + tool_name="nostream_tool", + tool_description="No streaming path", + ) + tool_context = ToolContext( + context=None, + tool_name="nostream_tool", + tool_call_id="call-no", + tool_arguments='{"input": "plain"}', + ) + + output = await tool.on_invoke_tool(tool_context, '{"input": "plain"}') + + assert output == "plain" + assert run_calls == [{"input": "plain"}] + + +@pytest.mark.asyncio +async def test_agent_as_tool_streaming_sets_tool_call_id_none_for_direct_invocation( + monkeypatch: pytest.MonkeyPatch, +) -> None: + agent = Agent(name="direct_invocation_agent") + + class DummyStreamingResult: + def __init__(self) -> None: + self.final_output = "ok" + + async def stream_events(self): + yield RawResponsesStreamEvent(data={"type": "response_started"}) + + monkeypatch.setattr( + Runner, "run_streamed", classmethod(lambda *args, **kwargs: DummyStreamingResult()) + ) + monkeypatch.setattr( + Runner, + "run", + classmethod(lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError("no run"))), + ) + + captured: list[dict[str, Any]] = [] + + async def on_stream(event: dict[str, Any]) -> None: + captured.append(event) + + tool = agent.as_tool( + tool_name="direct_stream_tool", + tool_description="Direct invocation", + on_stream=on_stream, + ) + tool_context = ToolContext( + context=None, + tool_name="direct_stream_tool", + tool_call_id=None, # Direct invoke path does not have a tool call ID. + tool_arguments='{"input": "hi"}', + ) + + output = await tool.on_invoke_tool(tool_context, '{"input": "hi"}') + + assert output == "ok" + assert captured[0]["tool_call_id"] is None From 4ce48a14574bfdba2eced0d1bd78901c9dc73058 Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Thu, 11 Dec 2025 11:46:01 +0900 Subject: [PATCH 2/6] fix mypy errors --- examples/financial_research_agent/manager.py | 4 +- src/agents/agent.py | 1 + tests/test_agent_as_tool.py | 102 +++++++++++-------- 3 files changed, 64 insertions(+), 43 deletions(-) diff --git a/examples/financial_research_agent/manager.py b/examples/financial_research_agent/manager.py index 58ec11bf2..6dfc631aa 100644 --- a/examples/financial_research_agent/manager.py +++ b/examples/financial_research_agent/manager.py @@ -6,7 +6,7 @@ from rich.console import Console -from agents import Runner, RunResult, custom_span, gen_trace_id, trace +from agents import Runner, RunResult, RunResultStreaming, custom_span, gen_trace_id, trace from .agents.financials_agent import financials_agent from .agents.planner_agent import FinancialSearchItem, FinancialSearchPlan, planner_agent @@ -17,7 +17,7 @@ from .printer import Printer -async def _summary_extractor(run_result: RunResult) -> str: +async def _summary_extractor(run_result: RunResult | RunResultStreaming) -> str: """Custom output extractor for sub‑agents that return an AnalysisSummary.""" # The financial/risk analyst agents emit an AnalysisSummary with a `summary` field. # We want the tool call to return just that summary text so the writer can drop it inline. diff --git a/src/agents/agent.py b/src/agents/agent.py index 6d8206ab6..d449fa3ae 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -439,6 +439,7 @@ async def run_agent(context: RunContextWrapper, input: str) -> Any: from .run import DEFAULT_MAX_TURNS, Runner resolved_max_turns = max_turns if max_turns is not None else DEFAULT_MAX_TURNS + run_result: RunResult | RunResultStreaming if on_stream is not None: run_result = Runner.run_streamed( diff --git a/tests/test_agent_as_tool.py b/tests/test_agent_as_tool.py index 2f6b38a4d..ab5f57660 100644 --- a/tests/test_agent_as_tool.py +++ b/tests/test_agent_as_tool.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any +from typing import Any, cast import pytest from openai.types.responses import ResponseOutputMessage, ResponseOutputText @@ -9,6 +9,7 @@ from agents import ( Agent, AgentBase, + AgentToolStreamEvent, FunctionTool, MessageOutputItem, RunConfig, @@ -382,8 +383,8 @@ async def test_agent_as_tool_streams_events_with_on_stream( ) -> None: agent = Agent(name="streamer") stream_events = [ - RawResponsesStreamEvent(data={"type": "response_started"}), - RawResponsesStreamEvent(data={"type": "output_text_delta", "delta": "hi"}), + RawResponsesStreamEvent(data=cast(Any, {"type": "response_started"})), + RawResponsesStreamEvent(data=cast(Any, {"type": "output_text_delta", "delta": "hi"})), ] class DummyStreamingResult: @@ -431,15 +432,18 @@ async def unexpected_run(*args: Any, **kwargs: Any) -> None: monkeypatch.setattr(Runner, "run_streamed", classmethod(fake_run_streamed)) monkeypatch.setattr(Runner, "run", classmethod(unexpected_run)) - received_events: list[dict[str, Any]] = [] + received_events: list[AgentToolStreamEvent] = [] - async def on_stream(payload: dict[str, Any]) -> None: + async def on_stream(payload: AgentToolStreamEvent) -> None: received_events.append(payload) - tool = agent.as_tool( - tool_name="stream_tool", - tool_description="Streams events", - on_stream=on_stream, + tool = cast( + FunctionTool, + agent.as_tool( + tool_name="stream_tool", + tool_description="Streams events", + on_stream=on_stream, + ), ) tool_context = ToolContext( @@ -463,7 +467,8 @@ async def test_agent_as_tool_streaming_works_with_custom_extractor( monkeypatch: pytest.MonkeyPatch, ) -> None: agent = Agent(name="streamer") - stream_events = [RawResponsesStreamEvent(data={"type": "response_started"})] + stream_events = [RawResponsesStreamEvent(data=cast(Any, {"type": "response_started"}))] + stream_events = [RawResponsesStreamEvent(data=cast(Any, {"type": "response_started"}))] class DummyStreamingResult: def __init__(self) -> None: @@ -505,14 +510,17 @@ async def extractor(result) -> str: callbacks: list[Any] = [] - async def on_stream(payload: dict[str, Any]) -> None: + async def on_stream(payload: AgentToolStreamEvent) -> None: callbacks.append(payload["event"]) - tool = agent.as_tool( - tool_name="stream_tool", - tool_description="Streams events", - custom_output_extractor=extractor, - on_stream=on_stream, + tool = cast( + FunctionTool, + agent.as_tool( + tool_name="stream_tool", + tool_description="Streams events", + custom_output_extractor=extractor, + on_stream=on_stream, + ), ) tool_context = ToolContext( @@ -539,7 +547,7 @@ def __init__(self) -> None: self.final_output = "ok" async def stream_events(self): - yield RawResponsesStreamEvent(data={"type": "response_started"}) + yield RawResponsesStreamEvent(data=cast(Any, {"type": "response_started"})) monkeypatch.setattr( Runner, "run_streamed", classmethod(lambda *args, **kwargs: DummyStreamingResult()) @@ -552,13 +560,16 @@ async def stream_events(self): calls: list[str] = [] - def sync_handler(event: dict[str, Any]) -> None: + def sync_handler(event: AgentToolStreamEvent) -> None: calls.append(event["event"].type) - tool = agent.as_tool( - tool_name="sync_tool", - tool_description="Uses sync handler", - on_stream=sync_handler, + tool = cast( + FunctionTool, + agent.as_tool( + tool_name="sync_tool", + tool_description="Uses sync handler", + on_stream=sync_handler, + ), ) tool_context = ToolContext( context=None, @@ -584,7 +595,7 @@ def __init__(self) -> None: self.final_output = "ok" async def stream_events(self): - yield RawResponsesStreamEvent(data={"type": "response_started"}) + yield RawResponsesStreamEvent(data=cast(Any, {"type": "response_started"})) monkeypatch.setattr( Runner, "run_streamed", classmethod(lambda *args, **kwargs: DummyStreamingResult()) @@ -595,13 +606,16 @@ async def stream_events(self): classmethod(lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError("no run"))), ) - def bad_handler(event: dict[str, Any]) -> None: + def bad_handler(event: AgentToolStreamEvent) -> None: raise RuntimeError("boom") - tool = agent.as_tool( - tool_name="error_tool", - tool_description="Handler throws", - on_stream=bad_handler, + tool = cast( + FunctionTool, + agent.as_tool( + tool_name="error_tool", + tool_description="Handler throws", + on_stream=bad_handler, + ), ) tool_context = ToolContext( context=None, @@ -651,9 +665,12 @@ async def fake_run( classmethod(lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError("no stream"))), ) - tool = agent.as_tool( - tool_name="nostream_tool", - tool_description="No streaming path", + tool = cast( + FunctionTool, + agent.as_tool( + tool_name="nostream_tool", + tool_description="No streaming path", + ), ) tool_context = ToolContext( context=None, @@ -669,7 +686,7 @@ async def fake_run( @pytest.mark.asyncio -async def test_agent_as_tool_streaming_sets_tool_call_id_none_for_direct_invocation( +async def test_agent_as_tool_streaming_sets_tool_call_id_from_context( monkeypatch: pytest.MonkeyPatch, ) -> None: agent = Agent(name="direct_invocation_agent") @@ -679,7 +696,7 @@ def __init__(self) -> None: self.final_output = "ok" async def stream_events(self): - yield RawResponsesStreamEvent(data={"type": "response_started"}) + yield RawResponsesStreamEvent(data=cast(Any, {"type": "response_started"})) monkeypatch.setattr( Runner, "run_streamed", classmethod(lambda *args, **kwargs: DummyStreamingResult()) @@ -690,24 +707,27 @@ async def stream_events(self): classmethod(lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError("no run"))), ) - captured: list[dict[str, Any]] = [] + captured: list[AgentToolStreamEvent] = [] - async def on_stream(event: dict[str, Any]) -> None: + async def on_stream(event: AgentToolStreamEvent) -> None: captured.append(event) - tool = agent.as_tool( - tool_name="direct_stream_tool", - tool_description="Direct invocation", - on_stream=on_stream, + tool = cast( + FunctionTool, + agent.as_tool( + tool_name="direct_stream_tool", + tool_description="Direct invocation", + on_stream=on_stream, + ), ) tool_context = ToolContext( context=None, tool_name="direct_stream_tool", - tool_call_id=None, # Direct invoke path does not have a tool call ID. + tool_call_id="direct-call-id", tool_arguments='{"input": "hi"}', ) output = await tool.on_invoke_tool(tool_context, '{"input": "hi"}') assert output == "ok" - assert captured[0]["tool_call_id"] is None + assert captured[0]["tool_call_id"] == "direct-call-id" From 7f1672ade8dd00e4d71a542361ff92dec8f022d7 Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Mon, 15 Dec 2025 16:01:47 +0900 Subject: [PATCH 3/6] Enrich the event properties --- .../agents_as_tools_streaming.py | 4 +- src/agents/agent.py | 18 +++-- src/agents/tool_context.py | 12 ++- tests/test_agent_as_tool.py | 74 +++++++++++++++---- 4 files changed, 84 insertions(+), 24 deletions(-) diff --git a/examples/agent_patterns/agents_as_tools_streaming.py b/examples/agent_patterns/agents_as_tools_streaming.py index 846593c81..2eeda9989 100644 --- a/examples/agent_patterns/agents_as_tools_streaming.py +++ b/examples/agent_patterns/agents_as_tools_streaming.py @@ -18,7 +18,9 @@ def billing_status_checker(customer_id: str | None = None, question: str = "") - def handle_stream(event: AgentToolStreamEvent) -> None: """Print streaming events emitted by the nested billing agent.""" stream = event["event"] - print(f"[stream] agent={event['agent_name']} type={stream.type} {stream}") + tool_call = event.get("tool_call") + tool_call_info = tool_call.call_id if tool_call is not None else "unknown" + print(f"[stream] agent={event['agent'].name} call={tool_call_info} type={stream.type} {stream}") async def main() -> None: diff --git a/src/agents/agent.py b/src/agents/agent.py index d449fa3ae..153f3705e 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -29,6 +29,8 @@ from .util._types import MaybeAwaitable if TYPE_CHECKING: + from openai.types.responses.response_function_tool_call import ResponseFunctionToolCall + from .lifecycle import AgentHooks, RunHooks from .mcp import MCPServer from .memory.session import Session @@ -65,11 +67,11 @@ class AgentToolStreamEvent(TypedDict): event: StreamEvent """The streaming event from the nested agent run.""" - agent_name: str - """The name of the nested agent emitting the event.""" + agent: Agent[Any] + """The nested agent emitting the event.""" - tool_call_id: str | None - """The originating tool call ID, if available.""" + tool_call: ResponseFunctionToolCall | None + """The originating tool call, if available.""" class StopAtTools(TypedDict): @@ -427,7 +429,9 @@ def as_tool( context and agent and returns whether the tool is enabled. Disabled tools are hidden from the LLM at runtime. on_stream: Optional callback (sync or async) to receive streaming events from the nested - agent run. When provided, the nested agent is executed in streaming mode. + agent run. The callback receives an `AgentToolStreamEvent` containing the nested + agent, the originating tool call (when available), and each stream event. When + provided, the nested agent is executed in streaming mode. """ @function_tool( @@ -456,8 +460,8 @@ async def run_agent(context: RunContextWrapper, input: str) -> Any: async for event in run_result.stream_events(): payload: AgentToolStreamEvent = { "event": event, - "agent_name": self.name, - "tool_call_id": getattr(context, "tool_call_id", None), + "agent": self, + "tool_call": getattr(context, "tool_call", None), } try: maybe_result = on_stream(payload) diff --git a/src/agents/tool_context.py b/src/agents/tool_context.py index 5b81239f6..0fc354299 100644 --- a/src/agents/tool_context.py +++ b/src/agents/tool_context.py @@ -31,6 +31,9 @@ class ToolContext(RunContextWrapper[TContext]): tool_arguments: str = field(default_factory=_assert_must_pass_tool_arguments) """The raw arguments string of the tool call.""" + tool_call: Optional[ResponseFunctionToolCall] = None + """The tool call object associated with this invocation.""" + @classmethod def from_agent_context( cls, @@ -50,6 +53,11 @@ def from_agent_context( tool_call.arguments if tool_call is not None else _assert_must_pass_tool_arguments() ) - return cls( - tool_name=tool_name, tool_call_id=tool_call_id, tool_arguments=tool_args, **base_values + tool_context = cls( + tool_name=tool_name, + tool_call_id=tool_call_id, + tool_arguments=tool_args, + tool_call=tool_call, + **base_values, ) + return tool_context diff --git a/tests/test_agent_as_tool.py b/tests/test_agent_as_tool.py index ab5f57660..71d923b7b 100644 --- a/tests/test_agent_as_tool.py +++ b/tests/test_agent_as_tool.py @@ -4,6 +4,7 @@ import pytest from openai.types.responses import ResponseOutputMessage, ResponseOutputText +from openai.types.responses.response_function_tool_call import ResponseFunctionToolCall from pydantic import BaseModel from agents import ( @@ -437,6 +438,14 @@ async def unexpected_run(*args: Any, **kwargs: Any) -> None: async def on_stream(payload: AgentToolStreamEvent) -> None: received_events.append(payload) + tool_call = ResponseFunctionToolCall( + id="call_123", + arguments='{"input": "run streaming"}', + call_id="call-123", + name="stream_tool", + type="function_call", + ) + tool = cast( FunctionTool, agent.as_tool( @@ -449,15 +458,16 @@ async def on_stream(payload: AgentToolStreamEvent) -> None: tool_context = ToolContext( context=None, tool_name="stream_tool", - tool_call_id="call-123", - tool_arguments='{"input": "run streaming"}', + tool_call_id=tool_call.call_id, + tool_arguments=tool_call.arguments, + tool_call=tool_call, ) output = await tool.on_invoke_tool(tool_context, '{"input": "run streaming"}') assert output == "streamed output" assert len(received_events) == len(stream_events) - assert received_events[0]["agent_name"] == "streamer" - assert received_events[0]["tool_call_id"] == "call-123" + assert received_events[0]["agent"] is agent + assert received_events[0]["tool_call"] is tool_call assert received_events[0]["event"] == stream_events[0] assert run_calls[0]["input"] == "run streaming" @@ -513,6 +523,14 @@ async def extractor(result) -> str: async def on_stream(payload: AgentToolStreamEvent) -> None: callbacks.append(payload["event"]) + tool_call = ResponseFunctionToolCall( + id="call_abc", + arguments='{"input": "stream please"}', + call_id="call-abc", + name="stream_tool", + type="function_call", + ) + tool = cast( FunctionTool, agent.as_tool( @@ -526,8 +544,9 @@ async def on_stream(payload: AgentToolStreamEvent) -> None: tool_context = ToolContext( context=None, tool_name="stream_tool", - tool_call_id="call-abc", - tool_arguments='{"input": "stream please"}', + tool_call_id=tool_call.call_id, + tool_arguments=tool_call.arguments, + tool_call=tool_call, ) output = await tool.on_invoke_tool(tool_context, '{"input": "stream please"}') @@ -563,6 +582,14 @@ async def stream_events(self): def sync_handler(event: AgentToolStreamEvent) -> None: calls.append(event["event"].type) + tool_call = ResponseFunctionToolCall( + id="call_sync", + arguments='{"input": "go"}', + call_id="call-sync", + name="sync_tool", + type="function_call", + ) + tool = cast( FunctionTool, agent.as_tool( @@ -574,8 +601,9 @@ def sync_handler(event: AgentToolStreamEvent) -> None: tool_context = ToolContext( context=None, tool_name="sync_tool", - tool_call_id="call-sync", - tool_arguments='{"input": "go"}', + tool_call_id=tool_call.call_id, + tool_arguments=tool_call.arguments, + tool_call=tool_call, ) output = await tool.on_invoke_tool(tool_context, '{"input": "go"}') @@ -609,6 +637,14 @@ async def stream_events(self): def bad_handler(event: AgentToolStreamEvent) -> None: raise RuntimeError("boom") + tool_call = ResponseFunctionToolCall( + id="call_bad", + arguments='{"input": "go"}', + call_id="call-bad", + name="error_tool", + type="function_call", + ) + tool = cast( FunctionTool, agent.as_tool( @@ -620,8 +656,9 @@ def bad_handler(event: AgentToolStreamEvent) -> None: tool_context = ToolContext( context=None, tool_name="error_tool", - tool_call_id="call-bad", - tool_arguments='{"input": "go"}', + tool_call_id=tool_call.call_id, + tool_arguments=tool_call.arguments, + tool_call=tool_call, ) output = await tool.on_invoke_tool(tool_context, '{"input": "go"}') @@ -686,7 +723,7 @@ async def fake_run( @pytest.mark.asyncio -async def test_agent_as_tool_streaming_sets_tool_call_id_from_context( +async def test_agent_as_tool_streaming_sets_tool_call_from_context( monkeypatch: pytest.MonkeyPatch, ) -> None: agent = Agent(name="direct_invocation_agent") @@ -712,6 +749,14 @@ async def stream_events(self): async def on_stream(event: AgentToolStreamEvent) -> None: captured.append(event) + tool_call = ResponseFunctionToolCall( + id="call_direct", + arguments='{"input": "hi"}', + call_id="direct-call-id", + name="direct_stream_tool", + type="function_call", + ) + tool = cast( FunctionTool, agent.as_tool( @@ -723,11 +768,12 @@ async def on_stream(event: AgentToolStreamEvent) -> None: tool_context = ToolContext( context=None, tool_name="direct_stream_tool", - tool_call_id="direct-call-id", - tool_arguments='{"input": "hi"}', + tool_call_id=tool_call.call_id, + tool_arguments=tool_call.arguments, + tool_call=tool_call, ) output = await tool.on_invoke_tool(tool_context, '{"input": "hi"}') assert output == "ok" - assert captured[0]["tool_call_id"] == "direct-call-id" + assert captured[0]["tool_call"] is tool_call From 94d223937ce01bf9b71db1909fc5e16416e8fe20 Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Mon, 15 Dec 2025 16:44:03 +0900 Subject: [PATCH 4/6] improve the concurrency of event handling --- src/agents/agent.py | 40 +++++++++++++++--- tests/test_agent_as_tool.py | 83 +++++++++++++++++++++++++++++++++++++ 2 files changed, 117 insertions(+), 6 deletions(-) diff --git a/src/agents/agent.py b/src/agents/agent.py index 153f3705e..9c86f9727 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -457,12 +457,12 @@ async def run_agent(context: RunContextWrapper, input: str) -> Any: conversation_id=conversation_id, session=session, ) - async for event in run_result.stream_events(): - payload: AgentToolStreamEvent = { - "event": event, - "agent": self, - "tool_call": getattr(context, "tool_call", None), - } + # Dispatch callbacks in the background so slow handlers do not block + # event consumption. + event_queue: asyncio.Queue[AgentToolStreamEvent | None] = asyncio.Queue() + + async def _run_handler(payload: AgentToolStreamEvent) -> None: + """Execute the user callback while capturing exceptions.""" try: maybe_result = on_stream(payload) if inspect.isawaitable(maybe_result): @@ -472,6 +472,34 @@ async def run_agent(context: RunContextWrapper, input: str) -> Any: "Error while handling on_stream event for agent tool %s.", self.name, ) + + async def dispatch_stream_events() -> None: + while True: + payload = await event_queue.get() + is_sentinel = payload is None # None marks the end of the stream. + try: + if payload is not None: + await _run_handler(payload) + finally: + event_queue.task_done() + + if is_sentinel: + break + + dispatch_task = asyncio.create_task(dispatch_stream_events()) + + try: + async for event in run_result.stream_events(): + payload: AgentToolStreamEvent = { + "event": event, + "agent": self, + "tool_call": getattr(context, "tool_call", None), + } + await event_queue.put(payload) + finally: + await event_queue.put(None) + await event_queue.join() + await dispatch_task else: run_result = await Runner.run( starting_agent=self, diff --git a/tests/test_agent_as_tool.py b/tests/test_agent_as_tool.py index 71d923b7b..b480ec00d 100644 --- a/tests/test_agent_as_tool.py +++ b/tests/test_agent_as_tool.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio from typing import Any, cast import pytest @@ -612,6 +613,88 @@ def sync_handler(event: AgentToolStreamEvent) -> None: assert calls == ["raw_response_event"] +@pytest.mark.asyncio +async def test_agent_as_tool_streaming_dispatches_without_blocking( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """on_stream handlers should not block streaming iteration.""" + agent = Agent(name="nonblocking_agent") + + first_handler_started = asyncio.Event() + allow_handler_to_continue = asyncio.Event() + second_event_yielded = asyncio.Event() + second_event_handled = asyncio.Event() + + first_event = RawResponsesStreamEvent(data=cast(Any, {"type": "response_started"})) + second_event = RawResponsesStreamEvent( + data=cast(Any, {"type": "output_text_delta", "delta": "hi"}) + ) + + class DummyStreamingResult: + def __init__(self) -> None: + self.final_output = "ok" + + async def stream_events(self): + yield first_event + second_event_yielded.set() + yield second_event + + dummy_result = DummyStreamingResult() + + monkeypatch.setattr(Runner, "run_streamed", classmethod(lambda *args, **kwargs: dummy_result)) + monkeypatch.setattr( + Runner, + "run", + classmethod(lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError("no run"))), + ) + + async def on_stream(payload: AgentToolStreamEvent) -> None: + if payload["event"] is first_event: + first_handler_started.set() + await allow_handler_to_continue.wait() + else: + second_event_handled.set() + + tool_call = ResponseFunctionToolCall( + id="call_nonblocking", + arguments='{"input": "go"}', + call_id="call-nonblocking", + name="nonblocking_tool", + type="function_call", + ) + + tool = cast( + FunctionTool, + agent.as_tool( + tool_name="nonblocking_tool", + tool_description="Uses non-blocking streaming handler", + on_stream=on_stream, + ), + ) + tool_context = ToolContext( + context=None, + tool_name="nonblocking_tool", + tool_call_id=tool_call.call_id, + tool_arguments=tool_call.arguments, + tool_call=tool_call, + ) + + async def _invoke_tool() -> Any: + return await tool.on_invoke_tool(tool_context, '{"input": "go"}') + + invoke_task: asyncio.Task[Any] = asyncio.create_task(_invoke_tool()) + + await asyncio.wait_for(first_handler_started.wait(), timeout=1.0) + await asyncio.wait_for(second_event_yielded.wait(), timeout=1.0) + assert invoke_task.done() is False + + allow_handler_to_continue.set() + await asyncio.wait_for(second_event_handled.wait(), timeout=1.0) + output = await asyncio.wait_for(invoke_task, timeout=1.0) + + assert output == "ok" + + @pytest.mark.asyncio async def test_agent_as_tool_streaming_handler_exception_does_not_fail_call( monkeypatch: pytest.MonkeyPatch, From dbcb6e3294c81b6aaeb0fb9d23550a9bb8cfd7f4 Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Mon, 15 Dec 2025 18:02:34 +0900 Subject: [PATCH 5/6] fix --- src/agents/agent.py | 8 +++- tests/test_agent_as_tool.py | 82 ++++++++++++++++++++++++++++++++++++- 2 files changed, 88 insertions(+), 2 deletions(-) diff --git a/src/agents/agent.py b/src/agents/agent.py index 9c86f9727..4a6f7821f 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -489,10 +489,16 @@ async def dispatch_stream_events() -> None: dispatch_task = asyncio.create_task(dispatch_stream_events()) try: + from .stream_events import AgentUpdatedStreamEvent + + current_agent = getattr(run_result, "current_agent", self) async for event in run_result.stream_events(): + if isinstance(event, AgentUpdatedStreamEvent): + current_agent = event.new_agent + payload: AgentToolStreamEvent = { "event": event, - "agent": self, + "agent": current_agent, "tool_call": getattr(context, "tool_call", None), } await event_queue.put(payload) diff --git a/tests/test_agent_as_tool.py b/tests/test_agent_as_tool.py index b480ec00d..416700f64 100644 --- a/tests/test_agent_as_tool.py +++ b/tests/test_agent_as_tool.py @@ -21,7 +21,7 @@ Session, TResponseInputItem, ) -from agents.stream_events import RawResponsesStreamEvent +from agents.stream_events import AgentUpdatedStreamEvent, RawResponsesStreamEvent from agents.tool_context import ToolContext @@ -473,6 +473,86 @@ async def on_stream(payload: AgentToolStreamEvent) -> None: assert run_calls[0]["input"] == "run streaming" +@pytest.mark.asyncio +async def test_agent_as_tool_streaming_updates_agent_on_handoff( + monkeypatch: pytest.MonkeyPatch, +) -> None: + first_agent = Agent(name="primary") + handed_off_agent = Agent(name="delegate") + + events = [ + AgentUpdatedStreamEvent(new_agent=first_agent), + RawResponsesStreamEvent(data=cast(Any, {"type": "response_started"})), + AgentUpdatedStreamEvent(new_agent=handed_off_agent), + RawResponsesStreamEvent(data=cast(Any, {"type": "output_text_delta", "delta": "hello"})), + ] + + class DummyStreamingResult: + def __init__(self) -> None: + self.final_output = "delegated output" + + async def stream_events(self): + for ev in events: + yield ev + + def fake_run_streamed( + cls, + starting_agent, + input, + *, + context, + max_turns, + hooks, + run_config, + previous_response_id, + auto_previous_response_id=False, + conversation_id, + session, + ): + return DummyStreamingResult() + + monkeypatch.setattr(Runner, "run_streamed", classmethod(fake_run_streamed)) + monkeypatch.setattr( + Runner, + "run", + classmethod(lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError("no run"))), + ) + + seen_agents: list[Agent[Any]] = [] + + async def on_stream(payload: AgentToolStreamEvent) -> None: + seen_agents.append(payload["agent"]) + + tool = cast( + FunctionTool, + first_agent.as_tool( + tool_name="delegate_tool", + tool_description="Streams handoff events", + on_stream=on_stream, + ), + ) + + tool_call = ResponseFunctionToolCall( + id="call_delegate", + arguments='{"input": "handoff"}', + call_id="call-delegate", + name="delegate_tool", + type="function_call", + ) + tool_context = ToolContext( + context=None, + tool_name="delegate_tool", + tool_call_id=tool_call.call_id, + tool_arguments=tool_call.arguments, + tool_call=tool_call, + ) + + output = await tool.on_invoke_tool(tool_context, '{"input": "handoff"}') + + assert output == "delegated output" + assert seen_agents == [first_agent, first_agent, handed_off_agent, handed_off_agent] + + @pytest.mark.asyncio async def test_agent_as_tool_streaming_works_with_custom_extractor( monkeypatch: pytest.MonkeyPatch, From a8104bf28aac025918e9e56ea45340d27cb7b036 Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Mon, 15 Dec 2025 18:08:58 +0900 Subject: [PATCH 6/6] get rid of getattr --- src/agents/agent.py | 7 ++++--- tests/test_agent_as_tool.py | 7 +++++++ 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/src/agents/agent.py b/src/agents/agent.py index 4a6f7821f..d7e780ba9 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -25,6 +25,7 @@ from .prompts import DynamicPromptFunction, Prompt, PromptUtil from .run_context import RunContextWrapper, TContext from .tool import FunctionTool, FunctionToolResult, Tool, function_tool +from .tool_context import ToolContext from .util import _transforms from .util._types import MaybeAwaitable @@ -439,7 +440,7 @@ def as_tool( description_override=tool_description or "", is_enabled=is_enabled, ) - async def run_agent(context: RunContextWrapper, input: str) -> Any: + async def run_agent(context: ToolContext, input: str) -> Any: from .run import DEFAULT_MAX_TURNS, Runner resolved_max_turns = max_turns if max_turns is not None else DEFAULT_MAX_TURNS @@ -491,7 +492,7 @@ async def dispatch_stream_events() -> None: try: from .stream_events import AgentUpdatedStreamEvent - current_agent = getattr(run_result, "current_agent", self) + current_agent = run_result.current_agent async for event in run_result.stream_events(): if isinstance(event, AgentUpdatedStreamEvent): current_agent = event.new_agent @@ -499,7 +500,7 @@ async def dispatch_stream_events() -> None: payload: AgentToolStreamEvent = { "event": event, "agent": current_agent, - "tool_call": getattr(context, "tool_call", None), + "tool_call": context.tool_call, } await event_queue.put(payload) finally: diff --git a/tests/test_agent_as_tool.py b/tests/test_agent_as_tool.py index 416700f64..c28ce8fb1 100644 --- a/tests/test_agent_as_tool.py +++ b/tests/test_agent_as_tool.py @@ -392,6 +392,7 @@ async def test_agent_as_tool_streams_events_with_on_stream( class DummyStreamingResult: def __init__(self) -> None: self.final_output = "streamed output" + self.current_agent = agent async def stream_events(self): for ev in stream_events: @@ -490,6 +491,7 @@ async def test_agent_as_tool_streaming_updates_agent_on_handoff( class DummyStreamingResult: def __init__(self) -> None: self.final_output = "delegated output" + self.current_agent = first_agent async def stream_events(self): for ev in events: @@ -564,6 +566,7 @@ async def test_agent_as_tool_streaming_works_with_custom_extractor( class DummyStreamingResult: def __init__(self) -> None: self.final_output = "raw output" + self.current_agent = agent async def stream_events(self): for ev in stream_events: @@ -645,6 +648,7 @@ async def test_agent_as_tool_streaming_accepts_sync_handler( class DummyStreamingResult: def __init__(self) -> None: self.final_output = "ok" + self.current_agent = agent async def stream_events(self): yield RawResponsesStreamEvent(data=cast(Any, {"type": "response_started"})) @@ -713,6 +717,7 @@ async def test_agent_as_tool_streaming_dispatches_without_blocking( class DummyStreamingResult: def __init__(self) -> None: self.final_output = "ok" + self.current_agent = agent async def stream_events(self): yield first_event @@ -784,6 +789,7 @@ async def test_agent_as_tool_streaming_handler_exception_does_not_fail_call( class DummyStreamingResult: def __init__(self) -> None: self.final_output = "ok" + self.current_agent = agent async def stream_events(self): yield RawResponsesStreamEvent(data=cast(Any, {"type": "response_started"})) @@ -894,6 +900,7 @@ async def test_agent_as_tool_streaming_sets_tool_call_from_context( class DummyStreamingResult: def __init__(self) -> None: self.final_output = "ok" + self.current_agent = agent async def stream_events(self): yield RawResponsesStreamEvent(data=cast(Any, {"type": "response_started"}))