55
66import logging
77import warnings
8- from typing import Literal
8+ from typing import Any , Literal
99
1010from pytensor .compile .function .types import Supervisor
1111from pytensor .configdefaults import config
@@ -62,23 +62,17 @@ def register_linker(name, linker):
6262 predefined_linkers [name ] = linker
6363
6464
65- # If a string is passed as the optimizer argument in the constructor
66- # for Mode, it will be used as the key to retrieve the real optimizer
67- # in this dictionary
68- exclude = []
69- if not config .cxx :
70- exclude = ["cxx_only" ]
71- OPT_NONE = RewriteDatabaseQuery (include = [], exclude = exclude )
65+ OPT_NONE = RewriteDatabaseQuery (include = [])
7266# Minimum set of rewrites needed to evaluate a function. This is needed for graphs with "dummy" Operations
73- OPT_MINIMUM = RewriteDatabaseQuery (include = ["minimum_compile" ], exclude = exclude )
67+ OPT_MINIMUM = RewriteDatabaseQuery (include = ["minimum_compile" ])
7468# Even if multiple merge optimizer call will be there, this shouldn't
7569# impact performance.
76- OPT_MERGE = RewriteDatabaseQuery (include = ["merge" ], exclude = exclude )
77- OPT_FAST_RUN = RewriteDatabaseQuery (include = ["fast_run" ], exclude = exclude )
70+ OPT_MERGE = RewriteDatabaseQuery (include = ["merge" ])
71+ OPT_FAST_RUN = RewriteDatabaseQuery (include = ["fast_run" ])
7872OPT_FAST_RUN_STABLE = OPT_FAST_RUN .requiring ("stable" )
7973
80- OPT_FAST_COMPILE = RewriteDatabaseQuery (include = ["fast_compile" ], exclude = exclude )
81- OPT_STABILIZE = RewriteDatabaseQuery (include = ["fast_run" ], exclude = exclude )
74+ OPT_FAST_COMPILE = RewriteDatabaseQuery (include = ["fast_compile" ])
75+ OPT_STABILIZE = RewriteDatabaseQuery (include = ["fast_run" ])
8276OPT_STABILIZE .position_cutoff = 1.5000001
8377OPT_NONE .name = "OPT_NONE"
8478OPT_MINIMUM .name = "OPT_MINIMUM"
@@ -316,6 +310,8 @@ def __init__(
316310 ):
317311 if linker is None :
318312 linker = config .linker
313+ if isinstance (linker , str ) and linker == "auto" :
314+ linker = "cvm" if config .cxx else "vm"
319315 if isinstance (optimizer , str ) and optimizer == "default" :
320316 optimizer = config .optimizer
321317
@@ -451,24 +447,9 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
451447 return new_mode
452448
453449
454- # If a string is passed as the mode argument in function or
455- # FunctionMaker, the Mode will be taken from this dictionary using the
456- # string as the key
457- # Use VM_linker to allow lazy evaluation by default.
458- FAST_COMPILE = Mode (
459- VMLinker (use_cloop = False , c_thunks = False ),
460- RewriteDatabaseQuery (include = ["fast_compile" , "py_only" ]),
461- )
462- if config .cxx :
463- FAST_RUN = Mode ("cvm" , "fast_run" )
464- else :
465- FAST_RUN = Mode (
466- "vm" ,
467- RewriteDatabaseQuery (include = ["fast_run" , "py_only" ]),
468- )
469-
470450C = Mode ("c" , "fast_run" )
471451CVM = Mode ("cvm" , "fast_run" )
452+ VM = (Mode ("vm" , "fast_run" ),)
472453
473454NUMBA = Mode (
474455 NumbaLinker (),
@@ -489,10 +470,19 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
489470 RewriteDatabaseQuery (include = ["fast_run" ]),
490471)
491472
473+ FAST_COMPILE = Mode (
474+ VMLinker (use_cloop = False , c_thunks = False ),
475+ RewriteDatabaseQuery (include = ["fast_compile" , "py_only" ]),
476+ )
477+
478+ fast_run_linkers_to_mode = {
479+ "cvm" : CVM ,
480+ "vm" : VM ,
481+ "numba" : NUMBA ,
482+ }
492483
493484predefined_modes = {
494485 "FAST_COMPILE" : FAST_COMPILE ,
495- "FAST_RUN" : FAST_RUN ,
496486 "C" : C ,
497487 "CVM" : CVM ,
498488 "JAX" : JAX ,
@@ -501,7 +491,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
501491 "MLX" : MLX ,
502492}
503493
504- _CACHED_RUNTIME_MODES : dict [str , Mode ] = {}
494+ _CACHED_RUNTIME_MODES : dict [Any , Mode ] = {}
505495
506496
507497def get_mode (orig_string ):
@@ -519,10 +509,20 @@ def get_mode(orig_string):
519509 if upper_string in predefined_modes :
520510 return predefined_modes [upper_string ]
521511
512+ if upper_string == "FAST_RUN" :
513+ linker = config .linker
514+ if linker == "auto" :
515+ return CVM if config .cxx else VM
516+ return fast_run_linkers_to_mode [linker ]
517+
522518 global _CACHED_RUNTIME_MODES
523519
524- if upper_string in _CACHED_RUNTIME_MODES :
525- return _CACHED_RUNTIME_MODES [upper_string ]
520+ cache_key = ("MODE" , config .linker ) if upper_string == "MODE" else upper_string
521+
522+ try :
523+ return _CACHED_RUNTIME_MODES [cache_key ]
524+ except KeyError :
525+ pass
526526
527527 # Need to define the mode for the first time
528528 if upper_string == "MODE" :
@@ -548,7 +548,7 @@ def get_mode(orig_string):
548548 if config .optimizer_requiring :
549549 ret = ret .requiring (* config .optimizer_requiring .split (":" ))
550550 # Cache the mode for next time
551- _CACHED_RUNTIME_MODES [upper_string ] = ret
551+ _CACHED_RUNTIME_MODES [cache_key ] = ret
552552
553553 return ret
554554
0 commit comments