Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
cdbfb58
Add authorization requirement to all API endpoints
mina-parham Nov 21, 2025
ca8d4cb
Merge remote-tracking branch 'origin/fix/add-back-process-envs-in-ren…
mina-parham Nov 21, 2025
1dd72e3
Merge branch 'main' into add/auth-to-all-endpoints
mina-parham Nov 21, 2025
af78644
Add mock authentication to test fixtures
mina-parham Nov 21, 2025
c1cbe41
Merge branch 'add/auth-to-all-endpoints' of https://github.com/transf…
mina-parham Nov 21, 2025
a34ae66
Fix test so properly test authenticated endpoints
mina-parham Nov 21, 2025
6a74442
Create default admin user
mina-parham Nov 21, 2025
978bd91
Update live_server tests to manually authenticate with admin credentials
mina-parham Nov 21, 2025
67f1a83
Merge branch 'main' into add/auth-to-all-endpoints
mina-parham Nov 21, 2025
6e23b06
Ruff
mina-parham Nov 21, 2025
bdedffb
Merge branch 'add/auth-to-all-endpoints' of https://github.com/transf…
mina-parham Nov 21, 2025
dcc8f10
Debug admin user
mina-parham Nov 21, 2025
9ba61e2
Ruff
mina-parham Nov 21, 2025
41e7a16
Remove dependencies from team routers
mina-parham Nov 21, 2025
57c1df6
Debug
mina-parham Nov 21, 2025
8c6e318
Merge branch 'main' into add/auth-to-all-endpoints
mina-parham Nov 21, 2025
a8391e9
Fix config test
mina-parham Nov 21, 2025
ba375ba
Merge branch 'add/auth-to-all-endpoints' of https://github.com/transf…
mina-parham Nov 21, 2025
5c5b45d
Fix server test
mina-parham Nov 21, 2025
5c903de
Fix transformerlab/services/experiment_init.py
mina-parham Nov 21, 2025
3387223
Fix admin user verification in seed function
mina-parham Nov 21, 2025
d38bfb7
Add error logging
mina-parham Nov 21, 2025
dfbffd1
Add error logging
mina-parham Nov 21, 2025
a4f13d8
Fix admin user verification by re-fetching from database after creation
mina-parham Nov 21, 2025
5e6689e
add missing environment variables to CI
mina-parham Nov 21, 2025
cfc3549
Remove multi tenant
mina-parham Nov 21, 2025
6f911a2
Update macos workflow
mina-parham Nov 21, 2025
c01b32c
Update the pytest workflow
mina-parham Nov 21, 2025
05713cd
patch background function instead of asyncio.create_task in model tests
mina-parham Nov 21, 2025
ad6e98f
Revert pytest
mina-parham Nov 21, 2025
284f683
Remove cleanup_test_database
mina-parham Nov 21, 2025
e9bb3b8
Add auth to teams routers and skip test_teams.py
mina-parham Nov 24, 2025
0902cda
Revert test_model.py
mina-parham Nov 24, 2025
3cd7b63
Remove _init_db_and_admin from conftest
mina-parham Nov 24, 2025
232e57e
mark the workos routes as deprecated wth a tag for now
aliasaria Nov 24, 2025
fbca8e0
Fix test client to include X-Team-Id header for authenticated requests
mina-parham Nov 24, 2025
f36fed0
Merge branch 'add/auth-to-all-endpoints' of https://github.com/transf…
mina-parham Nov 24, 2025
2d2cb4d
Merge branch 'main' into add/auth-to-all-endpoints
mina-parham Nov 24, 2025
088cfa4
Fix test hanging by using in-memory database
mina-parham Nov 24, 2025
2566b90
Fix test_model.py
mina-parham Nov 24, 2025
4240a13
Merge branch 'main' into add/auth-to-all-endpoints
mina-parham Nov 24, 2025
35b0e8d
Ruff
mina-parham Nov 24, 2025
c40ac26
Add back create_db_and_tables
mina-parham Nov 24, 2025
2015ac4
Merge branch 'main' into add/auth-to-all-endpoints
mina-parham Nov 24, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/pytest-server-test-macos.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ jobs:
AUTH_SUCCESS_REDIRECT_URL: /
AUTH_ERROR_REDIRECT_URL: /
AUTH_LOGOUT_REDIRECT_URL: /
TRANSFORMERLAB_JWT_SECRET: ${{ secrets.TRANSFORMERLAB_JWT_SECRET }}
TRANSFORMERLAB_REFRESH_SECRET: ${{ secrets.TRANSFORMERLAB_REFRESH_SECRET }}
EMAIL_METHOD: "dev"
strategy:
fail-fast: false
matrix:
Expand Down
3 changes: 3 additions & 0 deletions .github/workflows/pytest-server-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ jobs:
AUTH_SUCCESS_REDIRECT_URL: /
AUTH_ERROR_REDIRECT_URL: /
AUTH_LOGOUT_REDIRECT_URL: /
TRANSFORMERLAB_JWT_SECRET: ${{ secrets.TRANSFORMERLAB_JWT_SECRET }}
TRANSFORMERLAB_REFRESH_SECRET: ${{ secrets.TRANSFORMERLAB_REFRESH_SECRET }}
EMAIL_METHOD: "dev"
strategy:
fail-fast: false
matrix:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ jobs:
uv pip install --system -r requirements-no-gpu-uv.txt --upgrade --index=https://download.pytorch.org/whl/cpu --index-strategy unsafe-best-match
- name: Test with pytest
run: |
pytest --cov=transformerlab --cov-branch --cov-report=xml
pytest --cov=transformerlab --cov-branch --cov-report=xml -k 'not test_teams'

