-
Notifications
You must be signed in to change notification settings - Fork 6.4k
add skip_layers argument to SD3 transformer model class #9880
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add skip_layers argument to SD3 transformer model class #9880
Conversation
cbeef91 to
21075fa
Compare
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
this is a part of the work to implement skip-layer guidance for CFG in SD 3.5 Medium. the recommendation from SAI is to skip layers 7, 8, 9 when doing negative guidance. so this will have to be altered for that to work. |
|
cc @asomoza here |
|
upstream; # Run cond and uncond in a batch together
batched = self.model.apply_model(
torch.cat([x, x]),
torch.cat([timestep, timestep]),
c_crossattn=torch.cat([cond["c_crossattn"], uncond["c_crossattn"]]),
y=torch.cat([cond["y"], uncond["y"]]),
)
# Then split and apply CFG Scaling
pos_out, neg_out = batched.chunk(2)
scaled = neg_out + (pos_out - neg_out) * cond_scale
# Then run with skip layer
if (
self.slg > 0
and self.step > (self.skip_start * self.steps)
and self.step < (self.skip_end * self.steps)
):
skip_layer_out = self.model.apply_model(
x,
timestep,
c_crossattn=cond["c_crossattn"],
y=cond["y"],
skip_layers=self.skip_layers,
)
# Then scale acc to skip layer guidance
scaled = scaled + (pos_out - skip_layer_out) * self.slg |
|
link to #9819 |
|
ohh got it |
|
@vladmandic see the above pipeline commit if you're interested |
|
I've asked @Dango233 for a review, let's work on this PR and get it merged soon:) |
|
i think a possible improvement would be to dynamically determine whether to skip a layer and what scale to apply but its a bit hard to do. |
|
closing due to lack of updates. |
|
for anyone wanting this feature it seems candle has an interest in keeping up with community pull requests. the diffusers project has been falling behind quite a bit lately in addressing development. |
|
@bghira i agree diffusers have been falling behind lately, but why close this pr? @yiyixuxu @sayakpaul what can we do here? |
|
they said some other pull request was coming, so, i assume the preference is for that pull request. at this point i've simply advocated for internal fork of Diffusers that fulfils our needs and merely cherry-pick fixes from this project where it makes sense to. |
|
initially, the author of SD3.5 was going to send a pull request for this feature around the same time, that's why I said this #9880 (comment). I checked with him after we received your PR, and we agreed we should move forward with this PR, and have him to do a review instead, I was traveling for the past week and could not follow up. I understand your frustration. If you are willing to re-open the PR, we can continue to work on this and get it merged soon. Otherwise, we will add it in a new PR and add you as a co-author (cc @asomoza here ) |
|
regardless do we reopen this pr or implement independently, we need to figure out how to make diffusers current, at least when it comes to top models - this is slg is the default behavior in sd35-medium for 3+ weeks now and we're still discussing it here. i understand, that cannot be done for all models, but for something that is currently in top-3 we should aim for parity much faster. |
|
as the branch has been deleted, the pull request can no longer be reopened edit: found a hidden button |
|
@vladmandic the strangest thing is i thought this would be much harder because no one had gotten around to it yet, which is why it took me so long to even get around to attempting it. when i saw how simple it was, i had it working in under an hour. but then made me greatly confused why SD 3.5 Medium in Diffusers is so far behind since even I was capable of doing it. |
| for index_block, block in enumerate(self.transformer_blocks): | ||
| # Skip specified layers | ||
| if skip_layers is not None and index_block in skip_layers: | ||
| if block_controlnet_hidden_states is not None and block.context_pre_only is False: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we skip the block of code that needs to be skipped instead of adding duplicated code here
otherwise, if we have to change this part of the code that handles controlnet residual in the future, we have to remember to change both places, which is not great
| ) | ||
|
|
||
| self._guidance_scale = guidance_scale | ||
| self._skip_layer_guidance_scale = skip_layer_guidance_scale |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need to add a decorator for this too, like this
diffusers/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py
Line 642 in 07d0fbf
| def guidance_scale(self): |
src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py
Show resolved
Hide resolved
|
@bghira for what I take from the other discussion, you would prefer if we just take over this PR and finish it right? |
|
@asomoza I made the change I requested - feel free to merge after you test it and think is ok. |
|
thanks! |
|
remember that it's applying a skip to the negative prompt so that part of things also impacts testing. for the use case for simpletuner it is to improve the results of validation so that they match inference in comfyUI where skip-layer guidance is more often than not used to improve the results. after introducing the feature, it's often reported that a given user does not want to disable the option anymore, because the results so much more closely match their inference results in CUI. |
* add skip_layers argument to SD3 transformer model class * add unit test for skip_layers in stable diffusion 3 * sd3: pipeline should support skip layer guidance * up --------- Co-authored-by: bghira <[email protected]> Co-authored-by: yiyixuxu <[email protected]>













What does this PR do?
Adds skip_layer parameter to the transformer model class for stable diffusion 3.
I can un-bundle the batched CFG and also include the pipeline changes for this pull request if you require.
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@sayakpaul @yiyixuxu