File tree Expand file tree Collapse file tree 2 files changed +40
-3
lines changed
Expand file tree Collapse file tree 2 files changed +40
-3
lines changed Original file line number Diff line number Diff line change 1212 Criteo1TbDlrmSmallWorkload as JaxWorkload
1313from algoperf .workloads .criteo1tb .criteo1tb_pytorch .workload import \
1414 Criteo1TbDlrmSmallWorkload as PyTorchWorkload
15- from tests .modeldiffs .diff import out_diff
15+ from tests .modeldiffs .diff import ModelDiffRunner
1616
1717
1818def key_transform (k ):
@@ -74,11 +74,11 @@ def sd_transform(sd):
7474 rng = jax .random .PRNGKey (0 ),
7575 update_batch_norm = False )
7676
77- out_diff (
77+ ModelDiffRunner (
7878 jax_workload = jax_workload ,
7979 pytorch_workload = pytorch_workload ,
8080 jax_model_kwargs = jax_model_kwargs ,
8181 pytorch_model_kwargs = pytorch_model_kwargs ,
8282 key_transform = key_transform ,
8383 sd_transform = sd_transform ,
84- out_transform = None )
84+ out_transform = None ). run ()
Original file line number Diff line number Diff line change @@ -64,3 +64,40 @@ def out_diff(jax_workload,
6464
6565 print (f'Max fprop difference between jax and pytorch: { max_diff } ' )
6666 print (f'Min fprop difference between jax and pytorch: { min_diff } ' )
67+
68+
69+ class ModelDiffRunner :
70+ def __init__ (self , jax_workload ,
71+ pytorch_workload ,
72+ jax_model_kwargs ,
73+ pytorch_model_kwargs ,
74+ key_transform = None ,
75+ sd_transform = None ,
76+ out_transform = None ) -> None :
77+ """Initializes the instance based on diffing logic.
78+ Args:
79+ jax_workload: Workload implementation using JAX
80+ pytorch_workload: Workload implementation using PyTorch
81+ jax_model_kwargs: Arguments to be used for model_fn in jax workload
82+ pytorch_model_kwargs: Arguments to be used for model_fn in PyTorch workload
83+ key_transform: Transformation function for keys.
84+ sd_transform: Transformation function for State Dictionary.
85+ out_transform: Transformation function for the output.
86+ """
87+
88+ self .jax_workload = jax_workload
89+ self .pytorch_workload = pytorch_workload
90+ self .jax_model_kwargs = jax_model_kwargs
91+ self .pytorch_model_kwargs = pytorch_model_kwargs
92+ self .key_transform = key_transform
93+ self .sd_transform = sd_transform
94+ self .out_transform = out_transform
95+
96+ def run (self ):
97+ out_diff (self .jax_workload ,
98+ self .pytorch_workload ,
99+ self .jax_model_kwargs ,
100+ self .pytorch_model_kwargs ,
101+ self .key_transform ,
102+ self .sd_transform ,
103+ self .out_transform )
You can’t perform that action at this time.
0 commit comments