Skip to content

Commit 8c22bf9

Browse files
committed
refactor: Refactor domain rate limiting to token bucket algorithm
1 parent 98958db commit 8c22bf9

File tree

2 files changed

+134
-113
lines changed

2 files changed

+134
-113
lines changed

mcim_sync/utils/network/__init__.py

Lines changed: 4 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
"""
44

55
import httpx
6-
import time
76
from typing import Optional, Union
87

98
from tenacity import retry, stop_after_attempt, retry_if_not_exception_type
@@ -63,15 +62,9 @@ def request(
6362
httpx.Response: 请求结果
6463
"""
6564
if not ignore_rate_limit:
66-
# 检查是否可以发起请求
67-
if not domain_rate_limiter.can_make_request(url):
68-
wait_time = domain_rate_limiter.wait_time(url)
69-
if wait_time > 0:
70-
time.sleep(wait_time)
71-
72-
# 记录请求
73-
domain_rate_limiter.record_request(url)
74-
65+
if not domain_rate_limiter.acquire_token(url):
66+
raise TimeoutError(f"Rate limit timeout for {url}")
67+
7568
# 执行实际请求
7669
if params is not None:
7770
params = {k: v for k, v in params.items() if v is not None}
@@ -105,17 +98,4 @@ def request(
10598
params=params,
10699
msg=res.text,
107100
)
108-
return res
109-
110-
111-
def get_domain_status(domain: str) -> dict:
112-
"""
113-
获取域名的限速状态
114-
115-
Args:
116-
domain (str): 域名
117-
118-
Returns:
119-
dict: 限速状态信息
120-
"""
121-
return domain_rate_limiter.get_domain_status(domain)
101+
return res

mcim_sync/utils/rate_limit.py

Lines changed: 130 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -1,115 +1,156 @@
1-
"""
2-
域名限速器模块
3-
"""
4-
51
import time
62
import threading
7-
from collections import deque, defaultdict
8-
from urllib.parse import urlparse
93
from typing import Dict
4+
from urllib.parse import urlparse
105

116
from mcim_sync.config import Config
127

138

