-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Open
Labels
Description
Intro
Hi!
My setup
$ uv pip list | grep "jax\|mjx\|mujoco"
jax 0.8.1
jaxlib 0.8.1
mujoco 3.3.7
mujoco-mjx 3.3.7$ python -c "import platform; print(f'{platform.system()=}, {platform.release()=}, {platform.machine()=}')"
platform.system()='Darwin', platform.release()='24.6.0', platform.machine()='arm64'What's happening? What did you expect?
Many of our models run >100x slower on jax>=0.7 than on jax<0.7 when using MJX on MacOS. main...hartikainen:mujoco:slow-jax includes a minimal model that demonstrates this.
When run, I get:
$ uv pip list | grep "jax\|mjx\|mujoco"
jax 0.6.2
jaxlib 0.6.2
mujoco 3.3.7
mujoco-mjx 3.3.7
$ python ./slow_jax.py
Running benchmark with 9 bodies, 17 geoms, 8 sites on cpu
Compiled.
Time for 100 iterations: 0.0178s
Avg time: 0.1784ms
$ uv pip install -U "jax"
Resolved 6 packages in 110ms
Prepared 2 packages in 0.17ms
Uninstalled 2 packages in 51ms
Installed 2 packages in 7ms
- jax==0.6.2
+ jax==0.8.1
- jaxlib==0.6.2
+ jaxlib==0.8.1
$ python ./slow_jax.py
Running benchmark with 9 bodies, 17 geoms, 8 sites on cpu
Compiled.
Time for 100 iterations: 40.7880s
Avg time: 407.8800msNote the crazy difference in average times!
Steps for reproduction
Run the script above.
Code required for reproduction
Details
import time
import jax
import mujoco
from mujoco import mjx
import numpy as np
# Create a chain of bodies with many geoms and sites
xml = "<mujoco>\n"
xml += " <worldbody>\n"
xml += ' <body name="0" pos="0 0 0">\n'
xml += ' <joint type="free"/>\n'
xml += ' <geom size="0.1"/>\n'
depth = 9 # > 20 joints
geoms_per_body = 1
sites_per_body = 1
for i in range(1, depth):
xml += f' <body name="{i}" pos="0.1 0 0">\n'
xml += ' <joint type="hinge"/>\n'
xml += ' <geom size="0.1"/>\n'
for j in range(geoms_per_body):
xml += f' <geom size="0.01" pos="0 {j*0.01} 0"/>\n'
for j in range(sites_per_body):
xml += f' <site name="s_{i}_{j}" pos="0 0 {j*0.01}"/>\n'
for i in range(depth):
xml += " </body>"
xml += "\n </worldbody>\n"
xml += "</mujoco>\n"
m = mujoco.MjModel.from_xml_string(xml)
d = mujoco.MjData(m)
mx = mjx.put_model(m)
dx = mjx.put_data(m, d)
print(f"Running benchmark with {depth} bodies, {m.ngeom} geoms, {m.nsite} sites on {jax.default_backend()}")
# Compile
kinematics_jit = jax.jit(mjx.kinematics)
dx = kinematics_jit(mx, dx)
dx.qpos.block_until_ready()
print("Compiled.")
# Benchmark
start = time.time()
N = 3
for _ in range(N):
dx = kinematics_jit(mx, dx)
dx.qpos.block_until_ready()
end = time.time()
print(f"Time for {N} iterations: {end - start:.4f}s")
print(f"Avg time: {(end - start)/N*1000:.4f}ms")Confirmations
- I searched the latest documentation thoroughly before posting.
- I searched previous Issues and Discussions, I am certain this has not been raised before.
tkelestemur and PratikKunapuli