Files
MIND/backend/internals/server.py

369 lines
11 KiB
Python

# -*- coding: utf-8 -*-
"""
Setting up, running and shutting down the webserver.
Also handling startup types.
"""
from __future__ import annotations
from os import urandom
from threading import Timer
from typing import TYPE_CHECKING, Any, Callable, Iterable, Mapping, Union
from flask import Flask, request
from waitress.server import create_server
from waitress.task import ThreadedTaskDispatcher as TTD
from werkzeug.middleware.dispatcher import DispatcherMiddleware
from backend.base.definitions import Constants, StartType, StartTypeHandler
from backend.base.helpers import Singleton, folder_path
from backend.base.logging import LOGGER
from backend.internals.db import DBConnectionManager, close_db
from backend.internals.db_backup_import import revert_db_import
from backend.internals.settings import Settings
if TYPE_CHECKING:
from waitress.server import BaseWSGIServer, MultiSocketServer
# region Thread Manager
class ThreadedTaskDispatcher(TTD):
def __init__(self) -> None:
super().__init__()
# The DB connection should be closed when the thread is ending, but
# right before it actually has. Waitress will consider a thread closed
# once it's not in the self.threads set anymore, regardless of whether
# the thread has actually ended/joined, so anything we do after that
# could be cut short by the main thread ending. So we need to close
# the DB connection before the thread is discarded from the set.
class TDDSet(set):
def discard(self, element: Any) -> None:
DBConnectionManager.close_connection_of_thread()
return super().discard(element)
self.threads = TDDSet()
return
def shutdown(self, cancel_pending: bool = True, timeout: int = 5) -> bool:
print()
LOGGER.info('Shutting down MIND')
result = super().shutdown(cancel_pending, timeout)
return result
# region Server
class Server(metaclass=Singleton):
url_prefix = ''
def __init__(self) -> None:
self.__start_type = None
return
def create_app(self) -> None:
"""Creates an flask app instance that can be used to start a web server"""
from frontend.api import admin_api, api
from frontend.ui import render, ui
app = Flask(
__name__,
template_folder=folder_path('frontend', 'templates'),
static_folder=folder_path('frontend', 'static'),
static_url_path='/static'
)
app.config['SECRET_KEY'] = urandom(32)
app.config['JSONIFY_PRETTYPRINT_REGULAR'] = True
app.config['JSON_SORT_KEYS'] = False
# Add error handlers
@app.errorhandler(400)
def bad_request(e):
return {'error': "BadRequest", "result": {}}, 400
@app.errorhandler(404)
def not_found(e):
if request.path.startswith(
(Constants.API_PREFIX, Constants.ADMIN_PREFIX)
):
return {'error': "NotFound", "result": {}}, 404
return render("page_not_found.html")
@app.errorhandler(405)
def method_not_allowed(e):
return {'error': "MethodNotAllowed", "result": {}}, 405
@app.errorhandler(500)
def internal_error(e):
return {'error': "InternalError", "result": {}}, 500
# Add endpoints
app.register_blueprint(ui)
app.register_blueprint(api, url_prefix=Constants.API_PREFIX)
app.register_blueprint(admin_api, url_prefix=Constants.ADMIN_PREFIX)
# Setup db handling
app.teardown_appcontext(close_db)
self.app = app
return
def set_url_prefix(self, url_prefix: str) -> None:
"""Change the URL prefix of the server.
Args:
url_prefix (str): The desired URL prefix to set it to.
"""
self.app.config["APPLICATION_ROOT"] = url_prefix
self.app.wsgi_app = DispatcherMiddleware( # type: ignore
Flask(__name__),
{url_prefix: self.app.wsgi_app}
)
self.url_prefix = url_prefix
return
def __create_waitress_server(
self,
host: str,
port: int
) -> Union[MultiSocketServer, BaseWSGIServer]:
"""From the `Flask` instance created in `self.create_app()`, create
a waitress server instance.
Args:
host (str): Where to host the server on (e.g. `0.0.0.0`).
port (int): The port to host the server on (e.g. `5656`).
Returns:
Union[MultiSocketServer, BaseWSGIServer]: The waitress server instance.
"""
dispatcher = ThreadedTaskDispatcher()
dispatcher.set_thread_count(Constants.HOSTING_THREADS)
server = create_server(
self.app,
_dispatcher=dispatcher,
host=host,
port=port,
threads=Constants.HOSTING_THREADS
)
return server
def run(self, host: str, port: int) -> Union[StartType, None]:
"""Start the webserver.
Args:
host (str): Where to host the server on (e.g. `0.0.0.0`).
port (int): The port to host the server on (e.g. `5656`).
Returns:
Union[StartType, None]: `None` on shutdown, `StartType` on restart.
"""
self.server = self.__create_waitress_server(host, port)
LOGGER.info(f'MIND running on http://{host}:{port}{self.url_prefix}')
self.server.run()
return self.__start_type
def __trigger_server_shutdown(self) -> None:
"""Shutdown waitress server. Intended to be run in a thread."""
if not hasattr(self, 'server'):
return
self.server.task_dispatcher.shutdown()
self.server.close()
self.server._map.clear() # type: ignore
return
def shutdown(self) -> None:
"""
Stop the waitress server. Starts a thread that will trigger the server
shutdown after one second.
"""
self.get_db_timer_thread(
1.0,
self.__trigger_server_shutdown,
"InternalStateHandler"
).start()
return
def restart(
self,
start_type: StartType = StartType.RESTART
) -> None:
"""Same as `self.shutdown()`, but restart instead of shutting down.
Args:
start_type (StartType, optional): Why Kapowarr should
restart.
Defaults to StartType.RESTART.
"""
self.__start_type = start_type
self.shutdown()
return
def get_db_timer_thread(
self,
interval: float,
target: Callable[..., object],
name: Union[str, None] = None,
args: Iterable[Any] = (),
kwargs: Mapping[str, Any] = {}
) -> Timer:
"""Create a timer thread that runs under Flask app context.
Args:
interval (float): The time to wait before running the target.
target (Callable[..., object]): The function to run in the thread.
name (Union[str, None], optional): The name of the thread.
Defaults to None.
args (Iterable[Any], optional): The arguments to pass to the function.
Defaults to ().
kwargs (Mapping[str, Any], optional): The keyword arguments to pass
to the function.
Defaults to {}.
Returns:
Timer: The timer thread instance.
"""
def db_thread(*args, **kwargs) -> None:
with self.app.app_context():
target(*args, **kwargs)
DBConnectionManager.close_connection_of_thread()
return
t = Timer(
interval=interval,
function=db_thread,
args=args,
kwargs=kwargs
)
if name:
t.name = name
return t
# region StartType Handling
class StartTypeHandlers:
handlers: dict[StartType, StartTypeHandler] = {}
timeout_thread: Union[Timer, None] = None
running_handler: Union[StartType, None] = None
@classmethod
def register_handler(cls, start_type: StartType):
"""Class decorator to register a StartTypeHandler for a certain start
type.
```
@StartTypeHandlers.register_handler(example_type)
class ExampleHandler(StartTypeHandler):
...
```
Args:
start_type (StartType): The start type that the StartTypeHandler is
for.
"""
def wrapper(
handler_class: type[StartTypeHandler]
) -> type[StartTypeHandler]:
cls.handlers[start_type] = handler_class()
return handler_class
return wrapper
@staticmethod
def _on_timeout_wrapper(
on_timeout: Callable[[], None],
restart_on_timeout: bool
) -> None:
on_timeout()
if restart_on_timeout:
Server().restart()
return
@classmethod
def start_timer(cls, start_type: StartType) -> None:
"""Start the timer for a start type.
Args:
start_type (StartType): The start type to start the timer for.
"""
if start_type not in cls.handlers:
return
if cls.timeout_thread and cls.timeout_thread.is_alive():
cls.timeout_thread.cancel()
handler = cls.handlers[start_type]
cls.running_handler = start_type
cls.timeout_thread = Server().get_db_timer_thread(
interval=handler.timeout,
target=cls._on_timeout_wrapper,
name="StartTypeHandler",
args=(handler.on_timeout, handler.restart_on_timeout)
)
cls.timeout_thread.start()
LOGGER.info(
"Starting timer for %s (%d seconds)",
handler.description, handler.timeout
)
return
@classmethod
def diffuse_timer(cls, start_type: StartType) -> None:
"""Stop/Diffuse the timer for a start type.
Args:
start_type (StartType): The start type to stop the timer for.
"""
if cls.running_handler != start_type:
return
if cls.timeout_thread and cls.timeout_thread.is_alive():
handler = cls.handlers[start_type]
LOGGER.info(
"Timer for %s diffused",
handler.description
)
cls.timeout_thread.cancel()
cls.timeout_thread = None
cls.running_handler = None
handler.on_diffuse()
return
@StartTypeHandlers.register_handler(StartType.RESTART_HOSTING_CHANGES)
class HostingChangesHandler(StartTypeHandler):
description = "hosting changes"
timeout = Constants.HOSTING_REVERT_TIME
restart_on_timeout = True
def on_timeout(self) -> None:
Settings().restore_hosting_settings()
return
def on_diffuse(self) -> None:
return
@StartTypeHandlers.register_handler(StartType.RESTART_DB_CHANGES)
class DatabaseChangesHandler(StartTypeHandler):
description = "database import"
timeout = Constants.DB_REVERT_TIME
restart_on_timeout = True
def on_timeout(self) -> None:
revert_db_import(swap=True)
return
def on_diffuse(self) -> None:
revert_db_import(swap=False)
return