-
Notifications
You must be signed in to change notification settings - Fork 143
HACK: Implement Numba VM with caching of individual nodes #1604
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
left some comments; hope you don't mind
kwargs.setdefault("cache", config.numba__cache) | ||
kwargs.setdefault("no_cpython_wrapper", True) | ||
kwargs.setdefault("no_cfunc_wrapper", True) | ||
kwargs.setdefault("cache", True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that the numba currently can't detect changes in other modules, which could lead to outdated caches. I'm guessing that's why they've decided to make it opt-in
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure what you mean? So far numba caching has been fine, to the extent that it was actually used (not much). It plays a much bigger role in this PR approach, but we are also pretty-much customizing it's behavior completely.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was talking about this note from the docs (https://numba.readthedocs.io/en/stable/user/jit.html#cache):
Caching of compiled functions has several known limitations:
- The caching of compiled functions is not performed on a function-by-function basis. The cached function is the the main jit function, and all secondary functions (those called by the main function) are incorporated in the cache of the main function.
- Cache invalidation fails to recognize changes in functions defined in a different file. This means that when a main jit function calls functions that were imported from a different module, a change in those other modules will not be detected and the cache will not be updated. This carries the risk that “old” function code might be used in the calculations.
- Global variables are treated as constants. The cache will remember the value of the global variable at compilation time. On cache load, the cached function will not rebind to the new value of the global variable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the pointer!
Most relevant is point 2, but I don't think that's a big problem right now. Caching wasn't really working until now for our outer functions (FunctionGraph, Elemwise, Blockwise, RV), and the inner functions for which caching was working aren't using any shared functionality other than intrinsics like to_fixed_tuple
.
After this PR we'll move to our own caching system, and define the caching keys ourselves for the function we jit, and we can avoid that gotcha. Perhaps we have to do it for every function, not just the ones we get from codegen. Or we add the pytensor version as part of the key...
Points 1-3 aren't an issue for us AFAICT
else: | ||
from pytensor.link.numba.dispatch.basic import numba_njit | ||
|
||
jitted_fn = numba_njit(fn, no_cpython_wrapper=False, no_cfunc_wrapper=False) | ||
return jitted_fn | ||
jitted_fn = numba_njit(fn, no_cpython_wrapper=False, no_cfunc_wrapper=False) | ||
return jitted_fn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suppose there's no need for this else
branch 🤷🏻
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a stylistic choice. Should be enforced by ruff I guess
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Feel free to open an issue to add https://docs.astral.sh/ruff/rules/superfluous-else-return/ to our rules.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Feel free to open an issue to add docs.astral.sh/ruff/rules/superfluous-else-return to our rules.
Nah you're right; it's personal preference. I guess I've just been brainwashed by the default ruff style haha.
pytensor/link/utils.py
Outdated
with open(filename, "wb") as f: | ||
f.write(src.encode()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
with open(filename, "wb") as f: | |
f.write(src.encode()) | |
filename.write_bytes(src.encode()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TIL
I don't, but it's still too early for that sort of feedback. I'm just thinkering around at this point. |
aff2a9a
to
978b701
Compare
) | ||
|
||
signature = create_numba_signature(node, force_scalar=True) | ||
# signature = create_numba_signature(node, force_scalar=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was causing eager compilation during dispatch, nullifying any caching benefits
output_core_shapes, | ||
) | ||
|
||
core_signature = typingctx.resolve_function_type( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This causes compilation of the inner function, so I moved it to the codegen, otherwise we had to pay the cost, even if we could cache it? Not sure. But even if not, I guess we generally want to lazy compile by default?
978b701
to
3677f33
Compare
And make that the default backend
8709393
to
5bb8c9b
Compare
Disclaimer: this is still 100% on hack status, and I don't understand half of the things I did
When we tried #811 it was obvious that numba compile times were prohibitive.
This PR tries a different approach (still at the hacking stage), of using a mode more like the CVM, where individual nodes are jit_compiled, but the whole graph (i.e., the "VM") is not. This allows reusing pre-compiled/cached nodes across different functions, bringing the compilation cost down.
It requires interfacing with the numba cache locator to direct it to cached objects, which requires defining our own cache keys. Numba usually uses the file line position and contents as the cache key, but this doesn't work for dynamically generated files (at least not if stored in a random temp file) nor really for nested functions like those built for
Elemwise
, Also some Ops are string-generated, and others are regular python functions with globals which numba can usually cache. All this has to be re-examined.We are also not calling njit on inner Ops (the
store_core_outputs
/ScalarOp
) ofElemwise
, but instead doingregister_jittable
. This was needed for caching to work, because if we njit a function we always get a new object and once serialized the numba cache key will differ, whereasregister_jitable
overloads the function but returns it unchanged, which doesn't change the cache key.This requires us to move the jit away from the dispatch functionality.
Results:
We're finally approaching the speed of the previous backend (at least for single function compilation + eval). Probably could get it there with more optimizing, but a small slowdown is acceptable.
TODO:
cache=True
failures with locally defined functions numba/numba#10098