From 984c0fc8009fd473fc23c05816b245bdb0485904 Mon Sep 17 00:00:00 2001 From: Oguz Ulgen Date: Wed, 24 Sep 2025 22:05:48 +0000 Subject: [PATCH] [RFC] Add support for device for loop indexing This PR is mostly vibe coded via cursor using gpt5, however heavily modified for correctness and simplicity. Fixes #598 stack-info: PR: https://github.com/pytorch/helion/pull/673, branch: oulgen/stack/99 --- helion/_compiler/device_ir.py | 31 ++++++++++ helion/_compiler/type_propagation.py | 93 +++++++++++++++++++++++++++- test/test_type_propagation.expected | 68 ++++++++++++++++++++ test/test_type_propagation.py | 34 ++++++++++ 4 files changed, 225 insertions(+), 1 deletion(-) diff --git a/helion/_compiler/device_ir.py b/helion/_compiler/device_ir.py index 8b60c9f5a..e5833ced9 100644 --- a/helion/_compiler/device_ir.py +++ b/helion/_compiler/device_ir.py @@ -590,6 +590,32 @@ def visit_For(self, node: ast.For) -> None: self._assign(node.target, inner_type.proxy()) self._body(node.body) elif node._loop_type == LoopType.DEVICE: + # Try static unrolling when begin/end are compile-time ints + begin, end = self._extract_tile_begin_end(node) + if isinstance(inner_type, SequenceType): + iter_vars = inner_type.unpack() + if begin is None: + begin_list = [0] * len(iter_vars) + else: + begin_list = begin if isinstance(begin, (list, tuple)) else [begin] + end_list = end if isinstance(end, (list, tuple)) else [end] + try_static = all( + isinstance(b, int) and isinstance(e, int) + for b, e in zip(begin_list, end_list, strict=True) + ) + if try_static: + # Assign inner proxy to target and then unroll nested ranges over scalar indices + self._assign(node.target, inner_type.proxy()) + self._body(node.body) + return + else: + # 1D case + b0 = 0 if begin is None else begin + if isinstance(b0, int) and isinstance(end, int): + for iv in range(b0, end): + self._assign(node.target, iv) + self._body(node.body) + return rw: ReadWrites = ReadWrites.from_ast(node) inputs: LiftTensorArgs = LiftTensorArgs( { @@ -947,6 +973,11 @@ def visit_Subscript(self, node: ast.Subscript) -> object: assert isinstance(value, ExtendedAST) type_info = value._type_info if isinstance(type_info, SequenceType): + index_val = self.visit(node.slice) + if isinstance(index_val, int): + sequence_val = self.visit(value) + assert isinstance(sequence_val, (list, tuple)) + return sequence_val[index_val] if isinstance(node.slice, ast.Constant): return self.visit(value)[self.visit(node.slice)] # pyright: ignore[reportIndexIssue] raise exc.InvalidSequenceSubscription(node.slice) diff --git a/helion/_compiler/type_propagation.py b/helion/_compiler/type_propagation.py index 5bae0d3ea..cd91ee455 100644 --- a/helion/_compiler/type_propagation.py +++ b/helion/_compiler/type_propagation.py @@ -1268,7 +1268,18 @@ def populate_symbol_origins(self, origin: Origin) -> None: subtype.populate_symbol_origins(GetItemOrigin(origin, i)) def propagate_getitem(self, key: TypeInfo, origin: Origin) -> TypeInfo: - return super().propagate_getitem(key, origin) + # Try literal indexing first + try: + return super().propagate_getitem(key, origin) + except exc.TypeInferenceError: + # If indexing with a symbolic/grid index on device and the sequence length is known, + # conservatively merge all possible element types. + if origin.is_device() and isinstance(key, (SymIntType, GridIndexType)): + merged: TypeInfo = self.element_types[0] + for candidate in self.element_types[1:]: + merged = merged.merge(candidate) + return merged + raise def merge(self, other: TypeInfo, var_name: str | None = None) -> TypeInfo: if isinstance(other, SequenceType): @@ -2161,6 +2172,86 @@ def visit_For(self, node: ast.For) -> TypeInfo: raise exc.NestedGridLoop self.device_loop_depth += device_loop + + # Try static unrolling for device grid loops when iteration count is known + try: + if node._loop_type != LoopType.HOST and isinstance(node.iter, ast.Call): + call_node = node.iter + # Extract begin, end, step; support only 1D grid here + begin_val: int | None + end_val: int | None + step_val: int | None + + if len(call_node.args) == 1: + begin_val = 0 + end_type = self.visit(call_node.args[0]) + step_type: TypeInfo | None = None + else: + begin_type = self.visit(call_node.args[0]) + end_type = self.visit(call_node.args[1]) + step_type = ( + self.visit(call_node.args[2]) + if len(call_node.args) >= 3 + else None + ) + begin_val = ( + begin_type.as_literal() if begin_type.is_literal() else None + ) # type: ignore[assignment] + + for kw in call_node.keywords: + if kw.arg == "step" and step_type is None: + step_type = self.visit(kw.value) + + end_val = end_type.as_literal() if end_type.is_literal() else None # type: ignore[assignment] + step_val = ( + step_type.as_literal() + if (step_type is not None and step_type.is_literal()) + else 1 + ) # type: ignore[assignment] + + if ( + isinstance(begin_val, int) + and isinstance(end_val, int) + and isinstance(step_val, int) + ): + # Build concrete iteration values + iter_values = list(range(begin_val, end_val, step_val)) + # Small guard to avoid excessive compile-time blowups + if len(iter_values) <= 64: + merged_scope: LocalScope | None = None + for iv in iter_values: + # Emulate _loop_body with loop index bound to a literal + self.push_scope() + self._assign(node.target, LiteralType(self.origin(), iv)) + exit_scopes = [self.scope] + for stmt in node.body: + self.visit(stmt) + if isinstance(stmt, (ast.Break, ast.Continue)): + exit_scopes.append(self.scope.clone()) + # Reset loop variable back to its GridIndexType to avoid control-flow merging issues + self._assign( + node.target, + iter_type.propagate_iter(self.origin()), + ) + self.pop_scope() + iter_scope = functools.reduce( + lambda x, y: x.merge(y), exit_scopes + ) + if merged_scope is None: + merged_scope = iter_scope + else: + merged_scope.merge(iter_scope) + + if merged_scope is not None: + body = merged_scope + orelse = self._body(node.orelse) + self.scope.merge_if_else(body, orelse) + self.device_loop_depth -= device_loop + return NoType(origin=self.origin()) + except NotImplementedError: + # Fall back to generic handling if we can't statically determine iterations + pass + body = self._loop_body(node.body) with self.swap_scope(body): # second pass for fixed point diff --git a/test/test_type_propagation.expected b/test/test_type_propagation.expected index 2a103e16f..1e5cd02c3 100644 --- a/test/test_type_propagation.expected +++ b/test/test_type_propagation.expected @@ -537,6 +537,74 @@ def root_graph_2(): _for_loop = helion_language__tracing_ops__for_loop(1, [0], [x_size0], []); x_size0 = _for_loop = None return None +--- assertExpectedJournal(TestTypePropagation.test_for_loop_indexing_in_device_code0) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_kernel(out, As_item_0, As_item_1, As_item_2, As_item_3, out_size_0, As_item_0_stride_0, As_item_1_stride_0, As_item_2_stride_0, As_item_3_stride_0, out_stride_0, _BLOCK_SIZE_0: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < out_size_0 + load = tl.load(out + indices_0 * out_stride_0, mask_0, other=0) + load_1 = tl.load(As_item_0 + indices_0 * As_item_0_stride_0, mask_0, other=0) + v_0 = load + load_1 + tl.store(out + indices_0 * out_stride_0, v_0, mask_0) + load_2 = tl.load(out + indices_0 * out_stride_0, mask_0, other=0) + load_3 = tl.load(As_item_1 + indices_0 * As_item_1_stride_0, mask_0, other=0) + v_1 = load_2 + load_3 + tl.store(out + indices_0 * out_stride_0, v_1, mask_0) + load_4 = tl.load(out + indices_0 * out_stride_0, mask_0, other=0) + load_5 = tl.load(As_item_2 + indices_0 * As_item_2_stride_0, mask_0, other=0) + v_2 = load_4 + load_5 + tl.store(out + indices_0 * out_stride_0, v_2, mask_0) + load_6 = tl.load(out + indices_0 * out_stride_0, mask_0, other=0) + load_7 = tl.load(As_item_3 + indices_0 * As_item_3_stride_0, mask_0, other=0) + v_3 = load_6 + load_7 + tl.store(out + indices_0 * out_stride_0, v_3, mask_0) + +def kernel(As: list[torch.Tensor], *, _launcher=_default_launcher): + out = torch.zeros_like(As[0]) + _BLOCK_SIZE_0 = 16 + _launcher(_helion_kernel, (triton.cdiv(out.size(0), _BLOCK_SIZE_0),), out, As[0], As[1], As[2], As[3], out.size(0), As[0].stride(0), As[1].stride(0), As[2].stride(0), As[3].stride(0), out.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3) + return out + +--- assertExpectedJournal(TestTypePropagation.test_for_loop_indexing_in_device_code1) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_kernel(out, As_item_0, As_item_1, As_item_2, As_item_3, out_size_0, As_item_0_stride_0, As_item_1_stride_0, As_item_2_stride_0, As_item_3_stride_0, out_stride_0, _BLOCK_SIZE_0: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < out_size_0 + acc = tl.full([_BLOCK_SIZE_0], 0.0, tl.float32) + load = tl.load(As_item_0 + indices_0 * As_item_0_stride_0, mask_0, other=0) + v_0 = acc + load + load_1 = tl.load(As_item_1 + indices_0 * As_item_1_stride_0, mask_0, other=0) + v_1 = v_0 + load_1 + load_2 = tl.load(As_item_2 + indices_0 * As_item_2_stride_0, mask_0, other=0) + v_2 = v_1 + load_2 + load_3 = tl.load(As_item_3 + indices_0 * As_item_3_stride_0, mask_0, other=0) + v_3 = v_2 + load_3 + tl.store(out + indices_0 * out_stride_0, v_3, mask_0) + +def kernel(As: list[torch.Tensor], *, _launcher=_default_launcher): + out = torch.zeros_like(As[0]) + _BLOCK_SIZE_0 = 16 + _launcher(_helion_kernel, (triton.cdiv(out.size(0), _BLOCK_SIZE_0),), out, As[0], As[1], As[2], As[3], out.size(0), As[0].stride(0), As[1].stride(0), As[2].stride(0), As[3].stride(0), out.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3) + return out + --- assertExpectedJournal(TestTypePropagation.test_hl_full_usage) def hl_full_usage(x: torch.Tensor): # Call: TensorType([x_size0, x_size1], torch.int32) SourceOrigin(location=) diff --git a/test/test_type_propagation.py b/test/test_type_propagation.py index 79fa3a6b6..2c1530e6e 100644 --- a/test/test_type_propagation.py +++ b/test/test_type_propagation.py @@ -8,8 +8,10 @@ import helion from helion import exc +from helion._testing import DEVICE from helion._testing import RefEagerTestDisabled from helion._testing import TestCase +from helion._testing import code_and_output from helion._testing import import_path import helion.language as hl @@ -132,6 +134,38 @@ def use_unsupported_property(x: torch.Tensor) -> torch.Tensor: ): type_propagation_report(use_unsupported_property, x) + def test_for_loop_indexing_in_device_code0(self): + @helion.kernel + def kernel(As: list[torch.Tensor]) -> torch.Tensor: + out = torch.zeros_like(As[0]) + for tile in hl.tile(out.size()): + for i in range(len(As)): + a = As[i] + out[tile] += a[tile] + return out + + args = [torch.randn(16, device=DEVICE) for _ in range(4)] + code, result = code_and_output(kernel, (args,)) + torch.testing.assert_close(result, sum(args)) + self.assertExpectedJournal(code) + + def test_for_loop_indexing_in_device_code1(self): + @helion.kernel + def kernel(As: list[torch.Tensor]) -> torch.Tensor: + out = torch.zeros_like(As[0]) + for tile in hl.tile(out.size()): + acc = hl.zeros(tile) + for i in range(len(As)): + a = As[i] + acc = acc + a[tile] + out[tile] = acc + return out + + args = [torch.randn(16, device=DEVICE) for _ in range(4)] + code, result = code_and_output(kernel, (args,)) + torch.testing.assert_close(result, sum(args)) + self.assertExpectedJournal(code) + if __name__ == "__main__": unittest.main()