Skip to content

Conversation

@dougalm
Copy link
Collaborator

@dougalm dougalm commented Dec 3, 2025

This reduces the stack depth in cond and exposes fewer internals.

@dougalm dougalm added the pull ready Ready for copybara import and testing label Dec 3, 2025
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @dougalm, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly refactors the internal JAXPR tracing and constant handling mechanisms within JAX, particularly impacting control flow primitives like lax.cond. The changes aim to simplify the JAXPR representation by having trace_to_jaxpr directly return ClosedJaxpr objects, thereby reducing the number of intermediate steps and the overall stack depth during tracing. As a consequence of these internal improvements, the lax.cond API has been streamlined, deprecating its older form in favor of a more direct function-first signature. This work contributes to a more efficient and cleaner JAX core.

Highlights

  • Simplified JAXPR Tracing: The internal trace_to_jaxpr function now directly produces ClosedJaxpr objects, streamlining the process of capturing computation graphs and their associated constants. A new separate_consts utility is introduced for explicit constant extraction when needed.
  • Deprecation of Old lax.cond Form: The lax.cond API has been updated to a more consistent signature, requiring functions to be passed before their operands. The older, more complex wrapper that handled the previous argument style has been removed.
  • Reduced Stack Depth: By removing several intermediate tracing layers and simplifying the JAXPR construction, the pull request achieves a reduction in stack depth, which can improve performance and debugging experiences. This is reflected in updated traceback tests.
  • Internal API Cleanup: Several internal utility functions related to JAXPR handling (e.g., _initial_style_open_jaxpr, _initial_style_jaxpr) have been removed, leading to a cleaner and more maintainable codebase.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request is a nice refactoring that deprecates an old form of cond and simplifies some of the internal tracing logic. The changes are consistent across multiple files and the test suite is updated accordingly. I've found a few critical bugs in the new helper functions in jax/_src/lax/control_flow/common.py related to incorrect attribute access on ClosedJaxpr objects, which would lead to AttributeErrors. I also spotted an unpacking error in jax/_src/lax/control_flow/solves.py that would cause a ValueError at runtime. I've provided detailed comments and suggestions to fix these issues. Besides these, the refactoring looks solid.

Comment on lines +67 to 77
def _pad_constvars(jaxpr: core.ClosedJaxpr, num_consts: int,
left: tuple[core.AvalQDD, ...],
right: tuple[core.AbstractValue, ...]) -> core.ClosedJaxpr:
def make_var(aq):
return core.Var(aq.aval, initial_qdd=aq.qdd, final_qdd=aq.qdd)
constvars = [*map(make_var, left), *jaxpr.constvars, *map(make_var, right)]
effs = pe._renumber_effects([*constvars, *jaxpr.invars],
[*jaxpr.constvars, *jaxpr.invars], jaxpr.effects)
jaxpr = jaxpr.replace(constvars=constvars, effects=effs)
config.enable_checks.value and core.check_jaxpr(jaxpr)
invars = [*map(make_var, left), *jaxpr.invars[:num_consts],
*map(make_var, right), *jaxpr.invars[num_consts:]]
effs = pe._renumber_effects(invars, jaxpr.invars, jaxpr.effects)
jaxpr = jaxpr.replace(jaxpr=jaxpr.jaxpr.replace(invars=invars, effects=effs))
config.enable_checks.value and core.check_jaxpr(jaxpr.jaxpr)
return jaxpr
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The jaxpr argument is a ClosedJaxpr, but you are trying to access attributes like invars and effects which belong to the inner Jaxpr object. This will raise an AttributeError. You should access them via jaxpr.jaxpr.

