Skip to content

Commit 1936f1e

Browse files
committed
complete the main alphafold2 flow, sans diffusion module and losses + sampling
1 parent b0198d4 commit 1936f1e

File tree

3 files changed

+322
-9
lines changed

3 files changed

+322
-9
lines changed

README.md

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,56 @@ Implementation of <a href="https://www.nature.com/articles/s41586-024-07487-w">A
66

77
Getting a fair number of emails. You can chat with me about this work <a href="https://discord.gg/x6FuzQPQXY">here</a>
88

9+
## Install
10+
11+
```bash
12+
$ pip install alphafold3-pytorch
13+
```
14+
15+
## Usage
16+
17+
```python
18+
import torch
19+
from alphafold3_pytorch import Alphafold3
20+
21+
alphafold3 = Alphafold3(
22+
dim_atom_inputs = 77,
23+
dim_additional_residue_feats = 33,
24+
dim_template_feats = 44
25+
)
26+
27+
# mock inputs
28+
29+
seq_len = 16
30+
atom_seq_len = seq_len * 27
31+
32+
atom_inputs = torch.randn(2, atom_seq_len, 77)
33+
atom_mask = torch.ones((2, atom_seq_len)).bool()
34+
atompair_feats = torch.randn(2, atom_seq_len, atom_seq_len, 16)
35+
additional_residue_feats = torch.randn(2, seq_len, 33)
36+
37+
template_feats = torch.randn(2, 2, seq_len, seq_len, 44)
38+
template_mask = torch.ones((2, 2)).bool()
39+
40+
msa = torch.randn(2, 7, seq_len, 64)
41+
42+
# train
43+
44+
loss = alphafold3(
45+
num_recycling_steps = 2,
46+
atom_inputs = atom_inputs,
47+
atom_mask = atom_mask,
48+
atompair_feats = atompair_feats,
49+
additional_residue_feats = additional_residue_feats,
50+
msa = msa,
51+
templates = template_feats,
52+
template_mask = template_mask
53+
)
54+
55+
loss.backward()
56+
57+
```
58+
959
## Citations
1060

