diff --git a/train_diloco.py b/train_diloco.py index 0c6b9cf6..e207e73e 100644 --- a/train_diloco.py +++ b/train_diloco.py @@ -34,7 +34,7 @@ ProcessGroupGloo, ProcessGroupNCCL, ) -from torchft.checkpointing.pg_transport import PGTransport +from torchft.checkpointing.http_transport import HTTPTransport from torchft.local_sgd import DiLoCo logging.basicConfig(level=logging.INFO) @@ -67,13 +67,12 @@ def state_dict(): timeout=timedelta(seconds=10), ) if torch.cuda.is_available() and USE_NCCL - else ProcessGroupGloo(timeout=timedelta(seconds=5)) + else ProcessGroupGloo(timeout=timedelta(seconds=10)) ) - transport = PGTransport( - pg, + transport = HTTPTransport( timeout=timedelta(seconds=10), - device=device, + num_chunks=0, ) manager = Manager(