-
Notifications
You must be signed in to change notification settings - Fork 210
Description
Describe the bug
hi, great work! but there is some bugs when I run the script "distill_dmd_VSA_t2v_1.3B", the DATA_DIR="FastVideo/Wan-Syn_77x448x832_600k/train", can u help me for this, thanks
Reproduction
#!/bin/bash
单机 8 卡,无 SLURM
set -e -x
##############################
1. 基础环境(同原脚本)
##############################
source ~/conda/miniconda/bin/activate
conda activate your_env
export WANDB_MODE=online
export NCCL_P2P_DISABLE=1
export TORCH_NCCL_ENABLE_MONITORING=0
export TRITON_CACHE_DIR=/tmp/triton_cache_$$
export TOKENIZERS_PARALLELISM=false
export WANDB_BASE_URL="https://api.wandb.ai"
export FASTVIDEO_ATTENTION_BACKEND=VIDEO_SPARSE_ATTN
##############################
2. 路径 & 超参(同原脚本)
##############################
MODEL_PATH="Wan-AI/Wan2.1-T2V-14B-Diffusers"
REAL_SCORE_MODEL_PATH="Wan-AI/Wan2.1-T2V-14B-Diffusers"
FAKE_SCORE_MODEL_PATH="Wan-AI/Wan2.1-T2V-14B-Diffusers"
FAKE_SCORE_MODEL_PATH="Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
REAL_SCORE_MODEL_PATH="Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
MODEL_PATH="Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
DATA_DIR="FastVideo/Wan-Syn_77x448x832_600k/train"
VALIDATION_DATASET_FILE="your_validation_dataset_file"
OUTPUT_DIR="checkpoints/wan_t2v_finetune"
##############################
3. 训练参数数组(直接照搬)
##############################
parallel_args=(
--num_gpus 4 # 单机 8 卡
--sp_size 4
--tp_size 1
--hsdp_replicate_dim 1
--hsdp_shard_dim 4
)
model_args=(
--model_path $MODEL_PATH
--pretrained_model_name_or_path $MODEL_PATH
--real_score_model_path $REAL_SCORE_MODEL_PATH
--fake_score_model_path $FAKE_SCORE_MODEL_PATH
)
dataset_args=( --data_path "$DATA_DIR" --dataloader_num_workers 4 )
training_args=(
--tracker_project_name wan_t2v_distill_dmd_VSA
--output_dir "$OUTPUT_DIR"
--max_train_steps 4000
--train_batch_size 1
--train_sp_batch_size 1
--gradient_accumulation_steps 1
--num_latent_t 21
--num_height 480
--num_width 480
--num_frames 81
--enable_gradient_checkpointing_type "full"
)
optimizer_args=(
--learning_rate 2e-6
--mixed_precision "bf16"
--training_state_checkpointing_steps 500
--weight_only_checkpointing_steps 500
--weight_decay 0.01
--max_grad_norm 1.0
)
validation_args=(
--validation_dataset_file "$VALIDATION_DATASET_FILE"
--validation_steps 200
--validation_sampling_steps "3"
--validation_guidance_scale "6.0"
)
miscellaneous_args=(
--inference_mode False
--checkpoints_total_limit 3
--training_cfg_rate 0.0
--dit_precision "fp32"
--ema_start_step 0
--flow_shift 3
--seed 1000
)
dmd_args=(
--dmd_denoising_steps '1000,757,522'
--min_timestep_ratio 0.02
--max_timestep_ratio 0.98
--generator_update_interval 5
--real_score_guidance_scale 3.5
--VSA_sparsity 0.9
)
##############################
4. 启动训练
##############################
单机 8 卡,master=127.0.0.1
torchrun
--nnodes=1
--nproc_per_node=4
--master_port=29500
fastvideo/training/wan_distillation_pipeline.py
"${parallel_args[@]}"
"${model_args[@]}"
"${dataset_args[@]}"
"${training_args[@]}"
"${optimizer_args[@]}"
"${validation_args[@]}"
"${miscellaneous_args[@]}"
"${dmd_args[@]}"
Environment
linux
torch 2.7.1
fastvideo 0.1.6 /root/wangzairan01/projects/FastVideo
cuda12.4(A800)
Python 3.12.12