Skip to content

Commit be64693

Browse files
committed
Implementation and tests for DocumentChunk
1 parent 761b9fc commit be64693

File tree

3 files changed

+135
-5
lines changed

3 files changed

+135
-5
lines changed

src/coherence/ai.py

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,12 @@
66

77
import base64
88
from abc import ABC
9-
from typing import Optional
9+
from collections import OrderedDict
10+
from typing import Any, Dict, Optional, cast
1011

11-
from coherence.serialization import proxy
12+
import jsonpickle
13+
14+
from coherence.serialization import JavaProxyUnpickler, proxy
1215

1316

1417
class Vector(ABC):
@@ -54,3 +57,73 @@ class FloatVector(Vector):
5457
def __init__(self, float_array: list[float]):
5558
super().__init__()
5659
self.array = float_array
60+
61+
62+
class AbstractEvolvable(ABC):
63+
def __init__(self, data_version: int = 0, bin_future: Optional[Any] = None):
64+
self.dataVersion = data_version
65+
self.binFuture = bin_future
66+
67+
68+
@proxy("ai.DocumentChunk")
69+
class DocumentChunk(AbstractEvolvable):
70+
def __init__(
71+
self,
72+
text: str,
73+
metadata: Optional[dict[str, Any] | OrderedDict[str, Any]] = None,
74+
vector: Optional[Vector] = None,
75+
):
76+
super().__init__()
77+
self.text = text
78+
if metadata is None:
79+
self.metadata: Dict[str, Any] = OrderedDict()
80+
else:
81+
self.metadata = metadata
82+
self.vector = vector
83+
84+
85+
@jsonpickle.handlers.register(DocumentChunk)
86+
class DocumentChunkHandler(jsonpickle.handlers.BaseHandler):
87+
def flatten(self, obj: object, data: dict[str, Any]) -> dict[str, Any]:
88+
dc: DocumentChunk = cast(DocumentChunk, obj)
89+
result_dict: Dict[Any, Any] = dict()
90+
result_dict["@class"] = "ai.DocumentChunk"
91+
result_dict["dataVersion"] = dc.dataVersion
92+
if hasattr(dc, "binFuture"):
93+
if dc.binFuture is not None:
94+
result_dict["binFuture"] = dc.binFuture
95+
if hasattr(dc, "metadata"):
96+
if dc.metadata is not None:
97+
result_dict["metadata"] = dict()
98+
if isinstance(dc.metadata, OrderedDict):
99+
result_dict["metadata"]["@ordered"] = True
100+
entries = list()
101+
for k, v in dc.metadata.items():
102+
entries.append({"key": k, "value": v})
103+
result_dict["metadata"]["entries"] = entries
104+
if hasattr(dc, "vector"):
105+
v = dc.vector
106+
if v is not None:
107+
if isinstance(v, BitVector):
108+
result_dict["vector"] = dict()
109+
result_dict["vector"]["@class"] = "ai.BitVector"
110+
# noinspection PyUnresolvedReferences
111+
result_dict["vector"]["bits"] = v.bits
112+
elif isinstance(v, ByteVector):
113+
result_dict["vector"] = dict()
114+
result_dict["vector"]["@class"] = "ai.Int8Vector"
115+
# noinspection PyUnresolvedReferences
116+
result_dict["vector"]["array"] = v.array
117+
elif isinstance(v, FloatVector):
118+
result_dict["vector"] = dict()
119+
result_dict["vector"]["@class"] = "ai.Float32Vector"
120+
# noinspection PyUnresolvedReferences
121+
result_dict["vector"]["array"] = v.array
122+
result_dict["text"] = dc.text
123+
return result_dict
124+
125+
def restore(self, obj: dict[str, Any]) -> DocumentChunk:
126+
jpu = JavaProxyUnpickler()
127+
d = DocumentChunk("")
128+
o = jpu._restore_from_dict(obj, d)
129+
return o

src/coherence/serialization.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from __future__ import annotations
66

7+
import collections
78
from abc import ABC, abstractmethod
89
from decimal import Decimal
910
from typing import Any, Callable, Dict, Final, Optional, Type, TypeVar, cast
@@ -18,6 +19,7 @@
1819
_META_CLASS: Final[str] = "@class"
1920
_META_VERSION: Final[str] = "@version"
2021
_META_ENUM: Final[str] = "enum"
22+
_META_ORDERED: Final[str] = "@ordered"
2123

2224
_JSON_KEY = "key"
2325
_JSON_VALUE = "value"
@@ -205,10 +207,10 @@ class JavaProxyUnpickler(jsonpickle.Unpickler):
205207
# noinspection PyUnresolvedReferences
206208
def _restore(self, obj: Any) -> Any:
207209
if isinstance(obj, dict):
208-
metadata: str = obj.get(_META_CLASS, None)
210+
metadata: Any = obj.get(_META_CLASS, None)
209211
if metadata is not None:
210212
type_: Optional[Type[Any]] = _type_for(metadata)
211-
actual: dict[str, Any] = dict()
213+
actual: dict[Any, Any] = dict()
212214
if type_ is None:
213215
if "map" in metadata.lower():
214216
for entry in obj[_JSON_ENTRIES]:
@@ -228,6 +230,25 @@ def _restore(self, obj: Any) -> Any:
228230

