Skip to content
5 changes: 3 additions & 2 deletions lib/OptimizationOptimisers/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ version = "0.3.12"
[deps]
OptimizationBase = "bca83a33-5cc9-4baa-983d-23429ab6bcbb"
ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"

[extras]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Expand All @@ -19,6 +19,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"

[compat]
julia = "1.10"
Expand All @@ -29,4 +30,4 @@ Optimisers = "0.2, 0.3, 0.4"
Reexport = "1.2"

[targets]
test = ["ComponentArrays", "ForwardDiff", "Lux", "MLDataDevices", "MLUtils", "Random", "Test", "Zygote"]
test = ["ComponentArrays", "ForwardDiff", "Lux", "MLDataDevices", "MLUtils", "Random", "Test", "Zygote", "Printf"]
118 changes: 57 additions & 61 deletions lib/OptimizationOptimisers/src/OptimizationOptimisers.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module OptimizationOptimisers

using Reexport, Printf, ProgressLogging
using Reexport, ProgressLogging, UUIDs
@reexport using Optimisers, OptimizationBase
using SciMLBase

Expand Down Expand Up @@ -95,77 +95,73 @@ function SciMLBase.__solve(cache::OptimizationBase.OptimizationCache{
gevals = 0
t0 = time()
breakall = false
begin
for epoch in 1:epochs
if breakall
break
progress_id = uuid4()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about using a random UUID here. That as a symbol will intern. What's the reasoning?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

String or symbol? because we've never used a UUID there before, at least a string one. It at least before was a symbol, and we would get memory leaks if it was unique. So check if it's always a string?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for epoch in 1:epochs, d in data
if cache.f.fg !== nothing && dataiterate
x = cache.f.fg(G, θ, d)
iterations += 1
fevals += 1
gevals += 1
elseif dataiterate
cache.f.grad(G, θ, d)
x = cache.f(θ, d)
iterations += 1
fevals += 2
gevals += 1
elseif cache.f.fg !== nothing
x = cache.f.fg(G, θ)
iterations += 1
fevals += 1
gevals += 1
else
cache.f.grad(G, θ)
x = cache.f(θ)
iterations += 1
fevals += 2
gevals += 1
end
opt_state = OptimizationBase.OptimizationState(
iter = iterations,
u = θ,
p = d,
objective = x[1],
grad = G,
original = state)
breakall = cache.callback(opt_state, x...)
if !(breakall isa Bool)
error("The callback should return a boolean `halt` for whether to stop the optimization process. Please see the `solve` documentation for information.")
elseif breakall
break
end
cache.progress &&
@info ProgressLogging.Progress(progress_id, iterations / maxiters;
name = "loss: $(round(first(first(x)); digits=3))")

if cache.solver_args.save_best
if first(x)[1] < first(min_err)[1] #found a better solution
min_opt = opt
min_err = x
min_θ = copy(θ)
end
for (i, d) in enumerate(data)
if cache.f.fg !== nothing && dataiterate
x = cache.f.fg(G, θ, d)
iterations += 1
fevals += 1
gevals += 1
elseif dataiterate
cache.f.grad(G, θ, d)
x = cache.f(θ, d)
iterations += 1
fevals += 2
gevals += 1
elseif cache.f.fg !== nothing
x = cache.f.fg(G, θ)
iterations += 1
fevals += 1
gevals += 1
else
cache.f.grad(G, θ)
x = cache.f(θ)
iterations += 1
fevals += 2
gevals += 1
end
opt_state = OptimizationBase.OptimizationState(
iter = i + (epoch - 1) * length(data),
if iterations == length(data) * epochs #Last iter, revert to best.
opt = min_opt
x = min_err
θ = min_θ
cache.f.grad(G, θ, d)
opt_state = OptimizationBase.OptimizationState(iter = iterations,
u = θ,
p = d,
objective = x[1],
grad = G,
original = state)
breakall = cache.callback(opt_state, x...)
if !(breakall isa Bool)
error("The callback should return a boolean `halt` for whether to stop the optimization process. Please see the `solve` documentation for information.")
elseif breakall
break
end
msg = @sprintf("loss: %.3g", first(x)[1])
cache.progress && ProgressLogging.@logprogress msg iterations/maxiters

if cache.solver_args.save_best
if first(x)[1] < first(min_err)[1] #found a better solution
min_opt = opt
min_err = x
min_θ = copy(θ)
end
if iterations == length(data) * epochs #Last iter, revert to best.
opt = min_opt
x = min_err
θ = min_θ
cache.f.grad(G, θ, d)
opt_state = OptimizationBase.OptimizationState(iter = iterations,
u = θ,
p = d,
objective = x[1],
grad = G,
original = state)
breakall = cache.callback(opt_state, x...)
break
end
end
state, θ = Optimisers.update(state, θ, G)
break
end
end
state, θ = Optimisers.update(state, θ, G)
end

cache.progress && @info ProgressLogging.Progress(progress_id; done = true)
t1 = time()
stats = OptimizationBase.OptimizationStats(; iterations,
time = t1 - t0, fevals, gevals)
Expand Down
Loading