Skip to content

Commit 87fe4a6

Browse files
sharadmvGoogle-ML-Automation
authored andcommitted
Unbreak blockspec caching issue
PiperOrigin-RevId: 829650158
1 parent c308d91 commit 87fe4a6

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

jax/_src/pallas/core.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@
4444
from jax._src.state import discharge as state_discharge
4545
from jax._src.state import indexing
4646
from jax._src.state import types as state_types
47-
from jax._src import traceback_util
4847
from jax._src.state.types import TransformedRef
4948
import jax.numpy as jnp
5049

@@ -480,8 +479,11 @@ class BlockSpec:
480479

481480
def __post_init__(self):
482481
if self.index_map is not None:
483-
self.index_map = _IndexMapFunc(
484-
traceback_util.api_boundary(self.index_map, repro_user_func=True))
482+
# TODO(sharadmv): Add this once we have a better way to handle
483+
# index_map equality.
484+
# self.index_map = _IndexMapFunc(
485+
# traceback_util.api_boundary(self.index_map, repro_user_func=True))
486+
self.index_map = _IndexMapFunc(self.index_map)
485487

486488
def to_block_mapping(
487489
self,

0 commit comments

Comments
 (0)