Skip to content

MJX on MacOS is >100x slower on jax>=0.7 than on jax<0.7 #2957

@hartikainen

Description

@hartikainen

Intro

Hi!

My setup

$ uv pip list | grep "jax\|mjx\|mujoco"
jax                 0.8.1
jaxlib              0.8.1
mujoco              3.3.7
mujoco-mjx          3.3.7
$ python -c "import platform; print(f'{platform.system()=}, {platform.release()=}, {platform.machine()=}')"
platform.system()='Darwin', platform.release()='24.6.0', platform.machine()='arm64'

What's happening? What did you expect?

Many of our models run >100x slower on jax>=0.7 than on jax<0.7 when using MJX on MacOS. main...hartikainen:mujoco:slow-jax includes a minimal model that demonstrates this.

When run, I get:

$ uv pip list | grep "jax\|mjx\|mujoco"
jax                 0.6.2
jaxlib              0.6.2
mujoco              3.3.7
mujoco-mjx          3.3.7
$ python ./slow_jax.py
Running benchmark with 9 bodies, 17 geoms, 8 sites on cpu
Compiled.
Time for 100 iterations: 0.0178s
Avg time: 0.1784ms
$ uv pip install -U "jax"
Resolved 6 packages in 110ms
Prepared 2 packages in 0.17ms
Uninstalled 2 packages in 51ms
Installed 2 packages in 7ms
 - jax==0.6.2
 + jax==0.8.1
 - jaxlib==0.6.2
 + jaxlib==0.8.1
$ python ./slow_jax.py
Running benchmark with 9 bodies, 17 geoms, 8 sites on cpu
Compiled.
Time for 100 iterations: 40.7880s
Avg time: 407.8800ms

Note the crazy difference in average times!

Steps for reproduction

Run the script above.

Code required for reproduction

Details
import time
import jax
import mujoco
from mujoco import mjx
import numpy as np

# Create a chain of bodies with many geoms and sites
xml = "<mujoco>\n"
xml += "  <worldbody>\n"
xml += '    <body name="0" pos="0 0 0">\n'
xml += '      <joint type="free"/>\n'
xml += '      <geom size="0.1"/>\n'

depth = 9  # > 20 joints
geoms_per_body = 1
sites_per_body = 1

for i in range(1, depth):
    xml += f'      <body name="{i}" pos="0.1 0 0">\n'
    xml += '        <joint type="hinge"/>\n'
    xml += '        <geom size="0.1"/>\n'
    for j in range(geoms_per_body):
        xml += f'        <geom size="0.01" pos="0 {j*0.01} 0"/>\n'
    for j in range(sites_per_body):
        xml += f'        <site name="s_{i}_{j}" pos="0 0 {j*0.01}"/>\n'

for i in range(depth):
    xml += "    </body>"

xml += "\n  </worldbody>\n"
xml += "</mujoco>\n"

m = mujoco.MjModel.from_xml_string(xml)
d = mujoco.MjData(m)
mx = mjx.put_model(m)
dx = mjx.put_data(m, d)

print(f"Running benchmark with {depth} bodies, {m.ngeom} geoms, {m.nsite} sites on {jax.default_backend()}")

# Compile
kinematics_jit = jax.jit(mjx.kinematics)
dx = kinematics_jit(mx, dx)
dx.qpos.block_until_ready()
print("Compiled.")

# Benchmark
start = time.time()
N = 3
for _ in range(N):
    dx = kinematics_jit(mx, dx)
    dx.qpos.block_until_ready()
end = time.time()

print(f"Time for {N} iterations: {end - start:.4f}s")
print(f"Avg time: {(end - start)/N*1000:.4f}ms")

Confirmations

Metadata

Metadata

Assignees

No one assigned

    Labels

    MJXUsing JAX to run on GPUbugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions