@@ -462,12 +462,19 @@ outlier detection model.
462
462
When `evaluate`/`evaluate!` is called, a number of train/test pairs
463
463
("folds") of row indices are generated, according to the options
464
464
provided, which are discussed in the [`evaluate!`](@ref)
465
- doc-string. Rows correspond to observations. The train/test pairs
466
- generated are recorded in the `train_test_rows` field of the
465
+ doc-string. Rows correspond to observations. The generated train/test
466
+ pairs are recorded in the `train_test_rows` field of the
467
467
`PerformanceEvaluation` struct, and the corresponding estimates,
468
468
aggregated over all train/test pairs, are recorded in `measurement`, a
469
469
vector with one entry for each measure (metric) recorded in `measure`.
470
470
471
+ When displayed, a `PerformanceEvalution` object includes a value under
472
+ the heading `1.96*SE`, derived from the standard error of the `per_fold`
473
+ entries. This value is suitable for constructing a formal 95%
474
+ confidence interval for the given `measurement`. Such intervals should
475
+ be interpreted with caution. See, for example, Bates et al.
476
+ [(2021)](https://arxiv.org/abs/2104.00673).
477
+
471
478
### Fields
472
479
473
480
These fields are part of the public API of the `PerformanceEvaluation`
@@ -503,8 +510,9 @@ struct.
503
510
machine `mach` training in resampling - one machine per train/test
504
511
pair.
505
512
506
- - `train_test_rows`: a vector of tuples, each of the form `(train, test)`, where `train` and `test`
507
- are vectors of row (observation) indices for training and evaluation respectively.
513
+ - `train_test_rows`: a vector of tuples, each of the form `(train, test)`,
514
+ where `train` and `test` are vectors of row (observation) indices for
515
+ training and evaluation respectively.
508
516
"""
509
517
struct PerformanceEvaluation{M,
510
518
Measurement,
@@ -532,18 +540,35 @@ _short(v::Vector{<:Real}) = MLJBase.short_string(v)
532
540
_short (v:: Vector ) = string (" [" , join (_short .(v), " , " ), " ]" )
533
541
_short (:: Missing ) = missing
534
542
535
- function Base. show (io:: IO , :: MIME"text/plain" , e:: PerformanceEvaluation )
536
- _measure = map (e. measure) do m
537
- repr (MIME (" text/plain" ), m)
543
+ function _standard_errors (e:: PerformanceEvaluation )
544
+ factor = 1.96 # For the 95% confidence interval.
545
+ measure = e. measure
546
+ nfolds = length (e. per_fold[1 ])
547
+ nfolds == 1 && return [nothing ]
548
+ std_errors = map (e. per_fold) do per_fold
549
+ factor * std (per_fold) / sqrt (nfolds - 1 )
538
550
end
551
+ return std_errors
552
+ end
553
+
554
+ function Base. show (io:: IO , :: MIME"text/plain" , e:: PerformanceEvaluation )
555
+ _measure = [repr (MIME (" text/plain" ), m) for m in e. measure]
539
556
_measurement = round3 .(e. measurement)
540
557
_per_fold = [round3 .(v) for v in e. per_fold]
558
+ _sterr = round3 .(_standard_errors (e))
559
+
560
+ # Only show the standard error if the number of folds is higher than 1.
561
+ show_sterr = any (! isnothing, _sterr)
562
+ data = show_sterr ?
563
+ hcat (_measure, e. operation, _measurement, _sterr, _per_fold) :
564
+ hcat (_measure, e. operation, _measurement, _per_fold)
565
+ header = show_sterr ?
566
+ [" measure" , " operation" , " measurement" , " 1.96*SE" , " per_fold" ] :
567
+ [" measure" , " operation" , " measurement" , " per_fold" ]
541
568
542
- data = hcat (_measure, _measurement, e. operation, _per_fold)
543
- header = [" measure" , " measurement" , " operation" , " per_fold" ]
544
569
println (io, " PerformanceEvaluation object " *
545
570
" with these fields:" )
546
- println (io, " measure, measurement, operation , per_fold,\n " *
571
+ println (io, " measure, operation, measurement , per_fold,\n " *
547
572
" per_observation, fitted_params_per_fold,\n " *
548
573
" report_per_fold, train_test_rows" )
549
574
println (io, " Extract:" )
0 commit comments