77i - residue sequence length (source) 
88j - residue sequence length (target) 
99m - atom sequence length 
10+ c - coordinates (3 for spatial) 
1011d - feature dimension 
1112ds - feature dimension (single) 
1213dp - feature dimension (pairwise) 
@@ -861,7 +862,7 @@ def __init__(
861862        # final projection of mean pooled repr -> out 
862863
863864        self .to_out  =  nn .Sequential (
864-             LinearNoBias (dim , dim ),
865+             LinearNoBias (dim , dim_pairwise ),
865866            nn .ReLU ()
866867        )
867868
@@ -873,7 +874,7 @@ def forward(
873874        template_mask : Bool ['b t' ],
874875        pairwise_repr : Float ['b n n dp' ],
875876        mask : Bool ['b n' ] |  None  =  None ,
876-     ) ->  Float ['b n n d ' ]:
877+     ) ->  Float ['b n n dp ' ]:
877878
878879        num_templates  =  templates .shape [1 ]
879880
@@ -884,7 +885,8 @@ def forward(
884885
885886        v , merged_batch_ps  =  pack_one (v , '* i j d' )
886887
887-         mask  =  repeat (mask , 'b n -> (b t) n' , t  =  num_templates )
888+         if  exists (mask ):
889+             mask  =  repeat (mask , 'b n -> (b t) n' , t  =  num_templates )
888890
889891        for  block  in  self .pairformer_stack :
890892            v  =  block (
@@ -1815,7 +1817,7 @@ def forward(
18151817        pairwise_repr : Float ['b n n dp' ],
18161818        pred_atom_pos : Float ['b n c' ],
18171819        mask : Bool ['b n' ] |  None  =  None ,
1818-         calc_pae_logits_and_loss  =  True 
1820+         return_pae_logits  =  True 
18191821
18201822    ) ->  ConfidenceHeadLogits [
18211823        Float ['b pae n n' ] |  None ,
@@ -1854,7 +1856,7 @@ def forward(
18541856
18551857        pae_logits  =  None 
18561858
1857-         if  calc_pae_logits_and_loss :
1859+         if  return_pae_logits :
18581860            pae_logits  =  self .to_pae_logits (pairwise_repr )
18591861
18601862        # return all logits 
@@ -1863,21 +1865,248 @@ def forward(
18631865
18641866# main class 
18651867
1868+ LossBreakdown  =  namedtuple ('LossBreakdown' , [
1869+     'distogram' ,
1870+     'pae' ,
1871+     'pdt' ,
1872+     'plddt' ,
1873+     'resolved' 
1874+ ])
1875+ 
18661876class  Alphafold3 (Module ):
1877+     """ Algorithm 1 """ 
1878+ 
1879+     @typecheck  
18671880    def  __init__ (
18681881        self ,
18691882        * ,
1883+         dim_atom_inputs ,
1884+         dim_additional_residue_feats ,
1885+         dim_template_feats ,
1886+         dim_template_model  =  64 ,
1887+         atoms_per_window  =  27 ,
1888+         dim_atom  =  128 ,
1889+         dim_atompair  =  16 ,
1890+         dim_input_embedder_token  =  384 ,
1891+         dim_single  =  384 ,
1892+         dim_pairwise  =  128 ,
1893+         atompair_dist_bins : Float [' dist_bins' ] =  torch .linspace (3 , 20 , 37 ),
1894+         ignore_index  =  - 1 ,
1895+         num_dist_bins  =  38 ,
1896+         num_plddt_bins  =  50 ,
1897+         num_pde_bins  =  64 ,
1898+         num_pae_bins  =  64 ,
18701899        loss_confidence_weight  =  1e-4 ,
18711900        loss_distogram_weight  =  1e-2 ,
1872-         loss_diffusion  =  4. 
1901+         loss_diffusion_weight  =  4. ,
1902+         input_embedder_kwargs : dict  =  dict (
1903+             atom_transformer_blocks  =  3 ,
1904+             atom_transformer_heads  =  4 ,
1905+             atom_transformer_kwargs  =  dict ()
1906+         ),
1907+         confidence_head_kwargs : dict  =  dict (
1908+             pairformer_depth  =  4 
1909+         ),
1910+         template_embedder_kwargs : dict  =  dict (
1911+             pairformer_stack_depth  =  2 ,
1912+             pairwise_block_kwargs  =  dict (),
1913+         ),
1914+         msa_module_kwargs : dict  =  dict (
1915+             depth  =  4 ,
1916+             dim_msa  =  64 ,
1917+             dim_msa_input  =  None ,
1918+             outer_product_mean_dim_hidden  =  32 ,
1919+             msa_pwa_dropout_row_prob  =  0.15 ,
1920+             msa_pwa_heads  =  8 ,
1921+             msa_pwa_dim_head  =  32 ,
1922+             pairwise_block_kwargs  =  dict ()
1923+         ),
1924+         pairformer_stack : dict  =  dict (
1925+             depth  =  48 ,
1926+             pair_bias_attn_dim_head  =  64 ,
1927+             pair_bias_attn_heads  =  16 ,
1928+             dropout_row_prob  =  0.25 ,
1929+             pairwise_block_kwargs  =  dict ()
1930+         )
18731931    ):
18741932        super ().__init__ ()
18751933
1934+         self .atoms_per_window  =  atoms_per_window 
1935+ 
1936+         # input feature embedder 
1937+ 
1938+         self .input_embedder  =  InputFeatureEmbedder (
1939+             dim_atom_inputs  =  dim_atom_inputs ,
1940+             dim_additional_residue_feats  =  dim_additional_residue_feats ,
1941+             atoms_per_window  =  atoms_per_window ,
1942+             dim_atom  =  dim_atom ,
1943+             dim_atompair  =  dim_atompair ,
1944+             dim_token  =  dim_input_embedder_token ,
1945+             dim_single  =  dim_single ,
1946+             dim_pairwise  =  dim_pairwise ,
1947+             ** input_embedder_kwargs 
1948+         )
1949+ 
1950+         dim_single_inputs  =  dim_input_embedder_token  +  dim_additional_residue_feats 
1951+ 
1952+         # templates 
1953+ 
1954+         self .template_embedder  =  TemplateEmbedder (
1955+             dim_template_feats  =  dim_template_feats ,
1956+             dim  =  dim_template_model ,
1957+             dim_pairwise  =  dim_pairwise ,
1958+             ** template_embedder_kwargs 
1959+         )
1960+ 
1961+         # msa 
1962+ 
1963+         self .msa_module  =  MSAModule (
1964+             dim_single  =  dim_single ,
1965+             dim_pairwise  =  dim_pairwise ,
1966+             ** msa_module_kwargs 
1967+         )
1968+ 
1969+         # main pairformer trunk, 48 layers 
1970+ 
1971+         self .pairformer  =  PairformerStack (
1972+             dim_single  =  dim_single ,
1973+             dim_pairwise  =  dim_pairwise ,
1974+             ** pairformer_stack 
1975+         )
1976+ 
1977+         # recycling related 
1978+ 
1979+         self .recycle_single  =  nn .Sequential (
1980+             nn .LayerNorm (dim_single ),
1981+             LinearNoBias (dim_single , dim_single )
1982+         )
1983+ 
1984+         self .recycle_pairwise  =  nn .Sequential (
1985+             nn .LayerNorm (dim_pairwise ),
1986+             LinearNoBias (dim_pairwise , dim_pairwise )
1987+         )
1988+ 
1989+         # logit heads 
1990+ 
1991+         self .distogram_head  =  DistogramHead (
1992+             dim_pairwise  =  dim_pairwise ,
1993+             num_dist_bins  =  num_dist_bins 
1994+         )
1995+ 
1996+         self .confidence_head  =  ConfidenceHead (
1997+             dim_single_inputs  =  dim_single_inputs ,
1998+             atompair_dist_bins  =  atompair_dist_bins ,
1999+             dim_single  =  dim_single ,
2000+             dim_pairwise  =  dim_pairwise ,
2001+             num_plddt_bins  =  num_plddt_bins ,
2002+             num_pde_bins  =  num_pde_bins ,
2003+             num_pae_bins  =  num_pae_bins ,
2004+             ** confidence_head_kwargs 
2005+         )
2006+ 
2007+         # loss related 
2008+ 
2009+         self .ignore_index  =  ignore_index 
2010+         self .loss_distogram_weight  =  loss_distogram_weight 
2011+         self .loss_confidence_weight  =  loss_confidence_weight 
2012+         self .loss_diffusion_weight  =  loss_diffusion_weight 
18762013
18772014    @typecheck  
18782015    def  forward (
18792016        self ,
18802017        * ,
1881-         include_pae_loss  =  False  # turned on in latter part of training 
1882-     ):
1883-         return 
2018+         atom_inputs : Float ['b m dai' ],
2019+         atom_mask : Bool ['b m' ],
2020+         atompair_feats : Float ['b m m dap' ],
2021+         additional_residue_feats : Float ['b n rf' ],
2022+         msa : Float ['b s n d' ],
2023+         templates : Float ['b t n n dt' ],
2024+         template_mask : Bool ['b t' ],
2025+         num_recycling_steps : int  =  1 ,
2026+         distance_labels : Int ['b n n' ] |  None  =  None ,
2027+         pae_labels : Int ['b n n' ] |  None  =  None ,
2028+         pde_labels : Int ['b n n' ] |  None  =  None ,
2029+         plddt_labels : Int ['b n' ] |  None  =  None ,
2030+         resolved_labels : Int ['b n' ] |  None  =  None ,
2031+     ) ->  Float ['b m c' ] |  Float ['' ]:
2032+ 
2033+         # embed inputs 
2034+ 
2035+         (
2036+             single_inputs ,
2037+             single_init ,
2038+             pairwise_init ,
2039+             atom_feats ,
2040+             atompair_feats 
2041+         ) =  self .input_embedder (
2042+             atom_inputs  =  atom_inputs ,
2043+             atom_mask  =  atom_mask ,
2044+             atompair_feats  =  atompair_feats ,
2045+             additional_residue_feats  =  additional_residue_feats 
2046+         )
2047+ 
2048+         w  =  self .atoms_per_window 
2049+ 
2050+         mask  =  reduce (atom_mask , 'b (n w) -> b n' , w  =  w , reduction  =  'any' )
2051+ 
2052+         # init recycled single and pairwise 
2053+ 
2054+         recycled_pairwise  =  recycled_single  =  None 
2055+         single  =  pairwise  =  None 
2056+ 
2057+         # for each recycling step 
2058+ 
2059+         for  _  in  range (num_recycling_steps ):
2060+ 
2061+             # handle recycled single and pairwise if not first step 
2062+ 
2063+             recycled_single  =  recycled_pairwise  =  0. 
2064+ 
2065+             if  exists (single ):
2066+                 recycled_single  =  self .recycle_single (single )
2067+ 
2068+             if  exists (pairwise ):
2069+                 recycled_pairwise  =  self .recycle_pairwise (pairwise )
2070+ 
2071+             single  =  single_init  +  recycled_single 
2072+             pairwise  =  pairwise_init  +  recycled_pairwise 
2073+ 
2074+             # else go through main transformer trunk from alphafold2 
2075+ 
2076+             # templates 
2077+ 
2078+             embedded_template  =  self .template_embedder (
2079+                 templates  =  templates ,
2080+                 template_mask  =  template_mask ,
2081+                 pairwise_repr  =  pairwise ,
2082+                 mask  =  mask 
2083+             )
2084+ 
2085+             pairwise  =  embedded_template  +  pairwise 
2086+ 
2087+             # msa 
2088+ 
2089+             embedded_msa  =  self .msa_module (
2090+                 msa  =  msa ,
2091+                 single_repr  =  single ,
2092+                 pairwise_repr  =  pairwise ,
2093+                 mask  =  mask 
2094+             )
2095+ 
2096+             pairwise  =  embedded_msa  +  pairwise 
2097+ 
2098+             # main attention trunk (pairformer) 
2099+ 
2100+             single , pairwise  =  self .pairformer (
2101+                 single_repr  =  single ,
2102+                 pairwise_repr  =  pairwise ,
2103+                 mask  =  mask 
2104+             )
2105+ 
2106+         # determine whether to return loss if any labels were to be passed in 
2107+         # otherwise will sample the atomic coordinates 
2108+ 
2109+         labels  =  (distance_labels , pae_labels , pde_labels , plddt_labels , resolved_labels )
2110+         return_loss  =  any ([* filter (exists , labels )])
2111+ 
2112+         return  torch .tensor (0. )
0 commit comments