Skip to content

Commit e30b96c

Browse files
author
jax authors
committed
Merge pull request #9201 from LenaMartens:changelist/420794552
PiperOrigin-RevId: 422589737
2 parents 6411f8a + 8ea8576 commit e30b96c

File tree

2 files changed

+144
-67
lines changed

2 files changed

+144
-67
lines changed

jax/experimental/checkify.py

Lines changed: 64 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import enum
1516
from dataclasses import dataclass
1617
from functools import partial
1718
import itertools as it
18-
from typing import Union, Optional, Callable, Dict, Tuple, TypeVar
19+
from typing import Union, Optional, Callable, Dict, Tuple, TypeVar, Set, FrozenSet
1920

2021
import numpy as np
2122

@@ -97,14 +98,21 @@ def __init__(self, trace, val):
9798
class CheckifyTrace(core.Trace):
9899
pure = lift = lambda self, val: CheckifyTracer(self, val)
99100

101+
def __init__(self, main: core.MainTrace, sublevel: core.Sublevel,
102+
enabled_errors: FrozenSet['ErrorCategory']) -> None:
103+
self.main = main
104+
self.level = main.level
105+
self.sublevel = sublevel
106+
self.main.enabled_errors = enabled_errors
107+
100108
def sublift(self, tracer):
101109
return CheckifyTracer(self, tracer.val)
102110

103111
def process_primitive(self, primitive, tracers, params):
104112
in_vals = [t.val for t in tracers]
105113
rule = error_checks.get(primitive)
106114
if rule:
107-
out, self.main.error = rule(self.main.error, *in_vals, **params) # type: ignore
115+
out, self.main.error = rule(self.main.error, self.main.enabled_errors, *in_vals, **params) # type: ignore
108116
else:
109117
out = primitive.bind(*in_vals, **params)
110118
if primitive.multiple_results:
@@ -166,18 +174,18 @@ def _reduce_any_error(errs, codes):
166174
errs_, codes_ = lax.sort_key_val(errs, codes, dimension=0)
167175
return errs_[-1], codes_[-1]
168176

169-
ErrorCheckRule = Callable
177+
ErrorCheckRule = Callable # (Error, FrozenSet[ErrorCategory], *in_vals, **params) -> (Any, Error)
170178
error_checks: Dict[core.Primitive, ErrorCheckRule] = {}
171179

172-
def checkify_flat(fun: lu.WrappedFun, *args):
180+
def checkify_flat(fun: lu.WrappedFun, enabled_errors: FrozenSet['ErrorCategory'], *args):
173181
fun, msgs = checkify_subtrace(fun)
174-
fun = checkify_traceable(fun, tuple(init_error.msgs.items()))
182+
fun = checkify_traceable(fun, tuple(init_error.msgs.items()), enabled_errors)
175183
err, code, *outvals = fun.call_wrapped(init_error.err, init_error.code, *args)
176184
return (err, code, outvals), msgs()
177185

178186
@lu.transformation
179-
def checkify_traceable(msgs, err, code, *args):
180-
with core.new_main(CheckifyTrace) as main:
187+
def checkify_traceable(msgs, enabled_errors, err, code, *args):
188+
with core.new_main(CheckifyTrace, enabled_errors=enabled_errors) as main:
181189
outs = yield (main, msgs, err, code, *args), {}
182190
del main
183191
yield outs
@@ -196,13 +204,13 @@ def checkify_subtrace(main, msgs, err, code, *args):
196204

197205

198206
# TODO take (error_aval, code_aval) instead of error here?
199-
def checkify_jaxpr(jaxpr, error):
207+
def checkify_jaxpr(jaxpr, error, enabled_errors):
200208
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
201-
return checkify_fun_to_jaxpr(f, error, jaxpr.in_avals)
209+
return checkify_fun_to_jaxpr(f, error, enabled_errors, jaxpr.in_avals)
202210

203-
def checkify_fun_to_jaxpr(f, error, in_avals):
211+
def checkify_fun_to_jaxpr(f, error, enabled_errors, in_avals):
204212
f, msgs = checkify_subtrace(f)
205-
f = checkify_traceable(f, tuple(error.msgs.items()))
213+
f = checkify_traceable(f, tuple(error.msgs.items()), enabled_errors)
206214
err_aval = core.raise_to_shaped(core.get_aval(error.err))
207215
code_aval = core.raise_to_shaped(core.get_aval(error.code))
208216
avals_in = [err_aval, code_aval, *in_avals]
@@ -244,20 +252,25 @@ def assert_abstract_eval(pred, code, *, msgs):
244252
def summary() -> str:
245253
return str(source_info_util.summarize(source_info_util.current()))
246254

247-
def nan_error_check(prim, error, *in_vals, **params):
255+
def nan_error_check(prim, error, enabled_errors, *in_vals, **params):
248256
out = prim.bind(*in_vals, **params)
257+
if ErrorCategory.NAN not in enabled_errors:
258+
return out, error
249259
no_nans = jnp.logical_not(jnp.any(jnp.isnan(out)))
250260
msg = f"nan generated by primitive {prim.name} at {summary()}"
251261
return out, assert_func(error, no_nans, msg)
252262

