From 09d7f690c9b4da5bfc6faf15df3c067ec0262005 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 23 Dec 2025 02:55:12 -0500 Subject: [PATCH 01/11] Use Testsuite for AD tests --- Project.toml | 2 +- .../MatrixAlgebraKitCUDAExt.jl | 21 +- ext/MatrixAlgebraKitChainRulesCoreExt.jl | 21 + .../MatrixAlgebraKitMooncakeExt.jl | 16 +- src/pullbacks/eig.jl | 12 +- src/pullbacks/lq.jl | 48 +- src/pullbacks/polar.jl | 6 +- src/pullbacks/qr.jl | 52 +- src/pullbacks/svd.jl | 20 +- test/ad_utils.jl | 62 -- test/chainrules.jl | 592 +---------------- test/mooncake.jl | 619 +----------------- test/testsuite/TestSuite.jl | 4 + test/testsuite/ad_utils.jl | 418 ++++++++++++ test/testsuite/chainrules.jl | 612 +++++++++++++++++ test/testsuite/mooncake.jl | 483 ++++++++++++++ 16 files changed, 1686 insertions(+), 1302 deletions(-) delete mode 100644 test/ad_utils.jl create mode 100644 test/testsuite/ad_utils.jl create mode 100644 test/testsuite/chainrules.jl create mode 100644 test/testsuite/mooncake.jl diff --git a/Project.toml b/Project.toml index 72bc22be..6d2e9c02 100644 --- a/Project.toml +++ b/Project.toml @@ -32,7 +32,7 @@ GenericLinearAlgebra = "0.3.19" GenericSchur = "0.5.6" JET = "0.9, 0.10" LinearAlgebra = "1" -Mooncake = "0.4.183" +Mooncake = "0.4.195" ParallelTestRunner = "2" Random = "1" SafeTestsets = "0.1" diff --git a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl index 8bb09db1..10322a7b 100644 --- a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl +++ b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl @@ -3,7 +3,7 @@ module MatrixAlgebraKitCUDAExt using MatrixAlgebraKit using MatrixAlgebraKit: @algdef, Algorithm, check_input using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular! -using MatrixAlgebraKit: diagview, sign_safe +using MatrixAlgebraKit: diagview, sign_safe, default_pullback_gauge_atol, default_pullback_rank_atol using MatrixAlgebraKit: LQViaTransposedQR, TruncationByValue, AbstractAlgorithm using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm, default_eigh_algorithm import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_Xgesvdr!, _gpu_gesvdj!, _gpu_geev! @@ -195,4 +195,23 @@ end MatrixAlgebraKit._ind_intersect(A::CuVector{Int}, B::CuVector{Int}) = MatrixAlgebraKit._ind_intersect(collect(A), collect(B)) +MatrixAlgebraKit.default_pullback_rank_atol(A::AnyCuArray) = eps(norm(CuArray(A), Inf))^(3 / 4) +MatrixAlgebraKit.default_pullback_gauge_atol(A::AnyCuArray) = MatrixAlgebraKit.iszerotangent(A) ? 0 : eps(norm(CuArray(A), Inf))^(3 / 4) +function MatrixAlgebraKit.default_pullback_gauge_atol(A::AnyCuArray, As...) + As′ = filter(!MatrixAlgebraKit.iszerotangent, (A, As...)) + return isempty(As′) ? 0 : eps(norm(CuArray.(As′), Inf))^(3 / 4) +end + +function LinearAlgebra.sylvester(A::AnyCuMatrix, B::AnyCuMatrix, C::AnyCuMatrix) + #=m = size(A, 1) + n = size(B, 2) + I_n = fill!(similar(A, n), one(eltype(A))) + I_m = fill!(similar(B, m), one(eltype(B))) + L = kron(diagm(I_n), A) + kron(adjoint(B), diagm(I_m)) + x_vec = L \ -vec(C) + X = CuMatrix(reshape(x_vec, m, n))=# + hX = sylvester(collect(A), collect(B), collect(C)) + return CuArray(hX) +end + end diff --git a/ext/MatrixAlgebraKitChainRulesCoreExt.jl b/ext/MatrixAlgebraKitChainRulesCoreExt.jl index c2de1758..400b2a79 100644 --- a/ext/MatrixAlgebraKitChainRulesCoreExt.jl +++ b/ext/MatrixAlgebraKitChainRulesCoreExt.jl @@ -95,6 +95,9 @@ for eig in (:eig, :eigh) eig_t! = Symbol(eig, "_trunc!") eig_t_pb = Symbol(eig, "_trunc_pullback") _make_eig_t_pb = Symbol("_make_", eig_t_pb) + eig_t_ne! = Symbol(eig, "_trunc_no_error!") + eig_t_ne_pb = Symbol(eig, "_trunc_no_error_pullback") + _make_eig_t_ne_pb = Symbol("_make_", eig_t_ne_pb) eig_v = Symbol(eig, "_vals") eig_v! = Symbol(eig_v, "!") eig_v_pb = Symbol(eig_v, "_pullback") @@ -136,6 +139,24 @@ for eig in (:eig, :eigh) end return $eig_t_pb end + function ChainRulesCore.rrule(::typeof($eig_t_ne!), A, DV, alg::TruncatedAlgorithm) + Ac = copy_input($eig_f, A) + DV = $(eig_f!)(Ac, DV, alg.alg) + DV′, ind = MatrixAlgebraKit.truncate($eig_t!, DV, alg.trunc) + return DV′, $(_make_eig_t_ne_pb)(A, DV, ind) + end + function $(_make_eig_t_ne_pb)(A, DV, ind) + function $eig_t_ne_pb(ΔDV) + ΔA = zero(A) + ΔD, ΔV = ΔDV + MatrixAlgebraKit.$eig_pb!(ΔA, A, DV, unthunk.((ΔD, ΔV)), ind) + return NoTangent(), ΔA, ZeroTangent(), NoTangent() + end + function $eig_t_ne_pb(::Tuple{ZeroTangent, ZeroTangent}) # is this extra definition useful? + return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent() + end + return $eig_t_ne_pb + end function ChainRulesCore.rrule(::typeof($eig_v!), A, D, alg) DV = $eig_f(A, alg) function $eig_v_pb(ΔD) diff --git a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl index f6feda8b..75701ba5 100644 --- a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl +++ b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl @@ -3,7 +3,7 @@ module MatrixAlgebraKitMooncakeExt using Mooncake using Mooncake: DefaultCtx, CoDual, Dual, NoRData, rrule!!, frule!!, arrayify, @is_primitive using MatrixAlgebraKit -using MatrixAlgebraKit: inv_safe, diagview, copy_input +using MatrixAlgebraKit: inv_safe, diagview, copy_input, initialize_output using MatrixAlgebraKit: qr_pullback!, lq_pullback! using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback! using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_vals_pullback! @@ -18,14 +18,24 @@ Mooncake.tangent_type(::Type{<:MatrixAlgebraKit.AbstractAlgorithm}) = Mooncake.N @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(copy_input), Any, Any} function Mooncake.rrule!!(::CoDual{typeof(copy_input)}, f_df::CoDual, A_dA::CoDual) Ac = copy_input(Mooncake.primal(f_df), Mooncake.primal(A_dA)) - dAc = Mooncake.zero_tangent(Ac) + Ac_dAc = Mooncake.zero_fcodual(Ac) + dAc = Mooncake.tangent(Ac_dAc) function copy_input_pb(::NoRData) Mooncake.increment!!(Mooncake.tangent(A_dA), dAc) return NoRData(), NoRData(), NoRData() end - return CoDual(Ac, dAc), copy_input_pb + return Ac_dAc, copy_input_pb end +@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(initialize_output), Any, Any, Any} +function Mooncake.rrule!!(::CoDual{typeof(initialize_output)}, f_df::CoDual, A_dA::CoDual, alg_dalg::CoDual) + output = initialize_output(Mooncake.primal(f_df), Mooncake.primal(A_dA), Mooncake.primal(alg_dalg)) + output_doutput = Mooncake.zero_fcodual(output) + initialize_output_pb(::NoRData) = (NoRData(), NoRData(), NoRData(), NoRData()) + return output_doutput, initialize_output_pb +end + + # two-argument in-place factorizations like LQ, QR, EIG for (f!, f, pb, adj) in ( (:qr_full!, :qr_full, :qr_pullback!, :qr_adjoint), diff --git a/src/pullbacks/eig.jl b/src/pullbacks/eig.jl index 6b89b64f..5c79f9f5 100644 --- a/src/pullbacks/eig.jl +++ b/src/pullbacks/eig.jl @@ -39,12 +39,16 @@ function eig_pullback!( VᴴΔV = fill!(similar(V), 0) indV = axes(V, 2)[ind] length(indV) == pV || throw(DimensionMismatch()) - mul!(view(VᴴΔV, :, indV), V', ΔV) + VᴴΔV[:, indV] .= V' * ΔV + #mul!(view(VᴴΔV, :, indV), V', ΔV) mask = abs.(transpose(D) .- D) .< degeneracy_atol - Δgauge = norm(view(VᴴΔV, mask), Inf) - Δgauge ≤ gauge_atol || - @warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" + if isa(ΔA, Array) + # not GPU friendly... + Δgauge = norm(view(VᴴΔV, mask), Inf) + Δgauge ≤ gauge_atol || + @warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" + end VᴴΔV .*= conj.(inv_safe.(transpose(D) .- D, degeneracy_atol)) diff --git a/src/pullbacks/lq.jl b/src/pullbacks/lq.jl index b30fe198..cc8ce1fd 100644 --- a/src/pullbacks/lq.jl +++ b/src/pullbacks/lq.jl @@ -36,28 +36,30 @@ function lq_pullback!( ΔA1 = view(ΔA, 1:p, :) ΔA2 = view(ΔA, (p + 1):m, :) - if minmn > p # case where A is rank-deficient - Δgauge = abs(zero(eltype(Q))) - if !iszerotangent(ΔQ) - # in this case the number Householder reflections will - # change upon small variations, and all of the remaining - # columns of ΔQ should be zero for a gauge-invariant - # cost function - ΔQ2 = view(ΔQ, (p + 1):size(Q, 1), :) - Δgauge = max(Δgauge, norm(ΔQ2, Inf)) - end - if !iszerotangent(ΔL) - ΔL22 = view(ΔL, (p + 1):m, (p + 1):minmn) - Δgauge = max(Δgauge, norm(ΔL22, Inf)) + if isa(ΔA, Array) # not GPU friendly + if minmn > p # case where A is rank-deficient + Δgauge = abs(zero(eltype(Q))) + if !iszerotangent(ΔQ) + # in this case the number Householder reflections will + # change upon small variations, and all of the remaining + # columns of ΔQ should be zero for a gauge-invariant + # cost function + ΔQ2 = view(ΔQ, (p + 1):size(Q, 1), :) + Δgauge = max(Δgauge, norm(ΔQ2, Inf)) + end + if !iszerotangent(ΔL) + ΔL22 = view(ΔL, (p + 1):m, (p + 1):minmn) + Δgauge = max(Δgauge, norm(ΔL22, Inf)) + end + Δgauge ≤ gauge_atol || + @warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" end - Δgauge ≤ gauge_atol || - @warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" end ΔQ̃ = zero!(similar(Q, (p, n))) if !iszerotangent(ΔQ) ΔQ1 = view(ΔQ, 1:p, :) - copy!(ΔQ̃, ΔQ1) + ΔQ̃ .= ΔQ1 if p < size(Q, 1) Q2 = view(Q, (p + 1):size(Q, 1), :) ΔQ2 = view(ΔQ, (p + 1):size(Q, 1), :) @@ -69,9 +71,11 @@ function lq_pullback!( # how the full Q2 will change, but this we omit for now, and we consider # Q2' * ΔQ2 as a gauge dependent quantity. ΔQ2Q1ᴴ = ΔQ2 * Q1' - Δgauge = norm(mul!(copy(ΔQ2), ΔQ2Q1ᴴ, Q1, -1, 1), Inf) - Δgauge ≤ gauge_atol || - @warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" + if isa(ΔA, Array) # not GPU friendly + Δgauge = norm(mul!(copy(ΔQ2), ΔQ2Q1ᴴ, Q1, -1, 1), Inf) + Δgauge ≤ gauge_atol || + @warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" + end ΔQ̃ = mul!(ΔQ̃, ΔQ2Q1ᴴ', Q2, -1, 1) end end @@ -95,8 +99,10 @@ function lq_pullback!( Md = diagview(M) Md .= real.(Md) end - ldiv!(LowerTriangular(L11)', M) - ldiv!(LowerTriangular(L11)', ΔQ̃) + # not GPU friendly... + L11arr = typeof(L)(L11) + ldiv!(LowerTriangular(L11arr)', M) + ldiv!(LowerTriangular(L11arr)', ΔQ̃) ΔA1 = mul!(ΔA1, M, Q1, +1, 1) ΔA1 .+= ΔQ̃ return ΔA diff --git a/src/pullbacks/polar.jl b/src/pullbacks/polar.jl index 1c6de509..0caed08b 100644 --- a/src/pullbacks/polar.jl +++ b/src/pullbacks/polar.jl @@ -22,7 +22,7 @@ function left_polar_pullback!(ΔA::AbstractMatrix, A, WP, ΔWP; kwargs...) if !iszerotangent(ΔW) ΔWP = ΔW / P WdΔWP = W' * ΔWP - ΔWP = mul!(ΔWP, W, WdΔWP, -1, 1) + ΔWP .-= W * WdΔWP ΔA .+= ΔWP end return ΔA @@ -48,11 +48,11 @@ function right_polar_pullback!(ΔA::AbstractMatrix, A, PWᴴ, ΔPWᴴ; kwargs... !iszerotangent(ΔP) && mul!(M, P, ΔP, -1, 1) C = sylvester(P, P, M' - M) C .+= ΔP - ΔA = mul!(ΔA, C, Wᴴ, 1, 1) + ΔA .+= C * Wᴴ if !iszerotangent(ΔWᴴ) PΔWᴴ = P \ ΔWᴴ PΔWᴴW = PΔWᴴ * Wᴴ' - PΔWᴴ = mul!(PΔWᴴ, PΔWᴴW, Wᴴ, -1, 1) + PΔWᴴ .-= PΔWᴴW * Wᴴ ΔA .+= PΔWᴴ end return ΔA diff --git a/src/pullbacks/qr.jl b/src/pullbacks/qr.jl index 888029be..a054ae21 100644 --- a/src/pullbacks/qr.jl +++ b/src/pullbacks/qr.jl @@ -37,27 +37,29 @@ function qr_pullback!( ΔA1 = view(ΔA, :, 1:p) ΔA2 = view(ΔA, :, (p + 1):n) - if minmn > p # case where A is rank-deficient - Δgauge = abs(zero(eltype(Q))) - if !iszerotangent(ΔQ) - # in this case the number Householder reflections will - # change upon small variations, and all of the remaining - # columns of ΔQ should be zero for a gauge-invariant - # cost function - ΔQ2 = view(ΔQ, :, (p + 1):size(Q, 2)) - Δgauge = max(Δgauge, norm(ΔQ2, Inf)) - end - if !iszerotangent(ΔR) - ΔR22 = view(ΔR, (p + 1):minmn, (p + 1):n) - Δgauge = max(Δgauge, norm(ΔR22, Inf)) + if isa(ΔA, Array) # not GPU friendly + if minmn > p # case where A is rank-deficient + Δgauge = abs(zero(eltype(Q))) + if !iszerotangent(ΔQ) + # in this case the number Householder reflections will + # change upon small variations, and all of the remaining + # columns of ΔQ should be zero for a gauge-invariant + # cost function + ΔQ2 = view(ΔQ, :, (p + 1):size(Q, 2)) + Δgauge = max(Δgauge, norm(ΔQ2, Inf)) + end + if !iszerotangent(ΔR) + ΔR22 = view(ΔR, (p + 1):minmn, (p + 1):n) + Δgauge = max(Δgauge, norm(ΔR22, Inf)) + end + Δgauge ≤ gauge_atol || + @warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" end - Δgauge ≤ gauge_atol || - @warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" end ΔQ̃ = zero!(similar(Q, (m, p))) if !iszerotangent(ΔQ) - copy!(ΔQ̃, view(ΔQ, :, 1:p)) + ΔQ̃ .= view(ΔQ, :, 1:p) if p < size(Q, 2) Q2 = view(Q, :, (p + 1):size(Q, 2)) ΔQ2 = view(ΔQ, :, (p + 1):size(Q, 2)) @@ -69,9 +71,11 @@ function qr_pullback!( # how the full Q2 will change, but this we omit for now, and we consider # Q2' * ΔQ2 as a gauge dependent quantity. Q1dΔQ2 = Q1' * ΔQ2 - Δgauge = norm(mul!(copy(ΔQ2), Q1, Q1dΔQ2, -1, 1), Inf) - Δgauge ≤ gauge_atol || - @warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" + if isa(ΔA, Array) # not GPU friendly + Δgauge = norm(mul!(copy(ΔQ2), Q1, Q1dΔQ2, -1, 1), Inf) + Δgauge ≤ gauge_atol || + @warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" + end ΔQ̃ = mul!(ΔQ̃, Q2, Q1dΔQ2', -1, 1) end end @@ -87,16 +91,18 @@ function qr_pullback!( M = zero!(similar(R, (p, p))) if !iszerotangent(ΔR) ΔR11 = view(ΔR, 1:p, 1:p) - M = mul!(M, ΔR11, R11', 1, 1) + M += ΔR11 * R11' end - M = mul!(M, Q1', ΔQ̃, -1, 1) + M -= Q1' * ΔQ̃ view(M, lowertriangularind(M)) .= conj.(view(M, uppertriangularind(M))) if eltype(M) <: Complex Md = diagview(M) Md .= real.(Md) end - rdiv!(M, UpperTriangular(R11)') - rdiv!(ΔQ̃, UpperTriangular(R11)') + # not GPU-friendly... + R11arr = typeof(R)(R11) + rdiv!(M, UpperTriangular(R11arr)') + rdiv!(ΔQ̃, UpperTriangular(R11arr)') ΔA1 = mul!(ΔA1, Q1, M, +1, 1) ΔA1 .+= ΔQ̃ return ΔA diff --git a/src/pullbacks/svd.jl b/src/pullbacks/svd.jl index 1608343e..dd393b04 100644 --- a/src/pullbacks/svd.jl +++ b/src/pullbacks/svd.jl @@ -22,8 +22,8 @@ which `abs(S[i] - S[j]) < degeneracy_atol`, is not small compared to `gauge_atol """ function svd_pullback!( ΔA::AbstractMatrix, A, USVᴴ, ΔUSVᴴ, ind = Colon(); - rank_atol::Real = default_pullback_rank_atol(USVᴴ[2]), - degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]), + rank_atol::Real = default_pullback_rank_atol(diagview(USVᴴ[2])), + degeneracy_atol::Real = default_pullback_rank_atol(diagview(USVᴴ[2])), gauge_atol::Real = default_pullback_gauge_atol(ΔUSVᴴ[1], ΔUSVᴴ[3]) ) # Extract the SVD components @@ -33,7 +33,7 @@ function svd_pullback!( minmn = min(m, n) S = diagview(Smat) length(S) == minmn || throw(DimensionMismatch("length of S ($(length(S))) does not matrix minimum dimension of U, Vᴴ ($minmn)")) - r = searchsortedlast(S, rank_atol; rev = true) # rank + r = findlast(s -> s ≥ rank_atol, S) # rank Ur = view(U, :, 1:r) Vᴴr = view(Vᴴ, 1:r, :) Sr = view(S, 1:r) @@ -71,9 +71,11 @@ function svd_pullback!( # check whether cotangents arise from gauge-invariance objective function mask = abs.(Sr' .- Sr) .< degeneracy_atol - Δgauge = norm(view(aUΔU, mask) + view(aVΔV, mask), Inf) - Δgauge ≤ gauge_atol || - @warn "`svd` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" + if isa(ΔA, Array) # norm check not GPU friendly + Δgauge = norm(view(aUΔU, mask) + view(aVΔV, mask), Inf) + Δgauge ≤ gauge_atol || + @warn "`svd` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" + end UdΔAV = (aUΔU .+ aVΔV) .* inv_safe.(Sr' .- Sr, degeneracy_atol) .+ (aUΔU .- aVΔV) .* inv_safe.(Sr' .+ Sr, degeneracy_atol) @@ -84,18 +86,18 @@ function svd_pullback!( length(indS) == pS || throw(DimensionMismatch("length of selected S diagonals ($(length(indS))) does not match length of ΔS diagonal ($(length(ΔS)))")) view(diagview(UdΔAV), indS) .+= real.(ΔS) end - ΔA = mul!(ΔA, Ur, UdΔAV * Vᴴr, 1, 1) # add the contribution to ΔA + ΔA .+= Ur * UdΔAV * Vᴴr # add the contribution to ΔA # Add the remaining contributions if m > r && !iszerotangent(ΔU) # remaining ΔU is already orthogonal to Ur Sp = view(S, indU) Vᴴp = view(Vᴴ, indU, :) - ΔA = mul!(ΔA, ΔU ./ Sp', Vᴴp, 1, 1) + ΔA .+= (ΔU ./ Sp') * Vᴴp end if n > r && !iszerotangent(ΔVᴴ) # remaining ΔV is already orthogonal to Vᴴr Sp = view(S, indV) Up = view(U, :, indV) - ΔA = mul!(ΔA, Up, Sp .\ ΔVᴴ, 1, 1) + ΔA .+= Up * (Sp .\ ΔVᴴ) end return ΔA end diff --git a/test/ad_utils.jl b/test/ad_utils.jl deleted file mode 100644 index fccc6c00..00000000 --- a/test/ad_utils.jl +++ /dev/null @@ -1,62 +0,0 @@ -function remove_svdgauge_dependence!( - ΔU, ΔVᴴ, U, S, Vᴴ; - degeneracy_atol = MatrixAlgebraKit.default_pullback_gaugetol(S) - ) - gaugepart = mul!(U' * ΔU, Vᴴ, ΔVᴴ', true, true) - gaugepart = project_antihermitian!(gaugepart) - gaugepart[abs.(transpose(diagview(S)) .- diagview(S)) .>= degeneracy_atol] .= 0 - mul!(ΔU, U, gaugepart, -1, 1) - return ΔU, ΔVᴴ -end -function remove_eiggauge_dependence!( - ΔV, D, V; - degeneracy_atol = MatrixAlgebraKit.default_pullback_gaugetol(S) - ) - gaugepart = V' * ΔV - gaugepart[abs.(transpose(diagview(D)) .- diagview(D)) .>= degeneracy_atol] .= 0 - mul!(ΔV, V / (V' * V), gaugepart, -1, 1) - return ΔV -end -function remove_eighgauge_dependence!( - ΔV, D, V; - degeneracy_atol = MatrixAlgebraKit.default_pullback_gaugetol(S) - ) - gaugepart = V' * ΔV - gaugepart = project_antihermitian!(gaugepart) - gaugepart[abs.(transpose(diagview(D)) .- diagview(D)) .>= degeneracy_atol] .= 0 - mul!(ΔV, V, gaugepart, -1, 1) - return ΔV -end -function stabilize_eigvals!(D::AbstractVector) - absD = abs.(D) - p = invperm(sortperm(absD)) # rank of abs(D) - # account for exact degeneracies in absolute value when having complex conjugate pairs - for i in 1:(length(D) - 1) - if absD[i] == absD[i + 1] # conjugate pairs will appear sequentially - p[p .>= p[i + 1]] .-= 1 # lower the rank of all higher ones - end - end - n = maximum(p) - # rescale eigenvalues so that they lie on distinct radii in the complex plane - # that are chosen randomly in non-overlapping intervals [10 * k/n, 10 * (k+0.5)/n)] for k=1,...,n - radii = 10 .* ((1:n) .+ rand(real(eltype(D)), n) ./ 2) ./ n - for i in 1:length(D) - D[i] = sign(D[i]) * radii[p[i]] - end - return D -end -function make_eig_matrix(rng, T, n) - A = randn(rng, T, n, n) - D, V = eig_full(A) - stabilize_eigvals!(diagview(D)) - Ac = V * D * inv(V) - return (T <: Real) ? real(Ac) : Ac -end -function make_eigh_matrix(rng, T, n) - A = project_hermitian!(randn(rng, T, n, n)) - D, V = eigh_full(A) - stabilize_eigvals!(diagview(D)) - return project_hermitian!(V * D * V') -end - -precision(::Type{T}) where {T <: Number} = sqrt(eps(real(T))) diff --git a/test/chainrules.jl b/test/chainrules.jl index a8b2fd3b..c0ab618a 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -1,590 +1,18 @@ using MatrixAlgebraKit using Test -using TestExtras -using StableRNGs -using ChainRulesCore, ChainRulesTestUtils, Zygote -using MatrixAlgebraKit: diagview, TruncatedAlgorithm, PolarViaSVD -using LinearAlgebra: UpperTriangular, Diagonal, Hermitian, mul! -include("ad_utils.jl") +#BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) +BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI -for f in - ( - :qr_compact, :qr_full, :qr_null, :lq_compact, :lq_full, :lq_null, - :eig_full, :eig_trunc, :eig_vals, :eigh_full, :eigh_trunc, :eigh_vals, - :svd_compact, :svd_trunc, :svd_trunc_no_error, :svd_vals, - :left_polar, :right_polar, - ) - copy_f = Symbol(:copy_, f) - f! = Symbol(f, '!') - _hermitian = startswith(string(f), "eigh") - @eval begin - function $copy_f(input, alg) - if $_hermitian - input = (input + input') / 2 - end - return $f(input, alg) - end - function ChainRulesCore.rrule(::typeof($copy_f), input, alg) - output = MatrixAlgebraKit.initialize_output($f!, input, alg) - if $_hermitian - input = (input + input') / 2 - else - input = copy(input) - end - output, pb = ChainRulesCore.rrule($f!, input, output, alg) - return output, x -> (NoTangent(), pb(x)[2], NoTangent()) - end - end -end - -@timedtestset "QR AD Rules with eltype $T" for T in (Float64, ComplexF64, Float32) - rng = StableRNG(12345) - m = 19 - @testset "size ($m, $n)" for n in (17, m, 23) - # qr_compact - atol = rtol = m * n * precision(T) - A = randn(rng, T, m, n) - minmn = min(m, n) - alg = LAPACK_HouseholderQR(; positive = true) - Q, R = copy_qr_compact(A, alg) - ΔQ = randn(rng, T, m, minmn) - ΔR = randn(rng, T, minmn, n) - ΔR2 = UpperTriangular(randn(rng, T, minmn, minmn)) - ΔN = Q * randn(rng, T, minmn, max(0, m - minmn)) - test_rrule( - copy_qr_compact, A, alg ⊢ NoTangent(); - output_tangent = (ΔQ, ΔR), atol = atol, rtol = rtol - ) - test_rrule( - copy_qr_null, A, alg ⊢ NoTangent(); - output_tangent = ΔN, atol = atol, rtol = rtol - ) - config = Zygote.ZygoteRuleConfig() - test_rrule( - config, qr_compact, A; - fkwargs = (; positive = true), output_tangent = (ΔQ, ΔR), - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, first ∘ qr_compact, A; - fkwargs = (; positive = true), output_tangent = ΔQ, - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, last ∘ qr_compact, A; - fkwargs = (; positive = true), output_tangent = ΔR, - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, qr_null, A; - fkwargs = (; positive = true), output_tangent = ΔN, - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - # qr_full - Q, R = copy_qr_full(A, alg) - Q1 = view(Q, 1:m, 1:minmn) - ΔQ = randn(rng, T, m, m) - ΔQ2 = view(ΔQ, :, (minmn + 1):m) - mul!(ΔQ2, Q1, Q1' * ΔQ2) - ΔR = randn(rng, T, m, n) - test_rrule( - copy_qr_full, A, alg ⊢ NoTangent(); - output_tangent = (ΔQ, ΔR), atol = atol, rtol = rtol - ) - config = Zygote.ZygoteRuleConfig() - test_rrule( - config, qr_full, A; - fkwargs = (; positive = true), output_tangent = (ΔQ, ΔR), - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - if m > n - _, null_pb = Zygote.pullback(qr_null, A, alg) - @test_logs (:warn,) null_pb(randn(rng, T, m, max(0, m - minmn))) - _, full_pb = Zygote.pullback(qr_full, A, alg) - @test_logs (:warn,) full_pb((randn(rng, T, m, m), randn(rng, T, m, n))) - end - # rank-deficient A - r = minmn - 5 - A = randn(rng, T, m, r) * randn(rng, T, r, n) - Q, R = qr_compact(A, alg) - ΔQ = randn(rng, T, m, minmn) - Q1 = view(Q, 1:m, 1:r) - Q2 = view(Q, 1:m, (r + 1):minmn) - ΔQ2 = view(ΔQ, 1:m, (r + 1):minmn) - ΔQ2 .= 0 - ΔR = randn(rng, T, minmn, n) - view(ΔR, (r + 1):minmn, :) .= 0 - test_rrule( - copy_qr_compact, A, alg ⊢ NoTangent(); - output_tangent = (ΔQ, ΔR), atol = atol, rtol = rtol - ) - config = Zygote.ZygoteRuleConfig() - test_rrule( - config, qr_compact, A; - fkwargs = (; positive = true), output_tangent = (ΔQ, ΔR), - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - end -end - -@timedtestset "LQ AD Rules with eltype $T" for T in (Float64, ComplexF64, Float32) - rng = StableRNG(12345) - m = 19 - @testset "size ($m, $n)" for n in (17, m, 23) - # lq_compact - atol = rtol = m * n * precision(T) - A = randn(rng, T, m, n) - minmn = min(m, n) - alg = LAPACK_HouseholderLQ(; positive = true) - L, Q = copy_lq_compact(A, alg) - ΔL = randn(rng, T, m, minmn) - ΔQ = randn(rng, T, minmn, n) - ΔNᴴ = randn(rng, T, max(0, n - minmn), minmn) * Q - test_rrule( - copy_lq_compact, A, alg ⊢ NoTangent(); - output_tangent = (ΔL, ΔQ), atol = atol, rtol = rtol - ) - test_rrule( - copy_lq_null, A, alg ⊢ NoTangent(); - output_tangent = ΔNᴴ, atol = atol, rtol = rtol - ) - config = Zygote.ZygoteRuleConfig() - test_rrule( - config, lq_compact, A; - fkwargs = (; positive = true), output_tangent = (ΔL, ΔQ), - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, first ∘ lq_compact, A; - fkwargs = (; positive = true), output_tangent = ΔL, - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, last ∘ lq_compact, A; - fkwargs = (; positive = true), output_tangent = ΔQ, - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, lq_null, A; - fkwargs = (; positive = true), output_tangent = ΔNᴴ, - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - # lq_full - L, Q = copy_lq_full(A, alg) - Q1 = view(Q, 1:minmn, 1:n) - ΔQ = randn(rng, T, n, n) - ΔQ2 = view(ΔQ, (minmn + 1):n, 1:n) - mul!(ΔQ2, ΔQ2 * Q1', Q1) - ΔL = randn(rng, T, m, n) - test_rrule( - copy_lq_full, A, alg ⊢ NoTangent(); - output_tangent = (ΔL, ΔQ), atol = atol, rtol = rtol - ) - config = Zygote.ZygoteRuleConfig() - test_rrule( - config, lq_full, A; - fkwargs = (; positive = true), output_tangent = (ΔL, ΔQ), - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - if m < n - Nᴴ, null_pb = Zygote.pullback(lq_null, A, alg) - @test_logs (:warn,) null_pb(randn(rng, T, max(0, n - minmn), n)) - _, full_pb = Zygote.pullback(lq_full, A, alg) - @test_logs (:warn,) full_pb((randn(rng, T, m, n), randn(rng, T, n, n))) - end - # rank-deficient A - r = minmn - 5 - A = randn(rng, T, m, r) * randn(rng, T, r, n) - L, Q = lq_compact(A, alg) - ΔL = randn(rng, T, m, minmn) - ΔQ = randn(rng, T, minmn, n) - Q1 = view(Q, 1:r, 1:n) - Q2 = view(Q, (r + 1):minmn, 1:n) - ΔQ2 = view(ΔQ, (r + 1):minmn, 1:n) - ΔQ2 .= 0 - view(ΔL, :, (r + 1):minmn) .= 0 - test_rrule( - copy_lq_compact, A, alg ⊢ NoTangent(); - output_tangent = (ΔL, ΔQ), atol = atol, rtol = rtol - ) - config = Zygote.ZygoteRuleConfig() - test_rrule( - config, lq_compact, A; - fkwargs = (; positive = true), output_tangent = (ΔL, ΔQ), - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - end -end - -@timedtestset "EIG AD Rules with eltype $T" for T in (Float64, ComplexF64, Float32) - rng = StableRNG(12345) - m = 19 - atol = rtol = m * m * precision(T) - A = make_eig_matrix(rng, T, m) - D, V = eig_full(A) - Ddiag = diagview(D) - ΔV = randn(rng, complex(T), m, m) - ΔV = remove_eiggauge_dependence!(ΔV, D, V; degeneracy_atol = atol) - ΔD = randn(rng, complex(T), m, m) - ΔD2 = Diagonal(randn(rng, complex(T), m)) - for alg in (LAPACK_Simple(), LAPACK_Expert()) - test_rrule( - copy_eig_full, A, alg ⊢ NoTangent(); output_tangent = (ΔD, ΔV), atol, rtol - ) - test_rrule( - copy_eig_full, A, alg ⊢ NoTangent(); output_tangent = (ΔD2, ΔV), atol, rtol - ) - test_rrule( - copy_eig_vals, A, alg ⊢ NoTangent(); output_tangent = diagview(ΔD), atol, rtol - ) - for r in 1:4:m - truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs)) - ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) - Dtrunc = Diagonal(diagview(D)[ind]) - Vtrunc = V[:, ind] - ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) - ΔVtrunc = ΔV[:, ind] - test_rrule( - copy_eig_trunc, A, truncalg ⊢ NoTangent(); - output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))), - atol = atol, rtol = rtol - ) - dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind) - dA2 = MatrixAlgebraKit.eig_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc)) - @test isapprox(dA1, dA2; atol = atol, rtol = rtol) - end - truncalg = TruncatedAlgorithm(alg, truncrank(5; by = real)) - ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) - Dtrunc = Diagonal(Ddiag[ind]) - Vtrunc = V[:, ind] - ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) - ΔVtrunc = ΔV[:, ind] - test_rrule( - copy_eig_trunc, A, truncalg ⊢ NoTangent(); - output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))), - atol = atol, rtol = rtol - ) - dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind) - dA2 = MatrixAlgebraKit.eig_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc)) - @test isapprox(dA1, dA2; atol = atol, rtol = rtol) - end - # Zygote part - config = Zygote.ZygoteRuleConfig() - test_rrule( - config, eig_full, A; - output_tangent = (ΔD, ΔV), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, eig_full, A; - output_tangent = (ΔD2, ΔV), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, first ∘ eig_full, A; - output_tangent = ΔD, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, last ∘ eig_full, A; - output_tangent = ΔV, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, eig_vals, A; - output_tangent = diagview(ΔD), atol, rtol, rrule_f = rrule_via_ad, check_inferred = false - ) -end - -@timedtestset "EIGH AD Rules with eltype $T" for T in (Float64, ComplexF64, Float32) - rng = StableRNG(12345) - m = 19 - atol = rtol = m * m * precision(T) - A = make_eigh_matrix(rng, T, m) - D, V = eigh_full(A) - Ddiag = diagview(D) - ΔV = randn(rng, T, m, m) - ΔV = remove_eighgauge_dependence!(ΔV, D, V; degeneracy_atol = atol) - ΔD = randn(rng, real(T), m, m) - ΔD2 = Diagonal(randn(rng, real(T), m)) - for alg in ( - LAPACK_QRIteration(), LAPACK_DivideAndConquer(), LAPACK_Bisection(), - LAPACK_MultipleRelativelyRobustRepresentations(), - ) - # copy_eigh_full includes a projector onto the Hermitian part of the matrix - test_rrule( - copy_eigh_full, A, alg ⊢ NoTangent(); output_tangent = (ΔD, ΔV), atol, rtol - ) - test_rrule( - copy_eigh_full, A, alg ⊢ NoTangent(); output_tangent = (ΔD2, ΔV), atol, rtol - ) - test_rrule( - copy_eigh_vals, A, alg ⊢ NoTangent(); output_tangent = diagview(ΔD), atol, rtol - ) - for r in 1:4:m - truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs)) - ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) - Dtrunc = Diagonal(diagview(D)[ind]) - Vtrunc = V[:, ind] - ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) - ΔVtrunc = ΔV[:, ind] - test_rrule( - copy_eigh_trunc, A, truncalg ⊢ NoTangent(); - output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))), - atol = atol, rtol = rtol - ) - dA1 = MatrixAlgebraKit.eigh_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind) - dA2 = MatrixAlgebraKit.eigh_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc)) - @test isapprox(dA1, dA2; atol = atol, rtol = rtol) - end - truncalg = TruncatedAlgorithm(alg, trunctol(; atol = maximum(abs, Ddiag) / 2)) - ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) - Dtrunc = Diagonal(diagview(D)[ind]) - Vtrunc = V[:, ind] - ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) - ΔVtrunc = ΔV[:, ind] - test_rrule( - copy_eigh_trunc, A, truncalg ⊢ NoTangent(); - output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))), - atol = atol, rtol = rtol - ) - dA1 = MatrixAlgebraKit.eigh_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind) - dA2 = MatrixAlgebraKit.eigh_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc)) - @test isapprox(dA1, dA2; atol = atol, rtol = rtol) - end - # Zygote part - config = Zygote.ZygoteRuleConfig() - # eigh_full does not include a projector onto the Hermitian part of the matrix - test_rrule( - config, eigh_full ∘ Matrix ∘ Hermitian, A; - output_tangent = (ΔD, ΔV), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, eigh_full ∘ Matrix ∘ Hermitian, A; - output_tangent = (ΔD2, ΔV), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, first ∘ eigh_full ∘ Matrix ∘ Hermitian, A; - output_tangent = ΔD, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, last ∘ eigh_full ∘ Matrix ∘ Hermitian, A; - output_tangent = ΔV, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, eigh_vals ∘ Matrix ∘ Hermitian, A; - output_tangent = diagview(ΔD), atol, rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - eigh_trunc2(A; kwargs...) = eigh_trunc(Matrix(Hermitian(A)); kwargs...) - for r in 1:4:m - trunc = truncrank(r; by = real) - ind = MatrixAlgebraKit.findtruncated(Ddiag, trunc) - test_rrule( - config, eigh_trunc2, A; - fkwargs = (; trunc = trunc), - output_tangent = (ΔD[ind, ind], ΔV[:, ind], zero(real(T))), - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - end - trunc = trunctol(; rtol = 1 / 2) - ind = MatrixAlgebraKit.findtruncated(Ddiag, trunc) - test_rrule( - config, eigh_trunc2, A; - fkwargs = (; trunc = trunc), - output_tangent = (ΔD[ind, ind], ΔV[:, ind], zero(real(T))), - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) -end - -@timedtestset "SVD AD Rules with eltype $T" for T in (Float64, ComplexF64, Float32) - rng = StableRNG(12345) - m = 19 - @testset "size ($m, $n)" for n in (17, m, 23) - atol = rtol = m * n * precision(T) - A = randn(rng, T, m, n) - minmn = min(m, n) - U, S, Vᴴ = svd_compact(A) - ΔU = randn(rng, T, m, minmn) - ΔS = randn(rng, real(T), minmn, minmn) - ΔS2 = Diagonal(randn(rng, real(T), minmn)) - ΔVᴴ = randn(rng, T, minmn, n) - ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = atol) - for alg in (LAPACK_QRIteration(), LAPACK_DivideAndConquer()) - test_rrule( - copy_svd_compact, A, alg ⊢ NoTangent(); - output_tangent = (ΔU, ΔS, ΔVᴴ), atol = atol, rtol = rtol - ) - test_rrule( - copy_svd_compact, A, alg ⊢ NoTangent(); - output_tangent = (ΔU, ΔS2, ΔVᴴ), atol = atol, rtol = rtol - ) - test_rrule( - copy_svd_vals, A, alg ⊢ NoTangent(); - output_tangent = diagview(ΔS), atol, rtol - ) - for r in 1:4:minmn - truncalg = TruncatedAlgorithm(alg, truncrank(r)) - ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc) - Strunc = Diagonal(diagview(S)[ind]) - Utrunc = U[:, ind] - Vᴴtrunc = Vᴴ[ind, :] - ΔStrunc = Diagonal(diagview(ΔS2)[ind]) - ΔUtrunc = ΔU[:, ind] - ΔVᴴtrunc = ΔVᴴ[ind, :] - test_rrule( - copy_svd_trunc, A, truncalg ⊢ NoTangent(); - output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))), - atol = atol, rtol = rtol - ) - test_rrule( - copy_svd_trunc_no_error, A, truncalg ⊢ NoTangent(); - output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), - atol = atol, rtol = rtol - ) - dA1 = MatrixAlgebraKit.svd_pullback!(zero(A), A, (U, S, Vᴴ), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), ind) - dA2 = MatrixAlgebraKit.svd_trunc_pullback!(zero(A), A, (Utrunc, Strunc, Vᴴtrunc), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc)) - @test isapprox(dA1, dA2; atol = atol, rtol = rtol) - end - truncalg = TruncatedAlgorithm(alg, trunctol(atol = S[1, 1] / 2)) - ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc) - Strunc = Diagonal(diagview(S)[ind]) - Utrunc = U[:, ind] - Vᴴtrunc = Vᴴ[ind, :] - ΔStrunc = Diagonal(diagview(ΔS2)[ind]) - ΔUtrunc = ΔU[:, ind] - ΔVᴴtrunc = ΔVᴴ[ind, :] - test_rrule( - copy_svd_trunc, A, truncalg ⊢ NoTangent(); - output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))), - atol = atol, rtol = rtol - ) - test_rrule( - copy_svd_trunc_no_error, A, truncalg ⊢ NoTangent(); - output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), - atol = atol, rtol = rtol - ) - dA1 = MatrixAlgebraKit.svd_pullback!(zero(A), A, (U, S, Vᴴ), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), ind) - dA2 = MatrixAlgebraKit.svd_trunc_pullback!(zero(A), A, (Utrunc, Strunc, Vᴴtrunc), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc)) - @test isapprox(dA1, dA2; atol = atol, rtol = rtol) - end - # Zygote part - config = Zygote.ZygoteRuleConfig() - test_rrule( - config, svd_compact, A; - output_tangent = (ΔU, ΔS, ΔVᴴ), atol = atol, rtol = rtol, - rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, svd_compact, A; - output_tangent = (ΔU, ΔS2, ΔVᴴ), atol = atol, rtol = rtol, - rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, svd_vals, A; - output_tangent = diagview(ΔS), atol, rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - for r in 1:4:minmn - trunc = truncrank(r) - ind = MatrixAlgebraKit.findtruncated(diagview(S), trunc) - test_rrule( - config, svd_trunc, A; - fkwargs = (; trunc = trunc), - output_tangent = (ΔU[:, ind], ΔS[ind, ind], ΔVᴴ[ind, :], zero(real(T))), - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, svd_trunc_no_error, A; - fkwargs = (; trunc = trunc), - output_tangent = (ΔU[:, ind], ΔS[ind, ind], ΔVᴴ[ind, :]), - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - end - trunc = trunctol(; atol = S[1, 1] / 2) - ind = MatrixAlgebraKit.findtruncated(diagview(S), trunc) - test_rrule( - config, svd_trunc, A; - fkwargs = (; trunc = trunc), - output_tangent = (ΔU[:, ind], ΔS[ind, ind], ΔVᴴ[ind, :], zero(real(T))), - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, svd_trunc_no_error, A; - fkwargs = (; trunc = trunc), - output_tangent = (ΔU[:, ind], ΔS[ind, ind], ΔVᴴ[ind, :]), - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - end -end - -@timedtestset "Polar AD Rules with eltype $T" for T in (Float64, ComplexF64, Float32) - rng = StableRNG(12345) - m = 19 - @testset "size ($m, $n)" for n in (17, m, 23) - atol = rtol = m * n * precision(T) - A = randn(rng, T, m, n) - for alg in PolarViaSVD.((LAPACK_QRIteration(), LAPACK_DivideAndConquer())) - m >= n && - test_rrule(copy_left_polar, A, alg ⊢ NoTangent(); atol = atol, rtol = rtol) - m <= n && - test_rrule(copy_right_polar, A, alg ⊢ NoTangent(); atol = atol, rtol = rtol) - end - # Zygote part - config = Zygote.ZygoteRuleConfig() - m >= n && test_rrule( - config, left_polar, A; - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - m <= n && test_rrule( - config, right_polar, A; - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - end -end - -@timedtestset "Orth and null with eltype $T" for T in (Float64, ComplexF64, Float32) - rng = StableRNG(12345) - m = 19 - @testset "size ($m, $n)" for n in (17, m, 23) - atol = rtol = m * n * precision(T) - A = randn(rng, T, m, n) - config = Zygote.ZygoteRuleConfig() - test_rrule( - config, left_orth, A; - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, left_orth, A; - fkwargs = (; alg = :qr), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - m >= n && - test_rrule( - config, left_orth, A; - fkwargs = (; alg = :polar), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - - ΔN = left_orth(A; alg = :qr)[1] * randn(rng, T, min(m, n), m - min(m, n)) - test_rrule( - config, left_null, A; - fkwargs = (; alg = :qr), output_tangent = ΔN, atol = atol, rtol = rtol, - rrule_f = rrule_via_ad, check_inferred = false - ) +@isdefined(TestSuite) || include("testsuite/TestSuite.jl") +using .TestSuite - test_rrule( - config, right_orth, A; - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, right_orth, A; fkwargs = (; alg = :lq), - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - m <= n && - test_rrule( - config, right_orth, A; fkwargs = (; alg = :polar), - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) +is_buildkite = get(ENV, "BUILDKITE", "false") == "true" - ΔNᴴ = randn(rng, T, n - min(m, n), min(m, n)) * right_orth(A; alg = :lq)[2] - test_rrule( - config, right_null, A; - fkwargs = (; alg = :lq), output_tangent = ΔNᴴ, - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) +m = 19 +for T in BLASFloats, n in (17, m, 23) + TestSuite.seed_rng!(123) + if !is_buildkite # doesn't work on GPU + TestSuite.test_chainrules(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) end end diff --git a/test/mooncake.jl b/test/mooncake.jl index 760102b1..eed9154d 100644 --- a/test/mooncake.jl +++ b/test/mooncake.jl @@ -1,597 +1,30 @@ using MatrixAlgebraKit using Test -using TestExtras -using StableRNGs -using Mooncake, Mooncake.TestUtils -using Mooncake: rrule!! -using MatrixAlgebraKit: diagview, TruncatedAlgorithm, PolarViaSVD, eigh_trunc -using LinearAlgebra: UpperTriangular, Diagonal, Hermitian, mul! - -include("ad_utils.jl") - -make_mooncake_tangent(ΔAelem::T) where {T <: Complex} = Mooncake.build_tangent(T, real(ΔAelem), imag(ΔAelem)) -make_mooncake_tangent(ΔA::Matrix{<:Real}) = ΔA -make_mooncake_tangent(ΔA::Vector{<:Real}) = ΔA -make_mooncake_tangent(ΔA::Matrix{T}) where {T <: Complex} = map(make_mooncake_tangent, ΔA) -make_mooncake_tangent(ΔA::Vector{T}) where {T <: Complex} = map(make_mooncake_tangent, ΔA) -make_mooncake_tangent(ΔD::Diagonal{T}) where {T <: Real} = Mooncake.build_tangent(typeof(ΔD), diagview(ΔD)) -make_mooncake_tangent(ΔD::Diagonal{T}) where {T <: Complex} = Mooncake.build_tangent(typeof(ΔD), map(make_mooncake_tangent, diagview(ΔD))) - -make_mooncake_tangent(T::Tuple) = Mooncake.build_tangent(typeof(T), T...) - -make_mooncake_fdata(x) = make_mooncake_tangent(x) -make_mooncake_fdata(x::Diagonal) = Mooncake.FData((diag = make_mooncake_tangent(x.diag),)) - -ETs = (Float32, ComplexF64) - -# no `alg` argument -function _get_copying_derivative(f_c, rrule, A, ΔA, args, Δargs, ::Nothing, rdata) - dA_copy = make_mooncake_tangent(copy(ΔA)) - A_copy = copy(A) - dargs_copy = Δargs isa Tuple ? make_mooncake_fdata.(deepcopy(Δargs)) : make_mooncake_fdata(deepcopy(Δargs)) - copy_out, copy_pb!! = rrule(Mooncake.CoDual(f_c, Mooncake.NoFData()), Mooncake.CoDual(A_copy, dA_copy), Mooncake.CoDual(args, dargs_copy)) - copy_pb!!(rdata) - return dA_copy -end - -# `alg` argument -function _get_copying_derivative(f_c, rrule, A, ΔA, args, Δargs, alg, rdata) - dA_copy = make_mooncake_tangent(copy(ΔA)) - A_copy = copy(A) - dargs_copy = Δargs isa Tuple ? make_mooncake_fdata.(deepcopy(Δargs)) : make_mooncake_fdata(deepcopy(Δargs)) - copy_out, copy_pb!! = rrule(Mooncake.CoDual(f_c, Mooncake.NoFData()), Mooncake.CoDual(A_copy, dA_copy), Mooncake.CoDual(args, dargs_copy), Mooncake.CoDual(alg, Mooncake.NoFData())) - copy_pb!!(rdata) - return dA_copy -end - -function _get_inplace_derivative(f!, A, ΔA, args, Δargs, ::Nothing, rdata) - dA_inplace = make_mooncake_tangent(copy(ΔA)) - A_inplace = copy(A) - dargs_inplace = Δargs isa Tuple ? make_mooncake_fdata.(deepcopy(Δargs)) : make_mooncake_fdata(deepcopy(Δargs)) - # not every f! has a handwritten rrule!! - inplace_sig = Tuple{typeof(f!), typeof(A), typeof(args)} - has_handwritten_rule = hasmethod(Mooncake.rrule!!, inplace_sig) - if has_handwritten_rule - inplace_out, inplace_pb!! = Mooncake.rrule!!(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(A_inplace, dA_inplace), Mooncake.CoDual(args, dargs_inplace)) - else - inplace_sig = Tuple{typeof(f!), typeof(A), typeof(args)} - rvs_interp = Mooncake.get_interpreter(Mooncake.ReverseMode) - inplace_rrule = Mooncake.build_rrule(rvs_interp, inplace_sig) - inplace_out, inplace_pb!! = inplace_rrule(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(A_inplace, dA_inplace), Mooncake.CoDual(args, dargs_inplace)) - end - inplace_pb!!(rdata) - return dA_inplace -end - -function _get_inplace_derivative(f!, A, ΔA, args, Δargs, alg, rdata) - dA_inplace = make_mooncake_tangent(copy(ΔA)) - A_inplace = copy(A) - dargs_inplace = Δargs isa Tuple ? make_mooncake_fdata.(deepcopy(Δargs)) : make_mooncake_fdata(deepcopy(Δargs)) - # not every f! has a handwritten rrule!! - inplace_sig = Tuple{typeof(f!), typeof(A), typeof(args), typeof(alg)} - has_handwritten_rule = hasmethod(Mooncake.rrule!!, inplace_sig) - if has_handwritten_rule - inplace_out, inplace_pb!! = Mooncake.rrule!!(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(A_inplace, dA_inplace), Mooncake.CoDual(args, dargs_inplace), Mooncake.CoDual(alg, Mooncake.NoFData())) - else - inplace_sig = Tuple{typeof(f!), typeof(A), typeof(args), typeof(alg)} - rvs_interp = Mooncake.get_interpreter(Mooncake.ReverseMode) - inplace_rrule = Mooncake.build_rrule(rvs_interp, inplace_sig) - inplace_out, inplace_pb!! = inplace_rrule(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(A_inplace, dA_inplace), Mooncake.CoDual(args, dargs_inplace), Mooncake.CoDual(alg, Mooncake.NoFData())) - end - inplace_pb!!(rdata) - return dA_inplace -end - -""" - test_pullbacks_match(rng, f!, f, A, args, Δargs, alg = nothing; rdata = Mooncake.NoRData()) - -Compare the result of running the *in-place, mutating* function `f!`'s reverse rule -with the result of running its *non-mutating* partner function `f`'s reverse rule. -We must compare directly because many of the mutating functions modify `A` as a -scratch workspace, making testing `f!` against finite differences infeasible. - -The arguments to this function are: - - `f!` the mutating, in-place version of the function (accepts `args` for the function result) - - `f` the non-mutating version of the function (does not accept `args` for the function result) - - `A` the input matrix to factorize - - `args` preallocated output for `f!` (e.g. `Q` and `R` matrices for `qr_compact!`) - - `Δargs` precomputed derivatives of `args` for pullbacks of `f` and `f!`, to ensure they receive the same input - - `alg` optional algorithm keyword argument - - `rdata` Mooncake reverse data to supply to the pullback, in case `f` and `f!` return scalar results (as truncating functions do) -""" -function test_pullbacks_match(rng, f!, f, A, args, Δargs, alg = nothing; rdata = Mooncake.NoRData()) - f_c = isnothing(alg) ? (A, args) -> f!(MatrixAlgebraKit.copy_input(f, A), args) : (A, args, alg) -> f!(MatrixAlgebraKit.copy_input(f, A), args, alg) - sig = isnothing(alg) ? Tuple{typeof(f_c), typeof(A), typeof(args)} : Tuple{typeof(f_c), typeof(A), typeof(args), typeof(alg)} - rvs_interp = Mooncake.get_interpreter(Mooncake.ReverseMode) - rrule = Mooncake.build_rrule(rvs_interp, sig) - ΔA = randn(rng, eltype(A), size(A)) - - dA_copy = _get_copying_derivative(f_c, rrule, A, ΔA, args, Δargs, alg, rdata) - dA_inplace = _get_inplace_derivative(f!, A, ΔA, args, Δargs, alg, rdata) - - dA_inplace_ = Mooncake.arrayify(A, dA_inplace)[2] - dA_copy_ = Mooncake.arrayify(A, dA_copy)[2] - @test dA_inplace_ ≈ dA_copy_ - return -end - -@timedtestset "QR AD Rules with eltype $T" for T in ETs - rng = StableRNG(12345) - m = 19 - @testset "size ($m, $n)" for n in (17, m, 23) - atol = rtol = m * n * precision(T) - A = randn(rng, T, m, n) - minmn = min(m, n) - @testset for alg in ( - LAPACK_HouseholderQR(), - LAPACK_HouseholderQR(; positive = true), - ) - @testset "qr_compact" begin - QR = qr_compact(A, alg) - Q = randn(rng, T, m, minmn) - R = randn(rng, T, minmn, n) - Mooncake.TestUtils.test_rule(rng, qr_compact, A, alg; mode = Mooncake.ReverseMode, is_primitive = false, atol = atol, rtol = rtol) - test_pullbacks_match(rng, qr_compact!, qr_compact, A, (Q, R), (randn(rng, T, m, minmn), randn(rng, T, minmn, n)), alg) - end - @testset "qr_null" begin - Q, R = qr_compact(A, alg) - ΔN = Q * randn(rng, T, minmn, max(0, m - minmn)) - N = qr_null(A, alg) - dN = make_mooncake_tangent(copy(ΔN)) - Mooncake.TestUtils.test_rule(rng, qr_null, A, alg; mode = Mooncake.ReverseMode, output_tangent = dN, is_primitive = false, atol = atol, rtol = rtol) - test_pullbacks_match(rng, qr_null!, qr_null, A, N, ΔN, alg) - end - @testset "qr_full" begin - Q, R = qr_full(A, alg) - Q1 = view(Q, 1:m, 1:minmn) - ΔQ = randn(rng, T, m, m) - ΔQ2 = view(ΔQ, :, (minmn + 1):m) - mul!(ΔQ2, Q1, Q1' * ΔQ2) - ΔR = randn(rng, T, m, n) - dQ = make_mooncake_tangent(copy(ΔQ)) - dR = make_mooncake_tangent(copy(ΔR)) - dQR = Mooncake.build_tangent(typeof((ΔQ, ΔR)), dQ, dR) - Mooncake.TestUtils.test_rule(rng, qr_full, A, alg; mode = Mooncake.ReverseMode, output_tangent = dQR, is_primitive = false, atol = atol, rtol = rtol) - test_pullbacks_match(rng, qr_full!, qr_full, A, (Q, R), (ΔQ, ΔR), alg) - end - @testset "qr_compact - rank-deficient A" begin - r = minmn - 5 - Ard = randn(rng, T, m, r) * randn(rng, T, r, n) - Q, R = qr_compact(Ard, alg) - QR = (Q, R) - ΔQ = randn(rng, T, m, minmn) - Q1 = view(Q, 1:m, 1:r) - Q2 = view(Q, 1:m, (r + 1):minmn) - ΔQ2 = view(ΔQ, 1:m, (r + 1):minmn) - ΔQ2 .= 0 - ΔR = randn(rng, T, minmn, n) - view(ΔR, (r + 1):minmn, :) .= 0 - dQ = make_mooncake_tangent(copy(ΔQ)) - dR = make_mooncake_tangent(copy(ΔR)) - dQR = Mooncake.build_tangent(typeof((ΔQ, ΔR)), dQ, dR) - Mooncake.TestUtils.test_rule(rng, qr_compact, Ard, alg; mode = Mooncake.ReverseMode, output_tangent = dQR, is_primitive = false, atol = atol, rtol = rtol) - test_pullbacks_match(rng, qr_compact!, qr_compact, Ard, (Q, R), (ΔQ, ΔR), alg) - end - end - end -end - -@timedtestset "LQ AD Rules with eltype $T" for T in ETs - rng = StableRNG(12345) - m = 19 - @testset "size ($m, $n)" for n in (17, m, 23) - atol = rtol = m * n * precision(T) - A = randn(rng, T, m, n) - minmn = min(m, n) - @testset for alg in ( - LAPACK_HouseholderLQ(), - LAPACK_HouseholderLQ(; positive = true), - ) - @testset "lq_compact" begin - L, Q = lq_compact(A, alg) - Mooncake.TestUtils.test_rule(rng, lq_compact, A, alg; mode = Mooncake.ReverseMode, is_primitive = false, atol = atol, rtol = rtol) - test_pullbacks_match(rng, lq_compact!, lq_compact, A, (L, Q), (randn(rng, T, m, minmn), randn(rng, T, minmn, n)), alg) - end - @testset "lq_null" begin - L, Q = lq_compact(A, alg) - ΔNᴴ = randn(rng, T, max(0, n - minmn), minmn) * Q - Nᴴ = randn(rng, T, max(0, n - minmn), n) - dNᴴ = make_mooncake_tangent(ΔNᴴ) - Mooncake.TestUtils.test_rule(rng, lq_null, A, alg; mode = Mooncake.ReverseMode, output_tangent = dNᴴ, is_primitive = false, atol = atol, rtol = rtol) - test_pullbacks_match(rng, lq_null!, lq_null, A, Nᴴ, ΔNᴴ, alg) - end - @testset "lq_full" begin - L, Q = lq_full(A, alg) - Q1 = view(Q, 1:minmn, 1:n) - ΔQ = randn(rng, T, n, n) - ΔQ2 = view(ΔQ, (minmn + 1):n, 1:n) - mul!(ΔQ2, ΔQ2 * Q1', Q1) - ΔL = randn(rng, T, m, n) - dL = make_mooncake_tangent(ΔL) - dQ = make_mooncake_tangent(ΔQ) - dLQ = Mooncake.build_tangent(typeof((ΔL, ΔQ)), dL, dQ) - Mooncake.TestUtils.test_rule(rng, lq_full, A, alg; mode = Mooncake.ReverseMode, output_tangent = dLQ, is_primitive = false, atol = atol, rtol = rtol) - test_pullbacks_match(rng, lq_full!, lq_full, A, (L, Q), (ΔL, ΔQ), alg) - end - @testset "lq_compact - rank-deficient A" begin - r = minmn - 5 - Ard = randn(rng, T, m, r) * randn(rng, T, r, n) - L, Q = lq_compact(Ard, alg) - ΔL = randn(rng, T, m, minmn) - ΔQ = randn(rng, T, minmn, n) - Q1 = view(Q, 1:r, 1:n) - Q2 = view(Q, (r + 1):minmn, 1:n) - ΔQ2 = view(ΔQ, (r + 1):minmn, 1:n) - ΔQ2 .= 0 - view(ΔL, :, (r + 1):minmn) .= 0 - dL = make_mooncake_tangent(ΔL) - dQ = make_mooncake_tangent(ΔQ) - dLQ = Mooncake.build_tangent(typeof((ΔL, ΔQ)), dL, dQ) - Mooncake.TestUtils.test_rule(rng, lq_compact, Ard, alg; mode = Mooncake.ReverseMode, output_tangent = dLQ, is_primitive = false, atol = atol, rtol = rtol) - test_pullbacks_match(rng, lq_compact!, lq_compact, Ard, (L, Q), (ΔL, ΔQ), alg) - end - end - end -end - -@timedtestset "EIG AD Rules with eltype $T" for T in ETs - rng = StableRNG(12345) - m = 19 - atol = rtol = m * m * precision(T) - A = make_eig_matrix(rng, T, m) - DV = eig_full(A) - D, V = DV - Ddiag = diagview(D) - ΔV = randn(rng, complex(T), m, m) - ΔV = remove_eiggauge_dependence!(ΔV, D, V; degeneracy_atol = atol) - ΔD = randn(rng, complex(T), m, m) - ΔD2 = Diagonal(randn(rng, complex(T), m)) - - dD = make_mooncake_tangent(ΔD2) - dV = make_mooncake_tangent(ΔV) - dDV = Mooncake.build_tangent(typeof((ΔD2, ΔV)), dD, dV) - # compute the dA corresponding to the above dD, dV - @testset for alg in ( - LAPACK_Simple(), - #LAPACK_Expert(), # expensive on CI - ) - @testset "eig_full" begin - Mooncake.TestUtils.test_rule(rng, eig_full, A, alg; mode = Mooncake.ReverseMode, output_tangent = dDV, is_primitive = false, atol = atol, rtol = rtol) - test_pullbacks_match(rng, eig_full!, eig_full, A, (D, V), (ΔD2, ΔV), alg) - end - @testset "eig_vals" begin - Mooncake.TestUtils.test_rule(rng, eig_vals, A, alg; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, eig_vals!, eig_vals, A, D.diag, ΔD2.diag, alg) - end - @testset "eig_trunc" begin - for r in 1:4:m - truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs)) - ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) - Dtrunc = Diagonal(diagview(D)[ind]) - Vtrunc = V[:, ind] - ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) - ΔVtrunc = ΔV[:, ind] - dDtrunc = make_mooncake_tangent(ΔDtrunc) - dVtrunc = make_mooncake_tangent(ΔVtrunc) - dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc, zero(real(T)))), dDtrunc, dVtrunc, zero(real(T))) - Mooncake.TestUtils.test_rule(rng, eig_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, eig_trunc!, eig_trunc, A, (D, V), (ΔD2, ΔV), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) - dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc)), dDtrunc, dVtrunc) - Mooncake.TestUtils.test_rule(rng, eig_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg) - end - truncalg = TruncatedAlgorithm(alg, truncrank(5; by = real)) - ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) - Dtrunc = Diagonal(diagview(D)[ind]) - Vtrunc = V[:, ind] - ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) - ΔVtrunc = ΔV[:, ind] - dDtrunc = make_mooncake_tangent(ΔDtrunc) - dVtrunc = make_mooncake_tangent(ΔVtrunc) - dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc, zero(real(T)))), dDtrunc, dVtrunc, zero(real(T))) - Mooncake.TestUtils.test_rule(rng, eig_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, eig_trunc!, eig_trunc, A, (D, V), (ΔD2, ΔV), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) - dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc)), dDtrunc, dVtrunc) - Mooncake.TestUtils.test_rule(rng, eig_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg) - end - end -end - -function copy_eigh_full(A, alg; kwargs...) - A = (A + A') / 2 - return eigh_full(A, alg; kwargs...) -end - -function copy_eigh_full!(A, DV, alg; kwargs...) - A = (A + A') / 2 - return eigh_full!(A, DV, alg; kwargs...) -end - -function copy_eigh_vals(A, alg; kwargs...) - A = (A + A') / 2 - return eigh_vals(A, alg; kwargs...) -end - -function copy_eigh_vals!(A, D, alg; kwargs...) - A = (A + A') / 2 - return eigh_vals!(A, D, alg; kwargs...) -end - -function copy_eigh_trunc(A, alg; kwargs...) - A = (A + A') / 2 - return eigh_trunc(A, alg; kwargs...) -end - -function copy_eigh_trunc!(A, DV, alg; kwargs...) - A = (A + A') / 2 - return eigh_trunc!(A, DV, alg; kwargs...) -end - -function copy_eigh_trunc_no_error(A, alg; kwargs...) - A = (A + A') / 2 - return eigh_trunc_no_error(A, alg; kwargs...) -end - -function copy_eigh_trunc_no_error!(A, DV, alg; kwargs...) - A = (A + A') / 2 - return eigh_trunc_no_error!(A, DV, alg; kwargs...) -end - -MatrixAlgebraKit.copy_input(::typeof(copy_eigh_full), A) = MatrixAlgebraKit.copy_input(eigh_full, A) -MatrixAlgebraKit.copy_input(::typeof(copy_eigh_vals), A) = MatrixAlgebraKit.copy_input(eigh_vals, A) -MatrixAlgebraKit.copy_input(::typeof(copy_eigh_trunc), A) = MatrixAlgebraKit.copy_input(eigh_trunc, A) -MatrixAlgebraKit.copy_input(::typeof(copy_eigh_trunc_no_error), A) = MatrixAlgebraKit.copy_input(eigh_trunc, A) - -@timedtestset "EIGH AD Rules with eltype $T" for T in ETs - rng = StableRNG(12345) - m = 19 - atol = rtol = m * m * precision(T) - A = make_eigh_matrix(rng, T, m) - D, V = eigh_full(A) - Ddiag = diagview(D) - ΔV = randn(rng, T, m, m) - ΔV = remove_eighgauge_dependence!(ΔV, D, V; degeneracy_atol = atol) - ΔD = randn(rng, real(T), m, m) - ΔD2 = Diagonal(randn(rng, real(T), m)) - dD = make_mooncake_tangent(ΔD2) - dV = make_mooncake_tangent(ΔV) - dDV = Mooncake.build_tangent(typeof((ΔD2, ΔV)), dD, dV) - @testset for alg in ( - LAPACK_QRIteration(), - #LAPACK_DivideAndConquer(), - #LAPACK_Bisection(), - #LAPACK_MultipleRelativelyRobustRepresentations(), # expensive on CI - ) - @testset "eigh_full" begin - Mooncake.TestUtils.test_rule(rng, copy_eigh_full, A, alg; mode = Mooncake.ReverseMode, output_tangent = dDV, is_primitive = false, atol = atol, rtol = rtol) - test_pullbacks_match(rng, copy_eigh_full!, copy_eigh_full, A, (D, V), (ΔD2, ΔV), alg) - end - @testset "eigh_vals" begin - Mooncake.TestUtils.test_rule(rng, copy_eigh_vals, A, alg; mode = Mooncake.ReverseMode, is_primitive = false, atol = atol, rtol = rtol) - test_pullbacks_match(rng, copy_eigh_vals!, copy_eigh_vals, A, D.diag, ΔD2.diag, alg) - end - @testset "eigh_trunc" begin - for r in 1:4:m - truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs)) - ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) - Dtrunc = Diagonal(diagview(D)[ind]) - Vtrunc = V[:, ind] - ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) - ΔVtrunc = ΔV[:, ind] - dDtrunc = make_mooncake_tangent(ΔDtrunc) - dVtrunc = make_mooncake_tangent(ΔVtrunc) - dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc, zero(real(T)))), dDtrunc, dVtrunc, zero(real(T))) - Mooncake.TestUtils.test_rule(rng, copy_eigh_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, copy_eigh_trunc!, copy_eigh_trunc, A, (D, V), (ΔD2, ΔV), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) - dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc)), dDtrunc, dVtrunc) - Mooncake.TestUtils.test_rule(rng, copy_eigh_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg) - end - truncalg = TruncatedAlgorithm(alg, trunctol(; atol = maximum(abs, Ddiag) / 2)) - ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) - Dtrunc = Diagonal(diagview(D)[ind]) - Vtrunc = V[:, ind] - ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) - ΔVtrunc = ΔV[:, ind] - dDtrunc = make_mooncake_tangent(ΔDtrunc) - dVtrunc = make_mooncake_tangent(ΔVtrunc) - dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc, zero(real(T)))), dDtrunc, dVtrunc, zero(real(T))) - Mooncake.TestUtils.test_rule(rng, copy_eigh_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, copy_eigh_trunc!, copy_eigh_trunc, A, (D, V), (ΔD2, ΔV), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) - dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc)), dDtrunc, dVtrunc) - Mooncake.TestUtils.test_rule(rng, copy_eigh_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg) - end - end -end - -@timedtestset "SVD AD Rules with eltype $T" for T in ETs - rng = StableRNG(12345) - m = 19 - @testset "size ($m, $n)" for n in (17, m, 23) - atol = rtol = m * n * precision(T) - A = randn(rng, T, m, n) - minmn = min(m, n) - @testset for alg in ( - LAPACK_QRIteration(), - #LAPACK_DivideAndConquer(), # expensive on CI - ) - @testset "svd_compact" begin - ΔU = randn(rng, T, m, minmn) - ΔS = randn(rng, real(T), minmn, minmn) - ΔS2 = Diagonal(randn(rng, real(T), minmn)) - ΔVᴴ = randn(rng, T, minmn, n) - U, S, Vᴴ = svd_compact(A) - ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = atol) - dS = make_mooncake_tangent(ΔS2) - dU = make_mooncake_tangent(ΔU) - dVᴴ = make_mooncake_tangent(ΔVᴴ) - dUSVᴴ = Mooncake.build_tangent(typeof((ΔU, ΔS2, ΔVᴴ)), dU, dS, dVᴴ) - Mooncake.TestUtils.test_rule(rng, svd_compact, A, alg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol = atol, rtol = rtol) - test_pullbacks_match(rng, svd_compact!, svd_compact, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), alg) - end - @testset "svd_full" begin - ΔU = randn(rng, T, m, minmn) - ΔS = randn(rng, real(T), minmn, minmn) - ΔS2 = Diagonal(randn(rng, real(T), minmn)) - ΔVᴴ = randn(rng, T, minmn, n) - U, S, Vᴴ = svd_compact(A) - ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = atol) - ΔUfull = zeros(T, m, m) - ΔSfull = zeros(real(T), m, n) - ΔVᴴfull = zeros(T, n, n) - U, S, Vᴴ = svd_full(A) - view(ΔUfull, :, 1:minmn) .= ΔU - view(ΔVᴴfull, 1:minmn, :) .= ΔVᴴ - diagview(ΔSfull)[1:minmn] .= diagview(ΔS2) - dS = make_mooncake_tangent(ΔSfull) - dU = make_mooncake_tangent(ΔUfull) - dVᴴ = make_mooncake_tangent(ΔVᴴfull) - dUSVᴴ = Mooncake.build_tangent(typeof((ΔUfull, ΔSfull, ΔVᴴfull)), dU, dS, dVᴴ) - Mooncake.TestUtils.test_rule(rng, svd_full, A, alg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol = atol, rtol = rtol) - test_pullbacks_match(rng, svd_full!, svd_full, A, (U, S, Vᴴ), (ΔUfull, ΔSfull, ΔVᴴfull), alg) - end - @testset "svd_vals" begin - Mooncake.TestUtils.test_rule(rng, svd_vals, A, alg; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol) - S = svd_vals(A, alg) - test_pullbacks_match(rng, svd_vals!, svd_vals, A, S, randn(rng, real(T), minmn), alg) - end - @testset "svd_trunc" begin - @testset for r in 1:4:minmn - U, S, Vᴴ = svd_compact(A) - ΔU = randn(rng, T, m, minmn) - ΔS = randn(rng, real(T), minmn, minmn) - ΔS2 = Diagonal(randn(rng, real(T), minmn)) - ΔVᴴ = randn(rng, T, minmn, n) - ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = atol) - truncalg = TruncatedAlgorithm(alg, truncrank(r)) - ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc) - Strunc = Diagonal(diagview(S)[ind]) - Utrunc = U[:, ind] - Vᴴtrunc = Vᴴ[ind, :] - ΔStrunc = Diagonal(diagview(ΔS2)[ind]) - ΔUtrunc = ΔU[:, ind] - ΔVᴴtrunc = ΔVᴴ[ind, :] - dStrunc = make_mooncake_tangent(ΔStrunc) - dUtrunc = make_mooncake_tangent(ΔUtrunc) - dVᴴtrunc = make_mooncake_tangent(ΔVᴴtrunc) - ϵ = zero(real(T)) - dUSVᴴerr = Mooncake.build_tangent(typeof((ΔU, ΔS2, ΔVᴴ, ϵ)), dUtrunc, dStrunc, dVᴴtrunc, ϵ) - Mooncake.TestUtils.test_rule(rng, svd_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴerr, atol = atol, rtol = rtol) - test_pullbacks_match(rng, svd_trunc!, svd_trunc, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) - dUSVᴴ = Mooncake.build_tangent(typeof((ΔU, ΔS2, ΔVᴴ)), dUtrunc, dStrunc, dVᴴtrunc) - Mooncake.TestUtils.test_rule(rng, svd_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol = atol, rtol = rtol) - test_pullbacks_match(rng, svd_trunc_no_error!, svd_trunc_no_error, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg) - end - @testset "trunctol" begin - U, S, Vᴴ = svd_compact(A) - ΔU = randn(rng, T, m, minmn) - ΔS = randn(rng, real(T), minmn, minmn) - ΔS2 = Diagonal(randn(rng, real(T), minmn)) - ΔVᴴ = randn(rng, T, minmn, n) - ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = atol) - truncalg = TruncatedAlgorithm(alg, trunctol(atol = S[1, 1] / 2)) - ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc) - Strunc = Diagonal(diagview(S)[ind]) - Utrunc = U[:, ind] - Vᴴtrunc = Vᴴ[ind, :] - ΔStrunc = Diagonal(diagview(ΔS2)[ind]) - ΔUtrunc = ΔU[:, ind] - ΔVᴴtrunc = ΔVᴴ[ind, :] - dStrunc = make_mooncake_tangent(ΔStrunc) - dUtrunc = make_mooncake_tangent(ΔUtrunc) - dVᴴtrunc = make_mooncake_tangent(ΔVᴴtrunc) - ϵ = zero(real(T)) - dUSVᴴerr = Mooncake.build_tangent(typeof((ΔU, ΔS2, ΔVᴴ, ϵ)), dUtrunc, dStrunc, dVᴴtrunc, ϵ) - Mooncake.TestUtils.test_rule(rng, svd_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴerr, atol = atol, rtol = rtol) - test_pullbacks_match(rng, svd_trunc!, svd_trunc, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) - dUSVᴴ = Mooncake.build_tangent(typeof((ΔU, ΔS2, ΔVᴴ)), dUtrunc, dStrunc, dVᴴtrunc) - Mooncake.TestUtils.test_rule(rng, svd_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol = atol, rtol = rtol) - test_pullbacks_match(rng, svd_trunc_no_error!, svd_trunc_no_error, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg) - end - end - end - end -end - -@timedtestset "Polar AD Rules with eltype $T" for T in ETs - rng = StableRNG(12345) - m = 19 - @testset "size ($m, $n)" for n in (17, m, 23) - atol = rtol = m * n * precision(T) - A = randn(rng, T, m, n) - @testset for alg in PolarViaSVD.( - ( - LAPACK_QRIteration(), - #LAPACK_DivideAndConquer(), # expensive on CI - ) - ) - if m >= n - WP = left_polar(A, alg) - Mooncake.TestUtils.test_rule(rng, left_polar, A, alg; mode = Mooncake.ReverseMode, is_primitive = false, atol = atol, rtol = rtol) - test_pullbacks_match(rng, left_polar!, left_polar, A, WP, (randn(rng, T, m, n), randn(rng, T, n, n)), alg) - elseif m <= n - PWᴴ = right_polar(A, alg) - Mooncake.TestUtils.test_rule(rng, right_polar, A, alg; mode = Mooncake.ReverseMode, is_primitive = false, atol = atol, rtol = rtol) - test_pullbacks_match(rng, right_polar!, right_polar, A, PWᴴ, (randn(rng, T, m, m), randn(rng, T, m, n)), alg) - end - end - end -end - -left_orth_qr(X) = left_orth(X; alg = :qr) -left_orth_polar(X) = left_orth(X; alg = :polar) -left_null_qr(X) = left_null(X; alg = :qr) -right_orth_lq(X) = right_orth(X; alg = :lq) -right_orth_polar(X) = right_orth(X; alg = :polar) -right_null_lq(X) = right_null(X; alg = :lq) - -MatrixAlgebraKit.copy_input(::typeof(left_orth_qr), A) = MatrixAlgebraKit.copy_input(left_orth, A) -MatrixAlgebraKit.copy_input(::typeof(left_orth_polar), A) = MatrixAlgebraKit.copy_input(left_orth, A) -MatrixAlgebraKit.copy_input(::typeof(left_null_qr), A) = MatrixAlgebraKit.copy_input(left_null, A) -MatrixAlgebraKit.copy_input(::typeof(right_orth_lq), A) = MatrixAlgebraKit.copy_input(right_orth, A) -MatrixAlgebraKit.copy_input(::typeof(right_orth_polar), A) = MatrixAlgebraKit.copy_input(right_orth, A) -MatrixAlgebraKit.copy_input(::typeof(right_null_lq), A) = MatrixAlgebraKit.copy_input(right_null, A) - -@timedtestset "Orth and null with eltype $T" for T in ETs - rng = StableRNG(12345) - m = 19 - @testset "size ($m, $n)" for n in (17, m, 23) - atol = rtol = m * n * precision(T) - A = randn(rng, T, m, n) - VC = left_orth(A) - CVᴴ = right_orth(A) - Mooncake.TestUtils.test_rule(rng, left_orth, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, left_orth!, left_orth, A, VC, (randn(rng, T, size(VC[1])...), randn(rng, T, size(VC[2])...))) - Mooncake.TestUtils.test_rule(rng, right_orth, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, right_orth!, right_orth, A, CVᴴ, (randn(rng, T, size(CVᴴ[1])...), randn(rng, T, size(CVᴴ[2])...))) - - Mooncake.TestUtils.test_rule(rng, left_orth_qr, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, ((X, VC) -> left_orth!(X, VC; alg = :qr)), left_orth_qr, A, VC, (randn(rng, T, size(VC[1])...), randn(rng, T, size(VC[2])...))) - if m >= n - Mooncake.TestUtils.test_rule(rng, left_orth_polar, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, ((X, VC) -> left_orth!(X, VC; alg = :polar)), left_orth_polar, A, VC, (randn(rng, T, size(VC[1])...), randn(rng, T, size(VC[2])...))) - end - - N = left_orth(A; alg = :qr)[1] * randn(rng, T, min(m, n), m - min(m, n)) - ΔN = left_orth(A; alg = :qr)[1] * randn(rng, T, min(m, n), m - min(m, n)) - dN = make_mooncake_tangent(ΔN) - Mooncake.TestUtils.test_rule(rng, left_null_qr, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false, output_tangent = dN) - test_pullbacks_match(rng, ((X, N) -> left_null!(X, N; alg = :qr)), left_null_qr, A, N, ΔN) - - Mooncake.TestUtils.test_rule(rng, right_orth_lq, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, ((X, CVᴴ) -> right_orth!(X, CVᴴ; alg = :lq)), right_orth_lq, A, CVᴴ, (randn(rng, T, size(CVᴴ[1])...), randn(rng, T, size(CVᴴ[2])...))) - - if m <= n - Mooncake.TestUtils.test_rule(rng, right_orth_polar, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, ((X, CVᴴ) -> right_orth!(X, CVᴴ; alg = :polar)), right_orth_polar, A, CVᴴ, (randn(rng, T, size(CVᴴ[1])...), randn(rng, T, size(CVᴴ[2])...))) - end - - Nᴴ = randn(rng, T, n - min(m, n), min(m, n)) * right_orth(A; alg = :lq)[2] - ΔNᴴ = randn(rng, T, n - min(m, n), min(m, n)) * right_orth(A; alg = :lq)[2] - dNᴴ = make_mooncake_tangent(ΔNᴴ) - Mooncake.TestUtils.test_rule(rng, right_null_lq, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false, output_tangent = dNᴴ) - test_pullbacks_match(rng, ((X, Nᴴ) -> right_null!(X, Nᴴ; alg = :lq)), right_null_lq, A, Nᴴ, ΔNᴴ) +using LinearAlgebra: Diagonal +using CUDA, AMDGPU + +#BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) +BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI +GenericFloats = () +@isdefined(TestSuite) || include("testsuite/TestSuite.jl") +using .TestSuite + +is_buildkite = get(ENV, "BUILDKITE", "false") == "true" + +m = 19 +for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) + TestSuite.seed_rng!(123) + if CUDA.functional() + TestSuite.test_mooncake(CuMatrix{T}, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + #n == m && TestSuite.test_mooncake(Diagonal{T, CuVector{T}}, m; atol = m * TestSuite.precision(T), rtol = m * TestSuite.precision(T)) + end + #=if AMDGPU.functional() + TestSuite.test_mooncake(ROCMatrix{T}, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + TestSuite.test_mooncake(Diagonal{T, ROCVector{T}}, m; atol = m * TestSuite.precision(T), rtol = m * TestSuite.precision(T)) + end=# # not yet supported + if !is_buildkite + TestSuite.test_mooncake(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + # Pullbacks don't work for a lot of these, waiting on https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/156 + #n == m && TestSuite.test_mooncake(Diagonal{T, Vector{T}}, m; atol = m * TestSuite.precision(T), rtol = m * TestSuite.precision(T)) end end diff --git a/test/testsuite/TestSuite.jl b/test/testsuite/TestSuite.jl index 2f3fde50..28833557 100644 --- a/test/testsuite/TestSuite.jl +++ b/test/testsuite/TestSuite.jl @@ -84,6 +84,8 @@ function instantiate_unitary(T, A::ROCMatrix{<:Complex}, sz) end instantiate_unitary(::Type{<:Diagonal}, A, sz) = Diagonal(fill!(similar(parent(A), eltype(A), sz), one(eltype(A)))) +include("ad_utils.jl") + include("qr.jl") include("lq.jl") include("polar.jl") @@ -93,5 +95,7 @@ include("eig.jl") include("eigh.jl") include("orthnull.jl") include("svd.jl") +include("mooncake.jl") +include("chainrules.jl") end diff --git a/test/testsuite/ad_utils.jl b/test/testsuite/ad_utils.jl new file mode 100644 index 00000000..70ea1615 --- /dev/null +++ b/test/testsuite/ad_utils.jl @@ -0,0 +1,418 @@ +function remove_svdgauge_dependence!( + ΔU, ΔVᴴ, U, S, Vᴴ; + degeneracy_atol = MatrixAlgebraKit.default_pullback_gauge_atol(S) + ) + gaugepart = mul!(U' * ΔU, Vᴴ, ΔVᴴ', true, true) + gaugepart = project_antihermitian!(gaugepart) + gaugepart[abs.(transpose(diagview(S)) .- diagview(S)) .>= degeneracy_atol] .= 0 + mul!(ΔU, U, gaugepart, -1, 1) + return ΔU, ΔVᴴ +end +function remove_eiggauge_dependence!( + ΔV, D, V; + degeneracy_atol = MatrixAlgebraKit.default_pullback_gauge_atol(D) + ) + gaugepart = V' * ΔV + gaugepart[abs.(transpose(diagview(D)) .- diagview(D)) .>= degeneracy_atol] .= 0 + mul!(ΔV, V / (V' * V), gaugepart, -1, 1) + return ΔV +end +function remove_eighgauge_dependence!( + ΔV, D, V; + degeneracy_atol = MatrixAlgebraKit.default_pullback_gauge_atol(D) + ) + gaugepart = V' * ΔV + gaugepart = project_antihermitian!(gaugepart) + gaugepart[abs.(transpose(diagview(D)) .- diagview(D)) .>= degeneracy_atol] .= 0 + mul!(ΔV, V, gaugepart, -1, 1) + return ΔV +end + +function stabilize_eigvals!(D::AbstractVector) + absD = collect(abs.(D)) + p = invperm(sortperm(collect(absD))) # rank of abs(D) + # account for exact degeneracies in absolute value when having complex conjugate pairs + for i in 1:(length(D) - 1) + if absD[i] == absD[i + 1] # conjugate pairs will appear sequentially + p[p .>= p[i + 1]] .-= 1 # lower the rank of all higher ones + end + end + n = maximum(p) + # rescale eigenvalues so that they lie on distinct radii in the complex plane + # that are chosen randomly in non-overlapping intervals [10 * k/n, 10 * (k+0.5)/n)] for k=1,...,n + radii = 10 .* ((1:n) .+ rand(real(eltype(D)), n) ./ 2) ./ n + hD = sign.(collect(D)) .* radii[p] + copyto!(D, hD) + return D +end +function make_eig_matrix(T, sz) + A = instantiate_matrix(T, sz) + D, V = eig_full(A) + stabilize_eigvals!(diagview(D)) + Ac = V * D * inv(V) + return (eltype(T) <: Real) ? real(Ac) : Ac +end +function make_eigh_matrix(T, sz) + A = project_hermitian!(instantiate_matrix(T, sz)) + D, V = eigh_full(A) + stabilize_eigvals!(diagview(D)) + return project_hermitian!(V * D * V') +end + +function ad_qr_compact_setup(A) + m, n = size(A) + minmn = min(m, n) + QR = qr_compact(A) + T = eltype(A) + ΔQ = randn!(similar(A, T, m, minmn)) + ΔR = randn!(similar(A, T, minmn, n)) + return QR, (ΔQ, ΔR) +end + +function ad_qr_compact_setup(A::Diagonal) + m, n = size(A) + minmn = min(m, n) + QR = qr_compact(A) + T = eltype(A) + ΔQ = Diagonal(randn!(similar(A.diag, T, m))) + ΔR = Diagonal(randn!(similar(A.diag, T, m))) + return QR, (ΔQ, ΔR) +end + +function ad_qr_null_setup(A) + m, n = size(A) + minmn = min(m, n) + Q, R = qr_compact(A) + T = eltype(A) + ΔN = Q * randn!(similar(A, T, minmn, max(0, m - minmn))) + N = qr_null(A) + return N, ΔN +end + +function ad_qr_full_setup(A) + m, n = size(A) + minmn = min(m, n) + T = eltype(A) + Q, R = qr_full(A) + Q1 = view(Q, 1:m, 1:minmn) + ΔQ = randn!(similar(A, T, m, m)) + ΔQ2 = view(ΔQ, :, (minmn + 1):m) + mul!(ΔQ2, Q1, Q1' * ΔQ2) + ΔR = randn!(similar(A, T, m, n)) + return (Q, R), (ΔQ, ΔR) +end + +ad_qr_full_setup(A::Diagonal) = ad_qr_compact_setup(A) + +function ad_qr_rd_compact_setup(A) + m, n = size(A) + minmn = min(m, n) + T = eltype(A) + r = minmn - 5 + Ard = randn!(similar(A, T, m, r)) * randn!(similar(A, T, r, n)) + Q, R = qr_compact(Ard) + QR = (Q, R) + ΔQ = randn!(similar(A, T, m, minmn)) + Q1 = view(Q, 1:m, 1:r) + Q2 = view(Q, 1:m, (r + 1):minmn) + ΔQ2 = view(ΔQ, 1:m, (r + 1):minmn) + ΔQ2 .= 0 + ΔR = randn!(similar(A, T, minmn, n)) + view(ΔR, (r + 1):minmn, :) .= 0 + return (Q, R), (ΔQ, ΔR) +end + +function ad_qr_rd_compact_setup(A::Diagonal) + m, n = size(A) + minmn = min(m, n) + T = eltype(A) + r = minmn - 5 + Ard_ = randn!(similar(A, T, m)) + Ard_[(r + 1):m] .= zero(T) + Ard = Diagonal(Ard_) + Q, R = qr_compact(Ard) + QR = (Q, R) + ΔQ = Diagonal(randn!(similar(A.diag, T, m))) + ΔR = Diagonal(randn!(similar(A.diag, T, m))) + diagview(ΔQ)[(r + 1):m] .= zero(T) + diagview(ΔR)[(r + 1):m] .= zero(T) + return (Q, R), (ΔQ, ΔR) +end + +function ad_lq_compact_setup(A) + m, n = size(A) + minmn = min(m, n) + LQ = lq_compact(A) + T = eltype(A) + ΔL = randn!(similar(A, T, m, minmn)) + ΔQ = randn!(similar(A, T, minmn, n)) + return LQ, (ΔL, ΔQ) +end +ad_lq_compact_setup(A::Diagonal) = ad_qr_compact_setup(A) + +function ad_lq_null_setup(A) + m, n = size(A) + minmn = min(m, n) + T = eltype(A) + L, Q = lq_compact(A) + ΔNᴴ = randn!(similar(A, T, max(0, n - minmn), minmn)) * Q + Nᴴ = randn!(similar(A, T, max(0, n - minmn), n)) + return Nᴴ, ΔNᴴ +end + +function ad_lq_full_setup(A) + m, n = size(A) + minmn = min(m, n) + T = eltype(A) + L, Q = lq_full(A) + Q1 = view(Q, 1:minmn, 1:n) + ΔQ = randn!(similar(A, T, n, n)) + ΔQ2 = view(ΔQ, (minmn + 1):n, 1:n) + ΔQ2 .= (ΔQ2 * Q1') * Q1 + ΔL = randn!(similar(A, T, m, n)) + return (L, Q), (ΔL, ΔQ) +end +ad_lq_full_setup(A::Diagonal) = ad_qr_full_setup(A) + +function ad_lq_rd_compact_setup(A) + m, n = size(A) + minmn = min(m, n) + T = eltype(A) + r = minmn - 5 + Ard = randn!(similar(A, T, m, r)) * randn!(similar(A, T, r, n)) + L, Q = lq_compact(Ard) + ΔL = randn!(similar(A, T, m, minmn)) + ΔQ = randn!(similar(A, T, minmn, n)) + Q1 = view(Q, 1:r, 1:n) + Q2 = view(Q, (r + 1):minmn, 1:n) + ΔQ2 = view(ΔQ, (r + 1):minmn, 1:n) + ΔQ2 .= 0 + view(ΔL, :, (r + 1):minmn) .= 0 + return (L, Q), (ΔL, ΔQ) +end +ad_lq_rd_compact_setup(A::Diagonal) = ad_qr_rd_compact_setup(A) + +function ad_eig_full_setup(A) + m, n = size(A) + T = eltype(A) + DV = eig_full(A) + D, V = DV + Ddiag = diagview(D) + ΔV = randn!(similar(A, complex(T), m, m)) + ΔV = remove_eiggauge_dependence!(ΔV, D, V) + ΔD = randn!(similar(A, complex(T), m, m)) + ΔD2 = Diagonal(randn!(similar(A, complex(T), m))) + return DV, (ΔD, ΔV), (ΔD2, ΔV) +end + +function ad_eig_full_setup(A::Diagonal) + m, n = size(A) + T = eltype(A) + DV = eig_full(A) + D, V = DV + ΔV = Diagonal(randn!(similar(A.diag, T, m))) + ΔV = remove_eiggauge_dependence!(ΔV, D, V) + ΔD = Diagonal(randn!(similar(A.diag, T, m))) + ΔD2 = Diagonal(randn!(similar(A.diag, T, m))) + return DV, (ΔD, ΔV), (ΔD2, ΔV) +end + +function ad_eig_vals_setup(A) + m, n = size(A) + T = eltype(A) + D = eig_vals(A) + ΔD = randn!(similar(A, complex(T), m)) + return D, ΔD +end + +function ad_eig_vals_setup(A::Diagonal) + m, n = size(A) + T = eltype(A) + D = eig_vals(A) + ΔD = randn!(similar(A.diag)) + return D, ΔD +end + +function ad_eig_trunc_setup(A, truncalg) + DV, ΔDV, ΔD2V = ad_eig_full_setup(A) + ind = MatrixAlgebraKit.findtruncated(diagview(DV[1]), truncalg.trunc) + Dtrunc = Diagonal(diagview(DV[1])[ind]) + Vtrunc = DV[2][:, ind] + ΔDtrunc = Diagonal(diagview(ΔD2V[1])[ind]) + ΔVtrunc = ΔDV[2][:, ind] + return DV, (Dtrunc, Vtrunc), ΔD2V, (ΔDtrunc, ΔVtrunc) +end + +function ad_eigh_full_setup(A) + m, n = size(A) + T = eltype(A) + DV = eigh_full(A) + D, V = DV + Ddiag = diagview(D) + ΔV = randn!(similar(A, T, m, m)) + ΔV = remove_eighgauge_dependence!(ΔV, D, V) + ΔD = randn!(similar(A, real(T), m, m)) + ΔD2 = Diagonal(randn!(similar(A, real(T), m))) + return DV, (ΔD, ΔV), (ΔD2, ΔV) +end + +function ad_eigh_vals_setup(A) + m, n = size(A) + T = eltype(A) + D = eigh_vals(A) + ΔD = randn!(similar(A, real(T), m)) + return D, ΔD +end + +function ad_eigh_trunc_setup(A, truncalg) + DV, ΔDV, ΔD2V = ad_eigh_full_setup(A) + ind = MatrixAlgebraKit.findtruncated(diagview(DV[1]), truncalg.trunc) + Dtrunc = Diagonal(diagview(DV[1])[ind]) + Vtrunc = DV[2][:, ind] + ΔDtrunc = Diagonal(diagview(ΔD2V[1])[ind]) + ΔVtrunc = ΔDV[2][:, ind] + return DV, (Dtrunc, Vtrunc), ΔD2V, (ΔDtrunc, ΔVtrunc) +end + +function ad_svd_compact_setup(A) + m, n = size(A) + T = eltype(A) + minmn = min(m, n) + ΔU = randn!(similar(A, T, m, minmn)) + ΔS = randn!(similar(A, real(T), minmn, minmn)) + ΔS2 = Diagonal(randn!(similar(A, real(T), minmn))) + ΔVᴴ = randn!(similar(A, T, minmn, n)) + U, S, Vᴴ = svd_compact(A) + ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ) + return (U, S, Vᴴ), (ΔU, ΔS, ΔVᴴ), (ΔU, ΔS2, ΔVᴴ) +end + +function ad_svd_compact_setup(A::Diagonal) + m, n = size(A) + T = eltype(A) + minmn = min(m, n) + ΔU = randn!(similar(A.diag, T, m, n)) + ΔS = Diagonal(randn!(similar(A.diag, real(T), minmn))) + ΔS2 = Diagonal(randn!(similar(A.diag, real(T), minmn))) + ΔVᴴ = randn!(similar(A.diag, T, m, n)) + U, S, Vᴴ = svd_compact(A) + ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ) + return (U, S, Vᴴ), (ΔU, ΔS, ΔVᴴ), (ΔU, ΔS2, ΔVᴴ) +end + +function ad_svd_full_setup(A) + m, n = size(A) + T = eltype(A) + minmn = min(m, n) + ΔU = randn!(similar(A, T, m, minmn)) + ΔS = randn!(similar(A, real(T), minmn, minmn)) + ΔS2 = Diagonal(randn!(similar(A, real(T), minmn))) + ΔVᴴ = randn!(similar(A, T, minmn, n)) + U, S, Vᴴ = svd_compact(A) + ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ) + ΔUfull = similar(A, T, m, m) + ΔUfull .= zero(T) + ΔSfull = similar(A, real(T), m, n) + ΔSfull .= zero(real(T)) + ΔVᴴfull = similar(A, T, n, n) + ΔVᴴfull .= zero(T) + U, S, Vᴴ = svd_full(A) + view(ΔUfull, :, 1:minmn) .= ΔU + view(ΔVᴴfull, 1:minmn, :) .= ΔVᴴ + diagview(ΔSfull)[1:minmn] .= diagview(ΔS2) + return (U, S, Vᴴ), (ΔUfull, ΔSfull, ΔVᴴfull) +end + +ad_svd_full_setup(A::Diagonal) = ad_svd_compact_setup(A) + +function ad_svd_vals_setup(A) + m, n = size(A) + minmn = min(m, n) + T = eltype(A) + S = svd_vals(A) + ΔS = randn!(similar(A, real(T), minmn)) + return S, ΔS +end + +function ad_svd_trunc_setup(A, truncalg) + USVᴴ, ΔUSVᴴ, ΔUS2Vᴴ = ad_svd_compact_setup(A) + ind = MatrixAlgebraKit.findtruncated(diagview(USVᴴ[2]), truncalg.trunc) + Strunc = Diagonal(diagview(USVᴴ[2])[ind]) + Utrunc = USVᴴ[1][:, ind] + Vᴴtrunc = USVᴴ[3][ind, :] + ΔStrunc = Diagonal(diagview(ΔUS2Vᴴ[2])[ind]) + ΔUtrunc = ΔUSVᴴ[1][:, ind] + ΔVᴴtrunc = ΔUSVᴴ[3][ind, :] + return USVᴴ, ΔUS2Vᴴ, (ΔUtrunc, ΔStrunc, ΔVᴴtrunc) +end + +function ad_left_polar_setup(A) + m, n = size(A) + T = eltype(A) + WP = left_polar(A) + ΔWP = (randn!(similar(A, T, m, n)), randn!(similar(A, T, n, n))) + return WP, ΔWP +end + +function ad_left_polar_setup(A::Diagonal) + m, n = size(A) + T = eltype(A) + WP = left_polar(A) + ΔWP = (Diagonal(randn!(similar(A.diag))), randn!(similar(WP[2]))) + return WP, ΔWP +end + +function ad_right_polar_setup(A) + m, n = size(A) + T = eltype(A) + PWᴴ = right_polar(A) + ΔPWᴴ = (randn!(similar(A, T, m, m)), randn!(similar(A, T, m, n))) + return PWᴴ, ΔPWᴴ +end +function ad_right_polar_setup(A::Diagonal) + m, n = size(A) + T = eltype(A) + PWᴴ = right_polar(A) + ΔPWᴴ = (randn!(similar(PWᴴ[1])), Diagonal(randn!(similar(A.diag)))) + return PWᴴ, ΔPWᴴ +end + +function ad_left_orth_setup(A) + m, n = size(A) + T = eltype(A) + VC = left_orth(A) + ΔVC = (randn!(similar(A, T, size(VC[1])...)), randn!(similar(A, T, size(VC[2])...))) + return VC, ΔVC +end +function ad_left_orth_setup(A::Diagonal) + m, n = size(A) + T = eltype(A) + VC = left_orth(A) + ΔVC = (Diagonal(randn!(similar(A.diag, T, m))), Diagonal(randn!(similar(A.diag, T, m)))) + return VC, ΔVC +end + +function ad_left_null_setup(A) + m, n = size(A) + T = eltype(A) + N = left_orth(A; alg = :qr)[1] * randn!(similar(A, T, min(m, n), m - min(m, n))) + ΔN = left_orth(A; alg = :qr)[1] * randn!(similar(A, T, min(m, n), m - min(m, n))) + return N, ΔN +end + +function ad_right_orth_setup(A) + m, n = size(A) + T = eltype(A) + CVᴴ = right_orth(A) + ΔCVᴴ = (randn!(similar(A, T, size(CVᴴ[1])...)), randn!(similar(A, T, size(CVᴴ[2])...))) + return CVᴴ, ΔCVᴴ +end +ad_right_orth_setup(A::Diagonal) = ad_left_orth_setup(A) + +function ad_right_null_setup(A) + m, n = size(A) + T = eltype(A) + Nᴴ = randn!(similar(A, T, n - min(m, n), min(m, n))) * right_orth(A; alg = :lq)[2] + ΔNᴴ = randn!(similar(A, T, n - min(m, n), min(m, n))) * right_orth(A; alg = :lq)[2] + return Nᴴ, ΔNᴴ +end diff --git a/test/testsuite/chainrules.jl b/test/testsuite/chainrules.jl new file mode 100644 index 00000000..625ff6a3 --- /dev/null +++ b/test/testsuite/chainrules.jl @@ -0,0 +1,612 @@ +using MatrixAlgebraKit +using ChainRulesCore, ChainRulesTestUtils, Zygote +using MatrixAlgebraKit: diagview, TruncatedAlgorithm, PolarViaSVD +using LinearAlgebra: UpperTriangular, Diagonal, Hermitian, mul! + +for f in + ( + :qr_compact, :qr_full, :qr_null, :lq_compact, :lq_full, :lq_null, + :eig_full, :eig_trunc, :eig_vals, :eigh_full, :eigh_trunc, :eigh_vals, + :eig_trunc_no_error, :eigh_trunc_no_error, + :svd_compact, :svd_trunc, :svd_trunc_no_error, :svd_vals, + :left_polar, :right_polar, + ) + copy_f = Symbol(:cr_copy_, f) + f! = Symbol(f, '!') + _hermitian = startswith(string(f), "eigh") + @eval begin + function $copy_f(input, alg) + if $_hermitian + input = (input + input') / 2 + end + return $f(input, alg) + end + function ChainRulesCore.rrule(::typeof($copy_f), input, alg) + output = MatrixAlgebraKit.initialize_output($f!, input, alg) + if $_hermitian + input = (input + input') / 2 + else + input = copy(input) + end + output, pb = ChainRulesCore.rrule($f!, input, output, alg) + return output, x -> (NoTangent(), pb(x)[2], NoTangent()) + end + end +end + +function test_chainrules(T::Type, sz; kwargs...) + summary_str = testargs_summary(T, sz) + return @testset "Chainrules AD $summary_str" begin + test_chainrules_qr(T, sz; kwargs...) + test_chainrules_lq(T, sz; kwargs...) + if length(sz) == 1 || sz[1] == sz[2] + test_chainrules_eig(T, sz; kwargs...) + test_chainrules_eigh(T, sz; kwargs...) + end + test_chainrules_svd(T, sz; kwargs...) + test_chainrules_polar(T, sz; kwargs...) + test_chainrules_orthnull(T, sz; kwargs...) + end +end + +function test_chainrules_qr( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "QR ChainRules AD rules $summary_str" begin + A = instantiate_matrix(T, sz) + config = Zygote.ZygoteRuleConfig() + alg = MatrixAlgebraKit.default_qr_algorithm(A) + @testset "qr_compact" begin + QR, ΔQR = ad_qr_compact_setup(A) + ΔQ, ΔR = ΔQR + test_rrule( + cr_copy_qr_compact, A, alg ⊢ NoTangent(); + output_tangent = ΔQR, atol = atol, rtol = rtol + ) + test_rrule( + config, qr_compact, A; + fkwargs = (; positive = true), output_tangent = ΔQR, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, first ∘ qr_compact, A; + fkwargs = (; positive = true), output_tangent = ΔQ, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, last ∘ qr_compact, A; + fkwargs = (; positive = true), output_tangent = ΔR, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + @testset "qr_null" begin + N, ΔN = ad_qr_null_setup(A) + test_rrule( + cr_copy_qr_null, A, alg ⊢ NoTangent(); + output_tangent = ΔN, atol = atol, rtol = rtol + ) + test_rrule( + config, qr_null, A; + fkwargs = (; positive = true), output_tangent = ΔN, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + m, n = size(A) + end + @testset "qr_full" begin + QR, ΔQR = ad_qr_full_setup(A) + test_rrule( + cr_copy_qr_full, A, alg ⊢ NoTangent(); + output_tangent = ΔQR, atol = atol, rtol = rtol + ) + test_rrule( + config, qr_full, A; + fkwargs = (; positive = true), output_tangent = ΔQR, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + m, n = size(A) + end + @testset "qr_compact - rank-deficient A" begin + m, n = size(A) + r = min(m, n) - 5 + Ard = instantiate_matrix(T, (m, r)) * instantiate_matrix(T, (r, n)) + QR, ΔQR = ad_qr_rd_compact_setup(Ard) + ΔQ, ΔR = ΔQR + test_rrule( + cr_copy_qr_compact, Ard, alg ⊢ NoTangent(); + output_tangent = ΔQR, atol = atol, rtol = rtol + ) + test_rrule( + config, qr_compact, Ard; + fkwargs = (; positive = true), output_tangent = ΔQR, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + end +end + +function test_chainrules_lq( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "LQ Chainrules AD rules $summary_str" begin + A = instantiate_matrix(T, sz) + m, n = size(A) + config = Zygote.ZygoteRuleConfig() + alg = MatrixAlgebraKit.default_lq_algorithm(A) + @testset "lq_compact" begin + LQ, ΔLQ = ad_lq_compact_setup(A) + ΔL, ΔQ = ΔLQ + test_rrule( + cr_copy_lq_compact, A, alg ⊢ NoTangent(); + output_tangent = ΔLQ, atol = atol, rtol = rtol + ) + test_rrule( + config, lq_compact, A; + fkwargs = (; positive = true), output_tangent = ΔLQ, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, first ∘ lq_compact, A; + fkwargs = (; positive = true), output_tangent = ΔL, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, last ∘ lq_compact, A; + fkwargs = (; positive = true), output_tangent = ΔQ, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + @testset "lq_null" begin + Nᴴ, ΔNᴴ = ad_lq_null_setup(A) + test_rrule( + cr_copy_lq_null, A, alg ⊢ NoTangent(); + output_tangent = ΔNᴴ, atol = atol, rtol = rtol + ) + test_rrule( + config, lq_null, A; + fkwargs = (; positive = true), output_tangent = ΔNᴴ, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + @testset "lq_full" begin + LQ, ΔLQ = ad_lq_full_setup(A) + test_rrule( + cr_copy_lq_full, A, alg ⊢ NoTangent(); + output_tangent = ΔLQ, atol = atol, rtol = rtol + ) + test_rrule( + config, lq_full, A; + fkwargs = (; positive = true), output_tangent = ΔLQ, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + @testset "lq_compact - rank-deficient A" begin + m, n = size(A) + r = min(m, n) - 5 + Ard = instantiate_matrix(T, (m, r)) * instantiate_matrix(T, (r, n)) + LQ, ΔLQ = ad_lq_rd_compact_setup(Ard) + test_rrule( + cr_copy_lq_compact, Ard, alg ⊢ NoTangent(); + output_tangent = ΔLQ, atol = atol, rtol = rtol + ) + test_rrule( + config, lq_compact, Ard; + fkwargs = (; positive = true), output_tangent = ΔLQ, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + end +end + +function test_chainrules_eig( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "EIG Chainrules AD rules $summary_str" begin + A = make_eig_matrix(T, sz) + m = size(A, 1) + config = Zygote.ZygoteRuleConfig() + alg = MatrixAlgebraKit.default_eig_algorithm(A) + @testset "eig_full" begin + DV, ΔDV, ΔD2V = ad_eig_full_setup(A) + ΔD, ΔV = ΔDV + test_rrule( + cr_copy_eig_full, A, alg ⊢ NoTangent(); output_tangent = ΔDV, atol, rtol + ) + test_rrule( + cr_copy_eig_full, A, alg ⊢ NoTangent(); output_tangent = ΔD2V, atol, rtol + ) + test_rrule( + config, eig_full, A, alg ⊢ NoTangent(); + output_tangent = ΔDV, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, eig_full, A, alg ⊢ NoTangent(); + output_tangent = ΔD2V, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, first ∘ eig_full, A, alg ⊢ NoTangent(); + output_tangent = ΔD, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, last ∘ eig_full, A, alg ⊢ NoTangent(); + output_tangent = ΔV, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + @testset "eig_vals" begin + D, ΔD = ad_eig_vals_setup(A) + test_rrule( + cr_copy_eig_vals, A, alg ⊢ NoTangent(); output_tangent = ΔD, atol, rtol + ) + test_rrule( + config, eig_vals, A, alg ⊢ NoTangent(); + output_tangent = ΔD, atol, rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + @testset "eig_trunc" begin + for r in 1:4:m + truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs)) + DV, DVtrunc, ΔDV, ΔDVtrunc = ad_eig_trunc_setup(A, truncalg) + test_rrule( + cr_copy_eig_trunc, A, truncalg ⊢ NoTangent(); + output_tangent = (ΔDVtrunc..., zero(real(T))), + atol = atol, rtol = rtol + ) + test_rrule( + cr_copy_eig_trunc_no_error, A, truncalg ⊢ NoTangent(); + output_tangent = ΔDVtrunc, atol = atol, rtol = rtol + ) + ind = MatrixAlgebraKit.findtruncated(diagview(DV[1]), truncalg.trunc) + dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, DV, ΔDVtrunc, ind) + dA2 = MatrixAlgebraKit.eig_trunc_pullback!(zero(A), A, DVtrunc, ΔDVtrunc) + @test isapprox(dA1, dA2; atol = atol, rtol = rtol) + end + truncalg = TruncatedAlgorithm(alg, truncrank(5; by = real)) + DV, DVtrunc, ΔDV, ΔDVtrunc = ad_eig_trunc_setup(A, truncalg) + test_rrule( + cr_copy_eig_trunc, A, truncalg ⊢ NoTangent(); + output_tangent = (ΔDVtrunc..., zero(real(T))), + atol = atol, rtol = rtol + ) + test_rrule( + cr_copy_eig_trunc_no_error, A, truncalg ⊢ NoTangent(); + output_tangent = ΔDVtrunc, atol = atol, rtol = rtol + ) + ind = MatrixAlgebraKit.findtruncated(diagview(DV[1]), truncalg.trunc) + dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, DV, ΔDVtrunc, ind) + dA2 = MatrixAlgebraKit.eig_trunc_pullback!(zero(A), A, DVtrunc, ΔDVtrunc) + @test isapprox(dA1, dA2; atol = atol, rtol = rtol) + end + end +end + +function test_chainrules_eigh( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "EIGH ChainRules AD rules $summary_str" begin + A = make_eigh_matrix(T, sz) + m = size(A, 1) + config = Zygote.ZygoteRuleConfig() + alg = MatrixAlgebraKit.default_eigh_algorithm(A) + # copy_eigh_xxxx includes a projector onto the Hermitian part of the matrix + @testset "eigh_full" begin + DV, ΔDV, ΔD2V = ad_eigh_full_setup(A) + ΔD, ΔV = ΔDV + test_rrule( + cr_copy_eigh_full, A, alg ⊢ NoTangent(); output_tangent = ΔDV, atol, rtol + ) + test_rrule( + cr_copy_eigh_full, A, alg ⊢ NoTangent(); output_tangent = ΔD2V, atol, rtol + ) + # eigh_full does not include a projector onto the Hermitian part of the matrix + test_rrule( + config, eigh_full ∘ Matrix ∘ Hermitian, A; + output_tangent = ΔDV, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, eigh_full ∘ Matrix ∘ Hermitian, A; + output_tangent = ΔD2V, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, first ∘ eigh_full ∘ Matrix ∘ Hermitian, A; + output_tangent = ΔD, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, last ∘ eigh_full ∘ Matrix ∘ Hermitian, A; + output_tangent = ΔV, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + @testset "eigh_vals" begin + D, ΔD = ad_eigh_vals_setup(A) + test_rrule( + cr_copy_eigh_vals, A, alg ⊢ NoTangent(); output_tangent = ΔD, atol, rtol + ) + test_rrule( + config, eigh_vals ∘ Matrix ∘ Hermitian, A; + output_tangent = ΔD, atol, rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + @testset "eigh_trunc" begin + eigh_trunc2(A; kwargs...) = eigh_trunc(Matrix(Hermitian(A)); kwargs...) + eigh_trunc_no_error2(A; kwargs...) = eigh_trunc_no_error(Matrix(Hermitian(A)); kwargs...) + for r in 1:4:m + truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs)) + DV, DVtrunc, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg) + test_rrule( + cr_copy_eigh_trunc, A, truncalg ⊢ NoTangent(); + output_tangent = (ΔDVtrunc..., zero(real(T))), + atol = atol, rtol = rtol + ) + test_rrule( + cr_copy_eigh_trunc_no_error, A, truncalg ⊢ NoTangent(); + output_tangent = ΔDVtrunc, atol = atol, rtol = rtol + ) + ind = MatrixAlgebraKit.findtruncated(diagview(DV[1]), truncalg.trunc) + dA1 = MatrixAlgebraKit.eigh_pullback!(zero(A), A, DV, ΔDVtrunc, ind) + dA2 = MatrixAlgebraKit.eigh_trunc_pullback!(zero(A), A, DVtrunc, ΔDVtrunc) + @test isapprox(dA1, dA2; atol = atol, rtol = rtol) + trunc = truncrank(r; by = real) + ind = MatrixAlgebraKit.findtruncated(diagview(DV[1]), trunc) + truncalg = TruncatedAlgorithm(alg, trunc) + DV, DVtrunc, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg) + test_rrule( + config, eigh_trunc2, A; + fkwargs = (; trunc = trunc), + output_tangent = (ΔDVtrunc..., zero(real(T))), + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, eigh_trunc_no_error2, A; + fkwargs = (; trunc = trunc), + output_tangent = ΔDVtrunc, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + D, ΔD = ad_eigh_vals_setup(A / 2) + truncalg = TruncatedAlgorithm(alg, trunctol(; atol = maximum(abs, D) / 2)) + DV, DVtrunc, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg) + ind = MatrixAlgebraKit.findtruncated(diagview(DV[1]), truncalg.trunc) + test_rrule( + cr_copy_eigh_trunc, A, truncalg ⊢ NoTangent(); + output_tangent = (ΔDVtrunc..., zero(real(T))), + atol = atol, rtol = rtol + ) + test_rrule( + cr_copy_eigh_trunc_no_error, A, truncalg ⊢ NoTangent(); + output_tangent = ΔDVtrunc, atol = atol, rtol = rtol + ) + dA1 = MatrixAlgebraKit.eigh_pullback!(zero(A), A, DV, ΔDVtrunc, ind) + dA2 = MatrixAlgebraKit.eigh_trunc_pullback!(zero(A), A, DVtrunc, ΔDVtrunc) + @test isapprox(dA1, dA2; atol = atol, rtol = rtol) + trunc = trunctol(; rtol = 1 / 2) + truncalg = TruncatedAlgorithm(alg, trunc) + DV, DVtrunc, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg) + ind = MatrixAlgebraKit.findtruncated(diagview(DV[1]), truncalg.trunc) + test_rrule( + config, eigh_trunc2, A; + fkwargs = (; trunc = trunc), + output_tangent = (ΔDVtrunc..., zero(real(T))), + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, eigh_trunc_no_error2, A; + fkwargs = (; trunc = trunc), + output_tangent = ΔDVtrunc, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + end +end + +function test_chainrules_svd( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "SVD Chainrules AD rules $summary_str" begin + A = instantiate_matrix(T, sz) + minmn = min(size(A)...) + config = Zygote.ZygoteRuleConfig() + alg = MatrixAlgebraKit.default_svd_algorithm(A) + @testset "svd_compact" begin + USV, ΔUSVᴴ, ΔUS2Vᴴ = ad_svd_compact_setup(A) + test_rrule( + cr_copy_svd_compact, A, alg ⊢ NoTangent(); + output_tangent = ΔUSVᴴ, atol = atol, rtol = rtol + ) + test_rrule( + cr_copy_svd_compact, A, alg ⊢ NoTangent(); + output_tangent = ΔUS2Vᴴ, atol = atol, rtol = rtol + ) + test_rrule( + config, svd_compact, A, alg ⊢ NoTangent(); + output_tangent = ΔUSVᴴ, atol = atol, rtol = rtol, + rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, svd_compact, A, alg ⊢ NoTangent(); + output_tangent = ΔUS2Vᴴ, atol = atol, rtol = rtol, + rrule_f = rrule_via_ad, check_inferred = false + ) + end + @testset "svd_vals" begin + S, ΔS = ad_svd_vals_setup(A) + test_rrule( + cr_copy_svd_vals, A, alg ⊢ NoTangent(); + output_tangent = ΔS, atol, rtol + ) + test_rrule( + config, svd_vals, A, alg ⊢ NoTangent(); + output_tangent = ΔS, atol, rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + @testset "svd_trunc" begin + @testset for r in 1:4:minmn + truncalg = TruncatedAlgorithm(alg, truncrank(r)) + USVᴴ, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(A, truncalg) + test_rrule( + cr_copy_svd_trunc, A, truncalg ⊢ NoTangent(); + output_tangent = (ΔUSVᴴtrunc..., zero(real(T))), + atol = atol, rtol = rtol + ) + test_rrule( + cr_copy_svd_trunc_no_error, A, truncalg ⊢ NoTangent(); + output_tangent = ΔUSVᴴtrunc, + atol = atol, rtol = rtol + ) + U, S, Vᴴ = USVᴴ + ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc) + Strunc = Diagonal(diagview(S)[ind]) + Utrunc = U[:, ind] + Vᴴtrunc = Vᴴ[ind, :] + dA1 = MatrixAlgebraKit.svd_pullback!(zero(A), A, USVᴴ, ΔUSVᴴtrunc, ind) + dA2 = MatrixAlgebraKit.svd_trunc_pullback!(zero(A), A, (Utrunc, Strunc, Vᴴtrunc), ΔUSVᴴtrunc) + ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc) + @test isapprox(dA1, dA2; atol = atol, rtol = rtol) + trunc = truncrank(r) + ind = MatrixAlgebraKit.findtruncated(diagview(S), trunc) + test_rrule( + config, svd_trunc, A; + fkwargs = (; trunc = trunc), + output_tangent = (ΔUSVᴴtrunc..., zero(real(T))), + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, svd_trunc_no_error, A; + fkwargs = (; trunc = trunc), + output_tangent = ΔUSVᴴtrunc, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + S, ΔS = ad_svd_vals_setup(A) + truncalg = TruncatedAlgorithm(alg, trunctol(atol = S[1, 1] / 2)) + USVᴴ, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(A, truncalg) + test_rrule( + cr_copy_svd_trunc, A, truncalg ⊢ NoTangent(); + output_tangent = (ΔUSVᴴtrunc..., zero(real(T))), + atol = atol, rtol = rtol + ) + test_rrule( + cr_copy_svd_trunc_no_error, A, truncalg ⊢ NoTangent(); + output_tangent = ΔUSVᴴtrunc, atol = atol, rtol = rtol + ) + U, S, Vᴴ = USVᴴ + ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc) + Strunc = Diagonal(diagview(S)[ind]) + Utrunc = U[:, ind] + Vᴴtrunc = Vᴴ[ind, :] + dA1 = MatrixAlgebraKit.svd_pullback!(zero(A), A, USVᴴ, ΔUSVᴴtrunc, ind) + dA2 = MatrixAlgebraKit.svd_trunc_pullback!(zero(A), A, (Utrunc, Strunc, Vᴴtrunc), ΔUSVᴴtrunc) + @test isapprox(dA1, dA2; atol = atol, rtol = rtol) + trunc = trunctol(; atol = S[1, 1] / 2) + ind = MatrixAlgebraKit.findtruncated(diagview(S), trunc) + test_rrule( + config, svd_trunc, A; + fkwargs = (; trunc = trunc), + output_tangent = (ΔUSVᴴtrunc..., zero(real(T))), + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, svd_trunc_no_error, A; + fkwargs = (; trunc = trunc), + output_tangent = ΔUSVᴴtrunc, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + end +end + +function test_chainrules_polar( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "Polar Chainrules AD rules $summary_str" begin + A = instantiate_matrix(T, sz) + m, n = size(A) + config = Zygote.ZygoteRuleConfig() + alg = MatrixAlgebraKit.default_polar_algorithm(A) + @testset "left_polar" begin + if m >= n + test_rrule(cr_copy_left_polar, A, alg ⊢ NoTangent(); atol = atol, rtol = rtol) + test_rrule( + config, left_polar, A, alg ⊢ NoTangent(); + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + end + @testset "right_polar" begin + if m <= n + test_rrule(cr_copy_right_polar, A, alg ⊢ NoTangent(); atol = atol, rtol = rtol) + test_rrule( + config, right_polar, A, alg ⊢ NoTangent(); + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + end + end +end + +function test_chainrules_orthnull( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "Orthnull Chainrules AD rules $summary_str" begin + A = instantiate_matrix(T, sz) + m, n = size(A) + config = Zygote.ZygoteRuleConfig() + N, ΔN = ad_left_null_setup(A) + Nᴴ, ΔNᴴ = ad_right_null_setup(A) + test_rrule( + config, left_orth, A; + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, left_orth, A; + fkwargs = (; alg = :qr), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + m >= n && + test_rrule( + config, left_orth, A; + fkwargs = (; alg = :polar), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, left_null, A; + fkwargs = (; alg = :qr), output_tangent = ΔN, atol = atol, rtol = rtol, + rrule_f = rrule_via_ad, check_inferred = false + ) + + test_rrule( + config, right_orth, A; + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, right_orth, A; fkwargs = (; alg = :lq), + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + m <= n && + test_rrule( + config, right_orth, A; fkwargs = (; alg = :polar), + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, right_null, A; + fkwargs = (; alg = :lq), output_tangent = ΔNᴴ, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end +end diff --git a/test/testsuite/mooncake.jl b/test/testsuite/mooncake.jl new file mode 100644 index 00000000..aac46182 --- /dev/null +++ b/test/testsuite/mooncake.jl @@ -0,0 +1,483 @@ +using TestExtras +using MatrixAlgebraKit +using Mooncake, Mooncake.TestUtils +using Mooncake: rrule!! +using MatrixAlgebraKit: diagview, TruncatedAlgorithm, PolarViaSVD, eigh_trunc +using LinearAlgebra: BlasFloat +using GenericLinearAlgebra + +function mc_copy_eigh_full(A; kwargs...) + A = (A + A') / 2 + return eigh_full(A; kwargs...) +end + +function mc_copy_eigh_full!(A, DV; kwargs...) + A = (A + A') / 2 + return eigh_full!(A, DV; kwargs...) +end + +function mc_copy_eigh_vals(A; kwargs...) + A = (A + A') / 2 + return eigh_vals(A; kwargs...) +end + +function mc_copy_eigh_vals!(A, D; kwargs...) + A = (A + A') / 2 + return eigh_vals!(A, D; kwargs...) +end + +function mc_copy_eigh_trunc(A, alg; kwargs...) + A = (A + A') / 2 + return eigh_trunc(A, alg; kwargs...) +end + +function mc_copy_eigh_trunc!(A, DV, alg; kwargs...) + A = (A + A') / 2 + return eigh_trunc!(A, DV, alg; kwargs...) +end + +function mc_copy_eigh_trunc_no_error(A, alg; kwargs...) + A = (A + A') / 2 + return eigh_trunc_no_error(A, alg; kwargs...) +end + +function mc_copy_eigh_trunc_no_error!(A, DV, alg; kwargs...) + A = (A + A') / 2 + return eigh_trunc_no_error!(A, DV, alg; kwargs...) +end + +MatrixAlgebraKit.copy_input(::typeof(mc_copy_eigh_full), A) = MatrixAlgebraKit.copy_input(eigh_full, A) +MatrixAlgebraKit.copy_input(::typeof(mc_copy_eigh_vals), A) = MatrixAlgebraKit.copy_input(eigh_vals, A) +MatrixAlgebraKit.copy_input(::typeof(mc_copy_eigh_trunc), A) = MatrixAlgebraKit.copy_input(eigh_trunc, A) +MatrixAlgebraKit.copy_input(::typeof(mc_copy_eigh_trunc_no_error), A) = MatrixAlgebraKit.copy_input(eigh_trunc, A) + +make_mooncake_tangent(ΔAelem::T) where {T <: Real} = ΔAelem +make_mooncake_tangent(ΔAelem::T) where {T <: Complex} = Mooncake.build_tangent(T, real(ΔAelem), imag(ΔAelem)) +make_mooncake_tangent(ΔA::AbstractMatrix{<:Real}) = ΔA +make_mooncake_tangent(ΔA::AbstractVector{<:Real}) = ΔA +make_mooncake_tangent(ΔA::AbstractMatrix{T}) where {T <: Complex} = map(make_mooncake_tangent, ΔA) +make_mooncake_tangent(ΔA::AbstractVector{T}) where {T <: Complex} = map(make_mooncake_tangent, ΔA) +make_mooncake_tangent(ΔD::Diagonal{T}) where {T <: Real} = Mooncake.build_tangent(typeof(ΔD), diagview(ΔD)) +make_mooncake_tangent(ΔD::Diagonal{T}) where {T <: Complex} = Mooncake.build_tangent(typeof(ΔD), map(make_mooncake_tangent, diagview(ΔD))) + +make_mooncake_tangent(T::Tuple) = Mooncake.build_tangent(typeof(T), make_mooncake_tangent.(T)...) + +make_mooncake_fdata(x) = make_mooncake_tangent(x) +make_mooncake_fdata(x::Diagonal) = Mooncake.FData((diag = make_mooncake_tangent(x.diag),)) +make_mooncake_fdata(x::Tuple) = map(make_mooncake_fdata, x) + +# no `alg` argument +function _get_copying_derivative(f_c, rrule, A, ΔA, args, Δargs, ::Nothing, rdata) + dA_copy = make_mooncake_fdata(copy(ΔA)) + A_copy = copy(A) + dargs_copy = make_mooncake_fdata(deepcopy(Δargs)) + copy_out, copy_pb!! = rrule(Mooncake.CoDual(f_c, Mooncake.NoFData()), Mooncake.CoDual(A_copy, dA_copy), Mooncake.CoDual(args, dargs_copy)) + copy_pb!!(rdata) + return dA_copy +end + +# `alg` argument +function _get_copying_derivative(f_c, rrule, A, ΔA, args, Δargs, alg, rdata) + dA_copy = make_mooncake_fdata(copy(ΔA)) + A_copy = copy(A) + dargs_copy = make_mooncake_fdata(deepcopy(Δargs)) + copy_out, copy_pb!! = rrule(Mooncake.CoDual(f_c, Mooncake.NoFData()), Mooncake.CoDual(A_copy, dA_copy), Mooncake.CoDual(args, dargs_copy), Mooncake.CoDual(alg, Mooncake.NoFData())) + copy_pb!!(rdata) + return dA_copy +end + +function _get_inplace_derivative(f!, A, ΔA, args, Δargs, ::Nothing, rdata) + dA_inplace = make_mooncake_fdata(copy(ΔA)) + A_inplace = copy(A) + dargs_inplace = make_mooncake_fdata(deepcopy(Δargs)) + # not every f! has a handwritten rrule!! + inplace_sig = Tuple{typeof(f!), typeof(A), typeof(args)} + has_handwritten_rule = hasmethod(Mooncake.rrule!!, inplace_sig) + if has_handwritten_rule + inplace_out, inplace_pb!! = Mooncake.rrule!!(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(A_inplace, dA_inplace), Mooncake.CoDual(args, dargs_inplace)) + else + inplace_sig = Tuple{typeof(f!), typeof(A), typeof(args)} + rvs_interp = Mooncake.get_interpreter(Mooncake.ReverseMode) + inplace_rrule = Mooncake.build_rrule(rvs_interp, inplace_sig) + inplace_out, inplace_pb!! = inplace_rrule(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(A_inplace, dA_inplace), Mooncake.CoDual(args, dargs_inplace)) + end + inplace_pb!!(rdata) + return dA_inplace +end + +function _get_inplace_derivative(f!, A, ΔA, args, Δargs, alg, rdata) + dA_inplace = make_mooncake_fdata(copy(ΔA)) + A_inplace = copy(A) + dargs_inplace = make_mooncake_fdata(deepcopy(Δargs)) + # not every f! has a handwritten rrule!! + inplace_sig = Tuple{typeof(f!), typeof(A), typeof(args), typeof(alg)} + has_handwritten_rule = hasmethod(Mooncake.rrule!!, inplace_sig) + if has_handwritten_rule + inplace_out, inplace_pb!! = Mooncake.rrule!!(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(A_inplace, dA_inplace), Mooncake.CoDual(args, dargs_inplace), Mooncake.CoDual(alg, Mooncake.NoFData())) + else + inplace_sig = Tuple{typeof(f!), typeof(A), typeof(args), typeof(alg)} + rvs_interp = Mooncake.get_interpreter(Mooncake.ReverseMode) + inplace_rrule = Mooncake.build_rrule(rvs_interp, inplace_sig) + inplace_out, inplace_pb!! = inplace_rrule(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(A_inplace, dA_inplace), Mooncake.CoDual(args, dargs_inplace), Mooncake.CoDual(alg, Mooncake.NoFData())) + end + inplace_pb!!(rdata) + return dA_inplace +end + +""" + test_pullbacks_match(f!, f, A, args, Δargs, alg = nothing; rdata = Mooncake.NoRData()) + +Compare the result of running the *in-place, mutating* function `f!`'s reverse rule +with the result of running its *non-mutating* partner function `f`'s reverse rule. +We must compare directly because many of the mutating functions modify `A` as a +scratch workspace, making testing `f!` against finite differences infeasible. + +The arguments to this function are: + - `f!` the mutating, in-place version of the function (accepts `args` for the function result) + - `f` the non-mutating version of the function (does not accept `args` for the function result) + - `A` the input matrix to factorize + - `args` preallocated output for `f!` (e.g. `Q` and `R` matrices for `qr_compact!`) + - `Δargs` precomputed derivatives of `args` for pullbacks of `f` and `f!`, to ensure they receive the same input + - `alg` optional algorithm keyword argument + - `rdata` Mooncake reverse data to supply to the pullback, in case `f` and `f!` return scalar results (as truncating functions do) +""" +function test_pullbacks_match(f!, f, A, args, Δargs, alg = nothing; rdata = Mooncake.NoRData()) + f_c = isnothing(alg) ? (A, args) -> f!(MatrixAlgebraKit.copy_input(f, A), args) : (A, args, alg) -> f!(MatrixAlgebraKit.copy_input(f, A), args, alg) + sig = isnothing(alg) ? Tuple{typeof(f_c), typeof(A), typeof(args)} : Tuple{typeof(f_c), typeof(A), typeof(args), typeof(alg)} + rvs_interp = Mooncake.get_interpreter(Mooncake.ReverseMode) + rrule = Mooncake.build_rrule(rvs_interp, sig) + ΔA = isa(A, Diagonal) ? Diagonal(randn!(similar(A.diag))) : randn!(similar(A)) + + dA_copy = _get_copying_derivative(f_c, rrule, A, ΔA, args, Δargs, alg, rdata) + dA_inplace = _get_inplace_derivative(f!, A, ΔA, args, Δargs, alg, rdata) + + dA_inplace_ = Mooncake.arrayify(A, dA_inplace)[2] + dA_copy_ = Mooncake.arrayify(A, dA_copy)[2] + @test dA_inplace_ ≈ dA_copy_ + return +end + +function test_mooncake(T::Type, sz; kwargs...) + summary_str = testargs_summary(T, sz) + return @testset "Mooncake AD $summary_str" begin + test_mooncake_qr(T, sz; kwargs...) + test_mooncake_lq(T, sz; kwargs...) + if length(sz) == 1 || sz[1] == sz[2] + test_mooncake_eig(T, sz; kwargs...) + test_mooncake_eigh(T, sz; kwargs...) + end + test_mooncake_svd(T, sz; kwargs...) + test_mooncake_polar(T, sz; kwargs...) + test_mooncake_orthnull(T, sz; kwargs...) + end +end + +function test_mooncake_qr( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "QR Mooncake AD rules $summary_str" begin + A = instantiate_matrix(T, sz) + @testset "qr_compact" begin + QR, ΔQR = ad_qr_compact_setup(A) + Mooncake.TestUtils.test_rule(rng, qr_compact, A; is_primitive = false, mode = Mooncake.ReverseMode, atol = atol, rtol = rtol) + test_pullbacks_match(qr_compact!, qr_compact, A, QR, ΔQR) + end + @testset "qr_null" begin + N, ΔN = ad_qr_null_setup(A) + dN = make_mooncake_tangent(copy(ΔN)) + Mooncake.TestUtils.test_rule(rng, qr_null, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dN, atol = atol, rtol = rtol) + test_pullbacks_match(qr_null!, qr_null, A, N, ΔN) + end + @testset "qr_full" begin + QR, ΔQR = ad_qr_full_setup(A) + dQR = make_mooncake_tangent(ΔQR) + Mooncake.TestUtils.test_rule(rng, qr_full, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dQR, atol = atol, rtol = rtol) + test_pullbacks_match(qr_full!, qr_full, A, QR, ΔQR) + end + @testset "qr_compact - rank-deficient A" begin + m, n = size(A) + r = min(m, n) - 5 + Ard = instantiate_matrix(T, (m, r)) * instantiate_matrix(T, (r, n)) + QR, ΔQR = ad_qr_rd_compact_setup(Ard) + dQR = make_mooncake_tangent(ΔQR) + Mooncake.TestUtils.test_rule(rng, qr_compact, Ard; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dQR, atol = atol, rtol = rtol) + test_pullbacks_match(qr_compact!, qr_compact, Ard, QR, ΔQR) + end + end +end + +function test_mooncake_lq( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "LQ Mooncake AD rules $summary_str" begin + A = instantiate_matrix(T, sz) + @testset "lq_compact" begin + LQ, ΔLQ = ad_lq_compact_setup(A) + Mooncake.TestUtils.test_rule(rng, lq_compact, A; is_primitive = false, mode = Mooncake.ReverseMode, atol = atol, rtol = rtol) + test_pullbacks_match(lq_compact!, lq_compact, A, LQ, ΔLQ) + end + @testset "lq_null" begin + Nᴴ, ΔNᴴ = ad_lq_null_setup(A) + dNᴴ = make_mooncake_tangent(ΔNᴴ) + Mooncake.TestUtils.test_rule(rng, lq_null, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dNᴴ, atol = atol, rtol = rtol) + test_pullbacks_match(lq_null!, lq_null, A, Nᴴ, ΔNᴴ) + end + @testset "lq_full" begin + LQ, ΔLQ = ad_lq_full_setup(A) + dLQ = make_mooncake_tangent(ΔLQ) + Mooncake.TestUtils.test_rule(rng, lq_full, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dLQ, atol = atol, rtol = rtol) + test_pullbacks_match(lq_full!, lq_full, A, LQ, ΔLQ) + end + @testset "lq_compact - rank-deficient A" begin + m, n = size(A) + r = min(m, n) - 5 + Ard = instantiate_matrix(T, (m, r)) * instantiate_matrix(T, (r, n)) + LQ, ΔLQ = ad_lq_rd_compact_setup(Ard) + dLQ = make_mooncake_tangent(ΔLQ) + Mooncake.TestUtils.test_rule(rng, lq_compact, Ard; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dLQ, atol = atol, rtol = rtol) + test_pullbacks_match(lq_compact!, lq_compact, Ard, LQ, ΔLQ) + end + end +end + +function test_mooncake_eig( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "EIG Mooncake AD rules $summary_str" begin + A = make_eig_matrix(T, sz) + m = size(A, 1) + @testset "eig_full" begin + DV, ΔDV, ΔD2V = ad_eig_full_setup(A) + dDV = make_mooncake_tangent(ΔD2V) + Mooncake.TestUtils.test_rule(rng, eig_full, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dDV, atol = atol, rtol = rtol) + test_pullbacks_match(eig_full!, eig_full, A, DV, ΔD2V) + end + @testset "eig_vals" begin + D, ΔD = ad_eig_vals_setup(A) + dD = make_mooncake_tangent(ΔD) + Mooncake.TestUtils.test_rule(rng, eig_vals, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dD, atol = atol, rtol = rtol) + test_pullbacks_match(eig_vals!, eig_vals, A, D, ΔD) + end + if T <: Number # not a GPU array + @testset "eig_trunc" begin + for r in 1:4:m + truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_eig_algorithm(A), truncrank(r; by = abs)) + DV, _, ΔDV, ΔDVtrunc = ad_eig_trunc_setup(A, truncalg) + ϵ = zero(real(T)) + dDVerr = make_mooncake_tangent((ΔDVtrunc..., ϵ)) + Mooncake.TestUtils.test_rule(rng, eig_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVerr, atol = atol, rtol = rtol) + test_pullbacks_match(eig_trunc!, eig_trunc, A, DV, ΔDV, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) + dDVtrunc = make_mooncake_tangent(ΔDVtrunc) + Mooncake.TestUtils.test_rule(rng, eig_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol) + test_pullbacks_match(eig_trunc_no_error!, eig_trunc_no_error, A, DV, ΔDV, truncalg) + end + truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_eig_algorithm(A), truncrank(5; by = real)) + DV, _, ΔDV, ΔDVtrunc = ad_eig_trunc_setup(A, truncalg) + ϵ = zero(real(T)) + dDVerr = make_mooncake_tangent((ΔDVtrunc..., ϵ)) + Mooncake.TestUtils.test_rule(rng, eig_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVerr, atol = atol, rtol = rtol) + test_pullbacks_match(eig_trunc!, eig_trunc, A, DV, ΔDV, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) + dDVtrunc = make_mooncake_tangent(ΔDVtrunc) + Mooncake.TestUtils.test_rule(rng, eig_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol) + test_pullbacks_match(eig_trunc_no_error!, eig_trunc_no_error, A, DV, ΔDV, truncalg) + end + end + end +end + +function test_mooncake_eigh( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "EIGH Mooncake AD rules $summary_str" begin + A = make_eigh_matrix(T, sz) + m = size(A, 1) + @testset "eigh_full" begin + DV, ΔDV, ΔD2V = ad_eigh_full_setup(A) + dDV = make_mooncake_tangent(ΔD2V) + Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_full, A; mode = Mooncake.ReverseMode, output_tangent = dDV, is_primitive = false, atol = atol, rtol = rtol) + test_pullbacks_match(mc_copy_eigh_full!, mc_copy_eigh_full, A, DV, ΔD2V) + end + @testset "eigh_vals" begin + D, ΔD = ad_eigh_vals_setup(A) + dD = make_mooncake_tangent(ΔD) + Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_vals, A; mode = Mooncake.ReverseMode, output_tangent = dD, is_primitive = false, atol = atol, rtol = rtol) + test_pullbacks_match(mc_copy_eigh_vals!, mc_copy_eigh_vals, A, D, ΔD) + end + if T <: Number + @testset "eigh_trunc" begin + for r in 1:4:m + truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_eigh_algorithm(A), truncrank(r; by = abs)) + DV, _, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg) + ϵ = zero(real(T)) + dDVerr = make_mooncake_tangent((ΔDVtrunc..., ϵ)) + Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVerr, atol = atol, rtol = rtol, is_primitive = false) + test_pullbacks_match(mc_copy_eigh_trunc!, mc_copy_eigh_trunc, A, DV, ΔDV, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) + dDVtrunc = make_mooncake_tangent(ΔDVtrunc) + Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) + test_pullbacks_match(mc_copy_eigh_trunc_no_error!, mc_copy_eigh_trunc_no_error, A, DV, ΔDV, truncalg) + end + D = eigh_vals(A / 2) + truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_eigh_algorithm(A), trunctol(; atol = maximum(abs, D) / 2)) + DV, _, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg) + ϵ = zero(real(T)) + dDVerr = make_mooncake_tangent((ΔDVtrunc..., ϵ)) + Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVerr, atol = atol, rtol = rtol, is_primitive = false) + test_pullbacks_match(mc_copy_eigh_trunc!, mc_copy_eigh_trunc, A, DV, ΔDV, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) + dDVtrunc = make_mooncake_tangent(ΔDVtrunc) + Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) + test_pullbacks_match(mc_copy_eigh_trunc_no_error!, mc_copy_eigh_trunc_no_error, A, DV, ΔDV, truncalg) + end + end + end +end + +function test_mooncake_svd( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "SVD Mooncake AD rules $summary_str" begin + A = instantiate_matrix(T, sz) + minmn = min(size(A)...) + @testset "svd_compact" begin + USVᴴ, _, ΔUSVᴴ = ad_svd_compact_setup(A) + dUSVᴴ = make_mooncake_tangent(ΔUSVᴴ) + Mooncake.TestUtils.test_rule(rng, svd_compact, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol = atol, rtol = rtol) + test_pullbacks_match(svd_compact!, svd_compact, A, USVᴴ, ΔUSVᴴ) + end + @testset "svd_full" begin + USVᴴ, ΔUSVᴴ = ad_svd_full_setup(A) + dUSVᴴ = make_mooncake_tangent(ΔUSVᴴ) + Mooncake.TestUtils.test_rule(rng, svd_full, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol = atol, rtol = rtol) + test_pullbacks_match(svd_full!, svd_full, A, USVᴴ, ΔUSVᴴ) + end + @testset "svd_vals" begin + S, ΔS = ad_svd_vals_setup(A) + Mooncake.TestUtils.test_rule(rng, svd_vals, A; is_primitive = false, mode = Mooncake.ReverseMode, atol = atol, rtol = rtol) + test_pullbacks_match(svd_vals!, svd_vals, A, S, ΔS) + end + if T <: Number # not a GPU array + @testset "svd_trunc" begin + S, ΔS = ad_svd_vals_setup(A) + @testset for r in 1:4:minmn + truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_svd_algorithm(A), truncrank(r)) + USVᴴ, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(A, truncalg) + ϵ = zero(real(T)) + dUSVᴴerr = make_mooncake_tangent((ΔUSVᴴtrunc..., ϵ)) + Mooncake.TestUtils.test_rule(rng, svd_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴerr, atol = atol, rtol = rtol) + test_pullbacks_match(svd_trunc!, svd_trunc, A, USVᴴ, ΔUSVᴴ, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) + dUSVᴴ = make_mooncake_tangent(ΔUSVᴴtrunc) + Mooncake.TestUtils.test_rule(rng, svd_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol = atol, rtol = rtol) + test_pullbacks_match(svd_trunc_no_error!, svd_trunc_no_error, A, USVᴴ, ΔUSVᴴ, truncalg) + end + @testset "trunctol" begin + truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_svd_algorithm(A), trunctol(atol = S[1, 1] / 2)) + USVᴴ, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(A, truncalg) + ϵ = zero(real(T)) + dUSVᴴerr = make_mooncake_tangent((ΔUSVᴴtrunc..., ϵ)) + Mooncake.TestUtils.test_rule(rng, svd_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴerr, atol = atol, rtol = rtol) + test_pullbacks_match(svd_trunc!, svd_trunc, A, USVᴴ, ΔUSVᴴ, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) + dUSVᴴ = make_mooncake_tangent(ΔUSVᴴtrunc) + Mooncake.TestUtils.test_rule(rng, svd_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol = atol, rtol = rtol) + test_pullbacks_match(svd_trunc_no_error!, svd_trunc_no_error, A, USVᴴ, ΔUSVᴴ, truncalg) + end + end + end + end +end + +function test_mooncake_polar( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "Polar Mooncake AD rules $summary_str" begin + A = instantiate_matrix(T, sz) + m, n = size(A) + @testset "left_polar" begin + if m >= n + WP, ΔWP = ad_left_polar_setup(A) + Mooncake.TestUtils.test_rule(rng, left_polar, A; is_primitive = false, mode = Mooncake.ReverseMode, atol = atol, rtol = rtol) + test_pullbacks_match(left_polar!, left_polar, A, WP, ΔWP) + end + end + @testset "right_polar" begin + if m <= n + PWᴴ, ΔPWᴴ = ad_right_polar_setup(A) + Mooncake.TestUtils.test_rule(rng, right_polar, A; is_primitive = false, mode = Mooncake.ReverseMode, atol = atol, rtol = rtol) + test_pullbacks_match(right_polar!, right_polar, A, PWᴴ, ΔPWᴴ) + end + end + end +end + +left_orth_qr(X) = left_orth(X; alg = :qr) +left_orth_polar(X) = left_orth(X; alg = :polar) +left_null_qr(X) = left_null(X; alg = :qr) +right_orth_lq(X) = right_orth(X; alg = :lq) +right_orth_polar(X) = right_orth(X; alg = :polar) +right_null_lq(X) = right_null(X; alg = :lq) + +MatrixAlgebraKit.copy_input(::typeof(left_orth_qr), A) = MatrixAlgebraKit.copy_input(left_orth, A) +MatrixAlgebraKit.copy_input(::typeof(left_orth_polar), A) = MatrixAlgebraKit.copy_input(left_orth, A) +MatrixAlgebraKit.copy_input(::typeof(left_null_qr), A) = MatrixAlgebraKit.copy_input(left_null, A) +MatrixAlgebraKit.copy_input(::typeof(right_orth_lq), A) = MatrixAlgebraKit.copy_input(right_orth, A) +MatrixAlgebraKit.copy_input(::typeof(right_orth_polar), A) = MatrixAlgebraKit.copy_input(right_orth, A) +MatrixAlgebraKit.copy_input(::typeof(right_null_lq), A) = MatrixAlgebraKit.copy_input(right_null, A) + +function test_mooncake_orthnull( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "Orthnull Mooncake AD rules $summary_str" begin + A = instantiate_matrix(T, sz) + m, n = size(A) + VC, ΔVC = ad_left_orth_setup(A) + CVᴴ, ΔCVᴴ = ad_right_orth_setup(A) + Mooncake.TestUtils.test_rule(rng, left_orth, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) + test_pullbacks_match(left_orth!, left_orth, A, VC, ΔVC) + Mooncake.TestUtils.test_rule(rng, right_orth, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) + test_pullbacks_match(right_orth!, right_orth, A, CVᴴ, ΔCVᴴ) + + Mooncake.TestUtils.test_rule(rng, left_orth_qr, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) + test_pullbacks_match(((X, VC) -> left_orth!(X, VC; alg = :qr)), left_orth_qr, A, VC, ΔVC) + if m >= n + Mooncake.TestUtils.test_rule(rng, left_orth_polar, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) + test_pullbacks_match(((X, VC) -> left_orth!(X, VC; alg = :polar)), left_orth_polar, A, VC, ΔVC) + end + + N, ΔN = ad_left_null_setup(A) + dN = make_mooncake_tangent(ΔN) + Mooncake.TestUtils.test_rule(rng, left_null_qr, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false, output_tangent = dN) + test_pullbacks_match(((X, N) -> left_null!(X, N; alg = :qr)), left_null_qr, A, N, ΔN) + + Mooncake.TestUtils.test_rule(rng, right_orth_lq, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) + test_pullbacks_match(((X, CVᴴ) -> right_orth!(X, CVᴴ; alg = :lq)), right_orth_lq, A, CVᴴ, ΔCVᴴ) + + if m <= n + Mooncake.TestUtils.test_rule(rng, right_orth_polar, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) + test_pullbacks_match(((X, CVᴴ) -> right_orth!(X, CVᴴ; alg = :polar)), right_orth_polar, A, CVᴴ, ΔCVᴴ) + end + + Nᴴ, ΔNᴴ = ad_right_null_setup(A) + dNᴴ = make_mooncake_tangent(ΔNᴴ) + Mooncake.TestUtils.test_rule(rng, right_null_lq, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false, output_tangent = dNᴴ) + test_pullbacks_match(((X, Nᴴ) -> right_null!(X, Nᴴ; alg = :lq)), right_null_lq, A, Nᴴ, ΔNᴴ) + end +end From 239686ed2c685541a91d55df8cc1347a443edaa3 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 20 Jan 2026 19:33:30 +0100 Subject: [PATCH 02/11] Reenable Diagonals --- test/mooncake.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/mooncake.jl b/test/mooncake.jl index eed9154d..40d81be2 100644 --- a/test/mooncake.jl +++ b/test/mooncake.jl @@ -24,7 +24,6 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) end=# # not yet supported if !is_buildkite TestSuite.test_mooncake(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) - # Pullbacks don't work for a lot of these, waiting on https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/156 - #n == m && TestSuite.test_mooncake(Diagonal{T, Vector{T}}, m; atol = m * TestSuite.precision(T), rtol = m * TestSuite.precision(T)) + n == m && TestSuite.test_mooncake(Diagonal{T, Vector{T}}, m; atol = m * TestSuite.precision(T), rtol = m * TestSuite.precision(T)) end end From df74a86dbb9b5006ce46e275b7312ad05e84179f Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 20 Jan 2026 22:55:16 +0100 Subject: [PATCH 03/11] Update ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl Co-authored-by: Lukas Devos --- .../MatrixAlgebraKitMooncakeExt.jl | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl index 75701ba5..872c29cd 100644 --- a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl +++ b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl @@ -27,15 +27,7 @@ function Mooncake.rrule!!(::CoDual{typeof(copy_input)}, f_df::CoDual, A_dA::CoDu return Ac_dAc, copy_input_pb end -@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(initialize_output), Any, Any, Any} -function Mooncake.rrule!!(::CoDual{typeof(initialize_output)}, f_df::CoDual, A_dA::CoDual, alg_dalg::CoDual) - output = initialize_output(Mooncake.primal(f_df), Mooncake.primal(A_dA), Mooncake.primal(alg_dalg)) - output_doutput = Mooncake.zero_fcodual(output) - initialize_output_pb(::NoRData) = (NoRData(), NoRData(), NoRData(), NoRData()) - return output_doutput, initialize_output_pb -end - - +@zero_derivative Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(initialize_output), Any, Any, Any} # two-argument in-place factorizations like LQ, QR, EIG for (f!, f, pb, adj) in ( (:qr_full!, :qr_full, :qr_pullback!, :qr_adjoint), From 1629c0c1b2e77fb73f17d5289e6c85243e715154 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 20 Jan 2026 22:55:29 +0100 Subject: [PATCH 04/11] Update src/pullbacks/eig.jl Co-authored-by: Lukas Devos --- src/pullbacks/eig.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/pullbacks/eig.jl b/src/pullbacks/eig.jl index 5c79f9f5..92a7c986 100644 --- a/src/pullbacks/eig.jl +++ b/src/pullbacks/eig.jl @@ -39,8 +39,7 @@ function eig_pullback!( VᴴΔV = fill!(similar(V), 0) indV = axes(V, 2)[ind] length(indV) == pV || throw(DimensionMismatch()) - VᴴΔV[:, indV] .= V' * ΔV - #mul!(view(VᴴΔV, :, indV), V', ΔV) + mul!(view(VᴴΔV, :, indV), V', ΔV) mask = abs.(transpose(D) .- D) .< degeneracy_atol if isa(ΔA, Array) From 111cc89d6e0940aa1e91a3df3c820439d767b358 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 20 Jan 2026 22:55:43 +0100 Subject: [PATCH 05/11] Update test/testsuite/ad_utils.jl Co-authored-by: Lukas Devos --- test/testsuite/ad_utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/testsuite/ad_utils.jl b/test/testsuite/ad_utils.jl index 70ea1615..b55e9d18 100644 --- a/test/testsuite/ad_utils.jl +++ b/test/testsuite/ad_utils.jl @@ -128,7 +128,7 @@ function ad_qr_rd_compact_setup(A::Diagonal) T = eltype(A) r = minmn - 5 Ard_ = randn!(similar(A, T, m)) - Ard_[(r + 1):m] .= zero(T) + zero!(view(Ard_, (r + 1):m) Ard = Diagonal(Ard_) Q, R = qr_compact(Ard) QR = (Q, R) From 86777c4c0033dc364a3de9c7119131fdc894700a Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 20 Jan 2026 22:56:39 +0100 Subject: [PATCH 06/11] Update test/testsuite/ad_utils.jl Co-authored-by: Lukas Devos --- test/testsuite/ad_utils.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/testsuite/ad_utils.jl b/test/testsuite/ad_utils.jl index b55e9d18..9ac167ad 100644 --- a/test/testsuite/ad_utils.jl +++ b/test/testsuite/ad_utils.jl @@ -132,8 +132,8 @@ function ad_qr_rd_compact_setup(A::Diagonal) Ard = Diagonal(Ard_) Q, R = qr_compact(Ard) QR = (Q, R) - ΔQ = Diagonal(randn!(similar(A.diag, T, m))) - ΔR = Diagonal(randn!(similar(A.diag, T, m))) + ΔQ = Diagonal(randn!(similar(diagview(A), T, m))) + ΔR = Diagonal(randn!(similar(diagview(A), T, m))) diagview(ΔQ)[(r + 1):m] .= zero(T) diagview(ΔR)[(r + 1):m] .= zero(T) return (Q, R), (ΔQ, ΔR) From 3708f83501fd3744f2bc21f515e2532d07deee0f Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 20 Jan 2026 22:56:50 +0100 Subject: [PATCH 07/11] Update test/testsuite/ad_utils.jl Co-authored-by: Lukas Devos --- test/testsuite/ad_utils.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/test/testsuite/ad_utils.jl b/test/testsuite/ad_utils.jl index 9ac167ad..a92988a6 100644 --- a/test/testsuite/ad_utils.jl +++ b/test/testsuite/ad_utils.jl @@ -131,7 +131,6 @@ function ad_qr_rd_compact_setup(A::Diagonal) zero!(view(Ard_, (r + 1):m) Ard = Diagonal(Ard_) Q, R = qr_compact(Ard) - QR = (Q, R) ΔQ = Diagonal(randn!(similar(diagview(A), T, m))) ΔR = Diagonal(randn!(similar(diagview(A), T, m))) diagview(ΔQ)[(r + 1):m] .= zero(T) From c3be1425bb13b4764b020ba9fc58f2a1e74e2596 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 20 Jan 2026 22:57:25 +0100 Subject: [PATCH 08/11] Update test/testsuite/ad_utils.jl Co-authored-by: Lukas Devos --- test/testsuite/ad_utils.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/testsuite/ad_utils.jl b/test/testsuite/ad_utils.jl index a92988a6..ecc6f480 100644 --- a/test/testsuite/ad_utils.jl +++ b/test/testsuite/ad_utils.jl @@ -133,8 +133,8 @@ function ad_qr_rd_compact_setup(A::Diagonal) Q, R = qr_compact(Ard) ΔQ = Diagonal(randn!(similar(diagview(A), T, m))) ΔR = Diagonal(randn!(similar(diagview(A), T, m))) - diagview(ΔQ)[(r + 1):m] .= zero(T) - diagview(ΔR)[(r + 1):m] .= zero(T) + zero!(view(diagview(ΔQ), (r + 1):m) + zero!(view(diagview(ΔR), (r + 1):m) return (Q, R), (ΔQ, ΔR) end From c4627bc2ce398c8d27824647170eacca9136eab0 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 20 Jan 2026 22:59:14 +0100 Subject: [PATCH 09/11] Update test/testsuite/ad_utils.jl Co-authored-by: Lukas Devos --- test/testsuite/ad_utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/testsuite/ad_utils.jl b/test/testsuite/ad_utils.jl index ecc6f480..a050f589 100644 --- a/test/testsuite/ad_utils.jl +++ b/test/testsuite/ad_utils.jl @@ -116,7 +116,7 @@ function ad_qr_rd_compact_setup(A) Q1 = view(Q, 1:m, 1:r) Q2 = view(Q, 1:m, (r + 1):minmn) ΔQ2 = view(ΔQ, 1:m, (r + 1):minmn) - ΔQ2 .= 0 + zero!(ΔQ2) ΔR = randn!(similar(A, T, minmn, n)) view(ΔR, (r + 1):minmn, :) .= 0 return (Q, R), (ΔQ, ΔR) From 6f727544eb1737d9b7fc564825160059907f4deb Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 21 Jan 2026 10:17:00 +0100 Subject: [PATCH 10/11] Refactor checks into functions, rd -> rank_deficient --- src/common/defaults.jl | 1 + src/pullbacks/eig.jl | 24 +++++----- src/pullbacks/eigh.jl | 24 +++++----- src/pullbacks/lq.jl | 78 ++++++++++++++++++++------------ src/pullbacks/polar.jl | 6 +-- src/pullbacks/qr.jl | 86 ++++++++++++++++++++---------------- src/pullbacks/svd.jl | 24 +++++----- test/testsuite/ad_utils.jl | 14 +++--- test/testsuite/chainrules.jl | 2 +- test/testsuite/mooncake.jl | 4 +- 10 files changed, 149 insertions(+), 114 deletions(-) diff --git a/src/common/defaults.jl b/src/common/defaults.jl index dad16376..bc4160a1 100644 --- a/src/common/defaults.jl +++ b/src/common/defaults.jl @@ -34,6 +34,7 @@ default_pullback_degeneracy_atol(A) = eps(norm(A, Inf))^(3 / 4) Default tolerance for deciding what values should be considered equal to 0. """ default_pullback_rank_atol(A) = eps(norm(A, Inf))^(3 / 4) +default_pullback_rank_atol(A::Diagonal) = default_pullback_rank_atol(diagview(A)) """ default_hermitian_tol(A) diff --git a/src/pullbacks/eig.jl b/src/pullbacks/eig.jl index 92a7c986..fbbdee8f 100644 --- a/src/pullbacks/eig.jl +++ b/src/pullbacks/eig.jl @@ -1,3 +1,12 @@ +function check_eig_cotangents(D, VᴴΔV; degeneracy_atol::Real = default_pullback_rank_atol(D), gauge_atol::Real = default_pullback_gauge_atol(VᴴΔV)) + mask = abs.(transpose(D) .- D) .< degeneracy_atol + # not GPU friendly... + Δgauge = norm(view(VᴴΔV, mask)) + Δgauge ≤ gauge_atol || + @warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" + return +end + """ eig_pullback!( ΔA::AbstractMatrix, A, DV, ΔDV, [ind]; @@ -40,14 +49,7 @@ function eig_pullback!( indV = axes(V, 2)[ind] length(indV) == pV || throw(DimensionMismatch()) mul!(view(VᴴΔV, :, indV), V', ΔV) - - mask = abs.(transpose(D) .- D) .< degeneracy_atol - if isa(ΔA, Array) - # not GPU friendly... - Δgauge = norm(view(VᴴΔV, mask), Inf) - Δgauge ≤ gauge_atol || - @warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" - end + check_eig_cotangents(D, VᴴΔV; degeneracy_atol, gauge_atol) VᴴΔV .*= conj.(inv_safe.(transpose(D) .- D, degeneracy_atol)) @@ -132,10 +134,7 @@ function eig_trunc_pullback!( if !iszerotangent(ΔV) (n, p) == size(ΔV) || throw(DimensionMismatch()) VᴴΔV = V' * ΔV - mask = abs.(transpose(D) .- D) .< degeneracy_atol - Δgauge = norm(view(VᴴΔV, mask), Inf) - Δgauge ≤ gauge_atol || - @warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" + check_eig_cotangents(D, VᴴΔV; degeneracy_atol, gauge_atol) ΔVperp = ΔV - V * inv(G) * VᴴΔV VᴴΔV .*= conj.(inv_safe.(transpose(D) .- D, degeneracy_atol)) @@ -194,7 +193,6 @@ function eig_vals_pullback!( ΔA, A, DV, ΔD, ind = Colon(); degeneracy_atol::Real = default_pullback_rank_atol(DV[1]), ) - ΔDV = (diagonal(ΔD), nothing) return eig_pullback!(ΔA, A, DV, ΔDV, ind; degeneracy_atol) end diff --git a/src/pullbacks/eigh.jl b/src/pullbacks/eigh.jl index 11171685..506b5ca3 100644 --- a/src/pullbacks/eigh.jl +++ b/src/pullbacks/eigh.jl @@ -1,3 +1,15 @@ +function check_eigh_cotangents( + D, aVᴴΔV; + degeneracy_atol::Real = default_pullback_rank_atol(D), + gauge_atol::Real = default_pullback_gauge_atol(aVᴴΔV) + ) + mask = abs.(D' .- D) .< degeneracy_atol + Δgauge = norm(view(aVᴴΔV, mask)) + Δgauge ≤ gauge_atol || + @warn "`eigh` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" + return +end + """ eigh_pullback!( ΔA::AbstractMatrix, A, DV, ΔDV, [ind]; @@ -41,12 +53,7 @@ function eigh_pullback!( length(indV) == pV || throw(DimensionMismatch()) mul!(view(VᴴΔV, :, indV), V', ΔV) aVᴴΔV = project_antihermitian(VᴴΔV) # can't use in-place or recycling doesn't work - - mask = abs.(D' .- D) .< degeneracy_atol - Δgauge = norm(view(aVᴴΔV, mask)) - Δgauge ≤ gauge_atol || - @warn "`eigh` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" - + check_eigh_cotangents(D, aVᴴΔV; degeneracy_atol, gauge_atol) aVᴴΔV .*= inv_safe.(D' .- D, degeneracy_atol) if !iszerotangent(ΔDmat) @@ -120,10 +127,7 @@ function eigh_trunc_pullback!( VᴴΔV = V' * ΔV aVᴴΔV = project_antihermitian!(VᴴΔV) - mask = abs.(D' .- D) .< degeneracy_atol - Δgauge = norm(view(aVᴴΔV, mask)) - Δgauge ≤ gauge_atol || - @warn "`eigh` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" + check_eigh_cotangents(D, aVᴴΔV; degeneracy_atol, gauge_atol) aVᴴΔV .*= inv_safe.(D' .- D, degeneracy_atol) diff --git a/src/pullbacks/lq.jl b/src/pullbacks/lq.jl index cc8ce1fd..0998dfa8 100644 --- a/src/pullbacks/lq.jl +++ b/src/pullbacks/lq.jl @@ -1,3 +1,42 @@ +function check_lq_cotangents( + L, Q, ΔL, ΔQ, minmn::Int, p::Int; + gauge_atol::Real = default_pullback_gauge_atol(ΔQ) + ) + if minmn > p # case where A is rank-deficient + Δgauge = abs(zero(eltype(Q))) + if !iszerotangent(ΔQ) + # in this case the number Householder reflections will + # change upon small variations, and all of the remaining + # columns of ΔQ should be zero for a gauge-invariant + # cost function + ΔQ2 = view(ΔQ, (p + 1):size(Q, 1), :) + Δgauge = max(Δgauge, norm(ΔQ2)) + end + if !iszerotangent(ΔL) + ΔL22 = view(ΔL, (p + 1):m, (p + 1):minmn) + Δgauge = max(Δgauge, norm(ΔL22)) + end + Δgauge ≤ gauge_atol || + @warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" + end + return +end + +function check_lq_full_cotangents(Q1, ΔQ2, ΔQ2Q1ᴴ; gauge_atol::Real = default_pullback_gauge_atol(Q1)) + # in the case where A is full rank, but there are more columns in Q than in A + # (the case of `lq_full`), there is gauge-invariant information in the + # projection of ΔQ2 onto the column space of Q1, by virtue of Q being a unitary + # matrix. As the number of Householder reflections is in fixed in the full rank + # case, Q is expected to rotate smoothly (we might even be able to predict) also + # how the full Q2 will change, but this we omit for now, and we consider + # Q2' * ΔQ2 as a gauge dependent quantity. + Δgauge = norm(mul!(copy(ΔQ2), ΔQ2Q1ᴴ, Q1, -1, 1), Inf) + Δgauge ≤ gauge_atol || + @warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" + return +end + + """ lq_pullback!( ΔA, A, LQ, ΔLQ; @@ -36,25 +75,7 @@ function lq_pullback!( ΔA1 = view(ΔA, 1:p, :) ΔA2 = view(ΔA, (p + 1):m, :) - if isa(ΔA, Array) # not GPU friendly - if minmn > p # case where A is rank-deficient - Δgauge = abs(zero(eltype(Q))) - if !iszerotangent(ΔQ) - # in this case the number Householder reflections will - # change upon small variations, and all of the remaining - # columns of ΔQ should be zero for a gauge-invariant - # cost function - ΔQ2 = view(ΔQ, (p + 1):size(Q, 1), :) - Δgauge = max(Δgauge, norm(ΔQ2, Inf)) - end - if !iszerotangent(ΔL) - ΔL22 = view(ΔL, (p + 1):m, (p + 1):minmn) - Δgauge = max(Δgauge, norm(ΔL22, Inf)) - end - Δgauge ≤ gauge_atol || - @warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" - end - end + check_lq_cotangents(L, Q, ΔL, ΔQ, minmn, p; gauge_atol) ΔQ̃ = zero!(similar(Q, (p, n))) if !iszerotangent(ΔQ) @@ -71,11 +92,7 @@ function lq_pullback!( # how the full Q2 will change, but this we omit for now, and we consider # Q2' * ΔQ2 as a gauge dependent quantity. ΔQ2Q1ᴴ = ΔQ2 * Q1' - if isa(ΔA, Array) # not GPU friendly - Δgauge = norm(mul!(copy(ΔQ2), ΔQ2Q1ᴴ, Q1, -1, 1), Inf) - Δgauge ≤ gauge_atol || - @warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" - end + check_lq_full_cotangents(Q1, ΔQ2, ΔQ2Q1ᴴ; gauge_atol) ΔQ̃ = mul!(ΔQ̃, ΔQ2Q1ᴴ', Q2, -1, 1) end end @@ -108,6 +125,14 @@ function lq_pullback!( return ΔA end +function check_lq_null_cotangents(Nᴴ, ΔNᴴ; gauge_atol::Real = default_pullback_gauge_atol(ΔNᴴ)) + aNᴴΔN = project_antihermitian!(Nᴴ * ΔNᴴ') + Δgauge = norm(aNᴴΔN) + Δgauge ≤ gauge_atol || + @warn "`lq_null` cotangent sensitive to gauge choice: (|Δgauge| = $Δgauge)" + return +end + """ lq_null_pullback!( ΔA::AbstractMatrix, A, Nᴴ, ΔNᴴ; @@ -124,10 +149,7 @@ function lq_null_pullback!( gauge_atol::Real = default_pullback_gauge_atol(ΔNᴴ) ) if !iszerotangent(ΔNᴴ) && size(Nᴴ, 1) > 0 - aNᴴΔN = project_antihermitian!(Nᴴ * ΔNᴴ') - Δgauge = norm(aNᴴΔN) - Δgauge ≤ gauge_atol || - @warn "`lq_null` cotangent sensitive to gauge choice: (|Δgauge| = $Δgauge)" + check_lq_null_cotangents(Nᴴ, ΔNᴴ; gauge_atol) L, Q = lq_compact(A; positive = true) # should we be able to provide algorithm here? X = ldiv!(LowerTriangular(L)', Q * ΔNᴴ') ΔA = mul!(ΔA, X, Nᴴ, -1, 1) diff --git a/src/pullbacks/polar.jl b/src/pullbacks/polar.jl index 0caed08b..1c6de509 100644 --- a/src/pullbacks/polar.jl +++ b/src/pullbacks/polar.jl @@ -22,7 +22,7 @@ function left_polar_pullback!(ΔA::AbstractMatrix, A, WP, ΔWP; kwargs...) if !iszerotangent(ΔW) ΔWP = ΔW / P WdΔWP = W' * ΔWP - ΔWP .-= W * WdΔWP + ΔWP = mul!(ΔWP, W, WdΔWP, -1, 1) ΔA .+= ΔWP end return ΔA @@ -48,11 +48,11 @@ function right_polar_pullback!(ΔA::AbstractMatrix, A, PWᴴ, ΔPWᴴ; kwargs... !iszerotangent(ΔP) && mul!(M, P, ΔP, -1, 1) C = sylvester(P, P, M' - M) C .+= ΔP - ΔA .+= C * Wᴴ + ΔA = mul!(ΔA, C, Wᴴ, 1, 1) if !iszerotangent(ΔWᴴ) PΔWᴴ = P \ ΔWᴴ PΔWᴴW = PΔWᴴ * Wᴴ' - PΔWᴴ .-= PΔWᴴW * Wᴴ + PΔWᴴ = mul!(PΔWᴴ, PΔWᴴW, Wᴴ, -1, 1) ΔA .+= PΔWᴴ end return ΔA diff --git a/src/pullbacks/qr.jl b/src/pullbacks/qr.jl index a054ae21..67c6fbba 100644 --- a/src/pullbacks/qr.jl +++ b/src/pullbacks/qr.jl @@ -1,3 +1,38 @@ +function check_qr_cotangents(Q, R, ΔQ, ΔR, minmn::Int, p::Int; gauge_atol::Real = default_pullback_gauge_atol(ΔQ)) + if minmn > p # case where A is rank-deficient + Δgauge = abs(zero(eltype(Q))) + if !iszerotangent(ΔQ) + # in this case the number Householder reflections will + # change upon small variations, and all of the remaining + # columns of ΔQ should be zero for a gauge-invariant + # cost function + ΔQ2 = view(ΔQ, :, (p + 1):size(Q, 2)) + Δgauge = max(Δgauge, norm(ΔQ2, Inf)) + end + if !iszerotangent(ΔR) + ΔR22 = view(ΔR, (p + 1):minmn, (p + 1):n) + Δgauge = max(Δgauge, norm(ΔR22, Inf)) + end + Δgauge ≤ gauge_atol || + @warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" + end + return +end + +function check_qr_full_cotangents(Q1, ΔQ2, ΔR, Q1dΔQ2, ; gauge_atol::Real = default_pullback_gauge_atol(ΔQ2)) + # in the case where A is full rank, but there are more columns in Q than in A + # (the case of `qr_full`), there is gauge-invariant information in the + # projection of ΔQ2 onto the column space of Q1, by virtue of Q being a unitary + # matrix. As the number of Householder reflections is in fixed in the full rank + # case, Q is expected to rotate smoothly (we might even be able to predict) also + # how the full Q2 will change, but this we omit for now, and we consider + # Q2' * ΔQ2 as a gauge dependent quantity. + Δgauge = norm(mul!(copy(ΔQ2), Q1, Q1dΔQ2, -1, 1), Inf) + Δgauge ≤ gauge_atol || + @warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" + return +end + """ qr_pullback!( ΔA, A, QR, ΔQR; @@ -37,25 +72,7 @@ function qr_pullback!( ΔA1 = view(ΔA, :, 1:p) ΔA2 = view(ΔA, :, (p + 1):n) - if isa(ΔA, Array) # not GPU friendly - if minmn > p # case where A is rank-deficient - Δgauge = abs(zero(eltype(Q))) - if !iszerotangent(ΔQ) - # in this case the number Householder reflections will - # change upon small variations, and all of the remaining - # columns of ΔQ should be zero for a gauge-invariant - # cost function - ΔQ2 = view(ΔQ, :, (p + 1):size(Q, 2)) - Δgauge = max(Δgauge, norm(ΔQ2, Inf)) - end - if !iszerotangent(ΔR) - ΔR22 = view(ΔR, (p + 1):minmn, (p + 1):n) - Δgauge = max(Δgauge, norm(ΔR22, Inf)) - end - Δgauge ≤ gauge_atol || - @warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" - end - end + check_qr_cotangents(Q, R, ΔQ, ΔR, minmn, p; gauge_atol) ΔQ̃ = zero!(similar(Q, (m, p))) if !iszerotangent(ΔQ) @@ -63,19 +80,8 @@ function qr_pullback!( if p < size(Q, 2) Q2 = view(Q, :, (p + 1):size(Q, 2)) ΔQ2 = view(ΔQ, :, (p + 1):size(Q, 2)) - # in the case where A is full rank, but there are more columns in Q than in A - # (the case of `qr_full`), there is gauge-invariant information in the - # projection of ΔQ2 onto the column space of Q1, by virtue of Q being a unitary - # matrix. As the number of Householder reflections is in fixed in the full rank - # case, Q is expected to rotate smoothly (we might even be able to predict) also - # how the full Q2 will change, but this we omit for now, and we consider - # Q2' * ΔQ2 as a gauge dependent quantity. Q1dΔQ2 = Q1' * ΔQ2 - if isa(ΔA, Array) # not GPU friendly - Δgauge = norm(mul!(copy(ΔQ2), Q1, Q1dΔQ2, -1, 1), Inf) - Δgauge ≤ gauge_atol || - @warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" - end + check_qr_full_cotangents(Q1, ΔQ2, Q1dΔQ2; gauge_atol) ΔQ̃ = mul!(ΔQ̃, Q2, Q1dΔQ2', -1, 1) end end @@ -91,9 +97,9 @@ function qr_pullback!( M = zero!(similar(R, (p, p))) if !iszerotangent(ΔR) ΔR11 = view(ΔR, 1:p, 1:p) - M += ΔR11 * R11' + M = mul!(M, ΔR11, R11', 1, 1) end - M -= Q1' * ΔQ̃ + M = mul!(M, Q1', ΔQ̃, -1, 1) view(M, lowertriangularind(M)) .= conj.(view(M, uppertriangularind(M))) if eltype(M) <: Complex Md = diagview(M) @@ -108,6 +114,14 @@ function qr_pullback!( return ΔA end +function check_qr_null_cotangents(N, ΔN; gauge_atol::Real = default_pullback_gauge_atol(ΔN)) + aNᴴΔN = project_antihermitian!(N' * ΔN) + Δgauge = norm(aNᴴΔN) + Δgauge ≤ gauge_atol || + @warn "`qr_null` cotangent sensitive to gauge choice: (|Δgauge| = $Δgauge)" + return +end + """ qr_null_pullback!( ΔA::AbstractMatrix, A, N, ΔN; @@ -124,11 +138,7 @@ function qr_null_pullback!( gauge_atol::Real = default_pullback_gauge_atol(ΔN) ) if !iszerotangent(ΔN) && size(N, 2) > 0 - aNᴴΔN = project_antihermitian!(N' * ΔN) - Δgauge = norm(aNᴴΔN) - Δgauge ≤ gauge_atol || - @warn "`qr_null` cotangent sensitive to gauge choice: (|Δgauge| = $Δgauge)" - + check_qr_null_cotangents(N, ΔN; gauge_atol) Q, R = qr_compact(A; positive = true) X = rdiv!(ΔN' * Q, UpperTriangular(R)') ΔA = mul!(ΔA, N, X, -1, 1) diff --git a/src/pullbacks/svd.jl b/src/pullbacks/svd.jl index dd393b04..759cee2e 100644 --- a/src/pullbacks/svd.jl +++ b/src/pullbacks/svd.jl @@ -1,3 +1,11 @@ +function check_svd_cotangents(aUΔU, Sr, aVΔV; degeneracy_atol = default_pullback_rank_atol(Sr), gauge_atol = default_pullback_gauge_atol(aUΔU, aVΔV)) + mask = abs.(Sr' .- Sr) .< degeneracy_atol + Δgauge = norm(view(aUΔU, mask) + view(aVΔV, mask), Inf) + Δgauge ≤ gauge_atol || + @warn "`svd` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" + return +end + """ svd_pullback!( ΔA, A, USVᴴ, ΔUSVᴴ, [ind]; @@ -22,8 +30,8 @@ which `abs(S[i] - S[j]) < degeneracy_atol`, is not small compared to `gauge_atol """ function svd_pullback!( ΔA::AbstractMatrix, A, USVᴴ, ΔUSVᴴ, ind = Colon(); - rank_atol::Real = default_pullback_rank_atol(diagview(USVᴴ[2])), - degeneracy_atol::Real = default_pullback_rank_atol(diagview(USVᴴ[2])), + rank_atol::Real = default_pullback_rank_atol(USVᴴ[2]), + degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]), gauge_atol::Real = default_pullback_gauge_atol(ΔUSVᴴ[1], ΔUSVᴴ[3]) ) # Extract the SVD components @@ -70,12 +78,7 @@ function svd_pullback!( aVΔV = project_antihermitian!(VΔV) # check whether cotangents arise from gauge-invariance objective function - mask = abs.(Sr' .- Sr) .< degeneracy_atol - if isa(ΔA, Array) # norm check not GPU friendly - Δgauge = norm(view(aUΔU, mask) + view(aVΔV, mask), Inf) - Δgauge ≤ gauge_atol || - @warn "`svd` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" - end + check_svd_cotangents(aUΔU, Sr, aVΔV; degeneracy_atol, gauge_atol) UdΔAV = (aUΔU .+ aVΔV) .* inv_safe.(Sr' .- Sr, degeneracy_atol) .+ (aUΔU .- aVΔV) .* inv_safe.(Sr' .+ Sr, degeneracy_atol) @@ -171,10 +174,7 @@ function svd_trunc_pullback!( aVΔV = project_antihermitian!(VΔV) # check whether cotangents arise from gauge-invariance objective function - mask = abs.(S' .- S) .< degeneracy_atol - Δgauge = norm(view(aUΔU, mask) + view(aVΔV, mask), Inf) - Δgauge ≤ gauge_atol || - @warn "`svd` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" + check_svd_cotangents(aUΔU, Sr, aVΔV; degeneracy_atol, gauge_atol) UdΔAV = (aUΔU .+ aVΔV) .* inv_safe.(S' .- S, degeneracy_atol) .+ (aUΔU .- aVΔV) .* inv_safe.(S' .+ S, degeneracy_atol) diff --git a/test/testsuite/ad_utils.jl b/test/testsuite/ad_utils.jl index a050f589..20512b94 100644 --- a/test/testsuite/ad_utils.jl +++ b/test/testsuite/ad_utils.jl @@ -104,7 +104,7 @@ end ad_qr_full_setup(A::Diagonal) = ad_qr_compact_setup(A) -function ad_qr_rd_compact_setup(A) +function ad_qr_rank_deficient_compact_setup(A) m, n = size(A) minmn = min(m, n) T = eltype(A) @@ -122,19 +122,19 @@ function ad_qr_rd_compact_setup(A) return (Q, R), (ΔQ, ΔR) end -function ad_qr_rd_compact_setup(A::Diagonal) +function ad_qr_rank_deficient_compact_setup(A::Diagonal) m, n = size(A) minmn = min(m, n) T = eltype(A) r = minmn - 5 Ard_ = randn!(similar(A, T, m)) - zero!(view(Ard_, (r + 1):m) + zero!(view(Ard_, (r + 1):m)) Ard = Diagonal(Ard_) Q, R = qr_compact(Ard) ΔQ = Diagonal(randn!(similar(diagview(A), T, m))) ΔR = Diagonal(randn!(similar(diagview(A), T, m))) - zero!(view(diagview(ΔQ), (r + 1):m) - zero!(view(diagview(ΔR), (r + 1):m) + zero!(view(diagview(ΔQ), (r + 1):m)) + zero!(view(diagview(ΔR), (r + 1):m)) return (Q, R), (ΔQ, ΔR) end @@ -173,7 +173,7 @@ function ad_lq_full_setup(A) end ad_lq_full_setup(A::Diagonal) = ad_qr_full_setup(A) -function ad_lq_rd_compact_setup(A) +function ad_lq_rank_deficient_compact_setup(A) m, n = size(A) minmn = min(m, n) T = eltype(A) @@ -189,7 +189,7 @@ function ad_lq_rd_compact_setup(A) view(ΔL, :, (r + 1):minmn) .= 0 return (L, Q), (ΔL, ΔQ) end -ad_lq_rd_compact_setup(A::Diagonal) = ad_qr_rd_compact_setup(A) +ad_lq_rank_deficient_compact_setup(A::Diagonal) = ad_qr_rank_deficient_compact_setup(A) function ad_eig_full_setup(A) m, n = size(A) diff --git a/test/testsuite/chainrules.jl b/test/testsuite/chainrules.jl index 625ff6a3..2e134eda 100644 --- a/test/testsuite/chainrules.jl +++ b/test/testsuite/chainrules.jl @@ -112,7 +112,7 @@ function test_chainrules_qr( m, n = size(A) r = min(m, n) - 5 Ard = instantiate_matrix(T, (m, r)) * instantiate_matrix(T, (r, n)) - QR, ΔQR = ad_qr_rd_compact_setup(Ard) + QR, ΔQR = ad_qr_rank_deficient_compact_setup(Ard) ΔQ, ΔR = ΔQR test_rrule( cr_copy_qr_compact, Ard, alg ⊢ NoTangent(); diff --git a/test/testsuite/mooncake.jl b/test/testsuite/mooncake.jl index aac46182..bf225a59 100644 --- a/test/testsuite/mooncake.jl +++ b/test/testsuite/mooncake.jl @@ -201,7 +201,7 @@ function test_mooncake_qr( m, n = size(A) r = min(m, n) - 5 Ard = instantiate_matrix(T, (m, r)) * instantiate_matrix(T, (r, n)) - QR, ΔQR = ad_qr_rd_compact_setup(Ard) + QR, ΔQR = ad_qr_rank_deficient_compact_setup(Ard) dQR = make_mooncake_tangent(ΔQR) Mooncake.TestUtils.test_rule(rng, qr_compact, Ard; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dQR, atol = atol, rtol = rtol) test_pullbacks_match(qr_compact!, qr_compact, Ard, QR, ΔQR) @@ -238,7 +238,7 @@ function test_mooncake_lq( m, n = size(A) r = min(m, n) - 5 Ard = instantiate_matrix(T, (m, r)) * instantiate_matrix(T, (r, n)) - LQ, ΔLQ = ad_lq_rd_compact_setup(Ard) + LQ, ΔLQ = ad_lq_rank_deficient_compact_setup(Ard) dLQ = make_mooncake_tangent(ΔLQ) Mooncake.TestUtils.test_rule(rng, lq_compact, Ard; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dLQ, atol = atol, rtol = rtol) test_pullbacks_match(lq_compact!, lq_compact, Ard, LQ, ΔLQ) From b343415dd36be0f4c21828d9a7bce4c182261c8b Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 21 Jan 2026 10:29:51 +0100 Subject: [PATCH 11/11] Fix bad comment suggestion --- ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl index 872c29cd..e435a6dd 100644 --- a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl +++ b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl @@ -27,7 +27,7 @@ function Mooncake.rrule!!(::CoDual{typeof(copy_input)}, f_df::CoDual, A_dA::CoDu return Ac_dAc, copy_input_pb end -@zero_derivative Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(initialize_output), Any, Any, Any} +Mooncake.@zero_derivative Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(initialize_output), Any, Any, Any} # two-argument in-place factorizations like LQ, QR, EIG for (f!, f, pb, adj) in ( (:qr_full!, :qr_full, :qr_pullback!, :qr_adjoint),