@@ -5707,6 +5707,36 @@ def _compute_unresolved_rasa(
57075707
57085708 return unresolved_rasa .mean ()
57095709
5710+ @typecheck
5711+ def calc_atom_access_surface_score_from_structure (
5712+ self ,
5713+ structure : Structure ,
5714+ ** kwargs
5715+ ) -> Float ['m' ]:
5716+
5717+ # use the structure as source of truth, matching what xluo did
5718+
5719+ structure_atom_pos = []
5720+ structure_atom_type_for_radii = []
5721+ side_atom_index = len (self .atom_type_index )
5722+
5723+ for atom in structure .get_atoms ():
5724+
5725+ one_atom_pos = list (atom .get_vector ())
5726+ one_atom_type = self .atom_type_index .get (atom .name , side_atom_index )
5727+
5728+ structure_atom_pos .append (one_atom_pos )
5729+ structure_atom_type_for_radii .append (one_atom_type )
5730+
5731+ structure_atom_pos : Float [' m 3' ] = tensor (structure_atom_pos )
5732+ structure_atom_type_for_radii : Int [' m' ] = tensor (structure_atom_type_for_radii )
5733+
5734+ return self .calc_atom_access_surface_score (
5735+ atom_pos = structure_atom_pos ,
5736+ atom_type = structure_atom_type_for_radii ,
5737+ ** kwargs
5738+ )
5739+
57105740 @typecheck
57115741 def calc_atom_access_surface_score (
57125742 self ,
@@ -5749,7 +5779,7 @@ def calc_atom_access_surface_score(
57495779 lat .sin ()
57505780 ), dim = - 1 )
57515781
5752- # first get atom relative positions + distance
5782+ # get atom relative positions + distance
57535783 # for determining whether to include pairs of atom in calculation for the `free` adjective
57545784
57555785 atom_rel_pos = einx .subtract ('j c, i c -> i j c' , atom_pos , atom_pos )
@@ -5762,13 +5792,23 @@ def calc_atom_access_surface_score(
57625792 (atom_rel_dist_sq > atom_distance_min_thres )
57635793 )
57645794
5795+ # max included in calculation per row
5796+
5797+ max_included = include_in_free_calc .long ().sum (dim = - 1 ).amax ()
5798+
5799+ include_indices = include_in_free_calc .long ().topk (max_included , dim = - 1 ).indices
5800+
5801+ include_in_free_calc = einx .get_at ('i [m], i j -> i j' , include_in_free_calc , include_indices )
5802+ atom_rel_pos = einx .get_at ('i [m] c, i j -> i j c' , atom_rel_pos , include_indices )
5803+ target_atom_radii_sq = einx .get_at ('[m], i j -> i j' , atom_radii_sq , include_indices )
5804+
57655805 # overall logic
57665806
57675807 surface_dots = einx .multiply ('m, sd c -> m sd c' , atom_radii , unit_surface_dots )
57685808
57695809 dist_from_surface_dots_sq = einx .subtract ('i j c, i sd c -> i sd j c' , atom_rel_pos , surface_dots ).pow (2 ).sum (dim = - 1 )
57705810
5771- target_atom_close_to_surface_dots = einx .less ('j, i sd j -> i sd j' , atom_radii_sq , dist_from_surface_dots_sq )
5811+ target_atom_close_to_surface_dots = einx .less ('i j, i sd j -> i sd j' , target_atom_radii_sq , dist_from_surface_dots_sq )
57725812
57735813 target_atom_close_or_not_included = einx .logical_or ('i sd j, i j -> i sd j' , target_atom_close_to_surface_dots , ~ include_in_free_calc )
57745814
@@ -5831,28 +5871,10 @@ def _inhouse_compute_unresolved_rasa(
58315871 chain_atom_mask ,
58325872 )
58335873
5834- # use the structure as source of truth, matching what xluo did
5835-
5836- structure_atom_pos = []
5837- structure_atom_type_for_radii = []
5838- side_atom_index = len (self .atom_type_index )
5839-
5840- for atom in structure .get_atoms ():
5841-
5842- one_atom_pos = list (atom .get_vector ())
5843- one_atom_type = self .atom_type_index .get (atom .name , side_atom_index )
5844-
5845- structure_atom_pos .append (one_atom_pos )
5846- structure_atom_type_for_radii .append (one_atom_type )
5847-
5848- structure_atom_pos : Float [' m 3' ] = tensor (structure_atom_pos )
5849- structure_atom_type_for_radii : Int [' m' ] = tensor (structure_atom_type_for_radii )
5850-
58515874 # per atom rsa calculation
58525875
5853- per_atom_access_surface_score = self .calc_atom_access_surface_score (
5854- structure_atom_pos ,
5855- structure_atom_type_for_radii ,
5876+ per_atom_access_surface_score = self .calc_atom_access_surface_score_from_structure (
5877+ structure ,
58565878 ** rsa_calc_kwargs
58575879 )
58585880
0 commit comments