11import asyncio
2+ from dataclasses import dataclass , field
23from typing import Any
34
45from acp import (
1920 stdio_streams ,
2021 PROTOCOL_VERSION ,
2122)
22- from acp .schema import TextContentBlock , AgentMessageChunk
23+ from acp .schema import (
24+ AgentMessageChunk ,
25+ AllowedOutcome ,
26+ ContentToolCallContent ,
27+ PermissionOption ,
28+ RequestPermissionRequest ,
29+ TextContentBlock ,
30+ ToolCallUpdate ,
31+ )
32+
33+
34+ @dataclass
35+ class SessionState :
36+ cancel_event : asyncio .Event = field (default_factory = asyncio .Event )
37+ prompt_counter : int = 0
38+
39+ def begin_prompt (self ) -> None :
40+ self .prompt_counter += 1
41+ self .cancel_event .clear ()
42+
43+ def cancel (self ) -> None :
44+ self .cancel_event .set ()
2345
2446
2547class ExampleAgent (Agent ):
2648 def __init__ (self , conn : AgentSideConnection ) -> None :
2749 self ._conn = conn
2850 self ._next_session_id = 0
51+ self ._sessions : dict [str , SessionState ] = {}
52+
53+ def _session (self , session_id : str ) -> SessionState :
54+ state = self ._sessions .get (session_id )
55+ if state is None :
56+ state = SessionState ()
57+ self ._sessions [session_id ] = state
58+ return state
59+
60+ async def _send_text (self , session_id : str , text : str ) -> None :
61+ await self ._conn .sessionUpdate (
62+ SessionNotification (
63+ sessionId = session_id ,
64+ update = AgentMessageChunk (
65+ sessionUpdate = "agent_message_chunk" ,
66+ content = TextContentBlock (type = "text" , text = text ),
67+ ),
68+ )
69+ )
70+
71+ def _format_prompt_preview (self , blocks : list [Any ]) -> str :
72+ parts : list [str ] = []
73+ for block in blocks :
74+ if isinstance (block , dict ):
75+ if block .get ("type" ) == "text" :
76+ parts .append (str (block .get ("text" , "" )))
77+ else :
78+ parts .append (f"<{ block .get ('type' , 'content' )} >" )
79+ else :
80+ parts .append (getattr (block , "text" , "<content>" ))
81+ preview = " \n " .join (filter (None , parts )).strip ()
82+ return preview or "<empty prompt>"
83+
84+ async def _request_permission (self , session_id : str , preview : str , state : SessionState ) -> str :
85+ state .prompt_counter += 1
86+ request = RequestPermissionRequest (
87+ sessionId = session_id ,
88+ toolCall = ToolCallUpdate (
89+ toolCallId = f"echo-{ state .prompt_counter } " ,
90+ title = "Echo input" ,
91+ kind = "echo" ,
92+ status = "pending" ,
93+ content = [
94+ ContentToolCallContent (
95+ type = "content" ,
96+ content = TextContentBlock (type = "text" , text = preview ),
97+ )
98+ ],
99+ ),
100+ options = [
101+ PermissionOption (optionId = "allow-once" , name = "Allow once" , kind = "allow_once" ),
102+ PermissionOption (optionId = "deny" , name = "Deny" , kind = "reject_once" ),
103+ ],
104+ )
105+
106+ permission_task = asyncio .create_task (self ._conn .requestPermission (request ))
107+ cancel_task = asyncio .create_task (state .cancel_event .wait ())
108+
109+ done , pending = await asyncio .wait ({permission_task , cancel_task }, return_when = asyncio .FIRST_COMPLETED )
110+
111+ for task in pending :
112+ task .cancel ()
113+
114+ if cancel_task in done :
115+ permission_task .cancel ()
116+ return "cancelled"
117+
118+ try :
119+ response = await permission_task
120+ except asyncio .CancelledError :
121+ return "cancelled"
122+ except Exception as exc : # noqa: BLE001
123+ await self ._send_text (session_id , f"Permission failed: { exc } " )
124+ return "error"
125+
126+ if isinstance (response .outcome , AllowedOutcome ):
127+ option_id = response .outcome .optionId
128+ if option_id .startswith ("allow" ):
129+ return "allowed"
130+ return "denied"
131+ return "cancelled"
29132
30133 async def initialize (self , params : InitializeRequest ) -> InitializeResponse :
31134 return InitializeResponse (protocolVersion = PROTOCOL_VERSION , agentCapabilities = None , authMethods = [])
@@ -36,6 +139,7 @@ async def authenticate(self, params: AuthenticateRequest) -> AuthenticateRespons
36139 async def newSession (self , params : NewSessionRequest ) -> NewSessionResponse : # noqa: ARG002
37140 session_id = f"sess-{ self ._next_session_id } "
38141 self ._next_session_id += 1
142+ self ._sessions [session_id ] = SessionState ()
39143 return NewSessionResponse (sessionId = session_id )
40144
41145 async def loadSession (self , params ): # type: ignore[override]
@@ -45,41 +149,39 @@ async def setSessionMode(self, params: SetSessionModeRequest) -> SetSessionModeR
45149 return {}
46150
47151 async def prompt (self , params : PromptRequest ) -> PromptResponse :
48- # Stream a couple of agent message chunks, then end the turn
49- # 1) Prefix
50- await self ._conn .sessionUpdate (
51- SessionNotification (
52- sessionId = params .sessionId ,
53- update = AgentMessageChunk (
54- sessionUpdate = "agent_message_chunk" ,
55- content = TextContentBlock (type = "text" , text = "Client sent: " ),
56- ),
57- )
58- )
59- # 2) Echo text blocks
152+ state = self ._session (params .sessionId )
153+ state .begin_prompt ()
154+
155+ preview = self ._format_prompt_preview (list (params .prompt ))
156+ await self ._send_text (params .sessionId , "Agent received a prompt. Checking permissions..." )
157+
158+ decision = await self ._request_permission (params .sessionId , preview , state )
159+ if decision == "cancelled" :
160+ await self ._send_text (params .sessionId , "Prompt cancelled before permission decided." )
161+ return PromptResponse (stopReason = "cancelled" )
162+ if decision == "denied" :
163+ await self ._send_text (params .sessionId , "Permission denied by the client." )
164+ return PromptResponse (stopReason = "permission_denied" )
165+ if decision == "error" :
166+ return PromptResponse (stopReason = "error" )
167+
168+ await self ._send_text (params .sessionId , "Permission granted. Echoing content:" )
169+
60170 for block in params .prompt :
61- if isinstance (block , dict ):
62- # tolerate raw dicts
63- if block .get ("type" ) == "text" :
64- text = str (block .get ("text" , "" ))
65- else :
66- text = f"<{ block .get ('type' , 'content' )} >"
67- else :
68- # pydantic model TextContentBlock
69- text = getattr (block , "text" , "<content>" )
70- await self ._conn .sessionUpdate (
71- SessionNotification (
72- sessionId = params .sessionId ,
73- update = AgentMessageChunk (
74- sessionUpdate = "agent_message_chunk" ,
75- content = TextContentBlock (type = "text" , text = text ),
76- ),
77- )
78- )
171+ if state .cancel_event .is_set ():
172+ await self ._send_text (params .sessionId , "Prompt interrupted by cancellation." )
173+ return PromptResponse (stopReason = "cancelled" )
174+ text = self ._format_prompt_preview ([block ])
175+ await self ._send_text (params .sessionId , text )
176+ await asyncio .sleep (0.4 )
177+
79178 return PromptResponse (stopReason = "end_turn" )
80179
81180 async def cancel (self , params : CancelNotification ) -> None : # noqa: ARG002
82- return None
181+ state = self ._sessions .get (params .sessionId )
182+ if state :
183+ state .cancel ()
184+ await self ._send_text (params .sessionId , "Agent received cancel signal." )
83185
84186 async def extMethod (self , method : str , params : dict ) -> dict : # noqa: ARG002
85187 return {"example" : "response" }
@@ -90,7 +192,6 @@ async def extNotification(self, method: str, params: dict) -> None: # noqa: ARG
90192
91193async def main () -> None :
92194 reader , writer = await stdio_streams ()
93- # For an agent process, local writes go to client stdin (writer=stdout)
94195 AgentSideConnection (lambda conn : ExampleAgent (conn ), writer , reader )
95196 await asyncio .Event ().wait ()
96197
0 commit comments