diff --git a/crates/test-programs/src/bin/p3_readdir.rs b/crates/test-programs/src/bin/p3_readdir.rs new file mode 100644 index 000000000000..aaeb414f8262 --- /dev/null +++ b/crates/test-programs/src/bin/p3_readdir.rs @@ -0,0 +1,95 @@ +use test_programs::p3::wasi; +use test_programs::p3::wasi::filesystem::types::{ + Descriptor, DescriptorFlags, DescriptorType, DirectoryEntry, OpenFlags, PathFlags, +}; + +struct Component; + +test_programs::p3::export!(Component); + +impl test_programs::p3::exports::wasi::cli::run::Guest for Component { + async fn run() -> Result<(), ()> { + let preopens = wasi::filesystem::preopens::get_directories(); + let (dir, _) = &preopens[0]; + + test_readdir(dir).await; + test_readdir_lots(dir).await; + Ok(()) + } +} + +fn main() { + unreachable!() +} + +async fn read_dir(dir: &Descriptor) -> Vec { + let (dirs, result) = dir.read_directory().await; + let mut dirs = dirs.collect().await; + result.await.unwrap(); + dirs.sort_by_key(|d| d.name.clone()); + dirs +} + +async fn assert_empty_dir(dir: &Descriptor) { + let dirs = read_dir(dir).await; + assert_eq!(dirs.len(), 0); +} + +async fn test_readdir(dir: &Descriptor) { + // Check the behavior in an empty directory + assert_empty_dir(dir).await; + + dir.open_at( + PathFlags::empty(), + "file".to_string(), + OpenFlags::CREATE, + DescriptorFlags::READ | DescriptorFlags::WRITE, + ) + .await + .unwrap(); + + dir.create_directory_at("nested".to_string()).await.unwrap(); + let nested = dir + .open_at( + PathFlags::empty(), + "nested".to_string(), + OpenFlags::DIRECTORY, + DescriptorFlags::empty(), + ) + .await + .unwrap(); + + let entries = read_dir(dir).await; + assert_eq!(entries.len(), 2); + assert_eq!(entries[0].name, "file"); + assert_eq!(entries[0].type_, DescriptorType::RegularFile); + assert_eq!(entries[1].name, "nested"); + assert_eq!(entries[1].type_, DescriptorType::Directory); + + assert_empty_dir(&nested).await; + drop(nested); + + dir.unlink_file_at("file".to_string()).await.unwrap(); + dir.remove_directory_at("nested".to_string()).await.unwrap(); +} + +async fn test_readdir_lots(dir: &Descriptor) { + for count in 0..1000 { + dir.open_at( + PathFlags::empty(), + format!("file.{count}"), + OpenFlags::CREATE, + DescriptorFlags::READ | DescriptorFlags::WRITE, + ) + .await + .expect("failed to create file"); + } + + assert_eq!(read_dir(dir).await.len(), 1000); + + for count in 0..1000 { + dir.unlink_file_at(format!("file.{count}")) + .await + .expect("removing a file"); + } +} diff --git a/crates/wasi/src/p3/filesystem/host.rs b/crates/wasi/src/p3/filesystem/host.rs index 763040934b88..eedbccc39687 100644 --- a/crates/wasi/src/p3/filesystem/host.rs +++ b/crates/wasi/src/p3/filesystem/host.rs @@ -6,7 +6,8 @@ use crate::p3::bindings::filesystem::types::{ }; use crate::p3::filesystem::{FilesystemError, FilesystemResult, preopens}; use crate::p3::{ - DEFAULT_BUFFER_CAPACITY, FutureOneshotProducer, FutureReadyProducer, StreamEmptyProducer, + DEFAULT_BUFFER_CAPACITY, FallibleIteratorProducer, FutureOneshotProducer, FutureReadyProducer, + StreamEmptyProducer, }; use crate::{DirPerms, FilePerms}; use anyhow::{Context as _, anyhow}; @@ -22,7 +23,7 @@ use tokio::task::{JoinHandle, spawn_blocking}; use wasmtime::StoreContextMut; use wasmtime::component::{ Accessor, Destination, FutureReader, Resource, ResourceTable, Source, StreamConsumer, - StreamProducer, StreamReader, StreamResult, VecBuffer, + StreamProducer, StreamReader, StreamResult, }; fn get_descriptor<'a>( @@ -291,150 +292,94 @@ fn map_dir_entry( } } -struct BlockingDirectoryStreamProducer { - dir: Arc, +struct ReadDirStream { + rx: mpsc::Receiver, + task: JoinHandle>, result: Option>>, } -impl Drop for BlockingDirectoryStreamProducer { - fn drop(&mut self) { - self.close(Ok(())) - } -} - -impl BlockingDirectoryStreamProducer { - fn close(&mut self, res: Result<(), ErrorCode>) { - if let Some(tx) = self.result.take() { - _ = tx.send(res); - } - } -} - -impl StreamProducer for BlockingDirectoryStreamProducer { - type Item = DirectoryEntry; - type Buffer = VecBuffer; - - fn poll_produce<'a>( - mut self: Pin<&mut Self>, - _: &mut Context<'_>, - _: StoreContextMut<'a, D>, - mut dst: Destination<'a, Self::Item, Self::Buffer>, - _finish: bool, - ) -> Poll> { - let entries = match self.dir.entries() { - Ok(entries) => entries, - Err(err) => { - self.close(Err(err.into())); - return Poll::Ready(Ok(StreamResult::Dropped)); - } - }; - let res = match entries - .filter_map(|entry| map_dir_entry(entry).transpose()) - .collect::, _>>() - { - Ok(entries) => { - dst.set_buffer(entries.into()); - Ok(()) - } - Err(err) => Err(err), - }; - self.close(res); - Poll::Ready(Ok(StreamResult::Dropped)) - } -} - -struct NonblockingDirectoryStreamProducer(DirStreamState); - -enum DirStreamState { - Init { +impl ReadDirStream { + fn new( dir: Arc, result: oneshot::Sender>, - }, - InProgress { - rx: mpsc::Receiver, - task: JoinHandle>, - result: oneshot::Sender>, - }, - Closed, -} - -impl Drop for NonblockingDirectoryStreamProducer { - fn drop(&mut self) { - self.close(Ok(())) + ) -> ReadDirStream { + let (tx, rx) = mpsc::channel(1); + ReadDirStream { + task: spawn_blocking(move || { + let entries = dir.entries()?; + for entry in entries { + if let Some(entry) = map_dir_entry(entry)? { + if let Err(_) = tx.blocking_send(entry) { + break; + } + } + } + Ok(()) + }), + rx, + result: Some(result), + } } -} -impl NonblockingDirectoryStreamProducer { fn close(&mut self, res: Result<(), ErrorCode>) { - if let DirStreamState::Init { result, .. } | DirStreamState::InProgress { result, .. } = - mem::replace(&mut self.0, DirStreamState::Closed) - { - _ = result.send(res); - } + self.rx.close(); + self.task.abort(); + let _ = self.result.take().unwrap().send(res); } } -impl StreamProducer for NonblockingDirectoryStreamProducer { +impl StreamProducer for ReadDirStream { type Item = DirectoryEntry; type Buffer = Option; fn poll_produce<'a>( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - store: StoreContextMut<'a, D>, + mut store: StoreContextMut<'a, D>, mut dst: Destination<'a, Self::Item, Self::Buffer>, finish: bool, ) -> Poll> { - match mem::replace(&mut self.0, DirStreamState::Closed) { - DirStreamState::Init { .. } if finish => Poll::Ready(Ok(StreamResult::Cancelled)), - DirStreamState::Init { dir, result } => { - let (entry_tx, entry_rx) = mpsc::channel(1); - let task = spawn_blocking(move || { - let entries = dir.entries()?; - for entry in entries { - if let Some(entry) = map_dir_entry(entry)? { - if let Err(_) = entry_tx.blocking_send(entry) { - break; - } - } - } - Ok(()) - }); - self.0 = DirStreamState::InProgress { - rx: entry_rx, - task, - result, - }; - self.poll_produce(cx, store, dst, finish) + // If this is a 0-length read then `mpsc::Receiver` does not expose an + // API to wait for an item to be available without taking it out of the + // channel. In lieu of that just say that we're complete and ready for a + // read. + if dst.remaining(&mut store) == Some(0) { + return Poll::Ready(Ok(StreamResult::Completed)); + } + + match self.rx.poll_recv(cx) { + // If an item is on the channel then send that along and say that + // the read is now complete with one item being yielded. + Poll::Ready(Some(item)) => { + dst.set_buffer(Some(item)); + Poll::Ready(Ok(StreamResult::Completed)) } - DirStreamState::InProgress { - mut rx, - mut task, - result, - } => { - let Poll::Ready(res) = rx.poll_recv(cx) else { - self.0 = DirStreamState::InProgress { rx, task, result }; - if finish { - return Poll::Ready(Ok(StreamResult::Cancelled)); - } - return Poll::Pending; - }; - match res { - Some(entry) => { - self.0 = DirStreamState::InProgress { rx, task, result }; - dst.set_buffer(Some(entry)); - Poll::Ready(Ok(StreamResult::Completed)) - } - None => { - let res = ready!(Pin::new(&mut task).poll(cx)) - .context("failed to join I/O task")?; - self.0 = DirStreamState::InProgress { rx, task, result }; - self.close(res); - Poll::Ready(Ok(StreamResult::Dropped)) - } - } + + // If there's nothing left on the channel then that means that an + // error occurred or the iterator is done. In both cases an + // un-cancellable wait for the spawned task is entered and we await + // its completion. Upon completion there our own stream is closed + // with the result (sending an error code on our oneshot) and then + // the stream is reported as dropped. + Poll::Ready(None) => { + let result = ready!(Pin::new(&mut self.task).poll(cx)) + .expect("spawned task should not panic"); + self.close(result); + Poll::Ready(Ok(StreamResult::Dropped)) } - DirStreamState::Closed => Poll::Ready(Ok(StreamResult::Dropped)), + + // If an item isn't ready yet then cancel this outstanding request + // if `finish` is set, otherwise propagate the `Pending` status. + Poll::Pending if finish => Poll::Ready(Ok(StreamResult::Cancelled)), + Poll::Pending => Poll::Pending, + } + } +} + +impl Drop for ReadDirStream { + fn drop(&mut self) { + if self.result.is_some() { + self.close(Ok(())); } } } @@ -848,23 +793,22 @@ impl types::HostDescriptorWithStore for WasiFilesystem { let dir = Arc::clone(dir.as_dir()); let (result_tx, result_rx) = oneshot::channel(); let stream = if allow_blocking_current_thread { - StreamReader::new( - instance, - &mut store, - BlockingDirectoryStreamProducer { - dir, - result: Some(result_tx), - }, - ) + match dir.entries() { + Ok(readdir) => StreamReader::new( + instance, + &mut store, + FallibleIteratorProducer::new( + readdir.filter_map(|e| map_dir_entry(e).transpose()), + result_tx, + ), + ), + Err(e) => { + result_tx.send(Err(e.into())).unwrap(); + StreamReader::new(instance, &mut store, StreamEmptyProducer::default()) + } + } } else { - StreamReader::new( - instance, - &mut store, - NonblockingDirectoryStreamProducer(DirStreamState::Init { - dir, - result: result_tx, - }), - ) + StreamReader::new(instance, &mut store, ReadDirStream::new(dir, result_tx)) }; Ok(( stream, diff --git a/crates/wasi/src/p3/mod.rs b/crates/wasi/src/p3/mod.rs index 67c2b6699c10..f3bab1d63217 100644 --- a/crates/wasi/src/p3/mod.rs +++ b/crates/wasi/src/p3/mod.rs @@ -24,7 +24,7 @@ use core::task::{Context, Poll}; use tokio::sync::oneshot; use wasmtime::StoreContextMut; use wasmtime::component::{ - Accessor, Destination, FutureProducer, Linker, StreamProducer, StreamResult, + Accessor, Destination, FutureProducer, Linker, StreamProducer, StreamResult, VecBuffer, }; // Default buffer capacity to use for reads of byte-sized values. @@ -79,6 +79,100 @@ where } } +/// Helper structure to convert an iterator of `Result` into a `stream` +/// plus a `future>` in WIT. +/// +/// This will drain the iterator on calls to `poll_produce` and place as many +/// items as the input buffer has capacity for into the result. This will avoid +/// doing anything if the async read is cancelled. +/// +/// Note that this does not actually do anything async, it's assuming that the +/// internal `iter` is either fast or intended to block. +struct FallibleIteratorProducer { + iter: I, + result: Option>>, +} + +impl StreamProducer for FallibleIteratorProducer +where + I: Iterator> + Send + Unpin + 'static, + T: Send + Sync + 'static, + E: Send + 'static, +{ + type Item = T; + type Buffer = VecBuffer; + + fn poll_produce<'a>( + mut self: Pin<&mut Self>, + _: &mut Context<'_>, + mut store: StoreContextMut<'a, D>, + mut dst: Destination<'a, Self::Item, Self::Buffer>, + // Explicitly ignore `_finish` because this implementation never + // returns `Poll::Pending` anyway meaning that it never "blocks" in the + // async sense. + _finish: bool, + ) -> Poll> { + // Take up to `count` items as requested by the guest, or pick some + // reasonable-ish number for the host. + let count = dst.remaining(&mut store).unwrap_or(32); + + // Handle 0-length reads which test for readiness as saying "we're + // always ready" since, in theory, this is. + if count == 0 { + return Poll::Ready(Ok(StreamResult::Completed)); + } + + // Drain `self.iter`. Successful results go into `buf`. Any errors make + // their way to the `oneshot` result inside this structure. Otherwise + // this only gets dropped if `None` is seen or an error. Also this'll + // terminate once `buf` grows too large. + let mut buf = Vec::new(); + let result = loop { + match self.iter.next() { + Some(Ok(item)) => buf.push(item), + Some(Err(e)) => { + self.close(Err(e)); + break StreamResult::Dropped; + } + + None => { + self.close(Ok(())); + break StreamResult::Dropped; + } + } + if buf.len() >= count { + break StreamResult::Completed; + } + }; + + dst.set_buffer(buf.into()); + return Poll::Ready(Ok(result)); + } +} + +impl FallibleIteratorProducer { + fn new(iter: I, result: oneshot::Sender>) -> Self { + Self { + iter, + result: Some(result), + } + } + + fn close(&mut self, result: Result<(), E>) { + // Ignore send failures because it means the other end wasn't interested + // in the final error, if any. + let _ = self.result.take().unwrap().send(result); + } +} + +impl Drop for FallibleIteratorProducer { + fn drop(&mut self) { + if self.result.is_some() { + self.close(Ok(())); + } + } +} + /// Add all WASI interfaces from this module into the `linker` provided. /// /// This function will add all interfaces implemented by this module to the diff --git a/crates/wasi/tests/all/p3/mod.rs b/crates/wasi/tests/all/p3/mod.rs index 4efc033968ba..d9a0d04915f6 100644 --- a/crates/wasi/tests/all/p3/mod.rs +++ b/crates/wasi/tests/all/p3/mod.rs @@ -7,6 +7,13 @@ use wasmtime::component::{Component, Linker}; use wasmtime_wasi::p3::bindings::Command; async fn run(path: &str) -> Result<()> { + run_allow_blocking_current_thread(path, false).await +} + +async fn run_allow_blocking_current_thread( + path: &str, + allow_blocking_current_thread: bool, +) -> Result<()> { let path = Path::new(path); let name = path.file_stem().unwrap().to_str().unwrap(); let engine = test_programs_artifacts::engine(|config| { @@ -20,7 +27,9 @@ async fn run(path: &str) -> Result<()> { wasmtime_wasi::p3::add_to_linker(&mut linker).context("failed to link `wasi:cli@0.3.x`")?; let (mut store, _td) = Ctx::new(&engine, name, |builder| MyWasiCtx { - wasi: builder.build(), + wasi: builder + .allow_blocking_current_thread(allow_blocking_current_thread) + .build(), table: Default::default(), })?; let component = Component::from_file(&engine, path)?; @@ -118,3 +127,13 @@ async fn p3_sockets_udp_sockopts() -> anyhow::Result<()> { async fn p3_sockets_udp_states() -> anyhow::Result<()> { run(P3_SOCKETS_UDP_STATES_COMPONENT).await } + +#[test_log::test(tokio::test(flavor = "multi_thread"))] +async fn p3_readdir() -> anyhow::Result<()> { + run(P3_READDIR_COMPONENT).await +} + +#[test_log::test(tokio::test(flavor = "multi_thread"))] +async fn p3_readdir_blocking() -> anyhow::Result<()> { + run_allow_blocking_current_thread(P3_READDIR_COMPONENT, true).await +}