Skip to content

Commit 85f3dc0

Browse files
committed
Add REML named argument to fit!
1 parent 710fa69 commit 85f3dc0

File tree

3 files changed

+44
-25
lines changed

3 files changed

+44
-25
lines changed

src/linalg/logdet.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,20 +27,20 @@ function LD(d::DenseMatrix{T}) where {T}
2727
end
2828

2929
"""
30-
logdet(m::LinearMixedModel, REML::Bool=false)
30+
logdet(m::LinearMixedModel)
3131
3232
Return the value of `log(det(Λ'Z'ZΛ + I)) + log(det(LX*LX'))` evaluated in place.
3333
3434
Here LX is the diagonal term corresponding to the fixed-effects in the blocked
3535
lower Cholesky factor.
3636
"""
37-
function LinearAlgebra.logdet(m::LinearMixedModel{T}, REML::Bool=false) where {T}
37+
function LinearAlgebra.logdet(m::LinearMixedModel{T}) where {T}
3838
s = log(one(T))
3939
Ldat = m.L.data
4040
@inbounds for (i, trm) in enumerate(m.trms)
4141
isa(trm, AbstractFactorReTerm) && (s += LD(Ldat[Block(i, i)]))
4242
end
43-
if REML
43+
if m.optsum.REML
4444
feindex = length(m.trms) - 1
4545
fetrm = m.trms[feindex]
4646
if isa(fetrm, MatrixTerm)

src/pls.jl

Lines changed: 38 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -155,19 +155,24 @@ end
155155
StatsBase.coef(m::MixedModel) = fixef(m, false)
156156

157157
"""
158-
fit!(m::LinearMixedModel[, verbose::Bool=false])
158+
fit!(m::LinearMixedModel; verbose=false, REML=false)
159159
160160
Optimize the objective of a `LinearMixedModel`. When `verbose` is `true` the values of the
161-
objective and the parameters are printed on stdout at each function evaluation.
161+
objective and the parameters are printed on stdout at each function evaluation. The
162+
objective is negative twice the log-likelihood when `REML` is `false` (the default) or
163+
the REML criterion.
162164
"""
163-
function StatsBase.fit!(m::LinearMixedModel{T}; verbose::Bool=false, REML::Bool=false) where {T}
165+
function StatsBase.fit!(m::LinearMixedModel{T}; verbose=false, REML=nothing) where {T}
164166
optsum = m.optsum
165167
opt = Opt(optsum)
166168
feval = 0
169+
if isa(REML, Bool)
170+
optsum.REML = REML
171+
end
167172
function obj(x, g)
168173
isempty(g) || error("gradient not defined")
169174
feval += 1
170-
val = objective(updateL!(setθ!(m, x)), REML)
175+
val = objective(updateL!(setθ!(m, x)))
171176
feval == 1 && (optsum.finitial = val)
172177
verbose && println("f_", feval, ": ", round(val, digits=5), " ", x)
173178
val
@@ -229,13 +234,14 @@ lowerbd(m::LinearMixedModel) = lowerbd(m.trms)
229234
StatsBase.model_response(m::LinearMixedModel) = vec(m.trms[end].x)
230235

231236
"""
232-
objective(m::LinearMixedModel, REML::Bool=false)
237+
objective(m::LinearMixedModel)
233238
234-
Return negative twice the log-likelihood of model `m` or the REML criterion
239+
Return negative twice the log-likelihood of model `m` or the REML criterion,
240+
according to the value of `m.optsum.REML`
235241
"""
236-
function objective(m::LinearMixedModel, REML::Bool=false)
242+
function objective(m::LinearMixedModel)
237243
wts = m.sqrtwts
238-
logdet(m, REML) + dof_residual(m, REML)*(1 + log2π + log(varest(m, REML))) -
244+
logdet(m) + varest_denom(m)*(1 + log2π + log(varest(m))) -
239245
(isempty(wts) ? 0 : 2sum(log, wts))
240246
end
241247

@@ -269,7 +275,7 @@ end
269275

270276
StatsBase.dof(m::LinearMixedModel) = size(m)[2] + sum(nθ, m.trms) + 1
271277

