Skip to content

Commit 4c423c3

Browse files
hawkinspjax authors
authored andcommitted
Speed up check_jaxpr().
(check_jaxpr() is only used when debugging.) Don't eagerly pretty print jaxprs: only do so if we are going to raise an error. Don't eagerly form error messages. Delete typecheck_assert. PiperOrigin-RevId: 422594126
1 parent e30b96c commit 4c423c3

File tree

3 files changed

+93
-84
lines changed

3 files changed

+93
-84
lines changed

jax/_src/lax/control_flow.py

Lines changed: 36 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -121,12 +121,13 @@ def _abstractify(x):
121121
return raise_to_shaped(core.get_aval(x))
122122

123123
def _typecheck_param(prim, param, name, msg_required, pred):
124-
msg = (f'invalid {prim} param {name} of type {type(param).__name__}, '
125-
f'{msg_required} required:')
126-
param_str = str(param)
127-
sep = os.linesep if os.linesep in param_str else ' '
128-
msg = sep.join([msg, param_str])
129-
core.typecheck_assert(pred, msg)
124+
if not pred:
125+
msg = (f'invalid {prim} param {name} of type {type(param).__name__}, '
126+
f'{msg_required} required:')
127+
param_str = str(param)
128+
sep = os.linesep if os.linesep in param_str else ' '
129+
msg = sep.join([msg, param_str])
130+
raise core.JaxprTypeError(msg)
130131

131132

132133
### fori_loop and while_loop
@@ -1340,47 +1341,45 @@ def _cond_typecheck(*avals, branches, linear):
13401341
tc(linear, 'linear', 'tuple of bool',
13411342
type(linear) is tuple and all(type(x) is bool for x in linear))
13421343

1343-
core.typecheck_assert(
1344-
len(branches) > 0,
1345-
'cond requires at least one branch function')
1346-
core.typecheck_assert(
1347-
len(linear) + 1 == len(avals),
1348-
f'cond given {len(linear)} linear flags for '
1349-
f'{len(avals) - 1} non-predicate operands')
1344+
if len(branches) == 0:
1345+
raise core.JaxprTypeError('cond requires at least one branch function')
1346+
if len(linear) + 1 != len(avals):
1347+
raise core.JaxprTypeError(f'cond given {len(linear)} linear flags for '
1348+
f'{len(avals) - 1} non-predicate operands')
13501349

13511350
jaxpr0 = branches[0]
13521351
jaxpr0_in_avals_str = _avals_short(jaxpr0.in_avals)
13531352
jaxpr0_out_avals_str = _avals_short(jaxpr0.out_avals)
13541353

13551354
for i, jaxpr in enumerate(branches[1:]):
1356-
core.typecheck_assert(
1357-
len(jaxpr0.in_avals) == len(jaxpr.in_avals),
1355+
if len(jaxpr0.in_avals) != len(jaxpr.in_avals):
1356+
raise core.JaxprTypeError(
13581357
f'cond branch 0 takes {len(jaxpr0.in_avals)} inputs, '
13591358
f'branch {i+1} takes {len(jaxpr.in_avals)}')
1360-
core.typecheck_assert(
1361-
len(jaxpr0.out_avals) == len(jaxpr.out_avals),
1359+
if len(jaxpr0.out_avals) != len(jaxpr.out_avals):
1360+
raise core.JaxprTypeError(
13621361
f'cond branch 0 outputs {len(jaxpr0.out_avals)} values, '
13631362
f'branch {i+1} outputs {len(jaxpr.out_avals)}')
1364-
core.typecheck_assert(
1365-
all(_map(core.typematch, jaxpr0.in_avals, jaxpr.in_avals)),
1363+
if not all(_map(core.typematch, jaxpr0.in_avals, jaxpr.in_avals)):
1364+
raise core.JaxprTypeError(
13661365
f'cond branches 0 and {i+1} have mismatching input types: '
13671366
f'{jaxpr0_in_avals_str} vs {_avals_short(jaxpr.in_avals)}')
1368-
core.typecheck_assert(
1369-
all(_map(core.typematch, jaxpr0.out_avals, jaxpr.out_avals)),
1367+
if not all(_map(core.typematch, jaxpr0.out_avals, jaxpr.out_avals)):
1368+
raise core.JaxprTypeError(
13701369
f'cond branches 0 and {i+1} have mismatching output types: '
13711370
f'{jaxpr0_out_avals_str} vs {_avals_short(jaxpr.out_avals)}')
13721371

1373-
core.typecheck_assert(
1374-
len(avals) == 1 + len(jaxpr0.in_avals),
1372+
if len(avals) != 1 + len(jaxpr0.in_avals):
1373+
raise core.JaxprTypeError(
13751374
f'cond called with {len(avals) - 1} non-predicate operands, '
13761375
f'but branches take {len(jaxpr0.in_avals)} inputs')
13771376

13781377
index_aval, *op_avals = avals
1379-
core.typecheck_assert(
1380-
index_aval.dtype == np.int32,
1378+
if index_aval.dtype != np.int32:
1379+
raise core.JaxprTypeError(
13811380
f'cond called with index of type {index_aval.dtype} instead of int32')
1382-
core.typecheck_assert(
1383-
all(_map(core.typecompat, jaxpr0.in_avals, op_avals)),
1381+
if not all(_map(core.typecompat, jaxpr0.in_avals, op_avals)):
1382+
raise core.JaxprTypeError(
13841383
f'cond branches take input types {jaxpr0_in_avals_str}, '
13851384
f'called with operands of type {_avals_short(op_avals)}')
13861385

@@ -2177,8 +2176,8 @@ def _scan_typecheck(bind_time, *avals, reverse, length, num_consts, num_carry,
21772176
tc(length, 'length', 'non-negative int',
21782177
type(length) in length_types and length >= 0)
21792178

2180-
core.typecheck_assert(
2181-
len(linear) == len(avals),
2179+
if len(linear) != len(avals):
2180+
raise core.JaxprTypeError(
21822181
f'scan param linear has length {len(linear)} for {len(avals)} operands')
21832182

21842183
const_avals, init_avals, x_avals = split_list(avals, [num_consts, num_carry])
@@ -2187,20 +2186,20 @@ def _scan_typecheck(bind_time, *avals, reverse, length, num_consts, num_carry,
21872186
carry_avals_jaxpr, _ = split_list(jaxpr.out_avals, [num_carry])
21882187
x_avals_mapped = _map(partial(core.mapped_aval, length, 0), x_avals)
21892188

2190-
core.typecheck_assert(
2191-
all(_map(core.typematch, init_avals_jaxpr, carry_avals_jaxpr)),
2189+
if not all(_map(core.typematch, init_avals_jaxpr, carry_avals_jaxpr)):
2190+
raise core.JaxprTypeError(
21922191
f'scan input carry input and output types mismatch: '
21932192
f'\n{_avals_short(init_avals_jaxpr)}\nvs\n{_avals_short(carry_avals_jaxpr)}')
2194-
core.typecheck_assert(
2195-
all(_map(core.typecompat, const_avals_jaxpr, const_avals)),
2193+
if not all(_map(core.typecompat, const_avals_jaxpr, const_avals)):
2194+
raise core.JaxprTypeError(
21962195
f'scan jaxpr takes input const types\n{_avals_short(const_avals_jaxpr)},\n'
21972196
f'called with consts of type\n{_avals_short(const_avals)}')
2198-
core.typecheck_assert(
2199-
all(_map(core.typecompat, init_avals_jaxpr, init_avals)),
2197+
if not all(_map(core.typecompat, init_avals_jaxpr, init_avals)):
2198+
raise core.JaxprTypeError(
22002199
f'scan jaxpr takes input carry types\n{_avals_short(init_avals_jaxpr)},\n'
22012200
f'called with initial carry of type\n{_avals_short(init_avals)}')
2202-
core.typecheck_assert(
2203-
all(_map(core.typecompat, x_avals_jaxpr, x_avals_mapped)),
2201+
if not all(_map(core.typecompat, x_avals_jaxpr, x_avals_mapped)):
2202+
raise core.JaxprTypeError(
22042203
f'scan jaxpr takes input sequence types\n{_avals_short(x_avals_jaxpr)},\n'
22052204
f'called with sequence of type\n{_avals_short(x_avals)}')
22062205

jax/core.py

Lines changed: 53 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import collections
1717
from collections import namedtuple
1818
from contextlib import contextmanager
19+
import functools
1920
from functools import partial, partialmethod, total_ordering
2021
import gc
2122
import itertools as it
@@ -2050,10 +2051,6 @@ def typematch(aval1: AbstractValue, aval2: AbstractValue) -> bool:
20502051

20512052
class JaxprTypeError(TypeError): pass
20522053

2053-
def typecheck_assert(pred, msg):
2054-
if not pred:
2055-
raise JaxprTypeError(msg)
2056-
20572054
custom_typechecks: Dict[Primitive, Callable] = {}
20582055

20592056
def check_jaxpr(jaxpr: Jaxpr):
@@ -2067,13 +2064,17 @@ def check_jaxpr(jaxpr: Jaxpr):
20672064
Raises `JaxprTypeError` if `jaxpr` is determined invalid. Returns `None`
20682065
otherwise.
20692066
"""
2070-
ctx = JaxprPpContext()
2071-
try: pp_jaxpr(jaxpr, ctx) # side-effect on ctx, build variable names
2072-
except: pass
2067+
@functools.lru_cache(maxsize=None)
2068+
def ctx_factory():
2069+
ctx = JaxprPpContext()
2070+
try: pp_jaxpr(jaxpr, ctx) # side-effect on ctx, build variable names
2071+
except: pass
2072+
return ctx
20732073

20742074
try:
2075-
_check_jaxpr(ctx, jaxpr, [v.aval for v in jaxpr.invars])
2075+
_check_jaxpr(ctx_factory, jaxpr, [v.aval for v in jaxpr.invars])
20762076
except JaxprTypeError as e:
2077+
ctx = ctx_factory()
20772078
if len(e.args) == 2:
20782079
msg, eqnidx = e.args
20792080
jaxpr_str = str(pp_jaxpr_eqn_range(jaxpr, eqnidx - 10, eqnidx + 10, ctx))
@@ -2083,22 +2084,28 @@ def check_jaxpr(jaxpr: Jaxpr):
20832084
msg = "\n\n".join([msg, "while checking jaxpr:", jaxpr_str])
20842085
raise JaxprTypeError(msg) from None
20852086

2086-
def _check_jaxpr(ctx: 'JaxprPpContext', jaxpr: Jaxpr,
2087+
def _check_jaxpr(ctx_factory: Callable[[], 'JaxprPpContext'], jaxpr: Jaxpr,
20872088
in_avals: Sequence[AbstractValue]) -> None:
20882089

20892090
def read(v: Atom) -> AbstractValue:
20902091
if isinstance(v, Literal):
20912092
return raise_to_shaped(get_aval(v.val))
20922093
else:
2093-
typecheck_assert(v in env, f"Variable '{pp_var(v, ctx)}' not defined")
2094+
if v not in env:
2095+
ctx = ctx_factory()
2096+
raise JaxprTypeError(f"Variable '{pp_var(v, ctx)}' not defined")
20942097
return env[v]
20952098

20962099
def write(v: Var, a: AbstractValue) -> None:
2097-
typecheck_assert(v not in env, f"Variable '{pp_var(v, ctx)}' already bound")
2100+
if v in env:
2101+
ctx = ctx_factory()
2102+
raise JaxprTypeError(f"Variable '{pp_var(v, ctx)}' already bound")
20982103
if not isinstance(v, DropVar):
2099-
typecheck_assert(typecompat(v.aval, a),
2100-
f"Variable '{pp_var(v, ctx)}' inconsistently typed as "
2101-
f"{pp_aval(a, ctx)}, bound as {pp_aval(v.aval, ctx)}")
2104+
if not typecompat(v.aval, a):
2105+
ctx = ctx_factory()
2106+
raise JaxprTypeError(
2107+
f"Variable '{pp_var(v, ctx)}' inconsistently typed as "
2108+
f"{pp_aval(a, ctx)}, bound as {pp_aval(v.aval, ctx)}")
21022109
env[v] = a
21032110

21042111
env : Dict[Var, AbstractValue] = {}
@@ -2111,20 +2118,21 @@ def write(v: Var, a: AbstractValue) -> None:
21112118
prim = eqn.primitive
21122119
try:
21132120
in_avals = map(read, eqn.invars)
2114-
typecheck_assert(all(not isinstance(ina, ConcreteArray) for ina in in_avals),
2115-
"Equation given ConcreteArray type inputs")
2121+
if any(isinstance(ina, ConcreteArray) for ina in in_avals):
2122+
raise JaxprTypeError("Equation given ConcreteArray type inputs")
21162123
if prim in custom_typechecks:
21172124
out_avals = custom_typechecks[prim](*in_avals, **eqn.params)
21182125
if out_avals is None:
21192126
out_avals = [v.aval for v in eqn.outvars]
21202127
elif prim.call_primitive:
2121-
out_avals = check_call(ctx, prim, in_avals, eqn.params)
2128+
out_avals = check_call(ctx_factory, prim, in_avals, eqn.params)
21222129
elif prim.map_primitive:
2123-
out_avals = check_map(ctx, prim, in_avals, eqn.params)
2130+
out_avals = check_map(ctx_factory, prim, in_avals, eqn.params)
21242131
else:
21252132
out_avals = check_eqn(prim, in_avals, eqn.params)
21262133
map(write, eqn.outvars, out_avals)
21272134
except JaxprTypeError as e:
2135+
ctx = ctx_factory()
21282136
msg, = e.args
21292137
src = source_info_util.summarize(eqn.source_info)
21302138
msg = "\n\n".join([msg, "in equation:", str(pp.nest(2, pp_eqn(eqn, ctx))),
@@ -2142,57 +2150,58 @@ def check_eqn(prim, in_avals, params):
21422150
out_avals = [out_avals]
21432151
return out_avals
21442152

2145-
def check_call(ctx, prim, in_avals, params):
2146-
typecheck_assert("call_jaxpr" in params,
2147-
f"Call primitive {prim} missing 'call_jaxpr' parameter")
2153+
def check_call(ctx_factory, prim, in_avals, params):
2154+
if "call_jaxpr" not in params:
2155+
raise JaxprTypeError(
2156+
f"Call primitive {prim} missing 'call_jaxpr' parameter")
21482157
call_jaxpr = params["call_jaxpr"]
21492158

21502159
# These checks also happen in recursive call, but give better errors here.
2151-
typecheck_assert(len(in_avals) == len(call_jaxpr.invars),
2152-
f"Call primitive {prim} with {len(call_jaxpr.invars)} "
2153-
f"operands cannot call jaxpr with {len(call_jaxpr.invars)} "
2154-
f"inputs")
2160+
if len(in_avals) != len(call_jaxpr.invars):
2161+
raise JaxprTypeError(f"Call primitive {prim} with {len(call_jaxpr.invars)} "
2162+
f"operands cannot call jaxpr with {len(call_jaxpr.invars)} "
2163+
f"inputs")
21552164
binder_avals = [v.aval for v in call_jaxpr.invars]
21562165
for binder_aval, in_aval in zip(binder_avals, in_avals):
2157-
typecheck_assert(typecompat(binder_aval, in_aval),
2158-
f"Call primitive {prim} passes operand {in_aval} "
2159-
f"to jaxpr expecting {binder_aval}")
2166+
if not typecompat(binder_aval, in_aval):
2167+
raise JaxprTypeError(f"Call primitive {prim} passes operand {in_aval} "
2168+
f"to jaxpr expecting {binder_aval}")
21602169

2161-
_check_jaxpr(ctx, call_jaxpr, in_avals)
2170+
_check_jaxpr(ctx_factory, call_jaxpr, in_avals)
21622171

21632172
out_avals = [v.aval for v in call_jaxpr.outvars]
21642173
return out_avals
21652174

2166-
def check_map(ctx, prim, in_avals, params):
2167-
typecheck_assert("call_jaxpr" in params,
2168-
f"Map primitive {prim} missing 'call_jaxpr' parameter")
2175+
def check_map(ctx_factory, prim, in_avals, params):
2176+
if "call_jaxpr" not in params:
2177+
raise JaxprTypeError(f"Map primitive {prim} missing 'call_jaxpr' parameter")
21692178
call_jaxpr = params["call_jaxpr"]
2170-
typecheck_assert("axis_size" in params,
2171-
f"Map primitive {prim} missing 'axis_size' parameter")
2179+
if "axis_size" not in params:
2180+
raise JaxprTypeError(f"Map primitive {prim} missing 'axis_size' parameter")
21722181
axis_size = params["axis_size"]
2173-
typecheck_assert("axis_name" in params,
2174-
f"Map primitive {prim} missing 'axis_name' parameter")
2182+
if "axis_name" not in params:
2183+
raise JaxprTypeError(f"Map primitive {prim} missing 'axis_name' parameter")
21752184
axis_name = params["axis_name"]
2176-
typecheck_assert("in_axes" in params,
2177-
f"Map primitive {prim} missing 'in_axes' parameter")
2185+
if "in_axes" not in params:
2186+
raise JaxprTypeError(f"Map primitive {prim} missing 'in_axes' parameter")
21782187
in_axes = params["in_axes"]
2179-
typecheck_assert("out_axes" in params,
2180-
f"Map primitive {prim} missing 'out_axes' parameter")
2188+
if "out_axes" not in params:
2189+
raise JaxprTypeError(f"Map primitive {prim} missing 'out_axes' parameter")
21812190
out_axes = params["out_axes"]
21822191

21832192
binder_avals = [unmapped_aval(axis_size, axis_name, in_axis, v.aval)
21842193
if in_axis is not None else v.aval
21852194
for v, in_axis in zip(call_jaxpr.invars, in_axes)]
21862195
for binder_aval, in_aval in zip(binder_avals, in_avals):
2187-
typecheck_assert(typecompat(binder_aval, in_aval),
2188-
f"Call primitive {prim} passes operand {in_aval} "
2189-
f"to jaxpr expecting {binder_aval}")
2196+
if not typecompat(binder_aval, in_aval):
2197+
raise JaxprTypeError(f"Call primitive {prim} passes operand {in_aval} "
2198+
f"to jaxpr expecting {binder_aval}")
21902199

21912200
mapped_avals = [mapped_aval(axis_size, in_axis, aval)
21922201
if in_axis is not None else aval
21932202
for aval, in_axis in zip(in_avals, in_axes)]
21942203
with extend_axis_env(params['axis_name'], axis_size, None):
2195-
_check_jaxpr(ctx, call_jaxpr, mapped_avals)
2204+
_check_jaxpr(ctx_factory, call_jaxpr, mapped_avals)
21962205

21972206
mapped_out_avals = [v.aval for v in call_jaxpr.outvars]
21982207
out_avals = [unmapped_aval(axis_size, axis_name, out_axis, aval) if out_axis is not None else aval

jax/experimental/maps.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -951,14 +951,15 @@ def _typecheck_xmap(
951951
binder_in_avals = [_insert_aval_axes(v.aval, a_in_axes, local_axis_sizes)
952952
for v, a_in_axes in zip(call_jaxpr.invars, in_axes)]
953953
for binder_in_aval, in_aval in zip(binder_in_avals, in_avals):
954-
core.typecheck_assert(
955-
core.typecompat(binder_in_aval, in_aval),
954+
if not core.typecompat(binder_in_aval, in_aval):
955+
raise core.JaxprTypeError(
956956
f"xmap passes operand {in_aval} to jaxpr expecting {binder_in_aval}")
957957

958958
mapped_in_avals = [_delete_aval_axes(a, a_in_axes, global_axis_sizes)
959959
for a, a_in_axes in zip(in_avals, in_axes)]
960960
with core.extend_axis_env_nd(global_axis_sizes.items()):
961-
core._check_jaxpr(core.JaxprPpContext(), call_jaxpr, mapped_in_avals)
961+
core._check_jaxpr(lambda: core.JaxprPpContext(), call_jaxpr,
962+
mapped_in_avals)
962963

963964
mapped_out_avals = [v.aval for v in call_jaxpr.outvars]
964965
out_avals = [_insert_aval_axes(a, a_out_axes, local_axis_sizes)

0 commit comments

Comments
 (0)