-
| I wondered if there's any way to indicate semantically that a single unmapped random key should serve as randomness independently to the elements of a mapped dimension? At the moment, something like  # a vector of logits
>>> logits
Array([-0.02830462,  0.46713185,  0.29570296,  0.15354592], dtype=float32)
# each sample is independent even though sampling from same distribution
>>> jax.random.categorical(jax.random.key(42), jnp.broadcast_to(logits, (10, *logits.shape)))
Array([1, 3, 3, 0, 3, 3, 1, 1, 1, 1], dtype=int32)
# each sample identical when calling under jax.vmap.
# unampped `key` argument, but mapped `logits` argument
>>> jax.vmap(jax.random.categorical, in_axes=(None,0))(jax.random.key(42), jnp.broadcast_to(logits, (10, *logits.shape)))
Array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int32)For reference, jax.random.categorical has this signature categorical(
    key: 'ArrayLike',
    logits: 'RealArray',
    axis: 'int' = -1,
    shape: 'Shape | None' = None,
    replace: 'bool' = True
) -> 'Array' | 
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 2 replies
-
| If you use a single unmapped random key, then the same key will be used within each batch of the computation: this is working by design. If you'd like a different key within each batch, then you can split/fold the key and map over it, e.g. like this: >>> jax.vmap(jax.random.categorical)(jax.random.split(keys, 10), jnp.broadcast_to(logits, (10, *logits.shape)))
Array([2, 1, 1, 3, 2, 0, 3, 3, 3, 1], dtype=int32) | 
Beta Was this translation helpful? Give feedback.
-
| Yes, I kow that is how it is meant to work. And, on second thought, treating key in some special way by jax.vmap would be a pretty bad idea I think. It's best the way it is. What I was more broadly wondering about is, are there situations where hand-coding the batched version of a function affords the opportunity for a much more efficient compiled function? Or, does jax.jit + jax.vmap basically always work close to optimally? | 
Beta Was this translation helpful? Give feedback.
If you use a single unmapped random key, then the same key will be used within each batch of the computation: this is working by design. If you'd like a different key within each batch, then you can split/fold the key and map over it, e.g. like this: