Skip to content

Commit 7c70988

Browse files
committed
Integrate TorchFT
**Summary** This is a WIP TorchFT integration PR. **Current Issues** This doesn't work at this moment as there are hanged groups when a new group joins. **Issue 1:** ~Group 0 and group 1 will hang during the first `should_commit` after group 1 applying the pending state_dict from group 0.~ Fixed with: meta-pytorch/torchft#83 **Issue 2:** ~Group 0 and group 1 will pass the `should_commit` but group 0 needs healing which is wrong and the healing process will cause another hang.~ Fixed with: meta-pytorch/torchft#83 **Issue 3:** ~The byproduct of issue 1 and issue 2: group 1 will continue to print out~ ``` [rank0]:devgpu051:76838:80357 [0] misc/socket.cc:50 NCCL WARN socketProgress: Connection closed by remote peer devgpu051.cln3.svc.fbinfra.net<33618> ``` Fixed with meta-pytorch/torchft#91 and several other fixes. **Issue 4:** When there are 3 groups, everyone requests the state dict every step. ***How to reproduce?*** Using the `Reproduce steps` to run 2 groups, then add another group by modifying the command. Seems to be fixed, will need more tests. **Issue 5:** Hang will happen if using functional collective. ***How to reproduce?*** Pull the latest version of this PR and comment out line 41 and uncomment line 42 in `torchtitan/utils.py` **Reproduce steps:** 1. Patch TorchFT with meta-pytorch/torchft#82 2. Execute lighthouse 3. Execute the following command in one terminal: ``` TORCHFT_MANAGER_PORT=29520 REPLICA_GROUP_ID=0 CUDA_VISIBLE_DEVICES=0,1 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=0 ``` 4. Wait 10 seconds, execute following command in another terminal: ``` TORCHFT_MANAGER_PORT=29522 REPLICA_GROUP_ID=1 CUDA_VISIBLE_DEVICES=2,3 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=1 ``` ghstack-source-id: 5f9a731 Pull Request resolved: #834
1 parent 43fe812 commit 7c70988

File tree

10 files changed

+282
-42
lines changed

10 files changed

+282
-42
lines changed

run_train.sh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@ if [ $# -ne 0 ]; then
1919
overrides="$*"
2020
fi
2121

22+
TORCHFT_MANAGER_PORT=${TORCHFT_MANAGER_PORT:-"29512"}
23+
2224
PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \
25+
TORCHFT_LIGHTHOUSE=http://localhost:29510 \
26+
TORCHFT_MANAGER_PORT=${TORCHFT_MANAGER_PORT} \
2327
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
2428
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
2529
torchtitan/train.py --job.config_file ${CONFIG_FILE} $overrides

tests/unit_tests/test_checkpoint.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,18 @@ class DummyJob:
6161
dump_folder: str = "dummy_folder"
6262

6363

64+
@dataclass
65+
class DummyExperimental:
66+
ft_replica_id = 0
67+
ft_group_size = 1
68+
69+
6470
@dataclass
6571
class DummyJobConfig:
6672
checkpoint: DummyCheckpointConfig = field(default_factory=DummyCheckpointConfig)
6773
job: DummyJob = field(default_factory=DummyJob)
74+
experimental: DummyExperimental = field(default_factory=DummyExperimental)
75+
ft_manager = None
6876

6977

7078
# Dummy instances to supply as constructor arguments.

tests/unit_tests/test_model_converter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def build_parallel_dims(job_config, world_size):
2222
pp=job_config.experimental.pipeline_parallel_degree,
2323
world_size=world_size,
2424
enable_loss_parallel=not job_config.training.disable_loss_parallel,
25+
ft_manager=None,
2526
)
2627
return parallel_dims
2728

torchtitan/components/checkpoint.py

