Skip to content

[Bug] param.grad is None or gradient is zero when reproducing FastWan #858

@dingangui

Description

@dingangui

Describe the bug

error info

  Steps:   0%|▏                                                              | 64/20000 [20:09<102:30:56, 18.51s/it, total_loss=0.7242, generator_loss=0.6643, fake_score_loss=0.0599, step_time=18.71s, grad_norm=None, ema=✓, ema2=✗]Traceback (most recent call last):
    File "/DATA/dgg/codes/temp/FastVideo/fastvideo/training/wan_distillation_pipeline.py", line 76, in <module>
      main(args)
    File "/DATA/dgg/codes/temp/FastVideo/fastvideo/training/wan_distillation_pipeline.py", line 64, in main
      pipeline.train()
    File "/DATA/dgg/codes/temp/FastVideo/fastvideo/training/distillation_pipeline.py", line 1524, in train
      training_batch = self.train_one_step(training_batch)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  [rank1]: Traceback (most recent call last):
  [rank1]:   File "/DATA/dgg/codes/temp/FastVideo/fastvideo/training/wan_distillation_pipeline.py", line 76, in <module>
  [rank1]:     main(args)
  [rank1]:   File "/DATA/dgg/codes/temp/FastVideo/fastvideo/training/wan_distillation_pipeline.py", line 64, in main
  [rank1]:     pipeline.train()
  [rank1]:   File "/DATA/dgg/codes/temp/FastVideo/fastvideo/training/distillation_pipeline.py", line 1524, in train
  [rank1]:     training_batch = self.train_one_step(training_batch)
  [rank1]:                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  [rank1]:   File "/DATA/dgg/codes/temp/FastVideo/fastvideo/training/distillation_pipeline.py", line 1003, in train_one_step
  [rank1]:     assert param.grad is not None and param.grad.abs().sum() > 0
  [rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  [rank1]: AssertionError
    File "/DATA/dgg/codes/temp/FastVideo/fastvideo/training/distillation_pipeline.py", line 1003, in train_one_step
      assert param.grad is not None and param.grad.abs().sum() > 0
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  AssertionError
  [rank0]: Traceback (most recent call last):
  [rank0]:   File "/DATA/dgg/codes/temp/FastVideo/fastvideo/training/wan_distillation_pipeline.py", line 76, in <module>
  [rank0]:     main(args)
  [rank0]:   File "/DATA/dgg/codes/temp/FastVideo/fastvideo/training/wan_distillation_pipeline.py", line 64, in main
  [rank0]:     pipeline.train()
  [rank0]:   File "/DATA/dgg/codes/temp/FastVideo/fastvideo/training/distillation_pipeline.py", line 1524, in train
  [rank0]:     training_batch = self.train_one_step(training_batch)
  [rank0]:                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  [rank0]:   File "/DATA/dgg/codes/temp/FastVideo/fastvideo/training/distillation_pipeline.py", line 1003, in train_one_step
  [rank0]:     assert param.grad is not None and param.grad.abs().sum() > 0
  [rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  [rank0]: AssertionError

Reproduction

training scripts:

I run this scripts at commit 404314d: [Feature]Add video-to-video (V2V) pipeline (#829),
Apart from adding the following .sh file, no other content has been modified.

distill.sh

#!/bin/bash
#SBATCH --job-name=t2v
#SBATCH --partition=main
#SBATCH --nodes=8
#SBATCH --ntasks=8
#SBATCH --ntasks-per-node=1
#SBATCH --gres=gpu:8
#SBATCH --cpus-per-task=128
#SBATCH --mem=1440G
#SBATCH --output=dmd_t2v_output/t2v_%j.out
#SBATCH --error=dmd_t2v_output/t2v_%j.err
#SBATCH --exclusive
set -e -x

# Environment Setup
# source ~/conda/miniconda/bin/activate
# conda activate your_env

# Basic Info
export WANDB_MODE=online
export WANDB_BASE_URL=http://localhost:8080
export WANDB_API_KEY=local-091fbd542428e2f9b998b07f97ef6ea46d2f74cf
export TOKENIZERS_PARALLELISM=false
export FASTVIDEO_ATTENTION_BACKEND=VIDEO_SPARSE_ATTN
export TRITON_CACHE_DIR=/tmp/triton_cache
export MASTER_ADDR=localhost
export MASTER_PORT=$(python -c 'import socket; s=socket.socket(); s.bind(("",0)); print(s.getsockname()[1]); s.close()')
export NODE_RANK=0
export CUDA_VISIBLE_DEVICES=6,7

# Configs
NUM_GPUS=2
MODEL_PATH="/DATA/dgg/models/Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
REAL_SCORE_MODEL_PATH=$MODEL_PATH
FAKE_SCORE_MODEL_PATH=$MODEL_PATH
DATA_DIR=/DATA/dgg/models/FastVideo/Wan-Syn_77x448x832_600k/train
VALIDATION_DATASET_FILE=/DATA/dgg/models/FastVideo/Wan-Syn_77x448x832_600k/val/Part_1/latents_chunk_0000.parquet
INTERVAL=1

PREFIX=$(date +"%m%d_%H%M")
WANDB_RUN_NAME="${PREFIX}_vsa_1step_${NUM_GPUS}gpu_interval${INTERVAL}"
OUTPUT_DIR="checkpoints/distill_wan_t2v_to_t2v/${WANDB_RUN_NAME}"
# export CUDA_VISIBLE_DEVICES=4,5
# IP=[MASTER NODE IP]

# Training arguments
training_args=(
  --tracker_project_name wan_t2v_distill_dmd_VSA
  --wandb_run_name ${WANDB_RUN_NAME}
  --output_dir "$OUTPUT_DIR"
  --max_train_steps 20000
  --train_batch_size 1
  --train_sp_batch_size 1
  --gradient_accumulation_steps 1
  --num_latent_t 20
  --num_height 448
  --num_width 832
  --num_frames 77
  --enable_gradient_checkpointing_type "full"
  --mode distillation
)

# Parallel arguments
parallel_args=(
  --num_gpus $NUM_GPUS
  --sp_size 1
  --tp_size 1
  --hsdp_replicate_dim $NUM_GPUS
  --hsdp_shard_dim 1
)

# Model arguments
model_args=(
  --model_path $MODEL_PATH
  --pretrained_model_name_or_path $MODEL_PATH
  --real_score_model_path $REAL_SCORE_MODEL_PATH
  --fake_score_model_path $FAKE_SCORE_MODEL_PATH
)

# Dataset arguments
dataset_args=(
  --data_path "$DATA_DIR"
  --dataloader_num_workers 4
)

# Validation arguments
validation_args=(
  --log_validation
  --validation_dataset_file "$VALIDATION_DATASET_FILE"
  --validation_steps 500
  --validation_sampling_steps "1"
  --validation_guidance_scale "6.0" # not used for dmd inference
)

# Optimizer arguments
optimizer_args=(
  --learning_rate 2e-6
  --mixed_precision "bf16"
  --training_state_checkpointing_steps 1000
  --weight_only_checkpointing_steps 1000
  --weight_decay 0.01
  --max_grad_norm 1.0
)

# Miscellaneous arguments
miscellaneous_args=(
  --inference_mode False
  --checkpoints_total_limit 3
  --training_cfg_rate 0.0
  --dit_precision "fp32"
  --ema_start_step 0
  --flow_shift 8
  --seed 1000
)

# DMD arguments
dmd_args=(
  --dmd_denoising_steps '1000'
  --min_timestep_ratio 0.02
  --max_timestep_ratio 0.98
  --generator_update_interval $INTERVAL
  --real_score_guidance_scale 3.5
  --VSA_sparsity 0.8
)

v2lv_args=(
  # --log_visualization # disable if oom
#   --task_flag ""
#   --token_concat_mode "sequential"
#   --use_flow_matching_loss
#   --flow_matching_weight 1.0
)

torchrun \
  --standalone \
  --nnodes=1 \
  --nproc_per_node=$NUM_GPUS \
  --master_port=$MASTER_PORT \
    fastvideo/training/wan_distillation_pipeline.py \
    "${parallel_args[@]}" \
    "${model_args[@]}" \
    "${dataset_args[@]}" \
    "${training_args[@]}" \
    "${optimizer_args[@]}" \
    "${validation_args[@]}" \
    "${miscellaneous_args[@]}" \
    "${dmd_args[@]}" \
    "${v2lv_args[@]}"

Environment

envs

➜ python collect_env.py

  INFO 10-30 11:56:59 [__init__.py:109] ROCm platform is unavailable: No module named 'amdsmi'
  WARNING 10-30 11:56:59 [logger.py:122]  By default, logger.info(..) will only log from the local main process. Set logger.info(..., is_local_main_process=False) to log from all processes.
  INFO 10-30 11:56:59 [__init__.py:47] CUDA is available
  Collecting environment information...
  PyTorch version: 2.7.1+cu126
  Is debug build: False
  CUDA used to build PyTorch: 12.6
  ROCM used to build PyTorch: N/A
  
  OS: Ubuntu 18.04.6 LTS (x86_64)
  GCC version: (conda-forge gcc 12.4.0-2) 12.4.0
  Clang version: Could not collect
  CMake version: version 3.10.2
  Libc version: glibc-2.31
  
  Python version: 3.12.12 | packaged by Anaconda, Inc. | (main, Oct 14 2025, 16:16:33) [GCC 11.2.0] (64-bit runtime)
  Python platform: Linux-4.15.0-187-generic-x86_64-with-glibc2.31
  Is CUDA available: True
  CUDA runtime version: 11.6.55
  CUDA_MODULE_LOADING set to: LAZY
  GPU models and configuration: 
  GPU 0: NVIDIA A100 80GB PCIe
  GPU 1: NVIDIA A100 80GB PCIe
  GPU 2: NVIDIA A100 80GB PCIe
  GPU 3: NVIDIA A100 80GB PCIe
  GPU 4: NVIDIA A100 80GB PCIe
  GPU 5: NVIDIA A100 80GB PCIe
  GPU 6: NVIDIA A100 80GB PCIe
  GPU 7: NVIDIA A100 80GB PCIe
  
  Nvidia driver version: 535.54.03
  cuDNN version: Could not collect
  HIP runtime version: N/A
  MIOpen runtime version: N/A
  Is XNNPACK available: True
  
  CPU:
  Architecture:        x86_64
  CPU op-mode(s):      32-bit, 64-bit
  Byte Order:          Little Endian
  CPU(s):              72
  On-line CPU(s) list: 0-71
  Thread(s) per core:  2
  Core(s) per socket:  18
  Socket(s):           2
  NUMA node(s):        2
  Vendor ID:           GenuineIntel
  CPU family:          6
  Model:               85
  Model name:          Intel(R) Xeon(R) Gold 6240 CPU @ 2.60GHz
  Stepping:            7
  CPU MHz:             2844.296
  CPU max MHz:         2601.0000
  CPU min MHz:         1000.0000
  BogoMIPS:            5200.00
  Virtualization:      VT-x
  L1d cache:           32K
  L1i cache:           32K
  L2 cache:            1024K
  L3 cache:            25344K
  NUMA node0 CPU(s):   0-17,36-53
  NUMA node1 CPU(s):   18-35,54-71
  Flags:               fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cdp_l3 invpcid_single intel_ppin ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm mpx rdt_a avx512f avx512dq rdseed adx smap clflushopt clwb intel_pt avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local dtherm ida arat pln pts pku ospke avx512_vnni md_clear flush_l1d arch_capabilities
  
  Versions of relevant libraries:
  [pip3] accelerate==1.0.1
  [pip3] numpy==2.2.6
  [pip3] nvidia-cublas-cu12==12.6.4.1
  [pip3] nvidia-cuda-cupti-cu12==12.6.80
  [pip3] nvidia-cuda-nvrtc-cu12==12.6.77
  [pip3] nvidia-cuda-runtime-cu12==12.6.77
  [pip3] nvidia-cudnn-cu12==9.5.1.17
  [pip3] nvidia-cufft-cu12==11.3.0.4
  [pip3] nvidia-cufile-cu12==1.11.1.6
  [pip3] nvidia-curand-cu12==10.3.7.77
  [pip3] nvidia-cusolver-cu12==11.7.1.2
  [pip3] nvidia-cusparse-cu12==12.5.4.2
  [pip3] nvidia-cusparselt-cu12==0.6.3
  [pip3] nvidia-ml-py==13.580.82
  [pip3] nvidia-nccl-cu12==2.26.2
  [pip3] nvidia-nvjitlink-cu12==12.6.85
  [pip3] nvidia-nvshmem-cu12==3.3.20
  [pip3] nvidia-nvtx-cu12==12.6.77
  [pip3] peft==0.17.1
  [pip3] torch==2.7.1
  [pip3] torchcodec==0.5
  [pip3] torchdata==0.11.0
  [pip3] torchvision==0.22.1
  [pip3] transformers==4.57.1
  [pip3] triton==3.3.1
  [conda] accelerate                1.0.1                    pypi_0    pypi
  [conda] numpy                     2.2.6                    pypi_0    pypi
  [conda] nvidia-cublas-cu12        12.6.4.1                 pypi_0    pypi
  [conda] nvidia-cuda-cupti-cu12    12.6.80                  pypi_0    pypi
  [conda] nvidia-cuda-nvrtc-cu12    12.6.77                  pypi_0    pypi
  [conda] nvidia-cuda-runtime-cu12  12.6.77                  pypi_0    pypi
  [conda] nvidia-cudnn-cu12         9.5.1.17                 pypi_0    pypi
  [conda] nvidia-cufft-cu12         11.3.0.4                 pypi_0    pypi
  [conda] nvidia-cufile-cu12        1.11.1.6                 pypi_0    pypi
  [conda] nvidia-curand-cu12        10.3.7.77                pypi_0    pypi
  [conda] nvidia-cusolver-cu12      11.7.1.2                 pypi_0    pypi
  [conda] nvidia-cusparse-cu12      12.5.4.2                 pypi_0    pypi
  [conda] nvidia-cusparselt-cu12    0.6.3                    pypi_0    pypi
  [conda] nvidia-ml-py              13.580.82                pypi_0    pypi
  [conda] nvidia-nccl-cu12          2.26.2                   pypi_0    pypi
  [conda] nvidia-nvjitlink-cu12     12.6.85                  pypi_0    pypi
  [conda] nvidia-nvshmem-cu12       3.3.20                   pypi_0    pypi
  [conda] nvidia-nvtx-cu12          12.6.77                  pypi_0    pypi
  [conda] peft                      0.17.1                   pypi_0    pypi
  [conda] torch                     2.7.1                    pypi_0    pypi
  [conda] torchcodec                0.5                      pypi_0    pypi
  [conda] torchdata                 0.11.0                   pypi_0    pypi
  [conda] torchvision               0.22.1                   pypi_0    pypi
  [conda] transformers              4.57.1                   pypi_0    pypi
  [conda] triton                    3.3.1                    pypi_0    pypi
  FastVideo Version: 
  FastVideo Build Flags:
  CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
  GPU Topology:
          GPU0    GPU1    GPU2    GPU3    GPU4    GPU5    GPU6    GPU7    NIC0    NIC1    CPU Affinity    NUMA Affinity   GPU NUMA ID
  GPU0     X      PIX     PIX     PIX     NODE    NODE    NODE    NODE    PIX     PIX     0-17,36-53      0               N/A
  GPU1    PIX      X      PIX     PIX     NODE    NODE    NODE    NODE    PIX     PIX     0-17,36-53      0               N/A
  GPU2    PIX     PIX      X      PIX     NODE    NODE    NODE    NODE    PIX     PIX     0-17,36-53      0               N/A
  GPU3    PIX     PIX     PIX      X      NODE    NODE    NODE    NODE    PIX     PIX     0-17,36-53      0               N/A
  GPU4    NODE    NODE    NODE    NODE     X      PIX     PIX     PIX     NODE    NODE    0-17,36-53      0               N/A
  GPU5    NODE    NODE    NODE    NODE    PIX      X      PIX     PIX     NODE    NODE    0-17,36-53      0               N/A
  GPU6    NODE    NODE    NODE    NODE    PIX     PIX      X      PIX     NODE    NODE    0-17,36-53      0               N/A
  GPU7    NODE    NODE    NODE    NODE    PIX     PIX     PIX      X      NODE    NODE    0-17,36-53      0               N/A
  NIC0    PIX     PIX     PIX     PIX     NODE    NODE    NODE    NODE     X      PIX
  NIC1    PIX     PIX     PIX     PIX     NODE    NODE    NODE    NODE    PIX      X 
  
  Legend:
  
    X    = Self
    SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
    NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
    PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
    PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
    PIX  = Connection traversing at most a single PCIe bridge
    NV#  = Connection traversing a bonded set of # NVLinks
  
  NIC Legend:
  
    NIC0: mlx5_0
    NIC1: mlx5_1
  
  LD_LIBRARY_PATH=:/usr/local/cuda-12.4/lib64
  CUDA_HOME=/usr/local/cuda-12.4
  CUDA_MODULE_LOADING=LAZY
  TORCHINDUCTOR_CACHE_DIR=/tmp/torchinductor_gpu


other info

➜ python               
Python 3.12.12 | packaged by Anaconda, Inc. | (main, Oct 14 2025, 16:16:33) [GCC 11.2.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> from fastvideo.utils import is_vsa_available
INFO 10-30 12:47:12 [__init__.py:109] ROCm platform is unavailable: No module named 'amdsmi'
WARNING 10-30 12:47:12 [logger.py:122]  By default, logger.info(..) will only log from the local main process. Set logger.info(..., is_local_main_process=False) to log from all processes.
INFO 10-30 12:47:12 [__init__.py:47] CUDA is available
>>> is_vsa_available()
True
>>> 

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions