@@ -118,15 +118,6 @@ def retrieve_latents(
118118        raise  AttributeError ("Could not access latents of provided encoder_output" )
119119
120120
121- # TODO: align this with Qwen patchifier 
122- def  _pack_latents (latents , batch_size , num_channels_latents , height , width ):
123-     latents  =  latents .view (batch_size , num_channels_latents , height  //  2 , 2 , width  //  2 , 2 )
124-     latents  =  latents .permute (0 , 2 , 4 , 1 , 3 , 5 )
125-     latents  =  latents .reshape (batch_size , (height  //  2 ) *  (width  //  2 ), num_channels_latents  *  4 )
126- 
127-     return  latents 
128- 
129- 
130121def  _get_initial_timesteps_and_optionals (
131122    transformer ,
132123    scheduler ,
@@ -398,16 +389,15 @@ def prepare_latents(
398389                f" size of { batch_size }  
399390            )
400391
401-         # TODO: move packing latents code to a patchifier 
392+         # TODO: move packing latents code to a patchifier similar to Qwen  
402393        latents  =  randn_tensor (shape , generator = generator , device = device , dtype = dtype )
403-         latents  =  _pack_latents (latents , batch_size , num_channels_latents , height , width )
394+         latents  =  FluxPipeline . _pack_latents (latents , batch_size , num_channels_latents , height , width )
404395
405396        return  latents 
406397
407398    @torch .no_grad () 
408399    def  __call__ (self , components : FluxModularPipeline , state : PipelineState ) ->  PipelineState :
409400        block_state  =  self .get_block_state (state )
410- 
411401        block_state .height  =  block_state .height  or  components .default_height 
412402        block_state .width  =  block_state .width  or  components .default_width 
413403        block_state .device  =  components ._execution_device 
@@ -557,3 +547,73 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip
557547        self .set_block_state (state , block_state )
558548
559549        return  components , state 
550+ 
551+ 
552+ class  FluxKontextRoPEInputsStep (ModularPipelineBlocks ):
553+     model_name  =  "flux-kontext" 
554+ 
555+     @property  
556+     def  description (self ) ->  str :
557+         return  "Step that prepares the RoPE inputs for the denoising process of Flux Kontext. Should be placed after text encoder and latent preparation steps." 
558+ 
559+     @property  
560+     def  inputs (self ) ->  List [InputParam ]:
561+         return  [
562+             InputParam (name = "image_height" ),
563+             InputParam (name = "image_width" ),
564+             InputParam (name = "height" ),
565+             InputParam (name = "width" ),
566+             InputParam (name = "prompt_embeds" ),
567+         ]
568+ 
569+     @property  
570+     def  intermediate_outputs (self ) ->  List [OutputParam ]:
571+         return  [
572+             OutputParam (
573+                 name = "txt_ids" ,
574+                 kwargs_type = "denoiser_input_fields" ,
575+                 type_hint = List [int ],
576+                 description = "The sequence lengths of the prompt embeds, used for RoPE calculation." ,
577+             ),
578+             OutputParam (
579+                 name = "img_ids" ,
580+                 kwargs_type = "denoiser_input_fields" ,
581+                 type_hint = List [int ],
582+                 description = "The sequence lengths of the image latents, used for RoPE calculation." ,
583+             ),
584+         ]
585+ 
586+     def  __call__ (self , components : FluxModularPipeline , state : PipelineState ) ->  PipelineState :
587+         block_state  =  self .get_block_state (state )
588+ 
589+         prompt_embeds  =  block_state .prompt_embeds 
590+         device , dtype  =  prompt_embeds .device , prompt_embeds .dtype 
591+         block_state .txt_ids  =  torch .zeros (prompt_embeds .shape [1 ], 3 ).to (
592+             device = prompt_embeds .device , dtype = prompt_embeds .dtype 
593+         )
594+ 
595+         img_ids  =  None 
596+         if  (
597+             getattr (block_state , "image_height" , None ) is  not None 
598+             and  getattr (block_state , "image_width" , None ) is  not None 
599+         ):
600+             image_latent_height  =  2  *  (int (block_state .image_height ) //  (components .vae_scale_factor  *  2 ))
601+             image_latent_width  =  2  *  (int (block_state .width ) //  (components .vae_scale_factor  *  2 ))
602+             img_ids  =  FluxPipeline ._prepare_latent_image_ids (
603+                 None , image_latent_height  //  2 , image_latent_width  //  2 , device , dtype 
604+             )
605+             # image ids are the same as latent ids with the first dimension set to 1 instead of 0 
606+             img_ids [..., 0 ] =  1 
607+ 
608+         height  =  2  *  (int (block_state .height ) //  (components .vae_scale_factor  *  2 ))
609+         width  =  2  *  (int (block_state .width ) //  (components .vae_scale_factor  *  2 ))
610+         latent_ids  =  FluxPipeline ._prepare_latent_image_ids (None , height  //  2 , width  //  2 , device , dtype )
611+ 
612+         if  img_ids  is  not None :
613+             latent_ids  =  torch .cat ([latent_ids , img_ids ], dim = 0 )
614+ 
615+         block_state .img_ids  =  latent_ids 
616+ 
617+         self .set_block_state (state , block_state )
618+ 
619+         return  components , state 
0 commit comments