From a99bf0c36fad28eacb3c00ed86f5236efce9bb4b Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 14 Jan 2026 15:32:57 +0100 Subject: [PATCH 1/9] Add Enzyme rules --- Project.toml | 7 +- .../TensorOperationsEnzymeExt.jl | 347 ++++++++++++++++++ test/enzyme.jl | 92 +++++ test/runtests.jl | 4 +- 4 files changed, 448 insertions(+), 2 deletions(-) create mode 100644 ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl create mode 100644 test/enzyme.jl diff --git a/Project.toml b/Project.toml index c05ec0c..52f2d97 100644 --- a/Project.toml +++ b/Project.toml @@ -24,12 +24,14 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" [extensions] TensorOperationsBumperExt = "Bumper" TensorOperationsChainRulesCoreExt = "ChainRulesCore" TensorOperationscuTENSORExt = ["cuTENSOR", "CUDA"] TensorOperationsMooncakeExt = "Mooncake" +TensorOperationsEnzymeExt = "Enzyme" [compat] Aqua = "0.6, 0.7, 0.8" @@ -38,6 +40,7 @@ CUDA = "5" ChainRulesCore = "1" ChainRulesTestUtils = "1" DynamicPolynomials = "0.5, 0.6" +Enzyme = "0.13.115" LRUCache = "1" LinearAlgebra = "1.6" Logging = "1.6" @@ -61,6 +64,8 @@ Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" DynamicPolynomials = "7c1d4256-1411-5781-91ec-d7bc3513ac07" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -68,4 +73,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1" [targets] -test = ["Test", "Random", "DynamicPolynomials", "ChainRulesTestUtils", "CUDA", "cuTENSOR", "Aqua", "Logging", "Bumper", "Mooncake"] +test = ["Test", "Random", "DynamicPolynomials", "ChainRulesTestUtils", "CUDA", "cuTENSOR", "Aqua", "Logging", "Bumper", "Mooncake", "Enzyme", "EnzymeTestUtils"] diff --git a/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl b/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl new file mode 100644 index 0000000..b681e88 --- /dev/null +++ b/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl @@ -0,0 +1,347 @@ +module TensorOperationsEnzymeExt + +using TensorOperations +using TensorOperations: numind, numin, numout, promote_contract +using TensorOperations: AbstractBackend, DefaultAllocator, CUDAAllocator, ManualAllocator +using VectorInterface +using TupleTools +using Enzyme +using Enzyme.EnzymeCore +using Enzyme.EnzymeCore: EnzymeRules + +trivtuple(N) = ntuple(identity, N) + +# To avoid computing rrules for α and β when these aren't needed, we want to have a +# type-stable quick bail-out +_needs_tangent(x) = _needs_tangent(typeof(x)) +_needs_tangent(::Type{<:Number}) = true +_needs_tangent(::Type{<:Integer}) = false +_needs_tangent(::Type{<:Union{One, Zero}}) = false + +_kron(Es::NTuple{1}, ba) = Es[1] +function _kron(Es::NTuple{N, Any}, ba) where {N} + E1 = Es[1] + E2 = _kron(Base.tail(Es), ba) + p2 = ((), trivtuple(2 * N - 2)) + p = ((1, (2 .+ trivtuple(N - 1))...), (2, ((N + 1) .+ trivtuple(N - 1))...)) + return tensorproduct(p, E1, ((1, 2), ()), false, E2, p2, false, One(), ba...) +end + +@inline EnzymeRules.inactive_type(v::Type{<:AbstractBackend}) = true +@inline EnzymeRules.inactive_type(v::Type{DefaultAllocator}) = true +@inline EnzymeRules.inactive_type(v::Type{CUDAAllocator}) = true +@inline EnzymeRules.inactive_type(v::Type{ManualAllocator}) = true +@inline EnzymeRules.inactive_type(v::Type{Index2Tuple}) = true + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(TensorOperations.tensorcontract!)}, + ::Type{RT}, + C_dC::Annotation{<:AbstractArray{TC}}, + A_dA::Annotation{<:AbstractArray{TA}}, + pA_dpA::Const{<:Index2Tuple}, + conjA_dconjA::Const{Bool}, + B_dB::Annotation{<:AbstractArray{TB}}, + pB_dpB::Const{<:Index2Tuple}, + conjB_dconjB::Const{Bool}, + pAB_dpAB::Const{<:Index2Tuple}, + α_dα::Annotation{Tα}, + β_dβ::Annotation{Tβ}, + ba_dba::Const..., + ) where {RT, Tα <: Number, Tβ <: Number, TA <: Number, TB <: Number, TC <: Number} + # form caches if needed + cache_A = !isa(A_dA, Const) && EnzymeRules.overwritten(config)[3] ? copy(A_dA.val) : nothing + cache_B = !isa(B_dB, Const) && EnzymeRules.overwritten(config)[6] ? copy(B_dB.val) : nothing + cache_C = copy(C_dC.val) + ba = map(ba_ -> getfield(ba_, :val), ba_dba) + primal = if EnzymeRules.needs_primal(config) + TensorOperations.tensorcontract!(C_dC.val, A_dA.val, pA_dpA.val, conjA_dconjA.val, B_dB.val, pB_dpB.val, conjB_dconjB.val, pAB_dpAB.val, α_dα.val, β_dβ.val, ba...) + C_dC.val + else + nothing + end + shadow = if EnzymeRules.needs_shadow(config) + C_dC.dval + else + nothing + end + return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_B, cache_C)) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(TensorOperations.tensorcontract!)}, + ::Type{RT}, + cache, + C_dC::Annotation{<:AbstractArray{TC}}, + A_dA::Annotation{<:AbstractArray{TA}}, + pA_dpA::Const{<:Index2Tuple}, + conjA_dconjA::Const{Bool}, + B_dB::Annotation{<:AbstractArray{TB}}, + pB_dpB::Const{<:Index2Tuple}, + conjB_dconjB::Const{Bool}, + pAB_dpAB::Const{<:Index2Tuple}, + α_dα::Annotation{Tα}, + β_dβ::Annotation{Tβ}, + ba_dba::Const..., + ) where {RT, Tα <: Number, Tβ <: Number, TA <: Number, TB <: Number, TC <: Number} + cache_A, cache_B, cache_C = cache + Aval = something(cache_A, A_dA.val) + Bval = something(cache_B, B_dB.val) + Cval = cache_C + #=if Tα == Zero && Tβ == Zero + scale!(C_dC.dval, zero(TC)) + return ntuple(i -> nothing, 10 + length(ba_dba)) + end=# + ipAB = invperm(linearize(pAB_dpAB.val)) + pdC = ( + TupleTools.getindices(ipAB, trivtuple(numout(pA_dpA.val))), + TupleTools.getindices(ipAB, numout(pA_dpA.val) .+ trivtuple(numin(pB_dpB.val))), + ) + ipA = (invperm(linearize(pA_dpA.val)), ()) + ipB = (invperm(linearize(pB_dpB.val)), ()) + conjA = conjA_dconjA.val + conjB = conjB_dconjB.val + conjΔC = conjA + conjB′ = conjA ? conjB : !conjB + ba = map(ba_ -> getfield(ba_, :val), ba_dba) + α = α_dα.val + β = β_dβ.val + tensorcontract!( + A_dA.dval, + C_dC.dval, pdC, conjΔC, + Bval, reverse(pB_dpB.val), conjB′, + ipA, + conjA ? α : conj(α), One(), ba... + ) + conjΔC = conjB + conjA′ = conjB ? conjA : !conjA + tensorcontract!( + B_dB.dval, + Aval, reverse(pA_dpA.val), conjA′, + C_dC.dval, pdC, conjΔC, + ipB, + conjB ? α : conj(α), One(), ba... + ) + dα = if !isa(α_dα, Const) && _needs_tangent(Tα) + C_αβ = tensorcontract(Aval, pA_dpA.val, conjA, Bval, pB_dpB.val, conjB, pAB_dpAB.val, One(), ba...) + # TODO: consider using `inner` + tensorscalar( + tensorcontract( + C_αβ, ((), trivtuple(numind(pAB_dpAB.val))), true, + C_dC.dval, (trivtuple(numind(pAB_dpAB.val)), ()), false, + ((), ()), One(), ba... + ) + ) + else + nothing + end + dβ = if !isa(β_dβ, Const) && _needs_tangent(Tβ) + # TODO: consider using `inner` + tensorscalar( + tensorcontract( + Cval, ((), trivtuple(numind(pAB_dpAB.val))), true, + C_dC.dval, (trivtuple(numind(pAB_dpAB.val)), ()), false, + ((), ()), One(), ba... + ) + ) + else + nothing + end + if β === Zero() + scale!(C_dC.dval, β) + else + scale!(C_dC.dval, conj(β)) + end + return nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, dα, dβ, map(ba_ -> nothing, ba)... +end + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + ::Annotation{typeof(tensoradd!)}, + ::Type{RT}, + C_dC::Annotation{<:AbstractArray{TC}}, + A_dA::Annotation{<:AbstractArray{TA}}, + pA_dpA::Const{<:Index2Tuple}, + conjA_dconjA::Const{Bool}, + α_dα::Annotation{Tα}, + β_dβ::Annotation{Tβ}, + ba_dba::Const..., + ) where {RT, Tα <: Number, Tβ <: Number, TA <: Number, TC <: Number} + # form caches if needed + cache_A = !isa(A_dA, Const) && EnzymeRules.overwritten(config)[3] ? copy(A_dA.val) : nothing + cache_C = copy(C_dC.val) + ba = map(ba_ -> getfield(ba_, :val), ba_dba) + α = α_dα.val + β = β_dβ.val + conjA = conjA_dconjA.val + primal = if EnzymeRules.needs_primal(config) + TensorOperations.tensoradd!(C_dC.val, A_dA.val, pA_dpA.val, conjA, α, β, ba...) + else + nothing + end + shadow = if EnzymeRules.needs_shadow(config) + C_dC.dval + else + nothing + end + return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_C)) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + ::Annotation{typeof(tensoradd!)}, + ::Type{RT}, + cache, + C_dC::Annotation{<:AbstractArray{TC}}, + A_dA::Annotation{<:AbstractArray{TA}}, + pA_dpA::Const{<:Index2Tuple}, + conjA_dconjA::Const{Bool}, + α_dα::Annotation{Tα}, + β_dβ::Annotation{Tβ}, + ba_dba::Const..., + ) where {RT, Tα <: Number, Tβ <: Number, TA <: Number, TC <: Number} + cache_A, cache_C = cache + Aval = something(cache_A, A_dA.val) + Cval = cache_C + pA = pA_dpA.val + ipA = invperm(linearize(pA)) + conjA = conjA_dconjA.val + α = α_dα.val + β = β_dβ.val + ba = map(ba_ -> getfield(ba_, :val), ba_dba) + tensoradd!(A_dA.dval, C_dC.dval, (ipA, ()), conjA, conjA ? α : conj(α), One(), ba...) + dα = if !isa(α_dα, Const) && _needs_tangent(Tα) + tensorscalar( + tensorcontract( + Aval, ((), linearize(pA)), !conjA, + C_dC.dval, (trivtuple(numind(pA)), ()), false, + ((), ()), One(), ba... + ) + ) + else + nothing + end + dβ = if !isa(β_dβ, Const) && _needs_tangent(Tβ) + tensorscalar( + tensorcontract( + Cval, ((), trivtuple(numind(pA))), true, + C_dC.dval, (trivtuple(numind(pA)), ()), false, + ((), ()), One(), ba... + ) + ) + else + nothing + end + if β === Zero() + scale!(C_dC.dval, β) + else + scale!(C_dC.dval, conj(β)) + end + return nothing, nothing, nothing, nothing, dα, dβ, map(ba_ -> nothing, ba)... +end + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + ::Annotation{typeof(tensortrace!)}, + ::Type{RT}, + C_dC::Annotation{<:AbstractArray{TC}}, + A_dA::Annotation{<:AbstractArray{TA}}, + p_dp::Const{<:Index2Tuple}, + q_dq::Const{<:Index2Tuple}, + conjA_dconjA::Const{Bool}, + α_dα::Annotation{Tα}, + β_dβ::Annotation{Tβ}, + ba_dba::Const..., + ) where {RT, Tα <: Number, Tβ <: Number, TA <: Number, TC <: Number} + # form caches if needed + cache_A = !isa(A_dA, Const) && EnzymeRules.overwritten(config)[3] ? copy(A_dA.val) : nothing + cache_C = copy(C_dC.val) + ba = map(ba_ -> getfield(ba_, :val), ba_dba) + α = α_dα.val + β = β_dβ.val + conjA = conjA_dconjA.val + primal = if EnzymeRules.needs_primal(config) + TensorOperations.tensortrace!(C_dC.val, A_dA.val, p_dp.val, q_dq.val, conjA, α, β, ba...) + else + nothing + end + shadow = if EnzymeRules.needs_shadow(config) + C_dC.dval + else + nothing + end + return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_C)) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + ::Annotation{typeof(tensortrace!)}, + ::Type{RT}, + cache, + C_dC::Annotation{<:AbstractArray{TC}}, + A_dA::Annotation{<:AbstractArray{TA}}, + p_dp::Const{<:Index2Tuple}, + q_dq::Const{<:Index2Tuple}, + conjA_dconjA::Const{Bool}, + α_dα::Annotation{Tα}, + β_dβ::Annotation{Tβ}, + ba_dba::Const..., + ) where {RT, Tα <: Number, Tβ <: Number, TA <: Number, TC <: Number} + cache_A, cache_C = cache + Aval = something(cache_A, A_dA.val) + Cval = cache_C + p = p_dp.val + q = q_dq.val + conjA = conjA_dconjA.val + α = α_dα.val + β = β_dβ.val + ba = map(ba_ -> getfield(ba_, :val), ba_dba) + ip = invperm((linearize(p)..., q[1]..., q[2]...)) + Es = map(q[1], q[2]) do i1, i2 + one( + TensorOperations.tensoralloc_add( + TensorOperations.scalartype(Aval), Aval, ((i1,), (i2,)), conjA + ) + ) + end + E = _kron(Es, ba) + dA = tensorproduct!( + A_dA.dval, C_dC.dval, (trivtuple(numind(p)), ()), conjA, + E, ((), trivtuple(numind(q))), conjA, + (ip, ()), + conjA ? α : conj(α), One(), ba... + ) + C_αβ = tensortrace(Aval, p, q, false, One(), ba...) + dα = if !isa(α_dα, Const) && _needs_tangent(Tα) + tensorscalar( + tensorcontract( + C_αβ, ((), trivtuple(numind(p))), + !conjA, + C_dC.dval, (trivtuple(numind(p)), ()), false, + ((), ()), One(), ba... + ) + ) + else + nothing + end + dβ = if !isa(β_dβ, Const) && _needs_tangent(Tβ) + tensorscalar( + tensorcontract( + Cval, ((), trivtuple(numind(p))), true, + C_dC.dval, (trivtuple(numind(p)), ()), false, + ((), ()), One(), ba... + ) + ) + else + nothing + end + if β === Zero() + scale!(C_dC.dval, β) + else + scale!(C_dC.dval, conj(β)) + end + return nothing, nothing, nothing, nothing, nothing, dα, dβ, map(ba_ -> nothing, ba)... +end + +end diff --git a/test/enzyme.jl b/test/enzyme.jl new file mode 100644 index 0000000..f6ee2da --- /dev/null +++ b/test/enzyme.jl @@ -0,0 +1,92 @@ +using TensorOperations, VectorInterface +using Enzyme, EnzymeTestUtils + +@testset "tensorcontract! ($T₁, $T₂)" for (T₁, T₂) in + ( + (Float64, Float64), + (Float32, Float64), + (ComplexF64, ComplexF64), + #(Float64, ComplexF64), + ) + T = promote_type(T₁, T₂) + atol = max(precision(T₁), precision(T₂)) + rtol = max(precision(T₁), precision(T₂)) + + pAB = ((3, 2, 4, 1), ()) + pA = ((2, 4, 5), (1, 3)) + pB = ((2, 1), (3,)) + + A = rand(T₁, (2, 3, 4, 2, 5)) + B = rand(T₂, (4, 2, 3)) + C = rand(T, (5, 2, 3, 3)) + @testset for (α, β) in ((Zero(), Zero()), (randn(T), Zero()), (Zero(), randn(T)), (randn(T), randn(T))) + Tα = α === Zero() ? Const : Active + Tβ = β === Zero() ? Const : Active + test_reverse(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (false, Const), (B, Duplicated), (pB, Const), (false, Const), (pAB, Const), (α, Tα), (β, Tβ); atol, rtol) + test_reverse(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (false, Const), (B, Duplicated), (pB, Const), (true, Const), (pAB, Const), (α, Tα), (β, Tβ); atol, rtol) + test_reverse(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (B, Duplicated), (pB, Const), (true, Const), (pAB, Const), (α, Tα), (β, Tβ); atol, rtol) + + test_reverse(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (false, Const), (B, Duplicated), (pB, Const), (false, Const), (pAB, Const), (α, Tα), (β, Tβ), (StridedBLAS(), Const); atol, rtol) + test_reverse(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (B, Duplicated), (pB, Const), (true, Const), (pAB, Const), (α, Tα), (β, Tβ), (StridedNative(), Const); atol, rtol) + end +end + +@testset "tensoradd! ($T₁, $T₂)" for (T₁, T₂) in ( + (Float64, Float64), + (Float32, Float64), + (ComplexF64, ComplexF64), + #(Float64, ComplexF64), + ) + T = promote_type(T₁, T₂) + atol = max(precision(T₁), precision(T₂)) + rtol = max(precision(T₁), precision(T₂)) + + pA = ((2, 1, 4, 3, 5), ()) + A = rand(T₁, (2, 3, 4, 2, 1)) + C = rand(T₂, size.(Ref(A), pA[1])) + @testset for (α, β) in ((Zero(), Zero()), (randn(T), Zero()), (Zero(), randn(T)), (randn(T), randn(T))) + Tα = α === Zero() ? Const : Active + Tβ = β === Zero() ? Const : Active + test_reverse(tensoradd!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (false, Const), (α, Tα), (β, Tβ); atol, rtol) + test_reverse(tensoradd!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (α, Tα), (β, Tβ); atol, rtol) + + test_reverse(tensoradd!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (false, Const), (α, Tα), (β, Tβ), (StridedBLAS(), Const); atol, rtol) + test_reverse(tensoradd!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (α, Tα), (β, Tβ), (StridedNative(), Const); atol, rtol) + end +end + +@testset "tensortrace! ($T₁, $T₂)" for (T₁, T₂) in + ( + (Float64, Float64), + (Float32, Float64), + (ComplexF64, ComplexF64), + #(Float64, ComplexF64), + ) + T = promote_type(T₁, T₂) + atol = max(precision(T₁), precision(T₂)) + rtol = max(precision(T₁), precision(T₂)) + + p = ((3, 5, 2), ()) + q = ((1,), (4,)) + A = rand(T₁, (2, 3, 4, 2, 5)) + C = rand(T₂, size.(Ref(A), p[1])) + @testset for (α, β) in ((Zero(), Zero()), (randn(T), Zero()), (Zero(), randn(T)), (randn(T), randn(T))) + Tα = α === Zero() ? Const : Active + Tβ = β === Zero() ? Const : Active + + test_reverse(tensortrace!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (q, Const), (false, Const), (α, Tα), (β, Tβ); atol, rtol) + test_reverse(tensortrace!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (q, Const), (true, Const), (α, Tα), (β, Tβ); atol, rtol) + + test_reverse(tensortrace!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (q, Const), (true, Const), (α, Tα), (β, Tβ), (StridedBLAS(), Const); atol, rtol) + test_reverse(tensortrace!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (q, Const), (false, Const), (α, Tα), (β, Tβ), (StridedNative(), Const); atol, rtol) + end +end + +@testset "tensorscalar ($T)" for T in (Float32, Float64, ComplexF64) + atol = precision(T) + rtol = precision(T) + + C = Array{T, 0}(undef, ()) + fill!(C, rand(T)) + test_reverse(tensorscalar, Active, (C, Duplicated); atol, rtol) +end diff --git a/test/runtests.jl b/test/runtests.jl index f67fbe6..dd594b1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -15,7 +15,6 @@ precision(::Type{<:Union{Float64, Complex{Float64}}}) = 1.0e-8 # specific ones is_buildkite = get(ENV, "BUILDKITE", "false") == "true" if !is_buildkite - @testset "tensoropt" verbose = true begin include("tensoropt.jl") end @@ -37,6 +36,9 @@ if !is_buildkite @testset "mooncake" verbose = false begin include("mooncake.jl") end + @testset "enzyme" verbose = false begin + include("enzyme.jl") + end end if is_buildkite From 062c0ce786625536b24d74c0a54cf4ae376710cc Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 15 Jan 2026 15:00:56 +0100 Subject: [PATCH 2/9] Move pullback into own file --- .../TensorOperationsEnzymeExt.jl | 65 ++---------------- .../TensorOperationsMooncakeExt.jl | 66 +------------------ src/pullbacks/contract.jl | 63 ++++++++++++++++++ 3 files changed, 70 insertions(+), 124 deletions(-) create mode 100644 src/pullbacks/contract.jl diff --git a/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl b/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl index b681e88..01f4fed 100644 --- a/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl +++ b/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl @@ -89,70 +89,13 @@ function EnzymeRules.reverse( Aval = something(cache_A, A_dA.val) Bval = something(cache_B, B_dB.val) Cval = cache_C - #=if Tα == Zero && Tβ == Zero - scale!(C_dC.dval, zero(TC)) - return ntuple(i -> nothing, 10 + length(ba_dba)) - end=# - ipAB = invperm(linearize(pAB_dpAB.val)) - pdC = ( - TupleTools.getindices(ipAB, trivtuple(numout(pA_dpA.val))), - TupleTools.getindices(ipAB, numout(pA_dpA.val) .+ trivtuple(numin(pB_dpB.val))), - ) - ipA = (invperm(linearize(pA_dpA.val)), ()) - ipB = (invperm(linearize(pB_dpB.val)), ()) - conjA = conjA_dconjA.val - conjB = conjB_dconjB.val - conjΔC = conjA - conjB′ = conjA ? conjB : !conjB + dC = C_dC.dval + dA = A_dA.dval + dB = B_dB.dval ba = map(ba_ -> getfield(ba_, :val), ba_dba) α = α_dα.val β = β_dβ.val - tensorcontract!( - A_dA.dval, - C_dC.dval, pdC, conjΔC, - Bval, reverse(pB_dpB.val), conjB′, - ipA, - conjA ? α : conj(α), One(), ba... - ) - conjΔC = conjB - conjA′ = conjB ? conjA : !conjA - tensorcontract!( - B_dB.dval, - Aval, reverse(pA_dpA.val), conjA′, - C_dC.dval, pdC, conjΔC, - ipB, - conjB ? α : conj(α), One(), ba... - ) - dα = if !isa(α_dα, Const) && _needs_tangent(Tα) - C_αβ = tensorcontract(Aval, pA_dpA.val, conjA, Bval, pB_dpB.val, conjB, pAB_dpAB.val, One(), ba...) - # TODO: consider using `inner` - tensorscalar( - tensorcontract( - C_αβ, ((), trivtuple(numind(pAB_dpAB.val))), true, - C_dC.dval, (trivtuple(numind(pAB_dpAB.val)), ()), false, - ((), ()), One(), ba... - ) - ) - else - nothing - end - dβ = if !isa(β_dβ, Const) && _needs_tangent(Tβ) - # TODO: consider using `inner` - tensorscalar( - tensorcontract( - Cval, ((), trivtuple(numind(pAB_dpAB.val))), true, - C_dC.dval, (trivtuple(numind(pAB_dpAB.val)), ()), false, - ((), ()), One(), ba... - ) - ) - else - nothing - end - if β === Zero() - scale!(C_dC.dval, β) - else - scale!(C_dC.dval, conj(β)) - end + ΔC, ΔA, ΔB, dα, dβ = tensorcontract_pb!(dC, Cval, dA, Aval, dB, Bval, α, β, pA_dpA.val, pB_dpB.val, pAB_dpAB.val, conjA_dconjA.val, conjB_dconjB.val, ba...) return nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, dα, dβ, map(ba_ -> nothing, ba)... end diff --git a/ext/TensorOperationsMooncakeExt/TensorOperationsMooncakeExt.jl b/ext/TensorOperationsMooncakeExt/TensorOperationsMooncakeExt.jl index 49d1708..a58b703 100644 --- a/ext/TensorOperationsMooncakeExt/TensorOperationsMooncakeExt.jl +++ b/ext/TensorOperationsMooncakeExt/TensorOperationsMooncakeExt.jl @@ -61,69 +61,9 @@ function Mooncake.rrule!!( TensorOperations.tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba...) function contract_pb(::NoRData) scale!(C, C_cache, One()) - if Tα == Zero && Tβ == Zero - scale!(dC, zero(TC)) - return ntuple(i -> NoRData(), 11 + length(ba)) - end - ipAB = invperm(linearize(pAB)) - pdC = ( - TupleTools.getindices(ipAB, trivtuple(numout(pA))), - TupleTools.getindices(ipAB, numout(pA) .+ trivtuple(numin(pB))), - ) - ipA = (invperm(linearize(pA)), ()) - ipB = (invperm(linearize(pB)), ()) - conjΔC = conjA - conjB′ = conjA ? conjB : !conjB - dA = tensorcontract!( - dA, - dC, pdC, conjΔC, - B, reverse(pB), conjB′, - ipA, - conjA ? α : conj(α), One(), ba... - ) - conjΔC = conjB - conjA′ = conjB ? conjA : !conjA - dB = tensorcontract!( - dB, - A, reverse(pA), conjA′, - dC, pdC, conjΔC, - ipB, - conjB ? α : conj(α), One(), ba... - ) - dα = if _needs_tangent(Tα) - C_αβ = tensorcontract(A, pA, conjA, B, pB, conjB, pAB, One(), ba...) - # TODO: consider using `inner` - Mooncake._rdata( - tensorscalar( - tensorcontract( - C_αβ, ((), trivtuple(numind(pAB))), true, - dC, (trivtuple(numind(pAB)), ()), false, - ((), ()), One(), ba... - ) - ) - ) - else - NoRData() - end - dβ = if _needs_tangent(Tβ) - # TODO: consider using `inner` - Mooncake._rdata( - tensorscalar( - tensorcontract( - C, ((), trivtuple(numind(pAB))), true, - dC, (trivtuple(numind(pAB)), ()), false, - ((), ()), One(), ba... - ) - ) - ) - else - NoRData() - end - if β === Zero() - scale!(dC, β) - else - scale!(dC, conj(β)) - end + dC, dA, dB, Δα, Δβ = tensorcontract_pb!(dC, C, dA, A, dB, B, α, β, pA, pB, pAB, conjA, conjB, ba...) + dα = isnothing(Δα) ? NoRData() : Mooncake._rdata(Δα) + dβ = isnothing(Δβ) ? NoRData() : Mooncake._rdata(Δβ) return NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), dα, dβ, map(ba_ -> NoRData(), ba)... end return C_dC, contract_pb diff --git a/src/pullbacks/contract.jl b/src/pullbacks/contract.jl new file mode 100644 index 0000000..a2a9e7c --- /dev/null +++ b/src/pullbacks/contract.jl @@ -0,0 +1,63 @@ +function tensorcontract_pb!(ΔC, C, ΔA, A, ΔB, B, α, β, pA, pB, pAB, conjA::Bool, conjB::Bool, ba...) + ipAB = invperm(linearize(pAB)) + pdC = ( + TupleTools.getindices(ipAB, trivtuple(numout(pA))), + TupleTools.getindices(ipAB, numout(pA) .+ trivtuple(numin(pB))), + ) + ipA = (invperm(linearize(pA)), ()) + ipB = (invperm(linearize(pB)), ()) + conjΔC = conjA + conjB′ = conjA ? conjB : !conjB + tensorcontract!( + ΔA, + ΔC, pdC, conjΔC, + B, reverse(pB), conjB′, + ipA, + conjA ? α : conj(α), One(), ba... + ) + conjΔC = conjB + conjA′ = conjB ? conjA : !conjA + tensorcontract!( + ΔB, + A, reverse(pA), conjA′, + ΔC, pdC, conjΔC, + ipB, + conjB ? α : conj(α), One(), ba... + ) + Δα = if _needs_tangent(Tα) + C_αβ = tensorcontract(A, pA, conjA, B, pB, conjB, pAB, One(), ba...) + # TODO: consider using `inner` + tensorscalar( + tensorcontract( + C_αβ, ((), trivtuple(numind(pAB))), true, + ΔC, (trivtuple(numind(pAB)), ()), false, + ((), ()), One(), ba... + ) + ) + else + nothing + end + Δβ = if _needs_tangent(Tβ) + # TODO: consider using `inner` + tensorscalar( + tensorcontract( + C, ((), trivtuple(numind(pAB))), true, + ΔC, (trivtuple(numind(pAB)), ()), false, + ((), ()), One(), ba... + ) + ) + else + nothing + end + if β === Zero() + scale!(ΔC, β) + else + scale!(ΔC, conj(β)) + end + return ΔC, ΔA, ΔB, Δα, Δβ +end + +function tensorcontract_pb!(ΔC, C, ΔA, A, ΔB, B, α::Zero, β::Zero, args...) + scale!(ΔC, zero(eltype(C))) + return ntuple(i -> nothing, 5) +end From c0dad8a58626b932a24f50a27bea0b8a2d10a17d Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 15 Jan 2026 17:15:22 +0100 Subject: [PATCH 3/9] Move some code into pullbacks --- Project.toml | 1 + ext/TensorOperationsChainRulesCoreExt.jl | 4 +- .../TensorOperationsEnzymeExt.jl | 96 +------------------ .../TensorOperationsMooncakeExt.jl | 89 ++--------------- src/TensorOperations.jl | 6 ++ src/pullbacks/add.jl | 32 +++++++ src/pullbacks/contract.jl | 11 +-- src/pullbacks/trace.jl | 47 +++++++++ src/utils.jl | 2 + 9 files changed, 103 insertions(+), 185 deletions(-) create mode 100644 src/pullbacks/add.jl create mode 100644 src/pullbacks/trace.jl diff --git a/Project.toml b/Project.toml index 52f2d97..42be359 100644 --- a/Project.toml +++ b/Project.toml @@ -41,6 +41,7 @@ ChainRulesCore = "1" ChainRulesTestUtils = "1" DynamicPolynomials = "0.5, 0.6" Enzyme = "0.13.115" +EnzymeTestUtils = "0.2" LRUCache = "1" LinearAlgebra = "1.6" Logging = "1.6" diff --git a/ext/TensorOperationsChainRulesCoreExt.jl b/ext/TensorOperationsChainRulesCoreExt.jl index 539d9d9..8338468 100644 --- a/ext/TensorOperationsChainRulesCoreExt.jl +++ b/ext/TensorOperationsChainRulesCoreExt.jl @@ -1,7 +1,7 @@ module TensorOperationsChainRulesCoreExt using TensorOperations -using TensorOperations: numind, numin, numout, promote_contract, _needs_tangent +using TensorOperations: numind, numin, numout, promote_contract, _needs_tangent, trivtuple using TensorOperations: DefaultBackend, DefaultAllocator, _kron using ChainRulesCore using TupleTools @@ -9,8 +9,6 @@ using VectorInterface using TupleTools: invperm using LinearAlgebra -trivtuple(N) = ntuple(identity, N) - @non_differentiable TensorOperations.tensorstructure(args...) @non_differentiable TensorOperations.tensoradd_structure(args...) @non_differentiable TensorOperations.tensoradd_type(args...) diff --git a/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl b/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl index 01f4fed..e9625c8 100644 --- a/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl +++ b/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl @@ -1,7 +1,6 @@ module TensorOperationsEnzymeExt using TensorOperations -using TensorOperations: numind, numin, numout, promote_contract using TensorOperations: AbstractBackend, DefaultAllocator, CUDAAllocator, ManualAllocator using VectorInterface using TupleTools @@ -9,24 +8,6 @@ using Enzyme using Enzyme.EnzymeCore using Enzyme.EnzymeCore: EnzymeRules -trivtuple(N) = ntuple(identity, N) - -# To avoid computing rrules for α and β when these aren't needed, we want to have a -# type-stable quick bail-out -_needs_tangent(x) = _needs_tangent(typeof(x)) -_needs_tangent(::Type{<:Number}) = true -_needs_tangent(::Type{<:Integer}) = false -_needs_tangent(::Type{<:Union{One, Zero}}) = false - -_kron(Es::NTuple{1}, ba) = Es[1] -function _kron(Es::NTuple{N, Any}, ba) where {N} - E1 = Es[1] - E2 = _kron(Base.tail(Es), ba) - p2 = ((), trivtuple(2 * N - 2)) - p = ((1, (2 .+ trivtuple(N - 1))...), (2, ((N + 1) .+ trivtuple(N - 1))...)) - return tensorproduct(p, E1, ((1, 2), ()), false, E2, p2, false, One(), ba...) -end - @inline EnzymeRules.inactive_type(v::Type{<:AbstractBackend}) = true @inline EnzymeRules.inactive_type(v::Type{DefaultAllocator}) = true @inline EnzymeRules.inactive_type(v::Type{CUDAAllocator}) = true @@ -95,7 +76,7 @@ function EnzymeRules.reverse( ba = map(ba_ -> getfield(ba_, :val), ba_dba) α = α_dα.val β = β_dβ.val - ΔC, ΔA, ΔB, dα, dβ = tensorcontract_pb!(dC, Cval, dA, Aval, dB, Bval, α, β, pA_dpA.val, pB_dpB.val, pAB_dpAB.val, conjA_dconjA.val, conjB_dconjB.val, ba...) + dα, dβ = TensorOperations.tensorcontract_pb!(dC, Cval, dA, Aval, dB, Bval, α, β, pA_dpA.val, pB_dpB.val, pAB_dpAB.val, conjA_dconjA.val, conjB_dconjB.val, ba...) return nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, dα, dβ, map(ba_ -> nothing, ba)... end @@ -148,39 +129,11 @@ function EnzymeRules.reverse( Aval = something(cache_A, A_dA.val) Cval = cache_C pA = pA_dpA.val - ipA = invperm(linearize(pA)) conjA = conjA_dconjA.val α = α_dα.val β = β_dβ.val ba = map(ba_ -> getfield(ba_, :val), ba_dba) - tensoradd!(A_dA.dval, C_dC.dval, (ipA, ()), conjA, conjA ? α : conj(α), One(), ba...) - dα = if !isa(α_dα, Const) && _needs_tangent(Tα) - tensorscalar( - tensorcontract( - Aval, ((), linearize(pA)), !conjA, - C_dC.dval, (trivtuple(numind(pA)), ()), false, - ((), ()), One(), ba... - ) - ) - else - nothing - end - dβ = if !isa(β_dβ, Const) && _needs_tangent(Tβ) - tensorscalar( - tensorcontract( - Cval, ((), trivtuple(numind(pA))), true, - C_dC.dval, (trivtuple(numind(pA)), ()), false, - ((), ()), One(), ba... - ) - ) - else - nothing - end - if β === Zero() - scale!(C_dC.dval, β) - else - scale!(C_dC.dval, conj(β)) - end + dα, dβ = TensorOperations.tensoradd_pb!(C_dC.dval, Cval, A_dA.dval, Aval, α, β, pA, conjA, ba...) return nothing, nothing, nothing, nothing, dα, dβ, map(ba_ -> nothing, ba)... end @@ -240,50 +193,7 @@ function EnzymeRules.reverse( α = α_dα.val β = β_dβ.val ba = map(ba_ -> getfield(ba_, :val), ba_dba) - ip = invperm((linearize(p)..., q[1]..., q[2]...)) - Es = map(q[1], q[2]) do i1, i2 - one( - TensorOperations.tensoralloc_add( - TensorOperations.scalartype(Aval), Aval, ((i1,), (i2,)), conjA - ) - ) - end - E = _kron(Es, ba) - dA = tensorproduct!( - A_dA.dval, C_dC.dval, (trivtuple(numind(p)), ()), conjA, - E, ((), trivtuple(numind(q))), conjA, - (ip, ()), - conjA ? α : conj(α), One(), ba... - ) - C_αβ = tensortrace(Aval, p, q, false, One(), ba...) - dα = if !isa(α_dα, Const) && _needs_tangent(Tα) - tensorscalar( - tensorcontract( - C_αβ, ((), trivtuple(numind(p))), - !conjA, - C_dC.dval, (trivtuple(numind(p)), ()), false, - ((), ()), One(), ba... - ) - ) - else - nothing - end - dβ = if !isa(β_dβ, Const) && _needs_tangent(Tβ) - tensorscalar( - tensorcontract( - Cval, ((), trivtuple(numind(p))), true, - C_dC.dval, (trivtuple(numind(p)), ()), false, - ((), ()), One(), ba... - ) - ) - else - nothing - end - if β === Zero() - scale!(C_dC.dval, β) - else - scale!(C_dC.dval, conj(β)) - end + dα, dβ = TensorOperations.tensortrace_pb!(C_dC.dval, Cval, A_dA.dval, Aval, α, β, p, q, conjA, ba...) return nothing, nothing, nothing, nothing, nothing, dα, dβ, map(ba_ -> nothing, ba)... end diff --git a/ext/TensorOperationsMooncakeExt/TensorOperationsMooncakeExt.jl b/ext/TensorOperationsMooncakeExt/TensorOperationsMooncakeExt.jl index a58b703..5f85a51 100644 --- a/ext/TensorOperationsMooncakeExt/TensorOperationsMooncakeExt.jl +++ b/ext/TensorOperationsMooncakeExt/TensorOperationsMooncakeExt.jl @@ -6,7 +6,7 @@ using TensorOperations # extension are in fact loaded using Mooncake, Mooncake.CRC using TensorOperations: AbstractBackend, DefaultAllocator, CUDAAllocator, ManualAllocator -using TensorOperations: tensoralloc, tensoradd!, tensorcontract!, tensortrace!, _kron, numind, _needs_tangent, numin, numout +using TensorOperations: tensoralloc, tensoradd!, tensorcontract!, tensortrace! using Mooncake: ReverseMode, DefaultCtx, CoDual, NoRData, arrayify, @zero_derivative, primal, tangent using VectorInterface, TupleTools @@ -16,8 +16,6 @@ Mooncake.tangent_type(::Type{DefaultAllocator}) = Mooncake.NoTangent Mooncake.tangent_type(::Type{CUDAAllocator}) = Mooncake.NoTangent Mooncake.tangent_type(::Type{ManualAllocator}) = Mooncake.NoTangent -trivtuple(N) = ntuple(identity, N) - @zero_derivative DefaultCtx Tuple{typeof(TensorOperations.tensorstructure), Any} @zero_derivative DefaultCtx Tuple{typeof(TensorOperations.tensoradd_structure), Any} @zero_derivative DefaultCtx Tuple{typeof(TensorOperations.tensoradd_type), Any} @@ -61,7 +59,7 @@ function Mooncake.rrule!!( TensorOperations.tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba...) function contract_pb(::NoRData) scale!(C, C_cache, One()) - dC, dA, dB, Δα, Δβ = tensorcontract_pb!(dC, C, dA, A, dB, B, α, β, pA, pB, pAB, conjA, conjB, ba...) + Δα, Δβ = TensorOperations.tensorcontract_pb!(dC, C, dA, A, dB, B, α, β, pA, pB, pAB, conjA, conjB, ba...) dα = isnothing(Δα) ? NoRData() : Mooncake._rdata(Δα) dβ = isnothing(Δβ) ? NoRData() : Mooncake._rdata(Δβ) return NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), dα, dβ, map(ba_ -> NoRData(), ba)... @@ -91,35 +89,9 @@ function Mooncake.rrule!!( TensorOperations.tensoradd!(C, A, pA, conjA, α, β, ba...) function add_pb(::NoRData) scale!(C, C_cache, One()) - ipA = invperm(linearize(pA)) - dA = tensoradd!(dA, dC, (ipA, ()), conjA, conjA ? α : conj(α), One(), ba...) - dα = if _needs_tangent(Tα) - tensorscalar( - tensorcontract( - A, ((), linearize(pA)), !conjA, - dC, (trivtuple(numind(pA)), ()), false, - ((), ()), One(), ba... - ) - ) - else - Mooncake.NoRData() - end - dβ = if _needs_tangent(Tβ) - tensorscalar( - tensorcontract( - C, ((), trivtuple(numind(pA))), true, - dC, (trivtuple(numind(pA)), ()), false, - ((), ()), One(), ba... - ) - ) - else - Mooncake.NoRData() - end - if β === Zero() - scale!(dC, β) - else - scale!(dC, conj(β)) - end + Δα, Δβ = TensorOperations.tensoradd_pb!(dC, C, dA, A, α, β, pA, conjA, ba...) + dα = isnothing(Δα) ? NoRData() : Mooncake._rdata(Δα) + dβ = isnothing(Δβ) ? NoRData() : Mooncake._rdata(Δβ) return NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), dα, dβ, map(ba_ -> NoRData(), ba)... end return C_dC, add_pb @@ -149,54 +121,9 @@ function Mooncake.rrule!!( TensorOperations.tensortrace!(C, A, p, q, conjA, α, β, ba...) function trace_pb(::NoRData) scale!(C, C_cache, One()) - ip = invperm((linearize(p)..., q[1]..., q[2]...)) - Es = map(q[1], q[2]) do i1, i2 - one( - TensorOperations.tensoralloc_add( - TensorOperations.scalartype(A), A, ((i1,), (i2,)), conjA - ) - ) - end - E = _kron(Es, ba) - dA = tensorproduct!( - dA, dC, (trivtuple(numind(p)), ()), conjA, - E, ((), trivtuple(numind(q))), conjA, - (ip, ()), - conjA ? α : conj(α), One(), ba... - ) - C_αβ = tensortrace(A, p, q, false, One(), ba...) - dα = if _needs_tangent(Tα) - Mooncake._rdata( - tensorscalar( - tensorcontract( - C_αβ, ((), trivtuple(numind(p))), - !conjA, - dC, (trivtuple(numind(p)), ()), false, - ((), ()), One(), ba... - ) - ) - ) - else - NoRData() - end - dβ = if _needs_tangent(Tβ) - Mooncake._rdata( - tensorscalar( - tensorcontract( - C, ((), trivtuple(numind(p))), true, - dC, (trivtuple(numind(p)), ()), false, - ((), ()), One(), ba... - ) - ) - ) - else - NoRData() - end - if β === Zero() - scale!(dC, β) - else - scale!(dC, conj(β)) - end + Δα, Δβ = TensorOperations.tensortrace_pb!(dC, C, dA, A, α, β, p, q, conjA, ba...) + dα = isnothing(Δα) ? NoRData() : Mooncake._rdata(Δα) + dβ = isnothing(Δβ) ? NoRData() : Mooncake._rdata(Δβ) return NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), dα, dβ, map(ba_ -> NoRData(), ba)... end return C_dC, trace_pb diff --git a/src/TensorOperations.jl b/src/TensorOperations.jl index 23f205d..05ae808 100644 --- a/src/TensorOperations.jl +++ b/src/TensorOperations.jl @@ -36,6 +36,12 @@ include("backends.jl") include("interface.jl") include("utils.jl") +# Generic pullbacks for AD +#--------------------------- +include("pullbacks/add.jl") +include("pullbacks/trace.jl") +include("pullbacks/contract.jl") + # Index notation via macros #--------------------------- @nospecialize diff --git a/src/pullbacks/add.jl b/src/pullbacks/add.jl new file mode 100644 index 0000000..bb0b9b6 --- /dev/null +++ b/src/pullbacks/add.jl @@ -0,0 +1,32 @@ +function tensoradd_pb!(ΔC, C, ΔA, A, α, β, pA, conjA::Bool, ba...) + ipA = invperm(linearize(pA)) + tensoradd!(ΔA, ΔC, (ipA, ()), conjA, conjA ? α : conj(α), One(), ba...) + Δα = if _needs_tangent(α) + tensorscalar( + tensorcontract( + A, ((), linearize(pA)), !conjA, + ΔC, (trivtuple(numind(pA)), ()), false, + ((), ()), One(), ba... + ) + ) + else + nothing + end + Δβ = if _needs_tangent(β) + tensorscalar( + tensorcontract( + C, ((), trivtuple(numind(pA))), true, + ΔC, (trivtuple(numind(pA)), ()), false, + ((), ()), One(), ba... + ) + ) + else + nothing + end + if β === Zero() + scale!(ΔC, β) + else + scale!(ΔC, conj(β)) + end + return Δα, Δβ +end diff --git a/src/pullbacks/contract.jl b/src/pullbacks/contract.jl index a2a9e7c..b8190f3 100644 --- a/src/pullbacks/contract.jl +++ b/src/pullbacks/contract.jl @@ -24,7 +24,7 @@ function tensorcontract_pb!(ΔC, C, ΔA, A, ΔB, B, α, β, pA, pB, pAB, conjA:: ipB, conjB ? α : conj(α), One(), ba... ) - Δα = if _needs_tangent(Tα) + Δα = if _needs_tangent(α) C_αβ = tensorcontract(A, pA, conjA, B, pB, conjB, pAB, One(), ba...) # TODO: consider using `inner` tensorscalar( @@ -37,7 +37,7 @@ function tensorcontract_pb!(ΔC, C, ΔA, A, ΔB, B, α, β, pA, pB, pAB, conjA:: else nothing end - Δβ = if _needs_tangent(Tβ) + Δβ = if _needs_tangent(β) # TODO: consider using `inner` tensorscalar( tensorcontract( @@ -54,10 +54,5 @@ function tensorcontract_pb!(ΔC, C, ΔA, A, ΔB, B, α, β, pA, pB, pAB, conjA:: else scale!(ΔC, conj(β)) end - return ΔC, ΔA, ΔB, Δα, Δβ -end - -function tensorcontract_pb!(ΔC, C, ΔA, A, ΔB, B, α::Zero, β::Zero, args...) - scale!(ΔC, zero(eltype(C))) - return ntuple(i -> nothing, 5) + return Δα, Δβ end diff --git a/src/pullbacks/trace.jl b/src/pullbacks/trace.jl new file mode 100644 index 0000000..c616014 --- /dev/null +++ b/src/pullbacks/trace.jl @@ -0,0 +1,47 @@ +function tensortrace_pb!(ΔC, C, ΔA, A, α, β, p, q, conjA, ba...) + ip = invperm((linearize(p)..., q[1]..., q[2]...)) + Es = map(q[1], q[2]) do i1, i2 + one( + TensorOperations.tensoralloc_add( + TensorOperations.scalartype(A), A, ((i1,), (i2,)), conjA + ) + ) + end + E = _kron(Es, ba) + tensorproduct!( + ΔA, ΔC, (trivtuple(numind(p)), ()), conjA, + E, ((), trivtuple(numind(q))), conjA, + (ip, ()), + conjA ? α : conj(α), One(), ba... + ) + C_αβ = tensortrace(A, p, q, false, One(), ba...) + Δα = if _needs_tangent(α) + tensorscalar( + tensorcontract( + C_αβ, ((), trivtuple(numind(p))), + !conjA, + ΔC, (trivtuple(numind(p)), ()), false, + ((), ()), One(), ba... + ) + ) + else + nothing + end + Δβ = if _needs_tangent(β) + tensorscalar( + tensorcontract( + C, ((), trivtuple(numind(p))), true, + ΔC, (trivtuple(numind(p)), ()), false, + ((), ()), One(), ba... + ) + ) + else + nothing + end + if β === Zero() + scale!(ΔC, β) + else + scale!(ΔC, conj(β)) + end + return Δα, Δβ +end diff --git a/src/utils.jl b/src/utils.jl index 20e406b..68b98c8 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -13,3 +13,5 @@ _needs_tangent(x) = _needs_tangent(typeof(x)) _needs_tangent(::Type{<:Number}) = true _needs_tangent(::Type{<:Integer}) = false _needs_tangent(::Type{<:Union{One, Zero}}) = false + +trivtuple(N) = ntuple(identity, N) From 61727e41f5ea3f5eb17b1da3537f59e5d758524c Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Fri, 16 Jan 2026 12:03:49 +0100 Subject: [PATCH 4/9] Update ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl Co-authored-by: Lukas Devos --- ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl b/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl index e9625c8..01c861f 100644 --- a/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl +++ b/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl @@ -10,7 +10,7 @@ using Enzyme.EnzymeCore: EnzymeRules @inline EnzymeRules.inactive_type(v::Type{<:AbstractBackend}) = true @inline EnzymeRules.inactive_type(v::Type{DefaultAllocator}) = true -@inline EnzymeRules.inactive_type(v::Type{CUDAAllocator}) = true +@inline EnzymeRules.inactive_type(v::Type{<:CUDAAllocator}) = true @inline EnzymeRules.inactive_type(v::Type{ManualAllocator}) = true @inline EnzymeRules.inactive_type(v::Type{Index2Tuple}) = true From 4788ba760dc325c6f25000374422c47daa648d85 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Fri, 16 Jan 2026 14:30:01 +0100 Subject: [PATCH 5/9] Updates re comments --- .../TensorOperationsEnzymeExt.jl | 10 +++++++--- .../TensorOperationsMooncakeExt.jl | 6 +++--- src/pullbacks/add.jl | 10 +++++++--- src/pullbacks/contract.jl | 16 ++++++++++++---- src/pullbacks/trace.jl | 10 +++++++--- test/enzyme.jl | 7 ++++--- test/mooncake.jl | 1 + 7 files changed, 41 insertions(+), 19 deletions(-) diff --git a/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl b/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl index 01c861f..df72208 100644 --- a/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl +++ b/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl @@ -76,7 +76,7 @@ function EnzymeRules.reverse( ba = map(ba_ -> getfield(ba_, :val), ba_dba) α = α_dα.val β = β_dβ.val - dα, dβ = TensorOperations.tensorcontract_pb!(dC, Cval, dA, Aval, dB, Bval, α, β, pA_dpA.val, pB_dpB.val, pAB_dpAB.val, conjA_dconjA.val, conjB_dconjB.val, ba...) + dC, dA, dB, dα, dβ = TensorOperations.tensorcontract_pullback!(dC, dA, dB, Cval, Aval, Bval, α, β, pA_dpA.val, pB_dpB.val, pAB_dpAB.val, conjA_dconjA.val, conjB_dconjB.val, ba...) return nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, dα, dβ, map(ba_ -> nothing, ba)... end @@ -133,7 +133,9 @@ function EnzymeRules.reverse( α = α_dα.val β = β_dβ.val ba = map(ba_ -> getfield(ba_, :val), ba_dba) - dα, dβ = TensorOperations.tensoradd_pb!(C_dC.dval, Cval, A_dA.dval, Aval, α, β, pA, conjA, ba...) + dC = C_dC.dval + dA = A_dA.dval + dC, dA, dα, dβ = TensorOperations.tensoradd_pullback!(dC, dA, Cval, Aval, α, β, pA, conjA, ba...) return nothing, nothing, nothing, nothing, dα, dβ, map(ba_ -> nothing, ba)... end @@ -193,7 +195,9 @@ function EnzymeRules.reverse( α = α_dα.val β = β_dβ.val ba = map(ba_ -> getfield(ba_, :val), ba_dba) - dα, dβ = TensorOperations.tensortrace_pb!(C_dC.dval, Cval, A_dA.dval, Aval, α, β, p, q, conjA, ba...) + dC = C_dC.dval + dA = A_dA.dval + dC, dA, dα, dβ = TensorOperations.tensortrace_pullback!(dC, dA, Cval, Aval, α, β, p, q, conjA, ba...) return nothing, nothing, nothing, nothing, nothing, dα, dβ, map(ba_ -> nothing, ba)... end diff --git a/ext/TensorOperationsMooncakeExt/TensorOperationsMooncakeExt.jl b/ext/TensorOperationsMooncakeExt/TensorOperationsMooncakeExt.jl index 5f85a51..9450d77 100644 --- a/ext/TensorOperationsMooncakeExt/TensorOperationsMooncakeExt.jl +++ b/ext/TensorOperationsMooncakeExt/TensorOperationsMooncakeExt.jl @@ -59,7 +59,7 @@ function Mooncake.rrule!!( TensorOperations.tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba...) function contract_pb(::NoRData) scale!(C, C_cache, One()) - Δα, Δβ = TensorOperations.tensorcontract_pb!(dC, C, dA, A, dB, B, α, β, pA, pB, pAB, conjA, conjB, ba...) + dC, dA, dB, Δα, Δβ = TensorOperations.tensorcontract_pullback!(dC, dA, dB, C, A, B, α, β, pA, pB, pAB, conjA, conjB, ba...) dα = isnothing(Δα) ? NoRData() : Mooncake._rdata(Δα) dβ = isnothing(Δβ) ? NoRData() : Mooncake._rdata(Δβ) return NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), dα, dβ, map(ba_ -> NoRData(), ba)... @@ -89,7 +89,7 @@ function Mooncake.rrule!!( TensorOperations.tensoradd!(C, A, pA, conjA, α, β, ba...) function add_pb(::NoRData) scale!(C, C_cache, One()) - Δα, Δβ = TensorOperations.tensoradd_pb!(dC, C, dA, A, α, β, pA, conjA, ba...) + dC, dA, Δα, Δβ = TensorOperations.tensoradd_pullback!(dC, dA, C, A, α, β, pA, conjA, ba...) dα = isnothing(Δα) ? NoRData() : Mooncake._rdata(Δα) dβ = isnothing(Δβ) ? NoRData() : Mooncake._rdata(Δβ) return NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), dα, dβ, map(ba_ -> NoRData(), ba)... @@ -121,7 +121,7 @@ function Mooncake.rrule!!( TensorOperations.tensortrace!(C, A, p, q, conjA, α, β, ba...) function trace_pb(::NoRData) scale!(C, C_cache, One()) - Δα, Δβ = TensorOperations.tensortrace_pb!(dC, C, dA, A, α, β, p, q, conjA, ba...) + dC, dA, Δα, Δβ = TensorOperations.tensortrace_pullback!(dC, dA, C, A, α, β, p, q, conjA, ba...) dα = isnothing(Δα) ? NoRData() : Mooncake._rdata(Δα) dβ = isnothing(Δβ) ? NoRData() : Mooncake._rdata(Δβ) return NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), dα, dβ, map(ba_ -> NoRData(), ba)... diff --git a/src/pullbacks/add.jl b/src/pullbacks/add.jl index bb0b9b6..1cc5dab 100644 --- a/src/pullbacks/add.jl +++ b/src/pullbacks/add.jl @@ -1,6 +1,10 @@ -function tensoradd_pb!(ΔC, C, ΔA, A, α, β, pA, conjA::Bool, ba...) +function tensoradd_pullback!(ΔC, ΔA, C, A, α, β, pA, conjA::Bool, ba...) ipA = invperm(linearize(pA)) - tensoradd!(ΔA, ΔC, (ipA, ()), conjA, conjA ? α : conj(α), One(), ba...) + ΔAc = eltype(ΔC) <: Complex && eltype(ΔA) <: Real ? zerovector(A, VectorInterface.promote_add(ΔC, α)) : ΔA + tensoradd!(ΔAc, ΔC, (ipA, ()), conjA, conjA ? α : conj(α), One(), ba...) + if eltype(ΔC) <: Complex && eltype(ΔA) <: Real + ΔA .+= real.(ΔAc) + end Δα = if _needs_tangent(α) tensorscalar( tensorcontract( @@ -28,5 +32,5 @@ function tensoradd_pb!(ΔC, C, ΔA, A, α, β, pA, conjA::Bool, ba...) else scale!(ΔC, conj(β)) end - return Δα, Δβ + return ΔC, ΔA, Δα, Δβ end diff --git a/src/pullbacks/contract.jl b/src/pullbacks/contract.jl index b8190f3..939b189 100644 --- a/src/pullbacks/contract.jl +++ b/src/pullbacks/contract.jl @@ -1,4 +1,4 @@ -function tensorcontract_pb!(ΔC, C, ΔA, A, ΔB, B, α, β, pA, pB, pAB, conjA::Bool, conjB::Bool, ba...) +function tensorcontract_pullback!(ΔC, ΔA, ΔB, C, A, B, α, β, pA, pB, pAB, conjA::Bool, conjB::Bool, ba...) ipAB = invperm(linearize(pAB)) pdC = ( TupleTools.getindices(ipAB, trivtuple(numout(pA))), @@ -8,22 +8,30 @@ function tensorcontract_pb!(ΔC, C, ΔA, A, ΔB, B, α, β, pA, pB, pAB, conjA:: ipB = (invperm(linearize(pB)), ()) conjΔC = conjA conjB′ = conjA ? conjB : !conjB + ΔAc = eltype(ΔC) <: Complex && eltype(ΔA) <: Real ? zerovector(A, VectorInterface.promote_add(ΔC, α)) : ΔA tensorcontract!( - ΔA, + ΔAc, ΔC, pdC, conjΔC, B, reverse(pB), conjB′, ipA, conjA ? α : conj(α), One(), ba... ) + if eltype(ΔC) <: Complex && eltype(ΔA) <: Real + ΔA .+= real.(ΔAc) + end conjΔC = conjB conjA′ = conjB ? conjA : !conjA + ΔBc = eltype(ΔC) <: Complex && eltype(ΔB) <: Real ? zerovector(B, VectorInterface.promote_add(ΔC, α)) : ΔB tensorcontract!( - ΔB, + ΔBc, A, reverse(pA), conjA′, ΔC, pdC, conjΔC, ipB, conjB ? α : conj(α), One(), ba... ) + if eltype(ΔC) <: Complex && eltype(ΔB) <: Real + ΔB .+= real.(ΔBc) + end Δα = if _needs_tangent(α) C_αβ = tensorcontract(A, pA, conjA, B, pB, conjB, pAB, One(), ba...) # TODO: consider using `inner` @@ -54,5 +62,5 @@ function tensorcontract_pb!(ΔC, C, ΔA, A, ΔB, B, α, β, pA, pB, pAB, conjA:: else scale!(ΔC, conj(β)) end - return Δα, Δβ + return ΔC, ΔA, ΔB, Δα, Δβ end diff --git a/src/pullbacks/trace.jl b/src/pullbacks/trace.jl index c616014..3ca8b78 100644 --- a/src/pullbacks/trace.jl +++ b/src/pullbacks/trace.jl @@ -1,4 +1,4 @@ -function tensortrace_pb!(ΔC, C, ΔA, A, α, β, p, q, conjA, ba...) +function tensortrace_pullback!(ΔC, ΔA, C, A, α, β, p, q, conjA, ba...) ip = invperm((linearize(p)..., q[1]..., q[2]...)) Es = map(q[1], q[2]) do i1, i2 one( @@ -8,12 +8,16 @@ function tensortrace_pb!(ΔC, C, ΔA, A, α, β, p, q, conjA, ba...) ) end E = _kron(Es, ba) + ΔAc = eltype(ΔC) <: Complex && eltype(ΔA) <: Real ? zerovector(A, VectorInterface.promote_add(ΔC, α)) : ΔA tensorproduct!( - ΔA, ΔC, (trivtuple(numind(p)), ()), conjA, + ΔAc, ΔC, (trivtuple(numind(p)), ()), conjA, E, ((), trivtuple(numind(q))), conjA, (ip, ()), conjA ? α : conj(α), One(), ba... ) + if eltype(ΔC) <: Complex && eltype(ΔA) <: Real + ΔA .+= real.(ΔAc) + end C_αβ = tensortrace(A, p, q, false, One(), ba...) Δα = if _needs_tangent(α) tensorscalar( @@ -43,5 +47,5 @@ function tensortrace_pb!(ΔC, C, ΔA, A, α, β, p, q, conjA, ba...) else scale!(ΔC, conj(β)) end - return Δα, Δβ + return ΔC, ΔA, Δα, Δβ end diff --git a/test/enzyme.jl b/test/enzyme.jl index f6ee2da..8390f45 100644 --- a/test/enzyme.jl +++ b/test/enzyme.jl @@ -6,7 +6,8 @@ using Enzyme, EnzymeTestUtils (Float64, Float64), (Float32, Float64), (ComplexF64, ComplexF64), - #(Float64, ComplexF64), + (Float64, ComplexF64), + (ComplexF64, Float64), ) T = promote_type(T₁, T₂) atol = max(precision(T₁), precision(T₂)) @@ -35,7 +36,7 @@ end (Float64, Float64), (Float32, Float64), (ComplexF64, ComplexF64), - #(Float64, ComplexF64), + (Float64, ComplexF64), ) T = promote_type(T₁, T₂) atol = max(precision(T₁), precision(T₂)) @@ -60,7 +61,7 @@ end (Float64, Float64), (Float32, Float64), (ComplexF64, ComplexF64), - #(Float64, ComplexF64), + (Float64, ComplexF64), ) T = promote_type(T₁, T₂) atol = max(precision(T₁), precision(T₂)) diff --git a/test/mooncake.jl b/test/mooncake.jl index 729a74d..8db47e9 100644 --- a/test/mooncake.jl +++ b/test/mooncake.jl @@ -14,6 +14,7 @@ is_primitive = false (Float32, Float64), #(ComplexF64, ComplexF64), #(Float64, ComplexF64), + #(ComplexF64, Float64), ) T = promote_type(T₁, T₂) atol = max(precision(T₁), precision(T₂)) From d927e7d7b4a2e18e9d9f29c3752d09de5a3cb976 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Fri, 16 Jan 2026 15:25:33 +0100 Subject: [PATCH 6/9] Import CRC rules for alloc and free --- Project.toml | 2 +- .../TensorOperationsEnzymeExt.jl | 27 +++++++------------ 2 files changed, 11 insertions(+), 18 deletions(-) diff --git a/Project.toml b/Project.toml index 42be359..d7bdd0d 100644 --- a/Project.toml +++ b/Project.toml @@ -31,7 +31,7 @@ TensorOperationsBumperExt = "Bumper" TensorOperationsChainRulesCoreExt = "ChainRulesCore" TensorOperationscuTENSORExt = ["cuTENSOR", "CUDA"] TensorOperationsMooncakeExt = "Mooncake" -TensorOperationsEnzymeExt = "Enzyme" +TensorOperationsEnzymeExt = ["Enzyme", "ChainRulesCore"] [compat] Aqua = "0.6, 0.7, 0.8" diff --git a/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl b/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl index df72208..2f71d90 100644 --- a/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl +++ b/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl @@ -4,10 +4,13 @@ using TensorOperations using TensorOperations: AbstractBackend, DefaultAllocator, CUDAAllocator, ManualAllocator using VectorInterface using TupleTools -using Enzyme +using Enzyme, ChainRulesCore using Enzyme.EnzymeCore using Enzyme.EnzymeCore: EnzymeRules +Enzyme.@import_rrule typeof(TensorOperations.tensorfree!) Any +Enzyme.@import_rrule typeof(TensorOperations.tensoralloc) Any + @inline EnzymeRules.inactive_type(v::Type{<:AbstractBackend}) = true @inline EnzymeRules.inactive_type(v::Type{DefaultAllocator}) = true @inline EnzymeRules.inactive_type(v::Type{<:CUDAAllocator}) = true @@ -33,7 +36,7 @@ function EnzymeRules.augmented_primal( # form caches if needed cache_A = !isa(A_dA, Const) && EnzymeRules.overwritten(config)[3] ? copy(A_dA.val) : nothing cache_B = !isa(B_dB, Const) && EnzymeRules.overwritten(config)[6] ? copy(B_dB.val) : nothing - cache_C = copy(C_dC.val) + cache_C = copy(C_dC.val) # do we need to do this, if we don't need the primal? ba = map(ba_ -> getfield(ba_, :val), ba_dba) primal = if EnzymeRules.needs_primal(config) TensorOperations.tensorcontract!(C_dC.val, A_dA.val, pA_dpA.val, conjA_dconjA.val, B_dB.val, pB_dpB.val, conjB_dconjB.val, pAB_dpAB.val, α_dα.val, β_dβ.val, ba...) @@ -41,11 +44,7 @@ function EnzymeRules.augmented_primal( else nothing end - shadow = if EnzymeRules.needs_shadow(config) - C_dC.dval - else - nothing - end + shadow = EnzymeRules.needs_shadow(config) ? C_dC.dval : nothing return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_B, cache_C)) end @@ -101,14 +100,11 @@ function EnzymeRules.augmented_primal( conjA = conjA_dconjA.val primal = if EnzymeRules.needs_primal(config) TensorOperations.tensoradd!(C_dC.val, A_dA.val, pA_dpA.val, conjA, α, β, ba...) + C_dC.val else nothing end - shadow = if EnzymeRules.needs_shadow(config) - C_dC.dval - else - nothing - end + shadow = EnzymeRules.needs_shadow(config) ? C_dC.dval : nothing return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_C)) end @@ -161,14 +157,11 @@ function EnzymeRules.augmented_primal( conjA = conjA_dconjA.val primal = if EnzymeRules.needs_primal(config) TensorOperations.tensortrace!(C_dC.val, A_dA.val, p_dp.val, q_dq.val, conjA, α, β, ba...) + C_dC.val else nothing end - shadow = if EnzymeRules.needs_shadow(config) - C_dC.dval - else - nothing - end + shadow = EnzymeRules.needs_shadow(config) ? C_dC.dval : nothing return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_C)) end From 2519616be115707461348c78c11880dafa90c66f Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Mon, 19 Jan 2026 16:10:33 +0100 Subject: [PATCH 7/9] Or don't --- Project.toml | 4 ++-- ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl | 5 +---- test/runtests.jl | 6 ++++-- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/Project.toml b/Project.toml index d7bdd0d..70ba45f 100644 --- a/Project.toml +++ b/Project.toml @@ -31,7 +31,7 @@ TensorOperationsBumperExt = "Bumper" TensorOperationsChainRulesCoreExt = "ChainRulesCore" TensorOperationscuTENSORExt = ["cuTENSOR", "CUDA"] TensorOperationsMooncakeExt = "Mooncake" -TensorOperationsEnzymeExt = ["Enzyme", "ChainRulesCore"] +TensorOperationsEnzymeExt = "Enzyme" [compat] Aqua = "0.6, 0.7, 0.8" @@ -52,7 +52,7 @@ Preferences = "1.4" PtrArrays = "1.2" Random = "1" Strided = "2.2" -StridedViews = "0.3, 0.4" +StridedViews = "=0.4.1" Test = "1" TupleTools = "1.6" VectorInterface = "0.4.1,0.5" diff --git a/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl b/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl index 2f71d90..e53e944 100644 --- a/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl +++ b/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl @@ -4,13 +4,10 @@ using TensorOperations using TensorOperations: AbstractBackend, DefaultAllocator, CUDAAllocator, ManualAllocator using VectorInterface using TupleTools -using Enzyme, ChainRulesCore +using Enzyme using Enzyme.EnzymeCore using Enzyme.EnzymeCore: EnzymeRules -Enzyme.@import_rrule typeof(TensorOperations.tensorfree!) Any -Enzyme.@import_rrule typeof(TensorOperations.tensoralloc) Any - @inline EnzymeRules.inactive_type(v::Type{<:AbstractBackend}) = true @inline EnzymeRules.inactive_type(v::Type{DefaultAllocator}) = true @inline EnzymeRules.inactive_type(v::Type{<:CUDAAllocator}) = true diff --git a/test/runtests.jl b/test/runtests.jl index dd594b1..733a5fe 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -36,8 +36,10 @@ if !is_buildkite @testset "mooncake" verbose = false begin include("mooncake.jl") end - @testset "enzyme" verbose = false begin - include("enzyme.jl") + @static if VERSION > v"1.11.0" + @testset "enzyme" verbose = false begin + include("enzyme.jl") + end end end From b8b6cc9e55b59589b4344b5abe8de635420fe8fa Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Mon, 19 Jan 2026 16:38:28 +0100 Subject: [PATCH 8/9] Fixes for primal --- ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl | 6 +++--- test/runtests.jl | 6 ++---- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl b/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl index e53e944..ffe9d8b 100644 --- a/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl +++ b/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl @@ -35,8 +35,8 @@ function EnzymeRules.augmented_primal( cache_B = !isa(B_dB, Const) && EnzymeRules.overwritten(config)[6] ? copy(B_dB.val) : nothing cache_C = copy(C_dC.val) # do we need to do this, if we don't need the primal? ba = map(ba_ -> getfield(ba_, :val), ba_dba) + TensorOperations.tensorcontract!(C_dC.val, A_dA.val, pA_dpA.val, conjA_dconjA.val, B_dB.val, pB_dpB.val, conjB_dconjB.val, pAB_dpAB.val, α_dα.val, β_dβ.val, ba...) primal = if EnzymeRules.needs_primal(config) - TensorOperations.tensorcontract!(C_dC.val, A_dA.val, pA_dpA.val, conjA_dconjA.val, B_dB.val, pB_dpB.val, conjB_dconjB.val, pAB_dpAB.val, α_dα.val, β_dβ.val, ba...) C_dC.val else nothing @@ -95,8 +95,8 @@ function EnzymeRules.augmented_primal( α = α_dα.val β = β_dβ.val conjA = conjA_dconjA.val + TensorOperations.tensoradd!(C_dC.val, A_dA.val, pA_dpA.val, conjA, α, β, ba...) primal = if EnzymeRules.needs_primal(config) - TensorOperations.tensoradd!(C_dC.val, A_dA.val, pA_dpA.val, conjA, α, β, ba...) C_dC.val else nothing @@ -152,8 +152,8 @@ function EnzymeRules.augmented_primal( α = α_dα.val β = β_dβ.val conjA = conjA_dconjA.val + TensorOperations.tensortrace!(C_dC.val, A_dA.val, p_dp.val, q_dq.val, conjA, α, β, ba...) primal = if EnzymeRules.needs_primal(config) - TensorOperations.tensortrace!(C_dC.val, A_dA.val, p_dp.val, q_dq.val, conjA, α, β, ba...) C_dC.val else nothing diff --git a/test/runtests.jl b/test/runtests.jl index 733a5fe..dd594b1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -36,10 +36,8 @@ if !is_buildkite @testset "mooncake" verbose = false begin include("mooncake.jl") end - @static if VERSION > v"1.11.0" - @testset "enzyme" verbose = false begin - include("enzyme.jl") - end + @testset "enzyme" verbose = false begin + include("enzyme.jl") end end From aac7e43c1f6f8b0d5b6f89f9af36163c11bd78bc Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Mon, 19 Jan 2026 18:25:30 +0100 Subject: [PATCH 9/9] Comments --- Project.toml | 5 +++-- ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl | 7 +++++-- src/pullbacks/add.jl | 6 +----- src/pullbacks/contract.jl | 6 +----- src/pullbacks/trace.jl | 6 +----- test/enzyme.jl | 2 +- 6 files changed, 12 insertions(+), 20 deletions(-) diff --git a/Project.toml b/Project.toml index 70ba45f..d57dce1 100644 --- a/Project.toml +++ b/Project.toml @@ -31,7 +31,7 @@ TensorOperationsBumperExt = "Bumper" TensorOperationsChainRulesCoreExt = "ChainRulesCore" TensorOperationscuTENSORExt = ["cuTENSOR", "CUDA"] TensorOperationsMooncakeExt = "Mooncake" -TensorOperationsEnzymeExt = "Enzyme" +TensorOperationsEnzymeExt = ["Enzyme", "ChainRulesCore"] [compat] Aqua = "0.6, 0.7, 0.8" @@ -63,6 +63,7 @@ julia = "1.10" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" DynamicPolynomials = "7c1d4256-1411-5781-91ec-d7bc3513ac07" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" @@ -74,4 +75,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1" [targets] -test = ["Test", "Random", "DynamicPolynomials", "ChainRulesTestUtils", "CUDA", "cuTENSOR", "Aqua", "Logging", "Bumper", "Mooncake", "Enzyme", "EnzymeTestUtils"] +test = ["Test", "Random", "DynamicPolynomials", "ChainRulesTestUtils", "ChainRulesCore", "CUDA", "cuTENSOR", "Aqua", "Logging", "Bumper", "Mooncake", "Enzyme", "EnzymeTestUtils"] diff --git a/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl b/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl index ffe9d8b..18fb7b0 100644 --- a/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl +++ b/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl @@ -4,15 +4,18 @@ using TensorOperations using TensorOperations: AbstractBackend, DefaultAllocator, CUDAAllocator, ManualAllocator using VectorInterface using TupleTools -using Enzyme +using Enzyme, ChainRulesCore using Enzyme.EnzymeCore using Enzyme.EnzymeCore: EnzymeRules +@inline EnzymeRules.inactive(::typeof(TensorOperations.tensorfree!), ::Any) = true +Enzyme.@import_rrule(typeof(TensorOperations.tensoralloc), Any, Any, Any, Any) + @inline EnzymeRules.inactive_type(v::Type{<:AbstractBackend}) = true @inline EnzymeRules.inactive_type(v::Type{DefaultAllocator}) = true @inline EnzymeRules.inactive_type(v::Type{<:CUDAAllocator}) = true @inline EnzymeRules.inactive_type(v::Type{ManualAllocator}) = true -@inline EnzymeRules.inactive_type(v::Type{Index2Tuple}) = true +@inline EnzymeRules.inactive_type(v::Type{<:Index2Tuple}) = true function EnzymeRules.augmented_primal( config::EnzymeRules.RevConfigWidth{1}, diff --git a/src/pullbacks/add.jl b/src/pullbacks/add.jl index 1cc5dab..bf34e3f 100644 --- a/src/pullbacks/add.jl +++ b/src/pullbacks/add.jl @@ -27,10 +27,6 @@ function tensoradd_pullback!(ΔC, ΔA, C, A, α, β, pA, conjA::Bool, ba...) else nothing end - if β === Zero() - scale!(ΔC, β) - else - scale!(ΔC, conj(β)) - end + scale!(ΔC, conj(β)) return ΔC, ΔA, Δα, Δβ end diff --git a/src/pullbacks/contract.jl b/src/pullbacks/contract.jl index 939b189..2bd42ac 100644 --- a/src/pullbacks/contract.jl +++ b/src/pullbacks/contract.jl @@ -57,10 +57,6 @@ function tensorcontract_pullback!(ΔC, ΔA, ΔB, C, A, B, α, β, pA, pB, pAB, c else nothing end - if β === Zero() - scale!(ΔC, β) - else - scale!(ΔC, conj(β)) - end + scale!(ΔC, conj(β)) return ΔC, ΔA, ΔB, Δα, Δβ end diff --git a/src/pullbacks/trace.jl b/src/pullbacks/trace.jl index 3ca8b78..5c08033 100644 --- a/src/pullbacks/trace.jl +++ b/src/pullbacks/trace.jl @@ -42,10 +42,6 @@ function tensortrace_pullback!(ΔC, ΔA, C, A, α, β, p, q, conjA, ba...) else nothing end - if β === Zero() - scale!(ΔC, β) - else - scale!(ΔC, conj(β)) - end + scale!(ΔC, conj(β)) return ΔC, ΔA, Δα, Δβ end diff --git a/test/enzyme.jl b/test/enzyme.jl index 8390f45..2852b5d 100644 --- a/test/enzyme.jl +++ b/test/enzyme.jl @@ -1,5 +1,5 @@ using TensorOperations, VectorInterface -using Enzyme, EnzymeTestUtils +using Enzyme, ChainRulesCore, EnzymeTestUtils @testset "tensorcontract! ($T₁, $T₂)" for (T₁, T₂) in (