Skip to content

Commit 684909d

Browse files
committed
apply pr microsoft#20
1 parent 35f54d6 commit 684909d

File tree

5 files changed

+89
-9
lines changed

5 files changed

+89
-9
lines changed

RepoCoder/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ This project contains the basic components of RepoCoder. Here is an overview:
2222
|-- build_prompt.py # build the prompt with the unfinished code and the retrieved code snippets
2323
|-- run_pipeline.py # run the code completion pipeline
2424
|-- compute_score.py # evaluate the performance of the code completion
25+
|-- codegen_inference.py # an example script for using CodeGen to generate code completions
2526
|-- utils.py # utility functions
2627
|-- datasets/datasets.zip # the input data for the code completion task
2728
|-- function_level_completion_4k_context_codex.test.jsonl

RepoCoder/codegen_inference.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import torch
2+
import tqdm
3+
import json
4+
from transformers import AutoModelForCausalLM, AutoTokenizer
5+
6+
7+
class Tools:
8+
@staticmethod
9+
def load_jsonl(path):
10+
with open(path, 'r') as f:
11+
return [json.loads(line) for line in f.readlines()]
12+
13+
@staticmethod
14+
def dump_jsonl(obj, path):
15+
with open(path, 'w') as f:
16+
for line in obj:
17+
f.write(json.dumps(line) + '\n')
18+
19+
20+
class CodeGen:
21+
def __init__(self, model_name, batch_size):
22+
self.model_name = model_name
23+
self.model = AutoModelForCausalLM.from_pretrained(model_name)
24+
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
25+
self.tokenizer.add_special_tokens({'pad_token': self.tokenizer.eos_token})
26+
self.model.cuda()
27+
self.batch_size = batch_size
28+
print('done loading model')
29+
30+
def _get_batchs(self, prompts, batch_size):
31+
batches = []
32+
for i in range(0, len(prompts), batch_size):
33+
batches.append(prompts[i:i+batch_size])
34+
return batches
35+
36+
def _generate_batch(self, prompt_batch, max_new_tokens=100):
37+
prompts = self.tokenizer(prompt_batch, return_tensors='pt', padding=True, truncation=True)
38+
39+
with torch.no_grad():
40+
gen_tokens = self.model.generate(
41+
input_ids = prompts['input_ids'].cuda(),
42+
attention_mask = prompts['attention_mask'].cuda(),
43+
do_sample=False,
44+
max_new_tokens=max_new_tokens,
45+
)
46+
gen_text = self.tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)
47+
for i in range(len(gen_text)): # remove the prompt
48+
gen_text[i] = gen_text[i][len(prompt_batch[i]):]
49+
return gen_text
50+
51+
def batch_generate(self, file):
52+
print(f'generating from {file}')
53+
lines = Tools.load_jsonl(file)
54+
# have a new line at the end
55+
prompts = [f"{line['prompt']}\n" for line in lines]
56+
batches = self._get_batchs(prompts, self.batch_size)
57+
gen_text = []
58+
for batch in tqdm.tqdm(batches):
59+
gen_text.extend(self._generate_batch(batch))
60+
print(f'generated {len(gen_text)} samples')
61+
assert len(gen_text) == len(prompts)
62+
new_lines = []
63+
for line, gen in zip(lines, gen_text):
64+
new_lines.append({
65+
'prompt': line['prompt'],
66+
'metadata': line['metadata'],
67+
'choices': [{'text': gen}]
68+
})
69+
Tools.dump_jsonl(new_lines, file.replace('.jsonl', f'_{self.model_name.split("/")[-1]}.jsonl'))
70+
71+
72+
if __name__ == '__main__':
73+
file_path = 'datasets/line_level_completion_1k_context_codegen.test.jsonl'
74+
tiny_codegen = 'Salesforce/codegen-350M-mono'
75+
76+
cg = CodeGen(tiny_codegen, batch_size=8)
77+
cg.batch_generate(file_path)

RepoCoder/make_window.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def build_window(self):
138138
}
139139
})
140140
print(f'build {len(code_windows)} ground truth windows for {self.repo} with window size {self.window_size}')
141-
output_path = FilePathBuilder.search_first_window_path(self.benchmark, CONSTANTS.rg, self.repo, self.window_size)
141+
output_path = FilePathBuilder.search_first_window_path(self.benchmark, CONSTANTS.gt, self.repo, self.window_size)
142142
Tools.dump_pickle(code_windows, output_path)
143143

144144
class PredictionWindowMaker:

RepoCoder/run_pipeline.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,12 @@
1212
from utils import CONSTANTS, CodexTokenizer
1313

1414
def make_repo_window(repos, window_sizes, slice_sizes):
15-
worker = MakeWindowWrapper(None, repos, window_sizes, slice_sizes)
16-
worker.window_for_repo_files()
15+
MakeWindowWrapper(None, repos, window_sizes, slice_sizes).window_for_repo_files()
16+
vectorizer = BagOfWords
17+
BuildVectorWrapper(None, vectorizer, repos, window_sizes, slice_sizes).vectorize_repo_windows()
1718

1819

1920
def run_RG1_and_oracle_method(benchmark, repos, window_sizes, slice_sizes):
20-
# build code snippets for all the repositories
21-
make_repo_window(repos, window_sizes, slice_sizes)
2221
# build code snippets for vanilla retrieval-augmented approach and ground truth
2322
MakeWindowWrapper(benchmark, repos, window_sizes, slice_sizes).window_for_baseline_and_ground()
2423
# build vector for vanilla retrieval-augmented approach and ground truth
@@ -62,6 +61,9 @@ def run_RepoCoder_method(benchmark, repos, window_sizes, slice_sizes, prediction
6261
window_sizes = [20]
6362
slice_sizes = [2] # 20 / 2 = 10
6463

64+
# build window for the repos
65+
make_repo_window(repos, window_sizes, slice_sizes)
66+
6567
# build prompt for the RG1 and oracle methods
6668
run_RG1_and_oracle_method(CONSTANTS.api_benchmark, repos, window_sizes, slice_sizes)
6769

RepoCoder/utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@ class CONSTANTS:
2020
rgrg = 'r-g-r-g' # RepoCoder, two-stage retrieval and generation
2121

2222
class FilePathBuilder:
23-
api_completion_benchmark = 'datasets/random-api-completion.test.jsonl'
24-
random_line_completion_benchmark = 'datasets/random-line-completion.test.jsonl'
23+
api_completion_benchmark = 'datasets/api_level_completion_2k_context_codex.test.jsonl'
24+
random_line_completion_benchmark = 'datasets/line_level_completion_2k_context_codex.test.jsonl'
2525
# short version for codegen
26-
short_api_completion_benchmark = 'datasets/random-api-completion-short-version.test.jsonl'
27-
short_random_line_completion_benchmark = 'datasets/random-line-completion-short-version.test.jsonl'
26+
short_api_completion_benchmark = 'datasets/api_level_completion_1k_context_codegen.test.jsonl'
27+
short_random_line_completion_benchmark = 'datasets/line_level_completion_1k_context_codegen.test.jsonl'
2828
repo_base_dir = 'repositories/line_and_api_level'
2929

3030
@staticmethod

0 commit comments

Comments
 (0)