@@ -74,15 +74,23 @@ def pe_selection_index_based_on_dim(self, h, w):
7474 # PE will be viewed as 2d-grid, and H/p x W/p of the PE will be selected
7575 # because original input are in flattened format, we have to flatten this 2d grid as well.
7676 h_p , w_p = h // self .patch_size , w // self .patch_size
77- original_pe_indexes = torch .arange (self .pos_embed .shape [1 ])
7877 h_max , w_max = int (self .pos_embed_max_size ** 0.5 ), int (self .pos_embed_max_size ** 0.5 )
79- original_pe_indexes = original_pe_indexes .view (h_max , w_max )
78+
79+ # Calculate the top-left corner indices for the centered patch grid
8080 starth = h_max // 2 - h_p // 2
81- endh = starth + h_p
8281 startw = w_max // 2 - w_p // 2
83- endw = startw + w_p
84- original_pe_indexes = original_pe_indexes [starth :endh , startw :endw ]
85- return original_pe_indexes .flatten ()
82+
83+ # Generate the row and column indices for the desired patch grid
84+ rows = torch .arange (starth , starth + h_p , device = self .pos_embed .device )
85+ cols = torch .arange (startw , startw + w_p , device = self .pos_embed .device )
86+
87+ # Create a 2D grid of indices
88+ row_indices , col_indices = torch .meshgrid (rows , cols , indexing = "ij" )
89+
90+ # Convert the 2D grid indices to flattened 1D indices
91+ selected_indices = (row_indices * w_max + col_indices ).flatten ()
92+
93+ return selected_indices
8694
8795 def forward (self , latent ):
8896 batch_size , num_channels , height , width = latent .size ()
@@ -275,17 +283,17 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From
275283 sample_size (`int`): The width of the latent images. This is fixed during training since
276284 it is used to learn a number of position embeddings.
277285 patch_size (`int`): Patch size to turn the input data into small patches.
278- in_channels (`int`, *optional*, defaults to 16 ): The number of channels in the input.
286+ in_channels (`int`, *optional*, defaults to 4 ): The number of channels in the input.
279287 num_mmdit_layers (`int`, *optional*, defaults to 4): The number of layers of MMDiT Transformer blocks to use.
280- num_single_dit_layers (`int`, *optional*, defaults to 4 ):
288+ num_single_dit_layers (`int`, *optional*, defaults to 32 ):
281289 The number of layers of Transformer blocks to use. These blocks use concatenated image and text
282290 representations.
283- attention_head_dim (`int`, *optional*, defaults to 64 ): The number of channels in each head.
284- num_attention_heads (`int`, *optional*, defaults to 18 ): The number of heads to use for multi-head attention.
291+ attention_head_dim (`int`, *optional*, defaults to 256 ): The number of channels in each head.
292+ num_attention_heads (`int`, *optional*, defaults to 12 ): The number of heads to use for multi-head attention.
285293 joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
286294 caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`.
287- out_channels (`int`, defaults to 16 ): Number of output channels.
288- pos_embed_max_size (`int`, defaults to 4096 ): Maximum positions to embed from the image latents.
295+ out_channels (`int`, defaults to 4 ): Number of output channels.
296+ pos_embed_max_size (`int`, defaults to 1024 ): Maximum positions to embed from the image latents.
289297 """
290298
291299 _no_split_modules = ["AuraFlowJointTransformerBlock" , "AuraFlowSingleTransformerBlock" , "AuraFlowPatchEmbed" ]
0 commit comments