11import asyncio
2+ import queue
3+ import sys
24import threading
35from contextlib import contextmanager
46from datetime import timedelta
@@ -36,8 +38,13 @@ def cancel(self) -> None:
3638
3739class _TimeoutManager :
3840 """
39- This class manages timeouts for futures. It uses a background thread with an
40- event loop to schedule the timeouts.
41+ This class manages timeouts for code blocks, futures and CUDA events. It
42+ uses a background thread with an event loop to schedule the timeouts and
43+ call the callback function when the timeout is reached.
44+
45+ Generally there is a single instance of this class that is used for all
46+ timeouts. The callbacks should not block otherwise other timeouts may not
47+ be processed.
4148 """
4249
4350 def __init__ (self ) -> None :
@@ -46,6 +53,10 @@ def __init__(self) -> None:
4653 self ._event_loop_thread : Optional [threading .Thread ] = None
4754 self ._next_timer_id = 0
4855
56+ # This queue is used to delete events on the main thread as cudaEventDestroy
57+ # can block if the CUDA queue is full.
58+ self ._del_queue : queue .SimpleQueue [object ] = queue .SimpleQueue ()
59+
4960 def _maybe_start_event_loop (self ) -> asyncio .AbstractEventLoop :
5061 """
5162 Start the event loop if it has not already been started.
@@ -82,6 +93,8 @@ def register(self, fut: Future[T], timeout: timedelta) -> Future[T]:
8293 if isinstance (fut , Mock ):
8394 return fut
8495
96+ self ._clear_del_queue ()
97+
8598 loop = self ._maybe_start_event_loop ()
8699
87100 # pyre-fixme[29]: Future is not a function
@@ -114,6 +127,8 @@ def callback(fut: Future[T]) -> None:
114127 return timed_fut
115128
116129 def stream_timeout (self , callback : Callable [[], None ], timeout : timedelta ) -> None :
130+ self ._clear_del_queue ()
131+
117132 loop = self ._maybe_start_event_loop ()
118133
119134 event : torch .cuda .Event = torch .cuda .Event ()
@@ -123,6 +138,11 @@ def handler() -> None:
123138 if not event .query ():
124139 callback ()
125140
141+ # cudaEventDestroy can block so we never want to delete in the event
142+ # loop. Put it on the del queue so we can delete it in the main
143+ # thread.
144+ self ._del_queue .put (event )
145+
126146 loop .call_soon_threadsafe (
127147 self ._register_callback , loop , handler , timeout , _TimerHandle ()
128148 )
@@ -145,6 +165,8 @@ def _register_callback(
145165 def context_timeout (
146166 self , callback : Callable [[], None ], timeout : timedelta
147167 ) -> Generator [None , None , None ]:
168+ self ._clear_del_queue ()
169+
148170 loop = self ._maybe_start_event_loop ()
149171 handle = _TimerHandle ()
150172
@@ -156,6 +178,31 @@ def context_timeout(
156178
157179 handle .cancel ()
158180
181+ def _clear_del_queue (self ) -> int :
182+ """
183+ Clear the queue of futures to be deleted.
184+
185+ Returns the number of items deleted.
186+ """
187+ count = 0
188+ while True :
189+ try :
190+ # get and immediately discard item
191+ item = self ._del_queue .get_nowait ()
192+ refcount = sys .getrefcount (item )
193+ assert (
194+ # 1 from item, 1 from getrefcount
195+ refcount
196+ == 2
197+ ), f"items in del_queue reference should not have other references, found { refcount = } "
198+ del item
199+
200+ count += 1
201+ except queue .Empty :
202+ break
203+
204+ return count
205+
159206
160207_TIMEOUT_MANAGER = _TimeoutManager ()
161208
0 commit comments