Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
#!/bin/bash
#SBATCH --job-name=t2v
#SBATCH --partition=main
#SBATCH --nodes=8
#SBATCH --ntasks=8
#SBATCH --ntasks-per-node=1
#SBATCH --gres=gpu:8
#SBATCH --cpus-per-task=128
#SBATCH --mem=1440G
#SBATCH --output=dmd_Wan2.2/t2v_g2e5_f1e5_%j.out
#SBATCH --error=dmd_Wan2.2/t2v_g2e5_f1e5_%j.err
#SBATCH --exclusive
set -e -x

# Environment Setup
source ~/conda/miniconda/bin/activate
conda activate your_env

# Basic Info
export WANDB_MODE="online"
export NCCL_P2P_DISABLE=1
export TORCH_NCCL_ENABLE_MONITORING=0
# different cache dir for different processes
export TRITON_CACHE_DIR=/tmp/triton_cache_${SLURM_PROCID}
export MASTER_PORT=29500
export NODE_RANK=$SLURM_PROCID
nodes=( $(scontrol show hostnames $SLURM_JOB_NODELIST) )
export MASTER_ADDR=${nodes[0]}
export CUDA_VISIBLE_DEVICES=$SLURM_LOCALID
export TOKENIZERS_PARALLELISM=false
export WANDB_BASE_URL="https://api.wandb.ai"
export WANDB_MODE=online
export FASTVIDEO_ATTENTION_BACKEND=FLASH_ATTN
export WANDB_API_KEY=your_wandb_api_key
# export FASTVIDEO_ATTENTION_BACKEND=TORCH_SDPA

echo "MASTER_ADDR: $MASTER_ADDR"
echo "NODE_RANK: $NODE_RANK"

# Configs
NUM_GPUS=8
MODEL_PATH="Wan-AI/Wan2.2-TI2V-5B-Diffusers"
DATA_DIR=your_data_dir
VALIDATION_DIR=your_validation_path #(example:validation_64.json)
# export CUDA_VISIBLE_DEVICES=4,5
# IP=[MASTER NODE IP]

# Training arguments
training_args=(
--tracker_project_name Wan_distillation
--output_dir "your_output_dir"
--max_train_steps 4000
--train_batch_size 1
--train_sp_batch_size 1
--gradient_accumulation_steps 1
--num_latent_t 31
--num_height 704
--num_width 1280
--num_frames 121
--enable_gradient_checkpointing_type "full"
--lora_rank 32
--lora_training True
)

# Parallel arguments
parallel_args=(
--num_gpus 64
--sp_size 1
--tp_size 1
--hsdp_replicate_dim 64
--hsdp_shard_dim 1
)

# Model arguments
model_args=(
--model_path $MODEL_PATH
--pretrained_model_name_or_path $MODEL_PATH
)

# Dataset arguments
dataset_args=(
--data_path "$DATA_DIR"
--dataloader_num_workers 4
)

# Validation arguments
validation_args=(
--log_validation
--validation_dataset_file "$VALIDATION_DIR"
--validation_steps 200
--validation_sampling_steps "3"
--validation_guidance_scale "6.0" # not used for dmd inference
)

# Optimizer arguments
optimizer_args=(
--learning_rate 2e-4
--lr_scheduler "cosine_with_min_lr"
--min_lr_ratio 0.5
--lr_warmup_steps 100
--fake_score_learning_rate 1e-5
--fake_score_lr_scheduler "cosine_with_min_lr"
--mixed_precision "bf16"
--training_state_checkpointing_steps 500
--weight_only_checkpointing_steps 200
--weight_decay 0.01
--max_grad_norm 1.0
)

# Miscellaneous arguments
miscellaneous_args=(
--inference_mode False
--checkpoints_total_limit 3
--training_cfg_rate 0.0
--dit_precision "fp32"
--ema_start_step 0
--flow_shift 5
--seed 1000
)

# DMD arguments
dmd_args=(
--dmd_denoising_steps '1000,757,522'
--min_timestep_ratio 0.02
--max_timestep_ratio 0.98
--generator_update_interval 5
--real_score_guidance_scale 3
--simulate_generator_forward
--log_visualization # disable if oom
)

