File tree Expand file tree Collapse file tree 1 file changed +18
-13
lines changed Expand file tree Collapse file tree 1 file changed +18
-13
lines changed Original file line number Diff line number Diff line change @@ -67,19 +67,24 @@ def out_diff(jax_workload,
6767
6868
6969class ModelDiffRunner :
70- def __init__ (self , jax_workload ,
71- pytorch_workload ,
72- jax_model_kwargs ,
73- pytorch_model_kwargs ,
74- key_transform = None ,
75- sd_transform = None ,
76- out_transform = None ) -> None :
77- """Initializes the instance based on diffing logic.
70+
71+ def __init__ (self ,
72+ jax_workload ,
73+ pytorch_workload ,
74+ jax_model_kwargs ,
75+ pytorch_model_kwargs ,
76+ key_transform = None ,
77+ sd_transform = None ,
78+ out_transform = None ) -> None :
79+ """
80+ Initializes the instance based on diffing logic.
81+
7882 Args:
79- jax_workload: Workload implementation using JAX
80- pytorch_workload: Workload implementation using PyTorch
81- jax_model_kwargs: Arguments to be used for model_fn in jax workload
82- pytorch_model_kwargs: Arguments to be used for model_fn in PyTorch workload
83+ jax_workload: Workload implementation using JAX.
84+ pytorch_workload: Workload implementation using PyTorch.
85+ jax_model_kwargs: Arguments to be used for model_fn in jax workload.
86+ pytorch_model_kwargs: Arguments to be used for model_fn in PyTorch
87+ workload.
8388 key_transform: Transformation function for keys.
8489 sd_transform: Transformation function for State Dictionary.
8590 out_transform: Transformation function for the output.
@@ -100,4 +105,4 @@ def run(self):
100105 self .pytorch_model_kwargs ,
101106 self .key_transform ,
102107 self .sd_transform ,
103- self .out_transform )
108+ self .out_transform )
You can’t perform that action at this time.
0 commit comments