Indexing a Jax Tracer? #32408
-
|
Hi All! I am trying to index a jax tracer. Initially I looked into converting a tracer to a normal jax array and then I stumbled upon this: https://docs.jax.dev/en/latest/faq.html#how-can-i-convert-a-jax-tracer-to-a-numpy-array . The main problem I am running into is I need this to happen: And when I checked with The function works without any qualms when not in grad, but has this tracer thing in grad mode which is understandable. Way more context, I am trying to implement a graph edge kernel in GPJax and I am running into indexing issues. Any help is much appreciated! |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 3 replies
-
|
From the error message, it sounds like The fix here would be to not attempt to index a 1D array or tracer with two indices.
This probably indicates that you're doing some sort of batching transformation along with the gradient. We'll need more info in order to help out – could you include a minimal reproducible example of code that shows this behavior? |
Beta Was this translation helpful? Give feedback.
-
|
Thank you for the explanation. This is making tonnes of sense now. |
Beta Was this translation helpful? Give feedback.
Thanks for the repro! This is behaving as expected.
When you wrap a function in
vmap, it means that the function effectively operates over a single batch of the input. Your inputxis two-dimensional, and so each batch is logically a one-dimensional vector. When you writex[:, 0]within your function, you are attempting a two-dimensional indexing operation on a one-dimensional array, which leads to the error you're seeing.To fix this, you'll either need to (1) not wrap your function in
vmap, or (2) rewrite your function so that it accepts one-dimensional batches.You can read more about
vmapand automatic vectorization at https://docs.jax.dev/en/latest/automatic-vectorization.html.