Skip to content

Commit 168d4b7

Browse files
[mxfp8 moe training] integrate mxfp8 dim0 triton kernel
stack-info: PR: #3129, branch: danielvegamyhre/stack/76
1 parent c62b0f0 commit 168d4b7

File tree

3 files changed

+17
-12
lines changed

3 files changed

+17
-12
lines changed

test/prototype/mx_formats/test_kernels.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -440,12 +440,11 @@ def triton_to_mxfp8_dim0_reference(
440440
"""
441441
from torchao.prototype.mx_formats.mx_tensor import to_mx
442442

443-
# cast across dim0 (rowwise) - no transpose needed
444443
scale_e8m0_dim0, x_hp_d0_normalized = to_mx(x_hp, torch.float8_e4m3fn, block_size)
445444
scale_e8m0_dim0 = scale_e8m0_dim0.view(torch.float8_e8m0fnu)
446445
return (
447446
x_hp_d0_normalized,
448-
scale_e8m0_dim0.unsqueeze(-1),
447+
scale_e8m0_dim0,
449448
)
450449

451450

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
MXGemmKernelChoice,
3333
ScaleCalculationMode,
3434
)
35-
from torchao.prototype.mx_formats.mx_tensor import to_mx
35+
from torchao.prototype.mx_formats.kernels import triton_to_mxfp8_dim0
3636
from torchao.prototype.mx_formats.utils import _to_mxfp8_dim1_kernel_wrapper
3737

3838
logger: logging.Logger = logging.getLogger(__name__)
@@ -303,16 +303,16 @@ def forward(
303303

304304
# A_data shape: (M, K)
305305
# A_scale shape: (M, K//block_size)
306-
A_scale, A_data = to_mx(
307-
A, elem_dtype=torch.float8_e4m3fn, block_size=block_size
306+
A_data, A_scale = triton_to_mxfp8_dim0(
307+
A,
308+
inner_block_size=block_size,
308309
)
309310

310311
# B_data shape: (E, N, K)
311312
# B_scale shape: (E, N, K//block_size)
312-
B_scales, B_data = to_mx(
313+
B_data, B_scales = triton_to_mxfp8_dim0(
313314
B_t.transpose(-2, -1),
314-
elem_dtype=torch.float8_e4m3fn,
315-
block_size=block_size,
315+
inner_block_size=block_size,
316316
)
317317

318318
# Convert scales to blocked format for 2d-3d grouped mm
@@ -351,8 +351,8 @@ def backward(ctx, grad_out: torch.Tensor):
351351

352352
# grad_out_data shape: (M, N)
353353
# grad_out_scale shape: (M, N//block_size)
354-
grad_out_scale, grad_out_data = to_mx(
355-
grad_out, elem_dtype=torch.float8_e4m3fn, block_size=block_size
354+
grad_out_data, grad_out_scale = triton_to_mxfp8_dim0(
355+
grad_out, inner_block_size=block_size
356356
)
357357

358358
# Quantize 3d expert weights along N (contraction dimension for next grouped gemm)

torchao/prototype/mx_formats/kernels.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1162,7 +1162,9 @@ def triton_to_mxfp8_dim0(
11621162
assert x.is_contiguous(), "`x` must be contiguous"
11631163
assert inner_block_size <= 32
11641164

1165-
# Get tensor shape
1165+
# Reshape tensor to 2d if necessary and get shape
1166+
x_orig_shape = x.shape
1167+
x = x.reshape(-1, x.shape[-1])
11661168
n_rows, n_cols = x.shape
11671169

11681170
# Masking of loads and stores is not well tested yet, so for now enforce
@@ -1181,7 +1183,7 @@ def triton_to_mxfp8_dim0(
11811183

11821184
# Create scale tensors for rowwise scaling
11831185
row_scale = torch.empty(
1184-
(n_rows, n_cols // inner_block_size, 1),
1186+
(n_rows, n_cols // inner_block_size),
11851187
dtype=torch.uint8,
11861188
device=x.device,
11871189
)
@@ -1202,6 +1204,10 @@ def triton_to_mxfp8_dim0(
12021204
INNER_BLOCK_SIZE=inner_block_size,
12031205
)
12041206

1207+
# Reshape output back to original shape
1208+
output = output.reshape(x_orig_shape)
1209+
row_scale = row_scale.reshape(*x_orig_shape[:-1], row_scale.shape[-1])
1210+
12051211
return (
12061212
output,
12071213
row_scale.view(torch.float8_e8m0fnu),

0 commit comments

Comments
 (0)