-
|
Hi Flax community. I’d like to request adding an Both Rough ProposalExtend the NNX function signature with a keyword-only argument: def dot_product_attention(..., is_causal: bool = False):
...Behavior wise, you'd have 2 paths: Jax.nn Path: if dropout is not activated and jax.nn.dot_product_attention(query, key, value, bias, mask, is_causal=is_causal)Manual Path: When the manual path is used (dropout active or module non-None), we could interpret combined_mask = mask
if is_causal:
causal_mask = ... # standard [batch..., 1, q_length, kv_length] mask
combined_mask = combine_masks(combined_mask, causal_mask)If you’d prefer this to be limited to the jax.nn path only (or not added at all), I’m happy to adjust or drop the idea. If this sounds reasonable, I can open a small PR with the minimal implementation and tests. Thanks for considering it! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
|
Thanks for suggestion @ibbyml ! I think it makes sense to add |
Beta Was this translation helpful? Give feedback.
Thanks for suggestion @ibbyml ! I think it makes sense to add
is_causalargument and support it in both paths.