22import abc
33import os
44import sys
5+ import threading
56import _collections_abc
67from collections import deque
78from functools import wraps
@@ -390,22 +391,76 @@ async def __aexit__(self, *exc_info):
390391 await self .thing .aclose ()
391392
392393
394+ class _PerThreadStream :
395+ def __init__ (self , default_stream ):
396+ self .default_stream = default_stream
397+ # each stack entry is (stream, thread_id). thread_id is None if
398+ # per_thread=False.
399+ self ._stack = []
400+ self ._lock = threading .Lock ()
401+
402+ @property
403+ def _current_stream (self ):
404+ thread_id = threading .get_ident ()
405+ # look for the most recent redirect which was either:
406+ # * per_thread=False
407+ # * per_thread=True, and in our thread
408+ #
409+ # If none match, fall back to the default stream.
410+ with self ._lock :
411+ for stream , entry_thread_id in reversed (self ._stack ):
412+ if entry_thread_id is None or entry_thread_id == thread_id :
413+ return stream
414+ return self .default_stream
415+
416+ def add_entry (self , entry ):
417+ with self ._lock :
418+ self ._stack .append (entry )
419+
420+ def remove_entry (self , entry ):
421+ # remove by identity, not equality, in case two streams compare equal
422+ with self ._lock :
423+ for i , e in enumerate (self ._stack ):
424+ if e is entry :
425+ del self ._stack [i ]
426+ return
427+
428+ def __getattr__ (self , name ):
429+ return getattr (self ._current_stream , name )
430+
431+
393432class _RedirectStream (AbstractContextManager ):
394433
395434 _stream = None
435+ _lock = None
436+ _stream_ref = None
396437
397- def __init__ (self , new_target ):
438+ def __init__ (self , new_target , * , per_thread = False ):
398439 self ._new_target = new_target
399- # We use a list of old targets to make this CM re-entrant
400- self ._old_targets = []
440+ self . _per_thread = per_thread
441+ self ._entries = [] # stack for reentrant usage
401442
402443 def __enter__ (self ):
403- self ._old_targets .append (getattr (sys , self ._stream ))
404- setattr (sys , self ._stream , self ._new_target )
444+ with self ._lock :
445+ if self ._stream_ref is None :
446+ type(self )._stream_ref = _PerThreadStream (getattr (sys , self ._stream ))
447+ setattr (sys , self ._stream , self ._stream_ref )
448+ entry = (
449+ self ._new_target ,
450+ threading .get_ident () if self ._per_thread else None ,
451+ )
452+ self ._entries .append (entry )
453+ self ._stream_ref .add_entry (entry )
454+
405455 return self ._new_target
406456
407457 def __exit__ (self , exctype , excinst , exctb ):
408- setattr (sys , self ._stream , self ._old_targets .pop ())
458+ with self ._lock :
459+ entry = self ._entries .pop ()
460+ self ._stream_ref .remove_entry (entry )
461+ if len (self ._stream_ref ._stack ) == 0 :
462+ setattr (sys , self ._stream , self ._stream_ref .default_stream )
463+ type(self )._stream_ref = None
409464
410465
411466class redirect_stdout (_RedirectStream ):
@@ -422,12 +477,16 @@ class redirect_stdout(_RedirectStream):
422477 """
423478
424479 _stream = "stdout"
480+ _lock = threading .Lock ()
481+ _stream_ref = None
425482
426483
427484class redirect_stderr (_RedirectStream ):
428485 """Context manager for temporarily redirecting stderr to another file."""
429486
430487 _stream = "stderr"
488+ _lock = threading .Lock ()
489+ _stream_ref = None
431490
432491
433492class suppress (AbstractContextManager ):
0 commit comments