-
-
Notifications
You must be signed in to change notification settings - Fork 615
Closed
Description
As per discussion in https://github.com/JuliaGPU/CuArrays.jl/issues/52, that I report here, the current way to use a specific gpu in a Flux script would be the following
gpu_id = 0 ## set < 0 for no cuda, >= 0 for using a specific device (if available)
if CUDAapi.has_cuda_gpu() && gpu_id >=0
CUDAdrv.device!(gpu_id)
CuArrays.allowscalar(false)
device = Flux.gpu
@info "Training on GPU-$(gpu_id)"
else
device = Flux.cpu
@info "Training on CPU"
end
model = model |> device
for x in data
x = x |> device
....This involves importing 3 cuda libraries, CUDAapi, CuArrays and CUDAdrv, which likely won't be likely used in the rest of the script.
I suggest to wrap the same functionality under Flux, in a gpu! function, so that we can equivalently write
gpu_id = 0 ## set < 0 for no cuda, >= 0 for using a specific device (if available)
if CUDAapi.has_cuda_gpu() && gpu_id >=0
device = Flux.gpu!(gpu_id, allowscalar=false)
@info "Training on GPU-$(gpu_id)"
else
device = Flux.cpu
@info "Training on CPU"
endThoughts?
Metadata
Metadata
Assignees
Labels
No labels