Skip to content

Commit e5b7930

Browse files
committed
call out_diff through ModelDiffRunner().run()
1 parent d96da4a commit e5b7930

File tree

30 files changed

+90
-101
lines changed

30 files changed

+90
-101
lines changed

tests/modeldiffs/criteo1tb_embed_init/compare.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
Criteo1TbDlrmSmallEmbedInitWorkload as JaxWorkload
1313
from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import \
1414
Criteo1TbDlrmSmallEmbedInitWorkload as PyTorchWorkload
15-
from tests.modeldiffs.diff import out_diff
15+
from tests.modeldiffs.diff import ModelDiffRunner
1616

1717

1818
def key_transform(k):
@@ -73,11 +73,11 @@ def sd_transform(sd):
7373
rng=jax.random.PRNGKey(0),
7474
update_batch_norm=False)
7575

76-
out_diff(
76+
ModelDiffRunner(
7777
jax_workload=jax_workload,
7878
pytorch_workload=pytorch_workload,
7979
jax_model_kwargs=jax_model_kwargs,
8080
pytorch_model_kwargs=pytorch_model_kwargs,
8181
key_transform=key_transform,
8282
sd_transform=sd_transform,
83-
out_transform=None)
83+
out_transform=None).run()

tests/modeldiffs/criteo1tb_layernorm/compare.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
Criteo1TbDlrmSmallLayerNormWorkload as JaxWorkload
1313
from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import \
1414
Criteo1TbDlrmSmallLayerNormWorkload as PyTorchWorkload
15-
from tests.modeldiffs.diff import out_diff
15+
from tests.modeldiffs.diff import ModelDiffRunner
1616

1717

1818
def key_transform(k):
@@ -85,11 +85,11 @@ def sd_transform(sd):
8585
rng=jax.random.PRNGKey(0),
8686
update_batch_norm=False)
8787

88-
out_diff(
88+
ModelDiffRunner(
8989
jax_workload=jax_workload,
9090
pytorch_workload=pytorch_workload,
9191
jax_model_kwargs=jax_model_kwargs,
9292
pytorch_model_kwargs=pytorch_model_kwargs,
9393
key_transform=key_transform,
9494
sd_transform=sd_transform,
95-
out_transform=None)
95+
out_transform=None).run()

tests/modeldiffs/criteo1tb_resnet/compare.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
Criteo1TbDlrmSmallResNetWorkload as JaxWorkload
1414
from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import \
1515
Criteo1TbDlrmSmallResNetWorkload as PyTorchWorkload
16-
from tests.modeldiffs.diff import out_diff
16+
from tests.modeldiffs.diff import ModelDiffRunner
1717

1818

1919
def key_transform(k):
@@ -92,11 +92,11 @@ def sd_transform(sd):
9292
rng=jax.random.PRNGKey(0),
9393
update_batch_norm=False)
9494

95-
out_diff(
95+
ModelDiffRunner(
9696
jax_workload=jax_workload,
9797
pytorch_workload=pytorch_workload,
9898
jax_model_kwargs=jax_model_kwargs,
9999
pytorch_model_kwargs=pytorch_model_kwargs,
100100
key_transform=key_transform,
101101
sd_transform=sd_transform,
102-
out_transform=None)
102+
out_transform=None).run()

tests/modeldiffs/fastmri/compare.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
FastMRIWorkload as JaxWorkload
1212
from algoperf.workloads.fastmri.fastmri_pytorch.workload import \
1313
FastMRIWorkload as PyTorchWorkload
14-
from tests.modeldiffs.diff import out_diff
14+
from tests.modeldiffs.diff import ModelDiffRunner
1515

1616

1717
def sd_transform(sd):
@@ -76,11 +76,10 @@ def sort_key(k):
7676
rng=jax.random.PRNGKey(0),
7777
update_batch_norm=False)
7878

79-
out_diff(
79+
ModelDiffRunner(
8080
jax_workload=jax_workload,
8181
pytorch_workload=pytorch_workload,
8282
jax_model_kwargs=jax_model_kwargs,
8383
pytorch_model_kwargs=pytorch_model_kwargs,
8484
key_transform=None,
85-
sd_transform=sd_transform,
86-
)
85+
sd_transform=sd_transform).run()

tests/modeldiffs/fastmri_layernorm/compare.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
FastMRILayerNormWorkload as JaxWorkload
1212
from algoperf.workloads.fastmri.fastmri_pytorch.workload import \
1313
FastMRILayerNormWorkload as PyTorchWorkload
14-
from tests.modeldiffs.diff import out_diff
14+
from tests.modeldiffs.diff import ModelDiffRunner
1515

1616

1717
def sd_transform(sd):
@@ -83,11 +83,10 @@ def sort_key(k):
8383
rng=jax.random.PRNGKey(0),
8484
update_batch_norm=False)
8585

86-
out_diff(
86+
ModelDiffRunner(
8787
jax_workload=jax_workload,
8888
pytorch_workload=pytorch_workload,
8989
jax_model_kwargs=jax_model_kwargs,
9090
pytorch_model_kwargs=pytorch_model_kwargs,
9191
key_transform=None,
92-
sd_transform=sd_transform,
93-
)
92+
sd_transform=sd_transform).run()

