Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,7 @@ def _get_binary_result(self, score: float) -> str:
else:
return EVALUATION_PASS_FAIL_MAPPING[False]

@override
async def _do_eval(self, eval_input: Dict) -> Dict[str, Union[float, str]]: # type: ignore[override]
async def _do_eval_with_flow(self, eval_input: Dict, flow) -> Dict[str, Union[float, str]]: # type: ignore[override]
"""Do a relevance evaluation.

:param eval_input: The input to the evaluator. Expected to contain
Expand All @@ -134,7 +133,7 @@ async def _do_eval(self, eval_input: Dict) -> Dict[str, Union[float, str]]: # t
target=ErrorTarget.CONVERSATION,
)
# Call the prompty flow to get the evaluation result.
prompty_output_dict = await self._flow(timeout=self._LLM_CALL_TIMEOUT, **eval_input)
prompty_output_dict = await flow(timeout=self._LLM_CALL_TIMEOUT, **eval_input)

score = math.nan
if prompty_output_dict:
Expand Down Expand Up @@ -190,6 +189,20 @@ async def _do_eval(self, eval_input: Dict) -> Dict[str, Union[float, str]]: # t
f"{self._result_key}_threshold": self._threshold,
}

@override
async def _do_eval(self, eval_input: Dict) -> Dict[str, Union[float, str]]: # type: ignore[override]
"""Do a relevance evaluation with default flow.

:param eval_input: The input to the evaluator. Expected to contain
whatever inputs are needed for the _flow method, including context
and other fields depending on the child class.
:type eval_input: Dict
:return: The evaluation result.
:rtype: Dict
"""

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional: Remove redundant new line after doc string.

return await self._do_eval_with_flow(eval_input, self._flow)

@staticmethod
def _get_built_in_tool_definition(tool_name: str):
"""Get the definition for the built-in tool."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,9 @@ class GroundednessEvaluator(PromptyEvaluatorBase[Union[str, float]]):
@override
def __init__(self, model_config, *, threshold=3, credential=None, **kwargs):
current_dir = os.path.dirname(__file__)
prompty_path = os.path.join(current_dir, self._PROMPTY_FILE_NO_QUERY) # Default to no query
prompty_path = os.path.join(
current_dir, self._PROMPTY_FILE_NO_QUERY
) # Default to no query

