Skip to content

Commit d3f25d8

Browse files
committed
add tests
1 parent ad36a7c commit d3f25d8

File tree

10 files changed

+337
-598
lines changed

10 files changed

+337
-598
lines changed

algoperf/workloads/criteo1tb/criteo1tb_jax/models_ref.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class DLRMResNet(nn.Module):
2323
mlp_bottom_dims: Sequence[int] = (256, 256, 256)
2424
mlp_top_dims: Sequence[int] = (256, 256, 256, 256, 1)
2525
embed_dim: int = 128
26-
dropout_rate: float = 0.0
26+
dropout_rate: float = 0.1
2727
use_layer_norm: bool = False # Unused.
2828
embedding_init_multiplier: float = None # Unused
2929

algoperf/workloads/librispeech_conformer/librispeech_jax/models_ref.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,12 @@ class ConformerConfig:
3636
encoder_dim: int = 512
3737
num_attention_heads: int = 8
3838
num_encoder_layers: int = 4
39-
attention_dropout_rate: float = 0.0
39+
attention_dropout_rate: float = 0.1
4040
# If None, defaults to 0.1.
4141
attention_residual_dropout_rate: Optional[float] = 0.1
4242
# If None, defaults to 0.0.
43-
conv_residual_dropout_rate: Optional[float] = 0.0
44-
feed_forward_dropout_rate: float = 0.0
43+
conv_residual_dropout_rate: Optional[float] = 0.1
44+
feed_forward_dropout_rate: float = 0.1
4545
# If None, defaults to 0.1.
4646
feed_forward_residual_dropout_rate: Optional[float] = 0.1
4747
convolution_kernel_size: int = 5

algoperf/workloads/ogbg/ogbg_jax/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,4 +93,4 @@ def __call__(self, graph, train, dropout_rate=DROPOUT_RATE):
9393
decoder = jraph.GraphMapFeatures(embed_global_fn=nn.Dense(self.num_outputs))
9494
graph = decoder(graph)
9595

96-
return graph.globals
96+
return graph.globals

algoperf/workloads/ogbg/ogbg_jax/models_ref.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def make_fn(inputs):
1515
return make_fn
1616

1717

18-
def _make_mlp(hidden_dims, dropout, activation_fn):
18+
def _make_mlp(hidden_dims, activation_fn, train, dropout_rate):
1919
"""Creates a MLP with specified dimensions."""
2020

2121
@jraph.concatenated_args
@@ -25,7 +25,7 @@ def make_fn(inputs):
2525
x = nn.Dense(features=dim)(x)
2626
x = nn.LayerNorm()(x)
2727
x = activation_fn(x)
28-
x = dropout(x)
28+
x = nn.Dropout(rate=dropout_rate, deterministic=not train)(x)
2929
return x
3030

3131
return make_fn
@@ -46,11 +46,7 @@ class GNN(nn.Module):
4646

4747
@nn.compact
4848
def __call__(self, graph, train):
49-
if self.dropout_rate is None:
50-
dropout_rate = 0.1
51-
else:
52-
dropout_rate = self.dropout_rate
53-
dropout = nn.Dropout(rate=dropout_rate, deterministic=not train)
49+
dropout_rate = self.dropout_rate
5450

5551
graph = graph._replace(
5652
globals=jnp.zeros([graph.n_node.shape[0], self.num_outputs]))
@@ -73,11 +69,11 @@ def __call__(self, graph, train):
7369
for _ in range(self.num_message_passing_steps):
7470
net = jraph.GraphNetwork(
7571
update_edge_fn=_make_mlp(
76-
self.hidden_dims, dropout=dropout, activation_fn=activation_fn),
72+
self.hidden_dims, activation_fn=activation_fn, train=train, dropout_rate=dropout_rate),
7773
update_node_fn=_make_mlp(
78-
self.hidden_dims, dropout=dropout, activation_fn=activation_fn),
74+
self.hidden_dims, activation_fn=activation_fn, train=train, dropout_rate=dropout_rate),
7975
update_global_fn=_make_mlp(
80-
self.hidden_dims, dropout=dropout, activation_fn=activation_fn))
76+
self.hidden_dims, activation_fn=activation_fn, train=train, dropout_rate=dropout_rate))
8177

8278
graph = net(graph)
8379

