From 7e4abe4f73e32a118332524c7860c841d0645298 Mon Sep 17 00:00:00 2001 From: Diretnan Domnan Date: Fri, 13 Feb 2026 13:33:04 +0100 Subject: [PATCH] H1 servers can be shutdown connections with pending unflushed data for slower clients --- Cargo.toml | 5 + src/proto/h1/dispatch.rs | 11 +- tests/h1_shutdown_while_buffered.rs | 206 ++++++++++++++++++++++++++++ 3 files changed, 221 insertions(+), 1 deletion(-) create mode 100644 tests/h1_shutdown_while_buffered.rs diff --git a/Cargo.toml b/Cargo.toml index 4441bdcdea..ca4b55b87d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -243,3 +243,8 @@ required-features = ["full"] name = "server" path = "tests/server.rs" required-features = ["full"] + +[[test]] +name = "h1_shutdown_while_buffered" +path = "tests/h1_shutdown_while_buffered.rs" +required-features = ["full"] diff --git a/src/proto/h1/dispatch.rs b/src/proto/h1/dispatch.rs index 5daeb5ebf6..240a0985c7 100644 --- a/src/proto/h1/dispatch.rs +++ b/src/proto/h1/dispatch.rs @@ -171,7 +171,16 @@ where for _ in 0..16 { let _ = self.poll_read(cx)?; let _ = self.poll_write(cx)?; - let _ = self.poll_flush(cx)?; + let flush_result = self.poll_flush(cx)?; + + // Check if flush is still pending before exiting poll_loop. + // If we have buffered data that needs to be written, we should return Poll::Pending to + // allow the buffer to drain, otherwise poll_shutdown may be called prematurely with + // data still buffered. This should also be a no-op for Unbuffered streams as + // `flush_result` should always be Ready. + if flush_result.is_pending() { + return Poll::Pending; + } // This could happen if reading paused before blocking on IO, // such as getting to the end of a framed message, but then diff --git a/tests/h1_shutdown_while_buffered.rs b/tests/h1_shutdown_while_buffered.rs new file mode 100644 index 0000000000..ca60a6d47f --- /dev/null +++ b/tests/h1_shutdown_while_buffered.rs @@ -0,0 +1,206 @@ +// Test: Ensures poll_shutdown() is never called with buffered data +// +// Reproduces rare timing bug where HTTP/1.1 server calls shutdown() on a socket while response +// data is still buffered (not flushed), leading to data loss. +// +// Scenario: +// 1. Request fully received and read. +// 2. Server computes a "large" response with Full::new() +// 3. Socket accepts only a chunk of response and then pends +// 3. Flush returns Pending (remaining data still buffered), result ignored +// 4. self.conn.wants_read_again() is false and poll_loop returns Ready +// 5. BUG: poll_shutdown called prematurely and buffered body is lost +// 6. FIX: poll_loop checks flush result and returns Pending, giving the chance for poll_loop to +// run again + +use std::{ + pin::Pin, + sync::{Arc, Mutex}, + task::Poll, + time::Duration, +}; + +use bytes::Bytes; +use http::{Request, Response}; +use http_body_util::Full; +use hyper::{body::Incoming, service::service_fn}; +use support::TokioIo; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + net::{TcpListener, TcpStream}, + time::{sleep, timeout}, +}; +mod support; + +#[derive(Debug, Default)] +struct PendingStreamStatistics { + bytes_written: usize, + total_attempted: usize, + shutdown_called_with_buffered: bool, + buffered_at_shutdown: usize, +} + +// Simple struct that simply does one write and then pends perpetually +struct PendingStream { + inner: TcpStream, + // Keep track of how many times we entered poll_write so as to be able to write only the first + // time out + write_count: usize, + // Only write this chunk size out of full buffer + write_chunk_size: usize, + stats: Arc>, +} + +impl PendingStream { + fn new( + inner: TcpStream, + write_chunk_size: usize, + stats: Arc>, + ) -> Self { + Self { + inner, + stats, + write_chunk_size, + write_count: 0, + } + } +} + +impl AsyncRead for PendingStream { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + Pin::new(&mut self.inner).poll_read(cx, buf) + } +} + +impl AsyncWrite for PendingStream { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> Poll> { + self.write_count += 1; + + let mut stats = self.stats.lock().unwrap(); + stats.total_attempted += buf.len(); + + if self.write_count == 1 { + // First write: partial only + let partial = std::cmp::min(buf.len(), self.write_chunk_size); + drop(stats); + + let result = Pin::new(&mut self.inner).poll_write(cx, &buf[..partial]); + if let Poll::Ready(Ok(n)) = result { + self.stats.lock().unwrap().bytes_written += n; + } + return result; + } + + // Block all further writes to simulate pending buffer + Poll::Pending + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let mut stats = self.stats.lock().unwrap(); + let buffered = stats.total_attempted - stats.bytes_written; + + if buffered > 0 { + eprintln!( + "\n❌BUG: shutdown() called with {} bytes buffered", + buffered + ); + stats.shutdown_called_with_buffered = true; + stats.buffered_at_shutdown = buffered; + } + drop(stats); + Pin::new(&mut self.inner).poll_shutdown(cx) + } + + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let stats = self.stats.lock().unwrap(); + let buffered = stats.total_attempted - stats.bytes_written; + + if buffered > 0 { + return Poll::Pending; + } + + drop(stats); + Pin::new(&mut self.inner).poll_flush(cx) + } +} + +// Test doesn't necessarily check that the connections ended successfully but mainly that shutdown +// wasn't called with data still remaining within hyper's internal buffer +#[tokio::test] +async fn test_no_premature_shutdown_while_buffered() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let stats = Arc::new(Mutex::new(PendingStreamStatistics::default())); + + let stats_clone = stats.clone(); + let server = tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + let pending_stream = PendingStream::new(stream, 212_992, stats_clone); + let io = TokioIo::new(pending_stream); + + let service = service_fn(|_req: Request| async move { + // Larger Full response than write_chunk_size + let body = Full::new(Bytes::from(vec![b'X'; 500_000])); + Ok::<_, hyper::Error>(Response::new(body)) + }); + + hyper::server::conn::http1::Builder::new() + .serve_connection(io, service) + .await + }); + + // Wait for server to be ready + sleep(Duration::from_millis(50)).await; + + // Client sends request + tokio::spawn(async move { + let mut stream = TcpStream::connect(addr).await.unwrap(); + + use tokio::io::AsyncWriteExt; + + stream + .write_all( + b"POST / HTTP/1.1\r\n\ + Host: localhost\r\n\ + Transfer-Encoding: chunked\r\n\ + \r\n", + ) + .await + .unwrap(); + + stream.write_all(b"A\r\nHello World\r\n").await.unwrap(); + stream.write_all(b"0\r\n\r\n").await.unwrap(); + stream.flush().await.unwrap(); + + // keep connection open + sleep(Duration::from_secs(2)).await; + }); + + // Wait for completion + let result = timeout(Duration::from_millis(900), server).await; + + let stats = stats.lock().unwrap(); + + assert!( + !stats.shutdown_called_with_buffered, + "shutdown() called with {} bytes still buffered (wrote {} of {} bytes)", + stats.buffered_at_shutdown, stats.bytes_written, stats.total_attempted + ); + if let Ok(Ok(conn_result)) = result { + conn_result.ok(); + } +}