Skip to content

Commit 35b6bc1

Browse files
committed
do a small optimization for the torch rsa calculation for model selection
1 parent 097cbff commit 35b6bc1

File tree

2 files changed

+45
-23
lines changed

2 files changed

+45
-23
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 44 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5707,6 +5707,36 @@ def _compute_unresolved_rasa(
57075707

57085708
return unresolved_rasa.mean()
57095709

5710+
@typecheck
5711+
def calc_atom_access_surface_score_from_structure(
5712+
self,
5713+
structure: Structure,
5714+
**kwargs
5715+
) -> Float['m']:
5716+
5717+
# use the structure as source of truth, matching what xluo did
5718+
5719+
structure_atom_pos = []
5720+
structure_atom_type_for_radii = []
5721+
side_atom_index = len(self.atom_type_index)
5722+
5723+
for atom in structure.get_atoms():
5724+
5725+
one_atom_pos = list(atom.get_vector())
5726+
one_atom_type = self.atom_type_index.get(atom.name, side_atom_index)
5727+
5728+
structure_atom_pos.append(one_atom_pos)
5729+
structure_atom_type_for_radii.append(one_atom_type)
5730+
5731+
structure_atom_pos: Float[' m 3'] = tensor(structure_atom_pos)
5732+
structure_atom_type_for_radii: Int[' m'] = tensor(structure_atom_type_for_radii)
5733+
5734+
return self.calc_atom_access_surface_score(
5735+
atom_pos = structure_atom_pos,
5736+
atom_type = structure_atom_type_for_radii,
5737+
**kwargs
5738+
)
5739+
57105740
@typecheck
57115741
def calc_atom_access_surface_score(
57125742
self,
@@ -5749,7 +5779,7 @@ def calc_atom_access_surface_score(
57495779
lat.sin()
57505780
), dim = -1)
57515781

5752-
# first get atom relative positions + distance
5782+
# get atom relative positions + distance
57535783
# for determining whether to include pairs of atom in calculation for the `free` adjective
57545784

57555785
atom_rel_pos = einx.subtract('j c, i c -> i j c', atom_pos, atom_pos)
@@ -5762,13 +5792,23 @@ def calc_atom_access_surface_score(
57625792
(atom_rel_dist_sq > atom_distance_min_thres)
57635793
)
57645794

5795+
# max included in calculation per row
5796+
5797+
max_included = include_in_free_calc.long().sum(dim = -1).amax()
5798+
5799+
include_indices = include_in_free_calc.long().topk(max_included, dim = -1).indices
5800+
5801+
include_in_free_calc = einx.get_at('i [m], i j -> i j', include_in_free_calc, include_indices)
5802+
atom_rel_pos = einx.get_at('i [m] c, i j -> i j c', atom_rel_pos, include_indices)
5803+
target_atom_radii_sq = einx.get_at('[m], i j -> i j', atom_radii_sq, include_indices)
5804+
57655805
# overall logic
57665806

57675807
surface_dots = einx.multiply('m, sd c -> m sd c', atom_radii, unit_surface_dots)
57685808

57695809
dist_from_surface_dots_sq = einx.subtract('i j c, i sd c -> i sd j c', atom_rel_pos, surface_dots).pow(2).sum(dim = -1)
57705810

5771-
target_atom_close_to_surface_dots = einx.less('j, i sd j -> i sd j', atom_radii_sq, dist_from_surface_dots_sq)
5811+
target_atom_close_to_surface_dots = einx.less('i j, i sd j -> i sd j', target_atom_radii_sq, dist_from_surface_dots_sq)
57725812

57735813
target_atom_close_or_not_included = einx.logical_or('i sd j, i j -> i sd j', target_atom_close_to_surface_dots, ~include_in_free_calc)
57745814

@@ -5831,28 +5871,10 @@ def _inhouse_compute_unresolved_rasa(
58315871
chain_atom_mask,
58325872
)
58335873

5834-
# use the structure as source of truth, matching what xluo did
5835-
5836-
structure_atom_pos = []
5837-
structure_atom_type_for_radii = []
5838-
side_atom_index = len(self.atom_type_index)
5839-
5840-
for atom in structure.get_atoms():
5841-
5842-
one_atom_pos = list(atom.get_vector())
5843-
one_atom_type = self.atom_type_index.get(atom.name, side_atom_index)
5844-
5845-
structure_atom_pos.append(one_atom_pos)
5846-
structure_atom_type_for_radii.append(one_atom_type)
5847-
5848-
structure_atom_pos: Float[' m 3'] = tensor(structure_atom_pos)
5849-
structure_atom_type_for_radii: Int[' m'] = tensor(structure_atom_type_for_radii)
5850-
58515874
# per atom rsa calculation
58525875

5853-
per_atom_access_surface_score = self.calc_atom_access_surface_score(
5854-
structure_atom_pos,
5855-
structure_atom_type_for_radii,
5876+
per_atom_access_surface_score = self.calc_atom_access_surface_score_from_structure(
5877+
structure,
58565878
**rsa_calc_kwargs
58575879
)
58585880

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.5.48"
3+
version = "0.5.49"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" },

0 commit comments

Comments
 (0)