Skip to content

Commit 67645cd

Browse files
committed
Integrate TorchFT
**Summary** This is a WIP TorchFT integration PR. **Current Issues** This doesn't work at this moment as there are hanged groups when a new group joins. **Issue 1:** ~Group 0 and group 1 will hang during the first `should_commit` after group 1 applying the pending state_dict from group 0.~ Fixed with: meta-pytorch/torchft#83 **Issue 2:** ~Group 0 and group 1 will pass the `should_commit` but group 0 needs healing which is wrong and the healing process will cause another hang.~ Fixed with: meta-pytorch/torchft#83 **Issue 3:** ~The byproduct of issue 1 and issue 2: group 1 will continue to print out~ ``` [rank0]:devgpu051:76838:80357 [0] misc/socket.cc:50 NCCL WARN socketProgress: Connection closed by remote peer devgpu051.cln3.svc.fbinfra.net<33618> ``` Fixed with meta-pytorch/torchft#91 and several other fixes. **Issue 4:** When there are 3 groups, everyone requests the state dict every step. ***How to reproduce?*** Using the `Reproduce steps` to run 2 groups, then add another group by modifying the command. Seems to be fixed, will need more tests. **Issue 5:** Hang will happen if using functional collective. ***How to reproduce?*** Pull the latest version of this PR and comment out line 41 and uncomment line 42 in `torchtitan/utils.py` **Reproduce steps:** 1. Patch TorchFT with meta-pytorch/torchft#82 2. Execute lighthouse 3. Execute the following command in one terminal: ``` TORCHFT_MANAGER_PORT=29520 REPLICA_GROUP_ID=0 CUDA_VISIBLE_DEVICES=0,1 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=0 ``` 4. Wait 10 seconds, execute following command in another terminal: ``` TORCHFT_MANAGER_PORT=29522 REPLICA_GROUP_ID=1 CUDA_VISIBLE_DEVICES=2,3 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=1 ``` ghstack-source-id: 9fba357 Pull Request resolved: #834
1 parent 3d9e881 commit 67645cd

File tree

8 files changed

+277
-42
lines changed

8 files changed

+277
-42
lines changed

