Skip to content

Commit c5f036e

Browse files
authored
Merge pull request #333 from kvcache-ai/feat_experts_gpu
toy support for experts on GPU, no CUDA Graph
2 parents ae8da01 + 8ed8eb2 commit c5f036e

File tree

7 files changed

+204
-68
lines changed

7 files changed

+204
-68
lines changed

doc/en/FAQ.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ from-https://github.com/kvcache-ai/ktransformers/issues/129#issue-2842799552
2525
1. local_chat.py: You can increase the context window size by setting `--max_new_tokens` to a larger value.
2626
2. server: Increase the `--cache_lens' to a larger value.
2727
2. Move more weights to the GPU.
28-
Refer to the ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-marlin.yaml
28+
Refer to the ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-4.yaml
2929
```yaml
3030
- match:
3131
name: "^model\\.layers\\.([4-10])\\.mlp\\.experts$" # inject experts in layer 4~10 as marlin expert
@@ -39,6 +39,8 @@ from-https://github.com/kvcache-ai/ktransformers/issues/129#issue-2842799552
3939
You can modify layer as you want, eg. `name: "^model\\.layers\\.([4-10])\\.mlp\\.experts$"` to `name: "^model\\.layers\\.([4-12])\\.mlp\\.experts$"` to move more weights to the GPU.
4040

4141
> Note: The first matched rule in yaml will be applied. For example, if you have two rules that match the same layer, only the first rule's replacement will be valid.
42+
> Note:Currently, executing experts on the GPU will conflict with CUDA Graph. Without CUDA Graph, there will be a significant slowdown. Therefore, unless you have a substantial amount of VRAM (placing a single layer of experts for DeepSeek-V3/R1 on the GPU requires at least 5.6GB of VRAM), we do not recommend enabling this feature. We are actively working on optimization.
43+
> Note KExpertsTorch is untested.
4244

4345

4446
### Q: If I don't have enough VRAM, but I have multiple GPUs, how can I utilize them?

ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
#include <c10/cuda/CUDAGuard.h>
1818

