1111from __future__ import annotations
1212
1313import os
14- from typing import TYPE_CHECKING , Any , Optional
14+ from typing import TYPE_CHECKING , Any , Optional , Union , cast
1515
1616import requests
1717
1818from cleanlab_tlm .internal .base import BaseTLM
1919from cleanlab_tlm .internal .constants import _VALID_TLM_QUALITY_PRESETS_CHAT_COMPLETIONS
20+ from cleanlab_tlm .tlm import TLMScore
21+ from cleanlab_tlm .utils .per_field_score_utils import _get_untrustworthy_fields
2022from cleanlab_tlm .utils .vpc .tlm import VPCTLMOptions
2123
2224if TYPE_CHECKING :
2325 from openai .types .chat import ChatCompletion
2426
25- from cleanlab_tlm .internal .types import JSONDict
26-
2727
2828class TLMChatCompletion (BaseTLM ):
2929 """
@@ -74,7 +74,7 @@ def score(
7474 * ,
7575 response : ChatCompletion ,
7676 ** openai_kwargs : Any ,
77- ) -> JSONDict :
77+ ) -> TLMScore :
7878 """Score the trustworthiness of an OpenAI ChatCompletion response.
7979
8080 Args:
@@ -84,12 +84,22 @@ def score(
8484 Returns:
8585 TLMScore: A dict containing the trustworthiness score and optional logs
8686 """
87+ try :
88+ from openai .lib ._parsing ._completions import type_to_response_format_param
89+ except ImportError as e :
90+ raise ImportError (
91+ f"OpenAI is required to use the { self .__class__ .__name__ } class. Please install it with `pip install openai`."
92+ ) from e
93+
8794 if (base_url := os .environ .get ("BASE_URL" )) is None :
8895 raise ValueError ("BASE_URL is not set. Please set it in the environment variables." )
8996
9097 # replace the model used for scoring with the specified model in options
9198 openai_kwargs ["model" ] = self ._options ["model" ]
9299
100+ if "response_format" in openai_kwargs :
101+ openai_kwargs ["response_format" ] = type_to_response_format_param (openai_kwargs ["response_format" ])
102+
93103 res = requests .post (
94104 f"{ base_url } /chat/score" ,
95105 json = {
@@ -112,7 +122,49 @@ def score(
112122
113123 res_json = res .json ()
114124 tlm_result = {"trustworthiness_score" : res_json ["tlm_metadata" ]["trustworthiness_score" ]}
115- if explanation := res_json ["tlm_metadata" ].get ("log" , {}).get ("explanation" ):
116- tlm_result ["log" ] = {"explanation" : explanation }
117125
118- return tlm_result
126+ if self ._return_log :
127+ log = {}
128+
129+ log_options = cast (list [str ], self ._options .get ("log" , []))
130+ if "explanation" in log_options :
131+ explanation = res_json ["tlm_metadata" ].get ("log" , {}).get ("explanation" )
132+ log ["explanation" ] = explanation
133+
134+ if "per_field_score" in log_options :
135+ per_field_score = res_json ["tlm_metadata" ].get ("log" , {}).get ("per_field_score" )
136+ log ["per_field_score" ] = per_field_score
137+
138+ tlm_result ["log" ] = log
139+
140+ return cast (TLMScore , tlm_result )
141+
142+ def get_untrustworthy_fields (
143+ self ,
144+ * ,
145+ response : Optional [ChatCompletion ] = None ,
146+ tlm_result : Union [TLMScore , ChatCompletion ],
147+ threshold : float = 0.8 ,
148+ display_details : bool = True ,
149+ ) -> list [str ]:
150+ """Gets the fields of a structured output response that are considered untrustworthy by TLM.
151+ Only works for responses that are valid JSON objects (uses `response_format` to specify the output format).
152+ Prints detailed information about the untrustworthy fields if `display_details` is True.
153+
154+ Args:
155+ response (ChatCompletion): The OpenAI ChatCompletion response object to evaluate
156+ tlm_result (TLMScore | ChatCompletion): The result object from a previous TLM call
157+ threshold (float): The threshold for considering a field untrustworthy
158+ display_details (bool): Whether to display detailed information about the untrustworthy fields
159+
160+ Returns:
161+ list[str]: The fields of the response that are considered untrustworthy by TLM
162+ """
163+
164+ return _get_untrustworthy_fields (
165+ response = response ,
166+ tlm_result = tlm_result ,
167+ threshold = threshold ,
168+ display_details = display_details ,
169+ class_name = self .__class__ .__name__ ,
170+ )
0 commit comments