From 7082cdc4917de526a5f2fd50d3c1f8f64a929496 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Wed, 8 Oct 2025 14:43:48 -0700 Subject: [PATCH 1/2] test --- test/test_matmul.py | 48 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/test/test_matmul.py b/test/test_matmul.py index f8e1cc4cc..55270bf25 100644 --- a/test/test_matmul.py +++ b/test/test_matmul.py @@ -199,6 +199,54 @@ def test_matmul_static_shapes3(self): torch.testing.assert_close(output, args[0] @ args[1], atol=1e-1, rtol=1e-2) self.assertExpectedJournal(code) + def test_matmul_packed_int4_block_size_constexpr(self): + torch.manual_seed(0) + M = N = K = 32 + + @helion.kernel(use_default_config=True, static_shapes=True) + def matmul_bf16_packed_int4( + A: torch.Tensor, B_packed: torch.Tensor, C: torch.Tensor + ) -> torch.Tensor: + M0, K0 = A.shape + _, N0 = B_packed.shape + + block_n = hl.register_block_size(N0) + block_k = hl.register_block_size(K0) + + for tile_m in hl.tile(M0): + for tile_n in hl.tile(N0, block_size=block_n): + acc = hl.zeros((tile_m, tile_n), dtype=torch.float32) + + for tile_k in hl.tile(K0, block_size=block_k): + tile_k_begin = tile_k.begin + b_tile = B_packed[ + tile_k_begin // 2 : tile_k_begin // 2 + block_k // 2, + tile_n, + ] + shift = hl.full((1,), 4, dtype=torch.int8) + b_lo = (b_tile << shift) >> shift + b_hi = b_tile >> shift + stacked = torch.stack( + (b_lo.to(torch.float16), b_hi.to(torch.float16)), dim=2 + ) + stacked = stacked.permute(0, 2, 1) + b_block = stacked.reshape([block_k, block_n]) + acc = hl.dot(A[tile_m, tile_k], b_block, acc=acc) + + C[tile_m, tile_n] = acc + + return C + + A = torch.randn((M, K), dtype=torch.bfloat16, device=DEVICE) + B_packed = torch.randint(0, 16, (K // 2, N), dtype=torch.int8, device=DEVICE) + C = torch.zeros((M, N), dtype=torch.float32, device=DEVICE) + + matmul_bf16_packed_int4(A, B_packed, C) + torch.cuda.synchronize() + + self.assertTrue(torch.isfinite(C).all()) + self.assertFalse(torch.allclose(C, torch.zeros_like(C))) + def test_matmul_split_k(self): @helion.kernel(dot_precision="ieee") def matmul_split_k(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: From 71521aca552dc93c0d8d9cd14bdc2beb9e2633dd Mon Sep 17 00:00:00 2001 From: Will Feng Date: Wed, 8 Oct 2025 14:45:30 -0700 Subject: [PATCH 2/2] up --- helion/_compiler/device_function.py | 59 +++++++++++++++++++++++---- helion/_compiler/inductor_lowering.py | 3 +- helion/_compiler/tile_strategy.py | 7 +--- test/test_constexpr.expected | 3 +- test/test_examples.expected | 58 ++++++++++---------------- test/test_indexing.expected | 2 - test/test_loops.expected | 6 ++- test/test_matmul.py | 2 +- test/test_misc.expected | 1 - test/test_reductions.expected | 10 +---- test/test_tensor_descriptor.expected | 8 ++-- test/test_views.expected | 8 ++-- 12 files changed, 91 insertions(+), 76 deletions(-) diff --git a/helion/_compiler/device_function.py b/helion/_compiler/device_function.py index 9bec72f47..a75356095 100644 --- a/helion/_compiler/device_function.py +++ b/helion/_compiler/device_function.py @@ -207,6 +207,7 @@ def __init__(self, name: str, config: Config, codegen: GenerateAST) -> None: ] = {} self._expr_args: dict[sympy.Expr, SymbolArgument] = {} self._constexpr_args: dict[str, ConstExprArg] = {} + self._constexpr_host_defs: set[str] = set() self._tensor_properties: dict[ tuple[type[TensorPropertyArg], torch.Tensor, int], TensorPropertyArg ] = {} @@ -282,11 +283,7 @@ def block_size_var(self, block_id: int) -> str | None: var_name = self.new_var(f"_BLOCK_SIZE_{block_id}") self.block_size_var_cache[key] = var_name - host_expr = HostFunction.current().literal_expr(block_value) - if self.constexpr_arg(var_name, host_expr): - self.codegen.host_statements.append( - statement_from_string(f"{var_name} = {host_expr}") - ) + self.constexpr_arg_with_host_def(var_name, block_value) return self.block_size_var_cache[key] @@ -484,14 +481,55 @@ def expr_arg(self, sym: sympy.Expr, origin: Origin) -> SymbolArgument: self._expr_args[sym] = arg return self._expr_args[sym] - def constexpr_arg(self, name: str, host_str: str | None = None) -> bool: + def constexpr_arg(self, name: str, value: object | None = None) -> bool: """Create a constexpr argument, returns True if created, False if already exists.""" if name in self._constexpr_args: return False - self._constexpr_args[name] = rv = ConstExprArg(name, host_str or name) + host_str = name if value is None else self._format_constexpr_value(value) + self._constexpr_args[name] = rv = ConstExprArg(name, host_str) self.arguments.append(rv) return True + def constexpr_arg_with_host_def(self, name: str, value: object) -> None: + """Create a constexpr argument and add its host-side definition if needed.""" + created = self.constexpr_arg(name, value) + host_expr = self._constexpr_args[name].host_str() + if created or name not in self._constexpr_host_defs: + self.codegen.host_statements.append( + statement_from_string(f"{name} = {host_expr}") + ) + self._constexpr_host_defs.add(name) + + def _format_constexpr_value(self, value: object) -> str: + if isinstance(value, str): + return value + if isinstance(value, (int, float, bool)): + return repr(value) + + # Extract sympy expression from torch symbolic types + if isinstance(value, (torch.SymInt, torch.SymFloat, torch.SymBool)): + value = value._sympy_() + + # Handle sympy expressions (sanitize by replacing triton_helpers functions) + if isinstance(value, sympy.Expr): + expr = cast( + "sympy.Expr", + value.replace( + lambda node: isinstance(node, sympy.Function) + and getattr(node.func, "__name__", "") + == "triton_helpers.div_floor_integer", + lambda node: sympy.floor(node.args[0] / node.args[1]), # pyright: ignore[reportAttributeAccessIssue] + ).replace( + lambda node: isinstance(node, sympy.Function) + and getattr(node.func, "__name__", "") + == "triton_helpers.remainder_integer", + lambda node: sympy.Mod(node.args[0], node.args[1]), # pyright: ignore[reportAttributeAccessIssue] + ), + ) + return HostFunction.current().sympy_expr(expr) + + return HostFunction.current().literal_expr(value) + def _tensor_property( self, prop_cls: type[_P], @@ -556,7 +594,12 @@ def codegen_function_def(self) -> list[ast.stmt]: ] def codegen_function_call(self) -> ast.AST: - args = [arg.host_str() for arg in self.sorted_args()] + args = [] + for arg in self.sorted_args(): + if isinstance(arg, ConstExprArg) and arg.name in self._constexpr_host_defs: + args.append(arg.name) + else: + args.append(arg.host_str()) if self.has_rng_ops(): # Pass the host-side seed buffer variable to the kernel diff --git a/helion/_compiler/inductor_lowering.py b/helion/_compiler/inductor_lowering.py index 2e0a6d919..dc4496c10 100644 --- a/helion/_compiler/inductor_lowering.py +++ b/helion/_compiler/inductor_lowering.py @@ -1241,8 +1241,7 @@ def _create_named_result(self, node: Node, result: ast.expr) -> str: ): # This expression is used in tl.arange, make it a constexpr name = self.cg.device_function.new_var(node.name) - host_expr = self.cg.device_function.sympy_expr(val._sympy_()) - self.cg.device_function.constexpr_arg(name, host_expr) + self.cg.device_function.constexpr_arg(name, val._sympy_()) return name # If the lowering produced a named value that is already defined elsewhere diff --git a/helion/_compiler/tile_strategy.py b/helion/_compiler/tile_strategy.py index 389a89a5d..bc1e83e03 100644 --- a/helion/_compiler/tile_strategy.py +++ b/helion/_compiler/tile_strategy.py @@ -244,12 +244,7 @@ def _setup_block_size_constexpr( self, state: CodegenState, block_size_var: str, block_size: SymIntLike ) -> None: """Helper to setup constexpr block size variable on host.""" - if state.device_function.constexpr_arg(block_size_var): - state.codegen.host_statements.append( - statement_from_string( - f"{block_size_var} = {HostFunction.current().literal_expr(block_size)}" - ) - ) + state.device_function.constexpr_arg_with_host_def(block_size_var, block_size) class BlockSizeTileStrategy(TileStrategy): diff --git a/test/test_constexpr.expected b/test/test_constexpr.expected index dca993a14..c0cfcb515 100644 --- a/test/test_constexpr.expected +++ b/test/test_constexpr.expected @@ -68,8 +68,9 @@ def matmul_int4_block_expr(A: torch.Tensor, B: torch.Tensor, *, _launcher=_defau C = torch.zeros(M, N, dtype=torch.bfloat16, device=A.device) _NUM_SM = helion.runtime.get_num_sm(A.device) _BLOCK_SIZE_2 = 16 + _BLOCK_SIZE_1 = 1 _BLOCK_SIZE_0 = 1 - _launcher(_helion_matmul_int4_block_expr, (_NUM_SM,), B, A, C, _NUM_SM, _BLOCK_SIZE_2, 1, 1, 2 * _BLOCK_SIZE_0, num_warps=1, num_stages=8) + _launcher(_helion_matmul_int4_block_expr, (_NUM_SM,), B, A, C, _NUM_SM, _BLOCK_SIZE_2, _BLOCK_SIZE_1, _BLOCK_SIZE_0, 2 * _BLOCK_SIZE_0, num_warps=1, num_stages=8) return C --- assertExpectedJournal(TestConstExpr.test_constexpr_float) diff --git a/test/test_examples.expected b/test/test_examples.expected index a27f62dd8..d6aac9b01 100644 --- a/test/test_examples.expected +++ b/test/test_examples.expected @@ -156,7 +156,6 @@ import torch import triton import triton.language as tl from torch._inductor.runtime import triton_helpers -from torch._inductor.runtime.triton_compat import libdevice from helion.runtime import default_launcher as _default_launcher @triton.jit @@ -244,7 +243,6 @@ import torch import triton import triton.language as tl from torch._inductor.runtime import triton_helpers -from torch._inductor.runtime.triton_compat import libdevice from helion.runtime import default_launcher as _default_launcher @triton.jit @@ -325,8 +323,9 @@ def attention(q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor, *, _la out = torch.empty_like(q_view) _BLOCK_SIZE_1 = 64 _RDIM_SIZE_2 = 64 + _BLOCK_SIZE_0 = 1 _BLOCK_SIZE_3 = 32 - _launcher(_helion_attention, (q_in.size(1) * triton.cdiv(m_dim, _BLOCK_SIZE_1),), q_view, k_view, v_view, out, q_in.size(1), k_view.stride(0), k_view.stride(1), k_view.stride(2), out.stride(0), out.stride(1), out.stride(2), q_view.stride(0), q_view.stride(1), q_view.stride(2), v_view.stride(0), v_view.stride(1), v_view.stride(2), m_dim, n_dim, _BLOCK_SIZE_1, _RDIM_SIZE_2, 1, _BLOCK_SIZE_3, num_warps=4, num_stages=2) + _launcher(_helion_attention, (q_in.size(1) * triton.cdiv(m_dim, _BLOCK_SIZE_1),), q_view, k_view, v_view, out, q_in.size(1), k_view.stride(0), k_view.stride(1), k_view.stride(2), out.stride(0), out.stride(1), out.stride(2), q_view.stride(0), q_view.stride(1), q_view.stride(2), v_view.stride(0), v_view.stride(1), v_view.stride(2), m_dim, n_dim, _BLOCK_SIZE_1, _RDIM_SIZE_2, _BLOCK_SIZE_0, _BLOCK_SIZE_3, num_warps=4, num_stages=2) return out.view(q_in.size()) --- assertExpectedJournal(TestExamples.test_attention_persistent_interleaved_l2_grouping) @@ -337,7 +336,6 @@ import helion import triton import triton.language as tl from torch._inductor.runtime import triton_helpers -from torch._inductor.runtime.triton_compat import libdevice from helion.runtime import default_launcher as _default_launcher @triton.jit @@ -434,7 +432,6 @@ import torch import triton import triton.language as tl from torch._inductor.runtime import triton_helpers -from torch._inductor.runtime.triton_compat import libdevice from helion.runtime import default_launcher as _default_launcher @triton.jit @@ -511,8 +508,9 @@ def attention(q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor, *, _la out = torch.empty_like(q_view) _BLOCK_SIZE_1 = 64 _RDIM_SIZE_2 = 64 + _BLOCK_SIZE_0 = 1 _BLOCK_SIZE_3 = 32 - _launcher(_helion_attention, (32 * triton.cdiv(512, _BLOCK_SIZE_1),), q_view, k_view, v_view, out, _BLOCK_SIZE_1, _RDIM_SIZE_2, 1, _BLOCK_SIZE_3, num_warps=4, num_stages=2) + _launcher(_helion_attention, (32 * triton.cdiv(512, _BLOCK_SIZE_1),), q_view, k_view, v_view, out, _BLOCK_SIZE_1, _RDIM_SIZE_2, _BLOCK_SIZE_0, _BLOCK_SIZE_3, num_warps=4, num_stages=2) return out.view(q_in.size()) --- assertExpectedJournal(TestExamples.test_bf16xint16) @@ -767,7 +765,6 @@ import torch import triton import triton.language as tl from torch._inductor.runtime.triton_helpers import math as tl_math -from torch._inductor.runtime.triton_compat import libdevice from helion.runtime import default_launcher as _default_launcher @triton.jit @@ -816,7 +813,8 @@ def cross_entropy(logits: torch.Tensor, labels: torch.Tensor, *, _launcher=_defa losses = torch.zeros([n], dtype=logits.dtype, device=logits.device) logits_flat = logits.view(-1) _RDIM_SIZE_1 = triton.next_power_of_2(v) - _launcher(_helion_cross_entropy, (n,), labels, logits_flat, logits, losses, labels.stride(0), logits.stride(0), logits.stride(1), logits_flat.stride(0), losses.stride(0), v, _RDIM_SIZE_1, 1, num_warps=4, num_stages=2) + _BLOCK_SIZE_0 = 1 + _launcher(_helion_cross_entropy, (n,), labels, logits_flat, logits, losses, labels.stride(0), logits.stride(0), logits.stride(1), logits_flat.stride(0), losses.stride(0), v, _RDIM_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=2) return losses.mean() --- assertExpectedJournal(TestExamples.test_embedding_block_ptr) @@ -945,7 +943,6 @@ from __future__ import annotations import torch import triton import triton.language as tl -from torch._inductor.runtime.triton_compat import libdevice from helion.runtime import default_launcher as _default_launcher @triton.jit @@ -983,7 +980,6 @@ import torch import triton import triton.language as tl from torch._inductor.runtime import triton_helpers -from torch._inductor.runtime.triton_compat import libdevice from helion.runtime import default_launcher as _default_launcher @triton.jit @@ -1118,7 +1114,6 @@ import torch import triton import triton.language as tl from torch._inductor.runtime.triton_helpers import math as tl_math -from torch._inductor.runtime.triton_compat import libdevice from helion.runtime import default_launcher as _default_launcher @triton.jit @@ -1257,7 +1252,6 @@ from __future__ import annotations import torch import triton import triton.language as tl -from torch._inductor.runtime.triton_compat import libdevice from helion.runtime import default_launcher as _default_launcher @triton.jit @@ -1393,7 +1387,6 @@ import torch import triton import triton.language as tl from torch._inductor.runtime import triton_helpers -from torch._inductor.runtime.triton_compat import libdevice from helion.runtime import default_launcher as _default_launcher @triton.jit @@ -1526,8 +1519,10 @@ def grouped_gemm_jagged_persistent(A_packed: torch.Tensor, B: torch.Tensor, grou assert K == K2 out = torch.zeros(total_M, N, dtype=torch.promote_types(A_packed.dtype, B.dtype), device=A_packed.device) G = group_offsets.size(0) - 1 + _BLOCK_SIZE_0 = 32 + _BLOCK_SIZE_1 = 32 _BLOCK_SIZE_5 = 16 - _launcher(_helion_grouped_gemm_jagged_persistent, (num_workers,), group_offsets, A_packed, B, out, A_packed.stride(0), A_packed.stride(1), B.stride(0), B.stride(1), group_offsets.stride(0), out.stride(0), out.stride(1), num_workers, G, N, K, 32, 32, _BLOCK_SIZE_5, num_warps=4, num_stages=2) + _launcher(_helion_grouped_gemm_jagged_persistent, (num_workers,), group_offsets, A_packed, B, out, A_packed.stride(0), A_packed.stride(1), B.stride(0), B.stride(1), group_offsets.stride(0), out.stride(0), out.stride(1), num_workers, G, N, K, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_5, num_warps=4, num_stages=2) return out --- assertExpectedJournal(TestExamples.test_int4_gemm) @@ -1792,7 +1787,6 @@ from __future__ import annotations import torch import triton import triton.language as tl -from torch._inductor.runtime.triton_compat import libdevice from helion.runtime import default_launcher as _default_launcher @triton.jit @@ -2101,7 +2095,6 @@ import torch import triton import triton.language as tl from torch._inductor.runtime import triton_helpers -from torch._inductor.runtime.triton_compat import libdevice from helion.runtime import default_launcher as _default_launcher @triton.jit @@ -2326,7 +2319,6 @@ import torch import triton import triton.language as tl from torch._inductor.runtime.triton_helpers import math as tl_math -from torch._inductor.runtime.triton_compat import libdevice from helion.runtime import default_launcher as _default_launcher @triton.jit @@ -2460,7 +2452,8 @@ def jsd_forward(_input: Tensor, target: Tensor, shift_labels: Tensor | None=None if n_non_ignore == 0: return (torch.zeros([], dtype=_input.dtype, device=_input.device), torch.zeros_like(_input)) _BLOCK_SIZE_1 = 4096 - _launcher(_helion_jsd_forward, (triton.cdiv(BT, _BLOCK_SIZE_1),), _input, target, loss, dX, _input.stride(0), _input.stride(1), dX.stride(0), loss.stride(0), target.stride(0), target.stride(1), BT, V, beta, one_minus_beta, n_non_ignore, _BLOCK_SIZE_1, 1, num_warps=4, num_stages=3) + _BLOCK_SIZE_0 = 1 + _launcher(_helion_jsd_forward, (triton.cdiv(BT, _BLOCK_SIZE_1),), _input, target, loss, dX, _input.stride(0), _input.stride(1), dX.stride(0), loss.stride(0), target.stride(0), target.stride(1), BT, V, beta, one_minus_beta, n_non_ignore, _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3) final_loss = torch.sum(loss) return (final_loss, dX) @@ -2472,7 +2465,6 @@ import triton import triton.language as tl from torch._inductor.runtime import triton_helpers from torch._inductor.runtime.triton_helpers import math as tl_math -from torch._inductor.runtime.triton_compat import libdevice from helion.runtime import default_launcher as _default_launcher @triton.jit @@ -2538,7 +2530,8 @@ def kl_div_forward(y_pred: Tensor, y_true: Tensor, log_target: bool=False, reduc else: loss = torch.zeros((BT,), dtype=torch.float32, device=y_pred.device) _BLOCK_SIZE_1 = 4096 - _launcher(_helion_kl_div_forward, (triton.cdiv(BT, _BLOCK_SIZE_1),), y_pred, y_true, loss, loss.stride(0), y_pred.stride(0), y_pred.stride(1), y_true.stride(0), y_true.stride(1), BT, V, log_target, eps, _BLOCK_SIZE_1, 1, num_warps=4, num_stages=3) + _BLOCK_SIZE_0 = 1 + _launcher(_helion_kl_div_forward, (triton.cdiv(BT, _BLOCK_SIZE_1),), y_pred, y_true, loss, loss.stride(0), y_pred.stride(0), y_pred.stride(1), y_true.stride(0), y_true.stride(1), BT, V, log_target, eps, _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3) if reduction == 'batchmean': final_loss = torch.sum(loss) / BT elif reduction == 'sum': @@ -2752,7 +2745,6 @@ from __future__ import annotations import torch import triton import triton.language as tl -from torch._inductor.runtime.triton_compat import libdevice from helion.runtime import default_launcher as _default_launcher @triton.jit @@ -2823,7 +2815,6 @@ from __future__ import annotations import torch import triton import triton.language as tl -from torch._inductor.runtime.triton_compat import libdevice from helion.runtime import default_launcher as _default_launcher @triton.jit @@ -2898,7 +2889,6 @@ from __future__ import annotations import torch import triton import triton.language as tl -from torch._inductor.runtime.triton_compat import libdevice from helion.runtime import default_launcher as _default_launcher @triton.jit @@ -3149,7 +3139,6 @@ from __future__ import annotations import torch import triton import triton.language as tl -from torch._inductor.runtime.triton_compat import libdevice from helion.runtime import default_launcher as _default_launcher @triton.jit @@ -3224,7 +3213,6 @@ from __future__ import annotations import torch import triton import triton.language as tl -from torch._inductor.runtime.triton_compat import libdevice from helion.runtime import default_launcher as _default_launcher @triton.jit @@ -3630,7 +3618,6 @@ from __future__ import annotations import torch import triton import triton.language as tl -from torch._inductor.runtime.triton_compat import libdevice from helion.runtime import default_launcher as _default_launcher @triton.jit @@ -3690,7 +3677,6 @@ from __future__ import annotations import torch import triton import triton.language as tl -from torch._inductor.runtime.triton_compat import libdevice from helion.runtime import default_launcher as _default_launcher import helion._testing.segment_reduction as _source_module @@ -3774,7 +3760,6 @@ from __future__ import annotations import torch import triton import triton.language as tl -from torch._inductor.runtime.triton_compat import libdevice from helion.runtime import default_launcher as _default_launcher @triton.jit @@ -3804,7 +3789,8 @@ def softmax(x: torch.Tensor, *, _launcher=_default_launcher): n, _m = x.size() out = torch.empty_like(x) _RDIM_SIZE_1 = triton.next_power_of_2(_m) - _launcher(_helion_softmax, (n,), x, out, out.size(0), out.size(1), x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), _m, _RDIM_SIZE_1, 1, num_warps=4, num_stages=1) + _BLOCK_SIZE_0 = 1 + _launcher(_helion_softmax, (n,), x, out, out.size(0), out.size(1), x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), _m, _RDIM_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=1) return out --- assertExpectedJournal(TestExamples.test_softmax_decomposed) @@ -3813,7 +3799,6 @@ from __future__ import annotations import torch import triton import triton.language as tl -from torch._inductor.runtime.triton_compat import libdevice from helion.runtime import default_launcher as _default_launcher @triton.jit @@ -3844,7 +3829,8 @@ def softmax_decomposed(x: torch.Tensor, *, _launcher=_default_launcher): n, _m = x.size() out = torch.empty_like(x) _RDIM_SIZE_1 = triton.next_power_of_2(_m) - _launcher(_helion_softmax_decomposed, (n,), x, out, out.size(0), out.size(1), x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), _m, _RDIM_SIZE_1, 1, num_warps=4, num_stages=1) + _BLOCK_SIZE_0 = 1 + _launcher(_helion_softmax_decomposed, (n,), x, out, out.size(0), out.size(1), x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), _m, _RDIM_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=1) return out --- assertExpectedJournal(TestExamples.test_softmax_looped) @@ -3854,7 +3840,6 @@ import torch import triton import triton.language as tl from torch._inductor.runtime import triton_helpers -from torch._inductor.runtime.triton_compat import libdevice from helion.runtime import default_launcher as _default_launcher @triton.jit @@ -3904,7 +3889,8 @@ def softmax(x: torch.Tensor, *, _launcher=_default_launcher): n, _m = x.size() out = torch.empty_like(x) _REDUCTION_BLOCK_1 = 32 - _launcher(_helion_softmax, (n,), x, out, out.size(0), out.size(1), x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), _m, _REDUCTION_BLOCK_1, 1, num_warps=4, num_stages=1) + _BLOCK_SIZE_0 = 1 + _launcher(_helion_softmax, (n,), x, out, out.size(0), out.size(1), x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), _m, _REDUCTION_BLOCK_1, _BLOCK_SIZE_0, num_warps=4, num_stages=1) return out --- assertExpectedJournal(TestExamples.test_softmax_two_pass) @@ -3914,7 +3900,6 @@ import torch import triton import triton.language as tl from torch._inductor.runtime import triton_helpers -from torch._inductor.runtime.triton_compat import libdevice from helion.runtime import default_launcher as _default_launcher @triton.jit @@ -3984,7 +3969,6 @@ import torch import triton import triton.language as tl from torch._inductor.runtime import triton_helpers -from torch._inductor.runtime.triton_compat import libdevice from helion.runtime import default_launcher as _default_launcher @triton.jit @@ -4081,8 +4065,9 @@ def sum_kernel(x: torch.Tensor, *, _launcher=_default_launcher): """ m, n = x.shape out = torch.empty([m], dtype=x.dtype, device=x.device) + _BLOCK_SIZE_0 = 1 _REDUCTION_BLOCK_1 = 32768 - _launcher(_helion_sum_kernel, (m,), x, out, out.stride(0), x.stride(0), x.stride(1), n, 1, _REDUCTION_BLOCK_1, num_warps=4, num_stages=2) + _launcher(_helion_sum_kernel, (m,), x, out, out.stride(0), x.stride(0), x.stride(1), n, _BLOCK_SIZE_0, _REDUCTION_BLOCK_1, num_warps=4, num_stages=2) return out --- assertExpectedJournal(TestExamples.test_swiglu) @@ -4318,7 +4303,6 @@ from __future__ import annotations import torch import triton import triton.language as tl -from torch._inductor.runtime.triton_compat import libdevice from helion.runtime import default_launcher as _default_launcher @triton.jit diff --git a/test/test_indexing.expected b/test/test_indexing.expected index 72974dd4d..37d0b6aa6 100644 --- a/test/test_indexing.expected +++ b/test/test_indexing.expected @@ -191,7 +191,6 @@ from __future__ import annotations import torch import triton import triton.language as tl -from torch._inductor.runtime.triton_compat import libdevice from helion.runtime import default_launcher as _default_launcher @triton.jit @@ -227,7 +226,6 @@ from __future__ import annotations import torch import triton import triton.language as tl -from torch._inductor.runtime.triton_compat import libdevice from helion.runtime import default_launcher as _default_launcher @triton.jit diff --git a/test/test_loops.expected b/test/test_loops.expected index d41f935ea..11f3b0e4a 100644 --- a/test/test_loops.expected +++ b/test/test_loops.expected @@ -132,7 +132,8 @@ def device_loop_3d(x: torch.Tensor, *, _launcher=_default_launcher): _BLOCK_SIZE_0 = 2 _BLOCK_SIZE_2 = 4 _BLOCK_SIZE_1 = 8 - _launcher(_helion_device_loop_3d, (triton.cdiv(a, _BLOCK_SIZE_0),), x, out, out.size(0), out.size(1), out.size(2), out.size(3), x.size(0), x.size(1), x.size(2), x.size(3), out.stride(0), out.stride(1), out.stride(2), out.stride(3), x.stride(0), x.stride(1), x.stride(2), x.stride(3), b, c, d, _BLOCK_SIZE_0, _BLOCK_SIZE_2, _BLOCK_SIZE_1, 1, num_warps=4, num_stages=2) + _BLOCK_SIZE_3 = 1 + _launcher(_helion_device_loop_3d, (triton.cdiv(a, _BLOCK_SIZE_0),), x, out, out.size(0), out.size(1), out.size(2), out.size(3), x.size(0), x.size(1), x.size(2), x.size(3), out.stride(0), out.stride(1), out.stride(2), out.stride(3), x.stride(0), x.stride(1), x.stride(2), x.stride(3), b, c, d, _BLOCK_SIZE_0, _BLOCK_SIZE_2, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=4, num_stages=2) return out --- assertExpectedJournal(TestLoops.test_chebyshev_polynomials) @@ -906,11 +907,12 @@ def _helion_nested_loop_accumulator(x, out, out_stride_0, out_stride_1, out_stri def nested_loop_accumulator(x: torch.Tensor, *, _launcher=_default_launcher): B, N, M = x.size() out = torch.zeros_like(x) + _BLOCK_SIZE_0 = 1 _BLOCK_SIZE_1 = 2 _BLOCK_SIZE_2 = 4 _BLOCK_SIZE_3 = 2 _BLOCK_SIZE_4 = 4 - _launcher(_helion_nested_loop_accumulator, (B,), x, out, out.stride(0), out.stride(1), out.stride(2), x.stride(0), x.stride(1), x.stride(2), N, M, 1, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, _BLOCK_SIZE_4, num_warps=4, num_stages=2) + _launcher(_helion_nested_loop_accumulator, (B,), x, out, out.stride(0), out.stride(1), out.stride(2), x.stride(0), x.stride(1), x.stride(2), N, M, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, _BLOCK_SIZE_4, num_warps=4, num_stages=2) return out --- assertExpectedJournal(TestLoops.test_pointwise_device_loop) diff --git a/test/test_matmul.py b/test/test_matmul.py index 55270bf25..4c93ebf4f 100644 --- a/test/test_matmul.py +++ b/test/test_matmul.py @@ -227,7 +227,7 @@ def matmul_bf16_packed_int4( b_lo = (b_tile << shift) >> shift b_hi = b_tile >> shift stacked = torch.stack( - (b_lo.to(torch.float16), b_hi.to(torch.float16)), dim=2 + (b_lo.to(A.dtype), b_hi.to(A.dtype)), dim=2 ) stacked = stacked.permute(0, 2, 1) b_block = stacked.reshape([block_k, block_n]) diff --git a/test/test_misc.expected b/test/test_misc.expected index d44823cd0..5e3b1e016 100644 --- a/test/test_misc.expected +++ b/test/test_misc.expected @@ -150,7 +150,6 @@ from __future__ import annotations import torch import triton import triton.language as tl -from torch._inductor.runtime.triton_compat import libdevice from helion.runtime import default_launcher as _default_launcher @triton.jit diff --git a/test/test_reductions.expected b/test/test_reductions.expected index b4a08387b..cf046d6a0 100644 --- a/test/test_reductions.expected +++ b/test/test_reductions.expected @@ -58,8 +58,9 @@ def _helion_reduce_kernel(x, out, out_size_0, x_size_0, x_size_1, out_stride_0, def reduce_kernel(x: torch.Tensor, fn: Callable[[torch.Tensor], torch.Tensor], out_dtype=torch.float32, *, _launcher=_default_launcher): n, _m = x.size() out = torch.empty([n], dtype=out_dtype, device=x.device) + _BLOCK_SIZE_0 = 1 _REDUCTION_BLOCK_1 = 16 - _launcher(_helion_reduce_kernel, (n,), x, out, out.size(0), x.size(0), x.size(1), out.stride(0), x.stride(0), x.stride(1), _m, 1, _REDUCTION_BLOCK_1, num_warps=4, num_stages=2) + _launcher(_helion_reduce_kernel, (n,), x, out, out.size(0), x.size(0), x.size(1), out.stride(0), x.stride(0), x.stride(1), _m, _BLOCK_SIZE_0, _REDUCTION_BLOCK_1, num_warps=4, num_stages=2) return out --- assertExpectedJournal(TestReductions.test_broken_layernorm) @@ -68,7 +69,6 @@ from __future__ import annotations import torch import triton import triton.language as tl -from torch._inductor.runtime.triton_compat import libdevice from helion.runtime import default_launcher as _default_launcher @triton.jit @@ -120,7 +120,6 @@ from __future__ import annotations import torch import triton import triton.language as tl -from torch._inductor.runtime.triton_compat import libdevice from helion.runtime import default_launcher as _default_launcher @triton.jit @@ -148,7 +147,6 @@ import torch import triton import triton.language as tl from torch._inductor.runtime.triton_helpers import math as tl_math -from torch._inductor.runtime.triton_compat import libdevice from helion.runtime import default_launcher as _default_launcher @triton.jit @@ -210,7 +208,6 @@ from __future__ import annotations import torch import triton import triton.language as tl -from torch._inductor.runtime.triton_compat import libdevice from helion.runtime import default_launcher as _default_launcher @triton.jit @@ -263,7 +260,6 @@ from __future__ import annotations import torch import triton import triton.language as tl -from torch._inductor.runtime.triton_compat import libdevice from helion.runtime import default_launcher as _default_launcher @triton.jit @@ -332,7 +328,6 @@ from __future__ import annotations import torch import triton import triton.language as tl -from torch._inductor.runtime.triton_compat import libdevice from helion.runtime import default_launcher as _default_launcher @triton.jit @@ -477,7 +472,6 @@ from __future__ import annotations import torch import triton import triton.language as tl -from torch._inductor.runtime.triton_compat import libdevice from helion.runtime import default_launcher as _default_launcher @triton.jit diff --git a/test/test_tensor_descriptor.expected b/test/test_tensor_descriptor.expected index e51b79c75..81e2a1d76 100644 --- a/test/test_tensor_descriptor.expected +++ b/test/test_tensor_descriptor.expected @@ -8,7 +8,6 @@ import torch import triton import triton.language as tl from torch._inductor.runtime import triton_helpers -from torch._inductor.runtime.triton_compat import libdevice from helion.runtime import default_launcher as _default_launcher @triton.jit @@ -89,8 +88,9 @@ def attention(q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor, *, _la out = torch.empty_like(q_view) _BLOCK_SIZE_1 = 16 _RDIM_SIZE_2 = 64 + _BLOCK_SIZE_0 = 1 _BLOCK_SIZE_3 = 16 - _launcher(_helion_attention, (q_in.size(1) * triton.cdiv(m_dim, _BLOCK_SIZE_1),), q_view, k_view, v_view, out, q_in.size(1), k_view.stride(0), k_view.stride(1), k_view.stride(2), out.stride(0), out.stride(1), out.stride(2), q_view.stride(0), q_view.stride(1), q_view.stride(2), v_view.stride(0), v_view.stride(1), v_view.stride(2), m_dim, n_dim, _BLOCK_SIZE_1, _RDIM_SIZE_2, 1, _BLOCK_SIZE_3, num_warps=4, num_stages=2) + _launcher(_helion_attention, (q_in.size(1) * triton.cdiv(m_dim, _BLOCK_SIZE_1),), q_view, k_view, v_view, out, q_in.size(1), k_view.stride(0), k_view.stride(1), k_view.stride(2), out.stride(0), out.stride(1), out.stride(2), q_view.stride(0), q_view.stride(1), q_view.stride(2), v_view.stride(0), v_view.stride(1), v_view.stride(2), m_dim, n_dim, _BLOCK_SIZE_1, _RDIM_SIZE_2, _BLOCK_SIZE_0, _BLOCK_SIZE_3, num_warps=4, num_stages=2) return out.view(q_in.size()) --- assertExpectedJournal(TestTensorDescriptor.test_attention_tensor_descriptor) @@ -100,7 +100,6 @@ import torch import triton import triton.language as tl from torch._inductor.runtime import triton_helpers -from torch._inductor.runtime.triton_compat import libdevice from helion.runtime import default_launcher as _default_launcher @triton.jit @@ -181,6 +180,7 @@ def attention(q_in: torch.Tensor, k_in: torch.Tensor, v_in: torch.Tensor, *, _la out = torch.empty_like(q_view) _BLOCK_SIZE_1 = 128 _RDIM_SIZE_2 = 64 + _BLOCK_SIZE_0 = 1 _BLOCK_SIZE_3 = 64 - _launcher(_helion_attention, (64 * triton.cdiv(1024, _BLOCK_SIZE_1),), q_view, k_view, v_view, out, _BLOCK_SIZE_1, _RDIM_SIZE_2, 1, _BLOCK_SIZE_3, num_warps=4, num_stages=2) + _launcher(_helion_attention, (64 * triton.cdiv(1024, _BLOCK_SIZE_1),), q_view, k_view, v_view, out, _BLOCK_SIZE_1, _RDIM_SIZE_2, _BLOCK_SIZE_0, _BLOCK_SIZE_3, num_warps=4, num_stages=2) return out.view(q_in.size()) diff --git a/test/test_views.expected b/test/test_views.expected index a5d73fb14..182bd1e94 100644 --- a/test/test_views.expected +++ b/test/test_views.expected @@ -44,7 +44,6 @@ from __future__ import annotations import torch import triton import triton.language as tl -from torch._inductor.runtime.triton_compat import libdevice from helion.runtime import default_launcher as _default_launcher @triton.jit @@ -72,7 +71,8 @@ def softmax(x: torch.Tensor, *, _launcher=_default_launcher): n, _m = x.size() out = torch.empty_like(x) _RDIM_SIZE_1 = triton.next_power_of_2(_m) - _launcher(_helion_softmax, (n,), x, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), _m, _RDIM_SIZE_1, 1, num_warps=4, num_stages=2) + _BLOCK_SIZE_0 = 1 + _launcher(_helion_softmax, (n,), x, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), _m, _RDIM_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=2) return out --- assertExpectedJournal(TestViews.test_softmax_view_reshape) @@ -81,7 +81,6 @@ from __future__ import annotations import torch import triton import triton.language as tl -from torch._inductor.runtime.triton_compat import libdevice from helion.runtime import default_launcher as _default_launcher @triton.jit @@ -109,7 +108,8 @@ def softmax(x: torch.Tensor, *, _launcher=_default_launcher): n, _m = x.size() out = torch.empty_like(x) _RDIM_SIZE_1 = triton.next_power_of_2(_m) - _launcher(_helion_softmax, (n,), x, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), _m, _RDIM_SIZE_1, 1, num_warps=4, num_stages=2) + _BLOCK_SIZE_0 = 1 + _launcher(_helion_softmax, (n,), x, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), _m, _RDIM_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=2) return out --- assertExpectedJournal(TestViews.test_squeeze)