1161
```bibtex

alphafold3_pytorch/alphafold3.py

Lines changed: 238 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
i - residue sequence length (source)
88
j - residue sequence length (target)
99
m - atom sequence length
10+
c - coordinates (3 for spatial)
1011
d - feature dimension
1112
ds - feature dimension (single)
1213
dp - feature dimension (pairwise)
@@ -861,7 +862,7 @@ def __init__(
861862
# final projection of mean pooled repr -> out
862863

863864
self.to_out = nn.Sequential(
864-
LinearNoBias(dim, dim),
865+
LinearNoBias(dim, dim_pairwise),
865866
nn.ReLU()
866867
)
867868

@@ -873,7 +874,7 @@ def forward(
873874
template_mask: Bool['b t'],
874875
pairwise_repr: Float['b n n dp'],
875876
mask: Bool['b n'] | None = None,
876-
) -> Float['b n n d']:
877+
) -> Float['b n n dp']:
877878

878879
num_templates = templates.shape[1]
879880

@@ -884,7 +885,8 @@ def forward(
884885

885886
v, merged_batch_ps = pack_one(v, '* i j d')
886887

887-
mask = repeat(mask, 'b n -> (b t) n', t = num_templates)
888+
if exists(mask):
889+
mask = repeat(mask, 'b n -> (b t) n', t = num_templates)
888890

889891
for block in self.pairformer_stack:
890892
v = block(
@@ -1815,7 +1817,7 @@ def forward(
18151817
pairwise_repr: Float['b n n dp'],
18161818
pred_atom_pos: Float['b n c'],
18171819
mask: Bool['b n'] | None = None,
1818-
calc_pae_logits_and_loss = True
1820+
return_pae_logits = True
18191821

18201822
) -> ConfidenceHeadLogits[
18211823
Float['b pae n n'] | None,
@@ -1854,7 +1856,7 @@ def forward(
18541856

18551857
pae_logits = None
18561858

1857-
if calc_pae_logits_and_loss:
1859+
if return_pae_logits:
18581860
pae_logits = self.to_pae_logits(pairwise_repr)
18591861

18601862
# return all logits
@@ -1863,21 +1865,248 @@ def forward(
18631865

18641866
# main class
18651867

1868+
LossBreakdown = namedtuple('LossBreakdown', [
1869+
'distogram',
1870+
'pae',
1871+
'pdt',
1872+
'plddt',
1873+
'resolved'
1874+
])
1875+
18661876
class Alphafold3(Module):
1877+
""" Algorithm 1 """
1878+
1879+
@typecheck
18671880
def __init__(
18681881
self,
18691882
*,
1883+
dim_atom_inputs,
1884+
dim_additional_residue_feats,
1885+
dim_template_feats,
1886+
dim_template_model = 64,
1887+
atoms_per_window = 27,
1888+
dim_atom = 128,
1889+
dim_atompair = 16,
1890+
dim_input_embedder_token = 384,
1891+
dim_single = 384,
1892+
dim_pairwise = 128,
1893+
atompair_dist_bins: Float[' dist_bins'] = torch.linspace(3, 20, 37),
1894+
ignore_index = -1,
1895+
num_dist_bins = 38,
1896+
num_plddt_bins = 50,
1897+
num_pde_bins = 64,
1898+
num_pae_bins = 64,
18701899
loss_confidence_weight = 1e-4,
18711900
loss_distogram_weight = 1e-2,
1872-
loss_diffusion = 4.
1901+
loss_diffusion_weight = 4.,
1902+
input_embedder_kwargs: dict = dict(
1903+
atom_transformer_blocks = 3,
1904+
atom_transformer_heads = 4,
1905+
atom_transformer_kwargs = dict()
1906+
),
1907+
confidence_head_kwargs: dict = dict(
1908+
pairformer_depth = 4
1909+
),
1910+
template_embedder_kwargs: dict = dict(
1911+
pairformer_stack_depth = 2,
1912+
pairwise_block_kwargs = dict(),
1913+
),
1914+
msa_module_kwargs: dict = dict(
1915+
depth = 4,
1916+
dim_msa = 64,
1917+
dim_msa_input = None,
1918+
outer_product_mean_dim_hidden = 32,
1919+
msa_pwa_dropout_row_prob = 0.15,
1920+
msa_pwa_heads = 8,
1921+
msa_pwa_dim_head = 32,
1922+
pairwise_block_kwargs = dict()
1923+
),
1924+
pairformer_stack: dict = dict(
1925+
depth = 48,
1926+
pair_bias_attn_dim_head = 64,
1927+
pair_bias_attn_heads = 16,
1928+
dropout_row_prob = 0.25,
1929+
pairwise_block_kwargs = dict()
1930+
)
18731931
):
18741932
super().__init__()
18751933

1934+
self.atoms_per_window = atoms_per_window
1935+
1936+
# input feature embedder
1937+
1938+
self.input_embedder = InputFeatureEmbedder(
1939+
dim_atom_inputs = dim_atom_inputs,
1940+
dim_additional_residue_feats = dim_additional_residue_feats,
1941+
atoms_per_window = atoms_per_window,
1942+
dim_atom = dim_atom,
1943+
dim_atompair = dim_atompair,
1944+
dim_token = dim_input_embedder_token,
1945+
dim_single = dim_single,
1946+
dim_pairwise = dim_pairwise,
1947+
**input_embedder_kwargs
1948+
)
1949+
1950+
dim_single_inputs = dim_input_embedder_token + dim_additional_residue_feats
1951+
1952+
# templates
1953+
1954+
self.template_embedder = TemplateEmbedder(
1955+
dim_template_feats = dim_template_feats,
1956+
dim = dim_template_model,
1957+
dim_pairwise = dim_pairwise,
1958+
**template_embedder_kwargs
1959+
)
1960+
1961+
# msa
1962+
1963+
self.msa_module = MSAModule(
1964+
dim_single = dim_single,
1965+
dim_pairwise = dim_pairwise,
1966+
**msa_module_kwargs
1967+
)
1968+
1969+
# main pairformer trunk, 48 layers
1970+
1971+
self.pairformer = PairformerStack(
1972+
dim_single = dim_single,
1973+
dim_pairwise = dim_pairwise,
1974+
**pairformer_stack
1975+
)
1976+
1977+
# recycling related
1978+
1979+
self.recycle_single = nn.Sequential(
1980+
nn.LayerNorm(dim_single),
1981+
LinearNoBias(dim_single, dim_single)
1982+
)
1983+
1984+
self.recycle_pairwise = nn.Sequential(
1985+
nn.LayerNorm(dim_pairwise),
1986+
LinearNoBias(dim_pairwise, dim_pairwise)
1987+
)
1988+
1989+
# logit heads
1990+
1991+
self.distogram_head = DistogramHead(
1992+
dim_pairwise = dim_pairwise,
1993+
num_dist_bins = num_dist_bins
1994+
)
1995+
1996+
self.confidence_head = ConfidenceHead(
1997+
dim_single_inputs = dim_single_inputs,
1998+
atompair_dist_bins = atompair_dist_bins,
1999+
dim_single = dim_single,
2000+
dim_pairwise = dim_pairwise,
2001+
num_plddt_bins = num_plddt_bins,
2002+
num_pde_bins = num_pde_bins,
2003+
num_pae_bins = num_pae_bins,
2004+
**confidence_head_kwargs
2005+
)
2006+
2007+
# loss related
2008+
2009+
self.ignore_index = ignore_index
2010+
self.loss_distogram_weight = loss_distogram_weight
2011+
self.loss_confidence_weight = loss_confidence_weight
2012+
self.loss_diffusion_weight = loss_diffusion_weight
18762013

18772014
@typecheck
18782015
def forward(
18792016
self,
18802017
*,
1881-
include_pae_loss = False # turned on in latter part of training
1882-
):
1883-
return
2018+
atom_inputs: Float['b m dai'],
2019+
atom_mask: Bool['b m'],
2020+
atompair_feats: Float['b m m dap'],
2021+
additional_residue_feats: Float['b n rf'],
2022+
msa: Float['b s n d'],
2023+
templates: Float['b t n n dt'],
2024+
template_mask: Bool['b t'],
2025+
num_recycling_steps: int = 1,
2026+
distance_labels: Int['b n n'] | None = None,
2027+
pae_labels: Int['b n n'] | None = None,
2028+
pde_labels: Int['b n n'] | None = None,
2029+
plddt_labels: Int['b n'] | None = None,
2030+
resolved_labels: Int['b n'] | None = None,
2031+
) -> Float['b m c'] | Float['']:
2032+
2033+
# embed inputs
2034+
2035+
(
2036+
single_inputs,
2037+
single_init,
2038+
pairwise_init,
2039+
atom_feats,
2040+
atompair_feats
2041+
) = self.input_embedder(
2042+
atom_inputs = atom_inputs,
2043+
atom_mask = atom_mask,
2044+
atompair_feats = atompair_feats,
2045+
additional_residue_feats = additional_residue_feats
2046+
)
2047+
2048+
w = self.atoms_per_window
2049+
2050+
mask = reduce(atom_mask, 'b (n w) -> b n', w = w, reduction = 'any')
2051+
2052+
# init recycled single and pairwise
2053+
2054+
recycled_pairwise = recycled_single = None
2055+
single = pairwise = None
2056+
2057+
# for each recycling step
2058+
2059+
for _ in range(num_recycling_steps):
2060+
2061+
# handle recycled single and pairwise if not first step
2062+
2063+
recycled_single = recycled_pairwise = 0.
2064+
2065+
if exists(single):
2066+
recycled_single = self.recycle_single(single)
2067+
2068+
if exists(pairwise):
2069+
recycled_pairwise = self.recycle_pairwise(pairwise)
2070+
2071+
single = single_init + recycled_single
2072+
pairwise = pairwise_init + recycled_pairwise
2073+
2074+
# else go through main transformer trunk from alphafold2
2075+
2076+
# templates
2077+
2078+
embedded_template = self.template_embedder(
2079+
templates = templates,
2080+
template_mask = template_mask,
2081+
pairwise_repr = pairwise,
2082+
mask = mask
2083+
)
2084+
2085+
pairwise = embedded_template + pairwise
2086+
2087+
# msa
2088+
2089+
embedded_msa = self.msa_module(
2090+
msa = msa,
2091+
single_repr = single,
2092+
pairwise_repr = pairwise,
2093+
mask = mask
2094+
)
2095+
2096+
pairwise = embedded_msa + pairwise
2097+
2098+
# main attention trunk (pairformer)
2099+
2100+
single, pairwise = self.pairformer(
2101+
single_repr = single,
2102+
pairwise_repr = pairwise,
2103+
mask = mask
2104+
)
2105+
2106+
# determine whether to return loss if any labels were to be passed in
2107+
# otherwise will sample the atomic coordinates
2108+
2109+
labels = (distance_labels, pae_labels, pde_labels, plddt_labels, resolved_labels)
2110+
return_loss = any([*filter(exists, labels)])
2111+
2112+
return torch.tensor(0.)

0 commit comments

Comments
 (0)