229231
return super().restore(actual, reset=False)
230232

233+
# When "@Ordered" set to true which converts to OrderedDict()
234+
metadata = obj.get(_META_ORDERED, False)
235+
if metadata is True:
236+
o = collections.OrderedDict()
237+
entries = obj.get(_JSON_ENTRIES, None)
238+
if entries is not None:
239+
for entry in obj[_JSON_ENTRIES]:
240+
o[entry[_JSON_KEY]] = entry[_JSON_VALUE]
241+
return o
242+
243+
# When there is no "@Ordered" set. Only "entries" list exists
244+
if len(obj) == 1:
245+
entries = obj.get(_JSON_ENTRIES, None)
246+
if entries is not None:
247+
actual = dict()
248+
for entry in obj[_JSON_ENTRIES]:
249+
actual[entry[_JSON_KEY]] = entry[_JSON_VALUE]
250+
return super().restore(actual, reset=False)
251+
231252
return super()._restore(obj)
232253

233254

tests/test_ai.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# Licensed under the Universal Permissive License v 1.0 as shown at
33
# https://oss.oracle.com/licenses/upl.
44

5-
from coherence.ai import BitVector, ByteVector, FloatVector
5+
from coherence.ai import BitVector, ByteVector, FloatVector, DocumentChunk
66
from coherence.serialization import JSONSerializer, SerializerRegistry
77

88
s = SerializerRegistry.serializer(JSONSerializer.SER_FORMAT)
@@ -12,27 +12,63 @@ def test_BitVector_serialization() -> None:
1212
coh_bv = BitVector(hex_string="AABBCC")
1313
ser = s.serialize(coh_bv)
1414
assert ser == b'\x15{"@class": "ai.BitVector", "bits": "0xAABBCC"}'
15+
o = s.deserialize(ser)
16+
assert isinstance(o, BitVector)
1517

1618
coh_bv = BitVector(hex_string="0xAABBCC")
1719
ser = s.serialize(coh_bv)
1820
assert ser == b'\x15{"@class": "ai.BitVector", "bits": "0xAABBCC"}'
21+
o = s.deserialize(ser)
22+
assert isinstance(o, BitVector)
1923

2024
coh_bv = BitVector(hex_string=None, byte_array=bytes([1, 2, 10]))
2125
ser = s.serialize(coh_bv)
2226
assert ser == b'\x15{"@class": "ai.BitVector", "bits": "0x01020a"}'
27+
o = s.deserialize(ser)
28+
assert isinstance(o, BitVector)
2329

2430
coh_bv = BitVector(hex_string=None, int_array=[1234, 1235])
2531
ser = s.serialize(coh_bv)
2632
assert ser == b'\x15{"@class": "ai.BitVector", "bits": "0x4d24d3"}'
33+
o = s.deserialize(ser)
34+
assert isinstance(o, BitVector)
2735

2836

2937
def test_ByteVector_serialization() -> None:
3038
coh_int8v = ByteVector(bytes([1, 2, 3, 4]))
3139
ser = s.serialize(coh_int8v)
3240
assert ser == b'\x15{"@class": "ai.Int8Vector", "array": "AQIDBA=="}'
41+
o = s.deserialize(ser)
42+
assert isinstance(o, ByteVector)
3343

3444

3545
def test_FloatVector_serialization() -> None:
3646
coh_fv = FloatVector([1.0, 2.0, 3.0])
3747
ser = s.serialize(coh_fv)
3848
assert ser == b'\x15{"@class": "ai.Float32Vector", "array": [1.0, 2.0, 3.0]}'
49+
o = s.deserialize(ser)
50+
assert isinstance(o, FloatVector)
51+
52+
53+
def test_DocumentChunk_serialization() -> None:
54+
dc = DocumentChunk("test")
55+
ser = s.serialize(dc)
56+
assert ser == b'\x15{"@class": "ai.DocumentChunk", "dataVersion": 0, "metadata": {"@ordered": true, "entries": []}, "text": "test"}'
57+
o = s.deserialize(ser)
58+
assert isinstance(o, DocumentChunk)
59+
60+
d = {"one":"one-value", "two": "two-value"}
61+
dc = DocumentChunk("test",d)
62+
ser = s.serialize(dc)
63+
assert ser == b'\x15{"@class": "ai.DocumentChunk", "dataVersion": 0, "metadata": {"entries": [{"key": "one", "value": "one-value"}, {"key": "two", "value": "two-value"}]}, "text": "test"}'
64+
o = s.deserialize(ser)
65+
assert isinstance(o, DocumentChunk)
66+
67+
coh_fv = FloatVector([1.0, 2.0, 3.0])
68+
d = {"one":"one-value", "two": "two-value"}
69+
dc = DocumentChunk("test",d, coh_fv)
70+
ser = s.serialize(dc)
71+
assert ser == b'\x15{"@class": "ai.DocumentChunk", "dataVersion": 0, "metadata": {"entries": [{"key": "one", "value": "one-value"}, {"key": "two", "value": "two-value"}]}, "vector": {"@class": "ai.Float32Vector", "array": [1.0, 2.0, 3.0]}, "text": "test"}'
72+
o = s.deserialize(ser)
73+
assert isinstance(o, DocumentChunk)
74+

0 commit comments

Comments
 (0)