-
Notifications
You must be signed in to change notification settings - Fork 74
Reference implementation for triangle attention ending #5730
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: wjy/bias
Are you sure you want to change the base?
Conversation
|
Review updated until commit e293685 Description
|
| Relevant files | |||
|---|---|---|---|
| Bug fix |
| ||
| Tests |
|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests |
| ⚡ Recommended focus areas for review |
Performance Impact
.contiguous() call on line 3437 may have performance implications. While necessary for correctness when B and N dimensions are not adjacent in stride order, this creates an extra memory copy. The PR should include performance benchmarks comparing the ending node implementation with the starting node to quantify this overhead and ensure it's acceptable for the use case. |
| n_heads: int = 2 | ||
|
|
||
|
|
||
| _DEFAULT_CONFIG = ModelConfig() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@DejunL, what are the sizes people use in practice?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Boltz reference model typically has:
Bor batch_size of 1Nor token counts of{inference: however many in request but typically ~100 to ~2000, training: {stage1: 256, stage2: 512, stage3: 768}c_zortoken_zor hidden dimension of pair representationzis 128num_headsis 4head_dimorc_hiddenis 32
But in some other models they vary but probably within same order of magnitude. Typically structure prediction models are small in model weight counts but large in activation so hidden dimensions are typically small as such but I do see some models experiment with larger hidden dimensions.
Greptile SummaryImplemented support for triangle attention ending nodes in AlphaFold3 by adding a Key changes:
Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Test as test_triangle_attention
participant FD as FusionDefinition
participant SDPA as SdpaFwdOp::evaluate
participant Flatten as flattenBatchDims
Test->>FD: define z_in [b,i,j,c_z]
alt Direction.INCOMING
FD->>FD: permute z_in [b,j,i,c_z]
Test->>FD: define mask [b,i,j]
FD->>FD: permute mask [b,j,i]
end
Test->>FD: linear ops (q, k, v, b, g)
FD->>FD: reshape & permute to attention format
FD->>FD: broadcast bias & mask to 5D
FD->>SDPA: sdpfa_fwd(q_h, k_h, v_h, bias=b_h, mask=mask)
Note over SDPA: attn_bias = bias + mask_bias
alt attn_bias is non-contiguous (INCOMING)
SDPA->>SDPA: attn_bias.contiguous()
Note over SDPA: Fixes stride ordering<br/>B and N no longer adjacent
end
SDPA->>Flatten: flattenBatchDims(attn_bias)
Flatten-->>SDPA: [B*N, H, Q, K]
SDPA->>SDPA: _scaled_dot_product_attention
SDPA-->>FD: attention output
FD->>FD: apply sigmoid gating (g)
FD->>FD: linear projection (w_o)
alt Direction.INCOMING
FD->>FD: permute z_out back [b,i,j,c_z]
end
FD-->>Test: z_out [b,i,j,c_z]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 files reviewed, 1 comment
cc @DejunL