|
13 | 13 | # limitations under the License. |
14 | 14 | # |
15 | 15 | from inspect import isfunction |
16 | | -from typing import Optional |
17 | 16 |
|
18 | 17 | from aitemplate.compiler import ops |
19 | 18 | from aitemplate.frontend import nn, Tensor |
@@ -279,10 +278,10 @@ def __init__( |
279 | 278 | def forward( |
280 | 279 | self, |
281 | 280 | hidden_states: Tensor, |
282 | | - attention_mask: Optional[Tensor] = None, |
283 | | - causal_attention_mask: Optional[Tensor] = None, |
284 | | - output_attentions: Optional[bool] = False, |
285 | | - residual: Optional[Tensor] = None, |
| 281 | + attention_mask: Tensor | None = None, |
| 282 | + causal_attention_mask: Tensor | None = None, |
| 283 | + output_attentions: bool | None = False, |
| 284 | + residual: Tensor | None = None, |
286 | 285 | ): |
287 | 286 | if residual is not None: |
288 | 287 | self_output = self.attn(hidden_states, residual) |
@@ -399,7 +398,7 @@ def __init__( |
399 | 398 | def forward( |
400 | 399 | self, |
401 | 400 | hidden_states: Tensor, |
402 | | - output_attentions: Optional[bool] = False, |
| 401 | + output_attentions: bool | None = False, |
403 | 402 | ): |
404 | 403 | """ |
405 | 404 | Args: |
@@ -469,11 +468,11 @@ def __init__( |
469 | 468 | def forward( |
470 | 469 | self, |
471 | 470 | inputs_embeds, |
472 | | - attention_mask: Optional[Tensor] = None, |
473 | | - causal_attention_mask: Optional[Tensor] = None, |
474 | | - output_attentions: Optional[bool] = None, |
475 | | - output_hidden_states: Optional[bool] = None, |
476 | | - return_dict: Optional[bool] = None, |
| 471 | + attention_mask: Tensor | None = None, |
| 472 | + causal_attention_mask: Tensor | None = None, |
| 473 | + output_attentions: bool | None = None, |
| 474 | + output_hidden_states: bool | None = None, |
| 475 | + return_dict: bool | None = None, |
477 | 476 | ): |
478 | 477 | r""" |
479 | 478 | Args: |
@@ -548,7 +547,7 @@ def forward( |
548 | 547 | self, |
549 | 548 | input_ids: Tensor, |
550 | 549 | position_ids: Tensor, |
551 | | - inputs_embeds: Optional[Tensor] = None, |
| 550 | + inputs_embeds: Tensor | None = None, |
552 | 551 | ) -> Tensor: |
553 | 552 | input_shape = ops.size()(input_ids) |
554 | 553 |
|
@@ -612,12 +611,12 @@ def __init__( |
612 | 611 |
|
613 | 612 | def forward( |
614 | 613 | self, |
615 | | - input_ids: Optional[Tensor] = None, |
616 | | - attention_mask: Optional[Tensor] = None, |
617 | | - position_ids: Optional[Tensor] = None, |
618 | | - output_attentions: Optional[bool] = None, |
619 | | - output_hidden_states: Optional[bool] = None, |
620 | | - return_dict: Optional[bool] = None, |
| 614 | + input_ids: Tensor | None = None, |
| 615 | + attention_mask: Tensor | None = None, |
| 616 | + position_ids: Tensor | None = None, |
| 617 | + output_attentions: bool | None = None, |
| 618 | + output_hidden_states: bool | None = None, |
| 619 | + return_dict: bool | None = None, |
621 | 620 | ): |
622 | 621 | r""" |
623 | 622 | Returns: |
|
0 commit comments