Skip to content

Commit 4d7ff8d

Browse files
Filter update (#67)
* remove unnecessary inits and refactor * include smarts filter, smiles filter, descriptors filter * Fix wrong typing that caused thousands of type ignores * linting and fix element number test * reset name typing * Christians first review * more changes * linting * pylint * rewrite filter logic (#71) * Combine filters with one base logic * change dict to Mapping * isort * Include comments * linting * linting and ComplexFilter * typing, tests, complex filter naming * finalize filter refactoring * review Christian * pylint * include check for failed patterns in init * final review * final linting * final final linting * final final final linting --------- Co-authored-by: Christian W. Feldmann <[email protected]> Co-authored-by: Christian Feldmann <[email protected]>
1 parent 294e923 commit 4d7ff8d

File tree

7 files changed

+1139
-122
lines changed

7 files changed

+1139
-122
lines changed
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
"""Initialize the module for abstract mol2mol elements."""
2+
3+
from molpipeline.abstract_pipeline_elements.mol2mol.filter import (
4+
BaseKeepMatchesFilter,
5+
BasePatternsFilter,
6+
)
7+
8+
__all__ = ["BasePatternsFilter", "BaseKeepMatchesFilter"]
Lines changed: 359 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,359 @@
1+
"""Abstract classes for filters."""
2+
3+
import abc
4+
from typing import Any, Literal, Mapping, Optional, Sequence, TypeAlias, Union
5+
6+
try:
7+
from typing import Self # type: ignore[attr-defined]
8+
except ImportError:
9+
from typing_extensions import Self
10+
11+
from molpipeline.abstract_pipeline_elements.core import (
12+
InvalidInstance,
13+
MolToMolPipelineElement,
14+
OptionalMol,
15+
RDKitMol,
16+
)
17+
from molpipeline.utils.molpipeline_types import (
18+
FloatCountRange,
19+
IntCountRange,
20+
IntOrIntCountRange,
21+
)
22+
from molpipeline.utils.value_conversions import count_value_to_tuple
23+
24+
# possible mode types for a KeepMatchesFilter:
25+
# - "any" means one match is enough
26+
# - "all" means all elements must be matched
27+
FilterModeType: TypeAlias = Literal["any", "all"]
28+
29+
30+
def _within_boundaries(
31+
lower_bound: Optional[float], upper_bound: Optional[float], property_value: float
32+
) -> bool:
33+
"""Check if a value is within the specified boundaries.
34+
35+
Boundaries given as None are ignored.
36+
37+
Parameters
38+
----------
39+
lower_bound: Optional[float]
40+
Lower boundary.
41+
upper_bound: Optional[float]
42+
Upper boundary.
43+
property_value: float
44+
Property value to check.
45+
46+
Returns
47+
-------
48+
bool
49+
True if the value is within the boundaries, else False.
50+
"""
51+
if lower_bound is not None and property_value < lower_bound:
52+
return False
53+
if upper_bound is not None and property_value > upper_bound:
54+
return False
55+
return True
56+
57+
58+
class BaseKeepMatchesFilter(MolToMolPipelineElement, abc.ABC):
59+
"""Filter to keep or remove molecules based on patterns.
60+
61+
Notes
62+
-----
63+
There are four possible scenarios:
64+
- mode = "any" & keep_matches = True: Needs to match at least one filter element.
65+
- mode = "any" & keep_matches = False: Must not match any filter element.
66+
- mode = "all" & keep_matches = True: Needs to match all filter elements.
67+
- mode = "all" & keep_matches = False: Must not match all filter elements.
68+
"""
69+
70+
keep_matches: bool
71+
mode: FilterModeType
72+
73+
def __init__(
74+
self,
75+
filter_elements: Union[
76+
Mapping[Any, Union[FloatCountRange, IntCountRange, IntOrIntCountRange]],
77+
Sequence[Any],
78+
],
79+
keep_matches: bool = True,
80+
mode: FilterModeType = "any",
81+
name: Optional[str] = None,
82+
n_jobs: int = 1,
83+
uuid: Optional[str] = None,
84+
) -> None:
85+
"""Initialize BasePatternsFilter.
86+
87+
Parameters
88+
----------
89+
filter_elements: Union[Mapping[Any, Union[FloatCountRange, IntCountRange, IntOrIntCountRange]], Sequence[Any]]
90+
List of filter elements. Typically can be a list of patterns or a dictionary with patterns as keys and
91+
an int for exact count or a tuple of minimum and maximum.
92+
NOTE: for each child class, the type of filter_elements must be specified by the filter_elements setter.
93+
keep_matches: bool, optional (default: True)
94+
If True, molecules containing the specified patterns are kept, else removed.
95+
mode: FilterModeType, optional (default: "any")
96+
If "any", at least one of the specified patterns must be present in the molecule.
97+
If "all", all of the specified patterns must be present in the molecule.
98+
name: Optional[str], optional (default: None)
99+
Name of the pipeline element.
100+
n_jobs: int, optional (default: 1)
101+
Number of parallel jobs to use.
102+
uuid: str, optional (default: None)
103+
Unique identifier of the pipeline element.
104+
"""
105+
super().__init__(name=name, n_jobs=n_jobs, uuid=uuid)
106+
self.filter_elements = filter_elements # type: ignore
107+
self.keep_matches = keep_matches
108+
self.mode = mode
109+
110+
@property
111+
@abc.abstractmethod
112+
def filter_elements(
113+
self,
114+
) -> Mapping[Any, FloatCountRange]:
115+
"""Get filter elements as dict."""
116+
117+
@filter_elements.setter
118+
@abc.abstractmethod
119+
def filter_elements(
120+
self,
121+
filter_elements: Union[Mapping[Any, FloatCountRange], Sequence[Any]],
122+
) -> None:
123+
"""Set filter elements as dict.
124+
125+
Parameters
126+
----------
127+
filter_elements: Union[Mapping[Any, FloatCountRange], Sequence[Any]]
128+
List of filter elements.
129+
"""
130+
131+
def set_params(self, **parameters: Any) -> Self:
132+
"""Set parameters of BaseKeepMatchesFilter.
133+
134+
Parameters
135+
----------
136+
parameters: Any
137+
Parameters to set.
138+
139+
Returns
140+
-------
141+
Self
142+
Self.
143+
"""
144+
parameter_copy = dict(parameters)
145+
if "keep_matches" in parameter_copy:
146+
self.keep_matches = parameter_copy.pop("keep_matches")
147+
if "mode" in parameter_copy:
148+
self.mode = parameter_copy.pop("mode")
149+
if "filter_elements" in parameter_copy:
150+
self.filter_elements = parameter_copy.pop("filter_elements")
151+
super().set_params(**parameter_copy)
152+
return self
153+
154+
def get_params(self, deep: bool = True) -> dict[str, Any]:
155+
"""Get parameters of PatternFilter.
156+
157+
Parameters
158+
----------
159+
deep: bool, optional (default: True)
160+
If True, return the parameters of all subobjects that are PipelineElements.
161+
162+
Returns
163+
-------
164+
dict[str, Any]
165+
Parameters of BaseKeepMatchesFilter.
166+
"""
167+
params = super().get_params(deep=deep)
168+
params["keep_matches"] = self.keep_matches
169+
params["mode"] = self.mode
170+
params["filter_elements"] = self.filter_elements
171+
return params
172+
173+
def pretransform_single(self, value: RDKitMol) -> OptionalMol:
174+
"""Invalidate or validate molecule based on specified filter.
175+
176+
There are four possible scenarios:
177+
- mode = "any" & keep_matches = True: Needs to match at least one filter element.
178+
- mode = "any" & keep_matches = False: Must not match any filter element.
179+
- mode = "all" & keep_matches = True: Needs to match all filter elements.
180+
- mode = "all" & keep_matches = False: Must not match all filter elements.
181+
182+
Parameters
183+
----------
184+
value: RDKitMol
185+
Molecule to check.
186+
187+
Returns
188+
-------
189+
OptionalMol
190+
Molecule that matches defined filter elements, else InvalidInstance.
191+
"""
192+
for filter_element, (lower_limit, upper_limit) in self.filter_elements.items():
193+
property_value = self._calculate_single_element_value(filter_element, value)
194+
if _within_boundaries(lower_limit, upper_limit, property_value):
195+
# For "any" mode we can return early if a match is found
196+
if self.mode == "any":
197+
if not self.keep_matches:
198+
value = InvalidInstance(
199+
self.uuid,
200+
f"Molecule contains forbidden filter element {filter_element}.",
201+
self.name,
202+
)
203+
return value
204+
else:
205+
# For "all" mode we can return early if a match is not found
206+
if self.mode == "all":
207+
if self.keep_matches:
208+
value = InvalidInstance(
209+
self.uuid,
210+
f"Molecule does not contain required filter element {filter_element}.",
211+
self.name,
212+
)
213+
return value
214+
215+
# If this point is reached, no or all patterns were found
216+
# If mode is "any", finishing the loop means no match was found
217+
if self.mode == "any":
218+
if self.keep_matches:
219+
value = InvalidInstance(
220+
self.uuid,
221+
"Molecule does not match any of the required filter elements.",
222+
self.name,
223+
)
224+
# else: No match with forbidden filter elements was found, return original molecule
225+
return value
226+
227+
if self.mode == "all":
228+
if not self.keep_matches:
229+
value = InvalidInstance(
230+
self.uuid,
231+
"Molecule matches all forbidden filter elements.",
232+
self.name,
233+
)
234+
# else: All required filter elements were found, return original molecule
235+
return value
236+
237+
raise ValueError(f"Invalid mode: {self.mode}")
238+
239+
@abc.abstractmethod
240+
def _calculate_single_element_value(
241+
self, filter_element: Any, value: RDKitMol
242+
) -> float:
243+
"""Calculate the value of a single match.
244+
245+
Parameters
246+
----------
247+
filter_element: Any
248+
Match case to calculate.
249+
value: RDKitMol
250+
Molecule to calculate the match for.
251+
252+
Returns
253+
-------
254+
float
255+
Value of the match.
256+
"""
257+
258+
259+
class BasePatternsFilter(BaseKeepMatchesFilter, abc.ABC):
260+
"""Filter to keep or remove molecules based on patterns.
261+
262+
Attributes
263+
----------
264+
filter_elements: Union[Sequence[str], Mapping[str, IntOrIntCountRange]]
265+
List of patterns to allow in molecules.
266+
Alternatively, a dictionary can be passed with patterns as keys
267+
and an int for exact count or a tuple of minimum and maximum.
268+
[...]
269+
270+
Notes
271+
-----
272+
There are four possible scenarios:
273+
- mode = "any" & keep_matches = True: Needs to match at least one filter element.
274+
- mode = "any" & keep_matches = False: Must not match any filter element.
275+
- mode = "all" & keep_matches = True: Needs to match all filter elements.
276+
- mode = "all" & keep_matches = False: Must not match all filter elements.
277+
"""
278+
279+
_filter_elements: Mapping[str, IntCountRange]
280+
281+
@property
282+
def filter_elements(self) -> Mapping[str, IntCountRange]:
283+
"""Get allowed filter elements (patterns) as dict."""
284+
return self._filter_elements
285+
286+
@filter_elements.setter
287+
def filter_elements(
288+
self,
289+
patterns: Union[list[str], Mapping[str, IntOrIntCountRange]],
290+
) -> None:
291+
"""Set allowed filter elements (patterns) as dict.
292+
293+
Parameters
294+
----------
295+
patterns: Union[list[str], Mapping[str, IntOrIntCountRange]]
296+
List of patterns.
297+
"""
298+
if isinstance(patterns, (list, set)):
299+
self._filter_elements = {pat: (1, None) for pat in patterns}
300+
else:
301+
self._filter_elements = {
302+
pat: count_value_to_tuple(count) for pat, count in patterns.items()
303+
}
304+
self.patterns_mol_dict = list(self._filter_elements.keys()) # type: ignore
305+
306+
@property
307+
def patterns_mol_dict(self) -> Mapping[str, RDKitMol]:
308+
"""Get patterns as dict with RDKitMol objects."""
309+
return self._patterns_mol_dict
310+
311+
@patterns_mol_dict.setter
312+
def patterns_mol_dict(self, patterns: Sequence[str]) -> None:
313+
"""Set patterns as dict with RDKitMol objects.
314+
315+
Parameters
316+
----------
317+
patterns: Sequence[str]
318+
List of patterns.
319+
"""
320+
self._patterns_mol_dict = {pat: self._pattern_to_mol(pat) for pat in patterns}
321+
failed_patterns = [
322+
pat for pat, mol in self._patterns_mol_dict.items() if not mol
323+
]
324+
if failed_patterns:
325+
raise ValueError("Invalid pattern(s): " + ", ".join(failed_patterns))
326+
327+
@abc.abstractmethod
328+
def _pattern_to_mol(self, pattern: str) -> RDKitMol:
329+
"""Convert pattern to Rdkitmol object.
330+
331+
Parameters
332+
----------
333+
pattern: str
334+
Pattern to convert.
335+
336+
Returns
337+
-------
338+
RDKitMol
339+
RDKitMol object of the pattern.
340+
"""
341+
342+
def _calculate_single_element_value(
343+
self, filter_element: Any, value: RDKitMol
344+
) -> int:
345+
"""Calculate a single match count for a molecule.
346+
347+
Parameters
348+
----------
349+
filter_element: Any
350+
smarts to calculate match count for.
351+
value: RDKitMol
352+
Molecule to calculate smarts match count for.
353+
354+
Returns
355+
-------
356+
int
357+
smarts match count value.
358+
"""
359+
return len(value.GetSubstructMatches(self.patterns_mol_dict[filter_element]))

molpipeline/mol2mol/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
"""Init the module for mol2mol pipeline elements."""
22

33
from molpipeline.mol2mol.filter import (
4+
ComplexFilter,
45
ElementFilter,
56
EmptyMoleculeFilter,
67
InorganicsFilter,
78
MixtureFilter,
9+
RDKitDescriptorsFilter,
10+
SmartsFilter,
11+
SmilesFilter,
812
)
913
from molpipeline.mol2mol.reaction import MolToMolReaction
1014
from molpipeline.mol2mol.scaffolds import MakeScaffoldGeneric, MurckoScaffold
@@ -41,4 +45,8 @@
4145
"SolventRemover",
4246
"Uncharger",
4347
"InorganicsFilter",
48+
"SmartsFilter",
49+
"SmilesFilter",
50+
"RDKitDescriptorsFilter",
51+
"ComplexFilter",
4452
)

0 commit comments

Comments
 (0)