@@ -1596,48 +1596,131 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
15961596 converted_state_dict = {}
15971597 original_state_dict = {k [len ("diffusion_model." ) :]: v for k , v in state_dict .items ()}
15981598
1599- num_blocks = len ({k .split ("blocks." )[1 ].split ("." )[0 ] for k in original_state_dict })
1599+ num_blocks = len ({k .split ("blocks." )[1 ].split ("." )[0 ] for k in original_state_dict if "blocks." in k })
16001600 is_i2v_lora = any ("k_img" in k for k in original_state_dict ) and any ("v_img" in k for k in original_state_dict )
1601+ lora_down_key = "lora_A" if any ("lora_A" in k for k in original_state_dict ) else "lora_down"
1602+ lora_up_key = "lora_B" if any ("lora_B" in k for k in original_state_dict ) else "lora_up"
1603+
1604+ diff_keys = [k for k in original_state_dict if k .endswith ((".diff_b" , ".diff" ))]
1605+ if diff_keys :
1606+ for diff_k in diff_keys :
1607+ param = original_state_dict [diff_k ]
1608+ all_zero = torch .all (param == 0 ).item ()
1609+ if all_zero :
1610+ logger .debug (f"Removed { diff_k } key from the state dict as it's all zeros." )
1611+ original_state_dict .pop (diff_k )
1612+
1613+ # For the `diff_b` keys, we treat them as lora_bias.
1614+ # https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraConfig.lora_bias
16011615
16021616 for i in range (num_blocks ):
16031617 # Self-attention
16041618 for o , c in zip (["q" , "k" , "v" , "o" ], ["to_q" , "to_k" , "to_v" , "to_out.0" ]):
16051619 converted_state_dict [f"blocks.{ i } .attn1.{ c } .lora_A.weight" ] = original_state_dict .pop (
1606- f"blocks.{ i } .self_attn.{ o } .lora_A .weight"
1620+ f"blocks.{ i } .self_attn.{ o } .{ lora_down_key } .weight"
16071621 )
16081622 converted_state_dict [f"blocks.{ i } .attn1.{ c } .lora_B.weight" ] = original_state_dict .pop (
1609- f"blocks.{ i } .self_attn.{ o } .lora_B .weight"
1623+ f"blocks.{ i } .self_attn.{ o } .{ lora_up_key } .weight"
16101624 )
1625+ if f"blocks.{ i } .self_attn.{ o } .diff_b" in original_state_dict :
1626+ converted_state_dict [f"blocks.{ i } .attn1.{ c } .lora_B.bias" ] = original_state_dict .pop (
1627+ f"blocks.{ i } .self_attn.{ o } .diff_b"
1628+ )
16111629
16121630 # Cross-attention
16131631 for o , c in zip (["q" , "k" , "v" , "o" ], ["to_q" , "to_k" , "to_v" , "to_out.0" ]):
16141632 converted_state_dict [f"blocks.{ i } .attn2.{ c } .lora_A.weight" ] = original_state_dict .pop (
1615- f"blocks.{ i } .cross_attn.{ o } .lora_A .weight"
1633+ f"blocks.{ i } .cross_attn.{ o } .{ lora_down_key } .weight"
16161634 )
16171635 converted_state_dict [f"blocks.{ i } .attn2.{ c } .lora_B.weight" ] = original_state_dict .pop (
1618- f"blocks.{ i } .cross_attn.{ o } .lora_B .weight"
1636+ f"blocks.{ i } .cross_attn.{ o } .{ lora_up_key } .weight"
16191637 )
1638+ if f"blocks.{ i } .cross_attn.{ o } .diff_b" in original_state_dict :
1639+ converted_state_dict [f"blocks.{ i } .attn2.{ c } .lora_B.bias" ] = original_state_dict .pop (
1640+ f"blocks.{ i } .cross_attn.{ o } .diff_b"
1641+ )
16201642
16211643 if is_i2v_lora :
16221644 for o , c in zip (["k_img" , "v_img" ], ["add_k_proj" , "add_v_proj" ]):
16231645 converted_state_dict [f"blocks.{ i } .attn2.{ c } .lora_A.weight" ] = original_state_dict .pop (
1624- f"blocks.{ i } .cross_attn.{ o } .lora_A .weight"
1646+ f"blocks.{ i } .cross_attn.{ o } .{ lora_down_key } .weight"
16251647 )
16261648 converted_state_dict [f"blocks.{ i } .attn2.{ c } .lora_B.weight" ] = original_state_dict .pop (
1627- f"blocks.{ i } .cross_attn.{ o } .lora_B .weight"
1649+ f"blocks.{ i } .cross_attn.{ o } .{ lora_up_key } .weight"
16281650 )
1651+ if f"blocks.{ i } .cross_attn.{ o } .diff_b" in original_state_dict :
1652+ converted_state_dict [f"blocks.{ i } .attn2.{ c } .lora_B.bias" ] = original_state_dict .pop (
1653+ f"blocks.{ i } .cross_attn.{ o } .diff_b"
1654+ )
16291655
16301656 # FFN
16311657 for o , c in zip (["ffn.0" , "ffn.2" ], ["net.0.proj" , "net.2" ]):
16321658 converted_state_dict [f"blocks.{ i } .ffn.{ c } .lora_A.weight" ] = original_state_dict .pop (
1633- f"blocks.{ i } .{ o } .lora_A .weight"
1659+ f"blocks.{ i } .{ o } .{ lora_down_key } .weight"
16341660 )
16351661 converted_state_dict [f"blocks.{ i } .ffn.{ c } .lora_B.weight" ] = original_state_dict .pop (
1636- f"blocks.{ i } .{ o } .lora_B .weight"
1662+ f"blocks.{ i } .{ o } .{ lora_up_key } .weight"
16371663 )
1664+ if f"blocks.{ i } .{ o } .diff_b" in original_state_dict :
1665+ converted_state_dict [f"blocks.{ i } .ffn.{ c } .lora_B.bias" ] = original_state_dict .pop (
1666+ f"blocks.{ i } .{ o } .diff_b"
1667+ )
1668+
1669+ # Remaining.
1670+ if original_state_dict :
1671+ if any ("time_projection" in k for k in original_state_dict ):
1672+ converted_state_dict ["condition_embedder.time_proj.lora_A.weight" ] = original_state_dict .pop (
1673+ f"time_projection.1.{ lora_down_key } .weight"
1674+ )
1675+ converted_state_dict ["condition_embedder.time_proj.lora_B.weight" ] = original_state_dict .pop (
1676+ f"time_projection.1.{ lora_up_key } .weight"
1677+ )
1678+ if "time_projection.1.diff_b" in original_state_dict :
1679+ converted_state_dict ["condition_embedder.time_proj.lora_B.bias" ] = original_state_dict .pop (
1680+ "time_projection.1.diff_b"
1681+ )
1682+
1683+ if any ("head.head" in k for k in state_dict ):
1684+ converted_state_dict ["proj_out.lora_A.weight" ] = original_state_dict .pop (
1685+ f"head.head.{ lora_down_key } .weight"
1686+ )
1687+ converted_state_dict ["proj_out.lora_B.weight" ] = original_state_dict .pop (f"head.head.{ lora_up_key } .weight" )
1688+ if "head.head.diff_b" in original_state_dict :
1689+ converted_state_dict ["proj_out.lora_B.bias" ] = original_state_dict .pop ("head.head.diff_b" )
1690+
1691+ for text_time in ["text_embedding" , "time_embedding" ]:
1692+ if any (text_time in k for k in original_state_dict ):
1693+ for b_n in [0 , 2 ]:
1694+ diffusers_b_n = 1 if b_n == 0 else 2
1695+ diffusers_name = (
1696+ "condition_embedder.text_embedder"
1697+ if text_time == "text_embedding"
1698+ else "condition_embedder.time_embedder"
1699+ )
1700+ if any (f"{ text_time } .{ b_n } " in k for k in original_state_dict ):
1701+ converted_state_dict [f"{ diffusers_name } .linear_{ diffusers_b_n } .lora_A.weight" ] = (
1702+ original_state_dict .pop (f"{ text_time } .{ b_n } .{ lora_down_key } .weight" )
1703+ )
1704+ converted_state_dict [f"{ diffusers_name } .linear_{ diffusers_b_n } .lora_B.weight" ] = (
1705+ original_state_dict .pop (f"{ text_time } .{ b_n } .{ lora_up_key } .weight" )
1706+ )
1707+ if f"{ text_time } .{ b_n } .diff_b" in original_state_dict :
1708+ converted_state_dict [f"{ diffusers_name } .linear_{ diffusers_b_n } .lora_B.bias" ] = (
1709+ original_state_dict .pop (f"{ text_time } .{ b_n } .diff_b" )
1710+ )
16381711
16391712 if len (original_state_dict ) > 0 :
1640- raise ValueError (f"`state_dict` should be empty at this point but has { original_state_dict .keys ()= } " )
1713+ diff = all (".diff" in k for k in original_state_dict )
1714+ if diff :
1715+ diff_keys = {k for k in original_state_dict if k .endswith (".diff" )}
1716+ if not all ("lora" not in k for k in diff_keys ):
1717+ raise ValueError
1718+ logger .info (
1719+ "The remaining `state_dict` contains `diff` keys which we do not handle yet. If you see performance issues, please file an issue: "
1720+ "https://github.com/huggingface/diffusers//issues/new"
1721+ )
1722+ else :
1723+ raise ValueError (f"`state_dict` should be empty at this point but has { original_state_dict .keys ()= } " )
16411724
16421725 for key in list (converted_state_dict .keys ()):
16431726 converted_state_dict [f"transformer.{ key } " ] = converted_state_dict .pop (key )
0 commit comments