Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@ Our backwards-compatibility policy can be found [here](https://github.com/python
([#707](https://github.com/python-attrs/cattrs/issues/707) [#708](https://github.com/python-attrs/cattrs/pull/708))
- Enum handling has been optimized by switching to hook factories, improving performance especially for plain enums.
([#705](https://github.com/python-attrs/cattrs/pull/705))
- Fix `include_subclasses` when used with `configure_tagged_union` and classes using diamond inheritance.
- Fix {func}`cattrs.strategies.include_subclasses` when used with {func}`cattrs.strategies.configure_tagged_union` and classes using diamond inheritance.
([#685](https://github.com/python-attrs/cattrs/issues/685) [#713](https://github.com/python-attrs/cattrs/pull/713))
- Fix {func}`cattrs.strategies.configure_tagged_union` when used with recursive type aliases.
([#678](https://github.com/python-attrs/cattrs/issues/678) [#714](https://github.com/python-attrs/cattrs/pull/714))

## 25.3.0 (2025-10-07)

Expand Down
89 changes: 45 additions & 44 deletions src/cattrs/strategies/_unions.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,60 +52,20 @@ def configure_tagged_union(
if is_type_alias(union):
union = union.__value__
args = union.__args__

tag_to_hook = {}
exact_cl_unstruct_hooks = {}
for cl in args:
tag = tag_generator(cl)
struct_handler = converter.get_structure_hook(cl)
unstruct_handler = converter.get_unstructure_hook(cl)

def structure_union_member(val: dict, _cl=cl, _h=struct_handler) -> cl:
return _h(val, _cl)

def unstructure_union_member(val: union, _h=unstruct_handler) -> dict:
return _h(val)

tag_to_hook[tag] = structure_union_member
exact_cl_unstruct_hooks[cl] = unstructure_union_member

cl_to_tag = {cl: tag_generator(cl) for cl in args}
cl_to_tag = {}

if default is not NOTHING:
default_handler = converter.get_structure_hook(default)

def structure_default(val: dict, _cl=default, _h=default_handler):
return _h(val, _cl)

tag_to_hook = defaultdict(lambda: structure_default, tag_to_hook)
cl_to_tag = defaultdict(lambda: default, cl_to_tag)
tag_to_hook = defaultdict(lambda: structure_default)
cl_to_tag = defaultdict(lambda: default)

def unstructure_tagged_union(
val: union,
_exact_cl_unstruct_hooks=exact_cl_unstruct_hooks,
_cl_to_tag=cl_to_tag,
_tag_name=tag_name,
) -> dict:
res = _exact_cl_unstruct_hooks[val.__class__](val)
res[_tag_name] = _cl_to_tag[val.__class__]
return res

if default is NOTHING:
if getattr(converter, "forbid_extra_keys", False):

def structure_tagged_union(
val: dict, _, _tag_to_cl=tag_to_hook, _tag_name=tag_name
) -> union:
val = val.copy()
return _tag_to_cl[val.pop(_tag_name)](val)

else:

def structure_tagged_union(
val: dict, _, _tag_to_cl=tag_to_hook, _tag_name=tag_name
) -> union:
return _tag_to_cl[val[_tag_name]](val)

else:
if getattr(converter, "forbid_extra_keys", False):

def structure_tagged_union(
Expand Down Expand Up @@ -135,9 +95,50 @@ def structure_tagged_union(
return _tag_to_hook[val[_tag_name]](val)
return _dh(val, _default)

else:
if getattr(converter, "forbid_extra_keys", False):

def structure_tagged_union(
val: dict, _, _tag_to_cl=tag_to_hook, _tag_name=tag_name
) -> union:
val = val.copy()
return _tag_to_cl[val.pop(_tag_name)](val)

else:

def structure_tagged_union(
val: dict, _, _tag_to_cl=tag_to_hook, _tag_name=tag_name
) -> union:
return _tag_to_cl[val[_tag_name]](val)

def unstructure_tagged_union(
val: union,
_exact_cl_unstruct_hooks=exact_cl_unstruct_hooks,
_cl_to_tag=cl_to_tag,
_tag_name=tag_name,
) -> dict:
res = _exact_cl_unstruct_hooks[val.__class__](val)
res[_tag_name] = _cl_to_tag[val.__class__]
return res

converter.register_unstructure_hook(union, unstructure_tagged_union)
converter.register_structure_hook(union, structure_tagged_union)

for cl in args:
tag = tag_generator(cl)
struct_handler = converter.get_structure_hook(cl)
unstruct_handler = converter.get_unstructure_hook(cl)

def structure_union_member(val: dict, _cl=cl, _h=struct_handler) -> cl:
return _h(val, _cl)

def unstructure_union_member(val: union, _h=unstruct_handler) -> dict:
return _h(val)

tag_to_hook[tag] = structure_union_member
exact_cl_unstruct_hooks[cl] = unstructure_union_member
cl_to_tag[cl] = tag


def configure_union_passthrough(
union: Any, converter: BaseConverter, accept_ints_as_floats: bool = True
Expand Down
42 changes: 38 additions & 4 deletions tests/strategies/test_tagged_unions_695.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import pytest
from __future__ import annotations

from cattrs import BaseConverter
from attrs import define

from cattrs import BaseConverter, Converter
from cattrs.strategies import configure_tagged_union

from .._compat import is_py312_plus
from .test_tagged_unions import A, B


@pytest.mark.skipif(not is_py312_plus, reason="New type alias syntax")
def test_type_alias(converter: BaseConverter):
"""Type aliases to unions also work."""
type AOrB = A | B
Expand All @@ -19,3 +19,37 @@ def test_type_alias(converter: BaseConverter):

assert converter.structure({"_type": "A", "a": 1}, AOrB) == A(1)
assert converter.structure({"_type": "B", "a": 1}, AOrB) == B("1")


@define
class Lit:
value: float


@define
class Add:
left: Expr
right: Expr


type Expr = Add | Lit


def test_recursive_type_alias(genconverter: Converter):
"""Recursive type aliases to unions also work.

Only tests on the GenConverter since the BaseConverter doesn't support
stringified annotations.
"""

configure_tagged_union(Expr, genconverter)

val = Add(Lit(1.0), Lit(2.0))
expected = {
"_type": "Add",
"left": {"_type": "Lit", "value": 1.0},
"right": {"_type": "Lit", "value": 2.0},
}

assert genconverter.unstructure(val, Expr) == expected
assert genconverter.structure(expected, Expr) == val
Loading