33import pickle
44import selectors
55from collections import deque
6- from contextlib import AsyncExitStack
76from functools import partial
7+ from threading import get_ident
88from typing import (
99 Any ,
1010 Callable ,
1313
1414from anyio import (
1515 Event ,
16- Lock ,
1716 TASK_STATUS_IGNORED ,
1817 create_task_group ,
1918 get_cancelled_exc_class ,
2019 sleep ,
2120 wait_readable ,
2221)
23- from anyio .abc import TaskStatus
22+ from anyio .abc import TaskGroup , TaskStatus
2423from anyioutils import FIRST_COMPLETED , Future , create_task , wait
2524
2625import zmq
@@ -157,15 +156,18 @@ class Socket(zmq.Socket):
157156 _fd = None
158157 _exit_stack = None
159158 _task_group = None
159+ __task_group = None
160+ _thread = None
160161 started = None
161162 stopped = None
162- _start_lock = None
163+ _starting = None
163164 _exited = None
164165
165166 def __init__ (
166167 self ,
167168 context_or_socket : zmq .Context | zmq .Socket ,
168169 socket_type : int = - 1 ,
170+ task_group : TaskGroup | None = None ,
169171 ** kwargs ,
170172 ) -> None :
171173 """
@@ -188,7 +190,7 @@ def __init__(
188190 self .started = Event ()
189191 self ._exited = Event ()
190192 self .stopped = Event ()
191- self ._start_lock = Lock ()
193+ self ._task_group = task_group
192194
193195 def get (self , key ):
194196 result = super ().get (key )
@@ -825,44 +827,56 @@ def _update_handler(self, state) -> None:
825827 self ._schedule_remaining_events ()
826828
827829 async def __aenter__ (self ) -> Socket :
828- assert self ._start_lock is not None
829- async with self ._start_lock :
830- if self ._task_group is None :
831- async with AsyncExitStack () as exit_stack :
832- self ._task_group = await exit_stack .enter_async_context (
833- create_task_group ()
834- )
835- self ._exit_stack = exit_stack .pop_all ()
836- await self ._task_group .start (self ._start )
830+ if self ._starting :
831+ return
832+
833+ self ._starting = True
834+ if self ._task_group is None :
835+ self .__task_group = create_task_group ()
836+ self ._task_group = await self .__task_group .__aenter__ ()
837+ await self ._task_group .start (self ._start )
837838
838839 return self
839840
840841 async def __aexit__ (self , exc_type , exc_value , exc_tb ):
841842 await self .stop ()
842- return await self ._exit_stack .__aexit__ (exc_type , exc_value , exc_tb )
843+ if self .__task_group is not None :
844+ return await self .__task_group .__aexit__ (exc_type , exc_value , exc_tb )
843845
844846 async def start (
845847 self ,
846848 * ,
847849 task_status : TaskStatus [None ] = TASK_STATUS_IGNORED ,
848850 ) -> None :
849- assert self ._start_lock is not None
850- async with self ._start_lock :
851- if self ._task_group is None :
852- async with create_task_group () as self ._task_group :
853- await self ._task_group .start (self ._start )
854- task_status .started ()
855- else :
851+ if self ._starting :
852+ return
853+
854+ self ._starting = True
855+ assert self .started is not None
856+ if self .started .is_set ():
857+ task_status .started ()
858+ return
859+
860+ if self ._task_group is None :
861+ async with create_task_group () as self ._task_group :
856862 await self ._task_group .start (self ._start )
857863 task_status .started ()
864+ else :
865+ await self ._task_group .start (self ._start )
866+ task_status .started ()
867+
868+ async def _start (self , * , task_status : TaskStatus [None ] = TASK_STATUS_IGNORED ):
869+ assert self .started is not None
870+ if self .started .is_set ():
871+ return
858872
859- async def _start (self , * , task_status : TaskStatus [None ]):
860873 assert self .started is not None
861874 assert self .stopped is not None
862875 assert self ._exited is not None
863876 assert self ._task_group is not None
864877 task_status .started ()
865878 self .started .set ()
879+ self ._thread = get_ident ()
866880 try :
867881 while True :
868882 wait_stopped_task = create_task (
@@ -922,6 +936,12 @@ def _check_started(self):
922936 "Socket must be used with async context manager (or `await sock.start()`)"
923937 )
924938
939+ self ._task_group .start_soon (self ._start )
940+
941+ assert self ._thread is not None
942+ if self ._thread != get_ident ():
943+ raise RuntimeError ("Socket must be used in the same thread" )
944+
925945
926946def ignore_exceptions (exc : BaseException ) -> bool :
927947 return True
0 commit comments