Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions jax/_src/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,6 +884,11 @@ def move_envvars(jaxpr: Jaxpr, which: tuple[bool, ...]) -> Jaxpr:
constvars, envvars = partition_list(which, jaxpr.constvars)
return jaxpr.replace(constvars=constvars, invars=[*envvars, *jaxpr.invars])

@weakref_lru_cache
def separate_consts(jaxpr: ClosedJaxpr) -> tuple[ClosedJaxpr, list[Any]]:
"""Moves the constvars to the start of invars and returns the consts explicitly."""
return ClosedJaxpr(convert_constvars_jaxpr(jaxpr.jaxpr), []), jaxpr.consts

@weakref_lru_cache
def convert_constvars_jaxpr(jaxpr: Jaxpr) -> Jaxpr:
"""Moves the constvars to the start of invars."""
Expand Down Expand Up @@ -2400,7 +2405,7 @@ def trace_to_jaxpr(
in_tree: PyTreeDef,
in_avals_flat: Sequence[AbstractValue | core.AvalQDD],
debug_info: core.DebugInfo
) -> tuple[Jaxpr, PyTreeDef, list[Any]]:
) -> tuple[ClosedJaxpr, PyTreeDef]:
config.enable_checks.value and debug_info.assert_arg_names(len(in_avals_flat))
parent_trace = core.trace_ctx.trace
trace = DynamicJaxprTrace(debug_info, parent_trace=parent_trace)
Expand All @@ -2424,8 +2429,7 @@ def trace_to_jaxpr(
del trace, fun, in_tracers_flat, in_tracers, out_tracers, ans, ans_flat

config.enable_checks.value and core.check_jaxpr(jaxpr)
return jaxpr, out_tree, consts

return ClosedJaxpr(jaxpr, consts), out_tree

# TODO(dougalm): remove in favor of `trace_to_jaxpr`
@profiler.annotate_function
Expand Down
3 changes: 0 additions & 3 deletions jax/_src/lax/control_flow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,6 @@
# Private utilities used elsewhere in JAX
# TODO(sharadmv): lift them into a more common place
from jax._src.lax.control_flow.common import (
_initial_style_open_jaxpr as _initial_style_open_jaxpr,
_initial_style_jaxpr as _initial_style_jaxpr,
_initial_style_jaxprs_with_common_consts as _initial_style_jaxprs_with_common_consts,
_check_tree_and_avals as _check_tree_and_avals,

)
Expand Down
84 changes: 30 additions & 54 deletions jax/_src/lax/control_flow/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from __future__ import annotations

from collections.abc import Callable, Sequence
from collections.abc import Sequence
import os
from functools import partial
from typing import Any
Expand All @@ -27,7 +27,7 @@
from jax._src.util import weakref_lru_cache, safe_map
from jax._src.interpreters import partial_eval as pe
from jax._src.tree_util import (equality_errors_pytreedef, tree_map,
tree_unflatten, keystr, PyTreeDef)
tree_unflatten, keystr)

map, unsafe_map = safe_map, map

Expand All @@ -43,78 +43,54 @@ def _typecheck_param(prim, param, name, msg_required, pred):
msg = sep.join([msg, param_str])
raise core.JaxprTypeError(msg)

# TODO(dougalm): this is a silly wrapper now. Delete it.
@weakref_lru_cache
def _initial_style_open_jaxpr(fun: Callable,
in_tree: PyTreeDef,
in_avals: Sequence[core.AbstractValue | core.AvalQDD],
debug_info: core.DebugInfo):
jaxpr, out_tree, consts = pe.trace_to_jaxpr(fun, in_tree, in_avals, debug_info)
return jaxpr, consts, out_tree

# TODO(dougalm): Delete. Make `trace_to_jaxpr` do the jaxpr-closing thing instead.
@weakref_lru_cache
def _initial_style_jaxpr(fun: Callable,
in_tree: PyTreeDef,
in_avals: Sequence[core.AbstractValue],
debug_info: core.DebugInfo) -> tuple[core.ClosedJaxpr, Sequence[Any], PyTreeDef]:
jaxpr, consts, out_tree = _initial_style_open_jaxpr(
fun, in_tree, in_avals, debug_info)
closed_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr))
return closed_jaxpr, consts, out_tree

def _initial_style_jaxprs_with_common_consts(
funs: Sequence[Callable],
in_tree: PyTreeDef, in_avals: Sequence[core.AbstractValue | core.AvalQDD],
debug_infos: Sequence[core.DebugInfo]):
jaxpr_data = [_initial_style_open_jaxpr(fn, in_tree, in_avals, debug_info)
for fn, debug_info in zip(funs, debug_infos)]
if not jaxpr_data: return [], [], []
jaxprs, all_consts, all_out_trees = zip(*jaxpr_data)

