Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ GenericLinearAlgebra = "0.3.19"
GenericSchur = "0.5.6"
JET = "0.9, 0.10"
LinearAlgebra = "1"
Mooncake = "0.4.183"
Mooncake = "0.4.195"
ParallelTestRunner = "2"
Random = "1"
SafeTestsets = "0.1"
Expand Down
21 changes: 20 additions & 1 deletion ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module MatrixAlgebraKitCUDAExt
using MatrixAlgebraKit
using MatrixAlgebraKit: @algdef, Algorithm, check_input
using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular!
using MatrixAlgebraKit: diagview, sign_safe
using MatrixAlgebraKit: diagview, sign_safe, default_pullback_gauge_atol, default_pullback_rank_atol
using MatrixAlgebraKit: LQViaTransposedQR, TruncationByValue, AbstractAlgorithm
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm, default_eigh_algorithm
import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_Xgesvdr!, _gpu_gesvdj!, _gpu_geev!
Expand Down Expand Up @@ -195,4 +195,23 @@ end
MatrixAlgebraKit._ind_intersect(A::CuVector{Int}, B::CuVector{Int}) =
MatrixAlgebraKit._ind_intersect(collect(A), collect(B))

MatrixAlgebraKit.default_pullback_rank_atol(A::AnyCuArray) = eps(norm(CuArray(A), Inf))^(3 / 4)
MatrixAlgebraKit.default_pullback_gauge_atol(A::AnyCuArray) = MatrixAlgebraKit.iszerotangent(A) ? 0 : eps(norm(CuArray(A), Inf))^(3 / 4)
function MatrixAlgebraKit.default_pullback_gauge_atol(A::AnyCuArray, As...)
As′ = filter(!MatrixAlgebraKit.iszerotangent, (A, As...))
return isempty(As′) ? 0 : eps(norm(CuArray.(As′), Inf))^(3 / 4)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was this needed? what breaks if we don't do CuArray.(As')?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

norm doesn't work for Adjoint{CuArray} for example

end

function LinearAlgebra.sylvester(A::AnyCuMatrix, B::AnyCuMatrix, C::AnyCuMatrix)
#=m = size(A, 1)
n = size(B, 2)
I_n = fill!(similar(A, n), one(eltype(A)))
I_m = fill!(similar(B, m), one(eltype(B)))
L = kron(diagm(I_n), A) + kron(adjoint(B), diagm(I_m))
x_vec = L \ -vec(C)
X = CuMatrix(reshape(x_vec, m, n))=#
hX = sylvester(collect(A), collect(B), collect(C))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very awful but I wasn't able to find a correct way to do it in five minutes so there you go

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any chance we could:

  1. open an issue for CUDA and add the link in some comment here
  2. insert a function _sylvester to avoid type piracy

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes to both

return CuArray(hX)
end

