Files
MIND/backend/internals/server.py
CasVT 89ff90a9a4 Fixed closing DB conns in time on shutdown
We needed the close_all_db function to close any remaining database connections on shutdown. But why were there any unclosed db connections anyway? The connections were closed after the thread was already marked as shut down so when exiting the thread wasn't done with closing the database yet and thus sometimes you had an improper shutdown. Now the threads are only marked as stopped when the database connection is also closed so we have guaranteed that they'll be closed when the server returns. So no more need to explicitly close any remaining afterwards.
2025-04-23 23:10:01 +02:00

276 lines
8.6 KiB
Python

# -*- coding: utf-8 -*-
"""
Setting up, running and shutting down the API and web-ui
"""
from __future__ import annotations
from os import urandom
from threading import Timer, current_thread
from typing import TYPE_CHECKING, Union
from flask import Flask
from waitress.server import create_server
from waitress.task import ThreadedTaskDispatcher as TTD
from werkzeug.middleware.dispatcher import DispatcherMiddleware
from backend.base.definitions import Constants, StartType
from backend.base.helpers import Singleton, folder_path
from backend.base.logging import LOGGER
from backend.internals.db import (DBConnectionManager,
close_db, revert_db_import)
from backend.internals.settings import Settings
if TYPE_CHECKING:
from waitress.server import BaseWSGIServer, MultiSocketServer
class ThreadedTaskDispatcher(TTD):
def handler_thread(self, thread_no: int) -> None:
# Most of this method's content is copied straight from waitress
# except for the the part marked. The thread is considered to be
# stopped when it's removed from self.threads, so we need to close
# the database connection before it.
while True:
with self.lock:
while not self.queue and self.stop_count == 0:
# Mark ourselves as idle before waiting to be
# woken up, then we will once again be active
self.active_count -= 1
self.queue_cv.wait()
self.active_count += 1
if self.stop_count > 0:
self.active_count -= 1
self.stop_count -= 1
# =================
# Kapowarr part
thread_id = current_thread().native_id or -1
if (
thread_id in DBConnectionManager.instances
and not DBConnectionManager.instances[thread_id].closed
):
DBConnectionManager.instances[thread_id].close()
# =================
self.threads.discard(thread_no)
self.thread_exit_cv.notify()
break
task = self.queue.popleft()
try:
task.service()
except BaseException:
self.logger.exception("Exception when servicing %r", task)
return
def shutdown(self, cancel_pending: bool = True, timeout: int = 5) -> bool:
print()
LOGGER.info('Shutting down MIND')
result = super().shutdown(cancel_pending, timeout)
return result
def handle_start_type(start_type: StartType) -> None:
"""Do special actions needed based on restart version.
Args:
start_type (StartType): The restart version.
"""
if start_type == StartType.RESTART_HOSTING_CHANGES:
LOGGER.info("Starting timer for hosting changes")
Server().revert_hosting_timer.start()
elif start_type == StartType.RESTART_DB_CHANGES:
LOGGER.info("Starting timer for database import")
Server().revert_db_timer.start()
return
def diffuse_timers() -> None:
"""Stop any timers running after doing a special restart."""
SERVER = Server()
if SERVER.revert_hosting_timer.is_alive():
LOGGER.info("Timer for hosting changes diffused")
SERVER.revert_hosting_timer.cancel()
elif SERVER.revert_db_timer.is_alive():
LOGGER.info("Timer for database import diffused")
SERVER.revert_db_timer.cancel()
revert_db_import(swap=False)
return
class Server(metaclass=Singleton):
api_prefix = "/api"
admin_api_extension = "/admin"
admin_prefix = "/api/admin"
url_prefix = ''
def __init__(self) -> None:
self.start_type = None
self.revert_db_timer = Timer(
Constants.DB_REVERT_TIME,
revert_db_import,
kwargs={"swap": True}
)
self.revert_db_timer.name = "DatabaseImportHandler"
self.revert_hosting_timer = Timer(
Constants.HOSTING_REVERT_TIME,
self.restore_hosting_settings
)
self.revert_hosting_timer.name = "HostingHandler"
return
def create_app(self) -> None:
"""Creates an flask app instance that can be used to start a web server"""
from frontend.api import admin_api, api
from frontend.ui import ui
app = Flask(
__name__,
template_folder=folder_path('frontend', 'templates'),
static_folder=folder_path('frontend', 'static'),
static_url_path='/static'
)
app.config['SECRET_KEY'] = urandom(32)
app.config['JSONIFY_PRETTYPRINT_REGULAR'] = True
app.config['JSON_SORT_KEYS'] = False
# Add error handlers
@app.errorhandler(400)
def bad_request(e):
return {'error': "BadRequest", "result": {}}, 400
@app.errorhandler(405)
def method_not_allowed(e):
return {'error': "MethodNotAllowed", "result": {}}, 405
@app.errorhandler(500)
def internal_error(e):
return {'error': "InternalError", "result": {}}, 500
# Add endpoints
app.register_blueprint(ui)
app.register_blueprint(api, url_prefix=self.api_prefix)
app.register_blueprint(admin_api, url_prefix=self.admin_prefix)
# Setup db handling
app.teardown_appcontext(close_db)
self.app = app
return
def set_url_prefix(self, url_prefix: str) -> None:
"""Change the URL prefix of the server.
Args:
url_prefix (str): The desired URL prefix to set it to.
"""
self.app.config["APPLICATION_ROOT"] = url_prefix
self.app.wsgi_app = DispatcherMiddleware( # type: ignore
Flask(__name__),
{url_prefix: self.app.wsgi_app}
)
self.url_prefix = url_prefix
return
def __create_waitress_server(
self,
host: str,
port: int
) -> Union[MultiSocketServer, BaseWSGIServer]:
"""From the `Flask` instance created in `self.create_app()`, create
a waitress server instance.
Args:
host (str): Where to host the server on (e.g. `0.0.0.0`).
port (int): The port to host the server on (e.g. `5656`).
Returns:
Union[MultiSocketServer, BaseWSGIServer]: The waitress server instance.
"""
dispatcher = ThreadedTaskDispatcher()
dispatcher.set_thread_count(Constants.HOSTING_THREADS)
server = create_server(
self.app,
_dispatcher=dispatcher,
host=host,
port=port,
threads=Constants.HOSTING_THREADS
)
return server
def run(self, host: str, port: int) -> None:
"""Start the webserver.
Args:
host (str): Where to host the server on (e.g. `0.0.0.0`).
port (int): The port to host the server on (e.g. `5656`).
"""
self.server = self.__create_waitress_server(host, port)
LOGGER.info(f'MIND running on http://{host}:{port}{self.url_prefix}')
self.server.run()
return
def __shutdown_thread_function(self) -> None:
"""Shutdown waitress server. Intended to be run in a thread.
"""
if not hasattr(self, 'server'):
return
self.server.task_dispatcher.shutdown()
self.server.close()
self.server._map.clear() # type: ignore
return
def shutdown(self) -> None:
"""
Stop the waitress server. Starts a thread that shuts down the server.
"""
t = Timer(1.0, self.__shutdown_thread_function)
t.name = "InternalStateHandler"
t.start()
return
def restart(
self,
start_type: StartType = StartType.STARTUP
) -> None:
"""Same as `self.shutdown()`, but restart instead of shutting down.
Args:
start_type (StartType, optional): Why Kapowarr should
restart.
Defaults to StartType.STARTUP.
"""
self.start_type = start_type
self.shutdown()
return
def restore_hosting_settings(self) -> None:
with self.app.app_context():
settings = Settings()
values = settings.get_settings()
main_settings = {
'host': values.backup_host,
'port': values.backup_port,
'url_prefix': values.backup_url_prefix
}
settings.update(main_settings)
self.restart()
return