Skip to content

Commit ee753f5

Browse files
authored
feat: introduce handshake to client and gRPC server (#89)
Signed-off-by: Sidhant Kohli <[email protected]>
1 parent 6dcd082 commit ee753f5

File tree

16 files changed

+274
-19
lines changed

16 files changed

+274
-19
lines changed

examples/function/multiproc_map/README.md

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,11 @@ processes to bind to the same port.
1515

1616
To enable multiprocessing mode
1717

18-
1) Set the env flag `MAP_MULTIPROC="true"` for the numa container
19-
20-
2) Start the multiproc server in the UDF using the following command
18+
1) Start the multiproc server in the UDF using the following command
2119
```python
2220
if __name__ == "__main__":
2321
grpc_server = MultiProcServer(map_handler=my_handler)
2422
grpc_server.start()
2523
```
26-
3) Set the ENV var value `NUM_CPU_MULTIPROC="n"` for the UDF and numa container,
27-
to set the value of the number of processes to be created.
24+
2) Set the ENV var value `NUM_CPU_MULTIPROC="n"` for the UDF container,
25+
to set the value of the number of server instances (one for each subprocess) to be created.

examples/function/multiproc_map/pipeline.yaml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,6 @@ spec:
3232
env:
3333
- name: NUMAFLOW_DEBUG
3434
value: "true" # DO NOT forget the double quotes!!!
35-
- name: MAP_MULTIPROC
36-
value: "true" # DO NOT forget the double quotes!!!
37-
- name: NUM_CPU_MULTIPROC
38-
value: "2" # DO NOT forget the double quotes!!!
3935

4036
- name: out
4137
sink:

pynumaflow/function/async_server.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
from pynumaflow.function.proto import udfunction_pb2
2424
from pynumaflow.function.proto import udfunction_pb2_grpc
2525
from pynumaflow.types import NumaflowServicerContext
26+
from pynumaflow.info.server import get_sdk_version, write as info_server_write
27+
from pynumaflow.info.types import ServerInfo, Protocol, Language, SERVER_INFO_FILE_PATH
2628

2729
_LOGGER = setup_logging(__name__)
2830
if os.getenv("PYTHONDEBUG"):
@@ -306,6 +308,12 @@ async def __serve_async(self, server) -> None:
306308
server.add_insecure_port(self.sock_path)
307309
_LOGGER.info("GRPC Async Server listening on: %s", self.sock_path)
308310
await server.start()
311+
serv_info = ServerInfo(
312+
protocol=Protocol.UDS,
313+
language=Language.PYTHON,
314+
version=get_sdk_version(),
315+
)
316+
info_server_write(server_info=serv_info, info_file=SERVER_INFO_FILE_PATH)
309317

