Skip to content

Commit 3ac97ae

Browse files
committed
fix formatting
1 parent 9090e43 commit 3ac97ae

File tree

7 files changed

+540
-470
lines changed

7 files changed

+540
-470
lines changed

tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py

Lines changed: 76 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -31,63 +31,90 @@
3131
torch.backends.cudnn.deterministic = True
3232
torch.use_deterministic_algorithms(True)
3333

34+
3435
class ModelEquivalenceTest(parameterized.TestCase):
3536

36-
@parameterized.named_parameters(
37-
dict(testcase_name='DLRMResNet, p=0.0', model='dlrm_resnet', dropout_rate=0.0),
38-
dict(testcase_name='DlrmSmall, p=0.0', model='dlrm_small', dropout_rate=0.0),
39-
dict(testcase_name='DLRMResNet, p=0.1', model='dlrm_resnet', dropout_rate=0.1),
40-
dict(testcase_name='DlrmSmall, p=0.1', model='dlrm_small', dropout_rate=0.1),
41-
dict(testcase_name='DLRMResNet, p=1.0', model='dlrm_resnet', dropout_rate=1.0),
42-
dict(testcase_name='DlrmSmall, p=1.0', model='dlrm_small', dropout_rate=1.0),
37+
@parameterized.named_parameters(
38+
dict(
39+
testcase_name='DLRMResNet, p=0.0',
40+
model='dlrm_resnet',
41+
dropout_rate=0.0),
42+
dict(
43+
testcase_name='DlrmSmall, p=0.0',
44+
model='dlrm_small',
45+
dropout_rate=0.0),
46+
dict(
47+
testcase_name='DLRMResNet, p=0.1',
48+
model='dlrm_resnet',
49+
dropout_rate=0.1),
50+
dict(
51+
testcase_name='DlrmSmall, p=0.1',
52+
model='dlrm_small',
53+
dropout_rate=0.1),
54+
dict(
55+
testcase_name='DLRMResNet, p=1.0',
56+
model='dlrm_resnet',
57+
dropout_rate=1.0),
58+
dict(
59+
testcase_name='DlrmSmall, p=1.0',
60+
model='dlrm_small',
61+
dropout_rate=1.0),
62+
)
63+
def test_forward(self, model, dropout_rate):
64+
OrigCls, CustCls = (
65+
(OriginalDLRMResNet, CustomDLRMResNet)
66+
if model == 'dlrm_resnet'
67+
else (OriginalDlrmSmall, CustomDlrmSmall)
4368
)
44-
def test_forward(self, model, dropout_rate):
45-
OrigCls, CustCls = (
46-
(OriginalDLRMResNet, CustomDLRMResNet)
47-
if model == 'dlrm_resnet'
48-
else (OriginalDlrmSmall, CustomDlrmSmall)
49-
)
5069

51-
torch.manual_seed(SEED)
52-
orig = OrigCls(vocab_size=VOCAB, dropout_rate=dropout_rate).to(DEVICE)
70+
torch.manual_seed(SEED)
71+
orig = OrigCls(vocab_size=VOCAB, dropout_rate=dropout_rate).to(DEVICE)
72+
73+
torch.manual_seed(SEED)
74+
cust = CustCls(vocab_size=VOCAB).to(DEVICE)
75+
76+
x = torch.randn(BATCH, FEATURES, device=DEVICE)
5377

78+
for mode in ('train', 'eval'):
79+
getattr(orig, mode)()
80+
getattr(cust, mode)()
81+
torch.manual_seed(SEED)
82+
y1 = orig(x)
83+
torch.manual_seed(SEED)
84+
y2 = cust(x, dropout_rate)
85+
assert_close(y1, y2, atol=0, rtol=0)
86+
if mode == 'eval': # one extra test: omit dropout at eval
5487
torch.manual_seed(SEED)
55-
cust = CustCls(vocab_size=VOCAB).to(DEVICE)
56-
57-
x = torch.randn(BATCH, FEATURES, device=DEVICE)
58-
59-
for mode in ('train', 'eval'):
60-
getattr(orig, mode)(); getattr(cust, mode)()
61-
torch.manual_seed(SEED); y1 = orig(x)
62-
torch.manual_seed(SEED); y2 = cust(x, dropout_rate)
63-
assert_close(y1, y2, atol=0, rtol=0)
64-
if mode == 'eval': # one extra test: omit dropout at eval
65-
torch.manual_seed(SEED); y2 = cust(x)
66-
assert_close(y1, y2, atol=0, rtol=0)
67-
68-
@parameterized.named_parameters(
69-
dict(testcase_name='DLRMResNet, default', model='dlrm_resnet'),
70-
dict(testcase_name='DlrmSmall, default', model='dlrm_small'),
88+
y2 = cust(x)
89+
assert_close(y1, y2, atol=0, rtol=0)
90+
91+
@parameterized.named_parameters(
92+
dict(testcase_name='DLRMResNet, default', model='dlrm_resnet'),
93+
dict(testcase_name='DlrmSmall, default', model='dlrm_small'),
94+
)
95+
def test_default_dropout(self, model):
96+
"""Test default dropout_rate."""
97+
OrigCls, CustCls = (
98+
(OriginalDLRMResNet, CustomDLRMResNet)
99+
if model == 'dlrm_resnet'
100+
else (OriginalDlrmSmall, CustomDlrmSmall)
71101
)
72-
def test_default_dropout(self, model):
73-
"""Test default dropout_rate."""
74-
OrigCls, CustCls = (
75-
(OriginalDLRMResNet, CustomDLRMResNet)
76-
if model == 'dlrm_resnet'
77-
else (OriginalDlrmSmall, CustomDlrmSmall)
78-
)
79102

80-
torch.manual_seed(SEED)
81-
orig = OrigCls(vocab_size=VOCAB).to(DEVICE)
82-
torch.manual_seed(SEED)
83-
cust = CustCls(vocab_size=VOCAB).to(DEVICE)
103+
torch.manual_seed(SEED)
104+
orig = OrigCls(vocab_size=VOCAB).to(DEVICE)
105+
torch.manual_seed(SEED)
106+
cust = CustCls(vocab_size=VOCAB).to(DEVICE)
107+
108+
x = torch.randn(BATCH, FEATURES, device=DEVICE)
109+
for mode in ('train', 'eval'):
110+
getattr(orig, mode)()
111+
getattr(cust, mode)()
112+
torch.manual_seed(0)
113+
y1 = orig(x)
114+
torch.manual_seed(0)
115+
y2 = cust(x)
116+
assert_close(y1, y2, atol=0, rtol=0)
84117

85-
x = torch.randn(BATCH, FEATURES, device=DEVICE)
86-
for mode in ('train', 'eval'):
87-
getattr(orig, mode)(); getattr(cust, mode)()
88-
torch.manual_seed(0); y1 = orig(x)
89-
torch.manual_seed(0); y2 = cust(x)
90-
assert_close(y1, y2, atol=0, rtol=0)
91118

92119
if __name__ == '__main__':
93-
absltest.main()
120+
absltest.main()

tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py

Lines changed: 97 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -27,87 +27,104 @@
2727
torch.backends.cudnn.deterministic = True
2828
torch.use_deterministic_algorithms(True)
2929

30+
3031
class FastMRIModeEquivalenceTest(parameterized.TestCase):
3132

32-
def fwd_pass(self, orig, cust, dropout_rate):
33-
x = torch.randn(BATCH, IN_CHANS, H, W, device=DEVICE)
34-
for mode in ('train', 'eval'):
35-
getattr(orig, mode)(); getattr(cust, mode)()
36-
torch.manual_seed(0); y1 = orig(x)
37-
torch.manual_seed(0); y2 = cust(x, dropout_rate)
38-
assert_close(y1, y2, atol=0, rtol=0)
39-
if mode == 'eval': # one extra test: omit dropout at eval
40-
torch.manual_seed(0); y2 = cust(x)
41-
assert_close(y1, y2, atol=0, rtol=0)
42-
43-
@parameterized.named_parameters(
44-
dict(testcase_name='p=0.0', dropout_rate=0.0),
45-
dict(testcase_name='p=0.1', dropout_rate=0.1),
46-
dict(testcase_name='p=0.7', dropout_rate=0.7),
47-
dict(testcase_name='p=1.0', dropout_rate=1.0),
48-
)
49-
def test_dropout_values(self, dropout_rate):
50-
"""Test different values of dropout_rate."""
51-
52-
torch.manual_seed(SEED)
53-
orig = OriginalUNet(IN_CHANS, OUT_CHANS, C, LAYERS, dropout_rate=dropout_rate).to(DEVICE)
54-
55-
torch.manual_seed(SEED)
56-
cust = CustomUNet(IN_CHANS, OUT_CHANS, C, LAYERS).to(DEVICE)
57-
58-
cust.load_state_dict(orig.state_dict()) # sync weights
59-
if TORCH_COMPILE:
60-
orig = torch.compile(orig); cust = torch.compile(cust)
61-
62-
self.fwd_pass(orig, cust, dropout_rate)
63-
64-
65-
@parameterized.named_parameters(
66-
dict(testcase_name='default', use_tanh=False, use_layer_norm=False),
67-
dict(testcase_name='tanh', use_tanh=True, use_layer_norm=False),
68-
dict(testcase_name='layer_norm', use_tanh=False, use_layer_norm=True),
69-
dict(testcase_name='both', use_tanh=True, use_layer_norm=True),
70-
)
71-
def test_arch_configs(self, use_tanh, use_layer_norm):
72-
"""Test different architecture configurations, fixed dropout_rate."""
73-
dropout_rate = 0.1
74-
75-
torch.manual_seed(SEED)
76-
orig = OriginalUNet(
77-
IN_CHANS, OUT_CHANS, C, LAYERS, dropout_rate=dropout_rate,
78-
use_tanh=use_tanh, use_layer_norm=use_layer_norm
79-
).to(DEVICE)
80-
81-
torch.manual_seed(SEED)
82-
cust = CustomUNet(
83-
IN_CHANS, OUT_CHANS, C, LAYERS,
84-
use_tanh=use_tanh, use_layer_norm=use_layer_norm
85-
).to(DEVICE)
86-
87-
cust.load_state_dict(orig.state_dict()) # sync weights
88-
if TORCH_COMPILE:
89-
orig = torch.compile(orig); cust = torch.compile(cust)
90-
91-
self.fwd_pass(orig, cust, dropout_rate)
92-
93-
@parameterized.named_parameters(
94-
dict(testcase_name=''),
95-
)
96-
def test_default_dropout(self):
97-
"""Test default dropout_rate."""
98-
99-
torch.manual_seed(SEED)
100-
orig = OriginalUNet(IN_CHANS, OUT_CHANS, C, LAYERS).to(DEVICE)
101-
torch.manual_seed(SEED)
102-
cust = CustomUNet(IN_CHANS, OUT_CHANS, C, LAYERS).to(DEVICE)
103-
cust.load_state_dict(orig.state_dict()) # sync weights
104-
105-
x = torch.randn(BATCH, IN_CHANS, H, W, device=DEVICE)
106-
for mode in ('train', 'eval'):
107-
getattr(orig, mode)(); getattr(cust, mode)()
108-
torch.manual_seed(0); y1 = orig(x)
109-
torch.manual_seed(0); y2 = cust(x)
110-
assert_close(y1, y2, atol=0, rtol=0)
33+
def fwd_pass(self, orig, cust, dropout_rate):
34+
x = torch.randn(BATCH, IN_CHANS, H, W, device=DEVICE)
35+
for mode in ('train', 'eval'):
36+
getattr(orig, mode)()
37+
getattr(cust, mode)()
38+
torch.manual_seed(0)
39+
y1 = orig(x)
40+
torch.manual_seed(0)
41+
y2 = cust(x, dropout_rate)
42+
assert_close(y1, y2, atol=0, rtol=0)
43+
if mode == 'eval': # one extra test: omit dropout at eval
44+
torch.manual_seed(0)
45+
y2 = cust(x)
46+
assert_close(y1, y2, atol=0, rtol=0)
47+
48+
@parameterized.named_parameters(
49+
dict(testcase_name='p=0.0', dropout_rate=0.0),
50+
dict(testcase_name='p=0.1', dropout_rate=0.1),
51+
dict(testcase_name='p=0.7', dropout_rate=0.7),
52+
dict(testcase_name='p=1.0', dropout_rate=1.0),
53+
)
54+
def test_dropout_values(self, dropout_rate):
55+
"""Test different values of dropout_rate."""
56+
57+
torch.manual_seed(SEED)
58+
orig = OriginalUNet(
59+
IN_CHANS, OUT_CHANS, C, LAYERS, dropout_rate=dropout_rate).to(DEVICE)
60+
61+
torch.manual_seed(SEED)
62+
cust = CustomUNet(IN_CHANS, OUT_CHANS, C, LAYERS).to(DEVICE)
63+
64+
cust.load_state_dict(orig.state_dict()) # sync weights
65+
if TORCH_COMPILE:
66+
orig = torch.compile(orig)
67+
cust = torch.compile(cust)
68+
69+
self.fwd_pass(orig, cust, dropout_rate)
70+
71+
@parameterized.named_parameters(
72+
dict(testcase_name='default', use_tanh=False, use_layer_norm=False),
73+
dict(testcase_name='tanh', use_tanh=True, use_layer_norm=False),
74+
dict(testcase_name='layer_norm', use_tanh=False, use_layer_norm=True),
75+
dict(testcase_name='both', use_tanh=True, use_layer_norm=True),
76+
)
77+
def test_arch_configs(self, use_tanh, use_layer_norm):
78+
"""Test different architecture configurations, fixed dropout_rate."""
79+
dropout_rate = 0.1
80+
81+
torch.manual_seed(SEED)
82+
orig = OriginalUNet(
83+
IN_CHANS,
84+
OUT_CHANS,
85+
C,
86+
LAYERS,
87+
dropout_rate=dropout_rate,
88+
use_tanh=use_tanh,
89+
use_layer_norm=use_layer_norm).to(DEVICE)
90+
91+
torch.manual_seed(SEED)
92+
cust = CustomUNet(
93+
IN_CHANS,
94+
OUT_CHANS,
95+
C,
96+
LAYERS,
97+
use_tanh=use_tanh,
98+
use_layer_norm=use_layer_norm).to(DEVICE)
99+
100+
cust.load_state_dict(orig.state_dict()) # sync weights
101+
if TORCH_COMPILE:
102+
orig = torch.compile(orig)
103+
cust = torch.compile(cust)
104+
105+
self.fwd_pass(orig, cust, dropout_rate)
106+
107+
@parameterized.named_parameters(
108+
dict(testcase_name=''),)
109+
def test_default_dropout(self):
110+
"""Test default dropout_rate."""
111+
112+
torch.manual_seed(SEED)
113+
orig = OriginalUNet(IN_CHANS, OUT_CHANS, C, LAYERS).to(DEVICE)
114+
torch.manual_seed(SEED)
115+
cust = CustomUNet(IN_CHANS, OUT_CHANS, C, LAYERS).to(DEVICE)
116+
cust.load_state_dict(orig.state_dict()) # sync weights
117+
118+
x = torch.randn(BATCH, IN_CHANS, H, W, device=DEVICE)
119+
for mode in ('train', 'eval'):
120+
getattr(orig, mode)()
121+
getattr(cust, mode)()
122+
torch.manual_seed(0)
123+
y1 = orig(x)
124+
torch.manual_seed(0)
125+
y2 = cust(x)
126+
assert_close(y1, y2, atol=0, rtol=0)
127+
111128

112129
if __name__ == '__main__':
113-
absltest.main()
130+
absltest.main()

0 commit comments

Comments
 (0)