Skip to content

Commit 702d5d4

Browse files
committed
Allow config.linker to change and affect both Mode and FAST_RUN abstract modes
1 parent c623124 commit 702d5d4

File tree

5 files changed

+47
-46
lines changed

5 files changed

+47
-46
lines changed

pytensor/compile/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
)
1818
from pytensor.compile.io import In, Out, SymbolicInput, SymbolicOutput
1919
from pytensor.compile.mode import (
20+
CVM,
2021
FAST_COMPILE,
21-
FAST_RUN,
2222
JAX,
2323
NUMBA,
2424
OPT_FAST_COMPILE,
@@ -33,6 +33,7 @@
3333
PYTORCH,
3434
AddDestroyHandler,
3535
AddFeatureOptimizer,
36+
C,
3637
Mode,
3738
PrintCurrentFunctionGraph,
3839
get_default_mode,

pytensor/compile/mode.py

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import logging
77
import warnings
8-
from typing import Literal
8+
from typing import Any, Literal
99

1010
from pytensor.compile.function.types import Supervisor
1111
from pytensor.configdefaults import config
@@ -62,23 +62,17 @@ def register_linker(name, linker):
6262
predefined_linkers[name] = linker
6363

6464

65-
# If a string is passed as the optimizer argument in the constructor
66-
# for Mode, it will be used as the key to retrieve the real optimizer
67-
# in this dictionary
68-
exclude = []
69-
if not config.cxx:
70-
exclude = ["cxx_only"]
71-
OPT_NONE = RewriteDatabaseQuery(include=[], exclude=exclude)
65+
OPT_NONE = RewriteDatabaseQuery(include=[])
7266
# Minimum set of rewrites needed to evaluate a function. This is needed for graphs with "dummy" Operations
73-
OPT_MINIMUM = RewriteDatabaseQuery(include=["minimum_compile"], exclude=exclude)
67+
OPT_MINIMUM = RewriteDatabaseQuery(include=["minimum_compile"])
7468
# Even if multiple merge optimizer call will be there, this shouldn't
7569
# impact performance.
76-
OPT_MERGE = RewriteDatabaseQuery(include=["merge"], exclude=exclude)
77-
OPT_FAST_RUN = RewriteDatabaseQuery(include=["fast_run"], exclude=exclude)
70+
OPT_MERGE = RewriteDatabaseQuery(include=["merge"])
71+
OPT_FAST_RUN = RewriteDatabaseQuery(include=["fast_run"])
7872
OPT_FAST_RUN_STABLE = OPT_FAST_RUN.requiring("stable")
7973

80-
OPT_FAST_COMPILE = RewriteDatabaseQuery(include=["fast_compile"], exclude=exclude)
81-
OPT_STABILIZE = RewriteDatabaseQuery(include=["fast_run"], exclude=exclude)
74+
OPT_FAST_COMPILE = RewriteDatabaseQuery(include=["fast_compile"])
75+
OPT_STABILIZE = RewriteDatabaseQuery(include=["fast_run"])
8276
OPT_STABILIZE.position_cutoff = 1.5000001
8377
OPT_NONE.name = "OPT_NONE"
8478
OPT_MINIMUM.name = "OPT_MINIMUM"
@@ -316,6 +310,8 @@ def __init__(
316310
):
317311
if linker is None:
318312
linker = config.linker
313+
if isinstance(linker, str) and linker == "auto":
314+
linker = "cvm" if config.cxx else "vm"
319315
if isinstance(optimizer, str) and optimizer == "default":
320316
optimizer = config.optimizer
321317

@@ -451,24 +447,9 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
451447
return new_mode
452448

453449

454-
# If a string is passed as the mode argument in function or
455-
# FunctionMaker, the Mode will be taken from this dictionary using the
456-
# string as the key
457-
# Use VM_linker to allow lazy evaluation by default.
458-
FAST_COMPILE = Mode(
459-
VMLinker(use_cloop=False, c_thunks=False),
460-
RewriteDatabaseQuery(include=["fast_compile", "py_only"]),
461-
)
462-
if config.cxx:
463-
FAST_RUN = Mode("cvm", "fast_run")
464-
else:
465-
FAST_RUN = Mode(
466-
"vm",
467-
RewriteDatabaseQuery(include=["fast_run", "py_only"]),
468-
)
469-
470450
C = Mode("c", "fast_run")
471451
CVM = Mode("cvm", "fast_run")
452+
VM = (Mode("vm", "fast_run"),)
472453

473454
NUMBA = Mode(
474455
NumbaLinker(),
@@ -489,10 +470,19 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
489470
RewriteDatabaseQuery(include=["fast_run"]),
490471
)
491472

473+
FAST_COMPILE = Mode(
474+
VMLinker(use_cloop=False, c_thunks=False),
475+
RewriteDatabaseQuery(include=["fast_compile", "py_only"]),
476+
)
477+
478+
fast_run_linkers_to_mode = {
479+
"cvm": CVM,
480+
"vm": VM,
481+
"numba": NUMBA,
482+
}
492483

493484
predefined_modes = {
494485
"FAST_COMPILE": FAST_COMPILE,
495-
"FAST_RUN": FAST_RUN,
496486
"C": C,
497487
"CVM": CVM,
498488
"JAX": JAX,
@@ -501,7 +491,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
501491
"MLX": MLX,
502492
}
503493

504-
_CACHED_RUNTIME_MODES: dict[str, Mode] = {}
494+
_CACHED_RUNTIME_MODES: dict[Any, Mode] = {}
505495

506496

507497
def get_mode(orig_string):
@@ -519,10 +509,20 @@ def get_mode(orig_string):
519509
if upper_string in predefined_modes:
520510
return predefined_modes[upper_string]
521511

512+
if upper_string == "FAST_RUN":
513+
linker = config.linker
514+
if linker == "auto":
515+
return CVM if config.cxx else VM
516+
return fast_run_linkers_to_mode[linker]
517+
522518
global _CACHED_RUNTIME_MODES
523519

524-
if upper_string in _CACHED_RUNTIME_MODES:
525-
return _CACHED_RUNTIME_MODES[upper_string]
520+
cache_key = ("MODE", config.linker) if upper_string == "MODE" else upper_string
521+
522+
try:
523+
return _CACHED_RUNTIME_MODES[cache_key]
524+
except KeyError:
525+
pass
526526

527527
# Need to define the mode for the first time
528528
if upper_string == "MODE":
@@ -548,7 +548,7 @@ def get_mode(orig_string):
548548
if config.optimizer_requiring:
549549
ret = ret.requiring(*config.optimizer_requiring.split(":"))
550550
# Cache the mode for next time
551-
_CACHED_RUNTIME_MODES[upper_string] = ret
551+
_CACHED_RUNTIME_MODES[cache_key] = ret
552552

553553
return ret
554554

pytensor/configdefaults.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -371,11 +371,12 @@ def add_compile_configvars():
371371
)
372372
del param
373373

