diff --git a/helion/_compiler/compile_environment.py b/helion/_compiler/compile_environment.py index 764db7ea3..3b7e7fa60 100644 --- a/helion/_compiler/compile_environment.py +++ b/helion/_compiler/compile_environment.py @@ -112,6 +112,7 @@ def __init__( collections.Counter() ) self.specialized_vars: set[sympy.Symbol] = set() + self.specialized_strides: set[tuple[str, int]] = set() self.loop_dependency_checker = LoopDependencyChecker() self._symint_cache: dict[object, torch.SymInt] = {} self.device_load_count = ( diff --git a/helion/_compiler/device_function.py b/helion/_compiler/device_function.py index bfba224b1..a8203c0a6 100644 --- a/helion/_compiler/device_function.py +++ b/helion/_compiler/device_function.py @@ -14,6 +14,7 @@ import sympy import torch +from torch._dynamo.source import LocalSource from torch._inductor.codegen.triton import TritonPrinter from torch.fx.graph import _Namespace @@ -602,11 +603,23 @@ def tensor_size(self, fake_value: torch.Tensor, dim: int) -> Argument: return self._tensor_property(TensorSizeArg, fake_value, dim, "size") def tensor_stride(self, fake_value: torch.Tensor, dim: int) -> Argument: + v = fake_value.stride(dim) + env = CompileEnvironment.current() + # Check if this stride was explicitly specialized + source = env.input_sources.get(fake_value) if ( - isinstance(v := fake_value.stride(dim), int) - and CompileEnvironment.current().settings.static_shapes + isinstance(source, LocalSource) + and (source.local_name, dim) in env.specialized_strides ): - return StaticShape(v) + return StaticShape(int(v)) + if isinstance(v, int): + if env.settings.static_shapes: + return StaticShape(v) + else: + # Check if all free symbols are specialized + syms = v._sympy_().free_symbols + if syms and syms <= env.specialized_vars: + return StaticShape(int(v)) return self._tensor_property(TensorStrideArg, fake_value, dim, "stride") def sorted_args(self) -> list[Argument]: diff --git a/helion/_compiler/type_propagation.py b/helion/_compiler/type_propagation.py index c2d96c82b..db6dc427d 100644 --- a/helion/_compiler/type_propagation.py +++ b/helion/_compiler/type_propagation.py @@ -645,7 +645,7 @@ def propagate_call( attr = self.attr() if attr in {"dim", "ndimension"} and not (args or kwargs): return TypeInfo.from_example(self.tensor.fake_value.ndim, origin) - if attr in {"shape", "size"} and not kwargs: + if attr in {"shape", "size", "stride"} and not kwargs: fn = getattr(self.tensor.fake_value, attr) try: return TypeInfo.from_example( diff --git a/helion/exc.py b/helion/exc.py index f29e67d15..8ce2b3f51 100644 --- a/helion/exc.py +++ b/helion/exc.py @@ -186,7 +186,7 @@ class SpecializeOnDevice(BaseError): class SpecializeArgType(BaseError): - message = "hl.specialize() must be called on a size from an input tensor, got: {}" + message = "hl.specialize() must be called on a size or stride from an input tensor, got: {}" class StackTensorcOnHost(BaseError): diff --git a/helion/language/constexpr.py b/helion/language/constexpr.py index 4528d9e05..48aab38d5 100644 --- a/helion/language/constexpr.py +++ b/helion/language/constexpr.py @@ -6,6 +6,9 @@ from typing_extensions import TypeVar import torch +from torch._dynamo.source import LocalSource +from torch._dynamo.source import TensorProperty +from torch._dynamo.source import TensorPropertySource from .. import exc from .._compiler.ast_extension import expr_from_string @@ -87,7 +90,18 @@ def _(value: TypeInfo, *, origin: Origin) -> TypeInfo: env = CompileEnvironment.current() def handle_symint(symint: torch.SymInt) -> int: - env.specialized_vars.update(symint._sympy_().free_symbols) + syms = symint._sympy_().free_symbols + env.specialized_vars.update(syms) + # Track stride specializations + for sym in syms: + for source in env.shape_env.var_to_sources.get(sym, []): + if ( + isinstance(source, TensorPropertySource) + and source.prop == TensorProperty.STRIDE + and isinstance(source.base, LocalSource) + and source.idx is not None + ): + env.specialized_strides.add((source.base.local_name, source.idx)) return symint.__int__() specialized = _convert_specializable(proxy, on_symint=handle_symint) diff --git a/helion/runtime/kernel.py b/helion/runtime/kernel.py index 88fec2868..cb5769a1a 100644 --- a/helion/runtime/kernel.py +++ b/helion/runtime/kernel.py @@ -624,12 +624,14 @@ def _specialize_extra(self) -> list[Callable[[Sequence[object]], Hashable]]: def make_extractor(v: Source) -> Callable[[Sequence[object]], Hashable]: if isinstance(v, TensorPropertySource): - assert v.prop == TensorProperty.SIZE index = v.idx assert index is not None inner = make_extractor(v.base) - - return lambda args: cast("torch.Tensor", inner(args)).size(index) + if v.prop == TensorProperty.SIZE: + return lambda args: cast("torch.Tensor", inner(args)).size(index) + if v.prop == TensorProperty.STRIDE: + return lambda args: cast("torch.Tensor", inner(args)).stride(index) + raise exc.SpecializeArgType(v) if isinstance(v, LocalSource): index = arg_name_to_index[v.local_name] return operator.itemgetter(index) diff --git a/test/test_specialize.expected b/test/test_specialize.expected index d397c81c1..03a52f23c 100644 --- a/test/test_specialize.expected +++ b/test/test_specialize.expected @@ -335,6 +335,119 @@ def fn(x: torch.Tensor, *, _launcher=_default_launcher): # src[test_specialize.py:N]: return out return out +--- assertExpectedJournal(TestSpecialize.test_specialize_size_becomes_static) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_fn(x, out, out_stride_0, x_stride_0, _BLOCK_SIZE_0: tl.constexpr): + # src[test_specialize.py:N]: for tile in hl.tile(n): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < 137 + # src[test_specialize.py:N]: out[tile] = x[tile] + 1 + load = tl.load(x + indices_0 * x_stride_0, mask_0, other=0) + v_0 = 1.0 + v_1 = load + v_0 + tl.store(out + indices_0 * out_stride_0, v_1, mask_0) + +def fn(x: torch.Tensor, *, _launcher=_default_launcher): + # src[test_specialize.py:N]: out = torch.empty_like(x) + out = torch.empty_like(x) + # src[test_specialize.py:N]: for tile in hl.tile(n): + _BLOCK_SIZE_0 = 32 + # src[test_specialize.py:N]: for tile in hl.tile(n): + # src[test_specialize.py:N]: out[tile] = x[tile] + 1 + _launcher(_helion_fn, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, out, out.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=1) + # src[test_specialize.py:N]: return out + return out + +--- assertExpectedJournal(TestSpecialize.test_specialize_stride_basic) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_fn(x, out, x_size_0, x_size_1, out_stride_0, out_stride_1, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr): + # src[test_specialize.py:N]: for tile in hl.tile(x.size()): + num_blocks_0 = tl.cdiv(x_size_0, _BLOCK_SIZE_0) + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < x_size_0 + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + mask_1 = indices_1 < x_size_1 + # src[test_specialize.py:N]: out[tile] = x[tile] + stride + load = tl.load(x + (indices_0[:, None] * 137 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0) + v_0 = 137.0 + v_1 = load + v_0 + tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_1, mask_0[:, None] & mask_1[None, :]) + +def fn(x: torch.Tensor, *, _launcher=_default_launcher): + # src[test_specialize.py:N]: out = torch.empty_like(x) + out = torch.empty_like(x) + # src[test_specialize.py:N]: for tile in hl.tile(x.size()): + _BLOCK_SIZE_0 = 32 + _BLOCK_SIZE_1 = 32 + # src[test_specialize.py:N]: for tile in hl.tile(x.size()): + # src[test_specialize.py:N]: # Use stride in computation to verify it's a constant + # src[test_specialize.py:N]: out[tile] = x[tile] + stride + _launcher(_helion_fn, (triton.cdiv(x.size(0), _BLOCK_SIZE_0) * triton.cdiv(x.size(1), _BLOCK_SIZE_1),), x, out, x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) + # src[test_specialize.py:N]: return out + return out + +--- assertExpectedJournal(TestSpecialize.test_specialize_stride_tuple) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_fn(x, out, x_size_0, x_size_1, out_stride_0, out_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr): + # src[test_specialize.py:N]: for tile in hl.tile(x.size()): + num_blocks_0 = tl.cdiv(x_size_0, _BLOCK_SIZE_0) + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < x_size_0 + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + mask_1 = indices_1 < x_size_1 + # src[test_specialize.py:N]: out[tile] = x[tile] + stride0 + stride1 + load = tl.load(x + (indices_0[:, None] * 311 + indices_1[None, :] * 131), mask_0[:, None] & mask_1[None, :], other=0) + v_0 = 311.0 + v_1 = load + v_0 + v_2 = 131.0 + v_3 = v_1 + v_2 + tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_3, mask_0[:, None] & mask_1[None, :]) + +def fn(x: torch.Tensor, *, _launcher=_default_launcher): + # src[test_specialize.py:N]: stride0, stride1 = hl.specialize((x.stride(0), x.stride(1))) + stride0, stride1 = (311, 131) + # src[test_specialize.py:N]: out = torch.empty_like(x) + out = torch.empty_like(x) + # src[test_specialize.py:N]: for tile in hl.tile(x.size()): + _BLOCK_SIZE_0 = 32 + _BLOCK_SIZE_1 = 32 + # src[test_specialize.py:N]: for tile in hl.tile(x.size()): + # src[test_specialize.py:N]: out[tile] = x[tile] + stride0 + stride1 + _launcher(_helion_fn, (triton.cdiv(x.size(0), _BLOCK_SIZE_0) * triton.cdiv(x.size(1), _BLOCK_SIZE_1),), x, out, x.size(0), x.size(1), out.stride(0), out.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1) + # src[test_specialize.py:N]: return out + return out + --- assertExpectedJournal(TestSpecialize.test_specialize_tuple_element) from __future__ import annotations diff --git a/test/test_specialize.py b/test/test_specialize.py index e7d53fca4..b4224f4b2 100644 --- a/test/test_specialize.py +++ b/test/test_specialize.py @@ -326,6 +326,125 @@ def foo(x: torch.Tensor, bitshift: tuple[int, int]) -> torch.Tensor: self.assertIn("65536", code) self.assertExpectedJournal(code) + def test_specialize_size_becomes_static(self): + """Test that hl.specialize on a size makes it NOT passed to the triton kernel.""" + + @helion.kernel(static_shapes=False) + def fn(x: torch.Tensor) -> torch.Tensor: + n = hl.specialize(x.size(0)) + out = torch.empty_like(x) + for tile in hl.tile(n): + out[tile] = x[tile] + 1 + return out + + x = torch.randn([137], device=DEVICE) # Use prime to avoid alignment + code, result = code_and_output(fn, (x,)) + torch.testing.assert_close(result, x + 1) + # Verify x_size_0 is NOT passed as an argument (it should be static) + self.assertNotIn("x_size_0", code) + self.assertExpectedJournal(code) + + def test_specialize_stride_basic(self): + """Test that hl.specialize works with tensor strides.""" + + @helion.kernel(static_shapes=False, autotune_effort="none") + def fn(x: torch.Tensor) -> torch.Tensor: + stride = hl.specialize(x.stride(0)) + out = torch.empty_like(x) + for tile in hl.tile(x.size()): + # Use stride in computation to verify it's a constant + out[tile] = x[tile] + stride + return out + + # Use empty_strided to create tensor with a unique stride value (137) + # that won't be confused with shape values + size = (64, 64) + stride0 = 137 # Distinctive prime number for stride(0) + stride1 = 1 + # Need storage size to fit: (size[0]-1)*stride0 + (size[1]-1)*stride1 + 1 + storage_size = (size[0] - 1) * stride0 + (size[1] - 1) * stride1 + 1 + storage = torch.randn(storage_size, device=DEVICE) + x = torch.as_strided(storage, size, (stride0, stride1)) + + code, result = code_and_output(fn, (x,)) + torch.testing.assert_close(result, x + x.stride(0)) + # Verify the unique stride value 137 is inlined as a constant + self.assertIn("137", code) + # Verify x_stride_0 is NOT passed as an argument (it should be inlined) + self.assertNotIn("x_stride_0", code) + self.assertExpectedJournal(code) + + def test_specialize_stride_creates_different_variants(self): + """Test that different stride patterns create different kernel variants.""" + + @helion.kernel(static_shapes=False, autotune_effort="none") + def fn(x: torch.Tensor) -> torch.Tensor: + stride = hl.specialize(x.stride(0)) + out = torch.empty_like(x) + for tile in hl.tile(x.size()): + out[tile] = x[tile] + stride + return out + + # Create two tensors with different unique stride values using empty_strided + size = (64, 64) + + # First tensor with stride(0) = 173 (distinctive prime) + stride0_a = 173 + storage_size_a = (size[0] - 1) * stride0_a + (size[1] - 1) * 1 + 1 + storage_a = torch.randn(storage_size_a, device=DEVICE) + x_a = torch.as_strided(storage_a, size, (stride0_a, 1)) + + # Second tensor with stride(0) = 257 (different distinctive prime) + stride0_b = 257 + storage_size_b = (size[0] - 1) * stride0_b + (size[1] - 1) * 1 + 1 + storage_b = torch.randn(storage_size_b, device=DEVICE) + x_b = torch.as_strided(storage_b, size, (stride0_b, 1)) + + # These should create different bound kernels due to different strides + bound1 = fn.bind((x_a,)) + bound2 = fn.bind((x_b,)) + + # Verify different variants are used + self.assertTrueIfInNormalMode(bound1 is not bound2) + + # Verify correctness + result1 = fn(x_a) + result2 = fn(x_b) + torch.testing.assert_close(result1, x_a + stride0_a) + torch.testing.assert_close(result2, x_b + stride0_b) + + def test_specialize_stride_tuple(self): + """Test that hl.specialize works with tuple of strides.""" + + @helion.kernel(static_shapes=False, autotune_effort="none") + def fn(x: torch.Tensor) -> torch.Tensor: + stride0, stride1 = hl.specialize((x.stride(0), x.stride(1))) + out = torch.empty_like(x) + for tile in hl.tile(x.size()): + out[tile] = x[tile] + stride0 + stride1 + return out + + # Create tensor with unique stride values using empty_strided + # stride0 = 311, stride1 = 131 (distinctive primes unlikely to appear elsewhere) + size = (64, 64) + stride0 = 311 + stride1 = 131 + # Storage must fit the largest offset: (size[0]-1)*stride0 + (size[1]-1)*stride1 + 1 + storage_size = (size[0] - 1) * stride0 + (size[1] - 1) * stride1 + 1 + storage = torch.randn(storage_size, device=DEVICE) + x = torch.as_strided(storage, size, (stride0, stride1)) + + code, result = code_and_output(fn, (x,)) + expected = x + stride0 + stride1 + torch.testing.assert_close(result, expected) + # Verify both unique stride values appear in the generated code + self.assertIn("311", code) + self.assertIn("131", code) + # Verify both x_stride_0 and x_stride_1 are NOT passed as arguments (they should be inlined) + self.assertNotIn("x_stride_0", code) + self.assertNotIn("x_stride_1", code) + self.assertExpectedJournal(code) + if __name__ == "__main__": unittest.main()