Skip to content

Commit e6469df

Browse files
authored
Add get_explanation() API (#116)
1 parent 4dc15e0 commit e6469df

File tree

10 files changed

+1099
-31
lines changed

10 files changed

+1099
-31
lines changed

CHANGELOG.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
## [1.1.30] - 2025-09-09
11+
12+
### Added
13+
14+
- Add `get_explanation()` API for TLM, TrustworthyRAG and TLMChatCompletions
15+
1016
## [1.1.29] - 2025-09-03
1117

1218
### Added
@@ -341,7 +347,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
341347
- Release of the Cleanlab TLM Python client.
342348

343349

344-
[Unreleased]: https://github.com/cleanlab/cleanlab-tlm/compare/v1.1.29...HEAD
350+
[Unreleased]: https://github.com/cleanlab/cleanlab-tlm/compare/v1.1.30...HEAD
351+
[1.1.30]: https://github.com/cleanlab/cleanlab-tlm/compare/v1.1.29...v1.1.30
345352
[1.1.29]: https://github.com/cleanlab/cleanlab-tlm/compare/v1.1.28...v1.1.29
346353
[1.1.28]: https://github.com/cleanlab/cleanlab-tlm/compare/v1.1.27...v1.1.28
347354
[1.1.27]: https://github.com/cleanlab/cleanlab-tlm/compare/v1.1.26...v1.1.27

src/cleanlab_tlm/__about__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
# SPDX-License-Identifier: MIT
2-
__version__ = "1.1.29"
2+
__version__ = "1.1.30"

src/cleanlab_tlm/internal/api/api.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
tlm_base_url = f"{base_url}/v0/trustworthy_llm"
5656
tlm_rag_base_url = f"{base_url}/v1/rag_trustworthy_llm"
5757
tlm_openai_base_url = f"{base_url}/v1/openai_trustworthy_llm"
58+
tlm_explanation_base_url = f"{base_url}/v1/tlm_explanation"
5859

5960

6061
def _construct_headers(api_key: Optional[str], content_type: Optional[str] = "application/json") -> JSONDict:
@@ -577,3 +578,45 @@ async def tlm_chat_completions_score(
577578
await client_session.close()
578579

579580
return cast(JSONDict, res_json)
581+
582+
583+
@tlm_retry
584+
async def tlm_get_explanation(
585+
api_key: str,
586+
prompt: str,
587+
formatted_tlm_result: dict[str, Any],
588+
options: Optional[JSONDict],
589+
rate_handler: TlmRateHandler,
590+
client_session: Optional[aiohttp.ClientSession] = None,
591+
batch_index: Optional[int] = None,
592+
) -> JSONDict:
593+
local_scoped_client = False
594+
if not client_session:
595+
client_session = aiohttp.ClientSession()
596+
local_scoped_client = True
597+
598+
try:
599+
async with rate_handler:
600+
res = await client_session.post(
601+
f"{tlm_explanation_base_url}/get_explanation",
602+
json={
603+
_TLM_PROMPT_KEY: prompt,
604+
_TLM_RESPONSE_KEY: formatted_tlm_result,
605+
_TLM_OPTIONS_KEY: options or {},
606+
},
607+
headers=_construct_headers(api_key),
608+
)
609+
610+
res_json = await res.json()
611+
612+
await handle_api_key_error_from_resp(res)
613+
await handle_http_bad_request_error_from_resp(res)
614+
handle_rate_limit_error_from_resp(res)
615+
await handle_tlm_client_error_from_resp(res, batch_index)
616+
await handle_tlm_api_error_from_resp(res, batch_index)
617+
618+
finally:
619+
if local_scoped_client:
620+
await client_session.close()
621+
622+
return cast(JSONDict, res_json)

src/cleanlab_tlm/internal/validation.py

Lines changed: 161 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os
44
import warnings
55
from collections.abc import Sequence
6-
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
6+
from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast
77

88
from cleanlab_tlm.errors import ValidationError
99
from cleanlab_tlm.internal.constants import (
@@ -26,8 +26,8 @@
2626
from cleanlab_tlm.internal.types import Task
2727

2828
if TYPE_CHECKING:
29-
from cleanlab_tlm.tlm import TLMOptions
30-
from cleanlab_tlm.utils.rag import Eval
29+
from cleanlab_tlm.tlm import TLMOptions, TLMResponse, TLMScore
30+
from cleanlab_tlm.utils.rag import Eval, TrustworthyRAGResponse, TrustworthyRAGScore
3131

3232
SKIP_VALIDATE_TLM_OPTIONS: bool = os.environ.get("CLEANLAB_TLM_SKIP_VALIDATE_TLM_OPTIONS", "false").lower() == "true"
3333

@@ -366,6 +366,164 @@ def tlm_score_process_response_and_kwargs(
366366
return [dict(zip(combined_response_keys, values)) for values in combined_response_values_transposed]
367367

368368

369+
def tlm_explanation_format_tlm_result(
370+
tlm_result: Union[TLMResponse, Sequence[TLMResponse], TLMScore, Sequence[TLMScore]],
371+
response: Optional[Union[str, Sequence[str]]] = None,
372+
) -> Union[dict[str, Any], list[dict[str, Any]]]:
373+
if isinstance(tlm_result, Sequence):
374+
if not all(isinstance(r, dict) for r in tlm_result):
375+
raise ValidationError("all items in the tlm_result sequence must be dicts")
376+
377+
if not all("trustworthiness_score" in r for r in tlm_result):
378+
raise ValidationError("all items in the tlm_result sequence must contain a 'trustworthiness_score' key")
379+
380+
# for .get_trustworthiness_score() cases, the response is passed in as a separate argument
381+
if not all("response" in r for r in tlm_result):
382+
if response is None:
383+
raise ValidationError(
384+
"'response' is required if not provided in tlm_result, pass it in using the 'response' argument"
385+
)
386+
if not isinstance(response, Sequence) or isinstance(response, str):
387+
raise ValidationError("response must be a sequence when tlm_result is a sequence")
388+
if len(response) != len(tlm_result):
389+
raise ValidationError("response and score sequences must have the same length")
390+
if not all(isinstance(r, str) for r in response):
391+
raise ValidationError("all items in the response sequence must be strings")
392+
393+
return [{"response": r, **tlm_result} for r, tlm_result in zip(response, tlm_result)]
394+
395+
# for .prompt() cases, the response is provided in the tlm_result dict
396+
if response is not None:
397+
raise ValidationError(
398+
"response should only be provided once, either using the 'response' argument or in 'tlm_result'"
399+
)
400+
401+
return cast(list[dict[str, Any]], tlm_result)
402+
403+
if not isinstance(tlm_result, dict):
404+
raise ValidationError("tlm_result must be a dict or a sequence of dicts")
405+
406+
if "trustworthiness_score" not in tlm_result:
407+
raise ValidationError("tlm_result must contain a 'trustworthiness' key")
408+
409+
# the .get_trustworthiness_score() case
410+
if "response" not in tlm_result:
411+
if response is None:
412+
raise ValidationError(
413+
"'response' is required if not provided in tlm_result, pass it in using the 'response' argument"
414+
)
415+
if not isinstance(response, str):
416+
raise ValidationError("response must be a string when tlm_result is a dict")
417+
return {"response": response, **tlm_result}
418+
419+
# the .prompt() case
420+
if response is not None:
421+
raise ValidationError(
422+
"response should only be provided once, either using the 'response' argument or in 'tlm_result'"
423+
)
424+
return cast(dict[str, Any], tlm_result)
425+
426+
427+
def tlm_explanation_format_trustworthy_rag_result(
428+
tlm_result: Union[
429+
TrustworthyRAGResponse,
430+
Sequence[TrustworthyRAGResponse],
431+
TrustworthyRAGScore,
432+
Sequence[TrustworthyRAGScore],
433+
],
434+
response: Optional[Union[str, Sequence[str]]] = None,
435+
) -> Union[dict[str, Any], list[dict[str, Any]]]:
436+
if isinstance(tlm_result, Sequence):
437+
if not all(isinstance(r, dict) for r in tlm_result):
438+
raise ValidationError("all items in the tlm_result sequence must be dicts")
439+
440+
if not all(
441+
"trustworthiness" in r
442+
and isinstance(r["trustworthiness"], dict)
443+
and "score" in r["trustworthiness"]
444+
and r["trustworthiness"]["score"] is not None
445+
for r in tlm_result
446+
):
447+
raise ValidationError(
448+
"all items in the tlm_result sequence must contain a 'trustworthiness' dict with a non-None 'score' key"
449+
)
450+
451+
# for .score() cases, the response is passed in as a separate argument
452+
if not all("response" in r for r in tlm_result):
453+
if response is None:
454+
raise ValidationError(
455+
"'response' is required if not provided in tlm_result, pass it in using the 'response' argument"
456+
)
457+
if not isinstance(response, Sequence) or isinstance(response, str):
458+
raise ValidationError("response must be a sequence when tlm_result is a sequence")
459+
if len(response) != len(tlm_result):
460+
raise ValidationError("response and score sequences must have the same length")
461+
if not all(isinstance(r, str) for r in response):
462+
raise ValidationError("all items in the response sequence must be strings")
463+
464+
return [
465+
{
466+
"response": resp,
467+
"trustworthiness_score": res["trustworthiness"]["score"], # type: ignore
468+
**{k: v for k, v in res["trustworthiness"].items() if k != "score"}, # type: ignore
469+
}
470+
for resp, res in zip(response, tlm_result)
471+
]
472+
473+
# for .generate() cases, the response is provided in the tlm_result dict
474+
if response is not None:
475+
raise ValidationError(
476+
"response should only be provided once, either using the 'response' argument or in 'tlm_result'"
477+
)
478+
479+
return [
480+
{
481+
"response": res["response"],
482+
"trustworthiness_score": res["trustworthiness"]["score"], # type: ignore
483+
**{k: v for k, v in res["trustworthiness"].items() if k != "score"}, # type: ignore
484+
}
485+
for res in tlm_result
486+
]
487+
488+
if not isinstance(tlm_result, dict):
489+
raise ValidationError("tlm_result must be a dict or a sequence of dicts")
490+
491+
if (
492+
"trustworthiness" not in tlm_result
493+
or not isinstance(tlm_result["trustworthiness"], dict)
494+
or "score" not in tlm_result["trustworthiness"]
495+
or tlm_result["trustworthiness"]["score"] is None
496+
):
497+
raise ValidationError("tlm_result must contain a 'trustworthiness' dict with a non-None 'score' key")
498+
499+
# the .score() case
500+
if "response" not in tlm_result:
501+
if response is None:
502+
raise ValidationError(
503+
"'response' is required if not provided in tlm_result, pass it in using the 'response' argument"
504+
)
505+
if not isinstance(response, str):
506+
raise ValidationError("response must be a string when tlm_result is a dict")
507+
508+
return {
509+
"response": response,
510+
"trustworthiness_score": tlm_result["trustworthiness"]["score"],
511+
**{k: v for k, v in tlm_result["trustworthiness"].items() if k != "score"},
512+
}
513+
514+
# the .generate() case
515+
if response is not None:
516+
raise ValidationError(
517+
"response should only be provided once, either using the 'response' argument or in 'tlm_result'"
518+
)
519+
520+
return {
521+
"response": tlm_result["response"],
522+
"trustworthiness_score": tlm_result["trustworthiness"]["score"],
523+
**{k: v for k, v in tlm_result["trustworthiness"].items() if k != "score"},
524+
}
525+
526+
369527
def validate_tlm_lite_score_options(score_options: Any) -> None:
370528
invalid_score_keys = set(score_options.keys()).intersection(INVALID_SCORE_OPTIONS)
371529
if invalid_score_keys:

0 commit comments

Comments
 (0)