Skip to content

Commit 1b71a8c

Browse files
committed
support metadata in source
Signed-off-by: Sreekanth <[email protected]>
1 parent a2a28c7 commit 1b71a8c

File tree

9 files changed

+143
-35
lines changed

9 files changed

+143
-35
lines changed

packages/pynumaflow/pynumaflow/_metadata.py

Lines changed: 63 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass, field
2-
from collections import defaultdict
2+
from typing import Optional
33
from pynumaflow.proto.common import metadata_pb2
44

55
"""
@@ -46,13 +46,13 @@ def keys(self, group: str) -> list[str]:
4646
"""
4747
Returns the list of keys for a given group.
4848
"""
49-
return list(self._data[group].keys())
49+
return list(self._data.get(group, {}).keys())
5050

51-
def value(self, group: str, key: str) -> bytes:
51+
def value(self, group: str, key: str) -> Optional[bytes]:
5252
"""
5353
Returns the value for a given group and key.
5454
"""
55-
return self._data[group][key]
55+
return self._data.get(group, {}).get(key)
5656

5757

5858
@dataclass
@@ -61,49 +61,95 @@ class UserMetadata:
6161
UserMetadata wraps the user-generated metadata groups per message. It is read-write to UDFs.
6262
"""
6363

64-
_data: defaultdict[str, dict[str, bytes]] = field(default_factory=lambda: defaultdict(dict))
64+
_data: dict[str, dict[str, bytes]] = field(default_factory=dict)
6565

6666
def groups(self) -> list[str]:
6767
"""
6868
Returns the list of group names for the user metadata.
6969
"""
7070
return list(self._data.keys())
7171

72-
def keys(self, group: str) -> list[str]:
72+
def keys(self, group: str) -> Optional[list[str]]:
7373
"""
7474
Returns the list of keys for a given group.
7575
"""
76-
return list(self._data[group].keys())
76+
keys = self._data.get(group)
77+
if keys is None:
78+
return None
79+
return list(keys.keys())
7780

78-
def value(self, group: str, key: str) -> bytes:
81+
def __contains__(self, group: str) -> bool:
7982
"""
80-
Returns the value for a given group and key.
83+
Returns True if the group exists.
84+
"""
85+
return group in self._data
86+
87+
def __getitem__(self, group: str) -> dict[str, bytes]:
88+
"""
89+
Returns the data for a given group.
90+
Raises KeyError if the group does not exist.
91+
"""
92+
return self._data[group]
93+
94+
def __setitem__(self, group: str, data: dict[str, bytes]):
95+
"""
96+
Sets the data for a given group.
97+
"""
98+
self._data[group] = data
99+
100+
def __delitem__(self, group: str):
101+
"""
102+
Removes the group and all its keys and values.
103+
Raises KeyError if the group does not exist.
104+
"""
105+
del self._data[group]
106+
107+
def __len__(self) -> int:
81108
"""
82-
return self._data[group][key]
109+
Returns the number of groups.
110+
"""
111+
return len(self._data)
112+
113+
def value(self, group: str, key: str) -> Optional[bytes]:
114+
"""
115+
Returns the value for a given group and key. If the group or key does not exist, returns None.
116+
"""
117+
value = self._data.get(group)
118+
if value is None:
119+
return None
120+
return value.get(key)
83121

84122
def add(self, group: str, key: str, value: bytes):
85123
"""
86124
Adds the value for a given group and key.
87125
"""
88-
self._data[group][key] = value
126+
self._data.setdefault(group, {})[key] = value
89127

90128
def set_group(self, group: str, data: dict[str, bytes]):
91129
"""
92130
Sets the data for a given group.
93131
"""
94132
self._data[group] = data
95133

96-
def remove(self, group: str, key: str):
134+
def remove(self, group: str, key: str) -> Optional[bytes]:
97135
"""
98-
Removes the key and its value for a given group.
136+
Removes the key and its value for a given group and returns the value. If this key is the only key in the group, the group will be removed.
137+
Returns None if the group or key does not exist.
99138
"""
100-
del self._data[group][key]
139+
group_data = self._data.pop(group, None)
140+
if group_data is None:
141+
return None
142+
value = group_data.pop(key, None)
143+
if group_data:
144+
self._data[group] = group_data
145+
return value
101146

102-
def remove_group(self, group: str):
147+
def remove_group(self, group: str) -> Optional[dict[str, bytes]]:
103148
"""
104-
Removes the group and all its keys and values.
149+
Removes the group and all its keys and values and returns the data.
150+
Returns None if the group does not exist.
105151
"""
106-
del self._data[group]
152+
return self._data.pop(group, None)
107153

