@@ -335,6 +335,119 @@ def fn(x: torch.Tensor, *, _launcher=_default_launcher):
335335 # src[test_specialize.py:N]: return out
336336 return out
337337
338+ --- assertExpectedJournal(TestSpecialize.test_specialize_size_becomes_static)
339+ from __future__ import annotations
340+
341+ import torch
342+ import triton
343+ import triton.language as tl
344+ from helion.runtime import default_launcher as _default_launcher
345+
346+ @triton.jit
347+ def _helion_fn(x, out, out_stride_0, x_stride_0, _BLOCK_SIZE_0: tl.constexpr):
348+ # src[test_specialize.py:N]: for tile in hl.tile(n):
349+ pid_0 = tl.program_id(0)
350+ offset_0 = pid_0 * _BLOCK_SIZE_0
351+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
352+ mask_0 = indices_0 < 137
353+ # src[test_specialize.py:N]: out[tile] = x[tile] + 1
354+ load = tl.load(x + indices_0 * x_stride_0, mask_0, other=0)
355+ v_0 = 1.0
356+ v_1 = load + v_0
357+ tl.store(out + indices_0 * out_stride_0, v_1, mask_0)
358+
359+ def fn(x: torch.Tensor, *, _launcher=_default_launcher):
360+ # src[test_specialize.py:N]: out = torch.empty_like(x)
361+ out = torch.empty_like(x)
362+ # src[test_specialize.py:N]: for tile in hl.tile(n):
363+ _BLOCK_SIZE_0 = 32
364+ # src[test_specialize.py:N]: for tile in hl.tile(n):
365+ # src[test_specialize.py:N]: out[tile] = x[tile] + 1
366+ _launcher(_helion_fn, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, out, out.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=1)
367+ # src[test_specialize.py:N]: return out
368+ return out
369+
370+ --- assertExpectedJournal(TestSpecialize.test_specialize_stride_basic)
371+ from __future__ import annotations
372+
373+ import torch
374+ import triton
375+ import triton.language as tl
376+ from helion.runtime import default_launcher as _default_launcher
377+
378+ @triton.jit
379+ def _helion_fn(x, out, x_size_0, x_size_1, out_stride_0, out_stride_1, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
380+ # src[test_specialize.py:N]: for tile in hl.tile(x.size()):
381+ num_blocks_0 = tl.cdiv(x_size_0, _BLOCK_SIZE_0)
382+ pid_0 = tl.program_id(0) % num_blocks_0
383+ pid_1 = tl.program_id(0) // num_blocks_0
384+ offset_0 = pid_0 * _BLOCK_SIZE_0
385+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
386+ mask_0 = indices_0 < x_size_0
387+ offset_1 = pid_1 * _BLOCK_SIZE_1
388+ indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
389+ mask_1 = indices_1 < x_size_1
390+ # src[test_specialize.py:N]: out[tile] = x[tile] + stride
391+ load = tl.load(x + (indices_0[:, None] * 137 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
392+ v_0 = 137.0
393+ v_1 = load + v_0
394+ tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_1, mask_0[:, None] & mask_1[None, :])
395+
396+ def fn(x: torch.Tensor, *, _launcher=_default_launcher):
397+ # src[test_specialize.py:N]: out = torch.empty_like(x)
398+ out = torch.empty_like(x)
399+ # src[test_specialize.py:N]: for tile in hl.tile(x.size()):
400+ _BLOCK_SIZE_0 = 32
401+ _BLOCK_SIZE_1 = 32
402+ # src[test_specialize.py:N]: for tile in hl.tile(x.size()):
403+ # src[test_specialize.py:N]: # Use stride in computation to verify it's a constant
404+ # src[test_specialize.py:N]: out[tile] = x[tile] + stride
405+ _launcher(_helion_fn, (triton.cdiv(x.size(0), _BLOCK_SIZE_0) * triton.cdiv(x.size(1), _BLOCK_SIZE_1),), x, out, x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1)
406+ # src[test_specialize.py:N]: return out
407+ return out
408+
409+ --- assertExpectedJournal(TestSpecialize.test_specialize_stride_tuple)
410+ from __future__ import annotations
411+
412+ import torch
413+ import triton
414+ import triton.language as tl
415+ from helion.runtime import default_launcher as _default_launcher
416+
417+ @triton.jit
418+ def _helion_fn(x, out, x_size_0, x_size_1, out_stride_0, out_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
419+ # src[test_specialize.py:N]: for tile in hl.tile(x.size()):
420+ num_blocks_0 = tl.cdiv(x_size_0, _BLOCK_SIZE_0)
421+ pid_0 = tl.program_id(0) % num_blocks_0
422+ pid_1 = tl.program_id(0) // num_blocks_0
423+ offset_0 = pid_0 * _BLOCK_SIZE_0
424+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
425+ mask_0 = indices_0 < x_size_0
426+ offset_1 = pid_1 * _BLOCK_SIZE_1
427+ indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
428+ mask_1 = indices_1 < x_size_1
429+ # src[test_specialize.py:N]: out[tile] = x[tile] + stride0 + stride1
430+ load = tl.load(x + (indices_0[:, None] * 311 + indices_1[None, :] * 131), mask_0[:, None] & mask_1[None, :], other=0)
431+ v_0 = 311.0
432+ v_1 = load + v_0
433+ v_2 = 131.0
434+ v_3 = v_1 + v_2
435+ tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_3, mask_0[:, None] & mask_1[None, :])
436+
437+ def fn(x: torch.Tensor, *, _launcher=_default_launcher):
438+ # src[test_specialize.py:N]: stride0, stride1 = hl.specialize((x.stride(0), x.stride(1)))
439+ stride0, stride1 = (311, 131)
440+ # src[test_specialize.py:N]: out = torch.empty_like(x)
441+ out = torch.empty_like(x)
442+ # src[test_specialize.py:N]: for tile in hl.tile(x.size()):
443+ _BLOCK_SIZE_0 = 32
444+ _BLOCK_SIZE_1 = 32
445+ # src[test_specialize.py:N]: for tile in hl.tile(x.size()):
446+ # src[test_specialize.py:N]: out[tile] = x[tile] + stride0 + stride1
447+ _launcher(_helion_fn, (triton.cdiv(x.size(0), _BLOCK_SIZE_0) * triton.cdiv(x.size(1), _BLOCK_SIZE_1),), x, out, x.size(0), x.size(1), out.stride(0), out.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=1)
448+ # src[test_specialize.py:N]: return out
449+ return out
450+
338451--- assertExpectedJournal(TestSpecialize.test_specialize_tuple_element)
339452from __future__ import annotations
340453
0 commit comments