Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,6 @@ backend/yarn.lock
backend/static

/process-compose.yml

.next
next-env.d.ts
2 changes: 1 addition & 1 deletion backend/Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ click = ">=8.1.3"
databases = {extras = ["asyncpg"], version = "<=0.8.0"}
fastapi = "0.115.6"
fastapi-cache2 = ">=0.2.0"
fastapi-sso = ">=0.6.4"
fastapi-sso = "0.17.0"
gunicorn = ">=20.1.0"
heliclockter = ">=1.3.0"
parameterized = ">=0.8.1"
Expand Down
9 changes: 8 additions & 1 deletion backend/bracket/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pydantic import Field, PostgresDsn
from pydantic_settings import BaseSettings, SettingsConfigDict

from bracket.models.sso import SSOProvider
from bracket.utils.types import EnumAutoStr


Expand All @@ -29,8 +30,8 @@ def get_log_level(self) -> int:
class Config(BaseSettings):
admin_email: str | None = None
admin_password: str | None = None
allow_insecure_http_sso: bool = False
allow_user_registration: bool = True
allow_user_basic_login: bool = True
allow_demo_user_registration: bool = True
captcha_secret: str | None = None
base_url: str = "http://localhost:8400"
Expand All @@ -41,6 +42,12 @@ class Config(BaseSettings):
pg_dsn: PostgresDsn = "postgresql://user:pass@localhost:5432/db" # type: ignore[assignment]
sentry_dsn: str | None = None

sso_1_provider: SSOProvider | None = None
sso_1_client_id: str | None = None
sso_1_allow_insecure_http_sso: bool = False
sso_1_openid_discovery_url: str | None = None
sso_1_openid_scopes: str | None = None

def is_cors_enabled(self) -> bool:
return self.cors_origins != "*"

Expand Down
85 changes: 85 additions & 0 deletions backend/bracket/logic/sso.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from typing import Any, assert_never

import aiohttp
from fastapi_sso import GoogleSSO, OpenID, SSOBase, create_provider
from fastapi_sso.sso.base import DiscoveryDocument
from httpx import AsyncClient

from bracket.config import config
from bracket.models.sso import SSOID, SSOConfig, SSOProvider

providers_cache: dict[SSOID, SSOBase] | None = None


async def get_discovery_document(discovery_url: str) -> DiscoveryDocument:
async with aiohttp.ClientSession() as session:
response = await session.get(discovery_url)
response.raise_for_status()
response_json = await response.json()
return {
"authorization_endpoint": response_json["authorization_endpoint"],
"token_endpoint": response_json["token_endpoint"],
"userinfo_endpoint": response_json["userinfo_endpoint"],
}


async def get_openid_provider(sso_config: SSOConfig) -> type[SSOBase]:
assert sso_config.openid_discovery_url is not None, (
"`openid_discovery_url` should be set for OpenID SSO"
)
assert sso_config.openid_scopes is not None, "`openid_scopes` should be set for OpenID SSO"

def convert_openid(response: dict[str, Any], _client: AsyncClient | None) -> OpenID:
return OpenID(display_name=response["sub"])

return create_provider(
name="oidc",
discovery_document=await get_discovery_document(sso_config.openid_discovery_url),
response_convertor=convert_openid,
default_scope=sso_config.openid_scopes.split(","),
)


async def build_sso(sso_config: SSOConfig) -> SSOBase:
match sso_config.provider:
case SSOProvider.google:
provider: type[SSOBase] = GoogleSSO
case SSOProvider.github:
provider = GoogleSSO
case SSOProvider.openid:
provider = await get_openid_provider(sso_config)
case _ as fallback:
assert_never(fallback)

return provider(
client_id=sso_config.client_id,
client_secret=sso_config.client_secret,
redirect_uri=sso_config.redirect_uri,
allow_insecure_http=sso_config.allow_insecure_http,
)


