Skip to content

Commit 04c9189

Browse files
authored
Merge pull request #104 from nsajko/feature_logabstanh
add `logabstanh` function
2 parents 8dee873 + 00e2d63 commit 04c9189

File tree

6 files changed

+31
-4
lines changed

6 files changed

+31
-4
lines changed

docs/src/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ logistic
1616
logit
1717
logcosh
1818
logabssinh
19+
logabstanh
1920
log1psq
2021
log1pexp
2122
softplus

ext/LogExpFunctionsChainRulesCoreExt.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ ChainRulesCore.@scalar_rule(logistic(x::Real), (Ω * (1 - Ω),))
118118
ChainRulesCore.@scalar_rule(logit(x::Real), (inv(x * (1 - x)),))
119119
ChainRulesCore.@scalar_rule(logcosh(x::Real), tanh(x))
120120
ChainRulesCore.@scalar_rule(logabssinh(x::Real), coth(x))
121+
ChainRulesCore.@scalar_rule(logabstanh(x::Real), inv(cosh(x) * sinh(x)))
121122
ChainRulesCore.@scalar_rule(log1psq(x::Real), (2 * x / (1 + x^2),))
122123
ChainRulesCore.@scalar_rule(log1pexp(x::Real), (logistic(x),))
123124
ChainRulesCore.@scalar_rule(log1mexp(x::Real), (-exp(x - Ω),))

src/LogExpFunctions.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import LinearAlgebra
88

99
export xlogx, xlogy, xlog1py, xexpx, xexpy, logistic, logit, log1psq, log1pexp, log1mexp, log2mexp, logexpm1,
1010
softplus, invsoftplus, log1pmx, logmxp1, logaddexp, logsubexp, logsumexp, logsumexp!, softmax,
11-
softmax!, logcosh, logabssinh, cloglog, cexpexp,
11+
softmax!, logcosh, logabssinh, logabstanh, cloglog, cexpexp,
1212
loglogistic, logitexp, log1mlogistic, logit1mexp
1313

1414
include("basicfuns.jl")

src/basicfuns.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,22 @@ end
218218
"""
219219
$(SIGNATURES)
220220
221+
Return `log(abs(tanh(x)))`, evaluated carefully.
222+
223+
The implementation ensures `logabstanh(-x) = logabstanh(x)`.
224+
"""
225+
function logabstanh(x::Real)
226+
a = abs(x)
227+
if 8*a < 3
228+
log(tanh(a))
229+
else
230+
log1p(-2/(exp(2*a)+1))
231+
end
232+
end
233+
234+
"""
235+
$(SIGNATURES)
236+
221237
Return `log(1+x^2)` evaluated carefully for `abs(x)` very small or very large.
222238
"""
223239
log1psq(x::Real) = log1p(abs2(x))

test/basicfuns.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,29 +93,36 @@ end
9393
end
9494
end
9595

96-
@testset "logcosh and logabssinh" begin
96+
@testset "logcosh and logabssinh and logabstanh" begin
9797
for x in (randn(), randn(Float32))
9898
@test @inferred(logcosh(x)) isa typeof(x)
9999
@test logcosh(x) log(cosh(x))
100100
@test logcosh(-x) == logcosh(x)
101101
@test @inferred(logabssinh(x)) isa typeof(x)
102102
@test logabssinh(x) log(abs(sinh(x)))
103103
@test logabssinh(-x) == logabssinh(x)
104+
@test @inferred(logabstanh(x)) isa typeof(x)
105+
@test logabstanh(x) log(abs(tanh(x)))
106+
@test logabstanh(-x) == logabstanh(x)
104107
end
105108

106109
# special values
107110
for x in (-Inf, Inf, -Inf32, Inf32)
108111
@test @inferred(logcosh(x)) === oftype(x, Inf)
109112
@test @inferred(logabssinh(x)) === oftype(x, Inf)
113+
@test @inferred(logabstanh(x)) === -oftype(x, 0)
110114
end
111115
for x in (NaN, NaN32)
112116
@test @inferred(logcosh(x)) === x
113117
@test @inferred(logabssinh(x)) === x
118+
@test @inferred(logabstanh(x)) === x
114119
end
115120

116-
@testset "accuracy of `logcosh`" begin
121+
@testset "accuracy" begin
117122
for t in (Float16, Float32, Float64)
118-
@test ulp_error_maximum(logcosh, range(start = t(-3), stop = t(3), length = 1000)) < 3
123+
ran = range(start = t(-3), stop = t(3), length = 1000)
124+
@test ulp_error_maximum(logcosh, ran) < 3
125+
@test ulp_error_maximum(logabstanh, ran) < 3
119126
end
120127
end
121128
end

test/chainrules.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@
6464
test_rrule(logcosh, x)
6565
test_frule(logabssinh, x)
6666
test_rrule(logabssinh, x)
67+
test_frule(logabstanh, x)
68+
test_rrule(logabstanh, x)
6769
end
6870

6971
@testset "log1pexp" begin

0 commit comments

Comments
 (0)