Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,8 @@
_import_structure["modular_pipelines"].extend(
[
"FluxAutoBlocks",
"FluxKontextAutoBlocks",
"FluxKontextModularPipeline",
"FluxModularPipeline",
"QwenImageAutoBlocks",
"QwenImageEditAutoBlocks",
Expand Down Expand Up @@ -1050,6 +1052,8 @@
else:
from .modular_pipelines import (
FluxAutoBlocks,
FluxKontextAutoBlocks,
FluxKontextModularPipeline,
FluxModularPipeline,
QwenImageAutoBlocks,
QwenImageEditAutoBlocks,
Expand Down
9 changes: 7 additions & 2 deletions src/diffusers/modular_pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,12 @@
]
_import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline"]
_import_structure["wan"] = ["WanAutoBlocks", "WanModularPipeline"]
_import_structure["flux"] = ["FluxAutoBlocks", "FluxModularPipeline"]
_import_structure["flux"] = [
"FluxAutoBlocks",
"FluxModularPipeline",
"FluxKontextAutoBlocks",
"FluxKontextModularPipeline",
]
_import_structure["qwenimage"] = [
"QwenImageAutoBlocks",
"QwenImageModularPipeline",
Expand All @@ -65,7 +70,7 @@
from ..utils.dummy_pt_objects import * # noqa F403
else:
from .components_manager import ComponentsManager
from .flux import FluxAutoBlocks, FluxModularPipeline
from .flux import FluxAutoBlocks, FluxKontextAutoBlocks, FluxKontextModularPipeline, FluxModularPipeline
from .modular_pipeline import (
AutoPipelineBlocks,
BlockState,
Expand Down
15 changes: 12 additions & 3 deletions src/diffusers/modular_pipelines/flux/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,18 @@
_import_structure["modular_blocks"] = [
"ALL_BLOCKS",
"AUTO_BLOCKS",
"AUTO_BLOCKS_KONTTEXT",
"FLUX_KONTEXT_BLOCKS",
"TEXT2IMAGE_BLOCKS",
"FluxAutoBeforeDenoiseStep",
"FluxAutoBlocks",
"FluxAutoBlocks",
"FluxAutoDecodeStep",
"FluxAutoDenoiseStep",
"FluxKontextAutoBlocks",
"FluxKontextAutoDenoiseStep",
"FluxKontextBeforeDenoiseStep",
]
_import_structure["modular_pipeline"] = ["FluxModularPipeline"]
_import_structure["modular_pipeline"] = ["FluxKontextModularPipeline", "FluxModularPipeline"]

if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
Expand All @@ -45,13 +49,18 @@
from .modular_blocks import (
ALL_BLOCKS,
AUTO_BLOCKS,
AUTO_BLOCKS_KONTTEXT,
FLUX_KONTEXT_BLOCKS,
TEXT2IMAGE_BLOCKS,
FluxAutoBeforeDenoiseStep,
FluxAutoBlocks,
FluxAutoDecodeStep,
FluxAutoDenoiseStep,
FluxKontextAutoBlocks,
FluxKontextAutoDenoiseStep,
FluxKontextBeforeDenoiseStep,
)
from .modular_pipeline import FluxModularPipeline
from .modular_pipeline import FluxKontextModularPipeline, FluxModularPipeline
else:
import sys

Expand Down
83 changes: 72 additions & 11 deletions src/diffusers/modular_pipelines/flux/before_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,15 +118,6 @@ def retrieve_latents(
raise AttributeError("Could not access latents of provided encoder_output")


# TODO: align this with Qwen patchifier
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)

return latents


def _get_initial_timesteps_and_optionals(
transformer,
scheduler,
Expand Down Expand Up @@ -398,9 +389,9 @@ def prepare_latents(
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)

# TODO: move packing latents code to a patchifier
# TODO: move packing latents code to a patchifier similar to Qwen
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = _pack_latents(latents, batch_size, num_channels_latents, height, width)
latents = FluxPipeline._pack_latents(latents, batch_size, num_channels_latents, height, width)

return latents

Expand Down Expand Up @@ -557,3 +548,73 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip
self.set_block_state(state, block_state)

return components, state


class FluxKontextRoPEInputsStep(ModularPipelineBlocks):
model_name = "flux_kontext"

@property
def description(self) -> str:
return "Step that prepares the RoPE inputs for the denoising process of Flux Kontext. Should be placed after text encoder and latent preparation steps."

@property
def inputs(self) -> List[InputParam]:
return [
InputParam(name="image_height"),
InputParam(name="image_width"),
InputParam(name="height"),
InputParam(name="width"),
InputParam(name="prompt_embeds"),
]

@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
name="txt_ids",
kwargs_type="denoiser_input_fields",
type_hint=List[int],
description="The sequence lengths of the prompt embeds, used for RoPE calculation.",
),
OutputParam(
name="img_ids",
kwargs_type="denoiser_input_fields",
type_hint=List[int],
description="The sequence lengths of the image latents, used for RoPE calculation.",
),
]

def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)

prompt_embeds = block_state.prompt_embeds
device, dtype = prompt_embeds.device, prompt_embeds.dtype
block_state.txt_ids = torch.zeros(prompt_embeds.shape[1], 3).to(
device=prompt_embeds.device, dtype=prompt_embeds.dtype
)

img_ids = None
if (
getattr(block_state, "image_height", None) is not None
and getattr(block_state, "image_width", None) is not None
):
image_latent_height = 2 * (int(block_state.image_height) // (components.vae_scale_factor * 2))
image_latent_width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2))
img_ids = FluxPipeline._prepare_latent_image_ids(
None, image_latent_height // 2, image_latent_width // 2, device, dtype
)
# image ids are the same as latent ids with the first dimension set to 1 instead of 0
img_ids[..., 0] = 1

height = 2 * (int(block_state.height) // (components.vae_scale_factor * 2))
width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2))
latent_ids = FluxPipeline._prepare_latent_image_ids(None, height // 2, width // 2, device, dtype)

if img_ids is not None:
latent_ids = torch.cat([latent_ids, img_ids], dim=0)

block_state.img_ids = latent_ids

self.set_block_state(state, block_state)

return components, state
106 changes: 106 additions & 0 deletions src/diffusers/modular_pipelines/flux/denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,96 @@ def __call__(
return components, block_state


class FluxKontextLoopDenoiser(ModularPipelineBlocks):
model_name = "flux_kontext"

@property
def expected_components(self) -> List[ComponentSpec]:
return [ComponentSpec("transformer", FluxTransformer2DModel)]

@property
def description(self) -> str:
return (
"Step within the denoising loop that denoise the latents for Flux Kontext. "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `FluxDenoiseLoopWrapper`)"
)

@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("joint_attention_kwargs"),
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
),
InputParam(
"image_latents",
type_hint=torch.Tensor,
description="Image latents to use for the denoising process. Can be generated in prepare_latent step.",
),
InputParam(
"guidance",
required=True,
type_hint=torch.Tensor,
description="Guidance scale as a tensor",
),
InputParam(
"prompt_embeds",
required=True,
type_hint=torch.Tensor,
description="Prompt embeddings",
),
InputParam(
"pooled_prompt_embeds",
required=True,
type_hint=torch.Tensor,
description="Pooled prompt embeddings",
),
InputParam(
"txt_ids",
required=True,
type_hint=torch.Tensor,
description="IDs computed from text sequence needed for RoPE",
),
InputParam(
"img_ids",
required=True,
type_hint=torch.Tensor,
description="IDs computed from latent sequence needed for RoPE",
),
]

