Skip to content

Conversation

@kshyatt
Copy link
Member

@kshyatt kshyatt commented Dec 23, 2025

This lets us reuse a lot of the setup infrastructure for ChainRules, Mooncake, and (soon) Enzyme. Also starts testing the AD rules on GPU.

@kshyatt kshyatt requested review from Jutho and lkdvos December 23, 2025 07:58
L = kron(diagm(I_n), A) + kron(adjoint(B), diagm(I_m))
x_vec = L \ -vec(C)
X = CuMatrix(reshape(x_vec, m, n))=#
hX = sylvester(collect(A), collect(B), collect(C))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very awful but I wasn't able to find a correct way to do it in five minutes so there you go

@github-actions
Copy link

github-actions bot commented Dec 23, 2025

Your PR requires formatting changes to meet the project's style guidelines.
Please consider running Runic (git runic main) to apply these changes.

Click here to view the suggested changes.
diff --git a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl
index 637012c..75701ba 100644
--- a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl
+++ b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl
@@ -17,9 +17,9 @@ Mooncake.tangent_type(::Type{<:MatrixAlgebraKit.AbstractAlgorithm}) = Mooncake.N
 
 @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(copy_input), Any, Any}
 function Mooncake.rrule!!(::CoDual{typeof(copy_input)}, f_df::CoDual, A_dA::CoDual)
-    Ac     = copy_input(Mooncake.primal(f_df), Mooncake.primal(A_dA))
+    Ac = copy_input(Mooncake.primal(f_df), Mooncake.primal(A_dA))
     Ac_dAc = Mooncake.zero_fcodual(Ac)
-    dAc    = Mooncake.tangent(Ac_dAc)
+    dAc = Mooncake.tangent(Ac_dAc)
     function copy_input_pb(::NoRData)
         Mooncake.increment!!(Mooncake.tangent(A_dA), dAc)
         return NoRData(), NoRData(), NoRData()
diff --git a/src/pullbacks/polar.jl b/src/pullbacks/polar.jl
index e4400a3..0caed08 100644
--- a/src/pullbacks/polar.jl
+++ b/src/pullbacks/polar.jl
@@ -48,12 +48,12 @@ function right_polar_pullback!(ΔA::AbstractMatrix, A, PWᴴ, ΔPWᴴ; kwargs...
     !iszerotangent(ΔP) && mul!(M, P, ΔP, -1, 1)
     C = sylvester(P, P, M' - M)
     C .+= ΔP
-    ΔA .+= C*Wᴴ
+    ΔA .+= C * Wᴴ
     if !iszerotangent(ΔWᴴ)
-        PΔWᴴ   = P \ ΔWᴴ
-        PΔWᴴW  = PΔWᴴ * Wᴴ'
+        PΔWᴴ = P \ ΔWᴴ
+        PΔWᴴW = PΔWᴴ * Wᴴ'
         PΔWᴴ .-= PΔWᴴW * Wᴴ
-        ΔA   .+= PΔWᴴ
+        ΔA .+= PΔWᴴ
     end
     return ΔA
 end
diff --git a/src/pullbacks/svd.jl b/src/pullbacks/svd.jl
index c87407f..bc7b9cb 100644
--- a/src/pullbacks/svd.jl
+++ b/src/pullbacks/svd.jl
@@ -90,7 +90,7 @@ function svd_pullback!(
 
     # Add the remaining contributions
     if m > r && !iszerotangent(ΔU) # remaining ΔU is already orthogonal to Ur
-        Sp  = view(S, indU)
+        Sp = view(S, indU)
         Vᴴp = view(Vᴴ, indU, :)
         ΔA .+= (ΔU ./ Sp') * Vᴴp
     end
diff --git a/test/testsuite/ad_utils.jl b/test/testsuite/ad_utils.jl
index dfd61ef..70ea161 100644
--- a/test/testsuite/ad_utils.jl
+++ b/test/testsuite/ad_utils.jl
@@ -128,14 +128,14 @@ function ad_qr_rd_compact_setup(A::Diagonal)
     T = eltype(A)
     r = minmn - 5
     Ard_ = randn!(similar(A, T, m))
-    Ard_[r+1:m] .= zero(T)
-    Ard  = Diagonal(Ard_)
+    Ard_[(r + 1):m] .= zero(T)
+    Ard = Diagonal(Ard_)
     Q, R = qr_compact(Ard)
     QR = (Q, R)
     ΔQ = Diagonal(randn!(similar(A.diag, T, m)))
     ΔR = Diagonal(randn!(similar(A.diag, T, m)))
-    diagview(ΔQ)[r+1:m] .= zero(T)
-    diagview(ΔR)[r+1:m] .= zero(T)
+    diagview(ΔQ)[(r + 1):m] .= zero(T)
+    diagview(ΔR)[(r + 1):m] .= zero(T)
     return (Q, R), (ΔQ, ΔR)
 end
 
@@ -283,7 +283,7 @@ function ad_svd_compact_setup(A)
     ΔS2 = Diagonal(randn!(similar(A, real(T), minmn)))
     ΔVᴴ = randn!(similar(A, T, minmn, n))
     U, S, Vᴴ = svd_compact(A)
-    ΔU, ΔVᴴ  = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ)
+    ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ)
     return (U, S, Vᴴ), (ΔU, ΔS, ΔVᴴ), (ΔU, ΔS2, ΔVᴴ)
 end
 
@@ -296,7 +296,7 @@ function ad_svd_compact_setup(A::Diagonal)
     ΔS2 = Diagonal(randn!(similar(A.diag, real(T), minmn)))
     ΔVᴴ = randn!(similar(A.diag, T, m, n))
     U, S, Vᴴ = svd_compact(A)
-    ΔU, ΔVᴴ  = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ)
+    ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ)
     return (U, S, Vᴴ), (ΔU, ΔS, ΔVᴴ), (ΔU, ΔS2, ΔVᴴ)
 end
 
