Skip to content

Commit efedbdf

Browse files
more fixes
1 parent 63b7613 commit efedbdf

File tree

16 files changed

+57
-143
lines changed

16 files changed

+57
-143
lines changed

docs/src/models/advanced.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,10 @@ There is a second, more severe, kind of restriction possible. This is not recomm
8484

8585
Sometimes a model needs to receive several separate inputs at once or produce several separate outputs at once. In other words, there multiple paths within this high-level layer, each processing a different input or producing a different output. A simple example of this in machine learning literature is the [inception module](https://www.cv-foundation.org/openaccess/content_cvpr_2016/papers/Szegedy_Rethinking_the_Inception_CVPR_2016_paper.pdf).
8686

87-
Naively, we could have a struct that stores the weights of along each path and implement the joining/splitting in the forward pass function. But that would mean a new struct any time the operations along each path changes. Instead, this guide will show you how to construct a high-level layer (like [`Chain`](@ref)) that is made of multiple sub-layers for each path.
87+
We could have a struct that stores the weights of along each path and implement the joining/splitting in the forward pass function. That would mean a new struct for each different block,
88+
e.g. one would have a `TransformerBlock` struct for a transformer block, and a `ResNetBlock` struct for a ResNet block, each block being composed by smaller sub-blocks. This is often the simplest and cleanest way to implement complex models.
89+
90+
This guide instead will show you how to construct a high-level layer (like [`Chain`](@ref)) that is made of multiple sub-layers for each path.
8891

8992
### Multiple inputs: a custom `Join` layer
9093

docs/src/models/recurrence.md

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ In such a model, only the last two outputs are used to compute the loss, hence t
154154
Alternatively, if one wants to perform some warmup of the sequence, it could be performed once, followed with a regular training where all the steps of the sequence would be considered for the gradient update:
155155

156156
```julia
157-
function loss(x, y)
157+
function loss(m, x, y)
158158
sum(mse(m(xi), yi) for (xi, yi) in zip(x, y))
159159
end
160160

@@ -172,9 +172,8 @@ data = zip(X,Y)
172172
Flux.reset!(m)
173173
[m(x) for x in seq_init]
174174

175-
ps = Flux.params(m)
176-
opt = Adam(1e-3)
177-
Flux.train!(loss, ps, data, opt)
175+
opt = Flux.setup(Adam(1e-3), m)
176+
Flux.train!(loss, m, data, opt)
178177
```
179178

180179
In this previous example, model's state is first reset with `Flux.reset!`. Then, there's a warmup that is performed over a sequence of length 1 by feeding it with `seq_init`, resulting in a warmup state. The model can then be trained for 1 epoch, where 2 batches are provided (`seq_1` and `seq_2`) and all the timesteps outputs are considered for the loss.

docs/src/training/reference.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ The available optimization rules are listed the [optimisation rules](@ref man-op
1414

1515
```@docs
1616
Flux.Train.setup
17-
Flux.Train.train!(loss, model, data, state; cb)
17+
Flux.Train.train!(loss, model, data, state)
1818
Optimisers.update!
1919
```
2020

src/Flux.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ using MLUtils
1212
const stack = MLUtils.stack # now exported by Base
1313
@reexport using Optimisers
1414
import Optimisers: trainable
15+
using Optimisers: update!, trainables
1516
using Random: default_rng
1617
using Zygote, ChainRulesCore
1718
using Zygote: Params, @adjoint, gradient, pullback
@@ -43,7 +44,7 @@ export Chain, Dense, Embedding, Maxout, SkipConnection, Parallel, PairwiseFusion
4344

4445
include("train.jl")
4546
using .Train
46-
using .Train: setup
47+
using .Train: setup, train!
4748

4849
using Adapt, Functors, OneHotArrays
4950
include("utils.jl")
@@ -55,7 +56,7 @@ include("functor.jl")
5556
# from Functors.jl
5657
functor, @functor,
5758
# from Train/Optimisers.jl
58-
setup, update!, destructure, freeze!, thaw!, adjust!, params, trainable
59+
setup, update!, destructure, freeze!, thaw!, adjust!, trainable, trainables
5960
))
6061

6162
# Pirate error to catch a common mistake.

src/deprecations.jl

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,9 @@
11
# v0.15 deprecations
22

