-
Notifications
You must be signed in to change notification settings - Fork 6.4k
[core] feat: support group offloading at the pipeline level #12283
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
record_stream: bool = False, | ||
low_cpu_mem_usage=False, | ||
offload_to_disk_path: Optional[str] = None, | ||
exclude_modules: Optional[Union[str, List[str]]] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's okay to expose this as an argument as opposed to how we do model_cpu_offload_seq
, for example:
model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae" |
This is because model CPU offloading relies on a sequence for device management. I don't think we have that constraint in the case of group offloading.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Really neat how simplified this has become! 🎉 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! We could reduce LoC by creating a kwargs dict and passing into both branches, but definitely not a blocker
Tests look good but maybe the two could be combined into one to reduce total run count
init Co-authored-by: Sayak Paul <[email protected]>
What does this PR do?
to
Of course, if users still want to apply different offloading techniques to different model-level components, they can easily choose to do so. But IMO,
enable_group_offload()
is an easier entrypoint.We can allow users to pass mappings like we do for
quant_mapping
inPipelineQuantizationConfig
in the future.Will request for reviews after CI.
TODOs