Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions src/cleanlab_tlm/utils/chat_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import asyncio
from typing import TYPE_CHECKING, Any, Optional, cast

import numpy as np

from cleanlab_tlm.internal.api.api import tlm_chat_completions_score
from cleanlab_tlm.internal.base import BaseTLM
from cleanlab_tlm.internal.constants import (
Expand All @@ -16,7 +18,10 @@
)
from cleanlab_tlm.internal.types import TLMQualityPreset
from cleanlab_tlm.tlm import TLM, TLMOptions, TLMScore
from cleanlab_tlm.utils.chat import _form_prompt_chat_completions_api, form_response_string_chat_completions
from cleanlab_tlm.utils.chat import (
_form_prompt_chat_completions_api,
form_response_string_chat_completions,
)

if TYPE_CHECKING:
from openai.types.chat import ChatCompletion, ChatCompletionMessage
Expand Down Expand Up @@ -116,7 +121,21 @@ def score(
prompt_text = _form_prompt_chat_completions_api(messages, tools)
response_text = form_response_string_chat_completions(response=response)

return cast(TLMScore, self._tlm.get_trustworthiness_score(prompt_text, response_text))
scoring_kwargs = {}

# add perplexity to tlm.get_trustworthiness_score kwargs if it exists
try:
perplexity = _extract_perplexity(response)
except Exception:
perplexity = None

if perplexity is not None:
scoring_kwargs["perplexity"] = perplexity

return cast(
TLMScore,
self._tlm.get_trustworthiness_score(prompt_text, response_text, **scoring_kwargs),
)

@staticmethod
def _get_response_message(response: "ChatCompletion") -> "ChatCompletionMessage":
Expand All @@ -136,3 +155,13 @@ def _validate_chat_completion(self, response: Any) -> None:
message = self._get_response_message(response)
if message.content is None and message.tool_calls is None:
raise ValueError("The OpenAI ChatCompletion object does not contain a message content or tool calls.")


def _extract_perplexity(response: "ChatCompletion") -> Optional[float]:
response_logprobs = response.choices[0].logprobs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this won't ever cause an error for any OpenAI LLM?
I'd imagine its safer to check if the key logprobs exists

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Afaik this always exists, I've encapsulated this entire function in a try / except block tho so even if something weird happens with extracting logprobs the entire call will not fail (it will just not include perplexity)

if response_logprobs is None or response_logprobs.content is None:
return None

logprobs_list = [completion.logprob for completion in response_logprobs.content]
perplexity = np.exp(np.mean(logprobs_list))
return float(perplexity)
93 changes: 89 additions & 4 deletions tests/test_chat_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

import pytest
from openai.types.chat import ChatCompletion, ChatCompletionMessage
from openai.types.chat.chat_completion import Choice
from openai.types.chat.chat_completion import Choice, ChoiceLogprobs
from openai.types.chat.chat_completion_token_logprob import ChatCompletionTokenLogprob
from openai.types.completion_usage import (
CompletionTokensDetails,
CompletionUsage,
Expand All @@ -13,7 +14,7 @@
from cleanlab_tlm.internal.constants import _TLM_DEFAULT_MODEL
from cleanlab_tlm.internal.types import TLMQualityPreset
from cleanlab_tlm.tlm import TLMScore
from cleanlab_tlm.utils.chat_completions import TLMChatCompletion
from cleanlab_tlm.utils.chat_completions import TLMChatCompletion, _extract_perplexity
from tests.conftest import make_text_unique
from tests.constants import TEST_PROMPT, TEST_RESPONSE
from tests.openai_compat import ChatCompletionMessageToolCall, Function
Expand Down Expand Up @@ -85,6 +86,8 @@ def test_tlm_chat_completion_score_with_options() -> None:

assert score is not None
assert is_trustworthiness_score_json_format(score)
assert score["log"]["explanation"] is not None
assert score["log"]["perplexity"] is None


def test_tlm_chat_completion_score_with_tools() -> None:
Expand Down Expand Up @@ -132,6 +135,82 @@ def test_tlm_chat_completion_score_with_tools() -> None:
assert is_trustworthiness_score_json_format(score)


def test_tlm_chat_completion_score_with_perplexity() -> None:
tlm_chat = TLMChatCompletion(options={"log": ["perplexity"]})
openai_kwargs = {
"model": "gpt-4.1-mini",
"messages": [{"role": "user", "content": test_prompt}],
}
response = ChatCompletion(
id="test",
choices=[
Choice(
index=0,
message=ChatCompletionMessage(role="assistant", content=test_response),
finish_reason="stop",
logprobs=ChoiceLogprobs(
content=[
ChatCompletionTokenLogprob(
token="The", # noqa: S106
bytes=[84, 104, 101],
logprob=0.0,
top_logprobs=[],
),
ChatCompletionTokenLogprob(
token=" capital", # noqa: S106
bytes=[32, 99, 97, 112, 105, 116, 97, 108],
logprob=0.0,
top_logprobs=[],
),
ChatCompletionTokenLogprob(
token=" of", # noqa: S106
bytes=[32, 111, 102],
logprob=0.0,
top_logprobs=[],
),
ChatCompletionTokenLogprob(
token=" France", # noqa: S106
bytes=[32, 70, 114, 97, 110, 99, 101],
logprob=0.0,
top_logprobs=[],
),
ChatCompletionTokenLogprob(
token=" is", # noqa: S106
bytes=[32, 105, 115],
logprob=0.0,
top_logprobs=[],
),
ChatCompletionTokenLogprob(
token=" Paris", # noqa: S106
bytes=[32, 80, 97, 114, 105, 115],
logprob=0.0,
top_logprobs=[],
),
ChatCompletionTokenLogprob(
token=".", # noqa: S106
bytes=[46],
logprob=-1.9361264946837764e-07,
top_logprobs=[],
),
],
refusal=None,
),
)
],
created=1234567890,
model="test-model",
object="chat.completion",
)

manually_calculated_perplexity = _extract_perplexity(response)

score = tlm_chat.score(response=response, **openai_kwargs)
returned_perplexity = score["log"]["perplexity"]

assert returned_perplexity is not None
assert manually_calculated_perplexity == returned_perplexity


def test_tlm_chat_completion_score_with_structured_output() -> None:
tlm_chat = TLMChatCompletion()
openai_kwargs = {
Expand Down Expand Up @@ -248,8 +327,14 @@ def test_tlm_chat_completion_score_missing_messages() -> None:
@pytest.mark.parametrize(
"arguments, condition", # noqa: PT006
[
(json.dumps({"query": "Capital of Germany"}), lambda score: score["trustworthiness_score"] < 0.5), # noqa: PLR2004
(json.dumps({"query": "Capital of France"}), lambda score: score["trustworthiness_score"] >= 0.8), # noqa: PLR2004
(
json.dumps({"query": "Capital of Germany"}),
lambda score: score["trustworthiness_score"] < 0.5, # noqa: PLR2004
),
(
json.dumps({"query": "Capital of France"}),
lambda score: score["trustworthiness_score"] >= 0.8, # noqa: PLR2004
),
],
ids=["bad_arguments", "good_arguments"],
)
Expand Down