From 33281b06dc720173d3a2149446ad12bc4e9cb430 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 20 Jan 2026 10:52:11 +0100 Subject: [PATCH 01/11] Split pullbacks into separate files --- ext/TensorOperationsChainRulesCoreExt.jl | 4 +- .../TensorOperationsMooncakeExt.jl | 153 ++---------------- src/TensorOperations.jl | 6 + src/pullbacks/add.jl | 32 ++++ src/pullbacks/contract.jl | 62 +++++++ src/pullbacks/trace.jl | 47 ++++++ src/utils.jl | 2 + 7 files changed, 160 insertions(+), 146 deletions(-) create mode 100644 src/pullbacks/add.jl create mode 100644 src/pullbacks/contract.jl create mode 100644 src/pullbacks/trace.jl diff --git a/ext/TensorOperationsChainRulesCoreExt.jl b/ext/TensorOperationsChainRulesCoreExt.jl index 539d9d9..8338468 100644 --- a/ext/TensorOperationsChainRulesCoreExt.jl +++ b/ext/TensorOperationsChainRulesCoreExt.jl @@ -1,7 +1,7 @@ module TensorOperationsChainRulesCoreExt using TensorOperations -using TensorOperations: numind, numin, numout, promote_contract, _needs_tangent +using TensorOperations: numind, numin, numout, promote_contract, _needs_tangent, trivtuple using TensorOperations: DefaultBackend, DefaultAllocator, _kron using ChainRulesCore using TupleTools @@ -9,8 +9,6 @@ using VectorInterface using TupleTools: invperm using LinearAlgebra -trivtuple(N) = ntuple(identity, N) - @non_differentiable TensorOperations.tensorstructure(args...) @non_differentiable TensorOperations.tensoradd_structure(args...) @non_differentiable TensorOperations.tensoradd_type(args...) diff --git a/ext/TensorOperationsMooncakeExt/TensorOperationsMooncakeExt.jl b/ext/TensorOperationsMooncakeExt/TensorOperationsMooncakeExt.jl index 49d1708..9450d77 100644 --- a/ext/TensorOperationsMooncakeExt/TensorOperationsMooncakeExt.jl +++ b/ext/TensorOperationsMooncakeExt/TensorOperationsMooncakeExt.jl @@ -6,7 +6,7 @@ using TensorOperations # extension are in fact loaded using Mooncake, Mooncake.CRC using TensorOperations: AbstractBackend, DefaultAllocator, CUDAAllocator, ManualAllocator -using TensorOperations: tensoralloc, tensoradd!, tensorcontract!, tensortrace!, _kron, numind, _needs_tangent, numin, numout +using TensorOperations: tensoralloc, tensoradd!, tensorcontract!, tensortrace! using Mooncake: ReverseMode, DefaultCtx, CoDual, NoRData, arrayify, @zero_derivative, primal, tangent using VectorInterface, TupleTools @@ -16,8 +16,6 @@ Mooncake.tangent_type(::Type{DefaultAllocator}) = Mooncake.NoTangent Mooncake.tangent_type(::Type{CUDAAllocator}) = Mooncake.NoTangent Mooncake.tangent_type(::Type{ManualAllocator}) = Mooncake.NoTangent -trivtuple(N) = ntuple(identity, N) - @zero_derivative DefaultCtx Tuple{typeof(TensorOperations.tensorstructure), Any} @zero_derivative DefaultCtx Tuple{typeof(TensorOperations.tensoradd_structure), Any} @zero_derivative DefaultCtx Tuple{typeof(TensorOperations.tensoradd_type), Any} @@ -61,69 +59,9 @@ function Mooncake.rrule!!( TensorOperations.tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba...) function contract_pb(::NoRData) scale!(C, C_cache, One()) - if Tα == Zero && Tβ == Zero - scale!(dC, zero(TC)) - return ntuple(i -> NoRData(), 11 + length(ba)) - end - ipAB = invperm(linearize(pAB)) - pdC = ( - TupleTools.getindices(ipAB, trivtuple(numout(pA))), - TupleTools.getindices(ipAB, numout(pA) .+ trivtuple(numin(pB))), - ) - ipA = (invperm(linearize(pA)), ()) - ipB = (invperm(linearize(pB)), ()) - conjΔC = conjA - conjB′ = conjA ? conjB : !conjB - dA = tensorcontract!( - dA, - dC, pdC, conjΔC, - B, reverse(pB), conjB′, - ipA, - conjA ? α : conj(α), One(), ba... - ) - conjΔC = conjB - conjA′ = conjB ? conjA : !conjA - dB = tensorcontract!( - dB, - A, reverse(pA), conjA′, - dC, pdC, conjΔC, - ipB, - conjB ? α : conj(α), One(), ba... - ) - dα = if _needs_tangent(Tα) - C_αβ = tensorcontract(A, pA, conjA, B, pB, conjB, pAB, One(), ba...) - # TODO: consider using `inner` - Mooncake._rdata( - tensorscalar( - tensorcontract( - C_αβ, ((), trivtuple(numind(pAB))), true, - dC, (trivtuple(numind(pAB)), ()), false, - ((), ()), One(), ba... - ) - ) - ) - else - NoRData() - end - dβ = if _needs_tangent(Tβ) - # TODO: consider using `inner` - Mooncake._rdata( - tensorscalar( - tensorcontract( - C, ((), trivtuple(numind(pAB))), true, - dC, (trivtuple(numind(pAB)), ()), false, - ((), ()), One(), ba... - ) - ) - ) - else - NoRData() - end - if β === Zero() - scale!(dC, β) - else - scale!(dC, conj(β)) - end + dC, dA, dB, Δα, Δβ = TensorOperations.tensorcontract_pullback!(dC, dA, dB, C, A, B, α, β, pA, pB, pAB, conjA, conjB, ba...) + dα = isnothing(Δα) ? NoRData() : Mooncake._rdata(Δα) + dβ = isnothing(Δβ) ? NoRData() : Mooncake._rdata(Δβ) return NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), dα, dβ, map(ba_ -> NoRData(), ba)... end return C_dC, contract_pb @@ -151,35 +89,9 @@ function Mooncake.rrule!!( TensorOperations.tensoradd!(C, A, pA, conjA, α, β, ba...) function add_pb(::NoRData) scale!(C, C_cache, One()) - ipA = invperm(linearize(pA)) - dA = tensoradd!(dA, dC, (ipA, ()), conjA, conjA ? α : conj(α), One(), ba...) - dα = if _needs_tangent(Tα) - tensorscalar( - tensorcontract( - A, ((), linearize(pA)), !conjA, - dC, (trivtuple(numind(pA)), ()), false, - ((), ()), One(), ba... - ) - ) - else - Mooncake.NoRData() - end - dβ = if _needs_tangent(Tβ) - tensorscalar( - tensorcontract( - C, ((), trivtuple(numind(pA))), true, - dC, (trivtuple(numind(pA)), ()), false, - ((), ()), One(), ba... - ) - ) - else - Mooncake.NoRData() - end - if β === Zero() - scale!(dC, β) - else - scale!(dC, conj(β)) - end + dC, dA, Δα, Δβ = TensorOperations.tensoradd_pullback!(dC, dA, C, A, α, β, pA, conjA, ba...) + dα = isnothing(Δα) ? NoRData() : Mooncake._rdata(Δα) + dβ = isnothing(Δβ) ? NoRData() : Mooncake._rdata(Δβ) return NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), dα, dβ, map(ba_ -> NoRData(), ba)... end return C_dC, add_pb @@ -209,54 +121,9 @@ function Mooncake.rrule!!( TensorOperations.tensortrace!(C, A, p, q, conjA, α, β, ba...) function trace_pb(::NoRData) scale!(C, C_cache, One()) - ip = invperm((linearize(p)..., q[1]..., q[2]...)) - Es = map(q[1], q[2]) do i1, i2 - one( - TensorOperations.tensoralloc_add( - TensorOperations.scalartype(A), A, ((i1,), (i2,)), conjA - ) - ) - end - E = _kron(Es, ba) - dA = tensorproduct!( - dA, dC, (trivtuple(numind(p)), ()), conjA, - E, ((), trivtuple(numind(q))), conjA, - (ip, ()), - conjA ? α : conj(α), One(), ba... - ) - C_αβ = tensortrace(A, p, q, false, One(), ba...) - dα = if _needs_tangent(Tα) - Mooncake._rdata( - tensorscalar( - tensorcontract( - C_αβ, ((), trivtuple(numind(p))), - !conjA, - dC, (trivtuple(numind(p)), ()), false, - ((), ()), One(), ba... - ) - ) - ) - else - NoRData() - end - dβ = if _needs_tangent(Tβ) - Mooncake._rdata( - tensorscalar( - tensorcontract( - C, ((), trivtuple(numind(p))), true, - dC, (trivtuple(numind(p)), ()), false, - ((), ()), One(), ba... - ) - ) - ) - else - NoRData() - end - if β === Zero() - scale!(dC, β) - else - scale!(dC, conj(β)) - end + dC, dA, Δα, Δβ = TensorOperations.tensortrace_pullback!(dC, dA, C, A, α, β, p, q, conjA, ba...) + dα = isnothing(Δα) ? NoRData() : Mooncake._rdata(Δα) + dβ = isnothing(Δβ) ? NoRData() : Mooncake._rdata(Δβ) return NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), dα, dβ, map(ba_ -> NoRData(), ba)... end return C_dC, trace_pb diff --git a/src/TensorOperations.jl b/src/TensorOperations.jl index 23f205d..05ae808 100644 --- a/src/TensorOperations.jl +++ b/src/TensorOperations.jl @@ -36,6 +36,12 @@ include("backends.jl") include("interface.jl") include("utils.jl") +# Generic pullbacks for AD +#--------------------------- +include("pullbacks/add.jl") +include("pullbacks/trace.jl") +include("pullbacks/contract.jl") + # Index notation via macros #--------------------------- @nospecialize diff --git a/src/pullbacks/add.jl b/src/pullbacks/add.jl new file mode 100644 index 0000000..bf34e3f --- /dev/null +++ b/src/pullbacks/add.jl @@ -0,0 +1,32 @@ +function tensoradd_pullback!(ΔC, ΔA, C, A, α, β, pA, conjA::Bool, ba...) + ipA = invperm(linearize(pA)) + ΔAc = eltype(ΔC) <: Complex && eltype(ΔA) <: Real ? zerovector(A, VectorInterface.promote_add(ΔC, α)) : ΔA + tensoradd!(ΔAc, ΔC, (ipA, ()), conjA, conjA ? α : conj(α), One(), ba...) + if eltype(ΔC) <: Complex && eltype(ΔA) <: Real + ΔA .+= real.(ΔAc) + end + Δα = if _needs_tangent(α) + tensorscalar( + tensorcontract( + A, ((), linearize(pA)), !conjA, + ΔC, (trivtuple(numind(pA)), ()), false, + ((), ()), One(), ba... + ) + ) + else + nothing + end + Δβ = if _needs_tangent(β) + tensorscalar( + tensorcontract( + C, ((), trivtuple(numind(pA))), true, + ΔC, (trivtuple(numind(pA)), ()), false, + ((), ()), One(), ba... + ) + ) + else + nothing + end + scale!(ΔC, conj(β)) + return ΔC, ΔA, Δα, Δβ +end diff --git a/src/pullbacks/contract.jl b/src/pullbacks/contract.jl new file mode 100644 index 0000000..2bd42ac --- /dev/null +++ b/src/pullbacks/contract.jl @@ -0,0 +1,62 @@ +function tensorcontract_pullback!(ΔC, ΔA, ΔB, C, A, B, α, β, pA, pB, pAB, conjA::Bool, conjB::Bool, ba...) + ipAB = invperm(linearize(pAB)) + pdC = ( + TupleTools.getindices(ipAB, trivtuple(numout(pA))), + TupleTools.getindices(ipAB, numout(pA) .+ trivtuple(numin(pB))), + ) + ipA = (invperm(linearize(pA)), ()) + ipB = (invperm(linearize(pB)), ()) + conjΔC = conjA + conjB′ = conjA ? conjB : !conjB + ΔAc = eltype(ΔC) <: Complex && eltype(ΔA) <: Real ? zerovector(A, VectorInterface.promote_add(ΔC, α)) : ΔA + tensorcontract!( + ΔAc, + ΔC, pdC, conjΔC, + B, reverse(pB), conjB′, + ipA, + conjA ? α : conj(α), One(), ba... + ) + if eltype(ΔC) <: Complex && eltype(ΔA) <: Real + ΔA .+= real.(ΔAc) + end + conjΔC = conjB + conjA′ = conjB ? conjA : !conjA + ΔBc = eltype(ΔC) <: Complex && eltype(ΔB) <: Real ? zerovector(B, VectorInterface.promote_add(ΔC, α)) : ΔB + tensorcontract!( + ΔBc, + A, reverse(pA), conjA′, + ΔC, pdC, conjΔC, + ipB, + conjB ? α : conj(α), One(), ba... + ) + if eltype(ΔC) <: Complex && eltype(ΔB) <: Real + ΔB .+= real.(ΔBc) + end + Δα = if _needs_tangent(α) + C_αβ = tensorcontract(A, pA, conjA, B, pB, conjB, pAB, One(), ba...) + # TODO: consider using `inner` + tensorscalar( + tensorcontract( + C_αβ, ((), trivtuple(numind(pAB))), true, + ΔC, (trivtuple(numind(pAB)), ()), false, + ((), ()), One(), ba... + ) + ) + else + nothing + end + Δβ = if _needs_tangent(β) + # TODO: consider using `inner` + tensorscalar( + tensorcontract( + C, ((), trivtuple(numind(pAB))), true, + ΔC, (trivtuple(numind(pAB)), ()), false, + ((), ()), One(), ba... + ) + ) + else + nothing + end + scale!(ΔC, conj(β)) + return ΔC, ΔA, ΔB, Δα, Δβ +end diff --git a/src/pullbacks/trace.jl b/src/pullbacks/trace.jl new file mode 100644 index 0000000..5c08033 --- /dev/null +++ b/src/pullbacks/trace.jl @@ -0,0 +1,47 @@ +function tensortrace_pullback!(ΔC, ΔA, C, A, α, β, p, q, conjA, ba...) + ip = invperm((linearize(p)..., q[1]..., q[2]...)) + Es = map(q[1], q[2]) do i1, i2 + one( + TensorOperations.tensoralloc_add( + TensorOperations.scalartype(A), A, ((i1,), (i2,)), conjA + ) + ) + end + E = _kron(Es, ba) + ΔAc = eltype(ΔC) <: Complex && eltype(ΔA) <: Real ? zerovector(A, VectorInterface.promote_add(ΔC, α)) : ΔA + tensorproduct!( + ΔAc, ΔC, (trivtuple(numind(p)), ()), conjA, + E, ((), trivtuple(numind(q))), conjA, + (ip, ()), + conjA ? α : conj(α), One(), ba... + ) + if eltype(ΔC) <: Complex && eltype(ΔA) <: Real + ΔA .+= real.(ΔAc) + end + C_αβ = tensortrace(A, p, q, false, One(), ba...) + Δα = if _needs_tangent(α) + tensorscalar( + tensorcontract( + C_αβ, ((), trivtuple(numind(p))), + !conjA, + ΔC, (trivtuple(numind(p)), ()), false, + ((), ()), One(), ba... + ) + ) + else + nothing + end + Δβ = if _needs_tangent(β) + tensorscalar( + tensorcontract( + C, ((), trivtuple(numind(p))), true, + ΔC, (trivtuple(numind(p)), ()), false, + ((), ()), One(), ba... + ) + ) + else + nothing + end + scale!(ΔC, conj(β)) + return ΔC, ΔA, Δα, Δβ +end diff --git a/src/utils.jl b/src/utils.jl index 20e406b..68b98c8 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -13,3 +13,5 @@ _needs_tangent(x) = _needs_tangent(typeof(x)) _needs_tangent(::Type{<:Number}) = true _needs_tangent(::Type{<:Integer}) = false _needs_tangent(::Type{<:Union{One, Zero}}) = false + +trivtuple(N) = ntuple(identity, N) From 2b20afaf0d7bb5e92a899a4e0aa6c9874e6fe6b8 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 20 Jan 2026 10:55:26 +0100 Subject: [PATCH 02/11] Fix StridedViews version for now --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index c05ec0c..b9bb62e 100644 --- a/Project.toml +++ b/Project.toml @@ -48,7 +48,7 @@ Preferences = "1.4" PtrArrays = "1.2" Random = "1" Strided = "2.2" -StridedViews = "0.3, 0.4" +StridedViews = "=0.4.1" Test = "1" TupleTools = "1.6" VectorInterface = "0.4.1,0.5" From fce1435244c4244dfe6c9acdc71032aa08851b38 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 20 Jan 2026 13:43:14 +0100 Subject: [PATCH 03/11] Fix ordering --- .../TensorOperationsMooncakeExt.jl | 8 ++++---- src/pullbacks/add.jl | 2 +- src/pullbacks/contract.jl | 2 +- src/pullbacks/trace.jl | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/ext/TensorOperationsMooncakeExt/TensorOperationsMooncakeExt.jl b/ext/TensorOperationsMooncakeExt/TensorOperationsMooncakeExt.jl index 9450d77..64ea471 100644 --- a/ext/TensorOperationsMooncakeExt/TensorOperationsMooncakeExt.jl +++ b/ext/TensorOperationsMooncakeExt/TensorOperationsMooncakeExt.jl @@ -59,7 +59,7 @@ function Mooncake.rrule!!( TensorOperations.tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba...) function contract_pb(::NoRData) scale!(C, C_cache, One()) - dC, dA, dB, Δα, Δβ = TensorOperations.tensorcontract_pullback!(dC, dA, dB, C, A, B, α, β, pA, pB, pAB, conjA, conjB, ba...) + dC, dA, dB, Δα, Δβ = TensorOperations.tensorcontract_pullback!(dC, dA, dB, C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba...) dα = isnothing(Δα) ? NoRData() : Mooncake._rdata(Δα) dβ = isnothing(Δβ) ? NoRData() : Mooncake._rdata(Δβ) return NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), dα, dβ, map(ba_ -> NoRData(), ba)... @@ -80,7 +80,7 @@ function Mooncake.rrule!!( ) where {Tα <: Number, Tβ <: Number, TA <: Number, TC <: Number} C, dC = arrayify(C_dC) A, dA = arrayify(A_dA) - pA = primal(pA_dpA) + pA = primal(pA_dpA) conjA = primal(conjA_dconjA) α = primal(α_dα) β = primal(β_dβ) @@ -89,7 +89,7 @@ function Mooncake.rrule!!( TensorOperations.tensoradd!(C, A, pA, conjA, α, β, ba...) function add_pb(::NoRData) scale!(C, C_cache, One()) - dC, dA, Δα, Δβ = TensorOperations.tensoradd_pullback!(dC, dA, C, A, α, β, pA, conjA, ba...) + dC, dA, Δα, Δβ = TensorOperations.tensoradd_pullback!(dC, dA, C, A, pA, conjA, α, β, ba...) dα = isnothing(Δα) ? NoRData() : Mooncake._rdata(Δα) dβ = isnothing(Δβ) ? NoRData() : Mooncake._rdata(Δβ) return NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), dα, dβ, map(ba_ -> NoRData(), ba)... @@ -121,7 +121,7 @@ function Mooncake.rrule!!( TensorOperations.tensortrace!(C, A, p, q, conjA, α, β, ba...) function trace_pb(::NoRData) scale!(C, C_cache, One()) - dC, dA, Δα, Δβ = TensorOperations.tensortrace_pullback!(dC, dA, C, A, α, β, p, q, conjA, ba...) + dC, dA, Δα, Δβ = TensorOperations.tensortrace_pullback!(dC, dA, C, A, p, q, conjA, α, β, ba...) dα = isnothing(Δα) ? NoRData() : Mooncake._rdata(Δα) dβ = isnothing(Δβ) ? NoRData() : Mooncake._rdata(Δβ) return NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), dα, dβ, map(ba_ -> NoRData(), ba)... diff --git a/src/pullbacks/add.jl b/src/pullbacks/add.jl index bf34e3f..4febed9 100644 --- a/src/pullbacks/add.jl +++ b/src/pullbacks/add.jl @@ -1,4 +1,4 @@ -function tensoradd_pullback!(ΔC, ΔA, C, A, α, β, pA, conjA::Bool, ba...) +function tensoradd_pullback!(ΔC, ΔA, C, A, pA::Index2Tuple, conjA::Bool, α, β, ba...) ipA = invperm(linearize(pA)) ΔAc = eltype(ΔC) <: Complex && eltype(ΔA) <: Real ? zerovector(A, VectorInterface.promote_add(ΔC, α)) : ΔA tensoradd!(ΔAc, ΔC, (ipA, ()), conjA, conjA ? α : conj(α), One(), ba...) diff --git a/src/pullbacks/contract.jl b/src/pullbacks/contract.jl index 2bd42ac..84ad353 100644 --- a/src/pullbacks/contract.jl +++ b/src/pullbacks/contract.jl @@ -1,4 +1,4 @@ -function tensorcontract_pullback!(ΔC, ΔA, ΔB, C, A, B, α, β, pA, pB, pAB, conjA::Bool, conjB::Bool, ba...) +function tensorcontract_pullback!(ΔC, ΔA, ΔB, C, A, pA::Index2Tuple, conjA::Bool, B, pB::Index2Tuple, conjB::Bool, pAB::Index2Tuple, α, β, ba...) ipAB = invperm(linearize(pAB)) pdC = ( TupleTools.getindices(ipAB, trivtuple(numout(pA))), diff --git a/src/pullbacks/trace.jl b/src/pullbacks/trace.jl index 5c08033..09b82d3 100644 --- a/src/pullbacks/trace.jl +++ b/src/pullbacks/trace.jl @@ -1,4 +1,4 @@ -function tensortrace_pullback!(ΔC, ΔA, C, A, α, β, p, q, conjA, ba...) +function tensortrace_pullback!(ΔC, ΔA, C, A, p::Index2Tuple, q::Index2Tuple, conjA::Bool, α, β, ba...) ip = invperm((linearize(p)..., q[1]..., q[2]...)) Es = map(q[1], q[2]) do i1, i2 one( From b20f8069af6e7158e6953391df836bc774f8a054 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 20 Jan 2026 13:44:43 +0100 Subject: [PATCH 04/11] Format --- ext/TensorOperationsMooncakeExt/TensorOperationsMooncakeExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/TensorOperationsMooncakeExt/TensorOperationsMooncakeExt.jl b/ext/TensorOperationsMooncakeExt/TensorOperationsMooncakeExt.jl index 64ea471..8bff2f8 100644 --- a/ext/TensorOperationsMooncakeExt/TensorOperationsMooncakeExt.jl +++ b/ext/TensorOperationsMooncakeExt/TensorOperationsMooncakeExt.jl @@ -80,7 +80,7 @@ function Mooncake.rrule!!( ) where {Tα <: Number, Tβ <: Number, TA <: Number, TC <: Number} C, dC = arrayify(C_dC) A, dA = arrayify(A_dA) - pA = primal(pA_dpA) + pA = primal(pA_dpA) conjA = primal(conjA_dconjA) α = primal(α_dα) β = primal(β_dβ) From 3a5ebadf06caeb545888f18a80584086efe8c6f5 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 20 Jan 2026 14:49:51 +0100 Subject: [PATCH 05/11] Fix Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index b9bb62e..deccc00 100644 --- a/Project.toml +++ b/Project.toml @@ -48,7 +48,7 @@ Preferences = "1.4" PtrArrays = "1.2" Random = "1" Strided = "2.2" -StridedViews = "=0.4.1" +StridedViews = "0.3, 0.4, ~0.4.2" Test = "1" TupleTools = "1.6" VectorInterface = "0.4.1,0.5" From 149de598105dbc02cac5a2cc7aada24256ecadc2 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 20 Jan 2026 10:20:43 -0500 Subject: [PATCH 06/11] index magic --- src/indices.jl | 29 +++++++++++++++++++++++++++++ src/pullbacks/add.jl | 6 +++--- src/pullbacks/contract.jl | 22 +++++++++++----------- src/pullbacks/trace.jl | 14 +++++++------- src/utils.jl | 6 ++---- 5 files changed, 52 insertions(+), 25 deletions(-) diff --git a/src/indices.jl b/src/indices.jl index d844478..61eeddc 100644 --- a/src/indices.jl +++ b/src/indices.jl @@ -26,10 +26,39 @@ trivialpermutation(p::IndexTuple{N}) where {N} = ntuple(identity, Val(N)) function trivialpermutation(p::Index2Tuple) return (trivialpermutation(p[1]), numout(p) .+ trivialpermutation(p[2])) end +trivialpermutation(N::Integer) = ntuple(identity, N) +trivialpermutation(N₁::Integer, N₂::Integer) = (trivialpermutation(N₁), trivialpermutation(N₂) .+ N₁) istrivialpermutation(p::IndexTuple) = p == trivialpermutation(p) istrivialpermutation(p::Index2Tuple) = p == trivialpermutation(p) +Base.@constprop :aggressive function repartition(p::IndexTuple, N₁::Int) + length(p) >= N₁ || throw(ArgumentError("cannot repartition $(typeof(p)) to $N₁, $(length(p) - N₁)")) + partition = trivialpermutation(N₁, length(p) - N₁) + return TupleTools.getindices(p, partition[1]), TupleTools.getindices(p, partition[2]) +end +@inline repartition(p::Index2Tuple, N₁::Int) = repartition(linearize(p), N₁) + +@inline repartition(p::Union{IndexTuple, Index2Tuple}, ::Index2Tuple{N₁}) where {N₁} = + repartition(p, N₁) +@inline repartition(p::Union{IndexTuple, Index2Tuple}, A::AbstractArray) = repartition(p, ndims(A)) + +""" + inversepermutation(p::Index2Tuple) -> ip::IndexTuple + inversepermutation(p::Index2Tuple, N₁::Int) -> ip::Index2Tuple + inversepermutation(p::Index2Tuple, partition_as::Index2Tuple) -> ip::Index2Tuple + inversepermutation(p::Index2Tuple, partition_as::AbstractArray) -> ip::Index2Tuple + +Compute the inverse permutation associated to `p`. +If no extra arguments are provided, the result is returned as a single `IndexTuple`. +Otherwise, the extra arguments are used to partition the inverse permutation into an `Index2Tuple`. +""" +inversepermutation(p::Index2Tuple) = invperm(linearize(p)) +function inversepermutation(p::Index2Tuple, args...) + ip = invperm(linearize(p)) + return repartition(ip, args...) +end + """ const LabelType = Union{Int,Symbol,Char} diff --git a/src/pullbacks/add.jl b/src/pullbacks/add.jl index 4febed9..7cf10c0 100644 --- a/src/pullbacks/add.jl +++ b/src/pullbacks/add.jl @@ -9,7 +9,7 @@ function tensoradd_pullback!(ΔC, ΔA, C, A, pA::Index2Tuple, conjA::Bool, α, tensorscalar( tensorcontract( A, ((), linearize(pA)), !conjA, - ΔC, (trivtuple(numind(pA)), ()), false, + ΔC, trivialpermutation(numind(pA), 0), false, ((), ()), One(), ba... ) ) @@ -19,8 +19,8 @@ function tensoradd_pullback!(ΔC, ΔA, C, A, pA::Index2Tuple, conjA::Bool, α, Δβ = if _needs_tangent(β) tensorscalar( tensorcontract( - C, ((), trivtuple(numind(pA))), true, - ΔC, (trivtuple(numind(pA)), ()), false, + C, trivialpermutation(0, numind(pA)), true, + ΔC, trivialpermutation(numind(pA), 0), false, ((), ()), One(), ba... ) ) diff --git a/src/pullbacks/contract.jl b/src/pullbacks/contract.jl index 84ad353..ff3a8a1 100644 --- a/src/pullbacks/contract.jl +++ b/src/pullbacks/contract.jl @@ -1,11 +1,8 @@ function tensorcontract_pullback!(ΔC, ΔA, ΔB, C, A, pA::Index2Tuple, conjA::Bool, B, pB::Index2Tuple, conjB::Bool, pAB::Index2Tuple, α, β, ba...) - ipAB = invperm(linearize(pAB)) - pdC = ( - TupleTools.getindices(ipAB, trivtuple(numout(pA))), - TupleTools.getindices(ipAB, numout(pA) .+ trivtuple(numin(pB))), - ) - ipA = (invperm(linearize(pA)), ()) - ipB = (invperm(linearize(pB)), ()) + pdC = inversepermutation(pAB, numout(pA)) + ipA = inversepermutation(pA, A) + ipB = inversepermutation(pB, B) + conjΔC = conjA conjB′ = conjA ? conjB : !conjB ΔAc = eltype(ΔC) <: Complex && eltype(ΔA) <: Real ? zerovector(A, VectorInterface.promote_add(ΔC, α)) : ΔA @@ -19,6 +16,7 @@ function tensorcontract_pullback!(ΔC, ΔA, ΔB, C, A, pA::Index2Tuple, conjA::B if eltype(ΔC) <: Complex && eltype(ΔA) <: Real ΔA .+= real.(ΔAc) end + conjΔC = conjB conjA′ = conjB ? conjA : !conjA ΔBc = eltype(ΔC) <: Complex && eltype(ΔB) <: Real ? zerovector(B, VectorInterface.promote_add(ΔC, α)) : ΔB @@ -32,25 +30,27 @@ function tensorcontract_pullback!(ΔC, ΔA, ΔB, C, A, pA::Index2Tuple, conjA::B if eltype(ΔC) <: Complex && eltype(ΔB) <: Real ΔB .+= real.(ΔBc) end + Δα = if _needs_tangent(α) C_αβ = tensorcontract(A, pA, conjA, B, pB, conjB, pAB, One(), ba...) # TODO: consider using `inner` tensorscalar( tensorcontract( - C_αβ, ((), trivtuple(numind(pAB))), true, - ΔC, (trivtuple(numind(pAB)), ()), false, + C_αβ, trivialtuple(0, numind(pAB)), true, + ΔC, trivialtuple(numind(pAB, 0), false, ((), ()), One(), ba... ) ) else nothing end + Δβ = if _needs_tangent(β) # TODO: consider using `inner` tensorscalar( tensorcontract( - C, ((), trivtuple(numind(pAB))), true, - ΔC, (trivtuple(numind(pAB)), ()), false, + C, trivialtuple(0, numind(pAB)), true, + ΔC, trivialtuple(numind(pAB), 0), false, ((), ()), One(), ba... ) ) diff --git a/src/pullbacks/trace.jl b/src/pullbacks/trace.jl index 09b82d3..bd0e2d9 100644 --- a/src/pullbacks/trace.jl +++ b/src/pullbacks/trace.jl @@ -1,5 +1,5 @@ function tensortrace_pullback!(ΔC, ΔA, C, A, p::Index2Tuple, q::Index2Tuple, conjA::Bool, α, β, ba...) - ip = invperm((linearize(p)..., q[1]..., q[2]...)) + ip = repartition(invperm((linearize(p)..., linearize(q)...)), numind(p) + numind(q)) Es = map(q[1], q[2]) do i1, i2 one( TensorOperations.tensoralloc_add( @@ -10,8 +10,8 @@ function tensortrace_pullback!(ΔC, ΔA, C, A, p::Index2Tuple, q::Index2Tuple, c E = _kron(Es, ba) ΔAc = eltype(ΔC) <: Complex && eltype(ΔA) <: Real ? zerovector(A, VectorInterface.promote_add(ΔC, α)) : ΔA tensorproduct!( - ΔAc, ΔC, (trivtuple(numind(p)), ()), conjA, - E, ((), trivtuple(numind(q))), conjA, + ΔAc, ΔC, trivialpermutation(numind(p), 0), conjA, + E, trivialpermutation(0, numind(q)), conjA, (ip, ()), conjA ? α : conj(α), One(), ba... ) @@ -22,9 +22,9 @@ function tensortrace_pullback!(ΔC, ΔA, C, A, p::Index2Tuple, q::Index2Tuple, c Δα = if _needs_tangent(α) tensorscalar( tensorcontract( - C_αβ, ((), trivtuple(numind(p))), + C_αβ, trivialpermutation(0, numind(p)), !conjA, - ΔC, (trivtuple(numind(p)), ()), false, + ΔC, trivialpermutation(numind(p), 0), false, ((), ()), One(), ba... ) ) @@ -34,8 +34,8 @@ function tensortrace_pullback!(ΔC, ΔA, C, A, p::Index2Tuple, q::Index2Tuple, c Δβ = if _needs_tangent(β) tensorscalar( tensorcontract( - C, ((), trivtuple(numind(p))), true, - ΔC, (trivtuple(numind(p)), ()), false, + C, trivialpermtation(0, numind(p)), true, + ΔC, trivialpermutation(numind(p), 0), false, ((), ()), One(), ba... ) ) diff --git a/src/utils.jl b/src/utils.jl index 68b98c8..2010e8b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -2,8 +2,8 @@ _kron(Es::NTuple{1}, ba) = Es[1] function _kron(Es::NTuple{N, Any}, ba) where {N} E1 = Es[1] E2 = _kron(Base.tail(Es), ba) - p2 = ((), trivtuple(2 * N - 2)) - p = ((1, (2 .+ trivtuple(N - 1))...), (2, ((N + 1) .+ trivtuple(N - 1))...)) + p2 = trivialpermutation(0, 2N - 2) + p = ((1, (2 .+ trivialpermutation(N - 1))...), (2, ((N + 1) .+ trivialpermutation(N - 1))...)) return tensorproduct(p, E1, ((1, 2), ()), false, E2, p2, false, One(), ba...) end @@ -13,5 +13,3 @@ _needs_tangent(x) = _needs_tangent(typeof(x)) _needs_tangent(::Type{<:Number}) = true _needs_tangent(::Type{<:Integer}) = false _needs_tangent(::Type{<:Union{One, Zero}}) = false - -trivtuple(N) = ntuple(identity, N) From 6f0aa858a58ae5677c3e85976b934cc15a6d4aea Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 20 Jan 2026 11:38:11 -0500 Subject: [PATCH 07/11] ChainRules using common pullback functions --- ext/TensorOperationsChainRulesCoreExt.jl | 180 ++++------------------- src/TensorOperations.jl | 1 + src/pullbacks/add.jl | 38 ++--- src/pullbacks/common.jl | 11 ++ src/pullbacks/contract.jl | 124 ++++++++++------ src/pullbacks/trace.jl | 78 +++++++--- src/utils.jl | 7 - 7 files changed, 200 insertions(+), 239 deletions(-) create mode 100644 src/pullbacks/common.jl diff --git a/ext/TensorOperationsChainRulesCoreExt.jl b/ext/TensorOperationsChainRulesCoreExt.jl index 8338468..4ea1123 100644 --- a/ext/TensorOperationsChainRulesCoreExt.jl +++ b/ext/TensorOperationsChainRulesCoreExt.jl @@ -1,7 +1,12 @@ module TensorOperationsChainRulesCoreExt using TensorOperations -using TensorOperations: numind, numin, numout, promote_contract, _needs_tangent, trivtuple +using TensorOperations: numind, numin, numout, promote_contract, _needs_tangent +using TensorOperations: pullback_dC, pullback_dβ, + tensoradd_pullback_dA, tensoradd_pullback_dα, + tensorcontract_pullback_dA, tensorcontract_pullback_dB, tensorcontract_pullback_dα, + tensortrace_pullback_dA, tensortrace_pullback_dα + using TensorOperations: DefaultBackend, DefaultAllocator, _kron using ChainRulesCore using TupleTools @@ -72,45 +77,18 @@ function _rrule_tensoradd!(C, A, pA, conjA, α, β, ba) projectα = ProjectTo(α) projectβ = ProjectTo(β) - function pullback(ΔC′) + function tensoradd_pullback(ΔC′) ΔC = unthunk(ΔC′) - dC = if β === Zero() - ZeroTangent() - else - @thunk projectC(scale(ΔC, conj(β))) - end - dA = @thunk let - ipA = invperm(linearize(pA)) - _dA = zerovector(A, VectorInterface.promote_add(ΔC, α)) - _dA = tensoradd!(_dA, ΔC, (ipA, ()), conjA, conjA ? α : conj(α), Zero(), ba...) - projectA(_dA) - end + + dC = β === Zero() ? ZeroTangent() : @thunk projectC(pullback_dC(ΔC, β)) + dA = @thunk projectA(tensoradd_pullback_dA(ΔC, C, A, pA, conjA, α, ba...)) dα = if _needs_tangent(α) - @thunk let - _dα = tensorscalar( - tensorcontract( - A, ((), linearize(pA)), !conjA, - ΔC, (trivtuple(numind(pA)), ()), false, - ((), ()), One(), ba... - ) - ) - projectα(_dα) - end + @thunk projectα(tensoradd_pullback_dα(ΔC, C, A, pA, conjA, α, ba...)) else ZeroTangent() end dβ = if _needs_tangent(β) - @thunk let - # TODO: consider using `inner` - _dβ = tensorscalar( - tensorcontract( - C, ((), trivtuple(numind(pA))), true, - ΔC, (trivtuple(numind(pA)), ()), false, - ((), ()), One(), ba... - ) - ) - projectβ(_dβ) - end + @thunk projectβ(pullback_dβ(ΔC, C, β)) else ZeroTangent() end @@ -118,7 +96,7 @@ function _rrule_tensoradd!(C, A, pA, conjA, α, β, ba) return NoTangent(), dC, dA, NoTangent(), NoTangent(), dα, dβ, dba... end - return C′, pullback + return C′, tensoradd_pullback end function ChainRulesCore.rrule( @@ -141,84 +119,31 @@ function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba) projectα = ProjectTo(α) projectβ = ProjectTo(β) - function pullback(ΔC′) + function tensorcontract_pullback(ΔC′) ΔC = unthunk(ΔC′) - ipAB = invperm(linearize(pAB)) - pΔC = ( - TupleTools.getindices(ipAB, trivtuple(numout(pA))), - TupleTools.getindices(ipAB, numout(pA) .+ trivtuple(numin(pB))), - ) - dC = if β === Zero() - ZeroTangent() - else - @thunk projectC(scale(ΔC, conj(β))) - end - dA = @thunk let - ipA = (invperm(linearize(pA)), ()) - conjΔC = conjA - conjB′ = conjA ? conjB : !conjB - _dA = zerovector(A, promote_contract(scalartype(ΔC), scalartype(B), typeof(α))) - _dA = tensorcontract!( - _dA, - ΔC, pΔC, conjΔC, - B, reverse(pB), conjB′, - ipA, - conjA ? α : conj(α), Zero(), ba... - ) - projectA(_dA) - end - dB = @thunk let - ipB = (invperm(linearize(pB)), ()) - conjΔC = conjB - conjA′ = conjB ? conjA : !conjA - _dB = zerovector(B, promote_contract(scalartype(ΔC), scalartype(A), typeof(α))) - _dB = tensorcontract!( - _dB, - A, reverse(pA), conjA′, - ΔC, pΔC, conjΔC, - ipB, - conjB ? α : conj(α), Zero(), ba... - ) - projectB(_dB) - end + + dC = β === Zero() ? ZeroTangent() : @thunk projectC(pullback_dC(ΔC, β)) + dA = @thunk projectA(tensorcontract_pullback_dA(ΔC, C, A, pA, conjA, B, pB, conjB, pAB, α, ba...)) + dB = @thunk projectB(tensorcontract_pullback_dB(ΔC, C, A, pA, conjA, B, pB, conjB, pAB, α, ba...)) dα = if _needs_tangent(α) - @thunk let - C_αβ = tensorcontract(A, pA, conjA, B, pB, conjB, pAB, One(), ba...) - # TODO: consider using `inner` - _dα = tensorscalar( - tensorcontract( - C_αβ, ((), trivtuple(numind(pAB))), true, - ΔC, (trivtuple(numind(pAB)), ()), false, - ((), ()), One(), ba... - ) - ) - projectα(_dα) - end + @thunk projectα(tensorcontract_pullback_dα(ΔC, C, A, pA, conjA, B, pB, conjB, pAB, α, ba...)) else ZeroTangent() end dβ = if _needs_tangent(β) - @thunk let - # TODO: consider using `inner` - _dβ = tensorscalar( - tensorcontract( - C, ((), trivtuple(numind(pAB))), true, - ΔC, (trivtuple(numind(pAB)), ()), false, - ((), ()), One(), ba... - ) - ) - projectβ(_dβ) - end + @thunk projectβ(pullback_dβ(ΔC, C, β)) else ZeroTangent() end dba = map(_ -> NoTangent(), ba) return NoTangent(), dC, - dA, NoTangent(), NoTangent(), dB, NoTangent(), NoTangent(), - NoTangent(), dα, dβ, dba... + dA, NoTangent(), NoTangent(), + dB, NoTangent(), NoTangent(), + NoTangent(), + dα, dβ, dba... end - return C′, pullback + return C′, tensorcontract_pullback end function ChainRulesCore.rrule( @@ -237,59 +162,18 @@ function _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba) projectα = ProjectTo(α) projectβ = ProjectTo(β) - function pullback(ΔC′) + function tensortrace_pullback(ΔC′) ΔC = unthunk(ΔC′) - dC = if β === Zero() - ZeroTangent() - else - @thunk projectC(scale(ΔC, conj(β))) - end - dA = @thunk let - ip = invperm((linearize(p)..., q[1]..., q[2]...)) - Es = map(q[1], q[2]) do i1, i2 - one( - TensorOperations.tensoralloc_add( - scalartype(A), A, ((i1,), (i2,)), conjA - ) - ) - end - E = _kron(Es, ba) - _dA = zerovector(A, VectorInterface.promote_scale(ΔC, α)) - _dA = tensorproduct!( - _dA, ΔC, (trivtuple(numind(p)), ()), conjA, - E, ((), trivtuple(numind(q))), conjA, - (ip, ()), - conjA ? α : conj(α), Zero(), ba... - ) - projectA(_dA) - end + + dC = β === Zero() ? ZeroTangent() : @thunk projectC(pullback_dC(ΔC, β)) + dA = @thunk projectA(tensortrace_pullback_dA(ΔC, C, A, p, q, conjA, α, ba...)) dα = if _needs_tangent(α) - @thunk let - C_αβ = tensortrace(A, p, q, false, One(), ba...) - _dα = tensorscalar( - tensorcontract( - C_αβ, ((), trivtuple(numind(p))), - !conjA, - ΔC, (trivtuple(numind(p)), ()), false, - ((), ()), One(), ba... - ) - ) - projectα(_dα) - end + @thunk projectα(tensortrace_pullback_dα(ΔC, C, A, p, q, conjA, α, ba...)) else ZeroTangent() end dβ = if _needs_tangent(β) - @thunk let - _dβ = tensorscalar( - tensorcontract( - C, ((), trivtuple(numind(p))), true, - ΔC, (trivtuple(numind(p)), ()), false, - ((), ()), One(), ba... - ) - ) - projectβ(_dβ) - end + @thunk projectβ(pullback_dβ(ΔC, C, β)) else ZeroTangent() end @@ -297,7 +181,7 @@ function _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba) return NoTangent(), dC, dA, NoTangent(), NoTangent(), NoTangent(), dα, dβ, dba... end - return C′, pullback + return C′, tensortrace_pullback end # NCON functions diff --git a/src/TensorOperations.jl b/src/TensorOperations.jl index 05ae808..9bc4dbd 100644 --- a/src/TensorOperations.jl +++ b/src/TensorOperations.jl @@ -38,6 +38,7 @@ include("utils.jl") # Generic pullbacks for AD #--------------------------- +include("pullbacks/common.jl") include("pullbacks/add.jl") include("pullbacks/trace.jl") include("pullbacks/contract.jl") diff --git a/src/pullbacks/add.jl b/src/pullbacks/add.jl index 7cf10c0..63d366c 100644 --- a/src/pullbacks/add.jl +++ b/src/pullbacks/add.jl @@ -1,25 +1,31 @@ function tensoradd_pullback!(ΔC, ΔA, C, A, pA::Index2Tuple, conjA::Bool, α, β, ba...) - ipA = invperm(linearize(pA)) - ΔAc = eltype(ΔC) <: Complex && eltype(ΔA) <: Real ? zerovector(A, VectorInterface.promote_add(ΔC, α)) : ΔA - tensoradd!(ΔAc, ΔC, (ipA, ()), conjA, conjA ? α : conj(α), One(), ba...) + dA = tensoradd_pullback_dA!(ΔA, ΔC, C, A, pA, conjA, α, ba...) + dα = tensoradd_pullback_dα(ΔC, C, A, pA, conjA, α, ba...) + dβ = pullback_dβ(ΔC, C, β) + dC = pullback_dC!(ΔC, β) + return dC, dA, dα, dβ +end + +function tensoradd_pullback_dA(ΔC, C, A, pA::Index2Tuple, conjA::Bool, α, ba...) + ipA = inversepermutation(pA, A) + return tensorcopy(ΔC, ipA, conjA, conjA ? α : conj(α), ba...) +end +function tensoradd_pullback_dA!(ΔA, ΔC, C, A, pA::Index2Tuple, conjA::Bool, α, ba...) if eltype(ΔC) <: Complex && eltype(ΔA) <: Real + ΔAc = tensoradd_pullback_dA(ΔC, C, A, pA, conjA, α, ba...) ΔA .+= real.(ΔAc) - end - Δα = if _needs_tangent(α) - tensorscalar( - tensorcontract( - A, ((), linearize(pA)), !conjA, - ΔC, trivialpermutation(numind(pA), 0), false, - ((), ()), One(), ba... - ) - ) else - nothing + ipA = inversepermutation(pA, ΔA) + tensoradd!(ΔA, ΔC, ipA, conjA, conjA ? α : conj(α), One(), ba...) end - Δβ = if _needs_tangent(β) + return ΔA +end + +function tensoradd_pullback_dα(ΔC, C, A, pA::Index2Tuple, conjA::Bool, α, ba...) + return if _needs_tangent(α) tensorscalar( tensorcontract( - C, trivialpermutation(0, numind(pA)), true, + A, repartition(pA, 0), !conjA, ΔC, trivialpermutation(numind(pA), 0), false, ((), ()), One(), ba... ) @@ -27,6 +33,4 @@ function tensoradd_pullback!(ΔC, ΔA, C, A, pA::Index2Tuple, conjA::Bool, α, else nothing end - scale!(ΔC, conj(β)) - return ΔC, ΔA, Δα, Δβ end diff --git a/src/pullbacks/common.jl b/src/pullbacks/common.jl new file mode 100644 index 0000000..9df8b44 --- /dev/null +++ b/src/pullbacks/common.jl @@ -0,0 +1,11 @@ +# To avoid computing rrules for α and β when these aren't needed, we want to have a +# type-stable quick bail-out +_needs_tangent(x) = _needs_tangent(typeof(x)) +_needs_tangent(::Type{<:Number}) = true +_needs_tangent(::Type{<:Integer}) = false +_needs_tangent(::Type{<:Union{One, Zero}}) = false + +# (partial) pullbacks that are shared +pullback_dC!(ΔC, β) = scale!(ΔC, conj(β)) +pullback_dC(ΔC, β) = scale(ΔC, conj(β)) +pullback_dβ(ΔC, C, β) = _needs_tangent(β) ? inner(C, ΔC) : nothing diff --git a/src/pullbacks/contract.jl b/src/pullbacks/contract.jl index ff3a8a1..257652e 100644 --- a/src/pullbacks/contract.jl +++ b/src/pullbacks/contract.jl @@ -1,62 +1,96 @@ -function tensorcontract_pullback!(ΔC, ΔA, ΔB, C, A, pA::Index2Tuple, conjA::Bool, B, pB::Index2Tuple, conjB::Bool, pAB::Index2Tuple, α, β, ba...) +function tensorcontract_pullback!( + ΔC, ΔA, ΔB, + C, + A, pA::Index2Tuple, conjA::Bool, + B, pB::Index2Tuple, conjB::Bool, + pAB::Index2Tuple, + α::Number, β::Number, + ba... + ) + dA = tensorcontract_pullback_dA!(ΔA, ΔC, C, A, pA, conjA, B, pB, conjB, pAB, α, ba...) + dB = tensorcontract_pullback_dB!(ΔB, ΔC, C, A, pA, conjA, B, pB, conjB, pAB, α, ba...) + dα = tensorcontract_pullback_dα(ΔC, C, A, pA, conjA, B, pB, conjB, pAB, α, ba...) + dβ = pullback_dβ(ΔC, C, β) + dC = pullback_dC!(ΔC, β) + return dC, dA, dB, dα, dβ +end + +function tensorcontract_pullback_dA( + ΔC, C, + A, pA::Index2Tuple, conjA::Bool, B, pB::Index2Tuple, conjB::Bool, + pAB::Index2Tuple, α::Number, ba... + ) pdC = inversepermutation(pAB, numout(pA)) ipA = inversepermutation(pA, A) - ipB = inversepermutation(pB, B) - - conjΔC = conjA - conjB′ = conjA ? conjB : !conjB - ΔAc = eltype(ΔC) <: Complex && eltype(ΔA) <: Real ? zerovector(A, VectorInterface.promote_add(ΔC, α)) : ΔA - tensorcontract!( - ΔAc, - ΔC, pdC, conjΔC, - B, reverse(pB), conjB′, - ipA, - conjA ? α : conj(α), One(), ba... + return tensorcontract( + ΔC, pdC, conjA, B, reverse(pB), conjA ? conjB : !conjB, + ipA, conjA ? α : conj(α), ba... + ) +end +function tensorcontract_pullback_dA!( + ΔA, ΔC, C, + A, pA::Index2Tuple, conjA::Bool, B, pB::Index2Tuple, conjB::Bool, + pAB::Index2Tuple, α::Number, ba... ) + if eltype(ΔC) <: Complex && eltype(ΔA) <: Real + ΔAc = tensorcontract_pullback_dA(ΔC, C, A, pA, conjA, B, pB, conjB, pAB, α, ba...) ΔA .+= real.(ΔAc) + else + pdC = inversepermutation(pAB, numout(pA)) + ipA = inversepermutation(pA, A) + tensorcontract!( + ΔA, + ΔC, pdC, conjA, B, reverse(pB), conjA ? conjB : !conjB, + ipA, conjA ? α : conj(α), One(), ba... + ) end - conjΔC = conjB - conjA′ = conjB ? conjA : !conjA - ΔBc = eltype(ΔC) <: Complex && eltype(ΔB) <: Real ? zerovector(B, VectorInterface.promote_add(ΔC, α)) : ΔB - tensorcontract!( - ΔBc, - A, reverse(pA), conjA′, - ΔC, pdC, conjΔC, - ipB, - conjB ? α : conj(α), One(), ba... + return ΔA +end + +function tensorcontract_pullback_dB( + ΔC, C, + A, pA::Index2Tuple, conjA::Bool, B, pB::Index2Tuple, conjB::Bool, + pAB::Index2Tuple, α::Number, ba... + ) + pdC = inversepermutation(pAB, numout(pA)) + ipB = inversepermutation(pB, B) + return tensorcontract( + A, reverse(pA), conjB ? conjA : !conjA, ΔC, pdC, conjB, + ipB, conjB ? α : conj(α), ba... + ) +end +function tensorcontract_pullback_dB!( + ΔB, ΔC, C, + A, pA::Index2Tuple, conjA::Bool, B, pB::Index2Tuple, conjB::Bool, + pAB::Index2Tuple, α::Number, ba... ) if eltype(ΔC) <: Complex && eltype(ΔB) <: Real + ΔBc = tensorcontract_pullback_dB(ΔC, C, A, pA, conjA, B, pB, conjB, pAB, α, ba...) ΔB .+= real.(ΔBc) - end - - Δα = if _needs_tangent(α) - C_αβ = tensorcontract(A, pA, conjA, B, pB, conjB, pAB, One(), ba...) - # TODO: consider using `inner` - tensorscalar( - tensorcontract( - C_αβ, trivialtuple(0, numind(pAB)), true, - ΔC, trivialtuple(numind(pAB, 0), false, - ((), ()), One(), ba... - ) - ) else - nothing + pdC = inversepermutation(pAB, numout(pA)) + ipB = inversepermutation(pB, B) + tensorcontract!( + ΔB, + A, reverse(pA), conjB ? conjA : !conjA, ΔC, pdC, conjB, + ipB, conjB ? α : conj(α), One(), ba... + ) end - Δβ = if _needs_tangent(β) - # TODO: consider using `inner` - tensorscalar( - tensorcontract( - C, trivialtuple(0, numind(pAB)), true, - ΔC, trivialtuple(numind(pAB), 0), false, - ((), ()), One(), ba... - ) - ) + return ΔB +end + +function tensorcontract_pullback_dα( + ΔC, C, + A, pA::Index2Tuple, conjA::Bool, B, pB::Index2Tuple, conjB::Bool, + pAB::Index2Tuple, α::Number, ba... + ) + return if _needs_tangent(α) + C_αβ = tensorcontract(A, pA, conjA, B, pB, conjB, pAB, One(), ba...) + inner(C_αβ, ΔC) else nothing end - scale!(ΔC, conj(β)) - return ΔC, ΔA, ΔB, Δα, Δβ end diff --git a/src/pullbacks/trace.jl b/src/pullbacks/trace.jl index bd0e2d9..067e02a 100644 --- a/src/pullbacks/trace.jl +++ b/src/pullbacks/trace.jl @@ -1,5 +1,21 @@ -function tensortrace_pullback!(ΔC, ΔA, C, A, p::Index2Tuple, q::Index2Tuple, conjA::Bool, α, β, ba...) - ip = repartition(invperm((linearize(p)..., linearize(q)...)), numind(p) + numind(q)) +function tensortrace_pullback!( + ΔC, ΔA, + C, A, p::Index2Tuple, q::Index2Tuple, conjA::Bool, + α::Number, β::Number, ba... + ) + dA = tensortrace_pullback_dA!(ΔA, ΔC, C, A, p, q, conjA, α, ba...) + dα = tensortrace_pullback_dα(ΔC, C, A, p, q, conjA, α, ba...) + dβ = pullback_dβ(ΔC, C, β) + dC = pullback_dC!(ΔC, β) + + return dC, dA, dα, dβ +end + +function tensortrace_pullback_dA( + ΔC, C, A, p::Index2Tuple, q::Index2Tuple, conjA::Bool, + α::Number, ba... + ) + ip = repartition(invperm((linearize(p)..., linearize(q)...)), A) Es = map(q[1], q[2]) do i1, i2 one( TensorOperations.tensoralloc_add( @@ -8,33 +24,53 @@ function tensortrace_pullback!(ΔC, ΔA, C, A, p::Index2Tuple, q::Index2Tuple, c ) end E = _kron(Es, ba) - ΔAc = eltype(ΔC) <: Complex && eltype(ΔA) <: Real ? zerovector(A, VectorInterface.promote_add(ΔC, α)) : ΔA - tensorproduct!( - ΔAc, ΔC, trivialpermutation(numind(p), 0), conjA, + + return tensorproduct( + ΔC, trivialpermutation(numind(p), 0), conjA, E, trivialpermutation(0, numind(q)), conjA, - (ip, ()), - conjA ? α : conj(α), One(), ba... + ip, + conjA ? α : conj(α), ba... + ) +end +function tensortrace_pullback_dA!( + ΔA, ΔC, C, A, p::Index2Tuple, q::Index2Tuple, conjA::Bool, + α::Number, ba... ) + if eltype(ΔC) <: Complex && eltype(ΔA) <: Real + ΔAc = tensortrace_pullback_dA(ΔC, C, A, p, q, conjA, α, ba...) ΔA .+= real.(ΔAc) - end - C_αβ = tensortrace(A, p, q, false, One(), ba...) - Δα = if _needs_tangent(α) - tensorscalar( - tensorcontract( - C_αβ, trivialpermutation(0, numind(p)), - !conjA, - ΔC, trivialpermutation(numind(p), 0), false, - ((), ()), One(), ba... + else + ip = repartition(invperm((linearize(p)..., linearize(q)...)), A) + Es = map(q[1], q[2]) do i1, i2 + one( + TensorOperations.tensoralloc_add( + TensorOperations.scalartype(A), A, ((i1,), (i2,)), conjA + ) ) + end + E = _kron(Es, ba) + tensorproduct!( + ΔA, ΔC, trivialpermutation(numind(p), 0), conjA, + E, trivialpermutation(0, numind(q)), conjA, + ip, + conjA ? α : conj(α), One(), ba... ) - else - nothing end - Δβ = if _needs_tangent(β) + + return ΔA +end + +function tensortrace_pullback_dα( + ΔC, C, A, p::Index2Tuple, q::Index2Tuple, conjA::Bool, + α::Number, ba... + ) + return if _needs_tangent(α) + C_αβ = tensortrace(A, p, q, false, One(), ba...) tensorscalar( tensorcontract( - C, trivialpermtation(0, numind(p)), true, + C_αβ, trivialpermutation(0, numind(p)), + !conjA, ΔC, trivialpermutation(numind(p), 0), false, ((), ()), One(), ba... ) @@ -42,6 +78,4 @@ function tensortrace_pullback!(ΔC, ΔA, C, A, p::Index2Tuple, q::Index2Tuple, c else nothing end - scale!(ΔC, conj(β)) - return ΔC, ΔA, Δα, Δβ end diff --git a/src/utils.jl b/src/utils.jl index 2010e8b..0afef28 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -6,10 +6,3 @@ function _kron(Es::NTuple{N, Any}, ba) where {N} p = ((1, (2 .+ trivialpermutation(N - 1))...), (2, ((N + 1) .+ trivialpermutation(N - 1))...)) return tensorproduct(p, E1, ((1, 2), ()), false, E2, p2, false, One(), ba...) end - -# To avoid computing rrules for α and β when these aren't needed, we want to have a -# type-stable quick bail-out -_needs_tangent(x) = _needs_tangent(typeof(x)) -_needs_tangent(::Type{<:Number}) = true -_needs_tangent(::Type{<:Integer}) = false -_needs_tangent(::Type{<:Union{One, Zero}}) = false From c538e5a638a0d9bcea4a04010c8abe19cee25961 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 20 Jan 2026 15:09:03 -0500 Subject: [PATCH 08/11] only check public functions are included in docs --- docs/make.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/make.jl b/docs/make.jl index 77d50fb..bd6c62d 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -18,7 +18,8 @@ makedocs(; "man/precompilation.md", ], "Index" => "index/index.md", - ] + ], + checkdocs = :public ) # Documenter can also automatically deploy documentation to gh-pages. From 99fed3568915b62571b7856383fc0e0eed9a5c0a Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 21 Jan 2026 11:01:00 -0500 Subject: [PATCH 09/11] update docstrings --- src/indices.jl | 6 +++--- src/pullbacks/add.jl | 23 +++++++++++++++++++++++ src/pullbacks/common.jl | 22 ++++++++++++++++++++++ src/pullbacks/contract.jl | 33 +++++++++++++++++++++++++++++++++ src/pullbacks/trace.jl | 21 +++++++++++++++++++++ 5 files changed, 102 insertions(+), 3 deletions(-) diff --git a/src/indices.jl b/src/indices.jl index 61eeddc..5e5bfeb 100644 --- a/src/indices.jl +++ b/src/indices.jl @@ -49,9 +49,9 @@ end inversepermutation(p::Index2Tuple, partition_as::Index2Tuple) -> ip::Index2Tuple inversepermutation(p::Index2Tuple, partition_as::AbstractArray) -> ip::Index2Tuple -Compute the inverse permutation associated to `p`. -If no extra arguments are provided, the result is returned as a single `IndexTuple`. -Otherwise, the extra arguments are used to partition the inverse permutation into an `Index2Tuple`. +Compute the inverse permutation associated with `p`. +If no extra arguments are provided, the result is returned as a single [`IndexTuple`](@ref). +Otherwise, the extra arguments are used to partition the inverse permutation into an [`Index2Tuple`](@ref). """ inversepermutation(p::Index2Tuple) = invperm(linearize(p)) function inversepermutation(p::Index2Tuple, args...) diff --git a/src/pullbacks/add.jl b/src/pullbacks/add.jl index 63d366c..e0b2f08 100644 --- a/src/pullbacks/add.jl +++ b/src/pullbacks/add.jl @@ -1,3 +1,11 @@ +""" + tensoradd_pullback!(ΔC, ΔA, C, A, pA::Index2Tuple, conjA::Bool, α, β, ba...) -> ΔC, ΔA, Δα, Δβ + +Compute pullbacks for [`tensoradd!`](@ref), updating cotangent arrays and returning cotangent scalars. + +See also [`pullback_dC`](@ref), [`tensoradd_pullback_dA`](@ref), [`tensoradd_pullback_dα`](@ref) and [`pullback_dβ`](@ref) +for computing pullbacks for the individual components. +""" function tensoradd_pullback!(ΔC, ΔA, C, A, pA::Index2Tuple, conjA::Bool, α, β, ba...) dA = tensoradd_pullback_dA!(ΔA, ΔC, C, A, pA, conjA, α, ba...) dα = tensoradd_pullback_dα(ΔC, C, A, pA, conjA, α, ba...) @@ -6,6 +14,16 @@ function tensoradd_pullback!(ΔC, ΔA, C, A, pA::Index2Tuple, conjA::Bool, α, return dC, dA, dα, dβ end +@doc """ + tensoradd_pullback_dA(ΔC, C, A, pA::Index2Tuple, conjA::Bool, α, ba...) + tensoradd_pullback_dA!(ΔA, ΔC, C, A, pA::Index2Tuple, conjA::Bool, α, ba...) + +Compute the pullback for [`tensoradd!`](@ref) with respect to the input `A`. +The mutating version can be used to accumulate the result into `ΔA`. + +See also [`tensoradd_pullback_dA!`](@ref) for computing and updating the gradient in-place. +""" tensoradd_pullback_dA, tensoradd_pullback_dA! + function tensoradd_pullback_dA(ΔC, C, A, pA::Index2Tuple, conjA::Bool, α, ba...) ipA = inversepermutation(pA, A) return tensorcopy(ΔC, ipA, conjA, conjA ? α : conj(α), ba...) @@ -21,6 +39,11 @@ function tensoradd_pullback_dA!(ΔA, ΔC, C, A, pA::Index2Tuple, conjA::Bool, α return ΔA end +""" + tensoradd_pullback_dα(ΔC, C, A, pA::Index2Tuple, conjA::Bool, α, ba...) + +Compute the pullback for [`tensoradd!]`(ref) with respect to scaling coefficient `α`. +""" function tensoradd_pullback_dα(ΔC, C, A, pA::Index2Tuple, conjA::Bool, α, ba...) return if _needs_tangent(α) tensorscalar( diff --git a/src/pullbacks/common.jl b/src/pullbacks/common.jl index 9df8b44..add21ce 100644 --- a/src/pullbacks/common.jl +++ b/src/pullbacks/common.jl @@ -1,11 +1,33 @@ # To avoid computing rrules for α and β when these aren't needed, we want to have a # type-stable quick bail-out +""" + _needs_tangent(x) + _needs_tangent(::Type{T}) + +Determine whether a value requires tangent computation during automatic differentiation. +Returns `false` for constants like Integer, One, and Zero types to avoid unnecessary computation +in automatic differentiation. Returns `true` only for general Numbers. +""" _needs_tangent(x) = _needs_tangent(typeof(x)) _needs_tangent(::Type{<:Number}) = true _needs_tangent(::Type{<:Integer}) = false _needs_tangent(::Type{<:Union{One, Zero}}) = false # (partial) pullbacks that are shared +@doc """ + pullback_dC(ΔC, β) + pullback_dC!(ΔC, β) + +For functions of the form `f!(C, β, ...) = βC + ...`, compute the pullback with respect to `C`. +""" pullback_dC, pullback_dC! + pullback_dC!(ΔC, β) = scale!(ΔC, conj(β)) pullback_dC(ΔC, β) = scale(ΔC, conj(β)) + +@doc """ + pullback_dβ(ΔC, C, β) + +For functions of the form `f!(C, β, ...) = βC + ...`, compute the pullback with respect to `β`. +""" pullback_dβ + pullback_dβ(ΔC, C, β) = _needs_tangent(β) ? inner(C, ΔC) : nothing diff --git a/src/pullbacks/contract.jl b/src/pullbacks/contract.jl index 257652e..ab965c0 100644 --- a/src/pullbacks/contract.jl +++ b/src/pullbacks/contract.jl @@ -1,3 +1,11 @@ +""" + tensorcontract_pullback!(ΔC, ΔA, ΔB, C, A, pA::Index2Tuple, conjA::Bool, B, pB::Index2Tuple, conjB::Bool, pAB::Index2Tuple, α::Number, β::Number, ba...) -> ΔC, ΔA, ΔB, Δα, Δβ + +Compute pullbacks for [`tensorcontract!`](@ref), updating cotangent arrays and returning cotangent scalars. + +See also [`pullback_dC`](@ref), [`tensorcontract_pullback_dA`](@ref), [`tensorcontract_pullback_dB`](@ref), +[`tensorcontract_pullback_dα`](@ref) and [`pullback_dβ`](@ref) for computing pullbacks for the individual components. +""" function tensorcontract_pullback!( ΔC, ΔA, ΔB, C, @@ -15,6 +23,16 @@ function tensorcontract_pullback!( return dC, dA, dB, dα, dβ end +@doc """ + tensorcontract_pullback_dA(ΔC, C, A, pA::Index2Tuple, conjA::Bool, B, pB::Index2Tuple, conjB::Bool, pAB::Index2Tuple, α::Number, ba...) + tensorcontract_pullback_dA!(ΔA, ΔC, C, A, pA::Index2Tuple, conjA::Bool, B, pB::Index2Tuple, conjB::Bool, pAB::Index2Tuple, α::Number, ba...) + +Compute the pullback for [`tensorcontract!`](@ref) with respect to the input `A`. +The mutating version can be used to accumulate the result into `ΔA`. + +See also [`tensorcontract_pullback_dB`](@ref) and [`tensorcontract_pullback_dB!`](@ref) for the pullback with respect to `B`. +""" tensorcontract_pullback_dA, tensorcontract_pullback_dA! + function tensorcontract_pullback_dA( ΔC, C, A, pA::Index2Tuple, conjA::Bool, B, pB::Index2Tuple, conjB::Bool, @@ -49,6 +67,16 @@ function tensorcontract_pullback_dA!( return ΔA end +@doc """ + tensorcontract_pullback_dB(ΔC, C, A, pA::Index2Tuple, conjA::Bool, B, pB::Index2Tuple, conjB::Bool, pAB::Index2Tuple, α::Number, ba...) + tensorcontract_pullback_dB!(ΔB, ΔC, C, A, pA::Index2Tuple, conjA::Bool, B, pB::Index2Tuple, conjB::Bool, pAB::Index2Tuple, α::Number, ba...) + +Compute the pullback for [`tensorcontract!`](@ref) with respect to the input `B`. +The mutating version can be used to accumulate the result into `ΔB`. + +See also [`tensorcontract_pullback_dA`](@ref) and [`tensorcontract_pullback_dA!`](@ref) for the pullback with respect to `A`. +""" tensorcontract_pullback_dB, tensorcontract_pullback_dB! + function tensorcontract_pullback_dB( ΔC, C, A, pA::Index2Tuple, conjA::Bool, B, pB::Index2Tuple, conjB::Bool, @@ -82,6 +110,11 @@ function tensorcontract_pullback_dB!( return ΔB end +""" + tensorcontract_pullback_dα(ΔC, C, A, pA::Index2Tuple, conjA::Bool, B, pB::Index2Tuple, conjB::Bool, pAB::Index2Tuple, α::Number, ba...) + +Compute the pullback for [`tensorcontract!`](@ref) with respect to scaling coefficient `α`. +""" function tensorcontract_pullback_dα( ΔC, C, A, pA::Index2Tuple, conjA::Bool, B, pB::Index2Tuple, conjB::Bool, diff --git a/src/pullbacks/trace.jl b/src/pullbacks/trace.jl index 067e02a..f9673a4 100644 --- a/src/pullbacks/trace.jl +++ b/src/pullbacks/trace.jl @@ -1,3 +1,11 @@ +""" + tensortrace_pullback!(ΔC, ΔA, C, A, p::Index2Tuple, q::Index2Tuple, conjA::Bool, α::Number, β::Number, ba...) -> ΔC, ΔA, Δα, Δβ + +Compute pullbacks for [`tensortrace!`](@ref), updating cotangent arrays and returning cotangent scalars. + +See also [`pullback_dC`](@ref), [`tensortrace_pullback_dA`](@ref), [`tensortrace_pullback_dα`](@ref) and [`pullback_dβ`](@ref) +for computing pullbacks for the individual components. +""" function tensortrace_pullback!( ΔC, ΔA, C, A, p::Index2Tuple, q::Index2Tuple, conjA::Bool, @@ -11,6 +19,14 @@ function tensortrace_pullback!( return dC, dA, dα, dβ end +@doc """ + tensortrace_pullback_dA(ΔC, C, A, p::Index2Tuple, q::Index2Tuple, conjA::Bool, α::Number, ba...) + tensortrace_pullback_dA!(ΔA, ΔC, C, A, p::Index2Tuple, q::Index2Tuple, conjA::Bool, α::Number, ba...) + +Compute the pullback for [`tensortrace!`](@ref) with respect to the input `A`. +The mutating version can be used to accumulate the result into `ΔA`. +""" tensortrace_pullback_dA, tensortrace_pullback_dA! + function tensortrace_pullback_dA( ΔC, C, A, p::Index2Tuple, q::Index2Tuple, conjA::Bool, α::Number, ba... @@ -61,6 +77,11 @@ function tensortrace_pullback_dA!( return ΔA end +""" + tensortrace_pullback_dα(ΔC, C, A, p::Index2Tuple, q::Index2Tuple, conjA::Bool, α::Number, ba...) + +Compute the pullback for [`tensortrace!`](@ref) with respect to scaling coefficient `α`. +""" function tensortrace_pullback_dα( ΔC, C, A, p::Index2Tuple, q::Index2Tuple, conjA::Bool, α::Number, ba... From dc8fb444948c21c7f4211eac9c8f52457566a9e1 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 21 Jan 2026 11:01:09 -0500 Subject: [PATCH 10/11] cleaner syntax --- src/pullbacks/add.jl | 17 +++++++---------- src/pullbacks/contract.jl | 9 +++------ src/pullbacks/trace.jl | 21 +++++++++------------ 3 files changed, 19 insertions(+), 28 deletions(-) diff --git a/src/pullbacks/add.jl b/src/pullbacks/add.jl index e0b2f08..a71cbfa 100644 --- a/src/pullbacks/add.jl +++ b/src/pullbacks/add.jl @@ -45,15 +45,12 @@ end Compute the pullback for [`tensoradd!]`(ref) with respect to scaling coefficient `α`. """ function tensoradd_pullback_dα(ΔC, C, A, pA::Index2Tuple, conjA::Bool, α, ba...) - return if _needs_tangent(α) - tensorscalar( - tensorcontract( - A, repartition(pA, 0), !conjA, - ΔC, trivialpermutation(numind(pA), 0), false, - ((), ()), One(), ba... - ) + _needs_tangent(α) || return nothing + return tensorscalar( + tensorcontract( + A, repartition(pA, 0), !conjA, + ΔC, trivialpermutation(numind(pA), 0), false, + ((), ()), One(), ba... ) - else - nothing - end + ) end diff --git a/src/pullbacks/contract.jl b/src/pullbacks/contract.jl index ab965c0..f8f4659 100644 --- a/src/pullbacks/contract.jl +++ b/src/pullbacks/contract.jl @@ -120,10 +120,7 @@ function tensorcontract_pullback_dα( A, pA::Index2Tuple, conjA::Bool, B, pB::Index2Tuple, conjB::Bool, pAB::Index2Tuple, α::Number, ba... ) - return if _needs_tangent(α) - C_αβ = tensorcontract(A, pA, conjA, B, pB, conjB, pAB, One(), ba...) - inner(C_αβ, ΔC) - else - nothing - end + _needs_tangent(α) || return nothing + C_αβ = tensorcontract(A, pA, conjA, B, pB, conjB, pAB, One(), ba...) + return inner(C_αβ, ΔC) end diff --git a/src/pullbacks/trace.jl b/src/pullbacks/trace.jl index f9673a4..f3135a2 100644 --- a/src/pullbacks/trace.jl +++ b/src/pullbacks/trace.jl @@ -86,17 +86,14 @@ function tensortrace_pullback_dα( ΔC, C, A, p::Index2Tuple, q::Index2Tuple, conjA::Bool, α::Number, ba... ) - return if _needs_tangent(α) - C_αβ = tensortrace(A, p, q, false, One(), ba...) - tensorscalar( - tensorcontract( - C_αβ, trivialpermutation(0, numind(p)), - !conjA, - ΔC, trivialpermutation(numind(p), 0), false, - ((), ()), One(), ba... - ) + _needs_tangent(α) || return nothing + C_αβ = tensortrace(A, p, q, false, One(), ba...) + return tensorscalar( + tensorcontract( + C_αβ, trivialpermutation(0, numind(p)), + !conjA, + ΔC, trivialpermutation(numind(p), 0), false, + ((), ()), One(), ba... ) - else - nothing - end + ) end From 914b47b302bedb522191b42dca4138c9d269ed1c Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 21 Jan 2026 11:01:19 -0500 Subject: [PATCH 11/11] support complex _needs_tangent --- src/pullbacks/common.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/pullbacks/common.jl b/src/pullbacks/common.jl index add21ce..c561c43 100644 --- a/src/pullbacks/common.jl +++ b/src/pullbacks/common.jl @@ -12,6 +12,7 @@ _needs_tangent(x) = _needs_tangent(typeof(x)) _needs_tangent(::Type{<:Number}) = true _needs_tangent(::Type{<:Integer}) = false _needs_tangent(::Type{<:Union{One, Zero}}) = false +_needs_tangent(::Type{Complex{T}}) where {T} = _needs_tangent(T) # (partial) pullbacks that are shared @doc """