diff --git a/e2e-chatbot-app/app.py b/e2e-chatbot-app/app.py index e483eca4..baadfb0e 100644 --- a/e2e-chatbot-app/app.py +++ b/e2e-chatbot-app/app.py @@ -2,13 +2,13 @@ import os import streamlit as st from model_serving_utils import ( - endpoint_supports_feedback, query_endpoint, query_endpoint_stream, _get_endpoint_task_type, + _extract_trace_id, ) from collections import OrderedDict -from messages import UserMessage, AssistantResponse, render_message +from messages import UserMessage, AssistantResponse, render_message, render_assistant_message_feedback logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -21,7 +21,6 @@ "'serving_endpoint' with CAN_QUERY permissions, as described in " "https://docs.databricks.com/aws/en/generative-ai/agent-framework/chat-app#deploy-the-databricks-app") -ENDPOINT_SUPPORTS_FEEDBACK = endpoint_supports_feedback(SERVING_ENDPOINT) def reduce_chat_agent_chunks(chunks): """ @@ -120,12 +119,13 @@ def query_chat_completions_endpoint_and_render(input_messages): accumulated_content = "" request_id = None + trace_id = None try: for chunk in query_endpoint_stream( endpoint_name=SERVING_ENDPOINT, messages=input_messages, - return_traces=ENDPOINT_SUPPORTS_FEEDBACK + return_traces=True ): if "choices" in chunk and chunk["choices"]: delta = chunk["choices"][0].get("delta", {}) @@ -138,23 +138,32 @@ def query_chat_completions_endpoint_and_render(input_messages): req_id = chunk["databricks_output"].get("databricks_request_id") if req_id: request_id = req_id + chunk_trace_id = _extract_trace_id(chunk) + if chunk_trace_id: + trace_id = chunk_trace_id + if trace_id is not None: + render_assistant_message_feedback(len(st.session_state.history), trace_id) + return AssistantResponse( messages=[{"role": "assistant", "content": accumulated_content}], - request_id=request_id + request_id=request_id, + trace_id=trace_id ) except Exception: response_area.markdown("_Ran into an error. Retrying without streaming..._") - messages, request_id = query_endpoint( + messages, request_id, trace_id = query_endpoint( endpoint_name=SERVING_ENDPOINT, messages=input_messages, - return_traces=ENDPOINT_SUPPORTS_FEEDBACK + return_traces=True ) response_area.empty() with response_area.container(): for message in messages: render_message(message) - return AssistantResponse(messages=messages, request_id=request_id) + if trace_id is not None: + render_assistant_message_feedback(len(st.session_state.history), trace_id) + return AssistantResponse(messages=messages, request_id=request_id, trace_id=trace_id) def query_chat_agent_endpoint_and_render(input_messages): @@ -167,21 +176,32 @@ def query_chat_agent_endpoint_and_render(input_messages): message_buffers = OrderedDict() request_id = None + trace_id = None try: for raw_chunk in query_endpoint_stream( endpoint_name=SERVING_ENDPOINT, messages=input_messages, - return_traces=ENDPOINT_SUPPORTS_FEEDBACK + return_traces=True ): + # Extract trace metadata from databricks_output before parsing as ChatAgentChunk + if "databricks_output" in raw_chunk: + req_id = raw_chunk["databricks_output"].get("databricks_request_id") + if req_id: + request_id = req_id + chunk_trace_id = _extract_trace_id(raw_chunk) + if chunk_trace_id: + trace_id = chunk_trace_id + + # Skip chunks that only carry databricks_output (no agent message data) + if "delta" not in raw_chunk and "choices" not in raw_chunk and "type" not in raw_chunk: + continue + response_area.empty() chunk = ChatAgentChunk.model_validate(raw_chunk) delta = chunk.delta message_id = delta.id - req_id = raw_chunk.get("databricks_output", {}).get("databricks_request_id") - if req_id: - request_id = req_id if message_id not in message_buffers: message_buffers[message_id] = { "chunks": [], @@ -199,22 +219,28 @@ def query_chat_agent_endpoint_and_render(input_messages): for msg_id, msg_info in message_buffers.items(): messages.append(reduce_chat_agent_chunks(msg_info["chunks"])) + if trace_id is not None: + render_assistant_message_feedback(len(st.session_state.history), trace_id) + return AssistantResponse( messages=[message.model_dump_compat(exclude_none=True) for message in messages], - request_id=request_id + request_id=request_id, + trace_id=trace_id ) except Exception: response_area.markdown("_Ran into an error. Retrying without streaming..._") - messages, request_id = query_endpoint( + messages, request_id, trace_id = query_endpoint( endpoint_name=SERVING_ENDPOINT, messages=input_messages, - return_traces=ENDPOINT_SUPPORTS_FEEDBACK + return_traces=True ) response_area.empty() with response_area.container(): for message in messages: render_message(message) - return AssistantResponse(messages=messages, request_id=request_id) + if trace_id is not None: + render_assistant_message_feedback(len(st.session_state.history), trace_id) + return AssistantResponse(messages=messages, request_id=request_id, trace_id=trace_id) def query_responses_endpoint_and_render(input_messages): @@ -225,31 +251,31 @@ def query_responses_endpoint_and_render(input_messages): response_area = st.empty() response_area.markdown("_Thinking..._") - # Track all the messages that need to be rendered in order all_messages = [] request_id = None + trace_id = None try: for raw_event in query_endpoint_stream( endpoint_name=SERVING_ENDPOINT, messages=input_messages, - return_traces=ENDPOINT_SUPPORTS_FEEDBACK + return_traces=True ): - # Extract databricks_output for request_id if "databricks_output" in raw_event: req_id = raw_event["databricks_output"].get("databricks_request_id") if req_id: request_id = req_id + event_trace_id = _extract_trace_id(raw_event) + if event_trace_id: + trace_id = event_trace_id - # Parse using MLflow streaming event types, similar to ChatAgentChunk if "type" in raw_event: event = ResponsesAgentStreamEvent.model_validate(raw_event) if hasattr(event, 'item') and event.item: - item = event.item # This is a dict, not a parsed object + item = event.item if item.get("type") == "message": - # Extract text content from message if present content_parts = item.get("content", []) for content_part in content_parts: if content_part.get("type") == "output_text": @@ -261,12 +287,10 @@ def query_responses_endpoint_and_render(input_messages): }) elif item.get("type") == "function_call": - # Tool call call_id = item.get("call_id") function_name = item.get("name") arguments = item.get("arguments", "") - # Add to messages for history all_messages.append({ "role": "assistant", "content": "", @@ -281,36 +305,38 @@ def query_responses_endpoint_and_render(input_messages): }) elif item.get("type") == "function_call_output": - # Tool call output/result call_id = item.get("call_id") output = item.get("output", "") - # Add to messages for history all_messages.append({ "role": "tool", "content": output, "tool_call_id": call_id }) - # Update the display by rendering all accumulated messages if all_messages: with response_area.container(): for msg in all_messages: render_message(msg) - return AssistantResponse(messages=all_messages, request_id=request_id) + if trace_id is not None: + render_assistant_message_feedback(len(st.session_state.history), trace_id) + + return AssistantResponse(messages=all_messages, request_id=request_id, trace_id=trace_id) except Exception: response_area.markdown("_Ran into an error. Retrying without streaming..._") - messages, request_id = query_endpoint( + messages, request_id, trace_id = query_endpoint( endpoint_name=SERVING_ENDPOINT, messages=input_messages, - return_traces=ENDPOINT_SUPPORTS_FEEDBACK + return_traces=True ) response_area.empty() with response_area.container(): for message in messages: render_message(message) - return AssistantResponse(messages=messages, request_id=request_id) + if trace_id is not None: + render_assistant_message_feedback(len(st.session_state.history), trace_id) + return AssistantResponse(messages=messages, request_id=request_id, trace_id=trace_id) diff --git a/e2e-chatbot-app/messages.py b/e2e-chatbot-app/messages.py index 3c97cb02..4c09f954 100644 --- a/e2e-chatbot-app/messages.py +++ b/e2e-chatbot-app/messages.py @@ -41,11 +41,11 @@ def render(self, _): class AssistantResponse(Message): - def __init__(self, messages, request_id): + def __init__(self, messages, request_id, trace_id=None): super().__init__() self.messages = messages - # Request ID tracked to enable submitting feedback on assistant responses via the feedback endpoint self.request_id = request_id + self.trace_id = trace_id def to_input_messages(self): return self.messages @@ -55,8 +55,8 @@ def render(self, idx): for msg in self.messages: render_message(msg) - if self.request_id is not None: - render_assistant_message_feedback(idx, self.request_id) + if self.trace_id is not None: + render_assistant_message_feedback(idx, self.trace_id) def render_message(msg): @@ -78,18 +78,14 @@ def render_message(msg): @st.fragment -def render_assistant_message_feedback(i, request_id): +def render_assistant_message_feedback(i, trace_id): """Render feedback UI for assistant messages.""" from model_serving_utils import submit_feedback - import os def save_feedback(index): - serving_endpoint = os.getenv('SERVING_ENDPOINT') - if serving_endpoint: - submit_feedback( - endpoint=serving_endpoint, - request_id=request_id, - rating=st.session_state[f"feedback_{index}"] - ) + submit_feedback( + trace_id=trace_id, + rating=st.session_state[f"feedback_{index}"], + ) st.feedback("thumbs", key=f"feedback_{i}", on_change=save_feedback, args=[i]) \ No newline at end of file diff --git a/e2e-chatbot-app/model_serving_utils.py b/e2e-chatbot-app/model_serving_utils.py index a2bf360f..bc9f81ff 100644 --- a/e2e-chatbot-app/model_serving_utils.py +++ b/e2e-chatbot-app/model_serving_utils.py @@ -23,6 +23,16 @@ def _get_endpoint_task_type(endpoint_name: str) -> str: except Exception: return "chat/completions" +def _extract_trace_id(response): + """Extract trace_id from a Databricks model serving response.""" + try: + return (response.get("databricks_output", {}) + .get("trace", {}) + .get("info", {}) + .get("trace_id")) + except (AttributeError, TypeError): + return None + def _convert_to_responses_format(messages): """Convert chat messages to ResponsesAgent API format.""" input_messages = [] @@ -116,8 +126,8 @@ def _query_responses_endpoint_stream(endpoint_name: str, messages: list[dict[str def query_endpoint(endpoint_name, messages, return_traces): """ - Query an endpoint, returning the string message content and request - ID for feedback + Query an endpoint, returning the string message content, request + ID, and trace_id for feedback """ task_type = _get_endpoint_task_type(endpoint_name) @@ -137,8 +147,9 @@ def _query_chat_endpoint(endpoint_name, messages, return_traces): inputs=inputs, ) request_id = res.get("databricks_output", {}).get("databricks_request_id") + trace_id = _extract_trace_id(res) if "messages" in res: - return res["messages"], request_id + return res["messages"], request_id, trace_id elif "choices" in res: choice_message = res["choices"][0]["message"] choice_content = choice_message.get("content") @@ -150,11 +161,11 @@ def _query_chat_endpoint(endpoint_name, messages, return_traces): "role": choice_message.get("role"), "content": combined_content } - return [reformatted_message], request_id + return [reformatted_message], request_id, trace_id # Case 2: The content is a simple string elif isinstance(choice_content, str): - return [choice_message], request_id + return [choice_message], request_id, trace_id _throw_unexpected_endpoint_format() @@ -178,6 +189,7 @@ def _query_responses_endpoint(endpoint_name, messages, return_traces): # Extract messages from the response result_messages = [] request_id = response.get("databricks_output", {}).get("databricks_request_id") + trace_id = _extract_trace_id(response) # Process the output items from ResponsesAgent response output_items = response.get("output", []) @@ -231,40 +243,21 @@ def _query_responses_endpoint(endpoint_name, messages, return_traces): "tool_call_id": call_id }) - return result_messages or [{"role": "assistant", "content": "No response found"}], request_id + return result_messages or [{"role": "assistant", "content": "No response found"}], request_id, trace_id -def submit_feedback(endpoint, request_id, rating): - """Submit feedback to the agent.""" - rating_string = "positive" if rating == 1 else "negative" - text_assessments = [] if rating is None else [{ - "ratings": { - "answer_correct": {"value": rating_string}, - }, - "free_text_comment": None - }] +def submit_feedback(trace_id, rating, user_id="chatbot-user"): + """Submit feedback on an agent response using MLflow trace feedback.""" + import mlflow + from mlflow.entities import AssessmentSource - proxy_payload = { - "dataframe_records": [ - { - "source": json.dumps({ - "id": "e2e-chatbot-app", # Or extract from auth - "type": "human" - }), - "request_id": request_id, - "text_assessments": json.dumps(text_assessments), - "retrieval_assessments": json.dumps([]), - } - ] - } - w = WorkspaceClient() - return w.api_client.do( - method='POST', - path=f"/serving-endpoints/{endpoint}/served-models/feedback/invocations", - body=proxy_payload, + mlflow.set_tracking_uri("databricks") + + is_correct = rating == 1 + mlflow.log_feedback( + trace_id=trace_id, + name="User feedback", + value=is_correct, + source=AssessmentSource(source_type="HUMAN", source_id=user_id), ) -def endpoint_supports_feedback(endpoint_name): - w = WorkspaceClient() - endpoint = w.serving_endpoints.get(endpoint_name) - return "feedback" in [entity.name for entity in endpoint.config.served_entities]