|
14 | 14 | """ |
15 | 15 | This submodule defines a strategy structure for defining custom plxpr interpreters |
16 | 16 | """ |
17 | | -# pylint: disable=no-self-use |
18 | 17 | from copy import copy |
19 | 18 | from functools import partial, wraps |
| 19 | + |
| 20 | +# pylint: disable=no-self-use, wrong-import-position |
| 21 | +from importlib.metadata import version |
20 | 22 | from typing import Callable, Optional, Sequence |
21 | 23 |
|
22 | 24 | import jax |
| 25 | +from packaging.version import Version |
23 | 26 |
|
24 | 27 | import pennylane as qml |
25 | 28 | from pennylane import math |
@@ -635,15 +638,22 @@ class FlattenedInterpreter(PlxprInterpreter): |
635 | 638 | """ |
636 | 639 |
|
637 | 640 |
|
| 641 | +jax_version = version("jax") |
| 642 | +if Version(jax_version) > Version("0.6.2"): # pragma: no cover |
| 643 | + from jax._src.pjit import jit_p as pjit_p |
| 644 | +else: # pragma: no cover |
| 645 | + from jax._src.pjit import pjit_p |
| 646 | + |
| 647 | + |
638 | 648 | # pylint: disable=protected-access |
639 | | -@FlattenedInterpreter.register_primitive(jax._src.pjit.pjit_p) |
| 649 | +@FlattenedInterpreter.register_primitive(pjit_p) |
640 | 650 | def _(self, *invals, jaxpr, **params): |
641 | 651 | if jax.config.jax_dynamic_shapes: |
642 | 652 | # just evaluate it so it doesn't throw dynamic shape errors |
643 | 653 | return copy(self).eval(jaxpr.jaxpr, jaxpr.consts, *invals) |
644 | 654 |
|
645 | | - subfuns, params = jax._src.pjit.pjit_p.get_bind_params({"jaxpr": jaxpr, **params}) |
646 | | - return jax._src.pjit.pjit_p.bind(*subfuns, *invals, **params) |
| 655 | + subfuns, params = pjit_p.get_bind_params({"jaxpr": jaxpr, **params}) |
| 656 | + return pjit_p.bind(*subfuns, *invals, **params) |
647 | 657 |
|
648 | 658 |
|
649 | 659 | @FlattenedInterpreter.register_primitive(while_loop_prim) |
|
0 commit comments