mirror of
https://github.com/Casvt/MIND.git
synced 2026-02-19 11:54:46 -05:00
369 lines
11 KiB
Python
369 lines
11 KiB
Python
# -*- coding: utf-8 -*-
|
|
|
|
"""
|
|
Setting up, running and shutting down the webserver.
|
|
Also handling startup types.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from os import urandom
|
|
from threading import Timer
|
|
from typing import Any, Callable, Iterable, Mapping, Union
|
|
|
|
from flask import Flask, request
|
|
from flask.json.provider import DefaultJSONProvider
|
|
from waitress.server import create_server
|
|
from waitress.task import ThreadedTaskDispatcher as TTD
|
|
from werkzeug.middleware.dispatcher import DispatcherMiddleware
|
|
|
|
from backend.base.custom_exceptions import (BadRequest, InternalError,
|
|
LogUnauthMindException,
|
|
MethodNotAllowed, NotFound)
|
|
from backend.base.definitions import (Constants, MindException,
|
|
StartType, StartTypeHandler)
|
|
from backend.base.helpers import Singleton, folder_path, return_api
|
|
from backend.base.logging import LOGGER
|
|
from backend.internals.db import DBConnectionManager, close_db
|
|
from backend.internals.db_backup_import import revert_db_import
|
|
from backend.internals.settings import Settings
|
|
|
|
|
|
# region Thread Manager
|
|
class ThreadedTaskDispatcher(TTD):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
# The DB connection should be closed when the thread is ending, but
|
|
# right before it actually has. Waitress will consider a thread closed
|
|
# once it's not in the self.threads set anymore, regardless of whether
|
|
# the thread has actually ended/joined, so anything we do after that
|
|
# could be cut short by the main thread ending. So we need to close
|
|
# the DB connection before the thread is discarded from the set.
|
|
class TDDSet(set):
|
|
def discard(self, element: Any) -> None:
|
|
DBConnectionManager.close_connection_of_thread()
|
|
return super().discard(element)
|
|
|
|
self.threads = TDDSet()
|
|
return
|
|
|
|
def shutdown(self, cancel_pending: bool = True, timeout: int = 5) -> bool:
|
|
print()
|
|
LOGGER.info('Shutting down MIND')
|
|
result = super().shutdown(cancel_pending, timeout)
|
|
return result
|
|
|
|
|
|
# region Server
|
|
class Server(metaclass=Singleton):
|
|
url_prefix = ''
|
|
|
|
def __init__(self) -> None:
|
|
self.__start_type = None
|
|
self.app = self._create_app()
|
|
return
|
|
|
|
@staticmethod
|
|
def _create_app() -> Flask:
|
|
"""Creates a flask app instance that can be used to start a web server.
|
|
|
|
Returns:
|
|
Flask: The instance.
|
|
"""
|
|
from frontend.api import admin_api, api
|
|
from frontend.ui import render, ui
|
|
|
|
app = Flask(
|
|
__name__,
|
|
template_folder=folder_path('frontend', 'templates'),
|
|
static_folder=folder_path('frontend', 'static'),
|
|
static_url_path='/static'
|
|
)
|
|
app.config['SECRET_KEY'] = urandom(32)
|
|
|
|
json_provider = DefaultJSONProvider(app)
|
|
json_provider.sort_keys = False
|
|
json_provider.compact = False
|
|
app.json = json_provider
|
|
|
|
# Add error handlers
|
|
@app.errorhandler(400)
|
|
def bad_request(e):
|
|
return return_api(**BadRequest().api_response)
|
|
|
|
@app.errorhandler(404)
|
|
def not_found(e):
|
|
if request.path.startswith(
|
|
(Constants.API_PREFIX, Constants.ADMIN_PREFIX)
|
|
):
|
|
return return_api(**NotFound().api_response)
|
|
return render("page_not_found.html")
|
|
|
|
@app.errorhandler(405)
|
|
def method_not_allowed(e):
|
|
return return_api(**MethodNotAllowed().api_response)
|
|
|
|
@app.errorhandler(500)
|
|
def internal_error(e):
|
|
return return_api(**InternalError().api_response)
|
|
|
|
@app.errorhandler(MindException)
|
|
def mind_exception(e: MindException):
|
|
if isinstance(e, LogUnauthMindException):
|
|
ip = request.environ.get(
|
|
'HTTP_X_FORWARDED_FOR',
|
|
request.remote_addr
|
|
)
|
|
LOGGER.warning(f'Unauthorised request from {ip}')
|
|
|
|
return return_api(**e.api_response)
|
|
|
|
# Add endpoints
|
|
app.register_blueprint(ui)
|
|
app.register_blueprint(api, url_prefix=Constants.API_PREFIX)
|
|
app.register_blueprint(admin_api, url_prefix=Constants.ADMIN_PREFIX)
|
|
|
|
# Setup db handling
|
|
app.teardown_appcontext(close_db)
|
|
|
|
return app
|
|
|
|
def run(
|
|
self,
|
|
host: str,
|
|
port: int,
|
|
url_prefix: str
|
|
) -> Union[StartType, None]:
|
|
"""Start the webserver.
|
|
|
|
Args:
|
|
host (str): IP address to bind to, or `0.0.0.0` for all.
|
|
port (int): The port to listen on.
|
|
url_prefix (str): The url prefix/base to host the endpoints on, or
|
|
an empty string for no prefix.
|
|
|
|
Returns:
|
|
Union[StartType, None]: `None` on shutdown, `StartType` on restart.
|
|
"""
|
|
self.app.config["APPLICATION_ROOT"] = url_prefix
|
|
self.app.wsgi_app = DispatcherMiddleware( # type: ignore
|
|
Flask(__name__),
|
|
{url_prefix: self.app.wsgi_app}
|
|
)
|
|
self.__class__.url_prefix = url_prefix
|
|
|
|
dispatcher = ThreadedTaskDispatcher()
|
|
dispatcher.set_thread_count(Constants.HOSTING_THREADS)
|
|
|
|
self.server = create_server(
|
|
self.app,
|
|
_dispatcher=dispatcher,
|
|
host=host,
|
|
port=port,
|
|
threads=Constants.HOSTING_THREADS
|
|
)
|
|
|
|
LOGGER.info(f'MIND running on http://{host}:{port}{self.url_prefix}')
|
|
self.server.run()
|
|
|
|
return self.__start_type
|
|
|
|
def __trigger_server_shutdown(self) -> None:
|
|
"""Shutdown waitress server. Intended to be run in a thread."""
|
|
if not hasattr(self, 'server'):
|
|
return
|
|
|
|
self.server.task_dispatcher.shutdown()
|
|
self.server.close()
|
|
self.server._map.clear() # type: ignore
|
|
return
|
|
|
|
def shutdown(self) -> None:
|
|
"""
|
|
Stop the waitress server. Starts a thread that will trigger the server
|
|
shutdown after one second.
|
|
"""
|
|
self.get_db_timer_thread(
|
|
1.0,
|
|
self.__trigger_server_shutdown,
|
|
"InternalStateHandler"
|
|
).start()
|
|
return
|
|
|
|
def restart(
|
|
self,
|
|
start_type: StartType = StartType.RESTART
|
|
) -> None:
|
|
"""Same as `self.shutdown()`, but restart instead of shutting down.
|
|
|
|
Args:
|
|
start_type (StartType, optional): Why Kapowarr should
|
|
restart.
|
|
Defaults to StartType.RESTART.
|
|
"""
|
|
self.__start_type = start_type
|
|
self.shutdown()
|
|
return
|
|
|
|
def get_db_timer_thread(
|
|
self,
|
|
interval: float,
|
|
target: Callable[..., object],
|
|
name: Union[str, None] = None,
|
|
args: Iterable[Any] = (),
|
|
kwargs: Mapping[str, Any] = {}
|
|
) -> Timer:
|
|
"""Create a timer thread that runs under Flask app context.
|
|
|
|
Args:
|
|
interval (float): The time to wait before running the target.
|
|
|
|
target (Callable[..., object]): The function to run in the thread.
|
|
|
|
name (Union[str, None], optional): The name of the thread.
|
|
Defaults to None.
|
|
|
|
args (Iterable[Any], optional): The arguments to pass to the function.
|
|
Defaults to ().
|
|
|
|
kwargs (Mapping[str, Any], optional): The keyword arguments to pass
|
|
to the function.
|
|
Defaults to {}.
|
|
|
|
Returns:
|
|
Timer: The timer thread instance.
|
|
"""
|
|
def db_thread(*args, **kwargs) -> None:
|
|
with self.app.app_context():
|
|
target(*args, **kwargs)
|
|
return
|
|
|
|
t = Timer(
|
|
interval=interval,
|
|
function=db_thread,
|
|
args=args,
|
|
kwargs=kwargs
|
|
)
|
|
if name:
|
|
t.name = name
|
|
return t
|
|
|
|
|
|
# region StartType Handling
|
|
class StartTypeHandlers:
|
|
handlers: dict[StartType, StartTypeHandler] = {}
|
|
timeout_thread: Union[Timer, None] = None
|
|
running_handler: Union[StartType, None] = None
|
|
|
|
@classmethod
|
|
def register_handler(cls, start_type: StartType):
|
|
"""Class decorator to register a StartTypeHandler for a certain start
|
|
type.
|
|
|
|
```
|
|
@StartTypeHandlers.register_handler(example_type)
|
|
class ExampleHandler(StartTypeHandler):
|
|
...
|
|
```
|
|
|
|
Args:
|
|
start_type (StartType): The start type that the StartTypeHandler is
|
|
for.
|
|
"""
|
|
def wrapper(
|
|
handler_class: type[StartTypeHandler]
|
|
) -> type[StartTypeHandler]:
|
|
cls.handlers[start_type] = handler_class()
|
|
return handler_class
|
|
return wrapper
|
|
|
|
@staticmethod
|
|
def _on_timeout_wrapper(
|
|
on_timeout: Callable[[], None],
|
|
restart_on_timeout: bool
|
|
) -> None:
|
|
on_timeout()
|
|
if restart_on_timeout:
|
|
Server().restart()
|
|
return
|
|
|
|
@classmethod
|
|
def start_timer(cls, start_type: StartType) -> None:
|
|
"""Start the timer for a start type.
|
|
|
|
Args:
|
|
start_type (StartType): The start type to start the timer for.
|
|
"""
|
|
if start_type not in cls.handlers:
|
|
return
|
|
|
|
if cls.timeout_thread and cls.timeout_thread.is_alive():
|
|
cls.timeout_thread.cancel()
|
|
|
|
handler = cls.handlers[start_type]
|
|
cls.running_handler = start_type
|
|
cls.timeout_thread = Server().get_db_timer_thread(
|
|
interval=handler.timeout,
|
|
target=cls._on_timeout_wrapper,
|
|
name="StartTypeHandler",
|
|
args=(handler.on_timeout, handler.restart_on_timeout)
|
|
)
|
|
cls.timeout_thread.start()
|
|
LOGGER.info(
|
|
"Starting timer for %s (%d seconds)",
|
|
handler.description, handler.timeout
|
|
)
|
|
return
|
|
|
|
@classmethod
|
|
def diffuse_timer(cls, start_type: StartType) -> None:
|
|
"""Stop/Diffuse the timer for a start type.
|
|
|
|
Args:
|
|
start_type (StartType): The start type to stop the timer for.
|
|
"""
|
|
if cls.running_handler != start_type:
|
|
return
|
|
|
|
if cls.timeout_thread and cls.timeout_thread.is_alive():
|
|
handler = cls.handlers[start_type]
|
|
LOGGER.info(
|
|
"Timer for %s diffused",
|
|
handler.description
|
|
)
|
|
cls.timeout_thread.cancel()
|
|
cls.timeout_thread = None
|
|
cls.running_handler = None
|
|
handler.on_diffuse()
|
|
return
|
|
|
|
|
|
@StartTypeHandlers.register_handler(StartType.RESTART_HOSTING_CHANGES)
|
|
class HostingChangesHandler(StartTypeHandler):
|
|
description = "hosting changes"
|
|
timeout = Constants.HOSTING_REVERT_TIME
|
|
restart_on_timeout = True
|
|
|
|
def on_timeout(self) -> None:
|
|
Settings().restore_hosting_settings()
|
|
return
|
|
|
|
def on_diffuse(self) -> None:
|
|
return
|
|
|
|
|
|
@StartTypeHandlers.register_handler(StartType.RESTART_DB_CHANGES)
|
|
class DatabaseChangesHandler(StartTypeHandler):
|
|
description = "database import"
|
|
timeout = Constants.DB_REVERT_TIME
|
|
restart_on_timeout = True
|
|
|
|
def on_timeout(self) -> None:
|
|
revert_db_import(swap=True)
|
|
return
|
|
|
|
def on_diffuse(self) -> None:
|
|
revert_db_import(swap=False)
|
|
return
|