@@ -4249,7 +4249,33 @@ def lora_state_dict(
42494249
42504250 return state_dict
42514251
4252- # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
4252+ @classmethod
4253+ def _maybe_expand_t2v_lora_for_i2v (
4254+ cls ,
4255+ transformer : torch .nn .Module ,
4256+ state_dict ,
4257+ ):
4258+ if transformer .config .image_dim is None :
4259+ return state_dict
4260+
4261+ if any (k .startswith ("transformer.blocks." ) for k in state_dict ):
4262+ num_blocks = len ({k .split ("blocks." )[1 ].split ("." )[0 ] for k in state_dict })
4263+ is_i2v_lora = any ("add_k_proj" in k for k in state_dict ) and any ("add_v_proj" in k for k in state_dict )
4264+
4265+ if is_i2v_lora :
4266+ return state_dict
4267+
4268+ for i in range (num_blocks ):
4269+ for o , c in zip (["k_img" , "v_img" ], ["add_k_proj" , "add_v_proj" ]):
4270+ state_dict [f"transformer.blocks.{ i } .attn2.{ c } .lora_A.weight" ] = torch .zeros_like (
4271+ state_dict [f"transformer.blocks.{ i } .attn2.to_k.lora_A.weight" ]
4272+ )
4273+ state_dict [f"transformer.blocks.{ i } .attn2.{ c } .lora_B.weight" ] = torch .zeros_like (
4274+ state_dict [f"transformer.blocks.{ i } .attn2.to_k.lora_B.weight" ]
4275+ )
4276+
4277+ return state_dict
4278+
42534279 def load_lora_weights (
42544280 self , pretrained_model_name_or_path_or_dict : Union [str , Dict [str , torch .Tensor ]], adapter_name = None , ** kwargs
42554281 ):
@@ -4287,7 +4313,11 @@ def load_lora_weights(
42874313
42884314 # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
42894315 state_dict = self .lora_state_dict (pretrained_model_name_or_path_or_dict , ** kwargs )
4290-
4316+ # convert T2V LoRA to I2V LoRA (when loaded to Wan I2V) by adding zeros for the additional (missing) _img layers
4317+ state_dict = self ._maybe_expand_t2v_lora_for_i2v (
4318+ transformer = getattr (self , self .transformer_name ) if not hasattr (self , "transformer" ) else self .transformer ,
4319+ state_dict = state_dict ,
4320+ )
42914321 is_correct_format = all ("lora" in key for key in state_dict .keys ())
42924322 if not is_correct_format :
42934323 raise ValueError ("Invalid LoRA checkpoint." )
0 commit comments