diff --git a/torchft/checkpointing/pg_transport.py b/torchft/checkpointing/pg_transport.py index cd07771f..d263febf 100644 --- a/torchft/checkpointing/pg_transport.py +++ b/torchft/checkpointing/pg_transport.py @@ -4,12 +4,17 @@ from contextlib import contextmanager from dataclasses import dataclass from datetime import timedelta -from typing import Generator, List, Tuple, TypeVar, Union, cast +from typing import Callable, Generator, Optional, TypeVar, Union, cast import torch from torch.distributed import Work from torch.distributed.tensor import DTensor, _DTensorSpec -from torch.utils._pytree import TreeSpec, tree_flatten, tree_unflatten +from torch.utils._pytree import ( + KeyPath, + TreeSpec, + tree_flatten_with_path, + tree_unflatten, +) from torchft.checkpointing.transport import CheckpointTransport from torchft.process_group import ProcessGroup @@ -32,7 +37,7 @@ class _TensorMeta: shape: torch.Size dtype: torch.dtype storage_offset: int - stride: Tuple[int, ...] + stride: tuple[int, ...] nbytes: int @@ -61,13 +66,15 @@ class _StateDictMeta: Args: step: the step of the checkpoint to verify consistency treespec: the pytree spec of the state dict + paths: the path of each leaf in the state dict non_tensor_leaves: the metadata for each tensor in the state dict and any non-tensor leaves in the state dict """ step: int treespec: TreeSpec - non_tensor_leaves: List[Union[object, _TensorMeta, _DTensorMeta]] + paths: list[KeyPath] + non_tensor_leaves: list[Union[object, _TensorMeta, _DTensorMeta]] @contextmanager @@ -78,7 +85,7 @@ def _timeit(name: str) -> Generator[None, None, None]: logger.info(f"{name} took {dur}s") -def _prepare_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, _TensorMeta]: +def _prepare_tensor(tensor: torch.Tensor) -> tuple[torch.Tensor, _TensorMeta]: return ( _cast_tensor(tensor, torch.uint8), _TensorMeta( @@ -95,12 +102,16 @@ def _prepare_state_dict( state_dict: object, step: int, device: torch.device, -) -> Tuple[_StateDictMeta, List[torch.Tensor]]: - leaves, treespec = tree_flatten(state_dict) +) -> tuple[_StateDictMeta, list[torch.Tensor]]: + leaves: list[tuple[KeyPath, object]] + leaves, treespec = tree_flatten_with_path(state_dict) + + paths: list[KeyPath] = [] + non_tensor_leaves: list[Union[object, _TensorMeta, _DTensorMeta]] = [] + tensors: list[torch.Tensor] = [] + for key_path, v in leaves: + paths.append(key_path) - non_tensor_leaves = [] - tensors = [] - for v in leaves: if isinstance(v, DTensor): tensor, tensor_meta = _prepare_tensor(v._local_tensor) @@ -123,6 +134,7 @@ def _prepare_state_dict( _StateDictMeta( step=step, treespec=treespec, + paths=paths, non_tensor_leaves=non_tensor_leaves, ), tensors, @@ -139,6 +151,9 @@ def _cast_tensor(tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: caveat that the cast tensor may be larger than the original tensor due to the differences in striding. """ + assert ( + type(tensor) is torch.Tensor + ), f"can only cast standard tensors not {type(tensor)}" storage = tensor.untyped_storage() ret = torch.tensor(storage, dtype=dtype, device=tensor.device) assert ret.untyped_storage() is storage, "storage should be the same" @@ -150,17 +165,28 @@ class PGTransport(CheckpointTransport[T]): This is a checkpoint transport that uses the process group to transfer checkpoints. This allows for fast recovery of workers by fetching the current weights from an existing worker. + Args: - state_dict: a callable that returns the state dict to be transferred + pg: the process group to use for communication + timeout: the timeout for communication + device: the device to use for tensors + state_dict: if specified this function will be called to do an inplace + receive into the returned state_dict. This is much faster than + having to allocate new tensors and transferring them to the CPU. """ def __init__( - self, pg: ProcessGroup, timeout: timedelta, device: torch.device + self, + pg: ProcessGroup, + timeout: timedelta, + device: torch.device, + state_dict: Optional[Callable[[], object]] = None, ) -> None: - self._work: List[Work] = [] + self._work: list[Work] = [] self._pg = pg self._timeout = timeout self._device = device + self._state_dict = state_dict def metadata(self) -> str: return "" @@ -169,7 +195,7 @@ def disallow_checkpoint(self) -> None: pass def send_checkpoint( - self, dst_ranks: List[int], step: int, state_dict: T, timeout: timedelta + self, dst_ranks: list[int], step: int, state_dict: T, timeout: timedelta ) -> None: with _timeit("preparing state_dict"): meta, tensors = _prepare_state_dict(state_dict, step, device=self._device) @@ -186,13 +212,17 @@ def send_checkpoint( with _timeit("send tensors"): for i, t in enumerate(tensors): + original_device = t.device t = t.to(self._device) for dst_rank in dst_ranks: work.append(self._pg.send([t], dst_rank, tag=3 + i)) - # allow 3 concurrent transfers at a time to avoid OOMs - while len(work) > (3 * len(dst_ranks)): - work.pop(0).wait(timeout) + # if we did a copy we should wait for the work to complete so we + # can free the memory to avoid OOMs + if original_device == torch.device("cpu"): + for w in work: + w.wait(timeout) + work = [] for w in work: w.wait(timeout) @@ -200,6 +230,11 @@ def send_checkpoint( def recv_checkpoint( self, src_rank: int, metadata: str, step: int, timeout: timedelta ) -> T: + state_dict = self._state_dict() if self._state_dict else {} + state_dict_leaves, _ = tree_flatten_with_path(state_dict) + + dst_tensors: dict[KeyPath, object] = dict(state_dict_leaves) + len_t = torch.zeros(1, dtype=torch.int64, device=self._device) self._pg.recv([len_t], src_rank, tag=1).wait(timeout) length = cast(int, len_t.item()) @@ -213,18 +248,34 @@ def recv_checkpoint( assert meta.step == step i: int = 0 + works: list[Work] = [] - def recv(v: _TensorMeta) -> torch.Tensor: + def recv(path: KeyPath, v: _TensorMeta) -> torch.Tensor: nonlocal i - t = torch.empty(v.nbytes, dtype=torch.uint8, device=self._device) - # TODO: parallelize receives - self._pg.recv([t], src_rank, tag=3 + i).wait(timeout) + inplace = dst_tensors.get(path) + if ( + isinstance(inplace, torch.Tensor) + and inplace.device.type == self._device.type + ): + if isinstance(inplace, DTensor): + inplace = inplace._local_tensor + t = _cast_tensor(inplace, torch.uint8) + assert ( + t.nbytes == v.nbytes + ), "inplace tensor storage must be the same size" + else: + t = torch.empty(v.nbytes, dtype=torch.uint8, device=self._device) + + work = self._pg.recv([t], src_rank, tag=3 + i) i += 1 - # TODO: allow in place receives to avoid having to copy to cpu to - # avoid OOMs - t = t.cpu() + if inplace is None: + # if not inplace we need to copy it to CPU to avoid OOMing + work.wait(timeout) + t = t.cpu() + else: + works.append(work) return torch.as_strided( t.view(v.dtype), @@ -234,14 +285,17 @@ def recv(v: _TensorMeta) -> torch.Tensor: ) values = [] - for v in meta.non_tensor_leaves: + for path, v in zip(meta.paths, meta.non_tensor_leaves): if isinstance(v, _TensorMeta): - values.append(recv(v)) + values.append(recv(path, v)) elif isinstance(v, _DTensorMeta): - tensor = recv(v.local) + tensor = recv(path, v.local) # pyre-fixme[29]: DTensor is not a function values.append(DTensor(tensor, v.spec, requires_grad=False)) else: values.append(v) + for work in works: + work.wait(timeout) + return tree_unflatten(values, meta.treespec) diff --git a/torchft/checkpointing/pg_transport_bench.py b/torchft/checkpointing/pg_transport_bench.py new file mode 100644 index 00000000..1bf385f8 --- /dev/null +++ b/torchft/checkpointing/pg_transport_bench.py @@ -0,0 +1,93 @@ +import logging +import sys +from concurrent.futures import ThreadPoolExecutor +from datetime import timedelta + +import torch +import torch.distributed as dist + +from torchft.checkpointing.pg_transport import PGTransport, _timeit +from torchft.process_group import ProcessGroupBabyNCCL + +logger: logging.Logger = logging.getLogger(__name__) + + +def main(argv: list[str]) -> None: + import argparse + + logging.basicConfig(level=logging.INFO) + + parser = argparse.ArgumentParser() + parser.add_argument("--inplace", action="store_true") + parser.add_argument("--device", type=str, default="cpu") + parser.add_argument("--chunk-size", type=int, default=3_000_000) # 3MB + parser.add_argument("--total-size", type=int, default=12_000_000_000) # 12GB + args = parser.parse_args(argv) + + CHUNK_SIZE: int = args.chunk_size + TOTAL_SIZE: int = args.total_size + INPLACE: bool = args.inplace + DEVICE: str = args.device + + timeout: timedelta = timedelta(seconds=10) + + store = dist.TCPStore( + "localhost", + 0, + is_master=True, + timeout=timeout, + wait_for_workers=False, + ) + store_addr: str = f"localhost:{store.port}" + + def run(rank: int) -> None: + torch.cuda.set_device(rank) + + device = torch.device(DEVICE) + + with _timeit("init_pg"): + pg = ProcessGroupBabyNCCL(timeout=timeout) + pg.configure(store_addr=store_addr, rank=rank, world_size=2) + + t = torch.zeros(10, device=device, dtype=torch.float32) + pg.allreduce([t], dist.ReduceOp.SUM).wait(timeout=timeout) + + with _timeit("create state_dict"): + state_dict: dict[str, torch.Tensor] = {} + for i in range(0, TOTAL_SIZE, CHUNK_SIZE): + state_dict[f"chunk/{i}"] = torch.zeros( + CHUNK_SIZE // 4, dtype=torch.float32, device=device + ) + + def get_state_dict() -> object: + return state_dict + + transport = PGTransport( + pg=pg, + timeout=timeout, + device=device, + state_dict=get_state_dict if INPLACE else None, + ) + metadata = transport.metadata() + + if rank == 0: + with _timeit("send_checkpoint"): + transport.send_checkpoint( + dst_ranks=[1], + step=1, + state_dict=state_dict, + timeout=timedelta(seconds=60), + ) + elif rank == 1: + with _timeit("recv_checkpoint"): + transport.recv_checkpoint( + src_rank=0, metadata=metadata, step=1, timeout=timedelta(seconds=60) + ) + + with ThreadPoolExecutor(max_workers=2) as executor: + results = executor.map(run, range(2)) + list(results) + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/torchft/checkpointing/pg_transport_test.py b/torchft/checkpointing/pg_transport_test.py index a7b9c123..068acaa9 100644 --- a/torchft/checkpointing/pg_transport_test.py +++ b/torchft/checkpointing/pg_transport_test.py @@ -1,5 +1,4 @@ from datetime import timedelta -from typing import Dict from unittest import TestCase, skipUnless import torch @@ -7,7 +6,10 @@ from torchft.checkpointing.pg_transport import PGTransport from torchft.checkpointing.transport import CheckpointTransport -from torchft.checkpointing.transport_test import run_multi_recovery_test +from torchft.checkpointing.transport_test import ( + make_state_dict, + run_multi_recovery_test, +) from torchft.process_group import ProcessGroupBabyNCCL, ProcessGroupGloo @@ -18,7 +20,7 @@ def test_pg_transport_gloo(self) -> None: ) device: torch.device = torch.device("cpu") - def init(rank: int, world_size: int) -> CheckpointTransport[Dict[str, object]]: + def init(rank: int, world_size: int) -> CheckpointTransport[dict[str, object]]: pg = ProcessGroupGloo() pg.configure( store_addr=f"localhost:{store.port}/prefix", @@ -26,7 +28,7 @@ def init(rank: int, world_size: int) -> CheckpointTransport[Dict[str, object]]: world_size=world_size, ) - return PGTransport[Dict[str, object]]( + return PGTransport[dict[str, object]]( pg, timeout=timedelta(seconds=10), device=device ) @@ -39,19 +41,49 @@ def test_pg_transport_baby_nccl(self) -> None: host_name="localhost", port=0, is_master=True, wait_for_workers=False ) device: torch.device = torch.device("cuda") + timeout: timedelta = timedelta(seconds=10) - def init(rank: int, world_size: int) -> CheckpointTransport[Dict[str, object]]: + def init(rank: int, world_size: int) -> CheckpointTransport[dict[str, object]]: torch.cuda.set_device(rank) - pg = ProcessGroupBabyNCCL() + pg = ProcessGroupBabyNCCL(timeout=timeout) pg.configure( store_addr=f"localhost:{store.port}/prefix", rank=rank, world_size=world_size, ) - return PGTransport[Dict[str, object]]( - pg, timeout=timedelta(seconds=10), device=device + return PGTransport[dict[str, object]](pg, timeout=timeout, device=device) + + run_multi_recovery_test(self, init, device=device) + + # pyre-fixme[56]: Pyre was not able to infer the type of argument + @skipUnless(torch.cuda.device_count() >= 3, "need three CUDA devices") + def test_pg_transport_baby_nccl_inplace(self) -> None: + store: TCPStore = TCPStore( + host_name="localhost", port=0, is_master=True, wait_for_workers=False + ) + device: torch.device = torch.device("cuda") + timeout: timedelta = timedelta(seconds=10) + + def state_dict() -> dict[str, object]: + return make_state_dict(device) + + def init(rank: int, world_size: int) -> CheckpointTransport[dict[str, object]]: + torch.cuda.set_device(rank) + + pg = ProcessGroupBabyNCCL(timeout=timeout) + pg.configure( + store_addr=f"localhost:{store.port}/prefix", + rank=rank, + world_size=world_size, + ) + + return PGTransport[dict[str, object]]( + pg, + timeout=timeout, + device=device, + state_dict=state_dict, ) run_multi_recovery_test(self, init, device=device) diff --git a/torchft/checkpointing/transport_test.py b/torchft/checkpointing/transport_test.py index 5601db6b..0756c2ca 100644 --- a/torchft/checkpointing/transport_test.py +++ b/torchft/checkpointing/transport_test.py @@ -1,7 +1,8 @@ import threading +import traceback from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import timedelta -from typing import Callable, Dict, List +from typing import Callable from unittest import TestCase import torch @@ -14,12 +15,12 @@ def assertStateDictEqual( - self: TestCase, a: Dict[str, object], b: Dict[str, object] + self: TestCase, a: dict[str, object], b: dict[str, object] ) -> None: for k, v1 in a.items(): v2 = b[k] if isinstance(v1, DTensor) and isinstance(v2, DTensor): - torch.testing.assert_close(v1._local_tensor, v2._local_tensor) + torch.testing.assert_close(v1._local_tensor.cpu(), v2._local_tensor.cpu()) self.assertEqual(v1._spec, v2._spec) elif isinstance(v1, torch.Tensor) and isinstance(v2, torch.Tensor): torch.testing.assert_close(v1.cpu(), v2.cpu()) @@ -27,9 +28,23 @@ def assertStateDictEqual( self.assertEqual(v1, v2) +def make_state_dict(device: torch.device) -> dict[str, object]: + device_mesh = DeviceMesh("cpu", 1) + tensor = torch.tensor([5, 6, 7]) + dtensor: DTensor = distribute_tensor(tensor, device_mesh, []) + + return { + "rank": torch.tensor([1, 2, 3], device=device), + # "strided": torch.tensor([10], device=device)[1::2], + "str": "str", + "int": 1234, + "dtensor": dtensor, + } + + def run_multi_recovery_test( self: TestCase, - init_transport: Callable[[int, int], CheckpointTransport[Dict[str, object]]], + init_transport: Callable[[int, int], CheckpointTransport[dict[str, object]]], device: torch.device, ) -> None: """ @@ -41,18 +56,14 @@ def run_multi_recovery_test( WORLD_SIZE: int = 3 # barrier is used to simulate quorum/allreduce barriers - barrier: threading.Barrier = threading.Barrier(WORLD_SIZE) + barrier: threading.Barrier = threading.Barrier(WORLD_SIZE, timeout=10) metadata: str = "" dist.init_process_group( backend="gloo", rank=0, world_size=1, store=dist.HashStore() ) - device_mesh = DeviceMesh("cpu", 1) - tensor = torch.randn(4, 4) - dtensor: DTensor = distribute_tensor(tensor, device_mesh, []) - - def run(rank: int) -> CheckpointTransport[Dict[str, object]]: + def run(rank: int) -> CheckpointTransport[dict[str, object]]: transport = init_transport(rank, WORLD_SIZE) if rank == 0: @@ -61,12 +72,7 @@ def run(rank: int) -> CheckpointTransport[Dict[str, object]]: barrier.wait() - state_dict: Dict[str, object] = { - "rank": torch.tensor([1, 2, 3], device=device), - "str": "str", - "int": 1234, - "dtensor": dtensor, - } + state_dict: dict[str, object] = make_state_dict(device) # 3 node recovery if rank == 0: @@ -140,6 +146,7 @@ def run(rank: int) -> CheckpointTransport[Dict[str, object]]: transports.append(fut.result()) except Exception as e: print(e) + traceback.print_exc() raise for transport in transports: diff --git a/torchft/multiprocessing.py b/torchft/multiprocessing.py index 273e820e..6e038a9d 100644 --- a/torchft/multiprocessing.py +++ b/torchft/multiprocessing.py @@ -38,18 +38,20 @@ def get(self, timeout: Union[float, timedelta]) -> object: start = time.perf_counter() while True: + try: + v = self._q.get(timeout=self._poll_interval_s) + break + except queue.Empty: + pass + elapsed = time.perf_counter() - start if elapsed > timeout: raise TimeoutError(f"queue.get() timed out after {timeout} seconds") + + # polling the process can be slow so we only do it every poll_interval if not self._p.is_alive(): raise RuntimeError(f"process is not alive {self._p.exitcode}") - try: - v = self._q.get(timeout=self._poll_interval_s) - break - except queue.Empty: - continue - if isinstance(v, Exception): raise v return v @@ -71,18 +73,20 @@ def put(self, obj: object, timeout: Union[float, timedelta]) -> None: start = time.perf_counter() while True: + try: + self._q.put(obj, timeout=self._poll_interval_s) + break + except queue.Full: + pass + elapsed = time.perf_counter() - start if elapsed > timeout: raise TimeoutError(f"queue.put() timed out after {timeout} seconds") + + # polling the process can be slow so we only do it every poll_interval if not self._p.is_alive(): raise RuntimeError(f"process is not alive {self._p.exitcode}") - try: - self._q.put(obj, timeout=self._poll_interval_s) - break - except queue.Full: - continue - def close(self) -> None: self._q.close() diff --git a/torchft/process_group.py b/torchft/process_group.py index 3a49e3c0..3ce4dcb7 100644 --- a/torchft/process_group.py +++ b/torchft/process_group.py @@ -1148,8 +1148,6 @@ def _future_handler(self, future_queue: _MonitoredQueue) -> None: logger.exception(f"got unexpected error in future handler: {e}") def _get_future(self, op_id: int) -> Future[object]: - self._assert_alive() - with self._futures_lock: fut = Future() # pyre-fixme[29]: is not a function self._futures[op_id] = fut @@ -1162,8 +1160,6 @@ def _get_future(self, op_id: int) -> Future[object]: return fut def _wait(self, op_id: int, timeout: Optional[timedelta] = None) -> bool: - self._assert_alive() - assert self._tx is not None self._tx.put(("wait", op_id, timeout), timeout=self._timeout) @@ -1179,14 +1175,10 @@ def _wait(self, op_id: int, timeout: Optional[timedelta] = None) -> bool: return True def _del(self, op_id: int) -> None: - self._assert_alive() - assert self._tx is not None self._tx.put(("del", op_id), timeout=self._timeout) def _run_func(self, func: str, *args: object, **kwargs: object) -> Work: - self._assert_alive() - rx = self._rx tx = self._tx assert rx is not None @@ -1222,16 +1214,6 @@ def _run_func(self, func: str, *args: object, **kwargs: object) -> Work: return _BabyWork(pg=self, op_id=op_id) - def _assert_alive(self) -> None: - """ - Assert that the process group is alive. This is used to ensure that - operations are not performed on a dead process group and any errors are surfaced. - """ - p = self._p - assert p is not None - if not p.is_alive(): - raise RuntimeError(f"child process {p.pid=} is dead {p.exitcode=}") - def allgather( self, output_tensors: List[List[torch.Tensor]],