Skip to content

Commit aa0cafb

Browse files
authored
Merge branch 'main' into support-group-offloading-pipeline-level
2 parents 1a8ebf6 + 4067d6c commit aa0cafb

23 files changed

+5387
-25
lines changed

docs/source/en/api/image_processor.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,12 @@ All pipelines with [`VaeImageProcessor`] accept PIL Image, PyTorch tensor, or Nu
2020

2121
[[autodoc]] image_processor.VaeImageProcessor
2222

23+
## InpaintProcessor
24+
25+
The [`InpaintProcessor`] accepts `mask` and `image` inputs and process them together. Optionally, it can accept padding_mask_crop and apply mask overlay.
26+
27+
[[autodoc]] image_processor.InpaintProcessor
28+
2329
## VaeImageProcessorLDM3D
2430

2531
The [`VaeImageProcessorLDM3D`] accepts RGB and depth inputs and returns RGB and depth outputs.

examples/dreambooth/train_dreambooth_lora_flux_kontext.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,9 @@
2929
import numpy as np
3030
import torch
3131
import transformers
32-
from accelerate import Accelerator
32+
from accelerate import Accelerator, DistributedType
3333
from accelerate.logging import get_logger
34+
from accelerate.state import AcceleratorState
3435
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
3536
from huggingface_hub import create_repo, upload_folder
3637
from huggingface_hub.utils import insecure_hashlib
@@ -1222,6 +1223,9 @@ def main(args):
12221223
kwargs_handlers=[kwargs],
12231224
)
12241225

1226+
if accelerator.distributed_type == DistributedType.DEEPSPEED:
1227+
AcceleratorState().deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = args.train_batch_size
1228+
12251229
# Disable AMP for MPS.
12261230
if torch.backends.mps.is_available():
12271231
accelerator.native_amp = False
@@ -1438,17 +1442,20 @@ def save_model_hook(models, weights, output_dir):
14381442
text_encoder_one_lora_layers_to_save = None
14391443
modules_to_save = {}
14401444
for model in models:
1441-
if isinstance(model, type(unwrap_model(transformer))):
1445+
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
1446+
model = unwrap_model(model)
14421447
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
14431448
modules_to_save["transformer"] = model
1444-
elif isinstance(model, type(unwrap_model(text_encoder_one))):
1449+
elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_one))):
1450+
model = unwrap_model(model)
14451451
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
14461452
modules_to_save["text_encoder"] = model
14471453
else:
14481454
raise ValueError(f"unexpected save model: {model.__class__}")
14491455

14501456
# make sure to pop weight so that corresponding model is not saved again
1451-
weights.pop()
1457+
if weights:
1458+
weights.pop()
14521459

14531460
FluxKontextPipeline.save_lora_weights(
14541461
output_dir,
@@ -1461,15 +1468,25 @@ def load_model_hook(models, input_dir):
14611468
transformer_ = None
14621469
text_encoder_one_ = None
14631470

1464-
while len(models) > 0:
1465-
model = models.pop()
1471+
if not accelerator.distributed_type == DistributedType.DEEPSPEED:
1472+
while len(models) > 0:
1473+
model = models.pop()
14661474

1467-
if isinstance(model, type(unwrap_model(transformer))):
1468-
transformer_ = model
1469-
elif isinstance(model, type(unwrap_model(text_encoder_one))):
1470-
text_encoder_one_ = model
1471-
else:
1472-
raise ValueError(f"unexpected save model: {model.__class__}")
1475+
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
1476+
transformer_ = unwrap_model(model)
1477+
elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_one))):
1478+
text_encoder_one_ = unwrap_model(model)
1479+
else:
1480+
raise ValueError(f"unexpected save model: {model.__class__}")
1481+
1482+
else:
1483+
transformer_ = FluxTransformer2DModel.from_pretrained(
1484+
args.pretrained_model_name_or_path, subfolder="transformer"
1485+
)
1486+
transformer_.add_adapter(transformer_lora_config)
1487+
text_encoder_one_ = text_encoder_cls_one.from_pretrained(
1488+
args.pretrained_model_name_or_path, subfolder="text_encoder"
1489+
)
14731490

14741491
lora_state_dict = FluxKontextPipeline.lora_state_dict(input_dir)
14751492

@@ -2069,7 +2086,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
20692086
progress_bar.update(1)
20702087
global_step += 1
20712088

2072-
if accelerator.is_main_process:
2089+
if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:
20732090
if global_step % args.checkpointing_steps == 0:
20742091
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
20752092
if args.checkpoints_total_limit is not None:

