|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 4 | +""" |
| 5 | +Script to autotune Helion kernels using the CustomOp registry. |
| 6 | +
|
| 7 | +This script discovers all registered Helion kernels and runs autotuning |
| 8 | +to generate optimized configurations for different input shapes. |
| 9 | +
|
| 10 | +Usage: |
| 11 | + # Autotune all Helion kernels |
| 12 | + python scripts/autotune_helion_kernels.py |
| 13 | +
|
| 14 | + # Autotune specific kernel |
| 15 | + python scripts/autotune_helion_kernels.py --kernel silu_mul_fp8_helion |
| 16 | +
|
| 17 | + # Autotune with custom output directory |
| 18 | + python scripts/autotune_helion_kernels.py --output-dir ./my_configs |
| 19 | +
|
| 20 | + # List available Helion kernels |
| 21 | + python scripts/autotune_helion_kernels.py --list |
| 22 | +
|
| 23 | +Requirements: |
| 24 | + - CUDA GPU available |
| 25 | + - Helion package installed |
| 26 | + - vLLM environment setup |
| 27 | +""" |
| 28 | + |
| 29 | +import argparse |
| 30 | +import os |
| 31 | +import sys |
| 32 | +import time |
| 33 | + |
| 34 | +import torch |
| 35 | + |
| 36 | +# Add vLLM to path if not already available |
| 37 | +try: |
| 38 | + from vllm.compilation.helion.config_manager import ConfigManager |
| 39 | + from vllm.compilation.helion.custom_op import HelionCustomOp |
| 40 | + from vllm.config import VllmConfig |
| 41 | + from vllm.config.compilation import CompilationConfig |
| 42 | + from vllm.config.vllm import set_current_vllm_config |
| 43 | + from vllm.logger import init_logger |
| 44 | + from vllm.model_executor.custom_op import CustomOp |
| 45 | +except ImportError as e: |
| 46 | + print(f"Error importing vLLM: {e}") |
| 47 | + print("Please ensure vLLM is installed and in your Python path") |
| 48 | + sys.exit(1) |
| 49 | + |
| 50 | +logger = init_logger("vllm.scripts.autotune_helion_kernels") |
| 51 | + |
| 52 | + |
| 53 | +def get_default_config_dir() -> str: |
| 54 | + """ |
| 55 | + Get the default configuration directory using ConfigManager. |
| 56 | +
|
| 57 | + Returns: |
| 58 | + Default path for Helion configs |
| 59 | + """ |
| 60 | + config_manager = ConfigManager() |
| 61 | + return str(config_manager.get_base_dir()) |
| 62 | + |
| 63 | + |
| 64 | +def get_helion_kernels() -> dict[str, type[HelionCustomOp]]: |
| 65 | + """ |
| 66 | + Discover all registered Helion kernels. |
| 67 | +
|
| 68 | + Returns: |
| 69 | + Dictionary mapping kernel names to their classes |
| 70 | + """ |
| 71 | + helion_kernels = {} |
| 72 | + |
| 73 | + for name, op_cls in CustomOp.op_registry.items(): |
| 74 | + if issubclass(op_cls, HelionCustomOp): |
| 75 | + helion_kernels[name] = op_cls |
| 76 | + |
| 77 | + return helion_kernels |
| 78 | + |
| 79 | + |
| 80 | +def list_kernels(): |
| 81 | + """List all available Helion kernels.""" |
| 82 | + kernels = get_helion_kernels() |
| 83 | + |
| 84 | + if not kernels: |
| 85 | + print("No Helion kernels found in registry.") |
| 86 | + return |
| 87 | + |
| 88 | + print("Available Helion kernels:") |
| 89 | + print("=" * 50) |
| 90 | + |
| 91 | + for name, op_cls in kernels.items(): |
| 92 | + doc = op_cls.__doc__ or "No description available" |
| 93 | + # Extract first line of docstring |
| 94 | + first_line = doc.strip().split("\n")[0] |
| 95 | + print(f" {name:<30} - {first_line}") |
| 96 | + |
| 97 | + print(f"\nTotal: {len(kernels)} kernels") |
| 98 | + |
| 99 | + |
| 100 | +def check_requirements() -> bool: |
| 101 | + """ |
| 102 | + Check if all requirements are met for autotuning. |
| 103 | +
|
| 104 | + Returns: |
| 105 | + True if requirements are met, False otherwise |
| 106 | + """ |
| 107 | + # Check CUDA availability |
| 108 | + if not torch.cuda.is_available(): |
| 109 | + logger.error("CUDA is not available. Helion autotuning requires GPU.") |
| 110 | + return False |
| 111 | + |
| 112 | + # Check Helion availability |
| 113 | + if not HelionCustomOp.is_helion_available(): |
| 114 | + logger.error("Helion is not installed. Please install Helion package.") |
| 115 | + return False |
| 116 | + |
| 117 | + return True |
| 118 | + |
| 119 | + |
| 120 | +def autotune_kernel( |
| 121 | + kernel_name: str, op_cls: type[HelionCustomOp], output_dir: str, force: bool = False |
| 122 | +) -> bool: |
| 123 | + """ |
| 124 | + Autotune a specific Helion kernel. |
| 125 | +
|
| 126 | + Args: |
| 127 | + kernel_name: Name of the kernel |
| 128 | + op_cls: Kernel class |
| 129 | + output_dir: Output directory for configs |
| 130 | + force: Force re-autotuning even if configs exist |
| 131 | +
|
| 132 | + Returns: |
| 133 | + True if successful, False otherwise |
| 134 | + """ |
| 135 | + try: |
| 136 | + # Create kernel instance |
| 137 | + logger.info("Autotuning kernel: %s", kernel_name) |
| 138 | + |
| 139 | + # Skip enabled check during autotuning - we want to force autotune for |
| 140 | + # Helion kernels. The issue is that the compilation config is being |
| 141 | + # reset to ['none'] by system defaults |
| 142 | + logger.info( |
| 143 | + "Forcing autotuning for %s (bypassing enabled check)", kernel_name |
| 144 | + ) |
| 145 | + |
| 146 | + kernel_instance = op_cls() |
| 147 | + |
| 148 | + # Get autotune inputs to check what will be generated |
| 149 | + autotune_inputs = kernel_instance.get_autotune_inputs() |
| 150 | + logger.info( |
| 151 | + "Will generate %d configs: %s", |
| 152 | + len(autotune_inputs), |
| 153 | + list(autotune_inputs.keys()) |
| 154 | + ) |
| 155 | + |
| 156 | + # Filter out existing configs (unless forcing) |
| 157 | + configs_to_autotune = autotune_inputs |
| 158 | + if not force: |
| 159 | + config_manager = ConfigManager(output_dir) |
| 160 | + existing_configs = [] |
| 161 | + configs_to_autotune = {} |
| 162 | + |
| 163 | + for config_key, inputs in autotune_inputs.items(): |
| 164 | + if config_manager.config_exists(kernel_instance.__class__, config_key): |
| 165 | + existing_configs.append(config_key) |
| 166 | + logger.info("Config %s already exists, skipping", config_key) |
| 167 | + else: |
| 168 | + configs_to_autotune[config_key] = inputs |
| 169 | + |
| 170 | + if existing_configs and configs_to_autotune: |
| 171 | + logger.info( |
| 172 | + "Found existing configs for %s, will autotune remaining: %s", |
| 173 | + existing_configs, list(configs_to_autotune.keys()) |
| 174 | + ) |
| 175 | + elif existing_configs and not configs_to_autotune: |
| 176 | + logger.info( |
| 177 | + f"All configs already exist for {existing_configs}, use --force to re-generate" |
| 178 | + ) |
| 179 | + return True |
| 180 | + elif not existing_configs: |
| 181 | + logger.info( |
| 182 | + f"No existing configs found, will autotune all: {list(configs_to_autotune.keys())}" |
| 183 | + ) |
| 184 | + |
| 185 | + if not configs_to_autotune: |
| 186 | + logger.info("No configs to autotune for %s", kernel_name) |
| 187 | + return True |
| 188 | + |
| 189 | + # Run autotuning with filtered inputs |
| 190 | + start_time = time.time() |
| 191 | + configs = kernel_instance.autotune(configs_to_autotune) |
| 192 | + end_time = time.time() |
| 193 | + |
| 194 | + # Save the generated configs |
| 195 | + config_manager = ConfigManager(output_dir) |
| 196 | + config_manager.ensure_base_dir_exists() |
| 197 | + |
| 198 | + saved_configs = [] |
| 199 | + for config_key, config in configs.items(): |
| 200 | + try: |
| 201 | + config_path = config_manager.save_config( |
| 202 | + kernel_instance.__class__, config_key, config |
| 203 | + ) |
| 204 | + saved_configs.append(config_key) |
| 205 | + logger.info(f"Saved config {config_key} to: {config_path}") |
| 206 | + except Exception as e: |
| 207 | + logger.error(f"Failed to save config {config_key}: {e}") |
| 208 | + |
| 209 | + logger.info(f"Autotuning completed in {end_time - start_time:.2f}s") |
| 210 | + logger.info( |
| 211 | + f"Generated and saved {len(saved_configs)} configs for {kernel_name}: {saved_configs}" |
| 212 | + ) |
| 213 | + |
| 214 | + return True |
| 215 | + |
| 216 | + except Exception as e: |
| 217 | + logger.error(f"Failed to autotune {kernel_name}: {e}") |
| 218 | + return False |
| 219 | + |
| 220 | + |
| 221 | +def main(): |
| 222 | + parser = argparse.ArgumentParser( |
| 223 | + description="Autotune Helion kernels", |
| 224 | + formatter_class=argparse.RawDescriptionHelpFormatter, |
| 225 | + epilog=__doc__.split("Usage:")[1] if "Usage:" in __doc__ else "", |
| 226 | + ) |
| 227 | + |
| 228 | + parser.add_argument( |
| 229 | + "--kernel", type=str, help="Specific kernel to autotune (default: all kernels)" |
| 230 | + ) |
| 231 | + |
| 232 | + parser.add_argument( |
| 233 | + "--output-dir", |
| 234 | + type=str, |
| 235 | + default=None, |
| 236 | + help="Output directory for config files (default: <vllm_repo>/vllm/compilation/helion/configs)", |
| 237 | + ) |
| 238 | + |
| 239 | + parser.add_argument( |
| 240 | + "--list", action="store_true", help="List available Helion kernels and exit" |
| 241 | + ) |
| 242 | + |
| 243 | + parser.add_argument( |
| 244 | + "--force", |
| 245 | + action="store_true", |
| 246 | + help="Force re-autotuning even if config files already exist", |
| 247 | + ) |
| 248 | + |
| 249 | + parser.add_argument("--verbose", action="store_true", help="Enable verbose logging") |
| 250 | + |
| 251 | + args = parser.parse_args() |
| 252 | + |
| 253 | + # Set up logging |
| 254 | + if args.verbose: |
| 255 | + import logging |
| 256 | + |
| 257 | + logging.getLogger("vllm").setLevel(logging.DEBUG) |
| 258 | + |
| 259 | + # List kernels if requested |
| 260 | + if args.list: |
| 261 | + list_kernels() |
| 262 | + return |
| 263 | + |
| 264 | + # Check requirements |
| 265 | + if not check_requirements(): |
| 266 | + sys.exit(1) |
| 267 | + |
| 268 | + # Configure vLLM to enable all custom ops for autotuning |
| 269 | + # This overrides the default "none" setting when using inductor backend |
| 270 | + vllm_config = VllmConfig( |
| 271 | + compilation_config=CompilationConfig( |
| 272 | + custom_ops=["all"] # Enable all custom ops including Helion kernels |
| 273 | + ) |
| 274 | + ) |
| 275 | + |
| 276 | + # Set the config context for this autotuning session |
| 277 | + set_current_vllm_config(vllm_config) |
| 278 | + logger.info("Enabled all custom ops for autotuning") |
| 279 | + |
| 280 | + # Get available kernels |
| 281 | + helion_kernels = get_helion_kernels() |
| 282 | + |
| 283 | + if not helion_kernels: |
| 284 | + logger.error("No Helion kernels found in registry") |
| 285 | + sys.exit(1) |
| 286 | + |
| 287 | + # Filter to specific kernel if requested |
| 288 | + if args.kernel: |
| 289 | + if args.kernel not in helion_kernels: |
| 290 | + logger.error(f"Kernel '{args.kernel}' not found. Available kernels:") |
| 291 | + for name in helion_kernels.keys(): |
| 292 | + logger.error(f" - {name}") |
| 293 | + sys.exit(1) |
| 294 | + helion_kernels = {args.kernel: helion_kernels[args.kernel]} |
| 295 | + |
| 296 | + # Determine output directory |
| 297 | + output_dir = args.output_dir if args.output_dir else get_default_config_dir() |
| 298 | + |
| 299 | + # Create output directory |
| 300 | + try: |
| 301 | + os.makedirs(output_dir, exist_ok=True) |
| 302 | + logger.info(f"Output directory: {output_dir}") |
| 303 | + |
| 304 | + # Verify directory is writable |
| 305 | + test_file = os.path.join(output_dir, ".write_test") |
| 306 | + try: |
| 307 | + with open(test_file, "w") as f: |
| 308 | + f.write("test") |
| 309 | + os.remove(test_file) |
| 310 | + except Exception as e: |
| 311 | + logger.error(f"Output directory is not writable: {e}") |
| 312 | + sys.exit(1) |
| 313 | + |
| 314 | + except Exception as e: |
| 315 | + logger.error(f"Failed to create output directory '{output_dir}': {e}") |
| 316 | + sys.exit(1) |
| 317 | + |
| 318 | + # Autotune kernels |
| 319 | + total_kernels = len(helion_kernels) |
| 320 | + successful = 0 |
| 321 | + |
| 322 | + logger.info(f"Starting autotuning for {total_kernels} kernel(s)") |
| 323 | + |
| 324 | + for kernel_name, op_cls in helion_kernels.items(): |
| 325 | + if autotune_kernel(kernel_name, op_cls, output_dir, args.force): |
| 326 | + successful += 1 |
| 327 | + else: |
| 328 | + logger.warning(f"Skipped or failed: {kernel_name}") |
| 329 | + |
| 330 | + # Summary |
| 331 | + logger.info("=" * 50) |
| 332 | + logger.info(f"Autotuning complete: {successful}/{total_kernels} kernels successful") |
| 333 | + |
| 334 | + if successful < total_kernels: |
| 335 | + logger.warning(f"{total_kernels - successful} kernels failed or were skipped") |
| 336 | + sys.exit(1) |
| 337 | + else: |
| 338 | + logger.info("All kernels autotuned successfully!") |
| 339 | + |
| 340 | + |
| 341 | +if __name__ == "__main__": |
| 342 | + main() |
0 commit comments