|
1 |
| -""" |
2 |
| -域名限速器模块 |
3 |
| -""" |
4 |
| - |
5 | 1 | import time
|
6 | 2 | import threading
|
7 |
| -from collections import deque, defaultdict |
8 |
| -from urllib.parse import urlparse |
9 | 3 | from typing import Dict
|
| 4 | +from urllib.parse import urlparse |
10 | 5 |
|
11 | 6 | from mcim_sync.config import Config
|
12 | 7 |
|
13 | 8 |
|
| 9 | +class TokenBucket: |
| 10 | + """令牌桶""" |
| 11 | + |
| 12 | + def __init__(self, capacity: int, refill_rate: float, initial_tokens: int = None): |
| 13 | + self.capacity = capacity |
| 14 | + self.refill_rate = refill_rate |
| 15 | + self.tokens = initial_tokens if initial_tokens is not None else capacity |
| 16 | + self.last_refill_time = time.monotonic() |
| 17 | + self.condition = threading.Condition() |
| 18 | + self.waiting_count = 0 |
| 19 | + self._start_refill_thread() |
| 20 | + |
| 21 | + def _start_refill_thread(self): |
| 22 | + """启动后台线程定期补充令牌""" |
| 23 | + |
| 24 | + def refill_loop(): |
| 25 | + while True: |
| 26 | + with self.condition: |
| 27 | + old_tokens = self.tokens |
| 28 | + self._refill() |
| 29 | + new_tokens = int(self.tokens - old_tokens) |
| 30 | + if new_tokens > 0 and self.waiting_count > 0: |
| 31 | + for _ in range(min(new_tokens, self.waiting_count)): |
| 32 | + self.condition.notify() |
| 33 | + time.sleep(1.0 / self.refill_rate) |
| 34 | + |
| 35 | + thread = threading.Thread(target=refill_loop, daemon=True) |
| 36 | + thread.start() |
| 37 | + |
| 38 | + def _refill(self): |
| 39 | + """补充令牌 - 必须在持有锁的情况下调用""" |
| 40 | + current_time = time.monotonic() |
| 41 | + time_passed = current_time - self.last_refill_time |
| 42 | + tokens_to_add = time_passed * self.refill_rate |
| 43 | + self.tokens = min(self.capacity, self.tokens + tokens_to_add) |
| 44 | + self.last_refill_time = current_time |
| 45 | + |
| 46 | + def acquire(self, tokens: int = 1, timeout: float = None) -> bool: |
| 47 | + """ |
| 48 | + 获取令牌,如果没有则等待 |
| 49 | + """ |
| 50 | + with self.condition: |
| 51 | + self._refill() |
| 52 | + if self.tokens >= tokens: |
| 53 | + self.tokens -= tokens |
| 54 | + return True |
| 55 | + |
| 56 | + self.waiting_count += 1 |
| 57 | + end_time = None if timeout is None else time.monotonic() + timeout |
| 58 | + |
| 59 | + try: |
| 60 | + while self.tokens < tokens: |
| 61 | + if timeout is None: |
| 62 | + self.condition.wait() |
| 63 | + else: |
| 64 | + remaining = end_time - time.monotonic() |
| 65 | + if remaining <= 0 or not self.condition.wait(timeout=remaining): |
| 66 | + return False |
| 67 | + self._refill() |
| 68 | + self.tokens -= tokens |
| 69 | + return True |
| 70 | + finally: |
| 71 | + self.waiting_count -= 1 |
| 72 | + |
| 73 | + def get_status(self) -> Dict: |
| 74 | + """获取状态""" |
| 75 | + with self.condition: |
| 76 | + self._refill() |
| 77 | + return { |
| 78 | + "capacity": self.capacity, |
| 79 | + "current_tokens": self.tokens, |
| 80 | + "refill_rate": self.refill_rate, |
| 81 | + "waiting_requests": self.waiting_count, |
| 82 | + "utilization": (self.capacity - self.tokens) / self.capacity, |
| 83 | + } |
| 84 | + |
| 85 | + |
14 | 86 | class DomainRateLimiter:
|
15 |
| - """简单的域名限速器""" |
16 |
| - |
| 87 | + """基于令牌桶的域名限速器""" |
| 88 | + |
17 | 89 | def __init__(self):
|
18 | 90 | self.domain_rate_limits_config = Config.load().domain_rate_limits
|
19 |
| - self.domain_requests: Dict[str, deque] = defaultdict(deque) |
20 |
| - self.locks: Dict[str, threading.Lock] = defaultdict(threading.Lock) |
21 |
| - |
| 91 | + self.token_buckets: Dict[str, TokenBucket] = {} |
| 92 | + self.lock = threading.Lock() |
| 93 | + |
22 | 94 | def get_domain_from_url(self, url: str) -> str:
|
23 | 95 | """从URL中提取域名"""
|
24 | 96 | try:
|
25 | 97 | parsed = urlparse(url)
|
26 |
| - return parsed.netloc.lower() |
| 98 | + return parsed.hostname.lower() if parsed.hostname else "unknown" |
27 | 99 | except Exception:
|
28 | 100 | return "unknown"
|
29 |
| - |
30 |
| - def can_make_request(self, url: str) -> bool: |
31 |
| - """检查是否可以向指定URL发起请求""" |
| 101 | + |
| 102 | + def _get_token_bucket(self, domain: str) -> TokenBucket: |
| 103 | + """获取域名对应的令牌桶""" |
| 104 | + with self.lock: |
| 105 | + bucket = self.token_buckets.get(domain) |
| 106 | + if bucket is not None: |
| 107 | + return bucket |
| 108 | + |
| 109 | + config = self.domain_rate_limits_config.get(domain) |
| 110 | + if config is None: |
| 111 | + raise ValueError(f"Domain '{domain}' not configured in rate limiter") |
| 112 | + |
| 113 | + bucket = TokenBucket( |
| 114 | + capacity=config.capacity, |
| 115 | + refill_rate=config.refill_rate, |
| 116 | + initial_tokens=config.initial_tokens, |
| 117 | + ) |
| 118 | + self.token_buckets[domain] = bucket |
| 119 | + return bucket |
| 120 | + |
| 121 | + def acquire_token(self, url: str, timeout: float = None) -> bool: |
| 122 | + """ |
| 123 | + 获取令牌,如果没有则等待 |
| 124 | + """ |
32 | 125 | domain = self.get_domain_from_url(url)
|
33 |
| - |
34 |
| - # 如果域名没有配置限速,则允许请求 |
| 126 | + |
35 | 127 | if domain not in self.domain_rate_limits_config:
|
36 | 128 | return True
|
37 |
| - |
38 |
| - domain_config = self.domain_rate_limits_config[domain] |
39 |
| - current_time = time.time() |
40 |
| - |
41 |
| - with self.locks[domain]: |
42 |
| - requests = self.domain_requests[domain] |
43 |
| - |
44 |
| - # 清理超出时间窗口的请求记录 |
45 |
| - while requests and current_time - requests[0] > domain_config.time_window: |
46 |
| - requests.popleft() |
47 |
| - |
48 |
| - # 检查是否超过最大请求数 |
49 |
| - return len(requests) < domain_config.max_requests |
50 |
| - |
51 |
| - def record_request(self, url: str): |
52 |
| - """记录请求""" |
53 |
| - domain = self.get_domain_from_url(url) |
54 |
| - |
55 |
| - # 如果域名没有配置限速,则不记录 |
56 |
| - if domain not in self.domain_rate_limits_config: |
57 |
| - return |
58 |
| - |
59 |
| - current_time = time.time() |
60 |
| - |
61 |
| - with self.locks[domain]: |
62 |
| - self.domain_requests[domain].append(current_time) |
63 |
| - |
64 |
| - def wait_time(self, url: str) -> float: |
65 |
| - """计算需要等待的时间""" |
66 |
| - domain = self.get_domain_from_url(url) |
67 |
| - |
68 |
| - # 如果域名没有配置限速,则不需要等待 |
69 |
| - if domain not in self.domain_rate_limits_config: |
70 |
| - return 0.0 |
71 |
| - |
72 |
| - domain_config = self.domain_rate_limits_config[domain] |
73 |
| - current_time = time.time() |
74 |
| - |
75 |
| - with self.locks[domain]: |
76 |
| - requests = self.domain_requests[domain] |
77 |
| - |
78 |
| - # 清理超出时间窗口的请求记录 |
79 |
| - while requests and current_time - requests[0] > domain_config.time_window: |
80 |
| - requests.popleft() |
81 |
| - |
82 |
| - # 如果请求数已满,计算等待时间 |
83 |
| - if len(requests) >= domain_config.max_requests: |
84 |
| - oldest_request = requests[0] |
85 |
| - return domain_config.time_window - (current_time - oldest_request) |
86 |
| - |
87 |
| - return 0.0 |
88 |
| - |
| 129 | + |
| 130 | + try: |
| 131 | + bucket = self._get_token_bucket(domain) |
| 132 | + except ValueError: |
| 133 | + return True # fallback for dynamic change |
| 134 | + |
| 135 | + return bucket.acquire(timeout=timeout) |
| 136 | + |
89 | 137 | def get_domain_status(self, domain: str) -> Dict:
|
90 | 138 | """获取域名的限速状态"""
|
91 | 139 | if domain not in self.domain_rate_limits_config:
|
92 | 140 | return {"configured": False}
|
93 |
| - |
94 |
| - domain_config = self.domain_rate_limits_config[domain] |
95 |
| - current_time = time.time() |
96 |
| - |
97 |
| - with self.locks[domain]: |
98 |
| - requests = self.domain_requests[domain] |
99 |
| - |
100 |
| - # 清理超出时间窗口的请求记录 |
101 |
| - while requests and current_time - requests[0] > domain_config.time_window: |
102 |
| - requests.popleft() |
103 |
| - |
104 |
| - return { |
105 |
| - "configured": True, |
106 |
| - "max_requests": domain_config.max_requests, |
107 |
| - "time_window": domain_config.time_window, |
108 |
| - "current_requests": len(requests), |
109 |
| - "remaining_requests": domain_config.max_requests - len(requests), |
110 |
| - "next_reset_time": requests[0] + domain_config.time_window if requests else current_time |
111 |
| - } |
112 | 141 |
|
| 142 | + bucket = self._get_token_bucket(domain) |
| 143 | + status = bucket.get_status() |
| 144 | + config = self.domain_rate_limits_config[domain] |
| 145 | + |
| 146 | + return { |
| 147 | + "configured": True, |
| 148 | + "algorithm": "token_bucket", |
| 149 | + "capacity": config.capacity, |
| 150 | + "refill_rate": config.refill_rate, |
| 151 | + "current_tokens": status["current_tokens"], |
| 152 | + "waiting_requests": status["waiting_requests"], |
| 153 | + "utilization": status["utilization"], |
| 154 | + } |
113 | 155 |
|
114 |
| -# 全局限速器实例 |
115 | 156 | domain_rate_limiter = DomainRateLimiter()
|
0 commit comments