-
Notifications
You must be signed in to change notification settings - Fork 738
Description
Feature description
Allow a sweep over all autodiff graphs to remove tracked nodes that cannot be useful.
Feature motivation
On the autodiff graphs, it's possible to have leafs that are being tracked but are no longer useful thanks to their underlying tensors being dropped. This happens, for example, when a tensor that is set to require grad is created but doesn't participate in the graph that receives a backward call, i.e. when a developer makes a mistake and leaves extra tensors that shouldn't but do requires grad. Even after the tensor is dropped, its related node is still present as a graph for the autodiff graphs and there is no way to remove it, possibly leading to out of memory problems.
Ideally the developer would detect those and fix the "hanging" tensors, preventing them of having a tracked node, but this is somewhat costly during research and experiments, whereas a sporadic "sweep cleanup" call can be a useful option.
Example:
[dependencies.burn]
version = "0.19.0"
default-features = false
features = ["autodiff", "ndarray"]use burn::prelude::*;
type NdArray = burn::backend::NdArray<f32, i32>;
type NdArrayAuto = burn::backend::Autodiff<NdArray>;
fn main() {
train::<NdArrayAuto>();
}
pub fn train<AutoB: burn::tensor::backend::AutodiffBackend>() {
let train_device = <<AutoB as Backend>::Device>::default();
for _ in 0..1_000_000 {
let a: Tensor<AutoB, 2> = Tensor::zeros([2, 2], &train_device);
let a = a.require_grad(); // an autodiff node is created
drop(a); // the tensor is dropped but its unusable graph persists
}
// in here, AutoB/train_device cannot be used to cleanup the "hanging nodes"
}Suggest a Solution
Add a graph_cleanup function to the burn::tensor::backend::AutodiffBackend trait, that sweeps over all graphs and remove useless tracked leaves and their parents and the whole graph, if applicable. This can re-utilize the cleanup procedure that happens after a backward call.
Alternatively, every time an autodiff-related tensor is dropped, some form of node cleanup could happen. But since the cleanup involves analyzing the graph in which that tensor (its node) is part of, that would be too costly too often and an "overkill" as far as cleaning up go.
- Liked PR as a suggestion.