Skip to content

Commit 7db09b7

Browse files
authored
Fix matmul output dtype to match PyTorch eager behavior (#1044)
1 parent 6f0ce18 commit 7db09b7

File tree

8 files changed

+73
-45
lines changed

8 files changed

+73
-45
lines changed

examples/squeeze_and_excitation_net.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def squeeze_and_excitation_net_fwd(
5353
for tile_n in hl.tile(n):
5454
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
5555
for tile_k in hl.tile(k):
56-
acc += c[tile_m, tile_k] @ b[tile_k, tile_n]
56+
acc = torch.addmm(acc, c[tile_m, tile_k], b[tile_k, tile_n])
5757
d[tile_m, tile_n] = torch.sigmoid(acc)
5858
out[tile_m, tile_n] = x[tile_m, tile_n] * d[tile_m, tile_n]
5959

@@ -103,7 +103,7 @@ def squeeze_and_excitation_net_bwd_dx(
103103

104104
# Backprop through (x @ a): grad_x_contribution = grad_c_masked @ a.T
105105
# [tile_m, tile_k] @ [tile_k, tile_n] = [tile_m, tile_n]
106-
acc += grad_c_masked @ a[tile_n, tile_k].T
106+
acc = torch.addmm(acc, grad_c_masked, a[tile_n, tile_k].T)
107107

108108
grad_x[tile_m, tile_n] = acc
109109

@@ -136,7 +136,7 @@ def squeeze_and_excitation_net_bwd_da(
136136
# Backprop through relu
137137
grad_through_relu = grad_to_c * (c[tile_m, tile_k] > 0)
138138
# Accumulate x.T @ grad_c: [tile_n, tile_m] @ [tile_m, tile_k] = [tile_n, tile_k]
139-
acc_a += x[tile_m, tile_n].T @ grad_through_relu
139+
acc_a = torch.addmm(acc_a, x[tile_m, tile_n].T, grad_through_relu)
140140
grad_a[tile_n, tile_k] = acc_a
141141

142142
return grad_a
@@ -164,7 +164,7 @@ def squeeze_and_excitation_net_bwd_db(
164164
* d[tile_m, tile_n]
165165
* (1.0 - d[tile_m, tile_n])
166166
)
167-
acc += c[tile_m, tile_k].T @ grad_d
167+
acc = torch.addmm(acc, c[tile_m, tile_k].T, grad_d)
168168
grad_b[tile_k, tile_n] = acc
169169

170170
return grad_b

helion/_compiler/inductor_lowering.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1091,13 +1091,19 @@ def reduce_3d_dot(
10911091
else None
10921092
) # pyright: ignore[reportOptionalMemberAccess]
10931093

1094+
# Extract expected output dtype from FX node to match PyTorch eager mode behavior
1095+
out_dtype: torch.dtype | None = None
1096+
if "val" in node.meta and isinstance(node.meta["val"], torch.Tensor):
1097+
out_dtype = node.meta["val"].dtype
1098+
10941099
return emit_tl_dot_with_padding(
10951100
lhs,
10961101
rhs,
10971102
acc if with_acc else None,
10981103
lhs_dtype,
10991104
rhs_dtype,
11001105
acc_dtype=acc_dtype_meta if with_acc else None,
1106+
out_dtype=out_dtype,
11011107
lhs_shape=lhs_shape,
11021108
rhs_shape=rhs_shape,
11031109
acc_shape=acc_shape,

helion/_compiler/matmul_utils.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,9 +196,25 @@ def emit_tl_dot_with_padding(
196196
acc_out = acc if not fuse_acc else None
197197
acc_for_dot = acc if fuse_acc else None
198198
acc_cast_dtype = acc_dtype if not fuse_acc else None
199-
dot_out_dtype = out_dtype or (
199+
200+
# Determine the out_dtype to use for tl.dot operation, and whether to
201+
# explicitly cast the tl.dot result to the expected output dtype
202+
expected_out_dtype = out_dtype or (
200203
acc_dtype if fuse_acc else _compute_out_dtype(lhs_dtype, rhs_dtype)
201204
)
205+
if expected_out_dtype == torch.float32:
206+
dot_out_dtype = torch.float32
207+
elif expected_out_dtype == torch.float16:
208+
dot_out_dtype = (
209+
torch.float32
210+
if common_dtype in {torch.float16, torch.bfloat16} and not fuse_acc
211+
else torch.float16
212+
)
213+
elif common_dtype == torch.int8 and expected_out_dtype == torch.int32:
214+
dot_out_dtype = torch.int32
215+
else:
216+
# Unsupported dtype (like bfloat16), use float32 and cast afterward
217+
dot_out_dtype = torch.float32
202218

203219
# Squeeze 3D shapes to 2D when leading dims map to block size 1 for both operands.
204220
need_squeeze_dim = (
@@ -320,6 +336,12 @@ def emit_tl_dot_with_padding(
320336

321337
if acc_cast_dtype is not None:
322338
result = cast_ast(result, acc_cast_dtype)
339+
340+
# Explicitly cast to expected output dtype if we used a different out_dtype for tl.dot and haven't already cast
341+
if dot_out_dtype != expected_out_dtype and acc_cast_dtype != expected_out_dtype:
342+
assert expected_out_dtype is not None
343+
result = cast_ast(result, expected_out_dtype)
344+
323345
return (
324346
expr_from_string("{acc} + {mm}", acc=acc_out, mm=result)
325347
if not fuse_acc and acc_out is not None

test/test_dot.expected

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2255,7 +2255,7 @@ def _helion_mm_small_dims(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1:
22552255
# src[test_dot.py:N]: acc = mm_func(acc, x[tile_m, tile_k], y[tile_k, tile_n])
22562256
load = tl.load(x + (indices_0[:, None] * 6 + indices_2[None, :] * 1), mask_0[:, None] & mask_2[None, :], other=0)
22572257
load_1 = tl.load(y + (indices_2[:, None] * 7 + indices_1[None, :] * 1), mask_2[:, None] & mask_1[None, :], other=0)
2258-
mm = tl.dot(tl.cast(load, tl.bfloat16), tl.cast(load_1, tl.bfloat16), input_precision='tf32', out_dtype=tl.float32)
2258+
mm = tl.cast(tl.dot(tl.cast(load, tl.bfloat16), tl.cast(load_1, tl.bfloat16), input_precision='tf32', out_dtype=tl.float32), tl.bfloat16)
22592259
v_0 = tl.cast(mm, tl.float32)
22602260
acc = acc_copy_0 + v_0
22612261
# src[test_dot.py:N]: out[tile_m, tile_n] = acc
@@ -2316,7 +2316,7 @@ def _helion_mm_small_dims(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1:
23162316
# src[test_dot.py:N]: acc = mm_func(acc, x[tile_m, tile_k], y[tile_k, tile_n])
23172317
load = tl.load(x + (indices_0[:, None] * 6 + indices_2[None, :] * 1), mask_0[:, None] & mask_2[None, :], other=0)
23182318
load_1 = tl.load(y + (indices_2[:, None] * 7 + indices_1[None, :] * 1), mask_2[:, None] & mask_1[None, :], other=0)
2319-
mm = tl.dot(tl.cast(load, tl.bfloat16), tl.cast(load_1, tl.bfloat16), input_precision='tf32', out_dtype=tl.float32)
2319+
mm = tl.cast(tl.dot(tl.cast(load, tl.bfloat16), tl.cast(load_1, tl.bfloat16), input_precision='tf32', out_dtype=tl.float32), tl.bfloat16)
23202320
v_0 = tl.cast(mm, tl.float32)
23212321
acc = acc_copy_0 + v_0
23222322
# src[test_dot.py:N]: out[tile_m, tile_n] = acc

test/test_dot.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,7 @@ def _test_small_dims(
350350
n_dim,
351351
mm_func,
352352
check_code=False,
353+
check_matmul_cast_pattern=False,
353354
*,
354355
rtol: float = 1e-2,
355356
atol: float = 1e-3,
@@ -376,6 +377,11 @@ def mm_small_dims(
376377
if check_code:
377378
code, result = code_and_output(mm_small_dims, (x, y, mm_func))
378379
self.assertExpectedJournal(code)
380+
if check_matmul_cast_pattern:
381+
self.assertIn(
382+
"mm = tl.cast(tl.dot(tl.cast(load, tl.bfloat16), tl.cast(load_1, tl.bfloat16), input_precision='tf32', out_dtype=tl.float32), tl.bfloat16)",
383+
code,
384+
)
379385
else:
380386
result = mm_small_dims(x, y, mm_func)
381387

@@ -773,6 +779,7 @@ def test_mm_multiple_small_dims(self):
773779
n_dim=7,
774780
mm_func=lambda acc, a, b: acc + torch.mm(a, b),
775781
check_code=True,
782+
check_matmul_cast_pattern=True,
776783
)
777784

778785
def test_mm_reshape_m_1(self):
@@ -850,6 +857,7 @@ def test_matmul_multiple_small_dims(self):
850857
n_dim=7,
851858
mm_func=lambda acc, a, b: acc + torch.matmul(a, b),
852859
check_code=True,
860+
check_matmul_cast_pattern=True,
853861
)
854862

855863
def test_matmul_reshape_m_1(self):

test/test_examples.expected

Lines changed: 28 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def _helion_attention(q_view, k_view, v_view, out, _BLOCK_SIZE_0: tl.constexpr,
241241
# src[attention.py:N]: k = k_view[tile_b, :, tile_n]
242242
k = tl.load(tl.make_block_ptr(k_view, [64, 64, 512], [32768, 1, 64], [offset_0, 0, offset_2], [_BLOCK_SIZE_0, 64, _BLOCK_SIZE_3], [2, 0, 1]), boundary_check=[0, 1, 2], padding_option='zero')
243243
# src[attention.py:N]: qk = torch.bmm(q, k)
244-
qk = tl.dot(tl.cast(q_copy_0, tl.float16), tl.cast(k, tl.float16), input_precision='tf32', out_dtype=tl.float32)
244+
qk = tl.cast(tl.dot(tl.cast(q_copy_0, tl.float16), tl.cast(k, tl.float16), input_precision='tf32', out_dtype=tl.float32), tl.float16)
245245
# src[attention.py:N]: m_ij = torch.maximum(m_i, torch.amax(qk, -1) * qk_scale)
246246
amax = tl.cast(tl.max(qk, 2), tl.float16)
247247
v_0 = 0.18033688
@@ -519,7 +519,7 @@ def _helion_attention(q_view, k_view, v_view, out, _NUM_SM: tl.constexpr, _BLOCK
519519
# src[attention.py:N]: k = k_view[tile_b, :, tile_n]
520520
k = tl.load(tl.make_block_ptr(k_view, [32, 64, 512], [32768, 1, 64], [offset_0, 0, offset_2], [_BLOCK_SIZE_0, 64, _BLOCK_SIZE_3], [2, 0, 1]), boundary_check=[0, 1, 2], padding_option='zero')
521521
# src[attention.py:N]: qk = torch.bmm(q, k)
522-
qk = tl.dot(tl.cast(q_copy_0, tl.float16), tl.cast(k, tl.float16), input_precision='tf32', out_dtype=tl.float32)
522+
qk = tl.cast(tl.dot(tl.cast(q_copy_0, tl.float16), tl.cast(k, tl.float16), input_precision='tf32', out_dtype=tl.float32), tl.float16)
523523
# src[attention.py:N]: m_ij = torch.maximum(m_i, torch.amax(qk, -1) * qk_scale)
524524
amax = tl.cast(tl.max(qk, 2), tl.float16)
525525
v_0 = 0.18033688
@@ -2413,7 +2413,7 @@ def _helion__helion_jagged_attention_kernel(seq_offsets, q, k, v, out, max_seq_l
24132413
v_blk = tl.load(v + (v_16[:, None] * 256 + offset_1 * 32 + indices_5[None, :] * 1), mask_4[:, None], other=0)
24142414
# src[jagged_hstu_attn.py:N]: torch.nn.functional.silu(torch.matmul(q_blk, k_blk.T) * alpha)
24152415
permute = tl.permute(k_blk, [1, 0])
2416-
mm = tl.dot(tl.cast(q_blk_copy_0, tl.bfloat16), tl.cast(permute, tl.bfloat16), input_precision='tf32', out_dtype=tl.float32)
2416+
mm = tl.cast(tl.dot(tl.cast(q_blk_copy_0, tl.bfloat16), tl.cast(permute, tl.bfloat16), input_precision='tf32', out_dtype=tl.float32), tl.bfloat16)
24172417
v_17 = tl.cast(alpha, tl.bfloat16)
24182418
v_18 = mm * v_17
24192419
v_19 = tl.cast(v_18, tl.float32)
@@ -2448,7 +2448,7 @@ def _helion__helion_jagged_attention_kernel(seq_offsets, q, k, v, out, max_seq_l
24482448
v_30 = tl.where(v_27, v_24, v_29)
24492449
# src[jagged_hstu_attn.py:N]: acc += torch.matmul(scores.to(v.dtype), v_blk)
24502450
_mask_to_2 = tl.where(mask_2[:, None] & mask_4[None, :], v_30, tl.full([], 0, tl.bfloat16))
2451-
mm_1 = tl.dot(tl.cast(_mask_to_2, tl.bfloat16), tl.cast(v_blk, tl.bfloat16), input_precision='tf32', out_dtype=tl.float32)
2451+
mm_1 = tl.cast(tl.dot(tl.cast(_mask_to_2, tl.bfloat16), tl.cast(v_blk, tl.bfloat16), input_precision='tf32', out_dtype=tl.float32), tl.bfloat16)
24522452
v_31 = tl.cast(mm_1, tl.float32)
24532453
acc = acc_copy_0 + v_31
24542454
# src[jagged_hstu_attn.py:N]: out[tile_q.index + starts, tile_h.begin, :] = acc.to(out.dtype)
@@ -5559,22 +5559,20 @@ def _helion_squeeze_and_excitation_net_bwd_da(grad_out, x, d, b, c, grad_a, _BLO
55595559
# src[squeeze_and_excitation_net.py:N]: grad_to_c = grad_to_cb @ b[tile_k, :].T
55605560
load_4 = tl.load(b + (indices_1[:, None] * 256 + indices_3[None, :] * 1), None)
55615561
permute = tl.permute(load_4, [1, 0])
5562-
grad_to_c = tl.dot(tl.cast(v_4, tl.float16), tl.cast(permute, tl.float16), input_precision='tf32', out_dtype=tl.float32)
5562+
grad_to_c = tl.cast(tl.dot(tl.cast(v_4, tl.float16), tl.cast(permute, tl.float16), input_precision='tf32', out_dtype=tl.float32), tl.float16)
55635563
# src[squeeze_and_excitation_net.py:N]: grad_through_relu = grad_to_c * (c[tile_m, tile_k] > 0)
55645564
load_5 = tl.load(c + (indices_2[:, None] * 256 + indices_1[None, :] * 1), None)
55655565
v_5 = 0.0
55665566
v_6 = load_5 > v_5
55675567
v_7 = tl.cast(v_6, tl.float16)
55685568
v_8 = grad_to_c * v_7
5569-
# src[squeeze_and_excitation_net.py:N]: acc_a += x[tile_m, tile_n].T @ grad_through_relu
5569+
# src[squeeze_and_excitation_net.py:N]: acc_a = torch.addmm(acc_a, x[tile_m, tile_n].T, grad_through_relu)
55705570
load_6 = tl.load(x + (indices_2[:, None] * 256 + indices_0[None, :] * 1), None)
55715571
permute_1 = tl.permute(load_6, [1, 0])
5572-
mm_1 = tl.dot(tl.cast(permute_1, tl.float16), tl.cast(v_8, tl.float16), input_precision='tf32', out_dtype=tl.float32)
5573-
v_9 = tl.cast(mm_1, tl.float32)
5574-
acc_a = acc_a_copy_0 + v_9
5572+
acc_a = tl.dot(tl.cast(permute_1, tl.float16), tl.cast(v_8, tl.float16), acc=acc_a_copy_0, input_precision='tf32', out_dtype=tl.float32)
55755573
# src[squeeze_and_excitation_net.py:N]: grad_a[tile_n, tile_k] = acc_a
5576-
v_11 = tl.cast(acc_a, tl.float16)
5577-
tl.store(grad_a + (indices_0[:, None] * 256 + indices_1[None, :] * 1), v_11, None)
5574+
v_9 = tl.cast(acc_a, tl.float16)
5575+
tl.store(grad_a + (indices_0[:, None] * 256 + indices_1[None, :] * 1), v_9, None)
55785576

55795577
def squeeze_and_excitation_net_bwd_da(grad_out: Tensor, x: Tensor, b: Tensor, c: Tensor, d: Tensor, *, _launcher=_default_launcher):
55805578
"""
@@ -5654,15 +5652,13 @@ def _helion_squeeze_and_excitation_net_bwd_db(grad_out, x, d, c, grad_b, _BLOCK_
56545652
# src[squeeze_and_excitation_net.py:N]: * d[tile_m, tile_n]
56555653
# src[squeeze_and_excitation_net.py:N-N]: ...
56565654
v_4 = v_1 * v_3
5657-
# src[squeeze_and_excitation_net.py:N]: acc += c[tile_m, tile_k].T @ grad_d
5655+
# src[squeeze_and_excitation_net.py:N]: acc = torch.addmm(acc, c[tile_m, tile_k].T, grad_d)
56585656
load_4 = tl.load(c + (indices_2[:, None] * 256 + indices_0[None, :] * 1), None)
56595657
permute = tl.permute(load_4, [1, 0])
5660-
mm = tl.dot(tl.cast(permute, tl.float16), tl.cast(v_4, tl.float16), input_precision='tf32', out_dtype=tl.float32)
5661-
v_5 = tl.cast(mm, tl.float32)
5662-
acc = acc_copy_0 + v_5
5658+
acc = tl.dot(tl.cast(permute, tl.float16), tl.cast(v_4, tl.float16), acc=acc_copy_0, input_precision='tf32', out_dtype=tl.float32)
56635659
# src[squeeze_and_excitation_net.py:N]: grad_b[tile_k, tile_n] = acc
5664-
v_7 = tl.cast(acc, tl.float16)
5665-
tl.store(grad_b + (indices_0[:, None] * 256 + indices_1[None, :] * 1), v_7, None)
5660+
v_5 = tl.cast(acc, tl.float16)
5661+
tl.store(grad_b + (indices_0[:, None] * 256 + indices_1[None, :] * 1), v_5, None)
56665662

56675663
def squeeze_and_excitation_net_bwd_db(grad_out: Tensor, x: Tensor, d: Tensor, c: Tensor, *, _launcher=_default_launcher):
56685664
"""
@@ -5739,22 +5735,20 @@ def _helion_squeeze_and_excitation_net_bwd_dx(grad_out, d, x, b, c, a, grad_x, _
57395735
# src[squeeze_and_excitation_net.py:N]: grad_to_c = grad_to_d @ b[tile_k, :].T
57405736
load_6 = tl.load(b + (indices_2[:, None] * 256 + indices_3[None, :] * 1), None)
57415737
permute = tl.permute(load_6, [1, 0])
5742-
grad_to_c = tl.dot(tl.cast(v_7, tl.float16), tl.cast(permute, tl.float16), input_precision='tf32', out_dtype=tl.float32)
5738+
grad_to_c = tl.cast(tl.dot(tl.cast(v_7, tl.float16), tl.cast(permute, tl.float16), input_precision='tf32', out_dtype=tl.float32), tl.float16)
57435739
# src[squeeze_and_excitation_net.py:N]: grad_c_masked = grad_to_c * (c[tile_m, tile_k] > 0)
57445740
load_7 = tl.load(c + (indices_0[:, None] * 256 + indices_2[None, :] * 1), None)
57455741
v_8 = 0.0
57465742
v_9 = load_7 > v_8
57475743
v_10 = tl.cast(v_9, tl.float16)
57485744
v_11 = grad_to_c * v_10
5749-
# src[squeeze_and_excitation_net.py:N]: acc += grad_c_masked @ a[tile_n, tile_k].T
5745+
# src[squeeze_and_excitation_net.py:N]: acc = torch.addmm(acc, grad_c_masked, a[tile_n, tile_k].T)
57505746
load_8 = tl.load(a + (indices_1[:, None] * 256 + indices_2[None, :] * 1), None)
57515747
permute_1 = tl.permute(load_8, [1, 0])
5752-
mm_1 = tl.dot(tl.cast(v_11, tl.float16), tl.cast(permute_1, tl.float16), input_precision='tf32', out_dtype=tl.float32)
5753-
v_12 = tl.cast(mm_1, tl.float32)
5754-
v_2 = v_2_copy_0 + v_12
5748+
v_2 = tl.dot(tl.cast(v_11, tl.float16), tl.cast(permute_1, tl.float16), acc=v_2_copy_0, input_precision='tf32', out_dtype=tl.float32)
57555749
# src[squeeze_and_excitation_net.py:N]: grad_x[tile_m, tile_n] = acc
5756-
v_14 = tl.cast(v_2, tl.float16)
5757-
tl.store(grad_x + (indices_0[:, None] * 256 + indices_1[None, :] * 1), v_14, None)
5750+
v_12 = tl.cast(v_2, tl.float16)
5751+
tl.store(grad_x + (indices_0[:, None] * 256 + indices_1[None, :] * 1), v_12, None)
57585752

57595753
def squeeze_and_excitation_net_bwd_dx(grad_out: Tensor, x: Tensor, a: Tensor, b: Tensor, c: Tensor, d: Tensor, *, _launcher=_default_launcher):
57605754
"""
@@ -5815,7 +5809,7 @@ def _helion_squeeze_and_excitation_net_fwd(x, a, c, b, d, out, _BLOCK_SIZE_0: tl
58155809
# src[squeeze_and_excitation_net.py:N]: partial_xa = x[tile_m, :] @ a[:, tile_k]
58165810
load = tl.load(x + (indices_0[:, None] * 1024 + indices_2[None, :] * 1), None)
58175811
load_1 = tl.load(a + (indices_2[:, None] * 1024 + indices_1[None, :] * 1), None)
5818-
partial_xa = tl.dot(tl.cast(load, tl.float16), tl.cast(load_1, tl.float16), input_precision='tf32', out_dtype=tl.float32)
5812+
partial_xa = tl.cast(tl.dot(tl.cast(load, tl.float16), tl.cast(load_1, tl.float16), input_precision='tf32', out_dtype=tl.float32), tl.float16)
58195813
# src[squeeze_and_excitation_net.py:N]: c[tile_m, tile_k] = torch.relu(partial_xa)
58205814
v_0 = tl.full([], 0, tl.int32)
58215815
v_1 = triton_helpers.maximum(v_0, partial_xa)
@@ -5829,26 +5823,24 @@ def _helion_squeeze_and_excitation_net_fwd(x, a, c, b, d, out, _BLOCK_SIZE_0: tl
58295823
# src[squeeze_and_excitation_net.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
58305824
acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_3], 0.0, tl.float32)
58315825
# src[squeeze_and_excitation_net.py:N]: for tile_k in hl.tile(k):
5832-
# src[squeeze_and_excitation_net.py:N]: acc += c[tile_m, tile_k] @ b[tile_k, tile_n]
5826+
# src[squeeze_and_excitation_net.py:N]: acc = torch.addmm(acc, c[tile_m, tile_k], b[tile_k, tile_n])
58335827
for offset_4 in tl.range(0, 1024, _BLOCK_SIZE_4):
58345828
indices_4 = offset_4 + tl.arange(0, _BLOCK_SIZE_4).to(tl.int32)
58355829
acc_copy = acc
58365830
acc_copy_0 = acc_copy
5837-
# src[squeeze_and_excitation_net.py:N]: acc += c[tile_m, tile_k] @ b[tile_k, tile_n]
5831+
# src[squeeze_and_excitation_net.py:N]: acc = torch.addmm(acc, c[tile_m, tile_k], b[tile_k, tile_n])
58385832
load_2 = tl.load(c + (indices_0[:, None] * 1024 + indices_4[None, :] * 1), None)
58395833
load_3 = tl.load(b + (indices_4[:, None] * 1024 + indices_3[None, :] * 1), None)
5840-
mm = tl.dot(tl.cast(load_2, tl.float16), tl.cast(load_3, tl.float16), input_precision='tf32', out_dtype=tl.float32)
5841-
v_2 = tl.cast(mm, tl.float32)
5842-
acc = acc_copy_0 + v_2
5834+
acc = tl.dot(tl.cast(load_2, tl.float16), tl.cast(load_3, tl.float16), acc=acc_copy_0, input_precision='tf32', out_dtype=tl.float32)
58435835
# src[squeeze_and_excitation_net.py:N]: d[tile_m, tile_n] = torch.sigmoid(acc)
5844-
v_4 = tl.sigmoid(tl.cast(acc, tl.float32))
5845-
v_5 = tl.cast(v_4, tl.float16)
5846-
tl.store(d + (indices_0[:, None] * 1024 + indices_3[None, :] * 1), v_5, None)
5836+
v_2 = tl.sigmoid(tl.cast(acc, tl.float32))
5837+
v_3 = tl.cast(v_2, tl.float16)
5838+
tl.store(d + (indices_0[:, None] * 1024 + indices_3[None, :] * 1), v_3, None)
58475839
# src[squeeze_and_excitation_net.py:N]: out[tile_m, tile_n] = x[tile_m, tile_n] * d[tile_m, tile_n]
58485840
load_4 = tl.load(x + (indices_0[:, None] * 1024 + indices_3[None, :] * 1), None)
58495841
load_5 = tl.load(d + (indices_0[:, None] * 1024 + indices_3[None, :] * 1), None)
5850-
v_6 = load_4 * load_5
5851-
tl.store(out + (indices_0[:, None] * 1024 + indices_3[None, :] * 1), v_6, None)
5842+
v_4 = load_4 * load_5
5843+
tl.store(out + (indices_0[:, None] * 1024 + indices_3[None, :] * 1), v_4, None)
58525844

58535845
def squeeze_and_excitation_net_fwd(x: Tensor, a: Tensor, b: Tensor, *, _launcher=_default_launcher):
58545846
"""
@@ -5885,7 +5877,7 @@ def squeeze_and_excitation_net_fwd(x: Tensor, a: Tensor, b: Tensor, *, _launcher
58855877
# src[squeeze_and_excitation_net.py:N-N]: ...
58865878
_BLOCK_SIZE_3 = 16
58875879
# src[squeeze_and_excitation_net.py:N]: for tile_k in hl.tile(k):
5888-
# src[squeeze_and_excitation_net.py:N]: acc += c[tile_m, tile_k] @ b[tile_k, tile_n]
5880+
# src[squeeze_and_excitation_net.py:N]: acc = torch.addmm(acc, c[tile_m, tile_k], b[tile_k, tile_n])
58895881
_BLOCK_SIZE_4 = 16
58905882
# src[squeeze_and_excitation_net.py:N]: for tile_m in hl.tile(m):
58915883
# src[squeeze_and_excitation_net.py:N]: # Compute c = relu(x @ a) for this tile_m

0 commit comments

Comments
 (0)