Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 65 additions & 48 deletions nats/src/nats/js/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,28 +19,14 @@
import time
from email.parser import BytesParser
from secrets import token_hex
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Dict,
List,
Optional,
)
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional

import nats.errors
import nats.js.errors
from nats.aio.msg import Msg
from nats.aio.subscription import Subscription
from nats.js import api
from nats.js.errors import (
BadBucketError,
BucketNotFoundError,
FetchTimeoutError,
InvalidBucketNameError,
NotFoundError,
)
from nats.js.errors import BadBucketError, BucketNotFoundError, FetchTimeoutError, InvalidBucketNameError, NotFoundError
from nats.js.kv import KeyValue
from nats.js.manager import JetStreamManager
from nats.js.object_store import (
Expand Down Expand Up @@ -123,7 +109,8 @@ def __init__(
self._publish_async_completed_event = asyncio.Event()
self._publish_async_completed_event.set()

self._publish_async_pending_semaphore = asyncio.Semaphore(publish_async_max_pending)
self._publish_async_pending_semaphore = asyncio.Semaphore(
publish_async_max_pending)

@property
def _jsm(self) -> JetStreamManager:
Expand All @@ -144,10 +131,11 @@ async def _init_async_reply(self) -> None:
async_reply_subject = self._async_reply_prefix[:]
async_reply_subject.extend(b"*")

await self._nc.subscribe(async_reply_subject.decode(), cb=self._handle_async_reply)
await self._nc.subscribe(async_reply_subject.decode(),
cb=self._handle_async_reply)

async def _handle_async_reply(self, msg: Msg) -> None:
token = msg.subject[len(self._nc._inbox_prefix) + 22 + 2 :]
token = msg.subject[len(self._nc._inbox_prefix) + 22 + 2:]
future = self._publish_async_futures.get(token)

if not future:
Expand All @@ -157,7 +145,8 @@ async def _handle_async_reply(self, msg: Msg) -> None:
return

# Handle no responders
if msg.headers and msg.headers.get(api.Header.STATUS) == NO_RESPONDERS_STATUS:
if msg.headers and msg.headers.get(
api.Header.STATUS) == NO_RESPONDERS_STATUS:
future.set_exception(nats.js.errors.NoStreamResponseError)
return

Expand Down Expand Up @@ -229,7 +218,9 @@ async def publish_async(
hdr[api.Header.EXPECTED_STREAM] = stream

try:
await asyncio.wait_for(self._publish_async_pending_semaphore.acquire(), timeout=wait_stall)
await asyncio.wait_for(
self._publish_async_pending_semaphore.acquire(),
timeout=wait_stall)
except (asyncio.TimeoutError, asyncio.CancelledError):
raise nats.js.errors.TooManyStalledMsgsError

Expand All @@ -256,7 +247,10 @@ def handle_done(future):
if self._publish_async_completed_event.is_set():
self._publish_async_completed_event.clear()

await self._nc.publish(subject, payload, reply=inbox.decode(), headers=hdr)
await self._nc.publish(subject,
payload,
reply=inbox.decode(),
headers=hdr)

return future

Expand Down Expand Up @@ -349,7 +343,9 @@ async def cb(msg):
# If using a queue, that will be the consumer/durable name.
if queue:
if durable and durable != queue:
raise nats.js.errors.Error(f"cannot create queue subscription '{queue}' to consumer '{durable}'")
raise nats.js.errors.Error(
f"cannot create queue subscription '{queue}' to consumer '{durable}'"
)
else:
durable = queue

Expand Down Expand Up @@ -381,7 +377,8 @@ async def cb(msg):
elif consumer_info.push_bound:
# Need to reject a non queue subscription to a non queue consumer
# if the consumer is already bound.
raise nats.js.errors.Error("consumer is already bound to a subscription")
raise nats.js.errors.Error(
"consumer is already bound to a subscription")
else:
if not queue:
raise nats.js.errors.Error(
Expand All @@ -390,8 +387,7 @@ async def cb(msg):
elif queue != deliver_group:
raise nats.js.errors.Error(
f"cannot create a queue subscription {queue} for a consumer "
f"with a deliver group {deliver_group}"
)
f"with a deliver group {deliver_group}")
elif should_create:
# Auto-create consumer if none found.
if config is None:
Expand All @@ -414,7 +410,10 @@ async def cb(msg):
config.deliver_subject = deliver

# Auto created consumers use the filter subject.
config.filter_subject = subject
# Use filter_subjects if already set (modern multi-filter API)
# Otherwise use filter_subject (legacy single-filter API)
if not config.filter_subjects:
config.filter_subject = subject

# Heartbeats / FlowControl
config.flow_control = flow_control
Expand Down Expand Up @@ -469,7 +468,8 @@ async def subscribe_bind(
#
# In case ack policy is none then we also do not require to ack.
#
if cb and (not manual_ack) and (config.ack_policy is not api.AckPolicy.NONE):
if cb and (not manual_ack) and (config.ack_policy
is not api.AckPolicy.NONE):
cb = self._auto_ack_callback(cb)
if config.deliver_subject is None:
raise TypeError("config.deliver_subject is required")
Expand Down Expand Up @@ -497,12 +497,14 @@ async def subscribe_bind(
sub._jsi._hbtask = asyncio.create_task(sub._jsi.activity_check())

if ordered_consumer:
sub._jsi._fctask = asyncio.create_task(sub._jsi.check_flow_control_response())
sub._jsi._fctask = asyncio.create_task(
sub._jsi.check_flow_control_response())

return psub

@staticmethod
def _auto_ack_callback(callback: Callback) -> Callback:

async def new_callback(msg: Msg) -> None:
await callback(msg)
try:
Expand Down Expand Up @@ -669,11 +671,9 @@ def _is_processable_msg(cls, status: Optional[str], msg: Msg) -> bool:

@classmethod
def _is_temporary_error(cls, status: Optional[str]) -> bool:
if (
status == api.StatusCode.NO_MESSAGES
or status == api.StatusCode.CONFLICT
or status == api.StatusCode.REQUEST_TIMEOUT
):
if (status == api.StatusCode.NO_MESSAGES
or status == api.StatusCode.CONFLICT
or status == api.StatusCode.REQUEST_TIMEOUT):
return True
else:
return False
Expand All @@ -686,12 +686,14 @@ def _is_heartbeat(cls, status: Optional[str]) -> bool:
return False

@classmethod
def _time_until(cls, timeout: Optional[float], start_time: float) -> Optional[float]:
def _time_until(cls, timeout: Optional[float],
start_time: float) -> Optional[float]:
if timeout is None:
return None
return timeout - (time.monotonic() - start_time)

class _JSI:

def __init__(
self,
js: JetStreamContext,
Expand Down Expand Up @@ -766,7 +768,8 @@ async def check_flow_control_response(self):
if self._conn.is_closed:
break

if (self._fciseq - self._psub._pending_queue.qsize()) >= self._fcd:
if (self._fciseq -
self._psub._pending_queue.qsize()) >= self._fcd:
fc_reply = self._fcr
try:
if fc_reply:
Expand All @@ -779,7 +782,8 @@ async def check_flow_control_response(self):
except asyncio.CancelledError:
break

async def check_for_sequence_mismatch(self, msg: Msg) -> Optional[bool]:
async def check_for_sequence_mismatch(self,
msg: Msg) -> Optional[bool]:
self._active = True
if not self._cmeta:
return None
Expand All @@ -797,7 +801,8 @@ async def check_for_sequence_mismatch(self, msg: Msg) -> Optional[bool]:
sseq = int(tokens[5]) # stream sequence

if self._ordered:
did_reset = await self.reset_ordered_consumer(self._sseq + 1)
did_reset = await self.reset_ordered_consumer(self._sseq +
1)
else:
ecs = nats.js.errors.ConsumerSequenceMismatchError(
stream_resume_sequence=sseq,
Expand Down Expand Up @@ -852,7 +857,10 @@ async def reset_ordered_consumer(self, sseq: Optional[int]) -> bool:

async def recreate_consumer(self) -> None:
try:
cinfo = await self._js._jsm.add_consumer(self._stream, config=self._ccreq, timeout=self._js._timeout)
cinfo = await self._js._jsm.add_consumer(
self._stream,
config=self._ccreq,
timeout=self._js._timeout)
self._psub._consumer = cinfo.name
except Exception as err:
await self._conn._error_cb(err)
Expand Down Expand Up @@ -1017,7 +1025,8 @@ async def consumer_info(self) -> api.ConsumerInfo:
"""
consumer_info gets the current info of the consumer from this subscription.
"""
info = await self._js._jsm.consumer_info(self._stream, self._consumer)
info = await self._js._jsm.consumer_info(self._stream,
self._consumer)
return info

async def fetch(
Expand Down Expand Up @@ -1063,7 +1072,8 @@ async def main():
if timeout is not None and timeout <= 0:
raise ValueError("nats: invalid fetch timeout")

expires = int(timeout * 1_000_000_000) - 100_000 if timeout else None
expires = int(timeout *
1_000_000_000) - 100_000 if timeout else None
if batch == 1:
msg = await self._fetch_one(expires, timeout, heartbeat)
return [msg]
Expand Down Expand Up @@ -1099,7 +1109,8 @@ async def _fetch_one(
if expires:
next_req["expires"] = int(expires)
if heartbeat:
next_req["idle_heartbeat"] = int(heartbeat * 1_000_000_000) # to nanoseconds
next_req["idle_heartbeat"] = int(
heartbeat * 1_000_000_000) # to nanoseconds

await self._nc.publish(
self._nms,
Expand All @@ -1111,7 +1122,8 @@ async def _fetch_one(
got_any_response = False
while True:
try:
deadline = JetStreamContext._time_until(timeout, start_time)
deadline = JetStreamContext._time_until(
timeout, start_time)
# Wait for the response or raise timeout.
msg = await self._sub.next_msg(timeout=deadline)

Expand All @@ -1131,7 +1143,8 @@ async def _fetch_one(
else:
return msg
except asyncio.TimeoutError:
deadline = JetStreamContext._time_until(timeout, start_time)
deadline = JetStreamContext._time_until(
timeout, start_time)
if deadline is not None and deadline < 0:
# No response from the consumer could have been
# due to a reconnect while the fetch request,
Expand Down Expand Up @@ -1178,7 +1191,8 @@ async def _fetch_n(
if expires:
next_req["expires"] = expires
if heartbeat:
next_req["idle_heartbeat"] = int(heartbeat * 1_000_000_000) # to nanoseconds
next_req["idle_heartbeat"] = int(
heartbeat * 1_000_000_000) # to nanoseconds
next_req["no_wait"] = True
await self._nc.publish(
self._nms,
Expand Down Expand Up @@ -1210,7 +1224,8 @@ async def _fetch_n(

try:
for i in range(0, needed):
deadline = JetStreamContext._time_until(timeout, start_time)
deadline = JetStreamContext._time_until(
timeout, start_time)
msg = await self._sub.next_msg(timeout=deadline)
status = JetStreamContext.is_status_msg(msg)
if status == api.StatusCode.NO_MESSAGES or status == api.StatusCode.REQUEST_TIMEOUT:
Expand Down Expand Up @@ -1240,7 +1255,8 @@ async def _fetch_n(
if expires:
next_req["expires"] = expires
if heartbeat:
next_req["idle_heartbeat"] = int(heartbeat * 1_000_000_000) # to nanoseconds
next_req["idle_heartbeat"] = int(
heartbeat * 1_000_000_000) # to nanoseconds

await self._nc.publish(
self._nms,
Expand Down Expand Up @@ -1296,7 +1312,8 @@ async def _fetch_n(
# Wait for the rest of the messages to be delivered to the internal pending queue.
try:
for _ in range(needed):
deadline = JetStreamContext._time_until(timeout, start_time)
deadline = JetStreamContext._time_until(
timeout, start_time)
if deadline is not None and deadline < 0:
return msgs

Expand Down
Loading
Loading