mirror of
https://github.com/Casvt/MIND.git
synced 2026-04-03 03:00:22 -04:00
Refactored backend (Fixes #87)
This commit is contained in:
493
backend/internals/db.py
Normal file
493
backend/internals/db.py
Normal file
@@ -0,0 +1,493 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
Setting up the database and handling connections
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from os import remove
|
||||
from os.path import dirname, exists, isdir, isfile, join
|
||||
from shutil import move
|
||||
from sqlite3 import (PARSE_DECLTYPES, Connection, Cursor,
|
||||
OperationalError, ProgrammingError, Row,
|
||||
register_adapter, register_converter)
|
||||
from threading import current_thread
|
||||
from typing import Any, Dict, Generator, Iterable, List, Type, Union
|
||||
|
||||
from flask import g
|
||||
|
||||
from backend.base.custom_exceptions import InvalidDatabaseFile
|
||||
from backend.base.definitions import Constants, ReminderType, StartType, T
|
||||
from backend.base.helpers import create_folder, folder_path, rename_file
|
||||
from backend.base.logging import LOGGER, set_log_level
|
||||
from backend.internals.db_migration import get_latest_db_version, migrate_db
|
||||
|
||||
REMINDER_TO_KEY = {
|
||||
ReminderType.REMINDER: "reminder_id",
|
||||
ReminderType.STATIC_REMINDER: "static_reminder_id",
|
||||
ReminderType.TEMPLATE: "template_id"
|
||||
}
|
||||
|
||||
|
||||
class MindCursor(Cursor):
|
||||
|
||||
row_factory: Union[Type[Row], None] # type: ignore
|
||||
|
||||
@property
|
||||
def lastrowid(self) -> int:
|
||||
return super().lastrowid or 1
|
||||
|
||||
def fetchonedict(self) -> Union[Dict[str, Any], None]:
|
||||
"""Same as `fetchone` but convert the Row object to a dict.
|
||||
|
||||
Returns:
|
||||
Union[Dict[str, Any], None]: The dict or None i.c.o. no result.
|
||||
"""
|
||||
r = self.fetchone()
|
||||
if r is None:
|
||||
return r
|
||||
return dict(r)
|
||||
|
||||
def fetchmanydict(self, size: Union[int, None] = 1) -> List[Dict[str, Any]]:
|
||||
"""Same as `fetchmany` but convert the Row object to a dict.
|
||||
|
||||
Args:
|
||||
size (Union[int, None], optional): The amount of rows to return.
|
||||
Defaults to 1.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: The rows.
|
||||
"""
|
||||
return [dict(e) for e in self.fetchmany(size)]
|
||||
|
||||
def fetchalldict(self) -> List[Dict[str, Any]]:
|
||||
"""Same as `fetchall` but convert the Row object to a dict.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: The results.
|
||||
"""
|
||||
return [dict(e) for e in self]
|
||||
|
||||
def exists(self) -> Union[Any, None]:
|
||||
"""Return the first column of the first row, or `None` if not found.
|
||||
|
||||
Returns:
|
||||
Union[Any, None]: The value of the first column of the first row,
|
||||
or `None` if not found.
|
||||
"""
|
||||
r = self.fetchone()
|
||||
if r is None:
|
||||
return r
|
||||
return r[0]
|
||||
|
||||
|
||||
class DBConnectionManager(type):
|
||||
instances: Dict[int, DBConnection] = {}
|
||||
|
||||
def __call__(cls, *args: Any, **kwargs: Any) -> DBConnection:
|
||||
thread_id = current_thread().native_id or -1
|
||||
|
||||
if (
|
||||
not thread_id in cls.instances
|
||||
or cls.instances[thread_id].closed
|
||||
):
|
||||
cls.instances[thread_id] = super().__call__(*args, **kwargs)
|
||||
|
||||
return cls.instances[thread_id]
|
||||
|
||||
|
||||
class DBConnection(Connection, metaclass=DBConnectionManager):
|
||||
file = ''
|
||||
|
||||
def __init__(self, timeout: float) -> None:
|
||||
"""Create a connection with a database.
|
||||
|
||||
Args:
|
||||
timeout (float): How long to wait before giving up on a command.
|
||||
"""
|
||||
LOGGER.debug(f'Creating connection {self}')
|
||||
super().__init__(
|
||||
self.file,
|
||||
timeout=timeout,
|
||||
detect_types=PARSE_DECLTYPES
|
||||
)
|
||||
super().cursor().execute("PRAGMA foreign_keys = ON;")
|
||||
self.closed = False
|
||||
return
|
||||
|
||||
def cursor( # type: ignore
|
||||
self,
|
||||
force_new: bool = False
|
||||
) -> MindCursor:
|
||||
"""Get a database cursor from the connection.
|
||||
|
||||
Args:
|
||||
force_new (bool, optional): Get a new cursor instead of the cached
|
||||
one.
|
||||
Defaults to False.
|
||||
|
||||
Returns:
|
||||
MindCursor: The database cursor.
|
||||
"""
|
||||
if not hasattr(g, 'cursors'):
|
||||
g.cursors = []
|
||||
|
||||
if not g.cursors:
|
||||
c = MindCursor(self)
|
||||
c.row_factory = Row
|
||||
g.cursors.append(c)
|
||||
|
||||
if not force_new:
|
||||
return g.cursors[0]
|
||||
else:
|
||||
c = MindCursor(self)
|
||||
c.row_factory = Row
|
||||
g.cursors.append(c)
|
||||
return g.cursors[-1]
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the database connection"""
|
||||
LOGGER.debug(f'Closing connection {self}')
|
||||
self.closed = True
|
||||
super().close()
|
||||
return
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<{self.__class__.__name__}; {current_thread().name}; {id(self)}>'
|
||||
|
||||
|
||||
def set_db_location(
|
||||
db_folder: Union[str, None]
|
||||
) -> None:
|
||||
"""Setup database location. Create folder for database and set location for
|
||||
`db.DBConnection`.
|
||||
|
||||
Args:
|
||||
db_folder (Union[str, None], optional): The folder in which the database
|
||||
will be stored or in which a database is for MIND to use. Give
|
||||
`None` for the default location.
|
||||
|
||||
Raises:
|
||||
ValueError: Value of `db_folder` exists but is not a folder.
|
||||
"""
|
||||
if db_folder:
|
||||
if exists(db_folder) and not isdir(db_folder):
|
||||
raise ValueError('Database location is not a folder')
|
||||
|
||||
db_file_location = join(
|
||||
db_folder or folder_path(*Constants.DB_FOLDER),
|
||||
Constants.DB_NAME
|
||||
)
|
||||
|
||||
LOGGER.debug(f'Setting database location: {db_file_location}')
|
||||
|
||||
create_folder(dirname(db_file_location))
|
||||
|
||||
if isfile(folder_path('db', 'Noted.db')):
|
||||
rename_file(
|
||||
folder_path('db', 'Noted.db'),
|
||||
db_file_location
|
||||
)
|
||||
|
||||
DBConnection.file = db_file_location
|
||||
|
||||
return
|
||||
|
||||
|
||||
def get_db(force_new: bool = False) -> MindCursor:
|
||||
"""Get a database cursor instance or create a new one if needed.
|
||||
|
||||
Args:
|
||||
force_new (bool, optional): Decides if a new cursor is
|
||||
returned instead of the standard one.
|
||||
Defaults to False.
|
||||
|
||||
Returns:
|
||||
MindCursor: Database cursor instance that outputs Row objects.
|
||||
"""
|
||||
cursor = (
|
||||
DBConnection(timeout=Constants.DB_TIMEOUT)
|
||||
.cursor(force_new=force_new)
|
||||
)
|
||||
return cursor
|
||||
|
||||
|
||||
def commit() -> None:
|
||||
"""Commit the database"""
|
||||
get_db().connection.commit()
|
||||
return
|
||||
|
||||
|
||||
def iter_commit(iterable: Iterable[T]) -> Generator[T, Any, Any]:
|
||||
"""Commit the database after each iteration. Also commits just before the
|
||||
first iteration starts.
|
||||
|
||||
Args:
|
||||
iterable (Iterable[T]): Iterable that will be iterated over like normal.
|
||||
|
||||
Yields:
|
||||
Generator[T, Any, Any]: Items of iterable.
|
||||
"""
|
||||
commit = get_db().connection.commit
|
||||
commit()
|
||||
for i in iterable:
|
||||
yield i
|
||||
commit()
|
||||
return
|
||||
|
||||
|
||||
def close_db(e: Union[None, BaseException] = None) -> None:
|
||||
"""Close database cursor, commit database and close database.
|
||||
|
||||
Args:
|
||||
e (Union[None, BaseException], optional): Error. Defaults to None.
|
||||
"""
|
||||
try:
|
||||
cursors = g.cursors
|
||||
db: DBConnection = cursors[0].connection
|
||||
for c in cursors:
|
||||
c.close()
|
||||
delattr(g, 'cursors')
|
||||
db.commit()
|
||||
if not current_thread().name.startswith('waitress-'):
|
||||
db.close()
|
||||
|
||||
except (AttributeError, ProgrammingError):
|
||||
pass
|
||||
|
||||
return
|
||||
|
||||
|
||||
def close_all_db() -> None:
|
||||
"Close all non-temporary database connections that are still open"
|
||||
LOGGER.debug('Closing any open database connections')
|
||||
|
||||
for i in DBConnectionManager.instances.values():
|
||||
if not i.closed:
|
||||
i.close()
|
||||
|
||||
c = DBConnection(timeout=20.0)
|
||||
c.commit()
|
||||
c.close()
|
||||
return
|
||||
|
||||
|
||||
def setup_db() -> None:
|
||||
"""
|
||||
Setup the database tables and default config when they aren't setup yet
|
||||
"""
|
||||
from backend.implementations.users import Users
|
||||
from backend.internals.settings import Settings
|
||||
|
||||
cursor = get_db()
|
||||
cursor.execute("PRAGMA journal_mode = wal;")
|
||||
register_adapter(bool, lambda b: int(b))
|
||||
register_converter("BOOL", lambda b: b == b'1')
|
||||
|
||||
cursor.executescript("""
|
||||
CREATE TABLE IF NOT EXISTS users(
|
||||
id INTEGER PRIMARY KEY,
|
||||
username VARCHAR(255) UNIQUE NOT NULL,
|
||||
salt VARCHAR(40) NOT NULL,
|
||||
hash VARCHAR(100) NOT NULL,
|
||||
admin BOOL NOT NULL DEFAULT 0
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS notification_services(
|
||||
id INTEGER PRIMARY KEY,
|
||||
user_id INTEGER NOT NULL,
|
||||
title VARCHAR(255),
|
||||
url TEXT,
|
||||
|
||||
FOREIGN KEY (user_id) REFERENCES users(id)
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS reminders(
|
||||
id INTEGER PRIMARY KEY,
|
||||
user_id INTEGER NOT NULL,
|
||||
title VARCHAR(255) NOT NULL,
|
||||
text TEXT,
|
||||
time INTEGER NOT NULL,
|
||||
|
||||
repeat_quantity VARCHAR(15),
|
||||
repeat_interval INTEGER,
|
||||
original_time INTEGER,
|
||||
weekdays VARCHAR(13),
|
||||
|
||||
color VARCHAR(7),
|
||||
|
||||
FOREIGN KEY (user_id) REFERENCES users(id)
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS templates(
|
||||
id INTEGER PRIMARY KEY,
|
||||
user_id INTEGER NOT NULL,
|
||||
title VARCHAR(255) NOT NULL,
|
||||
text TEXT,
|
||||
|
||||
color VARCHAR(7),
|
||||
|
||||
FOREIGN KEY (user_id) REFERENCES users(id)
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS static_reminders(
|
||||
id INTEGER PRIMARY KEY,
|
||||
user_id INTEGER NOT NULL,
|
||||
title VARCHAR(255) NOT NULL,
|
||||
text TEXT,
|
||||
|
||||
color VARCHAR(7),
|
||||
|
||||
FOREIGN KEY (user_id) REFERENCES users(id)
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS reminder_services(
|
||||
reminder_id INTEGER,
|
||||
static_reminder_id INTEGER,
|
||||
template_id INTEGER,
|
||||
notification_service_id INTEGER NOT NULL,
|
||||
|
||||
FOREIGN KEY (reminder_id) REFERENCES reminders(id)
|
||||
ON DELETE CASCADE,
|
||||
FOREIGN KEY (static_reminder_id) REFERENCES static_reminders(id)
|
||||
ON DELETE CASCADE,
|
||||
FOREIGN KEY (template_id) REFERENCES templates(id)
|
||||
ON DELETE CASCADE,
|
||||
FOREIGN KEY (notification_service_id) REFERENCES notification_services(id)
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS config(
|
||||
key VARCHAR(255) PRIMARY KEY,
|
||||
value BLOB NOT NULL
|
||||
);
|
||||
""")
|
||||
|
||||
settings = Settings()
|
||||
settings_values = settings.get_settings()
|
||||
|
||||
set_log_level(settings_values.log_level)
|
||||
|
||||
migrate_db()
|
||||
|
||||
# DB Migration might change settings, so update cache just to be sure.
|
||||
settings._fetch_settings()
|
||||
|
||||
# Add admin user if it doesn't exist
|
||||
users = Users()
|
||||
if Constants.ADMIN_USERNAME not in users:
|
||||
users.add(
|
||||
Constants.ADMIN_USERNAME, Constants.ADMIN_PASSWORD,
|
||||
force=True,
|
||||
is_admin=True
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
|
||||
def revert_db_import(
|
||||
swap: bool,
|
||||
imported_db_file: str = ''
|
||||
) -> None:
|
||||
"""Revert the database import process. The original_db_file is the file
|
||||
currently used (`DBConnection.file`).
|
||||
|
||||
Args:
|
||||
swap (bool): Whether or not to keep the imported_db_file or not,
|
||||
instead of the original_db_file.
|
||||
|
||||
imported_db_file (str, optional): The other database file. Keep empty
|
||||
to use `Constants.DB_ORIGINAL_FILENAME`.
|
||||
Defaults to ''.
|
||||
"""
|
||||
original_db_file = DBConnection.file
|
||||
if not imported_db_file:
|
||||
imported_db_file = join(
|
||||
dirname(DBConnection.file),
|
||||
Constants.DB_ORIGINAL_NAME
|
||||
)
|
||||
|
||||
if swap:
|
||||
remove(original_db_file)
|
||||
move(
|
||||
imported_db_file,
|
||||
original_db_file
|
||||
)
|
||||
|
||||
else:
|
||||
remove(imported_db_file)
|
||||
|
||||
return
|
||||
|
||||
|
||||
def import_db(
|
||||
new_db_file: str,
|
||||
copy_hosting_settings: bool
|
||||
) -> None:
|
||||
"""Replace the current database with a new one.
|
||||
|
||||
Args:
|
||||
new_db_file (str): The path to the new database file.
|
||||
copy_hosting_settings (bool): Keep the hosting settings from the current
|
||||
database.
|
||||
|
||||
Raises:
|
||||
InvalidDatabaseFile: The new database file is invalid or unsupported.
|
||||
"""
|
||||
LOGGER.info(f'Importing new database; {copy_hosting_settings=}')
|
||||
|
||||
cursor = Connection(new_db_file, timeout=20.0).cursor()
|
||||
try:
|
||||
database_version = cursor.execute(
|
||||
"SELECT value FROM config WHERE key = 'database_version' LIMIT 1;"
|
||||
).fetchone()[0]
|
||||
if not isinstance(database_version, int):
|
||||
raise InvalidDatabaseFile(new_db_file)
|
||||
|
||||
except (OperationalError, InvalidDatabaseFile):
|
||||
LOGGER.error('Uploaded database is not a MIND database file')
|
||||
cursor.connection.close()
|
||||
revert_db_import(
|
||||
swap=False,
|
||||
imported_db_file=new_db_file
|
||||
)
|
||||
raise InvalidDatabaseFile(new_db_file)
|
||||
|
||||
if database_version > get_latest_db_version():
|
||||
LOGGER.error(
|
||||
'Uploaded database is higher version than this MIND installation can support')
|
||||
revert_db_import(
|
||||
swap=False,
|
||||
imported_db_file=new_db_file
|
||||
)
|
||||
raise InvalidDatabaseFile(new_db_file)
|
||||
|
||||
if copy_hosting_settings:
|
||||
hosting_settings = get_db().execute("""
|
||||
SELECT key, value
|
||||
FROM config
|
||||
WHERE key = 'host'
|
||||
OR key = 'port'
|
||||
OR key = 'url_prefix'
|
||||
LIMIT 3;
|
||||
"""
|
||||
).fetchalldict()
|
||||
cursor.executemany("""
|
||||
INSERT INTO config(key, value)
|
||||
VALUES (:key, :value)
|
||||
ON CONFLICT(key) DO
|
||||
UPDATE
|
||||
SET value = :value;
|
||||
""",
|
||||
hosting_settings
|
||||
)
|
||||
cursor.connection.commit()
|
||||
cursor.connection.close()
|
||||
|
||||
move(
|
||||
DBConnection.file,
|
||||
join(dirname(DBConnection.file), Constants.DB_ORIGINAL_NAME)
|
||||
)
|
||||
move(
|
||||
new_db_file,
|
||||
DBConnection.file
|
||||
)
|
||||
|
||||
from backend.internals.server import Server
|
||||
Server().restart(StartType.RESTART_DB_CHANGES)
|
||||
|
||||
return
|
||||
312
backend/internals/db_migration.py
Normal file
312
backend/internals/db_migration.py
Normal file
@@ -0,0 +1,312 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from typing import Dict, Type
|
||||
|
||||
from backend.base.definitions import Constants, DBMigrator
|
||||
from backend.base.logging import LOGGER
|
||||
|
||||
|
||||
class VersionMappingContainer:
|
||||
version_map: Dict[int, Type[DBMigrator]] = {}
|
||||
|
||||
|
||||
def _load_version_map() -> None:
|
||||
if VersionMappingContainer.version_map:
|
||||
return
|
||||
|
||||
VersionMappingContainer.version_map = {
|
||||
m.start_version: m
|
||||
for m in DBMigrator.__subclasses__()
|
||||
}
|
||||
return
|
||||
|
||||
|
||||
def get_latest_db_version() -> int:
|
||||
_load_version_map()
|
||||
return max(VersionMappingContainer.version_map) + 1
|
||||
|
||||
|
||||
def migrate_db() -> None:
|
||||
"""
|
||||
Migrate a MIND database from it's current version
|
||||
to the newest version supported by the MIND version installed.
|
||||
"""
|
||||
from backend.internals.db import iter_commit
|
||||
from backend.internals.settings import Settings
|
||||
|
||||
s = Settings()
|
||||
current_db_version = s.get_settings().database_version
|
||||
newest_version = get_latest_db_version()
|
||||
if current_db_version == newest_version:
|
||||
return
|
||||
|
||||
LOGGER.info('Migrating database to newer version...')
|
||||
LOGGER.debug(
|
||||
"Database migration: %d -> %d",
|
||||
current_db_version, newest_version
|
||||
)
|
||||
|
||||
for start_version in iter_commit(range(current_db_version, newest_version)):
|
||||
if start_version not in VersionMappingContainer.version_map:
|
||||
continue
|
||||
VersionMappingContainer.version_map[start_version]().run()
|
||||
s.update({'database_version': start_version + 1})
|
||||
|
||||
s._fetch_settings()
|
||||
|
||||
return
|
||||
|
||||
|
||||
class MigrateToUTC(DBMigrator):
|
||||
start_version = 1
|
||||
|
||||
def run(self) -> None:
|
||||
# V1 -> V2
|
||||
|
||||
from datetime import datetime
|
||||
from time import time
|
||||
|
||||
from backend.internals.db import get_db
|
||||
|
||||
cursor = get_db()
|
||||
|
||||
t = time()
|
||||
utc_offset = datetime.fromtimestamp(t) - datetime.utcfromtimestamp(t)
|
||||
|
||||
cursor.execute("SELECT time, id FROM reminders;")
|
||||
new_reminders = [
|
||||
[
|
||||
round((
|
||||
datetime.fromtimestamp(r["time"]) - utc_offset
|
||||
).timestamp()),
|
||||
r["id"]
|
||||
]
|
||||
for r in cursor
|
||||
]
|
||||
|
||||
cursor.executemany(
|
||||
"UPDATE reminders SET time = ? WHERE id = ?;",
|
||||
new_reminders
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
class MigrateAddColor(DBMigrator):
|
||||
start_version = 2
|
||||
|
||||
def run(self) -> None:
|
||||
# V2 -> V3
|
||||
|
||||
from backend.internals.db import get_db
|
||||
|
||||
get_db().executescript("""
|
||||
ALTER TABLE reminders
|
||||
ADD color VARCHAR(7);
|
||||
ALTER TABLE templates
|
||||
ADD color VARCHAR(7);
|
||||
""")
|
||||
|
||||
return
|
||||
|
||||
|
||||
class MigrateFixRQ(DBMigrator):
|
||||
start_version = 3
|
||||
|
||||
def run(self) -> None:
|
||||
# V3 -> V4
|
||||
|
||||
from backend.internals.db import get_db
|
||||
|
||||
get_db().executescript("""
|
||||
UPDATE reminders
|
||||
SET repeat_quantity = repeat_quantity || 's'
|
||||
WHERE repeat_quantity NOT LIKE '%s';
|
||||
""")
|
||||
|
||||
return
|
||||
|
||||
|
||||
class MigrateToReminderServices(DBMigrator):
|
||||
start_version = 4
|
||||
|
||||
def run(self) -> None:
|
||||
# V4 -> V5
|
||||
|
||||
from backend.internals.db import get_db
|
||||
|
||||
get_db().executescript("""
|
||||
BEGIN TRANSACTION;
|
||||
PRAGMA defer_foreign_keys = ON;
|
||||
|
||||
CREATE TEMPORARY TABLE temp_reminder_services(
|
||||
reminder_id,
|
||||
static_reminder_id,
|
||||
template_id,
|
||||
notification_service_id
|
||||
);
|
||||
|
||||
-- Reminders
|
||||
INSERT INTO temp_reminder_services(reminder_id, notification_service_id)
|
||||
SELECT id, notification_service
|
||||
FROM reminders;
|
||||
|
||||
CREATE TEMPORARY TABLE temp_reminders AS
|
||||
SELECT id, user_id, title, text, time, repeat_quantity, repeat_interval, original_time, color
|
||||
FROM reminders;
|
||||
DROP TABLE reminders;
|
||||
CREATE TABLE reminders(
|
||||
id INTEGER PRIMARY KEY,
|
||||
user_id INTEGER NOT NULL,
|
||||
title VARCHAR(255) NOT NULL,
|
||||
text TEXT,
|
||||
time INTEGER NOT NULL,
|
||||
|
||||
repeat_quantity VARCHAR(15),
|
||||
repeat_interval INTEGER,
|
||||
original_time INTEGER,
|
||||
|
||||
color VARCHAR(7),
|
||||
|
||||
FOREIGN KEY (user_id) REFERENCES users(id)
|
||||
);
|
||||
INSERT INTO reminders
|
||||
SELECT * FROM temp_reminders;
|
||||
|
||||
-- Templates
|
||||
INSERT INTO temp_reminder_services(template_id, notification_service_id)
|
||||
SELECT id, notification_service
|
||||
FROM templates;
|
||||
|
||||
CREATE TEMPORARY TABLE temp_templates AS
|
||||
SELECT id, user_id, title, text, color
|
||||
FROM templates;
|
||||
DROP TABLE templates;
|
||||
CREATE TABLE templates(
|
||||
id INTEGER PRIMARY KEY,
|
||||
user_id INTEGER NOT NULL,
|
||||
title VARCHAR(255) NOT NULL,
|
||||
text TEXT,
|
||||
|
||||
color VARCHAR(7),
|
||||
|
||||
FOREIGN KEY (user_id) REFERENCES users(id)
|
||||
);
|
||||
INSERT INTO templates
|
||||
SELECT * FROM temp_templates;
|
||||
|
||||
INSERT INTO reminder_services
|
||||
SELECT * FROM temp_reminder_services;
|
||||
|
||||
COMMIT;
|
||||
""")
|
||||
|
||||
return
|
||||
|
||||
|
||||
class MigrateRemoveUser1(DBMigrator):
|
||||
start_version = 5
|
||||
|
||||
def run(self) -> None:
|
||||
# V5 -> V6
|
||||
from backend.base.custom_exceptions import (AccessUnauthorized,
|
||||
UserNotFound)
|
||||
from backend.implementations.users import Users
|
||||
|
||||
try:
|
||||
Users().login('User1', 'Password1').delete()
|
||||
|
||||
except (UserNotFound, AccessUnauthorized):
|
||||
pass
|
||||
|
||||
return
|
||||
|
||||
|
||||
class MigrateAddWeekdays(DBMigrator):
|
||||
start_version = 6
|
||||
|
||||
def run(self) -> None:
|
||||
# V6 -> V7
|
||||
|
||||
from backend.internals.db import get_db
|
||||
|
||||
get_db().executescript("""
|
||||
ALTER TABLE reminders
|
||||
ADD weekdays VARCHAR(13);
|
||||
""")
|
||||
|
||||
return
|
||||
|
||||
|
||||
class MigrateAddAdmin(DBMigrator):
|
||||
start_version = 7
|
||||
|
||||
def run(self) -> None:
|
||||
# V7 -> V8
|
||||
|
||||
from backend.implementations.users import Users
|
||||
from backend.internals.db import get_db
|
||||
from backend.internals.settings import Settings
|
||||
|
||||
cursor = get_db()
|
||||
|
||||
cursor.executescript("""
|
||||
DROP TABLE config;
|
||||
CREATE TABLE IF NOT EXISTS config(
|
||||
key VARCHAR(255) PRIMARY KEY,
|
||||
value BLOB NOT NULL
|
||||
);
|
||||
"""
|
||||
)
|
||||
Settings()._insert_missing_settings()
|
||||
|
||||
cursor.executescript("""
|
||||
ALTER TABLE users
|
||||
ADD admin BOOL NOT NULL DEFAULT 0;
|
||||
"""
|
||||
)
|
||||
users = Users()
|
||||
if 'admin' in users:
|
||||
users.get_one(
|
||||
users.user_db.username_to_id('admin')
|
||||
).update(
|
||||
new_username='admin_old',
|
||||
new_password=None
|
||||
)
|
||||
|
||||
users.add(
|
||||
Constants.ADMIN_USERNAME, Constants.ADMIN_PASSWORD,
|
||||
force=True,
|
||||
is_admin=True
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
|
||||
class MigrateHostSettingsToDB(DBMigrator):
|
||||
start_version = 8
|
||||
|
||||
def run(self) -> None:
|
||||
# V8 -> V9
|
||||
# In newer versions, the variables don't exist anymore, and behaviour
|
||||
# was to then set the values to the default values. But that's already
|
||||
# taken care of by the settings, so nothing to do here anymore.
|
||||
return
|
||||
|
||||
|
||||
class MigrateUpdateManifest(DBMigrator):
|
||||
start_version = 9
|
||||
|
||||
def run(self) -> None:
|
||||
# V9 -> V10
|
||||
|
||||
# Nothing is changed in the database
|
||||
# It's just that this code needs to run once
|
||||
# and the DB migration system does exactly that:
|
||||
# run pieces of code once.
|
||||
from backend.internals.settings import Settings, update_manifest
|
||||
|
||||
update_manifest(
|
||||
Settings().get_settings().url_prefix
|
||||
)
|
||||
|
||||
return
|
||||
815
backend/internals/db_models.py
Normal file
815
backend/internals/db_models.py
Normal file
@@ -0,0 +1,815 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from typing import List, Union
|
||||
|
||||
from backend.base.definitions import (NotificationServiceData, ReminderData,
|
||||
ReminderType, StaticReminderData,
|
||||
TemplateData, UserData)
|
||||
from backend.base.helpers import first_of_column
|
||||
from backend.internals.db import REMINDER_TO_KEY, get_db
|
||||
|
||||
|
||||
class NotificationServicesDB:
|
||||
def __init__(self, user_id: int) -> None:
|
||||
self.user_id = user_id
|
||||
return
|
||||
|
||||
def exists(self, notification_service_id: int) -> bool:
|
||||
return get_db().execute("""
|
||||
SELECT 1
|
||||
FROM notification_services
|
||||
WHERE id = :id
|
||||
AND user_id = :user_id
|
||||
LIMIT 1;
|
||||
""",
|
||||
{
|
||||
'user_id': self.user_id,
|
||||
'id': notification_service_id
|
||||
}
|
||||
).fetchone() is not None
|
||||
|
||||
def fetch(
|
||||
self,
|
||||
notification_service_id: Union[int, None] = None
|
||||
) -> List[NotificationServiceData]:
|
||||
id_filter = ""
|
||||
if notification_service_id:
|
||||
id_filter = "AND id = :ns_id"
|
||||
|
||||
result = get_db().execute(f"""
|
||||
SELECT
|
||||
id, title, url
|
||||
FROM notification_services
|
||||
WHERE user_id = :user_id
|
||||
{id_filter}
|
||||
ORDER BY title, id;
|
||||
""",
|
||||
{
|
||||
"user_id": self.user_id,
|
||||
"ns_id": notification_service_id
|
||||
}
|
||||
).fetchalldict()
|
||||
|
||||
return [
|
||||
NotificationServiceData(**entry)
|
||||
for entry in result
|
||||
]
|
||||
|
||||
def add(
|
||||
self,
|
||||
title: str,
|
||||
url: str
|
||||
) -> int:
|
||||
new_id = get_db().execute("""
|
||||
INSERT INTO notification_services(user_id, title, url)
|
||||
VALUES (?, ?, ?)
|
||||
""",
|
||||
(self.user_id, title, url)
|
||||
).lastrowid
|
||||
return new_id
|
||||
|
||||
def update(
|
||||
self,
|
||||
notification_service_id: int,
|
||||
title: str,
|
||||
url: str
|
||||
) -> None:
|
||||
get_db().execute("""
|
||||
UPDATE notification_services
|
||||
SET title = :title, url = :url
|
||||
WHERE id = :ns_id;
|
||||
""",
|
||||
{
|
||||
"title": title,
|
||||
"url": url,
|
||||
"ns_id": notification_service_id
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
def delete(
|
||||
self,
|
||||
notification_service_id: int
|
||||
) -> None:
|
||||
get_db().execute(
|
||||
"DELETE FROM notification_services WHERE id = ?;",
|
||||
(notification_service_id,)
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
class ReminderServicesDB:
|
||||
def __init__(self, reminder_type: ReminderType) -> None:
|
||||
self.key = REMINDER_TO_KEY[reminder_type]
|
||||
return
|
||||
|
||||
def reminder_to_ns(
|
||||
self,
|
||||
reminder_id: int
|
||||
) -> List[int]:
|
||||
"""Get the ID's of the notification services that are linked to the given
|
||||
reminder, static reminder or template.
|
||||
|
||||
Args:
|
||||
reminder_id (int): The ID of the reminder, static reminder or template.
|
||||
|
||||
Returns:
|
||||
List[int]: A list of the notification service ID's that are linked to
|
||||
the given reminder, static reminder or template.
|
||||
"""
|
||||
result = first_of_column(get_db().execute(
|
||||
f"""
|
||||
SELECT notification_service_id
|
||||
FROM reminder_services
|
||||
WHERE {self.key} = ?;
|
||||
""",
|
||||
(reminder_id,)
|
||||
))
|
||||
|
||||
return result
|
||||
|
||||
def update_ns_bindings(
|
||||
self,
|
||||
reminder_id: int,
|
||||
notification_services: List[int]
|
||||
) -> None:
|
||||
"""Update the bindings of a reminder, static reminder or template to
|
||||
notification services.
|
||||
|
||||
Args:
|
||||
reminder_id (int): The ID of the reminder, static reminder or template.
|
||||
|
||||
notification_services (List[int]): The new list of notification services
|
||||
that should be linked to the reminder, static reminder or template.
|
||||
"""
|
||||
cursor = get_db()
|
||||
cursor.connection.isolation_level = None
|
||||
cursor.execute("BEGIN TRANSACTION;")
|
||||
|
||||
cursor.execute(
|
||||
f"""
|
||||
DELETE FROM reminder_services
|
||||
WHERE {self.key} = ?;
|
||||
""",
|
||||
(reminder_id,)
|
||||
)
|
||||
cursor.executemany(
|
||||
f"""
|
||||
INSERT INTO reminder_services(
|
||||
{self.key},
|
||||
notification_service_id
|
||||
)
|
||||
VALUES (?, ?);
|
||||
""",
|
||||
((reminder_id, ns_id) for ns_id in notification_services)
|
||||
)
|
||||
|
||||
cursor.execute("COMMIT;")
|
||||
cursor.connection.isolation_level = ""
|
||||
return
|
||||
|
||||
def uses_ns(
|
||||
self,
|
||||
notification_service_id: int
|
||||
) -> List[int]:
|
||||
"""Get the ID's of the reminders (of given type) that use the given
|
||||
notification service.
|
||||
|
||||
Args:
|
||||
notification_service_id (int): The ID of the notification service to
|
||||
check for.
|
||||
|
||||
Returns:
|
||||
List[int]: The ID's of the reminders (only of the given type) that
|
||||
use the notification service.
|
||||
"""
|
||||
return first_of_column(get_db().execute(
|
||||
f"""
|
||||
SELECT {self.key}
|
||||
FROM reminder_services
|
||||
WHERE notification_service_id = ?
|
||||
AND {self.key} IS NOT NULL
|
||||
LIMIT 1;
|
||||
""",
|
||||
(notification_service_id,)
|
||||
))
|
||||
|
||||
|
||||
class UsersDB:
|
||||
def exists(self, user_id: int) -> bool:
|
||||
return get_db().execute("""
|
||||
SELECT 1
|
||||
FROM users
|
||||
WHERE id = ?
|
||||
LIMIT 1;
|
||||
""",
|
||||
(user_id,)
|
||||
).fetchone() is not None
|
||||
|
||||
def taken(self, username: str) -> bool:
|
||||
return get_db().execute("""
|
||||
SELECT 1
|
||||
FROM users
|
||||
WHERE username = ?
|
||||
LIMIT 1;
|
||||
""",
|
||||
(username,)
|
||||
).fetchone() is not None
|
||||
|
||||
def username_to_id(self, username: str) -> int:
|
||||
return get_db().execute("""
|
||||
SELECT id
|
||||
FROM users
|
||||
WHERE username = ?
|
||||
LIMIT 1;
|
||||
""",
|
||||
(username,)
|
||||
).fetchone()[0]
|
||||
|
||||
def fetch(
|
||||
self,
|
||||
user_id: Union[int, None] = None
|
||||
) -> List[UserData]:
|
||||
id_filter = ""
|
||||
if user_id:
|
||||
id_filter = "WHERE id = :id"
|
||||
|
||||
result = get_db().execute(f"""
|
||||
SELECT
|
||||
id, username, admin, salt, hash
|
||||
FROM users
|
||||
{id_filter}
|
||||
ORDER BY username, id;
|
||||
""",
|
||||
{
|
||||
"id": user_id
|
||||
}
|
||||
).fetchalldict()
|
||||
|
||||
return [
|
||||
UserData(**entry)
|
||||
for entry in result
|
||||
]
|
||||
|
||||
def add(
|
||||
self,
|
||||
username: str,
|
||||
salt: bytes,
|
||||
hash: bytes,
|
||||
admin: bool
|
||||
) -> int:
|
||||
user_id = get_db().execute(
|
||||
"""
|
||||
INSERT INTO users(username, salt, hash, admin)
|
||||
VALUES (?, ?, ?, ?);
|
||||
""",
|
||||
(username, salt, hash, admin)
|
||||
).lastrowid
|
||||
return user_id
|
||||
|
||||
def update(
|
||||
self,
|
||||
user_id: int,
|
||||
username: str,
|
||||
hash: bytes
|
||||
) -> None:
|
||||
get_db().execute("""
|
||||
UPDATE users
|
||||
SET username = :username, hash = :hash
|
||||
WHERE id = :user_id;
|
||||
""",
|
||||
{
|
||||
"username": username,
|
||||
"hash": hash,
|
||||
"user_id": user_id
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
def delete(
|
||||
self,
|
||||
user_id: int
|
||||
) -> None:
|
||||
get_db().executescript(f"""
|
||||
BEGIN TRANSACTION;
|
||||
|
||||
DELETE FROM reminders WHERE user_id = {user_id};
|
||||
DELETE FROM templates WHERE user_id = {user_id};
|
||||
DELETE FROM static_reminders WHERE user_id = {user_id};
|
||||
DELETE FROM notification_services WHERE user_id = {user_id};
|
||||
DELETE FROM users WHERE id = {user_id};
|
||||
|
||||
COMMIT;
|
||||
""")
|
||||
return
|
||||
|
||||
|
||||
class TemplatesDB:
|
||||
def __init__(self, user_id: int) -> None:
|
||||
self.user_id = user_id
|
||||
self.rms_db = ReminderServicesDB(ReminderType.TEMPLATE)
|
||||
return
|
||||
|
||||
def exists(self, template_id: int) -> bool:
|
||||
return get_db().execute(
|
||||
"SELECT 1 FROM templates WHERE id = ? AND user_id = ? LIMIT 1;",
|
||||
(template_id, self.user_id)
|
||||
).fetchone() is not None
|
||||
|
||||
def fetch(
|
||||
self,
|
||||
template_id: Union[int, None] = None
|
||||
) -> List[TemplateData]:
|
||||
id_filter = ""
|
||||
if template_id:
|
||||
id_filter = "AND id = :t_id"
|
||||
|
||||
result = get_db().execute(f"""
|
||||
SELECT
|
||||
id, title, text, color
|
||||
FROM templates
|
||||
WHERE user_id = :user_id
|
||||
{id_filter}
|
||||
ORDER BY title, id;
|
||||
""",
|
||||
{
|
||||
"user_id": self.user_id,
|
||||
"t_id": template_id
|
||||
}
|
||||
).fetchalldict()
|
||||
|
||||
for r in result:
|
||||
r['notification_services'] = self.rms_db.reminder_to_ns(r['id'])
|
||||
|
||||
return [
|
||||
TemplateData(**entry)
|
||||
for entry in result
|
||||
]
|
||||
|
||||
def add(
|
||||
self,
|
||||
title: str,
|
||||
text: Union[str, None],
|
||||
color: Union[str, None],
|
||||
notification_services: List[int]
|
||||
) -> int:
|
||||
new_id = get_db().execute("""
|
||||
INSERT INTO templates(user_id, title, text, color)
|
||||
VALUES (?, ?, ?, ?);
|
||||
""",
|
||||
(self.user_id, title, text, color)
|
||||
).lastrowid
|
||||
|
||||
self.rms_db.update_ns_bindings(
|
||||
new_id, notification_services
|
||||
)
|
||||
|
||||
return new_id
|
||||
|
||||
def update(
|
||||
self,
|
||||
template_id: int,
|
||||
title: str,
|
||||
text: Union[str, None],
|
||||
color: Union[str, None],
|
||||
notification_services: List[int]
|
||||
) -> None:
|
||||
get_db().execute("""
|
||||
UPDATE templates
|
||||
SET
|
||||
title = :title,
|
||||
text = :text,
|
||||
color = :color
|
||||
WHERE id = :t_id;
|
||||
""",
|
||||
{
|
||||
"title": title,
|
||||
"text": text,
|
||||
"color": color,
|
||||
"t_id": template_id
|
||||
}
|
||||
)
|
||||
|
||||
self.rms_db.update_ns_bindings(
|
||||
template_id,
|
||||
notification_services
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
def delete(
|
||||
self,
|
||||
template_id: int
|
||||
) -> None:
|
||||
get_db().execute(
|
||||
"DELETE FROM templates WHERE id = ?;",
|
||||
(template_id,)
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
class StaticRemindersDB:
|
||||
def __init__(self, user_id: int) -> None:
|
||||
self.user_id = user_id
|
||||
self.rms_db = ReminderServicesDB(ReminderType.STATIC_REMINDER)
|
||||
return
|
||||
|
||||
def exists(self, reminder_id: int) -> bool:
|
||||
return get_db().execute("""
|
||||
SELECT 1
|
||||
FROM static_reminders
|
||||
WHERE id = ?
|
||||
AND user_id = ?
|
||||
LIMIT 1;
|
||||
""",
|
||||
(reminder_id, self.user_id)
|
||||
).fetchone() is not None
|
||||
|
||||
def fetch(
|
||||
self,
|
||||
reminder_id: Union[int, None] = None
|
||||
) -> List[StaticReminderData]:
|
||||
id_filter = ""
|
||||
if reminder_id:
|
||||
id_filter = "AND id = :r_id"
|
||||
|
||||
result = get_db().execute(f"""
|
||||
SELECT
|
||||
id, title, text, color
|
||||
FROM static_reminders
|
||||
WHERE user_id = :user_id
|
||||
{id_filter}
|
||||
ORDER BY title, id;
|
||||
""",
|
||||
{
|
||||
"user_id": self.user_id,
|
||||
"r_id": reminder_id
|
||||
}
|
||||
).fetchalldict()
|
||||
|
||||
for r in result:
|
||||
r['notification_services'] = self.rms_db.reminder_to_ns(r['id'])
|
||||
|
||||
return [
|
||||
StaticReminderData(**entry)
|
||||
for entry in result
|
||||
]
|
||||
|
||||
def add(
|
||||
self,
|
||||
title: str,
|
||||
text: Union[str, None],
|
||||
color: Union[str, None],
|
||||
notification_services: List[int]
|
||||
) -> int:
|
||||
new_id = get_db().execute("""
|
||||
INSERT INTO static_reminders(user_id, title, text, color)
|
||||
VALUES (?, ?, ?, ?);
|
||||
""",
|
||||
(self.user_id, title, text, color)
|
||||
).lastrowid
|
||||
|
||||
self.rms_db.update_ns_bindings(
|
||||
new_id, notification_services
|
||||
)
|
||||
|
||||
return new_id
|
||||
|
||||
def update(
|
||||
self,
|
||||
reminder_id: int,
|
||||
title: str,
|
||||
text: Union[str, None],
|
||||
color: Union[str, None],
|
||||
notification_services: List[int]
|
||||
) -> None:
|
||||
get_db().execute("""
|
||||
UPDATE static_reminders
|
||||
SET
|
||||
title = :title,
|
||||
text = :text,
|
||||
color = :color
|
||||
WHERE id = :r_id;
|
||||
""",
|
||||
{
|
||||
"title": title,
|
||||
"text": text,
|
||||
"color": color,
|
||||
"r_id": reminder_id
|
||||
}
|
||||
)
|
||||
|
||||
self.rms_db.update_ns_bindings(
|
||||
reminder_id,
|
||||
notification_services
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
def delete(
|
||||
self,
|
||||
reminder_id: int
|
||||
) -> None:
|
||||
get_db().execute(
|
||||
"DELETE FROM static_reminders WHERE id = ?;",
|
||||
(reminder_id,)
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
class RemindersDB:
|
||||
def __init__(self, user_id: int) -> None:
|
||||
self.user_id = user_id
|
||||
self.rms_db = ReminderServicesDB(ReminderType.REMINDER)
|
||||
return
|
||||
|
||||
def exists(self, reminder_id: int) -> bool:
|
||||
return get_db().execute("""
|
||||
SELECT 1
|
||||
FROM reminders
|
||||
WHERE id = ?
|
||||
AND user_id = ?
|
||||
LIMIT 1;
|
||||
""",
|
||||
(reminder_id, self.user_id)
|
||||
).fetchone() is not None
|
||||
|
||||
def fetch(
|
||||
self,
|
||||
reminder_id: Union[int, None] = None
|
||||
) -> List[ReminderData]:
|
||||
id_filter = ""
|
||||
if reminder_id:
|
||||
id_filter = "AND id = :r_id"
|
||||
|
||||
result = get_db().execute(f"""
|
||||
SELECT
|
||||
id, title, text, color,
|
||||
time, original_time,
|
||||
repeat_quantity, repeat_interval,
|
||||
weekdays AS _weekdays
|
||||
FROM reminders
|
||||
WHERE user_id = :user_id
|
||||
{id_filter};
|
||||
""",
|
||||
{
|
||||
"user_id": self.user_id,
|
||||
"r_id": reminder_id
|
||||
}
|
||||
).fetchalldict()
|
||||
|
||||
for r in result:
|
||||
r['notification_services'] = self.rms_db.reminder_to_ns(r['id'])
|
||||
|
||||
return [
|
||||
ReminderData(**entry)
|
||||
for entry in result
|
||||
]
|
||||
|
||||
def add(
|
||||
self,
|
||||
title: str,
|
||||
text: Union[str, None],
|
||||
time: int,
|
||||
repeat_quantity: Union[str, None],
|
||||
repeat_interval: Union[int, None],
|
||||
weekdays: Union[str, None],
|
||||
original_time: Union[int, None],
|
||||
color: Union[str, None],
|
||||
notification_services: List[int]
|
||||
) -> int:
|
||||
new_id = get_db().execute("""
|
||||
INSERT INTO reminders(
|
||||
user_id,
|
||||
title, text,
|
||||
time,
|
||||
repeat_quantity, repeat_interval,
|
||||
weekdays,
|
||||
original_time,
|
||||
color
|
||||
)
|
||||
VALUES (
|
||||
:user_id,
|
||||
:title, :text,
|
||||
:time,
|
||||
:rq, :ri,
|
||||
:wd,
|
||||
:ot,
|
||||
:color
|
||||
);
|
||||
""",
|
||||
{
|
||||
"user_id": self.user_id,
|
||||
"title": title,
|
||||
"text": text,
|
||||
"time": time,
|
||||
"rq": repeat_quantity,
|
||||
"ri": repeat_interval,
|
||||
"wd": weekdays,
|
||||
"ot": original_time,
|
||||
"color": color
|
||||
}
|
||||
).lastrowid
|
||||
|
||||
self.rms_db.update_ns_bindings(
|
||||
new_id, notification_services
|
||||
)
|
||||
|
||||
return new_id
|
||||
|
||||
def update(
|
||||
self,
|
||||
reminder_id: int,
|
||||
title: str,
|
||||
text: Union[str, None],
|
||||
time: int,
|
||||
repeat_quantity: Union[str, None],
|
||||
repeat_interval: Union[int, None],
|
||||
weekdays: Union[str, None],
|
||||
original_time: Union[int, None],
|
||||
color: Union[str, None],
|
||||
notification_services: List[int]
|
||||
) -> None:
|
||||
get_db().execute("""
|
||||
UPDATE reminders
|
||||
SET
|
||||
title = :title,
|
||||
text = :text,
|
||||
time = :time,
|
||||
repeat_quantity = :rq,
|
||||
repeat_interval = :ri,
|
||||
weekdays = :wd,
|
||||
original_time = :ot,
|
||||
color = :color
|
||||
WHERE id = :r_id;
|
||||
""",
|
||||
{
|
||||
"title": title,
|
||||
"text": text,
|
||||
"time": time,
|
||||
"rq": repeat_quantity,
|
||||
"ri": repeat_interval,
|
||||
"wd": weekdays,
|
||||
"ot": original_time,
|
||||
"color": color,
|
||||
"r_id": reminder_id
|
||||
}
|
||||
)
|
||||
|
||||
self.rms_db.update_ns_bindings(
|
||||
reminder_id,
|
||||
notification_services
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
def delete(
|
||||
self,
|
||||
reminder_id: int
|
||||
) -> None:
|
||||
get_db().execute(
|
||||
"DELETE FROM reminders WHERE id = ?;",
|
||||
(reminder_id,)
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
class UserlessRemindersDB:
|
||||
def __init__(self) -> None:
|
||||
self.rms_db = ReminderServicesDB(ReminderType.REMINDER)
|
||||
return
|
||||
|
||||
def exists(self, reminder_id: int) -> bool:
|
||||
return get_db().execute("""
|
||||
SELECT 1
|
||||
FROM reminders
|
||||
WHERE id = ?
|
||||
LIMIT 1;
|
||||
""",
|
||||
(reminder_id,)
|
||||
).fetchone() is not None
|
||||
|
||||
def reminder_id_to_user_id(self, reminder_id: int) -> int:
|
||||
return get_db().execute(
|
||||
"""
|
||||
SELECT user_id
|
||||
FROM reminders
|
||||
WHERE id = ?
|
||||
LIMIT 1;
|
||||
""",
|
||||
(reminder_id,)
|
||||
).exists() or -1
|
||||
|
||||
def get_soonest_time(self) -> Union[int, None]:
|
||||
return get_db().execute("SELECT MIN(time) FROM reminders;").exists()
|
||||
|
||||
def fetch(
|
||||
self,
|
||||
time: Union[int, None] = None
|
||||
) -> List[ReminderData]:
|
||||
time_filter = ""
|
||||
if time:
|
||||
time_filter = "WHERE time = :time"
|
||||
|
||||
result = get_db().execute(f"""
|
||||
SELECT
|
||||
id,
|
||||
title, text, color,
|
||||
time, original_time,
|
||||
repeat_quantity, repeat_interval,
|
||||
weekdays AS _weekdays
|
||||
FROM reminders
|
||||
{time_filter};
|
||||
""",
|
||||
{
|
||||
"time": time
|
||||
}
|
||||
).fetchalldict()
|
||||
|
||||
for r in result:
|
||||
r['notification_services'] = self.rms_db.reminder_to_ns(r['id'])
|
||||
|
||||
return [
|
||||
ReminderData(**entry)
|
||||
for entry in result
|
||||
]
|
||||
|
||||
def add(
|
||||
self,
|
||||
user_id: int,
|
||||
title: str,
|
||||
text: Union[str, None],
|
||||
time: int,
|
||||
repeat_quantity: Union[str, None],
|
||||
repeat_interval: Union[int, None],
|
||||
weekdays: Union[str, None],
|
||||
original_time: Union[int, None],
|
||||
color: Union[str, None],
|
||||
notification_services: List[int]
|
||||
) -> int:
|
||||
new_id = get_db().execute("""
|
||||
INSERT INTO reminders(
|
||||
user_id,
|
||||
title, text,
|
||||
time,
|
||||
repeat_quantity, repeat_interval,
|
||||
weekdays,
|
||||
original_time,
|
||||
color
|
||||
)
|
||||
VALUES (
|
||||
:user_id,
|
||||
:title, :text,
|
||||
:time,
|
||||
:rq, :ri,
|
||||
:wd,
|
||||
:ot,
|
||||
:color
|
||||
);
|
||||
""",
|
||||
{
|
||||
"user_id": user_id,
|
||||
"title": title,
|
||||
"text": text,
|
||||
"time": time,
|
||||
"rq": repeat_quantity,
|
||||
"ri": repeat_interval,
|
||||
"wd": weekdays,
|
||||
"ot": original_time,
|
||||
"color": color
|
||||
}
|
||||
).lastrowid
|
||||
|
||||
self.rms_db.update_ns_bindings(
|
||||
new_id, notification_services
|
||||
)
|
||||
|
||||
return new_id
|
||||
|
||||
def update(
|
||||
self,
|
||||
reminder_id: int,
|
||||
time: int
|
||||
) -> None:
|
||||
get_db().execute("""
|
||||
UPDATE reminders
|
||||
SET time = :time
|
||||
WHERE id = :r_id;
|
||||
""",
|
||||
{
|
||||
"time": time,
|
||||
"r_id": reminder_id
|
||||
}
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
def delete(
|
||||
self,
|
||||
reminder_id: int
|
||||
) -> None:
|
||||
get_db().execute(
|
||||
"DELETE FROM reminders WHERE id = ?;",
|
||||
(reminder_id,)
|
||||
)
|
||||
return
|
||||
247
backend/internals/server.py
Normal file
247
backend/internals/server.py
Normal file
@@ -0,0 +1,247 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
Setting up, running and shutting down the API and web-ui
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from os import urandom
|
||||
from threading import Timer, current_thread
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
from flask import Flask, render_template, request
|
||||
from waitress.server import create_server
|
||||
from waitress.task import ThreadedTaskDispatcher as TTD
|
||||
from werkzeug.middleware.dispatcher import DispatcherMiddleware
|
||||
|
||||
from backend.base.definitions import Constants, StartType
|
||||
from backend.base.helpers import Singleton, folder_path
|
||||
from backend.base.logging import LOGGER
|
||||
from backend.internals.db import (DBConnectionManager,
|
||||
close_db, revert_db_import)
|
||||
from backend.internals.settings import Settings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from waitress.server import BaseWSGIServer, MultiSocketServer
|
||||
|
||||
|
||||
class ThreadedTaskDispatcher(TTD):
|
||||
def handler_thread(self, thread_no: int) -> None:
|
||||
super().handler_thread(thread_no)
|
||||
|
||||
thread_id = current_thread().native_id or -1
|
||||
if (
|
||||
thread_id in DBConnectionManager.instances
|
||||
and not DBConnectionManager.instances[thread_id].closed
|
||||
):
|
||||
DBConnectionManager.instances[thread_id].close()
|
||||
|
||||
return
|
||||
|
||||
def shutdown(self, cancel_pending: bool = True, timeout: int = 5) -> bool:
|
||||
print()
|
||||
LOGGER.info('Shutting down MIND')
|
||||
result = super().shutdown(cancel_pending, timeout)
|
||||
return result
|
||||
|
||||
|
||||
def handle_start_type(start_type: StartType) -> None:
|
||||
"""Do special actions needed based on restart version.
|
||||
|
||||
Args:
|
||||
start_type (StartType): The restart version.
|
||||
"""
|
||||
if start_type == StartType.RESTART_HOSTING_CHANGES:
|
||||
LOGGER.info("Starting timer for hosting changes")
|
||||
Server().revert_hosting_timer.start()
|
||||
|
||||
elif start_type == StartType.RESTART_DB_CHANGES:
|
||||
LOGGER.info("Starting timer for database import")
|
||||
Server().revert_db_timer.start()
|
||||
|
||||
return
|
||||
|
||||
|
||||
def diffuse_timers() -> None:
|
||||
"""Stop any timers running after doing a special restart."""
|
||||
|
||||
SERVER = Server()
|
||||
|
||||
if SERVER.revert_hosting_timer.is_alive():
|
||||
LOGGER.info("Timer for hosting changes diffused")
|
||||
SERVER.revert_hosting_timer.cancel()
|
||||
|
||||
elif SERVER.revert_db_timer.is_alive():
|
||||
LOGGER.info("Timer for database import diffused")
|
||||
SERVER.revert_db_timer.cancel()
|
||||
revert_db_import(swap=False)
|
||||
|
||||
return
|
||||
|
||||
|
||||
class Server(metaclass=Singleton):
|
||||
api_prefix = "/api"
|
||||
admin_api_extension = "/admin"
|
||||
admin_prefix = "/api/admin"
|
||||
url_prefix = ''
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.start_type = None
|
||||
|
||||
self.revert_db_timer = Timer(
|
||||
Constants.DB_REVERT_TIME,
|
||||
revert_db_import,
|
||||
kwargs={"swap": True}
|
||||
)
|
||||
self.revert_db_timer.name = "DatabaseImportHandler"
|
||||
|
||||
self.revert_hosting_timer = Timer(
|
||||
Constants.HOSTING_REVERT_TIME,
|
||||
self.restore_hosting_settings
|
||||
)
|
||||
self.revert_hosting_timer.name = "HostingHandler"
|
||||
|
||||
return
|
||||
|
||||
def create_app(self) -> None:
|
||||
"""Creates an flask app instance that can be used to start a web server"""
|
||||
|
||||
from frontend.api import admin_api, api
|
||||
from frontend.ui import ui
|
||||
|
||||
app = Flask(
|
||||
__name__,
|
||||
template_folder=folder_path('frontend', 'templates'),
|
||||
static_folder=folder_path('frontend', 'static'),
|
||||
static_url_path='/static'
|
||||
)
|
||||
app.config['SECRET_KEY'] = urandom(32)
|
||||
app.config['JSONIFY_PRETTYPRINT_REGULAR'] = True
|
||||
app.config['JSON_SORT_KEYS'] = False
|
||||
|
||||
# Add error handlers
|
||||
@app.errorhandler(400)
|
||||
def bad_request(e):
|
||||
return {'error': "BadRequest", "result": {}}, 400
|
||||
|
||||
@app.errorhandler(405)
|
||||
def method_not_allowed(e):
|
||||
return {'error': "MethodNotAllowed", "result": {}}, 405
|
||||
|
||||
@app.errorhandler(500)
|
||||
def internal_error(e):
|
||||
return {'error': "InternalError", "result": {}}, 500
|
||||
|
||||
# Add endpoints
|
||||
app.register_blueprint(ui)
|
||||
app.register_blueprint(api, url_prefix=self.api_prefix)
|
||||
app.register_blueprint(admin_api, url_prefix=self.admin_prefix)
|
||||
|
||||
# Setup db handling
|
||||
app.teardown_appcontext(close_db)
|
||||
|
||||
self.app = app
|
||||
return
|
||||
|
||||
def set_url_prefix(self, url_prefix: str) -> None:
|
||||
"""Change the URL prefix of the server.
|
||||
|
||||
Args:
|
||||
url_prefix (str): The desired URL prefix to set it to.
|
||||
"""
|
||||
self.app.config["APPLICATION_ROOT"] = url_prefix
|
||||
self.app.wsgi_app = DispatcherMiddleware( # type: ignore
|
||||
Flask(__name__),
|
||||
{url_prefix: self.app.wsgi_app}
|
||||
)
|
||||
self.url_prefix = url_prefix
|
||||
return
|
||||
|
||||
def __create_waitress_server(
|
||||
self,
|
||||
host: str,
|
||||
port: int
|
||||
) -> Union[MultiSocketServer, BaseWSGIServer]:
|
||||
"""From the `Flask` instance created in `self.create_app()`, create
|
||||
a waitress server instance.
|
||||
|
||||
Args:
|
||||
host (str): Where to host the server on (e.g. `0.0.0.0`).
|
||||
port (int): The port to host the server on (e.g. `5656`).
|
||||
|
||||
Returns:
|
||||
Union[MultiSocketServer, BaseWSGIServer]: The waitress server instance.
|
||||
"""
|
||||
dispatcher = ThreadedTaskDispatcher()
|
||||
dispatcher.set_thread_count(Constants.HOSTING_THREADS)
|
||||
|
||||
server = create_server(
|
||||
self.app,
|
||||
_dispatcher=dispatcher,
|
||||
host=host,
|
||||
port=port,
|
||||
threads=Constants.HOSTING_THREADS
|
||||
)
|
||||
return server
|
||||
|
||||
def run(self, host: str, port: int) -> None:
|
||||
"""Start the webserver.
|
||||
|
||||
Args:
|
||||
host (str): Where to host the server on (e.g. `0.0.0.0`).
|
||||
port (int): The port to host the server on (e.g. `5656`).
|
||||
"""
|
||||
self.server = self.__create_waitress_server(host, port)
|
||||
LOGGER.info(f'MIND running on http://{host}:{port}{self.url_prefix}')
|
||||
self.server.run()
|
||||
|
||||
return
|
||||
|
||||
def __shutdown_thread_function(self) -> None:
|
||||
"""Shutdown waitress server. Intended to be run in a thread.
|
||||
"""
|
||||
if not hasattr(self, 'server'):
|
||||
return
|
||||
|
||||
self.server.task_dispatcher.shutdown()
|
||||
self.server.close()
|
||||
self.server._map.clear() # type: ignore
|
||||
return
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""
|
||||
Stop the waitress server. Starts a thread that shuts down the server.
|
||||
"""
|
||||
t = Timer(1.0, self.__shutdown_thread_function)
|
||||
t.name = "InternalStateHandler"
|
||||
t.start()
|
||||
return
|
||||
|
||||
def restart(
|
||||
self,
|
||||
start_type: StartType = StartType.STARTUP
|
||||
) -> None:
|
||||
"""Same as `self.shutdown()`, but restart instead of shutting down.
|
||||
|
||||
Args:
|
||||
start_type (StartType, optional): Why Kapowarr should
|
||||
restart.
|
||||
Defaults to StartType.STARTUP.
|
||||
"""
|
||||
self.start_type = start_type
|
||||
self.shutdown()
|
||||
return
|
||||
|
||||
def restore_hosting_settings(self) -> None:
|
||||
with self.app.app_context():
|
||||
settings = Settings()
|
||||
values = settings.get_settings()
|
||||
main_settings = {
|
||||
'host': values.backup_host,
|
||||
'port': values.backup_port,
|
||||
'url_prefix': values.backup_url_prefix
|
||||
}
|
||||
settings.update(main_settings)
|
||||
self.restart()
|
||||
return
|
||||
255
backend/internals/settings.py
Normal file
255
backend/internals/settings.py
Normal file
@@ -0,0 +1,255 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from dataclasses import _MISSING_TYPE, asdict, dataclass
|
||||
from functools import lru_cache
|
||||
from json import dump, load
|
||||
from logging import DEBUG, INFO
|
||||
from typing import Any, Dict, Mapping
|
||||
|
||||
from backend.base.custom_exceptions import InvalidKeyValue, KeyNotFound
|
||||
from backend.base.helpers import (Singleton, folder_path,
|
||||
get_python_version, reversed_tuples)
|
||||
from backend.base.logging import LOGGER, set_log_level
|
||||
from backend.internals.db import DBConnection, commit, get_db
|
||||
from backend.internals.db_migration import get_latest_db_version
|
||||
|
||||
THIRTY_DAYS = 2592000
|
||||
|
||||
|
||||
@lru_cache(1)
|
||||
def get_about_data() -> Dict[str, Any]:
|
||||
"""Get data about the application and it's environment.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the version is not found in the pyproject.toml file.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The information.
|
||||
"""
|
||||
with open(folder_path("pyproject.toml"), "r") as f:
|
||||
for line in f:
|
||||
if line.startswith("version = "):
|
||||
version = "V" + line.split('"')[1]
|
||||
break
|
||||
else:
|
||||
raise RuntimeError("Version not found in pyproject.toml")
|
||||
|
||||
return {
|
||||
"version": version,
|
||||
"python_version": get_python_version(),
|
||||
"database_version": get_latest_db_version(),
|
||||
"database_location": DBConnection.file,
|
||||
"data_folder": folder_path()
|
||||
}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SettingsValues:
|
||||
database_version: int = get_latest_db_version()
|
||||
log_level: int = INFO
|
||||
|
||||
host: str = '0.0.0.0'
|
||||
port: int = 8080
|
||||
url_prefix: str = ''
|
||||
backup_host: str = '0.0.0.0'
|
||||
backup_port: int = 8080
|
||||
backup_url_prefix: str = ''
|
||||
|
||||
allow_new_accounts: bool = True
|
||||
login_time: int = 3600
|
||||
login_time_reset: bool = True
|
||||
|
||||
def todict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
k: v
|
||||
for k, v in self.__dict__.items()
|
||||
if not k.startswith('backup_')
|
||||
}
|
||||
|
||||
|
||||
class Settings(metaclass=Singleton):
|
||||
def __init__(self) -> None:
|
||||
self._insert_missing_settings()
|
||||
self._fetch_settings()
|
||||
return
|
||||
|
||||
def _insert_missing_settings(self) -> None:
|
||||
"Insert any missing keys from the settings into the database."
|
||||
get_db().executemany(
|
||||
"INSERT OR IGNORE INTO config(key, value) VALUES (?, ?);",
|
||||
asdict(SettingsValues()).items()
|
||||
)
|
||||
commit()
|
||||
return
|
||||
|
||||
def _fetch_settings(self) -> None:
|
||||
"Load the settings from the database into the cache."
|
||||
db_values = {
|
||||
k: v
|
||||
for k, v in get_db().execute(
|
||||
"SELECT key, value FROM config;"
|
||||
)
|
||||
if k in SettingsValues.__dataclass_fields__
|
||||
}
|
||||
|
||||
for b_key in ('allow_new_accounts', 'login_time_reset'):
|
||||
db_values[b_key] = bool(db_values[b_key])
|
||||
|
||||
self.__cached_values = SettingsValues(**db_values)
|
||||
return
|
||||
|
||||
def get_settings(self) -> SettingsValues:
|
||||
"""Get the settings from the cache.
|
||||
|
||||
Returns:
|
||||
SettingsValues: The settings.
|
||||
"""
|
||||
return self.__cached_values
|
||||
|
||||
# Alias, better in one-liners
|
||||
# sv = Settings Values
|
||||
@property
|
||||
def sv(self) -> SettingsValues:
|
||||
"""Get the settings from the cache.
|
||||
|
||||
Returns:
|
||||
SettingsValues: The settings.
|
||||
"""
|
||||
return self.__cached_values
|
||||
|
||||
def update(
|
||||
self,
|
||||
data: Mapping[str, Any]
|
||||
) -> None:
|
||||
"""Change the settings, in a `dict.update()` type of way.
|
||||
|
||||
Args:
|
||||
data (Mapping[str, Any]): The keys and their new values.
|
||||
|
||||
Raises:
|
||||
KeyNotFound: Key is not a setting.
|
||||
InvalidKeyValue: Value of the key is not allowed.
|
||||
"""
|
||||
formatted_data = {}
|
||||
for key, value in data.items():
|
||||
formatted_data[key] = self.__format_setting(key, value)
|
||||
|
||||
get_db().executemany(
|
||||
"UPDATE config SET value = ? WHERE key = ?;",
|
||||
reversed_tuples(formatted_data.items())
|
||||
)
|
||||
|
||||
for key, handler in (
|
||||
('url_prefix', update_manifest),
|
||||
('log_level', set_log_level)
|
||||
):
|
||||
if (
|
||||
key in data
|
||||
and formatted_data[key] != getattr(self.get_settings(), key)
|
||||
):
|
||||
handler(formatted_data[key])
|
||||
|
||||
self._fetch_settings()
|
||||
|
||||
LOGGER.info(f"Settings changed: {formatted_data}")
|
||||
|
||||
return
|
||||
|
||||
def reset(self, key: str) -> None:
|
||||
"""Reset the value of the key to the default value.
|
||||
|
||||
Args:
|
||||
key (str): The key of which to reset the value.
|
||||
|
||||
Raises:
|
||||
KeyNotFound: Key is not a setting.
|
||||
"""
|
||||
LOGGER.debug(f'Setting reset: {key}')
|
||||
|
||||
if not isinstance(
|
||||
SettingsValues.__dataclass_fields__[key].default_factory,
|
||||
_MISSING_TYPE
|
||||
):
|
||||
self.update({
|
||||
key: SettingsValues.__dataclass_fields__[key].default_factory()
|
||||
})
|
||||
else:
|
||||
self.update({
|
||||
key: SettingsValues.__dataclass_fields__[key].default
|
||||
})
|
||||
|
||||
return
|
||||
|
||||
def backup_hosting_settings(self) -> None:
|
||||
"Backup the hosting settings in the database."
|
||||
s = self.get_settings()
|
||||
backup_settings = {
|
||||
'backup_host': s.host,
|
||||
'backup_port': s.port,
|
||||
'backup_url_prefix': s.url_prefix
|
||||
}
|
||||
self.update(backup_settings)
|
||||
return
|
||||
|
||||
def __format_setting(self, key: str, value: Any) -> Any:
|
||||
"""Check if the value of a setting is allowed and convert if needed.
|
||||
|
||||
Args:
|
||||
key (str): Key of setting.
|
||||
value (Any): Value of setting.
|
||||
|
||||
Raises:
|
||||
KeyNotFound: Key is not a setting.
|
||||
InvalidKeyValue: Value is not allowed.
|
||||
|
||||
Returns:
|
||||
Any: (Converted) Setting value.
|
||||
"""
|
||||
converted_value = value
|
||||
|
||||
if key not in SettingsValues.__dataclass_fields__:
|
||||
raise KeyNotFound(key)
|
||||
|
||||
key_data = SettingsValues.__dataclass_fields__[key]
|
||||
|
||||
if not isinstance(value, key_data.type):
|
||||
raise InvalidKeyValue(key, value)
|
||||
|
||||
if key == 'login_time':
|
||||
if not 60 <= value <= THIRTY_DAYS:
|
||||
raise InvalidKeyValue(key, value)
|
||||
|
||||
elif key in ('port', 'backup_port'):
|
||||
if not 1 <= value <= 65535:
|
||||
raise InvalidKeyValue(key, value)
|
||||
|
||||
elif key in ('url_prefix', 'backup_url_prefix'):
|
||||
if value:
|
||||
converted_value = ('/' + value.lstrip('/')).rstrip('/')
|
||||
|
||||
elif key == 'log_level':
|
||||
if value not in (INFO, DEBUG):
|
||||
raise InvalidKeyValue(key, value)
|
||||
|
||||
return converted_value
|
||||
|
||||
|
||||
def update_manifest(url_base: str) -> None:
|
||||
"""Update the url's in the manifest file.
|
||||
Needs to happen when url base changes.
|
||||
|
||||
Args:
|
||||
url_base (str): The url base to use in the file.
|
||||
"""
|
||||
filename = folder_path('frontend', 'static', 'json', 'pwa_manifest.json')
|
||||
|
||||
with open(filename, 'r') as f:
|
||||
manifest = load(f)
|
||||
manifest['start_url'] = url_base + '/'
|
||||
manifest['scope'] = url_base + '/'
|
||||
manifest['icons'][0]['src'] = f'{url_base}/static/img/favicon.svg'
|
||||
|
||||
with open(filename, 'w') as f:
|
||||
dump(manifest, f, indent=4)
|
||||
|
||||
return
|
||||
Reference in New Issue
Block a user