From c0ed00f2755469d6b93812aa801641168fd077f2 Mon Sep 17 00:00:00 2001 From: PaulZhang12 Date: Wed, 15 Oct 2025 13:57:57 -0700 Subject: [PATCH] Add epilogue subtiling stack-info: PR: https://github.com/pytorch/helion/pull/948, branch: PaulZhang12/stack/14 --- helion/_compiler/compile_environment.py | 1 + helion/_compiler/device_function.py | 9 ++ helion/_compiler/device_ir.py | 53 ++++++++---- helion/_compiler/indexing_strategy.py | 105 +++++++++++++++++++++++- helion/autotuner/config_spec.py | 16 ++++ helion/runtime/config.py | 7 ++ test/test_autotuner.expected | 60 +++++++------- test/test_register_tunable.expected | 8 +- 8 files changed, 209 insertions(+), 50 deletions(-) diff --git a/helion/_compiler/compile_environment.py b/helion/_compiler/compile_environment.py index 27c39ca34..dad51afc4 100644 --- a/helion/_compiler/compile_environment.py +++ b/helion/_compiler/compile_environment.py @@ -99,6 +99,7 @@ def __init__(self, device: torch.device, settings: Settings) -> None: self.device_load_count = ( 0 # Track number of loads in all device code for eviction policy tuning ) + self.device_store_count = 0 # Track number of stores for subtiling def add_kernel_tensor_size(self, sizes: Sequence[int | torch.SymInt]) -> None: from .device_function import contains_only_block_size_symbols diff --git a/helion/_compiler/device_function.py b/helion/_compiler/device_function.py index 49abcdded..bc56017ff 100644 --- a/helion/_compiler/device_function.py +++ b/helion/_compiler/device_function.py @@ -250,6 +250,9 @@ def __init__(self, name: str, config: Config, codegen: GenerateAST) -> None: self.rng_seed_count = 0 self.device_load_index = 0 # Track which load in device code we're generating (for eviction policy tuning) # Name of the RNG seed buffer parameter in kernel signature + self.device_store_index = ( + 0 # Track which store in device code we're generating (for subtiling) + ) self.rng_seed_buffer_param_name = None def has_rng_ops(self) -> bool: @@ -420,9 +423,15 @@ def tensor_arg( def tensor_descriptor_arg( self, fake_value: torch.Tensor, block_size: list[int | torch.SymInt] ) -> TensorDescriptorArg: + import re + host_function = HostFunction.current() block_size_expr = ", ".join(map(self.literal_expr, block_size)) + pattern = r"triton_helpers\.div_floor_integer\(([^,]+),\s*(\d+)\)" + replacement = r"\1 // \2" + block_size_expr = re.sub(pattern, replacement, block_size_expr) key = (fake_value, block_size_expr) + if key not in self._tensor_descriptor_args: origin = host_function.tensor_to_origin[fake_value] desc_name = self.new_var(origin.suggest_var_name() + "_desc") diff --git a/helion/_compiler/device_ir.py b/helion/_compiler/device_ir.py index 932235661..e59e862ea 100644 --- a/helion/_compiler/device_ir.py +++ b/helion/_compiler/device_ir.py @@ -1076,7 +1076,7 @@ def visit_For(self, node: ast.For) -> None: self.generic_visit(node) -def _count_device_loads(device_ir: DeviceIR) -> int: +def _count_device_loads_and_stores(device_ir: DeviceIR) -> int: """Count the number of load operations in all device code for eviction policy tuning.""" from ..language import memory_ops @@ -1087,7 +1087,7 @@ def _count_device_loads(device_ir: DeviceIR) -> int: if info.new_graph_id is not None } - load_count = 0 + load_count, store_count = 0, 0 # Walk all graphs except rolled duplicates for graph_info in device_ir.graphs: if graph_info.graph_id in rolled_graph_ids: @@ -1095,18 +1095,21 @@ def _count_device_loads(device_ir: DeviceIR) -> int: for node in graph_info.graph.nodes: # Check if this is a load operation - if node.op == "call_function" and node.target is memory_ops.load: - # Only count loads without explicit eviction policy - # (user can still specify eviction_policy to override tuning) - # Check kwargs first, then check if 4th arg (eviction_policy) is None - eviction_policy_arg = node.kwargs.get("eviction_policy") - if eviction_policy_arg is None: - # Check if eviction_policy was passed as positional arg (index 3) - if len(node.args) >= 4: - eviction_policy_arg = node.args[3] + if node.op == "call_function": + if node.target is memory_ops.load: + # Only count loads without explicit eviction policy + # (user can still specify eviction_policy to override tuning) + # Check kwargs first, then check if 4th arg (eviction_policy) is None + eviction_policy_arg = node.kwargs.get("eviction_policy") if eviction_policy_arg is None: - load_count += 1 - return load_count + # Check if eviction_policy was passed as positional arg (index 3) + if len(node.args) >= 4: + eviction_policy_arg = node.args[3] + if eviction_policy_arg is None: + load_count += 1 + elif node.target is memory_ops.store: + store_count += 1 + return load_count, store_count def _register_eviction_policy_tunable(load_count: int) -> None: @@ -1125,6 +1128,24 @@ def _register_eviction_policy_tunable(load_count: int) -> None: env.device_load_count = load_count +def _register_epilogue_subtile_tunable(store_count: int) -> None: + """Register the epilogue subtile tunable for all device stores.""" + if store_count == 0: + return + + from ..autotuner.config_fragment import EnumFragment + from ..autotuner.config_fragment import ListOf + from ..autotuner.config_spec import VALID_EPILOGUE_SUBTILE_SIZES + + env = CompileEnvironment.current() + # Register a tunable for epilogue subtile for all device stores + fragment = ListOf( + EnumFragment(choices=VALID_EPILOGUE_SUBTILE_SIZES), length=store_count + ) + env.config_spec.epilogue_subtiling = fragment + env.device_store_count = store_count + + def lower_to_device_ir(func: HostFunction) -> DeviceIR: device_ir = DeviceIR() with func, device_ir, compile_lock: @@ -1148,9 +1169,13 @@ def lower_to_device_ir(func: HostFunction) -> DeviceIR: CompileEnvironment.current().config_spec.disallow_pid_type("xyz") # Count all device loads and register eviction policy tunable - load_count = _count_device_loads(device_ir) + load_count, store_count = _count_device_loads_and_stores(device_ir) _register_eviction_policy_tunable(load_count) + # Epilogue subtiling only for Blackwell + if torch.cuda.get_device_capability() >= (10, 0): + _register_epilogue_subtile_tunable(store_count) + return device_ir diff --git a/helion/_compiler/indexing_strategy.py b/helion/_compiler/indexing_strategy.py index e10d34037..75552ec18 100644 --- a/helion/_compiler/indexing_strategy.py +++ b/helion/_compiler/indexing_strategy.py @@ -15,6 +15,7 @@ from .. import exc from .._compat import get_tensor_descriptor_fn_name from .ast_extension import expr_from_string +from .ast_extension import statement_from_string from .compile_environment import CompileEnvironment from .device_function import DeviceFunction from .host_function import HostFunction @@ -353,7 +354,6 @@ def codegen_load( ) assert extra_mask is None indexing = BlockedSubscriptIndexing.create(state, fake_tensor, subscript) - # Load from tensor descriptor with permuted offsets load_expr = expr_from_string( f"{indexing.tensor_descriptor(state)}.load({indexing.offsets_str_permuted(state)})" @@ -383,10 +383,12 @@ def codegen_store( ) assert extra_mask is None indexing = BlockedSubscriptIndexing.create(state, fake_tensor, subscript) + store_value = indexing.reshape_store(state, value) + config = DeviceFunction.current().config + epilogue_subtiles = state.config.epilogue_subtiling # Apply permutation to the value being stored if needed desc_arg = indexing.tensor_descriptor_arg(state) - store_value = indexing.reshape_store(state, value) if desc_arg.permutation is not None: # Apply permutation to the value @@ -395,11 +397,110 @@ def codegen_store( store_val=store_value, ) + if (idx := state.device_function.device_store_index) < len(epilogue_subtiles): + subtile_split = epilogue_subtiles[idx] + state.device_function.device_store_index += 1 + + subtile_codegen = self._codegen_epilogue_subtile_store( + state, fake_tensor, indexing, store_value, subtile_split, config + ) + if subtile_codegen is not None: + return subtile_codegen + return expr_from_string( f"{indexing.tensor_descriptor(state)}.store({indexing.offsets_str_permuted(state)}, {{value}})", value=store_value, ) + def _codegen_epilogue_subtile_store( + self, + state: CodegenState, + fake_tensor: torch.Tensor, + indexing: BlockedSubscriptIndexing, + store_value: ast.AST, + subtile_split: int, + config: Config, + ) -> ast.AST | None: + # Currently support 2D tiles without permutations + if ( + len(indexing.block_shape) != 2 + or len(indexing.offsets) != 2 + or subtile_split == 0 + ): + return None + + env = CompileEnvironment.current() + block_m, block_n = indexing.block_shape + try: + block_n_hint = env.size_hint(block_n) + block_idx = env.get_block_id(block_n) + block_size = env.block_sizes[block_idx].from_config(config) + except Exception: + return None + + if block_n_hint % 2 != 0 or block_size <= 16: + return None + + device_fn = state.device_function + codegen = state.codegen + + block_m_str = device_fn.literal_expr(block_m) + block_n_str = device_fn.literal_expr(block_n) + indexing.block_shape[1] //= subtile_split + + # TODO(PaulZhang12): Support more epilogue subtile configs besides 2 + block_n_half_str = f"({block_n_str} // {subtile_split})" + + # Lift the store value into a temporary variable for reuse + acc_var = codegen.lift(store_value, prefix="acc") + + reshape_expr = expr_from_string( + "tl.reshape({acc}, [{dim_m}, 2, {dim_half}]).permute(0, 2, 1)", + acc=acc_var, + dim_m=expr_from_string(block_m_str), + dim_half=expr_from_string(block_n_half_str), + ) + reshape_var = codegen.lift(reshape_expr, prefix="acc") + + acc0_name = codegen.tmpvar(prefix="acc") + acc1_name = codegen.tmpvar(prefix="acc") + codegen.add_statement( + statement_from_string( + f"{acc0_name}, {acc1_name} = tl.split({{acc}})", + acc=reshape_var, + ) + ) + acc0 = expr_from_string(acc0_name) + acc1 = expr_from_string(acc1_name) + + desc_name = indexing.tensor_descriptor(state) + offset0 = expr_from_string(indexing.offsets[0]) + offset1 = expr_from_string(indexing.offsets[1]) + + # First subtile store + codegen.add_statement( + statement_from_string( + f"{desc_name}.store([{{off0}}, {{off1}}], {{value}})", + off0=offset0, + off1=offset1, + value=acc0, + ) + ) + + offset1_shifted = expr_from_string( + "({offset} + {half})", + offset=expr_from_string(indexing.offsets[1]), + half=expr_from_string(block_n_half_str), + ) + + # Emit second subtile store as the expression returned to the caller + return expr_from_string( + f"{desc_name}.store([{{off0}}, {{off1}}], {{value}})", + off0=offset0, + off1=offset1_shifted, + value=acc1, + ) + class StackIndexingStrategy: """ diff --git a/helion/autotuner/config_spec.py b/helion/autotuner/config_spec.py index 3649b6ec0..86ec5764c 100644 --- a/helion/autotuner/config_spec.py +++ b/helion/autotuner/config_spec.py @@ -52,10 +52,12 @@ "pid_type", "indexing", "load_eviction_policies", + "epilogue_subtiling", ] ) VALID_PID_TYPES = ("flat", "xyz", "persistent_blocked", "persistent_interleaved") VALID_EVICTION_POLICIES = ("", "first", "last") +VALID_EPILOGUE_SUBTILE_SIZES = (0, 2) @dataclasses.dataclass @@ -105,6 +107,11 @@ class ConfigSpec: EnumFragment(choices=VALID_EVICTION_POLICIES), length=0 ) ) + epilogue_subtiling: ListOf = dataclasses.field( + default_factory=lambda: ListOf( + EnumFragment(choices=VALID_EPILOGUE_SUBTILE_SIZES), length=0 + ) + ) @staticmethod def _valid_indexing_types() -> tuple[IndexingLiteral, ...]: @@ -208,6 +215,7 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None: "range_flattens", "static_ranges", "load_eviction_policies", + "epilogue_subtiling", ): if not config.get(name): config.pop(name, None) @@ -217,6 +225,7 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None: config.setdefault( "load_eviction_policies", self.load_eviction_policies.default() ) + config.setdefault("epilogue_subtiling", self.epilogue_subtiling.default()) # TODO(jansel): include num_ctas and max_nreg for name, values in ( @@ -231,6 +240,10 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None: else: config[name] = values[0] + if config["indexing"] != "tensor_descriptor": + for i in range(len(config["epilogue_subtiling"])): + config["epilogue_subtiling"][i] = 0 + # Set default values for grid indices when pid_type is not persistent pid_type = config["pid_type"] if pid_type in ("flat", "xyz") and self.grid_block_ids: @@ -289,6 +302,7 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf "indexing": fn(EnumFragment(self._valid_indexing_types())), "pid_type": fn(EnumFragment(self.allowed_pid_types)), "load_eviction_policies": fn(self.load_eviction_policies), + "epilogue_subtiling": fn(self.epilogue_subtiling), } # Add tunable parameters config.update( @@ -307,9 +321,11 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf "range_flattens", "static_ranges", "load_eviction_policies", + "epilogue_subtiling", ): if not config.get(name): config.pop(name, None) + self.normalize(config) return helion.Config(**config) diff --git a/helion/runtime/config.py b/helion/runtime/config.py index f55f6563e..2800d9cd6 100644 --- a/helion/runtime/config.py +++ b/helion/runtime/config.py @@ -39,6 +39,7 @@ def __init__( num_stages: int | None = None, pid_type: PidTypeLiteral | None = None, indexing: IndexingLiteral | None = None, + epilogue_subtiling: list[int] | None = None, # For user-defined properties **kwargs: object, ) -> None: @@ -61,6 +62,7 @@ def __init__( num_stages: Number of stages for software pipelining. pid_type: Program ID type strategy ("flat", "xyz", "persistent_blocked", "persistent_interleaved"). indexing: Indexing strategy ("pointer", "tensor_descriptor", "block_ptr"). + epilogue_subtiling: Whether to use subtiling for epilogue. **kwargs: Additional user-defined configuration parameters. """ self.config = {} @@ -81,6 +83,7 @@ def __init__( "num_stages": num_stages, "indexing": indexing, "pid_type": pid_type, + "epilogue_subtiling": epilogue_subtiling, } for key, value in core_props.items(): if value is not None: @@ -206,6 +209,10 @@ def load_eviction_policies(self) -> list[str]: def indexing(self) -> IndexingLiteral: return self.config.get("indexing", "pointer") # type: ignore[return-value] + @property + def epilogue_subtiling(self) -> bool: + return cast("list[int]", self.config.get("epilogue_subtiling", [])) # type: ignore[return-value] + def _to_hashable(x: object) -> object: if isinstance(x, list): diff --git a/test/test_autotuner.expected b/test/test_autotuner.expected index 319ab4fc5..540923dd3 100644 --- a/test/test_autotuner.expected +++ b/test/test_autotuner.expected @@ -2,40 +2,40 @@ This file is automatically generated by assertExpectedJournal calls in test_auto Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set. --- assertExpectedJournal(TestAutotuner.test_config_fragment0) -helion.Config(block_sizes=[16, 16, 16], indexing='pointer', l2_groupings=[1], load_eviction_policies=['', ''], loop_orders=[[0, 1]], num_stages=2, num_warps=4, pid_type='flat', range_flattens=[None, None], range_multi_buffers=[None, None], range_num_stages=[0, 0], range_unroll_factors=[0, 0], range_warp_specializes=[None, None]) -helion.Config(block_sizes=[32, 128, 64], indexing='block_ptr', l2_groupings=[8], load_eviction_policies=['', ''], loop_orders=[[1, 0]], num_stages=8, num_warps=8, pid_type='persistent_blocked', range_flattens=[None, True], range_multi_buffers=[False, True], range_num_stages=[3, 0], range_unroll_factors=[1, 0], range_warp_specializes=[None, True]) -helion.Config(block_sizes=[16, 16, 16], indexing='tensor_descriptor', l2_groupings=[16], load_eviction_policies=['last', ''], loop_orders=[[0, 1]], num_stages=7, num_warps=4, pid_type='flat', range_flattens=[None, None], range_multi_buffers=[None, None], range_num_stages=[0, 0], range_unroll_factors=[0, 3], range_warp_specializes=[None, False]) -helion.Config(block_sizes=[16, 32, 256], indexing='pointer', l2_groupings=[64], load_eviction_policies=['first', ''], loop_orders=[[1, 0]], num_stages=2, num_warps=16, pid_type='persistent_interleaved', range_flattens=[True, True], range_multi_buffers=[False, None], range_num_stages=[2, 4], range_unroll_factors=[2, 0], range_warp_specializes=[True, None]) -helion.Config(block_sizes=[64, 32, 16], indexing='block_ptr', l2_groupings=[2], load_eviction_policies=['first', 'last'], loop_orders=[[1, 0]], num_stages=2, num_warps=4, pid_type='flat', range_flattens=[None, True], range_multi_buffers=[None, True], range_num_stages=[0, 4], range_unroll_factors=[0, 1], range_warp_specializes=[None, None]) -helion.Config(block_sizes=[16, 16, 16], indexing='tensor_descriptor', l2_groupings=[32], load_eviction_policies=['last', 'first'], loop_orders=[[0, 1]], num_stages=2, num_warps=1, pid_type='flat', range_flattens=[None, False], range_multi_buffers=[None, None], range_num_stages=[0, 2], range_unroll_factors=[0, 2], range_warp_specializes=[None, False]) -helion.Config(block_sizes=[16, 32, 64], indexing='block_ptr', l2_groupings=[8], load_eviction_policies=['last', 'first'], loop_orders=[[1, 0]], num_stages=5, num_warps=16, pid_type='flat', range_flattens=[None, None], range_multi_buffers=[None, False], range_num_stages=[0, 3], range_unroll_factors=[0, 3], range_warp_specializes=[None, None]) -helion.Config(block_sizes=[16, 32, 16], indexing='pointer', l2_groupings=[2], load_eviction_policies=['first', 'first'], loop_orders=[[0, 1]], num_stages=8, num_warps=16, pid_type='persistent_interleaved', range_flattens=[False, None], range_multi_buffers=[False, None], range_num_stages=[3, 3], range_unroll_factors=[2, 0], range_warp_specializes=[False, True]) -helion.Config(block_sizes=[256, 16, 16], indexing='pointer', l2_groupings=[2], load_eviction_policies=['', ''], loop_orders=[[0, 1]], num_stages=5, num_warps=32, pid_type='flat', range_flattens=[None, None], range_multi_buffers=[None, False], range_num_stages=[0, 1], range_unroll_factors=[0, 0], range_warp_specializes=[None, True]) -helion.Config(block_sizes=[16, 64, 16], indexing='tensor_descriptor', l2_groupings=[8], load_eviction_policies=['last', ''], loop_orders=[[0, 1]], num_stages=3, num_warps=32, pid_type='persistent_interleaved', range_flattens=[True, False], range_multi_buffers=[False, None], range_num_stages=[3, 0], range_unroll_factors=[3, 0], range_warp_specializes=[False, True]) +helion.Config(block_sizes=[16, 16, 16], epilogue_subtiling=[0], indexing='pointer', l2_groupings=[1], load_eviction_policies=['', ''], loop_orders=[[0, 1]], num_stages=2, num_warps=4, pid_type='flat', range_flattens=[None, None], range_multi_buffers=[None, None], range_num_stages=[0, 0], range_unroll_factors=[0, 0], range_warp_specializes=[None, None]) +helion.Config(block_sizes=[32, 128, 64], epilogue_subtiling=[0], indexing='block_ptr', l2_groupings=[8], load_eviction_policies=['', ''], loop_orders=[[1, 0]], num_stages=8, num_warps=8, pid_type='persistent_blocked', range_flattens=[None, True], range_multi_buffers=[False, True], range_num_stages=[3, 0], range_unroll_factors=[1, 0], range_warp_specializes=[None, True]) +helion.Config(block_sizes=[16, 16, 16], epilogue_subtiling=[0], indexing='pointer', l2_groupings=[4], load_eviction_policies=['last', 'last'], loop_orders=[[1, 0]], num_stages=2, num_warps=1, pid_type='persistent_blocked', range_flattens=[None, True], range_multi_buffers=[False, True], range_num_stages=[1, 2], range_unroll_factors=[0, 0], range_warp_specializes=[None, True]) +helion.Config(block_sizes=[32, 256, 64], epilogue_subtiling=[0], indexing='block_ptr', l2_groupings=[16], load_eviction_policies=['', 'last'], loop_orders=[[0, 1]], num_stages=3, num_warps=16, pid_type='persistent_blocked', range_flattens=[False, None], range_multi_buffers=[None, True], range_num_stages=[4, 0], range_unroll_factors=[3, 0], range_warp_specializes=[None, True]) +helion.Config(block_sizes=[16, 16, 16], epilogue_subtiling=[2], indexing='tensor_descriptor', l2_groupings=[16], load_eviction_policies=['first', 'first'], loop_orders=[[0, 1]], num_stages=1, num_warps=1, pid_type='persistent_blocked', range_flattens=[False, None], range_multi_buffers=[True, None], range_num_stages=[3, 4], range_unroll_factors=[2, 0], range_warp_specializes=[False, None]) +helion.Config(block_sizes=[16, 16, 16], epilogue_subtiling=[0], indexing='block_ptr', l2_groupings=[1], load_eviction_policies=['first', 'first'], loop_orders=[[1, 0]], num_stages=4, num_warps=16, pid_type='persistent_blocked', range_flattens=[False, True], range_multi_buffers=[None, False], range_num_stages=[3, 0], range_unroll_factors=[4, 0], range_warp_specializes=[True, False]) +helion.Config(block_sizes=[64, 128, 32], epilogue_subtiling=[0], indexing='pointer', l2_groupings=[16], load_eviction_policies=['first', 'first'], loop_orders=[[1, 0]], num_stages=5, num_warps=32, pid_type='persistent_interleaved', range_flattens=[False, None], range_multi_buffers=[False, True], range_num_stages=[3, 3], range_unroll_factors=[2, 0], range_warp_specializes=[None, True]) +helion.Config(block_sizes=[16, 16, 16], epilogue_subtiling=[2], indexing='tensor_descriptor', l2_groupings=[8], load_eviction_policies=['first', ''], loop_orders=[[1, 0]], num_stages=4, num_warps=2, pid_type='flat', range_flattens=[None, True], range_multi_buffers=[None, False], range_num_stages=[0, 4], range_unroll_factors=[0, 4], range_warp_specializes=[None, False]) +helion.Config(block_sizes=[16, 128, 64], epilogue_subtiling=[2], indexing='tensor_descriptor', l2_groupings=[4], load_eviction_policies=['last', ''], loop_orders=[[0, 1]], num_stages=6, num_warps=2, pid_type='flat', range_flattens=[None, True], range_multi_buffers=[None, None], range_num_stages=[0, 0], range_unroll_factors=[0, 0], range_warp_specializes=[None, True]) +helion.Config(block_sizes=[16, 16, 16], epilogue_subtiling=[0], indexing='block_ptr', l2_groupings=[1], load_eviction_policies=['', 'last'], loop_orders=[[1, 0]], num_stages=4, num_warps=1, pid_type='persistent_interleaved', range_flattens=[None, False], range_multi_buffers=[True, True], range_num_stages=[1, 4], range_unroll_factors=[3, 0], range_warp_specializes=[True, False]) --- assertExpectedJournal(TestAutotuner.test_config_fragment1) -helion.Config(block_sizes=[8, 16, 16], flatten_loops=[False], indexing='pointer', l2_groupings=[1], load_eviction_policies=['', ''], loop_orders=[[0, 1, 2]], num_stages=2, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[], range_unroll_factors=[0], range_warp_specializes=[None]) -helion.Config(block_sizes=[1, 64, 64], flatten_loops=[False], indexing='tensor_descriptor', l2_groupings=[4], load_eviction_policies=['first', 'first'], loop_orders=[[1, 2, 0]], num_stages=4, num_warps=8, pid_type='persistent_blocked', range_flattens=[None], range_multi_buffers=[False], range_unroll_factors=[0], range_warp_specializes=[True]) -helion.Config(block_sizes=[2, 8, 512], flatten_loops=[True], indexing='tensor_descriptor', l2_groupings=[8], load_eviction_policies=['first', 'first'], loop_orders=[[2, 0, 1]], num_stages=2, num_warps=1, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[], range_unroll_factors=[0], range_warp_specializes=[None]) -helion.Config(block_sizes=[1, 512, 1], flatten_loops=[True], indexing='tensor_descriptor', l2_groupings=[1], load_eviction_policies=['', 'last'], loop_orders=[[0, 2, 1]], num_stages=5, num_warps=2, pid_type='persistent_blocked', range_flattens=[True], range_multi_buffers=[False], range_unroll_factors=[0], range_warp_specializes=[True]) -helion.Config(block_sizes=[1, 4, 256], flatten_loops=[True], indexing='block_ptr', l2_groupings=[8], load_eviction_policies=['last', 'last'], loop_orders=[[1, 0, 2]], num_stages=2, num_warps=32, pid_type='persistent_interleaved', range_flattens=[None], range_multi_buffers=[True], range_unroll_factors=[0], range_warp_specializes=[True]) -helion.Config(block_sizes=[1, 128, 16], flatten_loops=[True], indexing='tensor_descriptor', l2_groupings=[16], load_eviction_policies=['first', 'first'], loop_orders=[[0, 1, 2]], num_stages=1, num_warps=1, pid_type='persistent_blocked', range_flattens=[None], range_multi_buffers=[False], range_unroll_factors=[4], range_warp_specializes=[None]) -helion.Config(block_sizes=[8, 32, 256], flatten_loops=[False], indexing='pointer', l2_groupings=[64], load_eviction_policies=['first', 'last'], loop_orders=[[0, 1, 2]], num_stages=2, num_warps=8, pid_type='persistent_blocked', range_flattens=[False], range_multi_buffers=[True], range_unroll_factors=[4], range_warp_specializes=[None]) -helion.Config(block_sizes=[2, 64, 32], flatten_loops=[False], indexing='block_ptr', l2_groupings=[8], load_eviction_policies=['last', 'first'], loop_orders=[[1, 2, 0]], num_stages=5, num_warps=16, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[], range_unroll_factors=[0], range_warp_specializes=[None]) -helion.Config(block_sizes=[4, 32, 1], flatten_loops=[True], indexing='pointer', l2_groupings=[8], load_eviction_policies=['', 'last'], loop_orders=[[2, 1, 0]], num_stages=8, num_warps=8, pid_type='persistent_blocked', range_flattens=[True], range_multi_buffers=[False], range_unroll_factors=[0], range_warp_specializes=[True]) -helion.Config(block_sizes=[4, 2, 128], flatten_loops=[False], indexing='tensor_descriptor', l2_groupings=[2], load_eviction_policies=['', 'first'], loop_orders=[[1, 2, 0]], num_stages=2, num_warps=4, pid_type='persistent_blocked', range_flattens=[False], range_multi_buffers=[None], range_unroll_factors=[1], range_warp_specializes=[False]) +helion.Config(block_sizes=[8, 16, 16], epilogue_subtiling=[0], flatten_loops=[False], indexing='pointer', l2_groupings=[1], load_eviction_policies=['', ''], loop_orders=[[0, 1, 2]], num_stages=2, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[], range_unroll_factors=[0], range_warp_specializes=[None]) +helion.Config(block_sizes=[1, 32, 32], epilogue_subtiling=[2], flatten_loops=[False], indexing='tensor_descriptor', l2_groupings=[4], load_eviction_policies=['first', 'first'], loop_orders=[[1, 2, 0]], num_stages=4, num_warps=8, pid_type='persistent_blocked', range_flattens=[None], range_multi_buffers=[False], range_unroll_factors=[0], range_warp_specializes=[True]) +helion.Config(block_sizes=[1, 64, 1], epilogue_subtiling=[0], flatten_loops=[True], indexing='block_ptr', l2_groupings=[16], load_eviction_policies=['last', ''], loop_orders=[[2, 1, 0]], num_stages=4, num_warps=1, pid_type='persistent_blocked', range_flattens=[None], range_multi_buffers=[None], range_unroll_factors=[3], range_warp_specializes=[False]) +helion.Config(block_sizes=[2, 16, 256], epilogue_subtiling=[0], flatten_loops=[True], indexing='tensor_descriptor', l2_groupings=[64], load_eviction_policies=['last', ''], loop_orders=[[1, 0, 2]], num_stages=3, num_warps=8, pid_type='persistent_interleaved', range_flattens=[True], range_multi_buffers=[False], range_unroll_factors=[4], range_warp_specializes=[None]) +helion.Config(block_sizes=[2, 4, 256], epilogue_subtiling=[0], flatten_loops=[True], indexing='tensor_descriptor', l2_groupings=[2], load_eviction_policies=['last', ''], loop_orders=[[2, 0, 1]], num_stages=6, num_warps=1, pid_type='persistent_interleaved', range_flattens=[True], range_multi_buffers=[None], range_unroll_factors=[0], range_warp_specializes=[True]) +helion.Config(block_sizes=[4, 2, 128], epilogue_subtiling=[0], flatten_loops=[True], indexing='block_ptr', l2_groupings=[1], load_eviction_policies=['first', 'first'], loop_orders=[[0, 1, 2]], num_stages=6, num_warps=1, pid_type='persistent_blocked', range_flattens=[True], range_multi_buffers=[True], range_unroll_factors=[0], range_warp_specializes=[True]) +helion.Config(block_sizes=[2, 16, 2], epilogue_subtiling=[0], flatten_loops=[True], indexing='block_ptr', l2_groupings=[64], load_eviction_policies=['first', 'first'], loop_orders=[[0, 2, 1]], num_stages=4, num_warps=16, pid_type='persistent_blocked', range_flattens=[False], range_multi_buffers=[False], range_unroll_factors=[0], range_warp_specializes=[None]) +helion.Config(block_sizes=[4, 128, 16], epilogue_subtiling=[0], flatten_loops=[True], indexing='pointer', l2_groupings=[64], load_eviction_policies=['first', ''], loop_orders=[[1, 0, 2]], num_stages=6, num_warps=4, pid_type='persistent_interleaved', range_flattens=[False], range_multi_buffers=[False], range_unroll_factors=[0], range_warp_specializes=[True]) +helion.Config(block_sizes=[4, 256, 32], epilogue_subtiling=[0], flatten_loops=[False], indexing='block_ptr', l2_groupings=[8], load_eviction_policies=['last', ''], loop_orders=[[2, 1, 0]], num_stages=5, num_warps=2, pid_type='persistent_blocked', range_flattens=[False], range_multi_buffers=[True], range_unroll_factors=[0], range_warp_specializes=[True]) +helion.Config(block_sizes=[2, 128, 8], epilogue_subtiling=[0], flatten_loops=[True], indexing='pointer', l2_groupings=[2], load_eviction_policies=['', ''], loop_orders=[[1, 2, 0]], num_stages=5, num_warps=32, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[], range_unroll_factors=[0], range_warp_specializes=[None]) --- assertExpectedJournal(TestAutotuner.test_config_warp_specialize_unroll) -helion.Config(block_sizes=[8, 16, 16], flatten_loops=[False], indexing='pointer', l2_groupings=[1], load_eviction_policies=['', ''], loop_orders=[[0, 1, 2]], num_stages=2, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[], range_unroll_factors=[0], range_warp_specializes=[None]) -helion.Config(block_sizes=[1, 64, 64], flatten_loops=[False], indexing='tensor_descriptor', l2_groupings=[4], load_eviction_policies=['first', 'first'], loop_orders=[[1, 2, 0]], num_stages=4, num_warps=8, pid_type='persistent_blocked', range_flattens=[None], range_multi_buffers=[False], range_unroll_factors=[0], range_warp_specializes=[True]) -helion.Config(block_sizes=[2, 8, 512], flatten_loops=[True], indexing='tensor_descriptor', l2_groupings=[8], load_eviction_policies=['first', 'first'], loop_orders=[[2, 0, 1]], num_stages=2, num_warps=1, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[], range_unroll_factors=[0], range_warp_specializes=[None]) -helion.Config(block_sizes=[1, 512, 1], flatten_loops=[True], indexing='tensor_descriptor', l2_groupings=[1], load_eviction_policies=['', 'last'], loop_orders=[[0, 2, 1]], num_stages=5, num_warps=2, pid_type='persistent_blocked', range_flattens=[True], range_multi_buffers=[False], range_unroll_factors=[0], range_warp_specializes=[True]) -helion.Config(block_sizes=[1, 4, 256], flatten_loops=[True], indexing='block_ptr', l2_groupings=[8], load_eviction_policies=['last', 'last'], loop_orders=[[1, 0, 2]], num_stages=2, num_warps=32, pid_type='persistent_interleaved', range_flattens=[None], range_multi_buffers=[True], range_unroll_factors=[0], range_warp_specializes=[True]) -helion.Config(block_sizes=[1, 128, 16], flatten_loops=[True], indexing='tensor_descriptor', l2_groupings=[16], load_eviction_policies=['first', 'first'], loop_orders=[[0, 1, 2]], num_stages=1, num_warps=1, pid_type='persistent_blocked', range_flattens=[None], range_multi_buffers=[False], range_unroll_factors=[0], range_warp_specializes=[True]) -helion.Config(block_sizes=[8, 32, 256], flatten_loops=[False], indexing='pointer', l2_groupings=[64], load_eviction_policies=['first', 'last'], loop_orders=[[0, 1, 2]], num_stages=2, num_warps=8, pid_type='persistent_blocked', range_flattens=[False], range_multi_buffers=[True], range_unroll_factors=[0], range_warp_specializes=[True]) -helion.Config(block_sizes=[2, 64, 32], flatten_loops=[False], indexing='block_ptr', l2_groupings=[8], load_eviction_policies=['last', 'first'], loop_orders=[[1, 2, 0]], num_stages=5, num_warps=16, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[], range_unroll_factors=[0], range_warp_specializes=[None]) -helion.Config(block_sizes=[4, 32, 1], flatten_loops=[True], indexing='pointer', l2_groupings=[8], load_eviction_policies=['', 'last'], loop_orders=[[2, 1, 0]], num_stages=8, num_warps=8, pid_type='persistent_blocked', range_flattens=[True], range_multi_buffers=[False], range_unroll_factors=[0], range_warp_specializes=[True]) -helion.Config(block_sizes=[4, 2, 128], flatten_loops=[False], indexing='tensor_descriptor', l2_groupings=[2], load_eviction_policies=['', 'first'], loop_orders=[[1, 2, 0]], num_stages=2, num_warps=4, pid_type='persistent_blocked', range_flattens=[False], range_multi_buffers=[None], range_unroll_factors=[0], range_warp_specializes=[True]) +helion.Config(block_sizes=[8, 16, 16], epilogue_subtiling=[0], flatten_loops=[False], indexing='pointer', l2_groupings=[1], load_eviction_policies=['', ''], loop_orders=[[0, 1, 2]], num_stages=2, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[], range_unroll_factors=[0], range_warp_specializes=[None]) +helion.Config(block_sizes=[1, 32, 32], epilogue_subtiling=[2], flatten_loops=[False], indexing='tensor_descriptor', l2_groupings=[4], load_eviction_policies=['first', 'first'], loop_orders=[[1, 2, 0]], num_stages=4, num_warps=8, pid_type='persistent_blocked', range_flattens=[None], range_multi_buffers=[False], range_unroll_factors=[0], range_warp_specializes=[True]) +helion.Config(block_sizes=[1, 64, 1], epilogue_subtiling=[0], flatten_loops=[True], indexing='block_ptr', l2_groupings=[16], load_eviction_policies=['last', ''], loop_orders=[[2, 1, 0]], num_stages=4, num_warps=1, pid_type='persistent_blocked', range_flattens=[None], range_multi_buffers=[None], range_unroll_factors=[0], range_warp_specializes=[True]) +helion.Config(block_sizes=[2, 16, 256], epilogue_subtiling=[0], flatten_loops=[True], indexing='tensor_descriptor', l2_groupings=[64], load_eviction_policies=['last', ''], loop_orders=[[1, 0, 2]], num_stages=3, num_warps=8, pid_type='persistent_interleaved', range_flattens=[True], range_multi_buffers=[False], range_unroll_factors=[0], range_warp_specializes=[True]) +helion.Config(block_sizes=[2, 4, 256], epilogue_subtiling=[0], flatten_loops=[True], indexing='tensor_descriptor', l2_groupings=[2], load_eviction_policies=['last', ''], loop_orders=[[2, 0, 1]], num_stages=6, num_warps=1, pid_type='persistent_interleaved', range_flattens=[True], range_multi_buffers=[None], range_unroll_factors=[0], range_warp_specializes=[True]) +helion.Config(block_sizes=[4, 2, 128], epilogue_subtiling=[0], flatten_loops=[True], indexing='block_ptr', l2_groupings=[1], load_eviction_policies=['first', 'first'], loop_orders=[[0, 1, 2]], num_stages=6, num_warps=1, pid_type='persistent_blocked', range_flattens=[True], range_multi_buffers=[True], range_unroll_factors=[0], range_warp_specializes=[True]) +helion.Config(block_sizes=[2, 16, 2], epilogue_subtiling=[0], flatten_loops=[True], indexing='block_ptr', l2_groupings=[64], load_eviction_policies=['first', 'first'], loop_orders=[[0, 2, 1]], num_stages=4, num_warps=16, pid_type='persistent_blocked', range_flattens=[False], range_multi_buffers=[False], range_unroll_factors=[0], range_warp_specializes=[True]) +helion.Config(block_sizes=[4, 128, 16], epilogue_subtiling=[0], flatten_loops=[True], indexing='pointer', l2_groupings=[64], load_eviction_policies=['first', ''], loop_orders=[[1, 0, 2]], num_stages=6, num_warps=4, pid_type='persistent_interleaved', range_flattens=[False], range_multi_buffers=[False], range_unroll_factors=[0], range_warp_specializes=[True]) +helion.Config(block_sizes=[4, 256, 32], epilogue_subtiling=[0], flatten_loops=[False], indexing='block_ptr', l2_groupings=[8], load_eviction_policies=['last', ''], loop_orders=[[2, 1, 0]], num_stages=5, num_warps=2, pid_type='persistent_blocked', range_flattens=[False], range_multi_buffers=[True], range_unroll_factors=[0], range_warp_specializes=[True]) +helion.Config(block_sizes=[2, 128, 8], epilogue_subtiling=[0], flatten_loops=[True], indexing='pointer', l2_groupings=[2], load_eviction_policies=['', ''], loop_orders=[[1, 2, 0]], num_stages=5, num_warps=32, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[], range_unroll_factors=[0], range_warp_specializes=[None]) --- assertExpectedJournal(TestAutotuner.test_save_load_config) { diff --git a/test/test_register_tunable.expected b/test/test_register_tunable.expected index ce67fb770..4fe321a76 100644 --- a/test/test_register_tunable.expected +++ b/test/test_register_tunable.expected @@ -2,7 +2,7 @@ This file is automatically generated by assertExpectedJournal calls in test_regi Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set. --- assertExpectedJournal(TestRegisterTunable.test_integer_fragment) -helion.Config(block_sizes=[128], indexing='pointer', load_eviction_policies=[''], multiplier=3, num_stages=2, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[0], range_unroll_factors=[0], range_warp_specializes=[]) +helion.Config(block_sizes=[128], epilogue_subtiling=[0], indexing='pointer', load_eviction_policies=[''], multiplier=3, num_stages=2, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[0], range_unroll_factors=[0], range_warp_specializes=[]) --- assertExpectedJournal(TestRegisterTunable.test_integer_fragment) from __future__ import annotations @@ -12,7 +12,7 @@ import triton import triton.language as tl from helion.runtime import default_launcher as _default_launcher -import test.test_register_tunable as _source_module +import __main__ as _source_module @triton.jit def _helion_kernel_with_int_param(x, out, multiplier, _BLOCK_SIZE_0: tl.constexpr): @@ -50,7 +50,7 @@ import triton import triton.language as tl from helion.runtime import default_launcher as _default_launcher -import test.test_register_tunable as _source_module +import __main__ as _source_module @triton.jit def _helion_matmul_split_k(x, y, out, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr): @@ -125,7 +125,7 @@ import triton import triton.language as tl from helion.runtime import default_launcher as _default_launcher -import test.test_register_tunable as _source_module +import __main__ as _source_module @triton.jit def _helion_kernel_with_tunable(x, out, _BLOCK_SIZE_0: tl.constexpr):