Skip to content

Commit a8c8a2f

Browse files
authored
Add better structured outputs handling for ChatCompletions (#95)
1 parent fac93e7 commit a8c8a2f

File tree

5 files changed

+166
-6
lines changed

5 files changed

+166
-6
lines changed

CHANGELOG.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,20 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
## [1.1.17] - 2025-07-18
11+
12+
### Changed
13+
14+
- Enabled `TLMChatCompletion.score()`to evaluate structured outputs in `ChatCompletion` objects
15+
16+
1017
## [1.1.16] - 2025-07-15
1118

1219
### Changed
1320

1421
- Add internal setting to bypass model validation check (for custom/VPC models)
1522

23+
1624
## [1.1.15] - 2025-07-14
1725

1826
### Changed
@@ -28,6 +36,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2836
- Support for `best` and `high` preset in `TrustworthyRAG`
2937

3038
### Changed
39+
3140
- Deprecate `use_self_reflection`
3241
- Documentation updates for new default configurations
3342

src/cleanlab_tlm/__about__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
# SPDX-License-Identifier: MIT
2-
__version__ = "1.1.16"
2+
__version__ = "1.1.17"

src/cleanlab_tlm/internal/api/api.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545

4646
if TYPE_CHECKING:
4747
import requests
48+
from openai.types.chat import ChatCompletion
4849

4950
from cleanlab_tlm.internal.concurrency import TlmRateHandler
5051
from cleanlab_tlm.utils.rag import Eval
@@ -53,6 +54,7 @@
5354
base_url = os.environ.get("CLEANLAB_API_BASE_URL", "https://api.cleanlab.ai/api")
5455
tlm_base_url = f"{base_url}/v0/trustworthy_llm"
5556
tlm_rag_base_url = f"{base_url}/v1/rag_trustworthy_llm"
57+
tlm_openai_base_url = f"{base_url}/v1/openai_trustworthy_llm"
5658

5759

5860
def _construct_headers(api_key: Optional[str], content_type: Optional[str] = "application/json") -> JSONDict:
@@ -115,7 +117,7 @@ def handle_rate_limit_error_from_resp(resp: aiohttp.ClientResponse) -> None:
115117
)
116118

117119

