Refactored server code

This commit is contained in:
CasVT
2024-03-01 12:47:50 +01:00
parent 5a6ef16e95
commit 1e6ef57d6a
9 changed files with 334 additions and 258 deletions

153
MIND.py
View File

@@ -6,24 +6,13 @@ The main file where MIND is started from
"""
import logging
from os import execv, makedirs, urandom
from os.path import dirname, isfile
from shutil import move
from sys import argv
from typing import Union
from flask import Flask, render_template, request
from waitress.server import create_server
from werkzeug.middleware.dispatcher import DispatcherMiddleware
from backend.db import (DBConnection, ThreadedTaskDispatcher, close_db,
revert_db_import, setup_db)
from backend.helpers import RestartVars, check_python_version, folder_path
from backend.db import setup_db, setup_db_location
from backend.helpers import check_python_version
from backend.reminders import ReminderHandler
from backend.settings import get_setting, restore_hosting_settings
from frontend.api import (APIVariables, admin_api, admin_api_prefix, api,
api_prefix, revert_db_thread, revert_hosting_thread)
from frontend.ui import UIVariables, ui
from backend.server import SERVER, handle_flags
from backend.settings import get_setting
#=============================
# WARNING:
@@ -36,8 +25,6 @@ URL_PREFIX = '' # Must either be empty or start with '/' e.g. '/mind'
#=============================
LOGGING_LEVEL = logging.INFO
THREADS = 10
DB_FILENAME = 'db', 'MIND.db'
logging.basicConfig(
level=LOGGING_LEVEL,
@@ -45,97 +32,6 @@ logging.basicConfig(
datefmt='%Y-%m-%d %H:%M:%S'
)
def _create_app() -> Flask:
"""Create a Flask app instance
Returns:
Flask: The created app instance
"""
app = Flask(
__name__,
template_folder=folder_path('frontend','templates'),
static_folder=folder_path('frontend','static'),
static_url_path='/static'
)
app.config['SECRET_KEY'] = urandom(32)
app.config['JSONIFY_PRETTYPRINT_REGULAR'] = True
app.config['JSON_SORT_KEYS'] = False
# Add error handlers
@app.errorhandler(400)
def bad_request(e):
return {'error': 'Bad request', 'result': {}}, 400
@app.errorhandler(405)
def method_not_allowed(e):
return {'error': 'Method not allowed', 'result': {}}, 405
@app.errorhandler(500)
def internal_error(e):
return {'error': 'Internal error', 'result': {}}, 500
@app.errorhandler(404)
def not_found(e):
if request.path.startswith(api_prefix):
return {'error': 'Not Found', 'result': {}}, 404
return render_template('page_not_found.html', url_prefix=UIVariables.url_prefix)
app.register_blueprint(ui)
app.register_blueprint(api, url_prefix=api_prefix)
app.register_blueprint(admin_api, url_prefix=admin_api_prefix)
# Setup closing database
app.teardown_appcontext(close_db)
return app
def _set_url_prefix(app: Flask, url_prefix: str) -> None:
"""Change the URL prefix of the server.
Args:
app (Flask): The `Flask` instance to change the URL prefix of.
url_prefix (str): The desired URL prefix to set it to.
"""
app.config["APPLICATION_ROOT"] = url_prefix
app.wsgi_app = DispatcherMiddleware(
Flask(__name__),
{url_prefix: app.wsgi_app}
)
UIVariables.url_prefix = url_prefix
return
def _handle_flags(flag: Union[None, str]) -> None:
"""Run flag specific actions on startup.
Args:
flag (Union[None, str]): The flag or `None` if there is no flag set.
"""
if flag == RestartVars.DB_IMPORT:
logging.info('Starting timer for database import')
revert_db_thread.start()
elif flag == RestartVars.HOST_CHANGE:
logging.info('Starting timer for hosting changes')
revert_hosting_thread.start()
return
def _handle_flags_pre_restart(flag: Union[None, str]) -> None:
"""Run flag specific actions just before restarting.
Args:
flag (Union[None, str]): The flag or `None` if there is no flag set.
"""
if flag == RestartVars.DB_IMPORT:
revert_db_import(swap=True)
elif flag == RestartVars.HOST_CHANGE:
with Flask(__name__).app_context():
restore_hosting_settings()
close_db()
return
def MIND() -> None:
"""The main function of MIND
"""
@@ -145,53 +41,30 @@ def MIND() -> None:
exit(1)
flag = argv[1] if len(argv) > 1 else None
_handle_flags(flag)
handle_flags(flag)
if isfile(folder_path('db', 'Noted.db')):
move(folder_path('db', 'Noted.db'), folder_path(*DB_FILENAME))
setup_db_location()
db_location = folder_path(*DB_FILENAME)
logging.debug(f'Database location: {db_location}')
makedirs(dirname(db_location), exist_ok=True)
DBConnection.file = db_location
app = _create_app()
reminder_handler = ReminderHandler(app.app_context)
with app.app_context():
SERVER.create_app()
reminder_handler = ReminderHandler(SERVER.app.app_context)
with SERVER.app.app_context():
setup_db()
host = get_setting("host")
port = get_setting("port")
url_prefix = get_setting("url_prefix")
_set_url_prefix(app, url_prefix)
SERVER.set_url_prefix(url_prefix)
reminder_handler.find_next_reminder()
# Create waitress server and run
dispatcher = ThreadedTaskDispatcher()
dispatcher.set_thread_count(THREADS)
server = create_server(
app,
_dispatcher=dispatcher,
host=host,
port=port,
threads=THREADS
)
APIVariables.server_instance = server
logging.info(f'MIND running on http://{host}:{port}{url_prefix}')
# =================
server.run()
SERVER.run(host, port)
# =================
reminder_handler.stop_handling()
if APIVariables.restart:
if APIVariables.handle_flags:
_handle_flags_pre_restart(flag)
logging.info('Restarting MIND')
execv(__file__, [argv[0], *APIVariables.restart_args])
if SERVER.do_restart:
SERVER.handle_restart(flag)
return

View File

@@ -6,8 +6,8 @@ Setting up and interacting with the database.
import logging
from datetime import datetime
from os import remove
from os.path import dirname, join
from os import makedirs, remove
from os.path import dirname, isfile, join
from shutil import move
from sqlite3 import Connection, OperationalError, ProgrammingError, Row
from threading import current_thread, main_thread
@@ -15,12 +15,12 @@ from time import time
from typing import Type, Union
from flask import g
from waitress.task import ThreadedTaskDispatcher as OldThreadedTaskDispatcher
from backend.custom_exceptions import (AccessUnauthorized, InvalidDatabaseFile,
UserNotFound)
from backend.helpers import RestartVars
from backend.helpers import RestartVars, folder_path
DB_FILENAME = 'db', 'MIND.db'
__DATABASE_VERSION__ = 9
__DATEBASE_NAME_ORIGINAL__ = "MIND_original.db"
@@ -34,20 +34,6 @@ class DB_Singleton(type):
return cls._instances[i]
class ThreadedTaskDispatcher(OldThreadedTaskDispatcher):
def handler_thread(self, thread_no: int) -> None:
super().handler_thread(thread_no)
i = f'{DBConnection}{current_thread()}'
if i in DB_Singleton._instances and not DB_Singleton._instances[i].closed:
DB_Singleton._instances[i].close()
return
def shutdown(self, cancel_pending: bool = True, timeout: int = 5) -> bool:
print()
logging.info('Shutting down MIND')
result = super().shutdown(cancel_pending, timeout)
return result
class DBConnection(Connection, metaclass=DB_Singleton):
file = ''
@@ -67,6 +53,19 @@ class DBConnection(Connection, metaclass=DB_Singleton):
def __repr__(self) -> str:
return f'<{self.__class__.__name__}; {current_thread().name}; {id(self)}>'
def setup_db_location() -> None:
"""Create folder for database and link file to DBConnection class
"""
if isfile(folder_path('db', 'Noted.db')):
move(folder_path('db', 'Noted.db'), folder_path(*DB_FILENAME))
db_location = folder_path(*DB_FILENAME)
logging.debug(f'Database location: {db_location}')
makedirs(dirname(db_location), exist_ok=True)
DBConnection.file = db_location
return
def get_db(output_type: Union[Type[dict], Type[tuple]]=tuple):
"""Get a database cursor instance. Coupled to Flask's g.
@@ -473,8 +472,7 @@ def import_db(new_db_file: str) -> None:
DBConnection.file
)
from frontend.api import APIVariables, restart_server_thread
APIVariables.restart_args = [RestartVars.DB_IMPORT.value]
restart_server_thread.start()
from backend.server import SERVER
SERVER.restart([RestartVars.DB_IMPORT.value])
return

