Skip to content

Commit 7e7e62c

Browse files
rockerBOOsayakpaulgithub-actions[bot]
authored
Convert alphas for embedders for sd-scripts to ai toolkit conversion (huggingface#12332)
* Convert alphas for embedders for sd-scripts to ai toolkit conversion * Add kohya embedders conversion test * Apply style fixes --------- Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent eda9ff8 commit 7e7e62c

File tree

2 files changed

+46
-47
lines changed

2 files changed

+46
-47
lines changed

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 39 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -558,70 +558,62 @@ def assign_remaining_weights(assignments, source):
558558
ait_sd[target_key] = value
559559

560560
if any("guidance_in" in k for k in sds_sd):
561-
assign_remaining_weights(
562-
[
563-
(
564-
"time_text_embed.guidance_embedder.linear_1.{lora_key}.weight",
565-
"lora_unet_guidance_in_in_layer.{orig_lora_key}.weight",
566-
None,
567-
),
568-
(
569-
"time_text_embed.guidance_embedder.linear_2.{lora_key}.weight",
570-
"lora_unet_guidance_in_out_layer.{orig_lora_key}.weight",
571-
None,
572-
),
573-
],
561+
_convert_to_ai_toolkit(
574562
sds_sd,
563+
ait_sd,
564+
"lora_unet_guidance_in_in_layer",
565+
"time_text_embed.guidance_embedder.linear_1",
566+
)
567+
568+
_convert_to_ai_toolkit(
569+
sds_sd,
570+
ait_sd,
571+
"lora_unet_guidance_in_out_layer",
572+
"time_text_embed.guidance_embedder.linear_2",
575573
)
576574

577575
if any("img_in" in k for k in sds_sd):
578-
assign_remaining_weights(
579-
[
580-
("x_embedder.{lora_key}.weight", "lora_unet_img_in.{orig_lora_key}.weight", None),
581-
],
576+
_convert_to_ai_toolkit(
582577
sds_sd,
578+
ait_sd,
579+
"lora_unet_img_in",
580+
"x_embedder",
583581
)
584582

585583
if any("txt_in" in k for k in sds_sd):
586-
assign_remaining_weights(
587-
[
588-
("context_embedder.{lora_key}.weight", "lora_unet_txt_in.{orig_lora_key}.weight", None),
589-
],
584+
_convert_to_ai_toolkit(
590585
sds_sd,
586+
ait_sd,
587+
"lora_unet_txt_in",
588+
"context_embedder",
591589
)
592590

593591
if any("time_in" in k for k in sds_sd):
594-
assign_remaining_weights(
595-
[
596-
(
597-
"time_text_embed.timestep_embedder.linear_1.{lora_key}.weight",
598-
"lora_unet_time_in_in_layer.{orig_lora_key}.weight",
599-
None,
600-
),
601-
(
602-
"time_text_embed.timestep_embedder.linear_2.{lora_key}.weight",
603-
"lora_unet_time_in_out_layer.{orig_lora_key}.weight",
604-
None,
605-
),
606-
],
592+
_convert_to_ai_toolkit(
607593
sds_sd,
594+
ait_sd,
595+
"lora_unet_time_in_in_layer",
596+
"time_text_embed.timestep_embedder.linear_1",
597+
)
598+
_convert_to_ai_toolkit(
599+
sds_sd,
600+
ait_sd,
601+
"lora_unet_time_in_out_layer",
602+
"time_text_embed.timestep_embedder.linear_2",
608603
)
609604

610605
if any("vector_in" in k for k in sds_sd):
611-
assign_remaining_weights(
612-
[
613-
(
614-
"time_text_embed.text_embedder.linear_1.{lora_key}.weight",
615-
"lora_unet_vector_in_in_layer.{orig_lora_key}.weight",
616-
None,
617-
),
618-
(
619-
"time_text_embed.text_embedder.linear_2.{lora_key}.weight",
620-
"lora_unet_vector_in_out_layer.{orig_lora_key}.weight",
621-
None,
622-
),
623-
],
606+
_convert_to_ai_toolkit(
607+
sds_sd,
608+
ait_sd,
609+
"lora_unet_vector_in_in_layer",
610+
"time_text_embed.text_embedder.linear_1",
611+
)
612+
_convert_to_ai_toolkit(
624613
sds_sd,
614+
ait_sd,
615+
"lora_unet_vector_in_out_layer",
616+
"time_text_embed.text_embedder.linear_2",
625617
)
626618

627619
if any("final_layer" in k for k in sds_sd):

tests/lora/test_lora_layers_flux.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -907,6 +907,13 @@ def test_flux_kohya_with_text_encoder(self):
907907

908908
assert max_diff < 1e-3
909909

910+
def test_flux_kohya_embedders_conversion(self):
911+
"""Test that embedders load without throwing errors"""
912+
self.pipeline.load_lora_weights("rockerBOO/flux-bpo-po-lora")
913+
self.pipeline.unload_lora_weights()
914+
915+
assert True
916+
910917
def test_flux_xlabs(self):
911918
self.pipeline.load_lora_weights("XLabs-AI/flux-lora-collection", weight_name="disney_lora.safetensors")
912919
self.pipeline.fuse_lora()

0 commit comments

Comments
 (0)