Skip to content

Commit 510dcc3

Browse files
authored
use FMA where possible in fitting (#740)
* use FMA where possible in fitting * use muladd everywhere * NEWS update * format
1 parent c1f9ca0 commit 510dcc3

File tree

6 files changed

+30
-23
lines changed

6 files changed

+30
-23
lines changed

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
MixedModels v4.22.5 Release Notes
22
==============================
3+
* Use `muladd` where possible to enable fused multiply-add (FMA) on architectures with hardware support. FMA will generally improve computational speed and gives more accurate rounding. [#740]
34
* Replace broadcasted lambda with explicit loop and use `one`. This may result in a small performance improvement. [#738]
45

56
MixedModels v4.22.4 Release Notes
@@ -500,5 +501,6 @@ Package dependencies
500501
[#717]: https://github.com/JuliaStats/MixedModels.jl/issues/717
501502
[#733]: https://github.com/JuliaStats/MixedModels.jl/issues/733
502503
[#738]: https://github.com/JuliaStats/MixedModels.jl/issues/738
504+
[#740]: https://github.com/JuliaStats/MixedModels.jl/issues/740
503505
[#744]: https://github.com/JuliaStats/MixedModels.jl/issues/744
504506
[#748]: https://github.com/JuliaStats/MixedModels.jl/issues/748

src/linalg.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ function LinearAlgebra.mul!(
1818
αbnz = α * bnz[ib]
1919
jj = brv[ib]
2020
for ia in nzrange(A, j)
21-
C[arv[ia], jj] += anz[ia] * αbnz
21+
C[arv[ia], jj] = muladd(anz[ia], αbnz, C[arv[ia], jj])
2222
end
2323
end
2424
end

src/linalg/rankUpdate.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ function MixedModels.rankUpdate!(
2222
Cdiag = C.data.diag
2323
Adiag = A.diag
2424
@inbounds for idx in eachindex(Cdiag, Adiag)
25-
Cdiag[idx] = β * Cdiag[idx] + α * abs2(Adiag[idx])
25+
Cdiag[idx] = muladd(β, Cdiag[idx], α * abs2(Adiag[idx]))
2626
end
2727
return C
2828
end
@@ -52,7 +52,7 @@ function _columndot(rv, nz, rngi, rngj)
5252
while i ni && j nj
5353
@inbounds ri, rj = rv[rngi[i]], rv[rngj[j]]
5454
if ri == rj
55-
@inbounds accum += nz[rngi[i]] * nz[rngj[j]]
55+
@inbounds accum = muladd(nz[rngi[i]], nz[rngj[j]], accum)
5656
i += 1
5757
j += 1
5858
elseif ri < rj
@@ -80,17 +80,17 @@ function rankUpdate!(C::HermOrSym{T,S}, A::SparseMatrixCSC{T}, α, β) where {T,
8080
rvj = rv[j]
8181
for i in k:lenrngjj
8282
kk = rangejj[i]
83-
Cd[rv[kk], rvj] += nz[kk] * anzj
83+
Cd[rv[kk], rvj] = muladd(nz[kk], anzj, Cd[rv[kk], rvj])
8484
end
8585
end
8686
end
8787
else
8888
@inbounds for j in axes(C, 2)
8989
rngj = nzrange(A, j)
9090
for i in 1:(j - 1)
91-
Cd[i, j] += α * _columndot(rv, nz, nzrange(A, i), rngj)
91+
Cd[i, j] = muladd(α, _columndot(rv, nz, nzrange(A, i), rngj), Cd[i, j])
9292
end
93-
Cd[j, j] += α * sum(i -> abs2(nz[i]), rngj)
93+
Cd[j, j] = muladd(α, sum(i -> abs2(nz[i]), rngj), Cd[j, j])
9494
end
9595
end
9696
return C
@@ -109,7 +109,7 @@ function rankUpdate!(
109109
isone(β) || rmul!(Cdiag, β)
110110

111111
@inbounds for i in eachindex(Cdiag)
112-
Cdiag[i] += α * sum(abs2, view(A, i, :))
112+
Cdiag[i] = muladd(α, sum(abs2, view(A, i, :)), Cdiag[i])
113113
end
114114

115115
return C
@@ -132,9 +132,9 @@ function rankUpdate!(
132132
AtAij = 0
133133
for idx in axes(A, 2)
134134
# because the second multiplicant is from A', swap index order
135-
AtAij += A[iind, idx] * A[jind, idx]
135+
AtAij = muladd(A[iind, idx], A[jind, idx], AtAij)
136136
end
137-
Cdat[i, j, k] += α * AtAij
137+
Cdat[i, j, k] = muladd(α, AtAij, Cdat[i, j, k])
138138
end
139139
end
140140

@@ -152,7 +152,7 @@ function rankUpdate!(
152152
throw(ArgumentError("Columns of A must have exactly 1 nonzero"))
153153

154154
for (r, nz) in zip(rowvals(A), nonzeros(A))
155-
dd[r] += α * abs2(nz)
155+
dd[r] = muladd(α, abs2(nz), dd[r])
156156
end
157157

158158
return C

src/linearmixedmodel.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -767,7 +767,9 @@ function StatsAPI.leverage(m::LinearMixedModel{T}) where {T}
767767
z = trm.z
768768
stride = size(z, 1)
769769
mul!(
770-
view(rhs2, (rhsoffset + (trm.refs[i] - 1) * stride) .+ Base.OneTo(stride)),
770+
view(
771+
rhs2, muladd((trm.refs[i] - 1), stride, rhsoffset) .+ Base.OneTo(stride)
772+
),
771773
adjoint(trm.λ),
772774
view(z, :, i),
773775
)
@@ -816,7 +818,7 @@ function objective(m::LinearMixedModel{T}) where {T}
816818
val = if isnothing(σ)
817819
logdet(m) + denomdf * (one(T) + log2π + log(pwrss(m) / denomdf))
818820
else
819-
denomdf * (log2π + 2 * log(σ)) + logdet(m) + pwrss(m) / σ^2
821+
muladd(denomdf, muladd(2, log(σ), log2π), (logdet(m) + pwrss(m) / σ^2))
820822
end
821823
return isempty(wts) ? val : val - T(2.0) * sum(log, wts)
822824
end

src/remat.jl

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ function LinearAlgebra.mul!(
284284
@inbounds for (j, rrj) in enumerate(B.refs)
285285
αzj = α * zz[j]
286286
for i in 1:p
287-
C[i, rrj] += αzj * Awt[j, i]
287+
C[i, rrj] = muladd(αzj, Awt[j, i], C[i, rrj])
288288
end
289289
end
290290
return C
@@ -310,7 +310,7 @@ function LinearAlgebra.mul!(
310310
aki = α * Awt[k, i]
311311
kk = Int(rr[k])
312312
for ii in 1:S
313-
scr[ii, kk] += aki * Bwt[ii, k]
313+
scr[ii, kk] = muladd(aki, Bwt[ii, k], scr[ii, kk])
314314
end
315315
end
316316
for j in 1:q
@@ -340,7 +340,7 @@ function LinearAlgebra.mul!(
340340
coljlast = Int(C.colptr[j + 1] - 1)
341341
K = searchsortedfirst(rv, i, Int(C.colptr[j]), coljlast, Base.Order.Forward)
342342
if K coljlast && rv[K] == i
343-
nz[K] += Az[k] * Bz[k]
343+
nz[K] = muladd(Az[k], Bz[k], nz[K])
344344
else
345345
throw(ArgumentError("C does not have the nonzero pattern of A'B"))
346346
end
@@ -361,7 +361,7 @@ function LinearAlgebra.mul!(
361361
@inbounds for i in 1:S
362362
zij = Awtz[i, j]
363363
for k in 1:S
364-
Cd[k, i, r] += zij * Awtz[k, j]
364+
Cd[k, i, r] = muladd(zij, Awtz[k, j], Cd[k, i, r])
365365
end
366366
end
367367
end
@@ -397,7 +397,7 @@ function LinearAlgebra.mul!(
397397
jjo = jj + joffset
398398
Bzijj = Bz[jj, i]
399399
for ii in 1:S
400-
C[ii + ioffset, jjo] += Az[ii, i] * Bzijj
400+
C[ii + ioffset, jjo] = muladd(Az[ii, i], Bzijj, C[ii + ioffset, jjo])
401401
end
402402
end
403403
end
@@ -416,7 +416,8 @@ function LinearAlgebra.mul!(
416416
isone(beta) || rmul!(y, beta)
417417
z = A.z
418418
@inbounds for (i, r) in enumerate(A.refs)
419-
y[i] += alpha * b[r] * z[i]
419+
# must be muladd and not fma because of potential missings
420+
y[i] = muladd(alpha * b[r], z[i], y[i])
420421
end
421422
return y
422423
end
@@ -446,7 +447,8 @@ function LinearAlgebra.mul!(
446447
@inbounds for (i, ii) in enumerate(A.refs)
447448
offset = (ii - 1) * k
448449
for j in 1:k
449-
y[i] += alpha * Z[j, i] * b[offset + j]
450+
# must be muladd and not fma because of potential missings
451+
y[i] = muladd(alpha * Z[j, i], b[offset + j], y[i])
450452
end
451453
end
452454
return y
@@ -466,7 +468,8 @@ function LinearAlgebra.mul!(
466468
isone(beta) || rmul!(y, beta)
467469
@inbounds for (i, ii) in enumerate(refarray(A))
468470
for j in 1:k
469-
y[i] += alpha * Z[j, i] * B[j, ii]
471+
# must be muladd and not fma because of potential missings
472+
y[i] = muladd(alpha * Z[j, i], B[j, ii], y[i])
470473
end
471474
end
472475
return y
@@ -566,7 +569,7 @@ function copyscaleinflate!(Ljj::Diagonal{T}, Ajj::Diagonal{T}, Λj::ReMat{T,1})
566569
Ldiag, Adiag = Ljj.diag, Ajj.diag
567570
lambsq = abs2(only(Λj.λ.data))
568571
@inbounds for i in eachindex(Ldiag, Adiag)
569-
Ldiag[i] = lambsq * Adiag[i] + one(T)
572+
Ldiag[i] = muladd(lambsq, Adiag[i], one(T))
570573
end
571574
return Ljj
572575
end
@@ -575,7 +578,7 @@ function copyscaleinflate!(Ljj::Matrix{T}, Ajj::Diagonal{T}, Λj::ReMat{T,1}) wh
575578
fill!(Ljj, zero(T))
576579
lambsq = abs2(only(Λj.λ.data))
577580
@inbounds for (i, a) in enumerate(Ajj.diag)
578-
Ljj[i, i] = lambsq * a + one(T)
581+
Ljj[i, i] = muladd(lambsq, a, one(T))
579582
end
580583
return Ljj
581584
end

test/pls.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ end
139139

140140
vc = fm1.vcov
141141
@test isa(vc, Matrix{Float64})
142-
@test only(vc) 375.7167775 rtol=1.e-6
142+
@test only(vc) 375.7167775 rtol=1.e-3
143143
# since we're caching the fits, we should get it back to being correctly fitted
144144
# we also take this opportunity to test fitlog
145145
@testset "fitlog" begin

0 commit comments

Comments
 (0)