-
Notifications
You must be signed in to change notification settings - Fork 5
Use Testsuite for AD tests #126
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
09d7f69
239686e
df74a86
1629c0c
111cc89
86777c4
3708f83
c3be142
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,7 +3,7 @@ module MatrixAlgebraKitCUDAExt | |
| using MatrixAlgebraKit | ||
| using MatrixAlgebraKit: @algdef, Algorithm, check_input | ||
| using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular! | ||
| using MatrixAlgebraKit: diagview, sign_safe | ||
| using MatrixAlgebraKit: diagview, sign_safe, default_pullback_gauge_atol, default_pullback_rank_atol | ||
| using MatrixAlgebraKit: LQViaTransposedQR, TruncationByValue, AbstractAlgorithm | ||
| using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm, default_eigh_algorithm | ||
| import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_Xgesvdr!, _gpu_gesvdj!, _gpu_geev! | ||
|
|
@@ -195,4 +195,23 @@ end | |
| MatrixAlgebraKit._ind_intersect(A::CuVector{Int}, B::CuVector{Int}) = | ||
| MatrixAlgebraKit._ind_intersect(collect(A), collect(B)) | ||
|
|
||
| MatrixAlgebraKit.default_pullback_rank_atol(A::AnyCuArray) = eps(norm(CuArray(A), Inf))^(3 / 4) | ||
| MatrixAlgebraKit.default_pullback_gauge_atol(A::AnyCuArray) = MatrixAlgebraKit.iszerotangent(A) ? 0 : eps(norm(CuArray(A), Inf))^(3 / 4) | ||
| function MatrixAlgebraKit.default_pullback_gauge_atol(A::AnyCuArray, As...) | ||
| As′ = filter(!MatrixAlgebraKit.iszerotangent, (A, As...)) | ||
| return isempty(As′) ? 0 : eps(norm(CuArray.(As′), Inf))^(3 / 4) | ||
| end | ||
|
|
||
| function LinearAlgebra.sylvester(A::AnyCuMatrix, B::AnyCuMatrix, C::AnyCuMatrix) | ||
| #=m = size(A, 1) | ||
| n = size(B, 2) | ||
| I_n = fill!(similar(A, n), one(eltype(A))) | ||
| I_m = fill!(similar(B, m), one(eltype(B))) | ||
| L = kron(diagm(I_n), A) + kron(adjoint(B), diagm(I_m)) | ||
| x_vec = L \ -vec(C) | ||
| X = CuMatrix(reshape(x_vec, m, n))=# | ||
| hX = sylvester(collect(A), collect(B), collect(C)) | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is very awful but I wasn't able to find a correct way to do it in five minutes so there you go
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there any chance we could:
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes to both |
||
| return CuArray(hX) | ||
| end | ||
|
|
||
| end | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -42,9 +42,12 @@ function eig_pullback!( | |
| mul!(view(VᴴΔV, :, indV), V', ΔV) | ||
|
|
||
| mask = abs.(transpose(D) .- D) .< degeneracy_atol | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The mask probably also belongs inside of the |
||
| Δgauge = norm(view(VᴴΔV, mask), Inf) | ||
| Δgauge ≤ gauge_atol || | ||
| @warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" | ||
| if isa(ΔA, Array) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we refactor this into a separate function and use dispatch to resolve this? It might be slightly safer to simply allocate instead of taking a view for the GPU versions, which should then no longer be scalar-indexed?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep! |
||
| # not GPU friendly... | ||
| Δgauge = norm(view(VᴴΔV, mask), Inf) | ||
| Δgauge ≤ gauge_atol || | ||
| @warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" | ||
| end | ||
|
|
||
| VᴴΔV .*= conj.(inv_safe.(transpose(D) .- D, degeneracy_atol)) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -36,28 +36,30 @@ function lq_pullback!( | |
| ΔA1 = view(ΔA, 1:p, :) | ||
| ΔA2 = view(ΔA, (p + 1):m, :) | ||
|
|
||
| if minmn > p # case where A is rank-deficient | ||
| Δgauge = abs(zero(eltype(Q))) | ||
| if !iszerotangent(ΔQ) | ||
| # in this case the number Householder reflections will | ||
| # change upon small variations, and all of the remaining | ||
| # columns of ΔQ should be zero for a gauge-invariant | ||
| # cost function | ||
| ΔQ2 = view(ΔQ, (p + 1):size(Q, 1), :) | ||
| Δgauge = max(Δgauge, norm(ΔQ2, Inf)) | ||
| end | ||
| if !iszerotangent(ΔL) | ||
| ΔL22 = view(ΔL, (p + 1):m, (p + 1):minmn) | ||
| Δgauge = max(Δgauge, norm(ΔL22, Inf)) | ||
| if isa(ΔA, Array) # not GPU friendly | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same comment here about refactoring into a separate function |
||
| if minmn > p # case where A is rank-deficient | ||
| Δgauge = abs(zero(eltype(Q))) | ||
| if !iszerotangent(ΔQ) | ||
| # in this case the number Householder reflections will | ||
| # change upon small variations, and all of the remaining | ||
| # columns of ΔQ should be zero for a gauge-invariant | ||
| # cost function | ||
| ΔQ2 = view(ΔQ, (p + 1):size(Q, 1), :) | ||
| Δgauge = max(Δgauge, norm(ΔQ2, Inf)) | ||
| end | ||
| if !iszerotangent(ΔL) | ||
| ΔL22 = view(ΔL, (p + 1):m, (p + 1):minmn) | ||
| Δgauge = max(Δgauge, norm(ΔL22, Inf)) | ||
| end | ||
| Δgauge ≤ gauge_atol || | ||
| @warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" | ||
| end | ||
| Δgauge ≤ gauge_atol || | ||
| @warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" | ||
| end | ||
|
|
||
| ΔQ̃ = zero!(similar(Q, (p, n))) | ||
| if !iszerotangent(ΔQ) | ||
| ΔQ1 = view(ΔQ, 1:p, :) | ||
| copy!(ΔQ̃, ΔQ1) | ||
| ΔQ̃ .= ΔQ1 | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this also a GPU thing?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah |
||
| if p < size(Q, 1) | ||
| Q2 = view(Q, (p + 1):size(Q, 1), :) | ||
| ΔQ2 = view(ΔQ, (p + 1):size(Q, 1), :) | ||
|
|
@@ -69,9 +71,11 @@ function lq_pullback!( | |
| # how the full Q2 will change, but this we omit for now, and we consider | ||
| # Q2' * ΔQ2 as a gauge dependent quantity. | ||
| ΔQ2Q1ᴴ = ΔQ2 * Q1' | ||
| Δgauge = norm(mul!(copy(ΔQ2), ΔQ2Q1ᴴ, Q1, -1, 1), Inf) | ||
| Δgauge ≤ gauge_atol || | ||
| @warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" | ||
| if isa(ΔA, Array) # not GPU friendly | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same comment about refactoring into a function |
||
| Δgauge = norm(mul!(copy(ΔQ2), ΔQ2Q1ᴴ, Q1, -1, 1), Inf) | ||
| Δgauge ≤ gauge_atol || | ||
| @warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" | ||
| end | ||
| ΔQ̃ = mul!(ΔQ̃, ΔQ2Q1ᴴ', Q2, -1, 1) | ||
| end | ||
| end | ||
|
|
@@ -95,8 +99,10 @@ function lq_pullback!( | |
| Md = diagview(M) | ||
| Md .= real.(Md) | ||
| end | ||
| ldiv!(LowerTriangular(L11)', M) | ||
| ldiv!(LowerTriangular(L11)', ΔQ̃) | ||
| # not GPU friendly... | ||
| L11arr = typeof(L)(L11) | ||
| ldiv!(LowerTriangular(L11arr)', M) | ||
| ldiv!(LowerTriangular(L11arr)', ΔQ̃) | ||
| ΔA1 = mul!(ΔA1, M, Q1, +1, 1) | ||
| ΔA1 .+= ΔQ̃ | ||
| return ΔA | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,7 +22,7 @@ function left_polar_pullback!(ΔA::AbstractMatrix, A, WP, ΔWP; kwargs...) | |
| if !iszerotangent(ΔW) | ||
| ΔWP = ΔW / P | ||
| WdΔWP = W' * ΔWP | ||
| ΔWP = mul!(ΔWP, W, WdΔWP, -1, 1) | ||
| ΔWP .-= W * WdΔWP | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this still a |
||
| ΔA .+= ΔWP | ||
| end | ||
| return ΔA | ||
|
|
@@ -48,11 +48,11 @@ function right_polar_pullback!(ΔA::AbstractMatrix, A, PWᴴ, ΔPWᴴ; kwargs... | |
| !iszerotangent(ΔP) && mul!(M, P, ΔP, -1, 1) | ||
| C = sylvester(P, P, M' - M) | ||
| C .+= ΔP | ||
| ΔA = mul!(ΔA, C, Wᴴ, 1, 1) | ||
| ΔA .+= C * Wᴴ | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same |
||
| if !iszerotangent(ΔWᴴ) | ||
| PΔWᴴ = P \ ΔWᴴ | ||
| PΔWᴴW = PΔWᴴ * Wᴴ' | ||
| PΔWᴴ = mul!(PΔWᴴ, PΔWᴴW, Wᴴ, -1, 1) | ||
| PΔWᴴ .-= PΔWᴴW * Wᴴ | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same |
||
| ΔA .+= PΔWᴴ | ||
| end | ||
| return ΔA | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -37,27 +37,29 @@ function qr_pullback!( | |
| ΔA1 = view(ΔA, :, 1:p) | ||
| ΔA2 = view(ΔA, :, (p + 1):n) | ||
|
|
||
| if minmn > p # case where A is rank-deficient | ||
| Δgauge = abs(zero(eltype(Q))) | ||
| if !iszerotangent(ΔQ) | ||
| # in this case the number Householder reflections will | ||
| # change upon small variations, and all of the remaining | ||
| # columns of ΔQ should be zero for a gauge-invariant | ||
| # cost function | ||
| ΔQ2 = view(ΔQ, :, (p + 1):size(Q, 2)) | ||
| Δgauge = max(Δgauge, norm(ΔQ2, Inf)) | ||
| end | ||
| if !iszerotangent(ΔR) | ||
| ΔR22 = view(ΔR, (p + 1):minmn, (p + 1):n) | ||
| Δgauge = max(Δgauge, norm(ΔR22, Inf)) | ||
| if isa(ΔA, Array) # not GPU friendly | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same refactor into function comment here |
||
| if minmn > p # case where A is rank-deficient | ||
| Δgauge = abs(zero(eltype(Q))) | ||
| if !iszerotangent(ΔQ) | ||
| # in this case the number Householder reflections will | ||
| # change upon small variations, and all of the remaining | ||
| # columns of ΔQ should be zero for a gauge-invariant | ||
| # cost function | ||
| ΔQ2 = view(ΔQ, :, (p + 1):size(Q, 2)) | ||
| Δgauge = max(Δgauge, norm(ΔQ2, Inf)) | ||
| end | ||
| if !iszerotangent(ΔR) | ||
| ΔR22 = view(ΔR, (p + 1):minmn, (p + 1):n) | ||
| Δgauge = max(Δgauge, norm(ΔR22, Inf)) | ||
| end | ||
| Δgauge ≤ gauge_atol || | ||
| @warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" | ||
| end | ||
| Δgauge ≤ gauge_atol || | ||
| @warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" | ||
| end | ||
|
|
||
| ΔQ̃ = zero!(similar(Q, (m, p))) | ||
| if !iszerotangent(ΔQ) | ||
| copy!(ΔQ̃, view(ΔQ, :, 1:p)) | ||
| ΔQ̃ .= view(ΔQ, :, 1:p) | ||
| if p < size(Q, 2) | ||
| Q2 = view(Q, :, (p + 1):size(Q, 2)) | ||
| ΔQ2 = view(ΔQ, :, (p + 1):size(Q, 2)) | ||
|
|
@@ -69,9 +71,11 @@ function qr_pullback!( | |
| # how the full Q2 will change, but this we omit for now, and we consider | ||
| # Q2' * ΔQ2 as a gauge dependent quantity. | ||
| Q1dΔQ2 = Q1' * ΔQ2 | ||
| Δgauge = norm(mul!(copy(ΔQ2), Q1, Q1dΔQ2, -1, 1), Inf) | ||
| Δgauge ≤ gauge_atol || | ||
| @warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" | ||
| if isa(ΔA, Array) # not GPU friendly | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same refactor into a function comment here |
||
| Δgauge = norm(mul!(copy(ΔQ2), Q1, Q1dΔQ2, -1, 1), Inf) | ||
| Δgauge ≤ gauge_atol || | ||
| @warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" | ||
| end | ||
| ΔQ̃ = mul!(ΔQ̃, Q2, Q1dΔQ2', -1, 1) | ||
| end | ||
| end | ||
|
|
@@ -87,16 +91,18 @@ function qr_pullback!( | |
| M = zero!(similar(R, (p, p))) | ||
| if !iszerotangent(ΔR) | ||
| ΔR11 = view(ΔR, 1:p, 1:p) | ||
| M = mul!(M, ΔR11, R11', 1, 1) | ||
| M += ΔR11 * R11' | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same |
||
| end | ||
| M = mul!(M, Q1', ΔQ̃, -1, 1) | ||
| M -= Q1' * ΔQ̃ | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same |
||
| view(M, lowertriangularind(M)) .= conj.(view(M, uppertriangularind(M))) | ||
| if eltype(M) <: Complex | ||
| Md = diagview(M) | ||
| Md .= real.(Md) | ||
| end | ||
| rdiv!(M, UpperTriangular(R11)') | ||
| rdiv!(ΔQ̃, UpperTriangular(R11)') | ||
| # not GPU-friendly... | ||
| R11arr = typeof(R)(R11) | ||
| rdiv!(M, UpperTriangular(R11arr)') | ||
| rdiv!(ΔQ̃, UpperTriangular(R11arr)') | ||
| ΔA1 = mul!(ΔA1, Q1, M, +1, 1) | ||
| ΔA1 .+= ΔQ̃ | ||
| return ΔA | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,8 +22,8 @@ which `abs(S[i] - S[j]) < degeneracy_atol`, is not small compared to `gauge_atol | |
| """ | ||
| function svd_pullback!( | ||
| ΔA::AbstractMatrix, A, USVᴴ, ΔUSVᴴ, ind = Colon(); | ||
| rank_atol::Real = default_pullback_rank_atol(USVᴴ[2]), | ||
| degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]), | ||
| rank_atol::Real = default_pullback_rank_atol(diagview(USVᴴ[2])), | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we handle this conversion in the functions instead of at the callsite? |
||
| degeneracy_atol::Real = default_pullback_rank_atol(diagview(USVᴴ[2])), | ||
| gauge_atol::Real = default_pullback_gauge_atol(ΔUSVᴴ[1], ΔUSVᴴ[3]) | ||
| ) | ||
| # Extract the SVD components | ||
|
|
@@ -33,7 +33,7 @@ function svd_pullback!( | |
| minmn = min(m, n) | ||
| S = diagview(Smat) | ||
| length(S) == minmn || throw(DimensionMismatch("length of S ($(length(S))) does not matrix minimum dimension of U, Vᴴ ($minmn)")) | ||
| r = searchsortedlast(S, rank_atol; rev = true) # rank | ||
| r = findlast(s -> s ≥ rank_atol, S) # rank | ||
| Ur = view(U, :, 1:r) | ||
| Vᴴr = view(Vᴴ, 1:r, :) | ||
| Sr = view(S, 1:r) | ||
|
|
@@ -71,9 +71,11 @@ function svd_pullback!( | |
|
|
||
| # check whether cotangents arise from gauge-invariance objective function | ||
| mask = abs.(Sr' .- Sr) .< degeneracy_atol | ||
| Δgauge = norm(view(aUΔU, mask) + view(aVΔV, mask), Inf) | ||
| Δgauge ≤ gauge_atol || | ||
| @warn "`svd` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" | ||
| if isa(ΔA, Array) # norm check not GPU friendly | ||
| Δgauge = norm(view(aUΔU, mask) + view(aVΔV, mask), Inf) | ||
| Δgauge ≤ gauge_atol || | ||
| @warn "`svd` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" | ||
| end | ||
|
|
||
| UdΔAV = (aUΔU .+ aVΔV) .* inv_safe.(Sr' .- Sr, degeneracy_atol) .+ | ||
| (aUΔU .- aVΔV) .* inv_safe.(Sr' .+ Sr, degeneracy_atol) | ||
|
|
@@ -84,18 +86,18 @@ function svd_pullback!( | |
| length(indS) == pS || throw(DimensionMismatch("length of selected S diagonals ($(length(indS))) does not match length of ΔS diagonal ($(length(ΔS)))")) | ||
| view(diagview(UdΔAV), indS) .+= real.(ΔS) | ||
| end | ||
| ΔA = mul!(ΔA, Ur, UdΔAV * Vᴴr, 1, 1) # add the contribution to ΔA | ||
| ΔA .+= Ur * UdΔAV * Vᴴr # add the contribution to ΔA | ||
|
|
||
| # Add the remaining contributions | ||
| if m > r && !iszerotangent(ΔU) # remaining ΔU is already orthogonal to Ur | ||
| Sp = view(S, indU) | ||
| Vᴴp = view(Vᴴ, indU, :) | ||
| ΔA = mul!(ΔA, ΔU ./ Sp', Vᴴp, 1, 1) | ||
| ΔA .+= (ΔU ./ Sp') * Vᴴp | ||
| end | ||
| if n > r && !iszerotangent(ΔVᴴ) # remaining ΔV is already orthogonal to Vᴴr | ||
| Sp = view(S, indV) | ||
| Up = view(U, :, indV) | ||
| ΔA = mul!(ΔA, Up, Sp .\ ΔVᴴ, 1, 1) | ||
| ΔA .+= Up * (Sp .\ ΔVᴴ) | ||
| end | ||
| return ΔA | ||
| end | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why was this needed? what breaks if we don't do
CuArray.(As')?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
normdoesn't work forAdjoint{CuArray}for example