Skip to content

Commit 93ff958

Browse files
committed
ogbg debugging
1 parent 2e4cc9e commit 93ff958

File tree

3 files changed

+23
-17
lines changed

3 files changed

+23
-17
lines changed

algoperf/workloads/ogbg/input_pipeline.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import torch
1212

1313
AVG_NODES_PER_GRAPH = 26
14-
AVG_EDGES_PER_GRAPH = 56
14+
AVG_EDGES_PER_GRAPH = 28
1515

1616
TFDS_SPLIT_NAME = {
1717
'train': 'train',
@@ -148,24 +148,24 @@ def _get_batch_iterator(dataset_iter, global_batch_size, num_shards=None):
148148
weights_shards.append(weights)
149149

150150
if count == num_shards:
151-
# yield {
152-
# 'inputs': jraph.batch(graphs_shards),
153-
# 'targets': np.vstack(labels_shards),
154-
# 'weights': np.vstack(weights_shards)
155-
# }
156-
157-
def f(x):
158-
return jax.tree.map(lambda *vals: np.concatenate(vals, axis=0), x[0], *x[1:])
159-
160-
graphs_shards = f(graphs_shards)
161-
labels_shards = f(labels_shards)
162-
weights_shards = f(weights_shards)
163151
yield {
164-
'inputs': graphs_shards,
165-
'targets': labels_shards,
166-
'weights': weights_shards,
152+
'inputs': jraph.batch(graphs_shards),
153+
'targets': np.vstack(labels_shards),
154+
'weights': np.vstack(weights_shards)
167155
}
168156

157+
# def f(x):
158+
# return jax.tree.map(lambda *vals: np.concatenate(vals, axis=0), x[0], *x[1:])
159+
160+
# graphs_shards = f(graphs_shards)
161+
# labels_shards = f(labels_shards)
162+
# weights_shards = f(weights_shards)
163+
# yield {
164+
# 'inputs': graphs_shards,
165+
# 'targets': labels_shards,
166+
# 'weights': weights_shards,
167+
# }
168+
169169
count = 0
170170
graphs_shards = []
171171
labels_shards = []

reference_algorithms/paper_baselines/adamw/jax/submission.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def _loss_fn(params):
7474
model_state,
7575
spec.ForwardPassMode.TRAIN,
7676
rng,
77-
update_batch_norm=True,)
77+
update_batch_norm=True)
7878
jax.debug.print("logits: {logits}", logits=logits)
7979
loss_dict = workload.loss_fn(
8080
label_batch=batch['targets'],
@@ -136,6 +136,10 @@ def update_params(
136136
else:
137137
grad_clip = None
138138

139+
batch_shapes = jax.tree.map(jnp.shape, batch)
140+
print("batch shapes:")
141+
print(batch_shapes)
142+
139143
# Set up mesh and sharding
140144
mesh = sharding_utils.get_mesh()
141145
replicated = NamedSharding(mesh, P()) # No partitioning

submission_runner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
import jax
3333
import tensorflow as tf
3434

35+
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
36+
3537
# New PRNG implementation for correct sharding
3638
jax.config.update('jax_default_prng_impl', 'threefry2x32')
3739
jax.config.update('jax_threefry_partitionable', True)

0 commit comments

Comments
 (0)