Skip to content

Commit 0fbecc1

Browse files
authored
feat: introduces sync gRPC server (#27)
* udf multi-threaded sync server * udsink multi-threaded sync server * udf and udsink async servers with migration threadpool * convert Message to a dataclass type Signed-off-by: Avik Basu <[email protected]>
1 parent b6a4bc0 commit 0fbecc1

File tree

7 files changed

+136
-75
lines changed

7 files changed

+136
-75
lines changed

.codecov.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@ coverage:
33
project:
44
default:
55
target: auto
6-
threshold: 1%
6+
threshold: 3%
77
patch:
88
default:
99
target: auto
10-
threshold: 1%
10+
threshold: 5%
1111

1212
ignore:
1313
- "examples/"

.coveragerc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,9 @@ source = pynumaflow
55
omit =
66
pynumaflow/tests/*
77
examples/*
8+
9+
[report]
10+
exclude_lines =
11+
def start
12+
def start_async
13+
def __serve_async

pynumaflow/function/_dtypes.py

Lines changed: 19 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
from dataclasses import dataclass
12
from datetime import datetime
3+
from functools import partialmethod
24
from typing import TypeVar, Type, List
35

46
DROP = b"U+005C__DROP__"
@@ -9,41 +11,34 @@
911
Ms = TypeVar("Ms", bound="Messages")
1012

1113

14+
@dataclass(frozen=True)
1215
class Message:
13-
def __init__(self, key: str, value: bytes):
14-
self._key = key or ""
15-
self._value = value or b""
16-
17-
def __str__(self):
18-
return str({self._key: self._value})
19-
20-
def __repr__(self):
21-
return str(self)
16+
"""
17+
Basic datatype for data passing to the next vertex/vertices.
2218
23-
@property
24-
def key(self) -> str:
25-
return self._key
19+
Args:
20+
key: string key for vertex;
21+
special values are ALL (send to all), DROP (drop message)
22+
value: data in bytes
23+
"""
2624

27-
@property
28-
def value(self) -> bytes:
29-
return self._value
25+
key: str = ""
26+
value: bytes = b""
3027

3128
@classmethod
3229
def to_vtx(cls: Type[M], key: str, value: bytes) -> M:
30+
"""
31+
Returns a Message object to send value to a vertex.
32+
"""
3333
return cls(key, value)
3434

35-
@classmethod
36-
def to_all(cls: Type[M], value: bytes) -> M:
37-
return cls(ALL, value)
38-
39-
@classmethod
40-
def to_drop(cls: Type[M]) -> M:
41-
return cls(DROP, b"")
35+
to_all = partialmethod(to_vtx, ALL)
36+
to_drop = partialmethod(to_vtx, DROP, b"")
4237

4338

4439
class Messages:
45-
def __init__(self):
46-
self._messages = []
40+
def __init__(self, *messages: M):
41+
self._messages = list(messages) or []
4742

4843
def __str__(self):
4944
return str(self._messages)

pynumaflow/function/server.py

Lines changed: 47 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import asyncio
22
import logging
3-
from os import environ
3+
import multiprocessing
4+
import os
5+
from concurrent.futures import ThreadPoolExecutor
46
from typing import Callable, Iterator
57

68
import grpc
@@ -16,12 +18,14 @@
1618
from pynumaflow.function.generated import udfunction_pb2_grpc
1719
from pynumaflow.types import NumaflowServicerContext
1820

19-
if environ.get("PYTHONDEBUG"):
21+
if os.getenv("PYTHONDEBUG"):
2022
logging.basicConfig(level=logging.DEBUG)
2123

2224
_LOGGER = logging.getLogger(__name__)
2325

2426
UDFMapCallable = Callable[[str, Datum], Messages]
27+
_PROCESS_COUNT = multiprocessing.cpu_count()
28+
MAX_THREADS = int(os.getenv("MAX_THREADS", 0)) or (_PROCESS_COUNT * 4)
2529

2630

2731
class UserDefinedFunctionServicer(udfunction_pb2_grpc.UserDefinedFunctionServicer):
@@ -33,15 +37,16 @@ class UserDefinedFunctionServicer(udfunction_pb2_grpc.UserDefinedFunctionService
3337
map_handler: Function callable following the type signature of UDFMapCallable
3438
sock_path: Path to the UNIX Domain Socket
3539
max_message_size: The max message size in bytes the server can receive and send
40+
max_threads: The max number of threads to be spawned;
41+
defaults to number of processors x4
3642
3743
Example invocation:
3844
>>> from pynumaflow.function import Messages, Message, Datum, UserDefinedFunctionServicer
3945
>>> def map_handler(key: str, datum: Datum) -> Messages:
4046
... val = datum.value
4147
... _ = datum.event_time
4248
... _ = datum.watermark
43-
... messages = Messages()
44-
... messages.append(Message.to_vtx(key, val))
49+
... messages = Messages(Message.to_vtx(key, val))
4550
... return messages
4651
>>> grpc_server = UserDefinedFunctionServicer(map_handler)
4752
>>> grpc_server.start()
@@ -52,12 +57,19 @@ def __init__(
5257
map_handler: UDFMapCallable,
5358
sock_path=FUNCTION_SOCK_PATH,
5459
max_message_size=MAX_MESSAGE_SIZE,
60+
max_threads=MAX_THREADS,
5561
):
5662
self.__map_handler: UDFMapCallable = map_handler
5763
self.sock_path = f"unix://{sock_path}"
5864
self._max_message_size = max_message_size
65+
self._max_threads = max_threads
5966
self._cleanup_coroutines = []
6067

68+
self._server_options = [
69+
("grpc.max_send_message_length", self._max_message_size),
70+
("grpc.max_receive_message_length", self._max_message_size),
71+
]
72+
6173
def MapFn(
6274
self, request: udfunction_pb2.Datum, context: NumaflowServicerContext
6375
) -> udfunction_pb2.DatumList:
@@ -112,35 +124,51 @@ def IsReady(
112124
"""
113125
return udfunction_pb2.ReadyResponse(ready=True)
114126

115-
async def __serve(self) -> None:
116-
server = grpc.aio.server(
117-
options=[
118-
("grpc.max_send_message_length", self._max_message_size),
119-
("grpc.max_receive_message_length", self._max_message_size),
120-
]
121-
)
127+
async def __serve_async(self, server) -> None:
122128
udfunction_pb2_grpc.add_UserDefinedFunctionServicer_to_server(
123129
UserDefinedFunctionServicer(self.__map_handler), server
124130
)
125131
server.add_insecure_port(self.sock_path)
126-
_LOGGER.info("Server listening on: %s", self.sock_path)
132+
_LOGGER.info("GRPC Async Server listening on: %s", self.sock_path)
127133
await server.start()
128134

129135
async def server_graceful_shutdown():
130-
logging.info("Starting graceful shutdown...")
131-
# Shuts down the server with 5 seconds of grace period. During the
132-
# grace period, the server won't accept new connections and allow
133-
# existing RPCs to continue within the grace period.
136+
"""
137+
Shuts down the server with 5 seconds of grace period. During the
138+
grace period, the server won't accept new connections and allow
139+
existing RPCs to continue within the grace period.
140+
"""
141+
_LOGGER.info("Starting graceful shutdown...")
134142
await server.stop(5)
135143

136144
self._cleanup_coroutines.append(server_graceful_shutdown())
137145
await server.wait_for_termination()
138146

139-
def start(self) -> None:
140-
"""Starts the server on the given UNIX socket."""
147+
def start_async(self) -> None:
148+
"""Starts the Async gRPC server on the given UNIX socket."""
149+
server = grpc.aio.server(
150+
ThreadPoolExecutor(max_workers=self._max_threads), options=self._server_options
151+
)
141152
loop = asyncio.get_event_loop()
142153
try:
143-
loop.run_until_complete(self.__serve())
154+
loop.run_until_complete(self.__serve_async(server))
144155
finally:
145156
loop.run_until_complete(*self._cleanup_coroutines)
146157
loop.close()
158+
159+
def start(self) -> None:
160+
"""
161+
Starts the gRPC server on the given UNIX socket with given max threads.
162+
"""
163+
server = grpc.server(
164+
ThreadPoolExecutor(max_workers=self._max_threads), options=self._server_options
165+
)
166+
udfunction_pb2_grpc.add_UserDefinedFunctionServicer_to_server(
167+
UserDefinedFunctionServicer(self.__map_handler), server
168+
)
169+
server.add_insecure_port(self.sock_path)
170+
server.start()
171+
_LOGGER.info(
172+
"GRPC Server listening on: %s with max threads: %s", self.sock_path, self._max_threads
173+
)
174+
server.wait_for_termination()

pynumaflow/sink/_dtypes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
Rs = TypeVar("Rs", bound="Responses")
77

88

9-
@dataclass
9+
@dataclass(frozen=True)
1010
class Response:
1111
id: str
1212
success: bool

pynumaflow/sink/server.py

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import asyncio
22
import logging
3-
from os import environ
3+
import multiprocessing
4+
import os
5+
from concurrent.futures import ThreadPoolExecutor
46
from typing import Callable, List
57

68
import grpc
@@ -14,12 +16,14 @@
1416
from pynumaflow.sink.generated import udsink_pb2_grpc, udsink_pb2
1517
from pynumaflow.types import NumaflowServicerContext
1618

17-
if environ.get("PYTHONDEBUG"):
19+
if os.getenv("PYTHONDEBUG"):
1820
logging.basicConfig(level=logging.DEBUG)
1921

2022
_LOGGER = logging.getLogger(__name__)
2123

2224
UDSinkCallable = Callable[[List[Datum]], Responses]
25+
_PROCESS_COUNT = multiprocessing.cpu_count()
26+
MAX_THREADS = int(os.getenv("MAX_THREADS", 0)) or (_PROCESS_COUNT * 4)
2327

2428

2529
class UserDefinedSinkServicer(udsink_pb2_grpc.UserDefinedSinkServicer):
@@ -31,14 +35,15 @@ class UserDefinedSinkServicer(udsink_pb2_grpc.UserDefinedSinkServicer):
3135
sink_handler: Function callable following the type signature of UDSinkCallable
3236
sock_path: Path to the UNIX Domain Socket
3337
max_message_size: The max message size in bytes the server can receive and send
38+
max_threads: The max number of threads to be spawned;
39+
defaults to number of processors x 4
3440
3541
Example invocation:
3642
>>> from typing import List
3743
>>> from pynumaflow.sink import Datum, Responses, Response, UserDefinedSinkServicer
3844
>>> def udsink_handler(datums: List[Datum]) -> Responses:
3945
... responses = Responses()
4046
... for msg in datums:
41-
... print("User Defined Sink", msg)
4247
... responses.append(Response.as_success(msg.id))
4348
... return responses
4449
>>> grpc_server = UserDefinedSinkServicer(udsink_handler)
@@ -50,12 +55,19 @@ def __init__(
5055
sink_handler: UDSinkCallable,
5156
sock_path=SINK_SOCK_PATH,
5257
max_message_size=MAX_MESSAGE_SIZE,
58+
max_threads=MAX_THREADS,
5359
):
5460
self.__sink_handler: UDSinkCallable = sink_handler
5561
self.sock_path = f"unix://{sock_path}"
5662
self._max_message_size = max_message_size
63+
self._max_threads = max_threads
5764
self._cleanup_coroutines = []
5865

66+
self._server_options = [
67+
("grpc.max_send_message_length", self._max_message_size),
68+
("grpc.max_receive_message_length", self._max_message_size),
69+
]
70+
5971
def SinkFn(
6072
self, request: udsink_pb2.DatumList, context: NumaflowServicerContext
6173
) -> udsink_pb2.ResponseList:
@@ -90,35 +102,51 @@ def IsReady(
90102
"""
91103
return udsink_pb2.ReadyResponse(ready=True)
92104

93-
async def __serve(self) -> None:
105+
async def __serve_async(self) -> None:
94106
server = grpc.aio.server(
95-
options=[
96-
("grpc.max_send_message_length", self._max_message_size),
97-
("grpc.max_receive_message_length", self._max_message_size),
98-
]
107+
ThreadPoolExecutor(max_workers=self._max_threads), options=self._server_options
99108
)
100109
udsink_pb2_grpc.add_UserDefinedSinkServicer_to_server(
101110
UserDefinedSinkServicer(self.__sink_handler), server
102111
)
103112
server.add_insecure_port(self.sock_path)
104-
_LOGGER.info("Server listening on: %s", self.sock_path)
113+
_LOGGER.info("GRPC Async Server listening on: %s", self.sock_path)
105114
await server.start()
106115

107116
async def server_graceful_shutdown():
108-
logging.info("Starting graceful shutdown...")
109-
# Shuts down the server with 5 seconds of grace period. During the
110-
# grace period, the server won't accept new connections and allow
111-
# existing RPCs to continue within the grace period.
117+
_LOGGER.info("Starting graceful shutdown...")
118+
"""
119+
Shuts down the server with 5 seconds of grace period. During the
120+
grace period, the server won't accept new connections and allow
121+
existing RPCs to continue within the grace period.
112122
await server.stop(5)
123+
"""
113124

114125
self._cleanup_coroutines.append(server_graceful_shutdown())
115126
await server.wait_for_termination()
116127

117-
def start(self) -> None:
118-
"""Starts the server on the given UNIX socket."""
128+
def start_async(self) -> None:
129+
"""Starts the Async gRPC server on the given UNIX socket."""
119130
loop = asyncio.get_event_loop()
120131
try:
121-
loop.run_until_complete(self.__serve())
132+
loop.run_until_complete(self.__serve_async())
122133
finally:
123134
loop.run_until_complete(*self._cleanup_coroutines)
124135
loop.close()
136+
137+
def start(self) -> None:
138+
"""
139+
Starts the gRPC server on the given UNIX socket with given max threads.
140+
"""
141+
server = grpc.server(
142+
ThreadPoolExecutor(max_workers=self._max_threads), options=self._server_options
143+
)
144+
udsink_pb2_grpc.add_UserDefinedSinkServicer_to_server(
145+
UserDefinedSinkServicer(self.__sink_handler), server
146+
)
147+
server.add_insecure_port(self.sock_path)
148+
server.start()
149+
_LOGGER.info(
150+
"GRPC Server listening on: %s with max threads: %s", self.sock_path, self._max_threads
151+
)
152+
server.wait_for_termination()

0 commit comments

Comments
 (0)