Lines changed: 101 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ class CheckpointManager:
221221
states (Dict[str, Any]): The states that need to be saved, other than the
222222
previous 4 components.
223223
job_config (JobConfig): The job config used to configure the checkpointing.
224+
ft_manager (Optional[ft.Manager]): The FTManager from TorchFT.
224225
"""
225226

226227
def __init__(
@@ -231,16 +232,41 @@ def __init__(
231232
lr_schedulers: LRSchedulersContainer,
232233
states: Dict[str, Any],
233234
job_config: JobConfig,
235+
ft_manager: Optional["ft.Manager"] = None,
234236
) -> None:
235237
ckpt_config = job_config.checkpoint
236238
self.enable_checkpoint = ckpt_config.enable_checkpoint
239+
self.ft_manager = ft_manager
240+
241+
if self.ft_manager:
242+
optimizers.init_cache_state_dict()
243+
244+
def state_dict():
245+
ret = {}
246+
for k, v in self.states.items():
247+
if k in {
248+
MODEL,
249+
OPTIMIZER,
250+
LR_SCHEDULER,
251+
TRAIN_STATE,
252+
}:
253+
ret[k] = v.state_dict()
254+
return ret
255+
256+
def load_state_dict(state_dict):
257+
assert state_dict is not None
258+
for k, v in state_dict.items():
259+
self.states[k].load_state_dict(v)
260+
261+
ft_manager.set_state_dict_fns(load_state_dict, state_dict)
262+
self.ft_replica_id = job_config.experimental.ft_replica_id
237263

238264
async_mode = ckpt_config.async_mode.lower()
239265
self.enable_staging = (
240266
self.enable_checkpoint and async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM
241-
)
267+
) or self.ft_manager
242268

243-
if not self.enable_checkpoint:
269+
if not self.enable_checkpoint and self.ft_manager is None:
244270
return
245271

246272
self.states = states
@@ -252,6 +278,13 @@ def __init__(
252278
LR_SCHEDULER: lr_schedulers,
253279
}
254280
)
281+
self.ft_states = {DATALOADER: dataloader}
282+
283+
self.staging = False
284+
self.sending_to_checkpoint_mp = False
285+
self.staging_id = None
286+
self.cpu_offload_state_dict = None
287+
self.staging_stream = torch.cuda.Stream() if self.enable_staging else None
255288

256289
self.staging = False
257290
self.sending_to_checkpoint_mp = False
@@ -262,7 +295,7 @@ def __init__(
262295
self.folder = os.path.join(job_config.job.dump_folder, ckpt_config.folder)
263296
self.interval = ckpt_config.interval
264297
async_mode = ckpt_config.async_mode.lower()
265-
if async_mode == AsyncMode.ASYNC:
298+
if async_mode == AsyncMode.ASYNC or self.ft_manager:
266299
self.pg = dist.new_group(backend="gloo")
267300

268301
self.keep_latest_k = ckpt_config.keep_latest_k
@@ -336,35 +369,44 @@ def save(self, curr_step: int, force: bool = False) -> None:
336369
None
337370
"""
338371

372+
if self.ft_manager:
373+
self._ft_save(curr_step)
374+
339375
if not self._should_save(curr_step, force):
340376
return
341377

342378
begin = time.monotonic()
343-
logger.info("Saving the checkpoint (or staging if async is enabled).")
344-
checkpoint_id = self._create_checkpoint_id(curr_step)
345-
self._async_wait()
346-
# This GC is called for async checkpoint as it is useless to do
347-
# GC right after async_save -- the CPU memory is not able to be
348-
# freed until _async_wait()
349-
if force:
350-
self._save_last_step(curr_step)
351-
elif self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
352-
GarbageCollection.collect("GC collection invoked by checkpointer.")
353-
self._async_with_pinned_memory(checkpoint_id)
354-
elif self.async_mode == AsyncMode.ASYNC:
355-
GarbageCollection.collect("GC collection invoked by checkpointer.")
356-
self.async_future = dcp.async_save(
357-
self.states, checkpoint_id=checkpoint_id, process_group=self.pg
358-
)
359-
GarbageCollection.collect("GC collection invoked by checkpointer.")
360-
else:
361-
save_with_gc(self.states, checkpoint_id=checkpoint_id)
362-
self._purge_stale_checkpoints()
379+
if not self.ft_manager or self.ft_manager.participating_rank() == 0:
380+
logger.info("Saving the checkpoint (or staging if async is enabled).")
381+
checkpoint_id = self._create_checkpoint_id(curr_step)
382+
self._async_wait()
383+
# This GC is called for async checkpoint as it is useless to do
384+
# GC right after async_save -- the CPU memory is not able to be
385+
# freed until _async_wait()
386+
if force:
387+
self._save_last_step(curr_step)
388+
elif self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
389+
GarbageCollection.collect("GC collection invoked by checkpointer.")
390+
self._async_with_pinned_memory(checkpoint_id)
391+
elif self.async_mode == AsyncMode.ASYNC:
392+
GarbageCollection.collect("GC collection invoked by checkpointer.")
393+
self.async_future = dcp.async_save(
394+
self.states, checkpoint_id=checkpoint_id, process_group=self.pg
395+
)
396+
GarbageCollection.collect("GC collection invoked by checkpointer.")
397+
else:
398+
save_with_gc(self.states, checkpoint_id=checkpoint_id)
399+
self._purge_stale_checkpoints()
363400

