Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ Preferences = "1.4"
PtrArrays = "1.2"
Random = "1"
Strided = "2.2"
StridedViews = "0.3, 0.4"
StridedViews = "0.3, 0.4, ~0.4.2"
Test = "1"
TupleTools = "1.6"
VectorInterface = "0.4.1,0.5"
Expand Down
3 changes: 2 additions & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ makedocs(;
"man/precompilation.md",
],
"Index" => "index/index.md",
]
],
checkdocs = :public
)

# Documenter can also automatically deploy documentation to gh-pages.
Expand Down
180 changes: 31 additions & 149 deletions ext/TensorOperationsChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,18 @@ module TensorOperationsChainRulesCoreExt

using TensorOperations
using TensorOperations: numind, numin, numout, promote_contract, _needs_tangent
using TensorOperations: pullback_dC, pullback_dβ,
tensoradd_pullback_dA, tensoradd_pullback_dα,
tensorcontract_pullback_dA, tensorcontract_pullback_dB, tensorcontract_pullback_dα,
tensortrace_pullback_dA, tensortrace_pullback_dα

using TensorOperations: DefaultBackend, DefaultAllocator, _kron
using ChainRulesCore
using TupleTools
using VectorInterface
using TupleTools: invperm
using LinearAlgebra

trivtuple(N) = ntuple(identity, N)

@non_differentiable TensorOperations.tensorstructure(args...)
@non_differentiable TensorOperations.tensoradd_structure(args...)
@non_differentiable TensorOperations.tensoradd_type(args...)
Expand Down Expand Up @@ -74,53 +77,26 @@ function _rrule_tensoradd!(C, A, pA, conjA, α, β, ba)
projectα = ProjectTo(α)
projectβ = ProjectTo(β)

function pullback(ΔC′)
function tensoradd_pullback(ΔC′)
ΔC = unthunk(ΔC′)
dC = if β === Zero()
ZeroTangent()
else
@thunk projectC(scale(ΔC, conj(β)))
end
dA = @thunk let
ipA = invperm(linearize(pA))
_dA = zerovector(A, VectorInterface.promote_add(ΔC, α))
_dA = tensoradd!(_dA, ΔC, (ipA, ()), conjA, conjA ? α : conj(α), Zero(), ba...)
projectA(_dA)
end

dC = β === Zero() ? ZeroTangent() : @thunk projectC(pullback_dC(ΔC, β))
dA = @thunk projectA(tensoradd_pullback_dA(ΔC, C, A, pA, conjA, α, ba...))
dα = if _needs_tangent(α)
@thunk let
_dα = tensorscalar(
tensorcontract(
A, ((), linearize(pA)), !conjA,
ΔC, (trivtuple(numind(pA)), ()), false,
((), ()), One(), ba...
)
)
projectα(_dα)
end
@thunk projectα(tensoradd_pullback_dα(ΔC, C, A, pA, conjA, α, ba...))
else
ZeroTangent()
end
dβ = if _needs_tangent(β)
@thunk let
# TODO: consider using `inner`
_dβ = tensorscalar(
tensorcontract(
C, ((), trivtuple(numind(pA))), true,
ΔC, (trivtuple(numind(pA)), ()), false,
((), ()), One(), ba...
)
)
projectβ(_dβ)
end
@thunk projectβ(pullback_dβ(ΔC, C, β))
else
ZeroTangent()
end
dba = map(_ -> NoTangent(), ba)
return NoTangent(), dC, dA, NoTangent(), NoTangent(), dα, dβ, dba...
end

return C′, pullback
return C′, tensoradd_pullback
end

function ChainRulesCore.rrule(
Expand All @@ -143,84 +119,31 @@ function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba)
projectα = ProjectTo(α)
projectβ = ProjectTo(β)

function pullback(ΔC′)
function tensorcontract_pullback(ΔC′)
ΔC = unthunk(ΔC′)
ipAB = invperm(linearize(pAB))
pΔC = (
TupleTools.getindices(ipAB, trivtuple(numout(pA))),
TupleTools.getindices(ipAB, numout(pA) .+ trivtuple(numin(pB))),
)
dC = if β === Zero()
ZeroTangent()
else
@thunk projectC(scale(ΔC, conj(β)))
end
dA = @thunk let
ipA = (invperm(linearize(pA)), ())
conjΔC = conjA
conjB′ = conjA ? conjB : !conjB
_dA = zerovector(A, promote_contract(scalartype(ΔC), scalartype(B), typeof(α)))
_dA = tensorcontract!(
_dA,
ΔC, pΔC, conjΔC,
B, reverse(pB), conjB′,
ipA,
conjA ? α : conj(α), Zero(), ba...
)
projectA(_dA)
end
dB = @thunk let
ipB = (invperm(linearize(pB)), ())
conjΔC = conjB
conjA′ = conjB ? conjA : !conjA
_dB = zerovector(B, promote_contract(scalartype(ΔC), scalartype(A), typeof(α)))
_dB = tensorcontract!(
_dB,
A, reverse(pA), conjA′,
ΔC, pΔC, conjΔC,
ipB,
conjB ? α : conj(α), Zero(), ba...
)
projectB(_dB)
end

