Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions jax/_src/pallas/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,15 @@ py_library(
"//jax/_src/lib",
] + py_deps("numpy"),
)

py_library(
name = "pallas_test_util",
testonly = True,
srcs = [
"pallas_test_util.py",
],
deps = [
":pallas",
"//jax/_src:test_util",
],
)
55 changes: 55 additions & 0 deletions jax/_src/pallas/pallas_test_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pallas test utilities."""
import sys

from jax._src import test_util as jtu
from jax._src.pallas import pallas_call
from jax.experimental import pallas as pl

use_mosaic_gpu = pallas_call._PALLAS_USE_MOSAIC_GPU.value


@jtu.with_config(jax_traceback_filtering="off")
class PallasTest(jtu.JaxTestCase):
INTERPRET: bool = False

def setUp(self):
if not jtu.test_device_matches(['cpu']) and self.INTERPRET:
self.skipTest('Only run interpret tests on CPU.')
if not self.INTERPRET:
# Running on accelerator
if jtu.test_device_matches(["cpu"]):
self.skipTest("On CPU the test works only in interpret mode")
if (jtu.test_device_matches(["cuda"]) and
not jtu.is_cuda_compute_capability_at_least("8.0")):
self.skipTest("Only works on GPU with capability >= sm80")
if (jtu.test_device_matches(["cuda"]) and use_mosaic_gpu and
not jtu.is_cuda_compute_capability_at_least("9.0")):
self.skipTest("Mosaic GPU requires capability >= sm90")
if sys.platform == "win32":
self.skipTest("Only works on non-Windows platforms")
super().setUp()

def pallas_call(self, *args, **kwargs):
return pl.pallas_call(*args, **kwargs, interpret=self.INTERPRET)


class PallasTPUTest(PallasTest):
"""A test case that only runs on TPUs or in interpret mode on CPU."""

def setUp(self):
if not jtu.test_device_matches(['tpu']) and not self.INTERPRET:
self.skipTest('Test requires TPUs')
super().setUp()
18 changes: 13 additions & 5 deletions tests/pallas/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ jax_multiplatform_test(
"//jax:pallas_gpu_ops",
"//jax:pallas_tpu",
"//jax:pallas_tpu_ops",
"//jax/_src/pallas:pallas_test_util",
] + py_deps([
"absl/testing",
"numpy",
Expand Down Expand Up @@ -143,6 +144,7 @@ jax_multiplatform_test(
"//jax:pallas_gpu", # build_cleaner: keep
"//jax:pallas_tpu",
"//jax:pallas_tpu_ops",
"//jax/_src/pallas:pallas_test_util",
] + py_deps([
"absl/testing:flagsaver",
"absl/testing",
Expand Down Expand Up @@ -187,6 +189,7 @@ jax_multiplatform_test(
"//jax:pallas_gpu", # build_cleaner: keep
"//jax:pallas_mosaic_gpu", # build_cleaner: keep
"//jax:pallas_tpu",
"//jax/_src/pallas:pallas_test_util",
] + py_deps([
"absl/testing:flagsaver",
"absl/testing",
Expand Down Expand Up @@ -418,7 +421,10 @@ jax_multiplatform_test(
jax_multiplatform_test(
name = "tpu_pallas_test",
srcs = ["tpu_pallas_test.py"],
enable_backends = ["tpu"],
enable_backends = [
"tpu",
"cpu",
],
enable_configs = [
"tpu_v5e",
"tpu_v5p",
Expand All @@ -428,6 +434,7 @@ jax_multiplatform_test(
"//jax:mesh_utils",
"//jax:pallas_tpu",
"//jax:pallas_tpu_ops",
"//jax/_src/pallas:pallas_test_util",
"//jax/extend",
] + py_deps([
"absl/testing",
Expand Down Expand Up @@ -455,6 +462,7 @@ jax_multiplatform_test(
deps = [
"//jax:pallas_tpu",
"//jax:pallas_tpu_ops",
"//jax/_src/pallas:pallas_test_util",
"//jax/extend",
] + py_deps([
"absl/testing",
Expand Down Expand Up @@ -499,16 +507,14 @@ jax_multiplatform_test(
srcs = [
"tpu_ops_test.py",
],
enable_backends = [
"cpu",
"tpu",
],
enable_backends = ["tpu"],
shard_count = 8,
deps = [
"//jax:pallas",
"//jax:pallas_gpu", # build_cleaner: keep
"//jax:pallas_tpu",
"//jax:pallas_tpu_ops",
"//jax/_src/pallas:pallas_test_util",
] + py_deps([
"absl/testing",
"hypothesis",
Expand Down Expand Up @@ -757,6 +763,7 @@ jax_multiplatform_test(
],
deps = [
"//jax:pallas_tpu_ops",
"//jax/_src/pallas:pallas_test_util",
] + py_deps([
"absl/testing",
"numpy",
Expand All @@ -775,6 +782,7 @@ jax_multiplatform_test(
deps = [
"//jax:pallas_tpu",
"//jax:pallas_tpu_ops",
"//jax/_src/pallas:pallas_test_util",
"//jax/extend",
] + py_deps([
"absl/testing",
Expand Down
20 changes: 5 additions & 15 deletions tests/pallas/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from jax._src import test_util as jtu
from jax._src.pallas import pallas_call
from jax._src.pallas import primitives as pallas_primitives
from jax._src.pallas import pallas_test_util as ptu
from jax.experimental import pallas as pl
from jax.interpreters import partial_eval as pe
import jax.numpy as jnp
Expand Down Expand Up @@ -274,21 +275,7 @@ def select_n_strategy(
]


class PallasBaseTest(jtu.JaxTestCase):
INTERPRET = False

def setUp(self):
if not self.INTERPRET:
if jtu.device_under_test() == "cpu":
self.skipTest("Only interpret mode supported on CPU")
if (jtu.test_device_matches(["cuda"]) and
not jtu.is_cuda_compute_capability_at_least("8.0")):
self.skipTest("Only works on GPUs with capability >= sm80")
if (jtu.test_device_matches(["cuda"]) and use_mosaic_gpu and
not jtu.is_cuda_compute_capability_at_least("9.0")):
self.skipTest("Mosaic GPU requires capability >= sm90")

super().setUp()
class PallasBaseTest(ptu.PallasTest):

@classmethod
def pallas_call(cls, *args, **kwargs):
Expand Down Expand Up @@ -715,6 +702,9 @@ def test_cast_from_sub_32bit(self, from_dtype, to_dtype, randomize):
or from_dtype in {"int2", "uint2"}
):
self.skipTest("sub-byte casts are buggy on GPU") # b/391292861
if self.INTERPRET and (to_dtype in {"int2", "uint2"} or
from_dtype in {"int2", "uint2"}):
self.skipTest("Test fails on CPU.")
if from_dtype == "float16" or to_dtype == "float16" and not sut_is_mosaic_gpu:
self.skipTest("float16 is only supported with Mosaic GPU")
if sut_is_mosaic_gpu:
Expand Down
46 changes: 14 additions & 32 deletions tests/pallas/pallas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from jax._src import dtypes
from jax._src import hijax
from jax._src import test_util as jtu
from jax._src.pallas import pallas_test_util as ptu
from jax.experimental import pallas as pl
import jax.export
import jax.numpy as jnp
Expand Down Expand Up @@ -92,26 +93,7 @@ def body(i, acc):
return matmul_kernel(x, y)


@jtu.with_config(jax_traceback_filtering="off")
class PallasBaseTest(jtu.JaxTestCase):
INTERPRET = False

def setUp(self):
if jtu.test_device_matches(["cpu"]) and not self.INTERPRET:
self.skipTest("On CPU the test works only in interpret mode")
if (jtu.test_device_matches(["cuda"]) and
not jtu.is_cuda_compute_capability_at_least("8.0")):
self.skipTest("Only works on GPU with capability >= sm80")
if sys.platform == "win32" and not self.INTERPRET:
self.skipTest("Only works on non-Windows platforms")

super().setUp()

def pallas_call(self, *args, **kwargs):
return pl.pallas_call(*args, **kwargs, interpret=self.INTERPRET)


class PallasCallTest(PallasBaseTest):
class PallasCallTest(ptu.PallasTest):

def test_add_one(self):
if jtu.test_device_matches(["tpu"]) and not self.INTERPRET:
Expand Down Expand Up @@ -757,7 +739,7 @@ class PallasCallInterpretTest(PallasCallTest):
INTERPRET = True


class PallasCallElementIndexingTest(PallasBaseTest):
class PallasCallElementIndexingTest(ptu.PallasTest):

def test_block_spec_element(self):
def show_program_ids(
Expand Down Expand Up @@ -893,7 +875,7 @@ class PallasCallElementIndexingInterpretTest(PallasCallElementIndexingTest):
INTERPRET = True


class PallasCallBoundedSliceIndexingTest(PallasBaseTest):
class PallasCallBoundedSliceIndexingTest(ptu.PallasTest):

def setUp(self):
super().setUp()
Expand Down Expand Up @@ -922,7 +904,7 @@ def kernel(x_ref, o_ref):
),
)(x)

class ApiErrorTest(PallasBaseTest):
class ApiErrorTest(ptu.PallasTest):
def test_pallas_call_kernel_args_mismatch(self):
a = np.arange(256, dtype=np.int32)
f = self.pallas_call(lambda x_ref: None, # Missing o_ref
Expand Down Expand Up @@ -1174,7 +1156,7 @@ class ApiErrorInterpretTest(ApiErrorTest):
INTERPRET = True


class PallasCallInputOutputAliasingTest(PallasBaseTest):
class PallasCallInputOutputAliasingTest(ptu.PallasTest):

def test_vector_input_output_aliasing(self):
# Input needs to be big so it doesn't fit in VMEM
Expand Down Expand Up @@ -1285,11 +1267,11 @@ def f(x_scalar_in, x_vector_in):
print(x_vector)


class PallasCallInputOutputAliasingInterpretTest(PallasBaseTest):
class PallasCallInputOutputAliasingInterpretTest(ptu.PallasTest):
INTERPRET = True


class PallasControlFlowTest(PallasBaseTest):
class PallasControlFlowTest(ptu.PallasTest):

def setUp(self):
super().setUp()
Expand Down Expand Up @@ -2078,7 +2060,7 @@ class PallasControlFlowInterpretTest(PallasControlFlowTest):
]


class PallasCallAutodifferentiationTest(PallasBaseTest):
class PallasCallAutodifferentiationTest(ptu.PallasTest):

def setUp(self):
super().setUp()
Expand Down Expand Up @@ -2193,7 +2175,7 @@ class PallasCallAutodifferentiationInterpretTest(PallasCallAutodifferentiationTe
INTERPRET = True


class PallasOutOfBoundsInterpretTest(PallasBaseTest):
class PallasOutOfBoundsInterpretTest(ptu.PallasTest):
INTERPRET = True

def test_interpret_mode_out_of_bounds_access(self):
Expand Down Expand Up @@ -2273,7 +2255,7 @@ def _():
np.testing.assert_allclose(out, expected, atol=atol, rtol=rtol)


class PallasCheckifyTest(PallasBaseTest):
class PallasCheckifyTest(ptu.PallasTest):
INTERPRET = False

def test_basic_runtime_assert(self):
Expand Down Expand Up @@ -2453,7 +2435,7 @@ class PallasCheckifyInterpretTest(PallasCheckifyTest):
INTERPRET = True


class PallasCallNamedGridTest(PallasBaseTest):
class PallasCallNamedGridTest(ptu.PallasTest):
def test_named_grid(self):

def kernel(x_ref, y_ref):
Expand Down Expand Up @@ -2553,7 +2535,7 @@ def kernel(x_ref, y_ref):
)


class SymbolicPallasTest(PallasBaseTest):
class SymbolicPallasTest(ptu.PallasTest):

def test_simple_symbolic_matmul_export(self):
if jtu.test_device_matches(["gpu"]):
Expand Down Expand Up @@ -2747,7 +2729,7 @@ def index_to_lojax(xt: jax.Ref) -> jax.Array:
index_p.to_lojax = index_to_lojax


class PallasHiJaxTest(PallasBaseTest):
class PallasHiJaxTest(ptu.PallasTest):

def test_pass_weird_tuple_into_pallas_call(self):

Expand Down
18 changes: 3 additions & 15 deletions tests/pallas/tpu_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from jax import lax
from jax._src import dtypes
from jax._src import test_util as jtu
from jax._src.pallas import pallas_test_util as ptu
from jax.experimental import pallas as pl
import jax.numpy as jnp
import numpy as np
Expand Down Expand Up @@ -76,22 +77,8 @@ def rand(
raise NotImplementedError(f"Unsupported random data generation for {dtype=}")


class PallasBaseTest(jtu.JaxTestCase):
INTERPRET = False

def setUp(self):
if not jtu.test_device_matches(["tpu"]):
self.skipTest("Test only supported on TPU.")

super().setUp()

@classmethod
def pallas_call(cls, *args, **kwargs):
return pl.pallas_call(*args, interpret=cls.INTERPRET, **kwargs)


@jtu.thread_unsafe_test_class(condition=not jtu.hypothesis_is_thread_safe())
class OpsTest(PallasBaseTest):
class OpsTest(ptu.PallasTPUTest):

@parameterized.product(
from_dtype=_JAX_DTYPES,
Expand Down Expand Up @@ -883,5 +870,6 @@ def kernel(x_ref, o_ref):

np.testing.assert_array_equal(result, expected)


if __name__ == "__main__":
absltest.main()
Loading
Loading