async def get_sso_providers() -> dict[SSOID, SSOBase]:
global providers_cache # noqa: PLW0603

if providers_cache is not None:
return providers_cache

configs = []
if (
config.sso_1_provider is not None
and config.sso_1_client_id is not None
):
configs.append(
SSOConfig(
id=SSOID(1),
provider=config.sso_1_provider,
client_id=config.sso_1_client_id,
redirect_uri=f"{config.base_url}/sso-callback/1",
allow_insecure_http=config.sso_1_allow_insecure_http_sso,
)
)

providers = {sso_config.id: await build_sso(sso_config) for sso_config in configs}
providers_cache = providers
return providers_cache
25 changes: 25 additions & 0 deletions backend/bracket/models/sso.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from __future__ import annotations

from enum import StrEnum, auto
from typing import NewType

from pydantic import BaseModel

SSOID = NewType("SSOID", int)


class SSOConfig(BaseModel):
id: SSOID
provider: SSOProvider
client_id: str
client_secret: str
redirect_uri: str
allow_insecure_http: bool
openid_discovery_url: str | None = None
openid_scopes: str | None = None


class SSOProvider(StrEnum):
google = auto()
github = auto()
openid = auto()
56 changes: 25 additions & 31 deletions backend/bracket/routes/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@
from jwt import DecodeError, ExpiredSignatureError
from pydantic import BaseModel
from starlette.requests import Request
from starlette.responses import RedirectResponse

from bracket.config import config
from bracket.database import database
from bracket.logic.sso import get_sso_providers
from bracket.models.db.tournament import Tournament
from bracket.models.db.user import UserInDB, UserPublic
from bracket.models.sso import SSOID
from bracket.schema import tournaments
from bracket.sql.tournaments import sql_get_tournament_by_endpoint_name
from bracket.sql.users import get_user, get_user_access_to_club, get_user_access_to_tournament
Expand All @@ -28,20 +31,8 @@
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")


# def convert_openid(response: dict[str, Any]) -> OpenID:
# """Convert user information returned by OIDC"""
# return OpenID(display_name=response["sub"])


# os.environ["OAUTHLIB_INSECURE_TRANSPORT"] = "1"

# sso = GoogleSSO(
# client_id="test",
# client_secret="secret",
# redirect_uri="http://localhost:8080/sso_callback",
# allow_insecure_http=config.allow_insecure_http_sso,
# )


class Token(BaseModel):
access_token: str
Expand Down Expand Up @@ -184,22 +175,25 @@ async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends(
return Token(access_token=access_token, token_type="bearer", user_id=user.id)


# @router.get("/login", summary='SSO login')
# async def sso_login() -> RedirectResponse:
# """Generate login url and redirect"""
# return cast(RedirectResponse, await sso.get_login_redirect())
#
#
# @router.get("/sso_callback", summary='SSO callback')
# async def sso_callback(request: Request) -> dict[str, Any]:
# """Process login response from OIDC and return user info"""
# user = await sso.verify_and_process(request)
# if user is None:
# raise HTTPException(401, "Failed to fetch user information")
# return {
# "id": user.id,
# "picture": user.picture,
# "display_name": user.display_name,
# "email": user.email,
# "provider": user.provider,
# }
@router.get("/sso-login/{sso_id}")
async def sso_login(sso_id: SSOID) -> RedirectResponse:
"""Generate login url and redirect"""
sso_providers = await get_sso_providers()
return await sso_providers[sso_id].get_login_redirect()


@router.get("/sso-callback/{sso_id}")
async def sso_callback(request: Request, sso_id: SSOID) -> dict[str, Any]:
"""Process login response from OIDC and return user info"""
sso_providers = await get_sso_providers()
user = await sso_providers[sso_id].verify_and_process(request)
if user is None:
raise HTTPException(401, "Failed to fetch user information")

return {
"id": user.id,
"picture": user.picture,
"display_name": user.display_name,
"email": user.email,
"provider": user.provider,
}