1+ import torch
2+ import tqdm
3+ import json
4+ from transformers import AutoModelForCausalLM , AutoTokenizer
5+
6+
7+ class Tools :
8+ @staticmethod
9+ def load_jsonl (path ):
10+ with open (path , 'r' ) as f :
11+ return [json .loads (line ) for line in f .readlines ()]
12+
13+ @staticmethod
14+ def dump_jsonl (obj , path ):
15+ with open (path , 'w' ) as f :
16+ for line in obj :
17+ f .write (json .dumps (line ) + '\n ' )
18+
19+
20+ class CodeGen :
21+ def __init__ (self , model_name , batch_size ):
22+ self .model_name = model_name
23+ self .model = AutoModelForCausalLM .from_pretrained (model_name )
24+ self .tokenizer = AutoTokenizer .from_pretrained (model_name , padding_side = "left" )
25+ self .tokenizer .add_special_tokens ({'pad_token' : self .tokenizer .eos_token })
26+ self .model .cuda ()
27+ self .batch_size = batch_size
28+ print ('done loading model' )
29+
30+ def _get_batchs (self , prompts , batch_size ):
31+ batches = []
32+ for i in range (0 , len (prompts ), batch_size ):
33+ batches .append (prompts [i :i + batch_size ])
34+ return batches
35+
36+ def _generate_batch (self , prompt_batch , max_new_tokens = 100 ):
37+ prompts = self .tokenizer (prompt_batch , return_tensors = 'pt' , padding = True , truncation = True )
38+
39+ with torch .no_grad ():
40+ gen_tokens = self .model .generate (
41+ input_ids = prompts ['input_ids' ].cuda (),
42+ attention_mask = prompts ['attention_mask' ].cuda (),
43+ do_sample = False ,
44+ max_new_tokens = max_new_tokens ,
45+ )
46+ gen_text = self .tokenizer .batch_decode (gen_tokens , skip_special_tokens = True )
47+ for i in range (len (gen_text )): # remove the prompt
48+ gen_text [i ] = gen_text [i ][len (prompt_batch [i ]):]
49+ return gen_text
50+
51+ def batch_generate (self , file ):
52+ print (f'generating from { file } ' )
53+ lines = Tools .load_jsonl (file )
54+ # have a new line at the end
55+ prompts = [f"{ line ['prompt' ]} \n " for line in lines ]
56+ batches = self ._get_batchs (prompts , self .batch_size )
57+ gen_text = []
58+ for batch in tqdm .tqdm (batches ):
59+ gen_text .extend (self ._generate_batch (batch ))
60+ print (f'generated { len (gen_text )} samples' )
61+ assert len (gen_text ) == len (prompts )
62+ new_lines = []
63+ for line , gen in zip (lines , gen_text ):
64+ new_lines .append ({
65+ 'prompt' : line ['prompt' ],
66+ 'metadata' : line ['metadata' ],
67+ 'choices' : [{'text' : gen }]
68+ })
69+ Tools .dump_jsonl (new_lines , file .replace ('.jsonl' , f'_{ self .model_name .split ("/" )[- 1 ]} .jsonl' ))
70+
71+
72+ if __name__ == '__main__' :
73+ file_path = 'datasets/line_level_completion_1k_context_codegen.test.jsonl'
74+ tiny_codegen = 'Salesforce/codegen-350M-mono'
75+
76+ cg = CodeGen (tiny_codegen , batch_size = 8 )
77+ cg .batch_generate (file_path )
0 commit comments