From 2c99fe196e5b206700ab2ff5a1855c8019594ab4 Mon Sep 17 00:00:00 2001 From: Roman Volosatovs Date: Mon, 3 Nov 2025 17:48:44 +0100 Subject: [PATCH 1/2] fix `Request::into_http`, use in `handle` Signed-off-by: Roman Volosatovs --- crates/wasi-http/src/p3/body.rs | 146 ++++++++++++++ crates/wasi-http/src/p3/host/handler.rs | 251 +----------------------- crates/wasi-http/src/p3/request.rs | 94 +++++---- 3 files changed, 205 insertions(+), 286 deletions(-) diff --git a/crates/wasi-http/src/p3/body.rs b/crates/wasi-http/src/p3/body.rs index 54e93d61f667..c804307f78cf 100644 --- a/crates/wasi-http/src/p3/body.rs +++ b/crates/wasi-http/src/p3/body.rs @@ -541,3 +541,149 @@ where } } } + +/// A wrapper around [http_body::Body], which allows attaching arbitrary state to it +pub(crate) struct BodyWithState { + body: T, + _state: U, +} + +impl http_body::Body for BodyWithState +where + T: http_body::Body + Unpin, + U: Unpin, +{ + type Data = T::Data; + type Error = T::Error; + + #[inline] + fn poll_frame( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>>> { + Pin::new(&mut self.get_mut().body).poll_frame(cx) + } + + #[inline] + fn is_end_stream(&self) -> bool { + self.body.is_end_stream() + } + + #[inline] + fn size_hint(&self) -> http_body::SizeHint { + self.body.size_hint() + } +} + +/// A wrapper around [http_body::Body], which validates `Content-Length` +pub(crate) struct BodyWithContentLength { + body: T, + error_tx: Option>, + make_error: fn(Option) -> E, + /// Limit of bytes to be sent + limit: u64, + /// Number of bytes sent + sent: u64, +} + +impl BodyWithContentLength { + /// Sends the error constructed by [Self::make_error] on [Self::error_tx]. + /// Does nothing if an error has already been sent on [Self::error_tx]. + fn send_error(&mut self, sent: Option) -> Poll>> { + if let Some(error_tx) = self.error_tx.take() { + _ = error_tx.send((self.make_error)(sent)); + } + Poll::Ready(Some(Err((self.make_error)(sent)))) + } +} + +impl http_body::Body for BodyWithContentLength +where + T: http_body::Body + Unpin, +{ + type Data = T::Data; + type Error = T::Error; + + #[inline] + fn poll_frame( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>>> { + match ready!(Pin::new(&mut self.as_mut().body).poll_frame(cx)) { + Some(Ok(frame)) => { + let Some(data) = frame.data_ref() else { + return Poll::Ready(Some(Ok(frame))); + }; + let Ok(sent) = data.len().try_into() else { + return self.send_error(None); + }; + let Some(sent) = self.sent.checked_add(sent) else { + return self.send_error(None); + }; + if sent > self.limit { + return self.send_error(Some(sent)); + } + self.sent = sent; + Poll::Ready(Some(Ok(frame))) + } + Some(Err(err)) => Poll::Ready(Some(Err(err))), + None if self.limit != self.sent => { + // short write + let sent = self.sent; + self.send_error(Some(sent)) + } + None => Poll::Ready(None), + } + } + + #[inline] + fn is_end_stream(&self) -> bool { + self.body.is_end_stream() + } + + #[inline] + fn size_hint(&self) -> http_body::SizeHint { + let n = self.limit.saturating_sub(self.sent); + let mut hint = self.body.size_hint(); + if hint.lower() >= n { + hint.set_exact(n) + } else if let Some(max) = hint.upper() { + hint.set_upper(n.min(max)) + } else { + hint.set_upper(n) + } + hint + } +} + +pub(crate) trait BodyExt { + fn with_state(self, state: T) -> BodyWithState + where + Self: Sized, + { + BodyWithState { + body: self, + _state: state, + } + } + + fn with_content_length( + self, + limit: u64, + error_tx: oneshot::Sender, + make_error: fn(Option) -> E, + ) -> BodyWithContentLength + where + Self: Sized, + { + BodyWithContentLength { + body: self, + error_tx: Some(error_tx), + make_error, + limit, + sent: 0, + } + } +} + +impl BodyExt for T {} diff --git a/crates/wasi-http/src/p3/host/handler.rs b/crates/wasi-http/src/p3/host/handler.rs index 7918d8e33352..44f2c6aff038 100644 --- a/crates/wasi-http/src/p3/host/handler.rs +++ b/crates/wasi-http/src/p3/host/handler.rs @@ -1,14 +1,9 @@ -use crate::get_content_length; use crate::p3::bindings::http::handler::{Host, HostWithStore}; use crate::p3::bindings::http::types::{ErrorCode, Request, Response}; -use crate::p3::body::{Body, GuestBody}; +use crate::p3::body::{Body, BodyExt as _}; use crate::p3::{HttpError, HttpResult, WasiHttp, WasiHttpCtxView}; use anyhow::Context as _; -use bytes::Bytes; -use core::pin::Pin; -use core::task::{Context, Poll, Waker, ready}; -use http::header::HOST; -use http::{HeaderValue, Uri}; +use core::task::{Context, Poll, Waker}; use http_body_util::BodyExt as _; use std::sync::Arc; use tokio::sync::oneshot; @@ -26,152 +21,6 @@ impl Drop for AbortOnDropJoinHandle { } } -/// A wrapper around [http_body::Body], which allows attaching arbitrary state to it -struct BodyWithState { - body: T, - _state: U, -} - -impl http_body::Body for BodyWithState -where - T: http_body::Body + Unpin, - U: Unpin, -{ - type Data = T::Data; - type Error = T::Error; - - #[inline] - fn poll_frame( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll, Self::Error>>> { - Pin::new(&mut self.get_mut().body).poll_frame(cx) - } - - #[inline] - fn is_end_stream(&self) -> bool { - self.body.is_end_stream() - } - - #[inline] - fn size_hint(&self) -> http_body::SizeHint { - self.body.size_hint() - } -} - -/// A wrapper around [http_body::Body], which validates `Content-Length` -struct BodyWithContentLength { - body: T, - error_tx: Option>, - make_error: fn(Option) -> E, - /// Limit of bytes to be sent - limit: u64, - /// Number of bytes sent - sent: u64, -} - -impl BodyWithContentLength { - /// Sends the error constructed by [Self::make_error] on [Self::error_tx]. - /// Does nothing if an error has already been sent on [Self::error_tx]. - fn send_error(&mut self, sent: Option) -> Poll>> { - if let Some(error_tx) = self.error_tx.take() { - _ = error_tx.send((self.make_error)(sent)); - } - Poll::Ready(Some(Err((self.make_error)(sent)))) - } -} - -impl http_body::Body for BodyWithContentLength -where - T: http_body::Body + Unpin, -{ - type Data = T::Data; - type Error = T::Error; - - #[inline] - fn poll_frame( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll, Self::Error>>> { - match ready!(Pin::new(&mut self.as_mut().body).poll_frame(cx)) { - Some(Ok(frame)) => { - let Some(data) = frame.data_ref() else { - return Poll::Ready(Some(Ok(frame))); - }; - let Ok(sent) = data.len().try_into() else { - return self.send_error(None); - }; - let Some(sent) = self.sent.checked_add(sent) else { - return self.send_error(None); - }; - if sent > self.limit { - return self.send_error(Some(sent)); - } - self.sent = sent; - Poll::Ready(Some(Ok(frame))) - } - Some(Err(err)) => Poll::Ready(Some(Err(err))), - None if self.limit != self.sent => { - // short write - let sent = self.sent; - self.send_error(Some(sent)) - } - None => Poll::Ready(None), - } - } - - #[inline] - fn is_end_stream(&self) -> bool { - self.body.is_end_stream() - } - - #[inline] - fn size_hint(&self) -> http_body::SizeHint { - let n = self.limit.saturating_sub(self.sent); - let mut hint = self.body.size_hint(); - if hint.lower() >= n { - hint.set_exact(n) - } else if let Some(max) = hint.upper() { - hint.set_upper(n.min(max)) - } else { - hint.set_upper(n) - } - hint - } -} - -trait BodyExt { - fn with_state(self, state: T) -> BodyWithState - where - Self: Sized, - { - BodyWithState { - body: self, - _state: state, - } - } - - fn with_content_length( - self, - limit: u64, - error_tx: oneshot::Sender, - make_error: fn(Option) -> E, - ) -> BodyWithContentLength - where - Self: Sized, - { - BodyWithContentLength { - body: self, - error_tx: Some(error_tx), - make_error, - limit, - sent: 0, - } - } -} - -impl BodyExt for T {} - async fn io_task_result( rx: oneshot::Receiver<( Arc, @@ -203,102 +52,14 @@ impl HostWithStore for WasiHttp { let getter = store.getter(); let fut = store.with(|mut store| { let WasiHttpCtxView { table, .. } = store.get(); - let Request { - method, - scheme, - authority, - path_with_query, - headers, - options, - body, - } = table + let req = table .delete(req) .context("failed to delete request from table") .map_err(HttpError::trap)?; - // `Content-Length` header value is validated in `fields` implementation - let content_length = match get_content_length(&headers) { - Ok(content_length) => content_length, - Err(err) => { - body.drop(&mut store); - return Err(ErrorCode::InternalError(Some(format!("{err:#}"))).into()); - } - }; - let mut headers = Arc::unwrap_or_clone(headers); - let body = match body { - Body::Guest { - contents_rx, - trailers_rx, - result_tx, - } => GuestBody::new( - &mut store, - contents_rx, - trailers_rx, - result_tx, - io_task_result(io_result_rx), - content_length, - ErrorCode::HttpRequestBodySize, - getter, - ) - .with_state(io_task_rx) - .boxed_unsync(), - Body::Host { body, result_tx } => { - if let Some(limit) = content_length { - let (http_result_tx, http_result_rx) = oneshot::channel(); - _ = result_tx.send(Box::new(async move { - if let Ok(err) = http_result_rx.await { - return Err(err); - }; - io_task_result(io_result_rx).await - })); - body.with_content_length( - limit, - http_result_tx, - ErrorCode::HttpRequestBodySize, - ) - .with_state(io_task_rx) - .boxed_unsync() - } else { - _ = result_tx.send(Box::new(io_task_result(io_result_rx))); - body.with_state(io_task_rx).boxed_unsync() - } - } - }; - - let WasiHttpCtxView { ctx, .. } = store.get(); - if ctx.set_host_header() { - let host = if let Some(authority) = authority.as_ref() { - HeaderValue::try_from(authority.as_str()) - .map_err(|err| ErrorCode::InternalError(Some(err.to_string())))? - } else { - HeaderValue::from_static("") - }; - headers.insert(HOST, host); - } - let scheme = match scheme { - None => ctx.default_scheme().ok_or(ErrorCode::HttpProtocolError)?, - Some(scheme) if ctx.is_supported_scheme(&scheme) => scheme, - Some(..) => return Err(ErrorCode::HttpProtocolError.into()), - }; - let mut uri = Uri::builder().scheme(scheme); - if let Some(authority) = authority { - uri = uri.authority(authority) - }; - if let Some(path_with_query) = path_with_query { - uri = uri.path_and_query(path_with_query) - }; - let uri = uri.build().map_err(|err| { - debug!(?err, "failed to build request URI"); - ErrorCode::HttpRequestUriInvalid - })?; - let mut req = http::Request::builder(); - *req.headers_mut().unwrap() = headers; - let req = req - .method(method) - .uri(uri) - .body(body) - .map_err(|err| ErrorCode::InternalError(Some(err.to_string())))?; + let (req, options) = + req.into_http_with_getter(&mut store, io_task_result(io_result_rx), getter)?; HttpResult::Ok(store.get().ctx.send_request( - req, + req.map(|body| body.with_state(io_task_rx).boxed_unsync()), options.as_deref().copied(), Box::new(async { // Forward the response processing result to `WasiHttpCtx` implementation diff --git a/crates/wasi-http/src/p3/request.rs b/crates/wasi-http/src/p3/request.rs index 6d3792f2e77a..9f19b5614330 100644 --- a/crates/wasi-http/src/p3/request.rs +++ b/crates/wasi-http/src/p3/request.rs @@ -1,6 +1,6 @@ use crate::get_content_length; use crate::p3::bindings::http::types::ErrorCode; -use crate::p3::body::{Body, GuestBody}; +use crate::p3::body::{Body, BodyExt as _, GuestBody}; use crate::p3::{WasiHttpCtxView, WasiHttpView}; use bytes::Bytes; use core::time::Duration; @@ -15,7 +15,7 @@ use tracing::debug; use wasmtime::AsContextMut; /// The concrete type behind a `wasi:http/types.request-options` resource. -#[derive(Copy, Clone, Debug, Default)] +#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)] pub struct RequestOptions { /// How long to wait for a connection to be established. pub connect_timeout: Option, @@ -134,7 +134,13 @@ impl Request { self, store: impl AsContextMut, fut: impl Future> + Send + 'static, - ) -> wasmtime::Result>> { + ) -> Result< + ( + http::Request>, + Option>, + ), + ErrorCode, + > { self.into_http_with_getter(store, fut, T::http) } @@ -144,21 +150,28 @@ impl Request { mut store: impl AsContextMut, fut: impl Future> + Send + 'static, getter: fn(&mut T) -> WasiHttpCtxView<'_>, - ) -> wasmtime::Result>> { + ) -> Result< + ( + http::Request>, + Option>, + ), + ErrorCode, + > { let Request { method, scheme, authority, path_with_query, headers, - options: _, + options, body, } = self; + // `Content-Length` header value is validated in `fields` implementation let content_length = match get_content_length(&headers) { Ok(content_length) => content_length, Err(err) => { body.drop(&mut store); - return Err(ErrorCode::InternalError(Some(format!("{err:#}"))).into()); + return Err(ErrorCode::InternalError(Some(format!("{err:#}")))); } }; // This match must appear before any potential errors handled with '?' @@ -183,13 +196,25 @@ impl Request { ) .boxed_unsync(), Body::Host { body, result_tx } => { - _ = result_tx.send(Box::new(fut)); - body + if let Some(limit) = content_length { + let (http_result_tx, http_result_rx) = oneshot::channel(); + _ = result_tx.send(Box::new(async move { + if let Ok(err) = http_result_rx.await { + return Err(err); + }; + fut.await + })); + body.with_content_length(limit, http_result_tx, ErrorCode::HttpRequestBodySize) + .boxed_unsync() + } else { + _ = result_tx.send(Box::new(fut)); + body + } } }; let mut headers = Arc::unwrap_or_clone(headers); - let mut store_ctx = store.as_context_mut(); - let WasiHttpCtxView { ctx, table: _ } = getter(store_ctx.data_mut()); + let mut store = store.as_context_mut(); + let WasiHttpCtxView { ctx, .. } = getter(store.data_mut()); if ctx.set_host_header() { let host = if let Some(authority) = authority.as_ref() { HeaderValue::try_from(authority.as_str()) @@ -202,7 +227,7 @@ impl Request { let scheme = match scheme { None => ctx.default_scheme().ok_or(ErrorCode::HttpProtocolError)?, Some(scheme) if ctx.is_supported_scheme(&scheme) => scheme, - Some(..) => return Err(ErrorCode::HttpProtocolError.into()), + Some(..) => return Err(ErrorCode::HttpProtocolError), }; let mut uri = Uri::builder().scheme(scheme); if let Some(authority) = authority { @@ -216,21 +241,14 @@ impl Request { ErrorCode::HttpRequestUriInvalid })?; let mut req = http::Request::builder(); - if let Some(headers_mut) = req.headers_mut() { - *headers_mut = headers; - } else { - return Err(ErrorCode::InternalError(Some( - "failed to get mutable headers from request builder".to_string(), - )) - .into()); - } + *req.headers_mut().unwrap() = headers; let req = req .method(method) .uri(uri) .body(body) .map_err(|err| ErrorCode::InternalError(Some(err.to_string())))?; let (req, body) = req.into_parts(); - Ok(http::Request::from_parts(req, body)) + Ok((http::Request::from_parts(req, body), options)) } } @@ -465,20 +483,20 @@ pub async fn default_send_request( #[cfg(test)] mod tests { use super::*; - use crate::p3::WasiHttpCtx; + use crate::p3::DefaultWasiHttpCtx; use anyhow::Result; + use core::future::Future; + use core::pin::pin; + use core::str::FromStr; + use core::task::{Context, Poll, Waker}; use http_body_util::{BodyExt, Empty, Full}; - use std::future::Future; - use std::str::FromStr; - use std::task::{Context, Waker}; use wasmtime::{Engine, Store}; use wasmtime_wasi::{ResourceTable, WasiCtx, WasiCtxBuilder, WasiCtxView, WasiView}; - struct TestHttpCtx; struct TestCtx { table: ResourceTable, wasi: WasiCtx, - http: TestHttpCtx, + http: DefaultWasiHttpCtx, } impl TestCtx { @@ -486,7 +504,7 @@ mod tests { Self { table: ResourceTable::default(), wasi: WasiCtxBuilder::new().build(), - http: TestHttpCtx, + http: DefaultWasiHttpCtx, } } } @@ -500,8 +518,6 @@ mod tests { } } - impl WasiHttpCtx for TestHttpCtx {} - impl WasiHttpView for TestCtx { fn http(&mut self) -> WasiHttpCtxView<'_> { WasiHttpCtxView { @@ -529,7 +545,8 @@ mod tests { .boxed_unsync(), ); let mut store = Store::new(&engine, TestCtx::new()); - let http_req = req.into_http(&mut store, async { Ok(()) }).unwrap(); + let (http_req, options) = req.into_http(&mut store, async { Ok(()) }).unwrap(); + assert_eq!(options, None); assert_eq!(http_req.method(), Method::POST); let expected_scheme = scheme.unwrap_or(Scheme::HTTPS); // default scheme assert_eq!( @@ -541,11 +558,10 @@ mod tests { .unwrap() ); let body_bytes = http_req.into_body().collect().await?; - assert_eq!(*body_bytes.to_bytes(), *b"body"); + assert_eq!(body_bytes.to_bytes(), b"body".as_slice()); let mut cx = Context::from_waker(Waker::noop()); - let mut fut = Box::pin(fut); - let result = fut.as_mut().poll(&mut cx); - assert!(matches!(result, futures::task::Poll::Ready(Ok(())))); + let result = pin!(fut).poll(&mut cx); + assert!(matches!(result, Poll::Ready(Ok(())))); } Ok(()) @@ -566,16 +582,12 @@ mod tests { let result = req.into_http(&mut store, async { Err(ErrorCode::InternalError(Some("uh oh".to_string()))) }); - assert!(result.is_err()); - assert!(matches!( - result.unwrap_err().downcast_ref::(), - Some(ErrorCode::HttpRequestUriInvalid) - )); + assert!(matches!(result, Err(ErrorCode::HttpRequestUriInvalid))); let mut cx = Context::from_waker(Waker::noop()); - let result = Box::pin(fut).as_mut().poll(&mut cx); + let result = pin!(fut).poll(&mut cx); assert!(matches!( result, - futures::task::Poll::Ready(Err(ErrorCode::InternalError(Some(_)))) + Poll::Ready(Err(ErrorCode::InternalError(Some(_)))) )); Ok(()) From 4db290aa34991702d3e62c9ffe799a25b9ff64d8 Mon Sep 17 00:00:00 2001 From: Roman Volosatovs Date: Mon, 3 Nov 2025 18:04:01 +0100 Subject: [PATCH 2/2] doc: use correct body trait ref Signed-off-by: Roman Volosatovs --- crates/wasi-http/src/p3/request.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/wasi-http/src/p3/request.rs b/crates/wasi-http/src/p3/request.rs index 9f19b5614330..00c56e787e23 100644 --- a/crates/wasi-http/src/p3/request.rs +++ b/crates/wasi-http/src/p3/request.rs @@ -125,7 +125,7 @@ impl Request { ) } - /// Convert this [`Request`] into an [`http::Request>`]. + /// Convert this [`Request`] into an [`http::Request>`]. /// /// The specified future `fut` can be used to communicate a request processing /// error, if any, back to the caller (e.g., if this request was constructed