Skip to content

Commit 744f6b4

Browse files
skyejax authors
authored andcommitted
Update xla_client._version and add missing version checks to JAX
PiperOrigin-RevId: 449021408
1 parent bd20f0f commit 744f6b4

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

jax/experimental/compilation_cache/compilation_cache.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,11 @@ def _hash_computation(hash_obj, xla_computation):
113113
hash_obj.update(scrubbed_hlo)
114114

115115
def _hash_compile_options(hash_obj, compile_options_obj):
116-
assert len(dir(compile_options_obj)) == 32, (
116+
if xla_client._version >= 68: # Remove when minimum jaxlib version >= 0.3.11
117+
expected_num_compile_options = 32
118+
else:
119+
expected_num_compile_options = 31
120+
assert len(dir(compile_options_obj)) == expected_num_compile_options, (
117121
f"Unexpected number of CompileOption fields: "
118122
f"{len(dir(compile_options_obj))}. This likely: means that an extra "
119123
f"field was added, and this function needs to be updated.")
@@ -126,7 +130,8 @@ def _hash_compile_options(hash_obj, compile_options_obj):
126130
_hash_bool(hash_obj, compile_options_obj.tuple_arguments)
127131
_hash_int(hash_obj, compile_options_obj.num_replicas)
128132
_hash_int(hash_obj, compile_options_obj.num_partitions)
129-
_hash_int(hash_obj, compile_options_obj.profile_version)
133+
if xla_client._version >= 68: # Remove when minimum jaxlib version >= 0.3.11
134+
_hash_int(hash_obj, compile_options_obj.profile_version)
130135
if compile_options_obj.device_assignment is not None:
131136
hash_obj.update(compile_options_obj.device_assignment.serialize())
132137

0 commit comments

Comments
 (0)