|  | 
| 1 | 1 | import queue | 
| 2 | 2 | import time | 
| 3 | 3 | from datetime import timedelta | 
|  | 4 | +from multiprocessing.connection import Connection | 
| 4 | 5 | from typing import Union | 
| 5 | 6 | 
 | 
| 6 | 7 | import torch.multiprocessing as mp | 
| 7 | 8 | 
 | 
| 8 | 9 | 
 | 
| 9 |  | -class _MonitoredQueue: | 
| 10 |  | -    def __init__( | 
| 11 |  | -        self, | 
| 12 |  | -        p: mp.Process, | 
| 13 |  | -        q: mp.Queue, | 
| 14 |  | -        poll_interval: timedelta = timedelta(seconds=1), | 
| 15 |  | -    ) -> None: | 
| 16 |  | -        """ | 
| 17 |  | -        Args: | 
| 18 |  | -            p: process to monitor | 
| 19 |  | -            q: queue to monitor | 
| 20 |  | -            poll_interval: interval to poll the Process health when calling get/put | 
| 21 |  | -        """ | 
| 22 |  | -        self._p = p | 
| 23 |  | -        self._q = q | 
| 24 |  | -        self._poll_interval_s: float = poll_interval.total_seconds() | 
|  | 10 | +class _MonitoredPipe: | 
|  | 11 | +    def __init__(self, pipe: "Connection[object, object]") -> None: | 
|  | 12 | +        self._pipe = pipe | 
| 25 | 13 | 
 | 
| 26 |  | -    def get(self, timeout: Union[float, timedelta]) -> object: | 
| 27 |  | -        """ | 
| 28 |  | -        Get an item from the queue. If the process is not alive, raise RuntimeError. | 
| 29 |  | -        If the queue is empty, wait for up to timeout seconds for an item to be | 
| 30 |  | -        available. If no item is available after timeout seconds, raise TimeoutError. | 
| 31 |  | -
 | 
| 32 |  | -        Args: | 
| 33 |  | -            timeout: timeout in seconds | 
| 34 |  | -        """ | 
|  | 14 | +    def send(self, obj: object) -> None: | 
|  | 15 | +        self._pipe.send(obj) | 
| 35 | 16 | 
 | 
|  | 17 | +    def recv(self, timeout: Union[float, timedelta]) -> object: | 
| 36 | 18 |         if isinstance(timeout, timedelta): | 
| 37 | 19 |             timeout = timeout.total_seconds() | 
| 38 |  | - | 
| 39 |  | -        start = time.perf_counter() | 
| 40 |  | -        while True: | 
| 41 |  | -            try: | 
| 42 |  | -                v = self._q.get(timeout=self._poll_interval_s) | 
| 43 |  | -                break | 
| 44 |  | -            except queue.Empty: | 
| 45 |  | -                pass | 
| 46 |  | - | 
| 47 |  | -            elapsed = time.perf_counter() - start | 
| 48 |  | -            if elapsed > timeout: | 
| 49 |  | -                raise TimeoutError(f"queue.get() timed out after {timeout} seconds") | 
| 50 |  | - | 
| 51 |  | -            # polling the process can be slow so we only do it every poll_interval | 
| 52 |  | -            if not self._p.is_alive(): | 
| 53 |  | -                raise RuntimeError(f"process is not alive {self._p.exitcode}") | 
| 54 |  | - | 
| 55 |  | -        if isinstance(v, Exception): | 
| 56 |  | -            raise v | 
| 57 |  | -        return v | 
| 58 |  | - | 
| 59 |  | -    def put(self, obj: object, timeout: Union[float, timedelta]) -> None: | 
| 60 |  | -        """ | 
| 61 |  | -        Put an item into the queue. If the process is not alive, raise RuntimeError. | 
| 62 |  | -        If the queue is full, wait for up to timeout seconds for an item to be | 
| 63 |  | -        available. If queue is full after timeout seconds, raise TimeoutError. | 
| 64 |  | -
 | 
| 65 |  | -        If an exception is put into the queue, it will be raised when calling get(). | 
| 66 |  | -
 | 
| 67 |  | -        Args: | 
| 68 |  | -            obj: object to put into the queue | 
| 69 |  | -            timeout: timeout in seconds | 
| 70 |  | -        """ | 
| 71 |  | -        if isinstance(timeout, timedelta): | 
| 72 |  | -            timeout = timeout.total_seconds() | 
| 73 |  | - | 
| 74 |  | -        start = time.perf_counter() | 
| 75 |  | -        while True: | 
| 76 |  | -            try: | 
| 77 |  | -                self._q.put(obj, timeout=self._poll_interval_s) | 
| 78 |  | -                break | 
| 79 |  | -            except queue.Full: | 
| 80 |  | -                pass | 
| 81 |  | - | 
| 82 |  | -            elapsed = time.perf_counter() - start | 
| 83 |  | -            if elapsed > timeout: | 
| 84 |  | -                raise TimeoutError(f"queue.put() timed out after {timeout} seconds") | 
| 85 |  | - | 
| 86 |  | -            # polling the process can be slow so we only do it every poll_interval | 
| 87 |  | -            if not self._p.is_alive(): | 
| 88 |  | -                raise RuntimeError(f"process is not alive {self._p.exitcode}") | 
|  | 20 | +        if self._pipe.poll(timeout): | 
|  | 21 | +            out = self._pipe.recv() | 
|  | 22 | +            if isinstance(out, Exception): | 
|  | 23 | +                raise out | 
|  | 24 | +            return out | 
|  | 25 | +        else: | 
|  | 26 | +            raise TimeoutError(f"pipe.recv() timed out after {timeout} seconds") | 
| 89 | 27 | 
 | 
| 90 | 28 |     def close(self) -> None: | 
| 91 |  | -        self._q.close() | 
|  | 29 | +        self._pipe.close() | 
| 92 | 30 | 
 | 
| 93 | 31 |     def closed(self) -> bool: | 
| 94 |  | -        # pyre-ignore[16]: no attribute _closed | 
| 95 |  | -        return self._q._closed | 
|  | 32 | +        return self._pipe.closed | 
0 commit comments