Skip to content

Commit f8d50a1

Browse files
authored
[TLMChatCompletion] Add functionality for structured outputs per-field scoring (#120)
1 parent 64eb488 commit f8d50a1

File tree

7 files changed

+217
-9
lines changed

7 files changed

+217
-9
lines changed

CHANGELOG.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
## [1.1.32] - 2025-09-22
11+
12+
### Added
13+
14+
- Add per-field scoring functionality for structured outputs responses in `TLMChatCompletion`
15+
1016
## [1.1.31] - 2025-09-18
1117

1218
### Added
@@ -343,7 +349,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
343349

344350
- Release of the Cleanlab TLM Python client.
345351

346-
[Unreleased]: https://github.com/cleanlab/cleanlab-tlm/compare/v1.1.31...HEAD
352+
[Unreleased]: https://github.com/cleanlab/cleanlab-tlm/compare/v1.1.32...HEAD
353+
[1.1.32]: https://github.com/cleanlab/cleanlab-tlm/compare/v1.1.31...v1.1.32
347354
[1.1.31]: https://github.com/cleanlab/cleanlab-tlm/compare/v1.1.30...v1.1.31
348355
[1.1.30]: https://github.com/cleanlab/cleanlab-tlm/compare/v1.1.29...v1.1.30
349356
[1.1.29]: https://github.com/cleanlab/cleanlab-tlm/compare/v1.1.28...v1.1.29

src/cleanlab_tlm/__about__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
# SPDX-License-Identifier: MIT
2-
__version__ = "1.1.31"
2+
__version__ = "1.1.32"

src/cleanlab_tlm/internal/api/api.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
_TLM_TRUSTWORTHINESS_KEY,
4242
_TLM_USER_ID_KEY,
4343
)
44+
from cleanlab_tlm.internal.exception_handling import handle_tlm_exceptions
4445
from cleanlab_tlm.internal.types import JSONDict
4546

