Skip to content

Commit 3634e1b

Browse files
authored
Merge pull request #39 from alan-turing-institute/dev
Lots of little improvements
2 parents ae1aa06 + d18e533 commit 3634e1b

19 files changed

+1294
-329
lines changed

.travis.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ julia:
88
# - 0.7
99
- 1.0
1010
- 1.1
11+
- 1.2
1112
- nightly
1213
notifications:
1314
email: false

Project.toml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLJBase"
22
uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
33
authors = ["Anthony D. Blaom <[email protected]>"]
4-
version = "0.4.0"
4+
version = "0.5.0"
55

66
[deps]
77
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
@@ -20,15 +20,16 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
2020
CSV = "0.5"
2121
CategoricalArrays = "<0.5.3"
2222
Requires = "^0.5.2"
23-
ScientificTypes = "0.1.3"
23+
ScientificTypes = "0.2.0"
2424
Tables = "<0.1.19, >= 0.2"
2525
julia = "1"
2626

2727
[extras]
2828
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
2929
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
30+
LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7"
3031
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3132
TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9"
3233

3334
[targets]
34-
test = ["CSV", "DataFrames", "Test", "TypedTables"]
35+
test = ["CSV", "DataFrames", "LossFunctions", "Test", "TypedTables"]

src/MLJBase.jl

Lines changed: 49 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
11
# Users of this module should first read the document
22
# https://alan-turing-institute.github.io/MLJ.jl/dev/adding_models_for_general_use/
3-
43
module MLJBase
54

6-
export MLJType, Model, Supervised, Unsupervised, Deterministic, Probabilistic
5+
export MLJType, Model, Supervised, Unsupervised
6+
export Deterministic, Probabilistic, Interval
77
export DeterministicNetwork, ProbabilisticNetwork, UnsupervisedNetwork
88
export fit, update, clean!
99
export predict, predict_mean, predict_mode, fitted_params
1010
export transform, inverse_transform, se, evaluate, best
11-
export load_path, package_url, package_name, package_uuid
12-
export input_scitype, supports_weights
13-
export target_scitype, output_scitype
14-
export is_pure_julia, is_wrapper
11+
export info, info_dict
1512

13+
export load_path, package_url, package_name, package_uuid # model_traits.jl
14+
export input_scitype, supports_weights # model_traits.jl
15+
export target_scitype, output_scitype # model_traits.jl
16+
export is_pure_julia, is_wrapper, prediction_type # model_traits.jl
1617
export params # parameters.jl
1718
export reconstruct, int, decoder, classes # data.jl
1819
export selectrows, selectcols, select, nrows # data.jl
@@ -25,6 +26,14 @@ export UnivariateFinite, average # distributions.jl
2526
export SupervisedTask, UnsupervisedTask, MLJTask # tasks.jl
2627
export X_and_y, X_, y_, nrows, nfeatures # tasks.jl
2728
export info # info.jl
29+
export @load_boston, @load_ames, @load_iris # datasets.jl
30+
export @load_reduced_ames # datasets.jl
31+
export @load_crabs # datasets.jl
32+
export orientation, reports_each_observation # measures.jl
33+
export is_feature_dependent # measures.jl
34+
export default_measure, value # measures.jl
35+
export mav, mae, rms, rmsl, rmslp1, rmsp, l1, l2 # measures.jl
36+
export misclassification_rate, cross_entropy # measures.jl
2837

2938
# methods from other packages to be rexported:
3039
export pdf, mean, mode
@@ -36,6 +45,10 @@ export OrderedFactor, Multiclass, Count, Continuous
3645
export Binary, ColorImage, GrayImage, Image
3746
export scitype, scitype_union, coerce, schema
3847

48+
# rexport from Random, Statistics, Distributions, CategoricalArrays:
49+
export pdf, mode, median, mean, shuffle!, categorical, shuffle, levels, levels!
50+
export std
51+
3952
import Base.==
4053

4154
using Tables
@@ -90,6 +103,9 @@ abstract type Probabilistic <: Supervised end
90103
# supervised models that `predict` point-values are of:
91104
abstract type Deterministic <: Supervised end
92105

106+
# supervised models that `predict` intervals:
107+
abstract type Interval <: Supervised end
108+
93109
# for models that are "exported" learning networks (return a Node as
94110
# their fit-result; see MLJ docs:
95111
abstract type ProbabilisticNetwork <: Probabilistic end
@@ -125,39 +141,6 @@ function inverse_transform end
125141
# fitted parameters (eg, coeficients of linear model):
126142
fitted_params(::Model, fitresult) = (fitresult=fitresult,)
127143

128-
# operations implemented by some meta-models:
129-
function se end
130-
function evaluate end
131-
function best end
132-
133-
# a model wishing invalid hyperparameters to be corrected with a
134-
# warning should overload this method (return value is the warning
135-
# message):
136-
clean!(model::Model) = ""
137-
138-
# fallback trait declarations:
139-
input_scitype(::Any) = Unknown
140-
output_scitype(::Any) = Unknown
141-
target_scitype(::Any) = Unknown
142-
is_pure_julia(::Any) = false
143-
package_name(::Any) = "unknown"
144-
package_license(::Any) = "unknown"
145-
load_path(::Any) = "unknown"
146-
package_uuid(::Any) = "unknown"
147-
package_url(::Any) = "unknown"
148-
is_wrapper(::Any) = false
149-
supports_weights(::Any) = false
150-
151-
input_scitype(model::Model) = input_scitype(typeof(model))
152-
output_scitype(model::Model) = output_scitype(typeof(model))
153-
target_scitype(model::Model) = target_scitype(typeof(model))
154-
is_pure_julia(model::Model) = is_pure_julia(typeof(model))
155-
package_name(model::Model) = package_name(typeof(model))
156-
load_path(model::Model) = load_path(typeof(model))
157-
package_uuid(model::Model) = package_uuid(typeof(model))
158-
package_url(model::Model) = package_url(typeof(model))
159-
is_wrapper(m::Model) = is_wrapper(typeof(m))
160-
161144
# probabilistic supervised models may also overload one or more of
162145
# `predict_mode`, `predict_median` and `predict_mean` defined below.
163146

@@ -173,6 +156,31 @@ predict_mean(model::Probabilistic, fitresult, Xnew) =
173156
predict_median(model::Probabilistic, fitresult, Xnew) =
174157
median.(predict(model, fitresult, Xnew))
175158

159+
# operations implemented by some meta-models:
160+
function se end
161+
function evaluate end
162+
function best end
163+
164+
# a model wishing invalid hyperparameters to be corrected with a
165+
# warning should overload this method (return value is the warning
166+
# message):
167+
clean!(model::Model) = ""
168+
169+
170+
## TRAITS
171+
172+
"""
173+
174+
info(object)
175+
176+
List the traits of an object, such as a model or a performance measure.
177+
178+
"""
179+
info(object) = info(object, Val(ScientificTypes.trait(object)))
180+
181+
182+
include("model_traits.jl")
183+
176184
# for unpacking the fields of MLJ objects:
177185
include("parameters.jl")
178186

@@ -187,7 +195,9 @@ include("data.jl")
187195
include("distributions.jl")
188196

189197
include("info.jl")
198+
include("datasets.jl") # importing CSV will also load datasets_requires.jl
190199
include("tasks.jl")
200+
include("measures.jl")
191201

192202
# __init__() function:
193203
include("init.jl")

0 commit comments

Comments
 (0)