11import asyncio
22import logging
3- from os import environ
3+ import multiprocessing
4+ import os
5+ from concurrent .futures import ThreadPoolExecutor
46from typing import Callable , Iterator
57
68import grpc
1618from pynumaflow .function .generated import udfunction_pb2_grpc
1719from pynumaflow .types import NumaflowServicerContext
1820
19- if environ . get ("PYTHONDEBUG" ):
21+ if os . getenv ("PYTHONDEBUG" ):
2022 logging .basicConfig (level = logging .DEBUG )
2123
2224_LOGGER = logging .getLogger (__name__ )
2325
2426UDFMapCallable = Callable [[str , Datum ], Messages ]
27+ _PROCESS_COUNT = multiprocessing .cpu_count ()
28+ MAX_THREADS = int (os .getenv ("MAX_THREADS" , 0 )) or (_PROCESS_COUNT * 4 )
2529
2630
2731class UserDefinedFunctionServicer (udfunction_pb2_grpc .UserDefinedFunctionServicer ):
@@ -33,15 +37,16 @@ class UserDefinedFunctionServicer(udfunction_pb2_grpc.UserDefinedFunctionService
3337 map_handler: Function callable following the type signature of UDFMapCallable
3438 sock_path: Path to the UNIX Domain Socket
3539 max_message_size: The max message size in bytes the server can receive and send
40+ max_threads: The max number of threads to be spawned;
41+ defaults to number of processors x4
3642
3743 Example invocation:
3844 >>> from pynumaflow.function import Messages, Message, Datum, UserDefinedFunctionServicer
3945 >>> def map_handler(key: str, datum: Datum) -> Messages:
4046 ... val = datum.value
4147 ... _ = datum.event_time
4248 ... _ = datum.watermark
43- ... messages = Messages()
44- ... messages.append(Message.to_vtx(key, val))
49+ ... messages = Messages(Message.to_vtx(key, val))
4550 ... return messages
4651 >>> grpc_server = UserDefinedFunctionServicer(map_handler)
4752 >>> grpc_server.start()
@@ -52,12 +57,19 @@ def __init__(
5257 map_handler : UDFMapCallable ,
5358 sock_path = FUNCTION_SOCK_PATH ,
5459 max_message_size = MAX_MESSAGE_SIZE ,
60+ max_threads = MAX_THREADS ,
5561 ):
5662 self .__map_handler : UDFMapCallable = map_handler
5763 self .sock_path = f"unix://{ sock_path } "
5864 self ._max_message_size = max_message_size
65+ self ._max_threads = max_threads
5966 self ._cleanup_coroutines = []
6067
68+ self ._server_options = [
69+ ("grpc.max_send_message_length" , self ._max_message_size ),
70+ ("grpc.max_receive_message_length" , self ._max_message_size ),
71+ ]
72+
6173 def MapFn (
6274 self , request : udfunction_pb2 .Datum , context : NumaflowServicerContext
6375 ) -> udfunction_pb2 .DatumList :
@@ -112,35 +124,51 @@ def IsReady(
112124 """
113125 return udfunction_pb2 .ReadyResponse (ready = True )
114126
115- async def __serve (self ) -> None :
116- server = grpc .aio .server (
117- options = [
118- ("grpc.max_send_message_length" , self ._max_message_size ),
119- ("grpc.max_receive_message_length" , self ._max_message_size ),
120- ]
121- )
127+ async def __serve_async (self , server ) -> None :
122128 udfunction_pb2_grpc .add_UserDefinedFunctionServicer_to_server (
123129 UserDefinedFunctionServicer (self .__map_handler ), server
124130 )
125131 server .add_insecure_port (self .sock_path )
126- _LOGGER .info ("Server listening on: %s" , self .sock_path )
132+ _LOGGER .info ("GRPC Async Server listening on: %s" , self .sock_path )
127133 await server .start ()
128134
129135 async def server_graceful_shutdown ():
130- logging .info ("Starting graceful shutdown..." )
131- # Shuts down the server with 5 seconds of grace period. During the
132- # grace period, the server won't accept new connections and allow
133- # existing RPCs to continue within the grace period.
136+ """
137+ Shuts down the server with 5 seconds of grace period. During the
138+ grace period, the server won't accept new connections and allow
139+ existing RPCs to continue within the grace period.
140+ """
141+ _LOGGER .info ("Starting graceful shutdown..." )
134142 await server .stop (5 )
135143
136144 self ._cleanup_coroutines .append (server_graceful_shutdown ())
137145 await server .wait_for_termination ()
138146
139- def start (self ) -> None :
140- """Starts the server on the given UNIX socket."""
147+ def start_async (self ) -> None :
148+ """Starts the Async gRPC server on the given UNIX socket."""
149+ server = grpc .aio .server (
150+ ThreadPoolExecutor (max_workers = self ._max_threads ), options = self ._server_options
151+ )
141152 loop = asyncio .get_event_loop ()
142153 try :
143- loop .run_until_complete (self .__serve ( ))
154+ loop .run_until_complete (self .__serve_async ( server ))
144155 finally :
145156 loop .run_until_complete (* self ._cleanup_coroutines )
146157 loop .close ()
158+
159+ def start (self ) -> None :
160+ """
161+ Starts the gRPC server on the given UNIX socket with given max threads.
162+ """
163+ server = grpc .server (
164+ ThreadPoolExecutor (max_workers = self ._max_threads ), options = self ._server_options
165+ )
166+ udfunction_pb2_grpc .add_UserDefinedFunctionServicer_to_server (
167+ UserDefinedFunctionServicer (self .__map_handler ), server
168+ )
169+ server .add_insecure_port (self .sock_path )
170+ server .start ()
171+ _LOGGER .info (
172+ "GRPC Server listening on: %s with max threads: %s" , self .sock_path , self ._max_threads
173+ )
174+ server .wait_for_termination ()
0 commit comments