From f6f54c397846ae0ff83ae328be626df790a0130f Mon Sep 17 00:00:00 2001 From: Tang Jie Date: Wed, 19 Mar 2025 07:49:14 +0000 Subject: [PATCH] fix: non-stream async invoke error. --- src/pipeline/backend/backend.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/pipeline/backend/backend.py b/src/pipeline/backend/backend.py index f0279008..e1a025ce 100644 --- a/src/pipeline/backend/backend.py +++ b/src/pipeline/backend/backend.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Iterable, List +from typing import Iterable, List, AsyncGenerator from emd.models import Model,Engine from typing import Iterable, List import os @@ -200,7 +200,7 @@ def _transform_response(self, response): async def _atransform_response(self, response): # Transform response to sagemaker format - return self._aget_response(response) + return await self._aget_response(response) def _transform_streaming_response(self, response): # Transform response to sagemaker format @@ -226,7 +226,7 @@ def _get_streaming_response(self, response) -> Iterable[List[str]]: logger.error(traceback.format_exc()) yield self._format_streaming_response(json.dumps({"error": str(e)})) - async def _aget_streaming_response(self, response) -> Iterable[List[str]]: + async def _aget_streaming_response(self, response) -> AsyncGenerator[str, None]: try: async for chunk in response: logger.info(f"chunk: {chunk}")