Skip to content

Commit 634f58c

Browse files
hawkinspjax authors
authored andcommitted
Enable a number of tests on GPU.
In particular, pjit/xmap work on CPU these days. PiperOrigin-RevId: 446085110
1 parent b7293d5 commit 634f58c

File tree

4 files changed

+1
-9
lines changed

4 files changed

+1
-9
lines changed

tests/api_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2212,7 +2212,7 @@ def test_xla_computation_psum_constant(self):
22122212
f = lambda: jax.lax.psum(1, "i")
22132213
api.xla_computation(f, axis_env=[("i", 2)])() # doesn't crash
22142214

2215-
@jtu.skip_on_devices("cpu", "gpu")
2215+
@jtu.skip_on_devices("cpu")
22162216
@jtu.ignore_warning(message="Some donated buffers were not usable")
22172217
def test_xla_computation_donate_argnums(self):
22182218
api.xla_computation(lambda x: None, donate_argnums=(0,))(3) # doesn't crash

tests/debug_nans_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,6 @@ def testXmap(self):
144144
ans.block_until_ready()
145145

146146
@jtu.ignore_warning(message=".*is an experimental.*")
147-
@jtu.skip_on_devices("cpu", "gpu")
148147
def testPjit(self):
149148
if jax.device_count() < 2:
150149
raise SkipTest("test requires >=2 devices")

tests/jaxpr_effects_test.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import unittest
1514

1615
from absl.testing import absltest
1716
from absl.testing import parameterized
@@ -190,8 +189,6 @@ def f(x):
190189
jax.make_jaxpr(f)(jnp.arange(jax.local_device_count()))
191190

192191
def test_pjit_inherits_effects(self):
193-
if jax.default_backend() not in {'gpu', 'tpu'}:
194-
raise unittest.SkipTest("pjit only supports GPU and TPU backends")
195192
def f(x):
196193
effect_p.bind(effect='foo')
197194
effect_p.bind(effect='bar')

tests/xmap_test.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -212,8 +212,6 @@ class XMapTestCase(jtu.BufferDonationTestCase):
212212
# A mixin that enables SPMD lowering tests
213213
class SPMDTestMixin:
214214
def setUp(self):
215-
if jtu.device_under_test() not in ['tpu', 'gpu']:
216-
raise SkipTest
217215
super().setUp()
218216
jtu.set_spmd_lowering_flag(True)
219217

@@ -223,8 +221,6 @@ def tearDown(self):
223221

224222
class ManualSPMDTestMixin:
225223
def setUp(self):
226-
if jtu.device_under_test() not in ['tpu', 'gpu']:
227-
raise SkipTest
228224
if not hasattr(xla_client.OpSharding.Type, "MANUAL"):
229225
raise SkipTest
230226
super().setUp()

0 commit comments

Comments
 (0)