diff --git a/test/testsuite/mooncake.jl b/test/testsuite/mooncake.jl
index c587b2e..4736288 100644
--- a/test/testsuite/mooncake.jl
+++ b/test/testsuite/mooncake.jl
@@ -62,14 +62,14 @@ make_mooncake_tangent(ΔD::Diagonal{T}) where {T <: Complex} = Mooncake.build_ta
 
 make_mooncake_tangent(T::Tuple) = Mooncake.build_tangent(typeof(T), make_mooncake_tangent.(T)...)
 
-make_mooncake_fdata(x)           = make_mooncake_tangent(x)
+make_mooncake_fdata(x) = make_mooncake_tangent(x)
 make_mooncake_fdata(x::Diagonal) = Mooncake.FData((diag = make_mooncake_tangent(x.diag),))
-make_mooncake_fdata(x::Tuple)    = map(make_mooncake_fdata, x)
+make_mooncake_fdata(x::Tuple) = map(make_mooncake_fdata, x)
 
 # no `alg` argument
 function _get_copying_derivative(f_c, rrule, A, ΔA, args, Δargs, ::Nothing, rdata)
-    dA_copy    = make_mooncake_fdata(copy(ΔA))
-    A_copy     = copy(A)
+    dA_copy = make_mooncake_fdata(copy(ΔA))
+    A_copy = copy(A)
     dargs_copy = make_mooncake_fdata(deepcopy(Δargs))
     copy_out, copy_pb!! = rrule(Mooncake.CoDual(f_c, Mooncake.NoFData()), Mooncake.CoDual(A_copy, dA_copy), Mooncake.CoDual(args, dargs_copy))
     copy_pb!!(rdata)
@@ -78,8 +78,8 @@ end
 
 # `alg` argument
 function _get_copying_derivative(f_c, rrule, A, ΔA, args, Δargs, alg, rdata)
-    dA_copy    = make_mooncake_fdata(copy(ΔA))
-    A_copy     = copy(A)
+    dA_copy = make_mooncake_fdata(copy(ΔA))
+    A_copy = copy(A)
     dargs_copy = make_mooncake_fdata(deepcopy(Δargs))
     copy_out, copy_pb!! = rrule(Mooncake.CoDual(f_c, Mooncake.NoFData()), Mooncake.CoDual(A_copy, dA_copy), Mooncake.CoDual(args, dargs_copy), Mooncake.CoDual(alg, Mooncake.NoFData()))
     copy_pb!!(rdata)

@kshyatt kshyatt force-pushed the testsuite-ad branch 2 times, most recently from 0669d08 to 958dd36 Compare December 23, 2025 18:25
@kshyatt kshyatt force-pushed the testsuite-ad branch 6 times, most recently from fad3f4a to daeab4d Compare January 9, 2026 09:52
@codecov
Copy link

codecov bot commented Jan 19, 2026

Codecov Report

❌ Patch coverage is 0% with 91 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/pullbacks/qr.jl 0.00% 21 Missing ⚠️
src/pullbacks/lq.jl 0.00% 19 Missing ⚠️
ext/MatrixAlgebraKitChainRulesCoreExt.jl 0.00% 14 Missing ⚠️
...gebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl 0.00% 9 Missing ⚠️
src/pullbacks/svd.jl 0.00% 9 Missing ⚠️
...MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl 0.00% 8 Missing ⚠️
src/pullbacks/polar.jl 0.00% 6 Missing ⚠️
src/pullbacks/eig.jl 0.00% 5 Missing ⚠️
Files with missing lines Coverage Δ
src/pullbacks/eig.jl 0.00% <0.00%> (-96.11%) ⬇️
src/pullbacks/polar.jl 0.00% <0.00%> (-100.00%) ⬇️
...MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl 55.88% <0.00%> (-7.46%) ⬇️
...gebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl 0.00% <0.00%> (-99.13%) ⬇️
src/pullbacks/svd.jl 0.00% <0.00%> (-96.37%) ⬇️
ext/MatrixAlgebraKitChainRulesCoreExt.jl 0.00% <0.00%> (-81.82%) ⬇️
src/pullbacks/lq.jl 0.00% <0.00%> (-95.32%) ⬇️
src/pullbacks/qr.jl 0.00% <0.00%> (-95.24%) ⬇️

... and 30 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants