Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions tests/modeldiffs/criteo1tb/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
Criteo1TbDlrmSmallWorkload as JaxWorkload
from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import \
Criteo1TbDlrmSmallWorkload as PyTorchWorkload
from tests.modeldiffs.diff import out_diff
from tests.modeldiffs.diff import ModelDiffRunner


def key_transform(k):
Expand Down Expand Up @@ -53,16 +53,16 @@ def sd_transform(sd):
jax_workload = JaxWorkload()
pytorch_workload = PyTorchWorkload()

pyt_batch = {
pytorch_batch = {
'inputs': torch.ones((2, 13 + 26)),
'targets': torch.randint(low=0, high=1, size=(2,)),
'weights': torch.ones(2),
}
jax_batch = {k: np.array(v) for k, v in pyt_batch.items()}
jax_batch = {k: np.array(v) for k, v in pytorch_batch.items()}

# Test outputs for identical weights and inputs.
pytorch_model_kwargs = dict(
augmented_and_preprocessed_input_batch=pyt_batch,
augmented_and_preprocessed_input_batch=pytorch_batch,
model_state=None,
mode=spec.ForwardPassMode.EVAL,
rng=None,
Expand All @@ -74,11 +74,11 @@ def sd_transform(sd):
rng=jax.random.PRNGKey(0),
update_batch_norm=False)

out_diff(
ModelDiffRunner(
jax_workload=jax_workload,
pytorch_workload=pytorch_workload,
jax_model_kwargs=jax_model_kwargs,
pytorch_model_kwargs=pytorch_model_kwargs,
key_transform=key_transform,
sd_transform=sd_transform,
out_transform=None)
out_transform=None).run()
12 changes: 6 additions & 6 deletions tests/modeldiffs/criteo1tb_embed_init/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
Criteo1TbDlrmSmallEmbedInitWorkload as JaxWorkload
from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import \
Criteo1TbDlrmSmallEmbedInitWorkload as PyTorchWorkload
from tests.modeldiffs.diff import out_diff
from tests.modeldiffs.diff import ModelDiffRunner


def key_transform(k):
Expand Down Expand Up @@ -52,16 +52,16 @@ def sd_transform(sd):
jax_workload = JaxWorkload()
pytorch_workload = PyTorchWorkload()

pyt_batch = {
pytorch_batch = {
'inputs': torch.ones((2, 13 + 26)),
'targets': torch.randint(low=0, high=1, size=(2,)),
'weights': torch.ones(2),
}
jax_batch = {k: np.array(v) for k, v in pyt_batch.items()}
jax_batch = {k: np.array(v) for k, v in pytorch_batch.items()}

# Test outputs for identical weights and inputs.
pytorch_model_kwargs = dict(
augmented_and_preprocessed_input_batch=pyt_batch,
augmented_and_preprocessed_input_batch=pytorch_batch,
model_state=None,
mode=spec.ForwardPassMode.EVAL,
rng=None,
Expand All @@ -73,11 +73,11 @@ def sd_transform(sd):
rng=jax.random.PRNGKey(0),
update_batch_norm=False)

out_diff(
ModelDiffRunner(
jax_workload=jax_workload,
pytorch_workload=pytorch_workload,
jax_model_kwargs=jax_model_kwargs,
pytorch_model_kwargs=pytorch_model_kwargs,
key_transform=key_transform,
sd_transform=sd_transform,
out_transform=None)
out_transform=None).run()
12 changes: 6 additions & 6 deletions tests/modeldiffs/criteo1tb_layernorm/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
Criteo1TbDlrmSmallLayerNormWorkload as JaxWorkload
from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import \
Criteo1TbDlrmSmallLayerNormWorkload as PyTorchWorkload
from tests.modeldiffs.diff import out_diff
from tests.modeldiffs.diff import ModelDiffRunner


def key_transform(k):
Expand Down Expand Up @@ -64,16 +64,16 @@ def sd_transform(sd):
jax_workload = JaxWorkload()
pytorch_workload = PyTorchWorkload()

pyt_batch = {
pytorch_batch = {
'inputs': torch.ones((2, 13 + 26)),
'targets': torch.randint(low=0, high=1, size=(2,)),
'weights': torch.ones(2),
}
jax_batch = {k: np.array(v) for k, v in pyt_batch.items()}
jax_batch = {k: np.array(v) for k, v in pytorch_batch.items()}

# Test outputs for identical weights and inputs.
pytorch_model_kwargs = dict(
augmented_and_preprocessed_input_batch=pyt_batch,
augmented_and_preprocessed_input_batch=pytorch_batch,
model_state=None,
mode=spec.ForwardPassMode.EVAL,
rng=None,
Expand All @@ -85,11 +85,11 @@ def sd_transform(sd):
rng=jax.random.PRNGKey(0),
update_batch_norm=False)

out_diff(
ModelDiffRunner(
jax_workload=jax_workload,
pytorch_workload=pytorch_workload,
jax_model_kwargs=jax_model_kwargs,
pytorch_model_kwargs=pytorch_model_kwargs,
key_transform=key_transform,
sd_transform=sd_transform,
out_transform=None)
out_transform=None).run()
12 changes: 6 additions & 6 deletions tests/modeldiffs/criteo1tb_resnet/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
Criteo1TbDlrmSmallResNetWorkload as JaxWorkload
from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import \
Criteo1TbDlrmSmallResNetWorkload as PyTorchWorkload
from tests.modeldiffs.diff import out_diff
from tests.modeldiffs.diff import ModelDiffRunner


def key_transform(k):
Expand Down Expand Up @@ -64,7 +64,7 @@ def sd_transform(sd):
jax_workload = JaxWorkload()
pytorch_workload = PyTorchWorkload()

pyt_batch = {
pytorch_batch = {
'inputs': torch.ones((2, 13 + 26)),
'targets': torch.randint(low=0, high=1, size=(2,)),
'weights': torch.ones(2),
Expand All @@ -75,12 +75,12 @@ def sd_transform(sd):
input_size = 13 + num_categorical_features
input_shape = (init_fake_batch_size, input_size)
fake_inputs = jnp.ones(input_shape, jnp.float32)
jax_batch = {k: np.array(v) for k, v in pyt_batch.items()}
jax_batch = {k: np.array(v) for k, v in pytorch_batch.items()}
jax_batch['inputs'] = fake_inputs

# Test outputs for identical weights and inputs.
pytorch_model_kwargs = dict(
augmented_and_preprocessed_input_batch=pyt_batch,
augmented_and_preprocessed_input_batch=pytorch_batch,
model_state=None,
mode=spec.ForwardPassMode.EVAL,
rng=None,
Expand All @@ -92,11 +92,11 @@ def sd_transform(sd):
rng=jax.random.PRNGKey(0),
update_batch_norm=False)

out_diff(
ModelDiffRunner(
jax_workload=jax_workload,
pytorch_workload=pytorch_workload,
jax_model_kwargs=jax_model_kwargs,
pytorch_model_kwargs=pytorch_model_kwargs,
key_transform=key_transform,
sd_transform=sd_transform,
out_transform=None)
out_transform=None).run()
42 changes: 42 additions & 0 deletions tests/modeldiffs/diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,45 @@ def out_diff(jax_workload,

print(f'Max fprop difference between jax and pytorch: {max_diff}')
print(f'Min fprop difference between jax and pytorch: {min_diff}')


class ModelDiffRunner:

def __init__(self,
jax_workload,
pytorch_workload,
jax_model_kwargs,
pytorch_model_kwargs,
key_transform=None,
sd_transform=None,
out_transform=None) -> None:
"""
Initializes the instance based on diffing logic.

Args:
jax_workload: Workload implementation using JAX.
pytorch_workload: Workload implementation using PyTorch.
jax_model_kwargs: Arguments to be used for model_fn in jax workload.
pytorch_model_kwargs: Arguments to be used for model_fn in PyTorch
workload.
key_transform: Transformation function for keys.
sd_transform: Transformation function for State Dictionary.
out_transform: Transformation function for the output.
"""

self.jax_workload = jax_workload
self.pytorch_workload = pytorch_workload
self.jax_model_kwargs = jax_model_kwargs
self.pytorch_model_kwargs = pytorch_model_kwargs
self.key_transform = key_transform
self.sd_transform = sd_transform
self.out_transform = out_transform

def run(self):
out_diff(self.jax_workload,
self.pytorch_workload,
self.jax_model_kwargs,
self.pytorch_model_kwargs,
self.key_transform,
self.sd_transform,
self.out_transform)
11 changes: 5 additions & 6 deletions tests/modeldiffs/fastmri/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
FastMRIWorkload as JaxWorkload
from algoperf.workloads.fastmri.fastmri_pytorch.workload import \
FastMRIWorkload as PyTorchWorkload
from tests.modeldiffs.diff import out_diff
from tests.modeldiffs.diff import ModelDiffRunner


def sd_transform(sd):
Expand Down Expand Up @@ -61,10 +61,10 @@ def sort_key(k):
image = torch.randn(2, 320, 320)

jax_batch = {'inputs': image.detach().numpy()}
pyt_batch = {'inputs': image}
pytorch_batch = {'inputs': image}

pytorch_model_kwargs = dict(
augmented_and_preprocessed_input_batch=pyt_batch,
augmented_and_preprocessed_input_batch=pytorch_batch,
model_state=None,
mode=spec.ForwardPassMode.EVAL,
rng=None,
Expand All @@ -76,11 +76,10 @@ def sort_key(k):
rng=jax.random.PRNGKey(0),
update_batch_norm=False)

out_diff(
ModelDiffRunner(
jax_workload=jax_workload,
pytorch_workload=pytorch_workload,
jax_model_kwargs=jax_model_kwargs,
pytorch_model_kwargs=pytorch_model_kwargs,
key_transform=None,
sd_transform=sd_transform,
)
sd_transform=sd_transform).run()
11 changes: 5 additions & 6 deletions tests/modeldiffs/fastmri_layernorm/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
FastMRILayerNormWorkload as JaxWorkload
from algoperf.workloads.fastmri.fastmri_pytorch.workload import \
FastMRILayerNormWorkload as PyTorchWorkload
from tests.modeldiffs.diff import out_diff
from tests.modeldiffs.diff import ModelDiffRunner


def sd_transform(sd):
Expand Down Expand Up @@ -68,10 +68,10 @@ def sort_key(k):
image = torch.randn(2, 320, 320)

jax_batch = {'inputs': image.detach().numpy()}
pyt_batch = {'inputs': image}
pytorch_batch = {'inputs': image}

pytorch_model_kwargs = dict(
augmented_and_preprocessed_input_batch=pyt_batch,
augmented_and_preprocessed_input_batch=pytorch_batch,
model_state=None,
mode=spec.ForwardPassMode.EVAL,
rng=None,
Expand All @@ -83,11 +83,10 @@ def sort_key(k):
rng=jax.random.PRNGKey(0),
update_batch_norm=False)

out_diff(
ModelDiffRunner(
jax_workload=jax_workload,
pytorch_workload=pytorch_workload,
jax_model_kwargs=jax_model_kwargs,
pytorch_model_kwargs=pytorch_model_kwargs,
key_transform=None,
sd_transform=sd_transform,
)
sd_transform=sd_transform).run()
11 changes: 5 additions & 6 deletions tests/modeldiffs/fastmri_model_size/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
FastMRIModelSizeWorkload as JaxWorkload
from algoperf.workloads.fastmri.fastmri_pytorch.workload import \
FastMRIModelSizeWorkload as PyTorchWorkload
from tests.modeldiffs.diff import out_diff
from tests.modeldiffs.diff import ModelDiffRunner


def sd_transform(sd):
Expand Down Expand Up @@ -61,10 +61,10 @@ def sort_key(k):
image = torch.randn(2, 320, 320)

jax_batch = {'inputs': image.detach().numpy()}
pyt_batch = {'inputs': image}
pytorch_batch = {'inputs': image}

pytorch_model_kwargs = dict(
augmented_and_preprocessed_input_batch=pyt_batch,
augmented_and_preprocessed_input_batch=pytorch_batch,
model_state=None,
mode=spec.ForwardPassMode.EVAL,
rng=None,
Expand All @@ -76,11 +76,10 @@ def sort_key(k):
rng=jax.random.PRNGKey(0),
update_batch_norm=False)

out_diff(
ModelDiffRunner(
jax_workload=jax_workload,
pytorch_workload=pytorch_workload,
jax_model_kwargs=jax_model_kwargs,
pytorch_model_kwargs=pytorch_model_kwargs,
key_transform=None,
sd_transform=sd_transform,
)
sd_transform=sd_transform).run()
11 changes: 5 additions & 6 deletions tests/modeldiffs/fastmri_tanh/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
FastMRITanhWorkload as JaxWorkload
from algoperf.workloads.fastmri.fastmri_pytorch.workload import \
FastMRITanhWorkload as PyTorchWorkload
from tests.modeldiffs.diff import out_diff
from tests.modeldiffs.diff import ModelDiffRunner


def sd_transform(sd):
Expand Down Expand Up @@ -61,10 +61,10 @@ def sort_key(k):
image = torch.randn(2, 320, 320)

jax_batch = {'inputs': image.detach().numpy()}
pyt_batch = {'inputs': image}
pytorch_batch = {'inputs': image}

pytorch_model_kwargs = dict(
augmented_and_preprocessed_input_batch=pyt_batch,
augmented_and_preprocessed_input_batch=pytorch_batch,
model_state=None,
mode=spec.ForwardPassMode.EVAL,
rng=None,
Expand All @@ -76,11 +76,10 @@ def sort_key(k):
rng=jax.random.PRNGKey(0),
update_batch_norm=False)

out_diff(
ModelDiffRunner(
jax_workload=jax_workload,
pytorch_workload=pytorch_workload,
jax_model_kwargs=jax_model_kwargs,
pytorch_model_kwargs=pytorch_model_kwargs,
key_transform=None,
sd_transform=sd_transform,
)
sd_transform=sd_transform).run()
Loading