|
6 | 6 |
|
7 | 7 | import os |
8 | 8 |
|
| 9 | + |
9 | 10 | from absl.testing import absltest |
10 | 11 | from absl.testing import parameterized |
11 | | -import torch |
12 | | -from torch.testing import assert_close |
| 12 | +import jax |
| 13 | +import jax.numpy as jnp |
| 14 | +# import equinox as eqx |
| 15 | + |
13 | 16 |
|
14 | | -from algoperf.workloads.fastmri.fastmri_pytorch.models import \ |
| 17 | +from algoperf.workloads.fastmri.fastmri_jax.models_ref import \ |
15 | 18 | UNet as OriginalUNet |
16 | | -from algoperf.workloads.fastmri.fastmri_pytorch.models_dropout import \ |
| 19 | +from algoperf.workloads.fastmri.fastmri_jax.models import \ |
17 | 20 | UNet as CustomUNet |
18 | 21 |
|
19 | 22 | BATCH, IN_CHANS, H, W = 4, 1, 256, 256 |
20 | 23 | OUT_CHANS, C, LAYERS = 1, 32, 4 |
21 | | -DEVICE = 'cuda' |
22 | | -TORCH_COMPILE = False |
23 | 24 | SEED = 1996 |
24 | 25 |
|
25 | | -os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" |
26 | | -torch.backends.cudnn.benchmark = False |
27 | | -torch.backends.cudnn.deterministic = True |
28 | | -torch.use_deterministic_algorithms(True) |
29 | | - |
30 | 26 |
|
31 | | -class FastMRIModeEquivalenceTest(parameterized.TestCase): |
32 | | - |
33 | | - def fwd_pass(self, orig, cust, dropout_rate): |
34 | | - x = torch.randn(BATCH, IN_CHANS, H, W, device=DEVICE) |
35 | | - for mode in ('train', 'eval'): |
36 | | - getattr(orig, mode)() |
37 | | - getattr(cust, mode)() |
38 | | - torch.manual_seed(0) |
39 | | - y1 = orig(x) |
40 | | - torch.manual_seed(0) |
41 | | - y2 = cust(x, dropout_rate) |
42 | | - assert_close(y1, y2, atol=0, rtol=0) |
43 | | - if mode == 'eval': # one extra test: omit dropout at eval |
44 | | - torch.manual_seed(0) |
45 | | - y2 = cust(x) |
46 | | - assert_close(y1, y2, atol=0, rtol=0) |
| 27 | +class ModelEquivalenceTest(parameterized.TestCase): |
47 | 28 |
|
48 | 29 | @parameterized.named_parameters( |
49 | | - dict(testcase_name='p=0.0', dropout_rate=0.0), |
50 | | - dict(testcase_name='p=0.1', dropout_rate=0.1), |
51 | | - dict(testcase_name='p=0.7', dropout_rate=0.7), |
52 | | - dict(testcase_name='p=1.0', dropout_rate=1.0), |
| 30 | + dict( |
| 31 | + testcase_name='UNet, p=0.0', |
| 32 | + dropout_rate=0.0), |
| 33 | + dict( |
| 34 | + testcase_name='UNet, p=0.1', |
| 35 | + dropout_rate=0.1), |
53 | 36 | ) |
54 | | - def test_dropout_values(self, dropout_rate): |
55 | | - """Test different values of dropout_rate.""" |
| 37 | + def test_forward(self, dropout_rate): |
| 38 | + OrigCls, CustCls = (OriginalUNet, CustomUNet) |
56 | 39 |
|
57 | | - torch.manual_seed(SEED) |
58 | | - orig = OriginalUNet( |
59 | | - IN_CHANS, OUT_CHANS, C, LAYERS, dropout_rate=dropout_rate).to(DEVICE) |
60 | 40 |
|
61 | | - torch.manual_seed(SEED) |
62 | | - cust = CustomUNet(IN_CHANS, OUT_CHANS, C, LAYERS).to(DEVICE) |
| 41 | + # init model |
| 42 | + rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3) |
63 | 43 |
|
64 | | - cust.load_state_dict(orig.state_dict()) # sync weights |
65 | | - if TORCH_COMPILE: |
66 | | - orig = torch.compile(orig) |
67 | | - cust = torch.compile(cust) |
| 44 | + kwargs = dict(num_pool_layers = LAYERS, num_channels=IN_CHANS) |
| 45 | + orig_model = OrigCls(**kwargs) |
| 46 | + cust_model = CustCls(**kwargs) |
68 | 47 |
|
69 | | - self.fwd_pass(orig, cust, dropout_rate) |
| 48 | + fake_batch = jnp.ones((BATCH, IN_CHANS, H, W)) |
70 | 49 |
|
71 | | - @parameterized.named_parameters( |
72 | | - dict(testcase_name='default', use_tanh=False, use_layer_norm=False), |
73 | | - dict(testcase_name='tanh', use_tanh=True, use_layer_norm=False), |
74 | | - dict(testcase_name='layer_norm', use_tanh=False, use_layer_norm=True), |
75 | | - dict(testcase_name='both', use_tanh=True, use_layer_norm=True), |
76 | | - ) |
77 | | - def test_arch_configs(self, use_tanh, use_layer_norm): |
78 | | - """Test different architecture configurations, fixed dropout_rate.""" |
79 | | - dropout_rate = 0.1 |
80 | | - |
81 | | - torch.manual_seed(SEED) |
82 | | - orig = OriginalUNet( |
83 | | - IN_CHANS, |
84 | | - OUT_CHANS, |
85 | | - C, |
86 | | - LAYERS, |
87 | | - dropout_rate=dropout_rate, |
88 | | - use_tanh=use_tanh, |
89 | | - use_layer_norm=use_layer_norm).to(DEVICE) |
90 | | - |
91 | | - torch.manual_seed(SEED) |
92 | | - cust = CustomUNet( |
93 | | - IN_CHANS, |
94 | | - OUT_CHANS, |
95 | | - C, |
96 | | - LAYERS, |
97 | | - use_tanh=use_tanh, |
98 | | - use_layer_norm=use_layer_norm).to(DEVICE) |
99 | | - |
100 | | - cust.load_state_dict(orig.state_dict()) # sync weights |
101 | | - if TORCH_COMPILE: |
102 | | - orig = torch.compile(orig) |
103 | | - cust = torch.compile(cust) |
104 | | - |
105 | | - self.fwd_pass(orig, cust, dropout_rate) |
| 50 | + initial_params_original = orig_model.init({'params': rng}, |
| 51 | + fake_batch, |
| 52 | + train=False) |
| 53 | + initial_params_custom = cust_model.init({'params': rng}, |
| 54 | + fake_batch, |
| 55 | + train=False) |
| 56 | + |
| 57 | + # fwd |
| 58 | + x = jax.random.normal(data_rng, shape=(BATCH, H, W)) |
| 59 | + |
| 60 | + for mode in ('train', 'eval'): |
| 61 | + train = mode == 'train' |
| 62 | + y1 = orig_model.apply( |
| 63 | + initial_params_original, |
| 64 | + x, |
| 65 | + train=train, |
| 66 | + rngs={'dropout': dropout_rng}) |
| 67 | + y2 = cust_model.apply( |
| 68 | + initial_params_custom, |
| 69 | + x, |
| 70 | + train=train, |
| 71 | + dropout_rate=dropout_rate, |
| 72 | + rngs={'dropout': dropout_rng}) |
| 73 | + |
| 74 | + assert jnp.allclose(y1, y2, atol=1e-3, rtol=1e-3) |
106 | 75 |
|
107 | 76 | @parameterized.named_parameters( |
108 | | - dict(testcase_name=''),) |
| 77 | + dict(testcase_name='UNet, default'), |
| 78 | + ) |
109 | 79 | def test_default_dropout(self): |
110 | 80 | """Test default dropout_rate.""" |
| 81 | + OrigCls, CustCls = (OriginalUNet, CustomUNet) |
111 | 82 |
|
112 | | - torch.manual_seed(SEED) |
113 | | - orig = OriginalUNet(IN_CHANS, OUT_CHANS, C, LAYERS).to(DEVICE) |
114 | | - torch.manual_seed(SEED) |
115 | | - cust = CustomUNet(IN_CHANS, OUT_CHANS, C, LAYERS).to(DEVICE) |
116 | | - cust.load_state_dict(orig.state_dict()) # sync weights |
117 | 83 |
|
118 | | - x = torch.randn(BATCH, IN_CHANS, H, W, device=DEVICE) |
119 | | - for mode in ('train', 'eval'): |
120 | | - getattr(orig, mode)() |
121 | | - getattr(cust, mode)() |
122 | | - torch.manual_seed(0) |
123 | | - y1 = orig(x) |
124 | | - torch.manual_seed(0) |
125 | | - y2 = cust(x) |
126 | | - assert_close(y1, y2, atol=0, rtol=0) |
| 84 | + # init model |
| 85 | + rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3) |
| 86 | + |
| 87 | + kwargs = dict(num_pool_layers=LAYERS, |
| 88 | + num_channels=IN_CHANS, |
| 89 | + ) |
| 90 | + orig_model = OrigCls(**kwargs) |
| 91 | + cust_model = CustCls(**kwargs) |
127 | 92 |
|
| 93 | + fake_batch = jnp.ones((2, IN_CHANS, H, W)) |
| 94 | + |
| 95 | + initial_params_original = orig_model.init({'params': rng}, |
| 96 | + fake_batch, |
| 97 | + train=False) |
| 98 | + initial_params_custom = cust_model.init({'params': rng}, |
| 99 | + fake_batch, |
| 100 | + train=False) |
| 101 | + |
| 102 | + # fwd |
| 103 | + x = jax.random.normal(data_rng, shape=(BATCH, H, W)) |
| 104 | + |
| 105 | + for mode in ('train', 'eval'): |
| 106 | + train = mode == 'train' |
| 107 | + y1 = orig_model.apply( |
| 108 | + initial_params_original, |
| 109 | + x, |
| 110 | + train=train, |
| 111 | + rngs={'dropout': dropout_rng}) |
| 112 | + y2 = cust_model.apply( |
| 113 | + initial_params_custom, x, train=train, rngs={'dropout': dropout_rng}) |
| 114 | + |
| 115 | + assert jnp.allclose(y1, y2, atol=0, rtol=0) |
128 | 116 |
|
129 | 117 | if __name__ == '__main__': |
130 | 118 | absltest.main() |
0 commit comments