# TODO(dougalm): this seems way too complicated. Why not allow different consts for each
# branch of a switch?
def _merge_common_consts(
jaxprs: Sequence[core.ClosedJaxpr],
all_consts: Sequence[Sequence[Any]]
) -> tuple[Sequence[core.ClosedJaxpr], Sequence[Any]]:
# Jaxprs must share consts, so we concat consts and pad the jaxprs' constvars.
lens = map(len, all_consts)
consts = [c for cs in all_consts for c in cs]
avalqdds = tuple(map(core.cur_aval_qdd, consts))
jaxprs = [_pad_constvars(jaxpr, avalqdds[:sum(lens[:i])], avalqdds[sum(lens[:i+1]):])
for i, jaxpr in enumerate(jaxprs)]
num_constss = [len(cs) for cs in all_consts]
jaxprs = [_pad_constvars(jaxpr, num_consts, avalqdds[:sum(lens[:i])], avalqdds[sum(lens[:i+1]):])
for i, (jaxpr, num_consts) in enumerate(zip(jaxprs, num_constss))]
# De-duplicate shared constants.
const_ids = tuple(id(c) for c in consts)
seen = set()
consts = [c for c in consts if id(c) not in seen and not seen.add(id(c))] # type: ignore
jaxprs = [_dedup_consts(jaxpr, const_ids) for jaxpr in jaxprs]

closed_jaxprs = [pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr))
for jaxpr in jaxprs]
return closed_jaxprs, consts, all_out_trees
dd_consts = [c for c in consts if id(c) not in seen and not seen.add(id(c))] # type: ignore
jaxprs = [_dedup_consts(jaxpr, len(consts), const_ids) for jaxpr in jaxprs]
return jaxprs, dd_consts

@weakref_lru_cache
def _pad_constvars(jaxpr: core.Jaxpr, left: tuple[core.AvalQDD, ...],
right: tuple[core.AbstractValue, ...]) -> core.Jaxpr:
def _pad_constvars(jaxpr: core.ClosedJaxpr, num_consts: int,
left: tuple[core.AvalQDD, ...],
right: tuple[core.AbstractValue, ...]) -> core.ClosedJaxpr:
def make_var(aq):
return core.Var(aq.aval, initial_qdd=aq.qdd, final_qdd=aq.qdd)
constvars = [*map(make_var, left), *jaxpr.constvars, *map(make_var, right)]
effs = pe._renumber_effects([*constvars, *jaxpr.invars],
[*jaxpr.constvars, *jaxpr.invars], jaxpr.effects)
jaxpr = jaxpr.replace(constvars=constvars, effects=effs)
config.enable_checks.value and core.check_jaxpr(jaxpr)
invars = [*map(make_var, left), *jaxpr.invars[:num_consts],
*map(make_var, right), *jaxpr.invars[num_consts:]]
effs = pe._renumber_effects(invars, jaxpr.invars, jaxpr.effects)
jaxpr = jaxpr.replace(jaxpr=jaxpr.jaxpr.replace(invars=invars, effects=effs))
config.enable_checks.value and core.check_jaxpr(jaxpr.jaxpr)
return jaxpr
Comment on lines +67 to 77
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The jaxpr argument is a ClosedJaxpr, but you are trying to access attributes like invars and effects which belong to the inner Jaxpr object. This will raise an AttributeError. You should access them via jaxpr.jaxpr.

