Skip to content

Commit 26730ca

Browse files
authored
Tidy app entrypoint (#7668)
## Summary Prior to this PR, most of the app setup was being done in `api_app.py` at import time. This PR cleans this up, by: - Splitting app setup into more modular functions - Narrower responsibility for the `api_app.py` file - it just initializes the `FastAPI` app The main motivation for this changes is to make it easier to support an upcoming torch configuration feature that requires more careful ordering of app initialization steps. ## Related Issues / Discussions N/A ## QA Instructions - [x] Launch the app via invokeai-web.py and smoke test it. - [ ] Launch the app via the installer and smoke test it. - [x] Test that generate_openapi_schema.py produces the same result before and after the change. - [x] No regression in unit tests that directly interact with the app. (test_images.py) ## Merge Plan - [x] Check to see if there are any commercial implications to modifying the app entrypoint. ## Checklist - [x] _The PR has a short but descriptive title, suitable for a changelog_ - [x] _Tests added / updated (if applicable)_ - [x] _Documentation added / updated (if applicable)_ - [ ] _Updated `What's New` copy (if doing a release after this PR)_
2 parents 84c9ecc + 1e2c7c5 commit 26730ca

File tree

3 files changed

+133
-109
lines changed

3 files changed

+133
-109
lines changed

invokeai/app/api_app.py

Lines changed: 1 addition & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,8 @@
11
import asyncio
22
import logging
3-
import mimetypes
4-
import socket
53
from contextlib import asynccontextmanager
64
from pathlib import Path
75

8-
import torch
9-
import uvicorn
106
from fastapi import FastAPI, Request
117
from fastapi.middleware.cors import CORSMiddleware
128
from fastapi.middleware.gzip import GZipMiddleware
@@ -15,11 +11,7 @@
1511
from fastapi_events.handlers.local import local_handler
1612
from fastapi_events.middleware import EventHandlerASGIMiddleware
1713
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
18-
from torch.backends.mps import is_available as is_mps_available
1914

20-
# for PyCharm:
21-
# noinspection PyUnresolvedReferences
22-
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
2315
import invokeai.frontend.web as web_dir
2416
from invokeai.app.api.dependencies import ApiDependencies
2517
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
@@ -36,39 +28,15 @@
3628
workflows,
3729
)
3830
from invokeai.app.api.sockets import SocketIO
39-
from invokeai.app.invocations.load_custom_nodes import load_custom_nodes
4031
from invokeai.app.services.config.config_default import get_config
4132
from invokeai.app.util.custom_openapi import get_openapi_func
42-
from invokeai.backend.util.devices import TorchDevice
4333
from invokeai.backend.util.logging import InvokeAILogger
4434

4535
app_config = get_config()
46-
47-
48-
if is_mps_available():
49-
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
50-
51-
5236
logger = InvokeAILogger.get_logger(config=app_config)
53-
# fix for windows mimetypes registry entries being borked
54-
# see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352
55-
mimetypes.add_type("application/javascript", ".js")
56-
mimetypes.add_type("text/css", ".css")
57-
58-
torch_device_name = TorchDevice.get_torch_device_name()
59-
logger.info(f"Using torch device: {torch_device_name}")
6037

6138
loop = asyncio.new_event_loop()
6239

63-
# We may change the port if the default is in use, this global variable is used to store the port so that we can log
64-
# the correct port when the server starts in the lifespan handler.
65-
port = app_config.port
66-
67-
# Load custom nodes. This must be done after importing the Graph class, which itself imports all modules from the
68-
# invocations module. The ordering here is implicit, but important - we want to load custom nodes after all the
69-
# core nodes have been imported so that we can catch when a custom node clobbers a core node.
70-
load_custom_nodes(custom_nodes_path=app_config.custom_nodes_path)
71-
7240

7341
@asynccontextmanager
7442
async def lifespan(app: FastAPI):
@@ -77,7 +45,7 @@ async def lifespan(app: FastAPI):
7745

7846
# Log the server address when it starts - in case the network log level is not high enough to see the startup log
7947
proto = "https" if app_config.ssl_certfile else "http"
80-
msg = f"Invoke running on {proto}://{app_config.host}:{port} (Press CTRL+C to quit)"
48+
msg = f"Invoke running on {proto}://{app_config.host}:{app_config.port} (Press CTRL+C to quit)"
8149

