Moved backup command to DB connection

This commit is contained in:
CasVT
2025-08-02 18:03:32 +02:00
parent 783a5664d5
commit ba2d89d4ce
4 changed files with 71 additions and 41 deletions

View File

@@ -34,6 +34,14 @@ class MindCursor(Cursor):
def lastrowid(self) -> int:
return super().lastrowid or 1
@property
def connection(self) -> DBConnection:
return super().connection # type: ignore
def __init__(self, cursor: DBConnection, /) -> None:
super().__init__(cursor)
return
def fetchonedict(self) -> Union[Dict[str, Any], None]:
"""Same as `fetchone` but convert the Row object to a dict.
@@ -98,30 +106,46 @@ class MindCursor(Cursor):
class DBConnectionManager(type):
instances: Dict[int, DBConnection] = {}
def __call__(cls, *args: Any, **kwargs: Any) -> DBConnection:
def __call__(cls, **kwargs: Any) -> DBConnection:
if kwargs.get('db_file'):
return super().__call__(**kwargs)
thread_id = current_thread_id()
if (
not thread_id in cls.instances
or cls.instances[thread_id].closed
):
cls.instances[thread_id] = super().__call__(*args, **kwargs)
cls.instances[thread_id] = super().__call__(**kwargs)
return cls.instances[thread_id]
class DBConnection(Connection, metaclass=DBConnectionManager):
file = ''
default_file = ''
def __init__(self, timeout: float) -> None:
def __init__(
self, *,
db_file: Union[str, None] = None,
timeout: float = Constants.DB_TIMEOUT
) -> None:
"""Create a connection with a database.
Args:
timeout (float): How long to wait before giving up on a command.
db_file (Union[str, None], optional): The database file to connect
to. If `None`, the default file will be used. If something else
than the default file is given, then a new connection will
always be returned.
Defaults to None.
timeout (float, optional): How long to wait before giving up
on a command.
Defaults to Constants.DB_TIMEOUT.
"""
LOGGER.debug(f'Creating connection {self}')
self.db_file = db_file or self.default_file
super().__init__(
self.file,
self.db_file,
timeout=timeout,
detect_types=PARSE_DECLTYPES
)
@@ -144,20 +168,35 @@ class DBConnection(Connection, metaclass=DBConnectionManager):
MindCursor: The database cursor.
"""
if not hasattr(g, 'cursors'):
g.cursors = []
g.cursors = {}
if not g.cursors:
if self.db_file not in g.cursors:
g.cursors[self.db_file] = []
if not g.cursors[self.db_file]:
c = MindCursor(self)
c.row_factory = Row
g.cursors.append(c)
g.cursors[self.db_file].append(c)
if not force_new:
return g.cursors[0]
return g.cursors[self.db_file][0]
else:
c = MindCursor(self)
c.row_factory = Row
g.cursors.append(c)
return g.cursors[-1]
g.cursors[self.db_file].append(c)
return g.cursors[self.db_file][-1]
def create_backup(self, filepath: str) -> None:
"""Create a backup of the current database.
Args:
filepath (str): What the filepath of the backup will be.
"""
self.execute(
"VACUUM INTO ?;",
(filepath,)
)
return
def close(self) -> None:
"""Close the database connection"""
@@ -205,7 +244,7 @@ def set_db_location(
db_file_location
)
DBConnection.file = db_file_location
DBConnection.default_file = db_file_location
SettingsValues.db_backup_folder = dirname(db_file_location)
return
@@ -222,11 +261,7 @@ def get_db(force_new: bool = False) -> MindCursor:
Returns:
MindCursor: Database cursor instance that outputs Row objects.
"""
cursor = (
DBConnection(timeout=Constants.DB_TIMEOUT)
.cursor(force_new=force_new)
)
return cursor
return DBConnection().cursor(force_new=force_new)
def commit() -> None:
@@ -271,13 +306,14 @@ def close_db(e: Union[None, BaseException] = None) -> None:
try:
cursors = g.cursors
db: DBConnection = cursors[0].connection
for c in cursors:
c.close()
for cursors in g.cursors.values():
db: DBConnection = cursors[0].connection
for c in cursors:
c.close()
db.commit()
if not current_thread().name.startswith('waitress-'):
db.close()
delattr(g, 'cursors')
db.commit()
if not current_thread().name.startswith('waitress-'):
db.close()
except ProgrammingError:
pass

View File

@@ -7,7 +7,7 @@ from os import remove
from os.path import basename, dirname, exists, join
from re import compile
from shutil import move
from sqlite3 import Connection, OperationalError, Row
from sqlite3 import OperationalError
from time import time
from typing import TYPE_CHECKING, List, Union
@@ -16,7 +16,7 @@ from backend.base.custom_exceptions import (DatabaseFileNotFound,
from backend.base.definitions import Constants, DatabaseBackupEntry, StartType
from backend.base.helpers import copy, folder_path, list_files
from backend.base.logging import LOGGER
from backend.internals.db import DBConnection, MindCursor, get_db
from backend.internals.db import DBConnection
from backend.internals.db_migration import get_latest_db_version
from backend.internals.db_models import ConfigDB
from backend.internals.settings import Settings
@@ -95,10 +95,7 @@ def create_database_copy(folder: str) -> str:
"""
current_date = datetime.now().strftime(r"%Y_%m_%d_%H_%M")
filename = join(folder, f'MIND_{current_date}.db')
get_db().execute(
"VACUUM INTO ?;",
(filename,)
)
DBConnection().create_backup(filename)
return filename
@@ -174,10 +171,10 @@ def revert_db_import(
Raises:
InvalidDatabaseFile: The other database file does not exist.
"""
original_db_file = DBConnection.file
original_db_file = DBConnection.default_file
if not other_db_file:
other_db_file = join(
dirname(DBConnection.file),
dirname(DBConnection.default_file),
Constants.DB_ORIGINAL_NAME
)
@@ -218,10 +215,7 @@ def import_db(
LOGGER.info(f"Importing new database; {copy_hosting_settings=}")
cursor_new = MindCursor(
Connection(new_db_file, timeout=Constants.DB_TIMEOUT)
)
cursor_new.row_factory = Row
cursor_new = DBConnection(db_file=new_db_file).cursor()
config_current = ConfigDB()
config_new = ConfigDB(cursor_new)
@@ -259,12 +253,12 @@ def import_db(
cursor_new.connection.close()
move(
DBConnection.file,
join(dirname(DBConnection.file), Constants.DB_ORIGINAL_NAME)
DBConnection.default_file,
join(dirname(DBConnection.default_file), Constants.DB_ORIGINAL_NAME)
)
move(
new_db_file,
DBConnection.file
DBConnection.default_file
)
Server().restart(StartType.RESTART_DB_CHANGES)

View File

@@ -40,7 +40,7 @@ def get_about_data() -> Dict[str, Any]:
"version": version,
"python_version": get_python_version(),
"database_version": get_latest_db_version(),
"database_location": DBConnection.file,
"database_location": DBConnection.default_file,
"data_folder": folder_path()
}

View File

@@ -12,7 +12,7 @@ class Test_DB(unittest.TestCase):
app = Flask(__name__)
app.teardown_appcontext(close_db)
DBConnection.file = join(
DBConnection.default_file = join(
folder_path(*Constants.DB_FOLDER),
Constants.DB_NAME
)