Skip to content

Commit dbeb411

Browse files
committed
bett bulk data tests
1 parent f3880a9 commit dbeb411

File tree

6 files changed

+238
-68
lines changed

6 files changed

+238
-68
lines changed

requirements.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
cachetools==5.3.3
2-
certifi==2024.6.2
2+
certifi==2024.7.4
33
chardet==5.2.0
44
charset-normalizer==3.3.2
55
colorama==0.4.6
@@ -11,6 +11,7 @@ idna==3.7
1111
iniconfig==2.0.0
1212
jsonpath-ng==1.6.1
1313
leb128==1.0.8
14+
numpy~=2.1.2
1415
orjson==3.10.7
1516
packaging==24.1
1617
pandas==2.2.3
@@ -30,3 +31,4 @@ tox==4.15.1
3031
tzdata==2024.2
3132
urllib3==2.2.2
3233
virtualenv==20.26.2
34+

src/module-api/src/bulk_data.py

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,38 @@
55

66
@dataclass
77
class BulkData:
8+
# BulkData List of byte arrays
89
data: list[bytes]
910

1011
def __init__(self, data: list[bytes]):
1112
self.data = data
1213

1314
@staticmethod
1415
def is_serialized_bulk_data(serialized: bytes) -> bool:
16+
"""
17+
Check if the serialized byte array is a serialized BulkData
18+
19+
A BulkData is serialized as a list of bytes with a header and a footer.
20+
The header is 0x87 followed by the number of items in the list.
21+
The footer is the number of items in the list and each item is a byte
22+
string followed by the length of the byte string.
23+
24+
:param serialized: the serialized byte array
25+
:return: True if the byte array is a serialized BulkData, False otherwise
26+
"""
1527
return len(serialized) > 2 and serialized[0] == 0x87 and serialized[1] == 0x87
1628

1729
def serialize(self) -> bytes:
30+
"""
31+
Serialize the BulkData to a byte array.
32+
33+
A BulkData is serialized as a list of bytes with a header and a footer.
34+
The header is 0x87 followed by the number of items in the list.
35+
The footer is the number of items in the list and each item is a byte
36+
string followed by the length of the byte string.
37+
38+
:return: the serialized byte array
39+
"""
1840
result = io.BytesIO()
1941
# Write the header
2042
result.write(bytes([0x87, 0x87]))
@@ -28,14 +50,29 @@ def serialize(self) -> bytes:
2850

2951
@classmethod
3052
def deserialize(cls, serialized: bytes) -> 'BulkData':
53+
"""
54+
Deserialize a serialized BulkData to a BulkData object
55+
56+
The serialized byte array is expected to be a list of bytes with a header and a footer.
57+
The header is 0x87 followed by the number of items in the list.
58+
The footer is the number of items in the list, and each item is a byte
59+
string followed by the length of the byte string.
60+
61+
:param serialized: the serialized byte arrays
62+
:return: the deserialized BulkData object
63+
"""
3164
data = io.BytesIO(serialized)
3265
# read the first two bytes 0x87
33-
_header = data.read(2)
66+
header = data.read(2)
67+
# assert the first two bytes are 0x87
68+
assert header == bytes([0x87, 0x87])
3469
# read the number of items
3570
num_items, _num_bytes = u.decode_reader(data)
3671
# Preallocate the list for the result
37-
result = [None] * num_items
38-
for i in range(num_items):
72+
result: list[bytes] = [b'' for _ in range(num_items)]
73+
for i in range(num_items):
3974
item_length, _ = u.decode_reader(data)
40-
result[i] = data.read(item_length)
75+
b = data.read(item_length)
76+
assert len(b) == item_length, f"Item {i} has length {len(b)}, Expected {item_length} bytes"
77+
result[i] = b
4178
return cls(result)

src/module-api/src/kmip_post.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import json
22
import requests
3-
# import httpx
43
from client_configuration import ClientConfiguration
54
import logging
65

src/module-api/src/lru_cache.py

Lines changed: 69 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,20 @@
1414

1515

1616
def key_hash(key: bytes | list[bytes]) -> int:
17+
"""
18+
Compute the hash of the key
19+
20+
This function computes the hash of the key using the xxh64 algorithm.
21+
The hash is computed as the digest of the concatenation of the bytes
22+
in the key if the key is a list of bytes, or as the digest of the
23+
key bytes if the key is a single bytes object.
24+
25+
Args:
26+
key: the key to hash
27+
28+
Returns:
29+
the hash of the key
30+
"""
1731
h = xxh64()
1832
if isinstance(key, list):
1933
for k in key:
@@ -22,39 +36,88 @@ def key_hash(key: bytes | list[bytes]) -> int:
2236
h.update(key)
2337
return h.intdigest()
2438

