Skip to content

Commit c96ba62

Browse files
Merge pull request #862 from mlcommons/dev
Dev -> main
2 parents 42d9ae1 + 7638497 commit c96ba62

File tree

49 files changed

+650
-259
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+650
-259
lines changed

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,12 @@
2222

2323
---
2424

25+
Unlike benchmarks that focus on model architecture or hardware, the AlgoPerf benchmark isolates the training algorithm itself, measuring how quickly it can achieve target performance levels on a fixed set of representative deep learning tasks. These tasks span various domains, including image classification, speech recognition, machine translation, and more, all running on standardized hardware (8x NVIDIA V100 GPUs). The benchmark includes 8 base workloads, which are fully specified. In addition there are definitions for "randomized" workloads, which are variations of the fixed workloads, which are designed to discourage overfitting. These randomized workloads were used for scoring the AlgPerf competition but will not be used for future scoring.
26+
27+
Submissions are evaluated based on their "time-to-result", i.e., the wall-clock time it takes to reach predefined validation and test set performance targets on each workload. Submissions are scored under one of two different tuning rulesets. The [external tuning rule set](https://github.com/mlcommons/algorithmic-efficiency/blob/main/docs/DOCUMENTATION.md#external-tuning-ruleset) allows a limited amount of hyperparameter tuning (20 quasirandom trials) for each workload. The [self-tuning rule set](https://github.com/mlcommons/algorithmic-efficiency/blob/main/docs/DOCUMENTATION.md#self-tuning-ruleset) allows no external tuning, so any tuning is done "on-the-clock". For each submission, a single, overall benchmark score is computed by integrating its "performance profile" across all fixed workloads. The performance profile captures the relative training time of the submission to the best submission on each workload. Therefore the score of each submission is a function of other submissions in the submission pool. The higher the benchmark score, the better the submission's overall performance.
28+
29+
---
30+
2531
> This is the repository for the *AlgoPerf: Training Algorithms benchmark* measuring neural network training speedups due to algorithmic improvements.
2632
> It is developed by the [MLCommons Algorithms Working Group](https://mlcommons.org/en/groups/research-algorithms/).
2733
> This repository holds the benchmark code, the benchmark's [**technical documentation**](/docs/DOCUMENTATION.md) and [**getting started guides**](/docs/GETTING_STARTED.md). For a detailed description of the benchmark design, see our [**introductory paper**](https://arxiv.org/abs/2306.07179), for the results of the inaugural competition see our [**results paper**](https://openreview.net/forum?id=CtM5xjRSfm).

docker/build_docker_images.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ do
1414
done
1515

1616
# Artifact repostiory
17-
ARTIFACT_REPO="europe-docker.pkg.dev/mlcommons-algoperf/algoperf-docker-repo"
17+
ARTIFACT_REPO="europe-west-4-docker.pkg.dev/mlcommons-algoperf/algoperf-docker-repo"
1818

1919
if [[ -z ${GIT_BRANCH+x} ]]
2020
then

docker/scripts/startup.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,6 @@ if [[ ! -z ${SUBMISSION_PATH+x} ]]; then
293293
--workload=${WORKLOAD} \
294294
--submission_path=${SUBMISSION_PATH} \
295295
--data_dir=${DATA_DIR} \
296-
--num_tuning_trials=1 \
297296
--experiment_dir=${EXPERIMENT_DIR} \
298297
--experiment_name=${EXPERIMENT_NAME} \
299298
--overwrite=${OVERWRITE} \
Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
"""This script can
2+
1. Summarize the raw submission times for each workload run in a set of studies and trials.
3+
2. Produce the performance profiles and scores of a group of submissions.
4+
Note that for performance profiles and final scores are computed w.r.t. a group of submissions.
5+
If you only have logs for one submission you may group it with some reference submission
6+
to compare the performance.
7+
8+
Example usage:
9+
python3 score_submissions.py \
10+
--submission_directory $HOME/algorithmic-efficiency/prize_qualification_baselines/logs \
11+
--strict True
12+
--compute_performance_profiles
13+
"""
14+
15+
import operator
16+
import os
17+
import pickle
18+
19+
from absl import app
20+
from absl import flags
21+
from absl import logging
22+
import numpy as np
23+
import pandas as pd
24+
import performance_profile
25+
import scoring_utils
26+
from tabulate import tabulate
27+
28+
flags.DEFINE_string(
29+
'submission_directory',
30+
None,
31+
'Path to submission directory containing experiment directories.')
32+
flags.DEFINE_string(
33+
'output_dir',
34+
'scoring_results',
35+
'Path to save performance profile artifacts, submission_summaries and results files.'
36+
)
37+
flags.DEFINE_boolean('compute_performance_profiles',
38+
False,
39+
'Whether or not to compute the performance profiles.')
40+
flags.DEFINE_boolean(
41+
'strict',
42+
False,
43+
'Whether to enforce scoring criteria on variant performance and on'
44+
'5-trial median performance. Note that during official scoring this '
45+
'flag will be set to True.')
46+
flags.DEFINE_boolean(
47+
'self_tuning_ruleset',
48+
False,
49+
'Whether to score on self-tuning ruleset or externally tuned ruleset')
50+
flags.DEFINE_string(
51+
'save_results_to_filename',
52+
None,
53+
'Filename to save the processed results that are fed into the performance profile functions.'
54+
)
55+
flags.DEFINE_string(
56+
'load_results_from_filename',
57+
None,
58+
'Filename to load processed results from that are fed into performance profile functions'
59+
)
60+
flags.DEFINE_string(
61+
'exclude_submissions',
62+
'',
63+
'Optional comma seperated list of names of submissions to exclude from scoring.'
64+
)
65+
FLAGS = flags.FLAGS
66+
67+
68+
def get_summary_df(workload, workload_df, include_test_split=False):
69+
validation_metric, validation_target = scoring_utils.get_workload_metrics_and_targets(workload, split='validation')
70+
71+
is_minimized = performance_profile.check_if_minimized(validation_metric)
72+
target_op = operator.le if is_minimized else operator.ge
73+
best_op = min if is_minimized else max
74+
idx_op = np.argmin if is_minimized else np.argmax
75+
76+
summary_df = pd.DataFrame()
77+
summary_df['workload'] = workload_df['workload']
78+
summary_df['trial'] = workload_df['trial'].apply(lambda x: x[0])
79+
summary_df['val target metric name'] = validation_metric
80+
summary_df['val target metric value'] = validation_target
81+
82+
summary_df['val target reached'] = workload_df[validation_metric].apply(
83+
lambda x: target_op(x, validation_target)).apply(np.any)
84+
summary_df['best metric value on val'] = workload_df[validation_metric].apply(
85+
lambda x: best_op(x))
86+
workload_df['index best eval on val'] = workload_df[validation_metric].apply(
87+
lambda x: idx_op(x))
88+
summary_df['time to best eval on val (s)'] = workload_df.apply(
89+
lambda x: x['accumulated_submission_time'][x['index best eval on val']],
90+
axis=1)
91+
workload_df['val target reached'] = workload_df[validation_metric].apply(
92+
lambda x: target_op(x, validation_target)).apply(np.any)
93+
workload_df['index to target on val'] = workload_df.apply(
94+
lambda x: np.argmax(target_op(x[validation_metric], validation_target))
95+
if x['val target reached'] else np.nan,
96+
axis=1)
97+
summary_df['time to target on val (s)'] = workload_df.apply(
98+
lambda x: x['accumulated_submission_time'][int(x[
99+
'index to target on val'])] if x['val target reached'] else np.inf,
100+
axis=1)
101+
102+
# test metrics
103+
if include_test_split:
104+
test_metric, test_target = scoring_utils.get_workload_metrics_and_targets(workload, split='test')
105+
106+
summary_df['test target metric name'] = test_metric
107+
summary_df['test target metric value'] = test_target
108+
109+
summary_df['test target reached'] = workload_df[test_metric].apply(
110+
lambda x: target_op(x, test_target)).apply(np.any)
111+
summary_df['best metric value on test'] = workload_df[test_metric].apply(
112+
lambda x: best_op(x))
113+
workload_df['index best eval on test'] = workload_df[test_metric].apply(
114+
lambda x: idx_op(x))
115+
summary_df['time to best eval on test (s)'] = workload_df.apply(
116+
lambda x: x['accumulated_submission_time'][x['index best eval on test']
117+
],
118+
axis=1)
119+
summary_df['time to target on test (s)'] = summary_df.apply(
120+
lambda x: x['time to best eval on test (s)']
121+
if x['test target reached'] else np.inf,
122+
axis=1)
123+
124+
return summary_df
125+
126+
127+
def get_submission_summary(df, include_test_split=True):
128+
"""Summarizes the submission results into metric and time tables
129+
organized by workload.
130+
"""
131+
132+
dfs = []
133+
print(df)
134+
for workload, group in df.groupby('workload'):
135+
summary_df = get_summary_df(
136+
workload, group, include_test_split=include_test_split)
137+
dfs.append(summary_df)
138+
139+
df = pd.concat(dfs)
140+
logging.info('\n' + tabulate(df, headers='keys', tablefmt='psql'))
141+
return df
142+
143+
144+
def compute_leaderboard_score(df, normalize=True):
145+
"""Compute leaderboard score by taking integral of performance profile.
146+
147+
Args:
148+
df: pd.DataFrame returned from `compute_performance_profiles`.
149+
normalize: divide by the range of the performance profile's tau.
150+
151+
Returns:
152+
pd.DataFrame with one column of scores indexed by submission.
153+
"""
154+
scores = np.trapz(df, x=df.columns)
155+
if normalize:
156+
scores /= df.columns.max() - df.columns.min()
157+
return pd.DataFrame(scores, columns=['score'], index=df.index)
158+
159+
160+
def main(_):
161+
results = {}
162+
os.makedirs(FLAGS.output_dir, exist_ok=True)
163+
164+
# Optionally read results to filename
165+
if FLAGS.load_results_from_filename:
166+
with open(
167+
os.path.join(FLAGS.output_dir, FLAGS.load_results_from_filename),
168+
'rb') as f:
169+
results = pickle.load(f)
170+
else:
171+
for team in os.listdir(FLAGS.submission_directory):
172+
for submission in os.listdir(
173+
os.path.join(FLAGS.submission_directory, team)):
174+
print(submission)
175+
if submission in FLAGS.exclude_submissions.split(','):
176+
continue
177+
experiment_path = os.path.join(FLAGS.submission_directory,
178+
team,
179+
submission)
180+
df = scoring_utils.get_experiment_df(experiment_path)
181+
results[submission] = df
182+
summary_df = get_submission_summary(df)
183+
with open(
184+
os.path.join(FLAGS.output_dir, f'{submission}_summary.csv'),
185+
'w') as fout:
186+
summary_df.to_csv(fout)
187+
188+
# Optionally save results to filename
189+
if FLAGS.save_results_to_filename:
190+
with open(
191+
os.path.join(FLAGS.output_dir, FLAGS.save_results_to_filename),
192+
'wb') as f:
193+
pickle.dump(results, f)
194+
195+
if not FLAGS.strict:
196+
logging.warning(
197+
'You are running with strict=False. This will relax '
198+
'scoring criteria on the held-out workloads, number of trials and number '
199+
'of studies. Your score may not be an accurate representation '
200+
'under competition scoring rules. To enforce the criteria set strict=True.'
201+
)
202+
if FLAGS.compute_performance_profiles:
203+
performance_profile_df = performance_profile.compute_performance_profiles(
204+
results,
205+
time_col='score',
206+
min_tau=1.0,
207+
max_tau=4.0,
208+
reference_submission_tag=None,
209+
num_points=100,
210+
scale='linear',
211+
verbosity=0,
212+
self_tuning_ruleset=FLAGS.self_tuning_ruleset,
213+
strict=FLAGS.strict,
214+
output_dir=FLAGS.output_dir,
215+
)
216+
if not os.path.exists(FLAGS.output_dir):
217+
os.mkdir(FLAGS.output_dir)
218+
performance_profile.plot_performance_profiles(
219+
performance_profile_df, 'score', save_dir=FLAGS.output_dir)
220+
performance_profile_str = tabulate(
221+
performance_profile_df.T, headers='keys', tablefmt='psql')
222+
logging.info(f'Performance profile:\n {performance_profile_str}')
223+
scores = compute_leaderboard_score(performance_profile_df)
224+
scores.to_csv(os.path.join(FLAGS.output_dir, 'scores.csv'))
225+
scores_str = tabulate(scores, headers='keys', tablefmt='psql')
226+
logging.info(f'Scores: \n {scores_str}')
227+
228+
229+
if __name__ == '__main__':
230+
# flags.mark_flag_as_required('submission_directory')
231+
app.run(main)

scoring/performance_profile.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,17 +47,21 @@
4747
WORKLOAD_NAME_PATTERN = '(.*)(_jax|_pytorch)'
4848
BASE_WORKLOADS_DIR = 'algoperf/workloads/'
4949
# Open json file to read heldout workloads
50-
# TODO: This probably shouldn't be hardcoded but passed as an argument.
51-
with open("held_out_workloads_algoperf_v05.json", "r") as f:
52-
HELDOUT_WORKLOADS = json.load(f)
50+
# TODO: This probably shouldn't be hardcoded but passed as an argument.\
51+
try:
52+
with open("held_out_workloads_algoperf_v05.json", "r") as f:
53+
HELDOUT_WORKLOADS = json.load(f)
54+
except:
55+
HELDOUT_WORKLOADS = None
56+
5357
# These global variables have to be set according to the current set of
5458
# workloads and rules for the scoring to be correct.
5559
# We do not use the workload registry since it contains test and development
5660
# workloads as well.
5761
NUM_BASE_WORKLOADS = 8
58-
NUM_VARIANT_WORKLOADS = 6
62+
NUM_VARIANT_WORKLOADS = 0
5963
NUM_TRIALS = 5
60-
NUM_STUDIES = 5
64+
NUM_STUDIES = 3
6165

6266
MIN_EVAL_METRICS = [
6367
'ce_loss',
@@ -318,7 +322,8 @@ def compute_performance_profiles(submissions,
318322
# Restrict to base and sampled held-out workloads
319323
# (ignore the additional workload variants of the baseline
320324
# as they cause issues when checking for nans in workload variants).
321-
df = df[BASE_WORKLOADS + HELDOUT_WORKLOADS]
325+
if HELDOUT_WORKLOADS:
326+
df = df[BASE_WORKLOADS + HELDOUT_WORKLOADS]
322327
# Sort workloads alphabetically (for better display)
323328
df = df.reindex(sorted(df.columns), axis=1)
324329

scoring/score_submissions.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
88
Example usage:
99
python3 score_submissions.py \
10-
--submission_directory $HOME/algorithmic-efficiency/prize_qualification_baselines/logs \
11-
--strict True
12-
--compute_performance_profiles
10+
--submission_directory $HOME/algoperf-runs/submissions/rolling_leaderboard/self_tuning \
11+
--compute_performance_profiles \
12+
--output_dir scoring_results_self_tuning \
13+
--self_tuning_ruleset
1314
"""
1415

1516
import operator
@@ -160,6 +161,7 @@ def compute_leaderboard_score(df, normalize=True):
160161
def main(_):
161162
results = {}
162163
os.makedirs(FLAGS.output_dir, exist_ok=True)
164+
logging.info(f"Scoring submissions in {FLAGS.submission_directory}")
163165

164166
# Optionally read results to filename
165167
if FLAGS.load_results_from_filename:
@@ -168,22 +170,18 @@ def main(_):
168170
'rb') as f:
169171
results = pickle.load(f)
170172
else:
171-
for team in os.listdir(FLAGS.submission_directory):
172-
for submission in os.listdir(
173-
os.path.join(FLAGS.submission_directory, team)):
174-
print(submission)
175-
if submission in FLAGS.exclude_submissions.split(','):
176-
continue
177-
experiment_path = os.path.join(FLAGS.submission_directory,
178-
team,
179-
submission)
180-
df = scoring_utils.get_experiment_df(experiment_path)
181-
results[submission] = df
182-
summary_df = get_submission_summary(df)
183-
with open(
184-
os.path.join(FLAGS.output_dir, f'{submission}_summary.csv'),
185-
'w') as fout:
186-
summary_df.to_csv(fout)
173+
for submission in os.listdir(FLAGS.submission_directory):
174+
print(submission)
175+
if submission in FLAGS.exclude_submissions.split(','):
176+
continue
177+
experiment_path = os.path.join(FLAGS.submission_directory, submission)
178+
df = scoring_utils.get_experiment_df(experiment_path)
179+
results[submission] = df
180+
summary_df = get_submission_summary(df)
181+
with open(
182+
os.path.join(FLAGS.output_dir, f'{submission}_summary.csv'),
183+
'w') as fout:
184+
summary_df.to_csv(fout)
187185

188186
# Optionally save results to filename
189187
if FLAGS.save_results_to_filename:
File renamed without changes.

0 commit comments

Comments
 (0)