Skip to content

Commit ef4570a

Browse files
committed
gh-142352: asyncio.streams: transfer buffered data to SSL layer in start_tls()
1 parent cf71e34 commit ef4570a

File tree

3 files changed

+99
-0
lines changed

3 files changed

+99
-0
lines changed

Lib/asyncio/base_events.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1341,6 +1341,14 @@ async def start_tls(self, transport, protocol, sslcontext, *,
13411341
ssl_shutdown_timeout=ssl_shutdown_timeout,
13421342
call_connection_made=False)
13431343

1344+
# gh-142352: move buffered StreamReader data to SSLProtocol
1345+
stream_reader = getattr(protocol, '_stream_reader', None)
1346+
if stream_reader is not None:
1347+
buffer = stream_reader._buffer
1348+
if buffer:
1349+
ssl_protocol._incoming.write(buffer)
1350+
buffer.clear()
1351+
13441352
# Pause early so that "ssl_protocol.data_received()" doesn't
13451353
# have a chance to get called before "ssl_protocol.connection_made()".
13461354
transport.pause_reading()

Lib/test/test_asyncio/test_streams.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -819,6 +819,93 @@ async def client(addr):
819819
self.assertEqual(msg1, b"hello world 1!\n")
820820
self.assertEqual(msg2, b"hello world 2!\n")
821821

822+
def _test_start_tls_buffered_data(self, send_combined):
823+
# gh-142352: test start_tls() with buffered data
824+
825+
PROXY_LINE = b"PROXY TCP4 127.0.0.1 127.0.0.1 54321 443\r\n"
826+
TEST_MESSAGE = b"hello world\n"
827+
828+
async def pipe(src, dst):
829+
try:
830+
while data := await src.read(4096):
831+
dst.write(data)
832+
await dst.drain()
833+
finally:
834+
dst.close()
835+
await dst.wait_closed()
836+
837+
async def proxy_handler(client_reader, client_writer, backend_addr):
838+
backend_reader, backend_writer = await asyncio.open_connection(
839+
*backend_addr)
840+
try:
841+
tls_data = await client_reader.read(4096)
842+
if send_combined:
843+
backend_writer.write(PROXY_LINE + tls_data)
844+
else:
845+
backend_writer.write(PROXY_LINE)
846+
await backend_writer.drain()
847+
await asyncio.sleep(0.01)
848+
backend_writer.write(tls_data)
849+
await backend_writer.drain()
850+
851+
await asyncio.gather(
852+
pipe(client_reader, backend_writer),
853+
pipe(backend_reader, client_writer),
854+
)
855+
finally:
856+
client_writer.close()
857+
backend_writer.close()
858+
await asyncio.gather(
859+
client_writer.wait_closed(),
860+
backend_writer.wait_closed(),
861+
return_exceptions=True
862+
)
863+
864+
async def server_handler(client_reader, client_writer):
865+
self.assertEqual(await client_reader.readline(), PROXY_LINE)
866+
await client_writer.start_tls(test_utils.simple_server_sslcontext())
867+
self.assertEqual(await client_reader.readline(), TEST_MESSAGE)
868+
client_writer.close()
869+
await client_writer.wait_closed()
870+
871+
async def client(addr):
872+
_, writer = await asyncio.open_connection(*addr)
873+
await writer.start_tls(test_utils.simple_client_sslcontext())
874+
writer.write(TEST_MESSAGE)
875+
await writer.drain()
876+
writer.close()
877+
await writer.wait_closed()
878+
879+
async def run_test():
880+
server = await asyncio.start_server(
881+
server_handler, socket_helper.HOSTv4, 0)
882+
server_addr = server.sockets[0].getsockname()
883+
884+
proxy = await asyncio.start_server(
885+
lambda r, w: proxy_handler(r, w, server_addr),
886+
socket_helper.HOSTv4, 0)
887+
proxy_addr = proxy.sockets[0].getsockname()
888+
889+
await asyncio.wait_for(client(proxy_addr), timeout=5.0)
890+
proxy.close()
891+
server.close()
892+
await asyncio.gather(proxy.wait_closed(), server.wait_closed())
893+
894+
messages = []
895+
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
896+
self.loop.run_until_complete(run_test())
897+
self.assertEqual(messages, [])
898+
899+
@unittest.skipIf(ssl is None, 'No ssl module')
900+
def test_start_tls_buffered_data_combined(self):
901+
# gh-142352: Test TLS data buffered before start_tls
902+
self._test_start_tls_buffered_data(send_combined=True)
903+
904+
@unittest.skipIf(ssl is None, 'No ssl module')
905+
def test_start_tls_buffered_data_separate(self):
906+
# gh-142352: Test TLS data sent separately
907+
self._test_start_tls_buffered_data(send_combined=False)
908+
822909
def test_streamreader_constructor_without_loop(self):
823910
with self.assertRaisesRegex(RuntimeError, 'no current event loop'):
824911
asyncio.StreamReader()
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
Fix :meth:`asyncio.StreamWriter.start_tls` to transfer buffered data from
2+
:class:`~asyncio.StreamReader` to the SSL layer, preventing data loss when
3+
upgrading a connection to TLS mid-stream (e.g., when implementing PROXY
4+
protocol support).

0 commit comments

Comments
 (0)