@@ -162,7 +162,7 @@ def reduce_scatter_quantized(
162
162
opts : ReduceScatterOptions | ReduceOp ,
163
163
process_group : "ProcessGroup" ,
164
164
sync_stream : cuda .Stream | None = None ,
165
- ) -> Future [ None ] :
165
+ ) -> Work :
166
166
"""
167
167
Performs a quantized reduce-scatter operation on a list of tensors.
168
168
@@ -196,10 +196,10 @@ def reduce_scatter_quantized(
196
196
"""
197
197
198
198
if isinstance (opts , ReduceOp ):
199
- reducescatter_opts = ReduceScatterOptions ()
199
+ reducescatter_opts : ReduceScatterOptions = ReduceScatterOptions ()
200
200
reducescatter_opts .reduceOp = opts
201
201
else :
202
- reducescatter_opts = opts
202
+ reducescatter_opts : ReduceScatterOptions = opts
203
203
204
204
# Check if the reduceOp is AVG or SUM
205
205
if reducescatter_opts .reduceOp not in {
@@ -211,15 +211,15 @@ def reduce_scatter_quantized(
211
211
f"for quantized reduce-scatter, only AVG and SUM are supported"
212
212
)
213
213
214
- rank = process_group .rank ()
215
- world_size = process_group .size ()
214
+ rank : int = process_group .rank ()
215
+ world_size : int = process_group .size ()
216
216
217
217
reduce_output_sizes = [
218
218
torch .Size ((s [0 ] // world_size , * s [1 :]))
219
219
for s in get_padded_sizes (inputs , world_size )
220
220
]
221
221
reduce_output_numels = [s .numel () for s in reduce_output_sizes ]
222
- reduce_outputs = [
222
+ reduce_outputs : list [ torch . Tensor ] = [
223
223
o .view (s )
224
224
for o , s in zip (
225
225
output .split (reduce_output_numels ),
@@ -240,48 +240,51 @@ def reduce_scatter_quantized(
240
240
quantized_inputs = fused_quantize_into_fp8 (inputs , world_size )
241
241
242
242
# Allocate output tensor where all-reduce results will be stored
243
- quantized_inputs_out = torch .zeros_like (quantized_inputs )
243
+ quantized_inputs_out : torch . Tensor = torch .zeros_like (quantized_inputs )
244
244
# Collect chunks and their scales from other ranks
245
- process_group .alltoall_base (
245
+ work = process_group .alltoall_base (
246
246
quantized_inputs_out .view (world_size , - 1 ),
247
247
quantized_inputs .view (world_size , - 1 ),
248
248
[],
249
249
[],
250
250
_to_alltoall_options (reducescatter_opts ),
251
- ).wait ()
252
-
253
- # Reduce chunks locally in higher precision after dequantization.
254
- # The output is again quantized.
255
- fused_reduce_fp8 (
256
- inputs ,
257
- quantized_inputs_out ,
258
- world_size ,
259
- rank ,
260
- reducescatter_opts .reduceOp ,
261
251
)
252
+ work .wait ()
262
253
263
- # Get view into the output tensor that corresponds to the
264
- # current rank
265
- quantized_reduce_scatter = (
266
- quantized_inputs_out .view (world_size , - 1 ).split (1 )[rank ].squeeze (0 )
267
- )
268
- # Dequantize the result back to the original precision for
269
- # the current rank
270
- fused_dequantize_from_fp8 (
271
- reduce_outputs ,
272
- quantized_reduce_scatter ,
273
- 1 ,
274
- )
254
+ fut = work .get_future ()
275
255
276
- # pyre-ignore[29]
277
- return _QuantizedOpFuture (
278
- sync_stream ,
279
- [
280
- quantized_inputs ,
281
- quantized_inputs_out ,
282
- ],
283
- [output ],
284
- )
256
+ def callback (fut : Future [list [torch .Tensor ]]) -> None :
257
+ nonlocal inputs , quantized_inputs_out , world_size , sync_stream , rank , reduce_outputs , reducescatter_opts
258
+
259
+ with torch .cuda .stream (sync_stream ):
260
+ # Setup stream dependency
261
+ fut .wait ()
262
+ # Reduce chunks locally in higher precision after dequantization.
263
+ # The output is again quantized.
264
+ fused_reduce_fp8 (
265
+ inputs ,
266
+ quantized_inputs_out ,
267
+ world_size ,
268
+ rank ,
269
+ reducescatter_opts .reduceOp ,
270
+ )
271
+
272
+ # Get view into the output tensor that corresponds to the
273
+ # current rank
274
+ quantized_reduce_scatter = (
275
+ quantized_inputs_out .view (world_size , - 1 ).split (1 )[rank ].squeeze (0 )
276
+ )
277
+ # Dequantize the result back to the original precision for
278
+ # the current rank
279
+ fused_dequantize_from_fp8 (
280
+ reduce_outputs ,
281
+ quantized_reduce_scatter ,
282
+ 1 ,
283
+ )
284
+
285
+ fut .add_done_callback (callback )
286
+
287
+ return work
285
288
286
289
287
290
def allreduce_quantized (
0 commit comments