118-
async def handle_tlm_client_error_from_resp(resp: aiohttp.ClientResponse, batch_index: Optional[int]) -> None:
120+
async def handle_tlm_client_error_from_resp(resp: aiohttp.ClientResponse, batch_index: Optional[int] = None) -> None:
119121
"""Catches 4XX (client error) errors."""
120122
if 400 <= resp.status < 500: # noqa: PLR2004
121123
try:
@@ -133,7 +135,7 @@ async def handle_tlm_client_error_from_resp(resp: aiohttp.ClientResponse, batch_
133135
raise TlmBadRequestError(error_message, retryable)
134136

135137

136-
async def handle_tlm_api_error_from_resp(resp: aiohttp.ClientResponse, batch_index: Optional[int]) -> None:
138+
async def handle_tlm_api_error_from_resp(resp: aiohttp.ClientResponse, batch_index: Optional[int] = None) -> None:
137139
"""Catches 5XX (server error) errors."""
138140
if 500 <= resp.status < 600: # noqa: PLR2004
139141
try:
@@ -527,3 +529,51 @@ async def tlm_rag_score(
527529
ordered_res[evaluation.name] = res_json[evaluation.name]
528530

529531
return ordered_res
532+
533+
534+
@tlm_retry
535+
async def tlm_chat_completions_score(
536+
api_key: str,
537+
response: ChatCompletion,
538+
client_session: Optional[aiohttp.ClientSession] = None,
539+
**input_kwargs: Any,
540+
) -> JSONDict:
541+
"""
542+
Score an OpenAI ChatCompletion response using Trustworthy Language Model
543+
544+
Args:
545+
api_key (str): API key for auth
546+
response (ChatCompletion): response to be evaluated (OpenAI ChatCompletion object)
547+
client_session (aiohttp.ClientSession): client session used to issue TLM request
548+
**input_kwargs: additional keyword arguments (openai arguments or TLM options) to pass to the TLM request.
549+
Returns:
550+
JSONDict: dictionary with trustworthiness score and any evaluation results
551+
"""
552+
local_scoped_client = False
553+
if not client_session:
554+
client_session = aiohttp.ClientSession()
555+
local_scoped_client = True
556+
557+
try:
558+
res = await client_session.post(
559+
f"{tlm_openai_base_url}/score",
560+
json={
561+
"response": response.model_dump(),
562+
**input_kwargs,
563+
},
564+
headers=_construct_headers(api_key),
565+
)
566+
567+
res_json = await res.json()
568+
569+
await handle_api_key_error_from_resp(res)
570+
await handle_http_bad_request_error_from_resp(res)
571+
handle_rate_limit_error_from_resp(res)
572+
await handle_tlm_client_error_from_resp(res)
573+
await handle_tlm_api_error_from_resp(res)
574+
575+
finally:
576+
if local_scoped_client:
577+
await client_session.close()
578+
579+
return cast(JSONDict, res_json)

src/cleanlab_tlm/utils/chat_completions.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@
55
It works for any OpenAI LLM model, as well as the many other non-OpenAI LLMs that are also usable via Chat Completions API (Gemini, DeepSeek, Llama, etc).
66
"""
77

8+
import asyncio
89
from typing import TYPE_CHECKING, Any, Optional, cast
910

11+
from cleanlab_tlm.internal.api.api import tlm_chat_completions_score
1012
from cleanlab_tlm.internal.base import BaseTLM
1113
from cleanlab_tlm.internal.constants import (
1214
_DEFAULT_TLM_QUALITY_PRESET,
13-
_VALID_TLM_QUALITY_PRESETS_CHAT_COMPLETIONS,
15+
_VALID_TLM_QUALITY_PRESETS,
1416
)
1517
from cleanlab_tlm.internal.types import TLMQualityPreset
1618
from cleanlab_tlm.tlm import TLM, TLMOptions, TLMScore
@@ -52,7 +54,7 @@ def __init__(
5254
"""
5355
super().__init__(
5456
quality_preset=quality_preset,
55-
valid_quality_presets=_VALID_TLM_QUALITY_PRESETS_CHAT_COMPLETIONS,
57+
valid_quality_presets=_VALID_TLM_QUALITY_PRESETS,
5658
support_custom_eval_criteria=True,
5759
api_key=api_key,
5860
options=options,
@@ -85,6 +87,26 @@ def score(
8587
self._validate_chat_completion(response)
8688
if (messages := openai_kwargs.get("messages")) is None:
8789
raise ValueError("messages is a required OpenAI input argument.")
90+
91+
combined_kwargs = {**openai_kwargs, **self._options}
92+
93+
# handle structured outputs differently
94+
if openai_kwargs.get("response_format"):
95+
return cast(
96+
TLMScore,
97+
self._event_loop.run_until_complete(
98+
asyncio.wait_for(
99+
tlm_chat_completions_score(
100+
api_key=self._api_key,
101+
response=response,
102+
**combined_kwargs,
103+
),
104+
timeout=self._timeout,
105+
)
106+
),
107+
)
108+
109+
# all other cases
88110
tools = openai_kwargs.get("tools", None)
89111

90112
prompt_text = _form_prompt_chat_completions_api(messages, tools)

tests/test_chat_completions.py

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@
55
from openai.types.chat import ChatCompletion, ChatCompletionMessage
66
from openai.types.chat.chat_completion import Choice
77
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall, Function
8+
from openai.types.completion_usage import (
9+
CompletionTokensDetails,
10+
CompletionUsage,
11+
PromptTokensDetails,
12+
)
813

914
from cleanlab_tlm.internal.types import TLMQualityPreset
1015
from cleanlab_tlm.tlm import TLMScore
@@ -19,7 +24,7 @@
1924

2025
@pytest.mark.parametrize(
2126
"quality_preset",
22-
["base", "low", "medium"],
27+
["base", "low", "medium", "high", "best"],
2328
)
2429
def test_tlm_chat_completion_score(quality_preset: TLMQualityPreset) -> None:
2530
tlm_chat = TLMChatCompletion(quality_preset=quality_preset)
@@ -118,6 +123,80 @@ def test_tlm_chat_completion_score_with_tools() -> None:
118123
assert is_trustworthiness_score_json_format(score)
119124

120125

126+
def test_tlm_chat_completion_score_with_structured_output() -> None:
127+
tlm_chat = TLMChatCompletion()
128+
openai_kwargs = {
129+
"model": "gpt-4.1-mini",
130+
"messages": [
131+
{
132+
"role": "system",
133+
"content": "You are a helpful math tutor. Guide the user through the solution step by step.",
134+
},
135+
{"role": "user", "content": "how can I solve 8x + 7 = -23"},
136+
],
137+
"response_format": {
138+
"type": "json_schema",
139+
"json_schema": {
140+
"name": "math_reasoning",
141+
"schema": {
142+
"type": "object",
143+
"properties": {
144+
"steps": {
145+
"type": "array",
146+
"items": {
147+
"type": "object",
148+
"properties": {
149+
"explanation": {"type": "string"},
150+
"output": {"type": "string"},
151+
},
152+
"required": ["explanation", "output"],
153+
"additionalProperties": False,
154+
},
155+
},
156+
"final_answer": {"type": "string"},
157+
},
158+
"required": ["steps", "final_answer"],
159+
"additionalProperties": False,
160+
},
161+
"strict": True,
162+
},
163+
},
164+
}
165+
response = ChatCompletion(
166+
id="test",
167+
choices=[
168+
Choice(
169+
index=0,
170+
message=ChatCompletionMessage(
171+
role="assistant",
172+
content='{"steps":[{"explanation":"Start with the original equation: 8x + 7 = -23","output":"8x + 7 = -23"},{"explanation":"Subtract 7 from both sides to isolate the term with x on one side. This will give us: 8x = -23 - 7","output":"8x = -30"},{"explanation":"Now simplify the right side: -23 - 7 equals -30, so we have 8x = -30","output":"8x = -30"},{"explanation":"Next, divide both sides by 8 to solve for x. This gives us: x = -30 / 8","output":"x = -3.75"},{"explanation":"We can also simplify -30 / 8 by dividing both the numerator and the denominator by 2. This leads to: x = -15 / 4","output":"x = -15/4 (or -3.75 as a decimal)"}],"final_answer":"x = -15/4 or x = -3.75"}',
173+
),
174+
finish_reason="stop",
175+
)
176+
],
177+
usage=CompletionUsage(
178+
completion_tokens=50,
179+
completion_tokens_details=CompletionTokensDetails(
180+
accepted_prediction_tokens=0,
181+
audio_tokens=0,
182+
reasoning_tokens=0,
183+
rejected_prediction_tokens=0,
184+
),
185+
prompt_tokens=50,
186+
prompt_tokens_details=PromptTokensDetails(audio_tokens=0, cached_tokens=0),
187+
total_tokens=100,
188+
),
189+
created=1234567890,
190+
model="test-model",
191+
object="chat.completion",
192+
)
193+
194+
score = tlm_chat.score(response=response, **openai_kwargs)
195+
196+
assert score is not None
197+
assert is_trustworthiness_score_json_format(score)
198+
199+
121200
def test_tlm_chat_completion_score_invalid_response() -> None:
122201
tlm_chat = TLMChatCompletion()
123202
openai_kwargs = {

0 commit comments

Comments
 (0)