tests/dropout_fix/fastmri_jax/test_model_equivalence.py

Lines changed: 81 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -6,125 +6,113 @@
66

77
import os
88

9+
910
from absl.testing import absltest
1011
from absl.testing import parameterized
11-
import torch
12-
from torch.testing import assert_close
12+
import jax
13+
import jax.numpy as jnp
14+
# import equinox as eqx
15+
1316

14-
from algoperf.workloads.fastmri.fastmri_pytorch.models import \
17+
from algoperf.workloads.fastmri.fastmri_jax.models_ref import \
1518
UNet as OriginalUNet
16-
from algoperf.workloads.fastmri.fastmri_pytorch.models_dropout import \
19+
from algoperf.workloads.fastmri.fastmri_jax.models import \
1720
UNet as CustomUNet
1821

1922
BATCH, IN_CHANS, H, W = 4, 1, 256, 256
2023
OUT_CHANS, C, LAYERS = 1, 32, 4
21-
DEVICE = 'cuda'
22-
TORCH_COMPILE = False
2324
SEED = 1996
2425

25-
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
26-
torch.backends.cudnn.benchmark = False
27-
torch.backends.cudnn.deterministic = True
28-
torch.use_deterministic_algorithms(True)
29-
3026

31-
class FastMRIModeEquivalenceTest(parameterized.TestCase):
32-
33-
def fwd_pass(self, orig, cust, dropout_rate):
34-
x = torch.randn(BATCH, IN_CHANS, H, W, device=DEVICE)
35-
for mode in ('train', 'eval'):
36-
getattr(orig, mode)()
37-
getattr(cust, mode)()
38-
torch.manual_seed(0)
39-
y1 = orig(x)
40-
torch.manual_seed(0)
41-
y2 = cust(x, dropout_rate)
42-
assert_close(y1, y2, atol=0, rtol=0)
43-
if mode == 'eval': # one extra test: omit dropout at eval
44-
torch.manual_seed(0)
45-
y2 = cust(x)
46-
assert_close(y1, y2, atol=0, rtol=0)
27+
class ModelEquivalenceTest(parameterized.TestCase):
4728

4829
@parameterized.named_parameters(
49-
dict(testcase_name='p=0.0', dropout_rate=0.0),
50-
dict(testcase_name='p=0.1', dropout_rate=0.1),
51-
dict(testcase_name='p=0.7', dropout_rate=0.7),
52-
dict(testcase_name='p=1.0', dropout_rate=1.0),
30+
dict(
31+
testcase_name='UNet, p=0.0',
32+
dropout_rate=0.0),
33+
dict(
34+
testcase_name='UNet, p=0.1',
35+
dropout_rate=0.1),
5336
)
54-
def test_dropout_values(self, dropout_rate):
55-
"""Test different values of dropout_rate."""
37+
def test_forward(self, dropout_rate):
38+
OrigCls, CustCls = (OriginalUNet, CustomUNet)
5639

57-
torch.manual_seed(SEED)
58-
orig = OriginalUNet(
59-
IN_CHANS, OUT_CHANS, C, LAYERS, dropout_rate=dropout_rate).to(DEVICE)
6040

61-
torch.manual_seed(SEED)
62-
cust = CustomUNet(IN_CHANS, OUT_CHANS, C, LAYERS).to(DEVICE)
41+
# init model
42+
rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3)
6343

64-
cust.load_state_dict(orig.state_dict()) # sync weights
65-
if TORCH_COMPILE:
66-
orig = torch.compile(orig)
67-
cust = torch.compile(cust)
44+
kwargs = dict(num_pool_layers = LAYERS, num_channels=IN_CHANS)
45+
orig_model = OrigCls(**kwargs)
46+
cust_model = CustCls(**kwargs)
6847

69-
self.fwd_pass(orig, cust, dropout_rate)
48+
fake_batch = jnp.ones((BATCH, IN_CHANS, H, W))
7049

