From 0e9a179a9ad28f5a70a903c59538692b8cf3753c Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Wed, 29 Oct 2025 00:19:44 +0800 Subject: [PATCH 01/18] init fix --- .../callbacks/progress/rich_progress.py | 43 ++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/src/lightning/pytorch/callbacks/progress/rich_progress.py b/src/lightning/pytorch/callbacks/progress/rich_progress.py index d4c3c916c7ed0..16efa8d42ac63 100644 --- a/src/lightning/pytorch/callbacks/progress/rich_progress.py +++ b/src/lightning/pytorch/callbacks/progress/rich_progress.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import math +import time from collections.abc import Generator from dataclasses import dataclass from datetime import timedelta +from threading import Event, Thread from typing import Any, Optional, Union, cast import torch @@ -22,6 +24,7 @@ from typing_extensions import override import lightning.pytorch as pl +from lightning.fabric.utilities.imports import _IS_INTERACTIVE from lightning.pytorch.callbacks.progress.progress_bar import ProgressBar from lightning.pytorch.utilities.imports import _RICH_AVAILABLE from lightning.pytorch.utilities.types import STEP_OUTPUT @@ -29,6 +32,7 @@ if _RICH_AVAILABLE: from rich import get_console, reconfigure from rich.console import Console, RenderableType + from rich.live import Live from rich.progress import BarColumn, Progress, ProgressColumn, Task, TaskID, TextColumn from rich.progress_bar import ProgressBar as _RichProgressBar from rich.style import Style @@ -66,9 +70,46 @@ class CustomInfiniteTask(Task): def time_remaining(self) -> Optional[float]: return None + class _RefreshThread(Thread): + def __init__( + self, + live: Live, + ) -> None: + self.live = live + self.refresh_cond = False + self.done = Event() + super().__init__(daemon=True) + + def run(self) -> None: + while not self.done.is_set(): + if self.refresh_cond: + with self.live._lock: + self.live.refresh() + self.refresh_cond = False + time.sleep(0.001) + + def stop(self) -> None: + self.done.set() + class CustomProgress(Progress): """Overrides ``Progress`` to support adding tasks that have an infinite total size.""" + def start(self) -> None: + if self.live.auto_refresh: + self.live._refresh_thread = _RefreshThread(self.live) + self.live.auto_refresh = False + super().start() + if self.live._refresh_thread: + self.live.auto_refresh = True + self.live._refresh_thread.start() + + def refresh(self) -> None: + if self.live.auto_refresh: + self.live._refresh_thread.refresh_cond = True + if _IS_INTERACTIVE: + return super().refresh() + return None + def add_task( self, description: str, @@ -356,7 +397,7 @@ def _init_progress(self, trainer: "pl.Trainer") -> None: self.progress = CustomProgress( *self.configure_columns(trainer), self._metric_component, - auto_refresh=False, + auto_refresh=True, disable=self.is_disabled, console=self._console, ) From c83a8f491876c9be4c65a45680174b7a674f6ccb Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Wed, 29 Oct 2025 19:11:04 +0800 Subject: [PATCH 02/18] temp fix unittests --- src/lightning/pytorch/callbacks/progress/rich_progress.py | 8 ++++++++ .../callbacks/progress/test_rich_progress_bar.py | 4 ++++ 2 files changed, 12 insertions(+) diff --git a/src/lightning/pytorch/callbacks/progress/rich_progress.py b/src/lightning/pytorch/callbacks/progress/rich_progress.py index 16efa8d42ac63..8ec9be8742f92 100644 --- a/src/lightning/pytorch/callbacks/progress/rich_progress.py +++ b/src/lightning/pytorch/callbacks/progress/rich_progress.py @@ -103,6 +103,14 @@ def start(self) -> None: self.live.auto_refresh = True self.live._refresh_thread.start() + def stop(self) -> None: + refresh_thread = self.live._refresh_thread + self.live.auto_refresh = refresh_thread is not None + super().stop() + if refresh_thread: + refresh_thread.stop() + refresh_thread.join() + def refresh(self) -> None: if self.live.auto_refresh: self.live._refresh_thread.refresh_cond = True diff --git a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py index 9d74871ce84e4..7291daf5df53b 100644 --- a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py +++ b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py @@ -131,6 +131,8 @@ def test_rich_progress_bar_custom_theme(): _, kwargs = mocks["ProcessingSpeedColumn"].call_args assert kwargs["style"] == theme.processing_speed + progress_bar.progress.live._refresh_thread.stop() + @RunIf(rich=True) def test_rich_progress_bar_keyboard_interrupt(tmp_path): @@ -176,6 +178,8 @@ def configure_columns(self, trainer): assert progress_bar.progress.columns[0] == custom_column assert len(progress_bar.progress.columns) == 2 + progress_bar.progress.stop() + @RunIf(rich=True) @pytest.mark.parametrize(("leave", "reset_call_count"), ([(True, 0), (False, 3)])) From 24c1729ea96fae388ad0f8352858c26d26eb60a4 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Wed, 29 Oct 2025 20:55:58 +0800 Subject: [PATCH 03/18] release time sleep --- src/lightning/pytorch/callbacks/progress/rich_progress.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/callbacks/progress/rich_progress.py b/src/lightning/pytorch/callbacks/progress/rich_progress.py index 8ec9be8742f92..35b581cc26893 100644 --- a/src/lightning/pytorch/callbacks/progress/rich_progress.py +++ b/src/lightning/pytorch/callbacks/progress/rich_progress.py @@ -86,7 +86,7 @@ def run(self) -> None: with self.live._lock: self.live.refresh() self.refresh_cond = False - time.sleep(0.001) + time.sleep(0.005) def stop(self) -> None: self.done.set() From 47c3bd9beda77631a970b7667f0587868def9289 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Wed, 29 Oct 2025 20:56:23 +0800 Subject: [PATCH 04/18] fix unittest --- tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py index 7291daf5df53b..a44d116d76d46 100644 --- a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py +++ b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py @@ -132,6 +132,7 @@ def test_rich_progress_bar_custom_theme(): assert kwargs["style"] == theme.processing_speed progress_bar.progress.live._refresh_thread.stop() + progress_bar.progress.live._refresh_thread.join() @RunIf(rich=True) From e3b3100bc7af834086d697a204e4edc0add64667 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Thu, 30 Oct 2025 05:48:19 +0800 Subject: [PATCH 05/18] ref soft_refresh --- .../pytorch/callbacks/progress/rich_progress.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/lightning/pytorch/callbacks/progress/rich_progress.py b/src/lightning/pytorch/callbacks/progress/rich_progress.py index 35b581cc26893..6d8d8f2d025b1 100644 --- a/src/lightning/pytorch/callbacks/progress/rich_progress.py +++ b/src/lightning/pytorch/callbacks/progress/rich_progress.py @@ -111,12 +111,9 @@ def stop(self) -> None: refresh_thread.stop() refresh_thread.join() - def refresh(self) -> None: + def soft_refresh(self) -> None: if self.live.auto_refresh: self.live._refresh_thread.refresh_cond = True - if _IS_INTERACTIVE: - return super().refresh() - return None def add_task( self, @@ -413,9 +410,12 @@ def _init_progress(self, trainer: "pl.Trainer") -> None: # progress has started self._progress_stopped = False - def refresh(self) -> None: + def refresh(self, hard=False) -> None: if self.progress: - self.progress.refresh() + if hard or _IS_INTERACTIVE: + self.progress.refresh() + else: + self.progress.soft_refresh() @override def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: From f11e58fd9cf64b53fac3d878c51c8d884af924ff Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Fri, 31 Oct 2025 01:19:24 +0800 Subject: [PATCH 06/18] fix test --- .../callbacks/progress/test_rich_progress_bar.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py index a44d116d76d46..567552459cc28 100644 --- a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py +++ b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py @@ -350,7 +350,8 @@ def training_step(self, *args, **kwargs): for key in ("loss", "v_num", "train_loss"): assert key in rendered[train_progress_bar_id][1] - assert key not in rendered[val_progress_bar_id][1] + if val_progress_bar_id in rendered: + assert key not in rendered[val_progress_bar_id][1] def test_rich_progress_bar_metrics_fast_dev_run(tmp_path): @@ -364,7 +365,8 @@ def test_rich_progress_bar_metrics_fast_dev_run(tmp_path): val_progress_bar_id = progress_bar.val_progress_bar_id rendered = progress_bar.progress.columns[-1]._renderable_cache assert "v_num" not in rendered[train_progress_bar_id][1] - assert "v_num" not in rendered[val_progress_bar_id][1] + if val_progress_bar_id in rendered: + assert "v_num" not in rendered[val_progress_bar_id][1] @RunIf(rich=True) From 86d823ade38f8c8db9a32f5c183c235a77e5f58a Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Fri, 31 Oct 2025 01:30:57 +0800 Subject: [PATCH 07/18] refactor _RefreshThread --- .../callbacks/progress/rich_progress.py | 27 +++++++++---------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/src/lightning/pytorch/callbacks/progress/rich_progress.py b/src/lightning/pytorch/callbacks/progress/rich_progress.py index 6d8d8f2d025b1..85187b1ca1598 100644 --- a/src/lightning/pytorch/callbacks/progress/rich_progress.py +++ b/src/lightning/pytorch/callbacks/progress/rich_progress.py @@ -16,7 +16,6 @@ from collections.abc import Generator from dataclasses import dataclass from datetime import timedelta -from threading import Event, Thread from typing import Any, Optional, Union, cast import torch @@ -32,7 +31,7 @@ if _RICH_AVAILABLE: from rich import get_console, reconfigure from rich.console import Console, RenderableType - from rich.live import Live + from rich.live import _RefreshThread as _RichRefreshThread from rich.progress import BarColumn, Progress, ProgressColumn, Task, TaskID, TextColumn from rich.progress_bar import ProgressBar as _RichProgressBar from rich.style import Style @@ -70,15 +69,10 @@ class CustomInfiniteTask(Task): def time_remaining(self) -> Optional[float]: return None - class _RefreshThread(Thread): - def __init__( - self, - live: Live, - ) -> None: - self.live = live + class _RefreshThread(_RichRefreshThread): + def __init__(self, *args, **kwargs) -> None: self.refresh_cond = False - self.done = Event() - super().__init__(daemon=True) + super().__init__(*args, **kwargs) def run(self) -> None: while not self.done.is_set(): @@ -88,15 +82,19 @@ def run(self) -> None: self.refresh_cond = False time.sleep(0.005) - def stop(self) -> None: - self.done.set() - class CustomProgress(Progress): """Overrides ``Progress`` to support adding tasks that have an infinite total size.""" def start(self) -> None: + """Starts the progress display. + + Notes + ----- + This override is needed to support the custom refresh thread. + + """ if self.live.auto_refresh: - self.live._refresh_thread = _RefreshThread(self.live) + self.live._refresh_thread = _RefreshThread(self.live, self.live.refresh_per_second) self.live.auto_refresh = False super().start() if self.live._refresh_thread: @@ -105,7 +103,6 @@ def start(self) -> None: def stop(self) -> None: refresh_thread = self.live._refresh_thread - self.live.auto_refresh = refresh_thread is not None super().stop() if refresh_thread: refresh_thread.stop() From 2f24cb7c1ade87d128b262e2b281054211b762d9 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Fri, 31 Oct 2025 01:40:29 +0800 Subject: [PATCH 08/18] add type annotation --- src/lightning/pytorch/callbacks/progress/rich_progress.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lightning/pytorch/callbacks/progress/rich_progress.py b/src/lightning/pytorch/callbacks/progress/rich_progress.py index 85187b1ca1598..f849164327e64 100644 --- a/src/lightning/pytorch/callbacks/progress/rich_progress.py +++ b/src/lightning/pytorch/callbacks/progress/rich_progress.py @@ -70,7 +70,7 @@ def time_remaining(self) -> Optional[float]: return None class _RefreshThread(_RichRefreshThread): - def __init__(self, *args, **kwargs) -> None: + def __init__(self, *args: Any, **kwargs: Any) -> None: self.refresh_cond = False super().__init__(*args, **kwargs) @@ -407,7 +407,7 @@ def _init_progress(self, trainer: "pl.Trainer") -> None: # progress has started self._progress_stopped = False - def refresh(self, hard=False) -> None: + def refresh(self, hard: bool = False) -> None: if self.progress: if hard or _IS_INTERACTIVE: self.progress.refresh() From daaacf9588a3d8e4e3e76949cec20d6737ac4c39 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Fri, 31 Oct 2025 20:30:36 +0800 Subject: [PATCH 09/18] add isinstance check --- src/lightning/pytorch/callbacks/progress/rich_progress.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/callbacks/progress/rich_progress.py b/src/lightning/pytorch/callbacks/progress/rich_progress.py index f849164327e64..15cec555a93d1 100644 --- a/src/lightning/pytorch/callbacks/progress/rich_progress.py +++ b/src/lightning/pytorch/callbacks/progress/rich_progress.py @@ -109,7 +109,7 @@ def stop(self) -> None: refresh_thread.join() def soft_refresh(self) -> None: - if self.live.auto_refresh: + if self.live.auto_refresh and isinstance(self.live._refresh_thread, _RefreshThread): self.live._refresh_thread.refresh_cond = True def add_task( From 6b39688315ac52ae8c01be2ca56c69c45f3d470a Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Mon, 8 Dec 2025 17:23:10 +0800 Subject: [PATCH 10/18] fix(progress): update refresh_rate to be per second --- .../pytorch/callbacks/progress/rich_progress.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/lightning/pytorch/callbacks/progress/rich_progress.py b/src/lightning/pytorch/callbacks/progress/rich_progress.py index 15cec555a93d1..53e18e5f82b73 100644 --- a/src/lightning/pytorch/callbacks/progress/rich_progress.py +++ b/src/lightning/pytorch/callbacks/progress/rich_progress.py @@ -80,7 +80,7 @@ def run(self) -> None: with self.live._lock: self.live.refresh() self.refresh_cond = False - time.sleep(0.005) + time.sleep(1 / self.refresh_per_second) class CustomProgress(Progress): """Overrides ``Progress`` to support adding tasks that have an infinite total size.""" @@ -282,8 +282,8 @@ class RichProgressBar(ProgressBar): trainer = Trainer(callbacks=RichProgressBar()) Args: - refresh_rate: Determines at which rate (in number of batches) the progress bars get updated. - Set it to ``0`` to disable the display. + refresh_rate: Determines at which rate (per second) the progress bars get updated. + Set it to ``0`` to disable the display. Default: 100 leave: Leaves the finished progress bar in the terminal at the end of the epoch. Default: False theme: Contains styles used to stylize the progress bar. console_kwargs: Args for constructing a `Console` @@ -301,7 +301,7 @@ class RichProgressBar(ProgressBar): def __init__( self, - refresh_rate: int = 1, + refresh_rate: int = 100, leave: bool = False, theme: RichProgressBarTheme = RichProgressBarTheme(), console_kwargs: Optional[dict[str, Any]] = None, @@ -400,6 +400,7 @@ def _init_progress(self, trainer: "pl.Trainer") -> None: *self.configure_columns(trainer), self._metric_component, auto_refresh=True, + refresh_per_second=self.refresh_rate if self.is_enabled else 1, disable=self.is_disabled, console=self._console, ) From d84c74a51d8ea6456285b9c252ac535a8ff0424a Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Mon, 8 Dec 2025 23:37:40 +0800 Subject: [PATCH 11/18] refactor(progress): aligning progress bar's `refresh_rate` semantic --- .../callbacks/progress/rich_progress.py | 46 ++++++------------- 1 file changed, 14 insertions(+), 32 deletions(-) diff --git a/src/lightning/pytorch/callbacks/progress/rich_progress.py b/src/lightning/pytorch/callbacks/progress/rich_progress.py index 53e18e5f82b73..e16aa2ef782d2 100644 --- a/src/lightning/pytorch/callbacks/progress/rich_progress.py +++ b/src/lightning/pytorch/callbacks/progress/rich_progress.py @@ -498,8 +498,6 @@ def on_validation_batch_start( visible=False, ) - self.refresh() - def _add_task(self, total_batches: Union[int, float], description: str, visible: bool = True) -> "TaskID": assert self.progress is not None return self.progress.add_task( @@ -513,28 +511,23 @@ def _initialize_train_progress_bar_id(self) -> None: train_description = self._get_train_description(self.trainer.current_epoch) self.train_progress_bar_id = self._add_task(total_batches, train_description) - def _update(self, progress_bar_id: Optional["TaskID"], current: int, visible: bool = True) -> None: + def _update( + self, progress_bar_id: Optional["TaskID"], current: int, visible: bool = True, refresh: bool = True + ) -> None: if self.progress is not None and self.is_enabled and progress_bar_id is not None: - total = self.progress.tasks[progress_bar_id].total - assert total is not None - if not self._should_update(current, total): - return - self.progress.update(progress_bar_id, completed=current, visible=visible) - - def _should_update(self, current: int, total: Union[int, float]) -> bool: - return current % self.refresh_rate == 0 or current == total + self.progress.update(progress_bar_id, completed=current, visible=visible, refresh=refresh) @override def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: if self.is_enabled and self.val_progress_bar_id is not None and trainer.state.fn == "fit": assert self.progress is not None - self.progress.update(self.val_progress_bar_id, advance=0, visible=False) - self.refresh() + self.progress.update(self.val_progress_bar_id, advance=0, visible=False, refresh=True) @override def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: if trainer.state.fn == "fit": self._update_metrics(trainer, pl_module) + self.refresh() self.reset_dataloader_idx_tracker() @override @@ -561,7 +554,6 @@ def on_test_batch_start( assert self.progress is not None self.progress.update(self.test_progress_bar_id, advance=0, visible=False) self.test_progress_bar_id = self._add_task(self.total_test_batches_current_dataloader, self.test_description) - self.refresh() @override def on_predict_batch_start( @@ -581,7 +573,6 @@ def on_predict_batch_start( self.predict_progress_bar_id = self._add_task( self.total_predict_batches_current_dataloader, self.predict_description ) - self.refresh() @override def on_train_batch_end( @@ -595,13 +586,14 @@ def on_train_batch_end( if not self.is_disabled and self.train_progress_bar_id is None: # can happen when resuming from a mid-epoch restart self._initialize_train_progress_bar_id() - self._update(self.train_progress_bar_id, batch_idx + 1) - self._update_metrics(trainer, pl_module, batch_idx + 1) + self._update(self.train_progress_bar_id, batch_idx + 1, refresh=False) + self._update_metrics(trainer, pl_module) self.refresh() @override def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - self._update_metrics(trainer, pl_module, total_batches=True) + self._update_metrics(trainer, pl_module) + self.refresh() @override def on_validation_batch_end( @@ -617,13 +609,12 @@ def on_validation_batch_end( return if trainer.sanity_checking: if self.val_sanity_progress_bar_id is not None: - self._update(self.val_sanity_progress_bar_id, batch_idx + 1) + self._update(self.val_sanity_progress_bar_id, batch_idx + 1, refresh=True) return if self.val_progress_bar_id is None: return - self._update(self.val_progress_bar_id, batch_idx + 1) - self.refresh() + self._update(self.val_progress_bar_id, batch_idx + 1, refresh=True) @override def on_test_batch_end( @@ -637,8 +628,7 @@ def on_test_batch_end( ) -> None: if self.is_disabled or self.test_progress_bar_id is None: return - self._update(self.test_progress_bar_id, batch_idx + 1) - self.refresh() + self._update(self.test_progress_bar_id, batch_idx + 1, refresh=True) @override def on_predict_batch_end( @@ -652,8 +642,7 @@ def on_predict_batch_end( ) -> None: if self.is_disabled or self.predict_progress_bar_id is None: return - self._update(self.predict_progress_bar_id, batch_idx + 1) - self.refresh() + self._update(self.predict_progress_bar_id, batch_idx + 1, refresh=True) def _get_train_description(self, current_epoch: int) -> str: train_description = f"Epoch {current_epoch}" @@ -690,17 +679,10 @@ def _update_metrics( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", - current: Optional[int] = None, - total_batches: bool = False, ) -> None: if not self.is_enabled or self._metric_component is None: return - if current is not None and not total_batches: - total = self.total_train_batches - if not self._should_update(current, total): - return - metrics = self.get_metrics(trainer, pl_module) if self._metric_component: self._metric_component.update(metrics) From 4da13b40921f157ff00725607d181e93e7d5a90c Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Tue, 9 Dec 2025 00:40:34 +0800 Subject: [PATCH 12/18] fix(progress): update default refresh_rate to 10 --- src/lightning/pytorch/callbacks/progress/rich_progress.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lightning/pytorch/callbacks/progress/rich_progress.py b/src/lightning/pytorch/callbacks/progress/rich_progress.py index e16aa2ef782d2..937202b0da1ad 100644 --- a/src/lightning/pytorch/callbacks/progress/rich_progress.py +++ b/src/lightning/pytorch/callbacks/progress/rich_progress.py @@ -283,7 +283,7 @@ class RichProgressBar(ProgressBar): Args: refresh_rate: Determines at which rate (per second) the progress bars get updated. - Set it to ``0`` to disable the display. Default: 100 + Set it to ``0`` to disable the display. Default: 10 leave: Leaves the finished progress bar in the terminal at the end of the epoch. Default: False theme: Contains styles used to stylize the progress bar. console_kwargs: Args for constructing a `Console` @@ -301,7 +301,7 @@ class RichProgressBar(ProgressBar): def __init__( self, - refresh_rate: int = 100, + refresh_rate: int = 10, leave: bool = False, theme: RichProgressBarTheme = RichProgressBarTheme(), console_kwargs: Optional[dict[str, Any]] = None, From d15460feff5d7ece770191b1c8df955063a398a0 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Tue, 9 Dec 2025 01:51:37 +0800 Subject: [PATCH 13/18] refactor: revert some changes --- .../callbacks/progress/rich_progress.py | 27 ++++++++++++------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/src/lightning/pytorch/callbacks/progress/rich_progress.py b/src/lightning/pytorch/callbacks/progress/rich_progress.py index 937202b0da1ad..6cca14ae7a0a0 100644 --- a/src/lightning/pytorch/callbacks/progress/rich_progress.py +++ b/src/lightning/pytorch/callbacks/progress/rich_progress.py @@ -498,6 +498,8 @@ def on_validation_batch_start( visible=False, ) + self.refresh() + def _add_task(self, total_batches: Union[int, float], description: str, visible: bool = True) -> "TaskID": assert self.progress is not None return self.progress.add_task( @@ -512,22 +514,27 @@ def _initialize_train_progress_bar_id(self) -> None: self.train_progress_bar_id = self._add_task(total_batches, train_description) def _update( - self, progress_bar_id: Optional["TaskID"], current: int, visible: bool = True, refresh: bool = True + self, + progress_bar_id: Optional["TaskID"], + current: int, + visible: bool = True, + hard: bool = False, ) -> None: if self.progress is not None and self.is_enabled and progress_bar_id is not None: - self.progress.update(progress_bar_id, completed=current, visible=visible, refresh=refresh) + self.progress.update(progress_bar_id, completed=current, visible=visible) + self.refresh(hard=hard) @override def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: if self.is_enabled and self.val_progress_bar_id is not None and trainer.state.fn == "fit": assert self.progress is not None - self.progress.update(self.val_progress_bar_id, advance=0, visible=False, refresh=True) + self.progress.update(self.val_progress_bar_id, advance=0, visible=False) + self.refresh() @override def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: if trainer.state.fn == "fit": self._update_metrics(trainer, pl_module) - self.refresh() self.reset_dataloader_idx_tracker() @override @@ -554,6 +561,7 @@ def on_test_batch_start( assert self.progress is not None self.progress.update(self.test_progress_bar_id, advance=0, visible=False) self.test_progress_bar_id = self._add_task(self.total_test_batches_current_dataloader, self.test_description) + self.refresh() @override def on_predict_batch_start( @@ -573,6 +581,7 @@ def on_predict_batch_start( self.predict_progress_bar_id = self._add_task( self.total_predict_batches_current_dataloader, self.predict_description ) + self.refresh() @override def on_train_batch_end( @@ -586,7 +595,7 @@ def on_train_batch_end( if not self.is_disabled and self.train_progress_bar_id is None: # can happen when resuming from a mid-epoch restart self._initialize_train_progress_bar_id() - self._update(self.train_progress_bar_id, batch_idx + 1, refresh=False) + self._update(self.train_progress_bar_id, batch_idx + 1) self._update_metrics(trainer, pl_module) self.refresh() @@ -609,12 +618,12 @@ def on_validation_batch_end( return if trainer.sanity_checking: if self.val_sanity_progress_bar_id is not None: - self._update(self.val_sanity_progress_bar_id, batch_idx + 1, refresh=True) + self._update(self.val_sanity_progress_bar_id, batch_idx + 1) return if self.val_progress_bar_id is None: return - self._update(self.val_progress_bar_id, batch_idx + 1, refresh=True) + self._update(self.val_progress_bar_id, batch_idx + 1) @override def on_test_batch_end( @@ -628,7 +637,7 @@ def on_test_batch_end( ) -> None: if self.is_disabled or self.test_progress_bar_id is None: return - self._update(self.test_progress_bar_id, batch_idx + 1, refresh=True) + self._update(self.test_progress_bar_id, batch_idx + 1) @override def on_predict_batch_end( @@ -642,7 +651,7 @@ def on_predict_batch_end( ) -> None: if self.is_disabled or self.predict_progress_bar_id is None: return - self._update(self.predict_progress_bar_id, batch_idx + 1, refresh=True) + self._update(self.predict_progress_bar_id, batch_idx + 1) def _get_train_description(self, current_epoch: int) -> str: train_description = f"Epoch {current_epoch}" From 0cb295393c67e5471267877ffdcd37a620cb2cb7 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Tue, 9 Dec 2025 02:21:35 +0800 Subject: [PATCH 14/18] test(progress): update test_rich_progress_bar_with_refresh_rate to test_rich_progress_bar_update_counts --- .../progress/test_rich_progress_bar.py | 23 ++++++++----------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py index 567552459cc28..7b4a33e9a155a 100644 --- a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py +++ b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py @@ -221,22 +221,19 @@ def test_rich_progress_bar_refresh_rate_disabled(progress_update, tmp_path): @RunIf(rich=True) @pytest.mark.parametrize( - ("refresh_rate", "train_batches", "val_batches", "expected_call_count"), + ("train_batches", "val_batches", "expected_call_count"), [ # note: there is always one extra update at the very end (+1) - (3, 6, 6, 2 + 2 + 1), - (4, 6, 6, 2 + 2 + 1), - (7, 6, 6, 1 + 1 + 1), - (1, 2, 3, 2 + 3 + 1), - (1, 0, 0, 0 + 0), - (3, 1, 0, 1 + 0), - (3, 1, 1, 1 + 1 + 1), - (3, 5, 0, 2 + 0), - (3, 5, 2, 2 + 1 + 1), - (6, 5, 2, 1 + 1 + 1), + (6, 6, 6 + 6 + 1), + (2, 3, 2 + 3 + 1), + (0, 0, 0 + 0), + (1, 0, 1 + 0), + (1, 1, 1 + 1 + 1), + (5, 0, 5 + 0), + (5, 2, 5 + 2 + 1), ], ) -def test_rich_progress_bar_with_refresh_rate(tmp_path, refresh_rate, train_batches, val_batches, expected_call_count): +def test_rich_progress_bar_update_counts(tmp_path, train_batches, val_batches, expected_call_count): model = BoringModel() trainer = Trainer( default_root_dir=tmp_path, @@ -244,7 +241,7 @@ def test_rich_progress_bar_with_refresh_rate(tmp_path, refresh_rate, train_batch limit_train_batches=train_batches, limit_val_batches=val_batches, max_epochs=1, - callbacks=RichProgressBar(refresh_rate=refresh_rate), + callbacks=RichProgressBar(), ) trainer.progress_bar_callback.on_train_start(trainer, model) From 21bae149000e2223cb228c25126fb8db7663ab90 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Tue, 9 Dec 2025 02:24:34 +0800 Subject: [PATCH 15/18] reset default refresh_rate to 100 --- src/lightning/pytorch/callbacks/progress/rich_progress.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lightning/pytorch/callbacks/progress/rich_progress.py b/src/lightning/pytorch/callbacks/progress/rich_progress.py index 6cca14ae7a0a0..701c969d02b17 100644 --- a/src/lightning/pytorch/callbacks/progress/rich_progress.py +++ b/src/lightning/pytorch/callbacks/progress/rich_progress.py @@ -283,7 +283,7 @@ class RichProgressBar(ProgressBar): Args: refresh_rate: Determines at which rate (per second) the progress bars get updated. - Set it to ``0`` to disable the display. Default: 10 + Set it to ``0`` to disable the display. Default: 100 leave: Leaves the finished progress bar in the terminal at the end of the epoch. Default: False theme: Contains styles used to stylize the progress bar. console_kwargs: Args for constructing a `Console` @@ -301,7 +301,7 @@ class RichProgressBar(ProgressBar): def __init__( self, - refresh_rate: int = 10, + refresh_rate: int = 100, leave: bool = False, theme: RichProgressBarTheme = RichProgressBarTheme(), console_kwargs: Optional[dict[str, Any]] = None, From 5c7070d7d3ff10dd44096159e0e1e450928dbe3e Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Tue, 9 Dec 2025 08:22:43 +0800 Subject: [PATCH 16/18] fix test --- .../callbacks/progress/test_tqdm_progress_bar.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py b/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py index 0bd29b998c598..72da9bf543155 100644 --- a/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py +++ b/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py @@ -320,7 +320,7 @@ def on_validation_epoch_end(self, *args): def test_tqdm_progress_bar_default_value(tmp_path): """Test that a value of None defaults to refresh rate 1.""" - trainer = Trainer(default_root_dir=tmp_path) + trainer = Trainer(default_root_dir=tmp_path, callbacks=TQDMProgressBar()) assert trainer.progress_bar_callback.refresh_rate == 1 @@ -328,9 +328,6 @@ def test_tqdm_progress_bar_default_value(tmp_path): @patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False) def test_tqdm_progress_bar_value_on_colab(tmp_path): """Test that Trainer will override the default in Google COLAB.""" - trainer = Trainer(default_root_dir=tmp_path) - assert trainer.progress_bar_callback.refresh_rate == 20 - trainer = Trainer(default_root_dir=tmp_path, callbacks=TQDMProgressBar()) assert trainer.progress_bar_callback.refresh_rate == 20 From 456e82c0c23571f2262ae23102d03e9fc7951ea9 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Thu, 18 Dec 2025 00:46:14 +0800 Subject: [PATCH 17/18] update hooks doc --- docs/source-pytorch/common/hooks.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source-pytorch/common/hooks.rst b/docs/source-pytorch/common/hooks.rst index 89c1c15d0413f..0a26441a46e35 100644 --- a/docs/source-pytorch/common/hooks.rst +++ b/docs/source-pytorch/common/hooks.rst @@ -59,7 +59,7 @@ important to understand. The following order is always used: ... Callback: Training is starting! Model: Training is starting! - Epoch 0/0 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 64/64 ... + Epoch 0/0 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 64/64 ... .. note:: From a6ea6f0763e292c02cf0fd5d2498fb5f1fb9ec8a Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Thu, 18 Dec 2025 00:53:40 +0800 Subject: [PATCH 18/18] update hooks doc --- docs/source-pytorch/common/hooks.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source-pytorch/common/hooks.rst b/docs/source-pytorch/common/hooks.rst index 0a26441a46e35..73b05f7765e53 100644 --- a/docs/source-pytorch/common/hooks.rst +++ b/docs/source-pytorch/common/hooks.rst @@ -59,7 +59,7 @@ important to understand. The following order is always used: ... Callback: Training is starting! Model: Training is starting! - Epoch 0/0 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 64/64 ... + Epoch 0/0 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 64/64 ... .. note::