diff --git a/Project.toml b/Project.toml index 934bdb6ed..11b378c4a 100644 --- a/Project.toml +++ b/Project.toml @@ -56,7 +56,7 @@ TensorKitSectors = "0.3.3" TensorOperations = "5.1" Test = "1" TestExtras = "0.2,0.3" -TupleTools = "1.1" +TupleTools = "1.5" VectorInterface = "0.4.8, 0.5" Zygote = "0.7" cuTENSOR = "2" diff --git a/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl b/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl index b35c73f4c..4c692adb9 100644 --- a/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl +++ b/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl @@ -1,17 +1,20 @@ module TensorKitMooncakeExt using Mooncake -using Mooncake: @zero_derivative, DefaultCtx, ReverseMode, NoRData, CoDual, arrayify, primal +using Mooncake: @zero_derivative, DefaultCtx, ReverseMode, NoFData, NoRData, CoDual, arrayify, primal using TensorKit +import TensorKit as TK +using VectorInterface using TensorOperations: TensorOperations, IndexTuple, Index2Tuple, linearize import TensorOperations as TO -using VectorInterface: One, Zero using TupleTools - include("utility.jl") include("tangent.jl") include("linalg.jl") +include("indexmanipulations.jl") +include("vectorinterface.jl") include("tensoroperations.jl") +include("planaroperations.jl") end diff --git a/ext/TensorKitMooncakeExt/indexmanipulations.jl b/ext/TensorKitMooncakeExt/indexmanipulations.jl new file mode 100644 index 000000000..464c18392 --- /dev/null +++ b/ext/TensorKitMooncakeExt/indexmanipulations.jl @@ -0,0 +1,409 @@ +for transform in (:permute, :transpose) + add_transform! = Symbol(:add_, transform, :!) + add_transform_pullback = Symbol(add_transform!, :_pullback) + @eval Mooncake.@is_primitive( + DefaultCtx, + ReverseMode, + Tuple{ + typeof(TK.$add_transform!), + AbstractTensorMap, + AbstractTensorMap, Index2Tuple, + Number, Number, Vararg{Any}, + } + ) + + @eval function Mooncake.rrule!!( + ::CoDual{typeof(TK.$add_transform!)}, + C_ΔC::CoDual{<:AbstractTensorMap}, + A_ΔA::CoDual{<:AbstractTensorMap}, p_Δp::CoDual{<:Index2Tuple}, + α_Δα::CoDual{<:Number}, β_Δβ::CoDual{<:Number}, + ba_Δba::CoDual... + ) + # prepare arguments + C, ΔC = arrayify(C_ΔC) + A, ΔA = arrayify(A_ΔA) + p = primal(p_Δp) + α, β = primal.((α_Δα, β_Δβ)) + ba = primal.(ba_Δba) + + C_cache = copy(C) + + # if we need to compute Δa, it is faster to allocate an intermediate permuted A + # and store that instead of repeating the permutation in the pullback each time. + # effectively, we replace `add_permute` by `add ∘ permute`. + Tdα = Mooncake.rdata_type(Mooncake.tangent_type(typeof(α))) + Ap = if Tdα === NoRData + TK.$add_transform!(C, A, p, α, β, ba...) + nothing + else + Ap = $transform(A, p) + add!(C, Ap, α, β) + Ap + end + + function $add_transform_pullback(::NoRData) + copy!(C, C_cache) + + scale!(ΔC, conj(β)) + ΔCr = NoRData() + + # ΔA + ip = invperm(linearize(p)) + pΔA = _repartition(ip, A) + TK.$add_transform!(ΔA, ΔC, pΔA, conj(α), One(), ba...) + ΔAr = NoRData() + + # Δα + Δαr = if isnothing(Ap) + NoRData() + else + Mooncake._rdata(inner(Ap, ΔC)) + end + + # Δβ + Tdβ = Mooncake.rdata_type(Mooncake.tangent_type(typeof(β))) + Δβr = if Tdβ === NoRData + NoRData() + else + Mooncake._rdata(inner(C, ΔC)) + end + + + return NoRData(), ΔCr, ΔAr, NoRData(), Δαr, Δβr, map(Returns(NoRData()), ba)... + end + + return C_ΔC, $add_transform_pullback + end +end + +Mooncake.@is_primitive( + DefaultCtx, + ReverseMode, + Tuple{ + typeof(TK.add_braid!), + AbstractTensorMap, + AbstractTensorMap, Index2Tuple, IndexTuple, + Number, Number, Vararg{Any}, + } +) + +function Mooncake.rrule!!( + ::CoDual{typeof(TK.add_braid!)}, + C_ΔC::CoDual{<:AbstractTensorMap}, + A_ΔA::CoDual{<:AbstractTensorMap}, p_Δp::CoDual{<:Index2Tuple}, levels_Δlevels::CoDual{<:IndexTuple}, + α_Δα::CoDual{<:Number}, β_Δβ::CoDual{<:Number}, + ba_Δba::CoDual... + ) + # prepare arguments + C, ΔC = arrayify(C_ΔC) + A, ΔA = arrayify(A_ΔA) + p = primal(p_Δp) + levels = primal(levels_Δlevels) + α, β = primal.((α_Δα, β_Δβ)) + ba = primal.(ba_Δba) + + C_cache = copy(C) + + # if we need to compute Δa, it is faster to allocate an intermediate braided A + # and store that instead of repeating the permutation in the pullback each time. + # effectively, we replace `add_permute` by `add ∘ permute`. + Tdα = Mooncake.rdata_type(Mooncake.tangent_type(typeof(α))) + Ap = if Tdα === NoRData + TK.add_braid!(C, A, p, levels, α, β, ba...) + nothing + else + Ap = braid(A, p, levels) + add!(C, Ap, α, β) + Ap + end + + function add_braid!_pullback(::NoRData) + copy!(C, C_cache) + + scale!(ΔC, conj(β)) + ΔCr = NoRData() + + # ΔA + ip = invperm(linearize(p)) + pΔA = _repartition(ip, A) + ilevels = TupleTools.permute(levels, linearize(p)) + TK.add_braid!(ΔA, ΔC, pΔA, ilevels, conj(α), One(), ba...) + ΔAr = NoRData() + + # Δα + Δαr = if isnothing(Ap) + NoRData() + else + Mooncake._rdata(inner(Ap, ΔC)) + end + + # Δβ + Tdβ = Mooncake.rdata_type(Mooncake.tangent_type(typeof(β))) + Δβr = if Tdβ === NoRData + NoRData() + else + Mooncake._rdata(inner(C, ΔC)) + end + + + return NoRData(), ΔCr, ΔAr, NoRData(), NoRData(), Δαr, Δβr, map(Returns(NoRData()), ba)... + end + + return C_ΔC, add_braid!_pullback +end + +# both are needed for correctly capturing every dispatch +Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(twist!), AbstractTensorMap, Any} +Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(Core.kwcall), @NamedTuple{inv::Bool}, typeof(twist!), AbstractTensorMap, Any} + +function Mooncake.rrule!!(::CoDual{typeof(twist!)}, t_Δt::CoDual{<:AbstractTensorMap}, inds_Δinds::CoDual) + # prepare arguments + t, Δt = arrayify(t_Δt) + inv = false + inds = primal(inds_Δinds) + + # primal call + t_cache = copy(t) + twist!(t, inds; inv) + + function twist_pullback(::NoRData) + copy!(t, t_cache) + twist!(Δt, inds; inv = !inv) + return ntuple(Returns(NoRData()), 3) + end + + return t_Δt, twist_pullback + +end +function Mooncake.rrule!!( + ::CoDual{typeof(Core.kwcall)}, kwargs_Δkwargs::CoDual{@NamedTuple{inv::Bool}}, ::CoDual{typeof(twist!)}, + t_Δt::CoDual{<:AbstractTensorMap}, inds_Δinds::CoDual + ) + # prepare arguments + t, Δt = arrayify(t_Δt) + inv = primal(kwargs_Δkwargs).inv + inds = primal(inds_Δinds) + + # primal call + t_cache = copy(t) + twist!(t, inds; inv) + + function twist_pullback(::NoRData) + copy!(t, t_cache) + twist!(Δt, inds; inv = !inv) + return ntuple(Returns(NoRData()), 5) + end + + return t_Δt, twist_pullback +end + +# both are needed for correctly capturing every dispatch +Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(flip), AbstractTensorMap, Any} +Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(Core.kwcall), @NamedTuple{inv::Bool}, typeof(flip), AbstractTensorMap, Any} + +function Mooncake.rrule!!(::CoDual{typeof(flip)}, t_Δt::CoDual{<:AbstractTensorMap}, inds_Δinds::CoDual) + # prepare arguments + t, Δt = arrayify(t_Δt) + inv = false + inds = primal(inds_Δinds) + + # primal call + t_flipped = flip(t, inds; inv) + t_flipped_Δt_flipped = Mooncake.zero_fcodual(t_flipped) + _, Δt_flipped = arrayify(t_flipped_Δt_flipped) + + function twist_pullback(::NoRData) + copy!(Δt, flip(Δt_flipped, inds; inv = !inv)) + return ntuple(Returns(NoRData()), 3) + end + + return t_flipped_Δt_flipped, twist_pullback +end +function Mooncake.rrule!!( + ::CoDual{typeof(Core.kwcall)}, kwargs_Δkwargs::CoDual{@NamedTuple{inv::Bool}}, ::CoDual{typeof(flip)}, + t_Δt::CoDual{<:AbstractTensorMap}, inds_Δinds::CoDual + ) + # prepare arguments + t, Δt = arrayify(t_Δt) + inv = primal(kwargs_Δkwargs).inv + inds = primal(inds_Δinds) + + # primal call + t_flipped = flip(t, inds; inv) + t_flipped_Δt_flipped = Mooncake.zero_fcodual(t_flipped) + _, Δt_flipped = arrayify(t_flipped_Δt_flipped) + + function twist_pullback(::NoRData) + copy!(Δt, flip(Δt_flipped, inds; inv = !inv)) + return ntuple(Returns(NoRData()), 5) + end + + return t_flipped_Δt_flipped, twist_pullback +end + +for insertunit in (:insertleftunit, :insertrightunit) + insertunit_pullback = Symbol(insertunit, :_pullback) + @eval begin + # both are needed for correctly capturing every dispatch + Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof($insertunit), AbstractTensorMap, Val} + Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(Core.kwcall), NamedTuple, typeof($insertunit), AbstractTensorMap, Val} + + function Mooncake.rrule!!(::CoDual{typeof($insertunit)}, tsrc_Δtsrc::CoDual{<:AbstractTensorMap}, ival_Δival::CoDual{<:Val}) + # prepare arguments + tsrc, Δtsrc = arrayify(tsrc_Δtsrc) + ival = primal(ival_Δival) + + # tdst shares data with tsrc if <:TensorMap, in this case we have to deal with correctly + # sharing address spaces + if tsrc isa TensorMap + tsrc_cache = copy(tsrc) + tdst = $insertunit(tsrc, ival) + # note: this is somewhat of a hack that makes use of the fact that the tangent is + # encoded without any information about the space, which allows us to simply reuse + # the tangent exactly without having to modify the space information + tdst_Δtdst = CoDual(tdst, Mooncake.tangent(tsrc_Δtsrc)) + else + tsrc_cache = nothing + tdst = $insertunit(tsrc, ival) + tdst_Δtdst = Mooncake.zero_fcodual(tdst) + end + + _, Δtdst = arrayify(tdst_Δtdst) + + function $insertunit_pullback(::NoRData) + if isnothing(tsrc_cache) + for (c, b) in blocks(Δtdst) + copy!(block(Δtsrc, c), b) + end + else + copy!(tsrc, tsrc_cache) + # note: since data is already shared, don't have to do anything here! + end + return ntuple(Returns(NoRData()), 3) + end + + return tdst_Δtdst, $insertunit_pullback + end + function Mooncake.rrule!!( + ::CoDual{typeof(Core.kwcall)}, kwargs_Δkwargs::CoDual{<:NamedTuple}, + ::CoDual{typeof($insertunit)}, tsrc_Δtsrc::CoDual{<:AbstractTensorMap}, ival_Δival::CoDual{<:Val} + ) + # prepare arguments + tsrc, Δtsrc = arrayify(tsrc_Δtsrc) + ival = primal(ival_Δival) + kwargs = primal(kwargs_Δkwargs) + + # tdst shares data with tsrc if <:TensorMap & copy=false, in this case we have to deal with correctly + # sharing address spaces + if tsrc isa TensorMap && !get(kwargs, :copy, false) + tsrc_cache = copy(tsrc) + tdst = $insertunit(tsrc, ival; kwargs...) + # note: this is somewhat of a hack that makes use of the fact that the tangent is + # encoded without any information about the space, which allows us to simply reuse + # the tangent exactly without having to modify the space information + tdst_Δtdst = CoDual(tdst, Mooncake.tangent(tsrc_Δtsrc)) + else + tsrc_cache = nothing + tdst = $insertunit(tsrc, ival; kwargs...) + tdst_Δtdst = Mooncake.zero_fcodual(tdst) + end + + _, Δtdst = arrayify(tdst_Δtdst) + + function $insertunit_pullback(::NoRData) + if isnothing(tsrc_cache) + for (c, b) in blocks(Δtdst) + copy!(block(Δtsrc, c), b) + end + else + copy!(tsrc, tsrc_cache) + # note: since data is already shared, don't have to do anything here! + end + return ntuple(Returns(NoRData()), 5) + end + + return tdst_Δtdst, $insertunit_pullback + end + end +end + + +Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(removeunit), AbstractTensorMap, Val} +Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(Core.kwcall), NamedTuple, typeof(removeunit), AbstractTensorMap, Val} + +function Mooncake.rrule!!(::CoDual{typeof(removeunit)}, tsrc_Δtsrc::CoDual{<:AbstractTensorMap}, ival_Δival::CoDual{Val{i}}) where {i} + # prepare arguments + tsrc, Δtsrc = arrayify(tsrc_Δtsrc) + ival = primal(ival_Δival) + + # tdst shares data with tsrc if <:TensorMap, in this case we have to deal with correctly + # sharing address spaces + if tsrc isa TensorMap + tsrc_cache = copy(tsrc) + tdst = removeunit(tsrc, ival) + # note: this is somewhat of a hack that makes use of the fact that the tangent is + # encoded without any information about the space, which allows us to simply reuse + # the tangent exactly without having to modify the space information + tdst_Δtdst = CoDual(tdst, Mooncake.tangent(tsrc_Δtsrc)) + else + tsrc_cache = nothing + tdst = removeunit(tsrc, ival) + tdst_Δtdst = Mooncake.zero_fcodual(tdst) + end + + _, Δtdst = arrayify(tdst_Δtdst) + + function removeunit_pullback(::NoRData) + if isnothing(tsrc_cache) + for (c, b) in blocks(Δtdst) + copy!(block(Δtsrc, c), b) + end + else + copy!(tsrc, tsrc_cache) + # note: since data is already shared, don't have to do anything here! + end + return ntuple(Returns(NoRData()), 3) + end + + return tdst_Δtdst, removeunit_pullback +end +function Mooncake.rrule!!( + ::CoDual{typeof(Core.kwcall)}, kwargs_Δkwargs::CoDual{<:NamedTuple}, + ::CoDual{typeof(removeunit)}, tsrc_Δtsrc::CoDual{<:AbstractTensorMap}, ival_Δival::CoDual{<:Val} + ) + # prepare arguments + tsrc, Δtsrc = arrayify(tsrc_Δtsrc) + ival = primal(ival_Δival) + kwargs = primal(kwargs_Δkwargs) + + # tdst shares data with tsrc if <:TensorMap & copy=false, in this case we have to deal with correctly + # sharing address spaces + if tsrc isa TensorMap && !get(kwargs, :copy, false) + tsrc_cache = copy(tsrc) + tdst = removeunit(tsrc, ival; kwargs...) + # note: this is somewhat of a hack that makes use of the fact that the tangent is + # encoded without any information about the space, which allows us to simply reuse + # the tangent exactly without having to modify the space information + tdst_Δtdst = CoDual(tdst, Mooncake.tangent(tsrc_Δtsrc)) + else + tsrc_cache = nothing + tdst = removeunit(tsrc, ival; kwargs...) + tdst_Δtdst = Mooncake.zero_fcodual(tdst) + end + + _, Δtdst = arrayify(tdst_Δtdst) + + function removeunit_pullback(::NoRData) + if isnothing(tsrc_cache) + for (c, b) in blocks(Δtdst) + copy!(block(Δtsrc, c), b) + end + else + copy!(tsrc, tsrc_cache) + # note: since data is already shared, don't have to do anything here! + end + return ntuple(Returns(NoRData()), 5) + end + + return tdst_Δtdst, removeunit_pullback +end diff --git a/ext/TensorKitMooncakeExt/linalg.jl b/ext/TensorKitMooncakeExt/linalg.jl index 56533d227..092ddf369 100644 --- a/ext/TensorKitMooncakeExt/linalg.jl +++ b/ext/TensorKitMooncakeExt/linalg.jl @@ -1,3 +1,42 @@ +Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(mul!), AbstractTensorMap, AbstractTensorMap, AbstractTensorMap, Number, Number} + +function Mooncake.rrule!!( + ::CoDual{typeof(mul!)}, + C_ΔC::CoDual{<:AbstractTensorMap}, A_ΔA::CoDual{<:AbstractTensorMap}, B_ΔB::CoDual{<:AbstractTensorMap}, + α_Δα::CoDual{<:Number}, β_Δβ::CoDual{<:Number} + ) + (C, ΔC), (A, ΔA), (B, ΔB) = arrayify.((C_ΔC, A_ΔA, B_ΔB)) + α, β = primal.((α_Δα, β_Δβ)) + + # primal call + C_cache = copy(C) + AB = if _needs_tangent(α) + AB = A * B + add!(C, AB, α, β) + AB + else + mul!(C, A, B, α, β) + nothing + end + + function mul_pullback(::NoRData) + copy!(C, C_cache) + + scale!(ΔC, conj(β)) + mul!(ΔA, ΔC, B', conj(α), One()) + mul!(ΔB, A', ΔC, conj(α), One()) + ΔCr = NoRData() + ΔAr = NoRData() + ΔBr = NoRData() + Δαr = isnothing(AB) ? NoRData() : Mooncake._rdata(inner(AB, ΔC)) + Δβr = _needs_tangent(β) ? Mooncake._rdata(inner(C, ΔC)) : NoRData() + + return NoRData(), ΔCr, ΔAr, ΔBr, Δαr, Δβr + end + + return C_ΔC, mul_pullback +end + Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(norm), AbstractTensorMap, Real} function Mooncake.rrule!!(::CoDual{typeof(norm)}, tΔt::CoDual{<:AbstractTensorMap}, pdp::CoDual{<:Real}) @@ -12,3 +51,19 @@ function Mooncake.rrule!!(::CoDual{typeof(norm)}, tΔt::CoDual{<:AbstractTensorM end return CoDual(n, Mooncake.NoFData()), norm_pullback end + +Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(tr), AbstractTensorMap} + +function Mooncake.rrule!!(::CoDual{typeof(tr)}, A_ΔA::CoDual{<:AbstractTensorMap}) + A, ΔA = arrayify(A_ΔA) + trace = tr(A) + + function tr_pullback(Δtrace) + for (_, b) in blocks(ΔA) + TensorKit.diagview(b) .+= Δtrace + end + return NoRData(), NoRData() + end + + return CoDual(trace, Mooncake.NoFData()), tr_pullback +end diff --git a/ext/TensorKitMooncakeExt/planaroperations.jl b/ext/TensorKitMooncakeExt/planaroperations.jl new file mode 100644 index 000000000..a480293af --- /dev/null +++ b/ext/TensorKitMooncakeExt/planaroperations.jl @@ -0,0 +1,88 @@ +# planartrace! +# ------------ +Mooncake.@is_primitive( + DefaultCtx, + ReverseMode, + Tuple{ + typeof(TensorKit.planartrace!), + AbstractTensorMap, + AbstractTensorMap, Index2Tuple, Index2Tuple, + Number, Number, + Any, Any, + } +) + +function Mooncake.rrule!!( + ::CoDual{typeof(TensorKit.planartrace!)}, + C_ΔC::CoDual{<:AbstractTensorMap}, + A_ΔA::CoDual{<:AbstractTensorMap}, p_Δp::CoDual{<:Index2Tuple}, q_Δq::CoDual{<:Index2Tuple}, + α_Δα::CoDual{<:Number}, β_Δβ::CoDual{<:Number}, + backend_Δbackend::CoDual, allocator_Δallocator::CoDual + ) + # prepare arguments + C, ΔC = arrayify(C_ΔC) + A, ΔA = arrayify(A_ΔA) + p = primal(p_Δp) + q = primal(q_Δq) + α, β = primal.((α_Δα, β_Δβ)) + backend, allocator = primal.((backend_Δbackend, allocator_Δallocator)) + + # primal call + C_cache = copy(C) + TensorKit.planartrace!(C, A, p, q, α, β, backend, allocator) + + function planartrace_pullback(::NoRData) + copy!(C, C_cache) + + ΔAr = planartrace_pullback_ΔA!(ΔA, ΔC, A, p, q, α, backend, allocator) + Δαr = planartrace_pullback_Δα(ΔC, A, p, q, α, backend, allocator) + Δβr = planartrace_pullback_Δβ(ΔC, C, β) + ΔCr = planartrace_pullback_ΔC!(ΔC, β) + + return NoRData(), + ΔCr, ΔAr, NoRData(), NoRData(), + Δαr, Δβr, NoRData(), NoRData() + end + + return C_ΔC, planartrace_pullback +end + +planartrace_pullback_ΔC!(ΔC, β) = (scale!(ΔC, conj(β)); NoRData()) + +function planartrace_pullback_ΔA!( + ΔA, ΔC, A, p, q, α, backend, allocator + ) + ip = invperm((linearize(p)..., q[1]..., q[2]...)) + pdA = _repartition(ip, A) + E = one!(TO.tensoralloc_add(scalartype(A), A, q, false)) + twist!(E, filter(x -> !isdual(space(E, x)), codomainind(E))) + pE = ((), trivtuple(TO.numind(q))) + pΔC = (trivtuple(TO.numind(p)), ()) + TensorKit.planarcontract!( + ΔA, ΔC, pΔC, E, pE, pdA, conj(α), One(), backend, allocator + ) + return NoRData() +end + +function planartrace_pullback_Δα( + ΔC, A, p, q, α, backend, allocator + ) + Tdα = Mooncake.rdata_type(Mooncake.tangent_type(typeof(α))) + Tdα === NoRData && return NoRData() + + # TODO: this result might be easier to compute as: + # C′ = βC + α * trace(A) ⟹ At = (C′ - βC) / α + At = TO.tensoralloc_add(scalartype(A), A, p, false, Val(true), allocator) + TensorKit.planartrace!(At, A, p, q, false, One(), backend, allocator) + Δα = inner(At, ΔC) + TO.tensorfree!(At, allocator) + return Mooncake._rdata(Δα) +end + +function planartrace_pullback_Δβ(ΔC, C, β) + Tdβ = Mooncake.rdata_type(Mooncake.tangent_type(typeof(β))) + Tdβ === NoRData && return NoRData() + + Δβ = inner(C, ΔC) + return Mooncake._rdata(Δβ) +end diff --git a/ext/TensorKitMooncakeExt/tangent.jl b/ext/TensorKitMooncakeExt/tangent.jl index 761e626f0..9fa6e401a 100644 --- a/ext/TensorKitMooncakeExt/tangent.jl +++ b/ext/TensorKitMooncakeExt/tangent.jl @@ -5,3 +5,11 @@ function Mooncake.arrayify(A_dA::CoDual{<:TensorMap}) dA = typeof(A)(data, A.space) return A, dA end + +function Mooncake.arrayify(Aᴴ_ΔAᴴ::CoDual{<:TensorKit.AdjointTensorMap}) + Aᴴ = Mooncake.primal(Aᴴ_ΔAᴴ) + ΔAᴴ = Mooncake.tangent(Aᴴ_ΔAᴴ) + A_ΔA = CoDual(Aᴴ', ΔAᴴ.data.parent) + A, ΔA = arrayify(A_ΔA) + return A', ΔA' +end diff --git a/ext/TensorKitMooncakeExt/tensoroperations.jl b/ext/TensorKitMooncakeExt/tensoroperations.jl index d663a3281..59a398e27 100644 --- a/ext/TensorKitMooncakeExt/tensoroperations.jl +++ b/ext/TensorKitMooncakeExt/tensoroperations.jl @@ -1,73 +1,72 @@ +# tensorcontract! +# --------------- Mooncake.@is_primitive( DefaultCtx, ReverseMode, Tuple{ - typeof(TO.tensorcontract!), + typeof(TensorKit.blas_contract!), AbstractTensorMap, - AbstractTensorMap, Index2Tuple, Bool, - AbstractTensorMap, Index2Tuple, Bool, + AbstractTensorMap, Index2Tuple, + AbstractTensorMap, Index2Tuple, Index2Tuple, Number, Number, - Vararg{Any}, + Any, Any, } ) function Mooncake.rrule!!( - ::CoDual{typeof(TO.tensorcontract!)}, + ::CoDual{typeof(TensorKit.blas_contract!)}, C_ΔC::CoDual{<:AbstractTensorMap}, - A_ΔA::CoDual{<:AbstractTensorMap}, pA_ΔpA::CoDual{<:Index2Tuple}, conjA_ΔconjA::CoDual{Bool}, - B_ΔB::CoDual{<:AbstractTensorMap}, pB_ΔpB::CoDual{<:Index2Tuple}, conjB_ΔconjB::CoDual{Bool}, + A_ΔA::CoDual{<:AbstractTensorMap}, pA_ΔpA::CoDual{<:Index2Tuple}, + B_ΔB::CoDual{<:AbstractTensorMap}, pB_ΔpB::CoDual{<:Index2Tuple}, pAB_ΔpAB::CoDual{<:Index2Tuple}, α_Δα::CoDual{<:Number}, β_Δβ::CoDual{<:Number}, - ba_Δba::CoDual..., + backend_Δbackend::CoDual, allocator_Δallocator::CoDual ) # prepare arguments (C, ΔC), (A, ΔA), (B, ΔB) = arrayify.((C_ΔC, A_ΔA, B_ΔB)) pA, pB, pAB = primal.((pA_ΔpA, pB_ΔpB, pAB_ΔpAB)) - conjA, conjB = primal.((conjA_ΔconjA, conjB_ΔconjB)) α, β = primal.((α_Δα, β_Δβ)) - ba = primal.(ba_Δba) + backend, allocator = primal.((backend_Δbackend, allocator_Δallocator)) # primal call C_cache = copy(C) - TO.tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba...) + TensorKit.blas_contract!(C, A, pA, B, pB, pAB, α, β, backend, allocator) - function tensorcontract_pullback(::NoRData) + function blas_contract_pullback(::NoRData) copy!(C, C_cache) - ΔCr = tensorcontract_pullback_ΔC!(ΔC, β) - ΔAr = tensorcontract_pullback_ΔA!( - ΔA, ΔC, A, pA, conjA, B, pB, conjB, pAB, α, ba... + ΔAr = blas_contract_pullback_ΔA!( + ΔA, ΔC, A, pA, B, pB, pAB, α, backend, allocator ) - ΔBr = tensorcontract_pullback_ΔB!( - ΔB, ΔC, A, pA, conjA, B, pB, conjB, pAB, α, ba... + ΔBr = blas_contract_pullback_ΔB!( + ΔB, ΔC, A, pA, B, pB, pAB, α, backend, allocator ) - Δαr = tensorcontract_pullback_Δα( - ΔC, A, pA, conjA, B, pB, conjB, pAB, α, ba... + Δαr = blas_contract_pullback_Δα( + ΔC, A, pA, B, pB, pAB, α, backend, allocator ) - Δβr = tensorcontract_pullback_Δβ(ΔC, C, β) + Δβr = blas_contract_pullback_Δβ(ΔC, C, β) + ΔCr = blas_contract_pullback_ΔC!(ΔC, β) return NoRData(), ΔCr, - ΔAr, NoRData(), NoRData(), - ΔBr, NoRData(), NoRData(), + ΔAr, NoRData(), + ΔBr, NoRData(), NoRData(), Δαr, Δβr, - map(ba_ -> NoRData(), ba)... + NoRData(), NoRData() end - return C_ΔC, tensorcontract_pullback + return C_ΔC, blas_contract_pullback end -tensorcontract_pullback_ΔC!(ΔC, β) = (scale!(ΔC, conj(β)); NoRData()) +blas_contract_pullback_ΔC!(ΔC, β) = (scale!(ΔC, conj(β)); NoRData()) -function tensorcontract_pullback_ΔA!( - ΔA, ΔC, A, pA, conjA, B, pB, conjB, pAB, α, ba... +function blas_contract_pullback_ΔA!( + ΔA, ΔC, A, pA, B, pB, pAB, α, backend, allocator ) ipAB = invperm(linearize(pAB)) pΔC = _repartition(ipAB, TO.numout(pA)) ipA = _repartition(invperm(linearize(pA)), A) - conjΔC = conjA - conjB′ = conjA ? conjB : !conjB tB = twist( B, @@ -79,24 +78,22 @@ function tensorcontract_pullback_ΔA!( TO.tensorcontract!( ΔA, - ΔC, pΔC, conjΔC, - tB, reverse(pB), conjB′, + ΔC, pΔC, false, + tB, reverse(pB), true, ipA, - conjA ? α : conj(α), Zero(), - ba... + conj(α), Zero(), + backend, allocator ) return NoRData() end -function tensorcontract_pullback_ΔB!( - ΔB, ΔC, A, pA, conjA, B, pB, conjB, pAB, α, ba... +function blas_contract_pullback_ΔB!( + ΔB, ΔC, A, pA, B, pB, pAB, α, backend, allocator ) ipAB = invperm(linearize(pAB)) pΔC = _repartition(ipAB, TO.numout(pA)) ipB = _repartition(invperm(linearize(pB)), B) - conjΔC = conjB - conjA′ = conjB ? conjA : !conjA tA = twist( A, @@ -108,27 +105,114 @@ function tensorcontract_pullback_ΔB!( TO.tensorcontract!( ΔB, - tA, reverse(pA), conjA′, - ΔC, pΔC, conjΔC, + tA, reverse(pA), true, + ΔC, pΔC, false, ipB, - conjB ? α : conj(α), Zero(), ba... + conj(α), Zero(), backend, allocator ) return NoRData() end -function tensorcontract_pullback_Δα( - ΔC, A, pA, conjA, B, pB, conjB, pAB, α, ba... +function blas_contract_pullback_Δα( + ΔC, A, pA, B, pB, pAB, α, backend, allocator ) Tdα = Mooncake.rdata_type(Mooncake.tangent_type(typeof(α))) Tdα === NoRData && return NoRData() - AB = TO.tensorcontract(A, pA, conjA, B, pB, conjB, pAB, One(), ba...) + AB = TO.tensorcontract(A, pA, false, B, pB, false, pAB, One(), backend, allocator) Δα = inner(AB, ΔC) return Mooncake._rdata(Δα) end -function tensorcontract_pullback_Δβ(ΔC, C, β) +function blas_contract_pullback_Δβ(ΔC, C, β) + Tdβ = Mooncake.rdata_type(Mooncake.tangent_type(typeof(β))) + Tdβ === NoRData && return NoRData() + + Δβ = inner(C, ΔC) + return Mooncake._rdata(Δβ) +end + +# tensortrace! +# ------------ +Mooncake.@is_primitive( + DefaultCtx, + ReverseMode, + Tuple{ + typeof(TensorKit.trace_permute!), + AbstractTensorMap, + AbstractTensorMap, Index2Tuple, Index2Tuple, + Number, Number, + Any, + } +) + +function Mooncake.rrule!!( + ::CoDual{typeof(TensorKit.trace_permute!)}, + C_ΔC::CoDual{<:AbstractTensorMap}, + A_ΔA::CoDual{<:AbstractTensorMap}, p_Δp::CoDual{<:Index2Tuple}, q_Δq::CoDual{<:Index2Tuple}, + α_Δα::CoDual{<:Number}, β_Δβ::CoDual{<:Number}, + backend_Δbackend::CoDual + ) + # prepare arguments + C, ΔC = arrayify(C_ΔC) + A, ΔA = arrayify(A_ΔA) + p = primal(p_Δp) + q = primal(q_Δq) + α, β = primal.((α_Δα, β_Δβ)) + backend = primal(backend_Δbackend) + + # primal call + C_cache = copy(C) + TensorKit.trace_permute!(C, A, p, q, α, β, backend) + + function trace_permute_pullback(::NoRData) + copy!(C, C_cache) + + ΔAr = trace_permute_pullback_ΔA!(ΔA, ΔC, A, p, q, α, backend) + Δαr = trace_permute_pullback_Δα(ΔC, A, p, q, α, backend) + Δβr = trace_permute_pullback_Δβ(ΔC, C, β) + ΔCr = trace_permute_pullback_ΔC!(ΔC, β) + + return NoRData(), + ΔCr, ΔAr, NoRData(), NoRData(), + Δαr, Δβr, NoRData() + end + + return C_ΔC, trace_permute_pullback +end + +trace_permute_pullback_ΔC!(ΔC, β) = (scale!(ΔC, conj(β)); NoRData()) + +function trace_permute_pullback_ΔA!( + ΔA, ΔC, A, p, q, α, backend + ) + ip = invperm((linearize(p)..., q[1]..., q[2]...)) + pdA = _repartition(ip, A) + E = one!(TO.tensoralloc_add(scalartype(A), A, q, false)) + twist!(E, filter(x -> !isdual(space(E, x)), codomainind(E))) + pE = ((), trivtuple(TO.numind(q))) + pΔC = (trivtuple(TO.numind(p)), ()) + TO.tensorproduct!( + ΔA, ΔC, pΔC, false, E, pE, false, pdA, conj(α), One(), backend + ) + return NoRData() +end + +function trace_permute_pullback_Δα( + ΔC, A, p, q, α, backend + ) + Tdα = Mooncake.rdata_type(Mooncake.tangent_type(typeof(α))) + Tdα === NoRData && return NoRData() + + # TODO: this result might be easier to compute as: + # C′ = βC + α * trace(A) ⟹ At = (C′ - βC) / α + At = TO.tensortrace(A, p, q, false, One(), backend) + Δα = inner(At, ΔC) + return Mooncake._rdata(Δα) +end + +function trace_permute_pullback_Δβ(ΔC, C, β) Tdβ = Mooncake.rdata_type(Mooncake.tangent_type(typeof(β))) Tdβ === NoRData && return NoRData() diff --git a/ext/TensorKitMooncakeExt/utility.jl b/ext/TensorKitMooncakeExt/utility.jl index ca2c79b54..e93de22be 100644 --- a/ext/TensorKitMooncakeExt/utility.jl +++ b/ext/TensorKitMooncakeExt/utility.jl @@ -25,4 +25,15 @@ end # Ignore derivatives # ------------------ + +# A VectorSpace has no meaningful notion of a vector space (tangent space) +Mooncake.tangent_type(::Type{<:VectorSpace}) = Mooncake.NoTangent + @zero_derivative DefaultCtx Tuple{typeof(TensorKit.fusionblockstructure), Any} + +@zero_derivative DefaultCtx Tuple{typeof(TensorKit.select), HomSpace, Index2Tuple} +@zero_derivative DefaultCtx Tuple{typeof(TensorKit.flip), HomSpace, Any} +@zero_derivative DefaultCtx Tuple{typeof(TensorKit.permute), HomSpace, Index2Tuple} +@zero_derivative DefaultCtx Tuple{typeof(TensorKit.braid), HomSpace, Index2Tuple, IndexTuple} +@zero_derivative DefaultCtx Tuple{typeof(TensorKit.compose), HomSpace, HomSpace} +@zero_derivative DefaultCtx Tuple{typeof(TensorOperations.tensorcontract), HomSpace, Index2Tuple, Bool, HomSpace, Index2Tuple, Bool, Index2Tuple} diff --git a/ext/TensorKitMooncakeExt/vectorinterface.jl b/ext/TensorKitMooncakeExt/vectorinterface.jl new file mode 100644 index 000000000..2c1bfe984 --- /dev/null +++ b/ext/TensorKitMooncakeExt/vectorinterface.jl @@ -0,0 +1,93 @@ +Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(scale!), AbstractTensorMap, Number} + +function Mooncake.rrule!!(::CoDual{typeof(scale!)}, C_ΔC::CoDual{<:AbstractTensorMap}, α_Δα::CoDual{<:Number}) + # prepare arguments + C, ΔC = arrayify(C_ΔC) + α = primal(α_Δα) + + # primal call + C_cache = copy(C) + scale!(C, α) + + function scale_pullback(::NoRData) + copy!(C, C_cache) + scale!(ΔC, conj(α)) + TΔα = Mooncake.rdata_type(Mooncake.tangent_type(typeof(α))) + Δαr = TΔα === NoRData ? NoRData() : inner(C, ΔC) + return NoRData(), NoRData(), Δαr + end + + return C_ΔC, scale_pullback +end + +Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(scale!), AbstractTensorMap, AbstractTensorMap, Number} + +function Mooncake.rrule!!(::CoDual{typeof(scale!)}, C_ΔC::CoDual{<:AbstractTensorMap}, A_ΔA::CoDual{<:AbstractTensorMap}, α_Δα::CoDual{<:Number}) + # prepare arguments + C, ΔC = arrayify(C_ΔC) + A, ΔA = arrayify(A_ΔA) + α = primal(α_Δα) + + # primal call + C_cache = copy(C) + scale!(C, A, α) + + function scale_pullback(::NoRData) + copy!(C, C_cache) + zerovector!(ΔC) + scale!(ΔA, conj(α)) + TΔα = Mooncake.rdata_type(Mooncake.tangent_type(typeof(α))) + Δαr = TΔα === NoRData ? NoRData() : inner(C, ΔC) + return NoRData(), NoRData(), NoRData(), Δαr + end + + return C_ΔC, scale_pullback +end + +Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(add!), AbstractTensorMap, AbstractTensorMap, Number, Number} + +function Mooncake.rrule!!(::CoDual{typeof(add!)}, C_ΔC::CoDual{<:AbstractTensorMap}, A_ΔA::CoDual{<:AbstractTensorMap}, α_Δα::CoDual{<:Number}, β_Δβ::CoDual{<:Number}) + # prepare arguments + C, ΔC = arrayify(C_ΔC) + A, ΔA = arrayify(A_ΔA) + α = primal(α_Δα) + β = primal(β_Δβ) + + # primal call + C_cache = copy(C) + add!(C, A, α, β) + + function add_pullback(::NoRData) + copy!(C, C_cache) + scale!(ΔC, conj(β)) + scale!(ΔA, conj(α)) + + TΔα = Mooncake.rdata_type(Mooncake.tangent_type(typeof(α))) + Δαr = TΔα === NoRData ? NoRData() : inner(A, ΔC) + TΔβ = Mooncake.rdata_type(Mooncake.tangent_type(typeof(β))) + Δβr = TΔβ === NoRData ? NoRData() : inner(C, ΔC) + + return NoRData(), NoRData(), NoRData(), Δαr, Δβr + end + + return C_ΔC, add_pullback +end + +Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(inner), AbstractTensorMap, AbstractTensorMap} + +function Mooncake.rrule!!(::CoDual{typeof(inner)}, A_ΔA::CoDual{<:AbstractTensorMap}, B_ΔB::CoDual{<:AbstractTensorMap}) + # prepare arguments + A, ΔA = arrayify(A_ΔA) + B, ΔB = arrayify(B_ΔB) + + # primal call + s = inner(A, B) + + function inner_pullback(Δs) + scale!(ΔA, B, conj(Δs)) + scale!(ΔB, A, Δs) + return NoRData(), NoRData(), NoRData() + end + + return CoDual(s, NoFData()), inner_pullback +end diff --git a/test/autodiff/mooncake.jl b/test/autodiff/mooncake.jl index 1cd74fa27..e9f7d01d7 100644 --- a/test/autodiff/mooncake.jl +++ b/test/autodiff/mooncake.jl @@ -3,6 +3,7 @@ using TensorKit using TensorOperations using Mooncake using Random +using TupleTools mode = Mooncake.ReverseMode rng = Random.default_rng() @@ -13,6 +14,14 @@ function randindextuple(N::Int, k::Int = rand(0:N)) _p = randperm(N) return (tuple(_p[1:k]...), tuple(_p[(k + 1):end]...)) end +function randcircshift(N₁::Int, N₂::Int, k::Int = rand(0:(N₁ + N₂))) + N = N₁ + N₂ + @assert 0 ≤ k ≤ N + p = TupleTools.vcat(ntuple(identity, N₁), reverse(ntuple(identity, N₂) .+ N₁)) + n = rand(0:N) + _p = TupleTools.circshift(p, n) + return (tuple(_p[1:k]...), reverse(tuple(_p[(k + 1):end]...))) +end const _repartition = @static if isdefined(Base, :get_extension) Base.get_extension(TensorKit, :TensorKitMooncakeExt)._repartition @@ -50,13 +59,13 @@ spacelist = ( Vect[SU2Irrep](1 // 2 => 2), Vect[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)', ), - ( - Vect[FibonacciAnyon](:I => 2, :τ => 1), - Vect[FibonacciAnyon](:I => 1, :τ => 2)', - Vect[FibonacciAnyon](:I => 2, :τ => 2)', - Vect[FibonacciAnyon](:I => 2, :τ => 3), - Vect[FibonacciAnyon](:I => 2, :τ => 2), - ), + # ( + # Vect[FibonacciAnyon](:I => 2, :τ => 1), + # Vect[FibonacciAnyon](:I => 1, :τ => 2)', + # Vect[FibonacciAnyon](:I => 2, :τ => 2)', + # Vect[FibonacciAnyon](:I => 2, :τ => 3), + # Vect[FibonacciAnyon](:I => 2, :τ => 2), + # ), ) for V in spacelist @@ -68,6 +77,139 @@ for V in spacelist println("Mooncake with symmetry: $Istr") println("---------------------------------------") eltypes = (Float64,) # no complex support yet + + @timedtestset "VectorInterface with scalartype $T" for T in eltypes + atol = precision(T) + rtol = precision(T) + + C = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) + A = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) + α = randn(T) + β = randn(T) + + Mooncake.TestUtils.test_rule(rng, scale!, C, α; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, scale!, C', α; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, scale!, C, A, α; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, scale!, C', A', α; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, scale!, copy(C'), A', α; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, scale!, C', copy(A'), α; atol, rtol, mode) + + Mooncake.TestUtils.test_rule(rng, add!, C, A; atol, rtol, mode, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, add!, C, A, α; atol, rtol, mode, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, add!, C, A, α, β; atol, rtol, mode) + + Mooncake.TestUtils.test_rule(rng, inner, C, A; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, inner, C', A'; atol, rtol, mode) + end + + @timedtestset "LinearAlgebra with scalartype $T" for T in eltypes + atol = precision(T) + rtol = precision(T) + + C = randn(T, V[1] ⊗ V[2] ← V[5]) + A = randn(T, codomain(C) ← V[3] ⊗ V[4]) + B = randn(T, domain(A) ← domain(C)) + α = randn(T) + β = randn(T) + + Mooncake.TestUtils.test_rule(rng, mul!, C, A, B, α, β; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, mul!, C, A, B; atol, rtol, mode, is_primitive = false) + + Mooncake.TestUtils.test_rule(rng, norm, C, 2; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, norm, C', 2; atol, rtol, mode) + + D1 = randn(T, V[1] ← V[1]) + D2 = randn(T, V[1] ⊗ V[2] ← V[1] ⊗ V[2]) + D3 = randn(T, V[1] ⊗ V[2] ⊗ V[3] ← V[1] ⊗ V[2] ⊗ V[3]) + + Mooncake.TestUtils.test_rule(rng, tr, D1; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, tr, D2; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, tr, D3; atol, rtol, mode) + end + + + @timedtestset "Index manipulations with scalartype $T" for T in eltypes + atol = precision(T) + rtol = precision(T) + + symmetricbraiding && @timedtestset "add_permute!" begin + A = randn(T, V[1] ⊗ V[2] ← V[4] ⊗ V[5]) + α = randn(T) + β = randn(T) + + # repeat a couple times to get some distribution of arrows + for _ in 1:5 + p = randindextuple(numind(A)) + C = randn!(permute(A, p)) + Mooncake.TestUtils.test_rule(rng, TensorKit.add_permute!, C, A, p, α, β; atol, rtol, mode) + A = C + end + end + + @timedtestset "add_transpose!" begin + A = randn(T, V[1] ⊗ V[2] ← V[4] ⊗ V[5]) + α = randn(T) + β = randn(T) + + # repeat a couple times to get some distribution of arrows + for _ in 1:5 + p = randcircshift(numout(A), numin(A)) + C = randn!(transpose(A, p)) + Mooncake.TestUtils.test_rule(rng, TensorKit.add_transpose!, C, A, p, α, β; atol, rtol, mode) + A = C + end + end + + @timedtestset "add_braid!" begin + A = randn(T, V[1] ⊗ V[2] ← V[4] ⊗ V[5]) + α = randn(T) + β = randn(T) + + # repeat a couple times to get some distribution of arrows + for _ in 1:5 + p = randcircshift(numout(A), numin(A)) + levels = tuple(randperm(numind(A))) + C = randn!(transpose(A, p)) + Mooncake.TestUtils.test_rule(rng, TensorKit.add_transpose!, C, A, p, α, β; atol, rtol, mode) + A = C + end + end + + @timedtestset "flip_n_twist!" begin + A = randn(T, V[1] ⊗ V[2] ← V[4] ⊗ V[5]) + Mooncake.TestUtils.test_rule(rng, Core.kwcall, (; inv = false), twist!, A, 1; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, Core.kwcall, (; inv = true), twist!, A, [1, 3]; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, twist!, A, 1; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, twist!, A, [1, 3]; atol, rtol, mode) + + Mooncake.TestUtils.test_rule(rng, Core.kwcall, (; inv = false), flip, A, 1; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, Core.kwcall, (; inv = true), flip, A, [1, 3]; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, flip, A, 1; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, flip, A, [1, 3]; atol, rtol, mode) + end + + @timedtestset "insert and remove units" begin + A = randn(T, V[1] ⊗ V[2] ← V[4] ⊗ V[5]) + + for insertunit in (insertleftunit, insertrightunit) + Mooncake.TestUtils.test_rule(rng, insertunit, A, Val(1); atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, insertunit, A, Val(4); atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, insertunit, A', Val(2); atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, Core.kwcall, (; copy = false), insertunit, A, Val(1); atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, Core.kwcall, (; copy = true), insertunit, A, Val(2); atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, Core.kwcall, (; copy = false, dual = true, conj = true), insertunit, A, Val(3); atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, Core.kwcall, (; copy = false, dual = true, conj = true), insertunit, A', Val(3); atol, rtol, mode) + end + + for i in 1:4 + B = insertleftunit(A, i; dual = rand(Bool)) + Mooncake.TestUtils.test_rule(rng, removeunit, B, Val(i); atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, Core.kwcall, (; copy = false), removeunit, B, Val(i); atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, Core.kwcall, (; copy = true), removeunit, B, Val(i); atol, rtol, mode) + end + end + end + symmetricbraiding && @timedtestset "TensorOperations with scalartype $T" for T in eltypes atol = precision(T) rtol = precision(T) @@ -97,20 +239,115 @@ for V in spacelist β = randn(T) V2_conj = prod(conj, V2; init = one(V[1])) - for conjA in (false, true), conjB in (false, true) - A = randn(T, permute(V1 ← (conjA ? V2_conj : V2), ipA)) - B = randn(T, permute((conjB ? V2_conj : V2) ← V3, ipB)) - C = randn!( - TensorOperations.tensoralloc_contract( - T, A, pA, conjA, B, pB, conjB, pAB, Val(false) - ) - ) - Mooncake.TestUtils.test_rule( - rng, tensorcontract!, C, A, pA, conjA, B, pB, conjB, pAB, α, β; - atol, rtol, mode, is_primitive + A = randn(T, permute(V1 ← V2, ipA)) + B = randn(T, permute(V2 ← V3, ipB)) + C = randn!( + TensorOperations.tensoralloc_contract( + T, A, pA, false, B, pB, false, pAB, Val(false) ) + ) + Mooncake.TestUtils.test_rule( + rng, TensorKit.blas_contract!, + C, A, pA, B, pB, pAB, α, β, + TensorOperations.DefaultBackend(), TensorOperations.DefaultAllocator(); + atol, rtol, mode + ) + end + end + + @timedtestset "trace_permute!" begin + for _ in 1:5 + k1 = rand(0:2) + k2 = rand(1:2) + V1 = map(v -> rand(Bool) ? v' : v, rand(V, k1)) + V2 = map(v -> rand(Bool) ? v' : v, rand(V, k2)) + + (_p, _q) = randindextuple(k1 + 2 * k2, k1) + p = _repartition(_p, rand(0:k1)) + q = _repartition(_q, k2) + ip = _repartition(invperm(linearize((_p, _q))), rand(0:(k1 + 2 * k2))) + A = randn(T, permute(prod(V1) ⊗ prod(V2) ← prod(V2), ip)) + α = randn(T) + β = randn(T) + C = randn!(TensorOperations.tensoralloc_add(T, A, p, false, Val(false))) + Mooncake.TestUtils.test_rule( + rng, TensorKit.trace_permute!, C, A, p, q, α, β, TensorOperations.DefaultBackend(); + atol, rtol, mode + ) + end + end + end + + @timedtestset "PlanarOperations with scalartype $T" for T in eltypes + atol = precision(T) + rtol = precision(T) + + @timedtestset "planarcontract!" begin + for _ in 1:5 + d = 0 + local V1, V2, V3, k1, k2, k3 + # retry a couple times to make sure there are at least some nonzero elements + for _ in 1:10 + k1 = rand(0:3) + k2 = rand(0:2) + k3 = rand(0:2) + V1 = prod(v -> rand(Bool) ? v' : v, rand(V, k1); init = one(V[1])) + V2 = prod(v -> rand(Bool) ? v' : v, rand(V, k2); init = one(V[1])) + V3 = prod(v -> rand(Bool) ? v' : v, rand(V, k3); init = one(V[1])) + d = min(dim(V1 ← V2), dim(V1' ← V2), dim(V2 ← V3), dim(V2' ← V3)) + d > 1 && break end + k′ = rand(0:(k1 + k2)) + pA = randcircshift(k′, k1 + k2 - k′, k1) + ipA = _repartition(invperm(linearize(pA)), k′) + k′ = rand(0:(k2 + k3)) + pB = randcircshift(k′, k2 + k3 - k′, k2) + ipB = _repartition(invperm(linearize(pB)), k′) + # TODO: primal value already is broken for this? + # pAB = randcircshift(k1, k3) + pAB = _repartition(tuple((1:(k1 + k3))...), k1) + + α = randn(T) + β = randn(T) + + A = randn(T, permute(V1 ← V2, ipA)) + B = randn(T, permute(V2 ← V3, ipB)) + C = randn!( + TensorOperations.tensoralloc_contract( + T, A, pA, false, B, pB, false, pAB, Val(false) + ) + ) + Mooncake.TestUtils.test_rule( + rng, TensorKit.planarcontract!, C, A, pA, B, pB, pAB, α, β; + atol, rtol, mode, is_primitive = false + ) + end + end + + @timedtestset "planartrace!" begin + for _ in 1:5 + k1 = rand(0:2) + k2 = rand(1:2) + V1 = map(v -> rand(Bool) ? v' : v, rand(V, k1)) + V2 = map(v -> rand(Bool) ? v' : v, rand(V, k2)) + + k′ = rand(0:(k1 + 2k2)) + (_p, _q) = randcircshift(k′, k1 + 2 * k2 - k′, k1) + p = _repartition(_p, rand(0:k1)) + q = _repartition(_q, k2) + ip = _repartition(invperm(linearize((_p, _q))), k′) + A = randn(T, permute(prod(V1) ⊗ prod(V2) ← prod(V2), ip)) + + α = randn(T) + β = randn(T) + C = randn!(TensorOperations.tensoralloc_add(T, A, p, false, Val(false))) + Mooncake.TestUtils.test_rule( + rng, TensorKit.planartrace!, + C, A, p, q, α, β, + TensorOperations.DefaultBackend(), TensorOperations.DefaultAllocator(); + atol, rtol, mode + ) end end end