diff --git a/MIND.py b/MIND.py index 1c4d918..ea923b9 100644 --- a/MIND.py +++ b/MIND.py @@ -6,24 +6,13 @@ The main file where MIND is started from """ import logging -from os import execv, makedirs, urandom -from os.path import dirname, isfile -from shutil import move from sys import argv -from typing import Union -from flask import Flask, render_template, request -from waitress.server import create_server -from werkzeug.middleware.dispatcher import DispatcherMiddleware - -from backend.db import (DBConnection, ThreadedTaskDispatcher, close_db, - revert_db_import, setup_db) -from backend.helpers import RestartVars, check_python_version, folder_path +from backend.db import setup_db, setup_db_location +from backend.helpers import check_python_version from backend.reminders import ReminderHandler -from backend.settings import get_setting, restore_hosting_settings -from frontend.api import (APIVariables, admin_api, admin_api_prefix, api, - api_prefix, revert_db_thread, revert_hosting_thread) -from frontend.ui import UIVariables, ui +from backend.server import SERVER, handle_flags +from backend.settings import get_setting #============================= # WARNING: @@ -36,8 +25,6 @@ URL_PREFIX = '' # Must either be empty or start with '/' e.g. '/mind' #============================= LOGGING_LEVEL = logging.INFO -THREADS = 10 -DB_FILENAME = 'db', 'MIND.db' logging.basicConfig( level=LOGGING_LEVEL, @@ -45,97 +32,6 @@ logging.basicConfig( datefmt='%Y-%m-%d %H:%M:%S' ) -def _create_app() -> Flask: - """Create a Flask app instance - - Returns: - Flask: The created app instance - """ - app = Flask( - __name__, - template_folder=folder_path('frontend','templates'), - static_folder=folder_path('frontend','static'), - static_url_path='/static' - ) - app.config['SECRET_KEY'] = urandom(32) - app.config['JSONIFY_PRETTYPRINT_REGULAR'] = True - app.config['JSON_SORT_KEYS'] = False - - # Add error handlers - @app.errorhandler(400) - def bad_request(e): - return {'error': 'Bad request', 'result': {}}, 400 - - @app.errorhandler(405) - def method_not_allowed(e): - return {'error': 'Method not allowed', 'result': {}}, 405 - - @app.errorhandler(500) - def internal_error(e): - return {'error': 'Internal error', 'result': {}}, 500 - - @app.errorhandler(404) - def not_found(e): - if request.path.startswith(api_prefix): - return {'error': 'Not Found', 'result': {}}, 404 - return render_template('page_not_found.html', url_prefix=UIVariables.url_prefix) - - app.register_blueprint(ui) - app.register_blueprint(api, url_prefix=api_prefix) - app.register_blueprint(admin_api, url_prefix=admin_api_prefix) - - # Setup closing database - app.teardown_appcontext(close_db) - - return app - -def _set_url_prefix(app: Flask, url_prefix: str) -> None: - """Change the URL prefix of the server. - - Args: - app (Flask): The `Flask` instance to change the URL prefix of. - url_prefix (str): The desired URL prefix to set it to. - """ - app.config["APPLICATION_ROOT"] = url_prefix - app.wsgi_app = DispatcherMiddleware( - Flask(__name__), - {url_prefix: app.wsgi_app} - ) - UIVariables.url_prefix = url_prefix - return - -def _handle_flags(flag: Union[None, str]) -> None: - """Run flag specific actions on startup. - - Args: - flag (Union[None, str]): The flag or `None` if there is no flag set. - """ - if flag == RestartVars.DB_IMPORT: - logging.info('Starting timer for database import') - revert_db_thread.start() - - elif flag == RestartVars.HOST_CHANGE: - logging.info('Starting timer for hosting changes') - revert_hosting_thread.start() - - return - -def _handle_flags_pre_restart(flag: Union[None, str]) -> None: - """Run flag specific actions just before restarting. - - Args: - flag (Union[None, str]): The flag or `None` if there is no flag set. - """ - if flag == RestartVars.DB_IMPORT: - revert_db_import(swap=True) - - elif flag == RestartVars.HOST_CHANGE: - with Flask(__name__).app_context(): - restore_hosting_settings() - close_db() - - return - def MIND() -> None: """The main function of MIND """ @@ -145,53 +41,30 @@ def MIND() -> None: exit(1) flag = argv[1] if len(argv) > 1 else None - _handle_flags(flag) + handle_flags(flag) - if isfile(folder_path('db', 'Noted.db')): - move(folder_path('db', 'Noted.db'), folder_path(*DB_FILENAME)) + setup_db_location() - db_location = folder_path(*DB_FILENAME) - logging.debug(f'Database location: {db_location}') - makedirs(dirname(db_location), exist_ok=True) - - DBConnection.file = db_location - - app = _create_app() - reminder_handler = ReminderHandler(app.app_context) - with app.app_context(): + SERVER.create_app() + reminder_handler = ReminderHandler(SERVER.app.app_context) + with SERVER.app.app_context(): setup_db() host = get_setting("host") port = get_setting("port") url_prefix = get_setting("url_prefix") - _set_url_prefix(app, url_prefix) + SERVER.set_url_prefix(url_prefix) reminder_handler.find_next_reminder() - # Create waitress server and run - dispatcher = ThreadedTaskDispatcher() - dispatcher.set_thread_count(THREADS) - server = create_server( - app, - _dispatcher=dispatcher, - host=host, - port=port, - threads=THREADS - ) - APIVariables.server_instance = server - logging.info(f'MIND running on http://{host}:{port}{url_prefix}') # ================= - server.run() + SERVER.run(host, port) # ================= reminder_handler.stop_handling() - if APIVariables.restart: - if APIVariables.handle_flags: - _handle_flags_pre_restart(flag) - - logging.info('Restarting MIND') - execv(__file__, [argv[0], *APIVariables.restart_args]) + if SERVER.do_restart: + SERVER.handle_restart(flag) return diff --git a/backend/db.py b/backend/db.py index 3d1bb71..520c598 100644 --- a/backend/db.py +++ b/backend/db.py @@ -6,8 +6,8 @@ Setting up and interacting with the database. import logging from datetime import datetime -from os import remove -from os.path import dirname, join +from os import makedirs, remove +from os.path import dirname, isfile, join from shutil import move from sqlite3 import Connection, OperationalError, ProgrammingError, Row from threading import current_thread, main_thread @@ -15,12 +15,12 @@ from time import time from typing import Type, Union from flask import g -from waitress.task import ThreadedTaskDispatcher as OldThreadedTaskDispatcher from backend.custom_exceptions import (AccessUnauthorized, InvalidDatabaseFile, UserNotFound) -from backend.helpers import RestartVars +from backend.helpers import RestartVars, folder_path +DB_FILENAME = 'db', 'MIND.db' __DATABASE_VERSION__ = 9 __DATEBASE_NAME_ORIGINAL__ = "MIND_original.db" @@ -34,20 +34,6 @@ class DB_Singleton(type): return cls._instances[i] -class ThreadedTaskDispatcher(OldThreadedTaskDispatcher): - def handler_thread(self, thread_no: int) -> None: - super().handler_thread(thread_no) - i = f'{DBConnection}{current_thread()}' - if i in DB_Singleton._instances and not DB_Singleton._instances[i].closed: - DB_Singleton._instances[i].close() - return - - def shutdown(self, cancel_pending: bool = True, timeout: int = 5) -> bool: - print() - logging.info('Shutting down MIND') - result = super().shutdown(cancel_pending, timeout) - return result - class DBConnection(Connection, metaclass=DB_Singleton): file = '' @@ -67,6 +53,19 @@ class DBConnection(Connection, metaclass=DB_Singleton): def __repr__(self) -> str: return f'<{self.__class__.__name__}; {current_thread().name}; {id(self)}>' +def setup_db_location() -> None: + """Create folder for database and link file to DBConnection class + """ + if isfile(folder_path('db', 'Noted.db')): + move(folder_path('db', 'Noted.db'), folder_path(*DB_FILENAME)) + + db_location = folder_path(*DB_FILENAME) + logging.debug(f'Database location: {db_location}') + makedirs(dirname(db_location), exist_ok=True) + + DBConnection.file = db_location + return + def get_db(output_type: Union[Type[dict], Type[tuple]]=tuple): """Get a database cursor instance. Coupled to Flask's g. @@ -473,8 +472,7 @@ def import_db(new_db_file: str) -> None: DBConnection.file ) - from frontend.api import APIVariables, restart_server_thread - APIVariables.restart_args = [RestartVars.DB_IMPORT.value] - restart_server_thread.start() + from backend.server import SERVER + SERVER.restart([RestartVars.DB_IMPORT.value]) return diff --git a/backend/helpers.py b/backend/helpers.py index b2568f9..61f6d9f 100644 --- a/backend/helpers.py +++ b/backend/helpers.py @@ -8,7 +8,7 @@ import logging from enum import Enum from os.path import abspath, dirname, join from sys import version_info -from typing import Any, Callable, TypeVar, Union +from typing import Callable, TypeVar, Union T = TypeVar('T') U = TypeVar('U') diff --git a/backend/server.py b/backend/server.py new file mode 100644 index 0000000..81aba78 --- /dev/null +++ b/backend/server.py @@ -0,0 +1,263 @@ +#-*- coding: utf-8 -*- + +from __future__ import annotations + +import logging +from os import execv, urandom +from sys import argv +from threading import Timer, current_thread +from typing import TYPE_CHECKING, List, NoReturn, Union + +from flask import Flask, render_template, request +from waitress import create_server +from waitress.task import ThreadedTaskDispatcher as TTD +from werkzeug.middleware.dispatcher import DispatcherMiddleware + +from backend.db import DB_Singleton, DBConnection, close_db, revert_db_import +from backend.helpers import RestartVars, Singleton, folder_path +from backend.settings import restore_hosting_settings + +if TYPE_CHECKING: + from waitress.server import TcpWSGIServer + +THREADS = 10 + +class ThreadedTaskDispatcher(TTD): + def handler_thread(self, thread_no: int) -> None: + super().handler_thread(thread_no) + i = f'{DBConnection}{current_thread()}' + if i in DB_Singleton._instances and not DB_Singleton._instances[i].closed: + DB_Singleton._instances[i].close() + return + + def shutdown(self, cancel_pending: bool = True, timeout: int = 5) -> bool: + print() + logging.info('Shutting down MIND') + result = super().shutdown(cancel_pending, timeout) + return result + + +class Server(metaclass=Singleton): + api_prefix = "/api" + admin_api_extension = "/admin" + admin_prefix = "/api/admin" + + def __init__(self) -> None: + self.do_restart = False + "Restart instead of shutdown" + + self.restart_args: List[str] = [] + "Flag to run with when restarting" + + self.handle_flags: bool = False + "Run any flag specific actions before restarting" + + self.url_prefix = "" + + self.revert_db_timer = Timer(60.0, self.__revert_db) + self.revert_db_timer.name = "DatabaseImportHandler" + self.revert_hosting_timer = Timer(60.0, self.__revert_hosting) + self.revert_hosting_timer.name = "HostingHandler" + + return + + def create_app(self) -> None: + """Create a Flask app instance""" + from frontend.api import admin_api, api + from frontend.ui import ui + + app = Flask( + __name__, + template_folder=folder_path('frontend','templates'), + static_folder=folder_path('frontend','static'), + static_url_path='/static' + ) + app.config['SECRET_KEY'] = urandom(32) + app.config['JSONIFY_PRETTYPRINT_REGULAR'] = True + app.config['JSON_SORT_KEYS'] = False + + # Add error handlers + @app.errorhandler(400) + def bad_request(e): + return {'error': 'Bad request', 'result': {}}, 400 + + @app.errorhandler(405) + def method_not_allowed(e): + return {'error': 'Method not allowed', 'result': {}}, 405 + + @app.errorhandler(500) + def internal_error(e): + return {'error': 'Internal error', 'result': {}}, 500 + + @app.errorhandler(404) + def not_found(e): + if request.path.startswith(self.api_prefix): + return {'error': 'Not Found', 'result': {}}, 404 + return render_template('page_not_found.html', url_prefix=self.url_prefix) + + app.register_blueprint(ui) + app.register_blueprint(api, url_prefix=self.api_prefix) + app.register_blueprint(admin_api, url_prefix=self.admin_prefix) + + # Setup closing database + app.teardown_appcontext(close_db) + + self.app = app + return + + def set_url_prefix(self, url_prefix: str) -> None: + """Change the URL prefix of the server. + + Args: + url_prefix (str): The desired URL prefix to set it to. + """ + self.app.config["APPLICATION_ROOT"] = url_prefix + self.app.wsgi_app = DispatcherMiddleware( + Flask(__name__), + {url_prefix: self.app.wsgi_app} + ) + self.url_prefix = url_prefix + return + + def __create_waitress_server( + self, + host: str, + port: int + ) -> TcpWSGIServer: + """From the `Flask` instance created in `self.create_app()`, create + a waitress server instance. + + Args: + host (str): The host to bind to. + port (int): The port to listen on. + + Returns: + TcpWSGIServer: The waitress server. + """ + dispatcher = ThreadedTaskDispatcher() + dispatcher.set_thread_count(THREADS) + server = create_server( + self.app, + _dispatcher=dispatcher, + host=host, + port=port, + threads=THREADS + ) + return server + + def run(self, host: str, port: int) -> None: + """Start the webserver. + + Args: + host (str): The host to bind to. + port (int): The port to listen on. + """ + self.server = self.__create_waitress_server(host, port) + logging.info(f'MIND running on http://{host}:{port}{self.url_prefix}') + self.server.run() + + return + + def __shutdown_thread_function(self) -> None: + """Shutdown waitress server. Intended to be run in a thread. + """ + self.server.close() + self.server.task_dispatcher.shutdown() + self.server._map.clear() + return + + def shutdown(self) -> None: + """Stop the waitress server. Starts a thread that + shuts down the server. + """ + t = Timer(1.0, self.__shutdown_thread_function) + t.name = "InternalStateHandler" + t.start() + return + + def restart( + self, + restart_args: List[str] = [], + handle_flags: bool = False + ) -> None: + """Same as `self.shutdown()`, but restart instead of shutting down. + + Args: + restart_args (List[str], optional): Any arguments to run the new instance with. + Defaults to []. + + handle_flags (bool, optional): Run flag specific actions just before restarting. + Defaults to False. + """ + self.do_restart = True + self.restart_args = restart_args + self.handle_flags = handle_flags + self.shutdown() + return + + def handle_restart(self, flag: Union[str, None]) -> NoReturn: + """Restart the interpreter. + + Args: + flag (Union[str, None]): Supplied flag, for flag handling. + + Returns: + NoReturn: No return because it replaces the interpreter. + """ + if self.handle_flags: + handle_flags_pre_restart(flag) + + logging.info('Restarting MIND') + from MIND import __file__ as mind_file + execv(folder_path(mind_file), [argv[0], *self.restart_args]) + + def __revert_db(self) -> None: + """Revert database import and restart. + """ + logging.warning(f'Timer for database import expired; reverting back to original file') + self.restart(handle_flags=True) + return + + def __revert_hosting(self) -> None: + """Revert the hosting changes. + """ + logging.warning(f'Timer for hosting changes expired; reverting back to original settings') + self.restart(handle_flags=True) + return + + +SERVER = Server() + + +def handle_flags(flag: Union[None, str]) -> None: + """Run flag specific actions on startup. + + Args: + flag (Union[None, str]): The flag or `None` if there is no flag set. + """ + if flag == RestartVars.DB_IMPORT: + logging.info('Starting timer for database import') + SERVER.revert_db_timer.start() + + elif flag == RestartVars.HOST_CHANGE: + logging.info('Starting timer for hosting changes') + SERVER.revert_hosting_timer.start() + + return + + +def handle_flags_pre_restart(flag: Union[None, str]) -> None: + """Run flag specific actions just before restarting. + + Args: + flag (Union[None, str]): The flag or `None` if there is no flag set. + """ + if flag == RestartVars.DB_IMPORT: + revert_db_import(swap=True) + + elif flag == RestartVars.HOST_CHANGE: + with SERVER.app.app_context(): + restore_hosting_settings() + close_db() + + return diff --git a/frontend/api.py b/frontend/api.py index a7775d0..ca5965d 100644 --- a/frontend/api.py +++ b/frontend/api.py @@ -8,9 +8,8 @@ from datetime import datetime from io import BytesIO from os import remove, urandom from os.path import basename -from threading import Timer from time import time as epoch_time -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, Tuple from flask import g, request, send_file @@ -26,6 +25,7 @@ from backend.custom_exceptions import (AccessUnauthorized, APIKeyExpired, from backend.db import get_db, import_db, revert_db_import from backend.helpers import RestartVars, folder_path from backend.notification_service import get_apprise_services +from backend.server import SERVER from backend.settings import (backup_hosting_settings, get_admin_settings, get_setting, set_setting) from backend.users import Users @@ -47,13 +47,10 @@ from frontend.input_validation import (AllowNewAccountsVariable, ColorVariable, UrlPrefixVariable, URLVariable, UsernameCreateVariable, UsernameVariable, WeekDaysVariable, - admin_api, admin_api_prefix, api, - api_prefix, get_api_docs, + admin_api, api, get_api_docs, input_validation) if TYPE_CHECKING: - from waitress.server import BaseWSGIServer - from backend.users import User @@ -61,58 +58,6 @@ if TYPE_CHECKING: # General variables and functions #=================== -class APIVariables: - server_instance: Union[BaseWSGIServer, None] = None - - restart: bool = False - "Restart instead of shutdown" - - restart_args: List[str] = [] - "Flag to run with when restarting" - - handle_flags: bool = False - "Run any flag specific actions before restarting" - -def shutdown_server() -> None: - """Stop server from running""" - APIVariables.server_instance.close() - APIVariables.server_instance.task_dispatcher.shutdown() - APIVariables.server_instance._map.clear() - return - -def restart_server() -> None: - """Restart server. - Will completely replace the current process. - """ - APIVariables.restart = True - shutdown_server() - return - -def revert_db() -> None: - """Revert database import and restart. - """ - logging.warning(f'Timer for database import expired; reverting back to original file') - APIVariables.handle_flags = True - restart_server() - return - -def revert_hosting() -> None: - """Revert the hosting changes. - """ - logging.warning(f'Timer for hosting changes expired; reverting back to original settings') - APIVariables.handle_flags = True - restart_server() - return - -shutdown_server_thread = Timer(1.0, shutdown_server) -shutdown_server_thread.name = "InternalStateHandler" -restart_server_thread = Timer(1.0, restart_server) -restart_server_thread.name = "InternalStateHandler" -revert_db_thread = Timer(60.0, revert_db) -revert_db_thread.name = "DatabaseImportHandler" -revert_hosting_thread = Timer(60.0, revert_hosting) -revert_hosting_thread.name = "HostingHandler" - @dataclass class ApiKeyEntry: exp: int @@ -140,14 +85,14 @@ def auth() -> None: if ( map_entry.user_data.admin and - not request.path.startswith((admin_api_prefix, api_prefix + '/auth')) + not request.path.startswith((SERVER.admin_prefix, SERVER.api_prefix + '/auth')) ): raise APIKeyInvalid if ( not map_entry.user_data.admin and - request.path.startswith(admin_api_prefix) + request.path.startswith(SERVER.admin_prefix) ): raise APIKeyInvalid @@ -216,17 +161,17 @@ def endpoint_wrapper(method: Callable) -> Callable: @endpoint_wrapper def api_login(inputs: Dict[str, str]): user = users.login(inputs['username'], inputs['password']) - + # Login successful - - if user.admin and revert_db_thread.is_alive(): + + if user.admin and SERVER.revert_db_timer.is_alive(): logging.info('Timer for database import diffused') - revert_db_thread.cancel() + SERVER.revert_db_timer.cancel() revert_db_import(swap=False) - elif user.admin and revert_hosting_thread.is_alive(): + elif user.admin and SERVER.revert_hosting_timer.is_alive(): logging.info('Timer for hosting changes diffused') - revert_hosting_thread.cancel() + SERVER.revert_hosting_timer.cancel() # Generate an API key until one # is generated that isn't used already @@ -722,7 +667,7 @@ def api_get_static_reminder(inputs: Dict[str, Any], s_id: int): ) @endpoint_wrapper def api_shutdown(): - shutdown_server_thread.start() + SERVER.shutdown() return return_api({}) @admin_api.route( @@ -732,7 +677,7 @@ def api_shutdown(): ) @endpoint_wrapper def api_restart(): - restart_server_thread.start() + SERVER.restart() return return_api({}) @api.route( @@ -782,8 +727,7 @@ def api_admin_settings(inputs: Dict[str, Any]): set_setting(k, v) if hosting_changes: - APIVariables.restart_args = [RestartVars.HOST_CHANGE.value] - restart_server_thread.start() + SERVER.restart([RestartVars.HOST_CHANGE.value]) return return_api({}) diff --git a/frontend/input_validation.py b/frontend/input_validation.py index beedb6f..2bca0c5 100644 --- a/frontend/input_validation.py +++ b/frontend/input_validation.py @@ -16,23 +16,20 @@ from apprise import Apprise from flask import Blueprint, request from flask.sansio.scaffold import T_route -from backend.custom_exceptions import (AccessUnauthorized, InvalidDatabaseFile, InvalidKeyValue, - InvalidTime, KeyNotFound, - NewAccountsNotAllowed, +from backend.custom_exceptions import (AccessUnauthorized, InvalidDatabaseFile, + InvalidKeyValue, InvalidTime, + KeyNotFound, NewAccountsNotAllowed, NotificationServiceNotFound, UsernameInvalid, UsernameTaken, UserNotFound) from backend.helpers import (RepeatQuantity, SortingMethod, TimelessSortingMethod, folder_path) +from backend.server import SERVER from backend.settings import _format_setting if TYPE_CHECKING: from flask import Request -api_prefix = "/api" -_admin_api_prefix = '/admin' -admin_api_prefix = api_prefix + _admin_api_prefix - color_regex = compile(r'#[0-9a-f]{6}') api_docs: Dict[str, ApiDocEntry] = {} @@ -130,10 +127,10 @@ class ApiDocEntry: def get_api_docs(request: Request) -> ApiDocEntry: - if request.path.startswith(admin_api_prefix): - url = _admin_api_prefix + request.url_rule.rule.split(admin_api_prefix)[1] + if request.path.startswith(SERVER.admin_prefix): + url = SERVER.admin_api_extension + request.url_rule.rule.split(SERVER.admin_prefix)[1] else: - url = request.url_rule.rule.split(api_prefix)[1] + url = request.url_rule.rule.split(SERVER.api_prefix)[1] return api_docs[url] @@ -489,7 +486,7 @@ class APIBlueprint(Blueprint): if self == api: processed_rule = rule elif self == admin_api: - processed_rule = _admin_api_prefix + rule + processed_rule = SERVER.admin_api_extension + rule else: raise NotImplementedError diff --git a/frontend/ui.py b/frontend/ui.py index 4f953f6..35c3dd9 100644 --- a/frontend/ui.py +++ b/frontend/ui.py @@ -2,21 +2,20 @@ from flask import Blueprint, render_template +from backend.server import SERVER + ui = Blueprint('ui', __name__) methods = ['GET'] -class UIVariables: - url_prefix: str = '' - @ui.route('/', methods=methods) def ui_login(): - return render_template('login.html', url_prefix=UIVariables.url_prefix) + return render_template('login.html', url_prefix=SERVER.url_prefix) @ui.route('/reminders', methods=methods) def ui_reminders(): - return render_template('reminders.html', url_prefix=UIVariables.url_prefix) + return render_template('reminders.html', url_prefix=SERVER.url_prefix) @ui.route('/admin', methods=methods) def ui_admin(): - return render_template('admin.html', url_prefix=UIVariables.url_prefix) + return render_template('admin.html', url_prefix=SERVER.url_prefix) diff --git a/tests/MIND_test.py b/tests/MIND_test.py index 4e0e084..5a497a7 100644 --- a/tests/MIND_test.py +++ b/tests/MIND_test.py @@ -4,17 +4,19 @@ from flask import Flask from frontend.api import api from frontend.ui import ui -from MIND import _create_app +from backend.server import SERVER class Test_MIND(unittest.TestCase): def test_create_app(self): - result = _create_app() - self.assertIsInstance(result, Flask) + SERVER.create_app() + self.assertTrue(hasattr(SERVER, 'app')) + app = SERVER.app + self.assertIsInstance(app, Flask) - self.assertEqual(result.blueprints.get('ui'), ui) - self.assertEqual(result.blueprints.get('api'), api) + self.assertEqual(app.blueprints.get('ui'), ui) + self.assertEqual(app.blueprints.get('api'), api) - handlers = result.error_handler_spec[None].keys() + handlers = app.error_handler_spec[None].keys() required_handlers = 400, 405, 500 for handler in required_handlers: self.assertIn(handler, handlers) diff --git a/tests/db_test.py b/tests/db_test.py index 5f47389..3da2b31 100644 --- a/tests/db_test.py +++ b/tests/db_test.py @@ -1,12 +1,12 @@ import unittest -from backend.db import DBConnection +from backend.db import DB_FILENAME, DBConnection from backend.helpers import folder_path -from MIND import DB_FILENAME class Test_DB(unittest.TestCase): - def test_foreign_key(self): + def test_foreign_key_and_wal(self): DBConnection.file = folder_path(*DB_FILENAME) instance = DBConnection(timeout=20.0) self.assertEqual(instance.cursor().execute("PRAGMA foreign_keys;").fetchone()[0], 1) + self.assertEqual(instance.cursor().execute("PRAGMA journal_mode;").fetchone()[0], 'wal')