Skip to content

Commit 8de12fc

Browse files
authored
Merge pull request #238 from IINemo/official_em
Official em
2 parents 4da765b + 9b7cb3f commit 8de12fc

30 files changed

+159
-24
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
process_output_fn:
2+
path: instruct/output_processing_scripts/coqa.py
3+
fn_name: normalize_em_coqa
4+
process_target_fn:
5+
path: instruct/output_processing_scripts/coqa.py
6+
fn_name: normalize_em_coqa
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
process_output_fn:
2+
path: instruct/output_processing_scripts/triviaqa.py
3+
fn_name: normalize_em_triviaqa
4+
process_target_fn:
5+
path: instruct/output_processing_scripts/triviaqa.py
6+
fn_name: normalize_em_triviaqa

examples/configs/instruct/cot_processing.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@ process_output_fn:
33
fn_name: process_output_cot
44
process_target_fn:
55
path: output_processing_scripts/default.py
6-
fn_name: process_target
6+
fn_name: normalize_text
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
process_output_fn:
2+
path: output_processing_scripts/coqa.py
3+
fn_name: process_output_cot_coqa
4+
process_target_fn:
5+
path: output_processing_scripts/coqa.py
6+
fn_name: normalize_em_coqa
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
process_output_fn:
2+
path: output_processing_scripts/triviaqa.py
3+
fn_name: process_output_triviaqa
4+
process_target_fn:
5+
path: output_processing_scripts/triviaqa.py
6+
fn_name: normalize_em_triviaqa
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import re
2+
import string
3+
4+
from default import (
5+
TOP1_OUTPUT_IGNORE_REGEX,
6+
TOPK_OUTPUT_IGNORE_REGEX,
7+
CoT_OUTPUT_IGNORE_REGEX,
8+
)
9+
10+
11+
def normalize_em_coqa(s: str) -> str:
12+
def remove_articles(text):
13+
regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
14+
return re.sub(regex, " ", text)
15+
16+
def white_space_fix(text):
17+
return " ".join(text.split())
18+
19+
def remove_punc(text):
20+
exclude = set(string.punctuation)
21+
return "".join(ch for ch in text if ch not in exclude)
22+
23+
def lower(text):
24+
return text.lower()
25+
26+
return white_space_fix(remove_articles(remove_punc(lower(s))))
27+
28+
29+
def process_output_top1_coqa(output: str) -> str:
30+
output = TOP1_OUTPUT_IGNORE_REGEX.sub("", output)
31+
output = normalize_em_coqa(output)
32+
return output
33+
34+
35+
def process_output_topk_coqa(output: str) -> str:
36+
output = TOPK_OUTPUT_IGNORE_REGEX.sub("", output)
37+
output = normalize_em_coqa(output)
38+
return output
39+
40+
41+
def process_output_cot_coqa(output: str) -> str:
42+
output = CoT_OUTPUT_IGNORE_REGEX.sub("", output)
43+
output = normalize_em_coqa(output)
44+
return output

examples/configs/instruct/output_processing_scripts/default.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,6 @@ def normalize_text(text: str) -> str:
1212
return text
1313

1414

15-
def process_target(target: str) -> str:
16-
target = normalize_text(target)
17-
return target
18-
19-
2015
def process_output_top1(output: str) -> str:
2116
output = TOP1_OUTPUT_IGNORE_REGEX.sub("", output)
2217
output = normalize_text(output)
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import re
2+
import string
3+
4+
from default import (
5+
TOP1_OUTPUT_IGNORE_REGEX,
6+
TOPK_OUTPUT_IGNORE_REGEX,
7+
CoT_OUTPUT_IGNORE_REGEX,
8+
)
9+
10+
11+
def normalize_em_triviaqa(s: str) -> str:
12+
def remove_articles(text):
13+
return re.sub(r"\b(a|an|the)\b", " ", text)
14+
15+
def white_space_fix(text):
16+
return " ".join(text.split())
17+
18+
def handle_punc(text):
19+
exclude = set(string.punctuation + "".join(["‘", "’", "´", "`"]))
20+
return "".join(ch if ch not in exclude else " " for ch in text)
21+
22+
def lower(text):
23+
return text.lower()
24+
25+
def replace_underscore(text):
26+
return text.replace("_", " ")
27+
28+
return white_space_fix(
29+
remove_articles(handle_punc(lower(replace_underscore(s))))
30+
).strip()
31+
32+
33+
def process_output_top1_triviaqa(output: str) -> str:
34+
output = TOP1_OUTPUT_IGNORE_REGEX.sub("", output)
35+
output = normalize_em_triviaqa(output)
36+
return output
37+
38+
39+
def process_output_topk_triviaqa(output: str) -> str:
40+
output = TOPK_OUTPUT_IGNORE_REGEX.sub("", output)
41+
output = normalize_em_triviaqa(output)
42+
return output
43+
44+
45+
def process_output_cot_triviaqa(output: str) -> str:
46+
output = CoT_OUTPUT_IGNORE_REGEX.sub("", output)
47+
output = normalize_em_triviaqa(output)
48+
return output

examples/configs/instruct/polygraph_eval_coqa_empirical_baselines.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
defaults:
22
- polygraph_eval_coqa_default_instruct
3-
- top1_processing
3+
- top1_processing_coqa
44
- _self_
55

66
experiment_name: coqa_empirical_baselines

examples/configs/instruct/polygraph_eval_coqa_ling_1s.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
defaults:
22
- polygraph_eval_coqa_default_instruct
3-
- top1_processing
3+
- top1_processing_coqa
44
- _self_
55

66
experiment_name: coqa_ling_1s

0 commit comments

Comments
 (0)