Skip to content

Commit 4b1f024

Browse files
Merge pull request #980 from ChrisRackauckas/fix-formatting
Apply JuliaFormatter to fix code formatting
2 parents 29e6f50 + ebb7d49 commit 4b1f024

File tree

6 files changed

+128
-117
lines changed

6 files changed

+128
-117
lines changed

docs/src/examples/augmented_neural_ode.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ function plot_contour(model, ps, st, npoints = 300)
6161
x = range(-4.0f0, 4.0f0; length = npoints)
6262
y = range(-4.0f0, 4.0f0; length = npoints)
6363
for x1 in x, x2 in y
64+
6465
grid_points[:, idx] .= [x1, x2]
6566
idx += 1
6667
end
@@ -212,6 +213,7 @@ function plot_contour(model, ps, st, npoints = 300)
212213
x = range(-4.0f0, 4.0f0; length = npoints)
213214
y = range(-4.0f0, 4.0f0; length = npoints)
214215
for x1 in x, x2 in y
216+
215217
grid_points[:, idx] .= [x1, x2]
216218
idx += 1
217219
end

docs/src/examples/multiple_shooting.md

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,15 @@ function loss_function(data, pred)
6262
return sum(abs2, data - pred)
6363
end
6464
65-
l1, preds = multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function,
65+
l1,
66+
preds = multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function,
6667
Tsit5(), group_size; continuity_term)
6768
6869
function loss_multiple_shooting(p)
6970
ps = ComponentArray(p, pax)
7071
71-
loss, currpred = multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function,
72+
loss,
73+
currpred = multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function,
7274
Tsit5(), group_size; continuity_term)
7375
global preds = currpred
7476
return loss
@@ -93,7 +95,8 @@ function callback(state, l; doplot = true, prob_node = prob_node)
9395
# plot the original data
9496
plt = scatter(tsteps, ode_data[1, :]; label = "Data")
9597
# plot the different predictions for individual shoot
96-
l1, preds = multiple_shoot(
98+
l1,
99+
preds = multiple_shoot(
97100
ComponentArray(state.u, pax), ode_data, tsteps, prob_node, loss_function,
98101
Tsit5(), group_size; continuity_term)
99102
plot_multiple_shoot(plt, preds, group_size)
@@ -127,7 +130,8 @@ pd, pax = getdata(ps), getaxes(ps)
127130
128131
function loss_single_shooting(p)
129132
ps = ComponentArray(p, pax)
130-
loss, currpred = multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function,
133+
loss,
134+
currpred = multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function,
131135
Tsit5(), group_size; continuity_term)
132136
global preds = currpred
133137
return loss

