Skip to content

Commit 5a551fe

Browse files
committed
feat: Increase API concurrency control to avoid service downtime
1 parent 4f8d897 commit 5a551fe

File tree

1 file changed

+29
-3
lines changed

1 file changed

+29
-3
lines changed

mineru/cli/fast_api.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import zipfile
99
from pathlib import Path
1010
import glob
11-
from fastapi import FastAPI, UploadFile, File, Form
11+
from fastapi import Depends, FastAPI, HTTPException, UploadFile, File, Form
1212
from fastapi.middleware.gzip import GZipMiddleware
1313
from fastapi.responses import JSONResponse, FileResponse
1414
from starlette.background import BackgroundTask
@@ -21,7 +21,25 @@
2121
from mineru.utils.guess_suffix_or_lang import guess_suffix_by_path
2222
from mineru.version import __version__
2323

24-
app = FastAPI()
24+
# 并发控制器
25+
_request_semaphore: Optional[asyncio.Semaphore] = None
26+
27+
28+
# 并发控制依赖函数
29+
async def limit_concurrency():
30+
if _request_semaphore is not None:
31+
if _request_semaphore.locked():
32+
raise HTTPException(
33+
status_code=503,
34+
detail="Server is at maximum capacity. Please try again later."
35+
)
36+
async with _request_semaphore:
37+
yield
38+
else:
39+
yield
40+
41+
42+
app = FastAPI(openapi_url=None, docs_url=None, redoc_url=None)
2543
app.add_middleware(GZipMiddleware, minimum_size=1000)
2644

2745

@@ -60,7 +78,7 @@ def get_infer_result(file_suffix_identifier: str, pdf_name: str, parse_dir: str)
6078
return None
6179

6280

63-
@app.post(path="/file_parse",)
81+
@app.post(path="/file_parse", dependencies=[Depends(limit_concurrency)])
6482
async def parse_pdf(
6583
files: List[UploadFile] = File(...),
6684
output_dir: str = Form("./output"),
@@ -256,6 +274,14 @@ def main(ctx, host, port, reload, **kwargs):
256274

257275
kwargs.update(arg_parse(ctx))
258276

277+
# 初始化并发控制器
278+
global _request_semaphore
279+
max_concurrent_requests = int(kwargs.get("max_concurrent_requests", 0))
280+
if max_concurrent_requests > 0:
281+
_request_semaphore = asyncio.Semaphore(max_concurrent_requests)
282+
logger.info(f"Request concurrency limited to {max_concurrent_requests}")
283+
284+
259285
# 将配置参数存储到应用状态中
260286
app.state.config = kwargs
261287

0 commit comments

Comments
 (0)