Skip to content

Commit 443a58e

Browse files
authored
Merge pull request #49 from alan-turing-institute/dev
Merge dev branch for a 0.5.1 release
2 parents 3634e1b + 00f4f31 commit 443a58e

11 files changed

+446
-24
lines changed

Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLJBase"
22
uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
33
authors = ["Anthony D. Blaom <[email protected]>"]
4-
version = "0.5.0"
4+
version = "0.5.1"
55

66
[deps]
77
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
@@ -27,9 +27,10 @@ julia = "1"
2727
[extras]
2828
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
2929
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
30+
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
3031
LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7"
3132
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3233
TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9"
3334

3435
[targets]
35-
test = ["CSV", "DataFrames", "LossFunctions", "Test", "TypedTables"]
36+
test = ["CSV", "DataFrames", "Distances", "LossFunctions", "Test", "TypedTables"]

src/MLJBase.jl

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ export selectrows, selectcols, select, nrows # data.jl
2020
export table, levels_seen, matrix, container_type # data.jl
2121
export partition, unpack # data.jl
2222
export @set_defaults # utilities.jl
23+
export @mlj_model # mlj_model_macro.jl
24+
export metadata_model, metadata_pkg # metadata_utilities
2325
export HANDLE_GIVEN_ID, @more, @constant # show.jl
2426
export color_on, color_off # show.jl
2527
export UnivariateFinite, average # distributions.jl
@@ -50,6 +52,7 @@ export pdf, mode, median, mean, shuffle!, categorical, shuffle, levels, levels!
5052
export std
5153

5254
import Base.==
55+
import Base: @__doc__
5356

5457
using Tables
5558
using OrderedCollections # already a dependency of StatsBase
@@ -83,7 +86,7 @@ const DEFAULT_SHOW_DEPTH = 0
8386
include("utilities.jl")
8487

8588

86-
## BASE TYPES
89+
## BASE TYPES
8790

8891
abstract type MLJType end
8992
include("equality.jl") # equality for MLJType objects
@@ -116,7 +119,7 @@ abstract type UnsupervisedNetwork <: Unsupervised end
116119
## THE MODEL INTERFACE
117120

118121
# every model interface must implement a `fit` method of the form
119-
# `fit(model, verbosity::Integer, training_args...) -> fitresult, cache, report`
122+
# `fit(model, verbosity::Integer, training_args...) -> fitresult, cache, report`
120123
# or, one the simplified versions
121124
# `fit(model, training_args...) -> fitresult`
122125
# `fit(model, X, ys...) -> fitresult`
@@ -169,14 +172,14 @@ clean!(model::Model) = ""
169172

170173
## TRAITS
171174

172-
"""
175+
"""
173176
174177
info(object)
175178
176179
List the traits of an object, such as a model or a performance measure.
177180
178181
"""
179-
info(object) = info(object, Val(ScientificTypes.trait(object)))
182+
info(object) = info(object, Val(ScientificTypes.trait(object)))
180183

181184

182185
include("model_traits.jl")
@@ -199,6 +202,12 @@ include("datasets.jl") # importing CSV will also load datasets_requires.jl
199202
include("tasks.jl")
200203
include("measures.jl")
201204

205+
# mlj model macro to help define models
206+
include("mlj_model_macro.jl")
207+
208+
# metadata utils
209+
include("metadata_utilities.jl")
210+
202211
# __init__() function:
203212
include("init.jl")
204213

src/data.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,12 @@ function partition(rows::AbstractVector{Int}, fractions...; shuffle::Bool=false,
4545
end
4646

4747
"""
48-
t1, t2, ...., tk = unnpack(table, c1, c2, ... ck; wrap_singles=false)
48+
t1, t2, ...., tk = unnpack(table, t1, t2, ... tk; wrap_singles=false)
4949
5050
Split any Tables.jl compatible `table` into smaller tables (or
5151
vectors) `t1, t2, ..., tk` by making selections *without replacement*
52-
from the column names defined by the conditionals `c1`, `c2`, ...,
53-
`ck`. A *conditional* is any object `c` such that `c(name)` is `true`
52+
from the column names defined by the tests `t1`, `t2`, ...,
53+
`tk`. A *test* is any object `t` such that `t(name)` is `true`
5454
or `false` for each column `name::Symbol` of `table`.
5555
5656
Whenever a returned table contains a single column, it is converted to
@@ -59,7 +59,7 @@ a vector unless `wrap_singles=true`.
5959
Scientific type conversions can be optionally specified (note
6060
semicolon):
6161
62-
unpack(table, c...; wrap_singles=false, col1=>scitype1, col2=>scitype2, ... )
62+
unpack(table, t...; wrap_singles=false, col1=>scitype1, col2=>scitype2, ... )
6363
6464
### Example
6565
@@ -82,7 +82,7 @@ julia> Z
8282
```
8383
8484
"""
85-
function unpack(X, conditionals...; wrap_singles=false, pairs...)
85+
function unpack(X, tests...; wrap_singles=false, pairs...)
8686

8787
if isempty(pairs)
8888
Xfixed = X
@@ -94,7 +94,7 @@ function unpack(X, conditionals...; wrap_singles=false, pairs...)
9494
names_left = schema(Xfixed).names |> collect
9595
history = ""
9696
counter = 1
97-
for c in conditionals
97+
for c in tests
9898
names = filter(c, names_left)
9999
filter!(!in(names), names_left)
100100
history *= "selection $counter: $names\n remaining: $names_left\n"

src/distributions.jl

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22

33
const Dist = Distributions
44

5-
65
## EQUALITY OF DISTRIBUTIONS (minor type piracy)
76

7+
# TODO: We should get rid of this. I think it is used only in
8+
# MLJModels/test.
9+
810
function ==(d1::D, d2::D) where D<:Dist.Sampleable
911
ret = true
1012
for fld in fieldnames(D)
@@ -108,8 +110,10 @@ end
108110

109111
function UnivariateFinite(classes::AbstractVector{L},
110112
p::AbstractVector{<:Real}) where L
111-
L <: CategoricalElement || error("classes must have CategoricalValue or "*
112-
"CategoricalString type.")
113+
L <: CategoricalElement ||
114+
error("`classes` must have type `AbstractVector{T}` where "*
115+
"`T <: Union{CategoricalValue,CategoricalString}. "*
116+
"Perhaps you have `T=Any`? ")
113117
Dist.@check_args(UnivariateFinite, length(classes)==length(p))
114118
prob_given_class = LittleDict([classes[i]=>p[i] for i in eachindex(p)])
115119
return UnivariateFinite(prob_given_class)
@@ -138,6 +142,28 @@ function Base.show(stream::IO, d::UnivariateFinite)
138142
print(stream, str)
139143
end
140144

