diff --git a/backend/internals/db.py b/backend/internals/db.py index 85859b6..ca54a99 100644 --- a/backend/internals/db.py +++ b/backend/internals/db.py @@ -17,7 +17,6 @@ from flask import g from backend.base.definitions import Constants, ReminderType, T from backend.base.helpers import create_folder, folder_path, rename_file from backend.base.logging import LOGGER, set_log_level -from backend.internals.db_migration import migrate_db REMINDER_TO_KEY = { ReminderType.REMINDER: "reminder_id", @@ -277,6 +276,7 @@ def setup_db() -> None: Setup the database tables and default config when they aren't setup yet """ from backend.implementations.users import Users + from backend.internals.db_migration import migrate_db from backend.internals.settings import Settings cursor = get_db() diff --git a/backend/internals/db_migration.py b/backend/internals/db_migration.py index cf7e851..e890b53 100644 --- a/backend/internals/db_migration.py +++ b/backend/internals/db_migration.py @@ -1,29 +1,35 @@ # -*- coding: utf-8 -*- +from functools import lru_cache from typing import Dict, Type from backend.base.definitions import Constants, DBMigrator from backend.base.logging import LOGGER +from backend.internals.db import get_db, iter_commit -class VersionMappingContainer: - version_map: Dict[int, Type[DBMigrator]] = {} +@lru_cache(1) +def get_db_migration_map() -> Dict[int, Type[DBMigrator]]: + """Get a map of the database version to the migrator class for that version + to one database version higher. E.g. 2 -> Migrate2To3. - -def _load_version_map() -> None: - if VersionMappingContainer.version_map: - return - - VersionMappingContainer.version_map = { + Returns: + Dict[int, Type[DBMigrator]]: The map. + """ + return { m.start_version: m for m in DBMigrator.__subclasses__() } - return +@lru_cache(1) def get_latest_db_version() -> int: - _load_version_map() - return max(VersionMappingContainer.version_map) + 1 + """Get the latest database version supported. + + Returns: + int: The version. + """ + return max(get_db_migration_map()) + 1 def migrate_db() -> None: @@ -31,28 +37,31 @@ def migrate_db() -> None: Migrate a MIND database from it's current version to the newest version supported by the MIND version installed. """ - from backend.internals.db import iter_commit from backend.internals.settings import Settings s = Settings() - current_db_version = s.get_settings().database_version + current_db_version = s.sv.database_version newest_version = get_latest_db_version() if current_db_version == newest_version: + get_db_migration_map.cache_clear() return - LOGGER.info('Migrating database to newer version...') + LOGGER.info("Migrating database to newer version...") LOGGER.debug( "Database migration: %d -> %d", current_db_version, newest_version ) + db_migration_map = get_db_migration_map() for start_version in iter_commit(range(current_db_version, newest_version)): - if start_version not in VersionMappingContainer.version_map: + if start_version not in db_migration_map: continue - VersionMappingContainer.version_map[start_version]().run() - s.update({'database_version': start_version + 1}) + db_migration_map[start_version]().run() + s.update({"database_version": start_version + 1}) + get_db().execute("VACUUM;") s._fetch_settings() + get_db_migration_map.cache_clear() return @@ -66,8 +75,6 @@ class MigrateToUTC(DBMigrator): from datetime import datetime from time import time - from backend.internals.db import get_db - cursor = get_db() t = time() @@ -97,13 +104,11 @@ class MigrateAddColor(DBMigrator): def run(self) -> None: # V2 -> V3 - from backend.internals.db import get_db - get_db().executescript(""" ALTER TABLE reminders - ADD color VARCHAR(7); + ADD color VARCHAR(7); ALTER TABLE templates - ADD color VARCHAR(7); + ADD color VARCHAR(7); """) return @@ -115,8 +120,6 @@ class MigrateFixRQ(DBMigrator): def run(self) -> None: # V3 -> V4 - from backend.internals.db import get_db - get_db().executescript(""" UPDATE reminders SET repeat_quantity = repeat_quantity || 's' @@ -132,8 +135,6 @@ class MigrateToReminderServices(DBMigrator): def run(self) -> None: # V4 -> V5 - from backend.internals.db import get_db - get_db().executescript(""" BEGIN TRANSACTION; PRAGMA defer_foreign_keys = ON; @@ -146,12 +147,17 @@ class MigrateToReminderServices(DBMigrator): ); -- Reminders - INSERT INTO temp_reminder_services(reminder_id, notification_service_id) + INSERT INTO temp_reminder_services( + reminder_id, notification_service_id + ) SELECT id, notification_service FROM reminders; CREATE TEMPORARY TABLE temp_reminders AS - SELECT id, user_id, title, text, time, repeat_quantity, repeat_interval, original_time, color + SELECT + id, user_id, title, text, + time, repeat_quantity, repeat_interval, original_time, + color FROM reminders; DROP TABLE reminders; CREATE TABLE reminders( @@ -173,7 +179,9 @@ class MigrateToReminderServices(DBMigrator): SELECT * FROM temp_reminders; -- Templates - INSERT INTO temp_reminder_services(template_id, notification_service_id) + INSERT INTO temp_reminder_services( + template_id, notification_service_id + ) SELECT id, notification_service FROM templates; @@ -227,11 +235,9 @@ class MigrateAddWeekdays(DBMigrator): def run(self) -> None: # V6 -> V7 - from backend.internals.db import get_db - get_db().executescript(""" ALTER TABLE reminders - ADD weekdays VARCHAR(13); + ADD weekdays VARCHAR(13); """) return @@ -244,7 +250,6 @@ class MigrateAddAdmin(DBMigrator): # V7 -> V8 from backend.implementations.users import Users - from backend.internals.db import get_db from backend.internals.settings import Settings cursor = get_db() @@ -261,7 +266,7 @@ class MigrateAddAdmin(DBMigrator): cursor.executescript(""" ALTER TABLE users - ADD admin BOOL NOT NULL DEFAULT 0; + ADD admin BOOL NOT NULL DEFAULT 0; """ ) users = Users() @@ -311,11 +316,9 @@ class MigrateAddEnabled(DBMigrator): def run(self) -> None: # V10 -> V11 - from backend.internals.db import get_db - get_db().execute(""" ALTER TABLE reminders - ADD enabled BOOL NOT NULL DEFAULT 1; + ADD enabled BOOL NOT NULL DEFAULT 1; """) return @@ -342,8 +345,6 @@ class MigrateAddCronScheduleColumn(DBMigrator): def run(self) -> None: # V12 -> V13 - from backend.internals.db import get_db - get_db().executescript(""" BEGIN TRANSACTION; PRAGMA defer_foreign_keys = ON;