253-
def gather_error_check(error, operand, start_indices, *,
263+
def gather_error_check(error, enabled_errors, operand, start_indices, *,
254264
dimension_numbers, slice_sizes, unique_indices,
255265
indices_are_sorted, mode, fill_value):
256266
out = lax.gather_p.bind(
257267
operand, start_indices, dimension_numbers=dimension_numbers,
258268
slice_sizes=slice_sizes, unique_indices=unique_indices,
259269
indices_are_sorted=indices_are_sorted, mode=mode, fill_value=fill_value)
260270

271+
if ErrorCategory.OOB not in enabled_errors:
272+
return out, error
273+
261274
# compare to OOB masking logic in lax._gather_translation_rule
262275
dnums = dimension_numbers
263276
operand_dims = np.array(operand.shape)
@@ -270,12 +283,13 @@ def gather_error_check(error, operand, start_indices, *,
270283
return out, assert_func(error, all_inbounds, msg)
271284
error_checks[lax.gather_p] = gather_error_check
272285

273-
def div_error_check(error, x, y):
286+
def div_error_check(error, enabled_errors, x, y):
274287
"""Checks for division by zero and NaN."""
275-
all_nonzero = jnp.logical_not(jnp.any(jnp.equal(y, 0)))
276-
msg = f'divided by zero at {summary()}'
277-
div_by_zero_err = assert_func(error, all_nonzero, msg)
278-
return nan_error_check(lax.div_p, div_by_zero_err, x, y)
288+
if ErrorCategory.DIV in enabled_errors:
289+
all_nonzero = jnp.logical_not(jnp.any(jnp.equal(y, 0)))
290+
msg = f'divided by zero at {summary()}'
291+
error = assert_func(error, all_nonzero, msg)
292+
return nan_error_check(lax.div_p, error, enabled_errors, x, y)
279293
error_checks[lax.div_p] = div_error_check
280294

281295
def scatter_in_bounds(operand, indices, updates, dnums):
@@ -300,17 +314,19 @@ def scatter_in_bounds(operand, indices, updates, dnums):
300314
upper_in_bounds = jnp.all(jnp.less_equal(indices, upper_bound))
301315
return jnp.logical_and(lower_in_bounds, upper_in_bounds)
302316

303-
def scatter_error_check(prim, error, operand, indices, updates, *,
304-
update_jaxpr, update_consts,
305-
dimension_numbers, indices_are_sorted,
306-
unique_indices, mode):
317+
def scatter_error_check(prim, error, enabled_errors, operand, indices, updates,
318+
*, update_jaxpr, update_consts, dimension_numbers,
319+
indices_are_sorted, unique_indices, mode):
307320
"""Checks if indices are within bounds and update does not generate NaN."""
308321
out = prim.bind(
309322
operand, indices, updates, update_jaxpr=update_jaxpr,
310323
update_consts=update_consts, dimension_numbers=dimension_numbers,
311324
indices_are_sorted=indices_are_sorted, unique_indices=unique_indices,
312325
mode=mode)
313326

327+
if ErrorCategory.OOB not in enabled_errors:
328+
return out, error
329+
314330
in_bounds = scatter_in_bounds(operand, indices, updates, dimension_numbers)
315331
oob_msg = f'out-of-bounds indexing while updating at {summary()}'
316332
oob_error = assert_func(error, in_bounds, oob_msg)
@@ -324,8 +340,8 @@ def scatter_error_check(prim, error, operand, indices, updates, *,
324340
error_checks[lax.scatter_min_p] = partial(scatter_error_check, lax.scatter_min_p)
325341
error_checks[lax.scatter_max_p] = partial(scatter_error_check, lax.scatter_max_p)
326342

327-
def cond_error_check(error, index, *ops, branches, linear):
328-
new_branches, msgs_ = unzip2(checkify_jaxpr(jxpr, error) for jxpr in branches)
343+
def cond_error_check(error, enabled_errors, index, *ops, branches, linear):
344+
new_branches, msgs_ = unzip2(checkify_jaxpr(jxpr, error, enabled_errors) for jxpr in branches)
329345
new_linear = (False, False, *linear)
330346
err, code, *outs = lax.cond_p.bind(
331347
index, error.err, error.code, *ops,
@@ -334,9 +350,9 @@ def cond_error_check(error, index, *ops, branches, linear):
334350
return outs, Error(err, code, new_msgs)
335351
error_checks[lax.cond_p] = cond_error_check
336352

337-
def scan_error_check(error, *in_flat, reverse, length, jaxpr, num_consts, num_carry, linear, unroll):
353+
def scan_error_check(error, enabled_errors, *in_flat, reverse, length, jaxpr, num_consts, num_carry, linear, unroll):
338354
consts, carry, xs = split_list(in_flat, [num_consts, num_carry])
339-
checked_jaxpr, msgs_ = checkify_jaxpr(jaxpr, error)
355+
checked_jaxpr, msgs_ = checkify_jaxpr(jaxpr, error, enabled_errors)
340356
new_linear = (False, False, *linear)
341357
new_in_flat = [*consts, error.err, error.code, *carry, *xs]
342358
err, code, *outs = lax.scan_p.bind(
@@ -348,14 +364,14 @@ def scan_error_check(error, *in_flat, reverse, length, jaxpr, num_consts, num_ca
348364
return outs, Error(err, code, new_msgs)
349365
error_checks[lax.scan_p] = scan_error_check
350366

351-
def checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr, error):
367+
def checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr, error, enabled_errors):
352368
cond_f = core.jaxpr_as_fun(cond_jaxpr)
353369
body_f = core.jaxpr_as_fun(body_jaxpr)
354370
def new_body_f(*vals):
355371
out = body_f(*vals)
356372
_ = cond_f(*out) # this checks if the next cond application will error
357373
return out
358-
return checkify_fun_to_jaxpr(lu.wrap_init(new_body_f), error, body_jaxpr.in_avals)
374+
return checkify_fun_to_jaxpr(lu.wrap_init(new_body_f), error, enabled_errors, body_jaxpr.in_avals)
359375

360376
def ignore_errors_jaxpr(jaxpr, error):
361377
"""Constructs a jaxpr which takes two extra args but ignores them."""
@@ -369,13 +385,13 @@ def ignore_errors_jaxpr(jaxpr, error):
369385
jaxpr.outvars, jaxpr.eqns)
370386
return core.ClosedJaxpr(new_jaxpr, consts)
371387

372-
def while_loop_error_check(error, *in_flat, cond_nconsts, cond_jaxpr, body_nconsts, body_jaxpr):
373-
checked_cond_jaxpr, msgs_cond = checkify_jaxpr(cond_jaxpr, error)
388+
def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts, cond_jaxpr, body_nconsts, body_jaxpr):
389+
checked_cond_jaxpr, msgs_cond = checkify_jaxpr(cond_jaxpr, error, enabled_errors)
374390
checked_cond_fun = core.jaxpr_as_fun(checked_cond_jaxpr)
375391
# Check if the first cond application will error.
376392
cond_err, cond_code, _ = checked_cond_fun(error.err, error.code, *in_flat)
377393

378-
checked_body_jaxpr, msgs_body = checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr, error)
394+
checked_body_jaxpr, msgs_body = checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr, error, enabled_errors)
379395
compat_cond_jaxpr = ignore_errors_jaxpr(cond_jaxpr, error)
380396
c_consts, b_consts, carry = split_list(in_flat, [cond_nconsts, body_nconsts])
381397
new_in_flat = [*c_consts, *b_consts, cond_err, cond_code, *carry]
@@ -453,7 +469,10 @@ def add_nan_check(prim):
453469
add_nan_check(lax.max_p)
454470
add_nan_check(lax.min_p)
455471