272-
function StatsBase.dof_residual(m::LinearMixedModel, REML::Bool=false)
278+
function StatsBase.dof_residual(m::LinearMixedModel)
273279
(n, p, q, k) = size(m)
274280
n - REML * p
275281
end
@@ -290,7 +296,7 @@ end
290296
291297
Return the estimate of σ, the standard deviation of the per-observation noise.
292298
"""
293-
sdest(m::LinearMixedModel) = sqrtpwrss(m) / nobs(m)
299+
sdest(m::LinearMixedModel) = varest(m)
294300

295301
"""
296302
setθ!{T}(m::LinearMixedModel{T}, v::Vector{T})
@@ -321,12 +327,17 @@ This value is the contents of the `1 × 1` bottom right block of `m.L`
321327
"""
322328
sqrtpwrss(m::LinearMixedModel) = @views m.L.data.blocks[end, end][1]
323329

330+
function varest_denom(m::LinearMixedModel)
331+
(n, p, q, k) = size(m)
332+
n - m.optsum.REML * p
333+
end
334+
324335
"""
325-
varest(m::LinearMixedModel, REML::Bool=false)
336+
varest(m::LinearMixedModel)
326337
327338
Returns the estimate of σ², the variance of the conditional distribution of Y given B.
328339
"""
329-
varest(m::LinearMixedModel, REML::Bool=false) = pwrss(m) / dof_residual(m, REML)
340+
varest(m::LinearMixedModel) = pwrss(m) / varest_denom(m)
330341

331342
"""
332343
pwrss(m::LinearMixedModel)
@@ -392,19 +403,25 @@ function Base.show(io::IO, m::LinearMixedModel)
392403
return nothing
393404
end
394405
n, p, q, k = size(m)
395-
println(io, "Linear mixed model fit by maximum likelihood")
406+
REML = m.optsum.REML
407+
println(io, "Linear mixed model fit by ", REML ? "REML" : "maximum likelihood")
396408
println(io, " ", m.formula)
397409
oo = objective(m)
398-
nums = showoff([-oo/ 2, oo, aic(m), bic(m)])
399-
fieldwd = max(maximum(textwidth.(nums)) + 1, 11)
400-
for label in [" logLik", "-2 logLik", "AIC", "BIC"]
401-
print(io, rpad(lpad(label, (fieldwd + textwidth(label)) >> 1), fieldwd))
410+
if REML
411+
println(io, " REML criterion at convergence: ", oo)
412+
else
413+
nums = showoff([-oo/ 2, oo, aic(m), bic(m)])
414+
fieldwd = max(maximum(textwidth.(nums)) + 1, 11)
415+
for label in [" logLik", "-2 logLik", "AIC", "BIC"]
416+
print(io, rpad(lpad(label, (fieldwd + textwidth(label)) >> 1), fieldwd))
417+
end
418+
println(io)
419+
for num in nums
420+
print(io, lpad(num, fieldwd))
421+
end
422+
println(io)
402423
end
403424
println(io)
404-
for num in nums
405-
print(io, lpad(num, fieldwd))
406-
end
407-
println(io); println(io)
408425

409426
show(io,VarCorr(m))
410427

src/types.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ Summary of an `NLopt` optimization
159159
* `optimizer`: the name of the optimizer used, as a `Symbol`
160160
* `returnvalue`: the return value, as a `Symbol`
161161
* `nAGQ`: number of adaptive Gauss-Hermite quadrature points in deviance evaluation for GLMMs
162+
* `REML`: use the REML criterion for LMM fits
162163
163164
The latter field doesn't really belong here but it has to be in a mutable struct in case it is changed.
164165
"""
@@ -178,12 +179,13 @@ mutable struct OptSummary{T <: AbstractFloat}
178179
optimizer::Symbol
179180
returnvalue::Symbol
180181
nAGQ::Integer # doesn't really belong here but I needed some place to store it
182+
REML::Bool # similarly, just needed a place to store this information
181183
end
182184
function OptSummary(initial::Vector{T}, lowerbd::Vector{T},
183185
optimizer::Symbol; ftol_rel::T=zero(T), ftol_abs::T=zero(T), xtol_rel::T=zero(T),
184186
initial_step::Vector{T}=T[]) where T <: AbstractFloat
185187
OptSummary(initial, lowerbd, T(Inf), ftol_rel, ftol_abs, xtol_rel, zero(initial),
186-
initial_step, -1, copy(initial), T(Inf), -1, optimizer, :FAILURE, 1)
188+
initial_step, -1, copy(initial), T(Inf), -1, optimizer, :FAILURE, 1, false)
187189
end
188190

189191
function Base.show(io::IO, s::OptSummary)

0 commit comments

Comments
 (0)