end
21 changes: 21 additions & 0 deletions ext/MatrixAlgebraKitChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ for eig in (:eig, :eigh)
eig_t! = Symbol(eig, "_trunc!")
eig_t_pb = Symbol(eig, "_trunc_pullback")
_make_eig_t_pb = Symbol("_make_", eig_t_pb)
eig_t_ne! = Symbol(eig, "_trunc_no_error!")
eig_t_ne_pb = Symbol(eig, "_trunc_no_error_pullback")
_make_eig_t_ne_pb = Symbol("_make_", eig_t_ne_pb)
eig_v = Symbol(eig, "_vals")
eig_v! = Symbol(eig_v, "!")
eig_v_pb = Symbol(eig_v, "_pullback")
Expand Down Expand Up @@ -136,6 +139,24 @@ for eig in (:eig, :eigh)
end
return $eig_t_pb
end
function ChainRulesCore.rrule(::typeof($eig_t_ne!), A, DV, alg::TruncatedAlgorithm)
Ac = copy_input($eig_f, A)
DV = $(eig_f!)(Ac, DV, alg.alg)
DV′, ind = MatrixAlgebraKit.truncate($eig_t!, DV, alg.trunc)
return DV′, $(_make_eig_t_ne_pb)(A, DV, ind)
end
function $(_make_eig_t_ne_pb)(A, DV, ind)
function $eig_t_ne_pb(ΔDV)
ΔA = zero(A)
ΔD, ΔV = ΔDV
MatrixAlgebraKit.$eig_pb!(ΔA, A, DV, unthunk.((ΔD, ΔV)), ind)
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
end
function $eig_t_ne_pb(::Tuple{ZeroTangent, ZeroTangent}) # is this extra definition useful?
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
end
return $eig_t_ne_pb
end
function ChainRulesCore.rrule(::typeof($eig_v!), A, D, alg)
DV = $eig_f(A, alg)
function $eig_v_pb(ΔD)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module MatrixAlgebraKitMooncakeExt
using Mooncake
using Mooncake: DefaultCtx, CoDual, Dual, NoRData, rrule!!, frule!!, arrayify, @is_primitive
using MatrixAlgebraKit
using MatrixAlgebraKit: inv_safe, diagview, copy_input
using MatrixAlgebraKit: inv_safe, diagview, copy_input, initialize_output
using MatrixAlgebraKit: qr_pullback!, lq_pullback!
using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback!
using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_vals_pullback!
Expand All @@ -18,14 +18,16 @@ Mooncake.tangent_type(::Type{<:MatrixAlgebraKit.AbstractAlgorithm}) = Mooncake.N
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(copy_input), Any, Any}
function Mooncake.rrule!!(::CoDual{typeof(copy_input)}, f_df::CoDual, A_dA::CoDual)
Ac = copy_input(Mooncake.primal(f_df), Mooncake.primal(A_dA))
dAc = Mooncake.zero_tangent(Ac)
Ac_dAc = Mooncake.zero_fcodual(Ac)
dAc = Mooncake.tangent(Ac_dAc)
function copy_input_pb(::NoRData)
Mooncake.increment!!(Mooncake.tangent(A_dA), dAc)
return NoRData(), NoRData(), NoRData()
end
return CoDual(Ac, dAc), copy_input_pb
return Ac_dAc, copy_input_pb
end

