diff --git a/torchft/local_sgd.py b/torchft/local_sgd.py index 1458f073..ec3cf822 100644 --- a/torchft/local_sgd.py +++ b/torchft/local_sgd.py @@ -3,25 +3,29 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - """ LocalSGD ========= - This module implements a fault tolerant version of LocalSGD and related methods. """ - -from typing import Any, Dict, List, Mapping, Optional +import logging +from types import TracebackType +from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Type import torch from torch import nn, optim +from torch.nn.parameter import Parameter +from torch.optim.optimizer import Optimizer +from torch.utils.hooks import RemovableHandle from torchft.manager import Manager +logger: logging.Logger = logging.getLogger(__name__) + -class LocalSGD(nn.Module): +class LocalSGD: """ - LocalSGD is a model wrapper similar to DistributedDataParallel that + LocalSGD is a context manager that implements the algorithm described in https://arxiv.org/pdf/1805.09767 This will synchronize the model parameters periodically in a fault tolerant @@ -68,18 +72,14 @@ def __init__( pin_memory: Whether to pin the memory used for the backup of the model parameters. """ super().__init__() - self._manager = manager self._model = model + self._local_optimizer = optimizer self._local_step = 0 - self._started_step = False self._sync_every = sync_every assert sync_every >= 1, "sync_every must be greater than or equal to 1" - device = backup_device or torch.device("cpu") - self._backup_parameters: Dict[str, torch.Tensor] = {} - for name, p in self._model.named_parameters(): t = torch.empty(*tuple(p.shape), dtype=p.dtype, device=device) if ( @@ -90,95 +90,150 @@ def __init__( t = t.pin_memory() self._backup_parameters[name] = t + self._hooks: List[RemovableHandle] = [] # Need to copy the parameters to the host to be safe if we are on the first step. self._save_parameters() - optimizer.register_step_post_hook(self._step_post_hook) + def __enter__(self) -> "LocalSGD": + # Add optimizer hook which increments the local step counter and syncs if necessary + self._hooks.append( + self._local_optimizer.register_step_post_hook(self._step_post_hook) + ) + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> bool: + # Handle any cleanup or error handling here + if exc_type is not None: + # If an exception occurred, restore parameters + self._restore_parameters() + # Clean up hooks + for hook in self._hooks: + hook.remove() + self._hooks.clear() + + return False # Propagate exceptions def _save_parameters(self) -> None: - # TODO: consider running copy on a separate stream - for name, p in self._model.named_parameters(): - self._backup_parameters[name].copy_(p.data, non_blocking=True) + with torch.no_grad(): + # TODO: consider running copy on a separate stream + for name, p in self._model.named_parameters(): + self._backup_parameters[name].copy_(p.data, non_blocking=True) def _restore_parameters(self) -> None: - # TODO: consider running copy on a separate stream - for name, p in self._model.named_parameters(): - p.data.copy_(self._backup_parameters[name], non_blocking=True) + with torch.no_grad(): + # TODO: consider running copy on a separate stream + for name, p in self._model.named_parameters(): + p.data.copy_(self._backup_parameters[name], non_blocking=False) - # pyre-fixme[14]: support state_dict args - def state_dict(self) -> Dict[str, object]: - """ - state_dict returns the state_dict from the last time LocalSGD - synchronized and not the current weights. - """ - state_dict = self._model.state_dict() - for name, p in self._backup_parameters.items(): - assert name in state_dict - state_dict[name] = p - return state_dict - - def load_state_dict( - self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False + def _step_post_hook( + self, _optim: optim.Optimizer, _args: List[object], _kwargs: Dict[str, object] ) -> None: """ - Loads the state dict to the model and the backup parameters. + This hook is registered on the optimizer and is called after the optimizer step. + """ + self._local_step += 1 + if self._local_step >= self._sync_every: + self.sync() - This must be called while the model weights aren't being modified to - avoid corrupting the backup weights. + def sync(self) -> None: """ - self._model.load_state_dict(state_dict, strict=strict, assign=assign) - self._save_parameters() + Synchronizes and averages the model weights across the manager. + """ + self._manager.start_quorum() + self._perform_sync() + self._local_step = 0 - def forward(self, *args: object, **kwargs: object) -> object: + def _perform_sync(self) -> None: + """ + Performs the synchronization of the model weights across the manager. + This method is intended to be overridden by subclasses to implement custom + synchronization logic. """ - Run the model parameters. + self._average() + if self._manager.should_commit(): + self._save_parameters() + else: + # commit failed, restore from the backup parameters + self._restore_parameters() - This should be called before the optimizer step. + def _average(self) -> None: + # TODO: do we need to broadcast buffers like DDP does? - This will start the quorum and save the parameters if this is the first step. - """ - if self._local_step == 0: - self._manager.start_quorum() + works = [] + + for p in self._model.parameters(): + # TODO: bucketize parameters + works.append(self._manager.allreduce(p.data.detach())) - self._started_step = True + for work in works: + work.wait() - return self._model.forward(*args, **kwargs) - def _step_post_hook( - self, _optim: optim.Optimizer, _args: List[object], _kwargs: Dict[str, object] - ) -> None: - """ - This hook is registered on the optimizer and is called after the optimizer step. +class DiLoCo(LocalSGD): + """ + DiLoCo is a subclass of LocalSGD that overrides the synchronization + mechanism to average and synchronize the pseudogradients (delta of the previous global weight and current local weights). - This will call the allreduce on the model weights every sync_every steps. - If any errors occur it will restore to the weights from the previous sync. + diloco: https://arxiv.org/pdf/2311.08105 + """ - ``forward`` must be called before this function. + def __init__( + self, + manager: Manager, + model: nn.Module, + inner_optimizer: optim.Optimizer, + outer_optimizer: optim.Optimizer, + sync_every: int, + backup_device: Optional[torch.device] = None, + pin_memory: bool = True, + ) -> None: + if manager._use_async_quorum: + raise ValueError( + "Using DiLoCo require synchronous quorum to be enabled. " + "Ensure that the manager is initialized with use_async_quorum=False" + ) + super().__init__( + manager, model, inner_optimizer, sync_every, backup_device, pin_memory + ) + self._outer_optimizer = outer_optimizer + + def _perform_sync(self) -> None: + """ + Overrides the sync method to calculate the pseugradient, average them across the manager group, and + step using the outer optimizer. """ - assert self._started_step, "forward must be called before step" - self._started_step = False - self._local_step += 1 + # Set the .grad field of each parameter to its pseudogradient + for name, p in self._model.named_parameters(): + assert name in self._backup_parameters + pseudogradient = p.data - self._backup_parameters[name] + p.grad = pseudogradient - if self._local_step >= self._sync_every: - self._local_step = 0 - self._average() + self._average_grads() + # Restore the parameters back to the previous state + self._restore_parameters() - if self._manager.should_commit(): - # save the parameters so we can restore from them later if necessary. - self._save_parameters() - else: - # commit failed, restore from the backup parameters - self._restore_parameters() - - def _average(self) -> None: - # TODO: do we need to broadcast buffers like DDP does? + if self._manager.should_commit(): + # Use the outer optimizer to update the model parameters + self._outer_optimizer.step() + self._save_parameters() + self._outer_optimizer.zero_grad() + def _average_grads(self) -> None: + """ + Average the gradients across the diloco group. + """ works = [] - for p in self._model.parameters(): - # TODO: bucketize parameters - works.append(self._manager.allreduce(p.data.detach())) - + # Perform allreduce on the pseudogradients + assert p.grad is not None + work = self._manager.allreduce(p.grad) + works.append(work) + # Wait for all allreduce operations to complete for work in works: work.wait() diff --git a/torchft/local_sgd_test.py b/torchft/local_sgd_test.py index d2b73cd5..05f88b7a 100644 --- a/torchft/local_sgd_test.py +++ b/torchft/local_sgd_test.py @@ -11,7 +11,7 @@ import torch from torch import nn, optim -from torchft.local_sgd import LocalSGD +from torchft.local_sgd import DiLoCo, LocalSGD from torchft.manager import Manager @@ -40,57 +40,107 @@ def _copy_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Ten class LocalSGDTest(TestCase): def test_local_sgd_healthy(self) -> None: - base_m = SimpleModel() - optimizer = optim.SGD(base_m.parameters()) + model = SimpleModel() + optimizer = optim.SGD(model.parameters()) manager = create_autospec(Manager) - - m = LocalSGD(manager, base_m, optimizer, sync_every=2) - self.assertEqual(m._local_step, 0) - - torch.testing.assert_close(m._backup_parameters, _params_dict(base_m)) - - inp = torch.rand(2, 3) - - loss = m(inp).mean() - loss.backward() - optimizer.step() - - self.assertEqual(m._local_step, 1) - self.assertEqual(manager.start_quorum.call_count, 1) - - loss = m(inp).mean() - loss.backward() - optimizer.step() - - manager.should_commit.return_value = True - self.assertEqual(m._local_step, 0) - - torch.testing.assert_close(m._backup_parameters, _params_dict(base_m)) - self.assertEqual(manager.should_commit.call_count, 1) - self.assertEqual(manager.allreduce.call_count, 4) + with LocalSGD(manager, model, optimizer, sync_every=2) as local_sgd: + self.assertEqual(local_sgd._local_step, 0) + torch.testing.assert_close( + local_sgd._backup_parameters, _params_dict(model) + ) + inp = torch.rand(2, 3) + loss = model(inp).mean() + loss.backward() + optimizer.step() + + self.assertEqual(local_sgd._local_step, 1) + self.assertEqual(manager.start_quorum.call_count, 0) + loss = model(inp).mean() + loss.backward() + optimizer.step() + self.assertEqual(manager.start_quorum.call_count, 1) + + manager.should_commit.return_value = True + self.assertEqual(local_sgd._local_step, 0) + torch.testing.assert_close( + local_sgd._backup_parameters, _params_dict(model) + ) + self.assertEqual(manager.should_commit.call_count, 1) + self.assertEqual(manager.allreduce.call_count, 4) def test_local_sgd_recovery(self) -> None: - base_m = SimpleModel() - optimizer = optim.SGD(base_m.parameters()) + model = SimpleModel() + optimizer = optim.SGD(model.parameters()) manager = create_autospec(Manager) - m = LocalSGD(manager, base_m, optimizer, sync_every=2) - - torch.testing.assert_close(m._backup_parameters, _params_dict(base_m)) - og_state_dict = _copy_state_dict(base_m.state_dict()) - - inp = torch.rand(2, 3) - - loss = m(inp).mean() - loss.backward() - optimizer.step() - - self.assertEqual(m._local_step, 1) - - state_dict = m.state_dict() - torch.testing.assert_close(state_dict, m._backup_parameters) - torch.testing.assert_close(state_dict, og_state_dict) + with LocalSGD(manager, model, optimizer, sync_every=2) as local_sgd: + torch.testing.assert_close( + local_sgd._backup_parameters, _params_dict(model) + ) + og_state_dict = _copy_state_dict(model.state_dict()) + print(og_state_dict) + + inp = torch.rand(2, 3) + + loss = model(inp).mean() + loss.backward() + optimizer.step() + + # Check that the model's state dict has been updated + for name, param in model.state_dict().items(): + # Ensure the parameter has changed + self.assertFalse( + torch.equal(og_state_dict[name], param), + f"Parameter {name} did not change.", + ) + self.assertEqual(local_sgd._local_step, 1) + + local_sgd._restore_parameters() + torch.testing.assert_close( + local_sgd._backup_parameters, _params_dict(model) + ) + + +class DiLoCoTest(TestCase): + def test_diloco_healthy(self) -> None: + model = SimpleModel() + + # Setup optimizers + inner_optimizer = torch.optim.AdamW( + model.parameters(), lr=4e-4, weight_decay=0.1, betas=(0.9, 0.95) + ) + outer_optimizer = torch.optim.SGD( + model.parameters(), lr=0.7, momentum=0.9, nesterov=True + ) - m.load_state_dict(state_dict) - torch.testing.assert_close(_params_dict(base_m), state_dict) - torch.testing.assert_close(m._backup_parameters, _params_dict(base_m)) + manager = create_autospec(Manager) + manager._use_async_quorum = False + with DiLoCo( + manager, model, inner_optimizer, outer_optimizer, sync_every=2 + ) as diloco: + parameter_count = len(list(model.parameters())) + initial_outer_opt_state = outer_optimizer.state_dict() + self.assertEqual(initial_outer_opt_state["state"], {}) + + self.assertEqual(diloco._local_step, 0) + torch.testing.assert_close(diloco._backup_parameters, _params_dict(model)) + inp = torch.rand(2, 3) + loss = model(inp).mean() + loss.backward() + inner_optimizer.step() + + self.assertEqual(diloco._local_step, 1) + self.assertEqual(manager.start_quorum.call_count, 0) + loss = model(inp).mean() + loss.backward() + inner_optimizer.step() + self.assertEqual(manager.start_quorum.call_count, 1) + + manager.should_commit.return_value = True + self.assertEqual(diloco._local_step, 0) + torch.testing.assert_close(diloco._backup_parameters, _params_dict(model)) + self.assertEqual(manager.should_commit.call_count, 1) + self.assertEqual(manager.allreduce.call_count, parameter_count) + + outer_opt_state = outer_optimizer.state_dict() + self.assertEqual(len(outer_opt_state["state"]), parameter_count) diff --git a/torchft/manager_integ_test.py b/torchft/manager_integ_test.py index 0721b17e..8c7c45d2 100644 --- a/torchft/manager_integ_test.py +++ b/torchft/manager_integ_test.py @@ -1,3 +1,4 @@ +import copy import logging import threading import time @@ -5,7 +6,7 @@ from contextlib import ExitStack, contextmanager from dataclasses import dataclass, field from datetime import timedelta -from typing import Dict, Generator, List, Protocol, Set, Tuple +from typing import Any, Dict, Generator, List, Protocol, Set, Tuple from unittest import TestCase import torch @@ -14,7 +15,7 @@ from torch import nn, optim from torchft.ddp import DistributedDataParallel -from torchft.local_sgd import LocalSGD +from torchft.local_sgd import DiLoCo, LocalSGD from torchft.manager import Manager from torchft.optim import OptimizerWrapper from torchft.process_group import ProcessGroupGloo @@ -76,6 +77,7 @@ class Runner: world_size: int = 1 attempts: int = 3 manager_args: Dict[str, object] = field(default_factory=dict) + train_loop_args: Dict[str, Any] = field(default_factory=dict) def _replica_main(self) -> List[Dict[str, Dict[str, object]]]: store = dist.TCPStore( @@ -103,7 +105,7 @@ def _replica_main(self) -> List[Dict[str, Dict[str, object]]]: try: fut.result() except Exception as e: - logger.exception(f"worker threw exception: {e}") + logger.exception(f"worker {self.replica_id=} threw exception: {e}") raise return [fut.result() for fut in futures] @@ -227,30 +229,114 @@ def state_dict() -> Dict[str, Dict[str, object]]: m: nn.Module = MyModel() optimizer: optim.Optimizer = optim.Adam(m.parameters()) - m = LocalSGD(manager, m, optimizer, sync_every=2) criterion = nn.CrossEntropyLoss() - while True: - inputs = torch.rand(2, 3) - labels = torch.randint(4, (2,)) + with LocalSGD(manager, m, optimizer, sync_every=2): + while True: + inputs = torch.rand(2, 3) + labels = torch.randint(4, (2,)) - optimizer.zero_grad() - out = m(inputs) - loss = criterion(out, labels) + optimizer.zero_grad() + out = m(inputs) + loss = criterion(out, labels) - loss.backward() + loss.backward() - optimizer.step() + optimizer.step() - if manager.current_step() >= 4: - break + if manager.current_step() >= 4: + break - runner.failure_injector.check(rank, manager.current_step()) + runner.failure_injector.check(rank, manager.current_step()) # return state_dict so we can check consistency return state_dict() +def diloco_train_loop( + rank: int, + store_port: int, + runner: Runner, +) -> Dict[str, Dict[str, object]]: + with ExitStack() as stack: + # Declare the model and optimizers + m: nn.Module = MyModel() + model_state_dict: Dict[str, Any] = runner.train_loop_args["model_state_dict"] + m.load_state_dict(model_state_dict) + + # Setup optimizers + inner_optimizer: optim.Optimizer = torch.optim.AdamW( + m.parameters(), lr=4e-4, weight_decay=0.1, betas=(0.9, 0.95) + ) + outer_optimizer: optim.Optimizer = torch.optim.SGD( + m.parameters(), lr=0.7, momentum=0.9, nesterov=True + ) + + # pyre-ignore[53] + def load_state_dict(state_dict: Dict[str, Dict[str, object]]) -> None: + m.load_state_dict(state_dict["model"]) + # TODO: make this cleaner so we don't have to save this + diloco._backup_parameters = state_dict["backup_params"] + inner_optimizer.load_state_dict(state_dict["inner_optim"]) + outer_optimizer.load_state_dict(state_dict["outer_optim"]) + + def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53] + return { + "model": m.state_dict(), + "backup_params": copy.deepcopy(diloco._backup_parameters), + "inner_optim": inner_optimizer.state_dict(), + "outer_optim": outer_optimizer.state_dict(), + } + + print(f"worker {runner.replica_id=} {rank=} {runner.world_size=} starting") + + pg = ProcessGroupGloo() + manager = Manager( + pg=pg, + min_replica_size=2, + use_async_quorum=False, + load_state_dict=load_state_dict, + state_dict=state_dict, + replica_id=str(runner.replica_id), + store_addr="localhost", + store_port=store_port, + rank=rank, + world_size=runner.world_size, + lighthouse_addr=runner.lighthouse_address, + port=19530 + runner.replica_id, + # pyre-fixme[6]: Incompatible parameter type + **runner.manager_args, + ) + stack.callback(manager.shutdown) + + criterion = nn.CrossEntropyLoss() + all_state_dicts = {} + with DiLoCo( + manager, m, inner_optimizer, outer_optimizer, sync_every=2 + ) as diloco: + while True: + inputs = torch.rand(2, 3) + labels = torch.randint(4, (2,)) + + out = m(inputs) + loss = criterion(out, labels) + + inner_optimizer.zero_grad() + loss.backward() + inner_optimizer.step() + manager_step_str = str(manager.current_step()) + all_state_dicts[manager_step_str] = state_dict() + + # after 4 model updates then break + if manager.current_step() >= 4: + break + + runner.failure_injector.check(rank, manager.current_step()) + + # return state_dict so we can check consistency + return all_state_dicts + + class ManagerIntegTest(TestCase): @contextmanager def assertElapsedLessThan( @@ -431,6 +517,108 @@ def test_local_sgd_recovery(self) -> None: self.assertEqual(failure_injectors[1].count, 1) + def test_diloco_healthy(self) -> None: + lighthouse = Lighthouse( + bind="[::]:0", + min_replicas=2, + ) + num_replicas = 2 + futures = [] + + torch.manual_seed(42) + # Initialize the model so we can pass in the state_dict + m: nn.Module = MyModel() + + with ThreadPoolExecutor(max_workers=num_replicas) as executor: + for replica_id in range(num_replicas): + failure_injector = FailureInjector() + runner = Runner( + replica_id=replica_id, + lighthouse_address=lighthouse.address(), + failure_injector=failure_injector, + train_loop=diloco_train_loop, + train_loop_args={ + "model_state_dict": m.state_dict(), + }, + ) + futures.append(executor.submit(runner.run_replica)) + + state_dicts = [] + + for fut in as_completed(futures): + state_dicts.append(fut.result()[0]) + + lighthouse.shutdown() + + for replica_group in state_dicts: + for step, state_dict in replica_group.items(): + # inner optimizer will be different, outer optimizer and model should be the same + torch.testing.assert_close( + state_dict["backup_params"], + state_dicts[0][str(step)]["backup_params"], + ) + torch.testing.assert_close( + state_dict["outer_optim"], state_dicts[0][str(step)]["outer_optim"] + ) + + def test_diloco_recovery(self) -> None: + lighthouse = Lighthouse( + bind="[::]:0", + min_replicas=2, + ) + num_replicas = 2 + futures = [] + + failure_injectors = [ + FailureInjector(), + FailureInjector().fail_at(0, 2), + ] + + torch.manual_seed(42) + # Initialize the model so we can pass in the state_dict + m: nn.Module = MyModel() + + with ThreadPoolExecutor(max_workers=num_replicas) as executor: + for replica_id, failure_injector in zip( + range(num_replicas), failure_injectors + ): + runner = Runner( + replica_id=replica_id, + lighthouse_address=lighthouse.address(), + failure_injector=failure_injector, + train_loop=diloco_train_loop, + train_loop_args={ + "model_state_dict": m.state_dict(), + }, + ) + futures.append(executor.submit(runner.run_replica)) + + state_dicts = [] + + for fut in as_completed(futures): + try: + state_dicts.append(fut.result()[0]) + except Exception as e: + print(e) + raise + + lighthouse.shutdown() + for replica_group in state_dicts: + for step, state_dict in replica_group.items(): + str_step = str(step) + if str_step in state_dicts[0]: + # inner optimizer will be different, outer optimizer and model should be the same + torch.testing.assert_close( + state_dict["backup_params"], + state_dicts[0][str_step]["backup_params"], + ) + torch.testing.assert_close( + state_dict["outer_optim"], + state_dicts[0][str_step]["outer_optim"], + ) + + self.assertEqual(failure_injectors[1].count, 1) + def test_quorum_timeout(self) -> None: with ExitStack() as stack: lighthouse = Lighthouse(