25-
39+
###
40+
# The LRUCache is a least recently used cache. It is used to store the result of the
41+
# encrypt and decrypt operations in the KMS proxy. The cache is implemented as a
42+
# dictionary with a limited size (the capacity). The cache is protected by a lock
43+
# to prevent concurrent access from multiple threads. The cache is cleared when
44+
# the cache size reaches the capacity.
45+
###
2646
class LRUCache:
2747

2848
def __init__(self, capacity):
29-
self.cache = dict()
3049
self.capacity = capacity
31-
self.access = deque()
50+
self.cache = dict()
51+
# The access list is used to track the order of access of the cache
52+
# entries. The most recently accessed entry is at the end of the list
53+
self.access = deque(maxlen=capacity)
3254
self.lock = threading.Lock()
3355

3456
def get(self, key: bytes | list[bytes]) -> bytes | None:
57+
"""
58+
Get the value associated with the key from the cache
59+
60+
Args:
61+
key: the key to get the value for
62+
63+
Returns:
64+
the value associated with the key if the key is in the cache
65+
None otherwise
66+
"""
3567
key = key_hash(key)
3668
if key not in self.cache:
69+
# The key is not in the cache, return None
3770
return None
3871
else:
39-
# small race condition here with the test on self.cache
40-
# but we do not want to delay self.cache
72+
# The key is in the cache, return the value associated with the key
73+
# The access list is used to track the order of access of the cache
74+
# entries. The most recently accessed entry is at the end of the list
75+
# When we get an entry, we remove it from the list and add it to
76+
# the end of the list
4177
with self.lock:
4278
if self.access[-1] != key:
79+
# The key is not the most recently accessed, remove it from
80+
# the list and add it to the end of the list
4381
self.access.remove(key)
4482
self.access.append(key)
83+
# Return the value associated with the key
4584
return self.cache[key]
4685

4786
def put(self, key: bytes | list[bytes], value: bytes):
87+
"""
88+
Put a key/value pair in the cache
89+
90+
Args:
91+
key: the key to put in the cache
92+
value: the value to associate with the key
93+
94+
Notes:
95+
When the cache reaches its capacity, the least recently used entry
96+
is removed from the cache
97+
"""
4898
key = key_hash(key)
4999
with self.lock:
100+
# If the key is already in the cache, remove it from the access list
50101
if key in self.cache:
51102
self.access.remove(key)
103+
# If the cache is full, remove the least recently used entry
52104
elif len(self.cache) == self.capacity:
53105
oldest = self.access.popleft()
54106
del self.cache[oldest]
107+
# Put the key/value pair in the cache and add it to the end of the
108+
# access list
55109
self.cache[key] = value
56110
self.access.append(key)
57111

58112
def print(self):
59-
for key in self.access:
113+
"""
114+
Print the content of the cache
115+
116+
This method is useful for debugging purposes. It prints the content of
117+
the cache to the console. The cache is printed as a sequence of key/value
118+
pairs, with the most recently accessed entry last.
119+
"""
120+
# Iterate over the access list in reverse order to get the most recently
121+
# accessed entry last
122+
for key in reversed(self.access):
60123
print(f"{key}: {self.cache[key]}")

tests/bulk_data_test.py

Lines changed: 73 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import numpy as np
22
import numpy.testing as npt
3-
from bulk_data import BulkData
43
import logging
54
import time
65
import random
6+
import unittest
7+
from bulk_data import BulkData
78

89
logger = logging.getLogger(__name__)
910
slog = logging.LoggerAdapter(logger, {
@@ -15,22 +16,6 @@
1516
})
1617

1718

18-
def test_bulk_data_test_vector():
19-
data = np.array([
20-
bytes([0x01, 0x02, 0x03]),
21-
bytes([0x04, 0x05, 0x06]),
22-
bytes([0x07] * 10)
23-
])
24-
bulk_data = BulkData(data)
25-
serialized = bulk_data.serialize()
26-
assert list(serialized) == [
27-
0x87, 0x87, 0x03, 0x03, 0x01, 0x02, 0x03, 0x03, 0x04, 0x05, 0x06, 0x0A, 0x07, 0x07,
28-
0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07
29-
]
30-
deserialized = BulkData.deserialize(serialized)
31-
npt.assert_array_equal(data, deserialized.data)
32-
33-
3419
def benchmark_bulk_data(bulk_data) -> BulkData:
3520
t_start = time.perf_counter()
3621
serialized = bulk_data.serialize()
@@ -48,23 +33,76 @@ def random_bytes() -> bytes:
4833
return bytes(random.getrandbits(8) for _ in range(64))
4934

