-
|
Hi, Sorry if this question is already answered (I am sure it must be), but somehow I could not find any resource. I can differentiate it explicitly using I get the following error: |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
|
There's no way to make jax.grad(lambda x: fv(x)[0])(x)If you're interested in computing the gradients for all outputs at once, you could do something like this: jax.jacrev(fv)(x) |
Beta Was this translation helpful? Give feedback.
There's no way to make
vmap(grad(fv))work for this function, becausegradrequires a scalar-output function, and yourfvfunction is a vector output that does not correspond to any vector input. If you're interested on the gradient's effect on a single output, you could do something like this:If you're interested in computing the gradients for all outputs at once, you could do something like this: