Skip to content

Commit aa18a98

Browse files
sharadmvGoogle-ML-Automation
authored andcommitted
Allow DCE side-effecting custom calls
PiperOrigin-RevId: 829639173
1 parent 1fb2976 commit aa18a98

File tree

1 file changed

+34
-10
lines changed

1 file changed

+34
-10
lines changed

jax/_src/tpu_custom_call.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,15 @@ def to_json(self) -> bytes:
133133
).encode("ascii")
134134

135135

136+
class TpuSideEffectType(enum.Enum):
137+
# No side effects, can be deduplicated / removed if unused.
138+
PURE = "pure"
139+
# Cannot be deduplicated, but can be removed if unused.
140+
DATAFLOW_SIDE_EFFECTING = "dataflow_side_effecting"
141+
# Cannot be deduplicated or removed.
142+
SIDE_EFFECTING = "side_effecting"
143+
144+
136145
@dataclasses.dataclass(frozen=True)
137146
class CustomCallBackendConfig:
138147
"""Represents an unserialized backend config for custom calls."""
@@ -304,7 +313,7 @@ def _tpu_custom_call_lowering(
304313
ctx: mlir.LoweringRuleContext,
305314
*in_nodes, # pylint: disable=missing-function-docstring
306315
config: CustomCallBackendConfig,
307-
has_side_effects: bool,
316+
has_side_effects: TpuSideEffectType,
308317
kernel_name: str | None,
309318
out_avals: Any,
310319
input_output_aliases: tuple[tuple[int, int], ...],
@@ -340,24 +349,27 @@ def _tpu_custom_call_lowering(
340349
# information.
341350
if kernel_name is not None:
342351
extra_attributes = dict(kernel_name=ir.StringAttr.get(kernel_name))
343-
has_side_effects = has_side_effects if has_side_effects is not None else False
344352
call = mlir.custom_call(
345353
"tpu_custom_call",
346354
result_types=result_types,
347355
operands=in_nodes,
348356
backend_config=config.to_json(),
349357
api_version=1,
350-
has_side_effect=has_side_effects,
358+
has_side_effect=has_side_effects != TpuSideEffectType.PURE,
351359
operand_output_aliases=dict(input_output_aliases),
352360
operand_layouts=_avals_to_layouts(ctx.avals_in),
353361
result_layouts=_avals_to_layouts(ctx.avals_out),
354362
result_shapes=result_shapes,
355363
extra_attributes=extra_attributes,
356364
)
365+
metadata_dict = {}
357366
if metadata is not None:
358-
call.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get(
359-
dict(kernel_metadata=ir.StringAttr.get(json.dumps(metadata)))
360-
)
367+
metadata_dict["kernel_metadata"] = ir.StringAttr.get(json.dumps(metadata))
368+
assert isinstance(has_side_effects, TpuSideEffectType)
369+
if has_side_effects == TpuSideEffectType.DATAFLOW_SIDE_EFFECTING:
370+
metadata_dict["xla_allow_dce_side_effecting_op"] = ir.StringAttr.get("true")
371+
if metadata_dict:
372+
call.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get(metadata_dict)
361373
return call.results
362374

363375

@@ -643,7 +655,7 @@ def lower_module_to_custom_call(
643655
input_output_aliases: tuple[tuple[int, int], ...],
644656
internal_scratch_in_bytes: int | None,
645657
collective_id: int | None,
646-
has_side_effects: bool,
658+
has_side_effects: bool | TpuSideEffectType,
647659
serialization_format: int | None,
648660
output_memory_spaces: tuple[MemorySpace | None, ...] | None,
649661
disable_bounds_checks: bool = False,
@@ -653,6 +665,12 @@ def lower_module_to_custom_call(
653665
allow_collective_id_without_custom_barrier: bool = False,
654666
shape_invariant_numerics: bool = False,
655667
) -> Sequence[ir.Value]:
668+
if isinstance(has_side_effects, bool):
669+
has_side_effects = (
670+
TpuSideEffectType.PURE
671+
if not has_side_effects
672+
else TpuSideEffectType.DATAFLOW_SIDE_EFFECTING
673+
)
656674
config = _lower_to_custom_call_config(
657675
module,
658676
vmem_limit_bytes=vmem_limit_bytes,
@@ -694,7 +712,7 @@ def as_tpu_kernel(
694712
input_output_aliases: tuple[tuple[int, int], ...] = (),
695713
internal_scratch_in_bytes: int | None = None,
696714
collective_id: int | None = None,
697-
has_side_effects: bool = False,
715+
has_side_effects: TpuSideEffectType = TpuSideEffectType.PURE,
698716
serialization_format: int | None = 1,
699717
output_memory_spaces: tuple[MemorySpace | None, ...] | None = None,
700718
disable_bounds_checks: bool = False,
@@ -738,7 +756,7 @@ def lowered_as_tpu_kernel(
738756
needs_hlo_passes: bool = False,
739757
needs_layout_passes: bool = False,
740758
has_communication: bool = False,
741-
has_side_effects: bool = False,
759+
has_side_effects: bool | TpuSideEffectType = False,
742760
has_custom_barrier: bool = False,
743761
kernel_name: str | None = None,
744762
vmem_limit_bytes: int | None = None,
@@ -755,6 +773,12 @@ def lowered_as_tpu_kernel(
755773
lowered_module_asm = lowered_module.operation.get_asm(
756774
binary=True, enable_debug_info=True
757775
)
776+
if isinstance(has_side_effects, bool):
777+
has_side_effects = (
778+
TpuSideEffectType.PURE
779+
if not has_side_effects
780+
else TpuSideEffectType.DATAFLOW_SIDE_EFFECTING
781+
)
758782
config = _lowered_to_custom_call_config(
759783
lowered_module_asm,
760784
vmem_limit_bytes=vmem_limit_bytes,
@@ -784,7 +808,7 @@ def lowered_as_tpu_kernel(
784808

785809
def _as_jax_callable(
786810
config: CustomCallBackendConfig,
787-
has_side_effects: bool,
811+
has_side_effects: TpuSideEffectType,
788812
out_type: Any,
789813
*,
790814
kernel_name: str | None,

0 commit comments

Comments
 (0)