diff --git a/Project.toml b/Project.toml index c05ec0c..deccc00 100644 --- a/Project.toml +++ b/Project.toml @@ -48,7 +48,7 @@ Preferences = "1.4" PtrArrays = "1.2" Random = "1" Strided = "2.2" -StridedViews = "0.3, 0.4" +StridedViews = "0.3, 0.4, ~0.4.2" Test = "1" TupleTools = "1.6" VectorInterface = "0.4.1,0.5" diff --git a/docs/make.jl b/docs/make.jl index 77d50fb..bd6c62d 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -18,7 +18,8 @@ makedocs(; "man/precompilation.md", ], "Index" => "index/index.md", - ] + ], + checkdocs = :public ) # Documenter can also automatically deploy documentation to gh-pages. diff --git a/ext/TensorOperationsChainRulesCoreExt.jl b/ext/TensorOperationsChainRulesCoreExt.jl index 539d9d9..4ea1123 100644 --- a/ext/TensorOperationsChainRulesCoreExt.jl +++ b/ext/TensorOperationsChainRulesCoreExt.jl @@ -2,6 +2,11 @@ module TensorOperationsChainRulesCoreExt using TensorOperations using TensorOperations: numind, numin, numout, promote_contract, _needs_tangent +using TensorOperations: pullback_dC, pullback_dβ, + tensoradd_pullback_dA, tensoradd_pullback_dα, + tensorcontract_pullback_dA, tensorcontract_pullback_dB, tensorcontract_pullback_dα, + tensortrace_pullback_dA, tensortrace_pullback_dα + using TensorOperations: DefaultBackend, DefaultAllocator, _kron using ChainRulesCore using TupleTools @@ -9,8 +14,6 @@ using VectorInterface using TupleTools: invperm using LinearAlgebra -trivtuple(N) = ntuple(identity, N) - @non_differentiable TensorOperations.tensorstructure(args...) @non_differentiable TensorOperations.tensoradd_structure(args...) @non_differentiable TensorOperations.tensoradd_type(args...) @@ -74,45 +77,18 @@ function _rrule_tensoradd!(C, A, pA, conjA, α, β, ba) projectα = ProjectTo(α) projectβ = ProjectTo(β) - function pullback(ΔC′) + function tensoradd_pullback(ΔC′) ΔC = unthunk(ΔC′) - dC = if β === Zero() - ZeroTangent() - else - @thunk projectC(scale(ΔC, conj(β))) - end - dA = @thunk let - ipA = invperm(linearize(pA)) - _dA = zerovector(A, VectorInterface.promote_add(ΔC, α)) - _dA = tensoradd!(_dA, ΔC, (ipA, ()), conjA, conjA ? α : conj(α), Zero(), ba...) - projectA(_dA) - end + + dC = β === Zero() ? ZeroTangent() : @thunk projectC(pullback_dC(ΔC, β)) + dA = @thunk projectA(tensoradd_pullback_dA(ΔC, C, A, pA, conjA, α, ba...)) dα = if _needs_tangent(α) - @thunk let - _dα = tensorscalar( - tensorcontract( - A, ((), linearize(pA)), !conjA, - ΔC, (trivtuple(numind(pA)), ()), false, - ((), ()), One(), ba... - ) - ) - projectα(_dα) - end + @thunk projectα(tensoradd_pullback_dα(ΔC, C, A, pA, conjA, α, ba...)) else ZeroTangent() end dβ = if _needs_tangent(β) - @thunk let - # TODO: consider using `inner` - _dβ = tensorscalar( - tensorcontract( - C, ((), trivtuple(numind(pA))), true, - ΔC, (trivtuple(numind(pA)), ()), false, - ((), ()), One(), ba... - ) - ) - projectβ(_dβ) - end + @thunk projectβ(pullback_dβ(ΔC, C, β)) else ZeroTangent() end @@ -120,7 +96,7 @@ function _rrule_tensoradd!(C, A, pA, conjA, α, β, ba) return NoTangent(), dC, dA, NoTangent(), NoTangent(), dα, dβ, dba... end - return C′, pullback + return C′, tensoradd_pullback end function ChainRulesCore.rrule( @@ -143,84 +119,31 @@ function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba) projectα = ProjectTo(α) projectβ = ProjectTo(β) - function pullback(ΔC′) + function tensorcontract_pullback(ΔC′) ΔC = unthunk(ΔC′) - ipAB = invperm(linearize(pAB)) - pΔC = ( - TupleTools.getindices(ipAB, trivtuple(numout(pA))), - TupleTools.getindices(ipAB, numout(pA) .+ trivtuple(numin(pB))), - ) - dC = if β === Zero() - ZeroTangent() - else - @thunk projectC(scale(ΔC, conj(β))) - end - dA = @thunk let - ipA = (invperm(linearize(pA)), ()) - conjΔC = conjA - conjB′ = conjA ? conjB : !conjB - _dA = zerovector(A, promote_contract(scalartype(ΔC), scalartype(B), typeof(α))) - _dA = tensorcontract!( - _dA, - ΔC, pΔC, conjΔC, - B, reverse(pB), conjB′, - ipA, - conjA ? α : conj(α), Zero(), ba... - ) - projectA(_dA) - end - dB = @thunk let - ipB = (invperm(linearize(pB)), ()) - conjΔC = conjB - conjA′ = conjB ? conjA : !conjA - _dB = zerovector(B, promote_contract(scalartype(ΔC), scalartype(A), typeof(α))) - _dB = tensorcontract!( - _dB, - A, reverse(pA), conjA′, - ΔC, pΔC, conjΔC, - ipB, - conjB ? α : conj(α), Zero(), ba... - ) - projectB(_dB) - end + + dC = β === Zero() ? ZeroTangent() : @thunk projectC(pullback_dC(ΔC, β)) + dA = @thunk projectA(tensorcontract_pullback_dA(ΔC, C, A, pA, conjA, B, pB, conjB, pAB, α, ba...)) + dB = @thunk projectB(tensorcontract_pullback_dB(ΔC, C, A, pA, conjA, B, pB, conjB, pAB, α, ba...)) dα = if _needs_tangent(α) - @thunk let - C_αβ = tensorcontract(A, pA, conjA, B, pB, conjB, pAB, One(), ba...) - # TODO: consider using `inner` - _dα = tensorscalar( - tensorcontract( - C_αβ, ((), trivtuple(numind(pAB))), true, - ΔC, (trivtuple(numind(pAB)), ()), false, - ((), ()), One(), ba... - ) - ) - projectα(_dα) - end + @thunk projectα(tensorcontract_pullback_dα(ΔC, C, A, pA, conjA, B, pB, conjB, pAB, α, ba...)) else ZeroTangent() end dβ = if _needs_tangent(β) - @thunk let - # TODO: consider using `inner` - _dβ = tensorscalar( - tensorcontract( - C, ((), trivtuple(numind(pAB))), true, - ΔC, (trivtuple(numind(pAB)), ()), false, - ((), ()), One(), ba... - ) - ) - projectβ(_dβ) - end + @thunk projectβ(pullback_dβ(ΔC, C, β)) else ZeroTangent() end dba = map(_ -> NoTangent(), ba) return NoTangent(), dC, - dA, NoTangent(), NoTangent(), dB, NoTangent(), NoTangent(), - NoTangent(), dα, dβ, dba... + dA, NoTangent(), NoTangent(), + dB, NoTangent(), NoTangent(), + NoTangent(), + dα, dβ, dba... end - return C′, pullback + return C′, tensorcontract_pullback end function ChainRulesCore.rrule( @@ -239,59 +162,18 @@ function _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba) projectα = ProjectTo(α) projectβ = ProjectTo(β) - function pullback(ΔC′) + function tensortrace_pullback(ΔC′) ΔC = unthunk(ΔC′) - dC = if β === Zero() - ZeroTangent() - else - @thunk projectC(scale(ΔC, conj(β))) - end - dA = @thunk let - ip = invperm((linearize(p)..., q[1]..., q[2]...)) - Es = map(q[1], q[2]) do i1, i2 - one( - TensorOperations.tensoralloc_add( - scalartype(A), A, ((i1,), (i2,)), conjA - ) - ) - end - E = _kron(Es, ba) - _dA = zerovector(A, VectorInterface.promote_scale(ΔC, α)) - _dA = tensorproduct!( - _dA, ΔC, (trivtuple(numind(p)), ()), conjA, - E, ((), trivtuple(numind(q))), conjA, - (ip, ()), - conjA ? α : conj(α), Zero(), ba... - ) - projectA(_dA) - end + + dC = β === Zero() ? ZeroTangent() : @thunk projectC(pullback_dC(ΔC, β)) + dA = @thunk projectA(tensortrace_pullback_dA(ΔC, C, A, p, q, conjA, α, ba...)) dα = if _needs_tangent(α) - @thunk let - C_αβ = tensortrace(A, p, q, false, One(), ba...) - _dα = tensorscalar( - tensorcontract( - C_αβ, ((), trivtuple(numind(p))), - !conjA, - ΔC, (trivtuple(numind(p)), ()), false, - ((), ()), One(), ba... - ) - ) - projectα(_dα) - end + @thunk projectα(tensortrace_pullback_dα(ΔC, C, A, p, q, conjA, α, ba...)) else ZeroTangent() end dβ = if _needs_tangent(β) - @thunk let - _dβ = tensorscalar( - tensorcontract( - C, ((), trivtuple(numind(p))), true, - ΔC, (trivtuple(numind(p)), ()), false, - ((), ()), One(), ba... - ) - ) - projectβ(_dβ) - end + @thunk projectβ(pullback_dβ(ΔC, C, β)) else ZeroTangent() end @@ -299,7 +181,7 @@ function _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba) return NoTangent(), dC, dA, NoTangent(), NoTangent(), NoTangent(), dα, dβ, dba... end - return C′, pullback + return C′, tensortrace_pullback end # NCON functions diff --git a/ext/TensorOperationsMooncakeExt/TensorOperationsMooncakeExt.jl b/ext/TensorOperationsMooncakeExt/TensorOperationsMooncakeExt.jl index 49d1708..8bff2f8 100644 --- a/ext/TensorOperationsMooncakeExt/TensorOperationsMooncakeExt.jl +++ b/ext/TensorOperationsMooncakeExt/TensorOperationsMooncakeExt.jl @@ -6,7 +6,7 @@ using TensorOperations # extension are in fact loaded using Mooncake, Mooncake.CRC using TensorOperations: AbstractBackend, DefaultAllocator, CUDAAllocator, ManualAllocator -using TensorOperations: tensoralloc, tensoradd!, tensorcontract!, tensortrace!, _kron, numind, _needs_tangent, numin, numout +using TensorOperations: tensoralloc, tensoradd!, tensorcontract!, tensortrace! using Mooncake: ReverseMode, DefaultCtx, CoDual, NoRData, arrayify, @zero_derivative, primal, tangent using VectorInterface, TupleTools @@ -16,8 +16,6 @@ Mooncake.tangent_type(::Type{DefaultAllocator}) = Mooncake.NoTangent Mooncake.tangent_type(::Type{CUDAAllocator}) = Mooncake.NoTangent Mooncake.tangent_type(::Type{ManualAllocator}) = Mooncake.NoTangent -trivtuple(N) = ntuple(identity, N) - @zero_derivative DefaultCtx Tuple{typeof(TensorOperations.tensorstructure), Any} @zero_derivative DefaultCtx Tuple{typeof(TensorOperations.tensoradd_structure), Any} @zero_derivative DefaultCtx Tuple{typeof(TensorOperations.tensoradd_type), Any} @@ -61,69 +59,9 @@ function Mooncake.rrule!!( TensorOperations.tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba...) function contract_pb(::NoRData) scale!(C, C_cache, One()) - if Tα == Zero && Tβ == Zero - scale!(dC, zero(TC)) - return ntuple(i -> NoRData(), 11 + length(ba)) - end - ipAB = invperm(linearize(pAB)) - pdC = ( - TupleTools.getindices(ipAB, trivtuple(numout(pA))), - TupleTools.getindices(ipAB, numout(pA) .+ trivtuple(numin(pB))), - ) - ipA = (invperm(linearize(pA)), ()) - ipB = (invperm(linearize(pB)), ()) - conjΔC = conjA - conjB′ = conjA ? conjB : !conjB - dA = tensorcontract!( - dA, - dC, pdC, conjΔC, - B, reverse(pB), conjB′, - ipA, - conjA ? α : conj(α), One(), ba... - ) - conjΔC = conjB - conjA′ = conjB ? conjA : !conjA - dB = tensorcontract!( - dB, - A, reverse(pA), conjA′, - dC, pdC, conjΔC, - ipB, - conjB ? α : conj(α), One(), ba... - ) - dα = if _needs_tangent(Tα) - C_αβ = tensorcontract(A, pA, conjA, B, pB, conjB, pAB, One(), ba...) - # TODO: consider using `inner` - Mooncake._rdata( - tensorscalar( - tensorcontract( - C_αβ, ((), trivtuple(numind(pAB))), true, - dC, (trivtuple(numind(pAB)), ()), false, - ((), ()), One(), ba... - ) - ) - ) - else - NoRData() - end - dβ = if _needs_tangent(Tβ) - # TODO: consider using `inner` - Mooncake._rdata( - tensorscalar( - tensorcontract( - C, ((), trivtuple(numind(pAB))), true, - dC, (trivtuple(numind(pAB)), ()), false, - ((), ()), One(), ba... - ) - ) - ) - else - NoRData() - end - if β === Zero() - scale!(dC, β) - else - scale!(dC, conj(β)) - end + dC, dA, dB, Δα, Δβ = TensorOperations.tensorcontract_pullback!(dC, dA, dB, C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba...) + dα = isnothing(Δα) ? NoRData() : Mooncake._rdata(Δα) + dβ = isnothing(Δβ) ? NoRData() : Mooncake._rdata(Δβ) return NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), dα, dβ, map(ba_ -> NoRData(), ba)... end return C_dC, contract_pb @@ -151,35 +89,9 @@ function Mooncake.rrule!!( TensorOperations.tensoradd!(C, A, pA, conjA, α, β, ba...) function add_pb(::NoRData) scale!(C, C_cache, One()) - ipA = invperm(linearize(pA)) - dA = tensoradd!(dA, dC, (ipA, ()), conjA, conjA ? α : conj(α), One(), ba...) - dα = if _needs_tangent(Tα) - tensorscalar( - tensorcontract( - A, ((), linearize(pA)), !conjA, - dC, (trivtuple(numind(pA)), ()), false, - ((), ()), One(), ba... - ) - ) - else - Mooncake.NoRData() - end - dβ = if _needs_tangent(Tβ) - tensorscalar( - tensorcontract( - C, ((), trivtuple(numind(pA))), true, - dC, (trivtuple(numind(pA)), ()), false, - ((), ()), One(), ba... - ) - ) - else - Mooncake.NoRData() - end - if β === Zero() - scale!(dC, β) - else - scale!(dC, conj(β)) - end + dC, dA, Δα, Δβ = TensorOperations.tensoradd_pullback!(dC, dA, C, A, pA, conjA, α, β, ba...) + dα = isnothing(Δα) ? NoRData() : Mooncake._rdata(Δα) + dβ = isnothing(Δβ) ? NoRData() : Mooncake._rdata(Δβ) return NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), dα, dβ, map(ba_ -> NoRData(), ba)... end return C_dC, add_pb @@ -209,54 +121,9 @@ function Mooncake.rrule!!( TensorOperations.tensortrace!(C, A, p, q, conjA, α, β, ba...) function trace_pb(::NoRData) scale!(C, C_cache, One()) - ip = invperm((linearize(p)..., q[1]..., q[2]...)) - Es = map(q[1], q[2]) do i1, i2 - one( - TensorOperations.tensoralloc_add( - TensorOperations.scalartype(A), A, ((i1,), (i2,)), conjA - ) - ) - end - E = _kron(Es, ba) - dA = tensorproduct!( - dA, dC, (trivtuple(numind(p)), ()), conjA, - E, ((), trivtuple(numind(q))), conjA, - (ip, ()), - conjA ? α : conj(α), One(), ba... - ) - C_αβ = tensortrace(A, p, q, false, One(), ba...) - dα = if _needs_tangent(Tα) - Mooncake._rdata( - tensorscalar( - tensorcontract( - C_αβ, ((), trivtuple(numind(p))), - !conjA, - dC, (trivtuple(numind(p)), ()), false, - ((), ()), One(), ba... - ) - ) - ) - else - NoRData() - end - dβ = if _needs_tangent(Tβ) - Mooncake._rdata( - tensorscalar( - tensorcontract( - C, ((), trivtuple(numind(p))), true, - dC, (trivtuple(numind(p)), ()), false, - ((), ()), One(), ba... - ) - ) - ) - else - NoRData() - end - if β === Zero() - scale!(dC, β) - else - scale!(dC, conj(β)) - end + dC, dA, Δα, Δβ = TensorOperations.tensortrace_pullback!(dC, dA, C, A, p, q, conjA, α, β, ba...) + dα = isnothing(Δα) ? NoRData() : Mooncake._rdata(Δα) + dβ = isnothing(Δβ) ? NoRData() : Mooncake._rdata(Δβ) return NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), dα, dβ, map(ba_ -> NoRData(), ba)... end return C_dC, trace_pb diff --git a/src/TensorOperations.jl b/src/TensorOperations.jl index 23f205d..9bc4dbd 100644 --- a/src/TensorOperations.jl +++ b/src/TensorOperations.jl @@ -36,6 +36,13 @@ include("backends.jl") include("interface.jl") include("utils.jl") +# Generic pullbacks for AD +#--------------------------- +include("pullbacks/common.jl") +include("pullbacks/add.jl") +include("pullbacks/trace.jl") +include("pullbacks/contract.jl") + # Index notation via macros #--------------------------- @nospecialize diff --git a/src/indices.jl b/src/indices.jl index d844478..5e5bfeb 100644 --- a/src/indices.jl +++ b/src/indices.jl @@ -26,10 +26,39 @@ trivialpermutation(p::IndexTuple{N}) where {N} = ntuple(identity, Val(N)) function trivialpermutation(p::Index2Tuple) return (trivialpermutation(p[1]), numout(p) .+ trivialpermutation(p[2])) end +trivialpermutation(N::Integer) = ntuple(identity, N) +trivialpermutation(N₁::Integer, N₂::Integer) = (trivialpermutation(N₁), trivialpermutation(N₂) .+ N₁) istrivialpermutation(p::IndexTuple) = p == trivialpermutation(p) istrivialpermutation(p::Index2Tuple) = p == trivialpermutation(p) +Base.@constprop :aggressive function repartition(p::IndexTuple, N₁::Int) + length(p) >= N₁ || throw(ArgumentError("cannot repartition $(typeof(p)) to $N₁, $(length(p) - N₁)")) + partition = trivialpermutation(N₁, length(p) - N₁) + return TupleTools.getindices(p, partition[1]), TupleTools.getindices(p, partition[2]) +end +@inline repartition(p::Index2Tuple, N₁::Int) = repartition(linearize(p), N₁) + +@inline repartition(p::Union{IndexTuple, Index2Tuple}, ::Index2Tuple{N₁}) where {N₁} = + repartition(p, N₁) +@inline repartition(p::Union{IndexTuple, Index2Tuple}, A::AbstractArray) = repartition(p, ndims(A)) + +""" + inversepermutation(p::Index2Tuple) -> ip::IndexTuple + inversepermutation(p::Index2Tuple, N₁::Int) -> ip::Index2Tuple + inversepermutation(p::Index2Tuple, partition_as::Index2Tuple) -> ip::Index2Tuple + inversepermutation(p::Index2Tuple, partition_as::AbstractArray) -> ip::Index2Tuple + +Compute the inverse permutation associated with `p`. +If no extra arguments are provided, the result is returned as a single [`IndexTuple`](@ref). +Otherwise, the extra arguments are used to partition the inverse permutation into an [`Index2Tuple`](@ref). +""" +inversepermutation(p::Index2Tuple) = invperm(linearize(p)) +function inversepermutation(p::Index2Tuple, args...) + ip = invperm(linearize(p)) + return repartition(ip, args...) +end + """ const LabelType = Union{Int,Symbol,Char} diff --git a/src/pullbacks/add.jl b/src/pullbacks/add.jl new file mode 100644 index 0000000..a71cbfa --- /dev/null +++ b/src/pullbacks/add.jl @@ -0,0 +1,56 @@ +""" + tensoradd_pullback!(ΔC, ΔA, C, A, pA::Index2Tuple, conjA::Bool, α, β, ba...) -> ΔC, ΔA, Δα, Δβ + +Compute pullbacks for [`tensoradd!`](@ref), updating cotangent arrays and returning cotangent scalars. + +See also [`pullback_dC`](@ref), [`tensoradd_pullback_dA`](@ref), [`tensoradd_pullback_dα`](@ref) and [`pullback_dβ`](@ref) +for computing pullbacks for the individual components. +""" +function tensoradd_pullback!(ΔC, ΔA, C, A, pA::Index2Tuple, conjA::Bool, α, β, ba...) + dA = tensoradd_pullback_dA!(ΔA, ΔC, C, A, pA, conjA, α, ba...) + dα = tensoradd_pullback_dα(ΔC, C, A, pA, conjA, α, ba...) + dβ = pullback_dβ(ΔC, C, β) + dC = pullback_dC!(ΔC, β) + return dC, dA, dα, dβ +end + +@doc """ + tensoradd_pullback_dA(ΔC, C, A, pA::Index2Tuple, conjA::Bool, α, ba...) + tensoradd_pullback_dA!(ΔA, ΔC, C, A, pA::Index2Tuple, conjA::Bool, α, ba...) + +Compute the pullback for [`tensoradd!`](@ref) with respect to the input `A`. +The mutating version can be used to accumulate the result into `ΔA`. + +See also [`tensoradd_pullback_dA!`](@ref) for computing and updating the gradient in-place. +""" tensoradd_pullback_dA, tensoradd_pullback_dA! + +function tensoradd_pullback_dA(ΔC, C, A, pA::Index2Tuple, conjA::Bool, α, ba...) + ipA = inversepermutation(pA, A) + return tensorcopy(ΔC, ipA, conjA, conjA ? α : conj(α), ba...) +end +function tensoradd_pullback_dA!(ΔA, ΔC, C, A, pA::Index2Tuple, conjA::Bool, α, ba...) + if eltype(ΔC) <: Complex && eltype(ΔA) <: Real + ΔAc = tensoradd_pullback_dA(ΔC, C, A, pA, conjA, α, ba...) + ΔA .+= real.(ΔAc) + else + ipA = inversepermutation(pA, ΔA) + tensoradd!(ΔA, ΔC, ipA, conjA, conjA ? α : conj(α), One(), ba...) + end + return ΔA +end + +""" + tensoradd_pullback_dα(ΔC, C, A, pA::Index2Tuple, conjA::Bool, α, ba...) + +Compute the pullback for [`tensoradd!]`(ref) with respect to scaling coefficient `α`. +""" +function tensoradd_pullback_dα(ΔC, C, A, pA::Index2Tuple, conjA::Bool, α, ba...) + _needs_tangent(α) || return nothing + return tensorscalar( + tensorcontract( + A, repartition(pA, 0), !conjA, + ΔC, trivialpermutation(numind(pA), 0), false, + ((), ()), One(), ba... + ) + ) +end diff --git a/src/pullbacks/common.jl b/src/pullbacks/common.jl new file mode 100644 index 0000000..c561c43 --- /dev/null +++ b/src/pullbacks/common.jl @@ -0,0 +1,34 @@ +# To avoid computing rrules for α and β when these aren't needed, we want to have a +# type-stable quick bail-out +""" + _needs_tangent(x) + _needs_tangent(::Type{T}) + +Determine whether a value requires tangent computation during automatic differentiation. +Returns `false` for constants like Integer, One, and Zero types to avoid unnecessary computation +in automatic differentiation. Returns `true` only for general Numbers. +""" +_needs_tangent(x) = _needs_tangent(typeof(x)) +_needs_tangent(::Type{<:Number}) = true +_needs_tangent(::Type{<:Integer}) = false +_needs_tangent(::Type{<:Union{One, Zero}}) = false +_needs_tangent(::Type{Complex{T}}) where {T} = _needs_tangent(T) + +# (partial) pullbacks that are shared +@doc """ + pullback_dC(ΔC, β) + pullback_dC!(ΔC, β) + +For functions of the form `f!(C, β, ...) = βC + ...`, compute the pullback with respect to `C`. +""" pullback_dC, pullback_dC! + +pullback_dC!(ΔC, β) = scale!(ΔC, conj(β)) +pullback_dC(ΔC, β) = scale(ΔC, conj(β)) + +@doc """ + pullback_dβ(ΔC, C, β) + +For functions of the form `f!(C, β, ...) = βC + ...`, compute the pullback with respect to `β`. +""" pullback_dβ + +pullback_dβ(ΔC, C, β) = _needs_tangent(β) ? inner(C, ΔC) : nothing diff --git a/src/pullbacks/contract.jl b/src/pullbacks/contract.jl new file mode 100644 index 0000000..f8f4659 --- /dev/null +++ b/src/pullbacks/contract.jl @@ -0,0 +1,126 @@ +""" + tensorcontract_pullback!(ΔC, ΔA, ΔB, C, A, pA::Index2Tuple, conjA::Bool, B, pB::Index2Tuple, conjB::Bool, pAB::Index2Tuple, α::Number, β::Number, ba...) -> ΔC, ΔA, ΔB, Δα, Δβ + +Compute pullbacks for [`tensorcontract!`](@ref), updating cotangent arrays and returning cotangent scalars. + +See also [`pullback_dC`](@ref), [`tensorcontract_pullback_dA`](@ref), [`tensorcontract_pullback_dB`](@ref), +[`tensorcontract_pullback_dα`](@ref) and [`pullback_dβ`](@ref) for computing pullbacks for the individual components. +""" +function tensorcontract_pullback!( + ΔC, ΔA, ΔB, + C, + A, pA::Index2Tuple, conjA::Bool, + B, pB::Index2Tuple, conjB::Bool, + pAB::Index2Tuple, + α::Number, β::Number, + ba... + ) + dA = tensorcontract_pullback_dA!(ΔA, ΔC, C, A, pA, conjA, B, pB, conjB, pAB, α, ba...) + dB = tensorcontract_pullback_dB!(ΔB, ΔC, C, A, pA, conjA, B, pB, conjB, pAB, α, ba...) + dα = tensorcontract_pullback_dα(ΔC, C, A, pA, conjA, B, pB, conjB, pAB, α, ba...) + dβ = pullback_dβ(ΔC, C, β) + dC = pullback_dC!(ΔC, β) + return dC, dA, dB, dα, dβ +end + +@doc """ + tensorcontract_pullback_dA(ΔC, C, A, pA::Index2Tuple, conjA::Bool, B, pB::Index2Tuple, conjB::Bool, pAB::Index2Tuple, α::Number, ba...) + tensorcontract_pullback_dA!(ΔA, ΔC, C, A, pA::Index2Tuple, conjA::Bool, B, pB::Index2Tuple, conjB::Bool, pAB::Index2Tuple, α::Number, ba...) + +Compute the pullback for [`tensorcontract!`](@ref) with respect to the input `A`. +The mutating version can be used to accumulate the result into `ΔA`. + +See also [`tensorcontract_pullback_dB`](@ref) and [`tensorcontract_pullback_dB!`](@ref) for the pullback with respect to `B`. +""" tensorcontract_pullback_dA, tensorcontract_pullback_dA! + +function tensorcontract_pullback_dA( + ΔC, C, + A, pA::Index2Tuple, conjA::Bool, B, pB::Index2Tuple, conjB::Bool, + pAB::Index2Tuple, α::Number, ba... + ) + pdC = inversepermutation(pAB, numout(pA)) + ipA = inversepermutation(pA, A) + return tensorcontract( + ΔC, pdC, conjA, B, reverse(pB), conjA ? conjB : !conjB, + ipA, conjA ? α : conj(α), ba... + ) +end +function tensorcontract_pullback_dA!( + ΔA, ΔC, C, + A, pA::Index2Tuple, conjA::Bool, B, pB::Index2Tuple, conjB::Bool, + pAB::Index2Tuple, α::Number, ba... + ) + + if eltype(ΔC) <: Complex && eltype(ΔA) <: Real + ΔAc = tensorcontract_pullback_dA(ΔC, C, A, pA, conjA, B, pB, conjB, pAB, α, ba...) + ΔA .+= real.(ΔAc) + else + pdC = inversepermutation(pAB, numout(pA)) + ipA = inversepermutation(pA, A) + tensorcontract!( + ΔA, + ΔC, pdC, conjA, B, reverse(pB), conjA ? conjB : !conjB, + ipA, conjA ? α : conj(α), One(), ba... + ) + end + + return ΔA +end + +@doc """ + tensorcontract_pullback_dB(ΔC, C, A, pA::Index2Tuple, conjA::Bool, B, pB::Index2Tuple, conjB::Bool, pAB::Index2Tuple, α::Number, ba...) + tensorcontract_pullback_dB!(ΔB, ΔC, C, A, pA::Index2Tuple, conjA::Bool, B, pB::Index2Tuple, conjB::Bool, pAB::Index2Tuple, α::Number, ba...) + +Compute the pullback for [`tensorcontract!`](@ref) with respect to the input `B`. +The mutating version can be used to accumulate the result into `ΔB`. + +See also [`tensorcontract_pullback_dA`](@ref) and [`tensorcontract_pullback_dA!`](@ref) for the pullback with respect to `A`. +""" tensorcontract_pullback_dB, tensorcontract_pullback_dB! + +function tensorcontract_pullback_dB( + ΔC, C, + A, pA::Index2Tuple, conjA::Bool, B, pB::Index2Tuple, conjB::Bool, + pAB::Index2Tuple, α::Number, ba... + ) + pdC = inversepermutation(pAB, numout(pA)) + ipB = inversepermutation(pB, B) + return tensorcontract( + A, reverse(pA), conjB ? conjA : !conjA, ΔC, pdC, conjB, + ipB, conjB ? α : conj(α), ba... + ) +end +function tensorcontract_pullback_dB!( + ΔB, ΔC, C, + A, pA::Index2Tuple, conjA::Bool, B, pB::Index2Tuple, conjB::Bool, + pAB::Index2Tuple, α::Number, ba... + ) + if eltype(ΔC) <: Complex && eltype(ΔB) <: Real + ΔBc = tensorcontract_pullback_dB(ΔC, C, A, pA, conjA, B, pB, conjB, pAB, α, ba...) + ΔB .+= real.(ΔBc) + else + pdC = inversepermutation(pAB, numout(pA)) + ipB = inversepermutation(pB, B) + tensorcontract!( + ΔB, + A, reverse(pA), conjB ? conjA : !conjA, ΔC, pdC, conjB, + ipB, conjB ? α : conj(α), One(), ba... + ) + end + + return ΔB +end + +""" + tensorcontract_pullback_dα(ΔC, C, A, pA::Index2Tuple, conjA::Bool, B, pB::Index2Tuple, conjB::Bool, pAB::Index2Tuple, α::Number, ba...) + +Compute the pullback for [`tensorcontract!`](@ref) with respect to scaling coefficient `α`. +""" +function tensorcontract_pullback_dα( + ΔC, C, + A, pA::Index2Tuple, conjA::Bool, B, pB::Index2Tuple, conjB::Bool, + pAB::Index2Tuple, α::Number, ba... + ) + _needs_tangent(α) || return nothing + C_αβ = tensorcontract(A, pA, conjA, B, pB, conjB, pAB, One(), ba...) + return inner(C_αβ, ΔC) +end diff --git a/src/pullbacks/trace.jl b/src/pullbacks/trace.jl new file mode 100644 index 0000000..f3135a2 --- /dev/null +++ b/src/pullbacks/trace.jl @@ -0,0 +1,99 @@ +""" + tensortrace_pullback!(ΔC, ΔA, C, A, p::Index2Tuple, q::Index2Tuple, conjA::Bool, α::Number, β::Number, ba...) -> ΔC, ΔA, Δα, Δβ + +Compute pullbacks for [`tensortrace!`](@ref), updating cotangent arrays and returning cotangent scalars. + +See also [`pullback_dC`](@ref), [`tensortrace_pullback_dA`](@ref), [`tensortrace_pullback_dα`](@ref) and [`pullback_dβ`](@ref) +for computing pullbacks for the individual components. +""" +function tensortrace_pullback!( + ΔC, ΔA, + C, A, p::Index2Tuple, q::Index2Tuple, conjA::Bool, + α::Number, β::Number, ba... + ) + dA = tensortrace_pullback_dA!(ΔA, ΔC, C, A, p, q, conjA, α, ba...) + dα = tensortrace_pullback_dα(ΔC, C, A, p, q, conjA, α, ba...) + dβ = pullback_dβ(ΔC, C, β) + dC = pullback_dC!(ΔC, β) + + return dC, dA, dα, dβ +end + +@doc """ + tensortrace_pullback_dA(ΔC, C, A, p::Index2Tuple, q::Index2Tuple, conjA::Bool, α::Number, ba...) + tensortrace_pullback_dA!(ΔA, ΔC, C, A, p::Index2Tuple, q::Index2Tuple, conjA::Bool, α::Number, ba...) + +Compute the pullback for [`tensortrace!`](@ref) with respect to the input `A`. +The mutating version can be used to accumulate the result into `ΔA`. +""" tensortrace_pullback_dA, tensortrace_pullback_dA! + +function tensortrace_pullback_dA( + ΔC, C, A, p::Index2Tuple, q::Index2Tuple, conjA::Bool, + α::Number, ba... + ) + ip = repartition(invperm((linearize(p)..., linearize(q)...)), A) + Es = map(q[1], q[2]) do i1, i2 + one( + TensorOperations.tensoralloc_add( + TensorOperations.scalartype(A), A, ((i1,), (i2,)), conjA + ) + ) + end + E = _kron(Es, ba) + + return tensorproduct( + ΔC, trivialpermutation(numind(p), 0), conjA, + E, trivialpermutation(0, numind(q)), conjA, + ip, + conjA ? α : conj(α), ba... + ) +end +function tensortrace_pullback_dA!( + ΔA, ΔC, C, A, p::Index2Tuple, q::Index2Tuple, conjA::Bool, + α::Number, ba... + ) + + if eltype(ΔC) <: Complex && eltype(ΔA) <: Real + ΔAc = tensortrace_pullback_dA(ΔC, C, A, p, q, conjA, α, ba...) + ΔA .+= real.(ΔAc) + else + ip = repartition(invperm((linearize(p)..., linearize(q)...)), A) + Es = map(q[1], q[2]) do i1, i2 + one( + TensorOperations.tensoralloc_add( + TensorOperations.scalartype(A), A, ((i1,), (i2,)), conjA + ) + ) + end + E = _kron(Es, ba) + tensorproduct!( + ΔA, ΔC, trivialpermutation(numind(p), 0), conjA, + E, trivialpermutation(0, numind(q)), conjA, + ip, + conjA ? α : conj(α), One(), ba... + ) + end + + return ΔA +end + +""" + tensortrace_pullback_dα(ΔC, C, A, p::Index2Tuple, q::Index2Tuple, conjA::Bool, α::Number, ba...) + +Compute the pullback for [`tensortrace!`](@ref) with respect to scaling coefficient `α`. +""" +function tensortrace_pullback_dα( + ΔC, C, A, p::Index2Tuple, q::Index2Tuple, conjA::Bool, + α::Number, ba... + ) + _needs_tangent(α) || return nothing + C_αβ = tensortrace(A, p, q, false, One(), ba...) + return tensorscalar( + tensorcontract( + C_αβ, trivialpermutation(0, numind(p)), + !conjA, + ΔC, trivialpermutation(numind(p), 0), false, + ((), ()), One(), ba... + ) + ) +end diff --git a/src/utils.jl b/src/utils.jl index 20e406b..0afef28 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -2,14 +2,7 @@ _kron(Es::NTuple{1}, ba) = Es[1] function _kron(Es::NTuple{N, Any}, ba) where {N} E1 = Es[1] E2 = _kron(Base.tail(Es), ba) - p2 = ((), trivtuple(2 * N - 2)) - p = ((1, (2 .+ trivtuple(N - 1))...), (2, ((N + 1) .+ trivtuple(N - 1))...)) + p2 = trivialpermutation(0, 2N - 2) + p = ((1, (2 .+ trivialpermutation(N - 1))...), (2, ((N + 1) .+ trivialpermutation(N - 1))...)) return tensorproduct(p, E1, ((1, 2), ()), false, E2, p2, false, One(), ba...) end - -# To avoid computing rrules for α and β when these aren't needed, we want to have a -# type-stable quick bail-out -_needs_tangent(x) = _needs_tangent(typeof(x)) -_needs_tangent(::Type{<:Number}) = true -_needs_tangent(::Type{<:Integer}) = false -_needs_tangent(::Type{<:Union{One, Zero}}) = false