2121from pynumaflow .function .proto import udfunction_pb2
2222from pynumaflow .function .proto import udfunction_pb2_grpc
2323from pynumaflow .types import NumaflowServicerContext
24+ from pynumaflow .info .server import (
25+ get_sdk_version ,
26+ write as info_server_write ,
27+ get_metadata_env ,
28+ )
29+ from pynumaflow .info .types import (
30+ ServerInfo ,
31+ Protocol ,
32+ Language ,
33+ SERVER_INFO_FILE_PATH ,
34+ METADATA_ENVS ,
35+ )
2436
2537_LOGGER = setup_logging (__name__ )
2638if os .getenv ("PYTHONDEBUG" ):
2941UDFMapCallable = Callable [[List [str ], Datum ], Messages ]
3042UDFMapTCallable = Callable [[List [str ], Datum ], MessageTs ]
3143UDFReduceCallable = Callable [[List [str ], AsyncIterable [Datum ], Metadata ], Messages ]
32- _PROCESS_COUNT = int (os .getenv ("NUM_CPU_MULTIPROC" , multiprocessing .cpu_count ()))
33- MAX_THREADS = int (os .getenv ("MAX_THREADS" , 0 )) or (_PROCESS_COUNT * 4 )
3444
3545
3646class MultiProcServer (udfunction_pb2_grpc .UserDefinedFunctionServicer ):
@@ -84,7 +94,6 @@ def __init__(
8494 reduce_handler : UDFReduceCallable = None ,
8595 sock_path = MULTIPROC_FUNCTION_SOCK_PORT ,
8696 max_message_size = MAX_MESSAGE_SIZE ,
87- max_threads = MAX_THREADS ,
8897 ):
8998 if not (map_handler or mapt_handler or reduce_handler ):
9099 raise ValueError ("Require a valid map/mapt handler and/or a valid reduce handler." )
@@ -93,7 +102,6 @@ def __init__(
93102 self .__mapt_handler : UDFMapTCallable = mapt_handler
94103 self .__reduce_handler : UDFReduceCallable = reduce_handler
95104 self ._max_message_size = max_message_size
96- self ._max_threads = max_threads
97105 self .cleanup_coroutines = []
98106 # Collection for storing strong references to all running tasks.
99107 # Event loop only keeps a weak reference, which can cause it to
@@ -107,8 +115,10 @@ def __init__(
107115 ("grpc.so_reuseaddr" , 1 ),
108116 ]
109117 self ._sock_path = sock_path
110- self ._process_count = int (os .getenv ("NUM_CPU_MULTIPROC" , multiprocessing .cpu_count ()))
111- self ._thread_concurrency = MAX_THREADS
118+ self ._process_count = int (
119+ os .getenv ("NUM_CPU_MULTIPROC" ) or os .getenv ("NUMAFLOW_CPU_LIMIT" , 1 )
120+ )
121+ self ._thread_concurrency = int (os .getenv ("MAX_THREADS" , 0 )) or (self ._process_count * 4 )
112122
113123 def MapFn (
114124 self , request : udfunction_pb2 .DatumRequest , context : NumaflowServicerContext
@@ -229,6 +239,16 @@ def _run_server(self, bind_address):
229239 udfunction_pb2_grpc .add_UserDefinedFunctionServicer_to_server (self , server )
230240 server .add_insecure_port (bind_address )
231241 server .start ()
242+ serv_info = ServerInfo (
243+ protocol = Protocol .TCP ,
244+ language = Language .PYTHON ,
245+ version = get_sdk_version (),
246+ metadata = get_metadata_env (envs = METADATA_ENVS ),
247+ )
248+ # Overwrite the CPU_LIMIT metadata using user input
249+ serv_info .metadata ["CPU_LIMIT" ] = str (self ._process_count )
250+ info_server_write (server_info = serv_info , info_file = SERVER_INFO_FILE_PATH )
251+
232252 _LOGGER .info ("GRPC Multi-Processor Server listening on: %s %d" , bind_address , os .getpid ())
233253 server .wait_for_termination ()
234254
0 commit comments