Skip to content

Commit 4341ee3

Browse files
authored
Merge pull request #794 from JuliaAI/dev
For a 0.20.7 release
2 parents 66065db + 095485f commit 4341ee3

File tree

5 files changed

+61
-26
lines changed

5 files changed

+61
-26
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLJBase"
22
uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
33
authors = ["Anthony D. Blaom <[email protected]>"]
4-
version = "0.20.6"
4+
version = "0.20.7"
55

66
[deps]
77
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"

src/measures/finite.jl

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -966,12 +966,7 @@ function _mtn(m::CM, return_type::Type{LittleDict})
966966
return LittleDict(m.labels, vec(_sum))
967967
end
968968

969-
@inline function _mean(x::Arr{<:Real})
970-
for i in eachindex(x)
971-
@inbounds x[i] = ifelse(isnan(x[i]), zero(eltype(x)), x[i])
972-
end
973-
return mean(x)
974-
end
969+
@inline _mean(x::Arr{<:Real}) = mean(skipnan(x)) # defined in src/data/data.jl
975970

976971
@inline function _class_w(level_m::Arr{<:String},
977972
class_w::AbstractDict{<:Any, <:Real})

src/resampling.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -827,11 +827,11 @@ measure or vector.
827827
828828
Do `subtypes(MLJ.ResamplingStrategy)` to obtain a list of available
829829
resampling strategies. If `resampling` is not an object of type
830-
`MLJ.ResamplingStrategy`, then a vector of pairs (of the form
830+
`MLJ.ResamplingStrategy`, then a vector of tuples (of the form
831831
`(train_rows, test_rows)` is expected. For example, setting
832832
833-
resampling = [(1:100), (101:200)),
834-
(101:200), (1:100)]
833+
resampling = [((1:100), (101:200)),
834+
((101:200), (1:100))]
835835
836836
gives two-fold cross-validation using the first 200 rows of data.
837837
@@ -1164,7 +1164,7 @@ function evaluate!(mach::Machine, resampling, weights,
11641164

11651165
if !(resampling isa TrainTestPairs)
11661166
error("`resampling` must be an "*
1167-
"`MLJ.ResamplingStrategy` or tuple of pairs "*
1167+
"`MLJ.ResamplingStrategy` or tuple of rows "*
11681168
"of the form `(train_rows, test_rows)`")
11691169
end
11701170

test/measures/finite.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,28 @@ end
284284
2 ./ ( 1 ./ ( sum(cm_tp) ./ sum(cm_fp.+cm_tp) ) + 1 ./ ( sum(cm_tp) ./ sum(cm_fn.+cm_tp) ) )
285285
end
286286

287+
@testset "issue #630" begin
288+
# multiclass fscore corner case of absent class
289+
290+
y = coerce([1, 2, 2, 2, 3], OrderedFactor)[1:4]
291+
# [1, 2, 2, 2] # but 3 is in the pool
292+
yhat = reverse(y)
293+
# [2, 2, 2, 1]
294+
295+
# In this case, assigning "3" as "positive" gives all true negative,
296+
# and so NaN for that class's contribution to the average F1Score,
297+
# which should accordingly be skipped.
298+
299+
# postive class | TP | FP | FN | score for that class
300+
# --------------|----|----|----|---------------------
301+
# 1 | 0 | 1 | 2 | 0
302+
# 2 | 2 | 1 | 1 | 2/3
303+
# 3 | 0 | 0 | 0 | NaN
304+
305+
# mean score with skippin NaN is 1/3
306+
@test MulticlassFScore()(yhat, y) 1/3
307+
end
308+
287309
@testset "Metadata binary" begin
288310
for m in (accuracy, recall, Precision(), f1score, specificity)
289311
e = info(m)

test/resampling.jl

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -486,24 +486,43 @@ end
486486
end
487487

488488
@testset_accelerated "class weights in evaluation" accel begin
489-
x = [1,2,3,4,5,6,7]
490-
X, y = Tables.table([x x x x x x]), coerce([1,2,1,3,1,2,2], Multiclass)
491-
model = Models.DeterministicConstantClassifier()
492-
mach = machine(model, X, y)
489+
X, y = make_blobs(rng=rng)
493490
cv=CV(nfolds = 2)
491+
fold1, fold2 = partition(eachindex(y), 0.5)
492+
m = MLJBase.MulticlassFScore()
494493
class_w = Dict(1=>1, 2=>2, 3=>3)
495-
e = evaluate!(mach,
496-
resampling=cv,
497-
measure=MLJBase.MulticlassFScore(return_type=Vector),
498-
class_weights=class_w,
499-
verbosity=verb,
500-
acceleration=accel).measurement[1]
501-
@test round(e, digits=3) 0.217
494+
495+
model = Models.DeterministicConstantClassifier()
496+
mach = machine(model, X, y)
497+
498+
# fscore by hand:
499+
fit!(mach, rows=fold1, verbosity=0)
500+
score1 = m(predict(mach, rows=fold2), y[fold2], class_w)
501+
fit!(mach, rows=fold2, verbosity=0)
502+
score2 = m(predict(mach, rows=fold1), y[fold1], class_w)
503+
score_by_hand = mean([score1, score2])
504+
505+
# fscore by evaluate!:
506+
score = evaluate!(
507+
mach,
508+
resampling=cv,
509+
measure=m,
510+
class_weights=class_w,
511+
verbosity=verb,
512+
acceleration=accel,
513+
).measurement[1]
514+
515+
@test score score_by_hand
502516

503517
# if class weights in `evaluate!` isn't specified:
504-
e = evaluate!(mach, resampling=cv, measure=multiclass_f1score,
505-
verbosity=verb, acceleration=accel).measurement[1]
506-
@test e 0.15
518+
plain_score = evaluate!(
519+
mach,
520+
resampling=cv,
521+
measure=m,
522+
verbosity=verb,
523+
acceleration=accel,
524+
).measurement[1]
525+
@test !(score plain_score)
507526
end
508527

509528
@testset_accelerated "resampler as machine" accel begin
@@ -794,4 +813,3 @@ end
794813

795814
#end
796815
true
797-

0 commit comments

Comments
 (0)