Skip to content

Commit dde2eda

Browse files
Thomas Polasekfacebook-github-bot
authored andcommitted
Convert directory fbcode/aitemplate to use the Ruff Formatter (#1030)
Summary: Pull Request resolved: #1030 Converts the directory specified to use the Ruff formatter in pyfmt ruff_dog If this diff causes merge conflicts when rebasing, please run `hg status -n -0 --change . -I '**/*.{py,pyi}' | xargs -0 arc pyfmt` on your diff, and amend any changes before rebasing onto latest. That should help reduce or eliminate any merge conflicts. allow-large-files Reviewed By: amyreese Differential Revision: D64264218
1 parent 437b48a commit dde2eda

File tree

460 files changed

+498
-85
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

460 files changed

+498
-85
lines changed

examples/01_resnet-50/benchmark_ait.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ def mark_output(y):
4545

4646

4747
def compile_module(model_name, batch_size, **kwargs):
48-
4948
if model_name != "resnet50":
5049
raise NotImplementedError
5150

examples/01_resnet-50/weight_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
Only tested on resnet50
1818
"""
1919

20-
2120
import pickle
2221
import re
2322

examples/02_detectron2/demo.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""
1616
A main inference script for rcnn models
1717
"""
18+
1819
import glob
1920
import os
2021

examples/02_detectron2/modeling/roi_heads/roi_heads.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ def get_shape(self, x):
5959
return shape
6060

6161
def forward(self, features: Dict[str, Tensor], rois: Tensor, proposals: Tensor):
62-
6362
box_features = [features[f] for f in self.in_features]
6463
roi_feat = self.box_head(box_features, rois)
6564
detections = self.box_predictor(roi_feat, proposals)

examples/02_detectron2/predictor/builtin_meta.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
COCO model (with correct class names and colors).
2626
"""
2727

28-
2928
# All coco categories, together with their nice-looking visualization colors
3029
# It's from https://github.com/cocodataset/panopticapi/blob/master/panoptic_coco_categories.json
3130
COCO_CATEGORIES = [

examples/04_vit/weight_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
#
15-
"""script for converting vit model from timm to ait
16-
"""
15+
"""script for converting vit model from timm to ait"""
16+
1717
import pickle
1818

1919
import click

examples/05_stable_diffusion/scripts/download_pipeline.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
help="Pipeline files local directory.",
3737
)
3838
def download_pipeline_files(model_name, token, save_directory) -> None:
39-
4039
StableDiffusionPipeline.from_pretrained(
4140
model_name,
4241
revision="fp16",

fx2ait/fx2ait/acc_tracer/acc_normalizer.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,10 @@ def move_kwargs_to_acc_out_ty(
283283

284284
for kwarg_replacement_tuple in normalization_info.kwargs_to_move_to_acc_out_ty:
285285
if len(kwarg_replacement_tuple) == 2:
286-
orig_kwarg_name, tmd_field_name, move_to_qparams = *kwarg_replacement_tuple, False # type: ignore[misc]
286+
orig_kwarg_name, tmd_field_name, move_to_qparams = (
287+
*kwarg_replacement_tuple,
288+
False,
289+
) # type: ignore[misc]
287290
else:
288291
assert len(kwarg_replacement_tuple) == 3
289292
orig_kwarg_name, tmd_field_name, move_to_qparams = kwarg_replacement_tuple # type: ignore[misc]
@@ -331,9 +334,7 @@ def get_normalized_kwargs(
331334
new_kwargs[new_kwarg_name] = node.args[i]
332335
else:
333336
# Verify the arg we're trying to normalize was optional.
334-
assert (
335-
is_optional
336-
), f"Cannot normalize {orig_kwargs_names} to {new_kwarg_name} for {node.name}"
337+
assert is_optional, f"Cannot normalize {orig_kwargs_names} to {new_kwarg_name} for {node.name}"
337338
else:
338339
new_kwargs[new_kwarg_name] = node.kwargs[orig_kwargs_name]
339340

fx2ait/fx2ait/acc_tracer/acc_tracer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,10 @@ def __init__(self, orig):
462462
for k, v in orig.__dict__.items():
463463
if k == "_modules":
464464
for mod_k, mod_v in v.items():
465-
if getattr(mod_v, "_base_class_origin", type(mod_v)) in leaf_module_list: # type: ignore[operator]
465+
if (
466+
getattr(mod_v, "_base_class_origin", type(mod_v))
467+
in leaf_module_list
468+
): # type: ignore[operator]
466469
_LOGGER.info(
467470
f"Skip rewriting leaf module {type(mod_v)}"
468471
)

fx2ait/fx2ait/converters/ait_converters.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -920,8 +920,8 @@ def acc_ops_conv_transpose2d(
920920
# Grouped conv doesn't currently work on AIT CUDA, manually map
921921
groups = kwargs["groups"]
922922
assert (
923-
w_last_dim * groups
924-
) % 8 == 0, f"cutlass needs weight output channel={w_last_dim*groups} is not divisble by 8! This restriction may be not valid in newer version"
923+
(w_last_dim * groups) % 8 == 0
924+
), f"cutlass needs weight output channel={w_last_dim*groups} is not divisble by 8! This restriction may be not valid in newer version"
925925

926926
group_size = input_val.shape()[3]._attrs["values"][0] // groups
927927
w_group_size = weight.shape()[0]._attrs["values"][0] // groups
@@ -1767,7 +1767,7 @@ def acc_ops_to_dtype(
17671767
input_val = kwargs["input"]
17681768

17691769
def _get_cast_to_dtype_from_kwargs(
1770-
kwargs: Dict[str, Argument]
1770+
kwargs: Dict[str, Argument],
17711771
) -> Optional[torch.dtype]:
17721772
torch_dtype_to_ait_dtype_str = {
17731773
torch.float: "float32",

0 commit comments

Comments
 (0)