Skip to content

Commit aed7b79

Browse files
authored
Merge pull request #72 from alan-turing-institute/dev
Patch release 0.7.2
2 parents 18d07f0 + b28cdef commit aed7b79

File tree

8 files changed

+70
-45
lines changed

8 files changed

+70
-45
lines changed

Project.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLJBase"
22
uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
33
authors = ["Anthony D. Blaom <[email protected]>"]
4-
version = "0.7.1"
4+
version = "0.7.2"
55

66
[deps]
77
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
@@ -12,7 +12,6 @@ OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
1212
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1313
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1414
ScientificTypes = "321657f4-b219-11e9-178b-2701a2544e81"
15-
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1615
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1716
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1817
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

src/MLJBase.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ export fit, update, clean!
99
export predict, predict_mean, predict_mode, fitted_params
1010
export transform, inverse_transform, se, evaluate, best
1111
export info, info_dict
12+
export is_same_except
1213

1314
export load_path, package_url, package_name, package_uuid # model_traits.jl
1415
export input_scitype, supports_weights # model_traits.jl
@@ -74,7 +75,6 @@ import Missings.levels
7475
using Statistics
7576
using Random
7677
using InteractiveUtils
77-
using SparseArrays
7878

7979

8080
## CONSTANTS
@@ -125,16 +125,22 @@ abstract type UnsupervisedNetwork <: Unsupervised end
125125
# `fit(model, verbosity::Integer, training_args...) -> fitresult, cache, report`
126126
# or, one the simplified versions
127127
# `fit(model, training_args...) -> fitresult`
128-
# `fit(model, X, ys...) -> fitresult`
129-
fit(model::Model, verbosity::Integer, args...) = fit(model, args...), nothing, nothing
128+
fit(model::Model, verbosity::Integer, args...) =
129+
fit(model, args...), nothing, nothing
130130

131131
# each model interface may optionally overload the following refitting
132132
# method:
133133
update(model::Model, verbosity, fitresult, cache, args...) =
134134
fit(model, verbosity, args...)
135135

136+
# fallbacks for supervised models that don't support sample weights:
137+
fit(model::Supervised, verbosity::Integer, X, y, w) =
138+
fit(model, verbosity, X, y)
139+
update(model::Supervised, verbosity, fitresult, cache, X, y, w) =
140+
update(model, verbosity, fitresult, cache, X, y)
141+
136142
# methods dispatched on a model and fit-result are called
137-
# *operations*. supervised models must implement a `predict`
143+
# *operations*. Supervised models must implement a `predict`
138144
# operation (extending the `predict` method of StatsBase).
139145

140146
# unsupervised methods must implement this operation:

