1818import torch
1919import sys , os
2020from ktransformers .operators .base_operator import BaseInjectedModule
21+ from tqdm import tqdm
2122
2223sys .path .append (os .path .join (os .path .dirname (__file__ ), ".." , "ktransformers_ext" , "build" ))
2324sys .path .append (os .path .join (os .path .dirname (__file__ ), ".." , "ktransformers_ext" , "build" , "Release" ))
@@ -225,6 +226,7 @@ def unload(self):
225226 return
226227
227228 def load_weights (self , override_key : str | None = None , device : str = "cpu" ):
229+ # TODO: support Bias
228230 res = {}
229231 if override_key is not None :
230232 keys = override_key
@@ -288,6 +290,8 @@ def __init__(
288290 self .act_fn = ACT2FN [config .hidden_act ]
289291 assert device .lower () != "cpu" , "Marlin experts can only be loaded on GPU"
290292 self .device = device
293+ self .elements_per_tensor = config .moe_intermediate_size * config .hidden_size
294+
291295 # create empty marlin experts according to the number of experts per token
292296 # up
293297 self .up_projs = [KLinearMarlin (key + "." + "ffn_up_exps" , gguf_loader , config , device = device ) for i in range (self .expert_num )]
@@ -299,17 +303,34 @@ def __init__(
299303 def load (self , w : dict | nn .Parameter | tuple | None = None , device : str | None = None , warmup : bool = False ):
300304 if device is None : device = self .device
301305 assert device .lower () != "cpu" , "Marlin experts can only be loaded on GPU"
302- if w is None : w = self .load_weights ()[self .key ]
303-
304- if isinstance (w , dict ):
305- self .gate = w ["gate" ]
306- self .up = (w ["up" ])
307- self .down = (w ["down" ])
308- for i in range (self .expert_num ):
309- self .up_projs [i ].load (nn .Parameter (self .up [i ,...]), device = device )
310- self .gate_projs [i ].load (nn .Parameter (self .gate [i ,...]), device = device )
311- self .down_projs [i ].load (nn .Parameter (self .down [i ,...]), device = device )
312- self .loaded_experts_idx .append (i )
306+ if w is None :
307+ w = self .load_weights ()
308+ load_by_experts = True
309+
310+ if load_by_experts :
311+ if isinstance (w , dict ):
312+ self .gate = w ["gate" ]
313+ self .up = (w ["up" ])
314+ self .down = (w ["down" ])
315+ for i in tqdm (range (self .expert_num ), desc = f"Dequanting and quanting for KExpertsMarlin { self .key } " ):
316+ up_weights = self .gguf_loader .load_expert_tensor (self .key + ".ffn_up_exps.weight" , self .up , i , self .elements_per_tensor , device = self .device )
317+ gate_weights = self .gguf_loader .load_expert_tensor (self .key + ".ffn_gate_exps.weight" , self .gate , i , self .elements_per_tensor , device = self .device )
318+ down_weights = self .gguf_loader .load_expert_tensor (self .key + ".ffn_down_exps.weight" , self .down , i , self .elements_per_tensor , device = self .device )
319+
320+ self .up_projs [i ].load (nn .Parameter (up_weights ), device = device )
321+ self .gate_projs [i ].load (nn .Parameter (gate_weights ), device = device )
322+ self .down_projs [i ].load (nn .Parameter (down_weights ), device = device )
323+ self .loaded_experts_idx .append (i )
324+ else :
325+ if isinstance (w , dict ):
326+ self .gate = w ["gate" ]
327+ self .up = (w ["up" ])
328+ self .down = (w ["down" ])
329+ for i in range (self .expert_num ):
330+ self .up_projs [i ].load (nn .Parameter (self .up [i ,...]), device = device )
331+ self .gate_projs [i ].load (nn .Parameter (self .gate [i ,...]), device = device )
332+ self .down_projs [i ].load (nn .Parameter (self .down [i ,...]), device = device )
333+ self .loaded_experts_idx .append (i )
313334 return
314335
315336 def unload (self ):
@@ -329,20 +350,13 @@ def load_weights(self, override_key: str | None = None):
329350 gate = None
330351 up = None
331352 down = None
332- gate_type = None
333- up_type = None
334- down_type = None
335353
336354 for key in keys :
337355 if key + ".ffn_gate_exps.weight" in self .gguf_loader .tensor_info :
338- gate = self .gguf_loader .load_gguf_tensor (key + ".ffn_gate_exps.weight" )
339- up = self .gguf_loader .load_gguf_tensor (key + ".ffn_up_exps.weight" )
340- down = self .gguf_loader .load_gguf_tensor (key + ".ffn_down_exps.weight" )
341- gate_type = self .gguf_loader .tensor_info [key + ".ffn_gate_exps.weight" ]["ggml_type" ]
342- up_type = self .gguf_loader .tensor_info [key + ".ffn_up_exps.weight" ]["ggml_type" ]
343- down_type = self .gguf_loader .tensor_info [key + ".ffn_down_exps.weight" ]["ggml_type" ]
344- # tensors = self.load_multi(key, [".ffn_gate_exps.weight", ".ffn_up_exps.weight", ".ffn_down_exps.weight"])
345- res = {key :{"gate" : nn .Parameter (gate ), "up" : nn .Parameter (up ), "down" : nn .Parameter (down ), "gate_type" : gate_type , "up_type" : up_type , "down_type" : down_type }}
356+ gate = self .gguf_loader .get_mmap_tensor (key + ".ffn_gate_exps.weight" )
357+ up = self .gguf_loader .get_mmap_tensor (key + ".ffn_up_exps.weight" )
358+ down = self .gguf_loader .get_mmap_tensor (key + ".ffn_down_exps.weight" )
359+ res = {"gate" : gate , "up" : up , "down" : down }
346360 return res
347361
348362 def forward (self , hidden_states_cpu : torch .Tensor , selected_experts_cpu : torch .Tensor , routing_weights_cpu : torch .Tensor ) -> torch .Tensor :
@@ -381,6 +395,7 @@ def forward(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.T
381395
382396 return final_hidden_states .to (dtype = org_dtype , device = org_device )
383397
398+ # untested, CUDA OOM
384399class KExpertsTorch (KExpertsBase ):
385400 expert_num : int
386401 loaded_experts_idx : list [int ]
@@ -402,26 +417,65 @@ def __init__(
402417 # self.loaded_experts_idx = []
403418 self .act_fn = ACT2FN [config .hidden_act ]
404419 self .device = device
405- self .gate = None
406- self .up = None
407- self .donw = None
420+ self .elements_per_tensor = config .moe_intermediate_size * config .hidden_size
421+ self .gate = [None for _ in range (self .expert_num )]
422+ self .up = [None for _ in range (self .expert_num )]
423+ self .down = [None for _ in range (self .expert_num )]
408424 self .dtype = torch .get_default_dtype ()
409425
410426 def load (self , w : dict | nn .Parameter | tuple | None = None , device : str | None = None , warmup : bool = False ):
411427 if device is None : device = self .device
412- if w is None : w = self .load_weights (device = device )[self .key ]
413-
414- if isinstance (w , dict ):
415- self .gate = w ["gate" ].to (device = device , dtype = self .dtype )
416- self .up = w ["up" ].to (device = device , dtype = self .dtype )
417- self .down = w ["down" ].to (device = device , dtype = self .dtype )
428+ if w is None :
429+ w = self .load_weights ()
430+ load_by_experts = True
431+
432+ if load_by_experts :
433+ if isinstance (w , dict ):
434+ for i in tqdm (range (self .expert_num ), desc = f"Dequanting for KExpertsTorch { self .key } " ):
435+ up_weights = self .gguf_loader .load_expert_tensor (self .key + ".ffn_up_exps.weight" , w ["up" ], i , self .elements_per_tensor , device = self .device )
436+ gate_weights = self .gguf_loader .load_expert_tensor (self .key + ".ffn_gate_exps.weight" , w ["gate" ], i , self .elements_per_tensor , device = self .device )
437+ down_weights = self .gguf_loader .load_expert_tensor (self .key + ".ffn_down_exps.weight" , w ["down" ], i , self .elements_per_tensor , device = self .device )
438+
439+ self .up [i ] = up_weights
440+ self .gate [i ] = gate_weights
441+ self .down [i ] = down_weights
442+ else :
443+ if isinstance (w , dict ):
444+ for i in range (self .expert_num ):
445+ self .gate [i ] = w ["gate" ][i , ...].to (device = device , dtype = self .dtype )
446+ self .up [i ] = w ["up" ][i , ...].to (device = device , dtype = self .dtype )
447+ self .down [i ] = w ["down" ][i , ...].to (device = device , dtype = self .dtype )
448+
449+ self .up = torch .cat (self .gate , dim = 0 )
450+ self .gate = torch .cat (self .gate , dim = 0 )
451+ self .down = torch .cat (self .gate , dim = 0 )
452+ return
418453
419454 def unload (self ):
420455 if self .gate is not None :
421456 self .gate = None
422457 self .up = None
423458 self .down = None
424459
460+ def load_weights (self , override_key : str | None = None ):
461+ res = {}
462+ if override_key is not None :
463+ keys = override_key
464+ else :
465+ keys = [self .key ]
466+
467+ gate = None
468+ up = None
469+ down = None
470+
471+ for key in keys :
472+ if key + ".ffn_gate_exps.weight" in self .gguf_loader .tensor_info :
473+ gate = self .gguf_loader .get_mmap_tensor (key + ".ffn_gate_exps.weight" )
474+ up = self .gguf_loader .get_mmap_tensor (key + ".ffn_up_exps.weight" )
475+ down = self .gguf_loader .get_mmap_tensor (key + ".ffn_down_exps.weight" )
476+ res = {"gate" : gate , "up" : up , "down" : down }
477+ return res
478+
425479 def forward (self , hidden_states_cpu : torch .Tensor , selected_experts_cpu : torch .Tensor , routing_weights_cpu : torch .Tensor ) -> torch .Tensor :
426480
427481 org_device = hidden_states_cpu .device
@@ -582,7 +636,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
582636
583637 if isinstance (self .experts , KExpertsBase ):
584638 y = (
585- self .moe_on_cpuinfer (
639+ self .moe_kexperts (
586640 hidden_states_expert , selected_experts_expert , routing_weights_expert
587641 )
588642 .view (* orig_shape )
@@ -601,8 +655,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
601655 return y , router_logits
602656
603657 @torch .no_grad ()
604- def moe_on_cpuinfer (self , x : torch .Tensor , topk_ids : torch .Tensor , topk_weight : torch .Tensor ) -> torch .Tensor :
605- outs = torch .empty_like (x )
658+ def moe_kexperts (self , x : torch .Tensor , topk_ids : torch .Tensor , topk_weight : torch .Tensor ) -> torch .Tensor :
606659 outs = self .experts (x , topk_ids , topk_weight )
607660 return outs
608661
@@ -672,7 +725,7 @@ def forward(self, hidden_states):
672725 y_ = self .shared_experts (identity ).squeeze (0 )
673726
674727 if isinstance (self .experts , KExpertsBase ):
675- y = self .moe_on_cpuinfer (hidden_states , topk_idx , topk_weight ).view (* orig_shape ).to (device = hidden_states .device )
728+ y = self .moe_kexperts (hidden_states , topk_idx , topk_weight ).view (* orig_shape ).to (device = hidden_states .device )
676729 elif hidden_states .size (0 ) > 10 :
677730 # TODO may bugs here
678731 y = (
@@ -692,8 +745,7 @@ def forward(self, hidden_states):
692745 return y
693746
694747 @torch .no_grad ()
695- def moe_on_cpuinfer (self , x : torch .Tensor , topk_ids : torch .Tensor , topk_weight : torch .Tensor ) -> torch .Tensor :
696- outs = torch .empty_like (x )
748+ def moe_kexperts (self , x : torch .Tensor , topk_ids : torch .Tensor , topk_weight : torch .Tensor ) -> torch .Tensor :
697749 outs = self .experts (x , topk_ids , topk_weight )
698750 return outs
699751
@@ -773,7 +825,7 @@ def forward(self, hidden_states):
773825 y_ = self .shared_experts (identity ).squeeze (0 )
774826
775827 if isinstance (self .experts , KExpertsBase ):
776- y = self .moe_on_cpuinfer (hidden_states , topk_idx , topk_weight ).view (* orig_shape ).to (device = hidden_states .device )
828+ y = self .moe_kexperts (hidden_states , topk_idx , topk_weight ).view (* orig_shape ).to (device = hidden_states .device )
777829 elif hidden_states .size (0 ) > 10 :
778830 # TODO may bugs here
779831 y = (
@@ -793,8 +845,7 @@ def forward(self, hidden_states):
793845 return y
794846
795847 @torch .no_grad ()
796- def moe_on_cpuinfer (self , x : torch .Tensor , topk_ids : torch .Tensor , topk_weight : torch .Tensor ) -> torch .Tensor :
797- outs = torch .empty_like (x )
848+ def moe_kexperts (self , x : torch .Tensor , topk_ids : torch .Tensor , topk_weight : torch .Tensor ) -> torch .Tensor :
798849 outs = self .experts (x , topk_ids , topk_weight )
799850 return outs
800851
@@ -881,7 +932,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
881932
882933 if isinstance (self .experts , KExpertsBase ):
883934 y = (
884- self .moe_on_cpuinfer (
935+ self .moe_kexperts (
885936 hidden_states_expert , selected_experts_expert , routing_weights_expert
886937 )
887938 .view (* orig_shape )
@@ -900,8 +951,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
900951 return y , router_logits
901952
902953 @torch .no_grad ()
903- def moe_on_cpuinfer (self , x : torch .Tensor , topk_ids : torch .Tensor , topk_weight : torch .Tensor ) -> torch .Tensor :
904- outs = torch .empty_like (x )
954+ def moe_kexperts (self , x : torch .Tensor , topk_ids : torch .Tensor , topk_weight : torch .Tensor ) -> torch .Tensor :
905955 outs = self .experts (x , topk_ids , topk_weight )
906956 return outs
907957
0 commit comments