Skip to content

Commit 9716a41

Browse files
authored
Merge pull request #911 from transformerlab/add/auth-to-all-endpoints
Add authorization requirement to all API endpoints
2 parents 3a10a76 + 2015ac4 commit 9716a41

File tree

11 files changed

+239
-52
lines changed

11 files changed

+239
-52
lines changed

.github/workflows/pytest-server-test-macos.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ jobs:
2525
AUTH_SUCCESS_REDIRECT_URL: /
2626
AUTH_ERROR_REDIRECT_URL: /
2727
AUTH_LOGOUT_REDIRECT_URL: /
28+
TRANSFORMERLAB_JWT_SECRET: ${{ secrets.TRANSFORMERLAB_JWT_SECRET }}
29+
TRANSFORMERLAB_REFRESH_SECRET: ${{ secrets.TRANSFORMERLAB_REFRESH_SECRET }}
30+
EMAIL_METHOD: "dev"
2831
strategy:
2932
fail-fast: false
3033
matrix:

.github/workflows/pytest-server-test.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ jobs:
2525
AUTH_SUCCESS_REDIRECT_URL: /
2626
AUTH_ERROR_REDIRECT_URL: /
2727
AUTH_LOGOUT_REDIRECT_URL: /
28+
TRANSFORMERLAB_JWT_SECRET: ${{ secrets.TRANSFORMERLAB_JWT_SECRET }}
29+
TRANSFORMERLAB_REFRESH_SECRET: ${{ secrets.TRANSFORMERLAB_REFRESH_SECRET }}
30+
EMAIL_METHOD: "dev"
2831
strategy:
2932
fail-fast: false
3033
matrix:

.github/workflows/pytest.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ jobs:
4343
uv pip install --system -r requirements-no-gpu-uv.txt --upgrade --index=https://download.pytorch.org/whl/cpu --index-strategy unsafe-best-match
4444
- name: Test with pytest
4545
run: |
46-
pytest --cov=transformerlab --cov-branch --cov-report=xml
46+
pytest --cov=transformerlab --cov-branch --cov-report=xml -k 'not test_teams'
4747
4848
- name: Upload results to Codecov
4949
uses: codecov/codecov-action@v5

api/api.py

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
# Using torch to test for CUDA and MPS support.
2020
import uvicorn
21-
from fastapi import FastAPI, HTTPException, Request
21+
from fastapi import FastAPI, HTTPException, Request, Depends
2222
from fastapi.exceptions import RequestValidationError
2323
from fastapi.middleware.cors import CORSMiddleware
2424
from fastapi.staticfiles import StaticFiles
@@ -32,7 +32,7 @@
3232

3333
from transformerlab.services.experiment_service import experiment_get
3434
from transformerlab.services.job_service import job_create, job_get, job_update_status
35-
from transformerlab.services.experiment_init import seed_default_experiments, cancel_in_progress_jobs
35+
from transformerlab.services.experiment_init import seed_default_experiments, cancel_in_progress_jobs, seed_default_admin_user
3636
import transformerlab.db.session as db
3737

3838
from transformerlab.shared.ssl_utils import ensure_persistent_self_signed_cert
@@ -53,6 +53,7 @@
5353
auth2,
5454
teams,
5555
)
56+
from transformerlab.routers.auth2 import get_user_and_team
5657
import torch
5758

5859
try:
@@ -115,6 +116,8 @@ async def lifespan(app: FastAPI):
115116
await db.init() # This now runs Alembic migrations internally
116117
# create_db_and_tables() is deprecated - migrations are handled in db.init()
117118
print("✅ SEED DATA")
119+
# Seed default admin user
120+
await seed_default_admin_user()
118121
# Initialize experiments and cancel any running jobs
119122
seed_default_experiments()
120123
cancel_in_progress_jobs()
@@ -218,23 +221,23 @@ async def validation_exception_handler(request, exc):
218221
### END GENERAL API - NOT OPENAI COMPATIBLE ###
219222

