Skip to content

Commit 5dca0d6

Browse files
authored
Static arrays (#99)
* Initial changes for StaticArrays [ci skip] * Changes to reduce allocation * Refactoring to reduce allocations * Slight changes in results * Relax tolerance on comparison
1 parent a560d35 commit 5dca0d6

File tree

13 files changed

+136
-360
lines changed

13 files changed

+136
-360
lines changed

REQUIRE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,6 @@ GLM 0.7
88
NamedArrays
99
NLopt 0.3
1010
Showoff
11+
StaticArrays 0.6
1112
StatsBase 0.11
1213
StatsFuns 0.3

src/MixedModels.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ __precompile__()
33
module MixedModels
44

55
using ArgCheck, BlockArrays, DataArrays, DataFrames
6-
using Distributions, GLM, NLopt, Showoff, StatsBase
6+
using Distributions, GLM, NLopt, Showoff, StaticArrays, StatsBase
77
using StatsFuns: log2π
88
using NamedArrays: NamedArray, setnames!
99
using Base.LinAlg: BlasFloat, BlasReal, HermOrSym, PosDefException, checksquare, copytri!

src/linalg.jl

Lines changed: 1 addition & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -25,45 +25,12 @@ function αβA_mul_Bc!(α::T, A::SparseMatrixCSC{T}, B::SparseMatrixCSC{T},
2525
C
2626
end
2727

28-
#= This is not tested. See if it is really needed.
29-
function αβA_mul_Bc!(α::T, A::SparseMatrixCSC{T}, B::SparseMatrixCSC{T},
30-
β::T, C::SparseMatrixCSC{T}) where T <: Number
31-
@argcheck B.m == C.n && A.m == C.m && A.n == B.n DimensionMismatch
32-
anz = nonzeros(A)
33-
arv = rowvals(A)
34-
bnz = nonzeros(B)
35-
brv = rowvals(B)
36-
cnz = nonzeros(C)
37-
crv = rowvals(C)
38-
if β ≠ one(T)
39-
iszero(β) ? fill!(cnz, β) : scale!(cnz, β)
40-
end
41-
for j = 1:A.n
42-
for ib in nzrange(B, j)
43-
αbnz = α * bnz[ib]
44-
jj = brv[ib]
45-
for ia in nzrange(A, j)
46-
crng = nzrange(C, jj)
47-
ind = findfirst(crv[crng], arv[ia])
48-
if iszero(ind)
49-
throw(ArgumentError("A*B' has nonzero positions not in C"))
50-
end
51-
cnz[crng[ind]] += anz[ia] * αbnz
52-
end
53-
end
54-
end
55-
C
56-
end
57-
=#
58-
5928
function αβA_mul_Bc!::T, A::StridedVecOrMat{T}, B::SparseMatrixCSC{T}, β::T,
6029
C::StridedVecOrMat{T}) where T
6130
m, n = size(A)
6231
p, q = size(B)
6332
r, s = size(C)
64-
if r m || s p || n q
65-
throw(DimensionMismatch("size(C,1) ≠ size(A,1) or size(C,2) ≠ size(B,1) or size(A,2) ≠ size(B,2)"))
66-
end
33+
@argcheck(r == m && s == p && n == q, DimensionMismatch)
6734
if β one(T)
6835
iszero(β) ? fill!(C, β) : scale!(C, β)
6936
end

src/linalg/cholUnblocked.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,9 @@ function cholUnblocked!(A::StridedMatrix{T}, ::Type{Val{:L}}) where T<:BlasFloat
2929
A
3030
end
3131

32-
cholUnblocked!(D::UniformBlockDiagonal, ::Type{Val{:L}}) = cholUnblocked!.(D.facevec, Val{:L})
32+
function cholUnblocked!(D::UniformBlockDiagonal, ::Type{Val{:L}})
33+
for f in D.facevec
34+
cholUnblocked!(f, Val{:L})
35+
end
36+
D
37+
end

src/linalg/logdet.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,11 @@ end
3232
Return the value of `log(det(Λ'Z'ZΛ + I))` evaluated in place.
3333
"""
3434
function logdet(m::LinearMixedModel{T}) where {T}
35-
s = zero(T)
35+
s = log(one(T))
36+
Ldat = m.L.data
3637
for (i, trm) in enumerate(m.trms)
3738
if isa(trm, AbstractFactorReTerm)
38-
s += T(LD(m.L.data[Block(i, i)]))
39+
s += LD(m.L.data[Block(i, i)])
3940
end
4041
end
4142
2s

src/linalg/rankUpdate.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,22 @@ function rankUpdate!(α::T, a::StridedVector{T},
1515
BLAS.syr!(A.uplo, α, a, A.data)
1616
A
1717
end
18-
rankUpdate!(a::StridedVector{T}, A::HermOrSym{T,S}) where {T<:BlasReal,S<:StridedMatrix} = rankUpdate!(one(T), a, A)
18+
19+
rankUpdate!(a::StridedVector{T}, A::HermOrSym{T,S}) where {T<:BlasReal,S<:StridedMatrix} =
20+
rankUpdate!(one(T), a, A)
1921

2022
rankUpdate!::T, A::StridedMatrix{T}, β::T,
21-
C::HermOrSym{T,S}) where {T<:BlasReal,S<:StridedMatrix} = BLAS.syrk!(C.uplo, 'N', α, A, β, C.data)
23+
C::HermOrSym{T,S}) where {T<:BlasReal,S<:StridedMatrix} =
24+
BLAS.syrk!(C.uplo, 'N', α, A, β, C.data)
25+
2226
rankUpdate!::T, A::StridedMatrix{T}, C::HermOrSym{T,S}) where {T<:Real,S<:StridedMatrix} =
2327
rankUpdate!(α, A, one(T), C)
28+
2429
rankUpdate!(A::StridedMatrix{T}, C::HermOrSym{T,S}) where {T<:Real,S<:StridedMatrix} =
2530
rankUpdate!(one(T), A, one(T), C)
2631

27-
function rankUpdate!::T, A::SparseMatrixCSC{T,I},
28-
β::T, C::HermOrSym{T,S}) where {T,I,S<:StridedMatrix{T}}
32+
function rankUpdate!::T, A::SparseMatrixCSC{T},
33+
β::T, C::HermOrSym{T,S}) where {T,S<:StridedMatrix{T}}
2934
m, n = size(A)
3035
@argcheck m == size(C, 2) && C.uplo == 'L' DimensionMismatch
3136
Cd = C.data
@@ -73,6 +78,7 @@ function rankUpdate!(α::T, A::SparseMatrixCSC{T},
7378
C::HermOrSym{T,UniformBlockDiagonal{T}}) where T<:Number
7479
m, n, k = size(C.data.data)
7580
@argcheck m == n && size(A, 1) == m * k DimensionMismatch
81+
# Another expensive evaluation in terms of storage allocation
7682
aat = α * (A * A')
7783
nz = nonzeros(aat)
7884
rv = rowvals(aat)

0 commit comments

Comments
 (0)