@zero_derivative Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(initialize_output), Any, Any, Any}
# two-argument in-place factorizations like LQ, QR, EIG
for (f!, f, pb, adj) in (
(:qr_full!, :qr_full, :qr_pullback!, :qr_adjoint),
Expand Down
9 changes: 6 additions & 3 deletions src/pullbacks/eig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,12 @@ function eig_pullback!(
mul!(view(VᴴΔV, :, indV), V', ΔV)

mask = abs.(transpose(D) .- D) .< degeneracy_atol
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The mask probably also belongs inside of the if

Δgauge = norm(view(VᴴΔV, mask), Inf)
Δgauge ≤ gauge_atol ||
@warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
if isa(ΔA, Array)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we refactor this into a separate function and use dispatch to resolve this? It might be slightly safer to simply allocate instead of taking a view for the GPU versions, which should then no longer be scalar-indexed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep!

# not GPU friendly...
Δgauge = norm(view(VᴴΔV, mask), Inf)
Δgauge ≤ gauge_atol ||
@warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
end

VᴴΔV .*= conj.(inv_safe.(transpose(D) .- D, degeneracy_atol))

Expand Down
48 changes: 27 additions & 21 deletions src/pullbacks/lq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,28 +36,30 @@ function lq_pullback!(
ΔA1 = view(ΔA, 1:p, :)
ΔA2 = view(ΔA, (p + 1):m, :)

if minmn > p # case where A is rank-deficient
Δgauge = abs(zero(eltype(Q)))
if !iszerotangent(ΔQ)
# in this case the number Householder reflections will
# change upon small variations, and all of the remaining
# columns of ΔQ should be zero for a gauge-invariant
# cost function
ΔQ2 = view(ΔQ, (p + 1):size(Q, 1), :)
Δgauge = max(Δgauge, norm(ΔQ2, Inf))
end
if !iszerotangent(ΔL)
ΔL22 = view(ΔL, (p + 1):m, (p + 1):minmn)
Δgauge = max(Δgauge, norm(ΔL22, Inf))
if isa(ΔA, Array) # not GPU friendly
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment here about refactoring into a separate function

if minmn > p # case where A is rank-deficient
Δgauge = abs(zero(eltype(Q)))
if !iszerotangent(ΔQ)
# in this case the number Householder reflections will
# change upon small variations, and all of the remaining
# columns of ΔQ should be zero for a gauge-invariant
# cost function
ΔQ2 = view(ΔQ, (p + 1):size(Q, 1), :)
Δgauge = max(Δgauge, norm(ΔQ2, Inf))
end
if !iszerotangent(ΔL)
ΔL22 = view(ΔL, (p + 1):m, (p + 1):minmn)
Δgauge = max(Δgauge, norm(ΔL22, Inf))
end
Δgauge ≤ gauge_atol ||
@warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
end
Δgauge ≤ gauge_atol ||
@warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
end

ΔQ̃ = zero!(similar(Q, (p, n)))
if !iszerotangent(ΔQ)
ΔQ1 = view(ΔQ, 1:p, :)
copy!(ΔQ̃, ΔQ1)
ΔQ̃ .= ΔQ1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this also a GPU thing? copy! does not work but broadcasting does?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah

if p < size(Q, 1)
Q2 = view(Q, (p + 1):size(Q, 1), :)
ΔQ2 = view(ΔQ, (p + 1):size(Q, 1), :)
Expand All @@ -69,9 +71,11 @@ function lq_pullback!(
# how the full Q2 will change, but this we omit for now, and we consider
# Q2' * ΔQ2 as a gauge dependent quantity.
ΔQ2Q1ᴴ = ΔQ2 * Q1'
Δgauge = norm(mul!(copy(ΔQ2), ΔQ2Q1ᴴ, Q1, -1, 1), Inf)
Δgauge ≤ gauge_atol ||
@warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
if isa(ΔA, Array) # not GPU friendly
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment about refactoring into a function

Δgauge = norm(mul!(copy(ΔQ2), ΔQ2Q1ᴴ, Q1, -1, 1), Inf)
Δgauge ≤ gauge_atol ||
@warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
end
ΔQ̃ = mul!(ΔQ̃, ΔQ2Q1ᴴ', Q2, -1, 1)
end
end
Expand All @@ -95,8 +99,10 @@ function lq_pullback!(
Md = diagview(M)
Md .= real.(Md)
end
ldiv!(LowerTriangular(L11)', M)
ldiv!(LowerTriangular(L11)', ΔQ̃)
# not GPU friendly...
L11arr = typeof(L)(L11)
ldiv!(LowerTriangular(L11arr)', M)
ldiv!(LowerTriangular(L11arr)', ΔQ̃)
ΔA1 = mul!(ΔA1, M, Q1, +1, 1)
ΔA1 .+= ΔQ̃
return ΔA
Expand Down
6 changes: 3 additions & 3 deletions src/pullbacks/polar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ function left_polar_pullback!(ΔA::AbstractMatrix, A, WP, ΔWP; kwargs...)
if !iszerotangent(ΔW)
ΔWP = ΔW / P
WdΔWP = W' * ΔWP
ΔWP = mul!(ΔWP, W, WdΔWP, -1, 1)
ΔWP .-= W * WdΔWP
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this still a Diagonal thing that I was missing? It's not super nice to have to allocate and broadcast when mul! is really what we want, if not could we maybe use _mul!! and use dispatch to select between the two?

ΔA .+= ΔWP
end
return ΔA
Expand All @@ -48,11 +48,11 @@ function right_polar_pullback!(ΔA::AbstractMatrix, A, PWᴴ, ΔPWᴴ; kwargs...
!iszerotangent(ΔP) && mul!(M, P, ΔP, -1, 1)
C = sylvester(P, P, M' - M)
C .+= ΔP
ΔA = mul!(ΔA, C, Wᴴ, 1, 1)
ΔA .+= C * Wᴴ
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same _mul!! comment here

if !iszerotangent(ΔWᴴ)
PΔWᴴ = P \ ΔWᴴ
PΔWᴴW = PΔWᴴ * Wᴴ'
PΔWᴴ = mul!(PΔWᴴ, PΔWᴴW, Wᴴ, -1, 1)
PΔWᴴ .-= PΔWᴴW * Wᴴ
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same _mul!! comment here

ΔA .+= PΔWᴴ
end
return ΔA
Expand Down
52 changes: 29 additions & 23 deletions src/pullbacks/qr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,27 +37,29 @@ function qr_pullback!(
ΔA1 = view(ΔA, :, 1:p)
ΔA2 = view(ΔA, :, (p + 1):n)

if minmn > p # case where A is rank-deficient
Δgauge = abs(zero(eltype(Q)))
if !iszerotangent(ΔQ)
# in this case the number Householder reflections will
# change upon small variations, and all of the remaining
# columns of ΔQ should be zero for a gauge-invariant
# cost function
ΔQ2 = view(ΔQ, :, (p + 1):size(Q, 2))
Δgauge = max(Δgauge, norm(ΔQ2, Inf))
end
if !iszerotangent(ΔR)
ΔR22 = view(ΔR, (p + 1):minmn, (p + 1):n)
Δgauge = max(Δgauge, norm(ΔR22, Inf))
if isa(ΔA, Array) # not GPU friendly
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same refactor into function comment here

if minmn > p # case where A is rank-deficient
Δgauge = abs(zero(eltype(Q)))
if !iszerotangent(ΔQ)
# in this case the number Householder reflections will
# change upon small variations, and all of the remaining
# columns of ΔQ should be zero for a gauge-invariant
# cost function
ΔQ2 = view(ΔQ, :, (p + 1):size(Q, 2))
Δgauge = max(Δgauge, norm(ΔQ2, Inf))
end
if !iszerotangent(ΔR)
ΔR22 = view(ΔR, (p + 1):minmn, (p + 1):n)
Δgauge = max(Δgauge, norm(ΔR22, Inf))
end
Δgauge ≤ gauge_atol ||
@warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
end
Δgauge ≤ gauge_atol ||
@warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
end

ΔQ̃ = zero!(similar(Q, (m, p)))
if !iszerotangent(ΔQ)
copy!(ΔQ̃, view(ΔQ, :, 1:p))
ΔQ̃ .= view(ΔQ, :, 1:p)
if p < size(Q, 2)
Q2 = view(Q, :, (p + 1):size(Q, 2))
ΔQ2 = view(ΔQ, :, (p + 1):size(Q, 2))
Expand All @@ -69,9 +71,11 @@ function qr_pullback!(
# how the full Q2 will change, but this we omit for now, and we consider
# Q2' * ΔQ2 as a gauge dependent quantity.
Q1dΔQ2 = Q1' * ΔQ2
Δgauge = norm(mul!(copy(ΔQ2), Q1, Q1dΔQ2, -1, 1), Inf)
Δgauge ≤ gauge_atol ||
@warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
if isa(ΔA, Array) # not GPU friendly
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same refactor into a function comment here

Δgauge = norm(mul!(copy(ΔQ2), Q1, Q1dΔQ2, -1, 1), Inf)
Δgauge ≤ gauge_atol ||
@warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
end
ΔQ̃ = mul!(ΔQ̃, Q2, Q1dΔQ2', -1, 1)
end
end
Expand All @@ -87,16 +91,18 @@ function qr_pullback!(
M = zero!(similar(R, (p, p)))
if !iszerotangent(ΔR)
ΔR11 = view(ΔR, 1:p, 1:p)
M = mul!(M, ΔR11, R11', 1, 1)
M += ΔR11 * R11'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same _mul!! comment

end
M = mul!(M, Q1', ΔQ̃, -1, 1)
M -= Q1' * ΔQ̃
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same _mul!! comment

view(M, lowertriangularind(M)) .= conj.(view(M, uppertriangularind(M)))
if eltype(M) <: Complex
Md = diagview(M)
Md .= real.(Md)
end
rdiv!(M, UpperTriangular(R11)')
rdiv!(ΔQ̃, UpperTriangular(R11)')
# not GPU-friendly...
R11arr = typeof(R)(R11)
rdiv!(M, UpperTriangular(R11arr)')
rdiv!(ΔQ̃, UpperTriangular(R11arr)')
ΔA1 = mul!(ΔA1, Q1, M, +1, 1)
ΔA1 .+= ΔQ̃
return ΔA
Expand Down
20 changes: 11 additions & 9 deletions src/pullbacks/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ which `abs(S[i] - S[j]) < degeneracy_atol`, is not small compared to `gauge_atol
"""
function svd_pullback!(
ΔA::AbstractMatrix, A, USVᴴ, ΔUSVᴴ, ind = Colon();
rank_atol::Real = default_pullback_rank_atol(USVᴴ[2]),
degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]),
rank_atol::Real = default_pullback_rank_atol(diagview(USVᴴ[2])),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we handle this conversion in the functions instead of at the callsite?

degeneracy_atol::Real = default_pullback_rank_atol(diagview(USVᴴ[2])),
gauge_atol::Real = default_pullback_gauge_atol(ΔUSVᴴ[1], ΔUSVᴴ[3])
)
# Extract the SVD components
Expand All @@ -33,7 +33,7 @@ function svd_pullback!(
minmn = min(m, n)
S = diagview(Smat)
length(S) == minmn || throw(DimensionMismatch("length of S ($(length(S))) does not matrix minimum dimension of U, Vᴴ ($minmn)"))
r = searchsortedlast(S, rank_atol; rev = true) # rank
r = findlast(s -> s ≥ rank_atol, S) # rank
Ur = view(U, :, 1:r)
Vᴴr = view(Vᴴ, 1:r, :)
Sr = view(S, 1:r)
Expand Down Expand Up @@ -71,9 +71,11 @@ function svd_pullback!(

# check whether cotangents arise from gauge-invariance objective function
mask = abs.(Sr' .- Sr) .< degeneracy_atol
Δgauge = norm(view(aUΔU, mask) + view(aVΔV, mask), Inf)
Δgauge ≤ gauge_atol ||
@warn "`svd` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
if isa(ΔA, Array) # norm check not GPU friendly
Δgauge = norm(view(aUΔU, mask) + view(aVΔV, mask), Inf)
Δgauge ≤ gauge_atol ||
@warn "`svd` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
end

UdΔAV = (aUΔU .+ aVΔV) .* inv_safe.(Sr' .- Sr, degeneracy_atol) .+
(aUΔU .- aVΔV) .* inv_safe.(Sr' .+ Sr, degeneracy_atol)
Expand All @@ -84,18 +86,18 @@ function svd_pullback!(
length(indS) == pS || throw(DimensionMismatch("length of selected S diagonals ($(length(indS))) does not match length of ΔS diagonal ($(length(ΔS)))"))
view(diagview(UdΔAV), indS) .+= real.(ΔS)
end
ΔA = mul!(ΔA, Ur, UdΔAV * Vᴴr, 1, 1) # add the contribution to ΔA
ΔA .+= Ur * UdΔAV * Vᴴr # add the contribution to ΔA

# Add the remaining contributions
if m > r && !iszerotangent(ΔU) # remaining ΔU is already orthogonal to Ur
Sp = view(S, indU)
Vᴴp = view(Vᴴ, indU, :)
ΔA = mul!(ΔA, ΔU ./ Sp', Vᴴp, 1, 1)
ΔA .+= (ΔU ./ Sp') * Vᴴp
end
if n > r && !iszerotangent(ΔVᴴ) # remaining ΔV is already orthogonal to Vᴴr
Sp = view(S, indV)
Up = view(U, :, indV)
ΔA = mul!(ΔA, Up, Sp .\ ΔVᴴ, 1, 1)
ΔA .+= Up * (Sp .\ ΔVᴴ)
end
return ΔA
end
Expand Down
Loading
Loading