@@ -35,7 +35,15 @@ def __init__(self, df: DataFrame):
3535 """
3636 self ._df = df
3737
38- def with_cluster_labels (self , by : ColumnOrName , num_clusters : int , label_column : str = "cluster_label" , centroid_column : Optional [str ] = None ) -> DataFrame :
38+ def with_cluster_labels (
39+ self ,
40+ by : ColumnOrName ,
41+ num_clusters : int ,
42+ max_iter : int = 300 ,
43+ num_init : int = 1 ,
44+ label_column : str = "cluster_label" ,
45+ centroid_column : Optional [str ] = None ,
46+ ) -> DataFrame :
3947 """Cluster rows using K-means and add cluster metadata columns.
4048
4149 This method clusters rows based on the given embedding column or expression using K-means.
@@ -45,6 +53,8 @@ def with_cluster_labels(self, by: ColumnOrName, num_clusters: int, label_column:
4553 Args:
4654 by: Column or expression producing embeddings to cluster (e.g., `embed(col("text"))`).
4755 num_clusters: Number of clusters to compute (must be > 0).
56+ max_iter: Maximum iterations for a single run of the k-means algorithm. The algorithm stops when it either converges or reaches this limit.
57+ num_init: Number of independent runs of k-means with different centroid seeds. The best result is selected.
4858 label_column: Name of the output column for cluster IDs. Default is "cluster_label".
4959 centroid_column: If provided, adds a column with this name containing the centroid embedding
5060 for each row's assigned cluster.
@@ -56,14 +66,16 @@ def with_cluster_labels(self, by: ColumnOrName, num_clusters: int, label_column:
5666
5767 Raises:
5868 ValidationError: If num_clusters is not a positive integer
69+ ValidationError: If max_iter is not a positive integer
70+ ValidationError: If num_init is not a positive integer
5971 ValidationError: If label_column is not a non-empty string
6072 ValidationError: If centroid_column is not a non-empty string
6173 TypeMismatchError: If the column is not an EmbeddingType
6274
6375 Example: Basic clustering
6476 ```python
6577 # Cluster customer feedback and add cluster metadata
66- clustered_df = df.semantic.with_cluster_labels("feedback_embeddings", 5)
78+ clustered_df = df.semantic.with_cluster_labels("feedback_embeddings", num_clusters= 5)
6779
6880 # Then use regular operations to analyze clusters
6981 clustered_df.group_by("cluster_label").agg(count("*"), avg("rating"))
@@ -72,15 +84,23 @@ def with_cluster_labels(self, by: ColumnOrName, num_clusters: int, label_column:
7284 Example: Filter outliers using centroids
7385 ```python
7486 # Cluster and filter out rows far from their centroid
75- clustered_df = df.semantic.with_cluster_labels("embeddings", 3 , centroid_column="cluster_centroid")
87+ clustered_df = df.semantic.with_cluster_labels("embeddings", num_clusters=3, num_init=10 , centroid_column="cluster_centroid")
7688 clean_df = clustered_df.filter(
7789 embedding.compute_similarity("embeddings", "cluster_centroid", metric="cosine") > 0.7
7890 )
7991 ```
8092 """
8193 # Validate num_clusters
8294 if not isinstance (num_clusters , int ) or num_clusters <= 0 :
83- raise ValidationError ("`num_clusters` must be a positive integer greater than 0." )
95+ raise ValidationError ("`num_clusters` must be a positive integer." )
96+
97+ # Validate max_iter
98+ if not isinstance (max_iter , int ) or max_iter <= 0 :
99+ raise ValidationError ("`max_iter` must be a positive integer." )
100+
101+ # Validate num_init
102+ if not isinstance (num_init , int ) or num_init <= 0 :
103+ raise ValidationError ("`num_init` must be a positive integer." )
84104
85105 # Validate clustering target
86106 if not isinstance (by , ColumnOrName ):
@@ -106,7 +126,13 @@ def with_cluster_labels(self, by: ColumnOrName, num_clusters: int, label_column:
106126
107127 return self ._df ._from_logical_plan (
108128 SemanticCluster (
109- self ._df ._logical_plan , by_expr , num_clusters , label_column , centroid_column
129+ self ._df ._logical_plan ,
130+ by_expr ,
131+ num_clusters = num_clusters ,
132+ max_iter = max_iter ,
133+ num_init = num_init ,
134+ label_column = label_column ,
135+ centroid_column = centroid_column ,
110136 )
111137 )
112138
0 commit comments