Skip to content

Commit bf00b2e

Browse files
authored
feat: parallelized mapstreamer (#242)
Signed-off-by: kohlisid <[email protected]>
1 parent b944a5d commit bf00b2e

File tree

2 files changed

+146
-65
lines changed

2 files changed

+146
-65
lines changed
Lines changed: 94 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,27 @@
1+
import asyncio
12
from collections.abc import AsyncIterable
23

34
from google.protobuf import empty_pb2 as _empty_pb2
45

6+
from pynumaflow.shared.asynciter import NonBlockingIterator
7+
from pynumaflow._constants import _LOGGER, STREAM_EOF, ERR_UDF_EXCEPTION_STRING
58
from pynumaflow.mapstreamer import Datum
69
from pynumaflow.mapstreamer._dtypes import MapStreamCallable, MapStreamError
710
from pynumaflow.proto.mapper import map_pb2_grpc, map_pb2
811
from pynumaflow.shared.server import handle_async_error
912
from pynumaflow.types import NumaflowServicerContext
10-
from pynumaflow._constants import _LOGGER, ERR_UDF_EXCEPTION_STRING
1113

1214

1315
class AsyncMapStreamServicer(map_pb2_grpc.MapServicer):
1416
"""
15-
This class is used to create a new grpc Map Stream Servicer instance.
16-
It implements the SyncMapServicer interface from the proto
17-
map_pb2_grpc.py file.
18-
Provides the functionality for the required rpc methods.
17+
Concurrent gRPC Map Stream Servicer.
18+
Spawns one background task per incoming MapRequest; each task streams
19+
results as produced and finally emits an EOT for that request.
1920
"""
2021

21-
def __init__(
22-
self,
23-
handler: MapStreamCallable,
24-
):
22+
def __init__(self, handler: MapStreamCallable):
2523
self.__map_stream_handler: MapStreamCallable = handler
24+
self._background_tasks: set[asyncio.Task] = set()
2625

2726
async def MapFn(
2827
self,
@@ -31,51 +30,105 @@ async def MapFn(
3130
) -> AsyncIterable[map_pb2.MapResponse]:
3231
"""
3332
Applies a map function to a datum stream in streaming mode.
34-
The pascal case function name comes from the proto map_pb2_grpc.py file.
33+
The PascalCase name comes from the generated map_pb2_grpc.py file.
3534
"""
3635
try:
37-
# The first message to be received should be a valid handshake
38-
req = await request_iterator.__anext__()
39-
# check if it is a valid handshake req
40-
if not (req.handshake and req.handshake.sot):
36+
# First message must be a handshake
37+
first = await request_iterator.__anext__()
38+
if not (first.handshake and first.handshake.sot):
4139
raise MapStreamError("MapStreamFn: expected handshake as the first message")
40+
# Acknowledge handshake
4241
yield map_pb2.MapResponse(handshake=map_pb2.Handshake(sot=True))
4342

44-
# read for each input request
45-
async for req in request_iterator:
46-
# yield messages as received from the UDF
47-
async for res in self.__invoke_map_stream(
48-
list(req.request.keys),
49-
Datum(
50-
keys=list(req.request.keys),
51-
value=req.request.value,
52-
event_time=req.request.event_time.ToDatetime(),
53-
watermark=req.request.watermark.ToDatetime(),
54-
headers=dict(req.request.headers),
55-
),
56-
):
57-
yield map_pb2.MapResponse(results=[res], id=req.id)
58-
# send EOT to indicate end of transmission for a given message
59-
yield map_pb2.MapResponse(status=map_pb2.TransmissionStatus(eot=True), id=req.id)
60-
except BaseException as err:
43+
# Global non-blocking queue for outbound responses / errors
44+
global_result_queue = NonBlockingIterator()
45+
46+
# Start producer that turns each inbound request into a background task
47+
producer = asyncio.create_task(
48+
self._process_inputs(request_iterator, global_result_queue)
49+
)
50+
51+
# Consume results as they arrive and stream them to the client
52+
async for msg in global_result_queue.read_iterator():
53+
if isinstance(msg, BaseException):
54+
await handle_async_error(context, msg, ERR_UDF_EXCEPTION_STRING)
55+
return
56+
else:
57+
# msg is a map_pb2.MapResponse, already formed
58+
yield msg
59+
60+
# Ensure producer has finished (covers graceful shutdown)
61+
await producer
62+
63+
except BaseException as e:
6164
_LOGGER.critical("UDFError, re-raising the error", exc_info=True)
62-
await handle_async_error(context, err, ERR_UDF_EXCEPTION_STRING)
65+
await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING)
6366
return
6467

65-
async def __invoke_map_stream(self, keys: list[str], req: Datum):
68+
async def _process_inputs(
69+
self,
70+
request_iterator: AsyncIterable[map_pb2.MapRequest],
71+
result_queue: NonBlockingIterator,
72+
) -> None:
73+
"""
74+
Reads MapRequests from the client and spawns a background task per request.
75+
Each task streams results to result_queue as they are produced.
76+
"""
6677
try:
67-
# Invoke the user handler for map stream
68-
async for msg in self.__map_stream_handler(keys, req):
69-
yield map_pb2.MapResponse.Result(keys=msg.keys, value=msg.value, tags=msg.tags)
78+
async for req in request_iterator:
79+
task = asyncio.create_task(self._invoke_map_stream(req, result_queue))
80+
self._background_tasks.add(task)
81+
# Remove from the set when done to avoid memory growth
82+
task.add_done_callback(self._background_tasks.discard)
83+
84+
# Wait for all in-flight tasks to complete
85+
if self._background_tasks:
86+
await asyncio.gather(*list(self._background_tasks), return_exceptions=False)
87+
88+
# Signal end-of-stream to the consumer
89+
await result_queue.put(STREAM_EOF)
90+
91+
except BaseException as e:
92+
_LOGGER.critical("MapFn Error, re-raising the error", exc_info=True)
93+
# Surface the error to the consumer; MapFn will handle and close the RPC
94+
await result_queue.put(e)
95+
96+
async def _invoke_map_stream(
97+
self,
98+
req: map_pb2.MapRequest,
99+
result_queue: NonBlockingIterator,
100+
) -> None:
101+
"""
102+
Invokes the user-provided async generator for a single request and
103+
pushes each result onto the global queue, followed by an EOT for this id.
104+
"""
105+
try:
106+
datum = Datum(
107+
keys=list(req.request.keys),
108+
value=req.request.value,
109+
event_time=req.request.event_time.ToDatetime(),
110+
watermark=req.request.watermark.ToDatetime(),
111+
headers=dict(req.request.headers),
112+
)
113+
114+
# Stream results from the user handler as they are produced
115+
async for msg in self.__map_stream_handler(list(req.request.keys), datum):
116+
res = map_pb2.MapResponse.Result(keys=msg.keys, value=msg.value, tags=msg.tags)
117+
await result_queue.put(map_pb2.MapResponse(results=[res], id=req.id))
118+
119+
# Emit EOT for this request id
120+
await result_queue.put(
121+
map_pb2.MapResponse(status=map_pb2.TransmissionStatus(eot=True), id=req.id)
122+
)
123+
70124
except BaseException as err:
71125
_LOGGER.critical("MapFn handler error", exc_info=True)
72-
raise err
126+
# Surface handler error to the main producer;
127+
# it will call handle_async_error and end the RPC
128+
await result_queue.put(err)
73129

74130
async def IsReady(
75131
self, request: _empty_pb2.Empty, context: NumaflowServicerContext
76132
) -> map_pb2.ReadyResponse:
77-
"""
78-
IsReady is the heartbeat endpoint for gRPC.
79-
The pascal case function name comes from the proto map_pb2_grpc.py file.
80-
"""
133+
"""Heartbeat endpoint for gRPC."""
81134
return map_pb2.ReadyResponse(ready=True)

tests/mapstream/test_async_map_stream.py

Lines changed: 52 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -95,39 +95,67 @@ def tearDownClass(cls) -> None:
9595

9696
def test_map_stream(self) -> None:
9797
stub = self.__stub()
98-
generator_response = None
98+
99+
# Send >1 requests
100+
req_count = 3
99101
try:
100-
generator_response = stub.MapFn(request_iterator=request_generator(count=1, session=1))
102+
generator_response = stub.MapFn(
103+
request_iterator=request_generator(count=req_count, session=1)
104+
)
101105
except grpc.RpcError as e:
102106
logging.error(e)
107+
self.fail(f"RPC failed: {e}")
103108

109+
# First message must be the handshake
104110
handshake = next(generator_response)
105-
# assert that handshake response is received.
106111
self.assertTrue(handshake.handshake.sot)
107-
data_resp = []
108-
for r in generator_response:
109-
data_resp.append(r)
110-
111-
self.assertEqual(11, len(data_resp))
112112

113-
idx = 0
114-
while idx < len(data_resp) - 1:
113+
# Expected: 10 results per request + 1 EOT per request
114+
expected_result_msgs = req_count * 10
115+
expected_eots = req_count
116+
117+
# Prepare expected payload
118+
expected_payload = bytes(
119+
"payload:test_mock_message "
120+
"event_time:2022-09-12 16:00:00 watermark:2022-09-12 16:01:00",
121+
encoding="utf-8",
122+
)
123+
124+
from collections import Counter
125+
126+
id_counter = Counter()
127+
result_msg_count = 0
128+
eot_count = 0
129+
130+
for msg in generator_response:
131+
# Count EOTs wherever they show up
132+
if hasattr(msg, "status") and msg.status.eot:
133+
eot_count += 1
134+
continue
135+
136+
# Otherwise, it's a data/result message; validate payload and tally by id
137+
self.assertTrue(msg.results, "Expected results in MapResponse.")
138+
self.assertEqual(expected_payload, msg.results[0].value)
139+
id_counter[msg.id] += 1
140+
result_msg_count += 1
141+
142+
# Validate totals
143+
self.assertEqual(
144+
expected_result_msgs,
145+
result_msg_count,
146+
f"Expected {expected_result_msgs} result messages, got {result_msg_count}",
147+
)
148+
self.assertEqual(
149+
expected_eots, eot_count, f"Expected {expected_eots} EOT messages, got {eot_count}"
150+
)
151+
152+
# Validate 10 messages per request id: test-id-0..test-id-(req_count-1)
153+
for i in range(req_count):
115154
self.assertEqual(
116-
bytes(
117-
"payload:test_mock_message "
118-
"event_time:2022-09-12 16:00:00 watermark:2022-09-12 16:01:00",
119-
encoding="utf-8",
120-
),
121-
data_resp[idx].results[0].value,
155+
10,
156+
id_counter[f"test-id-{i}"],
157+
f"Expected 10 results for test-id-{i}, got {id_counter[f'test-id-{i}']}",
122158
)
123-
_id = data_resp[idx].id
124-
self.assertEqual(_id, "test-id-0")
125-
# capture the output from the SinkFn generator and assert.
126-
idx += 1
127-
# EOT Response
128-
self.assertEqual(data_resp[len(data_resp) - 1].status.eot, True)
129-
# 10 sink responses + 1 EOT response
130-
self.assertEqual(11, len(data_resp))
131159

132160
def test_is_ready(self) -> None:
133161
with grpc.insecure_channel("unix:///tmp/async_map_stream.sock") as channel:

0 commit comments

Comments
 (0)