1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ import enum
1516from dataclasses import dataclass
1617from functools import partial
1718import itertools as it
18- from typing import Union , Optional , Callable , Dict , Tuple , TypeVar
19+ from typing import Union , Optional , Callable , Dict , Tuple , TypeVar , Set , FrozenSet
1920
2021import numpy as np
2122
@@ -97,14 +98,21 @@ def __init__(self, trace, val):
9798class CheckifyTrace (core .Trace ):
9899 pure = lift = lambda self , val : CheckifyTracer (self , val )
99100
101+ def __init__ (self , main : core .MainTrace , sublevel : core .Sublevel ,
102+ enabled_errors : FrozenSet ['ErrorCategory' ]) -> None :
103+ self .main = main
104+ self .level = main .level
105+ self .sublevel = sublevel
106+ self .main .enabled_errors = enabled_errors
107+
100108 def sublift (self , tracer ):
101109 return CheckifyTracer (self , tracer .val )
102110
103111 def process_primitive (self , primitive , tracers , params ):
104112 in_vals = [t .val for t in tracers ]
105113 rule = error_checks .get (primitive )
106114 if rule :
107- out , self .main .error = rule (self .main .error , * in_vals , ** params ) # type: ignore
115+ out , self .main .error = rule (self .main .error , self . main . enabled_errors , * in_vals , ** params ) # type: ignore
108116 else :
109117 out = primitive .bind (* in_vals , ** params )
110118 if primitive .multiple_results :
@@ -166,18 +174,18 @@ def _reduce_any_error(errs, codes):
166174 errs_ , codes_ = lax .sort_key_val (errs , codes , dimension = 0 )
167175 return errs_ [- 1 ], codes_ [- 1 ]
168176
169- ErrorCheckRule = Callable
177+ ErrorCheckRule = Callable # (Error, FrozenSet[ErrorCategory], *in_vals, **params) -> (Any, Error)
170178error_checks : Dict [core .Primitive , ErrorCheckRule ] = {}
171179
172- def checkify_flat (fun : lu .WrappedFun , * args ):
180+ def checkify_flat (fun : lu .WrappedFun , enabled_errors : FrozenSet [ 'ErrorCategory' ], * args ):
173181 fun , msgs = checkify_subtrace (fun )
174- fun = checkify_traceable (fun , tuple (init_error .msgs .items ()))
182+ fun = checkify_traceable (fun , tuple (init_error .msgs .items ()), enabled_errors )
175183 err , code , * outvals = fun .call_wrapped (init_error .err , init_error .code , * args )
176184 return (err , code , outvals ), msgs ()
177185
178186@lu .transformation
179- def checkify_traceable (msgs , err , code , * args ):
180- with core .new_main (CheckifyTrace ) as main :
187+ def checkify_traceable (msgs , enabled_errors , err , code , * args ):
188+ with core .new_main (CheckifyTrace , enabled_errors = enabled_errors ) as main :
181189 outs = yield (main , msgs , err , code , * args ), {}
182190 del main
183191 yield outs
@@ -196,13 +204,13 @@ def checkify_subtrace(main, msgs, err, code, *args):
196204
197205
198206# TODO take (error_aval, code_aval) instead of error here?
199- def checkify_jaxpr (jaxpr , error ):
207+ def checkify_jaxpr (jaxpr , error , enabled_errors ):
200208 f = lu .wrap_init (core .jaxpr_as_fun (jaxpr ))
201- return checkify_fun_to_jaxpr (f , error , jaxpr .in_avals )
209+ return checkify_fun_to_jaxpr (f , error , enabled_errors , jaxpr .in_avals )
202210
203- def checkify_fun_to_jaxpr (f , error , in_avals ):
211+ def checkify_fun_to_jaxpr (f , error , enabled_errors , in_avals ):
204212 f , msgs = checkify_subtrace (f )
205- f = checkify_traceable (f , tuple (error .msgs .items ()))
213+ f = checkify_traceable (f , tuple (error .msgs .items ()), enabled_errors )
206214 err_aval = core .raise_to_shaped (core .get_aval (error .err ))
207215 code_aval = core .raise_to_shaped (core .get_aval (error .code ))
208216 avals_in = [err_aval , code_aval , * in_avals ]
@@ -244,20 +252,25 @@ def assert_abstract_eval(pred, code, *, msgs):
244252def summary () -> str :
245253 return str (source_info_util .summarize (source_info_util .current ()))
246254
247- def nan_error_check (prim , error , * in_vals , ** params ):
255+ def nan_error_check (prim , error , enabled_errors , * in_vals , ** params ):
248256 out = prim .bind (* in_vals , ** params )
257+ if ErrorCategory .NAN not in enabled_errors :
258+ return out , error
249259 no_nans = jnp .logical_not (jnp .any (jnp .isnan (out )))
250260 msg = f"nan generated by primitive { prim .name } at { summary ()} "
251261 return out , assert_func (error , no_nans , msg )
252262
253- def gather_error_check (error , operand , start_indices , * ,
263+ def gather_error_check (error , enabled_errors , operand , start_indices , * ,
254264 dimension_numbers , slice_sizes , unique_indices ,
255265 indices_are_sorted , mode , fill_value ):
256266 out = lax .gather_p .bind (
257267 operand , start_indices , dimension_numbers = dimension_numbers ,
258268 slice_sizes = slice_sizes , unique_indices = unique_indices ,
259269 indices_are_sorted = indices_are_sorted , mode = mode , fill_value = fill_value )
260270
271+ if ErrorCategory .OOB not in enabled_errors :
272+ return out , error
273+
261274 # compare to OOB masking logic in lax._gather_translation_rule
262275 dnums = dimension_numbers
263276 operand_dims = np .array (operand .shape )
@@ -270,12 +283,13 @@ def gather_error_check(error, operand, start_indices, *,
270283 return out , assert_func (error , all_inbounds , msg )
271284error_checks [lax .gather_p ] = gather_error_check
272285
273- def div_error_check (error , x , y ):
286+ def div_error_check (error , enabled_errors , x , y ):
274287 """Checks for division by zero and NaN."""
275- all_nonzero = jnp .logical_not (jnp .any (jnp .equal (y , 0 )))
276- msg = f'divided by zero at { summary ()} '
277- div_by_zero_err = assert_func (error , all_nonzero , msg )
278- return nan_error_check (lax .div_p , div_by_zero_err , x , y )
288+ if ErrorCategory .DIV in enabled_errors :
289+ all_nonzero = jnp .logical_not (jnp .any (jnp .equal (y , 0 )))
290+ msg = f'divided by zero at { summary ()} '
291+ error = assert_func (error , all_nonzero , msg )
292+ return nan_error_check (lax .div_p , error , enabled_errors , x , y )
279293error_checks [lax .div_p ] = div_error_check
280294
281295def scatter_in_bounds (operand , indices , updates , dnums ):
@@ -300,17 +314,19 @@ def scatter_in_bounds(operand, indices, updates, dnums):
300314 upper_in_bounds = jnp .all (jnp .less_equal (indices , upper_bound ))
301315 return jnp .logical_and (lower_in_bounds , upper_in_bounds )
302316
303- def scatter_error_check (prim , error , operand , indices , updates , * ,
304- update_jaxpr , update_consts ,
305- dimension_numbers , indices_are_sorted ,
306- unique_indices , mode ):
317+ def scatter_error_check (prim , error , enabled_errors , operand , indices , updates ,
318+ * , update_jaxpr , update_consts , dimension_numbers ,
319+ indices_are_sorted , unique_indices , mode ):
307320 """Checks if indices are within bounds and update does not generate NaN."""
308321 out = prim .bind (
309322 operand , indices , updates , update_jaxpr = update_jaxpr ,
310323 update_consts = update_consts , dimension_numbers = dimension_numbers ,
311324 indices_are_sorted = indices_are_sorted , unique_indices = unique_indices ,
312325 mode = mode )
313326
327+ if ErrorCategory .OOB not in enabled_errors :
328+ return out , error
329+
314330 in_bounds = scatter_in_bounds (operand , indices , updates , dimension_numbers )
315331 oob_msg = f'out-of-bounds indexing while updating at { summary ()} '
316332 oob_error = assert_func (error , in_bounds , oob_msg )
@@ -324,8 +340,8 @@ def scatter_error_check(prim, error, operand, indices, updates, *,
324340error_checks [lax .scatter_min_p ] = partial (scatter_error_check , lax .scatter_min_p )
325341error_checks [lax .scatter_max_p ] = partial (scatter_error_check , lax .scatter_max_p )
326342
327- def cond_error_check (error , index , * ops , branches , linear ):
328- new_branches , msgs_ = unzip2 (checkify_jaxpr (jxpr , error ) for jxpr in branches )
343+ def cond_error_check (error , enabled_errors , index , * ops , branches , linear ):
344+ new_branches , msgs_ = unzip2 (checkify_jaxpr (jxpr , error , enabled_errors ) for jxpr in branches )
329345 new_linear = (False , False , * linear )
330346 err , code , * outs = lax .cond_p .bind (
331347 index , error .err , error .code , * ops ,
@@ -334,9 +350,9 @@ def cond_error_check(error, index, *ops, branches, linear):
334350 return outs , Error (err , code , new_msgs )
335351error_checks [lax .cond_p ] = cond_error_check
336352
337- def scan_error_check (error , * in_flat , reverse , length , jaxpr , num_consts , num_carry , linear , unroll ):
353+ def scan_error_check (error , enabled_errors , * in_flat , reverse , length , jaxpr , num_consts , num_carry , linear , unroll ):
338354 consts , carry , xs = split_list (in_flat , [num_consts , num_carry ])
339- checked_jaxpr , msgs_ = checkify_jaxpr (jaxpr , error )
355+ checked_jaxpr , msgs_ = checkify_jaxpr (jaxpr , error , enabled_errors )
340356 new_linear = (False , False , * linear )
341357 new_in_flat = [* consts , error .err , error .code , * carry , * xs ]
342358 err , code , * outs = lax .scan_p .bind (
@@ -348,14 +364,14 @@ def scan_error_check(error, *in_flat, reverse, length, jaxpr, num_consts, num_ca
348364 return outs , Error (err , code , new_msgs )
349365error_checks [lax .scan_p ] = scan_error_check
350366
351- def checkify_while_body_jaxpr (cond_jaxpr , body_jaxpr , error ):
367+ def checkify_while_body_jaxpr (cond_jaxpr , body_jaxpr , error , enabled_errors ):
352368 cond_f = core .jaxpr_as_fun (cond_jaxpr )
353369 body_f = core .jaxpr_as_fun (body_jaxpr )
354370 def new_body_f (* vals ):
355371 out = body_f (* vals )
356372 _ = cond_f (* out ) # this checks if the next cond application will error
357373 return out
358- return checkify_fun_to_jaxpr (lu .wrap_init (new_body_f ), error , body_jaxpr .in_avals )
374+ return checkify_fun_to_jaxpr (lu .wrap_init (new_body_f ), error , enabled_errors , body_jaxpr .in_avals )
359375
360376def ignore_errors_jaxpr (jaxpr , error ):
361377 """Constructs a jaxpr which takes two extra args but ignores them."""
@@ -369,13 +385,13 @@ def ignore_errors_jaxpr(jaxpr, error):
369385 jaxpr .outvars , jaxpr .eqns )
370386 return core .ClosedJaxpr (new_jaxpr , consts )
371387
372- def while_loop_error_check (error , * in_flat , cond_nconsts , cond_jaxpr , body_nconsts , body_jaxpr ):
373- checked_cond_jaxpr , msgs_cond = checkify_jaxpr (cond_jaxpr , error )
388+ def while_loop_error_check (error , enabled_errors , * in_flat , cond_nconsts , cond_jaxpr , body_nconsts , body_jaxpr ):
389+ checked_cond_jaxpr , msgs_cond = checkify_jaxpr (cond_jaxpr , error , enabled_errors )
374390 checked_cond_fun = core .jaxpr_as_fun (checked_cond_jaxpr )
375391 # Check if the first cond application will error.
376392 cond_err , cond_code , _ = checked_cond_fun (error .err , error .code , * in_flat )
377393
378- checked_body_jaxpr , msgs_body = checkify_while_body_jaxpr (cond_jaxpr , body_jaxpr , error )
394+ checked_body_jaxpr , msgs_body = checkify_while_body_jaxpr (cond_jaxpr , body_jaxpr , error , enabled_errors )
379395 compat_cond_jaxpr = ignore_errors_jaxpr (cond_jaxpr , error )
380396 c_consts , b_consts , carry = split_list (in_flat , [cond_nconsts , body_nconsts ])
381397 new_in_flat = [* c_consts , * b_consts , cond_err , cond_code , * carry ]
@@ -453,7 +469,10 @@ def add_nan_check(prim):
453469add_nan_check (lax .max_p )
454470add_nan_check (lax .min_p )
455471
456- def assert_discharge_rule (error , pred , code , * , msgs ):
472+ def assert_discharge_rule (error , enabled_errors , pred , code , * , msgs ):
473+ if ErrorCategory .ASSERT not in enabled_errors :
474+ return [], error
475+
457476 out_err = error .err | jnp .logical_not (pred )
458477 out_code = lax .select (error .err , error .code , code )
459478 return [], Error (out_err , out_code , {** error .msgs , ** msgs })
@@ -462,13 +481,24 @@ def assert_discharge_rule(error, pred, code, *, msgs):
462481
463482## checkify api
464483
484+ ErrorCategory = enum .Enum ('ErrorCategory' , ['NAN' , 'OOB' , 'DIV' , 'ASSERT' ])
485+
486+ float_errors = {ErrorCategory .NAN , ErrorCategory .DIV }
487+ index_errors = {ErrorCategory .OOB }
488+ automatic_errors = float_errors | index_errors
489+ user_asserts = {ErrorCategory .ASSERT }
490+
465491Out = TypeVar ('Out' )
466- def checkify (fun : Callable [..., Out ]) -> Callable [..., Tuple [Error , Out ]]:
492+ def checkify (fun : Callable [..., Out ], errors : Set [ErrorCategory ] = user_asserts ) -> Callable [..., Tuple [Error , Out ]]:
493+ if not errors :
494+ raise ValueError ('Checkify needs to be called with at least one enabled'
495+ ' ErrorCategory, was called with an empty errors set.' )
496+
467497 @traceback_util .api_boundary
468498 def checked_fun (* args , ** kwargs ):
469499 args_flat , in_tree = tree_flatten ((args , kwargs ))
470500 f , out_tree = flatten_fun (lu .wrap_init (fun ), in_tree )
471- (err , code , out_flat ), msgs = checkify_flat (f , * args_flat )
501+ (err , code , out_flat ), msgs = checkify_flat (f , frozenset ( errors ), * args_flat )
472502 out = tree_unflatten (out_tree (), out_flat )
473503 return Error (err , code , msgs ), out
474504 return checked_fun
0 commit comments