mirror of
https://github.com/Casvt/MIND.git
synced 2026-04-03 03:00:22 -04:00
Before it would be set via a method, but it's more logical for it to just be supplied when starting to run the server, alongside the host and port.
369 lines
11 KiB
Python
369 lines
11 KiB
Python
# -*- coding: utf-8 -*-
|
|
|
|
"""
|
|
Setting up, running and shutting down the webserver.
|
|
Also handling startup types.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from os import urandom
|
|
from threading import Timer
|
|
from typing import Any, Callable, Iterable, Mapping, Union
|
|
|
|
from flask import Flask, request
|
|
from flask.json.provider import DefaultJSONProvider
|
|
from waitress.server import create_server
|
|
from waitress.task import ThreadedTaskDispatcher as TTD
|
|
from werkzeug.middleware.dispatcher import DispatcherMiddleware
|
|
|
|
from backend.base.custom_exceptions import (BadRequest, InternalError,
|
|
LogUnauthMindException,
|
|
MethodNotAllowed, NotFound)
|
|
from backend.base.definitions import (Constants, MindException,
|
|
StartType, StartTypeHandler)
|
|
from backend.base.helpers import Singleton, folder_path, return_api
|
|
from backend.base.logging import LOGGER
|
|
from backend.internals.db import DBConnectionManager, close_db
|
|
from backend.internals.db_backup_import import revert_db_import
|
|
from backend.internals.settings import Settings
|
|
|
|
|
|
# region Thread Manager
|
|
class ThreadedTaskDispatcher(TTD):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
# The DB connection should be closed when the thread is ending, but
|
|
# right before it actually has. Waitress will consider a thread closed
|
|
# once it's not in the self.threads set anymore, regardless of whether
|
|
# the thread has actually ended/joined, so anything we do after that
|
|
# could be cut short by the main thread ending. So we need to close
|
|
# the DB connection before the thread is discarded from the set.
|
|
class TDDSet(set):
|
|
def discard(self, element: Any) -> None:
|
|
DBConnectionManager.close_connection_of_thread()
|
|
return super().discard(element)
|
|
|
|
self.threads = TDDSet()
|
|
return
|
|
|
|
def shutdown(self, cancel_pending: bool = True, timeout: int = 5) -> bool:
|
|
print()
|
|
LOGGER.info('Shutting down MIND')
|
|
result = super().shutdown(cancel_pending, timeout)
|
|
return result
|
|
|
|
|
|
# region Server
|
|
class Server(metaclass=Singleton):
|
|
url_prefix = ''
|
|
|
|
def __init__(self) -> None:
|
|
self.__start_type = None
|
|
self.app = self._create_app()
|
|
return
|
|
|
|
@staticmethod
|
|
def _create_app() -> Flask:
|
|
"""Creates a flask app instance that can be used to start a web server.
|
|
|
|
Returns:
|
|
Flask: The instance.
|
|
"""
|
|
from frontend.api import admin_api, api
|
|
from frontend.ui import render, ui
|
|
|
|
app = Flask(
|
|
__name__,
|
|
template_folder=folder_path('frontend', 'templates'),
|
|
static_folder=folder_path('frontend', 'static'),
|
|
static_url_path='/static'
|
|
)
|
|
app.config['SECRET_KEY'] = urandom(32)
|
|
|
|
json_provider = DefaultJSONProvider(app)
|
|
json_provider.sort_keys = False
|
|
json_provider.compact = False
|
|
app.json = json_provider
|
|
|
|
# Add error handlers
|
|
@app.errorhandler(400)
|
|
def bad_request(e):
|
|
return return_api(**BadRequest().api_response)
|
|
|
|
@app.errorhandler(404)
|
|
def not_found(e):
|
|
if request.path.startswith(
|
|
(Constants.API_PREFIX, Constants.ADMIN_PREFIX)
|
|
):
|
|
return return_api(**NotFound().api_response)
|
|
return render("page_not_found.html")
|
|
|
|
@app.errorhandler(405)
|
|
def method_not_allowed(e):
|
|
return return_api(**MethodNotAllowed().api_response)
|
|
|
|
@app.errorhandler(500)
|
|
def internal_error(e):
|
|
return return_api(**InternalError().api_response)
|
|
|
|
@app.errorhandler(MindException)
|
|
def mind_exception(e: MindException):
|
|
if isinstance(e, LogUnauthMindException):
|
|
ip = request.environ.get(
|
|
'HTTP_X_FORWARDED_FOR',
|
|
request.remote_addr
|
|
)
|
|
LOGGER.warning(f'Unauthorised request from {ip}')
|
|
|
|
return return_api(**e.api_response)
|
|
|
|
# Add endpoints
|
|
app.register_blueprint(ui)
|
|
app.register_blueprint(api, url_prefix=Constants.API_PREFIX)
|
|
app.register_blueprint(admin_api, url_prefix=Constants.ADMIN_PREFIX)
|
|
|
|
# Setup db handling
|
|
app.teardown_appcontext(close_db)
|
|
|
|
return app
|
|
|
|
def run(
|
|
self,
|
|
host: str,
|
|
port: int,
|
|
url_prefix: str
|
|
) -> Union[StartType, None]:
|
|
"""Start the webserver.
|
|
|
|
Args:
|
|
host (str): IP address to bind to, or `0.0.0.0` for all.
|
|
port (int): The port to listen on.
|
|
url_prefix (str): The url prefix/base to host the endpoints on, or
|
|
an empty string for no prefix.
|
|
|
|
Returns:
|
|
Union[StartType, None]: `None` on shutdown, `StartType` on restart.
|
|
"""
|
|
self.app.config["APPLICATION_ROOT"] = url_prefix
|
|
self.app.wsgi_app = DispatcherMiddleware( # type: ignore
|
|
Flask(__name__),
|
|
{url_prefix: self.app.wsgi_app}
|
|
)
|
|
self.__class__.url_prefix = url_prefix
|
|
|
|
dispatcher = ThreadedTaskDispatcher()
|
|
dispatcher.set_thread_count(Constants.HOSTING_THREADS)
|
|
|
|
self.server = create_server(
|
|
self.app,
|
|
_dispatcher=dispatcher,
|
|
host=host,
|
|
port=port,
|
|
threads=Constants.HOSTING_THREADS
|
|
)
|
|
|
|
LOGGER.info(f'MIND running on http://{host}:{port}{self.url_prefix}')
|
|
self.server.run()
|
|
|
|
return self.__start_type
|
|
|
|
def __trigger_server_shutdown(self) -> None:
|
|
"""Shutdown waitress server. Intended to be run in a thread."""
|
|
if not hasattr(self, 'server'):
|
|
return
|
|
|
|
self.server.task_dispatcher.shutdown()
|
|
self.server.close()
|
|
self.server._map.clear() # type: ignore
|
|
return
|
|
|
|
def shutdown(self) -> None:
|
|
"""
|
|
Stop the waitress server. Starts a thread that will trigger the server
|
|
shutdown after one second.
|
|
"""
|
|
self.get_db_timer_thread(
|
|
1.0,
|
|
self.__trigger_server_shutdown,
|
|
"InternalStateHandler"
|
|
).start()
|
|
return
|
|
|
|
def restart(
|
|
self,
|
|
start_type: StartType = StartType.RESTART
|
|
) -> None:
|
|
"""Same as `self.shutdown()`, but restart instead of shutting down.
|
|
|
|
Args:
|
|
start_type (StartType, optional): Why Kapowarr should
|
|
restart.
|
|
Defaults to StartType.RESTART.
|
|
"""
|
|
self.__start_type = start_type
|
|
self.shutdown()
|
|
return
|
|
|
|
def get_db_timer_thread(
|
|
self,
|
|
interval: float,
|
|
target: Callable[..., object],
|
|
name: Union[str, None] = None,
|
|
args: Iterable[Any] = (),
|
|
kwargs: Mapping[str, Any] = {}
|
|
) -> Timer:
|
|
"""Create a timer thread that runs under Flask app context.
|
|
|
|
Args:
|
|
interval (float): The time to wait before running the target.
|
|
|
|
target (Callable[..., object]): The function to run in the thread.
|
|
|
|
name (Union[str, None], optional): The name of the thread.
|
|
Defaults to None.
|
|
|
|
args (Iterable[Any], optional): The arguments to pass to the function.
|
|
Defaults to ().
|
|
|
|
kwargs (Mapping[str, Any], optional): The keyword arguments to pass
|
|
to the function.
|
|
Defaults to {}.
|
|
|
|
Returns:
|
|
Timer: The timer thread instance.
|
|
"""
|
|
def db_thread(*args, **kwargs) -> None:
|
|
with self.app.app_context():
|
|
target(*args, **kwargs)
|
|
return
|
|
|
|
t = Timer(
|
|
interval=interval,
|
|
function=db_thread,
|
|
args=args,
|
|
kwargs=kwargs
|
|
)
|
|
if name:
|
|
t.name = name
|
|
return t
|
|
|
|
|
|
# region StartType Handling
|
|
class StartTypeHandlers:
|
|
handlers: dict[StartType, StartTypeHandler] = {}
|
|
timeout_thread: Union[Timer, None] = None
|
|
running_handler: Union[StartType, None] = None
|
|
|
|
@classmethod
|
|
def register_handler(cls, start_type: StartType):
|
|
"""Class decorator to register a StartTypeHandler for a certain start
|
|
type.
|
|
|
|
```
|
|
@StartTypeHandlers.register_handler(example_type)
|
|
class ExampleHandler(StartTypeHandler):
|
|
...
|
|
```
|
|
|
|
Args:
|
|
start_type (StartType): The start type that the StartTypeHandler is
|
|
for.
|
|
"""
|
|
def wrapper(
|
|
handler_class: type[StartTypeHandler]
|
|
) -> type[StartTypeHandler]:
|
|
cls.handlers[start_type] = handler_class()
|
|
return handler_class
|
|
return wrapper
|
|
|
|
@staticmethod
|
|
def _on_timeout_wrapper(
|
|
on_timeout: Callable[[], None],
|
|
restart_on_timeout: bool
|
|
) -> None:
|
|
on_timeout()
|
|
if restart_on_timeout:
|
|
Server().restart()
|
|
return
|
|
|
|
@classmethod
|
|
def start_timer(cls, start_type: StartType) -> None:
|
|
"""Start the timer for a start type.
|
|
|
|
Args:
|
|
start_type (StartType): The start type to start the timer for.
|
|
"""
|
|
if start_type not in cls.handlers:
|
|
return
|
|
|
|
if cls.timeout_thread and cls.timeout_thread.is_alive():
|
|
cls.timeout_thread.cancel()
|
|
|
|
handler = cls.handlers[start_type]
|
|
cls.running_handler = start_type
|
|
cls.timeout_thread = Server().get_db_timer_thread(
|
|
interval=handler.timeout,
|
|
target=cls._on_timeout_wrapper,
|
|
name="StartTypeHandler",
|
|
args=(handler.on_timeout, handler.restart_on_timeout)
|
|
)
|
|
cls.timeout_thread.start()
|
|
LOGGER.info(
|
|
"Starting timer for %s (%d seconds)",
|
|
handler.description, handler.timeout
|
|
)
|
|
return
|
|
|
|
@classmethod
|
|
def diffuse_timer(cls, start_type: StartType) -> None:
|
|
"""Stop/Diffuse the timer for a start type.
|
|
|
|
Args:
|
|
start_type (StartType): The start type to stop the timer for.
|
|
"""
|
|
if cls.running_handler != start_type:
|
|
return
|
|
|
|
if cls.timeout_thread and cls.timeout_thread.is_alive():
|
|
handler = cls.handlers[start_type]
|
|
LOGGER.info(
|
|
"Timer for %s diffused",
|
|
handler.description
|
|
)
|
|
cls.timeout_thread.cancel()
|
|
cls.timeout_thread = None
|
|
cls.running_handler = None
|
|
handler.on_diffuse()
|
|
return
|
|
|
|
|
|
@StartTypeHandlers.register_handler(StartType.RESTART_HOSTING_CHANGES)
|
|
class HostingChangesHandler(StartTypeHandler):
|
|
description = "hosting changes"
|
|
timeout = Constants.HOSTING_REVERT_TIME
|
|
restart_on_timeout = True
|
|
|
|
def on_timeout(self) -> None:
|
|
Settings().restore_hosting_settings()
|
|
return
|
|
|
|
def on_diffuse(self) -> None:
|
|
return
|
|
|
|
|
|
@StartTypeHandlers.register_handler(StartType.RESTART_DB_CHANGES)
|
|
class DatabaseChangesHandler(StartTypeHandler):
|
|
description = "database import"
|
|
timeout = Constants.DB_REVERT_TIME
|
|
restart_on_timeout = True
|
|
|
|
def on_timeout(self) -> None:
|
|
revert_db_import(swap=True)
|
|
return
|
|
|
|
def on_diffuse(self) -> None:
|
|
revert_db_import(swap=False)
|
|
return
|