Skip to content
Draft
1 change: 0 additions & 1 deletion pytensor/compile/function/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1838,7 +1838,6 @@ def orig_function(
profile.compile_time += t2 - t1
# TODO: append
profile.nb_nodes = len(fn.maker.fgraph.apply_nodes)

return fn


Expand Down
51 changes: 36 additions & 15 deletions pytensor/compile/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
"jax": JAXLinker(),
"pytorch": PytorchLinker(),
"numba": NumbaLinker(),
"numba_vm": NumbaLinker(vm=True),
}


Expand All @@ -63,9 +64,8 @@ def register_linker(name, linker):
# If a string is passed as the optimizer argument in the constructor
# for Mode, it will be used as the key to retrieve the real optimizer
# in this dictionary
exclude = []
if not config.cxx:
exclude = ["cxx_only"]

exclude = ["cxx_only", "BlasOpt"]
OPT_NONE = RewriteDatabaseQuery(include=[], exclude=exclude)
# Minimum set of rewrites needed to evaluate a function. This is needed for graphs with "dummy" Operations
OPT_MINIMUM = RewriteDatabaseQuery(include=["minimum_compile"], exclude=exclude)
Expand Down Expand Up @@ -351,6 +351,11 @@ def __setstate__(self, state):
optimizer = predefined_optimizers[optimizer]
if isinstance(optimizer, RewriteDatabaseQuery):
self.provided_optimizer = optimizer

# Force numba-required rewrites if using NumbaLinker
if isinstance(linker, NumbaLinker):
optimizer = optimizer.including("numba")

self._optimizer = optimizer
self.call_time = 0
self.fn_time = 0
Expand Down Expand Up @@ -448,19 +453,26 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
# string as the key
# Use VM_linker to allow lazy evaluation by default.
FAST_COMPILE = Mode(
VMLinker(use_cloop=False, c_thunks=False),
RewriteDatabaseQuery(include=["fast_compile", "py_only"]),
"numba_vm",
# TODO: Fast_compile should just use python code, CHANGE ME!
RewriteDatabaseQuery(
include=["fast_compile", "numba"],
exclude=["cxx_only", "BlasOpt", "local_careduce_fusion"],
),
)
FAST_RUN = Mode(
"numba",
RewriteDatabaseQuery(
include=["fast_run", "numba"],
exclude=["cxx_only", "BlasOpt", "local_careduce_fusion"],
),
)
if config.cxx:
FAST_RUN = Mode("cvm", "fast_run")
else:
FAST_RUN = Mode(
"vm",
RewriteDatabaseQuery(include=["fast_run", "py_only"]),
)

C = Mode("c", "fast_run")
C_VM = Mode("cvm", "fast_run")

NUMBA = Mode(
NumbaLinker(),
"numba",
RewriteDatabaseQuery(
include=["fast_run", "numba"],
exclude=[
Expand All @@ -472,8 +484,13 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
),
)

NUMBA_VM = Mode(
"numba_vm",
NUMBA._optimizer,
)

JAX = Mode(
JAXLinker(),
"jax",
RewriteDatabaseQuery(
include=["fast_run", "jax"],
exclude=[
Expand All @@ -489,7 +506,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
),
)
PYTORCH = Mode(
PytorchLinker(),
"pytorch",
RewriteDatabaseQuery(
include=["fast_run"],
exclude=[
Expand All @@ -508,8 +525,11 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
predefined_modes = {
"FAST_COMPILE": FAST_COMPILE,
"FAST_RUN": FAST_RUN,
"C": C,
"C_VM": C_VM,
"JAX": JAX,
"NUMBA": NUMBA,
"NUMBA_VM": NUMBA_VM,
"PYTORCH": PYTORCH,
}

Expand Down Expand Up @@ -574,6 +594,7 @@ def register_mode(name, mode):
Add a `Mode` which can be referred to by `name` in `function`.

"""
# TODO: Remove me
if name in predefined_modes:
raise ValueError(f"Mode name already taken: {name}")
predefined_modes[name] = mode
Expand Down
17 changes: 14 additions & 3 deletions pytensor/configdefaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,11 +370,22 @@ def add_compile_configvars():

if rc == 0 and config.cxx != "":
# Keep the default linker the same as the one for the mode FAST_RUN
linker_options = ["c|py", "py", "c", "c|py_nogc", "vm", "vm_nogc", "cvm_nogc"]
linker_options = [
"cvm",
"c|py",
"py",
"c",
"c|py_nogc",
"vm",
"vm_nogc",
"cvm_nogc",
"jax",
"numba_vm",
]
else:
# g++ is not present or the user disabled it,
# linker should default to python only.
linker_options = ["py", "vm_nogc"]
linker_options = ["py", "vm", "vm_nogc", "cvm" "jax", "numba", "numba_vm"]
if type(config).cxx.is_default:
# If the user provided an empty value for cxx, do not warn.
_logger.warning(
Expand All @@ -388,7 +399,7 @@ def add_compile_configvars():
"linker",
"Default linker used if the pytensor flags mode is Mode",
# Not mutable because the default mode is cached after the first use.
EnumStr("cvm", linker_options, mutable=False),
EnumStr("numba_vm", linker_options, mutable=False),
in_c_key=False,
)

Expand Down
76 changes: 76 additions & 0 deletions pytensor/link/numba/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import weakref
from hashlib import sha256
from pathlib import Path

from numba.core.caching import CacheImpl, _CacheLocator

from pytensor import config
from pytensor.graph.basic import Apply


NUMBA_PYTENSOR_CACHE_ENABLED = True
NUMBA_CACHE_PATH = config.base_compiledir / "numba"
NUMBA_CACHE_PATH.mkdir(exist_ok=True)
CACHED_SRC_FUNCTIONS = weakref.WeakKeyDictionary()


class NumbaPyTensorCacheLocator(_CacheLocator):
def __init__(self, py_func, py_file, hash):
self._py_func = py_func
self._py_file = py_file
self._hash = hash
# src_hash = hash(pytensor_loader._module_sources[self._py_file])
# self._hash = hash((src_hash, py_file, pytensor.__version__))

def ensure_cache_path(self):
pass

def get_cache_path(self):
"""
Return the directory the function is cached in.
"""
return NUMBA_CACHE_PATH

def get_source_stamp(self):
"""
Get a timestamp representing the source code's freshness.
Can return any picklable Python object.
"""
return 0

def get_disambiguator(self):
"""
Get a string disambiguator for this locator's function.
It should allow disambiguating different but similarly-named functions.
"""
return self._hash

@classmethod
def from_function(cls, py_func, py_file):
"""
Create a locator instance for the given function located in the given file.
"""
# py_file = Path(py_file).parent
# if py_file == (config.base_compiledir / "numba"):
if NUMBA_PYTENSOR_CACHE_ENABLED and py_func in CACHED_SRC_FUNCTIONS:
# print(f"Applies to {py_file}")
return cls(py_func, Path(py_file).parent, CACHED_SRC_FUNCTIONS[py_func])


CacheImpl._locator_classes.insert(0, NumbaPyTensorCacheLocator)


def cache_node_key(node: Apply, extra_key="") -> str:
op = node.op
return sha256(
str(
(
# Op signature
(type(op), op._props_dict() if hasattr(op, "_props_dict") else ""),
# Node signature
tuple((type(inp_type := inp.type), inp_type) for inp in node.inputs),
# Extra key given by the caller
extra_key,
),
).encode()
).hexdigest()
Loading
Loading