Skip to content

Commit 693d8a3

Browse files
authored
[modular] i2i and t2i support for kontext modular (huggingface#12454)
* up * get ready * fix import * up * up
1 parent a9df12a commit 693d8a3

File tree

11 files changed

+659
-46
lines changed

11 files changed

+659
-46
lines changed

src/diffusers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,8 @@
386386
_import_structure["modular_pipelines"].extend(
387387
[
388388
"FluxAutoBlocks",
389+
"FluxKontextAutoBlocks",
390+
"FluxKontextModularPipeline",
389391
"FluxModularPipeline",
390392
"QwenImageAutoBlocks",
391393
"QwenImageEditAutoBlocks",
@@ -1050,6 +1052,8 @@
10501052
else:
10511053
from .modular_pipelines import (
10521054
FluxAutoBlocks,
1055+
FluxKontextAutoBlocks,
1056+
FluxKontextModularPipeline,
10531057
FluxModularPipeline,
10541058
QwenImageAutoBlocks,
10551059
QwenImageEditAutoBlocks,

src/diffusers/modular_pipelines/__init__.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,12 @@
4646
]
4747
_import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline"]
4848
_import_structure["wan"] = ["WanAutoBlocks", "WanModularPipeline"]
49-
_import_structure["flux"] = ["FluxAutoBlocks", "FluxModularPipeline"]
49+
_import_structure["flux"] = [
50+
"FluxAutoBlocks",
51+
"FluxModularPipeline",
52+
"FluxKontextAutoBlocks",
53+
"FluxKontextModularPipeline",
54+
]
5055
_import_structure["qwenimage"] = [
5156
"QwenImageAutoBlocks",
5257
"QwenImageModularPipeline",
@@ -65,7 +70,7 @@
6570
from ..utils.dummy_pt_objects import * # noqa F403
6671
else:
6772
from .components_manager import ComponentsManager
68-
from .flux import FluxAutoBlocks, FluxModularPipeline
73+
from .flux import FluxAutoBlocks, FluxKontextAutoBlocks, FluxKontextModularPipeline, FluxModularPipeline
6974
from .modular_pipeline import (
7075
AutoPipelineBlocks,
7176
BlockState,

src/diffusers/modular_pipelines/flux/__init__.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,18 @@
2525
_import_structure["modular_blocks"] = [
2626
"ALL_BLOCKS",
2727
"AUTO_BLOCKS",
28+
"AUTO_BLOCKS_KONTEXT",
29+
"FLUX_KONTEXT_BLOCKS",
2830
"TEXT2IMAGE_BLOCKS",
2931
"FluxAutoBeforeDenoiseStep",
3032
"FluxAutoBlocks",
31-
"FluxAutoBlocks",
3233
"FluxAutoDecodeStep",
3334
"FluxAutoDenoiseStep",
35+
"FluxKontextAutoBlocks",
36+
"FluxKontextAutoDenoiseStep",
37+
"FluxKontextBeforeDenoiseStep",
3438
]
35-
_import_structure["modular_pipeline"] = ["FluxModularPipeline"]
39+
_import_structure["modular_pipeline"] = ["FluxKontextModularPipeline", "FluxModularPipeline"]
3640

3741
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
3842
try:
@@ -45,13 +49,18 @@
4549
from .modular_blocks import (
4650
ALL_BLOCKS,
4751
AUTO_BLOCKS,
52+
AUTO_BLOCKS_KONTEXT,
53+
FLUX_KONTEXT_BLOCKS,
4854
TEXT2IMAGE_BLOCKS,
4955
FluxAutoBeforeDenoiseStep,
5056
FluxAutoBlocks,
5157
FluxAutoDecodeStep,
5258
FluxAutoDenoiseStep,
59+
FluxKontextAutoBlocks,
60+
FluxKontextAutoDenoiseStep,
61+
FluxKontextBeforeDenoiseStep,
5362
)
54-
from .modular_pipeline import FluxModularPipeline
63+
from .modular_pipeline import FluxKontextModularPipeline, FluxModularPipeline
5564
else:
5665
import sys
5766

src/diffusers/modular_pipelines/flux/before_denoise.py

Lines changed: 72 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -118,15 +118,6 @@ def retrieve_latents(
118118
raise AttributeError("Could not access latents of provided encoder_output")
119119

120120

121-
# TODO: align this with Qwen patchifier
122-
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
123-
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
124-
latents = latents.permute(0, 2, 4, 1, 3, 5)
125-
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
126-
127-
return latents
128-
129-
130121
def _get_initial_timesteps_and_optionals(
131122
transformer,
132123
scheduler,
@@ -398,16 +389,15 @@ def prepare_latents(
398389
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
399390
)
400391

401-
# TODO: move packing latents code to a patchifier
392+
# TODO: move packing latents code to a patchifier similar to Qwen
402393
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
403-
latents = _pack_latents(latents, batch_size, num_channels_latents, height, width)
394+
latents = FluxPipeline._pack_latents(latents, batch_size, num_channels_latents, height, width)
404395

405396
return latents
406397

407398
@torch.no_grad()
408399
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
409400
block_state = self.get_block_state(state)
410-
411401
block_state.height = block_state.height or components.default_height
412402
block_state.width = block_state.width or components.default_width
413403
block_state.device = components._execution_device
@@ -557,3 +547,73 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip
557547
self.set_block_state(state, block_state)
558548

559549
return components, state
550+
551+
552+
class FluxKontextRoPEInputsStep(ModularPipelineBlocks):
553+
model_name = "flux-kontext"
554+
555+
@property
556+
def description(self) -> str:
557+
return "Step that prepares the RoPE inputs for the denoising process of Flux Kontext. Should be placed after text encoder and latent preparation steps."
558+
559+
@property
560+
def inputs(self) -> List[InputParam]:
561+
return [
562+
InputParam(name="image_height"),
563+
InputParam(name="image_width"),
564+
InputParam(name="height"),
565+
InputParam(name="width"),
566+
InputParam(name="prompt_embeds"),
567+
]
568+
569+
@property
570+
def intermediate_outputs(self) -> List[OutputParam]:
571+
return [
572+
OutputParam(
573+
name="txt_ids",
574+
kwargs_type="denoiser_input_fields",
575+
type_hint=List[int],
576+
description="The sequence lengths of the prompt embeds, used for RoPE calculation.",
577+
),
578+
OutputParam(
579+
name="img_ids",
580+
kwargs_type="denoiser_input_fields",
581+
type_hint=List[int],
582+
description="The sequence lengths of the image latents, used for RoPE calculation.",
583+
),
584+
]
585+
586+
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
587+
block_state = self.get_block_state(state)
588+
589+
prompt_embeds = block_state.prompt_embeds
590+
device, dtype = prompt_embeds.device, prompt_embeds.dtype
591+
block_state.txt_ids = torch.zeros(prompt_embeds.shape[1], 3).to(
592+
device=prompt_embeds.device, dtype=prompt_embeds.dtype
593+
)
594+
595+
img_ids = None
596+
if (
597+
getattr(block_state, "image_height", None) is not None
598+
and getattr(block_state, "image_width", None) is not None
599+
):
600+
image_latent_height = 2 * (int(block_state.image_height) // (components.vae_scale_factor * 2))
601+
image_latent_width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2))
602+
img_ids = FluxPipeline._prepare_latent_image_ids(
603+
None, image_latent_height // 2, image_latent_width // 2, device, dtype
604+
)
605+
# image ids are the same as latent ids with the first dimension set to 1 instead of 0
606+
img_ids[..., 0] = 1
607+
608+
height = 2 * (int(block_state.height) // (components.vae_scale_factor * 2))
609+
width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2))
610+
latent_ids = FluxPipeline._prepare_latent_image_ids(None, height // 2, width // 2, device, dtype)
611+
612+
if img_ids is not None:
613+
latent_ids = torch.cat([latent_ids, img_ids], dim=0)
614+
615+
block_state.img_ids = latent_ids
616+
617+
self.set_block_state(state, block_state)
618+
619+
return components, state

src/diffusers/modular_pipelines/flux/denoise.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,96 @@ def __call__(
109109
return components, block_state
110110

111111

112+
class FluxKontextLoopDenoiser(ModularPipelineBlocks):
113+
model_name = "flux-kontext"
114+
115+
@property
116+
def expected_components(self) -> List[ComponentSpec]:
117+
return [ComponentSpec("transformer", FluxTransformer2DModel)]
118+
119+
@property
120+
def description(self) -> str:
121+
return (
122+
"Step within the denoising loop that denoise the latents for Flux Kontext. "
123+
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
124+
"object (e.g. `FluxDenoiseLoopWrapper`)"
125+
)
126+
127+
@property
128+
def inputs(self) -> List[Tuple[str, Any]]:
129+
return [
130+
InputParam("joint_attention_kwargs"),
131+
InputParam(
132+
"latents",
133+
required=True,
134+
type_hint=torch.Tensor,
135+
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
136+
),
137+
InputParam(
138+
"image_latents",
139+
type_hint=torch.Tensor,
140+
description="Image latents to use for the denoising process. Can be generated in prepare_latent step.",
141+
),
142+
InputParam(
143+
"guidance",
144+
required=True,
145+
type_hint=torch.Tensor,
146+
description="Guidance scale as a tensor",
147+
),
148+
InputParam(
149+
"prompt_embeds",
150+
required=True,
151+
type_hint=torch.Tensor,
152+
description="Prompt embeddings",
153+
),
154+
InputParam(
155+
"pooled_prompt_embeds",
156+
required=True,
157+
type_hint=torch.Tensor,
158+
description="Pooled prompt embeddings",
159+
),
160+
InputParam(
161+
"txt_ids",
162+
required=True,
163+
type_hint=torch.Tensor,
164+
description="IDs computed from text sequence needed for RoPE",
165+
),
166+
InputParam(
167+
"img_ids",
168+
required=True,
169+
type_hint=torch.Tensor,
170+
description="IDs computed from latent sequence needed for RoPE",
171+
),
172+
]
173+
174+
@torch.no_grad()
175+
def __call__(
176+
self, components: FluxModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
177+
) -> PipelineState:
178+
latents = block_state.latents
179+
latent_model_input = latents
180+
image_latents = block_state.image_latents
181+
if image_latents is not None:
182+
latent_model_input = torch.cat([latent_model_input, image_latents], dim=1)
183+
184+
timestep = t.expand(latents.shape[0]).to(latents.dtype)
185+
noise_pred = components.transformer(
186+
hidden_states=latent_model_input,
187+
timestep=timestep / 1000,
188+
guidance=block_state.guidance,
189+
encoder_hidden_states=block_state.prompt_embeds,
190+
pooled_projections=block_state.pooled_prompt_embeds,
191+
joint_attention_kwargs=block_state.joint_attention_kwargs,
192+
txt_ids=block_state.txt_ids,
193+
img_ids=block_state.img_ids,
194+
return_dict=False,
195+
)[0]
196+
noise_pred = noise_pred[:, : latents.size(1)]
197+
block_state.noise_pred = noise_pred
198+
199+
return components, block_state
200+
201+
112202
class FluxLoopAfterDenoiser(ModularPipelineBlocks):
113203
model_name = "flux"
114204

@@ -221,3 +311,20 @@ def description(self) -> str:
221311
" - `FluxLoopAfterDenoiser`\n"
222312
"This block supports both text2image and img2img tasks."
223313
)
314+
315+
316+
class FluxKontextDenoiseStep(FluxDenoiseLoopWrapper):
317+
model_name = "flux-kontext"
318+
block_classes = [FluxKontextLoopDenoiser, FluxLoopAfterDenoiser]
319+
block_names = ["denoiser", "after_denoiser"]
320+
321+
@property
322+
def description(self) -> str:
323+
return (
324+
"Denoise step that iteratively denoise the latents. \n"
325+
"Its loop logic is defined in `FluxDenoiseLoopWrapper.__call__` method \n"
326+
"At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
327+
" - `FluxKontextLoopDenoiser`\n"
328+
" - `FluxLoopAfterDenoiser`\n"
329+
"This block supports both text2image and img2img tasks."
330+
)

0 commit comments

Comments
 (0)