364-
logger.info(
365-
"Finished saving the checkpoint (or staging if async is enabled)"
366-
f"in {time.monotonic() - begin:.2f} seconds."
367-
)
401+
logger.info(
402+
"Finished saving the checkpoint (or staging if async is enabled)"
403+
f"in {time.monotonic() - begin:.2f} seconds."
404+
)
405+
elif self.ft_manager:
406+
logger.info(
407+
"Replica %d doesn't save checkpoint.",
408+
self.ft_manager.participating_rank(),
409+
)
368410

369411
@torch.no_grad()
370412
def load(self, step: int = -1) -> bool:
@@ -381,6 +423,9 @@ def load(self, step: int = -1) -> bool:
381423
bool: Whether the checkpoint was loaded successfully.
382424
"""
383425

426+
if self.ft_manager:
427+
self._ft_load()
428+
384429
if not self.enable_checkpoint or not os.path.isdir(self.folder):
385430
return False
386431

@@ -464,10 +509,36 @@ def _find_load_step(self, folder: str = "") -> int:
464509
return -1
465510
return max(step_counts)
466511

512+
def _ft_folder(self) -> str:
513+
return os.path.join(self.folder, f"ft-replicat-{self.ft_replica_id}")
514+
467515
def _create_checkpoint_id(self, step: int, folder: str = "") -> str:
468516
folder = folder if folder else self.folder
469517
return os.path.join(folder, f"step-{step}")
470518

519+
def _ft_save(self, step: int) -> None:
520+
begin = time.monotonic()
521+
self._async_wait()
522+
checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder())
523+
self.async_future = dcp.async_save(
524+
self.ft_states, checkpoint_id=checkpoint_id, process_group=self.pg
525+
)
526+
logger.info(f"Staging ft checkpoint took {time.monotonic() - begin} secs.")
527+
528+
def _ft_load(self) -> None:
529+
step = self._find_load_step(folder=self._ft_folder())
530+
if step == -1:
531+
return
532+
533+
begin = time.monotonic()
534+
logger.info(f"Loading the FT checkpoint at step {step}.")
535+
checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder())
536+
dcp.load(self.ft_states, checkpoint_id=checkpoint_id)
537+
GarbageCollection.collect("GC collection for checkpoint loading.")
538+
logger.info(
539+
f"Finished loading the ft checkpoint in {time.monotonic() - begin:.2f} seconds."
540+
)
541+
471542
def _states_to_load(self, step: int) -> Dict[str, Any]:
472543
"""Determines which states to load for the given step.
473544
@@ -488,6 +559,8 @@ def _states_to_load(self, step: int) -> Dict[str, Any]:
488559
for exclude_key in self.exclude_from_loading:
489560
if exclude_key not in states:
490561
raise ValueError(f"{exclude_key} not found in state_dict.")
562+
if self.ft_manager:
563+
states_to_load.pop(DATALOADER)
491564
return states_to_load
492565

493566
def _save_last_step(self, curr_step: int) -> None:
@@ -574,6 +647,7 @@ def _purge_stale_checkpoints(self):
574647
self.keep_latest_k > 0
575648
and dist.get_rank() == 0
576649
and os.path.isdir(self.folder)
650+
and (not self.ft_manager or self.ft_manager.participating_rank() == 0)
577651
):
578652
discovered_checkpoints = []
579653
for filename in os.listdir(self.folder):