456-
def assert_discharge_rule(error, pred, code, *, msgs):
472+
def assert_discharge_rule(error, enabled_errors, pred, code, *, msgs):
473+
if ErrorCategory.ASSERT not in enabled_errors:
474+
return [], error
475+
457476
out_err = error.err | jnp.logical_not(pred)
458477
out_code = lax.select(error.err, error.code, code)
459478
return [], Error(out_err, out_code, {**error.msgs, **msgs})
@@ -462,13 +481,24 @@ def assert_discharge_rule(error, pred, code, *, msgs):
462481

463482
## checkify api
464483

484+
ErrorCategory = enum.Enum('ErrorCategory', ['NAN', 'OOB', 'DIV', 'ASSERT'])
485+
486+
float_errors = {ErrorCategory.NAN, ErrorCategory.DIV}
487+
index_errors = {ErrorCategory.OOB}
488+
automatic_errors = float_errors | index_errors
489+
user_asserts = {ErrorCategory.ASSERT}
490+
465491
Out = TypeVar('Out')
466-
def checkify(fun: Callable[..., Out]) -> Callable[..., Tuple[Error, Out]]:
492+
def checkify(fun: Callable[..., Out], errors: Set[ErrorCategory] = user_asserts) -> Callable[..., Tuple[Error, Out]]:
493+
if not errors:
494+
raise ValueError('Checkify needs to be called with at least one enabled'
495+
' ErrorCategory, was called with an empty errors set.')
496+
467497
@traceback_util.api_boundary
468498
def checked_fun(*args, **kwargs):
469499
args_flat, in_tree = tree_flatten((args, kwargs))
470500
f, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
471-
(err, code, out_flat), msgs = checkify_flat(f, *args_flat)
501+
(err, code, out_flat), msgs = checkify_flat(f, frozenset(errors), *args_flat)
472502
out = tree_unflatten(out_tree(), out_flat)
473503
return Error(err, code, msgs), out
474504
return checked_fun

0 commit comments

Comments
 (0)