From 3e12319853338d9c22c7ce925afa2138aa75c079 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Fri, 16 Jan 2026 17:25:01 +0100 Subject: [PATCH 1/6] add Mooncake extension --- Project.toml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index ff890157a..7230feece 100644 --- a/Project.toml +++ b/Project.toml @@ -23,12 +23,14 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" [extensions] TensorKitAdaptExt = "Adapt" TensorKitCUDAExt = ["CUDA", "cuTENSOR"] TensorKitChainRulesCoreExt = "ChainRulesCore" TensorKitFiniteDifferencesExt = "FiniteDifferences" +TensorKitMooncakeExt = "Mooncake" [compat] Adapt = "4" @@ -70,6 +72,7 @@ Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" @@ -78,4 +81,4 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1" [targets] -test = ["ArgParse", "Adapt", "Aqua", "Combinatorics", "CUDA", "cuTENSOR", "GPUArrays", "LinearAlgebra", "SafeTestsets", "TensorOperations", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences", "Zygote"] +test = ["ArgParse", "Adapt", "Aqua", "Combinatorics", "CUDA", "cuTENSOR", "GPUArrays", "LinearAlgebra", "SafeTestsets", "TensorOperations", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences", "Zygote", "Mooncake"] From 0779ec406427effa3288747bbd8fba8cd31773d6 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Fri, 16 Jan 2026 17:25:08 +0100 Subject: [PATCH 2/6] start adding some rules --- .../TensorKitMooncakeExt.jl | 158 ++++++++++++++++++ 1 file changed, 158 insertions(+) create mode 100644 ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl diff --git a/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl b/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl new file mode 100644 index 000000000..f79f5143e --- /dev/null +++ b/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl @@ -0,0 +1,158 @@ +module TensorKitMooncakeExt + +using Mooncake +using Mooncake: @zero_derivative, DefaultCtx, ReverseMode, NoRData, CoDual, arrayify, primal +using TensorKit +using TensorOperations: TensorOperations, tensorcontract!, IndexTuple, Index2Tuple, linearize +import TensorOperations as TO +using VectorInterface: One, Zero +using TupleTools + +# Ignore derivatives +# ------------------ +@zero_derivative DefaultCtx Tuple{typeof(TensorKit.fusionblockstructure), Any} + +_needs_tangent(x) = _needs_tangent(typeof(x)) +_needs_tangent(::Type{<:Number}) = true +_needs_tangent(::Type{<:Integer}) = false +_needs_tangent(::Type{<:Union{One, Zero}}) = false + + +function Mooncake.arrayify(A_dA::CoDual{<:TensorMap}) + A = Mooncake.primal(A_dA) + dA_fw = Mooncake.tangent(A_dA) + data = dA_fw.data.data + dA = typeof(A)(data, A.space) + return A, dA +end +trivtuple(N) = ntuple(identity, N) +Base.@constprop :aggressive function _repartition(p::IndexTuple, N₁::Int) + length(p) >= N₁ || + throw(ArgumentError("cannot repartition $(typeof(p)) to $N₁, $(length(p) - N₁)")) + return TupleTools.getindices(p, trivtuple(N₁)), + TupleTools.getindices(p, trivtuple(length(p) - N₁) .+ N₁) +end +Base.@constprop :aggressive function _repartition(p::Index2Tuple, N₁::Int) + return _repartition(linearize(p), N₁) +end +function _repartition(p::Union{IndexTuple, Index2Tuple}, ::Index2Tuple{N₁}) where {N₁} + return _repartition(p, N₁) +end +function _repartition(p::Union{IndexTuple, Index2Tuple}, t::AbstractTensorMap) + return _repartition(p, TensorKit.numout(t)) +end + +Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(tensorcontract!), AbstractTensorMap, AbstractTensorMap, Index2Tuple, Bool, AbstractTensorMap, Index2Tuple, Bool, Index2Tuple, Number, Number, Vararg{Any}} +function Mooncake.rrule!!( + ::CoDual{typeof(tensorcontract!)}, + C_dC::CoDual{<:AbstractTensorMap{TC}}, + A_dA::CoDual{<:AbstractTensorMap{TA}}, pA_dpA::CoDual{<:Index2Tuple}, conjA_dconjA::CoDual{Bool}, + B_dB::CoDual{<:AbstractTensorMap{TB}}, pB_dpB::CoDual{<:Index2Tuple}, conjB_dconjB::CoDual{Bool}, + pAB_dpAB::CoDual{<:Index2Tuple}, + α_dα::CoDual{Tα}, β_dβ::CoDual{Tβ}, + ba_dba::CoDual..., + ) where {Tα <: Number, Tβ <: Number, TA <: Number, TB <: Number, TC <: Number} + C, ΔC = arrayify(C_dC) + A, ΔA = arrayify(A_dA) + B, ΔB = arrayify(B_dB) + pA = primal(pA_dpA) + pB = primal(pB_dpB) + pAB = primal(pAB_dpAB) + conjA = primal(conjA_dconjA) + conjB = primal(conjB_dconjB) + α = primal(α_dα) + β = primal(β_dβ) + ba = primal.(ba_dba) + C_cache = copy(C) + TensorOperations.tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba...) + + function tensorcontract_pullback(::NoRData) + copy!(C, C_cache) + if Tα == Zero && Tβ == Zero + scale!(ΔC, zero(TC)) + return ntuple(i -> NoRData(), 11 + length(ba)) + end + ipAB = invperm(linearize(pAB)) + pΔC = _repartition(ipAB, TO.numout(pA)) + + # dC + if β === Zero() + scale!(ΔC, β) + else + scale!(ΔC, conj(β)) + end + + # dA + ipA = _repartition(invperm(linearize(pA)), A) + conjΔC = conjA + conjB′ = conjA ? conjB : !conjB + # TODO: allocator + tB = twist( + B, + TupleTools.vcat( + filter(x -> !isdual(space(B, x)), pB[1]), + filter(x -> isdual(space(B, x)), pB[2]) + ); copy = false + ) + tensorcontract!( + ΔA, + ΔC, pΔC, conjΔC, + tB, reverse(pB), conjB′, + ipA, + conjA ? α : conj(α), Zero(), ba... + ) + + # dB + ipB = _repartition(invperm(linearize(pB)), B) + conjΔC = conjB + conjA′ = conjB ? conjA : !conjA + # TODO: allocator + tA = twist( + A, + TupleTools.vcat( + filter(x -> isdual(space(A, x)), pA[1]), + filter(x -> !isdual(space(A, x)), pA[2]) + ); copy = false + ) + tensorcontract!( + ΔB, + tA, reverse(pA), conjA′, + ΔC, pΔC, conjΔC, + ipB, + conjB ? α : conj(α), Zero(), ba... + ) + + dα = if _needs_tangent(Tα) + AB = tensorcontract(A, pA, conjA, B, pB, conjB, pAB, One(), ba...) + Mooncake._rdata(inner(AB, ΔC)) + else + NoRData() + end + dβ = if _needs_tangent(Tβ) + # TODO: consider using `inner` + Mooncake._rdata(inner(C, ΔC)) + else + NoRData() + end + + return NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), dα, dβ, map(ba_ -> NoRData(), ba)... + end + return C_dC, tensorcontract_pullback +end + + +Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(norm), AbstractTensorMap, Real} +function Mooncake.rrule!!(::CoDual{typeof(norm)}, tΔt::CoDual{<:AbstractTensorMap}, pdp::CoDual{<:Real}) + t, Δt = arrayify(tΔt) + p = primal(pdp) + p == 2 || error("currently only implemented for p = 2") + n = norm(t, p) + function norm_pullback(Δn) + x = (Δn' + Δn) / 2 / hypot(n, eps(one(n))) + add!(Δt, t, x) + return NoRData(), NoRData(), NoRData() + end + return CoDual(n, Mooncake.NoFData()), norm_pullback +end + +end From db467bd9191dba7c554e4cd02b485a9e9af404d3 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Fri, 16 Jan 2026 21:36:22 +0100 Subject: [PATCH 3/6] reorganize mooncake extension --- .../TensorKitMooncakeExt.jl | 149 +----------------- ext/TensorKitMooncakeExt/linalg.jl | 14 ++ ext/TensorKitMooncakeExt/tangent.jl | 7 + ext/TensorKitMooncakeExt/tensoroperations.jl | 98 ++++++++++++ ext/TensorKitMooncakeExt/utility.jl | 28 ++++ 5 files changed, 151 insertions(+), 145 deletions(-) create mode 100644 ext/TensorKitMooncakeExt/linalg.jl create mode 100644 ext/TensorKitMooncakeExt/tangent.jl create mode 100644 ext/TensorKitMooncakeExt/tensoroperations.jl create mode 100644 ext/TensorKitMooncakeExt/utility.jl diff --git a/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl b/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl index f79f5143e..57074ede1 100644 --- a/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl +++ b/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl @@ -8,151 +8,10 @@ import TensorOperations as TO using VectorInterface: One, Zero using TupleTools -# Ignore derivatives -# ------------------ -@zero_derivative DefaultCtx Tuple{typeof(TensorKit.fusionblockstructure), Any} -_needs_tangent(x) = _needs_tangent(typeof(x)) -_needs_tangent(::Type{<:Number}) = true -_needs_tangent(::Type{<:Integer}) = false -_needs_tangent(::Type{<:Union{One, Zero}}) = false - - -function Mooncake.arrayify(A_dA::CoDual{<:TensorMap}) - A = Mooncake.primal(A_dA) - dA_fw = Mooncake.tangent(A_dA) - data = dA_fw.data.data - dA = typeof(A)(data, A.space) - return A, dA -end -trivtuple(N) = ntuple(identity, N) -Base.@constprop :aggressive function _repartition(p::IndexTuple, N₁::Int) - length(p) >= N₁ || - throw(ArgumentError("cannot repartition $(typeof(p)) to $N₁, $(length(p) - N₁)")) - return TupleTools.getindices(p, trivtuple(N₁)), - TupleTools.getindices(p, trivtuple(length(p) - N₁) .+ N₁) -end -Base.@constprop :aggressive function _repartition(p::Index2Tuple, N₁::Int) - return _repartition(linearize(p), N₁) -end -function _repartition(p::Union{IndexTuple, Index2Tuple}, ::Index2Tuple{N₁}) where {N₁} - return _repartition(p, N₁) -end -function _repartition(p::Union{IndexTuple, Index2Tuple}, t::AbstractTensorMap) - return _repartition(p, TensorKit.numout(t)) -end - -Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(tensorcontract!), AbstractTensorMap, AbstractTensorMap, Index2Tuple, Bool, AbstractTensorMap, Index2Tuple, Bool, Index2Tuple, Number, Number, Vararg{Any}} -function Mooncake.rrule!!( - ::CoDual{typeof(tensorcontract!)}, - C_dC::CoDual{<:AbstractTensorMap{TC}}, - A_dA::CoDual{<:AbstractTensorMap{TA}}, pA_dpA::CoDual{<:Index2Tuple}, conjA_dconjA::CoDual{Bool}, - B_dB::CoDual{<:AbstractTensorMap{TB}}, pB_dpB::CoDual{<:Index2Tuple}, conjB_dconjB::CoDual{Bool}, - pAB_dpAB::CoDual{<:Index2Tuple}, - α_dα::CoDual{Tα}, β_dβ::CoDual{Tβ}, - ba_dba::CoDual..., - ) where {Tα <: Number, Tβ <: Number, TA <: Number, TB <: Number, TC <: Number} - C, ΔC = arrayify(C_dC) - A, ΔA = arrayify(A_dA) - B, ΔB = arrayify(B_dB) - pA = primal(pA_dpA) - pB = primal(pB_dpB) - pAB = primal(pAB_dpAB) - conjA = primal(conjA_dconjA) - conjB = primal(conjB_dconjB) - α = primal(α_dα) - β = primal(β_dβ) - ba = primal.(ba_dba) - C_cache = copy(C) - TensorOperations.tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba...) - - function tensorcontract_pullback(::NoRData) - copy!(C, C_cache) - if Tα == Zero && Tβ == Zero - scale!(ΔC, zero(TC)) - return ntuple(i -> NoRData(), 11 + length(ba)) - end - ipAB = invperm(linearize(pAB)) - pΔC = _repartition(ipAB, TO.numout(pA)) - - # dC - if β === Zero() - scale!(ΔC, β) - else - scale!(ΔC, conj(β)) - end - - # dA - ipA = _repartition(invperm(linearize(pA)), A) - conjΔC = conjA - conjB′ = conjA ? conjB : !conjB - # TODO: allocator - tB = twist( - B, - TupleTools.vcat( - filter(x -> !isdual(space(B, x)), pB[1]), - filter(x -> isdual(space(B, x)), pB[2]) - ); copy = false - ) - tensorcontract!( - ΔA, - ΔC, pΔC, conjΔC, - tB, reverse(pB), conjB′, - ipA, - conjA ? α : conj(α), Zero(), ba... - ) - - # dB - ipB = _repartition(invperm(linearize(pB)), B) - conjΔC = conjB - conjA′ = conjB ? conjA : !conjA - # TODO: allocator - tA = twist( - A, - TupleTools.vcat( - filter(x -> isdual(space(A, x)), pA[1]), - filter(x -> !isdual(space(A, x)), pA[2]) - ); copy = false - ) - tensorcontract!( - ΔB, - tA, reverse(pA), conjA′, - ΔC, pΔC, conjΔC, - ipB, - conjB ? α : conj(α), Zero(), ba... - ) - - dα = if _needs_tangent(Tα) - AB = tensorcontract(A, pA, conjA, B, pB, conjB, pAB, One(), ba...) - Mooncake._rdata(inner(AB, ΔC)) - else - NoRData() - end - dβ = if _needs_tangent(Tβ) - # TODO: consider using `inner` - Mooncake._rdata(inner(C, ΔC)) - else - NoRData() - end - - return NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), dα, dβ, map(ba_ -> NoRData(), ba)... - end - return C_dC, tensorcontract_pullback -end - - -Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(norm), AbstractTensorMap, Real} -function Mooncake.rrule!!(::CoDual{typeof(norm)}, tΔt::CoDual{<:AbstractTensorMap}, pdp::CoDual{<:Real}) - t, Δt = arrayify(tΔt) - p = primal(pdp) - p == 2 || error("currently only implemented for p = 2") - n = norm(t, p) - function norm_pullback(Δn) - x = (Δn' + Δn) / 2 / hypot(n, eps(one(n))) - add!(Δt, t, x) - return NoRData(), NoRData(), NoRData() - end - return CoDual(n, Mooncake.NoFData()), norm_pullback -end +include("utility.jl") +include("tangent.jl") +include("linalg.jl") +include("tensoroperations.jl") end diff --git a/ext/TensorKitMooncakeExt/linalg.jl b/ext/TensorKitMooncakeExt/linalg.jl new file mode 100644 index 000000000..56533d227 --- /dev/null +++ b/ext/TensorKitMooncakeExt/linalg.jl @@ -0,0 +1,14 @@ +Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(norm), AbstractTensorMap, Real} + +function Mooncake.rrule!!(::CoDual{typeof(norm)}, tΔt::CoDual{<:AbstractTensorMap}, pdp::CoDual{<:Real}) + t, Δt = arrayify(tΔt) + p = primal(pdp) + p == 2 || error("currently only implemented for p = 2") + n = norm(t, p) + function norm_pullback(Δn) + x = (Δn' + Δn) / 2 / hypot(n, eps(one(n))) + add!(Δt, t, x) + return NoRData(), NoRData(), NoRData() + end + return CoDual(n, Mooncake.NoFData()), norm_pullback +end diff --git a/ext/TensorKitMooncakeExt/tangent.jl b/ext/TensorKitMooncakeExt/tangent.jl new file mode 100644 index 000000000..761e626f0 --- /dev/null +++ b/ext/TensorKitMooncakeExt/tangent.jl @@ -0,0 +1,7 @@ +function Mooncake.arrayify(A_dA::CoDual{<:TensorMap}) + A = Mooncake.primal(A_dA) + dA_fw = Mooncake.tangent(A_dA) + data = dA_fw.data.data + dA = typeof(A)(data, A.space) + return A, dA +end diff --git a/ext/TensorKitMooncakeExt/tensoroperations.jl b/ext/TensorKitMooncakeExt/tensoroperations.jl new file mode 100644 index 000000000..f06d4d345 --- /dev/null +++ b/ext/TensorKitMooncakeExt/tensoroperations.jl @@ -0,0 +1,98 @@ +Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(tensorcontract!), AbstractTensorMap, AbstractTensorMap, Index2Tuple, Bool, AbstractTensorMap, Index2Tuple, Bool, Index2Tuple, Number, Number, Vararg{Any}} + +function Mooncake.rrule!!( + ::CoDual{typeof(tensorcontract!)}, + C_dC::CoDual{<:AbstractTensorMap{TC}}, + A_dA::CoDual{<:AbstractTensorMap{TA}}, pA_dpA::CoDual{<:Index2Tuple}, conjA_dconjA::CoDual{Bool}, + B_dB::CoDual{<:AbstractTensorMap{TB}}, pB_dpB::CoDual{<:Index2Tuple}, conjB_dconjB::CoDual{Bool}, + pAB_dpAB::CoDual{<:Index2Tuple}, + α_dα::CoDual{Tα}, β_dβ::CoDual{Tβ}, + ba_dba::CoDual..., + ) where {Tα <: Number, Tβ <: Number, TA <: Number, TB <: Number, TC <: Number} + C, ΔC = arrayify(C_dC) + A, ΔA = arrayify(A_dA) + B, ΔB = arrayify(B_dB) + pA = primal(pA_dpA) + pB = primal(pB_dpB) + pAB = primal(pAB_dpAB) + conjA = primal(conjA_dconjA) + conjB = primal(conjB_dconjB) + α = primal(α_dα) + β = primal(β_dβ) + ba = primal.(ba_dba) + C_cache = copy(C) + TensorOperations.tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba...) + + function tensorcontract_pullback(::NoRData) + copy!(C, C_cache) + if Tα == Zero && Tβ == Zero + scale!(ΔC, zero(TC)) + return ntuple(i -> NoRData(), 11 + length(ba)) + end + ipAB = invperm(linearize(pAB)) + pΔC = _repartition(ipAB, TO.numout(pA)) + + # dC + if β === Zero() + scale!(ΔC, β) + else + scale!(ΔC, conj(β)) + end + + # dA + ipA = _repartition(invperm(linearize(pA)), A) + conjΔC = conjA + conjB′ = conjA ? conjB : !conjB + # TODO: allocator + tB = twist( + B, + TupleTools.vcat( + filter(x -> !isdual(space(B, x)), pB[1]), + filter(x -> isdual(space(B, x)), pB[2]) + ); copy = false + ) + tensorcontract!( + ΔA, + ΔC, pΔC, conjΔC, + tB, reverse(pB), conjB′, + ipA, + conjA ? α : conj(α), Zero(), ba... + ) + + # dB + ipB = _repartition(invperm(linearize(pB)), B) + conjΔC = conjB + conjA′ = conjB ? conjA : !conjA + # TODO: allocator + tA = twist( + A, + TupleTools.vcat( + filter(x -> isdual(space(A, x)), pA[1]), + filter(x -> !isdual(space(A, x)), pA[2]) + ); copy = false + ) + tensorcontract!( + ΔB, + tA, reverse(pA), conjA′, + ΔC, pΔC, conjΔC, + ipB, + conjB ? α : conj(α), Zero(), ba... + ) + + dα = if _needs_tangent(Tα) + AB = tensorcontract(A, pA, conjA, B, pB, conjB, pAB, One(), ba...) + Mooncake._rdata(inner(AB, ΔC)) + else + NoRData() + end + dβ = if _needs_tangent(Tβ) + # TODO: consider using `inner` + Mooncake._rdata(inner(C, ΔC)) + else + NoRData() + end + + return NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), dα, dβ, map(ba_ -> NoRData(), ba)... + end + return C_dC, tensorcontract_pullback +end diff --git a/ext/TensorKitMooncakeExt/utility.jl b/ext/TensorKitMooncakeExt/utility.jl new file mode 100644 index 000000000..ca2c79b54 --- /dev/null +++ b/ext/TensorKitMooncakeExt/utility.jl @@ -0,0 +1,28 @@ +_needs_tangent(x) = _needs_tangent(typeof(x)) +_needs_tangent(::Type{<:Number}) = true +_needs_tangent(::Type{<:Integer}) = false +_needs_tangent(::Type{<:Union{One, Zero}}) = false + +# IndexTuple utility +# ------------------ +trivtuple(N) = ntuple(identity, N) + +Base.@constprop :aggressive function _repartition(p::IndexTuple, N₁::Int) + length(p) >= N₁ || + throw(ArgumentError("cannot repartition $(typeof(p)) to $N₁, $(length(p) - N₁)")) + return TupleTools.getindices(p, trivtuple(N₁)), + TupleTools.getindices(p, trivtuple(length(p) - N₁) .+ N₁) +end +Base.@constprop :aggressive function _repartition(p::Index2Tuple, N₁::Int) + return _repartition(linearize(p), N₁) +end +function _repartition(p::Union{IndexTuple, Index2Tuple}, ::Index2Tuple{N₁}) where {N₁} + return _repartition(p, N₁) +end +function _repartition(p::Union{IndexTuple, Index2Tuple}, t::AbstractTensorMap) + return _repartition(p, TensorKit.numout(t)) +end + +# Ignore derivatives +# ------------------ +@zero_derivative DefaultCtx Tuple{typeof(TensorKit.fusionblockstructure), Any} From c078b010fe04c9509be3e3dd564b17664239b9e3 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Fri, 16 Jan 2026 21:43:02 +0100 Subject: [PATCH 4/6] reorganize tensorcontract_pullback --- .../TensorKitMooncakeExt.jl | 2 +- ext/TensorKitMooncakeExt/tensoroperations.jl | 211 +++++++++++------- 2 files changed, 126 insertions(+), 87 deletions(-) diff --git a/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl b/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl index 57074ede1..b35c73f4c 100644 --- a/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl +++ b/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl @@ -3,7 +3,7 @@ module TensorKitMooncakeExt using Mooncake using Mooncake: @zero_derivative, DefaultCtx, ReverseMode, NoRData, CoDual, arrayify, primal using TensorKit -using TensorOperations: TensorOperations, tensorcontract!, IndexTuple, Index2Tuple, linearize +using TensorOperations: TensorOperations, IndexTuple, Index2Tuple, linearize import TensorOperations as TO using VectorInterface: One, Zero using TupleTools diff --git a/ext/TensorKitMooncakeExt/tensoroperations.jl b/ext/TensorKitMooncakeExt/tensoroperations.jl index f06d4d345..d663a3281 100644 --- a/ext/TensorKitMooncakeExt/tensoroperations.jl +++ b/ext/TensorKitMooncakeExt/tensoroperations.jl @@ -1,98 +1,137 @@ -Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(tensorcontract!), AbstractTensorMap, AbstractTensorMap, Index2Tuple, Bool, AbstractTensorMap, Index2Tuple, Bool, Index2Tuple, Number, Number, Vararg{Any}} +Mooncake.@is_primitive( + DefaultCtx, + ReverseMode, + Tuple{ + typeof(TO.tensorcontract!), + AbstractTensorMap, + AbstractTensorMap, Index2Tuple, Bool, + AbstractTensorMap, Index2Tuple, Bool, + Index2Tuple, + Number, Number, + Vararg{Any}, + } +) function Mooncake.rrule!!( - ::CoDual{typeof(tensorcontract!)}, - C_dC::CoDual{<:AbstractTensorMap{TC}}, - A_dA::CoDual{<:AbstractTensorMap{TA}}, pA_dpA::CoDual{<:Index2Tuple}, conjA_dconjA::CoDual{Bool}, - B_dB::CoDual{<:AbstractTensorMap{TB}}, pB_dpB::CoDual{<:Index2Tuple}, conjB_dconjB::CoDual{Bool}, - pAB_dpAB::CoDual{<:Index2Tuple}, - α_dα::CoDual{Tα}, β_dβ::CoDual{Tβ}, - ba_dba::CoDual..., - ) where {Tα <: Number, Tβ <: Number, TA <: Number, TB <: Number, TC <: Number} - C, ΔC = arrayify(C_dC) - A, ΔA = arrayify(A_dA) - B, ΔB = arrayify(B_dB) - pA = primal(pA_dpA) - pB = primal(pB_dpB) - pAB = primal(pAB_dpAB) - conjA = primal(conjA_dconjA) - conjB = primal(conjB_dconjB) - α = primal(α_dα) - β = primal(β_dβ) - ba = primal.(ba_dba) + ::CoDual{typeof(TO.tensorcontract!)}, + C_ΔC::CoDual{<:AbstractTensorMap}, + A_ΔA::CoDual{<:AbstractTensorMap}, pA_ΔpA::CoDual{<:Index2Tuple}, conjA_ΔconjA::CoDual{Bool}, + B_ΔB::CoDual{<:AbstractTensorMap}, pB_ΔpB::CoDual{<:Index2Tuple}, conjB_ΔconjB::CoDual{Bool}, + pAB_ΔpAB::CoDual{<:Index2Tuple}, + α_Δα::CoDual{<:Number}, β_Δβ::CoDual{<:Number}, + ba_Δba::CoDual..., + ) + # prepare arguments + (C, ΔC), (A, ΔA), (B, ΔB) = arrayify.((C_ΔC, A_ΔA, B_ΔB)) + pA, pB, pAB = primal.((pA_ΔpA, pB_ΔpB, pAB_ΔpAB)) + conjA, conjB = primal.((conjA_ΔconjA, conjB_ΔconjB)) + α, β = primal.((α_Δα, β_Δβ)) + ba = primal.(ba_Δba) + + # primal call C_cache = copy(C) - TensorOperations.tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba...) + TO.tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba...) function tensorcontract_pullback(::NoRData) copy!(C, C_cache) - if Tα == Zero && Tβ == Zero - scale!(ΔC, zero(TC)) - return ntuple(i -> NoRData(), 11 + length(ba)) - end - ipAB = invperm(linearize(pAB)) - pΔC = _repartition(ipAB, TO.numout(pA)) - - # dC - if β === Zero() - scale!(ΔC, β) - else - scale!(ΔC, conj(β)) - end - - # dA - ipA = _repartition(invperm(linearize(pA)), A) - conjΔC = conjA - conjB′ = conjA ? conjB : !conjB - # TODO: allocator - tB = twist( - B, - TupleTools.vcat( - filter(x -> !isdual(space(B, x)), pB[1]), - filter(x -> isdual(space(B, x)), pB[2]) - ); copy = false - ) - tensorcontract!( - ΔA, - ΔC, pΔC, conjΔC, - tB, reverse(pB), conjB′, - ipA, - conjA ? α : conj(α), Zero(), ba... - ) - # dB - ipB = _repartition(invperm(linearize(pB)), B) - conjΔC = conjB - conjA′ = conjB ? conjA : !conjA - # TODO: allocator - tA = twist( - A, - TupleTools.vcat( - filter(x -> isdual(space(A, x)), pA[1]), - filter(x -> !isdual(space(A, x)), pA[2]) - ); copy = false + ΔCr = tensorcontract_pullback_ΔC!(ΔC, β) + ΔAr = tensorcontract_pullback_ΔA!( + ΔA, ΔC, A, pA, conjA, B, pB, conjB, pAB, α, ba... ) - tensorcontract!( - ΔB, - tA, reverse(pA), conjA′, - ΔC, pΔC, conjΔC, - ipB, - conjB ? α : conj(α), Zero(), ba... + ΔBr = tensorcontract_pullback_ΔB!( + ΔB, ΔC, A, pA, conjA, B, pB, conjB, pAB, α, ba... ) + Δαr = tensorcontract_pullback_Δα( + ΔC, A, pA, conjA, B, pB, conjB, pAB, α, ba... + ) + Δβr = tensorcontract_pullback_Δβ(ΔC, C, β) - dα = if _needs_tangent(Tα) - AB = tensorcontract(A, pA, conjA, B, pB, conjB, pAB, One(), ba...) - Mooncake._rdata(inner(AB, ΔC)) - else - NoRData() - end - dβ = if _needs_tangent(Tβ) - # TODO: consider using `inner` - Mooncake._rdata(inner(C, ΔC)) - else - NoRData() - end - - return NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), dα, dβ, map(ba_ -> NoRData(), ba)... + return NoRData(), ΔCr, + ΔAr, NoRData(), NoRData(), + ΔBr, NoRData(), NoRData(), + NoRData(), + Δαr, Δβr, + map(ba_ -> NoRData(), ba)... end - return C_dC, tensorcontract_pullback + + return C_ΔC, tensorcontract_pullback +end + +tensorcontract_pullback_ΔC!(ΔC, β) = (scale!(ΔC, conj(β)); NoRData()) + +function tensorcontract_pullback_ΔA!( + ΔA, ΔC, A, pA, conjA, B, pB, conjB, pAB, α, ba... + ) + ipAB = invperm(linearize(pAB)) + pΔC = _repartition(ipAB, TO.numout(pA)) + ipA = _repartition(invperm(linearize(pA)), A) + conjΔC = conjA + conjB′ = conjA ? conjB : !conjB + + tB = twist( + B, + TupleTools.vcat( + filter(x -> !isdual(space(B, x)), pB[1]), + filter(x -> isdual(space(B, x)), pB[2]) + ); copy = false + ) + + TO.tensorcontract!( + ΔA, + ΔC, pΔC, conjΔC, + tB, reverse(pB), conjB′, + ipA, + conjA ? α : conj(α), Zero(), + ba... + ) + + return NoRData() +end + +function tensorcontract_pullback_ΔB!( + ΔB, ΔC, A, pA, conjA, B, pB, conjB, pAB, α, ba... + ) + ipAB = invperm(linearize(pAB)) + pΔC = _repartition(ipAB, TO.numout(pA)) + ipB = _repartition(invperm(linearize(pB)), B) + conjΔC = conjB + conjA′ = conjB ? conjA : !conjA + + tA = twist( + A, + TupleTools.vcat( + filter(x -> isdual(space(A, x)), pA[1]), + filter(x -> !isdual(space(A, x)), pA[2]) + ); copy = false + ) + + TO.tensorcontract!( + ΔB, + tA, reverse(pA), conjA′, + ΔC, pΔC, conjΔC, + ipB, + conjB ? α : conj(α), Zero(), ba... + ) + + return NoRData() +end + +function tensorcontract_pullback_Δα( + ΔC, A, pA, conjA, B, pB, conjB, pAB, α, ba... + ) + Tdα = Mooncake.rdata_type(Mooncake.tangent_type(typeof(α))) + Tdα === NoRData && return NoRData() + + AB = TO.tensorcontract(A, pA, conjA, B, pB, conjB, pAB, One(), ba...) + Δα = inner(AB, ΔC) + return Mooncake._rdata(Δα) +end + +function tensorcontract_pullback_Δβ(ΔC, C, β) + Tdβ = Mooncake.rdata_type(Mooncake.tangent_type(typeof(β))) + Tdβ === NoRData && return NoRData() + + Δβ = inner(C, ΔC) + return Mooncake._rdata(Δβ) end From 5f5144579cff026db58278b2fde376e7e29a077d Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Fri, 16 Jan 2026 22:48:38 +0100 Subject: [PATCH 5/6] add tests --- test/autodiff/{ad.jl => chainrules.jl} | 0 test/autodiff/mooncake.jl | 117 +++++++++++++++++++++++++ 2 files changed, 117 insertions(+) rename test/autodiff/{ad.jl => chainrules.jl} (100%) create mode 100644 test/autodiff/mooncake.jl diff --git a/test/autodiff/ad.jl b/test/autodiff/chainrules.jl similarity index 100% rename from test/autodiff/ad.jl rename to test/autodiff/chainrules.jl diff --git a/test/autodiff/mooncake.jl b/test/autodiff/mooncake.jl new file mode 100644 index 000000000..1cd74fa27 --- /dev/null +++ b/test/autodiff/mooncake.jl @@ -0,0 +1,117 @@ +using Test, TestExtras +using TensorKit +using TensorOperations +using Mooncake +using Random + +mode = Mooncake.ReverseMode +rng = Random.default_rng() +is_primitive = false + +function randindextuple(N::Int, k::Int = rand(0:N)) + @assert 0 ≤ k ≤ N + _p = randperm(N) + return (tuple(_p[1:k]...), tuple(_p[(k + 1):end]...)) +end + +const _repartition = @static if isdefined(Base, :get_extension) + Base.get_extension(TensorKit, :TensorKitMooncakeExt)._repartition +else + TensorKit.TensorKitMooncakeExt._repartition +end + +spacelist = ( + (ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), + ( + Vect[Z2Irrep](0 => 1, 1 => 1), + Vect[Z2Irrep](0 => 1, 1 => 2)', + Vect[Z2Irrep](0 => 2, 1 => 2)', + Vect[Z2Irrep](0 => 2, 1 => 3), + Vect[Z2Irrep](0 => 2, 1 => 2), + ), + ( + Vect[FermionParity](0 => 1, 1 => 1), + Vect[FermionParity](0 => 1, 1 => 2)', + Vect[FermionParity](0 => 2, 1 => 1)', + Vect[FermionParity](0 => 2, 1 => 3), + Vect[FermionParity](0 => 2, 1 => 2), + ), + ( + Vect[U1Irrep](0 => 2, 1 => 1, -1 => 1), + Vect[U1Irrep](0 => 2, 1 => 1, -1 => 1), + Vect[U1Irrep](0 => 2, 1 => 2, -1 => 1)', + Vect[U1Irrep](0 => 1, 1 => 1, -1 => 2), + Vect[U1Irrep](0 => 1, 1 => 2, -1 => 1)', + ), + ( + Vect[SU2Irrep](0 => 2, 1 // 2 => 1), + Vect[SU2Irrep](0 => 1, 1 => 1), + Vect[SU2Irrep](1 // 2 => 1, 1 => 1)', + Vect[SU2Irrep](1 // 2 => 2), + Vect[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)', + ), + ( + Vect[FibonacciAnyon](:I => 2, :τ => 1), + Vect[FibonacciAnyon](:I => 1, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 3), + Vect[FibonacciAnyon](:I => 2, :τ => 2), + ), +) + +for V in spacelist + I = sectortype(eltype(V)) + Istr = TensorKit.type_repr(I) + + symmetricbraiding = BraidingStyle(sectortype(eltype(V))) isa SymmetricBraiding + println("---------------------------------------") + println("Mooncake with symmetry: $Istr") + println("---------------------------------------") + eltypes = (Float64,) # no complex support yet + symmetricbraiding && @timedtestset "TensorOperations with scalartype $T" for T in eltypes + atol = precision(T) + rtol = precision(T) + + @timedtestset "tensorcontract!" begin + for _ in 1:5 + d = 0 + local V1, V2, V3 + # retry a couple times to make sure there are at least some nonzero elements + for _ in 1:10 + k1 = rand(0:3) + k2 = rand(0:2) + k3 = rand(0:2) + V1 = prod(v -> rand(Bool) ? v' : v, rand(V, k1); init = one(V[1])) + V2 = prod(v -> rand(Bool) ? v' : v, rand(V, k2); init = one(V[1])) + V3 = prod(v -> rand(Bool) ? v' : v, rand(V, k3); init = one(V[1])) + d = min(dim(V1 ← V2), dim(V1' ← V2), dim(V2 ← V3), dim(V2' ← V3)) + d > 0 && break + end + ipA = randindextuple(length(V1) + length(V2)) + pA = _repartition(invperm(linearize(ipA)), length(V1)) + ipB = randindextuple(length(V2) + length(V3)) + pB = _repartition(invperm(linearize(ipB)), length(V2)) + pAB = randindextuple(length(V1) + length(V3)) + + α = randn(T) + β = randn(T) + V2_conj = prod(conj, V2; init = one(V[1])) + + for conjA in (false, true), conjB in (false, true) + A = randn(T, permute(V1 ← (conjA ? V2_conj : V2), ipA)) + B = randn(T, permute((conjB ? V2_conj : V2) ← V3, ipB)) + C = randn!( + TensorOperations.tensoralloc_contract( + T, A, pA, conjA, B, pB, conjB, pAB, Val(false) + ) + ) + Mooncake.TestUtils.test_rule( + rng, tensorcontract!, C, A, pA, conjA, B, pB, conjB, pAB, α, β; + atol, rtol, mode, is_primitive + ) + + end + end + end + end +end From 134ddd50ded1f180ee59bf9c480940e7f78a45ee Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Sat, 17 Jan 2026 07:13:03 +0100 Subject: [PATCH 6/6] Add Mooncake compat --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index 7230feece..934bdb6ed 100644 --- a/Project.toml +++ b/Project.toml @@ -45,6 +45,7 @@ GPUArrays = "11.3.1" LRUCache = "1.0.2" LinearAlgebra = "1" MatrixAlgebraKit = "0.6.2" +Mooncake = "0.4.183" OhMyThreads = "0.8.0" Printf = "1" Random = "1"