@@ -113,7 +113,11 @@ def _hash_computation(hash_obj, xla_computation):
113113 hash_obj .update (scrubbed_hlo )
114114
115115def _hash_compile_options (hash_obj , compile_options_obj ):
116- assert len (dir (compile_options_obj )) == 32 , (
116+ if xla_client ._version >= 68 : # Remove when minimum jaxlib version >= 0.3.11
117+ expected_num_compile_options = 32
118+ else :
119+ expected_num_compile_options = 31
120+ assert len (dir (compile_options_obj )) == expected_num_compile_options , (
117121 f"Unexpected number of CompileOption fields: "
118122 f"{ len (dir (compile_options_obj ))} . This likely: means that an extra "
119123 f"field was added, and this function needs to be updated." )
@@ -126,7 +130,8 @@ def _hash_compile_options(hash_obj, compile_options_obj):
126130 _hash_bool (hash_obj , compile_options_obj .tuple_arguments )
127131 _hash_int (hash_obj , compile_options_obj .num_replicas )
128132 _hash_int (hash_obj , compile_options_obj .num_partitions )
129- _hash_int (hash_obj , compile_options_obj .profile_version )
133+ if xla_client ._version >= 68 : # Remove when minimum jaxlib version >= 0.3.11
134+ _hash_int (hash_obj , compile_options_obj .profile_version )
130135 if compile_options_obj .device_assignment is not None :
131136 hash_obj .update (compile_options_obj .device_assignment .serialize ())
132137
0 commit comments