Skip to content

Commit 511ce6c

Browse files
committed
add LearnAPI.clone; tweak contract for update
1 parent 29ccc3b commit 511ce6c

File tree

10 files changed

+124
-131
lines changed

10 files changed

+124
-131
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@ julia = "1.6"
1313
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
1414
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1515
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
16+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1617
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
18+
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1719
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
1820
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1921

2022
[targets]
21-
test = ["DataFrames", "LinearAlgebra", "MLUtils", "Serialization", "Tables", "Test"]
23+
test = ["DataFrames", "LinearAlgebra", "MLUtils", "Random", "Serialization", "Statistics", "Tables", "Test"]

ROADMAP.md

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,6 @@
3939
- [ ] meta-algorithms
4040

4141
- [ ] In a utility package provide:
42-
- [ ] Method to clone an algorithm with user-specified property (hyperparameter)
43-
replacement in `LearnAPI.clone(algorithm, p1=value1, p22=value2, ...)` (since
44-
`algorithm` can have any type, can't really overload `Base.replace` without
45-
piracy). This will be needed in tuning meta-algorithms. Or should this be in
46-
LearnAPI.jl proper, to expose it to all users?
4742
- [ ] Methods to facilitate common-use case data interfaces: support simultaneously
4843
`fit` data of the form `data = (X, y)` where `X` is table *or* matrix, and `data` a
4944
table with target specified by hyperparameter; here `obs` will return a thin wrapping