5035

51-
def test_bulk_data_benchmark():
52-
num_samples = 5000000
53-
slog.info(f"Testing performance with bulk data of {num_samples} samples")
54-
t_start = time.perf_counter()
55-
data = np.array([
56-
random.randbytes(64) for _ in range(num_samples)
57-
], dtype=np.object_)
58-
# check all samples have 64 bytes
59-
for item in data:
60-
assert len(item) == 64
61-
bulk_data = BulkData(data)
62-
t_generate = time.perf_counter() - t_start
63-
slog.info(f"Generate: {t_generate}s")
36+
class TestBulkDataDeserialize(unittest.TestCase):
37+
def test_valid_serialization_one_item(self):
38+
serialized = b'\x87\x87\x01\x03abc'
39+
expected = BulkData([b'abc'])
40+
self.assertEqual(BulkData.deserialize(serialized), expected)
41+
42+
def test_valid_serialization_multiple_items(self):
43+
serialized = b'\x87\x87\x02\x03abc\x03def'
44+
expected = BulkData([b'abc', b'def'])
45+
self.assertEqual(BulkData.deserialize(serialized), expected)
46+
47+
def test_invalid_serialization_incorrect_header(self):
48+
serialized = b'\x88\x87\x01\x03abc'
49+
with self.assertRaises(AssertionError):
50+
BulkData.deserialize(serialized)
51+
52+
def test_invalid_serialization_incorrect_item_length(self):
53+
serialized = b'\x87\x87\x01\x04abc'
54+
with self.assertRaises(AssertionError):
55+
BulkData.deserialize(serialized)
56+
57+
def test_invalid_serialization_truncated_data(self):
58+
serialized = b'\x87\x87\x01\x03ab'
59+
with self.assertRaises(AssertionError):
60+
BulkData.deserialize(serialized)
61+
62+
def test_invalid_serialization_empty_data(self):
63+
serialized = b''
64+
with self.assertRaises(AssertionError):
65+
BulkData.deserialize(serialized)
66+
67+
def test_bulk_data_test_vector(self):
68+
data = np.array([
69+
bytes([0x01, 0x02, 0x03]),
70+
bytes([0x04, 0x05, 0x06]),
71+
bytes([0x07] * 10)
72+
])
73+
bulk_data = BulkData(data.tolist())
74+
serialized = bulk_data.serialize()
75+
assert list(serialized) == [
76+
0x87, 0x87, 0x03, 0x03, 0x01, 0x02, 0x03, 0x03, 0x04, 0x05, 0x06, 0x0A, 0x07, 0x07,
77+
0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07
78+
]
79+
deserialized = BulkData.deserialize(serialized)
80+
npt.assert_array_equal(data, deserialized.data)
81+
82+
def test_bulk_data_benchmark(self):
83+
num_samples = 1000000
84+
slog.info(f"Testing performance with bulk data of {num_samples} samples")
85+
t_start = time.perf_counter()
86+
data = np.array([
87+
random.randbytes(64) for _ in range(num_samples)
88+
], dtype=np.object_)
89+
# check all samples have 64 bytes
90+
for item in data:
91+
assert len(item) == 64
92+
bulk_data = BulkData(data.tolist())
93+
t_generate = time.perf_counter() - t_start
94+
slog.info(f"Generate: {t_generate}s")
95+
# serialize+deserialize
96+
t_start = time.perf_counter()
97+
recovered = benchmark_bulk_data(bulk_data)
98+
t_all = time.perf_counter() - t_start
99+
slog.info(f"serialize+deserialize: {t_all}s, i.e. {t_all / num_samples * 1000000:.6f}µs per item")
100+
self.assertEqual(len(bulk_data.data), len(recovered.data))
101+
# sample 100 random data from both arrays and check they are equal
102+
for _ in range(100):
103+
i = random.randint(0, len(bulk_data.data) - 1)
104+
assert np.array_equal(bulk_data.data[i], recovered.data[i])
64105

65-
t_start = time.perf_counter()
66-
recovered = benchmark_bulk_data(bulk_data)
67-
t_all = time.perf_counter() - t_start
68-
slog.info(f"serialize+deserialize: {t_all}s, i.e. {t_all / num_samples * 1000000:.6f}µs per item")
69106

70-
assert np.array_equal(bulk_data.data, recovered.data)
107+
if __name__ == '__main__':
108+
unittest.main()

0 commit comments

Comments
 (0)