Skip to content

Commit d96da4a

Browse files
committed
fix style
1 parent 7a20f93 commit d96da4a

File tree

1 file changed

+18
-13
lines changed

1 file changed

+18
-13
lines changed

tests/modeldiffs/diff.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -67,19 +67,24 @@ def out_diff(jax_workload,
6767

6868

6969
class ModelDiffRunner:
70-
def __init__(self, jax_workload,
71-
pytorch_workload,
72-
jax_model_kwargs,
73-
pytorch_model_kwargs,
74-
key_transform=None,
75-
sd_transform=None,
76-
out_transform=None) -> None:
77-
"""Initializes the instance based on diffing logic.
70+
71+
def __init__(self,
72+
jax_workload,
73+
pytorch_workload,
74+
jax_model_kwargs,
75+
pytorch_model_kwargs,
76+
key_transform=None,
77+
sd_transform=None,
78+
out_transform=None) -> None:
79+
"""
80+
Initializes the instance based on diffing logic.
81+
7882
Args:
79-
jax_workload: Workload implementation using JAX
80-
pytorch_workload: Workload implementation using PyTorch
81-
jax_model_kwargs: Arguments to be used for model_fn in jax workload
82-
pytorch_model_kwargs: Arguments to be used for model_fn in PyTorch workload
83+
jax_workload: Workload implementation using JAX.
84+
pytorch_workload: Workload implementation using PyTorch.
85+
jax_model_kwargs: Arguments to be used for model_fn in jax workload.
86+
pytorch_model_kwargs: Arguments to be used for model_fn in PyTorch
87+
workload.
8388
key_transform: Transformation function for keys.
8489
sd_transform: Transformation function for State Dictionary.
8590
out_transform: Transformation function for the output.
@@ -100,4 +105,4 @@ def run(self):
100105
self.pytorch_model_kwargs,
101106
self.key_transform,
102107
self.sd_transform,
103-
self.out_transform)
108+
self.out_transform)

0 commit comments

Comments
 (0)