|
| 1 | +# Copyright 2025 The HuggingFace Team. All rights reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import inspect |
| 16 | +from dataclasses import dataclass |
| 17 | +from typing import Dict, List, Type, Union |
| 18 | + |
| 19 | +import torch |
| 20 | +import torch.distributed._functional_collectives as funcol |
| 21 | + |
| 22 | +from ..models._modeling_parallel import ( |
| 23 | + ContextParallelConfig, |
| 24 | + ContextParallelInput, |
| 25 | + ContextParallelModelPlan, |
| 26 | + ContextParallelOutput, |
| 27 | +) |
| 28 | +from ..utils import get_logger |
| 29 | +from ..utils.torch_utils import unwrap_module |
| 30 | +from .hooks import HookRegistry, ModelHook |
| 31 | + |
| 32 | + |
| 33 | +logger = get_logger(__name__) # pylint: disable=invalid-name |
| 34 | + |
| 35 | +_CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE = "cp_input---{}" |
| 36 | +_CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE = "cp_output---{}" |
| 37 | + |
| 38 | + |
| 39 | +# TODO(aryan): consolidate with ._helpers.TransformerBlockMetadata |
| 40 | +@dataclass |
| 41 | +class ModuleForwardMetadata: |
| 42 | + cached_parameter_indices: Dict[str, int] = None |
| 43 | + _cls: Type = None |
| 44 | + |
| 45 | + def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None): |
| 46 | + kwargs = kwargs or {} |
| 47 | + |
| 48 | + if identifier in kwargs: |
| 49 | + return kwargs[identifier], True, None |
| 50 | + |
| 51 | + if self.cached_parameter_indices is not None: |
| 52 | + index = self.cached_parameter_indices.get(identifier, None) |
| 53 | + if index is None: |
| 54 | + raise ValueError(f"Parameter '{identifier}' not found in cached indices.") |
| 55 | + return args[index], False, index |
| 56 | + |
| 57 | + if self._cls is None: |
| 58 | + raise ValueError("Model class is not set for metadata.") |
| 59 | + |
| 60 | + parameters = list(inspect.signature(self._cls.forward).parameters.keys()) |
| 61 | + parameters = parameters[1:] # skip `self` |
| 62 | + self.cached_parameter_indices = {param: i for i, param in enumerate(parameters)} |
| 63 | + |
| 64 | + if identifier not in self.cached_parameter_indices: |
| 65 | + raise ValueError(f"Parameter '{identifier}' not found in function signature but was requested.") |
| 66 | + |
| 67 | + index = self.cached_parameter_indices[identifier] |
| 68 | + |
| 69 | + if index >= len(args): |
| 70 | + raise ValueError(f"Expected {index} arguments but got {len(args)}.") |
| 71 | + |
| 72 | + return args[index], False, index |
| 73 | + |
| 74 | + |
| 75 | +def apply_context_parallel( |
| 76 | + module: torch.nn.Module, |
| 77 | + parallel_config: ContextParallelConfig, |
| 78 | + plan: Dict[str, ContextParallelModelPlan], |
| 79 | +) -> None: |
| 80 | + """Apply context parallel on a model.""" |
| 81 | + logger.debug(f"Applying context parallel with CP mesh: {parallel_config._mesh} and plan: {plan}") |
| 82 | + |
| 83 | + for module_id, cp_model_plan in plan.items(): |
| 84 | + submodule = _get_submodule_by_name(module, module_id) |
| 85 | + if not isinstance(submodule, list): |
| 86 | + submodule = [submodule] |
| 87 | + |
| 88 | + logger.debug(f"Applying ContextParallelHook to {module_id=} identifying a total of {len(submodule)} modules") |
| 89 | + |
| 90 | + for m in submodule: |
| 91 | + if isinstance(cp_model_plan, dict): |
| 92 | + hook = ContextParallelSplitHook(cp_model_plan, parallel_config) |
| 93 | + hook_name = _CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE.format(module_id) |
| 94 | + elif isinstance(cp_model_plan, (ContextParallelOutput, list, tuple)): |
| 95 | + if isinstance(cp_model_plan, ContextParallelOutput): |
| 96 | + cp_model_plan = [cp_model_plan] |
| 97 | + if not all(isinstance(x, ContextParallelOutput) for x in cp_model_plan): |
| 98 | + raise ValueError(f"Expected all elements of cp_model_plan to be CPOutput, but got {cp_model_plan}") |
| 99 | + hook = ContextParallelGatherHook(cp_model_plan, parallel_config) |
| 100 | + hook_name = _CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE.format(module_id) |
| 101 | + else: |
| 102 | + raise ValueError(f"Unsupported context parallel model plan type: {type(cp_model_plan)}") |
| 103 | + registry = HookRegistry.check_if_exists_or_initialize(m) |
| 104 | + registry.register_hook(hook, hook_name) |
| 105 | + |
| 106 | + |
| 107 | +def remove_context_parallel(module: torch.nn.Module, plan: Dict[str, ContextParallelModelPlan]) -> None: |
| 108 | + for module_id, cp_model_plan in plan.items(): |
| 109 | + submodule = _get_submodule_by_name(module, module_id) |
| 110 | + if not isinstance(submodule, list): |
| 111 | + submodule = [submodule] |
| 112 | + |
| 113 | + for m in submodule: |
| 114 | + registry = HookRegistry.check_if_exists_or_initialize(m) |
| 115 | + if isinstance(cp_model_plan, dict): |
| 116 | + hook_name = _CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE.format(module_id) |
| 117 | + elif isinstance(cp_model_plan, (ContextParallelOutput, list, tuple)): |
| 118 | + hook_name = _CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE.format(module_id) |
| 119 | + else: |
| 120 | + raise ValueError(f"Unsupported context parallel model plan type: {type(cp_model_plan)}") |
| 121 | + registry.remove_hook(hook_name) |
| 122 | + |
| 123 | + |
| 124 | +class ContextParallelSplitHook(ModelHook): |
| 125 | + def __init__(self, metadata: ContextParallelModelPlan, parallel_config: ContextParallelConfig) -> None: |
| 126 | + super().__init__() |
| 127 | + self.metadata = metadata |
| 128 | + self.parallel_config = parallel_config |
| 129 | + self.module_forward_metadata = None |
| 130 | + |
| 131 | + def initialize_hook(self, module): |
| 132 | + cls = unwrap_module(module).__class__ |
| 133 | + self.module_forward_metadata = ModuleForwardMetadata(_cls=cls) |
| 134 | + return module |
| 135 | + |
| 136 | + def pre_forward(self, module, *args, **kwargs): |
| 137 | + args_list = list(args) |
| 138 | + |
| 139 | + for name, cpm in self.metadata.items(): |
| 140 | + if isinstance(cpm, ContextParallelInput) and cpm.split_output: |
| 141 | + continue |
| 142 | + |
| 143 | + # Maybe the parameter was passed as a keyword argument |
| 144 | + input_val, is_kwarg, index = self.module_forward_metadata._get_parameter_from_args_kwargs( |
| 145 | + name, args_list, kwargs |
| 146 | + ) |
| 147 | + |
| 148 | + if input_val is None: |
| 149 | + continue |
| 150 | + |
| 151 | + # The input_val may be a tensor or list/tuple of tensors. In certain cases, user may specify to shard |
| 152 | + # the output instead of input for a particular layer by setting split_output=True |
| 153 | + if isinstance(input_val, torch.Tensor): |
| 154 | + input_val = self._prepare_cp_input(input_val, cpm) |
| 155 | + elif isinstance(input_val, (list, tuple)): |
| 156 | + if len(input_val) != len(cpm): |
| 157 | + raise ValueError( |
| 158 | + f"Expected input model plan to have {len(input_val)} elements, but got {len(cpm)}." |
| 159 | + ) |
| 160 | + sharded_input_val = [] |
| 161 | + for i, x in enumerate(input_val): |
| 162 | + if torch.is_tensor(x) and not cpm[i].split_output: |
| 163 | + x = self._prepare_cp_input(x, cpm[i]) |
| 164 | + sharded_input_val.append(x) |
| 165 | + input_val = sharded_input_val |
| 166 | + else: |
| 167 | + raise ValueError(f"Unsupported input type: {type(input_val)}") |
| 168 | + |
| 169 | + if is_kwarg: |
| 170 | + kwargs[name] = input_val |
| 171 | + elif index is not None and index < len(args_list): |
| 172 | + args_list[index] = input_val |
| 173 | + else: |
| 174 | + raise ValueError( |
| 175 | + f"An unexpected error occurred while processing the input '{name}'. Please open an " |
| 176 | + f"issue at https://github.com/huggingface/diffusers/issues and provide a minimal reproducible " |
| 177 | + f"example along with the full stack trace." |
| 178 | + ) |
| 179 | + |
| 180 | + return tuple(args_list), kwargs |
| 181 | + |
| 182 | + def post_forward(self, module, output): |
| 183 | + is_tensor = isinstance(output, torch.Tensor) |
| 184 | + is_tensor_list = isinstance(output, (list, tuple)) and all(isinstance(x, torch.Tensor) for x in output) |
| 185 | + |
| 186 | + if not is_tensor and not is_tensor_list: |
| 187 | + raise ValueError(f"Expected output to be a tensor or a list/tuple of tensors, but got {type(output)}.") |
| 188 | + |
| 189 | + output = [output] if is_tensor else list(output) |
| 190 | + for index, cpm in self.metadata.items(): |
| 191 | + if not isinstance(cpm, ContextParallelInput) or not cpm.split_output: |
| 192 | + continue |
| 193 | + if index >= len(output): |
| 194 | + raise ValueError(f"Index {index} out of bounds for output of length {len(output)}.") |
| 195 | + current_output = output[index] |
| 196 | + current_output = self._prepare_cp_input(current_output, cpm) |
| 197 | + output[index] = current_output |
| 198 | + |
| 199 | + return output[0] if is_tensor else tuple(output) |
| 200 | + |
| 201 | + def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput) -> torch.Tensor: |
| 202 | + if cp_input.expected_dims is not None and x.dim() != cp_input.expected_dims: |
| 203 | + raise ValueError( |
| 204 | + f"Expected input tensor to have {cp_input.expected_dims} dimensions, but got {x.dim()} dimensions." |
| 205 | + ) |
| 206 | + return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh) |
| 207 | + |
| 208 | + |
| 209 | +class ContextParallelGatherHook(ModelHook): |
| 210 | + def __init__(self, metadata: ContextParallelModelPlan, parallel_config: ContextParallelConfig) -> None: |
| 211 | + super().__init__() |
| 212 | + self.metadata = metadata |
| 213 | + self.parallel_config = parallel_config |
| 214 | + |
| 215 | + def post_forward(self, module, output): |
| 216 | + is_tensor = isinstance(output, torch.Tensor) |
| 217 | + |
| 218 | + if is_tensor: |
| 219 | + output = [output] |
| 220 | + elif not (isinstance(output, (list, tuple)) and all(isinstance(x, torch.Tensor) for x in output)): |
| 221 | + raise ValueError(f"Expected output to be a tensor or a list/tuple of tensors, but got {type(output)}.") |
| 222 | + |
| 223 | + output = list(output) |
| 224 | + |
| 225 | + if len(output) != len(self.metadata): |
| 226 | + raise ValueError(f"Expected output to have {len(self.metadata)} elements, but got {len(output)}.") |
| 227 | + |
| 228 | + for i, cpm in enumerate(self.metadata): |
| 229 | + if cpm is None: |
| 230 | + continue |
| 231 | + output[i] = EquipartitionSharder.unshard(output[i], cpm.gather_dim, self.parallel_config._flattened_mesh) |
| 232 | + |
| 233 | + return output[0] if is_tensor else tuple(output) |
| 234 | + |
| 235 | + |
| 236 | +class AllGatherFunction(torch.autograd.Function): |
| 237 | + @staticmethod |
| 238 | + def forward(ctx, tensor, dim, group): |
| 239 | + ctx.dim = dim |
| 240 | + ctx.group = group |
| 241 | + ctx.world_size = torch.distributed.get_world_size(group) |
| 242 | + ctx.rank = torch.distributed.get_rank(group) |
| 243 | + return funcol.all_gather_tensor(tensor, dim, group=group) |
| 244 | + |
| 245 | + @staticmethod |
| 246 | + def backward(ctx, grad_output): |
| 247 | + grad_chunks = torch.chunk(grad_output, ctx.world_size, dim=ctx.dim) |
| 248 | + return grad_chunks[ctx.rank], None, None |
| 249 | + |
| 250 | + |
| 251 | +class EquipartitionSharder: |
| 252 | + @classmethod |
| 253 | + def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor: |
| 254 | + # NOTE: the following assertion does not have to be true in general. We simply enforce it for now |
| 255 | + # because the alternate case has not yet been tested/required for any model. |
| 256 | + assert tensor.size()[dim] % mesh.size() == 0, ( |
| 257 | + "Tensor size along dimension to be sharded must be divisible by mesh size" |
| 258 | + ) |
| 259 | + |
| 260 | + # The following is not fullgraph compatible with Dynamo (fails in DeviceMesh.get_rank) |
| 261 | + # return tensor.chunk(mesh.size(), dim=dim)[mesh.get_rank()] |
| 262 | + |
| 263 | + return tensor.chunk(mesh.size(), dim=dim)[torch.distributed.get_rank(mesh.get_group())] |
| 264 | + |
| 265 | + @classmethod |
| 266 | + def unshard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor: |
| 267 | + tensor = tensor.contiguous() |
| 268 | + tensor = AllGatherFunction.apply(tensor, dim, mesh.get_group()) |
| 269 | + return tensor |
| 270 | + |
| 271 | + |
| 272 | +def _get_submodule_by_name(model: torch.nn.Module, name: str) -> Union[torch.nn.Module, List[torch.nn.Module]]: |
| 273 | + if name.count("*") > 1: |
| 274 | + raise ValueError("Wildcard '*' can only be used once in the name") |
| 275 | + return _find_submodule_by_name(model, name) |
| 276 | + |
| 277 | + |
| 278 | +def _find_submodule_by_name(model: torch.nn.Module, name: str) -> Union[torch.nn.Module, List[torch.nn.Module]]: |
| 279 | + if name == "": |
| 280 | + return model |
| 281 | + first_atom, remaining_name = name.split(".", 1) if "." in name else (name, "") |
| 282 | + if first_atom == "*": |
| 283 | + if not isinstance(model, torch.nn.ModuleList): |
| 284 | + raise ValueError("Wildcard '*' can only be used with ModuleList") |
| 285 | + submodules = [] |
| 286 | + for submodule in model: |
| 287 | + subsubmodules = _find_submodule_by_name(submodule, remaining_name) |
| 288 | + if not isinstance(subsubmodules, list): |
| 289 | + subsubmodules = [subsubmodules] |
| 290 | + submodules.extend(subsubmodules) |
| 291 | + return submodules |
| 292 | + else: |
| 293 | + if hasattr(model, first_atom): |
| 294 | + submodule = getattr(model, first_atom) |
| 295 | + return _find_submodule_by_name(submodule, remaining_name) |
| 296 | + else: |
| 297 | + raise ValueError(f"'{first_atom}' is not a submodule of '{model.__class__.__name__}'") |
0 commit comments