Skip to content

Commit 8338452

Browse files
authored
Autotuning Progress Bar (#739)
1 parent e4df8ca commit 8338452

File tree

7 files changed

+75
-20
lines changed

7 files changed

+75
-20
lines changed

helion/autotuner/base_search.py

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import torch.multiprocessing as mp
3030
from torch.utils._pytree import tree_flatten
3131
from torch.utils._pytree import tree_map
32+
from tqdm.auto import tqdm
3233
from triton.testing import do_bench
3334

3435
from .. import exc
@@ -295,13 +296,14 @@ def extract_launcher(
295296
)
296297

297298
def parallel_benchmark(
298-
self, configs: list[Config]
299+
self, configs: list[Config], *, desc: str = "Benchmarking"
299300
) -> list[tuple[Config, Callable[..., object], float]]:
300301
"""
301302
Benchmark multiple configurations in parallel.
302303
303304
Args:
304305
configs: A list of configurations to benchmark.
306+
desc: Description for the progress bar.
305307
306308
Returns:
307309
A list of tuples containing configurations and their performance.
@@ -319,7 +321,16 @@ def parallel_benchmark(
319321
else:
320322
is_workings = [True] * len(configs)
321323
results = []
322-
for config, fn, is_working in zip(configs, fns, is_workings, strict=True):
324+
iterator = zip(configs, fns, is_workings, strict=True)
325+
if self.settings.autotune_progress_bar:
326+
iterator = tqdm(
327+
iterator,
328+
total=len(configs),
329+
desc=desc,
330+
unit="config",
331+
disable=not self.settings.autotune_progress_bar,
332+
)
333+
for config, fn, is_working in iterator:
323334
if is_working:
324335
# benchmark one-by-one to avoid noisy results
325336
results.append((config, fn, self.benchmark_function(config, fn)))
@@ -479,13 +490,19 @@ def make_unbenchmarked(self, flat_values: FlatConfig) -> PopulationMember:
479490
return PopulationMember(_unset_fn, [], flat_values, config)
480491

481492
def parallel_benchmark_population(
482-
self, members: list[PopulationMember]
493+
self, members: list[PopulationMember], *, desc: str = "Benchmarking"
483494
) -> list[PopulationMember]:
484495
"""
485496
Benchmark multiple population members in parallel. Members should be created with make_unbenchmarked.
497+
498+
Args:
499+
members: The list of population members to benchmark.
500+
desc: Description for the progress bar.
486501
"""
487502
for member, (config_out, fn, perf) in zip(
488-
members, self.parallel_benchmark([m.config for m in members]), strict=True
503+
members,
504+
self.parallel_benchmark([m.config for m in members], desc=desc),
505+
strict=True,
489506
):
490507
assert config_out is member.config
491508
member.perfs.append(perf)
@@ -523,30 +540,45 @@ def should_rebenchmark(self, member: PopulationMember) -> bool:
523540
and math.isfinite(member.perf)
524541
)
525542

526-
def rebenchmark(self, members: list[PopulationMember]) -> None:
543+
def rebenchmark(
544+
self, members: list[PopulationMember], *, desc: str = "Rebenchmarking"
545+
) -> None:
527546
"""
528547
Re-benchmark a list of population members to avoid outliers.
548+
549+
Args:
550+
members: The list of population members to rebenchmark.
551+
desc: Description for the progress bar.
529552
"""
530553
if len(members) < 2:
531554
return
532555
repeat = max(3, int(200 / self.best_perf_so_far))
533-
new_timings = interleaved_bench(
534-
[functools.partial(m.fn, *self.args) for m in members], repeat=repeat
535-
)
556+
iterator = [functools.partial(m.fn, *self.args) for m in members]
557+
if self.settings.autotune_progress_bar:
558+
new_timings = interleaved_bench(iterator, repeat=repeat, desc=desc)
559+
else:
560+
new_timings = interleaved_bench(iterator, repeat=repeat)
536561
for m, t in zip(members, new_timings, strict=True):
537562
m.perfs.append(t)
538563
if t < self.best_perf_so_far:
539564
self.best_perf_so_far = t
540565

541566
def rebenchmark_population(
542-
self, members: list[PopulationMember] | None = None
567+
self,
568+
members: list[PopulationMember] | None = None,
569+
*,
570+
desc: str = "Rebenchmarking",
543571
) -> None:
544572
"""
545573
Re-benchmark the entire population to avoid outliers.
574+
575+
Args:
576+
members: The list of population members to rebenchmark.
577+
desc: Description for the progress bar.
546578
"""
547579
if members is None:
548580
members = self.population
549-
self.rebenchmark([p for p in members if self.should_rebenchmark(p)])
581+
self.rebenchmark([p for p in members if self.should_rebenchmark(p)], desc=desc)
550582

551583
def statistics(self) -> str:
552584
"""

helion/autotuner/benchmarking.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,22 @@
44
import statistics
55
from typing import Callable
66

7+
from tqdm.auto import tqdm
78
from triton import runtime
89

910

10-
def interleaved_bench(fns: list[Callable[[], object]], *, repeat: int) -> list[float]:
11+
def interleaved_bench(
12+
fns: list[Callable[[], object]], *, repeat: int, desc: str | None = None
13+
) -> list[float]:
1114
"""
1215
Benchmark multiple functions at once, interleaving their executions to reduce
1316
the impact of external factors (e.g., load, temperature) on the
1417
measurements.
18+
19+
Args:
20+
fns: List of functions to benchmark
21+
repeat: Number of times to repeat each benchmark
22+
desc: Optional description for progress bar
1523
"""
1624
# warmup
1725
for fn in fns:
@@ -30,7 +38,10 @@ def interleaved_bench(fns: list[Callable[[], object]], *, repeat: int) -> list[f
3038
]
3139

3240
di.synchronize()
33-
for i in range(repeat):
41+
iterator = range(repeat)
42+
if desc is not None:
43+
iterator = tqdm(iterator, desc=desc, total=repeat, unit="round")
44+
for i in iterator:
3445
for j in range(len(fns)):
3546
clear_cache()
3647
start_events[j][i].record()

helion/autotuner/finite_search.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@ def __init__(
3535
def _autotune(self) -> Config:
3636
best_config = None
3737
best_time = float("inf")
38-
for config, _fn, time in self.parallel_benchmark(self.configs):
38+
for config, _fn, time in self.parallel_benchmark(
39+
self.configs, desc="Benchmarking"
40+
):
3941
if time < best_time:
4042
best_time = time
4143
best_config = config

helion/autotuner/pattern_search.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,9 @@ def _autotune(self) -> Config:
5757
if member.config not in visited:
5858
visited.add(member.config)
5959
self.population.append(member)
60-
self.parallel_benchmark_population(self.population)
60+
self.parallel_benchmark_population(self.population, desc="Initial population")
6161
# again with higher accuracy
62-
self.rebenchmark_population(self.population)
62+
self.rebenchmark_population(self.population, desc="Initial rebench")
6363
self.population.sort(key=performance)
6464
starting_points = []
6565
for member in self.population[: self.copies]:
@@ -90,11 +90,15 @@ def _autotune(self) -> Config:
9090
break
9191
self.population = [*new_population.values()]
9292
# compile any unbenchmarked members in parallel
93-
self.parallel_benchmark_population(
94-
[m for m in self.population if len(m.perfs) == 0]
95-
)
93+
unbenchmarked = [m for m in self.population if len(m.perfs) == 0]
94+
if unbenchmarked:
95+
self.parallel_benchmark_population(
96+
unbenchmarked, desc=f"Gen {generation} neighbors"
97+
)
9698
# higher-accuracy rebenchmark
97-
self.rebenchmark_population(self.population)
99+
self.rebenchmark_population(
100+
self.population, desc=f"Gen {generation} rebench"
101+
)
98102
self.log(
99103
f"Generation {generation}, {num_neighbors} neighbors, {num_active} active:",
100104
self.statistics,

helion/runtime/settings.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,9 @@ class _Settings:
111111
autotune_rebenchmark_threshold: float = float(
112112
os.environ.get("HELION_REBENCHMARK_THRESHOLD", "1.5")
113113
)
114+
autotune_progress_bar: bool = (
115+
os.environ.get("HELION_AUTOTUNE_PROGRESS_BAR", "1") == "1"
116+
)
114117
print_output_code: bool = os.environ.get("HELION_PRINT_OUTPUT_CODE", "0") == "1"
115118
force_autotune: bool = os.environ.get("HELION_FORCE_AUTOTUNE", "0") == "1"
116119
allow_warp_specialize: bool = (
@@ -142,6 +145,7 @@ class Settings(_Settings):
142145
"autotune_random_seed": "Seed used for autotuner random number generation. Defaults to HELION_AUTOTUNE_RANDOM_SEED or a time-based seed.",
143146
"autotune_accuracy_check": "If True, validate candidate configs against the baseline kernel output before accepting them during autotuning.",
144147
"autotune_rebenchmark_threshold": "If a config is within threshold*best_perf, re-benchmark it to avoid outliers. Default is 1.5x. Set to <1 to disable.",
148+
"autotune_progress_bar": "If True, show progress bar during autotuning. Default is True. Set HELION_AUTOTUNE_PROGRESS_BAR=0 to disable.",
145149
"print_output_code": "If True, print the output code of the kernel to stderr.",
146150
"force_autotune": "If True, force autotuning even if a config is provided.",
147151
"allow_warp_specialize": "If True, allow warp specialization for tl.range calls on CUDA devices.",

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ dependencies = [
2020
"torch>=2.7.0",
2121
"typing_extensions>=4.0.0",
2222
"filecheck",
23-
"psutil"
23+
"psutil",
24+
"tqdm"
2425
]
2526

2627
[project.optional-dependencies]

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ pre-commit
44
filecheck
55
expecttest
66
numpy
7+
tqdm

0 commit comments

Comments
 (0)