|
| 1 | +"""Abstract classes for filters.""" |
| 2 | + |
| 3 | +import abc |
| 4 | +from typing import Any, Literal, Mapping, Optional, Sequence, TypeAlias, Union |
| 5 | + |
| 6 | +try: |
| 7 | + from typing import Self # type: ignore[attr-defined] |
| 8 | +except ImportError: |
| 9 | + from typing_extensions import Self |
| 10 | + |
| 11 | +from molpipeline.abstract_pipeline_elements.core import ( |
| 12 | + InvalidInstance, |
| 13 | + MolToMolPipelineElement, |
| 14 | + OptionalMol, |
| 15 | + RDKitMol, |
| 16 | +) |
| 17 | +from molpipeline.utils.molpipeline_types import ( |
| 18 | + FloatCountRange, |
| 19 | + IntCountRange, |
| 20 | + IntOrIntCountRange, |
| 21 | +) |
| 22 | +from molpipeline.utils.value_conversions import count_value_to_tuple |
| 23 | + |
| 24 | +# possible mode types for a KeepMatchesFilter: |
| 25 | +# - "any" means one match is enough |
| 26 | +# - "all" means all elements must be matched |
| 27 | +FilterModeType: TypeAlias = Literal["any", "all"] |
| 28 | + |
| 29 | + |
| 30 | +def _within_boundaries( |
| 31 | + lower_bound: Optional[float], upper_bound: Optional[float], property_value: float |
| 32 | +) -> bool: |
| 33 | + """Check if a value is within the specified boundaries. |
| 34 | +
|
| 35 | + Boundaries given as None are ignored. |
| 36 | +
|
| 37 | + Parameters |
| 38 | + ---------- |
| 39 | + lower_bound: Optional[float] |
| 40 | + Lower boundary. |
| 41 | + upper_bound: Optional[float] |
| 42 | + Upper boundary. |
| 43 | + property_value: float |
| 44 | + Property value to check. |
| 45 | +
|
| 46 | + Returns |
| 47 | + ------- |
| 48 | + bool |
| 49 | + True if the value is within the boundaries, else False. |
| 50 | + """ |
| 51 | + if lower_bound is not None and property_value < lower_bound: |
| 52 | + return False |
| 53 | + if upper_bound is not None and property_value > upper_bound: |
| 54 | + return False |
| 55 | + return True |
| 56 | + |
| 57 | + |
| 58 | +class BaseKeepMatchesFilter(MolToMolPipelineElement, abc.ABC): |
| 59 | + """Filter to keep or remove molecules based on patterns. |
| 60 | +
|
| 61 | + Notes |
| 62 | + ----- |
| 63 | + There are four possible scenarios: |
| 64 | + - mode = "any" & keep_matches = True: Needs to match at least one filter element. |
| 65 | + - mode = "any" & keep_matches = False: Must not match any filter element. |
| 66 | + - mode = "all" & keep_matches = True: Needs to match all filter elements. |
| 67 | + - mode = "all" & keep_matches = False: Must not match all filter elements. |
| 68 | + """ |
| 69 | + |
| 70 | + keep_matches: bool |
| 71 | + mode: FilterModeType |
| 72 | + |
| 73 | + def __init__( |
| 74 | + self, |
| 75 | + filter_elements: Union[ |
| 76 | + Mapping[Any, Union[FloatCountRange, IntCountRange, IntOrIntCountRange]], |
| 77 | + Sequence[Any], |
| 78 | + ], |
| 79 | + keep_matches: bool = True, |
| 80 | + mode: FilterModeType = "any", |
| 81 | + name: Optional[str] = None, |
| 82 | + n_jobs: int = 1, |
| 83 | + uuid: Optional[str] = None, |
| 84 | + ) -> None: |
| 85 | + """Initialize BasePatternsFilter. |
| 86 | +
|
| 87 | + Parameters |
| 88 | + ---------- |
| 89 | + filter_elements: Union[Mapping[Any, Union[FloatCountRange, IntCountRange, IntOrIntCountRange]], Sequence[Any]] |
| 90 | + List of filter elements. Typically can be a list of patterns or a dictionary with patterns as keys and |
| 91 | + an int for exact count or a tuple of minimum and maximum. |
| 92 | + NOTE: for each child class, the type of filter_elements must be specified by the filter_elements setter. |
| 93 | + keep_matches: bool, optional (default: True) |
| 94 | + If True, molecules containing the specified patterns are kept, else removed. |
| 95 | + mode: FilterModeType, optional (default: "any") |
| 96 | + If "any", at least one of the specified patterns must be present in the molecule. |
| 97 | + If "all", all of the specified patterns must be present in the molecule. |
| 98 | + name: Optional[str], optional (default: None) |
| 99 | + Name of the pipeline element. |
| 100 | + n_jobs: int, optional (default: 1) |
| 101 | + Number of parallel jobs to use. |
| 102 | + uuid: str, optional (default: None) |
| 103 | + Unique identifier of the pipeline element. |
| 104 | + """ |
| 105 | + super().__init__(name=name, n_jobs=n_jobs, uuid=uuid) |
| 106 | + self.filter_elements = filter_elements # type: ignore |
| 107 | + self.keep_matches = keep_matches |
| 108 | + self.mode = mode |
| 109 | + |
| 110 | + @property |
| 111 | + @abc.abstractmethod |
| 112 | + def filter_elements( |
| 113 | + self, |
| 114 | + ) -> Mapping[Any, FloatCountRange]: |
| 115 | + """Get filter elements as dict.""" |
| 116 | + |
| 117 | + @filter_elements.setter |
| 118 | + @abc.abstractmethod |
| 119 | + def filter_elements( |
| 120 | + self, |
| 121 | + filter_elements: Union[Mapping[Any, FloatCountRange], Sequence[Any]], |
| 122 | + ) -> None: |
| 123 | + """Set filter elements as dict. |
| 124 | +
|
| 125 | + Parameters |
| 126 | + ---------- |
| 127 | + filter_elements: Union[Mapping[Any, FloatCountRange], Sequence[Any]] |
| 128 | + List of filter elements. |
| 129 | + """ |
| 130 | + |
| 131 | + def set_params(self, **parameters: Any) -> Self: |
| 132 | + """Set parameters of BaseKeepMatchesFilter. |
| 133 | +
|
| 134 | + Parameters |
| 135 | + ---------- |
| 136 | + parameters: Any |
| 137 | + Parameters to set. |
| 138 | +
|
| 139 | + Returns |
| 140 | + ------- |
| 141 | + Self |
| 142 | + Self. |
| 143 | + """ |
| 144 | + parameter_copy = dict(parameters) |
| 145 | + if "keep_matches" in parameter_copy: |
| 146 | + self.keep_matches = parameter_copy.pop("keep_matches") |
| 147 | + if "mode" in parameter_copy: |
| 148 | + self.mode = parameter_copy.pop("mode") |
| 149 | + if "filter_elements" in parameter_copy: |
| 150 | + self.filter_elements = parameter_copy.pop("filter_elements") |
| 151 | + super().set_params(**parameter_copy) |
| 152 | + return self |
| 153 | + |
| 154 | + def get_params(self, deep: bool = True) -> dict[str, Any]: |
| 155 | + """Get parameters of PatternFilter. |
| 156 | +
|
| 157 | + Parameters |
| 158 | + ---------- |
| 159 | + deep: bool, optional (default: True) |
| 160 | + If True, return the parameters of all subobjects that are PipelineElements. |
| 161 | +
|
| 162 | + Returns |
| 163 | + ------- |
| 164 | + dict[str, Any] |
| 165 | + Parameters of BaseKeepMatchesFilter. |
| 166 | + """ |
| 167 | + params = super().get_params(deep=deep) |
| 168 | + params["keep_matches"] = self.keep_matches |
| 169 | + params["mode"] = self.mode |
| 170 | + params["filter_elements"] = self.filter_elements |
| 171 | + return params |
| 172 | + |
| 173 | + def pretransform_single(self, value: RDKitMol) -> OptionalMol: |
| 174 | + """Invalidate or validate molecule based on specified filter. |
| 175 | +
|
| 176 | + There are four possible scenarios: |
| 177 | + - mode = "any" & keep_matches = True: Needs to match at least one filter element. |
| 178 | + - mode = "any" & keep_matches = False: Must not match any filter element. |
| 179 | + - mode = "all" & keep_matches = True: Needs to match all filter elements. |
| 180 | + - mode = "all" & keep_matches = False: Must not match all filter elements. |
| 181 | +
|
| 182 | + Parameters |
| 183 | + ---------- |
| 184 | + value: RDKitMol |
| 185 | + Molecule to check. |
| 186 | +
|
| 187 | + Returns |
| 188 | + ------- |
| 189 | + OptionalMol |
| 190 | + Molecule that matches defined filter elements, else InvalidInstance. |
| 191 | + """ |
| 192 | + for filter_element, (lower_limit, upper_limit) in self.filter_elements.items(): |
| 193 | + property_value = self._calculate_single_element_value(filter_element, value) |
| 194 | + if _within_boundaries(lower_limit, upper_limit, property_value): |
| 195 | + # For "any" mode we can return early if a match is found |
| 196 | + if self.mode == "any": |
| 197 | + if not self.keep_matches: |
| 198 | + value = InvalidInstance( |
| 199 | + self.uuid, |
| 200 | + f"Molecule contains forbidden filter element {filter_element}.", |
| 201 | + self.name, |
| 202 | + ) |
| 203 | + return value |
| 204 | + else: |
| 205 | + # For "all" mode we can return early if a match is not found |
| 206 | + if self.mode == "all": |
| 207 | + if self.keep_matches: |
| 208 | + value = InvalidInstance( |
| 209 | + self.uuid, |
| 210 | + f"Molecule does not contain required filter element {filter_element}.", |
| 211 | + self.name, |
| 212 | + ) |
| 213 | + return value |
| 214 | + |
| 215 | + # If this point is reached, no or all patterns were found |
| 216 | + # If mode is "any", finishing the loop means no match was found |
| 217 | + if self.mode == "any": |
| 218 | + if self.keep_matches: |
| 219 | + value = InvalidInstance( |
| 220 | + self.uuid, |
| 221 | + "Molecule does not match any of the required filter elements.", |
| 222 | + self.name, |
| 223 | + ) |
| 224 | + # else: No match with forbidden filter elements was found, return original molecule |
| 225 | + return value |
| 226 | + |
| 227 | + if self.mode == "all": |
| 228 | + if not self.keep_matches: |
| 229 | + value = InvalidInstance( |
| 230 | + self.uuid, |
| 231 | + "Molecule matches all forbidden filter elements.", |
| 232 | + self.name, |
| 233 | + ) |
| 234 | + # else: All required filter elements were found, return original molecule |
| 235 | + return value |
| 236 | + |
| 237 | + raise ValueError(f"Invalid mode: {self.mode}") |
| 238 | + |
| 239 | + @abc.abstractmethod |
| 240 | + def _calculate_single_element_value( |
| 241 | + self, filter_element: Any, value: RDKitMol |
| 242 | + ) -> float: |
| 243 | + """Calculate the value of a single match. |
| 244 | +
|
| 245 | + Parameters |
| 246 | + ---------- |
| 247 | + filter_element: Any |
| 248 | + Match case to calculate. |
| 249 | + value: RDKitMol |
| 250 | + Molecule to calculate the match for. |
| 251 | +
|
| 252 | + Returns |
| 253 | + ------- |
| 254 | + float |
| 255 | + Value of the match. |
| 256 | + """ |
| 257 | + |
| 258 | + |
| 259 | +class BasePatternsFilter(BaseKeepMatchesFilter, abc.ABC): |
| 260 | + """Filter to keep or remove molecules based on patterns. |
| 261 | +
|
| 262 | + Attributes |
| 263 | + ---------- |
| 264 | + filter_elements: Union[Sequence[str], Mapping[str, IntOrIntCountRange]] |
| 265 | + List of patterns to allow in molecules. |
| 266 | + Alternatively, a dictionary can be passed with patterns as keys |
| 267 | + and an int for exact count or a tuple of minimum and maximum. |
| 268 | + [...] |
| 269 | +
|
| 270 | + Notes |
| 271 | + ----- |
| 272 | + There are four possible scenarios: |
| 273 | + - mode = "any" & keep_matches = True: Needs to match at least one filter element. |
| 274 | + - mode = "any" & keep_matches = False: Must not match any filter element. |
| 275 | + - mode = "all" & keep_matches = True: Needs to match all filter elements. |
| 276 | + - mode = "all" & keep_matches = False: Must not match all filter elements. |
| 277 | + """ |
| 278 | + |
| 279 | + _filter_elements: Mapping[str, IntCountRange] |
| 280 | + |
| 281 | + @property |
| 282 | + def filter_elements(self) -> Mapping[str, IntCountRange]: |
| 283 | + """Get allowed filter elements (patterns) as dict.""" |
| 284 | + return self._filter_elements |
| 285 | + |
| 286 | + @filter_elements.setter |
| 287 | + def filter_elements( |
| 288 | + self, |
| 289 | + patterns: Union[list[str], Mapping[str, IntOrIntCountRange]], |
| 290 | + ) -> None: |
| 291 | + """Set allowed filter elements (patterns) as dict. |
| 292 | +
|
| 293 | + Parameters |
| 294 | + ---------- |
| 295 | + patterns: Union[list[str], Mapping[str, IntOrIntCountRange]] |
| 296 | + List of patterns. |
| 297 | + """ |
| 298 | + if isinstance(patterns, (list, set)): |
| 299 | + self._filter_elements = {pat: (1, None) for pat in patterns} |
| 300 | + else: |
| 301 | + self._filter_elements = { |
| 302 | + pat: count_value_to_tuple(count) for pat, count in patterns.items() |
| 303 | + } |
| 304 | + self.patterns_mol_dict = list(self._filter_elements.keys()) # type: ignore |
| 305 | + |
| 306 | + @property |
| 307 | + def patterns_mol_dict(self) -> Mapping[str, RDKitMol]: |
| 308 | + """Get patterns as dict with RDKitMol objects.""" |
| 309 | + return self._patterns_mol_dict |
| 310 | + |
| 311 | + @patterns_mol_dict.setter |
| 312 | + def patterns_mol_dict(self, patterns: Sequence[str]) -> None: |
| 313 | + """Set patterns as dict with RDKitMol objects. |
| 314 | +
|
| 315 | + Parameters |
| 316 | + ---------- |
| 317 | + patterns: Sequence[str] |
| 318 | + List of patterns. |
| 319 | + """ |
| 320 | + self._patterns_mol_dict = {pat: self._pattern_to_mol(pat) for pat in patterns} |
| 321 | + failed_patterns = [ |
| 322 | + pat for pat, mol in self._patterns_mol_dict.items() if not mol |
| 323 | + ] |
| 324 | + if failed_patterns: |
| 325 | + raise ValueError("Invalid pattern(s): " + ", ".join(failed_patterns)) |
| 326 | + |
| 327 | + @abc.abstractmethod |
| 328 | + def _pattern_to_mol(self, pattern: str) -> RDKitMol: |
| 329 | + """Convert pattern to Rdkitmol object. |
| 330 | +
|
| 331 | + Parameters |
| 332 | + ---------- |
| 333 | + pattern: str |
| 334 | + Pattern to convert. |
| 335 | +
|
| 336 | + Returns |
| 337 | + ------- |
| 338 | + RDKitMol |
| 339 | + RDKitMol object of the pattern. |
| 340 | + """ |
| 341 | + |
| 342 | + def _calculate_single_element_value( |
| 343 | + self, filter_element: Any, value: RDKitMol |
| 344 | + ) -> int: |
| 345 | + """Calculate a single match count for a molecule. |
| 346 | +
|
| 347 | + Parameters |
| 348 | + ---------- |
| 349 | + filter_element: Any |
| 350 | + smarts to calculate match count for. |
| 351 | + value: RDKitMol |
| 352 | + Molecule to calculate smarts match count for. |
| 353 | +
|
| 354 | + Returns |
| 355 | + ------- |
| 356 | + int |
| 357 | + smarts match count value. |
| 358 | + """ |
| 359 | + return len(value.GetSubstructMatches(self.patterns_mol_dict[filter_element])) |
0 commit comments