tests/modeldiffs/fastmri_model_size/compare.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
FastMRIModelSizeWorkload as JaxWorkload
1212
from algoperf.workloads.fastmri.fastmri_pytorch.workload import \
1313
FastMRIModelSizeWorkload as PyTorchWorkload
14-
from tests.modeldiffs.diff import out_diff
14+
from tests.modeldiffs.diff import ModelDiffRunner
1515

1616

1717
def sd_transform(sd):
@@ -76,11 +76,10 @@ def sort_key(k):
7676
rng=jax.random.PRNGKey(0),
7777
update_batch_norm=False)
7878

79-
out_diff(
79+
ModelDiffRunner(
8080
jax_workload=jax_workload,
8181
pytorch_workload=pytorch_workload,
8282
jax_model_kwargs=jax_model_kwargs,
8383
pytorch_model_kwargs=pytorch_model_kwargs,
8484
key_transform=None,
85-
sd_transform=sd_transform,
86-
)
85+
sd_transform=sd_transform).run()

tests/modeldiffs/fastmri_tanh/compare.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
FastMRITanhWorkload as JaxWorkload
1212
from algoperf.workloads.fastmri.fastmri_pytorch.workload import \
1313
FastMRITanhWorkload as PyTorchWorkload
14-
from tests.modeldiffs.diff import out_diff
14+
from tests.modeldiffs.diff import ModelDiffRunner
1515

1616

1717
def sd_transform(sd):
@@ -76,11 +76,10 @@ def sort_key(k):
7676
rng=jax.random.PRNGKey(0),
7777
update_batch_norm=False)
7878

79-
out_diff(
79+
ModelDiffRunner(
8080
jax_workload=jax_workload,
8181
pytorch_workload=pytorch_workload,
8282
jax_model_kwargs=jax_model_kwargs,
8383
pytorch_model_kwargs=pytorch_model_kwargs,
8484
key_transform=None,
85-
sd_transform=sd_transform,
86-
)
85+
sd_transform=sd_transform).run()

tests/modeldiffs/imagenet_resnet/compare.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
ImagenetResNetWorkload as JaxWorkload
1212
from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import \
1313
ImagenetResNetWorkload as PyTorchWorkload
14-
from tests.modeldiffs.diff import out_diff
14+
from tests.modeldiffs.diff import ModelDiffRunner
1515

1616

1717
def key_transform(k):
@@ -93,11 +93,10 @@ def sd_transform(sd):
9393
rng=jax.random.PRNGKey(0),
9494
update_batch_norm=False)
9595

96-
out_diff(
96+
ModelDiffRunner(
9797
jax_workload=jax_workload,
9898
pytorch_workload=pytorch_workload,
9999
jax_model_kwargs=jax_model_kwargs,
100100
pytorch_model_kwargs=pytorch_model_kwargs,
101101
key_transform=key_transform,
102-
sd_transform=sd_transform,
103-
)
102+
sd_transform=sd_transform).run()

tests/modeldiffs/imagenet_resnet/gelu_compare.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
ImagenetResNetGELUWorkload as JaxWorkload
1212
from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import \
1313
ImagenetResNetGELUWorkload as PyTorchWorkload
14-
from tests.modeldiffs.diff import out_diff
14+
from tests.modeldiffs.diff import ModelDiffRunner
1515
from tests.modeldiffs.imagenet_resnet.compare import key_transform
1616
from tests.modeldiffs.imagenet_resnet.compare import sd_transform
1717

@@ -40,11 +40,10 @@
4040
rng=jax.random.PRNGKey(0),
4141
update_batch_norm=False)
4242

43-
out_diff(
43+
ModelDiffRunner(
4444
jax_workload=jax_workload,
4545
pytorch_workload=pytorch_workload,
4646
jax_model_kwargs=jax_model_kwargs,
4747
pytorch_model_kwargs=pytorch_model_kwargs,
4848
key_transform=key_transform,
49-
sd_transform=sd_transform,
50-
)
49+
sd_transform=sd_transform).run()

tests/modeldiffs/imagenet_resnet/silu_compare.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
ImagenetResNetSiLUWorkload as JaxWorkload
1212
from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import \
1313
ImagenetResNetSiLUWorkload as PyTorchWorkload
14-
from tests.modeldiffs.diff import out_diff
14+
from tests.modeldiffs.diff import ModelDiffRunner
1515
from tests.modeldiffs.imagenet_resnet.compare import key_transform
1616
from tests.modeldiffs.imagenet_resnet.compare import sd_transform
1717

@@ -40,11 +40,10 @@
4040
rng=jax.random.PRNGKey(0),
4141
update_batch_norm=False)
4242

43-
out_diff(
43+
ModelDiffRunner(
4444
jax_workload=jax_workload,
4545
pytorch_workload=pytorch_workload,
4646
jax_model_kwargs=jax_model_kwargs,
4747
pytorch_model_kwargs=pytorch_model_kwargs,
4848
key_transform=key_transform,
49-
sd_transform=sd_transform,
50-
)
49+
sd_transform=sd_transform).run()

0 commit comments

Comments
 (0)