1616import collections
1717from collections import namedtuple
1818from contextlib import contextmanager
19+ import functools
1920from functools import partial , partialmethod , total_ordering
2021import gc
2122import itertools as it
@@ -2050,10 +2051,6 @@ def typematch(aval1: AbstractValue, aval2: AbstractValue) -> bool:
20502051
20512052class JaxprTypeError (TypeError ): pass
20522053
2053- def typecheck_assert (pred , msg ):
2054- if not pred :
2055- raise JaxprTypeError (msg )
2056-
20572054custom_typechecks : Dict [Primitive , Callable ] = {}
20582055
20592056def check_jaxpr (jaxpr : Jaxpr ):
@@ -2067,13 +2064,17 @@ def check_jaxpr(jaxpr: Jaxpr):
20672064 Raises `JaxprTypeError` if `jaxpr` is determined invalid. Returns `None`
20682065 otherwise.
20692066 """
2070- ctx = JaxprPpContext ()
2071- try : pp_jaxpr (jaxpr , ctx ) # side-effect on ctx, build variable names
2072- except : pass
2067+ @functools .lru_cache (maxsize = None )
2068+ def ctx_factory ():
2069+ ctx = JaxprPpContext ()
2070+ try : pp_jaxpr (jaxpr , ctx ) # side-effect on ctx, build variable names
2071+ except : pass
2072+ return ctx
20732073
20742074 try :
2075- _check_jaxpr (ctx , jaxpr , [v .aval for v in jaxpr .invars ])
2075+ _check_jaxpr (ctx_factory , jaxpr , [v .aval for v in jaxpr .invars ])
20762076 except JaxprTypeError as e :
2077+ ctx = ctx_factory ()
20772078 if len (e .args ) == 2 :
20782079 msg , eqnidx = e .args
20792080 jaxpr_str = str (pp_jaxpr_eqn_range (jaxpr , eqnidx - 10 , eqnidx + 10 , ctx ))
@@ -2083,22 +2084,28 @@ def check_jaxpr(jaxpr: Jaxpr):
20832084 msg = "\n \n " .join ([msg , "while checking jaxpr:" , jaxpr_str ])
20842085 raise JaxprTypeError (msg ) from None
20852086
2086- def _check_jaxpr (ctx : 'JaxprPpContext' , jaxpr : Jaxpr ,
2087+ def _check_jaxpr (ctx_factory : Callable [[], 'JaxprPpContext' ] , jaxpr : Jaxpr ,
20872088 in_avals : Sequence [AbstractValue ]) -> None :
20882089
20892090 def read (v : Atom ) -> AbstractValue :
20902091 if isinstance (v , Literal ):
20912092 return raise_to_shaped (get_aval (v .val ))
20922093 else :
2093- typecheck_assert (v in env , f"Variable '{ pp_var (v , ctx )} ' not defined" )
2094+ if v not in env :
2095+ ctx = ctx_factory ()
2096+ raise JaxprTypeError (f"Variable '{ pp_var (v , ctx )} ' not defined" )
20942097 return env [v ]
20952098
20962099 def write (v : Var , a : AbstractValue ) -> None :
2097- typecheck_assert (v not in env , f"Variable '{ pp_var (v , ctx )} ' already bound" )
2100+ if v in env :
2101+ ctx = ctx_factory ()
2102+ raise JaxprTypeError (f"Variable '{ pp_var (v , ctx )} ' already bound" )
20982103 if not isinstance (v , DropVar ):
2099- typecheck_assert (typecompat (v .aval , a ),
2100- f"Variable '{ pp_var (v , ctx )} ' inconsistently typed as "
2101- f"{ pp_aval (a , ctx )} , bound as { pp_aval (v .aval , ctx )} " )
2104+ if not typecompat (v .aval , a ):
2105+ ctx = ctx_factory ()
2106+ raise JaxprTypeError (
2107+ f"Variable '{ pp_var (v , ctx )} ' inconsistently typed as "
2108+ f"{ pp_aval (a , ctx )} , bound as { pp_aval (v .aval , ctx )} " )
21022109 env [v ] = a
21032110
21042111 env : Dict [Var , AbstractValue ] = {}
@@ -2111,20 +2118,21 @@ def write(v: Var, a: AbstractValue) -> None:
21112118 prim = eqn .primitive
21122119 try :
21132120 in_avals = map (read , eqn .invars )
2114- typecheck_assert ( all ( not isinstance (ina , ConcreteArray ) for ina in in_avals ),
2115- "Equation given ConcreteArray type inputs" )
2121+ if any ( isinstance (ina , ConcreteArray ) for ina in in_avals ):
2122+ raise JaxprTypeError ( "Equation given ConcreteArray type inputs" )
21162123 if prim in custom_typechecks :
21172124 out_avals = custom_typechecks [prim ](* in_avals , ** eqn .params )
21182125 if out_avals is None :
21192126 out_avals = [v .aval for v in eqn .outvars ]
21202127 elif prim .call_primitive :
2121- out_avals = check_call (ctx , prim , in_avals , eqn .params )
2128+ out_avals = check_call (ctx_factory , prim , in_avals , eqn .params )
21222129 elif prim .map_primitive :
2123- out_avals = check_map (ctx , prim , in_avals , eqn .params )
2130+ out_avals = check_map (ctx_factory , prim , in_avals , eqn .params )
21242131 else :
21252132 out_avals = check_eqn (prim , in_avals , eqn .params )
21262133 map (write , eqn .outvars , out_avals )
21272134 except JaxprTypeError as e :
2135+ ctx = ctx_factory ()
21282136 msg , = e .args
21292137 src = source_info_util .summarize (eqn .source_info )
21302138 msg = "\n \n " .join ([msg , "in equation:" , str (pp .nest (2 , pp_eqn (eqn , ctx ))),
@@ -2142,57 +2150,58 @@ def check_eqn(prim, in_avals, params):
21422150 out_avals = [out_avals ]
21432151 return out_avals
21442152
2145- def check_call (ctx , prim , in_avals , params ):
2146- typecheck_assert ("call_jaxpr" in params ,
2147- f"Call primitive { prim } missing 'call_jaxpr' parameter" )
2153+ def check_call (ctx_factory , prim , in_avals , params ):
2154+ if "call_jaxpr" not in params :
2155+ raise JaxprTypeError (
2156+ f"Call primitive { prim } missing 'call_jaxpr' parameter" )
21482157 call_jaxpr = params ["call_jaxpr" ]
21492158
21502159 # These checks also happen in recursive call, but give better errors here.
2151- typecheck_assert ( len (in_avals ) == len (call_jaxpr .invars ),
2152- f"Call primitive { prim } with { len (call_jaxpr .invars )} "
2153- f"operands cannot call jaxpr with { len (call_jaxpr .invars )} "
2154- f"inputs" )
2160+ if len (in_avals ) != len (call_jaxpr .invars ):
2161+ raise JaxprTypeError ( f"Call primitive { prim } with { len (call_jaxpr .invars )} "
2162+ f"operands cannot call jaxpr with { len (call_jaxpr .invars )} "
2163+ f"inputs" )
21552164 binder_avals = [v .aval for v in call_jaxpr .invars ]
21562165 for binder_aval , in_aval in zip (binder_avals , in_avals ):
2157- typecheck_assert ( typecompat (binder_aval , in_aval ),
2158- f"Call primitive { prim } passes operand { in_aval } "
2159- f"to jaxpr expecting { binder_aval } " )
2166+ if not typecompat (binder_aval , in_aval ):
2167+ raise JaxprTypeError ( f"Call primitive { prim } passes operand { in_aval } "
2168+ f"to jaxpr expecting { binder_aval } " )
21602169
2161- _check_jaxpr (ctx , call_jaxpr , in_avals )
2170+ _check_jaxpr (ctx_factory , call_jaxpr , in_avals )
21622171
21632172 out_avals = [v .aval for v in call_jaxpr .outvars ]
21642173 return out_avals
21652174
2166- def check_map (ctx , prim , in_avals , params ):
2167- typecheck_assert ( "call_jaxpr" in params ,
2168- f"Map primitive { prim } missing 'call_jaxpr' parameter" )
2175+ def check_map (ctx_factory , prim , in_avals , params ):
2176+ if "call_jaxpr" not in params :
2177+ raise JaxprTypeError ( f"Map primitive { prim } missing 'call_jaxpr' parameter" )
21692178 call_jaxpr = params ["call_jaxpr" ]
2170- typecheck_assert ( "axis_size" in params ,
2171- f"Map primitive { prim } missing 'axis_size' parameter" )
2179+ if "axis_size" not in params :
2180+ raise JaxprTypeError ( f"Map primitive { prim } missing 'axis_size' parameter" )
21722181 axis_size = params ["axis_size" ]
2173- typecheck_assert ( "axis_name" in params ,
2174- f"Map primitive { prim } missing 'axis_name' parameter" )
2182+ if "axis_name" not in params :
2183+ raise JaxprTypeError ( f"Map primitive { prim } missing 'axis_name' parameter" )
21752184 axis_name = params ["axis_name" ]
2176- typecheck_assert ( "in_axes" in params ,
2177- f"Map primitive { prim } missing 'in_axes' parameter" )
2185+ if "in_axes" not in params :
2186+ raise JaxprTypeError ( f"Map primitive { prim } missing 'in_axes' parameter" )
21782187 in_axes = params ["in_axes" ]
2179- typecheck_assert ( "out_axes" in params ,
2180- f"Map primitive { prim } missing 'out_axes' parameter" )
2188+ if "out_axes" not in params :
2189+ raise JaxprTypeError ( f"Map primitive { prim } missing 'out_axes' parameter" )
21812190 out_axes = params ["out_axes" ]
21822191
21832192 binder_avals = [unmapped_aval (axis_size , axis_name , in_axis , v .aval )
21842193 if in_axis is not None else v .aval
21852194 for v , in_axis in zip (call_jaxpr .invars , in_axes )]
21862195 for binder_aval , in_aval in zip (binder_avals , in_avals ):
2187- typecheck_assert ( typecompat (binder_aval , in_aval ),
2188- f"Call primitive { prim } passes operand { in_aval } "
2189- f"to jaxpr expecting { binder_aval } " )
2196+ if not typecompat (binder_aval , in_aval ):
2197+ raise JaxprTypeError ( f"Call primitive { prim } passes operand { in_aval } "
2198+ f"to jaxpr expecting { binder_aval } " )
21902199
21912200 mapped_avals = [mapped_aval (axis_size , in_axis , aval )
21922201 if in_axis is not None else aval
21932202 for aval , in_axis in zip (in_avals , in_axes )]
21942203 with extend_axis_env (params ['axis_name' ], axis_size , None ):
2195- _check_jaxpr (ctx , call_jaxpr , mapped_avals )
2204+ _check_jaxpr (ctx_factory , call_jaxpr , mapped_avals )
21962205
21972206 mapped_out_avals = [v .aval for v in call_jaxpr .outvars ]
21982207 out_avals = [unmapped_aval (axis_size , axis_name , out_axis , aval ) if out_axis is not None else aval
0 commit comments