From 5fb3588aee77cdc8f20c3672ad6e0e5127a686d5 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 8 Jan 2026 15:37:26 +0100 Subject: [PATCH 1/3] try to make truncation GPU-friendly --- src/factorizations/truncation.jl | 49 +++++++++++++++++++++----------- 1 file changed, 33 insertions(+), 16 deletions(-) diff --git a/src/factorizations/truncation.jl b/src/factorizations/truncation.jl index b9d060fec..3bc1ebf48 100644 --- a/src/factorizations/truncation.jl +++ b/src/factorizations/truncation.jl @@ -191,6 +191,23 @@ function _sort_and_perm(values::SectorVector; by = identity, rev::Bool = false) return values_sorted, perms end +function _findtruncvalue_order(values::SectorVector, n::Int; by = identity, rev::Bool = false) + I = sectortype(values) + p = sortperm(parent(values); by, rev) + + if FusionStyle(I) isa UniqueFusion # dimensions are all 1 + return n <= 0 ? nothing : p[min(n, length(p))] + else + dims = similar(values, Base.promote_op(dim, I)) + for (c, v) in pairs(dims) + fill!(v, dim(c)) + end + cumulative_dim = cumsum(Base.permute!(parent(dims), p)) + k = findlast(<=(n), cumulative_dim) + return isnothing(k) ? k : p[k] + end +end + # findtruncated # ------------- # Generic fallback @@ -202,25 +219,25 @@ function MAK.findtruncated(values::SectorVector, ::NoTruncation) return SectorDict(c => Colon() for c in keys(values)) end +# TruncationByOrder strategy: +# - find the howmany'th value of the input sorted according to the strategy +# - discard everything that is ordered after that value + function MAK.findtruncated(values::SectorVector, strategy::TruncationByOrder) - values_sorted, perms = _sort_and_perm(values; strategy.by, strategy.rev) - inds = MAK.findtruncated_svd(values_sorted, truncrank(strategy.howmany)) - return SectorDict(c => perms[c][I] for (c, I) in inds) -end -function MAK.findtruncated_svd(values::SectorVector, strategy::TruncationByOrder) - I = keytype(values) - truncdim = SectorDict{I, Int}(c => length(d) for (c, d) in pairs(values)) - totaldim = sum(dim(c) * d for (c, d) in truncdim; init = 0) - while totaldim > strategy.howmany - next = _findnexttruncvalue(values, truncdim; strategy.by, strategy.rev) - isnothing(next) && break - _, cmin = next - truncdim[cmin] -= 1 - totaldim -= dim(cmin) - truncdim[cmin] == 0 && delete!(truncdim, cmin) + k = _findtruncvalue_order(values, strategy.howmany; strategy.by, strategy.rev) + + if isnothing(k) + # discard everything + return SectorDict{sectortype(values), UnitRange{Int}}() + else + val = strategy.by(values[k]) + strategy = trunctol(; atol = val, strategy.by, keep_below = !strategy.rev) + return MAK.findtruncated_svd(values, strategy) end - return SectorDict(c => Base.OneTo(d) for (c, d) in truncdim) end +# disambiguate +MAK.findtruncated_svd(values::SectorVector, strategy::TruncationByOrder) = + MAK.findtruncated(values, strategy) function MAK.findtruncated(values::SectorVector, strategy::TruncationByFilter) return SectorDict(c => findall(strategy.filter, d) for (c, d) in pairs(values)) From 77f0ffa0879bd84e3cd304ec6c9e0b29a20b12de Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 20 Jan 2026 11:21:17 +0100 Subject: [PATCH 2/3] Temporarily fix StridedViews version --- Project.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Project.toml b/Project.toml index 934bdb6ed..3e0c7a02d 100644 --- a/Project.toml +++ b/Project.toml @@ -52,6 +52,7 @@ Random = "1" SafeTestsets = "0.1" ScopedValues = "1.3.0" Strided = "2" +StridedViews = "=0.4.1" TensorKitSectors = "0.3.3" TensorOperations = "5.1" Test = "1" @@ -75,6 +76,7 @@ GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" +StridedViews = "4db3bf67-4bd7-4b4e-b153-31dc3fb37143" TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a" From 5e5d87b4f113bc9a25dc55f61079ac8571091cde Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 20 Jan 2026 16:16:49 +0100 Subject: [PATCH 3/3] Revert "Temporarily fix StridedViews version" This reverts commit 77f0ffa0879bd84e3cd304ec6c9e0b29a20b12de. --- Project.toml | 2 -- 1 file changed, 2 deletions(-) diff --git a/Project.toml b/Project.toml index 3e0c7a02d..934bdb6ed 100644 --- a/Project.toml +++ b/Project.toml @@ -52,7 +52,6 @@ Random = "1" SafeTestsets = "0.1" ScopedValues = "1.3.0" Strided = "2" -StridedViews = "=0.4.1" TensorKitSectors = "0.3.3" TensorOperations = "5.1" Test = "1" @@ -76,7 +75,6 @@ GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" -StridedViews = "4db3bf67-4bd7-4b4e-b153-31dc3fb37143" TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"