-
-
Notifications
You must be signed in to change notification settings - Fork 614
Open
Description
Hi,
This is probably less of a bug and more a discussion/feature request, but I've noticed when transferring models from pytorch to Flux that the BatchNorm
layer seems to use a lot more memory on the GPU for larger arrays than Pytorch does. This proves to be a bottleneck when doing e.g. graph-based models on very large datasets (~100k nodes). Given one could just increase the size of GPU they're using, it's obviously not a blocker, but if Pytorch can handle the same case on the same GPU it might be worth investigating why? Any help/insight would be appreciated :)
As a MWE, I ran the following on an 8GB NVIDIA RTX 2000 device with CUDA and Flux:
using Flux
using CUDA
m = BatchNorm(32) |> gpu
m(CUDA.rand(32, 800_000))
Which gives the following output:
ERROR: CUDNNError: CUDNN_STATUS_NOT_SUPPORTED (code 3000)
Stacktrace:
[1] throw_api_error(res::cuDNN.cudnnStatus_t)
@ cuDNN ~/.julia/packages/cuDNN/7odoD/src/libcudnn.jl:15
[2] check
@ ~/.julia/packages/cuDNN/7odoD/src/libcudnn.jl:26 [inlined]
[3] cudnnBatchNormalizationForwardInference
@ ~/.julia/packages/GPUToolbox/cZlg7/src/ccalls.jl:33 [inlined]
[4] cudnnBNForward!(y::CuArray{…}, g::CuArray{…}, b::CuArray{…}, x::CuArray{…}, running_mean::CuArray{…}, running_var::CuArray{…}, momentum::Float32; cache::Nothing, alpha::Int64, beta::Int64, eps::Float32, training::Bool, affine::Bool, track_stats::Bool)
@ NNlibCUDACUDNNExt ~/.julia/packages/NNlib/1TYHL/ext/NNlibCUDACUDNNExt/batchnorm.jl:88
[5] cudnnBNForward!
@ ~/.julia/packages/NNlib/1TYHL/ext/NNlibCUDACUDNNExt/batchnorm.jl:40 [inlined]
[6] #batchnorm#64
@ ~/.julia/packages/NNlib/1TYHL/ext/NNlibCUDACUDNNExt/batchnorm.jl:37 [inlined]
[7] batchnorm
@ ~/.julia/packages/NNlib/1TYHL/ext/NNlibCUDACUDNNExt/batchnorm.jl:35 [inlined]
[8] #batchnorm#63
@ ~/.julia/packages/NNlib/1TYHL/ext/NNlibCUDACUDNNExt/batchnorm.jl:31 [inlined]
[9] (::BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.DeviceMemory}, Float32, CuArray{Float32, 1, CUDA.DeviceMemory}})(x::CuArray{Float32, 2, CUDA.DeviceMemory}, cache::Nothing)
@ FluxCUDAcuDNNExt ~/.julia/packages/Flux/uRn8o/ext/FluxCUDAcuDNNExt/FluxCUDAcuDNNExt.jl:13
[10] (::BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.DeviceMemory}, Float32, CuArray{Float32, 1, CUDA.DeviceMemory}})(x::CuArray{Float32, 2, CUDA.DeviceMemory})
@ FluxCUDAcuDNNExt ~/.julia/packages/Flux/uRn8o/ext/FluxCUDAcuDNNExt/FluxCUDAcuDNNExt.jl:10
[11] top-level scope
@ REPL[19]:1
Some type information was truncated. Use `show(err)` to see complete types.
I can run the following using PyTorch v2.1.2 on Python v3.12.2:
import torch
device = torch.device("cuda")
m = torch.nn.BatchNorm1d(32).to(device)
input = torch.randn(800000, 32).to(device)
output = m(input)
Metadata
Metadata
Assignees
Labels
No labels