Skip to content

Commit 0a5251d

Browse files
authored
Fix bug: ElementFilter: roundtrip from recursive_to_json to recursive_from_json messes up atom types (#139)
* Add test to recreate bug * cast atom number to int
1 parent 84d09fd commit 0a5251d

File tree

2 files changed

+33
-1
lines changed

2 files changed

+33
-1
lines changed

molpipeline/mol2mol/filter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def allowed_element_numbers(
147147
}
148148
else:
149149
self._allowed_element_numbers = {
150-
atom_number: count_value_to_tuple(count)
150+
int(atom_number): count_value_to_tuple(count)
151151
for atom_number, count in allowed_element_numbers.items()
152152
}
153153

tests/test_elements/test_mol2mol/test_mol2mol_filter.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
"""Test MolFilter, which invalidate molecules based on criteria defined in the respective filter."""
22

3+
import json
4+
import tempfile
35
import unittest
6+
from pathlib import Path
47

58
from molpipeline import ErrorFilter, FilterReinserter, Pipeline
69
from molpipeline.any2mol import SmilesToMol
@@ -14,6 +17,7 @@
1417
SmartsFilter,
1518
SmilesFilter,
1619
)
20+
from molpipeline.utils.comparison import compare_recursive
1721
from molpipeline.utils.json_operations import recursive_from_json, recursive_to_json
1822
from molpipeline.utils.molpipeline_types import FloatCountRange, IntOrIntCountRange
1923

@@ -88,6 +92,34 @@ def test_element_filter(self) -> None:
8892
filtered_smiles = pipeline.fit_transform(SMILES_LIST)
8993
self.assertEqual(filtered_smiles, test_params["result"])
9094

95+
def test_json_roundtrip(self) -> None:
96+
"""Test if ElementFilter can be serialized and deserialized.
97+
98+
Notes
99+
-----
100+
It is important to save the ElementFilter as a JSON file and then load it back.
101+
This is because json.dumps() sets the keys of the dictionary to strings.
102+
"""
103+
element_filter = ElementFilter()
104+
json_object = recursive_to_json(element_filter)
105+
with tempfile.TemporaryDirectory() as temp_folder:
106+
temp_file_path = Path(temp_folder) / "test.json"
107+
with open(temp_file_path, "w", encoding="UTF-8") as out_file:
108+
json.dump(json_object, out_file)
109+
with open(temp_file_path, "r", encoding="UTF-8") as in_file:
110+
loaded_json_object = json.load(in_file)
111+
recreated_element_filter = recursive_from_json(loaded_json_object)
112+
113+
original_params = element_filter.get_params()
114+
recreated_params = recreated_element_filter.get_params()
115+
self.assertEqual(original_params.keys(), recreated_params.keys())
116+
for param_name, original_value in original_params.items():
117+
with self.subTest(param_name=param_name):
118+
self.assertTrue(
119+
compare_recursive(original_value, recreated_params[param_name]),
120+
f"Original: {original_value}, Recreated: {recreated_params[param_name]}",
121+
)
122+
91123

92124
class ComplexFilterTest(unittest.TestCase):
93125
"""Unittest for ComplexFilter."""

0 commit comments

Comments
 (0)