374-
default_linker = "cvm"
374+
default_linker = "auto"
375375

376376
if rc == 0 and config.cxx != "":
377377
# Keep the default linker the same as the one for the mode FAST_RUN
378378
linker_options = [
379+
"cvm",
379380
"c|py",
380381
"py",
381382
"c",
@@ -401,9 +402,8 @@ def add_compile_configvars():
401402

402403
config.add(
403404
"linker",
404-
"Default linker used if the pytensor flags mode is Mode",
405-
# Not mutable because the default mode is cached after the first use.
406-
EnumStr(default_linker, linker_options, mutable=False),
405+
"Default linker used if the pytensor flags mode is Mode or FAST_RUN",
406+
EnumStr(default_linker, linker_options, mutable=True),
407407
in_c_key=False,
408408
)
409409

pytensor/gradient.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1784,14 +1784,14 @@ def max_err(self, g_pt, abs_tol, rel_tol):
17841784

17851785
def mode_not_slow(mode):
17861786
from pytensor.compile.debugmode import DebugMode
1787-
from pytensor.compile.mode import FAST_RUN, get_mode
1787+
from pytensor.compile.mode import get_mode
17881788

17891789
if mode == "FAST_COMPILE":
1790-
return FAST_RUN
1790+
return get_mode("FAST_RUN")
17911791
mode = get_mode(mode)
17921792
if isinstance(mode, DebugMode):
17931793
opt = mode.optimizer
1794-
return FAST_RUN.clone(optimizer=opt)
1794+
return get_mode("FAST_RUN").clone(optimizer=opt)
17951795
else:
17961796
return mode
17971797

tests/scan/test_views.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import pytensor.tensor as pt
55
from pytensor import config, function, grad, shared
6-
from pytensor.compile.mode import FAST_RUN
6+
from pytensor.compile.mode import get_mode
77
from pytensor.link.basic import JITLinker
88
from pytensor.scan.views import filter as pt_filter
99
from pytensor.scan.views import foldl, foldr
@@ -65,7 +65,7 @@ def test_reduce_memory_consumption():
6565
pt.constant(np.asarray(0.0, dtype=config.floatX)),
6666
return_updates=False,
6767
)
68-
mode = FAST_RUN
68+
mode = get_mode("FAST_RUN")
6969
mode = mode.excluding("inplace")
7070
f1 = function([], o, mode=mode)
7171
inputs, outputs = clone_optimized_graph(f1)
@@ -106,7 +106,7 @@ def test_foldl_memory_consumption(return_updates):
106106
else:
107107
o = o_raw
108108

109-
mode = FAST_RUN
109+
mode = get_mode("FAST_RUN")
110110
mode = mode.excluding("inplace")
111111
f0 = function([], o, mode=mode)
112112
inputs, outputs = clone_optimized_graph(f0)
@@ -147,7 +147,7 @@ def test_foldr_memory_consumption(return_updates):
147147
else:
148148
o = o_raw
149149

150-
mode = FAST_RUN
150+
mode = get_mode("FAST_RUN")
151151
mode = mode.excluding("inplace")
152152
f1 = function([], o, mode=mode)
153153
inputs, outputs = clone_optimized_graph(f1)

0 commit comments

Comments
 (0)