View File

@@ -8,7 +8,7 @@ import logging
from enum import Enum
from os.path import abspath, dirname, join
from sys import version_info
from typing import Any, Callable, TypeVar, Union
from typing import Callable, TypeVar, Union
T = TypeVar('T')
U = TypeVar('U')

263
backend/server.py Normal file
View File

@@ -0,0 +1,263 @@
#-*- coding: utf-8 -*-
from __future__ import annotations
import logging
from os import execv, urandom
from sys import argv
from threading import Timer, current_thread
from typing import TYPE_CHECKING, List, NoReturn, Union
from flask import Flask, render_template, request
from waitress import create_server
from waitress.task import ThreadedTaskDispatcher as TTD
from werkzeug.middleware.dispatcher import DispatcherMiddleware
from backend.db import DB_Singleton, DBConnection, close_db, revert_db_import
from backend.helpers import RestartVars, Singleton, folder_path
from backend.settings import restore_hosting_settings
if TYPE_CHECKING:
from waitress.server import TcpWSGIServer
THREADS = 10
class ThreadedTaskDispatcher(TTD):
def handler_thread(self, thread_no: int) -> None:
super().handler_thread(thread_no)
i = f'{DBConnection}{current_thread()}'
if i in DB_Singleton._instances and not DB_Singleton._instances[i].closed:
DB_Singleton._instances[i].close()
return
def shutdown(self, cancel_pending: bool = True, timeout: int = 5) -> bool:
print()
logging.info('Shutting down MIND')
result = super().shutdown(cancel_pending, timeout)
return result
class Server(metaclass=Singleton):
api_prefix = "/api"
admin_api_extension = "/admin"
admin_prefix = "/api/admin"
def __init__(self) -> None:
self.do_restart = False
"Restart instead of shutdown"
self.restart_args: List[str] = []
"Flag to run with when restarting"
self.handle_flags: bool = False
"Run any flag specific actions before restarting"
self.url_prefix = ""
self.revert_db_timer = Timer(60.0, self.__revert_db)
self.revert_db_timer.name = "DatabaseImportHandler"
self.revert_hosting_timer = Timer(60.0, self.__revert_hosting)
self.revert_hosting_timer.name = "HostingHandler"
return
def create_app(self) -> None:
"""Create a Flask app instance"""
from frontend.api import admin_api, api
from frontend.ui import ui
app = Flask(
__name__,
template_folder=folder_path('frontend','templates'),
static_folder=folder_path('frontend','static'),
static_url_path='/static'
)
app.config['SECRET_KEY'] = urandom(32)
app.config['JSONIFY_PRETTYPRINT_REGULAR'] = True
app.config['JSON_SORT_KEYS'] = False
# Add error handlers
@app.errorhandler(400)
def bad_request(e):
return {'error': 'Bad request', 'result': {}}, 400
@app.errorhandler(405)
def method_not_allowed(e):
return {'error': 'Method not allowed', 'result': {}}, 405
@app.errorhandler(500)
def internal_error(e):
return {'error': 'Internal error', 'result': {}}, 500
@app.errorhandler(404)
def not_found(e):
if request.path.startswith(self.api_prefix):
return {'error': 'Not Found', 'result': {}}, 404
return render_template('page_not_found.html', url_prefix=self.url_prefix)
app.register_blueprint(ui)
app.register_blueprint(api, url_prefix=self.api_prefix)
app.register_blueprint(admin_api, url_prefix=self.admin_prefix)
# Setup closing database
app.teardown_appcontext(close_db)
self.app = app
return
def set_url_prefix(self, url_prefix: str) -> None:
"""Change the URL prefix of the server.
Args:
url_prefix (str): The desired URL prefix to set it to.
"""
self.app.config["APPLICATION_ROOT"] = url_prefix
self.app.wsgi_app = DispatcherMiddleware(
Flask(__name__),
{url_prefix: self.app.wsgi_app}
)
self.url_prefix = url_prefix
return
def __create_waitress_server(
self,
host: str,
port: int
) -> TcpWSGIServer:
"""From the `Flask` instance created in `self.create_app()`, create
a waitress server instance.
Args:
host (str): The host to bind to.
port (int): The port to listen on.
Returns:
TcpWSGIServer: The waitress server.
"""
dispatcher = ThreadedTaskDispatcher()
dispatcher.set_thread_count(THREADS)
server = create_server(
self.app,
_dispatcher=dispatcher,
host=host,
port=port,
threads=THREADS
)
return server
def run(self, host: str, port: int) -> None:
"""Start the webserver.
Args:
host (str): The host to bind to.
port (int): The port to listen on.
"""
self.server = self.__create_waitress_server(host, port)
logging.info(f'MIND running on http://{host}:{port}{self.url_prefix}')
self.server.run()
return
def __shutdown_thread_function(self) -> None:
"""Shutdown waitress server. Intended to be run in a thread.
"""
self.server.close()
self.server.task_dispatcher.shutdown()
self.server._map.clear()
return
def shutdown(self) -> None:
"""Stop the waitress server. Starts a thread that
shuts down the server.
"""
t = Timer(1.0, self.__shutdown_thread_function)
t.name = "InternalStateHandler"
t.start()
return
def restart(
self,
restart_args: List[str] = [],
handle_flags: bool = False
) -> None:
"""Same as `self.shutdown()`, but restart instead of shutting down.
Args:
restart_args (List[str], optional): Any arguments to run the new instance with.
Defaults to [].
handle_flags (bool, optional): Run flag specific actions just before restarting.
Defaults to False.
"""
self.do_restart = True
self.restart_args = restart_args
self.handle_flags = handle_flags
self.shutdown()
return
def handle_restart(self, flag: Union[str, None]) -> NoReturn:
"""Restart the interpreter.
Args:
flag (Union[str, None]): Supplied flag, for flag handling.
Returns:
NoReturn: No return because it replaces the interpreter.
"""
if self.handle_flags:
handle_flags_pre_restart(flag)
logging.info('Restarting MIND')
from MIND import __file__ as mind_file
execv(folder_path(mind_file), [argv[0], *self.restart_args])
def __revert_db(self) -> None:
"""Revert database import and restart.
"""
logging.warning(f'Timer for database import expired; reverting back to original file')
self.restart(handle_flags=True)
return
def __revert_hosting(self) -> None:
"""Revert the hosting changes.
"""
logging.warning(f'Timer for hosting changes expired; reverting back to original settings')
self.restart(handle_flags=True)
return
SERVER = Server()
def handle_flags(flag: Union[None, str]) -> None:
"""Run flag specific actions on startup.
Args:
flag (Union[None, str]): The flag or `None` if there is no flag set.
"""
if flag == RestartVars.DB_IMPORT:
logging.info('Starting timer for database import')
SERVER.revert_db_timer.start()
elif flag == RestartVars.HOST_CHANGE:
logging.info('Starting timer for hosting changes')
SERVER.revert_hosting_timer.start()
return
def handle_flags_pre_restart(flag: Union[None, str]) -> None:
"""Run flag specific actions just before restarting.
Args:
flag (Union[None, str]): The flag or `None` if there is no flag set.
"""
if flag == RestartVars.DB_IMPORT:
revert_db_import(swap=True)
elif flag == RestartVars.HOST_CHANGE:
with SERVER.app.app_context():
restore_hosting_settings()
close_db()
return

