Status and workarounds for slow lax.while_loop on GPU? #32791
-
|
Many an issue has been opened in relation to the slow execution of lax.while_loop on GPU. I use JAX for scientific computing and frequently can't get around the use of a while loop, not to mention the implicit use of lax.while_loop in differential equation solvers like Diffrax. Wondering if anyone has any home-cooked workarounds for improving runtime in specific cases, or if there are any updates from the JAX team about this issue. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
|
Hi - thanks for the question. I believe this is still an issue, and basically boils down to the fact that each loop iteration is a new dispatch that has some overhead associated with it. With that in mind, the way to avoid the slowdown is to make sure that each loop iteration does enough work that that dispatch overhead becomes less important. One way you could approach this is to manually batch your |
Beta Was this translation helpful? Give feedback.
Hi - thanks for the question. I believe this is still an issue, and basically boils down to the fact that each loop iteration is a new dispatch that has some overhead associated with it. With that in mind, the way to avoid the slowdown is to make sure that each loop iteration does enough work that that dispatch overhead becomes less important.
One way you could approach this is to manually batch your
while_loopcontent so that each loop iteration does multiple iterations of work, potentially within avmapif possible. Whether or not this is doable depends on the details of the computation you're doing, so unfortunately there's not any real automated way to approach this kind of code change.