310318
async def server_graceful_shutdown():
311319
"""

pynumaflow/function/multiproc_server.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,18 @@
2121
from pynumaflow.function.proto import udfunction_pb2
2222
from pynumaflow.function.proto import udfunction_pb2_grpc
2323
from pynumaflow.types import NumaflowServicerContext
24+
from pynumaflow.info.server import (
25+
get_sdk_version,
26+
write as info_server_write,
27+
get_metadata_env,
28+
)
29+
from pynumaflow.info.types import (
30+
ServerInfo,
31+
Protocol,
32+
Language,
33+
SERVER_INFO_FILE_PATH,
34+
METADATA_ENVS,
35+
)
2436

2537
_LOGGER = setup_logging(__name__)
2638
if os.getenv("PYTHONDEBUG"):
@@ -29,8 +41,6 @@
2941
UDFMapCallable = Callable[[List[str], Datum], Messages]
3042
UDFMapTCallable = Callable[[List[str], Datum], MessageTs]
3143
UDFReduceCallable = Callable[[List[str], AsyncIterable[Datum], Metadata], Messages]
32-
_PROCESS_COUNT = int(os.getenv("NUM_CPU_MULTIPROC", multiprocessing.cpu_count()))
33-
MAX_THREADS = int(os.getenv("MAX_THREADS", 0)) or (_PROCESS_COUNT * 4)
3444

3545

3646
class MultiProcServer(udfunction_pb2_grpc.UserDefinedFunctionServicer):
@@ -84,7 +94,6 @@ def __init__(
8494
reduce_handler: UDFReduceCallable = None,
8595
sock_path=MULTIPROC_FUNCTION_SOCK_PORT,
8696
max_message_size=MAX_MESSAGE_SIZE,
87-
max_threads=MAX_THREADS,
8897
):
8998
if not (map_handler or mapt_handler or reduce_handler):
9099
raise ValueError("Require a valid map/mapt handler and/or a valid reduce handler.")
@@ -93,7 +102,6 @@ def __init__(
93102
self.__mapt_handler: UDFMapTCallable = mapt_handler
94103
self.__reduce_handler: UDFReduceCallable = reduce_handler
95104
self._max_message_size = max_message_size
96-
self._max_threads = max_threads
97105
self.cleanup_coroutines = []
98106
# Collection for storing strong references to all running tasks.
99107
# Event loop only keeps a weak reference, which can cause it to
@@ -107,8 +115,10 @@ def __init__(
107115
("grpc.so_reuseaddr", 1),
108116
]
109117
self._sock_path = sock_path
110-
self._process_count = int(os.getenv("NUM_CPU_MULTIPROC", multiprocessing.cpu_count()))
111-
self._thread_concurrency = MAX_THREADS
118+
self._process_count = int(
119+
os.getenv("NUM_CPU_MULTIPROC") or os.getenv("NUMAFLOW_CPU_LIMIT", 1)
120+
)
121+
self._thread_concurrency = int(os.getenv("MAX_THREADS", 0)) or (self._process_count * 4)
112122

113123
def MapFn(
114124
self, request: udfunction_pb2.DatumRequest, context: NumaflowServicerContext
@@ -229,6 +239,16 @@ def _run_server(self, bind_address):
229239
udfunction_pb2_grpc.add_UserDefinedFunctionServicer_to_server(self, server)
230240
server.add_insecure_port(bind_address)
231241
server.start()
242+
serv_info = ServerInfo(
243+
protocol=Protocol.TCP,
244+
language=Language.PYTHON,
245+
version=get_sdk_version(),
246+
metadata=get_metadata_env(envs=METADATA_ENVS),
247+
)
248+
# Overwrite the CPU_LIMIT metadata using user input
249+
serv_info.metadata["CPU_LIMIT"] = str(self._process_count)
250+
info_server_write(server_info=serv_info, info_file=SERVER_INFO_FILE_PATH)
251+
232252
_LOGGER.info("GRPC Multi-Processor Server listening on: %s %d", bind_address, os.getpid())
233253
server.wait_for_termination()
234254

pynumaflow/function/server.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import grpc
88
from google.protobuf import empty_pb2 as _empty_pb2
99
from google.protobuf import timestamp_pb2 as _timestamp_pb2
10+
from pynumaflow.info.server import get_sdk_version, write as info_server_write
11+
from pynumaflow.info.types import ServerInfo, Protocol, Language, SERVER_INFO_FILE_PATH
1012

1113
from pynumaflow import setup_logging
1214
from pynumaflow._constants import (
@@ -219,6 +221,12 @@ def start(self) -> None:
219221
udfunction_pb2_grpc.add_UserDefinedFunctionServicer_to_server(self, server)
220222
server.add_insecure_port(self.sock_path)
221223
server.start()
224+
serv_info = ServerInfo(
225+
protocol=Protocol.UDS,
226+
language=Language.PYTHON,
227+
version=get_sdk_version(),
228+
)
229+
info_server_write(server_info=serv_info, info_file=SERVER_INFO_FILE_PATH)
222230
_LOGGER.info(
223231
"GRPC Server listening on: %s with max threads: %s", self.sock_path, self._max_threads
224232
)

pynumaflow/info/__init__.py

Whitespace-only changes.

pynumaflow/info/server.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import os
2+
from importlib.metadata import version
3+
from typing import Any
4+
5+
from pynumaflow import setup_logging
6+
from pynumaflow.info.types import ServerInfo, EOF
7+
import json
8+
import logging
9+
10+
_LOGGER = setup_logging(__name__)
11+
if os.getenv("PYTHONDEBUG"):
12+
_LOGGER.setLevel(logging.DEBUG)
13+
14+
15+
def get_sdk_version() -> str:
16+
"""
17+
Return the pynumaflow SDK version
18+
"""
19+
try:
20+
return version("pynumaflow")
21+
except Exception as e:
22+
# Adding this to handle the case for local test/CI where pynumaflow
23+
# will not be installed as a package
24+
_LOGGER.error("Could not read SDK version %r", e, exc_info=True)
25+
return ""
26+
27+
28+
def write(server_info: ServerInfo, info_file: str):
29+
"""
30+
Write the ServerInfo to a file , shared with the client (numa container).
31+
32+
args:
33+
serv: The ServerInfo object to be shared
34+
info_file: the shared file path
35+
"""
36+
try:
37+
data = server_info.__dict__
38+
with open(info_file, "w+") as f:
39+
json.dump(data, f, ensure_ascii=False)
40+
f.write(EOF)
41+
except Exception as err:
42+
_LOGGER.critical("Could not write data to Info-Server %r", err, exc_info=True)
43+
raise err
44+
45+
46+
def get_metadata_env(envs: list[tuple[str, str]]) -> dict[str, Any]:
47+
"""
48+
Extract the environment var value from the provided list,
49+
and assign them to the given key in the metadata
50+
51+
args:
52+
envs: List of tuples (key, env_var)
53+
"""
54+
meta = {}
55+
for key, val in envs:
56+
res = os.getenv(val, None)
57+
if res:
58+
meta[key] = res
59+
return meta

pynumaflow/info/types.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from dataclasses import dataclass, field
2+
from enum import Enum
3+
4+
# Constants for using in the info-server
5+
# Need to keep consistent with all SDKs and client
6+
SERVER_INFO_FILE_PATH = "/var/run/numaflow/server-info"
7+
EOF = "U+005C__END__"
8+
9+
# Env variables to be passed in the info server metadata.
10+
# These need to be accessed in the client using the same key.
11+
# Format - (key, env_var)
12+
METADATA_ENVS = [("CPU_LIMIT", "NUMAFLOW_CPU_LIMIT")]
13+
14+
15+
class Protocol(str, Enum):
16+
"""
17+
Enumerate grpc server connection protocol.
18+
"""
19+
20+
UDS = "uds"
21+
TCP = "tcp"
22+
23+
24+
class Language(str, Enum):
25+
"""
26+
Enumerate Numaflow SDK language.
27+
"""
28+
29+
GO = "go"
30+
PYTHON = "python"
31+
JAVA = "java"
32+
33+
34+
@dataclass
35+
class ServerInfo:
36+
"""
37+
ServerInfo is used for the gRPC server to provide the information such as protocol,
38+
sdk version, language, metadata to the client.
39+
Args:
40+
protocol: Protocol to use (UDS or TCP)
41+
language: Language used by the server(Python, Golang, Java)
42+
version: Numaflow sdk version used by the server
43+
metadata: Any additional information to be provided (env vars)
44+
"""
45+
46+
protocol: Protocol
47+
language: Language
48+
version: str
49+
metadata: dict = field(default_factory=dict)

pynumaflow/sink/async_sink.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
SINK_SOCK_PATH,
1212
MAX_MESSAGE_SIZE,
1313
)
14+
from pynumaflow.info.server import get_sdk_version, write as info_server_write
15+
from pynumaflow.info.types import ServerInfo, Protocol, Language, SERVER_INFO_FILE_PATH
1416
from pynumaflow.sink import Responses, Datum, Response
1517
from pynumaflow.sink.proto import udsink_pb2_grpc, udsink_pb2
1618
from pynumaflow.types import NumaflowServicerContext
@@ -128,6 +130,12 @@ async def __serve_async(self, server) -> None:
128130
server.add_insecure_port(self.sock_path)
129131
_LOGGER.info("GRPC Async Server listening on: %s", self.sock_path)
130132
await server.start()
133+
serv_info = ServerInfo(
134+
protocol=Protocol.UDS,
135+
language=Language.PYTHON,
136+
version=get_sdk_version(),
137+
)
138+
info_server_write(server_info=serv_info, info_file=SERVER_INFO_FILE_PATH)
131139

132140
async def server_graceful_shutdown():
133141
_LOGGER.info("Starting graceful shutdown...")

pynumaflow/sink/server.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
SINK_SOCK_PATH,
1313
MAX_MESSAGE_SIZE,
1414
)
15+
from pynumaflow.info.server import get_sdk_version, write as info_server_write
16+
from pynumaflow.info.types import ServerInfo, Protocol, Language, SERVER_INFO_FILE_PATH
1517
from pynumaflow.sink import Responses, Datum, Response
1618
from pynumaflow.sink.proto import udsink_pb2_grpc, udsink_pb2
1719
from pynumaflow.types import NumaflowServicerContext
@@ -124,6 +126,13 @@ def start(self) -> None:
124126
udsink_pb2_grpc.add_UserDefinedSinkServicer_to_server(Sink(self.__sink_handler), server)
125127
server.add_insecure_port(self.sock_path)
126128
server.start()
129+
serv_info = ServerInfo(
130+
protocol=Protocol.UDS,
131+
language=Language.PYTHON,
132+
version=get_sdk_version(),
133+
)
134+
info_server_write(server_info=serv_info, info_file=SERVER_INFO_FILE_PATH)
135+
127136
_LOGGER.info(
128137
"GRPC Server listening on: %s with max threads: %s", self.sock_path, self._max_threads
129138
)

0 commit comments

Comments
 (0)