From 8f924d40645f13a73b7896884d86f7a1472a71f0 Mon Sep 17 00:00:00 2001 From: cvanelteren Date: Tue, 30 Dec 2025 07:35:57 +1000 Subject: [PATCH] Preserve log formatter when setting log scales --- ultraplot/axes/cartesian.py | 68 ++++++++++++++++++++++++++++++++---- ultraplot/tests/test_plot.py | 27 ++++++++++++++ 2 files changed, 89 insertions(+), 6 deletions(-) diff --git a/ultraplot/axes/cartesian.py b/ultraplot/axes/cartesian.py index 46685b5df..b19047542 100644 --- a/ultraplot/axes/cartesian.py +++ b/ultraplot/axes/cartesian.py @@ -5,22 +5,27 @@ import copy import inspect +import matplotlib.axis as maxis import matplotlib.dates as mdates import matplotlib.ticker as mticker import numpy as np - from packaging import version from .. import constructor from .. import scale as pscale from .. import ticker as pticker from ..config import rc -from ..internals import ic # noqa: F401 -from ..internals import _not_none, _pop_rc, _version_mpl, docstring, labels, warnings -from . import plot, shared -import matplotlib.axis as maxis - +from ..internals import ( + _not_none, + _pop_rc, + _version_mpl, + docstring, + ic, # noqa: F401 + labels, + warnings, +) from ..utils import units +from . import plot, shared __all__ = ["CartesianAxes"] @@ -779,6 +784,26 @@ def _sharey_setup(self, sharey, *, labels=True, limits=True): if level > 1 and limits: self._sharey_limits(sharey) + def _apply_log_formatter_on_scale(self, s): + """ + Enforce log formatter when log scale is set and rc is enabled. + """ + if not rc.find("formatter.log", context=True): + return + if getattr(self, f"get_{s}scale")() != "log": + return + self._update_formatter(s, "log") + + def set_xscale(self, value, **kwargs): + result = super().set_xscale(value, **kwargs) + self._apply_log_formatter_on_scale("x") + return result + + def set_yscale(self, value, **kwargs): + result = super().set_yscale(value, **kwargs) + self._apply_log_formatter_on_scale("y") + return result + def _update_formatter( self, s, @@ -1389,6 +1414,7 @@ def format( # WARNING: Changing axis scale also changes default locators # and formatters, and restricts possible range of axis limits, # so critical to do it first. + scale_requested = scale is not None if scale is not None: scale = constructor.Scale(scale, **scale_kw) getattr(self, f"set_{s}scale")(scale) @@ -1480,10 +1506,40 @@ def format( tickrange=tickrange, wraprange=wraprange, ) + if ( + scale_requested + and formatter is None + and not formatter_kw + and tickrange is None + and wraprange is None + and rc.find("formatter.log", context=True) + and getattr(self, f"get_{s}scale")() == "log" + ): + self._update_formatter(s, "log") # Ensure ticks are within axis bounds self._fix_ticks(s, fixticks=fixticks) + if rc.find("formatter.log", context=True): + if ( + xscale is not None + and xformatter is None + and not xformatter_kw + and xtickrange is None + and xwraprange is None + and self.get_xscale() == "log" + ): + self._update_formatter("x", "log") + if ( + yscale is not None + and yformatter is None + and not yformatter_kw + and ytickrange is None + and ywraprange is None + and self.get_yscale() == "log" + ): + self._update_formatter("y", "log") + # Parent format method if aspect is not None: self.set_aspect(aspect) diff --git a/ultraplot/tests/test_plot.py b/ultraplot/tests/test_plot.py index fb54d191a..1bcb69684 100644 --- a/ultraplot/tests/test_plot.py +++ b/ultraplot/tests/test_plot.py @@ -361,6 +361,33 @@ def reset(ax): uplt.close(fig) +def test_format_log_scale_preserves_log_formatter(): + """ + Test that setting a log scale preserves the log formatter when enabled. + """ + x = np.linspace(1, 1e6, 10) + log_formatter = uplt.constructor.Formatter("log") + log_formatter_type = type(log_formatter) + + with uplt.rc.context({"formatter.log": True}): + fig, ax = uplt.subplots() + ax.plot(x, x) + ax.format(yscale="log") + assert isinstance(ax.yaxis.get_major_formatter(), log_formatter_type) + ax.set_yscale("log") + assert isinstance(ax.yaxis.get_major_formatter(), log_formatter_type) + + with uplt.rc.context({"formatter.log": False}): + fig, ax = uplt.subplots() + ax.plot(x, x) + ax.format(yscale="log") + assert not isinstance(ax.yaxis.get_major_formatter(), log_formatter_type) + ax.set_yscale("log") + assert not isinstance(ax.yaxis.get_major_formatter(), log_formatter_type) + + uplt.close(fig) + + def test_shading_pcolor(rng): """ Pcolormesh by default adjusts the plot by