@@ -819,6 +819,93 @@ async def client(addr):
819819 self .assertEqual (msg1 , b"hello world 1!\n " )
820820 self .assertEqual (msg2 , b"hello world 2!\n " )
821821
822+ def _test_start_tls_buffered_data (self , send_combined ):
823+ # gh-142352: test start_tls() with buffered data
824+
825+ PROXY_LINE = b"PROXY TCP4 127.0.0.1 127.0.0.1 54321 443\r \n "
826+ TEST_MESSAGE = b"hello world\n "
827+
828+ async def pipe (src , dst ):
829+ try :
830+ while data := await src .read (4096 ):
831+ dst .write (data )
832+ await dst .drain ()
833+ finally :
834+ dst .close ()
835+ await dst .wait_closed ()
836+
837+ async def proxy_handler (client_reader , client_writer , backend_addr ):
838+ backend_reader , backend_writer = await asyncio .open_connection (
839+ * backend_addr )
840+ try :
841+ tls_data = await client_reader .read (4096 )
842+ if send_combined :
843+ backend_writer .write (PROXY_LINE + tls_data )
844+ else :
845+ backend_writer .write (PROXY_LINE )
846+ await backend_writer .drain ()
847+ await asyncio .sleep (0.01 )
848+ backend_writer .write (tls_data )
849+ await backend_writer .drain ()
850+
851+ await asyncio .gather (
852+ pipe (client_reader , backend_writer ),
853+ pipe (backend_reader , client_writer ),
854+ )
855+ finally :
856+ client_writer .close ()
857+ backend_writer .close ()
858+ await asyncio .gather (
859+ client_writer .wait_closed (),
860+ backend_writer .wait_closed (),
861+ return_exceptions = True
862+ )
863+
864+ async def server_handler (client_reader , client_writer ):
865+ self .assertEqual (await client_reader .readline (), PROXY_LINE )
866+ await client_writer .start_tls (test_utils .simple_server_sslcontext ())
867+ self .assertEqual (await client_reader .readline (), TEST_MESSAGE )
868+ client_writer .close ()
869+ await client_writer .wait_closed ()
870+
871+ async def client (addr ):
872+ _ , writer = await asyncio .open_connection (* addr )
873+ await writer .start_tls (test_utils .simple_client_sslcontext ())
874+ writer .write (TEST_MESSAGE )
875+ await writer .drain ()
876+ writer .close ()
877+ await writer .wait_closed ()
878+
879+ async def run_test ():
880+ server = await asyncio .start_server (
881+ server_handler , socket_helper .HOSTv4 , 0 )
882+ server_addr = server .sockets [0 ].getsockname ()
883+
884+ proxy = await asyncio .start_server (
885+ lambda r , w : proxy_handler (r , w , server_addr ),
886+ socket_helper .HOSTv4 , 0 )
887+ proxy_addr = proxy .sockets [0 ].getsockname ()
888+
889+ await asyncio .wait_for (client (proxy_addr ), timeout = 5.0 )
890+ proxy .close ()
891+ server .close ()
892+ await asyncio .gather (proxy .wait_closed (), server .wait_closed ())
893+
894+ messages = []
895+ self .loop .set_exception_handler (lambda loop , ctx : messages .append (ctx ))
896+ self .loop .run_until_complete (run_test ())
897+ self .assertEqual (messages , [])
898+
899+ @unittest .skipIf (ssl is None , 'No ssl module' )
900+ def test_start_tls_buffered_data_combined (self ):
901+ # gh-142352: Test TLS data buffered before start_tls
902+ self ._test_start_tls_buffered_data (send_combined = True )
903+
904+ @unittest .skipIf (ssl is None , 'No ssl module' )
905+ def test_start_tls_buffered_data_separate (self ):
906+ # gh-142352: Test TLS data sent separately
907+ self ._test_start_tls_buffered_data (send_combined = False )
908+
822909 def test_streamreader_constructor_without_loop (self ):
823910 with self .assertRaisesRegex (RuntimeError , 'no current event loop' ):
824911 asyncio .StreamReader ()
0 commit comments