@@ -241,7 +241,7 @@ def _helion_attention(q_view, k_view, v_view, out, _BLOCK_SIZE_0: tl.constexpr,
241241 # src[attention.py:N]: k = k_view[tile_b, :, tile_n]
242242 k = tl.load(tl.make_block_ptr(k_view, [64, 64, 512], [32768, 1, 64], [offset_0, 0, offset_2], [_BLOCK_SIZE_0, 64, _BLOCK_SIZE_3], [2, 0, 1]), boundary_check=[0, 1, 2], padding_option='zero')
243243 # src[attention.py:N]: qk = torch.bmm(q, k)
244- qk = tl.dot(tl.cast(q_copy_0, tl.float16), tl.cast(k, tl.float16), input_precision='tf32', out_dtype=tl.float32)
244+ qk = tl.cast(tl. dot(tl.cast(q_copy_0, tl.float16), tl.cast(k, tl.float16), input_precision='tf32', out_dtype=tl.float32), tl.float16 )
245245 # src[attention.py:N]: m_ij = torch.maximum(m_i, torch.amax(qk, -1) * qk_scale)
246246 amax = tl.cast(tl.max(qk, 2), tl.float16)
247247 v_0 = 0.18033688
@@ -519,7 +519,7 @@ def _helion_attention(q_view, k_view, v_view, out, _NUM_SM: tl.constexpr, _BLOCK
519519 # src[attention.py:N]: k = k_view[tile_b, :, tile_n]
520520 k = tl.load(tl.make_block_ptr(k_view, [32, 64, 512], [32768, 1, 64], [offset_0, 0, offset_2], [_BLOCK_SIZE_0, 64, _BLOCK_SIZE_3], [2, 0, 1]), boundary_check=[0, 1, 2], padding_option='zero')
521521 # src[attention.py:N]: qk = torch.bmm(q, k)
522- qk = tl.dot(tl.cast(q_copy_0, tl.float16), tl.cast(k, tl.float16), input_precision='tf32', out_dtype=tl.float32)
522+ qk = tl.cast(tl. dot(tl.cast(q_copy_0, tl.float16), tl.cast(k, tl.float16), input_precision='tf32', out_dtype=tl.float32), tl.float16 )
523523 # src[attention.py:N]: m_ij = torch.maximum(m_i, torch.amax(qk, -1) * qk_scale)
524524 amax = tl.cast(tl.max(qk, 2), tl.float16)
525525 v_0 = 0.18033688
@@ -2413,7 +2413,7 @@ def _helion__helion_jagged_attention_kernel(seq_offsets, q, k, v, out, max_seq_l
24132413 v_blk = tl.load(v + (v_16[:, None] * 256 + offset_1 * 32 + indices_5[None, :] * 1), mask_4[:, None], other=0)
24142414 # src[jagged_hstu_attn.py:N]: torch.nn.functional.silu(torch.matmul(q_blk, k_blk.T) * alpha)
24152415 permute = tl.permute(k_blk, [1, 0])
2416- mm = tl.dot(tl.cast(q_blk_copy_0, tl.bfloat16), tl.cast(permute, tl.bfloat16), input_precision='tf32', out_dtype=tl.float32)
2416+ mm = tl.cast(tl. dot(tl.cast(q_blk_copy_0, tl.bfloat16), tl.cast(permute, tl.bfloat16), input_precision='tf32', out_dtype=tl.float32), tl.bfloat16 )
24172417 v_17 = tl.cast(alpha, tl.bfloat16)
24182418 v_18 = mm * v_17
24192419 v_19 = tl.cast(v_18, tl.float32)
@@ -2448,7 +2448,7 @@ def _helion__helion_jagged_attention_kernel(seq_offsets, q, k, v, out, max_seq_l
24482448 v_30 = tl.where(v_27, v_24, v_29)
24492449 # src[jagged_hstu_attn.py:N]: acc += torch.matmul(scores.to(v.dtype), v_blk)
24502450 _mask_to_2 = tl.where(mask_2[:, None] & mask_4[None, :], v_30, tl.full([], 0, tl.bfloat16))
2451- mm_1 = tl.dot(tl.cast(_mask_to_2, tl.bfloat16), tl.cast(v_blk, tl.bfloat16), input_precision='tf32', out_dtype=tl.float32)
2451+ mm_1 = tl.cast(tl. dot(tl.cast(_mask_to_2, tl.bfloat16), tl.cast(v_blk, tl.bfloat16), input_precision='tf32', out_dtype=tl.float32), tl.bfloat16 )
24522452 v_31 = tl.cast(mm_1, tl.float32)
24532453 acc = acc_copy_0 + v_31
24542454 # src[jagged_hstu_attn.py:N]: out[tile_q.index + starts, tile_h.begin, :] = acc.to(out.dtype)
@@ -5559,22 +5559,20 @@ def _helion_squeeze_and_excitation_net_bwd_da(grad_out, x, d, b, c, grad_a, _BLO
55595559 # src[squeeze_and_excitation_net.py:N]: grad_to_c = grad_to_cb @ b[tile_k, :].T
55605560 load_4 = tl.load(b + (indices_1[:, None] * 256 + indices_3[None, :] * 1), None)
55615561 permute = tl.permute(load_4, [1, 0])
5562- grad_to_c = tl.dot(tl.cast(v_4, tl.float16), tl.cast(permute, tl.float16), input_precision='tf32', out_dtype=tl.float32)
5562+ grad_to_c = tl.cast(tl. dot(tl.cast(v_4, tl.float16), tl.cast(permute, tl.float16), input_precision='tf32', out_dtype=tl.float32), tl.float16 )
55635563 # src[squeeze_and_excitation_net.py:N]: grad_through_relu = grad_to_c * (c[tile_m, tile_k] > 0)
55645564 load_5 = tl.load(c + (indices_2[:, None] * 256 + indices_1[None, :] * 1), None)
55655565 v_5 = 0.0
55665566 v_6 = load_5 > v_5
55675567 v_7 = tl.cast(v_6, tl.float16)
55685568 v_8 = grad_to_c * v_7
5569- # src[squeeze_and_excitation_net.py:N]: acc_a += x[tile_m, tile_n].T @ grad_through_relu
5569+ # src[squeeze_and_excitation_net.py:N]: acc_a = torch.addmm(acc_a, x[tile_m, tile_n].T, grad_through_relu)
55705570 load_6 = tl.load(x + (indices_2[:, None] * 256 + indices_0[None, :] * 1), None)
55715571 permute_1 = tl.permute(load_6, [1, 0])
5572- mm_1 = tl.dot(tl.cast(permute_1, tl.float16), tl.cast(v_8, tl.float16), input_precision='tf32', out_dtype=tl.float32)
5573- v_9 = tl.cast(mm_1, tl.float32)
5574- acc_a = acc_a_copy_0 + v_9
5572+ acc_a = tl.dot(tl.cast(permute_1, tl.float16), tl.cast(v_8, tl.float16), acc=acc_a_copy_0, input_precision='tf32', out_dtype=tl.float32)
55755573 # src[squeeze_and_excitation_net.py:N]: grad_a[tile_n, tile_k] = acc_a
5576- v_11 = tl.cast(acc_a, tl.float16)
5577- tl.store(grad_a + (indices_0[:, None] * 256 + indices_1[None, :] * 1), v_11 , None)
5574+ v_9 = tl.cast(acc_a, tl.float16)
5575+ tl.store(grad_a + (indices_0[:, None] * 256 + indices_1[None, :] * 1), v_9 , None)
55785576
55795577def squeeze_and_excitation_net_bwd_da(grad_out: Tensor, x: Tensor, b: Tensor, c: Tensor, d: Tensor, *, _launcher=_default_launcher):
55805578 """
@@ -5654,15 +5652,13 @@ def _helion_squeeze_and_excitation_net_bwd_db(grad_out, x, d, c, grad_b, _BLOCK_
56545652 # src[squeeze_and_excitation_net.py:N]: * d[tile_m, tile_n]
56555653 # src[squeeze_and_excitation_net.py:N-N]: ...
56565654 v_4 = v_1 * v_3
5657- # src[squeeze_and_excitation_net.py:N]: acc += c[tile_m, tile_k].T @ grad_d
5655+ # src[squeeze_and_excitation_net.py:N]: acc = torch.addmm(acc, c[tile_m, tile_k].T, grad_d)
56585656 load_4 = tl.load(c + (indices_2[:, None] * 256 + indices_0[None, :] * 1), None)
56595657 permute = tl.permute(load_4, [1, 0])
5660- mm = tl.dot(tl.cast(permute, tl.float16), tl.cast(v_4, tl.float16), input_precision='tf32', out_dtype=tl.float32)
5661- v_5 = tl.cast(mm, tl.float32)
5662- acc = acc_copy_0 + v_5
5658+ acc = tl.dot(tl.cast(permute, tl.float16), tl.cast(v_4, tl.float16), acc=acc_copy_0, input_precision='tf32', out_dtype=tl.float32)
56635659 # src[squeeze_and_excitation_net.py:N]: grad_b[tile_k, tile_n] = acc
5664- v_7 = tl.cast(acc, tl.float16)
5665- tl.store(grad_b + (indices_0[:, None] * 256 + indices_1[None, :] * 1), v_7 , None)
5660+ v_5 = tl.cast(acc, tl.float16)
5661+ tl.store(grad_b + (indices_0[:, None] * 256 + indices_1[None, :] * 1), v_5 , None)
56665662
56675663def squeeze_and_excitation_net_bwd_db(grad_out: Tensor, x: Tensor, d: Tensor, c: Tensor, *, _launcher=_default_launcher):
56685664 """
@@ -5739,22 +5735,20 @@ def _helion_squeeze_and_excitation_net_bwd_dx(grad_out, d, x, b, c, a, grad_x, _
57395735 # src[squeeze_and_excitation_net.py:N]: grad_to_c = grad_to_d @ b[tile_k, :].T
57405736 load_6 = tl.load(b + (indices_2[:, None] * 256 + indices_3[None, :] * 1), None)
57415737 permute = tl.permute(load_6, [1, 0])
5742- grad_to_c = tl.dot(tl.cast(v_7, tl.float16), tl.cast(permute, tl.float16), input_precision='tf32', out_dtype=tl.float32)
5738+ grad_to_c = tl.cast(tl. dot(tl.cast(v_7, tl.float16), tl.cast(permute, tl.float16), input_precision='tf32', out_dtype=tl.float32), tl.float16 )
57435739 # src[squeeze_and_excitation_net.py:N]: grad_c_masked = grad_to_c * (c[tile_m, tile_k] > 0)
57445740 load_7 = tl.load(c + (indices_0[:, None] * 256 + indices_2[None, :] * 1), None)
57455741 v_8 = 0.0
57465742 v_9 = load_7 > v_8
57475743 v_10 = tl.cast(v_9, tl.float16)
57485744 v_11 = grad_to_c * v_10
5749- # src[squeeze_and_excitation_net.py:N]: acc += grad_c_masked @ a[tile_n, tile_k].T
5745+ # src[squeeze_and_excitation_net.py:N]: acc = torch.addmm(acc, grad_c_masked, a[tile_n, tile_k].T)
57505746 load_8 = tl.load(a + (indices_1[:, None] * 256 + indices_2[None, :] * 1), None)
57515747 permute_1 = tl.permute(load_8, [1, 0])
5752- mm_1 = tl.dot(tl.cast(v_11, tl.float16), tl.cast(permute_1, tl.float16), input_precision='tf32', out_dtype=tl.float32)
5753- v_12 = tl.cast(mm_1, tl.float32)
5754- v_2 = v_2_copy_0 + v_12
5748+ v_2 = tl.dot(tl.cast(v_11, tl.float16), tl.cast(permute_1, tl.float16), acc=v_2_copy_0, input_precision='tf32', out_dtype=tl.float32)
57555749 # src[squeeze_and_excitation_net.py:N]: grad_x[tile_m, tile_n] = acc
5756- v_14 = tl.cast(v_2, tl.float16)
5757- tl.store(grad_x + (indices_0[:, None] * 256 + indices_1[None, :] * 1), v_14 , None)
5750+ v_12 = tl.cast(v_2, tl.float16)
5751+ tl.store(grad_x + (indices_0[:, None] * 256 + indices_1[None, :] * 1), v_12 , None)
57585752
57595753def squeeze_and_excitation_net_bwd_dx(grad_out: Tensor, x: Tensor, a: Tensor, b: Tensor, c: Tensor, d: Tensor, *, _launcher=_default_launcher):
57605754 """
@@ -5815,7 +5809,7 @@ def _helion_squeeze_and_excitation_net_fwd(x, a, c, b, d, out, _BLOCK_SIZE_0: tl
58155809 # src[squeeze_and_excitation_net.py:N]: partial_xa = x[tile_m, :] @ a[:, tile_k]
58165810 load = tl.load(x + (indices_0[:, None] * 1024 + indices_2[None, :] * 1), None)
58175811 load_1 = tl.load(a + (indices_2[:, None] * 1024 + indices_1[None, :] * 1), None)
5818- partial_xa = tl.dot(tl.cast(load, tl.float16), tl.cast(load_1, tl.float16), input_precision='tf32', out_dtype=tl.float32)
5812+ partial_xa = tl.cast(tl. dot(tl.cast(load, tl.float16), tl.cast(load_1, tl.float16), input_precision='tf32', out_dtype=tl.float32), tl.float16 )
58195813 # src[squeeze_and_excitation_net.py:N]: c[tile_m, tile_k] = torch.relu(partial_xa)
58205814 v_0 = tl.full([], 0, tl.int32)
58215815 v_1 = triton_helpers.maximum(v_0, partial_xa)
@@ -5829,26 +5823,24 @@ def _helion_squeeze_and_excitation_net_fwd(x, a, c, b, d, out, _BLOCK_SIZE_0: tl
58295823 # src[squeeze_and_excitation_net.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
58305824 acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_3], 0.0, tl.float32)
58315825 # src[squeeze_and_excitation_net.py:N]: for tile_k in hl.tile(k):
5832- # src[squeeze_and_excitation_net.py:N]: acc += c[tile_m, tile_k] @ b[tile_k, tile_n]
5826+ # src[squeeze_and_excitation_net.py:N]: acc = torch.addmm(acc, c[tile_m, tile_k], b[tile_k, tile_n])
58335827 for offset_4 in tl.range(0, 1024, _BLOCK_SIZE_4):
58345828 indices_4 = offset_4 + tl.arange(0, _BLOCK_SIZE_4).to(tl.int32)
58355829 acc_copy = acc
58365830 acc_copy_0 = acc_copy
5837- # src[squeeze_and_excitation_net.py:N]: acc += c[tile_m, tile_k] @ b[tile_k, tile_n]
5831+ # src[squeeze_and_excitation_net.py:N]: acc = torch.addmm(acc, c[tile_m, tile_k], b[tile_k, tile_n])
58385832 load_2 = tl.load(c + (indices_0[:, None] * 1024 + indices_4[None, :] * 1), None)
58395833 load_3 = tl.load(b + (indices_4[:, None] * 1024 + indices_3[None, :] * 1), None)
5840- mm = tl.dot(tl.cast(load_2, tl.float16), tl.cast(load_3, tl.float16), input_precision='tf32', out_dtype=tl.float32)
5841- v_2 = tl.cast(mm, tl.float32)
5842- acc = acc_copy_0 + v_2
5834+ acc = tl.dot(tl.cast(load_2, tl.float16), tl.cast(load_3, tl.float16), acc=acc_copy_0, input_precision='tf32', out_dtype=tl.float32)
58435835 # src[squeeze_and_excitation_net.py:N]: d[tile_m, tile_n] = torch.sigmoid(acc)
5844- v_4 = tl.sigmoid(tl.cast(acc, tl.float32))
5845- v_5 = tl.cast(v_4 , tl.float16)
5846- tl.store(d + (indices_0[:, None] * 1024 + indices_3[None, :] * 1), v_5 , None)
5836+ v_2 = tl.sigmoid(tl.cast(acc, tl.float32))
5837+ v_3 = tl.cast(v_2 , tl.float16)
5838+ tl.store(d + (indices_0[:, None] * 1024 + indices_3[None, :] * 1), v_3 , None)
58475839 # src[squeeze_and_excitation_net.py:N]: out[tile_m, tile_n] = x[tile_m, tile_n] * d[tile_m, tile_n]
58485840 load_4 = tl.load(x + (indices_0[:, None] * 1024 + indices_3[None, :] * 1), None)
58495841 load_5 = tl.load(d + (indices_0[:, None] * 1024 + indices_3[None, :] * 1), None)
5850- v_6 = load_4 * load_5
5851- tl.store(out + (indices_0[:, None] * 1024 + indices_3[None, :] * 1), v_6 , None)
5842+ v_4 = load_4 * load_5
5843+ tl.store(out + (indices_0[:, None] * 1024 + indices_3[None, :] * 1), v_4 , None)
58525844
58535845def squeeze_and_excitation_net_fwd(x: Tensor, a: Tensor, b: Tensor, *, _launcher=_default_launcher):
58545846 """
@@ -5885,7 +5877,7 @@ def squeeze_and_excitation_net_fwd(x: Tensor, a: Tensor, b: Tensor, *, _launcher
58855877 # src[squeeze_and_excitation_net.py:N-N]: ...
58865878 _BLOCK_SIZE_3 = 16
58875879 # src[squeeze_and_excitation_net.py:N]: for tile_k in hl.tile(k):
5888- # src[squeeze_and_excitation_net.py:N]: acc += c[tile_m, tile_k] @ b[tile_k, tile_n]
5880+ # src[squeeze_and_excitation_net.py:N]: acc = torch.addmm(acc, c[tile_m, tile_k], b[tile_k, tile_n])
58895881 _BLOCK_SIZE_4 = 16
58905882 # src[squeeze_and_excitation_net.py:N]: for tile_m in hl.tile(m):
58915883 # src[squeeze_and_excitation_net.py:N]: # Compute c = relu(x @ a) for this tile_m
0 commit comments