Skip to content

Conversation

@d4l3k
Copy link
Member

@d4l3k d4l3k commented Feb 11, 2025

This makes HTTP recovery much much faster with a few key changes:

  1. use RWLock so multiple readers can be fetching the state_dict at the same time
  2. use _streaming_save/_streaming_load when available which is ~2x faster on average and avoids the 2x memory overhead (requires PT nightly)
  3. support parallel transfers by using pytree to divy up the leaf values into chunks which can be sent via parallel HTTP requests

This doesn't change the behavior of Manager by default as it seems that using parallel chunks w/ torch.save/load can actually increase time significantly as PyTorchStreamReader holds the GIL during deserialization.

The optimal config is with streaming and chunking enabled

Test plan:

pytest

Testing with 12GB total and 1MB tensors

pytorch nightly

$ python torchft/checkpointing/http_transport.py --num-chunks 0
INFO:__main__:fetching checkpoint took 6.626614563167095s
$ python torchft/checkpointing/http_transport.py --num-chunks 10
INFO:__main__:fetching checkpoint took 3.0460726767778397s
$ python torchft/checkpointing/http_transport.py --num-chunks 0 --device cuda
INFO:__main__:fetching checkpoint took 6.147395346313715s
$ python torchft/checkpointing/http_transport.py --num-chunks 10 --device cuda
INFO:__main__:fetching checkpoint took 2.9234009198844433s

pytorch 2.6.0

$ python torchft/checkpointing/http_transport.py --num-chunks 0
INFO:__main__:fetching checkpoint took 17.019980508834124s
$ python torchft/checkpointing/http_transport.py --num-chunks 10
INFO:__main__:fetching checkpoint took 40.383272521197796s

@d4l3k d4l3k requested review from H-Huang and fegin February 11, 2025 02:19
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Feb 11, 2025
Copy link
Contributor

@H-Huang H-Huang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thanks for the change!

python torchft/checkpointing/http_transport.py --num-chunks 0 --device cuda
INFO:main:fetching checkpoint took 12.673462141305208s

How come --num-chunks 0 with cuda is slower than CPU?

return output_list


def bench_main() -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe we could move the benchmarking code to its own folder

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved to a _bench.py file and added a small test to make sure benchmark doesn't regress

return tree_unflatten(values, spec)


def _to_cpu(values: List[T], pin_memory: bool) -> List[T]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could you use tree_map here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tree_map does the same flatten+unflatten so I think I'll just keep it like this to avoid duplicate mapping

@d4l3k
Copy link
Member Author

d4l3k commented Feb 11, 2025

Fixed slowness w/ CUDA due to duplicate transfers to CPU

@d4l3k d4l3k merged commit f44aaa5 into main Feb 11, 2025
6 checks passed
@d4l3k d4l3k deleted the d4l3k/fast_http branch February 11, 2025 21:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants