@@ -133,6 +133,15 @@ def to_json(self) -> bytes:
133133 ).encode ("ascii" )
134134
135135
136+ class TpuSideEffectType (enum .Enum ):
137+ # No side effects, can be deduplicated / removed if unused.
138+ PURE = "pure"
139+ # Cannot be deduplicated, but can be removed if unused.
140+ DATAFLOW_SIDE_EFFECTING = "dataflow_side_effecting"
141+ # Cannot be deduplicated or removed.
142+ SIDE_EFFECTING = "side_effecting"
143+
144+
136145@dataclasses .dataclass (frozen = True )
137146class CustomCallBackendConfig :
138147 """Represents an unserialized backend config for custom calls."""
@@ -304,7 +313,7 @@ def _tpu_custom_call_lowering(
304313 ctx : mlir .LoweringRuleContext ,
305314 * in_nodes , # pylint: disable=missing-function-docstring
306315 config : CustomCallBackendConfig ,
307- has_side_effects : bool ,
316+ has_side_effects : TpuSideEffectType ,
308317 kernel_name : str | None ,
309318 out_avals : Any ,
310319 input_output_aliases : tuple [tuple [int , int ], ...],
@@ -340,24 +349,27 @@ def _tpu_custom_call_lowering(
340349 # information.
341350 if kernel_name is not None :
342351 extra_attributes = dict (kernel_name = ir .StringAttr .get (kernel_name ))
343- has_side_effects = has_side_effects if has_side_effects is not None else False
344352 call = mlir .custom_call (
345353 "tpu_custom_call" ,
346354 result_types = result_types ,
347355 operands = in_nodes ,
348356 backend_config = config .to_json (),
349357 api_version = 1 ,
350- has_side_effect = has_side_effects ,
358+ has_side_effect = has_side_effects != TpuSideEffectType . PURE ,
351359 operand_output_aliases = dict (input_output_aliases ),
352360 operand_layouts = _avals_to_layouts (ctx .avals_in ),
353361 result_layouts = _avals_to_layouts (ctx .avals_out ),
354362 result_shapes = result_shapes ,
355363 extra_attributes = extra_attributes ,
356364 )
365+ metadata_dict = {}
357366 if metadata is not None :
358- call .attributes ["mhlo.frontend_attributes" ] = ir .DictAttr .get (
359- dict (kernel_metadata = ir .StringAttr .get (json .dumps (metadata )))
360- )
367+ metadata_dict ["kernel_metadata" ] = ir .StringAttr .get (json .dumps (metadata ))
368+ assert isinstance (has_side_effects , TpuSideEffectType )
369+ if has_side_effects == TpuSideEffectType .DATAFLOW_SIDE_EFFECTING :
370+ metadata_dict ["xla_allow_dce_side_effecting_op" ] = ir .StringAttr .get ("true" )
371+ if metadata_dict :
372+ call .attributes ["mhlo.frontend_attributes" ] = ir .DictAttr .get (metadata_dict )
361373 return call .results
362374
363375
@@ -643,7 +655,7 @@ def lower_module_to_custom_call(
643655 input_output_aliases : tuple [tuple [int , int ], ...],
644656 internal_scratch_in_bytes : int | None ,
645657 collective_id : int | None ,
646- has_side_effects : bool ,
658+ has_side_effects : bool | TpuSideEffectType ,
647659 serialization_format : int | None ,
648660 output_memory_spaces : tuple [MemorySpace | None , ...] | None ,
649661 disable_bounds_checks : bool = False ,
@@ -653,6 +665,12 @@ def lower_module_to_custom_call(
653665 allow_collective_id_without_custom_barrier : bool = False ,
654666 shape_invariant_numerics : bool = False ,
655667) -> Sequence [ir .Value ]:
668+ if isinstance (has_side_effects , bool ):
669+ has_side_effects = (
670+ TpuSideEffectType .PURE
671+ if not has_side_effects
672+ else TpuSideEffectType .DATAFLOW_SIDE_EFFECTING
673+ )
656674 config = _lower_to_custom_call_config (
657675 module ,
658676 vmem_limit_bytes = vmem_limit_bytes ,
@@ -694,7 +712,7 @@ def as_tpu_kernel(
694712 input_output_aliases : tuple [tuple [int , int ], ...] = (),
695713 internal_scratch_in_bytes : int | None = None ,
696714 collective_id : int | None = None ,
697- has_side_effects : bool = False ,
715+ has_side_effects : TpuSideEffectType = TpuSideEffectType . PURE ,
698716 serialization_format : int | None = 1 ,
699717 output_memory_spaces : tuple [MemorySpace | None , ...] | None = None ,
700718 disable_bounds_checks : bool = False ,
@@ -738,7 +756,7 @@ def lowered_as_tpu_kernel(
738756 needs_hlo_passes : bool = False ,
739757 needs_layout_passes : bool = False ,
740758 has_communication : bool = False ,
741- has_side_effects : bool = False ,
759+ has_side_effects : bool | TpuSideEffectType = False ,
742760 has_custom_barrier : bool = False ,
743761 kernel_name : str | None = None ,
744762 vmem_limit_bytes : int | None = None ,
@@ -755,6 +773,12 @@ def lowered_as_tpu_kernel(
755773 lowered_module_asm = lowered_module .operation .get_asm (
756774 binary = True , enable_debug_info = True
757775 )
776+ if isinstance (has_side_effects , bool ):
777+ has_side_effects = (
778+ TpuSideEffectType .PURE
779+ if not has_side_effects
780+ else TpuSideEffectType .DATAFLOW_SIDE_EFFECTING
781+ )
758782 config = _lowered_to_custom_call_config (
759783 lowered_module_asm ,
760784 vmem_limit_bytes = vmem_limit_bytes ,
@@ -784,7 +808,7 @@ def lowered_as_tpu_kernel(
784808
785809def _as_jax_callable (
786810 config : CustomCallBackendConfig ,
787- has_side_effects : bool ,
811+ has_side_effects : TpuSideEffectType ,
788812 out_type : Any ,
789813 * ,
790814 kernel_name : str | None ,
0 commit comments