|
33 | 33 | from jax._src.pallas import core as pallas_core |
34 | 34 | from jax._src.pallas.mosaic import core as tpu_core |
35 | 35 | from jax._src.pallas.mosaic import lowering as tc_lowering |
| 36 | +from jax._src.pallas.mosaic import sc_core |
36 | 37 | from jax._src.pallas.mosaic import sc_lowering |
37 | 38 | from jax._src.state import primitives as state_primitives |
38 | 39 | from jax._src.state import types as state_types |
@@ -634,6 +635,72 @@ def _reduce_sum_lowering_rule( |
634 | 635 | _cumsum_lowering_rule(ctx, x, 0, reverse=False), [], [vec_dim - 1]) |
635 | 636 |
|
636 | 637 |
|
| 638 | +masked_sort_p = jax_core.Primitive("masked_sort") |
| 639 | +masked_sort_p.multiple_results = True |
| 640 | + |
| 641 | +@masked_sort_p.def_abstract_eval |
| 642 | +def _masked_sort_abstract_eval(keys, values, *maybe_mask, descending): |
| 643 | + del descending # Unused. |
| 644 | + supported_shape = (sc_core.get_sparse_core_info().num_lanes,) |
| 645 | + if keys.dtype not in (jnp.int32, jnp.float32): |
| 646 | + raise NotImplementedError( |
| 647 | + f"sort_key_val: keys dtype {keys.dtype} should be int32 or float32") |
| 648 | + if keys.shape != supported_shape: |
| 649 | + raise ValueError(f"keys shape {keys.shape} must be {supported_shape}") |
| 650 | + if jnp.dtype(values.dtype).itemsize != 4: |
| 651 | + raise NotImplementedError( |
| 652 | + f"sort_key_val: values dtype {values.dtype} should be 32 bits") |
| 653 | + if values.shape != supported_shape: |
| 654 | + raise ValueError(f"values shape {values.shape} must be {supported_shape}") |
| 655 | + if maybe_mask: |
| 656 | + [mask] = maybe_mask |
| 657 | + if not jnp.issubdtype(mask.dtype, jnp.bool): |
| 658 | + raise TypeError(f"mask dtype {mask.dtype} is not boolean") |
| 659 | + if mask.shape != supported_shape: |
| 660 | + raise ValueError(f"mask shape {mask.shape} must be {supported_shape}") |
| 661 | + return keys, values, *maybe_mask |
| 662 | + |
| 663 | +@sc_lowering.register_lowering_rule(masked_sort_p) |
| 664 | +def _masked_sort_lowering_rule( |
| 665 | + ctx: sc_lowering.LoweringRuleContext, keys, values, *maybe_mask, descending): |
| 666 | + del ctx # Unused. |
| 667 | + if maybe_mask: |
| 668 | + [mask] = maybe_mask |
| 669 | + else: |
| 670 | + mask_type = ir.VectorType.get( |
| 671 | + [sc_core.get_sparse_core_info().num_lanes], |
| 672 | + ir.IntegerType.get_signless(1)) |
| 673 | + mask = arith.constant(mask_type, ir.DenseElementsAttr.get_splat( |
| 674 | + mask_type, ir.BoolAttr.get(True))) |
| 675 | + out_mask, sorted_keys, sorted_values = tpu.sort( |
| 676 | + mask.type, keys.type, values.type, keys, values, mask=mask, |
| 677 | + descending=descending |
| 678 | + ) |
| 679 | + if maybe_mask: |
| 680 | + return sorted_keys, sorted_values, out_mask |
| 681 | + return sorted_keys, sorted_values |
| 682 | + |
| 683 | +def sort_key_val( |
| 684 | + keys: jax.Array, values: jax.Array, *, |
| 685 | + mask: jax.Array | None = None, descending: bool = False |
| 686 | +) -> jax.Array: |
| 687 | + """Sorts keys and values, pushing invalid elements to the last positions. |
| 688 | +
|
| 689 | + Args: |
| 690 | + keys: An array of integers or floats. |
| 691 | + values: An array of values corresponding to the keys. |
| 692 | + mask: An optional array of booleans, which specifies which elements of |
| 693 | + `keys` and `values` are valid. If `None`, all elements are valid. |
| 694 | + descending: Whether to sort in descending order. |
| 695 | +
|
| 696 | + Returns: |
| 697 | + sorted_keys, sorted_values, [output_mask]: The sorted keys and values, and, |
| 698 | + if a mask was given, the corresponding mask for output keys and values. |
| 699 | + """ |
| 700 | + maybe_mask = () if mask is None else (mask,) |
| 701 | + return masked_sort_p.bind(keys, values, *maybe_mask, descending=descending) |
| 702 | + |
| 703 | + |
637 | 704 | parallel_loop_p = jax_core.Primitive("parallel_loop") |
638 | 705 | parallel_loop_p.is_effectful = lambda params: bool(params["jaxpr"].effects) # type: ignore |
639 | 706 | parallel_loop_p.multiple_results = True |
|
0 commit comments