@torch.no_grad()
def __call__(
self, components: FluxModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
) -> PipelineState:
latents = block_state.latents
latent_model_input = latents
image_latents = block_state.image_latents
if image_latents is not None:
latent_model_input = torch.cat([latent_model_input, image_latents], dim=1)

timestep = t.expand(latents.shape[0]).to(latents.dtype)
noise_pred = components.transformer(
hidden_states=latent_model_input,
timestep=timestep / 1000,
guidance=block_state.guidance,
encoder_hidden_states=block_state.prompt_embeds,
pooled_projections=block_state.pooled_prompt_embeds,
joint_attention_kwargs=block_state.joint_attention_kwargs,
txt_ids=block_state.txt_ids,
img_ids=block_state.img_ids,
return_dict=False,
)[0]
noise_pred = noise_pred[:, : latents.size(1)]
block_state.noise_pred = noise_pred

return components, block_state


class FluxLoopAfterDenoiser(ModularPipelineBlocks):
model_name = "flux"

Expand Down Expand Up @@ -221,3 +311,19 @@ def description(self) -> str:
" - `FluxLoopAfterDenoiser`\n"
"This block supports both text2image and img2img tasks."
)


class FluxKontextDenoiseStep(FluxDenoiseLoopWrapper):
block_classes = [FluxKontextLoopDenoiser, FluxLoopAfterDenoiser]
block_names = ["denoiser", "after_denoiser"]

@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents. \n"
"Its loop logic is defined in `FluxDenoiseLoopWrapper.__call__` method \n"
"At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
" - `FluxKontextLoopDenoiser`\n"
" - `FluxLoopAfterDenoiser`\n"
"This block supports both text2image and img2img tasks."
)
Loading
Loading