Suggested change
def _pad_constvars(jaxpr: core.ClosedJaxpr, num_consts: int,
left: tuple[core.AvalQDD, ...],
right: tuple[core.AbstractValue, ...]) -> core.ClosedJaxpr:
def make_var(aq):
return core.Var(aq.aval, initial_qdd=aq.qdd, final_qdd=aq.qdd)
constvars = [*map(make_var, left), *jaxpr.constvars, *map(make_var, right)]
effs = pe._renumber_effects([*constvars, *jaxpr.invars],
[*jaxpr.constvars, *jaxpr.invars], jaxpr.effects)
jaxpr = jaxpr.replace(constvars=constvars, effects=effs)
config.enable_checks.value and core.check_jaxpr(jaxpr)
invars = [*map(make_var, left), *jaxpr.invars[:num_consts],
*map(make_var, right), *jaxpr.invars[num_consts:]]
effs = pe._renumber_effects(invars, jaxpr.invars, jaxpr.effects)
jaxpr = jaxpr.replace(jaxpr=jaxpr.jaxpr.replace(invars=invars, effects=effs))
config.enable_checks.value and core.check_jaxpr(jaxpr.jaxpr)
return jaxpr
def _pad_constvars(jaxpr: core.ClosedJaxpr, num_consts: int,
left: tuple[core.AvalQDD, ...],
right: tuple[core.AbstractValue, ...]) -> core.ClosedJaxpr:
def make_var(aq):
return core.Var(aq.aval, initial_qdd=aq.qdd, final_qdd=aq.qdd)
invars = [*map(make_var, left), *jaxpr.jaxpr.invars[:num_consts],
*map(make_var, right), *jaxpr.jaxpr.invars[num_consts:]]
effs = pe._renumber_effects(invars, jaxpr.jaxpr.invars, jaxpr.jaxpr.effects)
jaxpr = jaxpr.replace(jaxpr=jaxpr.jaxpr.replace(invars=invars, effects=effs))
config.enable_checks.value and core.check_jaxpr(jaxpr.jaxpr)
return jaxpr

Comment on lines 79 to 95
@weakref_lru_cache
def _dedup_consts(jaxpr, const_ids):
def _dedup_consts(jaxpr, num_consts, const_ids):
newvars = {}
canonicalize = {v: newvars.setdefault(constid, v)
for constid, v in zip(const_ids, jaxpr.constvars)}
for constid, v in zip(const_ids, jaxpr.invars[:num_consts])}
eqns = [e.replace(invars=[canonicalize.get(x, x) if isinstance(x, core.Var)
else x for x in e.invars]) for e in jaxpr.eqns]
outvars = [canonicalize.get(x, x) if isinstance(x, core.Var) else x
for x in jaxpr.outvars]
constvars = list(newvars.values())
effs = pe._renumber_effects(
[*constvars, *jaxpr.invars],
[*map(canonicalize.get, jaxpr.constvars), *jaxpr.invars], jaxpr.effects)
jaxpr = jaxpr.replace(constvars=constvars, eqns=eqns, outvars=outvars,
effects=effs)
invars = [*list(newvars.values()), *jaxpr.invars[num_consts:]]
effs = pe._renumber_effects(invars,
[*map(canonicalize.get, jaxpr.invars[:num_consts]), *jaxpr.invars[num_consts:]],
jaxpr.effects)
jaxpr = jaxpr.replace(jaxpr=jaxpr.jaxpr.replace(invars=invars, eqns=eqns, outvars=outvars,
effects=effs))
config.enable_checks.value and core.check_jaxpr(jaxpr)
return jaxpr
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Similar to _pad_constvars, the jaxpr argument is a ClosedJaxpr, but you are accessing attributes of the inner Jaxpr object directly (e.g., jaxpr.invars, jaxpr.eqns). This will raise an AttributeError. You should use jaxpr.jaxpr to access them. Also, check_jaxpr at the end should be called on jaxpr.jaxpr. It would also be good to add type hints to this function for clarity.

