Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
@helion.kernel(
# static_shapes=True gives a performance boost for matmuls
static_shapes=True,
autotune_config_overrides={"indexing": "tensor_descriptor"}
)
def matmul(
x: Tensor,
Expand Down
3 changes: 3 additions & 0 deletions helion/_compiler/compile_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ def __init__(self, device: torch.device, settings: Settings) -> None:
self.device_load_count = (
0 # Track number of loads in all device code for eviction policy tuning
)
self.device_store_count = (
0 # Track number of stores for subtiling
)

def add_kernel_tensor_size(self, sizes: Sequence[int | torch.SymInt]) -> None:
from .device_function import contains_only_block_size_symbols
Expand Down
6 changes: 6 additions & 0 deletions helion/_compiler/device_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ def __init__(self, name: str, config: Config, codegen: GenerateAST) -> None:
self.rng_seed_count = 0
self.device_load_index = 0 # Track which load in device code we're generating (for eviction policy tuning)
# Name of the RNG seed buffer parameter in kernel signature
self.device_store_index = 0 # Track which store in device code we're generating (for subtiling)
self.rng_seed_buffer_param_name = None

def has_rng_ops(self) -> bool:
Expand Down Expand Up @@ -420,9 +421,14 @@ def tensor_arg(
def tensor_descriptor_arg(
self, fake_value: torch.Tensor, block_size: list[int | torch.SymInt]
) -> TensorDescriptorArg:
import re
host_function = HostFunction.current()
block_size_expr = ", ".join(map(self.literal_expr, block_size))
pattern = r'triton_helpers\.div_floor_integer\(([^,]+),\s*(\d+)\)'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yf225 didn't you add something to fix this somewhere else?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes - we have sanitization pass for triton_helpers.* right now at

if isinstance(value, sympy.Expr):
sanitized = value.replace( # pyright: ignore[reportAttributeAccessIssue]
lambda node: isinstance(node, sympy.Function)
and getattr(node.func, "__name__", "")
== "triton_helpers.div_floor_integer",
lambda node: sympy.floor(node.args[0] / node.args[1]), # pyright: ignore[reportAttributeAccessIssue]
).replace( # pyright: ignore[reportAttributeAccessIssue]
lambda node: isinstance(node, sympy.Function)
and getattr(node.func, "__name__", "")
== "triton_helpers.remainder_integer",
lambda node: sympy.Mod(node.args[0], node.args[1]), # pyright: ignore[reportAttributeAccessIssue]
)
expr = cast("sympy.Expr", sanitized)
return HostFunction.current().sympy_expr(expr)
for constexpr arg, maybe we can extract a common util function to be used in both sites

replacement = r'\1 // \2'
block_size_expr = re.sub(pattern, replacement, block_size_expr)
key = (fake_value, block_size_expr)

if key not in self._tensor_descriptor_args:
origin = host_function.tensor_to_origin[fake_value]
desc_name = self.new_var(origin.suggest_var_name() + "_desc")
Expand Down
47 changes: 33 additions & 14 deletions helion/_compiler/device_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1076,7 +1076,7 @@ def visit_For(self, node: ast.For) -> None:
self.generic_visit(node)


def _count_device_loads(device_ir: DeviceIR) -> int:
def _count_device_loads_and_stores(device_ir: DeviceIR) -> int:
"""Count the number of load operations in all device code for eviction policy tuning."""
from ..language import memory_ops

Expand All @@ -1087,26 +1087,29 @@ def _count_device_loads(device_ir: DeviceIR) -> int:
if info.new_graph_id is not None
}

load_count = 0
load_count, store_count = 0, 0
# Walk all graphs except rolled duplicates
for graph_info in device_ir.graphs:
if graph_info.graph_id in rolled_graph_ids:
continue

for node in graph_info.graph.nodes:
# Check if this is a load operation
if node.op == "call_function" and node.target is memory_ops.load:
# Only count loads without explicit eviction policy
# (user can still specify eviction_policy to override tuning)
# Check kwargs first, then check if 4th arg (eviction_policy) is None
eviction_policy_arg = node.kwargs.get("eviction_policy")
if eviction_policy_arg is None:
# Check if eviction_policy was passed as positional arg (index 3)
if len(node.args) >= 4:
eviction_policy_arg = node.args[3]
if node.op == "call_function":
if node.target is memory_ops.load:
# Only count loads without explicit eviction policy
# (user can still specify eviction_policy to override tuning)
# Check kwargs first, then check if 4th arg (eviction_policy) is None
eviction_policy_arg = node.kwargs.get("eviction_policy")
if eviction_policy_arg is None:
load_count += 1
return load_count
# Check if eviction_policy was passed as positional arg (index 3)
if len(node.args) >= 4:
eviction_policy_arg = node.args[3]
if eviction_policy_arg is None:
load_count += 1
elif node.target is memory_ops.store:
store_count += 1
return load_count, store_count


def _register_eviction_policy_tunable(load_count: int) -> None:
Expand All @@ -1124,6 +1127,21 @@ def _register_eviction_policy_tunable(load_count: int) -> None:
env.config_spec.load_eviction_policies = fragment
env.device_load_count = load_count

def _register_epilogue_subtile_tunable(store_count: int) -> None:
"""Register the epilogue subtile tunable for all device stores."""
if store_count == 0:
return

from ..autotuner.config_fragment import EnumFragment
from ..autotuner.config_fragment import ListOf
from ..autotuner.config_spec import VALID_EPILOGUE_SUBTILE_SIZES

env = CompileEnvironment.current()
# Register a tunable for epilogue subtile for all device stores
fragment = ListOf(EnumFragment(VALID_EPILOGUE_SUBTILE_SIZES), length=store_count)
env.config_spec.epilogue_subtile = fragment
env.device_store_count = store_count


def lower_to_device_ir(func: HostFunction) -> DeviceIR:
device_ir = DeviceIR()
Expand All @@ -1148,8 +1166,9 @@ def lower_to_device_ir(func: HostFunction) -> DeviceIR:
CompileEnvironment.current().config_spec.disallow_pid_type("xyz")

# Count all device loads and register eviction policy tunable
load_count = _count_device_loads(device_ir)
load_count, store_count = _count_device_loads_and_stores(device_ir)
_register_eviction_policy_tunable(load_count)
_register_epilogue_subtile_tunable(store_count)

return device_ir

Expand Down
104 changes: 103 additions & 1 deletion helion/_compiler/indexing_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .. import exc
from .._compat import get_tensor_descriptor_fn_name
from .ast_extension import expr_from_string
from .ast_extension import statement_from_string
from .compile_environment import CompileEnvironment
from .device_function import DeviceFunction
from .host_function import HostFunction
Expand Down Expand Up @@ -103,6 +104,12 @@ def _get_tile_with_offset_info(

return None

def _supports_epilogue_subtiling():
env = CompileEnvironment.current()
if env.device.type != "cuda" or not env.settings.allow_epilogue_subtiling:
return False
return torch.cuda.get_device_capability() >= (10, 0)


class IndexingStrategy:
def codegen_load(
Expand Down Expand Up @@ -376,6 +383,7 @@ def codegen_store(
subscript: list[object],
value: ast.AST,
extra_mask: ast.AST | None,
epilogue_subtile: int | None,
) -> ast.AST:
if not self.is_supported(state, fake_tensor, subscript, extra_mask):
return PointerIndexingStrategy().codegen_store(
Expand All @@ -384,6 +392,10 @@ def codegen_store(
assert extra_mask is None
indexing = BlockedSubscriptIndexing.create(state, fake_tensor, subscript)

config = DeviceFunction.current().config
if _supports_epilogue_subtiling and config.epilogue_subtiling:
return self._codegen_epilogue_subtile_store(state, fake_tensor, indexing, store_value)

# Apply permutation to the value being stored if needed
desc_arg = indexing.tensor_descriptor_arg(state)
store_value = indexing.reshape_store(state, value)
Expand All @@ -394,12 +406,102 @@ def codegen_store(
f"tl.permute({{store_val}}, {desc_arg.permutation!r})",
store_val=store_value,
)

return expr_from_string(
f"{indexing.tensor_descriptor(state)}.store({indexing.offsets_str_permuted(state)}, {{value}})",
value=store_value,
)

def _codegen_epilogue_subtile_store(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is subtiling to store() only sufficient? Or do we want to have a graph base that collects any pointwise ops flowing into the store?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You will want any pointwise ops as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh so basically performing the epilogue on the subtile?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes exactly. Epilogue subtiling is to avoid materializing all the registers needed to compute the result over the entire TMEM allocated tile.

self,
state: CodegenState,
fake_tensor: torch.Tensor,
indexing: BlockedSubscriptIndexing,
store_value: ast.AST,
) -> ast.AST | None:
# Currently support 2D tiles without permutations
if len(indexing.block_shape) != 2 or len(indexing.offsets) != 2:
return None

env = CompileEnvironment.current()
block_m, block_n = indexing.block_shape
try:
block_n_hint = env.size_hint(block_n)
except Exception:
return None

if block_n_hint % 2 != 0:
return None

device_fn = state.device_function
codegen = state.codegen

block_m_str = device_fn.literal_expr(block_m)
block_n_str = device_fn.literal_expr(block_n)
indexing.block_shape[1] //= 2
desc_arg = indexing.tensor_descriptor_arg(state)

if desc_arg.permutation is not None:
return None


block_n_half_str = f"({block_n_str} // 2)"

# Lift the store value into a temporary variable for reuse
acc_var = codegen.lift(store_value, prefix="acc")

reshape_expr = expr_from_string(
"tl.reshape({acc}, [{dim_m}, 2, {dim_half}])",
acc=acc_var,
dim_m=expr_from_string(block_m_str),
dim_half=expr_from_string(block_n_half_str),
)
reshape_var = codegen.lift(reshape_expr, prefix="acc")

permute_expr = expr_from_string(
"tl.permute({acc}, [0, 2, 1])",
acc=reshape_var,
)
permute_var = codegen.lift(permute_expr, prefix="acc")

acc0_name = codegen.tmpvar(prefix="acc")
acc1_name = codegen.tmpvar(prefix="acc")
codegen.add_statement(
statement_from_string(
f"{acc0_name}, {acc1_name} = tl.split({{acc}})",
acc=permute_var,
)
)
acc0 = expr_from_string(acc0_name)
acc1 = expr_from_string(acc1_name)

desc_name = indexing.tensor_descriptor(state)
offset0 = expr_from_string(indexing.offsets[0])
offset1 = expr_from_string(indexing.offsets[1])

# First subtile store
codegen.add_statement(
statement_from_string(
f"{desc_name}.store([{{off0}}, {{off1}}], {{value}})",
off0=offset0,
off1=offset1,
value=acc0,
)
)

offset1_shifted = expr_from_string(
"({offset} + {half})",
offset=expr_from_string(indexing.offsets[1]),
half=expr_from_string(block_n_half_str),
)

# Emit second subtile store as the expression returned to the caller
return expr_from_string(
f"{desc_name}.store([{{off0}}, {{off1}}], {{value}})",
off0=offset0,
off1=offset1_shifted,
value=acc1,
)

class StackIndexingStrategy:
"""
Expand Down
11 changes: 11 additions & 0 deletions helion/autotuner/config_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,12 @@
"pid_type",
"indexing",
"load_eviction_policies",
"epilogue_subtiling"
]
)
VALID_PID_TYPES = ("flat", "xyz", "persistent_blocked", "persistent_interleaved")
VALID_EVICTION_POLICIES = ("", "first", "last")
VALID_EPILOGUE_SUBTILE_SIZES = (0, 2, 4)


@dataclasses.dataclass
Expand Down Expand Up @@ -105,6 +107,9 @@ class ConfigSpec:
EnumFragment(choices=VALID_EVICTION_POLICIES), length=0
)
)
epilogue_subtiling: ListOf = dataclasses.field(
default_factory=lambda: ListOf(EnumFragment(choices=VALID_EPILOGUE_SUBTILE_SIZES), length=0)
)

@staticmethod
def _valid_indexing_types() -> tuple[IndexingLiteral, ...]:
Expand Down Expand Up @@ -208,6 +213,7 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
"range_flattens",
"static_ranges",
"load_eviction_policies",
"epilogue_subtiling",
):
if not config.get(name):
config.pop(name, None)
Expand All @@ -217,6 +223,7 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
config.setdefault(
"load_eviction_policies", self.load_eviction_policies.default()
)
config.setdefault("epilogue_subtiling", self.epilogue_subtiling.default())
# TODO(jansel): include num_ctas and max_nreg

for name, values in (
Expand All @@ -231,6 +238,9 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
else:
config[name] = values[0]

if config["indexing"] != "tensor_descriptor" or any(block_id < 16 for block_id in config["block_sizes"]):
for i in range(len(config["epilogue_subtiling"])):
config["epilogue_subtiling"][i] = 0
# Set default values for grid indices when pid_type is not persistent
pid_type = config["pid_type"]
if pid_type in ("flat", "xyz") and self.grid_block_ids:
Expand Down Expand Up @@ -289,6 +299,7 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf
"indexing": fn(EnumFragment(self._valid_indexing_types())),
"pid_type": fn(EnumFragment(self.allowed_pid_types)),
"load_eviction_policies": fn(self.load_eviction_policies),
"epilogue_subtiling": fn(self.epilogue_subtiling),
}
# Add tunable parameters
config.update(
Expand Down
7 changes: 5 additions & 2 deletions helion/language/memory_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def _(
index: list[object],
value: torch.Tensor | torch.SymInt | float,
extra_mask: torch.Tensor | None = None,
epilogue_subtile: int | None = None,
) -> tuple[
torch.Tensor | tuple,
list[object],
Expand All @@ -68,10 +69,10 @@ def _(
index = Tile._tiles_to_sizes(index)

if isinstance(tensor, StackTensor):
return (tuple(tensor), index, value, extra_mask)
return (tuple(tensor), index, value, extra_mask, epilogue_subtile)

if isinstance(tensor, torch.Tensor):
return (tensor, index, value, extra_mask)
return (tensor, index, value, extra_mask, epilogue_subtile)

raise NotImplementedError(f"Cannot store to type: {type(tensor)}")

Expand All @@ -82,6 +83,7 @@ def _(
index: list[object],
value: torch.Tensor | torch.SymInt | float,
extra_mask: torch.Tensor | None = None,
epilogue_subtile: int | None = None,
) -> None:
return None

Expand All @@ -93,6 +95,7 @@ def _(state: CodegenState) -> ast.AST:
assert isinstance(subscript, (list, tuple))
value = state.ast_arg(2)
extra_mask = state.ast_args[3]
import pdb; pdb.set_trace()
assert isinstance(extra_mask, (type(None), ast.AST))

if isinstance(tensor, torch.Tensor):
Expand Down
6 changes: 6 additions & 0 deletions helion/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(
num_stages: int | None = None,
pid_type: PidTypeLiteral | None = None,
indexing: IndexingLiteral | None = None,
epilogue_subtiling: list[int] | None = None,
# For user-defined properties
**kwargs: object,
) -> None:
Expand All @@ -61,6 +62,7 @@ def __init__(
num_stages: Number of stages for software pipelining.
pid_type: Program ID type strategy ("flat", "xyz", "persistent_blocked", "persistent_interleaved").
indexing: Indexing strategy ("pointer", "tensor_descriptor", "block_ptr").
epilogue_subtiling: Whether to use subtiling for epilogue.
**kwargs: Additional user-defined configuration parameters.
"""
self.config = {}
Expand All @@ -81,6 +83,7 @@ def __init__(
"num_stages": num_stages,
"indexing": indexing,
"pid_type": pid_type,
"epilogue_subtiling": epilogue_subtiling,
}
for key, value in core_props.items():
if value is not None:
Expand Down Expand Up @@ -206,6 +209,9 @@ def load_eviction_policies(self) -> list[str]:
def indexing(self) -> IndexingLiteral:
return self.config.get("indexing", "pointer") # type: ignore[return-value]

@property
def epilogue_subtiling(self) -> bool:
return cast("list[int]", self.config.get("epilogue_subtiling", [])) # type: ignore[return-value]

def _to_hashable(x: object) -> object:
if isinstance(x, list):
Expand Down
Loading