220223

221-
app.include_router(model.router)
222-
app.include_router(serverinfo.router)
223-
app.include_router(train.router)
224-
app.include_router(data.router)
225-
app.include_router(experiment.router)
226-
app.include_router(plugins.router)
227-
app.include_router(evals.router)
228-
app.include_router(jobs.router)
229-
app.include_router(tasks.router)
230-
app.include_router(config.router)
231-
app.include_router(prompts.router)
232-
app.include_router(tools.router)
233-
app.include_router(recipes.router)
234-
app.include_router(batched_prompts.router)
235-
app.include_router(remote.router)
236-
app.include_router(fastchat_openai_api.router)
237-
app.include_router(teams.router)
224+
app.include_router(model.router, dependencies=[Depends(get_user_and_team)])
225+
app.include_router(serverinfo.router, dependencies=[Depends(get_user_and_team)])
226+
app.include_router(train.router, dependencies=[Depends(get_user_and_team)])
227+
app.include_router(data.router, dependencies=[Depends(get_user_and_team)])
228+
app.include_router(experiment.router, dependencies=[Depends(get_user_and_team)])
229+
app.include_router(plugins.router, dependencies=[Depends(get_user_and_team)])
230+
app.include_router(evals.router, dependencies=[Depends(get_user_and_team)])
231+
app.include_router(jobs.router, dependencies=[Depends(get_user_and_team)])
232+
app.include_router(tasks.router, dependencies=[Depends(get_user_and_team)])
233+
app.include_router(config.router, dependencies=[Depends(get_user_and_team)])
234+
app.include_router(prompts.router, dependencies=[Depends(get_user_and_team)])
235+
app.include_router(tools.router, dependencies=[Depends(get_user_and_team)])
236+
app.include_router(recipes.router, dependencies=[Depends(get_user_and_team)])
237+
app.include_router(batched_prompts.router, dependencies=[Depends(get_user_and_team)])
238+
app.include_router(remote.router, dependencies=[Depends(get_user_and_team)])
239+
app.include_router(fastchat_openai_api.router, dependencies=[Depends(get_user_and_team)])
240+
app.include_router(teams.router, dependencies=[Depends(get_user_and_team)])
238241
app.include_router(auth2.router)
239242

240243
# Authentication and session management routes

api/test/api/conftest.py

Lines changed: 58 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pytest
22
from fastapi.testclient import TestClient
33
import os
4+
import asyncio
45

56
# Create test directories before setting environment variables
67
os.makedirs("test/tmp/workspace", exist_ok=True)
@@ -12,25 +13,70 @@
1213
os.environ["TRANSFORMERLAB_REFRESH_SECRET"] = "test-refresh-secret-for-testing-only"
1314
os.environ["EMAIL_METHOD"] = "dev" # Use dev mode for tests (no actual email sending)
1415

16+
# Use in-memory database for tests
17+
os.environ["DATABASE_URL"] = "sqlite+aiosqlite:///:memory:"
18+
1519
from api import app # noqa: E402
1620

1721

18-
@pytest.fixture(scope="session", autouse=True)
19-
def cleanup_test_database():
20-
"""Clean up test database before and after test session"""
21-
# Clean up before tests
22-
db_path = "test/tmp/llmlab.sqlite3"
23-
if os.path.exists(db_path):
24-
os.remove(db_path)
22+
class AuthenticatedTestClient(TestClient):
23+
"""TestClient that automatically adds admin authentication headers to all requests"""
24+
25+
def __init__(self, app, *args, **kwargs):
26+
super().__init__(app, *args, **kwargs)
27+
self._token = None
28+
self._team_id = None
29+
self._get_token()
2530

26-
yield
2731