src/data.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ integer arrays, in which case `d` is broadcast over all elements.
205205
julia> d(int(v)) == v
206206
true
207207
208-
*Warning:* It is *not* true that `int(d(u)) == u` always holds.
208+
*Warning:* It is *not* true that `int(d(u)) == u` always holds.
209209
210210
See also: [`int`](@ref), [`classes`](@ref).
211211
@@ -239,9 +239,11 @@ output, unless `transpose=true`.
239239
"""
240240
matrix(X; kwargs...) = matrix(Val(ScientificTypes.trait(X)), X; kwargs...)
241241
matrix(::Val{:other}, X; kwargs...) = throw(ArgumentError)
242-
matrix(::Val{:other}, X::AbstractMatrix; kwargs...) = X
242+
matrix(::Val{:other}, X::AbstractMatrix; transpose=false) =
243+
transpose ? permutedims(X) : X
243244

244245
matrix(::Val{:table}, X; kwargs...) = Tables.matrix(X; kwargs...)
246+
245247
# matrix(::Val{:table, X)
246248
# cols = Tables.columns(X) # property-accessible object
247249
# mat = reduce(hcat, [getproperty(cols, ftr) for ftr in propertynames(cols)])

src/distributions.jl

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -150,10 +150,10 @@ function Base.show(stream::IO, d::UnivariateFinite)
150150
# instantiation of d
151151
x1 = d.decoder(first(raw))
152152
p1 = d.prob_given_class[first(raw)]
153-
str = "UnivariateFinite($x1=>$p1"
153+
str = "UnivariateFinite($x1=>$(round(p1, sigdigits=3))"
154154
pairs = (d.decoder(r)=>d.prob_given_class[r] for r in raw[2:end])
155155
for pair in pairs
156-
str *= ", $(pair[1])=>$(pair[2])"
156+
str *= ", $(pair[1])=>$(round(pair[2], sigdigits=3))"
157157
end
158158
str *= ")"
159159
print(stream, str)
@@ -199,18 +199,14 @@ function average(dvec::AbstractVector{UnivariateFinite{L,U,T}};
199199
end
200200

201201
# get all refs:
202-
refs = reduce(union, [keys(d.prob_given_class) for d in dvec])
203-
204-
# pad each individual dicts so they have common keys:
205-
z = LittleDict{U,T}([x => zero(T) for x in refs]...)
206-
prob_given_class_vec = map(dvec) do d
207-
merge(z, d.prob_given_class)
208-
end
202+
refs = Tuple(reduce(union, [keys(d.prob_given_class) for d in dvec]))
209203

210204
# initialize the prob dictionary for the distribution sum:
211-
prob_given_class = LittleDict{U,T}()
212-
for x in refs
213-
prob_given_class[x] = zero(T)
205+
prob_given_class = LittleDict{U,T}(refs, zeros(T, length(refs)))
206+
207+
# make vector of all the distributions dicts padded to have same common keys:
208+
prob_given_class_vec = map(dvec) do d
209+
merge(prob_given_class, d.prob_given_class)
214210
end
215211

216212
# sum up:
@@ -232,15 +228,10 @@ function average(dvec::AbstractVector{UnivariateFinite{L,U,T}};
232228
end
233229

234230
return UnivariateFinite(first(dvec).decoder, prob_given_class)
235-
236231
end
237232

238233
function _pdf(d::UnivariateFinite{L,U,T}, ref) where {L,U,T}
239-
if haskey(d.prob_given_class, ref)
240-
return d.prob_given_class[ref]
241-
else
242-
return zero(T)
243-
end
234+
return get(d.prob_given_class, ref, zero(T))
244235
end
245236

246237
Distributions.pdf(d::UnivariateFinite{L,U,T},
@@ -336,7 +327,9 @@ function Distributions.fit(d::Type{<:UnivariateFinite},
336327
isempty(vpure) && error("No non-missing data to fit. ")
337328
N = length(vpure)
338329
count_given_class = Dist.countmap(vpure)
339-
prob_given_class = LittleDict([x=>c/N for (x, c) in count_given_class])
330+
classes = Tuple(keys(count_given_class))
331+
probs = values(count_given_class)./N
332+
prob_given_class = LittleDict(classes, probs)
340333
return UnivariateFinite(prob_given_class)
341334
end
342335

src/equality.jl

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,29 @@
1-
# by default, MLJType objects are `==` if: (i) they have ==
2-
# type, AND (ii) they have the same set of *defined* fields AND (iii)
3-
# their defined field values are `==` OR the values are both
4-
# AbstractRNG objects.
5-
import Base.==
6-
function ==(m1::M1, m2::M2) where {M1<:MLJType,M2<:MLJType}
1+
"""
2+
is_same_except(m1::MLJType, m2::MLJType, exceptions::Symbol...)
3+
4+
Returns `true` only the following conditions all hold:
5+
6+
- `m1` and `m2` have the same type.
7+
8+
- `m1` and `m2` have the same undefined fields.
9+
10+
- Corresponding fields agree, or are listed as
11+
`exceptions`, or have `AbstractRNG` as values (one or both)
12+
13+
Note that Base.== is overloaded such that `m1 == m2` if and only if
14+
`is_same_except(m1, m2)`.
15+
16+
"""
17+
function is_same_except(m1::M1, m2::M2,
18+
exceptions::Symbol...) where {M1<:MLJType,M2<:MLJType}
719
if typeof(m1) != typeof(m2)
820
return false
921
end
1022
defined1 = filter(fieldnames(M1)|>collect) do fld
11-
isdefined(m1, fld)
23+
isdefined(m1, fld) && !(fld in exceptions)
1224
end
1325
defined2 = filter(fieldnames(M1)|>collect) do fld
14-
isdefined(m2, fld)
26+
isdefined(m2, fld) && !(fld in exceptions)
1527
end
1628
if defined1 != defined2
1729
return false
@@ -20,17 +32,21 @@ function ==(m1::M1, m2::M2) where {M1<:MLJType,M2<:MLJType}
2032
for fld in defined1
2133
same_values = same_values &&
2234
(getfield(m1, fld) == getfield(m2, fld) ||
23-
getfield(m1, fld) isa AbstractRNG)
35+
getfield(m1, fld) isa AbstractRNG) ||
36+
getfield(m2, fld) isa AbstractRNG
2437
end
2538
return same_values
2639
end
2740

41+
import Base.==
42+
43+
==(m1::M1, m2::M2) where {M1<:MLJType,M2<:MLJType} = is_same_except(m1, m2)
44+
2845
# for using `replace` or `replace!` on collections of MLJType objects
2946
# (eg, Model objects in a learning network) we need a stricter
3047
# equality:
3148
MLJBase.isequal(m1::MLJType, m2::MLJType) = (m1 === m2)
3249

33-
3450
## TODO: Do we need to overload hash here?
3551
function Base.in(x::MLJType, itr::Set)
3652
anymissing = false

test/data.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@ module TestData
44
using Test
55
using DataFrames
66
import TypedTables
7-
using StatsBase
7+
# using StatsBase
88
# using JuliaDB
9-
using SparseArrays
9+
# using SparseArrays
1010
using CategoricalArrays
1111
import Tables
1212
using ScientificTypes
@@ -172,6 +172,9 @@ end
172172
tab = table(A)
173173
selectcols(tab, 1) == v
174174

175+
@test matrix(B) == B
176+
@test matrix(B, transpose=true) == permutedims(B)
177+
175178
end
176179

177180
## TABLE INDEXING

test/distributions.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
module TestDistributions
22

3-
# using Revise
43
using Test
54
using MLJBase
65
using CategoricalArrays

test/equality.jl

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,27 @@ using Test
55
mutable struct Foo <: MLJType
66
rng::AbstractRNG
77
x::Int
8+
y::Int
89
end
910

1011
mutable struct Bar <: MLJType
1112
rng::AbstractRNG
1213
x::Int
14+
y::Int
1315
end
1416

15-
f1 = Foo(MersenneTwister(7), 1)
16-
f2 = Foo(MersenneTwister(8), 1)
17+
f1 = Foo(MersenneTwister(7), 1, 2)
18+
f2 = Foo(MersenneTwister(8), 1, 2)
1719
@test f1.rng != f2.rng
1820
@test f1 == f2
19-
f1.x = 2
21+
f1.x = 10
2022
@test f1 != f2
21-
b = Bar(MersenneTwister(7), 1)
22-
@test f1 != b
23+
b = Bar(MersenneTwister(7), 1, 2)
24+
@test f2 != b
25+
26+
@test is_same_except(f1, f2, :x)
27+
f1.y = 20
28+
@test f1 != f2
29+
@test is_same_except(f1, f2, :x, :y)
2330

2431
true

0 commit comments

Comments
 (0)