srun torchrun \
--nnodes $SLURM_JOB_NUM_NODES \
--nproc_per_node $NUM_GPUS \
--node_rank $SLURM_PROCID \
--rdzv_backend=c10d \
--rdzv_endpoint="$MASTER_ADDR:$MASTER_PORT" \
fastvideo/training/wan_distillation_pipeline.py \
"${parallel_args[@]}" \
"${model_args[@]}" \
"${dataset_args[@]}" \
"${training_args[@]}" \
"${optimizer_args[@]}" \
"${validation_args[@]}" \
"${miscellaneous_args[@]}" \
"${dmd_args[@]}"
35 changes: 30 additions & 5 deletions fastvideo/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

import torch
import torch.distributed
import torch.distributed as dist
from torch.distributed import Backend, ProcessGroup, ReduceOp

import fastvideo.envs as envs
Expand Down Expand Up @@ -739,6 +740,23 @@ def init_world_group(ranks: list[int], local_rank: int,
)


def get_node_group() -> GroupCoordinator:
assert _NODE is not None, ("node group is not initialized")
return _NODE


def init_node_group(local_rank: int, backend: str):
cpu_group = get_world_group().cpu_group
node_ranks = same_node_ranks(cpu_group)
node_size = len(node_ranks)
all_node_ranks = [
list(range(i * node_size, (i + 1) * node_size))
for i in range(dist.get_world_size() // node_size)
]
global _NODE
_NODE = init_model_parallel_group(all_node_ranks, local_rank, backend)


def init_model_parallel_group(
group_ranks: list[list[int]],
local_rank: int,
Expand Down Expand Up @@ -825,6 +843,8 @@ def init_distributed_environment(
else:
assert _WORLD.world_size == torch.distributed.get_world_size(), (
"world group already initialized with a different world size")
# Init a group for each node
init_node_group(local_rank, backend)


_SP: GroupCoordinator | None = None
Expand Down Expand Up @@ -1075,17 +1095,22 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
ray.shutdown()


def is_the_same_node_as(pg: ProcessGroup | StatelessProcessGroup,
source_rank: int = 0) -> list[int]:
def same_node_ranks(pg: ProcessGroup | StatelessProcessGroup,
source_rank: int = 0) -> list[int]:
"""
This is a collective operation that returns if each rank is in the same node
This is a collective operation that returns ranks that are in the same node
as the source rank. It tests if processes are attached to the same
memory system (shared access to shared memory).
Args:
pg: the global process group to test
source_rank: the rank to test against
Returns:
A list of ranks that are in the same node as the source rank.
"""
if isinstance(pg, ProcessGroup):
assert torch.distributed.get_backend(
pg) != torch.distributed.Backend.NCCL, (
"in_the_same_node_as should be tested with a non-NCCL group.")
"same_node_ranks should be tested with a non-NCCL group.")
# local rank inside the group
rank = torch.distributed.get_rank(group=pg)
world_size = torch.distributed.get_world_size(group=pg)
Expand Down Expand Up @@ -1157,7 +1182,7 @@ def is_the_same_node_as(pg: ProcessGroup | StatelessProcessGroup,
rank_data = pg.broadcast_obj(is_in_the_same_node, src=i)
aggregated_data += rank_data

return [x == 1 for x in aggregated_data.tolist()]
return [i for i, x in enumerate(aggregated_data.tolist()) if x == 1]


def initialize_tensor_parallel_group(
Expand Down
8 changes: 7 additions & 1 deletion fastvideo/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed.tensor import DTensor

from fastvideo.layers.custom_op import CustomOp

Expand Down Expand Up @@ -78,7 +79,12 @@ def forward_native(
x = x * torch.rsqrt(variance + self.variance_epsilon)
x = x.to(orig_dtype)
if self.has_weight:
x = x * self.weight
# TODO(wenxuan): When using CPU offload, FSDP has a bug that doesn't unwrap DTensor in final_layer_norm.
# Report this
if isinstance(self.weight, DTensor):
x = x * self.weight.to_local().to(x.device)
else:
x = x * self.weight
if residual is None:
return x
else:
Expand Down
1 change: 0 additions & 1 deletion fastvideo/models/loader/fsdp_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,6 @@ def maybe_load_fsdp_model(

def shard_model(
model,
*,
cpu_offload: bool,
reshard_after_forward: bool = True,
mp_policy: MixedPrecisionPolicy | None = MixedPrecisionPolicy(), # noqa
Expand Down
61 changes: 51 additions & 10 deletions fastvideo/models/loader/weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@
import filelock
import huggingface_hub.constants
import torch
import torch.distributed as dist
from safetensors.torch import safe_open
from tqdm.auto import tqdm

from fastvideo.distributed.parallel_state import get_node_group
from fastvideo.distributed import get_local_torch_device
from fastvideo.logger import init_logger

Expand Down Expand Up @@ -119,12 +121,24 @@ def filter_files_not_needed_for_inference(

def safetensors_weights_iterator(
hf_weights_files: list[str],
to_cpu: bool = True,
to_cpu: bool = False,
async_broadcast: bool = False
) -> Generator[tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model safetensor files."""
enable_tqdm = not torch.distributed.is_initialized(
) or torch.distributed.get_rank() == 0
device = "cpu" if to_cpu else str(get_local_torch_device())
"""Iterate over the weights in the model safetensor files.
Args:
hf_weights_files: List of safetensor files to load.
to_cpu: Whether to load the weights to CPU. If False, will load to the GPU device bound to the current process.
async_broadcast: Whether to overlap loading from disk and broadcasting to other ranks. If True,
must iterate over all the weights before use. Only use if to_cpu is False.
"""
node_group = get_node_group()
local_rank = node_group.local_rank
device = f"cuda:{local_rank}" if not to_cpu else "cpu"
enable_tqdm = not torch.distributed.is_initialized() or local_rank == 0
assert not (async_broadcast
and to_cpu), "Cannot broadcast weights when loading to CPU"

handles = []
for st_file in tqdm(
hf_weights_files,
desc="Loading safetensors checkpoint shards",
Expand All @@ -133,18 +147,45 @@ def safetensors_weights_iterator(
):
with safe_open(st_file, framework="pt", device=device) as f:
for name in f.keys(): # noqa: SIM118
param = f.get_tensor(name)
if to_cpu:
param = f.get_tensor(name)
else:
if local_rank == 0:
param = f.get_tensor(name)
else:
shape = f.get_slice(name).get_shape()
param = torch.empty(shape, device=device)
# broadcast to local ranks
# TODO(Wenxuan): scatter instead of broadcast
if node_group.world_size > 1:
group = node_group.device_group
if async_broadcast:
handle = dist.broadcast(param,
src=dist.get_global_rank(
group, 0),
async_op=True,
group=group)
handles.append(handle)
else:
dist.broadcast(param,
src=dist.get_global_rank(group, 0),
group=group)
yield name, param

if async_broadcast:
for handle in handles:
handle.wait()


def pt_weights_iterator(
hf_weights_files: list[str],
to_cpu: bool = True,
to_cpu: bool = True # default to CPU for text encoder
) -> Generator[tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model bin/pt files."""
device = "cpu" if to_cpu else str(get_local_torch_device())
enable_tqdm = not torch.distributed.is_initialized(
) or torch.distributed.get_rank() == 0
node_group = get_node_group()
local_rank = node_group.local_rank
device = f"cuda:{local_rank}" if not to_cpu else "cpu"
enable_tqdm = not torch.distributed.is_initialized() or local_rank == 0
for bin_file in tqdm(
hf_weights_files,
desc="Loading pt checkpoint shards",
Expand Down
2 changes: 1 addition & 1 deletion fastvideo/pipelines/training/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
"""
Training pipelines for fastvideo.v1.
Training pipelines for fastvideo.

This package contains pipelines for training diffusion models.
"""