From 25d9c70d1ceb1239bdaa68b1b4315514fff5715e Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 4 Sep 2025 11:11:11 +0530 Subject: [PATCH 1/4] feat: support group offloading at the pipeline level. --- src/diffusers/pipelines/pipeline_utils.py | 135 ++++++++++++++++++++++ 1 file changed, 135 insertions(+) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 023feae4dd27..a6e87ca85fe5 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -1334,6 +1334,141 @@ def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Un offload_buffers = len(model._parameters) > 0 cpu_offload(model, device, offload_buffers=offload_buffers) + def enable_group_offload( + self, + onload_device: torch.device, + offload_device: torch.device = torch.device("cpu"), + offload_type: str = "block_level", + num_blocks_per_group: Optional[int] = None, + non_blocking: bool = False, + use_stream: bool = False, + record_stream: bool = False, + low_cpu_mem_usage=False, + offload_to_disk_path: Optional[str] = None, + exclude_modules: Optional[Union[str, List[str]]] = None, + ) -> None: + r""" + Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, + and where it is beneficial, we need to first provide some context on how other supported offloading methods + work. + + Typically, offloading is done at two levels: + - Module-level: In Diffusers, this can be enabled using the `ModelMixin::enable_model_cpu_offload()` method. It + works by offloading each component of a pipeline to the CPU for storage, and onloading to the accelerator + device when needed for computation. This method is more memory-efficient than keeping all components on the + accelerator, but the memory requirements are still quite high. For this method to work, one needs memory + equivalent to size of the model in runtime dtype + size of largest intermediate activation tensors to be able + to complete the forward pass. + - Leaf-level: In Diffusers, this can be enabled using the `ModelMixin::enable_sequential_cpu_offload()` method. + It + works by offloading the lowest leaf-level parameters of the computation graph to the CPU for storage, and + onloading only the leafs to the accelerator device for computation. This uses the lowest amount of accelerator + memory, but can be slower due to the excessive number of device synchronizations. + + Group offloading is a middle ground between the two methods. It works by offloading groups of internal layers, + (either `torch.nn.ModuleList` or `torch.nn.Sequential`). This method uses lower memory than module-level + offloading. It is also faster than leaf-level/sequential offloading, as the number of device synchronizations + is reduced. + + Another supported feature (for CUDA devices with support for asynchronous data transfer streams) is the ability + to overlap data transfer and computation to reduce the overall execution time compared to sequential + offloading. This is enabled using layer prefetching with streams, i.e., the layer that is to be executed next + starts onloading to the accelerator device while the current layer is being executed - this increases the + memory requirements slightly. Note that this implementation also supports leaf-level offloading but can be made + much faster when using streams. + + Args: + onload_device (`torch.device`): + The device to which the group of modules are onloaded. + offload_device (`torch.device`, defaults to `torch.device("cpu")`): + The device to which the group of modules are offloaded. This should typically be the CPU. Default is + CPU. + offload_type (`str` or `GroupOffloadingType`, defaults to "block_level"): + The type of offloading to be applied. Can be one of "block_level" or "leaf_level". Default is + "block_level". + offload_to_disk_path (`str`, *optional*, defaults to `None`): + The path to the directory where parameters will be offloaded. Setting this option can be useful in + limited RAM environment settings where a reasonable speed-memory trade-off is desired. + num_blocks_per_group (`int`, *optional*): + The number of blocks per group when using offload_type="block_level". This is required when using + offload_type="block_level". + non_blocking (`bool`, defaults to `False`): + If True, offloading and onloading is done with non-blocking data transfer. + use_stream (`bool`, defaults to `False`): + If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for + overlapping computation and data transfer. + record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor + as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to + the [PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) + more details. + low_cpu_mem_usage (`bool`, defaults to `False`): + If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. + This option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be + useful when the CPU memory is a bottleneck but may counteract the benefits of using streams. + exclude_modules (`Union[str, List[str]]`, defaults to `None`): List of modules to exclude from offloading. + + Example: + ```python + >>> from diffusers import DiffusionPipeline + >>> import torch + + >>> pipe = DiffusionPipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16) + + >>> pipe.enable_group_offload( + ... onload_device=torch.device("cuda"), + ... offload_device=torch.device("cpu"), + ... offload_type="leaf_level", + ... use_stream=True, + ... ) + >>> image = pipe("a beautiful sunset").images[0] + ``` + """ + from ..hooks import apply_group_offloading + + if exclude_modules is not None and isinstance(exclude_modules, str): + exclude_modules = [exclude_modules] + + unknown = set(exclude_modules) - set(self.components.keys()) + if unknown: + logger.info( + f"The following modules are not present in pipeline: {', '.join(unknown)}. Ignore if this is expected." + ) + + for name, component in self.components.items(): + if name not in exclude_modules and isinstance(component, torch.nn.Module): + if hasattr(component, "enable_group_offload"): + component.enable_group_offload( + onload_device=onload_device, + offload_device=offload_device, + offload_type=offload_type, + num_blocks_per_group=num_blocks_per_group, + non_blocking=non_blocking, + use_stream=use_stream, + record_stream=record_stream, + low_cpu_mem_usage=low_cpu_mem_usage, + offload_to_disk_path=offload_to_disk_path, + ) + else: + apply_group_offloading( + module=component, + onload_device=onload_device, + offload_device=offload_device, + offload_type=offload_type, + num_blocks_per_group=num_blocks_per_group, + non_blocking=non_blocking, + use_stream=use_stream, + record_stream=record_stream, + low_cpu_mem_usage=low_cpu_mem_usage, + offload_to_disk_path=offload_to_disk_path, + ) + + if exclude_modules: + for module_name in exclude_modules: + module = getattr(self, module_name, None) + if module is not None and isinstance(module, torch.nn.Module): + module.to(onload_device) + logger.debug(f"Placed `{module_name}` on {onload_device} device as it was in `exclude_modules`.") + def reset_device_map(self): r""" Resets the device maps (if any) to None. From e141f5cfd0bccf692920bef9ff54642864432db8 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 4 Sep 2025 11:49:54 +0530 Subject: [PATCH 2/4] add tests --- src/diffusers/pipelines/pipeline_utils.py | 6 +- tests/pipelines/test_pipelines_common.py | 68 +++++++++++++++++++++++ 2 files changed, 72 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index a6e87ca85fe5..04e33e655ade 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -1425,10 +1425,12 @@ def enable_group_offload( """ from ..hooks import apply_group_offloading - if exclude_modules is not None and isinstance(exclude_modules, str): + if isinstance(exclude_modules, str): exclude_modules = [exclude_modules] + elif exclude_modules is None: + exclude_modules = [] - unknown = set(exclude_modules) - set(self.components.keys()) + unknown = set(exclude_modules) - self.components.keys() if unknown: logger.info( f"The following modules are not present in pipeline: {', '.join(unknown)}. Ignore if this is expected." diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index dcef33897e6a..db8209835be4 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -9,6 +9,7 @@ import numpy as np import PIL.Image +import pytest import torch import torch.nn as nn from huggingface_hub import ModelCard, delete_repo @@ -2362,6 +2363,73 @@ def test_pipeline_with_accelerator_device_map(self, expected_max_difference=1e-4 max_diff = np.abs(to_np(out) - to_np(loaded_out)).max() self.assertLess(max_diff, expected_max_difference) + @require_torch_accelerator + def test_pipeline_level_group_offloading_sanity_checks(self): + components = self.get_dummy_components() + pipe: DiffusionPipeline = self.pipeline_class(**components) + + for name, component in pipe.components.items(): + if hasattr(component, "_supports_group_offloading"): + if not component._supports_group_offloading: + pytest.skip(f"{self.pipeline_class.__name__} is not suitable for this test.") + + module_names = sorted( + [name for name, component in pipe.components.items() if isinstance(component, torch.nn.Module)] + ) + exclude_module_name = module_names[0] + offload_device = "cpu" + pipe.enable_group_offload( + onload_device=torch_device, + offload_device=offload_device, + offload_type="leaf_level", + exclude_modules=exclude_module_name, + ) + excluded_module = getattr(pipe, exclude_module_name) + self.assertTrue(torch.device(excluded_module.device).type == torch.device(torch_device).type) + + for name, component in pipe.components.items(): + if name not in [exclude_module_name] and isinstance(component, torch.nn.Module): + # `component.device` prints the `onload_device` type. We should probably override the + # `device` property in `ModelMixin`. + component_device = next(component.parameters())[0].device + self.assertTrue(torch.device(component_device).type == torch.device(offload_device).type) + + @require_torch_accelerator + def test_pipeline_level_group_offloading_inference(self, expected_max_difference=1e-4): + components = self.get_dummy_components() + pipe: DiffusionPipeline = self.pipeline_class(**components) + + for name, component in pipe.components.items(): + if hasattr(component, "_supports_group_offloading"): + if not component._supports_group_offloading: + pytest.skip(f"{self.pipeline_class.__name__} is not suitable for this test.") + + # Regular inference. + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + torch.manual_seed(0) + inputs = self.get_dummy_inputs(torch_device) + inputs["generator"] = torch.manual_seed(0) + out = pipe(**inputs)[0] + + pipe.to("cpu") + del pipe + + # Inference with offloading + pipe: DiffusionPipeline = self.pipeline_class(**components) + offload_device = "cpu" + pipe.enable_group_offload( + onload_device=torch_device, + offload_device=offload_device, + offload_type="leaf_level", + ) + pipe.set_progress_bar_config(disable=None) + inputs["generator"] = torch.manual_seed(0) + out_offload = pipe(**inputs)[0] + + max_diff = np.abs(to_np(out) - to_np(out_offload)).max() + self.assertLess(max_diff, expected_max_difference) + @is_staging_test class PipelinePushToHubTester(unittest.TestCase): From 8e2d0383e1e4f29eb1adf3d8581ff9a7a5fbee63 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 10 Sep 2025 08:12:16 +0530 Subject: [PATCH 3/4] up --- src/diffusers/pipelines/pipeline_utils.py | 36 ++++++++--------------- 1 file changed, 13 insertions(+), 23 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 04e33e655ade..0116ad917c00 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -1436,33 +1436,23 @@ def enable_group_offload( f"The following modules are not present in pipeline: {', '.join(unknown)}. Ignore if this is expected." ) + group_offload_kwargs = { + "onload_device": onload_device, + "offload_device": offload_device, + "offload_type": offload_type, + "num_blocks_per_group": num_blocks_per_group, + "non_blocking": non_blocking, + "use_stream": use_stream, + "record_stream": record_stream, + "low_cpu_mem_usage": low_cpu_mem_usage, + "offload_to_disk_path": offload_to_disk_path, + } for name, component in self.components.items(): if name not in exclude_modules and isinstance(component, torch.nn.Module): if hasattr(component, "enable_group_offload"): - component.enable_group_offload( - onload_device=onload_device, - offload_device=offload_device, - offload_type=offload_type, - num_blocks_per_group=num_blocks_per_group, - non_blocking=non_blocking, - use_stream=use_stream, - record_stream=record_stream, - low_cpu_mem_usage=low_cpu_mem_usage, - offload_to_disk_path=offload_to_disk_path, - ) + component.enable_group_offload(**group_offload_kwargs) else: - apply_group_offloading( - module=component, - onload_device=onload_device, - offload_device=offload_device, - offload_type=offload_type, - num_blocks_per_group=num_blocks_per_group, - non_blocking=non_blocking, - use_stream=use_stream, - record_stream=record_stream, - low_cpu_mem_usage=low_cpu_mem_usage, - offload_to_disk_path=offload_to_disk_path, - ) + apply_group_offloading(module=component, **group_offload_kwargs) if exclude_modules: for module_name in exclude_modules: From 511056a4e38d998c7a9c18e952ed8133913ed6d4 Mon Sep 17 00:00:00 2001 From: Steven Liu <59462357+stevhliu@users.noreply.github.com> Date: Tue, 9 Sep 2025 19:56:36 -0700 Subject: [PATCH 4/4] [docs] Pipeline group offloading (#12286) init Co-authored-by: Sayak Paul --- docs/source/en/optimization/memory.md | 49 +++++++++++++++++++++++++-- 1 file changed, 46 insertions(+), 3 deletions(-) diff --git a/docs/source/en/optimization/memory.md b/docs/source/en/optimization/memory.md index 78fd96e0277d..611e07ec7655 100644 --- a/docs/source/en/optimization/memory.md +++ b/docs/source/en/optimization/memory.md @@ -291,13 +291,53 @@ Group offloading moves groups of internal layers ([torch.nn.ModuleList](https:// > [!WARNING] > Group offloading may not work with all models if the forward implementation contains weight-dependent device casting of inputs because it may clash with group offloading's device casting mechanism. -Call [`~ModelMixin.enable_group_offload`] to enable it for standard Diffusers model components that inherit from [`ModelMixin`]. For other model components that don't inherit from [`ModelMixin`], such as a generic [torch.nn.Module](https://pytorch.org/docs/stable/generated/torch.nn.Module.html), use [`~hooks.apply_group_offloading`] instead. - -The `offload_type` parameter can be set to `block_level` or `leaf_level`. +Enable group offloading by configuring the `offload_type` parameter to `block_level` or `leaf_level`. - `block_level` offloads groups of layers based on the `num_blocks_per_group` parameter. For example, if `num_blocks_per_group=2` on a model with 40 layers, 2 layers are onloaded and offloaded at a time (20 total onloads/offloads). This drastically reduces memory requirements. - `leaf_level` offloads individual layers at the lowest level and is equivalent to [CPU offloading](#cpu-offloading). But it can be made faster if you use streams without giving up inference speed. +Group offloading is supported for entire pipelines or individual models. Applying group offloading to the entire pipeline is the easiest option while selectively applying it to individual models gives users more flexibility to use different offloading techniques for different models. + + + + +Call [`~DiffusionPipeline.enable_group_offload`] on a pipeline. + +```py +import torch +from diffusers import CogVideoXPipeline +from diffusers.hooks import apply_group_offloading +from diffusers.utils import export_to_video + +onload_device = torch.device("cuda") +offload_device = torch.device("cpu") + +pipeline = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16) +pipeline.enable_group_offload( + onload_device=onload_device, + offload_device=offload_device, + offload_type="leaf_level", + use_stream=True +) + +prompt = ( + "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. " + "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other " + "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, " + "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. " + "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical " + "atmosphere of this unique musical performance." +) +video = pipeline(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0] +print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB") +export_to_video(video, "output.mp4", fps=8) +``` + + + + +Call [`~ModelMixin.enable_group_offload`] on standard Diffusers model components that inherit from [`ModelMixin`]. For other model components that don't inherit from [`ModelMixin`], such as a generic [torch.nn.Module](https://pytorch.org/docs/stable/generated/torch.nn.Module.html), use [`~hooks.apply_group_offloading`] instead. + ```py import torch from diffusers import CogVideoXPipeline @@ -328,6 +368,9 @@ print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} G export_to_video(video, "output.mp4", fps=8) ``` + + + #### CUDA stream The `use_stream` parameter can be activated for CUDA devices that support asynchronous data transfer streams to reduce overall execution time compared to [CPU offloading](#cpu-offloading). It overlaps data transfer and computation by using layer prefetching. The next layer to be executed is loaded onto the GPU while the current layer is still being executed. It can increase CPU memory significantly so ensure you have 2x the amount of memory as the model size.