Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions jax/_src/pallas/mosaic/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ pytype_strict_library(
deps = [
":core",
":lowering",
":sc_core",
":sc_lowering",
"//jax",
"//jax/_src:core",
Expand Down
67 changes: 67 additions & 0 deletions jax/_src/pallas/mosaic/sc_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from jax._src.pallas import core as pallas_core
from jax._src.pallas.mosaic import core as tpu_core
from jax._src.pallas.mosaic import lowering as tc_lowering
from jax._src.pallas.mosaic import sc_core
from jax._src.pallas.mosaic import sc_lowering
from jax._src.state import primitives as state_primitives
from jax._src.state import types as state_types
Expand Down Expand Up @@ -634,6 +635,72 @@ def _reduce_sum_lowering_rule(
_cumsum_lowering_rule(ctx, x, 0, reverse=False), [], [vec_dim - 1])


masked_sort_p = jax_core.Primitive("masked_sort")
masked_sort_p.multiple_results = True

@masked_sort_p.def_abstract_eval
def _masked_sort_abstract_eval(keys, values, *maybe_mask, descending):
del descending # Unused.
supported_shape = (sc_core.get_sparse_core_info().num_lanes,)
if keys.dtype not in (jnp.int32, jnp.float32):
raise NotImplementedError(
f"sort_key_val: keys dtype {keys.dtype} should be int32 or float32")
if keys.shape != supported_shape:
raise ValueError(f"keys shape {keys.shape} must be {supported_shape}")
if jnp.dtype(values.dtype).itemsize != 4:
raise NotImplementedError(
f"sort_key_val: values dtype {values.dtype} should be 32 bits")
if values.shape != supported_shape:
raise ValueError(f"values shape {values.shape} must be {supported_shape}")
if maybe_mask:
[mask] = maybe_mask
if not jnp.issubdtype(mask.dtype, jnp.bool):
raise TypeError(f"mask dtype {mask.dtype} is not boolean")
if mask.shape != supported_shape:
raise ValueError(f"mask shape {mask.shape} must be {supported_shape}")
return keys, values, *maybe_mask

@sc_lowering.register_lowering_rule(masked_sort_p)
def _masked_sort_lowering_rule(
ctx: sc_lowering.LoweringRuleContext, keys, values, *maybe_mask, descending):
del ctx # Unused.
if maybe_mask:
[mask] = maybe_mask
else:
mask_type = ir.VectorType.get(
[sc_core.get_sparse_core_info().num_lanes],
ir.IntegerType.get_signless(1))
mask = arith.constant(mask_type, ir.DenseElementsAttr.get_splat(
mask_type, ir.BoolAttr.get(True)))
out_mask, sorted_keys, sorted_values = tpu.sort(
mask.type, keys.type, values.type, keys, values, mask=mask,
descending=descending
)
if maybe_mask:
return sorted_keys, sorted_values, out_mask
return sorted_keys, sorted_values

def sort_key_val(
keys: jax.Array, values: jax.Array, *,
mask: jax.Array | None = None, descending: bool = False
) -> jax.Array:
"""Sorts keys and values, pushing invalid elements to the last positions.

Args:
keys: An array of integers or floats.
values: An array of values corresponding to the keys.
mask: An optional array of booleans, which specifies which elements of
`keys` and `values` are valid. If `None`, all elements are valid.
descending: Whether to sort in descending order.

Returns:
sorted_keys, sorted_values, [output_mask]: The sorted keys and values, and,
if a mask was given, the corresponding mask for output keys and values.
"""
maybe_mask = () if mask is None else (mask,)
return masked_sort_p.bind(keys, values, *maybe_mask, descending=descending)


parallel_loop_p = jax_core.Primitive("parallel_loop")
parallel_loop_p.is_effectful = lambda params: bool(params["jaxpr"].effects) # type: ignore
parallel_loop_p.multiple_results = True
Expand Down
1 change: 1 addition & 0 deletions jax/experimental/pallas/tpu_sc.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from jax._src.pallas.mosaic.sc_primitives import PackFormat as PackFormat
from jax._src.pallas.mosaic.sc_primitives import parallel_loop as parallel_loop
from jax._src.pallas.mosaic.sc_primitives import scan_count as scan_count
from jax._src.pallas.mosaic.sc_primitives import sort_key_val as sort_key_val
from jax._src.pallas.mosaic.sc_primitives import store_compressed as store_compressed
from jax._src.pallas.mosaic.sc_primitives import store_scatter as store_scatter
from jax._src.pallas.mosaic.sc_primitives import subcore_barrier as subcore_barrier
Expand Down
50 changes: 50 additions & 0 deletions tests/pallas/tpu_sparsecore_pallas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1654,6 +1654,56 @@ def kernel(x_ref, indices_ref, out_ref):

np.testing.assert_array_equal(kernel(x, indices), x[indices])

@parameterized.product(
keys_dtype=[np.int32, np.float32],
values_dtype=[np.int32, np.float32],
use_mask=[False, True],
descending=[False, True],
)
def test_sort_key_val(self, keys_dtype, values_dtype, use_mask, descending):
if not jtu.is_cloud_tpu_at_least(2025, 12, 2):
self.skipTest("Test requires a newer libtpu")

vec_dim = self.sc_info.num_lanes
keys = np.arange(vec_dim, dtype=keys_dtype)
np.random.shuffle(keys)
keys[3] = keys[1] # Verify sort stability.
values = np.arange(vec_dim, dtype=values_dtype)
np.random.shuffle(values)
mask = np.random.choice([True, False], size=vec_dim) if use_mask else None
maybe_mask_arg = (mask.astype(jnp.int32),) if use_mask else ()

@self.vector_subcore_kernel(out_shape=(keys, values, *maybe_mask_arg))
def kernel(*args):
if use_mask:
mask_ref, *args, o_mask_ref = args
mask = mask_ref[...].astype(jnp.bool)
else:
mask, o_mask_ref = None, None
keys_ref, values_ref, o_keys_ref, o_vals_ref = args
o_keys_ref[...], o_vals_ref[...], *maybe_out_mask = plsc.sort_key_val(
keys_ref[...], values_ref[...], mask=mask, descending=descending)
if use_mask:
[out_mask] = maybe_out_mask
o_mask_ref[...] = out_mask.astype(jnp.int32)

out_keys, out_values, *maybe_out_mask = kernel(
*maybe_mask_arg, keys, values)

keys_arg = keys
if descending:
keys_arg = -keys_arg
if use_mask:
keys_arg = jnp.where(mask, keys_arg, 100)
_, gt_keys = jax.lax.sort_key_val(keys_arg, keys)
_, gt_values = jax.lax.sort_key_val(keys_arg, values)
if use_mask:
[out_mask] = maybe_out_mask
gt_out_mask = jnp.arange(vec_dim) < mask.sum()
np.testing.assert_array_equal(out_mask, gt_out_mask.astype(jnp.int32))
np.testing.assert_array_equal(out_keys, gt_keys)
np.testing.assert_array_equal(out_values, gt_values)

@parameterized.product(dtype=[np.int32, np.float32])
def test_rev_and_sort_desc(self, dtype):
if not jtu.is_cloud_tpu_at_least(2025, 12, 2):
Expand Down
Loading