Skip to content

Commit a2592f1

Browse files
committed
avoid get_at when calculating surface score
1 parent a04e0cc commit a2592f1

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5701,11 +5701,12 @@ def calc_atom_access_surface_score(
57015701

57025702
max_included = include_in_free_calc.long().sum(dim = -1).amax()
57035703

5704-
include_indices = include_in_free_calc.long().topk(max_included, dim = -1).indices
5704+
include_in_free_calc, include_indices = include_in_free_calc.long().topk(max_included, dim = -1)
57055705

5706-
include_in_free_calc = einx.get_at('i [m], i j -> i j', include_in_free_calc, include_indices)
5707-
atom_rel_pos = einx.get_at('i [m] c, i j -> i j c', atom_rel_pos, include_indices)
5708-
target_atom_radii_sq = einx.get_at('[m], i j -> i j', atom_radii_sq, include_indices)
5706+
# atom_rel_pos = einx.get_at('i [m] c, i j -> i j c', atom_rel_pos, include_indices)
5707+
5708+
atom_rel_pos = atom_rel_pos.gather(1, repeat(include_indices, 'i j -> i j c', c = 3))
5709+
target_atom_radii_sq = atom_radii_sq[include_indices]
57095710

57105711
# overall logic
57115712

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.5.52"
3+
version = "0.5.53"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" },

0 commit comments

Comments
 (0)