@@ -195,18 +195,24 @@ def __init__(
195195 self .started = Event ()
196196
197197 def close (self , linger : int | None = None ) -> None :
198- if not self .closed and self ._fd is not None :
199- event_list : list [_FutureEvent ] = list (
200- chain (self ._recv_futures or [], self ._send_futures or [])
201- )
202- for event in event_list :
203- if not event .future .done ():
204- try :
205- event .future .cancel (raise_exception = False )
206- except RuntimeError :
207- # RuntimeError may be called during teardown
208- pass
209- super ().close (linger = linger )
198+ try :
199+ if not self .closed and self ._fd is not None :
200+ event_list : list [_FutureEvent ] = list (
201+ chain (self ._recv_futures or [], self ._send_futures or [])
202+ )
203+ for event in event_list :
204+ if not event .future .done ():
205+ try :
206+ event .future .cancel (raise_exception = False )
207+ except RuntimeError :
208+ # RuntimeError may be called during teardown
209+ pass
210+ super ().close (linger = linger )
211+ except BaseException :
212+ pass
213+
214+ if self ._task_group is not None :
215+ self ._task_group .cancel_scope .cancel ()
210216
211217 close .__doc__ = zmq .Socket .close .__doc__
212218
@@ -224,6 +230,7 @@ async def arecv(
224230 copy : bool = True ,
225231 track : bool = False ,
226232 ) -> bytes | zmq .Frame :
233+ self ._check_started ()
227234 return await self ._add_recv_event (
228235 "recv" , dict (flags = flags , copy = copy , track = track )
229236 )
@@ -315,6 +322,7 @@ async def arecv_multipart(
315322 copy : bool = True ,
316323 track : bool = False ,
317324 ) -> list [bytes ] | list [zmq .Frame ]:
325+ self ._check_started ()
318326 return await self ._add_recv_event (
319327 "recv_multipart" , dict (flags = flags , copy = copy , track = track )
320328 )
@@ -339,6 +347,7 @@ async def asend(
339347 track : bool = False ,
340348 ** kwargs : Any ,
341349 ) -> zmq .MessageTracker | None :
350+ self ._check_started ()
342351 kwargs ["flags" ] = flags
343352 kwargs ["copy" ] = copy
344353 kwargs ["track" ] = track
@@ -431,6 +440,7 @@ async def asend_multipart(
431440 track : bool = False ,
432441 ** kwargs ,
433442 ) -> zmq .MessageTracker | None :
443+ self ._check_started ()
434444 kwargs ["flags" ] = flags
435445 kwargs ["copy" ] = copy
436446 kwargs ["track" ] = track
@@ -447,7 +457,7 @@ async def apoll(self, timeout=None, flags=zmq.POLLIN) -> int: # type: ignore
447457
448458 returns a Future for the poll results.
449459 """
450-
460+ self . _check_started ()
451461 if self .closed :
452462 raise zmq .ZMQError (zmq .ENOTSUP )
453463
@@ -783,10 +793,6 @@ async def __aenter__(self) -> Socket:
783793 return self
784794
785795 async def __aexit__ (self , exc_type , exc_value , exc_tb ):
786- try :
787- self .close ()
788- except BaseException :
789- pass
790796 await self .stop ()
791797 return await self ._exit_stack .__aexit__ (exc_type , exc_value , exc_tb )
792798
@@ -796,29 +802,32 @@ async def start(
796802 if self ._task_group is None :
797803 async with create_task_group () as self ._task_group :
798804 await self ._task_group .start (self ._start )
805+ task_status .started ()
799806 else :
800807 await self ._task_group .start (self ._start )
801- task_status .started ()
808+ task_status .started ()
802809
803810 async def stop (self ):
804- if self ._task_group is None :
805- return
806-
807- self ._task_group .cancel_scope .cancel ()
811+ self .close ()
808812
809813 async def _start (self , * , task_status : TaskStatus [None ]):
810814 _set_selector_windows ()
811815 assert self ._task_group is not None
812816 assert self .started is not None
817+ task_status .started ()
813818 if self .started .is_set ():
814- task_status .started ()
815819 return
816820
817821 self .started .set ()
818- task_status .started ()
819822 try :
820823 while True :
821824 await wait_socket_readable (self ._shadow_sock .FD ) # type: ignore[arg-type]
822825 await self ._handle_events ()
823826 except Exception :
824827 pass
828+
829+ def _check_started (self ):
830+ if self ._task_group is None :
831+ raise RuntimeError (
832+ "Socket must be used with async context manager (or `await sock.start()`)"
833+ )
0 commit comments