src/diffusers/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,10 @@
385385
[
386386
"FluxAutoBlocks",
387387
"FluxModularPipeline",
388+
"QwenImageAutoBlocks",
389+
"QwenImageEditAutoBlocks",
390+
"QwenImageEditModularPipeline",
391+
"QwenImageModularPipeline",
388392
"StableDiffusionXLAutoBlocks",
389393
"StableDiffusionXLModularPipeline",
390394
"WanAutoBlocks",
@@ -506,6 +510,7 @@
506510
"PixArtAlphaPipeline",
507511
"PixArtSigmaPAGPipeline",
508512
"PixArtSigmaPipeline",
513+
"QwenImageControlNetInpaintPipeline",
509514
"QwenImageControlNetPipeline",
510515
"QwenImageEditInpaintPipeline",
511516
"QwenImageEditPipeline",
@@ -1038,6 +1043,10 @@
10381043
from .modular_pipelines import (
10391044
FluxAutoBlocks,
10401045
FluxModularPipeline,
1046+
QwenImageAutoBlocks,
1047+
QwenImageEditAutoBlocks,
1048+
QwenImageEditModularPipeline,
1049+
QwenImageModularPipeline,
10411050
StableDiffusionXLAutoBlocks,
10421051
StableDiffusionXLModularPipeline,
10431052
WanAutoBlocks,
@@ -1155,6 +1164,7 @@
11551164
PixArtAlphaPipeline,
11561165
PixArtSigmaPAGPipeline,
11571166
PixArtSigmaPipeline,
1167+
QwenImageControlNetInpaintPipeline,
11581168
QwenImageControlNetPipeline,
11591169
QwenImageEditInpaintPipeline,
11601170
QwenImageEditPipeline,

src/diffusers/hooks/_helpers.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def _register_attention_processors_metadata():
108108
from ..models.attention_processor import AttnProcessor2_0
109109
from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor
110110
from ..models.transformers.transformer_flux import FluxAttnProcessor
111+
from ..models.transformers.transformer_qwenimage import QwenDoubleStreamAttnProcessor2_0
111112
from ..models.transformers.transformer_wan import WanAttnProcessor2_0
112113

113114
# AttnProcessor2_0
@@ -140,6 +141,14 @@ def _register_attention_processors_metadata():
140141
metadata=AttentionProcessorMetadata(skip_processor_output_fn=_skip_proc_output_fn_Attention_FluxAttnProcessor),
141142
)
142143

144+
# QwenDoubleStreamAttnProcessor2
145+
AttentionProcessorRegistry.register(
146+
model_class=QwenDoubleStreamAttnProcessor2_0,
147+
metadata=AttentionProcessorMetadata(
148+
skip_processor_output_fn=_skip_proc_output_fn_Attention_QwenDoubleStreamAttnProcessor2_0
149+
),
150+
)
151+
143152

144153
def _register_transformer_blocks_metadata():
145154
from ..models.attention import BasicTransformerBlock
@@ -298,4 +307,5 @@ def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, *
298307
_skip_proc_output_fn_Attention_WanAttnProcessor2_0 = _skip_attention___ret___hidden_states
299308
# not sure what this is yet.
300309
_skip_proc_output_fn_Attention_FluxAttnProcessor = _skip_attention___ret___hidden_states
310+
_skip_proc_output_fn_Attention_QwenDoubleStreamAttnProcessor2_0 = _skip_attention___ret___hidden_states
301311
# fmt: on

src/diffusers/image_processor.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,7 @@ def resize(
523523
size=(height, width),
524524
)
525525
image = self.pt_to_numpy(image)
526+
526527
return image
527528

528529
def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image:
@@ -838,6 +839,137 @@ def apply_overlay(
838839
return image
839840

840841

842+
class InpaintProcessor(ConfigMixin):
843+
"""
844+
Image processor for inpainting image and mask.
845+
"""
846+
847+
config_name = CONFIG_NAME
848+
849+
@register_to_config
850+
def __init__(
851+
self,
852+
do_resize: bool = True,
853+
vae_scale_factor: int = 8,
854+
vae_latent_channels: int = 4,
855+
resample: str = "lanczos",
856+
reducing_gap: int = None,
857+
do_normalize: bool = True,
858+
do_binarize: bool = False,
859+
do_convert_grayscale: bool = False,
860+
mask_do_normalize: bool = False,
861+
mask_do_binarize: bool = True,
862+
mask_do_convert_grayscale: bool = True,
863+
):
864+
super().__init__()
865+
866+
self._image_processor = VaeImageProcessor(
867+
do_resize=do_resize,
868+
vae_scale_factor=vae_scale_factor,
869+
vae_latent_channels=vae_latent_channels,
870+
resample=resample,
871+
reducing_gap=reducing_gap,
872+
do_normalize=do_normalize,
873+
do_binarize=do_binarize,
874+
do_convert_grayscale=do_convert_grayscale,
875+
)
876+
self._mask_processor = VaeImageProcessor(
877+
do_resize=do_resize,
878+
vae_scale_factor=vae_scale_factor,
879+
vae_latent_channels=vae_latent_channels,
880+
resample=resample,
881+
reducing_gap=reducing_gap,
882+
do_normalize=mask_do_normalize,
883+
do_binarize=mask_do_binarize,
884+
do_convert_grayscale=mask_do_convert_grayscale,
885+
)
886+
887+
def preprocess(
888+
self,
889+
image: PIL.Image.Image,
890+
mask: PIL.Image.Image = None,
891+
height: int = None,
892+
width: int = None,
893+
padding_mask_crop: Optional[int] = None,
894+
) -> Tuple[torch.Tensor, torch.Tensor]:
895+
"""
896+
Preprocess the image and mask.
897+
"""
898+
if mask is None and padding_mask_crop is not None:
899+
raise ValueError("mask must be provided if padding_mask_crop is provided")
900+
901+
# if mask is None, same behavior as regular image processor
902+
if mask is None:
903+
return self._image_processor.preprocess(image, height=height, width=width)
904+
905+
if padding_mask_crop is not None:
906+
crops_coords = self._image_processor.get_crop_region(mask, width, height, pad=padding_mask_crop)
907+
resize_mode = "fill"
908+
else:
909+
crops_coords = None
910+
resize_mode = "default"
911+
912+
processed_image = self._image_processor.preprocess(
913+
image,
914+
height=height,
915+
width=width,
916+
crops_coords=crops_coords,
917+
resize_mode=resize_mode,
918+
)
919+
920+
processed_mask = self._mask_processor.preprocess(
921+
mask,
922+
height=height,
923+
width=width,
924+
resize_mode=resize_mode,
925+
crops_coords=crops_coords,
926+
)
927+
928+
if crops_coords is not None:
929+
postprocessing_kwargs = {
930+
"crops_coords": crops_coords,
931+
"original_image": image,
932+
"original_mask": mask,
933+
}
934+
else:
935+
postprocessing_kwargs = {
936+
"crops_coords": None,
937+
"original_image": None,
938+
"original_mask": None,
939+
}
940+
941+
return processed_image, processed_mask, postprocessing_kwargs
942+
943+
def postprocess(
944+
self,
945+
image: torch.Tensor,
946+
output_type: str = "pil",
947+
original_image: Optional[PIL.Image.Image] = None,
948+
original_mask: Optional[PIL.Image.Image] = None,
949+
crops_coords: Optional[Tuple[int, int, int, int]] = None,
950+
) -> Tuple[PIL.Image.Image, PIL.Image.Image]:
951+
"""
952+
Postprocess the image, optionally apply mask overlay
953+
"""
954+
image = self._image_processor.postprocess(
955+
image,
956+
output_type=output_type,
957+
)
958+
# optionally apply the mask overlay
959+
if crops_coords is not None and (original_image is None or original_mask is None):
960+
raise ValueError("original_image and original_mask must be provided if crops_coords is provided")
961+
962+
elif crops_coords is not None and output_type != "pil":
963+
raise ValueError("output_type must be 'pil' if crops_coords is provided")
964+
965+
elif crops_coords is not None:
966+
image = [
967+
self._image_processor.apply_overlay(original_mask, original_image, i, crops_coords) for i in image
968+
]
969+
970+
return image
971+
972+
841973
class VaeImageProcessorLDM3D(VaeImageProcessor):
842974
"""
843975
Image processor for VAE LDM3D.

src/diffusers/modular_pipelines/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,12 @@
4747
_import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline"]
4848
_import_structure["wan"] = ["WanAutoBlocks", "WanModularPipeline"]
4949
_import_structure["flux"] = ["FluxAutoBlocks", "FluxModularPipeline"]
50+
_import_structure["qwenimage"] = [
51+
"QwenImageAutoBlocks",
52+
"QwenImageModularPipeline",
53+
"QwenImageEditModularPipeline",
54+
"QwenImageEditAutoBlocks",
55+
]
5056
_import_structure["components_manager"] = ["ComponentsManager"]
5157

5258
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
@@ -68,6 +74,12 @@
6874
SequentialPipelineBlocks,
6975
)
7076
from .modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, InsertableDict, OutputParam
77+
from .qwenimage import (
78+
QwenImageAutoBlocks,
79+
QwenImageEditAutoBlocks,
80+
QwenImageEditModularPipeline,
81+
QwenImageModularPipeline,
82+
)
7183
from .stable_diffusion_xl import StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline
7284
from .wan import WanAutoBlocks, WanModularPipeline
7385
else:

0 commit comments

Comments
 (0)