Skip to content

Commit 9c9eea4

Browse files
authored
Fix hl.rand to use tile specific offsets instead of fixed offsets, ensure unique random num per tile (#685)
1 parent ebbd2c4 commit 9c9eea4

File tree

5 files changed

+721
-0
lines changed

5 files changed

+721
-0
lines changed

helion/_compiler/indexing_strategy.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,12 @@ def get_broadcast_str(
352352

353353
return stack_broadcast, tensor_broadcast
354354

355+
@staticmethod
356+
def get_element_broadcast_slice(dim_index: int, total_dims: int) -> str:
357+
broadcast_keys = ["None"] * total_dims
358+
broadcast_keys[dim_index] = ":"
359+
return f"[{', '.join(broadcast_keys)}]"
360+
355361
@staticmethod
356362
def get_mask_expr(
357363
state: CodegenState,

helion/language/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from .matmul_ops import dot as dot
2323
from .memory_ops import load as load
2424
from .memory_ops import store as store
25+
from .random_ops import rand as rand
2526
from .reduce_ops import reduce as reduce
2627
from .scan_ops import associative_scan as associative_scan
2728
from .scan_ops import cumprod as cumprod

helion/language/random_ops.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
import torch
6+
7+
from .._compiler.ast_extension import expr_from_string
8+
from .._compiler.compile_environment import CompileEnvironment
9+
from .._compiler.indexing_strategy import StackIndexingStrategy
10+
from ..exc import NotInsideKernel
11+
from . import _decorators
12+
from .ref_tile import RefTile
13+
14+
if TYPE_CHECKING:
15+
import ast
16+
17+
from .._compiler.inductor_lowering import CodegenState
18+
19+
__all__ = ["rand"]
20+
21+
22+
@_decorators.api(tiles_as_sizes=True)
23+
def rand(
24+
shape: list[object],
25+
seed: int | torch.Tensor,
26+
device: torch.device | None = None,
27+
) -> torch.Tensor:
28+
"""
29+
hl.rand provides a Philox-based pseudorandom number generator (PRNG) that operates independently of PyTorch’s global random seed.
30+
Instead, it requires an explicit seed argument. Offsets are derived from the full logical sizes of the tiles specified in the shape argument.
31+
32+
Args:
33+
shape: A list of sizes for the output tensor
34+
seed: A single element int64 tensor or int literal
35+
36+
Returns:
37+
torch.Tensor: A device tensor of float32 dtype filled with uniform random values in [0, 1)
38+
39+
Examples:
40+
.. code-block:: python
41+
42+
@helion.kernel
43+
def process_kernel(x: torch.Tensor) -> torch.Tensor:
44+
output = torch.zeros_like(x)
45+
(m,) = x.shape
46+
for tile_m in hl.tile(m):
47+
output[tile_m] = hl.rand([tile_m], seed=42)
48+
return output
49+
50+
"""
51+
raise NotInsideKernel
52+
53+
54+
@_decorators.register_fake(rand)
55+
def _rand_fake(
56+
shape: list[int | torch.SymInt],
57+
seed: int | torch.Tensor,
58+
device: torch.device | None = None,
59+
) -> torch.Tensor:
60+
if not isinstance(shape, (list, tuple)):
61+
raise TypeError(f"Expected list[SymInt], got {type(shape).__name__}")
62+
env = CompileEnvironment.current()
63+
env.add_kernel_tensor_size(shape)
64+
return torch.empty(
65+
[*shape],
66+
dtype=torch.float32,
67+
device=env.device if device is None else device,
68+
)
69+
70+
71+
@_decorators.codegen(rand)
72+
def _rand_codegen(state: CodegenState) -> ast.AST:
73+
"""
74+
Generate tl.rand() code with global indices for deterministic RNG per element.
75+
76+
This implementation uses improved dimension detection and broadcasting logic
77+
while maintaining compatibility with the existing approach.
78+
"""
79+
fake_value = state.fake_value
80+
assert isinstance(fake_value, torch.Tensor)
81+
82+
env = CompileEnvironment.current()
83+
tensor_shape = fake_value.size()
84+
ndim = len(tensor_shape)
85+
if ndim == 0:
86+
raise ValueError("hl.rand() requires at least one dimension")
87+
88+
seed_ast = state.ast_arg(1)
89+
90+
index_vars = []
91+
size_names = []
92+
for i in range(ndim):
93+
size = tensor_shape[i]
94+
block_id = env.get_block_id(size)
95+
if block_id is not None:
96+
index_vars.append(state.codegen.index_var(block_id))
97+
original_tensor_size = env.block_sizes[block_id].size
98+
assert isinstance(original_tensor_size, torch.SymInt), (
99+
f"Expected SymInt, got {type(original_tensor_size)}"
100+
)
101+
size_names.append(
102+
state.device_function.sympy_expr(original_tensor_size._sympy_())
103+
)
104+
else:
105+
rdim = env.allocate_reduction_dimension(size)
106+
index_vars.append(state.codegen.index_var(rdim.block_id))
107+
assert isinstance(rdim.var, torch.SymInt), (
108+
f"Expected SymInt, got {type(rdim.var)}"
109+
)
110+
size_names.append(state.device_function.sympy_expr(rdim.var._sympy_()))
111+
112+
if ndim == 1:
113+
offset_expr = expr_from_string(index_vars[0])
114+
else:
115+
offset_parts = []
116+
for i in range(ndim):
117+
broadcast_slice = StackIndexingStrategy.get_element_broadcast_slice(i, ndim)
118+
broadcasted_index = f"{index_vars[i]}{broadcast_slice}"
119+
if i < ndim - 1:
120+
stride_expr = " * ".join(map("({})".format, size_names[i + 1 :]))
121+
offset_parts.append(f"{broadcasted_index} * {stride_expr}")
122+
else:
123+
offset_parts.append(broadcasted_index)
124+
offset_expr = expr_from_string(" + ".join(offset_parts))
125+
return expr_from_string(
126+
"tl.rand({seed}, {offset})", seed=seed_ast, offset=offset_expr
127+
)
128+
129+
130+
@_decorators.get_masked_value(rand)
131+
def _(
132+
node: torch.fx.Node,
133+
) -> float:
134+
return 0
135+
136+
137+
@_decorators.ref(rand)
138+
def _(
139+
shape: list[int | RefTile],
140+
seed: int | torch.Tensor,
141+
device: torch.device | None = None,
142+
) -> torch.Tensor:
143+
processed_shape: list[int] = []
144+
for s in shape:
145+
if isinstance(s, RefTile):
146+
processed_shape.append(s.end - s.begin)
147+
else:
148+
processed_shape.append(int(s))
149+
env = CompileEnvironment.current()
150+
gen = torch.Generator(device=env.device if device is None else device)
151+
if isinstance(seed, torch.Tensor):
152+
gen.manual_seed(int(seed.item()))
153+
else:
154+
gen.manual_seed(seed)
155+
return torch.rand(
156+
processed_shape,
157+
dtype=torch.float32,
158+
generator=gen,
159+
device=env.device if device is None else device,
160+
)

0 commit comments

Comments
 (0)