Skip to content

Commit a62926e

Browse files
tushar00jainfacebook-github-bot
authored andcommitted
fix typechecking (#275)
Summary: as title Reviewed By: d4l3k Differential Revision: D83488418
1 parent 9d8d80f commit a62926e

File tree

4 files changed

+12
-12
lines changed

4 files changed

+12
-12
lines changed

torchft/multiprocessing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99

1010
class _MonitoredPipe:
11-
def __init__(self, pipe: "Connection[object, object]") -> None:
11+
def __init__(self, pipe: "Connection") -> None:
1212
self._pipe = pipe
1313

1414
def send(self, obj: object) -> None:

torchft/multiprocessing_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66
from torchft.multiprocessing import _MonitoredPipe
77

88

9-
def pipe_get(q: "Connection[object, object]") -> None:
9+
def pipe_get(q: "Connection") -> None:
1010
q.recv()
1111

1212

13-
def pipe_put(q: "Connection[object, object]") -> None:
13+
def pipe_put(q: "Connection") -> None:
1414
q.recv()
1515
q.send(1)
1616

torchft/process_group.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1443,8 +1443,8 @@ def _worker(
14431443
store_addr: str,
14441444
rank: int,
14451445
world_size: int,
1446-
req_pipe: "Connection[object, object]",
1447-
future_pipe: "Connection[object, object]",
1446+
req_pipe: "Connection",
1447+
future_pipe: "Connection",
14481448
curr_device: int,
14491449
) -> None:
14501450
try:

torchft/quantization.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def _fused_kernel_quantize_into_fp8(
125125
# be written
126126
o_curr_ptr = o_ptr + o_offset
127127
o_scale_ptr = o_curr_ptr.to(tl.pointer_type(SCALE_TL_DTYPE))
128-
o_quant_ptr = (o_curr_ptr + SCALE_TL_DTYPE_BYTES).to(tl.pointer_type(TL_FP8_TYPE))
128+
o_quant_ptr = (o_curr_ptr + SCALE_TL_DTYPE_BYTES).to(tl.pointer_type(TL_FP8_TYPE)) # type: ignore
129129

130130
# Compute maximum for the current row block by block
131131
col_offsets = tl.arange(0, BLOCK_SIZE)
@@ -233,7 +233,7 @@ def _fused_kernel_dequantize_from_fp8(
233233
# written
234234
o_curr_ptr = o_ptr + o_offset
235235
o_scale_ptr = o_curr_ptr.to(tl.pointer_type(SCALE_TL_DTYPE))
236-
o_quant_ptr = (o_curr_ptr + SCALE_TL_DTYPE_BYTES).to(tl.pointer_type(TL_FP8_TYPE))
236+
o_quant_ptr = (o_curr_ptr + SCALE_TL_DTYPE_BYTES).to(tl.pointer_type(TL_FP8_TYPE)) # type: ignore
237237

238238
# Load row scale
239239
i_row_scale = tl.load(o_scale_ptr)
@@ -342,7 +342,7 @@ def _fused_kernel_reduce_fp8(
342342
o_rank_row_ptr = o_ptr + all_reduce_rank * o_size_bytes_per_rank + o_offset
343343
o_rank_scale_ptr = o_rank_row_ptr.to(tl.pointer_type(SCALE_TL_DTYPE))
344344
o_rank_quant_ptr = (o_rank_row_ptr + SCALE_TL_DTYPE_BYTES).to(
345-
tl.pointer_type(TL_FP8_TYPE)
345+
tl.pointer_type(TL_FP8_TYPE) # type: ignore
346346
)
347347

348348
col_offsets = tl.arange(0, BLOCK_SIZE)
@@ -411,7 +411,7 @@ def _fused_kernel_accumulate_block(
411411
# Load row scale and block of quantized row
412412
o_scale_ptr = o_row_ptr.to(tl.pointer_type(tl.float32))
413413
o_quant_ptr = (o_row_ptr + SCALE_TL_DTYPE_BYTES).to(
414-
tl.pointer_type(TL_FP8_TYPE)
414+
tl.pointer_type(TL_FP8_TYPE) # type: ignore
415415
)
416416

417417
o_row_scale = tl.load(o_scale_ptr)
@@ -580,7 +580,7 @@ def fused_quantize_into_fp8(
580580
output,
581581
output_size // all_reduce_group_size,
582582
all_reduce_group_size,
583-
BLOCK_SIZE=BLOCK_SIZE_T,
583+
BLOCK_SIZE=BLOCK_SIZE_T, # type: ignore
584584
TL_FP8_TYPE=_get_fp8_type(),
585585
TL_FP8_MAX=_get_fp8_max(),
586586
)
@@ -630,7 +630,7 @@ def fused_dequantize_from_fp8(
630630
output,
631631
output_size // all_reduce_group_size,
632632
all_reduce_group_size,
633-
BLOCK_SIZE=BLOCK_SIZE_T,
633+
BLOCK_SIZE=BLOCK_SIZE_T, # type: ignore
634634
TL_FP8_TYPE=_get_fp8_type(),
635635
)
636636

@@ -680,7 +680,7 @@ def fused_reduce_fp8(
680680
all_reduce_group_size,
681681
all_reduce_rank,
682682
1.0 if reduce_op == ReduceOp.SUM else float(all_reduce_group_size),
683-
BLOCK_SIZE=BLOCK_SIZE_T,
683+
BLOCK_SIZE=BLOCK_SIZE_T, # type: ignore
684684
TL_FP8_TYPE=_get_fp8_type(),
685685
TL_FP8_MAX=_get_fp8_max(),
686686
)

0 commit comments

Comments
 (0)