Skip to content

BatchNorm Memory Blowup on Very Large CUDA Arrays #2616

@BenCurran98

Description

@BenCurran98

Hi,

This is probably less of a bug and more a discussion/feature request, but I've noticed when transferring models from pytorch to Flux that the BatchNorm layer seems to use a lot more memory on the GPU for larger arrays than Pytorch does. This proves to be a bottleneck when doing e.g. graph-based models on very large datasets (~100k nodes). Given one could just increase the size of GPU they're using, it's obviously not a blocker, but if Pytorch can handle the same case on the same GPU it might be worth investigating why? Any help/insight would be appreciated :)

As a MWE, I ran the following on an 8GB NVIDIA RTX 2000 device with CUDA and Flux:

using Flux
using CUDA

m = BatchNorm(32) |> gpu
m(CUDA.rand(32, 800_000))

Which gives the following output:

ERROR: CUDNNError: CUDNN_STATUS_NOT_SUPPORTED (code 3000)
Stacktrace:
  [1] throw_api_error(res::cuDNN.cudnnStatus_t)
    @ cuDNN ~/.julia/packages/cuDNN/7odoD/src/libcudnn.jl:15
  [2] check
    @ ~/.julia/packages/cuDNN/7odoD/src/libcudnn.jl:26 [inlined]
  [3] cudnnBatchNormalizationForwardInference
    @ ~/.julia/packages/GPUToolbox/cZlg7/src/ccalls.jl:33 [inlined]
  [4] cudnnBNForward!(y::CuArray{…}, g::CuArray{…}, b::CuArray{…}, x::CuArray{…}, running_mean::CuArray{…}, running_var::CuArray{…}, momentum::Float32; cache::Nothing, alpha::Int64, beta::Int64, eps::Float32, training::Bool, affine::Bool, track_stats::Bool)
    @ NNlibCUDACUDNNExt ~/.julia/packages/NNlib/1TYHL/ext/NNlibCUDACUDNNExt/batchnorm.jl:88
  [5] cudnnBNForward!
    @ ~/.julia/packages/NNlib/1TYHL/ext/NNlibCUDACUDNNExt/batchnorm.jl:40 [inlined]
  [6] #batchnorm#64
    @ ~/.julia/packages/NNlib/1TYHL/ext/NNlibCUDACUDNNExt/batchnorm.jl:37 [inlined]
  [7] batchnorm
    @ ~/.julia/packages/NNlib/1TYHL/ext/NNlibCUDACUDNNExt/batchnorm.jl:35 [inlined]
  [8] #batchnorm#63
    @ ~/.julia/packages/NNlib/1TYHL/ext/NNlibCUDACUDNNExt/batchnorm.jl:31 [inlined]
  [9] (::BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.DeviceMemory}, Float32, CuArray{Float32, 1, CUDA.DeviceMemory}})(x::CuArray{Float32, 2, CUDA.DeviceMemory}, cache::Nothing)
    @ FluxCUDAcuDNNExt ~/.julia/packages/Flux/uRn8o/ext/FluxCUDAcuDNNExt/FluxCUDAcuDNNExt.jl:13
 [10] (::BatchNorm{typeof(identity), CuArray{Float32, 1, CUDA.DeviceMemory}, Float32, CuArray{Float32, 1, CUDA.DeviceMemory}})(x::CuArray{Float32, 2, CUDA.DeviceMemory})
    @ FluxCUDAcuDNNExt ~/.julia/packages/Flux/uRn8o/ext/FluxCUDAcuDNNExt/FluxCUDAcuDNNExt.jl:10
 [11] top-level scope
    @ REPL[19]:1
Some type information was truncated. Use `show(err)` to see complete types.

I can run the following using PyTorch v2.1.2 on Python v3.12.2:

import torch
device = torch.device("cuda")
m = torch.nn.BatchNorm1d(32).to(device)
input = torch.randn(800000, 32).to(device)
output = m(input)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions