|
1 | 1 | """Test MolFilter, which invalidate molecules based on criteria defined in the respective filter.""" |
2 | 2 |
|
| 3 | +import json |
| 4 | +import tempfile |
3 | 5 | import unittest |
| 6 | +from pathlib import Path |
4 | 7 |
|
5 | 8 | from molpipeline import ErrorFilter, FilterReinserter, Pipeline |
6 | 9 | from molpipeline.any2mol import SmilesToMol |
|
14 | 17 | SmartsFilter, |
15 | 18 | SmilesFilter, |
16 | 19 | ) |
| 20 | +from molpipeline.utils.comparison import compare_recursive |
17 | 21 | from molpipeline.utils.json_operations import recursive_from_json, recursive_to_json |
18 | 22 | from molpipeline.utils.molpipeline_types import FloatCountRange, IntOrIntCountRange |
19 | 23 |
|
@@ -88,6 +92,34 @@ def test_element_filter(self) -> None: |
88 | 92 | filtered_smiles = pipeline.fit_transform(SMILES_LIST) |
89 | 93 | self.assertEqual(filtered_smiles, test_params["result"]) |
90 | 94 |
|
| 95 | + def test_json_roundtrip(self) -> None: |
| 96 | + """Test if ElementFilter can be serialized and deserialized. |
| 97 | +
|
| 98 | + Notes |
| 99 | + ----- |
| 100 | + It is important to save the ElementFilter as a JSON file and then load it back. |
| 101 | + This is because json.dumps() sets the keys of the dictionary to strings. |
| 102 | + """ |
| 103 | + element_filter = ElementFilter() |
| 104 | + json_object = recursive_to_json(element_filter) |
| 105 | + with tempfile.TemporaryDirectory() as temp_folder: |
| 106 | + temp_file_path = Path(temp_folder) / "test.json" |
| 107 | + with open(temp_file_path, "w", encoding="UTF-8") as out_file: |
| 108 | + json.dump(json_object, out_file) |
| 109 | + with open(temp_file_path, "r", encoding="UTF-8") as in_file: |
| 110 | + loaded_json_object = json.load(in_file) |
| 111 | + recreated_element_filter = recursive_from_json(loaded_json_object) |
| 112 | + |
| 113 | + original_params = element_filter.get_params() |
| 114 | + recreated_params = recreated_element_filter.get_params() |
| 115 | + self.assertEqual(original_params.keys(), recreated_params.keys()) |
| 116 | + for param_name, original_value in original_params.items(): |
| 117 | + with self.subTest(param_name=param_name): |
| 118 | + self.assertTrue( |
| 119 | + compare_recursive(original_value, recreated_params[param_name]), |
| 120 | + f"Original: {original_value}, Recreated: {recreated_params[param_name]}", |
| 121 | + ) |
| 122 | + |
91 | 123 |
|
92 | 124 | class ComplexFilterTest(unittest.TestCase): |
93 | 125 | """Unittest for ComplexFilter.""" |
|
0 commit comments