diff --git a/src/pipeline/backend/backend.py b/src/pipeline/backend/backend.py index f0279008..e1a025ce 100644 --- a/src/pipeline/backend/backend.py +++ b/src/pipeline/backend/backend.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Iterable, List +from typing import Iterable, List, AsyncGenerator from emd.models import Model,Engine from typing import Iterable, List import os @@ -200,7 +200,7 @@ def _transform_response(self, response): async def _atransform_response(self, response): # Transform response to sagemaker format - return self._aget_response(response) + return await self._aget_response(response) def _transform_streaming_response(self, response): # Transform response to sagemaker format @@ -226,7 +226,7 @@ def _get_streaming_response(self, response) -> Iterable[List[str]]: logger.error(traceback.format_exc()) yield self._format_streaming_response(json.dumps({"error": str(e)})) - async def _aget_streaming_response(self, response) -> Iterable[List[str]]: + async def _aget_streaming_response(self, response) -> AsyncGenerator[str, None]: try: async for chunk in response: logger.info(f"chunk: {chunk}")