@@ -182,7 +182,6 @@ def _manager_state_dict() -> Dict[str, T]:
182182 self ._batches_committed = 0
183183
184184 # first step is 1
185- self ._should_step = True
186185 self ._participating_rank : Optional [int ] = None
187186 self ._participating_world_size : int = 0
188187
@@ -218,8 +217,7 @@ def allreduce_grad(self, grad: torch.Tensor) -> torch.futures.Future[torch.Tenso
218217 fut .set_result (grad )
219218 return fut
220219
221- assert self ._quorum_future is not None , "must call step before allreduce_grad"
222- self ._quorum_future .result ()
220+ self .wait_quorum ()
223221
224222 if not self .is_participating ():
225223 grad .zero_ ()
@@ -315,21 +313,28 @@ def callback(
315313 self ._pending_work .append (cast (torch .futures .Future [object ], fut ))
316314 return fut
317315
318- def start_step (self ) -> None :
316+ def start_quorum (self , allow_heal : bool = True ) -> None :
319317 """
320318 .. note::
321319 We recommend using the :py:class:`torchft.optim.OptimizerWrapper` instead of calling this directly.
322320
323321 Computes a new quorum (potentially asynchronously) and readies the
324322 manager for a new step.
325323
326- Must be called before the forwards pass of each step for best
324+ It's best practice to call this before the forwards pass of each step for
327325 performance as computing quorum may take some time.
326+
327+ If allow_heal is set, the manager will attempt to heal either
328+ synchronously before returning or asynchronously prior to any network
329+ calls.
330+
331+ Args:
332+ allow_heal: whether to allow healing at the beginning of the step
328333 """
329334
330- if self . _should_step :
331- self ._step += 1
332- self ._batches_committed += self . num_participants ()
335+ # wait for previous quorum to complete
336+ if self ._quorum_future is not None :
337+ self ._quorum_future . result ()
333338
334339 self ._errored = None
335340 self ._healing = False
@@ -338,9 +343,9 @@ def start_step(self) -> None:
338343 # TODO: we should really be wrapping this whole section in a try-except
339344 # block to allow gracefully recovering from issues in PG setup and quorum.
340345
341- self ._quorum_future = self ._executor .submit (self ._async_quorum )
346+ self ._quorum_future = self ._executor .submit (self ._async_quorum , allow_heal )
342347 if not self ._use_async_quorum :
343- self ._quorum_future . result ()
348+ self .wait_quorum ()
344349
345350 if self ._healing :
346351 # eagerly apply pending state_dict so we can run the forwards pass
@@ -350,7 +355,18 @@ def start_step(self) -> None:
350355 # and don't need to zero_grad
351356 self ._healing = False
352357
353- def _async_quorum (self ) -> None :
358+ def wait_quorum (self ) -> None :
359+ """
360+ Wait for the quorum to complete.
361+
362+ ProcessGroup will be in a healthy state after this returns.
363+ """
364+ assert (
365+ self ._quorum_future is not None
366+ ), "must call start_quorum before wait_quorum"
367+ self ._quorum_future .result ()
368+
369+ def _async_quorum (self , allow_heal : bool ) -> None :
354370 (
355371 quorum_id ,
356372 replica_rank ,
@@ -372,7 +388,7 @@ def _async_quorum(self) -> None:
372388 # workers will be healthy.
373389 self ._participating_rank , self ._participating_world_size = (
374390 (max_rank , max_world_size )
375- if self ._use_async_quorum
391+ if self ._use_async_quorum or not allow_heal
376392 else (replica_rank , replica_world_size )
377393 )
378394
@@ -397,7 +413,7 @@ def _async_quorum(self) -> None:
397413 self ._quorum_id = quorum_id
398414
399415 # See manager.rs for healing conditions
400- if heal :
416+ if heal and allow_heal :
401417 self ._healing = True
402418 self ._logger .info (
403419 f"healing required, fetching checkpoint server address from { address = } { max_step = } "
@@ -475,7 +491,9 @@ def should_commit(self) -> bool:
475491 self ._ckpt_server .disallow_checkpoint ()
476492
477493 # decide whether we're in a healthy state to increase the step count
478- self ._should_step = should_commit
494+ if should_commit :
495+ self ._step += 1
496+ self ._batches_committed += self .num_participants ()
479497
480498 return should_commit
481499
0 commit comments