28-
# Clean up after tests
29-
if os.path.exists(db_path):
30-
os.remove(db_path)
32+
def _get_token(self):
33+
"""Get or refresh admin token and team"""
34+
if self._token is None:
35+
login_response = super().post(
36+
"/auth/jwt/login",
37+
data={"username": "[email protected]", "password": "admin123"}
38+
)
39+
if login_response.status_code != 200:
40+
raise RuntimeError(f"Failed to get admin token: {login_response.text}")
41+
self._token = login_response.json()["access_token"]
42+
43+
# Get user's teams
44+
teams_response = super().get(
45+
"/users/me/teams",
46+
headers={"Authorization": f"Bearer {self._token}"}
47+
)
48+
if teams_response.status_code == 200:
49+
teams = teams_response.json()["teams"]
50+
if teams:
51+
self._team_id = teams[0]["id"] # Use the first team
52+
return self._token
53+
54+
def request(self, method, url, **kwargs):
55+
"""Override request to add auth headers"""
56+
# Don't add auth headers to auth endpoints
57+
if "/auth/" not in url:
58+
# Ensure headers dict exists
59+
if "headers" not in kwargs or kwargs["headers"] is None:
60+
kwargs["headers"] = {}
61+
# Only add Authorization if not already present
62+
if "Authorization" not in kwargs["headers"]:
63+
kwargs["headers"]["Authorization"] = f"Bearer {self._get_token()}"
64+
# Only add team header if not already present
65+
if self._team_id and "X-Team-Id" not in kwargs["headers"]:
66+
kwargs["headers"]["X-Team-Id"] = self._team_id
67+
return super().request(method, url, **kwargs)
68+
3169

3270

3371
@pytest.fixture(scope="session")
3472
def client():
35-
with TestClient(app) as c:
73+
# Initialize database tables for tests
74+
from transformerlab.shared.models.user_model import create_db_and_tables # noqa: E402
75+
from transformerlab.services.experiment_init import seed_default_admin_user # noqa: E402
76+
77+
asyncio.run(create_db_and_tables())
78+
asyncio.run(seed_default_admin_user())
79+
80+
with AuthenticatedTestClient(app) as c:
3681
yield c
82+

