Skip to content

Commit 51c36e4

Browse files
committed
Fix serialization bug and packaging issues for macOS/Windows
This commit addresses two critical issues blocking CI: 1. Serialization bug (GitHub issue comment): - Added idx2exact to serialize(), save(), and load() methods - This fixes correctness loss after save/load for trees built from float64 input - The idx2exact map stores double-precision coordinates for refinement - Without this fix, trees lose precision after save/load cycles 2. Packaging issue (macOS/Windows test failures): - Added CMAKE_RUNTIME_OUTPUT_DIRECTORY_RELEASE for Windows - Added CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE for macOS - Fixes ModuleNotFoundError on macOS/Windows due to multi-config generators - Windows treats .pyd as RUNTIME, macOS needs config-specific output dirs 3. Comprehensive regression tests: - test_save_load_float64_matteo_case: Tests Matteo bug case with float64 - test_save_load_float32_no_regression: Ensures float32 path still works - Both tests verify correctness survives save/load cycles - Tests include explicit gc.collect() for Windows file locking All 123 tests pass locally (121 original + 2 new regression tests). Related to PR #45
1 parent beb0131 commit 51c36e4

File tree

3 files changed

+103
-4
lines changed

3 files changed

+103
-4
lines changed

cpp/prtree.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -756,7 +756,7 @@ class PRTree
756756
template <class Archive>
757757
void serialize(Archive &archive)
758758
{
759-
archive(flat_tree, idx2bb, idx2data, global_idx, n_at_build);
759+
archive(flat_tree, idx2bb, idx2data, global_idx, n_at_build, idx2exact);
760760
}
761761

762762
void save(std::string fname)
@@ -769,7 +769,8 @@ class PRTree
769769
cereal::make_nvp("idx2bb", idx2bb),
770770
cereal::make_nvp("idx2data", idx2data),
771771
cereal::make_nvp("global_idx", global_idx),
772-
cereal::make_nvp("n_at_build", n_at_build));
772+
cereal::make_nvp("n_at_build", n_at_build),
773+
cereal::make_nvp("idx2exact", idx2exact));
773774
}
774775
}
775776
}
@@ -784,7 +785,8 @@ class PRTree
784785
cereal::make_nvp("idx2bb", idx2bb),
785786
cereal::make_nvp("idx2data", idx2data),
786787
cereal::make_nvp("global_idx", global_idx),
787-
cereal::make_nvp("n_at_build", n_at_build));
788+
cereal::make_nvp("n_at_build", n_at_build),
789+
cereal::make_nvp("idx2exact", idx2exact));
788790
}
789791
}
790792
}

setup.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,18 @@ def build_extension(self, ext):
6565

6666
if platform.system() == "Windows":
6767
cmake_args += [
68-
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{}={}".format(cfg.upper(), extdir)
68+
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{}={}".format(cfg.upper(), extdir),
69+
"-DCMAKE_RUNTIME_OUTPUT_DIRECTORY_{}={}".format(cfg.upper(), extdir)
6970
]
7071
if sys.maxsize > 2**32:
7172
cmake_args += ["-A", "x64"]
7273
build_args += ["--", "/m"]
74+
elif platform.system() == "Darwin":
75+
cmake_args += [
76+
"-DCMAKE_BUILD_TYPE=" + cfg,
77+
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{}={}".format(cfg.upper(), extdir)
78+
]
79+
build_args += ["--", "-j" + str(cpu_count())]
7380
else:
7481
cmake_args += ["-DCMAKE_BUILD_TYPE=" + cfg]
7582
build_args += ["--", "-j" + str(cpu_count())]

tests/test_PRTree.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,3 +291,93 @@ def test_query_vs_batch_query_consistency(PRTree, dim):
291291
single_result = tree.query(query)
292292
assert set(batch_results[i]) == set(single_result), \
293293
f"Query {i}: batch_query returned {batch_results[i]}, query returned {single_result}"
294+
295+
296+
def test_save_load_float64_matteo_case(tmp_path):
297+
"""Regression test: ensure idx2exact survives save/load for float64 input.
298+
299+
This tests the fix for the serialization bug where idx2exact was not being
300+
archived, causing trees built from float64 input to lose correctness after
301+
save/load. The Matteo bug case has boxes separated by ~5.4e-6, which requires
302+
double-precision refinement to correctly identify as disjoint.
303+
"""
304+
import gc
305+
306+
A = np.array([[72.47410062, 80.52848893, 54.68197159, 75.02750896, 85.40646976, 62.42859506]], dtype=np.float64)
307+
B = np.array([[75.02751435, 74.65699325, 61.09751679, 78.71358218, 82.4585436, 67.24904609]], dtype=np.float64)
308+
309+
assert A[0][3] < B[0][0], "Test setup error: boxes should be disjoint"
310+
gap = B[0][0] - A[0][3]
311+
assert 5e-6 < gap < 6e-6, f"Test setup error: expected gap ~5.4e-6, got {gap}"
312+
313+
tree = PRTree3D(np.array([0], dtype=np.int64), A)
314+
315+
result_before = tree.batch_query(B)
316+
assert result_before == [[]], f"Before save: Expected [[]] (disjoint), got {result_before}"
317+
318+
fname = tmp_path / "tree_float64.bin"
319+
fname = str(fname)
320+
tree.save(fname)
321+
322+
del tree
323+
gc.collect()
324+
325+
tree_loaded = PRTree3D(fname)
326+
327+
result_after = tree_loaded.batch_query(B)
328+
assert result_after == [[]], f"After load: Expected [[]] (disjoint), got {result_after}"
329+
330+
np.random.seed(42)
331+
queries = np.random.rand(10, 6).astype(np.float64) * 100
332+
for i in range(3):
333+
queries[:, i + 3] += queries[:, i] + 1e-5 # Small gaps
334+
335+
results_before_save = tree_loaded.batch_query(queries)
336+
337+
fname2 = tmp_path / "tree_float64_2.bin"
338+
fname2 = str(fname2)
339+
tree_loaded.save(fname2)
340+
del tree_loaded
341+
gc.collect()
342+
343+
tree_loaded2 = PRTree3D(fname2)
344+
results_after_save = tree_loaded2.batch_query(queries)
345+
346+
assert results_before_save == results_after_save, \
347+
"Random queries: results changed after save/load cycle"
348+
349+
350+
def test_save_load_float32_no_regression(tmp_path):
351+
"""Regression test: ensure float32 path still works correctly after save/load.
352+
353+
This tests that the serialization fix (adding idx2exact to archive) doesn't
354+
break the float32 path, which doesn't use idx2exact.
355+
"""
356+
import gc
357+
358+
np.random.seed(42)
359+
N = 100
360+
idx = np.arange(N, dtype=np.int64)
361+
x = np.random.rand(N, 6).astype(np.float32) * 100
362+
for i in range(3):
363+
x[:, i + 3] += x[:, i] + 1.0 # Ensure valid boxes
364+
365+
tree = PRTree3D(idx, x)
366+
367+
queries = np.random.rand(20, 6).astype(np.float32) * 100
368+
for i in range(3):
369+
queries[:, i + 3] += queries[:, i] + 1.0
370+
371+
results_before = tree.batch_query(queries)
372+
373+
fname = tmp_path / "tree_float32.bin"
374+
fname = str(fname)
375+
tree.save(fname)
376+
del tree
377+
gc.collect()
378+
379+
tree_loaded = PRTree3D(fname)
380+
results_after = tree_loaded.batch_query(queries)
381+
382+
assert results_before == results_after, \
383+
"Float32 path: results changed after save/load cycle"

0 commit comments

Comments
 (0)