Skip to content

Conversation

rohan-tan-bhowmik
Copy link
Collaborator

Enabled mask and is_causal parameters for torch.aten.scaled_dot_product attention + relevant comments + tests.

The tests added highlight the new capabilities introduced in this PR, including:

Attention with F16 mask
Attention with Boolean mask
Causal attention with same Q K V shapes
Causal attention without Q K V shapes

Made sure that one cannot input both mask and is_causal.

@rsuderman
Copy link
Contributor

You need to still add the passing sdpa ops to the stable hlo tests

@rsuderman rsuderman merged commit e86f56b into llvm:main Sep 9, 2024
3 checks passed
@rohan-tan-bhowmik rohan-tan-bhowmik deleted the sdpa_mask branch September 21, 2024 09:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants