-
Notifications
You must be signed in to change notification settings - Fork 578
Description
🐛 Bug
I have a large training scripts for which I have encountered CUDA OOM issues. I was able to narrow down the issue to my particular usage of gpytorch and that using the ExactGP module was leading to memory leaks. The whole training script is too complicated to post here but I think I was able to reduce the bug down to a minimal example.
It seems that even just initializing the ExactGP module leads to memory leaks, even without doing backpropogation or calling the forward function of the module.
To reproduce
I found that the following code snippet already leads to a memory leak. Please let me know if I am using the library incorrectly in any way!
import gpytorch
import torch
import torch.nn as nn
import time
class ExactGPModel(gpytorch.models.ExactGP):
def __init__(self, train_x, train_y, likelihood):
super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
self.mean_module = gpytorch.means.ConstantMean()
self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
class BoxModule(nn.Module):
def __init__(self, x, y):
super().__init__()
self.x = x
self.y = y
def loss(y):
inputs = torch.randn(20, 2, device="cuda")
# Option 1:
likelihood = gpytorch.likelihoods.GaussianLikelihood()
model = ExactGPModel(inputs, y, likelihood)
# Option 2:
# box = BoxModule(inputs, targets)
def main():
torch.cuda.memory._record_memory_history()
y = torch.rand(20, device="cuda")
for _ in range(10):
l = loss(y)
time.sleep(3)
try:
torch.cuda.memory._dump_snapshot("memory.pickle")
except Exception as e:
print(f"Failed to capture memory snapshot {e}")
torch.cuda.memory._record_memory_history(enabled=None)
if __name__ == "__main__":
main()** Memory profile **
Using PyTorch's memory visualization tool I can observe that the memory for the GP input tensors does not seem to be freed after each iteration. Of course, in this example the memory build-up is negligible but in my actual use case this leads to CUDA OOM issues.
Expected Behavior
If we uncomment the usage of ExactGP in the loss function and instead initialise the BoxModule the memory seems to freed after each iteration as we would expect.
System information
Please complete the following information:
Python version: 3.12gpytorch.__version__: 1.13torch.__version__: 2.6.0torch.version.cuda: 12.4