- name: Upload results to Codecov
uses: codecov/codecov-action@v5
Expand Down
41 changes: 22 additions & 19 deletions api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

# Using torch to test for CUDA and MPS support.
import uvicorn
from fastapi import FastAPI, HTTPException, Request
from fastapi import FastAPI, HTTPException, Request, Depends
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
Expand All @@ -32,7 +32,7 @@

from transformerlab.services.experiment_service import experiment_get
from transformerlab.services.job_service import job_create, job_get, job_update_status
from transformerlab.services.experiment_init import seed_default_experiments, cancel_in_progress_jobs
from transformerlab.services.experiment_init import seed_default_experiments, cancel_in_progress_jobs, seed_default_admin_user
import transformerlab.db.session as db

from transformerlab.shared.ssl_utils import ensure_persistent_self_signed_cert
Expand All @@ -53,6 +53,7 @@
auth2,
teams,
)
from transformerlab.routers.auth2 import get_user_and_team
import torch

try:
Expand Down Expand Up @@ -115,6 +116,8 @@ async def lifespan(app: FastAPI):
await db.init() # This now runs Alembic migrations internally
# create_db_and_tables() is deprecated - migrations are handled in db.init()
print("✅ SEED DATA")
# Seed default admin user
await seed_default_admin_user()
# Initialize experiments and cancel any running jobs
seed_default_experiments()
cancel_in_progress_jobs()
Expand Down Expand Up @@ -218,23 +221,23 @@ async def validation_exception_handler(request, exc):
### END GENERAL API - NOT OPENAI COMPATIBLE ###


app.include_router(model.router)
app.include_router(serverinfo.router)
app.include_router(train.router)
app.include_router(data.router)
app.include_router(experiment.router)
app.include_router(plugins.router)
app.include_router(evals.router)
app.include_router(jobs.router)
app.include_router(tasks.router)
app.include_router(config.router)
app.include_router(prompts.router)
app.include_router(tools.router)
app.include_router(recipes.router)
app.include_router(batched_prompts.router)
app.include_router(remote.router)
app.include_router(fastchat_openai_api.router)
app.include_router(teams.router)
app.include_router(model.router, dependencies=[Depends(get_user_and_team)])
app.include_router(serverinfo.router, dependencies=[Depends(get_user_and_team)])
app.include_router(train.router, dependencies=[Depends(get_user_and_team)])
app.include_router(data.router, dependencies=[Depends(get_user_and_team)])
app.include_router(experiment.router, dependencies=[Depends(get_user_and_team)])
app.include_router(plugins.router, dependencies=[Depends(get_user_and_team)])
app.include_router(evals.router, dependencies=[Depends(get_user_and_team)])
app.include_router(jobs.router, dependencies=[Depends(get_user_and_team)])
app.include_router(tasks.router, dependencies=[Depends(get_user_and_team)])
app.include_router(config.router, dependencies=[Depends(get_user_and_team)])
app.include_router(prompts.router, dependencies=[Depends(get_user_and_team)])
app.include_router(tools.router, dependencies=[Depends(get_user_and_team)])
app.include_router(recipes.router, dependencies=[Depends(get_user_and_team)])
app.include_router(batched_prompts.router, dependencies=[Depends(get_user_and_team)])
app.include_router(remote.router, dependencies=[Depends(get_user_and_team)])
app.include_router(fastchat_openai_api.router, dependencies=[Depends(get_user_and_team)])
app.include_router(teams.router, dependencies=[Depends(get_user_and_team)])
app.include_router(auth2.router)

