Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 71 additions & 1 deletion ext/StatsFunsChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@ ChainRulesCore.@scalar_rule(
(α - 1) / x + (1 - β) / (1 - x),
),
)
ChainRulesCore.@scalar_rule(
betalogupdf(α::Real, β::Real, x::Number),
(
log(x),
log1p(-x),
(α - 1) / x + (1 - β) / (1 - x),
),
)

ChainRulesCore.@scalar_rule(
binomlogpdf(n::Real, p::Real, k::Real),
Expand All @@ -22,12 +30,28 @@ ChainRulesCore.@scalar_rule(
ChainRulesCore.NoTangent(),
),
)
ChainRulesCore.@scalar_rule(
binomlogupdf(n::Real, p::Real, k::Real),
(
ChainRulesCore.NoTangent(),
(k / p - n) / (1 - p),
ChainRulesCore.NoTangent(),
),
)

ChainRulesCore.@scalar_rule(
chisqlogpdf(k::Real, x::Number),
@setup(hk = k / 2),
(
(log(x) - logtwo - digamma(hk)) / 2,
(log(x / 2) - digamma(hk)) / 2,
(hk - 1) / x - one(hk) / 2,
),
)
ChainRulesCore.@scalar_rule(
chisqlogupdf(k::Real, x::Number),
@setup(hk = k / 2),
(
log(x / 2) / 2,
(hk - 1) / x - one(hk) / 2,
),
)
Expand All @@ -47,6 +71,19 @@ ChainRulesCore.@scalar_rule(
((ν1 - 2) / x - ν1 * νsum / temp1) / 2,
),
)
ChainRulesCore.@scalar_rule(
fdistlogupdf(ν1::Real, ν2::Real, x::Number),
@setup(
tmp = x * ν1 + ν2,
a = x * (ν1 + ν2) / tmp,
b = ν2 / tmp,
),
(
(log(x * b) - a) / 2,
(log(b) + (ν1 / ν2) * a) / 2,
(ν1 * (1 - a) - 2) / (2 * x),
),
)

ChainRulesCore.@scalar_rule(
gammalogpdf(k::Real, θ::Real, x::Number),
Expand All @@ -61,11 +98,32 @@ ChainRulesCore.@scalar_rule(
- (1 + z) / x,
),
)
ChainRulesCore.@scalar_rule(
gammalogupdf(k::Real, θ::Real, x::Number),
@setup(
invθ = inv(θ),
xoθ = invθ * x,
z = xoθ - (k - 1),
),
(
log(xoθ),
invθ * z,
- z / x,
),
)

ChainRulesCore.@scalar_rule(
poislogpdf(λ::Number, x::Number),
((iszero(x) && iszero(λ) ? zero(x / λ) : x / λ) - 1, ChainRulesCore.NoTangent()),
)
ChainRulesCore.@scalar_rule(
poislogupdf(λ::Number, x::Number),
((iszero(x) && iszero(λ) ? zero(x / λ) : x / λ), ChainRulesCore.NoTangent()),
)
ChainRulesCore.@scalar_rule(
poislogulikelihood(λ::Number, x::Number),
((iszero(x) && iszero(λ) ? zero(x / λ) : x / λ) - 1, ChainRulesCore.NoTangent()),
)

ChainRulesCore.@scalar_rule(
tdistlogpdf(ν::Real, x::Number),
Expand All @@ -81,5 +139,17 @@ ChainRulesCore.@scalar_rule(
- x * b,
),
)
ChainRulesCore.@scalar_rule(
tdistlogupdf(ν::Real, x::Number),
@setup(
xsq = x^2,
a = xsq / ν,
b = (ν + 1) / (ν + xsq),
),
(
(a * b - log1p(a)) / 2,
- x * b,
),
)

