Skip to content

Commit 70b043d

Browse files
committed
Make HELION_FORCE_AUTOTUNE or kernel.autotune() skip the cache
stack-info: PR: #930, branch: jansel/stack/190
1 parent ac80a10 commit 70b043d

File tree

3 files changed

+15
-8
lines changed

3 files changed

+15
-8
lines changed

helion/autotuner/base_cache.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,13 @@ def _get_cache_info_message(self) -> str:
157157
"""Return a message describing where the cache is and how to clear it."""
158158
return ""
159159

160-
def autotune(self) -> Config:
161-
if os.environ.get("HELION_SKIP_CACHE", "") not in {"", "0", "false", "False"}:
160+
def autotune(self, *, skip_cache: bool = False) -> Config:
161+
if skip_cache or os.environ.get("HELION_SKIP_CACHE", "") not in {
162+
"",
163+
"0",
164+
"false",
165+
"False",
166+
}:
162167
return self.autotuner.autotune()
163168

164169
if (config := self.get()) is not None:

helion/autotuner/base_search.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ class BaseAutotuner(abc.ABC):
6464
"""
6565

6666
@abc.abstractmethod
67-
def autotune(self) -> Config:
67+
def autotune(self, *, skip_cache: bool = False) -> Config:
6868
raise NotImplementedError
6969

7070

@@ -369,7 +369,7 @@ def parallel_benchmark(
369369
results.append((config, fn, inf))
370370
return results
371371

372-
def autotune(self) -> Config:
372+
def autotune(self, *, skip_cache: bool = False) -> Config:
373373
"""
374374
Perform autotuning to find the best configuration.
375375

helion/runtime/kernel.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ def autotune(
253253
self,
254254
args: Sequence[object],
255255
*,
256-
force: bool = False,
256+
force: bool = True,
257257
**options: object,
258258
) -> Config:
259259
"""
@@ -475,7 +475,7 @@ def autotune(
475475
self,
476476
args: Sequence[object],
477477
*,
478-
force: bool = False,
478+
force: bool = True,
479479
**kwargs: object,
480480
) -> Config:
481481
"""
@@ -508,7 +508,9 @@ def autotune(
508508
config = FiniteSearch(self, args, self.configs).autotune()
509509
else:
510510
self.settings.check_autotuning_disabled()
511-
config = self.settings.autotuner_fn(self, args, **kwargs).autotune()
511+
config = self.settings.autotuner_fn(self, args, **kwargs).autotune(
512+
skip_cache=force
513+
)
512514

513515
self.set_config(config)
514516
return config
@@ -623,7 +625,7 @@ def __call__(self, *args: object) -> _R:
623625
if (config := self._implicit_config()) is not None:
624626
self.set_config(config)
625627
else:
626-
self.autotune(args)
628+
self.autotune(args, force=False)
627629
assert self._run is not None
628630

629631
assert self._config is not None

0 commit comments

Comments
 (0)