mirror of
https://github.com/Casvt/MIND.git
synced 2026-02-19 11:54:46 -05:00
Refactored server code
This commit is contained in:
153
MIND.py
153
MIND.py
@@ -6,24 +6,13 @@ The main file where MIND is started from
|
||||
"""
|
||||
|
||||
import logging
|
||||
from os import execv, makedirs, urandom
|
||||
from os.path import dirname, isfile
|
||||
from shutil import move
|
||||
from sys import argv
|
||||
from typing import Union
|
||||
|
||||
from flask import Flask, render_template, request
|
||||
from waitress.server import create_server
|
||||
from werkzeug.middleware.dispatcher import DispatcherMiddleware
|
||||
|
||||
from backend.db import (DBConnection, ThreadedTaskDispatcher, close_db,
|
||||
revert_db_import, setup_db)
|
||||
from backend.helpers import RestartVars, check_python_version, folder_path
|
||||
from backend.db import setup_db, setup_db_location
|
||||
from backend.helpers import check_python_version
|
||||
from backend.reminders import ReminderHandler
|
||||
from backend.settings import get_setting, restore_hosting_settings
|
||||
from frontend.api import (APIVariables, admin_api, admin_api_prefix, api,
|
||||
api_prefix, revert_db_thread, revert_hosting_thread)
|
||||
from frontend.ui import UIVariables, ui
|
||||
from backend.server import SERVER, handle_flags
|
||||
from backend.settings import get_setting
|
||||
|
||||
#=============================
|
||||
# WARNING:
|
||||
@@ -36,8 +25,6 @@ URL_PREFIX = '' # Must either be empty or start with '/' e.g. '/mind'
|
||||
#=============================
|
||||
|
||||
LOGGING_LEVEL = logging.INFO
|
||||
THREADS = 10
|
||||
DB_FILENAME = 'db', 'MIND.db'
|
||||
|
||||
logging.basicConfig(
|
||||
level=LOGGING_LEVEL,
|
||||
@@ -45,97 +32,6 @@ logging.basicConfig(
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
|
||||
def _create_app() -> Flask:
|
||||
"""Create a Flask app instance
|
||||
|
||||
Returns:
|
||||
Flask: The created app instance
|
||||
"""
|
||||
app = Flask(
|
||||
__name__,
|
||||
template_folder=folder_path('frontend','templates'),
|
||||
static_folder=folder_path('frontend','static'),
|
||||
static_url_path='/static'
|
||||
)
|
||||
app.config['SECRET_KEY'] = urandom(32)
|
||||
app.config['JSONIFY_PRETTYPRINT_REGULAR'] = True
|
||||
app.config['JSON_SORT_KEYS'] = False
|
||||
|
||||
# Add error handlers
|
||||
@app.errorhandler(400)
|
||||
def bad_request(e):
|
||||
return {'error': 'Bad request', 'result': {}}, 400
|
||||
|
||||
@app.errorhandler(405)
|
||||
def method_not_allowed(e):
|
||||
return {'error': 'Method not allowed', 'result': {}}, 405
|
||||
|
||||
@app.errorhandler(500)
|
||||
def internal_error(e):
|
||||
return {'error': 'Internal error', 'result': {}}, 500
|
||||
|
||||
@app.errorhandler(404)
|
||||
def not_found(e):
|
||||
if request.path.startswith(api_prefix):
|
||||
return {'error': 'Not Found', 'result': {}}, 404
|
||||
return render_template('page_not_found.html', url_prefix=UIVariables.url_prefix)
|
||||
|
||||
app.register_blueprint(ui)
|
||||
app.register_blueprint(api, url_prefix=api_prefix)
|
||||
app.register_blueprint(admin_api, url_prefix=admin_api_prefix)
|
||||
|
||||
# Setup closing database
|
||||
app.teardown_appcontext(close_db)
|
||||
|
||||
return app
|
||||
|
||||
def _set_url_prefix(app: Flask, url_prefix: str) -> None:
|
||||
"""Change the URL prefix of the server.
|
||||
|
||||
Args:
|
||||
app (Flask): The `Flask` instance to change the URL prefix of.
|
||||
url_prefix (str): The desired URL prefix to set it to.
|
||||
"""
|
||||
app.config["APPLICATION_ROOT"] = url_prefix
|
||||
app.wsgi_app = DispatcherMiddleware(
|
||||
Flask(__name__),
|
||||
{url_prefix: app.wsgi_app}
|
||||
)
|
||||
UIVariables.url_prefix = url_prefix
|
||||
return
|
||||
|
||||
def _handle_flags(flag: Union[None, str]) -> None:
|
||||
"""Run flag specific actions on startup.
|
||||
|
||||
Args:
|
||||
flag (Union[None, str]): The flag or `None` if there is no flag set.
|
||||
"""
|
||||
if flag == RestartVars.DB_IMPORT:
|
||||
logging.info('Starting timer for database import')
|
||||
revert_db_thread.start()
|
||||
|
||||
elif flag == RestartVars.HOST_CHANGE:
|
||||
logging.info('Starting timer for hosting changes')
|
||||
revert_hosting_thread.start()
|
||||
|
||||
return
|
||||
|
||||
def _handle_flags_pre_restart(flag: Union[None, str]) -> None:
|
||||
"""Run flag specific actions just before restarting.
|
||||
|
||||
Args:
|
||||
flag (Union[None, str]): The flag or `None` if there is no flag set.
|
||||
"""
|
||||
if flag == RestartVars.DB_IMPORT:
|
||||
revert_db_import(swap=True)
|
||||
|
||||
elif flag == RestartVars.HOST_CHANGE:
|
||||
with Flask(__name__).app_context():
|
||||
restore_hosting_settings()
|
||||
close_db()
|
||||
|
||||
return
|
||||
|
||||
def MIND() -> None:
|
||||
"""The main function of MIND
|
||||
"""
|
||||
@@ -145,53 +41,30 @@ def MIND() -> None:
|
||||
exit(1)
|
||||
|
||||
flag = argv[1] if len(argv) > 1 else None
|
||||
_handle_flags(flag)
|
||||
handle_flags(flag)
|
||||
|
||||
if isfile(folder_path('db', 'Noted.db')):
|
||||
move(folder_path('db', 'Noted.db'), folder_path(*DB_FILENAME))
|
||||
setup_db_location()
|
||||
|
||||
db_location = folder_path(*DB_FILENAME)
|
||||
logging.debug(f'Database location: {db_location}')
|
||||
makedirs(dirname(db_location), exist_ok=True)
|
||||
|
||||
DBConnection.file = db_location
|
||||
|
||||
app = _create_app()
|
||||
reminder_handler = ReminderHandler(app.app_context)
|
||||
with app.app_context():
|
||||
SERVER.create_app()
|
||||
reminder_handler = ReminderHandler(SERVER.app.app_context)
|
||||
with SERVER.app.app_context():
|
||||
setup_db()
|
||||
|
||||
host = get_setting("host")
|
||||
port = get_setting("port")
|
||||
url_prefix = get_setting("url_prefix")
|
||||
_set_url_prefix(app, url_prefix)
|
||||
SERVER.set_url_prefix(url_prefix)
|
||||
|
||||
reminder_handler.find_next_reminder()
|
||||
|
||||
# Create waitress server and run
|
||||
dispatcher = ThreadedTaskDispatcher()
|
||||
dispatcher.set_thread_count(THREADS)
|
||||
server = create_server(
|
||||
app,
|
||||
_dispatcher=dispatcher,
|
||||
host=host,
|
||||
port=port,
|
||||
threads=THREADS
|
||||
)
|
||||
APIVariables.server_instance = server
|
||||
logging.info(f'MIND running on http://{host}:{port}{url_prefix}')
|
||||
# =================
|
||||
server.run()
|
||||
SERVER.run(host, port)
|
||||
# =================
|
||||
|
||||
reminder_handler.stop_handling()
|
||||
|
||||
if APIVariables.restart:
|
||||
if APIVariables.handle_flags:
|
||||
_handle_flags_pre_restart(flag)
|
||||
|
||||
logging.info('Restarting MIND')
|
||||
execv(__file__, [argv[0], *APIVariables.restart_args])
|
||||
if SERVER.do_restart:
|
||||
SERVER.handle_restart(flag)
|
||||
|
||||
return
|
||||
|
||||
|
||||
@@ -6,8 +6,8 @@ Setting up and interacting with the database.
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from os import remove
|
||||
from os.path import dirname, join
|
||||
from os import makedirs, remove
|
||||
from os.path import dirname, isfile, join
|
||||
from shutil import move
|
||||
from sqlite3 import Connection, OperationalError, ProgrammingError, Row
|
||||
from threading import current_thread, main_thread
|
||||
@@ -15,12 +15,12 @@ from time import time
|
||||
from typing import Type, Union
|
||||
|
||||
from flask import g
|
||||
from waitress.task import ThreadedTaskDispatcher as OldThreadedTaskDispatcher
|
||||
|
||||
from backend.custom_exceptions import (AccessUnauthorized, InvalidDatabaseFile,
|
||||
UserNotFound)
|
||||
from backend.helpers import RestartVars
|
||||
from backend.helpers import RestartVars, folder_path
|
||||
|
||||
DB_FILENAME = 'db', 'MIND.db'
|
||||
__DATABASE_VERSION__ = 9
|
||||
__DATEBASE_NAME_ORIGINAL__ = "MIND_original.db"
|
||||
|
||||
@@ -34,20 +34,6 @@ class DB_Singleton(type):
|
||||
|
||||
return cls._instances[i]
|
||||
|
||||
class ThreadedTaskDispatcher(OldThreadedTaskDispatcher):
|
||||
def handler_thread(self, thread_no: int) -> None:
|
||||
super().handler_thread(thread_no)
|
||||
i = f'{DBConnection}{current_thread()}'
|
||||
if i in DB_Singleton._instances and not DB_Singleton._instances[i].closed:
|
||||
DB_Singleton._instances[i].close()
|
||||
return
|
||||
|
||||
def shutdown(self, cancel_pending: bool = True, timeout: int = 5) -> bool:
|
||||
print()
|
||||
logging.info('Shutting down MIND')
|
||||
result = super().shutdown(cancel_pending, timeout)
|
||||
return result
|
||||
|
||||
class DBConnection(Connection, metaclass=DB_Singleton):
|
||||
file = ''
|
||||
|
||||
@@ -67,6 +53,19 @@ class DBConnection(Connection, metaclass=DB_Singleton):
|
||||
def __repr__(self) -> str:
|
||||
return f'<{self.__class__.__name__}; {current_thread().name}; {id(self)}>'
|
||||
|
||||
def setup_db_location() -> None:
|
||||
"""Create folder for database and link file to DBConnection class
|
||||
"""
|
||||
if isfile(folder_path('db', 'Noted.db')):
|
||||
move(folder_path('db', 'Noted.db'), folder_path(*DB_FILENAME))
|
||||
|
||||
db_location = folder_path(*DB_FILENAME)
|
||||
logging.debug(f'Database location: {db_location}')
|
||||
makedirs(dirname(db_location), exist_ok=True)
|
||||
|
||||
DBConnection.file = db_location
|
||||
return
|
||||
|
||||
def get_db(output_type: Union[Type[dict], Type[tuple]]=tuple):
|
||||
"""Get a database cursor instance. Coupled to Flask's g.
|
||||
|
||||
@@ -473,8 +472,7 @@ def import_db(new_db_file: str) -> None:
|
||||
DBConnection.file
|
||||
)
|
||||
|
||||
from frontend.api import APIVariables, restart_server_thread
|
||||
APIVariables.restart_args = [RestartVars.DB_IMPORT.value]
|
||||
restart_server_thread.start()
|
||||
from backend.server import SERVER
|
||||
SERVER.restart([RestartVars.DB_IMPORT.value])
|
||||
|
||||
return
|
||||
|
||||
@@ -8,7 +8,7 @@ import logging
|
||||
from enum import Enum
|
||||
from os.path import abspath, dirname, join
|
||||
from sys import version_info
|
||||
from typing import Any, Callable, TypeVar, Union
|
||||
from typing import Callable, TypeVar, Union
|
||||
|
||||
T = TypeVar('T')
|
||||
U = TypeVar('U')
|
||||
|
||||
263
backend/server.py
Normal file
263
backend/server.py
Normal file
@@ -0,0 +1,263 @@
|
||||
#-*- coding: utf-8 -*-
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from os import execv, urandom
|
||||
from sys import argv
|
||||
from threading import Timer, current_thread
|
||||
from typing import TYPE_CHECKING, List, NoReturn, Union
|
||||
|
||||
from flask import Flask, render_template, request
|
||||
from waitress import create_server
|
||||
from waitress.task import ThreadedTaskDispatcher as TTD
|
||||
from werkzeug.middleware.dispatcher import DispatcherMiddleware
|
||||
|
||||
from backend.db import DB_Singleton, DBConnection, close_db, revert_db_import
|
||||
from backend.helpers import RestartVars, Singleton, folder_path
|
||||
from backend.settings import restore_hosting_settings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from waitress.server import TcpWSGIServer
|
||||
|
||||
THREADS = 10
|
||||
|
||||
class ThreadedTaskDispatcher(TTD):
|
||||
def handler_thread(self, thread_no: int) -> None:
|
||||
super().handler_thread(thread_no)
|
||||
i = f'{DBConnection}{current_thread()}'
|
||||
if i in DB_Singleton._instances and not DB_Singleton._instances[i].closed:
|
||||
DB_Singleton._instances[i].close()
|
||||
return
|
||||
|
||||
def shutdown(self, cancel_pending: bool = True, timeout: int = 5) -> bool:
|
||||
print()
|
||||
logging.info('Shutting down MIND')
|
||||
result = super().shutdown(cancel_pending, timeout)
|
||||
return result
|
||||
|
||||
|
||||
class Server(metaclass=Singleton):
|
||||
api_prefix = "/api"
|
||||
admin_api_extension = "/admin"
|
||||
admin_prefix = "/api/admin"
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.do_restart = False
|
||||
"Restart instead of shutdown"
|
||||
|
||||
self.restart_args: List[str] = []
|
||||
"Flag to run with when restarting"
|
||||
|
||||
self.handle_flags: bool = False
|
||||
"Run any flag specific actions before restarting"
|
||||
|
||||
self.url_prefix = ""
|
||||
|
||||
self.revert_db_timer = Timer(60.0, self.__revert_db)
|
||||
self.revert_db_timer.name = "DatabaseImportHandler"
|
||||
self.revert_hosting_timer = Timer(60.0, self.__revert_hosting)
|
||||
self.revert_hosting_timer.name = "HostingHandler"
|
||||
|
||||
return
|
||||
|
||||
def create_app(self) -> None:
|
||||
"""Create a Flask app instance"""
|
||||
from frontend.api import admin_api, api
|
||||
from frontend.ui import ui
|
||||
|
||||
app = Flask(
|
||||
__name__,
|
||||
template_folder=folder_path('frontend','templates'),
|
||||
static_folder=folder_path('frontend','static'),
|
||||
static_url_path='/static'
|
||||
)
|
||||
app.config['SECRET_KEY'] = urandom(32)
|
||||
app.config['JSONIFY_PRETTYPRINT_REGULAR'] = True
|
||||
app.config['JSON_SORT_KEYS'] = False
|
||||
|
||||
# Add error handlers
|
||||
@app.errorhandler(400)
|
||||
def bad_request(e):
|
||||
return {'error': 'Bad request', 'result': {}}, 400
|
||||
|
||||
@app.errorhandler(405)
|
||||
def method_not_allowed(e):
|
||||
return {'error': 'Method not allowed', 'result': {}}, 405
|
||||
|
||||
@app.errorhandler(500)
|
||||
def internal_error(e):
|
||||
return {'error': 'Internal error', 'result': {}}, 500
|
||||
|
||||
@app.errorhandler(404)
|
||||
def not_found(e):
|
||||
if request.path.startswith(self.api_prefix):
|
||||
return {'error': 'Not Found', 'result': {}}, 404
|
||||
return render_template('page_not_found.html', url_prefix=self.url_prefix)
|
||||
|
||||
app.register_blueprint(ui)
|
||||
app.register_blueprint(api, url_prefix=self.api_prefix)
|
||||
app.register_blueprint(admin_api, url_prefix=self.admin_prefix)
|
||||
|
||||
# Setup closing database
|
||||
app.teardown_appcontext(close_db)
|
||||
|
||||
self.app = app
|
||||
return
|
||||
|
||||
def set_url_prefix(self, url_prefix: str) -> None:
|
||||
"""Change the URL prefix of the server.
|
||||
|
||||
Args:
|
||||
url_prefix (str): The desired URL prefix to set it to.
|
||||
"""
|
||||
self.app.config["APPLICATION_ROOT"] = url_prefix
|
||||
self.app.wsgi_app = DispatcherMiddleware(
|
||||
Flask(__name__),
|
||||
{url_prefix: self.app.wsgi_app}
|
||||
)
|
||||
self.url_prefix = url_prefix
|
||||
return
|
||||
|
||||
def __create_waitress_server(
|
||||
self,
|
||||
host: str,
|
||||
port: int
|
||||
) -> TcpWSGIServer:
|
||||
"""From the `Flask` instance created in `self.create_app()`, create
|
||||
a waitress server instance.
|
||||
|
||||
Args:
|
||||
host (str): The host to bind to.
|
||||
port (int): The port to listen on.
|
||||
|
||||
Returns:
|
||||
TcpWSGIServer: The waitress server.
|
||||
"""
|
||||
dispatcher = ThreadedTaskDispatcher()
|
||||
dispatcher.set_thread_count(THREADS)
|
||||
server = create_server(
|
||||
self.app,
|
||||
_dispatcher=dispatcher,
|
||||
host=host,
|
||||
port=port,
|
||||
threads=THREADS
|
||||
)
|
||||
return server
|
||||
|
||||
def run(self, host: str, port: int) -> None:
|
||||
"""Start the webserver.
|
||||
|
||||
Args:
|
||||
host (str): The host to bind to.
|
||||
port (int): The port to listen on.
|
||||
"""
|
||||
self.server = self.__create_waitress_server(host, port)
|
||||
logging.info(f'MIND running on http://{host}:{port}{self.url_prefix}')
|
||||
self.server.run()
|
||||
|
||||
return
|
||||
|
||||
def __shutdown_thread_function(self) -> None:
|
||||
"""Shutdown waitress server. Intended to be run in a thread.
|
||||
"""
|
||||
self.server.close()
|
||||
self.server.task_dispatcher.shutdown()
|
||||
self.server._map.clear()
|
||||
return
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""Stop the waitress server. Starts a thread that
|
||||
shuts down the server.
|
||||
"""
|
||||
t = Timer(1.0, self.__shutdown_thread_function)
|
||||
t.name = "InternalStateHandler"
|
||||
t.start()
|
||||
return
|
||||
|
||||
def restart(
|
||||
self,
|
||||
restart_args: List[str] = [],
|
||||
handle_flags: bool = False
|
||||
) -> None:
|
||||
"""Same as `self.shutdown()`, but restart instead of shutting down.
|
||||
|
||||
Args:
|
||||
restart_args (List[str], optional): Any arguments to run the new instance with.
|
||||
Defaults to [].
|
||||
|
||||
handle_flags (bool, optional): Run flag specific actions just before restarting.
|
||||
Defaults to False.
|
||||
"""
|
||||
self.do_restart = True
|
||||
self.restart_args = restart_args
|
||||
self.handle_flags = handle_flags
|
||||
self.shutdown()
|
||||
return
|
||||
|
||||
def handle_restart(self, flag: Union[str, None]) -> NoReturn:
|
||||
"""Restart the interpreter.
|
||||
|
||||
Args:
|
||||
flag (Union[str, None]): Supplied flag, for flag handling.
|
||||
|
||||
Returns:
|
||||
NoReturn: No return because it replaces the interpreter.
|
||||
"""
|
||||
if self.handle_flags:
|
||||
handle_flags_pre_restart(flag)
|
||||
|
||||
logging.info('Restarting MIND')
|
||||
from MIND import __file__ as mind_file
|
||||
execv(folder_path(mind_file), [argv[0], *self.restart_args])
|
||||
|
||||
def __revert_db(self) -> None:
|
||||
"""Revert database import and restart.
|
||||
"""
|
||||
logging.warning(f'Timer for database import expired; reverting back to original file')
|
||||
self.restart(handle_flags=True)
|
||||
return
|
||||
|
||||
def __revert_hosting(self) -> None:
|
||||
"""Revert the hosting changes.
|
||||
"""
|
||||
logging.warning(f'Timer for hosting changes expired; reverting back to original settings')
|
||||
self.restart(handle_flags=True)
|
||||
return
|
||||
|
||||
|
||||
SERVER = Server()
|
||||
|
||||
|
||||
def handle_flags(flag: Union[None, str]) -> None:
|
||||
"""Run flag specific actions on startup.
|
||||
|
||||
Args:
|
||||
flag (Union[None, str]): The flag or `None` if there is no flag set.
|
||||
"""
|
||||
if flag == RestartVars.DB_IMPORT:
|
||||
logging.info('Starting timer for database import')
|
||||
SERVER.revert_db_timer.start()
|
||||
|
||||
elif flag == RestartVars.HOST_CHANGE:
|
||||
logging.info('Starting timer for hosting changes')
|
||||
SERVER.revert_hosting_timer.start()
|
||||
|
||||
return
|
||||
|
||||
|
||||
def handle_flags_pre_restart(flag: Union[None, str]) -> None:
|
||||
"""Run flag specific actions just before restarting.
|
||||
|
||||
Args:
|
||||
flag (Union[None, str]): The flag or `None` if there is no flag set.
|
||||
"""
|
||||
if flag == RestartVars.DB_IMPORT:
|
||||
revert_db_import(swap=True)
|
||||
|
||||
elif flag == RestartVars.HOST_CHANGE:
|
||||
with SERVER.app.app_context():
|
||||
restore_hosting_settings()
|
||||
close_db()
|
||||
|
||||
return
|
||||
@@ -8,9 +8,8 @@ from datetime import datetime
|
||||
from io import BytesIO
|
||||
from os import remove, urandom
|
||||
from os.path import basename
|
||||
from threading import Timer
|
||||
from time import time as epoch_time
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Tuple
|
||||
|
||||
from flask import g, request, send_file
|
||||
|
||||
@@ -26,6 +25,7 @@ from backend.custom_exceptions import (AccessUnauthorized, APIKeyExpired,
|
||||
from backend.db import get_db, import_db, revert_db_import
|
||||
from backend.helpers import RestartVars, folder_path
|
||||
from backend.notification_service import get_apprise_services
|
||||
from backend.server import SERVER
|
||||
from backend.settings import (backup_hosting_settings, get_admin_settings,
|
||||
get_setting, set_setting)
|
||||
from backend.users import Users
|
||||
@@ -47,13 +47,10 @@ from frontend.input_validation import (AllowNewAccountsVariable, ColorVariable,
|
||||
UrlPrefixVariable, URLVariable,
|
||||
UsernameCreateVariable,
|
||||
UsernameVariable, WeekDaysVariable,
|
||||
admin_api, admin_api_prefix, api,
|
||||
api_prefix, get_api_docs,
|
||||
admin_api, api, get_api_docs,
|
||||
input_validation)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from waitress.server import BaseWSGIServer
|
||||
|
||||
from backend.users import User
|
||||
|
||||
|
||||
@@ -61,58 +58,6 @@ if TYPE_CHECKING:
|
||||
# General variables and functions
|
||||
#===================
|
||||
|
||||
class APIVariables:
|
||||
server_instance: Union[BaseWSGIServer, None] = None
|
||||
|
||||
restart: bool = False
|
||||
"Restart instead of shutdown"
|
||||
|
||||
restart_args: List[str] = []
|
||||
"Flag to run with when restarting"
|
||||
|
||||
handle_flags: bool = False
|
||||
"Run any flag specific actions before restarting"
|
||||
|
||||
def shutdown_server() -> None:
|
||||
"""Stop server from running"""
|
||||
APIVariables.server_instance.close()
|
||||
APIVariables.server_instance.task_dispatcher.shutdown()
|
||||
APIVariables.server_instance._map.clear()
|
||||
return
|
||||
|
||||
def restart_server() -> None:
|
||||
"""Restart server.
|
||||
Will completely replace the current process.
|
||||
"""
|
||||
APIVariables.restart = True
|
||||
shutdown_server()
|
||||
return
|
||||
|
||||
def revert_db() -> None:
|
||||
"""Revert database import and restart.
|
||||
"""
|
||||
logging.warning(f'Timer for database import expired; reverting back to original file')
|
||||
APIVariables.handle_flags = True
|
||||
restart_server()
|
||||
return
|
||||
|
||||
def revert_hosting() -> None:
|
||||
"""Revert the hosting changes.
|
||||
"""
|
||||
logging.warning(f'Timer for hosting changes expired; reverting back to original settings')
|
||||
APIVariables.handle_flags = True
|
||||
restart_server()
|
||||
return
|
||||
|
||||
shutdown_server_thread = Timer(1.0, shutdown_server)
|
||||
shutdown_server_thread.name = "InternalStateHandler"
|
||||
restart_server_thread = Timer(1.0, restart_server)
|
||||
restart_server_thread.name = "InternalStateHandler"
|
||||
revert_db_thread = Timer(60.0, revert_db)
|
||||
revert_db_thread.name = "DatabaseImportHandler"
|
||||
revert_hosting_thread = Timer(60.0, revert_hosting)
|
||||
revert_hosting_thread.name = "HostingHandler"
|
||||
|
||||
@dataclass
|
||||
class ApiKeyEntry:
|
||||
exp: int
|
||||
@@ -140,14 +85,14 @@ def auth() -> None:
|
||||
if (
|
||||
map_entry.user_data.admin
|
||||
and
|
||||
not request.path.startswith((admin_api_prefix, api_prefix + '/auth'))
|
||||
not request.path.startswith((SERVER.admin_prefix, SERVER.api_prefix + '/auth'))
|
||||
):
|
||||
raise APIKeyInvalid
|
||||
|
||||
if (
|
||||
not map_entry.user_data.admin
|
||||
and
|
||||
request.path.startswith(admin_api_prefix)
|
||||
request.path.startswith(SERVER.admin_prefix)
|
||||
):
|
||||
raise APIKeyInvalid
|
||||
|
||||
@@ -216,17 +161,17 @@ def endpoint_wrapper(method: Callable) -> Callable:
|
||||
@endpoint_wrapper
|
||||
def api_login(inputs: Dict[str, str]):
|
||||
user = users.login(inputs['username'], inputs['password'])
|
||||
|
||||
|
||||
# Login successful
|
||||
|
||||
if user.admin and revert_db_thread.is_alive():
|
||||
|
||||
if user.admin and SERVER.revert_db_timer.is_alive():
|
||||
logging.info('Timer for database import diffused')
|
||||
revert_db_thread.cancel()
|
||||
SERVER.revert_db_timer.cancel()
|
||||
revert_db_import(swap=False)
|
||||
|
||||
elif user.admin and revert_hosting_thread.is_alive():
|
||||
elif user.admin and SERVER.revert_hosting_timer.is_alive():
|
||||
logging.info('Timer for hosting changes diffused')
|
||||
revert_hosting_thread.cancel()
|
||||
SERVER.revert_hosting_timer.cancel()
|
||||
|
||||
# Generate an API key until one
|
||||
# is generated that isn't used already
|
||||
@@ -722,7 +667,7 @@ def api_get_static_reminder(inputs: Dict[str, Any], s_id: int):
|
||||
)
|
||||
@endpoint_wrapper
|
||||
def api_shutdown():
|
||||
shutdown_server_thread.start()
|
||||
SERVER.shutdown()
|
||||
return return_api({})
|
||||
|
||||
@admin_api.route(
|
||||
@@ -732,7 +677,7 @@ def api_shutdown():
|
||||
)
|
||||
@endpoint_wrapper
|
||||
def api_restart():
|
||||
restart_server_thread.start()
|
||||
SERVER.restart()
|
||||
return return_api({})
|
||||
|
||||
@api.route(
|
||||
@@ -782,8 +727,7 @@ def api_admin_settings(inputs: Dict[str, Any]):
|
||||
set_setting(k, v)
|
||||
|
||||
if hosting_changes:
|
||||
APIVariables.restart_args = [RestartVars.HOST_CHANGE.value]
|
||||
restart_server_thread.start()
|
||||
SERVER.restart([RestartVars.HOST_CHANGE.value])
|
||||
|
||||
return return_api({})
|
||||
|
||||
|
||||
@@ -16,23 +16,20 @@ from apprise import Apprise
|
||||
from flask import Blueprint, request
|
||||
from flask.sansio.scaffold import T_route
|
||||
|
||||
from backend.custom_exceptions import (AccessUnauthorized, InvalidDatabaseFile, InvalidKeyValue,
|
||||
InvalidTime, KeyNotFound,
|
||||
NewAccountsNotAllowed,
|
||||
from backend.custom_exceptions import (AccessUnauthorized, InvalidDatabaseFile,
|
||||
InvalidKeyValue, InvalidTime,
|
||||
KeyNotFound, NewAccountsNotAllowed,
|
||||
NotificationServiceNotFound,
|
||||
UsernameInvalid, UsernameTaken,
|
||||
UserNotFound)
|
||||
from backend.helpers import (RepeatQuantity, SortingMethod,
|
||||
TimelessSortingMethod, folder_path)
|
||||
from backend.server import SERVER
|
||||
from backend.settings import _format_setting
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from flask import Request
|
||||
|
||||
api_prefix = "/api"
|
||||
_admin_api_prefix = '/admin'
|
||||
admin_api_prefix = api_prefix + _admin_api_prefix
|
||||
|
||||
color_regex = compile(r'#[0-9a-f]{6}')
|
||||
|
||||
api_docs: Dict[str, ApiDocEntry] = {}
|
||||
@@ -130,10 +127,10 @@ class ApiDocEntry:
|
||||
|
||||
|
||||
def get_api_docs(request: Request) -> ApiDocEntry:
|
||||
if request.path.startswith(admin_api_prefix):
|
||||
url = _admin_api_prefix + request.url_rule.rule.split(admin_api_prefix)[1]
|
||||
if request.path.startswith(SERVER.admin_prefix):
|
||||
url = SERVER.admin_api_extension + request.url_rule.rule.split(SERVER.admin_prefix)[1]
|
||||
else:
|
||||
url = request.url_rule.rule.split(api_prefix)[1]
|
||||
url = request.url_rule.rule.split(SERVER.api_prefix)[1]
|
||||
return api_docs[url]
|
||||
|
||||
|
||||
@@ -489,7 +486,7 @@ class APIBlueprint(Blueprint):
|
||||
if self == api:
|
||||
processed_rule = rule
|
||||
elif self == admin_api:
|
||||
processed_rule = _admin_api_prefix + rule
|
||||
processed_rule = SERVER.admin_api_extension + rule
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -2,21 +2,20 @@
|
||||
|
||||
from flask import Blueprint, render_template
|
||||
|
||||
from backend.server import SERVER
|
||||
|
||||
ui = Blueprint('ui', __name__)
|
||||
|
||||
methods = ['GET']
|
||||
|
||||
class UIVariables:
|
||||
url_prefix: str = ''
|
||||
|
||||
@ui.route('/', methods=methods)
|
||||
def ui_login():
|
||||
return render_template('login.html', url_prefix=UIVariables.url_prefix)
|
||||
return render_template('login.html', url_prefix=SERVER.url_prefix)
|
||||
|
||||
@ui.route('/reminders', methods=methods)
|
||||
def ui_reminders():
|
||||
return render_template('reminders.html', url_prefix=UIVariables.url_prefix)
|
||||
return render_template('reminders.html', url_prefix=SERVER.url_prefix)
|
||||
|
||||
@ui.route('/admin', methods=methods)
|
||||
def ui_admin():
|
||||
return render_template('admin.html', url_prefix=UIVariables.url_prefix)
|
||||
return render_template('admin.html', url_prefix=SERVER.url_prefix)
|
||||
|
||||
@@ -4,17 +4,19 @@ from flask import Flask
|
||||
|
||||
from frontend.api import api
|
||||
from frontend.ui import ui
|
||||
from MIND import _create_app
|
||||
from backend.server import SERVER
|
||||
|
||||
class Test_MIND(unittest.TestCase):
|
||||
def test_create_app(self):
|
||||
result = _create_app()
|
||||
self.assertIsInstance(result, Flask)
|
||||
SERVER.create_app()
|
||||
self.assertTrue(hasattr(SERVER, 'app'))
|
||||
app = SERVER.app
|
||||
self.assertIsInstance(app, Flask)
|
||||
|
||||
self.assertEqual(result.blueprints.get('ui'), ui)
|
||||
self.assertEqual(result.blueprints.get('api'), api)
|
||||
self.assertEqual(app.blueprints.get('ui'), ui)
|
||||
self.assertEqual(app.blueprints.get('api'), api)
|
||||
|
||||
handlers = result.error_handler_spec[None].keys()
|
||||
handlers = app.error_handler_spec[None].keys()
|
||||
required_handlers = 400, 405, 500
|
||||
for handler in required_handlers:
|
||||
self.assertIn(handler, handlers)
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import unittest
|
||||
|
||||
from backend.db import DBConnection
|
||||
from backend.db import DB_FILENAME, DBConnection
|
||||
from backend.helpers import folder_path
|
||||
from MIND import DB_FILENAME
|
||||
|
||||
|
||||
class Test_DB(unittest.TestCase):
|
||||
def test_foreign_key(self):
|
||||
def test_foreign_key_and_wal(self):
|
||||
DBConnection.file = folder_path(*DB_FILENAME)
|
||||
instance = DBConnection(timeout=20.0)
|
||||
self.assertEqual(instance.cursor().execute("PRAGMA foreign_keys;").fetchone()[0], 1)
|
||||
self.assertEqual(instance.cursor().execute("PRAGMA journal_mode;").fetchone()[0], 'wal')
|
||||
|
||||
Reference in New Issue
Block a user