2121from  pathlib  import  Path 
2222from  typing  import  List , Optional , Tuple , Union 
2323
24+ import  numpy  as  np 
2425import  torch 
26+ import  torchvision .transforms  as  TT 
2527import  transformers 
2628from  accelerate  import  Accelerator 
2729from  accelerate .logging  import  get_logger 
2830from  accelerate .utils  import  DistributedDataParallelKwargs , ProjectConfiguration , set_seed 
2931from  huggingface_hub  import  create_repo , upload_folder 
3032from  peft  import  LoraConfig , get_peft_model_state_dict , set_peft_model_state_dict 
3133from  torch .utils .data  import  DataLoader , Dataset 
32- from  torchvision  import  transforms 
34+ from  torchvision .transforms  import  InterpolationMode 
35+ from  torchvision .transforms .functional  import  resize 
3336from  tqdm .auto  import  tqdm 
3437from  transformers  import  AutoTokenizer , T5EncoderModel , T5Tokenizer 
3538
3639import  diffusers 
3740from  diffusers  import  AutoencoderKLCogVideoX , CogVideoXDPMScheduler , CogVideoXPipeline , CogVideoXTransformer3DModel 
41+ from  diffusers .image_processor  import  VaeImageProcessor 
3842from  diffusers .models .embeddings  import  get_3d_rotary_pos_embed 
3943from  diffusers .optimization  import  get_scheduler 
4044from  diffusers .pipelines .cogvideo .pipeline_cogvideox  import  get_resize_crop_region_for_grid 
@@ -214,6 +218,12 @@ def get_args():
214218        default = 720 ,
215219        help = "All input videos are resized to this width." ,
216220    )
221+     parser .add_argument (
222+         "--video_reshape_mode" ,
223+         type = str ,
224+         default = "center" ,
225+         help = "All input videos are reshaped to this mode. Choose between ['center', 'random', 'none']" ,
226+     )
217227    parser .add_argument ("--fps" , type = int , default = 8 , help = "All input videos will be used at this FPS." )
218228    parser .add_argument (
219229        "--max_num_frames" , type = int , default = 49 , help = "All input videos will be truncated to these many frames." 
@@ -413,6 +423,7 @@ def __init__(
413423        video_column : str  =  "video" ,
414424        height : int  =  480 ,
415425        width : int  =  720 ,
426+         video_reshape_mode : str  =  "center" ,
416427        fps : int  =  8 ,
417428        max_num_frames : int  =  49 ,
418429        skip_frames_start : int  =  0 ,
@@ -429,6 +440,7 @@ def __init__(
429440        self .video_column  =  video_column 
430441        self .height  =  height 
431442        self .width  =  width 
443+         self .video_reshape_mode  =  video_reshape_mode 
432444        self .fps  =  fps 
433445        self .max_num_frames  =  max_num_frames 
434446        self .skip_frames_start  =  skip_frames_start 
@@ -532,6 +544,38 @@ def _load_dataset_from_local_path(self):
532544
533545        return  instance_prompts , instance_videos 
534546
547+     def  _resize_for_rectangle_crop (self , arr ):
548+         image_size  =  self .height , self .width 
549+         reshape_mode  =  self .video_reshape_mode 
550+         if  arr .shape [3 ] /  arr .shape [2 ] >  image_size [1 ] /  image_size [0 ]:
551+             arr  =  resize (
552+                 arr ,
553+                 size = [image_size [0 ], int (arr .shape [3 ] *  image_size [0 ] /  arr .shape [2 ])],
554+                 interpolation = InterpolationMode .BICUBIC ,
555+             )
556+         else :
557+             arr  =  resize (
558+                 arr ,
559+                 size = [int (arr .shape [2 ] *  image_size [1 ] /  arr .shape [3 ]), image_size [1 ]],
560+                 interpolation = InterpolationMode .BICUBIC ,
561+             )
562+ 
563+         h , w  =  arr .shape [2 ], arr .shape [3 ]
564+         arr  =  arr .squeeze (0 )
565+ 
566+         delta_h  =  h  -  image_size [0 ]
567+         delta_w  =  w  -  image_size [1 ]
568+ 
569+         if  reshape_mode  ==  "random"  or  reshape_mode  ==  "none" :
570+             top  =  np .random .randint (0 , delta_h  +  1 )
571+             left  =  np .random .randint (0 , delta_w  +  1 )
572+         elif  reshape_mode  ==  "center" :
573+             top , left  =  delta_h  //  2 , delta_w  //  2 
574+         else :
575+             raise  NotImplementedError 
576+         arr  =  TT .functional .crop (arr , top = top , left = left , height = image_size [0 ], width = image_size [1 ])
577+         return  arr 
578+ 
535579    def  _preprocess_data (self ):
536580        try :
537581            import  decord 
@@ -542,15 +586,14 @@ def _preprocess_data(self):
542586
543587        decord .bridge .set_bridge ("torch" )
544588
545-         videos  =  []
546-         train_transforms  =  transforms .Compose (
547-             [
548-                 transforms .Lambda (lambda  x : x  /  255.0  *  2.0  -  1.0 ),
549-             ]
589+         progress_dataset_bar  =  tqdm (
590+             range (0 , len (self .instance_video_paths )),
591+             desc = "Loading progress resize and crop videos" ,
550592        )
593+         videos  =  []
551594
552595        for  filename  in  self .instance_video_paths :
553-             video_reader  =  decord .VideoReader (uri = filename .as_posix (),  width = self . width ,  height = self . height )
596+             video_reader  =  decord .VideoReader (uri = filename .as_posix ())
554597            video_num_frames  =  len (video_reader )
555598
556599            start_frame  =  min (self .skip_frames_start , video_num_frames )
@@ -576,10 +619,16 @@ def _preprocess_data(self):
576619            assert  (selected_num_frames  -  1 ) %  4  ==  0 
577620
578621            # Training transforms 
579-             frames  =  frames .float ()
580-             frames  =  torch .stack ([train_transforms (frame ) for  frame  in  frames ], dim = 0 )
581-             videos .append (frames .permute (0 , 3 , 1 , 2 ).contiguous ())  # [F, C, H, W] 
622+             frames  =  (frames  -  127.5 ) /  127.5 
623+             frames  =  frames .permute (0 , 3 , 1 , 2 )  # [F, C, H, W] 
624+             progress_dataset_bar .set_description (
625+                 f"Loading progress Resizing video from { frames .shape [2 ]} { frames .shape [3 ]} { self .height } { self .width }  
626+             )
627+             frames  =  self ._resize_for_rectangle_crop (frames )
628+             videos .append (frames .contiguous ())  # [F, C, H, W] 
629+             progress_dataset_bar .update (1 )
582630
631+         progress_dataset_bar .close ()
583632        return  videos 
584633
585634
@@ -694,8 +743,13 @@ def log_validation(
694743
695744    videos  =  []
696745    for  _  in  range (args .num_validation_videos ):
697-         video  =  pipe (** pipeline_args , generator = generator , output_type = "np" ).frames [0 ]
698-         videos .append (video )
746+         pt_images  =  pipe (** pipeline_args , generator = generator , output_type = "pt" ).frames [0 ]
747+         pt_images  =  torch .stack ([pt_images [i ] for  i  in  range (pt_images .shape [0 ])])
748+ 
749+         image_np  =  VaeImageProcessor .pt_to_numpy (pt_images )
750+         image_pil  =  VaeImageProcessor .numpy_to_pil (image_np )
751+ 
752+         videos .append (image_pil )
699753
700754    for  tracker  in  accelerator .trackers :
701755        phase_name  =  "test"  if  is_final_validation  else  "validation" 
@@ -1171,6 +1225,7 @@ def load_model_hook(models, input_dir):
11711225        video_column = args .video_column ,
11721226        height = args .height ,
11731227        width = args .width ,
1228+         video_reshape_mode = args .video_reshape_mode ,
11741229        fps = args .fps ,
11751230        max_num_frames = args .max_num_frames ,
11761231        skip_frames_start = args .skip_frames_start ,
@@ -1179,13 +1234,21 @@ def load_model_hook(models, input_dir):
11791234        id_token = args .id_token ,
11801235    )
11811236
1182-     def  encode_video (video ):
1237+     def  encode_video (video , bar ):
1238+         bar .update (1 )
11831239        video  =  video .to (accelerator .device , dtype = vae .dtype ).unsqueeze (0 )
11841240        video  =  video .permute (0 , 2 , 1 , 3 , 4 )  # [B, C, F, H, W] 
11851241        latent_dist  =  vae .encode (video ).latent_dist 
11861242        return  latent_dist 
11871243
1188-     train_dataset .instance_videos  =  [encode_video (video ) for  video  in  train_dataset .instance_videos ]
1244+     progress_encode_bar  =  tqdm (
1245+         range (0 , len (train_dataset .instance_videos )),
1246+         desc = "Loading Encode videos" ,
1247+     )
1248+     train_dataset .instance_videos  =  [
1249+         encode_video (video , progress_encode_bar ) for  video  in  train_dataset .instance_videos 
1250+     ]
1251+     progress_encode_bar .close ()
11891252
11901253    def  collate_fn (examples ):
11911254        videos  =  [example ["instance_video" ].sample () *  vae .config .scaling_factor  for  example  in  examples ]
0 commit comments