File tree Expand file tree Collapse file tree 30 files changed +90
-101
lines changed
librispeech_conformer_attention_temperature
librispeech_conformer_gelu
librispeech_conformer_layernorm
librispeech_deepspeech_noresnet
librispeech_deepspeech_normaug
librispeech_deepspeech_tanh Expand file tree Collapse file tree 30 files changed +90
-101
lines changed Original file line number Diff line number Diff line change 1212 Criteo1TbDlrmSmallEmbedInitWorkload as JaxWorkload
1313from algoperf .workloads .criteo1tb .criteo1tb_pytorch .workload import \
1414 Criteo1TbDlrmSmallEmbedInitWorkload as PyTorchWorkload
15- from tests .modeldiffs .diff import out_diff
15+ from tests .modeldiffs .diff import ModelDiffRunner
1616
1717
1818def key_transform (k ):
@@ -73,11 +73,11 @@ def sd_transform(sd):
7373 rng = jax .random .PRNGKey (0 ),
7474 update_batch_norm = False )
7575
76- out_diff (
76+ ModelDiffRunner (
7777 jax_workload = jax_workload ,
7878 pytorch_workload = pytorch_workload ,
7979 jax_model_kwargs = jax_model_kwargs ,
8080 pytorch_model_kwargs = pytorch_model_kwargs ,
8181 key_transform = key_transform ,
8282 sd_transform = sd_transform ,
83- out_transform = None )
83+ out_transform = None ). run ()
Original file line number Diff line number Diff line change 1212 Criteo1TbDlrmSmallLayerNormWorkload as JaxWorkload
1313from algoperf .workloads .criteo1tb .criteo1tb_pytorch .workload import \
1414 Criteo1TbDlrmSmallLayerNormWorkload as PyTorchWorkload
15- from tests .modeldiffs .diff import out_diff
15+ from tests .modeldiffs .diff import ModelDiffRunner
1616
1717
1818def key_transform (k ):
@@ -85,11 +85,11 @@ def sd_transform(sd):
8585 rng = jax .random .PRNGKey (0 ),
8686 update_batch_norm = False )
8787
88- out_diff (
88+ ModelDiffRunner (
8989 jax_workload = jax_workload ,
9090 pytorch_workload = pytorch_workload ,
9191 jax_model_kwargs = jax_model_kwargs ,
9292 pytorch_model_kwargs = pytorch_model_kwargs ,
9393 key_transform = key_transform ,
9494 sd_transform = sd_transform ,
95- out_transform = None )
95+ out_transform = None ). run ()
Original file line number Diff line number Diff line change 1313 Criteo1TbDlrmSmallResNetWorkload as JaxWorkload
1414from algoperf .workloads .criteo1tb .criteo1tb_pytorch .workload import \
1515 Criteo1TbDlrmSmallResNetWorkload as PyTorchWorkload
16- from tests .modeldiffs .diff import out_diff
16+ from tests .modeldiffs .diff import ModelDiffRunner
1717
1818
1919def key_transform (k ):
@@ -92,11 +92,11 @@ def sd_transform(sd):
9292 rng = jax .random .PRNGKey (0 ),
9393 update_batch_norm = False )
9494
95- out_diff (
95+ ModelDiffRunner (
9696 jax_workload = jax_workload ,
9797 pytorch_workload = pytorch_workload ,
9898 jax_model_kwargs = jax_model_kwargs ,
9999 pytorch_model_kwargs = pytorch_model_kwargs ,
100100 key_transform = key_transform ,
101101 sd_transform = sd_transform ,
102- out_transform = None )
102+ out_transform = None ). run ()
Original file line number Diff line number Diff line change 1111 FastMRIWorkload as JaxWorkload
1212from algoperf .workloads .fastmri .fastmri_pytorch .workload import \
1313 FastMRIWorkload as PyTorchWorkload
14- from tests .modeldiffs .diff import out_diff
14+ from tests .modeldiffs .diff import ModelDiffRunner
1515
1616
1717def sd_transform (sd ):
@@ -76,11 +76,10 @@ def sort_key(k):
7676 rng = jax .random .PRNGKey (0 ),
7777 update_batch_norm = False )
7878
79- out_diff (
79+ ModelDiffRunner (
8080 jax_workload = jax_workload ,
8181 pytorch_workload = pytorch_workload ,
8282 jax_model_kwargs = jax_model_kwargs ,
8383 pytorch_model_kwargs = pytorch_model_kwargs ,
8484 key_transform = None ,
85- sd_transform = sd_transform ,
86- )
85+ sd_transform = sd_transform ).run ()
Original file line number Diff line number Diff line change 1111 FastMRILayerNormWorkload as JaxWorkload
1212from algoperf .workloads .fastmri .fastmri_pytorch .workload import \
1313 FastMRILayerNormWorkload as PyTorchWorkload
14- from tests .modeldiffs .diff import out_diff
14+ from tests .modeldiffs .diff import ModelDiffRunner
1515
1616
1717def sd_transform (sd ):
@@ -83,11 +83,10 @@ def sort_key(k):
8383 rng = jax .random .PRNGKey (0 ),
8484 update_batch_norm = False )
8585
86- out_diff (
86+ ModelDiffRunner (
8787 jax_workload = jax_workload ,
8888 pytorch_workload = pytorch_workload ,
8989 jax_model_kwargs = jax_model_kwargs ,
9090 pytorch_model_kwargs = pytorch_model_kwargs ,
9191 key_transform = None ,
92- sd_transform = sd_transform ,
93- )
92+ sd_transform = sd_transform ).run ()
Original file line number Diff line number Diff line change 1111 FastMRIModelSizeWorkload as JaxWorkload
1212from algoperf .workloads .fastmri .fastmri_pytorch .workload import \
1313 FastMRIModelSizeWorkload as PyTorchWorkload
14- from tests .modeldiffs .diff import out_diff
14+ from tests .modeldiffs .diff import ModelDiffRunner
1515
1616
1717def sd_transform (sd ):
@@ -76,11 +76,10 @@ def sort_key(k):
7676 rng = jax .random .PRNGKey (0 ),
7777 update_batch_norm = False )
7878
79- out_diff (
79+ ModelDiffRunner (
8080 jax_workload = jax_workload ,
8181 pytorch_workload = pytorch_workload ,
8282 jax_model_kwargs = jax_model_kwargs ,
8383 pytorch_model_kwargs = pytorch_model_kwargs ,
8484 key_transform = None ,
85- sd_transform = sd_transform ,
86- )
85+ sd_transform = sd_transform ).run ()
Original file line number Diff line number Diff line change 1111 FastMRITanhWorkload as JaxWorkload
1212from algoperf .workloads .fastmri .fastmri_pytorch .workload import \
1313 FastMRITanhWorkload as PyTorchWorkload
14- from tests .modeldiffs .diff import out_diff
14+ from tests .modeldiffs .diff import ModelDiffRunner
1515
1616
1717def sd_transform (sd ):
@@ -76,11 +76,10 @@ def sort_key(k):
7676 rng = jax .random .PRNGKey (0 ),
7777 update_batch_norm = False )
7878
79- out_diff (
79+ ModelDiffRunner (
8080 jax_workload = jax_workload ,
8181 pytorch_workload = pytorch_workload ,
8282 jax_model_kwargs = jax_model_kwargs ,
8383 pytorch_model_kwargs = pytorch_model_kwargs ,
8484 key_transform = None ,
85- sd_transform = sd_transform ,
86- )
85+ sd_transform = sd_transform ).run ()
Original file line number Diff line number Diff line change 1111 ImagenetResNetWorkload as JaxWorkload
1212from algoperf .workloads .imagenet_resnet .imagenet_pytorch .workload import \
1313 ImagenetResNetWorkload as PyTorchWorkload
14- from tests .modeldiffs .diff import out_diff
14+ from tests .modeldiffs .diff import ModelDiffRunner
1515
1616
1717def key_transform (k ):
@@ -93,11 +93,10 @@ def sd_transform(sd):
9393 rng = jax .random .PRNGKey (0 ),
9494 update_batch_norm = False )
9595
96- out_diff (
96+ ModelDiffRunner (
9797 jax_workload = jax_workload ,
9898 pytorch_workload = pytorch_workload ,
9999 jax_model_kwargs = jax_model_kwargs ,
100100 pytorch_model_kwargs = pytorch_model_kwargs ,
101101 key_transform = key_transform ,
102- sd_transform = sd_transform ,
103- )
102+ sd_transform = sd_transform ).run ()
Original file line number Diff line number Diff line change 1111 ImagenetResNetGELUWorkload as JaxWorkload
1212from algoperf .workloads .imagenet_resnet .imagenet_pytorch .workload import \
1313 ImagenetResNetGELUWorkload as PyTorchWorkload
14- from tests .modeldiffs .diff import out_diff
14+ from tests .modeldiffs .diff import ModelDiffRunner
1515from tests .modeldiffs .imagenet_resnet .compare import key_transform
1616from tests .modeldiffs .imagenet_resnet .compare import sd_transform
1717
4040 rng = jax .random .PRNGKey (0 ),
4141 update_batch_norm = False )
4242
43- out_diff (
43+ ModelDiffRunner (
4444 jax_workload = jax_workload ,
4545 pytorch_workload = pytorch_workload ,
4646 jax_model_kwargs = jax_model_kwargs ,
4747 pytorch_model_kwargs = pytorch_model_kwargs ,
4848 key_transform = key_transform ,
49- sd_transform = sd_transform ,
50- )
49+ sd_transform = sd_transform ).run ()
Original file line number Diff line number Diff line change 1111 ImagenetResNetSiLUWorkload as JaxWorkload
1212from algoperf .workloads .imagenet_resnet .imagenet_pytorch .workload import \
1313 ImagenetResNetSiLUWorkload as PyTorchWorkload
14- from tests .modeldiffs .diff import out_diff
14+ from tests .modeldiffs .diff import ModelDiffRunner
1515from tests .modeldiffs .imagenet_resnet .compare import key_transform
1616from tests .modeldiffs .imagenet_resnet .compare import sd_transform
1717
4040 rng = jax .random .PRNGKey (0 ),
4141 update_batch_norm = False )
4242
43- out_diff (
43+ ModelDiffRunner (
4444 jax_workload = jax_workload ,
4545 pytorch_workload = pytorch_workload ,
4646 jax_model_kwargs = jax_model_kwargs ,
4747 pytorch_model_kwargs = pytorch_model_kwargs ,
4848 key_transform = key_transform ,
49- sd_transform = sd_transform ,
50- )
49+ sd_transform = sd_transform ).run ()
You can’t perform that action at this time.
0 commit comments