1919
__global__ void dequantize_q8_0_kernel(float* output, const float* scales, const int8_t* qs, int num_blocks, int blk_size) {
20-
int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
21-
for (auto block_id=global_idx; block_id<num_blocks;block_id+=blockDim.x * gridDim.x){
20+
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
21+
for (long long block_id=global_idx; block_id<num_blocks;block_id+=blockDim.x * gridDim.x){
2222
for(int i=0;i<blk_size;i++){
2323
float scale = scales[block_id];
2424
output[block_id * blk_size + i] = scale * qs[block_id * blk_size + i];
@@ -37,8 +37,8 @@ __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t * __restrict_
3737
}
3838

3939
__global__ void dequantize_q2_k_kernel(int8_t* data, float* output, int blk_size, int num_blocks) {
40-
int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
41-
for (auto block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){
40+
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
41+
for (long long block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){
4242
float* __restrict__ output_blk = (float*)(output + block_id * 256);
4343

4444
const float d = __half2float(*(reinterpret_cast<half*>(data + block_id * blk_size + 80)));
@@ -72,10 +72,10 @@ __global__ void dequantize_q2_k_kernel(int8_t* data, float* output, int blk_size
7272

7373
__global__ void dequantize_q3_k_kernel(int8_t* data, float* output, int blk_size, int num_blocks) {
7474

75-
int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
75+
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
7676
const uint32_t kmask1 = 0x03030303;
7777
const uint32_t kmask2 = 0x0f0f0f0f;
78-
for (auto block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){
78+
for (long long block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){
7979
float* __restrict__ output_blk = (float*)(output + block_id * 256);
8080

8181
uint32_t aux[4];
@@ -128,8 +128,8 @@ __global__ void dequantize_q3_k_kernel(int8_t* data, float* output, int blk_size
128128

129129

130130
__global__ void dequantize_q4_k_kernel(int8_t* data, float* output, int blk_size, int num_blocks) {
131-
int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
132-
for (auto block_id=global_idx; block_id<num_blocks;block_id+=blockDim.x * gridDim.x){
131+
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
132+
for (long long block_id=global_idx; block_id<num_blocks; block_id+=blockDim.x * gridDim.x){
133133
float* __restrict__ output_blk = (float*)(output + block_id * 256);
134134
// const uint8_t * q = data[i].qs;
135135
const uint8_t * q = (uint8_t*)(data + block_id * 144 + 16);
@@ -152,8 +152,8 @@ __global__ void dequantize_q4_k_kernel(int8_t* data, float* output, int blk_size
152152
}
153153

154154
__global__ void dequantize_q5_k_kernel(int8_t* data, float* output, int blk_size, int num_blocks) {
155-
int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
156-
for (auto block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){
155+
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
156+
for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x){
157157
float* __restrict__ output_blk = (float*)(output + block_id * 256);
158158

159159
const float d = __half2float(*(reinterpret_cast<half*>(data + block_id * blk_size + 0)));
@@ -181,8 +181,8 @@ __global__ void dequantize_q5_k_kernel(int8_t* data, float* output, int blk_size
181181
}
182182

183183
__global__ void dequantize_q6_k_kernel(int8_t* data, float* output, int blk_size, int num_blocks) {
184-
int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
185-
for (auto block_id=global_idx; block_id<num_blocks;block_id+=blockDim.x * gridDim.x){
184+
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
185+
for (long long block_id=global_idx; block_id<num_blocks;block_id+=blockDim.x * gridDim.x){
186186
float* __restrict__ output_blk = (float*)(output + block_id * 256);
187187
const float d = __half2float(*(reinterpret_cast<half*>(data + block_id * blk_size + 208)));
188188

@@ -215,8 +215,8 @@ __global__ void dequantize_q6_k_kernel(int8_t* data, float* output, int blk_size
215215
static constexpr __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
216216

217217
__global__ void dequantize_iq4_xs_kernel(int8_t* data, float* output, int blk_size, int num_blocks) {
218-
int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
219-
for (auto block_id=global_idx; block_id<num_blocks; block_id+=blockDim.x * gridDim.x) {
218+
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
219+
for (long long block_id=global_idx; block_id<num_blocks; block_id+=blockDim.x * gridDim.x) {
220220
float* __restrict__ output_blk = (float*)(output + block_id * 256);
221221
const float d = __half2float(*(reinterpret_cast<half*>(data + block_id * blk_size)));
222222
const uint16_t scales_h = *(reinterpret_cast<uint16_t*>(data + block_id * blk_size + 2));

ktransformers/operators/experts.py

Lines changed: 93 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import torch
1919
import sys, os
2020
from ktransformers.operators.base_operator import BaseInjectedModule
21+
from tqdm import tqdm
2122

2223
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build"))
2324
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Release"))
@@ -225,6 +226,7 @@ def unload(self):
225226
return
226227

227228
def load_weights(self, override_key: str | None = None, device: str = "cpu"):
229+
# TODO: support Bias
228230
res = {}
229231
if override_key is not None:
230232
keys = override_key
@@ -288,6 +290,8 @@ def __init__(
288290
self.act_fn = ACT2FN[config.hidden_act]
289291
assert device.lower() != "cpu", "Marlin experts can only be loaded on GPU"
290292
self.device = device
293+
self.elements_per_tensor = config.moe_intermediate_size * config.hidden_size
294+
291295
# create empty marlin experts according to the number of experts per token
292296
# up
293297
self.up_projs = [KLinearMarlin(key+ "." + "ffn_up_exps", gguf_loader, config, device=device) for i in range(self.expert_num)]
@@ -299,17 +303,34 @@ def __init__(
299303
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str | None = None, warmup: bool = False):
300304
if device is None: device = self.device
301305
assert device.lower() != "cpu", "Marlin experts can only be loaded on GPU"
302-
if w is None: w = self.load_weights()[self.key]
303-
304-
if isinstance(w, dict):
305-
self.gate = w["gate"]
306-
self.up = (w["up"])
307-
self.down = (w["down"])
308-
for i in range(self.expert_num):
309-
self.up_projs[i].load(nn.Parameter(self.up[i,...]), device=device)
310-
self.gate_projs[i].load(nn.Parameter(self.gate[i,...]), device=device)
311-
self.down_projs[i].load(nn.Parameter(self.down[i,...]), device=device)
312-
self.loaded_experts_idx.append(i)
306+
if w is None:
307+
w = self.load_weights()
308+
load_by_experts = True
309+
310+
if load_by_experts:
311+
if isinstance(w, dict):
312+
self.gate = w["gate"]
313+
self.up = (w["up"])
314+
self.down = (w["down"])
315+
for i in tqdm(range(self.expert_num), desc=f"Dequanting and quanting for KExpertsMarlin {self.key}"):
316+
up_weights = self.gguf_loader.load_expert_tensor(self.key + ".ffn_up_exps.weight", self.up, i, self.elements_per_tensor, device=self.device)
317+
gate_weights = self.gguf_loader.load_expert_tensor(self.key + ".ffn_gate_exps.weight", self.gate, i, self.elements_per_tensor, device=self.device)
318+
down_weights = self.gguf_loader.load_expert_tensor(self.key + ".ffn_down_exps.weight", self.down, i, self.elements_per_tensor, device=self.device)
319+
320+
self.up_projs[i].load(nn.Parameter(up_weights), device=device)
321+
self.gate_projs[i].load(nn.Parameter(gate_weights), device=device)
322+
self.down_projs[i].load(nn.Parameter(down_weights), device=device)
323+
self.loaded_experts_idx.append(i)
324+
else:
325+
if isinstance(w, dict):
326+
self.gate = w["gate"]
327+
self.up = (w["up"])
328+
self.down = (w["down"])
329+
for i in range(self.expert_num):
330+
self.up_projs[i].load(nn.Parameter(self.up[i,...]), device=device)
331+
self.gate_projs[i].load(nn.Parameter(self.gate[i,...]), device=device)
332+
self.down_projs[i].load(nn.Parameter(self.down[i,...]), device=device)
333+
self.loaded_experts_idx.append(i)
313334
return
314335

315336
def unload(self):
@@ -329,20 +350,13 @@ def load_weights(self, override_key: str | None = None):
329350
gate = None
330351
up = None
331352
down = None
332-
gate_type = None
333-
up_type = None
334-
down_type = None
335353

336354
for key in keys:
337355
if key + ".ffn_gate_exps.weight" in self.gguf_loader.tensor_info:
338-
gate = self.gguf_loader.load_gguf_tensor(key + ".ffn_gate_exps.weight")
339-
up = self.gguf_loader.load_gguf_tensor(key + ".ffn_up_exps.weight")
340-
down = self.gguf_loader.load_gguf_tensor(key + ".ffn_down_exps.weight")
341-
gate_type = self.gguf_loader.tensor_info[key + ".ffn_gate_exps.weight"]["ggml_type"]
342-
up_type = self.gguf_loader.tensor_info[key + ".ffn_up_exps.weight"]["ggml_type"]
343-
down_type = self.gguf_loader.tensor_info[key + ".ffn_down_exps.weight"]["ggml_type"]
344-
# tensors = self.load_multi(key, [".ffn_gate_exps.weight", ".ffn_up_exps.weight", ".ffn_down_exps.weight"])
345-
res = {key:{"gate": nn.Parameter(gate), "up": nn.Parameter(up), "down": nn.Parameter(down), "gate_type": gate_type, "up_type": up_type, "down_type": down_type}}
356+
gate = self.gguf_loader.get_mmap_tensor(key + ".ffn_gate_exps.weight")
357+
up = self.gguf_loader.get_mmap_tensor(key + ".ffn_up_exps.weight")
358+
down = self.gguf_loader.get_mmap_tensor(key + ".ffn_down_exps.weight")
359+
res = {"gate": gate, "up": up, "down": down}
346360
return res
347361

348362
def forward(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.Tensor, routing_weights_cpu: torch.Tensor) -> torch.Tensor:
@@ -381,6 +395,7 @@ def forward(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.T
381395

382396
return final_hidden_states.to(dtype=org_dtype, device=org_device)
383397

398+
# untested, CUDA OOM
384399
class KExpertsTorch(KExpertsBase):
385400
expert_num: int
386401
loaded_experts_idx: list[int]
@@ -402,26 +417,65 @@ def __init__(
402417
# self.loaded_experts_idx = []
403418
self.act_fn = ACT2FN[config.hidden_act]
404419
self.device = device
405-
self.gate = None
406-
self.up = None
407-
self.donw = None
420+
self.elements_per_tensor = config.moe_intermediate_size * config.hidden_size
421+
self.gate = [None for _ in range(self.expert_num)]
422+
self.up = [None for _ in range(self.expert_num)]
423+
self.down = [None for _ in range(self.expert_num)]
408424
self.dtype = torch.get_default_dtype()
409425

410426
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str | None = None, warmup: bool = False):
411427
if device is None: device = self.device
412-
if w is None: w = self.load_weights(device=device)[self.key]
413-
414-
if isinstance(w, dict):
415-
self.gate = w["gate"].to(device=device, dtype=self.dtype)
416-
self.up = w["up"].to(device=device, dtype=self.dtype)
417-
self.down = w["down"].to(device=device, dtype=self.dtype)
428+
if w is None:
429+
w = self.load_weights()
430+
load_by_experts = True
431+
432+
if load_by_experts:
433+
if isinstance(w, dict):
434+
for i in tqdm(range(self.expert_num), desc=f"Dequanting for KExpertsTorch {self.key}"):
435+
up_weights = self.gguf_loader.load_expert_tensor(self.key + ".ffn_up_exps.weight", w["up"], i, self.elements_per_tensor, device=self.device)
436+
gate_weights = self.gguf_loader.load_expert_tensor(self.key + ".ffn_gate_exps.weight", w["gate"], i, self.elements_per_tensor, device=self.device)
437+
down_weights = self.gguf_loader.load_expert_tensor(self.key + ".ffn_down_exps.weight", w["down"], i, self.elements_per_tensor, device=self.device)
438+
439+
self.up[i] = up_weights
440+
self.gate[i] = gate_weights
441+
self.down[i] = down_weights
442+
else:
443+
if isinstance(w, dict):
444+
for i in range(self.expert_num):
445+
self.gate[i] = w["gate"][i, ...].to(device=device, dtype=self.dtype)
446+
self.up[i] = w["up"][i, ...].to(device=device, dtype=self.dtype)
447+
self.down[i] = w["down"][i, ...].to(device=device, dtype=self.dtype)
448+
449+
self.up = torch.cat(self.gate, dim=0)
450+
self.gate = torch.cat(self.gate, dim=0)
451+
self.down = torch.cat(self.gate, dim=0)
452+
return
418453

419454
def unload(self):
420455
if self.gate is not None:
421456
self.gate = None
422457
self.up = None
423458
self.down = None
424459

460+
def load_weights(self, override_key: str | None = None):
461+
res = {}
462+
if override_key is not None:
463+
keys = override_key
464+
else:
465+
keys = [self.key]
466+
467+
gate = None
468+
up = None
469+
down = None
470+
471+
for key in keys:
472+
if key + ".ffn_gate_exps.weight" in self.gguf_loader.tensor_info:
473+
gate = self.gguf_loader.get_mmap_tensor(key + ".ffn_gate_exps.weight")
474+
up = self.gguf_loader.get_mmap_tensor(key + ".ffn_up_exps.weight")
475+
down = self.gguf_loader.get_mmap_tensor(key + ".ffn_down_exps.weight")
476+
res = {"gate": gate, "up": up, "down": down}
477+
return res
478+
425479
def forward(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.Tensor, routing_weights_cpu: torch.Tensor) -> torch.Tensor:
426480

427481
org_device = hidden_states_cpu.device
@@ -582,7 +636,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
582636

583637
if isinstance(self.experts, KExpertsBase):
584638
y = (
585-
self.moe_on_cpuinfer(
639+
self.moe_kexperts(
586640
hidden_states_expert, selected_experts_expert, routing_weights_expert
587641
)
588642
.view(*orig_shape)
@@ -601,8 +655,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
601655
return y, router_logits
602656

603657
@torch.no_grad()
604-
def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
605-
outs = torch.empty_like(x)
658+
def moe_kexperts(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
606659
outs = self.experts(x, topk_ids, topk_weight)
607660
return outs
608661

@@ -672,7 +725,7 @@ def forward(self, hidden_states):
672725
y_ = self.shared_experts(identity).squeeze(0)
673726

674727
if isinstance(self.experts, KExpertsBase):
675-
y = self.moe_on_cpuinfer(hidden_states, topk_idx, topk_weight).view(*orig_shape).to(device=hidden_states.device)
728+
y = self.moe_kexperts(hidden_states, topk_idx, topk_weight).view(*orig_shape).to(device=hidden_states.device)
676729
elif hidden_states.size(0) > 10:
677730
# TODO may bugs here
678731
y = (
@@ -692,8 +745,7 @@ def forward(self, hidden_states):
692745
return y
693746

694747
@torch.no_grad()
695-
def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
696-
outs = torch.empty_like(x)
748+
def moe_kexperts(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
697749
outs = self.experts(x, topk_ids, topk_weight)
698750
return outs
699751

@@ -773,7 +825,7 @@ def forward(self, hidden_states):
773825
y_ = self.shared_experts(identity).squeeze(0)
774826

775827
if isinstance(self.experts, KExpertsBase):
776-
y = self.moe_on_cpuinfer(hidden_states, topk_idx, topk_weight).view(*orig_shape).to(device=hidden_states.device)
828+
y = self.moe_kexperts(hidden_states, topk_idx, topk_weight).view(*orig_shape).to(device=hidden_states.device)
777829
elif hidden_states.size(0) > 10:
778830
# TODO may bugs here
779831
y = (
@@ -793,8 +845,7 @@ def forward(self, hidden_states):
793845
return y
794846

795847
@torch.no_grad()
796-
def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
797-
outs = torch.empty_like(x)
848+
def moe_kexperts(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
798849
outs = self.experts(x, topk_ids, topk_weight)
799850
return outs
800851

@@ -881,7 +932,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
881932

882933
if isinstance(self.experts, KExpertsBase):
883934
y = (
884-
self.moe_on_cpuinfer(
935+
self.moe_kexperts(
885936
hidden_states_expert, selected_experts_expert, routing_weights_expert
886937
)
887938
.view(*orig_shape)
@@ -900,8 +951,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
900951
return y, router_logits
901952

902953
@torch.no_grad()
903-
def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
904-
outs = torch.empty_like(x)
954+
def moe_kexperts(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
905955
outs = self.experts(x, topk_ids, topk_weight)
906956
return outs
907957

0 commit comments

Comments
 (0)