self._higher_is_better = True
super().__init__(
Expand All @@ -118,6 +120,17 @@ def __init__(self, model_config, *, threshold=3, credential=None, **kwargs):
self.threshold = threshold
# Needs to be set because it's used in call method to re-validate prompt if `query` is provided

# To make sure they're not used directly
self._flow = None
self._prompty_file = None

self._flow_with_query = self._load_flow(
self._PROMPTY_FILE_WITH_QUERY, credential=credential
)
self._flow_no_query = self._load_flow(
self._PROMPTY_FILE_NO_QUERY, credential=credential
)

@overload
def __call__(
self,
Expand Down Expand Up @@ -201,31 +214,52 @@ def __call__( # pylint: disable=docstring-missing-param
:rtype: Union[Dict[str, Union[str, float]], Dict[str, Union[float, Dict[str, List[Union[str, float]]]]]]
"""

if kwargs.get("query", None):
self._ensure_query_prompty_loaded()

return super().__call__(*args, **kwargs)

def _ensure_query_prompty_loaded(self):
"""Switch to the query prompty file if not already loaded."""
def _load_flow(self, prompty_filename: str, **kwargs) -> AsyncPrompty:
"""Load the Prompty flow from the specified file.

:param prompty_filename: The filename of the Prompty flow to load.
:type prompty_filename: str
:return: The loaded Prompty flow.
:rtype: AsyncPrompty
"""

current_dir = os.path.dirname(__file__)
prompty_path = os.path.join(current_dir, self._PROMPTY_FILE_WITH_QUERY)
prompty_path = os.path.join(current_dir, prompty_filename)

self._prompty_file = prompty_path
prompty_model_config = construct_prompty_model_config(
validate_model_config(self._model_config),
self._DEFAULT_OPEN_API_VERSION,
UserAgentSingleton().value,
)
self._flow = AsyncPrompty.load(source=self._prompty_file, model=prompty_model_config)
flow = AsyncPrompty.load(
source=prompty_path,
model=prompty_model_config,
is_reasoning_model=self._is_reasoning_model,
**kwargs,
)

return flow

def _has_context(self, eval_input: dict) -> bool:
"""
Return True if eval_input contains a non-empty 'context' field.
Treats None, empty strings, empty lists, and lists of empty strings as no context.
"""
context = eval_input.get("context", None)
return self._validate_context(context)

def _validate_context(self, context) -> bool:
"""
Validate if the provided context is non-empty and meaningful.
Treats None, empty strings, empty lists, and lists of empty strings as no context.

:param context: The context to validate
:type context: Union[str, List, None]
:return: True if context is valid and non-empty, False otherwise
:rtype: bool
"""
if not context:
return False
if context == "<>": # Special marker for no context
Expand All @@ -239,12 +273,16 @@ def _has_context(self, eval_input: dict) -> bool:
@override
async def _do_eval(self, eval_input: Dict) -> Dict[str, Union[float, str]]:
if eval_input.get("query", None) is None:
return await super()._do_eval(eval_input)
return await super()._do_eval_with_flow(eval_input, self._flow_no_query)

contains_context = self._has_context(eval_input)

simplified_query = simplify_messages(eval_input["query"], drop_tool_calls=contains_context)
simplified_response = simplify_messages(eval_input["response"], drop_tool_calls=False)
simplified_query = simplify_messages(
eval_input["query"], drop_tool_calls=contains_context
)
simplified_response = simplify_messages(
eval_input["response"], drop_tool_calls=False
)

# Build simplified input
simplified_eval_input = {
Expand All @@ -254,7 +292,9 @@ async def _do_eval(self, eval_input: Dict) -> Dict[str, Union[float, str]]:
}

# Replace and call the parent method
return await super()._do_eval(simplified_eval_input)
return await super()._do_eval_with_flow(
simplified_eval_input, self._flow_with_query
)

async def _real_call(self, **kwargs):
"""The asynchronous call where real end-to-end evaluation logic is performed.
Expand All @@ -278,16 +318,21 @@ async def _real_call(self, **kwargs):
else:
raise ex

def _is_single_entry(self, value):
"""Determine if the input value represents a single entry, unsure is returned as False."""
if isinstance(value, str):
return True
if isinstance(value, list) and len(value) == 1:
return True
return False

def _convert_kwargs_to_eval_input(self, **kwargs):
if kwargs.get("context") or kwargs.get("conversation"):
return super()._convert_kwargs_to_eval_input(**kwargs)
query = kwargs.get("query")
response = kwargs.get("response")
tool_definitions = kwargs.get("tool_definitions")

if query and self._prompty_file != self._PROMPTY_FILE_WITH_QUERY:
self._ensure_query_prompty_loaded()

if (not query) or (not response): # or not tool_definitions:
msg = f"{type(self).__name__}: Either 'conversation' or individual inputs must be provided. For Agent groundedness 'query' and 'response' are required."
raise EvaluationException(
Expand All @@ -298,14 +343,39 @@ def _convert_kwargs_to_eval_input(self, **kwargs):
)
context = self._get_context_from_agent_response(response, tool_definitions)

filtered_response = self._filter_file_search_results(response)
return super()._convert_kwargs_to_eval_input(response=filtered_response, context=context, query=query)
if (
not self._validate_context(context)
and self._is_single_entry(response)
and self._is_single_entry(query)
):
msg = f"{type(self).__name__}: No valid context provided or could be extracted from the query or response."
raise EvaluationException(
message=msg,
blame=ErrorBlame.USER_ERROR,
category=ErrorCategory.NOT_APPLICABLE,
target=ErrorTarget.GROUNDEDNESS_EVALUATOR,
)

filtered_response = (
self._filter_file_search_results(response)
if self._validate_context(context)
else response
)
return super()._convert_kwargs_to_eval_input(
response=filtered_response, context=context, query=query
)

def _filter_file_search_results(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
def _filter_file_search_results(
self, messages: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""Filter out file_search tool results from the messages."""
file_search_ids = self._get_file_search_tool_call_ids(messages)
return [
msg for msg in messages if not (msg.get("role") == "tool" and msg.get("tool_call_id") in file_search_ids)
msg
for msg in messages
if not (
msg.get("role") == "tool" and msg.get("tool_call_id") in file_search_ids
)
]

def _get_context_from_agent_response(self, response, tool_definitions):
Expand All @@ -322,7 +392,10 @@ def _get_context_from_agent_response(self, response, tool_definitions):

context_lines = []
for tool_call in tool_calls:
if not isinstance(tool_call, dict) or tool_call.get("type") != "tool_call":
if (
not isinstance(tool_call, dict)
or tool_call.get("type") != "tool_call"
):
continue

tool_name = tool_call.get("name")
Expand Down Expand Up @@ -351,4 +424,8 @@ def _get_context_from_agent_response(self, response, tool_definitions):
def _get_file_search_tool_call_ids(self, query_or_response):
"""Return a list of tool_call_ids for file search tool calls."""
tool_calls = self._parse_tools_from_response(query_or_response)
return [tc.get("tool_call_id") for tc in tool_calls if tc.get("name") == "file_search"]
return [
tc.get("tool_call_id")
for tc in tool_calls
if tc.get("name") == "file_search"
]
Loading