71-
@parameterized.named_parameters(
72-
dict(testcase_name='default', use_tanh=False, use_layer_norm=False),
73-
dict(testcase_name='tanh', use_tanh=True, use_layer_norm=False),
74-
dict(testcase_name='layer_norm', use_tanh=False, use_layer_norm=True),
75-
dict(testcase_name='both', use_tanh=True, use_layer_norm=True),
76-
)
77-
def test_arch_configs(self, use_tanh, use_layer_norm):
78-
"""Test different architecture configurations, fixed dropout_rate."""
79-
dropout_rate = 0.1
80-
81-
torch.manual_seed(SEED)
82-
orig = OriginalUNet(
83-
IN_CHANS,
84-
OUT_CHANS,
85-
C,
86-
LAYERS,
87-
dropout_rate=dropout_rate,
88-
use_tanh=use_tanh,
89-
use_layer_norm=use_layer_norm).to(DEVICE)
90-
91-
torch.manual_seed(SEED)
92-
cust = CustomUNet(
93-
IN_CHANS,
94-
OUT_CHANS,
95-
C,
96-
LAYERS,
97-
use_tanh=use_tanh,
98-
use_layer_norm=use_layer_norm).to(DEVICE)
99-
100-
cust.load_state_dict(orig.state_dict()) # sync weights
101-
if TORCH_COMPILE:
102-
orig = torch.compile(orig)
103-
cust = torch.compile(cust)
104-
105-
self.fwd_pass(orig, cust, dropout_rate)
50+
initial_params_original = orig_model.init({'params': rng},
51+
fake_batch,
52+
train=False)
53+
initial_params_custom = cust_model.init({'params': rng},
54+
fake_batch,
55+
train=False)
56+
57+
# fwd
58+
x = jax.random.normal(data_rng, shape=(BATCH, H, W))
59+
60+
for mode in ('train', 'eval'):
61+
train = mode == 'train'
62+
y1 = orig_model.apply(
63+
initial_params_original,
64+
x,
65+
train=train,
66+
rngs={'dropout': dropout_rng})
67+
y2 = cust_model.apply(
68+
initial_params_custom,
69+
x,
70+
train=train,
71+
dropout_rate=dropout_rate,
72+
rngs={'dropout': dropout_rng})
73+
74+
assert jnp.allclose(y1, y2, atol=1e-3, rtol=1e-3)
10675

10776
@parameterized.named_parameters(
108-
dict(testcase_name=''),)
77+
dict(testcase_name='UNet, default'),
78+
)
10979
def test_default_dropout(self):
11080
"""Test default dropout_rate."""
81+
OrigCls, CustCls = (OriginalUNet, CustomUNet)
11182

112-
torch.manual_seed(SEED)
113-
orig = OriginalUNet(IN_CHANS, OUT_CHANS, C, LAYERS).to(DEVICE)
114-
torch.manual_seed(SEED)
115-
cust = CustomUNet(IN_CHANS, OUT_CHANS, C, LAYERS).to(DEVICE)
116-
cust.load_state_dict(orig.state_dict()) # sync weights
11783

118-
x = torch.randn(BATCH, IN_CHANS, H, W, device=DEVICE)
119-
for mode in ('train', 'eval'):
120-
getattr(orig, mode)()
121-
getattr(cust, mode)()
122-
torch.manual_seed(0)
123-
y1 = orig(x)
124-
torch.manual_seed(0)
125-
y2 = cust(x)
126-
assert_close(y1, y2, atol=0, rtol=0)
84+
# init model
85+
rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3)
86+
87+
kwargs = dict(num_pool_layers=LAYERS,
88+
num_channels=IN_CHANS,
89+
)
90+
orig_model = OrigCls(**kwargs)
91+
cust_model = CustCls(**kwargs)
12792

93+
fake_batch = jnp.ones((2, IN_CHANS, H, W))
94+
95+
initial_params_original = orig_model.init({'params': rng},
96+
fake_batch,
97+
train=False)
98+
initial_params_custom = cust_model.init({'params': rng},
99+
fake_batch,
100+
train=False)
101+
102+
# fwd
103+
x = jax.random.normal(data_rng, shape=(BATCH, H, W))
104+
105+
for mode in ('train', 'eval'):
106+
train = mode == 'train'
107+
y1 = orig_model.apply(
108+
initial_params_original,
109+
x,
110+
train=train,
111+
rngs={'dropout': dropout_rng})
112+
y2 = cust_model.apply(
113+
initial_params_custom, x, train=train, rngs={'dropout': dropout_rng})
114+
115+
assert jnp.allclose(y1, y2, atol=0, rtol=0)
128116

129117
if __name__ == '__main__':
130118
absltest.main()

0 commit comments

Comments
 (0)