1+ """
2+ Runs fwd pass with random input for FASTMRI U-Net models and compares outputs.
3+ Run it as:
4+ python3 tests/dropout_fix/imagenet_vit_jax/test_model_equivalence.py
5+ """
6+
7+ import os
8+
9+ from absl .testing import absltest
10+ from absl .testing import parameterized
11+ import jax
12+ import jax .numpy as jnp
13+
14+ from algoperf .workloads .wmt .wmt_jax .models import TransformerConfig as CustClsConfig
15+ from algoperf .workloads .wmt .wmt_jax .models import Transformer as CustCls
16+
17+ from algoperf .workloads .wmt .wmt_jax .models_ref import TransformerConfig as OrigClsConfig
18+ from algoperf .workloads .wmt .wmt_jax .models_ref import Transformer as OrigCls
19+
20+
21+ # Model / test hyper-params
22+ SEED = 1994
23+
24+ class ModeEquivalenceTest (parameterized .TestCase ):
25+
26+ @parameterized .named_parameters (
27+ dict (
28+ testcase_name = 'WMT, p=0.0' ,
29+ dropout_rate = 0.0 ),
30+ dict (
31+ testcase_name = 'WMT p=0.1' ,
32+ dropout_rate = 0.1 ),
33+ )
34+ def test_forward (self , dropout_rate ):
35+
36+ # init model
37+ rng , data_rng , dropout_rng = jax .random .split (jax .random .key (SEED ), 3 )
38+
39+ orig_model = OrigCls (OrigClsConfig )
40+ cust_model = CustCls (CustClsConfig )
41+
42+ init_fake_batch_size = 8
43+ input_shape = (init_fake_batch_size , 256 )
44+ target_shape = (init_fake_batch_size , 256 )
45+
46+ initial_params_original = orig_model .init ({'params' : rng },
47+ jnp .ones (input_shape , jnp .float32 ),
48+ jnp .ones (target_shape , jnp .float32 ),
49+ train = False )
50+ initial_params_custom = cust_model .init ({'params' : rng },
51+ jnp .ones (input_shape , jnp .float32 ),
52+ jnp .ones (target_shape , jnp .float32 ),
53+ train = False )
54+
55+ # fwd
56+
57+ for mode in ('train' , 'eval' ):
58+ train = mode == 'train'
59+ y1 = orig_model .apply (
60+ initial_params_original ,
61+ jnp .ones (input_shape , jnp .float32 ),
62+ jnp .ones (target_shape , jnp .float32 ),
63+ train = train ,
64+ rngs = {'dropout' : dropout_rng },
65+ mutable = ['batch_stats' ],)
66+ y2 = cust_model .apply (
67+ initial_params_custom ,
68+ jnp .ones (input_shape , jnp .float32 ),
69+ jnp .ones (target_shape , jnp .float32 ),
70+ train = train ,
71+ dropout_rate = dropout_rate ,
72+ rngs = {'dropout' : dropout_rng },
73+ mutable = ['batch_stats' ])
74+
75+ for i in range (len (y1 )):
76+ assert jnp .allclose (y1 [i ], y2 [i ])
77+
78+
79+
80+ @parameterized .named_parameters (
81+ dict (testcase_name = 'WMT, default' ),
82+ )
83+ def test_default_dropout (self ):
84+ """Test default dropout_rate."""
85+ # init model
86+ rng , data_rng , dropout_rng = jax .random .split (jax .random .key (SEED ), 3 )
87+
88+ orig_model = OrigCls (OrigClsConfig )
89+ cust_model = CustCls (CustClsConfig )
90+
91+ init_fake_batch_size = 8
92+ input_shape = (init_fake_batch_size , 256 )
93+ target_shape = (init_fake_batch_size , 256 )
94+
95+ initial_params_original = orig_model .init ({'params' : rng },
96+ jnp .ones (input_shape , jnp .float32 ),
97+ jnp .ones (target_shape , jnp .float32 ),
98+ train = False )
99+ initial_params_custom = cust_model .init ({'params' : rng },
100+ jnp .ones (input_shape , jnp .float32 ),
101+ jnp .ones (target_shape , jnp .float32 ),
102+ train = False )
103+
104+ # fwd
105+ x = [jax .random .normal (data_rng , (2 , * x )) for x in INPUT_SHAPE ]
106+
107+ for mode in ('train' , 'eval' ):
108+ train = mode == 'train'
109+ y1 = orig_model .apply (
110+ initial_params_original ,
111+ jnp .ones (input_shape , jnp .float32 ),
112+ jnp .ones (target_shape , jnp .float32 ),
113+ train = train ,
114+ rngs = {'dropout' : dropout_rng }, mutable = ['batch_stats' ])
115+ y2 = cust_model .apply (
116+ initial_params_custom ,
117+ jnp .ones (input_shape , jnp .float32 ),
118+ jnp .ones (target_shape , jnp .float32 ),
119+ train = train , rngs = {'dropout' : dropout_rng },
120+ mutable = ['batch_stats' ])
121+
122+ print (jax .tree .map (lambda x : x .shape , y1 ))
123+
124+ for i in range (len (y1 )):
125+ assert jnp .allclose (y1 [i ], y2 [i ])
126+
127+
128+ if __name__ == '__main__' :
129+ absltest .main ()
0 commit comments