File tree Expand file tree Collapse file tree 1 file changed +5
-3
lines changed Expand file tree Collapse file tree 1 file changed +5
-3
lines changed Original file line number Diff line number Diff line change 4444from jax ._src .state import discharge as state_discharge
4545from jax ._src .state import indexing
4646from jax ._src .state import types as state_types
47- from jax ._src import traceback_util
4847from jax ._src .state .types import TransformedRef
4948import jax .numpy as jnp
5049
@@ -480,8 +479,11 @@ class BlockSpec:
480479
481480 def __post_init__ (self ):
482481 if self .index_map is not None :
483- self .index_map = _IndexMapFunc (
484- traceback_util .api_boundary (self .index_map , repro_user_func = True ))
482+ # TODO(sharadmv): Add this once we have a better way to handle
483+ # index_map equality.
484+ # self.index_map = _IndexMapFunc(
485+ # traceback_util.api_boundary(self.index_map, repro_user_func=True))
486+ self .index_map = _IndexMapFunc (self .index_map )
485487
486488 def to_block_mapping (
487489 self ,
You can’t perform that action at this time.
0 commit comments