11import asyncio
22import threading
3+ from contextlib import contextmanager
34from datetime import timedelta
4- from typing import Optional , TypeVar
5+ from typing import Callable , Generator , Optional , TypeVar
56from unittest .mock import Mock
67
8+ import torch
79from torch .futures import Future
810
911T = TypeVar ("T" )
1214class _TimerHandle :
1315 def __init__ (self ) -> None :
1416 self ._lock = threading .Lock ()
15- self ._lock .acquire ()
1617 self ._timer_handle : Optional [asyncio .TimerHandle ] = None
18+ self ._cancelled = False
1719
18- def set_timer (self , timer_handle : asyncio .TimerHandle ) -> None :
19- assert self ._lock .locked ()
20-
21- self ._timer_handle = timer_handle
22- self ._lock .release ()
20+ def set_timer_handle (self , timer_handle : asyncio .TimerHandle ) -> None :
21+ with self ._lock :
22+ if self ._cancelled :
23+ timer_handle .cancel ()
24+ self ._timer_handle = None
25+ else :
26+ self ._timer_handle = timer_handle
2327
2428 def cancel (self ) -> None :
2529 with self ._lock :
26- assert self ._timer_handle is not None
27- self ._timer_handle .cancel ()
28- self ._timer_handle = None
30+ assert not self ._cancelled , "timer can only be cancelled once"
31+ self ._cancelled = True
32+ if self ._timer_handle is not None :
33+ self ._timer_handle .cancel ()
34+ self ._timer_handle = None
2935
3036
3137class _TimeoutManager :
@@ -81,8 +87,16 @@ def register(self, fut: Future[T], timeout: timedelta) -> Future[T]:
8187 # pyre-fixme[29]: Future is not a function
8288 timed_fut : Future [T ] = Future ()
8389 handle : _TimerHandle = _TimerHandle ()
84- # pyre-fixme[6]: *args
85- loop .call_soon_threadsafe (self ._register , loop , timed_fut , timeout , handle )
90+ loop .call_soon_threadsafe (
91+ self ._register_callback ,
92+ loop ,
93+ lambda : timed_fut .set_exception (
94+ # pyre-fixme[6]: e is not T
95+ TimeoutError (f"future did not complete within { timeout } " )
96+ ),
97+ timeout ,
98+ handle ,
99+ )
86100
87101 def callback (fut : Future [T ]) -> None :
88102 handle .cancel ()
@@ -99,22 +113,48 @@ def callback(fut: Future[T]) -> None:
99113 fut .add_done_callback (callback )
100114 return timed_fut
101115
116+ def stream_timeout (self , callback : Callable [[], None ], timeout : timedelta ) -> None :
117+ loop = self ._maybe_start_event_loop ()
118+
119+ event : torch .cuda .Event = torch .cuda .Event ()
120+ event .record ()
121+
122+ def handler () -> None :
123+ if not event .query ():
124+ callback ()
125+
126+ loop .call_soon_threadsafe (
127+ self ._register_callback , loop , handler , timeout , _TimerHandle ()
128+ )
129+
102130 @classmethod
103- def _register (
131+ def _register_callback (
104132 cls ,
105133 loop : asyncio .AbstractEventLoop ,
106- fut : Future [ T ],
134+ callback : Callable [[], None ],
107135 timeout : timedelta ,
108136 handle : _TimerHandle ,
109137 ) -> None :
110138 timer_handle = loop .call_later (
111139 timeout .total_seconds (),
112- lambda : fut .set_exception (
113- # pyre-fixme[6]: e is not T
114- TimeoutError (f"future did not complete within { timeout } " )
115- ),
140+ callback ,
116141 )
117- handle .set_timer (timer_handle )
142+ handle .set_timer_handle (timer_handle )
143+
144+ @contextmanager
145+ def context_timeout (
146+ self , callback : Callable [[], None ], timeout : timedelta
147+ ) -> Generator [None , None , None ]:
148+ loop = self ._maybe_start_event_loop ()
149+ handle = _TimerHandle ()
150+
151+ loop .call_soon_threadsafe (
152+ self ._register_callback , loop , callback , timeout , handle
153+ )
154+
155+ yield
156+
157+ handle .cancel ()
118158
119159
120160_TIMEOUT_MANAGER = _TimeoutManager ()
@@ -163,3 +203,35 @@ def callback(fut: Future[T]) -> T:
163203 raise TimeoutError (f"future did not complete within { timeout } " )
164204
165205 return fut .wait ()
206+
207+
208+ def stream_timeout (callback : Callable [[], None ], timeout : timedelta ) -> None :
209+ """
210+ Registers a callback that will be called after the specified timeout if
211+ the current stream doesn't complete in time.
212+
213+ This uses a cuda Event to track the completion of the current stream. If
214+ the stream is not complete after the timeout, the callback is called.
215+
216+ Args:
217+ callback: The callback to call if the stream doesn't complete in time.
218+ timeout: The timeout to wait for the stream to complete.
219+ """
220+ _TIMEOUT_MANAGER .stream_timeout (callback , timeout )
221+
222+
223+ @contextmanager
224+ def context_timeout (
225+ callback : Callable [[], None ], timeout : timedelta
226+ ) -> Generator [None , None , None ]:
227+ """
228+ Registers a callback that will be called after the specified timeout if
229+ the current contextmanager doesn't exit in time.
230+
231+ Args:
232+ callback: The callback to call if we time out.
233+ timeout: How long to wait for the contextmanager to exit.
234+ """
235+
236+ with _TIMEOUT_MANAGER .context_timeout (callback , timeout ):
237+ yield
0 commit comments