dC = β === Zero() ? ZeroTangent() : @thunk projectC(pullback_dC(ΔC, β))
dA = @thunk projectA(tensorcontract_pullback_dA(ΔC, C, A, pA, conjA, B, pB, conjB, pAB, α, ba...))
dB = @thunk projectB(tensorcontract_pullback_dB(ΔC, C, A, pA, conjA, B, pB, conjB, pAB, α, ba...))
dα = if _needs_tangent(α)
@thunk let
C_αβ = tensorcontract(A, pA, conjA, B, pB, conjB, pAB, One(), ba...)
# TODO: consider using `inner`
_dα = tensorscalar(
tensorcontract(
C_αβ, ((), trivtuple(numind(pAB))), true,
ΔC, (trivtuple(numind(pAB)), ()), false,
((), ()), One(), ba...
)
)
projectα(_dα)
end
@thunk projectα(tensorcontract_pullback_dα(ΔC, C, A, pA, conjA, B, pB, conjB, pAB, α, ba...))
else
ZeroTangent()
end
dβ = if _needs_tangent(β)
@thunk let
# TODO: consider using `inner`
_dβ = tensorscalar(
tensorcontract(
C, ((), trivtuple(numind(pAB))), true,
ΔC, (trivtuple(numind(pAB)), ()), false,
((), ()), One(), ba...
)
)
projectβ(_dβ)
end
@thunk projectβ(pullback_dβ(ΔC, C, β))
else
ZeroTangent()
end
dba = map(_ -> NoTangent(), ba)
return NoTangent(), dC,
dA, NoTangent(), NoTangent(), dB, NoTangent(), NoTangent(),
NoTangent(), dα, dβ, dba...
dA, NoTangent(), NoTangent(),
dB, NoTangent(), NoTangent(),
NoTangent(),
dα, dβ, dba...
end

return C′, pullback
return C′, tensorcontract_pullback
end

function ChainRulesCore.rrule(
Expand All @@ -239,67 +162,26 @@ function _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba)
projectα = ProjectTo(α)
projectβ = ProjectTo(β)

function pullback(ΔC′)
function tensortrace_pullback(ΔC′)
ΔC = unthunk(ΔC′)
dC = if β === Zero()
ZeroTangent()
else
@thunk projectC(scale(ΔC, conj(β)))
end
dA = @thunk let
ip = invperm((linearize(p)..., q[1]..., q[2]...))
Es = map(q[1], q[2]) do i1, i2
one(
TensorOperations.tensoralloc_add(
scalartype(A), A, ((i1,), (i2,)), conjA
)
)
end
E = _kron(Es, ba)
_dA = zerovector(A, VectorInterface.promote_scale(ΔC, α))
_dA = tensorproduct!(
_dA, ΔC, (trivtuple(numind(p)), ()), conjA,
E, ((), trivtuple(numind(q))), conjA,
(ip, ()),
conjA ? α : conj(α), Zero(), ba...
)
projectA(_dA)
end

dC = β === Zero() ? ZeroTangent() : @thunk projectC(pullback_dC(ΔC, β))
dA = @thunk projectA(tensortrace_pullback_dA(ΔC, C, A, p, q, conjA, α, ba...))
dα = if _needs_tangent(α)
@thunk let
C_αβ = tensortrace(A, p, q, false, One(), ba...)
_dα = tensorscalar(
tensorcontract(
C_αβ, ((), trivtuple(numind(p))),
!conjA,
ΔC, (trivtuple(numind(p)), ()), false,
((), ()), One(), ba...
)
)
projectα(_dα)
end
@thunk projectα(tensortrace_pullback_dα(ΔC, C, A, p, q, conjA, α, ba...))
else
ZeroTangent()
end
dβ = if _needs_tangent(β)
@thunk let
_dβ = tensorscalar(
tensorcontract(
C, ((), trivtuple(numind(p))), true,
ΔC, (trivtuple(numind(p)), ()), false,
((), ()), One(), ba...
)
)
projectβ(_dβ)
end
@thunk projectβ(pullback_dβ(ΔC, C, β))
else
ZeroTangent()
end
dba = map(_ -> NoTangent(), ba)
return NoTangent(), dC, dA, NoTangent(), NoTangent(), NoTangent(), dα, dβ, dba...
end

return C′, pullback
return C′, tensortrace_pullback
end

# NCON functions
Expand Down
Loading
Loading