Skip to content

Commit f1d7a72

Browse files
committed
Fix lint errors
1 parent eaf6efc commit f1d7a72

File tree

2 files changed

+3
-4
lines changed

2 files changed

+3
-4
lines changed

jax/_src/lax/control_flow/common.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from __future__ import annotations
1717

18-
from collections.abc import Callable, Sequence
18+
from collections.abc import Sequence
1919
import os
2020
from functools import partial
2121
from typing import Any
@@ -27,7 +27,7 @@
2727
from jax._src.util import weakref_lru_cache, safe_map
2828
from jax._src.interpreters import partial_eval as pe
2929
from jax._src.tree_util import (equality_errors_pytreedef, tree_map,
30-
tree_unflatten, keystr, PyTreeDef)
30+
tree_unflatten, keystr)
3131

3232
map, unsafe_map = safe_map, map
3333

@@ -73,7 +73,7 @@ def make_var(aq):
7373
*map(make_var, right), *jaxpr.invars[num_consts:]]
7474
effs = pe._renumber_effects(invars, jaxpr.invars, jaxpr.effects)
7575
jaxpr = jaxpr.replace(jaxpr=jaxpr.jaxpr.replace(invars=invars, effects=effs))
76-
config.enable_checks.value and core.check_jaxpr(jaxpr)
76+
config.enable_checks.value and core.check_jaxpr(jaxpr.jaxpr)
7777
return jaxpr
7878

7979
@weakref_lru_cache

jax/_src/lax/control_flow/conditionals.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from collections.abc import Callable, Sequence
1919
import functools
2020
from functools import partial
21-
import inspect
2221
import itertools
2322
import operator
2423
from typing import Any, TypeVar

0 commit comments

Comments
 (0)