Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 57 additions & 31 deletions e2e-chatbot-app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
import os
import streamlit as st
from model_serving_utils import (
endpoint_supports_feedback,
query_endpoint,
query_endpoint_stream,
_get_endpoint_task_type,
_extract_trace_id,
)
from collections import OrderedDict
from messages import UserMessage, AssistantResponse, render_message
from messages import UserMessage, AssistantResponse, render_message, render_assistant_message_feedback

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
Expand All @@ -21,7 +21,6 @@
"'serving_endpoint' with CAN_QUERY permissions, as described in "
"https://docs.databricks.com/aws/en/generative-ai/agent-framework/chat-app#deploy-the-databricks-app")

ENDPOINT_SUPPORTS_FEEDBACK = endpoint_supports_feedback(SERVING_ENDPOINT)

def reduce_chat_agent_chunks(chunks):
"""
Expand Down Expand Up @@ -120,12 +119,13 @@ def query_chat_completions_endpoint_and_render(input_messages):

accumulated_content = ""
request_id = None
trace_id = None

try:
for chunk in query_endpoint_stream(
endpoint_name=SERVING_ENDPOINT,
messages=input_messages,
return_traces=ENDPOINT_SUPPORTS_FEEDBACK
return_traces=True
):
if "choices" in chunk and chunk["choices"]:
delta = chunk["choices"][0].get("delta", {})
Expand All @@ -138,23 +138,32 @@ def query_chat_completions_endpoint_and_render(input_messages):
req_id = chunk["databricks_output"].get("databricks_request_id")
if req_id:
request_id = req_id
chunk_trace_id = _extract_trace_id(chunk)
if chunk_trace_id:
trace_id = chunk_trace_id

if trace_id is not None:
render_assistant_message_feedback(len(st.session_state.history), trace_id)

return AssistantResponse(
messages=[{"role": "assistant", "content": accumulated_content}],
request_id=request_id
request_id=request_id,
trace_id=trace_id
)
except Exception:
response_area.markdown("_Ran into an error. Retrying without streaming..._")
messages, request_id = query_endpoint(
messages, request_id, trace_id = query_endpoint(
endpoint_name=SERVING_ENDPOINT,
messages=input_messages,
return_traces=ENDPOINT_SUPPORTS_FEEDBACK
return_traces=True
)
response_area.empty()
with response_area.container():
for message in messages:
render_message(message)
return AssistantResponse(messages=messages, request_id=request_id)
if trace_id is not None:
render_assistant_message_feedback(len(st.session_state.history), trace_id)
return AssistantResponse(messages=messages, request_id=request_id, trace_id=trace_id)


def query_chat_agent_endpoint_and_render(input_messages):
Expand All @@ -167,21 +176,32 @@ def query_chat_agent_endpoint_and_render(input_messages):

message_buffers = OrderedDict()
request_id = None
trace_id = None

try:
for raw_chunk in query_endpoint_stream(
endpoint_name=SERVING_ENDPOINT,
messages=input_messages,
return_traces=ENDPOINT_SUPPORTS_FEEDBACK
return_traces=True
):
# Extract trace metadata from databricks_output before parsing as ChatAgentChunk
if "databricks_output" in raw_chunk:
req_id = raw_chunk["databricks_output"].get("databricks_request_id")
if req_id:
request_id = req_id
chunk_trace_id = _extract_trace_id(raw_chunk)
if chunk_trace_id:
trace_id = chunk_trace_id

# Skip chunks that only carry databricks_output (no agent message data)
if "delta" not in raw_chunk and "choices" not in raw_chunk and "type" not in raw_chunk:
continue

response_area.empty()
chunk = ChatAgentChunk.model_validate(raw_chunk)
delta = chunk.delta
message_id = delta.id

req_id = raw_chunk.get("databricks_output", {}).get("databricks_request_id")
if req_id:
request_id = req_id
if message_id not in message_buffers:
message_buffers[message_id] = {
"chunks": [],
Expand All @@ -199,22 +219,28 @@ def query_chat_agent_endpoint_and_render(input_messages):
for msg_id, msg_info in message_buffers.items():
messages.append(reduce_chat_agent_chunks(msg_info["chunks"]))

if trace_id is not None:
render_assistant_message_feedback(len(st.session_state.history), trace_id)

return AssistantResponse(
messages=[message.model_dump_compat(exclude_none=True) for message in messages],
request_id=request_id
request_id=request_id,
trace_id=trace_id
)
except Exception:
response_area.markdown("_Ran into an error. Retrying without streaming..._")
messages, request_id = query_endpoint(
messages, request_id, trace_id = query_endpoint(
endpoint_name=SERVING_ENDPOINT,
messages=input_messages,
return_traces=ENDPOINT_SUPPORTS_FEEDBACK
return_traces=True
)
response_area.empty()
with response_area.container():
for message in messages:
render_message(message)
return AssistantResponse(messages=messages, request_id=request_id)
if trace_id is not None:
render_assistant_message_feedback(len(st.session_state.history), trace_id)
return AssistantResponse(messages=messages, request_id=request_id, trace_id=trace_id)


def query_responses_endpoint_and_render(input_messages):
Expand All @@ -225,31 +251,31 @@ def query_responses_endpoint_and_render(input_messages):
response_area = st.empty()
response_area.markdown("_Thinking..._")

# Track all the messages that need to be rendered in order
all_messages = []
request_id = None
trace_id = None

try:
for raw_event in query_endpoint_stream(
endpoint_name=SERVING_ENDPOINT,
messages=input_messages,
return_traces=ENDPOINT_SUPPORTS_FEEDBACK
return_traces=True
):
# Extract databricks_output for request_id
if "databricks_output" in raw_event:
req_id = raw_event["databricks_output"].get("databricks_request_id")
if req_id:
request_id = req_id
event_trace_id = _extract_trace_id(raw_event)
if event_trace_id:
trace_id = event_trace_id

# Parse using MLflow streaming event types, similar to ChatAgentChunk
if "type" in raw_event:
event = ResponsesAgentStreamEvent.model_validate(raw_event)

if hasattr(event, 'item') and event.item:
item = event.item # This is a dict, not a parsed object
item = event.item

if item.get("type") == "message":
# Extract text content from message if present
content_parts = item.get("content", [])
for content_part in content_parts:
if content_part.get("type") == "output_text":
Expand All @@ -261,12 +287,10 @@ def query_responses_endpoint_and_render(input_messages):
})

elif item.get("type") == "function_call":
# Tool call
call_id = item.get("call_id")
function_name = item.get("name")
arguments = item.get("arguments", "")

# Add to messages for history
all_messages.append({
"role": "assistant",
"content": "",
Expand All @@ -281,36 +305,38 @@ def query_responses_endpoint_and_render(input_messages):
})

elif item.get("type") == "function_call_output":
# Tool call output/result
call_id = item.get("call_id")
output = item.get("output", "")

# Add to messages for history
all_messages.append({
"role": "tool",
"content": output,
"tool_call_id": call_id
})

# Update the display by rendering all accumulated messages
if all_messages:
with response_area.container():
for msg in all_messages:
render_message(msg)

return AssistantResponse(messages=all_messages, request_id=request_id)
if trace_id is not None:
render_assistant_message_feedback(len(st.session_state.history), trace_id)

return AssistantResponse(messages=all_messages, request_id=request_id, trace_id=trace_id)
except Exception:
response_area.markdown("_Ran into an error. Retrying without streaming..._")
messages, request_id = query_endpoint(
messages, request_id, trace_id = query_endpoint(
endpoint_name=SERVING_ENDPOINT,
messages=input_messages,
return_traces=ENDPOINT_SUPPORTS_FEEDBACK
return_traces=True
)
response_area.empty()
with response_area.container():
for message in messages:
render_message(message)
return AssistantResponse(messages=messages, request_id=request_id)
if trace_id is not None:
render_assistant_message_feedback(len(st.session_state.history), trace_id)
return AssistantResponse(messages=messages, request_id=request_id, trace_id=trace_id)



Expand Down
22 changes: 9 additions & 13 deletions e2e-chatbot-app/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ def render(self, _):


class AssistantResponse(Message):
def __init__(self, messages, request_id):
def __init__(self, messages, request_id, trace_id=None):
super().__init__()
self.messages = messages
# Request ID tracked to enable submitting feedback on assistant responses via the feedback endpoint
self.request_id = request_id
self.trace_id = trace_id

def to_input_messages(self):
return self.messages
Expand All @@ -55,8 +55,8 @@ def render(self, idx):
for msg in self.messages:
render_message(msg)

if self.request_id is not None:
render_assistant_message_feedback(idx, self.request_id)
if self.trace_id is not None:
render_assistant_message_feedback(idx, self.trace_id)


def render_message(msg):
Expand All @@ -78,18 +78,14 @@ def render_message(msg):


@st.fragment
def render_assistant_message_feedback(i, request_id):
def render_assistant_message_feedback(i, trace_id):
"""Render feedback UI for assistant messages."""
from model_serving_utils import submit_feedback
import os

def save_feedback(index):
serving_endpoint = os.getenv('SERVING_ENDPOINT')
if serving_endpoint:
submit_feedback(
endpoint=serving_endpoint,
request_id=request_id,
rating=st.session_state[f"feedback_{index}"]
)
submit_feedback(
trace_id=trace_id,
rating=st.session_state[f"feedback_{index}"],
)

st.feedback("thumbs", key=f"feedback_{i}", on_change=save_feedback, args=[i])
Loading