1313 Criteo1TbDlrmSmallResNetWorkload as JaxWorkload
1414from algoperf .workloads .criteo1tb .criteo1tb_pytorch .workload import \
1515 Criteo1TbDlrmSmallResNetWorkload as PyTorchWorkload
16- from tests .modeldiffs .diff import out_diff
16+ from tests .modeldiffs .diff import ModelDiffRunner
1717
1818
1919def key_transform (k ):
@@ -64,7 +64,7 @@ def sd_transform(sd):
6464 jax_workload = JaxWorkload ()
6565 pytorch_workload = PyTorchWorkload ()
6666
67- pyt_batch = {
67+ pytorch_batch = {
6868 'inputs' : torch .ones ((2 , 13 + 26 )),
6969 'targets' : torch .randint (low = 0 , high = 1 , size = (2 ,)),
7070 'weights' : torch .ones (2 ),
@@ -75,12 +75,12 @@ def sd_transform(sd):
7575 input_size = 13 + num_categorical_features
7676 input_shape = (init_fake_batch_size , input_size )
7777 fake_inputs = jnp .ones (input_shape , jnp .float32 )
78- jax_batch = {k : np .array (v ) for k , v in pyt_batch .items ()}
78+ jax_batch = {k : np .array (v ) for k , v in pytorch_batch .items ()}
7979 jax_batch ['inputs' ] = fake_inputs
8080
8181 # Test outputs for identical weights and inputs.
8282 pytorch_model_kwargs = dict (
83- augmented_and_preprocessed_input_batch = pyt_batch ,
83+ augmented_and_preprocessed_input_batch = pytorch_batch ,
8484 model_state = None ,
8585 mode = spec .ForwardPassMode .EVAL ,
8686 rng = None ,
@@ -92,11 +92,11 @@ def sd_transform(sd):
9292 rng = jax .random .PRNGKey (0 ),
9393 update_batch_norm = False )
9494
95- out_diff (
95+ ModelDiffRunner (
9696 jax_workload = jax_workload ,
9797 pytorch_workload = pytorch_workload ,
9898 jax_model_kwargs = jax_model_kwargs ,
9999 pytorch_model_kwargs = pytorch_model_kwargs ,
100100 key_transform = key_transform ,
101101 sd_transform = sd_transform ,
102- out_transform = None )
102+ out_transform = None ). run ()
0 commit comments