|
26 | 26 | from mujoco.mjx._src.types import DataJAX |
27 | 27 | from mujoco.mjx._src.types import DisableBit |
28 | 28 | from mujoco.mjx._src.types import EqType |
| 29 | +from mujoco.mjx._src.types import Impl |
29 | 30 | from mujoco.mjx._src.types import JointType |
30 | 31 | from mujoco.mjx._src.types import Model |
31 | 32 | from mujoco.mjx._src.types import ModelJAX |
32 | 33 | from mujoco.mjx._src.types import ObjType |
33 | 34 | from mujoco.mjx._src.types import TrnType |
34 | 35 | from mujoco.mjx._src.types import WrapType |
35 | 36 | # pylint: enable=g-importing-member |
| 37 | +import mujoco.mjx.warp as mjxw |
36 | 38 | import numpy as np |
37 | 39 |
|
38 | 40 |
|
39 | 41 | def kinematics(m: Model, d: Data) -> Data: |
40 | 42 | """Converts position/velocity from generalized coordinates to maximal.""" |
| 43 | + if m.impl == Impl.WARP and d.impl == Impl.WARP and mjxw.WARP_INSTALLED: |
| 44 | + from mujoco.mjx.warp import smooth as mjxw_smooth # pylint: disable=g-import-not-at-top # pytype: disable=import-error |
| 45 | + return mjxw_smooth.kinematics(m, d) |
| 46 | + |
41 | 47 | def fn(carry, jnt_typs, jnt_pos, jnt_axis, qpos, qpos0, pos, quat): |
42 | 48 | # calculate joint anchors, axes, body pos and quat in global frame |
43 | 49 | # also normalize qpos while we're at it |
@@ -844,6 +850,10 @@ def _forward(carry, cfrc_ext, cinert, cvel, body_dofadr, body_dofnum): |
844 | 850 |
|
845 | 851 | def tendon(m: Model, d: Data) -> Data: |
846 | 852 | """Computes tendon lengths and moments.""" |
| 853 | + if m.impl == Impl.WARP and d.impl == Impl.WARP and mjxw.WARP_INSTALLED: |
| 854 | + from mujoco.mjx.warp import smooth as mjxw_smooth # pylint: disable=g-import-not-at-top # pytype: disable=import-error |
| 855 | + return mjxw_smooth.tendon(m, d) |
| 856 | + |
847 | 857 | if not isinstance(m._impl, ModelJAX) or not isinstance(d._impl, DataJAX): |
848 | 858 | raise ValueError('tendon requires JAX backend implementation.') |
849 | 859 |
|
@@ -1091,7 +1101,7 @@ def _distance(p0, p1): |
1091 | 1101 |
|
1092 | 1102 | # assemble length and moment |
1093 | 1103 | ten_length = ( |
1094 | | - jp.zeros_like(d._impl.ten_length).at[tendon_id_jnt].set(length_jnt) |
| 1104 | + jp.zeros_like(d.ten_length).at[tendon_id_jnt].set(length_jnt) |
1095 | 1105 | ) |
1096 | 1106 | ten_length = ten_length.at[tendon_id_site].add(length_site) |
1097 | 1107 | ten_length = ten_length.at[tendon_id_geom].add(length_geom) |
@@ -1161,7 +1171,7 @@ def _distance(p0, p1): |
1161 | 1171 | ).reshape((m.nwrap, 2)) |
1162 | 1172 |
|
1163 | 1173 | return d.tree_replace({ |
1164 | | - '_impl.ten_length': ten_length, |
| 1174 | + 'ten_length': ten_length, |
1165 | 1175 | '_impl.ten_J': ten_moment, |
1166 | 1176 | '_impl.ten_wrapadr': jp.array(ten_wrapadr, dtype=int), |
1167 | 1177 | '_impl.ten_wrapnum': jp.array(ten_wrapnum, dtype=int), |
@@ -1263,7 +1273,7 @@ def fn( |
1263 | 1273 | wrench = jp.concatenate((frame_xmat @ gear[:3], frame_xmat @ gear[3:])) |
1264 | 1274 | moment = jac @ wrench |
1265 | 1275 | elif trntype == TrnType.TENDON: |
1266 | | - length = d._impl.ten_length[trnid[0]] * gear[:1] |
| 1276 | + length = d.ten_length[trnid[0]] * gear[:1] |
1267 | 1277 | moment = d._impl.ten_J[trnid[0]] * gear[0] |
1268 | 1278 | else: |
1269 | 1279 | raise RuntimeError(f'unrecognized trntype: {TrnType(trntype)}') |
|
0 commit comments