Skip to content

Commit d5feb89

Browse files
author
Chojan Shang
committed
refactor: update the buggy examples
Signed-off-by: Chojan Shang <chojan.shang@vesoft.com>
1 parent ba99789 commit d5feb89

File tree

4 files changed

+290
-57
lines changed

4 files changed

+290
-57
lines changed

examples/agent.py

Lines changed: 134 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
from dataclasses import dataclass, field
23
from typing import Any
34

45
from acp import (
@@ -19,13 +20,115 @@
1920
stdio_streams,
2021
PROTOCOL_VERSION,
2122
)
22-
from acp.schema import TextContentBlock, AgentMessageChunk
23+
from acp.schema import (
24+
AgentMessageChunk,
25+
AllowedOutcome,
26+
ContentToolCallContent,
27+
PermissionOption,
28+
RequestPermissionRequest,
29+
TextContentBlock,
30+
ToolCallUpdate,
31+
)
32+
33+
34+
@dataclass
35+
class SessionState:
36+
cancel_event: asyncio.Event = field(default_factory=asyncio.Event)
37+
prompt_counter: int = 0
38+
39+
def begin_prompt(self) -> None:
40+
self.prompt_counter += 1
41+
self.cancel_event.clear()
42+
43+
def cancel(self) -> None:
44+
self.cancel_event.set()
2345

2446

2547
class ExampleAgent(Agent):
2648
def __init__(self, conn: AgentSideConnection) -> None:
2749
self._conn = conn
2850
self._next_session_id = 0
51+
self._sessions: dict[str, SessionState] = {}
52+
53+
def _session(self, session_id: str) -> SessionState:
54+
state = self._sessions.get(session_id)
55+
if state is None:
56+
state = SessionState()
57+
self._sessions[session_id] = state
58+
return state
59+
60+
async def _send_text(self, session_id: str, text: str) -> None:
61+
await self._conn.sessionUpdate(
62+
SessionNotification(
63+
sessionId=session_id,
64+
update=AgentMessageChunk(
65+
sessionUpdate="agent_message_chunk",
66+
content=TextContentBlock(type="text", text=text),
67+
),
68+
)
69+
)
70+
71+
def _format_prompt_preview(self, blocks: list[Any]) -> str:
72+
parts: list[str] = []
73+
for block in blocks:
74+
if isinstance(block, dict):
75+
if block.get("type") == "text":
76+
parts.append(str(block.get("text", "")))
77+
else:
78+
parts.append(f"<{block.get('type', 'content')}>")
79+
else:
80+
parts.append(getattr(block, "text", "<content>"))
81+
preview = " \n".join(filter(None, parts)).strip()
82+
return preview or "<empty prompt>"
83+
84+
async def _request_permission(self, session_id: str, preview: str, state: SessionState) -> str:
85+
state.prompt_counter += 1
86+
request = RequestPermissionRequest(
87+
sessionId=session_id,
88+
toolCall=ToolCallUpdate(
89+
toolCallId=f"echo-{state.prompt_counter}",
90+
title="Echo input",
91+
kind="echo",
92+
status="pending",
93+
content=[
94+
ContentToolCallContent(
95+
type="content",
96+
content=TextContentBlock(type="text", text=preview),
97+
)
98+
],
99+
),
100+
options=[
101+
PermissionOption(optionId="allow-once", name="Allow once", kind="allow_once"),
102+
PermissionOption(optionId="deny", name="Deny", kind="reject_once"),
103+
],
104+
)
105+
106+
permission_task = asyncio.create_task(self._conn.requestPermission(request))
107+
cancel_task = asyncio.create_task(state.cancel_event.wait())
108+
109+
done, pending = await asyncio.wait({permission_task, cancel_task}, return_when=asyncio.FIRST_COMPLETED)
110+
111+
for task in pending:
112+
task.cancel()
113+
114+
if cancel_task in done:
115+
permission_task.cancel()
116+
return "cancelled"
117+
118+
try:
119+
response = await permission_task
120+
except asyncio.CancelledError:
121+
return "cancelled"
122+
except Exception as exc: # noqa: BLE001
123+
await self._send_text(session_id, f"Permission failed: {exc}")
124+
return "error"
125+
126+
if isinstance(response.outcome, AllowedOutcome):
127+
option_id = response.outcome.optionId
128+
if option_id.startswith("allow"):
129+
return "allowed"
130+
return "denied"
131+
return "cancelled"
29132