end # module
32 changes: 32 additions & 0 deletions src/StatsFuns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ export
# distrs/beta
betapdf, # pdf of beta distribution
betalogpdf, # logpdf of beta distribution
betalogupdf, # unnormalized logpdf of beta distribution (parameters constant)
betalogulikelihood, # unnormalized logpdf of beta distribution (data constant)
betacdf, # cdf of beta distribution
betaccdf, # ccdf of beta distribution
betalogcdf, # logcdf of beta distribution
Expand All @@ -67,6 +69,8 @@ export
# distrs/binom
binompdf, # pdf of binomial distribution
binomlogpdf, # logpdf of binomial distribution
binomlogupdf, # unnormalized logpdf of binomial distribution (parameters constant)
binomlogulikelihood, # unnormalized logpdf of binomial distribution (data constant)
binomcdf, # cdf of binomial distribution
binomccdf, # ccdf of binomial distribution
binomlogcdf, # logcdf of binomial distribution
Expand All @@ -79,6 +83,8 @@ export
# distrs/chisq
chisqpdf, # pdf of chi-square distribution
chisqlogpdf, # logpdf of chi-square distribution
chisqlogupdf, # unnormalized logpdf of chi-square distribution (parameters constant)
chisqlogulikelihood, # unnormalized logpdf of chi-square distribution (data constant)
chisqcdf, # cdf of chi-square distribution
chisqccdf, # ccdf of chi-square distribution
chisqlogcdf, # logcdf of chi-square distribution
Expand All @@ -91,6 +97,8 @@ export
# distrs/fdist
fdistpdf, # pdf of F distribution
fdistlogpdf, # logpdf of F distribution
fdistlogupdf, # unnormalized logpdf of F distribution (parameters constant)
fdistlogulikelihood, # unnormalized logpdf of F distribution (data constant)
fdistcdf, # cdf of F distribution
fdistccdf, # ccdf of F distribution
fdistlogcdf, # logcdf of F distribution
Expand All @@ -103,6 +111,8 @@ export
# distrs/gamma
gammapdf, # pdf of gamma distribution
gammalogpdf, # logpdf of gamma distribution
gammalogupdf, # unnormalized logpdf of gamma distribution (parameters constant)
gammalogulikelihood, # unnormalized logpdf of gamma distribution (data constant)
gammacdf, # cdf of gamma distribution
gammaccdf, # ccdf of gamma distribution
gammalogcdf, # logcdf of gamma distribution
Expand All @@ -115,6 +125,8 @@ export
# distrs/hyper
hyperpdf, # pdf of hypergeometric distribution
hyperlogpdf, # logpdf of hypergeometric distribution
hyperlogupdf, # unnormalized logpdf of hypergeometric distribution (parameters constant)
hyperlogulikelihood, # unnormalized logpdf of hypergeometric distribution (data constant)
hypercdf, # cdf of hypergeometric distribution
hyperccdf, # ccdf of hypergeometric distribution
hyperlogcdf, # logcdf of hypergeometric distribution
Expand All @@ -127,6 +139,8 @@ export
# distrs/nbeta
nbetapdf, # pdf of noncentral beta distribution
nbetalogpdf, # logpdf of noncentral beta distribution
nbetalogupdf, # unnormalized logpdf of noncentral beta distribution (parameters constant)
nbetalogulikelihood, # unnormalized logpdf of noncentral beta distribution (data constant)
nbetacdf, # cdf of noncentral beta distribution
nbetaccdf, # ccdf of noncentral beta distribution
nbetalogcdf, # logcdf of noncentral beta distribution
Expand All @@ -139,6 +153,8 @@ export
# distrs/nbinom
nbinompdf, # pdf of negative nbinomial distribution
nbinomlogpdf, # logpdf of negative nbinomial distribution
nbinomlogupdf, # unnormalized logpdf of negative nbinomial distribution (parameters constant)
nbinomlogulikelihood, # unnormalized logpdf of negative nbinomial distribution (data constant)
nbinomcdf, # cdf of negative nbinomial distribution
nbinomccdf, # ccdf of negative nbinomial distribution
nbinomlogcdf, # logcdf of negative nbinomial distribution
Expand All @@ -151,6 +167,8 @@ export
# distrs/nchisq
nchisqpdf, # pdf of noncentral chi-square distribution
nchisqlogpdf, # logpdf of noncentral chi-square distribution
nchisqlogupdf, # unnormalized logpdf of noncentral chi-square distribution (parameters constant)
nchisqlogulikelihood, # unnormalized logpdf of noncentral chi-square distribution (data constant)
nchisqcdf, # cdf of noncentral chi-square distribution
nchisqccdf, # ccdf of noncentral chi-square distribution
nchisqlogcdf, # logcdf of noncentral chi-square distribution
Expand All @@ -163,6 +181,8 @@ export
# distrs/nfdist
nfdistpdf, # pdf of noncentral F distribution
nfdistlogpdf, # logpdf of noncentral F distribution
nfdistlogupdf, # unnormalized logpdf of noncentral F distribution (parameters constant)
nfdistlogulikelihood, # unnormalized logpdf of noncentral F distribution (data constant)
nfdistcdf, # cdf of noncentral F distribution
nfdistccdf, # ccdf of noncentral F distribution
nfdistlogcdf, # logcdf of noncentral F distribution
Expand All @@ -175,6 +195,8 @@ export
# distrs/norm
normpdf, # pdf of normal distribution
normlogpdf, # logpdf of normal distribution
normlogupdf, # unnormalized logpdf of normal distribution (parameters constant)
normlogulikelihood, # unnormalized logpdf of normal distribution (data constant)
normcdf, # cdf of normal distribution
normccdf, # ccdf of normal distribution
normlogcdf, # logcdf of normal distribution
Expand All @@ -187,6 +209,8 @@ export
# distrs/ntdist
ntdistpdf, # pdf of noncentral t distribution
ntdistlogpdf, # logpdf of noncentral t distribution
ntdistlogupdf, # unnormalized logpdf of noncentral t distribution (parameters constant)
ntdistlogulikelihood, # unnormalized logpdf of noncentral t distribution (data constant)
ntdistcdf, # cdf of noncentral t distribution
ntdistccdf, # ccdf of noncentral t distribution
ntdistlogcdf, # logcdf of noncentral t distribution
Expand All @@ -199,6 +223,8 @@ export
# distrs/pois
poispdf, # pdf of Poisson distribution
poislogpdf, # logpdf of Poisson distribution
poislogupdf, # unnormalized logpdf of Poisson distribution (parameters constant)
poislogulikelihood, # unnormalized logpdf of Poisson distribution (data constant)
poiscdf, # cdf of Poisson distribution
poisccdf, # ccdf of Poisson distribution
poislogcdf, # logcdf of Poisson distribution
Expand All @@ -211,6 +237,8 @@ export
# distrs/tdist
tdistpdf, # pdf of student's t distribution
tdistlogpdf, # logpdf of student's t distribution
tdistlogupdf, # unnormalized logpdf of student's t distribution (parameters constant)
tdistlogulikelihood, # unnormalized logpdf of student's t distribution (data constant)
tdistcdf, # cdf of student's t distribution
tdistccdf, # ccdf of student's t distribution
tdistlogcdf, # logcdf of student's t distribution
Expand All @@ -223,6 +251,8 @@ export
# distrs/signrank
signrankpdf,
signranklogpdf,
signranklogupdf,
signranklogulikelihood,
signrankcdf,
signranklogcdf,
signrankccdf,
Expand All @@ -245,6 +275,8 @@ export
# distrs/wilcox
wilcoxpdf,
wilcoxlogpdf,
wilcoxlogupdf,
wilcoxlogulikelihood,
wilcoxcdf,
wilcoxlogcdf,
wilcoxccdf,
Expand Down
10 changes: 9 additions & 1 deletion src/distrs/beta.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,20 @@ betapdf(α::Real, β::Real, x::Real) = exp(betalogpdf(α, β, x))

