Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions helion/_compiler/device_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,32 @@ def visit_For(self, node: ast.For) -> None:
self._assign(node.target, inner_type.proxy())
self._body(node.body)
elif node._loop_type == LoopType.DEVICE:
# Try static unrolling when begin/end are compile-time ints
begin, end = self._extract_tile_begin_end(node)
if isinstance(inner_type, SequenceType):
iter_vars = inner_type.unpack()
if begin is None:
begin_list = [0] * len(iter_vars)
else:
begin_list = begin if isinstance(begin, (list, tuple)) else [begin]
end_list = end if isinstance(end, (list, tuple)) else [end]
try_static = all(
isinstance(b, int) and isinstance(e, int)
for b, e in zip(begin_list, end_list, strict=True)
)
if try_static:
# Assign inner proxy to target and then unroll nested ranges over scalar indices
self._assign(node.target, inner_type.proxy())
self._body(node.body)
return
else:
# 1D case
b0 = 0 if begin is None else begin
if isinstance(b0, int) and isinstance(end, int):
for iv in range(b0, end):
self._assign(node.target, iv)
self._body(node.body)
return
rw: ReadWrites = ReadWrites.from_ast(node)
inputs: LiftTensorArgs = LiftTensorArgs(
{
Expand Down Expand Up @@ -947,6 +973,11 @@ def visit_Subscript(self, node: ast.Subscript) -> object:
assert isinstance(value, ExtendedAST)
type_info = value._type_info
if isinstance(type_info, SequenceType):
index_val = self.visit(node.slice)
if isinstance(index_val, int):
sequence_val = self.visit(value)
assert isinstance(sequence_val, (list, tuple))
return sequence_val[index_val]
if isinstance(node.slice, ast.Constant):
return self.visit(value)[self.visit(node.slice)] # pyright: ignore[reportIndexIssue]
raise exc.InvalidSequenceSubscription(node.slice)
Expand Down
93 changes: 92 additions & 1 deletion helion/_compiler/type_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1268,7 +1268,18 @@ def populate_symbol_origins(self, origin: Origin) -> None:
subtype.populate_symbol_origins(GetItemOrigin(origin, i))

def propagate_getitem(self, key: TypeInfo, origin: Origin) -> TypeInfo:
return super().propagate_getitem(key, origin)
# Try literal indexing first
try:
return super().propagate_getitem(key, origin)
except exc.TypeInferenceError:
# If indexing with a symbolic/grid index on device and the sequence length is known,
# conservatively merge all possible element types.
if origin.is_device() and isinstance(key, (SymIntType, GridIndexType)):
merged: TypeInfo = self.element_types[0]
for candidate in self.element_types[1:]:
merged = merged.merge(candidate)
return merged
raise

def merge(self, other: TypeInfo, var_name: str | None = None) -> TypeInfo:
if isinstance(other, SequenceType):
Expand Down Expand Up @@ -2161,6 +2172,86 @@ def visit_For(self, node: ast.For) -> TypeInfo:
raise exc.NestedGridLoop

self.device_loop_depth += device_loop

# Try static unrolling for device grid loops when iteration count is known
try:
if node._loop_type != LoopType.HOST and isinstance(node.iter, ast.Call):
call_node = node.iter
# Extract begin, end, step; support only 1D grid here
begin_val: int | None
end_val: int | None
step_val: int | None

if len(call_node.args) == 1:
begin_val = 0
end_type = self.visit(call_node.args[0])
step_type: TypeInfo | None = None
else:
begin_type = self.visit(call_node.args[0])
end_type = self.visit(call_node.args[1])
step_type = (
self.visit(call_node.args[2])
if len(call_node.args) >= 3
else None
)
begin_val = (
begin_type.as_literal() if begin_type.is_literal() else None
) # type: ignore[assignment]

for kw in call_node.keywords:
if kw.arg == "step" and step_type is None:
step_type = self.visit(kw.value)

end_val = end_type.as_literal() if end_type.is_literal() else None # type: ignore[assignment]
step_val = (
step_type.as_literal()
if (step_type is not None and step_type.is_literal())
else 1
) # type: ignore[assignment]

if (
isinstance(begin_val, int)
and isinstance(end_val, int)
and isinstance(step_val, int)
):
# Build concrete iteration values
iter_values = list(range(begin_val, end_val, step_val))
# Small guard to avoid excessive compile-time blowups
if len(iter_values) <= 64:
merged_scope: LocalScope | None = None
for iv in iter_values:
# Emulate _loop_body with loop index bound to a literal
self.push_scope()
self._assign(node.target, LiteralType(self.origin(), iv))
exit_scopes = [self.scope]
for stmt in node.body:
self.visit(stmt)
if isinstance(stmt, (ast.Break, ast.Continue)):
exit_scopes.append(self.scope.clone())
# Reset loop variable back to its GridIndexType to avoid control-flow merging issues
self._assign(
node.target,
iter_type.propagate_iter(self.origin()),
)
self.pop_scope()
iter_scope = functools.reduce(
lambda x, y: x.merge(y), exit_scopes
)
if merged_scope is None:
merged_scope = iter_scope
else:
merged_scope.merge(iter_scope)

if merged_scope is not None:
body = merged_scope
orelse = self._body(node.orelse)
self.scope.merge_if_else(body, orelse)
self.device_loop_depth -= device_loop
return NoType(origin=self.origin())
except NotImplementedError:
# Fall back to generic handling if we can't statically determine iterations
pass

body = self._loop_body(node.body)
with self.swap_scope(body):
# second pass for fixed point
Expand Down
68 changes: 68 additions & 0 deletions test/test_type_propagation.expected
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,74 @@ def root_graph_2():
_for_loop = helion_language__tracing_ops__for_loop(1, [0], [x_size0], []); x_size0 = _for_loop = None
return None

--- assertExpectedJournal(TestTypePropagation.test_for_loop_indexing_in_device_code0)
from __future__ import annotations

import torch
import triton
import triton.language as tl
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _helion_kernel(out, As_item_0, As_item_1, As_item_2, As_item_3, out_size_0, As_item_0_stride_0, As_item_1_stride_0, As_item_2_stride_0, As_item_3_stride_0, out_stride_0, _BLOCK_SIZE_0: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
mask_0 = indices_0 < out_size_0
load = tl.load(out + indices_0 * out_stride_0, mask_0, other=0)
load_1 = tl.load(As_item_0 + indices_0 * As_item_0_stride_0, mask_0, other=0)
v_0 = load + load_1
tl.store(out + indices_0 * out_stride_0, v_0, mask_0)
load_2 = tl.load(out + indices_0 * out_stride_0, mask_0, other=0)
load_3 = tl.load(As_item_1 + indices_0 * As_item_1_stride_0, mask_0, other=0)
v_1 = load_2 + load_3
tl.store(out + indices_0 * out_stride_0, v_1, mask_0)
load_4 = tl.load(out + indices_0 * out_stride_0, mask_0, other=0)
load_5 = tl.load(As_item_2 + indices_0 * As_item_2_stride_0, mask_0, other=0)
v_2 = load_4 + load_5
tl.store(out + indices_0 * out_stride_0, v_2, mask_0)
load_6 = tl.load(out + indices_0 * out_stride_0, mask_0, other=0)
load_7 = tl.load(As_item_3 + indices_0 * As_item_3_stride_0, mask_0, other=0)
v_3 = load_6 + load_7
tl.store(out + indices_0 * out_stride_0, v_3, mask_0)

def kernel(As: list[torch.Tensor], *, _launcher=_default_launcher):
out = torch.zeros_like(As[0])
_BLOCK_SIZE_0 = 16
_launcher(_helion_kernel, (triton.cdiv(out.size(0), _BLOCK_SIZE_0),), out, As[0], As[1], As[2], As[3], out.size(0), As[0].stride(0), As[1].stride(0), As[2].stride(0), As[3].stride(0), out.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
return out

--- assertExpectedJournal(TestTypePropagation.test_for_loop_indexing_in_device_code1)
from __future__ import annotations

import torch
import triton
import triton.language as tl
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _helion_kernel(out, As_item_0, As_item_1, As_item_2, As_item_3, out_size_0, As_item_0_stride_0, As_item_1_stride_0, As_item_2_stride_0, As_item_3_stride_0, out_stride_0, _BLOCK_SIZE_0: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
mask_0 = indices_0 < out_size_0
acc = tl.full([_BLOCK_SIZE_0], 0.0, tl.float32)
load = tl.load(As_item_0 + indices_0 * As_item_0_stride_0, mask_0, other=0)
v_0 = acc + load
load_1 = tl.load(As_item_1 + indices_0 * As_item_1_stride_0, mask_0, other=0)
v_1 = v_0 + load_1
load_2 = tl.load(As_item_2 + indices_0 * As_item_2_stride_0, mask_0, other=0)
v_2 = v_1 + load_2
load_3 = tl.load(As_item_3 + indices_0 * As_item_3_stride_0, mask_0, other=0)
v_3 = v_2 + load_3
tl.store(out + indices_0 * out_stride_0, v_3, mask_0)

def kernel(As: list[torch.Tensor], *, _launcher=_default_launcher):
out = torch.zeros_like(As[0])
_BLOCK_SIZE_0 = 16
_launcher(_helion_kernel, (triton.cdiv(out.size(0), _BLOCK_SIZE_0),), out, As[0], As[1], As[2], As[3], out.size(0), As[0].stride(0), As[1].stride(0), As[2].stride(0), As[3].stride(0), out.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
return out

--- assertExpectedJournal(TestTypePropagation.test_hl_full_usage)
def hl_full_usage(x: torch.Tensor):
# Call: TensorType([x_size0, x_size1], torch.int32) SourceOrigin(location=<SourceLocation basic_kernels.py:38>)
Expand Down
34 changes: 34 additions & 0 deletions test/test_type_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@

import helion
from helion import exc
from helion._testing import DEVICE
from helion._testing import RefEagerTestDisabled
from helion._testing import TestCase
from helion._testing import code_and_output
from helion._testing import import_path
import helion.language as hl

Expand Down Expand Up @@ -132,6 +134,38 @@ def use_unsupported_property(x: torch.Tensor) -> torch.Tensor:
):
type_propagation_report(use_unsupported_property, x)

def test_for_loop_indexing_in_device_code0(self):
@helion.kernel
def kernel(As: list[torch.Tensor]) -> torch.Tensor:
out = torch.zeros_like(As[0])
for tile in hl.tile(out.size()):
for i in range(len(As)):
a = As[i]
out[tile] += a[tile]
return out

args = [torch.randn(16, device=DEVICE) for _ in range(4)]
code, result = code_and_output(kernel, (args,))
torch.testing.assert_close(result, sum(args))
self.assertExpectedJournal(code)

def test_for_loop_indexing_in_device_code1(self):
@helion.kernel
def kernel(As: list[torch.Tensor]) -> torch.Tensor:
out = torch.zeros_like(As[0])
for tile in hl.tile(out.size()):
acc = hl.zeros(tile)
for i in range(len(As)):
a = As[i]
acc = acc + a[tile]
out[tile] = acc
return out

args = [torch.randn(16, device=DEVICE) for _ in range(4)]
code, result = code_and_output(kernel, (args,))
torch.testing.assert_close(result, sum(args))
self.assertExpectedJournal(code)


if __name__ == "__main__":
unittest.main()
Loading