@@ -1068,17 +1068,15 @@ def forward(
10681068 latent_sequence_length = hidden_states .shape [1 ]
10691069 condition_sequence_length = encoder_hidden_states .shape [1 ]
10701070 sequence_length = latent_sequence_length + condition_sequence_length
1071- attention_mask = torch .zeros (
1071+ attention_mask = torch .ones (
10721072 batch_size , sequence_length , device = hidden_states .device , dtype = torch .bool
10731073 ) # [B, N]
1074-
10751074 effective_condition_sequence_length = encoder_attention_mask .sum (dim = 1 , dtype = torch .int ) # [B,]
10761075 effective_sequence_length = latent_sequence_length + effective_condition_sequence_length
1077-
1078- for i in range (batch_size ):
1079- attention_mask [i , : effective_sequence_length [i ]] = True
1080- # [B, 1, 1, N], for broadcasting across attention heads
1081- attention_mask = attention_mask .unsqueeze (1 ).unsqueeze (1 )
1076+ indices = torch .arange (sequence_length , device = hidden_states .device ).unsqueeze (0 ) # [1, N]
1077+ mask_indices = indices >= effective_sequence_length .unsqueeze (1 ) # [B, N]
1078+ attention_mask = attention_mask .masked_fill (mask_indices , False )
1079+ attention_mask = attention_mask .unsqueeze (1 ).unsqueeze (1 ) # [B, 1, 1, N]
10821080
10831081 # 4. Transformer blocks
10841082 if torch .is_grad_enabled () and self .gradient_checkpointing :
0 commit comments