View File

@@ -8,9 +8,8 @@ from datetime import datetime
from io import BytesIO
from os import remove, urandom
from os.path import basename
from threading import Timer
from time import time as epoch_time
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, Tuple
from flask import g, request, send_file
@@ -26,6 +25,7 @@ from backend.custom_exceptions import (AccessUnauthorized, APIKeyExpired,
from backend.db import get_db, import_db, revert_db_import
from backend.helpers import RestartVars, folder_path
from backend.notification_service import get_apprise_services
from backend.server import SERVER
from backend.settings import (backup_hosting_settings, get_admin_settings,
get_setting, set_setting)
from backend.users import Users
@@ -47,13 +47,10 @@ from frontend.input_validation import (AllowNewAccountsVariable, ColorVariable,
UrlPrefixVariable, URLVariable,
UsernameCreateVariable,
UsernameVariable, WeekDaysVariable,
admin_api, admin_api_prefix, api,
api_prefix, get_api_docs,
admin_api, api, get_api_docs,
input_validation)
if TYPE_CHECKING:
from waitress.server import BaseWSGIServer
from backend.users import User
@@ -61,58 +58,6 @@ if TYPE_CHECKING:
# General variables and functions
#===================
class APIVariables:
server_instance: Union[BaseWSGIServer, None] = None
restart: bool = False
"Restart instead of shutdown"
restart_args: List[str] = []
"Flag to run with when restarting"
handle_flags: bool = False
"Run any flag specific actions before restarting"
def shutdown_server() -> None:
"""Stop server from running"""
APIVariables.server_instance.close()
APIVariables.server_instance.task_dispatcher.shutdown()
APIVariables.server_instance._map.clear()
return
def restart_server() -> None:
"""Restart server.
Will completely replace the current process.
"""
APIVariables.restart = True
shutdown_server()
return
def revert_db() -> None:
"""Revert database import and restart.
"""
logging.warning(f'Timer for database import expired; reverting back to original file')
APIVariables.handle_flags = True
restart_server()
return
def revert_hosting() -> None:
"""Revert the hosting changes.
"""
logging.warning(f'Timer for hosting changes expired; reverting back to original settings')
APIVariables.handle_flags = True
restart_server()
return
shutdown_server_thread = Timer(1.0, shutdown_server)
shutdown_server_thread.name = "InternalStateHandler"
restart_server_thread = Timer(1.0, restart_server)
restart_server_thread.name = "InternalStateHandler"
revert_db_thread = Timer(60.0, revert_db)
revert_db_thread.name = "DatabaseImportHandler"
revert_hosting_thread = Timer(60.0, revert_hosting)
revert_hosting_thread.name = "HostingHandler"
@dataclass
class ApiKeyEntry:
exp: int
@@ -140,14 +85,14 @@ def auth() -> None:
if (
map_entry.user_data.admin
and
not request.path.startswith((admin_api_prefix, api_prefix + '/auth'))
not request.path.startswith((SERVER.admin_prefix, SERVER.api_prefix + '/auth'))
):
raise APIKeyInvalid
if (
not map_entry.user_data.admin
and
request.path.startswith(admin_api_prefix)
request.path.startswith(SERVER.admin_prefix)
):
raise APIKeyInvalid
@@ -216,17 +161,17 @@ def endpoint_wrapper(method: Callable) -> Callable:
@endpoint_wrapper
def api_login(inputs: Dict[str, str]):
user = users.login(inputs['username'], inputs['password'])
# Login successful
if user.admin and revert_db_thread.is_alive():
if user.admin and SERVER.revert_db_timer.is_alive():
logging.info('Timer for database import diffused')
revert_db_thread.cancel()
SERVER.revert_db_timer.cancel()
revert_db_import(swap=False)
elif user.admin and revert_hosting_thread.is_alive():
elif user.admin and SERVER.revert_hosting_timer.is_alive():
logging.info('Timer for hosting changes diffused')
revert_hosting_thread.cancel()
SERVER.revert_hosting_timer.cancel()
# Generate an API key until one
# is generated that isn't used already
@@ -722,7 +667,7 @@ def api_get_static_reminder(inputs: Dict[str, Any], s_id: int):
)
@endpoint_wrapper
def api_shutdown():
shutdown_server_thread.start()
SERVER.shutdown()
return return_api({})
@admin_api.route(
@@ -732,7 +677,7 @@ def api_shutdown():
)
@endpoint_wrapper
def api_restart():
restart_server_thread.start()
SERVER.restart()
return return_api({})
@api.route(
@@ -782,8 +727,7 @@ def api_admin_settings(inputs: Dict[str, Any]):
set_setting(k, v)
if hosting_changes:
APIVariables.restart_args = [RestartVars.HOST_CHANGE.value]
restart_server_thread.start()
SERVER.restart([RestartVars.HOST_CHANGE.value])
return return_api({})

View File

@@ -16,23 +16,20 @@ from apprise import Apprise
from flask import Blueprint, request
from flask.sansio.scaffold import T_route
from backend.custom_exceptions import (AccessUnauthorized, InvalidDatabaseFile, InvalidKeyValue,
InvalidTime, KeyNotFound,
NewAccountsNotAllowed,
from backend.custom_exceptions import (AccessUnauthorized, InvalidDatabaseFile,
InvalidKeyValue, InvalidTime,
KeyNotFound, NewAccountsNotAllowed,
NotificationServiceNotFound,
UsernameInvalid, UsernameTaken,
UserNotFound)
from backend.helpers import (RepeatQuantity, SortingMethod,
TimelessSortingMethod, folder_path)
from backend.server import SERVER
from backend.settings import _format_setting
if TYPE_CHECKING:
from flask import Request
api_prefix = "/api"
_admin_api_prefix = '/admin'
admin_api_prefix = api_prefix + _admin_api_prefix
color_regex = compile(r'#[0-9a-f]{6}')
api_docs: Dict[str, ApiDocEntry] = {}
@@ -130,10 +127,10 @@ class ApiDocEntry:
def get_api_docs(request: Request) -> ApiDocEntry:
if request.path.startswith(admin_api_prefix):
url = _admin_api_prefix + request.url_rule.rule.split(admin_api_prefix)[1]
if request.path.startswith(SERVER.admin_prefix):
url = SERVER.admin_api_extension + request.url_rule.rule.split(SERVER.admin_prefix)[1]
else:
url = request.url_rule.rule.split(api_prefix)[1]
url = request.url_rule.rule.split(SERVER.api_prefix)[1]
return api_docs[url]
@@ -489,7 +486,7 @@ class APIBlueprint(Blueprint):
if self == api:
processed_rule = rule
elif self == admin_api:
processed_rule = _admin_api_prefix + rule
processed_rule = SERVER.admin_api_extension + rule
else:
raise NotImplementedError

View File

@@ -2,21 +2,20 @@
from flask import Blueprint, render_template
from backend.server import SERVER
ui = Blueprint('ui', __name__)
methods = ['GET']
class UIVariables:
url_prefix: str = ''
@ui.route('/', methods=methods)
def ui_login():
return render_template('login.html', url_prefix=UIVariables.url_prefix)
return render_template('login.html', url_prefix=SERVER.url_prefix)
@ui.route('/reminders', methods=methods)
def ui_reminders():
return render_template('reminders.html', url_prefix=UIVariables.url_prefix)
return render_template('reminders.html', url_prefix=SERVER.url_prefix)
@ui.route('/admin', methods=methods)
def ui_admin():
return render_template('admin.html', url_prefix=UIVariables.url_prefix)
return render_template('admin.html', url_prefix=SERVER.url_prefix)

View File

@@ -4,17 +4,19 @@ from flask import Flask
from frontend.api import api
from frontend.ui import ui
from MIND import _create_app
from backend.server import SERVER
class Test_MIND(unittest.TestCase):
def test_create_app(self):
result = _create_app()
self.assertIsInstance(result, Flask)
SERVER.create_app()
self.assertTrue(hasattr(SERVER, 'app'))
app = SERVER.app
self.assertIsInstance(app, Flask)
self.assertEqual(result.blueprints.get('ui'), ui)
self.assertEqual(result.blueprints.get('api'), api)
self.assertEqual(app.blueprints.get('ui'), ui)
self.assertEqual(app.blueprints.get('api'), api)
handlers = result.error_handler_spec[None].keys()
handlers = app.error_handler_spec[None].keys()
required_handlers = 400, 405, 500
for handler in required_handlers:
self.assertIn(handler, handlers)

View File

@@ -1,12 +1,12 @@
import unittest
from backend.db import DBConnection
from backend.db import DB_FILENAME, DBConnection
from backend.helpers import folder_path
from MIND import DB_FILENAME
class Test_DB(unittest.TestCase):
def test_foreign_key(self):
def test_foreign_key_and_wal(self):
DBConnection.file = folder_path(*DB_FILENAME)
instance = DBConnection(timeout=20.0)
self.assertEqual(instance.cursor().execute("PRAGMA foreign_keys;").fetchone()[0], 1)
self.assertEqual(instance.cursor().execute("PRAGMA journal_mode;").fetchone()[0], 'wal')