|
1 | 1 | """Semantic functions for Fenic DataFrames - LLM-based operations.""" |
2 | 2 |
|
3 | | -from enum import Enum |
4 | 3 | from typing import List, Optional, Union |
5 | 4 |
|
6 | 5 | from pydantic import BaseModel, ConfigDict, validate_call |
|
9 | 8 | from fenic.core._logical_plan.expressions import ( |
10 | 9 | AnalyzeSentimentExpr, |
11 | 10 | EmbeddingsExpr, |
| 11 | + ResolvedClassDefinition, |
12 | 12 | SemanticClassifyExpr, |
13 | 13 | SemanticExtractExpr, |
14 | 14 | SemanticMapExpr, |
|
22 | 22 | ) |
23 | 23 | from fenic.core.error import ValidationError |
24 | 24 | from fenic.core.types import ( |
| 25 | + ClassDefinition, |
25 | 26 | ClassifyExampleCollection, |
26 | 27 | KeyPoints, |
27 | 28 | MapExampleCollection, |
@@ -259,60 +260,86 @@ def reduce( |
259 | 260 | @validate_call(config=ConfigDict(strict=True, arbitrary_types_allowed=True)) |
260 | 261 | def classify( |
261 | 262 | column: ColumnOrName, |
262 | | - labels: List[str] | type[Enum], |
| 263 | + classes: Union[List[str], List[ClassDefinition]], |
263 | 264 | examples: Optional[ClassifyExampleCollection] = None, |
264 | 265 | model_alias: Optional[str] = None, |
265 | 266 | temperature: float = 0, |
266 | 267 | ) -> Column: |
267 | | - """Classifies a string column into one of the provided labels. |
| 268 | + """Classifies a string column into one of the provided classes. |
268 | 269 |
|
269 | 270 | This is useful for tagging incoming documents with predefined categories. |
270 | 271 |
|
271 | 272 | Args: |
272 | 273 | column: Column or column name containing text to classify. |
273 | | -
|
274 | | - labels: List of category strings or an Enum defining the categories to classify the text into. |
275 | | -
|
| 274 | + classes: List of class labels or ClassDefinition objects defining the available classes. Use ClassDefinition objects to provide descriptions for the classes. |
276 | 275 | examples: Optional collection of example classifications to guide the model. |
277 | 276 | Examples should be created using ClassifyExampleCollection.create_example(), |
278 | 277 | with instruction variables mapped to their expected classifications. |
279 | | -
|
280 | 278 | model_alias: Optional alias for the language model to use for the mapping. If None, will use the language model configured as the default. |
281 | | -
|
282 | 279 | temperature: Optional temperature parameter for the language model. If None, will use the default temperature (0.0). |
283 | 280 |
|
284 | 281 | Returns: |
285 | 282 | Column: Expression containing the classification results. |
286 | 283 |
|
287 | 284 | Raises: |
288 | | - ValueError: If column is invalid or categories is not a list of strings. |
| 285 | + ValueError: If column is invalid or classes is empty or has duplicate labels. |
289 | 286 |
|
290 | 287 | Example: Categorizing incoming support requests |
291 | 288 | ```python |
292 | 289 | # Categorize incoming support requests |
293 | 290 | semantic.classify("message", ["Account Access", "Billing Issue", "Technical Problem"]) |
294 | 291 | ``` |
295 | 292 |
|
296 | | - Example: Categorizing incoming support requests with examples |
| 293 | + Example: Categorizing incoming support requests using ClassDefinition objects |
| 294 | + ```python |
| 295 | + # Categorize incoming support requests |
| 296 | + semantic.classify("message", [ |
| 297 | + ClassDefinition(label="Account Access", description="General questions, feature requests, or non-technical assistance"), |
| 298 | + ClassDefinition(label="Billing Issue", description="Questions about charges, payments, subscriptions, or account billing"), |
| 299 | + ClassDefinition(label="Technical Problem", description="Problems with product functionality, bugs, or technical difficulties") |
| 300 | + ]) |
| 301 | + ``` |
| 302 | +
|
| 303 | + Example: Categorizing incoming support requests with ClassDefinition objects and examples |
297 | 304 | ```python |
298 | 305 | examples = ClassifyExampleCollection() |
| 306 | + class_definitions = [ |
| 307 | + ClassDefinition(label="Account Access", description="General questions, feature requests, or non-technical assistance"), |
| 308 | + ClassDefinition(label="Billing Issue", description="Questions about charges, payments, subscriptions, or account billing"), |
| 309 | + ClassDefinition(label="Technical Problem", description="Problems with product functionality, bugs, or technical difficulties") |
| 310 | + ] |
299 | 311 | examples.create_example(ClassifyExample( |
300 | 312 | input="I can't reset my password or access my account.", |
301 | 313 | output="Account Access")) |
302 | 314 | examples.create_example(ClassifyExample( |
303 | 315 | input="You charged me twice for the same month.", |
304 | 316 | output="Billing Issue")) |
305 | | - semantic.classify("message", ["Account Access", "Billing Issue", "Technical Problem"], examples) |
| 317 | + semantic.classify("message", class_definitions, examples) |
306 | 318 | ``` |
307 | 319 | """ |
308 | | - if isinstance(labels, List) and len(labels) == 0: |
309 | | - raise ValueError( |
310 | | - f"Must specify the categories for classification, found: {len(labels)} categories" |
| 320 | + if len(classes) < 2: |
| 321 | + raise ValidationError( |
| 322 | + "The `classes` list must contain at least two ClassDefinition objects. " |
| 323 | + "You provided only one. Classification requires at least two possible labels." |
311 | 324 | ) |
| 325 | + |
| 326 | + # Validate unique labels |
| 327 | + if isinstance(classes[0], ClassDefinition): |
| 328 | + classes = [ResolvedClassDefinition(label=class_def.label, description=class_def.description) for class_def in classes] |
| 329 | + else: |
| 330 | + classes = [ResolvedClassDefinition(label=class_def, description=None) for class_def in classes] |
| 331 | + |
| 332 | + labels = [class_def.label for class_def in classes] |
| 333 | + duplicates = {label for label in labels if labels.count(label) > 1} |
| 334 | + if duplicates: |
| 335 | + raise ValidationError( |
| 336 | + f"Class labels must be unique. The following duplicate label(s) were found: {sorted(duplicates)}" |
| 337 | + ) |
| 338 | + |
312 | 339 | return Column._from_logical_expr( |
313 | 340 | SemanticClassifyExpr( |
314 | 341 | Column._from_col_or_name(column)._logical_expr, |
315 | | - labels, |
| 342 | + classes, |
316 | 343 | examples=examples, |
317 | 344 | model_alias=model_alias, |
318 | 345 | temperature=temperature, |
|
0 commit comments