@@ -22,13 +22,23 @@ function lu(A::AbstractMatrix, pivot = Val(true), thread = Val(true); kwargs...)
2222 return lu! (copy (A), normalize_pivot (pivot), thread; kwargs... )
2323end
2424
25+ struct NotIPIV <: AbstractVector{BlasInt} len:: Int end
26+ Base. size (A:: NotIPIV ) = (A. len,)
27+ Base. getindex (:: NotIPIV , i:: Int ) = i
28+ Base. view (:: NotIPIV , r:: AbstractUnitRange ) = NotIPIV (length (r))
29+ init_pivot (:: Val{false} , minmn) = NotIPIV (minmn)
30+ init_pivot (:: Val{true} , minmn) = Vector {BlasInt} (undef, minmn)
31+
32+
2533function lu! (A, pivot = Val (true ), thread = Val (true ); check = true , kwargs... )
2634 m, n = size (A)
2735 minmn = min (m, n)
28- F = if minmn < 10 # avx introduces small performance degradation
36+ # we want the type on both branches to match. When pivot = Val(false), we construct
37+ # a `NotIPIV`, which `LinearAlgebra.generic_lufact!` does not.
38+ F = if pivot === Val (true ) && minmn < 10 # avx introduces small performance degradation
2939 LinearAlgebra. generic_lufact! (A, to_stdlib_pivot (pivot); check = check)
3040 else
31- lu! (A, Vector {BlasInt} (undef , minmn), normalize_pivot (pivot), thread; check = check,
41+ lu! (A, init_pivot (pivot , minmn), normalize_pivot (pivot), thread; check = check,
3242 kwargs... )
3343 end
3444 return F
@@ -44,6 +54,8 @@ pick_threshold() = LoopVectorization.register_size() == 64 ? 48 : 40
4454recurse (:: StridedArray ) = true
4555recurse (_) = false
4656
57+ _ptrarray (ipiv) = PtrArray (ipiv)
58+ _ptrarray (ipiv:: NotIPIV ) = ipiv
4759function lu! (A:: AbstractMatrix{T} , ipiv:: AbstractVector{<:Integer} ,
4860 pivot = Val (true ), thread = Val (true );
4961 check:: Bool = true ,
@@ -58,7 +70,7 @@ function lu!(A::AbstractMatrix{T}, ipiv::AbstractVector{<:Integer},
5870 if T <: Union{Float32, Float64}
5971 GC. @preserve ipiv A begin info = recurse! (view (PtrArray (A), axes (A)... ), pivot,
6072 m, n, mnmin,
61- PtrArray (ipiv), info, blocksize,
73+ _ptrarray (ipiv), info, blocksize,
6274 thread) end
6375 else
6476 info = recurse! (A, pivot, m, n, mnmin, ipiv, info, blocksize, thread)
@@ -187,8 +199,10 @@ function reckernel!(A::AbstractMatrix{T}, pivot::Val{Pivot}, m, n, ipiv, info, b
187199 Pivot && apply_permutation! (P2, A21, thread)
188200
189201 info != previnfo && (info += n1)
190- @turbo warn_check_args= false for i in 1 : n2
191- P2[i] += n1
202+ if Pivot
203+ @turbo warn_check_args= false for i in 1 : n2
204+ P2[i] += n1
205+ end
192206 end
193207 return info
194208 end # inbounds
@@ -234,8 +248,8 @@ function _generic_lufact!(A, ::Val{Pivot}, ipiv, info) where {Pivot}
234248 amax = absi
235249 end
236250 end
251+ ipiv[k] = kp
237252 end
238- ipiv[k] = kp
239253 if ! iszero (A[kp, k])
240254 if k != kp
241255 # Interchange
0 commit comments