differentiable fori_loop ?
#9699
-
|
According to The Sharp Bits, However, this actually works import jax
import jax.numpy as jnp
@jax.jit
def bar(x):
def _body_fun(i, val):
return val + x[i]**2
return jax.lax.fori_loop(0, x.shape[0], _body_fun, 0.0)
x = jnp.arange(4, dtype=jnp.float32)
print (jax.grad(bar)(x)) # prints [0. 2. 4. 6.]To my understanding Thank you in advance for explanation. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
|
Thanks for the question! That part of the doc should be updated: while it's true that in general |
Beta Was this translation helpful? Give feedback.
Thanks for the question! That part of the doc should be updated: while it's true that in general
fori_loopis not reverse-mode differentiable, in the special case of concrete start and end-points, we lowerfori_looptoscanto allow for reverse-mode differentiation. Here's where it happens in the source: https://github.com/google/jax/blob/d5a1c64d135ae8519c61e15a2f32a75d8de36ab3/jax/_src/lax/control_flow.py#L205-L217