Skip to content

Commit cc3dbe5

Browse files
authored
Merge pull request #770 from JuliaAI/dev
For a 0.20.3 release
2 parents db7cc01 + 2b8d578 commit cc3dbe5

File tree

5 files changed

+127
-73
lines changed

5 files changed

+127
-73
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLJBase"
22
uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
33
authors = ["Anthony D. Blaom <[email protected]>"]
4-
version = "0.20.2"
4+
version = "0.20.3"
55

66
[deps]
77
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"

src/resampling.jl

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -462,12 +462,19 @@ outlier detection model.
462462
When `evaluate`/`evaluate!` is called, a number of train/test pairs
463463
("folds") of row indices are generated, according to the options
464464
provided, which are discussed in the [`evaluate!`](@ref)
465-
doc-string. Rows correspond to observations. The train/test pairs
466-
generated are recorded in the `train_test_rows` field of the
465+
doc-string. Rows correspond to observations. The generated train/test
466+
pairs are recorded in the `train_test_rows` field of the
467467
`PerformanceEvaluation` struct, and the corresponding estimates,
468468
aggregated over all train/test pairs, are recorded in `measurement`, a
469469
vector with one entry for each measure (metric) recorded in `measure`.
470470
471+
When displayed, a `PerformanceEvalution` object includes a value under
472+
the heading `1.96*SE`, derived from the standard error of the `per_fold`
473+
entries. This value is suitable for constructing a formal 95%
474+
confidence interval for the given `measurement`. Such intervals should
475+
be interpreted with caution. See, for example, Bates et al.
476+
[(2021)](https://arxiv.org/abs/2104.00673).
477+
471478
### Fields
472479
473480
These fields are part of the public API of the `PerformanceEvaluation`
@@ -503,8 +510,9 @@ struct.
503510
machine `mach` training in resampling - one machine per train/test
504511
pair.
505512
506-
- `train_test_rows`: a vector of tuples, each of the form `(train, test)`, where `train` and `test`
507-
are vectors of row (observation) indices for training and evaluation respectively.
513+
- `train_test_rows`: a vector of tuples, each of the form `(train, test)`,
514+
where `train` and `test` are vectors of row (observation) indices for
515+
training and evaluation respectively.
508516
"""
509517
struct PerformanceEvaluation{M,
510518
Measurement,
@@ -532,18 +540,35 @@ _short(v::Vector{<:Real}) = MLJBase.short_string(v)
532540
_short(v::Vector) = string("[", join(_short.(v), ", "), "]")
533541
_short(::Missing) = missing
534542

535-
function Base.show(io::IO, ::MIME"text/plain", e::PerformanceEvaluation)
536-
_measure = map(e.measure) do m
537-
repr(MIME("text/plain"), m)
543+
function _standard_errors(e::PerformanceEvaluation)
544+
factor = 1.96 # For the 95% confidence interval.
545+
measure = e.measure
546+
nfolds = length(e.per_fold[1])
547+
nfolds == 1 && return [nothing]
548+
std_errors = map(e.per_fold) do per_fold
549+
factor * std(per_fold) / sqrt(nfolds - 1)
538550
end
551+
return std_errors
552+
end
553+
554+
function Base.show(io::IO, ::MIME"text/plain", e::PerformanceEvaluation)
555+
_measure = [repr(MIME("text/plain"), m) for m in e.measure]
539556
_measurement = round3.(e.measurement)
540557
_per_fold = [round3.(v) for v in e.per_fold]
558+
_sterr = round3.(_standard_errors(e))
559+
560+
# Only show the standard error if the number of folds is higher than 1.
561+
show_sterr = any(!isnothing, _sterr)
562+
data = show_sterr ?
563+
hcat(_measure, e.operation, _measurement, _sterr, _per_fold) :
564+
hcat(_measure, e.operation, _measurement, _per_fold)
565+
header = show_sterr ?
566+
["measure", "operation", "measurement", "1.96*SE", "per_fold"] :
567+
["measure", "operation", "measurement", "per_fold"]
541568

542-
data = hcat(_measure, _measurement, e.operation, _per_fold)
543-
header = ["measure", "measurement", "operation", "per_fold"]
544569
println(io, "PerformanceEvaluation object "*
545570
"with these fields:")
546-
println(io, " measure, measurement, operation, per_fold,\n"*
571+
println(io, " measure, operation, measurement, per_fold,\n"*
547572
" per_observation, fitted_params_per_fold,\n"*
548573
" report_per_fold, train_test_rows")
549574
println(io, "Extract:")

test/preliminaries.jl

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
using MLJBase
2+
if !MLJBase.TESTING
3+
error(
4+
"To test MLJBase, the environment variable "*
5+
"`TEST_MLJBASE` must be set to `\"true\"`\n"*
6+
"You can do this in the REPL with `ENV[\"TEST_MLJBASE\"]=\"true\"`"
7+
)
8+
end
9+
10+
using Distributed
11+
# Thanks to https://stackoverflow.com/a/70895939/5056635 for the exeflags tip.
12+
addprocs(; exeflags="--project=$(Base.active_project())")
13+
14+
@info "nprocs() = $(nprocs())"
15+
@static if VERSION >= v"1.3.0-DEV.573"
16+
import .Threads
17+
@info "nthreads() = $(Threads.nthreads())"
18+
else
19+
@info "Running julia $(VERSION). Multithreading tests excluded. "
20+
end
21+
22+
@everywhere begin
23+
using MLJModelInterface
24+
using MLJBase
25+
using Test
26+
using CategoricalArrays
27+
using Logging
28+
using ComputationalResources
29+
using StableRNGs
30+
end
31+
32+
import TypedTables
33+
using Tables
34+
35+
function include_everywhere(filepath)
36+
include(filepath) # Load on Node 1 first, triggering any precompile
37+
if nprocs() > 1
38+
fullpath = joinpath(@__DIR__, filepath)
39+
@sync for p in workers()
40+
@async remotecall_wait(include, p, fullpath)
41+
end
42+
end
43+
end
44+
45+
include("test_utilities.jl")
46+
47+
# load Models module containing model implementations for testing:
48+
print("Loading some models for testing...")
49+
include_everywhere("_models/models.jl")
50+
print("\r \r")
51+
52+
# enable conditional testing of modules by providing test_args
53+
# e.g. `Pkg.test("MLJBase", test_args=["misc"])`
54+
RUN_ALL_TESTS = isempty(ARGS)
55+
macro conditional_testset(name, expr)
56+
name = string(name)
57+
esc(quote
58+
if RUN_ALL_TESTS || $name in ARGS
59+
@testset $name $expr
60+
end
61+
end)
62+
end
63+
64+
# To avoid printing `@conditional_testset (macro with 1 method)`
65+
# when loading this file via `include("test/preliminaries.jl")`.
66+
nothing

test/resampling.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -775,12 +775,21 @@ end
775775
@test T <: PerformanceEvaluation
776776

777777
show_text = sprint(show, MIME"text/plain"(), evaluations)
778+
cols = ["measure", "operation", "measurement", "1.96*SE", "per_fold"]
779+
@test all(contains.(show_text, cols))
780+
print(show_text)
778781
docstring_text = string(@doc(PerformanceEvaluation))
779782
for fieldname in fieldnames(PerformanceEvaluation)
780783
@test contains(show_text, string(fieldname))
781784
# string(text::Markdown.MD) converts `-` list items to `*`.
782785
@test contains(docstring_text, " * `$fieldname`")
783786
end
787+
788+
measures = [LogLoss(), Accuracy()]
789+
evaluations = evaluate(clf, X, y; measures, resampling=Holdout())
790+
show_text = sprint(show, MIME"text/plain"(), evaluations)
791+
print(show_text)
792+
@test !contains(show_text, "std")
784793
end
785794

786795
#end

test/runtests.jl

Lines changed: 16 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,65 +1,19 @@
1-
using Distributed
2-
addprocs()
3-
4-
5-
using MLJBase
6-
if !MLJBase.TESTING
7-
error(
8-
"To test MLJBase, the environment variable "*
9-
"`TEST_MLJBASE` must be set to `\"true\"`\n"*
10-
"You can do this in the REPL with `ENV[\"TEST_MLJBASE\"]=\"true\"`"
11-
)
12-
end
13-
14-
@info "nprocs() = $(nprocs())"
15-
@static if VERSION >= v"1.3.0-DEV.573"
16-
import .Threads
17-
@info "nthreads() = $(Threads.nthreads())"
18-
else
19-
@info "Running julia $(VERSION). Multithreading tests excluded. "
20-
end
21-
22-
@everywhere begin
23-
using MLJModelInterface
24-
using MLJBase
25-
using Test
26-
using CategoricalArrays
27-
using Logging
28-
using ComputationalResources
29-
using StableRNGs
30-
end
31-
32-
import TypedTables
33-
using Tables
34-
35-
function include_everywhere(filepath)
36-
include(filepath) # Load on Node 1 first, triggering any precompile
37-
if nprocs() > 1
38-
fullpath = joinpath(@__DIR__, filepath)
39-
@sync for p in workers()
40-
@async remotecall_wait(include, p, fullpath)
41-
end
42-
end
43-
end
44-
45-
include("test_utilities.jl")
46-
47-
# load Models module containing model implementations for testing:
48-
print("Loading some models for testing...")
49-
include_everywhere("_models/models.jl")
50-
print("\r \r")
51-
52-
# enable conditional testing of modules by providing test_args
53-
# e.g. `Pkg.test("MLJBase", test_args=["misc"])`
54-
RUN_ALL_TESTS = isempty(ARGS)
55-
macro conditional_testset(name, expr)
56-
name = string(name)
57-
esc(quote
58-
if RUN_ALL_TESTS || $name in ARGS
59-
@testset $name $expr
60-
end
61-
end)
62-
end
1+
# To speed up the development workflow, use `TestEnv`.
2+
# For example:
3+
# ```
4+
# $ julia --project
5+
#
6+
# julia> ENV["TEST_MLJBASE"] = "true"
7+
#
8+
# julia> using TestEnv; TestEnv.activate()
9+
#
10+
# julia> include("test/preliminaries.jl")
11+
# [...]
12+
#
13+
# julia> include("test/resampling.jl")
14+
# [...]
15+
# ```
16+
include("preliminaries.jl")
6317

6418
@conditional_testset "misc" begin
6519
@test include("utilities.jl")

0 commit comments

Comments
 (0)