diff --git a/HISTORY.md b/HISTORY.md index d0643cae..6265ac0f 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -31,8 +31,10 @@ Our backwards-compatibility policy can be found [here](https://github.com/python ([#707](https://github.com/python-attrs/cattrs/issues/707) [#708](https://github.com/python-attrs/cattrs/pull/708)) - Enum handling has been optimized by switching to hook factories, improving performance especially for plain enums. ([#705](https://github.com/python-attrs/cattrs/pull/705)) -- Fix `include_subclasses` when used with `configure_tagged_union` and classes using diamond inheritance. +- Fix {func}`cattrs.strategies.include_subclasses` when used with {func}`cattrs.strategies.configure_tagged_union` and classes using diamond inheritance. ([#685](https://github.com/python-attrs/cattrs/issues/685) [#713](https://github.com/python-attrs/cattrs/pull/713)) +- Fix {func}`cattrs.strategies.configure_tagged_union` when used with recursive type aliases. + ([#678](https://github.com/python-attrs/cattrs/issues/678) [#714](https://github.com/python-attrs/cattrs/pull/714)) ## 25.3.0 (2025-10-07) diff --git a/src/cattrs/strategies/_unions.py b/src/cattrs/strategies/_unions.py index 7cc06d81..14add983 100644 --- a/src/cattrs/strategies/_unions.py +++ b/src/cattrs/strategies/_unions.py @@ -52,23 +52,10 @@ def configure_tagged_union( if is_type_alias(union): union = union.__value__ args = union.__args__ + tag_to_hook = {} exact_cl_unstruct_hooks = {} - for cl in args: - tag = tag_generator(cl) - struct_handler = converter.get_structure_hook(cl) - unstruct_handler = converter.get_unstructure_hook(cl) - - def structure_union_member(val: dict, _cl=cl, _h=struct_handler) -> cl: - return _h(val, _cl) - - def unstructure_union_member(val: union, _h=unstruct_handler) -> dict: - return _h(val) - - tag_to_hook[tag] = structure_union_member - exact_cl_unstruct_hooks[cl] = unstructure_union_member - - cl_to_tag = {cl: tag_generator(cl) for cl in args} + cl_to_tag = {} if default is not NOTHING: default_handler = converter.get_structure_hook(default) @@ -76,36 +63,9 @@ def unstructure_union_member(val: union, _h=unstruct_handler) -> dict: def structure_default(val: dict, _cl=default, _h=default_handler): return _h(val, _cl) - tag_to_hook = defaultdict(lambda: structure_default, tag_to_hook) - cl_to_tag = defaultdict(lambda: default, cl_to_tag) + tag_to_hook = defaultdict(lambda: structure_default) + cl_to_tag = defaultdict(lambda: default) - def unstructure_tagged_union( - val: union, - _exact_cl_unstruct_hooks=exact_cl_unstruct_hooks, - _cl_to_tag=cl_to_tag, - _tag_name=tag_name, - ) -> dict: - res = _exact_cl_unstruct_hooks[val.__class__](val) - res[_tag_name] = _cl_to_tag[val.__class__] - return res - - if default is NOTHING: - if getattr(converter, "forbid_extra_keys", False): - - def structure_tagged_union( - val: dict, _, _tag_to_cl=tag_to_hook, _tag_name=tag_name - ) -> union: - val = val.copy() - return _tag_to_cl[val.pop(_tag_name)](val) - - else: - - def structure_tagged_union( - val: dict, _, _tag_to_cl=tag_to_hook, _tag_name=tag_name - ) -> union: - return _tag_to_cl[val[_tag_name]](val) - - else: if getattr(converter, "forbid_extra_keys", False): def structure_tagged_union( @@ -135,9 +95,50 @@ def structure_tagged_union( return _tag_to_hook[val[_tag_name]](val) return _dh(val, _default) + else: + if getattr(converter, "forbid_extra_keys", False): + + def structure_tagged_union( + val: dict, _, _tag_to_cl=tag_to_hook, _tag_name=tag_name + ) -> union: + val = val.copy() + return _tag_to_cl[val.pop(_tag_name)](val) + + else: + + def structure_tagged_union( + val: dict, _, _tag_to_cl=tag_to_hook, _tag_name=tag_name + ) -> union: + return _tag_to_cl[val[_tag_name]](val) + + def unstructure_tagged_union( + val: union, + _exact_cl_unstruct_hooks=exact_cl_unstruct_hooks, + _cl_to_tag=cl_to_tag, + _tag_name=tag_name, + ) -> dict: + res = _exact_cl_unstruct_hooks[val.__class__](val) + res[_tag_name] = _cl_to_tag[val.__class__] + return res + converter.register_unstructure_hook(union, unstructure_tagged_union) converter.register_structure_hook(union, structure_tagged_union) + for cl in args: + tag = tag_generator(cl) + struct_handler = converter.get_structure_hook(cl) + unstruct_handler = converter.get_unstructure_hook(cl) + + def structure_union_member(val: dict, _cl=cl, _h=struct_handler) -> cl: + return _h(val, _cl) + + def unstructure_union_member(val: union, _h=unstruct_handler) -> dict: + return _h(val) + + tag_to_hook[tag] = structure_union_member + exact_cl_unstruct_hooks[cl] = unstructure_union_member + cl_to_tag[cl] = tag + def configure_union_passthrough( union: Any, converter: BaseConverter, accept_ints_as_floats: bool = True diff --git a/tests/strategies/test_tagged_unions_695.py b/tests/strategies/test_tagged_unions_695.py index 34d6fcc8..42043a56 100644 --- a/tests/strategies/test_tagged_unions_695.py +++ b/tests/strategies/test_tagged_unions_695.py @@ -1,13 +1,13 @@ -import pytest +from __future__ import annotations -from cattrs import BaseConverter +from attrs import define + +from cattrs import BaseConverter, Converter from cattrs.strategies import configure_tagged_union -from .._compat import is_py312_plus from .test_tagged_unions import A, B -@pytest.mark.skipif(not is_py312_plus, reason="New type alias syntax") def test_type_alias(converter: BaseConverter): """Type aliases to unions also work.""" type AOrB = A | B @@ -19,3 +19,37 @@ def test_type_alias(converter: BaseConverter): assert converter.structure({"_type": "A", "a": 1}, AOrB) == A(1) assert converter.structure({"_type": "B", "a": 1}, AOrB) == B("1") + + +@define +class Lit: + value: float + + +@define +class Add: + left: Expr + right: Expr + + +type Expr = Add | Lit + + +def test_recursive_type_alias(genconverter: Converter): + """Recursive type aliases to unions also work. + + Only tests on the GenConverter since the BaseConverter doesn't support + stringified annotations. + """ + + configure_tagged_union(Expr, genconverter) + + val = Add(Lit(1.0), Lit(2.0)) + expected = { + "_type": "Add", + "left": {"_type": "Lit", "value": 1.0}, + "right": {"_type": "Lit", "value": 2.0}, + } + + assert genconverter.unstructure(val, Expr) == expected + assert genconverter.structure(expected, Expr) == val