diff --git a/piccolo_admin/endpoints.py b/piccolo_admin/endpoints.py index 8bdbdec3..dba77101 100644 --- a/piccolo_admin/endpoints.py +++ b/piccolo_admin/endpoints.py @@ -50,9 +50,9 @@ from piccolo_api.session_auth.middleware import SessionsAuthBackend from piccolo_api.session_auth.tables import SessionsBase from pydantic import BaseModel, Field, ValidationError +from starlette.exceptions import HTTPException from starlette.middleware import Middleware from starlette.middleware.authentication import AuthenticationMiddleware -from starlette.middleware.exceptions import HTTPException from starlette.requests import Request from starlette.responses import HTMLResponse, JSONResponse, Response from starlette.staticfiles import StaticFiles @@ -412,6 +412,13 @@ class FormConfigResponseModel(BaseModel): description: Optional[str] = None +@dataclass +class SessionExpiryConfig: + session_expiry: timedelta + max_session_expiry: timedelta + increase_expiry: Optional[timedelta] = timedelta(minutes=20) + + def handle_auth_exception(request: Request, exc: Exception): return JSONResponse({"error": "Auth failed"}, status_code=401) @@ -438,14 +445,6 @@ async def log_error(request: Request, exc: HTTPException): class AdminRouter(FastAPI): - """ - The root returns a single page app. The other URLs are REST endpoints. - """ - - table: list[Table] = [] - auth_table: type[BaseUser] = BaseUser - template: str = "" - def __init__( self, *tables: Union[type[Table], TableConfig], @@ -467,6 +466,68 @@ def __init__( sidebar_links: dict[str, str] = {}, mfa_providers: Optional[Sequence[MFAProvider]] = None, ) -> None: + self._init_fastapi_app(site_name, allowed_hosts, debug) + + ####################################################################### + # Convert any table arguments which are plain ``Table`` classes into + # ``TableConfig`` instances. + + self.table_configs = self._init_table_configs(tables) + self.table_config_map = { + table_config.table_class._meta.tablename: table_config + for table_config in self.table_configs + } + + self._check_media_storage() + + ####################################################################### + + self.default_language_code = default_language_code + self.translations_map = { + translation.language_code.lower(): translation + for translation in (translations or TRANSLATIONS) + } + + ####################################################################### + + # Public attribute retained for backwards compatibility. + # To be deprecated + self.auth_table = auth_table + self.site_name = site_name + self.forms = forms + self.read_only = read_only + self.page_size = page_size + self.sidebar_links = sidebar_links + self.form_config_map = {form.slug: form for form in self.forms} + + self.template = self._init_index_template() + + # Private attributes used internally to initialise the FastAPI app. + # Details and implementation subject to change + self._auth_table = auth_table + self._session_table = session_table + self._rate_limit_provider: RateLimitProvider = ( + rate_limit_provider + or InMemoryLimitProvider(limit=20, timespan=300) + ) + self._session_expiry_config = SessionExpiryConfig( + session_expiry=session_expiry, + max_session_expiry=max_session_expiry, + increase_expiry=increase_expiry, + ) + self._mfa_providers: Sequence[MFAProvider] = mfa_providers or [] + self._production = production + + self._init_app() + + def _init_fastapi_app( + self, site_name: str, allowed_hosts: Sequence[str], debug: bool + ): + """Call the parent FastAPI constructor, which provisions the top-level + ASGI app and ensures .mount etc. are safe to call from our code. + + This adds middleware for CSRF protection and custom logging. + """ super().__init__( title=site_name, description=f"{site_name} documentation", @@ -483,93 +544,147 @@ def __init__( redoc_url=None, ) + def _init_app(self): + """ + Configure all the Piccolo admin endpoints and sub-apps. + The high-level structure of the admin app consists of: + * /api - endpoints protected by auth, used for the admin UI + * /public - endpoints for login/logout etc. that do not require auth + and are used on the login page + * /assets - static files for the admin UI + """ + # App for authenticated endpoints: + api_app = self._init_api_app( + superuser_tables=(self._auth_table, self._session_table), + ) ####################################################################### - # Convert any table arguments which are plain ``Table`` classes into - # ``TableConfig`` instances. + # Add /user and /change-password endpoints + self._init_api_auth_endpoints(api_app) - table_configs: list[TableConfig] = [] + # Optional components of the API app: + self._init_mfa_provider(api_app) - for table in tables: - if isinstance(table, TableConfig): - table_configs.append(table) - else: - table_configs.append(TableConfig(table_class=table)) + self._init_auth_middleware(api_app) - self.table_configs = sorted( - table_configs, - key=lambda table_config: table_config.table_class._meta.tablename, - ) - self.table_config_map = { - table_config.table_class._meta.tablename: table_config - for table_config in self.table_configs - } + # App for unauthenticated endpoints: + public_app = self._init_public_app() + + # Add /login and /logout endpoints + self._init_public_auth_endpoints(public_app) + assets_app = StaticFiles(directory=os.path.join(ASSET_PATH, "assets")) ####################################################################### - # Make sure columns are configured properly. - for table_config in table_configs: - table_class = table_config.table_class - for column in table_class._meta.columns: - if column._meta.secret and column._meta.required: - message = ( - f"{table_class._meta.tablename}." - f"{column._meta._name} is using `secret` and " - f"`required` column args which are incompatible. " - f"You may encounter unexpected behavior when using " - f"this table within Piccolo Admin." - ) - colored_warning(message, level=Level.high) + self.add_route(path="/", route=self.get_root, methods=["GET"]) + + self.mount(path="/api", app=api_app) + self.mount(path="/public", app=public_app) + self.mount(path="/assets", app=assets_app) ####################################################################### - # Make sure media storage is configured properly. - media_storage = [ - i - for i in itertools.chain( - *[ - table_config.media_storage or [] - for table_config in table_configs - ] - ) - ] + def _init_public_app(self) -> FastAPI: + """Creates a sub-app for public endpoints.""" + public_app = FastAPI( + redoc_url=None, + docs_url=None, + debug=self.debug, + exception_handlers={500: log_error}, + ) + public_app.mount("/docs/", swagger_ui(schema_url="../openapi.json")) - if len(media_storage) != len(set(media_storage)): - raise ValueError( - "Media storage is misconfigured - multiple columns are saving " - "to the same location." - ) + # We make the meta endpoint available without auth, because it contains + # the site name. + public_app.add_api_route( + "/meta/", + endpoint=self.get_meta, + tags=["Meta"], # type: ignore + ) - ####################################################################### + # The translations are public, because we need them on the login page. + public_app.add_api_route( + "/translations/", + endpoint=self.get_translation_list, # type: ignore + methods=["GET"], + tags=["Translations"], + response_model=TranslationListResponse, + ) - self.default_language_code = default_language_code - self.translations_map = { - translation.language_code.lower(): translation - for translation in (translations or TRANSLATIONS) - } + public_app.add_api_route( + "/translations/{language_code:str}/", + endpoint=self.get_translation, # type: ignore + methods=["GET"], + tags=["Translations"], + response_model=Translation, + ) - ####################################################################### + return public_app - self.auth_table = auth_table - self.site_name = site_name - self.forms = forms - self.read_only = read_only - self.sidebar_links = sidebar_links - self.form_config_map = {form.slug: form for form in self.forms} + def _init_public_auth_endpoints(self, public_app: FastAPI): + """The unauthenticated user is supposed to hit the /login/ endpoint to + obtain credentials. + This modifies a FastAPI app to add /login/ and /logout/ endpoints. + """ + public_app.mount( + path="/login/", + # This rate limiting is to prevent brute forcing password login, + # and MFA codes. + app=RateLimitingMiddleware( + app=session_login( + auth_table=self._auth_table, + session_table=self._session_table, + session_expiry=self._session_expiry_config.session_expiry, + max_session_expiry=( + self._session_expiry_config.max_session_expiry + ), + redirect_to=None, + production=self._production, + mfa_providers=self._mfa_providers, + ), + provider=self._rate_limit_provider, + ), + ) - with open(os.path.join(ASSET_PATH, "index.html")) as f: - self.template = f.read() + public_app.add_route( + path="/logout/", + route=session_logout( + session_table=self._session_table # type: ignore + ), + methods=["POST"], + ) - ####################################################################### + def _init_auth_middleware(self, api_app: FastAPI): + """Add middleware to a FastAPI to provide auth on all endpoints. + This produces an instance of starlette's AuthenticationMiddleware and + adds it to the app. + """ + auth_middleware = partial( + AuthenticationMiddleware, + backend=SessionsAuthBackend( + auth_table=self._auth_table, + session_table=self._session_table, + admin_only=True, + increase_expiry=self._session_expiry_config.increase_expiry, + ), + on_error=handle_auth_exception, # type: ignore + ) + api_app.add_middleware(auth_middleware) - private_app = FastAPI( + def _init_api_app(self, superuser_tables: Any) -> FastAPI: + """Provision the API app with all authenticated endpoints necessary for + the admin UI. + This is a FastAPI app, meaning it can have its own middleware and can + be mounted on a parent app. + """ + api_app = FastAPI( docs_url=None, redoc_url=None, - debug=debug, + debug=self.debug, exception_handlers={500: log_error}, ) - private_app.mount("/docs/", swagger_ui(schema_url="../openapi.json")) + api_app.mount("/docs/", swagger_ui(schema_url="../openapi.json")) - for table_config in table_configs: + for table_config in self.table_configs: table_class = table_config.table_class visible_column_names = table_config.get_visible_column_names() visible_filter_names = table_config.get_visible_filter_names() @@ -581,17 +696,17 @@ def __init__( order_by = table_config.get_order_by() time_resolution = table_config.get_time_resolution() validators = table_config.validators - if table_class in (auth_table, session_table): + if table_class in superuser_tables: validators = validators or Validators() validators.every = [superuser_validators, *validators.every] FastAPIWrapper( root_url=f"/tables/{table_class._meta.tablename}/", - fastapi_app=private_app, + fastapi_app=api_app, piccolo_crud=PiccoloCRUD( table=table_class, - read_only=read_only, - page_size=page_size, + read_only=self.read_only, + page_size=self.page_size, schema_extra={ "visible_column_names": visible_column_names, "visible_filter_names": visible_filter_names, @@ -611,7 +726,7 @@ def __init__( ), ) - private_app.add_api_route( + api_app.add_api_route( path="/tables/", endpoint=self.get_table_list, # type: ignore methods=["GET"], @@ -619,7 +734,7 @@ def __init__( tags=["Tables"], ) - private_app.add_api_route( + api_app.add_api_route( path="/tables/grouped/", endpoint=self.get_table_list_grouped, # type: ignore methods=["GET"], @@ -627,14 +742,14 @@ def __init__( tags=["Tables"], ) - private_app.add_api_route( + api_app.add_api_route( path="/links/", endpoint=self.get_sidebar_links, # type: ignore methods=["GET"], tags=["Links"], ) - private_app.add_api_route( + api_app.add_api_route( path="/forms/", endpoint=self.get_forms, # type: ignore methods=["GET"], @@ -642,7 +757,7 @@ def __init__( response_model=list[FormConfigResponseModel], ) - private_app.add_api_route( + api_app.add_api_route( path="/forms/grouped/", endpoint=self.get_grouped_forms, # type: ignore methods=["GET"], @@ -650,28 +765,35 @@ def __init__( tags=["Forms"], ) - private_app.add_api_route( + api_app.add_api_route( path="/forms/{form_slug:str}/", endpoint=self.get_single_form, # type: ignore methods=["GET"], tags=["Forms"], ) - private_app.add_api_route( + api_app.add_api_route( path="/forms/{form_slug:str}/schema/", endpoint=self.get_single_form_schema, # type: ignore methods=["GET"], tags=["Forms"], ) - private_app.add_api_route( + api_app.add_api_route( path="/forms/{form_slug:str}/", endpoint=self.post_single_form, # type: ignore methods=["POST"], tags=["Forms"], ) - private_app.add_api_route( + self._init_media_endpoints(api_app) + return api_app + + def _init_api_auth_endpoints(self, api_app: FastAPI): + """Modify an app by adding endpoints for self-service for an + authenticated user: /user/ and /change-password/ + """ + api_app.add_api_route( path="/user/", endpoint=self.get_user, # type: ignore methods=["GET"], @@ -679,20 +801,19 @@ def __init__( response_model=UserResponseModel, ) - private_app.add_route( + api_app.add_route( path="/change-password/", route=change_password( # type: ignore login_url="./../../public/login/", - session_table=session_table, - read_only=read_only, + session_table=self._session_table, + read_only=self.read_only, ), methods=["POST"], ) - ####################################################################### - # Media - - private_app.add_api_route( + def _init_media_endpoints(self, api_app: FastAPI): + """Add /media/ endpoints required by some admin UI functionality.""" + api_app.add_api_route( path="/media/", endpoint=self.store_file, # type: ignore methods=["POST"], @@ -700,7 +821,7 @@ def __init__( response_model=StoreFileResponseModel, ) - private_app.add_api_route( + api_app.add_api_route( path="/media/generate-file-url/", endpoint=self.generate_file_url, # type: ignore methods=["POST"], @@ -717,7 +838,7 @@ def __init__( if isinstance(media_storage, LocalMediaStorage): # We apply a restrictive CSP here to mitigate SVG # files being used maliciously when viewed by admins - private_app.mount( + api_app.mount( path=f"/media-files/{column._meta.table._meta.tablename}/{column._meta.name}/", # noqa: E501 app=CSPMiddleware( StaticFiles( @@ -727,123 +848,84 @@ def __init__( ), ) - ####################################################################### - # MFA - - if mfa_providers: - if len(mfa_providers) > 1: - raise ValueError( - "Only a single mfa_provider is currently supported." - ) + def _init_mfa_provider(self, api_app: FastAPI): + """Add /mfa-setup/ endpoint to an app.""" + if len(self._mfa_providers) > 1: + raise ValueError( + "Only a single mfa_provider is currently supported." + ) - for mfa_provider in mfa_providers: - private_app.mount( - path="/mfa-setup/", - # This rate limiting is because some of the forms accept - # a password, and generating recovery codes is somewhat - # expensive, so we want to prevent abuse. - app=RateLimitingMiddleware( - app=mfa_setup( - provider=mfa_provider, - auth_table=self.auth_table, - ), - provider=InMemoryLimitProvider(limit=20, timespan=300), + for mfa_provider in self._mfa_providers: + api_app.mount( + path="/mfa-setup/", + # This rate limiting is because some of the forms accept + # a password, and generating recovery codes is somewhat + # expensive, so we want to prevent abuse. + app=RateLimitingMiddleware( + app=mfa_setup( + provider=mfa_provider, + auth_table=self._auth_table, ), - ) + provider=InMemoryLimitProvider(limit=20, timespan=300), + ), + ) - ####################################################################### + def _init_index_template(self) -> str: + with open(os.path.join(ASSET_PATH, "index.html")) as f: + return f.read() - public_app = FastAPI( - redoc_url=None, - docs_url=None, - debug=debug, - exception_handlers={500: log_error}, - ) - public_app.mount("/docs/", swagger_ui(schema_url="../openapi.json")) + def _check_media_storage(self): + # Make sure media storage is configured properly. - if not rate_limit_provider: - rate_limit_provider = InMemoryLimitProvider( - limit=20, - timespan=300, + media_storage = [ + i + for i in itertools.chain( + *[ + table_config.media_storage or [] + for table_config in self.table_configs + ] ) + ] - public_app.mount( - path="/login/", - # This rate limiting is to prevent brute forcing password login, - # and MFA codes. - app=RateLimitingMiddleware( - app=session_login( - auth_table=self.auth_table, - session_table=session_table, - session_expiry=session_expiry, - max_session_expiry=max_session_expiry, - redirect_to=None, - production=production, - mfa_providers=mfa_providers, - ), - provider=rate_limit_provider, - ), - ) - - public_app.add_route( - path="/logout/", - route=session_logout(session_table=session_table), # type: ignore - methods=["POST"], - ) - - # We make the meta endpoint available without auth, because it contains - # the site name. - public_app.add_api_route( - "/meta/", endpoint=self.get_meta, tags=["Meta"] # type: ignore - ) + if len(media_storage) != len(set(media_storage)): + raise ValueError( + "Media storage is misconfigured - multiple columns are saving " + "to the same location." + ) - # The translations are public, because we need them on the login page. - public_app.add_api_route( - "/translations/", - endpoint=self.get_translation_list, # type: ignore - methods=["GET"], - tags=["Translations"], - response_model=TranslationListResponse, - ) + def _init_table_configs( + self, tables: Sequence[Union[type[Table], TableConfig]] + ) -> list[TableConfig]: + """Validate and structure information about the database tables.""" + table_configs: list[TableConfig] = [] - public_app.add_api_route( - "/translations/{language_code:str}/", - endpoint=self.get_translation, # type: ignore - methods=["GET"], - tags=["Translations"], - response_model=Translation, - ) + for table in tables: + if isinstance(table, TableConfig): + table_configs.append(table) + else: + table_configs.append(TableConfig(table_class=table)) ####################################################################### + # Make sure columns are configured properly. - self.router.add_route( - path="/", endpoint=self.get_root, methods=["GET"] - ) - - self.mount( - path="/assets", - app=StaticFiles(directory=os.path.join(ASSET_PATH, "assets")), - ) + for table_config in table_configs: + table_class = table_config.table_class + for column in table_class._meta.columns: + if column._meta.secret and column._meta.required: + message = ( + f"{table_class._meta.tablename}." + f"{column._meta._name} is using `secret` and " + f"`required` column args which are incompatible. " + f"You may encounter unexpected behavior when using " + f"this table within Piccolo Admin." + ) + colored_warning(message, level=Level.high) - auth_middleware = partial( - AuthenticationMiddleware, - backend=SessionsAuthBackend( - auth_table=auth_table, - session_table=session_table, - admin_only=True, - increase_expiry=increase_expiry, - ), - on_error=handle_auth_exception, # type: ignore + return sorted( + table_configs, + key=lambda table_config: table_config.table_class._meta.tablename, ) - self.mount(path="/api", app=auth_middleware(private_app)) - self.mount(path="/public", app=public_app) - - async def get_root(self, request: Request) -> HTMLResponse: - return HTMLResponse(self.template) - - ########################################################################### - def _get_media_storage( self, table_name: str, column_name: str ) -> MediaStorage: @@ -876,7 +958,7 @@ def _get_media_storage( detail="No such column found.", ) - media_storage = media_columns.get(column) + media_storage = media_columns.get(column) # type: ignore if not media_storage: raise HTTPException( @@ -886,6 +968,9 @@ def _get_media_storage( return media_storage + async def get_root(self, request: Request) -> HTMLResponse: + return HTMLResponse(self.template) + async def store_file( self, request: Request,