Move check_cudnn() and jurigged setup to startup_utils.py.

This commit is contained in:
Ryan Dick
2025-02-20 20:35:14 +00:00
parent 6f1dcf385b
commit 35910d3952
2 changed files with 35 additions and 27 deletions

View File

@@ -4,7 +4,6 @@ import mimetypes
from contextlib import asynccontextmanager
from pathlib import Path
import torch
import uvicorn
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
@@ -39,7 +38,7 @@ from invokeai.app.util.custom_openapi import get_openapi_func
# for PyCharm:
# noinspection PyUnresolvedReferences
from invokeai.app.util.startup_utils import find_open_port
from invokeai.app.util.startup_utils import check_cudnn, enable_dev_reload, find_open_port
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger
@@ -195,33 +194,9 @@ app.mount(
) # docs favicon is in here
def check_cudnn(logger: logging.Logger) -> None:
"""Check for cuDNN issues that could be causing degraded performance."""
if torch.backends.cudnn.is_available():
try:
# Note: At the time of writing (torch 2.2.1), torch.backends.cudnn.version() only raises an error the first
# time it is called. Subsequent calls will return the version number without complaining about a mismatch.
cudnn_version = torch.backends.cudnn.version()
logger.info(f"cuDNN version: {cudnn_version}")
except RuntimeError as e:
logger.warning(
"Encountered a cuDNN version issue. This may result in degraded performance. This issue is usually "
"caused by an incompatible cuDNN version installed in your python environment, or on the host "
f"system. Full error message:\n{e}"
)
def invoke_api() -> None:
if app_config.dev_reload:
try:
import jurigged
except ImportError as e:
logger.error(
'Can\'t start `--dev_reload` because jurigged is not found; `pip install -e ".[dev]"` to include development dependencies.',
exc_info=e,
)
else:
jurigged.watch(logger=InvokeAILogger.get_logger(name="jurigged").info)
enable_dev_reload()
global port
port = find_open_port(app_config.port)

View File

@@ -1,5 +1,10 @@
import logging
import socket
import torch
from invokeai.backend.util.logging import InvokeAILogger
def find_open_port(port: int) -> int:
"""Find a port not in use starting at given port"""
@@ -11,3 +16,31 @@ def find_open_port(port: int) -> int:
return find_open_port(port=port + 1)
else:
return port
def check_cudnn(logger: logging.Logger) -> None:
"""Check for cuDNN issues that could be causing degraded performance."""
if torch.backends.cudnn.is_available():
try:
# Note: At the time of writing (torch 2.2.1), torch.backends.cudnn.version() only raises an error the first
# time it is called. Subsequent calls will return the version number without complaining about a mismatch.
cudnn_version = torch.backends.cudnn.version()
logger.info(f"cuDNN version: {cudnn_version}")
except RuntimeError as e:
logger.warning(
"Encountered a cuDNN version issue. This may result in degraded performance. This issue is usually "
"caused by an incompatible cuDNN version installed in your python environment, or on the host "
f"system. Full error message:\n{e}"
)
def enable_dev_reload() -> None:
"""Enable hot reloading on python file changes during development."""
try:
import jurigged
except ImportError as e:
raise RuntimeError(
'Can\'t start `--dev_reload` because jurigged is not found; `pip install -e ".[dev]"` to include development dependencies.'
) from e
else:
jurigged.watch(logger=InvokeAILogger.get_logger(name="jurigged").info)