Skip to content

Commit 7f1d537

Browse files
authored
Merge pull request #35 from JuliaAI/no-slurping
Remove data slurping fallbacks
2 parents e9c39a8 + 6242737 commit 7f1d537

13 files changed

+132
-101
lines changed

docs/src/anatomy_of_an_implementation.md

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,27 @@
11
# Anatomy of an Implementation
22

3-
This section explains a detailed implementation of the LearnAPI for naive [ridge
3+
This section explains a detailed implementation of the LearnAPI.jl for naive [ridge
44
regression](https://en.wikipedia.org/wiki/Ridge_regression) with no intercept. The kind of
55
workflow we want to enable has been previewed in [Sample workflow](@ref). Readers can also
66
refer to the [demonstration](@ref workflow) of the implementation given later.
77

8-
A transformer ordinarily implements `transform` instead of
9-
`predict`. For more on `predict` versus `transform`, see [Predict or transform?](@ref)
8+
The core LearnAPI.jl pattern looks like this:
9+
10+
```julia
11+
model = fit(algorithm, data)
12+
predict(model, newdata)
13+
```
14+
15+
A transformer ordinarily implements `transform` instead of `predict`. For more on
16+
`predict` versus `transform`, see [Predict or transform?](@ref)
1017

1118
!!! note
1219

1320
New implementations of `fit`, `predict`, etc,
14-
always have a *single* `data` argument, as in
15-
`LearnAPI.fit(algorithm, data; verbosity=1) = ...`.
16-
For convenience, user-calls, such as `fit(algorithm, X, y)`, automatically fallback
17-
to `fit(algorithm, (X, y))`.
21+
always have a *single* `data` argument as above.
22+
For convenience, a signature such as `fit(algorithm, X, y)`, calling
23+
`fit(algorithm, (X, y))`, can be added, but the LearnAPI.jl specification is
24+
silent on the meaning or existence of signatures with extra arguments.
1825

1926
!!! note
2027

@@ -52,7 +59,7 @@ nothing # hide
5259

5360
Instances of `Ridge` will be [algorithms](@ref algorithms), in LearnAPI.jl parlance.
5461

55-
Associated with each new type of LearnAPI [algorithm](@ref algorithms) will be a keyword
62+
Associated with each new type of LearnAPI.jl [algorithm](@ref algorithms) will be a keyword
5663
argument constructor, providing default values for all properties (struct fields) that are
5764
not other algorithms, and we must implement [`LearnAPI.constructor(algorithm)`](@ref), for
5865
recovering the constructor from an instance:
@@ -244,6 +251,14 @@ in LearnAPI.functions(algorithm)`, for every instance `algorithm`. With [some
244251
exceptions](@ref trait_contract), the value of a trait should depend only on the *type* of
245252
the argument.
246253

254+
## Signatures added for convenience
255+
256+
We add one `fit` signature for user-convenience only. The LearnAPI.jl specification has
257+
nothing to say about `fit` signatures with more than two positional arguments.
258+
259+
```@example anatomy
260+
LearnAPI.fit(algorithm::Ridge, X, y; kwargs...) = fit(algorithm, (X, y); kwargs...)
261+
```
247262

248263
## [Demonstration](@id workflow)
249264

@@ -466,6 +481,14 @@ overload the trait, [`LearnAPI.data_interface(algorithm)`](@ref). See [Data
466481
interfaces](@ref data_interfaces) for details.
467482

468483

484+
### Addition of signatures for user convenience
485+
486+
As above, we add a signature which plays no role vis-à-vis LearnAPI.jl.
487+
488+
```@example anatomy2
489+
LearnAPI.fit(algorithm::Ridge, X, y; kwargs...) = fit(algorithm, (X, y); kwargs...)
490+
```
491+
469492
## Demonstration of an advanced `obs` workflow
470493

471494
We now can train and predict using internal data representations, resampled using the

docs/src/common_implementation_patterns.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# Common Implementation Patterns
22

3+
!!! warning
4+
35
!!! warning
46

57
This section is only an implementation guide. The definitive specification of the

docs/src/fit_update.md

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@ A "static" algorithm is one that does not generalize to new observations (e.g.,
1111
clustering algorithms); there is no trainiing data and the algorithm is executed by
1212
`predict` or `transform` which receive the data. See example below.
1313

14-
When `fit` expects a tuple form of argument, `data = (X1, ..., Xn)`, then the signature
15-
`fit(algorithm, X1, ..., Xn)` is also provided.
1614

1715
### Updating
1816

@@ -32,7 +30,7 @@ Supposing `Algorithm` is some supervised classifier type, with an iteration para
3230

3331
```julia
3432
algorithm = Algorithm(n=100)
35-
model = fit(algorithm, (X, y)) # or `fit(algorithm, X, y)`
33+
model = fit(algorithm, (X, y))
3634

3735
# Predict probability distributions:
3836
= predict(model, Distribution(), Xnew)
@@ -76,6 +74,22 @@ labels = predict(algorithm, X)
7674
LearnAPI.extras(model)
7775
```
7876

77+
### Density estimation
78+
79+
In density estimation, `fit` consumes no features, only a target variable; `predict`,
80+
which consumes no data, returns the learned density:
81+
82+
```julia
83+
model = fit(algorithm, y) # no features
84+
predict(model) # shortcut for `predict(model, Distribution())`
85+
```
86+
87+
A one-liner will typically be implemented as well:
88+
89+
```julia
90+
predict(algorithm, y)
91+
```
92+
7993
## Implementation guide
8094

8195
### Training

docs/src/predict_transform.md

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,15 @@ transform(model, data)
66
inverse_transform(model, data)
77
```
88

9-
When a method expects a tuple form of argument, `data = (X1, ..., Xn)`, then a slurping
10-
signature is also provided, as in `transform(model, X1, ..., Xn)`.
11-
9+
Versions without the `data` argument may also appear, for example in [Density
10+
estimation](@ref).
1211

1312
## [Typical worklows](@id predict_workflow)
1413

1514
Train some supervised `algorithm`:
1615

1716
```julia
18-
model = fit(algorithm, (X, y)) # or `fit(algorithm, X, y)`
17+
model = fit(algorithm, (X, y))
1918
```
2019

2120
Predict probability distributions:

src/fit_update.jl

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,10 @@ The second signature is provided by algorithms that do not generalize to new obs
1414
..., data)` carries out the actual algorithm execution, writing any byproducts of that
1515
operation to the mutable object `model` returned by `fit`.
1616
17-
Whenever `fit` expects a tuple form of argument, `data = (X1, ..., Xn)`, then the
18-
signature `fit(algorithm, X1, ..., Xn)` is also provided.
19-
20-
For example, a supervised classifier will typically admit this workflow:
17+
For example, a supervised classifier might have a workflow like this:
2118
2219
```julia
23-
model = fit(algorithm, (X, y)) # or `fit(algorithm, X, y)`
20+
model = fit(algorithm, (X, y))
2421
ŷ = predict(model, Xnew)
2522
```
2623
@@ -33,24 +30,22 @@ See also [`predict`](@ref), [`transform`](@ref), [`inverse_transform`](@ref),
3330
3431
# New implementations
3532
36-
Implementation is compulsory. The signature must include `verbosity`. Fallbacks provide
37-
the data slurping versions. A fallback for the first signature calls the second, ignoring
38-
`data`:
33+
Implementation of exactly one of the signatures is compulsory. If `fit(algorithm;
34+
verbosity=1)` is implemented, then the trait [`LearnAPI.is_static`](@ref) must be
35+
overloaded to return `true`.
3936
40-
```julia
41-
fit(algorithm, data; kwargs...) = fit(algorithm; kwargs...)
42-
```
37+
The signature must include `verbosity`.
4338
44-
If only the `fit(algorithm)` signature is expliclty implemented, then the trait
45-
[`LearnAPI.is_static`](@ref) must be overloaded to return `true`.
39+
The LearnAPI.jl specification has nothing to say regarding `fit` signatures with more than
40+
two arguments. For convenience, for example, an algorithm is free to implement a slurping
41+
signature, such as `fit(algorithm, X, y, extras...) = fit(algorithm, (X, y, extras...))` but
42+
LearnAPI.jl does not guarantee such signatures are actually implemented.
4643
4744
$(DOC_DATA_INTERFACE(:fit))
4845
4946
"""
50-
fit(algorithm, data; kwargs...) =
51-
fit(algorithm; kwargs...)
52-
fit(algorithm, data1, datas...; kwargs...) =
53-
fit(algorithm, (data1, datas...); kwargs...)
47+
function fit end
48+
5449

5550
# # UPDATE AND COUSINS
5651

@@ -91,7 +86,7 @@ Implementation is optional. The signature must include
9186
See also [`LearnAPI.clone`](@ref)
9287
9388
"""
94-
update(model, data1, datas...; kwargs...) = update(model, (data1, datas...); kwargs...)
89+
function update end
9590

9691
"""
9792
update_observations(model, new_data; verbosity=1, parameter_replacements...)
@@ -127,8 +122,7 @@ Implementation is optional. The signature must include
127122
See also [`LearnAPI.clone`](@ref).
128123
129124
"""
130-
update_observations(algorithm, data1, datas...; kwargs...) =
131-
update_observations(algorithm, (data1, datas...); kwargs...)
125+
function update_observations end
132126

133127
"""
134128
update_features(model, new_data; verbosity=1, parameter_replacements...)
@@ -154,5 +148,4 @@ Implementation is optional. The signature must include
154148
See also [`LearnAPI.clone`](@ref).
155149
156150
"""
157-
update_features(algorithm, data1, datas...; kwargs...) =
158-
update_features(algorithm, (data1, datas...); kwargs...)
151+
function update_features end

src/predict_transform.jl

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,20 @@ DOC_MUTATION(op) =
1313
1414
"""
1515

16+
DOC_SLURPING(op) =
17+
"""
18+
19+
An algorithm is free to implement `$op` signatures with additional positional
20+
arguments (eg., data-slurping signatures) but LearnAPI.jl is silent about their
21+
interpretation or existence.
22+
23+
"""
1624

1725
DOC_MINIMIZE(func) =
1826
"""
1927
20-
If, additionally, [`LearnAPI.strip(model)`](@ref) is overloaded, then the following identity
21-
must hold:
28+
If, additionally, [`LearnAPI.strip(model)`](@ref) is overloaded, then the following
29+
identity must hold:
2230
2331
```julia
2432
$func(LearnAPI.strip(model), args...) = $func(model, args...)
@@ -63,7 +71,7 @@ which lists all supported target proxies.
6371
The argument `model` is anything returned by a call of the form `fit(algorithm, ...)`.
6472
6573
If `LearnAPI.features(LearnAPI.algorithm(model)) == nothing`, then argument `data` is
66-
omitted. An example is density estimators.
74+
omitted in both signatures. An example is density estimators.
6775
6876
# Example
6977
@@ -79,20 +87,20 @@ See also [`fit`](@ref), [`transform`](@ref), [`inverse_transform`](@ref).
7987
8088
# Extended help
8189
82-
If `predict` supports data in the form of a tuple `data = (X1, ..., Xn)`, then a slurping
83-
signature is also provided, as in `predict(model, X1, ..., Xn)`.
84-
85-
Note `predict ` does not mutate any argument, except in the special case
90+
Note `predict ` must not mutate any argument, except in the special case
8691
`LearnAPI.is_static(algorithm) == true`.
8792
8893
# New implementations
8994
9095
If there is no notion of a "target" variable in the LearnAPI.jl sense, or you need an
9196
operation with an inverse, implement [`transform`](@ref) instead.
9297
93-
Implementation is optional. Only the first signature is implemented, but each
94-
`kind_of_proxy` that gets an implementation must be added to the list returned by
95-
[`LearnAPI.kinds_of_proxy`](@ref).
98+
Implementation is optional. Only the first signature (with or without the `data` argument)
99+
is implemented, but each `kind_of_proxy` that gets an implementation must be added to the
100+
list returned by [`LearnAPI.kinds_of_proxy`](@ref).
101+
102+
If `data` is not present in the implemented signature (eg., for density estimators) then
103+
[`LearnAPI.features(algorithm, data)`](@ref) must return `nothing`.
96104
97105
$(DOC_IMPLEMENTED_METHODS(":(LearnAPI.predict)"))
98106
@@ -106,23 +114,12 @@ $(DOC_DATA_INTERFACE(:predict))
106114
predict(model, data) = predict(model, kinds_of_proxy(algorithm(model)) |> first, data)
107115
predict(model) = predict(model, kinds_of_proxy(algorithm(model)) |> first)
108116

109-
# automatic slurping of multiple data arguments:
110-
predict(model, k::KindOfProxy, data1, data2, datas...; kwargs...) =
111-
predict(model, k, (data1, data2, datas...); kwargs...)
112-
predict(model, data1, data2, datas...; kwargs...) =
113-
predict(model, (data1, data2, datas...); kwargs...)
114-
115-
116-
117117
"""
118118
transform(model, data)
119119
120120
Return a transformation of some `data`, using some `model`, as returned by
121121
[`fit`](@ref).
122122
123-
For `data` that consists of a tuple, a slurping version is also provided, i.e., you can do
124-
`transform(model, X1, X2, X3)` in place of `transform(model, (X1, X2, X3))`.
125-
126123
# Example
127124
128125
Below, `X` and `Xnew` are data of the same form.
@@ -157,8 +154,10 @@ See also [`fit`](@ref), [`predict`](@ref),
157154
158155
# New implementations
159156
160-
Implementation for new LearnAPI.jl algorithms is optional. A fallback provides the
161-
slurping version. $(DOC_IMPLEMENTED_METHODS(":(LearnAPI.transform)"))
157+
Implementation for new LearnAPI.jl algorithms is
158+
optional. $(DOC_IMPLEMENTED_METHODS(":(LearnAPI.transform)"))
159+
160+
$(DOC_SLURPING(:transform))
162161
163162
$(DOC_MINIMIZE(:transform))
164163
@@ -167,8 +166,8 @@ $(DOC_MUTATION(:transform))
167166
$(DOC_DATA_INTERFACE(:transform))
168167
169168
"""
170-
transform(model, data1, data2, datas...; kwargs...) =
171-
transform(model, (data1, data2, datas...); kwargs...) # automatic slurping
169+
function transform end
170+
172171

173172
"""
174173
inverse_transform(model, data)

test/fit_update.jl

Lines changed: 0 additions & 14 deletions
This file was deleted.

test/patterns/ensembling.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,13 @@ LearnAPI.strip(model::EnsembleFitted) = EnsembleFitted(
160160
)
161161
)
162162

163+
# convenience method:
164+
LearnAPI.fit(algorithm::Ensemble, X, y, extras...; kwargs...) =
165+
fit(algorithm, (X, y, extras...); kwargs...)
166+
LearnAPI.update(algorithm::EnsembleFitted, X, y, extras...; kwargs...) =
167+
update(algorithm, (X, y, extras...); kwargs...)
168+
169+
163170
# synthetic test data:
164171
N = 10 # number of observations
165172
train = 1:6

test/patterns/gradient_descent.jl

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -227,9 +227,9 @@ function LearnAPI.update_observations(
227227
)
228228

229229
# unpack data:
230-
X = observations.X
231-
y_hot = observations.y_hot
232-
classes = observations.classes
230+
X = observations_new.X
231+
y_hot = observations_new.y_hot
232+
classes = observations_new.classes
233233
nclasses = length(classes)
234234

235235
classes == model.classes || error("New training target has incompatible classes.")
@@ -328,6 +328,16 @@ LearnAPI.training_losses(model::PerceptronClassifierFitted) = model.losses
328328
)
329329

330330

331+
# ### Convenience methods
332+
333+
LearnAPI.fit(algorithm::PerceptronClassifier, X, y; kwargs...) =
334+
fit(algorithm, (X, y); kwargs...)
335+
LearnAPI.update_observations(algorithm::PerceptronClassifier, X, y; kwargs...) =
336+
update_observations(algorithm, (X, y); kwargs...)
337+
LearnAPI.update(algorithm::PerceptronClassifier, X, y; kwargs...) =
338+
update(algorithm, (X, y); kwargs...)
339+
340+
331341
# ## Tests
332342

333343
# synthetic test data:

0 commit comments

Comments
 (0)