Skip to content

Commit f192d4f

Browse files
Merge pull request #859 from mlcommons/iss807_modeldiff
[WIP] Refactor modeldiffs tests to improve variable names script names and deduplicate logic
2 parents 801151b + 5efdfad commit f192d4f

File tree

33 files changed

+229
-198
lines changed

33 files changed

+229
-198
lines changed

tests/modeldiffs/criteo1tb/compare.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
Criteo1TbDlrmSmallWorkload as JaxWorkload
1313
from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import \
1414
Criteo1TbDlrmSmallWorkload as PyTorchWorkload
15-
from tests.modeldiffs.diff import out_diff
15+
from tests.modeldiffs.diff import ModelDiffRunner
1616

1717

1818
def key_transform(k):
@@ -53,16 +53,16 @@ def sd_transform(sd):
5353
jax_workload = JaxWorkload()
5454
pytorch_workload = PyTorchWorkload()
5555

56-
pyt_batch = {
56+
pytorch_batch = {
5757
'inputs': torch.ones((2, 13 + 26)),
5858
'targets': torch.randint(low=0, high=1, size=(2,)),
5959
'weights': torch.ones(2),
6060
}
61-
jax_batch = {k: np.array(v) for k, v in pyt_batch.items()}
61+
jax_batch = {k: np.array(v) for k, v in pytorch_batch.items()}
6262

6363
# Test outputs for identical weights and inputs.
6464
pytorch_model_kwargs = dict(
65-
augmented_and_preprocessed_input_batch=pyt_batch,
65+
augmented_and_preprocessed_input_batch=pytorch_batch,
6666
model_state=None,
6767
mode=spec.ForwardPassMode.EVAL,
6868
rng=None,
@@ -74,11 +74,11 @@ def sd_transform(sd):
7474
rng=jax.random.PRNGKey(0),
7575
update_batch_norm=False)
7676

77-
out_diff(
77+
ModelDiffRunner(
7878
jax_workload=jax_workload,
7979
pytorch_workload=pytorch_workload,
8080
jax_model_kwargs=jax_model_kwargs,
8181
pytorch_model_kwargs=pytorch_model_kwargs,
8282
key_transform=key_transform,
8383
sd_transform=sd_transform,
84-
out_transform=None)
84+
out_transform=None).run()

tests/modeldiffs/criteo1tb_embed_init/compare.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
Criteo1TbDlrmSmallEmbedInitWorkload as JaxWorkload
1313
from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import \
1414
Criteo1TbDlrmSmallEmbedInitWorkload as PyTorchWorkload
15-
from tests.modeldiffs.diff import out_diff
15+
from tests.modeldiffs.diff import ModelDiffRunner
1616

1717

1818
def key_transform(k):
@@ -52,16 +52,16 @@ def sd_transform(sd):
5252
jax_workload = JaxWorkload()
5353
pytorch_workload = PyTorchWorkload()
5454

55-
pyt_batch = {
55+
pytorch_batch = {
5656
'inputs': torch.ones((2, 13 + 26)),
5757
'targets': torch.randint(low=0, high=1, size=(2,)),
5858
'weights': torch.ones(2),
5959
}
60-
jax_batch = {k: np.array(v) for k, v in pyt_batch.items()}
60+
jax_batch = {k: np.array(v) for k, v in pytorch_batch.items()}
6161

6262
# Test outputs for identical weights and inputs.
6363
pytorch_model_kwargs = dict(
64-
augmented_and_preprocessed_input_batch=pyt_batch,
64+
augmented_and_preprocessed_input_batch=pytorch_batch,
6565
model_state=None,
6666
mode=spec.ForwardPassMode.EVAL,
6767
rng=None,
@@ -73,11 +73,11 @@ def sd_transform(sd):
7373
rng=jax.random.PRNGKey(0),
7474
update_batch_norm=False)
7575

76-
out_diff(
76+
ModelDiffRunner(
7777
jax_workload=jax_workload,
7878
pytorch_workload=pytorch_workload,
7979
jax_model_kwargs=jax_model_kwargs,
8080
pytorch_model_kwargs=pytorch_model_kwargs,
8181
key_transform=key_transform,
8282
sd_transform=sd_transform,
83-
out_transform=None)
83+
out_transform=None).run()

tests/modeldiffs/criteo1tb_layernorm/compare.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
Criteo1TbDlrmSmallLayerNormWorkload as JaxWorkload
1313
from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import \
1414
Criteo1TbDlrmSmallLayerNormWorkload as PyTorchWorkload
15-
from tests.modeldiffs.diff import out_diff
15+
from tests.modeldiffs.diff import ModelDiffRunner
1616

1717

1818
def key_transform(k):
@@ -64,16 +64,16 @@ def sd_transform(sd):
6464
jax_workload = JaxWorkload()
6565
pytorch_workload = PyTorchWorkload()
6666

67-
pyt_batch = {
67+
pytorch_batch = {
6868
'inputs': torch.ones((2, 13 + 26)),
6969
'targets': torch.randint(low=0, high=1, size=(2,)),
7070
'weights': torch.ones(2),
7171
}
72-
jax_batch = {k: np.array(v) for k, v in pyt_batch.items()}
72+
jax_batch = {k: np.array(v) for k, v in pytorch_batch.items()}
7373

7474
# Test outputs for identical weights and inputs.
7575
pytorch_model_kwargs = dict(
76-
augmented_and_preprocessed_input_batch=pyt_batch,
76+
augmented_and_preprocessed_input_batch=pytorch_batch,
7777
model_state=None,
7878
mode=spec.ForwardPassMode.EVAL,
7979
rng=None,
@@ -85,11 +85,11 @@ def sd_transform(sd):
8585
rng=jax.random.PRNGKey(0),
8686
update_batch_norm=False)
8787

88-
out_diff(
88+
ModelDiffRunner(
8989
jax_workload=jax_workload,
9090
pytorch_workload=pytorch_workload,
9191
jax_model_kwargs=jax_model_kwargs,
9292
pytorch_model_kwargs=pytorch_model_kwargs,
9393
key_transform=key_transform,
9494
sd_transform=sd_transform,
95-
out_transform=None)
95+
out_transform=None).run()

tests/modeldiffs/criteo1tb_resnet/compare.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
Criteo1TbDlrmSmallResNetWorkload as JaxWorkload
1414
from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import \
1515
Criteo1TbDlrmSmallResNetWorkload as PyTorchWorkload
16-
from tests.modeldiffs.diff import out_diff
16+
from tests.modeldiffs.diff import ModelDiffRunner
1717

1818

1919
def key_transform(k):
@@ -64,7 +64,7 @@ def sd_transform(sd):
6464
jax_workload = JaxWorkload()
6565
pytorch_workload = PyTorchWorkload()
6666

67-
pyt_batch = {
67+
pytorch_batch = {
6868
'inputs': torch.ones((2, 13 + 26)),
6969
'targets': torch.randint(low=0, high=1, size=(2,)),
7070
'weights': torch.ones(2),
@@ -75,12 +75,12 @@ def sd_transform(sd):
7575
input_size = 13 + num_categorical_features
7676
input_shape = (init_fake_batch_size, input_size)
7777
fake_inputs = jnp.ones(input_shape, jnp.float32)
78-
jax_batch = {k: np.array(v) for k, v in pyt_batch.items()}
78+
jax_batch = {k: np.array(v) for k, v in pytorch_batch.items()}
7979
jax_batch['inputs'] = fake_inputs
8080

8181
# Test outputs for identical weights and inputs.
8282
pytorch_model_kwargs = dict(
83-
augmented_and_preprocessed_input_batch=pyt_batch,
83+
augmented_and_preprocessed_input_batch=pytorch_batch,
8484
model_state=None,
8585
mode=spec.ForwardPassMode.EVAL,
8686
rng=None,
@@ -92,11 +92,11 @@ def sd_transform(sd):
9292
rng=jax.random.PRNGKey(0),
9393
update_batch_norm=False)
9494

95-
out_diff(
95+
ModelDiffRunner(
9696
jax_workload=jax_workload,
9797
pytorch_workload=pytorch_workload,
9898
jax_model_kwargs=jax_model_kwargs,
9999
pytorch_model_kwargs=pytorch_model_kwargs,
100100
key_transform=key_transform,
101101
sd_transform=sd_transform,
102-
out_transform=None)
102+
out_transform=None).run()

tests/modeldiffs/diff.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,45 @@ def out_diff(jax_workload,
6464

6565
print(f'Max fprop difference between jax and pytorch: {max_diff}')
6666
print(f'Min fprop difference between jax and pytorch: {min_diff}')
67+
68+
69+
class ModelDiffRunner:
70+
71+
def __init__(self,
72+
jax_workload,
73+
pytorch_workload,
74+
jax_model_kwargs,
75+
pytorch_model_kwargs,
76+
key_transform=None,
77+
sd_transform=None,
78+
out_transform=None) -> None:
79+
"""
80+
Initializes the instance based on diffing logic.
81+
82+
Args:
83+
jax_workload: Workload implementation using JAX.
84+
pytorch_workload: Workload implementation using PyTorch.
85+
jax_model_kwargs: Arguments to be used for model_fn in jax workload.
86+
pytorch_model_kwargs: Arguments to be used for model_fn in PyTorch
87+
workload.
88+
key_transform: Transformation function for keys.
89+
sd_transform: Transformation function for State Dictionary.
90+
out_transform: Transformation function for the output.
91+
"""
92+
93+
self.jax_workload = jax_workload
94+
self.pytorch_workload = pytorch_workload
95+
self.jax_model_kwargs = jax_model_kwargs
96+
self.pytorch_model_kwargs = pytorch_model_kwargs
97+
self.key_transform = key_transform
98+
self.sd_transform = sd_transform
99+
self.out_transform = out_transform
100+
101+
def run(self):
102+
out_diff(self.jax_workload,
103+
self.pytorch_workload,
104+
self.jax_model_kwargs,
105+
self.pytorch_model_kwargs,
106+
self.key_transform,
107+
self.sd_transform,
108+
self.out_transform)

tests/modeldiffs/fastmri/compare.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
FastMRIWorkload as JaxWorkload
1212
from algoperf.workloads.fastmri.fastmri_pytorch.workload import \
1313
FastMRIWorkload as PyTorchWorkload
14-
from tests.modeldiffs.diff import out_diff
14+
from tests.modeldiffs.diff import ModelDiffRunner
1515

1616

1717
def sd_transform(sd):
@@ -61,10 +61,10 @@ def sort_key(k):
6161
image = torch.randn(2, 320, 320)
6262

6363
jax_batch = {'inputs': image.detach().numpy()}
64-
pyt_batch = {'inputs': image}
64+
pytorch_batch = {'inputs': image}
6565

6666
pytorch_model_kwargs = dict(
67-
augmented_and_preprocessed_input_batch=pyt_batch,
67+
augmented_and_preprocessed_input_batch=pytorch_batch,
6868
model_state=None,
6969
mode=spec.ForwardPassMode.EVAL,
7070
rng=None,
@@ -76,11 +76,10 @@ def sort_key(k):
7676
rng=jax.random.PRNGKey(0),
7777
update_batch_norm=False)
7878

79-
out_diff(
79+
ModelDiffRunner(
8080
jax_workload=jax_workload,
8181
pytorch_workload=pytorch_workload,
8282
jax_model_kwargs=jax_model_kwargs,
8383
pytorch_model_kwargs=pytorch_model_kwargs,
8484
key_transform=None,
85-
sd_transform=sd_transform,
86-
)
85+
sd_transform=sd_transform).run()

tests/modeldiffs/fastmri_layernorm/compare.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
FastMRILayerNormWorkload as JaxWorkload
1212
from algoperf.workloads.fastmri.fastmri_pytorch.workload import \
1313
FastMRILayerNormWorkload as PyTorchWorkload
14-
from tests.modeldiffs.diff import out_diff
14+
from tests.modeldiffs.diff import ModelDiffRunner
1515

1616

1717
def sd_transform(sd):
@@ -68,10 +68,10 @@ def sort_key(k):
6868
image = torch.randn(2, 320, 320)
6969

7070
jax_batch = {'inputs': image.detach().numpy()}
71-
pyt_batch = {'inputs': image}
71+
pytorch_batch = {'inputs': image}
7272

7373
pytorch_model_kwargs = dict(
74-
augmented_and_preprocessed_input_batch=pyt_batch,
74+
augmented_and_preprocessed_input_batch=pytorch_batch,
7575
model_state=None,
7676
mode=spec.ForwardPassMode.EVAL,
7777
rng=None,
@@ -83,11 +83,10 @@ def sort_key(k):
8383
rng=jax.random.PRNGKey(0),
8484
update_batch_norm=False)
8585

86-
out_diff(
86+
ModelDiffRunner(
8787
jax_workload=jax_workload,
8888
pytorch_workload=pytorch_workload,
8989
jax_model_kwargs=jax_model_kwargs,
9090
pytorch_model_kwargs=pytorch_model_kwargs,
9191
key_transform=None,
92-
sd_transform=sd_transform,
93-
)
92+
sd_transform=sd_transform).run()

tests/modeldiffs/fastmri_model_size/compare.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
FastMRIModelSizeWorkload as JaxWorkload
1212
from algoperf.workloads.fastmri.fastmri_pytorch.workload import \
1313
FastMRIModelSizeWorkload as PyTorchWorkload
14-
from tests.modeldiffs.diff import out_diff
14+
from tests.modeldiffs.diff import ModelDiffRunner
1515

1616

1717
def sd_transform(sd):
@@ -61,10 +61,10 @@ def sort_key(k):
6161
image = torch.randn(2, 320, 320)
6262

6363
jax_batch = {'inputs': image.detach().numpy()}
64-
pyt_batch = {'inputs': image}
64+
pytorch_batch = {'inputs': image}
6565

6666
pytorch_model_kwargs = dict(
67-
augmented_and_preprocessed_input_batch=pyt_batch,
67+
augmented_and_preprocessed_input_batch=pytorch_batch,
6868
model_state=None,
6969
mode=spec.ForwardPassMode.EVAL,
7070
rng=None,
@@ -76,11 +76,10 @@ def sort_key(k):
7676
rng=jax.random.PRNGKey(0),
7777
update_batch_norm=False)
7878

79-
out_diff(
79+
ModelDiffRunner(
8080
jax_workload=jax_workload,
8181
pytorch_workload=pytorch_workload,
8282
jax_model_kwargs=jax_model_kwargs,
8383
pytorch_model_kwargs=pytorch_model_kwargs,
8484
key_transform=None,
85-
sd_transform=sd_transform,
86-
)
85+
sd_transform=sd_transform).run()

tests/modeldiffs/fastmri_tanh/compare.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
FastMRITanhWorkload as JaxWorkload
1212
from algoperf.workloads.fastmri.fastmri_pytorch.workload import \
1313
FastMRITanhWorkload as PyTorchWorkload
14-
from tests.modeldiffs.diff import out_diff
14+
from tests.modeldiffs.diff import ModelDiffRunner
1515

1616

1717
def sd_transform(sd):
@@ -61,10 +61,10 @@ def sort_key(k):
6161
image = torch.randn(2, 320, 320)
6262

6363
jax_batch = {'inputs': image.detach().numpy()}
64-
pyt_batch = {'inputs': image}
64+
pytorch_batch = {'inputs': image}
6565

6666
pytorch_model_kwargs = dict(
67-
augmented_and_preprocessed_input_batch=pyt_batch,
67+
augmented_and_preprocessed_input_batch=pytorch_batch,
6868
model_state=None,
6969
mode=spec.ForwardPassMode.EVAL,
7070
rng=None,
@@ -76,11 +76,10 @@ def sort_key(k):
7676
rng=jax.random.PRNGKey(0),
7777
update_batch_norm=False)
7878

79-
out_diff(
79+
ModelDiffRunner(
8080
jax_workload=jax_workload,
8181
pytorch_workload=pytorch_workload,
8282
jax_model_kwargs=jax_model_kwargs,
8383
pytorch_model_kwargs=pytorch_model_kwargs,
8484
key_transform=None,
85-
sd_transform=sd_transform,
86-
)
85+
sd_transform=sd_transform).run()

0 commit comments

Comments
 (0)