Skip to content

Commit 7a20f93

Browse files
committed
Add ModelDiffRunner
These changes were made by SujataSaurabh at PR 820. I cannot use co-author becasuse it will raise the same issue for CLA.
1 parent 801151b commit 7a20f93

File tree

2 files changed

+40
-3
lines changed

2 files changed

+40
-3
lines changed

tests/modeldiffs/criteo1tb/compare.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
Criteo1TbDlrmSmallWorkload as JaxWorkload
1313
from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import \
1414
Criteo1TbDlrmSmallWorkload as PyTorchWorkload
15-
from tests.modeldiffs.diff import out_diff
15+
from tests.modeldiffs.diff import ModelDiffRunner
1616

1717

1818
def key_transform(k):
@@ -74,11 +74,11 @@ def sd_transform(sd):
7474
rng=jax.random.PRNGKey(0),
7575
update_batch_norm=False)
7676

77-
out_diff(
77+
ModelDiffRunner(
7878
jax_workload=jax_workload,
7979
pytorch_workload=pytorch_workload,
8080
jax_model_kwargs=jax_model_kwargs,
8181
pytorch_model_kwargs=pytorch_model_kwargs,
8282
key_transform=key_transform,
8383
sd_transform=sd_transform,
84-
out_transform=None)
84+
out_transform=None).run()

tests/modeldiffs/diff.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,40 @@ def out_diff(jax_workload,
6464

6565
print(f'Max fprop difference between jax and pytorch: {max_diff}')
6666
print(f'Min fprop difference between jax and pytorch: {min_diff}')
67+
68+
69+
class ModelDiffRunner:
70+
def __init__(self, jax_workload,
71+
pytorch_workload,
72+
jax_model_kwargs,
73+
pytorch_model_kwargs,
74+
key_transform=None,
75+
sd_transform=None,
76+
out_transform=None) -> None:
77+
"""Initializes the instance based on diffing logic.
78+
Args:
79+
jax_workload: Workload implementation using JAX
80+
pytorch_workload: Workload implementation using PyTorch
81+
jax_model_kwargs: Arguments to be used for model_fn in jax workload
82+
pytorch_model_kwargs: Arguments to be used for model_fn in PyTorch workload
83+
key_transform: Transformation function for keys.
84+
sd_transform: Transformation function for State Dictionary.
85+
out_transform: Transformation function for the output.
86+
"""
87+
88+
self.jax_workload = jax_workload
89+
self.pytorch_workload = pytorch_workload
90+
self.jax_model_kwargs = jax_model_kwargs
91+
self.pytorch_model_kwargs = pytorch_model_kwargs
92+
self.key_transform = key_transform
93+
self.sd_transform = sd_transform
94+
self.out_transform = out_transform
95+
96+
def run(self):
97+
out_diff(self.jax_workload,
98+
self.pytorch_workload,
99+
self.jax_model_kwargs,
100+
self.pytorch_model_kwargs,
101+
self.key_transform,
102+
self.sd_transform,
103+
self.out_transform)

0 commit comments

Comments
 (0)