Suggested change
def _pad_constvars(jaxpr: core.ClosedJaxpr, num_consts: int,
left: tuple[core.AvalQDD, ...],
right: tuple[core.AbstractValue, ...]) -> core.ClosedJaxpr:
def make_var(aq):
return core.Var(aq.aval, initial_qdd=aq.qdd, final_qdd=aq.qdd)
constvars = [*map(make_var, left), *jaxpr.constvars, *map(make_var, right)]
effs = pe._renumber_effects([*constvars, *jaxpr.invars],
[*jaxpr.constvars, *jaxpr.invars], jaxpr.effects)
jaxpr = jaxpr.replace(constvars=constvars, effects=effs)
config.enable_checks.value and core.check_jaxpr(jaxpr)
invars = [*map(make_var, left), *jaxpr.invars[:num_consts],
*map(make_var, right), *jaxpr.invars[num_consts:]]
effs = pe._renumber_effects(invars, jaxpr.invars, jaxpr.effects)
jaxpr = jaxpr.replace(jaxpr=jaxpr.jaxpr.replace(invars=invars, effects=effs))
config.enable_checks.value and core.check_jaxpr(jaxpr.jaxpr)
return jaxpr
def _pad_constvars(jaxpr: core.ClosedJaxpr, num_consts: int,
left: tuple[core.AvalQDD, ...],
right: tuple[core.AbstractValue, ...]) -> core.ClosedJaxpr:
def make_var(aq):
return core.Var(aq.aval, initial_qdd=aq.qdd, final_qdd=aq.qdd)
invars = [*map(make_var, left), *jaxpr.jaxpr.invars[:num_consts],
*map(make_var, right), *jaxpr.jaxpr.invars[num_consts:]]
effs = pe._renumber_effects(invars, jaxpr.jaxpr.invars, jaxpr.jaxpr.effects)
jaxpr = jaxpr.replace(jaxpr=jaxpr.jaxpr.replace(invars=invars, effects=effs))
config.enable_checks.value and core.check_jaxpr(jaxpr.jaxpr)
return jaxpr


@weakref_lru_cache
def _dedup_consts(jaxpr, const_ids):
def _dedup_consts(jaxpr, num_consts, const_ids):
newvars = {}
canonicalize = {v: newvars.setdefault(constid, v)
for constid, v in zip(const_ids, jaxpr.constvars)}
for constid, v in zip(const_ids, jaxpr.invars[:num_consts])}
eqns = [e.replace(invars=[canonicalize.get(x, x) if isinstance(x, core.Var)
else x for x in e.invars]) for e in jaxpr.eqns]
outvars = [canonicalize.get(x, x) if isinstance(x, core.Var) else x
for x in jaxpr.outvars]
constvars = list(newvars.values())
effs = pe._renumber_effects(
[*constvars, *jaxpr.invars],
[*map(canonicalize.get, jaxpr.constvars), *jaxpr.invars], jaxpr.effects)
jaxpr = jaxpr.replace(constvars=constvars, eqns=eqns, outvars=outvars,
effects=effs)
invars = [*list(newvars.values()), *jaxpr.invars[num_consts:]]
effs = pe._renumber_effects(invars,
[*map(canonicalize.get, jaxpr.invars[:num_consts]), *jaxpr.invars[num_consts:]],
jaxpr.effects)
jaxpr = jaxpr.replace(jaxpr=jaxpr.jaxpr.replace(invars=invars, eqns=eqns, outvars=outvars,
effects=effs))
config.enable_checks.value and core.check_jaxpr(jaxpr)
return jaxpr
Comment on lines 79 to 95
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Similar to _pad_constvars, the jaxpr argument is a ClosedJaxpr, but you are accessing attributes of the inner Jaxpr object directly (e.g., jaxpr.invars, jaxpr.eqns). This will raise an AttributeError. You should use jaxpr.jaxpr to access them. Also, check_jaxpr at the end should be called on jaxpr.jaxpr. It would also be good to add type hints to this function for clarity.

Suggested change
@weakref_lru_cache
def _dedup_consts(jaxpr, const_ids):
def _dedup_consts(jaxpr, num_consts, const_ids):
newvars = {}
canonicalize = {v: newvars.setdefault(constid, v)
for constid, v in zip(const_ids, jaxpr.constvars)}
for constid, v in zip(const_ids, jaxpr.invars[:num_consts])}
eqns = [e.replace(invars=[canonicalize.get(x, x) if isinstance(x, core.Var)
else x for x in e.invars]) for e in jaxpr.eqns]
outvars = [canonicalize.get(x, x) if isinstance(x, core.Var) else x
for x in jaxpr.outvars]
constvars = list(newvars.values())
effs = pe._renumber_effects(
[*constvars, *jaxpr.invars],
[*map(canonicalize.get, jaxpr.constvars), *jaxpr.invars], jaxpr.effects)
jaxpr = jaxpr.replace(constvars=constvars, eqns=eqns, outvars=outvars,
effects=effs)
invars = [*list(newvars.values()), *jaxpr.invars[num_consts:]]
effs = pe._renumber_effects(invars,
[*map(canonicalize.get, jaxpr.invars[:num_consts]), *jaxpr.invars[num_consts:]],
jaxpr.effects)
jaxpr = jaxpr.replace(jaxpr=jaxpr.jaxpr.replace(invars=invars, eqns=eqns, outvars=outvars,
effects=effs))
config.enable_checks.value and core.check_jaxpr(jaxpr)
return jaxpr
@weakref_lru_cache
def _dedup_consts(jaxpr: core.ClosedJaxpr, num_consts: int, const_ids: tuple) -> core.ClosedJaxpr:
newvars = {}
canonicalize = {v: newvars.setdefault(constid, v)
for constid, v in zip(const_ids, jaxpr.jaxpr.invars[:num_consts])}
eqns = [e.replace(invars=[canonicalize.get(x, x) if isinstance(x, core.Var)
else x for x in e.invars]) for e in jaxpr.jaxpr.eqns]
outvars = [canonicalize.get(x, x) if isinstance(x, core.Var) else x
for x in jaxpr.jaxpr.outvars]
invars = [*list(newvars.values()), *jaxpr.jaxpr.invars[num_consts:]]
effs = pe._renumber_effects(invars,
[*map(canonicalize.get, jaxpr.jaxpr.invars[:num_consts]), *jaxpr.jaxpr.invars[num_consts:]],
jaxpr.jaxpr.effects)
jaxpr = jaxpr.replace(jaxpr=jaxpr.jaxpr.replace(invars=invars, eqns=eqns, outvars=outvars,
effects=effs))
config.enable_checks.value and core.check_jaxpr(jaxpr.jaxpr)
return jaxpr


