From 7a20f936515bffcbfbccf273cf4042c22ba7fe5a Mon Sep 17 00:00:00 2001 From: naeemkh Date: Thu, 27 Mar 2025 16:23:19 +0000 Subject: [PATCH 1/4] Add ModelDiffRunner These changes were made by SujataSaurabh at PR 820. I cannot use co-author becasuse it will raise the same issue for CLA. --- tests/modeldiffs/criteo1tb/compare.py | 6 ++--- tests/modeldiffs/diff.py | 37 +++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 3 deletions(-) diff --git a/tests/modeldiffs/criteo1tb/compare.py b/tests/modeldiffs/criteo1tb/compare.py index d280803af..3fe7ae85e 100644 --- a/tests/modeldiffs/criteo1tb/compare.py +++ b/tests/modeldiffs/criteo1tb/compare.py @@ -12,7 +12,7 @@ Criteo1TbDlrmSmallWorkload as JaxWorkload from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import \ Criteo1TbDlrmSmallWorkload as PyTorchWorkload -from tests.modeldiffs.diff import out_diff +from tests.modeldiffs.diff import ModelDiffRunner def key_transform(k): @@ -74,11 +74,11 @@ def sd_transform(sd): rng=jax.random.PRNGKey(0), update_batch_norm=False) - out_diff( + ModelDiffRunner( jax_workload=jax_workload, pytorch_workload=pytorch_workload, jax_model_kwargs=jax_model_kwargs, pytorch_model_kwargs=pytorch_model_kwargs, key_transform=key_transform, sd_transform=sd_transform, - out_transform=None) + out_transform=None).run() diff --git a/tests/modeldiffs/diff.py b/tests/modeldiffs/diff.py index 9253d2633..a60d3aed1 100644 --- a/tests/modeldiffs/diff.py +++ b/tests/modeldiffs/diff.py @@ -64,3 +64,40 @@ def out_diff(jax_workload, print(f'Max fprop difference between jax and pytorch: {max_diff}') print(f'Min fprop difference between jax and pytorch: {min_diff}') + + +class ModelDiffRunner: + def __init__(self, jax_workload, + pytorch_workload, + jax_model_kwargs, + pytorch_model_kwargs, + key_transform=None, + sd_transform=None, + out_transform=None) -> None: + """Initializes the instance based on diffing logic. + Args: + jax_workload: Workload implementation using JAX + pytorch_workload: Workload implementation using PyTorch + jax_model_kwargs: Arguments to be used for model_fn in jax workload + pytorch_model_kwargs: Arguments to be used for model_fn in PyTorch workload + key_transform: Transformation function for keys. + sd_transform: Transformation function for State Dictionary. + out_transform: Transformation function for the output. + """ + + self.jax_workload = jax_workload + self.pytorch_workload = pytorch_workload + self.jax_model_kwargs = jax_model_kwargs + self.pytorch_model_kwargs = pytorch_model_kwargs + self.key_transform = key_transform + self.sd_transform = sd_transform + self.out_transform = out_transform + + def run(self): + out_diff(self.jax_workload, + self.pytorch_workload, + self.jax_model_kwargs, + self.pytorch_model_kwargs, + self.key_transform, + self.sd_transform, + self.out_transform) \ No newline at end of file From d96da4a4bfa12a33b2101fa6c80191a08ed46ae9 Mon Sep 17 00:00:00 2001 From: naeemkh Date: Thu, 27 Mar 2025 23:50:47 +0000 Subject: [PATCH 2/4] fix style --- tests/modeldiffs/diff.py | 31 ++++++++++++++++++------------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/tests/modeldiffs/diff.py b/tests/modeldiffs/diff.py index a60d3aed1..52241fd3a 100644 --- a/tests/modeldiffs/diff.py +++ b/tests/modeldiffs/diff.py @@ -67,19 +67,24 @@ def out_diff(jax_workload, class ModelDiffRunner: - def __init__(self, jax_workload, - pytorch_workload, - jax_model_kwargs, - pytorch_model_kwargs, - key_transform=None, - sd_transform=None, - out_transform=None) -> None: - """Initializes the instance based on diffing logic. + + def __init__(self, + jax_workload, + pytorch_workload, + jax_model_kwargs, + pytorch_model_kwargs, + key_transform=None, + sd_transform=None, + out_transform=None) -> None: + """ + Initializes the instance based on diffing logic. + Args: - jax_workload: Workload implementation using JAX - pytorch_workload: Workload implementation using PyTorch - jax_model_kwargs: Arguments to be used for model_fn in jax workload - pytorch_model_kwargs: Arguments to be used for model_fn in PyTorch workload + jax_workload: Workload implementation using JAX. + pytorch_workload: Workload implementation using PyTorch. + jax_model_kwargs: Arguments to be used for model_fn in jax workload. + pytorch_model_kwargs: Arguments to be used for model_fn in PyTorch + workload. key_transform: Transformation function for keys. sd_transform: Transformation function for State Dictionary. out_transform: Transformation function for the output. @@ -100,4 +105,4 @@ def run(self): self.pytorch_model_kwargs, self.key_transform, self.sd_transform, - self.out_transform) \ No newline at end of file + self.out_transform) From e5b7930ef2b138d67554b8f0fef33b0aa4ea23fb Mon Sep 17 00:00:00 2001 From: naeemkh Date: Fri, 28 Mar 2025 00:35:50 +0000 Subject: [PATCH 3/4] call out_diff through ModelDiffRunner().run() --- tests/modeldiffs/criteo1tb_embed_init/compare.py | 6 +++--- tests/modeldiffs/criteo1tb_layernorm/compare.py | 6 +++--- tests/modeldiffs/criteo1tb_resnet/compare.py | 6 +++--- tests/modeldiffs/fastmri/compare.py | 7 +++---- tests/modeldiffs/fastmri_layernorm/compare.py | 7 +++---- tests/modeldiffs/fastmri_model_size/compare.py | 7 +++---- tests/modeldiffs/fastmri_tanh/compare.py | 7 +++---- tests/modeldiffs/imagenet_resnet/compare.py | 7 +++---- tests/modeldiffs/imagenet_resnet/gelu_compare.py | 7 +++---- tests/modeldiffs/imagenet_resnet/silu_compare.py | 7 +++---- tests/modeldiffs/imagenet_vit/compare.py | 7 +++---- tests/modeldiffs/imagenet_vit_glu/compare.py | 7 +++---- tests/modeldiffs/imagenet_vit_map/compare.py | 7 +++---- tests/modeldiffs/imagenet_vit_postln/compare.py | 7 +++---- tests/modeldiffs/librispeech_conformer/compare.py | 6 +++--- .../librispeech_conformer_attention_temperature/compare.py | 6 +++--- tests/modeldiffs/librispeech_conformer_gelu/compare.py | 6 +++--- .../modeldiffs/librispeech_conformer_layernorm/compare.py | 6 +++--- tests/modeldiffs/librispeech_deepspeech/compare.py | 6 +++--- .../modeldiffs/librispeech_deepspeech_noresnet/compare.py | 6 +++--- tests/modeldiffs/librispeech_deepspeech_normaug/compare.py | 6 +++--- tests/modeldiffs/librispeech_deepspeech_tanh/compare.py | 6 +++--- tests/modeldiffs/ogbg/compare.py | 6 +++--- tests/modeldiffs/ogbg_gelu/compare.py | 6 +++--- tests/modeldiffs/ogbg_model_size/compare.py | 6 +++--- tests/modeldiffs/ogbg_silu/compare.py | 6 +++--- tests/modeldiffs/wmt/compare.py | 6 +++--- tests/modeldiffs/wmt_attention_temp/compare.py | 6 +++--- tests/modeldiffs/wmt_glu_tanh/compare.py | 6 +++--- tests/modeldiffs/wmt_post_ln/compare.py | 6 +++--- 30 files changed, 90 insertions(+), 101 deletions(-) diff --git a/tests/modeldiffs/criteo1tb_embed_init/compare.py b/tests/modeldiffs/criteo1tb_embed_init/compare.py index 73744c667..986dbb520 100644 --- a/tests/modeldiffs/criteo1tb_embed_init/compare.py +++ b/tests/modeldiffs/criteo1tb_embed_init/compare.py @@ -12,7 +12,7 @@ Criteo1TbDlrmSmallEmbedInitWorkload as JaxWorkload from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import \ Criteo1TbDlrmSmallEmbedInitWorkload as PyTorchWorkload -from tests.modeldiffs.diff import out_diff +from tests.modeldiffs.diff import ModelDiffRunner def key_transform(k): @@ -73,11 +73,11 @@ def sd_transform(sd): rng=jax.random.PRNGKey(0), update_batch_norm=False) - out_diff( + ModelDiffRunner( jax_workload=jax_workload, pytorch_workload=pytorch_workload, jax_model_kwargs=jax_model_kwargs, pytorch_model_kwargs=pytorch_model_kwargs, key_transform=key_transform, sd_transform=sd_transform, - out_transform=None) + out_transform=None).run() diff --git a/tests/modeldiffs/criteo1tb_layernorm/compare.py b/tests/modeldiffs/criteo1tb_layernorm/compare.py index 96e3cc5cc..577243c6d 100644 --- a/tests/modeldiffs/criteo1tb_layernorm/compare.py +++ b/tests/modeldiffs/criteo1tb_layernorm/compare.py @@ -12,7 +12,7 @@ Criteo1TbDlrmSmallLayerNormWorkload as JaxWorkload from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import \ Criteo1TbDlrmSmallLayerNormWorkload as PyTorchWorkload -from tests.modeldiffs.diff import out_diff +from tests.modeldiffs.diff import ModelDiffRunner def key_transform(k): @@ -85,11 +85,11 @@ def sd_transform(sd): rng=jax.random.PRNGKey(0), update_batch_norm=False) - out_diff( + ModelDiffRunner( jax_workload=jax_workload, pytorch_workload=pytorch_workload, jax_model_kwargs=jax_model_kwargs, pytorch_model_kwargs=pytorch_model_kwargs, key_transform=key_transform, sd_transform=sd_transform, - out_transform=None) + out_transform=None).run() diff --git a/tests/modeldiffs/criteo1tb_resnet/compare.py b/tests/modeldiffs/criteo1tb_resnet/compare.py index 188e4cac3..25164d942 100644 --- a/tests/modeldiffs/criteo1tb_resnet/compare.py +++ b/tests/modeldiffs/criteo1tb_resnet/compare.py @@ -13,7 +13,7 @@ Criteo1TbDlrmSmallResNetWorkload as JaxWorkload from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import \ Criteo1TbDlrmSmallResNetWorkload as PyTorchWorkload -from tests.modeldiffs.diff import out_diff +from tests.modeldiffs.diff import ModelDiffRunner def key_transform(k): @@ -92,11 +92,11 @@ def sd_transform(sd): rng=jax.random.PRNGKey(0), update_batch_norm=False) - out_diff( + ModelDiffRunner( jax_workload=jax_workload, pytorch_workload=pytorch_workload, jax_model_kwargs=jax_model_kwargs, pytorch_model_kwargs=pytorch_model_kwargs, key_transform=key_transform, sd_transform=sd_transform, - out_transform=None) + out_transform=None).run() diff --git a/tests/modeldiffs/fastmri/compare.py b/tests/modeldiffs/fastmri/compare.py index da5f0ba0a..8e73a24ad 100644 --- a/tests/modeldiffs/fastmri/compare.py +++ b/tests/modeldiffs/fastmri/compare.py @@ -11,7 +11,7 @@ FastMRIWorkload as JaxWorkload from algoperf.workloads.fastmri.fastmri_pytorch.workload import \ FastMRIWorkload as PyTorchWorkload -from tests.modeldiffs.diff import out_diff +from tests.modeldiffs.diff import ModelDiffRunner def sd_transform(sd): @@ -76,11 +76,10 @@ def sort_key(k): rng=jax.random.PRNGKey(0), update_batch_norm=False) - out_diff( + ModelDiffRunner( jax_workload=jax_workload, pytorch_workload=pytorch_workload, jax_model_kwargs=jax_model_kwargs, pytorch_model_kwargs=pytorch_model_kwargs, key_transform=None, - sd_transform=sd_transform, - ) + sd_transform=sd_transform).run() diff --git a/tests/modeldiffs/fastmri_layernorm/compare.py b/tests/modeldiffs/fastmri_layernorm/compare.py index 5f1eb1842..b67219a71 100644 --- a/tests/modeldiffs/fastmri_layernorm/compare.py +++ b/tests/modeldiffs/fastmri_layernorm/compare.py @@ -11,7 +11,7 @@ FastMRILayerNormWorkload as JaxWorkload from algoperf.workloads.fastmri.fastmri_pytorch.workload import \ FastMRILayerNormWorkload as PyTorchWorkload -from tests.modeldiffs.diff import out_diff +from tests.modeldiffs.diff import ModelDiffRunner def sd_transform(sd): @@ -83,11 +83,10 @@ def sort_key(k): rng=jax.random.PRNGKey(0), update_batch_norm=False) - out_diff( + ModelDiffRunner( jax_workload=jax_workload, pytorch_workload=pytorch_workload, jax_model_kwargs=jax_model_kwargs, pytorch_model_kwargs=pytorch_model_kwargs, key_transform=None, - sd_transform=sd_transform, - ) + sd_transform=sd_transform).run() diff --git a/tests/modeldiffs/fastmri_model_size/compare.py b/tests/modeldiffs/fastmri_model_size/compare.py index ebb8669f8..d20b1b9ae 100644 --- a/tests/modeldiffs/fastmri_model_size/compare.py +++ b/tests/modeldiffs/fastmri_model_size/compare.py @@ -11,7 +11,7 @@ FastMRIModelSizeWorkload as JaxWorkload from algoperf.workloads.fastmri.fastmri_pytorch.workload import \ FastMRIModelSizeWorkload as PyTorchWorkload -from tests.modeldiffs.diff import out_diff +from tests.modeldiffs.diff import ModelDiffRunner def sd_transform(sd): @@ -76,11 +76,10 @@ def sort_key(k): rng=jax.random.PRNGKey(0), update_batch_norm=False) - out_diff( + ModelDiffRunner( jax_workload=jax_workload, pytorch_workload=pytorch_workload, jax_model_kwargs=jax_model_kwargs, pytorch_model_kwargs=pytorch_model_kwargs, key_transform=None, - sd_transform=sd_transform, - ) + sd_transform=sd_transform).run() diff --git a/tests/modeldiffs/fastmri_tanh/compare.py b/tests/modeldiffs/fastmri_tanh/compare.py index 558bc2ba1..ead07ba94 100644 --- a/tests/modeldiffs/fastmri_tanh/compare.py +++ b/tests/modeldiffs/fastmri_tanh/compare.py @@ -11,7 +11,7 @@ FastMRITanhWorkload as JaxWorkload from algoperf.workloads.fastmri.fastmri_pytorch.workload import \ FastMRITanhWorkload as PyTorchWorkload -from tests.modeldiffs.diff import out_diff +from tests.modeldiffs.diff import ModelDiffRunner def sd_transform(sd): @@ -76,11 +76,10 @@ def sort_key(k): rng=jax.random.PRNGKey(0), update_batch_norm=False) - out_diff( + ModelDiffRunner( jax_workload=jax_workload, pytorch_workload=pytorch_workload, jax_model_kwargs=jax_model_kwargs, pytorch_model_kwargs=pytorch_model_kwargs, key_transform=None, - sd_transform=sd_transform, - ) + sd_transform=sd_transform).run() diff --git a/tests/modeldiffs/imagenet_resnet/compare.py b/tests/modeldiffs/imagenet_resnet/compare.py index 0a6a1b7c5..933aed1e5 100644 --- a/tests/modeldiffs/imagenet_resnet/compare.py +++ b/tests/modeldiffs/imagenet_resnet/compare.py @@ -11,7 +11,7 @@ ImagenetResNetWorkload as JaxWorkload from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import \ ImagenetResNetWorkload as PyTorchWorkload -from tests.modeldiffs.diff import out_diff +from tests.modeldiffs.diff import ModelDiffRunner def key_transform(k): @@ -93,11 +93,10 @@ def sd_transform(sd): rng=jax.random.PRNGKey(0), update_batch_norm=False) - out_diff( + ModelDiffRunner( jax_workload=jax_workload, pytorch_workload=pytorch_workload, jax_model_kwargs=jax_model_kwargs, pytorch_model_kwargs=pytorch_model_kwargs, key_transform=key_transform, - sd_transform=sd_transform, - ) + sd_transform=sd_transform).run() diff --git a/tests/modeldiffs/imagenet_resnet/gelu_compare.py b/tests/modeldiffs/imagenet_resnet/gelu_compare.py index 4f20873b7..4387b0c4c 100644 --- a/tests/modeldiffs/imagenet_resnet/gelu_compare.py +++ b/tests/modeldiffs/imagenet_resnet/gelu_compare.py @@ -11,7 +11,7 @@ ImagenetResNetGELUWorkload as JaxWorkload from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import \ ImagenetResNetGELUWorkload as PyTorchWorkload -from tests.modeldiffs.diff import out_diff +from tests.modeldiffs.diff import ModelDiffRunner from tests.modeldiffs.imagenet_resnet.compare import key_transform from tests.modeldiffs.imagenet_resnet.compare import sd_transform @@ -40,11 +40,10 @@ rng=jax.random.PRNGKey(0), update_batch_norm=False) - out_diff( + ModelDiffRunner( jax_workload=jax_workload, pytorch_workload=pytorch_workload, jax_model_kwargs=jax_model_kwargs, pytorch_model_kwargs=pytorch_model_kwargs, key_transform=key_transform, - sd_transform=sd_transform, - ) + sd_transform=sd_transform).run() diff --git a/tests/modeldiffs/imagenet_resnet/silu_compare.py b/tests/modeldiffs/imagenet_resnet/silu_compare.py index e94fdcd4c..8ca9408bc 100644 --- a/tests/modeldiffs/imagenet_resnet/silu_compare.py +++ b/tests/modeldiffs/imagenet_resnet/silu_compare.py @@ -11,7 +11,7 @@ ImagenetResNetSiLUWorkload as JaxWorkload from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import \ ImagenetResNetSiLUWorkload as PyTorchWorkload -from tests.modeldiffs.diff import out_diff +from tests.modeldiffs.diff import ModelDiffRunner from tests.modeldiffs.imagenet_resnet.compare import key_transform from tests.modeldiffs.imagenet_resnet.compare import sd_transform @@ -40,11 +40,10 @@ rng=jax.random.PRNGKey(0), update_batch_norm=False) - out_diff( + ModelDiffRunner( jax_workload=jax_workload, pytorch_workload=pytorch_workload, jax_model_kwargs=jax_model_kwargs, pytorch_model_kwargs=pytorch_model_kwargs, key_transform=key_transform, - sd_transform=sd_transform, - ) + sd_transform=sd_transform).run() diff --git a/tests/modeldiffs/imagenet_vit/compare.py b/tests/modeldiffs/imagenet_vit/compare.py index b7b9af794..f42ec3172 100644 --- a/tests/modeldiffs/imagenet_vit/compare.py +++ b/tests/modeldiffs/imagenet_vit/compare.py @@ -11,7 +11,7 @@ ImagenetVitWorkload as JaxWorkload from algoperf.workloads.imagenet_vit.imagenet_pytorch.workload import \ ImagenetVitWorkload as PyTorchWorkload -from tests.modeldiffs.diff import out_diff +from tests.modeldiffs.diff import ModelDiffRunner def key_transform(k): @@ -106,11 +106,10 @@ def key_transform(k): rng=jax.random.PRNGKey(0), update_batch_norm=False) - out_diff( + ModelDiffRunner( jax_workload=jax_workload, pytorch_workload=pytorch_workload, jax_model_kwargs=jax_model_kwargs, pytorch_model_kwargs=pytorch_model_kwargs, key_transform=key_transform, - sd_transform=None, - ) + sd_transform=None).run() diff --git a/tests/modeldiffs/imagenet_vit_glu/compare.py b/tests/modeldiffs/imagenet_vit_glu/compare.py index 11edcd84e..fb73d51ae 100644 --- a/tests/modeldiffs/imagenet_vit_glu/compare.py +++ b/tests/modeldiffs/imagenet_vit_glu/compare.py @@ -1,6 +1,6 @@ import os -from tests.modeldiffs.diff import out_diff +from tests.modeldiffs.diff import ModelDiffRunner from tests.modeldiffs.imagenet_vit.compare import key_transform # Disable GPU access for both jax and pytorch. @@ -42,11 +42,10 @@ rng=jax.random.PRNGKey(0), update_batch_norm=False) - out_diff( + ModelDiffRunner( jax_workload=jax_workload, pytorch_workload=pytorch_workload, jax_model_kwargs=jax_model_kwargs, pytorch_model_kwargs=pytorch_model_kwargs, key_transform=key_transform, - sd_transform=None, - ) + sd_transform=None).run() diff --git a/tests/modeldiffs/imagenet_vit_map/compare.py b/tests/modeldiffs/imagenet_vit_map/compare.py index 70bcd2e04..8b6f9928d 100644 --- a/tests/modeldiffs/imagenet_vit_map/compare.py +++ b/tests/modeldiffs/imagenet_vit_map/compare.py @@ -1,6 +1,6 @@ import os -from tests.modeldiffs.diff import out_diff +from tests.modeldiffs.diff import ModelDiffRunner from tests.modeldiffs.imagenet_vit.compare import key_transform # Disable GPU access for both jax and pytorch. @@ -53,11 +53,10 @@ def sd_transform(sd): rng=jax.random.PRNGKey(0), update_batch_norm=False) - out_diff( + ModelDiffRunner( jax_workload=jax_workload, pytorch_workload=pytorch_workload, jax_model_kwargs=jax_model_kwargs, pytorch_model_kwargs=pytorch_model_kwargs, key_transform=key_transform, - sd_transform=sd_transform, - ) + sd_transform=sd_transform).run() diff --git a/tests/modeldiffs/imagenet_vit_postln/compare.py b/tests/modeldiffs/imagenet_vit_postln/compare.py index 113a65a2a..2b7f93b09 100644 --- a/tests/modeldiffs/imagenet_vit_postln/compare.py +++ b/tests/modeldiffs/imagenet_vit_postln/compare.py @@ -1,6 +1,6 @@ import os -from tests.modeldiffs.diff import out_diff +from tests.modeldiffs.diff import ModelDiffRunner from tests.modeldiffs.imagenet_vit.compare import key_transform # Disable GPU access for both jax and pytorch. @@ -42,11 +42,10 @@ rng=jax.random.PRNGKey(0), update_batch_norm=False) - out_diff( + ModelDiffRunner( jax_workload=jax_workload, pytorch_workload=pytorch_workload, jax_model_kwargs=jax_model_kwargs, pytorch_model_kwargs=pytorch_model_kwargs, key_transform=key_transform, - sd_transform=None, - ) + sd_transform=None).run() diff --git a/tests/modeldiffs/librispeech_conformer/compare.py b/tests/modeldiffs/librispeech_conformer/compare.py index 5bfbf915a..a6a1a78c0 100644 --- a/tests/modeldiffs/librispeech_conformer/compare.py +++ b/tests/modeldiffs/librispeech_conformer/compare.py @@ -11,7 +11,7 @@ LibriSpeechConformerWorkload as JaxWorkload from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import \ LibriSpeechConformerWorkload as PyTorchWorkload -from tests.modeldiffs.diff import out_diff +from tests.modeldiffs.diff import ModelDiffRunner def key_transform(k): @@ -81,7 +81,7 @@ def sd_transform(sd): rng=jax.random.PRNGKey(0), update_batch_norm=False) - out_diff( + ModelDiffRunner( jax_workload=jax_workload, pytorch_workload=pytorch_workload, jax_model_kwargs=jax_model_kwargs, @@ -89,4 +89,4 @@ def sd_transform(sd): key_transform=key_transform, sd_transform=sd_transform, out_transform=lambda out_outpad: out_outpad[0] * - (1 - out_outpad[1][:, :, None])) + (1 - out_outpad[1][:, :, None])).run() diff --git a/tests/modeldiffs/librispeech_conformer_attention_temperature/compare.py b/tests/modeldiffs/librispeech_conformer_attention_temperature/compare.py index bb9a8fae1..bc56ca623 100644 --- a/tests/modeldiffs/librispeech_conformer_attention_temperature/compare.py +++ b/tests/modeldiffs/librispeech_conformer_attention_temperature/compare.py @@ -11,7 +11,7 @@ LibriSpeechConformerAttentionTemperatureWorkload as JaxWorkload from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import \ LibriSpeechConformerAttentionTemperatureWorkload as PyTorchWorkload -from tests.modeldiffs.diff import out_diff +from tests.modeldiffs.diff import ModelDiffRunner def key_transform(k): @@ -81,7 +81,7 @@ def sd_transform(sd): rng=jax.random.PRNGKey(0), update_batch_norm=False) - out_diff( + ModelDiffRunner( jax_workload=jax_workload, pytorch_workload=pytorch_workload, jax_model_kwargs=jax_model_kwargs, @@ -89,4 +89,4 @@ def sd_transform(sd): key_transform=key_transform, sd_transform=sd_transform, out_transform=lambda out_outpad: out_outpad[0] * - (1 - out_outpad[1][:, :, None])) + (1 - out_outpad[1][:, :, None])).run() diff --git a/tests/modeldiffs/librispeech_conformer_gelu/compare.py b/tests/modeldiffs/librispeech_conformer_gelu/compare.py index 629418488..bc2ebaa3b 100644 --- a/tests/modeldiffs/librispeech_conformer_gelu/compare.py +++ b/tests/modeldiffs/librispeech_conformer_gelu/compare.py @@ -11,7 +11,7 @@ LibriSpeechConformerGeluWorkload as JaxWorkload from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import \ LibriSpeechConformerGeluWorkload as PyTorchWorkload -from tests.modeldiffs.diff import out_diff +from tests.modeldiffs.diff import ModelDiffRunner def key_transform(k): @@ -81,7 +81,7 @@ def sd_transform(sd): rng=jax.random.PRNGKey(0), update_batch_norm=False) - out_diff( + ModelDiffRunner( jax_workload=jax_workload, pytorch_workload=pytorch_workload, jax_model_kwargs=jax_model_kwargs, @@ -89,4 +89,4 @@ def sd_transform(sd): key_transform=key_transform, sd_transform=sd_transform, out_transform=lambda out_outpad: out_outpad[0] * - (1 - out_outpad[1][:, :, None])) + (1 - out_outpad[1][:, :, None])).run() diff --git a/tests/modeldiffs/librispeech_conformer_layernorm/compare.py b/tests/modeldiffs/librispeech_conformer_layernorm/compare.py index 48fe991f7..a17b818a5 100644 --- a/tests/modeldiffs/librispeech_conformer_layernorm/compare.py +++ b/tests/modeldiffs/librispeech_conformer_layernorm/compare.py @@ -11,7 +11,7 @@ LibriSpeechConformerLayerNormWorkload as JaxWorkload from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import \ LibriSpeechConformerLayerNormWorkload as PyTorchWorkload -from tests.modeldiffs.diff import out_diff +from tests.modeldiffs.diff import ModelDiffRunner def key_transform(k): @@ -81,7 +81,7 @@ def sd_transform(sd): rng=jax.random.PRNGKey(0), update_batch_norm=False) - out_diff( + ModelDiffRunner( jax_workload=jax_workload, pytorch_workload=pytorch_workload, jax_model_kwargs=jax_model_kwargs, @@ -89,4 +89,4 @@ def sd_transform(sd): key_transform=key_transform, sd_transform=sd_transform, out_transform=lambda out_outpad: out_outpad[0] * - (1 - out_outpad[1][:, :, None])) + (1 - out_outpad[1][:, :, None])).run() diff --git a/tests/modeldiffs/librispeech_deepspeech/compare.py b/tests/modeldiffs/librispeech_deepspeech/compare.py index 81e12b15d..92e5b67d0 100644 --- a/tests/modeldiffs/librispeech_deepspeech/compare.py +++ b/tests/modeldiffs/librispeech_deepspeech/compare.py @@ -11,7 +11,7 @@ LibriSpeechDeepSpeechWorkload as JaxWorkload from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.workload import \ LibriSpeechDeepSpeechWorkload as PyTorchWorkload -from tests.modeldiffs.diff import out_diff +from tests.modeldiffs.diff import ModelDiffRunner def key_transform(k): @@ -106,7 +106,7 @@ def sd_transform(sd): rng=jax.random.PRNGKey(0), update_batch_norm=False) - out_diff( + ModelDiffRunner( jax_workload=jax_workload, pytorch_workload=pytorch_workload, jax_model_kwargs=jax_model_kwargs, @@ -114,4 +114,4 @@ def sd_transform(sd): key_transform=key_transform, sd_transform=sd_transform, out_transform=lambda out_outpad: out_outpad[0] * - (1 - out_outpad[1][:, :, None])) + (1 - out_outpad[1][:, :, None])).run() diff --git a/tests/modeldiffs/librispeech_deepspeech_noresnet/compare.py b/tests/modeldiffs/librispeech_deepspeech_noresnet/compare.py index ea106ebe4..ec0dde49e 100644 --- a/tests/modeldiffs/librispeech_deepspeech_noresnet/compare.py +++ b/tests/modeldiffs/librispeech_deepspeech_noresnet/compare.py @@ -11,7 +11,7 @@ LibriSpeechDeepSpeechTanhWorkload as JaxWorkload from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.workload import \ LibriSpeechDeepSpeechTanhWorkload as PyTorchWorkload -from tests.modeldiffs.diff import out_diff +from tests.modeldiffs.diff import ModelDiffRunner from tests.modeldiffs.librispeech_deepspeech.compare import key_transform from tests.modeldiffs.librispeech_deepspeech.compare import sd_transform @@ -42,7 +42,7 @@ rng=jax.random.PRNGKey(0), update_batch_norm=False) - out_diff( + ModelDiffRunner( jax_workload=jax_workload, pytorch_workload=pytorch_workload, jax_model_kwargs=jax_model_kwargs, @@ -50,4 +50,4 @@ key_transform=key_transform, sd_transform=sd_transform, out_transform=lambda out_outpad: out_outpad[0] * - (1 - out_outpad[1][:, :, None])) + (1 - out_outpad[1][:, :, None])).run() diff --git a/tests/modeldiffs/librispeech_deepspeech_normaug/compare.py b/tests/modeldiffs/librispeech_deepspeech_normaug/compare.py index ecb6d28af..a3231b2f9 100644 --- a/tests/modeldiffs/librispeech_deepspeech_normaug/compare.py +++ b/tests/modeldiffs/librispeech_deepspeech_normaug/compare.py @@ -11,7 +11,7 @@ LibriSpeechDeepSpeechNormAndSpecAugWorkload as JaxWorkload from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.workload import \ LibriSpeechDeepSpeechNormAndSpecAugWorkload as PyTorchWorkload -from tests.modeldiffs.diff import out_diff +from tests.modeldiffs.diff import ModelDiffRunner from tests.modeldiffs.librispeech_deepspeech.compare import key_transform from tests.modeldiffs.librispeech_deepspeech.compare import sd_transform @@ -42,7 +42,7 @@ rng=jax.random.PRNGKey(0), update_batch_norm=False) - out_diff( + ModelDiffRunner( jax_workload=jax_workload, pytorch_workload=pytorch_workload, jax_model_kwargs=jax_model_kwargs, @@ -50,4 +50,4 @@ key_transform=key_transform, sd_transform=sd_transform, out_transform=lambda out_outpad: out_outpad[0] * - (1 - out_outpad[1][:, :, None])) + (1 - out_outpad[1][:, :, None])).run() diff --git a/tests/modeldiffs/librispeech_deepspeech_tanh/compare.py b/tests/modeldiffs/librispeech_deepspeech_tanh/compare.py index 31d9029b4..518c8c9e0 100644 --- a/tests/modeldiffs/librispeech_deepspeech_tanh/compare.py +++ b/tests/modeldiffs/librispeech_deepspeech_tanh/compare.py @@ -11,7 +11,7 @@ LibriSpeechDeepSpeechNoResNetWorkload as JaxWorkload from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.workload import \ LibriSpeechDeepSpeechNoResNetWorkload as PyTorchWorkload -from tests.modeldiffs.diff import out_diff +from tests.modeldiffs.diff import ModelDiffRunner from tests.modeldiffs.librispeech_deepspeech.compare import key_transform from tests.modeldiffs.librispeech_deepspeech.compare import sd_transform @@ -42,7 +42,7 @@ rng=jax.random.PRNGKey(0), update_batch_norm=False) - out_diff( + ModelDiffRunner( jax_workload=jax_workload, pytorch_workload=pytorch_workload, jax_model_kwargs=jax_model_kwargs, @@ -50,4 +50,4 @@ key_transform=key_transform, sd_transform=sd_transform, out_transform=lambda out_outpad: out_outpad[0] * - (1 - out_outpad[1][:, :, None])) + (1 - out_outpad[1][:, :, None])).run() diff --git a/tests/modeldiffs/ogbg/compare.py b/tests/modeldiffs/ogbg/compare.py index 43ca48764..e41f032c6 100644 --- a/tests/modeldiffs/ogbg/compare.py +++ b/tests/modeldiffs/ogbg/compare.py @@ -13,7 +13,7 @@ OgbgWorkload as JaxWorkload from algoperf.workloads.ogbg.ogbg_pytorch.workload import \ OgbgWorkload as PyTorchWorkload -from tests.modeldiffs.diff import out_diff +from tests.modeldiffs.diff import ModelDiffRunner # Todo: refactor tests to use workload properties in cleaner way hidden_dims = len(JaxWorkload().hidden_dims) @@ -110,11 +110,11 @@ def sd_transform(sd): rng=jax.random.PRNGKey(0), update_batch_norm=False) - out_diff( + ModelDiffRunner( jax_workload=jax_workload, pytorch_workload=pytorch_workload, jax_model_kwargs=jax_model_kwargs, pytorch_model_kwargs=pytorch_model_kwargs, key_transform=key_transform, sd_transform=sd_transform, - out_transform=None) + out_transform=None).run() diff --git a/tests/modeldiffs/ogbg_gelu/compare.py b/tests/modeldiffs/ogbg_gelu/compare.py index 062588fe2..bbcc49255 100644 --- a/tests/modeldiffs/ogbg_gelu/compare.py +++ b/tests/modeldiffs/ogbg_gelu/compare.py @@ -13,7 +13,7 @@ OgbgGeluWorkload as JaxWorkload from algoperf.workloads.ogbg.ogbg_pytorch.workload import \ OgbgGeluWorkload as PyTorchWorkload -from tests.modeldiffs.diff import out_diff +from tests.modeldiffs.diff import ModelDiffRunner # Todo: refactor tests to use workload properties in cleaner way hidden_dims = len(JaxWorkload().hidden_dims) @@ -110,11 +110,11 @@ def sd_transform(sd): rng=jax.random.PRNGKey(0), update_batch_norm=False) - out_diff( + ModelDiffRunner( jax_workload=jax_workload, pytorch_workload=pytorch_workload, jax_model_kwargs=jax_model_kwargs, pytorch_model_kwargs=pytorch_model_kwargs, key_transform=key_transform, sd_transform=sd_transform, - out_transform=None) + out_transform=None).run() diff --git a/tests/modeldiffs/ogbg_model_size/compare.py b/tests/modeldiffs/ogbg_model_size/compare.py index 2eb70d097..afb8e5728 100644 --- a/tests/modeldiffs/ogbg_model_size/compare.py +++ b/tests/modeldiffs/ogbg_model_size/compare.py @@ -13,7 +13,7 @@ OgbgModelSizeWorkload as JaxWorkload from algoperf.workloads.ogbg.ogbg_pytorch.workload import \ OgbgModelSizeWorkload as PyTorchWorkload -from tests.modeldiffs.diff import out_diff +from tests.modeldiffs.diff import ModelDiffRunner # Todo: refactor tests to use workload properties in cleaner way hidden_dims = len(JaxWorkload().hidden_dims) @@ -110,11 +110,11 @@ def sd_transform(sd): rng=jax.random.PRNGKey(0), update_batch_norm=False) - out_diff( + ModelDiffRunner( jax_workload=jax_workload, pytorch_workload=pytorch_workload, jax_model_kwargs=jax_model_kwargs, pytorch_model_kwargs=pytorch_model_kwargs, key_transform=key_transform, sd_transform=sd_transform, - out_transform=None) + out_transform=None).run() diff --git a/tests/modeldiffs/ogbg_silu/compare.py b/tests/modeldiffs/ogbg_silu/compare.py index 19e446030..21ab631d4 100644 --- a/tests/modeldiffs/ogbg_silu/compare.py +++ b/tests/modeldiffs/ogbg_silu/compare.py @@ -13,7 +13,7 @@ OgbgSiluWorkload as JaxWorkload from algoperf.workloads.ogbg.ogbg_pytorch.workload import \ OgbgSiluWorkload as PyTorchWorkload -from tests.modeldiffs.diff import out_diff +from tests.modeldiffs.diff import ModelDiffRunner # Todo: refactor tests to use workload properties in cleaner way hidden_dims = len(JaxWorkload().hidden_dims) @@ -110,11 +110,11 @@ def sd_transform(sd): rng=jax.random.PRNGKey(0), update_batch_norm=False) - out_diff( + ModelDiffRunner( jax_workload=jax_workload, pytorch_workload=pytorch_workload, jax_model_kwargs=jax_model_kwargs, pytorch_model_kwargs=pytorch_model_kwargs, key_transform=key_transform, sd_transform=sd_transform, - out_transform=None) + out_transform=None).run() diff --git a/tests/modeldiffs/wmt/compare.py b/tests/modeldiffs/wmt/compare.py index 64401ef7f..2e01ce276 100644 --- a/tests/modeldiffs/wmt/compare.py +++ b/tests/modeldiffs/wmt/compare.py @@ -10,7 +10,7 @@ from algoperf.workloads.wmt.wmt_jax.workload import WmtWorkload as JaxWorkload from algoperf.workloads.wmt.wmt_pytorch.workload import \ WmtWorkload as PyTorchWorkload -from tests.modeldiffs.diff import out_diff +from tests.modeldiffs.diff import ModelDiffRunner def key_transform(k): @@ -130,11 +130,11 @@ def sd_transform(sd): rng=jax.random.PRNGKey(0), update_batch_norm=False) - out_diff( + ModelDiffRunner( jax_workload=jax_workload, pytorch_workload=pytorch_workload, jax_model_kwargs=jax_model_kwargs, pytorch_model_kwargs=pytorch_model_kwargs, key_transform=key_transform, sd_transform=sd_transform, - out_transform=None) + out_transform=None).run() diff --git a/tests/modeldiffs/wmt_attention_temp/compare.py b/tests/modeldiffs/wmt_attention_temp/compare.py index 01dc2895c..61a807452 100644 --- a/tests/modeldiffs/wmt_attention_temp/compare.py +++ b/tests/modeldiffs/wmt_attention_temp/compare.py @@ -11,7 +11,7 @@ WmtWorkloadAttentionTemp as JaxWorkload from algoperf.workloads.wmt.wmt_pytorch.workload import \ WmtWorkloadAttentionTemp as PyTorchWorkload -from tests.modeldiffs.diff import out_diff +from tests.modeldiffs.diff import ModelDiffRunner def key_transform(k): @@ -131,11 +131,11 @@ def sd_transform(sd): rng=jax.random.PRNGKey(0), update_batch_norm=False) - out_diff( + ModelDiffRunner( jax_workload=jax_workload, pytorch_workload=pytorch_workload, jax_model_kwargs=jax_model_kwargs, pytorch_model_kwargs=pytorch_model_kwargs, key_transform=key_transform, sd_transform=sd_transform, - out_transform=None) + out_transform=None).run() diff --git a/tests/modeldiffs/wmt_glu_tanh/compare.py b/tests/modeldiffs/wmt_glu_tanh/compare.py index 77e71c826..1e335a81e 100644 --- a/tests/modeldiffs/wmt_glu_tanh/compare.py +++ b/tests/modeldiffs/wmt_glu_tanh/compare.py @@ -11,7 +11,7 @@ WmtWorkloadGLUTanH as JaxWorkload from algoperf.workloads.wmt.wmt_pytorch.workload import \ WmtWorkloadGLUTanH as PyTorchWorkload -from tests.modeldiffs.diff import out_diff +from tests.modeldiffs.diff import ModelDiffRunner def key_transform(k): @@ -131,11 +131,11 @@ def sd_transform(sd): rng=jax.random.PRNGKey(0), update_batch_norm=False) - out_diff( + ModelDiffRunner( jax_workload=jax_workload, pytorch_workload=pytorch_workload, jax_model_kwargs=jax_model_kwargs, pytorch_model_kwargs=pytorch_model_kwargs, key_transform=key_transform, sd_transform=sd_transform, - out_transform=None) + out_transform=None).run() diff --git a/tests/modeldiffs/wmt_post_ln/compare.py b/tests/modeldiffs/wmt_post_ln/compare.py index 909fcd672..34ce37c83 100644 --- a/tests/modeldiffs/wmt_post_ln/compare.py +++ b/tests/modeldiffs/wmt_post_ln/compare.py @@ -11,7 +11,7 @@ WmtWorkloadPostLN as JaxWorkload from algoperf.workloads.wmt.wmt_pytorch.workload import \ WmtWorkloadPostLN as PyTorchWorkload -from tests.modeldiffs.diff import out_diff +from tests.modeldiffs.diff import ModelDiffRunner def key_transform(k): @@ -131,11 +131,11 @@ def sd_transform(sd): rng=jax.random.PRNGKey(0), update_batch_norm=False) - out_diff( + ModelDiffRunner( jax_workload=jax_workload, pytorch_workload=pytorch_workload, jax_model_kwargs=jax_model_kwargs, pytorch_model_kwargs=pytorch_model_kwargs, key_transform=key_transform, sd_transform=sd_transform, - out_transform=None) + out_transform=None).run() From 5efdfad14058c8533c96f197da01a669e88e9d6b Mon Sep 17 00:00:00 2001 From: naeemkh Date: Fri, 28 Mar 2025 01:29:07 +0000 Subject: [PATCH 4/4] change var name to enhance readibility pyt --> pytorch k --> eval_metric_key --- tests/modeldiffs/criteo1tb/compare.py | 6 ++-- .../criteo1tb_embed_init/compare.py | 6 ++-- .../modeldiffs/criteo1tb_layernorm/compare.py | 6 ++-- tests/modeldiffs/criteo1tb_resnet/compare.py | 6 ++-- tests/modeldiffs/fastmri/compare.py | 4 +-- tests/modeldiffs/fastmri_layernorm/compare.py | 4 +-- .../modeldiffs/fastmri_model_size/compare.py | 4 +-- tests/modeldiffs/fastmri_tanh/compare.py | 4 +-- tests/modeldiffs/imagenet_resnet/compare.py | 4 +-- .../imagenet_resnet/gelu_compare.py | 4 +-- .../imagenet_resnet/silu_compare.py | 4 +-- tests/modeldiffs/imagenet_vit/compare.py | 4 +-- tests/modeldiffs/imagenet_vit_glu/compare.py | 4 +-- tests/modeldiffs/imagenet_vit_map/compare.py | 4 +-- .../modeldiffs/imagenet_vit_postln/compare.py | 4 +-- .../librispeech_conformer/compare.py | 4 +-- .../compare.py | 4 +-- .../librispeech_conformer_gelu/compare.py | 4 +-- .../compare.py | 4 +-- .../librispeech_deepspeech/compare.py | 4 +-- .../compare.py | 4 +-- .../librispeech_deepspeech_normaug/compare.py | 4 +-- .../librispeech_deepspeech_tanh/compare.py | 4 +-- tests/modeldiffs/ogbg/compare.py | 10 +++--- tests/modeldiffs/ogbg_gelu/compare.py | 10 +++--- tests/modeldiffs/ogbg_model_size/compare.py | 10 +++--- tests/modeldiffs/ogbg_silu/compare.py | 10 +++--- tests/modeldiffs/wmt/compare.py | 4 +-- .../modeldiffs/wmt_attention_temp/compare.py | 4 +-- tests/modeldiffs/wmt_glu_tanh/compare.py | 4 +-- tests/modeldiffs/wmt_post_ln/compare.py | 4 +-- tests/test_traindiffs.py | 32 +++++++++---------- 32 files changed, 94 insertions(+), 94 deletions(-) diff --git a/tests/modeldiffs/criteo1tb/compare.py b/tests/modeldiffs/criteo1tb/compare.py index 3fe7ae85e..9de61a2a5 100644 --- a/tests/modeldiffs/criteo1tb/compare.py +++ b/tests/modeldiffs/criteo1tb/compare.py @@ -53,16 +53,16 @@ def sd_transform(sd): jax_workload = JaxWorkload() pytorch_workload = PyTorchWorkload() - pyt_batch = { + pytorch_batch = { 'inputs': torch.ones((2, 13 + 26)), 'targets': torch.randint(low=0, high=1, size=(2,)), 'weights': torch.ones(2), } - jax_batch = {k: np.array(v) for k, v in pyt_batch.items()} + jax_batch = {k: np.array(v) for k, v in pytorch_batch.items()} # Test outputs for identical weights and inputs. pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pyt_batch, + augmented_and_preprocessed_input_batch=pytorch_batch, model_state=None, mode=spec.ForwardPassMode.EVAL, rng=None, diff --git a/tests/modeldiffs/criteo1tb_embed_init/compare.py b/tests/modeldiffs/criteo1tb_embed_init/compare.py index 986dbb520..f1897d16f 100644 --- a/tests/modeldiffs/criteo1tb_embed_init/compare.py +++ b/tests/modeldiffs/criteo1tb_embed_init/compare.py @@ -52,16 +52,16 @@ def sd_transform(sd): jax_workload = JaxWorkload() pytorch_workload = PyTorchWorkload() - pyt_batch = { + pytorch_batch = { 'inputs': torch.ones((2, 13 + 26)), 'targets': torch.randint(low=0, high=1, size=(2,)), 'weights': torch.ones(2), } - jax_batch = {k: np.array(v) for k, v in pyt_batch.items()} + jax_batch = {k: np.array(v) for k, v in pytorch_batch.items()} # Test outputs for identical weights and inputs. pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pyt_batch, + augmented_and_preprocessed_input_batch=pytorch_batch, model_state=None, mode=spec.ForwardPassMode.EVAL, rng=None, diff --git a/tests/modeldiffs/criteo1tb_layernorm/compare.py b/tests/modeldiffs/criteo1tb_layernorm/compare.py index 577243c6d..5aad3cc67 100644 --- a/tests/modeldiffs/criteo1tb_layernorm/compare.py +++ b/tests/modeldiffs/criteo1tb_layernorm/compare.py @@ -64,16 +64,16 @@ def sd_transform(sd): jax_workload = JaxWorkload() pytorch_workload = PyTorchWorkload() - pyt_batch = { + pytorch_batch = { 'inputs': torch.ones((2, 13 + 26)), 'targets': torch.randint(low=0, high=1, size=(2,)), 'weights': torch.ones(2), } - jax_batch = {k: np.array(v) for k, v in pyt_batch.items()} + jax_batch = {k: np.array(v) for k, v in pytorch_batch.items()} # Test outputs for identical weights and inputs. pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pyt_batch, + augmented_and_preprocessed_input_batch=pytorch_batch, model_state=None, mode=spec.ForwardPassMode.EVAL, rng=None, diff --git a/tests/modeldiffs/criteo1tb_resnet/compare.py b/tests/modeldiffs/criteo1tb_resnet/compare.py index 25164d942..169b1cdf4 100644 --- a/tests/modeldiffs/criteo1tb_resnet/compare.py +++ b/tests/modeldiffs/criteo1tb_resnet/compare.py @@ -64,7 +64,7 @@ def sd_transform(sd): jax_workload = JaxWorkload() pytorch_workload = PyTorchWorkload() - pyt_batch = { + pytorch_batch = { 'inputs': torch.ones((2, 13 + 26)), 'targets': torch.randint(low=0, high=1, size=(2,)), 'weights': torch.ones(2), @@ -75,12 +75,12 @@ def sd_transform(sd): input_size = 13 + num_categorical_features input_shape = (init_fake_batch_size, input_size) fake_inputs = jnp.ones(input_shape, jnp.float32) - jax_batch = {k: np.array(v) for k, v in pyt_batch.items()} + jax_batch = {k: np.array(v) for k, v in pytorch_batch.items()} jax_batch['inputs'] = fake_inputs # Test outputs for identical weights and inputs. pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pyt_batch, + augmented_and_preprocessed_input_batch=pytorch_batch, model_state=None, mode=spec.ForwardPassMode.EVAL, rng=None, diff --git a/tests/modeldiffs/fastmri/compare.py b/tests/modeldiffs/fastmri/compare.py index 8e73a24ad..c1a349cec 100644 --- a/tests/modeldiffs/fastmri/compare.py +++ b/tests/modeldiffs/fastmri/compare.py @@ -61,10 +61,10 @@ def sort_key(k): image = torch.randn(2, 320, 320) jax_batch = {'inputs': image.detach().numpy()} - pyt_batch = {'inputs': image} + pytorch_batch = {'inputs': image} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pyt_batch, + augmented_and_preprocessed_input_batch=pytorch_batch, model_state=None, mode=spec.ForwardPassMode.EVAL, rng=None, diff --git a/tests/modeldiffs/fastmri_layernorm/compare.py b/tests/modeldiffs/fastmri_layernorm/compare.py index b67219a71..f26ad185e 100644 --- a/tests/modeldiffs/fastmri_layernorm/compare.py +++ b/tests/modeldiffs/fastmri_layernorm/compare.py @@ -68,10 +68,10 @@ def sort_key(k): image = torch.randn(2, 320, 320) jax_batch = {'inputs': image.detach().numpy()} - pyt_batch = {'inputs': image} + pytorch_batch = {'inputs': image} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pyt_batch, + augmented_and_preprocessed_input_batch=pytorch_batch, model_state=None, mode=spec.ForwardPassMode.EVAL, rng=None, diff --git a/tests/modeldiffs/fastmri_model_size/compare.py b/tests/modeldiffs/fastmri_model_size/compare.py index d20b1b9ae..42789539b 100644 --- a/tests/modeldiffs/fastmri_model_size/compare.py +++ b/tests/modeldiffs/fastmri_model_size/compare.py @@ -61,10 +61,10 @@ def sort_key(k): image = torch.randn(2, 320, 320) jax_batch = {'inputs': image.detach().numpy()} - pyt_batch = {'inputs': image} + pytorch_batch = {'inputs': image} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pyt_batch, + augmented_and_preprocessed_input_batch=pytorch_batch, model_state=None, mode=spec.ForwardPassMode.EVAL, rng=None, diff --git a/tests/modeldiffs/fastmri_tanh/compare.py b/tests/modeldiffs/fastmri_tanh/compare.py index ead07ba94..13ecb890c 100644 --- a/tests/modeldiffs/fastmri_tanh/compare.py +++ b/tests/modeldiffs/fastmri_tanh/compare.py @@ -61,10 +61,10 @@ def sort_key(k): image = torch.randn(2, 320, 320) jax_batch = {'inputs': image.detach().numpy()} - pyt_batch = {'inputs': image} + pytorch_batch = {'inputs': image} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pyt_batch, + augmented_and_preprocessed_input_batch=pytorch_batch, model_state=None, mode=spec.ForwardPassMode.EVAL, rng=None, diff --git a/tests/modeldiffs/imagenet_resnet/compare.py b/tests/modeldiffs/imagenet_resnet/compare.py index 933aed1e5..59ab45555 100644 --- a/tests/modeldiffs/imagenet_resnet/compare.py +++ b/tests/modeldiffs/imagenet_resnet/compare.py @@ -78,10 +78,10 @@ def sd_transform(sd): image = torch.randn(2, 3, 224, 224) jax_batch = {'inputs': image.permute(0, 2, 3, 1).detach().numpy()} - pyt_batch = {'inputs': image} + pytorch_batch = {'inputs': image} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pyt_batch, + augmented_and_preprocessed_input_batch=pytorch_batch, model_state=None, mode=spec.ForwardPassMode.EVAL, rng=None, diff --git a/tests/modeldiffs/imagenet_resnet/gelu_compare.py b/tests/modeldiffs/imagenet_resnet/gelu_compare.py index 4387b0c4c..07510ad70 100644 --- a/tests/modeldiffs/imagenet_resnet/gelu_compare.py +++ b/tests/modeldiffs/imagenet_resnet/gelu_compare.py @@ -25,10 +25,10 @@ image = torch.randn(2, 3, 224, 224) jax_batch = {'inputs': image.permute(0, 2, 3, 1).detach().numpy()} - pyt_batch = {'inputs': image} + pytorch_batch = {'inputs': image} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pyt_batch, + augmented_and_preprocessed_input_batch=pytorch_batch, model_state=None, mode=spec.ForwardPassMode.EVAL, rng=None, diff --git a/tests/modeldiffs/imagenet_resnet/silu_compare.py b/tests/modeldiffs/imagenet_resnet/silu_compare.py index 8ca9408bc..8246d17a2 100644 --- a/tests/modeldiffs/imagenet_resnet/silu_compare.py +++ b/tests/modeldiffs/imagenet_resnet/silu_compare.py @@ -25,10 +25,10 @@ image = torch.randn(2, 3, 224, 224) jax_batch = {'inputs': image.permute(0, 2, 3, 1).detach().numpy()} - pyt_batch = {'inputs': image} + pytorch_batch = {'inputs': image} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pyt_batch, + augmented_and_preprocessed_input_batch=pytorch_batch, model_state=None, mode=spec.ForwardPassMode.EVAL, rng=None, diff --git a/tests/modeldiffs/imagenet_vit/compare.py b/tests/modeldiffs/imagenet_vit/compare.py index f42ec3172..b4ca7d8ec 100644 --- a/tests/modeldiffs/imagenet_vit/compare.py +++ b/tests/modeldiffs/imagenet_vit/compare.py @@ -91,10 +91,10 @@ def key_transform(k): image = torch.randn(2, 3, 224, 224) jax_batch = {'inputs': image.permute(0, 2, 3, 1).detach().numpy()} - pyt_batch = {'inputs': image} + pytorch_batch = {'inputs': image} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pyt_batch, + augmented_and_preprocessed_input_batch=pytorch_batch, model_state=None, mode=spec.ForwardPassMode.EVAL, rng=None, diff --git a/tests/modeldiffs/imagenet_vit_glu/compare.py b/tests/modeldiffs/imagenet_vit_glu/compare.py index fb73d51ae..c152410b5 100644 --- a/tests/modeldiffs/imagenet_vit_glu/compare.py +++ b/tests/modeldiffs/imagenet_vit_glu/compare.py @@ -27,10 +27,10 @@ image = torch.randn(2, 3, 224, 224) jax_batch = {'inputs': image.permute(0, 2, 3, 1).detach().numpy()} - pyt_batch = {'inputs': image} + pytorch_batch = {'inputs': image} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pyt_batch, + augmented_and_preprocessed_input_batch=pytorch_batch, model_state=None, mode=spec.ForwardPassMode.EVAL, rng=None, diff --git a/tests/modeldiffs/imagenet_vit_map/compare.py b/tests/modeldiffs/imagenet_vit_map/compare.py index 8b6f9928d..7f1af41ab 100644 --- a/tests/modeldiffs/imagenet_vit_map/compare.py +++ b/tests/modeldiffs/imagenet_vit_map/compare.py @@ -38,10 +38,10 @@ def sd_transform(sd): image = torch.randn(2, 3, 224, 224) jax_batch = {'inputs': image.permute(0, 2, 3, 1).detach().numpy()} - pyt_batch = {'inputs': image} + pytorch_batch = {'inputs': image} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pyt_batch, + augmented_and_preprocessed_input_batch=pytorch_batch, model_state=None, mode=spec.ForwardPassMode.EVAL, rng=None, diff --git a/tests/modeldiffs/imagenet_vit_postln/compare.py b/tests/modeldiffs/imagenet_vit_postln/compare.py index 2b7f93b09..a3a639101 100644 --- a/tests/modeldiffs/imagenet_vit_postln/compare.py +++ b/tests/modeldiffs/imagenet_vit_postln/compare.py @@ -27,10 +27,10 @@ image = torch.randn(2, 3, 224, 224) jax_batch = {'inputs': image.permute(0, 2, 3, 1).detach().numpy()} - pyt_batch = {'inputs': image} + pytorch_batch = {'inputs': image} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pyt_batch, + augmented_and_preprocessed_input_batch=pytorch_batch, model_state=None, mode=spec.ForwardPassMode.EVAL, rng=None, diff --git a/tests/modeldiffs/librispeech_conformer/compare.py b/tests/modeldiffs/librispeech_conformer/compare.py index a6a1a78c0..664b1242d 100644 --- a/tests/modeldiffs/librispeech_conformer/compare.py +++ b/tests/modeldiffs/librispeech_conformer/compare.py @@ -66,10 +66,10 @@ def sd_transform(sd): pad[0, 200000:] = 1 jax_batch = {'inputs': (wave.detach().numpy(), pad.detach().numpy())} - pyt_batch = {'inputs': (wave, pad)} + pytorch_batch = {'inputs': (wave, pad)} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pyt_batch, + augmented_and_preprocessed_input_batch=pytorch_batch, model_state=None, mode=spec.ForwardPassMode.EVAL, rng=None, diff --git a/tests/modeldiffs/librispeech_conformer_attention_temperature/compare.py b/tests/modeldiffs/librispeech_conformer_attention_temperature/compare.py index bc56ca623..b0812e77d 100644 --- a/tests/modeldiffs/librispeech_conformer_attention_temperature/compare.py +++ b/tests/modeldiffs/librispeech_conformer_attention_temperature/compare.py @@ -66,10 +66,10 @@ def sd_transform(sd): pad[0, 200000:] = 1 jax_batch = {'inputs': (wave.detach().numpy(), pad.detach().numpy())} - pyt_batch = {'inputs': (wave, pad)} + pytorch_batch = {'inputs': (wave, pad)} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pyt_batch, + augmented_and_preprocessed_input_batch=pytorch_batch, model_state=None, mode=spec.ForwardPassMode.EVAL, rng=None, diff --git a/tests/modeldiffs/librispeech_conformer_gelu/compare.py b/tests/modeldiffs/librispeech_conformer_gelu/compare.py index bc2ebaa3b..3032a0005 100644 --- a/tests/modeldiffs/librispeech_conformer_gelu/compare.py +++ b/tests/modeldiffs/librispeech_conformer_gelu/compare.py @@ -66,10 +66,10 @@ def sd_transform(sd): pad[0, 200000:] = 1 jax_batch = {'inputs': (wave.detach().numpy(), pad.detach().numpy())} - pyt_batch = {'inputs': (wave, pad)} + pytorch_batch = {'inputs': (wave, pad)} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pyt_batch, + augmented_and_preprocessed_input_batch=pytorch_batch, model_state=None, mode=spec.ForwardPassMode.EVAL, rng=None, diff --git a/tests/modeldiffs/librispeech_conformer_layernorm/compare.py b/tests/modeldiffs/librispeech_conformer_layernorm/compare.py index a17b818a5..d623ef352 100644 --- a/tests/modeldiffs/librispeech_conformer_layernorm/compare.py +++ b/tests/modeldiffs/librispeech_conformer_layernorm/compare.py @@ -66,10 +66,10 @@ def sd_transform(sd): pad[0, 200000:] = 1 jax_batch = {'inputs': (wave.detach().numpy(), pad.detach().numpy())} - pyt_batch = {'inputs': (wave, pad)} + pytorch_batch = {'inputs': (wave, pad)} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pyt_batch, + augmented_and_preprocessed_input_batch=pytorch_batch, model_state=None, mode=spec.ForwardPassMode.EVAL, rng=None, diff --git a/tests/modeldiffs/librispeech_deepspeech/compare.py b/tests/modeldiffs/librispeech_deepspeech/compare.py index 92e5b67d0..84b0a6c86 100644 --- a/tests/modeldiffs/librispeech_deepspeech/compare.py +++ b/tests/modeldiffs/librispeech_deepspeech/compare.py @@ -91,10 +91,10 @@ def sd_transform(sd): pad[0, 200000:] = 1 jax_batch = {'inputs': (wave.detach().numpy(), pad.detach().numpy())} - pyt_batch = {'inputs': (wave, pad)} + pytorch_batch = {'inputs': (wave, pad)} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pyt_batch, + augmented_and_preprocessed_input_batch=pytorch_batch, model_state=None, mode=spec.ForwardPassMode.EVAL, rng=None, diff --git a/tests/modeldiffs/librispeech_deepspeech_noresnet/compare.py b/tests/modeldiffs/librispeech_deepspeech_noresnet/compare.py index ec0dde49e..2540c1b93 100644 --- a/tests/modeldiffs/librispeech_deepspeech_noresnet/compare.py +++ b/tests/modeldiffs/librispeech_deepspeech_noresnet/compare.py @@ -27,10 +27,10 @@ pad[0, 200000:] = 1 jax_batch = {'inputs': (wave.detach().numpy(), pad.detach().numpy())} - pyt_batch = {'inputs': (wave, pad)} + pytorch_batch = {'inputs': (wave, pad)} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pyt_batch, + augmented_and_preprocessed_input_batch=pytorch_batch, model_state=None, mode=spec.ForwardPassMode.EVAL, rng=None, diff --git a/tests/modeldiffs/librispeech_deepspeech_normaug/compare.py b/tests/modeldiffs/librispeech_deepspeech_normaug/compare.py index a3231b2f9..e5972120d 100644 --- a/tests/modeldiffs/librispeech_deepspeech_normaug/compare.py +++ b/tests/modeldiffs/librispeech_deepspeech_normaug/compare.py @@ -27,10 +27,10 @@ pad[0, 200000:] = 1 jax_batch = {'inputs': (wave.detach().numpy(), pad.detach().numpy())} - pyt_batch = {'inputs': (wave, pad)} + pytorch_batch = {'inputs': (wave, pad)} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pyt_batch, + augmented_and_preprocessed_input_batch=pytorch_batch, model_state=None, mode=spec.ForwardPassMode.EVAL, rng=None, diff --git a/tests/modeldiffs/librispeech_deepspeech_tanh/compare.py b/tests/modeldiffs/librispeech_deepspeech_tanh/compare.py index 518c8c9e0..4d2c4a5d5 100644 --- a/tests/modeldiffs/librispeech_deepspeech_tanh/compare.py +++ b/tests/modeldiffs/librispeech_deepspeech_tanh/compare.py @@ -27,10 +27,10 @@ pad[0, 200000:] = 1 jax_batch = {'inputs': (wave.detach().numpy(), pad.detach().numpy())} - pyt_batch = {'inputs': (wave, pad)} + pytorch_batch = {'inputs': (wave, pad)} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pyt_batch, + augmented_and_preprocessed_input_batch=pytorch_batch, model_state=None, mode=spec.ForwardPassMode.EVAL, rng=None, diff --git a/tests/modeldiffs/ogbg/compare.py b/tests/modeldiffs/ogbg/compare.py index e41f032c6..5d5ef50bf 100644 --- a/tests/modeldiffs/ogbg/compare.py +++ b/tests/modeldiffs/ogbg/compare.py @@ -79,7 +79,7 @@ def sd_transform(sd): jax_workload = JaxWorkload() pytorch_workload = PyTorchWorkload() - pyt_batch = dict( + pytorch_batch = dict( n_node=torch.LongTensor([5]), n_edge=torch.LongTensor([5]), nodes=torch.randn(5, 9), @@ -88,17 +88,17 @@ def sd_transform(sd): senders=torch.LongTensor(list(range(5))), receivers=torch.LongTensor([(i + 1) % 5 for i in range(5)])) - jax_batch = {k: np.array(v) for k, v in pyt_batch.items()} + jax_batch = {k: np.array(v) for k, v in pytorch_batch.items()} # Test outputs for identical weights and inputs. graph_j = jraph.GraphsTuple(**jax_batch) - graph_p = jraph.GraphsTuple(**pyt_batch) + graph_p = jraph.GraphsTuple(**pytorch_batch) jax_batch = {'inputs': graph_j} - pyt_batch = {'inputs': graph_p} + pytorch_batch = {'inputs': graph_p} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pyt_batch, + augmented_and_preprocessed_input_batch=pytorch_batch, model_state=None, mode=spec.ForwardPassMode.EVAL, rng=None, diff --git a/tests/modeldiffs/ogbg_gelu/compare.py b/tests/modeldiffs/ogbg_gelu/compare.py index bbcc49255..fc3992998 100644 --- a/tests/modeldiffs/ogbg_gelu/compare.py +++ b/tests/modeldiffs/ogbg_gelu/compare.py @@ -79,7 +79,7 @@ def sd_transform(sd): jax_workload = JaxWorkload() pytorch_workload = PyTorchWorkload() - pyt_batch = dict( + pytorch_batch = dict( n_node=torch.LongTensor([5]), n_edge=torch.LongTensor([5]), nodes=torch.randn(5, 9), @@ -88,17 +88,17 @@ def sd_transform(sd): senders=torch.LongTensor(list(range(5))), receivers=torch.LongTensor([(i + 1) % 5 for i in range(5)])) - jax_batch = {k: np.array(v) for k, v in pyt_batch.items()} + jax_batch = {k: np.array(v) for k, v in pytorch_batch.items()} # Test outputs for identical weights and inputs. graph_j = jraph.GraphsTuple(**jax_batch) - graph_p = jraph.GraphsTuple(**pyt_batch) + graph_p = jraph.GraphsTuple(**pytorch_batch) jax_batch = {'inputs': graph_j} - pyt_batch = {'inputs': graph_p} + pytorch_batch = {'inputs': graph_p} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pyt_batch, + augmented_and_preprocessed_input_batch=pytorch_batch, model_state=None, mode=spec.ForwardPassMode.EVAL, rng=None, diff --git a/tests/modeldiffs/ogbg_model_size/compare.py b/tests/modeldiffs/ogbg_model_size/compare.py index afb8e5728..e7cfa745c 100644 --- a/tests/modeldiffs/ogbg_model_size/compare.py +++ b/tests/modeldiffs/ogbg_model_size/compare.py @@ -79,7 +79,7 @@ def sd_transform(sd): jax_workload = JaxWorkload() pytorch_workload = PyTorchWorkload() - pyt_batch = dict( + pytorch_batch = dict( n_node=torch.LongTensor([5]), n_edge=torch.LongTensor([5]), nodes=torch.randn(5, 9), @@ -88,17 +88,17 @@ def sd_transform(sd): senders=torch.LongTensor(list(range(5))), receivers=torch.LongTensor([(i + 1) % 5 for i in range(5)])) - jax_batch = {k: np.array(v) for k, v in pyt_batch.items()} + jax_batch = {k: np.array(v) for k, v in pytorch_batch.items()} # Test outputs for identical weights and inputs. graph_j = jraph.GraphsTuple(**jax_batch) - graph_p = jraph.GraphsTuple(**pyt_batch) + graph_p = jraph.GraphsTuple(**pytorch_batch) jax_batch = {'inputs': graph_j} - pyt_batch = {'inputs': graph_p} + pytorch_batch = {'inputs': graph_p} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pyt_batch, + augmented_and_preprocessed_input_batch=pytorch_batch, model_state=None, mode=spec.ForwardPassMode.EVAL, rng=None, diff --git a/tests/modeldiffs/ogbg_silu/compare.py b/tests/modeldiffs/ogbg_silu/compare.py index 21ab631d4..4e3b96cf7 100644 --- a/tests/modeldiffs/ogbg_silu/compare.py +++ b/tests/modeldiffs/ogbg_silu/compare.py @@ -79,7 +79,7 @@ def sd_transform(sd): jax_workload = JaxWorkload() pytorch_workload = PyTorchWorkload() - pyt_batch = dict( + pytorch_batch = dict( n_node=torch.LongTensor([5]), n_edge=torch.LongTensor([5]), nodes=torch.randn(5, 9), @@ -88,17 +88,17 @@ def sd_transform(sd): senders=torch.LongTensor(list(range(5))), receivers=torch.LongTensor([(i + 1) % 5 for i in range(5)])) - jax_batch = {k: np.array(v) for k, v in pyt_batch.items()} + jax_batch = {k: np.array(v) for k, v in pytorch_batch.items()} # Test outputs for identical weights and inputs. graph_j = jraph.GraphsTuple(**jax_batch) - graph_p = jraph.GraphsTuple(**pyt_batch) + graph_p = jraph.GraphsTuple(**pytorch_batch) jax_batch = {'inputs': graph_j} - pyt_batch = {'inputs': graph_p} + pytorch_batch = {'inputs': graph_p} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pyt_batch, + augmented_and_preprocessed_input_batch=pytorch_batch, model_state=None, mode=spec.ForwardPassMode.EVAL, rng=None, diff --git a/tests/modeldiffs/wmt/compare.py b/tests/modeldiffs/wmt/compare.py index 2e01ce276..109bfa629 100644 --- a/tests/modeldiffs/wmt/compare.py +++ b/tests/modeldiffs/wmt/compare.py @@ -115,10 +115,10 @@ def sd_transform(sd): 'inputs': inp_tokens.detach().numpy(), 'targets': tgt_tokens.detach().numpy(), } - pyt_batch = {'inputs': inp_tokens, 'targets': tgt_tokens} + pytorch_batch = {'inputs': inp_tokens, 'targets': tgt_tokens} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pyt_batch, + augmented_and_preprocessed_input_batch=pytorch_batch, model_state=None, mode=spec.ForwardPassMode.EVAL, rng=None, diff --git a/tests/modeldiffs/wmt_attention_temp/compare.py b/tests/modeldiffs/wmt_attention_temp/compare.py index 61a807452..1aa20fe3b 100644 --- a/tests/modeldiffs/wmt_attention_temp/compare.py +++ b/tests/modeldiffs/wmt_attention_temp/compare.py @@ -116,10 +116,10 @@ def sd_transform(sd): 'inputs': inp_tokens.detach().numpy(), 'targets': tgt_tokens.detach().numpy(), } - pyt_batch = {'inputs': inp_tokens, 'targets': tgt_tokens} + pytorch_batch = {'inputs': inp_tokens, 'targets': tgt_tokens} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pyt_batch, + augmented_and_preprocessed_input_batch=pytorch_batch, model_state=None, mode=spec.ForwardPassMode.EVAL, rng=None, diff --git a/tests/modeldiffs/wmt_glu_tanh/compare.py b/tests/modeldiffs/wmt_glu_tanh/compare.py index 1e335a81e..e98a6945d 100644 --- a/tests/modeldiffs/wmt_glu_tanh/compare.py +++ b/tests/modeldiffs/wmt_glu_tanh/compare.py @@ -116,10 +116,10 @@ def sd_transform(sd): 'inputs': inp_tokens.detach().numpy(), 'targets': tgt_tokens.detach().numpy(), } - pyt_batch = {'inputs': inp_tokens, 'targets': tgt_tokens} + pytorch_batch = {'inputs': inp_tokens, 'targets': tgt_tokens} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pyt_batch, + augmented_and_preprocessed_input_batch=pytorch_batch, model_state=None, mode=spec.ForwardPassMode.EVAL, rng=None, diff --git a/tests/modeldiffs/wmt_post_ln/compare.py b/tests/modeldiffs/wmt_post_ln/compare.py index 34ce37c83..d110715b5 100644 --- a/tests/modeldiffs/wmt_post_ln/compare.py +++ b/tests/modeldiffs/wmt_post_ln/compare.py @@ -116,10 +116,10 @@ def sd_transform(sd): 'inputs': inp_tokens.detach().numpy(), 'targets': tgt_tokens.detach().numpy(), } - pyt_batch = {'inputs': inp_tokens, 'targets': tgt_tokens} + pytorch_batch = {'inputs': inp_tokens, 'targets': tgt_tokens} pytorch_model_kwargs = dict( - augmented_and_preprocessed_input_batch=pyt_batch, + augmented_and_preprocessed_input_batch=pytorch_batch, model_state=None, mode=spec.ForwardPassMode.EVAL, rng=None, diff --git a/tests/test_traindiffs.py b/tests/test_traindiffs.py index 80795826d..cea589202 100644 --- a/tests/test_traindiffs.py +++ b/tests/test_traindiffs.py @@ -46,11 +46,11 @@ def test_workload(self, workload): CUDA OOM errors resulting from the two frameworks competing with each other for GPU memory. """ name = f'Testing {workload}' - jax_logs = '/tmp/jax_log.pkl' - pyt_logs = '/tmp/pyt_log.pkl' + jax_logs_path = '/tmp/jax_log.pkl' + pytorch_logs_path = '/tmp/pyt_log.pkl' try: run( - f'XLA_PYTHON_CLIENT_ALLOCATOR=platform python -m tests.reference_algorithm_tests --workload={workload} --framework=jax --global_batch_size={GLOBAL_BATCH_SIZE} --log_file={jax_logs}' + f'XLA_PYTHON_CLIENT_ALLOCATOR=platform python -m tests.reference_algorithm_tests --workload={workload} --framework=jax --global_batch_size={GLOBAL_BATCH_SIZE} --log_file={jax_logs_path}' f' --submission_path=tests/modeldiffs/vanilla_sgd_jax.py --identical=True --tuning_search_space=None --num_train_steps={NUM_TRAIN_STEPS}', shell=True, stdout=DEVNULL, @@ -60,7 +60,7 @@ def test_workload(self, workload): print("Error:", e) try: run( - f'XLA_PYTHON_CLIENT_ALLOCATOR=platform torchrun --standalone --nnodes 1 --nproc_per_node 8 -m tests.reference_algorithm_tests --workload={workload} --framework=pytorch --global_batch_size={GLOBAL_BATCH_SIZE} --log_file={pyt_logs}' + f'XLA_PYTHON_CLIENT_ALLOCATOR=platform torchrun --standalone --nnodes 1 --nproc_per_node 8 -m tests.reference_algorithm_tests --workload={workload} --framework=pytorch --global_batch_size={GLOBAL_BATCH_SIZE} --log_file={pytorch_logs_path}' f' --submission_path=tests/modeldiffs/vanilla_sgd_pytorch.py --identical=True --tuning_search_space=None --num_train_steps={NUM_TRAIN_STEPS}', shell=True, stdout=DEVNULL, @@ -68,13 +68,13 @@ def test_workload(self, workload): check=True) except subprocess.CalledProcessError as e: print("Error:", e) - with open(jax_logs, 'rb') as f: + with open(jax_logs_path, 'rb') as f: jax_results = pickle.load(f) - with open(pyt_logs, 'rb') as f: - pyt_results = pickle.load(f) + with open(pytorch_logs_path, 'rb') as f: + pytorch_results = pickle.load(f) # PRINT RESULTS - k = next( + eval_metric_key = next( iter( filter(lambda k: 'train' in k and 'loss' in k, jax_results['eval_results'][0]))) @@ -99,30 +99,30 @@ def test_workload(self, workload): row = map(lambda x: str(round(x, 5)), [ - jax_results['eval_results'][i][k], - pyt_results['eval_results'][i][k], + jax_results['eval_results'][i][eval_metric_key], + pytorch_results['eval_results'][i][eval_metric_key], jax_results['scalars'][i]['grad_norm'], - pyt_results['scalars'][i]['grad_norm'], + pytorch_results['scalars'][i]['grad_norm'], jax_results['scalars'][i]['loss'], - pyt_results['scalars'][i]['loss'], + pytorch_results['scalars'][i]['loss'], ]) print(fmt([f'{i}', *row])) print('=' * len(header)) self.assertTrue( # eval_results allclose( - jax_results['eval_results'][i][k], - pyt_results['eval_results'][i][k], + jax_results['eval_results'][i][eval_metric_key], + pytorch_results['eval_results'][i][eval_metric_key], rtol=rtol)) self.assertTrue( # grad_norms allclose( jax_results['scalars'][i]['grad_norm'], - pyt_results['scalars'][i]['grad_norm'], + pytorch_results['scalars'][i]['grad_norm'], rtol=rtol)) self.assertTrue( # loss allclose( jax_results['scalars'][i]['loss'], - pyt_results['scalars'][i]['loss'], + pytorch_results['scalars'][i]['loss'], rtol=rtol))