# Authentication and session management routes
Expand Down
70 changes: 58 additions & 12 deletions api/test/api/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
from fastapi.testclient import TestClient
import os
import asyncio

# Create test directories before setting environment variables
os.makedirs("test/tmp/workspace", exist_ok=True)
Expand All @@ -12,25 +13,70 @@
os.environ["TRANSFORMERLAB_REFRESH_SECRET"] = "test-refresh-secret-for-testing-only"
os.environ["EMAIL_METHOD"] = "dev" # Use dev mode for tests (no actual email sending)

# Use in-memory database for tests
os.environ["DATABASE_URL"] = "sqlite+aiosqlite:///:memory:"

from api import app # noqa: E402


@pytest.fixture(scope="session", autouse=True)
def cleanup_test_database():
"""Clean up test database before and after test session"""
# Clean up before tests
db_path = "test/tmp/llmlab.sqlite3"
if os.path.exists(db_path):
os.remove(db_path)
class AuthenticatedTestClient(TestClient):
"""TestClient that automatically adds admin authentication headers to all requests"""

def __init__(self, app, *args, **kwargs):
super().__init__(app, *args, **kwargs)
self._token = None
self._team_id = None
self._get_token()

yield

# Clean up after tests
if os.path.exists(db_path):
os.remove(db_path)
def _get_token(self):
"""Get or refresh admin token and team"""
if self._token is None:
login_response = super().post(
"/auth/jwt/login",
data={"username": "[email protected]", "password": "admin123"}
)
if login_response.status_code != 200:
raise RuntimeError(f"Failed to get admin token: {login_response.text}")
self._token = login_response.json()["access_token"]

# Get user's teams
teams_response = super().get(
"/users/me/teams",
headers={"Authorization": f"Bearer {self._token}"}
)
if teams_response.status_code == 200:
teams = teams_response.json()["teams"]
if teams:
self._team_id = teams[0]["id"] # Use the first team
return self._token

def request(self, method, url, **kwargs):
"""Override request to add auth headers"""
# Don't add auth headers to auth endpoints
if "/auth/" not in url:
# Ensure headers dict exists
if "headers" not in kwargs or kwargs["headers"] is None:
kwargs["headers"] = {}
# Only add Authorization if not already present
if "Authorization" not in kwargs["headers"]:
kwargs["headers"]["Authorization"] = f"Bearer {self._get_token()}"
# Only add team header if not already present
if self._team_id and "X-Team-Id" not in kwargs["headers"]:
kwargs["headers"]["X-Team-Id"] = self._team_id
return super().request(method, url, **kwargs)



@pytest.fixture(scope="session")
def client():
with TestClient(app) as c:
# Initialize database tables for tests
from transformerlab.shared.models.user_model import create_db_and_tables # noqa: E402
from transformerlab.services.experiment_init import seed_default_admin_user # noqa: E402

asyncio.run(create_db_and_tables())
asyncio.run(seed_default_admin_user())

with AuthenticatedTestClient(app) as c:
yield c

