@@ -268,6 +268,7 @@ def forward(
268268 block_controlnet_hidden_states : List = None ,
269269 joint_attention_kwargs : Optional [Dict [str , Any ]] = None ,
270270 return_dict : bool = True ,
271+ skip_layers : Optional [List [int ]] = None ,
271272 ) -> Union [torch .FloatTensor , Transformer2DModelOutput ]:
272273 """
273274 The [`SD3Transformer2DModel`] forward method.
@@ -279,9 +280,9 @@ def forward(
279280 Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
280281 pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
281282 from the embeddings of input conditions.
282- timestep ( `torch.LongTensor`):
283+ timestep (`torch.LongTensor`):
283284 Used to indicate denoising step.
284- block_controlnet_hidden_states: (`list` of `torch.Tensor`):
285+ block_controlnet_hidden_states (`list` of `torch.Tensor`):
285286 A list of tensors that if specified are added to the residuals of transformer blocks.
286287 joint_attention_kwargs (`dict`, *optional*):
287288 A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
@@ -290,6 +291,8 @@ def forward(
290291 return_dict (`bool`, *optional*, defaults to `True`):
291292 Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
292293 tuple.
294+ skip_layers (`list` of `int`, *optional*):
295+ A list of layer indices to skip during the forward pass.
293296
294297 Returns:
295298 If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
@@ -317,6 +320,13 @@ def forward(
317320 encoder_hidden_states = self .context_embedder (encoder_hidden_states )
318321
319322 for index_block , block in enumerate (self .transformer_blocks ):
323+ # Skip specified layers
324+ if skip_layers is not None and index_block in skip_layers :
325+ if block_controlnet_hidden_states is not None and block .context_pre_only is False :
326+ interval_control = len (self .transformer_blocks ) // len (block_controlnet_hidden_states )
327+ hidden_states = hidden_states + block_controlnet_hidden_states [index_block // interval_control ]
328+ continue
329+
320330 if self .training and self .gradient_checkpointing :
321331
322332 def create_custom_forward (module , return_dict = None ):
@@ -336,7 +346,6 @@ def custom_forward(*inputs):
336346 temb ,
337347 ** ckpt_kwargs ,
338348 )
339-
340349 else :
341350 encoder_hidden_states , hidden_states = block (
342351 hidden_states = hidden_states , encoder_hidden_states = encoder_hidden_states , temb = temb
0 commit comments