betalogpdf(α::Real, β::Real, x::Real) = betalogpdf(promote(α, β, x)...)
function betalogpdf(α::T, β::T, x::T) where {T <: Real}
logupdf = betalogupdf(α, β, x)
return isfinite(logupdf) ? logupdf - logbeta(α, β) : logupdf
end

betalogupdf(α::Real, β::Real, x::Real) = betalogupdf(promote(α, β, x)...)
function betalogupdf(α::T, β::T, x::T) where {T <: Real}
# we ensure that `log(x)` and `log1p(-x)` do not error
y = clamp(x, 0, 1)
val = xlogy(α - 1, y) + xlog1py(β - 1, -y) - logbeta(α, β)
val = xlogy(α - 1, y) + xlog1py(β - 1, -y)
return x < 0 || x > 1 ? oftype(val, -Inf) : val
end

betalogulikelihood(α::Real, β::Real, x::Real) = betalogpdf(α, β, x)

function betacdf(α::Real, β::Real, x::Real)
# Handle degenerate cases
if iszero(α) && β > 0
Expand Down
14 changes: 13 additions & 1 deletion src/distrs/binom.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,23 @@ binompdf(n::Real, p::Real, k::Real) = exp(binomlogpdf(n, p, k))

binomlogpdf(n::Real, p::Real, k::Real) = binomlogpdf(promote(n, p, k)...)
function binomlogpdf(n::T, p::T, k::T) where {T <: Real}
logupdf = binomlogupdf(n, p, k)
if isfinite(logupdf)
return min(0, logupdf - log(n + 1))
else
return logupdf
end
end

binomlogupdf(n::Real, p::Real, k::Real) = binomlogupdf(promote(n, p, k)...)
function binomlogupdf(n::T, p::T, k::T) where {T <: Real}
m = clamp(k, 0, n)
val = min(0, betalogpdf(m + 1, n - m + 1, p) - log(n + 1))
val = betalogpdf(m + 1, n - m + 1, p)
return 0 <= k <= n && isinteger(k) ? val : oftype(val, -Inf)
end

binomlogulikelihood(n::Real, p::Real, k::Real) = binomlogpdf(n, p, k)

for l in ("", "log"), compl in (false, true)
fbinom = Symbol(string("binom", l, ifelse(compl, "c", ""), "cdf"))
fbeta = Symbol(string("beta", l, ifelse(compl, "", "c"), "cdf"))
Expand Down
10 changes: 9 additions & 1 deletion src/distrs/chisq.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
# functions related to chi-square distribution