108154
def clear(self):
109155
"""
Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,27 @@
11
import asyncio
2+
from typing import Generic, TypeVar
3+
from collections.abc import AsyncIterator
24

35
from pynumaflow._constants import STREAM_EOF
46

7+
T = TypeVar("T")
58

6-
class NonBlockingIterator:
9+
10+
class NonBlockingIterator(Generic[T]):
711
"""An Async Interator backed by a queue"""
812

913
__slots__ = "_queue"
1014

11-
def __init__(self, size=0):
12-
self._queue = asyncio.Queue(maxsize=size)
15+
def __init__(self, size: int = 0) -> None:
16+
self._queue: asyncio.Queue[T] = asyncio.Queue(maxsize=size)
1317

14-
async def read_iterator(self):
18+
async def read_iterator(self) -> AsyncIterator[T]:
1519
item = await self._queue.get()
1620
while True:
1721
if item == STREAM_EOF:
1822
break
1923
yield item
2024
item = await self._queue.get()
2125

22-
async def put(self, item):
26+
async def put(self, item: T) -> None:
2327
await self._queue.put(item)

packages/pynumaflow/pynumaflow/sinker/__init__.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,16 @@
22

33
from pynumaflow.sinker.server import SinkServer
44

5+
from pynumaflow._metadata import UserMetadata, SystemMetadata
56
from pynumaflow.sinker._dtypes import Response, Responses, Datum, Sinker
67

7-
__all__ = ["Response", "Responses", "Datum", "Sinker", "SinkAsyncServer", "SinkServer"]
8+
__all__ = [
9+
"Response",
10+
"Responses",
11+
"Datum",
12+
"Sinker",
13+
"SinkAsyncServer",
14+
"SinkServer",
15+
"UserMetadata",
16+
"SystemMetadata",
17+
]

packages/pynumaflow/pynumaflow/sourcer/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
Sourcer,
1111
SourceCallable,
1212
)
13+
from pynumaflow._metadata import UserMetadata
1314
from pynumaflow.sourcer.async_server import SourceAsyncServer
1415

1516
__all__ = [
@@ -24,4 +25,5 @@
2425
"Sourcer",
2526
"SourceAsyncServer",
2627
"SourceCallable",
28+
"UserMetadata",
2729
]

packages/pynumaflow/pynumaflow/sourcer/_dtypes.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from datetime import datetime
66
from typing import Callable, Optional
77

8+
from pynumaflow._metadata import UserMetadata
89
from pynumaflow.shared.asynciter import NonBlockingIterator
910

1011

@@ -56,15 +57,17 @@ class Message:
5657
event_time: event time of the message, usually extracted from the payload.
5758
keys: []string keys for vertex (optional)
5859
headers: dict of headers for the message (optional)
60+
user_metadata: metadata for the message (optional)
5961
"""
6062

61-
__slots__ = ("_payload", "_offset", "_event_time", "_keys", "_headers")
63+
__slots__ = ("_payload", "_offset", "_event_time", "_keys", "_headers", "_user_metadata")
6264

6365
_payload: bytes
6466
_offset: Offset
6567
_event_time: datetime
6668
_keys: list[str]
6769
_headers: dict[str, str]
70+
_user_metadata: UserMetadata
6871

