@@ -95,6 +95,7 @@ def __init__(self, model_config_path):
9595 self .mlp_method = self .MLPMethod .NORMAL # Currently no benefit to fused MLP
9696 self .device_map = ExLlamaDeviceMap (self .num_hidden_layers )
9797 self .auto_map = None # List of ints with memory allocation in GB, per CUDA device, overrides device_map
98+ self .dequant = None # Number of layers (per GPU) to de-quantize at load time
9899
99100
100101 # Parse and set list of GPU VRAM allocations
@@ -105,6 +106,14 @@ def set_auto_map(self, map_string):
105106 else : self .auto_map = [float (alloc ) for alloc in map_string .split ("," )]
106107
107108
109+ # Parse and set number of layers to de-quantize at load, per GPU
110+
111+ def set_dequant (self , dq_string ):
112+
113+ if dq_string is None : self .dequant = None
114+ else : self .dequant = [int (alloc ) for alloc in dq_string .split ("," )]
115+
116+
108117def _dump_tensor (t , name ):
109118
110119 if t is None :
@@ -146,11 +155,12 @@ def _mlp_switch(config, x):
146155
147156class Ex4bitLinear (nn .Module ):
148157
149- def __init__ (self , config , in_features , out_features , has_bias , tensors , key ):
158+ def __init__ (self , config , in_features , out_features , has_bias , tensors , key , dequant = False ):
150159 super ().__init__ ()
151160
152161 self .config = config
153162 self .key = key
163+ self .dequant = dequant
154164
155165 self .in_features = in_features
156166 self .out_features = out_features
@@ -210,6 +220,17 @@ def __init__(self, config, in_features, out_features, has_bias, tensors, key):
210220
211221 if has_bias : self .bias = tensors [key + ".bias" ]
212222
223+ # Optionally dequantize layer at init time
224+
225+ if self .dequant :
226+
227+ self .qweight_dequant = cuda_ext .dequantize_q4v2 (self .quant_args ())
228+ self .qweight = None
229+ self .scales = None
230+ self .zeros = None
231+ self .seq_g_idx = None
232+ self .x_map = None
233+
213234
214235 def quant_args (self ):
215236
@@ -268,20 +289,26 @@ def load_streaming(self):
268289
269290 def forward (self , x ):
270291
271- if torch . is_grad_enabled () :
292+ if self . dequant :
272293
273- # Untested
274- out = cuda_ext .ExAutogradMatmul4bitCuda .apply (x , self .qweight , self .scales , self .qzeros , self .groupsize , self .bits , self .maxq )
294+ out = torch .matmul (x , self .qweight_dequant )
275295
276296 else :
277297
278- out = cuda_ext .matmul_q4v2 (x , self .quant_args (), _matmul_switch (self .config , x ))
279- if self .bias is not None : out += self .bias
298+ if torch .is_grad_enabled ():
280299
281- # if self.key == "model.layers.0.mlp.gate_proj":
282- #
283- # _dump_tensor(x, "cuda_test/model.layers.0.mlp.gate_proj.x")
284- # sys.exit()
300+ # Untested
301+ out = cuda_ext .ExAutogradMatmul4bitCuda .apply (x , self .qweight , self .scales , self .qzeros , self .groupsize , self .bits , self .maxq )
302+
303+ else :
304+
305+ out = cuda_ext .matmul_q4v2 (x , self .quant_args (), _matmul_switch (self .config , x ))
306+ if self .bias is not None : out += self .bias
307+
308+ # if self.key == "model.layers.0.mlp.gate_proj":
309+ #
310+ # _dump_tensor(x, "cuda_test/model.layers.0.mlp.gate_proj.x")
311+ # sys.exit()
285312
286313 return out
287314
@@ -300,20 +327,22 @@ def dump(self, filename):
300327
301328class ExLlamaMLP (nn .Module ):
302329
303- def __init__ (self , config , tensors , key ):
330+ def __init__ (self , config , tensors , key , dequant = False ):
304331 super ().__init__ ()
305332
306333 self .config = config
334+ self .dequant = dequant
307335
308- self .gate_proj = Ex4bitLinear (config , self .config .hidden_size , self .config .intermediate_size , False , tensors , key + ".gate_proj" )
309- self .up_proj = Ex4bitLinear (config , self .config .hidden_size , self .config .intermediate_size , False , tensors , key + ".up_proj" )
310- self .down_proj = Ex4bitLinear (config , self .config .intermediate_size , self .config .hidden_size , False , tensors , key + ".down_proj" )
336+ self .gate_proj = Ex4bitLinear (config , self .config .hidden_size , self .config .intermediate_size , False , tensors , key + ".gate_proj" , dequant = dequant )
337+ self .up_proj = Ex4bitLinear (config , self .config .hidden_size , self .config .intermediate_size , False , tensors , key + ".up_proj" , dequant = dequant )
338+ self .down_proj = Ex4bitLinear (config , self .config .intermediate_size , self .config .hidden_size , False , tensors , key + ".down_proj" , dequant = dequant )
311339
312340 self .act_fn = nn .SiLU ()
313341
314342
315343 def forward_fused (self , x , rms_norm_weight , buffer ):
316344
345+ assert not self .dequant
317346 x = cuda_ext .mlp_q4v2 (x ,
318347 buffer .x_temp ,
319348 buffer .x_col_temp ,
@@ -367,18 +396,18 @@ def forward(self, hidden_states, buffer):
367396
368397class ExLlamaAttention (nn .Module ):
369398
370- def __init__ (self , config , tensors , key , sin , cos , index ):
399+ def __init__ (self , config , tensors , key , sin , cos , index , dequant = False ):
371400 super ().__init__ ()
372401
373402 self .config = config
374403 self .sin = sin
375404 self .cos = cos
376405 self .index = index
377406
378- self .q_proj = Ex4bitLinear (config , self .config .hidden_size , self .config .num_attention_heads * self .config .head_dim , False , tensors , key + ".q_proj" )
379- self .k_proj = Ex4bitLinear (config , self .config .hidden_size , self .config .num_attention_heads * self .config .head_dim , False , tensors , key + ".k_proj" )
380- self .v_proj = Ex4bitLinear (config , self .config .hidden_size , self .config .num_attention_heads * self .config .head_dim , False , tensors , key + ".v_proj" )
381- self .o_proj = Ex4bitLinear (config , self .config .num_attention_heads * self .config .head_dim , self .config .hidden_size , False , tensors , key + ".o_proj" )
407+ self .q_proj = Ex4bitLinear (config , self .config .hidden_size , self .config .num_attention_heads * self .config .head_dim , False , tensors , key + ".q_proj" , dequant = dequant )
408+ self .k_proj = Ex4bitLinear (config , self .config .hidden_size , self .config .num_attention_heads * self .config .head_dim , False , tensors , key + ".k_proj" , dequant = dequant )
409+ self .v_proj = Ex4bitLinear (config , self .config .hidden_size , self .config .num_attention_heads * self .config .head_dim , False , tensors , key + ".v_proj" , dequant = dequant )
410+ self .o_proj = Ex4bitLinear (config , self .config .num_attention_heads * self .config .head_dim , self .config .hidden_size , False , tensors , key + ".o_proj" , dequant = dequant )
382411
383412
384413 def forward (self , hidden_states , cache , buffer ):
@@ -467,14 +496,14 @@ def rotate_half(x):
467496
468497class ExLlamaDecoderLayer (nn .Module ):
469498
470- def __init__ (self , config , tensors , key , index , sin , cos ):
499+ def __init__ (self , config , tensors , key , index , sin , cos , dequant = False ):
471500 super ().__init__ ()
472501
473502 self .config = config
474503 self .index = index
475504
476- self .self_attn = ExLlamaAttention (self .config , tensors , key + ".self_attn" , sin , cos , self .index )
477- self .mlp = ExLlamaMLP (self .config , tensors , key + ".mlp" )
505+ self .self_attn = ExLlamaAttention (self .config , tensors , key + ".self_attn" , sin , cos , self .index , dequant = dequant )
506+ self .mlp = ExLlamaMLP (self .config , tensors , key + ".mlp" , dequant = dequant )
478507
479508 self .input_layernorm = ExLlamaRMSNorm (self .config , tensors , key + ".input_layernorm.weight" )
480509 self .post_attention_layernorm = ExLlamaRMSNorm (self .config , tensors , key + ".post_attention_layernorm.weight" )
@@ -487,7 +516,9 @@ def forward(self, hidden_states, cache, buffer):
487516 hidden_states = self .self_attn (hidden_states , cache , buffer )
488517 hidden_states = residual + hidden_states
489518
490- if _mlp_switch (self .config , hidden_states ):
519+ # TODO: Support dequantized layer in fused MLP. Also, finish implementing fused MLP
520+
521+ if self .mlp .dequant or _mlp_switch (self .config , hidden_states ):
491522
492523 residual = hidden_states
493524 hidden_states = self .post_attention_layernorm (hidden_states , buffer )
@@ -741,6 +772,17 @@ def to(self, device):
741772 return new
742773
743774
775+ def _device_to_int (device ):
776+
777+ return int (device [device .find (":" ) + 1 :])
778+
779+ def _skip_key (key ):
780+
781+ if key .endswith ("_proj.bias" ): return True
782+ if key .endswith (".rotary_emb.inv_freq" ): return True
783+ return False
784+
785+
744786class ExLlama (nn .Module ):
745787
746788 def __init__ (self , config ):
@@ -762,8 +804,10 @@ def __init__(self, config):
762804 # Begin auto mapping if enabled
763805
764806 decoder_size = 0
807+ decoder_dq_size = 0
765808 norm_size = 0
766809 head_size = 0
810+ half_element_size = torch .tensor ([], dtype = torch .float16 ).element_size ()
767811
768812 if self .config .auto_map is not None :
769813
@@ -772,9 +816,15 @@ def __init__(self, config):
772816
773817 for key in f .keys ():
774818
819+ if _skip_key (key ): continue
820+
775821 if key .startswith ("model.layers.0." ):
776822 tensor = f .get_tensor (key )
777823 decoder_size += tensor .numel () * tensor .element_size ()
824+ if key .endswith (".weight" ):
825+ decoder_dq_size += tensor .numel () * tensor .element_size ()
826+ if key .endswith (".qweight" ):
827+ decoder_dq_size += tensor .numel () * 8 * half_element_size
778828
779829 if key .startswith ("model.norm." ):
780830 tensor = f .get_tensor (key )
@@ -784,37 +834,40 @@ def __init__(self, config):
784834 tensor = f .get_tensor (key )
785835 head_size += tensor .numel () * tensor .element_size ()
786836
787- # Assign layers automatically
837+ # Assign layers automatically
788838
789- device_usage = 0
790- device_index = 0
791- max_usage = self .config .auto_map [device_index ] * (1024 ** 3 )
839+ device_usage = 0
840+ device_index = 0
841+ layer_index_device = 0
842+ max_usage = self .config .auto_map [device_index ] * (1024 ** 3 )
792843
793- for layer in range (self .config .num_hidden_layers + 2 ):
844+ for layer in range (self .config .num_hidden_layers + 2 ):
794845
795- this_layer_size = decoder_size
796- if layer == self .config .num_hidden_layers + 0 : this_layer_size = norm_size
797- if layer == self .config .num_hidden_layers + 1 : this_layer_size = head_size
846+ this_layer_size = decoder_size
847+ if layer == self .config .num_hidden_layers + 0 : this_layer_size = norm_size
848+ elif layer == self .config .num_hidden_layers + 1 : this_layer_size = head_size
849+ elif self .config .dequant is not None and layer_index_device < self .config .dequant [device_index ]: this_layer_size = decoder_dq_size
798850
799- while device_usage + this_layer_size > max_usage :
800- device_index += 1
801- device_usage = 0
802- max_usage = self .config .auto_map [device_index ] * (1024 ** 3 )
803- if device_index >= len (self .config .auto_map ): raise ValueError ("Model too large for device allocation scheme." )
851+ while device_usage + this_layer_size > max_usage :
852+ device_index += 1
853+ device_usage = 0
854+ layer_index_device = 0
855+ max_usage = self .config .auto_map [device_index ] * (1024 ** 3 )
856+ if device_index >= len (self .config .auto_map ): raise ValueError ("Model too large for device allocation scheme." )
804857
805- target = f"cuda:{ device_index } "
806- if layer == self .config .num_hidden_layers + 0 : self .config .device_map .norm = target
807- elif layer == self .config .num_hidden_layers + 1 : self .config .device_map .lm_head = target
808- else : self .config .device_map .layers [layer ] = f"cuda:{ device_index } "
858+ target = f"cuda:{ device_index } "
859+ if layer == self .config .num_hidden_layers + 0 : self .config .device_map .norm = target
860+ elif layer == self .config .num_hidden_layers + 1 : self .config .device_map .lm_head = target
861+ else : self .config .device_map .layers [layer ] = f"cuda:{ device_index } "
809862
810- device_usage += this_layer_size
863+ device_usage += this_layer_size
864+ layer_index_device += 1
811865
812- # Load tensors to
866+ # Load tensors, move to device(s)
813867
814868 for key in f .keys ():
815869
816- if key .endswith ("_proj.bias" ): continue # Skip loading unused, empty bias tensors
817- if key .endswith (".rotary_emb.inv_freq" ): continue # This is always precomputed during init anyway
870+ if _skip_key (key ): continue
818871
819872 device = self .config .device_map .map (key , loading = True )
820873 tensor = f .get_tensor (key )
@@ -845,8 +898,10 @@ def __init__(self, config):
845898
846899 # Prepare position embeddings for max seq length
847900
901+ devs = self .config .device_map .get_layers_devs ()
902+
848903 self .sincos = {}
849- for device in self . config . device_map . get_layers_devs () :
904+ for device in devs :
850905
851906 inv_freq = 1.0 / (self .config .rotary_embedding_base ** (torch .arange (0 , self .config .head_dim , 2 , device = device ).float () / self .config .head_dim ))
852907 t = torch .arange (self .config .max_seq_len , device = device , dtype = torch .float32 )
@@ -863,12 +918,21 @@ def __init__(self, config):
863918 layer_streaming = self .config .stream_layer_interval > 0
864919
865920 modules = []
921+ device_layer_index = [0 ] * len (devs )
922+
866923 for i in range (self .config .num_hidden_layers ):
867924
868925 device = self .config .device_map .layers [i ]
869926 sin , cos = self .sincos [device ]
870927
871- layer = ExLlamaDecoderLayer (self .config , tensors , f"model.layers.{ i } " , i , sin , cos )
928+ dequant = False
929+ if self .config .dequant is not None :
930+ device_idx = _device_to_int (device )
931+ device_layer = device_layer_index [device_idx ]
932+ device_layer_index [device_idx ] += 1
933+ if device_layer < self .config .dequant [device_idx ]: dequant = True
934+
935+ layer = ExLlamaDecoderLayer (self .config , tensors , f"model.layers.{ i } " , i , sin , cos , dequant = dequant )
872936
873937 if layer_streaming and i > 0 and (i + 1 ) % self .config .stream_layer_interval == 0 :
874938 if self .stream_buffer is None : self .stream_buffer = ExLlamaStreamer (self .config , layer ) # Use first layer as prototype
0 commit comments