docs/src/reference.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,15 @@ named_properties = NamedTuple{properties}(getproperty.(Ref(algorithm), propertie
9191
@assert algorithm == LearnAPI.constructor(algorithm)(; named_properties...)
9292
```
9393

94+
which can be tested with `@assert `[`LearnAPI.clone(algorithm)`](@ref)` == algorithm`.
95+
9496
Note that if if `algorithm` is an instance of a *mutable* struct, this requirement
9597
generally requires overloading `Base.==` for the struct.
9698

99+
No LearnAPI.jl method is permitted to mutate an algorithm. In particular, one should make
100+
deep copies of RNG hyperparameters before using them in a new implementation of
101+
[`fit`](@ref).
102+
97103
#### Composite algorithms (wrappers)
98104

99105
A *composite algorithm* is one with at least one property that can take other algorithms
@@ -179,6 +185,14 @@ Most algorithms will also implement [`predict`](@ref) and/or [`transform`](@ref)
179185
record general information about the algorithm. Only [`LearnAPI.constructor`](@ref) and
180186
[`LearnAPI.functions`](@ref) are universally compulsory.
181187

188+
189+
## Utilities
190+
191+
```@docs
192+
LearnAPI.clone
193+
LearnAPI.@trait
194+
```
195+
182196
---
183197

184198
¹ We acknowledge users may not like this terminology, and may know "algorithm" by some

docs/src/traits.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,5 +105,4 @@ LearnAPI.iteration_parameter
105105
LearnAPI.fit_observation_scitype
106106
LearnAPI.target_observation_scitype
107107
LearnAPI.predict_or_transform_mutates
108-
LearnAPI.@trait
109108
```

src/LearnAPI.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ include("target_weights_features.jl")
1111
include("obs.jl")
1212
include("accessor_functions.jl")
1313
include("traits.jl")
14+
include("clone.jl")
1415

1516
export @trait
1617
export fit, update, update_observations, update_features

src/clone.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
"""
2+
LearnAPI.clone(algorithm; replacements...)
3+
4+
Return a shallow copy of `algorithm` with the specified hyperparameter replacements.
5+
6+
```julia
7+
clone(algorithm; epochs=100, learning_rate=0.01)
8+
```
9+
10+
It is guaranted that `LearnAPI.clone(algorithm) == algorithm`.
11+
12+
"""
13+
function clone(algorithm; replacements...)
14+
reps = NamedTuple(replacements)
15+
names = propertynames(algorithm)
16+
rep_names = keys(reps)
17+
18+
new_values = map(names) do name
19+
name in rep_names && return getproperty(reps, name)
20+
getproperty(algorithm, name)
21+
end
22+
return LearnAPI.constructor(algorithm)(NamedTuple{names}(new_values)...)
23+
end

src/fit_update.jl

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -59,14 +59,15 @@ Return an updated version of the `model` object returned by a previous [`fit`](@
5959
`update` call, but with the specified hyperparameter replacements, in the form `p1=value1,
6060
p2=value2, ...`.
6161
62-
Provided that `data` is identical with the data presented in a preceding `fit` call, as in
63-
the example below, execution is semantically equivalent to the call `fit(algorithm,
64-
data)`, where `algorithm` is `LearnAPI.algorithm(model)` with the specified
65-
replacements. In some cases (typically, when changing an iteration parameter) there may be
66-
a performance benefit to using `update` instead of retraining ab initio.
62+
Provided that `data` is identical with the data presented in a preceding `fit` call *and*
63+
there is at most one hyperparameter replacement, as in the example below, execution is
64+
semantically equivalent to the call `fit(algorithm, data)`, where `algorithm` is
65+
`LearnAPI.algorithm(model)` with the specified replacements. In some cases (typically,
66+
when changing an iteration parameter) there may be a performance benefit to using `update`
67+
instead of retraining ab initio.
6768
68-
If `data` differs from that in the preceding `fit` or `update` call, then behaviour is
69-
algorithm-specific.
69+
If `data` differs from that in the preceding `fit` or `update` call, or there is more than
70+
one hyperparameter replacement, then behaviour is algorithm-specific.
7071
7172
```julia
7273
algorithm = MyForest(ntrees=100)
@@ -85,6 +86,8 @@ See also [`fit`](@ref), [`update_observations`](@ref), [`update_features`](@ref)
8586
Implementation is optional. The signature must include
8687
`verbosity`. $(DOC_IMPLEMENTED_METHODS(":(LearnAPI.update)"))
8788
89+
See also [`LearnAPI.clone`](@ref)
90+
8891
"""
8992
update(model, data1, datas...; kwargs...) = update(model, (data1, datas...); kwargs...)
9093

@@ -119,6 +122,8 @@ See also [`fit`](@ref), [`update`](@ref), [`update_features`](@ref).
119122
Implementation is optional. The signature must include
120123
`verbosity`. $(DOC_IMPLEMENTED_METHODS(":(LearnAPI.update_observations)"))
121124
125+
See also [`LearnAPI.clone`](@ref).
126+
122127
"""
123128
update_observations(algorithm, data1, datas...; kwargs...) =
124129
update_observations(algorithm, (data1, datas...); kwargs...)
@@ -144,6 +149,8 @@ See also [`fit`](@ref), [`update`](@ref), [`update_features`](@ref).
144149
Implementation is optional. The signature must include
145150
`verbosity`. $(DOC_IMPLEMENTED_METHODS(":(LearnAPI.update_features)"))
146151
152+
See also [`LearnAPI.clone`](@ref).
153+
147154
"""
148155
update_features(algorithm, data1, datas...; kwargs...) =
149156
update_features(algorithm, (data1, datas...); kwargs...)

src/traits.jl

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -105,29 +105,43 @@ value is non-empty.
105105
All new implementations must overload this trait. Here's a checklist for elements in the
106106
return value:
107107
108-
| symbol | implementation/overloading compulsory? | include in returned tuple? |
109-
|-----------------------------------|----------------------------------------|------------------------------------|
110-
| `:(LearnAPI.fit)` | yes | yes |
111-
| `:(LearnAPI.algorithm)` | yes | yes |
112-
| `:(LearnAPI.minimize)` | no | yes |
113-
| `:(LearnAPI.obs)` | no | yes |
114-
| `:(LearnAPI.features)` | no | yes, unless `fit` consumes no data |
115-
| `:(LearnAPI.update)` | no | only if implemented |
116-
| `:(LearnAPI.update_observations)` | no | only if implemented |
117-
| `:(LearnAPI.update_features)` | no | only if implemented |
118-
| `:(LearnAPI.target)` | no | only if implemented |
119-
| `:(LearnAPI.weights)` | no | only if implemented |
120-
| `:(LearnAPI.predict)` | no | only if implemented |
121-
| `:(LearnAPI.transform)` | no | only if implemented |
122-
| `:(LearnAPI.inverse_transform)` | no | only if implemented |
123-
| <accessor functions> | no | only if implemented |
108+
| expression | implementation compulsory? | include in returned tuple? |
109+
|-----------------------------------|----------------------------|------------------------------------|
110+
| `:(LearnAPI.fit)` | yes | yes |
111+
| `:(LearnAPI.algorithm)` | yes | yes |
112+
| `:(LearnAPI.minimize)` | no | yes |
113+
| `:(LearnAPI.obs)` | no | yes |
114+
| `:(LearnAPI.features)` | no | yes, unless `fit` consumes no data |
115+
| `:(LearnAPI.target)` | no | only if implemented |
116+
| `:(LearnAPI.weights)` | no | only if implemented |
117+
| `:(LearnAPI.update)` | no | only if implemented |
118+
| `:(LearnAPI.update_observations)` | no | only if implemented |
119+
| `:(LearnAPI.update_features)` | no | only if implemented |
120+
| `:(LearnAPI.predict)` | no | only if implemented |
121+
| `:(LearnAPI.transform)` | no | only if implemented |
122+
| `:(LearnAPI.inverse_transform)` | no | only if implemented |
123+
| <accessor functions> | no | only if implemented |
124124
125125
Also include any implemented accessor functions, both those owned by LearnaAPI.jl, and any
126126
algorithm-specific ones. The LearnAPI.jl accessor functions are: $ACCESSOR_FUNCTIONS_LIST.
127127
128128
"""
129129
functions(::Any) = ()
130-
130+
functions() = (
131+
:(LearnAPI.fit),
132+
:(LearnAPI.algorithm),
133+
:(LearnAPI.minimize),
134+
:(LearnAPI.obs),
135+
:(LearnAPI.features),
136+
:(LearnAPI.target),
137+
:(LearnAPI.weights),
138+
:(LearnAPI.update),
139+
:(LearnAPI.update_observations),
140+
:(LearnAPI.update_features),
141+
:(LearnAPI.predict),
142+
:(LearnAPI.transform),
143+
:(LearnAPI.inverse_transform),
144+
)
131145

132146
"""
133147
LearnAPI.kinds_of_proxy(algorithm)

test/clone.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
using Test
2+
using LearnAPI
3+
4+
struct Potato
5+
x
6+
y
7+
end
8+
9+
Potato(; x=1, y=2) = Potato(x, y)
10+
LearnAPI.constructor(::Potato) = Potato
11+
12+
@test LearnAPI.clone(Potato()) == Potato()
13+
14+
p = LearnAPI.clone(Potato(), y=20)
15+
@test p.y == 20
16+
@test p.x == 1
17+
18+
q = LearnAPI.clone(Potato(), y=20, x=10)
19+
@test q.y == 20
20+
@test q.x == 10
21+
22+
true

test/runtests.jl

Lines changed: 16 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -1,103 +1,19 @@
11
using Test
22

3-
@testset "tools.jl" begin
4-
include("tools.jl")
3+
test_files = [
4+
"tools.jl",
5+
"traits.jl",
6+
"clone.jl",
7+
"integration/regression.jl",
8+
"integration/static_algorithms.jl",
9+
]
10+
11+
files = isempty(ARGS) ? test_files : ARGS
12+
13+
for file in files
14+
quote
15+
@testset $file begin
16+
include($file*".jl")
17+
end
18+
end |> eval
519
end
6-
7-
@testset "traits.jl" begin
8-
include("traits.jl")
9-
end
10-
11-
# # INTEGRATION TESTS
12-
13-
@testset "regression" begin
14-
include("integration/regression.jl")
15-
end
16-
17-
# @testset "classification" begin
18-
# include("integration/classification.jl")
19-
# end
20-
21-
# @testset "clustering" begin
22-
# include("integration/clustering.jl")
23-
# end
24-
25-
# @testset "gradient_descent" begin
26-
# include("integration/gradient_descent.jl")
27-
# end
28-
29-
# @testset "iterative_algorithms" begin
30-
# include("integration/iterative_algorithms.jl")
31-
# end
32-
33-
# @testset "incremental_algorithms" begin
34-
# include("integration/incremental_algorithms.jl")
35-
# end
36-
37-
# @testset "dimension_reduction" begin
38-
# include("integration/dimension_reduction.jl")
39-
# end
40-
41-
# @testset "encoders" begin
42-
# include("integration/encoders.jl")
43-
# end
44-
45-
@testset "static_algorithms" begin
46-
include("integration/static_algorithms.jl")
47-
end
48-
49-
# @testset "missing_value_imputation" begin
50-
# include("integration/missing_value_imputation.jl")
51-
# end
52-
53-
# @testset "ensemble_algorithms" begin
54-
# include("integration/ensemble_algorithms.jl")
55-
# end
56-
57-
# @testset "wrappers" begin
58-
# include("integration/wrappers.jl")
59-
# end
60-
61-
# @testset "time_series_forecasting" begin
62-
# include("integration/time_series_forecasting.jl")
63-
# end
64-
65-
# @testset "time_series_classification" begin
66-
# include("integration/time_series_classification.jl")
67-
# end
68-
69-
# @testset "survival_analysis" begin
70-
# include("integration/survival_analysis.jl")
71-
# end
72-
73-
# @testset "distribution_fitters" begin
74-
# include("integration/distribution_fitters.jl")
75-
# end
76-
77-
# @testset "Bayesian_algorithms" begin
78-
# include("integration/Bayesian_algorithms.jl")
79-
# end
80-
81-
# @testset "outlier_detection" begin
82-
# include("integration/outlier_detection.jl")
83-
# end
84-
85-
# @testset "collaborative_filtering" begin
86-
# include("integration/collaborative_filtering.jl")
87-
# end
88-
89-
# @testset "text_analysis" begin
90-
# include("integration/text_analysis.jl")
91-
# end
92-
93-
# @testset "audio_analysis" begin
94-
# include("integration/audio_analysis.jl")
95-
# end
96-
97-
# @testset "natural_language_processing" begin
98-
# include("integration/natural_language_processing.jl")
99-
# end
100-
101-
# @testset "image_processing" begin
102-
# include("integration/image_processing.jl")
103-
# end

0 commit comments

Comments
 (0)