1
+ import csv
2
+ import glob
1
3
import os
2
4
import subprocess
3
5
6
8
DEFAULT_PORT = 23333
7
9
8
10
9
- def get_model_type (model_name ):
10
- model_name_lower = model_name . lower ()
11
+ def write_to_summary (model_name , tp_num , result , msg , worker_id , work_dir = None ):
12
+ status = '✅ PASS' if result else '❌ FAIL'
11
13
12
- chat_patterns = [
13
- 'chat' ,
14
- 'instruct' ,
15
- 'gemma' ,
16
- 'llama3' ,
17
- 'llama2' ,
18
- 'llama' ,
19
- ]
20
- if any (pattern in model_name_lower for pattern in chat_patterns ):
21
- return 'chat'
14
+ metrics = {}
15
+
16
+ if work_dir and os .path .exists (work_dir ):
17
+ try :
18
+ summary_dirs = glob .glob (os .path .join (work_dir , '*' , 'summary' ))
19
+ if summary_dirs :
20
+ summary_dir = summary_dirs [0 ]
21
+ csv_files = glob .glob (os .path .join (summary_dir , 'summary_*.csv' ))
22
+ if csv_files :
23
+ csv_file = sorted (csv_files )[- 1 ]
24
+ if os .path .exists (csv_file ):
25
+ with open (csv_file , 'r' ) as f :
26
+ reader = csv .reader (f )
27
+ next (reader )
28
+ for row in reader :
29
+ if len (row ) >= 5 and row [4 ]:
30
+ dataset = row [0 ]
31
+ metric_value = row [4 ]
32
+ try :
33
+ metrics [dataset ] = f'{ float (metric_value ):.2f} '
34
+ except ValueError :
35
+ metrics [dataset ] = metric_value
36
+ except Exception as e :
37
+ print (f'Error reading metrics: { str (e )} ' )
38
+
39
+ mmlu_value = metrics .get ('mmlu' , '' )
40
+ gsm8k_value = metrics .get ('gsm8k' , '' )
41
+
42
+ summary_line = f'| { model_name } | TP{ tp_num } | { status } | { mmlu_value } | { gsm8k_value } |\n '
43
+
44
+ summary_file = os .environ .get ('GITHUB_STEP_SUMMARY' , None )
45
+ if summary_file :
46
+ write_header = False
47
+ if not os .path .exists (summary_file ) or os .path .getsize (summary_file ) == 0 :
48
+ write_header = True
49
+ else :
50
+ with open (summary_file , 'r' ) as f :
51
+ first_lines = f .read (200 )
52
+ if '| Model | TP | Status | mmlu | gsm8k |' not in first_lines :
53
+ write_header = True
54
+
55
+ with open (summary_file , 'a' ) as f :
56
+ if write_header :
57
+ f .write ('## Model Evaluation Results\n ' )
58
+ f .write ('| Model | TP | Status | mmlu | gsm8k |\n ' )
59
+ f .write ('|-------|----|--------|------|-------|\n ' )
60
+ f .write (summary_line )
22
61
else :
23
- return 'base'
62
+ print ( f'Summary: { model_name } | TP { tp_num } | { status } | { mmlu_value } | { gsm8k_value } ' )
24
63
25
64
26
65
def restful_test (config , run_id , prepare_environment , worker_id = 'gw0' , port = DEFAULT_PORT ):
66
+ work_dir = None
27
67
try :
28
68
model_name = prepare_environment ['model' ]
29
69
backend_type = prepare_environment ['backend' ]
30
70
tp_num = prepare_environment .get ('tp_num' , 1 )
31
71
communicator = prepare_environment .get ('communicator' , 'native' )
32
72
quant_policy = prepare_environment .get ('quant_policy' , 0 )
33
73
34
- model_type = get_model_type (model_name )
35
- print (f'Model { model_name } identified as { model_type } model' )
36
-
37
74
current_dir = os .path .dirname (os .path .abspath (__file__ ))
38
75
parent_dir = os .path .dirname (current_dir )
39
76
40
- if model_type == 'base' :
41
- config_file = os .path .join (parent_dir , 'evaluate/eval_config_base.py' )
42
- else :
43
- config_file = os .path .join (parent_dir , 'evaluate/eval_config_chat.py' )
77
+ config_file = os .path .join (parent_dir , 'evaluate/eval_config_chat.py' )
44
78
45
79
model_base_path = config .get ('model_path' , '/nvme/qa_test_models' )
46
80
model_path = os .path .join (model_base_path , model_name )
47
81
48
82
print (f'Starting OpenCompass evaluation for model: { model_name } ' )
49
83
print (f'Model path: { model_path } ' )
50
84
print (f'Backend: { backend_type } ' )
51
- print (f'Model type: { model_type } ' )
52
85
print (f'Config file: { config_file } ' )
53
86
54
87
log_path = config .get ('log_path' , '/nvme/qa_test_models/autotest_model/log' )
55
88
os .makedirs (log_path , exist_ok = True )
56
89
57
90
original_cwd = os .getcwd ()
58
91
work_dir = os .path .join (
59
- log_path ,
60
- f"wk_{ backend_type } _{ model_name .replace ('/' , '_' )} _{ model_type } _{ communicator } _{ worker_id } _{ quant_policy } " )
92
+ log_path , f"wk_{ backend_type } _{ model_name .replace ('/' , '_' )} _{ communicator } _{ worker_id } _{ quant_policy } " )
61
93
os .makedirs (work_dir , exist_ok = True )
62
94
63
95
try :
@@ -99,15 +131,13 @@ def restful_test(config, run_id, prepare_environment, worker_id='gw0', port=DEFA
99
131
100
132
log_filename = (f'eval_{ backend_type } _'
101
133
f"{ model_name .replace ('/' , '_' )} _"
102
- f'{ model_type } _'
103
134
f'{ communicator } _'
104
135
f'{ worker_id } _'
105
136
f'{ quant_policy } .log' )
106
137
log_file = os .path .join (log_path , log_filename )
107
138
108
139
with open (log_file , 'w' , encoding = 'utf-8' ) as f :
109
140
f .write (f'Model: { model_name } \n ' )
110
- f .write (f'Model type: { model_type } \n ' )
111
141
f .write (f'Config file: { temp_config_file } \n ' )
112
142
f .write (f'Backend: { backend_type } \n ' )
113
143
f .write (f'TP Num: { tp_num } \n ' )
@@ -131,25 +161,29 @@ def restful_test(config, run_id, prepare_environment, worker_id='gw0', port=DEFA
131
161
break
132
162
133
163
if result .returncode == 0 and not evaluation_failed :
134
- return True , f'Evaluation completed successfully for { model_name } ({ model_type } )'
164
+ final_result = True
165
+ final_msg = f'Evaluation completed successfully for { model_name } '
135
166
else :
136
- error_msg = f'Evaluation failed for { model_name } ({ model_type } ) '
167
+ final_result = False
168
+ final_msg = f'Evaluation failed for { model_name } '
137
169
if result .returncode != 0 :
138
- error_msg += f'with return code { result .returncode } '
170
+ final_msg += f'with return code { result .returncode } '
139
171
elif evaluation_failed :
140
- error_msg += 'with internal errors detected in logs'
172
+ final_msg += 'with internal errors detected in logs'
141
173
142
174
if stderr_output :
143
- error_msg += f'\n STDERR: { stderr_output } '
175
+ final_msg += f'\n STDERR: { stderr_output } '
144
176
else :
145
177
error_lines = []
146
178
for line in stdout_output .split ('\n ' ):
147
179
if any (keyword in line for keyword in error_keywords ):
148
180
error_lines .append (line )
149
181
if error_lines :
150
- error_msg += f'\n Log errors: { " | " .join (error_lines [:3 ])} '
182
+ final_msg += f'\n Log errors: { " | " .join (error_lines [:3 ])} '
183
+
184
+ write_to_summary (model_name , tp_num , final_result , final_msg , worker_id , work_dir )
151
185
152
- return False , error_msg
186
+ return final_result , final_msg
153
187
154
188
finally :
155
189
os .chdir (original_cwd )
@@ -158,6 +192,11 @@ def restful_test(config, run_id, prepare_environment, worker_id='gw0', port=DEFA
158
192
except subprocess .TimeoutExpired :
159
193
timeout_msg = (f'Evaluation timed out for { model_name } '
160
194
f'after 7200 seconds' )
195
+ if work_dir :
196
+ write_to_summary (model_name , tp_num , False , timeout_msg , worker_id , work_dir )
161
197
return False , timeout_msg
162
198
except Exception as e :
163
- return False , f'Error during evaluation for { model_name } : { str (e )} '
199
+ error_msg = f'Error during evaluation for { model_name } : { str (e )} '
200
+ if work_dir :
201
+ write_to_summary (model_name , tp_num , False , error_msg , worker_id , work_dir )
202
+ return False , error_msg
0 commit comments