Skip to content

Conversation

@wujingyue
Copy link
Collaborator

@wujingyue wujingyue commented Dec 21, 2025

cc @DejunL

@wujingyue wujingyue requested a review from Priya2698 December 21, 2025 00:01
@github-actions
Copy link

github-actions bot commented Dec 21, 2025

Review updated until commit e293685

Description

  • Add contiguous() call for attn_bias in triangle attention ending nodes

  • Refactor triangle attention tests to support both incoming and outgoing directions

  • Add parameterization for direction-based attention tests

  • Update tensor definitions with contiguity requirements and add sigmoid activation

Changes walkthrough

Relevant files
Bug fix
internal_nodes.cpp
Fix attn_bias contiguous call for triangle attention         

csrc/ir/internal_nodes.cpp

  • Add contiguous() call to attn_bias for triangle attention ending nodes
  • Add detailed comment explaining why contiguous() is needed for ending
    nodes
  • Fix stride order issue when B and N dimensions are not adjacent
  • +8/-1     
    Tests
    test_alphafold3.py
    Refactor triangle attention tests with direction parameterization

    tests/python/direct/test_alphafold3.py

  • Add Direction enum with INCOMING and OUTGOING values
  • Refactor triangle attention tests to use parameterized direction
  • Add contiguity=True to all tensor definitions
  • Add sigmoid activation to gating mechanism
  • Update n_heads from 2 to 4 in default config
  • +54/-27 

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    Performance Impact

    The addition of .contiguous() call on line 3437 may have performance implications. While necessary for correctness when B and N dimensions are not adjacent in stride order, this creates an extra memory copy. The PR should include performance benchmarks comparing the ending node implementation with the starting node to quantify this overhead and ensure it's acceptable for the use case.

    attn_bias = flattenBatchDims(attn_bias.contiguous());
    Test Coverage

    The test functions test_triangle_updates and test_triangle_attention are currently empty (pass statements) for test_triangle_updates and only partially implemented for test_triangle_attention. While the main test_triangle_attention_starting_node function has been enhanced, the parameterized versions should include actual test logic to validate both INCOMING and OUTGOING directions work correctly.

    def test_triangle_updates(direction):
        pass

    n_heads: int = 2


    _DEFAULT_CONFIG = ModelConfig()
    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    @DejunL, what are the sizes people use in practice?

    Copy link

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    The Boltz reference model typically has:

    • B or batch_size of 1
    • N or token counts of {inference: however many in request but typically ~100 to ~2000, training: {stage1: 256, stage2: 512, stage3: 768}
    • c_z or token_z or hidden dimension of pair representation z is 128
    • num_heads is 4
    • head_dim or c_hidden is 32

    But in some other models they vary but probably within same order of magnitude. Typically structure prediction models are small in model weight counts but large in activation so hidden dimensions are typically small as such but I do see some models experiment with larger hidden dimensions.

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Dec 21, 2025

    Greptile Summary

    Implemented support for triangle attention ending nodes in AlphaFold3 by adding a contiguous() call before flattenBatchDims in the SDPA forward operation. This fixes a stride ordering issue where transposed masks cause batch dimensions B and N to become non-adjacent.

    Key changes:

    • csrc/ir/internal_nodes.cpp: Added contiguous() before flattenBatchDims(attn_bias) to handle non-contiguous stride patterns from transposed masks in ending nodes
    • tests/python/direct/test_alphafold3.py: Consolidated triangle attention tests using @pytest.mark.parametrize with Direction.INCOMING/OUTGOING, added sigmoid gating (fd.ops.sigmoid(g)), and added proper permutations for INCOMING direction

    Confidence Score: 4/5

    • Safe to merge with minor verification recommended
    • The contiguous() fix is a well-documented workaround for a known stride ordering issue. The test consolidation is clean and follows pytest best practices. Score reflects that this is a reference implementation (as stated in PR title) and the added memory copy overhead from contiguous() should be verified in production
    • No files require special attention

    Important Files Changed

    Filename Overview
    csrc/ir/internal_nodes.cpp Added contiguous() call before flattenBatchDims to handle non-adjacent strides in triangle attention ending nodes
    tests/python/direct/test_alphafold3.py Unified triangle attention tests with parametrization, added INCOMING/OUTGOING support with proper permutations and sigmoid gating

    Sequence Diagram

    sequenceDiagram
        participant Test as test_triangle_attention
        participant FD as FusionDefinition
        participant SDPA as SdpaFwdOp::evaluate
        participant Flatten as flattenBatchDims
        
        Test->>FD: define z_in [b,i,j,c_z]
        alt Direction.INCOMING
            FD->>FD: permute z_in [b,j,i,c_z]
            Test->>FD: define mask [b,i,j]
            FD->>FD: permute mask [b,j,i]
        end
        
        Test->>FD: linear ops (q, k, v, b, g)
        FD->>FD: reshape & permute to attention format
        FD->>FD: broadcast bias & mask to 5D
        
        FD->>SDPA: sdpfa_fwd(q_h, k_h, v_h, bias=b_h, mask=mask)
        
        Note over SDPA: attn_bias = bias + mask_bias
        alt attn_bias is non-contiguous (INCOMING)
            SDPA->>SDPA: attn_bias.contiguous()
            Note over SDPA: Fixes stride ordering<br/>B and N no longer adjacent
        end
        
        SDPA->>Flatten: flattenBatchDims(attn_bias)
        Flatten-->>SDPA: [B*N, H, Q, K]
        
        SDPA->>SDPA: _scaled_dot_product_attention
        SDPA-->>FD: attention output
        
        FD->>FD: apply sigmoid gating (g)
        FD->>FD: linear projection (w_o)
        
        alt Direction.INCOMING
            FD->>FD: permute z_out back [b,i,j,c_z]
        end
        
        FD-->>Test: z_out [b,i,j,c_z]
    
    Loading

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    2 files reviewed, 1 comment

    Edit Code Review Agent Settings | Greptile

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    4 participants