Skip to content

Commit db6eeef

Browse files
authored
[VPC ChatCompletion] Add functionality for structured outputs per-field scoring (#127)
1 parent 96b118f commit db6eeef

File tree

5 files changed

+162
-79
lines changed

5 files changed

+162
-79
lines changed

CHANGELOG.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
## [1.1.36] - 2025-09-30
11+
12+
## Added
13+
14+
- Extend `TLMResponses` to work for OpenAI-built-in tools
15+
- Add per-field scoring functionality for structured outputs responses in VPC ChatCompletion module
16+
1017
## [1.1.35] - 2025-09-25
1118

1219
### Added
@@ -367,7 +374,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
367374

368375
- Release of the Cleanlab TLM Python client.
369376

370-
[Unreleased]: https://github.com/cleanlab/cleanlab-tlm/compare/v1.1.35...HEAD
377+
[Unreleased]: https://github.com/cleanlab/cleanlab-tlm/compare/v1.1.36...HEAD
378+
[1.1.36]: https://github.com/cleanlab/cleanlab-tlm/compare/v1.1.35...v1.1.36
371379
[1.1.35]: https://github.com/cleanlab/cleanlab-tlm/compare/v1.1.34...v1.1.35
372380
[1.1.34]: https://github.com/cleanlab/cleanlab-tlm/compare/v1.1.33...v1.1.34
373381
[1.1.33]: https://github.com/cleanlab/cleanlab-tlm/compare/v1.1.32...v1.1.33

src/cleanlab_tlm/__about__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
# SPDX-License-Identifier: MIT
2-
__version__ = "1.1.35"
2+
__version__ = "1.1.36"

src/cleanlab_tlm/utils/chat_completions.py

Lines changed: 8 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
"""
77

88
import asyncio
9-
import json
109
from typing import TYPE_CHECKING, Any, Optional, Union, cast
1110

1211
from cleanlab_tlm.internal.api.api import tlm_chat_completions_score
@@ -18,6 +17,7 @@
1817
from cleanlab_tlm.internal.types import TLMQualityPreset
1918
from cleanlab_tlm.tlm import TLM, TLMOptions, TLMResponse, TLMScore
2019
from cleanlab_tlm.utils.chat import _form_prompt_chat_completions_api, form_response_string_chat_completions
20+
from cleanlab_tlm.utils.per_field_score_utils import _get_untrustworthy_fields
2121

2222
if TYPE_CHECKING:
2323
from openai.types.chat import ChatCompletion, ChatCompletionMessage
@@ -251,75 +251,13 @@ def get_untrustworthy_fields(
251251
Returns:
252252
list[str]: The fields of the response that are considered untrustworthy by TLM
253253
"""
254-
try:
255-
from openai.types.chat import ChatCompletion
256-
except ImportError as e:
257-
raise ImportError(
258-
f"OpenAI is required to use the {self.__class__.__name__} class. Please install it with `pip install openai`."
259-
) from e
260-
261-
if isinstance(tlm_result, dict):
262-
if response is None:
263-
raise ValueError("'response' is required when tlm_result is a TLMScore object")
264-
265-
tlm_metadata = tlm_result
266-
response_text = response.choices[0].message.content or "{}"
267-
268-
elif isinstance(tlm_result, ChatCompletion):
269-
if getattr(tlm_result, "tlm_metadata", None) is None:
270-
raise ValueError("tlm_result must contain tlm_metadata.")
271-
272-
tlm_metadata = tlm_result.tlm_metadata # type: ignore
273-
response_text = tlm_result.choices[0].message.content or "{}"
274-
275-
else:
276-
raise TypeError("tlm_result must be a TLMScore or ChatCompletion object.")
277-
278-
if "per_field_score" not in tlm_metadata.get("log", {}):
279-
raise ValueError(
280-
"`per_field_score` is not present in the log.\n"
281-
"`get_untrustworthy_fields()` can only be called scoring structured outputs responses and specifying "
282-
"`per_field_score` in the `log` option for TLM."
283-
)
284-
285-
try:
286-
so_response = json.loads(response_text)
287-
except Exception:
288-
raise ValueError(
289-
"The LLM response must be a valid JSON output (use `response_format` to specify the output format)"
290-
)
291-
292-
per_field_score = tlm_metadata["log"]["per_field_score"]
293-
per_score_details = []
294-
295-
for key, value in per_field_score.items():
296-
score = value["score"]
297-
if float(score) < threshold:
298-
key_details = {
299-
"response": so_response[key],
300-
"score": score,
301-
"explanation": value["explanation"],
302-
}
303-
per_score_details.append({key: key_details})
304-
305-
per_score_details.sort(key=lambda x: next(iter(x.values()))["score"])
306-
untrustworthy_fields = [next(iter(item.keys())) for item in per_score_details]
307-
308-
if display_details:
309-
if len(untrustworthy_fields) == 0:
310-
print("No untrustworthy fields found")
311-
312-
else:
313-
print(f"Untrustworthy fields: {untrustworthy_fields}\n")
314-
for item in per_score_details:
315-
print(f"Field: {next(iter(item.keys()))}")
316-
details = next(iter(item.values()))
317-
print(f"Response: {details['response']}")
318-
print(f"Score: {details['score']}")
319-
print(f"Explanation: {details['explanation']}")
320-
print()
321-
322-
return untrustworthy_fields
254+
return _get_untrustworthy_fields(
255+
response=response,
256+
tlm_result=tlm_result,
257+
threshold=threshold,
258+
display_details=display_details,
259+
class_name=self.__class__.__name__,
260+
)
323261

324262
@staticmethod
325263
def _get_response_message(response: "ChatCompletion") -> "ChatCompletionMessage":
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import json
2+
from typing import TYPE_CHECKING, Optional, Union
3+
4+
from cleanlab_tlm.tlm import TLMScore
5+
6+
if TYPE_CHECKING:
7+
from openai.types.chat import ChatCompletion
8+
9+
10+
def _get_untrustworthy_fields(
11+
response: Optional["ChatCompletion"],
12+
tlm_result: Union[TLMScore, "ChatCompletion"],
13+
threshold: float,
14+
display_details: bool,
15+
class_name: str,
16+
) -> list[str]:
17+
try:
18+
from openai.types.chat import ChatCompletion
19+
except ImportError as e:
20+
raise ImportError(
21+
f"OpenAI is required to use the {class_name} class. Please install it with `pip install openai`."
22+
) from e
23+
24+
if isinstance(tlm_result, dict):
25+
if response is None:
26+
raise ValueError("'response' is required when tlm_result is a TLMScore object")
27+
28+
tlm_metadata = tlm_result
29+
response_text = response.choices[0].message.content or "{}"
30+
31+
elif isinstance(tlm_result, ChatCompletion):
32+
if getattr(tlm_result, "tlm_metadata", None) is None:
33+
raise ValueError("tlm_result must contain tlm_metadata.")
34+
35+
tlm_metadata = tlm_result.tlm_metadata # type: ignore
36+
response_text = tlm_result.choices[0].message.content or "{}"
37+
38+
else:
39+
raise TypeError("tlm_result must be a TLMScore or ChatCompletion object.")
40+
41+
if "per_field_score" not in tlm_metadata.get("log", {}):
42+
raise ValueError(
43+
"`per_field_score` is not present in the log.\n"
44+
"`get_untrustworthy_fields()` can only be called scoring structured outputs responses and specifying "
45+
"`per_field_score` in the `log` option for TLM."
46+
)
47+
48+
try:
49+
so_response = json.loads(response_text)
50+
except Exception:
51+
raise ValueError(
52+
"The LLM response must be a valid JSON output (use `response_format` to specify the output format)"
53+
)
54+
55+
per_field_score = tlm_metadata["log"]["per_field_score"]
56+
per_score_details = []
57+
58+
for key, value in per_field_score.items():
59+
score = value["score"]
60+
if float(score) < threshold:
61+
key_details = {
62+
"response": so_response[key],
63+
"score": score,
64+
"explanation": value["explanation"],
65+
}
66+
per_score_details.append({key: key_details})
67+
68+
per_score_details.sort(key=lambda x: next(iter(x.values()))["score"])
69+
untrustworthy_fields = [next(iter(item.keys())) for item in per_score_details]
70+
71+
if display_details:
72+
if len(untrustworthy_fields) == 0:
73+
print("No untrustworthy fields found")
74+
75+
else:
76+
print(f"Untrustworthy fields: {untrustworthy_fields}\n")
77+
for item in per_score_details:
78+
print(f"Field: {next(iter(item.keys()))}")
79+
details = next(iter(item.values()))
80+
print(f"Response: {details['response']}")
81+
print(f"Score: {details['score']}")
82+
print(f"Explanation: {details['explanation']}")
83+
print()
84+
85+
return untrustworthy_fields

src/cleanlab_tlm/utils/vpc/chat_completions.py

Lines changed: 59 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,19 @@
1111
from __future__ import annotations
1212

1313
import os
14-
from typing import TYPE_CHECKING, Any, Optional
14+
from typing import TYPE_CHECKING, Any, Optional, Union, cast
1515

1616
import requests
1717

1818
from cleanlab_tlm.internal.base import BaseTLM
1919
from cleanlab_tlm.internal.constants import _VALID_TLM_QUALITY_PRESETS_CHAT_COMPLETIONS
20+
from cleanlab_tlm.tlm import TLMScore
21+
from cleanlab_tlm.utils.per_field_score_utils import _get_untrustworthy_fields
2022
from cleanlab_tlm.utils.vpc.tlm import VPCTLMOptions
2123

2224
if TYPE_CHECKING:
2325
from openai.types.chat import ChatCompletion
2426

25-
from cleanlab_tlm.internal.types import JSONDict
26-
2727

2828
class TLMChatCompletion(BaseTLM):
2929
"""
@@ -74,7 +74,7 @@ def score(
7474
*,
7575
response: ChatCompletion,
7676
**openai_kwargs: Any,
77-
) -> JSONDict:
77+
) -> TLMScore:
7878
"""Score the trustworthiness of an OpenAI ChatCompletion response.
7979
8080
Args:
@@ -84,12 +84,22 @@ def score(
8484
Returns:
8585
TLMScore: A dict containing the trustworthiness score and optional logs
8686
"""
87+
try:
88+
from openai.lib._parsing._completions import type_to_response_format_param
89+
except ImportError as e:
90+
raise ImportError(
91+
f"OpenAI is required to use the {self.__class__.__name__} class. Please install it with `pip install openai`."
92+
) from e
93+
8794
if (base_url := os.environ.get("BASE_URL")) is None:
8895
raise ValueError("BASE_URL is not set. Please set it in the environment variables.")
8996

9097
# replace the model used for scoring with the specified model in options
9198
openai_kwargs["model"] = self._options["model"]
9299

100+
if "response_format" in openai_kwargs:
101+
openai_kwargs["response_format"] = type_to_response_format_param(openai_kwargs["response_format"])
102+
93103
res = requests.post(
94104
f"{base_url}/chat/score",
95105
json={
@@ -112,7 +122,49 @@ def score(
112122

113123
res_json = res.json()
114124
tlm_result = {"trustworthiness_score": res_json["tlm_metadata"]["trustworthiness_score"]}
115-
if explanation := res_json["tlm_metadata"].get("log", {}).get("explanation"):
116-
tlm_result["log"] = {"explanation": explanation}
117125

118-
return tlm_result
126+
if self._return_log:
127+
log = {}
128+
129+
log_options = cast(list[str], self._options.get("log", []))
130+
if "explanation" in log_options:
131+
explanation = res_json["tlm_metadata"].get("log", {}).get("explanation")
132+
log["explanation"] = explanation
133+
134+
if "per_field_score" in log_options:
135+
per_field_score = res_json["tlm_metadata"].get("log", {}).get("per_field_score")
136+
log["per_field_score"] = per_field_score
137+
138+
tlm_result["log"] = log
139+
140+
return cast(TLMScore, tlm_result)
141+
142+
def get_untrustworthy_fields(
143+
self,
144+
*,
145+
response: Optional[ChatCompletion] = None,
146+
tlm_result: Union[TLMScore, ChatCompletion],
147+
threshold: float = 0.8,
148+
display_details: bool = True,
149+
) -> list[str]:
150+
"""Gets the fields of a structured output response that are considered untrustworthy by TLM.
151+
Only works for responses that are valid JSON objects (uses `response_format` to specify the output format).
152+
Prints detailed information about the untrustworthy fields if `display_details` is True.
153+
154+
Args:
155+
response (ChatCompletion): The OpenAI ChatCompletion response object to evaluate
156+
tlm_result (TLMScore | ChatCompletion): The result object from a previous TLM call
157+
threshold (float): The threshold for considering a field untrustworthy
158+
display_details (bool): Whether to display detailed information about the untrustworthy fields
159+
160+
Returns:
161+
list[str]: The fields of the response that are considered untrustworthy by TLM
162+
"""
163+
164+
return _get_untrustworthy_fields(
165+
response=response,
166+
tlm_result=tlm_result,
167+
threshold=threshold,
168+
display_details=display_details,
169+
class_name=self.__class__.__name__,
170+
)

0 commit comments

Comments
 (0)