Skip to content

Commit 70c4fb3

Browse files
more callback args fixes
1 parent b256e1e commit 70c4fb3

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

docs/src/examples/mnist_conv_neural_ode.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,10 +104,10 @@ iter = 0
104104
opt_func = OptimizationFunction(loss_function, Optimization.AutoZygote())
105105
opt_prob = OptimizationProblem(opt_func, ps, dataloader);
106106
107-
function callback(ps, l, pred)
107+
function callback(state, l)
108108
global iter += 1
109109
iter % 10 == 0 &&
110-
@info "[MNIST Conv GPU] Accuracy: $(accuracy(m, dataloader, ps.u, st))"
110+
@info "[MNIST Conv GPU] Accuracy: $(accuracy(m, dataloader, state.u, st))"
111111
return false
112112
end
113113

docs/src/examples/multiple_shooting.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ group_size = 3
8686
continuity_term = 200
8787
8888
l1, preds = multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function,
89-
Tsit5(), group_size; continuity_term)
89+
Tsit5(), group_size; continuity_term)
9090
9191
function loss_function(data, pred)
9292
return sum(abs2, data - pred)
@@ -124,14 +124,16 @@ pd, pax = getdata(ps), getaxes(ps)
124124
125125
function loss_single_shooting(p)
126126
ps = ComponentArray(p, pax)
127-
return multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function,
127+
loss, currpred = multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function,
128128
Tsit5(), group_size; continuity_term)
129+
global preds = currpred
130+
return loss
129131
end
130132
131133
adtype = Optimization.AutoZygote()
132134
optf = Optimization.OptimizationFunction((x, p) -> loss_single_shooting(x), adtype)
133135
optprob = Optimization.OptimizationProblem(optf, pd)
134-
res_ms = Optimization.solve(optprob, PolyOpt(); callback = callback)
136+
res_ms = Optimization.solve(optprob, PolyOpt(); callback = callback, maxiters = 5000)
135137
gif(anim, "single_shooting.gif"; fps = 15)
136138
```
137139

0 commit comments

Comments
 (0)