@@ -125,7 +125,7 @@ def _fused_kernel_quantize_into_fp8(
125
125
# be written
126
126
o_curr_ptr = o_ptr + o_offset
127
127
o_scale_ptr = o_curr_ptr .to (tl .pointer_type (SCALE_TL_DTYPE ))
128
- o_quant_ptr = (o_curr_ptr + SCALE_TL_DTYPE_BYTES ).to (tl .pointer_type (TL_FP8_TYPE ))
128
+ o_quant_ptr = (o_curr_ptr + SCALE_TL_DTYPE_BYTES ).to (tl .pointer_type (TL_FP8_TYPE )) # type: ignore
129
129
130
130
# Compute maximum for the current row block by block
131
131
col_offsets = tl .arange (0 , BLOCK_SIZE )
@@ -233,7 +233,7 @@ def _fused_kernel_dequantize_from_fp8(
233
233
# written
234
234
o_curr_ptr = o_ptr + o_offset
235
235
o_scale_ptr = o_curr_ptr .to (tl .pointer_type (SCALE_TL_DTYPE ))
236
- o_quant_ptr = (o_curr_ptr + SCALE_TL_DTYPE_BYTES ).to (tl .pointer_type (TL_FP8_TYPE ))
236
+ o_quant_ptr = (o_curr_ptr + SCALE_TL_DTYPE_BYTES ).to (tl .pointer_type (TL_FP8_TYPE )) # type: ignore
237
237
238
238
# Load row scale
239
239
i_row_scale = tl .load (o_scale_ptr )
@@ -342,7 +342,7 @@ def _fused_kernel_reduce_fp8(
342
342
o_rank_row_ptr = o_ptr + all_reduce_rank * o_size_bytes_per_rank + o_offset
343
343
o_rank_scale_ptr = o_rank_row_ptr .to (tl .pointer_type (SCALE_TL_DTYPE ))
344
344
o_rank_quant_ptr = (o_rank_row_ptr + SCALE_TL_DTYPE_BYTES ).to (
345
- tl .pointer_type (TL_FP8_TYPE )
345
+ tl .pointer_type (TL_FP8_TYPE ) # type: ignore
346
346
)
347
347
348
348
col_offsets = tl .arange (0 , BLOCK_SIZE )
@@ -411,7 +411,7 @@ def _fused_kernel_accumulate_block(
411
411
# Load row scale and block of quantized row
412
412
o_scale_ptr = o_row_ptr .to (tl .pointer_type (tl .float32 ))
413
413
o_quant_ptr = (o_row_ptr + SCALE_TL_DTYPE_BYTES ).to (
414
- tl .pointer_type (TL_FP8_TYPE )
414
+ tl .pointer_type (TL_FP8_TYPE ) # type: ignore
415
415
)
416
416
417
417
o_row_scale = tl .load (o_scale_ptr )
@@ -580,7 +580,7 @@ def fused_quantize_into_fp8(
580
580
output ,
581
581
output_size // all_reduce_group_size ,
582
582
all_reduce_group_size ,
583
- BLOCK_SIZE = BLOCK_SIZE_T ,
583
+ BLOCK_SIZE = BLOCK_SIZE_T , # type: ignore
584
584
TL_FP8_TYPE = _get_fp8_type (),
585
585
TL_FP8_MAX = _get_fp8_max (),
586
586
)
@@ -630,7 +630,7 @@ def fused_dequantize_from_fp8(
630
630
output ,
631
631
output_size // all_reduce_group_size ,
632
632
all_reduce_group_size ,
633
- BLOCK_SIZE = BLOCK_SIZE_T ,
633
+ BLOCK_SIZE = BLOCK_SIZE_T , # type: ignore
634
634
TL_FP8_TYPE = _get_fp8_type (),
635
635
)
636
636
@@ -680,7 +680,7 @@ def fused_reduce_fp8(
680
680
all_reduce_group_size ,
681
681
all_reduce_rank ,
682
682
1.0 if reduce_op == ReduceOp .SUM else float (all_reduce_group_size ),
683
- BLOCK_SIZE = BLOCK_SIZE_T ,
683
+ BLOCK_SIZE = BLOCK_SIZE_T , # type: ignore
684
684
TL_FP8_TYPE = _get_fp8_type (),
685
685
TL_FP8_MAX = _get_fp8_max (),
686
686
)
0 commit comments