@@ -22,13 +22,37 @@ function lu(A::AbstractMatrix, pivot = Val(true), thread = Val(true); kwargs...)
2222 return lu! (copy (A), normalize_pivot (pivot), thread; kwargs... )
2323end
2424
25+ struct NotIPIV <: AbstractVector{BlasInt}
26+ len:: Int
27+ end
28+ Base. size (A:: NotIPIV ) = (A. len,)
29+ Base. getindex (:: NotIPIV , i:: Int ) = i
30+ Base. view (:: NotIPIV , r:: AbstractUnitRange ) = NotIPIV (length (r))
31+ init_pivot (:: Val{false} , minmn) = NotIPIV (minmn)
32+ init_pivot (:: Val{true} , minmn) = Vector {BlasInt} (undef, minmn)
33+
34+ if isdefined (LinearAlgebra, :_ipiv_cols! )
35+ function LinearAlgebra. _ipiv_cols! (:: LU{<:Any, <:Any, NotIPIV} , :: OrdinalRange ,
36+ B:: StridedVecOrMat )
37+ return B
38+ end
39+ end
40+ if isdefined (LinearAlgebra, :_ipiv_rows! )
41+ function LinearAlgebra. _ipiv_rows! (:: LU{<:Any, <:Any, NotIPIV} , :: OrdinalRange ,
42+ B:: StridedVecOrMat )
43+ return B
44+ end
45+ end
46+
2547function lu! (A, pivot = Val (true ), thread = Val (true ); check = true , kwargs... )
2648 m, n = size (A)
2749 minmn = min (m, n)
28- F = if minmn < 10 # avx introduces small performance degradation
50+ # we want the type on both branches to match. When pivot = Val(false), we construct
51+ # a `NotIPIV`, which `LinearAlgebra.generic_lufact!` does not.
52+ F = if pivot === Val (true ) && minmn < 10 # avx introduces small performance degradation
2953 LinearAlgebra. generic_lufact! (A, to_stdlib_pivot (pivot); check = check)
3054 else
31- lu! (A, Vector {BlasInt} (undef , minmn), normalize_pivot (pivot), thread; check = check,
55+ lu! (A, init_pivot (pivot , minmn), normalize_pivot (pivot), thread; check = check,
3256 kwargs... )
3357 end
3458 return F
@@ -44,6 +68,8 @@ pick_threshold() = LoopVectorization.register_size() == 64 ? 48 : 40
4468recurse (:: StridedArray ) = true
4569recurse (_) = false
4670
71+ _ptrarray (ipiv) = PtrArray (ipiv)
72+ _ptrarray (ipiv:: NotIPIV ) = ipiv
4773function lu! (A:: AbstractMatrix{T} , ipiv:: AbstractVector{<:Integer} ,
4874 pivot = Val (true ), thread = Val (true );
4975 check:: Bool = true ,
@@ -58,7 +84,7 @@ function lu!(A::AbstractMatrix{T}, ipiv::AbstractVector{<:Integer},
5884 if T <: Union{Float32, Float64}
5985 GC. @preserve ipiv A begin info = recurse! (view (PtrArray (A), axes (A)... ), pivot,
6086 m, n, mnmin,
61- PtrArray (ipiv), info, blocksize,
87+ _ptrarray (ipiv), info, blocksize,
6288 thread) end
6389 else
6490 info = recurse! (A, pivot, m, n, mnmin, ipiv, info, blocksize, thread)
90116 # [AL AR]
91117 AL = @view A[:, 1 : m]
92118 AR = @view A[:, (m + 1 ): n]
93- apply_permutation! (ipiv, AR, Val ( Thread))
94- ldiv! (_unit_lower_triangular (AL), AR, Val ( Thread))
119+ apply_permutation! (ipiv, AR, Val { Thread} ( ))
120+ ldiv! (_unit_lower_triangular (AL), AR, Val { Thread} ( ))
95121 end
96122 info
97123end
@@ -187,8 +213,10 @@ function reckernel!(A::AbstractMatrix{T}, pivot::Val{Pivot}, m, n, ipiv, info, b
187213 Pivot && apply_permutation! (P2, A21, thread)
188214
189215 info != previnfo && (info += n1)
190- @turbo warn_check_args= false for i in 1 : n2
191- P2[i] += n1
216+ if Pivot
217+ @turbo warn_check_args= false for i in 1 : n2
218+ P2[i] += n1
219+ end
192220 end
193221 return info
194222 end # inbounds
@@ -234,8 +262,8 @@ function _generic_lufact!(A, ::Val{Pivot}, ipiv, info) where {Pivot}
234262 amax = absi
235263 end
236264 end
265+ ipiv[k] = kp
237266 end
238- ipiv[k] = kp
239267 if ! iszero (A[kp, k])
240268 if k != kp
241269 # Interchange
0 commit comments