Skip to content

Commit 7810f29

Browse files
authored
[Integration] Add openai integration (#112)
1 parent b5faf3f commit 7810f29

File tree

5 files changed

+517
-6
lines changed

5 files changed

+517
-6
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ extra-dependencies = [
4646
"smolagents; python_version >= '3.10'",
4747
"thefuzz",
4848
"langchain-core",
49-
"openai"
49+
"openai",
50+
"openai-agents",
5051
]
5152

5253
[[tool.hatch.envs.types.matrix]]
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""Methods to integrate with AI Agents built using the OpenAI Agents SDK."""
2+
3+
from cleanlab_codex.experimental.openai_agents.cleanlab_hook import CleanlabHook
4+
5+
__all__ = ["CleanlabHook"]
Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
"""Methods to integrate with AI Agents built using the OpenAI Agents SDK."""
2+
3+
from __future__ import annotations
4+
5+
from typing import TYPE_CHECKING, Any, Optional
6+
7+
if TYPE_CHECKING:
8+
from agents.items import ModelResponse, TResponseInputItem
9+
from codex.types.project_validate_response import ProjectValidateResponse
10+
from openai.types.chat import ChatCompletionMessageParam
11+
12+
from cleanlab_codex import Project
13+
14+
import secrets
15+
16+
from agents import FunctionTool
17+
from agents.lifecycle import RunHooks
18+
from agents.models.chatcmpl_converter import Converter
19+
from agents.run_context import RunContextWrapper, TContext
20+
from openai.types.responses import ResponseOutputMessage, ResponseOutputText
21+
22+
from cleanlab_codex.experimental.openai_agents.utils import (
23+
form_response_string_responses_api_from_response,
24+
get_tool_result_as_text,
25+
)
26+
27+
28+
def _cleanlab_string_to_response_output_message(text: str, message_id: str | None = None) -> ResponseOutputMessage:
29+
"""Convert text to OpenAI response output message format."""
30+
if message_id is None:
31+
message_id = f"msg_cleanlab{secrets.token_hex(16)}" # TODO: Add support for marking cleanlab responses beyond adding cleanlab to ID
32+
return ResponseOutputMessage(
33+
id=message_id,
34+
content=[ResponseOutputText(text=text, type="output_text", annotations=[])],
35+
role="assistant",
36+
type="message",
37+
status="completed",
38+
)
39+
40+
41+
def _rewrite_response_content_inplace(response: ModelResponse, new_content: str) -> None:
42+
"""Rewrite the response content and remove tool calls."""
43+
response.output.clear()
44+
new_message_raw = _cleanlab_string_to_response_output_message(new_content)
45+
response.output.append(new_message_raw)
46+
47+
48+
class CleanlabHook(RunHooks[TContext]):
49+
"""V3 hook with comprehensive text extraction for all OpenAI response types."""
50+
51+
def __init__(
52+
self,
53+
*,
54+
cleanlab_project: Project,
55+
fallback_response: str = "Sorry I am unsure. You can try rephrasing your request.",
56+
skip_validating_tool_calls: bool = False,
57+
context_retrieval_tools: list[str] | None = None,
58+
validate_every_response: bool = True,
59+
) -> None:
60+
"""Initialize Cleanlab response rewriter hook V3."""
61+
super().__init__()
62+
self.cleanlab_project = cleanlab_project
63+
self.fallback_response = fallback_response
64+
self.skip_validating_tool_calls = skip_validating_tool_calls
65+
self.context_retrieval_tools = context_retrieval_tools or []
66+
self.validate_every_response = validate_every_response
67+
68+
# Populated by on_llm_start with actual conversation history
69+
self._conversation_history: list[ChatCompletionMessageParam] = []
70+
self._system_prompt: Optional[str] = None
71+
self._latest_response_text: Optional[str] = None
72+
73+
async def on_llm_start(
74+
self,
75+
context: RunContextWrapper[TContext],
76+
agent: Any, # noqa: ARG002
77+
system_prompt: str | None,
78+
input_items: list[TResponseInputItem],
79+
) -> None:
80+
"""Capture the conversation history being sent to the LLM and set up context for storing results."""
81+
raw_messages = Converter.items_to_messages(input_items)
82+
self._conversation_history = raw_messages
83+
self._system_prompt = system_prompt
84+
if context.context is None:
85+
context.context = type("CleanlabContext", (), {})()
86+
87+
async def on_llm_end(self, context: RunContextWrapper[TContext], agent: Any, response: ModelResponse) -> None:
88+
"""Intercept and potentially rewrite model response before tool execution."""
89+
# Perform Cleanlab validation with actual conversation history
90+
validation_result = await self._cleanlab_validate(response, context, agent)
91+
92+
# Rewrite response if validation indicates we should
93+
await self.cleanlab_get_final_response(response, validation_result)
94+
95+
# Store validation result in context
96+
context.context.latest_cleanlab_validation_result = validation_result # type: ignore[attr-defined]
97+
context.context.latest_initial_response_text = self._get_latest_response_text(response) # type: ignore[attr-defined]
98+
99+
# Clear state vars
100+
self._latest_response_text = None
101+
102+
def _should_validate_response(self, response: ModelResponse) -> bool:
103+
"""Determine if this response should be validated with Cleanlab."""
104+
if self.skip_validating_tool_calls and self._response_has_tool_calls(response):
105+
return False
106+
return self._response_has_content(response)
107+
108+
def _response_has_tool_calls(self, response: ModelResponse) -> bool:
109+
"""Check if model response contains tool calls."""
110+
for item in response.output:
111+
# Check for tool calls in various formats
112+
if hasattr(item, "tool_calls") and item.tool_calls:
113+
return True
114+
if hasattr(item, "type") and "function_call" in str(item.type).lower():
115+
return True
116+
if "FunctionToolCall" in type(item).__name__:
117+
return True
118+
return False
119+
120+
def _response_has_content(self, response: ModelResponse) -> bool:
121+
"""Check if response has content that can be validated."""
122+
return bool(self._get_latest_response_text(response).strip())
123+
124+
def _get_latest_response_text(self, response: ModelResponse) -> str:
125+
"""Extract text content from model response."""
126+
if self._latest_response_text is None:
127+
self._latest_response_text = form_response_string_responses_api_from_response(response)
128+
return self._latest_response_text
129+
130+
def _get_latest_user_query(self) -> str:
131+
"""Extract the most recent user query from the actual conversation history."""
132+
for item in reversed(self._conversation_history):
133+
if isinstance(item, dict) and item.get("role") == "user":
134+
content = item.get("content", "")
135+
if isinstance(content, str):
136+
return content
137+
return ""
138+
139+
def _get_context_as_string(self, messages: list[ChatCompletionMessageParam]) -> str:
140+
"""Extract context from tool results in the agent's messages."""
141+
context_parts = ""
142+
for tool_name in self.context_retrieval_tools:
143+
tool_result_text = get_tool_result_as_text(messages, tool_name)
144+
if tool_result_text:
145+
context_parts += f"Context from tool {tool_name}:\n{tool_result_text}\n\n"
146+
147+
return context_parts
148+
149+
async def _cleanlab_validate(
150+
self, response: ModelResponse, context: RunContextWrapper[TContext], agent: Any
151+
) -> ProjectValidateResponse:
152+
"""Validate the model response using Cleanlab with actual conversation history."""
153+
# Step 1 - Convert hook items to Cleanlab format
154+
tools_dict = (
155+
[Converter.tool_to_openai(tool) for tool in agent.tools if isinstance(tool, FunctionTool)]
156+
if agent.tools
157+
else None
158+
)
159+
cleanlab_messages = list(self._conversation_history)
160+
if self._system_prompt:
161+
cleanlab_messages.insert(
162+
0,
163+
{
164+
"content": self._system_prompt,
165+
"role": "system",
166+
},
167+
)
168+
169+
session_id = getattr(context, "session_id", None) or "unknown"
170+
171+
# Step 2 - Get additional validation fields
172+
validate_fields = self.cleanlab_get_validate_fields(cleanlab_messages)
173+
eval_scores = None
174+
if not self._should_validate_response(response):
175+
eval_scores = {
176+
"trustworthiness": 1.0,
177+
"response_helpfulness": 1.0,
178+
"context_sufficiency": 1.0,
179+
"query_ease": 1.0,
180+
"response_groundedness": 1.0,
181+
}
182+
183+
# Step 3 - Run validation
184+
return self.cleanlab_project.validate(
185+
response=self._get_latest_response_text(response),
186+
messages=cleanlab_messages,
187+
tools=tools_dict,
188+
metadata={
189+
"thread_id": session_id,
190+
"agent_name": getattr(agent, "name", "unknown"),
191+
},
192+
eval_scores=eval_scores,
193+
**validate_fields,
194+
)
195+
196+
def cleanlab_get_validate_fields(self, messages: list[ChatCompletionMessageParam]) -> dict[str, Any]:
197+
"""
198+
Extract query and context fields from Strands messages for cleanlab validation.
199+
200+
Processes conversation messages to extract the user query and any
201+
contextual information from specified tool results.
202+
203+
Args:
204+
messages: Conversation messages to process
205+
206+
Returns:
207+
Dictionary with 'query' and 'context' fields for validation
208+
"""
209+
user_message = self._get_latest_user_query()
210+
context = self._get_context_as_string(messages)
211+
return {
212+
"query": user_message,
213+
"context": context,
214+
}
215+
216+
async def cleanlab_get_final_response(
217+
self, response: ModelResponse, validation_result: ProjectValidateResponse
218+
) -> None:
219+
"""
220+
Determine the final response content based on cleanlab validation results.
221+
222+
Checks validation results for expert answers or guardrail triggers,
223+
returning either the original response or a replacement.
224+
225+
Args:
226+
results: Validation results from cleanlab
227+
initial_response: Original model response content
228+
fallback_response: Fallback text for guardrailed responses
229+
230+
Returns:
231+
Tuple of (final_content, was_replaced_flag)
232+
"""
233+
replacement_text = None
234+
if validation_result.expert_answer:
235+
replacement_text = validation_result.expert_answer
236+
elif validation_result.should_guardrail:
237+
replacement_text = self.fallback_response
238+
239+
if replacement_text:
240+
_rewrite_response_content_inplace(response, replacement_text)

0 commit comments

Comments
 (0)