-
Notifications
You must be signed in to change notification settings - Fork 624
Open
Description
I have tested mxfp8 train for Qwen MoE models, and for DeepSeekV3 16b on B200. It did not show any speed up and even slows down in some case when I use mxfp8 (quantize.grouped_mm.mx).
I found this in torchao repo, saying that mxfp8 gives up to 1.6x speed up for DeepSeekV3 671b. Looks like it works only for big MoE models?
I have tried benchmarking of single MoE layer like here.
This is what I got with dims used in DeepSeekV3 16b [dim=2048, moe_inter_dim=1408]:
$ python -m benchmarks.prototype.moe_training.bench_moe_layer --recipe mxfp8 --local_batch_size=16 --dim=2048 --hidden_dim=1408 --local_num_experts=8
total_M: 131072, N: 1408, K: 2048
bf16 time: 16.882 ms
mxfp8 time: 17.710 ms
speedup: 0.953x
I couldn't get any speedup on Qwen3 235B-A22B and 30B-A3B too.
Benchmarking of MoE layer with dims form Qwen3 235B-A22B [dim=4096, moe_inter_dim=1536] is following:
$ python -m benchmarks.prototype.moe_training.bench_moe_layer --recipe mxfp8 --local_batch_size=16 --dim=4096 --hidden_dim=1536 --local_num_experts=8
total_M: 131072, N: 1536, K: 4096
bf16 time: 34.154 ms
mxfp8 time: 34.196 ms
speedup: 0.999x
Is there a any way, how I can get speed up using mxfp8 for above models?
Metadata
Metadata
Assignees
Labels
No labels