30133
async def initialize(self, params: InitializeRequest) -> InitializeResponse:
31134
return InitializeResponse(protocolVersion=PROTOCOL_VERSION, agentCapabilities=None, authMethods=[])
@@ -36,6 +139,7 @@ async def authenticate(self, params: AuthenticateRequest) -> AuthenticateRespons
36139
async def newSession(self, params: NewSessionRequest) -> NewSessionResponse: # noqa: ARG002
37140
session_id = f"sess-{self._next_session_id}"
38141
self._next_session_id += 1
142+
self._sessions[session_id] = SessionState()
39143
return NewSessionResponse(sessionId=session_id)
40144

41145
async def loadSession(self, params): # type: ignore[override]
@@ -45,41 +149,39 @@ async def setSessionMode(self, params: SetSessionModeRequest) -> SetSessionModeR
45149
return {}
46150

47151
async def prompt(self, params: PromptRequest) -> PromptResponse:
48-
# Stream a couple of agent message chunks, then end the turn
49-
# 1) Prefix
50-
await self._conn.sessionUpdate(
51-
SessionNotification(
52-
sessionId=params.sessionId,
53-
update=AgentMessageChunk(
54-
sessionUpdate="agent_message_chunk",
55-
content=TextContentBlock(type="text", text="Client sent: "),
56-
),
57-
)
58-
)
59-
# 2) Echo text blocks
152+
state = self._session(params.sessionId)
153+
state.begin_prompt()
154+
155+
preview = self._format_prompt_preview(list(params.prompt))
156+
await self._send_text(params.sessionId, "Agent received a prompt. Checking permissions...")
157+
158+
decision = await self._request_permission(params.sessionId, preview, state)
159+
if decision == "cancelled":
160+
await self._send_text(params.sessionId, "Prompt cancelled before permission decided.")
161+
return PromptResponse(stopReason="cancelled")
162+
if decision == "denied":
163+
await self._send_text(params.sessionId, "Permission denied by the client.")
164+
return PromptResponse(stopReason="permission_denied")
165+
if decision == "error":
166+
return PromptResponse(stopReason="error")
167+
168+
await self._send_text(params.sessionId, "Permission granted. Echoing content:")
169+
60170
for block in params.prompt:
61-
if isinstance(block, dict):
62-
# tolerate raw dicts
63-
if block.get("type") == "text":
64-
text = str(block.get("text", ""))
65-
else:
66-
text = f"<{block.get('type', 'content')}>"
67-
else:
68-
# pydantic model TextContentBlock
69-
text = getattr(block, "text", "<content>")
70-
await self._conn.sessionUpdate(
71-
SessionNotification(
72-
sessionId=params.sessionId,
73-
update=AgentMessageChunk(
74-
sessionUpdate="agent_message_chunk",
75-
content=TextContentBlock(type="text", text=text),
76-
),
77-
)
78-
)
171+
if state.cancel_event.is_set():
172+
await self._send_text(params.sessionId, "Prompt interrupted by cancellation.")
173+
return PromptResponse(stopReason="cancelled")
174+
text = self._format_prompt_preview([block])
175+
await self._send_text(params.sessionId, text)
176+
await asyncio.sleep(0.4)
177+
79178
return PromptResponse(stopReason="end_turn")
80179

81180
async def cancel(self, params: CancelNotification) -> None: # noqa: ARG002
82-
return None
181+
state = self._sessions.get(params.sessionId)
182+
if state:
183+
state.cancel()
184+
await self._send_text(params.sessionId, "Agent received cancel signal.")
83185

84186
async def extMethod(self, method: str, params: dict) -> dict: # noqa: ARG002
85187
return {"example": "response"}
@@ -90,7 +192,6 @@ async def extNotification(self, method: str, params: dict) -> None: # noqa: ARG
90192

91193
async def main() -> None:
92194
reader, writer = await stdio_streams()
93-
# For an agent process, local writes go to client stdin (writer=stdout)
94195
AgentSideConnection(lambda conn: ExampleAgent(conn), writer, reader)
95196
await asyncio.Event().wait()
96197

0 commit comments

Comments
 (0)