Skip to content

Commit 21075fa

Browse files
author
bghira
committed
add skip_layers argument to SD3 transformer model class
1 parent 76b7d86 commit 21075fa

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

src/diffusers/models/transformers/transformer_sd3.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,7 @@ def forward(
268268
block_controlnet_hidden_states: List = None,
269269
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
270270
return_dict: bool = True,
271+
skip_layers: Optional[List[int]] = None,
271272
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
272273
"""
273274
The [`SD3Transformer2DModel`] forward method.
@@ -279,9 +280,9 @@ def forward(
279280
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
280281
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
281282
from the embeddings of input conditions.
282-
timestep ( `torch.LongTensor`):
283+
timestep (`torch.LongTensor`):
283284
Used to indicate denoising step.
284-
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
285+
block_controlnet_hidden_states (`list` of `torch.Tensor`):
285286
A list of tensors that if specified are added to the residuals of transformer blocks.
286287
joint_attention_kwargs (`dict`, *optional*):
287288
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
@@ -290,6 +291,8 @@ def forward(
290291
return_dict (`bool`, *optional*, defaults to `True`):
291292
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
292293
tuple.
294+
skip_layers (`list` of `int`, *optional*):
295+
A list of layer indices to skip during the forward pass.
293296
294297
Returns:
295298
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
@@ -317,6 +320,13 @@ def forward(
317320
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
318321

319322
for index_block, block in enumerate(self.transformer_blocks):
323+
# Skip specified layers
324+
if skip_layers is not None and index_block in skip_layers:
325+
if block_controlnet_hidden_states is not None and block.context_pre_only is False:
326+
interval_control = len(self.transformer_blocks) // len(block_controlnet_hidden_states)
327+
hidden_states = hidden_states + block_controlnet_hidden_states[index_block // interval_control]
328+
continue
329+
320330
if self.training and self.gradient_checkpointing:
321331

322332
def create_custom_forward(module, return_dict=None):
@@ -336,7 +346,6 @@ def custom_forward(*inputs):
336346
temb,
337347
**ckpt_kwargs,
338348
)
339-
340349
else:
341350
encoder_hidden_states, hidden_states = block(
342351
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb

0 commit comments

Comments
 (0)