Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions helion/_compiler/compile_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def __init__(
collections.Counter()
)
self.specialized_vars: set[sympy.Symbol] = set()
self.specialized_strides: set[tuple[str, int]] = set()
self.loop_dependency_checker = LoopDependencyChecker()
self._symint_cache: dict[object, torch.SymInt] = {}
self.device_load_count = (
Expand Down
19 changes: 16 additions & 3 deletions helion/_compiler/device_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import sympy
import torch
from torch._dynamo.source import LocalSource
from torch._inductor.codegen.triton import TritonPrinter
from torch.fx.graph import _Namespace

Expand Down Expand Up @@ -602,11 +603,23 @@ def tensor_size(self, fake_value: torch.Tensor, dim: int) -> Argument:
return self._tensor_property(TensorSizeArg, fake_value, dim, "size")

def tensor_stride(self, fake_value: torch.Tensor, dim: int) -> Argument:
v = fake_value.stride(dim)
env = CompileEnvironment.current()
# Check if this stride was explicitly specialized
source = env.input_sources.get(fake_value)
if (
isinstance(v := fake_value.stride(dim), int)
and CompileEnvironment.current().settings.static_shapes
isinstance(source, LocalSource)
and (source.local_name, dim) in env.specialized_strides
):
return StaticShape(v)
return StaticShape(int(v))
if isinstance(v, int):
if env.settings.static_shapes:
return StaticShape(v)
else:
# Check if all free symbols are specialized
syms = v._sympy_().free_symbols
if syms and syms <= env.specialized_vars:
return StaticShape(int(v))
return self._tensor_property(TensorStrideArg, fake_value, dim, "stride")

def sorted_args(self) -> list[Argument]:
Expand Down
2 changes: 1 addition & 1 deletion helion/_compiler/type_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,7 @@ def propagate_call(
attr = self.attr()
if attr in {"dim", "ndimension"} and not (args or kwargs):
return TypeInfo.from_example(self.tensor.fake_value.ndim, origin)
if attr in {"shape", "size"} and not kwargs:
if attr in {"shape", "size", "stride"} and not kwargs:
fn = getattr(self.tensor.fake_value, attr)
try:
return TypeInfo.from_example(
Expand Down
2 changes: 1 addition & 1 deletion helion/exc.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ class SpecializeOnDevice(BaseError):


class SpecializeArgType(BaseError):
message = "hl.specialize() must be called on a size from an input tensor, got: {}"
message = "hl.specialize() must be called on a size or stride from an input tensor, got: {}"


class StackTensorcOnHost(BaseError):
Expand Down
16 changes: 15 additions & 1 deletion helion/language/constexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from typing_extensions import TypeVar

import torch
from torch._dynamo.source import LocalSource
from torch._dynamo.source import TensorProperty
from torch._dynamo.source import TensorPropertySource

from .. import exc
from .._compiler.ast_extension import expr_from_string
Expand Down Expand Up @@ -87,7 +90,18 @@ def _(value: TypeInfo, *, origin: Origin) -> TypeInfo:
env = CompileEnvironment.current()

def handle_symint(symint: torch.SymInt) -> int:
env.specialized_vars.update(symint._sympy_().free_symbols)
syms = symint._sympy_().free_symbols
env.specialized_vars.update(syms)
# Track stride specializations
for sym in syms:
for source in env.shape_env.var_to_sources.get(sym, []):
if (
isinstance(source, TensorPropertySource)
and source.prop == TensorProperty.STRIDE
and isinstance(source.base, LocalSource)
and source.idx is not None
):
env.specialized_strides.add((source.base.local_name, source.idx))
return symint.__int__()

specialized = _convert_specializable(proxy, on_symint=handle_symint)
Expand Down
8 changes: 5 additions & 3 deletions helion/runtime/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,12 +624,14 @@ def _specialize_extra(self) -> list[Callable[[Sequence[object]], Hashable]]:

def make_extractor(v: Source) -> Callable[[Sequence[object]], Hashable]:
if isinstance(v, TensorPropertySource):
assert v.prop == TensorProperty.SIZE
index = v.idx
assert index is not None
inner = make_extractor(v.base)

return lambda args: cast("torch.Tensor", inner(args)).size(index)
if v.prop == TensorProperty.SIZE:
return lambda args: cast("torch.Tensor", inner(args)).size(index)
if v.prop == TensorProperty.STRIDE:
return lambda args: cast("torch.Tensor", inner(args)).stride(index)
raise exc.SpecializeArgType(v)
if isinstance(v, LocalSource):
index = arg_name_to_index[v.local_name]
return operator.itemgetter(index)
Expand Down
113 changes: 113 additions & 0 deletions test/test_specialize.expected
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,119 @@ def fn(x: torch.Tensor, *, _launcher=_default_launcher):
# src[test_specialize.py:N]: return out
return out

--- assertExpectedJournal(TestSpecialize.test_specialize_size_becomes_static)
from __future__ import annotations

import torch
import triton
import triton.language as tl
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _helion_fn(x, out, out_stride_0, x_stride_0, _BLOCK_SIZE_0: tl.constexpr):
# src[test_specialize.py:N]: for tile in hl.tile(n):
pid_0 = tl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
mask_0 = indices_0 < 137
# src[test_specialize.py:N]: out[tile] = x[tile] + 1
load = tl.load(x + indices_0 * x_stride_0, mask_0, other=0)
v_0 = 1.0
v_1 = load + v_0
tl.store(out + indices_0 * out_stride_0, v_1, mask_0)

def fn(x: torch.Tensor, *, _launcher=_default_launcher):
# src[test_specialize.py:N]: out = torch.empty_like(x)
out = torch.empty_like(x)
# src[test_specialize.py:N]: for tile in hl.tile(n):
_BLOCK_SIZE_0 = 32
# src[test_specialize.py:N]: for tile in hl.tile(n):
# src[test_specialize.py:N]: out[tile] = x[tile] + 1
_launcher(_helion_fn, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, out, out.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=1)
# src[test_specialize.py:N]: return out
return out

--- assertExpectedJournal(TestSpecialize.test_specialize_stride_basic)
from __future__ import annotations

import torch
import triton
import triton.language as tl
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _helion_fn(x, out, x_size_0, x_size_1, out_stride_0, out_stride_1, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
# src[test_specialize.py:N]: for tile in hl.tile(x.size()):
num_blocks_0 = tl.cdiv(x_size_0, _BLOCK_SIZE_0)
pid_0 = tl.program_id(0) % num_blocks_0
pid_1 = tl.program_id(0) // num_blocks_0
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
mask_0 = indices_0 < x_size_0
offset_1 = pid_1 * _BLOCK_SIZE_1
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
mask_1 = indices_1 < x_size_1
# src[test_specialize.py:N]: out[tile] = x[tile] + stride
load = tl.load(x + (indices_0[:, None] * 137 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
v_0 = 137.0
v_1 = load + v_0
tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_1, mask_0[:, None] & mask_1[None, :])

def fn(x: torch.Tensor, *, _launcher=_default_launcher):
# src[test_specialize.py:N]: out = torch.empty_like(x)
out = torch.empty_like(x)
# src[test_specialize.py:N]: for tile in hl.tile(x.size()):
_BLOCK_SIZE_0 = 32
_BLOCK_SIZE_1 = 32
# src[test_specialize.py:N]: for tile in hl.tile(x.size()):
# src[test_specialize.py:N]: # Use stride in computation to verify it's a constant
# src[test_specialize.py:N]: out[tile] = x[tile] + stride
_launcher(_helion_fn, (triton.cdiv(x.size(0), _BLOCK_SIZE_0) * triton.cdiv(x.size(1), _BLOCK_SIZE_1),), x, out, x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1)
# src[test_specialize.py:N]: return out
return out

--- assertExpectedJournal(TestSpecialize.test_specialize_stride_tuple)
from __future__ import annotations

import torch
import triton
import triton.language as tl
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _helion_fn(x, out, x_size_0, x_size_1, out_stride_0, out_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
# src[test_specialize.py:N]: for tile in hl.tile(x.size()):
num_blocks_0 = tl.cdiv(x_size_0, _BLOCK_SIZE_0)
pid_0 = tl.program_id(0) % num_blocks_0
pid_1 = tl.program_id(0) // num_blocks_0
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
mask_0 = indices_0 < x_size_0
offset_1 = pid_1 * _BLOCK_SIZE_1
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
mask_1 = indices_1 < x_size_1
# src[test_specialize.py:N]: out[tile] = x[tile] + stride0 + stride1
load = tl.load(x + (indices_0[:, None] * 311 + indices_1[None, :] * 131), mask_0[:, None] & mask_1[None, :], other=0)
v_0 = 311.0
v_1 = load + v_0
v_2 = 131.0
v_3 = v_1 + v_2
tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_3, mask_0[:, None] & mask_1[None, :])

def fn(x: torch.Tensor, *, _launcher=_default_launcher):
# src[test_specialize.py:N]: stride0, stride1 = hl.specialize((x.stride(0), x.stride(1)))
stride0, stride1 = (311, 131)
# src[test_specialize.py:N]: out = torch.empty_like(x)
out = torch.empty_like(x)
# src[test_specialize.py:N]: for tile in hl.tile(x.size()):
_BLOCK_SIZE_0 = 32
_BLOCK_SIZE_1 = 32
# src[test_specialize.py:N]: for tile in hl.tile(x.size()):
# src[test_specialize.py:N]: out[tile] = x[tile] + stride0 + stride1
_launcher(_helion_fn, (triton.cdiv(x.size(0), _BLOCK_SIZE_0) * triton.cdiv(x.size(1), _BLOCK_SIZE_1),), x, out, x.size(0), x.size(1), out.stride(0), out.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1)
# src[test_specialize.py:N]: return out
return out

--- assertExpectedJournal(TestSpecialize.test_specialize_tuple_element)
from __future__ import annotations

Expand Down
119 changes: 119 additions & 0 deletions test/test_specialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,125 @@ def foo(x: torch.Tensor, bitshift: tuple[int, int]) -> torch.Tensor:
self.assertIn("65536", code)
self.assertExpectedJournal(code)

def test_specialize_size_becomes_static(self):
"""Test that hl.specialize on a size makes it NOT passed to the triton kernel."""

@helion.kernel(static_shapes=False)
def fn(x: torch.Tensor) -> torch.Tensor:
n = hl.specialize(x.size(0))
out = torch.empty_like(x)
for tile in hl.tile(n):
out[tile] = x[tile] + 1
return out

x = torch.randn([137], device=DEVICE) # Use prime to avoid alignment
code, result = code_and_output(fn, (x,))
torch.testing.assert_close(result, x + 1)
# Verify x_size_0 is NOT passed as an argument (it should be static)
self.assertNotIn("x_size_0", code)
self.assertExpectedJournal(code)

def test_specialize_stride_basic(self):
"""Test that hl.specialize works with tensor strides."""

@helion.kernel(static_shapes=False, autotune_effort="none")
def fn(x: torch.Tensor) -> torch.Tensor:
stride = hl.specialize(x.stride(0))
out = torch.empty_like(x)
for tile in hl.tile(x.size()):
# Use stride in computation to verify it's a constant
out[tile] = x[tile] + stride
return out

# Use empty_strided to create tensor with a unique stride value (137)
# that won't be confused with shape values
size = (64, 64)
stride0 = 137 # Distinctive prime number for stride(0)
stride1 = 1
# Need storage size to fit: (size[0]-1)*stride0 + (size[1]-1)*stride1 + 1
storage_size = (size[0] - 1) * stride0 + (size[1] - 1) * stride1 + 1
storage = torch.randn(storage_size, device=DEVICE)
x = torch.as_strided(storage, size, (stride0, stride1))

code, result = code_and_output(fn, (x,))
torch.testing.assert_close(result, x + x.stride(0))
# Verify the unique stride value 137 is inlined as a constant
self.assertIn("137", code)
# Verify x_stride_0 is NOT passed as an argument (it should be inlined)
self.assertNotIn("x_stride_0", code)
self.assertExpectedJournal(code)

def test_specialize_stride_creates_different_variants(self):
"""Test that different stride patterns create different kernel variants."""

@helion.kernel(static_shapes=False, autotune_effort="none")
def fn(x: torch.Tensor) -> torch.Tensor:
stride = hl.specialize(x.stride(0))
out = torch.empty_like(x)
for tile in hl.tile(x.size()):
out[tile] = x[tile] + stride
return out

# Create two tensors with different unique stride values using empty_strided
size = (64, 64)

# First tensor with stride(0) = 173 (distinctive prime)
stride0_a = 173
storage_size_a = (size[0] - 1) * stride0_a + (size[1] - 1) * 1 + 1
storage_a = torch.randn(storage_size_a, device=DEVICE)
x_a = torch.as_strided(storage_a, size, (stride0_a, 1))

# Second tensor with stride(0) = 257 (different distinctive prime)
stride0_b = 257
storage_size_b = (size[0] - 1) * stride0_b + (size[1] - 1) * 1 + 1
storage_b = torch.randn(storage_size_b, device=DEVICE)
x_b = torch.as_strided(storage_b, size, (stride0_b, 1))

# These should create different bound kernels due to different strides
bound1 = fn.bind((x_a,))
bound2 = fn.bind((x_b,))

# Verify different variants are used
self.assertTrueIfInNormalMode(bound1 is not bound2)

# Verify correctness
result1 = fn(x_a)
result2 = fn(x_b)
torch.testing.assert_close(result1, x_a + stride0_a)
torch.testing.assert_close(result2, x_b + stride0_b)

def test_specialize_stride_tuple(self):
"""Test that hl.specialize works with tuple of strides."""

@helion.kernel(static_shapes=False, autotune_effort="none")
def fn(x: torch.Tensor) -> torch.Tensor:
stride0, stride1 = hl.specialize((x.stride(0), x.stride(1)))
out = torch.empty_like(x)
for tile in hl.tile(x.size()):
out[tile] = x[tile] + stride0 + stride1
return out

# Create tensor with unique stride values using empty_strided
# stride0 = 311, stride1 = 131 (distinctive primes unlikely to appear elsewhere)
size = (64, 64)
stride0 = 311
stride1 = 131
# Storage must fit the largest offset: (size[0]-1)*stride0 + (size[1]-1)*stride1 + 1
storage_size = (size[0] - 1) * stride0 + (size[1] - 1) * stride1 + 1
storage = torch.randn(storage_size, device=DEVICE)
x = torch.as_strided(storage, size, (stride0, stride1))

code, result = code_and_output(fn, (x,))
expected = x + stride0 + stride1
torch.testing.assert_close(result, expected)
# Verify both unique stride values appear in the generated code
self.assertIn("311", code)
self.assertIn("131", code)
# Verify both x_stride_0 and x_stride_1 are NOT passed as arguments (they should be inlined)
self.assertNotIn("x_stride_0", code)
self.assertNotIn("x_stride_1", code)
self.assertExpectedJournal(code)


if __name__ == "__main__":
unittest.main()