8250
# Logging this way ignores the logger's log level and _always_ logs the message
8351
record = logger.makeRecord(
@@ -192,73 +160,3 @@ def overridden_redoc() -> HTMLResponse:
192160
app.mount(
193161
"/static", NoCacheStaticFiles(directory=Path(web_root_path, "static/")), name="static"
194162
) # docs favicon is in here
195-
196-
197-
def check_cudnn(logger: logging.Logger) -> None:
198-
"""Check for cuDNN issues that could be causing degraded performance."""
199-
if torch.backends.cudnn.is_available():
200-
try:
201-
# Note: At the time of writing (torch 2.2.1), torch.backends.cudnn.version() only raises an error the first
202-
# time it is called. Subsequent calls will return the version number without complaining about a mismatch.
203-
cudnn_version = torch.backends.cudnn.version()
204-
logger.info(f"cuDNN version: {cudnn_version}")
205-
except RuntimeError as e:
206-
logger.warning(
207-
"Encountered a cuDNN version issue. This may result in degraded performance. This issue is usually "
208-
"caused by an incompatible cuDNN version installed in your python environment, or on the host "
209-
f"system. Full error message:\n{e}"
210-
)
211-
212-
213-
def invoke_api() -> None:
214-
def find_port(port: int) -> int:
215-
"""Find a port not in use starting at given port"""
216-
# Taken from https://waylonwalker.com/python-find-available-port/, thanks Waylon!
217-
# https://github.com/WaylonWalker
218-
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
219-
s.settimeout(1)
220-
if s.connect_ex(("localhost", port)) == 0:
221-
return find_port(port=port + 1)
222-
else:
223-
return port
224-
225-
if app_config.dev_reload:
226-
try:
227-
import jurigged
228-
except ImportError as e:
229-
logger.error(
230-
'Can\'t start `--dev_reload` because jurigged is not found; `pip install -e ".[dev]"` to include development dependencies.',
231-
exc_info=e,
232-
)
233-
else:
234-
jurigged.watch(logger=InvokeAILogger.get_logger(name="jurigged").info)
235-
236-
global port
237-
port = find_port(app_config.port)
238-
if port != app_config.port:
239-
logger.warn(f"Port {app_config.port} in use, using port {port}")
240-
241-
check_cudnn(logger)
242-
243-
config = uvicorn.Config(
244-
app=app,
245-
host=app_config.host,
246-
port=port,
247-
loop="asyncio",
248-
log_level=app_config.log_level_network,
249-
ssl_certfile=app_config.ssl_certfile,
250-
ssl_keyfile=app_config.ssl_keyfile,
251-
)
252-
server = uvicorn.Server(config)
253-
254-
# replace uvicorn's loggers with InvokeAI's for consistent appearance
255-
uvicorn_logger = InvokeAILogger.get_logger("uvicorn")
256-
uvicorn_logger.handlers.clear()
257-
for hdlr in logger.handlers:
258-
uvicorn_logger.addHandler(hdlr)
259-
260-
loop.run_until_complete(server.serve())
261-
262-
263-
if __name__ == "__main__":
264-
invoke_api()

invokeai/app/run_app.py

Lines changed: 68 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,74 @@
1-
"""This is a wrapper around the main app entrypoint, to allow for CLI args to be parsed before running the app."""
1+
import uvicorn
22

3+
from invokeai.app.invocations.load_custom_nodes import load_custom_nodes
4+
from invokeai.app.services.config.config_default import get_config
5+
from invokeai.app.util.startup_utils import (
6+
apply_monkeypatches,
7+
check_cudnn,
8+
enable_dev_reload,
9+
find_open_port,
10+
register_mime_types,
11+
)
12+
from invokeai.backend.util.logging import InvokeAILogger
13+
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
314

4-
def run_app() -> None:
5-
# Before doing _anything_, parse CLI args!
6-
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
715

16+
def get_app():
17+
"""Import the app and event loop. We wrap this in a function to more explicitly control when it happens, because
18+
importing from api_app does a bunch of stuff - it's more like calling a function than importing a module.
19+
"""
20+
from invokeai.app.api_app import app, loop
21+
22+
return app, loop
23+
24+
25+
def run_app() -> None:
26+
"""The main entrypoint for the app."""
27+
# Parse the CLI arguments.
828
InvokeAIArgs.parse_args()
929

10-
from invokeai.app.api_app import invoke_api
30+
# Load config.
31+
app_config = get_config()
32+
33+
logger = InvokeAILogger.get_logger(config=app_config)
34+
35+
# Find an open port, and modify the config accordingly.
36+
orig_config_port = app_config.port
37+
app_config.port = find_open_port(app_config.port)
38+
if orig_config_port != app_config.port:
39+
logger.warning(f"Port {orig_config_port} is already in use. Using port {app_config.port}.")
40+
41+
# Miscellaneous startup tasks.
42+
apply_monkeypatches()
43+
register_mime_types()
44+
if app_config.dev_reload:
45+
enable_dev_reload()
46+
check_cudnn(logger)
47+
48+
# Initialize the app and event loop.
49+
app, loop = get_app()
50+
51+
# Load custom nodes. This must be done after importing the Graph class, which itself imports all modules from the
52+
# invocations module. The ordering here is implicit, but important - we want to load custom nodes after all the
53+
# core nodes have been imported so that we can catch when a custom node clobbers a core node.
54+
load_custom_nodes(custom_nodes_path=app_config.custom_nodes_path)
55+
56+
# Start the server.
57+
config = uvicorn.Config(
58+
app=app,
59+
host=app_config.host,
60+
port=app_config.port,
61+
loop="asyncio",
62+
log_level=app_config.log_level_network,
63+
ssl_certfile=app_config.ssl_certfile,
64+
ssl_keyfile=app_config.ssl_keyfile,
65+
)
66+
server = uvicorn.Server(config)
67+
68+
# replace uvicorn's loggers with InvokeAI's for consistent appearance
69+
uvicorn_logger = InvokeAILogger.get_logger("uvicorn")
70+
uvicorn_logger.handlers.clear()
71+
for hdlr in logger.handlers:
72+
uvicorn_logger.addHandler(hdlr)
1173

12-
invoke_api()
74+
loop.run_until_complete(server.serve())

invokeai/app/util/startup_utils.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import logging
2+
import mimetypes
3+
import socket
4+
5+
import torch
6+
7+
8+
def find_open_port(port: int) -> int:
9+
"""Find a port not in use starting at given port"""
10+
# Taken from https://waylonwalker.com/python-find-available-port/, thanks Waylon!
11+
# https://github.com/WaylonWalker
12+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
13+
s.settimeout(1)
14+
if s.connect_ex(("localhost", port)) == 0:
15+
return find_open_port(port=port + 1)
16+
else:
17+
return port
18+
19+
20+
def check_cudnn(logger: logging.Logger) -> None:
21+
"""Check for cuDNN issues that could be causing degraded performance."""
22+
if torch.backends.cudnn.is_available():
23+
try:
24+
# Note: At the time of writing (torch 2.2.1), torch.backends.cudnn.version() only raises an error the first
25+
# time it is called. Subsequent calls will return the version number without complaining about a mismatch.
26+
cudnn_version = torch.backends.cudnn.version()
27+
logger.info(f"cuDNN version: {cudnn_version}")
28+
except RuntimeError as e:
29+
logger.warning(
30+
"Encountered a cuDNN version issue. This may result in degraded performance. This issue is usually "
31+
"caused by an incompatible cuDNN version installed in your python environment, or on the host "
32+
f"system. Full error message:\n{e}"
33+
)
34+
35+
36+
def enable_dev_reload() -> None:
37+
"""Enable hot reloading on python file changes during development."""
38+
from invokeai.backend.util.logging import InvokeAILogger
39+
40+
try:
41+
import jurigged
42+
except ImportError as e:
43+
raise RuntimeError(
44+
'Can\'t start `--dev_reload` because jurigged is not found; `pip install -e ".[dev]"` to include development dependencies.'
45+
) from e
46+
else:
47+
jurigged.watch(logger=InvokeAILogger.get_logger(name="jurigged").info)
48+
49+
50+
def apply_monkeypatches() -> None:
51+
"""Apply monkeypatches to fix issues with third-party libraries."""
52+
53+
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
54+
55+
if torch.backends.mps.is_available():
56+
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
57+
58+
59+
def register_mime_types() -> None:
60+
"""Register additional mime types for windows."""
61+
# Fix for windows mimetypes registry entries being borked.
62+
# see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352
63+
mimetypes.add_type("application/javascript", ".js")
64+
mimetypes.add_type("text/css", ".css")

0 commit comments

Comments
 (0)