Skip to content

Commit ec0806b

Browse files
committed
test
1 parent 7aada66 commit ec0806b

File tree

2 files changed

+232
-0
lines changed

2 files changed

+232
-0
lines changed

test/test_specialize.expected

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,119 @@ def fn(x: torch.Tensor, *, _launcher=_default_launcher):
335335
# src[test_specialize.py:N]: return out
336336
return out
337337

338+
--- assertExpectedJournal(TestSpecialize.test_specialize_size_becomes_static)
339+
from __future__ import annotations
340+
341+
import torch
342+
import triton
343+
import triton.language as tl
344+
from helion.runtime import default_launcher as _default_launcher
345+
346+
@triton.jit
347+
def _helion_fn(x, out, out_stride_0, x_stride_0, _BLOCK_SIZE_0: tl.constexpr):
348+
# src[test_specialize.py:N]: for tile in hl.tile(n):
349+
pid_0 = tl.program_id(0)
350+
offset_0 = pid_0 * _BLOCK_SIZE_0
351+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
352+
mask_0 = indices_0 < 137
353+
# src[test_specialize.py:N]: out[tile] = x[tile] + 1
354+
load = tl.load(x + indices_0 * x_stride_0, mask_0, other=0)
355+
v_0 = 1.0
356+
v_1 = load + v_0
357+
tl.store(out + indices_0 * out_stride_0, v_1, mask_0)
358+
359+
def fn(x: torch.Tensor, *, _launcher=_default_launcher):
360+
# src[test_specialize.py:N]: out = torch.empty_like(x)
361+
out = torch.empty_like(x)
362+
# src[test_specialize.py:N]: for tile in hl.tile(n):
363+
_BLOCK_SIZE_0 = 32
364+
# src[test_specialize.py:N]: for tile in hl.tile(n):
365+
# src[test_specialize.py:N]: out[tile] = x[tile] + 1
366+
_launcher(_helion_fn, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, out, out.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=1)
367+
# src[test_specialize.py:N]: return out
368+
return out
369+
370+
--- assertExpectedJournal(TestSpecialize.test_specialize_stride_basic)
371+
from __future__ import annotations
372+
373+
import torch
374+
import triton
375+
import triton.language as tl
376+
from helion.runtime import default_launcher as _default_launcher
377+
378+
@triton.jit
379+
def _helion_fn(x, out, x_size_0, x_size_1, out_stride_0, out_stride_1, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
380+
# src[test_specialize.py:N]: for tile in hl.tile(x.size()):
381+
num_blocks_0 = tl.cdiv(x_size_0, _BLOCK_SIZE_0)
382+
pid_0 = tl.program_id(0) % num_blocks_0
383+
pid_1 = tl.program_id(0) // num_blocks_0
384+
offset_0 = pid_0 * _BLOCK_SIZE_0
385+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
386+
mask_0 = indices_0 < x_size_0
387+
offset_1 = pid_1 * _BLOCK_SIZE_1
388+
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
389+
mask_1 = indices_1 < x_size_1
390+
# src[test_specialize.py:N]: out[tile] = x[tile] + stride
391+
load = tl.load(x + (indices_0[:, None] * 137 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
392+
v_0 = 137.0
393+
v_1 = load + v_0
394+
tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_1, mask_0[:, None] & mask_1[None, :])
395+
396+
def fn(x: torch.Tensor, *, _launcher=_default_launcher):
397+
# src[test_specialize.py:N]: out = torch.empty_like(x)
398+
out = torch.empty_like(x)
399+
# src[test_specialize.py:N]: for tile in hl.tile(x.size()):
400+
_BLOCK_SIZE_0 = 32
401+
_BLOCK_SIZE_1 = 32
402+
# src[test_specialize.py:N]: for tile in hl.tile(x.size()):
403+
# src[test_specialize.py:N]: # Use stride in computation to verify it's a constant
404+
# src[test_specialize.py:N]: out[tile] = x[tile] + stride
405+
_launcher(_helion_fn, (triton.cdiv(x.size(0), _BLOCK_SIZE_0) * triton.cdiv(x.size(1), _BLOCK_SIZE_1),), x, out, x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1)
406+
# src[test_specialize.py:N]: return out
407+
return out
408+
409+
--- assertExpectedJournal(TestSpecialize.test_specialize_stride_tuple)
410+
from __future__ import annotations
411+
412+
import torch
413+
import triton
414+
import triton.language as tl
415+
from helion.runtime import default_launcher as _default_launcher
416+
417+
@triton.jit
418+
def _helion_fn(x, out, x_size_0, x_size_1, out_stride_0, out_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
419+
# src[test_specialize.py:N]: for tile in hl.tile(x.size()):
420+
num_blocks_0 = tl.cdiv(x_size_0, _BLOCK_SIZE_0)
421+
pid_0 = tl.program_id(0) % num_blocks_0
422+
pid_1 = tl.program_id(0) // num_blocks_0
423+
offset_0 = pid_0 * _BLOCK_SIZE_0
424+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
425+
mask_0 = indices_0 < x_size_0
426+
offset_1 = pid_1 * _BLOCK_SIZE_1
427+
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
428+
mask_1 = indices_1 < x_size_1
429+
# src[test_specialize.py:N]: out[tile] = x[tile] + stride0 + stride1
430+
load = tl.load(x + (indices_0[:, None] * 311 + indices_1[None, :] * 131), mask_0[:, None] & mask_1[None, :], other=0)
431+
v_0 = 311.0
432+
v_1 = load + v_0
433+
v_2 = 131.0
434+
v_3 = v_1 + v_2
435+
tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_3, mask_0[:, None] & mask_1[None, :])
436+
437+
def fn(x: torch.Tensor, *, _launcher=_default_launcher):
438+
# src[test_specialize.py:N]: stride0, stride1 = hl.specialize((x.stride(0), x.stride(1)))
439+
stride0, stride1 = (311, 131)
440+
# src[test_specialize.py:N]: out = torch.empty_like(x)
441+
out = torch.empty_like(x)
442+
# src[test_specialize.py:N]: for tile in hl.tile(x.size()):
443+
_BLOCK_SIZE_0 = 32
444+
_BLOCK_SIZE_1 = 32
445+
# src[test_specialize.py:N]: for tile in hl.tile(x.size()):
446+
# src[test_specialize.py:N]: out[tile] = x[tile] + stride0 + stride1
447+
_launcher(_helion_fn, (triton.cdiv(x.size(0), _BLOCK_SIZE_0) * triton.cdiv(x.size(1), _BLOCK_SIZE_1),), x, out, x.size(0), x.size(1), out.stride(0), out.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1)
448+
# src[test_specialize.py:N]: return out
449+
return out
450+
338451
--- assertExpectedJournal(TestSpecialize.test_specialize_tuple_element)
339452
from __future__ import annotations
340453

test/test_specialize.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,125 @@ def foo(x: torch.Tensor, bitshift: tuple[int, int]) -> torch.Tensor:
326326
self.assertIn("65536", code)
327327
self.assertExpectedJournal(code)
328328

329+
def test_specialize_size_becomes_static(self):
330+
"""Test that hl.specialize on a size makes it NOT passed to the triton kernel."""
331+
332+
@helion.kernel(static_shapes=False)
333+
def fn(x: torch.Tensor) -> torch.Tensor:
334+
n = hl.specialize(x.size(0))
335+
out = torch.empty_like(x)
336+
for tile in hl.tile(n):
337+
out[tile] = x[tile] + 1
338+
return out
339+
340+
x = torch.randn([137], device=DEVICE) # Use prime to avoid alignment
341+
code, result = code_and_output(fn, (x,))
342+
torch.testing.assert_close(result, x + 1)
343+
# Verify x_size_0 is NOT passed as an argument (it should be static)
344+
self.assertNotIn("x_size_0", code)
345+
self.assertExpectedJournal(code)
346+
347+
def test_specialize_stride_basic(self):
348+
"""Test that hl.specialize works with tensor strides."""
349+
350+
@helion.kernel(static_shapes=False, autotune_effort="none")
351+
def fn(x: torch.Tensor) -> torch.Tensor:
352+
stride = hl.specialize(x.stride(0))
353+
out = torch.empty_like(x)
354+
for tile in hl.tile(x.size()):
355+
# Use stride in computation to verify it's a constant
356+
out[tile] = x[tile] + stride
357+
return out
358+
359+
# Use empty_strided to create tensor with a unique stride value (137)
360+
# that won't be confused with shape values
361+
size = (64, 64)
362+
stride0 = 137 # Distinctive prime number for stride(0)
363+
stride1 = 1
364+
# Need storage size to fit: (size[0]-1)*stride0 + (size[1]-1)*stride1 + 1
365+
storage_size = (size[0] - 1) * stride0 + (size[1] - 1) * stride1 + 1
366+
storage = torch.randn(storage_size, device=DEVICE)
367+
x = torch.as_strided(storage, size, (stride0, stride1))
368+
369+
code, result = code_and_output(fn, (x,))
370+
torch.testing.assert_close(result, x + x.stride(0))
371+
# Verify the unique stride value 137 is inlined as a constant
372+
self.assertIn("137", code)
373+
# Verify x_stride_0 is NOT passed as an argument (it should be inlined)
374+
self.assertNotIn("x_stride_0", code)
375+
self.assertExpectedJournal(code)
376+
377+
def test_specialize_stride_creates_different_variants(self):
378+
"""Test that different stride patterns create different kernel variants."""
379+
380+
@helion.kernel(static_shapes=False, autotune_effort="none")
381+
def fn(x: torch.Tensor) -> torch.Tensor:
382+
stride = hl.specialize(x.stride(0))
383+
out = torch.empty_like(x)
384+
for tile in hl.tile(x.size()):
385+
out[tile] = x[tile] + stride
386+
return out
387+
388+
# Create two tensors with different unique stride values using empty_strided
389+
size = (64, 64)
390+
391+
# First tensor with stride(0) = 173 (distinctive prime)
392+
stride0_a = 173
393+
storage_size_a = (size[0] - 1) * stride0_a + (size[1] - 1) * 1 + 1
394+
storage_a = torch.randn(storage_size_a, device=DEVICE)
395+
x_a = torch.as_strided(storage_a, size, (stride0_a, 1))
396+
397+
# Second tensor with stride(0) = 257 (different distinctive prime)
398+
stride0_b = 257
399+
storage_size_b = (size[0] - 1) * stride0_b + (size[1] - 1) * 1 + 1
400+
storage_b = torch.randn(storage_size_b, device=DEVICE)
401+
x_b = torch.as_strided(storage_b, size, (stride0_b, 1))
402+
403+
# These should create different bound kernels due to different strides
404+
bound1 = fn.bind((x_a,))
405+
bound2 = fn.bind((x_b,))
406+
407+
# Verify different variants are used
408+
self.assertTrueIfInNormalMode(bound1 is not bound2)
409+
410+
# Verify correctness
411+
result1 = fn(x_a)
412+
result2 = fn(x_b)
413+
torch.testing.assert_close(result1, x_a + stride0_a)
414+
torch.testing.assert_close(result2, x_b + stride0_b)
415+
416+
def test_specialize_stride_tuple(self):
417+
"""Test that hl.specialize works with tuple of strides."""
418+
419+
@helion.kernel(static_shapes=False, autotune_effort="none")
420+
def fn(x: torch.Tensor) -> torch.Tensor:
421+
stride0, stride1 = hl.specialize((x.stride(0), x.stride(1)))
422+
out = torch.empty_like(x)
423+
for tile in hl.tile(x.size()):
424+
out[tile] = x[tile] + stride0 + stride1
425+
return out
426+
427+
# Create tensor with unique stride values using empty_strided
428+
# stride0 = 311, stride1 = 131 (distinctive primes unlikely to appear elsewhere)
429+
size = (64, 64)
430+
stride0 = 311
431+
stride1 = 131
432+
# Storage must fit the largest offset: (size[0]-1)*stride0 + (size[1]-1)*stride1 + 1
433+
storage_size = (size[0] - 1) * stride0 + (size[1] - 1) * stride1 + 1
434+
storage = torch.randn(storage_size, device=DEVICE)
435+
x = torch.as_strided(storage, size, (stride0, stride1))
436+
437+
code, result = code_and_output(fn, (x,))
438+
expected = x + stride0 + stride1
439+
torch.testing.assert_close(result, expected)
440+
# Verify both unique stride values appear in the generated code
441+
self.assertIn("311", code)
442+
self.assertIn("131", code)
443+
# Verify both x_stride_0 and x_stride_1 are NOT passed as arguments (they should be inlined)
444+
self.assertNotIn("x_stride_0", code)
445+
self.assertNotIn("x_stride_1", code)
446+
self.assertExpectedJournal(code)
447+
329448

330449
if __name__ == "__main__":
331450
unittest.main()

0 commit comments

Comments
 (0)