4647
if TYPE_CHECKING:
@@ -533,6 +534,7 @@ async def tlm_rag_score(
533534

534535

535536
@tlm_retry
537+
@handle_tlm_exceptions(response_type="TLMScore")
536538
async def tlm_chat_completions_score(
537539
api_key: str,
538540
response: ChatCompletion,
@@ -577,7 +579,14 @@ async def tlm_chat_completions_score(
577579
if local_scoped_client:
578580
await client_session.close()
579581

580-
return cast(JSONDict, res_json)
582+
tlm_result = {
583+
"trustworthiness_score": res_json["trustworthiness_score"],
584+
}
585+
586+
if "log" in input_kwargs:
587+
tlm_result["log"] = res_json["log"]
588+
589+
return tlm_result
581590

582591

583592
@tlm_retry

src/cleanlab_tlm/internal/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@
6666
"discrepancy",
6767
}
6868
TLM_REASONING_EFFORT_VALUES: set[str] = {"none", "low", "medium", "high"}
69-
TLM_VALID_LOG_OPTIONS: set[str] = {"perplexity", "explanation"}
69+
TLM_VALID_LOG_OPTIONS: set[str] = {"perplexity", "explanation", "per_field_score"}
7070
TLM_VALID_GET_TRUSTWORTHINESS_SCORE_KWARGS: set[str] = {
7171
"perplexity",
7272
_TLM_CONSTRAIN_OUTPUTS_KEY,

src/cleanlab_tlm/internal/exception_handling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def decorator(
5858
async def wrapper(*args: Any, **kwargs: Any) -> ResponseT:
5959
capture_exceptions = kwargs.get("capture_exceptions", False)
6060
batch_index = kwargs.get("batch_index")
61-
evals = getattr(args[0], "_evals", [])
61+
evals = getattr(args[0], "_evals", []) if args else []
6262
try:
6363
return await func(*args, **kwargs)
6464
except asyncio.TimeoutError:

src/cleanlab_tlm/utils/chat_completions.py

Lines changed: 108 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
"""
77

88
import asyncio
9+
import json
910
from typing import TYPE_CHECKING, Any, Optional, Union, cast
1011

1112
from cleanlab_tlm.internal.api.api import tlm_chat_completions_score
@@ -84,6 +85,13 @@ def score(
8485
Returns:
8586
TLMScore: A dict containing the trustworthiness score and optional logs
8687
"""
88+
try:
89+
from openai.lib._parsing._completions import type_to_response_format_param
90+
except ImportError as e:
91+
raise ImportError(
92+
f"OpenAI is required to use the {self.__class__.__name__} class. Please install it with `pip install openai`."
93+
) from e
94+
8795
self._validate_chat_completion(response)
8896
if (messages := openai_kwargs.get("messages")) is None:
8997
raise ValueError("messages is a required OpenAI input argument.")
@@ -95,7 +103,14 @@ def score(
95103
}
96104

97105
# handle structured outputs differently
98-
if openai_kwargs.get("response_format"):
106+
if combined_kwargs.get("response_format"):
107+
if "log" in combined_kwargs and "explanation" in combined_kwargs["log"]:
108+
raise ValueError(
109+
"`explanation` is not supported when `response_format` is specified, "
110+
"use `per_field_score` instead to get detailed explanations for each field"
111+
)
112+
113+
combined_kwargs["response_format"] = type_to_response_format_param(combined_kwargs["response_format"])
99114
return cast(
100115
TLMScore,
101116
self._event_loop.run_until_complete(
@@ -111,7 +126,7 @@ def score(
111126
)
112127

113128
# all other cases
114-
tools = openai_kwargs.get("tools", None)
129+
tools = combined_kwargs.get("tools")
115130

116131
prompt_text = _form_prompt_chat_completions_api(messages, tools)
117132
response_text = form_response_string_chat_completions(response=response)
@@ -195,6 +210,97 @@ def get_explanation(
195210

196211
raise TypeError("tlm_result must be a TLMScore or ChatCompletion object.")
197212

213+
def get_untrustworthy_fields(
214+
self,
215+
*,
216+
response: Optional["ChatCompletion"] = None,
217+
tlm_result: Union[TLMScore, "ChatCompletion"],
218+
threshold: float = 0.8,
219+
display_details: bool = True,
220+
) -> list[str]:
221+
"""Gets the fields of a structured output response that are considered untrustworthy by TLM.
222+
Only works for responses that are valid JSON objects (uses `response_format` to specify the output format).
223+
Prints detailed information about the untrustworthy fields if `display_details` is True.
224+
225+
Args:
226+
response (ChatCompletion): The OpenAI ChatCompletion response object to evaluate
227+
tlm_result (TLMScore | ChatCompletion): The result object from a previous TLM call
228+
threshold (float): The threshold for considering a field untrustworthy
229+
display_details (bool): Whether to display detailed information about the untrustworthy fields
230+
231+
Returns:
232+
list[str]: The fields of the response that are considered untrustworthy by TLM
233+
"""
234+
try:
235+
from openai.types.chat import ChatCompletion
236+
except ImportError as e:
237+
raise ImportError(
238+
f"OpenAI is required to use the {self.__class__.__name__} class. Please install it with `pip install openai`."
239+
) from e
240+
241+
if isinstance(tlm_result, dict):
242+
if response is None:
243+
raise ValueError("'response' is required when tlm_result is a TLMScore object")
244+
245+
tlm_metadata = tlm_result
246+
response_text = response.choices[0].message.content or "{}"
247+
248+
elif isinstance(tlm_result, ChatCompletion):
249+
if getattr(tlm_result, "tlm_metadata", None) is None:
250+
raise ValueError("tlm_result must contain tlm_metadata.")
251+
252+
tlm_metadata = tlm_result.tlm_metadata # type: ignore
253+
response_text = tlm_result.choices[0].message.content or "{}"
254+
255+
else:
256+
raise TypeError("tlm_result must be a TLMScore or ChatCompletion object.")
257+
258+
if "per_field_score" not in tlm_metadata.get("log", {}):
259+
raise ValueError(
260+
"`per_field_score` is not present in the log.\n"
261+
"`get_untrustworthy_fields()` can only be called scoring structured outputs responses and specifying "
262+
"`per_field_score` in the `log` option for TLM."
263+
)
264+
265+
try:
266+
so_response = json.loads(response_text)
267+
except Exception:
268+
raise ValueError(
269+
"The LLM response must be a valid JSON output (use `response_format` to specify the output format)"
270+
)
271+
272+
per_field_score = tlm_metadata["log"]["per_field_score"]
273+
per_score_details = []
274+
275+
for key, value in per_field_score.items():
276+
score = value["score"]
277+
if float(score) < threshold:
278+
key_details = {
279+
"response": so_response[key],
280+
"score": score,
281+
"explanation": value["explanation"],
282+
}
283+
per_score_details.append({key: key_details})
284+
285+
per_score_details.sort(key=lambda x: next(iter(x.values()))["score"])
286+
untrustworthy_fields = [next(iter(item.keys())) for item in per_score_details]
287+
288+
if display_details:
289+
if len(untrustworthy_fields) == 0:
290+
print("No untrustworthy fields found")
291+
292+
else:
293+
print(f"Untrustworthy fields: {untrustworthy_fields}\n")
294+
for item in per_score_details:
295+
print(f"Field: {next(iter(item.keys()))}")
296+
details = next(iter(item.values()))
297+
print(f"Response: {details['response']}")
298+
print(f"Score: {details['score']}")
299+
print(f"Explanation: {details['explanation']}")
300+
print()
301+
302+
return untrustworthy_fields
303+
198304
@staticmethod
199305
def _get_response_message(response: "ChatCompletion") -> "ChatCompletionMessage":
200306
return response.choices[0].message

tests/test_chat_completions.py

Lines changed: 88 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,86 @@ def test_tlm_chat_completion_score_with_structured_output() -> None:
206206
assert is_trustworthiness_score_json_format(score)
207207

208208

209+
def test_tlm_chat_completion_structured_output_per_field_scoring() -> None:
210+
tlm_chat = TLMChatCompletion(options={"log": ["per_field_score"]})
211+
212+
openai_kwargs = {
213+
"model": "gpt-4.1-mini",
214+
"messages": [
215+
{
216+
"role": "system",
217+
"content": "You are a helpful math tutor. Guide the user through the solution step by step.",
218+
},
219+
{"role": "user", "content": "how can I solve 8x + 7 = -23"},
220+
],
221+
"response_format": {
222+
"type": "json_schema",
223+
"json_schema": {
224+
"name": "math_reasoning",
225+
"schema": {
226+
"type": "object",
227+
"properties": {
228+
"steps": {
229+
"type": "array",
230+
"items": {
231+
"type": "object",
232+
"properties": {
233+
"explanation": {"type": "string"},
234+
"output": {"type": "string"},
235+
},
236+
"required": ["explanation", "output"],
237+
"additionalProperties": False,
238+
},
239+
},
240+
"final_answer": {"type": "string"},
241+
},
242+
"required": ["steps", "final_answer"],
243+
"additionalProperties": False,
244+
},
245+
"strict": True,
246+
},
247+
},
248+
}
249+
response = ChatCompletion(
250+
id="test",
251+
choices=[
252+
Choice(
253+
index=0,
254+
message=ChatCompletionMessage(
255+
role="assistant",
256+
content='{"steps":[{"explanation":"Start with the original equation: 8x + 7 = -23","output":"8x + 7 = -23"},{"explanation":"Subtract 7 from both sides to isolate the term with x on one side. This will give us: 8x = -23 - 7","output":"8x = -30"},{"explanation":"Now simplify the right side: -23 - 7 equals -30, so we have 8x = -30","output":"8x = -30"},{"explanation":"Next, divide both sides by 8 to solve for x. This gives us: x = -30 / 8","output":"x = -3.75"},{"explanation":"We can also simplify -30 / 8 by dividing both the numerator and the denominator by 2. This leads to: x = -15 / 4","output":"x = -15/4 (or -3.75 as a decimal)"}],"final_answer":"x = -17/4"}',
257+
),
258+
finish_reason="stop",
259+
)
260+
],
261+
usage=CompletionUsage(
262+
completion_tokens=50,
263+
completion_tokens_details=CompletionTokensDetails(
264+
accepted_prediction_tokens=0,
265+
audio_tokens=0,
266+
reasoning_tokens=0,
267+
rejected_prediction_tokens=0,
268+
),
269+
prompt_tokens=50,
270+
prompt_tokens_details=PromptTokensDetails(audio_tokens=0, cached_tokens=0),
271+
total_tokens=100,
272+
),
273+
created=1234567890,
274+
model="test-model",
275+
object="chat.completion",
276+
)
277+
278+
score = tlm_chat.score(response=response, **openai_kwargs)
279+
280+
assert score is not None
281+
assert is_trustworthiness_score_json_format(score)
282+
283+
# test per_field_score
284+
assert len(score["log"]["per_field_score"]) == 2 # noqa: PLR2004
285+
assert {"steps", "final_answer"} == set(score["log"]["per_field_score"].keys())
286+
assert tlm_chat.get_untrustworthy_fields(response=response, tlm_result=score) == ["final_answer"]
287+
288+
209289
def test_tlm_chat_completion_score_invalid_response() -> None:
210290
tlm_chat = TLMChatCompletion()
211291
openai_kwargs = {
@@ -248,8 +328,14 @@ def test_tlm_chat_completion_score_missing_messages() -> None:
248328
@pytest.mark.parametrize(
249329
"arguments, condition", # noqa: PT006
250330
[
251-
(json.dumps({"query": "Capital of Germany"}), lambda score: score["trustworthiness_score"] < 0.5), # noqa: PLR2004
252-
(json.dumps({"query": "Capital of France"}), lambda score: score["trustworthiness_score"] >= 0.8), # noqa: PLR2004
331+
(
332+
json.dumps({"query": "Capital of Germany"}),
333+
lambda score: score["trustworthiness_score"] < 0.5, # noqa: PLR2004
334+
),
335+
(
336+
json.dumps({"query": "Capital of France"}),
337+
lambda score: score["trustworthiness_score"] >= 0.8, # noqa: PLR2004
338+
),
253339
],
254340
ids=["bad_arguments", "good_arguments"],
255341
)

0 commit comments

Comments
 (0)