|
18 | 18 |
|
19 | 19 | import asyncio
|
20 | 20 | import logging
|
21 |
| -from typing import Optional, Type |
| 21 | +from typing import Optional |
22 | 22 |
|
23 |
| -from .transport import TCP, TCPAbridged |
| 23 | +from .transport import * |
24 | 24 | from ..session.internals import DataCenter
|
25 | 25 |
|
26 | 26 | log = logging.getLogger(__name__)
|
27 | 27 |
|
28 | 28 |
|
29 | 29 | class Connection:
|
30 |
| - MAX_CONNECTION_ATTEMPTS = 3 |
| 30 | + MAX_RETRIES = 3 |
31 | 31 |
|
32 |
| - def __init__( |
33 |
| - self, |
34 |
| - dc_id: int, |
35 |
| - test_mode: bool, |
36 |
| - ipv6: bool, |
37 |
| - proxy: dict, |
38 |
| - media: bool = False, |
39 |
| - protocol_factory: Type[TCP] = TCPAbridged |
40 |
| - ) -> None: |
| 32 | + MODES = { |
| 33 | + 0: TCPFull, |
| 34 | + 1: TCPAbridged, |
| 35 | + 2: TCPIntermediate, |
| 36 | + 3: TCPAbridgedO, |
| 37 | + 4: TCPIntermediateO |
| 38 | + } |
| 39 | + |
| 40 | + def __init__(self, dc_id: int, test_mode: bool, ipv6: bool, proxy: dict, media: bool = False, mode: int = 3): |
41 | 41 | self.dc_id = dc_id
|
42 | 42 | self.test_mode = test_mode
|
43 | 43 | self.ipv6 = ipv6
|
44 | 44 | self.proxy = proxy
|
45 | 45 | self.media = media
|
46 |
| - self.protocol_factory = protocol_factory |
47 |
| - |
48 | 46 | self.address = DataCenter(dc_id, test_mode, ipv6, media)
|
49 |
| - self.protocol: Optional[TCP] = None |
| 47 | + self.mode = self.MODES.get(mode, TCPAbridged) |
| 48 | + |
| 49 | + self.protocol = None # type: TCP |
50 | 50 |
|
51 |
| - async def connect(self) -> None: |
52 |
| - for i in range(Connection.MAX_CONNECTION_ATTEMPTS): |
53 |
| - self.protocol = self.protocol_factory(ipv6=self.ipv6, proxy=self.proxy) |
| 51 | + async def connect(self): |
| 52 | + for i in range(Connection.MAX_RETRIES): |
| 53 | + self.protocol = self.mode(self.ipv6, self.proxy) |
54 | 54 |
|
55 | 55 | try:
|
56 | 56 | log.info("Connecting...")
|
57 | 57 | await self.protocol.connect(self.address)
|
58 | 58 | except OSError as e:
|
59 |
| - log.warning("Unable to connect due to network issues: %s", e) |
60 |
| - await self.protocol.close() |
| 59 | + log.warning(f"Unable to connect due to network issues: {e}") |
| 60 | + self.protocol.close() |
61 | 61 | await asyncio.sleep(1)
|
62 | 62 | else:
|
63 |
| - log.info("Connected! %s DC%s%s - IPv%s", |
64 |
| - "Test" if self.test_mode else "Production", |
65 |
| - self.dc_id, |
66 |
| - " (media)" if self.media else "", |
67 |
| - "6" if self.ipv6 else "4") |
| 63 | + log.info("Connected! {} DC{}{} - IPv{} - {}".format( |
| 64 | + "Test" if self.test_mode else "Production", |
| 65 | + self.dc_id, |
| 66 | + " (media)" if self.media else "", |
| 67 | + "6" if self.ipv6 else "4", |
| 68 | + self.mode.__name__, |
| 69 | + )) |
68 | 70 | break
|
69 | 71 | else:
|
70 | 72 | log.warning("Connection failed! Trying again...")
|
71 |
| - raise ConnectionError |
| 73 | + raise TimeoutError |
72 | 74 |
|
73 |
| - async def close(self) -> None: |
74 |
| - await self.protocol.close() |
| 75 | + def close(self): |
| 76 | + self.protocol.close() |
75 | 77 | log.info("Disconnected")
|
76 | 78 |
|
77 |
| - async def send(self, data: bytes) -> None: |
78 |
| - await self.protocol.send(data) |
| 79 | + async def send(self, data: bytes): |
| 80 | + try: |
| 81 | + await self.protocol.send(data) |
| 82 | + except Exception as e: |
| 83 | + raise OSError(e) |
79 | 84 |
|
80 | 85 | async def recv(self) -> Optional[bytes]:
|
81 | 86 | return await self.protocol.recv()
|
0 commit comments