diff --git a/Project.toml b/Project.toml index c05ec0c..d57dce1 100644 --- a/Project.toml +++ b/Project.toml @@ -24,12 +24,14 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" [extensions] TensorOperationsBumperExt = "Bumper" TensorOperationsChainRulesCoreExt = "ChainRulesCore" TensorOperationscuTENSORExt = ["cuTENSOR", "CUDA"] TensorOperationsMooncakeExt = "Mooncake" +TensorOperationsEnzymeExt = ["Enzyme", "ChainRulesCore"] [compat] Aqua = "0.6, 0.7, 0.8" @@ -38,6 +40,8 @@ CUDA = "5" ChainRulesCore = "1" ChainRulesTestUtils = "1" DynamicPolynomials = "0.5, 0.6" +Enzyme = "0.13.115" +EnzymeTestUtils = "0.2" LRUCache = "1" LinearAlgebra = "1.6" Logging = "1.6" @@ -48,7 +52,7 @@ Preferences = "1.4" PtrArrays = "1.2" Random = "1" Strided = "2.2" -StridedViews = "0.3, 0.4" +StridedViews = "=0.4.1" Test = "1" TupleTools = "1.6" VectorInterface = "0.4.1,0.5" @@ -59,8 +63,11 @@ julia = "1.10" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" DynamicPolynomials = "7c1d4256-1411-5781-91ec-d7bc3513ac07" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -68,4 +75,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1" [targets] -test = ["Test", "Random", "DynamicPolynomials", "ChainRulesTestUtils", "CUDA", "cuTENSOR", "Aqua", "Logging", "Bumper", "Mooncake"] +test = ["Test", "Random", "DynamicPolynomials", "ChainRulesTestUtils", "ChainRulesCore", "CUDA", "cuTENSOR", "Aqua", "Logging", "Bumper", "Mooncake", "Enzyme", "EnzymeTestUtils"] diff --git a/ext/TensorOperationsChainRulesCoreExt.jl b/ext/TensorOperationsChainRulesCoreExt.jl index 539d9d9..8338468 100644 --- a/ext/TensorOperationsChainRulesCoreExt.jl +++ b/ext/TensorOperationsChainRulesCoreExt.jl @@ -1,7 +1,7 @@ module TensorOperationsChainRulesCoreExt using TensorOperations -using TensorOperations: numind, numin, numout, promote_contract, _needs_tangent +using TensorOperations: numind, numin, numout, promote_contract, _needs_tangent, trivtuple using TensorOperations: DefaultBackend, DefaultAllocator, _kron using ChainRulesCore using TupleTools @@ -9,8 +9,6 @@ using VectorInterface using TupleTools: invperm using LinearAlgebra -trivtuple(N) = ntuple(identity, N) - @non_differentiable TensorOperations.tensorstructure(args...) @non_differentiable TensorOperations.tensoradd_structure(args...) @non_differentiable TensorOperations.tensoradd_type(args...) diff --git a/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl b/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl new file mode 100644 index 0000000..18fb7b0 --- /dev/null +++ b/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl @@ -0,0 +1,197 @@ +module TensorOperationsEnzymeExt + +using TensorOperations +using TensorOperations: AbstractBackend, DefaultAllocator, CUDAAllocator, ManualAllocator +using VectorInterface +using TupleTools +using Enzyme, ChainRulesCore +using Enzyme.EnzymeCore +using Enzyme.EnzymeCore: EnzymeRules + +@inline EnzymeRules.inactive(::typeof(TensorOperations.tensorfree!), ::Any) = true +Enzyme.@import_rrule(typeof(TensorOperations.tensoralloc), Any, Any, Any, Any) + +@inline EnzymeRules.inactive_type(v::Type{<:AbstractBackend}) = true +@inline EnzymeRules.inactive_type(v::Type{DefaultAllocator}) = true +@inline EnzymeRules.inactive_type(v::Type{<:CUDAAllocator}) = true +@inline EnzymeRules.inactive_type(v::Type{ManualAllocator}) = true +@inline EnzymeRules.inactive_type(v::Type{<:Index2Tuple}) = true + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(TensorOperations.tensorcontract!)}, + ::Type{RT}, + C_dC::Annotation{<:AbstractArray{TC}}, + A_dA::Annotation{<:AbstractArray{TA}}, + pA_dpA::Const{<:Index2Tuple}, + conjA_dconjA::Const{Bool}, + B_dB::Annotation{<:AbstractArray{TB}}, + pB_dpB::Const{<:Index2Tuple}, + conjB_dconjB::Const{Bool}, + pAB_dpAB::Const{<:Index2Tuple}, + α_dα::Annotation{Tα}, + β_dβ::Annotation{Tβ}, + ba_dba::Const..., + ) where {RT, Tα <: Number, Tβ <: Number, TA <: Number, TB <: Number, TC <: Number} + # form caches if needed + cache_A = !isa(A_dA, Const) && EnzymeRules.overwritten(config)[3] ? copy(A_dA.val) : nothing + cache_B = !isa(B_dB, Const) && EnzymeRules.overwritten(config)[6] ? copy(B_dB.val) : nothing + cache_C = copy(C_dC.val) # do we need to do this, if we don't need the primal? + ba = map(ba_ -> getfield(ba_, :val), ba_dba) + TensorOperations.tensorcontract!(C_dC.val, A_dA.val, pA_dpA.val, conjA_dconjA.val, B_dB.val, pB_dpB.val, conjB_dconjB.val, pAB_dpAB.val, α_dα.val, β_dβ.val, ba...) + primal = if EnzymeRules.needs_primal(config) + C_dC.val + else + nothing + end + shadow = EnzymeRules.needs_shadow(config) ? C_dC.dval : nothing + return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_B, cache_C)) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(TensorOperations.tensorcontract!)}, + ::Type{RT}, + cache, + C_dC::Annotation{<:AbstractArray{TC}}, + A_dA::Annotation{<:AbstractArray{TA}}, + pA_dpA::Const{<:Index2Tuple}, + conjA_dconjA::Const{Bool}, + B_dB::Annotation{<:AbstractArray{TB}}, + pB_dpB::Const{<:Index2Tuple}, + conjB_dconjB::Const{Bool}, + pAB_dpAB::Const{<:Index2Tuple}, + α_dα::Annotation{Tα}, + β_dβ::Annotation{Tβ}, + ba_dba::Const..., + ) where {RT, Tα <: Number, Tβ <: Number, TA <: Number, TB <: Number, TC <: Number} + cache_A, cache_B, cache_C = cache + Aval = something(cache_A, A_dA.val) + Bval = something(cache_B, B_dB.val) + Cval = cache_C + dC = C_dC.dval + dA = A_dA.dval + dB = B_dB.dval + ba = map(ba_ -> getfield(ba_, :val), ba_dba) + α = α_dα.val + β = β_dβ.val + dC, dA, dB, dα, dβ = TensorOperations.tensorcontract_pullback!(dC, dA, dB, Cval, Aval, Bval, α, β, pA_dpA.val, pB_dpB.val, pAB_dpAB.val, conjA_dconjA.val, conjB_dconjB.val, ba...) + return nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, dα, dβ, map(ba_ -> nothing, ba)... +end + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + ::Annotation{typeof(tensoradd!)}, + ::Type{RT}, + C_dC::Annotation{<:AbstractArray{TC}}, + A_dA::Annotation{<:AbstractArray{TA}}, + pA_dpA::Const{<:Index2Tuple}, + conjA_dconjA::Const{Bool}, + α_dα::Annotation{Tα}, + β_dβ::Annotation{Tβ}, + ba_dba::Const..., + ) where {RT, Tα <: Number, Tβ <: Number, TA <: Number, TC <: Number} + # form caches if needed + cache_A = !isa(A_dA, Const) && EnzymeRules.overwritten(config)[3] ? copy(A_dA.val) : nothing + cache_C = copy(C_dC.val) + ba = map(ba_ -> getfield(ba_, :val), ba_dba) + α = α_dα.val + β = β_dβ.val + conjA = conjA_dconjA.val + TensorOperations.tensoradd!(C_dC.val, A_dA.val, pA_dpA.val, conjA, α, β, ba...) + primal = if EnzymeRules.needs_primal(config) + C_dC.val + else + nothing + end + shadow = EnzymeRules.needs_shadow(config) ? C_dC.dval : nothing + return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_C)) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + ::Annotation{typeof(tensoradd!)}, + ::Type{RT}, + cache, + C_dC::Annotation{<:AbstractArray{TC}}, + A_dA::Annotation{<:AbstractArray{TA}}, + pA_dpA::Const{<:Index2Tuple}, + conjA_dconjA::Const{Bool}, + α_dα::Annotation{Tα}, + β_dβ::Annotation{Tβ}, + ba_dba::Const..., + ) where {RT, Tα <: Number, Tβ <: Number, TA <: Number, TC <: Number} + cache_A, cache_C = cache + Aval = something(cache_A, A_dA.val) + Cval = cache_C + pA = pA_dpA.val + conjA = conjA_dconjA.val + α = α_dα.val + β = β_dβ.val + ba = map(ba_ -> getfield(ba_, :val), ba_dba) + dC = C_dC.dval + dA = A_dA.dval + dC, dA, dα, dβ = TensorOperations.tensoradd_pullback!(dC, dA, Cval, Aval, α, β, pA, conjA, ba...) + return nothing, nothing, nothing, nothing, dα, dβ, map(ba_ -> nothing, ba)... +end + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + ::Annotation{typeof(tensortrace!)}, + ::Type{RT}, + C_dC::Annotation{<:AbstractArray{TC}}, + A_dA::Annotation{<:AbstractArray{TA}}, + p_dp::Const{<:Index2Tuple}, + q_dq::Const{<:Index2Tuple}, + conjA_dconjA::Const{Bool}, + α_dα::Annotation{Tα}, + β_dβ::Annotation{Tβ}, + ba_dba::Const..., + ) where {RT, Tα <: Number, Tβ <: Number, TA <: Number, TC <: Number} + # form caches if needed + cache_A = !isa(A_dA, Const) && EnzymeRules.overwritten(config)[3] ? copy(A_dA.val) : nothing + cache_C = copy(C_dC.val) + ba = map(ba_ -> getfield(ba_, :val), ba_dba) + α = α_dα.val + β = β_dβ.val + conjA = conjA_dconjA.val + TensorOperations.tensortrace!(C_dC.val, A_dA.val, p_dp.val, q_dq.val, conjA, α, β, ba...) + primal = if EnzymeRules.needs_primal(config) + C_dC.val + else + nothing + end + shadow = EnzymeRules.needs_shadow(config) ? C_dC.dval : nothing + return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_C)) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + ::Annotation{typeof(tensortrace!)}, + ::Type{RT}, + cache, + C_dC::Annotation{<:AbstractArray{TC}}, + A_dA::Annotation{<:AbstractArray{TA}}, + p_dp::Const{<:Index2Tuple}, + q_dq::Const{<:Index2Tuple}, + conjA_dconjA::Const{Bool}, + α_dα::Annotation{Tα}, + β_dβ::Annotation{Tβ}, + ba_dba::Const..., + ) where {RT, Tα <: Number, Tβ <: Number, TA <: Number, TC <: Number} + cache_A, cache_C = cache + Aval = something(cache_A, A_dA.val) + Cval = cache_C + p = p_dp.val + q = q_dq.val + conjA = conjA_dconjA.val + α = α_dα.val + β = β_dβ.val + ba = map(ba_ -> getfield(ba_, :val), ba_dba) + dC = C_dC.dval + dA = A_dA.dval + dC, dA, dα, dβ = TensorOperations.tensortrace_pullback!(dC, dA, Cval, Aval, α, β, p, q, conjA, ba...) + return nothing, nothing, nothing, nothing, nothing, dα, dβ, map(ba_ -> nothing, ba)... +end + +end diff --git a/ext/TensorOperationsMooncakeExt/TensorOperationsMooncakeExt.jl b/ext/TensorOperationsMooncakeExt/TensorOperationsMooncakeExt.jl index 49d1708..9450d77 100644 --- a/ext/TensorOperationsMooncakeExt/TensorOperationsMooncakeExt.jl +++ b/ext/TensorOperationsMooncakeExt/TensorOperationsMooncakeExt.jl @@ -6,7 +6,7 @@ using TensorOperations # extension are in fact loaded using Mooncake, Mooncake.CRC using TensorOperations: AbstractBackend, DefaultAllocator, CUDAAllocator, ManualAllocator -using TensorOperations: tensoralloc, tensoradd!, tensorcontract!, tensortrace!, _kron, numind, _needs_tangent, numin, numout +using TensorOperations: tensoralloc, tensoradd!, tensorcontract!, tensortrace! using Mooncake: ReverseMode, DefaultCtx, CoDual, NoRData, arrayify, @zero_derivative, primal, tangent using VectorInterface, TupleTools @@ -16,8 +16,6 @@ Mooncake.tangent_type(::Type{DefaultAllocator}) = Mooncake.NoTangent Mooncake.tangent_type(::Type{CUDAAllocator}) = Mooncake.NoTangent Mooncake.tangent_type(::Type{ManualAllocator}) = Mooncake.NoTangent -trivtuple(N) = ntuple(identity, N) - @zero_derivative DefaultCtx Tuple{typeof(TensorOperations.tensorstructure), Any} @zero_derivative DefaultCtx Tuple{typeof(TensorOperations.tensoradd_structure), Any} @zero_derivative DefaultCtx Tuple{typeof(TensorOperations.tensoradd_type), Any} @@ -61,69 +59,9 @@ function Mooncake.rrule!!( TensorOperations.tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba...) function contract_pb(::NoRData) scale!(C, C_cache, One()) - if Tα == Zero && Tβ == Zero - scale!(dC, zero(TC)) - return ntuple(i -> NoRData(), 11 + length(ba)) - end - ipAB = invperm(linearize(pAB)) - pdC = ( - TupleTools.getindices(ipAB, trivtuple(numout(pA))), - TupleTools.getindices(ipAB, numout(pA) .+ trivtuple(numin(pB))), - ) - ipA = (invperm(linearize(pA)), ()) - ipB = (invperm(linearize(pB)), ()) - conjΔC = conjA - conjB′ = conjA ? conjB : !conjB - dA = tensorcontract!( - dA, - dC, pdC, conjΔC, - B, reverse(pB), conjB′, - ipA, - conjA ? α : conj(α), One(), ba... - ) - conjΔC = conjB - conjA′ = conjB ? conjA : !conjA - dB = tensorcontract!( - dB, - A, reverse(pA), conjA′, - dC, pdC, conjΔC, - ipB, - conjB ? α : conj(α), One(), ba... - ) - dα = if _needs_tangent(Tα) - C_αβ = tensorcontract(A, pA, conjA, B, pB, conjB, pAB, One(), ba...) - # TODO: consider using `inner` - Mooncake._rdata( - tensorscalar( - tensorcontract( - C_αβ, ((), trivtuple(numind(pAB))), true, - dC, (trivtuple(numind(pAB)), ()), false, - ((), ()), One(), ba... - ) - ) - ) - else - NoRData() - end - dβ = if _needs_tangent(Tβ) - # TODO: consider using `inner` - Mooncake._rdata( - tensorscalar( - tensorcontract( - C, ((), trivtuple(numind(pAB))), true, - dC, (trivtuple(numind(pAB)), ()), false, - ((), ()), One(), ba... - ) - ) - ) - else - NoRData() - end - if β === Zero() - scale!(dC, β) - else - scale!(dC, conj(β)) - end + dC, dA, dB, Δα, Δβ = TensorOperations.tensorcontract_pullback!(dC, dA, dB, C, A, B, α, β, pA, pB, pAB, conjA, conjB, ba...) + dα = isnothing(Δα) ? NoRData() : Mooncake._rdata(Δα) + dβ = isnothing(Δβ) ? NoRData() : Mooncake._rdata(Δβ) return NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), dα, dβ, map(ba_ -> NoRData(), ba)... end return C_dC, contract_pb @@ -151,35 +89,9 @@ function Mooncake.rrule!!( TensorOperations.tensoradd!(C, A, pA, conjA, α, β, ba...) function add_pb(::NoRData) scale!(C, C_cache, One()) - ipA = invperm(linearize(pA)) - dA = tensoradd!(dA, dC, (ipA, ()), conjA, conjA ? α : conj(α), One(), ba...) - dα = if _needs_tangent(Tα) - tensorscalar( - tensorcontract( - A, ((), linearize(pA)), !conjA, - dC, (trivtuple(numind(pA)), ()), false, - ((), ()), One(), ba... - ) - ) - else - Mooncake.NoRData() - end - dβ = if _needs_tangent(Tβ) - tensorscalar( - tensorcontract( - C, ((), trivtuple(numind(pA))), true, - dC, (trivtuple(numind(pA)), ()), false, - ((), ()), One(), ba... - ) - ) - else - Mooncake.NoRData() - end - if β === Zero() - scale!(dC, β) - else - scale!(dC, conj(β)) - end + dC, dA, Δα, Δβ = TensorOperations.tensoradd_pullback!(dC, dA, C, A, α, β, pA, conjA, ba...) + dα = isnothing(Δα) ? NoRData() : Mooncake._rdata(Δα) + dβ = isnothing(Δβ) ? NoRData() : Mooncake._rdata(Δβ) return NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), dα, dβ, map(ba_ -> NoRData(), ba)... end return C_dC, add_pb @@ -209,54 +121,9 @@ function Mooncake.rrule!!( TensorOperations.tensortrace!(C, A, p, q, conjA, α, β, ba...) function trace_pb(::NoRData) scale!(C, C_cache, One()) - ip = invperm((linearize(p)..., q[1]..., q[2]...)) - Es = map(q[1], q[2]) do i1, i2 - one( - TensorOperations.tensoralloc_add( - TensorOperations.scalartype(A), A, ((i1,), (i2,)), conjA - ) - ) - end - E = _kron(Es, ba) - dA = tensorproduct!( - dA, dC, (trivtuple(numind(p)), ()), conjA, - E, ((), trivtuple(numind(q))), conjA, - (ip, ()), - conjA ? α : conj(α), One(), ba... - ) - C_αβ = tensortrace(A, p, q, false, One(), ba...) - dα = if _needs_tangent(Tα) - Mooncake._rdata( - tensorscalar( - tensorcontract( - C_αβ, ((), trivtuple(numind(p))), - !conjA, - dC, (trivtuple(numind(p)), ()), false, - ((), ()), One(), ba... - ) - ) - ) - else - NoRData() - end - dβ = if _needs_tangent(Tβ) - Mooncake._rdata( - tensorscalar( - tensorcontract( - C, ((), trivtuple(numind(p))), true, - dC, (trivtuple(numind(p)), ()), false, - ((), ()), One(), ba... - ) - ) - ) - else - NoRData() - end - if β === Zero() - scale!(dC, β) - else - scale!(dC, conj(β)) - end + dC, dA, Δα, Δβ = TensorOperations.tensortrace_pullback!(dC, dA, C, A, α, β, p, q, conjA, ba...) + dα = isnothing(Δα) ? NoRData() : Mooncake._rdata(Δα) + dβ = isnothing(Δβ) ? NoRData() : Mooncake._rdata(Δβ) return NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), dα, dβ, map(ba_ -> NoRData(), ba)... end return C_dC, trace_pb diff --git a/src/TensorOperations.jl b/src/TensorOperations.jl index 23f205d..05ae808 100644 --- a/src/TensorOperations.jl +++ b/src/TensorOperations.jl @@ -36,6 +36,12 @@ include("backends.jl") include("interface.jl") include("utils.jl") +# Generic pullbacks for AD +#--------------------------- +include("pullbacks/add.jl") +include("pullbacks/trace.jl") +include("pullbacks/contract.jl") + # Index notation via macros #--------------------------- @nospecialize diff --git a/src/pullbacks/add.jl b/src/pullbacks/add.jl new file mode 100644 index 0000000..bf34e3f --- /dev/null +++ b/src/pullbacks/add.jl @@ -0,0 +1,32 @@ +function tensoradd_pullback!(ΔC, ΔA, C, A, α, β, pA, conjA::Bool, ba...) + ipA = invperm(linearize(pA)) + ΔAc = eltype(ΔC) <: Complex && eltype(ΔA) <: Real ? zerovector(A, VectorInterface.promote_add(ΔC, α)) : ΔA + tensoradd!(ΔAc, ΔC, (ipA, ()), conjA, conjA ? α : conj(α), One(), ba...) + if eltype(ΔC) <: Complex && eltype(ΔA) <: Real + ΔA .+= real.(ΔAc) + end + Δα = if _needs_tangent(α) + tensorscalar( + tensorcontract( + A, ((), linearize(pA)), !conjA, + ΔC, (trivtuple(numind(pA)), ()), false, + ((), ()), One(), ba... + ) + ) + else + nothing + end + Δβ = if _needs_tangent(β) + tensorscalar( + tensorcontract( + C, ((), trivtuple(numind(pA))), true, + ΔC, (trivtuple(numind(pA)), ()), false, + ((), ()), One(), ba... + ) + ) + else + nothing + end + scale!(ΔC, conj(β)) + return ΔC, ΔA, Δα, Δβ +end diff --git a/src/pullbacks/contract.jl b/src/pullbacks/contract.jl new file mode 100644 index 0000000..2bd42ac --- /dev/null +++ b/src/pullbacks/contract.jl @@ -0,0 +1,62 @@ +function tensorcontract_pullback!(ΔC, ΔA, ΔB, C, A, B, α, β, pA, pB, pAB, conjA::Bool, conjB::Bool, ba...) + ipAB = invperm(linearize(pAB)) + pdC = ( + TupleTools.getindices(ipAB, trivtuple(numout(pA))), + TupleTools.getindices(ipAB, numout(pA) .+ trivtuple(numin(pB))), + ) + ipA = (invperm(linearize(pA)), ()) + ipB = (invperm(linearize(pB)), ()) + conjΔC = conjA + conjB′ = conjA ? conjB : !conjB + ΔAc = eltype(ΔC) <: Complex && eltype(ΔA) <: Real ? zerovector(A, VectorInterface.promote_add(ΔC, α)) : ΔA + tensorcontract!( + ΔAc, + ΔC, pdC, conjΔC, + B, reverse(pB), conjB′, + ipA, + conjA ? α : conj(α), One(), ba... + ) + if eltype(ΔC) <: Complex && eltype(ΔA) <: Real + ΔA .+= real.(ΔAc) + end + conjΔC = conjB + conjA′ = conjB ? conjA : !conjA + ΔBc = eltype(ΔC) <: Complex && eltype(ΔB) <: Real ? zerovector(B, VectorInterface.promote_add(ΔC, α)) : ΔB + tensorcontract!( + ΔBc, + A, reverse(pA), conjA′, + ΔC, pdC, conjΔC, + ipB, + conjB ? α : conj(α), One(), ba... + ) + if eltype(ΔC) <: Complex && eltype(ΔB) <: Real + ΔB .+= real.(ΔBc) + end + Δα = if _needs_tangent(α) + C_αβ = tensorcontract(A, pA, conjA, B, pB, conjB, pAB, One(), ba...) + # TODO: consider using `inner` + tensorscalar( + tensorcontract( + C_αβ, ((), trivtuple(numind(pAB))), true, + ΔC, (trivtuple(numind(pAB)), ()), false, + ((), ()), One(), ba... + ) + ) + else + nothing + end + Δβ = if _needs_tangent(β) + # TODO: consider using `inner` + tensorscalar( + tensorcontract( + C, ((), trivtuple(numind(pAB))), true, + ΔC, (trivtuple(numind(pAB)), ()), false, + ((), ()), One(), ba... + ) + ) + else + nothing + end + scale!(ΔC, conj(β)) + return ΔC, ΔA, ΔB, Δα, Δβ +end diff --git a/src/pullbacks/trace.jl b/src/pullbacks/trace.jl new file mode 100644 index 0000000..5c08033 --- /dev/null +++ b/src/pullbacks/trace.jl @@ -0,0 +1,47 @@ +function tensortrace_pullback!(ΔC, ΔA, C, A, α, β, p, q, conjA, ba...) + ip = invperm((linearize(p)..., q[1]..., q[2]...)) + Es = map(q[1], q[2]) do i1, i2 + one( + TensorOperations.tensoralloc_add( + TensorOperations.scalartype(A), A, ((i1,), (i2,)), conjA + ) + ) + end + E = _kron(Es, ba) + ΔAc = eltype(ΔC) <: Complex && eltype(ΔA) <: Real ? zerovector(A, VectorInterface.promote_add(ΔC, α)) : ΔA + tensorproduct!( + ΔAc, ΔC, (trivtuple(numind(p)), ()), conjA, + E, ((), trivtuple(numind(q))), conjA, + (ip, ()), + conjA ? α : conj(α), One(), ba... + ) + if eltype(ΔC) <: Complex && eltype(ΔA) <: Real + ΔA .+= real.(ΔAc) + end + C_αβ = tensortrace(A, p, q, false, One(), ba...) + Δα = if _needs_tangent(α) + tensorscalar( + tensorcontract( + C_αβ, ((), trivtuple(numind(p))), + !conjA, + ΔC, (trivtuple(numind(p)), ()), false, + ((), ()), One(), ba... + ) + ) + else + nothing + end + Δβ = if _needs_tangent(β) + tensorscalar( + tensorcontract( + C, ((), trivtuple(numind(p))), true, + ΔC, (trivtuple(numind(p)), ()), false, + ((), ()), One(), ba... + ) + ) + else + nothing + end + scale!(ΔC, conj(β)) + return ΔC, ΔA, Δα, Δβ +end diff --git a/src/utils.jl b/src/utils.jl index 20e406b..68b98c8 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -13,3 +13,5 @@ _needs_tangent(x) = _needs_tangent(typeof(x)) _needs_tangent(::Type{<:Number}) = true _needs_tangent(::Type{<:Integer}) = false _needs_tangent(::Type{<:Union{One, Zero}}) = false + +trivtuple(N) = ntuple(identity, N) diff --git a/test/enzyme.jl b/test/enzyme.jl new file mode 100644 index 0000000..2852b5d --- /dev/null +++ b/test/enzyme.jl @@ -0,0 +1,93 @@ +using TensorOperations, VectorInterface +using Enzyme, ChainRulesCore, EnzymeTestUtils + +@testset "tensorcontract! ($T₁, $T₂)" for (T₁, T₂) in + ( + (Float64, Float64), + (Float32, Float64), + (ComplexF64, ComplexF64), + (Float64, ComplexF64), + (ComplexF64, Float64), + ) + T = promote_type(T₁, T₂) + atol = max(precision(T₁), precision(T₂)) + rtol = max(precision(T₁), precision(T₂)) + + pAB = ((3, 2, 4, 1), ()) + pA = ((2, 4, 5), (1, 3)) + pB = ((2, 1), (3,)) + + A = rand(T₁, (2, 3, 4, 2, 5)) + B = rand(T₂, (4, 2, 3)) + C = rand(T, (5, 2, 3, 3)) + @testset for (α, β) in ((Zero(), Zero()), (randn(T), Zero()), (Zero(), randn(T)), (randn(T), randn(T))) + Tα = α === Zero() ? Const : Active + Tβ = β === Zero() ? Const : Active + test_reverse(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (false, Const), (B, Duplicated), (pB, Const), (false, Const), (pAB, Const), (α, Tα), (β, Tβ); atol, rtol) + test_reverse(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (false, Const), (B, Duplicated), (pB, Const), (true, Const), (pAB, Const), (α, Tα), (β, Tβ); atol, rtol) + test_reverse(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (B, Duplicated), (pB, Const), (true, Const), (pAB, Const), (α, Tα), (β, Tβ); atol, rtol) + + test_reverse(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (false, Const), (B, Duplicated), (pB, Const), (false, Const), (pAB, Const), (α, Tα), (β, Tβ), (StridedBLAS(), Const); atol, rtol) + test_reverse(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (B, Duplicated), (pB, Const), (true, Const), (pAB, Const), (α, Tα), (β, Tβ), (StridedNative(), Const); atol, rtol) + end +end + +@testset "tensoradd! ($T₁, $T₂)" for (T₁, T₂) in ( + (Float64, Float64), + (Float32, Float64), + (ComplexF64, ComplexF64), + (Float64, ComplexF64), + ) + T = promote_type(T₁, T₂) + atol = max(precision(T₁), precision(T₂)) + rtol = max(precision(T₁), precision(T₂)) + + pA = ((2, 1, 4, 3, 5), ()) + A = rand(T₁, (2, 3, 4, 2, 1)) + C = rand(T₂, size.(Ref(A), pA[1])) + @testset for (α, β) in ((Zero(), Zero()), (randn(T), Zero()), (Zero(), randn(T)), (randn(T), randn(T))) + Tα = α === Zero() ? Const : Active + Tβ = β === Zero() ? Const : Active + test_reverse(tensoradd!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (false, Const), (α, Tα), (β, Tβ); atol, rtol) + test_reverse(tensoradd!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (α, Tα), (β, Tβ); atol, rtol) + + test_reverse(tensoradd!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (false, Const), (α, Tα), (β, Tβ), (StridedBLAS(), Const); atol, rtol) + test_reverse(tensoradd!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (α, Tα), (β, Tβ), (StridedNative(), Const); atol, rtol) + end +end + +@testset "tensortrace! ($T₁, $T₂)" for (T₁, T₂) in + ( + (Float64, Float64), + (Float32, Float64), + (ComplexF64, ComplexF64), + (Float64, ComplexF64), + ) + T = promote_type(T₁, T₂) + atol = max(precision(T₁), precision(T₂)) + rtol = max(precision(T₁), precision(T₂)) + + p = ((3, 5, 2), ()) + q = ((1,), (4,)) + A = rand(T₁, (2, 3, 4, 2, 5)) + C = rand(T₂, size.(Ref(A), p[1])) + @testset for (α, β) in ((Zero(), Zero()), (randn(T), Zero()), (Zero(), randn(T)), (randn(T), randn(T))) + Tα = α === Zero() ? Const : Active + Tβ = β === Zero() ? Const : Active + + test_reverse(tensortrace!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (q, Const), (false, Const), (α, Tα), (β, Tβ); atol, rtol) + test_reverse(tensortrace!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (q, Const), (true, Const), (α, Tα), (β, Tβ); atol, rtol) + + test_reverse(tensortrace!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (q, Const), (true, Const), (α, Tα), (β, Tβ), (StridedBLAS(), Const); atol, rtol) + test_reverse(tensortrace!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (q, Const), (false, Const), (α, Tα), (β, Tβ), (StridedNative(), Const); atol, rtol) + end +end + +@testset "tensorscalar ($T)" for T in (Float32, Float64, ComplexF64) + atol = precision(T) + rtol = precision(T) + + C = Array{T, 0}(undef, ()) + fill!(C, rand(T)) + test_reverse(tensorscalar, Active, (C, Duplicated); atol, rtol) +end diff --git a/test/mooncake.jl b/test/mooncake.jl index 729a74d..8db47e9 100644 --- a/test/mooncake.jl +++ b/test/mooncake.jl @@ -14,6 +14,7 @@ is_primitive = false (Float32, Float64), #(ComplexF64, ComplexF64), #(Float64, ComplexF64), + #(ComplexF64, Float64), ) T = promote_type(T₁, T₂) atol = max(precision(T₁), precision(T₂)) diff --git a/test/runtests.jl b/test/runtests.jl index f67fbe6..dd594b1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -15,7 +15,6 @@ precision(::Type{<:Union{Float64, Complex{Float64}}}) = 1.0e-8 # specific ones is_buildkite = get(ENV, "BUILDKITE", "false") == "true" if !is_buildkite - @testset "tensoropt" verbose = true begin include("tensoropt.jl") end @@ -37,6 +36,9 @@ if !is_buildkite @testset "mooncake" verbose = false begin include("mooncake.jl") end + @testset "enzyme" verbose = false begin + include("enzyme.jl") + end end if is_buildkite