Skip to content

Commit 33c9a3d

Browse files
authored
[GH-2331] Geopandas: Document differences of sindex compared to gpd + sindex fixes (#2332)
1 parent e6e7c88 commit 33c9a3d

File tree

3 files changed

+100
-32
lines changed

3 files changed

+100
-32
lines changed

python/sedona/spark/geopandas/geoseries.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -751,9 +751,7 @@ def sindex(self) -> SpatialIndex:
751751
if geometry_column is None:
752752
raise ValueError("No geometry column found in GeoSeries")
753753
if self._sindex is None:
754-
self._sindex = SpatialIndex(
755-
self._internal.spark_frame, column_name=geometry_column
756-
)
754+
self._sindex = SpatialIndex(self)
757755
return self._sindex
758756

759757
@property

python/sedona/spark/geopandas/sindex.py

Lines changed: 49 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,23 @@ def __init__(self, geometry, index_type="strtree", column_name=None):
3838
3939
Parameters
4040
----------
41-
geometry : np.array of Shapely geometries, PySparkDataFrame column, or PySparkDataFrame
41+
geometry : np.array of Shapely geometries, GeoSeries, or PySparkDataFrame
4242
index_type : str, default "strtree"
4343
The type of spatial index to use.
4444
column_name : str, optional
4545
The column name to extract geometry from if `geometry` is a PySparkDataFrame.
46+
47+
Note: query methods (ie. query, nearest, intersection) have different behaviors depending on how the index is constructed.
48+
When constructed from a np.array, the query methods return indices like original geopandas.
49+
When constructed from a GeoSeries or PySparkDataFrame, the query methods return geometries.
4650
"""
51+
from sedona.spark.geopandas import GeoSeries
52+
53+
if isinstance(geometry, GeoSeries):
54+
from sedona.spark.geopandas.geoseries import _get_series_col_name
55+
56+
column_name = _get_series_col_name(geometry)
57+
geometry = geometry._internal.spark_frame
4758

4859
if isinstance(geometry, np.ndarray):
4960
self.geometry = geometry
@@ -65,7 +76,7 @@ def __init__(self, geometry, index_type="strtree", column_name=None):
6576
self._build_spark_index(column_name)
6677
else:
6778
raise TypeError(
68-
"Invalid type for `geometry`. Expected np.array or PySparkDataFrame."
79+
"Invalid type for `geometry`. Expected np.array, GeoSeries, or PySparkDataFrame."
6980
)
7081

7182
def query(self, geometry: BaseGeometry, predicate: str = None, sort: bool = False):
@@ -82,12 +93,18 @@ def query(self, geometry: BaseGeometry, predicate: str = None, sort: bool = Fals
8293
sort : bool, optional, default False
8394
Whether to sort the results.
8495
96+
Note: query() has different behaviors depending on how the index is constructed.
97+
When constructed from a np.array, this method returns indices like original geopandas.
98+
When constructed from a GeoSeries or PySparkDataFrame, this method returns geometries.
99+
85100
Note: Unlike Geopandas, Sedona does not support geometry input of type np.array or GeoSeries.
101+
It is recommended to instead use GeoSeries.intersects directly.
86102
87103
Returns
88104
-------
89105
list
90-
List of indices of matching geometries.
106+
List of geometries if constructed from a GeoSeries or PySparkDataFrame.
107+
List of the corresponding indices if constructed from a np.array.
91108
"""
92109

93110
if not isinstance(geometry, BaseGeometry):
@@ -96,7 +113,7 @@ def query(self, geometry: BaseGeometry, predicate: str = None, sort: bool = Fals
96113
)
97114

98115
log_advice(
99-
"`query` returns local list of indices of matching geometries onto driver's memory. "
116+
"`query` returns a local list onto driver's memory. "
100117
"It should only be used if the resulting collection is expected to be small."
101118
)
102119

@@ -170,10 +187,15 @@ def nearest(
170187
171188
Note: Unlike Geopandas, Sedona does not support geometry input of type np.array or GeoSeries.
172189
190+
Note: nearest() has different behaviors depending on how the index is constructed.
191+
When constructed from a np.array, this method returns indices like original geopandas.
192+
When constructed from a GeoSeries or PySparkDataFrame, this method returns geometries.
193+
173194
Returns
174195
-------
175196
list or tuple
176-
List of indices of nearest geometries, optionally with distances.
197+
List of geometries if constructed from a GeoSeries or PySparkDataFrame.
198+
List of the corresponding indices if constructed from a np.array.
177199
"""
178200

179201
if not isinstance(geometry, BaseGeometry):
@@ -194,15 +216,18 @@ def nearest(
194216
from sedona.spark.core.spatialOperator import KNNQuery
195217

196218
# Execute the KNN query
197-
results = KNNQuery.SpatialKnnQuery(self._indexed_rdd, geometry, k, False)
219+
geo_data_list = KNNQuery.SpatialKnnQuery(
220+
self._indexed_rdd, geometry, k, False
221+
)
222+
223+
# No need to keep the userData field, so convert it directly to a list of geometries
224+
geoms_list = [row.geom for row in geo_data_list]
198225

199226
if return_distance:
200227
# Calculate distances if requested
201-
distances = [
202-
geom.distance(geometry) for geom in [row.geom for row in results]
203-
]
204-
return results, distances
205-
return results
228+
distances = [geom.distance(geometry) for geom in geoms_list]
229+
return geoms_list, distances
230+
return geoms_list
206231
else:
207232
# For local spatial index based on Shapely STRtree
208233
if k > len(self.geometry):
@@ -220,20 +245,29 @@ def nearest(
220245

221246
def intersection(self, bounds):
222247
"""
223-
Find geometries that intersect the given bounding box.
248+
Find geometries that intersect the given bounding box. Similar to the Geopandas version,
249+
this is a compatibility wrapper for rtree.index.Index.intersection, use query instead.
224250
225251
Parameters
226252
----------
227253
bounds : tuple
228254
Bounding box as (min_x, min_y, max_x, max_y).
229255
256+
Note: intersection() has different behaviors depending on how the index is constructed.
257+
When constructed from a np.array, this method returns indices like original geopandas.
258+
When constructed from a GeoSeries or PySparkDataFrame, this method returns geometries.
259+
260+
Note: Unlike Geopandas, Sedona does not support geometry input of type np.array or GeoSeries.
261+
It is recommended to instead use GeoSeries.intersects directly.
262+
230263
Returns
231264
-------
232265
list
233-
List of indices of matching geometries.
266+
List of geometries if constructed from a GeoSeries or PySparkDataFrame.
267+
List of the corresponding indices if constructed from a np.array.
234268
"""
235269
log_advice(
236-
"`intersection` returns local list of indices of matching geometries onto driver's memory. "
270+
"`intersection` returns local list of matching geometries onto driver's memory. "
237271
"It should only be used if the resulting collection is expected to be small."
238272
)
239273

@@ -246,16 +280,7 @@ def intersection(self, bounds):
246280
bbox = box(*bounds)
247281

248282
if self._is_spark:
249-
# For Spark-based spatial index
250-
from sedona.spark.core.spatialOperator import RangeQuery
251-
252-
# Execute the spatial range query with the bounding box
253-
result_rdd = RangeQuery.SpatialRangeQuery(
254-
self._indexed_rdd, bbox, True, True
255-
)
256-
257-
results = result_rdd.collect()
258-
return results
283+
return self.query(bbox, predicate="intersects")
259284
else:
260285
# For local spatial index based on Shapely STRtree
261286
try:

python/tests/geopandas/test_sindex.py

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import numpy as np
2020
import shapely
2121
from pyspark.sql.functions import expr
22-
from shapely.geometry import Point, Polygon, LineString
22+
from shapely.geometry import Point, Polygon, LineString, box
2323

2424
from tests.test_base import TestBase
2525
from sedona.spark.geopandas import GeoSeries
@@ -63,6 +63,31 @@ def setup_method(self):
6363
]
6464
)
6565

66+
def test_construct_from_geoseries(self):
67+
# Construct from a GeoSeries
68+
gs = GeoSeries([Point(x, x) for x in range(5)])
69+
sindex = SpatialIndex(gs)
70+
result = sindex.query(Point(2, 2))
71+
# SpatialIndex constructed from GeoSeries return geometries
72+
assert result == [Point(2, 2)]
73+
74+
def test_construct_from_pyspark_dataframe(self):
75+
# Construct from PySparkDataFrame
76+
df = self.spark.createDataFrame(
77+
[(Point(x, x),) for x in range(5)], ["geometry"]
78+
)
79+
sindex = SpatialIndex(df, column_name="geometry")
80+
result = sindex.query(Point(2, 2))
81+
assert result == [Point(2, 2)]
82+
83+
def test_construct_from_nparray(self):
84+
# Construct from np.array
85+
array = np.array([Point(x, x) for x in range(5)])
86+
sindex = SpatialIndex(array)
87+
result = sindex.query(Point(2, 2))
88+
# Returns indices like original geopandas
89+
assert result == np.array([2])
90+
6691
def test_geoseries_sindex_property_exists(self):
6792
"""Test that the sindex property exists on GeoSeries."""
6893
assert hasattr(self.points, "sindex")
@@ -182,7 +207,7 @@ def test_nearest_method(self):
182207
assert len(nearest_result) == 1
183208

184209
# The nearest point should have id=2 (POINT(1 1))
185-
assert nearest_result[0].geom.wkt == "POINT (1 1)"
210+
assert nearest_result[0].wkt == "POINT (1 1)"
186211

187212
# Test finding k=2 nearest neighbors
188213
nearest_2_results = spark_sindex.nearest(query_point, k=2)
@@ -219,7 +244,7 @@ def test_nearest_spark_with_various_geometries(self):
219244

220245
# Should find polygon containing the point
221246
assert len(nearest_geom) == 1
222-
assert "POLYGON" in nearest_geom[0].geom.wkt
247+
assert "POLYGON" in nearest_geom[0].wkt
223248

224249
# Test with linestring query
225250
query_line = LineString([(1.5, 1.5), (2.5, 2.5)])
@@ -343,7 +368,12 @@ def test_intersection_method(self):
343368
result_rows = spark_sindex.intersection(bounds)
344369

345370
# Verify correct results are returned
346-
assert len(result_rows) >= 2
371+
expected = [
372+
Polygon([(1, 1), (2, 1), (2, 2), (1, 2), (1, 1)]),
373+
Polygon([(2, 2), (3, 2), (3, 3), (2, 3), (2, 2)]),
374+
Polygon([(3, 3), (4, 3), (4, 4), (3, 4), (3, 3)]),
375+
]
376+
assert result_rows == expected
347377

348378
# Test with bounds that don't intersect any geometry
349379
empty_bounds = (10, 10, 11, 11)
@@ -353,7 +383,14 @@ def test_intersection_method(self):
353383
# Test with bounds that cover all geometries
354384
full_bounds = (-1, -1, 6, 6)
355385
full_results = spark_sindex.intersection(full_bounds)
356-
assert len(full_results) == 5 # Should match all 5 polygons
386+
expected = [
387+
Polygon([(0, 0), (1, 0), (1, 1), (0, 1), (0, 0)]),
388+
Polygon([(1, 1), (2, 1), (2, 2), (1, 2), (1, 1)]),
389+
Polygon([(2, 2), (3, 2), (3, 3), (2, 3), (2, 2)]),
390+
Polygon([(3, 3), (4, 3), (4, 4), (3, 4), (3, 3)]),
391+
Polygon([(4, 4), (5, 4), (5, 5), (4, 5), (4, 4)]),
392+
]
393+
assert full_results == expected
357394

358395
def test_intersection_with_points(self):
359396
"""Test the intersection method with point geometries."""
@@ -426,3 +463,11 @@ def test_intersection_with_mixed_geometries(self):
426463

427464
# Verify results
428465
assert len(results) == 2
466+
467+
# test from the geopandas docstring
468+
def test_geoseries_sindex_intersection(self):
469+
gs = GeoSeries([Point(x, x) for x in range(10)])
470+
result = gs.sindex.intersection(box(1, 1, 3, 3).bounds)
471+
# Unlike original geopandas, this returns geometries instead of indices
472+
expected = [Point(1, 1), Point(2, 2), Point(3, 3)]
473+
assert result == expected

0 commit comments

Comments
 (0)