diff --git a/src/pullbacks/eig.jl b/src/pullbacks/eig.jl index 4a203f64..6b89b64f 100644 --- a/src/pullbacks/eig.jl +++ b/src/pullbacks/eig.jl @@ -78,6 +78,16 @@ function eig_pullback!( end return ΔA end +function eig_pullback!( + ΔA::Diagonal, A, DV, ΔDV, ind = Colon(); + degeneracy_atol::Real = default_pullback_rank_atol(DV[1]), + gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2]) + ) + ΔA_full = zero!(similar(ΔA, size(ΔA))) + ΔA_full = eig_pullback!(ΔA_full, A, DV, ΔDV, ind; degeneracy_atol, gauge_atol) + diagview(ΔA) .+= diagview(ΔA_full) + return ΔA +end """ eig_trunc_pullback!( @@ -151,6 +161,16 @@ function eig_trunc_pullback!( end return ΔA end +function eig_trunc_pullback!( + ΔA::Diagonal, A, DV, ΔDV; + degeneracy_atol::Real = default_pullback_rank_atol(DV[1]), + gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2]) + ) + ΔA_full = zero!(similar(ΔA, size(ΔA))) + ΔA_full = eig_trunc_pullback!(ΔA_full, A, DV, ΔDV; degeneracy_atol, gauge_atol) + diagview(ΔA) .+= diagview(ΔA_full) + return ΔA +end """ eig_vals_pullback!( diff --git a/src/pullbacks/eigh.jl b/src/pullbacks/eigh.jl index 195539cf..11171685 100644 --- a/src/pullbacks/eigh.jl +++ b/src/pullbacks/eigh.jl @@ -68,6 +68,16 @@ function eigh_pullback!( end return ΔA end +function eigh_pullback!( + ΔA::Diagonal, A, DV, ΔDV, ind = Colon(); + degeneracy_atol::Real = default_pullback_rank_atol(DV[1]), + gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2]) + ) + ΔA_full = zero!(similar(ΔA, size(ΔA))) + ΔA_full = eigh_pullback!(ΔA_full, A, DV, ΔDV, ind; degeneracy_atol, gauge_atol) + diagview(ΔA) .+= diagview(ΔA_full) + return ΔA +end """ eigh_trunc_pullback!( @@ -141,6 +151,16 @@ function eigh_trunc_pullback!( end return ΔA end +function eigh_trunc_pullback!( + ΔA::Diagonal, A, DV, ΔDV; + degeneracy_atol::Real = default_pullback_rank_atol(DV[1]), + gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2]) + ) + ΔA_full = zero!(similar(ΔA, size(ΔA))) + ΔA_full = eigh_trunc_pullback!(ΔA_full, A, DV, ΔDV; degeneracy_atol, gauge_atol) + diagview(ΔA) .+= diagview(ΔA_full) + return ΔA +end """ eigh_vals_pullback!( diff --git a/src/pullbacks/svd.jl b/src/pullbacks/svd.jl index a8f8b70c..1608343e 100644 --- a/src/pullbacks/svd.jl +++ b/src/pullbacks/svd.jl @@ -99,6 +99,17 @@ function svd_pullback!( end return ΔA end +function svd_pullback!( + ΔA::Diagonal, A, USVᴴ, ΔUSVᴴ, ind = Colon(); + rank_atol::Real = default_pullback_rank_atol(USVᴴ[2]), + degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]), + gauge_atol::Real = default_pullback_gauge_atol(ΔUSVᴴ[1], ΔUSVᴴ[3]) + ) + ΔA_full = zero!(similar(ΔA, size(ΔA))) + ΔA_full = svd_pullback!(ΔA_full, A, USVᴴ, ΔUSVᴴ, ind; rank_atol, degeneracy_atol, gauge_atol) + diagview(ΔA) .+= diagview(ΔA_full) + return ΔA +end """ svd_trunc_pullback!( @@ -201,6 +212,17 @@ function svd_trunc_pullback!( ΔA = mul!(ΔA, U, Y' * Ṽᴴ, 1, 1) return ΔA end +function svd_trunc_pullback!( + ΔA::Diagonal, A, USVᴴ, ΔUSVᴴ; + rank_atol::Real = 0, + degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]), + gauge_atol::Real = default_pullback_gauge_atol(ΔUSVᴴ[1], ΔUSVᴴ[3]) + ) + ΔA_full = zero!(similar(ΔA, size(ΔA))) + ΔA_full = svd_trunc_pullback!(ΔA_full, A, USVᴴ, ΔUSVᴴ; rank_atol, degeneracy_atol, gauge_atol) + diagview(ΔA) .+= diagview(ΔA_full) + return ΔA +end """ svd_vals_pullback!(