Skip to content

[Bug] Memory leak when using ExactGP module #2649

@treigerm

Description

@treigerm

🐛 Bug

I have a large training scripts for which I have encountered CUDA OOM issues. I was able to narrow down the issue to my particular usage of gpytorch and that using the ExactGP module was leading to memory leaks. The whole training script is too complicated to post here but I think I was able to reduce the bug down to a minimal example.

It seems that even just initializing the ExactGP module leads to memory leaks, even without doing backpropogation or calling the forward function of the module.

To reproduce

I found that the following code snippet already leads to a memory leak. Please let me know if I am using the library incorrectly in any way!

import gpytorch
import torch
import torch.nn as nn
import time


class ExactGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

class BoxModule(nn.Module):

    def __init__(self, x, y):
        super().__init__()
        self.x = x
        self.y = y

def loss(y):
    inputs = torch.randn(20, 2, device="cuda")
    
    # Option 1:
    likelihood = gpytorch.likelihoods.GaussianLikelihood()
    model = ExactGPModel(inputs, y, likelihood)
    # Option 2:
    # box = BoxModule(inputs, targets)


def main():
    torch.cuda.memory._record_memory_history()
    y = torch.rand(20, device="cuda")

    for _ in range(10):
        l = loss(y)
        time.sleep(3)

    try:
        torch.cuda.memory._dump_snapshot("memory.pickle")
    except Exception as e:
        print(f"Failed to capture memory snapshot {e}")
    torch.cuda.memory._record_memory_history(enabled=None)


if __name__ == "__main__":
    main()

** Memory profile **

Using PyTorch's memory visualization tool I can observe that the memory for the GP input tensors does not seem to be freed after each iteration. Of course, in this example the memory build-up is negligible but in my actual use case this leads to CUDA OOM issues.

Image

Expected Behavior

If we uncomment the usage of ExactGP in the loss function and instead initialise the BoxModule the memory seems to freed after each iteration as we would expect.

Image

System information

Please complete the following information:

  • Python version: 3.12
  • gpytorch.__version__: 1.13
  • torch.__version__: 2.6.0
  • torch.version.cuda: 12.4

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions