Skip to content

[Bug] ValueError: text_dict cannot be None for distillation pipeline #877

@wzr1201

Description

@wzr1201

Describe the bug

hi, great work! but there is some bugs when I run the script "distill_dmd_VSA_t2v_1.3B", the DATA_DIR="FastVideo/Wan-Syn_77x448x832_600k/train", can u help me for this, thanks

Reproduction

#!/bin/bash

单机 8 卡,无 SLURM

set -e -x

##############################

1. 基础环境(同原脚本)

##############################

source ~/conda/miniconda/bin/activate

conda activate your_env

export WANDB_MODE=online
export NCCL_P2P_DISABLE=1
export TORCH_NCCL_ENABLE_MONITORING=0
export TRITON_CACHE_DIR=/tmp/triton_cache_$$
export TOKENIZERS_PARALLELISM=false
export WANDB_BASE_URL="https://api.wandb.ai"
export FASTVIDEO_ATTENTION_BACKEND=VIDEO_SPARSE_ATTN

##############################

2. 路径 & 超参(同原脚本)

##############################

MODEL_PATH="Wan-AI/Wan2.1-T2V-14B-Diffusers"

REAL_SCORE_MODEL_PATH="Wan-AI/Wan2.1-T2V-14B-Diffusers"

FAKE_SCORE_MODEL_PATH="Wan-AI/Wan2.1-T2V-14B-Diffusers"

FAKE_SCORE_MODEL_PATH="Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
REAL_SCORE_MODEL_PATH="Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
MODEL_PATH="Wan-AI/Wan2.1-T2V-1.3B-Diffusers"

DATA_DIR="FastVideo/Wan-Syn_77x448x832_600k/train"
VALIDATION_DATASET_FILE="your_validation_dataset_file"
OUTPUT_DIR="checkpoints/wan_t2v_finetune"

##############################

3. 训练参数数组(直接照搬)

##############################
parallel_args=(
--num_gpus 4 # 单机 8 卡
--sp_size 4
--tp_size 1
--hsdp_replicate_dim 1
--hsdp_shard_dim 4
)

model_args=(
--model_path $MODEL_PATH
--pretrained_model_name_or_path $MODEL_PATH
--real_score_model_path $REAL_SCORE_MODEL_PATH
--fake_score_model_path $FAKE_SCORE_MODEL_PATH
)

dataset_args=( --data_path "$DATA_DIR" --dataloader_num_workers 4 )

training_args=(
--tracker_project_name wan_t2v_distill_dmd_VSA
--output_dir "$OUTPUT_DIR"
--max_train_steps 4000
--train_batch_size 1
--train_sp_batch_size 1
--gradient_accumulation_steps 1
--num_latent_t 21
--num_height 480
--num_width 480
--num_frames 81
--enable_gradient_checkpointing_type "full"
)

optimizer_args=(
--learning_rate 2e-6
--mixed_precision "bf16"
--training_state_checkpointing_steps 500
--weight_only_checkpointing_steps 500
--weight_decay 0.01
--max_grad_norm 1.0
)

validation_args=(
--validation_dataset_file "$VALIDATION_DATASET_FILE"
--validation_steps 200
--validation_sampling_steps "3"
--validation_guidance_scale "6.0"
)

miscellaneous_args=(
--inference_mode False
--checkpoints_total_limit 3
--training_cfg_rate 0.0
--dit_precision "fp32"
--ema_start_step 0
--flow_shift 3
--seed 1000
)

dmd_args=(
--dmd_denoising_steps '1000,757,522'
--min_timestep_ratio 0.02
--max_timestep_ratio 0.98
--generator_update_interval 5
--real_score_guidance_scale 3.5
--VSA_sparsity 0.9
)

##############################

4. 启动训练

##############################

单机 8 卡,master=127.0.0.1

torchrun
--nnodes=1
--nproc_per_node=4
--master_port=29500
fastvideo/training/wan_distillation_pipeline.py
"${parallel_args[@]}"
"${model_args[@]}"
"${dataset_args[@]}"
"${training_args[@]}"
"${optimizer_args[@]}"
"${validation_args[@]}"
"${miscellaneous_args[@]}"
"${dmd_args[@]}"

Environment

linux
torch 2.7.1
fastvideo 0.1.6 /root/wangzairan01/projects/FastVideo
cuda12.4(A800)
Python 3.12.12

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions