Skip to content

Commit 5788885

Browse files
committed
Merge branch 'master' of https://github.com/JuliaPOMDP/POMDPs.jl
2 parents e9ea12c + 563697c commit 5788885

File tree

8 files changed

+133
-43
lines changed

8 files changed

+133
-43
lines changed

docs/src/api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ These functions should return *states*, *observations*, and/or *rewards*.
7272
```@docs
7373
gen
7474
initialstate
75+
initialobs
7576
```
7677

7778
### [Common](@id common_api)

docs/src/generative.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ The *generative* interface consists of two functions:
1919
The generative interface is typically used when it is easier to return sampled states and observations rather than explicit distributions as in the [Explicit interface](@ref explicit_doc).
2020
This type of model is often referred to as a "black-box" model.
2121

22+
In some special cases (e.g. reinforcement learning with [RLInterface.jl](https://github.com/JuliaPOMDP/RLInterface.jl)), an initial observation is needed before any actions are taken. In this case, the [`initialobs`](@ref) function will be used.
23+
2224
## The [`gen`](@ref) function
2325

2426
The [`gen`](@ref) function has three versions differentiated by the type of the first argument.

src/POMDPs.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ export
3232
# Generative model functions
3333
gen,
3434
initialstate,
35+
initialobs,
3536

3637
# Discrete Functions
3738
length,

src/errors.jl

Lines changed: 52 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ struct DistributionNotImplemented <: Exception
33
gen_firstarg::Type
44
func::Function
55
modeltype::Type
6-
dep_argtypes::NamedTuple
6+
dep_argtypes::AbstractVector
77
end
88

99
function Base.showerror(io::IO, ex::DistributionNotImplemented)
@@ -16,17 +16,20 @@ function Base.showerror(io::IO, ex::DistributionNotImplemented)
1616

1717
i = 1
1818
if ex.gen_firstarg <: DDNOut
19-
printstyled(io, "$i) Implement POMDPs.gen($argstring, ::AbstractRNG) to return a NamedTuple with key :$(ex.sym).\n", bold=true)
20-
gen_analysis(io, ex)
19+
M = ex.modeltype
20+
S = statetype(M)
21+
A = actiontype(M)
22+
printstyled(io, "$i) Implement POMDPs.gen(::$M, ::$S, ::$A, ::AbstractRNG) to return a NamedTuple with key :$(ex.sym).\n", bold=true)
23+
gen_analysis(io, ex.sym, M, [S,A])
2124
println(io)
2225
i += 1
2326
end
2427
printstyled(io, "$i) Implement POMDPs.gen(::DDNNode{:$(ex.sym)}, $argstring, ::AbstractRNG).\n",
2528
bold=true)
26-
Base.show_method_candidates(io, MethodError(gen, Tuple{DDNNode{ex.sym}, ex.modeltype, ex.dep_argtypes..., AbstractRNG})) # this is not exported - it may break
29+
try_show_method_candidates(io, MethodError(gen, Tuple{DDNNode{ex.sym}, ex.modeltype, ex.dep_argtypes..., AbstractRNG}))
2730
i += 1
2831
printstyled(io, "\n\n$i) Implement $(ex.func)($argstring).\n", bold=true)
29-
Base.show_method_candidates(io, MethodError(transition, Tuple{ex.modeltype, ex.dep_argtypes...}))
32+
try_show_method_candidates(io, MethodError(ex.func, Tuple{ex.modeltype, ex.dep_argtypes...}))
3033

3134
println(io, "\n\nThis error message uses heuristics to make recommendations for POMDPs.jl problem implementers. If it was misleading or you believe there is an inconsistency, please file an issue: https://github.com/JuliaPOMDP/POMDPs.jl/issues/new")
3235
end
@@ -36,63 +39,78 @@ function distribution_impl_error(sym, func, modeltype, dep_argtypes)
3639
acceptable = (:distribution_impl_error, nameof(func), nameof(gen), nameof(genout))
3740
gen_firstarg = nothing # The first argument to the `gen` call that is furthest down in the stack trace
3841

39-
for sf in stacktrace() # step up the stack trace
42+
try
43+
for sf in stacktrace() # step up the stack trace
4044

41-
# if it is a macro from ddn_struct.jl or gen_impl.jl it is ok
42-
if sf.func === Symbol("macro expansion")
43-
bn = basename(String(sf.file))
44-
if !(bn in ["ddn_struct.jl", "gen_impl.jl", "none"])
45-
break
46-
# the call stack includes a macro from some other package
47-
end
45+
# if it is a macro from ddn_struct.jl or gen_impl.jl it is ok
46+
if sf.func === Symbol("macro expansion")
47+
bn = basename(String(sf.file))
48+
if !(bn in ["ddn_struct.jl", "gen_impl.jl", "none"])
49+
break
50+
# the call stack includes a macro from some other package
51+
end
4852

49-
# if it is not a function we know about, give up
50-
elseif !(sf.func in acceptable)
51-
break
53+
# if it is not a function we know about, give up
54+
elseif !(sf.func in acceptable)
55+
break
5256

53-
# if it is gen, check to see if it's the DDNNode version
54-
elseif sf.func === nameof(gen)
55-
sig = sf.linfo.def.sig
56-
if sig isa UnionAll &&
57-
sig.body.parameters[1] == typeof(gen) &&
58-
sig.body.parameters[2] <: Union{DDNNode, DDNOut}
59-
# bingo!
60-
gen_firstarg = sig.body.parameters[2]
57+
# if it is gen, check to see if it's the DDNNode version
58+
elseif sf.func === nameof(gen)
59+
sig = sf.linfo.def.sig
60+
if sig isa UnionAll &&
61+
sig.body.parameters[1] == typeof(gen) &&
62+
sig.body.parameters[2] <: Union{DDNNode, DDNOut}
63+
# bingo!
64+
gen_firstarg = sig.body.parameters[2]
65+
dep_argtypes = [sig.body.parameters[3:end-1]...]
66+
end
6167
end
6268
end
69+
catch ex
70+
@debug("Error throwing DistributionNotImplemented error:\n$(sprint(showerror, ex))")
71+
throw(MethodError(func, Tuple{modeltype, dep_argtypes...}))
6372
end
6473

6574
if gen_firstarg === nothing
66-
throw(MethodError(transition, Tuple{modeltype, dep_argtypes...}))
75+
throw(MethodError(func, Tuple{modeltype, dep_argtypes...}))
6776
else
68-
throw(DistributionNotImplemented(:sp, gen_firstarg, func, modeltype, dep_argtypes))
77+
throw(DistributionNotImplemented(sym, gen_firstarg, func, modeltype, dep_argtypes))
6978
end
7079
end
7180

72-
function gen_analysis(io, ex::DistributionNotImplemented)
73-
argtypes = Tuple{ex.modeltype, ex.dep_argtypes..., AbstractRNG}
81+
function gen_analysis(io, sym::Symbol, modeltype::Type, dep_argtypes)
82+
argtypes = Tuple{modeltype, dep_argtypes..., AbstractRNG}
7483
rts = Base.return_types(gen, argtypes)
75-
@assert length(rts) > 0 # there should always be the default NamedTuple() impl.
76-
if length(rts) == 1
84+
if length(rts) <= 0 # there should always be the default NamedTuple() impl.
85+
@debug("Error analyzing the return types for gen. Please submit an issue at https://github.com/JuliaPOMDP/POMDPs.jl/issues/new", argtypes=argtypes, rts=rts)
86+
elseif length(rts) == 1
7787
rt = first(rts)
7888
if rt == typeof(NamedTuple()) && !implemented(gen, argtypes)
79-
Base.show_method_candidates(io, MethodError(gen, argtypes))
89+
try_show_method_candidates(io, MethodError(gen, argtypes))
8090
println(io)
8191
else
82-
println(io, "\nThis method was implemented and the return type was inferred to be $rt. Is this type always a NamedTuple with key :$(ex.sym)?")
92+
println(io, "\nThis method was implemented and the return type was inferred to be $rt. Is this type always a NamedTuple with key :$(sym)?")
8393
end
8494
else
8595
println(io, "(POMDPs.jl could not determine if this method was implemented correctly. [Base.return_types(gen, argtypes) = $(rts)])")
8696
end
8797
end
8898

89-
transition(m, s, a) = distribution_impl_error(:sp, transition, typeof(m), (s=typeof(s), a=typeof(a)))
99+
function try_show_method_candidates(io, args...)
100+
try
101+
Base.show_method_candidates(io, args...) # this isn't exported, so it might break
102+
catch ex
103+
@debug("Unable to show method candidates. Please submit an issue at https://github.com/JuliaPOMDP/POMDPs.jl/issues/new.\n$(sprint(showerror, ex))")
104+
end
105+
end
106+
107+
transition(m, s, a) = distribution_impl_error(:sp, transition, typeof(m), [typeof(s), typeof(a)])
90108
function implemented(t::typeof(transition), TT::TupleType)
91109
m = which(t, TT)
92110
return m.module != POMDPs # see if this was implemented by a user elsewhere
93111
end
94112

95-
observation(m, sp) = distribution_impl_error(:o, observation, typeof(m), (sp=typeof(sp),))
113+
observation(m, sp) = distribution_impl_error(:o, observation, typeof(m), [typeof(sp)])
96114
function implemented(o::typeof(observation), TT::Type{Tuple{M, SP}}) where {M<:POMDP, SP}
97115
m = which(o, TT)
98116
return m.module != POMDPs

src/generative.jl

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,19 +101,59 @@ end
101101
return rand(rng, d)
102102
end
103103

104+
# it is technically illegal to call this within the generated function
104105
if implemented(initialstate_distribution, Tuple{p})
105106
return impl
106107
else
107-
req = @req initialstate_distribution(::p)
108-
reqs = [(implemented(req...), req...)]
109-
this = @req(initialstate(::p, ::rng))
110108
return quote
111109
try
112110
$impl # trick to get the compiler to insert the right backedges
113111
catch
114-
# TODO failed_synth_warning($this, $reqs)
115112
throw(MethodError(initialstate, (p, rng)))
116113
end
117114
end
118115
end
119116
end
117+
118+
"""
119+
initialobs(m::POMDP, s, rng::AbstractRNG)
120+
121+
Return a sampled initial observation for the problem `m` and state `s`.
122+
123+
This function is only used in cases where the policy expects an initial observation rather than an initial belief, e.g. in a reinforcement learning setting. It is not used in a standard POMDP simulation.
124+
125+
By default, it will fall back to `observation(m, s)`. The random number generator `rng` should be used to draw this sample (e.g. use `rand(rng)` instead of `rand()`).
126+
"""
127+
function initialobs end
128+
129+
function implemented(f::typeof(initialobs), TT::Type)
130+
if !hasmethod(f, TT)
131+
return false
132+
end
133+
m = which(f, TT)
134+
if m.module == POMDPs && !implemented(observation, Tuple{TT.parameters[1:2]...})
135+
return false
136+
else
137+
return true
138+
end
139+
end
140+
141+
@generated function initialobs(m::POMDP, s, rng::AbstractRNG)
142+
impl = quote
143+
d = observation(m, s)
144+
return rand(rng, d)
145+
end
146+
147+
# it is technically illegal to call this within the generated function
148+
if implemented(observation, Tuple{m, s})
149+
return impl
150+
else
151+
return quote
152+
try
153+
$impl # trick to get the compiler to insert the right backedges
154+
catch
155+
throw(MethodError(initialobs, (m, s, rng)))
156+
end
157+
end
158+
end
159+
end

test/test_generative.jl

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,29 @@ import POMDPs: transition, reward, initialstate_distribution
22
import POMDPs: gen
33

44
struct W <: POMDP{Int, Bool, Int} end
5+
@test !@implemented initialstate(::W, ::typeof(Random.GLOBAL_RNG))
6+
@test !@implemented initialstate(::W, ::typeof(Random.GLOBAL_RNG), ::Nothing) # wrong number args
57
@test_throws MethodError initialstate(W(), Random.GLOBAL_RNG)
8+
@test !@implemented initialobs(::W, ::Int, ::typeof(Random.GLOBAL_RNG))
9+
@test !@implemented initialobs(::W, ::Int, ::typeof(Random.GLOBAL_RNG), ::Nothing) # wrong number args
10+
@test_throws MethodError initialobs(W(), 1, Random.GLOBAL_RNG)
611
@test_throws DistributionNotImplemented gen(DDNNode(:sp), W(), 1, true, Random.GLOBAL_RNG)
12+
try
13+
gen(DDNNode(:sp), W(), 1, true, Random.GLOBAL_RNG)
14+
catch ex
15+
str = sprint(showerror, ex)
16+
@test occursin(":sp", str)
17+
@test occursin("transition", str)
18+
end
719
@test_throws DistributionNotImplemented gen(DDNOut(:sp,:r), W(), 1, true, Random.GLOBAL_RNG)
820
@test_throws DistributionNotImplemented gen(DDNNode(:o), W(), 1, true, 2, Random.GLOBAL_RNG)
21+
try
22+
gen(DDNNode(:o), W(), 1, true, 2, Random.GLOBAL_RNG)
23+
catch ex
24+
str = sprint(showerror, ex)
25+
@test occursin(":o", str)
26+
@test occursin("observation", str)
27+
end
928
@test_throws DistributionNotImplemented gen(DDNOut(:sp,:o), W(), 1, true, Random.GLOBAL_RNG)
1029
@test_throws DistributionNotImplemented gen(DDNOut(:sp,:o,:r), W(), 1, true, Random.GLOBAL_RNG)
1130
POMDPs.gen(::W, ::Int, ::Bool, ::AbstractRNG) = nothing
@@ -34,7 +53,11 @@ gen(::DDNNode{:o}, b::B, s::Int, a::Bool, sp::Int, rng::AbstractRNG) = sp
3453
@test @inferred gen(DDNOut(:sp,:o,:r), B(), 1, true, Random.GLOBAL_RNG) == (2, 2, -1.0)
3554

3655
initialstate_distribution(b::B) = Int[1,2,3]
56+
@test @implemented initialstate(::B, ::MersenneTwister)
3757
@test initialstate(B(), Random.GLOBAL_RNG) in initialstate_distribution(B())
58+
POMDPs.observation(b::B, s::Int) = Bool[s]
59+
@test @implemented initialobs(::B, ::Int, ::MersenneTwister)
60+
@test initialobs(B(), 1, Random.GLOBAL_RNG) == 1
3861

3962
mutable struct C <: POMDP{Nothing, Nothing, Nothing} end
4063
gen(::DDNNode{:sp}, c::C, s::Nothing, a::Nothing, rng::AbstractRNG) = nothing
@@ -57,6 +80,5 @@ struct GE <: MDP{Int, Int} end
5780
@test_throws DistributionNotImplemented gen(DDNNode(:sp), GE(), 1, 1, Random.GLOBAL_RNG)
5881
@test_throws DistributionNotImplemented gen(DDNOut(:sp,:r), GE(), 1, 1, Random.GLOBAL_RNG)
5982
POMDPs.gen(::GE, s, a, ::AbstractRNG) = (sp=s+a, r=s^2)
60-
@show gen(DDNOut(:sp), GE(), 1, 1, Random.GLOBAL_RNG)
6183
@test @inferred gen(DDNOut(:sp), GE(), 1, 1, Random.GLOBAL_RNG) == 2
6284
@test @inferred gen(DDNOut(:sp,:r), GE(), 1, 1, Random.GLOBAL_RNG) == (2, 1)

test/test_generative_backedges.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,11 @@ let
1010
@test_throws DistributionNotImplemented gen(DDNOut(:sp,:o,:r), M(), 1, 1, MersenneTwister(4))
1111
@test_throws DistributionNotImplemented gen(DDNOut(:sp,:r), M(), 1, 1, MersenneTwister(4))
1212
POMDPs.reward(::M, ::Int, ::Int, ::Int) = 0.0
13-
POMDPs.gen(::DDNNode{:o}, ::M, ::Int, ::Int, ::Int, ::AbstractRNG) = `a`
13+
POMDPs.gen(::DDNNode{:o}, ::M, ::Int, ::Int, ::Int, ::AbstractRNG) = 'a'
1414
@test gen(DDNOut(:sp,:r), M(), 1, 1, MersenneTwister(4)) == (1, 0.0)
15-
@test gen(DDNOut(:sp,:o,:r), M(), 1, 1, MersenneTwister(4)) == (1, `a`, 0.0)
15+
@test gen(DDNOut(:sp,:o,:r), M(), 1, 1, MersenneTwister(4)) == (1, 'a', 0.0)
16+
17+
@test_throws MethodError initialobs(M(), 1, MersenneTwister(4))
18+
POMDPs.observation(::M, ::Int) = ['a']
19+
@test initialobs(M(), 1, MersenneTwister(4)) == 'a'
1620
end

test/test_requirements.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,10 @@ end
6464

6565
reqs = nothing # to check the hygeine of the macro
6666
println("There should be a warning about no @reqs here:")
67+
# 27 minutes has been spent trying to suppress this warning and automate a test for it. If you work more on it, please update this counter. The following things have been tried
68+
# - @test_logs (:warn, "No") @POMDP_requirements ...
69+
# - @capture_err @POMDP_requirements ... # From Suppressor.jl
70+
# - @capture_out @POMDP_requirements ... # From Suppressor.jl
6771
@POMDP_requirements "Warn none" begin
6872
1+1
6973
end
@@ -90,6 +94,4 @@ POMDPs.observations(p::SimplePOMDP) = [1,2,3]
9094
Random.rand(rng::AbstractRNG, d::SimpleDistribution) = sample(rng, d.ss, WeightVec(d.b))
9195
POMDPs.gen(::DDNOut{:o}, m::SimplePOMDP, s, a, rng) = 1
9296

93-
println("There should be no warnings or requirements output below this point!\n")
94-
9597
@test solve(CoolSolver(), SimplePOMDP())

0 commit comments

Comments
 (0)