Skip to content

Commit 1b43e21

Browse files
jwmuellerelisno
andauthored
Support for scoring tool calls in ChatCompletions API (#93)
Co-authored-by: Elías Snorrason <[email protected]>
1 parent 78c3264 commit 1b43e21

File tree

6 files changed

+335
-47
lines changed

6 files changed

+335
-47
lines changed

CHANGELOG.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
## [1.1.15] - 2025-07-14
11+
12+
### Changed
13+
14+
- Enabled `TLMChatCompletion.score()`to evaluate tool calls in `ChatCompletion` objects
15+
16+
1017
## [1.1.14] - 2025-07-08
1118

1219
### Added
@@ -237,7 +244,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
237244
- Release of the Cleanlab TLM Python client.
238245

239246

240-
[Unreleased]: https://github.com/cleanlab/cleanlab-tlm/compare/v1.1.14...HEAD
247+
[Unreleased]: https://github.com/cleanlab/cleanlab-tlm/compare/v1.1.15...HEAD
248+
[1.1.15]: https://github.com/cleanlab/cleanlab-tlm/compare/v1.1.14...v1.1.15
241249
[1.1.14]: https://github.com/cleanlab/cleanlab-tlm/compare/v1.1.13...v1.1.14
242250
[1.1.13]: https://github.com/cleanlab/cleanlab-tlm/compare/v1.1.12...v1.1.13
243251
[1.1.12]: https://github.com/cleanlab/cleanlab-tlm/compare/v1.1.11...v1.1.12

src/cleanlab_tlm/__about__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
# SPDX-License-Identifier: MIT
2-
__version__ = "1.1.14"
2+
__version__ = "1.1.15"

src/cleanlab_tlm/utils/chat.py

Lines changed: 43 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66

77
import json
88
import warnings
9-
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
9+
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
1010

1111
if TYPE_CHECKING:
12-
from openai.types.chat import ChatCompletionMessageParam
12+
from openai.types.chat import ChatCompletionMessage, ChatCompletionMessageParam
1313

1414
# Define message prefixes
1515
_SYSTEM_PREFIX = "System: "
@@ -443,41 +443,44 @@ def form_prompt_string(
443443
)
444444

445445

446-
def form_response_string_chat_completions_api(response: dict[str, Any]) -> str:
446+
def form_response_string_chat_completions_api(response: Union[dict[str, Any], "ChatCompletionMessage"]) -> str:
447447
"""
448448
Format an assistant response message dictionary from the Chat Completions API into a single string.
449449
450-
This function takes a response.choices[0].message.to_dict() from a chat.completions.create()
451-
and formats it into a string that includes both content and tool calls (if present).
450+
Given a ChatCompletion object `response` from `chat.completions.create()`,
451+
this function can take either a ChatCompletionMessage object from `response.choices[0].message`
452+
or a dictionary from `response.choices[0].message.to_dict()`.
453+
454+
All inputs are formatted into a string that includes both content and tool calls (if present).
452455
Tool calls are formatted using XML tags with JSON content, consistent with the format
453456
used in `form_prompt_string`.
454457
455458
Args:
456-
response (dict[str, Any]): A chat completion response message dictionary, containing:
457-
- 'content' (str): The main response content from the LLM
458-
- 'tool_calls' (List[Dict], optional): List of tool calls made by the LLM,
459-
where each tool call contains function name and arguments
459+
response (Union[dict[str, Any], ChatCompletionMessage]): Either:
460+
- A ChatCompletionMessage object from the OpenAI response
461+
- A chat completion response message dictionary, containing:
462+
- 'content' (str): The main response content from the LLM
463+
- 'tool_calls' (List[Dict], optional): List of tool calls made by the LLM,
464+
where each tool call contains function name and arguments
460465
461466
Returns:
462467
str: A formatted string containing the response content and any tool calls.
463468
Tool calls are formatted as XML tags containing JSON with function
464469
name and arguments.
465470
466471
Raises:
467-
TypeError: If response is not a dictionary.
472+
TypeError: If response is not a dictionary or ChatCompletionMessage object.
468473
"""
469-
if not isinstance(response, dict):
470-
raise TypeError(f"Expected response to be a dict, got {type(response).__name__}")
471-
472-
content = response.get("content") or ""
473-
474-
if "tool_calls" in response:
474+
response_dict = _response_to_dict(response)
475+
content = response_dict.get("content") or ""
476+
tool_calls = response_dict.get("tool_calls")
477+
if tool_calls is not None:
475478
try:
476-
tool_calls = "\n".join(
479+
tool_calls_str = "\n".join(
477480
f"{_TOOL_CALL_TAG_START}\n{json.dumps({'name': call['function']['name'], 'arguments': json.loads(call['function']['arguments']) if call['function']['arguments'] else {}}, indent=2)}\n{_TOOL_CALL_TAG_END}"
478-
for call in response["tool_calls"]
481+
for call in tool_calls
479482
)
480-
return f"{content}\n{tool_calls}".strip() if content else tool_calls
483+
return f"{content}\n{tool_calls_str}".strip() if content else tool_calls_str
481484
except (KeyError, TypeError, json.JSONDecodeError) as e:
482485
# Log the error but continue with just the content
483486
warnings.warn(
@@ -487,3 +490,24 @@ def form_response_string_chat_completions_api(response: dict[str, Any]) -> str:
487490
)
488491

489492
return str(content)
493+
494+
495+
def _response_to_dict(response: Any) -> dict[str, Any]:
496+
# `response` should be a Union[dict[str, Any], ChatCompletionMessage], but last isinstance check wouldn't be reachable
497+
if isinstance(response, dict):
498+
# Start with this isinstance check first to import `openai` lazily
499+
return response
500+
501+
try:
502+
from openai.types.chat import ChatCompletionMessage
503+
except ImportError as e:
504+
raise ImportError(
505+
"OpenAI is required to handle ChatCompletionMessage objects directly. Please install it with `pip install openai`."
506+
) from e
507+
508+
if not isinstance(response, ChatCompletionMessage):
509+
raise TypeError(
510+
f"Expected response to be a dict or ChatCompletionMessage object, got {type(response).__name__}"
511+
)
512+
513+
return response.model_dump()

src/cleanlab_tlm/utils/chat_completions.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@
1414
)
1515
from cleanlab_tlm.internal.types import TLMQualityPreset
1616
from cleanlab_tlm.tlm import TLM, TLMOptions, TLMScore
17-
from cleanlab_tlm.utils.chat import form_prompt_string
17+
from cleanlab_tlm.utils.chat import _form_prompt_chat_completions_api, form_response_string_chat_completions_api
1818

1919
if TYPE_CHECKING:
20-
from openai.types.chat import ChatCompletion
20+
from openai.types.chat import ChatCompletion, ChatCompletionMessage
2121

2222

2323
class TLMChatCompletion(BaseTLM):
@@ -82,26 +82,31 @@ def score(
8282
Returns:
8383
TLMScore: A dict containing the trustworthiness score and optional logs
8484
"""
85+
self._validate_chat_completion(response)
8586
if (messages := openai_kwargs.get("messages")) is None:
8687
raise ValueError("messages is a required OpenAI input argument.")
8788
tools = openai_kwargs.get("tools", None)
8889

89-
prompt_text = form_prompt_string(messages, tools)
90-
response_text = _get_string_response(response)
90+
prompt_text = _form_prompt_chat_completions_api(messages, tools)
91+
response_text = form_response_string_chat_completions_api(response=self._get_response_message(response))
9192

9293
return cast(TLMScore, self._tlm.get_trustworthiness_score(prompt_text, response_text))
9394

94-
95-
def _get_string_response(response: "ChatCompletion") -> str:
96-
try:
97-
from openai.types.chat import ChatCompletion
98-
except ImportError:
99-
raise ImportError(
100-
"OpenAI is required to use the TLMChatCompletion class. Please install it with `pip install openai`."
101-
)
102-
103-
if not isinstance(response, ChatCompletion):
104-
raise TypeError("The response is not an OpenAI ChatCompletion object.")
105-
if response.choices[0].message.content is None:
106-
raise ValueError("The OpenAI ChatCompletion object does not contain a message content.")
107-
return str(response.choices[0].message.content)
95+
@staticmethod
96+
def _get_response_message(response: "ChatCompletion") -> "ChatCompletionMessage":
97+
return response.choices[0].message
98+
99+
def _validate_chat_completion(self, response: Any) -> None:
100+
# `response` should be a ChatCompletion, but isinstance checks wouldn't be reachable
101+
try:
102+
from openai.types.chat import ChatCompletion
103+
except ImportError as e:
104+
raise ImportError(
105+
f"OpenAI is required to use the {self.__class__.__name__} class. Please install it with `pip install openai`."
106+
) from e
107+
if not isinstance(response, ChatCompletion):
108+
raise TypeError("The response is not an OpenAI ChatCompletion object.")
109+
110+
message = self._get_response_message(response)
111+
if message.content is None and message.tool_calls is None:
112+
raise ValueError("The OpenAI ChatCompletion object does not contain a message content or tool calls.")

tests/test_chat.py

Lines changed: 186 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from typing import TYPE_CHECKING, Any, cast
22

33
import pytest
4+
from openai.types.chat import ChatCompletionMessage
5+
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall, Function
46

57
from cleanlab_tlm.utils.chat import (
68
_form_prompt_chat_completions_api,
@@ -1366,13 +1368,13 @@ def test_form_response_string_chat_completions_api_empty_arguments() -> None:
13661368

13671369
def test_form_response_string_chat_completions_api_invalid_input() -> None:
13681370
"""Test form_response_string_chat_completions_api raises TypeError for invalid input."""
1369-
with pytest.raises(TypeError, match="Expected response to be a dict, got str"):
1371+
with pytest.raises(TypeError, match="Expected response to be a dict or ChatCompletionMessage object, got str"):
13701372
form_response_string_chat_completions_api("not a dict") # type: ignore[arg-type]
13711373

1372-
with pytest.raises(TypeError, match="Expected response to be a dict, got list"):
1374+
with pytest.raises(TypeError, match="Expected response to be a dict or ChatCompletionMessage object, got list"):
13731375
form_response_string_chat_completions_api([]) # type: ignore[arg-type]
13741376

1375-
with pytest.raises(TypeError, match="Expected response to be a dict, got NoneType"):
1377+
with pytest.raises(TypeError, match="Expected response to be a dict or ChatCompletionMessage object, got NoneType"):
13761378
form_response_string_chat_completions_api(None) # type: ignore[arg-type]
13771379

13781380

@@ -1406,3 +1408,184 @@ def test_form_response_string_chat_completions_api_malformed_tool_calls() -> Non
14061408
with pytest.warns(UserWarning, match="Error formatting tool_calls in response.*Returning content only"):
14071409
result = form_response_string_chat_completions_api(response)
14081410
assert result == "Let me check that."
1411+
1412+
1413+
############## ChatCompletionMessage tests ##############
1414+
1415+
1416+
def test_form_response_string_chat_completions_api_chatcompletion_message_just_content() -> None:
1417+
"""Test form_response_string_chat_completions_api with ChatCompletionMessage containing just content."""
1418+
1419+
content = "Hello, how can I help you today?"
1420+
message = ChatCompletionMessage(
1421+
role="assistant",
1422+
content=content,
1423+
)
1424+
result = form_response_string_chat_completions_api(message)
1425+
assert result == content
1426+
1427+
1428+
def test_form_response_string_chat_completions_api_chatcompletion_message_just_tool_calls() -> None:
1429+
"""Test form_response_string_chat_completions_api with ChatCompletionMessage containing just tool calls."""
1430+
message = ChatCompletionMessage(
1431+
role="assistant",
1432+
content=None,
1433+
tool_calls=[
1434+
ChatCompletionMessageToolCall(
1435+
id="call_123",
1436+
function=Function(
1437+
name="search_restaurants",
1438+
arguments='{"city": "Tokyo", "cuisine_type": "sushi", "max_price": 150, "dietary_restrictions": ["vegetarian", "gluten-free"], "open_now": true}',
1439+
),
1440+
type="function",
1441+
)
1442+
],
1443+
)
1444+
expected = (
1445+
"<tool_call>\n"
1446+
"{\n"
1447+
' "name": "search_restaurants",\n'
1448+
' "arguments": {\n'
1449+
' "city": "Tokyo",\n'
1450+
' "cuisine_type": "sushi",\n'
1451+
' "max_price": 150,\n'
1452+
' "dietary_restrictions": [\n'
1453+
' "vegetarian",\n'
1454+
' "gluten-free"\n'
1455+
" ],\n"
1456+
' "open_now": true\n'
1457+
" }\n"
1458+
"}\n"
1459+
"</tool_call>"
1460+
)
1461+
result = form_response_string_chat_completions_api(message)
1462+
assert result == expected
1463+
1464+
1465+
def test_form_response_string_chat_completions_api_chatcompletion_message_content_and_tool_calls() -> None:
1466+
"""Test form_response_string_chat_completions_api with ChatCompletionMessage containing both content and tool calls."""
1467+
message = ChatCompletionMessage(
1468+
role="assistant",
1469+
content="I'll check the weather for you.",
1470+
tool_calls=[
1471+
ChatCompletionMessageToolCall(
1472+
id="call_123",
1473+
function=Function(
1474+
name="get_weather",
1475+
arguments='{"location": "Paris"}',
1476+
),
1477+
type="function",
1478+
)
1479+
],
1480+
)
1481+
expected = (
1482+
"I'll check the weather for you.\n"
1483+
"<tool_call>\n"
1484+
"{\n"
1485+
' "name": "get_weather",\n'
1486+
' "arguments": {\n'
1487+
' "location": "Paris"\n'
1488+
" }\n"
1489+
"}\n"
1490+
"</tool_call>"
1491+
)
1492+
result = form_response_string_chat_completions_api(message)
1493+
assert result == expected
1494+
1495+
1496+
def test_form_response_string_chat_completions_api_chatcompletion_message_multiple_tool_calls() -> None:
1497+
"""Test form_response_string_chat_completions_api with ChatCompletionMessage containing multiple tool calls."""
1498+
message = ChatCompletionMessage(
1499+
role="assistant",
1500+
content="Let me check multiple things for you.",
1501+
tool_calls=[
1502+
ChatCompletionMessageToolCall(
1503+
id="call_123",
1504+
function=Function(
1505+
name="get_weather",
1506+
arguments='{"location": "Paris"}',
1507+
),
1508+
type="function",
1509+
),
1510+
ChatCompletionMessageToolCall(
1511+
id="call_456",
1512+
function=Function(
1513+
name="get_time",
1514+
arguments='{"timezone": "UTC"}',
1515+
),
1516+
type="function",
1517+
),
1518+
],
1519+
)
1520+
expected = (
1521+
"Let me check multiple things for you.\n"
1522+
"<tool_call>\n"
1523+
"{\n"
1524+
' "name": "get_weather",\n'
1525+
' "arguments": {\n'
1526+
' "location": "Paris"\n'
1527+
" }\n"
1528+
"}\n"
1529+
"</tool_call>\n"
1530+
"<tool_call>\n"
1531+
"{\n"
1532+
' "name": "get_time",\n'
1533+
' "arguments": {\n'
1534+
' "timezone": "UTC"\n'
1535+
" }\n"
1536+
"}\n"
1537+
"</tool_call>"
1538+
)
1539+
result = form_response_string_chat_completions_api(message)
1540+
assert result == expected
1541+
1542+
1543+
def test_form_response_string_chat_completions_api_chatcompletion_message_empty_content() -> None:
1544+
"""Test form_response_string_chat_completions_api with ChatCompletionMessage containing empty content."""
1545+
message = ChatCompletionMessage(
1546+
role="assistant",
1547+
content="",
1548+
)
1549+
expected = ""
1550+
result = form_response_string_chat_completions_api(message)
1551+
assert result == expected
1552+
1553+
1554+
def test_form_response_string_chat_completions_api_chatcompletion_message_empty_arguments() -> None:
1555+
"""Test form_response_string_chat_completions_api with ChatCompletionMessage containing empty arguments."""
1556+
message = ChatCompletionMessage(
1557+
role="assistant",
1558+
content="Running action",
1559+
tool_calls=[
1560+
ChatCompletionMessageToolCall(
1561+
id="call_123",
1562+
function=Function(
1563+
name="execute_action",
1564+
arguments="",
1565+
),
1566+
type="function",
1567+
)
1568+
],
1569+
)
1570+
expected = (
1571+
"Running action\n"
1572+
"<tool_call>\n"
1573+
"{\n"
1574+
' "name": "execute_action",\n'
1575+
' "arguments": {}\n'
1576+
"}\n"
1577+
"</tool_call>"
1578+
)
1579+
result = form_response_string_chat_completions_api(message)
1580+
assert result == expected
1581+
1582+
1583+
def test_form_response_string_chat_completions_api_chatcompletion_message_none_content() -> None:
1584+
"""Test form_response_string_chat_completions_api with ChatCompletionMessage containing None content."""
1585+
message = ChatCompletionMessage(
1586+
role="assistant",
1587+
content=None,
1588+
)
1589+
expected = ""
1590+
result = form_response_string_chat_completions_api(message)
1591+
assert result == expected

0 commit comments

Comments
 (0)