Skip to content

Commit cc27968

Browse files
committed
Don't break internal loops on CancelledError when the cancel is not triggered internally (avoid CTRL-C issue on python < 3.11)
1 parent 7e7883e commit cc27968

File tree

4 files changed

+81
-28
lines changed

4 files changed

+81
-28
lines changed

nats/aio/client.py

Lines changed: 41 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,7 @@ def __init__(self) -> None:
251251
# New style request/response
252252
self._resp_map: Dict[str, asyncio.Future] = {}
253253
self._resp_sub_prefix: Optional[bytearray] = None
254+
self._sub_prefix_subscription: Optional[Subscription] = None
254255
self._nuid = NUID()
255256
self._inbox_prefix = bytearray(DEFAULT_INBOX_PREFIX)
256257
self._auth_configured: bool = False
@@ -680,11 +681,17 @@ async def _close(self, status: int, do_cbs: bool = True) -> None:
680681
if self.is_closed:
681682
self._status = status
682683
return
683-
self._status = Client.CLOSED
684+
685+
if self._sub_prefix_subscription is not None:
686+
subscription = self._sub_prefix_subscription
687+
self._sub_prefix_subscription = None
688+
await subscription.unsubscribe()
684689

685690
# Kick the flusher once again so that Task breaks and avoid pending futures.
686691
await self._flush_pending()
687692

693+
self._status = Client.CLOSED
694+
688695
if self._reading_task is not None and not self._reading_task.cancelled(
689696
):
690697
self._reading_task.cancel()
@@ -726,11 +733,7 @@ async def _close(self, status: int, do_cbs: bool = True) -> None:
726733
# Cleanup subscriptions since not reconnecting so no need
727734
# to replay the subscriptions anymore.
728735
for sub in self._subs.values():
729-
# Async subs use join when draining already so just cancel here.
730-
if sub._wait_for_msgs_task and not sub._wait_for_msgs_task.done():
731-
sub._wait_for_msgs_task.cancel()
732-
if sub._message_iterator:
733-
sub._message_iterator._cancel()
736+
sub._stop_processing()
734737
# Sync subs may have some inflight next_msg calls that could be blocking
735738
# so cancel them here to unblock them.
736739
if sub._pending_next_msgs_calls:
@@ -985,7 +988,7 @@ async def _init_request_sub(self) -> None:
985988
self._resp_sub_prefix.extend(b".")
986989
resp_mux_subject = self._resp_sub_prefix[:]
987990
resp_mux_subject.extend(b"*")
988-
await self.subscribe(
991+
self._sub_prefix_subscription = await self.subscribe(
989992
resp_mux_subject.decode(), cb=self._request_sub_callback
990993
)
991994

@@ -2068,23 +2071,28 @@ async def _flusher(self) -> None:
20682071
if not self.is_connected or self.is_connecting:
20692072
break
20702073

2071-
future: asyncio.Future = await self._flush_queue.get()
2072-
20732074
try:
2074-
if self._pending_data_size > 0:
2075-
self._transport.writelines(self._pending[:])
2076-
self._pending = []
2077-
self._pending_data_size = 0
2078-
await self._transport.drain()
2079-
except OSError as e:
2080-
await self._error_cb(e)
2081-
await self._process_op_err(e)
2082-
break
2083-
except (asyncio.CancelledError, RuntimeError, AttributeError):
2084-
# RuntimeError in case the event loop is closed
2085-
break
2086-
finally:
2087-
future.set_result(None)
2075+
future: asyncio.Future = await self._flush_queue.get()
2076+
try:
2077+
if self._pending_data_size > 0:
2078+
self._transport.writelines(self._pending[:])
2079+
self._pending = []
2080+
self._pending_data_size = 0
2081+
await self._transport.drain()
2082+
except OSError as e:
2083+
await self._error_cb(e)
2084+
await self._process_op_err(e)
2085+
break
2086+
except (RuntimeError, AttributeError):
2087+
# RuntimeError in case the event loop is closed
2088+
break
2089+
finally:
2090+
future.set_result(None)
2091+
except asyncio.CancelledError:
2092+
if self._status == Client.CLOSED:
2093+
break
2094+
else:
2095+
continue
20882096

20892097
async def _ping_interval(self) -> None:
20902098
while True:
@@ -2098,8 +2106,13 @@ async def _ping_interval(self) -> None:
20982106
await self._process_op_err(ErrStaleConnection())
20992107
return
21002108
await self._send_ping()
2101-
except (asyncio.CancelledError, RuntimeError, AttributeError):
2109+
except (RuntimeError, AttributeError):
21022110
break
2111+
except asyncio.CancelledError:
2112+
if self._status == Client.CLOSED:
2113+
break
2114+
else:
2115+
continue
21032116
# except asyncio.InvalidStateError:
21042117
# pass
21052118

@@ -2130,7 +2143,10 @@ async def _read_loop(self) -> None:
21302143
await self._process_op_err(e)
21312144
break
21322145
except asyncio.CancelledError:
2133-
break
2146+
if self._status == Client.CLOSED:
2147+
break
2148+
else:
2149+
continue
21342150
except Exception as ex:
21352151
_logger.error("nats: encountered error", exc_info=ex)
21362152
break

nats/aio/subscription.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,6 @@ async def unsubscribe(self, limit: int = 0):
284284
self._max_msgs = limit
285285
if limit == 0 or (self._received >= limit
286286
and self._pending_queue.empty()):
287-
self._closed = True
288287
self._stop_processing()
289288
self._conn._remove_sub(self._id)
290289

@@ -295,6 +294,7 @@ def _stop_processing(self) -> None:
295294
"""
296295
Stops the subscription from processing new messages.
297296
"""
297+
self._closed = True
298298
if self._wait_for_msgs_task and not self._wait_for_msgs_task.done():
299299
self._wait_for_msgs_task.cancel()
300300
if self._message_iterator:
@@ -333,7 +333,10 @@ async def _wait_for_msgs(self, error_cb) -> None:
333333
and self._pending_queue.empty):
334334
self._stop_processing()
335335
except asyncio.CancelledError:
336-
break
336+
if self._closed:
337+
break
338+
else:
339+
continue
337340

338341

339342
class _SubscriptionMessageIterator:

nats/js/client.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -968,12 +968,19 @@ async def next_msg(self, timeout: Optional[float] = 1.0) -> Msg:
968968
self._sub._jsi._fcr = None
969969
return msg
970970

971+
async def drain(self):
972+
await self._sub.drain()
973+
self._closed = self._sub._closed
974+
971975
async def unsubscribe(self, limit: int = 0):
972976
"""
973977
Unsubscribes from a subscription, canceling any heartbeat and flow control tasks,
974978
and optionally limits the number of messages to process before unsubscribing.
979+
Nothing is really subscribed from this object, call unsubscribe on underlying sub
980+
and forward _closed flag.
975981
"""
976-
await super().unsubscribe(limit)
982+
await self._sub.unsubscribe(limit)
983+
self._closed = self._sub._closed
977984

978985
if self._sub._jsi._hbtask:
979986
self._sub._jsi._hbtask.cancel()

tests/test_js.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,8 @@ async def f():
314314
await task
315315
assert received
316316

317+
await nc.close()
318+
317319
@async_test
318320
async def test_add_pull_consumer_via_jsm(self):
319321
nc = NATS()
@@ -339,6 +341,8 @@ async def test_add_pull_consumer_via_jsm(self):
339341
info = await js.consumer_info("events", "a")
340342
assert 0 == info.num_pending
341343

344+
await nc.close()
345+
342346
@async_long_test
343347
async def test_fetch_n(self):
344348
nc = NATS()
@@ -852,6 +856,7 @@ async def test_ephemeral_pull_subscribe(self):
852856
cinfo = await sub.consumer_info()
853857
self.assertTrue(cinfo.config.name != None)
854858
self.assertTrue(cinfo.config.durable_name == None)
859+
855860
await nc.close()
856861

857862
@async_test
@@ -896,6 +901,8 @@ async def test_consumer_with_multiple_filters(self):
896901
ok = await msgs[0].ack_sync()
897902
assert ok
898903

904+
await nc.close()
905+
899906
@async_long_test
900907
async def test_add_consumer_with_backoff(self):
901908
nc = NATS()
@@ -953,6 +960,7 @@ async def cb(msg):
953960

954961
# Confirm possible to unmarshal the consumer config.
955962
assert info.config.backoff == [1, 2]
963+
956964
await nc.close()
957965

958966
@async_long_test
@@ -1495,6 +1503,8 @@ async def test_jsm_stream_info_options(self):
14951503
assert si.state.messages == 5
14961504
assert si.state.subjects == None
14971505

1506+
await nc.close()
1507+
14981508

14991509
class SubscribeTest(SingleJetStreamServerTestCase):
15001510

@@ -1657,6 +1667,8 @@ async def test_ephemeral_subscribe(self):
16571667
assert len(info2.name) > 0
16581668
assert info1.name != info2.name
16591669

1670+
await nc.close()
1671+
16601672
@async_test
16611673
async def test_subscribe_bind(self):
16621674
nc = await nats.connect()
@@ -1702,6 +1714,8 @@ async def test_subscribe_bind(self):
17021714
assert info.num_ack_pending == 0
17031715
assert info.num_pending == 0
17041716

1717+
await nc.close()
1718+
17051719
@async_test
17061720
async def test_subscribe_custom_limits(self):
17071721
errors = []
@@ -1904,6 +1918,8 @@ async def test_ack_v2_tokens(self):
19041918
tzinfo=datetime.timezone.utc
19051919
)
19061920

1921+
await nc.close()
1922+
19071923
@async_test
19081924
async def test_double_acking_pull_subscribe(self):
19091925
nc = await nats.connect()
@@ -2031,6 +2047,8 @@ async def f():
20312047
assert task.done()
20322048
assert received
20332049

2050+
await nc.close()
2051+
20342052

20352053
class DiscardPolicyTest(SingleJetStreamServerTestCase):
20362054

@@ -2516,6 +2534,7 @@ async def cb(msg):
25162534
await asyncio.wait_for(done, 10)
25172535

25182536
await nc.close()
2537+
await nc2.close()
25192538

25202539
@async_test
25212540
async def test_recreate_consumer_on_failed_hbs(self):
@@ -2548,6 +2567,8 @@ async def error_handler(e):
25482567
self.assertTrue(orig_name != info.name)
25492568
await js.delete_stream("MY_STREAM")
25502569

2570+
await nc.close()
2571+
25512572

25522573
class KVTest(SingleJetStreamServerTestCase):
25532574

@@ -2667,6 +2688,8 @@ async def error_handler(e):
26672688
with pytest.raises(BadBucketError):
26682689
await js.key_value(bucket="TEST3")
26692690

2691+
await nc.close()
2692+
26702693
@async_test
26712694
async def test_kv_basic(self):
26722695
errors = []
@@ -2824,6 +2847,8 @@ async def error_handler(e):
28242847
entry = await kv.get("age")
28252848
assert entry.revision == 10
28262849

2850+
await nc.close()
2851+
28272852
@async_test
28282853
async def test_kv_direct_get_msg(self):
28292854
errors = []
@@ -2879,6 +2904,8 @@ async def error_handler(e):
28792904
)
28802905
assert msg.data == b"33"
28812906

2907+
await nc.close()
2908+
28822909
@async_test
28832910
async def test_kv_direct(self):
28842911
errors = []

0 commit comments

Comments
 (0)