Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/adaptation/Adaptation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ struct NaiveHMCAdaptor{M<:MassMatrixAdaptor,Tssa<:StepSizeAdaptor} <: AbstractAd
pc::M
ssa::Tssa
end
function Base.show(io::IO, ::MIME"text/plain", a::NaiveHMCAdaptor)
function Base.show(io::IO, a::NaiveHMCAdaptor)
return print(io, "NaiveHMCAdaptor(pc=", a.pc, ", ssa=", a.ssa, ")")
end

Expand Down
6 changes: 3 additions & 3 deletions src/adaptation/massmatrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ end

struct UnitMassMatrix{T<:AbstractFloat} <: MassMatrixAdaptor end

function Base.show(io::IO, mime::MIME"text/plain", ::UnitMassMatrix{T}) where {T}
function Base.show(io::IO, ::UnitMassMatrix{T}) where {T}
return print(io, "UnitMassMatrix{", T, "} adaptor")
end

Expand Down Expand Up @@ -93,7 +93,7 @@ mutable struct WelfordVar{T<:AbstractFloat,E<:AbstractVecOrMat{T},V<:AbstractVec
end
end

function Base.show(io::IO, mime::MIME"text/plain", ::WelfordVar{T}) where {T}
function Base.show(io::IO, ::WelfordVar{T}) where {T}
return print(io, "WelfordVar{", T, "} adaptor")
end

Expand Down Expand Up @@ -194,7 +194,7 @@ mutable struct WelfordCov{F<:AbstractFloat,C<:AbstractMatrix{F}} <: DenseMatrixE
cov::C
end

function Base.show(io::IO, mime::MIME"text/plain", ::WelfordCov{T}) where {T}
function Base.show(io::IO, ::WelfordCov{T}) where {T}
return print(io, "WelfordCov{", T, "} adaptor")
end

Expand Down
19 changes: 18 additions & 1 deletion src/adaptation/stan_adaptor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ function initialize!(
return nothing
end

function Base.show(io::IO, mime::MIME"text/plain", state::StanHMCAdaptorState)
function Base.show(io::IO, state::StanHMCAdaptorState)
print(io, "window(", state.window_start, ", ", state.window_end, "), window_splits(")
join(io, state.window_splits, ", ")
return print(io, ")")
Expand All @@ -66,6 +66,23 @@ struct StanHMCAdaptor{M<:MassMatrixAdaptor,Tssa<:StepSizeAdaptor} <: AbstractAda
window_size::Int
state::StanHMCAdaptorState
end

function Base.show(io::IO, a::StanHMCAdaptor)
return print(
io,
"StanHMCAdaptor(",
a.pc,
", ",
a.ssa,
"; init_buffer=",
a.init_buffer,
", term_buffer=",
a.term_buffer,
", window_size=",
a.window_size,
")",
)
end
function Base.show(io::IO, mime::MIME"text/plain", a::StanHMCAdaptor)
return print(
io,
Expand Down
21 changes: 19 additions & 2 deletions src/adaptation/stepsize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ getϵ(ss::StepSizeAdaptor) = ss.state.ϵ
struct FixedStepSize{T<:AbstractScalarOrVec{<:AbstractFloat}} <: StepSizeAdaptor
ϵ::T
end
function Base.show(io::IO, mime::MIME"text/plain", a::FixedStepSize)
function Base.show(io::IO, a::FixedStepSize)
return print(io, "FixedStepSize adaptor with step size ", a.ϵ)
end

Expand All @@ -86,7 +86,7 @@ getϵ(fss::FixedStepSize) = fss.ϵ
struct ManualSSAdaptor{T<:AbstractScalarOrVec{<:AbstractFloat}} <: StepSizeAdaptor
state::MSSState{T}
end
function Base.show(io::IO, mime::MIME"text/plain", a::ManualSSAdaptor{T}) where {T}
function Base.show(io::IO, a::ManualSSAdaptor{T}) where {T}
return print(io, "ManualSSAdaptor{", T, "} with step size of ", a.state.ϵ)
end

Expand Down Expand Up @@ -119,6 +119,23 @@ struct NesterovDualAveraging{T<:AbstractFloat,S<:AbstractScalarOrVec{T}} <: Step
δ::T
state::DAState{S}
end

function Base.show(io::IO, a::NesterovDualAveraging)
print(
io,
"NesterovDualAveraging(",
a.γ,
", ",
a.t_0,
", ",
a.κ,
", ",
a.δ,
", ",
a.state.ϵ,
")",
)
end
function Base.show(io::IO, mime::MIME"text/plain", a::NesterovDualAveraging{T}) where {T}
return print(
io,
Expand Down
2 changes: 1 addition & 1 deletion src/hamiltonian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ struct Hamiltonian{M<:AbstractMetric,K<:AbstractKinetic,Tlogπ,T∂logπ∂θ}
ℓπ::Tlogπ
∂ℓπ∂θ::T∂logπ∂θ
end
function Base.show(io::IO, mime::MIME"text/plain", h::Hamiltonian)
function Base.show(io::IO, h::Hamiltonian)
return print(
io,
"Hamiltonian with ",
Expand Down
8 changes: 4 additions & 4 deletions src/integrator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ struct Leapfrog{T<:AbstractScalarOrVec{<:AbstractFloat}} <: AbstractLeapfrog{T}
"Step size."
ϵ::T
end
function Base.show(io::IO, mime::MIME"text/plain", l::Leapfrog)
return print(io, "Leapfrog with step size ϵ=", round.(l.ϵ; sigdigits=3), ")")
function Base.show(io::IO, l::Leapfrog)
return print(io, "Leapfrog with step size ϵ=", round.(l.ϵ; sigdigits=3))
end
integrator_eltype(i::AbstractLeapfrog{T}) where {T<:AbstractFloat} = T

Expand Down Expand Up @@ -120,7 +120,7 @@ end

JitteredLeapfrog(ϵ0, jitter) = JitteredLeapfrog(ϵ0, jitter, ϵ0)

function Base.show(io::IO, mime::MIME"text/plain", l::JitteredLeapfrog)
function Base.show(io::IO, l::JitteredLeapfrog)
return print(
io,
"JitteredLeapfrog with step size ",
Expand Down Expand Up @@ -178,7 +178,7 @@ struct TemperedLeapfrog{FT<:AbstractFloat,T<:AbstractScalarOrVec{FT}} <: Abstrac
α::FT
end

function Base.show(io::IO, mime::MIME"text/plain", l::TemperedLeapfrog)
function Base.show(io::IO, l::TemperedLeapfrog)
return print(
io,
"TemperedLeapfrog with step size ϵ=",
Expand Down
14 changes: 13 additions & 1 deletion src/metric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ abstract type AbstractMetric end

_string_M⁻¹(mat::AbstractMatrix, n_chars::Int=32) = _string_M⁻¹(diag(mat), n_chars)
function _string_M⁻¹(vec::AbstractVector, n_chars::Int=32)
s_vec = string(vec)
s_vec = repr(vec; context=(:compact => true))
l = length(s_vec)
s_dots = " ...]"
n_diag_chars = n_chars - length(s_dots)
Expand All @@ -33,6 +33,10 @@ renew(ue::UnitEuclideanMetric, M⁻¹) = UnitEuclideanMetric(M⁻¹, ue.size)
Base.eltype(::UnitEuclideanMetric{T}) where {T} = T
Base.size(e::UnitEuclideanMetric) = e.size
Base.size(e::UnitEuclideanMetric, dim::Int) = e.size[dim]

function Base.show(io::IO, uem::UnitEuclideanMetric{T}) where {T}
print(io, "UnitEuclideanMetric(", T, ", ", uem.size, ")")
end
function Base.show(io::IO, ::MIME"text/plain", uem::UnitEuclideanMetric{T}) where {T}
return print(
io,
Expand Down Expand Up @@ -66,6 +70,10 @@ renew(ue::DiagEuclideanMetric, M⁻¹) = DiagEuclideanMetric(M⁻¹)

Base.eltype(::DiagEuclideanMetric{T}) where {T} = T
Base.size(e::DiagEuclideanMetric, dim...) = size(e.M⁻¹, dim...)

function Base.show(io::IO, dem::DiagEuclideanMetric)
print(io, "DiagEuclideanMetric(", _string_M⁻¹(dem.M⁻¹), ")")
end
function Base.show(io::IO, ::MIME"text/plain", dem::DiagEuclideanMetric{T}) where {T}
return print(
io,
Expand Down Expand Up @@ -110,6 +118,10 @@ renew(ue::DenseEuclideanMetric, M⁻¹) = DenseEuclideanMetric(M⁻¹)

Base.eltype(::DenseEuclideanMetric{T}) where {T} = T
Base.size(e::DenseEuclideanMetric, dim...) = size(e._temp, dim...)

function Base.show(io::IO, dem::DenseEuclideanMetric)
print(io, "DenseEuclideanMetric(", _string_M⁻¹(dem.M⁻¹), ")")
end
function Base.show(io::IO, ::MIME"text/plain", dem::DenseEuclideanMetric{T}) where {T}
return print(
io,
Expand Down
6 changes: 3 additions & 3 deletions src/trajectory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ struct SliceTS{F<:AbstractFloat,P<:PhasePoint} <: AbstractTrajectorySampler
n::Int
end

function Base.show(io::IO, mime::MIME"text/plain", s::SliceTS)
function Base.show(io::IO, s::SliceTS)
return print(
io,
"SliceTS with slice variable ℓu=",
Expand Down Expand Up @@ -225,7 +225,7 @@ end

ConstructionBase.constructorof(::Type{<:Trajectory{TS}}) where {TS} = Trajectory{TS}

function Base.show(io::IO, mime::MIME"text/plain", τ::Trajectory{TS}) where {TS}
function Base.show(io::IO, τ::Trajectory{TS}) where {TS}
return print(
io,
"Trajectory{",
Expand Down Expand Up @@ -482,7 +482,7 @@ struct Termination
numerical::Bool
end

function Base.show(io::IO, mime::MIME"text/plain", d::Termination)
function Base.show(io::IO, d::Termination)
return print(
io, "Termination reasons of (dynamic=", d.dynamic, ", numerical=", d.numerical, ")"
)
Expand Down
Loading