Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 46 additions & 3 deletions docs/source/en/optimization/memory.md
Original file line number Diff line number Diff line change
Expand Up @@ -291,13 +291,53 @@ Group offloading moves groups of internal layers ([torch.nn.ModuleList](https://
> [!WARNING]
> Group offloading may not work with all models if the forward implementation contains weight-dependent device casting of inputs because it may clash with group offloading's device casting mechanism.

Call [`~ModelMixin.enable_group_offload`] to enable it for standard Diffusers model components that inherit from [`ModelMixin`]. For other model components that don't inherit from [`ModelMixin`], such as a generic [torch.nn.Module](https://pytorch.org/docs/stable/generated/torch.nn.Module.html), use [`~hooks.apply_group_offloading`] instead.

The `offload_type` parameter can be set to `block_level` or `leaf_level`.
Enable group offloading by configuring the `offload_type` parameter to `block_level` or `leaf_level`.

- `block_level` offloads groups of layers based on the `num_blocks_per_group` parameter. For example, if `num_blocks_per_group=2` on a model with 40 layers, 2 layers are onloaded and offloaded at a time (20 total onloads/offloads). This drastically reduces memory requirements.
- `leaf_level` offloads individual layers at the lowest level and is equivalent to [CPU offloading](#cpu-offloading). But it can be made faster if you use streams without giving up inference speed.

Group offloading is supported for entire pipelines or individual models. Applying group offloading to the entire pipeline is the easiest option while selectively applying it to individual models gives users more flexibility to use different offloading techniques for different models.

<hfoptions id="group-offloading">
<hfoption id="pipeline">

Call [`~DiffusionPipeline.enable_group_offload`] on a pipeline.

```py
import torch
from diffusers import CogVideoXPipeline
from diffusers.hooks import apply_group_offloading
from diffusers.utils import export_to_video

onload_device = torch.device("cuda")
offload_device = torch.device("cpu")

pipeline = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
pipeline.enable_group_offload(
onload_device=onload_device,
offload_device=offload_device,
offload_type="leaf_level",
use_stream=True
)

prompt = (
"A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
"The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
"pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
"casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
"The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
"atmosphere of this unique musical performance."
)
video = pipeline(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
export_to_video(video, "output.mp4", fps=8)
```

</hfoption>
<hfoption id="model">

Call [`~ModelMixin.enable_group_offload`] on standard Diffusers model components that inherit from [`ModelMixin`]. For other model components that don't inherit from [`ModelMixin`], such as a generic [torch.nn.Module](https://pytorch.org/docs/stable/generated/torch.nn.Module.html), use [`~hooks.apply_group_offloading`] instead.

```py
import torch
from diffusers import CogVideoXPipeline
Expand Down Expand Up @@ -328,6 +368,9 @@ print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} G
export_to_video(video, "output.mp4", fps=8)
```

</hfoption>
</hfoptions>

#### CUDA stream

The `use_stream` parameter can be activated for CUDA devices that support asynchronous data transfer streams to reduce overall execution time compared to [CPU offloading](#cpu-offloading). It overlaps data transfer and computation by using layer prefetching. The next layer to be executed is loaded onto the GPU while the current layer is still being executed. It can increase CPU memory significantly so ensure you have 2x the amount of memory as the model size.
Expand Down
127 changes: 127 additions & 0 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1334,6 +1334,133 @@ def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Un
offload_buffers = len(model._parameters) > 0
cpu_offload(model, device, offload_buffers=offload_buffers)

def enable_group_offload(
self,
onload_device: torch.device,
offload_device: torch.device = torch.device("cpu"),
offload_type: str = "block_level",
num_blocks_per_group: Optional[int] = None,
non_blocking: bool = False,
use_stream: bool = False,
record_stream: bool = False,
low_cpu_mem_usage=False,
offload_to_disk_path: Optional[str] = None,
exclude_modules: Optional[Union[str, List[str]]] = None,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's okay to expose this as an argument as opposed to how we do model_cpu_offload_seq, for example:

model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"

This is because model CPU offloading relies on a sequence for device management. I don't think we have that constraint in the case of group offloading.

) -> None:
r"""
Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is,
and where it is beneficial, we need to first provide some context on how other supported offloading methods
work.

Typically, offloading is done at two levels:
- Module-level: In Diffusers, this can be enabled using the `ModelMixin::enable_model_cpu_offload()` method. It
works by offloading each component of a pipeline to the CPU for storage, and onloading to the accelerator
device when needed for computation. This method is more memory-efficient than keeping all components on the
accelerator, but the memory requirements are still quite high. For this method to work, one needs memory
equivalent to size of the model in runtime dtype + size of largest intermediate activation tensors to be able
to complete the forward pass.
- Leaf-level: In Diffusers, this can be enabled using the `ModelMixin::enable_sequential_cpu_offload()` method.
It
works by offloading the lowest leaf-level parameters of the computation graph to the CPU for storage, and
onloading only the leafs to the accelerator device for computation. This uses the lowest amount of accelerator
memory, but can be slower due to the excessive number of device synchronizations.

Group offloading is a middle ground between the two methods. It works by offloading groups of internal layers,
(either `torch.nn.ModuleList` or `torch.nn.Sequential`). This method uses lower memory than module-level
offloading. It is also faster than leaf-level/sequential offloading, as the number of device synchronizations
is reduced.

Another supported feature (for CUDA devices with support for asynchronous data transfer streams) is the ability
to overlap data transfer and computation to reduce the overall execution time compared to sequential
offloading. This is enabled using layer prefetching with streams, i.e., the layer that is to be executed next
starts onloading to the accelerator device while the current layer is being executed - this increases the
memory requirements slightly. Note that this implementation also supports leaf-level offloading but can be made
much faster when using streams.

Args:
onload_device (`torch.device`):
The device to which the group of modules are onloaded.
offload_device (`torch.device`, defaults to `torch.device("cpu")`):
The device to which the group of modules are offloaded. This should typically be the CPU. Default is
CPU.
offload_type (`str` or `GroupOffloadingType`, defaults to "block_level"):
The type of offloading to be applied. Can be one of "block_level" or "leaf_level". Default is
"block_level".
offload_to_disk_path (`str`, *optional*, defaults to `None`):
The path to the directory where parameters will be offloaded. Setting this option can be useful in
limited RAM environment settings where a reasonable speed-memory trade-off is desired.
num_blocks_per_group (`int`, *optional*):
The number of blocks per group when using offload_type="block_level". This is required when using
offload_type="block_level".
non_blocking (`bool`, defaults to `False`):
If True, offloading and onloading is done with non-blocking data transfer.
use_stream (`bool`, defaults to `False`):
If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for
overlapping computation and data transfer.
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to
the [PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html)
more details.
low_cpu_mem_usage (`bool`, defaults to `False`):
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them.
This option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be
useful when the CPU memory is a bottleneck but may counteract the benefits of using streams.
exclude_modules (`Union[str, List[str]]`, defaults to `None`): List of modules to exclude from offloading.

Example:
```python
>>> from diffusers import DiffusionPipeline
>>> import torch

>>> pipe = DiffusionPipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16)

>>> pipe.enable_group_offload(
... onload_device=torch.device("cuda"),
... offload_device=torch.device("cpu"),
... offload_type="leaf_level",
... use_stream=True,
... )
>>> image = pipe("a beautiful sunset").images[0]
```
"""
from ..hooks import apply_group_offloading

if isinstance(exclude_modules, str):
exclude_modules = [exclude_modules]
elif exclude_modules is None:
exclude_modules = []

unknown = set(exclude_modules) - self.components.keys()
if unknown:
logger.info(
f"The following modules are not present in pipeline: {', '.join(unknown)}. Ignore if this is expected."
)

group_offload_kwargs = {
"onload_device": onload_device,
"offload_device": offload_device,
"offload_type": offload_type,
"num_blocks_per_group": num_blocks_per_group,
"non_blocking": non_blocking,
"use_stream": use_stream,
"record_stream": record_stream,
"low_cpu_mem_usage": low_cpu_mem_usage,
"offload_to_disk_path": offload_to_disk_path,
}
for name, component in self.components.items():
if name not in exclude_modules and isinstance(component, torch.nn.Module):
if hasattr(component, "enable_group_offload"):
component.enable_group_offload(**group_offload_kwargs)
else:
apply_group_offloading(module=component, **group_offload_kwargs)

if exclude_modules:
for module_name in exclude_modules:
module = getattr(self, module_name, None)
if module is not None and isinstance(module, torch.nn.Module):
module.to(onload_device)
logger.debug(f"Placed `{module_name}` on {onload_device} device as it was in `exclude_modules`.")

def reset_device_map(self):
r"""
Resets the device maps (if any) to None.
Expand Down
68 changes: 68 additions & 0 deletions tests/pipelines/test_pipelines_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import numpy as np
import PIL.Image
import pytest
import torch
import torch.nn as nn
from huggingface_hub import ModelCard, delete_repo
Expand Down Expand Up @@ -2362,6 +2363,73 @@ def test_pipeline_with_accelerator_device_map(self, expected_max_difference=1e-4
max_diff = np.abs(to_np(out) - to_np(loaded_out)).max()
self.assertLess(max_diff, expected_max_difference)

@require_torch_accelerator
def test_pipeline_level_group_offloading_sanity_checks(self):
components = self.get_dummy_components()
pipe: DiffusionPipeline = self.pipeline_class(**components)

for name, component in pipe.components.items():
if hasattr(component, "_supports_group_offloading"):
if not component._supports_group_offloading:
pytest.skip(f"{self.pipeline_class.__name__} is not suitable for this test.")

module_names = sorted(
[name for name, component in pipe.components.items() if isinstance(component, torch.nn.Module)]
)
exclude_module_name = module_names[0]
offload_device = "cpu"
pipe.enable_group_offload(
onload_device=torch_device,
offload_device=offload_device,
offload_type="leaf_level",
exclude_modules=exclude_module_name,
)
excluded_module = getattr(pipe, exclude_module_name)
self.assertTrue(torch.device(excluded_module.device).type == torch.device(torch_device).type)

for name, component in pipe.components.items():
if name not in [exclude_module_name] and isinstance(component, torch.nn.Module):
# `component.device` prints the `onload_device` type. We should probably override the
# `device` property in `ModelMixin`.
component_device = next(component.parameters())[0].device
self.assertTrue(torch.device(component_device).type == torch.device(offload_device).type)

@require_torch_accelerator
def test_pipeline_level_group_offloading_inference(self, expected_max_difference=1e-4):
components = self.get_dummy_components()
pipe: DiffusionPipeline = self.pipeline_class(**components)

for name, component in pipe.components.items():
if hasattr(component, "_supports_group_offloading"):
if not component._supports_group_offloading:
pytest.skip(f"{self.pipeline_class.__name__} is not suitable for this test.")

# Regular inference.
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
torch.manual_seed(0)
inputs = self.get_dummy_inputs(torch_device)
inputs["generator"] = torch.manual_seed(0)
out = pipe(**inputs)[0]

pipe.to("cpu")
del pipe

# Inference with offloading
pipe: DiffusionPipeline = self.pipeline_class(**components)
offload_device = "cpu"
pipe.enable_group_offload(
onload_device=torch_device,
offload_device=offload_device,
offload_type="leaf_level",
)
pipe.set_progress_bar_config(disable=None)
inputs["generator"] = torch.manual_seed(0)
out_offload = pipe(**inputs)[0]

max_diff = np.abs(to_np(out) - to_np(out_offload)).max()
self.assertLess(max_diff, expected_max_difference)


@is_staging_test
class PipelinePushToHubTester(unittest.TestCase):
Expand Down