-
Notifications
You must be signed in to change notification settings - Fork 7
Description
edit See this important issue
The measures part of MLJBase could do with some TLC. It is not the shiniest part of the MLJ code base, written in a bit of a hurry because nothing much could go forward without something in place, and the existing packages came up short.
I think the API is more-or-less fine, but the way things are implemented is less that ideal, leading to:
(i) code redundancy
(ii) less functionality: measures that could support weights or implement reports_each_observation don't
Recall that a measure reports_each_observation means m(v1, v2) returns a vector of measurements, and otherwise a single scalar is returned. So it does't really make sense for auc, for example, to report_each_observation (which it doesn't). However, mae should (but doesn't).
I propose we make the following assumption that will allow us to resolve these issues for the majority of measures:
If a measure m(v1, v2) implements reports_each_observation then it is understood that it is the sum or mean value of some scalar version m(s1, s2).
For such measures, then, we need only implement the scalar method m(s1, s2) and we can generate the other methods m(v1, v2), m(v1, v2, w) automatically.
For other measures, such as auc and the rms family, m(v1, v2) (and optionally m(v1, v2, w)) must be explicitly implemented, as at present.
In addition to the docs there is a lot about the measure design in this discussion.
Details
To "automatically generate" the extra methods, we could do something like this:
# fallbacks for measures
(m::Measure)(yhat, y::AbstractVector) = _eval(Val(reports_each_observation(m)), m, yhat, y)
(m::Measure)(yhat, y::AbstractVector, w) = _eval(Val(reports_each_observation(m)), m, yhat, y, w)
_eval(::Val{false}, m, args...) = m(yhat, args...)
_eval(::Val{true}, m, y, yhat) = (m(yhat, y)) |> aggregation(m)
_eval(::Val{true}, m, y, yhat, w) = w .* (m(yhat, y)) |> aggregation(m)
supports_measures(m::Measure) = _sm(Val(supports_each_observation(m), m)
_sm(::Val{false}, m) = false
_sm(::Val{true}, m) = trueMetadata
Metadata
Assignees
Labels
Type
Projects
Status