torchtitan/components/optimizer.py

Lines changed: 49 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import copy
88
import functools
9-
from typing import Any, Callable, Dict, Iterable, List
9+
from typing import Any, Callable, Dict, Iterable, List, Optional
1010

1111
import torch
1212
import torch.nn as nn
@@ -177,8 +177,47 @@ def zero_grad(self) -> None:
177177
pass
178178

179179

180+
class FTOptimizersContainer(OptimizersContainer):
181+
def __init__(
182+
self,
183+
model_parts: List[nn.Module],
184+
optimizer_kwargs: Dict[str, Any],
185+
name: str,
186+
ft_manager: Optional["ft.Manager"],
187+
) -> None:
188+
import torchft as ft
189+
190+
super().__init__(model_parts, optimizer_kwargs, name)
191+
192+
# Force to initialize the optimizer state so that `optim.step()`
193+
# won't be called by state_dict() and load_state_dict().
194+
_ = {
195+
k: v
196+
for sd in map(get_optimizer_state_dict, model_parts, self.optimizers)
197+
for k, v in sd.items()
198+
}
199+
self.optimizers = [ft.Optimizer(ft_manager, optim) for optim in self.optimizers]
200+
self.cache_state_dict: Dict[str, Any] = {}
201+
202+
def init_cache_state_dict(self) -> None:
203+
self.cache_state_dict = super().state_dict()
204+
205+
def state_dict(self) -> Dict[str, Any]:
206+
return self.cache_state_dict
207+
208+
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
209+
# We have to invalidate the `cache_state_dict` because optimizer uses
210+
# assign instead of copy when doing `load_state_dict()`. Without
211+
# invalidating the `cache_state_dict`, there will be memory leakage.
212+
self.cache_state_dict = {}
213+
super().load_state_dict(state_dict)
214+
self.init_cache_state_dict()
215+
216+
180217
def build_optimizers(
181-
model_parts: List[nn.Module], job_config: JobConfig
218+
model_parts: List[nn.Module],
219+
job_config: JobConfig,
220+
ft_manager: Optional["ft.Manager"] = None,
182221
) -> OptimizersContainer:
183222
"""Create a OptimizersContainer for the given model parts and job config.
184223
@@ -219,11 +258,14 @@ def build_optimizers(
219258
"foreach": foreach,
220259
}
221260

222-
return (
223-
OptimizersContainer(model_parts, optimizer_kwargs, name)
224-
if not optim_in_bwd
225-
else OptimizersInBackwardContainer(model_parts, optimizer_kwargs, name)
226-
)
261+
if optim_in_bwd and ft_manager:
262+
raise ValueError("TorchFT is not supported with optimizers in backward.")
263+
elif optim_in_bwd:
264+
return OptimizersInBackwardContainer(model_parts, optimizer_kwargs, name)
265+
elif ft_manager:
266+
return FTOptimizersContainer(model_parts, optimizer_kwargs, name, ft_manager)
267+
else:
268+
return OptimizersContainer(model_parts, optimizer_kwargs, name)
227269

228270

229271
class LRSchedulersContainer(Stateful):

torchtitan/config_manager.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -661,6 +661,30 @@ def __init__(self):
661661
action="store_true",
662662
)
663663

664+
self.parser.add_argument(
665+
"--experimental.enable_torchft",
666+
action="store_true",
667+
help="Enable TorchFT integration.",
668+
)
669+
670+
self.parser.add_argument(
671+
"--experimental.ft_replica_id",
672+
type=int,
673+
default=0,
674+
help="The TorchFT replica ID of this run.",
675+
)
676+
677+
self.parser.add_argument(
678+
"--experimental.ft_group_size",
679+
type=int,
680+
default=1,
681+
help="""
682+
The number of TorchFT replicate groups. This number will be used for
683+
dataloader to split the dataset across the replicate groups and FSDP
684+
dimension.
685+
""",
686+
)
687+
664688
def to_dict(self):
665689
return self.args_dict
666690

0 commit comments

Comments
 (0)