145+
"""
146+
isapprox(d1::UnivariateFinite, d2::UnivariateFinite; kwargs...)
147+
148+
Returns `true` if and only if `Set(classes(d1) == Set(classes(d2))`
149+
and the corresponding probabilities are approximately equal. The
150+
key-word arguments `kwargs` are passed through to each call of
151+
`isapprox` on probabiliity pairs. Returns `false` otherwise.
152+
153+
"""
154+
function Base.isapprox(d1::UnivariateFinite, d2::UnivariateFinite; kwargs...)
155+
156+
classes1 = classes(d1)
157+
classes2 = classes(d2)
158+
159+
for c in classes1
160+
c in classes2 || return false
161+
isapprox(pdf(d1, c), pdf(d2, c); kwargs...) ||
162+
return false # pdf defined below
163+
end
164+
return true
165+
end
166+
141167
function average(dvec::AbstractVector{UnivariateFinite{L,U,T}};
142168
weights=nothing) where {L,U,T}
143169

src/metadata_utilities.jl

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
"""
2+
docstring_ext
3+
4+
Helper function to generate the docstring for a package.
5+
"""
6+
function docstring_ext(T; descr::String="")
7+
package_name = MLJBase.package_name(T)
8+
package_url = MLJBase.package_url(T)
9+
model_name = MLJBase.name(T)
10+
# the message to return
11+
message = "$descr"
12+
message *= "\n→ based on [$package_name]($package_url)"
13+
message *= "\n→ do `@load $model_name` to use the model"
14+
message *= "\n→ do `?$model_name` for documentation."
15+
end
16+
17+
"""
18+
metadata_pkg
19+
20+
Helper function to write the metadata for a package.
21+
"""
22+
function metadata_pkg(T; name::String="unknown", uuid::String="unknown", url::String="unknown",
23+
julia::Union{Missing,Bool}=missing, license::String="unknown",
24+
is_wrapper::Bool=false)
25+
ex = quote
26+
package_name(::Type{<:$T}) = $name
27+
package_uuid(::Type{<:$T}) = $uuid
28+
package_url(::Type{<:$T}) = $url
29+
is_pure_julia(::Type{<:$T}) = $julia
30+
package_license(::Type{<:$T}) = $license
31+
is_wrapper(::Type{<:$T}) = $is_wrapper
32+
end
33+
eval(ex)
34+
end
35+
36+
"""
37+
metadata_model
38+
39+
Helper function to write the metadata for a single model of a package (complements
40+
[`metadata_ext`](@ref)).
41+
"""
42+
function metadata_model(T; input=Unknown, target=Unknown,
43+
output=Unknown, weights::Bool=false,
44+
descr::String="", path::String="")
45+
if isempty(path)
46+
path = "MLJModels.$(package_name(T))_.$(name(T))"
47+
end
48+
ex = quote
49+
input_scitype(::Type{<:$T}) = $input
50+
output_scitype(::Type{<:$T}) = $output
51+
target_scitype(::Type{<:$T}) = $target
52+
supports_weights(::Type{<:$T}) = $weights
53+
docstring(::Type{<:$T}) = docstring_ext($T, descr=$descr)
54+
load_path(::Type{<:$T}) = $path
55+
end
56+
eval(ex)
57+
end

0 commit comments

Comments
 (0)