From 10a187dba160d16ed919dad89ac2fc7f7b10ede4 Mon Sep 17 00:00:00 2001 From: cvanelteren Date: Tue, 30 Dec 2025 20:06:04 +1000 Subject: [PATCH 1/6] Add subset label sharing groups --- ultraplot/axes/base.py | 30 +++++ ultraplot/axes/cartesian.py | 28 +++-- ultraplot/axes/geo.py | 19 ++++ ultraplot/figure.py | 219 ++++++++++++++++++++++++++++++++++++ ultraplot/gridspec.py | 36 ++++++ 5 files changed, 323 insertions(+), 9 deletions(-) diff --git a/ultraplot/axes/base.py b/ultraplot/axes/base.py index a0e30f68b..c297563fc 100644 --- a/ultraplot/axes/base.py +++ b/ultraplot/axes/base.py @@ -3148,12 +3148,42 @@ def _update_share_labels(self, axes=None, target="x"): target : {'x', 'y'}, optional Which axis labels to share ('x' for x-axis, 'y' for y-axis) """ + if axes is False: + self.figure._clear_share_label_groups([self], target=target) + return + if not axes: + return + axes = list(axes) if not axes: return # Convert indices to actual axes objects if isinstance(axes[0], int): axes = [self.figure.axes[i] for i in axes] + axes = [ + ax._get_topmost_axes() if hasattr(ax, "_get_topmost_axes") else ax + for ax in axes + if ax is not None + ] + if len(axes) < 2: + return + # Preserve order while de-duplicating + seen = set() + unique = [] + for ax in axes: + ax_id = id(ax) + if ax_id in seen: + continue + seen.add(ax_id) + unique.append(ax) + axes = unique + if len(axes) < 2: + return + + # Prefer figure-managed spanning labels when possible + if all(isinstance(ax, maxes.SubplotBase) for ax in axes): + self.figure._register_share_label_group(axes, target=target, source=self) + return # Get the center position of the axes group if box := self.get_center_of_axes(axes): diff --git a/ultraplot/axes/cartesian.py b/ultraplot/axes/cartesian.py index 46685b5df..351823824 100644 --- a/ultraplot/axes/cartesian.py +++ b/ultraplot/axes/cartesian.py @@ -5,22 +5,27 @@ import copy import inspect +import matplotlib.axis as maxis import matplotlib.dates as mdates import matplotlib.ticker as mticker import numpy as np - from packaging import version from .. import constructor from .. import scale as pscale from .. import ticker as pticker from ..config import rc -from ..internals import ic # noqa: F401 -from ..internals import _not_none, _pop_rc, _version_mpl, docstring, labels, warnings -from . import plot, shared -import matplotlib.axis as maxis - +from ..internals import ( + _not_none, + _pop_rc, + _version_mpl, + docstring, + ic, # noqa: F401 + labels, + warnings, +) from ..utils import units +from . import plot, shared __all__ = ["CartesianAxes"] @@ -432,9 +437,14 @@ def _apply_axis_sharing_for_axis( # Handle axis label sharing (level > 0) if level > 0: - shared_axis_obj = getattr(shared_axis, f"{axis_name}axis") - labels._transfer_label(axis.label, shared_axis_obj.label) - axis.label.set_visible(False) + if self.figure._is_share_label_group_member(self, axis_name): + pass + elif self.figure._is_share_label_group_member(shared_axis, axis_name): + axis.label.set_visible(False) + else: + shared_axis_obj = getattr(shared_axis, f"{axis_name}axis") + labels._transfer_label(axis.label, shared_axis_obj.label) + axis.label.set_visible(False) # Handle tick label sharing (level > 2) if level > 2: diff --git a/ultraplot/axes/geo.py b/ultraplot/axes/geo.py index 9d65cff98..267acb206 100644 --- a/ultraplot/axes/geo.py +++ b/ultraplot/axes/geo.py @@ -32,6 +32,7 @@ _version_cartopy, docstring, ic, # noqa: F401 + labels, warnings, ) from ..utils import units @@ -661,6 +662,24 @@ def _apply_axis_sharing(self): the leftmost and bottommost is the *figure* sharing level. """ + # Share axis labels + if self._sharex and self.figure._sharex >= 1: + if self.figure._is_share_label_group_member(self, "x"): + pass + elif self.figure._is_share_label_group_member(self._sharex, "x"): + self.xaxis.label.set_visible(False) + else: + labels._transfer_label(self.xaxis.label, self._sharex.xaxis.label) + self.xaxis.label.set_visible(False) + if self._sharey and self.figure._sharey >= 1: + if self.figure._is_share_label_group_member(self, "y"): + pass + elif self.figure._is_share_label_group_member(self._sharey, "y"): + self.yaxis.label.set_visible(False) + else: + labels._transfer_label(self.yaxis.label, self._sharey.yaxis.label) + self.yaxis.label.set_visible(False) + # Share interval x if self._sharex and self.figure._sharex >= 2: self._lonaxis.set_view_interval(*self._sharex._lonaxis.get_view_interval()) diff --git a/ultraplot/figure.py b/ultraplot/figure.py index a0f74d201..5a4e5d1db 100644 --- a/ultraplot/figure.py +++ b/ultraplot/figure.py @@ -814,6 +814,7 @@ def __init__( self._supxlabel_dict = {} # an axes: label mapping self._supylabel_dict = {} # an axes: label mapping self._suplabel_dict = {"left": {}, "right": {}, "bottom": {}, "top": {}} + self._share_label_groups = {"x": {}, "y": {}} # explicit label-sharing groups self._suptitle_pad = rc["suptitle.pad"] d = self._suplabel_props = {} # store the super label props d["left"] = {"va": "center", "ha": "right"} @@ -840,6 +841,7 @@ def draw(self, renderer): # we can use get_border_axes for the outermost plots and then collect their outermost panels that are not colorbars self._share_ticklabels(axis="x") self._share_ticklabels(axis="y") + self._apply_share_label_groups() super().draw(renderer) def _share_ticklabels(self, *, axis: str) -> None: @@ -1889,6 +1891,223 @@ def _align_axis_label(self, x): if span: self._update_axis_label(pos, axs) + # Apply explicit label-sharing groups for this axis + self._apply_share_label_groups(axis=x) + + def _register_share_label_group(self, axes, *, target, source=None): + """ + Register an explicit label-sharing group for a subset of axes. + """ + if not axes: + return + axes = list(axes) + axes = [ax for ax in axes if ax is not None and ax.figure is self] + if len(axes) < 2: + return + + # Preserve order while de-duplicating + seen = set() + unique = [] + for ax in axes: + ax_id = id(ax) + if ax_id in seen: + continue + seen.add(ax_id) + unique.append(ax) + axes = unique + if len(axes) < 2: + return + + # Split by label side if mixed + axes_by_side = {} + if target == "x": + for ax in axes: + axes_by_side.setdefault(ax.xaxis.get_label_position(), []).append(ax) + else: + for ax in axes: + axes_by_side.setdefault(ax.yaxis.get_label_position(), []).append(ax) + if len(axes_by_side) > 1: + for side, side_axes in axes_by_side.items(): + side_source = source if source in side_axes else None + self._register_share_label_group_for_side( + side_axes, target=target, side=side, source=side_source + ) + return + + side, side_axes = next(iter(axes_by_side.items())) + self._register_share_label_group_for_side( + side_axes, target=target, side=side, source=source + ) + + def _register_share_label_group_for_side(self, axes, *, target, side, source=None): + """ + Register a single label-sharing group for a given label side. + """ + if not axes: + return + axes = [ax for ax in axes if ax is not None and ax.figure is self] + if len(axes) < 2: + return + + # Prefer label text from the source axes if available + label = None + if source in axes: + candidate = getattr(source, f"{target}axis").label + if candidate.get_text().strip(): + label = candidate + if label is None: + for ax in axes: + candidate = getattr(ax, f"{target}axis").label + if candidate.get_text().strip(): + label = candidate + break + + text = label.get_text() if label else "" + props = None + if label is not None: + props = { + "color": label.get_color(), + "fontproperties": label.get_font_properties(), + "rotation": label.get_rotation(), + "rotation_mode": label.get_rotation_mode(), + "ha": label.get_ha(), + "va": label.get_va(), + } + + group_key = tuple(sorted(id(ax) for ax in axes)) + groups = self._share_label_groups[target] + group = groups.get(group_key) + if group is None: + groups[group_key] = { + "axes": axes, + "side": side, + "text": text if text.strip() else "", + "props": props, + } + else: + group["axes"] = axes + group["side"] = side + if text.strip(): + group["text"] = text + group["props"] = props + + def _is_share_label_group_member(self, ax, axis): + """ + Return True if the axes belongs to any explicit label-sharing group. + """ + groups = self._share_label_groups.get(axis, {}) + return any(ax in group["axes"] for group in groups.values()) + + def _has_share_label_groups(self, axis): + """ + Return True if there are any explicit label-sharing groups for an axis. + """ + return bool(self._share_label_groups.get(axis, {})) + + def _clear_share_label_groups(self, axes=None, *, target=None): + """ + Clear explicit label-sharing groups, optionally filtered by axes. + """ + targets = ("x", "y") if target is None else (target,) + for axis in targets: + groups = self._share_label_groups.get(axis, {}) + if axes is None: + groups.clear() + continue + axes_set = {ax for ax in axes if ax is not None} + for key in list(groups): + if any(ax in axes_set for ax in groups[key]["axes"]): + del groups[key] + # Clear any existing spanning labels tied to these axes + if axis == "x": + for ax in axes_set: + if ax in self._supxlabel_dict: + self._supxlabel_dict[ax].set_text("") + else: + for ax in axes_set: + if ax in self._supylabel_dict: + self._supylabel_dict[ax].set_text("") + + def _apply_share_label_groups(self, axis=None): + """ + Apply explicit label-sharing groups, overriding default label sharing. + """ + + def _order_axes_for_side(axs, side): + if side in ("bottom", "top"): + key = ( + (lambda ax: ax._range_subplotspec("y")[1]) + if side == "bottom" + else (lambda ax: ax._range_subplotspec("y")[0]) + ) + reverse = side == "bottom" + else: + key = ( + (lambda ax: ax._range_subplotspec("x")[1]) + if side == "right" + else (lambda ax: ax._range_subplotspec("x")[0]) + ) + reverse = side == "right" + try: + return sorted(axs, key=key, reverse=reverse) + except Exception: + return list(axs) + + axes = (axis,) if axis in ("x", "y") else ("x", "y") + for target in axes: + groups = self._share_label_groups.get(target, {}) + for group in groups.values(): + axs = [ + ax for ax in group["axes"] if ax.figure is self and ax.get_visible() + ] + if len(axs) < 2: + continue + + side = group["side"] + ordered_axs = _order_axes_for_side(axs, side) + + # Refresh label text from any axis with non-empty text + label = None + for ax in ordered_axs: + candidate = getattr(ax, f"{target}axis").label + if candidate.get_text().strip(): + label = candidate + break + text = group["text"] + props = group["props"] + if label is not None: + text = label.get_text() + props = { + "color": label.get_color(), + "fontproperties": label.get_font_properties(), + "rotation": label.get_rotation(), + "rotation_mode": label.get_rotation_mode(), + "ha": label.get_ha(), + "va": label.get_va(), + } + group["text"] = text + group["props"] = props + + if not text: + continue + + try: + _, ax = self._get_align_coord( + side, ordered_axs, includepanels=self._includepanels + ) + except Exception: + continue + axlab = getattr(ax, f"{target}axis").label + axlab.set_text(text) + if props is not None: + axlab.set_color(props["color"]) + axlab.set_fontproperties(props["fontproperties"]) + axlab.set_rotation(props["rotation"]) + axlab.set_rotation_mode(props["rotation_mode"]) + axlab.set_ha(props["ha"]) + axlab.set_va(props["va"]) + self._update_axis_label(side, ordered_axs) + def _align_super_labels(self, side, renderer): """ Adjust the position of super labels. diff --git a/ultraplot/gridspec.py b/ultraplot/gridspec.py index 63556ab0d..fadf0fa56 100644 --- a/ultraplot/gridspec.py +++ b/ultraplot/gridspec.py @@ -1749,7 +1749,43 @@ def format(self, **kwargs): ultraplot.figure.Figure.format ultraplot.config.Configurator.context """ + # Implicit label sharing for subset format calls + share_xlabels = kwargs.get("share_xlabels", None) + share_ylabels = kwargs.get("share_ylabels", None) + xlabel = kwargs.get("xlabel", None) + ylabel = kwargs.get("ylabel", None) + if len(self) > 1: + if share_xlabels is False: + self.figure._clear_share_label_groups(self, target="x") + if share_ylabels is False: + self.figure._clear_share_label_groups(self, target="y") + if share_xlabels is None and xlabel is not None: + self.figure._register_share_label_group(self, target="x") + if share_ylabels is None and ylabel is not None: + self.figure._register_share_label_group(self, target="y") self.figure.format(axs=self, **kwargs) + # Refresh groups after labels are set + if len(self) > 1: + if share_xlabels is None and xlabel is not None: + self.figure._register_share_label_group(self, target="x") + if share_ylabels is None and ylabel is not None: + self.figure._register_share_label_group(self, target="y") + + def share_labels(self, *, axis="x"): + """ + Register an explicit label-sharing group for this subset. + """ + if not self: + return self + axis = axis.lower() + if axis in ("x", "y"): + self.figure._register_share_label_group(self, target=axis) + elif axis in ("both", "all", "xy"): + self.figure._register_share_label_group(self, target="x") + self.figure._register_share_label_group(self, target="y") + else: + raise ValueError(f"Invalid axis={axis!r}. Options are 'x', 'y', or 'both'.") + return self @property def figure(self): From 01d31fc904f6f82d455e75e9d69e01d53b7cc236 Mon Sep 17 00:00:00 2001 From: cvanelteren Date: Tue, 30 Dec 2025 20:06:15 +1000 Subject: [PATCH 2/6] Add subset label sharing tests --- ultraplot/tests/test_subplots.py | 132 +++++++++++++++++++++++++++++++ 1 file changed, 132 insertions(+) diff --git a/ultraplot/tests/test_subplots.py b/ultraplot/tests/test_subplots.py index 86ed55a68..6cb81ff78 100644 --- a/ultraplot/tests/test_subplots.py +++ b/ultraplot/tests/test_subplots.py @@ -258,6 +258,138 @@ def test_axis_sharing(share): return fig +def test_subset_share_xlabels_override(): + fig, ax = uplt.subplots(ncols=2, nrows=2, share="labels", span=False) + ax[0, 0].format(xlabel="Top-left X") + ax[0, 1].format(xlabel="Top-right X") + bottom = ax[1, :] + bottom[0].format(xlabel="Bottom-row X", share_xlabels=list(bottom)) + + fig.canvas.draw() + + assert not ax[0, 0].xaxis.get_label().get_visible() + assert not ax[0, 1].xaxis.get_label().get_visible() + assert bottom[0].get_xlabel().strip() == "" + assert bottom[1].get_xlabel().strip() == "" + assert any(lab.get_text() == "Bottom-row X" for lab in fig._supxlabel_dict.values()) + + uplt.close(fig) + + +def test_subset_share_xlabels_implicit(): + fig, ax = uplt.subplots(ncols=2, nrows=2, share="labels", span=False) + ax[0, 0].format(xlabel="Top-left X") + ax[0, 1].format(xlabel="Top-right X") + bottom = ax[1, :] + bottom.format(xlabel="Bottom-row X") + + fig.canvas.draw() + + assert not ax[0, 0].xaxis.get_label().get_visible() + assert not ax[0, 1].xaxis.get_label().get_visible() + assert bottom[0].get_xlabel().strip() == "" + assert bottom[1].get_xlabel().strip() == "" + assert any(lab.get_text() == "Bottom-row X" for lab in fig._supxlabel_dict.values()) + + uplt.close(fig) + + +def test_subset_share_ylabels_override(): + fig, ax = uplt.subplots(ncols=2, nrows=2, share="labels", span=False) + ax[0, 0].format(ylabel="Left-top Y") + ax[1, 0].format(ylabel="Left-bottom Y") + right = ax[:, 1] + right[0].format(ylabel="Right-column Y", share_ylabels=list(right)) + + fig.canvas.draw() + + assert ax[0, 0].yaxis.get_label().get_visible() + assert ax[0, 0].get_ylabel() == "Left-top Y" + assert ax[1, 0].yaxis.get_label().get_visible() + assert ax[1, 0].get_ylabel() == "Left-bottom Y" + assert right[0].get_ylabel().strip() == "" + assert right[1].get_ylabel().strip() == "" + assert any( + lab.get_text() == "Right-column Y" for lab in fig._supylabel_dict.values() + ) + + uplt.close(fig) + + +def test_subset_share_xlabels_implicit_column(): + fig, ax = uplt.subplots(ncols=2, nrows=2, share=0, span=False) + right = ax[:, 1] + right.format(xlabel="Right-column X") + + fig.canvas.draw() + + assert ax[0, 1].get_xlabel().strip() == "" + assert ax[1, 1].get_xlabel().strip() == "" + label_axes = [ + axi + for axi, lab in fig._supxlabel_dict.items() + if lab.get_text() == "Right-column X" + ] + assert label_axes and label_axes[0] is ax[1, 1] + + uplt.close(fig) + + +def test_subset_share_ylabels_implicit_row(): + fig, ax = uplt.subplots(ncols=2, nrows=2, share=0, span=False) + top = ax[0, :] + top.format(ylabel="Top-row Y") + + fig.canvas.draw() + + assert ax[0, 0].get_ylabel().strip() == "" + assert ax[0, 1].get_ylabel().strip() == "" + label_axes = [ + axi for axi, lab in fig._supylabel_dict.items() if lab.get_text() == "Top-row Y" + ] + assert label_axes and label_axes[0] is ax[0, 0] + + uplt.close(fig) + + +def test_subset_share_xlabels_implicit_column_top(): + fig, ax = uplt.subplots(ncols=2, nrows=2, share=0, span=False) + right = ax[:, 1] + right.format(xlabel="Right-column X (top)", xlabelloc="top") + + fig.canvas.draw() + + assert ax[0, 1].get_xlabel().strip() == "" + assert ax[1, 1].get_xlabel().strip() == "" + label_axes = [ + axi + for axi, lab in fig._supxlabel_dict.items() + if lab.get_text() == "Right-column X (top)" + ] + assert label_axes and label_axes[0] is ax[0, 1] + + uplt.close(fig) + + +def test_subset_share_ylabels_implicit_row_right(): + fig, ax = uplt.subplots(ncols=2, nrows=2, share=0, span=False) + top = ax[0, :] + top.format(ylabel="Top-row Y (right)", ylabelloc="right") + + fig.canvas.draw() + + assert ax[0, 0].get_ylabel().strip() == "" + assert ax[0, 1].get_ylabel().strip() == "" + label_axes = [ + axi + for axi, lab in fig._supylabel_dict.items() + if lab.get_text() == "Top-row Y (right)" + ] + assert label_axes and label_axes[0] is ax[0, 1] + + uplt.close(fig) + + @pytest.mark.parametrize( "layout", [ From 9de2b6a24c0a7cb50af50ccb47c835957c651557 Mon Sep 17 00:00:00 2001 From: cvanelteren Date: Tue, 30 Dec 2025 20:06:26 +1000 Subject: [PATCH 3/6] Adjust geo subset label tests --- ultraplot/tests/test_geographic.py | 39 ++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/ultraplot/tests/test_geographic.py b/ultraplot/tests/test_geographic.py index 9f1842d7b..f1efed6ec 100644 --- a/ultraplot/tests/test_geographic.py +++ b/ultraplot/tests/test_geographic.py @@ -407,6 +407,45 @@ def test_geo_panel_share_flag_controls_membership(): assert ax2[0]._panel_sharex_group is False +def test_geo_subset_share_xlabels_override(): + fig, ax = uplt.subplots(ncols=2, nrows=2, proj="cyl", share="labels", span=False) + # GeoAxes.format does not accept xlabel/ylabel; set labels directly. + ax[0, 0].set_xlabel("Top-left X") + ax[0, 1].set_xlabel("Top-right X") + bottom = ax[1, :] + bottom[0].set_xlabel("Bottom-row X") + bottom.format(share_xlabels=list(bottom)) + + fig.canvas.draw() + + assert not ax[0, 0].xaxis.get_label().get_visible() + assert not ax[0, 1].xaxis.get_label().get_visible() + assert bottom[0].get_xlabel().strip() == "" + assert bottom[1].get_xlabel().strip() == "" + assert any(lab.get_text() == "Bottom-row X" for lab in fig._supxlabel_dict.values()) + + uplt.close(fig) + + +def test_geo_subset_share_xlabels_implicit(): + fig, ax = uplt.subplots(ncols=2, nrows=2, proj="cyl", share="labels", span=False) + ax[0, 0].set_xlabel("Top-left X") + ax[0, 1].set_xlabel("Top-right X") + bottom = ax[1, :] + bottom[0].set_xlabel("Bottom-row X") + bottom.share_labels(axis="x") + + fig.canvas.draw() + + assert not ax[0, 0].xaxis.get_label().get_visible() + assert not ax[0, 1].xaxis.get_label().get_visible() + assert bottom[0].get_xlabel().strip() == "" + assert bottom[1].get_xlabel().strip() == "" + assert any(lab.get_text() == "Bottom-row X" for lab in fig._supxlabel_dict.values()) + + uplt.close(fig) + + def test_geo_non_rectilinear_right_panel_forces_no_share_and_warns(): """ Non-rectilinear Geo projections should not allow panel sharing; adding a right panel From b7d4853db5dc88b7ac6ba18448d0e60525a414ed Mon Sep 17 00:00:00 2001 From: cvanelteren Date: Tue, 30 Dec 2025 21:08:34 +1000 Subject: [PATCH 4/6] Limit implicit label sharing to subsets --- ultraplot/gridspec.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/ultraplot/gridspec.py b/ultraplot/gridspec.py index fadf0fa56..288f1abc4 100644 --- a/ultraplot/gridspec.py +++ b/ultraplot/gridspec.py @@ -1754,21 +1754,24 @@ def format(self, **kwargs): share_ylabels = kwargs.get("share_ylabels", None) xlabel = kwargs.get("xlabel", None) ylabel = kwargs.get("ylabel", None) + axes = [ax for ax in self if ax is not None] + all_axes = set(self.figure._subplot_dict.values()) + is_subset = bool(axes) and all_axes and set(axes) != all_axes if len(self) > 1: if share_xlabels is False: self.figure._clear_share_label_groups(self, target="x") if share_ylabels is False: self.figure._clear_share_label_groups(self, target="y") - if share_xlabels is None and xlabel is not None: + if is_subset and share_xlabels is None and xlabel is not None: self.figure._register_share_label_group(self, target="x") - if share_ylabels is None and ylabel is not None: + if is_subset and share_ylabels is None and ylabel is not None: self.figure._register_share_label_group(self, target="y") self.figure.format(axs=self, **kwargs) # Refresh groups after labels are set if len(self) > 1: - if share_xlabels is None and xlabel is not None: + if is_subset and share_xlabels is None and xlabel is not None: self.figure._register_share_label_group(self, target="x") - if share_ylabels is None and ylabel is not None: + if is_subset and share_ylabels is None and ylabel is not None: self.figure._register_share_label_group(self, target="y") def share_labels(self, *, axis="x"): From 704ebd31c99ed96c044c42547846d5c6e2851708 Mon Sep 17 00:00:00 2001 From: cvanelteren Date: Tue, 30 Dec 2025 21:24:56 +1000 Subject: [PATCH 5/6] Expand subset label sharing coverage --- ultraplot/tests/test_subplots.py | 68 ++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/ultraplot/tests/test_subplots.py b/ultraplot/tests/test_subplots.py index 6cb81ff78..eb42c79fc 100644 --- a/ultraplot/tests/test_subplots.py +++ b/ultraplot/tests/test_subplots.py @@ -352,6 +352,74 @@ def test_subset_share_ylabels_implicit_row(): uplt.close(fig) +def test_subset_share_xlabels_clear(): + fig, ax = uplt.subplots(ncols=2, nrows=2, share=0, span=False) + bottom = ax[1, :] + bottom.format(xlabel="Shared") + + fig.canvas.draw() + assert any(lab.get_text() == "Shared" for lab in fig._supxlabel_dict.values()) + + bottom.format(share_xlabels=False, xlabel="Unshared") + fig.canvas.draw() + + assert not any(lab.get_text() == "Shared" for lab in fig._supxlabel_dict.values()) + assert not any(lab.get_text() == "Unshared" for lab in fig._supxlabel_dict.values()) + assert bottom[0].get_xlabel() == "Unshared" + assert bottom[1].get_xlabel() == "Unshared" + + uplt.close(fig) + + +def test_subset_share_labels_method_both(): + fig, ax = uplt.subplots(ncols=2, nrows=2, share=0, span=False) + right = ax[:, 1] + right[0].set_xlabel("Right-column X") + right[0].set_ylabel("Right-column Y") + right.share_labels(axis="both") + + fig.canvas.draw() + + assert right[0].get_xlabel().strip() == "" + assert right[1].get_xlabel().strip() == "" + assert right[0].get_ylabel().strip() == "" + assert right[1].get_ylabel().strip() == "" + assert any( + lab.get_text() == "Right-column X" for lab in fig._supxlabel_dict.values() + ) + assert any( + lab.get_text() == "Right-column Y" for lab in fig._supylabel_dict.values() + ) + + uplt.close(fig) + + +def test_subset_share_labels_invalid_axis(): + fig, ax = uplt.subplots(ncols=2, nrows=2, share=0, span=False) + with pytest.raises(ValueError): + ax[:, 1].share_labels(axis="nope") + + uplt.close(fig) + + +def test_subset_share_xlabels_mixed_sides(): + fig, ax = uplt.subplots(ncols=2, nrows=2, share=0, span=False) + ax[0, :].format(xlabelloc="top", share_xlabels=False) + ax[1, :].format(xlabelloc="bottom", share_xlabels=False) + ax[0, 0].set_xlabel("Top X") + ax[0, 1].set_xlabel("Top X") + ax[1, 0].set_xlabel("Bottom X") + ax[1, 1].set_xlabel("Bottom X") + ax[0, 0].format(share_xlabels=list(ax)) + + fig.canvas.draw() + + assert any(lab.get_text() == "Top X" for lab in fig._supxlabel_dict.values()) + assert any(lab.get_text() == "Bottom X" for lab in fig._supxlabel_dict.values()) + + uplt.close(fig) + + def test_subset_share_xlabels_implicit_column_top(): fig, ax = uplt.subplots(ncols=2, nrows=2, share=0, span=False) right = ax[:, 1] From 3f5e1347613d540e275e7348853e0bc49c742818 Mon Sep 17 00:00:00 2001 From: cvanelteren Date: Tue, 30 Dec 2025 21:28:08 +1000 Subject: [PATCH 6/6] dedup logic --- ultraplot/axes/base.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/ultraplot/axes/base.py b/ultraplot/axes/base.py index c297563fc..01cc96d51 100644 --- a/ultraplot/axes/base.py +++ b/ultraplot/axes/base.py @@ -3151,10 +3151,7 @@ def _update_share_labels(self, axes=None, target="x"): if axes is False: self.figure._clear_share_label_groups([self], target=target) return - if not axes: - return - axes = list(axes) - if not axes: + if axes is None or not len(list(axes)): return # Convert indices to actual axes objects