Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
InvertedIndices = "41ab1584-1d38-5bbf-9106-f11c6c58b48f"
LearnAPI = "92ad9a40-7767-427a-9ee6-6e577f1266cb"
Expand Down Expand Up @@ -44,6 +45,7 @@ CategoricalDistributions = "0.1"
ComputationalResources = "0.3"
DelimitedFiles = "1"
Distributions = "0.25.3"
FillArrays = "1.14.0"
InvertedIndices = "1"
LearnAPI = "1"
MLJModelInterface = "1.11"
Expand Down
2 changes: 2 additions & 0 deletions src/MLJBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ const Dist = Distributions
# Measures
import StatisticalMeasuresBase

import FillArrays

# Plots
using RecipesBase: RecipesBase, @recipe

Expand Down
11 changes: 10 additions & 1 deletion src/resampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -849,6 +849,10 @@ end
# ---------------------------------------------------------------
# Helpers

# to fill out predictions in the case of density estimation ("cone" construction):
fill_if_needed(yhat, X, n) = yhat
fill_if_needed(yhat, X::Nothing, n) = FillArrays.Fill(yhat, n)

function actual_rows(rows, N, verbosity)
unspecified_rows = (rows === nothing)
_rows = unspecified_rows ? (1:N) : rows
Expand Down Expand Up @@ -1470,10 +1474,15 @@ function evaluate!(
function fit_and_extract_on_fold(mach, k)
train, test = resampling[k]
fit!(mach; rows=train, verbosity=verbosity - 1, force=force)
ntest = MLJBase.nrows(test)
# build a dictionary of predictions keyed on the operations
# that appear (`predict`, `predict_mode`, etc):
yhat_given_operation =
Dict(op=>op(mach, rows=test) for op in unique(operations))
Dict(op=>
fill_if_needed(op(mach, rows=test), X, ntest)
for op in unique(operations))
# Note: `fill_if_need(yhat, X, n) = yhat` in typical case that `X` is different
# from `nothing`.

ytest = selectrows(y, test)
if per_observation_flag
Expand Down
55 changes: 0 additions & 55 deletions test/interface/model_api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@ module TestModelAPI

using Test
using MLJBase
using StatisticalMeasures
import MLJModelInterface
using ..Models
using Distributions
using StableRNGs

rng = StableRNG(661)
Expand All @@ -30,57 +27,5 @@ rng = StableRNG(661)
@test predict_mode(rgs, fitresult, X)[1] == 3
end

mutable struct UnivariateFiniteFitter <: MLJModelInterface.Probabilistic
alpha::Float64
end
UnivariateFiniteFitter(;alpha=1.0) = UnivariateFiniteFitter(alpha)

@testset "models that fit a distribution" begin
function MLJModelInterface.fit(model::UnivariateFiniteFitter,
verbosity, X, y)

α = model.alpha
N = length(y)
_classes = classes(y)
d = length(_classes)

frequency_given_class = Distributions.countmap(y)
prob_given_class =
Dict(c => (frequency_given_class[c] + α)/(N + α*d) for c in _classes)

fitresult = MLJBase.UnivariateFinite(prob_given_class)

report = (params=Distributions.params(fitresult),)
cache = nothing

verbosity > 0 && @info "Fitted a $fitresult"

return fitresult, cache, report
end

MLJModelInterface.predict(model::UnivariateFiniteFitter,
fitresult,
X) = fitresult


MLJModelInterface.input_scitype(::Type{<:UnivariateFiniteFitter}) =
Nothing
MLJModelInterface.target_scitype(::Type{<:UnivariateFiniteFitter}) =
AbstractVector{<:Finite}

y = coerce(collect("aabbccaa"), Multiclass)
X = nothing
model = UnivariateFiniteFitter(alpha=0)
mach = machine(model, X, y)
fit!(mach, verbosity=0)

ytest = y[1:3]
yhat = predict(mach, nothing) # single UnivariateFinite distribution

@test cross_entropy(fill(yhat, 3), ytest) ≈
mean([-log(1/2), -log(1/2), -log(1/4)])

end

end
true
76 changes: 76 additions & 0 deletions test/resampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import Tables
@everywhere import StatisticalMeasures.StatisticalMeasuresBase as API
using StatisticalMeasures
import LearnAPI
import CategoricalDistributions
import MLJModelInterface

@everywhere begin
using .Models
Expand Down Expand Up @@ -1001,4 +1003,78 @@ MLJBase.save(logger::DummyLogger, mach::Machine) = mach.model
@test MLJBase.save(mach) == model
end


# # RESAMPLING FOR DENSITY ESTIMATORS

# we define a density estimator to fit a `UnivariateFinite` distribution to some
# Categorical data, with a Laplace smoothing option, α.

mutable struct UnivariateFiniteFitter <: MLJModelInterface.Probabilistic
alpha::Float64
end
UnivariateFiniteFitter(;alpha=1.0) = UnivariateFiniteFitter(alpha)

function MLJModelInterface.fit(model::UnivariateFiniteFitter,
verbosity, X, y)

α = model.alpha
N = length(y)
_classes = classes(y)
d = length(_classes)

frequency_given_class = Distributions.countmap(y)
prob_given_class =
Dict(c => (get(frequency_given_class, c, 0) + α)/(N + α*d) for c in _classes)

fitresult = CategoricalDistributions.UnivariateFinite(prob_given_class)

report = (params=Distributions.params(fitresult),)
cache = nothing

verbosity > 0 && @info "Fitted a $fitresult"

return fitresult, cache, report
end

MLJModelInterface.predict(model::UnivariateFiniteFitter,
fitresult,
X) = fitresult


MLJModelInterface.input_scitype(::Type{<:UnivariateFiniteFitter}) =
Nothing
MLJModelInterface.target_scitype(::Type{<:UnivariateFiniteFitter}) =
AbstractVector{<:Finite}

@testset "resampling for density estimators" begin
y = coerce(rand(StableRNG(123), "abc", 20), Multiclass)
X = nothing

train, test = partition(eachindex(y), 0.8)

# this model type is defined in /src/interface/model_api
model = UnivariateFiniteFitter(alpha=0)

mach = machine(model, X, y)
fit!(mach, rows=train, verbosity=0)

ytest = y[test]
yhat = predict(mach, nothing) # single UnivariateFinite distribution

# Estmiate out-of-sample loss. Notice we have to make duplicate versions `yhat`, to
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small typo Estmiate

# match the number ground truth observations with which we are pairing it ("cone"
# construction):
by_hand = log_loss(fill(yhat, length(ytest)), ytest)

# test some behaviour on which the implementation of `evaluate` for density estimators
# is predicated:
@test isnothing(selectrows(X, 1:3))
@test predict(mach, rows=1:3) ≈ yhat

# evaluate has an internal "cone" construction when `X = nothing`, so this should just
# work:
e = evaluate(model, X, y, resampling=[(train, test)], measure=log_loss)
@test e.measurement[1] ≈ by_hand
end

true
Loading