Skip to content

Commit bf55f48

Browse files
committed
add redirect_stdout(..., per_thread: bool = False)
1 parent 5f57f69 commit bf55f48

File tree

2 files changed

+439
-6
lines changed

2 files changed

+439
-6
lines changed

Lib/contextlib.py

Lines changed: 65 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import abc
33
import os
44
import sys
5+
import threading
56
import _collections_abc
67
from collections import deque
78
from functools import wraps
@@ -390,22 +391,76 @@ async def __aexit__(self, *exc_info):
390391
await self.thing.aclose()
391392

392393

394+
class _PerThreadStream:
395+
def __init__(self, default_stream):
396+
self.default_stream = default_stream
397+
# each stack entry is (stream, thread_id). thread_id is None if
398+
# per_thread=False.
399+
self._stack = []
400+
self._lock = threading.Lock()
401+
402+
@property
403+
def _current_stream(self):
404+
thread_id = threading.get_ident()
405+
# look for the most recent redirect which was either:
406+
# * per_thread=False
407+
# * per_thread=True, and in our thread
408+
#
409+
# If none match, fall back to the default stream.
410+
with self._lock:
411+
for stream, entry_thread_id in reversed(self._stack):
412+
if entry_thread_id is None or entry_thread_id == thread_id:
413+
return stream
414+
return self.default_stream
415+
416+
def add_entry(self, entry):
417+
with self._lock:
418+
self._stack.append(entry)
419+
420+
def remove_entry(self, entry):
421+
# remove by identity, not equality, in case two streams compare equal
422+
with self._lock:
423+
for i, e in enumerate(self._stack):
424+
if e is entry:
425+
del self._stack[i]
426+
return
427+
428+
def __getattr__(self, name):
429+
return getattr(self._current_stream, name)
430+
431+
393432
class _RedirectStream(AbstractContextManager):
394433

395434
_stream = None
435+
_lock = None
436+
_stream_ref = None
396437

397-
def __init__(self, new_target):
438+
def __init__(self, new_target, *, per_thread=False):
398439
self._new_target = new_target
399-
# We use a list of old targets to make this CM re-entrant
400-
self._old_targets = []
440+
self._per_thread = per_thread
441+
self._entries = [] # stack for reentrant usage
401442

402443
def __enter__(self):
403-
self._old_targets.append(getattr(sys, self._stream))
404-
setattr(sys, self._stream, self._new_target)
444+
with self._lock:
445+
if self._stream_ref is None:
446+
type(self)._stream_ref = _PerThreadStream(getattr(sys, self._stream))
447+
setattr(sys, self._stream, self._stream_ref)
448+
entry = (
449+
self._new_target,
450+
threading.get_ident() if self._per_thread else None,
451+
)
452+
self._entries.append(entry)
453+
self._stream_ref.add_entry(entry)
454+
405455
return self._new_target
406456

407457
def __exit__(self, exctype, excinst, exctb):
408-
setattr(sys, self._stream, self._old_targets.pop())
458+
with self._lock:
459+
entry = self._entries.pop()
460+
self._stream_ref.remove_entry(entry)
461+
if len(self._stream_ref._stack) == 0:
462+
setattr(sys, self._stream, self._stream_ref.default_stream)
463+
type(self)._stream_ref = None
409464

410465

411466
class redirect_stdout(_RedirectStream):
@@ -422,12 +477,16 @@ class redirect_stdout(_RedirectStream):
422477
"""
423478

424479
_stream = "stdout"
480+
_lock = threading.Lock()
481+
_stream_ref = None
425482

426483

427484
class redirect_stderr(_RedirectStream):
428485
"""Context manager for temporarily redirecting stderr to another file."""
429486

430487
_stream = "stderr"
488+
_lock = threading.Lock()
489+
_stream_ref = None
431490

432491

433492
class suppress(AbstractContextManager):

0 commit comments

Comments
 (0)