Skip to content

Commit 059342b

Browse files
brianwa84Google-ML-Automation
authored andcommitted
[Pallas:SC] Add plsc.sort_key_val to give access to the mask and descending args.
PiperOrigin-RevId: 839663879
1 parent 1b53ea1 commit 059342b

File tree

4 files changed

+119
-0
lines changed

4 files changed

+119
-0
lines changed

jax/_src/pallas/mosaic/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ pytype_strict_library(
175175
deps = [
176176
":core",
177177
":lowering",
178+
":sc_core",
178179
":sc_lowering",
179180
"//jax",
180181
"//jax/_src:core",

jax/_src/pallas/mosaic/sc_primitives.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from jax._src.pallas import core as pallas_core
3434
from jax._src.pallas.mosaic import core as tpu_core
3535
from jax._src.pallas.mosaic import lowering as tc_lowering
36+
from jax._src.pallas.mosaic import sc_core
3637
from jax._src.pallas.mosaic import sc_lowering
3738
from jax._src.state import primitives as state_primitives
3839
from jax._src.state import types as state_types
@@ -634,6 +635,72 @@ def _reduce_sum_lowering_rule(
634635
_cumsum_lowering_rule(ctx, x, 0, reverse=False), [], [vec_dim - 1])
635636

636637

638+
masked_sort_p = jax_core.Primitive("masked_sort")
639+
masked_sort_p.multiple_results = True
640+
641+
@masked_sort_p.def_abstract_eval
642+
def _masked_sort_abstract_eval(keys, values, *maybe_mask, descending):
643+
del descending # Unused.
644+
supported_shape = (sc_core.get_sparse_core_info().num_lanes,)
645+
if keys.dtype not in (jnp.int32, jnp.float32):
646+
raise NotImplementedError(
647+
f"sort_key_val: keys dtype {keys.dtype} should be int32 or float32")
648+
if keys.shape != supported_shape:
649+
raise ValueError(f"keys shape {keys.shape} must be {supported_shape}")
650+
if jnp.dtype(values.dtype).itemsize != 4:
651+
raise NotImplementedError(
652+
f"sort_key_val: values dtype {values.dtype} should be 32 bits")
653+
if values.shape != supported_shape:
654+
raise ValueError(f"values shape {values.shape} must be {supported_shape}")
655+
if maybe_mask:
656+
[mask] = maybe_mask
657+
if not jnp.issubdtype(mask.dtype, jnp.bool):
658+
raise TypeError(f"mask dtype {mask.dtype} is not boolean")
659+
if mask.shape != supported_shape:
660+
raise ValueError(f"mask shape {mask.shape} must be {supported_shape}")
661+
return keys, values, *maybe_mask
662+
663+
@sc_lowering.register_lowering_rule(masked_sort_p)
664+
def _masked_sort_lowering_rule(
665+
ctx: sc_lowering.LoweringRuleContext, keys, values, *maybe_mask, descending):
666+
del ctx # Unused.
667+
if maybe_mask:
668+
[mask] = maybe_mask
669+
else:
670+
mask_type = ir.VectorType.get(
671+
[sc_core.get_sparse_core_info().num_lanes],
672+
ir.IntegerType.get_signless(1))
673+
mask = arith.constant(mask_type, ir.DenseElementsAttr.get_splat(
674+
mask_type, ir.BoolAttr.get(True)))
675+
out_mask, sorted_keys, sorted_values = tpu.sort(
676+
mask.type, keys.type, values.type, keys, values, mask=mask,
677+
descending=descending
678+
)
679+
if maybe_mask:
680+
return sorted_keys, sorted_values, out_mask
681+
return sorted_keys, sorted_values
682+
683+
def sort_key_val(
684+
keys: jax.Array, values: jax.Array, *,
685+
mask: jax.Array | None = None, descending: bool = False
686+
) -> jax.Array:
687+
"""Sorts keys and values, pushing invalid elements to the last positions.
688+
689+
Args:
690+
keys: An array of integers or floats.
691+
values: An array of values corresponding to the keys.
692+
mask: An optional array of booleans, which specifies which elements of
693+
`keys` and `values` are valid. If `None`, all elements are valid.
694+
descending: Whether to sort in descending order.
695+
696+
Returns:
697+
sorted_keys, sorted_values, [output_mask]: The sorted keys and values, and,
698+
if a mask was given, the corresponding mask for output keys and values.
699+
"""
700+
maybe_mask = () if mask is None else (mask,)
701+
return masked_sort_p.bind(keys, values, *maybe_mask, descending=descending)
702+
703+
637704
parallel_loop_p = jax_core.Primitive("parallel_loop")
638705
parallel_loop_p.is_effectful = lambda params: bool(params["jaxpr"].effects) # type: ignore
639706
parallel_loop_p.multiple_results = True

jax/experimental/pallas/tpu_sc.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from jax._src.pallas.mosaic.sc_primitives import PackFormat as PackFormat
3333
from jax._src.pallas.mosaic.sc_primitives import parallel_loop as parallel_loop
3434
from jax._src.pallas.mosaic.sc_primitives import scan_count as scan_count
35+
from jax._src.pallas.mosaic.sc_primitives import sort_key_val as sort_key_val
3536
from jax._src.pallas.mosaic.sc_primitives import store_compressed as store_compressed
3637
from jax._src.pallas.mosaic.sc_primitives import store_scatter as store_scatter
3738
from jax._src.pallas.mosaic.sc_primitives import subcore_barrier as subcore_barrier

tests/pallas/tpu_sparsecore_pallas_test.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1654,6 +1654,56 @@ def kernel(x_ref, indices_ref, out_ref):
16541654

16551655
np.testing.assert_array_equal(kernel(x, indices), x[indices])
16561656

1657+
@parameterized.product(
1658+
keys_dtype=[np.int32, np.float32],
1659+
values_dtype=[np.int32, np.float32],
1660+
use_mask=[False, True],
1661+
descending=[False, True],
1662+
)
1663+
def test_sort_key_val(self, keys_dtype, values_dtype, use_mask, descending):
1664+
if not jtu.is_cloud_tpu_at_least(2025, 12, 2):
1665+
self.skipTest("Test requires a newer libtpu")
1666+
1667+
vec_dim = self.sc_info.num_lanes
1668+
keys = np.arange(vec_dim, dtype=keys_dtype)
1669+
np.random.shuffle(keys)
1670+
keys[3] = keys[1] # Verify sort stability.
1671+
values = np.arange(vec_dim, dtype=values_dtype)
1672+
np.random.shuffle(values)
1673+
mask = np.random.choice([True, False], size=vec_dim) if use_mask else None
1674+
maybe_mask_arg = (mask.astype(jnp.int32),) if use_mask else ()
1675+
1676+
@self.vector_subcore_kernel(out_shape=(keys, values, *maybe_mask_arg))
1677+
def kernel(*args):
1678+
if use_mask:
1679+
mask_ref, *args, o_mask_ref = args
1680+
mask = mask_ref[...].astype(jnp.bool)
1681+
else:
1682+
mask, o_mask_ref = None, None
1683+
keys_ref, values_ref, o_keys_ref, o_vals_ref = args
1684+
o_keys_ref[...], o_vals_ref[...], *maybe_out_mask = plsc.sort_key_val(
1685+
keys_ref[...], values_ref[...], mask=mask, descending=descending)
1686+
if use_mask:
1687+
[out_mask] = maybe_out_mask
1688+
o_mask_ref[...] = out_mask.astype(jnp.int32)
1689+
1690+
out_keys, out_values, *maybe_out_mask = kernel(
1691+
*maybe_mask_arg, keys, values)
1692+
1693+
keys_arg = keys
1694+
if descending:
1695+
keys_arg = -keys_arg
1696+
if use_mask:
1697+
keys_arg = jnp.where(mask, keys_arg, 100)
1698+
_, gt_keys = jax.lax.sort_key_val(keys_arg, keys)
1699+
_, gt_values = jax.lax.sort_key_val(keys_arg, values)
1700+
if use_mask:
1701+
[out_mask] = maybe_out_mask
1702+
gt_out_mask = jnp.arange(vec_dim) < mask.sum()
1703+
np.testing.assert_array_equal(out_mask, gt_out_mask.astype(jnp.int32))
1704+
np.testing.assert_array_equal(out_keys, gt_keys)
1705+
np.testing.assert_array_equal(out_values, gt_values)
1706+
16571707
@parameterized.product(dtype=[np.int32, np.float32])
16581708
def test_rev_and_sort_desc(self, dtype):
16591709
if not jtu.is_cloud_tpu_at_least(2025, 12, 2):

0 commit comments

Comments
 (0)