@@ -277,13 +277,56 @@ def prepare_tensor(tensor, dst_shape, *, force_transpose=False):
277277 if len (tensor .shape ) != 1 :
278278 print ("attention same shape not transpose !!!!!!!!!!!!!!!!!!!!!!" )
279279 return tensor
280- if len (tensor .shape ) == 2 and paddle .transpose (tensor , perm = [1 , 0 ]).contiguous ().shape == dst_shape :
280+
281+ if len (tensor .shape ) == 2 :
282+ num_experts , hidden_size = tensor .shape
283+ assert hidden_size == dst_shape [0 ], f"Shape not match: { tensor .shape } { dst_shape } "
284+ if num_experts != dst_shape [1 ]:
285+ print (f"Slice weight: { tensor .shape } -> { dst_shape } " )
286+ tensor = tensor [:dst_shape [1 ]]
281287 return paddle .transpose (tensor , perm = [1 , 0 ]).contiguous ()
282288
283- print ("shape not match here" )
289+ if len (tensor .shape ) == 1 :
290+ print (f"Slice weight: { tensor .shape } -> { dst_shape } " )
291+ tensor = tensor [:dst_shape [0 ]]
292+ return tensor
293+
294+ print ("Fatal: shape not match here:" , tensor .shape , dst_shape )
284295 sys .exit ()
285296
286297
298+ def hf_cache (path ):
299+ print ('looking up:' , path )
300+ import os , time , subprocess
301+ basename = 'lshrun_' + os .path .basename (path )
302+ cache_path = os .path .join ('/dev/shm' , basename )
303+ lock_path = cache_path + '.lock'
304+
305+ # Case 1: cache exists
306+ if os .path .exists (cache_path ):
307+ print ('hit cache:' , cache_path )
308+ return cache_path
309+
310+ try :
311+ open (lock_path , 'x' )
312+ except FileExistsError :
313+ # Case 2: peer is loading
314+ print ('waiting peer load:' , lock_path )
315+ while os .path .exists (lock_path ):
316+ time .sleep (0.1 )
317+ print ('peer done:' , lock_path )
318+ else :
319+ # Case 3: load it ourself
320+ print ('copying:' , lock_path )
321+ while subprocess .run (['cp' , path , lock_path ]).returncode :
322+ print ('retrying:' , path , '->' , lock_path )
323+ time .sleep (10 ) # sometimes too many open files cause error
324+ subprocess .run (['mv' , lock_path , cache_path ], check = True )
325+ print ('done copy:' , lock_path )
326+
327+ return cache_path
328+
329+
287330def load_huggingface_ckpt (model , huggingface_ckpt_path ):
288331 ckpt_pre = huggingface_ckpt_path
289332
@@ -328,8 +371,9 @@ def load_huggingface_ckpt(model, huggingface_ckpt_path):
328371 check_list = []
329372 print ("Start load huggingface ckpt" )
330373 for i , filename in enumerate (required_files ):
374+ print (f'loading { i + 1 } /{ len (required_files )} : { filename } ' )
331375 try :
332- with safe_open (ckpt_pre + filename , framework = "paddle" , device = "cpu" ) as f :
376+ with safe_open (hf_cache ( ckpt_pre + filename ) , framework = "paddle" , device = "cpu" ) as f :
333377 # 加载该文件包含的所有参数
334378 pd_params = file_to_pd_param_name [filename ]
335379 for pd_param in pd_params :
@@ -359,12 +403,12 @@ def load_huggingface_ckpt(model, huggingface_ckpt_path):
359403 if weight_map [hf_name [0 ]] == filename :
360404 tensor0 = f .get_tensor (hf_name [0 ])
361405 with safe_open (
362- ckpt_pre + weight_map [hf_name [1 ]], framework = "paddle" , device = "cpu"
406+ hf_cache ( ckpt_pre + weight_map [hf_name [1 ]]) , framework = "paddle" , device = "cpu"
363407 ) as f_other :
364408 tensor1 = f_other .get_tensor (hf_name [1 ])
365409 else :
366410 with safe_open (
367- ckpt_pre + weight_map [hf_name [0 ]], framework = "paddle" , device = "cpu"
411+ hf_cache ( ckpt_pre + weight_map [hf_name [0 ]]) , framework = "paddle" , device = "cpu"
368412 ) as f_other :
369413 tensor0 = f_other .get_tensor (hf_name [0 ])
370414 tensor1 = f .get_tensor (hf_name [1 ])
@@ -376,3 +420,4 @@ def load_huggingface_ckpt(model, huggingface_ckpt_path):
376420 except Exception as e :
377421 print (f"Error loading { filename } : { str (e )} " )
378422 raise
423+ print ("End load huggingface ckpt" )
0 commit comments