run_train.sh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@ if [ $# -ne 0 ]; then
1919
overrides="$*"
2020
fi
2121

22+
TORCHFT_MANAGER_PORT=${TORCHFT_MANAGER_PORT:-"29512"}
23+
2224
PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \
25+
TORCHFT_LIGHTHOUSE=http://localhost:29510 \
26+
TORCHFT_MANAGER_PORT=${TORCHFT_MANAGER_PORT} \
2327
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
2428
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
2529
torchtitan/train.py --job.config_file ${CONFIG_FILE} $overrides

torchtitan/components/checkpoint.py

Lines changed: 101 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,7 @@ class CheckpointManager:
225225
states (Dict[str, Any]): The states that need to be saved, other than the
226226
previous 4 components.
227227
job_config (JobConfig): The job config used to configure the checkpointing.
228+
ft_manager (Optional[ft.Manager]): The FTManager from TorchFT.
228229
"""
229230

230231
def __init__(
@@ -235,16 +236,41 @@ def __init__(
235236
lr_schedulers: LRSchedulersContainer,
236237
states: Dict[str, Any],
237238
job_config: JobConfig,
239+
ft_manager: Optional["ft.Manager"] = None,
238240
) -> None:
239241
ckpt_config = job_config.checkpoint
240242
self.enable_checkpoint = ckpt_config.enable_checkpoint
243+
self.ft_manager = ft_manager
244+
245+
if self.ft_manager:
246+
optimizers.init_cache_state_dict()
247+
248+
def state_dict():
249+
ret = {}
250+
for k, v in self.states.items():
251+
if k in {
252+
MODEL,
253+
OPTIMIZER,
254+
LR_SCHEDULER,
255+
TRAIN_STATE,
256+
}:
257+
ret[k] = v.state_dict()
258+
return ret
259+
260+
def load_state_dict(state_dict):
261+
assert state_dict is not None
262+
for k, v in state_dict.items():
263+
self.states[k].load_state_dict(v)
264+
265+
ft_manager.set_state_dict_fns(load_state_dict, state_dict)
266+
self.ft_replica_id = job_config.experimental.ft_replica_id
241267

242268
async_mode = ckpt_config.async_mode.lower()
243269
self.enable_staging = (
244270
self.enable_checkpoint and async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM
245-
)
271+
) or self.ft_manager
246272

247-
if not self.enable_checkpoint:
273+
if not self.enable_checkpoint and self.ft_manager is None:
248274
return
249275

250276
self.states = states
@@ -256,6 +282,13 @@ def __init__(
256282
LR_SCHEDULER: lr_schedulers,
257283
}
258284
)
285+
self.ft_states = {DATALOADER: dataloader}
286+
287+
self.staging = False
288+
self.sending_to_checkpoint_mp = False
289+
self.staging_id = None
290+
self.cpu_offload_state_dict = None
291+
self.staging_stream = torch.cuda.Stream() if self.enable_staging else None
259292

260293
self.staging = False
261294
self.sending_to_checkpoint_mp = False
@@ -266,7 +299,7 @@ def __init__(
266299
self.folder = os.path.join(job_config.job.dump_folder, ckpt_config.folder)
267300
self.interval = ckpt_config.interval
268301
async_mode = ckpt_config.async_mode.lower()
269-
if async_mode == AsyncMode.ASYNC:
302+
if async_mode == AsyncMode.ASYNC or self.ft_manager:
270303
self.pg = dist.new_group(backend="gloo")
271304

272305
self.keep_latest_k = ckpt_config.keep_latest_k
@@ -340,35 +373,44 @@ def save(self, curr_step: int, force: bool = False) -> None:
340373
None
341374
"""
342375

376+
if self.ft_manager:
377+
self._ft_save(curr_step)
378+
343379
if not self._should_save(curr_step, force):
344380
return
345381

346382
begin = time.monotonic()
347-
logger.info("Saving the checkpoint (or staging if async is enabled).")
348-
checkpoint_id = self._create_checkpoint_id(curr_step)
349-
self._async_wait()
350-
# This GC is called for async checkpoint as it is useless to do
351-
# GC right after async_save -- the CPU memory is not able to be
352-
# freed until _async_wait()
353-
if force:
354-
self._save_last_step(curr_step)
355-
elif self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
356-
GarbageCollection.collect("GC collection invoked by checkpointer.")
357-
self._async_with_pinned_memory(checkpoint_id)
358-
elif self.async_mode == AsyncMode.ASYNC:
359-
GarbageCollection.collect("GC collection invoked by checkpointer.")
360-
self.async_future = dcp.async_save(
361-
self.states, checkpoint_id=checkpoint_id, process_group=self.pg
362-
)
363-
GarbageCollection.collect("GC collection invoked by checkpointer.")
364-
else:
365-
save_with_gc(self.states, checkpoint_id=checkpoint_id)
366-
self._purge_stale_checkpoints()
383+
if not self.ft_manager or self.ft_manager.participating_rank() == 0:
384+
logger.info("Saving the checkpoint (or staging if async is enabled).")
385+
checkpoint_id = self._create_checkpoint_id(curr_step)
386+
self._async_wait()
387+
# This GC is called for async checkpoint as it is useless to do
388+
# GC right after async_save -- the CPU memory is not able to be
389+
# freed until _async_wait()
390+
if force:
391+
self._save_last_step(curr_step)
392+
elif self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
393+
GarbageCollection.collect("GC collection invoked by checkpointer.")
394+
self._async_with_pinned_memory(checkpoint_id)
395+
elif self.async_mode == AsyncMode.ASYNC:
396+
GarbageCollection.collect("GC collection invoked by checkpointer.")
397+
self.async_future = dcp.async_save(
398+
self.states, checkpoint_id=checkpoint_id, process_group=self.pg
399+
)
400+
GarbageCollection.collect("GC collection invoked by checkpointer.")
401+
else:
402+
save_with_gc(self.states, checkpoint_id=checkpoint_id)
403+
self._purge_stale_checkpoints()
367404

368-
logger.info(
369-
"Finished saving the checkpoint (or staging if async is enabled)"
370-
f"in {time.monotonic() - begin:.2f} seconds."
371-
)
405+
logger.info(
406+
"Finished saving the checkpoint (or staging if async is enabled)"
407+
f"in {time.monotonic() - begin:.2f} seconds."
408+
)
409+
elif self.ft_manager:
410+
logger.info(
411+
"Replica %d doesn't save checkpoint.",
412+
self.ft_manager.participating_rank(),
413+
)
372414

373415
@torch.no_grad()
374416
def load(self, step: int = -1) -> bool:
@@ -385,6 +427,9 @@ def load(self, step: int = -1) -> bool:
385427
bool: Whether the checkpoint was loaded successfully.
386428
"""
387429

430+
if self.ft_manager:
431+
self._ft_load()
432+
388433
if not self.enable_checkpoint or not os.path.isdir(self.folder):
389434
return False
390435

@@ -468,10 +513,36 @@ def _find_load_step(self, folder: str = "") -> int:
468513
return -1
469514
return max(step_counts)
470515

516+
def _ft_folder(self) -> str:
517+
return os.path.join(self.folder, f"ft-replicat-{self.ft_replica_id}")
518+
471519
def _create_checkpoint_id(self, step: int, folder: str = "") -> str:
472520
folder = folder if folder else self.folder
473521
return os.path.join(folder, f"step-{step}")
474522

523+
def _ft_save(self, step: int) -> None:
524+
begin = time.monotonic()
525+
self._async_wait()
526+
checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder())
527+
self.async_future = dcp.async_save(
528+
self.ft_states, checkpoint_id=checkpoint_id, process_group=self.pg
529+
)
530+
logger.info(f"Staging ft checkpoint took {time.monotonic() - begin} secs.")
531+
532+
def _ft_load(self) -> None:
533+
step = self._find_load_step(folder=self._ft_folder())
534+
if step == -1:
535+
return
536+
537+
begin = time.monotonic()
538+
logger.info(f"Loading the FT checkpoint at step {step}.")
539+
checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder())
540+
dcp.load(self.ft_states, checkpoint_id=checkpoint_id)
541+
GarbageCollection.collect("GC collection for checkpoint loading.")
542+
logger.info(
543+
f"Finished loading the ft checkpoint in {time.monotonic() - begin:.2f} seconds."
544+
)
545+
475546
def _states_to_load(self, step: int) -> Dict[str, Any]:
476547
"""Determines which states to load for the given step.
477548
@@ -492,6 +563,8 @@ def _states_to_load(self, step: int) -> Dict[str, Any]:
492563
for exclude_key in self.exclude_from_loading:
493564
if exclude_key not in states:
494565
raise ValueError(f"{exclude_key} not found in state_dict.")
566+
if self.ft_manager:
567+
states_to_load.pop(DATALOADER)
495568
return states_to_load
496569

497570
def _save_last_step(self, curr_step: int) -> None:
@@ -576,6 +649,7 @@ def _cpu_staging(self, checkpoint_id: Optional[str]) -> None:
576649
def _purge_stale_checkpoints(self):
577650
if (
578651
self.keep_latest_k > 0
652+
and self.ft_manager.participating_rank() == 0
579653
and dist.get_rank() == 0
580654
and os.path.isdir(self.folder)
581655
):

torchtitan/components/optimizer.py

Lines changed: 51 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import copy
88
import functools
9-
from typing import Any, Callable, Dict, Iterable, List
9+
from typing import Any, Callable, Dict, Iterable, List, Optional
1010

1111
import torch
1212
import torch.nn as nn
@@ -177,8 +177,49 @@ def zero_grad(self) -> None:
177177
pass
178178

179179

180+
class FTOptimizersContainer(OptimizersContainer):
181+
def __init__(
182+
self,
183+
model_parts: List[nn.Module],
184+
optimizer_kwargs: Dict[str, Any],
185+
name: str,
186+
ft_manager: Optional["ft.Manager"],
187+
) -> None:
188+
import torchft as ft
189+
190+
super().__init__(model_parts, optimizer_kwargs, name)
191+
192+
# Force to initialize the optimizer state so that `optim.step()`
193+
# won't be called by state_dict() and load_state_dict().
194+
_ = {
195+
k: v
196+
for sd in map(get_optimizer_state_dict, model_parts, self.optimizers)
197+
for k, v in sd.items()
198+
}
199+
self.optimizers = [
200+
ft.Optimizer(ft_manager, optim) for optim in self.optimizers
201+
]
202+
self.cache_state_dict: Dict[str, Any] = {}
203+
204+
def init_cache_state_dict(self) -> None:
205+
self.cache_state_dict = super().state_dict()
206+
207+
def state_dict(self) -> Dict[str, Any]:
208+
return self.cache_state_dict
209+
210+
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
211+
# We have to invalidate the `cache_state_dict` because optimizer uses
212+
# assign instead of copy when doing `load_state_dict()`. Without
213+
# invalidating the `cache_state_dict`, there will be memory leakage.
214+
self.cache_state_dict = {}
215+
super().load_state_dict(state_dict)
216+
self.init_cache_state_dict()
217+
218+
180219
def build_optimizers(
181-
model_parts: List[nn.Module], job_config: JobConfig
220+
model_parts: List[nn.Module],
221+
job_config: JobConfig,
222+
ft_manager: Optional["ft.Manager"] = None,
182223
) -> OptimizersContainer:
183224
"""Create a OptimizersContainer for the given model parts and job config.
184225
@@ -219,11 +260,14 @@ def build_optimizers(
219260
"foreach": foreach,
220261
}
221262

222-
return (
223-
OptimizersContainer(model_parts, optimizer_kwargs, name)
224-
if not optim_in_bwd
225-
else OptimizersInBackwardContainer(model_parts, optimizer_kwargs, name)
226-
)
263+
if optim_in_bwd and ft_manager:
264+
raise ValueError("TorchFT is not supported with optimizers in backward.")
265+
elif optim_in_bwd:
266+
return OptimizersInBackwardContainer(model_parts, optimizer_kwargs, name)
267+
elif ft_manager:
268+
return FTOptimizersContainer(model_parts, optimizer_kwargs, name, ft_manager)
269+
else:
270+
return OptimizersContainer(model_parts, optimizer_kwargs, name)
227271

228272

229273
class LRSchedulersContainer(Stateful):

torchtitan/config_manager.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -661,6 +661,30 @@ def __init__(self):
661661
action="store_true",
662662
)
663663

664+
self.parser.add_argument(
665+
"--experimental.enable_torchft",
666+
action="store_true",
667+
help="Enable TorchFT integration.",
668+
)
669+
670+
self.parser.add_argument(
671+
"--experimental.ft_replica_id",
672+
type=int,
673+
default=0,
674+
help="The TorchFT replica ID of this run.",
675+
)
676+
677+
self.parser.add_argument(
678+
"--experimental.ft_group_size",
679+
type=int,
680+
default=1,
681+
help="""
682+
The number of TorchFT replicate groups. This number will be used for
683+
dataloader to split the dataset across the replicate groups and FSDP
684+
dimension.
685+
""",
686+
)
687+
664688
def to_dict(self):
665689
return self.args_dict
666690

torchtitan/distributed/parallel_dims.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from dataclasses import dataclass
88
from functools import cached_property
9+
from typing import Any, Optional
910

1011
from torch.distributed.device_mesh import init_device_mesh
1112

@@ -24,6 +25,7 @@ class ParallelDims:
2425
pp: int
2526
world_size: int
2627
enable_loss_parallel: bool
28+
ft_manager: Optional["ft.Manager"]
2729

2830
def __post_init__(self):
2931
self._validate()
@@ -56,13 +58,24 @@ def build_mesh(self, device_type):
5658
[self.pp, self.dp_replicate, self.dp_shard, self.cp, self.tp],
5759
["pp", "dp_replicate", "dp_shard", "cp", "tp"],
5860
):
59-
if d > 1:
61+
if d > 1 or (name == "dp_replicate" and self.ft_manager is not None):
6062
dims.append(d)
6163
names.append(name)
6264

6365
logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}")
6466
names = tuple(names)
65-
mesh = init_device_mesh(device_type, dims, mesh_dim_names=names)
67+
if self.ft_manager is None:
68+
mesh = init_device_mesh(device_type, dims, mesh_dim_names=names)
69+
else:
70+
from torchft.process_group import ft_init_device_mesh
71+
72+
mesh = ft_init_device_mesh(
73+
device_type=device_type,
74+
mesh_shape=dims,
75+
mesh_dim_names=names,
76+
replicate_dim=names.index("dp_replicate"),
77+
manager=self.ft_manager,
78+
)
6679

6780
# Create all the submesh here to ensure all required process groups are
6881
# initialized:
@@ -73,7 +86,7 @@ def build_mesh(self, device_type):
7386
# Mesh for loss all-reduce
7487
dp_cp_mesh_dim_names = []
7588

76-
if self.dp_replicate_enabled:
89+
if self.dp_replicate_enabled or self.ft_manager is not None:
7790
dp_mesh_dim_names.append("dp_replicate")
7891
dp_cp_mesh_dim_names.append("dp_replicate")
7992
if self.dp_shard_enabled:
@@ -101,7 +114,7 @@ def dp_enabled(self):
101114

102115
@property
103116
def dp_replicate_enabled(self):
104-
return self.dp_replicate > 1
117+
return self.dp_replicate > 1 or self.ft_manager is not None
105118

106119
@property
107120
def dp_shard_enabled(self):

0 commit comments

Comments
 (0)