docs/src/examples/neural_ode_weather_forecast.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,8 @@ function train(t, y, obs_grid, maxiters, lr, rng, p = nothing, state = nothing;
134134
p === nothing && (p = p_new)
135135
state === nothing && (state = state_new)
136136
137-
p, state = train_one_round(node, p, state, y, OptimizationOptimisers.AdamW(lr),
137+
p,
138+
state = train_one_round(node, p, state, y, OptimizationOptimisers.AdamW(lr),
138139
maxiters, rng; callback = log_results(ps, losses), kwargs...)
139140
end
140141
ps, state, losses

test/cnf_tests.jl

Lines changed: 107 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -20,24 +20,25 @@ export callback
2020
end
2121

2222
@testitem "Smoke test for FFJORD" setup=[CNFTestSetup] tags=[:advancedneuralde] begin
23-
nn = Chain(Dense(1, 1, tanh))
24-
tspan = (0.0f0, 1.0f0)
25-
ffjord_mdl = FFJORD(nn, tspan, (1,), Tsit5())
26-
ps, st = Lux.setup(Xoshiro(0), ffjord_mdl)
27-
ps = ComponentArray(ps)
23+
nn=Chain(Dense(1, 1, tanh))
24+
tspan=(0.0f0, 1.0f0)
25+
ffjord_mdl=FFJORD(nn, tspan, (1,), Tsit5())
26+
ps, st=Lux.setup(Xoshiro(0), ffjord_mdl)
27+
ps=ComponentArray(ps)
2828

29-
data_dist = Beta(2.0f0, 2.0f0)
30-
train_data = Float32.(rand(data_dist, 1, 100))
29+
data_dist=Beta(2.0f0, 2.0f0)
30+
train_data=Float32.(rand(data_dist, 1, 100))
3131

3232
function loss(model, θ)
33-
logpx, λ₁, λ₂ = model(train_data, θ)
33+
logpx, λ₁, λ₂=model(train_data, θ)
3434
return -mean(logpx)
3535
end
3636

3737
@testset "ADType: $(adtype)" for adtype in (Optimization.AutoForwardDiff(),
3838
Optimization.AutoReverseDiff(), Optimization.AutoTracker(),
3939
Optimization.AutoZygote(), Optimization.AutoFiniteDiff())
40-
@testset "regularize = $(regularize) & monte_carlo = $(monte_carlo)" for regularize in (
40+
@testset "regularize = $(regularize) & monte_carlo = $(monte_carlo)" for regularize in
41+
(
4142
true, false),
4243
monte_carlo in (true, false)
4344

@@ -54,177 +55,177 @@ end
5455
end
5556

5657
@testitem "Smoke test for FFJORDDistribution (sampling & pdf)" setup=[CNFTestSetup] tags=[:advancedneuralde] begin
57-
nn = Chain(Dense(1, 1, tanh))
58-
tspan = (0.0f0, 1.0f0)
59-
ffjord_mdl = FFJORD(nn, tspan, (1,), Tsit5())
60-
ps, st = Lux.setup(Xoshiro(0), ffjord_mdl)
61-
ps = ComponentArray(ps)
58+
nn=Chain(Dense(1, 1, tanh))
59+
tspan=(0.0f0, 1.0f0)
60+
ffjord_mdl=FFJORD(nn, tspan, (1,), Tsit5())
61+
ps, st=Lux.setup(Xoshiro(0), ffjord_mdl)
62+
ps=ComponentArray(ps)
6263

63-
regularize = false
64-
monte_carlo = false
64+
regularize=false
65+
monte_carlo=false
6566

66-
data_dist = Beta(2.0f0, 2.0f0)
67-
train_data = Float32.(rand(data_dist, 1, 100))
67+
data_dist=Beta(2.0f0, 2.0f0)
68+
train_data=Float32.(rand(data_dist, 1, 100))
6869

6970
function loss(model, θ)
70-
logpx, λ₁, λ₂ = model(train_data, θ)
71+
logpx, λ₁, λ₂=model(train_data, θ)
7172
return -mean(logpx)
7273
end
7374

74-
adtype = Optimization.AutoZygote()
75+
adtype=Optimization.AutoZygote()
7576

76-
st_ = (; st..., regularize, monte_carlo)
77-
model = StatefulLuxLayer{true}(ffjord_mdl, nothing, st_)
77+
st_=(; st..., regularize, monte_carlo)
78+
model=StatefulLuxLayer{true}(ffjord_mdl, nothing, st_)
7879

79-
optf = Optimization.OptimizationFunction((θ, _) -> loss(model, θ), adtype)
80-
optprob = Optimization.OptimizationProblem(optf, ps)
81-
res = Optimization.solve(optprob, Adam(0.1); callback = callback(adtype), maxiters = 10)
80+
optf=Optimization.OptimizationFunction((θ, _)->loss(model, θ), adtype)
81+
optprob=Optimization.OptimizationProblem(optf, ps)
82+
res=Optimization.solve(optprob, Adam(0.1); callback = callback(adtype), maxiters = 10)
8283

83-
ffjord_d = FFJORDDistribution(ffjord_mdl, res.u, st_)
84+
ffjord_d=FFJORDDistribution(ffjord_mdl, res.u, st_)
8485

8586
@test !isnothing(pdf(ffjord_d, train_data))
8687
@test !isnothing(rand(ffjord_d))
8788
@test !isnothing(rand(ffjord_d, 10))
8889
end
8990

9091
@testitem "Test for default base distribution and deterministic trace FFJORD" setup=[CNFTestSetup] tags=[:advancedneuralde] begin
91-
nn = Chain(Dense(1, 1, tanh))
92-
tspan = (0.0f0, 1.0f0)
93-
ffjord_mdl = FFJORD(nn, tspan, (1,), Tsit5())
94-
ps, st = Lux.setup(Xoshiro(0), ffjord_mdl)
95-
ps = ComponentArray(ps)
92+
nn=Chain(Dense(1, 1, tanh))
93+
tspan=(0.0f0, 1.0f0)
94+
ffjord_mdl=FFJORD(nn, tspan, (1,), Tsit5())
95+
ps, st=Lux.setup(Xoshiro(0), ffjord_mdl)
96+
ps=ComponentArray(ps)
9697

97-
regularize = false
98-
monte_carlo = false
98+
regularize=false
99+
monte_carlo=false
99100

100-
data_dist = Beta(7.0f0, 7.0f0)
101-
train_data = Float32.(rand(data_dist, 1, 100))
102-
test_data = Float32.(rand(data_dist, 1, 100))
101+
data_dist=Beta(7.0f0, 7.0f0)
102+
train_data=Float32.(rand(data_dist, 1, 100))
103+
test_data=Float32.(rand(data_dist, 1, 100))
103104

104105
function loss(model, θ)
105-
logpx, λ₁, λ₂ = model(train_data, θ)
106+
logpx, λ₁, λ₂=model(train_data, θ)
106107
return -mean(logpx)
107108
end
108109

109-
adtype = Optimization.AutoZygote()
110-
st_ = (; st..., regularize, monte_carlo)
111-
model = StatefulLuxLayer{true}(ffjord_mdl, nothing, st_)
110+
adtype=Optimization.AutoZygote()
111+
st_=(; st..., regularize, monte_carlo)
112+
model=StatefulLuxLayer{true}(ffjord_mdl, nothing, st_)
112113

113-
optf = Optimization.OptimizationFunction((θ, _) -> loss(model, θ), adtype)
114-
optprob = Optimization.OptimizationProblem(optf, ps)
115-
res = Optimization.solve(optprob, Adam(0.1); callback = callback(adtype), maxiters = 10)
114+
optf=Optimization.OptimizationFunction((θ, _)->loss(model, θ), adtype)
115+
optprob=Optimization.OptimizationProblem(optf, ps)
116+
res=Optimization.solve(optprob, Adam(0.1); callback = callback(adtype), maxiters = 10)
116117

117-
actual_pdf = pdf.(data_dist, test_data)
118-
learned_pdf = exp.(model(test_data, res.u)[1])
118+
actual_pdf=pdf.(data_dist, test_data)
119+
learned_pdf=exp.(model(test_data, res.u)[1])
119120

120121
@test ps != res.u
121122
@test totalvariation(learned_pdf, actual_pdf) / size(test_data, 2) < 0.9
122123
end
123124

124125
@testitem "Test for alternative base distribution and deterministic trace FFJORD" setup=[CNFTestSetup] tags=[:advancedneuralde] begin
125-
nn = Chain(Dense(1, 3, tanh), Dense(3, 1, tanh))
126-
tspan = (0.0f0, 1.0f0)
127-
ffjord_mdl = FFJORD(
126+
nn=Chain(Dense(1, 3, tanh), Dense(3, 1, tanh))
127+
tspan=(0.0f0, 1.0f0)
128+
ffjord_mdl=FFJORD(
128129
nn, tspan, (1,), Tsit5(); basedist = MvNormal([0.0f0], Diagonal([4.0f0])))
129-
ps, st = Lux.setup(Xoshiro(0), ffjord_mdl)
130-
ps = ComponentArray(ps)
130+
ps, st=Lux.setup(Xoshiro(0), ffjord_mdl)
131+
ps=ComponentArray(ps)
131132

132-
regularize = false
133-
monte_carlo = false
133+
regularize=false
134+
monte_carlo=false
134135

135-
data_dist = Normal(6.0f0, 0.7f0)
136-
train_data = Float32.(rand(data_dist, 1, 100))
137-
test_data = Float32.(rand(data_dist, 1, 100))
136+
data_dist=Normal(6.0f0, 0.7f0)
137+
train_data=Float32.(rand(data_dist, 1, 100))
138+
test_data=Float32.(rand(data_dist, 1, 100))
138139

139140
function loss(model, θ)
140-
logpx, λ₁, λ₂ = model(train_data, θ)
141+
logpx, λ₁, λ₂=model(train_data, θ)
141142
return -mean(logpx)
142143
end
143144

144-
adtype = Optimization.AutoZygote()
145-
st_ = (; st..., regularize, monte_carlo)
146-
model = StatefulLuxLayer{true}(ffjord_mdl, nothing, st_)
145+
adtype=Optimization.AutoZygote()
146+
st_=(; st..., regularize, monte_carlo)
147+
model=StatefulLuxLayer{true}(ffjord_mdl, nothing, st_)
147148

148-
optf = Optimization.OptimizationFunction((θ, _) -> loss(model, θ), adtype)
149-
optprob = Optimization.OptimizationProblem(optf, ps)
150-
res = Optimization.solve(optprob, Adam(0.1); callback = callback(adtype), maxiters = 30)
149+
optf=Optimization.OptimizationFunction((θ, _)->loss(model, θ), adtype)
150+
optprob=Optimization.OptimizationProblem(optf, ps)
151+
res=Optimization.solve(optprob, Adam(0.1); callback = callback(adtype), maxiters = 30)
151152

152-
actual_pdf = pdf.(data_dist, test_data)
153-
learned_pdf = exp.(model(test_data, res.u)[1])
153+
actual_pdf=pdf.(data_dist, test_data)
154+
learned_pdf=exp.(model(test_data, res.u)[1])
154155

155156
@test ps != res.u
156157
@test totalvariation(learned_pdf, actual_pdf) / size(test_data, 2) < 0.25
157158
end
158159

159160
@testitem "Test for multivariate distribution and deterministic trace FFJORD" setup=[CNFTestSetup] tags=[:advancedneuralde] begin
160-
nn = Chain(Dense(2, 2, tanh))
161-
tspan = (0.0f0, 1.0f0)
162-
ffjord_mdl = FFJORD(nn, tspan, (2,), Tsit5())
163-
ps, st = Lux.setup(Xoshiro(0), ffjord_mdl)
164-
ps = ComponentArray(ps)
161+
nn=Chain(Dense(2, 2, tanh))
162+
tspan=(0.0f0, 1.0f0)
163+
ffjord_mdl=FFJORD(nn, tspan, (2,), Tsit5())
164+
ps, st=Lux.setup(Xoshiro(0), ffjord_mdl)
165+
ps=ComponentArray(ps)
165166

166-
regularize = false
167-
monte_carlo = false
167+
regularize=false
168+
monte_carlo=false
168169

169-
μ = ones(Float32, 2)
170-
Σ = Diagonal([7.0f0, 7.0f0])
171-
data_dist = MvNormal(μ, Σ)
172-
train_data = Float32.(rand(data_dist, 100))
173-
test_data = Float32.(rand(data_dist, 100))
170+
μ=ones(Float32, 2)
171+
Σ=Diagonal([7.0f0, 7.0f0])
172+
data_dist=MvNormal(μ, Σ)
173+
train_data=Float32.(rand(data_dist, 100))
174+
test_data=Float32.(rand(data_dist, 100))
174175

175176
function loss(model, θ)
176-
logpx, λ₁, λ₂ = model(train_data, θ)
177+
logpx, λ₁, λ₂=model(train_data, θ)
177178
return -mean(logpx)
178179
end
179180

180-
adtype = Optimization.AutoZygote()
181-
st_ = (; st..., regularize, monte_carlo)
182-
model = StatefulLuxLayer{true}(ffjord_mdl, nothing, st_)
181+
adtype=Optimization.AutoZygote()
182+
st_=(; st..., regularize, monte_carlo)
183+
model=StatefulLuxLayer{true}(ffjord_mdl, nothing, st_)
183184

184-
optf = Optimization.OptimizationFunction((θ, _) -> loss(model, θ), adtype)
185-
optprob = Optimization.OptimizationProblem(optf, ps)
186-
res = Optimization.solve(
185+
optf=Optimization.OptimizationFunction((θ, _)->loss(model, θ), adtype)
186+
optprob=Optimization.OptimizationProblem(optf, ps)
187+
res=Optimization.solve(
187188
optprob, Adam(0.01); callback = callback(adtype), maxiters = 30)
188189

189-
actual_pdf = pdf(data_dist, test_data)
190-
learned_pdf = exp.(model(test_data, res.u)[1])
190+
actual_pdf=pdf(data_dist, test_data)
191+
learned_pdf=exp.(model(test_data, res.u)[1])
191192

192193
@test ps != res.u
193194
@test totalvariation(learned_pdf, actual_pdf) / size(test_data, 2) < 0.25
194195
end
195196

196197
@testitem "Test for multivariate distribution and FFJORD with regularizers" setup=[CNFTestSetup] tags=[:advancedneuralde] begin
197-
nn = Chain(Dense(2, 2, tanh))
198-
tspan = (0.0f0, 1.0f0)
199-
ffjord_mdl = FFJORD(nn, tspan, (2,), Tsit5())
200-
ps, st = Lux.setup(Xoshiro(0), ffjord_mdl)
201-
ps = ComponentArray(ps) .* 0.001f0
198+
nn=Chain(Dense(2, 2, tanh))
199+
tspan=(0.0f0, 1.0f0)
200+
ffjord_mdl=FFJORD(nn, tspan, (2,), Tsit5())
201+
ps, st=Lux.setup(Xoshiro(0), ffjord_mdl)
202+
ps=ComponentArray(ps) .* 0.001f0
202203

203-
regularize = true
204-
monte_carlo = true
204+
regularize=true
205+
monte_carlo=true
205206

206-
μ = ones(Float32, 2)
207-
Σ = Diagonal([7.0f0, 7.0f0])
208-
data_dist = MvNormal(μ, Σ)
209-
train_data = Float32.(rand(data_dist, 100))
210-
test_data = Float32.(rand(data_dist, 100))
207+
μ=ones(Float32, 2)
208+
Σ=Diagonal([7.0f0, 7.0f0])
209+
data_dist=MvNormal(μ, Σ)
210+
train_data=Float32.(rand(data_dist, 100))
211+
test_data=Float32.(rand(data_dist, 100))
211212

212213
function loss(model, θ)
213-
logpx, λ₁, λ₂ = model(train_data, θ)
214+
logpx, λ₁, λ₂=model(train_data, θ)
214215
return -mean(logpx)
215216
end
216217

217-
adtype = Optimization.AutoZygote()
218-
st_ = (; st..., regularize, monte_carlo)
219-
model = StatefulLuxLayer{true}(ffjord_mdl, nothing, st_)
218+
adtype=Optimization.AutoZygote()
219+
st_=(; st..., regularize, monte_carlo)
220+
model=StatefulLuxLayer{true}(ffjord_mdl, nothing, st_)
220221

221-
optf = Optimization.OptimizationFunction((θ, _) -> loss(model, θ), adtype)
222-
optprob = Optimization.OptimizationProblem(optf, ps)
223-
res = Optimization.solve(
222+
optf=Optimization.OptimizationFunction((θ, _)->loss(model, θ), adtype)
223+
optprob=Optimization.OptimizationProblem(optf, ps)
224+
res=Optimization.solve(
224225
optprob, Adam(0.01); callback = callback(adtype), maxiters = 30)
225226

226-
actual_pdf = pdf(data_dist, test_data)
227-
learned_pdf = exp.(model(test_data, res.u)[1])
227+
actual_pdf=pdf(data_dist, test_data)
228+
learned_pdf=exp.(model(test_data, res.u)[1])
228229

229230
@test ps != res.u
230231
@test totalvariation(learned_pdf, actual_pdf) / size(test_data, 2) < 0.25

test/multiple_shoot_tests.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,11 @@
2424
(
2525
name = "Multi-D Test Config",
2626
u0 = Float32[2.0 0.0; 1.0 1.5; 0.5 -1.0],
27-
ode_func = (du, u, p, t) -> (du .= ((u .^ 3).*[-0.01 0.02; -0.02 -0.01; 0.01 -0.05])),
27+
ode_func = (
28+
du, u, p, t) -> (du .= ((u .^ 3) .* [-0.01 0.02; -0.02 -0.01; 0.01 -0.05])),
2829
nn = Chain(x -> x .^ 3, Dense(3 => 3, tanh)),
29-
u0s_ensemble = [Float32[2.0 0.0; 1.0 1.5; 0.5 -1.0], Float32[3.0 1.0; 2.0 0.5; 1.5 -0.5]]
30+
u0s_ensemble = [
31+
Float32[2.0 0.0; 1.0 1.5; 0.5 -1.0], Float32[3.0 1.0; 2.0 0.5; 1.5 -0.5]]
3032
)
3133
]
3234

@@ -158,7 +160,8 @@
158160
group_size = 3
159161
continuity_term = 200
160162
function loss_multiple_shooting_ens(p)
161-
return multiple_shoot(p, ode_data_ensemble, tsteps, ensemble_prob, ensemble_alg,
163+
return multiple_shoot(
164+
p, ode_data_ensemble, tsteps, ensemble_prob, ensemble_alg,
162165
loss_function, Tsit5(), group_size; continuity_term,
163166
trajectories, abstol = 1e-8, reltol = 1e-6)[1]
164167
end

0 commit comments

Comments
 (0)