Skip to content

Commit 0b4c6f0

Browse files
fix(mm): flux model variant probing
In #7780 we added FLUX Fill support, and needed the probe to be able to distinguish between "normal" FLUX models and FLUX Fill models. Logic was added to the probe to check a particular state dict key (input channels), which should be 384 for FLUX Fill and 64 for other FLUX models. The new logic was stricter and instead of falling back on the "normal" variant, it raised when an unexpected value for input channels was detected. This caused failures to probe for BNB-NF4 quantized FLUX Dev/Schnell, which apparently only have 1 input channel. After checking a variety of FLUX models, I loosened the strictness of the variant probing logic to only special-case the new FLUX Fill model, and otherwise fall back to returning the "normal" variant. This better matches the old behaviour and fixes the import errors. Closes #7822
1 parent a629102 commit 0b4c6f0

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

invokeai/backend/model_manager/legacy_probe.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -563,14 +563,20 @@ def get_variant_type(self) -> ModelVariantType:
563563

564564
if base_type == BaseModelType.Flux:
565565
in_channels = state_dict["img_in.weight"].shape[1]
566-
if in_channels == 64:
567-
return ModelVariantType.Normal
568-
elif in_channels == 384:
566+
567+
# FLUX Model variant types are distinguished by input channels:
568+
# - Unquantized Dev and Schnell have in_channels=64
569+
# - BNB-NF4 Dev and Schnell have in_channels=1
570+
# - FLUX Fill has in_channels=384
571+
# - Unsure of quantized FLUX Fill models
572+
# - Unsure of GGUF-quantized models
573+
if in_channels == 384:
574+
# This is a FLUX Fill model. FLUX Fill needs special handling throughout the application. The variant
575+
# type is used to determine whether to use the fill model or the base model.
569576
return ModelVariantType.Inpaint
570577
else:
571-
raise InvalidModelConfigException(
572-
f"Unexpected in_channels (in_channels={in_channels}) for FLUX model at {self.model_path}."
573-
)
578+
# Fall back on "normal" variant type for all other FLUX models.
579+
return ModelVariantType.Normal
574580

575581
in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
576582
if in_channels == 9:

0 commit comments

Comments
 (0)