@@ -1204,30 +1204,6 @@ def callback(
1204
1204
return work
1205
1205
1206
1206
1207
- class _ManagedWork (Work ):
1208
- def __init__ (self , manager : "Manager" , work : Work , default_result : object ) -> None :
1209
- super ().__init__ ()
1210
-
1211
- self ._manager = manager
1212
- self ._work = work
1213
- self ._default_result = default_result
1214
-
1215
- def wait (self , timeout : Optional [timedelta ] = None ) -> bool :
1216
- try :
1217
- if self ._work is not None :
1218
- if timeout is not None :
1219
- self ._work .wait (timeout )
1220
- else :
1221
- self ._work .wait ()
1222
- except Exception as e :
1223
- self ._manager .report_error (e )
1224
-
1225
- return True
1226
-
1227
- def get_future (self ) -> Future [object ]:
1228
- return self ._manager .wrap_future (self ._work .get_future (), self ._default_result )
1229
-
1230
-
1231
1207
class ManagedProcessGroup (ProcessGroupWrapper ):
1232
1208
"""
1233
1209
This is a wrapper around any ProcessGroup that is managed by a torchft
@@ -1247,23 +1223,13 @@ def __init__(self, manager: "Manager") -> None:
1247
1223
self ._manager = manager
1248
1224
1249
1225
def allreduce (self , tensors : List [torch .Tensor ], opts : object ) -> Work :
1250
- # Ensure we have a valid quorum and are configured before trying to do
1251
- # any work.
1252
- self ._manager .wait_quorum ()
1226
+ if isinstance (opts , ReduceOp ):
1227
+ return self ._manager .allreduce (tensors , reduce_op = opts )
1253
1228
1254
- if self ._manager .errored () is not None :
1255
- return _DummyWork (tensors )
1256
- try :
1257
- work = super ().allreduce (tensors , opts )
1258
- except Exception as e :
1259
- self ._manager .report_error (e )
1260
- return _DummyWork (tensors )
1229
+ if isinstance (opts , AllreduceOptions ):
1230
+ return self ._manager .allreduce (tensors , reduce_op = opts .reduceOp )
1261
1231
1262
- return _ManagedWork (
1263
- self ._manager ,
1264
- work ,
1265
- tensors ,
1266
- )
1232
+ assert False , "unreachable"
1267
1233
1268
1234
def size (self ) -> int :
1269
1235
return self ._manager .num_participants ()
0 commit comments