Skip to content
Discussion options

You must be logged in to vote

Hi - thanks for the question. I believe this is still an issue, and basically boils down to the fact that each loop iteration is a new dispatch that has some overhead associated with it. With that in mind, the way to avoid the slowdown is to make sure that each loop iteration does enough work that that dispatch overhead becomes less important.

One way you could approach this is to manually batch your while_loop content so that each loop iteration does multiple iterations of work, potentially within a vmap if possible. Whether or not this is doable depends on the details of the computation you're doing, so unfortunately there's not any real automated way to approach this kind of code change.

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@cgiovanetti
Comment options

Answer selected by cgiovanetti
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants