@@ -105,26 +105,26 @@ def calculate_shift(
105105
106106# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
107107def _pack_latents (latents , batch_size , num_channels_latents , height , width ):
108- latents = latents .view (batch_size , num_channels_latents , height // 2 , 2 , width // 2 , 2 )
109- latents = latents .permute (0 , 2 , 4 , 1 , 3 , 5 )
110- latents = latents .reshape (batch_size , (height // 2 ) * (width // 2 ), num_channels_latents * 4 )
108+ latents = latents .view (batch_size , num_channels_latents , height // 2 , 2 , width // 2 , 2 )
109+ latents = latents .permute (0 , 2 , 4 , 1 , 3 , 5 )
110+ latents = latents .reshape (batch_size , (height // 2 ) * (width // 2 ), num_channels_latents * 4 )
111111
112- return latents
112+ return latents
113113
114114
115115# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
116116def _prepare_latent_image_ids (batch_size , height , width , device , dtype ):
117- latent_image_ids = torch .zeros (height , width , 3 )
118- latent_image_ids [..., 1 ] = latent_image_ids [..., 1 ] + torch .arange (height )[:, None ]
119- latent_image_ids [..., 2 ] = latent_image_ids [..., 2 ] + torch .arange (width )[None , :]
117+ latent_image_ids = torch .zeros (height , width , 3 )
118+ latent_image_ids [..., 1 ] = latent_image_ids [..., 1 ] + torch .arange (height )[:, None ]
119+ latent_image_ids [..., 2 ] = latent_image_ids [..., 2 ] + torch .arange (width )[None , :]
120120
121- latent_image_id_height , latent_image_id_width , latent_image_id_channels = latent_image_ids .shape
121+ latent_image_id_height , latent_image_id_width , latent_image_id_channels = latent_image_ids .shape
122122
123- latent_image_ids = latent_image_ids .reshape (
124- latent_image_id_height * latent_image_id_width , latent_image_id_channels
125- )
123+ latent_image_ids = latent_image_ids .reshape (
124+ latent_image_id_height * latent_image_id_width , latent_image_id_channels
125+ )
126126
127- return latent_image_ids .to (device = device , dtype = dtype )
127+ return latent_image_ids .to (device = device , dtype = dtype )
128128
129129
130130class FluxInputStep (PipelineBlock ):
@@ -180,13 +180,11 @@ def intermediate_outputs(self) -> List[str]:
180180 OutputParam (
181181 "prompt_embeds" ,
182182 type_hint = torch .Tensor ,
183- # kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields
184183 description = "text embeddings used to guide the image generation" ,
185184 ),
186185 OutputParam (
187186 "pooled_prompt_embeds" ,
188187 type_hint = torch .Tensor ,
189- # kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields
190188 description = "pooled text embeddings used to guide the image generation" ,
191189 ),
192190 # TODO: support negative embeddings?
@@ -235,10 +233,10 @@ def description(self) -> str:
235233 def inputs (self ) -> List [InputParam ]:
236234 return [
237235 InputParam ("num_inference_steps" , default = 50 ),
238- InputParam ("timesteps" ),
236+ InputParam ("timesteps" ),
239237 InputParam ("sigmas" ),
240238 InputParam ("guidance_scale" , default = 3.5 ),
241- InputParam ("latents" , type_hint = torch .Tensor )
239+ InputParam ("latents" , type_hint = torch .Tensor ),
242240 ]
243241
244242 @property
@@ -261,7 +259,7 @@ def intermediate_outputs(self) -> List[OutputParam]:
261259 type_hint = int ,
262260 description = "The number of denoising steps to perform at inference time" ,
263261 ),
264- OutputParam ("guidance" , type_hint = torch .Tensor , description = "Optional guidance to be used." )
262+ OutputParam ("guidance" , type_hint = torch .Tensor , description = "Optional guidance to be used." ),
265263 ]
266264
267265 @torch .no_grad ()
@@ -340,10 +338,11 @@ def intermediate_outputs(self) -> List[OutputParam]:
340338 "latents" , type_hint = torch .Tensor , description = "The initial latents to use for the denoising process"
341339 ),
342340 OutputParam (
343- "latent_image_ids" , type_hint = torch .Tensor , description = "IDs computed from the image sequence needed for RoPE"
344- )
341+ "latent_image_ids" ,
342+ type_hint = torch .Tensor ,
343+ description = "IDs computed from the image sequence needed for RoPE" ,
344+ ),
345345 ]
346-
347346
348347 @staticmethod
349348 def check_inputs (components , block_state ):
@@ -417,7 +416,7 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip
417416 block_state .generator ,
418417 block_state .latents ,
419418 )
420-
419+
421420 self .set_block_state (state , block_state )
422421
423422 return components , state
0 commit comments