# Just use the Gamma definitions
for f in ("pdf", "logpdf", "cdf", "ccdf", "logcdf", "logccdf", "invcdf", "invccdf", "invlogcdf", "invlogccdf")
for f in ("pdf", "logpdf", "logupdf", "cdf", "ccdf", "logcdf", "logccdf", "invcdf", "invccdf", "invlogcdf", "invlogccdf")
_chisqf = Symbol("chisq" * f)
_gammaf = Symbol("gamma" * f)
@eval begin
$(_chisqf)(k::Real, x::Real) = $(_chisqf)(promote(k, x)...)
$(_chisqf)(k::T, x::T) where {T <: Real} = $(_gammaf)(k / 2, 2, x)
end
end

chisqlogulikelihood(k::Real, x::Real) = chisqlogulikelihood(promote(k, x)...)
function chisqlogulikelihood(k::T, x::T) where {T <: Real}
y = max(x, 0)
k2 = k / 2
val = xlogy(k2, x / 2) - loggamma(k2)
return x < 0 ? oftype(val, -Inf) : val
end
21 changes: 21 additions & 0 deletions src/distrs/fdist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,27 @@ function fdistlogpdf(ν1::T, ν2::T, x::T) where {T <: Real}
return x < 0 ? oftype(val, -Inf) : val
end

fdistlogupdf(ν1::Real, ν2::Real, x::Real) = fdistlogupdf(promote(ν1, ν2, x)...)
function fdistlogupdf(ν1::T, ν2::T, x::T) where {T <: Real}
# we ensure that `log(x)` does not error if `x < 0`
y = max(x, 0)
val = (xlogy(ν1 - 2, y) - xlogy(ν1 + ν2, ν1 * y + ν2)) / 2
return x < 0 ? oftype(val, -Inf) : val
end

fdistloguloglikelihood(ν1::Real, ν2::Real, x::Real) = fdistlogulikelihood(promote(ν1, ν2, x)...)
function fdistlogulikelihood(ν1::T, ν2::T, x::T) where {T}
# we ensure that `log(x)` does not error if `x < 0`
y = max(x, 0)
tmp = ν1 * y + ν2
a = ν1 / tmp
b = ν2 / tmp
halfν1 = ν1 / 2
halfν2 = ν2 / 2
val = (xlogy(halfν1, a) + xlogy(halfν2, b)) - logbeta(halfν1, halfν2)
return x < 0 ? oftype(val, -Inf) : val
end

for f in ("cdf", "ccdf", "logcdf", "logccdf")
ff = Symbol("fdist" * f)
bf = Symbol("beta" * f)
Expand Down
26 changes: 22 additions & 4 deletions src/distrs/gamma.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,31 @@ gammapdf(k::Real, θ::Real, x::Real) = exp(gammalogpdf(k, θ, x))

gammalogpdf(k::Real, θ::Real, x::Real) = gammalogpdf(promote(k, θ, x)...)
function gammalogpdf(k::T, θ::T, x::T) where {T <: Real}
logupdf = gammalogupdf(k, θ, x)
return isfinite(logupdf) ? logupdf - loggamma(k) - k * log(θ) : logupdf
end

gammalogupdf(k::Real, θ::Real, x::Real) = gammalogupdf(promote(k, θ, x)...)
function gammalogupdf(k::T, θ::T, x::T) where {T <: Real}
# we ensure that `log(x)` does not error if `x < 0`
= max(x, 0) / θ
val = -loggamma(k) - log(θ) - xθ
y = max(x, 0)
val = -float(y / θ)
# xlogy(k - 1, xθ) - xθ -> -∞ for xθ -> ∞ so we only add the first term
# when it's safe
if isfinite(xθ)
val += xlogy(k - 1, xθ)
if isfinite(val)
val += xlogy(k - 1, y)
end
return x < 0 ? oftype(val, -Inf) : val
end

function gammalogulikelihood(k::Real, θ::Real, x::Real)
# we ensure that `log(x)` does not error if `x < 0`
xθ = max(x, 0) / θ
val = - xθ - loggamma(k)
# xlogy(k, xθ) - xθ -> -∞ for xθ -> ∞ so we only add the first term
# when it's safe
if isfinite(val)
val += xlogy(k, xθ)
end
return x < 0 ? oftype(val, -Inf) : val
end
Expand Down
4 changes: 4 additions & 0 deletions src/distrs/hyper.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,7 @@ using .RFunctions:
hyperinvccdf,
hyperinvlogcdf,
hyperinvlogccdf


hyperlogupdf(ms::Real, mf::Real, n::Real, x::Real) = hyperlogpdf(ms, mf, n, x)
hyperlogulikelihood(ms::Real, mf::Real, n::Real, x::Real) = hyperlogpdf(ms, mf, n, x)
Loading
Loading