File tree Expand file tree Collapse file tree 2 files changed +3
-4
lines changed
jax/_src/lax/control_flow Expand file tree Collapse file tree 2 files changed +3
-4
lines changed Original file line number Diff line number Diff line change 1515
1616from __future__ import annotations
1717
18- from collections .abc import Callable , Sequence
18+ from collections .abc import Sequence
1919import os
2020from functools import partial
2121from typing import Any
2727from jax ._src .util import weakref_lru_cache , safe_map
2828from jax ._src .interpreters import partial_eval as pe
2929from jax ._src .tree_util import (equality_errors_pytreedef , tree_map ,
30- tree_unflatten , keystr , PyTreeDef )
30+ tree_unflatten , keystr )
3131
3232map , unsafe_map = safe_map , map
3333
@@ -73,7 +73,7 @@ def make_var(aq):
7373 * map (make_var , right ), * jaxpr .invars [num_consts :]]
7474 effs = pe ._renumber_effects (invars , jaxpr .invars , jaxpr .effects )
7575 jaxpr = jaxpr .replace (jaxpr = jaxpr .jaxpr .replace (invars = invars , effects = effs ))
76- config .enable_checks .value and core .check_jaxpr (jaxpr )
76+ config .enable_checks .value and core .check_jaxpr (jaxpr . jaxpr )
7777 return jaxpr
7878
7979@weakref_lru_cache
Original file line number Diff line number Diff line change 1818from collections .abc import Callable , Sequence
1919import functools
2020from functools import partial
21- import inspect
2221import itertools
2322import operator
2423from typing import Any , TypeVar
You can’t perform that action at this time.
0 commit comments