@@ -277,13 +277,54 @@ def prepare_tensor(tensor, dst_shape, *, force_transpose=False):
277277 if len (tensor .shape ) != 1 :
278278 print ("attention same shape not transpose !!!!!!!!!!!!!!!!!!!!!!" )
279279 return tensor
280- if len (tensor .shape ) == 2 and paddle .transpose (tensor , perm = [1 , 0 ]).contiguous ().shape == dst_shape :
280+
281+ if len (tensor .shape ) == 2 :
282+ num_experts , hidden_size = tensor .shape
283+ assert hidden_size == dst_shape [0 ], f"Shape not match: { tensor .shape } { dst_shape } "
284+ if num_experts != dst_shape [1 ]:
285+ print (f"Slice weight: { tensor .shape } -> { dst_shape } " )
286+ tensor = tensor [:dst_shape [1 ]]
281287 return paddle .transpose (tensor , perm = [1 , 0 ]).contiguous ()
282288
283- print ("shape not match here" )
289+ if len (tensor .shape ) == 1 :
290+ print (f"Slice weight: { tensor .shape } -> { dst_shape } " )
291+ tensor = tensor [:dst_shape [0 ]]
292+ return tensor
293+
294+ print ("Fatal: shape not match here:" , tensor .shape , dst_shape )
284295 sys .exit ()
285296
286297
298+ def hf_cache (path ):
299+ print ('looking up:' , path )
300+ import os , time , subprocess
301+ basename = 'lshrun_' + os .path .basename (path )
302+ cache_path = os .path .join ('/dev/shm' , basename )
303+ lock_path = cache_path + '.lock'
304+
305+ # Case 1: cache exists
306+ if os .path .exists (cache_path ):
307+ print ('hit cache:' , cache_path )
308+ return cache_path
309+
310+ try :
311+ open (lock_path , 'x' )
312+ except FileExistsError :
313+ # Case 2: peer is loading
314+ print ('waiting peer load:' , lock_path )
315+ while os .path .exists (lock_path ):
316+ time .sleep (0.1 )
317+ print ('peer done:' , lock_path )
318+ else :
319+ # Case 3: load it ourself
320+ print ('copying:' , lock_path )
321+ subprocess .run (['cp' , path , lock_path ], check = True )
322+ subprocess .run (['mv' , lock_path , cache_path ], check = True )
323+ print ('done copy:' , lock_path )
324+
325+ return cache_path
326+
327+
287328def load_huggingface_ckpt (model , huggingface_ckpt_path ):
288329 ckpt_pre = huggingface_ckpt_path
289330
@@ -328,8 +369,9 @@ def load_huggingface_ckpt(model, huggingface_ckpt_path):
328369 check_list = []
329370 print ("Start load huggingface ckpt" )
330371 for i , filename in enumerate (required_files ):
372+ print (f'loading { i + 1 } /{ len (required_files )} : { filename } ' )
331373 try :
332- with safe_open (ckpt_pre + filename , framework = "paddle" , device = "cpu" ) as f :
374+ with safe_open (hf_cache ( ckpt_pre + filename ) , framework = "paddle" , device = "cpu" ) as f :
333375 # 加载该文件包含的所有参数
334376 pd_params = file_to_pd_param_name [filename ]
335377 for pd_param in pd_params :
@@ -359,12 +401,12 @@ def load_huggingface_ckpt(model, huggingface_ckpt_path):
359401 if weight_map [hf_name [0 ]] == filename :
360402 tensor0 = f .get_tensor (hf_name [0 ])
361403 with safe_open (
362- ckpt_pre + weight_map [hf_name [1 ]], framework = "paddle" , device = "cpu"
404+ hf_cache ( ckpt_pre + weight_map [hf_name [1 ]]) , framework = "paddle" , device = "cpu"
363405 ) as f_other :
364406 tensor1 = f_other .get_tensor (hf_name [1 ])
365407 else :
366408 with safe_open (
367- ckpt_pre + weight_map [hf_name [0 ]], framework = "paddle" , device = "cpu"
409+ hf_cache ( ckpt_pre + weight_map [hf_name [0 ]]) , framework = "paddle" , device = "cpu"
368410 ) as f_other :
369411 tensor0 = f_other .get_tensor (hf_name [0 ])
370412 tensor1 = f .get_tensor (hf_name [1 ])
@@ -376,3 +418,4 @@ def load_huggingface_ckpt(model, huggingface_ckpt_path):
376418 except Exception as e :
377419 print (f"Error loading { filename } : { str (e )} " )
378420 raise
421+ print ("End load huggingface ckpt" )
0 commit comments