From 61a88c12369dbaf422b72cbfdce4b301e0c78a51 Mon Sep 17 00:00:00 2001 From: Paul Brehmer Date: Wed, 21 Jan 2026 20:00:29 +0100 Subject: [PATCH 01/14] Generalize environment gauge fixing --- src/algorithms/ctmrg/gaugefix.jl | 56 +++++++++++++++++++++++++++++- test/ctmrg/fixed_iterscheme.jl | 55 +++++++++++++++++++---------- test/ctmrg/flavors.jl | 13 +++++-- test/ctmrg/gaugefix.jl | 30 +++++++++++++--- test/ctmrg/jacobian_real_linear.jl | 11 +++--- test/ctmrg/unitcell.jl | 8 ++--- 6 files changed, 136 insertions(+), 37 deletions(-) diff --git a/src/algorithms/ctmrg/gaugefix.jl b/src/algorithms/ctmrg/gaugefix.jl index 105b827b9..66af48cc8 100644 --- a/src/algorithms/ctmrg/gaugefix.jl +++ b/src/algorithms/ctmrg/gaugefix.jl @@ -1,3 +1,14 @@ +function gauge_fix(boundary_alg::CTMRGAlgorithm, signs, info) + # TODO + decomp_alg_fixed = _fix_decomposition(decomposition_algorithm(alg.projector_alg), signs, info) + alg_fixed = @set alg.projector_alg.alg = decomp_alg_fixed + alg_fixed = @set alg_fixed.projector_alg.trunc = notrunc() + return alg_fixed +end + +struct ScramblingEnvGauge end +struct ScramblingEnvGaugeC4v end + """ $(SIGNATURES) @@ -6,7 +17,7 @@ This assumes that the `envfinal` is the result of one CTMRG iteration on `envpre Given that the CTMRG run is converged, the returned environment will be element-wise converged to `envprev`. """ -function gauge_fix(envprev::CTMRGEnv{C, T}, envfinal::CTMRGEnv{C, T}) where {C, T} +function gauge_fix(envfinal::CTMRGEnv{C, T}, ::ScramblingEnvGauge, envprev::CTMRGEnv{C, T}) where {C, T} # Check if spaces in envprev and envfinal are the same same_spaces = map(eachcoordinate(envfinal, 1:4)) do (dir, r, c) space(envfinal.edges[dir, r, c]) == space(envprev.edges[dir, r, c]) && @@ -54,6 +65,42 @@ function gauge_fix(envprev::CTMRGEnv{C, T}, envfinal::CTMRGEnv{C, T}) where {C, return fix_global_phases(envprev, CTMRGEnv(cornersfix, edgesfix)), signs end +# C4v specialized gauge fixing routine with Hermitian transfer matrix +function gauge_fix(envfinal::CTMRGEnv{C, T}, ::ScramblingEnvGaugeC4v, envprev::CTMRGEnv{C, T}) where {C, T} + # Check if spaces in envprev and envfinal are the same + same_spaces = map(eachcoordinate(envfinal, 1:4)) do (dir, r, c) + space(envfinal.edges[dir, r, c]) == space(envprev.edges[dir, r, c]) && + space(envfinal.corners[dir, r, c]) == space(envprev.corners[dir, r, c]) + end + @assert all(same_spaces) "Spaces of envprev and envfinal are not the same" + + # "general" algorithm from https://arxiv.org/abs/2311.11894 + Tprev = envprev.edges[1, 1, 1] + Tfinal = envfinal.edges[1, 1, 1] + + # Random Hermitian MPS of same bond dimension + # (make Hermitian such that T-M transfer matrix has real eigenvalues) + M = project_hermitian(randn(scalartype(Tfinal), space(Tfinal))) + + # Find right fixed points of mixed transfer matrices + ρinit = randn( + scalartype(T), MPSKit._lastspace(Tfinal)' ← MPSKit._lastspace(M)' + ) + ρprev = c4v_transfermatrix_fixedpoint(Tprev, M, ρinit) + ρfinal = c4v_transfermatrix_fixedpoint(Tfinal, M, ρinit) + + # Decompose and multiply + Qprev, = left_orth!(ρprev) + Qfinal, = left_orth!(ρfinal) + + σ = Qprev * Qfinal' + + @tensor cornerfix[χ_in; χ_out] := σ[χ_in; χ1] * envfinal.corners[1][χ1; χ2] * conj(σ[χ_out; χ2]) + @tensor edgefix[χ_in D_in_above D_in_below; χ_out] := + σ[χ_in; χ1] * envfinal.edges[1][χ1 D_in_above D_in_below; χ2] * conj(σ[χ_out; χ2]) + return _c4v_env(cornerfix, edgefix), fill(σ, (4, 1, 1)) +end + # this is a bit of a hack to get the fixed point of the mixed transfer matrix # because MPSKit is not compatible with AD # NOTE: the action of the transfer operator here is NOT the same as that of @@ -82,6 +129,13 @@ function transfermatrix_fixedpoint(tops, bottoms, ρinit) end return first(vecs) end +function c4v_transfermatrix_fixedpoint(top, bottom, ρinit) + _, vecs, info = eigsolve(ρinit, 1, :LM, Lanczos()) do ρ + PEPSKit.mps_transfer_right(ρ, top, bottom) + end + info.converged > 0 || @warn "eigsolve did not converge" + return first(vecs) +end # Explicit fixing of relative phases (doing this compactly in a loop is annoying) function fix_relative_phases(envfinal::CTMRGEnv, signs) diff --git a/test/ctmrg/fixed_iterscheme.jl b/test/ctmrg/fixed_iterscheme.jl index 817dbad24..72b9d2936 100644 --- a/test/ctmrg/fixed_iterscheme.jl +++ b/test/ctmrg/fixed_iterscheme.jl @@ -12,7 +12,8 @@ using PEPSKit: fix_relative_phases, fix_global_phases, calc_elementwise_convergence, - _fix_svd_algorithm + fix_decomposition, + gauge_fix # initialize parameters χbond = 2 @@ -39,13 +40,40 @@ atol = 1.0e-5 # do extra iteration to get SVD env_conv2, info = @constinferred ctmrg_iteration(n, env_conv1, ctm_alg) - env_fix, signs = gauge_fix(env_conv1, env_conv2) + env_fix, signs = gauge_fix(env_conv2, ScramblingEnvGauge(), env_conv1) @test calc_elementwise_convergence(env_conv1, env_fix) ≈ 0 atol = atol # fix gauge of SVD - svd_alg_fix = _fix_svd_algorithm(ctm_alg.projector_alg.svd_alg, signs, info) - ctm_alg_fix = @set ctm_alg.projector_alg.svd_alg = svd_alg_fix - ctm_alg_fix = @set ctm_alg_fix.projector_alg.trunc = notrunc() + ctm_alg_fix = gauge_fix(ctm_alg, signs, info) + + # do iteration with FixedSVD + env_fixedsvd, = @constinferred ctmrg_iteration(n, env_conv1, ctm_alg_fix) + env_fixedsvd = fix_global_phases(env_conv1, env_fixedsvd) + @test calc_elementwise_convergence(env_conv1, env_fixedsvd) ≈ 0 atol = atol +end + +eigh_algs = [EighAdjoint(; fwd_alg = qriteration), EighAdjoint(; fwd_alg = :lanczos)] +projector_algs = [:c4v_eigh, :c4v_qr] + +# test same thing for C4v CTMRG +@testset "" for (eigh_alg, projector_alg) in Iterators.product(eigh_algs, projector_algs) + # TODO + ctm_alg = C4vCTMRG(; eigh_alg, projector_alg) + + # initialize states + Random.seed!(2394823842) + psi = InfinitePEPS(ComplexSpace(2), ComplexSpace(χbond); unitcell) + n = InfiniteSquareNetwork(psi) + + env_conv1, = leading_boundary(CTMRGEnv(psi, ComplexSpace(χenv)), psi, ctm_alg) + + # do extra iteration to get SVD + env_conv2, info = @constinferred ctmrg_iteration(n, env_conv1, ctm_alg) + env_fix, signs = gauge_fix(env_conv2, ScramblingEnvGaugeC4v(), env_conv1) + @test calc_elementwise_convergence(env_conv1, env_fix) ≈ 0 atol = atol + + # fix gauge of SVD + ctm_alg_fix = gauge_fix(ctm_alg, signs, info) # do iteration with FixedSVD env_fixedsvd, = @constinferred ctmrg_iteration(n, env_conv1, ctm_alg_fix) @@ -69,25 +97,16 @@ end # do extra iteration to get SVD env_conv2_iter, info_iter = ctmrg_iteration(n, env_conv1, ctm_alg_iter) - env_fix_iter, signs_iter = gauge_fix(env_conv1, env_conv2_iter) + env_fix_iter, signs_iter = gauge_fix(env_conv2_iter, ScramblingEnvGauge(), env_conv1) @test calc_elementwise_convergence(env_conv1, env_fix_iter) ≈ 0 atol = atol env_conv2_full, info_full = ctmrg_iteration(n, env_conv1, ctm_alg_full) - env_fix_full, signs_full = gauge_fix(env_conv1, env_conv2_full) + env_fix_full, signs_full = gauge_fix(env_conv2_full, ScramblingEnvGauge(), env_conv1) @test calc_elementwise_convergence(env_conv1, env_fix_full) ≈ 0 atol = atol # fix gauge of SVD - svd_alg_fix_iter = _fix_svd_algorithm( - ctm_alg_iter.projector_alg.svd_alg, signs_iter, info_iter - ) - ctm_alg_fix_iter = @set ctm_alg_iter.projector_alg.svd_alg = svd_alg_fix_iter - ctm_alg_fix_iter = @set ctm_alg_fix_iter.projector_alg.trunc = notrunc() - - svd_alg_fix_full = _fix_svd_algorithm( - ctm_alg_full.projector_alg.svd_alg, signs_full, info_full - ) - ctm_alg_fix_full = @set ctm_alg_full.projector_alg.svd_alg = svd_alg_fix_full - ctm_alg_fix_full = @set ctm_alg_fix_full.projector_alg.trunc = notrunc() + ctm_alg_fix_iter = gauge_fix(ctm_alg_iter, signs_iter, info_iter) + ctm_alg_fix_full = gauge_fix(ctm_alg_full, signs_full, info_full) # do iteration with FixedSVD env_fixedsvd_iter, = ctmrg_iteration(n, env_conv1, ctm_alg_fix_iter) diff --git a/test/ctmrg/flavors.jl b/test/ctmrg/flavors.jl index 292f81ee2..d15705b88 100644 --- a/test/ctmrg/flavors.jl +++ b/test/ctmrg/flavors.jl @@ -9,10 +9,11 @@ using PEPSKit χbond = 2 χenv = 16 unitcells = [(1, 1), (3, 4)] -projector_algs = [:halfinfinite, :fullinfinite] +projector_algs_asymm = [:halfinfinite, :fullinfinite] +projector_algs_c4v = [:c4v_eigh, :c4v_qr] @testset "$(unitcell) unit cell with $projector_alg" for (unitcell, projector_alg) in - Iterators.product(unitcells, projector_algs) + Iterators.product(unitcells, projector_algs_asymm) # compute environments Random.seed!(32350283290358) psi = InfinitePEPS(ComplexSpace(2), ComplexSpace(χbond); unitcell) @@ -54,7 +55,7 @@ end # test fixedspace actually fixes space @testset "Fixedspace truncation using $alg and $projector_alg" for (alg, projector_alg) in - Iterators.product([:sequential, :simultaneous], projector_algs) + Iterators.product([:sequential, :simultaneous], projector_algs_asymm) Ds = ComplexSpace.(fill(2, 3, 3)) χs = ComplexSpace.([16 17 18; 15 20 21; 14 19 22]) psi = InfinitePEPS(Ds, Ds, Ds) @@ -67,3 +68,9 @@ end @test all(space.(env.corners) .== space.(env2.corners)) @test all(space.(env.edges) .== space.(env2.edges)) end + +@testset "C4v CTMRG using $alg and $projector_alg" for (alg, projector_alg) in + Iterators.product([:c4v], projector_algs_c4v) + + # TODO +end diff --git a/test/ctmrg/gaugefix.jl b/test/ctmrg/gaugefix.jl index 8d6b68c1f..425079144 100644 --- a/test/ctmrg/gaugefix.jl +++ b/test/ctmrg/gaugefix.jl @@ -4,12 +4,14 @@ using PEPSKit using TensorKit using PEPSKit: ctmrg_iteration, calc_elementwise_convergence +using PEPSKit: ScramblingEnvGauge, ScramblingEnvGaugeC4v spacetypes = [ComplexSpace, Z2Space] scalartypes = [Float64, ComplexF64] unitcells = [(1, 1), (2, 2), (3, 2)] ctmrg_algs = [SequentialCTMRG, SimultaneousCTMRG] projector_algs = [:halfinfinite, :fullinfinite] +gauge_algs = [ScramblingEnvGauge()] tol = 1.0e-6 # large tol due to χ=6 χ = 6 atol = 1.0e-4 @@ -37,17 +39,35 @@ for (S, T, unitcell) in Iterators.product(spacetypes, scalartypes, unitcells) push!(preconv, (S, T, unitcell) => result) end -@testset "($S) - ($T) - ($unitcell) - ($ctmrg_alg) - ($projector_alg)" for ( - S, T, unitcell, ctmrg_alg, projector_alg, +# asymmetric CTMRG +@testset "($S) - ($T) - ($unitcell) - ($ctmrg_alg) - ($projector_alg) - ($gauge_alg)" for ( + S, T, unitcell, ctmrg_alg, projector_alg, gauge_alg, ) in Iterators.product( - spacetypes, scalartypes, unitcells, ctmrg_algs, projector_algs + spacetypes, scalartypes, unitcells, ctmrg_algs, projector_algs, gauge_algs ) alg = ctmrg_alg(; tol, projector_alg) env_pre, psi = preconv[(S, T, unitcell)] n = InfiniteSquareNetwork(psi) - env_pre env, = leading_boundary(env_pre, psi, alg) env′, = ctmrg_iteration(n, env, alg) - env_fixed, = gauge_fix(env, env′) + env_fixed, = gauge_fix(env′, gauge_alg, env) + @test calc_elementwise_convergence(env, env_fixed) ≈ 0 atol = atol +end + +projector_algs_c4v = [:c4v_eigh, :c4v_qr] +gauge_algs_c4v = [ScramblingEnvGaugeC4v()] + +# C4v CTMRG +@testset "($S) - ($T) - ($projector_alg) - ($gauge_alg)" for ( + S, T, unitcell, ctmrg_alg, projector_alg, gauge_alg, + ) in Iterators.product( + spacetypes, scalartypes, projector_algs_c4v, gauge_algs_c4v + ) + alg = C4vCTMRG(; tol, projector_alg) + env_pre, psi = preconv[(S, T, unitcell)] # TODO + n = InfiniteSquareNetwork(psi) + env, = leading_boundary(env_pre, psi, alg) + env′, = ctmrg_iteration(n, env, alg) + env_fixed, = gauge_fix(env′, gauge_alg, env) @test calc_elementwise_convergence(env, env_fixed) ≈ 0 atol = atol end diff --git a/test/ctmrg/jacobian_real_linear.jl b/test/ctmrg/jacobian_real_linear.jl index 98d7cd07b..fce51a4ca 100644 --- a/test/ctmrg/jacobian_real_linear.jl +++ b/test/ctmrg/jacobian_real_linear.jl @@ -4,7 +4,7 @@ using Accessors using Zygote using TensorKit, KrylovKit, PEPSKit using PEPSKit: - ctmrg_iteration, fix_relative_phases, fix_global_phases, _fix_svd_algorithm + ctmrg_iteration, fix_relative_phases, fix_global_phases, fix_decomposition, ScramblingEnvGauge algs = [ (:fixed, SimultaneousCTMRG(; projector_alg = :halfinfinite)), @@ -16,6 +16,7 @@ algs = [ # (:diffgauge, SimultaneousCTMRG(; projector_alg=FullInfiniteProjector)), ] Dbond, χenv = 2, 16 +alg_gauge = ScramblingEnvGauge() @testset "$iterscheme and $ctm_alg" for (iterscheme, ctm_alg) in algs Random.seed!(123521938519) @@ -25,10 +26,8 @@ Dbond, χenv = 2, 16 # follow code of _rrule if iterscheme == :fixed env_conv, info = ctmrg_iteration(InfiniteSquareNetwork(state), env, ctm_alg) - env_fixed, signs = gauge_fix(env, env_conv) - svd_alg_fixed = _fix_svd_algorithm(ctm_alg.projector_alg.svd_alg, signs, info) - alg_fixed = @set ctm_alg.projector_alg.svd_alg = svd_alg_fixed - alg_fixed = @set alg_fixed.projector_alg.trunc = notrunc() + env_fixed, signs = gauge_fix(env_conv, alg_gauge, env) + alg_fixed = gauge_fix(ctm_alg, signs, info) _, env_vjp = pullback(state, env_fixed) do A, x e, = PEPSKit.ctmrg_iteration(InfiniteSquareNetwork(A), x, alg_fixed) @@ -36,7 +35,7 @@ Dbond, χenv = 2, 16 end elseif iterscheme == :diffgauge _, env_vjp = pullback(state, env) do A, x - return gauge_fix(x, ctmrg_iteration(InfiniteSquareNetwork(A), x, ctm_alg)[1])[1] + return gauge_fix(ctmrg_iteration(InfiniteSquareNetwork(A), x, ctm_alg)[1], alg_gauge, x)[1] end end diff --git a/test/ctmrg/unitcell.jl b/test/ctmrg/unitcell.jl index 1992d807b..69540d1b8 100644 --- a/test/ctmrg/unitcell.jl +++ b/test/ctmrg/unitcell.jl @@ -1,7 +1,7 @@ using Test using Random using PEPSKit -using PEPSKit: _prev, _next, ctmrg_iteration, _fix_svd_algorithm +using PEPSKit: _prev, _next, ctmrg_iteration, fix_decomposition using TensorKit # settings @@ -39,11 +39,11 @@ function test_unitcell( @test expectation_value(peps, random_op, env′) isa Number # test if gauge fixing routines run through - _, signs = gauge_fix(env′, env″) + _, signs = gauge_fix(env″, ScramblingEnvGauge(), env′) @test signs isa Array return if ctm_alg isa SimultaneousCTMRG # also test :fixed mode gauge fixing for simultaneous CTMRG - svd_alg_fixed_full = _fix_svd_algorithm(SVDAdjoint(; fwd_alg = (; alg = :sdd)), signs, info) - svd_alg_fixed_iter = _fix_svd_algorithm(SVDAdjoint(; fwd_alg = (; alg = :iterative)), signs, info) + svd_alg_fixed_full = fix_decomposition(SVDAdjoint(; fwd_alg = (; alg = :sdd)), signs, info) + svd_alg_fixed_iter = fix_decomposition(SVDAdjoint(; fwd_alg = (; alg = :iterative)), signs, info) @test svd_alg_fixed_full isa SVDAdjoint @test svd_alg_fixed_iter isa SVDAdjoint end From 0ec910b7c0d593c199636a5b9cd278607699f7c9 Mon Sep 17 00:00:00 2001 From: Paul Brehmer Date: Wed, 21 Jan 2026 20:00:55 +0100 Subject: [PATCH 02/14] Add C4v CTMRG algorithm --- src/algorithms/ctmrg/c4v.jl | 98 +++++++ src/algorithms/ctmrg/ctmrg.jl | 4 +- src/algorithms/ctmrg/projectors.jl | 19 +- src/utility/eig.jl | 398 +++++++++++++++++++++++++++++ 4 files changed, 508 insertions(+), 11 deletions(-) create mode 100644 src/algorithms/ctmrg/c4v.jl create mode 100644 src/utility/eig.jl diff --git a/src/algorithms/ctmrg/c4v.jl b/src/algorithms/ctmrg/c4v.jl new file mode 100644 index 000000000..2b336a320 --- /dev/null +++ b/src/algorithms/ctmrg/c4v.jl @@ -0,0 +1,98 @@ +struct C4vCTMRG{P <: ProjectorAlgorithm} <: CTMRGAlgorithm + tol::Float64 + maxiter::Int + miniter::Int + verbosity::Int + projector_alg::P +end +function C4vCTMRG(; kwargs...) + return CTMRGAlgorithm(; alg = :c4v, kwargs...) +end +CTMRG_SYMBOLS[:c4v] = C4vCTMRG + +struct C4vEighProjector{S <: EighAdjoint, T} <: ProjectorAlgorithm + alg::S + trunc::T + verbosity::Int +end +function C4vEighProjector(; kwargs...) + return ProjectorAlgorithm(; alg = :c4v_eigh, kwargs...) +end +decomposition_algorithm(alg::C4vEighProjector) = alg.alg +PROJECTOR_SYMBOLS[:c4v_eigh] = C4vEighProjector + +struct C4vQRProjector{S, T} <: ProjectorAlgorithm + alg::S + verbosity::Int +end +function C4vQRProjector(; kwargs...) + return ProjectorAlgorithm(; alg = :c4v_qr, kwargs...) +end +decomposition_algorithm(alg::C4vEighProjector) = alg.alg +PROJECTOR_SYMBOLS[:c4v_qr] = C4vQRProjector + +function ctmrg_iteration( + network, + env::CTMRGEnv, + alg::C4vCTMRG, + ) + enlarged_corner = TensorMap(EnlargedCorner(network, env, (NORTHWEST, 1, 1))) + corner′, projector, info = c4v_projector(enlarged_corner, alg.projector_alg) + edge′ = c4v_renormalize(network, env, projector) + return _c4v_env(corner′, edge′), info +end + +function c4v_projector(enlarged_corner, alg::C4vEighProjector) + hermitian_corner = 0.5 * (enlarged_corner + enlarged_corner') / norm(enlarged_corner) + trunc = truncation_strategy(alg, enlarged_corner) + D, V, info = eigh_trunc!(hermitian_corner, decomposition_algorithm(alg); trunc) + + # Check for degenerate eigenvalues + Zygote.isderiving() && ignore_derivatives() do + if alg.verbosity > 0 && is_degenerate_spectrum(D) + vals = TensorKit.SectorDict(c => diag(b) for (c, b) in blocks(D)) + @warn("degenerate eigenvalues detected: ", vals) + end + end + + return D / norm(D), V, (; D, V, info...) +end + +function c4v_projector(enlarged_corner, alg::C4vQRProjector) + # TODO +end + +function c4v_renormalize(network, env, projector) + new_edge = renormalize_north_edge(env.edges[1], projector, projector', network[1, 1]) + return new_edge / norm(new_edge) +end + +# TODO: this won't differentiate properly probably due to custom CTMRGEnv rrule defined in PEPSKit +function CTMRGEnv(corner::CornerTensor, edge::EdgeTensor) + corners = fill(corner, 4, 1, 1) + edge_SW = physical_flip(edge) + edges = reshape([edge, edge, edge_SW, edge_SW], (4, 1, 1)) + return CTMRGEnv(corners, edges) +end + +function _c4v_env(corner::CornerTensor, edge::EdgeTensor) + corners = fill(corner, 4, 1, 1) + edge_SW = physical_flip(edge) + edges = reshape([edge, edge, edge_SW, edge_SW], (4, 1, 1)) + return CTMRGEnv(corners, edges) +end + +# environment with dummy corner singlet(V) ← singlet(V) and identity edge V ← V, initialized at dim(Venv) +function initialize_singlet_c4v_env(Vpeps::ElementarySpace, Venv::ElementarySpace, T = ComplexF64) + corner₀ = DiagonalTensorMap(zeros(real(T), Venv ← Venv)) + corner₀.data[1] = one(real(T)) + edge₀ = permute(id(T, Venv ⊗ Vpeps), ((1, 2, 4), (3,))) + return CTMRGEnv(corner₀, edge₀) +end + +function initialize_random_c4v_env(Vpeps::ElementarySpace, Venv::ElementarySpace, T = ComplexF64) + corner₀ = DiagonalTensorMap(randn(real(T), Venv ← Venv)) + edge₀ = randn(T, Venv ⊗ Vpeps ⊗ Vpeps' ← Venv) + edge₀ = project_hermitian(edge₀) + return CTMRGEnv(corner₀, edge₀) +end diff --git a/src/algorithms/ctmrg/ctmrg.jl b/src/algorithms/ctmrg/ctmrg.jl index fb17aa45b..2ac7684cf 100644 --- a/src/algorithms/ctmrg/ctmrg.jl +++ b/src/algorithms/ctmrg/ctmrg.jl @@ -19,7 +19,7 @@ function CTMRGAlgorithm(; maxiter = Defaults.ctmrg_maxiter, miniter = Defaults.ctmrg_miniter, verbosity = Defaults.ctmrg_verbosity, trunc = (; alg = Defaults.trunc), - svd_alg = (;), + decomposition_alg = (;), projector_alg = Defaults.projector_alg, # only allows for Symbol/NamedTuple to expose projector kwargs ) # replace symbol with projector alg type @@ -29,7 +29,7 @@ function CTMRGAlgorithm(; # parse CTMRG projector algorithm projector_algorithm = ProjectorAlgorithm(; - alg = projector_alg, svd_alg, trunc, verbosity + alg = projector_alg, decomposition_alg, trunc, verbosity ) return alg_type(tol, maxiter, miniter, verbosity, projector_algorithm) diff --git a/src/algorithms/ctmrg/projectors.jl b/src/algorithms/ctmrg/projectors.jl index 073ccfc25..befd90434 100644 --- a/src/algorithms/ctmrg/projectors.jl +++ b/src/algorithms/ctmrg/projectors.jl @@ -14,7 +14,7 @@ Keyword argument parser returning the appropriate `ProjectorAlgorithm` algorithm """ function ProjectorAlgorithm(; alg = Defaults.projector_alg, - svd_alg = (;), + decomposition_alg = (;), trunc = (;), verbosity = Defaults.projector_verbosity, ) @@ -24,7 +24,7 @@ function ProjectorAlgorithm(; alg_type = PROJECTOR_SYMBOLS[alg] # parse SVD forward & rrule algorithm - svd_algorithm = _alg_or_nt(SVDAdjoint, svd_alg) + decomposition_algorithm = _alg_or_nt(SVDAdjoint, decomposition_alg) # TODO: generalize this to DecompositionAdjoint # parse truncation scheme truncation_strategy = if trunc isa TruncationStrategy @@ -35,13 +35,14 @@ function ProjectorAlgorithm(; throw(ArgumentError("unknown trunc $trunc")) end - return alg_type(svd_algorithm, truncation_strategy, verbosity) + return alg_type(decomposition_algorithm, truncation_strategy, verbosity) end -function svd_algorithm(alg::ProjectorAlgorithm, (dir, r, c)) - if alg.svd_alg isa SVDAdjoint{<:FixedSVD} - fwd_alg = alg.svd_alg.fwd_alg - fix_svd = if isfullsvd(alg.svd_alg.fwd_alg) +function decomposition_algorithm(alg::ProjectorAlgorithm, (dir, r, c)) + decomp_alg = decomposition_algorithm(alg) + if decomp_alg isa SVDAdjoint{<:FixedSVD} + fwd_alg = decomp_alg.fwd_alg + fix_svd = if isfullsvd(decomp_alg.fwd_alg) FixedSVD( fwd_alg.U[dir, r, c], fwd_alg.S[dir, r, c], fwd_alg.V[dir, r, c], fwd_alg.U_full[dir, r, c], fwd_alg.S_full[dir, r, c], fwd_alg.V_full[dir, r, c], @@ -52,9 +53,9 @@ function svd_algorithm(alg::ProjectorAlgorithm, (dir, r, c)) nothing, nothing, nothing, ) end - return SVDAdjoint(; fwd_alg = fix_svd, rrule_alg = alg.svd_alg.rrule_alg) + return SVDAdjoint(; fwd_alg = fix_svd, rrule_alg = decomp_alg.rrule_alg) else - return alg.svd_alg + return decomp_alg end end diff --git a/src/utility/eig.jl b/src/utility/eig.jl new file mode 100644 index 000000000..3728c61dc --- /dev/null +++ b/src/utility/eig.jl @@ -0,0 +1,398 @@ +using MatrixAlgebraKit: TruncationStrategy, NoTruncation, LAPACK_EighAlgorithm, truncate +using MatrixAlgebraKit: eigh_pullback!, eigh_trunc_pullback!, findtruncated, diagview +using TensorKit: AdjointTensorMap, SectorDict, Factorizations.TruncationSpace, + throw_invalid_innerproduct, similarstoragetype +using KrylovKit: Lanczos, BlockLanczos +const KrylovKitCRCExt = Base.get_extension(KrylovKit, :KrylovKitChainRulesCoreExt) + +""" +$(TYPEDEF) + +Eigh reverse-rule algorithm which wraps MatrixAlgebraKit's `eigh_pullback!`. + +## Fields + +$(TYPEDFIELDS) + +## Constructors + + FullEighPullback(; kwargs...) + +Construct a `FullEighPullback` algorithm struct from the following keyword arguments: + +* `verbosity::Int=0` : Suppresses all output if `≤0`, prints gauge dependency warnings if `1`, and always prints gauge dependency if `≥2`. +""" +@kwdef struct FullEighPullback + verbosity::Int = 1 +end + +""" +$(TYPEDEF) + +Truncated eigh reverse-rule algorithm which wraps MatrixAlgebraKit's `eigh_trunc_pullback!`. + +## Fields + +$(TYPEDFIELDS) + +## Constructors + + TruncEighPullback(; kwargs...) + +Construct a `TruncEighPullback` algorithm struct from the following keyword arguments: + +* `verbosity::Int=0` : Suppresses all output if `≤0`, prints gauge dependency warnings if `1`, and always prints gauge dependency if `≥2`. +""" +@kwdef struct TruncEighPullback + verbosity::Int = 1 +end + +abstract type DecompositionAdjoint end + +# TODO: should this be same struct as SVDAdjoint? +""" +$(TYPEDEF) + +Wrapper for a eigenvalue decomposition algorithm `fwd_alg` with a defined reverse rule `rrule_alg`. +If `isnothing(rrule_alg)`, Zygote differentiates the forward call automatically. + +## Fields + +$(TYPEDFIELDS) + +## Constructors + + EighAdjoint(; kwargs...) + +Construct a `EighAdjoint` algorithm struct based on the following keyword arguments: + +* `fwd_alg::Union{Algorithm,NamedTuple}=(; alg::Symbol=$(Defaults.eigh_fwd_alg))`: Eig algorithm of the forward pass which can either be passed as an `Algorithm` instance or a `NamedTuple` where `alg` is one of the following: +* `rrule_alg::Union{Algorithm,NamedTuple}=(; alg::Symbol=$(Defaults.eigh_rrule_alg))`: Reverse-rule algorithm for differentiating the eigenvalue decomposition. Can be supplied by an `Algorithm` instance directly or as a `NamedTuple` where `alg` is one of the following: +""" +struct EighAdjoint{F, R} <: DecompositionAdjoint + fwd_alg::F + rrule_alg::R +end # Keep truncation algorithm separate to be able to specify CTMRG dependent information + +const EIG_FWD_SYMBOLS = IdDict{Symbol, Any}( + :qriteration => LAPACK_QRIteration, + :bisection => LAPACK_Bisection, + :divideandconquer => LAPACK_DivideAndConquer, + :multiple => LAPACK_MultipleRelativelyRobustRepresentations, + :lanczos => + (; tol = 1.0e-14, krylovdim = 30, kwargs...) -> + IterEig(; alg = Lanczos(; tol, krylovdim), kwargs...), + :blocklanczos => + (; tol = 1.0e-14, krylovdim = 30, kwargs...) -> + IterEig(; alg = BlockLanczos(; tol, krylovdim), kwargs...), +) +const EIG_RRULE_SYMBOLS = IdDict{Symbol, Type{<:Any}}( + :full => FullEighPullback, :trunc => TruncEighPullback, + # :gmres => GMRES, :bicgstab => BiCGStab, :arnoldi => Arnoldi +) + +function EighAdjoint(; fwd_alg = (;), rrule_alg = (;)) + # parse forward algorithm + fwd_algorithm = if fwd_alg isa NamedTuple + fwd_kwargs = (; alg = Defaults.eig_fwd_alg, fwd_alg...) # overwrite with specified kwargs + haskey(EIG_FWD_SYMBOLS, fwd_kwargs.alg) || + throw(ArgumentError("unknown forward algorithm: $(fwd_kwargs.alg)")) + fwd_type = EIG_FWD_SYMBOLS[fwd_kwargs.alg] + fwd_kwargs = Base.structdiff(fwd_kwargs, (; alg = nothing)) # remove `alg` keyword argument + fwd_type(; fwd_kwargs...) + else + fwd_alg + end + + # parse reverse-rule algorithm + rrule_algorithm = if rrule_alg isa NamedTuple + rrule_kwargs = (; + alg = Defaults.eig_rrule_alg, + # tol = Defaults.svd_rrule_tol, # ignore GMRES/BiCGStab/Arnoldi for the moment + # krylovdim = Defaults.svd_rrule_min_krylovdim, + # broadening = Defaults.svd_rrule_broadening, + verbosity = Defaults.eig_rrule_verbosity, + rrule_alg..., + ) # overwrite with specified kwargs + + haskey(EIG_RRULE_SYMBOLS, rrule_kwargs.alg) || + throw(ArgumentError("unknown rrule algorithm: $(rrule_kwargs.alg)")) + rrule_type = EIG_RRULE_SYMBOLS[rrule_kwargs.alg] + if rrule_type <: Union{FullEighPullback, TruncEighPullback} + rrule_kwargs = (; rrule_kwargs.verbosity) + end + + rrule_type(; rrule_kwargs...) + else + rrule_alg + end + + return EighAdjoint(fwd_algorithm, rrule_algorithm) +end + +""" + eigh_trunc(t, alg::EighAdjoint; trunc=notrunc()) + eigh_trunc!(t, alg::EighAdjoint; trunc=notrunc()) + +Wrapper around `eigh_trunc(!)` which dispatches on the `EighAdjoint` algorithm. +This is needed since a custom adjoint may be defined, depending on the `alg`. +""" +MatrixAlgebraKit.eigh_trunc(t, alg::EighAdjoint; kwargs...) = eigh_trunc!(copy(t), alg; kwargs...) +function MatrixAlgebraKit.eigh_trunc!(t, alg::EighAdjoint; trunc = notrunc()) + return _eigh_trunc!(t, alg.fwd_alg, trunc) +end +function MatrixAlgebraKit.eigh_trunc!( + t::AdjointTensorMap, alg::EighAdjoint; trunc = notrunc() + ) + D, V, info = eigh_trunc!(adjoint(t), alg; trunc) + return adjoint(D), adjoint(V), info +end + +## Forward algorithms + +# Truncated eigh but also return full D and V to make it compatible with :fixed mode +function _eigh_trunc!( + t::TensorMap, + alg::LAPACK_EighAlgorithm, + trunc::TruncationStrategy, + ) + D, V = eigh_full!(t; alg) + D̃, Ṽ, truncerror = _truncate_eigh((D, V), trunc) + + # construct info NamedTuple + condnum = cond(D) + info = (; + truncation_error = truncerror, condition_number = condnum, D_full = D, V_full = V, + ) + return D̃, Ṽ, info +end + +# hacky way of computing the truncation error for current version of eigh_trunc! +# TODO: replace once TensorKit updates to new MatrixAlgebraKit which returns truncation error as well +function _truncate_eigh((D, V), trunc::TruncationStrategy) + if !(trunc isa NoTruncation) && !isempty(blocksectors(D)) + D̃, Ṽ = truncate(eigh_trunc!, (D, V), trunc)[1] + truncerror = sqrt(abs(norm(D)^2 - norm(D̃)^2)) + return D̃, Ṽ, truncerror + else + return D, V, zero(real(scalartype(D))) + end +end + +""" +$(TYPEDEF) + +Eigenvalue decomposition struct containing a pre-computed decomposition or even multiple ones. +Additionally, it can contain the untruncated full decomposition as well. The call to +`eigh_trunc`/`eig_trunc` just returns the pre-computed D and V. In the reverse pass, +the adjoint is computed with these exact D and V and, potentially, the full decompositions +if the adjoints needs access to them. + +## Fields + +$(TYPEDFIELDS) +""" +struct FixedEig{Dt, Vt, Dtf, Vtf} + D::Dt + V::Vt + D_full::Dtf + V_full::Vtf +end + +# check whether the full D and V are supplied +function isfulleig(alg::FixedEig) + if isnothing(alg.D_full) || isnothing(alg.V_full) + return false + else + return true + end +end + +# Return pre-computed decomposition +function _eigh_trunc!(_, alg::FixedEig, ::TruncationStrategy) + info = (; + truncation_error = zero(real(scalartype(alg.D))), + condition_number = cond(alg.D), + D_full = alg.D_full, + V_full = alg.V_full, + ) + return alg.D, alg.V, info +end + + +""" +$(TYPEDEF) + +Iterative eigenvalue solver based on KrylovKit's `eigsolve`, adapted to (symmetric) tensors. +The number of targeted eigenvalues is set via the `truncspace` in `ProjectorAlg`. +In particular, this makes it possible to specify the targeted eigenvalues block-wise. +In case the symmetry block is too small as compared to the number of singular values, or +the iterative decomposition didn't converge, the algorithm falls back to a dense `eigh`/`eigh`. + +## Fields + +$(TYPEDFIELDS) + +## Constructors + + IterEig(; kwargs...) + +Construct an `IterEig` algorithm struct based on the following keyword arguments: + +* `alg=KrylovKit.Lanczos(; tol=1e-14, krylovdim=25)` : KrylovKit algorithm struct for iterative eigenvalue decomposition. +* `fallback_threshold::Float64=Inf` : Threshold for `howmany / minimum(size(block))` above which (if the block is too small) the algorithm falls back to a dense decomposition. +* `start_vector=random_start_vector` : Function providing the initial vector for the iterative algorithm. +""" +@kwdef struct IterEig + alg = KrylovKit.Lanczos(; tol = 1.0e-14, krylovdim = 25) + fallback_threshold::Float64 = Inf + start_vector = random_start_vector +end + +# Compute eigh data block-wise using KrylovKit algorithm +function _eigh_trunc!(f, alg::IterEig, trunc::TruncationStrategy) + D, V = if isempty(blocksectors(f)) + # early return + truncation_error = zero(real(scalartype(f))) + MatrixAlgebraKit.initialize_output(eigh_full!, f, LAPACK_QRIteration()) # specified algorithm doesn't matter here + else + eighdata, dims = _compute_eighdata!(f, alg, trunc) + _create_eightensors(f, eighdata, dims) + end + + # construct info NamedTuple + truncation_error = + trunc isa NoTruncation ? abs(zero(scalartype(f))) : norm(V * D * V' - f) + condition_number = cond(D) + info = (; truncation_error, condition_number, D_full = nothing, V_full = nothing) + + return D, V, info +end + +# Obtain sparse decomposition from block-wise eigsolve calls +function _compute_eighdata!( + f, alg::IterEig, trunc::Union{NoTruncation, TruncationSpace} + ) + InnerProductStyle(f) === EuclideanInnerProduct() || throw_invalid_innerproduct(:eigh_trunc!) + domain(f) == codomain(f) || + throw(SpaceMismatch("`eigh!` requires domain and codomain to be the same")) + I = sectortype(f) + dims = SectorDict{I, Int}() + + sectors = trunc isa NoTruncation ? blocksectors(f) : blocksectors(trunc.space) + generator = Base.Iterators.map(sectors) do c + b = block(f, c) + howmany = trunc isa NoTruncation ? minimum(size(b)) : blockdim(trunc.space, c) + + if howmany / minimum(size(b)) > alg.fallback_threshold # Use dense decomposition for small blocks + D, V = eigh_full!(b, LAPACK_QRIteration()) + lm_ordering = sortperm(abs.(D.diag); rev = true) # order values and vectors consistently with eigsolve + D = D.diag[lm_ordering] # extracts diagonal as Vector instead of Diagonal to make compatible with D of svdsolve + V = view(V, lm_ordering)[:, 1:howmany] + else + x₀ = alg.start_vector(b) + eig_alg = alg.alg + if howmany > alg.alg.krylovdim + eig_alg = @set eig_alg.krylovdim = round(Int, howmany * 1.2) + end + D, lvecs, info = eigsolve(b, x₀, howmany, :LM, eig_alg) + if info.converged < howmany # Fall back to dense SVD if not properly converged + @warn "Iterative eigendecomposition did not converge for block $c, falling back to eigh_full" + D, V = eigh_full!(b, LAPACK_QRIteration()) + lm_ordering = sortperm(abs.(D.diag); rev = true) + D = D.diag[lm_ordering] + V = view(V, lm_ordering)[:, 1:howmany] + else # Slice in case more values were converged than requested + V = stack(view(lvecs, 1:howmany)) + end + end + + resize!(D, howmany) + dims[c] = length(D) + return c => (D, V) + end + + eigdata = SectorDict(generator) + return eigdata, dims +end + +# Create eigh TensorMaps from sparse SectorDict +function _create_eightensors(t::TensorMap, eighdata, dims) + InnerProductStyle(t) === EuclideanInnerProduct() || throw_invalid_innerproduct(:eigh!) + + T = scalartype(t) + S = spacetype(t) + W = S(dims) + + Tr = real(T) + A = similarstoragetype(t, Tr) + D = DiagonalTensorMap{Tr, S, A}(undef, W) + V = similar(t, domain(t) ← W) + for (c, (Dc, Vc)) in eighdata + r = Base.OneTo(dims[c]) + copy!(block(D, c), Diagonal(view(Dc, r))) + copy!(block(V, c), view(Vc, :, r)) + end + return D, V +end + +## Reverse-rule algorithms +function _get_pullback_gauge_tol(verbosity::Int) + if verbosity ≤ 0 # never print gauge sensitivity + return (_) -> Inf + elseif verbosity == 1 # print gauge sensitivity above default atol + MatrixAlgebraKit.default_pullback_gaugetol + else # always print gauge sensitivity + return (_) -> 0.0 + end +end + +# eigh_trunc! rrule wrapping MatrixAlgebraKit's eigh_pullback! +function ChainRulesCore.rrule( + ::typeof(eigh_trunc!), + t::AbstractTensorMap, + alg::EighAdjoint{F, R}; + trunc::TruncationStrategy = notrunc(), + ) where {F <: Union{<:LAPACK_EighAlgorithm, <:FixedEig}, R <: FullEighPullback} + D̃, Ṽ, info = eigh_trunc(t, alg; trunc) + D, V = info.D_full, info.V_full # untruncated decomposition + inds = findtruncated(diagview(D), truncspace(only(domain(D̃)))) + gtol = _get_pullback_gauge_tol(alg.rrule_alg.verbosity) + + function eigh_trunc!_full_pullback(ΔDV) + Δt = eigh_pullback!( # TODO: does this work by now? + zeros(scalartype(t), space(t)), t, (D, V), ΔDV, inds; + gauge_atol = gtol(ΔDV) + ) + return NoTangent(), Δt, NoTangent() + end + function eigh_trunc!_full_pullback(::Tuple{ZeroTangent, ZeroTangent}) + return NoTangent(), ZeroTangent(), NoTangent() + end + + return (D̃, Ṽ, info), eigh_trunc!_full_pullback +end + +# eigh_trunc! rrule wrapping MatrixAlgebraKit's eigh_trunc_pullback! (also works for IterEig) +function ChainRulesCore.rrule( + ::typeof(eigh_trunc!), + t, + alg::EighAdjoint{F, R}; + trunc::TruncationStrategy = notrunc(), + ) where {F <: Union{<:LAPACK_EighAlgorithm, <:FixedEig, IterEig}, R <: TruncEighPullback} + D, V, info = eigh_trunc(t, alg; trunc) + gtol = _get_pullback_gauge_tol(alg.rrule_alg.verbosity) + + function eigh_trunc!_trunc_pullback(ΔDV) + Δf = eigh_trunc_pullback!( + zeros(scalartype(t), space(t)), t, (D, V), ΔDV; + gauge_atol = gtol(ΔDV) + ) + return NoTangent(), Δf, NoTangent() + end + function eigh_trunc!_trunc_pullback(::Tuple{ZeroTangent, ZeroTangent}) + return NoTangent(), ZeroTangent(), NoTangent() + end + + return (D, V, info), eigh_trunc!_trunc_pullback +end From 7256f3f3cdda0a48955a11d6bf7dbcfbf599fd88 Mon Sep 17 00:00:00 2001 From: Paul Brehmer Date: Wed, 21 Jan 2026 20:01:13 +0100 Subject: [PATCH 03/14] Add C4v CTMRG rrules and defaults --- src/Defaults.jl | 5 ++ src/PEPSKit.jl | 2 + .../fixed_point_differentiation.jl | 78 +++++++++++++++++-- src/utility/svd.jl | 7 +- 4 files changed, 79 insertions(+), 13 deletions(-) diff --git a/src/Defaults.jl b/src/Defaults.jl index ae05dfe5d..7cdab9f96 100644 --- a/src/Defaults.jl +++ b/src/Defaults.jl @@ -101,6 +101,11 @@ const svd_rrule_alg = :full # ∈ {:full, :gmres, :bicgstab, :arnoldi} const svd_rrule_broadening = 1.0e-13 const krylovdim_factor = 1.4 +# eigh forward & reverse +const eigh_fwd_alg = :qriteration +const eigh_rrule_alg = :trunc +const eigh_rrule_verbosity = 0 + # Projectors const projector_alg = :halfinfinite # ∈ {:halfinfinite, :fullinfinite} const projector_verbosity = 0 diff --git a/src/PEPSKit.jl b/src/PEPSKit.jl index 6047b2068..787167a0e 100644 --- a/src/PEPSKit.jl +++ b/src/PEPSKit.jl @@ -29,6 +29,7 @@ include("Defaults.jl") # Include first to allow for docstring interpolation wit include("utility/util.jl") include("utility/diffable_threads.jl") +include("utility/eig.jl") include("utility/svd.jl") include("utility/rotations.jl") include("utility/hook_pullback.jl") @@ -71,6 +72,7 @@ include("algorithms/ctmrg/projectors.jl") include("algorithms/ctmrg/simultaneous.jl") include("algorithms/ctmrg/sequential.jl") include("algorithms/ctmrg/gaugefix.jl") +include("algorithms/ctmrg/c4v.jl") include("algorithms/truncation/truncationschemes.jl") include("algorithms/truncation/fullenv_truncation.jl") diff --git a/src/algorithms/optimization/fixed_point_differentiation.jl b/src/algorithms/optimization/fixed_point_differentiation.jl index 8c11a0d33..2e568fe6f 100644 --- a/src/algorithms/optimization/fixed_point_differentiation.jl +++ b/src/algorithms/optimization/fixed_point_differentiation.jl @@ -217,13 +217,14 @@ function _rrule( ) env, info = leading_boundary(envinit, state, alg) alg_fixed = @set alg.projector_alg.trunc = FixedSpaceTruncation() # fix spaces during differentiation + alg_gauge = ScramblingEnvGauge() # TODO: make this a field in GradMode? function leading_boundary_diffgauge_pullback((Δenv′, Δinfo)) Δenv = unthunk(Δenv′) # find partial gradients of gauge-fixed single CTMRG iteration function f(A, x) - return gauge_fix(x, ctmrg_iteration(InfiniteSquareNetwork(A), x, alg_fixed)[1])[1] + return gauge_fix(ctmrg_iteration(InfiniteSquareNetwork(A), x, alg_fixed)[1], alg_gauge, x)[1] end _, env_vjp = rrule_via_ad(config, f, state, env) @@ -249,13 +250,12 @@ function _rrule( ) env, = leading_boundary(envinit, state, alg) alg_fixed = @set alg.projector_alg.trunc = FixedSpaceTruncation() # fix spaces during differentiation + alg_gauge = ScramblingEnvGauge() env_conv, info = ctmrg_iteration(InfiniteSquareNetwork(state), env, alg_fixed) - env_fixed, signs = gauge_fix(env, env_conv) + env_fixed, signs = gauge_fix(env_conv, alg_gauge, env) # Fix SVD - svd_alg_fixed = _fix_svd_algorithm(alg.projector_alg.svd_alg, signs, info) - alg_fixed = @set alg.projector_alg.svd_alg = svd_alg_fixed - alg_fixed = @set alg_fixed.projector_alg.trunc = notrunc() + alg_fixed = gauge_fix(alg, signs, info) function leading_boundary_fixed_pullback((Δenv′, Δinfo)) Δenv = unthunk(Δenv′) @@ -278,7 +278,7 @@ function _rrule( return (env_fixed, info), leading_boundary_fixed_pullback end -function _fix_svd_algorithm(alg::SVDAdjoint, signs, info) +function fix_decomposition(alg::SVDAdjoint, signs, info) # embed gauge signs in larger space to fix gauge of full U and V on truncated subspace rowsize, colsize = size(signs, 2), size(signs, 3) signs_full = map(Iterators.product(1:4, 1:rowsize, 1:colsize)) do (dir, r, c) @@ -311,7 +311,7 @@ function _fix_svd_algorithm(alg::SVDAdjoint, signs, info) rrule_alg = alg.rrule_alg, ) end -function _fix_svd_algorithm(alg::SVDAdjoint{F}, signs, info) where {F <: IterSVD} +function fix_decomposition(alg::SVDAdjoint{F}, signs, info) where {F <: IterSVD} # fix kept U and V only since iterative SVD doesn't have access to full spectrum U_fixed, V_fixed = fix_relative_phases(info.U, info.V, signs) return SVDAdjoint(; @@ -319,6 +319,70 @@ function _fix_svd_algorithm(alg::SVDAdjoint{F}, signs, info) where {F <: IterSVD rrule_alg = alg.rrule_alg, ) end +function fix_decomposition(alg::EighAdjoint, signs, info) + # embed gauge signs in larger space to fix gauge of full V on truncated subspace + σ = signs[1] + extended_σ = zeros(scalartype(σ), space(info.D_full)) + for (c, b) in blocks(extended_σ) + σc = block(σ, c) + kept_dim = size(σc, 1) + b[diagind(b)] .= one(scalartype(σ)) # put ones on the diagonal + b[1:kept_dim, 1:kept_dim] .= σc # set to σ on kept subspace + end + + # fix kept and full U + U_fixed = info.U * σ' + U_full_fixed = info.U_full * extended_σ' + return EighAdjoint(; + fwd_alg = FixedEig(info.D, U_fixed, info.D_full, U_full_fixed), + rrule_alg = alg.rrule_alg, + ) +end +function fix_decomposition(alg::EighAdjoint{F}, signs, info) where {F <: IterEig} + # fix kept U only since iterative decomposition doesn't have access to full spectrum + U_fixed = info.U * signs[1]' + return EighAdjoint(; + fwd_alg = FixedEig(info.D, U_fixed, nothing, nothing), + rrule_alg = alg.rrule_alg, + ) +end + +# nested fixed-point gradient evaluation for C4v CTMRG +function _rrule( + gradmode::GradMode{:fixed}, + config::RuleConfig, + ::typeof(MPSKit.leading_boundary), + env₀, + state, + alg::C4vCTMRG, + ) + env, = leading_boundary(env₀, state, alg) + alg_fixed = @set alg.projector_alg.trunc = FixedSpaceTruncation() # fix spaces during differentiation + alg_gauge = ScramblingEnvGaugeC4v() + env_conv, info = ctmrg_iteration(InfiniteSquareNetwork(state), env, alg_fixed) + _, signs = gauge_fix(env_conv, alg_gauge, env) + + # Fix eigendecomposition + alg_fixed = gauge_fix(alg, signs, info) + + function leading_boundary_fixed_pullback((Δenv′, Δinfo)) + Δenv = unthunk(Δenv′) + + f(A, x) = ctmrg_iteration(InfiniteSquareNetwork(A), x, alg_fixed)[1] + _, env_vjp = rrule_via_ad(config, f, state, env) + + # evaluate the geometric sum + ∂f∂A(x)::typeof(state) = env_vjp(x)[2] + # ∂f∂x(x)::typeof(env) = env_vjp(x)[3] # TODO: why is this derivative type-instable? The corner gradient is a complex DiagonalTensorMap + ∂f∂x(x) = env_vjp(x)[3] + ∂F∂env = fpgrad(Δenv, ∂f∂x, ∂f∂A, Δenv, gradmode) + + return NoTangent(), ZeroTangent(), ∂F∂env, NoTangent() + end + + # TODO: also return env (instead of `env_fixed`) for general :fixed mode CTMRG in PEPSKit + return (env, info), leading_boundary_fixed_pullback +end @doc """ fpgrad(∂F∂x, ∂f∂x, ∂f∂A, y0, alg) diff --git a/src/utility/svd.jl b/src/utility/svd.jl index 27aa8a6c2..d8bc45803 100644 --- a/src/utility/svd.jl +++ b/src/utility/svd.jl @@ -1,8 +1,3 @@ -using MatrixAlgebraKit: NoTruncation, truncate -using TensorKit: AdjointTensorMap, SectorDict, Factorizations.TruncationSpace, - throw_invalid_innerproduct, similarstoragetype -const KrylovKitCRCExt = Base.get_extension(KrylovKit, :KrylovKitChainRulesCoreExt) - """ $(TYPEDEF) @@ -53,7 +48,7 @@ Construct a `SVDAdjoint` algorithm struct based on the following keyword argumen - `:bicgstab`: BiCGStab iterative linear solver, see the [KrylovKit docs](https://jutho.github.io/KrylovKit.jl/stable/man/algorithms/#KrylovKit.BiCGStab) for details - `:arnoldi`: Arnoldi Krylov algorithm, see the [KrylovKit docs](https://jutho.github.io/KrylovKit.jl/stable/man/algorithms/#KrylovKit.Arnoldi) for details """ -struct SVDAdjoint{F, R} +struct SVDAdjoint{F, R} <: DecompositionAdjoint fwd_alg::F rrule_alg::R end # Keep truncation algorithm separate to be able to specify CTMRG dependent information From 1a326b9106b64f21e60119852ca0c5e40cebdb12 Mon Sep 17 00:00:00 2001 From: Paul Brehmer Date: Thu, 22 Jan 2026 18:11:48 +0100 Subject: [PATCH 04/14] Fix tensor decompositions --- src/utility/eigh.jl | 392 ++++++++++++++++++++++++++++++++++++++++++++ src/utility/qr.jl | 68 ++++++++ src/utility/svd.jl | 2 +- 3 files changed, 461 insertions(+), 1 deletion(-) create mode 100644 src/utility/eigh.jl create mode 100644 src/utility/qr.jl diff --git a/src/utility/eigh.jl b/src/utility/eigh.jl new file mode 100644 index 000000000..3ee86130b --- /dev/null +++ b/src/utility/eigh.jl @@ -0,0 +1,392 @@ +using MatrixAlgebraKit: TruncationStrategy, NoTruncation, LAPACK_EighAlgorithm, truncate +using MatrixAlgebraKit: eigh_pullback!, eigh_trunc_pullback!, findtruncated, diagview +using TensorKit: AdjointTensorMap, SectorDict, Factorizations.TruncationSpace, + throw_invalid_innerproduct, similarstoragetype +using KrylovKit: Lanczos, BlockLanczos +const KrylovKitCRCExt = Base.get_extension(KrylovKit, :KrylovKitChainRulesCoreExt) + +""" +$(TYPEDEF) + +Eigh reverse-rule algorithm which wraps MatrixAlgebraKit's `eigh_pullback!`. + +## Fields + +$(TYPEDFIELDS) + +## Constructors + + FullEighPullback(; kwargs...) + +Construct a `FullEighPullback` algorithm struct from the following keyword arguments: + +* `verbosity::Int=0` : Suppresses all output if `≤0`, prints gauge dependency warnings if `1`, and always prints gauge dependency if `≥2`. +""" +@kwdef struct FullEighPullback + verbosity::Int = 1 +end + +""" +$(TYPEDEF) + +Truncated eigh reverse-rule algorithm which wraps MatrixAlgebraKit's `eigh_trunc_pullback!`. + +## Fields + +$(TYPEDFIELDS) + +## Constructors + + TruncEighPullback(; kwargs...) + +Construct a `TruncEighPullback` algorithm struct from the following keyword arguments: + +* `verbosity::Int=0` : Suppresses all output if `≤0`, prints gauge dependency warnings if `1`, and always prints gauge dependency if `≥2`. +""" +@kwdef struct TruncEighPullback + verbosity::Int = 1 +end + +# TODO: should this be same struct as SVDAdjoint? +""" +$(TYPEDEF) + +Wrapper for a eigenvalue decomposition algorithm `fwd_alg` with a defined reverse rule `rrule_alg`. +If `isnothing(rrule_alg)`, Zygote differentiates the forward call automatically. + +## Fields + +$(TYPEDFIELDS) + +## Constructors + + EighAdjoint(; kwargs...) + +Construct a `EighAdjoint` algorithm struct based on the following keyword arguments: + +* `fwd_alg::Union{Algorithm,NamedTuple}=(; alg::Symbol=$(Defaults.eigh_fwd_alg))`: Eig algorithm of the forward pass which can either be passed as an `Algorithm` instance or a `NamedTuple` where `alg` is one of the following: +* `rrule_alg::Union{Algorithm,NamedTuple}=(; alg::Symbol=$(Defaults.eigh_rrule_alg))`: Reverse-rule algorithm for differentiating the eigenvalue decomposition. Can be supplied by an `Algorithm` instance directly or as a `NamedTuple` where `alg` is one of the following: +""" +struct EighAdjoint{F, R} + fwd_alg::F + rrule_alg::R +end # Keep truncation algorithm separate to be able to specify CTMRG dependent information + +const EIGH_FWD_SYMBOLS = IdDict{Symbol, Any}( + :qriteration => LAPACK_QRIteration, + :bisection => LAPACK_Bisection, + :divideandconquer => LAPACK_DivideAndConquer, + :multiple => LAPACK_MultipleRelativelyRobustRepresentations, + :lanczos => + (; tol = 1.0e-14, krylovdim = 30, kwargs...) -> + IterEig(; alg = Lanczos(; tol, krylovdim), kwargs...), + :blocklanczos => + (; tol = 1.0e-14, krylovdim = 30, kwargs...) -> + IterEig(; alg = BlockLanczos(; tol, krylovdim), kwargs...), +) +const EIGH_RRULE_SYMBOLS = IdDict{Symbol, Type{<:Any}}( + :full => FullEighPullback, :trunc => TruncEighPullback, +) + +function EighAdjoint(; fwd_alg = (;), rrule_alg = (;)) + # parse forward algorithm + fwd_algorithm = if fwd_alg isa NamedTuple + fwd_kwargs = (; alg = Defaults.eigh_fwd_alg, fwd_alg...) # overwrite with specified kwargs + haskey(EIGH_FWD_SYMBOLS, fwd_kwargs.alg) || + throw(ArgumentError("unknown forward algorithm: $(fwd_kwargs.alg)")) + fwd_type = EIGH_FWD_SYMBOLS[fwd_kwargs.alg] + fwd_kwargs = Base.structdiff(fwd_kwargs, (; alg = nothing)) # remove `alg` keyword argument + fwd_type(; fwd_kwargs...) + else + fwd_alg + end + + # parse reverse-rule algorithm + rrule_algorithm = if rrule_alg isa NamedTuple + rrule_kwargs = (; + alg = Defaults.eigh_rrule_alg, + verbosity = Defaults.eigh_rrule_verbosity, + rrule_alg..., + ) # overwrite with specified kwargs + + haskey(EIGH_RRULE_SYMBOLS, rrule_kwargs.alg) || + throw(ArgumentError("unknown rrule algorithm: $(rrule_kwargs.alg)")) + rrule_type = EIGH_RRULE_SYMBOLS[rrule_kwargs.alg] + if rrule_type <: Union{FullEighPullback, TruncEighPullback} + rrule_kwargs = (; rrule_kwargs.verbosity) + end + + rrule_type(; rrule_kwargs...) + else + rrule_alg + end + + return EighAdjoint(fwd_algorithm, rrule_algorithm) +end + +""" + eigh_trunc(t, alg::EighAdjoint; trunc=notrunc()) + eigh_trunc!(t, alg::EighAdjoint; trunc=notrunc()) + +Wrapper around `eigh_trunc(!)` which dispatches on the `EighAdjoint` algorithm. +This is needed since a custom adjoint may be defined, depending on the `alg`. +""" +MatrixAlgebraKit.eigh_trunc(t, alg::EighAdjoint; kwargs...) = eigh_trunc!(copy(t), alg; kwargs...) +function MatrixAlgebraKit.eigh_trunc!(t, alg::EighAdjoint; trunc = notrunc()) + return _eigh_trunc!(t, alg.fwd_alg, trunc) +end +function MatrixAlgebraKit.eigh_trunc!( + t::AdjointTensorMap, alg::EighAdjoint; trunc = notrunc() + ) + D, V, info = eigh_trunc!(adjoint(t), alg; trunc) + return adjoint(D), adjoint(V), info +end + +## Forward algorithms + +# Truncated eigh but also return full D and V to make it compatible with :fixed mode +function _eigh_trunc!( + t::TensorMap, + alg::LAPACK_EighAlgorithm, + trunc::TruncationStrategy, + ) + D, V = eigh_full!(t; alg) + D̃, Ṽ, truncerror = _truncate_eigh((D, V), trunc) + + # construct info NamedTuple + condnum = cond(D) + info = (; + truncation_error = truncerror, condition_number = condnum, D_full = D, V_full = V, + ) + return D̃, Ṽ, info +end + +# hacky way of computing the truncation error for current version of eigh_trunc! +# TODO: replace once TensorKit updates to new MatrixAlgebraKit which returns truncation error as well +function _truncate_eigh((D, V), trunc::TruncationStrategy) + if !(trunc isa NoTruncation) && !isempty(blocksectors(D)) + D̃, Ṽ = truncate(eigh_trunc!, (D, V), trunc)[1] + truncerror = sqrt(abs(norm(D)^2 - norm(D̃)^2)) + return D̃, Ṽ, truncerror + else + return D, V, zero(real(scalartype(D))) + end +end + +""" +$(TYPEDEF) + +Eigenvalue decomposition struct containing a pre-computed decomposition or even multiple ones. +Additionally, it can contain the untruncated full decomposition as well. The call to +`eigh_trunc`/`eig_trunc` just returns the pre-computed D and V. In the reverse pass, +the adjoint is computed with these exact D and V and, potentially, the full decompositions +if the adjoints needs access to them. + +## Fields + +$(TYPEDFIELDS) +""" +struct FixedEig{Dt, Vt, Dtf, Vtf} + D::Dt + V::Vt + D_full::Dtf + V_full::Vtf +end + +# check whether the full D and V are supplied +function isfulleig(alg::FixedEig) + if isnothing(alg.D_full) || isnothing(alg.V_full) + return false + else + return true + end +end + +# Return pre-computed decomposition +function _eigh_trunc!(_, alg::FixedEig, ::TruncationStrategy) + info = (; + truncation_error = zero(real(scalartype(alg.D))), + condition_number = cond(alg.D), + D_full = alg.D_full, + V_full = alg.V_full, + ) + return alg.D, alg.V, info +end + + +""" +$(TYPEDEF) + +Iterative eigenvalue solver based on KrylovKit's `eigsolve`, adapted to (symmetric) tensors. +The number of targeted eigenvalues is set via the `truncspace` in `ProjectorAlg`. +In particular, this makes it possible to specify the targeted eigenvalues block-wise. +In case the symmetry block is too small as compared to the number of singular values, or +the iterative decomposition didn't converge, the algorithm falls back to a dense `eigh`/`eigh`. + +## Fields + +$(TYPEDFIELDS) + +## Constructors + + IterEig(; kwargs...) + +Construct an `IterEig` algorithm struct based on the following keyword arguments: + +* `alg=KrylovKit.Lanczos(; tol=1e-14, krylovdim=25)` : KrylovKit algorithm struct for iterative eigenvalue decomposition. +* `fallback_threshold::Float64=Inf` : Threshold for `howmany / minimum(size(block))` above which (if the block is too small) the algorithm falls back to a dense decomposition. +* `start_vector=random_start_vector` : Function providing the initial vector for the iterative algorithm. +""" +@kwdef struct IterEig + alg = KrylovKit.Lanczos(; tol = 1.0e-14, krylovdim = 25) + fallback_threshold::Float64 = Inf + start_vector = random_start_vector +end + +# Compute eigh data block-wise using KrylovKit algorithm +function _eigh_trunc!(f, alg::IterEig, trunc::TruncationStrategy) + D, V = if isempty(blocksectors(f)) + # early return + truncation_error = zero(real(scalartype(f))) + MatrixAlgebraKit.initialize_output(eigh_full!, f, LAPACK_QRIteration()) # specified algorithm doesn't matter here + else + eighdata, dims = _compute_eighdata!(f, alg, trunc) + _create_eightensors(f, eighdata, dims) + end + + # construct info NamedTuple + truncation_error = + trunc isa NoTruncation ? abs(zero(scalartype(f))) : norm(V * D * V' - f) + condition_number = cond(D) + info = (; truncation_error, condition_number, D_full = nothing, V_full = nothing) + + return D, V, info +end + +# Obtain sparse decomposition from block-wise eigsolve calls +function _compute_eighdata!( + f, alg::IterEig, trunc::Union{NoTruncation, TruncationSpace} + ) + InnerProductStyle(f) === EuclideanInnerProduct() || throw_invalid_innerproduct(:eigh_trunc!) + domain(f) == codomain(f) || + throw(SpaceMismatch("`eigh!` requires domain and codomain to be the same")) + I = sectortype(f) + dims = SectorDict{I, Int}() + + sectors = trunc isa NoTruncation ? blocksectors(f) : blocksectors(trunc.space) + generator = Base.Iterators.map(sectors) do c + b = block(f, c) + howmany = trunc isa NoTruncation ? minimum(size(b)) : blockdim(trunc.space, c) + + if howmany / minimum(size(b)) > alg.fallback_threshold # Use dense decomposition for small blocks + D, V = eigh_full!(b, LAPACK_QRIteration()) + lm_ordering = sortperm(abs.(D.diag); rev = true) # order values and vectors consistently with eigsolve + D = D.diag[lm_ordering] # extracts diagonal as Vector instead of Diagonal to make compatible with D of svdsolve + V = view(V, lm_ordering)[:, 1:howmany] + else + x₀ = alg.start_vector(b) + eig_alg = alg.alg + if howmany > alg.alg.krylovdim + eig_alg = @set eig_alg.krylovdim = round(Int, howmany * 1.2) + end + D, lvecs, info = eigsolve(b, x₀, howmany, :LM, eig_alg) + if info.converged < howmany # Fall back to dense SVD if not properly converged + @warn "Iterative eigendecomposition did not converge for block $c, falling back to eigh_full" + D, V = eigh_full!(b, LAPACK_QRIteration()) + lm_ordering = sortperm(abs.(D.diag); rev = true) + D = D.diag[lm_ordering] + V = view(V, lm_ordering)[:, 1:howmany] + else # Slice in case more values were converged than requested + V = stack(view(lvecs, 1:howmany)) + end + end + + resize!(D, howmany) + dims[c] = length(D) + return c => (D, V) + end + + eigdata = SectorDict(generator) + return eigdata, dims +end + +# Create eigh TensorMaps from sparse SectorDict +function _create_eightensors(t::TensorMap, eighdata, dims) + InnerProductStyle(t) === EuclideanInnerProduct() || throw_invalid_innerproduct(:eigh!) + + T = scalartype(t) + S = spacetype(t) + W = S(dims) + + Tr = real(T) + A = similarstoragetype(t, Tr) + D = DiagonalTensorMap{Tr, S, A}(undef, W) + V = similar(t, domain(t) ← W) + for (c, (Dc, Vc)) in eighdata + r = Base.OneTo(dims[c]) + copy!(block(D, c), Diagonal(view(Dc, r))) + copy!(block(V, c), view(Vc, :, r)) + end + return D, V +end + +## Reverse-rule algorithms +function _get_pullback_gauge_tol(verbosity::Int) + if verbosity ≤ 0 # never print gauge sensitivity + return (_) -> Inf + elseif verbosity == 1 # print gauge sensitivity above default atol + MatrixAlgebraKit.default_pullback_gaugetol + else # always print gauge sensitivity + return (_) -> 0.0 + end +end + +# eigh_trunc! rrule wrapping MatrixAlgebraKit's eigh_pullback! +function ChainRulesCore.rrule( + ::typeof(eigh_trunc!), + t::AbstractTensorMap, + alg::EighAdjoint{F, R}; + trunc::TruncationStrategy = notrunc(), + ) where {F <: Union{<:LAPACK_EighAlgorithm, <:FixedEig}, R <: FullEighPullback} + D̃, Ṽ, info = eigh_trunc(t, alg; trunc) + D, V = info.D_full, info.V_full # untruncated decomposition + inds = findtruncated(diagview(D), truncspace(only(domain(D̃)))) + gtol = _get_pullback_gauge_tol(alg.rrule_alg.verbosity) + + function eigh_trunc!_full_pullback(ΔDV) + Δt = eigh_pullback!( # TODO: does this work by now? + zeros(scalartype(t), space(t)), t, (D, V), ΔDV, inds; + gauge_atol = gtol(ΔDV) + ) + return NoTangent(), Δt, NoTangent() + end + function eigh_trunc!_full_pullback(::Tuple{ZeroTangent, ZeroTangent}) + return NoTangent(), ZeroTangent(), NoTangent() + end + + return (D̃, Ṽ, info), eigh_trunc!_full_pullback +end + +# eigh_trunc! rrule wrapping MatrixAlgebraKit's eigh_trunc_pullback! (also works for IterEig) +function ChainRulesCore.rrule( + ::typeof(eigh_trunc!), + t, + alg::EighAdjoint{F, R}; + trunc::TruncationStrategy = notrunc(), + ) where {F <: Union{<:LAPACK_EighAlgorithm, <:FixedEig, IterEig}, R <: TruncEighPullback} + D, V, info = eigh_trunc(t, alg; trunc) + gtol = _get_pullback_gauge_tol(alg.rrule_alg.verbosity) + + function eigh_trunc!_trunc_pullback(ΔDV) + Δf = eigh_trunc_pullback!( + zeros(scalartype(t), space(t)), t, (D, V), ΔDV; + gauge_atol = gtol(ΔDV) + ) + return NoTangent(), Δf, NoTangent() + end + function eigh_trunc!_trunc_pullback(::Tuple{ZeroTangent, ZeroTangent}) + return NoTangent(), ZeroTangent(), NoTangent() + end + + return (D, V, info), eigh_trunc!_trunc_pullback +end diff --git a/src/utility/qr.jl b/src/utility/qr.jl new file mode 100644 index 000000000..efd3a912d --- /dev/null +++ b/src/utility/qr.jl @@ -0,0 +1,68 @@ +""" +$(TYPEDEF) + +Wrapper for a QR decomposition algorithm `fwd_alg` with a defined reverse rule `rrule_alg`. +If `isnothing(rrule_alg)`, Zygote differentiates the forward call automatically. + +## Fields + +$(TYPEDFIELDS) + +## Constructors + + QRAdjoint(; kwargs...) + +Construct a `QRAdjoint` algorithm struct based on the following keyword arguments: + +* `fwd_alg::Union{Algorithm,NamedTuple}=(; alg::Symbol=$(Defaults.qr_fwd_alg))`: Eig algorithm of the forward pass which can either be passed as an `Algorithm` instance or a `NamedTuple` where `alg` is one of the following: +* `rrule_alg::Union{Algorithm,NamedTuple}=(; alg::Symbol=$(Defaults.qr_rrule_alg))`: Reverse-rule algorithm for differentiating the eigenvalue decomposition. Can be supplied by an `Algorithm` instance directly or as a `NamedTuple` where `alg` is one of the following: +""" +struct QRAdjoint{F, R} + fwd_alg::F + rrule_alg::R +end # Keep truncation algorithm separate to be able to specify CTMRG dependent information + +const QR_FWD_SYMBOLS = IdDict{Symbol, Any}( + # TODO +) +const QR_RRULE_SYMBOLS = IdDict{Symbol, Type{<:Any}}( + # TODO +) + +function QRAdjoint(; fwd_alg = (;), rrule_alg = (;)) + # parse forward algorithm + fwd_algorithm = if fwd_alg isa NamedTuple + fwd_kwargs = (; alg = Defaults.qr_fwd_alg, fwd_alg...) # overwrite with specified kwargs + haskey(QR_FWD_SYMBOLS, fwd_kwargs.alg) || + throw(ArgumentError("unknown forward algorithm: $(fwd_kwargs.alg)")) + fwd_type = QR_FWD_SYMBOLS[fwd_kwargs.alg] + fwd_kwargs = Base.structdiff(fwd_kwargs, (; alg = nothing)) # remove `alg` keyword argument + fwd_type(; fwd_kwargs...) + else + fwd_alg + end + + # parse reverse-rule algorithm + rrule_algorithm = if rrule_alg isa NamedTuple + rrule_kwargs = (; + alg = Defaults.qr_rrule_alg, + verbosity = Defaults.qr_rrule_verbosity, + rrule_alg..., + ) # overwrite with specified kwargs + + haskey(QR_RRULE_SYMBOLS, rrule_kwargs.alg) || + throw(ArgumentError("unknown rrule algorithm: $(rrule_kwargs.alg)")) + rrule_type = QR_RRULE_SYMBOLS[rrule_kwargs.alg] + if rrule_type <: Something # TODO + rrule_kwargs = (; rrule_kwargs.verbosity) + end + + rrule_type(; rrule_kwargs...) + else + rrule_alg + end + + return QRAdjoint(fwd_algorithm, rrule_algorithm) +end + +# TODO: implement wrapper for MatrixAlgebraKit QR functions diff --git a/src/utility/svd.jl b/src/utility/svd.jl index d8bc45803..348554342 100644 --- a/src/utility/svd.jl +++ b/src/utility/svd.jl @@ -48,7 +48,7 @@ Construct a `SVDAdjoint` algorithm struct based on the following keyword argumen - `:bicgstab`: BiCGStab iterative linear solver, see the [KrylovKit docs](https://jutho.github.io/KrylovKit.jl/stable/man/algorithms/#KrylovKit.BiCGStab) for details - `:arnoldi`: Arnoldi Krylov algorithm, see the [KrylovKit docs](https://jutho.github.io/KrylovKit.jl/stable/man/algorithms/#KrylovKit.Arnoldi) for details """ -struct SVDAdjoint{F, R} <: DecompositionAdjoint +struct SVDAdjoint{F, R} fwd_alg::F rrule_alg::R end # Keep truncation algorithm separate to be able to specify CTMRG dependent information From 23c16102bb83ee07302a739df67e32aec6721865 Mon Sep 17 00:00:00 2001 From: Paul Brehmer Date: Thu, 22 Jan 2026 18:12:19 +0100 Subject: [PATCH 05/14] Fix parameter selection --- src/Defaults.jl | 6 + src/algorithms/ctmrg/ctmrg.jl | 6 +- src/algorithms/ctmrg/projectors.jl | 22 +- src/algorithms/ctmrg/sequential.jl | 2 +- src/algorithms/ctmrg/simultaneous.jl | 2 +- src/algorithms/select_algorithm.jl | 15 +- src/utility/eig.jl | 398 --------------------------- 7 files changed, 34 insertions(+), 417 deletions(-) delete mode 100644 src/utility/eig.jl diff --git a/src/Defaults.jl b/src/Defaults.jl index 7cdab9f96..e1674491e 100644 --- a/src/Defaults.jl +++ b/src/Defaults.jl @@ -106,9 +106,15 @@ const eigh_fwd_alg = :qriteration const eigh_rrule_alg = :trunc const eigh_rrule_verbosity = 0 +# QR forward & reverse +const qr_fwd_alg = :something # TODO +const qr_rrule_alg = :something +const qr_rrule_verbosity = :something + # Projectors const projector_alg = :halfinfinite # ∈ {:halfinfinite, :fullinfinite} const projector_verbosity = 0 +const projector_alg_c4v = :c4v_eigh # Fixed-point gradient const gradient_tol = 1.0e-6 diff --git a/src/algorithms/ctmrg/ctmrg.jl b/src/algorithms/ctmrg/ctmrg.jl index 2ac7684cf..2de2653fa 100644 --- a/src/algorithms/ctmrg/ctmrg.jl +++ b/src/algorithms/ctmrg/ctmrg.jl @@ -27,7 +27,9 @@ function CTMRGAlgorithm(; alg_type = CTMRG_SYMBOLS[alg] # parse CTMRG projector algorithm - + if alg == :c4v && projector_alg == Defaults.projector_alg + projector_alg = Defaults.projector_alg_c4v + end projector_algorithm = ProjectorAlgorithm(; alg = projector_alg, decomposition_alg, trunc, verbosity ) @@ -76,7 +78,7 @@ supplied via the keyword arguments or directly as an [`CTMRGAlgorithm`](@ref) st - `:truncrank` : Additionally supply truncation dimension `η`; truncate such that the 2-norm of the truncated values is smaller than `η` - `:truncspace` : Additionally supply truncation space `η`; truncate according to the supplied vector space - `:trunctol` : Additionally supply singular value cutoff `η`; truncate such that every retained singular value is larger than `η` -* `svd_alg::Union{<:SVDAdjoint,NamedTuple}` : SVD algorithm for computing projectors. See also [`SVDAdjoint`](@ref). By default, a reverse-rule tolerance of `tol=1e1tol` where the `krylovdim` is adapted to the `env₀` environment dimension. +* `decomposition_alg::Union{<:DecompositionAdjoint,NamedTuple}` : Tensor decomposition algorithm for computing projectors. See e.g. [`SVDAdjoint`](@ref). * `projector_alg::Symbol=:$(Defaults.projector_alg)` : Variant of the projector algorithm. See also [`ProjectorAlgorithm`](@ref). - `:halfinfinite` : Projection via SVDs of half-infinite (two enlarged corners) CTMRG environments. - `:fullinfinite` : Projection via SVDs of full-infinite (all four enlarged corners) CTMRG environments. diff --git a/src/algorithms/ctmrg/projectors.jl b/src/algorithms/ctmrg/projectors.jl index befd90434..04388cb58 100644 --- a/src/algorithms/ctmrg/projectors.jl +++ b/src/algorithms/ctmrg/projectors.jl @@ -24,7 +24,14 @@ function ProjectorAlgorithm(; alg_type = PROJECTOR_SYMBOLS[alg] # parse SVD forward & rrule algorithm - decomposition_algorithm = _alg_or_nt(SVDAdjoint, decomposition_alg) # TODO: generalize this to DecompositionAdjoint + + decomposition_algorithm = if alg in [:halfinfinite, :fullinfinite] + _alg_or_nt(SVDAdjoint, decomposition_alg) + elseif alg in [:c4v_eigh,] + _alg_or_nt(EighAdjoint, decomposition_alg) + elseif alg in [:c4v_qr,] + _alg_or_nt(QRAdjoint, decomposition_alg) + end # TODO: how do we solve this in a proper way? # parse truncation scheme truncation_strategy = if trunc isa TruncationStrategy @@ -38,6 +45,7 @@ function ProjectorAlgorithm(; return alg_type(decomposition_algorithm, truncation_strategy, verbosity) end +decomposition_algorithm(alg::ProjectorAlgorithm) = alg.alg function decomposition_algorithm(alg::ProjectorAlgorithm, (dir, r, c)) decomp_alg = decomposition_algorithm(alg) if decomp_alg isa SVDAdjoint{<:FixedSVD} @@ -83,7 +91,7 @@ $(TYPEDFIELDS) Construct the half-infinite projector algorithm based on the following keyword arguments: -* `svd_alg::Union{<:SVDAdjoint,NamedTuple}=SVDAdjoint()` : SVD algorithm including the reverse rule. See [`SVDAdjoint`](@ref). +* `alg::Union{<:SVDAdjoint,NamedTuple}=SVDAdjoint()` : SVD algorithm including the reverse rule. See [`SVDAdjoint`](@ref). * `trunc::Union{TruncationStrategy,NamedTuple}=(; alg::Symbol=:$(Defaults.trunc))` : Truncation strategy for the projector computation, which controls the resulting virtual spaces. Here, `alg` can be one of the following: - `:fixedspace` : Keep virtual spaces fixed during projection - `:notrunc` : No singular values are truncated and the performed SVDs are exact @@ -96,7 +104,7 @@ Construct the half-infinite projector algorithm based on the following keyword a 1. Print singular value degeneracy warnings """ struct HalfInfiniteProjector{S <: SVDAdjoint, T} <: ProjectorAlgorithm - svd_alg::S + alg::S trunc::T verbosity::Int end @@ -121,7 +129,7 @@ $(TYPEDFIELDS) Construct the full-infinite projector algorithm based on the following keyword arguments: -* `svd_alg::Union{<:SVDAdjoint,NamedTuple}=SVDAdjoint()` : SVD algorithm including the reverse rule. See [`SVDAdjoint`](@ref). +* `alg::Union{<:SVDAdjoint,NamedTuple}=SVDAdjoint()` : SVD algorithm including the reverse rule. See [`SVDAdjoint`](@ref). * `trunc::Union{TruncationStrategy,NamedTuple}=(; alg::Symbol=:$(Defaults.trunc))` : Truncation scheme for the projector computation, which controls the resulting virtual spaces. Here, `alg` can be one of the following: - `:fixedspace` : Keep virtual spaces fixed during projection - `:notrunc` : No singular values are truncated and the performed SVDs are exact @@ -134,7 +142,7 @@ Construct the full-infinite projector algorithm based on the following keyword a 1. Print singular value degeneracy warnings """ struct FullInfiniteProjector{S <: SVDAdjoint, T} <: ProjectorAlgorithm - svd_alg::S + alg::S trunc::T verbosity::Int end @@ -153,7 +161,7 @@ and the given coordinate using the specified `alg`. function compute_projector(enlarged_corners, coordinate, alg::HalfInfiniteProjector) # SVD half-infinite environment halfinf = half_infinite_environment(enlarged_corners...) - svd_alg = svd_algorithm(alg, coordinate) + svd_alg = decomposition_algorithm(alg, coordinate) U, S, V, info = svd_trunc!(halfinf / norm(halfinf), svd_alg; trunc = alg.trunc) # Check for degenerate singular values @@ -174,7 +182,7 @@ function compute_projector(enlarged_corners, coordinate, alg::FullInfiniteProjec # SVD full-infinite environment fullinf = full_infinite_environment(halfinf_left, halfinf_right) - svd_alg = svd_algorithm(alg, coordinate) + svd_alg = decomposition_algorithm(alg, coordinate) U, S, V, info = svd_trunc!(fullinf / norm(fullinf), svd_alg; trunc = alg.trunc) # Check for degenerate singular values diff --git a/src/algorithms/ctmrg/sequential.jl b/src/algorithms/ctmrg/sequential.jl index 9818f79b2..92f8969cf 100644 --- a/src/algorithms/ctmrg/sequential.jl +++ b/src/algorithms/ctmrg/sequential.jl @@ -21,7 +21,7 @@ For a full description, see [`leading_boundary`](@ref). The supported keywords a * `miniter::Int=$(Defaults.ctmrg_miniter)` * `verbosity::Int=$(Defaults.ctmrg_verbosity)` * `trunc::Union{TruncationStrategy,NamedTuple}=(; alg::Symbol=:$(Defaults.trunc))` -* `svd_alg::Union{<:SVDAdjoint,NamedTuple}` +* `decomposition_alg::Union{<:SVDAdjoint,NamedTuple}` * `projector_alg::Symbol=:$(Defaults.projector_alg)` """ struct SequentialCTMRG{P <: ProjectorAlgorithm} <: CTMRGAlgorithm diff --git a/src/algorithms/ctmrg/simultaneous.jl b/src/algorithms/ctmrg/simultaneous.jl index 68f880acf..34ac95a16 100644 --- a/src/algorithms/ctmrg/simultaneous.jl +++ b/src/algorithms/ctmrg/simultaneous.jl @@ -20,7 +20,7 @@ For a full description, see [`leading_boundary`](@ref). The supported keywords a * `miniter::Int=$(Defaults.ctmrg_miniter)` * `verbosity::Int=$(Defaults.ctmrg_verbosity)` * `trunc::Union{TruncationStrategy,NamedTuple}=(; alg::Symbol=:$(Defaults.trunc))` -* `svd_alg::Union{<:SVDAdjoint,NamedTuple}` +* `decomposition_alg::Union{<:SVDAdjoint,NamedTuple}` * `projector_alg::Symbol=:$(Defaults.projector_alg)` """ struct SimultaneousCTMRG{P <: ProjectorAlgorithm} <: CTMRGAlgorithm diff --git a/src/algorithms/select_algorithm.jl b/src/algorithms/select_algorithm.jl index a2a177a61..57b615ac1 100644 --- a/src/algorithms/select_algorithm.jl +++ b/src/algorithms/select_algorithm.jl @@ -54,13 +54,13 @@ function select_algorithm( alg = Defaults.ctmrg_alg, tol = Defaults.ctmrg_tol, verbosity = Defaults.ctmrg_verbosity, - svd_alg = (;), + decomposition_alg = (;), kwargs..., ) # adjust SVD rrule settings to CTMRG tolerance, verbosity and environment dimension - if svd_alg isa NamedTuple && - haskey(svd_alg, :rrule_alg) && - svd_alg.rrule_alg isa NamedTuple + if decomposition_alg isa NamedTuple && + haskey(decomposition_alg, :rrule_alg) && + decomposition_alg.rrule_alg isa NamedTuple χenv = maximum(env₀.corners) do corner return dim(space(corner, 1)) end @@ -68,10 +68,9 @@ function select_algorithm( krylovdim = max( Defaults.svd_rrule_min_krylovdim, round(Int, Defaults.krylovdim_factor * χenv) ) - rrule_alg = (; tol = 1.0e1tol, verbosity = verbosity - 2, krylovdim, svd_alg.rrule_alg...) - svd_alg = (; rrule_alg, svd_alg...) + rrule_alg = (; tol = 1.0e1tol, verbosity = verbosity - 2, krylovdim, decomposition_alg.rrule_alg...) + decomposition_alg = (; rrule_alg, decomposition_alg...) end - svd_algorithm = SVDAdjoint(; svd_alg...) - return CTMRGAlgorithm(; alg, tol, verbosity, svd_alg = svd_algorithm, kwargs...) + return CTMRGAlgorithm(; alg, tol, verbosity, decomposition_alg, kwargs...) end diff --git a/src/utility/eig.jl b/src/utility/eig.jl deleted file mode 100644 index 3728c61dc..000000000 --- a/src/utility/eig.jl +++ /dev/null @@ -1,398 +0,0 @@ -using MatrixAlgebraKit: TruncationStrategy, NoTruncation, LAPACK_EighAlgorithm, truncate -using MatrixAlgebraKit: eigh_pullback!, eigh_trunc_pullback!, findtruncated, diagview -using TensorKit: AdjointTensorMap, SectorDict, Factorizations.TruncationSpace, - throw_invalid_innerproduct, similarstoragetype -using KrylovKit: Lanczos, BlockLanczos -const KrylovKitCRCExt = Base.get_extension(KrylovKit, :KrylovKitChainRulesCoreExt) - -""" -$(TYPEDEF) - -Eigh reverse-rule algorithm which wraps MatrixAlgebraKit's `eigh_pullback!`. - -## Fields - -$(TYPEDFIELDS) - -## Constructors - - FullEighPullback(; kwargs...) - -Construct a `FullEighPullback` algorithm struct from the following keyword arguments: - -* `verbosity::Int=0` : Suppresses all output if `≤0`, prints gauge dependency warnings if `1`, and always prints gauge dependency if `≥2`. -""" -@kwdef struct FullEighPullback - verbosity::Int = 1 -end - -""" -$(TYPEDEF) - -Truncated eigh reverse-rule algorithm which wraps MatrixAlgebraKit's `eigh_trunc_pullback!`. - -## Fields - -$(TYPEDFIELDS) - -## Constructors - - TruncEighPullback(; kwargs...) - -Construct a `TruncEighPullback` algorithm struct from the following keyword arguments: - -* `verbosity::Int=0` : Suppresses all output if `≤0`, prints gauge dependency warnings if `1`, and always prints gauge dependency if `≥2`. -""" -@kwdef struct TruncEighPullback - verbosity::Int = 1 -end - -abstract type DecompositionAdjoint end - -# TODO: should this be same struct as SVDAdjoint? -""" -$(TYPEDEF) - -Wrapper for a eigenvalue decomposition algorithm `fwd_alg` with a defined reverse rule `rrule_alg`. -If `isnothing(rrule_alg)`, Zygote differentiates the forward call automatically. - -## Fields - -$(TYPEDFIELDS) - -## Constructors - - EighAdjoint(; kwargs...) - -Construct a `EighAdjoint` algorithm struct based on the following keyword arguments: - -* `fwd_alg::Union{Algorithm,NamedTuple}=(; alg::Symbol=$(Defaults.eigh_fwd_alg))`: Eig algorithm of the forward pass which can either be passed as an `Algorithm` instance or a `NamedTuple` where `alg` is one of the following: -* `rrule_alg::Union{Algorithm,NamedTuple}=(; alg::Symbol=$(Defaults.eigh_rrule_alg))`: Reverse-rule algorithm for differentiating the eigenvalue decomposition. Can be supplied by an `Algorithm` instance directly or as a `NamedTuple` where `alg` is one of the following: -""" -struct EighAdjoint{F, R} <: DecompositionAdjoint - fwd_alg::F - rrule_alg::R -end # Keep truncation algorithm separate to be able to specify CTMRG dependent information - -const EIG_FWD_SYMBOLS = IdDict{Symbol, Any}( - :qriteration => LAPACK_QRIteration, - :bisection => LAPACK_Bisection, - :divideandconquer => LAPACK_DivideAndConquer, - :multiple => LAPACK_MultipleRelativelyRobustRepresentations, - :lanczos => - (; tol = 1.0e-14, krylovdim = 30, kwargs...) -> - IterEig(; alg = Lanczos(; tol, krylovdim), kwargs...), - :blocklanczos => - (; tol = 1.0e-14, krylovdim = 30, kwargs...) -> - IterEig(; alg = BlockLanczos(; tol, krylovdim), kwargs...), -) -const EIG_RRULE_SYMBOLS = IdDict{Symbol, Type{<:Any}}( - :full => FullEighPullback, :trunc => TruncEighPullback, - # :gmres => GMRES, :bicgstab => BiCGStab, :arnoldi => Arnoldi -) - -function EighAdjoint(; fwd_alg = (;), rrule_alg = (;)) - # parse forward algorithm - fwd_algorithm = if fwd_alg isa NamedTuple - fwd_kwargs = (; alg = Defaults.eig_fwd_alg, fwd_alg...) # overwrite with specified kwargs - haskey(EIG_FWD_SYMBOLS, fwd_kwargs.alg) || - throw(ArgumentError("unknown forward algorithm: $(fwd_kwargs.alg)")) - fwd_type = EIG_FWD_SYMBOLS[fwd_kwargs.alg] - fwd_kwargs = Base.structdiff(fwd_kwargs, (; alg = nothing)) # remove `alg` keyword argument - fwd_type(; fwd_kwargs...) - else - fwd_alg - end - - # parse reverse-rule algorithm - rrule_algorithm = if rrule_alg isa NamedTuple - rrule_kwargs = (; - alg = Defaults.eig_rrule_alg, - # tol = Defaults.svd_rrule_tol, # ignore GMRES/BiCGStab/Arnoldi for the moment - # krylovdim = Defaults.svd_rrule_min_krylovdim, - # broadening = Defaults.svd_rrule_broadening, - verbosity = Defaults.eig_rrule_verbosity, - rrule_alg..., - ) # overwrite with specified kwargs - - haskey(EIG_RRULE_SYMBOLS, rrule_kwargs.alg) || - throw(ArgumentError("unknown rrule algorithm: $(rrule_kwargs.alg)")) - rrule_type = EIG_RRULE_SYMBOLS[rrule_kwargs.alg] - if rrule_type <: Union{FullEighPullback, TruncEighPullback} - rrule_kwargs = (; rrule_kwargs.verbosity) - end - - rrule_type(; rrule_kwargs...) - else - rrule_alg - end - - return EighAdjoint(fwd_algorithm, rrule_algorithm) -end - -""" - eigh_trunc(t, alg::EighAdjoint; trunc=notrunc()) - eigh_trunc!(t, alg::EighAdjoint; trunc=notrunc()) - -Wrapper around `eigh_trunc(!)` which dispatches on the `EighAdjoint` algorithm. -This is needed since a custom adjoint may be defined, depending on the `alg`. -""" -MatrixAlgebraKit.eigh_trunc(t, alg::EighAdjoint; kwargs...) = eigh_trunc!(copy(t), alg; kwargs...) -function MatrixAlgebraKit.eigh_trunc!(t, alg::EighAdjoint; trunc = notrunc()) - return _eigh_trunc!(t, alg.fwd_alg, trunc) -end -function MatrixAlgebraKit.eigh_trunc!( - t::AdjointTensorMap, alg::EighAdjoint; trunc = notrunc() - ) - D, V, info = eigh_trunc!(adjoint(t), alg; trunc) - return adjoint(D), adjoint(V), info -end - -## Forward algorithms - -# Truncated eigh but also return full D and V to make it compatible with :fixed mode -function _eigh_trunc!( - t::TensorMap, - alg::LAPACK_EighAlgorithm, - trunc::TruncationStrategy, - ) - D, V = eigh_full!(t; alg) - D̃, Ṽ, truncerror = _truncate_eigh((D, V), trunc) - - # construct info NamedTuple - condnum = cond(D) - info = (; - truncation_error = truncerror, condition_number = condnum, D_full = D, V_full = V, - ) - return D̃, Ṽ, info -end - -# hacky way of computing the truncation error for current version of eigh_trunc! -# TODO: replace once TensorKit updates to new MatrixAlgebraKit which returns truncation error as well -function _truncate_eigh((D, V), trunc::TruncationStrategy) - if !(trunc isa NoTruncation) && !isempty(blocksectors(D)) - D̃, Ṽ = truncate(eigh_trunc!, (D, V), trunc)[1] - truncerror = sqrt(abs(norm(D)^2 - norm(D̃)^2)) - return D̃, Ṽ, truncerror - else - return D, V, zero(real(scalartype(D))) - end -end - -""" -$(TYPEDEF) - -Eigenvalue decomposition struct containing a pre-computed decomposition or even multiple ones. -Additionally, it can contain the untruncated full decomposition as well. The call to -`eigh_trunc`/`eig_trunc` just returns the pre-computed D and V. In the reverse pass, -the adjoint is computed with these exact D and V and, potentially, the full decompositions -if the adjoints needs access to them. - -## Fields - -$(TYPEDFIELDS) -""" -struct FixedEig{Dt, Vt, Dtf, Vtf} - D::Dt - V::Vt - D_full::Dtf - V_full::Vtf -end - -# check whether the full D and V are supplied -function isfulleig(alg::FixedEig) - if isnothing(alg.D_full) || isnothing(alg.V_full) - return false - else - return true - end -end - -# Return pre-computed decomposition -function _eigh_trunc!(_, alg::FixedEig, ::TruncationStrategy) - info = (; - truncation_error = zero(real(scalartype(alg.D))), - condition_number = cond(alg.D), - D_full = alg.D_full, - V_full = alg.V_full, - ) - return alg.D, alg.V, info -end - - -""" -$(TYPEDEF) - -Iterative eigenvalue solver based on KrylovKit's `eigsolve`, adapted to (symmetric) tensors. -The number of targeted eigenvalues is set via the `truncspace` in `ProjectorAlg`. -In particular, this makes it possible to specify the targeted eigenvalues block-wise. -In case the symmetry block is too small as compared to the number of singular values, or -the iterative decomposition didn't converge, the algorithm falls back to a dense `eigh`/`eigh`. - -## Fields - -$(TYPEDFIELDS) - -## Constructors - - IterEig(; kwargs...) - -Construct an `IterEig` algorithm struct based on the following keyword arguments: - -* `alg=KrylovKit.Lanczos(; tol=1e-14, krylovdim=25)` : KrylovKit algorithm struct for iterative eigenvalue decomposition. -* `fallback_threshold::Float64=Inf` : Threshold for `howmany / minimum(size(block))` above which (if the block is too small) the algorithm falls back to a dense decomposition. -* `start_vector=random_start_vector` : Function providing the initial vector for the iterative algorithm. -""" -@kwdef struct IterEig - alg = KrylovKit.Lanczos(; tol = 1.0e-14, krylovdim = 25) - fallback_threshold::Float64 = Inf - start_vector = random_start_vector -end - -# Compute eigh data block-wise using KrylovKit algorithm -function _eigh_trunc!(f, alg::IterEig, trunc::TruncationStrategy) - D, V = if isempty(blocksectors(f)) - # early return - truncation_error = zero(real(scalartype(f))) - MatrixAlgebraKit.initialize_output(eigh_full!, f, LAPACK_QRIteration()) # specified algorithm doesn't matter here - else - eighdata, dims = _compute_eighdata!(f, alg, trunc) - _create_eightensors(f, eighdata, dims) - end - - # construct info NamedTuple - truncation_error = - trunc isa NoTruncation ? abs(zero(scalartype(f))) : norm(V * D * V' - f) - condition_number = cond(D) - info = (; truncation_error, condition_number, D_full = nothing, V_full = nothing) - - return D, V, info -end - -# Obtain sparse decomposition from block-wise eigsolve calls -function _compute_eighdata!( - f, alg::IterEig, trunc::Union{NoTruncation, TruncationSpace} - ) - InnerProductStyle(f) === EuclideanInnerProduct() || throw_invalid_innerproduct(:eigh_trunc!) - domain(f) == codomain(f) || - throw(SpaceMismatch("`eigh!` requires domain and codomain to be the same")) - I = sectortype(f) - dims = SectorDict{I, Int}() - - sectors = trunc isa NoTruncation ? blocksectors(f) : blocksectors(trunc.space) - generator = Base.Iterators.map(sectors) do c - b = block(f, c) - howmany = trunc isa NoTruncation ? minimum(size(b)) : blockdim(trunc.space, c) - - if howmany / minimum(size(b)) > alg.fallback_threshold # Use dense decomposition for small blocks - D, V = eigh_full!(b, LAPACK_QRIteration()) - lm_ordering = sortperm(abs.(D.diag); rev = true) # order values and vectors consistently with eigsolve - D = D.diag[lm_ordering] # extracts diagonal as Vector instead of Diagonal to make compatible with D of svdsolve - V = view(V, lm_ordering)[:, 1:howmany] - else - x₀ = alg.start_vector(b) - eig_alg = alg.alg - if howmany > alg.alg.krylovdim - eig_alg = @set eig_alg.krylovdim = round(Int, howmany * 1.2) - end - D, lvecs, info = eigsolve(b, x₀, howmany, :LM, eig_alg) - if info.converged < howmany # Fall back to dense SVD if not properly converged - @warn "Iterative eigendecomposition did not converge for block $c, falling back to eigh_full" - D, V = eigh_full!(b, LAPACK_QRIteration()) - lm_ordering = sortperm(abs.(D.diag); rev = true) - D = D.diag[lm_ordering] - V = view(V, lm_ordering)[:, 1:howmany] - else # Slice in case more values were converged than requested - V = stack(view(lvecs, 1:howmany)) - end - end - - resize!(D, howmany) - dims[c] = length(D) - return c => (D, V) - end - - eigdata = SectorDict(generator) - return eigdata, dims -end - -# Create eigh TensorMaps from sparse SectorDict -function _create_eightensors(t::TensorMap, eighdata, dims) - InnerProductStyle(t) === EuclideanInnerProduct() || throw_invalid_innerproduct(:eigh!) - - T = scalartype(t) - S = spacetype(t) - W = S(dims) - - Tr = real(T) - A = similarstoragetype(t, Tr) - D = DiagonalTensorMap{Tr, S, A}(undef, W) - V = similar(t, domain(t) ← W) - for (c, (Dc, Vc)) in eighdata - r = Base.OneTo(dims[c]) - copy!(block(D, c), Diagonal(view(Dc, r))) - copy!(block(V, c), view(Vc, :, r)) - end - return D, V -end - -## Reverse-rule algorithms -function _get_pullback_gauge_tol(verbosity::Int) - if verbosity ≤ 0 # never print gauge sensitivity - return (_) -> Inf - elseif verbosity == 1 # print gauge sensitivity above default atol - MatrixAlgebraKit.default_pullback_gaugetol - else # always print gauge sensitivity - return (_) -> 0.0 - end -end - -# eigh_trunc! rrule wrapping MatrixAlgebraKit's eigh_pullback! -function ChainRulesCore.rrule( - ::typeof(eigh_trunc!), - t::AbstractTensorMap, - alg::EighAdjoint{F, R}; - trunc::TruncationStrategy = notrunc(), - ) where {F <: Union{<:LAPACK_EighAlgorithm, <:FixedEig}, R <: FullEighPullback} - D̃, Ṽ, info = eigh_trunc(t, alg; trunc) - D, V = info.D_full, info.V_full # untruncated decomposition - inds = findtruncated(diagview(D), truncspace(only(domain(D̃)))) - gtol = _get_pullback_gauge_tol(alg.rrule_alg.verbosity) - - function eigh_trunc!_full_pullback(ΔDV) - Δt = eigh_pullback!( # TODO: does this work by now? - zeros(scalartype(t), space(t)), t, (D, V), ΔDV, inds; - gauge_atol = gtol(ΔDV) - ) - return NoTangent(), Δt, NoTangent() - end - function eigh_trunc!_full_pullback(::Tuple{ZeroTangent, ZeroTangent}) - return NoTangent(), ZeroTangent(), NoTangent() - end - - return (D̃, Ṽ, info), eigh_trunc!_full_pullback -end - -# eigh_trunc! rrule wrapping MatrixAlgebraKit's eigh_trunc_pullback! (also works for IterEig) -function ChainRulesCore.rrule( - ::typeof(eigh_trunc!), - t, - alg::EighAdjoint{F, R}; - trunc::TruncationStrategy = notrunc(), - ) where {F <: Union{<:LAPACK_EighAlgorithm, <:FixedEig, IterEig}, R <: TruncEighPullback} - D, V, info = eigh_trunc(t, alg; trunc) - gtol = _get_pullback_gauge_tol(alg.rrule_alg.verbosity) - - function eigh_trunc!_trunc_pullback(ΔDV) - Δf = eigh_trunc_pullback!( - zeros(scalartype(t), space(t)), t, (D, V), ΔDV; - gauge_atol = gtol(ΔDV) - ) - return NoTangent(), Δf, NoTangent() - end - function eigh_trunc!_trunc_pullback(::Tuple{ZeroTangent, ZeroTangent}) - return NoTangent(), ZeroTangent(), NoTangent() - end - - return (D, V, info), eigh_trunc!_trunc_pullback -end From 0a475fe85325f1f8da04db924d4826061ae8cb7f Mon Sep 17 00:00:00 2001 From: Paul Brehmer Date: Thu, 22 Jan 2026 18:12:30 +0100 Subject: [PATCH 06/14] Make some tests runnable --- src/PEPSKit.jl | 4 +- src/algorithms/ctmrg/c4v.jl | 46 ++++++++++++++----- src/algorithms/ctmrg/gaugefix.jl | 39 +++++++++++++--- .../fixed_point_differentiation.jl | 8 ++-- src/environments/ctmrg_environments.jl | 4 +- test/ctmrg/fixed_iterscheme.jl | 23 ++++------ test/ctmrg/flavors.jl | 35 ++++++++++---- test/ctmrg/jacobian_real_linear.jl | 2 +- test/ctmrg/unitcell.jl | 6 +-- test/gradients/ctmrg_gradients.jl | 2 +- 10 files changed, 118 insertions(+), 51 deletions(-) diff --git a/src/PEPSKit.jl b/src/PEPSKit.jl index 787167a0e..aace22875 100644 --- a/src/PEPSKit.jl +++ b/src/PEPSKit.jl @@ -29,8 +29,9 @@ include("Defaults.jl") # Include first to allow for docstring interpolation wit include("utility/util.jl") include("utility/diffable_threads.jl") -include("utility/eig.jl") +include("utility/eigh.jl") include("utility/svd.jl") +include("utility/qr.jl") include("utility/rotations.jl") include("utility/hook_pullback.jl") include("utility/autoopt.jl") @@ -104,6 +105,7 @@ export SVDAdjoint, FullSVDReverseRule, IterSVD export CTMRGEnv, SequentialCTMRG, SimultaneousCTMRG export FixedSpaceTruncation, SiteDependentTruncation export HalfInfiniteProjector, FullInfiniteProjector +export EighAdjoint, C4vCTMRG, C4vEighProjector, C4vQRProjector export LocalOperator, physicalspace export product_peps export reduced_densitymatrix, expectation_value, network_value, cost_function diff --git a/src/algorithms/ctmrg/c4v.jl b/src/algorithms/ctmrg/c4v.jl index 2b336a320..fa3065f36 100644 --- a/src/algorithms/ctmrg/c4v.jl +++ b/src/algorithms/ctmrg/c4v.jl @@ -18,7 +18,6 @@ end function C4vEighProjector(; kwargs...) return ProjectorAlgorithm(; alg = :c4v_eigh, kwargs...) end -decomposition_algorithm(alg::C4vEighProjector) = alg.alg PROJECTOR_SYMBOLS[:c4v_eigh] = C4vEighProjector struct C4vQRProjector{S, T} <: ProjectorAlgorithm @@ -28,9 +27,12 @@ end function C4vQRProjector(; kwargs...) return ProjectorAlgorithm(; alg = :c4v_qr, kwargs...) end -decomposition_algorithm(alg::C4vEighProjector) = alg.alg PROJECTOR_SYMBOLS[:c4v_qr] = C4vQRProjector +# +## C4v-symmetric CTMRG iteration (called through `leading_boundary`) +# + function ctmrg_iteration( network, env::CTMRGEnv, @@ -39,7 +41,7 @@ function ctmrg_iteration( enlarged_corner = TensorMap(EnlargedCorner(network, env, (NORTHWEST, 1, 1))) corner′, projector, info = c4v_projector(enlarged_corner, alg.projector_alg) edge′ = c4v_renormalize(network, env, projector) - return _c4v_env(corner′, edge′), info + return CTMRGEnv(corner′, edge′), info end function c4v_projector(enlarged_corner, alg::C4vEighProjector) @@ -67,21 +69,43 @@ function c4v_renormalize(network, env, projector) return new_edge / norm(new_edge) end -# TODO: this won't differentiate properly probably due to custom CTMRGEnv rrule defined in PEPSKit -function CTMRGEnv(corner::CornerTensor, edge::EdgeTensor) +function CTMRGEnv( + corner::AbstractTensorMap{T, S, 1, 1}, edge::AbstractTensorMap{T′, S, N, 1} + ) where {T, T′, S, N} corners = fill(corner, 4, 1, 1) edge_SW = physical_flip(edge) edges = reshape([edge, edge, edge_SW, edge_SW], (4, 1, 1)) return CTMRGEnv(corners, edges) end -function _c4v_env(corner::CornerTensor, edge::EdgeTensor) - corners = fill(corner, 4, 1, 1) - edge_SW = physical_flip(edge) - edges = reshape([edge, edge, edge_SW, edge_SW], (4, 1, 1)) - return CTMRGEnv(corners, edges) +# +## utility +# + +# Adjoint of an edge tensor, but permutes the physical spaces back into the codomain. +# Intuitively, this conjugates a tensor and then reinterprets its 'direction' as an edge tensor. +function _dag(A::AbstractTensorMap{T, S, N, 1}) where {T, S, N} + return permute(A', ((1, (3:(N + 1))...), (2,))) +end +function physical_flip(A::AbstractTensorMap{T, S, N, 1}) where {T, S, N} + return flip(A, 2:N) end +# should perform this check at the beginning of `leading_boundary` really... +function check_symmetry(state, ::RotateReflect; atol = 1.0e-10) + @assert length(state) == 1 "check_symmetry only works for single site unit cells" + @assert norm(state[1] - _fit_spaces(rotl90(state[1]), state[1])) / + norm(state[1]) < atol "not rotation invariant" + @assert norm(state[1] - _fit_spaces(herm_depth(state[1]), state[1])) / + norm(state[1]) < atol "not hermitian-reflection invariant" + return nothing +end + +# +## environment initialization +# + +# TODO: rewrite this using `initialize_environment` and C4v-specific initialization algorithms # environment with dummy corner singlet(V) ← singlet(V) and identity edge V ← V, initialized at dim(Venv) function initialize_singlet_c4v_env(Vpeps::ElementarySpace, Venv::ElementarySpace, T = ComplexF64) corner₀ = DiagonalTensorMap(zeros(real(T), Venv ← Venv)) @@ -93,6 +117,6 @@ end function initialize_random_c4v_env(Vpeps::ElementarySpace, Venv::ElementarySpace, T = ComplexF64) corner₀ = DiagonalTensorMap(randn(real(T), Venv ← Venv)) edge₀ = randn(T, Venv ⊗ Vpeps ⊗ Vpeps' ← Venv) - edge₀ = project_hermitian(edge₀) + edge₀ = 0.5 * (edge₀ + physical_flip(_dag(edge₀))) return CTMRGEnv(corner₀, edge₀) end diff --git a/src/algorithms/ctmrg/gaugefix.jl b/src/algorithms/ctmrg/gaugefix.jl index 66af48cc8..ce0a57905 100644 --- a/src/algorithms/ctmrg/gaugefix.jl +++ b/src/algorithms/ctmrg/gaugefix.jl @@ -1,12 +1,38 @@ -function gauge_fix(boundary_alg::CTMRGAlgorithm, signs, info) - # TODO - decomp_alg_fixed = _fix_decomposition(decomposition_algorithm(alg.projector_alg), signs, info) - alg_fixed = @set alg.projector_alg.alg = decomp_alg_fixed - alg_fixed = @set alg_fixed.projector_alg.trunc = notrunc() + +""" + gauge_fix(alg::CTMRGAlgorithm, signs, info) + gauge_fix(alg::ProjectorAlgorithm, signs, info) + +Fix the free gauges of the tensor decompositions associated with `alg`. +""" +function gauge_fix(alg::CTMRGAlgorithm, signs, info) + alg_fixed = @set alg.projector_alg = gauge_fix(alg.projector_alg, signs, info) + return alg_fixed +end +function gauge_fix(alg::ProjectorAlgorithm, signs, info) + decomposition_alg_fixed = gauge_fix(alg.alg, signs, info) # every ProjectorAlgorithm needs an `alg` field + alg_fixed = @set alg.alg = decomposition_alg_fixed + alg_fixed = @set alg_fixed.trunc = notrunc() return alg_fixed end +""" +$(TYPEDEF) + +CTMRG environment gauge fixing algorithm implementing the "general" technique from +https://arxiv.org/abs/2311.11894. This works by constructing a transfer matrix consisting +of an edge tensor and a random MPS, thus scrambling potential degeneracies, and then +performing a QR decomposition to extract the gauge signs. This is adapted accordingly for +asymmetric CTMRG algorithms using multi-site unit cell transfer matrices. +""" struct ScramblingEnvGauge end + +""" +$(TYPEDEF) + +C4v-symmetric equivalent of the [ScramblingEnvGauge`](@ref) environment gauge fixing +algorithm. +""" struct ScramblingEnvGaugeC4v end """ @@ -25,7 +51,6 @@ function gauge_fix(envfinal::CTMRGEnv{C, T}, ::ScramblingEnvGauge, envprev::CTMR end @assert all(same_spaces) "Spaces of envprev and envfinal are not the same" - # Try the "general" algorithm from https://arxiv.org/abs/2311.11894 signs = map(eachcoordinate(envfinal, 1:4)) do (dir, r, c) # Gather edge tensors and pretend they're InfiniteMPSs if dir == NORTH @@ -98,7 +123,7 @@ function gauge_fix(envfinal::CTMRGEnv{C, T}, ::ScramblingEnvGaugeC4v, envprev::C @tensor cornerfix[χ_in; χ_out] := σ[χ_in; χ1] * envfinal.corners[1][χ1; χ2] * conj(σ[χ_out; χ2]) @tensor edgefix[χ_in D_in_above D_in_below; χ_out] := σ[χ_in; χ1] * envfinal.edges[1][χ1 D_in_above D_in_below; χ2] * conj(σ[χ_out; χ2]) - return _c4v_env(cornerfix, edgefix), fill(σ, (4, 1, 1)) + return CTMRGEnv(cornerfix, edgefix), fill(σ, (4, 1, 1)) end # this is a bit of a hack to get the fixed point of the mixed transfer matrix diff --git a/src/algorithms/optimization/fixed_point_differentiation.jl b/src/algorithms/optimization/fixed_point_differentiation.jl index 2e568fe6f..c05dcfc3d 100644 --- a/src/algorithms/optimization/fixed_point_differentiation.jl +++ b/src/algorithms/optimization/fixed_point_differentiation.jl @@ -278,7 +278,7 @@ function _rrule( return (env_fixed, info), leading_boundary_fixed_pullback end -function fix_decomposition(alg::SVDAdjoint, signs, info) +function gauge_fix(alg::SVDAdjoint, signs, info) # embed gauge signs in larger space to fix gauge of full U and V on truncated subspace rowsize, colsize = size(signs, 2), size(signs, 3) signs_full = map(Iterators.product(1:4, 1:rowsize, 1:colsize)) do (dir, r, c) @@ -311,7 +311,7 @@ function fix_decomposition(alg::SVDAdjoint, signs, info) rrule_alg = alg.rrule_alg, ) end -function fix_decomposition(alg::SVDAdjoint{F}, signs, info) where {F <: IterSVD} +function gauge_fix(alg::SVDAdjoint{F}, signs, info) where {F <: IterSVD} # fix kept U and V only since iterative SVD doesn't have access to full spectrum U_fixed, V_fixed = fix_relative_phases(info.U, info.V, signs) return SVDAdjoint(; @@ -319,7 +319,7 @@ function fix_decomposition(alg::SVDAdjoint{F}, signs, info) where {F <: IterSVD} rrule_alg = alg.rrule_alg, ) end -function fix_decomposition(alg::EighAdjoint, signs, info) +function gauge_fix(alg::EighAdjoint, signs, info) # embed gauge signs in larger space to fix gauge of full V on truncated subspace σ = signs[1] extended_σ = zeros(scalartype(σ), space(info.D_full)) @@ -338,7 +338,7 @@ function fix_decomposition(alg::EighAdjoint, signs, info) rrule_alg = alg.rrule_alg, ) end -function fix_decomposition(alg::EighAdjoint{F}, signs, info) where {F <: IterEig} +function gauge_fix(alg::EighAdjoint{F}, signs, info) where {F <: IterEig} # fix kept U only since iterative decomposition doesn't have access to full spectrum U_fixed = info.U * signs[1]' return EighAdjoint(; diff --git a/src/environments/ctmrg_environments.jl b/src/environments/ctmrg_environments.jl index 7adfb99e0..cf85e4702 100644 --- a/src/environments/ctmrg_environments.jl +++ b/src/environments/ctmrg_environments.jl @@ -233,7 +233,9 @@ end @non_differentiable CTMRGEnv(state::Union{InfinitePartitionFunction, InfinitePEPS}, args...) # Custom adjoint for CTMRGEnv constructor, needed for fixed-point differentiation -function ChainRulesCore.rrule(::Type{CTMRGEnv}, corners, edges) +function ChainRulesCore.rrule( + ::Type{CTMRGEnv}, corners::Array{3, C}, edges::Array{3, T} + ) where {C, T} ctmrgenv_pullback(ē) = NoTangent(), ē.corners, ē.edges return CTMRGEnv(corners, edges), ctmrgenv_pullback end diff --git a/test/ctmrg/fixed_iterscheme.jl b/test/ctmrg/fixed_iterscheme.jl index 72b9d2936..b3221e8d0 100644 --- a/test/ctmrg/fixed_iterscheme.jl +++ b/test/ctmrg/fixed_iterscheme.jl @@ -12,24 +12,24 @@ using PEPSKit: fix_relative_phases, fix_global_phases, calc_elementwise_convergence, - fix_decomposition, - gauge_fix # initialize parameters χbond = 2 χenv = 16 svd_algs = [SVDAdjoint(; fwd_alg = LAPACK_DivideAndConquer()), SVDAdjoint(; fwd_alg = IterSVD())] -projector_algs = [:halfinfinite] #, :fullinfinite] +eigh_algs = [EighAdjoint(; fwd_alg = qriteration), EighAdjoint(; fwd_alg = :lanczos)] +projector_algs_asymm = [:halfinfinite] #, :fullinfinite] +projector_algs_c4v = [:c4v_eigh, :c4v_qr] unitcells = [(1, 1), (3, 4)] atol = 1.0e-5 # test for element-wise convergence after application of fixed step @testset "$unitcell unit cell with $(typeof(svd_alg.fwd_alg)) and $projector_alg" for ( - unitcell, svd_alg, projector_alg, + unitcell, decomposition_alg, projector_alg, ) in Iterators.product( - unitcells, svd_algs, projector_algs + unitcells, svd_algs, projector_algs_asymm ) - ctm_alg = SimultaneousCTMRG(; svd_alg, projector_alg) + ctm_alg = SimultaneousCTMRG(; decomposition_alg, projector_alg) # initialize states Random.seed!(2394823842) @@ -52,13 +52,10 @@ atol = 1.0e-5 @test calc_elementwise_convergence(env_conv1, env_fixedsvd) ≈ 0 atol = atol end -eigh_algs = [EighAdjoint(; fwd_alg = qriteration), EighAdjoint(; fwd_alg = :lanczos)] -projector_algs = [:c4v_eigh, :c4v_qr] - # test same thing for C4v CTMRG -@testset "" for (eigh_alg, projector_alg) in Iterators.product(eigh_algs, projector_algs) +@testset "" for (eigh_alg, projector_alg) in Iterators.product(eigh_algs, projector_algs_c4v) # TODO - ctm_alg = C4vCTMRG(; eigh_alg, projector_alg) + ctm_alg = C4vCTMRG(; decomposition_alg, projector_alg) # initialize states Random.seed!(2394823842) @@ -84,9 +81,9 @@ end @testset "Element-wise consistency of LAPACK_DivideAndConquer and IterSVD" begin ctm_alg_iter = SimultaneousCTMRG(; maxiter = 200, - svd_alg = SVDAdjoint(; fwd_alg = IterSVD(; alg = GKL(; tol = 1.0e-14, krylovdim = χenv + 10))), + decomposition_alg = SVDAdjoint(; fwd_alg = IterSVD(; alg = GKL(; tol = 1.0e-14, krylovdim = χenv + 10))), ) - ctm_alg_full = SimultaneousCTMRG(; svd_alg = SVDAdjoint(; fwd_alg = LAPACK_DivideAndConquer())) + ctm_alg_full = SimultaneousCTMRG(; decomposition_alg = SVDAdjoint(; fwd_alg = LAPACK_DivideAndConquer())) # initialize states Random.seed!(91283219347) diff --git a/test/ctmrg/flavors.jl b/test/ctmrg/flavors.jl index d15705b88..2bc961367 100644 --- a/test/ctmrg/flavors.jl +++ b/test/ctmrg/flavors.jl @@ -4,24 +4,27 @@ using MatrixAlgebraKit using TensorKit using MPSKit using PEPSKit +using PEPSKit: peps_normalize, initialize_random_c4v_env, initialize_singlet_c4v_env # initialize parameters -χbond = 2 -χenv = 16 +D = 2 +χ = 16 unitcells = [(1, 1), (3, 4)] projector_algs_asymm = [:halfinfinite, :fullinfinite] -projector_algs_c4v = [:c4v_eigh, :c4v_qr] +projector_algs_c4v = [:c4v_eigh] # :c4v_qr] +Ts = [Float64, ComplexF64] +eigh_algs = [:qriteration, :lanczos] @testset "$(unitcell) unit cell with $projector_alg" for (unitcell, projector_alg) in Iterators.product(unitcells, projector_algs_asymm) # compute environments Random.seed!(32350283290358) - psi = InfinitePEPS(ComplexSpace(2), ComplexSpace(χbond); unitcell) + psi = InfinitePEPS(ComplexSpace(2), ComplexSpace(D); unitcell) env_sequential, = leading_boundary( - CTMRGEnv(psi, ComplexSpace(χenv)), psi; alg = :sequential, projector_alg + CTMRGEnv(psi, ComplexSpace(χ)), psi; alg = :sequential, projector_alg ) env_simultaneous, = leading_boundary( - CTMRGEnv(psi, ComplexSpace(χenv)), psi; alg = :simultaneous, projector_alg + CTMRGEnv(psi, ComplexSpace(χ)), psi; alg = :simultaneous, projector_alg ) # compare norms @@ -69,8 +72,22 @@ end @test all(space.(env.edges) .== space.(env2.edges)) end -@testset "C4v CTMRG using $alg and $projector_alg" for (alg, projector_alg) in - Iterators.product([:c4v], projector_algs_c4v) +@testset "C4v with ($T) - ($projector_alg) - ($eigh_alg)" for (projector_alg, T, eigh_alg) in + Iterators.product(projector_algs_c4v, Ts, eigh_algs) - # TODO + Random.seed!(29358293829382) + symm = RotateReflect() + Vphys = ComplexSpace(2) + Vpeps = ComplexSpace(D) + Venv = ComplexSpace(χ) + + peps = InfinitePEPS(randn, T, Vphys, Vpeps, Vpeps) + peps = peps_normalize(symmetrize!(peps, symm)) + + # boundary_alg = C4vCTMRG(; projector_alg, decomposition_alg = (; fwd_alg)) + env₀ = initialize_random_c4v_env(Vpeps, Venv, scalartype(peps)) + env, = leading_boundary( + env₀, peps; alg = :c4v, projector_alg, + decomposition_alg = (; fwd_alg = (; alg = eigh_alg)) + ) end diff --git a/test/ctmrg/jacobian_real_linear.jl b/test/ctmrg/jacobian_real_linear.jl index fce51a4ca..ae9869d8c 100644 --- a/test/ctmrg/jacobian_real_linear.jl +++ b/test/ctmrg/jacobian_real_linear.jl @@ -4,7 +4,7 @@ using Accessors using Zygote using TensorKit, KrylovKit, PEPSKit using PEPSKit: - ctmrg_iteration, fix_relative_phases, fix_global_phases, fix_decomposition, ScramblingEnvGauge + ctmrg_iteration, fix_relative_phases, fix_global_phases, ScramblingEnvGauge algs = [ (:fixed, SimultaneousCTMRG(; projector_alg = :halfinfinite)), diff --git a/test/ctmrg/unitcell.jl b/test/ctmrg/unitcell.jl index 69540d1b8..6524c68ab 100644 --- a/test/ctmrg/unitcell.jl +++ b/test/ctmrg/unitcell.jl @@ -1,7 +1,7 @@ using Test using Random using PEPSKit -using PEPSKit: _prev, _next, ctmrg_iteration, fix_decomposition +using PEPSKit: _prev, _next, ctmrg_iteration, gauge_fix using TensorKit # settings @@ -42,8 +42,8 @@ function test_unitcell( _, signs = gauge_fix(env″, ScramblingEnvGauge(), env′) @test signs isa Array return if ctm_alg isa SimultaneousCTMRG # also test :fixed mode gauge fixing for simultaneous CTMRG - svd_alg_fixed_full = fix_decomposition(SVDAdjoint(; fwd_alg = (; alg = :sdd)), signs, info) - svd_alg_fixed_iter = fix_decomposition(SVDAdjoint(; fwd_alg = (; alg = :iterative)), signs, info) + svd_alg_fixed_full = gauge_fix(SVDAdjoint(; fwd_alg = (; alg = :sdd)), signs, info) + svd_alg_fixed_iter = gauge_fix(SVDAdjoint(; fwd_alg = (; alg = :iterative)), signs, info) @test svd_alg_fixed_full isa SVDAdjoint @test svd_alg_fixed_iter isa SVDAdjoint end diff --git a/test/gradients/ctmrg_gradients.jl b/test/gradients/ctmrg_gradients.jl index e973c0a03..1c03d83dc 100644 --- a/test/gradients/ctmrg_gradients.jl +++ b/test/gradients/ctmrg_gradients.jl @@ -74,7 +74,7 @@ naive_gradient_done = Set() alg = ctmrg_alg, verbosity = ctmrg_verbosity, projector_alg = projector_alg, - svd_alg = (; rrule_alg = (; alg = svd_rrule_alg)), + decomposition_alg = (; rrule_alg = (; alg = svd_rrule_alg)), ) # instantiate because hook_pullback doesn't go through the keyword selector... concrete_gradient_alg = if isnothing(gradient_alg) From 094af8c77e9a2134a2d8aed73f0f5196bdd31898 Mon Sep 17 00:00:00 2001 From: Paul Brehmer Date: Thu, 22 Jan 2026 19:07:25 +0100 Subject: [PATCH 07/14] Fix some more stuff and add fixed_iterscheme C4v test --- src/PEPSKit.jl | 4 +- src/algorithms/ctmrg/c4v.jl | 14 ++++- src/algorithms/ctmrg/gaugefix.jl | 3 +- src/algorithms/ctmrg/projectors.jl | 4 +- src/utility/eigh.jl | 90 +++++++++++++++--------------- test/ctmrg/fixed_iterscheme.jl | 60 ++++++++++++-------- test/ctmrg/flavors.jl | 1 + 7 files changed, 99 insertions(+), 77 deletions(-) diff --git a/src/PEPSKit.jl b/src/PEPSKit.jl index aace22875..bfdb0399e 100644 --- a/src/PEPSKit.jl +++ b/src/PEPSKit.jl @@ -44,6 +44,8 @@ include("networks/infinitesquarenetwork.jl") include("states/infinitepeps.jl") include("states/infinitepartitionfunction.jl") +include("utility/symmetrization.jl") + include("operators/infinitepepo.jl") include("operators/transfermatrix.jl") include("operators/localoperator.jl") @@ -92,8 +94,6 @@ include("algorithms/transfermatrix.jl") include("algorithms/toolbox.jl") include("algorithms/correlators.jl") -include("utility/symmetrization.jl") - include("algorithms/optimization/fixed_point_differentiation.jl") include("algorithms/optimization/peps_optimization.jl") diff --git a/src/algorithms/ctmrg/c4v.jl b/src/algorithms/ctmrg/c4v.jl index fa3065f36..191f58725 100644 --- a/src/algorithms/ctmrg/c4v.jl +++ b/src/algorithms/ctmrg/c4v.jl @@ -47,7 +47,7 @@ end function c4v_projector(enlarged_corner, alg::C4vEighProjector) hermitian_corner = 0.5 * (enlarged_corner + enlarged_corner') / norm(enlarged_corner) trunc = truncation_strategy(alg, enlarged_corner) - D, V, info = eigh_trunc!(hermitian_corner, decomposition_algorithm(alg); trunc) + D, U, info = eigh_trunc!(hermitian_corner, decomposition_algorithm(alg); trunc) # Check for degenerate eigenvalues Zygote.isderiving() && ignore_derivatives() do @@ -57,7 +57,7 @@ function c4v_projector(enlarged_corner, alg::C4vEighProjector) end end - return D / norm(D), V, (; D, V, info...) + return D / norm(D), U, (; D, U, info...) end function c4v_projector(enlarged_corner, alg::C4vQRProjector) @@ -90,6 +90,14 @@ end function physical_flip(A::AbstractTensorMap{T, S, N, 1}) where {T, S, N} return flip(A, 2:N) end +function project_hermitian(E::AbstractTensorMap{T, S, N, 1}) where {T, S, N} + E´ = (E + physical_flip(_dag(E))) / 2 + return E´ +end +function project_hermitian(C::AbstractTensorMap{T, S, 1, 1}) where {T, S} + C´ = (C + C') / 2 + return C´ +end # should perform this check at the beginning of `leading_boundary` really... function check_symmetry(state, ::RotateReflect; atol = 1.0e-10) @@ -117,6 +125,6 @@ end function initialize_random_c4v_env(Vpeps::ElementarySpace, Venv::ElementarySpace, T = ComplexF64) corner₀ = DiagonalTensorMap(randn(real(T), Venv ← Venv)) edge₀ = randn(T, Venv ⊗ Vpeps ⊗ Vpeps' ← Venv) - edge₀ = 0.5 * (edge₀ + physical_flip(_dag(edge₀))) + edge₀ = project_hermitian(edge₀) return CTMRGEnv(corner₀, edge₀) end diff --git a/src/algorithms/ctmrg/gaugefix.jl b/src/algorithms/ctmrg/gaugefix.jl index ce0a57905..4758ec09e 100644 --- a/src/algorithms/ctmrg/gaugefix.jl +++ b/src/algorithms/ctmrg/gaugefix.jl @@ -1,4 +1,3 @@ - """ gauge_fix(alg::CTMRGAlgorithm, signs, info) gauge_fix(alg::ProjectorAlgorithm, signs, info) @@ -11,7 +10,7 @@ function gauge_fix(alg::CTMRGAlgorithm, signs, info) end function gauge_fix(alg::ProjectorAlgorithm, signs, info) decomposition_alg_fixed = gauge_fix(alg.alg, signs, info) # every ProjectorAlgorithm needs an `alg` field - alg_fixed = @set alg.alg = decomposition_alg_fixed + alg_fixed = @set alg.alg = decomposition_alg_fixed alg_fixed = @set alg_fixed.trunc = notrunc() return alg_fixed end diff --git a/src/algorithms/ctmrg/projectors.jl b/src/algorithms/ctmrg/projectors.jl index 04388cb58..26a8c821f 100644 --- a/src/algorithms/ctmrg/projectors.jl +++ b/src/algorithms/ctmrg/projectors.jl @@ -27,9 +27,9 @@ function ProjectorAlgorithm(; decomposition_algorithm = if alg in [:halfinfinite, :fullinfinite] _alg_or_nt(SVDAdjoint, decomposition_alg) - elseif alg in [:c4v_eigh,] + elseif alg in [:c4v_eigh] _alg_or_nt(EighAdjoint, decomposition_alg) - elseif alg in [:c4v_qr,] + elseif alg in [:c4v_qr] _alg_or_nt(QRAdjoint, decomposition_alg) end # TODO: how do we solve this in a proper way? diff --git a/src/utility/eigh.jl b/src/utility/eigh.jl index 3ee86130b..2fd2bd9ef 100644 --- a/src/utility/eigh.jl +++ b/src/utility/eigh.jl @@ -138,38 +138,38 @@ end function MatrixAlgebraKit.eigh_trunc!( t::AdjointTensorMap, alg::EighAdjoint; trunc = notrunc() ) - D, V, info = eigh_trunc!(adjoint(t), alg; trunc) - return adjoint(D), adjoint(V), info + D, U, info = eigh_trunc!(adjoint(t), alg; trunc) + return adjoint(D), adjoint(U), info end ## Forward algorithms -# Truncated eigh but also return full D and V to make it compatible with :fixed mode +# Truncated eigh but also return full D and U to make it compatible with :fixed mode function _eigh_trunc!( t::TensorMap, alg::LAPACK_EighAlgorithm, trunc::TruncationStrategy, ) - D, V = eigh_full!(t; alg) - D̃, Ṽ, truncerror = _truncate_eigh((D, V), trunc) + D, U = eigh_full!(t; alg) + D̃, Ũ, truncerror = _truncate_eigh((D, U), trunc) # construct info NamedTuple condnum = cond(D) info = (; - truncation_error = truncerror, condition_number = condnum, D_full = D, V_full = V, + truncation_error = truncerror, condition_number = condnum, D_full = D, U_full = U, ) - return D̃, Ṽ, info + return D̃, Ũ, info end # hacky way of computing the truncation error for current version of eigh_trunc! # TODO: replace once TensorKit updates to new MatrixAlgebraKit which returns truncation error as well -function _truncate_eigh((D, V), trunc::TruncationStrategy) +function _truncate_eigh((D, U), trunc::TruncationStrategy) if !(trunc isa NoTruncation) && !isempty(blocksectors(D)) - D̃, Ṽ = truncate(eigh_trunc!, (D, V), trunc)[1] + D̃, Ũ = truncate(eigh_trunc!, (D, U), trunc)[1] truncerror = sqrt(abs(norm(D)^2 - norm(D̃)^2)) - return D̃, Ṽ, truncerror + return D̃, Ũ, truncerror else - return D, V, zero(real(scalartype(D))) + return D, U, zero(real(scalartype(D))) end end @@ -178,24 +178,24 @@ $(TYPEDEF) Eigenvalue decomposition struct containing a pre-computed decomposition or even multiple ones. Additionally, it can contain the untruncated full decomposition as well. The call to -`eigh_trunc`/`eig_trunc` just returns the pre-computed D and V. In the reverse pass, -the adjoint is computed with these exact D and V and, potentially, the full decompositions +`eigh_trunc`/`eig_trunc` just returns the pre-computed D and U. In the reverse pass, +the adjoint is computed with these exact D and U and, potentially, the full decompositions if the adjoints needs access to them. ## Fields $(TYPEDFIELDS) """ -struct FixedEig{Dt, Vt, Dtf, Vtf} +struct FixedEig{Dt, Ut, Dtf, Utf} D::Dt - V::Vt + U::Ut D_full::Dtf - V_full::Vtf + U_full::Utf end -# check whether the full D and V are supplied +# check whether the full D and U are supplied function isfulleig(alg::FixedEig) - if isnothing(alg.D_full) || isnothing(alg.V_full) + if isnothing(alg.D_full) || isnothing(alg.U_full) return false else return true @@ -208,9 +208,9 @@ function _eigh_trunc!(_, alg::FixedEig, ::TruncationStrategy) truncation_error = zero(real(scalartype(alg.D))), condition_number = cond(alg.D), D_full = alg.D_full, - V_full = alg.V_full, + U_full = alg.U_full, ) - return alg.D, alg.V, info + return alg.D, alg.U, info end @@ -245,7 +245,7 @@ end # Compute eigh data block-wise using KrylovKit algorithm function _eigh_trunc!(f, alg::IterEig, trunc::TruncationStrategy) - D, V = if isempty(blocksectors(f)) + D, U = if isempty(blocksectors(f)) # early return truncation_error = zero(real(scalartype(f))) MatrixAlgebraKit.initialize_output(eigh_full!, f, LAPACK_QRIteration()) # specified algorithm doesn't matter here @@ -256,11 +256,11 @@ function _eigh_trunc!(f, alg::IterEig, trunc::TruncationStrategy) # construct info NamedTuple truncation_error = - trunc isa NoTruncation ? abs(zero(scalartype(f))) : norm(V * D * V' - f) + trunc isa NoTruncation ? abs(zero(scalartype(f))) : norm(U * D * U' - f) condition_number = cond(D) - info = (; truncation_error, condition_number, D_full = nothing, V_full = nothing) + info = (; truncation_error, condition_number, D_full = nothing, U_full = nothing) - return D, V, info + return D, U, info end # Obtain sparse decomposition from block-wise eigsolve calls @@ -279,10 +279,10 @@ function _compute_eighdata!( howmany = trunc isa NoTruncation ? minimum(size(b)) : blockdim(trunc.space, c) if howmany / minimum(size(b)) > alg.fallback_threshold # Use dense decomposition for small blocks - D, V = eigh_full!(b, LAPACK_QRIteration()) + D, U = eigh_full!(b, LAPACK_QRIteration()) lm_ordering = sortperm(abs.(D.diag); rev = true) # order values and vectors consistently with eigsolve D = D.diag[lm_ordering] # extracts diagonal as Vector instead of Diagonal to make compatible with D of svdsolve - V = view(V, lm_ordering)[:, 1:howmany] + U = view(U, lm_ordering)[:, 1:howmany] else x₀ = alg.start_vector(b) eig_alg = alg.alg @@ -292,18 +292,18 @@ function _compute_eighdata!( D, lvecs, info = eigsolve(b, x₀, howmany, :LM, eig_alg) if info.converged < howmany # Fall back to dense SVD if not properly converged @warn "Iterative eigendecomposition did not converge for block $c, falling back to eigh_full" - D, V = eigh_full!(b, LAPACK_QRIteration()) + D, U = eigh_full!(b, LAPACK_QRIteration()) lm_ordering = sortperm(abs.(D.diag); rev = true) D = D.diag[lm_ordering] - V = view(V, lm_ordering)[:, 1:howmany] + U = view(U, lm_ordering)[:, 1:howmany] else # Slice in case more values were converged than requested - V = stack(view(lvecs, 1:howmany)) + U = stack(view(lvecs, 1:howmany)) end end resize!(D, howmany) dims[c] = length(D) - return c => (D, V) + return c => (D, U) end eigdata = SectorDict(generator) @@ -321,13 +321,13 @@ function _create_eightensors(t::TensorMap, eighdata, dims) Tr = real(T) A = similarstoragetype(t, Tr) D = DiagonalTensorMap{Tr, S, A}(undef, W) - V = similar(t, domain(t) ← W) - for (c, (Dc, Vc)) in eighdata + U = similar(t, domain(t) ← W) + for (c, (Dc, Uc)) in eighdata r = Base.OneTo(dims[c]) copy!(block(D, c), Diagonal(view(Dc, r))) - copy!(block(V, c), view(Vc, :, r)) + copy!(block(U, c), view(Uc, :, r)) end - return D, V + return D, U end ## Reverse-rule algorithms @@ -348,15 +348,15 @@ function ChainRulesCore.rrule( alg::EighAdjoint{F, R}; trunc::TruncationStrategy = notrunc(), ) where {F <: Union{<:LAPACK_EighAlgorithm, <:FixedEig}, R <: FullEighPullback} - D̃, Ṽ, info = eigh_trunc(t, alg; trunc) - D, V = info.D_full, info.V_full # untruncated decomposition + D̃, Ũ, info = eigh_trunc(t, alg; trunc) + D, U = info.D_full, info.U_full # untruncated decomposition inds = findtruncated(diagview(D), truncspace(only(domain(D̃)))) gtol = _get_pullback_gauge_tol(alg.rrule_alg.verbosity) - function eigh_trunc!_full_pullback(ΔDV) + function eigh_trunc!_full_pullback(ΔDU) Δt = eigh_pullback!( # TODO: does this work by now? - zeros(scalartype(t), space(t)), t, (D, V), ΔDV, inds; - gauge_atol = gtol(ΔDV) + zeros(scalartype(t), space(t)), t, (D, U), ΔDU, inds; + gauge_atol = gtol(ΔDU) ) return NoTangent(), Δt, NoTangent() end @@ -364,7 +364,7 @@ function ChainRulesCore.rrule( return NoTangent(), ZeroTangent(), NoTangent() end - return (D̃, Ṽ, info), eigh_trunc!_full_pullback + return (D̃, Ũ, info), eigh_trunc!_full_pullback end # eigh_trunc! rrule wrapping MatrixAlgebraKit's eigh_trunc_pullback! (also works for IterEig) @@ -374,13 +374,13 @@ function ChainRulesCore.rrule( alg::EighAdjoint{F, R}; trunc::TruncationStrategy = notrunc(), ) where {F <: Union{<:LAPACK_EighAlgorithm, <:FixedEig, IterEig}, R <: TruncEighPullback} - D, V, info = eigh_trunc(t, alg; trunc) + D, U, info = eigh_trunc(t, alg; trunc) gtol = _get_pullback_gauge_tol(alg.rrule_alg.verbosity) - function eigh_trunc!_trunc_pullback(ΔDV) + function eigh_trunc!_trunc_pullback(ΔDU) Δf = eigh_trunc_pullback!( - zeros(scalartype(t), space(t)), t, (D, V), ΔDV; - gauge_atol = gtol(ΔDV) + zeros(scalartype(t), space(t)), t, (D, U), ΔDU; + gauge_atol = gtol(ΔDU) ) return NoTangent(), Δf, NoTangent() end @@ -388,5 +388,5 @@ function ChainRulesCore.rrule( return NoTangent(), ZeroTangent(), NoTangent() end - return (D, V, info), eigh_trunc!_trunc_pullback + return (D, U, info), eigh_trunc!_trunc_pullback end diff --git a/test/ctmrg/fixed_iterscheme.jl b/test/ctmrg/fixed_iterscheme.jl index b3221e8d0..85ea49893 100644 --- a/test/ctmrg/fixed_iterscheme.jl +++ b/test/ctmrg/fixed_iterscheme.jl @@ -12,19 +12,23 @@ using PEPSKit: fix_relative_phases, fix_global_phases, calc_elementwise_convergence, + peps_normalize, + initialize_random_c4v_env, + ScramblingEnvGauge, + ScramblingEnvGaugeC4v # initialize parameters -χbond = 2 -χenv = 16 +D = 2 +χ = 16 svd_algs = [SVDAdjoint(; fwd_alg = LAPACK_DivideAndConquer()), SVDAdjoint(; fwd_alg = IterSVD())] -eigh_algs = [EighAdjoint(; fwd_alg = qriteration), EighAdjoint(; fwd_alg = :lanczos)] +eigh_algs = [EighAdjoint(; fwd_alg = (; alg = :qriteration)), EighAdjoint(; fwd_alg = (; alg = :lanczos))] projector_algs_asymm = [:halfinfinite] #, :fullinfinite] -projector_algs_c4v = [:c4v_eigh, :c4v_qr] +projector_algs_c4v = [:c4v_eigh] # :c4v_qr] unitcells = [(1, 1), (3, 4)] atol = 1.0e-5 # test for element-wise convergence after application of fixed step -@testset "$unitcell unit cell with $(typeof(svd_alg.fwd_alg)) and $projector_alg" for ( +@testset "$unitcell unit cell with $(typeof(decomposition_alg.fwd_alg)) and $projector_alg" for ( unitcell, decomposition_alg, projector_alg, ) in Iterators.product( unitcells, svd_algs, projector_algs_asymm @@ -33,13 +37,14 @@ atol = 1.0e-5 # initialize states Random.seed!(2394823842) - psi = InfinitePEPS(ComplexSpace(2), ComplexSpace(χbond); unitcell) + psi = InfinitePEPS(ComplexSpace(2), ComplexSpace(D); unitcell) n = InfiniteSquareNetwork(psi) - env_conv1, = leading_boundary(CTMRGEnv(psi, ComplexSpace(χenv)), psi, ctm_alg) + env_conv1, = leading_boundary(CTMRGEnv(psi, ComplexSpace(χ)), psi, ctm_alg) # do extra iteration to get SVD - env_conv2, info = @constinferred ctmrg_iteration(n, env_conv1, ctm_alg) + # env_conv2, info = @constinferred ctmrg_iteration(n, env_conv1, ctm_alg) + env_conv2, info = ctmrg_iteration(n, env_conv1, ctm_alg) env_fix, signs = gauge_fix(env_conv2, ScramblingEnvGauge(), env_conv1) @test calc_elementwise_convergence(env_conv1, env_fix) ≈ 0 atol = atol @@ -47,33 +52,40 @@ atol = 1.0e-5 ctm_alg_fix = gauge_fix(ctm_alg, signs, info) # do iteration with FixedSVD - env_fixedsvd, = @constinferred ctmrg_iteration(n, env_conv1, ctm_alg_fix) + # env_fixedsvd, = @constinferred ctmrg_iteration(n, env_conv1, ctm_alg_fix) + env_fixedsvd, = ctmrg_iteration(n, env_conv1, ctm_alg_fix) env_fixedsvd = fix_global_phases(env_conv1, env_fixedsvd) @test calc_elementwise_convergence(env_conv1, env_fixedsvd) ≈ 0 atol = atol end # test same thing for C4v CTMRG -@testset "" for (eigh_alg, projector_alg) in Iterators.product(eigh_algs, projector_algs_c4v) - # TODO - ctm_alg = C4vCTMRG(; decomposition_alg, projector_alg) - +@testset "$(typeof(decomposition_alg.fwd_alg)) and $projector_alg" for (decomposition_alg, projector_alg) in + Iterators.product(eigh_algs, projector_algs_c4v) # initialize states Random.seed!(2394823842) - psi = InfinitePEPS(ComplexSpace(2), ComplexSpace(χbond); unitcell) + ctm_alg = C4vCTMRG(; projector_alg, decomposition_alg) + symm = RotateReflect() + + psi = InfinitePEPS(ComplexSpace(2), ComplexSpace(D)) + psi = peps_normalize(symmetrize!(psi, symm)) n = InfiniteSquareNetwork(psi) - env_conv1, = leading_boundary(CTMRGEnv(psi, ComplexSpace(χenv)), psi, ctm_alg) + env₀ = initialize_random_c4v_env(ComplexSpace(2), ComplexSpace(χ), scalartype(psi)) + env_conv1, = leading_boundary(env₀, psi, ctm_alg) # do extra iteration to get SVD - env_conv2, info = @constinferred ctmrg_iteration(n, env_conv1, ctm_alg) + # env_conv2, info = @constinferred ctmrg_iteration(n, env_conv1, ctm_alg) + env_conv2, info = ctmrg_iteration(n, env_conv1, ctm_alg) env_fix, signs = gauge_fix(env_conv2, ScramblingEnvGaugeC4v(), env_conv1) - @test calc_elementwise_convergence(env_conv1, env_fix) ≈ 0 atol = atol + env_ref = CTMRGEnv(TensorMap.(env_conv1.corners), env_conv1.edges) + @test calc_elementwise_convergence(env_ref, env_fix) ≈ 0 atol = atol # fix gauge of SVD ctm_alg_fix = gauge_fix(ctm_alg, signs, info) # do iteration with FixedSVD - env_fixedsvd, = @constinferred ctmrg_iteration(n, env_conv1, ctm_alg_fix) + # env_fixedsvd, = @constinferred ctmrg_iteration(n, env_conv1, ctm_alg_fix) + env_fixedsvd, = ctmrg_iteration(n, env_conv1, ctm_alg_fix) env_fixedsvd = fix_global_phases(env_conv1, env_fixedsvd) @test calc_elementwise_convergence(env_conv1, env_fixedsvd) ≈ 0 atol = atol end @@ -81,15 +93,15 @@ end @testset "Element-wise consistency of LAPACK_DivideAndConquer and IterSVD" begin ctm_alg_iter = SimultaneousCTMRG(; maxiter = 200, - decomposition_alg = SVDAdjoint(; fwd_alg = IterSVD(; alg = GKL(; tol = 1.0e-14, krylovdim = χenv + 10))), + decomposition_alg = SVDAdjoint(; fwd_alg = IterSVD(; alg = GKL(; tol = 1.0e-14, krylovdim = χ + 10))), ) ctm_alg_full = SimultaneousCTMRG(; decomposition_alg = SVDAdjoint(; fwd_alg = LAPACK_DivideAndConquer())) # initialize states Random.seed!(91283219347) - psi = InfinitePEPS(ComplexSpace(2), ComplexSpace(χbond)) + psi = InfinitePEPS(ComplexSpace(2), ComplexSpace(D)) n = InfiniteSquareNetwork(psi) - env₀ = CTMRGEnv(psi, ComplexSpace(χenv)) + env₀ = CTMRGEnv(psi, ComplexSpace(χ)) env_conv1, = leading_boundary(env₀, psi, ctm_alg_iter) # do extra iteration to get SVD @@ -131,8 +143,10 @@ end @test svalues_check # check normalization of U's and V's - Us = [info_iter.U, svd_alg_fix_iter.fwd_alg.U, info_full.U, svd_alg_fix_full.fwd_alg.U] - Vs = [info_iter.V, svd_alg_fix_iter.fwd_alg.V, info_full.V, svd_alg_fix_full.fwd_alg.V] + salg_fix_iter = ctm_alg_fix_iter.projector_alg.alg.fwd_alg + salg_fix_full = ctm_alg_fix_full.projector_alg.alg.fwd_alg + Us = [info_iter.U, salg_fix_iter.U, info_full.U, salg_fix_full.U] + Vs = [info_iter.V, salg_fix_iter.V, info_full.V, salg_fix_full.V] for (U, V) in zip(Us, Vs) U_check = all(U) do u uu = u' * u diff --git a/test/ctmrg/flavors.jl b/test/ctmrg/flavors.jl index 2bc961367..ae5dd99f6 100644 --- a/test/ctmrg/flavors.jl +++ b/test/ctmrg/flavors.jl @@ -90,4 +90,5 @@ end env₀, peps; alg = :c4v, projector_alg, decomposition_alg = (; fwd_alg = (; alg = eigh_alg)) ) + @test env isa CTMRGEnv end From a6ce10842adb27fefbb15570a06c1cffceb8eebf Mon Sep 17 00:00:00 2001 From: Paul Brehmer Date: Fri, 23 Jan 2026 11:02:55 +0100 Subject: [PATCH 08/14] Comment out QR things for now --- src/Defaults.jl | 6 +- src/PEPSKit.jl | 2 +- src/algorithms/ctmrg/c4v.jl | 32 ++++++----- src/utility/qr.jl | 106 ++++++++++++++++++------------------ 4 files changed, 76 insertions(+), 70 deletions(-) diff --git a/src/Defaults.jl b/src/Defaults.jl index e1674491e..72ff7397a 100644 --- a/src/Defaults.jl +++ b/src/Defaults.jl @@ -107,9 +107,9 @@ const eigh_rrule_alg = :trunc const eigh_rrule_verbosity = 0 # QR forward & reverse -const qr_fwd_alg = :something # TODO -const qr_rrule_alg = :something -const qr_rrule_verbosity = :something +# const qr_fwd_alg = :something # TODO +# const qr_rrule_alg = :something +# const qr_rrule_verbosity = :something # Projectors const projector_alg = :halfinfinite # ∈ {:halfinfinite, :fullinfinite} diff --git a/src/PEPSKit.jl b/src/PEPSKit.jl index bfdb0399e..dbad9cc05 100644 --- a/src/PEPSKit.jl +++ b/src/PEPSKit.jl @@ -31,7 +31,7 @@ include("utility/util.jl") include("utility/diffable_threads.jl") include("utility/eigh.jl") include("utility/svd.jl") -include("utility/qr.jl") +# include("utility/qr.jl") include("utility/rotations.jl") include("utility/hook_pullback.jl") include("utility/autoopt.jl") diff --git a/src/algorithms/ctmrg/c4v.jl b/src/algorithms/ctmrg/c4v.jl index 191f58725..c201ffc98 100644 --- a/src/algorithms/ctmrg/c4v.jl +++ b/src/algorithms/ctmrg/c4v.jl @@ -20,14 +20,14 @@ function C4vEighProjector(; kwargs...) end PROJECTOR_SYMBOLS[:c4v_eigh] = C4vEighProjector -struct C4vQRProjector{S, T} <: ProjectorAlgorithm - alg::S - verbosity::Int -end -function C4vQRProjector(; kwargs...) - return ProjectorAlgorithm(; alg = :c4v_qr, kwargs...) -end -PROJECTOR_SYMBOLS[:c4v_qr] = C4vQRProjector +# struct C4vQRProjector{S, T} <: ProjectorAlgorithm +# alg::S +# verbosity::Int +# end +# function C4vQRProjector(; kwargs...) +# return ProjectorAlgorithm(; alg = :c4v_qr, kwargs...) +# end +# PROJECTOR_SYMBOLS[:c4v_qr] = C4vQRProjector # ## C4v-symmetric CTMRG iteration (called through `leading_boundary`) @@ -38,12 +38,19 @@ function ctmrg_iteration( env::CTMRGEnv, alg::C4vCTMRG, ) - enlarged_corner = TensorMap(EnlargedCorner(network, env, (NORTHWEST, 1, 1))) + enlarged_corner = c4v_enlarge(network, env, alg.projector_alg) corner′, projector, info = c4v_projector(enlarged_corner, alg.projector_alg) edge′ = c4v_renormalize(network, env, projector) return CTMRGEnv(corner′, edge′), info end +function c4v_enlarge(network, env, ::C4vEighProjector) + return TensorMap(EnlargedCorner(network, env, (NORTHWEST, 1, 1))) +end +# function c4v_enlarge(enlarged_corner, alg::C4vQRProjector) +# # TODO +# end + function c4v_projector(enlarged_corner, alg::C4vEighProjector) hermitian_corner = 0.5 * (enlarged_corner + enlarged_corner') / norm(enlarged_corner) trunc = truncation_strategy(alg, enlarged_corner) @@ -59,10 +66,9 @@ function c4v_projector(enlarged_corner, alg::C4vEighProjector) return D / norm(D), U, (; D, U, info...) end - -function c4v_projector(enlarged_corner, alg::C4vQRProjector) - # TODO -end +# function c4v_projector(enlarged_corner, alg::C4vQRProjector) +# # TODO +# end function c4v_renormalize(network, env, projector) new_edge = renormalize_north_edge(env.edges[1], projector, projector', network[1, 1]) diff --git a/src/utility/qr.jl b/src/utility/qr.jl index efd3a912d..8b4cadfbc 100644 --- a/src/utility/qr.jl +++ b/src/utility/qr.jl @@ -1,68 +1,68 @@ -""" -$(TYPEDEF) +# """ +# $(TYPEDEF) -Wrapper for a QR decomposition algorithm `fwd_alg` with a defined reverse rule `rrule_alg`. -If `isnothing(rrule_alg)`, Zygote differentiates the forward call automatically. +# Wrapper for a QR decomposition algorithm `fwd_alg` with a defined reverse rule `rrule_alg`. +# If `isnothing(rrule_alg)`, Zygote differentiates the forward call automatically. -## Fields +# ## Fields -$(TYPEDFIELDS) +# $(TYPEDFIELDS) -## Constructors +# ## Constructors - QRAdjoint(; kwargs...) +# QRAdjoint(; kwargs...) -Construct a `QRAdjoint` algorithm struct based on the following keyword arguments: +# Construct a `QRAdjoint` algorithm struct based on the following keyword arguments: -* `fwd_alg::Union{Algorithm,NamedTuple}=(; alg::Symbol=$(Defaults.qr_fwd_alg))`: Eig algorithm of the forward pass which can either be passed as an `Algorithm` instance or a `NamedTuple` where `alg` is one of the following: -* `rrule_alg::Union{Algorithm,NamedTuple}=(; alg::Symbol=$(Defaults.qr_rrule_alg))`: Reverse-rule algorithm for differentiating the eigenvalue decomposition. Can be supplied by an `Algorithm` instance directly or as a `NamedTuple` where `alg` is one of the following: -""" -struct QRAdjoint{F, R} - fwd_alg::F - rrule_alg::R -end # Keep truncation algorithm separate to be able to specify CTMRG dependent information +# * `fwd_alg::Union{Algorithm,NamedTuple}=(; alg::Symbol=$(Defaults.qr_fwd_alg))`: Eig algorithm of the forward pass which can either be passed as an `Algorithm` instance or a `NamedTuple` where `alg` is one of the following: +# * `rrule_alg::Union{Algorithm,NamedTuple}=(; alg::Symbol=$(Defaults.qr_rrule_alg))`: Reverse-rule algorithm for differentiating the eigenvalue decomposition. Can be supplied by an `Algorithm` instance directly or as a `NamedTuple` where `alg` is one of the following: +# """ +# struct QRAdjoint{F, R} +# fwd_alg::F +# rrule_alg::R +# end # Keep truncation algorithm separate to be able to specify CTMRG dependent information -const QR_FWD_SYMBOLS = IdDict{Symbol, Any}( - # TODO -) -const QR_RRULE_SYMBOLS = IdDict{Symbol, Type{<:Any}}( - # TODO -) +# const QR_FWD_SYMBOLS = IdDict{Symbol, Any}( +# # TODO +# ) +# const QR_RRULE_SYMBOLS = IdDict{Symbol, Type{<:Any}}( +# # TODO +# ) -function QRAdjoint(; fwd_alg = (;), rrule_alg = (;)) - # parse forward algorithm - fwd_algorithm = if fwd_alg isa NamedTuple - fwd_kwargs = (; alg = Defaults.qr_fwd_alg, fwd_alg...) # overwrite with specified kwargs - haskey(QR_FWD_SYMBOLS, fwd_kwargs.alg) || - throw(ArgumentError("unknown forward algorithm: $(fwd_kwargs.alg)")) - fwd_type = QR_FWD_SYMBOLS[fwd_kwargs.alg] - fwd_kwargs = Base.structdiff(fwd_kwargs, (; alg = nothing)) # remove `alg` keyword argument - fwd_type(; fwd_kwargs...) - else - fwd_alg - end +# function QRAdjoint(; fwd_alg = (;), rrule_alg = (;)) +# # parse forward algorithm +# fwd_algorithm = if fwd_alg isa NamedTuple +# fwd_kwargs = (; alg = Defaults.qr_fwd_alg, fwd_alg...) # overwrite with specified kwargs +# haskey(QR_FWD_SYMBOLS, fwd_kwargs.alg) || +# throw(ArgumentError("unknown forward algorithm: $(fwd_kwargs.alg)")) +# fwd_type = QR_FWD_SYMBOLS[fwd_kwargs.alg] +# fwd_kwargs = Base.structdiff(fwd_kwargs, (; alg = nothing)) # remove `alg` keyword argument +# fwd_type(; fwd_kwargs...) +# else +# fwd_alg +# end - # parse reverse-rule algorithm - rrule_algorithm = if rrule_alg isa NamedTuple - rrule_kwargs = (; - alg = Defaults.qr_rrule_alg, - verbosity = Defaults.qr_rrule_verbosity, - rrule_alg..., - ) # overwrite with specified kwargs +# # parse reverse-rule algorithm +# rrule_algorithm = if rrule_alg isa NamedTuple +# rrule_kwargs = (; +# alg = Defaults.qr_rrule_alg, +# verbosity = Defaults.qr_rrule_verbosity, +# rrule_alg..., +# ) # overwrite with specified kwargs - haskey(QR_RRULE_SYMBOLS, rrule_kwargs.alg) || - throw(ArgumentError("unknown rrule algorithm: $(rrule_kwargs.alg)")) - rrule_type = QR_RRULE_SYMBOLS[rrule_kwargs.alg] - if rrule_type <: Something # TODO - rrule_kwargs = (; rrule_kwargs.verbosity) - end +# haskey(QR_RRULE_SYMBOLS, rrule_kwargs.alg) || +# throw(ArgumentError("unknown rrule algorithm: $(rrule_kwargs.alg)")) +# rrule_type = QR_RRULE_SYMBOLS[rrule_kwargs.alg] +# if rrule_type <: Something # TODO +# rrule_kwargs = (; rrule_kwargs.verbosity) +# end - rrule_type(; rrule_kwargs...) - else - rrule_alg - end +# rrule_type(; rrule_kwargs...) +# else +# rrule_alg +# end - return QRAdjoint(fwd_algorithm, rrule_algorithm) -end +# return QRAdjoint(fwd_algorithm, rrule_algorithm) +# end # TODO: implement wrapper for MatrixAlgebraKit QR functions From 97760e889c993e1390edf415817749004e4e35cc Mon Sep 17 00:00:00 2001 From: Paul Brehmer Date: Fri, 23 Jan 2026 15:10:24 +0100 Subject: [PATCH 09/14] Add gauge fixing and partition function C4v tests --- Project.toml | 2 +- src/algorithms/ctmrg/c4v.jl | 24 +++++++++++------ src/algorithms/ctmrg/gaugefix.jl | 17 +++++++++++- test/ctmrg/fixed_iterscheme.jl | 17 +++++------- test/ctmrg/flavors.jl | 2 +- test/ctmrg/gaugefix.jl | 46 ++++++++++++++++++++++---------- test/ctmrg/partition_function.jl | 22 +++++++-------- 7 files changed, 83 insertions(+), 47 deletions(-) diff --git a/Project.toml b/Project.toml index 6fc2cf0d8..45102057b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "PEPSKit" uuid = "52969e89-939e-4361-9b68-9bc7cde4bdeb" -authors = ["Paul Brehmer", "Lander Burgelman", "Lukas Devos "] version = "0.7.0" +authors = ["Paul Brehmer", "Lander Burgelman", "Lukas Devos "] [deps] Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" diff --git a/src/algorithms/ctmrg/c4v.jl b/src/algorithms/ctmrg/c4v.jl index c201ffc98..ff9fb2e82 100644 --- a/src/algorithms/ctmrg/c4v.jl +++ b/src/algorithms/ctmrg/c4v.jl @@ -121,16 +121,24 @@ end # TODO: rewrite this using `initialize_environment` and C4v-specific initialization algorithms # environment with dummy corner singlet(V) ← singlet(V) and identity edge V ← V, initialized at dim(Venv) -function initialize_singlet_c4v_env(Vpeps::ElementarySpace, Venv::ElementarySpace, T = ComplexF64) - corner₀ = DiagonalTensorMap(zeros(real(T), Venv ← Venv)) - corner₀.data[1] = one(real(T)) - edge₀ = permute(id(T, Venv ⊗ Vpeps), ((1, 2, 4), (3,))) - return CTMRGEnv(corner₀, edge₀) -end +# function initialize_singlet_c4v_env(Vpeps::ElementarySpace, Venv::ElementarySpace, T = ComplexF64) +# corner₀ = DiagonalTensorMap(zeros(real(T), Venv ← Venv)) +# corner₀.data[1] = one(real(T)) +# edge₀ = permute(id(T, Venv ⊗ Vpeps), ((1, 2, 4), (3,))) +# return CTMRGEnv(corner₀, edge₀) +# end -function initialize_random_c4v_env(Vpeps::ElementarySpace, Venv::ElementarySpace, T = ComplexF64) +function initialize_random_c4v_env(Vstate::ElementarySpace, Venv::ElementarySpace, T = ComplexF64) corner₀ = DiagonalTensorMap(randn(real(T), Venv ← Venv)) - edge₀ = randn(T, Venv ⊗ Vpeps ⊗ Vpeps' ← Venv) + edge₀ = randn(T, Venv ⊗ Vstate ← Venv) edge₀ = project_hermitian(edge₀) return CTMRGEnv(corner₀, edge₀) end +function initialize_random_c4v_env(state::InfinitePEPS, Venv::ElementarySpace, T = scalartype(state)) + Vpeps = domain(state[1])[1] + return initialize_random_c4v_env(Vpeps ⊗ Vpeps', Venv, T) +end +function initialize_random_c4v_env(state::InfinitePartitionFunction, Venv::ElementarySpace, T = scalartype(state)) + Vpf = domain(state[1])[1] + return initialize_random_c4v_env(Vpf, Venv, T) +end diff --git a/src/algorithms/ctmrg/gaugefix.jl b/src/algorithms/ctmrg/gaugefix.jl index 4758ec09e..c625419a7 100644 --- a/src/algorithms/ctmrg/gaugefix.jl +++ b/src/algorithms/ctmrg/gaugefix.jl @@ -246,7 +246,9 @@ end Check if the element-wise difference of the corner and edge tensors of the final and fixed CTMRG environments are below `atol` and return the maximal difference. """ -function calc_elementwise_convergence(envfinal::CTMRGEnv, envfix::CTMRGEnv; atol::Real = 1.0e-6) +function calc_elementwise_convergence( + envfinal::CTMRGEnv, envfix::CTMRGEnv; atol::Real = 1.0e-6 + ) ΔC = envfinal.corners .- envfix.corners ΔCmax = norm(ΔC, Inf) ΔCmean = norm(ΔC) @@ -272,5 +274,18 @@ function calc_elementwise_convergence(envfinal::CTMRGEnv, envfix::CTMRGEnv; atol return max(ΔCmax, ΔTmax) end +function calc_elementwise_convergence( + envfinal::CTMRGEnv{C}, envfix::CTMRGEnv{C′}; kwargs... + ) where {C <: DiagonalTensorMap, C′} # case where one of the environments might have diagonal corners + return calc_elementwise_convergence( # make corners non-diagonal TensorMaps such that you can compute difference + CTMRGEnv(convert.(TensorMap, envfinal.corners), envfinal.edges; kwargs...), + CTMRGEnv(envfix.corners, envfix.edges; kwargs...) + ) +end +function calc_elementwise_convergence( + envfinal::CTMRGEnv{C}, envfix::CTMRGEnv{C′}; kwargs... + ) where {C, C′ <: DiagonalTensorMap} + return calc_elementwise_convergence(envfix, envfinal) +end @non_differentiable calc_elementwise_convergence(args...) diff --git a/test/ctmrg/fixed_iterscheme.jl b/test/ctmrg/fixed_iterscheme.jl index 85ea49893..5a635bb1d 100644 --- a/test/ctmrg/fixed_iterscheme.jl +++ b/test/ctmrg/fixed_iterscheme.jl @@ -43,8 +43,7 @@ atol = 1.0e-5 env_conv1, = leading_boundary(CTMRGEnv(psi, ComplexSpace(χ)), psi, ctm_alg) # do extra iteration to get SVD - # env_conv2, info = @constinferred ctmrg_iteration(n, env_conv1, ctm_alg) - env_conv2, info = ctmrg_iteration(n, env_conv1, ctm_alg) + env_conv2, info = @constinferred ctmrg_iteration(n, env_conv1, ctm_alg) env_fix, signs = gauge_fix(env_conv2, ScramblingEnvGauge(), env_conv1) @test calc_elementwise_convergence(env_conv1, env_fix) ≈ 0 atol = atol @@ -52,8 +51,7 @@ atol = 1.0e-5 ctm_alg_fix = gauge_fix(ctm_alg, signs, info) # do iteration with FixedSVD - # env_fixedsvd, = @constinferred ctmrg_iteration(n, env_conv1, ctm_alg_fix) - env_fixedsvd, = ctmrg_iteration(n, env_conv1, ctm_alg_fix) + env_fixedsvd, = @constinferred ctmrg_iteration(n, env_conv1, ctm_alg_fix) env_fixedsvd = fix_global_phases(env_conv1, env_fixedsvd) @test calc_elementwise_convergence(env_conv1, env_fixedsvd) ≈ 0 atol = atol end @@ -70,22 +68,19 @@ end psi = peps_normalize(symmetrize!(psi, symm)) n = InfiniteSquareNetwork(psi) - env₀ = initialize_random_c4v_env(ComplexSpace(2), ComplexSpace(χ), scalartype(psi)) + env₀ = initialize_random_c4v_env(psi, ComplexSpace(χ)) env_conv1, = leading_boundary(env₀, psi, ctm_alg) # do extra iteration to get SVD - # env_conv2, info = @constinferred ctmrg_iteration(n, env_conv1, ctm_alg) - env_conv2, info = ctmrg_iteration(n, env_conv1, ctm_alg) + env_conv2, info = @constinferred ctmrg_iteration(n, env_conv1, ctm_alg) env_fix, signs = gauge_fix(env_conv2, ScramblingEnvGaugeC4v(), env_conv1) - env_ref = CTMRGEnv(TensorMap.(env_conv1.corners), env_conv1.edges) - @test calc_elementwise_convergence(env_ref, env_fix) ≈ 0 atol = atol + @test calc_elementwise_convergence(env_conv1, env_fix) ≈ 0 atol = atol # fix gauge of SVD ctm_alg_fix = gauge_fix(ctm_alg, signs, info) # do iteration with FixedSVD - # env_fixedsvd, = @constinferred ctmrg_iteration(n, env_conv1, ctm_alg_fix) - env_fixedsvd, = ctmrg_iteration(n, env_conv1, ctm_alg_fix) + env_fixedsvd, = @constinferred ctmrg_iteration(n, env_conv1, ctm_alg_fix) env_fixedsvd = fix_global_phases(env_conv1, env_fixedsvd) @test calc_elementwise_convergence(env_conv1, env_fixedsvd) ≈ 0 atol = atol end diff --git a/test/ctmrg/flavors.jl b/test/ctmrg/flavors.jl index ae5dd99f6..727e042a9 100644 --- a/test/ctmrg/flavors.jl +++ b/test/ctmrg/flavors.jl @@ -85,7 +85,7 @@ end peps = peps_normalize(symmetrize!(peps, symm)) # boundary_alg = C4vCTMRG(; projector_alg, decomposition_alg = (; fwd_alg)) - env₀ = initialize_random_c4v_env(Vpeps, Venv, scalartype(peps)) + env₀ = initialize_random_c4v_env(peps, Venv) env, = leading_boundary( env₀, peps; alg = :c4v, projector_alg, decomposition_alg = (; fwd_alg = (; alg = eigh_alg)) diff --git a/test/ctmrg/gaugefix.jl b/test/ctmrg/gaugefix.jl index 425079144..886c011d7 100644 --- a/test/ctmrg/gaugefix.jl +++ b/test/ctmrg/gaugefix.jl @@ -5,24 +5,33 @@ using TensorKit using PEPSKit: ctmrg_iteration, calc_elementwise_convergence using PEPSKit: ScramblingEnvGauge, ScramblingEnvGaugeC4v +using PEPSKit: peps_normalize, initialize_random_c4v_env spacetypes = [ComplexSpace, Z2Space] scalartypes = [Float64, ComplexF64] unitcells = [(1, 1), (2, 2), (3, 2)] -ctmrg_algs = [SequentialCTMRG, SimultaneousCTMRG] -projector_algs = [:halfinfinite, :fullinfinite] -gauge_algs = [ScramblingEnvGauge()] +ctmrg_algs_asymm = [SequentialCTMRG, SimultaneousCTMRG] +projector_algs_asymm = [:halfinfinite, :fullinfinite] +projector_algs_c4v = [:c4v_eigh] #, :c4v_qr] +gauge_algs_asymm = [ScramblingEnvGauge()] +gauge_algs_c4v = [ScramblingEnvGaugeC4v()] tol = 1.0e-6 # large tol due to χ=6 χ = 6 atol = 1.0e-4 function _pre_converge_env( - ::Type{T}, physical_space, peps_space, ctm_space, unitcell; seed = 12345 + ::Type{T}, alg, physical_space, peps_space, env_space, unitcell; + seed = 985293852935829 ) where {T} Random.seed!(seed) # Seed RNG to make random environment consistent psi = InfinitePEPS(rand, T, physical_space, peps_space; unitcell) - env₀ = CTMRGEnv(psi, ctm_space) - env_conv, = leading_boundary(env₀, psi; alg = :sequential, tol) + alg == :c4v && (psi = peps_normalize(symmetrize!(psi, RotateReflect()))) + env₀ = if alg == :c4v + initialize_random_c4v_env(psi, env_space, T) + else + CTMRGEnv(psi, env_space) + end + env_conv, = leading_boundary(env₀, psi; alg, tol) return env_conv, psi end @@ -30,20 +39,32 @@ end preconv = Dict() for (S, T, unitcell) in Iterators.product(spacetypes, scalartypes, unitcells) if S == ComplexSpace - result = _pre_converge_env(T, S(2), S(2), S(χ), unitcell) + result = _pre_converge_env(T, :sequential, S(2), S(2), S(χ), unitcell) elseif S == Z2Space result = _pre_converge_env( - T, S(0 => 1, 1 => 1), S(0 => 1, 1 => 1), S(0 => χ ÷ 2, 1 => χ ÷ 2), unitcell + T, :sequential, S(0 => 1, 1 => 1), S(0 => 1, 1 => 1), + S(0 => χ ÷ 2, 1 => χ ÷ 2), unitcell ) end push!(preconv, (S, T, unitcell) => result) end +preconv_c4v = Dict() +for (S, T) in Iterators.product(spacetypes, scalartypes) + if S == ComplexSpace + result = _pre_converge_env(T, :c4v, S(2), S(2), S(χ), (1, 1)) + elseif S == Z2Space + result = _pre_converge_env( + T, :c4v, S(0 => 1, 1 => 1), S(0 => 1, 1 => 1), S(0 => χ ÷ 2, 1 => χ ÷ 2), (1, 1) + ) + end + push!(preconv_c4v, (S, T) => result) +end # asymmetric CTMRG @testset "($S) - ($T) - ($unitcell) - ($ctmrg_alg) - ($projector_alg) - ($gauge_alg)" for ( S, T, unitcell, ctmrg_alg, projector_alg, gauge_alg, ) in Iterators.product( - spacetypes, scalartypes, unitcells, ctmrg_algs, projector_algs, gauge_algs + spacetypes, scalartypes, unitcells, ctmrg_algs_asymm, projector_algs_asymm, gauge_algs_asymm ) alg = ctmrg_alg(; tol, projector_alg) env_pre, psi = preconv[(S, T, unitcell)] @@ -54,17 +75,14 @@ end @test calc_elementwise_convergence(env, env_fixed) ≈ 0 atol = atol end -projector_algs_c4v = [:c4v_eigh, :c4v_qr] -gauge_algs_c4v = [ScramblingEnvGaugeC4v()] - # C4v CTMRG @testset "($S) - ($T) - ($projector_alg) - ($gauge_alg)" for ( - S, T, unitcell, ctmrg_alg, projector_alg, gauge_alg, + S, T, projector_alg, gauge_alg, ) in Iterators.product( spacetypes, scalartypes, projector_algs_c4v, gauge_algs_c4v ) alg = C4vCTMRG(; tol, projector_alg) - env_pre, psi = preconv[(S, T, unitcell)] # TODO + env_pre, psi = preconv_c4v[(S, T)] n = InfiniteSquareNetwork(psi) env, = leading_boundary(env_pre, psi, alg) env′, = ctmrg_iteration(n, env, alg) diff --git a/test/ctmrg/partition_function.jl b/test/ctmrg/partition_function.jl index 056092022..e05b480a4 100644 --- a/test/ctmrg/partition_function.jl +++ b/test/ctmrg/partition_function.jl @@ -97,15 +97,16 @@ end beta = 0.6 O, M, E = classical_ising(; beta) Z = InfinitePartitionFunction(O) +Venv = ℂ^12 Random.seed!(81812781143) - -# contract -χenv = ℂ^12 -env0 = CTMRGEnv(Z, χenv) - +env₀ = CTMRGEnv(Z, Venv) +env₀_c4v = initialize_random_c4v_env(Z, Venv) # cover all different flavors -ctm_styles = [:sequential, :simultaneous] -projector_algs = [:halfinfinite, :fullinfinite] +args = [ + (:sequential, :halfinfinite), (:sequential, :fullinfinite), + (:simultaneous, :halfinfinite), (:simultaneous, :fullinfinite), + (:c4v, :c4v_eigh), #(:c4v, :c4v_qr) +] # Basic properties @test spacetype(typeof(Z)) === ComplexSpace @@ -122,10 +123,9 @@ projector_algs = [:halfinfinite, :fullinfinite] @testset "Classical Ising partition function using $alg with $projector_alg" for ( alg, projector_alg, - ) in Iterators.product( - ctm_styles, projector_algs - ) - env, = leading_boundary(env0, Z; alg, maxiter = 150, projector_alg) + ) in args + env₀₀ = alg == :c4v ? env₀_c4v : env₀ + env, = leading_boundary(env₀₀, Z; alg, maxiter = 300, projector_alg) # check observables λ = network_value(Z, env) From 6effabe5e1fa9d1976311c74f773e17d69904e2c Mon Sep 17 00:00:00 2001 From: Paul Brehmer Date: Fri, 23 Jan 2026 16:14:30 +0100 Subject: [PATCH 10/14] Fix differentiability and add C4v Heisenberg test --- src/algorithms/ctmrg/c4v.jl | 2 +- src/environments/ctmrg_environments.jl | 2 +- test/examples/heisenberg.jl | 43 ++++++++++++++++++++++++++ 3 files changed, 45 insertions(+), 2 deletions(-) diff --git a/src/algorithms/ctmrg/c4v.jl b/src/algorithms/ctmrg/c4v.jl index ff9fb2e82..082d0a670 100644 --- a/src/algorithms/ctmrg/c4v.jl +++ b/src/algorithms/ctmrg/c4v.jl @@ -128,7 +128,7 @@ end # return CTMRGEnv(corner₀, edge₀) # end -function initialize_random_c4v_env(Vstate::ElementarySpace, Venv::ElementarySpace, T = ComplexF64) +function initialize_random_c4v_env(Vstate::VectorSpace, Venv::ElementarySpace, T = ComplexF64) corner₀ = DiagonalTensorMap(randn(real(T), Venv ← Venv)) edge₀ = randn(T, Venv ⊗ Vstate ← Venv) edge₀ = project_hermitian(edge₀) diff --git a/src/environments/ctmrg_environments.jl b/src/environments/ctmrg_environments.jl index cf85e4702..8060595ec 100644 --- a/src/environments/ctmrg_environments.jl +++ b/src/environments/ctmrg_environments.jl @@ -234,7 +234,7 @@ end # Custom adjoint for CTMRGEnv constructor, needed for fixed-point differentiation function ChainRulesCore.rrule( - ::Type{CTMRGEnv}, corners::Array{3, C}, edges::Array{3, T} + ::Type{CTMRGEnv}, corners::Array{C, 3}, edges::Array{T, 3} ) where {C, T} ctmrgenv_pullback(ē) = NoTangent(), ē.corners, ē.edges return CTMRGEnv(corners, edges), ctmrgenv_pullback diff --git a/test/examples/heisenberg.jl b/test/examples/heisenberg.jl index df526e9f4..bd89ceb0b 100644 --- a/test/examples/heisenberg.jl +++ b/test/examples/heisenberg.jl @@ -5,6 +5,8 @@ using PEPSKit using TensorKit using KrylovKit using OptimKit +using PEPSKit: peps_normalize, initialize_random_c4v_env +using MPSKitModels: S_xx, S_yy, S_zz # initialize parameters Dbond = 2 @@ -14,6 +16,25 @@ gradtol = 1.0e-3 # https://github.com/jurajHasik/j1j2_ipeps_states/blob/main/single-site_pg-C4v-A1/j20.0/state_1s_A1_j20.0_D2_chi_opt48.dat E_ref = -0.6602310934799577 +# Heisenberg model assuming C4v symmetric PEPS and environment, which only evaluates necessary term +function heisenberg_XYZ_c4v(lattice::InfiniteSquare; kwargs...) + return heisenberg_XYZ_c4v(ComplexF64, Trivial, lattice; kwargs...) +end +function heisenberg_XYZ_c4v( + T::Type{<:Number}, S::Type{<:Sector}, lattice::InfiniteSquare; + Jx = -1.0, Jy = 1.0, Jz = -1.0, spin = 1 // 2, + ) + @assert size(lattice) == (1, 1) "only trivial unit cells supported by C4v-symmetric Hamiltonians" + term = + rmul!(S_xx(T, S; spin = spin), Jx) + + rmul!(S_yy(T, S; spin = spin), Jy) + + rmul!(S_zz(T, S; spin = spin), Jz) + spaces = fill(domain(term)[1], (1, 1)) + return LocalOperator( # horizontal and vertical contributions are identical + spaces, (CartesianIndex(1, 1), CartesianIndex(1, 2)) => 2 * term + ) +end + @testset "(1, 1) unit cell AD optimization" begin # initialize states Random.seed!(123) @@ -29,6 +50,28 @@ E_ref = -0.6602310934799577 @test all(@. ξ_h > 0 && ξ_v > 0) end +@testset "C4v AD optimization" begin + # initialize symmetric states + Random.seed!(123) + symm = RotateReflect() + H = heisenberg_XYZ_c4v(InfiniteSquare()) + peps₀ = InfinitePEPS(ComplexSpace(2), ComplexSpace(Dbond)) + peps₀ = peps_normalize(symmetrize!(peps₀, symm)) + e₀ = initialize_random_c4v_env(peps₀, ComplexSpace(χenv)) + env₀, = leading_boundary(e₀, peps₀; alg = :c4v) + + # optimize energy and compute correlation lengths + peps, env, E, = fixedpoint( + H, peps₀, env₀; + optimizer_alg = (; tol = gradtol, maxiter = 25), + boundary_alg = (; alg = :c4v), symmetrization = symm, + ) + ξ_h, ξ_v, = correlation_length(peps, env) + + @test E ≈ E_ref atol = 1.0e-2 + @test all(@. ξ_h > 0 && ξ_v > 0) +end + @testset "(1, 2) unit cell AD optimization" begin # initialize states Random.seed!(456) From 089a005799038481a1da13aa5962ff5990a850cf Mon Sep 17 00:00:00 2001 From: Paul Brehmer Date: Fri, 23 Jan 2026 16:46:11 +0100 Subject: [PATCH 11/14] Automatically symmetrize optimization when alg=:c4v --- src/algorithms/select_algorithm.jl | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/algorithms/select_algorithm.jl b/src/algorithms/select_algorithm.jl index 57b615ac1..8ba6beda1 100644 --- a/src/algorithms/select_algorithm.jl +++ b/src/algorithms/select_algorithm.jl @@ -23,7 +23,7 @@ function select_algorithm( tol = Defaults.optimizer_tol, # top-level tolerance verbosity = 3, # top-level verbosity boundary_alg = (;), gradient_alg = (;), optimizer_alg = (;), - kwargs..., + symmetrization = nothing, kwargs..., ) # adjust CTMRG tols and verbosity if boundary_alg isa NamedTuple @@ -45,7 +45,12 @@ function select_algorithm( optimizer_alg = merge(defaults, optimizer_alg) end - return PEPSOptimize(; boundary_alg, gradient_alg, optimizer_alg, kwargs...) + # symmetrize state and gradient when doing C4v optimization + if boundary_alg isa C4vCTMRG && isnothing(symmetrization) + symmetrization = RotateReflect() + end + + return PEPSOptimize(; boundary_alg, gradient_alg, optimizer_alg, symmetrization, kwargs...) end function select_algorithm( From f3d89e1f6adfc1cab888667037ed0b7dec47566f Mon Sep 17 00:00:00 2001 From: Paul Brehmer Date: Fri, 23 Jan 2026 16:46:22 +0100 Subject: [PATCH 12/14] Fix unitcell test imports --- test/ctmrg/unitcell.jl | 2 +- test/examples/heisenberg.jl | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/test/ctmrg/unitcell.jl b/test/ctmrg/unitcell.jl index 6524c68ab..50016fa79 100644 --- a/test/ctmrg/unitcell.jl +++ b/test/ctmrg/unitcell.jl @@ -1,7 +1,7 @@ using Test using Random using PEPSKit -using PEPSKit: _prev, _next, ctmrg_iteration, gauge_fix +using PEPSKit: _prev, _next, ctmrg_iteration, gauge_fix, ScramblingEnvGauge using TensorKit # settings diff --git a/test/examples/heisenberg.jl b/test/examples/heisenberg.jl index bd89ceb0b..d389cf54a 100644 --- a/test/examples/heisenberg.jl +++ b/test/examples/heisenberg.jl @@ -64,12 +64,12 @@ end peps, env, E, = fixedpoint( H, peps₀, env₀; optimizer_alg = (; tol = gradtol, maxiter = 25), - boundary_alg = (; alg = :c4v), symmetrization = symm, + boundary_alg = (; alg = :c4v), ) ξ_h, ξ_v, = correlation_length(peps, env) @test E ≈ E_ref atol = 1.0e-2 - @test all(@. ξ_h > 0 && ξ_v > 0) + @test only(ξ_h) ≈ only(ξ_v) end @testset "(1, 2) unit cell AD optimization" begin From fcdea38093971b2392fcd7a0a2980d54a8fb9368 Mon Sep 17 00:00:00 2001 From: Paul Brehmer Date: Fri, 23 Jan 2026 17:24:40 +0100 Subject: [PATCH 13/14] Add eigh_wrapper test and fix IterEigh --- src/PEPSKit.jl | 2 +- .../fixed_point_differentiation.jl | 2 +- src/utility/eigh.jl | 23 +++--- test/runtests.jl | 3 + test/utility/eigh_wrapper.jl | 74 +++++++++++++++++++ test/utility/svd_wrapper.jl | 4 +- 6 files changed, 93 insertions(+), 15 deletions(-) create mode 100644 test/utility/eigh_wrapper.jl diff --git a/src/PEPSKit.jl b/src/PEPSKit.jl index dbad9cc05..24606eb0d 100644 --- a/src/PEPSKit.jl +++ b/src/PEPSKit.jl @@ -105,7 +105,7 @@ export SVDAdjoint, FullSVDReverseRule, IterSVD export CTMRGEnv, SequentialCTMRG, SimultaneousCTMRG export FixedSpaceTruncation, SiteDependentTruncation export HalfInfiniteProjector, FullInfiniteProjector -export EighAdjoint, C4vCTMRG, C4vEighProjector, C4vQRProjector +export EighAdjoint, IterEigh, C4vCTMRG, C4vEighProjector, C4vQRProjector export LocalOperator, physicalspace export product_peps export reduced_densitymatrix, expectation_value, network_value, cost_function diff --git a/src/algorithms/optimization/fixed_point_differentiation.jl b/src/algorithms/optimization/fixed_point_differentiation.jl index c05dcfc3d..680e3a16c 100644 --- a/src/algorithms/optimization/fixed_point_differentiation.jl +++ b/src/algorithms/optimization/fixed_point_differentiation.jl @@ -338,7 +338,7 @@ function gauge_fix(alg::EighAdjoint, signs, info) rrule_alg = alg.rrule_alg, ) end -function gauge_fix(alg::EighAdjoint{F}, signs, info) where {F <: IterEig} +function gauge_fix(alg::EighAdjoint{F}, signs, info) where {F <: IterEigh} # fix kept U only since iterative decomposition doesn't have access to full spectrum U_fixed = info.U * signs[1]' return EighAdjoint(; diff --git a/src/utility/eigh.jl b/src/utility/eigh.jl index 2fd2bd9ef..2e7906485 100644 --- a/src/utility/eigh.jl +++ b/src/utility/eigh.jl @@ -79,10 +79,10 @@ const EIGH_FWD_SYMBOLS = IdDict{Symbol, Any}( :multiple => LAPACK_MultipleRelativelyRobustRepresentations, :lanczos => (; tol = 1.0e-14, krylovdim = 30, kwargs...) -> - IterEig(; alg = Lanczos(; tol, krylovdim), kwargs...), + IterEigh(; alg = Lanczos(; tol, krylovdim), kwargs...), :blocklanczos => (; tol = 1.0e-14, krylovdim = 30, kwargs...) -> - IterEig(; alg = BlockLanczos(; tol, krylovdim), kwargs...), + IterEigh(; alg = BlockLanczos(; tol, krylovdim), kwargs...), ) const EIGH_RRULE_SYMBOLS = IdDict{Symbol, Type{<:Any}}( :full => FullEighPullback, :trunc => TruncEighPullback, @@ -229,22 +229,22 @@ $(TYPEDFIELDS) ## Constructors - IterEig(; kwargs...) + IterEigh(; kwargs...) -Construct an `IterEig` algorithm struct based on the following keyword arguments: +Construct an `IterEigh` algorithm struct based on the following keyword arguments: * `alg=KrylovKit.Lanczos(; tol=1e-14, krylovdim=25)` : KrylovKit algorithm struct for iterative eigenvalue decomposition. * `fallback_threshold::Float64=Inf` : Threshold for `howmany / minimum(size(block))` above which (if the block is too small) the algorithm falls back to a dense decomposition. * `start_vector=random_start_vector` : Function providing the initial vector for the iterative algorithm. """ -@kwdef struct IterEig +@kwdef struct IterEigh alg = KrylovKit.Lanczos(; tol = 1.0e-14, krylovdim = 25) fallback_threshold::Float64 = Inf start_vector = random_start_vector end # Compute eigh data block-wise using KrylovKit algorithm -function _eigh_trunc!(f, alg::IterEig, trunc::TruncationStrategy) +function _eigh_trunc!(f, alg::IterEigh, trunc::TruncationStrategy) D, U = if isempty(blocksectors(f)) # early return truncation_error = zero(real(scalartype(f))) @@ -265,7 +265,7 @@ end # Obtain sparse decomposition from block-wise eigsolve calls function _compute_eighdata!( - f, alg::IterEig, trunc::Union{NoTruncation, TruncationSpace} + f, alg::IterEigh, trunc::Union{NoTruncation, TruncationSpace} ) InnerProductStyle(f) === EuclideanInnerProduct() || throw_invalid_innerproduct(:eigh_trunc!) domain(f) == codomain(f) || @@ -282,7 +282,8 @@ function _compute_eighdata!( D, U = eigh_full!(b, LAPACK_QRIteration()) lm_ordering = sortperm(abs.(D.diag); rev = true) # order values and vectors consistently with eigsolve D = D.diag[lm_ordering] # extracts diagonal as Vector instead of Diagonal to make compatible with D of svdsolve - U = view(U, lm_ordering)[:, 1:howmany] + @show lm_ordering + U = stack(eachcol(U)[lm_ordering])[:, 1:howmany] else x₀ = alg.start_vector(b) eig_alg = alg.alg @@ -295,7 +296,7 @@ function _compute_eighdata!( D, U = eigh_full!(b, LAPACK_QRIteration()) lm_ordering = sortperm(abs.(D.diag); rev = true) D = D.diag[lm_ordering] - U = view(U, lm_ordering)[:, 1:howmany] + U = stack(eachcol(U)[lm_ordering])[:, 1:howmany] else # Slice in case more values were converged than requested U = stack(view(lvecs, 1:howmany)) end @@ -367,13 +368,13 @@ function ChainRulesCore.rrule( return (D̃, Ũ, info), eigh_trunc!_full_pullback end -# eigh_trunc! rrule wrapping MatrixAlgebraKit's eigh_trunc_pullback! (also works for IterEig) +# eigh_trunc! rrule wrapping MatrixAlgebraKit's eigh_trunc_pullback! (also works for IterEigh) function ChainRulesCore.rrule( ::typeof(eigh_trunc!), t, alg::EighAdjoint{F, R}; trunc::TruncationStrategy = notrunc(), - ) where {F <: Union{<:LAPACK_EighAlgorithm, <:FixedEig, IterEig}, R <: TruncEighPullback} + ) where {F <: Union{<:LAPACK_EighAlgorithm, <:FixedEig, IterEigh}, R <: TruncEighPullback} D, U, info = eigh_trunc(t, alg; trunc) gtol = _get_pullback_gauge_tol(alg.rrule_alg.verbosity) diff --git a/test/runtests.jl b/test/runtests.jl index 5fc589315..9085b763c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -111,6 +111,9 @@ end end end if GROUP == "ALL" || GROUP == "UTILITY" + @time @safetestset "SVD wrapper" begin + include("utility/eigh_wrapper.jl") + end @time @safetestset "SVD wrapper" begin include("utility/svd_wrapper.jl") end diff --git a/test/utility/eigh_wrapper.jl b/test/utility/eigh_wrapper.jl new file mode 100644 index 000000000..a1f1a4f8e --- /dev/null +++ b/test/utility/eigh_wrapper.jl @@ -0,0 +1,74 @@ +using Test +using Random +using LinearAlgebra +using TensorKit +using ChainRulesCore, Zygote +using Accessors +using FixedPointAD + +# Gauge-invariant loss function +function lossfun(A, alg, R = randn(space(A)), trunc = notrunc()) + D, V, = eigh_trunc(A, alg; trunc) + return real(dot(R, V * V')) + dot(D, D) # Overlap with random tensor R is gauge-invariant and differentiable +end + +dtype = ComplexF64 +n = 20 +χ = 10 +trunc = truncspace(ℂ^χ) +rtol = 1.0e-9 +Random.seed!(123456789) +r = randn(dtype, ℂ^n, ℂ^n) +r = 0.5 * (r + r') # make r Hermitian +R = randn(space(r)) +R = 0.5 * (R + R') + +full_alg = EighAdjoint(; fwd_alg = (; alg = :qriteration), rrule_alg = (; alg = :trunc)) +iter_alg = EighAdjoint(; fwd_alg = (; alg = :lanczos), rrule_alg = (; alg = :trunc)) + +@testset "Non-truncacted eigh" begin + l_full, g_full = withgradient(A -> lossfun(A, full_alg, R), r) + l_iter, g_iter = withgradient(A -> lossfun(A, iter_alg, R), r) + + @test l_iter ≈ l_full + @test g_full[1] ≈ g_iter[1] rtol = rtol +end + +@testset "Truncated eigh with χ=$χ" begin + l_full, g_full = withgradient(A -> lossfun(A, full_alg, R, trunc), r) + l_iter, g_iter = withgradient(A -> lossfun(A, iter_alg, R, trunc), r) + + @test l_iter ≈ l_full + @test g_full[1] ≈ g_iter[1] rtol = rtol +end + +symm_m, symm_n = 18, 24 +symm_space = Z2Space(0 => symm_m, 1 => symm_n) +symm_trspace = truncspace(Z2Space(0 => symm_m ÷ 2, 1 => symm_n ÷ 3)) +symm_r = randn(dtype, symm_space, symm_space) +symm_r = 0.5 * (symm_r + symm_r') +symm_R = randn(dtype, space(symm_r)) +symm_R = 0.5 * (symm_R + symm_R') + +@testset "IterEig of symmetric tensors" begin + l_full, g_full = withgradient(A -> lossfun(A, full_alg, symm_R), symm_r) + l_iter, g_iter = withgradient(A -> lossfun(A, iter_alg, symm_R), symm_r) + @test l_iter ≈ l_full + @test g_full[1] ≈ g_iter[1] rtol = rtol + + l_full_tr, g_full_tr = withgradient( + A -> lossfun(A, full_alg, symm_R, symm_trspace), symm_r + ) + l_iter_tr, g_iter_tr = withgradient( + A -> lossfun(A, iter_alg, symm_R, symm_trspace), symm_r + ) + @test l_iter_tr ≈ l_full_tr + @test g_full_tr[1] ≈ g_iter_tr[1] rtol = rtol + + iter_alg_fallback = @set iter_alg.fwd_alg.fallback_threshold = 0.4 # Do dense decomposition in one block, sparse one in the other + l_iter_fb, g_iter_fb = withgradient( + A -> lossfun(A, iter_alg_fallback, symm_R, symm_trspace), symm_r + ) + @test l_iter_fb ≈ l_full_tr + @test g_full_tr[1] ≈ g_iter_fb[1] rtol = rtol +end diff --git a/test/utility/svd_wrapper.jl b/test/utility/svd_wrapper.jl index b6bad0d7c..59b05d9a6 100644 --- a/test/utility/svd_wrapper.jl +++ b/test/utility/svd_wrapper.jl @@ -13,8 +13,8 @@ function lossfun(A, alg, R = randn(space(A)), trunc = notrunc()) return real(dot(R, U * V)) + dot(S, S) # Overlap with random tensor R is gauge-invariant and differentiable, also for m≠n end -m, n = 20, 30 dtype = ComplexF64 +m, n = 20, 30 χ = 12 trunc = truncspace(ℂ^χ) rtol = 1.0e-9 @@ -25,7 +25,7 @@ R = randn(space(r)) full_alg = SVDAdjoint(; rrule_alg = (; alg = :full, broadening = 0)) iter_alg = SVDAdjoint(; fwd_alg = (; alg = :iterative)) -@testset "Non-truncacted SVD" begin +@testset "Non-truncated SVD" begin l_fullsvd, g_fullsvd = withgradient(A -> lossfun(A, full_alg, R), r) l_itersvd, g_itersvd = withgradient(A -> lossfun(A, iter_alg, R), r) From 790557088a26b36f2b27668a1eff0fa30fddafe0 Mon Sep 17 00:00:00 2001 From: Yue Zhengyuan Date: Sat, 24 Jan 2026 14:50:20 +0800 Subject: [PATCH 14/14] Fix eigh_wrapper test imports --- test/utility/eigh_wrapper.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/utility/eigh_wrapper.jl b/test/utility/eigh_wrapper.jl index a1f1a4f8e..6318cd6d7 100644 --- a/test/utility/eigh_wrapper.jl +++ b/test/utility/eigh_wrapper.jl @@ -4,7 +4,7 @@ using LinearAlgebra using TensorKit using ChainRulesCore, Zygote using Accessors -using FixedPointAD +using PEPSKit # Gauge-invariant loss function function lossfun(A, alg, R = randn(space(A)), trunc = notrunc())