1+ import asyncio
12from collections .abc import AsyncIterable
23
34from google .protobuf import empty_pb2 as _empty_pb2
45
6+ from pynumaflow .shared .asynciter import NonBlockingIterator
7+ from pynumaflow ._constants import _LOGGER , STREAM_EOF , ERR_UDF_EXCEPTION_STRING
58from pynumaflow .mapstreamer import Datum
69from pynumaflow .mapstreamer ._dtypes import MapStreamCallable , MapStreamError
710from pynumaflow .proto .mapper import map_pb2_grpc , map_pb2
811from pynumaflow .shared .server import handle_async_error
912from pynumaflow .types import NumaflowServicerContext
10- from pynumaflow ._constants import _LOGGER , ERR_UDF_EXCEPTION_STRING
1113
1214
1315class AsyncMapStreamServicer (map_pb2_grpc .MapServicer ):
1416 """
15- This class is used to create a new grpc Map Stream Servicer instance.
16- It implements the SyncMapServicer interface from the proto
17- map_pb2_grpc.py file.
18- Provides the functionality for the required rpc methods.
17+ Concurrent gRPC Map Stream Servicer.
18+ Spawns one background task per incoming MapRequest; each task streams
19+ results as produced and finally emits an EOT for that request.
1920 """
2021
21- def __init__ (
22- self ,
23- handler : MapStreamCallable ,
24- ):
22+ def __init__ (self , handler : MapStreamCallable ):
2523 self .__map_stream_handler : MapStreamCallable = handler
24+ self ._background_tasks : set [asyncio .Task ] = set ()
2625
2726 async def MapFn (
2827 self ,
@@ -31,51 +30,105 @@ async def MapFn(
3130 ) -> AsyncIterable [map_pb2 .MapResponse ]:
3231 """
3332 Applies a map function to a datum stream in streaming mode.
34- The pascal case function name comes from the proto map_pb2_grpc.py file.
33+ The PascalCase name comes from the generated map_pb2_grpc.py file.
3534 """
3635 try :
37- # The first message to be received should be a valid handshake
38- req = await request_iterator .__anext__ ()
39- # check if it is a valid handshake req
40- if not (req .handshake and req .handshake .sot ):
36+ # First message must be a handshake
37+ first = await request_iterator .__anext__ ()
38+ if not (first .handshake and first .handshake .sot ):
4139 raise MapStreamError ("MapStreamFn: expected handshake as the first message" )
40+ # Acknowledge handshake
4241 yield map_pb2 .MapResponse (handshake = map_pb2 .Handshake (sot = True ))
4342
44- # read for each input request
45- async for req in request_iterator :
46- # yield messages as received from the UDF
47- async for res in self .__invoke_map_stream (
48- list (req .request .keys ),
49- Datum (
50- keys = list (req .request .keys ),
51- value = req .request .value ,
52- event_time = req .request .event_time .ToDatetime (),
53- watermark = req .request .watermark .ToDatetime (),
54- headers = dict (req .request .headers ),
55- ),
56- ):
57- yield map_pb2 .MapResponse (results = [res ], id = req .id )
58- # send EOT to indicate end of transmission for a given message
59- yield map_pb2 .MapResponse (status = map_pb2 .TransmissionStatus (eot = True ), id = req .id )
60- except BaseException as err :
43+ # Global non-blocking queue for outbound responses / errors
44+ global_result_queue = NonBlockingIterator ()
45+
46+ # Start producer that turns each inbound request into a background task
47+ producer = asyncio .create_task (
48+ self ._process_inputs (request_iterator , global_result_queue )
49+ )
50+
51+ # Consume results as they arrive and stream them to the client
52+ async for msg in global_result_queue .read_iterator ():
53+ if isinstance (msg , BaseException ):
54+ await handle_async_error (context , msg , ERR_UDF_EXCEPTION_STRING )
55+ return
56+ else :
57+ # msg is a map_pb2.MapResponse, already formed
58+ yield msg
59+
60+ # Ensure producer has finished (covers graceful shutdown)
61+ await producer
62+
63+ except BaseException as e :
6164 _LOGGER .critical ("UDFError, re-raising the error" , exc_info = True )
62- await handle_async_error (context , err , ERR_UDF_EXCEPTION_STRING )
65+ await handle_async_error (context , e , ERR_UDF_EXCEPTION_STRING )
6366 return
6467
65- async def __invoke_map_stream (self , keys : list [str ], req : Datum ):
68+ async def _process_inputs (
69+ self ,
70+ request_iterator : AsyncIterable [map_pb2 .MapRequest ],
71+ result_queue : NonBlockingIterator ,
72+ ) -> None :
73+ """
74+ Reads MapRequests from the client and spawns a background task per request.
75+ Each task streams results to result_queue as they are produced.
76+ """
6677 try :
67- # Invoke the user handler for map stream
68- async for msg in self .__map_stream_handler (keys , req ):
69- yield map_pb2 .MapResponse .Result (keys = msg .keys , value = msg .value , tags = msg .tags )
78+ async for req in request_iterator :
79+ task = asyncio .create_task (self ._invoke_map_stream (req , result_queue ))
80+ self ._background_tasks .add (task )
81+ # Remove from the set when done to avoid memory growth
82+ task .add_done_callback (self ._background_tasks .discard )
83+
84+ # Wait for all in-flight tasks to complete
85+ if self ._background_tasks :
86+ await asyncio .gather (* list (self ._background_tasks ), return_exceptions = False )
87+
88+ # Signal end-of-stream to the consumer
89+ await result_queue .put (STREAM_EOF )
90+
91+ except BaseException as e :
92+ _LOGGER .critical ("MapFn Error, re-raising the error" , exc_info = True )
93+ # Surface the error to the consumer; MapFn will handle and close the RPC
94+ await result_queue .put (e )
95+
96+ async def _invoke_map_stream (
97+ self ,
98+ req : map_pb2 .MapRequest ,
99+ result_queue : NonBlockingIterator ,
100+ ) -> None :
101+ """
102+ Invokes the user-provided async generator for a single request and
103+ pushes each result onto the global queue, followed by an EOT for this id.
104+ """
105+ try :
106+ datum = Datum (
107+ keys = list (req .request .keys ),
108+ value = req .request .value ,
109+ event_time = req .request .event_time .ToDatetime (),
110+ watermark = req .request .watermark .ToDatetime (),
111+ headers = dict (req .request .headers ),
112+ )
113+
114+ # Stream results from the user handler as they are produced
115+ async for msg in self .__map_stream_handler (list (req .request .keys ), datum ):
116+ res = map_pb2 .MapResponse .Result (keys = msg .keys , value = msg .value , tags = msg .tags )
117+ await result_queue .put (map_pb2 .MapResponse (results = [res ], id = req .id ))
118+
119+ # Emit EOT for this request id
120+ await result_queue .put (
121+ map_pb2 .MapResponse (status = map_pb2 .TransmissionStatus (eot = True ), id = req .id )
122+ )
123+
70124 except BaseException as err :
71125 _LOGGER .critical ("MapFn handler error" , exc_info = True )
72- raise err
126+ # Surface handler error to the main producer;
127+ # it will call handle_async_error and end the RPC
128+ await result_queue .put (err )
73129
74130 async def IsReady (
75131 self , request : _empty_pb2 .Empty , context : NumaflowServicerContext
76132 ) -> map_pb2 .ReadyResponse :
77- """
78- IsReady is the heartbeat endpoint for gRPC.
79- The pascal case function name comes from the proto map_pb2_grpc.py file.
80- """
133+ """Heartbeat endpoint for gRPC."""
81134 return map_pb2 .ReadyResponse (ready = True )
0 commit comments