Skip to content

Commit caacb84

Browse files
committed
add tests
1 parent d3f25d8 commit caacb84

File tree

2 files changed

+262
-0
lines changed

2 files changed

+262
-0
lines changed
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
"""
2+
Runs fwd pass with random input for OGBG
3+
"""
4+
5+
import os
6+
7+
import jraph
8+
9+
from absl.testing import absltest
10+
from absl.testing import parameterized
11+
import jax
12+
import jax.numpy as jnp
13+
14+
from algoperf.workloads.ogbg.ogbg_jax.models_ref import \
15+
GNN as OrigCls
16+
from algoperf.workloads.ogbg.ogbg_jax.models import \
17+
GNN as CustCls
18+
19+
# Model / test hyper-params
20+
SEED = 1994
21+
22+
class ModeEquivalenceTest(parameterized.TestCase):
23+
24+
@parameterized.named_parameters(
25+
dict(
26+
testcase_name='OGBG, p=0.0',
27+
dropout_rate=0.0),
28+
dict(
29+
testcase_name='OGBG, p=0.1',
30+
dropout_rate=0.1),
31+
)
32+
def test_forward(self, dropout_rate):
33+
# init model
34+
rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3)
35+
36+
orig_model = OrigCls(num_outputs=128, dropout_rate=dropout_rate)
37+
cust_model = CustCls(num_outputs=128)
38+
39+
fake_batch = jraph.GraphsTuple(
40+
n_node=jnp.asarray([1]),
41+
n_edge=jnp.asarray([1]),
42+
nodes=jnp.ones((1, 9)),
43+
edges=jnp.ones((1, 3)),
44+
globals=jnp.zeros((1, 128)),
45+
senders=jnp.asarray([0]),
46+
receivers=jnp.asarray([0]))
47+
48+
initial_params_original = orig_model.init({'params': rng},
49+
fake_batch,
50+
train=False)
51+
initial_params_custom = cust_model.init({'params': rng},
52+
fake_batch,
53+
train=False)
54+
55+
# fwd
56+
x = jraph.GraphsTuple(
57+
n_node=jnp.asarray([1]),
58+
n_edge=jnp.asarray([1]),
59+
nodes=jnp.ones((1, 9)),
60+
edges=jnp.ones((1, 3)),
61+
globals=jnp.zeros((1, 128)),
62+
senders=jnp.asarray([0]),
63+
receivers=jnp.asarray([0]))
64+
65+
for mode in ('train', 'eval'):
66+
train = mode == 'train'
67+
y1 = orig_model.apply(
68+
initial_params_original,
69+
x,
70+
train=train,
71+
rngs={'dropout': dropout_rng})
72+
y2 = cust_model.apply(
73+
initial_params_custom,
74+
x,
75+
train=train,
76+
dropout_rate=dropout_rate,
77+
rngs={'dropout': dropout_rng})
78+
79+
assert jnp.allclose(y1, y2, atol=1e-3, rtol=1e-3)
80+
81+
@parameterized.named_parameters(
82+
dict(testcase_name='OGBG, default'),
83+
)
84+
def test_default_dropout(self):
85+
"""Test default dropout_rate."""
86+
87+
88+
# init model
89+
rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3)
90+
91+
orig_model = OrigCls(num_outputs=128)
92+
cust_model = CustCls(num_outputs=128)
93+
94+
fake_batch = jraph.GraphsTuple(
95+
n_node=jnp.asarray([1]),
96+
n_edge=jnp.asarray([1]),
97+
nodes=jnp.ones((1, 9)),
98+
edges=jnp.ones((1, 3)),
99+
globals=jnp.zeros((1, 128)),
100+
senders=jnp.asarray([0]),
101+
receivers=jnp.asarray([0]))
102+
103+
initial_params_original = orig_model.init({'params': rng},
104+
fake_batch,
105+
train=False)
106+
initial_params_custom = cust_model.init({'params': rng},
107+
fake_batch,
108+
train=False)
109+
110+
# fwd
111+
x = jraph.GraphsTuple(
112+
n_node=jnp.asarray([1]),
113+
n_edge=jnp.asarray([1]),
114+
nodes=jnp.ones((1, 9)),
115+
edges=jnp.ones((1, 3)),
116+
globals=jnp.zeros((1, 128)),
117+
senders=jnp.asarray([0]),
118+
receivers=jnp.asarray([0]))
119+
120+
for mode in ('train', 'eval'):
121+
train = mode == 'train'
122+
y1 = orig_model.apply(
123+
initial_params_original,
124+
x,
125+
train=train,
126+
rngs={'dropout': dropout_rng})
127+
y2 = cust_model.apply(
128+
initial_params_custom, x, train=train, rngs={'dropout': dropout_rng})
129+
130+
assert jnp.allclose(y1, y2, atol=0, rtol=0)
131+
132+
if __name__ == '__main__':
133+
absltest.main()
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
"""
2+
Runs fwd pass with random input for FASTMRI U-Net models and compares outputs.
3+
Run it as:
4+
python3 tests/dropout_fix/imagenet_vit_jax/test_model_equivalence.py
5+
"""
6+
7+
import os
8+
9+
from absl.testing import absltest
10+
from absl.testing import parameterized
11+
import jax
12+
import jax.numpy as jnp
13+
14+
from algoperf.workloads.wmt.wmt_jax.models import TransformerConfig as CustClsConfig
15+
from algoperf.workloads.wmt.wmt_jax.models import Transformer as CustCls
16+
17+
from algoperf.workloads.wmt.wmt_jax.models_ref import TransformerConfig as OrigClsConfig
18+
from algoperf.workloads.wmt.wmt_jax.models_ref import Transformer as OrigCls
19+
20+
21+
# Model / test hyper-params
22+
SEED = 1994
23+
24+
class ModeEquivalenceTest(parameterized.TestCase):
25+
26+
@parameterized.named_parameters(
27+
dict(
28+
testcase_name='WMT, p=0.0',
29+
dropout_rate=0.0),
30+
dict(
31+
testcase_name='WMT p=0.1',
32+
dropout_rate=0.1),
33+
)
34+
def test_forward(self, dropout_rate):
35+
36+
# init model
37+
rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3)
38+
39+
orig_model = OrigCls(OrigClsConfig)
40+
cust_model = CustCls(CustClsConfig)
41+
42+
init_fake_batch_size = 8
43+
input_shape = (init_fake_batch_size, 256)
44+
target_shape = (init_fake_batch_size, 256)
45+
46+
initial_params_original = orig_model.init({'params': rng},
47+
jnp.ones(input_shape, jnp.float32),
48+
jnp.ones(target_shape, jnp.float32),
49+
train=False)
50+
initial_params_custom = cust_model.init({'params': rng},
51+
jnp.ones(input_shape, jnp.float32),
52+
jnp.ones(target_shape, jnp.float32),
53+
train=False)
54+
55+
# fwd
56+
57+
for mode in ('train', 'eval'):
58+
train = mode == 'train'
59+
y1 = orig_model.apply(
60+
initial_params_original,
61+
jnp.ones(input_shape, jnp.float32),
62+
jnp.ones(target_shape, jnp.float32),
63+
train=train,
64+
rngs={'dropout': dropout_rng},
65+
mutable=['batch_stats'],)
66+
y2 = cust_model.apply(
67+
initial_params_custom,
68+
jnp.ones(input_shape, jnp.float32),
69+
jnp.ones(target_shape, jnp.float32),
70+
train=train,
71+
dropout_rate=dropout_rate,
72+
rngs={'dropout': dropout_rng},
73+
mutable=['batch_stats'])
74+
75+
for i in range(len(y1)):
76+
assert jnp.allclose(y1[i], y2[i])
77+
78+
79+
80+
@parameterized.named_parameters(
81+
dict(testcase_name='WMT, default'),
82+
)
83+
def test_default_dropout(self):
84+
"""Test default dropout_rate."""
85+
# init model
86+
rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3)
87+
88+
orig_model = OrigCls(OrigClsConfig)
89+
cust_model = CustCls(CustClsConfig)
90+
91+
init_fake_batch_size = 8
92+
input_shape = (init_fake_batch_size, 256)
93+
target_shape = (init_fake_batch_size, 256)
94+
95+
initial_params_original = orig_model.init({'params': rng},
96+
jnp.ones(input_shape, jnp.float32),
97+
jnp.ones(target_shape, jnp.float32),
98+
train=False)
99+
initial_params_custom = cust_model.init({'params': rng},
100+
jnp.ones(input_shape, jnp.float32),
101+
jnp.ones(target_shape, jnp.float32),
102+
train=False)
103+
104+
# fwd
105+
x = [jax.random.normal(data_rng, (2, *x)) for x in INPUT_SHAPE]
106+
107+
for mode in ('train', 'eval'):
108+
train = mode == 'train'
109+
y1 = orig_model.apply(
110+
initial_params_original,
111+
jnp.ones(input_shape, jnp.float32),
112+
jnp.ones(target_shape, jnp.float32),
113+
train=train,
114+
rngs={'dropout': dropout_rng}, mutable=['batch_stats'])
115+
y2 = cust_model.apply(
116+
initial_params_custom,
117+
jnp.ones(input_shape, jnp.float32),
118+
jnp.ones(target_shape, jnp.float32),
119+
train=train, rngs={'dropout': dropout_rng},
120+
mutable=['batch_stats'])
121+
122+
print(jax.tree.map(lambda x: x.shape, y1))
123+
124+
for i in range(len(y1)):
125+
assert jnp.allclose(y1[i], y2[i])
126+
127+
128+
if __name__ == '__main__':
129+
absltest.main()

0 commit comments

Comments
 (0)