Skip to content

Commit 729e0d7

Browse files
committed
complete addition of update methods + other tweaks
1 parent 6e721c8 commit 729e0d7

17 files changed

+301
-156
lines changed

docs/make.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@ makedocs(
1515
"Anatomy of an Implementation" => "anatomy_of_an_implementation.md",
1616
"Reference" => [
1717
"Overview" => "reference.md",
18-
"fit" => "fit.md",
18+
"fit/update" => "fit.md",
1919
"predict/transform" => "predict_transform.md",
2020
"Kinds of Target Proxy" => "kinds_of_target_proxy.md",
2121
"minimize" => "minimize.md",
22-
"target/weights/input" => "target_weights_input.md",
22+
"target/weights/features" => "target_weights_features.md",
2323
"obs" => "obs.md",
2424
"Accessor Functions" => "accessor_functions.md",
2525
"Algorithm Traits" => "traits.md",

docs/src/accessor_functions.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# [Accessor Functions](@id accessor_functions)
22

3-
The sole argument of an accessor function is the output, `model`, of [`fit`](@ref).
3+
The sole argument of an accessor function is the output, `model`, of
4+
[`fit`](@ref). Algorithms are free to implement any number of these, or none of them.
45

56
- [`LearnAPI.algorithm(model)`](@ref)
67
- [`LearnAPI.extras(model)`](@ref)
@@ -15,6 +16,9 @@ The sole argument of an accessor function is the output, `model`, of [`fit`](@re
1516
- [`LearnAPI.training_scores(model)`](@ref)
1617
- [`LearnAPI.components(model)`](@ref)
1718

19+
Algorithm-specific accessor functions may also be implemented. The names of all accessor
20+
functions are included in the list returned by [`LearnAPI.functions(algorithm)`](@ref).
21+
1822
## Implementation guide
1923

2024
All new implementations must implement [`LearnAPI.algorithm`](@ref). While, all others are

docs/src/anatomy_of_an_implementation.md

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,26 +5,34 @@ regression](https://en.wikipedia.org/wiki/Ridge_regression) with no intercept. T
55
workflow we want to enable has been previewed in [Sample workflow](@ref). Readers can also
66
refer to the [demonstration](@ref workflow) of the implementation given later.
77

8-
For a transformer, implementations ordinarily implement `transform` instead of
8+
A transformer ordinarily implements `transform` instead of
99
`predict`. For more on `predict` versus `transform`, see [Predict or transform?](@ref)
1010

1111
!!! note
1212

1313
New implementations of `fit`, `predict`, etc,
1414
always have a *single* `data` argument, as in
1515
`LearnAPI.fit(algorithm, data; verbosity=1) = ...`.
16-
For convenience, user calls like `fit(algorithm, X, y)` automatically fallback
16+
For convenience, user-calls, such as `fit(algorithm, X, y)`, automatically fallback
1717
to `fit(algorithm, (X, y))`.
1818

1919
!!! note
2020

21+
By default, it is assumed that `data` supports the [`LearnAPI.RandomAccess`](@ref)
22+
interface; this includes all matrices, with observations-as-columns, most tables, and
23+
tuples thereof). See [`LearnAPI.RandomAccess`](@ref) for details. If this is not the
24+
case then an implementation must either:
25+
2126
If the `data` object consumed by `fit`, `predict`, or `transform` is not
2227
not a suitable table¹, array³, tuple of tables and arrays, or some
2328
other object implementing
2429
the MLUtils.jl `getobs`/`numobs` interface,
25-
then an implementation must: (i) suitably overload the trait
26-
[`LearnAPI.data_interface`](@ref); and/or (ii) overload [`obs`](@ref), as
27-
illustrated below under [Providing an advanced data interface](@ref).
30+
then an implementation must: (i) overload [`obs`](@ref) to articulate how
31+
provided data can be transformed into a form that does support
32+
it, as illustrated below under
33+
[Providing an advanced data interface](@ref); or (ii) overload the trait
34+
[`LearnAPI.data_interface`](@ref) to specify a more relaxed data
35+
API.
2836

2937
The first line below imports the lightweight package LearnAPI.jl whose methods we will be
3038
extending. The second imports libraries needed for the core algorithm.
@@ -152,9 +160,9 @@ from training data, by implementing [`LearnAPI.target`](@ref):
152160
LearnAPI.target(algorithm, data) = last(data)
153161
```
154162

155-
There is a similar method, [`LearnAPI.input`](@ref) for declaring how input data can be
156-
extracted (for passing to `predict`, for example) but this method has a fallback which
157-
typically suffices: return `first(data)` if `data` is a tuple, and otherwise return
163+
There is a similar method, [`LearnAPI.features`](@ref) for declaring how training features
164+
can be extracted (for passing to `predict`, for example) but this method has a fallback
165+
which typically suffices: return `first(data)` if `data` is a tuple, and otherwise return
158166
`data`.
159167

160168

@@ -218,7 +226,7 @@ A macro provides a shortcut, convenient when multiple traits are to be defined:
218226
:(LearnAPI.algorithm),
219227
:(LearnAPI.minimize),
220228
:(LearnAPI.obs),
221-
:(LearnAPI.input),
229+
:(LearnAPI.features),
222230
:(LearnAPI.target),
223231
:(LearnAPI.predict),
224232
:(LearnAPI.coefficients),
@@ -325,7 +333,7 @@ LearnAPI.minimize(model::RidgeFitted) =
325333
:(LearnAPI.algorithm),
326334
:(LearnAPI.minimize),
327335
:(LearnAPI.obs),
328-
:(LearnAPI.input),
336+
:(LearnAPI.features),
329337
:(LearnAPI.target),
330338
:(LearnAPI.predict),
331339
:(LearnAPI.coefficients),
@@ -423,7 +431,7 @@ LearnAPI.predict(model::RidgeFitted, ::LiteralTarget, Xnew) =
423431
predict(model, LiteralTarget(), obs(model, Xnew))
424432
```
425433

426-
### `target` and `input` methods
434+
### `target` and `features` methods
427435

428436
We provide an additional overloading of [`LearnAPI.target`](@ref) to handle the additional
429437
supported data argument of `fit`:
@@ -432,11 +440,11 @@ supported data argument of `fit`:
432440
LearnAPI.target(::Ridge, observations::RidgeFitObs) = observations.y
433441
```
434442

435-
Similarly, we must overload [`LearnAPI.input`](@ref), which extracts inputs from training
436-
data (objects that can be passed to `predict`) like this
443+
Similarly, we must overload [`LearnAPI.features`](@ref), which extracts features from
444+
training data (objects that can be passed to `predict`) like this
437445

438446
```@example anatomy2
439-
LearnAPI.input(::Ridge, observations::RidgeFitObs) = observations.A
447+
LearnAPI.features(::Ridge, observations::RidgeFitObs) = observations.A
440448
```
441449
as the fallback mentioned above is no longer adequate.
442450

@@ -482,6 +490,9 @@ ẑ = predict(model, MLUtils.getobs(observations_for_predict, test))
482490
@assert==
483491
```
484492

493+
For an application of [`obs`](@ref) to efficient cross-validation, see [here](@ref
494+
obs_workflows).
495+
485496
---
486497

487498
¹ In LearnAPI.jl a *table* is any object `X` implementing the

docs/src/fit.md

Lines changed: 36 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,28 @@
1-
# [`fit`](@ref fit)
1+
# [`fit`, `update`, `update_observations`, and `update_features`](@id fit)
22

3-
Training for the first time:
3+
### Training
44

55
```julia
66
fit(algorithm, data; verbosity=1) -> model
77
fit(algorithm; verbosity=1) -> static_model
88
```
99

10-
Updating:
10+
A "static" algorithm is one that does not generalize to new observations (e.g., some
11+
clustering algorithms); there is no trainiing data and the algorithm is executed by
12+
`predict` or `transform` which receive the data. See example below.
13+
14+
When `fit` expects a tuple form of argument, `data = (X1, ..., Xn)`, then the signature
15+
`fit(algorithm, X1, ..., Xn)` is also provided.
16+
17+
### Updating
1118

1219
```
13-
fit(model, data; verbosity=1, param1=new_value1, param2=new_value2, ...) -> updated_model
14-
fit(model, NewObservations(), new_data; verbosity=1, param1=new_value1, ...) -> updated_model
15-
fit(model, NewFeatures(), new_data; verbosity=1, param1=new_value1, ...) -> updated_model
20+
update(model, data; verbosity=1, param1=new_value1, param2=new_value2, ...) -> updated_model
21+
update_observations(model, new_data; verbosity=1, param1=new_value1, ...) -> updated_model
22+
update_features(model, new_data; verbosity=1, param1=new_value1, ...) -> updated_model
1623
```
1724

18-
When `fit` expects a tuple form of argument, `data = (X1, ..., Xn)`, then the signature
19-
`fit(algorithm, X1, ..., Xn)` is also provided.
25+
Data slurping forms are similarly provided for updating methods.
2026

2127
## Typical workflows
2228

@@ -27,46 +33,55 @@ algorithm = Algorithm(n=100)
2733
model = fit(algorithm, (X, y)) # or `fit(algorithm, X, y)`
2834

2935
# Predict probability distributions:
30-
= predict(model, Distribution(), Xnew)
36+
= predict(model, Distribution(), Xnew)
3137

3238
# Inspect some byproducts of training:
3339
LearnAPI.feature_importances(model)
3440

3541
# Add 50 iterations and predict again:
36-
model = fit(model; n=150)
42+
model = update(model; n=150)
3743
predict(model, Distribution(), X)
3844
```
3945

4046
### A static algorithm (no "learning")
4147

4248
```julia
4349
# Apply some clustering algorithm which cannot be generalized to new data:
44-
model = fit(algorithm)
45-
labels = predict(model, LabelAmbiguous(), X) # mutates `model`
50+
model = fit(algorithm) # no training data
51+
labels = predict(model, LabelAmbiguous(), X) # may mutate `model`
52+
53+
# Or, in one line:
54+
labels = predict(algorithm, LabelAmbiguous(), X)
4655

47-
# inspect byproducts of the clustering algorithm (e.g., outliers):
56+
# But two-line version exposes byproducts of the clustering algorithm (e.g., outliers):
4857
LearnAPI.extras(model)
4958
```
5059

5160
## Implementation guide
5261

53-
Initial training:
62+
### Training
5463

5564
| method | fallback | compulsory? |
5665
|:-------------------------------------------------------------------------------|:-----------------------------------------------------------------|--------------------|
5766
| [`fit`](@ref)`(algorithm, data; verbosity=1)` | ignores `data` and applies signature below | yes, unless static |
5867
| [`fit`](@ref)`(algorithm; verbosity=1)` | none | no, unless static |
5968

60-
Updating:
69+
### Updating
70+
71+
| method | fallback | compulsory? |
72+
|:-------------------------------------------------------------------------------------|:---------|-------------|
73+
| [`update`](@ref)`(model, data; verbosity=1, hyperparameter_updates...)` | none | no |
74+
| [`update_observations`](@ref)`(model, data; verbosity=1, hyperparameter_updates...)` | none | no |
75+
| [`update_features`](@ref)`(model, data; verbosity=1, hyperparameter_updates...)` | none | no |
6176

62-
| method | fallback | compulsory? |
63-
|:-------------------------------------------------------------------------------|:---------------------------------------------------------------------------|-------------|
64-
| [`fit`](@ref)`(model, data; verbosity=1, param_updates...)` | retrains from scratch on `data` with specified hyperparameter replacements | no |
65-
| [`fit`](@ref)`(model, ::NewObservations, data; verbosity=1, param_updates...)` | none | no |
66-
| [`fit`](@ref)`(model, ::NewFeatures, data; verbosity=1, param_updates...)` | none | no |
77+
There are some contracts regarding the behaviour of the update methods, as they relate to
78+
a previous `fit` call. Consult the document strings for details.
6779

6880
## Reference
6981

7082
```@docs
71-
LearnAPI.fit
83+
fit
84+
update
85+
update_observations
86+
update_features
7287
```

docs/src/index.md

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,14 @@ A base Julia interface for machine learning and statistics </span>
99
<br>
1010
```
1111

12-
LearnAPI.jl is a lightweight, functional-style interface, providing a collection of
13-
[methods](@ref Methods), such as `fit` and `predict`, to be implemented by algorithms from
14-
machine learning and statistics. Through such implementations, these algorithms buy into
15-
functionality, such as hyperparameter optimization and model composition, as provided by
16-
ML/statistics toolboxes and other packages. LearnAPI.jl also provides a number of Julia
17-
[traits](@ref traits) for promising specific behavior.
12+
LearnAPI.jl is a lightweight, functional-style interface, providing a
13+
collection of [methods](@ref Methods), such as `fit` and `predict`, to be implemented by
14+
algorithms from machine learning and statistics. Through such implementations, these
15+
algorithms buy into functionality, such as hyperparameter optimization and model
16+
composition, as provided by ML/statistics toolboxes and other packages. LearnAPI.jl also
17+
provides a number of Julia [traits](@ref traits) for promising specific behavior.
18+
19+
LearnAPI.jl has no package dependencies.
1820

1921
```@raw html
2022
&#128679;
@@ -41,15 +43,18 @@ X = <some training features>
4143
y = <some training target>
4244
Xnew = <some test or production features>
4345

46+
# List LearnaAPI functions implemented for `forest`:
47+
LearnAPI.functions(forest)
48+
4449
# Train:
4550
model = fit(forest, X, y)
4651

52+
# Generate point predictions:
53+
= predict(model, Xnew) # or `predict(model, LiteralTarget(), Xnew)`
54+
4755
# Predict probability distributions:
4856
predict(model, Distribution(), Xnew)
4957

50-
# Generate point predictions:
51-
= predict(model, LiteralTarget(), Xnew) # or `predict(model, Xnew)`
52-
5358
# Apply an "accessor function" to inspect byproducts of training:
5459
LearnAPI.feature_importances(model)
5560

@@ -77,13 +82,14 @@ data_interface) (read as "observations") gives users and meta-algorithms access
7782
algorithm-specific representation of input data, which is also guaranteed to implement a
7883
standard interface for accessing individual observations, unless the algorithm explicitly
7984
opts out. Moreover, the `fit` and `predict` methods will also be able to consume these
80-
alternative data representations.
85+
alternative data representations, for performance benefits in some situations.
8186

8287
The fallback data interface is the [MLUtils.jl](https://github.com/JuliaML/MLUtils.jl)
83-
`getobs/numobs` interface, and if the input consumed by the algorithm already implements
84-
that interface (tables, arrays, etc.) then overloading `obs` is completely optional. Plain
85-
iteration interfaces, with or without knowledge of the number of observations, can also be
86-
specified (to support, e.g., data loaders reading images from disk).
88+
`getobs/numobs` interface (here tagged as [`LearnAPI.RandomAccess()`](@ref)) and if the
89+
input consumed by the algorithm already implements that interface (tables, arrays, etc.)
90+
then overloading `obs` is completely optional. Plain iteration interfaces, with or without
91+
knowledge of the number of observations, can also be specified (to support, e.g., data
92+
loaders reading images from disk).
8793

8894
## Learning more
8995

docs/src/kinds_of_target_proxy.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ expectiles at 50% will provide `LiteralTarget` instead.
4747
> Table of concrete subtypes of `LearnAPI.IID <: LearnAPI.KindOfProxy`.
4848
4949

50-
## Proxies for distribution-fitting algorithms
50+
## Proxies for density estimation lgorithms
5151

5252
```@docs
5353
LearnAPI.Single

docs/src/obs.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ obs(algorithm, data) # can be passed to `fit` instead of `data`
1111
obs(model, data) # can be passed to `predict` or `transform` instead of `data`
1212
```
1313

14-
## Typical workflows
14+
## [Typical workflows](@id obs_workflows)
1515

1616
LearnAPI.jl makes no universal assumptions about the form of `data` in a call
1717
like `fit(algorithm, data)`. However, if we define
@@ -46,7 +46,7 @@ import MLUtils
4646
algorithm = <some supervised learner>
4747

4848
data = <some data that `fit` can consume, with 30 observations>
49-
X = LearnAPI.input(algorithm, data)
49+
X = LearnAPI.features(algorithm, data)
5050
y = LearnAPI.target(algorithm, data)
5151

5252
train_test_folds = map([1:10, 11:20, 21:30]) do test

0 commit comments

Comments
 (0)