api/test/api/test_model.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,8 @@ async def test_install_peft_base_model_adaptor_not_found(mock_run_script, mock_g
8282
assert "adapter not found" in data["message"]
8383

8484

85-
def test_install_peft_success(client):
85+
@pytest.mark.asyncio
86+
async def test_install_peft_success(client):
8687
adapter_id = "tcotter/Llama-3.2-1B-Instruct-Mojo-Adapter"
8788
model_id = "unsloth/Llama-3.2-1B-Instruct"
8889

@@ -93,7 +94,6 @@ def test_install_peft_success(client):
9394
patch("huggingface_hub.HfApi.model_info", return_value=make_mock_adapter_info()),
9495
patch("transformerlab.routers.model.huggingfacemodel.get_model_details_from_huggingface", return_value={}),
9596
patch("transformerlab.routers.model.job_service.job_create", return_value=123),
96-
patch("transformerlab.routers.model.asyncio.create_task"),
9797
):
9898
response = client.post(
9999
"/model/install_peft", params={"peft": adapter_id, "model_id": model_id, "experiment_id": 1}
@@ -130,7 +130,8 @@ def test_install_peft_adapter_info_fail(client):
130130
assert response.json()["check_status"]["error"] == "not found"
131131

132132

133-
def test_install_peft_architecture_detection_unknown(client):
133+
@pytest.mark.asyncio
134+
async def test_install_peft_architecture_detection_unknown(client):
134135
adapter_info = make_mock_adapter_info()
135136
with (
136137
patch("transformerlab.routers.model.snapshot_download", return_value="/tmp/mock"),
@@ -139,7 +140,6 @@ def test_install_peft_architecture_detection_unknown(client):
139140
patch("huggingface_hub.HfApi.model_info", return_value=adapter_info),
140141
patch("transformerlab.routers.model.huggingfacemodel.get_model_details_from_huggingface", return_value={}),
141142
patch("transformerlab.routers.model.job_service.job_create", return_value=123),
142-
patch("transformerlab.routers.model.asyncio.create_task"),
143143
):
144144
response = client.post(
145145
"/model/install_peft", params={"peft": "dummy", "model_id": "valid_model", "experiment_id": 1}
@@ -148,7 +148,8 @@ def test_install_peft_architecture_detection_unknown(client):
148148
assert response.json()["check_status"]["architectures_status"] == "unknown"
149149

150150

151-
def test_install_peft_unknown_field_status(client):
151+
@pytest.mark.asyncio
152+
async def test_install_peft_unknown_field_status(client):
152153
adapter_info = make_mock_adapter_info(overrides={"config": {}})
153154
with (
154155
patch("transformerlab.routers.model.snapshot_download", return_value="/tmp/mock"),
@@ -157,7 +158,6 @@ def test_install_peft_unknown_field_status(client):
157158
patch("huggingface_hub.HfApi.model_info", return_value=adapter_info),
158159
patch("transformerlab.routers.model.huggingfacemodel.get_model_details_from_huggingface", return_value={}),
159160
patch("transformerlab.routers.model.job_service.job_create", return_value=123),
160-
patch("transformerlab.routers.model.asyncio.create_task"),
161161
):
162162
response = client.post(
163163
"/model/install_peft", params={"peft": "dummy", "model_id": "valid_model", "experiment_id": 1}

api/test/server/test_config.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,47 @@
44

55
@pytest.mark.live_server
66
def test_set(live_server):
7-
response = requests.get(f"{live_server}/config/set", params={"k": "message", "v": "Hello, World!"})
7+
# Get admin token for authentication
8+
login_response = requests.post(
9+
f"{live_server}/auth/jwt/login",
10+
data={"username": "[email protected]", "password": "admin123"}
11+
)
12+
assert login_response.status_code == 200, f"Login failed with {login_response.status_code}: {login_response.text}"
13+
token = login_response.json()["access_token"]
14+
headers = {"Authorization": f"Bearer {token}"}
15+
16+
# Get user's team ID
17+
teams_response = requests.get(f"{live_server}/users/me/teams", headers=headers)
18+
assert teams_response.status_code == 200
19+
teams_data = teams_response.json()
20+
assert "teams" in teams_data and len(teams_data["teams"]) > 0, "User has no teams"
21+
team_id = teams_data["teams"][0]["id"]
22+
headers["X-Team-Id"] = team_id
23+
24+
response = requests.get(f"{live_server}/config/set", params={"k": "message", "v": "Hello, World!"}, headers=headers)
825
assert response.status_code == 200
926
assert response.json() == {"key": "message", "value": "Hello, World!"}
1027

1128

1229
@pytest.mark.live_server
1330
def test_get(live_server):
14-
response = requests.get(f"{live_server}/config/get/message")
31+
# Get admin token for authentication
32+
login_response = requests.post(
33+
f"{live_server}/auth/jwt/login",
34+
data={"username": "[email protected]", "password": "admin123"}
35+
)
36+
assert login_response.status_code == 200, f"Login failed with {login_response.status_code}: {login_response.text}"
37+
token = login_response.json()["access_token"]
38+
headers = {"Authorization": f"Bearer {token}"}
39+
40+
# Get user's team ID
41+
teams_response = requests.get(f"{live_server}/users/me/teams", headers=headers)
42+
assert teams_response.status_code == 200
43+
teams_data = teams_response.json()
44+
assert "teams" in teams_data and len(teams_data["teams"]) > 0, "User has no teams"
45+
team_id = teams_data["teams"][0]["id"]
46+
headers["X-Team-Id"] = team_id
47+
48+
response = requests.get(f"{live_server}/config/get/message", headers=headers)
1549
assert response.status_code == 200
1650
assert response.json() == "Hello, World!"

api/test/server/test_server_info.py

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,24 @@
44

55
@pytest.mark.live_server
66
def test_server_info(live_server):
7-
response = requests.get(f"{live_server}/server/info")
7+
# Get admin token for authentication
8+
login_response = requests.post(
9+
f"{live_server}/auth/jwt/login",
10+
data={"username": "[email protected]", "password": "admin123"}
11+
)
12+
assert login_response.status_code == 200, f"Login failed with {login_response.status_code}: {login_response.text}"
13+
token = login_response.json()["access_token"]
14+
headers = {"Authorization": f"Bearer {token}"}
15+
16+
# Get user's team ID
17+
teams_response = requests.get(f"{live_server}/users/me/teams", headers=headers)
18+
assert teams_response.status_code == 200
19+
teams_data = teams_response.json()
20+
assert "teams" in teams_data and len(teams_data["teams"]) > 0, "User has no teams"
21+
team_id = teams_data["teams"][0]["id"]
22+
headers["X-Team-Id"] = team_id
23+
24+
response = requests.get(f"{live_server}/server/info", headers=headers)
825
assert response.status_code == 200
926
data = response.json()
1027
assert isinstance(data, dict)
@@ -24,7 +41,24 @@ def test_server_info(live_server):
2441

2542
@pytest.mark.live_server
2643
def test_server_python_libraries(live_server):
27-
response = requests.get(f"{live_server}/server/python_libraries")
44+
# Get admin token for authentication
45+
login_response = requests.post(
46+
f"{live_server}/auth/jwt/login",
47+
data={"username": "[email protected]", "password": "admin123"}
48+
)
49+
assert login_response.status_code == 200, f"Login failed with {login_response.status_code}: {login_response.text}"
50+
token = login_response.json()["access_token"]
51+
headers = {"Authorization": f"Bearer {token}"}
52+
53+
# Get user's team ID
54+
teams_response = requests.get(f"{live_server}/users/me/teams", headers=headers)
55+
assert teams_response.status_code == 200
56+
teams_data = teams_response.json()
57+
assert "teams" in teams_data and len(teams_data["teams"]) > 0, "User has no teams"
58+
team_id = teams_data["teams"][0]["id"]
59+
headers["X-Team-Id"] = team_id
60+
61+
response = requests.get(f"{live_server}/server/python_libraries", headers=headers)
2862
assert response.status_code == 200
2963
data = response.json()
3064
# assert it is an array of {"name": "package_name", "version": "version_number"} type things
@@ -38,7 +72,24 @@ def test_server_python_libraries(live_server):
3872

3973
@pytest.mark.live_server
4074
def test_server_pytorch_collect_env(live_server):
41-
response = requests.get(f"{live_server}/server/pytorch_collect_env")
75+
# Get admin token for authentication
76+
login_response = requests.post(
77+
f"{live_server}/auth/jwt/login",
78+
data={"username": "[email protected]", "password": "admin123"}
79+
)
80+
assert login_response.status_code == 200, f"Login failed with {login_response.status_code}: {login_response.text}"
81+
token = login_response.json()["access_token"]
82+
headers = {"Authorization": f"Bearer {token}"}
83+
84+
# Get user's team ID
85+
teams_response = requests.get(f"{live_server}/users/me/teams", headers=headers)
86+
assert teams_response.status_code == 200
87+
teams_data = teams_response.json()
88+
assert "teams" in teams_data and len(teams_data["teams"]) > 0, "User has no teams"
89+
team_id = teams_data["teams"][0]["id"]
90+
headers["X-Team-Id"] = team_id
91+
92+
response = requests.get(f"{live_server}/server/pytorch_collect_env", headers=headers)
4293
assert response.status_code == 200
4394
data = response.text
4495
assert "PyTorch" in data

api/transformerlab/routers/auth/routes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from .api_key_auth import get_user_or_api_key
1212

13-
router = APIRouter(prefix="/auth", tags=["auth"])
13+
router = APIRouter(prefix="/auth", tags=["auth-workos-deprecated"], deprecated=True)
1414

1515

1616
@router.get("/login-url")

0 commit comments

Comments
 (0)