Skip to content

Commit f5663f8

Browse files
committed
added the async herd client
1 parent 56bd719 commit f5663f8

File tree

2 files changed

+214
-0
lines changed

2 files changed

+214
-0
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
from .default import AsyncDefaultClient
2+
from .herd import AsyncHerdClient
Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
import socket
2+
import time
3+
from typing import Tuple, Any
4+
5+
from django.conf import settings
6+
from valkey import Valkey
7+
from valkey.exceptions import ConnectionError, ResponseError, TimeoutError
8+
from valkey.typing import KeyT, EncodableT
9+
10+
from django_valkey.async_cache.client import AsyncDefaultClient
11+
from django_valkey.base_client import DEFAULT_TIMEOUT
12+
from django_valkey.client.herd import Marker, _is_expired
13+
from django_valkey.exceptions import ConnectionInterrupted
14+
15+
_main_exceptions = (ConnectionError, ResponseError, TimeoutError, socket.timeout)
16+
17+
18+
class AsyncHerdClient(AsyncDefaultClient):
19+
def __init__(self, *args, **kwargs):
20+
self._marker = Marker()
21+
self._herd_timeout: int = getattr(settings, "CACHE_HERD_TIMEOUT", 60)
22+
super().__init__(*args, **kwargs)
23+
24+
async def _pack(self, value: Any, timeout) -> Tuple[Marker, Any, int]:
25+
herd_timeout = (timeout or self._backend.default_timeout) + int(time.time())
26+
return self._marker, value, herd_timeout
27+
28+
async def _unpack(self, value: Tuple[Marker, Any, int]) -> Tuple[Any, bool]:
29+
try:
30+
marker, unpacked, herd_timeout = value
31+
except (ValueError, TypeError):
32+
return value, False
33+
34+
if not isinstance(marker, Marker):
35+
return value, False
36+
37+
now = time.time()
38+
if herd_timeout < now:
39+
x = now - herd_timeout
40+
return unpacked, _is_expired(x, self._herd_timeout)
41+
42+
return unpacked, False
43+
44+
async def set(
45+
self,
46+
key: KeyT,
47+
value: EncodableT,
48+
timeout: int | None = DEFAULT_TIMEOUT,
49+
version: int | None = None,
50+
client: Valkey | None = None,
51+
nx: bool = False,
52+
xx: bool = False,
53+
):
54+
if timeout is DEFAULT_TIMEOUT:
55+
timeout = self._backend.default_timeout
56+
57+
if timeout is None or timeout <= 0:
58+
return await super().aset(
59+
key,
60+
value,
61+
timeout=timeout,
62+
version=version,
63+
client=client,
64+
nx=nx,
65+
xx=xx,
66+
)
67+
68+
packed = await self._pack(value, timeout)
69+
real_timeout = timeout + self._herd_timeout
70+
71+
return await super().aset(
72+
key,
73+
packed,
74+
timeout=real_timeout,
75+
version=version,
76+
client=client,
77+
nx=nx,
78+
xx=xx,
79+
)
80+
81+
aset = set
82+
83+
async def get(self, key, default=None, version=None, client=None):
84+
packed = await super().aget(
85+
key, default=default, version=version, client=client
86+
)
87+
val, refresh = await self._unpack(packed)
88+
89+
if refresh:
90+
return default
91+
92+
return val
93+
94+
aget = get
95+
96+
async def get_many(self, keys, version=None, client=None):
97+
client = await self._get_client(write=False, client=client)
98+
99+
if not keys:
100+
return {}
101+
102+
recovered_data = {}
103+
104+
new_keys = [await self.make_key(key, version=version) for key in keys]
105+
map_keys = dict(zip(new_keys, keys))
106+
107+
try:
108+
pipeline = await client.pipeline()
109+
for key in new_keys:
110+
await pipeline.get(key)
111+
results = await pipeline.execute()
112+
except _main_exceptions as e:
113+
raise ConnectionInterrupted(connection=client) from e
114+
115+
for key, value in zip(new_keys, results):
116+
if value is None:
117+
continue
118+
119+
val, refresh = await self._unpack(await self.decode(value))
120+
recovered_data[map_keys[key]] = None if refresh else val
121+
122+
return recovered_data
123+
124+
aget_many = get_many
125+
126+
async def mget(self, keys, version=None, client=None):
127+
if not keys:
128+
return {}
129+
130+
client = await self._get_client(write=False, client=client)
131+
132+
recovered_data = {}
133+
134+
new_keys = [await self.make_key(key, version=version) for key in keys]
135+
map_keys = dict(zip(new_keys, keys))
136+
137+
try:
138+
results = await client.mget(new_keys)
139+
except _main_exceptions as e:
140+
raise ConnectionInterrupted(connection=client) from e
141+
142+
for key, value in zip(new_keys, results):
143+
if value is None:
144+
continue
145+
146+
val, refresh = await self._unpack(await self.decode(value))
147+
recovered_data[map_keys[key]] = None if refresh else val
148+
149+
return recovered_data
150+
151+
amget = mget
152+
153+
async def set_many(
154+
self, data, timeout=DEFAULT_TIMEOUT, version=None, client=None, herd=True
155+
):
156+
"""
157+
Set a bunch of values in the cache at once from a dict of key/value
158+
pairs. This is much more efficient than calling set() multiple times.
159+
160+
If timeout is given, that timeout will be used for the key; otherwise
161+
the default cache timeout will be used.
162+
"""
163+
client = await self._get_client(write=True, client=client)
164+
165+
set_function = self.aset if herd else super().aset
166+
167+
try:
168+
pipeline = await client.pipeline()
169+
for key, value in data.items():
170+
await set_function(
171+
key, value, timeout, version=version, client=pipeline
172+
)
173+
await pipeline.execute()
174+
except _main_exceptions as e:
175+
raise ConnectionInterrupted(connection=client) from e
176+
177+
ast_many = set_many
178+
179+
async def mset(self, data, timeout=DEFAULT_TIMEOUT, version=None, client=None):
180+
client = await self._get_client(write=True, client=client)
181+
data = {
182+
await self.make_key(k, version=version): await self.encode(v)
183+
for k, v in data.items()
184+
}
185+
186+
try:
187+
return await client.mset(data)
188+
except _main_exceptions as e:
189+
raise ConnectionInterrupted(connection=client) from e
190+
191+
amset = mset
192+
193+
def incr(self, *args, **kwargs):
194+
raise NotImplementedError
195+
196+
aincr = incr
197+
198+
def decr(self, *args, **kwargs):
199+
raise NotImplementedError
200+
201+
adecr = decr
202+
203+
async def touch(self, key, timeout=DEFAULT_TIMEOUT, version=None, client=None):
204+
client = await self._get_client(write=True, client=client)
205+
206+
value = await self.aget(key, version=version, client=client)
207+
if value is None:
208+
return False
209+
210+
await self.aset(key, value, timeout=timeout, version=version, client=client)
211+
return True
212+
213+
atouch = touch

0 commit comments

Comments
 (0)