- 
                Notifications
    You must be signed in to change notification settings 
- Fork 6.4k
[core] CogVideoX memory optimizations in VAE encode #9340
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
(cherry picked from commit bf890bc)
| The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. | 
| # setting it to anything other than 2 would give poor results because the VAE hasn't been trained to be adaptive with different | ||
| # number of temporal frames. | ||
| self.num_latent_frames_batch_size = 2 | ||
| self.num_sample_frames_batch_size = 8 | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this be configurable? Why change to a different hardcoded value?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be configurable to multiples of temporal_compression_ratio which is 4. We can configure this using vae.num_sample_frames_batch_size = 16 for the moment, and I think it's best to do it that way. 8, here, corresponds to 1 second of video because Cog was trained on 8 FPS videos. I would leave the configuration option out of enable_tiling and keep it for power users who have looked through the code and understand its usage tbh
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Just a comment about whether num_sample_frames_batch_size can be made configurable?
fake context parallel cache, vae encode tiling (cherry picked from commit bf890bc)
| This optimization doesn't seem to be the right thing to do I think maybe? | 
| gather_norm is just the SpatialNorm3D layer in the VAE. It contains one GroupNorm and two Conv3D layers. Since GroupNorm only operates on the channel dimension in groups of fixed size, you can chunk inference across all the other dimensions (in this case, we are chunking about the frame dimension). The Conv3D is a bit more tricky to understand because of the intermediate conv_cache padding, but it is also independent of the frame dimension as long as we correctly handle padding (otherwise conv across temporal dimension may lead to lower or higher frames than expected depending on whether you are doing encoding or decoding). If there was any operation in which the channel dimension was involved but was not independent of frame dimension (for example, if there were any attention layers), then framewise encoding/decoding would not have been possible. This is the case with Mochi VAE, for example. The encoder contains attention layer, so you can't chunk across frame dimension and perform framewise encoding, but since decoder is comprised only of convolutions (with appropriate padding) and norms, computation can be done independent of the frame dimension. I hope that makes sense haha! | 
| Yes I know what you mean about conv3d because I'm just adapting the diffusers implementation with real context parallel to support my distillation training. | 
| If the gather_norm is enabled, it means that the mean and std need to take consideration of the whole TimeFrame sequence, this will disable us to conduct framewise encoding/decoding. If the gather_norm is disabled, we need to carefully align with the training process to have a specific X frames per forward. Is that right? I have been confused about this for a long time... | 
| Maybe I can give a simpler example with  B, S = 4, 100
model = nn.Linear(256, 1024)
input = torch.randn((B, S, 256))
# Normal forward
output1 = model(input)
# Chunking across batch dimension with total 256 // 4 = 64 forward passes
output2 = torch.cat([model(x) for x in input.split(4, dim=0)], dim=0)
# Chunking across sequence length dimension with total 100 // 20 = 5 forward passes
output3 = torch.cat([model(x) for x in input.split(20, dim=1)], dim=1)
# Collapsing across single dimension with total (256 * 100) // 50 = 512 forward passes
input = input.flatten(0, 1)
output4 = torch.cat([model(x) for x in input.split(50, dim=0)], dim=0)
output4 = output4.unflatten(0, (B, -1))The outputs in all these cases are the exact same. Due to numerical precision issues and rounding and order of operations, there is maybe a small difference at the order of 1e-7 to 1e-15 or lower, which is negligible. This kind of chunking across dimension is possible for any kind of layer when done correctly. As group norm only affects small number of groups on the channel dimension, computation is fully independent of all the other dimensions (this is why you can perform spatial tiling as well as framewise decoding). It is really equivalent and there is no mistake in doing this. For your context parallel implementation, the simplest thing you could do for parallelizing is collapse your tensor to a  On another note, we are trying to look into distillation of CogVideoX too internally, so if you're interested in collaborating on the training and further future work, we would definitely love to join you! If this is of interest, let me know and I will setup a Slack channel where we can communicate together | 
| Of course I would be happy to do so!!! I think I have been in a channel with @sayakpaul about our Hyper-SD. | 
| It means that it is not like linear layers to be independent of the other dimensions I think! | 
fake context parallel cache, vae encode tiling (cherry picked from commit bf890bc)

(cherry picked from commit bf890bc)
What does this PR do?
Adds VAE encode memory optimizations from #9333 separately for a cleaner history of additions and because this is needed in #9302.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@DN6 @yiyixuxu