Suggested change
@weakref_lru_cache
def _dedup_consts(jaxpr, const_ids):
def _dedup_consts(jaxpr, num_consts, const_ids):
newvars = {}
canonicalize = {v: newvars.setdefault(constid, v)
for constid, v in zip(const_ids, jaxpr.constvars)}
for constid, v in zip(const_ids, jaxpr.invars[:num_consts])}
eqns = [e.replace(invars=[canonicalize.get(x, x) if isinstance(x, core.Var)
else x for x in e.invars]) for e in jaxpr.eqns]
outvars = [canonicalize.get(x, x) if isinstance(x, core.Var) else x
for x in jaxpr.outvars]
constvars = list(newvars.values())
effs = pe._renumber_effects(
[*constvars, *jaxpr.invars],
[*map(canonicalize.get, jaxpr.constvars), *jaxpr.invars], jaxpr.effects)
jaxpr = jaxpr.replace(constvars=constvars, eqns=eqns, outvars=outvars,
effects=effs)
invars = [*list(newvars.values()), *jaxpr.invars[num_consts:]]
effs = pe._renumber_effects(invars,
[*map(canonicalize.get, jaxpr.invars[:num_consts]), *jaxpr.invars[num_consts:]],
jaxpr.effects)
jaxpr = jaxpr.replace(jaxpr=jaxpr.jaxpr.replace(invars=invars, eqns=eqns, outvars=outvars,
effects=effs))
config.enable_checks.value and core.check_jaxpr(jaxpr)
return jaxpr
@weakref_lru_cache
def _dedup_consts(jaxpr: core.ClosedJaxpr, num_consts: int, const_ids: tuple) -> core.ClosedJaxpr:
newvars = {}
canonicalize = {v: newvars.setdefault(constid, v)
for constid, v in zip(const_ids, jaxpr.jaxpr.invars[:num_consts])}
eqns = [e.replace(invars=[canonicalize.get(x, x) if isinstance(x, core.Var)
else x for x in e.invars]) for e in jaxpr.jaxpr.eqns]
outvars = [canonicalize.get(x, x) if isinstance(x, core.Var) else x
for x in jaxpr.jaxpr.outvars]
invars = [*list(newvars.values()), *jaxpr.jaxpr.invars[num_consts:]]
effs = pe._renumber_effects(invars,
[*map(canonicalize.get, jaxpr.jaxpr.invars[:num_consts]), *jaxpr.jaxpr.invars[num_consts:]],
jaxpr.jaxpr.effects)
jaxpr = jaxpr.replace(jaxpr=jaxpr.jaxpr.replace(invars=invars, eqns=eqns, outvars=outvars,
effects=effs))
config.enable_checks.value and core.check_jaxpr(jaxpr.jaxpr)
return jaxpr

Comment on lines 108 to 109
solve_jaxpr, solution_tree, solve_consts = pe.trace_to_jaxpr(
partial(solve, f), in_args_tree, guess_avals, solve_debug)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

pe.trace_to_jaxpr returns a 2-tuple (ClosedJaxpr, PyTreeDef). You are trying to unpack it into three variables, which will cause a ValueError.

Suggested change
solve_jaxpr, solution_tree, solve_consts = pe.trace_to_jaxpr(
partial(solve, f), in_args_tree, guess_avals, solve_debug)
solve_jaxpr, solution_tree = pe.trace_to_jaxpr(
partial(solve, f), in_args_tree, guess_avals, solve_debug)

# TODO(dougalm): this seems way too complicated. Why not allow different consts for each
# branch of a switch?
def _merge_common_consts(
jaxprs: Sequence[core.Jaxpr],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The type hint for jaxprs is Sequence[core.Jaxpr], but based on its usage (it's passed to _pad_constvars which expects a ClosedJaxpr) and the call site in conditionals.py, it should be Sequence[core.ClosedJaxpr].

Suggested change
jaxprs: Sequence[core.Jaxpr],
jaxprs: Sequence[core.ClosedJaxpr],

@dougalm dougalm force-pushed the remove-tracing-wrappers branch from 8b7db85 to ec3ecde Compare December 3, 2025 23:17
This reduces the stack depth in `cond` and exposes fewer internals.
@dougalm dougalm force-pushed the remove-tracing-wrappers branch from ec3ecde to 4cbcbe7 Compare December 3, 2025 23:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

pull ready Ready for copybara import and testing

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant