Skip to content

Commit b5cc3bf

Browse files
andrijapaumudit2812mlxd
authored
[Bug Fix Release] raise warning if jax>0.6.2 is installed (#7949)
**Context:** Jax released a [new version](https://github.com/jax-ml/jax/releases/tag/jax-v0.7.0) with a lot of breaking changes. Unfortunately our `v0.42` release is not compatible with this version and so there are runtime errors. ℹ️ : This PR was created by branching the v0.42.0 tag. **Description of the Change:** - The fix will check if the user has `jax` and if they do it will perform a version check to ensure compatibility. ```python >>> import pennylane as qml ``` ``` ~/Documents/pennylane/pennylane/__init__.py:29: RuntimeWarning: PennyLane is not yet compatible with JAX versions > 0.6.2. You have version 0.7.0 installed. Please downgrade JAX to <=0.6.2 to avoid runtime errors. ``` - We need to install the stable version of `pennylane-lightning` and `pennylane-catalyst` instead of the latest, because the latest version of these packages is now incompatible with this branch (they are using the `pennylane.exceptions` module for some custom exceptions). This change will not be merged into master. **Benefits:** Better UI. **Possible Drawbacks:** None. [sc-96104] --------- Co-authored-by: Mudit Pandey <[email protected]> Co-authored-by: Lee James O'Riordan <[email protected]>
1 parent 2a74df8 commit b5cc3bf

File tree

14 files changed

+61
-20
lines changed

14 files changed

+61
-20
lines changed

.github/workflows/interface-dependency-versions.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,11 @@ jobs:
7575

7676
- name: Nightly Catalyst Version
7777
id: catalyst
78-
run: echo "nightly=--index https://test.pypi.org/simple/ --prerelease=allow --upgrade-package PennyLane-Catalyst PennyLane-Catalyst" >> $GITHUB_OUTPUT
78+
run: echo "nightly=--upgrade-package PennyLane-Catalyst PennyLane-Catalyst" >> $GITHUB_OUTPUT
7979

8080
- name: PennyLane-Lightning Latest Version
8181
id: pennylane-lightning
82-
run: echo "latest=--index https://test.pypi.org/simple/ --prerelease=allow --upgrade-package PennyLane-Lightning PennyLane-Lightning" >> $GITHUB_OUTPUT
82+
run: echo "latest=--upgrade-package PennyLane-Lightning PennyLane-Lightning" >> $GITHUB_OUTPUT
8383

8484
outputs:
8585
catalyst-jax-version: jax==${{ steps.catalyst-jax.outputs.version }} jaxlib==${{ steps.catalyst-jax.outputs.version }}

doc/development/release_notes.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ Release notes
33

44
This page contains the release notes for PennyLane.
55

6+
.. mdinclude:: ../releases/changelog-0.42.1.md
7+
68
.. mdinclude:: ../releases/changelog-0.42.0.md
79

810
.. mdinclude:: ../releases/changelog-0.41.1.md

doc/releases/changelog-0.42.0.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
:orphan:
22

3-
# Release 0.42.0 (current release)
3+
# Release 0.42.0
44

55
<h3>New features since last release</h3>
66

doc/releases/changelog-0.42.1.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
:orphan:
2+
3+
# Release 0.42.1 (current release)
4+
5+
<h3>Bug fixes 🐛</h3>
6+
7+
* A warning is raised if PennyLane is imported and a version of JAX greater than 0.6.2 is installed.
8+
[(#7949)](https://github.com/PennyLaneAI/pennylane/pull/7949)
9+
10+
<h3>Contributors ✍️</h3>
11+
12+
This release contains contributions from (in alphabetical order):
13+
14+
Andrija Paurevic

pennylane/__init__.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,21 @@
186186
from pennylane.liealg import lie_closure, structure_constants, center
187187
import pennylane.qnn
188188

189+
190+
from importlib.metadata import version as _metadata_version
191+
from importlib.util import find_spec as _find_spec
192+
from packaging.version import Version as _Version
193+
194+
if _find_spec("jax") is not None:
195+
if (jax_version := _Version(_metadata_version("jax"))) > _Version("0.6.2"): # pragma: no cover
196+
warnings.warn(
197+
"PennyLane is not yet compatible with JAX versions > 0.6.2. "
198+
f"You have version {jax_version} installed. "
199+
"Please downgrade JAX to 0.6.2 to avoid runtime errors using "
200+
"python -m pip install jax~=0.6.0 jaxlib~=0.6.0",
201+
RuntimeWarning,
202+
)
203+
189204
# Look for an existing configuration file
190205
default_config = Configuration("config.toml")
191206

pennylane/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@
1616
Version number (major.minor.patch[-label])
1717
"""
1818

19-
__version__ = "0.42.0"
19+
__version__ = "0.42.1"

pennylane/capture/base_interpreter.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,15 @@
1414
"""
1515
This submodule defines a strategy structure for defining custom plxpr interpreters
1616
"""
17-
# pylint: disable=no-self-use
1817
from copy import copy
1918
from functools import partial, wraps
19+
20+
# pylint: disable=no-self-use, wrong-import-position
21+
from importlib.metadata import version
2022
from typing import Callable, Optional, Sequence
2123

2224
import jax
25+
from packaging.version import Version
2326

2427
import pennylane as qml
2528
from pennylane import math
@@ -635,15 +638,22 @@ class FlattenedInterpreter(PlxprInterpreter):
635638
"""
636639

637640

641+
jax_version = version("jax")
642+
if Version(jax_version) > Version("0.6.2"): # pragma: no cover
643+
from jax._src.pjit import jit_p as pjit_p
644+
else: # pragma: no cover
645+
from jax._src.pjit import pjit_p
646+
647+
638648
# pylint: disable=protected-access
639-
@FlattenedInterpreter.register_primitive(jax._src.pjit.pjit_p)
649+
@FlattenedInterpreter.register_primitive(pjit_p)
640650
def _(self, *invals, jaxpr, **params):
641651
if jax.config.jax_dynamic_shapes:
642652
# just evaluate it so it doesn't throw dynamic shape errors
643653
return copy(self).eval(jaxpr.jaxpr, jaxpr.consts, *invals)
644654

645-
subfuns, params = jax._src.pjit.pjit_p.get_bind_params({"jaxpr": jaxpr, **params})
646-
return jax._src.pjit.pjit_p.bind(*subfuns, *invals, **params)
655+
subfuns, params = pjit_p.get_bind_params({"jaxpr": jaxpr, **params})
656+
return pjit_p.bind(*subfuns, *invals, **params)
647657

648658

649659
@FlattenedInterpreter.register_primitive(while_loop_prim)

pennylane/capture/make_plxpr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def fn(x):
144144
if not has_jax: # pragma: no cover
145145
raise ImportError(
146146
"Module jax is required for the ``make_plxpr`` function. "
147-
"You can install jax via: pip install jax"
147+
"You can install jax via: pip install jax~=0.6.0"
148148
)
149149

150150
if not qml.capture.enabled():

pennylane/devices/qubit/apply_operation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -711,7 +711,7 @@ def _evolve_state_vector_under_parametrized_evolution(
711711
except ImportError as e: # pragma: no cover
712712
raise ImportError(
713713
"Module jax is required for the ``ParametrizedEvolution`` class. "
714-
"You can install jax via: pip install jax"
714+
"You can install jax via: pip install jax~=0.6.0"
715715
) from e
716716

717717
if operation.data is None or operation.t is None:

pennylane/gradients/pulse_gradient.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def _assert_has_jax(transform_name):
5959
if not has_jax: # pragma: no cover
6060
raise ImportError(
6161
f"Module jax is required for the {transform_name} gradient transform. "
62-
"You can install jax via: pip install jax jaxlib"
62+
"You can install jax via: pip install jax~=0.6.0 jaxlib~=0.6.0"
6363
)
6464

6565

0 commit comments

Comments
 (0)