Skip to content

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Sep 2, 2025

Disclaimer: this is still 100% on hack status, and I don't understand half of the things I did

When we tried #811 it was obvious that numba compile times were prohibitive.

This PR tries a different approach (still at the hacking stage), of using a mode more like the CVM, where individual nodes are jit_compiled, but the whole graph (i.e., the "VM") is not. This allows reusing pre-compiled/cached nodes across different functions, bringing the compilation cost down.

It requires interfacing with the numba cache locator to direct it to cached objects, which requires defining our own cache keys. Numba usually uses the file line position and contents as the cache key, but this doesn't work for dynamically generated files (at least not if stored in a random temp file) nor really for nested functions like those built for Elemwise, Also some Ops are string-generated, and others are regular python functions with globals which numba can usually cache. All this has to be re-examined.

We are also not calling njit on inner Ops (the store_core_outputs / ScalarOp) of Elemwise, but instead doing register_jittable. This was needed for caching to work, because if we njit a function we always get a new object and once serialized the numba cache key will differ, whereas register_jitable overloads the function but returns it unchanged, which doesn't change the cache key.

This requires us to move the jit away from the dispatch functionality.

Results:

Second pass over tests/tensor/rewriting/test_basic.py (to allow compiling everything first):
2s with C_VM backend
54s with Numba backend
34s with Numba VM without cache
4s with Numba VM with cache

We're finally approaching the speed of the previous backend (at least for single function compilation + eval). Probably could get it there with more optimizing, but a small slowdown is acceptable.

TODO:

  • We are still writing python strings to the filesystem to compile them, this is probably not needed as explored in Cache numba stuff #1326 (last commit?)
  • We have to compile some functions that don't really need so we can cache it, such as with Elemwise. This is related to https://numba.discourse.group/t/caching-redefined-functions/3057 but I don't yet have a clear picture.
  • Proper cache keys, I just hacked some quick things. Perhaps use the source code of the generated functions?
    • Composite key is certainly broken
    • Cache whole FunctionGraph, this would avoid recompiling identical graphs in the regular Numba mode, not just NumbaCVM (it's also needed for correct cache of Composite/Blockwise/Scan,OpFromGraph (i.e., anything with inner Ops)).
  • Figure out what happens with Ops that run with object mode?
  • Handle functions with pointers / large constants that can't traditionally be cached (not sure what's happening now). Related to cache=True failures with locally defined functions numba/numba#10098
  • Benchmark slowdown from the "VM" approach in realistic functions. Consider using/adapting CVM to orchestrate the calls to the individuals nodes (would need to use the thunk approach). Right now the VM is the python source code generated by the outermost unjitted FunctionGraph

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

Copy link
Contributor

@jorenham jorenham left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

left some comments; hope you don't mind

kwargs.setdefault("cache", config.numba__cache)
kwargs.setdefault("no_cpython_wrapper", True)
kwargs.setdefault("no_cfunc_wrapper", True)
kwargs.setdefault("cache", True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that the numba currently can't detect changes in other modules, which could lead to outdated caches. I'm guessing that's why they've decided to make it opt-in

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what you mean? So far numba caching has been fine, to the extent that it was actually used (not much). It plays a much bigger role in this PR approach, but we are also pretty-much customizing it's behavior completely.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was talking about this note from the docs (https://numba.readthedocs.io/en/stable/user/jit.html#cache):

Caching of compiled functions has several known limitations:

  • The caching of compiled functions is not performed on a function-by-function basis. The cached function is the the main jit function, and all secondary functions (those called by the main function) are incorporated in the cache of the main function.
  • Cache invalidation fails to recognize changes in functions defined in a different file. This means that when a main jit function calls functions that were imported from a different module, a change in those other modules will not be detected and the cache will not be updated. This carries the risk that “old” function code might be used in the calculations.
  • Global variables are treated as constants. The cache will remember the value of the global variable at compilation time. On cache load, the cached function will not rebind to the new value of the global variable.

Copy link
Member Author

@ricardoV94 ricardoV94 Sep 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the pointer!

Most relevant is point 2, but I don't think that's a big problem right now. Caching wasn't really working until now for our outer functions (FunctionGraph, Elemwise, Blockwise, RV), and the inner functions for which caching was working aren't using any shared functionality other than intrinsics like to_fixed_tuple.

After this PR we'll move to our own caching system, and define the caching keys ourselves for the function we jit, and we can avoid that gotcha. Perhaps we have to do it for every function, not just the ones we get from codegen. Or we add the pytensor version as part of the key...

Points 1-3 aren't an issue for us AFAICT

Comment on lines 19 to 23
else:
from pytensor.link.numba.dispatch.basic import numba_njit

jitted_fn = numba_njit(fn, no_cpython_wrapper=False, no_cfunc_wrapper=False)
return jitted_fn
jitted_fn = numba_njit(fn, no_cpython_wrapper=False, no_cfunc_wrapper=False)
return jitted_fn
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose there's no need for this else branch 🤷🏻

Copy link
Member Author

@ricardoV94 ricardoV94 Sep 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a stylistic choice. Should be enforced by ruff I guess

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to open an issue to add https://docs.astral.sh/ruff/rules/superfluous-else-return/ to our rules.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to open an issue to add docs.astral.sh/ruff/rules/superfluous-else-return to our rules.

Nah you're right; it's personal preference. I guess I've just been brainwashed by the default ruff style haha.

Comment on lines 594 to 595
with open(filename, "wb") as f:
f.write(src.encode())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
with open(filename, "wb") as f:
f.write(src.encode())
filename.write_bytes(src.encode())

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TIL

@ricardoV94
Copy link
Member Author

left some comments; hope you don't mind

I don't, but it's still too early for that sort of feedback. I'm just thinkering around at this point.

@ricardoV94 ricardoV94 force-pushed the numba_cache branch 2 times, most recently from aff2a9a to 978b701 Compare September 14, 2025 10:30
@ricardoV94 ricardoV94 changed the title Implement Numba VM Implement Numba VM with caching of individual nodes Sep 14, 2025
@ricardoV94 ricardoV94 changed the title Implement Numba VM with caching of individual nodes HACK: Implement Numba VM with caching of individual nodes Sep 14, 2025
)

signature = create_numba_signature(node, force_scalar=True)
# signature = create_numba_signature(node, force_scalar=True)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was causing eager compilation during dispatch, nullifying any caching benefits

output_core_shapes,
)

core_signature = typingctx.resolve_function_type(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This causes compilation of the inner function, so I moved it to the codegen, otherwise we had to pay the cost, even if we could cache it? Not sure. But even if not, I guess we generally want to lazy compile by default?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants