Skip to content

Commit 9e1f337

Browse files
committed
clean up ogbg
2 parents 7a71cf0 + 801151b commit 9e1f337

21 files changed

+316
-95
lines changed

algoperf/workloads/ogbg/input_pipeline.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -148,22 +148,10 @@ def _get_batch_iterator(dataset_iter, global_batch_size, num_shards=None):
148148
weights_shards.append(weights)
149149

150150
if count == num_shards:
151-
# yield {
152-
# 'inputs': jraph.batch(graphs_shards),
153-
# 'targets': np.vstack(labels_shards),
154-
# 'weights': np.vstack(weights_shards)
155-
# }
156-
157-
def f(x):
158-
return jax.tree.map(lambda *vals: np.concatenate(vals, axis=0), x[0], *x[1:])
159-
160-
graphs_shards = f(graphs_shards)
161-
labels_shards = f(labels_shards)
162-
weights_shards = f(weights_shards)
163151
yield {
164-
'inputs': graphs_shards,
165-
'targets': labels_shards,
166-
'weights': weights_shards,
152+
'inputs': jraph.batch(graphs_shards),
153+
'targets': np.vstack(labels_shards),
154+
'weights': np.vstack(weights_shards)
167155
}
168156

169157
count = 0

algoperf/workloads/ogbg/ogbg_jax/models.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,7 @@ def __call__(self, graph, train):
7979
self.hidden_dims, dropout=dropout, activation_fn=activation_fn),
8080
update_global_fn=_make_mlp(
8181
self.hidden_dims, dropout=dropout, activation_fn=activation_fn))
82-
# jax.debug.print(str(graph))
83-
82+
8483
graph = net(graph)
8584

8685
# Map globals to represent the final result

algoperf/workloads/ogbg/ogbg_jax/workload.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -108,11 +108,7 @@ def _eval_metric(self, labels, logits, masks):
108108
return metrics.EvalMetrics.single_from_model_output(
109109
loss=loss['per_example'], logits=logits, labels=labels, mask=masks)
110110

111-
# @functools.partial(
112-
# jax.pmap,
113-
# axis_name='batch',
114-
# in_axes=(None, 0, 0, 0, None),
115-
# static_broadcasted_argnums=(0,))
111+
116112
@functools.partial(
117113
jax.jit,
118114
in_shardings=(sharding_utils.get_replicated_sharding(),
@@ -130,8 +126,6 @@ def _normalize_eval_metrics(
130126
Any]) -> Dict[str, float]:
131127
"""Normalize eval metrics."""
132128
del num_examples
133-
# total_metrics = total_metrics.reduce()
134-
print(total_metrics)
135129
return {k: float(v) for k, v in total_metrics.compute().items()}
136130

137131

docker/build_docker_images.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ do
1414
done
1515

1616
# Artifact repostiory
17-
ARTIFACT_REPO="europe-docker.pkg.dev/mlcommons-algoperf/algoperf-docker-repo"
17+
ARTIFACT_REPO="europe-west-4-docker.pkg.dev/mlcommons-algoperf/algoperf-docker-repo"
1818

1919
if [[ -z ${GIT_BRANCH+x} ]]
2020
then

docker/scripts/startup.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,6 @@ if [[ ! -z ${SUBMISSION_PATH+x} ]]; then
293293
--workload=${WORKLOAD} \
294294
--submission_path=${SUBMISSION_PATH} \
295295
--data_dir=${DATA_DIR} \
296-
--num_tuning_trials=1 \
297296
--experiment_dir=${EXPERIMENT_DIR} \
298297
--experiment_name=${EXPERIMENT_NAME} \
299298
--overwrite=${OVERWRITE} \

prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch."""
22

3+
import collections
34
import math
45
from typing import Any, Dict, Iterator, List, Optional, Tuple
56

@@ -24,6 +25,7 @@
2425
"weight_decay": 0.08121616522670176,
2526
"warmup_factor": 0.02
2627
}
28+
HPARAMS = collections.namedtuple('Hyperparameters', HPARAMS.keys())(**HPARAMS)
2729

2830

2931
# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py.

prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch."""
22

3+
import collections
34
import math
45
from typing import Any, Dict, Iterator, List, Optional, Tuple
56

@@ -24,6 +25,7 @@
2425
"weight_decay": 0.08121616522670176,
2526
"warmup_factor": 0.02
2627
}
28+
HPARAMS = collections.namedtuple('Hyperparameters', HPARAMS.keys())(**HPARAMS)
2729

2830

2931
# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py.

0 commit comments

Comments
 (0)