Skip to content

Commit dcb6dd9

Browse files
a-r-r-o-wDN6sayakpaul
authored
Context Parallel w/ Ring & Ulysses & Unified Attention (#11941)
* update * update * add coauthor Co-Authored-By: Dhruv Nair <[email protected]> * improve test * handle ip adapter params correctly * fix chroma qkv fusion test * fix fastercache implementation * fix more tests * fight more tests * add back set_attention_backend * update * update * make style * make fix-copies * make ip adapter processor compatible with attention dispatcher * refactor chroma as well * remove rmsnorm assert * minify and deprecate npu/xla processors * update * refactor * refactor; support flash attention 2 with cp * fix * support sage attention with cp * make torch compile compatible * update * refactor * update * refactor * refactor * add ulysses backward * try to make dreambooth script work; accelerator backward not playing well * Revert "try to make dreambooth script work; accelerator backward not playing well" This reverts commit 768d0ea. * workaround compilation problems with triton when doing all-to-all * support wan * handle backward correctly * support qwen * support ltx * make fix-copies * Update src/diffusers/models/modeling_utils.py Co-authored-by: Dhruv Nair <[email protected]> * apply review suggestions * update docs * add explanation * make fix-copies * add docstrings * support passing parallel_config to from_pretrained * apply review suggestions * make style * update * Update docs/source/en/api/parallel.md Co-authored-by: Aryan <[email protected]> * up --------- Co-authored-by: Dhruv Nair <[email protected]> Co-authored-by: sayakpaul <[email protected]>
1 parent 043ab25 commit dcb6dd9

16 files changed

+1571
-174
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@
7070
title: Reduce memory usage
7171
- local: optimization/speed-memory-optims
7272
title: Compiling and offloading quantized models
73+
- local: api/parallel
74+
title: Parallel inference
7375
- title: Community optimizations
7476
sections:
7577
- local: optimization/pruna

docs/source/en/api/parallel.md

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License. -->
11+
12+
# Parallelism
13+
14+
Parallelism strategies help speed up diffusion transformers by distributing computations across multiple devices, allowing for faster inference/training times.
15+
16+
## ParallelConfig
17+
18+
[[autodoc]] ParallelConfig
19+
20+
## ContextParallelConfig
21+
22+
[[autodoc]] ContextParallelConfig
23+
24+
[[autodoc]] hooks.apply_context_parallel

src/diffusers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@
202202
"CogView4Transformer2DModel",
203203
"ConsisIDTransformer3DModel",
204204
"ConsistencyDecoderVAE",
205+
"ContextParallelConfig",
205206
"ControlNetModel",
206207
"ControlNetUnionModel",
207208
"ControlNetXSAdapter",
@@ -229,6 +230,7 @@
229230
"MultiAdapter",
230231
"MultiControlNetModel",
231232
"OmniGenTransformer2DModel",
233+
"ParallelConfig",
232234
"PixArtTransformer2DModel",
233235
"PriorTransformer",
234236
"QwenImageControlNetModel",
@@ -888,6 +890,7 @@
888890
CogView4Transformer2DModel,
889891
ConsisIDTransformer3DModel,
890892
ConsistencyDecoderVAE,
893+
ContextParallelConfig,
891894
ControlNetModel,
892895
ControlNetUnionModel,
893896
ControlNetXSAdapter,
@@ -915,6 +918,7 @@
915918
MultiAdapter,
916919
MultiControlNetModel,
917920
OmniGenTransformer2DModel,
921+
ParallelConfig,
918922
PixArtTransformer2DModel,
919923
PriorTransformer,
920924
QwenImageControlNetModel,

src/diffusers/hooks/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717

1818
if is_torch_available():
19+
from .context_parallel import apply_context_parallel
1920
from .faster_cache import FasterCacheConfig, apply_faster_cache
2021
from .first_block_cache import FirstBlockCacheConfig, apply_first_block_cache
2122
from .group_offloading import apply_group_offloading
Lines changed: 297 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,297 @@
1+
# Copyright 2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import inspect
16+
from dataclasses import dataclass
17+
from typing import Dict, List, Type, Union
18+
19+
import torch
20+
import torch.distributed._functional_collectives as funcol
21+
22+
from ..models._modeling_parallel import (
23+
ContextParallelConfig,
24+
ContextParallelInput,
25+
ContextParallelModelPlan,
26+
ContextParallelOutput,
27+
)
28+
from ..utils import get_logger
29+
from ..utils.torch_utils import unwrap_module
30+
from .hooks import HookRegistry, ModelHook
31+
32+
33+
logger = get_logger(__name__) # pylint: disable=invalid-name
34+
35+
_CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE = "cp_input---{}"
36+
_CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE = "cp_output---{}"
37+
38+
39+
# TODO(aryan): consolidate with ._helpers.TransformerBlockMetadata
40+
@dataclass
41+
class ModuleForwardMetadata:
42+
cached_parameter_indices: Dict[str, int] = None
43+
_cls: Type = None
44+
45+
def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None):
46+
kwargs = kwargs or {}
47+
48+
if identifier in kwargs:
49+
return kwargs[identifier], True, None
50+
51+
if self.cached_parameter_indices is not None:
52+
index = self.cached_parameter_indices.get(identifier, None)
53+
if index is None:
54+
raise ValueError(f"Parameter '{identifier}' not found in cached indices.")
55+
return args[index], False, index
56+
57+
if self._cls is None:
58+
raise ValueError("Model class is not set for metadata.")
59+
60+
parameters = list(inspect.signature(self._cls.forward).parameters.keys())
61+
parameters = parameters[1:] # skip `self`
62+
self.cached_parameter_indices = {param: i for i, param in enumerate(parameters)}
63+
64+
if identifier not in self.cached_parameter_indices:
65+
raise ValueError(f"Parameter '{identifier}' not found in function signature but was requested.")
66+
67+
index = self.cached_parameter_indices[identifier]
68+
69+
if index >= len(args):
70+
raise ValueError(f"Expected {index} arguments but got {len(args)}.")
71+
72+
return args[index], False, index
73+
74+
75+
def apply_context_parallel(
76+
module: torch.nn.Module,
77+
parallel_config: ContextParallelConfig,
78+
plan: Dict[str, ContextParallelModelPlan],
79+
) -> None:
80+
"""Apply context parallel on a model."""
81+
logger.debug(f"Applying context parallel with CP mesh: {parallel_config._mesh} and plan: {plan}")
82+
83+
for module_id, cp_model_plan in plan.items():
84+
submodule = _get_submodule_by_name(module, module_id)
85+
if not isinstance(submodule, list):
86+
submodule = [submodule]
87+
88+
logger.debug(f"Applying ContextParallelHook to {module_id=} identifying a total of {len(submodule)} modules")
89+
90+
for m in submodule:
91+
if isinstance(cp_model_plan, dict):
92+
hook = ContextParallelSplitHook(cp_model_plan, parallel_config)
93+
hook_name = _CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE.format(module_id)
94+
elif isinstance(cp_model_plan, (ContextParallelOutput, list, tuple)):
95+
if isinstance(cp_model_plan, ContextParallelOutput):
96+
cp_model_plan = [cp_model_plan]
97+
if not all(isinstance(x, ContextParallelOutput) for x in cp_model_plan):
98+
raise ValueError(f"Expected all elements of cp_model_plan to be CPOutput, but got {cp_model_plan}")
99+
hook = ContextParallelGatherHook(cp_model_plan, parallel_config)
100+
hook_name = _CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE.format(module_id)
101+
else:
102+
raise ValueError(f"Unsupported context parallel model plan type: {type(cp_model_plan)}")
103+
registry = HookRegistry.check_if_exists_or_initialize(m)
104+
registry.register_hook(hook, hook_name)
105+
106+
107+
def remove_context_parallel(module: torch.nn.Module, plan: Dict[str, ContextParallelModelPlan]) -> None:
108+
for module_id, cp_model_plan in plan.items():
109+
submodule = _get_submodule_by_name(module, module_id)
110+
if not isinstance(submodule, list):
111+
submodule = [submodule]
112+
113+
for m in submodule:
114+
registry = HookRegistry.check_if_exists_or_initialize(m)
115+
if isinstance(cp_model_plan, dict):
116+
hook_name = _CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE.format(module_id)
117+
elif isinstance(cp_model_plan, (ContextParallelOutput, list, tuple)):
118+
hook_name = _CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE.format(module_id)
119+
else:
120+
raise ValueError(f"Unsupported context parallel model plan type: {type(cp_model_plan)}")
121+
registry.remove_hook(hook_name)
122+
123+
124+
class ContextParallelSplitHook(ModelHook):
125+
def __init__(self, metadata: ContextParallelModelPlan, parallel_config: ContextParallelConfig) -> None:
126+
super().__init__()
127+
self.metadata = metadata
128+
self.parallel_config = parallel_config
129+
self.module_forward_metadata = None
130+
131+
def initialize_hook(self, module):
132+
cls = unwrap_module(module).__class__
133+
self.module_forward_metadata = ModuleForwardMetadata(_cls=cls)
134+
return module
135+
136+
def pre_forward(self, module, *args, **kwargs):
137+
args_list = list(args)
138+
139+
for name, cpm in self.metadata.items():
140+
if isinstance(cpm, ContextParallelInput) and cpm.split_output:
141+
continue
142+
143+
# Maybe the parameter was passed as a keyword argument
144+
input_val, is_kwarg, index = self.module_forward_metadata._get_parameter_from_args_kwargs(
145+
name, args_list, kwargs
146+
)
147+
148+
if input_val is None:
149+
continue
150+
151+
# The input_val may be a tensor or list/tuple of tensors. In certain cases, user may specify to shard
152+
# the output instead of input for a particular layer by setting split_output=True
153+
if isinstance(input_val, torch.Tensor):
154+
input_val = self._prepare_cp_input(input_val, cpm)
155+
elif isinstance(input_val, (list, tuple)):
156+
if len(input_val) != len(cpm):
157+
raise ValueError(
158+
f"Expected input model plan to have {len(input_val)} elements, but got {len(cpm)}."
159+
)
160+
sharded_input_val = []
161+
for i, x in enumerate(input_val):
162+
if torch.is_tensor(x) and not cpm[i].split_output:
163+
x = self._prepare_cp_input(x, cpm[i])
164+
sharded_input_val.append(x)
165+
input_val = sharded_input_val
166+
else:
167+
raise ValueError(f"Unsupported input type: {type(input_val)}")
168+
169+
if is_kwarg:
170+
kwargs[name] = input_val
171+
elif index is not None and index < len(args_list):
172+
args_list[index] = input_val
173+
else:
174+
raise ValueError(
175+
f"An unexpected error occurred while processing the input '{name}'. Please open an "
176+
f"issue at https://github.com/huggingface/diffusers/issues and provide a minimal reproducible "
177+
f"example along with the full stack trace."
178+
)
179+
180+
return tuple(args_list), kwargs
181+
182+
def post_forward(self, module, output):
183+
is_tensor = isinstance(output, torch.Tensor)
184+
is_tensor_list = isinstance(output, (list, tuple)) and all(isinstance(x, torch.Tensor) for x in output)
185+
186+
if not is_tensor and not is_tensor_list:
187+
raise ValueError(f"Expected output to be a tensor or a list/tuple of tensors, but got {type(output)}.")
188+
189+
output = [output] if is_tensor else list(output)
190+
for index, cpm in self.metadata.items():
191+
if not isinstance(cpm, ContextParallelInput) or not cpm.split_output:
192+
continue
193+
if index >= len(output):
194+
raise ValueError(f"Index {index} out of bounds for output of length {len(output)}.")
195+
current_output = output[index]
196+
current_output = self._prepare_cp_input(current_output, cpm)
197+
output[index] = current_output
198+
199+
return output[0] if is_tensor else tuple(output)
200+
201+
def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput) -> torch.Tensor:
202+
if cp_input.expected_dims is not None and x.dim() != cp_input.expected_dims:
203+
raise ValueError(
204+
f"Expected input tensor to have {cp_input.expected_dims} dimensions, but got {x.dim()} dimensions."
205+
)
206+
return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)
207+
208+
209+
class ContextParallelGatherHook(ModelHook):
210+
def __init__(self, metadata: ContextParallelModelPlan, parallel_config: ContextParallelConfig) -> None:
211+
super().__init__()
212+
self.metadata = metadata
213+
self.parallel_config = parallel_config
214+
215+
def post_forward(self, module, output):
216+
is_tensor = isinstance(output, torch.Tensor)
217+
218+
if is_tensor:
219+
output = [output]
220+
elif not (isinstance(output, (list, tuple)) and all(isinstance(x, torch.Tensor) for x in output)):
221+
raise ValueError(f"Expected output to be a tensor or a list/tuple of tensors, but got {type(output)}.")
222+
223+
output = list(output)
224+
225+
if len(output) != len(self.metadata):
226+
raise ValueError(f"Expected output to have {len(self.metadata)} elements, but got {len(output)}.")
227+
228+
for i, cpm in enumerate(self.metadata):
229+
if cpm is None:
230+
continue
231+
output[i] = EquipartitionSharder.unshard(output[i], cpm.gather_dim, self.parallel_config._flattened_mesh)
232+
233+
return output[0] if is_tensor else tuple(output)
234+
235+
236+
class AllGatherFunction(torch.autograd.Function):
237+
@staticmethod
238+
def forward(ctx, tensor, dim, group):
239+
ctx.dim = dim
240+
ctx.group = group
241+
ctx.world_size = torch.distributed.get_world_size(group)
242+
ctx.rank = torch.distributed.get_rank(group)
243+
return funcol.all_gather_tensor(tensor, dim, group=group)
244+
245+
@staticmethod
246+
def backward(ctx, grad_output):
247+
grad_chunks = torch.chunk(grad_output, ctx.world_size, dim=ctx.dim)
248+
return grad_chunks[ctx.rank], None, None
249+
250+
251+
class EquipartitionSharder:
252+
@classmethod
253+
def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
254+
# NOTE: the following assertion does not have to be true in general. We simply enforce it for now
255+
# because the alternate case has not yet been tested/required for any model.
256+
assert tensor.size()[dim] % mesh.size() == 0, (
257+
"Tensor size along dimension to be sharded must be divisible by mesh size"
258+
)
259+
260+
# The following is not fullgraph compatible with Dynamo (fails in DeviceMesh.get_rank)
261+
# return tensor.chunk(mesh.size(), dim=dim)[mesh.get_rank()]
262+
263+
return tensor.chunk(mesh.size(), dim=dim)[torch.distributed.get_rank(mesh.get_group())]
264+
265+
@classmethod
266+
def unshard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
267+
tensor = tensor.contiguous()
268+
tensor = AllGatherFunction.apply(tensor, dim, mesh.get_group())
269+
return tensor
270+
271+
272+
def _get_submodule_by_name(model: torch.nn.Module, name: str) -> Union[torch.nn.Module, List[torch.nn.Module]]:
273+
if name.count("*") > 1:
274+
raise ValueError("Wildcard '*' can only be used once in the name")
275+
return _find_submodule_by_name(model, name)
276+
277+
278+
def _find_submodule_by_name(model: torch.nn.Module, name: str) -> Union[torch.nn.Module, List[torch.nn.Module]]:
279+
if name == "":
280+
return model
281+
first_atom, remaining_name = name.split(".", 1) if "." in name else (name, "")
282+
if first_atom == "*":
283+
if not isinstance(model, torch.nn.ModuleList):
284+
raise ValueError("Wildcard '*' can only be used with ModuleList")
285+
submodules = []
286+
for submodule in model:
287+
subsubmodules = _find_submodule_by_name(submodule, remaining_name)
288+
if not isinstance(subsubmodules, list):
289+
subsubmodules = [subsubmodules]
290+
submodules.extend(subsubmodules)
291+
return submodules
292+
else:
293+
if hasattr(model, first_atom):
294+
submodule = getattr(model, first_atom)
295+
return _find_submodule_by_name(submodule, remaining_name)
296+
else:
297+
raise ValueError(f"'{first_atom}' is not a submodule of '{model.__class__.__name__}'")

src/diffusers/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
_import_structure = {}
2626

2727
if is_torch_available():
28+
_import_structure["_modeling_parallel"] = ["ContextParallelConfig", "ParallelConfig"]
2829
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
2930
_import_structure["attention_dispatch"] = ["AttentionBackendName", "attention_backend"]
3031
_import_structure["auto_model"] = ["AutoModel"]
@@ -119,6 +120,7 @@
119120

120121
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
121122
if is_torch_available():
123+
from ._modeling_parallel import ContextParallelConfig, ParallelConfig
122124
from .adapter import MultiAdapter, T2IAdapter
123125
from .attention_dispatch import AttentionBackendName, attention_backend
124126
from .auto_model import AutoModel

0 commit comments

Comments
 (0)