Skip to content

Remove autodiff unused nodes according to dropped tensors #3976

@swfsql

Description

@swfsql

Feature description

Allow a sweep over all autodiff graphs to remove tracked nodes that cannot be useful.

Feature motivation

On the autodiff graphs, it's possible to have leafs that are being tracked but are no longer useful thanks to their underlying tensors being dropped. This happens, for example, when a tensor that is set to require grad is created but doesn't participate in the graph that receives a backward call, i.e. when a developer makes a mistake and leaves extra tensors that shouldn't but do requires grad. Even after the tensor is dropped, its related node is still present as a graph for the autodiff graphs and there is no way to remove it, possibly leading to out of memory problems.
Ideally the developer would detect those and fix the "hanging" tensors, preventing them of having a tracked node, but this is somewhat costly during research and experiments, whereas a sporadic "sweep cleanup" call can be a useful option.


Example:

[dependencies.burn]
version = "0.19.0"
default-features = false
features = ["autodiff", "ndarray"]
use burn::prelude::*;

type NdArray = burn::backend::NdArray<f32, i32>;
type NdArrayAuto = burn::backend::Autodiff<NdArray>;

fn main() {
    train::<NdArrayAuto>();
}

pub fn train<AutoB: burn::tensor::backend::AutodiffBackend>() {
    let train_device = <<AutoB as Backend>::Device>::default();
    for _ in 0..1_000_000 {
        let a: Tensor<AutoB, 2> = Tensor::zeros([2, 2], &train_device);
        let a = a.require_grad(); // an autodiff node is created
        drop(a); // the tensor is dropped but its unusable graph persists
    }
    // in here, AutoB/train_device cannot be used to cleanup the "hanging nodes"
}

Suggest a Solution

Add a graph_cleanup function to the burn::tensor::backend::AutodiffBackend trait, that sweeps over all graphs and remove useless tracked leaves and their parents and the whole graph, if applicable. This can re-utilize the cleanup procedure that happens after a backward call.

Alternatively, every time an autodiff-related tensor is dropped, some form of node cleanup could happen. But since the cleanup involves analyzing the graph in which that tensor (its node) is part of, that would be too costly too often and an "overkill" as far as cleaning up go.


  • Liked PR as a suggestion.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions