diff --git a/ads/aqua/common/enums.py b/ads/aqua/common/enums.py index 21a71606c..3897eb049 100644 --- a/ads/aqua/common/enums.py +++ b/ads/aqua/common/enums.py @@ -24,6 +24,7 @@ class PredictEndpoints(ExtendedEnum): CHAT_COMPLETIONS_ENDPOINT = "/v1/chat/completions" TEXT_COMPLETIONS_ENDPOINT = "/v1/completions" EMBEDDING_ENDPOINT = "/v1/embedding" + RESPONSES = "/v1/responses" class Tags(ExtendedEnum): diff --git a/ads/aqua/extension/deployment_handler.py b/ads/aqua/extension/deployment_handler.py index 3d23a1052..7b2b58f24 100644 --- a/ads/aqua/extension/deployment_handler.py +++ b/ads/aqua/extension/deployment_handler.py @@ -7,14 +7,15 @@ from tornado.web import HTTPError -from ads.aqua.app import logger from ads.aqua.client.client import Client, ExtendedRequestError +from ads.aqua.client.openai_client import OpenAI from ads.aqua.common.decorator import handle_exceptions from ads.aqua.common.enums import PredictEndpoints from ads.aqua.extension.base_handler import AquaAPIhandler from ads.aqua.extension.errors import Errors from ads.aqua.modeldeployment import AquaDeploymentApp from ads.config import COMPARTMENT_OCID +from ads.aqua import logger class AquaDeploymentHandler(AquaAPIhandler): @@ -221,11 +222,102 @@ def list_shapes(self): class AquaDeploymentStreamingInferenceHandler(AquaAPIhandler): + + def _extract_text_from_choice(self, choice: dict) -> str: + """ + Extract text content from a single choice structure. + + Handles both dictionary-based API responses and object-based SDK responses. + For dict choices, it checks delta-based streaming fields, message-based + non-streaming fields, and finally top-level text/content keys. + For object choices, it inspects `.delta`, `.message`, and top-level + `.text` or `.content` attributes. + + Parameters + ---------- + choice : dict + A choice entry from a model response. It may be: + - A dict originating from a JSON API response (streaming or non-streaming). + - An SDK-style object with attributes such as `delta`, `message`, + `text`, or `content`. + + For dicts, the method checks: + • delta → content/text + • message → content/text + • top-level → text/content + + For objects, the method checks the same fields via attributes. + + Returns + ------- + str | None: + The extracted text if present; otherwise None. + """ + # choice may be a dict or an object + if isinstance(choice, dict): + # streaming chunk: {"delta": {"content": "..."}} + delta = choice.get("delta") + if isinstance(delta, dict): + return delta.get("content") or delta.get("text") or None + # non-streaming: {"message": {"content": "..."}} + msg = choice.get("message") + if isinstance(msg, dict): + return msg.get("content") or msg.get("text") + # fallback top-level fields + return choice.get("text") or choice.get("content") + # object-like choice + delta = getattr(choice, "delta", None) + if delta is not None: + return getattr(delta, "content", None) or getattr(delta, "text", None) + msg = getattr(choice, "message", None) + if msg is not None: + if isinstance(msg, str): + return msg + return getattr(msg, "content", None) or getattr(msg, "text", None) + return getattr(choice, "text", None) or getattr(choice, "content", None) + + def _extract_text_from_chunk(self, chunk: dict) -> str : + """ + Extract text content from a model response chunk. + + Supports both dict-form chunks (streaming or non-streaming) and SDK-style + object chunks. When choices are present, extraction is delegated to + `_extract_text_from_choice`. If no choices exist, top-level text/content + fields or attributes are used. + + Parameters + ---------- + chunk : dict + A chunk returned from a model stream or full response. It may be: + - A dict containing a `choices` list or top-level text/content fields. + - An SDK-style object with a `choices` attribute or top-level + `text`/`content` attributes. + + If `choices` is present, the method extracts text from the first + choice using `_extract_text_from_choice`. Otherwise, it falls back + to top-level text/content. + Returns + ------- + str + The extracted text if present; otherwise None. + """ + if chunk : + if isinstance(chunk, dict): + choices = chunk.get("choices") or [] + if choices: + return self._extract_text_from_choice(choices[0]) + # fallback top-level + return chunk.get("text") or chunk.get("content") + # object-like chunk + choices = getattr(chunk, "choices", None) + if choices: + return self._extract_text_from_choice(choices[0]) + return getattr(chunk, "text", None) or getattr(chunk, "content", None) + def _get_model_deployment_response( self, model_deployment_id: str, - payload: dict, - route_override_header: Optional[str], + payload: dict ): """ Returns the model deployment inference response in a streaming fashion. @@ -272,53 +364,173 @@ def _get_model_deployment_response( """ model_deployment = AquaDeploymentApp().get(model_deployment_id) - endpoint = model_deployment.endpoint + "/predictWithResponseStream" - endpoint_type = model_deployment.environment_variables.get( - "MODEL_DEPLOY_PREDICT_ENDPOINT", PredictEndpoints.TEXT_COMPLETIONS_ENDPOINT - ) - aqua_client = Client(endpoint=endpoint) - - if PredictEndpoints.CHAT_COMPLETIONS_ENDPOINT in ( - endpoint_type, - route_override_header, - ): + endpoint = model_deployment.endpoint + "/predictWithResponseStream/v1" + + required_keys = ["endpoint_type", "prompt", "model"] + missing = [k for k in required_keys if k not in payload] + + if missing: + raise HTTPError(400, f"Missing required payload keys: {', '.join(missing)}") + + endpoint_type = payload["endpoint_type"] + aqua_client = OpenAI(base_url=endpoint) + + allowed = { + "max_tokens", + "temperature", + "top_p", + "stop", + "n", + "presence_penalty", + "frequency_penalty", + "logprobs", + "user", + "echo", + } + responses_allowed = { + "temperature", "top_p" + } + + # normalize and filter + if payload.get("stop") == []: + payload["stop"] = None + + encoded_image = "NA" + if encoded_image in payload : + encoded_image = payload["encoded_image"] + + model = payload.pop("model") + filtered = {k: v for k, v in payload.items() if k in allowed} + responses_filtered = {k: v for k, v in payload.items() if k in responses_allowed} + + if PredictEndpoints.CHAT_COMPLETIONS_ENDPOINT == endpoint_type and encoded_image == "NA": try: - for chunk in aqua_client.chat( - messages=payload.pop("messages"), - payload=payload, - stream=True, - ): - try: - if "text" in chunk["choices"][0]: - yield chunk["choices"][0]["text"] - elif "content" in chunk["choices"][0]["delta"]: - yield chunk["choices"][0]["delta"]["content"] - except Exception as e: - logger.debug( - f"Exception occurred while parsing streaming response: {e}" - ) + api_kwargs = { + "model": model, + "messages": [{"role": "user", "content": payload["prompt"]}], + "stream": True, + **filtered + } + if "chat_template" in payload: + chat_template = payload.pop("chat_template") + api_kwargs["extra_body"] = {"chat_template": chat_template} + + stream = aqua_client.chat.completions.create(**api_kwargs) + + for chunk in stream: + if chunk : + piece = self._extract_text_from_chunk(chunk) + if piece : + yield piece except ExtendedRequestError as ex: raise HTTPError(400, str(ex)) except Exception as ex: raise HTTPError(500, str(ex)) + elif ( + endpoint_type == PredictEndpoints.CHAT_COMPLETIONS_ENDPOINT + and encoded_image != "NA" + ): + file_type = payload.pop("file_type") + if file_type.startswith("image"): + api_kwargs = { + "model": model, + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": payload["prompt"]}, + { + "type": "image_url", + "image_url": {"url": f"{encoded_image}"}, + }, + ], + } + ], + "stream": True, + **filtered + } + + # Add chat_template for image-based chat completions + if "chat_template" in payload: + chat_template = payload.pop("chat_template") + api_kwargs["extra_body"] = {"chat_template": chat_template} + + response = aqua_client.chat.completions.create(**api_kwargs) + + elif file_type.startswith("audio"): + api_kwargs = { + "model": model, + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": payload["prompt"]}, + { + "type": "audio_url", + "audio_url": {"url": f"{encoded_image}"}, + }, + ], + } + ], + "stream": True, + **filtered + } + + # Add chat_template for audio-based chat completions + if "chat_template" in payload: + chat_template = payload.pop("chat_template") + api_kwargs["extra_body"] = {"chat_template": chat_template} + + response = aqua_client.chat.completions.create(**api_kwargs) + try: + for chunk in response: + piece = self._extract_text_from_chunk(chunk) + if piece: + yield piece + except ExtendedRequestError as ex: + raise HTTPError(400, str(ex)) + except Exception as ex: + raise HTTPError(500, str(ex)) elif endpoint_type == PredictEndpoints.TEXT_COMPLETIONS_ENDPOINT: try: - for chunk in aqua_client.generate( - prompt=payload.pop("prompt"), - payload=payload, - stream=True, + for chunk in aqua_client.completions.create( + prompt=payload["prompt"], stream=True, model=model, **filtered ): - try: - yield chunk["choices"][0]["text"] - except Exception as e: - logger.debug( - f"Exception occurred while parsing streaming response: {e}" - ) + if chunk : + piece = self._extract_text_from_chunk(chunk) + if piece : + yield piece + except ExtendedRequestError as ex: + raise HTTPError(400, str(ex)) + except Exception as ex: + raise HTTPError(500, str(ex)) + + elif endpoint_type == PredictEndpoints.RESPONSES: + api_kwargs = { + "model": model, + "input": payload["prompt"], + "stream": True + } + + if "temperature" in responses_filtered: + api_kwargs["temperature"] = responses_filtered["temperature"] + if "top_p" in responses_filtered: + api_kwargs["top_p"] = responses_filtered["top_p"] + + response = aqua_client.responses.create(**api_kwargs) + try: + for chunk in response: + if chunk : + piece = self._extract_text_from_chunk(chunk) + if piece : + yield piece except ExtendedRequestError as ex: raise HTTPError(400, str(ex)) except Exception as ex: raise HTTPError(500, str(ex)) + else: + raise HTTPError(400, f"Unsupported endpoint_type: {endpoint_type}") @handle_exceptions def post(self, model_deployment_id): @@ -340,16 +552,16 @@ def post(self, model_deployment_id): prompt = input_data.get("prompt") messages = input_data.get("messages") + if not prompt and not messages: raise HTTPError( 400, Errors.MISSING_REQUIRED_PARAMETER.format("prompt/messages") ) if not input_data.get("model"): raise HTTPError(400, Errors.MISSING_REQUIRED_PARAMETER.format("model")) - route_override_header = self.request.headers.get("route", None) self.set_header("Content-Type", "text/event-stream") response_gen = self._get_model_deployment_response( - model_deployment_id, input_data, route_override_header + model_deployment_id, input_data ) try: for chunk in response_gen: @@ -357,7 +569,7 @@ def post(self, model_deployment_id): self.flush() self.finish() except Exception as ex: - self.set_status(ex.status_code) + self.set_status(getattr(ex, "status_code", 500)) self.write({"message": "Error occurred", "reason": str(ex)}) self.finish() diff --git a/tests/unitary/with_extras/aqua/test_deployment_handler.py b/tests/unitary/with_extras/aqua/test_deployment_handler.py index f6ca6d271..c3529e748 100644 --- a/tests/unitary/with_extras/aqua/test_deployment_handler.py +++ b/tests/unitary/with_extras/aqua/test_deployment_handler.py @@ -8,9 +8,12 @@ import unittest from importlib import reload from unittest.mock import MagicMock, patch +from urllib.error import HTTPError +from ads.aqua.common.enums import PredictEndpoints from notebook.base.handlers import IPythonHandler from parameterized import parameterized +import openai import ads.aqua import ads.config @@ -245,6 +248,9 @@ def test_validate_deployment_params( class TestAquaDeploymentStreamingInferenceHandler(unittest.TestCase): + + EXPECTED_OCID = "ocid1.compartment.oc1..aaaaaaaaser65kfcfht7iddoioa4s6xos3vi53d3i7bi3czjkqyluawp2itq" + @patch.object(IPythonHandler, "__init__") def setUp(self, ipython_init_mock) -> None: ipython_init_mock.return_value = None @@ -274,13 +280,85 @@ def test_post(self, mock_get_model_deployment_response): mock_get_model_deployment_response.assert_called_with( "mock-deployment-id", - {"prompt": "Hello", "model": "some-model"}, - "test-route", + {"prompt": "Hello", "model": "some-model"} ) self.handler.write.assert_any_call("chunk1") self.handler.write.assert_any_call("chunk2") self.handler.finish.assert_called_once() + def test_extract_text_from_choice_dict_delta_content(self): + """Test dict choice with delta.content.""" + choice = {"delta": {"content": "hello"}} + result = self.handler._extract_text_from_choice(choice) + self.assertEqual(result, "hello") + + def test_extract_text_from_choice_dict_delta_text(self): + """Test dict choice with delta.text fallback.""" + choice = {"delta": {"text": "world"}} + result = self.handler._extract_text_from_choice(choice) + self.assertEqual(result, "world") + + def test_extract_text_from_choice_dict_message_content(self): + """Test dict choice with message.content.""" + choice = {"message": {"content": "foo"}} + result = self.handler._extract_text_from_choice(choice) + self.assertEqual(result, "foo") + + def test_extract_text_from_choice_dict_top_level_text(self): + """Test dict choice with top-level text.""" + choice = {"text": "bar"} + result = self.handler._extract_text_from_choice(choice) + self.assertEqual(result, "bar") + + def test_extract_text_from_choice_object_delta_content(self): + """Test object choice with delta.content attribute.""" + choice = MagicMock() + choice.delta = MagicMock(content="obj-content", text=None) + result = self.handler._extract_text_from_choice(choice) + self.assertEqual(result, "obj-content") + + def test_extract_text_from_choice_object_message_str(self): + """Test object choice with message as string.""" + choice = MagicMock() + choice.delta = None # No delta, so message takes precedence + choice.message = "direct-string" + result = self.handler._extract_text_from_choice(choice) + self.assertEqual(result, "direct-string") + + def test_extract_text_from_choice_none_return(self): + """Test choice with no text content returns None.""" + choice = {} + result = self.handler._extract_text_from_choice(choice) + self.assertIsNone(result) + + def test_extract_text_from_chunk_dict_with_choices(self): + """Test chunk dict with choices list.""" + chunk = {"choices": [{"delta": {"content": "chunk-text"}}]} + result = self.handler._extract_text_from_chunk(chunk) + self.assertEqual(result, "chunk-text") + + def test_extract_text_from_chunk_dict_top_level_content(self): + """Test chunk dict with top-level content (no choices).""" + chunk = {"content": "direct-content"} + result = self.handler._extract_text_from_chunk(chunk) + self.assertEqual(result, "direct-content") + + def test_extract_text_from_chunk_object_choices(self): + """Test object chunk with choices attribute.""" + chunk = MagicMock() + chunk.choices = [{"message": {"content": "obj-chunk"}}] + result = self.handler._extract_text_from_chunk(chunk) + self.assertEqual(result, "obj-chunk") + + def test_extract_text_from_chunk_empty(self): + """Test empty/None chunk returns None.""" + result = self.handler._extract_text_from_chunk({}) + self.assertIsNone(result) + result = self.handler._extract_text_from_chunk(None) + self.assertIsNone(result) + + + class AquaModelListHandlerTestCase(unittest.TestCase): default_params = {