12 changes: 6 additions & 6 deletions api/test/api/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ async def test_install_peft_base_model_adaptor_not_found(mock_run_script, mock_g
assert "adapter not found" in data["message"]


def test_install_peft_success(client):
@pytest.mark.asyncio
async def test_install_peft_success(client):
adapter_id = "tcotter/Llama-3.2-1B-Instruct-Mojo-Adapter"
model_id = "unsloth/Llama-3.2-1B-Instruct"

Expand All @@ -93,7 +94,6 @@ def test_install_peft_success(client):
patch("huggingface_hub.HfApi.model_info", return_value=make_mock_adapter_info()),
patch("transformerlab.routers.model.huggingfacemodel.get_model_details_from_huggingface", return_value={}),
patch("transformerlab.routers.model.job_service.job_create", return_value=123),
patch("transformerlab.routers.model.asyncio.create_task"),
):
response = client.post(
"/model/install_peft", params={"peft": adapter_id, "model_id": model_id, "experiment_id": 1}
Expand Down Expand Up @@ -130,7 +130,8 @@ def test_install_peft_adapter_info_fail(client):
assert response.json()["check_status"]["error"] == "not found"


def test_install_peft_architecture_detection_unknown(client):
@pytest.mark.asyncio
async def test_install_peft_architecture_detection_unknown(client):
adapter_info = make_mock_adapter_info()
with (
patch("transformerlab.routers.model.snapshot_download", return_value="/tmp/mock"),
Expand All @@ -139,7 +140,6 @@ def test_install_peft_architecture_detection_unknown(client):
patch("huggingface_hub.HfApi.model_info", return_value=adapter_info),
patch("transformerlab.routers.model.huggingfacemodel.get_model_details_from_huggingface", return_value={}),
patch("transformerlab.routers.model.job_service.job_create", return_value=123),
patch("transformerlab.routers.model.asyncio.create_task"),
):
response = client.post(
"/model/install_peft", params={"peft": "dummy", "model_id": "valid_model", "experiment_id": 1}
Expand All @@ -148,7 +148,8 @@ def test_install_peft_architecture_detection_unknown(client):
assert response.json()["check_status"]["architectures_status"] == "unknown"


def test_install_peft_unknown_field_status(client):
@pytest.mark.asyncio
async def test_install_peft_unknown_field_status(client):
adapter_info = make_mock_adapter_info(overrides={"config": {}})
with (
patch("transformerlab.routers.model.snapshot_download", return_value="/tmp/mock"),
Expand All @@ -157,7 +158,6 @@ def test_install_peft_unknown_field_status(client):
patch("huggingface_hub.HfApi.model_info", return_value=adapter_info),
patch("transformerlab.routers.model.huggingfacemodel.get_model_details_from_huggingface", return_value={}),
patch("transformerlab.routers.model.job_service.job_create", return_value=123),
patch("transformerlab.routers.model.asyncio.create_task"),
):
response = client.post(
"/model/install_peft", params={"peft": "dummy", "model_id": "valid_model", "experiment_id": 1}
Expand Down
38 changes: 36 additions & 2 deletions api/test/server/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,47 @@

@pytest.mark.live_server
def test_set(live_server):
response = requests.get(f"{live_server}/config/set", params={"k": "message", "v": "Hello, World!"})
# Get admin token for authentication
login_response = requests.post(
f"{live_server}/auth/jwt/login",
data={"username": "[email protected]", "password": "admin123"}
)
assert login_response.status_code == 200, f"Login failed with {login_response.status_code}: {login_response.text}"
token = login_response.json()["access_token"]
headers = {"Authorization": f"Bearer {token}"}

# Get user's team ID
teams_response = requests.get(f"{live_server}/users/me/teams", headers=headers)
assert teams_response.status_code == 200
teams_data = teams_response.json()
assert "teams" in teams_data and len(teams_data["teams"]) > 0, "User has no teams"
team_id = teams_data["teams"][0]["id"]
headers["X-Team-Id"] = team_id

response = requests.get(f"{live_server}/config/set", params={"k": "message", "v": "Hello, World!"}, headers=headers)
assert response.status_code == 200
assert response.json() == {"key": "message", "value": "Hello, World!"}


@pytest.mark.live_server
def test_get(live_server):
response = requests.get(f"{live_server}/config/get/message")
# Get admin token for authentication
login_response = requests.post(
f"{live_server}/auth/jwt/login",
data={"username": "[email protected]", "password": "admin123"}
)
assert login_response.status_code == 200, f"Login failed with {login_response.status_code}: {login_response.text}"
token = login_response.json()["access_token"]
headers = {"Authorization": f"Bearer {token}"}

# Get user's team ID
teams_response = requests.get(f"{live_server}/users/me/teams", headers=headers)
assert teams_response.status_code == 200
teams_data = teams_response.json()
assert "teams" in teams_data and len(teams_data["teams"]) > 0, "User has no teams"
team_id = teams_data["teams"][0]["id"]
headers["X-Team-Id"] = team_id

response = requests.get(f"{live_server}/config/get/message", headers=headers)
assert response.status_code == 200
assert response.json() == "Hello, World!"
57 changes: 54 additions & 3 deletions api/test/server/test_server_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,24 @@

@pytest.mark.live_server
def test_server_info(live_server):
response = requests.get(f"{live_server}/server/info")
# Get admin token for authentication
login_response = requests.post(
f"{live_server}/auth/jwt/login",
data={"username": "[email protected]", "password": "admin123"}
)
assert login_response.status_code == 200, f"Login failed with {login_response.status_code}: {login_response.text}"
token = login_response.json()["access_token"]
headers = {"Authorization": f"Bearer {token}"}

# Get user's team ID
teams_response = requests.get(f"{live_server}/users/me/teams", headers=headers)
assert teams_response.status_code == 200
teams_data = teams_response.json()
assert "teams" in teams_data and len(teams_data["teams"]) > 0, "User has no teams"
team_id = teams_data["teams"][0]["id"]
headers["X-Team-Id"] = team_id

response = requests.get(f"{live_server}/server/info", headers=headers)
assert response.status_code == 200
data = response.json()
assert isinstance(data, dict)
Expand All @@ -24,7 +41,24 @@ def test_server_info(live_server):

@pytest.mark.live_server
def test_server_python_libraries(live_server):
response = requests.get(f"{live_server}/server/python_libraries")
# Get admin token for authentication
login_response = requests.post(
f"{live_server}/auth/jwt/login",
data={"username": "[email protected]", "password": "admin123"}
)
assert login_response.status_code == 200, f"Login failed with {login_response.status_code}: {login_response.text}"
token = login_response.json()["access_token"]
headers = {"Authorization": f"Bearer {token}"}

# Get user's team ID
teams_response = requests.get(f"{live_server}/users/me/teams", headers=headers)
assert teams_response.status_code == 200
teams_data = teams_response.json()
assert "teams" in teams_data and len(teams_data["teams"]) > 0, "User has no teams"
team_id = teams_data["teams"][0]["id"]
headers["X-Team-Id"] = team_id

response = requests.get(f"{live_server}/server/python_libraries", headers=headers)
assert response.status_code == 200
data = response.json()
# assert it is an array of {"name": "package_name", "version": "version_number"} type things
Expand All @@ -38,7 +72,24 @@ def test_server_python_libraries(live_server):

@pytest.mark.live_server
def test_server_pytorch_collect_env(live_server):
response = requests.get(f"{live_server}/server/pytorch_collect_env")
# Get admin token for authentication
login_response = requests.post(
f"{live_server}/auth/jwt/login",
data={"username": "[email protected]", "password": "admin123"}
)
assert login_response.status_code == 200, f"Login failed with {login_response.status_code}: {login_response.text}"
token = login_response.json()["access_token"]
headers = {"Authorization": f"Bearer {token}"}

# Get user's team ID
teams_response = requests.get(f"{live_server}/users/me/teams", headers=headers)
assert teams_response.status_code == 200
teams_data = teams_response.json()
assert "teams" in teams_data and len(teams_data["teams"]) > 0, "User has no teams"
team_id = teams_data["teams"][0]["id"]
headers["X-Team-Id"] = team_id

response = requests.get(f"{live_server}/server/pytorch_collect_env", headers=headers)
assert response.status_code == 200
data = response.text
assert "PyTorch" in data
2 changes: 1 addition & 1 deletion api/transformerlab/routers/auth/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from .api_key_auth import get_user_or_api_key

router = APIRouter(prefix="/auth", tags=["auth"])
router = APIRouter(prefix="/auth", tags=["auth-workos-deprecated"], deprecated=True)


@router.get("/login-url")
Expand Down
Loading
Loading