Skip to content

Commit a5c638d

Browse files
fix piracy with DataLoader (#2608)
* fix piracy * fix test * Update test/Project.toml
1 parent 3659201 commit a5c638d

File tree

3 files changed

+0
-81
lines changed

3 files changed

+0
-81
lines changed

src/devices.jl

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,3 @@ get_device(x) = MLDataDevices.get_device(x)
22

33
@doc (@doc MLDataDevices.get_device) get_device
44

5-
function (device::MLDataDevices.AbstractDevice)(d::MLUtils.DataLoader)
6-
MLUtils.DataLoader(MLUtils.mapobs(device, d.data),
7-
d.batchsize,
8-
d.buffer,
9-
d.partial,
10-
d.shuffle,
11-
d.parallel,
12-
d.collate,
13-
d.rng,
14-
)
15-
end

src/functor.jl

Lines changed: 0 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -179,72 +179,3 @@ Chain(
179179
```
180180
"""
181181
f16(m) = _paramtype(Float16, m)
182-
183-
184-
"""
185-
gpu(data::DataLoader)
186-
cpu(data::DataLoader)
187-
188-
Transforms a given `DataLoader` to apply `gpu` or `cpu` to each batch of data,
189-
when iterated over. (If no GPU is available, this does nothing.)
190-
191-
# Example
192-
193-
```julia-repl
194-
julia> dl = Flux.DataLoader((x = ones(2,10), y='a':'j'), batchsize=3)
195-
4-element DataLoader(::NamedTuple{(:x, :y), Tuple{Matrix{Float64}, StepRange{Char, Int64}}}, batchsize=3)
196-
with first element:
197-
(; x = 2×3 Matrix{Float64}, y = 3-element StepRange{Char, Int64})
198-
199-
julia> first(dl)
200-
(x = [1.0 1.0 1.0; 1.0 1.0 1.0], y = 'a':1:'c')
201-
202-
julia> c_dl = gpu(dl)
203-
4-element DataLoader(::MLUtils.MappedData{:auto, typeof(gpu), NamedTuple{(:x, :y), Tuple{Matrix{Float64}, StepRange{Char, Int64}}}}, batchsize=3)
204-
with first element:
205-
(; x = 2×3 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, y = 3-element StepRange{Char, Int64})
206-
207-
julia> first(c_dl).x
208-
2×3 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}:
209-
1.0 1.0 1.0
210-
1.0 1.0 1.0
211-
```
212-
213-
For large datasets, this is preferred over moving all the data to
214-
the GPU before creating the `DataLoader`, like this:
215-
216-
```julia-repl
217-
julia> Flux.DataLoader((x = ones(2,10), y=2:11) |> gpu, batchsize=3)
218-
4-element DataLoader(::NamedTuple{(:x, :y), Tuple{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, UnitRange{Int64}}}, batchsize=3)
219-
with first element:
220-
(; x = 2×3 CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, y = 3-element UnitRange{Int64})
221-
```
222-
223-
!!! warning
224-
This only works if `gpu` is applied directly to the `DataLoader`.
225-
While `gpu` acts recursively on Flux models and many basic Julia structs,
226-
it will not work on (say) a tuple of `DataLoader`s.
227-
"""
228-
function gpu(d::MLUtils.DataLoader)
229-
MLUtils.DataLoader(MLUtils.mapobs(gpu, d.data);
230-
d.batchsize,
231-
d.buffer,
232-
d.partial,
233-
d.shuffle,
234-
d.parallel,
235-
d.collate,
236-
d.rng,
237-
)
238-
end
239-
240-
function cpu(d::MLUtils.DataLoader)
241-
MLUtils.DataLoader(MLUtils.mapobs(cpu, d.data);
242-
d.batchsize,
243-
d.buffer,
244-
d.partial,
245-
d.shuffle,
246-
d.parallel,
247-
d.collate,
248-
d.rng,
249-
)
250-
end

test/data.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
@test batches[3] == X[:,5:5]
1515

1616
d_cpu = d |> cpu # does nothing but shouldn't error
17-
@test d_cpu isa DataLoader
1817
@test first(d_cpu) == X[:,1:2]
1918
@test length(d_cpu) == 3
2019

0 commit comments

Comments
 (0)