3-
# Enable these when 0.15 is released, and delete const ClipGrad = Optimise.ClipValue etc:
4-
# Base.@deprecate_binding Optimiser OptimiserChain
5-
# Base.@deprecate_binding ClipValue ClipGrad
6-
7-
train!(loss::Function, ps::Zygote.Params, data, opt) = throw(ArgumentError(
3+
Train.train!(loss::Function, ps::Zygote.Params, data, opt) = throw(ArgumentError(
84
"""On Flux 0.15, `train!` no longer accepts implicit `Zygote.Params`.
95
Instead of `train!(loss_xy, Flux.params(model), data, Adam())`
10-
it now needs `opt = Flux.setup(Adam(), model); train!(loss_mxy, model, data, opt)`
6+
it now needs `opt_state = Flux.setup(Adam(), model); train!(loss_mxy, model, data, opt_state)`
117
where `loss_mxy` accepts the model as its first argument.
128
"""
139
))

src/train.jl

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,14 @@ using ..Flux: Flux # used only in docstring
88
export setup, train!
99

1010
using ProgressLogging: @progress, @withprogress, @logprogress
11-
using Zygote: Zygote, Params
11+
using Zygote: Zygote
1212

1313
"""
1414
opt_state = setup(rule, model)
1515
1616
This is a version of `Optimisers.setup`, and is the first step before using [`train!`](@ref Flux.train!).
17-
It differs from `Optimisers.setup` in that it:
18-
* has one extra check for mutability (since Flux expects to mutate the model in-place,
19-
while Optimisers.jl is designed to return an updated model)
20-
* has methods which accept Flux's old optimisers, and convert them.
21-
(The old `Flux.Optimise.Adam` and new `Optimisers.Adam` are distinct types.)
22-
23-
!!! compat "New"
24-
This function was added in Flux 0.13.9. It was not used by the old "implicit"
25-
interface, using `Flux.Optimise` module and [`Flux.params`](@ref).
17+
It differs from `Optimisers.setup` in that it has one extra check for mutability (since Flux expects to mutate the model in-place,
18+
while Optimisers.jl is designed to return an updated model).
2619
2720
# Example
2821
```jldoctest

test/optimise.jl renamed to test/TOREMOVE_optimise.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
using Flux.Optimise
2-
using Flux.Optimise: runall
31
using Flux: Params, gradient
42
import FillArrays, ComponentArrays
53
import Optimisers

test/data.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,18 +80,20 @@ using Random
8080
# test interaction with `train!`
8181
θ = ones(2)
8282
X = zeros(2, 10)
83-
loss(x) = sum((x .- θ).^2)
83+
loss(θ, x) = sum((x .- θ).^2)
8484
d = DataLoader(X)
85-
Flux.train!(loss, Params([θ]), ncycle(d, 10), Descent(0.1))
85+
opt = Flux.setup(Descent(0.1), θ)
86+
Flux.train!(loss, θ, ncycle(d, 10), opt)
8687
@test norm(θ) < 1e-4
8788

8889
# test interaction with `train!`
8990
θ = zeros(2)
9091
X = ones(2, 10)
9192
Y = fill(2, 10)
92-
loss(x, y) = sum((y - x'*θ).^2)
93+
loss(θ, x, y) = sum((y - x'*θ).^2)
9394
d = DataLoader((X, Y))
94-
Flux.train!(loss, Params([θ]), ncycle(d, 10), Descent(0.1))
95+
opt = Flux.setup(Descent(0.1), θ)
96+
Flux.train!(loss, θ, ncycle(d, 10), opt)
9597
@test norm.- 1) < 1e-10
9698

9799
# specify the rng

test/layers/basic.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ using Flux: activations
8080
@test size(Dense(10 => 5)(randn(10,2))) == (5,2)
8181
@test size(Dense(10 => 5)(randn(10,2,3))) == (5,2,3)
8282
@test size(Dense(10 => 5)(randn(10,2,3,4))) == (5,2,3,4)
83-
@test_throws DimensionMismatch Dense(10, 5)(randn(11,2,3))
83+
@test_throws DimensionMismatch Dense(10 => 5)(randn(11,2,3))
8484
end
8585
@testset "zeros" begin
8686
@test Dense(10 => 1, identity, init = ones)(ones(10,1)) == 10*ones(1, 1)
@@ -156,9 +156,9 @@ using Flux: activations
156156
@test mo(input) == target
157157
end
158158

159-
@testset "params" begin
159+
@testset "trainables" begin
160160
mo = Maxout(()->Dense(32 => 64), 4)
161-
ps = Flux.params(mo)
161+
ps = Flux.trainables(mo)
162162
@test length(ps) == 8 #4 alts, each with weight and bias
163163
end
164164
end

test/layers/conv.jl

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,35 +36,35 @@ end
3636
@test size(m(r)) == (10, 5)
3737

3838
# Test bias switch
39-
bias = Conv(ones(Float32, 2, 2, 1, 3), ones(Float32, 3))
39+
m2 = Conv(ones(Float32, 2, 2, 1, 3), ones(Float32, 3))
4040
ip = zeros(Float32, 28,28,1,1)
4141

42-
op = bias(ip)
42+
op = m2(ip)
4343
@test sum(op) == prod(size(op))
4444

4545
@testset "No bias mapped through $lmap" for lmap in (identity, cpu, f32)
46-
bias = Conv((2,2), 1=>3, bias = false) |> lmap
47-
op = bias(ip)
46+
m3 = Conv((2,2), 1=>3, bias = false) |> lmap
47+
op = m3(ip)
4848
@test sum(op) 0.f0
49-
gs = gradient(() -> sum(bias(ip)), Flux.params(bias))
50-
@test bias.bias gs.params
49+
gs = gradient(m -> sum(m(ip)), m3)[1]
50+
@test gs.bias === nothing
5151
end
5252

5353
# Train w/o bias and make sure no convergence happens
5454
# when only bias can be converged
55-
bias = Conv((2, 2), 1=>3, bias = false);
55+
m4 = Conv((2, 2), 1=>3, bias = false);
5656
ip = zeros(Float32, 28,28,1,1)
5757
op = zeros(Float32, 27,27,3,1) .+ 2.f0
58-
opt = Descent()
58+
opt_state = Flux.setup(Descent(), m4)
5959

6060
for _ = 1:10^3
61-
gs = gradient(Flux.params(bias)) do
62-
Flux.Losses.mse(bias(ip), op)
63-
end
64-
Flux.Optimise.update!(opt, params(bias), gs)
61+
gs = gradient(m4) do m
62+
Flux.mse(m(ip), op)
63+
end[1]
64+
Flux.update!(opt_state, m4, gs)
6565
end
6666

67-
@test Flux.Losses.mse(bias(ip), op) 4.f0
67+
@test Flux.Losses.mse(m4(ip), op) 4.f0
6868

6969
@testset "Grouped Conv" begin
7070
ip = rand(Float32, 28, 100, 2)

0 commit comments

Comments
 (0)