From 783a5664d56e67f936d1b66f842dbfc069dcc17a Mon Sep 17 00:00:00 2001 From: CasVT Date: Sat, 2 Aug 2025 17:07:45 +0200 Subject: [PATCH] Refactored db_backup_import.py --- backend/base/custom_exceptions.py | 10 +- backend/internals/db_backup_import.py | 146 +++++++++++++------------- backend/internals/db_models.py | 38 ++++++- frontend/input_validation.py | 5 +- 4 files changed, 119 insertions(+), 80 deletions(-) diff --git a/backend/base/custom_exceptions.py b/backend/base/custom_exceptions.py index eae315c..bbd9975 100644 --- a/backend/base/custom_exceptions.py +++ b/backend/base/custom_exceptions.py @@ -150,11 +150,12 @@ class NewAccountsNotAllowed(MindException): class InvalidDatabaseFile(MindException): "The uploaded database file is invalid or not supported" - def __init__(self, filepath_db: str) -> None: + def __init__(self, filepath_db: str, reason: str) -> None: self.filepath_db = filepath_db + self.reason = reason LOGGER.warning( - "The given database file is invalid: %s", - filepath_db + "The given database file is invalid: %s (reason=%s)", + filepath_db, reason ) return @@ -164,7 +165,8 @@ class InvalidDatabaseFile(MindException): 'code': 400, 'error': self.__class__.__name__, 'result': { - 'filepath_db': self.filepath_db + 'filepath_db': self.filepath_db, + 'reason': self.reason } } diff --git a/backend/internals/db_backup_import.py b/backend/internals/db_backup_import.py index f16460d..4bc3920 100644 --- a/backend/internals/db_backup_import.py +++ b/backend/internals/db_backup_import.py @@ -4,28 +4,28 @@ from __future__ import annotations from datetime import datetime from os import remove -from os.path import basename, dirname, join +from os.path import basename, dirname, exists, join from re import compile from shutil import move -from sqlite3 import Connection, OperationalError +from sqlite3 import Connection, OperationalError, Row from time import time from typing import TYPE_CHECKING, List, Union from backend.base.custom_exceptions import (DatabaseFileNotFound, InvalidDatabaseFile) from backend.base.definitions import Constants, DatabaseBackupEntry, StartType -from backend.base.helpers import Singleton, copy, folder_path, list_files +from backend.base.helpers import copy, folder_path, list_files from backend.base.logging import LOGGER -from backend.internals.db import DBConnection, get_db +from backend.internals.db import DBConnection, MindCursor, get_db from backend.internals.db_migration import get_latest_db_version +from backend.internals.db_models import ConfigDB from backend.internals.settings import Settings if TYPE_CHECKING: from threading import Timer -# =================== + # region Backup -# =================== DB_FILE_REGEX = compile( r'MIND_(?P\d{4})_(?P\d{2})_(?P\d{2})_(?P\d{2})_(?P\d{2}).db' ) @@ -70,7 +70,7 @@ def get_backup(index: int) -> DatabaseBackupEntry: """Get info on a specific database backup. Args: - index (int): The index (supplied by `get_backups()`) of the backup. + index (int): The index of the backup (supplied by `get_backups()`). Raises: DatabaseFileNotFound: No backup entry with the given index. @@ -126,7 +126,7 @@ def backup_database() -> None: return -class DatabaseBackupHandler(metaclass=Singleton): +class DatabaseBackupHandler: backup_timer: Union[Timer, None] = None def set_backup_timer(self) -> None: @@ -134,61 +134,68 @@ class DatabaseBackupHandler(metaclass=Singleton): already. Replace it if it does already exist, in case the interval setting has a new value. """ + from backend.internals.server import Server + sv = Settings().get_settings() - if self.backup_timer is not None: - self.backup_timer.cancel() + if self.__class__.backup_timer is not None: + self.__class__.backup_timer.cancel() - from backend.internals.server import Server - self.backup_timer = Server().get_db_timer_thread( + self.__class__.backup_timer = Server().get_db_timer_thread( sv.db_backup_last_run + sv.db_backup_interval - time(), backup_database, "DatabaseBackupHandler" ) - self.backup_timer.start() + self.__class__.backup_timer.start() return def stop_backup_timer(self) -> None: "If the backup timer is running, stop it" - if self.backup_timer is not None: - self.backup_timer.cancel() + if self.__class__.backup_timer is not None: + self.__class__.backup_timer.cancel() return -# =================== # region Import -# =================== def revert_db_import( swap: bool, - imported_db_file: str = '' + other_db_file: str = '' ) -> None: - """Revert the database import process. The original_db_file is the file - currently used (`DBConnection.file`). + """Revert the database import process. Args: - swap (bool): Whether or not to keep the imported_db_file or not, - instead of the original_db_file. + swap (bool): Keep the other database file instead of the current + database file. - imported_db_file (str, optional): The other database file. Keep empty - to use `Constants.DB_ORIGINAL_FILENAME`. + other_db_file (str, optional): The other database file. Keep empty to + use `Constants.DB_ORIGINAL_FILENAME`. Defaults to ''. + + Raises: + InvalidDatabaseFile: The other database file does not exist. """ original_db_file = DBConnection.file - if not imported_db_file: - imported_db_file = join( + if not other_db_file: + other_db_file = join( dirname(DBConnection.file), Constants.DB_ORIGINAL_NAME ) + if not exists(other_db_file): + raise InvalidDatabaseFile( + other_db_file, + "Database file does not exist" + ) + if swap: remove(original_db_file) move( - imported_db_file, + other_db_file, original_db_file ) else: - remove(imported_db_file) + remove(other_db_file) return @@ -202,60 +209,54 @@ def import_db( Args: new_db_file (str): The path to the new database file. copy_hosting_settings (bool): Keep the hosting settings from the current - database. + database. Raises: InvalidDatabaseFile: The new database file is invalid or unsupported. """ - LOGGER.info(f'Importing new database; {copy_hosting_settings=}') + from backend.internals.server import Server + + LOGGER.info(f"Importing new database; {copy_hosting_settings=}") + + cursor_new = MindCursor( + Connection(new_db_file, timeout=Constants.DB_TIMEOUT) + ) + cursor_new.row_factory = Row + config_current = ConfigDB() + config_new = ConfigDB(cursor_new) - cursor = Connection(new_db_file, timeout=20.0).cursor() try: - database_version = cursor.execute( - "SELECT value FROM config WHERE key = 'database_version' LIMIT 1;" - ).fetchone()[0] - if not isinstance(database_version, int): - raise InvalidDatabaseFile(new_db_file) + try: + database_version = config_new.fetch_key("database_version") + if not isinstance(database_version, int): + raise OperationalError + except OperationalError: + raise InvalidDatabaseFile( + new_db_file, + "Uploaded database is not a MIND database file" + ) - except (OperationalError, InvalidDatabaseFile): - LOGGER.error('Uploaded database is not a MIND database file') - cursor.connection.close() + if database_version > get_latest_db_version(): + raise InvalidDatabaseFile( + new_db_file, + "Uploaded database is higher version than this MIND installation can support") + + except InvalidDatabaseFile: + cursor_new.connection.close() revert_db_import( swap=False, - imported_db_file=new_db_file + other_db_file=new_db_file ) - raise InvalidDatabaseFile(new_db_file) - - if database_version > get_latest_db_version(): - LOGGER.error( - 'Uploaded database is higher version than this MIND installation can support') - revert_db_import( - swap=False, - imported_db_file=new_db_file - ) - raise InvalidDatabaseFile(new_db_file) + raise if copy_hosting_settings: - hosting_settings = get_db().execute(""" - SELECT key, value - FROM config - WHERE key = 'host' - OR key = 'port' - OR key = 'url_prefix' - LIMIT 3; - """ - ).fetchalldict() - cursor.executemany(""" - INSERT INTO config(key, value) - VALUES (:key, :value) - ON CONFLICT(key) DO - UPDATE - SET value = :value; - """, - hosting_settings - ) - cursor.connection.commit() - cursor.connection.close() + hosting_settings = config_current.fetch_all() + for key, value in hosting_settings: + if key in ('host', 'port', 'url_prefix'): + config_new.update(key, value) + + cursor_new.connection.commit() + cursor_new.connection.close() move( DBConnection.file, @@ -266,7 +267,6 @@ def import_db( DBConnection.file ) - from backend.internals.server import Server Server().restart(StartType.RESTART_DB_CHANGES) return @@ -279,9 +279,9 @@ def import_db_backup( """Replace the current database with a backup. Args: - index (int): The index (supplied by `get_backups()`) of the backup. + index (int): The index of the backup (supplied by `get_backups()`). copy_hosting_settings (bool): Keep the hosting settings from the current - database. + database. Raises: DatabaseFileNotFound: No backup entry with the given index. diff --git a/backend/internals/db_models.py b/backend/internals/db_models.py index 43c9cce..db652c7 100644 --- a/backend/internals/db_models.py +++ b/backend/internals/db_models.py @@ -1,12 +1,46 @@ # -*- coding: utf-8 -*- -from typing import List, Union +from typing import Any, List, Tuple, Union from backend.base.definitions import (NotificationServiceData, ReminderData, ReminderType, StaticReminderData, TemplateData, UserData) from backend.base.helpers import first_of_subarrays -from backend.internals.db import REMINDER_TO_KEY, get_db +from backend.internals.db import REMINDER_TO_KEY, MindCursor, get_db + + +class ConfigDB: + def __init__(self, cursor: Union[MindCursor, None] = None) -> None: + if cursor is None: + self.cursor = get_db() + else: + self.cursor = cursor + return + + def fetch_all(self) -> List[Tuple[str, Any]]: + return self.cursor.execute( + "SELECT key, value FROM config;" + ).fetchall() + + def fetch_key(self, key: str) -> Any: + return self.cursor.execute( + "SELECT value FROM config WHERE key = ? LIMIT 1;", + (key,) + ).exists() + + def insert(self, key: str, value: Any) -> None: + self.cursor.execute( + "INSERT OR IGNORE INTO config(key, value) VALUES (?, ?);", + (key, value) + ) + return + + def update(self, key: str, value: Any) -> None: + self.cursor.execute( + "UPDATE config SET value = ? WHERE key = ?;", + (value, key) + ) + return class NotificationServicesDB: diff --git a/frontend/input_validation.py b/frontend/input_validation.py index 7b6fd74..cda86d9 100644 --- a/frontend/input_validation.py +++ b/frontend/input_validation.py @@ -961,7 +961,10 @@ def input_validation() -> Dict[str, Any]: if not value.validate(): if isinstance(value, DatabaseFileVariable): - raise InvalidDatabaseFile(value.converted_value) + raise InvalidDatabaseFile( + value.converted_value, + "File is not a database file" + ) elif noted_var.source == DataSource.FILES: raise InvalidKeyValue(noted_var.name, input_value.filename) else: