Skip to content

Commit 6c200b7

Browse files
Check if socket started when calling async methods (#11)
1 parent 5f1ca66 commit 6c200b7

File tree

1 file changed

+33
-24
lines changed

1 file changed

+33
-24
lines changed

src/zmq_anyio/_socket.py

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -195,18 +195,24 @@ def __init__(
195195
self.started = Event()
196196

197197
def close(self, linger: int | None = None) -> None:
198-
if not self.closed and self._fd is not None:
199-
event_list: list[_FutureEvent] = list(
200-
chain(self._recv_futures or [], self._send_futures or [])
201-
)
202-
for event in event_list:
203-
if not event.future.done():
204-
try:
205-
event.future.cancel(raise_exception=False)
206-
except RuntimeError:
207-
# RuntimeError may be called during teardown
208-
pass
209-
super().close(linger=linger)
198+
try:
199+
if not self.closed and self._fd is not None:
200+
event_list: list[_FutureEvent] = list(
201+
chain(self._recv_futures or [], self._send_futures or [])
202+
)
203+
for event in event_list:
204+
if not event.future.done():
205+
try:
206+
event.future.cancel(raise_exception=False)
207+
except RuntimeError:
208+
# RuntimeError may be called during teardown
209+
pass
210+
super().close(linger=linger)
211+
except BaseException:
212+
pass
213+
214+
if self._task_group is not None:
215+
self._task_group.cancel_scope.cancel()
210216

211217
close.__doc__ = zmq.Socket.close.__doc__
212218

@@ -224,6 +230,7 @@ async def arecv(
224230
copy: bool = True,
225231
track: bool = False,
226232
) -> bytes | zmq.Frame:
233+
self._check_started()
227234
return await self._add_recv_event(
228235
"recv", dict(flags=flags, copy=copy, track=track)
229236
)
@@ -315,6 +322,7 @@ async def arecv_multipart(
315322
copy: bool = True,
316323
track: bool = False,
317324
) -> list[bytes] | list[zmq.Frame]:
325+
self._check_started()
318326
return await self._add_recv_event(
319327
"recv_multipart", dict(flags=flags, copy=copy, track=track)
320328
)
@@ -339,6 +347,7 @@ async def asend(
339347
track: bool = False,
340348
**kwargs: Any,
341349
) -> zmq.MessageTracker | None:
350+
self._check_started()
342351
kwargs["flags"] = flags
343352
kwargs["copy"] = copy
344353
kwargs["track"] = track
@@ -431,6 +440,7 @@ async def asend_multipart(
431440
track: bool = False,
432441
**kwargs,
433442
) -> zmq.MessageTracker | None:
443+
self._check_started()
434444
kwargs["flags"] = flags
435445
kwargs["copy"] = copy
436446
kwargs["track"] = track
@@ -447,7 +457,7 @@ async def apoll(self, timeout=None, flags=zmq.POLLIN) -> int: # type: ignore
447457
448458
returns a Future for the poll results.
449459
"""
450-
460+
self._check_started()
451461
if self.closed:
452462
raise zmq.ZMQError(zmq.ENOTSUP)
453463

@@ -783,10 +793,6 @@ async def __aenter__(self) -> Socket:
783793
return self
784794

785795
async def __aexit__(self, exc_type, exc_value, exc_tb):
786-
try:
787-
self.close()
788-
except BaseException:
789-
pass
790796
await self.stop()
791797
return await self._exit_stack.__aexit__(exc_type, exc_value, exc_tb)
792798

@@ -796,29 +802,32 @@ async def start(
796802
if self._task_group is None:
797803
async with create_task_group() as self._task_group:
798804
await self._task_group.start(self._start)
805+
task_status.started()
799806
else:
800807
await self._task_group.start(self._start)
801-
task_status.started()
808+
task_status.started()
802809

803810
async def stop(self):
804-
if self._task_group is None:
805-
return
806-
807-
self._task_group.cancel_scope.cancel()
811+
self.close()
808812

809813
async def _start(self, *, task_status: TaskStatus[None]):
810814
_set_selector_windows()
811815
assert self._task_group is not None
812816
assert self.started is not None
817+
task_status.started()
813818
if self.started.is_set():
814-
task_status.started()
815819
return
816820

817821
self.started.set()
818-
task_status.started()
819822
try:
820823
while True:
821824
await wait_socket_readable(self._shadow_sock.FD) # type: ignore[arg-type]
822825
await self._handle_events()
823826
except Exception:
824827
pass
828+
829+
def _check_started(self):
830+
if self._task_group is None:
831+
raise RuntimeError(
832+
"Socket must be used with async context manager (or `await sock.start()`)"
833+
)

0 commit comments

Comments
 (0)