-
Notifications
You must be signed in to change notification settings - Fork 50
checkpointing/HTTPTransport: added streaming serialization and parallel transfer support #106
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
acb04b0 to
72432f1
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, thanks for the change!
python torchft/checkpointing/http_transport.py --num-chunks 0 --device cuda
INFO:main:fetching checkpoint took 12.673462141305208s
How come --num-chunks 0 with cuda is slower than CPU?
| return output_list | ||
|
|
||
|
|
||
| def bench_main() -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe we could move the benchmarking code to its own folder
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved to a _bench.py file and added a small test to make sure benchmark doesn't regress
| return tree_unflatten(values, spec) | ||
|
|
||
|
|
||
| def _to_cpu(values: List[T], pin_memory: bool) -> List[T]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: could you use tree_map here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tree_map does the same flatten+unflatten so I think I'll just keep it like this to avoid duplicate mapping
…el transfer support
72432f1 to
a877119
Compare
|
Fixed slowness w/ CUDA due to duplicate transfers to CPU |
This makes HTTP recovery much much faster with a few key changes:
This doesn't change the behavior of Manager by default as it seems that using parallel chunks w/ torch.save/load can actually increase time significantly as PyTorchStreamReader holds the GIL during deserialization.
The optimal config is with streaming and chunking enabled
Test plan:
Testing with 12GB total and 1MB tensors
pytorch nightly
pytorch 2.6.0