Skip to content

Commit be3e833

Browse files
authored
allow using gloo from flag (#239)
Summary: add an env var that enables using gloo for the sample training script
1 parent 7db8d26 commit be3e833

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

train_diloco.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
os.environ["NCCL_HOSTID"] = str(REPLICA_GROUP_ID)
1515

1616
USE_STREAMING = os.getenv("USE_STREAMING", "False") == "True"
17+
USE_NCCL = os.getenv("USE_NCCL", "False") == "True"
1718

1819
import torch
1920
import torch.nn.functional as F
@@ -60,19 +61,19 @@ def state_dict():
6061
"outer_optim": outer_optimizer.state_dict(),
6162
}
6263

63-
device = "cuda" if torch.cuda.is_available() else "cpu"
64+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
6465
pg = (
6566
ProcessGroupNCCL(
6667
timeout=timedelta(seconds=10),
6768
)
68-
if torch.cuda.is_available()
69+
if torch.cuda.is_available() and USE_NCCL
6970
else ProcessGroupGloo(timeout=timedelta(seconds=5))
7071
)
7172

7273
transport = PGTransport(
7374
pg,
7475
timeout=timedelta(seconds=10),
75-
device=("cuda" if torch.cuda.is_available() else "cpu"),
76+
device=device,
7677
)
7778

7879
manager = Manager(

0 commit comments

Comments
 (0)