Skip to content

Commit 73bcca4

Browse files
committed
setting up to run the workloads
1 parent 5741c76 commit 73bcca4

File tree

3 files changed

+0
-102
lines changed

3 files changed

+0
-102
lines changed

algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from algoperf import param_utils, spec
1212
from algoperf.workloads.criteo1tb.criteo1tb_jax import models
1313
from algoperf.workloads.criteo1tb.workload import BaseCriteo1TbDlrmSmallWorkload
14-
from custom_pytorch_jax_converter import use_pytorch_weights
1514

1615

1716
class Criteo1TbDlrmSmallWorkload(BaseCriteo1TbDlrmSmallWorkload):
@@ -105,7 +104,6 @@ def init_model_fn(
105104
jnp.ones(input_shape, jnp.float32),
106105
)
107106
initial_params = initial_variables['params']
108-
initial_params = use_pytorch_weights(file_name="~/results/pytorch_base_model_criteo1tb_1_july.pth")
109107
self._param_shapes = param_utils.jax_param_shapes(initial_params)
110108
self._param_types = param_utils.jax_param_types(self._param_shapes)
111109
return jax_utils.replicate(initial_params), None

algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,6 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
8282
use_layer_norm=self.use_layer_norm,
8383
embedding_init_multiplier=self.embedding_init_multiplier,
8484
)
85-
torch.save(model.state_dict(), '~/results/pytorch_base_model_criteo1tb_1_july.pth')
8685
self._param_shapes = param_utils.pytorch_param_shapes(model)
8786
self._param_types = param_utils.pytorch_param_types(self._param_shapes)
8887
model.to(DEVICE)

custom_pytorch_jax_converter.py

Lines changed: 0 additions & 99 deletions
This file was deleted.

0 commit comments

Comments
 (0)