9+
class TokenBucket:
10+
"""令牌桶"""
11+
12+
def __init__(self, capacity: int, refill_rate: float, initial_tokens: int = None):
13+
self.capacity = capacity
14+
self.refill_rate = refill_rate
15+
self.tokens = initial_tokens if initial_tokens is not None else capacity
16+
self.last_refill_time = time.monotonic()
17+
self.condition = threading.Condition()
18+
self.waiting_count = 0
19+
self._start_refill_thread()
20+
21+
def _start_refill_thread(self):
22+
"""启动后台线程定期补充令牌"""
23+
24+
def refill_loop():
25+
while True:
26+
with self.condition:
27+
old_tokens = self.tokens
28+
self._refill()
29+
new_tokens = int(self.tokens - old_tokens)
30+
if new_tokens > 0 and self.waiting_count > 0:
31+
for _ in range(min(new_tokens, self.waiting_count)):
32+
self.condition.notify()
33+
time.sleep(1.0 / self.refill_rate)
34+
35+
thread = threading.Thread(target=refill_loop, daemon=True)
36+
thread.start()
37+
38+
def _refill(self):
39+
"""补充令牌 - 必须在持有锁的情况下调用"""
40+
current_time = time.monotonic()
41+
time_passed = current_time - self.last_refill_time
42+
tokens_to_add = time_passed * self.refill_rate
43+
self.tokens = min(self.capacity, self.tokens + tokens_to_add)
44+
self.last_refill_time = current_time
45+
46+
def acquire(self, tokens: int = 1, timeout: float = None) -> bool:
47+
"""
48+
获取令牌,如果没有则等待
49+
"""
50+
with self.condition:
51+
self._refill()
52+
if self.tokens >= tokens:
53+
self.tokens -= tokens
54+
return True
55+
56+
self.waiting_count += 1
57+
end_time = None if timeout is None else time.monotonic() + timeout
58+
59+
try:
60+
while self.tokens < tokens:
61+
if timeout is None:
62+
self.condition.wait()
63+
else:
64+
remaining = end_time - time.monotonic()
65+
if remaining <= 0 or not self.condition.wait(timeout=remaining):
66+
return False
67+
self._refill()
68+
self.tokens -= tokens
69+
return True
70+
finally:
71+
self.waiting_count -= 1
72+
73+
def get_status(self) -> Dict:
74+
"""获取状态"""
75+
with self.condition:
76+
self._refill()
77+
return {
78+
"capacity": self.capacity,
79+
"current_tokens": self.tokens,
80+
"refill_rate": self.refill_rate,
81+
"waiting_requests": self.waiting_count,
82+
"utilization": (self.capacity - self.tokens) / self.capacity,
83+
}
84+
85+
1486
class DomainRateLimiter:
15-
"""简单的域名限速器"""
16-
87+
"""基于令牌桶的域名限速器"""
88+
1789
def __init__(self):
1890
self.domain_rate_limits_config = Config.load().domain_rate_limits
19-
self.domain_requests: Dict[str, deque] = defaultdict(deque)
20-
self.locks: Dict[str, threading.Lock] = defaultdict(threading.Lock)
21-
91+
self.token_buckets: Dict[str, TokenBucket] = {}
92+
self.lock = threading.Lock()
93+
2294
def get_domain_from_url(self, url: str) -> str:
2395
"""从URL中提取域名"""
2496
try:
2597
parsed = urlparse(url)
26-
return parsed.netloc.lower()
98+
return parsed.hostname.lower() if parsed.hostname else "unknown"
2799
except Exception:
28100
return "unknown"
29-
30-
def can_make_request(self, url: str) -> bool:
31-
"""检查是否可以向指定URL发起请求"""
101+
102+
def _get_token_bucket(self, domain: str) -> TokenBucket:
103+
"""获取域名对应的令牌桶"""
104+
with self.lock:
105+
bucket = self.token_buckets.get(domain)
106+
if bucket is not None:
107+
return bucket
108+
109+
config = self.domain_rate_limits_config.get(domain)
110+
if config is None:
111+
raise ValueError(f"Domain '{domain}' not configured in rate limiter")
112+
113+
bucket = TokenBucket(
114+
capacity=config.capacity,
115+
refill_rate=config.refill_rate,
116+
initial_tokens=config.initial_tokens,
117+
)
118+
self.token_buckets[domain] = bucket
119+
return bucket
120+
121+
def acquire_token(self, url: str, timeout: float = None) -> bool:
122+
"""
123+
获取令牌,如果没有则等待
124+
"""
32125
domain = self.get_domain_from_url(url)
33-
34-
# 如果域名没有配置限速,则允许请求
126+
35127
if domain not in self.domain_rate_limits_config:
36128
return True
37-
38-
domain_config = self.domain_rate_limits_config[domain]
39-
current_time = time.time()
40-
41-
with self.locks[domain]:
42-
requests = self.domain_requests[domain]
43-
44-
# 清理超出时间窗口的请求记录
45-
while requests and current_time - requests[0] > domain_config.time_window:
46-
requests.popleft()
47-
48-
# 检查是否超过最大请求数
49-
return len(requests) < domain_config.max_requests
50-
51-
def record_request(self, url: str):
52-
"""记录请求"""
53-
domain = self.get_domain_from_url(url)
54-
55-
# 如果域名没有配置限速,则不记录
56-
if domain not in self.domain_rate_limits_config:
57-
return
58-
59-
current_time = time.time()
60-
61-
with self.locks[domain]:
62-
self.domain_requests[domain].append(current_time)
63-
64-
def wait_time(self, url: str) -> float:
65-
"""计算需要等待的时间"""
66-
domain = self.get_domain_from_url(url)
67-
68-
# 如果域名没有配置限速,则不需要等待
69-
if domain not in self.domain_rate_limits_config:
70-
return 0.0
71-
72-
domain_config = self.domain_rate_limits_config[domain]
73-
current_time = time.time()
74-
75-
with self.locks[domain]:
76-
requests = self.domain_requests[domain]
77-
78-
# 清理超出时间窗口的请求记录
79-
while requests and current_time - requests[0] > domain_config.time_window:
80-
requests.popleft()
81-
82-
# 如果请求数已满,计算等待时间
83-
if len(requests) >= domain_config.max_requests:
84-
oldest_request = requests[0]
85-
return domain_config.time_window - (current_time - oldest_request)
86-
87-
return 0.0
88-
129+
130+
try:
131+
bucket = self._get_token_bucket(domain)
132+
except ValueError:
133+
return True # fallback for dynamic change
134+
135+
return bucket.acquire(timeout=timeout)
136+
89137
def get_domain_status(self, domain: str) -> Dict:
90138
"""获取域名的限速状态"""
91139
if domain not in self.domain_rate_limits_config:
92140
return {"configured": False}
93-
94-
domain_config = self.domain_rate_limits_config[domain]
95-
current_time = time.time()
96-
97-
with self.locks[domain]:
98-
requests = self.domain_requests[domain]
99-
100-
# 清理超出时间窗口的请求记录
101-
while requests and current_time - requests[0] > domain_config.time_window:
102-
requests.popleft()
103-
104-
return {
105-
"configured": True,
106-
"max_requests": domain_config.max_requests,
107-
"time_window": domain_config.time_window,
108-
"current_requests": len(requests),
109-
"remaining_requests": domain_config.max_requests - len(requests),
110-
"next_reset_time": requests[0] + domain_config.time_window if requests else current_time
111-
}
112141

142+
bucket = self._get_token_bucket(domain)
143+
status = bucket.get_status()
144+
config = self.domain_rate_limits_config[domain]
145+
146+
return {
147+
"configured": True,
148+
"algorithm": "token_bucket",
149+
"capacity": config.capacity,
150+
"refill_rate": config.refill_rate,
151+
"current_tokens": status["current_tokens"],
152+
"waiting_requests": status["waiting_requests"],
153+
"utilization": status["utilization"],
154+
}
113155

114-
# 全局限速器实例
115156
domain_rate_limiter = DomainRateLimiter()

0 commit comments

Comments
 (0)