-
Notifications
You must be signed in to change notification settings - Fork 710
Description
1. Problem Description
When we deployed the decoder node on H20 with EP16/BS32, we found in the profile that DeepGEMM.v2 had a certain performance loss compared to v1. The following screenshot is from Deepseek_v2 MoE layer UP&GATE GEMM.
- v2:

- v1:

We are also doing some work on computation/communication overlap on H20, as shown in sgl-project/sglang#9660. We implement overlap without splitting the batch, which requires the computation kernel and the communication kernel to share the SMs. We found that in the v2 version, the gap between the GEMM kernels w/ and w/o overlap has increased significantly compared to the v1 version. By observing the profile information, we speculate that this is due to the introduction of a larger block_n size (192) in the v2 version on sm90. Therefore, we introduced a new parameter max_block_n
in #183 to limit the size of block_n. We conducted some experiments and the results showed that by setting max_block_n
to 160, we can restore the performance to the v1 version. The overall test results are shown below:
$python tests/test_v2.py
Library path:
> ['/root/eric.hc/DeepGEMM-Async/deep_gemm']
Testing m-grouped masked GEMM with max_block_n=256:
Warning: please use at least NVCC 12.9 for the best DeepGEMM performance
> Perf (num_groups=16, expected_m_per_group= 32, n=4096, k=7168, 1D2D, enable_overlap=False): 254 us | 112 TFLOPS | 1881 GB/s | BLOCK_N=192 | NUM_SMS=71
> Perf (num_groups=16, expected_m_per_group= 32, n=7168, k=2048, 1D2D, enable_overlap=False): 126 us | 118 TFLOPS | 1924 GB/s | BLOCK_N=192 | NUM_SMS=76
> Perf (num_groups=16, expected_m_per_group= 32, n=4096, k=7168, 1D2D, enable_overlap=True): 258 us | 113 TFLOPS | 1849 GB/s | BLOCK_N=192 | NUM_SMS=71
> Perf (num_groups=16, expected_m_per_group= 32, n=7168, k=2048, 1D2D, enable_overlap=True): 148 us | 93 TFLOPS | 1642 GB/s | BLOCK_N=192 | NUM_SMS=68
Testing m-grouped masked GEMM with max_block_n=160:
> Perf (num_groups=16, expected_m_per_group= 32, n=4096, k=7168, 1D2D, enable_overlap=False): 236 us | 123 TFLOPS | 2023 GB/s | BLOCK_N=144 | NUM_SMS=78
> Perf (num_groups=16, expected_m_per_group= 32, n=7168, k=2048, 1D2D, enable_overlap=False): 136 us | 113 TFLOPS | 1792 GB/s | BLOCK_N=160 | NUM_SMS=72
> Perf (num_groups=16, expected_m_per_group= 32, n=4096, k=7168, 1D2D, enable_overlap=True): 264 us | 107 TFLOPS | 1811 GB/s | BLOCK_N=160 | NUM_SMS=70
> Perf (num_groups=16, expected_m_per_group= 32, n=7168, k=2048, 1D2D, enable_overlap=True): 142 us | 108 TFLOPS | 1717 GB/s | BLOCK_N=160 | NUM_SMS=72
The above test results show that when using H20 for Deepseek_v2 in a BS32/EP16 deployment, UP&GATE GEMM performance suffers compared to v1, while DOWN GEMM performance improves. However, for overlapped workloads (which require sharing at least 3 SMs for communication on the H20; refer to deepseek-ai/DeepEP#390 for the specific implementation), there's also a certain performance loss on the DOWN GEMM side.
Also, based on performance comparisons with v1, we believe the current DeepGEMM get_best_config logic may not be optimal. This logic can be simply summarized as minimizing the total number of waves required for computation while maximizing the block count in the last wave. Assuming the total computation time is
Based on the above test results, we can roughly calculate the following table:
GEMM Type | num_groups |
expected_m |
n |
k |
block_m |
block_n |
block_k |
valid_sm |
sm_usage |
num_waves |
Time Cost (us) | Time/Wave (us) | TFLOPS |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
UP&GATE w/o overlap (max_n=256) | 16 | 32 | 4096 | 7168 | 64 | 192 | 128 | 78 | 71 | 5 | 254.0 |
50.80 |
112 |
UP&GATE w/ overlap (max_n=256) | 16 | 32 | 4096 | 7168 | 64 | 192 | 128 | 75 | 71 | 5 | 258.0 |
51.60 |
113 |
DOWN w/o overlap (max_n=256) | 16 | 32 | 7168 | 2048 | 64 | 192 | 128 | 78 | 76 | 8 | 126.0 |
15.75 |
118 |
DOWN w/ overlap (max_n=256) | 16 | 32 | 7168 | 2048 | 64 | 192 | 128 | 75 | 68 | 9 | 148.0 |
16.44 |
93 |
UP&GATE w/o overlap (max_n=160) | 16 | 32 | 4096 | 7168 | 64 | 144 | 128 | 78 | 78 | 7 | 236.0 |
33.71 |
123 |
UP&GATE w/ overlap (max_n=160) | 16 | 32 | 4096 | 7168 | 64 | 160 | 128 | 75 | 70 | 8 | 264.0 |
33.00 |
107 |
DOWN w/o overlap (max_n=160) | 16 | 32 | 7168 | 2048 | 64 | 160 | 128 | 78 | 72 | 10 | 136.0 |
13.60 |
113 |
DOWN w/ overlap (max_n=160) | 16 | 32 | 7168 | 2048 | 64 | 160 | 128 | 75 | 72 | 10 | 142.0 |
14.20 |
108 |
Based on our conjecture, we believe that the core reason for the suboptimal performance is that compressing the number of waves will prolong the calculation time of a single wave, and the total calculation time is determined by both. Therefore, optimizing only one of them is likely to lead to a suboptimal final result.
2. Reproduction
You can reproduce our results on SGLang using the branch and startup parameters in sgl-project/sglang#9660. You can also use the test code in https://github.com/Sulfur6/DeepGEMM/tree/sbo.v2.test to see the performance of GEMM in isolation.
3. Environment
$nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2024 NVIDIA Corporation
Built on Thu_Mar_28_02:18:24_PDT_2024
Cuda compilation tools, release 12.4, V12.4.131
Build cuda_12.4.r12.4/compiler.34097967_0