Skip to content

[Tracking] Improvements to measures #17

@ablaom

Description

@ablaom

edit See this important issue

The measures part of MLJBase could do with some TLC. It is not the shiniest part of the MLJ code base, written in a bit of a hurry because nothing much could go forward without something in place, and the existing packages came up short.

I think the API is more-or-less fine, but the way things are implemented is less that ideal, leading to:

(i) code redundancy
(ii) less functionality: measures that could support weights or implement reports_each_observation don't

Recall that a measure reports_each_observation means m(v1, v2) returns a vector of measurements, and otherwise a single scalar is returned. So it does't really make sense for auc, for example, to report_each_observation (which it doesn't). However, mae should (but doesn't).

I propose we make the following assumption that will allow us to resolve these issues for the majority of measures:

If a measure m(v1, v2) implements reports_each_observation then it is understood that it is the sum or mean value of some scalar version m(s1, s2).

For such measures, then, we need only implement the scalar method m(s1, s2) and we can generate the other methods m(v1, v2), m(v1, v2, w) automatically.

For other measures, such as auc and the rms family, m(v1, v2) (and optionally m(v1, v2, w)) must be explicitly implemented, as at present.

In addition to the docs there is a lot about the measure design in this discussion.

Details

To "automatically generate" the extra methods, we could do something like this:

# fallbacks for measures
(m::Measure)(yhat, y::AbstractVector) = _eval(Val(reports_each_observation(m)), m, yhat, y)
(m::Measure)(yhat, y::AbstractVector, w) = _eval(Val(reports_each_observation(m)), m, yhat, y, w)
_eval(::Val{false}, m, args...) = m(yhat, args...)
_eval(::Val{true}, m, y, yhat) = (m(yhat, y)) |> aggregation(m)
_eval(::Val{true}, m, y, yhat, w) = w .* (m(yhat, y)) |> aggregation(m)

supports_measures(m::Measure) = _sm(Val(supports_each_observation(m), m)
_sm(::Val{false}, m) = false
_sm(::Val{true}, m) = true

@tlienart
@azev77

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    Status

    tracking/discussion/metaissues/misc

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions