Skip to content

Commit fdea6af

Browse files
author
Rohit Rastogi
authored
feat: replace pylance kmeans implementation with scikit-learn and expose num_init and max_iter params in api (#104)
1 parent 845cb94 commit fdea6af

File tree

9 files changed

+252
-63
lines changed

9 files changed

+252
-63
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,13 @@ dependencies = [
2626
"numpy>=2.0.0",
2727
"polars>=1.20.0",
2828
"tiktoken>=0.9.0",
29-
"pylance>=0.23.2",
3029
"lancedb>=0.22.0",
3130
"openai>=1.82.0",
3231
"sqlglot>=26.25.3",
3332
"pandas>=2.2.2",
3433
"cloudpickle>=3.1.1",
3534
"jinja2>=3.1.6",
35+
"scikit-learn>=1.7.1",
3636
]
3737

3838
[project.urls]

src/fenic/_backends/local/physical_plan/transform.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from fenic._backends.local.lineage import OperatorLineage
1111
from fenic._backends.local.physical_plan.utils import apply_ingestion_coercions
1212
from fenic._backends.local.semantic_operators.cluster import Cluster
13-
from fenic.core._logical_plan.plans import CacheInfo
13+
from fenic.core._logical_plan.plans import CacheInfo, CentroidInfo
1414
from fenic.core.error import InternalError
1515

1616
if TYPE_CHECKING:
@@ -337,15 +337,19 @@ def __init__(
337337
by_expr: pl.Expr,
338338
by_expr_name: str,
339339
num_clusters: int,
340+
max_iter: int,
341+
num_init: int,
340342
label_column: str,
341-
centroid_info: Optional[Tuple[str, int]],
343+
centroid_info: Optional[CentroidInfo],
342344
cache_info: Optional[CacheInfo],
343345
session_state: LocalSessionState,
344346
):
345347
super().__init__([child], cache_info=cache_info, session_state=session_state)
346348
self.by_expr = by_expr
347349
self.by_expr_name = by_expr_name
348350
self.num_clusters = num_clusters
351+
self.max_iter = max_iter
352+
self.num_init = num_init
349353
self.label_column = label_column
350354
self.centroid_info = centroid_info
351355

@@ -359,9 +363,11 @@ def _execute(self, child_dfs: List[pl.DataFrame]) -> pl.DataFrame:
359363
clustered_df = Cluster(
360364
child_df,
361365
self.by_expr_name,
362-
self.num_clusters,
363-
self.label_column,
364-
self.centroid_info,
366+
num_clusters=self.num_clusters,
367+
max_iter=self.max_iter,
368+
num_init=self.num_init,
369+
label_column=self.label_column,
370+
centroid_info=self.centroid_info,
365371
).execute()
366372

367373
# Remove the temporary column we added for clustering if it wasn't in the original
Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
import logging
2-
from typing import Optional, Tuple
2+
from typing import Optional
33

44
import numpy as np
55
import polars as pl
66
import pyarrow as pa
7-
from lance.util import KMeans
7+
from sklearn.cluster import KMeans
88

99
from fenic._backends.local.semantic_operators.utils import (
1010
filter_invalid_embeddings_expr,
1111
)
12+
from fenic.core._logical_plan.plans import CentroidInfo
1213

1314
logger = logging.getLogger(__name__)
1415

@@ -18,21 +19,23 @@ def __init__(
1819
self,
1920
input: pl.DataFrame,
2021
embedding_column_name: str,
21-
num_centroids: int,
22+
num_clusters: int,
23+
max_iter: int,
24+
num_init: int,
2225
label_column: str,
23-
centroid_info: Optional[Tuple[str, int]],
24-
num_iter: int = 50,
26+
centroid_info: Optional[CentroidInfo],
2527
):
2628
self.input = input
2729
self.embedding_column_name = embedding_column_name
2830
input_height = input.height
29-
if num_centroids > input_height:
31+
if num_clusters > input_height:
3032
logger.warning(
31-
f"`num_centroids` was set to {num_centroids}, but the input DataFrame only contains {input_height} rows. "
32-
f"Reducing `num_centroids` to {input_height} to match the available number of rows."
33+
f"`num_clusters` was set to {num_clusters}, but the input DataFrame only contains {input_height} rows. "
34+
f"Reducing `num_clusters` to {input_height} to match the available number of rows."
3335
)
34-
self.num_centroids = min(num_centroids, input_height)
35-
self.num_iter = num_iter
36+
self.num_clusters = min(num_clusters, input_height)
37+
self.max_iter = max_iter
38+
self.num_init = num_init
3639
self.label_column = label_column
3740
self.centroid_info = centroid_info
3841

@@ -47,10 +50,18 @@ def execute(self) -> pl.DataFrame:
4750
centroids = None
4851
if not valid_df.is_empty():
4952
embeddings = np.stack(valid_df[self.embedding_column_name])
50-
kmeans = KMeans(k=self.num_centroids, max_iters=self.num_iter)
51-
kmeans.fit(embeddings)
52-
predicted = kmeans.predict(embeddings).tolist()
53-
cluster_centroids = kmeans.centroids.to_numpy(zero_copy_only=False)
53+
54+
# Using sklearn KMeans with k-means++ initialization (default)
55+
kmeans = KMeans(
56+
n_clusters=self.num_clusters,
57+
max_iter=self.max_iter,
58+
init='k-means++', # This is the default, but being explicit
59+
n_init=self.num_init, # Number of times to run k-means with different centroid seeds
60+
random_state=42 # For reproducibility
61+
)
62+
63+
predicted = kmeans.fit_predict(embeddings)
64+
cluster_centroids = kmeans.cluster_centers_
5465

5566
if self.centroid_info is not None:
5667
centroids = [None] * df.height
@@ -65,8 +76,8 @@ def execute(self) -> pl.DataFrame:
6576
if self.centroid_info is not None:
6677
res = res.with_columns(
6778
pl.from_arrow(
68-
pa.array(centroids, type=pa.list_(pa.float32(), self.centroid_info[1]))
69-
).alias(self.centroid_info[0])
79+
pa.array(centroids, type=pa.list_(pa.float32(), self.centroid_info.num_dimensions))
80+
).alias(self.centroid_info.centroid_column)
7081
)
7182

7283
return res

src/fenic/_backends/local/transpiler/plan_converter.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -264,9 +264,11 @@ def convert(
264264
child_physical,
265265
physical_by_expr,
266266
str(logical.by_expr()),
267-
logical.num_clusters(),
268-
logical.label_column(),
269-
logical.centroid_info(),
267+
num_clusters=logical.num_clusters(),
268+
max_iter=logical.max_iter(),
269+
num_init=logical.num_init(),
270+
label_column=logical.label_column(),
271+
centroid_info=logical.centroid_info(),
270272
cache_info=logical.cache_info,
271273
session_state=self.session_state,
272274
)

src/fenic/api/dataframe/semantic_extensions.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,15 @@ def __init__(self, df: DataFrame):
3535
"""
3636
self._df = df
3737

38-
def with_cluster_labels(self, by: ColumnOrName, num_clusters: int, label_column: str = "cluster_label", centroid_column: Optional[str] = None) -> DataFrame:
38+
def with_cluster_labels(
39+
self,
40+
by: ColumnOrName,
41+
num_clusters: int,
42+
max_iter: int = 300,
43+
num_init: int = 1,
44+
label_column: str = "cluster_label",
45+
centroid_column: Optional[str] = None,
46+
) -> DataFrame:
3947
"""Cluster rows using K-means and add cluster metadata columns.
4048
4149
This method clusters rows based on the given embedding column or expression using K-means.
@@ -45,6 +53,8 @@ def with_cluster_labels(self, by: ColumnOrName, num_clusters: int, label_column:
4553
Args:
4654
by: Column or expression producing embeddings to cluster (e.g., `embed(col("text"))`).
4755
num_clusters: Number of clusters to compute (must be > 0).
56+
max_iter: Maximum iterations for a single run of the k-means algorithm. The algorithm stops when it either converges or reaches this limit.
57+
num_init: Number of independent runs of k-means with different centroid seeds. The best result is selected.
4858
label_column: Name of the output column for cluster IDs. Default is "cluster_label".
4959
centroid_column: If provided, adds a column with this name containing the centroid embedding
5060
for each row's assigned cluster.
@@ -56,14 +66,16 @@ def with_cluster_labels(self, by: ColumnOrName, num_clusters: int, label_column:
5666
5767
Raises:
5868
ValidationError: If num_clusters is not a positive integer
69+
ValidationError: If max_iter is not a positive integer
70+
ValidationError: If num_init is not a positive integer
5971
ValidationError: If label_column is not a non-empty string
6072
ValidationError: If centroid_column is not a non-empty string
6173
TypeMismatchError: If the column is not an EmbeddingType
6274
6375
Example: Basic clustering
6476
```python
6577
# Cluster customer feedback and add cluster metadata
66-
clustered_df = df.semantic.with_cluster_labels("feedback_embeddings", 5)
78+
clustered_df = df.semantic.with_cluster_labels("feedback_embeddings", num_clusters=5)
6779
6880
# Then use regular operations to analyze clusters
6981
clustered_df.group_by("cluster_label").agg(count("*"), avg("rating"))
@@ -72,15 +84,23 @@ def with_cluster_labels(self, by: ColumnOrName, num_clusters: int, label_column:
7284
Example: Filter outliers using centroids
7385
```python
7486
# Cluster and filter out rows far from their centroid
75-
clustered_df = df.semantic.with_cluster_labels("embeddings", 3, centroid_column="cluster_centroid")
87+
clustered_df = df.semantic.with_cluster_labels("embeddings", num_clusters=3, num_init=10, centroid_column="cluster_centroid")
7688
clean_df = clustered_df.filter(
7789
embedding.compute_similarity("embeddings", "cluster_centroid", metric="cosine") > 0.7
7890
)
7991
```
8092
"""
8193
# Validate num_clusters
8294
if not isinstance(num_clusters, int) or num_clusters <= 0:
83-
raise ValidationError("`num_clusters` must be a positive integer greater than 0.")
95+
raise ValidationError("`num_clusters` must be a positive integer.")
96+
97+
# Validate max_iter
98+
if not isinstance(max_iter, int) or max_iter <= 0:
99+
raise ValidationError("`max_iter` must be a positive integer.")
100+
101+
# Validate num_init
102+
if not isinstance(num_init, int) or num_init <= 0:
103+
raise ValidationError("`num_init` must be a positive integer.")
84104

85105
# Validate clustering target
86106
if not isinstance(by, ColumnOrName):
@@ -106,7 +126,13 @@ def with_cluster_labels(self, by: ColumnOrName, num_clusters: int, label_column:
106126

107127
return self._df._from_logical_plan(
108128
SemanticCluster(
109-
self._df._logical_plan, by_expr, num_clusters, label_column, centroid_column
129+
self._df._logical_plan,
130+
by_expr,
131+
num_clusters=num_clusters,
132+
max_iter=max_iter,
133+
num_init=num_init,
134+
label_column=label_column,
135+
centroid_column=centroid_column,
110136
)
111137
)
112138

src/fenic/core/_logical_plan/plans/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
)
1919
from fenic.core._logical_plan.plans.transform import (
2020
SQL,
21+
CentroidInfo,
2122
DropDuplicates,
2223
Explode,
2324
Filter,
@@ -52,4 +53,5 @@
5253
"Sort",
5354
"Union",
5455
"Unnest",
56+
"CentroidInfo",
5557
]

src/fenic/core/_logical_plan/plans/transform.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import logging
44
import re
5+
from dataclasses import dataclass
56
from typing import Dict, List, Optional, Tuple
67

78
import duckdb
@@ -415,21 +416,30 @@ def with_children(self, children: List[LogicalPlan]) -> LogicalPlan:
415416
result.set_cache_info(self.cache_info)
416417
return result
417418

419+
@dataclass
420+
class CentroidInfo:
421+
centroid_column: str
422+
num_dimensions: int
423+
418424
class SemanticCluster(LogicalPlan):
419425
def __init__(
420426
self,
421427
input: LogicalPlan,
422428
by_expr: LogicalExpr,
423429
num_clusters: int,
430+
max_iter: int,
431+
num_init: int,
424432
label_column: str,
425433
centroid_column: Optional[str],
426434
):
427435
self._input = input
428436
self._by_expr = by_expr
429437
self._num_clusters = num_clusters
438+
self._max_iter = max_iter
439+
self._num_init = num_init
430440
self._label_column = label_column
431441
self._centroid_column = centroid_column
432-
self._centroid_info: Optional[Tuple[str, int]] = None
442+
self._centroid_info: Optional[CentroidInfo] = None
433443
super().__init__(self._input.session_state)
434444

435445
def children(self) -> List[LogicalPlan]:
@@ -446,7 +456,7 @@ def _build_schema(self) -> Schema:
446456
new_fields = [ColumnField(self._label_column, IntegerType)]
447457
if self._centroid_column:
448458
new_fields.append(ColumnField(self._centroid_column, by_expr_type))
449-
self._centroid_info = (self._centroid_column, by_expr_type.dimensions)
459+
self._centroid_info = CentroidInfo(self._centroid_column, by_expr_type.dimensions)
450460

451461
return Schema(column_fields=self._input.schema().column_fields + new_fields)
452462

@@ -456,7 +466,13 @@ def _repr(self) -> str:
456466
def num_clusters(self) -> int:
457467
return self._num_clusters
458468

459-
def centroid_info(self) -> Optional[Tuple[str, int]]:
469+
def max_iter(self) -> int:
470+
return self._max_iter
471+
472+
def num_init(self) -> int:
473+
return self._num_init
474+
475+
def centroid_info(self) -> Optional[CentroidInfo]:
460476
return self._centroid_info
461477

462478
def by_expr(self) -> LogicalExpr:
@@ -469,7 +485,13 @@ def with_children(self, children: List[LogicalPlan]) -> LogicalPlan:
469485
if len(children) != 1:
470486
raise ValueError("SemanticCluster must have exactly one child")
471487
result = SemanticCluster(
472-
children[0], self._by_expr, self._num_clusters, self._label_column, self._centroid_column
488+
children[0],
489+
self._by_expr,
490+
num_clusters=self._num_clusters,
491+
max_iter=self._max_iter,
492+
num_init=self._num_init,
493+
label_column=self._label_column,
494+
centroid_column=self._centroid_column,
473495
)
474496
result.set_cache_info(self.cache_info)
475497
return result

0 commit comments

Comments
 (0)