Skip to content

Commit 2e4cc9e

Browse files
committed
ogbg jit migration
1 parent c208cc7 commit 2e4cc9e

File tree

6 files changed

+53
-33
lines changed

6 files changed

+53
-33
lines changed

algoperf/workloads/ogbg/input_pipeline.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,10 +148,15 @@ def _get_batch_iterator(dataset_iter, global_batch_size, num_shards=None):
148148
weights_shards.append(weights)
149149

150150
if count == num_shards:
151+
# yield {
152+
# 'inputs': jraph.batch(graphs_shards),
153+
# 'targets': np.vstack(labels_shards),
154+
# 'weights': np.vstack(weights_shards)
155+
# }
151156

152157
def f(x):
153-
return jax.tree.map(lambda *vals: np.stack(vals, axis=0), x[0], *x[1:])
154-
158+
return jax.tree.map(lambda *vals: np.concatenate(vals, axis=0), x[0], *x[1:])
159+
155160
graphs_shards = f(graphs_shards)
156161
labels_shards = f(labels_shards)
157162
weights_shards = f(weights_shards)

algoperf/workloads/ogbg/ogbg_jax/models.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# https://github.com/google/init2winit/blob/master/init2winit/model_lib/gnn.py.
33
from typing import Optional, Tuple
44

5+
import jax
56
from flax import linen as nn
67
import jax.numpy as jnp
78
import jraph
@@ -78,7 +79,8 @@ def __call__(self, graph, train):
7879
self.hidden_dims, dropout=dropout, activation_fn=activation_fn),
7980
update_global_fn=_make_mlp(
8081
self.hidden_dims, dropout=dropout, activation_fn=activation_fn))
81-
82+
# jax.debug.print(str(graph))
83+
8284
graph = net(graph)
8385

8486
# Map globals to represent the final result

algoperf/workloads/ogbg/ogbg_jax/workload.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import jraph
99
import optax
1010

11+
from algoperf import sharding_utils
1112
from algoperf import param_utils
1213
from algoperf import spec
1314
from algoperf.workloads.ogbg import metrics
@@ -45,7 +46,8 @@ def init_model_fn(
4546
params = params['params']
4647
self._param_shapes = param_utils.jax_param_shapes(params)
4748
self._param_types = param_utils.jax_param_types(self._param_shapes)
48-
return jax_utils.replicate(params), None
49+
params = sharding_utils.shard_replicated(params)
50+
return params, None
4951

5052
def is_output_params(self, param_key: spec.ParameterKey) -> bool:
5153
return param_key == 'Dense_17'
@@ -106,11 +108,20 @@ def _eval_metric(self, labels, logits, masks):
106108
return metrics.EvalMetrics.single_from_model_output(
107109
loss=loss['per_example'], logits=logits, labels=labels, mask=masks)
108110

111+
# @functools.partial(
112+
# jax.pmap,
113+
# axis_name='batch',
114+
# in_axes=(None, 0, 0, 0, None),
115+
# static_broadcasted_argnums=(0,))
109116
@functools.partial(
110-
jax.pmap,
111-
axis_name='batch',
112-
in_axes=(None, 0, 0, 0, None),
113-
static_broadcasted_argnums=(0,))
117+
jax.jit,
118+
in_shardings=(sharding_utils.get_replicated_sharding(),
119+
sharding_utils.get_naive_sharding_spec(),
120+
sharding_utils.get_replicated_sharding(),
121+
sharding_utils.get_replicated_sharding()),
122+
static_argnums=(0,),
123+
out_shardings=sharding_utils.get_replicated_sharding(),
124+
)
114125
def _eval_batch(self, params, batch, model_state, rng):
115126
return super()._eval_batch(params, batch, model_state, rng)
116127

@@ -119,7 +130,8 @@ def _normalize_eval_metrics(
119130
Any]) -> Dict[str, float]:
120131
"""Normalize eval metrics."""
121132
del num_examples
122-
total_metrics = total_metrics.reduce()
133+
# total_metrics = total_metrics.reduce()
134+
print(total_metrics)
123135
return {k: float(v) for k, v in total_metrics.compute().items()}
124136

125137

algoperf/workloads/ogbg/workload.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ def _eval_batch(self,
161161
spec.ForwardPassMode.EVAL,
162162
rng,
163163
update_batch_norm=False)
164+
jax.debug.print(str(logits))
164165
return self._eval_metric(batch['targets'], logits, batch['weights'])
165166

166167
def _eval_model_on_split(self,

reference_algorithms/paper_baselines/adamw/jax/submission.py

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def _loss_fn(params):
7575
spec.ForwardPassMode.TRAIN,
7676
rng,
7777
update_batch_norm=True,)
78+
jax.debug.print("logits: {logits}", logits=logits)
7879
loss_dict = workload.loss_fn(
7980
label_batch=batch['targets'],
8081
logits_batch=logits,
@@ -140,31 +141,29 @@ def update_params(
140141
replicated = NamedSharding(mesh, P()) # No partitioning
141142
sharded = NamedSharding(mesh, P('batch')) # Partition along batch dimension
142143

143-
# Define input and output shardings
144-
arg_shardings = (
145-
# workload is static
146-
# opt_update_fn is static
147-
replicated, # model_state
148-
replicated, # optimizer_state
149-
replicated, # current_param_container
150-
sharded, # batch
151-
replicated, # rng
152-
replicated, # grad_clip
153-
replicated # label_smoothing
154-
)
155-
out_shardings = (
156-
replicated, # new_optimizer_state
157-
replicated, # updated_params
158-
replicated, # new_model_state
159-
replicated, # loss
160-
replicated # grad_norm
161-
)
162144
jitted_train_step = jax.jit(
163145
train_step,
164146
static_argnums=(0, 1),
165147
donate_argnums=(2, 3, 4),
166-
in_shardings=arg_shardings,
167-
out_shardings=out_shardings)
148+
in_shardings= (
149+
# workload is static
150+
# opt_update_fn is static
151+
replicated, # model_state
152+
replicated, # optimizer_state
153+
replicated, # current_param_container
154+
sharded, # batch
155+
replicated, # rng
156+
replicated, # grad_clip
157+
replicated # label_smoothing
158+
),
159+
out_shardings=(
160+
replicated, # new_optimizer_state
161+
replicated, # updated_params
162+
replicated, # new_model_state
163+
replicated, # loss
164+
replicated # grad_norm
165+
))
166+
# print(batch)
168167
new_optimizer_state, new_params, new_model_state, loss, grad_norm = jitted_train_step(workload,
169168
opt_update_fn,
170169
model_state,
@@ -176,7 +175,7 @@ def update_params(
176175
label_smoothing)
177176

178177
# Log loss, grad_norm.
179-
if global_step % 100 == 0 and workload.metrics_logger is not None:
178+
if global_step % 1 == 0 and workload.metrics_logger is not None:
180179
workload.metrics_logger.append_scalar_metrics(
181180
{
182181
'loss': loss.item(),

submission_runner.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -392,8 +392,9 @@ def train_once(
392392
train_step_end_time - train_state['last_step_end_time'])
393393

394394
# Check if submission is eligible for an untimed eval.
395-
if ((train_step_end_time - train_state['last_eval_time']) >=
396-
workload.eval_period_time_sec or train_state['training_complete']):
395+
if False:
396+
# if ((train_step_end_time - train_state['last_eval_time']) >=
397+
# workload.eval_period_time_sec or train_state['training_complete']):
397398

398399
# Prepare for evaluation (timed).
399400
if prepare_for_eval is not None:

0 commit comments

Comments
 (0)