Skip to content

Commit baf3897

Browse files
author
Rohit Rastogi
authored
feat: support descriptions of class labels in semantic.classify (#102)
## Summary This PR enhances the semantic.classify API to support optional descriptions for classification classes, improving LLM performance by providing better context. The enhancement maintains backward compatibility with string list input while removing enum support. ## Motivation Previously, the classify API only accepted simple string labels or enums, which limited the LLM's understanding of nuanced classification categories. By allowing optional descriptions, we can provide richer context that helps the model make more accurate classification decisions. ## Changes ### New Features - **ClassDefinition model**: New Pydantic model with `label` and `description` fields for enhanced class definitions - **Enhanced prompt generation**: System messages now include class descriptions when provided, giving the LLM better context - **Flexible API**: Support for both simple string lists (backward compatible) and detailed ClassDefinition objects ### 🔧 Technical Changes - Replaced `List[str] | type[Enum]` API with `Union[List[str], List[ClassDefinition]]` for explicit type handling - Added internal `ResolvedClassDefinition` dataclass for normalized representation - Updated validation logic to work with label sets instead of enums ## Usage Examples ### Basic usage (backward compatible) ```python semantic.classify("support_ticket", ["technical_issue", "billing_inquiry", "general_support"]) ``` ### Enhanced usage with descriptions ```python semantic.classify("support_ticket", [ ClassDefinition( label="technical_issue", description="Problems with product functionality, bugs, or technical difficulties" ), ClassDefinition( label="billing_inquiry", description="Questions about charges, payments, subscriptions, or account billing" ), ClassDefinition( label="general_support", description="General questions, feature requests, or non-technical assistance" ) ]) ```
1 parent 7795497 commit baf3897

File tree

16 files changed

+288
-314
lines changed

16 files changed

+288
-314
lines changed

src/fenic/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
from fenic.core import (
5353
ArrayType,
5454
BooleanType,
55+
ClassDefinition,
5556
ClassifyExample,
5657
ClassifyExampleCollection,
5758
ColumnField,
@@ -129,6 +130,7 @@
129130
"PredicateExample",
130131
"PredicateExampleCollection",
131132
"Schema",
133+
"ClassDefinition",
132134
"ClassifyExample",
133135
"ClassifyExampleCollection",
134136
"JoinExample",

src/fenic/_backends/local/semantic_operators/analyze_sentiment.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import json
22
import logging
3-
from enum import Enum
43
from typing import List, Optional
54

65
import polars as pl
@@ -110,17 +109,7 @@
110109
)
111110
)
112111

113-
SENTIMENT_ANALYSIS_MODEL = create_classification_pydantic_model(
114-
Enum(
115-
"Sentiment",
116-
[
117-
("POSITIVE", "positive"),
118-
("NEGATIVE", "negative"),
119-
("NEUTRAL", "neutral"),
120-
],
121-
)
122-
)
123-
112+
SENTIMENT_ANALYSIS_MODEL = create_classification_pydantic_model(["positive", "negative", "neutral"])
124113

125114
class AnalyzeSentiment(BaseSingleColumnInputOperator[str, str]):
126115
SYSTEM_PROMPT = """You are a sentiment analysis expert.
@@ -166,6 +155,11 @@ def postprocess(self, responses: List[Optional[str]]) -> List[Optional[str]]:
166155
try:
167156
data = json.loads(response)["output"]
168157
predictions.append(data)
158+
if data not in ["positive", "negative", "neutral"]:
159+
logger.warning(
160+
f"Model returned invalid label '{data}'. Valid labels: positive, negative, neutral"
161+
)
162+
predictions.append(None)
169163
except Exception as e:
170164
logger.warning(
171165
f"Invalid model output: {response} for semantic.analyze_sentiment: {e}"

src/fenic/_backends/local/semantic_operators/classify.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import json
22
import logging
3-
from enum import Enum
43
from typing import List, Optional, Type
54

65
import polars as pl
@@ -12,14 +11,14 @@
1211
)
1312
from fenic._backends.local.semantic_operators.utils import (
1413
create_classification_pydantic_model,
15-
stringify_enum_type,
1614
)
1715
from fenic._constants import (
1816
MAX_TOKENS_DETERMINISTIC_OUTPUT_SIZE,
1917
TOKEN_OVERHEAD_JSON,
2018
TOKEN_OVERHEAD_MISC,
2119
)
2220
from fenic._inference.language_model import InferenceConfiguration, LanguageModel
21+
from fenic.core._logical_plan.expressions import ResolvedClassDefinition
2322
from fenic.core.types import ClassifyExample, ClassifyExampleCollection
2423

2524
logger = logging.getLogger(__name__)
@@ -28,20 +27,24 @@
2827
class Classify(BaseSingleColumnInputOperator[str, str]):
2928
SYSTEM_PROMPT = (
3029
"You are a text classification expert. "
31-
"Classify the following document into one of the following labels: {labels}. "
30+
"Classify the following document into one of the following labels:"
31+
"\n{classes}\n"
3232
"Respond with *only* the predicted label."
3333
)
3434

3535
def __init__(
3636
self,
3737
input: pl.Series,
38-
labels: Type[Enum],
38+
classes: List[ResolvedClassDefinition],
3939
model: LanguageModel,
4040
temperature: float,
4141
examples: Optional[ClassifyExampleCollection] = None,
4242
):
43+
self.classes = classes
44+
self.valid_labels = {class_def.label for class_def in classes}
45+
# Create output model from class labels
46+
labels = [class_def.label for class_def in classes]
4347
self.output_model = create_classification_pydantic_model(labels)
44-
self.labels = labels
4548
super().__init__(
4649
input,
4750
CompletionOnlyRequestSender(
@@ -57,7 +60,17 @@ def __init__(
5760
)
5861

5962
def build_system_message(self) -> str:
60-
return self.SYSTEM_PROMPT.format(labels=stringify_enum_type(self.labels))
63+
"""Build system message with class descriptions."""
64+
class_descriptions = []
65+
for class_def in self.classes:
66+
if class_def.description:
67+
class_descriptions.append(f"- {class_def.label}: {class_def.description}")
68+
else:
69+
class_descriptions.append(f"- {class_def.label}")
70+
71+
classes_text = "\n".join(class_descriptions)
72+
73+
return self.SYSTEM_PROMPT.format(classes=classes_text)
6174

6275
def postprocess(self, responses: List[Optional[str]]) -> List[Optional[str]]:
6376
predictions = []
@@ -67,7 +80,14 @@ def postprocess(self, responses: List[Optional[str]]) -> List[Optional[str]]:
6780
else:
6881
try:
6982
data = json.loads(response)["output"]
70-
predictions.append(data)
83+
# Validate the response is one of the valid labels
84+
if data not in self.valid_labels:
85+
logger.warning(
86+
f"Model returned invalid label '{data}'. Valid labels: {self.valid_labels}"
87+
)
88+
predictions.append(None)
89+
else:
90+
predictions.append(data)
7191
except Exception as e:
7292
logger.warning(
7393
f"Invalid model output: {response} for semantic.classify: {e}"
@@ -82,5 +102,5 @@ def convert_example_to_assistant_message(self, example: ClassifyExample) -> str:
82102
return self.output_model(output=example.output).model_dump_json()
83103

84104
def get_max_tokens(self) -> int:
85-
max_label_length = max(len(str(label.value)) for label in self.labels)
105+
max_label_length = max(len(class_def.label) for class_def in self.classes)
86106
return max(MAX_TOKENS_DETERMINISTIC_OUTPUT_SIZE, TOKEN_OVERHEAD_JSON + TOKEN_OVERHEAD_MISC + max_label_length)

src/fenic/_backends/local/semantic_operators/utils.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import re
22
from enum import Enum
3-
from typing import Any, Dict, Type
3+
from typing import Any, Dict, List, Type
44

55
import polars as pl
66
from pydantic import BaseModel, create_model
@@ -37,15 +37,19 @@ def stringify_enum_type(enum_type: Type[Enum]) -> str:
3737
return ", ".join(f"{label.value}" for label in enum_type)
3838

3939

40-
def create_classification_pydantic_model(enum_cls: Type[Enum]) -> Type[BaseModel]:
41-
"""Creates a Pydantic model from an Enum class.
40+
def create_classification_pydantic_model(allowed_values: List[str]) -> type[BaseModel]:
41+
"""Creates a Pydantic model from a list of allowed string values using a dynamic Enum.
4242
4343
Args:
44-
enum_cls (Type[Enum]): The Enum class to convert.
44+
allowed_values (List[str]): The list of allowed string values.
4545
4646
Returns:
4747
Type[BaseModel]: A Pydantic model class with a field for the Enum values.
4848
"""
49+
enum_name = "LabelEnum"
50+
enum_members = {value.upper(): value for value in allowed_values}
51+
enum_cls = Enum(enum_name, enum_members)
52+
4953
return create_model(
5054
"EnumModel",
5155
output=(enum_cls, ...),

src/fenic/_backends/local/transpiler/expr_converter.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -554,16 +554,12 @@ def sem_predicate_fn(batch: pl.Series) -> pl.Series:
554554
@_convert_expr.register(SemanticClassifyExpr)
555555
def _convert_semantic_classify_expr(self, logical: SemanticClassifyExpr) -> pl.Expr:
556556
def sem_classify_fn(batch: pl.Series) -> pl.Series:
557-
labels_enum = (
558-
SemanticClassifyExpr.transform_labels_list_into_enum(logical.labels)
559-
if isinstance(logical.labels, list)
560-
else logical.labels
561-
)
562557
return SemanticClassify(
563558
input=batch,
564-
labels=labels_enum,
559+
classes=logical.classes,
565560
model=self.session_state.get_language_model(logical.model_alias),
566561
temperature=logical.temperature,
562+
examples=logical.examples,
567563
).execute()
568564

569565
return self._convert_expr(logical.expr).map_batches(

src/fenic/api/functions/semantic.py

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Semantic functions for Fenic DataFrames - LLM-based operations."""
22

3-
from enum import Enum
43
from typing import List, Optional, Union
54

65
from pydantic import BaseModel, ConfigDict, validate_call
@@ -9,6 +8,7 @@
98
from fenic.core._logical_plan.expressions import (
109
AnalyzeSentimentExpr,
1110
EmbeddingsExpr,
11+
ResolvedClassDefinition,
1212
SemanticClassifyExpr,
1313
SemanticExtractExpr,
1414
SemanticMapExpr,
@@ -22,6 +22,7 @@
2222
)
2323
from fenic.core.error import ValidationError
2424
from fenic.core.types import (
25+
ClassDefinition,
2526
ClassifyExampleCollection,
2627
KeyPoints,
2728
MapExampleCollection,
@@ -259,60 +260,86 @@ def reduce(
259260
@validate_call(config=ConfigDict(strict=True, arbitrary_types_allowed=True))
260261
def classify(
261262
column: ColumnOrName,
262-
labels: List[str] | type[Enum],
263+
classes: Union[List[str], List[ClassDefinition]],
263264
examples: Optional[ClassifyExampleCollection] = None,
264265
model_alias: Optional[str] = None,
265266
temperature: float = 0,
266267
) -> Column:
267-
"""Classifies a string column into one of the provided labels.
268+
"""Classifies a string column into one of the provided classes.
268269
269270
This is useful for tagging incoming documents with predefined categories.
270271
271272
Args:
272273
column: Column or column name containing text to classify.
273-
274-
labels: List of category strings or an Enum defining the categories to classify the text into.
275-
274+
classes: List of class labels or ClassDefinition objects defining the available classes. Use ClassDefinition objects to provide descriptions for the classes.
276275
examples: Optional collection of example classifications to guide the model.
277276
Examples should be created using ClassifyExampleCollection.create_example(),
278277
with instruction variables mapped to their expected classifications.
279-
280278
model_alias: Optional alias for the language model to use for the mapping. If None, will use the language model configured as the default.
281-
282279
temperature: Optional temperature parameter for the language model. If None, will use the default temperature (0.0).
283280
284281
Returns:
285282
Column: Expression containing the classification results.
286283
287284
Raises:
288-
ValueError: If column is invalid or categories is not a list of strings.
285+
ValueError: If column is invalid or classes is empty or has duplicate labels.
289286
290287
Example: Categorizing incoming support requests
291288
```python
292289
# Categorize incoming support requests
293290
semantic.classify("message", ["Account Access", "Billing Issue", "Technical Problem"])
294291
```
295292
296-
Example: Categorizing incoming support requests with examples
293+
Example: Categorizing incoming support requests using ClassDefinition objects
294+
```python
295+
# Categorize incoming support requests
296+
semantic.classify("message", [
297+
ClassDefinition(label="Account Access", description="General questions, feature requests, or non-technical assistance"),
298+
ClassDefinition(label="Billing Issue", description="Questions about charges, payments, subscriptions, or account billing"),
299+
ClassDefinition(label="Technical Problem", description="Problems with product functionality, bugs, or technical difficulties")
300+
])
301+
```
302+
303+
Example: Categorizing incoming support requests with ClassDefinition objects and examples
297304
```python
298305
examples = ClassifyExampleCollection()
306+
class_definitions = [
307+
ClassDefinition(label="Account Access", description="General questions, feature requests, or non-technical assistance"),
308+
ClassDefinition(label="Billing Issue", description="Questions about charges, payments, subscriptions, or account billing"),
309+
ClassDefinition(label="Technical Problem", description="Problems with product functionality, bugs, or technical difficulties")
310+
]
299311
examples.create_example(ClassifyExample(
300312
input="I can't reset my password or access my account.",
301313
output="Account Access"))
302314
examples.create_example(ClassifyExample(
303315
input="You charged me twice for the same month.",
304316
output="Billing Issue"))
305-
semantic.classify("message", ["Account Access", "Billing Issue", "Technical Problem"], examples)
317+
semantic.classify("message", class_definitions, examples)
306318
```
307319
"""
308-
if isinstance(labels, List) and len(labels) == 0:
309-
raise ValueError(
310-
f"Must specify the categories for classification, found: {len(labels)} categories"
320+
if len(classes) < 2:
321+
raise ValidationError(
322+
"The `classes` list must contain at least two ClassDefinition objects. "
323+
"You provided only one. Classification requires at least two possible labels."
311324
)
325+
326+
# Validate unique labels
327+
if isinstance(classes[0], ClassDefinition):
328+
classes = [ResolvedClassDefinition(label=class_def.label, description=class_def.description) for class_def in classes]
329+
else:
330+
classes = [ResolvedClassDefinition(label=class_def, description=None) for class_def in classes]
331+
332+
labels = [class_def.label for class_def in classes]
333+
duplicates = {label for label in labels if labels.count(label) > 1}
334+
if duplicates:
335+
raise ValidationError(
336+
f"Class labels must be unique. The following duplicate label(s) were found: {sorted(duplicates)}"
337+
)
338+
312339
return Column._from_logical_expr(
313340
SemanticClassifyExpr(
314341
Column._from_col_or_name(column)._logical_expr,
315-
labels,
342+
classes,
316343
examples=examples,
317344
model_alias=model_alias,
318345
temperature=temperature,

src/fenic/core/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
ArrayType,
1111
BooleanType,
1212
BranchSide,
13+
ClassDefinition,
1314
ClassifyExample,
1415
ClassifyExampleCollection,
1516
ColumnField,
@@ -61,6 +62,7 @@
6162
"TranscriptType",
6263
"ColumnField",
6364
"Schema",
65+
"ClassDefinition",
6466
"ClassifyExample",
6567
"ClassifyExampleCollection",
6668
"JoinExample",

src/fenic/core/_logical_plan/expressions/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@
7575
from fenic.core._logical_plan.expressions.semantic import (
7676
EmbeddingsExpr as EmbeddingsExpr,
7777
)
78+
from fenic.core._logical_plan.expressions.semantic import (
79+
ResolvedClassDefinition as ResolvedClassDefinition,
80+
)
7881
from fenic.core._logical_plan.expressions.semantic import (
7982
SemanticClassifyExpr as SemanticClassifyExpr,
8083
)

0 commit comments

Comments
 (0)