Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 51 additions & 8 deletions helion/_compiler/device_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ def __init__(self, name: str, config: Config, codegen: GenerateAST) -> None:
] = {}
self._expr_args: dict[sympy.Expr, SymbolArgument] = {}
self._constexpr_args: dict[str, ConstExprArg] = {}
self._constexpr_host_defs: set[str] = set()
self._tensor_properties: dict[
tuple[type[TensorPropertyArg], torch.Tensor, int], TensorPropertyArg
] = {}
Expand Down Expand Up @@ -282,11 +283,7 @@ def block_size_var(self, block_id: int) -> str | None:

var_name = self.new_var(f"_BLOCK_SIZE_{block_id}")
self.block_size_var_cache[key] = var_name
host_expr = HostFunction.current().literal_expr(block_value)
if self.constexpr_arg(var_name, host_expr):
self.codegen.host_statements.append(
statement_from_string(f"{var_name} = {host_expr}")
)
self.constexpr_arg_with_host_def(var_name, block_value)

return self.block_size_var_cache[key]

Expand Down Expand Up @@ -484,14 +481,55 @@ def expr_arg(self, sym: sympy.Expr, origin: Origin) -> SymbolArgument:
self._expr_args[sym] = arg
return self._expr_args[sym]

def constexpr_arg(self, name: str, host_str: str | None = None) -> bool:
def constexpr_arg(self, name: str, value: object | None = None) -> bool:
"""Create a constexpr argument, returns True if created, False if already exists."""
if name in self._constexpr_args:
return False
self._constexpr_args[name] = rv = ConstExprArg(name, host_str or name)
host_str = name if value is None else self._format_constexpr_value(value)
self._constexpr_args[name] = rv = ConstExprArg(name, host_str)
self.arguments.append(rv)
return True

def constexpr_arg_with_host_def(self, name: str, value: object) -> None:
"""Create a constexpr argument and add its host-side definition if needed."""
created = self.constexpr_arg(name, value)
host_expr = self._constexpr_args[name].host_str()
if created or name not in self._constexpr_host_defs:
self.codegen.host_statements.append(
statement_from_string(f"{name} = {host_expr}")
)
self._constexpr_host_defs.add(name)

def _format_constexpr_value(self, value: object) -> str:
if isinstance(value, str):
return value
if isinstance(value, (int, float, bool)):
return repr(value)

# Extract sympy expression from torch symbolic types
if isinstance(value, (torch.SymInt, torch.SymFloat, torch.SymBool)):
value = value._sympy_()

# Handle sympy expressions (sanitize by replacing triton_helpers functions)
if isinstance(value, sympy.Expr):
expr = cast(
"sympy.Expr",
value.replace(
lambda node: isinstance(node, sympy.Function)
and getattr(node.func, "__name__", "")
== "triton_helpers.div_floor_integer",
lambda node: sympy.floor(node.args[0] / node.args[1]), # pyright: ignore[reportAttributeAccessIssue]
).replace(
lambda node: isinstance(node, sympy.Function)
and getattr(node.func, "__name__", "")
== "triton_helpers.remainder_integer",
lambda node: sympy.Mod(node.args[0], node.args[1]), # pyright: ignore[reportAttributeAccessIssue]
),
)
return HostFunction.current().sympy_expr(expr)

return HostFunction.current().literal_expr(value)

def _tensor_property(
self,
prop_cls: type[_P],
Expand Down Expand Up @@ -556,7 +594,12 @@ def codegen_function_def(self) -> list[ast.stmt]:
]

def codegen_function_call(self) -> ast.AST:
args = [arg.host_str() for arg in self.sorted_args()]
args = []
for arg in self.sorted_args():
if isinstance(arg, ConstExprArg) and arg.name in self._constexpr_host_defs:
args.append(arg.name)
else:
args.append(arg.host_str())

if self.has_rng_ops():
# Pass the host-side seed buffer variable to the kernel
Expand Down
3 changes: 1 addition & 2 deletions helion/_compiler/inductor_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -1241,8 +1241,7 @@ def _create_named_result(self, node: Node, result: ast.expr) -> str:
):
# This expression is used in tl.arange, make it a constexpr
name = self.cg.device_function.new_var(node.name)
host_expr = self.cg.device_function.sympy_expr(val._sympy_())
self.cg.device_function.constexpr_arg(name, host_expr)
self.cg.device_function.constexpr_arg(name, val._sympy_())
return name

# If the lowering produced a named value that is already defined elsewhere
Expand Down
7 changes: 1 addition & 6 deletions helion/_compiler/tile_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,12 +244,7 @@ def _setup_block_size_constexpr(
self, state: CodegenState, block_size_var: str, block_size: SymIntLike
) -> None:
"""Helper to setup constexpr block size variable on host."""
if state.device_function.constexpr_arg(block_size_var):
state.codegen.host_statements.append(
statement_from_string(
f"{block_size_var} = {HostFunction.current().literal_expr(block_size)}"
)
)
state.device_function.constexpr_arg_with_host_def(block_size_var, block_size)


class BlockSizeTileStrategy(TileStrategy):
Expand Down
3 changes: 2 additions & 1 deletion test/test_constexpr.expected
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,9 @@ def matmul_int4_block_expr(A: torch.Tensor, B: torch.Tensor, *, _launcher=_defau
C = torch.zeros(M, N, dtype=torch.bfloat16, device=A.device)
_NUM_SM = helion.runtime.get_num_sm(A.device)
_BLOCK_SIZE_2 = 16
_BLOCK_SIZE_1 = 1
_BLOCK_SIZE_0 = 1
_launcher(_helion_matmul_int4_block_expr, (_NUM_SM,), B, A, C, _NUM_SM, _BLOCK_SIZE_2, 1, 1, 2 * _BLOCK_SIZE_0, num_warps=1, num_stages=8)
_launcher(_helion_matmul_int4_block_expr, (_NUM_SM,), B, A, C, _NUM_SM, _BLOCK_SIZE_2, _BLOCK_SIZE_1, _BLOCK_SIZE_0, 2 * _BLOCK_SIZE_0, num_warps=1, num_stages=8)
return C

--- assertExpectedJournal(TestConstExpr.test_constexpr_float)
Expand Down
Loading
Loading