diff --git a/examples/02_detectron2/modeling/backbone/utils.py b/examples/02_detectron2/modeling/backbone/utils.py index 81a0cb203..9c301a48b 100644 --- a/examples/02_detectron2/modeling/backbone/utils.py +++ b/examples/02_detectron2/modeling/backbone/utils.py @@ -13,7 +13,6 @@ # limitations under the License. # from dataclasses import dataclass -from typing import Optional @dataclass @@ -24,7 +23,7 @@ class ShapeSpec: to complement the lack of shape inference ability among pytorch modules. """ - channels: Optional[int] = None - height: Optional[int] = None - width: Optional[int] = None - stride: Optional[int] = None + channels: int | None = None + height: int | None = None + width: int | None = None + stride: int | None = None diff --git a/examples/02_detectron2/modeling/roi_heads/box_head.py b/examples/02_detectron2/modeling/roi_heads/box_head.py index 0269a6a4a..4055b162b 100644 --- a/examples/02_detectron2/modeling/roi_heads/box_head.py +++ b/examples/02_detectron2/modeling/roi_heads/box_head.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import Tuple from aitemplate.compiler import ops from aitemplate.frontend import nn @@ -32,7 +31,7 @@ def __init__( feat_dim: int, fc_dim: int, pooled_size: int, - im_shape: Tuple[int, int], + im_shape: tuple[int, int], ): super().__init__() self.num_rois = num_rois diff --git a/examples/02_detectron2/modeling/roi_heads/fast_rcnn.py b/examples/02_detectron2/modeling/roi_heads/fast_rcnn.py index d825a59a0..e9990772b 100644 --- a/examples/02_detectron2/modeling/roi_heads/fast_rcnn.py +++ b/examples/02_detectron2/modeling/roi_heads/fast_rcnn.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import Tuple from aitemplate.compiler import ops from aitemplate.frontend import nn, Tensor @@ -21,7 +20,7 @@ class fast_rcnn_inference: def __init__( self, - im_shape: Tuple[int, int], + im_shape: tuple[int, int], num_rois: int, num_classes: int, clip_box: bool = True, diff --git a/examples/02_detectron2/modeling/roi_heads/mask_head.py b/examples/02_detectron2/modeling/roi_heads/mask_head.py index 94e022205..8fc9d2549 100644 --- a/examples/02_detectron2/modeling/roi_heads/mask_head.py +++ b/examples/02_detectron2/modeling/roi_heads/mask_head.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import Tuple from aitemplate.compiler import ops from aitemplate.frontend import nn @@ -31,7 +30,7 @@ def __init__( feat_dim: int, conv_dim: int, pooled_size: int, - im_shape: Tuple[int, int], + im_shape: tuple[int, int], ): super().__init__() HH, WW = im_shape diff --git a/examples/02_detectron2/modeling/roi_heads/roi_heads.py b/examples/02_detectron2/modeling/roi_heads/roi_heads.py index cdc5d1685..454bbf759 100644 --- a/examples/02_detectron2/modeling/roi_heads/roi_heads.py +++ b/examples/02_detectron2/modeling/roi_heads/roi_heads.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import Dict from aitemplate.compiler import ops @@ -58,7 +57,7 @@ def get_shape(self, x): shape = [it.value() for it in x._attrs["shape"]] return shape - def forward(self, features: Dict[str, Tensor], rois: Tensor, proposals: Tensor): + def forward(self, features: dict[str, Tensor], rois: Tensor, proposals: Tensor): box_features = [features[f] for f in self.in_features] roi_feat = self.box_head(box_features, rois) detections = self.box_predictor(roi_feat, proposals) diff --git a/examples/02_detectron2/predictor/predictor.py b/examples/02_detectron2/predictor/predictor.py index ce3f85f24..5d083e6f2 100644 --- a/examples/02_detectron2/predictor/predictor.py +++ b/examples/02_detectron2/predictor/predictor.py @@ -14,7 +14,6 @@ # import itertools import os -from typing import Tuple import cv2 import numpy as np @@ -117,7 +116,7 @@ def apply_bbox(self, bbox, im_w, im_h): @staticmethod def get_output_shape( oldh: int, oldw: int, short_edge_length: int, max_size: int - ) -> Tuple[int, int]: + ) -> tuple[int, int]: """ Compute the output size given input size and target short edge length. """ diff --git a/examples/03_bert/benchmark_ait.py b/examples/03_bert/benchmark_ait.py index a16244a9a..18d605723 100644 --- a/examples/03_bert/benchmark_ait.py +++ b/examples/03_bert/benchmark_ait.py @@ -14,8 +14,6 @@ # import os -from typing import Dict, List - import click import numpy as np import torch @@ -35,12 +33,12 @@ def mark_output(y: Tensor) -> None: y[i]._attrs["is_output"] = True y[i]._attrs["name"] = "output_%d" % (i) y_shape = [d._attrs["values"][0] for d in y[i]._attrs["shape"]] - print("output_{} shape: {}".format(i, y_shape)) + print(f"output_{i} shape: {y_shape}") def create_bert_inputs( batch_size: int, seq_length: int, dtype: str = "int64" -) -> List[Tensor]: +) -> list[Tensor]: input_ids = Tensor( shape=[batch_size, seq_length], name="input_ids", @@ -76,7 +74,7 @@ def create_bert_encoders_input( def create_bert_inputs_pt( batch_size: int, seq_length: int, dtype: torch.dtype = torch.int64 -) -> Dict[str, torch.Tensor]: +) -> dict[str, torch.Tensor]: input_ids = torch.randn(batch_size, seq_length).to(dtype).cuda() token_type_ids = torch.randn(batch_size, seq_length).to(dtype).cuda() position_ids = torch.randn(batch_size, seq_length).to(dtype).cuda() @@ -90,14 +88,14 @@ def create_bert_inputs_pt( def create_bert_encoders_inputs_pt( batch_size: int, seq_length: int, hidden_size: int -) -> Dict[str, torch.Tensor]: +) -> dict[str, torch.Tensor]: encoder_input = torch.randn([batch_size, seq_length, hidden_size]).cuda().half() return {"input": encoder_input} def map_pt_params( ait_bert, pt_bert, batch_size: int, seq_length: int -) -> Dict[str, torch.Tensor]: +) -> dict[str, torch.Tensor]: pt_params = dict(pt_bert.named_parameters()) mapped_pt_params = {} for name, _ in ait_bert.named_parameters(): diff --git a/examples/03_bert/modeling/bert.py b/examples/03_bert/modeling/bert.py index a3a29b54f..413a8cf0b 100644 --- a/examples/03_bert/modeling/bert.py +++ b/examples/03_bert/modeling/bert.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import Tuple from aitemplate.compiler import ops from aitemplate.frontend import nn, Tensor @@ -70,7 +69,7 @@ def __init__( def forward( self, hidden_states: Tensor, - ) -> Tuple[Tensor]: + ) -> tuple[Tensor]: self_output = self.self(hidden_states, hidden_states) attention_output = self.output(self_output) outputs = (attention_output,) diff --git a/examples/05_stable_diffusion/src/modeling/attention.py b/examples/05_stable_diffusion/src/modeling/attention.py index 06ab5f1bd..a3c51e6aa 100644 --- a/examples/05_stable_diffusion/src/modeling/attention.py +++ b/examples/05_stable_diffusion/src/modeling/attention.py @@ -17,8 +17,6 @@ Implementations are translated from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py. """ -from typing import Optional - from aitemplate.compiler.ops import reshape from aitemplate.frontend import nn, Tensor @@ -46,7 +44,7 @@ def __init__( height: int, width: int, channels: int, - num_head_channels: Optional[int] = None, + num_head_channels: int | None = None, num_groups: int = 32, rescale_output_factor: float = 1.0, eps: float = 1e-5, diff --git a/examples/05_stable_diffusion/src/modeling/clip.py b/examples/05_stable_diffusion/src/modeling/clip.py index ff0ce792a..cf7c7fd07 100644 --- a/examples/05_stable_diffusion/src/modeling/clip.py +++ b/examples/05_stable_diffusion/src/modeling/clip.py @@ -13,7 +13,6 @@ # limitations under the License. # from inspect import isfunction -from typing import Optional from aitemplate.compiler import ops from aitemplate.frontend import nn, Tensor @@ -279,10 +278,10 @@ def __init__( def forward( self, hidden_states: Tensor, - attention_mask: Optional[Tensor] = None, - causal_attention_mask: Optional[Tensor] = None, - output_attentions: Optional[bool] = False, - residual: Optional[Tensor] = None, + attention_mask: Tensor | None = None, + causal_attention_mask: Tensor | None = None, + output_attentions: bool | None = False, + residual: Tensor | None = None, ): if residual is not None: self_output = self.attn(hidden_states, residual) @@ -399,7 +398,7 @@ def __init__( def forward( self, hidden_states: Tensor, - output_attentions: Optional[bool] = False, + output_attentions: bool | None = False, ): """ Args: @@ -469,11 +468,11 @@ def __init__( def forward( self, inputs_embeds, - attention_mask: Optional[Tensor] = None, - causal_attention_mask: Optional[Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, + attention_mask: Tensor | None = None, + causal_attention_mask: Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, ): r""" Args: @@ -548,7 +547,7 @@ def forward( self, input_ids: Tensor, position_ids: Tensor, - inputs_embeds: Optional[Tensor] = None, + inputs_embeds: Tensor | None = None, ) -> Tensor: input_shape = ops.size()(input_ids) @@ -612,12 +611,12 @@ def __init__( def forward( self, - input_ids: Optional[Tensor] = None, - attention_mask: Optional[Tensor] = None, - position_ids: Optional[Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, + input_ids: Tensor | None = None, + attention_mask: Tensor | None = None, + position_ids: Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, ): r""" Returns: diff --git a/examples/05_stable_diffusion/src/modeling/controlnet_unet_2d_condition.py b/examples/05_stable_diffusion/src/modeling/controlnet_unet_2d_condition.py index 56da472ab..c87685cba 100644 --- a/examples/05_stable_diffusion/src/modeling/controlnet_unet_2d_condition.py +++ b/examples/05_stable_diffusion/src/modeling/controlnet_unet_2d_condition.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import Optional, Tuple, Union from aitemplate.compiler import ops from aitemplate.frontend import nn @@ -82,26 +81,26 @@ def __init__( in_channels: int = 4, flip_sin_to_cos: bool = True, freq_shift: int = 0, - down_block_types: Tuple[str] = ( + down_block_types: tuple[str] = ( "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D", ), - block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + block_out_channels: tuple[int] = (320, 640, 1280, 1280), layers_per_block: int = 2, downsample_padding: int = 1, mid_block_scale_factor: float = 1, act_fn: str = "silu", - norm_num_groups: Optional[int] = 32, + norm_num_groups: int | None = 32, norm_eps: float = 1e-5, cross_attention_dim: int = 768, - attention_head_dim: Union[int, Tuple[int]] = 8, + attention_head_dim: int | tuple[int] = 8, use_linear_projection: bool = False, upcast_attention: bool = False, resnet_time_scale_shift: str = "default", controlnet_conditioning_channel_order: str = "rgb", - conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256), + conditioning_embedding_out_channels: tuple[int] | None = (16, 32, 96, 256), global_pool_conditions: bool = False, ): super().__init__() @@ -199,7 +198,7 @@ def forward( encoder_hidden_states, controlnet_cond, conditioning_scale: float = 1.0, - ) -> Tuple: + ) -> tuple: t_emb = self.time_proj(timestep) emb = self.time_embedding(t_emb) @@ -302,25 +301,25 @@ class ControlNetUNet2DConditionModel(nn.Module): def __init__( self, - sample_size: Optional[int] = None, + sample_size: int | None = None, in_channels: int = 4, out_channels: int = 4, center_input_sample: bool = False, flip_sin_to_cos: bool = True, freq_shift: int = 0, - down_block_types: Tuple[str] = ( + down_block_types: tuple[str] = ( "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D", ), - up_block_types: Tuple[str] = ( + up_block_types: tuple[str] = ( "UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", ), - block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + block_out_channels: tuple[int] = (320, 640, 1280, 1280), layers_per_block: int = 2, downsample_padding: int = 1, mid_block_scale_factor: float = 1, @@ -328,7 +327,7 @@ def __init__( norm_num_groups: int = 32, norm_eps: float = 1e-5, cross_attention_dim: int = 1280, - attention_head_dim: Union[int, Tuple[int]] = 8, + attention_head_dim: int | tuple[int] = 8, use_linear_projection: bool = False, ): super().__init__() diff --git a/examples/05_stable_diffusion/src/modeling/unet_2d_condition.py b/examples/05_stable_diffusion/src/modeling/unet_2d_condition.py index 2ad4d9718..00e8817e4 100644 --- a/examples/05_stable_diffusion/src/modeling/unet_2d_condition.py +++ b/examples/05_stable_diffusion/src/modeling/unet_2d_condition.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import Optional, Tuple, Union from aitemplate.frontend import nn, Tensor @@ -55,25 +54,25 @@ class UNet2DConditionModel(nn.Module): def __init__( self, - sample_size: Optional[int] = None, + sample_size: int | None = None, in_channels: int = 4, out_channels: int = 4, center_input_sample: bool = False, flip_sin_to_cos: bool = True, freq_shift: int = 0, - down_block_types: Tuple[str] = ( + down_block_types: tuple[str] = ( "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D", ), - up_block_types: Tuple[str] = ( + up_block_types: tuple[str] = ( "UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", ), - block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + block_out_channels: tuple[int] = (320, 640, 1280, 1280), layers_per_block: int = 2, downsample_padding: int = 1, mid_block_scale_factor: float = 1, @@ -81,7 +80,7 @@ def __init__( norm_num_groups: int = 32, norm_eps: float = 1e-5, cross_attention_dim: int = 1280, - attention_head_dim: Union[int, Tuple[int]] = 8, + attention_head_dim: int | tuple[int] = 8, use_linear_projection: bool = False, ): super().__init__() @@ -185,8 +184,8 @@ def forward( sample, timesteps, encoder_hidden_states, - down_block_additional_residuals: Optional[Tuple[Tensor]] = None, - mid_block_additional_residual: Optional[Tensor] = None, + down_block_additional_residuals: tuple[Tensor] | None = None, + mid_block_additional_residual: Tensor | None = None, return_dict: bool = True, ): """r diff --git a/examples/05_stable_diffusion/src/modeling/vae.py b/examples/05_stable_diffusion/src/modeling/vae.py index f2bea6a43..22235a1cf 100644 --- a/examples/05_stable_diffusion/src/modeling/vae.py +++ b/examples/05_stable_diffusion/src/modeling/vae.py @@ -15,8 +15,6 @@ Translated from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/vae.py. """ -from typing import Tuple - from aitemplate.frontend import nn, Tensor from .unet_blocks import get_up_block, UNetMidBlock2D @@ -119,9 +117,9 @@ def __init__( width: int, in_channels: int = 3, out_channels: int = 3, - down_block_types: Tuple[str] = ("DownEncoderBlock2D",), - up_block_types: Tuple[str] = ("UpDecoderBlock2D",), - block_out_channels: Tuple[int] = (64,), + down_block_types: tuple[str] = ("DownEncoderBlock2D",), + up_block_types: tuple[str] = ("UpDecoderBlock2D",), + block_out_channels: tuple[int] = (64,), layers_per_block: int = 1, act_fn: str = "silu", latent_channels: int = 4, diff --git a/examples/05_stable_diffusion/src/pipeline_stable_diffusion_ait.py b/examples/05_stable_diffusion/src/pipeline_stable_diffusion_ait.py index a89f43109..b9b4d964e 100644 --- a/examples/05_stable_diffusion/src/pipeline_stable_diffusion_ait.py +++ b/examples/05_stable_diffusion/src/pipeline_stable_diffusion_ait.py @@ -16,7 +16,6 @@ import os import warnings -from typing import List, Optional, Union import torch from aitemplate.compiler import Model @@ -75,14 +74,14 @@ def __init__( text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, - scheduler: Union[ - DDIMScheduler, - PNDMScheduler, - LMSDiscreteScheduler, - EulerDiscreteScheduler, - EulerAncestralDiscreteScheduler, - DPMSolverMultistepScheduler, - ], + scheduler: ( + DDIMScheduler + | PNDMScheduler + | LMSDiscreteScheduler + | EulerDiscreteScheduler + | EulerAncestralDiscreteScheduler + | DPMSolverMultistepScheduler + ), safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, requires_safety_checker: bool = True, @@ -172,16 +171,16 @@ def vae_inference(self, vae_input): @torch.no_grad() def __call__( self, - prompt: Union[str, List[str]], - height: Optional[int] = 512, - width: Optional[int] = 512, - num_inference_steps: Optional[int] = 50, - guidance_scale: Optional[float] = 7.5, - negative_prompt: Optional[Union[str, List[str]]] = None, - eta: Optional[float] = 0.0, - generator: Optional[torch.Generator] = None, - latents: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", + prompt: str | list[str], + height: int | None = 512, + width: int | None = 512, + num_inference_steps: int | None = 50, + guidance_scale: float | None = 7.5, + negative_prompt: str | list[str] | None = None, + eta: float | None = 0.0, + generator: torch.Generator | None = None, + latents: torch.FloatTensor | None = None, + output_type: str | None = "pil", return_dict: bool = True, **kwargs, ): @@ -276,7 +275,7 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance: - uncond_tokens: List[str] + uncond_tokens: list[str] max_length = text_input.input_ids.shape[-1] if negative_prompt is None: uncond_tokens = [""] * batch_size diff --git a/examples/05_stable_diffusion/src/pipeline_stable_diffusion_ait_alt.py b/examples/05_stable_diffusion/src/pipeline_stable_diffusion_ait_alt.py index bf9327fb5..f3633c919 100644 --- a/examples/05_stable_diffusion/src/pipeline_stable_diffusion_ait_alt.py +++ b/examples/05_stable_diffusion/src/pipeline_stable_diffusion_ait_alt.py @@ -16,7 +16,6 @@ import os import re -from typing import List, Optional, Union import torch from aitemplate.compiler import Model @@ -838,16 +837,16 @@ def vae_inference(self, vae_input, height, width): @torch.no_grad() def __call__( self, - prompt: Union[str, List[str]], - height: Optional[int] = 512, - width: Optional[int] = 512, - num_inference_steps: Optional[int] = 50, - guidance_scale: Optional[float] = 7.5, - negative_prompt: Optional[Union[str, List[str]]] = None, - eta: Optional[float] = 0.0, - generator: Optional[torch.Generator] = None, - latents: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", + prompt: str | list[str], + height: int | None = 512, + width: int | None = 512, + num_inference_steps: int | None = 50, + guidance_scale: float | None = 7.5, + negative_prompt: str | list[str] | None = None, + eta: float | None = 0.0, + generator: torch.Generator | None = None, + latents: torch.FloatTensor | None = None, + output_type: str | None = "pil", return_dict: bool = True, ): r""" @@ -931,7 +930,7 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance: - uncond_tokens: List[str] + uncond_tokens: list[str] max_length = text_input.input_ids.shape[-1] if negative_prompt is None: uncond_tokens = [""] * batch_size diff --git a/examples/05_stable_diffusion/src/pipeline_stable_diffusion_controlnet_ait.py b/examples/05_stable_diffusion/src/pipeline_stable_diffusion_controlnet_ait.py index e553c70c8..120cf4bf2 100644 --- a/examples/05_stable_diffusion/src/pipeline_stable_diffusion_controlnet_ait.py +++ b/examples/05_stable_diffusion/src/pipeline_stable_diffusion_controlnet_ait.py @@ -14,7 +14,6 @@ # import inspect import os -from typing import List, Optional, Union import torch from aitemplate.compiler import Model @@ -838,17 +837,17 @@ def vae_inference(self, vae_input, height, width): @torch.no_grad() def __call__( self, - prompt: Union[str, List[str]], + prompt: str | list[str], control_cond: torch.FloatTensor, - height: Optional[int] = 512, - width: Optional[int] = 512, - num_inference_steps: Optional[int] = 50, - guidance_scale: Optional[float] = 7.5, - negative_prompt: Optional[Union[str, List[str]]] = None, - eta: Optional[float] = 0.0, - generator: Optional[torch.Generator] = None, - latents: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", + height: int | None = 512, + width: int | None = 512, + num_inference_steps: int | None = 50, + guidance_scale: float | None = 7.5, + negative_prompt: str | list[str] | None = None, + eta: float | None = 0.0, + generator: torch.Generator | None = None, + latents: torch.FloatTensor | None = None, + output_type: str | None = "pil", return_dict: bool = True, ): r""" @@ -932,7 +931,7 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance: - uncond_tokens: List[str] + uncond_tokens: list[str] max_length = text_input.input_ids.shape[-1] if negative_prompt is None: uncond_tokens = [""] * batch_size diff --git a/examples/05_stable_diffusion/src/pipeline_stable_diffusion_img2img_ait.py b/examples/05_stable_diffusion/src/pipeline_stable_diffusion_img2img_ait.py index 893db028d..465fb1a74 100644 --- a/examples/05_stable_diffusion/src/pipeline_stable_diffusion_img2img_ait.py +++ b/examples/05_stable_diffusion/src/pipeline_stable_diffusion_img2img_ait.py @@ -16,7 +16,6 @@ # flakes8: noqa import inspect import os -from typing import List, Optional, Union import numpy as np @@ -83,7 +82,7 @@ def __init__( text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + scheduler: DDIMScheduler | PNDMScheduler | LMSDiscreteScheduler, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, requires_safety_checker: bool = True, @@ -183,14 +182,14 @@ def vae_inference(self, vae_input): @torch.no_grad() def __call__( self, - prompt: Union[str, List[str]], - init_image: Union[torch.FloatTensor, PIL.Image.Image], + prompt: str | list[str], + init_image: torch.FloatTensor | PIL.Image.Image, strength: float = 0.8, - num_inference_steps: Optional[int] = 50, - guidance_scale: Optional[float] = 7.5, - eta: Optional[float] = 0.0, - generator: Optional[torch.Generator] = None, - output_type: Optional[str] = "pil", + num_inference_steps: int | None = 50, + guidance_scale: float | None = 7.5, + eta: float | None = 0.0, + generator: torch.Generator | None = None, + output_type: str | None = "pil", return_dict: bool = True, ): r""" diff --git a/examples/06_how_to_add_an_op/how_to_add_an_op.py b/examples/06_how_to_add_an_op/how_to_add_an_op.py index 4e0087cd9..a195c73d8 100644 --- a/examples/06_how_to_add_an_op/how_to_add_an_op.py +++ b/examples/06_how_to_add_an_op/how_to_add_an_op.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List +from typing import Any import jinja2 import torch @@ -47,7 +47,7 @@ def __call__(self, x: Tensor) -> Tensor: self._attrs["outputs"] = [output] return output - def _infer_shape(self, x) -> List[IntVar]: + def _infer_shape(self, x) -> list[IntVar]: return x.shape() def gen_function(self) -> str: @@ -134,7 +134,7 @@ def gen_function(self) -> str: ) -def gen_function_call(func_attrs: Dict[str, Any], indent=" ") -> str: +def gen_function_call(func_attrs: dict[str, Any], indent=" ") -> str: assert len(func_attrs["outputs"]) == 1 assert len(func_attrs["inputs"]) == 1 @@ -151,7 +151,7 @@ def gen_function_call(func_attrs: Dict[str, Any], indent=" ") -> str: ) -def gen_function(func_attrs: Dict[str, Any], header_files: str, backend_spec) -> str: +def gen_function(func_attrs: dict[str, Any], header_files: str, backend_spec) -> str: input_x = func_attrs["inputs"][0] output_y = func_attrs["outputs"][0] input_type = backend_spec.dtype_to_backend_type(input_x._attrs["dtype"]) @@ -172,7 +172,7 @@ def gen_function(func_attrs: Dict[str, Any], header_files: str, backend_spec) -> ) -def gen_function_decl(func_attrs: Dict[str, Any], backend_spec) -> str: +def gen_function_decl(func_attrs: dict[str, Any], backend_spec) -> str: return FUNC_DECL.render( func_signature=FUNC_SIGNATURE.render( func_name=func_attrs["name"], @@ -187,17 +187,17 @@ def gen_function_decl(func_attrs: Dict[str, Any], backend_spec) -> str: @registry.reg("cuda.add_one.gen_function") -def cuda_add_one_gen_function(func_attrs: Dict[str, Any]) -> str: +def cuda_add_one_gen_function(func_attrs: dict[str, Any]) -> str: return gen_function(func_attrs, CUDA_HEADER_FILES, CUDASpec()) @registry.reg("cuda.add_one.func_decl") -def cuda_add_one_gen_function_decl(func_attrs: Dict[str, Any]) -> str: +def cuda_add_one_gen_function_decl(func_attrs: dict[str, Any]) -> str: return gen_function_decl(func_attrs, CUDASpec()) @registry.reg("cuda.add_one.func_call") -def cuda_add_one_gen_function_call(func_attrs: Dict[str, Any], indent=" ") -> str: +def cuda_add_one_gen_function_call(func_attrs: dict[str, Any], indent=" ") -> str: return gen_function_call(func_attrs, indent) @@ -208,17 +208,17 @@ def cuda_add_one_gen_function_call(func_attrs: Dict[str, Any], indent=" ") -> s @registry.reg("rocm.add_one.gen_function") -def rocm_add_one_gen_function(func_attrs: Dict[str, Any]) -> str: +def rocm_add_one_gen_function(func_attrs: dict[str, Any]) -> str: return gen_function(func_attrs, HIP_HEADER_FILES, ROCMSpec()) @registry.reg("rocm.add_one.func_decl") -def rocm_add_one_gen_function_decl(func_attrs: Dict[str, Any]) -> str: +def rocm_add_one_gen_function_decl(func_attrs: dict[str, Any]) -> str: return gen_function_decl(func_attrs, ROCMSpec()) @registry.reg("rocm.add_one.func_call") -def rocm_add_one_gen_function_call(func_attrs: Dict[str, Any], indent=" ") -> str: +def rocm_add_one_gen_function_call(func_attrs: dict[str, Any], indent=" ") -> str: return gen_function_call(func_attrs, indent) diff --git a/fx2ait/fx2ait/acc_tracer/acc_normalizer.py b/fx2ait/fx2ait/acc_tracer/acc_normalizer.py index c394e7570..ad8793128 100644 --- a/fx2ait/fx2ait/acc_tracer/acc_normalizer.py +++ b/fx2ait/fx2ait/acc_tracer/acc_normalizer.py @@ -15,7 +15,8 @@ import inspect import logging import re -from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple, Union +from collections.abc import Callable +from typing import Any, NamedTuple import torch import torch.fx @@ -41,7 +42,7 @@ # signature of the acc_op, or for custom mapped nodes from the original unnormalized op. # - The third member is a bool representing whether this arg is optional, i.e. whether it # is allowed to not be present in the original input args. -ArgReplacementTuplesType = List[Tuple[Tuple[str, ...], str, bool]] +ArgReplacementTuplesType = list[tuple[tuple[str, ...], str, bool]] class NormalizationInfo(NamedTuple): @@ -73,34 +74,31 @@ class NormalizationInfo(NamedTuple): """ new_fn_target: Callable - arg_replacement_tuples: Optional[ArgReplacementTuplesType] - custom_mapping_fn: Optional[Callable] + arg_replacement_tuples: ArgReplacementTuplesType | None + custom_mapping_fn: Callable | None # either (tensor_meta_field_name, original_field_name, move_to_qparams) or # (tensor_meta_field_name, orginal_field_name) # when move_to_qparams is True, we'll move the field to qparams # dictionary, otherwise it will stay in TensorMeta itself - kwargs_to_move_to_acc_out_ty: Optional[ - List[Union[Tuple[str, str, bool], Tuple[str, str]]] - ] + kwargs_to_move_to_acc_out_ty: None | (list[tuple[str, str, bool] | tuple[str, str]]) needs_shapes_for_normalization: bool skip_normalization_if_none: bool # Dict from (op, target) to NormalizationInfo for that op. -_normalization_dict: Dict[Tuple[str, Union[str, Callable]], NormalizationInfo] = {} +_normalization_dict: dict[tuple[str, str | Callable], NormalizationInfo] = {} # Set of all the acc ops. -_acc_ops: Set[Callable] = set() +_acc_ops: set[Callable] = set() def _insert_fun( - op_and_target: Tuple[str, Union[str, Callable]], - arg_replacement_tuples: List[Tuple], - new_fn_target: Optional[Callable] = None, - custom_mapping_fn: Optional[Callable] = None, - kwargs_to_move_to_acc_out_ty: Optional[ - List[Union[Tuple[str, str, bool], Tuple[str, str]]] - ] = None, + op_and_target: tuple[str, str | Callable], + arg_replacement_tuples: list[tuple], + new_fn_target: Callable | None = None, + custom_mapping_fn: Callable | None = None, + kwargs_to_move_to_acc_out_ty: None + | (list[tuple[str, str, bool] | tuple[str, str]]) = None, needs_shapes_for_normalization=False, allow_normalize_from_torch_package=False, skip_normalization_if_none=False, @@ -161,12 +159,12 @@ def _insert_fun( _normalization_dict[torch_package_op_and_target] = norm_info -def _get_dup_signature_tuples(fn: Callable) -> List[Tuple[str, str]]: +def _get_dup_signature_tuples(fn: Callable) -> list[tuple[str, str]]: """ Helper that inspects the arg signature of `fn` and returns a list of tuples, where each tuple is a pair of duplicated names which is used for arg_replacement_tuples. """ - sig_tuples: List[Tuple[str, str]] = [] + sig_tuples: list[tuple[str, str]] = [] for param in inspect.signature(inspect.unwrap(fn)).parameters: sig_tuples.append((param, param)) return sig_tuples @@ -181,18 +179,18 @@ def register_acc_op(acc_op: Callable): def register_acc_op_mapping( - op_and_target: Tuple[str, Union[str, Callable]], - arg_replacement_tuples: Optional[ - List[ - Union[ - Tuple[Union[str, Tuple[str, ...]], str], - Tuple[Union[str, Tuple[str, ...]], str, bool], - ] + op_and_target: tuple[str, str | Callable], + arg_replacement_tuples: None + | ( + list[ + ( + tuple[str | tuple[str, ...], str] + | tuple[str | tuple[str, ...], str, bool] + ) ] - ] = None, - kwargs_to_move_to_acc_out_ty: Optional[ - List[Union[Tuple[str, str, bool], Tuple[str, str]]] - ] = None, + ) = None, + kwargs_to_move_to_acc_out_ty: None + | (list[tuple[str, str, bool] | tuple[str, str]]) = None, allow_normalize_from_torch_package=False, ): """ @@ -225,12 +223,9 @@ def insert(new_fn_target: Callable): def register_custom_acc_mapper_fn( - op_and_target: Tuple[str, Union[str, Callable]], - arg_replacement_tuples: List[ - Union[ - Tuple[Union[str, Tuple[str, ...]], str], - Tuple[Union[str, Tuple[str, ...]], str, bool], - ] + op_and_target: tuple[str, str | Callable], + arg_replacement_tuples: list[ + (tuple[str | tuple[str, ...], str] | tuple[str | tuple[str, ...], str, bool]) ], needs_shapes_for_normalization=False, allow_normalize_from_torch_package=False, @@ -251,8 +246,8 @@ def insert(custom_mapping_fn: Callable): def move_kwargs_to_acc_out_ty( - node_or_normalization_info: Union[NormalizationInfo, torch.fx.Node], - new_kwargs: Dict[str, Any], + node_or_normalization_info: NormalizationInfo | torch.fx.Node, + new_kwargs: dict[str, Any], ): """ Given `node_or_normalization_info` which is either NormalizationInfo for a node, or @@ -278,8 +273,8 @@ def move_kwargs_to_acc_out_ty( # Build a dict representing the new TensorMetadata to use for acc_out_ty, # and then remove the kwarg from the new_kwargs since it's passed in via # acc_out_ty instead. - tmd_dict: Dict[str, Any] = {} - qparams: Dict[str, Any] = {} + tmd_dict: dict[str, Any] = {} + qparams: dict[str, Any] = {} for kwarg_replacement_tuple in normalization_info.kwargs_to_move_to_acc_out_ty: if len(kwarg_replacement_tuple) == 2: @@ -353,9 +348,7 @@ def get_normalized_kwargs( def normalize( mod: torch.fx.GraphModule, expect_nodes_have_shapes: bool = False, - acc_normalization_block_list: Optional[ - Set[Tuple[str, Union[str, Callable]]] - ] = None, + acc_normalization_block_list: None | (set[tuple[str, str | Callable]]) = None, ): assert len(_normalization_dict) > 0 graph = mod.graph @@ -376,8 +369,8 @@ def get_target(mod: torch.fx.GraphModule, node: torch.fx.Node): def normalize_to_acc_op( node: torch.fx.Node, normalization_info: NormalizationInfo, - normalized_args: Tuple[Any, ...], - normalized_kwargs: Dict[str, Any], + normalized_args: tuple[Any, ...], + normalized_kwargs: dict[str, Any], ): # If there's a custom mapping function then use it. if normalization_info.custom_mapping_fn is not None: diff --git a/fx2ait/fx2ait/acc_tracer/acc_op_properties.py b/fx2ait/fx2ait/acc_tracer/acc_op_properties.py index 895ad9a97..4e860a1ce 100644 --- a/fx2ait/fx2ait/acc_tracer/acc_op_properties.py +++ b/fx2ait/fx2ait/acc_tracer/acc_op_properties.py @@ -13,8 +13,9 @@ # limitations under the License. # from collections import defaultdict +from collections.abc import Callable from enum import auto, Flag -from typing import Callable, DefaultDict, Set +from typing import DefaultDict import torch import torch.fx @@ -38,8 +39,8 @@ class AccOpProperty(Flag): unary = auto() -acc_op_properties: DefaultDict[Callable, Set[AccOpProperty]] = defaultdict(set) -acc_ops_with_property: DefaultDict[AccOpProperty, Set[Callable]] = defaultdict(set) +acc_op_properties: DefaultDict[Callable, set[AccOpProperty]] = defaultdict(set) +acc_ops_with_property: DefaultDict[AccOpProperty, set[Callable]] = defaultdict(set) def register_acc_op_properties(*properties: AccOpProperty): diff --git a/fx2ait/fx2ait/acc_tracer/acc_ops.py b/fx2ait/fx2ait/acc_tracer/acc_ops.py index d1c159f12..ff249bba4 100644 --- a/fx2ait/fx2ait/acc_tracer/acc_ops.py +++ b/fx2ait/fx2ait/acc_tracer/acc_ops.py @@ -17,7 +17,8 @@ import operator import torch # isort:skip -from typing import cast, Iterable, List, Optional, Sequence +from collections.abc import Iterable, Sequence +from typing import cast import torch.nn as nn from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata @@ -464,7 +465,7 @@ def tile(*, input, dims): ], skip_normalization_if_none=True, ) -def repeat_mapper(node: torch.fx.Node, _: nn.Module) -> Optional[torch.fx.Node]: +def repeat_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node | None: """ Map repeat to tile. """ @@ -1867,7 +1868,7 @@ def ceil(*, input): @register_acc_op_mapping(op_and_target=("call_function", torch.nn.functional.pad)) @register_acc_op -def pad(*, input, pad: List[int], mode: str, value: float): +def pad(*, input, pad: list[int], mode: str, value: float): return torch.nn.functional.pad(input=input, pad=pad, mode=mode, value=value) @@ -2450,7 +2451,7 @@ def where(*, condition, x, y): def slice_tensor(*, input, dim, start, stop, step): slc = slice(start, stop, step) if dim >= 0: - slices: List[slice] = [slice(None, None, None) for _ in range(dim)] + slices: list[slice] = [slice(None, None, None) for _ in range(dim)] slices.append(slc) else: slices = [Ellipsis, slc] # type: ignore[list-item] diff --git a/fx2ait/fx2ait/acc_tracer/acc_utils.py b/fx2ait/fx2ait/acc_tracer/acc_utils.py index 21646af9e..dfc1cd7e1 100644 --- a/fx2ait/fx2ait/acc_tracer/acc_utils.py +++ b/fx2ait/fx2ait/acc_tracer/acc_utils.py @@ -16,7 +16,8 @@ import logging import os import re -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from collections.abc import Callable +from typing import Any import torch import torch.fx @@ -56,7 +57,7 @@ def get_attr(node: torch.fx.Node) -> Any: return get_target_from_module(node.graph.owning_module, str(node.target)) -def is_acc_op(node_or_target: Union[Callable, torch.fx.Node]) -> bool: +def is_acc_op(node_or_target: Callable | torch.fx.Node) -> bool: """ Returns whether `node_or_target` is an acc_op. If it's a node, then checks whether it's a call_function target is from the acc_ops module. Otherwise it's already @@ -72,9 +73,7 @@ def is_acc_op(node_or_target: Union[Callable, torch.fx.Node]) -> bool: return "acc_ops" in target.__module__ -def is_acc_op_with_kwarg( - node_or_target: Union[Callable, torch.fx.Node], kwarg: str -) -> bool: +def is_acc_op_with_kwarg(node_or_target: Callable | torch.fx.Node, kwarg: str) -> bool: """ Helper that inspects `node_or_target` and returns whether it is an acc_op node (or a target for an acc_op) that has an arg signature that includes `kwarg`. @@ -116,12 +115,12 @@ def draw_graph(traced: torch.fx.GraphModule, fname: str, figname: str = "fx_grap _LOGGER.error(f"Failed to write the FX graph due to: {e}") -def get_model_info_str(gm: torch.fx.GraphModule, header: Optional[str] = None): +def get_model_info_str(gm: torch.fx.GraphModule, header: str | None = None): """ Print out info of the provided `gm`. If `header` is provided then it's included in the printed string. """ - ops_and_counts: Dict[Callable, int] = {} + ops_and_counts: dict[Callable, int] = {} placeholder_count = get_attr_count = call_method_count = call_module_count = 0 for node in gm.graph.nodes: if node.op == "call_function": @@ -151,7 +150,7 @@ def get_model_info_str(gm: torch.fx.GraphModule, header: Optional[str] = None): # Sort and print all the other ops. Sort so it's deterministic between runs and # easier to parse. - pretty_ops_and_counts: List[Tuple[str, int]] = [] + pretty_ops_and_counts: list[tuple[str, int]] = [] for op, count in ops_and_counts.items(): pretty_ops_and_counts.append((_get_qualified_name(op), count)) pretty_ops_and_counts.sort() diff --git a/fx2ait/fx2ait/acc_tracer/ait_acc_ops_registry.py b/fx2ait/fx2ait/acc_tracer/ait_acc_ops_registry.py index 6417d8431..244e630ed 100644 --- a/fx2ait/fx2ait/acc_tracer/ait_acc_ops_registry.py +++ b/fx2ait/fx2ait/acc_tracer/ait_acc_ops_registry.py @@ -12,55 +12,52 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import Callable, Dict, List, NamedTuple, Optional, Tuple, Union +from collections.abc import Callable +from typing import NamedTuple class AitAccOpMapper(NamedTuple): new_fn_target: Callable - arg_replacement_tuples: Optional[ - List[ - Union[ - Tuple[Union[str, Tuple[str, ...]], str], - Tuple[Union[str, Tuple[str, ...]], str, bool], + arg_replacement_tuples: ( + None + | ( + list[ + ( + tuple[str | tuple[str, ...], str] + | tuple[str | tuple[str, ...], str, bool] + ) ] - ] - ] - kwargs_to_move_to_acc_out_ty: Optional[ - List[Union[Tuple[str, str, bool], Tuple[str, str]]] - ] + ) + ) + kwargs_to_move_to_acc_out_ty: None | (list[tuple[str, str, bool] | tuple[str, str]]) class CustomAitAccOpMapper(NamedTuple): custom_mapping_fn: Callable - arg_replacement_tuples: List[ - Union[ - Tuple[Union[str, Tuple[str, ...]], str], - Tuple[Union[str, Tuple[str, ...]], str, bool], - ] + arg_replacement_tuples: list[ + (tuple[str | tuple[str, ...], str] | tuple[str | tuple[str, ...], str, bool]) ] needs_shapes_for_normalization: bool allow_normalize_from_torch_package: bool -_AIT_ACC_OP_MAPPERS: Dict[Tuple[str, Union[str, Callable]], AitAccOpMapper] = {} -_CUSTOM_AIT_ACC_OP_MAPPERS: Dict[ - Tuple[str, Union[str, Callable]], CustomAitAccOpMapper -] = {} +_AIT_ACC_OP_MAPPERS: dict[tuple[str, str | Callable], AitAccOpMapper] = {} +_CUSTOM_AIT_ACC_OP_MAPPERS: dict[tuple[str, str | Callable], CustomAitAccOpMapper] = {} def ait_register_acc_op_mapping( - op_and_target: Tuple[str, Union[str, Callable]], - arg_replacement_tuples: Optional[ - List[ - Union[ - Tuple[Union[str, Tuple[str, ...]], str], - Tuple[Union[str, Tuple[str, ...]], str, bool], - ] + op_and_target: tuple[str, str | Callable], + arg_replacement_tuples: None + | ( + list[ + ( + tuple[str | tuple[str, ...], str] + | tuple[str | tuple[str, ...], str, bool] + ) ] - ] = None, - kwargs_to_move_to_acc_out_ty: Optional[ - List[Union[Tuple[str, str, bool], Tuple[str, str]]] - ] = None, + ) = None, + kwargs_to_move_to_acc_out_ty: None + | (list[tuple[str, str, bool] | tuple[str, str]]) = None, ): def insert(new_fn_target: Callable): _AIT_ACC_OP_MAPPERS[op_and_target] = AitAccOpMapper( @@ -74,12 +71,9 @@ def insert(new_fn_target: Callable): def ait_register_custom_acc_mapper_fn( - op_and_target: Tuple[str, Union[str, Callable]], - arg_replacement_tuples: List[ - Union[ - Tuple[Union[str, Tuple[str, ...]], str], - Tuple[Union[str, Tuple[str, ...]], str, bool], - ] + op_and_target: tuple[str, str | Callable], + arg_replacement_tuples: list[ + (tuple[str | tuple[str, ...], str] | tuple[str | tuple[str, ...], str, bool]) ], needs_shapes_for_normalization=False, allow_normalize_from_torch_package=False, @@ -96,11 +90,11 @@ def insert(custom_mapping_fn: Callable): return insert -def get_ait_acc_op_mappers() -> Dict[Tuple[str, Union[str, Callable]], AitAccOpMapper]: +def get_ait_acc_op_mappers() -> dict[tuple[str, str | Callable], AitAccOpMapper]: return _AIT_ACC_OP_MAPPERS def get_custom_ait_acc_op_mappers() -> ( - Dict[Tuple[str, Union[str, Callable]], CustomAitAccOpMapper] + dict[tuple[str, str | Callable], CustomAitAccOpMapper] ): return _CUSTOM_AIT_ACC_OP_MAPPERS diff --git a/fx2ait/fx2ait/ait_module.py b/fx2ait/fx2ait/ait_module.py index 9f83fe937..1046fb0a8 100644 --- a/fx2ait/fx2ait/ait_module.py +++ b/fx2ait/fx2ait/ait_module.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import List import torch @@ -25,7 +24,7 @@ def __init__( engine=None, interp_result=None, ): - super(AITModule, self).__init__() + super().__init__() self.engine = engine self.interp_result = interp_result @@ -65,7 +64,7 @@ def forward(self, *args, **kwargs): return tuple(outputs) def profile( - self, inputs: List[torch.Tensor], filename: str, num_iters: int + self, inputs: list[torch.Tensor], filename: str, num_iters: int ) -> None: """ Profile the AIT module and save the report to a file. The AITModule diff --git a/fx2ait/fx2ait/ait_splitter.py b/fx2ait/fx2ait/ait_splitter.py index 16244b777..c9097c7fb 100644 --- a/fx2ait/fx2ait/ait_splitter.py +++ b/fx2ait/fx2ait/ait_splitter.py @@ -13,7 +13,8 @@ # limitations under the License. # import logging -from typing import Any, Dict, Iterable, Mapping, Optional, Sequence, Set +from collections.abc import Iterable, Mapping, Sequence +from typing import Any import torch import torch.fx.passes.operator_support as ops @@ -72,7 +73,7 @@ def create_ait_operator_support( """Creates an `OperatorSupportBase` instance used for AIT splitting purpose.""" # Create an `OperatorSupport` that declares a node supported if it # finds a registered AIT converter. - support_dict: Dict[str, None] = {} + support_dict: dict[str, None] = {} for k in AIT_CONVERTERS.keys(): # may need to switch the op name here support_dict[get_acc_ops_name(k)] = None @@ -126,7 +127,7 @@ def __init__( class SelectedOperatorSupport(ops.OperatorSupportBase): - def __init__(self, selected_nodes: Set[torch.fx.Node]) -> None: + def __init__(self, selected_nodes: set[torch.fx.Node]) -> None: self.selected_nodes = selected_nodes def is_node_supported( @@ -148,7 +149,7 @@ def _range_operator_support( for i, n in enumerate(module.graph.nodes): logger.info(f"Index:{i}, n.op={n.op}, n.target={n.target}, n.name={n.name}") - selected_nodes: Set[torch.fx.Node] = set() + selected_nodes: set[torch.fx.Node] = set() for i, n in enumerate(module.graph.nodes): if i >= start and i <= end: if n.op in CALLABLE_NODE_OPS: @@ -167,8 +168,8 @@ def __init__( self, module: torch.fx.GraphModule, sample_input: Sequence[Any], - operator_support: Optional[ops.OperatorSupportBase] = None, - settings: Optional[AITSplitterSettings] = None, + operator_support: ops.OperatorSupportBase | None = None, + settings: AITSplitterSettings | None = None, ): if not settings: settings = AITSplitterSettings() diff --git a/fx2ait/fx2ait/converters/ait_converters.py b/fx2ait/fx2ait/converters/ait_converters.py index 8149872ea..0bf295a41 100644 --- a/fx2ait/fx2ait/converters/ait_converters.py +++ b/fx2ait/fx2ait/converters/ait_converters.py @@ -15,7 +15,8 @@ import logging import math import operator -from typing import Dict, List, Optional, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import Union import torch @@ -97,14 +98,14 @@ ) logger: logging.Logger = logging.getLogger(__name__) -ConverterOutput = Union[AITTensor, Tuple[AITTensor, ...], List[IntVar], IntVar] +ConverterOutput = Union[AITTensor, tuple[AITTensor, ...], list[IntVar], IntVar] @ait_converter(acc_ops.sigmoid) def acc_ops_sigmoid( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = kwargs["input"] @@ -116,8 +117,8 @@ def acc_ops_sigmoid( @ait_converter(acc_ops.mul) def acc_ops_mul( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: return create_binary_op(FuncEnum.MUL, args, kwargs, name) @@ -126,8 +127,8 @@ def acc_ops_mul( @ait_converter(acc_ops.square) def acc_ops_square( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: new_kwargs = dict(kwargs.copy()) @@ -138,8 +139,8 @@ def acc_ops_square( @ait_converter(acc_ops.div) def acc_ops_div( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: return create_binary_op(FuncEnum.DIV, args, kwargs, name) @@ -148,8 +149,8 @@ def acc_ops_div( @ait_converter(acc_ops.floor_div) def acc_ops_floor_div( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: return create_binary_op(FuncEnum.FLOOR_DIV, args, kwargs, name) @@ -158,8 +159,8 @@ def acc_ops_floor_div( @ait_converter(acc_ops.floor) def acc_ops_floor( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = kwargs["input"] @@ -169,8 +170,8 @@ def acc_ops_floor( @ait_converter(acc_ops.reciprocal) def acc_ops_reciprocal( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = kwargs["input"] @@ -180,8 +181,8 @@ def acc_ops_reciprocal( @ait_converter(acc_ops.add) def acc_ops_add( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: return create_binary_op(FuncEnum.ADD, args, kwargs, name) @@ -190,8 +191,8 @@ def acc_ops_add( @ait_converter(acc_ops.sub) def acc_ops_sub( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: return create_binary_op(FuncEnum.SUB, args, kwargs, name) @@ -200,8 +201,8 @@ def acc_ops_sub( @ait_converter(acc_ops.tanh) def acc_ops_tanh( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = kwargs["input"] @@ -211,8 +212,8 @@ def acc_ops_tanh( @ait_converter(acc_ops.sin) def acc_ops_sin( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = kwargs["input"] @@ -222,8 +223,8 @@ def acc_ops_sin( @ait_converter(acc_ops.cos) def acc_ops_cos( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = kwargs["input"] @@ -233,8 +234,8 @@ def acc_ops_cos( @ait_converter(acc_ops.sqrt) def acc_ops_sqrt( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = kwargs["input"] @@ -244,8 +245,8 @@ def acc_ops_sqrt( @ait_converter(acc_ops.clone) def acc_ops_clone( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = kwargs["input"] @@ -261,8 +262,8 @@ def acc_ops_clone( @ait_converter(acc_ops.sum) def acc_ops_sum( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: return create_reduce_op(reduce_sum, args, kwargs, name) @@ -271,8 +272,8 @@ def acc_ops_sum( @ait_converter(acc_ops.mean) def acc_ops_mean( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: return create_reduce_op(reduce_mean, args, kwargs, name) @@ -281,8 +282,8 @@ def acc_ops_mean( @ait_converter(acc_ops.amax) def acc_ops_amax( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: return create_reduce_op(reduce_max, args, kwargs, name) @@ -291,8 +292,8 @@ def acc_ops_amax( @ait_converter(acc_ops.amin) def acc_ops_amin( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: return create_reduce_op(reduce_min, args, kwargs, name) @@ -301,8 +302,8 @@ def acc_ops_amin( @ait_converter(acc_ops.linear) def acc_ops_linear( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = kwargs["input"] @@ -323,8 +324,8 @@ def acc_ops_linear( @ait_converter(acc_ops.unsqueeze) def acc_ops_unsqueeze( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = kwargs["input"] @@ -341,8 +342,8 @@ def acc_ops_unsqueeze( @ait_converter(acc_ops.clamp) def acc_ops_clamp( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = kwargs["input"] @@ -358,8 +359,8 @@ def acc_ops_clamp( @ait_converter(acc_ops.linalg_norm) def acc_ops_linalg_norm( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = kwargs["input"] @@ -392,8 +393,8 @@ def acc_ops_linalg_norm( @ait_converter(acc_ops.permute) def acc_ops_permute( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = kwargs["input"] @@ -418,8 +419,8 @@ def acc_ops_permute( @ait_converter(acc_ops.cat) def acc_ops_cat( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: tensors = kwargs["tensors"] @@ -437,8 +438,8 @@ def acc_ops_cat( @ait_converter(acc_ops.sign) def acc_ops_sign( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = kwargs["input"] @@ -451,8 +452,8 @@ def acc_ops_sign( @ait_converter(acc_ops.abs) def acc_ops_abs( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = kwargs["input"] @@ -465,8 +466,8 @@ def acc_ops_abs( @ait_converter(acc_ops.exp) def acc_ops_exp( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = kwargs["input"] @@ -479,8 +480,8 @@ def acc_ops_exp( @ait_converter(acc_ops.log) def acc_ops_log( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = kwargs["input"] @@ -493,8 +494,8 @@ def acc_ops_log( @ait_converter(acc_ops.log1p) def acc_ops_log1p( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = kwargs["input"] @@ -506,7 +507,7 @@ def acc_ops_log1p( @ait_converter(acc_ops.var) def acc_ops_var( - target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str + target: Target, args: tuple[Argument, ...], kwargs: dict[str, Argument], name: str ) -> ConverterOutput: input_val = kwargs["input"] if not isinstance(input_val, AITTensor): @@ -524,8 +525,8 @@ def acc_ops_var( @ait_converter(acc_ops.softmax) def acc_ops_softmax( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = kwargs["input"] @@ -537,7 +538,7 @@ def acc_ops_softmax( @ait_converter(acc_ops.relu) def acc_ops_relu( - target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str + target: Target, args: tuple[Argument, ...], kwargs: dict[str, Argument], name: str ) -> ConverterOutput: input_val = kwargs["input"] if not isinstance(input_val, AITTensor): @@ -548,7 +549,7 @@ def acc_ops_relu( @ait_converter(acc_ops.elu) def acc_ops_elu( - target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str + target: Target, args: tuple[Argument, ...], kwargs: dict[str, Argument], name: str ) -> ConverterOutput: input_val = kwargs["input"] if not isinstance(input_val, AITTensor): @@ -573,7 +574,7 @@ def acc_ops_elu( @ait_converter(acc_ops.leaky_relu) def acc_ops_leaky_relu( - target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str + target: Target, args: tuple[Argument, ...], kwargs: dict[str, Argument], name: str ) -> ConverterOutput: input_val = kwargs["input"] if not isinstance(input_val, AITTensor): @@ -584,7 +585,7 @@ def acc_ops_leaky_relu( @ait_converter(acc_ops.squeeze) def acc_ops_squeeze( - target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str + target: Target, args: tuple[Argument, ...], kwargs: dict[str, Argument], name: str ) -> ConverterOutput: input_val = kwargs["input"] if not isinstance(input_val, AITTensor): @@ -598,8 +599,8 @@ def acc_ops_squeeze( @ait_converter(acc_ops.size) def acc_ops_size( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = kwargs["input"] @@ -617,8 +618,8 @@ def acc_ops_size( @ait_converter(acc_ops.unbind) def acc_ops_unbind( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = kwargs["input"] @@ -646,8 +647,8 @@ def acc_ops_unbind( @ait_converter(acc_ops.getitem) def acc_ops_getitem( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: # operator.getitem does not have kwargs. We copy args to kwargs so the downstream like acc_ops_slice can use it. @@ -726,8 +727,8 @@ def acc_ops_getitem( @ait_converter(acc_ops.slice_tensor) def acc_ops_slice( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = kwargs["input"] @@ -806,8 +807,8 @@ def num_slice_types(slices): @ait_converter(acc_ops.reshape) def acc_ops_reshape( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = kwargs["input"] @@ -828,8 +829,8 @@ def acc_ops_reshape( @ait_converter(acc_ops.topk) def acc_ops_topk( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = kwargs["input"] @@ -864,8 +865,8 @@ def acc_ops_topk( @ait_converter(acc_ops.tuple_construct) def acc_ops_tuple_construct( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: tensors = kwargs["tensors"] @@ -875,8 +876,8 @@ def acc_ops_tuple_construct( @ait_converter(acc_ops.conv_transpose2d) def acc_ops_conv_transpose2d( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: output_padding = identical_elem_tuple_to_int(kwargs["output_padding"]) @@ -991,8 +992,8 @@ def make_slice(x, slice_idx, name): @ait_converter(acc_ops.nan_to_num) def acc_ops_nan_to_num( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = kwargs["input"] @@ -1030,8 +1031,8 @@ def _get_dtype(dtype: str): @ait_converter(acc_ops.group_norm) def acc_ops_group_norm( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = kwargs["input"] @@ -1051,8 +1052,8 @@ def acc_ops_group_norm( @ait_converter(acc_ops.layer_norm) def acc_ops_layer_norm( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = kwargs["input"] @@ -1079,8 +1080,8 @@ def acc_ops_layer_norm( @ait_converter(acc_ops.flatten) def acc_ops_flatten( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = kwargs["input"] @@ -1096,8 +1097,8 @@ def acc_ops_flatten( @ait_converter(acc_ops.matmul) def acc_ops_matmul( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: lhs = kwargs["input"] @@ -1163,8 +1164,8 @@ def acc_ops_matmul( @ait_converter(acc_ops.chunk) def acc_ops_chunk( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = kwargs["input"] @@ -1184,8 +1185,8 @@ def acc_ops_chunk( @ait_converter(ait_acc_ops.split) def ait_acc_ops_split( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = kwargs["input"] @@ -1211,8 +1212,8 @@ def ait_acc_ops_split( @ait_converter(acc_ops.expand) def ait_acc_ops_expand( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = kwargs["input"] @@ -1242,8 +1243,8 @@ def _is_int_list(iterable): @ait_converter(acc_ops.interpolate) def ait_acc_ops_interpolate( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = kwargs["input"] @@ -1268,8 +1269,8 @@ def ait_acc_ops_interpolate( @ait_converter(acc_ops.batch_norm) def acc_ops_batch_norm( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = kwargs["input"] @@ -1374,8 +1375,8 @@ def _choose_conv2d_op( @ait_converter(acc_ops.conv2d) def acc_ops_conv2d( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = ait_nchw2nhwc(kwargs["input"]) @@ -1513,8 +1514,8 @@ def _choose_conv3d_op( @ait_converter(acc_ops.conv3d) def acc_ops_conv3d( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = ait_ncdhw2ndhwc(kwargs["input"]) @@ -1544,8 +1545,8 @@ def acc_ops_conv3d( @ait_converter(acc_ops.max_pool3d) def acc_ops_max_pool3d( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = kwargs["input"] @@ -1610,8 +1611,8 @@ def acc_ops_max_pool3d( @ait_converter(acc_ops.max_pool2d) def acc_ops_max_pool2d( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = ait_nchw2nhwc(kwargs["input"]) @@ -1638,8 +1639,8 @@ def acc_ops_max_pool2d( @ait_converter(acc_ops.avg_pool2d) def acc_ops_avg_pool2d( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = ait_nchw2nhwc(kwargs["input"]) @@ -1667,8 +1668,8 @@ def acc_ops_avg_pool2d( @ait_converter(acc_ops.avg_pool3d) def acc_ops_avg_pool3d( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = kwargs["input"] @@ -1715,8 +1716,8 @@ def acc_ops_avg_pool3d( @ait_converter(acc_ops.adaptive_avg_pool2d) def acc_ops_adaptive_avg_pool2d( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = ait_nchw2nhwc(kwargs["input"]) @@ -1744,8 +1745,8 @@ def acc_ops_adaptive_avg_pool2d( @ait_converter(acc_ops.contiguous) def acc_ops_contiguous( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = kwargs["input"] @@ -1756,8 +1757,8 @@ def acc_ops_contiguous( @ait_converter(acc_ops.dtype) def acc_ops_to_dtype( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: # We suppose to bypass this op but in extreme case like @@ -1767,8 +1768,8 @@ def acc_ops_to_dtype( input_val = kwargs["input"] def _get_cast_to_dtype_from_kwargs( - kwargs: Dict[str, Argument], - ) -> Optional[torch.dtype]: + kwargs: dict[str, Argument], + ) -> torch.dtype | None: torch_dtype_to_ait_dtype_str = { torch.float: "float32", torch.half: "float16", @@ -1803,8 +1804,8 @@ def _get_cast_to_dtype_from_kwargs( @ait_converter(acc_ops.gelu) def acc_ops_gelu( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = kwargs["input"] @@ -1822,8 +1823,8 @@ def acc_ops_gelu( @ait_converter(acc_ops.pow) def acc_ops_pow( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = kwargs["input"] @@ -1837,8 +1838,8 @@ def acc_ops_pow( @ait_converter(acc_ops.tile) def acc_ops_tile( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = kwargs["input"] @@ -1867,14 +1868,14 @@ def acc_ops_tile( @ait_converter(math.sqrt) def math_sqrt( - target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str + target: Target, args: tuple[Argument, ...], kwargs: dict[str, Argument], name: str ) -> ConverterOutput: return create_unary_op(FuncEnum.SQRT, args, kwargs, name) @ait_converter(acc_ops.neg) def acc_ops_neg( - target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str + target: Target, args: tuple[Argument, ...], kwargs: dict[str, Argument], name: str ) -> ConverterOutput: input_val = kwargs["input"] if not isinstance(input_val, AITTensor): @@ -1893,7 +1894,7 @@ def acc_ops_neg( @ait_converter(acc_ops.new_full) def acc_ops_new_full( - target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str + target: Target, args: tuple[Argument, ...], kwargs: dict[str, Argument], name: str ) -> ConverterOutput: input_val = kwargs["input"] if not isinstance(input_val, AITTensor): @@ -1910,7 +1911,7 @@ def acc_ops_new_full( @ait_converter(acc_ops.full_like) def acc_ops_full_like( - target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str + target: Target, args: tuple[Argument, ...], kwargs: dict[str, Argument], name: str ) -> ConverterOutput: input_val = kwargs["input"] if not isinstance(input_val, AITTensor): @@ -1921,7 +1922,7 @@ def acc_ops_full_like( @ait_converter(acc_ops.new_ones) def acc_ops_new_ones( - target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str + target: Target, args: tuple[Argument, ...], kwargs: dict[str, Argument], name: str ) -> ConverterOutput: input_val = kwargs["input"] if not isinstance(input_val, AITTensor): @@ -1937,7 +1938,7 @@ def acc_ops_new_ones( @ait_converter(acc_ops.ones_like) def acc_ops_ones_like( - target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str + target: Target, args: tuple[Argument, ...], kwargs: dict[str, Argument], name: str ) -> ConverterOutput: input_val = kwargs["input"] if not isinstance(input_val, AITTensor): @@ -1947,7 +1948,7 @@ def acc_ops_ones_like( @ait_converter(acc_ops.new_zeros) def acc_ops_new_zeros( - target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str + target: Target, args: tuple[Argument, ...], kwargs: dict[str, Argument], name: str ) -> ConverterOutput: input_val = kwargs["input"] if not isinstance(input_val, AITTensor): @@ -1963,7 +1964,7 @@ def acc_ops_new_zeros( @ait_converter(acc_ops.zeros_like) def acc_ops_zeros_like( - target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str + target: Target, args: tuple[Argument, ...], kwargs: dict[str, Argument], name: str ) -> ConverterOutput: input_val = kwargs["input"] if not isinstance(input_val, AITTensor): @@ -1973,7 +1974,7 @@ def acc_ops_zeros_like( @ait_converter(acc_ops.masked_select) def acc_ops_masked_select( - target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str + target: Target, args: tuple[Argument, ...], kwargs: dict[str, Argument], name: str ) -> ConverterOutput: input_val = kwargs["input"] if not isinstance(input_val, AITTensor): @@ -1986,8 +1987,8 @@ def acc_ops_masked_select( @ait_converter(acc_ops.index_select) def acc_ops_index_select( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = args[0] if len(args) >= 1 else kwargs["input"] diff --git a/fx2ait/fx2ait/converters/ait_module_converters.py b/fx2ait/fx2ait/converters/ait_module_converters.py index 8a063804e..f83e6a624 100644 --- a/fx2ait/fx2ait/converters/ait_module_converters.py +++ b/fx2ait/fx2ait/converters/ait_module_converters.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import Any, Dict, Tuple +from typing import Any import numpy as np @@ -32,8 +32,8 @@ def multi_head_attention_module( target: Target, submod: Any, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: # TODO fix arg/kwargs matching diff --git a/fx2ait/fx2ait/converters/aten2ait_converters.py b/fx2ait/fx2ait/converters/aten2ait_converters.py index 473686f3a..32c146e0b 100644 --- a/fx2ait/fx2ait/converters/aten2ait_converters.py +++ b/fx2ait/fx2ait/converters/aten2ait_converters.py @@ -16,7 +16,7 @@ import torch # isort:skip import copy import operator -from typing import Dict, List, Tuple, Union +from typing import Union import numpy @@ -76,7 +76,7 @@ # Logging logger: logging.Logger = logging.getLogger(__name__) -ConverterOutput = Union[AITTensor, Tuple[AITTensor, ...], List[IntVar], IntVar] +ConverterOutput = Union[AITTensor, tuple[AITTensor, ...], list[IntVar], IntVar] ## make sure the functions are place in alphabetic order @@ -84,8 +84,8 @@ @ait_converter(torch.ops.aten._adaptive_avg_pool2d.default) def aten_ops_adaptive_avg_pool2d( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: # TODO: @qxy11 Update once NCHW supported @@ -113,8 +113,8 @@ def aten_ops_adaptive_avg_pool2d( @ait_converter(torch.ops.aten.avg_pool2d.default) def aten_ops_avg_pool2d( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: # TODO: @qxy11 Update once NCHW supported @@ -133,8 +133,8 @@ def aten_ops_avg_pool2d( @ait_converter(torch.ops.aten.batch_norm) def aten_ops_batch_norm( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: # TODO @qxy11: Update channels-last assumption once AIT backend is updated @@ -163,8 +163,8 @@ def aten_ops_batch_norm( @ait_converter(torch.ops.aten.add.Tensor) def aten_binary_ops_add( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: kwargs = { @@ -177,8 +177,8 @@ def aten_binary_ops_add( @ait_converter(torch.ops.aten.div.Tensor) def aten_binary_ops_div( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: kwargs = { @@ -191,8 +191,8 @@ def aten_binary_ops_div( @ait_converter(torch.ops.aten.mul.Tensor) def aten_binary_ops_mul( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: kwargs = { @@ -205,8 +205,8 @@ def aten_binary_ops_mul( @ait_converter(torch.ops.aten.sub.Tensor) def aten_binary_ops_sub( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: kwargs = { @@ -219,8 +219,8 @@ def aten_binary_ops_sub( @ait_converter(torch.ops.aten.cat.default) def aten_ops_cat( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: tensors = args[0] @@ -279,8 +279,8 @@ def _choose_conv2d_op( @ait_converter(torch.ops.aten.convolution.default) def aten_ops_conv2d( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: # TODO: qxy11: Update once channels-first format is supported @@ -389,8 +389,8 @@ def make_slice(x, dim, start, end, step, name): @ait_converter(torch.ops.aten.clone.default) def aten_unary_ops_clone( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = args[0] @@ -402,8 +402,8 @@ def aten_unary_ops_clone( @ait_converter(torch.ops.aten.cos.default) def aten_unary_ops_cos( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = args[0] @@ -414,8 +414,8 @@ def aten_unary_ops_cos( @ait_converter(torch.ops.aten.chunk.default) def aten_ops_chunk( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = args[0] @@ -435,8 +435,8 @@ def aten_ops_chunk( @ait_converter(torch.ops.aten.expand.default) def aten_ops_expand( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: # TODO expand is not functional yet but only for cases with dim=-1 @@ -467,8 +467,8 @@ def _is_int_list(iterable): @ait_converter(aten_operator_getitem) def aten_ops_getitem( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = args[0] @@ -494,8 +494,8 @@ def aten_ops_getitem( @ait_converter(torch.ops.aten.layer_norm.default) def aten_ops_layer_norm( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = args[0] @@ -521,8 +521,8 @@ def aten_ops_layer_norm( @ait_converter(torch.ops.aten.linear) def aten_ops_linear( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = args[0] @@ -539,8 +539,8 @@ def aten_ops_linear( @ait_converter(torch.ops.aten.max_pool2d) def aten_ops_max_pool2d( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: # TODO: @qxy11 Update once NCHW supported @@ -564,8 +564,8 @@ def aten_ops_max_pool2d( @ait_converter(torch.ops.aten.bmm.default) def aten_ops_matmul( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: if len(args) > 2: @@ -604,8 +604,8 @@ def aten_ops_matmul( @ait_converter(torch.ops.aten.mean.dim) def aten_ops_mean( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = args[0] @@ -631,8 +631,8 @@ def aten_ops_mean( @ait_converter(torch.ops.aten.nan_to_num.default) def aten_ops_nan_to_num( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = args[0] @@ -657,8 +657,8 @@ def aten_ops_nan_to_num( @ait_converter(torch.ops.aten.split.Tensor) def aten_ops_split( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = args[0] @@ -680,7 +680,7 @@ def aten_ops_split( @ait_converter(torch.ops.aten.sym_numel) def aten_ops_numel( - target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str + target: Target, args: tuple[Argument, ...], kwargs: dict[str, Argument], name: str ) -> ConverterOutput: input_val = args[0] if not isinstance(input_val, AITTensor): @@ -696,8 +696,8 @@ def aten_ops_numel( @ait_converter(torch.ops.aten.permute.default) def aten_ops_permute( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = args[0] @@ -714,8 +714,8 @@ def aten_ops_permute( @ait_converter(torch.ops.aten.pow.Tensor_Scalar) def aten_ops_pow( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = args[0] @@ -728,7 +728,7 @@ def aten_ops_pow( @ait_converter(torch.ops.aten.relu.default) def aten_ops_relu( - target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str + target: Target, args: tuple[Argument, ...], kwargs: dict[str, Argument], name: str ) -> ConverterOutput: input_val = args[0] if not isinstance(input_val, AITTensor): @@ -740,7 +740,7 @@ def aten_ops_relu( @ait_converter(torch.ops.aten.reshape) @ait_converter(torch.ops.aten.view.default) def aten_ops_reshape( - target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str + target: Target, args: tuple[Argument, ...], kwargs: dict[str, Argument], name: str ) -> ConverterOutput: input_val = args[0] if not isinstance(input_val, AITTensor): @@ -774,8 +774,8 @@ def aten_ops_reshape( @ait_converter(torch.ops.aten.sym_size) def aten_ops_size( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = args[0] @@ -790,8 +790,8 @@ def aten_ops_size( @ait_converter(torch.ops.aten.select.int) def aten_ops_slice( # noqa: C901 target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = args[0] @@ -886,8 +886,8 @@ def num_slice_types(slices): @ait_converter(torch.ops.aten.squeeze.dim) def aten_ops_squeeze( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = args[0] @@ -904,8 +904,8 @@ def aten_ops_squeeze( @ait_converter(torch.ops.aten.sum.dim_IntList) def aten_ops_sum( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = args[0] @@ -932,8 +932,8 @@ def aten_ops_sum( @ait_converter(torch.ops.aten.hardtanh.default) def aten_ops_hardtanh( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = args[0] @@ -952,8 +952,8 @@ def aten_ops_hardtanh( @ait_converter(torch.ops.aten.t.default) def aten_ops_transpose( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: # TODO: we will also support https://pytorch.org/docs/stable/generated/torch.transpose.html in the future @@ -971,8 +971,8 @@ def aten_ops_transpose( @ait_converter(torch.ops.aten.unsqueeze.default) def aten_ops_unsqueeze( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = args[0] @@ -990,8 +990,8 @@ def aten_ops_unsqueeze( @ait_converter(operator.mul) def operator_ops_mul( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = args[0] @@ -1013,8 +1013,8 @@ def operator_ops_mul( @ait_converter(operator.add) def operator_ops_add( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = args[0] @@ -1036,8 +1036,8 @@ def operator_ops_add( @ait_converter(operator.sub) def operator_ops_sub( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = args[0] @@ -1059,8 +1059,8 @@ def operator_ops_sub( @ait_converter(operator.floordiv) def operator_ops_floordiv( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = args[0] @@ -1082,8 +1082,8 @@ def operator_ops_floordiv( @ait_converter(torch.ops.aten.abs.default) def aten_unary_ops_abs( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = args[0] @@ -1096,8 +1096,8 @@ def aten_unary_ops_abs( @ait_converter(torch.ops.aten.clamp.default) def aten_unary_ops_clamp( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = args[0] @@ -1117,8 +1117,8 @@ def aten_unary_ops_clamp( @ait_converter(torch.ops.aten.log.default) def aten_unary_ops_log( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = args[0] @@ -1131,8 +1131,8 @@ def aten_unary_ops_log( @ait_converter(torch.ops.aten.sigmoid.default) def aten_unary_ops_sigmoid( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = args[0] @@ -1144,8 +1144,8 @@ def aten_unary_ops_sigmoid( @ait_converter(torch.ops.aten.sign.default) def aten_unary_ops_sign( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = args[0] @@ -1158,8 +1158,8 @@ def aten_unary_ops_sign( @ait_converter(torch.ops.aten.sin.default) def aten_unary_ops_sin( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = args[0] @@ -1169,8 +1169,8 @@ def aten_unary_ops_sin( @ait_converter(torch.ops.aten.sqrt.default) def aten_unary_ops_sqrt( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = args[0] @@ -1180,8 +1180,8 @@ def aten_unary_ops_sqrt( @ait_converter(torch.ops.aten.tanh.default) def aten_unary_ops_tanh( target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> ConverterOutput: input_val = args[0] diff --git a/fx2ait/fx2ait/converters/converter_registry.py b/fx2ait/fx2ait/converters/converter_registry.py index 11663efa0..ed52fdabc 100644 --- a/fx2ait/fx2ait/converters/converter_registry.py +++ b/fx2ait/fx2ait/converters/converter_registry.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import Any, Callable, Dict +from collections.abc import Callable +from typing import Any from torch.fx.node import Target -AIT_CONVERTERS: Dict[Target, Any] = {} +AIT_CONVERTERS: dict[Target, Any] = {} def ait_converter(key: Target, enabled: bool = True) -> Callable[[Any], Any]: diff --git a/fx2ait/fx2ait/converters/utils.py b/fx2ait/fx2ait/converters/utils.py index df25d9092..bbbce734f 100644 --- a/fx2ait/fx2ait/converters/utils.py +++ b/fx2ait/fx2ait/converters/utils.py @@ -14,7 +14,8 @@ # import math import operator -from typing import Any, Callable, Dict, List, Tuple, Union +from collections.abc import Callable +from typing import Any from aitemplate.compiler.base import IntImm, IntVar, IntVarTensor @@ -42,7 +43,7 @@ def get_positive_dim(dim: int, dim_size: int) -> int: def create_reduce_op( - op_type: Any, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str + op_type: Any, args: tuple[Argument, ...], kwargs: dict[str, Argument], name: str ) -> AITTensor: input_val = kwargs["input"] # TODO: remove once multiple reduction axes are supported @@ -66,8 +67,8 @@ def create_reduce_op( def create_binary_op( op_type: FuncEnum, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> AITTensor: lhs = kwargs["input"] @@ -121,8 +122,8 @@ def create_binary_op( def create_unary_op( op_type: FuncEnum, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], + args: tuple[Argument, ...], + kwargs: dict[str, Argument], name: str, ) -> AITTensor: input = kwargs["input"] if "input" in kwargs else args[0] @@ -185,11 +186,11 @@ def identical_elem_tuple_to_int(param): return param[0] -def nchw2nhwc(shape: List[Union[int, IntVar]]) -> List[Union[int, IntVar]]: +def nchw2nhwc(shape: list[int | IntVar]) -> list[int | IntVar]: return [shape[0], shape[2], shape[3], shape[1]] -def ncdhw2ndhwc(shape: List[Union[int, IntVar]]) -> List[Union[int, IntVar]]: +def ncdhw2ndhwc(shape: list[int | IntVar]) -> list[int | IntVar]: return [shape[0], shape[2], shape[3], shape[4], shape[1]] diff --git a/fx2ait/fx2ait/example/benchmark_utils.py b/fx2ait/fx2ait/example/benchmark_utils.py index f2d308eb5..d995bf4ec 100644 --- a/fx2ait/fx2ait/example/benchmark_utils.py +++ b/fx2ait/fx2ait/example/benchmark_utils.py @@ -15,7 +15,6 @@ import time import uuid -from typing import List, Optional import torch from fx2ait.acc_tracer import acc_tracer @@ -27,11 +26,11 @@ def verify_accuracy( mod: torch.nn.Module, - inputs: List[torch.Tensor], + inputs: list[torch.Tensor], rtol: float = 1e-01, atol: float = 1e-01, - permute_inputs: Optional[List[int]] = None, - permute_outputs: Optional[List[int]] = None, + permute_inputs: list[int] | None = None, + permute_outputs: list[int] | None = None, ): # TODO: add precision to interpreter once AIT supports multiple precision level # TODO: @qxy11 remove permute options once AIT supports channels-first format @@ -110,8 +109,8 @@ def benchmark_function( name: str, iters: int, mod: torch.nn.Module, - inputs: List[torch.Tensor], - permute_inputs: Optional[List[int]] = None, + inputs: list[torch.Tensor], + permute_inputs: list[int] | None = None, ait_mod: torch.nn.Module = None, ) -> float: mod.eval() diff --git a/fx2ait/fx2ait/fx2ait.py b/fx2ait/fx2ait/fx2ait.py index 76d06de00..6c1bee57f 100644 --- a/fx2ait/fx2ait/fx2ait.py +++ b/fx2ait/fx2ait/fx2ait.py @@ -17,8 +17,9 @@ import os import tempfile import warnings +from collections.abc import Sequence from datetime import datetime -from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Union +from typing import Any, NamedTuple import fx2ait.cache as cache @@ -57,18 +58,18 @@ class AITInterpreter(torch.fx.Interpreter): def __init__( self, module: torch.fx.GraphModule, - input_specs: List[Union[TensorSpec, List[TensorSpec]]], + input_specs: list[TensorSpec | list[TensorSpec]], workdir: str, name: str, dll_name: str = "test.so", dynamic_profile_strategy=DynamicProfileStrategy.MAX, profile_devs=None, use_fp16_acc=True, - dump_ait_dir: Optional[str] = None, - keep_constants: Optional[bool] = None, - load_ait_dir: Optional[str] = None, - remote_cache_file_path: Optional[str] = None, - save_remote_cache: Optional[bool] = False, + dump_ait_dir: str | None = None, + keep_constants: bool | None = None, + load_ait_dir: str | None = None, + remote_cache_file_path: str | None = None, + save_remote_cache: bool | None = False, do_optimize_graph: bool = True, use_fast_math: bool = True, use_tanh_for_sigmoid: bool = False, @@ -135,12 +136,12 @@ def __init__( self.dynamic_profile_strategy = dynamic_profile_strategy self.profile_devs = profile_devs - self._input_names: List[str] = [] - self._input_dtypes: List[str] = [] - self._output_names: List[str] = [] - self._output_dtypes: List[str] = [] - self._fx_input_names: List[str] = [] - self._loaded_params: Dict[str, AITTensor] = {} + self._input_names: list[str] = [] + self._input_dtypes: list[str] = [] + self._output_names: list[str] = [] + self._output_dtypes: list[str] = [] + self._fx_input_names: list[str] = [] + self._loaded_params: dict[str, AITTensor] = {} self.dump_ait_dir = dump_ait_dir self.keep_constants = keep_constants @@ -280,7 +281,7 @@ def run(self) -> AITInterpreterResult: self._input_dtypes.append(ait_input_dtypes[name]) for i, input_name in enumerate(self._fx_input_names): - _LOGGER.info("Set input{}: {}".format(i, input_name)) + _LOGGER.info(f"Set input{i}: {input_name}") if self.engine is None: raise RuntimeError("Engine is missing!") @@ -304,7 +305,7 @@ def run_node(self, n): def placeholder(self, target, args, kwargs): input_spec = self.input_specs[self.input_specs_iter] self.input_specs_iter += 1 - if isinstance(input_spec, List): + if isinstance(input_spec, list): """ List[Tensor] inputs are flattened in the compiled AIT engine. Pytorch module original forward: diff --git a/fx2ait/fx2ait/lower/lower.py b/fx2ait/fx2ait/lower/lower.py index 42e734af4..920c953b4 100644 --- a/fx2ait/fx2ait/lower/lower.py +++ b/fx2ait/fx2ait/lower/lower.py @@ -16,7 +16,8 @@ import datetime import logging import operator -from typing import Any, Callable, List, Optional, Sequence +from collections.abc import Callable, Sequence +from typing import Any import fx2ait.acc_tracer.acc_tracer as acc_tracer @@ -58,7 +59,7 @@ def __call__( self, module_name: str, mod: fx.GraphModule, - inputs: List[torch.Tensor], + inputs: list[torch.Tensor], ) -> AITInterpreterResult: (additional_inputs,) = self.lower_settings.additional_inputs if additional_inputs is None: @@ -191,7 +192,7 @@ def create( ) def lower_func( - self, split_result: SplitResult, additional_inputs: Optional[Input] = None + self, split_result: SplitResult, additional_inputs: Input | None = None ) -> nn.Module: if additional_inputs: additional_submodule_inputs = generate_inputs_for_submodules( @@ -231,7 +232,7 @@ def __call__( self, module: nn.Module, inputs: Input, - additional_inputs: Optional[Input] = None, + additional_inputs: Input | None = None, ) -> nn.Module: module.eval() module = acc_tracer.trace( @@ -244,8 +245,8 @@ def __call__( def _precision_to_torch_type( - precision: Optional[LowerPrecision], -) -> Optional[torch.dtype]: + precision: LowerPrecision | None, +) -> torch.dtype | None: if precision == LowerPrecision.FP16: return torch.float16 elif precision == LowerPrecision.FP32: diff --git a/fx2ait/fx2ait/lower/lower_settings.py b/fx2ait/fx2ait/lower/lower_settings.py index 340a0e00e..96e3a4a35 100644 --- a/fx2ait/fx2ait/lower/lower_settings.py +++ b/fx2ait/fx2ait/lower/lower_settings.py @@ -14,7 +14,7 @@ # import dataclasses as dc from enum import Enum -from typing import Any, List, Optional, Set, Type +from typing import Any import torch @@ -71,20 +71,20 @@ class LowerSettings: dynamic_size: int = -1 profile_devs: Any = None # If None, infer the dtypes from the sample inputs. - precision: Optional[LowerPrecision] = LowerPrecision.FP16 + precision: LowerPrecision | None = LowerPrecision.FP16 use_fp16_acc: bool = True # only valid for precision == FP16 use_fast_math: bool = True # Whether to use fast math in CUDA kernels allow_int_inputs: bool = False # If AIT acc subgraph accept integer inputs - ast_rewriter_allow_list: Optional[Set[Type[nn.Module]]] = None - leaf_module_list: Optional[Set[Type[nn.Module]]] = None + ast_rewriter_allow_list: set[type[nn.Module]] | None = None + leaf_module_list: set[type[nn.Module]] | None = None # If None, infer the dtypes from the sample inputs. - output_precision: Optional[LowerPrecision] = LowerPrecision.FP16 - additional_inputs: Optional[List[torch.Tensor]] = None - remote_cache_file_path: Optional[str] = None - save_remote_cache: Optional[bool] = None - dump_ait_dir: Optional[str] = None - keep_constants: Optional[bool] = None - load_ait_dir: Optional[str] = None + output_precision: LowerPrecision | None = LowerPrecision.FP16 + additional_inputs: list[torch.Tensor] | None = None + remote_cache_file_path: str | None = None + save_remote_cache: bool | None = None + dump_ait_dir: str | None = None + keep_constants: bool | None = None + load_ait_dir: str | None = None # jit.trace AITModule trace_ait_module: bool = True # If True, optimize for compilation time (ie. compile w/ -O1 rather than -O3 and skip profiling codegen) diff --git a/fx2ait/fx2ait/tensor_spec.py b/fx2ait/fx2ait/tensor_spec.py index 4a1977f41..a622f4a11 100644 --- a/fx2ait/fx2ait/tensor_spec.py +++ b/fx2ait/fx2ait/tensor_spec.py @@ -13,7 +13,7 @@ # limitations under the License. # import logging -from typing import Any, Dict, List, Optional, Set, Union +from typing import Any import torch from aitemplate.compiler.public import IntImm, IntVar @@ -24,7 +24,7 @@ class TensorSpec: - def __init__(self, shape: List[IntVar], dtype: torch.dtype) -> None: + def __init__(self, shape: list[IntVar], dtype: torch.dtype) -> None: self.shape = shape self.dtype = dtype @@ -51,9 +51,9 @@ def __repr__(self) -> str: @classmethod def from_two_input_lists( cls, - inputs1: List[Union[torch.Tensor, List[torch.Tensor]]], - inputs2: List[Union[torch.Tensor, List[torch.Tensor]]], - ) -> List["TensorSpec"]: + inputs1: list[torch.Tensor | list[torch.Tensor]], + inputs2: list[torch.Tensor | list[torch.Tensor]], + ) -> list["TensorSpec"]: """ This function is useful when we expect multiple dynamic dims. @@ -74,7 +74,7 @@ def from_two_input_lists( f"Different number of inputs: {len(inputs1)} vs {len(inputs2)}" ) - result: List[TensorSpec] = [] + result: list[TensorSpec] = [] for t1, t2 in zip(inputs1, inputs2): if isinstance(t1, list): @@ -86,7 +86,7 @@ def from_two_input_lists( raise ValueError( f"Different tensor sizes: {len(t1.shape)} vs {len(t2.shape)}" ) - shape: List[IntVar] = [] + shape: list[IntVar] = [] for i, (d1, d2) in enumerate(zip(t1.shape, t2.shape)): if d1 == d2: shape.append(IntImm(d1)) @@ -98,8 +98,8 @@ def from_two_input_lists( @classmethod def from_two_input_lists_jagged_tensor( - cls, inputs1: List[torch.Tensor], inputs2: List[torch.Tensor] - ) -> List["TensorSpec"]: + cls, inputs1: list[torch.Tensor], inputs2: list[torch.Tensor] + ) -> list["TensorSpec"]: """ This function is useful when we expect multiple dynamic dims. @@ -120,7 +120,7 @@ def from_two_input_lists_jagged_tensor( f"Different number of inputs: {len(inputs1)} vs {len(inputs2)}" ) - result: List[TensorSpec] = [] + result: list[TensorSpec] = [] dynamic_dict = {} num_dynamic = 0 for t1, t2 in zip(inputs1, inputs2): @@ -130,7 +130,7 @@ def from_two_input_lists_jagged_tensor( raise ValueError( f"Different tensor sizes: {len(t1.shape)} vs {len(t2.shape)}" ) - shape: List[IntVar] = [] + shape: list[IntVar] = [] for _, (d1, d2) in enumerate(zip(t1.shape, t2.shape)): if d1 == d2: shape.append(IntImm(d1)) @@ -168,10 +168,10 @@ def gen_int_var_min_max(cls, vmin: int, vmax: int, name: str = None): # noqa [B elif vmin < vmax: return IntVar([vmin, vmax], name=name) else: - raise RuntimeError("Unsupported int var definition: {}".format(values)) + raise RuntimeError(f"Unsupported int var definition: {values}") @classmethod - def create_spec_from_int_vars(cls, int_vars: List[IntVar], dtype_list: torch.dtype): + def create_spec_from_int_vars(cls, int_vars: list[IntVar], dtype_list: torch.dtype): if len(int_vars) != len(dtype_list): raise ValueError( f"Different number of int_var and dtype_list: {len(int_vars)} vs {len(dtype_list)}" @@ -183,8 +183,8 @@ def create_spec_from_int_vars(cls, int_vars: List[IntVar], dtype_list: torch.dty @classmethod def create_spec_from_shapes( - cls, inputs_min: List[int], inputs_max: List[int], dtype_list: torch.dtype - ) -> List["TensorSpec"]: + cls, inputs_min: list[int], inputs_max: list[int], dtype_list: torch.dtype + ) -> list["TensorSpec"]: if len(inputs_min) != len(inputs_max): raise ValueError( f"Different number of inputs: {len(inputs_min)} vs {len(inputs_max)}" @@ -196,7 +196,7 @@ def create_spec_from_shapes( f"Different number of input dims: {len(shape1)} vs {len(shape2)}" ) - shape: List[IntVar] = [] + shape: list[IntVar] = [] for i, (d1, d2) in enumerate(zip(shape1, shape2)): if d1 == d2: shape.append(IntImm(d1)) @@ -225,7 +225,7 @@ def to_specific_tensor(self, use_lower_bound, specify_num): @classmethod def create_inputs_from_specs( - cls, input_specs: List["TensorSpec"], use_lower_bound: bool, specify_num=None + cls, input_specs: list["TensorSpec"], use_lower_bound: bool, specify_num=None ) -> torch.Tensor: result = [] for inp in input_specs: @@ -240,19 +240,19 @@ def create_inputs_from_specs( @classmethod def from_input_list_with_batch_size( - cls, inputs: List[torch.Tensor], max_batch_size: int, batch_dim: int = 0 - ) -> List["TensorSpec"]: + cls, inputs: list[torch.Tensor], max_batch_size: int, batch_dim: int = 0 + ) -> list["TensorSpec"]: """ Most of the recommendation models will work fine using this function. We make an assumption that inferred lowerable subgraph inputs will have a single batch dimension with the same max batch size. """ - result: List[TensorSpec] = [] + result: list[TensorSpec] = [] bs_dim = cls.find_batch_size_dim(inputs) for index, t in enumerate(inputs): - shape: List[IntVar] = [] + shape: list[IntVar] = [] for i, d in enumerate(t.shape): if i == bs_dim[index]: shape.append(IntVar([1, max_batch_size], "batch_size")) @@ -264,9 +264,9 @@ def from_input_list_with_batch_size( @staticmethod def _get_max_seq_lens_from_offsets( - inputs: List[torch.Tensor], - jagged_offsets_batch_dims: Set[int], - ) -> Dict[int, int]: + inputs: list[torch.Tensor], + jagged_offsets_batch_dims: set[int], + ) -> dict[int, int]: """ Get the maximum sequence length encoded in each offsets tensor. @@ -311,10 +311,10 @@ def _get_max_seq_lens_from_offsets( @classmethod def try_getting_jagged_tensor_map( cls, - inputs: List[torch.Tensor], - jagged_tensor_batch_dims: Set[int], - fx_inputs: Optional[List[torch.fx.Node]] = None, - ) -> Optional[Dict[int, int]]: + inputs: list[torch.Tensor], + jagged_tensor_batch_dims: set[int], + fx_inputs: list[torch.fx.Node] | None = None, + ) -> dict[int, int] | None: """ Try getting a map associating each jagged tensor input with the offsets. @@ -366,16 +366,16 @@ def try_getting_jagged_tensor_map( @classmethod def from_input_list_with_batch_size_jagged_tensor( cls, - inputs: List[torch.Tensor], + inputs: list[torch.Tensor], max_batch_size: int, max_sequence_length: int, - jagged_tensor_batch_dims: Set[int], - jagged_offsets_batch_dims: Set[int], - additional_inputs: List[torch.Tensor] = None, + jagged_tensor_batch_dims: set[int], + jagged_offsets_batch_dims: set[int], + additional_inputs: list[torch.Tensor] = None, infer_max_seq_lens_from_offsets: bool = False, - fx_inputs: List[torch.fx.Node] = None, - jagged_tensor_map: Optional[Dict[int, int]] = None, - ) -> List["TensorSpec"]: + fx_inputs: list[torch.fx.Node] = None, + jagged_tensor_map: dict[int, int] | None = None, + ) -> list["TensorSpec"]: """ Most of the recommendation models will work fine using this function. @@ -389,11 +389,11 @@ def from_input_list_with_batch_size_jagged_tensor( jagged_offsets_batch_dims=jagged_offsets_batch_dims, ) - result: List = [] - result_unsorted: List = [] - left_inputs: List = [] - left_inputs_ind: List = [] - left_additional_inputs: List = [] + result: list = [] + result_unsorted: list = [] + left_inputs: list = [] + left_inputs_ind: list = [] + left_additional_inputs: list = [] for ind, t in enumerate(inputs): batch_dim: int = t.shape[0] batch_dim_lower_bound: int = 0 @@ -422,7 +422,7 @@ def from_input_list_with_batch_size_jagged_tensor( batch_dim_name = f"batch_size_jagged_offsets_{batch_dim}" if batch_dim_upper_bound > 0: - shape: List[IntVar] = [] + shape: list[IntVar] = [] for i, d in enumerate(t.shape): if i == 0: shape.append( @@ -457,7 +457,7 @@ def from_input_list_with_batch_size_jagged_tensor( else: bs_dim = cls.find_batch_size_dim(left_inputs) for index, t in enumerate(left_inputs): - shape: List[IntVar] = [] + shape: list[IntVar] = [] for i, d in enumerate(t.shape): if i == bs_dim[index]: shape.append(IntVar([1, max_batch_size], "batch_size")) @@ -489,18 +489,18 @@ def find_batch_size_dim( @classmethod def from_input_list_with_batch_size_static_batch( - cls, inputs: List[torch.Tensor], max_batch_size: int, batch_dim: int = 0 - ) -> List["TensorSpec"]: + cls, inputs: list[torch.Tensor], max_batch_size: int, batch_dim: int = 0 + ) -> list["TensorSpec"]: """ Most of the recommendation models will work fine using this function. We make an assumption that inferred lowerable subgraph inputs will have a single batch dimension with the same max batch size. """ - result: List[TensorSpec] = [] + result: list[TensorSpec] = [] for t in inputs: - shape: List[IntVar] = [] + shape: list[IntVar] = [] for _, d in enumerate(t.shape): shape.append(IntImm(d)) result.append(TensorSpec(shape, t.dtype)) diff --git a/fx2ait/fx2ait/test/converters/test_ait_binary_op.py b/fx2ait/fx2ait/test/converters/test_ait_binary_op.py index fbedbc1e6..13820ebf8 100644 --- a/fx2ait/fx2ait/test/converters/test_ait_binary_op.py +++ b/fx2ait/fx2ait/test/converters/test_ait_binary_op.py @@ -13,7 +13,7 @@ # limitations under the License. # import operator -from typing import Callable, List, Tuple, Union +from collections.abc import Callable import torch @@ -79,7 +79,7 @@ def test_two_tensors( name: str, op: Callable, acc_op: Callable, - inputs: List[Tuple[torch.Tensor, torch.Tensor]], + inputs: list[tuple[torch.Tensor, torch.Tensor]], ) -> None: class TestModule(torch.nn.Module): def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: @@ -102,7 +102,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: ] ) def test_scalar_operand( - self, name: str, scalar: Union[int, float], op: Callable, acc_op: Callable + self, name: str, scalar: int | float, op: Callable, acc_op: Callable ) -> None: class TestModuleScalarLhs(torch.nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -137,8 +137,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def test_constant_operand( self, name: str, - x: Union[int, float], - y: Union[int, float], + x: int | float, + y: int | float, op: Callable, acc_op: Callable, ) -> None: diff --git a/fx2ait/fx2ait/test/converters/test_ait_reshape.py b/fx2ait/fx2ait/test/converters/test_ait_reshape.py index 28cf42287..2251b8eef 100644 --- a/fx2ait/fx2ait/test/converters/test_ait_reshape.py +++ b/fx2ait/fx2ait/test/converters/test_ait_reshape.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import List import torch from fx2ait.acc_tracer import acc_ops @@ -34,7 +33,7 @@ class TestReshapeConverter(AITTestCase): [[2, 3, 4], [-1]], ] ) - def test_simple(self, original_shape: List[int], final_shape: List[int]) -> None: + def test_simple(self, original_shape: list[int], final_shape: list[int]) -> None: class TestModule(torch.nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.reshape(x, final_shape) diff --git a/fx2ait/fx2ait/test/converters/test_ait_split.py b/fx2ait/fx2ait/test/converters/test_ait_split.py index ab0d81032..dba42755f 100644 --- a/fx2ait/fx2ait/test/converters/test_ait_split.py +++ b/fx2ait/fx2ait/test/converters/test_ait_split.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import List, Union import torch from fx2ait.acc_tracer import ait_acc_ops @@ -32,7 +31,7 @@ class TestSplitConverter(AITTestCase): ] ) def test_with_dim( - self, input_shape: List[int], split_size_or_sections: Union[int, List[int]] + self, input_shape: list[int], split_size_or_sections: int | list[int] ) -> None: class TestModule(torch.nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -52,7 +51,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ] ) def test_without_dim( - self, input_shape: List[int], split_size_or_sections: Union[int, List[int]] + self, input_shape: list[int], split_size_or_sections: int | list[int] ) -> None: class TestModule(torch.nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -72,7 +71,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ] ) def test_tensor_split_with_dim( - self, input_shape: List[int], split_size_or_sections: Union[int, List[int]] + self, input_shape: list[int], split_size_or_sections: int | list[int] ) -> None: class TestModule(torch.nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -92,7 +91,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ] ) def test_tensor_split_without_dim( - self, input_shape: List[int], split_size_or_sections: Union[int, List[int]] + self, input_shape: list[int], split_size_or_sections: int | list[int] ) -> None: class TestModule(torch.nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: diff --git a/fx2ait/fx2ait/test/converters/test_ait_topk.py b/fx2ait/fx2ait/test/converters/test_ait_topk.py index 91615ac53..87b92bd90 100644 --- a/fx2ait/fx2ait/test/converters/test_ait_topk.py +++ b/fx2ait/fx2ait/test/converters/test_ait_topk.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import List import torch from fx2ait.acc_tracer import acc_ops @@ -29,7 +28,7 @@ class TestTopkConverter(AITTestCase): [[6], 6], ] ) - def test_simple(self, input: List[int], k: int) -> None: + def test_simple(self, input: list[int], k: int) -> None: class TestModule(torch.nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.topk(x, k) @@ -50,7 +49,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ] ) def test_multi_dimensional( - self, input: List[int], k: int, dtype: torch.dtype + self, input: list[int], k: int, dtype: torch.dtype ) -> None: class TestModule(torch.nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: diff --git a/fx2ait/fx2ait/test/converters/test_ait_unary_ops.py b/fx2ait/fx2ait/test/converters/test_ait_unary_ops.py index 672fecd30..70e73ed52 100644 --- a/fx2ait/fx2ait/test/converters/test_ait_unary_ops.py +++ b/fx2ait/fx2ait/test/converters/test_ait_unary_ops.py @@ -14,7 +14,7 @@ # import itertools import math -from typing import Callable, Dict, Set +from collections.abc import Callable import torch from aitemplate.testing.test_utils import filter_test_cases_by_params, TestEnv @@ -40,7 +40,7 @@ (torch.exp, acc_ops.exp), ] -TestEnvToPrecision: Dict[TestEnv, Set[LowerPrecision]] = { +TestEnvToPrecision: dict[TestEnv, set[LowerPrecision]] = { TestEnv.CUDA_LESS_THAN_SM80: [LowerPrecision.FP16, LowerPrecision.FP32], TestEnv.CUDA_SM80: [LowerPrecision.BF16], TestEnv.ROCM: [LowerPrecision.FP16],