Skip to content

Commit 811be91

Browse files
authored
Remove triton_helpers.* usage in lifted device function arguments (#849)
1 parent 0a4a7fa commit 811be91

12 files changed

+138
-75
lines changed

helion/_compiler/device_function.py

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ def __init__(self, name: str, config: Config, codegen: GenerateAST) -> None:
207207
] = {}
208208
self._expr_args: dict[sympy.Expr, SymbolArgument] = {}
209209
self._constexpr_args: dict[str, ConstExprArg] = {}
210+
self._constexpr_host_defs: set[str] = set()
210211
self._tensor_properties: dict[
211212
tuple[type[TensorPropertyArg], torch.Tensor, int], TensorPropertyArg
212213
] = {}
@@ -282,11 +283,7 @@ def block_size_var(self, block_id: int) -> str | None:
282283

283284
var_name = self.new_var(f"_BLOCK_SIZE_{block_id}")
284285
self.block_size_var_cache[key] = var_name
285-
host_expr = HostFunction.current().literal_expr(block_value)
286-
if self.constexpr_arg(var_name, host_expr):
287-
self.codegen.host_statements.append(
288-
statement_from_string(f"{var_name} = {host_expr}")
289-
)
286+
self.constexpr_arg_with_host_def(var_name, block_value)
290287

291288
return self.block_size_var_cache[key]
292289

@@ -484,14 +481,55 @@ def expr_arg(self, sym: sympy.Expr, origin: Origin) -> SymbolArgument:
484481
self._expr_args[sym] = arg
485482
return self._expr_args[sym]
486483

487-
def constexpr_arg(self, name: str, host_str: str | None = None) -> bool:
484+
def constexpr_arg(self, name: str, value: object | None = None) -> bool:
488485
"""Create a constexpr argument, returns True if created, False if already exists."""
489486
if name in self._constexpr_args:
490487
return False
491-
self._constexpr_args[name] = rv = ConstExprArg(name, host_str or name)
488+
host_str = name if value is None else self._format_constexpr_value(value)
489+
self._constexpr_args[name] = rv = ConstExprArg(name, host_str)
492490
self.arguments.append(rv)
493491
return True
494492

493+
def constexpr_arg_with_host_def(self, name: str, value: object) -> None:
494+
"""Create a constexpr argument and add its host-side definition if needed."""
495+
created = self.constexpr_arg(name, value)
496+
host_expr = self._constexpr_args[name].host_str()
497+
if created or name not in self._constexpr_host_defs:
498+
self.codegen.host_statements.append(
499+
statement_from_string(f"{name} = {host_expr}")
500+
)
501+
self._constexpr_host_defs.add(name)
502+
503+
def _format_constexpr_value(self, value: object) -> str:
504+
if isinstance(value, str):
505+
return value
506+
if isinstance(value, (int, float, bool)):
507+
return repr(value)
508+
509+
# Extract sympy expression from torch symbolic types
510+
if isinstance(value, (torch.SymInt, torch.SymFloat, torch.SymBool)):
511+
value = value._sympy_()
512+
513+
# Handle sympy expressions (sanitize by replacing triton_helpers functions)
514+
if isinstance(value, sympy.Expr):
515+
expr = cast(
516+
"sympy.Expr",
517+
value.replace(
518+
lambda node: isinstance(node, sympy.Function)
519+
and getattr(node.func, "__name__", "")
520+
== "triton_helpers.div_floor_integer",
521+
lambda node: sympy.floor(node.args[0] / node.args[1]), # pyright: ignore[reportAttributeAccessIssue]
522+
).replace(
523+
lambda node: isinstance(node, sympy.Function)
524+
and getattr(node.func, "__name__", "")
525+
== "triton_helpers.remainder_integer",
526+
lambda node: sympy.Mod(node.args[0], node.args[1]), # pyright: ignore[reportAttributeAccessIssue]
527+
),
528+
)
529+
return HostFunction.current().sympy_expr(expr)
530+
531+
return HostFunction.current().literal_expr(value)
532+
495533
def _tensor_property(
496534
self,
497535
prop_cls: type[_P],
@@ -556,7 +594,12 @@ def codegen_function_def(self) -> list[ast.stmt]:
556594
]
557595

558596
def codegen_function_call(self) -> ast.AST:
559-
args = [arg.host_str() for arg in self.sorted_args()]
597+
args = []
598+
for arg in self.sorted_args():
599+
if isinstance(arg, ConstExprArg) and arg.name in self._constexpr_host_defs:
600+
args.append(arg.name)
601+
else:
602+
args.append(arg.host_str())
560603

561604
if self.has_rng_ops():
562605
# Pass the host-side seed buffer variable to the kernel

helion/_compiler/inductor_lowering.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1241,8 +1241,7 @@ def _create_named_result(self, node: Node, result: ast.expr) -> str:
12411241
):
12421242
# This expression is used in tl.arange, make it a constexpr
12431243
name = self.cg.device_function.new_var(node.name)
1244-
host_expr = self.cg.device_function.sympy_expr(val._sympy_())
1245-
self.cg.device_function.constexpr_arg(name, host_expr)
1244+
self.cg.device_function.constexpr_arg(name, val._sympy_())
12461245
return name
12471246

12481247
# If the lowering produced a named value that is already defined elsewhere

helion/_compiler/tile_strategy.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -244,12 +244,7 @@ def _setup_block_size_constexpr(
244244
self, state: CodegenState, block_size_var: str, block_size: SymIntLike
245245
) -> None:
246246
"""Helper to setup constexpr block size variable on host."""
247-
if state.device_function.constexpr_arg(block_size_var):
248-
state.codegen.host_statements.append(
249-
statement_from_string(
250-
f"{block_size_var} = {HostFunction.current().literal_expr(block_size)}"
251-
)
252-
)
247+
state.device_function.constexpr_arg_with_host_def(block_size_var, block_size)
253248

254249

255250
class BlockSizeTileStrategy(TileStrategy):

test/test_constexpr.expected

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,9 @@ def matmul_int4_block_expr(A: torch.Tensor, B: torch.Tensor, *, _launcher=_defau
6868
C = torch.zeros(M, N, dtype=torch.bfloat16, device=A.device)
6969
_NUM_SM = helion.runtime.get_num_sm(A.device)
7070
_BLOCK_SIZE_2 = 16
71+
_BLOCK_SIZE_1 = 1
7172
_BLOCK_SIZE_0 = 1
72-
_launcher(_helion_matmul_int4_block_expr, (_NUM_SM,), B, A, C, _NUM_SM, _BLOCK_SIZE_2, 1, 1, 2 * _BLOCK_SIZE_0, num_warps=1, num_stages=8)
73+
_launcher(_helion_matmul_int4_block_expr, (_NUM_SM,), B, A, C, _NUM_SM, _BLOCK_SIZE_2, _BLOCK_SIZE_1, _BLOCK_SIZE_0, 2 * _BLOCK_SIZE_0, num_warps=1, num_stages=8)
7374
return C
7475

7576
--- assertExpectedJournal(TestConstExpr.test_constexpr_float)

0 commit comments

Comments
 (0)