6972
def __init__(
7073
self,
@@ -73,6 +76,7 @@ def __init__(
7376
event_time: datetime,
7477
keys: list[str] = None,
7578
headers: Optional[dict[str, str]] = None,
79+
user_metadata: Optional[UserMetadata] = None,
7680
):
7781
"""
7882
Creates a Message object to send value to a vertex.
@@ -82,6 +86,7 @@ def __init__(
8286
self._event_time = event_time
8387
self._keys = keys or []
8488
self._headers = headers or {}
89+
self._user_metadata = user_metadata or UserMetadata()
8590

8691
@property
8792
def payload(self) -> bytes:
@@ -103,6 +108,11 @@ def event_time(self) -> datetime:
103108
def headers(self) -> dict[str, str]:
104109
return self._headers
105110

111+
@property
112+
def user_metadata(self) -> UserMetadata:
113+
"""Returns the user metadata of the message."""
114+
return self._user_metadata
115+
106116

107117
@dataclass(init=False)
108118
class ReadRequest:

packages/pynumaflow/pynumaflow/sourcer/servicer/async_servicer.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
from collections.abc import AsyncIterator
3+
from typing import Union
34

45
from google.protobuf import timestamp_pb2 as _timestamp_pb2
56
from google.protobuf import empty_pb2 as _empty_pb2
@@ -9,6 +10,7 @@
910
from pynumaflow.sourcer import ReadRequest, Offset, NackRequest, AckRequest, SourceCallable
1011
from pynumaflow.proto.sourcer import source_pb2
1112
from pynumaflow.proto.sourcer import source_pb2_grpc
13+
from pynumaflow.sourcer._dtypes import Message
1214
from pynumaflow.types import NumaflowServicerContext
1315
from pynumaflow._constants import _LOGGER, STREAM_EOF, ERR_UDF_EXCEPTION_STRING
1416

@@ -31,7 +33,7 @@ def _create_ack_handshake_response():
3133
)
3234

3335

34-
def _create_read_response(response):
36+
def _create_read_response(response: Message):
3537
"""Create a read response from the handler result."""
3638
event_time_timestamp = _timestamp_pb2.Timestamp()
3739
event_time_timestamp.FromDatetime(dt=response.event_time)
@@ -41,6 +43,7 @@ def _create_read_response(response):
4143
offset=response.offset.as_dict,
4244
event_time=event_time_timestamp,
4345
headers=response.headers,
46+
metadata=response.user_metadata._to_proto(),
4447
)
4548
status = source_pb2.ReadResponse.Status(eot=False, code=source_pb2.ReadResponse.Status.SUCCESS)
4649
return source_pb2.ReadResponse(result=result, status=status)
@@ -98,7 +101,7 @@ async def ReadFn(
98101
async for req in request_iterator:
99102
# create an iterator to be provided to the user function where the responses will
100103
# be streamed
101-
niter = NonBlockingIterator()
104+
niter: NonBlockingIterator[Union[Message, Exception]] = NonBlockingIterator()
102105
riter = niter.read_iterator()
103106
task = asyncio.create_task(self.__invoke_read(req, niter))
104107
# Save a reference to the result of this function, to avoid a
@@ -121,7 +124,9 @@ async def ReadFn(
121124
_LOGGER.critical("User-Defined Source ReadFn error", exc_info=True)
122125
await handle_async_error(context, err, ERR_UDF_EXCEPTION_STRING)
123126

124-
async def __invoke_read(self, req, niter):
127+
async def __invoke_read(
128+
self, req: source_pb2.ReadRequest, niter: NonBlockingIterator[Union[Message, Exception]]
129+
):
125130
"""Invoke the read handler and manage the iterator."""
126131
try:
127132
await self.__source_read_handler(

packages/pynumaflow/tests/source/test_async_source.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
import asyncio
2+
from collections.abc import Iterator
23
import logging
34
import threading
45
import unittest
56

67
import grpc
78
from google.protobuf import empty_pb2 as _empty_pb2
8-
from grpc.aio._server import Server
9+
from grpc.aio import Server
910

1011
from pynumaflow import setup_logging
12+
from pynumaflow._metadata import _user_and_system_metadata_from_proto
1113
from pynumaflow.proto.sourcer import source_pb2_grpc, source_pb2
1214
from pynumaflow.sourcer import (
1315
SourceAsyncServer,
@@ -100,13 +102,13 @@ def test_read_source(self) -> None:
100102
stub = source_pb2_grpc.SourceStub(channel)
101103

102104
request = read_req_source_fn()
103-
generator_response = None
104105
try:
105-
generator_response = stub.ReadFn(
106+
generator_response: Iterator[source_pb2.ReadResponse] = stub.ReadFn(
106107
request_iterator=request_generator(1, request, "read")
107108
)
108109
except grpc.RpcError as e:
109110
logging.error(e)
111+
raise
110112

111113
counter = 0
112114
first = True
@@ -139,6 +141,20 @@ def test_read_source(self) -> None:
139141
r.result.offset.partition_id,
140142
)
141143

144+
print(r.result)
145+
(user_metadata, sys_metadata) = _user_and_system_metadata_from_proto(
146+
r.result.metadata
147+
)
148+
print(user_metadata)
149+
150+
self.assertCountEqual(user_metadata.groups(), ["custom_info", "test_info"])
151+
self.assertCountEqual(
152+
user_metadata.keys("custom_info"), ["custom_key", "custom_key2"]
153+
)
154+
self.assertIsNone(user_metadata.value("custom_info", "test_key"))
155+
self.assertEqual(user_metadata.value("custom_info", "custom_key"), b"custom_value")
156+
self.assertEqual(user_metadata.value("test_info", "test_key"), b"test_value")
157+
142158
self.assertFalse(first)
143159
self.assertTrue(last)
144160

packages/pynumaflow/tests/source/test_async_source_err.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import grpc
88

9-
from grpc.aio._server import Server
9+
from grpc.aio import Server
1010

1111
from pynumaflow import setup_logging
1212
from pynumaflow.proto.sourcer import source_pb2_grpc

0 commit comments

Comments
 (0)