File tree Expand file tree Collapse file tree 4 files changed +1
-9
lines changed
Expand file tree Collapse file tree 4 files changed +1
-9
lines changed Original file line number Diff line number Diff line change @@ -2212,7 +2212,7 @@ def test_xla_computation_psum_constant(self):
22122212 f = lambda : jax .lax .psum (1 , "i" )
22132213 api .xla_computation (f , axis_env = [("i" , 2 )])() # doesn't crash
22142214
2215- @jtu .skip_on_devices ("cpu" , "gpu" )
2215+ @jtu .skip_on_devices ("cpu" )
22162216 @jtu .ignore_warning (message = "Some donated buffers were not usable" )
22172217 def test_xla_computation_donate_argnums (self ):
22182218 api .xla_computation (lambda x : None , donate_argnums = (0 ,))(3 ) # doesn't crash
Original file line number Diff line number Diff line change @@ -144,7 +144,6 @@ def testXmap(self):
144144 ans .block_until_ready ()
145145
146146 @jtu .ignore_warning (message = ".*is an experimental.*" )
147- @jtu .skip_on_devices ("cpu" , "gpu" )
148147 def testPjit (self ):
149148 if jax .device_count () < 2 :
150149 raise SkipTest ("test requires >=2 devices" )
Original file line number Diff line number Diff line change 1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14- import unittest
1514
1615from absl .testing import absltest
1716from absl .testing import parameterized
@@ -190,8 +189,6 @@ def f(x):
190189 jax .make_jaxpr (f )(jnp .arange (jax .local_device_count ()))
191190
192191 def test_pjit_inherits_effects (self ):
193- if jax .default_backend () not in {'gpu' , 'tpu' }:
194- raise unittest .SkipTest ("pjit only supports GPU and TPU backends" )
195192 def f (x ):
196193 effect_p .bind (effect = 'foo' )
197194 effect_p .bind (effect = 'bar' )
Original file line number Diff line number Diff line change @@ -212,8 +212,6 @@ class XMapTestCase(jtu.BufferDonationTestCase):
212212# A mixin that enables SPMD lowering tests
213213class SPMDTestMixin :
214214 def setUp (self ):
215- if jtu .device_under_test () not in ['tpu' , 'gpu' ]:
216- raise SkipTest
217215 super ().setUp ()
218216 jtu .set_spmd_lowering_flag (True )
219217
@@ -223,8 +221,6 @@ def tearDown(self):
223221
224222class ManualSPMDTestMixin :
225223 def setUp (self ):
226- if jtu .device_under_test () not in ['tpu' , 'gpu' ]:
227- raise SkipTest
228224 if not hasattr (xla_client .OpSharding .Type , "MANUAL" ):
229225 raise SkipTest
230226 super ().setUp ()
You can’t perform that action at this time.
0 commit comments