Skip to content

Commit 52849b7

Browse files
authored
Merge pull request #808 from JuliaAI/dev
For a 0.20.12 release
2 parents 1c79f8f + ed601bd commit 52849b7

File tree

12 files changed

+172
-81
lines changed

12 files changed

+172
-81
lines changed

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLJBase"
22
uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
33
authors = ["Anthony D. Blaom <[email protected]>"]
4-
version = "0.20.11"
4+
version = "0.20.12"
55

66
[deps]
77
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
@@ -36,14 +36,14 @@ ComputationalResources = "0.3"
3636
Distributions = "0.25.3"
3737
InvertedIndices = "1"
3838
LossFunctions = "0.5, 0.6, 0.7, 0.8"
39-
MLJModelInterface = "1.5"
39+
MLJModelInterface = "1.6"
4040
Missings = "0.4, 1"
4141
OrderedCollections = "1.1"
4242
Parameters = "0.12"
4343
PrettyTables = "1"
4444
ProgressMeter = "1.7.1"
4545
ScientificTypes = "3"
46-
StatisticalTraits = "3"
46+
StatisticalTraits = "3.2"
4747
StatsBase = "0.32, 0.33"
4848
Tables = "0.2, 1.0"
4949
julia = "1.6"

src/composition/learning_networks/machines.jl

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,22 @@ $DOC_SIGNATURES
225225
"""
226226
glb(mach::Machine{<:Union{Composite,Surrogate}}) = glb(mach.fitresult)
227227

228+
"""
229+
report(fitresult::CompositeFitresult)
230+
231+
Return a tuple combining the report from `fitresult.glb` (a `Node` report) with the
232+
additions coming from nodes declared as report nodes in `fitresult.signature`, but without
233+
merging the two.
234+
235+
$DOC_SIGNATURES
236+
237+
**Private method**
238+
"""
239+
function report(fitresult::CompositeFitresult)
240+
basic = report(glb(fitresult))
241+
additions = _call(_report_part(signature(fitresult)))
242+
return (; basic, additions)
243+
end
228244

229245
"""
230246
fit!(mach::Machine{<:Surrogate};
@@ -245,11 +261,10 @@ See also [`machine`](@ref)
245261
246262
"""
247263
function fit!(mach::Machine{<:Surrogate}; kwargs...)
248-
glb_node = glb(mach)
249-
fit!(glb_node; kwargs...)
264+
glb = MLJBase.glb(mach)
265+
fit!(glb; kwargs...)
250266
mach.state += 1
251-
report_additions_ = _call(_report_part(signature(mach.fitresult)))
252-
mach.report = merge(report(glb_node), report_additions_)
267+
mach.report = MLJBase.report(mach.fitresult)
253268
return mach
254269
end
255270

@@ -347,7 +362,7 @@ the following:
347362
348363
- Calls `fit!(mach, verbosity=verbosity, acceleration=acceleration)`.
349364
350-
- Records a copy of `model` in a variable called `cache`.
365+
- Records (among other things) a copy of `model` in a variable called `cache`
351366
352367
- Returns `cache` and outcomes of training in an appropriate form
353368
(specifically, `(mach.fitresult, cache, mach.report)`; see [Adding
@@ -396,6 +411,7 @@ function return!(mach::Machine{<:Surrogate},
396411
# record the current hyper-parameter values:
397412
old_model = deepcopy(model)
398413

414+
glb = MLJBase.glb(mach)
399415
cache = (; old_model)
400416

401417
setfield!(mach.fitresult,
@@ -647,9 +663,8 @@ function restore!(mach::Machine{<:Composite})
647663
return mach
648664
end
649665

650-
651-
function setreport!(mach::Machine{<:Composite}, report)
652-
basereport = MLJBase.report(glb(mach))
653-
report_additions = Base.structdiff(report, basereport)
654-
mach.report = merge(basereport, report_additions)
666+
function setreport!(copymach::Machine{<:Composite}, mach)
667+
basic = report(glb(copymach.fitresult))
668+
additions = mach.report.additions
669+
copymach.report = (; basic, additions)
655670
end

src/composition/models/inspection.jl

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@ try_scalarize(v) = length(v) == 1 ? v[1] : v
55
function machines_given_model_name(mach::Machine{M}) where M<:Composite
66
network_model_names = getfield(mach.fitresult, :network_model_names)
77
names = unique(filter(name->!(name === nothing), network_model_names))
8-
network_models = MLJBase.models(glb(mach))
9-
network_machines = MLJBase.machines(glb(mach))
8+
glb = MLJBase.glb(mach)
9+
network_models = MLJBase.models(glb)
10+
network_machines = MLJBase.machines(glb)
1011
ret = LittleDict{Symbol,Any}()
1112
for name in names
1213
mask = map(==(name), network_model_names)
@@ -17,22 +18,27 @@ function machines_given_model_name(mach::Machine{M}) where M<:Composite
1718
return ret
1819
end
1920

20-
function tuple_keyed_on_model_names(item_given_machine, mach)
21+
function tuple_keyed_on_model_names(machines, mach, f)
2122
dict = MLJBase.machines_given_model_name(mach)
2223
names = tuple(keys(dict)...)
2324
named_tuple_values = map(names) do name
24-
[item_given_machine[m] for m in dict[name]] |> try_scalarize
25+
[f(m) for m in dict[name]] |> try_scalarize
2526
end
2627
return NamedTuple{names}(named_tuple_values)
2728
end
2829

29-
function report(mach::Machine{<:Composite})
30-
dict = mach.report.report_given_machine
31-
return merge(tuple_keyed_on_model_names(dict, mach), mach.report)
30+
function report(mach::Machine{<:Union{Composite,Surrogate}})
31+
report_additions = mach.report.additions
32+
report_basic = mach.report.basic
33+
report_components = mach isa Machine{<:Surrogate} ? NamedTuple() :
34+
MLJBase.tuple_keyed_on_model_names(report_basic.machines, mach, MLJBase.report)
35+
return merge(report_components, report_basic, report_additions)
3236
end
3337

3438
function fitted_params(mach::Machine{<:Composite})
35-
fp = fitted_params(mach.model, mach.fitresult)
36-
dict = fp.fitted_params_given_machine
37-
return merge(MLJBase.tuple_keyed_on_model_names(dict, mach), fp)
39+
fp_basic = fitted_params(mach.model, mach.fitresult)
40+
machines = fp_basic.machines
41+
fp_components =
42+
MLJBase.tuple_keyed_on_model_names(machines, mach, MLJBase.fitted_params)
43+
return merge(fp_components, fp_basic)
3844
end

src/composition/models/methods.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,24 +31,25 @@ function update(model::M,
3131
# underlying learning network machine.
3232

3333
network_model_names = getfield(fitresult, :network_model_names)
34-
old_model = cache.old_model
3534

36-
glb_node = glb(fitresult) # greatest lower bound
35+
old_model = cache.old_model
36+
glb = MLJBase.glb(fitresult) # greatest lower bound of network, a node
3737

38-
if fallback(model, old_model, network_model_names, glb_node)
38+
if fallback(model, old_model, network_model_names, glb)
3939
return fit(model, verbosity, args...)
4040
end
4141

42-
fit!(glb_node; verbosity=verbosity)
42+
fit!(glb; verbosity=verbosity)
43+
4344
# Retrieve additional report values
44-
report_additions_ = _call(_report_part(signature(fitresult)))
45+
report = MLJBase.report(fitresult)
4546

4647
# record current model state:
4748
cache = (; old_model = deepcopy(model))
4849

4950
return (fitresult,
5051
cache,
51-
merge(report(glb_node), report_additions_))
52+
report)
5253

5354
end
5455

src/composition/models/pipelines.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -608,7 +608,7 @@ MMI.target_scitype(p::SupervisedPipeline) = target_scitype(supervised_component(
608608
# ## Training losses
609609

610610
function MMI.training_losses(pipe::SupervisedPipeline, pipe_report)
611-
mach = supervised(pipe_report.machines)
611+
mach = supervised(pipe_report.basic.machines)
612612
_report = report(mach)
613613
return training_losses(mach.model, _report)
614614
end

src/composition/models/transformed_target_model.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ end
224224
# # TRAINING LOSSES
225225

226226
function training_losses(model::SomeTT, tt_report)
227-
mach = first(tt_report.machines)
227+
mach = first(tt_report.basic.machines)
228228
return training_losses(mach)
229229
end
230230

src/machines.jl

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -469,10 +469,11 @@ end
469469
# Not one, but *two*, fit methods are defined for machines here,
470470
# `fit!` and `fit_only!`.
471471

472-
# - `fit_only!`: trains a machine without touching the learned
473-
# parameters (`fitresult`) of any other machine. It may error if
474-
# another machine on which it depends (through its node training
475-
# arguments `N1, N2, ...`) has not been trained.
472+
# - `fit_only!`: trains a machine without touching the learned parameters (`fitresult`) of
473+
# any other machine. It may error if another machine on which it depends (through its node
474+
# training arguments `N1, N2, ...`) has not been trained. It's possible that a dependent
475+
# machine `mach` may have it's report mutated if `reporting_operations(mach.model)` is
476+
# non-empty.
476477

477478
# - `fit!`: trains a machine after first progressively training all
478479
# machines on which the machine depends. Implicitly this involves
@@ -909,13 +910,14 @@ function serializable(mach::Machine{<:Any, C}) where C
909910
setfield!(copymach, fieldname, ())
910911
# Make fitresult ready for serialization
911912
elseif fieldname == :fitresult
913+
# this `save` does the actual emptying of fields
912914
copymach.fitresult = save(mach.model, getfield(mach, fieldname))
913915
else
914916
setfield!(copymach, fieldname, getfield(mach, fieldname))
915917
end
916918
end
917919

918-
setreport!(copymach, mach.report)
920+
setreport!(copymach, mach)
919921

920922
return copymach
921923
end
@@ -997,6 +999,8 @@ function save(file::Union{String,IO},
997999
serialize(file, smach)
9981000
end
9991001

1002+
setreport!(copymach, mach) =
1003+
setfield!(copymach, :report, mach.report)
10001004

1001-
setreport!(mach::Machine, report) =
1002-
setfield!(mach, :report, report)
1005+
# NOTE. there is also a specialization for `setreport!` for `Composite` models, defined in
1006+
# /src/composition/learning_networks/machines/

src/operations.jl

Lines changed: 55 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -37,37 +37,59 @@ warn_serializable_mach(operation) = "The operation $operation has been called on
3737
"deserialised machine mach whose learned parameters "*
3838
"may be unusable. To be sure, first run restore!(mach)."
3939

40+
# Given return value `ret` of an operation with symbol `operation` (eg, `:predict`) return
41+
# `ret` in the ordinary case that the operation does not include an "report" component ;
42+
# otherwise update `mach.report` with that component and return the non-report part of
43+
# `ret`:
44+
function get!(ret, operation, mach)
45+
if operation in reporting_operations(mach.model)
46+
report = last(ret)
47+
if isnothing(mach.report) || isempty(mach.report)
48+
mach.report = report
49+
else
50+
mach.report = merge(mach.report, report)
51+
end
52+
return first(ret)
53+
end
54+
return ret
55+
end
56+
4057
# 0. operations on machine, given rows=...:
4158

4259
for operation in OPERATIONS
4360

44-
if operation != :inverse_transform
61+
quoted_operation = QuoteNode(operation) # eg, :(:predict)
4562

46-
ex = quote
47-
function $(operation)(mach::Machine{<:Model,false}; rows=:)
48-
# catch deserialized machine with no data:
49-
isempty(mach.args) && _err_serialized($operation)
50-
return ($operation)(mach, mach.args[1](rows=rows))
51-
end
52-
function $(operation)(mach::Machine{<:Model,true}; rows=:)
53-
# catch deserialized machine with no data:
54-
isempty(mach.args) && _err_serialized($operation)
55-
model = mach.model
56-
return ($operation)(model,
57-
mach.fitresult,
58-
selectrows(model, rows, mach.data[1])...)
59-
end
60-
end
61-
eval(ex)
63+
operation == :inverse_transform && continue
6264

65+
ex = quote
66+
function $(operation)(mach::Machine{<:Model,false}; rows=:)
67+
# catch deserialized machine with no data:
68+
isempty(mach.args) && _err_serialized($operation)
69+
ret = ($operation)(mach, mach.args[1](rows=rows))
70+
return get!(ret, $quoted_operation, mach)
71+
end
72+
function $(operation)(mach::Machine{<:Model,true}; rows=:)
73+
# catch deserialized machine with no data:
74+
isempty(mach.args) && _err_serialized($operation)
75+
model = mach.model
76+
ret = ($operation)(
77+
model,
78+
mach.fitresult,
79+
selectrows(model, rows, mach.data[1])...,
80+
)
81+
return get!(ret, $quoted_operation, mach)
82+
end
6383
end
84+
eval(ex)
85+
6486
end
6587

6688
# special case of Static models (no training arguments):
6789
transform(mach::Machine{<:Static}; rows=:) = _err_rows_not_allowed()
6890

6991
inverse_transform(mach::Machine; rows=:) =
70-
throw(ArgumentError("`inverse_transform()(mach)` and "*
92+
throw(ArgumentError("`inverse_transform(mach)` and "*
7193
"`inverse_transform(mach, rows=...)` are "*
7294
"not supported. Data or nodes "*
7395
"must be explictly specified, "*
@@ -77,22 +99,32 @@ _symbol(f) = Base.Core.Typeof(f).name.mt.name
7799

78100
for operation in OPERATIONS
79101

102+
quoted_operation = QuoteNode(operation) # eg, :(:predict)
103+
80104
ex = quote
81105
# 1. operations on machines, given *concrete* data:
82106
function $operation(mach::Machine, Xraw)
83107
if mach.state != 0
84108
mach.state == -1 && @warn warn_serializable_mach($operation)
85-
return $(operation)(mach.model,
86-
mach.fitresult,
87-
reformat(mach.model, Xraw)...)
109+
ret = $(operation)(
110+
mach.model,
111+
mach.fitresult,
112+
reformat(mach.model, Xraw)...,
113+
)
114+
get!(ret, $quoted_operation, mach)
88115
else
89116
error("$mach has not been trained.")
90117
end
91118
end
92119

93120
function $operation(mach::Machine{<:Static}, Xraw, Xraw_more...)
94-
return $(operation)(mach.model, mach.fitresult,
95-
Xraw, Xraw_more...)
121+
ret = $(operation)(
122+
mach.model,
123+
mach.fitresult,
124+
Xraw,
125+
Xraw_more...,
126+
)
127+
get!(ret, $quoted_operation, mach)
96128
end
97129

98130
# 2. operations on machines, given *dynamic* data (nodes):

test/composition/learning_networks/machines.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,6 @@ end
8888
@test Θ.transform == Wout
8989
Θ.report.some_stuff == rnode
9090
@test report(mach).some_stuff == :stuff
91-
9291
@test report(mach).machines == fitted_params(mach).machines
9392

9493
# supervised
@@ -281,7 +280,7 @@ end
281280
end
282281

283282
# Testing extra report field : it is a deepcopy
284-
@test smach.report.cv_report === mach.report.cv_report
283+
@test report(smach).cv_report === report(mach).cv_report
285284

286285
@test smach.fitresult isa MLJBase.CompositeFitresult
287286

@@ -356,7 +355,8 @@ end
356355
metalearner = FooBarRegressor(lambda=1.),
357356
resampling = dcv,
358357
model_1 = DeterministicConstantRegressor(),
359-
model_2=ConstantRegressor())
358+
model_2=ConstantRegressor()
359+
)
360360

361361
filesizes = []
362362
for n in [100, 500, 1000]

0 commit comments

Comments
 (0)