diff --git a/examples/agent.rs b/examples/agent.rs index da6e1e9..5dae9bf 100644 --- a/examples/agent.rs +++ b/examples/agent.rs @@ -142,15 +142,6 @@ impl acp::Agent for ExampleAgent { Ok(acp::SetSessionModelResponse::default()) } - #[cfg(feature = "unstable_session_list")] - async fn list_sessions( - &self, - args: acp::ListSessionsRequest, - ) -> Result { - log::info!("Received list sessions request {args:?}"); - Ok(acp::ListSessionsResponse::new(vec![])) - } - async fn ext_method(&self, args: acp::ExtRequest) -> Result { log::info!( "Received extension method call: method={}, params={:?}", diff --git a/src/agent-client-protocol/src/agent.rs b/src/agent-client-protocol/src/agent.rs index f1e4ba4..eb40a93 100644 --- a/src/agent-client-protocol/src/agent.rs +++ b/src/agent-client-protocol/src/agent.rs @@ -6,6 +6,8 @@ use agent_client_protocol_schema::{ LoadSessionResponse, NewSessionRequest, NewSessionResponse, PromptRequest, PromptResponse, Result, SetSessionModeRequest, SetSessionModeResponse, }; +#[cfg(feature = "unstable_session_fork")] +use agent_client_protocol_schema::{ForkSessionRequest, ForkSessionResponse}; #[cfg(feature = "unstable_session_list")] use agent_client_protocol_schema::{ListSessionsRequest, ListSessionsResponse}; #[cfg(feature = "unstable_session_model")] @@ -140,6 +142,18 @@ pub trait Agent { Err(Error::method_not_found()) } + /// **UNSTABLE** + /// + /// This capability is not part of the spec yet, and may be removed or changed at any point. + /// + /// Forks an existing session, creating a new session with the same conversation history. + /// + /// Only available if the Agent supports the `sessionCapabilities.fork` capability. + #[cfg(feature = "unstable_session_fork")] + async fn fork_session(&self, _args: ForkSessionRequest) -> Result { + Err(Error::method_not_found()) + } + /// Handles extension method requests from the client. /// /// Extension methods provide a way to add custom functionality while maintaining @@ -198,6 +212,10 @@ impl Agent for Rc { async fn list_sessions(&self, args: ListSessionsRequest) -> Result { self.as_ref().list_sessions(args).await } + #[cfg(feature = "unstable_session_fork")] + async fn fork_session(&self, args: ForkSessionRequest) -> Result { + self.as_ref().fork_session(args).await + } async fn ext_method(&self, args: ExtRequest) -> Result { self.as_ref().ext_method(args).await } @@ -243,6 +261,10 @@ impl Agent for Arc { async fn list_sessions(&self, args: ListSessionsRequest) -> Result { self.as_ref().list_sessions(args).await } + #[cfg(feature = "unstable_session_fork")] + async fn fork_session(&self, args: ForkSessionRequest) -> Result { + self.as_ref().fork_session(args).await + } async fn ext_method(&self, args: ExtRequest) -> Result { self.as_ref().ext_method(args).await } diff --git a/src/agent-client-protocol/src/lib.rs b/src/agent-client-protocol/src/lib.rs index a2b88b3..2abaab9 100644 --- a/src/agent-client-protocol/src/lib.rs +++ b/src/agent-client-protocol/src/lib.rs @@ -165,6 +165,16 @@ impl Agent for ClientSideConnection { .await } + #[cfg(feature = "unstable_session_fork")] + async fn fork_session(&self, args: ForkSessionRequest) -> Result { + self.conn + .request( + AGENT_METHOD_NAMES.session_fork, + Some(ClientRequest::ForkSessionRequest(args)), + ) + .await + } + async fn ext_method(&self, args: ExtRequest) -> Result { self.conn .request( @@ -532,6 +542,10 @@ impl Side for AgentSide { m if m == AGENT_METHOD_NAMES.session_list => serde_json::from_str(params.get()) .map(ClientRequest::ListSessionsRequest) .map_err(Into::into), + #[cfg(feature = "unstable_session_fork")] + m if m == AGENT_METHOD_NAMES.session_fork => serde_json::from_str(params.get()) + .map(ClientRequest::ForkSessionRequest) + .map_err(Into::into), m if m == AGENT_METHOD_NAMES.session_prompt => serde_json::from_str(params.get()) .map(ClientRequest::PromptRequest) .map_err(Into::into), @@ -606,6 +620,11 @@ impl MessageHandler for T { let response = self.list_sessions(args).await?; Ok(AgentResponse::ListSessionsResponse(response)) } + #[cfg(feature = "unstable_session_fork")] + ClientRequest::ForkSessionRequest(args) => { + let response = self.fork_session(args).await?; + Ok(AgentResponse::ForkSessionResponse(response)) + } ClientRequest::ExtMethodRequest(args) => { let response = self.ext_method(args).await?; Ok(AgentResponse::ExtMethodResponse(response)) diff --git a/src/agent-client-protocol/src/rpc_tests.rs b/src/agent-client-protocol/src/rpc_tests.rs index 8467564..3e62d96 100644 --- a/src/agent-client-protocol/src/rpc_tests.rs +++ b/src/agent-client-protocol/src/rpc_tests.rs @@ -133,7 +133,7 @@ impl Client for TestClient { #[derive(Clone)] struct TestAgent { - sessions: Arc>>, + sessions: Arc>>, prompts_received: Arc>>, cancellations_received: Arc>>, extension_notifications: Arc>>, @@ -144,7 +144,7 @@ type PromptReceived = (SessionId, Vec); impl TestAgent { fn new() -> Self { Self { - sessions: Arc::new(Mutex::new(std::collections::HashSet::new())), + sessions: Arc::new(Mutex::new(std::collections::HashMap::new())), prompts_received: Arc::new(Mutex::new(Vec::new())), cancellations_received: Arc::new(Mutex::new(Vec::new())), extension_notifications: Arc::new(Mutex::new(Vec::new())), @@ -163,9 +163,12 @@ impl Agent for TestAgent { Ok(AuthenticateResponse::default()) } - async fn new_session(&self, _arguments: NewSessionRequest) -> Result { + async fn new_session(&self, arguments: NewSessionRequest) -> Result { let session_id = SessionId::new("test-session-123"); - self.sessions.lock().unwrap().insert(session_id.clone()); + self.sessions + .lock() + .unwrap() + .insert(session_id.clone(), arguments.cwd); Ok(NewSessionResponse::new(session_id)) } @@ -210,8 +213,30 @@ impl Agent for TestAgent { &self, _args: agent_client_protocol_schema::ListSessionsRequest, ) -> Result { + let sessions = self.sessions.lock().unwrap(); + let session_infos: Vec<_> = sessions + .iter() + .map(|(id, cwd)| { + agent_client_protocol_schema::SessionInfo::new(id.clone(), cwd.clone()) + }) + .collect(); Ok(agent_client_protocol_schema::ListSessionsResponse::new( - vec![], + session_infos, + )) + } + + #[cfg(feature = "unstable_session_fork")] + async fn fork_session( + &self, + args: agent_client_protocol_schema::ForkSessionRequest, + ) -> Result { + let new_session_id = SessionId::new(format!("fork-of-{}", args.session_id.0)); + self.sessions + .lock() + .unwrap() + .insert(new_session_id.clone(), args.cwd); + Ok(agent_client_protocol_schema::ForkSessionResponse::new( + new_session_id, )) } @@ -665,3 +690,86 @@ async fn test_extension_methods_and_notifications() { }) .await; } + +#[cfg(feature = "unstable_session_fork")] +#[tokio::test] +async fn test_fork_session() { + let local_set = tokio::task::LocalSet::new(); + local_set + .run_until(async { + let client = TestClient::new(); + let agent = TestAgent::new(); + + let (agent_conn, _client_conn) = create_connection_pair(&client, &agent); + + // First create a session + let new_session_response = agent_conn + .new_session(NewSessionRequest::new("/test")) + .await + .expect("new_session failed"); + + let original_session_id = new_session_response.session_id; + + // Fork the session + let fork_response = agent_conn + .fork_session(agent_client_protocol_schema::ForkSessionRequest::new( + original_session_id.clone(), + "/test", + )) + .await + .expect("fork_session failed"); + + // Verify the forked session has a different ID + assert_ne!(fork_response.session_id, original_session_id); + assert_eq!( + fork_response.session_id.0.as_ref(), + format!("fork-of-{}", original_session_id.0) + ); + + // Verify the forked session was added to the agent's sessions + let sessions = agent.sessions.lock().unwrap(); + assert!(sessions.contains_key(&fork_response.session_id)); + }) + .await; +} + +#[cfg(feature = "unstable_session_list")] +#[tokio::test] +async fn test_list_sessions() { + let local_set = tokio::task::LocalSet::new(); + local_set + .run_until(async { + let client = TestClient::new(); + let agent = TestAgent::new(); + + let (agent_conn, _client_conn) = create_connection_pair(&client, &agent); + + // First create a session + let new_session_response = agent_conn + .new_session(NewSessionRequest::new("/test")) + .await + .expect("new_session failed"); + + // Verify the session was created + assert!(!new_session_response.session_id.0.is_empty()); + + // List sessions + let list_response = agent_conn + .list_sessions(agent_client_protocol_schema::ListSessionsRequest::new()) + .await + .expect("list_sessions failed"); + + // Verify the response contains our session + assert_eq!(list_response.sessions.len(), 1); + assert_eq!( + list_response.sessions[0].session_id, + new_session_response.session_id + ); + assert_eq!( + list_response.sessions[0].cwd, + std::path::PathBuf::from("/test") + ); + assert!(list_response.next_cursor.is_none()); + }) + .await; +}