Refactored backend (Fixes #87)

This commit is contained in:
CasVT
2025-04-22 23:29:35 +02:00
parent 401c97308b
commit 0cbb03151f
65 changed files with 6974 additions and 5014 deletions

493
backend/internals/db.py Normal file
View File

@@ -0,0 +1,493 @@
# -*- coding: utf-8 -*-
"""
Setting up the database and handling connections
"""
from __future__ import annotations
from os import remove
from os.path import dirname, exists, isdir, isfile, join
from shutil import move
from sqlite3 import (PARSE_DECLTYPES, Connection, Cursor,
OperationalError, ProgrammingError, Row,
register_adapter, register_converter)
from threading import current_thread
from typing import Any, Dict, Generator, Iterable, List, Type, Union
from flask import g
from backend.base.custom_exceptions import InvalidDatabaseFile
from backend.base.definitions import Constants, ReminderType, StartType, T
from backend.base.helpers import create_folder, folder_path, rename_file
from backend.base.logging import LOGGER, set_log_level
from backend.internals.db_migration import get_latest_db_version, migrate_db
REMINDER_TO_KEY = {
ReminderType.REMINDER: "reminder_id",
ReminderType.STATIC_REMINDER: "static_reminder_id",
ReminderType.TEMPLATE: "template_id"
}
class MindCursor(Cursor):
row_factory: Union[Type[Row], None] # type: ignore
@property
def lastrowid(self) -> int:
return super().lastrowid or 1
def fetchonedict(self) -> Union[Dict[str, Any], None]:
"""Same as `fetchone` but convert the Row object to a dict.
Returns:
Union[Dict[str, Any], None]: The dict or None i.c.o. no result.
"""
r = self.fetchone()
if r is None:
return r
return dict(r)
def fetchmanydict(self, size: Union[int, None] = 1) -> List[Dict[str, Any]]:
"""Same as `fetchmany` but convert the Row object to a dict.
Args:
size (Union[int, None], optional): The amount of rows to return.
Defaults to 1.
Returns:
List[Dict[str, Any]]: The rows.
"""
return [dict(e) for e in self.fetchmany(size)]
def fetchalldict(self) -> List[Dict[str, Any]]:
"""Same as `fetchall` but convert the Row object to a dict.
Returns:
List[Dict[str, Any]]: The results.
"""
return [dict(e) for e in self]
def exists(self) -> Union[Any, None]:
"""Return the first column of the first row, or `None` if not found.
Returns:
Union[Any, None]: The value of the first column of the first row,
or `None` if not found.
"""
r = self.fetchone()
if r is None:
return r
return r[0]
class DBConnectionManager(type):
instances: Dict[int, DBConnection] = {}
def __call__(cls, *args: Any, **kwargs: Any) -> DBConnection:
thread_id = current_thread().native_id or -1
if (
not thread_id in cls.instances
or cls.instances[thread_id].closed
):
cls.instances[thread_id] = super().__call__(*args, **kwargs)
return cls.instances[thread_id]
class DBConnection(Connection, metaclass=DBConnectionManager):
file = ''
def __init__(self, timeout: float) -> None:
"""Create a connection with a database.
Args:
timeout (float): How long to wait before giving up on a command.
"""
LOGGER.debug(f'Creating connection {self}')
super().__init__(
self.file,
timeout=timeout,
detect_types=PARSE_DECLTYPES
)
super().cursor().execute("PRAGMA foreign_keys = ON;")
self.closed = False
return
def cursor( # type: ignore
self,
force_new: bool = False
) -> MindCursor:
"""Get a database cursor from the connection.
Args:
force_new (bool, optional): Get a new cursor instead of the cached
one.
Defaults to False.
Returns:
MindCursor: The database cursor.
"""
if not hasattr(g, 'cursors'):
g.cursors = []
if not g.cursors:
c = MindCursor(self)
c.row_factory = Row
g.cursors.append(c)
if not force_new:
return g.cursors[0]
else:
c = MindCursor(self)
c.row_factory = Row
g.cursors.append(c)
return g.cursors[-1]
def close(self) -> None:
"""Close the database connection"""
LOGGER.debug(f'Closing connection {self}')
self.closed = True
super().close()
return
def __repr__(self) -> str:
return f'<{self.__class__.__name__}; {current_thread().name}; {id(self)}>'
def set_db_location(
db_folder: Union[str, None]
) -> None:
"""Setup database location. Create folder for database and set location for
`db.DBConnection`.
Args:
db_folder (Union[str, None], optional): The folder in which the database
will be stored or in which a database is for MIND to use. Give
`None` for the default location.
Raises:
ValueError: Value of `db_folder` exists but is not a folder.
"""
if db_folder:
if exists(db_folder) and not isdir(db_folder):
raise ValueError('Database location is not a folder')
db_file_location = join(
db_folder or folder_path(*Constants.DB_FOLDER),
Constants.DB_NAME
)
LOGGER.debug(f'Setting database location: {db_file_location}')
create_folder(dirname(db_file_location))
if isfile(folder_path('db', 'Noted.db')):
rename_file(
folder_path('db', 'Noted.db'),
db_file_location
)
DBConnection.file = db_file_location
return
def get_db(force_new: bool = False) -> MindCursor:
"""Get a database cursor instance or create a new one if needed.
Args:
force_new (bool, optional): Decides if a new cursor is
returned instead of the standard one.
Defaults to False.
Returns:
MindCursor: Database cursor instance that outputs Row objects.
"""
cursor = (
DBConnection(timeout=Constants.DB_TIMEOUT)
.cursor(force_new=force_new)
)
return cursor
def commit() -> None:
"""Commit the database"""
get_db().connection.commit()
return
def iter_commit(iterable: Iterable[T]) -> Generator[T, Any, Any]:
"""Commit the database after each iteration. Also commits just before the
first iteration starts.
Args:
iterable (Iterable[T]): Iterable that will be iterated over like normal.
Yields:
Generator[T, Any, Any]: Items of iterable.
"""
commit = get_db().connection.commit
commit()
for i in iterable:
yield i
commit()
return
def close_db(e: Union[None, BaseException] = None) -> None:
"""Close database cursor, commit database and close database.
Args:
e (Union[None, BaseException], optional): Error. Defaults to None.
"""
try:
cursors = g.cursors
db: DBConnection = cursors[0].connection
for c in cursors:
c.close()
delattr(g, 'cursors')
db.commit()
if not current_thread().name.startswith('waitress-'):
db.close()
except (AttributeError, ProgrammingError):
pass
return
def close_all_db() -> None:
"Close all non-temporary database connections that are still open"
LOGGER.debug('Closing any open database connections')
for i in DBConnectionManager.instances.values():
if not i.closed:
i.close()
c = DBConnection(timeout=20.0)
c.commit()
c.close()
return
def setup_db() -> None:
"""
Setup the database tables and default config when they aren't setup yet
"""
from backend.implementations.users import Users
from backend.internals.settings import Settings
cursor = get_db()
cursor.execute("PRAGMA journal_mode = wal;")
register_adapter(bool, lambda b: int(b))
register_converter("BOOL", lambda b: b == b'1')
cursor.executescript("""
CREATE TABLE IF NOT EXISTS users(
id INTEGER PRIMARY KEY,
username VARCHAR(255) UNIQUE NOT NULL,
salt VARCHAR(40) NOT NULL,
hash VARCHAR(100) NOT NULL,
admin BOOL NOT NULL DEFAULT 0
);
CREATE TABLE IF NOT EXISTS notification_services(
id INTEGER PRIMARY KEY,
user_id INTEGER NOT NULL,
title VARCHAR(255),
url TEXT,
FOREIGN KEY (user_id) REFERENCES users(id)
);
CREATE TABLE IF NOT EXISTS reminders(
id INTEGER PRIMARY KEY,
user_id INTEGER NOT NULL,
title VARCHAR(255) NOT NULL,
text TEXT,
time INTEGER NOT NULL,
repeat_quantity VARCHAR(15),
repeat_interval INTEGER,
original_time INTEGER,
weekdays VARCHAR(13),
color VARCHAR(7),
FOREIGN KEY (user_id) REFERENCES users(id)
);
CREATE TABLE IF NOT EXISTS templates(
id INTEGER PRIMARY KEY,
user_id INTEGER NOT NULL,
title VARCHAR(255) NOT NULL,
text TEXT,
color VARCHAR(7),
FOREIGN KEY (user_id) REFERENCES users(id)
);
CREATE TABLE IF NOT EXISTS static_reminders(
id INTEGER PRIMARY KEY,
user_id INTEGER NOT NULL,
title VARCHAR(255) NOT NULL,
text TEXT,
color VARCHAR(7),
FOREIGN KEY (user_id) REFERENCES users(id)
);
CREATE TABLE IF NOT EXISTS reminder_services(
reminder_id INTEGER,
static_reminder_id INTEGER,
template_id INTEGER,
notification_service_id INTEGER NOT NULL,
FOREIGN KEY (reminder_id) REFERENCES reminders(id)
ON DELETE CASCADE,
FOREIGN KEY (static_reminder_id) REFERENCES static_reminders(id)
ON DELETE CASCADE,
FOREIGN KEY (template_id) REFERENCES templates(id)
ON DELETE CASCADE,
FOREIGN KEY (notification_service_id) REFERENCES notification_services(id)
);
CREATE TABLE IF NOT EXISTS config(
key VARCHAR(255) PRIMARY KEY,
value BLOB NOT NULL
);
""")
settings = Settings()
settings_values = settings.get_settings()
set_log_level(settings_values.log_level)
migrate_db()
# DB Migration might change settings, so update cache just to be sure.
settings._fetch_settings()
# Add admin user if it doesn't exist
users = Users()
if Constants.ADMIN_USERNAME not in users:
users.add(
Constants.ADMIN_USERNAME, Constants.ADMIN_PASSWORD,
force=True,
is_admin=True
)
return
def revert_db_import(
swap: bool,
imported_db_file: str = ''
) -> None:
"""Revert the database import process. The original_db_file is the file
currently used (`DBConnection.file`).
Args:
swap (bool): Whether or not to keep the imported_db_file or not,
instead of the original_db_file.
imported_db_file (str, optional): The other database file. Keep empty
to use `Constants.DB_ORIGINAL_FILENAME`.
Defaults to ''.
"""
original_db_file = DBConnection.file
if not imported_db_file:
imported_db_file = join(
dirname(DBConnection.file),
Constants.DB_ORIGINAL_NAME
)
if swap:
remove(original_db_file)
move(
imported_db_file,
original_db_file
)
else:
remove(imported_db_file)
return
def import_db(
new_db_file: str,
copy_hosting_settings: bool
) -> None:
"""Replace the current database with a new one.
Args:
new_db_file (str): The path to the new database file.
copy_hosting_settings (bool): Keep the hosting settings from the current
database.
Raises:
InvalidDatabaseFile: The new database file is invalid or unsupported.
"""
LOGGER.info(f'Importing new database; {copy_hosting_settings=}')
cursor = Connection(new_db_file, timeout=20.0).cursor()
try:
database_version = cursor.execute(
"SELECT value FROM config WHERE key = 'database_version' LIMIT 1;"
).fetchone()[0]
if not isinstance(database_version, int):
raise InvalidDatabaseFile(new_db_file)
except (OperationalError, InvalidDatabaseFile):
LOGGER.error('Uploaded database is not a MIND database file')
cursor.connection.close()
revert_db_import(
swap=False,
imported_db_file=new_db_file
)
raise InvalidDatabaseFile(new_db_file)
if database_version > get_latest_db_version():
LOGGER.error(
'Uploaded database is higher version than this MIND installation can support')
revert_db_import(
swap=False,
imported_db_file=new_db_file
)
raise InvalidDatabaseFile(new_db_file)
if copy_hosting_settings:
hosting_settings = get_db().execute("""
SELECT key, value
FROM config
WHERE key = 'host'
OR key = 'port'
OR key = 'url_prefix'
LIMIT 3;
"""
).fetchalldict()
cursor.executemany("""
INSERT INTO config(key, value)
VALUES (:key, :value)
ON CONFLICT(key) DO
UPDATE
SET value = :value;
""",
hosting_settings
)
cursor.connection.commit()
cursor.connection.close()
move(
DBConnection.file,
join(dirname(DBConnection.file), Constants.DB_ORIGINAL_NAME)
)
move(
new_db_file,
DBConnection.file
)
from backend.internals.server import Server
Server().restart(StartType.RESTART_DB_CHANGES)
return

View File

@@ -0,0 +1,312 @@
# -*- coding: utf-8 -*-
from typing import Dict, Type
from backend.base.definitions import Constants, DBMigrator
from backend.base.logging import LOGGER
class VersionMappingContainer:
version_map: Dict[int, Type[DBMigrator]] = {}
def _load_version_map() -> None:
if VersionMappingContainer.version_map:
return
VersionMappingContainer.version_map = {
m.start_version: m
for m in DBMigrator.__subclasses__()
}
return
def get_latest_db_version() -> int:
_load_version_map()
return max(VersionMappingContainer.version_map) + 1
def migrate_db() -> None:
"""
Migrate a MIND database from it's current version
to the newest version supported by the MIND version installed.
"""
from backend.internals.db import iter_commit
from backend.internals.settings import Settings
s = Settings()
current_db_version = s.get_settings().database_version
newest_version = get_latest_db_version()
if current_db_version == newest_version:
return
LOGGER.info('Migrating database to newer version...')
LOGGER.debug(
"Database migration: %d -> %d",
current_db_version, newest_version
)
for start_version in iter_commit(range(current_db_version, newest_version)):
if start_version not in VersionMappingContainer.version_map:
continue
VersionMappingContainer.version_map[start_version]().run()
s.update({'database_version': start_version + 1})
s._fetch_settings()
return
class MigrateToUTC(DBMigrator):
start_version = 1
def run(self) -> None:
# V1 -> V2
from datetime import datetime
from time import time
from backend.internals.db import get_db
cursor = get_db()
t = time()
utc_offset = datetime.fromtimestamp(t) - datetime.utcfromtimestamp(t)
cursor.execute("SELECT time, id FROM reminders;")
new_reminders = [
[
round((
datetime.fromtimestamp(r["time"]) - utc_offset
).timestamp()),
r["id"]
]
for r in cursor
]
cursor.executemany(
"UPDATE reminders SET time = ? WHERE id = ?;",
new_reminders
)
return
class MigrateAddColor(DBMigrator):
start_version = 2
def run(self) -> None:
# V2 -> V3
from backend.internals.db import get_db
get_db().executescript("""
ALTER TABLE reminders
ADD color VARCHAR(7);
ALTER TABLE templates
ADD color VARCHAR(7);
""")
return
class MigrateFixRQ(DBMigrator):
start_version = 3
def run(self) -> None:
# V3 -> V4
from backend.internals.db import get_db
get_db().executescript("""
UPDATE reminders
SET repeat_quantity = repeat_quantity || 's'
WHERE repeat_quantity NOT LIKE '%s';
""")
return
class MigrateToReminderServices(DBMigrator):
start_version = 4
def run(self) -> None:
# V4 -> V5
from backend.internals.db import get_db
get_db().executescript("""
BEGIN TRANSACTION;
PRAGMA defer_foreign_keys = ON;
CREATE TEMPORARY TABLE temp_reminder_services(
reminder_id,
static_reminder_id,
template_id,
notification_service_id
);
-- Reminders
INSERT INTO temp_reminder_services(reminder_id, notification_service_id)
SELECT id, notification_service
FROM reminders;
CREATE TEMPORARY TABLE temp_reminders AS
SELECT id, user_id, title, text, time, repeat_quantity, repeat_interval, original_time, color
FROM reminders;
DROP TABLE reminders;
CREATE TABLE reminders(
id INTEGER PRIMARY KEY,
user_id INTEGER NOT NULL,
title VARCHAR(255) NOT NULL,
text TEXT,
time INTEGER NOT NULL,
repeat_quantity VARCHAR(15),
repeat_interval INTEGER,
original_time INTEGER,
color VARCHAR(7),
FOREIGN KEY (user_id) REFERENCES users(id)
);
INSERT INTO reminders
SELECT * FROM temp_reminders;
-- Templates
INSERT INTO temp_reminder_services(template_id, notification_service_id)
SELECT id, notification_service
FROM templates;
CREATE TEMPORARY TABLE temp_templates AS
SELECT id, user_id, title, text, color
FROM templates;
DROP TABLE templates;
CREATE TABLE templates(
id INTEGER PRIMARY KEY,
user_id INTEGER NOT NULL,
title VARCHAR(255) NOT NULL,
text TEXT,
color VARCHAR(7),
FOREIGN KEY (user_id) REFERENCES users(id)
);
INSERT INTO templates
SELECT * FROM temp_templates;
INSERT INTO reminder_services
SELECT * FROM temp_reminder_services;
COMMIT;
""")
return
class MigrateRemoveUser1(DBMigrator):
start_version = 5
def run(self) -> None:
# V5 -> V6
from backend.base.custom_exceptions import (AccessUnauthorized,
UserNotFound)
from backend.implementations.users import Users
try:
Users().login('User1', 'Password1').delete()
except (UserNotFound, AccessUnauthorized):
pass
return
class MigrateAddWeekdays(DBMigrator):
start_version = 6
def run(self) -> None:
# V6 -> V7
from backend.internals.db import get_db
get_db().executescript("""
ALTER TABLE reminders
ADD weekdays VARCHAR(13);
""")
return
class MigrateAddAdmin(DBMigrator):
start_version = 7
def run(self) -> None:
# V7 -> V8
from backend.implementations.users import Users
from backend.internals.db import get_db
from backend.internals.settings import Settings
cursor = get_db()
cursor.executescript("""
DROP TABLE config;
CREATE TABLE IF NOT EXISTS config(
key VARCHAR(255) PRIMARY KEY,
value BLOB NOT NULL
);
"""
)
Settings()._insert_missing_settings()
cursor.executescript("""
ALTER TABLE users
ADD admin BOOL NOT NULL DEFAULT 0;
"""
)
users = Users()
if 'admin' in users:
users.get_one(
users.user_db.username_to_id('admin')
).update(
new_username='admin_old',
new_password=None
)
users.add(
Constants.ADMIN_USERNAME, Constants.ADMIN_PASSWORD,
force=True,
is_admin=True
)
return
class MigrateHostSettingsToDB(DBMigrator):
start_version = 8
def run(self) -> None:
# V8 -> V9
# In newer versions, the variables don't exist anymore, and behaviour
# was to then set the values to the default values. But that's already
# taken care of by the settings, so nothing to do here anymore.
return
class MigrateUpdateManifest(DBMigrator):
start_version = 9
def run(self) -> None:
# V9 -> V10
# Nothing is changed in the database
# It's just that this code needs to run once
# and the DB migration system does exactly that:
# run pieces of code once.
from backend.internals.settings import Settings, update_manifest
update_manifest(
Settings().get_settings().url_prefix
)
return

View File

@@ -0,0 +1,815 @@
# -*- coding: utf-8 -*-
from typing import List, Union
from backend.base.definitions import (NotificationServiceData, ReminderData,
ReminderType, StaticReminderData,
TemplateData, UserData)
from backend.base.helpers import first_of_column
from backend.internals.db import REMINDER_TO_KEY, get_db
class NotificationServicesDB:
def __init__(self, user_id: int) -> None:
self.user_id = user_id
return
def exists(self, notification_service_id: int) -> bool:
return get_db().execute("""
SELECT 1
FROM notification_services
WHERE id = :id
AND user_id = :user_id
LIMIT 1;
""",
{
'user_id': self.user_id,
'id': notification_service_id
}
).fetchone() is not None
def fetch(
self,
notification_service_id: Union[int, None] = None
) -> List[NotificationServiceData]:
id_filter = ""
if notification_service_id:
id_filter = "AND id = :ns_id"
result = get_db().execute(f"""
SELECT
id, title, url
FROM notification_services
WHERE user_id = :user_id
{id_filter}
ORDER BY title, id;
""",
{
"user_id": self.user_id,
"ns_id": notification_service_id
}
).fetchalldict()
return [
NotificationServiceData(**entry)
for entry in result
]
def add(
self,
title: str,
url: str
) -> int:
new_id = get_db().execute("""
INSERT INTO notification_services(user_id, title, url)
VALUES (?, ?, ?)
""",
(self.user_id, title, url)
).lastrowid
return new_id
def update(
self,
notification_service_id: int,
title: str,
url: str
) -> None:
get_db().execute("""
UPDATE notification_services
SET title = :title, url = :url
WHERE id = :ns_id;
""",
{
"title": title,
"url": url,
"ns_id": notification_service_id
}
)
return
def delete(
self,
notification_service_id: int
) -> None:
get_db().execute(
"DELETE FROM notification_services WHERE id = ?;",
(notification_service_id,)
)
return
class ReminderServicesDB:
def __init__(self, reminder_type: ReminderType) -> None:
self.key = REMINDER_TO_KEY[reminder_type]
return
def reminder_to_ns(
self,
reminder_id: int
) -> List[int]:
"""Get the ID's of the notification services that are linked to the given
reminder, static reminder or template.
Args:
reminder_id (int): The ID of the reminder, static reminder or template.
Returns:
List[int]: A list of the notification service ID's that are linked to
the given reminder, static reminder or template.
"""
result = first_of_column(get_db().execute(
f"""
SELECT notification_service_id
FROM reminder_services
WHERE {self.key} = ?;
""",
(reminder_id,)
))
return result
def update_ns_bindings(
self,
reminder_id: int,
notification_services: List[int]
) -> None:
"""Update the bindings of a reminder, static reminder or template to
notification services.
Args:
reminder_id (int): The ID of the reminder, static reminder or template.
notification_services (List[int]): The new list of notification services
that should be linked to the reminder, static reminder or template.
"""
cursor = get_db()
cursor.connection.isolation_level = None
cursor.execute("BEGIN TRANSACTION;")
cursor.execute(
f"""
DELETE FROM reminder_services
WHERE {self.key} = ?;
""",
(reminder_id,)
)
cursor.executemany(
f"""
INSERT INTO reminder_services(
{self.key},
notification_service_id
)
VALUES (?, ?);
""",
((reminder_id, ns_id) for ns_id in notification_services)
)
cursor.execute("COMMIT;")
cursor.connection.isolation_level = ""
return
def uses_ns(
self,
notification_service_id: int
) -> List[int]:
"""Get the ID's of the reminders (of given type) that use the given
notification service.
Args:
notification_service_id (int): The ID of the notification service to
check for.
Returns:
List[int]: The ID's of the reminders (only of the given type) that
use the notification service.
"""
return first_of_column(get_db().execute(
f"""
SELECT {self.key}
FROM reminder_services
WHERE notification_service_id = ?
AND {self.key} IS NOT NULL
LIMIT 1;
""",
(notification_service_id,)
))
class UsersDB:
def exists(self, user_id: int) -> bool:
return get_db().execute("""
SELECT 1
FROM users
WHERE id = ?
LIMIT 1;
""",
(user_id,)
).fetchone() is not None
def taken(self, username: str) -> bool:
return get_db().execute("""
SELECT 1
FROM users
WHERE username = ?
LIMIT 1;
""",
(username,)
).fetchone() is not None
def username_to_id(self, username: str) -> int:
return get_db().execute("""
SELECT id
FROM users
WHERE username = ?
LIMIT 1;
""",
(username,)
).fetchone()[0]
def fetch(
self,
user_id: Union[int, None] = None
) -> List[UserData]:
id_filter = ""
if user_id:
id_filter = "WHERE id = :id"
result = get_db().execute(f"""
SELECT
id, username, admin, salt, hash
FROM users
{id_filter}
ORDER BY username, id;
""",
{
"id": user_id
}
).fetchalldict()
return [
UserData(**entry)
for entry in result
]
def add(
self,
username: str,
salt: bytes,
hash: bytes,
admin: bool
) -> int:
user_id = get_db().execute(
"""
INSERT INTO users(username, salt, hash, admin)
VALUES (?, ?, ?, ?);
""",
(username, salt, hash, admin)
).lastrowid
return user_id
def update(
self,
user_id: int,
username: str,
hash: bytes
) -> None:
get_db().execute("""
UPDATE users
SET username = :username, hash = :hash
WHERE id = :user_id;
""",
{
"username": username,
"hash": hash,
"user_id": user_id
}
)
return
def delete(
self,
user_id: int
) -> None:
get_db().executescript(f"""
BEGIN TRANSACTION;
DELETE FROM reminders WHERE user_id = {user_id};
DELETE FROM templates WHERE user_id = {user_id};
DELETE FROM static_reminders WHERE user_id = {user_id};
DELETE FROM notification_services WHERE user_id = {user_id};
DELETE FROM users WHERE id = {user_id};
COMMIT;
""")
return
class TemplatesDB:
def __init__(self, user_id: int) -> None:
self.user_id = user_id
self.rms_db = ReminderServicesDB(ReminderType.TEMPLATE)
return
def exists(self, template_id: int) -> bool:
return get_db().execute(
"SELECT 1 FROM templates WHERE id = ? AND user_id = ? LIMIT 1;",
(template_id, self.user_id)
).fetchone() is not None
def fetch(
self,
template_id: Union[int, None] = None
) -> List[TemplateData]:
id_filter = ""
if template_id:
id_filter = "AND id = :t_id"
result = get_db().execute(f"""
SELECT
id, title, text, color
FROM templates
WHERE user_id = :user_id
{id_filter}
ORDER BY title, id;
""",
{
"user_id": self.user_id,
"t_id": template_id
}
).fetchalldict()
for r in result:
r['notification_services'] = self.rms_db.reminder_to_ns(r['id'])
return [
TemplateData(**entry)
for entry in result
]
def add(
self,
title: str,
text: Union[str, None],
color: Union[str, None],
notification_services: List[int]
) -> int:
new_id = get_db().execute("""
INSERT INTO templates(user_id, title, text, color)
VALUES (?, ?, ?, ?);
""",
(self.user_id, title, text, color)
).lastrowid
self.rms_db.update_ns_bindings(
new_id, notification_services
)
return new_id
def update(
self,
template_id: int,
title: str,
text: Union[str, None],
color: Union[str, None],
notification_services: List[int]
) -> None:
get_db().execute("""
UPDATE templates
SET
title = :title,
text = :text,
color = :color
WHERE id = :t_id;
""",
{
"title": title,
"text": text,
"color": color,
"t_id": template_id
}
)
self.rms_db.update_ns_bindings(
template_id,
notification_services
)
return
def delete(
self,
template_id: int
) -> None:
get_db().execute(
"DELETE FROM templates WHERE id = ?;",
(template_id,)
)
return
class StaticRemindersDB:
def __init__(self, user_id: int) -> None:
self.user_id = user_id
self.rms_db = ReminderServicesDB(ReminderType.STATIC_REMINDER)
return
def exists(self, reminder_id: int) -> bool:
return get_db().execute("""
SELECT 1
FROM static_reminders
WHERE id = ?
AND user_id = ?
LIMIT 1;
""",
(reminder_id, self.user_id)
).fetchone() is not None
def fetch(
self,
reminder_id: Union[int, None] = None
) -> List[StaticReminderData]:
id_filter = ""
if reminder_id:
id_filter = "AND id = :r_id"
result = get_db().execute(f"""
SELECT
id, title, text, color
FROM static_reminders
WHERE user_id = :user_id
{id_filter}
ORDER BY title, id;
""",
{
"user_id": self.user_id,
"r_id": reminder_id
}
).fetchalldict()
for r in result:
r['notification_services'] = self.rms_db.reminder_to_ns(r['id'])
return [
StaticReminderData(**entry)
for entry in result
]
def add(
self,
title: str,
text: Union[str, None],
color: Union[str, None],
notification_services: List[int]
) -> int:
new_id = get_db().execute("""
INSERT INTO static_reminders(user_id, title, text, color)
VALUES (?, ?, ?, ?);
""",
(self.user_id, title, text, color)
).lastrowid
self.rms_db.update_ns_bindings(
new_id, notification_services
)
return new_id
def update(
self,
reminder_id: int,
title: str,
text: Union[str, None],
color: Union[str, None],
notification_services: List[int]
) -> None:
get_db().execute("""
UPDATE static_reminders
SET
title = :title,
text = :text,
color = :color
WHERE id = :r_id;
""",
{
"title": title,
"text": text,
"color": color,
"r_id": reminder_id
}
)
self.rms_db.update_ns_bindings(
reminder_id,
notification_services
)
return
def delete(
self,
reminder_id: int
) -> None:
get_db().execute(
"DELETE FROM static_reminders WHERE id = ?;",
(reminder_id,)
)
return
class RemindersDB:
def __init__(self, user_id: int) -> None:
self.user_id = user_id
self.rms_db = ReminderServicesDB(ReminderType.REMINDER)
return
def exists(self, reminder_id: int) -> bool:
return get_db().execute("""
SELECT 1
FROM reminders
WHERE id = ?
AND user_id = ?
LIMIT 1;
""",
(reminder_id, self.user_id)
).fetchone() is not None
def fetch(
self,
reminder_id: Union[int, None] = None
) -> List[ReminderData]:
id_filter = ""
if reminder_id:
id_filter = "AND id = :r_id"
result = get_db().execute(f"""
SELECT
id, title, text, color,
time, original_time,
repeat_quantity, repeat_interval,
weekdays AS _weekdays
FROM reminders
WHERE user_id = :user_id
{id_filter};
""",
{
"user_id": self.user_id,
"r_id": reminder_id
}
).fetchalldict()
for r in result:
r['notification_services'] = self.rms_db.reminder_to_ns(r['id'])
return [
ReminderData(**entry)
for entry in result
]
def add(
self,
title: str,
text: Union[str, None],
time: int,
repeat_quantity: Union[str, None],
repeat_interval: Union[int, None],
weekdays: Union[str, None],
original_time: Union[int, None],
color: Union[str, None],
notification_services: List[int]
) -> int:
new_id = get_db().execute("""
INSERT INTO reminders(
user_id,
title, text,
time,
repeat_quantity, repeat_interval,
weekdays,
original_time,
color
)
VALUES (
:user_id,
:title, :text,
:time,
:rq, :ri,
:wd,
:ot,
:color
);
""",
{
"user_id": self.user_id,
"title": title,
"text": text,
"time": time,
"rq": repeat_quantity,
"ri": repeat_interval,
"wd": weekdays,
"ot": original_time,
"color": color
}
).lastrowid
self.rms_db.update_ns_bindings(
new_id, notification_services
)
return new_id
def update(
self,
reminder_id: int,
title: str,
text: Union[str, None],
time: int,
repeat_quantity: Union[str, None],
repeat_interval: Union[int, None],
weekdays: Union[str, None],
original_time: Union[int, None],
color: Union[str, None],
notification_services: List[int]
) -> None:
get_db().execute("""
UPDATE reminders
SET
title = :title,
text = :text,
time = :time,
repeat_quantity = :rq,
repeat_interval = :ri,
weekdays = :wd,
original_time = :ot,
color = :color
WHERE id = :r_id;
""",
{
"title": title,
"text": text,
"time": time,
"rq": repeat_quantity,
"ri": repeat_interval,
"wd": weekdays,
"ot": original_time,
"color": color,
"r_id": reminder_id
}
)
self.rms_db.update_ns_bindings(
reminder_id,
notification_services
)
return
def delete(
self,
reminder_id: int
) -> None:
get_db().execute(
"DELETE FROM reminders WHERE id = ?;",
(reminder_id,)
)
return
class UserlessRemindersDB:
def __init__(self) -> None:
self.rms_db = ReminderServicesDB(ReminderType.REMINDER)
return
def exists(self, reminder_id: int) -> bool:
return get_db().execute("""
SELECT 1
FROM reminders
WHERE id = ?
LIMIT 1;
""",
(reminder_id,)
).fetchone() is not None
def reminder_id_to_user_id(self, reminder_id: int) -> int:
return get_db().execute(
"""
SELECT user_id
FROM reminders
WHERE id = ?
LIMIT 1;
""",
(reminder_id,)
).exists() or -1
def get_soonest_time(self) -> Union[int, None]:
return get_db().execute("SELECT MIN(time) FROM reminders;").exists()
def fetch(
self,
time: Union[int, None] = None
) -> List[ReminderData]:
time_filter = ""
if time:
time_filter = "WHERE time = :time"
result = get_db().execute(f"""
SELECT
id,
title, text, color,
time, original_time,
repeat_quantity, repeat_interval,
weekdays AS _weekdays
FROM reminders
{time_filter};
""",
{
"time": time
}
).fetchalldict()
for r in result:
r['notification_services'] = self.rms_db.reminder_to_ns(r['id'])
return [
ReminderData(**entry)
for entry in result
]
def add(
self,
user_id: int,
title: str,
text: Union[str, None],
time: int,
repeat_quantity: Union[str, None],
repeat_interval: Union[int, None],
weekdays: Union[str, None],
original_time: Union[int, None],
color: Union[str, None],
notification_services: List[int]
) -> int:
new_id = get_db().execute("""
INSERT INTO reminders(
user_id,
title, text,
time,
repeat_quantity, repeat_interval,
weekdays,
original_time,
color
)
VALUES (
:user_id,
:title, :text,
:time,
:rq, :ri,
:wd,
:ot,
:color
);
""",
{
"user_id": user_id,
"title": title,
"text": text,
"time": time,
"rq": repeat_quantity,
"ri": repeat_interval,
"wd": weekdays,
"ot": original_time,
"color": color
}
).lastrowid
self.rms_db.update_ns_bindings(
new_id, notification_services
)
return new_id
def update(
self,
reminder_id: int,
time: int
) -> None:
get_db().execute("""
UPDATE reminders
SET time = :time
WHERE id = :r_id;
""",
{
"time": time,
"r_id": reminder_id
}
)
return
def delete(
self,
reminder_id: int
) -> None:
get_db().execute(
"DELETE FROM reminders WHERE id = ?;",
(reminder_id,)
)
return

247
backend/internals/server.py Normal file
View File

@@ -0,0 +1,247 @@
# -*- coding: utf-8 -*-
"""
Setting up, running and shutting down the API and web-ui
"""
from __future__ import annotations
from os import urandom
from threading import Timer, current_thread
from typing import TYPE_CHECKING, Union
from flask import Flask, render_template, request
from waitress.server import create_server
from waitress.task import ThreadedTaskDispatcher as TTD
from werkzeug.middleware.dispatcher import DispatcherMiddleware
from backend.base.definitions import Constants, StartType
from backend.base.helpers import Singleton, folder_path
from backend.base.logging import LOGGER
from backend.internals.db import (DBConnectionManager,
close_db, revert_db_import)
from backend.internals.settings import Settings
if TYPE_CHECKING:
from waitress.server import BaseWSGIServer, MultiSocketServer
class ThreadedTaskDispatcher(TTD):
def handler_thread(self, thread_no: int) -> None:
super().handler_thread(thread_no)
thread_id = current_thread().native_id or -1
if (
thread_id in DBConnectionManager.instances
and not DBConnectionManager.instances[thread_id].closed
):
DBConnectionManager.instances[thread_id].close()
return
def shutdown(self, cancel_pending: bool = True, timeout: int = 5) -> bool:
print()
LOGGER.info('Shutting down MIND')
result = super().shutdown(cancel_pending, timeout)
return result
def handle_start_type(start_type: StartType) -> None:
"""Do special actions needed based on restart version.
Args:
start_type (StartType): The restart version.
"""
if start_type == StartType.RESTART_HOSTING_CHANGES:
LOGGER.info("Starting timer for hosting changes")
Server().revert_hosting_timer.start()
elif start_type == StartType.RESTART_DB_CHANGES:
LOGGER.info("Starting timer for database import")
Server().revert_db_timer.start()
return
def diffuse_timers() -> None:
"""Stop any timers running after doing a special restart."""
SERVER = Server()
if SERVER.revert_hosting_timer.is_alive():
LOGGER.info("Timer for hosting changes diffused")
SERVER.revert_hosting_timer.cancel()
elif SERVER.revert_db_timer.is_alive():
LOGGER.info("Timer for database import diffused")
SERVER.revert_db_timer.cancel()
revert_db_import(swap=False)
return
class Server(metaclass=Singleton):
api_prefix = "/api"
admin_api_extension = "/admin"
admin_prefix = "/api/admin"
url_prefix = ''
def __init__(self) -> None:
self.start_type = None
self.revert_db_timer = Timer(
Constants.DB_REVERT_TIME,
revert_db_import,
kwargs={"swap": True}
)
self.revert_db_timer.name = "DatabaseImportHandler"
self.revert_hosting_timer = Timer(
Constants.HOSTING_REVERT_TIME,
self.restore_hosting_settings
)
self.revert_hosting_timer.name = "HostingHandler"
return
def create_app(self) -> None:
"""Creates an flask app instance that can be used to start a web server"""
from frontend.api import admin_api, api
from frontend.ui import ui
app = Flask(
__name__,
template_folder=folder_path('frontend', 'templates'),
static_folder=folder_path('frontend', 'static'),
static_url_path='/static'
)
app.config['SECRET_KEY'] = urandom(32)
app.config['JSONIFY_PRETTYPRINT_REGULAR'] = True
app.config['JSON_SORT_KEYS'] = False
# Add error handlers
@app.errorhandler(400)
def bad_request(e):
return {'error': "BadRequest", "result": {}}, 400
@app.errorhandler(405)
def method_not_allowed(e):
return {'error': "MethodNotAllowed", "result": {}}, 405
@app.errorhandler(500)
def internal_error(e):
return {'error': "InternalError", "result": {}}, 500
# Add endpoints
app.register_blueprint(ui)
app.register_blueprint(api, url_prefix=self.api_prefix)
app.register_blueprint(admin_api, url_prefix=self.admin_prefix)
# Setup db handling
app.teardown_appcontext(close_db)
self.app = app
return
def set_url_prefix(self, url_prefix: str) -> None:
"""Change the URL prefix of the server.
Args:
url_prefix (str): The desired URL prefix to set it to.
"""
self.app.config["APPLICATION_ROOT"] = url_prefix
self.app.wsgi_app = DispatcherMiddleware( # type: ignore
Flask(__name__),
{url_prefix: self.app.wsgi_app}
)
self.url_prefix = url_prefix
return
def __create_waitress_server(
self,
host: str,
port: int
) -> Union[MultiSocketServer, BaseWSGIServer]:
"""From the `Flask` instance created in `self.create_app()`, create
a waitress server instance.
Args:
host (str): Where to host the server on (e.g. `0.0.0.0`).
port (int): The port to host the server on (e.g. `5656`).
Returns:
Union[MultiSocketServer, BaseWSGIServer]: The waitress server instance.
"""
dispatcher = ThreadedTaskDispatcher()
dispatcher.set_thread_count(Constants.HOSTING_THREADS)
server = create_server(
self.app,
_dispatcher=dispatcher,
host=host,
port=port,
threads=Constants.HOSTING_THREADS
)
return server
def run(self, host: str, port: int) -> None:
"""Start the webserver.
Args:
host (str): Where to host the server on (e.g. `0.0.0.0`).
port (int): The port to host the server on (e.g. `5656`).
"""
self.server = self.__create_waitress_server(host, port)
LOGGER.info(f'MIND running on http://{host}:{port}{self.url_prefix}')
self.server.run()
return
def __shutdown_thread_function(self) -> None:
"""Shutdown waitress server. Intended to be run in a thread.
"""
if not hasattr(self, 'server'):
return
self.server.task_dispatcher.shutdown()
self.server.close()
self.server._map.clear() # type: ignore
return
def shutdown(self) -> None:
"""
Stop the waitress server. Starts a thread that shuts down the server.
"""
t = Timer(1.0, self.__shutdown_thread_function)
t.name = "InternalStateHandler"
t.start()
return
def restart(
self,
start_type: StartType = StartType.STARTUP
) -> None:
"""Same as `self.shutdown()`, but restart instead of shutting down.
Args:
start_type (StartType, optional): Why Kapowarr should
restart.
Defaults to StartType.STARTUP.
"""
self.start_type = start_type
self.shutdown()
return
def restore_hosting_settings(self) -> None:
with self.app.app_context():
settings = Settings()
values = settings.get_settings()
main_settings = {
'host': values.backup_host,
'port': values.backup_port,
'url_prefix': values.backup_url_prefix
}
settings.update(main_settings)
self.restart()
return

View File

@@ -0,0 +1,255 @@
# -*- coding: utf-8 -*-
from dataclasses import _MISSING_TYPE, asdict, dataclass
from functools import lru_cache
from json import dump, load
from logging import DEBUG, INFO
from typing import Any, Dict, Mapping
from backend.base.custom_exceptions import InvalidKeyValue, KeyNotFound
from backend.base.helpers import (Singleton, folder_path,
get_python_version, reversed_tuples)
from backend.base.logging import LOGGER, set_log_level
from backend.internals.db import DBConnection, commit, get_db
from backend.internals.db_migration import get_latest_db_version
THIRTY_DAYS = 2592000
@lru_cache(1)
def get_about_data() -> Dict[str, Any]:
"""Get data about the application and it's environment.
Raises:
RuntimeError: If the version is not found in the pyproject.toml file.
Returns:
Dict[str, Any]: The information.
"""
with open(folder_path("pyproject.toml"), "r") as f:
for line in f:
if line.startswith("version = "):
version = "V" + line.split('"')[1]
break
else:
raise RuntimeError("Version not found in pyproject.toml")
return {
"version": version,
"python_version": get_python_version(),
"database_version": get_latest_db_version(),
"database_location": DBConnection.file,
"data_folder": folder_path()
}
@dataclass(frozen=True)
class SettingsValues:
database_version: int = get_latest_db_version()
log_level: int = INFO
host: str = '0.0.0.0'
port: int = 8080
url_prefix: str = ''
backup_host: str = '0.0.0.0'
backup_port: int = 8080
backup_url_prefix: str = ''
allow_new_accounts: bool = True
login_time: int = 3600
login_time_reset: bool = True
def todict(self) -> Dict[str, Any]:
return {
k: v
for k, v in self.__dict__.items()
if not k.startswith('backup_')
}
class Settings(metaclass=Singleton):
def __init__(self) -> None:
self._insert_missing_settings()
self._fetch_settings()
return
def _insert_missing_settings(self) -> None:
"Insert any missing keys from the settings into the database."
get_db().executemany(
"INSERT OR IGNORE INTO config(key, value) VALUES (?, ?);",
asdict(SettingsValues()).items()
)
commit()
return
def _fetch_settings(self) -> None:
"Load the settings from the database into the cache."
db_values = {
k: v
for k, v in get_db().execute(
"SELECT key, value FROM config;"
)
if k in SettingsValues.__dataclass_fields__
}
for b_key in ('allow_new_accounts', 'login_time_reset'):
db_values[b_key] = bool(db_values[b_key])
self.__cached_values = SettingsValues(**db_values)
return
def get_settings(self) -> SettingsValues:
"""Get the settings from the cache.
Returns:
SettingsValues: The settings.
"""
return self.__cached_values
# Alias, better in one-liners
# sv = Settings Values
@property
def sv(self) -> SettingsValues:
"""Get the settings from the cache.
Returns:
SettingsValues: The settings.
"""
return self.__cached_values
def update(
self,
data: Mapping[str, Any]
) -> None:
"""Change the settings, in a `dict.update()` type of way.
Args:
data (Mapping[str, Any]): The keys and their new values.
Raises:
KeyNotFound: Key is not a setting.
InvalidKeyValue: Value of the key is not allowed.
"""
formatted_data = {}
for key, value in data.items():
formatted_data[key] = self.__format_setting(key, value)
get_db().executemany(
"UPDATE config SET value = ? WHERE key = ?;",
reversed_tuples(formatted_data.items())
)
for key, handler in (
('url_prefix', update_manifest),
('log_level', set_log_level)
):
if (
key in data
and formatted_data[key] != getattr(self.get_settings(), key)
):
handler(formatted_data[key])
self._fetch_settings()
LOGGER.info(f"Settings changed: {formatted_data}")
return
def reset(self, key: str) -> None:
"""Reset the value of the key to the default value.
Args:
key (str): The key of which to reset the value.
Raises:
KeyNotFound: Key is not a setting.
"""
LOGGER.debug(f'Setting reset: {key}')
if not isinstance(
SettingsValues.__dataclass_fields__[key].default_factory,
_MISSING_TYPE
):
self.update({
key: SettingsValues.__dataclass_fields__[key].default_factory()
})
else:
self.update({
key: SettingsValues.__dataclass_fields__[key].default
})
return
def backup_hosting_settings(self) -> None:
"Backup the hosting settings in the database."
s = self.get_settings()
backup_settings = {
'backup_host': s.host,
'backup_port': s.port,
'backup_url_prefix': s.url_prefix
}
self.update(backup_settings)
return
def __format_setting(self, key: str, value: Any) -> Any:
"""Check if the value of a setting is allowed and convert if needed.
Args:
key (str): Key of setting.
value (Any): Value of setting.
Raises:
KeyNotFound: Key is not a setting.
InvalidKeyValue: Value is not allowed.
Returns:
Any: (Converted) Setting value.
"""
converted_value = value
if key not in SettingsValues.__dataclass_fields__:
raise KeyNotFound(key)
key_data = SettingsValues.__dataclass_fields__[key]
if not isinstance(value, key_data.type):
raise InvalidKeyValue(key, value)
if key == 'login_time':
if not 60 <= value <= THIRTY_DAYS:
raise InvalidKeyValue(key, value)
elif key in ('port', 'backup_port'):
if not 1 <= value <= 65535:
raise InvalidKeyValue(key, value)
elif key in ('url_prefix', 'backup_url_prefix'):
if value:
converted_value = ('/' + value.lstrip('/')).rstrip('/')
elif key == 'log_level':
if value not in (INFO, DEBUG):
raise InvalidKeyValue(key, value)
return converted_value
def update_manifest(url_base: str) -> None:
"""Update the url's in the manifest file.
Needs to happen when url base changes.
Args:
url_base (str): The url base to use in the file.
"""
filename = folder_path('frontend', 'static', 'json', 'pwa_manifest.json')
with open(filename, 'r') as f:
manifest = load(f)
manifest['start_url'] = url_base + '/'
manifest['scope'] = url_base + '/'
manifest['icons'][0]['src'] = f'{url_base}/static/img/favicon.svg'
with open(filename, 'w') as f:
dump(manifest, f, indent=4)
return