Expand Down
68 changes: 16 additions & 52 deletions jax/_src/lax/control_flow/conditionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from collections.abc import Callable, Sequence
import functools
from functools import partial
import inspect
import itertools
import operator
from typing import Any, TypeVar
Expand Down Expand Up @@ -53,7 +52,7 @@
import numpy as np

from jax._src.lax.control_flow.common import (
_avals_short, _typecheck_param, _initial_style_jaxprs_with_common_consts,
_avals_short, _typecheck_param, _merge_common_consts,
_make_closed_jaxpr, _prune_zeros)

map, unsafe_map = safe_map, map
Expand Down Expand Up @@ -149,8 +148,11 @@ def _switch_internal(
if config.mutable_array_checks.value:
api_util.check_no_aliased_ref_args(lambda: dbgs[0], ops_avals, ops)

jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
branches, ops_tree, ops_avals, dbgs)
jaxprs_, out_trees = zip(*[pe.trace_to_jaxpr(
branch, ops_tree, ops_avals, dbg) for branch, dbg in zip(branches, dbgs)])
jaxprs_, all_consts = zip(*[pe.separate_consts(j) for j in jaxprs_])
jaxprs, consts = _merge_common_consts(jaxprs_, all_consts)

if config.mutable_array_checks.value:
api_util._check_no_aliased_closed_over_refs(dbgs[0], (*jaxprs[0].consts, *consts), ops)
for i, (out_tree, jaxpr) in enumerate(zip(out_trees[1:], jaxprs[1:])):
Expand Down Expand Up @@ -184,7 +186,7 @@ def _switch_internal(
return tree_unflatten(out_trees[0], out)

@partial(api_boundary, repro_api_name="jax_cond")
def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
def cond(pred, true_fun: Callable, false_fun: Callable, *operands,
operand=_no_operand_sentinel):
"""Conditionally apply ``true_fun`` or ``false_fun``.

Expand Down Expand Up @@ -270,14 +272,18 @@ def cond(pred, true_fun, false_fun, *operands):
if config.mutable_array_checks.value:
api_util.check_no_aliased_ref_args(lambda: dbg_true_fun, ops_avals, ops)
dbg_false_fun = api_util.debug_info("cond", false_fun, operands, {})
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
(true_fun, false_fun), ops_tree, ops_avals,
[dbg_true_fun, dbg_false_fun])
true_jaxpr, false_jaxpr = jaxprs

true_jaxpr_, out_tree = pe.trace_to_jaxpr(
true_fun, ops_tree, ops_avals, dbg_true_fun)
true_jaxpr_, true_consts = pe.separate_consts(true_jaxpr_)
false_jaxpr_, false_out_tree = pe.trace_to_jaxpr(
false_fun, ops_tree, ops_avals, dbg_false_fun)
false_jaxpr_, false_consts = pe.separate_consts(false_jaxpr_)
(true_jaxpr, false_jaxpr), consts = _merge_common_consts(
(true_jaxpr_, false_jaxpr_), (true_consts, false_consts))
if config.mutable_array_checks.value:
api_util._check_no_aliased_closed_over_refs(dbg_true_fun, (*true_jaxpr.consts, *consts), ops)

out_tree, false_out_tree = out_trees
if any(isinstance(out_aval, AbstractRef) for out_aval in
true_jaxpr.out_avals + false_jaxpr.out_avals):
raise ValueError("Cannot return `Ref`s from `cond`.")
Expand Down Expand Up @@ -399,48 +405,6 @@ def _capitalize(s):
# s.capitalize() converts s[1:] to lowercase which we don't want.
return s[0].capitalize() + s[1:]

@api_boundary
@functools.wraps(_cond)
def cond(*args, **kwargs):
# detect an attempt to call the former, deprecated cond
try:
ba = inspect.signature(_cond_with_per_branch_args).bind(*args, **kwargs)
except TypeError:
pass
else:
assert not ba.kwargs # no catch-all **kwargs in _cond_with_per_branch
_, true_operand, true_fun, false_operand, false_fun = ba.args
if callable(true_operand) and callable(true_fun):
# treat this as modern cond (with two operands)
return _cond(*args, **kwargs)
if callable(true_fun) and callable(false_fun):
return _cond_with_per_branch_args(*ba.args)

return _cond(*args, **kwargs)

@partial(api_boundary, repro_api_name="jax_cond_with_per_branch_args")
def _cond_with_per_branch_args(pred,
true_operand, true_fun: Callable,
false_operand, false_fun: Callable):
"""Conditionally apply ``true_fun`` or ``false_fun``.

Has equivalent semantics to this Python implementation::

def cond(pred, true_operand, true_fun, false_operand, false_fun):
if pred:
return true_fun(true_operand)
else:
return false_fun(false_operand)

Pred has to be a scalar type, collection types (list, tuple) are not supported
"""
if not (callable(true_fun) and callable(false_fun)):
raise TypeError("lax.cond: true_fun and false_fun arguments should be callable.")
return _cond(pred,
lambda op: true_fun(op[0]),
lambda op: false_fun(op[1]),
(true_operand, false_operand))

def _join_cond_effects(branches: Sequence[core.ClosedJaxpr]) -> effects.Effects:
joined_effects = set()
for b in branches:
Expand Down
14 changes: 7 additions & 7 deletions jax/_src/lax/control_flow/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
from jax._src.lax import slicing
from jax._src.lax import windowed_reductions
from jax._src.lax.control_flow.common import (
_avals_short, _initial_style_jaxpr, _prune_zeros, _typecheck_param,
_avals_short, _prune_zeros, _typecheck_param,
_make_closed_jaxpr)
from jax._src.lax.other import logaddexp
from jax._src.pjit import auto_axes, PartitionSpec as P, reshard
Expand Down Expand Up @@ -281,9 +281,9 @@ def _create_jaxpr(init):
init_flat, init_tree = tree_flatten(init)
in_flat, in_tree = tree_flatten((init, xs))
carry_avals = tuple(_map(core.get_aval, init_flat))
open_jaxpr, out_tree, consts = pe.trace_to_jaxpr(
jaxpr, out_tree = pe.trace_to_jaxpr(
f, in_tree, (*carry_avals, *x_avals), debug_info=dbg_body)
jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(open_jaxpr))
jaxpr, consts = pe.separate_consts(jaxpr)
if config.mutable_array_checks.value:
_check_no_aliased_closed_over_refs(dbg_body, (*jaxpr.consts, *consts), in_flat)
out_tree_children = out_tree.children()
Expand Down Expand Up @@ -1712,11 +1712,11 @@ def _create_jaxpr(init_val):
init_vals, in_tree = tree_flatten((init_val,))
init_avals = tuple(_map(core.get_aval, init_vals))
cond_dbg = api_util.debug_info("while_cond", cond_fun, (init_val,), {})
cond_jaxpr, cond_consts, cond_tree = _initial_style_jaxpr(
cond_fun, in_tree, init_avals, cond_dbg)
cond_jaxpr, cond_tree = pe.trace_to_jaxpr(cond_fun, in_tree, init_avals, cond_dbg)
cond_jaxpr, cond_consts = pe.separate_consts(cond_jaxpr)
body_dbg = api_util.debug_info("while_body", body_fun, (init_val,), {})
body_jaxpr, body_consts, body_tree = _initial_style_jaxpr(
body_fun, in_tree, init_avals, body_dbg)
body_jaxpr, body_tree = pe.trace_to_jaxpr(body_fun, in_tree, init_avals, body_dbg)
body_jaxpr, body_consts = pe.separate_consts(body_jaxpr)
if not treedef_is_leaf(cond_tree) or len(cond_jaxpr.out_avals) != 1:
msg = "cond_fun must return a boolean scalar, but got pytree {}."
raise TypeError(msg.format(cond_tree))
Expand Down
Loading
Loading