From ccdb16eef5c4d7d23e013fdc5ed11cc6aaa6342f Mon Sep 17 00:00:00 2001 From: CasVT Date: Thu, 1 Feb 2024 14:42:10 +0100 Subject: [PATCH] Backend Refactor --- .vscode/settings.json | 4 +- MIND.py | 78 +-- backend/custom_exceptions.py | 71 ++- backend/db.py | 55 +- backend/helpers.py | 91 +++ backend/notification_service.py | 24 +- backend/reminders.py | 748 +++++++++++++++--------- backend/security.py | 5 +- backend/settings.py | 6 +- backend/static_reminders.py | 204 ++++--- backend/templates.py | 219 ++++--- backend/users.py | 257 ++++---- frontend/api.py | 546 ++++------------- frontend/input_validation.py | 434 ++++++++++++++ frontend/ui.py | 10 +- project_management/generate_api_docs.py | 9 +- tests/custom_exceptions_test.py | 19 +- tests/db_test.py | 6 +- tests/reminders_test.py | 12 +- tests/users_test.py | 8 +- 20 files changed, 1735 insertions(+), 1071 deletions(-) create mode 100644 backend/helpers.py create mode 100644 frontend/input_validation.py diff --git a/.vscode/settings.json b/.vscode/settings.json index c79b2a6..88b7c9f 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -7,5 +7,7 @@ "*_test.py" ], "python.testing.pytestEnabled": false, - "python.testing.unittestEnabled": true + "python.testing.unittestEnabled": true, + "python.analysis.autoImportCompletions": true, + "python.analysis.typeCheckingMode": "off" } \ No newline at end of file diff --git a/MIND.py b/MIND.py index b2bc0c5..5116a5a 100644 --- a/MIND.py +++ b/MIND.py @@ -1,20 +1,24 @@ #!/usr/bin/env python3 #-*- coding: utf-8 -*- +""" +The main file where MIND is started from +""" + import logging from os import makedirs, urandom -from os.path import abspath, dirname, isfile, join +from os.path import dirname, isfile from shutil import move -from sys import version_info from flask import Flask, render_template, request from waitress.server import create_server from werkzeug.middleware.dispatcher import DispatcherMiddleware from backend.db import DBConnection, ThreadedTaskDispatcher, close_db, setup_db -from frontend.api import (admin_api, admin_api_prefix, api, api_prefix, - reminder_handler) -from frontend.ui import ui +from backend.helpers import check_python_version, folder_path +from backend.reminders import ReminderHandler +from frontend.api import admin_api, admin_api_prefix, api, api_prefix +from frontend.ui import UIVariables, ui HOST = '0.0.0.0' PORT = '8080' @@ -23,35 +27,33 @@ LOGGING_LEVEL = logging.INFO THREADS = 10 DB_FILENAME = 'db', 'MIND.db' +UIVariables.url_prefix = URL_PREFIX logging.basicConfig( level=LOGGING_LEVEL, format='[%(asctime)s][%(threadName)s][%(levelname)s] %(message)s', datefmt='%Y-%m-%d %H:%M:%S' ) -def _folder_path(*folders) -> str: - """Turn filepaths relative to the project folder into absolute paths - Returns: - str: The absolute filepath - """ - return join(dirname(abspath(__file__)), *folders) - def _create_app() -> Flask: """Create a Flask app instance + Returns: Flask: The created app instance """ app = Flask( __name__, - template_folder=_folder_path('frontend','templates'), - static_folder=_folder_path('frontend','static'), + template_folder=folder_path('frontend','templates'), + static_folder=folder_path('frontend','static'), static_url_path='/static' ) app.config['SECRET_KEY'] = urandom(32) app.config['JSONIFY_PRETTYPRINT_REGULAR'] = True app.config['JSON_SORT_KEYS'] = False app.config['APPLICATION_ROOT'] = URL_PREFIX - app.wsgi_app = DispatcherMiddleware(Flask(__name__), {URL_PREFIX: app.wsgi_app}) + app.wsgi_app = DispatcherMiddleware( + Flask(__name__), + {URL_PREFIX: app.wsgi_app} + ) # Add error handlers @app.errorhandler(400) @@ -68,7 +70,7 @@ def _create_app() -> Flask: @app.errorhandler(404) def not_found(e): - if request.path.startswith('/api'): + if request.path.startswith(api_prefix): return {'error': 'Not Found', 'result': {}}, 404 return render_template('page_not_found.html', url_prefix=logging.URL_PREFIX) @@ -83,40 +85,42 @@ def _create_app() -> Flask: def MIND() -> None: """The main function of MIND - Returns: - None """ - # Check python version - if (version_info.major < 3) or (version_info.major == 3 and version_info.minor < 7): - logging.error('Error: the minimum python version required is python3.7 (currently ' + version_info.major + '.' + version_info.minor + '.' + version_info.micro + ')') + logging.info('Starting up MIND') + + if not check_python_version(): exit(1) + + if isfile(folder_path('db', 'Noted.db')): + move(folder_path('db', 'Noted.db'), folder_path(*DB_FILENAME)) + + db_location = folder_path(*DB_FILENAME) + logging.debug(f'Database location: {db_location}') + makedirs(dirname(db_location), exist_ok=True) + + DBConnection.file = db_location - # Register web server - # We need to get the value to ui.py but MIND.py imports from ui.py so we get an import loop. - # To go around this, we abuse the fact that the logging module is a singleton. - # We add an attribute to the logging module and in ui.py get the value this way. - logging.URL_PREFIX = URL_PREFIX app = _create_app() + reminder_handler = ReminderHandler(app.app_context) with app.app_context(): - if isfile(_folder_path('db', 'Noted.db')): - move(_folder_path('db', 'Noted.db'), _folder_path(*DB_FILENAME)) - - db_location = _folder_path(*DB_FILENAME) - logging.debug(f'Database location: {db_location}') - makedirs(dirname(db_location), exist_ok=True) - - DBConnection.file = db_location setup_db() reminder_handler.find_next_reminder() # Create waitress server and run dispatcher = ThreadedTaskDispatcher() dispatcher.set_thread_count(THREADS) - server = create_server(app, _dispatcher=dispatcher, host=HOST, port=PORT, threads=THREADS) + server = create_server( + app, + _dispatcher=dispatcher, + host=HOST, + port=PORT, + threads=THREADS + ) logging.info(f'MIND running on http://{HOST}:{PORT}{URL_PREFIX}') + # ================= server.run() - - # Stopping thread + # ================= + reminder_handler.stop_handling() return diff --git a/backend/custom_exceptions.py b/backend/custom_exceptions.py index 230f471..b27a9b7 100644 --- a/backend/custom_exceptions.py +++ b/backend/custom_exceptions.py @@ -1,8 +1,17 @@ #-*- coding: utf-8 -*- +""" +All custom exceptions are defined here +""" + +""" +Note: Not all CE's inherit from CustomException. +""" + import logging from typing import Any, Dict + class CustomException(Exception): def __init__(self, e=None) -> None: logging.warning(self.__doc__) @@ -13,10 +22,18 @@ class UsernameTaken(CustomException): """The username is already taken""" api_response = {'error': 'UsernameTaken', 'result': {}, 'code': 400} -class UsernameInvalid(CustomException): +class UsernameInvalid(Exception): """The username contains invalid characters""" api_response = {'error': 'UsernameInvalid', 'result': {}, 'code': 400} + def __init__(self, username: str): + self.username = username + super().__init__(self.username) + logging.warning( + f'The username contains invalid characters: {username}' + ) + return + class UserNotFound(CustomException): """The user requested can not be found""" api_response = {'error': 'UserNotFound', 'result': {}, 'code': 404} @@ -33,59 +50,81 @@ class NotificationServiceNotFound(CustomException): """The notification service was not found""" api_response = {'error': 'NotificationServiceNotFound', 'result': {}, 'code': 404} -class NotificationServiceInUse(CustomException): - """The notification service is wished to be deleted but a reminder is still using it""" +class NotificationServiceInUse(Exception): + """ + The notification service is wished to be deleted + but a reminder is still using it + """ def __init__(self, type: str=''): self.type = type super().__init__(self.type) + logging.warning( + f'The notification is wished to be deleted but a reminder of type {type} is still using it' + ) + return @property def api_response(self) -> Dict[str, Any]: - return {'error': 'NotificationServiceInUse', 'result': {'type': self.type}, 'code': 400} + return { + 'error': 'NotificationServiceInUse', + 'result': {'type': self.type}, + 'code': 400 + } class InvalidTime(CustomException): """The time given is in the past""" api_response = {'error': 'InvalidTime', 'result': {}, 'code': 400} -class KeyNotFound(CustomException): +class KeyNotFound(Exception): """A key was not found in the input that is required to be given""" def __init__(self, key: str=''): self.key = key super().__init__(self.key) + logging.warning( + "This key was not found in the API request," + + f" eventhough it's required: {key}" + ) + return @property def api_response(self) -> Dict[str, Any]: - return {'error': 'KeyNotFound', 'result': {'key': self.key}, 'code': 400} + return { + 'error': 'KeyNotFound', + 'result': {'key': self.key}, + 'code': 400 + } -class InvalidKeyValue(CustomException): +class InvalidKeyValue(Exception): """The value of a key is invalid""" def __init__(self, key: str='', value: str=''): self.key = key self.value = value super().__init__(self.key) + logging.warning( + 'This key in the API request has an invalid value: ' + + f'{key} = {value}' + ) @property def api_response(self) -> Dict[str, Any]: - return {'error': 'InvalidKeyValue', 'result': {'key': self.key, 'value': self.value}, 'code': 400} + return { + 'error': 'InvalidKeyValue', + 'result': {'key': self.key, 'value': self.value}, + 'code': 400 + } class TemplateNotFound(CustomException): """The template was not found""" api_response = {'error': 'TemplateNotFound', 'result': {}, 'code': 404} -class APIKeyInvalid(CustomException): +class APIKeyInvalid(Exception): """The API key is not correct""" api_response = {'error': 'APIKeyInvalid', 'result': {}, 'code': 401} - - def __init__(self, e=None) -> None: - return -class APIKeyExpired(CustomException): +class APIKeyExpired(Exception): """The API key has expired""" api_response = {'error': 'APIKeyExpired', 'result': {}, 'code': 401} - def __init__(self, e=None) -> None: - return - class NewAccountsNotAllowed(CustomException): """It's not allowed to create a new account""" api_response = {'error': 'NewAccountsNotAllowed', 'result': {}, 'code': 403} diff --git a/backend/db.py b/backend/db.py index a7d113c..45b4a7a 100644 --- a/backend/db.py +++ b/backend/db.py @@ -1,11 +1,15 @@ #-*- coding: utf-8 -*- +""" +Setting up and interacting with the database. +""" + import logging from datetime import datetime from sqlite3 import Connection, ProgrammingError, Row from threading import current_thread, main_thread from time import time -from typing import Union +from typing import Type, Union from flask import g from waitress.task import ThreadedTaskDispatcher as OldThreadedTaskDispatcher @@ -14,14 +18,14 @@ from backend.custom_exceptions import AccessUnauthorized, UserNotFound __DATABASE_VERSION__ = 8 -class Singleton(type): +class DB_Singleton(type): _instances = {} def __call__(cls, *args, **kwargs): i = f'{cls}{current_thread()}' if (i not in cls._instances or cls._instances[i].closed): logging.debug(f'Creating singleton instance: {i}') - cls._instances[i] = super(Singleton, cls).__call__(*args, **kwargs) + cls._instances[i] = super(DB_Singleton, cls).__call__(*args, **kwargs) return cls._instances[i] @@ -29,19 +33,21 @@ class ThreadedTaskDispatcher(OldThreadedTaskDispatcher): def handler_thread(self, thread_no: int) -> None: super().handler_thread(thread_no) i = f'{DBConnection}{current_thread()}' - if i in Singleton._instances and not Singleton._instances[i].closed: + if i in DB_Singleton._instances and not DB_Singleton._instances[i].closed: logging.debug(f'Closing singleton instance: {i}') - Singleton._instances[i].close() - + DB_Singleton._instances[i].close() + return + def shutdown(self, cancel_pending: bool = True, timeout: int = 5) -> bool: print() logging.info('Shutting down MIND...') super().shutdown(cancel_pending, timeout) DBConnection(20.0).close() + return -class DBConnection(Connection, metaclass=Singleton): +class DBConnection(Connection, metaclass=DB_Singleton): file = '' - + def __init__(self, timeout: float) -> None: logging.debug(f'Opening database connection for {current_thread()}') super().__init__(self.file, timeout=timeout) @@ -55,11 +61,13 @@ class DBConnection(Connection, metaclass=Singleton): super().close() return -def get_db(output_type: Union[dict, tuple]=tuple): +def get_db(output_type: Union[Type[dict], Type[tuple]]=tuple): """Get a database cursor instance. Coupled to Flask's g. Args: - output_type (Union[dict, tuple], optional): The type of output: a tuple or dictionary with the row values. Defaults to tuple. + output_type (Union[Type[dict], Type[tuple]], optional): + The type of output: a tuple or dictionary with the row values. + Defaults to tuple. Returns: Cursor: The Cursor instance to use @@ -220,7 +228,7 @@ def migrate_db(current_db_version: int) -> None: if current_db_version == 7: # V7 -> V8 from backend.settings import _format_setting, default_settings - from backend.users import register_user + from backend.users import Users cursor.executescript(""" DROP TABLE config; @@ -239,7 +247,7 @@ def migrate_db(current_db_version: int) -> None: default_settings.items() ) ) - + cursor.executescript(""" ALTER TABLE users ADD admin BOOL NOT NULL DEFAULT 0; @@ -248,9 +256,9 @@ def migrate_db(current_db_version: int) -> None: SET username = 'admin_old' WHERE username = 'admin'; """) - - register_user('admin', 'admin') - + + Users().add('admin', 'admin', True) + cursor.execute(""" UPDATE users SET admin = 1 @@ -264,6 +272,8 @@ def setup_db() -> None: """ from backend.settings import (_format_setting, default_settings, get_setting, set_setting) + from backend.users import Users + cursor = get_db() cursor.execute("PRAGMA journal_mode = wal;") @@ -348,11 +358,22 @@ def setup_db() -> None: default_settings.items() ) ) - + current_db_version = get_setting('database_version') - logging.debug(f'Current database version {current_db_version} and desired database version {__DATABASE_VERSION__}') if current_db_version < __DATABASE_VERSION__: + logging.debug( + f'Database migration: {current_db_version} -> {__DATABASE_VERSION__}' + ) migrate_db(current_db_version) set_setting('database_version', __DATABASE_VERSION__) + users = Users() + if not 'admin' in users: + users.add('admin', 'admin', True) + cursor.execute(""" + UPDATE users + SET admin = 1 + WHERE username = 'admin'; + """) + return diff --git a/backend/helpers.py b/backend/helpers.py new file mode 100644 index 0000000..a176692 --- /dev/null +++ b/backend/helpers.py @@ -0,0 +1,91 @@ +#-*- coding: utf-8 -*- + +""" +General functions +""" + +import logging +from enum import Enum +from os.path import abspath, dirname, join +from sys import version_info + + +def folder_path(*folders) -> str: + """Turn filepaths relative to the project folder into absolute paths + + Returns: + str: The absolute filepath + """ + return join(dirname(dirname(abspath(__file__))), *folders) + + +def check_python_version() -> bool: + """Check if the python version that is used is a minimum version. + + Returns: + bool: Whether or not the python version is version 3.8 or above or not. + """ + if not (version_info.major == 3 and version_info.minor >= 8): + logging.critical( + 'The minimum python version required is python3.8 ' + + '(currently ' + version_info.major + '.' + version_info.minor + '.' + version_info.micro + ').' + ) + return False + return True + + +def search_filter(query: str, result: dict) -> bool: + """Filter library results based on a query. + + Args: + query (str): The query to filter with. + result (dict): The library result to check. + + Returns: + bool: Whether or not the result passes the filter. + """ + query = query.lower() + return ( + query in result["title"].lower() + or query in result["text"].lower() + ) + + +class Singleton(type): + _instances = {} + def __call__(cls, *args, **kwargs): + c = str(cls) + if c not in cls._instances: + cls._instances[c] = super().__call__(*args, **kwargs) + + return cls._instances[c] + + +class BaseEnum(Enum): + def __eq__(self, other) -> bool: + return self.value == other + + +class TimelessSortingMethod(BaseEnum): + TITLE = (lambda r: (r['title'], r['text'], r['color']), False) + TITLE_REVERSED = (lambda r: (r['title'], r['text'], r['color']), True) + DATE_ADDED = (lambda r: r['id'], False) + DATE_ADDED_REVERSED = (lambda r: r['id'], True) + + +class SortingMethod(BaseEnum): + TIME = (lambda r: (r['time'], r['title'], r['text'], r['color']), False) + TIME_REVERSED = (lambda r: (r['time'], r['title'], r['text'], r['color']), True) + TITLE = (lambda r: (r['title'], r['time'], r['text'], r['color']), False) + TITLE_REVERSED = (lambda r: (r['title'], r['time'], r['text'], r['color']), True) + DATE_ADDED = (lambda r: r['id'], False) + DATE_ADDED_REVERSED = (lambda r: r['id'], True) + + +class RepeatQuantity(BaseEnum): + YEARS = "years" + MONTHS = "months" + WEEKS = "weeks" + DAYS = "days" + HOURS = "hours" + MINUTES = "minutes" diff --git a/backend/notification_service.py b/backend/notification_service.py index 713cf09..e62962e 100644 --- a/backend/notification_service.py +++ b/backend/notification_service.py @@ -74,7 +74,10 @@ def get_apprise_services() -> List[Dict[str, Union[str, Dict[str, list]]]]: 'prefix': content.get('prefix'), 'regex': process_regex(content.get('regex')) } - for content, _ in ((entry['details']['tokens'][e], handled_tokens.add(e)) for e in v['group']) + for content, _ in ( + (entry['details']['tokens'][e], handled_tokens.add(e)) + for e in v['group'] + ) ] } for k, v in @@ -175,8 +178,13 @@ class NotificationService: def __init__(self, user_id: int, notification_service_id: int) -> None: self.id = notification_service_id - if not get_db().execute( - "SELECT 1 FROM notification_services WHERE id = ? AND user_id = ? LIMIT 1;", + if not get_db().execute(""" + SELECT 1 + FROM notification_services + WHERE id = ? + AND user_id = ? + LIMIT 1; + """, (self.id, user_id) ).fetchone(): raise NotificationServiceNotFound @@ -187,8 +195,12 @@ class NotificationService: Returns: dict: The info about the notification service """ - result = dict(get_db(dict).execute( - "SELECT id, title, url FROM notification_services WHERE id = ? LIMIT 1", + result = dict(get_db(dict).execute(""" + SELECT id, title, url + FROM notification_services + WHERE id = ? + LIMIT 1 + """, (self.id,) ).fetchone()) @@ -330,7 +342,7 @@ class NotificationServices: url (str): The apprise url of the service Returns: - dict: The info about the new service + NotificationService: The instance representing the new service """ logging.info(f'Adding notification service with {title=}, {url=}') diff --git a/backend/reminders.py b/backend/reminders.py index 84f6030..6ada583 100644 --- a/backend/reminders.py +++ b/backend/reminders.py @@ -4,28 +4,65 @@ import logging from datetime import datetime from sqlite3 import IntegrityError from threading import Timer -from typing import List, Literal +from typing import List, Literal, Union from apprise import Apprise -from dateutil.relativedelta import relativedelta, weekday -from flask import Flask +from dateutil.relativedelta import relativedelta +from dateutil.relativedelta import weekday as du_weekday from backend.custom_exceptions import (InvalidKeyValue, InvalidTime, NotificationServiceNotFound, ReminderNotFound) -from backend.db import close_db, get_db +from backend.db import get_db +from backend.helpers import RepeatQuantity, Singleton, SortingMethod, search_filter -filter_function = lambda query, p: ( - query in p["title"].lower() - or query in p["text"].lower() -) + +def __next_selected_day( + weekdays: List[int], + weekday: int +) -> int: + """Find the next allowed day in the week. + + Args: + weekdays (List[int]): The days of the week that are allowed. + Monday is 0, Sunday is 6. + weekday (int): The current weekday. + + Returns: + int: The next allowed weekday. + """ + return ( + # Get all days later than current, then grab first one. + [d for d in weekdays if weekday < d] + or + # weekday is last allowed day, so it should grab the first + # allowed day of the week. + weekdays + )[0] def _find_next_time( original_time: int, - repeat_quantity: Literal["years", "months", "weeks", "days", "hours", "minutes"], - repeat_interval: int, - weekdays: List[int] + repeat_quantity: Union[RepeatQuantity, None], + repeat_interval: Union[int, None], + weekdays: Union[List[int], None] ) -> int: + """Calculate the next timestep based on original time and repeat/interval + values. + + Args: + original_time (int): The original time of the repeating timestamp. + + repeat_quantity (Union[RepeatQuantity, None]): If set, what the quantity + is of the repetition. + + repeat_interval (Union[int, None]): If set, the value of the repetition. + + weekdays (Union[List[int], None]): If set, on which days the time can + continue. Monday is 0, Sunday is 6. + + Returns: + int: The next timestamp in the future. + """ if weekdays is not None: weekdays.sort() @@ -33,152 +70,50 @@ def _find_next_time( current_time = datetime.fromtimestamp(datetime.utcnow().timestamp()) if repeat_quantity is not None: - td = relativedelta(**{repeat_quantity: repeat_interval}) + td = relativedelta(**{repeat_quantity.value: repeat_interval}) while new_time <= current_time: new_time += td else: - next_day = ([d for d in weekdays if new_time.weekday() < d] or weekdays)[0] - proposed_time = new_time + relativedelta(weekday=weekday(next_day)) - if proposed_time == new_time: - proposed_time += relativedelta(weekday=weekday(next_day, 2)) - new_time = proposed_time - - while new_time <= current_time: - next_day = ([d for d in weekdays if new_time.weekday() < d] or weekdays)[0] - proposed_time = new_time + relativedelta(weekday=weekday(next_day)) + # We run the loop contents at least once and then actually use the cond. + # This is because we need to force the 'free' date to go to one of the + # selected weekdays. + # Say it's Monday, we set a reminder for Wednesday and make it repeat + # on Tuesday and Thursday. Then the first notification needs to go on + # Thurday, not Wednesday. So run code at least once to force that. + # Afterwards, it can run normally to push the timestamp into the future. + one_to_go = True + while one_to_go or new_time <= current_time: + next_day = __next_selected_day(weekdays, new_time.weekday()) + proposed_time = new_time + relativedelta(weekday=du_weekday(next_day)) if proposed_time == new_time: - proposed_time += relativedelta(weekday=weekday(next_day, 2)) + proposed_time += relativedelta(weekday=du_weekday(next_day, 2)) new_time = proposed_time + one_to_go = False result = int(new_time.timestamp()) logging.debug( - f'{original_time=}, {current_time=} and interval of {repeat_interval} {repeat_quantity} leads to {result}' + f'{original_time=}, {current_time=} ' + + f'and interval of {repeat_interval} {repeat_quantity} ' + + f'leads to {result}' ) return result -class ReminderHandler: - """Handle set reminders - """ - def __init__(self, context) -> None: - self.context = context - self.next_trigger = { - 'thread': None, - 'time': None - } - - return - - def __trigger_reminders(self, time: int) -> None: - """Trigger all reminders that are set for a certain time - - Args: - time (int): The time of the reminders to trigger - """ - with self.context(): - cursor = get_db(dict) - cursor.execute(""" - SELECT - r.id, - r.title, r.text, - r.repeat_quantity, r.repeat_interval, - r.weekdays, - r.original_time - FROM reminders r - WHERE time = ?; - """, (time,)) - reminders = list(map(dict, cursor)) - - for reminder in reminders: - cursor.execute(""" - SELECT url - FROM reminder_services rs - INNER JOIN notification_services ns - ON rs.notification_service_id = ns.id - WHERE rs.reminder_id = ?; - """, (reminder['id'],)) - - # Send of reminder - a = Apprise() - for url in cursor: - a.add(url['url']) - a.notify(title=reminder["title"], body=reminder["text"]) - - if reminder['repeat_quantity'] is None and reminder['weekdays'] is None: - # Delete the reminder from the database - cursor.execute( - "DELETE FROM reminders WHERE id = ?;", - (reminder['id'],) - ) - logging.info(f'Deleted reminder {reminder["id"]}') - else: - # Set next time - new_time = _find_next_time( - reminder['original_time'], - reminder['repeat_quantity'], - reminder['repeat_interval'], - [int(d) for d in reminder['weekdays'].split(',')] if reminder['weekdays'] is not None else None - ) - cursor.execute( - "UPDATE reminders SET time = ? WHERE id = ?;", - (new_time, reminder['id']) - ) - - self.next_trigger.update({ - 'thread': None, - 'time': None - }) - self.find_next_reminder() - - def find_next_reminder(self, time: int=None) -> None: - """Determine when the soonest reminder is and set the timer to that time - - Args: - time (int, optional): The timestamp to check for. Otherwise check soonest in database. Defaults to None. - """ - if not time: - with self.context(): - time = get_db().execute(""" - SELECT DISTINCT r1.time - FROM reminders r1 - LEFT JOIN reminders r2 - ON r1.time > r2.time - WHERE r2.id IS NULL; - """).fetchone() - if time is None: - return - time = time[0] - - if (self.next_trigger['thread'] is None - or time < self.next_trigger['time']): - if self.next_trigger['thread'] is not None: - self.next_trigger['thread'].cancel() - - t = time - datetime.utcnow().timestamp() - self.next_trigger['thread'] = Timer( - t, - self.__trigger_reminders, - (time,) - ) - self.next_trigger['thread'].name = "ReminderHandler" - self.next_trigger['thread'].start() - self.next_trigger['time'] = time - - def stop_handling(self) -> None: - """Stop the timer if it's active - """ - if self.next_trigger['thread'] is not None: - self.next_trigger['thread'].cancel() - return - -handler_context = Flask('handler') -handler_context.teardown_appcontext(close_db) -reminder_handler = ReminderHandler(handler_context.app_context) class Reminder: """Represents a reminder """ - def __init__(self, user_id: int, reminder_id: int): + def __init__(self, user_id: int, reminder_id: int) -> None: + """Create an instance. + + Args: + user_id (int): The ID of the user. + reminder_id (int): The ID of the reminder. + + Raises: + ReminderNotFound: Reminder with given ID does not exist or is not + owned by user. + """ self.id = reminder_id # Check if reminder exists @@ -188,6 +123,26 @@ class Reminder: ).fetchone(): raise ReminderNotFound + return + + def _get_notification_services(self) -> List[int]: + """Get ID's of notification services linked to the reminder. + + Returns: + List[int]: The list with ID's. + """ + result = [ + r[0] + for r in get_db().execute(""" + SELECT notification_service_id + FROM reminder_services + WHERE reminder_id = ?; + """, + (self.id,) + ) + ] + return result + def get(self) -> dict: """Get info about the reminder @@ -211,46 +166,65 @@ class Reminder: ).fetchone() reminder = dict(reminder) - reminder['notification_services'] = list(map(lambda r: r[0], get_db().execute(""" - SELECT notification_service_id - FROM reminder_services - WHERE reminder_id = ?; - """, (self.id,)))) + reminder['notification_services'] = self._get_notification_services() return reminder def update( self, - title: str = None, - time: int = None, - notification_services: List[int] = None, - text: str = None, - repeat_quantity: Literal["years", "months", "weeks", "days", "hours", "minutes"] = None, - repeat_interval: int = None, - weekdays: List[int] = None, - color: str = None + title: Union[None, str] = None, + time: Union[None, int] = None, + notification_services: Union[None, List[int]] = None, + text: Union[None, str] = None, + repeat_quantity: Union[None, RepeatQuantity] = None, + repeat_interval: Union[None, int] = None, + weekdays: Union[None, List[int]] = None, + color: Union[None, str] = None ) -> dict: - """Edit the reminder + """Edit the reminder. Args: - title (str): The new title of the entry. Defaults to None. - time (int): The new UTC epoch timestamp the the reminder should be send. Defaults to None. - notification_services (List[int]): The new list of id's of the notification services to use to send the reminder. Defaults to None. - text (str, optional): The new body of the reminder. Defaults to None. - repeat_quantity (Literal["years", "months", "weeks", "days", "hours", "minutes"], optional): The new quantity of the repeat specified for the reminder. Defaults to None. - repeat_interval (int, optional): The new amount of repeat_quantity, like "5" (hours). Defaults to None. - weekdays (List[int], optional): The new indexes of the days of the week that the reminder should run. Defaults to None. - color (str, optional): The new hex code of the color of the reminder, which is shown in the web-ui. Defaults to None. + title (Union[None, str]): The new title of the entry. + Defaults to None. + + time (Union[None, int]): The new UTC epoch timestamp when the + reminder should be send. + Defaults to None. + + notification_services (Union[None, List[int]]): The new list + of id's of the notification services to use to send the reminder. + Defaults to None. + + text (Union[None, str], optional): The new body of the reminder. + Defaults to None. + + repeat_quantity (Union[None, RepeatQuantity], optional): The new + quantity of the repeat specified for the reminder. + Defaults to None. + + repeat_interval (Union[None, int], optional): The new amount of + repeat_quantity, like "5" (hours). + Defaults to None. + + weekdays (Union[None, List[int]], optional): The new indexes of + the days of the week that the reminder should run. + Defaults to None. + + color (Union[None, str], optional): The new hex code of the color + of the reminder, which is shown in the web-ui. + Defaults to None. Note about args: - Either repeat_quantity and repeat_interval are given, weekdays is given or neither, but not both. + Either repeat_quantity and repeat_interval are given, weekdays is + given or neither, but not both. Raises: - NotificationServiceNotFound: One of the notification services was not found - InvalidKeyValue: The value of one of the keys is not valid or the "Note about args" is violated + NotificationServiceNotFound: One of the notification services was not found. + InvalidKeyValue: The value of one of the keys is not valid or + the "Note about args" is violated. Returns: - dict: The new reminder info + dict: The new reminder info. """ logging.info( f'Updating notification service {self.id}: ' @@ -264,8 +238,9 @@ class Reminder: raise InvalidKeyValue('repeat_quantity', repeat_quantity) elif repeat_quantity is not None and repeat_interval is None: raise InvalidKeyValue('repeat_interval', repeat_interval) - elif weekdays is not None and repeat_quantity is not None and repeat_interval is not None: + elif weekdays is not None and repeat_quantity is not None: raise InvalidKeyValue('weekdays', weekdays) + repeated_reminder = ( (repeat_quantity is not None and repeat_interval is not None) or weekdays is not None @@ -285,26 +260,35 @@ class Reminder: 'text': text, 'repeat_quantity': repeat_quantity, 'repeat_interval': repeat_interval, - 'weekdays': ",".join(map(str, sorted(weekdays))) if weekdays is not None else None, + 'weekdays': + ",".join(map(str, sorted(weekdays))) + if weekdays is not None else + None, 'color': color } for k, v in new_values.items(): - if k in ('repeat_quantity', 'repeat_interval', 'weekdays', 'color') or v is not None: + if ( + k in ('repeat_quantity', 'repeat_interval', 'weekdays', 'color') + or v is not None + ): data[k] = v # Update database if repeated_reminder: next_time = _find_next_time( data["time"], - data["repeat_quantity"], data["repeat_interval"], + RepeatQuantity(data["repeat_quantity"]), + data["repeat_interval"], weekdays ) cursor.execute(""" UPDATE reminders SET - title=?, text=?, + title=?, + text=?, time=?, - repeat_quantity=?, repeat_interval=?, + repeat_quantity=?, + repeat_interval=?, weekdays=?, original_time=?, color=? @@ -326,9 +310,11 @@ class Reminder: cursor.execute(""" UPDATE reminders SET - title=?, text=?, + title=?, + text=?, time=?, - repeat_quantity=?, repeat_interval=?, + repeat_quantity=?, + repeat_interval=?, weekdays=?, color=? WHERE id = ?; @@ -346,18 +332,29 @@ class Reminder: if notification_services: cursor.connection.isolation_level = None cursor.execute("BEGIN TRANSACTION;") - cursor.execute("DELETE FROM reminder_services WHERE reminder_id = ?", (self.id,)) + cursor.execute( + "DELETE FROM reminder_services WHERE reminder_id = ?", + (self.id,) + ) try: - cursor.executemany( - "INSERT INTO reminder_services(reminder_id, notification_service_id) VALUES (?,?)", + cursor.executemany(""" + INSERT INTO reminder_services( + reminder_id, + notification_service_id + ) + VALUES (?,?); + """, ((self.id, s) for s in notification_services) ) cursor.execute("COMMIT;") + except IntegrityError: raise NotificationServiceNotFound - cursor.connection.isolation_level = "" - reminder_handler.find_next_reminder(next_time) + finally: + cursor.connection.isolation_level = "" + + ReminderHandler().find_next_reminder(next_time) return self.get() def delete(self) -> None: @@ -365,74 +362,76 @@ class Reminder: """ logging.info(f'Deleting reminder {self.id}') get_db().execute("DELETE FROM reminders WHERE id = ?", (self.id,)) - reminder_handler.find_next_reminder() + ReminderHandler().find_next_reminder() return class Reminders: """Represents the reminder library of the user account - """ - sort_functions = { - 'time': (lambda r: (r['time'], r['title'], r['text'], r['color']), False), - 'time_reversed': (lambda r: (r['time'], r['title'], r['text'], r['color']), True), - 'title': (lambda r: (r['title'], r['time'], r['text'], r['color']), False), - 'title_reversed': (lambda r: (r['title'], r['time'], r['text'], r['color']), True), - 'date_added': (lambda r: r['id'], False), - 'date_added_reversed': (lambda r: r['id'], True) - } - - def __init__(self, user_id: int): - self.user_id = user_id + """ - def fetchall(self, sort_by: Literal["time", "time_reversed", "title", "title_reversed", "date_added", "date_added_reversed"] = "time") -> List[dict]: + def __init__(self, user_id: int) -> None: + """Create an instance. + + Args: + user_id (int): The ID of the user. + """ + self.user_id = user_id + return + + def fetchall( + self, + sort_by: SortingMethod = SortingMethod.TIME + ) -> List[dict]: """Get all reminders Args: - sort_by (Literal["time", "time_reversed", "title", "title_reversed", "date_added", "date_added_reversed"], optional): How to sort the result. Defaults to "time". + sort_by (SortingMethod, optional): How to sort the result. + Defaults to SortingMethod.TIME. Returns: List[dict]: The id, title, text, time and color of each reminder """ - sort_function = self.sort_functions.get( - sort_by, - self.sort_functions['time'] - ) - - # Fetch all reminders - reminders: list = list(map(dict, get_db(dict).execute(""" - SELECT - id, - title, text, - time, - repeat_quantity, - repeat_interval, - weekdays, - color - FROM reminders - WHERE user_id = ?; - """, - (self.user_id,) - ))) + reminders = [ + dict(r) + for r in get_db(dict).execute(""" + SELECT + id, + title, text, + time, + repeat_quantity, + repeat_interval, + weekdays, + color + FROM reminders + WHERE user_id = ?; + """, + (self.user_id,) + ) + ] # Sort result - reminders.sort(key=sort_function[0], reverse=sort_function[1]) + reminders.sort(key=sort_by.value[0], reverse=sort_by.value[1]) return reminders - def search(self, query: str, sort_by: Literal["time", "time_reversed", "title", "title_reversed", "date_added", "date_added_reversed"] = "time") -> List[dict]: + def search( + self, + query: str, + sort_by: SortingMethod = SortingMethod.TIME) -> List[dict]: """Search for reminders Args: - query (str): The term to search for - sort_by (Literal["time", "time_reversed", "title", "title_reversed", "date_added", "date_added_reversed"], optional): How to sort the result. Defaults to "time". + query (str): The term to search for. + sort_by (SortingMethod, optional): How to sort the result. + Defaults to SortingMethod.TIME. Returns: List[dict]: All reminders that match. Similar output to self.fetchall - """ - query = query.lower() - reminders = list(filter( - lambda p: filter_function(query, p), - self.fetchall(sort_by) - )) + """ + reminders = [ + r for r in self.fetchall(sort_by) + if search_filter(query, r) + ] return reminders def fetchone(self, id: int) -> Reminder: @@ -452,32 +451,51 @@ class Reminders: time: int, notification_services: List[int], text: str = '', - repeat_quantity: Literal["years", "months", "weeks", "days", "hours", "minutes"] = None, - repeat_interval: int = None, - weekdays: List[int] = None, - color: str = None + repeat_quantity: Union[None, RepeatQuantity] = None, + repeat_interval: Union[None, int] = None, + weekdays: Union[None, List[int]] = None, + color: Union[None, str] = None ) -> Reminder: """Add a reminder Args: - title (str): The title of the entry + title (str): The title of the entry. + time (int): The UTC epoch timestamp the the reminder should be send. - notification_services (List[int]): The id's of the notification services to use to send the reminder. - text (str, optional): The body of the reminder. Defaults to ''. - repeat_quantity (Literal["years", "months", "weeks", "days", "hours", "minutes"], optional): The quantity of the repeat specified for the reminder. Defaults to None. - repeat_interval (int, optional): The amount of repeat_quantity, like "5" (hours). Defaults to None. - weekdays (List[int], optional): The indexes of the days of the week that the reminder should run. Defaults to None. - color (str, optional): The hex code of the color of the reminder, which is shown in the web-ui. Defaults to None. + + notification_services (List[int]): The id's of the notification services + to use to send the reminder. + + text (str, optional): The body of the reminder. + Defaults to ''. + + repeat_quantity (Union[None, RepeatQuantity], optional): The quantity + of the repeat specified for the reminder. + Defaults to None. + + repeat_interval (Union[None, int], optional): The amount of repeat_quantity, + like "5" (hours). + Defaults to None. + + weekdays (Union[None, List[int]], optional): The indexes of the days + of the week that the reminder should run. + Defaults to None. + + color (Union[None, str], optional): The hex code of the color of the + reminder, which is shown in the web-ui. + Defaults to None. Note about args: - Either repeat_quantity and repeat_interval are given, weekdays is given or neither, but not both. + Either repeat_quantity and repeat_interval are given, + weekdays is given or neither, but not both. Raises: - NotificationServiceNotFound: One of the notification services was not found - InvalidKeyValue: The value of one of the keys is not valid or the "Note about args" is violated + NotificationServiceNotFound: One of the notification services was not found. + InvalidKeyValue: The value of one of the keys is not valid + or the "Note about args" is violated. Returns: - dict: The info about the reminder + dict: The info about the reminder. """ logging.info( f'Adding reminder with {title=}, {time=}, {notification_services=}, ' @@ -492,50 +510,89 @@ class Reminders: raise InvalidKeyValue('repeat_quantity', repeat_quantity) elif repeat_quantity is not None and repeat_interval is None: raise InvalidKeyValue('repeat_interval', repeat_interval) - elif weekdays is not None and repeat_quantity is not None and repeat_interval is not None: + elif ( + weekdays is not None + and repeat_quantity is not None + and repeat_interval is not None + ): raise InvalidKeyValue('weekdays', weekdays) cursor = get_db() for service in notification_services: - if not cursor.execute( - "SELECT 1 FROM notification_services WHERE id = ? AND user_id = ? LIMIT 1;", + if not cursor.execute(""" + SELECT 1 + FROM notification_services + WHERE id = ? + AND user_id = ? + LIMIT 1; + """, (service, self.user_id) ).fetchone(): raise NotificationServiceNotFound - if repeat_quantity is not None and repeat_interval is not None: - id = cursor.execute(""" - INSERT INTO reminders(user_id, title, text, time, repeat_quantity, repeat_interval, original_time, color) - VALUES (?, ?, ?, ?, ?, ?, ?, ?); - """, (self.user_id, title, text, time, repeat_quantity, repeat_interval, time, color) - ).lastrowid - - elif weekdays is not None: - weekdays = ",".join(map(str, sorted(weekdays))) - id = cursor.execute(""" - INSERT INTO reminders(user_id, title, text, time, weekdays, original_time, color) - VALUES (?, ?, ?, ?, ?, ?, ?); - """, (self.user_id, title, text, time, weekdays, time, color) - ).lastrowid - + # Prepare args + if any((repeat_quantity, weekdays)): + original_time = time + time = _find_next_time( + original_time, + repeat_quantity, + repeat_interval, + weekdays + ) else: - id = cursor.execute(""" - INSERT INTO reminders(user_id, title, text, time, color) - VALUES (?, ?, ?, ?, ?); - """, (self.user_id, title, text, time, color) - ).lastrowid - + original_time = None + + if weekdays is not None: + weekdays = ",".join(map(str, sorted(weekdays))) + + if repeat_quantity is not None: + repeat_quantity = repeat_quantity.value + + cursor.connection.isolation_level = None + cursor.execute("BEGIN TRANSACTION;") + + id = cursor.execute(""" + INSERT INTO reminders( + user_id, + title, text, + time, + repeat_quantity, repeat_interval, + weekdays, + original_time, + color + ) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?); + """, ( + self.user_id, + title, text, + time, + repeat_quantity, + repeat_interval, + weekdays, + original_time, + color + )).lastrowid + try: - cursor.executemany( - "INSERT INTO reminder_services(reminder_id, notification_service_id) VALUES (?, ?);", + cursor.executemany(""" + INSERT INTO reminder_services( + reminder_id, + notification_service_id + ) + VALUES (?, ?); + """, ((id, service) for service in notification_services) ) + cursor.execute("COMMIT;") + except IntegrityError: raise NotificationServiceNotFound - - reminder_handler.find_next_reminder(time) - # Return info + finally: + cursor.connection.isolation_level = '' + + ReminderHandler().find_next_reminder(time) + return self.fetchone(id) def test_reminder( @@ -544,23 +601,164 @@ class Reminders: notification_services: List[int], text: str = '' ) -> None: - """Test send a reminder draft + """Test send a reminder draft. Args: - title (str): Title title of the entry - notification_service (int): The id of the notification service to use to send the reminder - text (str, optional): The body of the reminder. Defaults to ''. + title (str): Title title of the entry. + + notification_service (int): The id of the notification service to + use to send the reminder. + + text (str, optional): The body of the reminder. + Defaults to ''. """ logging.info(f'Testing reminder with {title=}, {notification_services=}, {text=}') a = Apprise() cursor = get_db(dict) + for service in notification_services: - url = cursor.execute( - "SELECT url FROM notification_services WHERE id = ? AND user_id = ? LIMIT 1;", + url = cursor.execute(""" + SELECT url + FROM notification_services + WHERE id = ? + AND user_id = ? + LIMIT 1; + """, (service, self.user_id) ).fetchone() if not url: raise NotificationServiceNotFound a.add(url[0]) + a.notify(title=title, body=text) return + + +class ReminderHandler(metaclass=Singleton): + """Handle set reminders. + + Note: Singleton. + """ + def __init__(self, context) -> None: + """Create instance of handler. + + Args: + context (AppContext): `Flask.app_context` + """ + self.context = context + self.thread: Union[Timer, None] = None + self.time: Union[int, None] = None + return + + def __trigger_reminders(self, time: int) -> None: + """Trigger all reminders that are set for a certain time + + Args: + time (int): The time of the reminders to trigger + """ + with self.context(): + cursor = get_db(dict) + reminders = [ + dict(r) + for r in cursor.execute(""" + SELECT + id, user_id, + title, text, + repeat_quantity, repeat_interval, + weekdays, + original_time + FROM reminders + WHERE time = ?; + """, + (time,) + ) + ] + + for reminder in reminders: + cursor.execute(""" + SELECT url + FROM reminder_services rs + INNER JOIN notification_services ns + ON rs.notification_service_id = ns.id + WHERE rs.reminder_id = ?; + """, + (reminder['id'],) + ) + + # Send reminder + a = Apprise() + for url in cursor: + a.add(url['url']) + a.notify(title=reminder["title"], body=reminder["text"]) + + self.thread = None + self.time = None + + if (reminder['repeat_quantity'], reminder['weekdays']) == (None, None): + # Delete the reminder from the database + Reminder(reminder["user_id"], reminder["id"]).delete() + + else: + # Set next time + new_time = _find_next_time( + reminder['original_time'], + RepeatQuantity(reminder['repeat_quantity']), + reminder['repeat_interval'], + [int(d) for d in reminder['weekdays'].split(',')] + if reminder['weekdays'] is not None else + None + ) + cursor.execute( + "UPDATE reminders SET time = ? WHERE id = ?;", + (new_time, reminder['id']) + ) + + self.find_next_reminder() + return + + def find_next_reminder(self, time: int=None) -> None: + """Determine when the soonest reminder is and set the timer to that time + + Args: + time (int, optional): The timestamp to check for. + Otherwise check soonest in database. + Defaults to None. + """ + if not time: + with self.context(): + time = get_db().execute(""" + SELECT DISTINCT r1.time + FROM reminders r1 + LEFT JOIN reminders r2 + ON r1.time > r2.time + WHERE r2.id IS NULL; + """).fetchone() + if time is None: + return + time = time[0] + + if ( + self.thread is None + or time < self.time + ): + if self.thread is not None: + self.thread.cancel() + + t = time - datetime.utcnow().timestamp() + self.thread = Timer( + t, + self.__trigger_reminders, + (time,) + ) + self.thread.name = "ReminderHandler" + self.thread.start() + self.time = time + + return + + def stop_handling(self) -> None: + """Stop the timer if it's active + """ + if self.thread is not None: + self.thread.cancel() + return diff --git a/backend/security.py b/backend/security.py index 15a248d..b5afaa1 100644 --- a/backend/security.py +++ b/backend/security.py @@ -1,5 +1,9 @@ #-*- coding: utf-8 -*- +""" +Hashing and salting +""" + from base64 import urlsafe_b64encode from hashlib import pbkdf2_hmac from secrets import token_bytes @@ -29,7 +33,6 @@ def generate_salt_hash(password: str) -> Tuple[bytes, bytes]: Returns: Tuple[bytes, bytes]: The salt (1) and hashed_password (2) """ - # Hash the password salt = token_bytes() hashed_password = get_hash(salt, password) del password diff --git a/backend/settings.py b/backend/settings.py index 5ffc1c3..89c4649 100644 --- a/backend/settings.py +++ b/backend/settings.py @@ -1,5 +1,9 @@ #-*- coding: utf-8 -*- +""" +Getting and setting settings +""" + from backend.custom_exceptions import InvalidKeyValue, KeyNotFound from backend.db import __DATABASE_VERSION__, get_db @@ -85,7 +89,7 @@ def get_admin_settings() -> dict: """ return dict(( (key, _reverse_format_setting(key, value)) - for (key, value) in get_db().execute(""" + for key, value in get_db().execute(""" SELECT key, value FROM config WHERE diff --git a/backend/static_reminders.py b/backend/static_reminders.py index 6adf046..f16f365 100644 --- a/backend/static_reminders.py +++ b/backend/static_reminders.py @@ -2,23 +2,30 @@ import logging from sqlite3 import IntegrityError -from typing import List, Literal +from typing import List, Union from apprise import Apprise from backend.custom_exceptions import (NotificationServiceNotFound, ReminderNotFound) from backend.db import get_db +from backend.helpers import TimelessSortingMethod, search_filter -filter_function = lambda query, p: ( - query in p["title"].lower() - or query in p["text"].lower() -) class StaticReminder: """Represents a static reminder """ def __init__(self, user_id: int, reminder_id: int) -> None: + """Create an instance. + + Args: + user_id (int): The ID of the user. + reminder_id (int): The ID of the reminder. + + Raises: + ReminderNotFound: Reminder with given ID does not exist or is not + owned by user. + """ self.id = reminder_id # Check if reminder exists @@ -27,12 +34,32 @@ class StaticReminder: (self.id, user_id) ).fetchone(): raise ReminderNotFound - + + return + + def _get_notification_services(self) -> List[int]: + """Get ID's of notification services linked to the static reminder. + + Returns: + List[int]: The list with ID's. + """ + result = [ + r[0] + for r in get_db().execute(""" + SELECT notification_service_id + FROM reminder_services + WHERE static_reminder_id = ?; + """, + (self.id,) + ) + ] + return result + def get(self) -> dict: """Get info about the static reminder Returns: - dict: The info about the reminder + dict: The info about the static reminder """ reminder = get_db(dict).execute(""" SELECT @@ -47,28 +74,33 @@ class StaticReminder: ).fetchone() reminder = dict(reminder) - reminder['notification_services'] = list(map(lambda r: r[0], get_db().execute(""" - SELECT notification_service_id - FROM reminder_services - WHERE static_reminder_id = ?; - """, (self.id,)))) + reminder['notification_services'] = self._get_notification_services() return reminder def update( self, - title: str = None, - notification_services: List[int] = None, - text: str = None, - color: str = None + title: Union[str, None] = None, + notification_services: Union[List[int], None] = None, + text: Union[str, None] = None, + color: Union[str, None] = None ) -> dict: - """Edit the static reminder + """Edit the static reminder. Args: - title (str, optional): The new title of the entry. Defaults to None. - notification_services (List[int], optional): The new id's of the notification services to use to send the reminder. Defaults to None. - text (str, optional): The new body of the reminder. Defaults to None. - color (str, optional): The new hex code of the color of the reminder, which is shown in the web-ui. Defaults to None. + title (Union[str, None], optional): The new title of the entry. + Defaults to None. + + notification_services (Union[List[int], None], optional): + The new id's of the notification services to use to send the reminder. + Defaults to None. + + text (Union[str, None], optional): The new body of the reminder. + Defaults to None. + + color (Union[str, None], optional): The new hex code of the color + of the reminder, which is shown in the web-ui. + Defaults to None. Raises: NotificationServiceNotFound: One of the notification services was not found @@ -109,16 +141,27 @@ class StaticReminder: if notification_services: cursor.connection.isolation_level = None cursor.execute("BEGIN TRANSACTION;") - cursor.execute("DELETE FROM reminder_services WHERE static_reminder_id = ?", (self.id,)) + cursor.execute( + "DELETE FROM reminder_services WHERE static_reminder_id = ?", + (self.id,) + ) try: - cursor.executemany( - "INSERT INTO reminder_services(static_reminder_id, notification_service_id) VALUES (?,?)", + cursor.executemany(""" + INSERT INTO reminder_services( + static_reminder_id, + notification_service_id + ) + VALUES (?,?); + """, ((self.id, s) for s in notification_services) ) cursor.execute("COMMIT;") + except IntegrityError: raise NotificationServiceNotFound - cursor.connection.isolation_level = "" + + finally: + cursor.connection.isolation_level = "" return self.get() @@ -132,33 +175,32 @@ class StaticReminder: class StaticReminders: """Represents the static reminder library of the user account """ - sort_functions = { - 'title': (lambda r: (r['title'], r['text'], r['color']), False), - 'title_reversed': (lambda r: (r['title'], r['text'], r['color']), True), - 'date_added': (lambda r: r['id'], False), - 'date_added_reversed': (lambda r: r['id'], True) - } def __init__(self, user_id: int) -> None: + """Create an instance. + + Args: + user_id (int): The ID of the user. + """ self.user_id = user_id + return - def fetchall(self, sort_by: Literal["title", "title_reversed", "date_added", "date_added_reversed"] = "title") -> List[dict]: + def fetchall( + self, + sort_by: TimelessSortingMethod = TimelessSortingMethod.TITLE + ) -> List[dict]: """Get all static reminders Args: - sort_by (Literal["title", "title_reversed", "date_added", "date_added_reversed"], optional): How to sort the result. Defaults to "title". + sort_by (TimelessSortingMethod, optional): How to sort the result. + Defaults to TimelessSortingMethod.TITLE. Returns: - List[dict]: The id, title, text and color of each static reminder + List[dict]: The id, title, text and color of each static reminder. """ - sort_function = self.sort_functions.get( - sort_by, - self.sort_functions['title'] - ) - - reminders: list = list(map( - dict, - get_db(dict).execute(""" + reminders = [ + dict(r) + for r in get_db(dict).execute(""" SELECT id, title, text, @@ -169,29 +211,36 @@ class StaticReminders: """, (self.user_id,) ) - )) - + ] + # Sort result - reminders.sort(key=sort_function[0], reverse=sort_function[1]) + reminders.sort(key=sort_by.value[0], reverse=sort_by.value[1]) return reminders - def search(self, query: str, sort_by: Literal["title", "title_reversed", "date_added", "date_added_reversed"] = "title") -> List[dict]: + def search( + self, + query: str, + sort_by: TimelessSortingMethod = TimelessSortingMethod.TITLE + ) -> List[dict]: """Search for static reminders Args: - query (str): The term to search for - sort_by (Literal["title", "title_reversed", "date_added", "date_added_reversed"], optional): How to sort the result. Defaults to "title". + query (str): The term to search for. + + sort_by (TimelessSortingMethod, optional): The sorting method of + the resulting list. + Defaults to TimelessSortingMethod.TITLE. Returns: - List[dict]: All static reminders that match. Similar output to self.fetchall - """ - query = query.lower() - reminders = list(filter( - lambda p: filter_function(query, p), - self.fetchall(sort_by) - )) - return reminders + List[dict]: All static reminders that match. + Similar output to `self.fetchall` + """ + static_reminders = [ + r for r in self.fetchall(sort_by) + if search_filter(query, r) + ] + return static_reminders def fetchone(self, id: int) -> StaticReminder: """Get one static reminder @@ -214,22 +263,32 @@ class StaticReminders: """Add a static reminder Args: - title (str): The title of the entry - notification_services (List[int]): The id's of the notification services to use to send the reminder. - text (str, optional): The body of the reminder. Defaults to ''. - color (str, optional): The hex code of the color of the reminder, which is shown in the web-ui. Defaults to None. + title (str): The title of the entry. + + notification_services (List[int]): The id's of the + notification services to use to send the reminder. + + text (str, optional): The body of the reminder. + Defaults to ''. + + color (str, optional): The hex code of the color of the template, + which is shown in the web-ui. + Defaults to None. Raises: NotificationServiceNotFound: One of the notification services was not found Returns: - StaticReminder: A StaticReminder instance representing the newly created static reminder + StaticReminder: The info about the static reminder """ logging.info( f'Adding static reminder with {title=}, {notification_services=}, {text=}, {color=}' ) cursor = get_db() + cursor.connection.isolation_level = None + cursor.execute("BEGIN TRANSACTION;") + id = cursor.execute(""" INSERT INTO static_reminders(user_id, title, text, color) VALUES (?,?,?,?); @@ -238,13 +297,22 @@ class StaticReminders: ).lastrowid try: - cursor.executemany( - "INSERT INTO reminder_services(static_reminder_id, notification_service_id) VALUES (?, ?);", + cursor.executemany(""" + INSERT INTO reminder_services( + static_reminder_id, + notification_service_id + ) + VALUES (?, ?); + """, ((id, service) for service in notification_services) ) + cursor.execute("COMMIT;") + except IntegrityError: raise NotificationServiceNotFound - + finally: + cursor.connection.isolation_level = "" + return self.fetchone(id) def trigger_reminder(self, id: int) -> None: @@ -265,7 +333,9 @@ class StaticReminders: id = ? AND user_id = ? LIMIT 1; - """, (id, self.user_id)).fetchone() + """, + (id, self.user_id) + ).fetchone() if not reminder: raise ReminderNotFound reminder = dict(reminder) @@ -277,7 +347,9 @@ class StaticReminders: INNER JOIN notification_services ns ON rs.notification_service_id = ns.id WHERE rs.static_reminder_id = ?; - """, (id,)) + """, + (id,) + ) for url in cursor: a.add(url['url']) a.notify(title=reminder['title'], body=reminder['text']) diff --git a/backend/templates.py b/backend/templates.py index 177a037..f151a68 100644 --- a/backend/templates.py +++ b/backend/templates.py @@ -2,21 +2,28 @@ import logging from sqlite3 import IntegrityError -from typing import List, Literal +from typing import List, Union from backend.custom_exceptions import (NotificationServiceNotFound, TemplateNotFound) from backend.db import get_db +from backend.helpers import TimelessSortingMethod, search_filter -filter_function = lambda query, p: ( - query in p["title"].lower() - or query in p["text"].lower() -) class Template: """Represents a template """ - def __init__(self, user_id: int, template_id: int): + def __init__(self, user_id: int, template_id: int) -> None: + """Create instance of class. + + Args: + user_id (int): The ID of the user. + template_id (int): The ID of the template. + + Raises: + TemplateNotFound: Template with given ID does not exist or is not + owned by user. + """ self.id = template_id exists = get_db().execute( @@ -25,7 +32,26 @@ class Template: ).fetchone() if not exists: raise TemplateNotFound - + return + + def _get_notification_services(self) -> List[int]: + """Get ID's of notification services linked to the template. + + Returns: + List[int]: The list with ID's. + """ + result = [ + r[0] + for r in get_db().execute(""" + SELECT notification_service_id + FROM reminder_services + WHERE template_id = ?; + """, + (self.id,) + ) + ] + return result + def get(self) -> dict: """Get info about the template @@ -45,27 +71,35 @@ class Template: ).fetchone() template = dict(template) - template['notification_services'] = list(map(lambda r: r[0], get_db().execute(""" - SELECT notification_service_id - FROM reminder_services - WHERE template_id = ?; - """, (self.id,)))) + template['notification_services'] = self._get_notification_services() return template def update(self, - title: str = None, - notification_services: List[int] = None, - text: str = None, - color: str = None + title: Union[str, None] = None, + notification_services: Union[List[int], None] = None, + text: Union[str, None] = None, + color: Union[str, None] = None ) -> dict: """Edit the template Args: - title (str): The new title of the entry. Defaults to None. - notification_services (List[int]): The new id's of the notification services to use to send the reminder. Defaults to None. - text (str, optional): The new body of the template. Defaults to None. - color (str, optional): The new hex code of the color of the template, which is shown in the web-ui. Defaults to None. + title (Union[str, None]): The new title of the entry. + Defaults to None. + + notification_services (Union[List[int], None]): The new id's of the + notification services to use to send the reminder. + Defaults to None. + + text (Union[str, None], optional): The new body of the template. + Defaults to None. + + color (Union[str, None], optional): The new hex code of the color of the template, + which is shown in the web-ui. + Defaults to None. + + Raises: + NotificationServiceNotFound: One of the notification services was not found Returns: dict: The new template info @@ -101,16 +135,27 @@ class Template: if notification_services: cursor.connection.isolation_level = None cursor.execute("BEGIN TRANSACTION;") - cursor.execute("DELETE FROM reminder_services WHERE template_id = ?", (self.id,)) + cursor.execute( + "DELETE FROM reminder_services WHERE template_id = ?", + (self.id,) + ) try: - cursor.executemany( - "INSERT INTO reminder_services(template_id, notification_service_id) VALUES (?,?)", + cursor.executemany(""" + INSERT INTO reminder_services( + template_id, + notification_service_id + ) + VALUES (?,?); + """, ((self.id, s) for s in notification_services) ) cursor.execute("COMMIT;") + except IntegrityError: raise NotificationServiceNotFound - cursor.connection.isolation_level = "" + + finally: + cursor.connection.isolation_level = "" return self.get() @@ -124,63 +169,72 @@ class Template: class Templates: """Represents the template library of the user account """ - sort_functions = { - 'title': (lambda r: (r['title'], r['text'], r['color']), False), - 'title_reversed': (lambda r: (r['title'], r['text'], r['color']), True), - 'date_added': (lambda r: r['id'], False), - 'date_added_reversed': (lambda r: r['id'], True) - } - def __init__(self, user_id: int): - self.user_id = user_id - - def fetchall(self, sort_by: Literal["title", "title_reversed", "date_added", "date_added_reversed"] = "title") -> List[dict]: - """Get all templates + def __init__(self, user_id: int) -> None: + """Create an instance. Args: - sort_by (Literal["title", "title_reversed", "date_added", "date_added_reversed"], optional): How to sort the result. Defaults to "title". + user_id (int): The ID of the user. + """ + self.user_id = user_id + return + + def fetchall( + self, + sort_by: TimelessSortingMethod = TimelessSortingMethod.TITLE + ) -> List[dict]: + """Get all templates of the user. + + Args: + sort_by (TimelessSortingMethod, optional): The sorting method of + the resulting list. + Defaults to TimelessSortingMethod.TITLE. Returns: - List[dict]: The id, title, text and color + List[dict]: The id, title, text and color of each template. """ - sort_function = self.sort_functions.get( - sort_by, - self.sort_functions['title'] - ) - - templates: list = list(map(dict, get_db(dict).execute(""" - SELECT - id, - title, text, - color - FROM templates - WHERE user_id = ? - ORDER BY title, id; - """, - (self.user_id,) - ))) + templates = [ + dict(r) + for r in get_db(dict).execute(""" + SELECT + id, + title, text, + color + FROM templates + WHERE user_id = ? + ORDER BY title, id; + """, + (self.user_id,) + ) + ] # Sort result - templates.sort(key=sort_function[0], reverse=sort_function[1]) + templates.sort(key=sort_by.value[0], reverse=sort_by.value[1]) return templates - def search(self, query: str, sort_by: Literal["title", "title_reversed", "date_added", "date_added_reversed"] = "title") -> List[dict]: + def search( + self, + query: str, + sort_by: TimelessSortingMethod = TimelessSortingMethod.TITLE + ) -> List[dict]: """Search for templates Args: - query (str): The term to search for - sort_by (Literal["title", "title_reversed", "date_added", "date_added_reversed"], optional): How to sort the result. Defaults to "title". + query (str): The term to search for. + + sort_by (TimelessSortingMethod, optional): The sorting method of + the resulting list. + Defaults to TimelessSortingMethod.TITLE. Returns: - List[dict]: All templates that match. Similar output to self.fetchall + List[dict]: All templates that match. Similar output to `self.fetchall` """ - query = query.lower() - reminders = list(filter( - lambda p: filter_function(query, p), - self.fetchall(sort_by) - )) - return reminders + templates = [ + r for r in self.fetchall(sort_by) + if search_filter(query, r) + ] + return templates def fetchone(self, id: int) -> Template: """Get one template @@ -203,10 +257,20 @@ class Templates: """Add a template Args: - title (str): The title of the entry - notification_services (List[int]): The id's of the notification services to use to send the reminder. - text (str, optional): The body of the reminder. Defaults to ''. - color (str, optional): The hex code of the color of the template, which is shown in the web-ui. Defaults to None. + title (str): The title of the entry. + + notification_services (List[int]): The id's of the + notification services to use to send the reminder. + + text (str, optional): The body of the reminder. + Defaults to ''. + + color (str, optional): The hex code of the color of the template, + which is shown in the web-ui. + Defaults to None. + + Raises: + NotificationServiceNotFound: One of the notification services was not found Returns: Template: The info about the template @@ -216,19 +280,32 @@ class Templates: ) cursor = get_db() + cursor.connection.isolation_level = None + cursor.execute("BEGIN TRANSACTION;") + id = cursor.execute(""" INSERT INTO templates(user_id, title, text, color) VALUES (?,?,?,?); """, (self.user_id, title, text, color) ).lastrowid - + try: - cursor.executemany( - "INSERT INTO reminder_services(template_id, notification_service_id) VALUES (?, ?);", + cursor.executemany(""" + INSERT INTO reminder_services( + template_id, + notification_service_id + ) + VALUES (?, ?); + """, ((id, service) for service in notification_services) ) + cursor.execute("COMMIT;") + except IntegrityError: raise NotificationServiceNotFound + finally: + cursor.connection.isolation_level = "" + return self.fetchone(id) diff --git a/backend/users.py b/backend/users.py index 732b466..e881ae2 100644 --- a/backend/users.py +++ b/backend/users.py @@ -1,7 +1,7 @@ #-*- coding: utf-8 -*- import logging -from typing import List, Union +from typing import List from backend.custom_exceptions import (AccessUnauthorized, NewAccountsNotAllowed, UsernameInvalid, @@ -19,32 +19,28 @@ ONEPASS_INVALID_USERNAMES = ['reminders', 'api'] class User: """Represents an user account - """ - def __init__(self, username: str, password: Union[str, None]=None): - # Fetch data of user to check if user exists and to check if password is correct + """ + + def __init__(self, id: int) -> None: result = get_db(dict).execute( - "SELECT id, salt, hash, admin FROM users WHERE username = ? LIMIT 1;", - (username,) + "SELECT username, admin FROM users WHERE id = ? LIMIT 1;", + (id,) ).fetchone() if not result: raise UserNotFound - self.username = username - self.salt = result['salt'] - self.user_id = result['id'] - self.admin = result['admin'] == 1 - # Check password - if password is not None: - hash_password = get_hash(result['salt'], password) - if not hash_password == result['hash']: - raise AccessUnauthorized - + self.username: str = result['username'] + self.user_id = id + self.admin: bool = result['admin'] == 1 + return + @property def reminders(self) -> Reminders: """Get access to the reminders of the user account Returns: - Reminders: Reminders instance that can be used to access the reminders of the user account + Reminders: Reminders instance that can be used to access the + reminders of the user account """ if not hasattr(self, 'reminders_instance'): self.reminders_instance = Reminders(self.user_id) @@ -55,7 +51,8 @@ class User: """Get access to the notification services of the user account Returns: - NotificationServices: NotificationServices instance that can be used to access the notification services of the user account + NotificationServices: NotificationServices instance that can be used + to access the notification services of the user account """ if not hasattr(self, 'notification_services_instance'): self.notification_services_instance = NotificationServices(self.user_id) @@ -66,7 +63,8 @@ class User: """Get access to the templates of the user account Returns: - Templates: Templates instance that can be used to access the templates of the user account + Templates: Templates instance that can be used to access the + templates of the user account """ if not hasattr(self, 'templates_instance'): self.templates_instance = Templates(self.user_id) @@ -77,7 +75,8 @@ class User: """Get access to the static reminders of the user account Returns: - StaticReminders: StaticReminders instance that can be used to access the static reminders of the user account + StaticReminders: StaticReminders instance that can be used to + access the static reminders of the user account """ if not hasattr(self, 'static_reminders_instance'): self.static_reminders_instance = StaticReminders(self.user_id) @@ -103,121 +102,133 @@ class User: def delete(self) -> None: """Delete the user account """ + if self.username == 'admin': + raise UserNotFound + logging.info(f'Deleting the user {self.username} ({self.user_id})') cursor = get_db() - cursor.execute("DELETE FROM reminders WHERE user_id = ?", (self.user_id,)) - cursor.execute("DELETE FROM templates WHERE user_id = ?", (self.user_id,)) - cursor.execute("DELETE FROM static_reminders WHERE user_id = ?", (self.user_id,)) - cursor.execute("DELETE FROM notification_services WHERE user_id = ?", (self.user_id,)) - cursor.execute("DELETE FROM users WHERE id = ?", (self.user_id,)) + cursor.execute( + "DELETE FROM reminders WHERE user_id = ?", + (self.user_id,) + ) + cursor.execute( + "DELETE FROM templates WHERE user_id = ?", + (self.user_id,) + ) + cursor.execute( + "DELETE FROM static_reminders WHERE user_id = ?", + (self.user_id,) + ) + cursor.execute( + "DELETE FROM notification_services WHERE user_id = ?", + (self.user_id,) + ) + cursor.execute( + "DELETE FROM users WHERE id = ?", + (self.user_id,) + ) return -def _check_username(username: str) -> None: - """Check if username is valid +class Users: + def _check_username(self, username: str) -> None: + """Check if username is valid - Args: - username (str): The username to check + Args: + username (str): The username to check - Raises: - UsernameInvalid: The username is not valid - """ - logging.debug(f'Checking the username {username}') - if username in ONEPASS_INVALID_USERNAMES or username.isdigit(): - raise UsernameInvalid - if list(filter(lambda c: not c in ONEPASS_USERNAME_CHARACTERS, username)): - raise UsernameInvalid - return - -def register_user(username: str, password: str, from_admin: bool=False) -> int: - """Add a user - - Args: - username (str): The username of the new user - password (str): The password of the new user - from_admin (bool, optional): Skip check if new accounts are allowed. - Defaults to False. - - Raises: - UsernameInvalid: Username not allowed or contains invalid characters - UsernameTaken: Username is already taken; usernames must be unique - NewAccountsNotAllowed: In the admin panel, new accounts are set to be - not allowed. - - Returns: - user_id (int): The id of the new user. User registered successful - """ - logging.info(f'Registering user with username {username}') - - if not from_admin and not get_setting('allow_new_accounts'): - raise NewAccountsNotAllowed - - # Check if username is valid - _check_username(username) - - cursor = get_db() - - # Check if username isn't already taken - if cursor.execute( - "SELECT 1 FROM users WHERE username = ? LIMIT 1", (username,) - ).fetchone(): - raise UsernameTaken - - # Generate salt and key exclusive for user - salt, hashed_password = generate_salt_hash(password) - del password - - # Add user to userlist - user_id = cursor.execute( + Raises: + UsernameInvalid: The username is not valid """ - INSERT INTO users(username, salt, hash) - VALUES (?,?,?); - """, - (username, salt, hashed_password) - ).lastrowid + logging.debug(f'Checking the username {username}') + if username in ONEPASS_INVALID_USERNAMES or username.isdigit(): + raise UsernameInvalid(username) + if list(filter(lambda c: not c in ONEPASS_USERNAME_CHARACTERS, username)): + raise UsernameInvalid(username) + return - logging.debug(f'Newly registered user has id {user_id}') - return user_id + def __contains__(self, username: str) -> bool: + result = get_db().execute( + "SELECT 1 FROM users WHERE username = ? LIMIT 1;", + (username,) + ).fetchone() + return result is not None -def get_users() -> List[dict]: - """Get all user info for the admin + def add(self, username: str, password: str, from_admin: bool=False) -> int: + """Add a user - Returns: - List[dict]: The info about all users - """ - result = [ - dict(u) - for u in get_db(dict).execute( - "SELECT id, username, admin FROM users ORDER BY username;" - ) - ] - return result + Args: + username (str): The username of the new user + password (str): The password of the new user + from_admin (bool, optional): Skip check if new accounts are allowed. + Defaults to False. -def edit_user_password(id: int, new_password: str) -> None: - """Change the password of a user for the admin + Raises: + UsernameInvalid: Username not allowed or contains invalid characters + UsernameTaken: Username is already taken; usernames must be unique + NewAccountsNotAllowed: In the admin panel, new accounts are set to be + not allowed. - Args: - id (int): The ID of the user to change the password of - new_password (str): The new password to set for the user - """ - username = (get_db().execute( - "SELECT username FROM users WHERE id = ? LIMIT 1;", - (id,) - ).fetchone() or [''])[0] - User(username).edit_password(new_password) - return + Returns: + int: The id of the new user. User registered successful + """ + logging.info(f'Registering user with username {username}') + + if not from_admin and not get_setting('allow_new_accounts'): + raise NewAccountsNotAllowed + + # Check if username is valid + self._check_username(username) -def delete_user(id: int) -> None: - """Delete a user for the admin + cursor = get_db() - Args: - id (int): The ID of the user to delete - """ - username = (get_db().execute( - "SELECT username FROM users WHERE id = ? LIMIT 1;", - (id,) - ).fetchone() or [''])[0] - if username == 'admin': - raise UserNotFound - User(username).delete() - return + # Check if username isn't already taken + if username in self: + raise UsernameTaken + + # Generate salt and key exclusive for user + salt, hashed_password = generate_salt_hash(password) + del password + + # Add user to userlist + user_id = cursor.execute( + """ + INSERT INTO users(username, salt, hash) + VALUES (?,?,?); + """, + (username, salt, hashed_password) + ).lastrowid + + logging.debug(f'Newly registered user has id {user_id}') + return user_id + + def get_all(self) -> List[dict]: + """Get all user info for the admin + + Returns: + List[dict]: The info about all users + """ + result = [ + dict(u) + for u in get_db(dict).execute( + "SELECT id, username, admin FROM users ORDER BY username;" + ) + ] + return result + + def login(self, username: str, password: str) -> User: + result = get_db(dict).execute( + "SELECT id, salt, hash FROM users WHERE username = ? LIMIT 1;", + (username,) + ).fetchone() + if not result: + raise UserNotFound + + hash_password = get_hash(result['salt'], password) + if not hash_password == result['hash']: + raise AccessUnauthorized + + return User(result['id']) + + def get_one(self, id: int) -> User: + return User(id) diff --git a/frontend/api.py b/frontend/api.py index 55e2c57..a353c01 100644 --- a/frontend/api.py +++ b/frontend/api.py @@ -1,16 +1,13 @@ #-*- coding: utf-8 -*- -from io import BytesIO import logging -from abc import ABC, abstractmethod +from dataclasses import dataclass +from io import BytesIO from os import urandom -from re import compile from time import time as epoch_time -from typing import Any, Callable, Dict, List, Tuple, Union +from typing import Any, Callable, Dict, Tuple -from apprise import Apprise -from flask import Blueprint, g, request, send_file -from flask.sansio.scaffold import T_route +from flask import g, request, send_file from backend.custom_exceptions import (AccessUnauthorized, APIKeyExpired, APIKeyInvalid, InvalidKeyValue, @@ -22,378 +19,39 @@ from backend.custom_exceptions import (AccessUnauthorized, APIKeyExpired, UsernameInvalid, UsernameTaken, UserNotFound) from backend.db import DBConnection -from backend.notification_service import (NotificationService, - NotificationServices, - get_apprise_services) -from backend.reminders import Reminders, reminder_handler -from backend.settings import (_format_setting, get_admin_settings, get_setting, - set_setting) -from backend.static_reminders import StaticReminders -from backend.templates import Template, Templates -from backend.users import (User, delete_user, edit_user_password, get_users, - register_user) - -#=================== -# Input validation -#=================== -color_regex = compile(r'#[0-9a-f]{6}') - -class DataSource: - DATA = 1 - VALUES = 2 - -class InputVariable(ABC): - @abstractmethod - def __init__(self, value: Any) -> None: - pass - - @property - @abstractmethod - def name() -> str: - pass - - @abstractmethod - def validate(self) -> bool: - pass - - @property - @abstractmethod - def required() -> bool: - pass - - @property - @abstractmethod - def default() -> Any: - pass - - @property - @abstractmethod - def source() -> int: - pass - - @property - @abstractmethod - def description() -> str: - pass - - @property - @abstractmethod - def related_exceptions() -> List[Exception]: - pass - -class DefaultInputVariable(InputVariable): - source = DataSource.DATA - required = True - default = None - related_exceptions = [] - - def __init__(self, value: Any) -> None: - self.value = value - - def validate(self) -> bool: - return isinstance(self.value, str) and self.value - - def __repr__(self) -> str: - return f'| {self.name} | {"Yes" if self.required else "No"} | {self.description} | N/A |' - -class NonRequiredVersion(InputVariable): - required = False - - def validate(self) -> bool: - return self.value is None or super().validate() - -class UsernameVariable(DefaultInputVariable): - name = 'username' - description = 'The username of the user account' - related_exceptions = [KeyNotFound, UserNotFound] - -class PasswordVariable(DefaultInputVariable): - name = 'password' - description = 'The password of the user account' - related_exceptions = [KeyNotFound, AccessUnauthorized] - -class NewPasswordVariable(PasswordVariable): - name = 'new_password' - description = 'The new password of the user account' - related_exceptions = [KeyNotFound] - -class UsernameCreateVariable(UsernameVariable): - related_exceptions = [ - KeyNotFound, - UsernameInvalid, UsernameTaken, - NewAccountsNotAllowed - ] - -class PasswordCreateVariable(PasswordVariable): - related_exceptions = [KeyNotFound] - -class TitleVariable(DefaultInputVariable): - name = 'title' - description = 'The title of the entry' - related_exceptions = [KeyNotFound] - -class URLVariable(DefaultInputVariable): - name = 'url' - description = 'The Apprise URL of the notification service' - related_exceptions = [KeyNotFound, InvalidKeyValue] - - def validate(self) -> bool: - return Apprise().add(self.value) - -class EditTitleVariable(NonRequiredVersion, TitleVariable): - related_exceptions = [] - -class EditURLVariable(NonRequiredVersion, URLVariable): - related_exceptions = [InvalidKeyValue] - -class SortByVariable(DefaultInputVariable): - name = 'sort_by' - description = 'How to sort the result' - required = False - source = DataSource.VALUES - _options = Reminders.sort_functions - default = next(iter(Reminders.sort_functions)) - related_exceptions = [InvalidKeyValue] - - def __init__(self, value: str) -> None: - self.value = value - - def validate(self) -> bool: - return self.value in self._options - - def __repr__(self) -> str: - return '| {n} | {r} | {d} | {v} |'.format( - n=self.name, - r="Yes" if self.required else "No", - d=self.description, - v=", ".join(f'`{o}`' for o in self._options) - ) - -class TemplateSortByVariable(SortByVariable): - _options = Templates.sort_functions - default = next(iter(Templates.sort_functions)) - -class StaticReminderSortByVariable(TemplateSortByVariable): - _options = StaticReminders.sort_functions - default = next(iter(StaticReminders.sort_functions)) - -class TimeVariable(DefaultInputVariable): - name = 'time' - description = 'The UTC epoch timestamp that the reminder should be sent at' - related_exceptions = [KeyNotFound, InvalidKeyValue, InvalidTime] - - def validate(self) -> bool: - return isinstance(self.value, (float, int)) - -class EditTimeVariable(NonRequiredVersion, TimeVariable): - related_exceptions = [InvalidKeyValue, InvalidTime] - -class NotificationServicesVariable(DefaultInputVariable): - name = 'notification_services' - description = "Array of the id's of the notification services to use to send the notification" - related_exceptions = [KeyNotFound, InvalidKeyValue, NotificationServiceNotFound] - - def validate(self) -> bool: - if not isinstance(self.value, list): - return False - if not self.value: - return False - for v in self.value: - if not isinstance(v, int): - return False - return True - -class EditNotificationServicesVariable(NonRequiredVersion, NotificationServicesVariable): - related_exceptions = [InvalidKeyValue, NotificationServiceNotFound] - -class TextVariable(NonRequiredVersion, DefaultInputVariable): - name = 'text' - description = 'The body of the entry' - default = '' - - def validate(self) -> bool: - return isinstance(self.value, str) - -class RepeatQuantityVariable(DefaultInputVariable): - name = 'repeat_quantity' - description = 'The quantity of the repeat_interval' - required = False - _options = ("years", "months", "weeks", "days", "hours", "minutes") - default = None - related_exceptions = [InvalidKeyValue] - - def validate(self) -> bool: - return self.value is None or self.value in self._options - - def __repr__(self) -> str: - return '| {n} | {r} | {d} | {v} |'.format( - n=self.name, - r="Yes" if self.required else "No", - d=self.description, - v=", ".join(f'`{o}`' for o in self._options) - ) - -class RepeatIntervalVariable(DefaultInputVariable): - name = 'repeat_interval' - description = 'The number of the interval' - required = False - default = None - related_exceptions = [InvalidKeyValue] - - def validate(self) -> bool: - return self.value is None or (isinstance(self.value, int) and self.value > 0) - -class WeekDaysVariable(DefaultInputVariable): - name = 'weekdays' - description = 'On which days of the weeks to run the reminder' - required = False - default = None - related_exceptions = [InvalidKeyValue] - _options = {0, 1, 2, 3, 4, 5, 6} - - def validate(self) -> bool: - return self.value is None or ( - isinstance(self.value, list) - and len(self.value) > 0 - and all(v in self._options for v in self.value) - ) - -class ColorVariable(DefaultInputVariable): - name = 'color' - description = 'The hex code of the color of the entry, which is shown in the web-ui' - required = False - default = None - related_exceptions = [InvalidKeyValue] - - def validate(self) -> bool: - return self.value is None or color_regex.search(self.value) - -class QueryVariable(DefaultInputVariable): - name = 'query' - description = 'The search term' - source = DataSource.VALUES - -class AdminSettingsVariable(DefaultInputVariable): - related_exceptions = [KeyNotFound, InvalidKeyValue] - - def validate(self) -> bool: - try: - _format_setting(self.name, self.value) - except InvalidKeyValue: - return False - return True - -class AllowNewAccountsVariable(AdminSettingsVariable): - name = 'allow_new_accounts' - description = ('Whether or not to allow users to register a new account. ' - + 'The admin can always add a new account.') - -class LoginTimeVariable(AdminSettingsVariable): - name = 'login_time' - description = ('How long a user stays logged in, in seconds. ' - + 'Between 1 min and 1 month (60 <= sec <= 2592000)') - -class LoginTimeResetVariable(AdminSettingsVariable): - name = 'login_time_reset' - description = 'If the Login Time timer should reset with each API request.' - -def input_validation() -> Union[None, Dict[str, Any]]: - """Checks, extracts and transforms inputs - - Raises: - KeyNotFound: A required key was not supplied - InvalidKeyValue: The value of a key is not valid - - Returns: - Union[None, Dict[str, Any]]: `None` if the endpoint + method doesn't require input variables. - Otherwise `Dict[str, Any]` with the input variables, checked and formatted. - """ - inputs = {} - - input_variables: Dict[str, List[Union[List[InputVariable], str]]] - if request.path.startswith(admin_api_prefix): - input_variables = api_docs[ - _admin_api_prefix + request.url_rule.rule.split(admin_api_prefix)[1] - ]['input_variables'] - else: - input_variables = api_docs[ - request.url_rule.rule.split(api_prefix)[1] - ]['input_variables'] - - if not input_variables: - return - - if input_variables.get(request.method) is None: - return inputs - - given_variables = {} - given_variables[DataSource.DATA] = request.get_json() if request.data else {} - given_variables[DataSource.VALUES] = request.values - for input_variable in input_variables[request.method]: - if ( - input_variable.required and - not input_variable.name in given_variables[input_variable.source] - ): - raise KeyNotFound(input_variable.name) - - input_value = given_variables[input_variable.source].get( - input_variable.name, - input_variable.default - ) - - if not input_variable(input_value).validate(): - raise InvalidKeyValue(input_variable.name, input_value) - - inputs[input_variable.name] = input_value - return inputs +from backend.notification_service import get_apprise_services +from backend.settings import get_admin_settings, get_setting, set_setting +from backend.users import User, Users +from frontend.input_validation import (AllowNewAccountsVariable, ColorVariable, + EditNotificationServicesVariable, + EditTimeVariable, EditTitleVariable, + EditURLVariable, LoginTimeResetVariable, + LoginTimeVariable, NewPasswordVariable, + NotificationServicesVariable, + PasswordCreateVariable, + PasswordVariable, QueryVariable, + RepeatIntervalVariable, + RepeatQuantityVariable, SortByVariable, + StaticReminderSortByVariable, + TemplateSortByVariable, TextVariable, + TimeVariable, TitleVariable, + URLVariable, UsernameCreateVariable, + UsernameVariable, WeekDaysVariable, + _admin_api_prefix, admin_api, + admin_api_prefix, api, api_docs, + api_prefix, input_validation) #=================== # General variables and functions #=================== -api_docs: Dict[str, Dict[str, Any]] = {} -class APIBlueprint(Blueprint): - def route( - self, - rule: str, - description: str = '', - input_variables: Dict[str, List[Union[List[InputVariable], str]]] = {}, - requires_auth: bool = True, - **options: Any - ) -> Callable[[T_route], T_route]: +@dataclass +class ApiKeyEntry: + exp: int + user_data: User - if self == api: - processed_rule = rule - elif self == admin_api: - processed_rule = _admin_api_prefix + rule - else: - raise NotImplementedError - - api_docs[processed_rule] = { - 'endpoint': processed_rule, - 'description': description, - 'requires_auth': requires_auth, - 'methods': options['methods'], - 'input_variables': { - k: v[0] - for k, v in input_variables.items() - if v and v[0] - }, - 'method_descriptions': { - k: v[1] - for k, v in input_variables.items() - if v and len(v) == 2 and v[1] - } - } - - return super().route(rule, **options) - -api_prefix = "/api" -_admin_api_prefix = '/admin' -admin_api_prefix = api_prefix + _admin_api_prefix -api = APIBlueprint('api', __name__) -admin_api = APIBlueprint('admin_api', __name__) -api_key_map = {} +users = Users() +api_key_map: Dict[int, ApiKeyEntry] = {} def return_api(result: Any, error: str=None, code: int=200) -> Tuple[dict, int]: return {'error': error, 'result': result}, code @@ -409,33 +67,37 @@ def auth() -> None: if not hashed_api_key in api_key_map: raise APIKeyInvalid - if not ( - ( - api_key_map[hashed_api_key]['user_data'].admin - and request.path.startswith((admin_api_prefix, api_prefix + '/auth')) - ) - or - ( - not api_key_map[hashed_api_key]['user_data'].admin - and not request.path.startswith(admin_api_prefix) - ) + map_entry = api_key_map[hashed_api_key] + + if ( + map_entry.user_data.admin + and + not request.path.startswith((admin_api_prefix, api_prefix + '/auth')) + ): + raise APIKeyInvalid + + if ( + not map_entry.user_data.admin + and + request.path.startswith(admin_api_prefix) ): raise APIKeyInvalid - exp = api_key_map[hashed_api_key]['exp'] - if exp <= epoch_time(): + if map_entry.exp <= epoch_time(): raise APIKeyExpired # Api key valid if get_setting('login_time_reset'): - api_key_map[hashed_api_key]['exp'] = exp = ( + g.exp = map_entry.exp = ( epoch_time() + get_setting('login_time') ) + else: + g.exp = map_entry.exp g.hashed_api_key = hashed_api_key - g.exp = exp - g.user_data = api_key_map[hashed_api_key]['user_data'] + g.user_data = map_entry.user_data + return def endpoint_wrapper(method: Callable) -> Callable: @@ -458,14 +120,15 @@ def endpoint_wrapper(method: Callable) -> Callable: return method(*args, **kwargs) return method(inputs, *args, **kwargs) - except (UsernameTaken, UsernameInvalid, UserNotFound, - AccessUnauthorized, - ReminderNotFound, NotificationServiceNotFound, - NotificationServiceInUse, InvalidTime, - KeyNotFound, InvalidKeyValue, - APIKeyInvalid, APIKeyExpired, - TemplateNotFound, - NewAccountsNotAllowed) as e: + except (AccessUnauthorized, APIKeyExpired, + APIKeyInvalid, InvalidKeyValue, + InvalidTime, KeyNotFound, + NewAccountsNotAllowed, + NotificationServiceInUse, + NotificationServiceNotFound, + ReminderNotFound, TemplateNotFound, + UsernameInvalid, UsernameTaken, + UserNotFound) as e: return return_api(**e.api_response) wrapper.__name__ = method.__name__ @@ -484,7 +147,7 @@ def endpoint_wrapper(method: Callable) -> Callable: ) @endpoint_wrapper def api_login(inputs: Dict[str, str]): - user = User(inputs['username'], inputs['password']) + user = users.login(inputs['username'], inputs['password']) # Generate an API key until one # is generated that isn't used already @@ -496,12 +159,7 @@ def api_login(inputs: Dict[str, str]): login_time = get_setting('login_time') exp = epoch_time() + login_time - api_key_map.update({ - hashed_api_key: { - 'exp': exp, - 'user_data': user - } - }) + api_key_map[hashed_api_key] = ApiKeyEntry(exp, user) result = {'api_key': api_key, 'expires': exp, 'admin': user.admin} return return_api(result, code=201) @@ -523,17 +181,17 @@ def api_logout(): ) @endpoint_wrapper def api_status(): + map_entry = api_key_map[g.hashed_api_key] result = { - 'expires': api_key_map[g.hashed_api_key]['exp'], - 'username': api_key_map[g.hashed_api_key]['user_data'].username, - 'admin': api_key_map[g.hashed_api_key]['user_data'].admin + 'expires': map_entry.exp, + 'username': map_entry.user_data.username, + 'admin': map_entry.user_data.admin } return return_api(result) #=================== # User endpoints #=================== - @api.route( '/user/add', 'Create a new user account', @@ -543,7 +201,7 @@ def api_status(): ) @endpoint_wrapper def api_add_user(inputs: Dict[str, str]): - register_user(inputs['username'], inputs['password']) + users.add(inputs['username'], inputs['password']) return return_api({}, code=201) @api.route( @@ -557,12 +215,13 @@ def api_add_user(inputs: Dict[str, str]): ) @endpoint_wrapper def api_manage_user(inputs: Dict[str, str]): + user = api_key_map[g.hashed_api_key].user_data if request.method == 'PUT': - g.user_data.edit_password(inputs['new_password']) + user.edit_password(inputs['new_password']) return return_api({}) elif request.method == 'DELETE': - g.user_data.delete() + user.delete() api_key_map.pop(g.hashed_api_key) return return_api({}) @@ -581,7 +240,7 @@ def api_manage_user(inputs: Dict[str, str]): ) @endpoint_wrapper def api_notification_services_list(inputs: Dict[str, str]): - services: NotificationServices = g.user_data.notification_services + services = api_key_map[g.hashed_api_key].user_data.notification_services if request.method == 'GET': result = services.fetchall() @@ -610,7 +269,10 @@ def api_notification_service_available(): ) @endpoint_wrapper def api_test_service(inputs: Dict[str, Any]): - g.user_data.notification_services.test_service(inputs['url']) + (api_key_map[g.hashed_api_key] + .user_data + .notification_services + .test_service(inputs['url'])) return return_api({}, code=201) @api.route( @@ -624,17 +286,20 @@ def api_test_service(inputs: Dict[str, Any]): ) @endpoint_wrapper def api_notification_service(inputs: Dict[str, str], n_id: int): - service: NotificationService = g.user_data.notification_services.fetchone(n_id) - + service = (api_key_map[g.hashed_api_key] + .user_data + .notification_services + .fetchone(n_id)) + if request.method == 'GET': result = service.get() return return_api(result) - + elif request.method == 'PUT': result = service.update(title=inputs['title'], url=inputs['url']) return return_api(result) - + elif request.method == 'DELETE': service.delete() return return_api({}) @@ -659,7 +324,7 @@ def api_notification_service(inputs: Dict[str, str], n_id: int): ) @endpoint_wrapper def api_reminders_list(inputs: Dict[str, Any]): - reminders: Reminders = g.user_data.reminders + reminders = api_key_map[g.hashed_api_key].user_data.reminders if request.method == 'GET': result = reminders.fetchall(inputs['sort_by']) @@ -684,7 +349,10 @@ def api_reminders_list(inputs: Dict[str, Any]): ) @endpoint_wrapper def api_reminders_query(inputs: Dict[str, str]): - result = g.user_data.reminders.search(inputs['query'], inputs['sort_by']) + result = (api_key_map[g.hashed_api_key] + .user_data + .reminders + .search(inputs['query'], inputs['sort_by'])) return return_api(result) @api.route( @@ -696,7 +364,11 @@ def api_reminders_query(inputs: Dict[str, str]): ) @endpoint_wrapper def api_test_reminder(inputs: Dict[str, Any]): - g.user_data.reminders.test_reminder(inputs['title'], inputs['notification_services'], inputs['text']) + api_key_map[g.hashed_api_key].user_data.reminders.test_reminder( + inputs['title'], + inputs['notification_services'], + inputs['text'] + ) return return_api({}, code=201) @api.route( @@ -714,7 +386,8 @@ def api_test_reminder(inputs: Dict[str, Any]): ) @endpoint_wrapper def api_get_reminder(inputs: Dict[str, Any], r_id: int): - reminders: Reminders = g.user_data.reminders + reminders = api_key_map[g.hashed_api_key].user_data.reminders + if request.method == 'GET': result = reminders.fetchone(r_id).get() return return_api(result) @@ -750,7 +423,7 @@ def api_get_reminder(inputs: Dict[str, Any], r_id: int): ) @endpoint_wrapper def api_get_templates(inputs: Dict[str, Any]): - templates: Templates = g.user_data.templates + templates = api_key_map[g.hashed_api_key].user_data.templates if request.method == 'GET': result = templates.fetchall(inputs['sort_by']) @@ -771,7 +444,10 @@ def api_get_templates(inputs: Dict[str, Any]): ) @endpoint_wrapper def api_templates_query(inputs: Dict[str, str]): - result = g.user_data.templates.search(inputs['query'], inputs['sort_by']) + result = (api_key_map[g.hashed_api_key] + .user_data + .templates + .search(inputs['query'], inputs['sort_by'])) return return_api(result) @api.route( @@ -786,7 +462,10 @@ def api_templates_query(inputs: Dict[str, str]): ) @endpoint_wrapper def api_get_template(inputs: Dict[str, Any], t_id: int): - template: Template = g.user_data.templates.fetchone(t_id) + template = (api_key_map[g.hashed_api_key] + .user_data + .templates + .fetchone(t_id)) if request.method == 'GET': result = template.get() @@ -819,7 +498,7 @@ def api_get_template(inputs: Dict[str, Any], t_id: int): ) @endpoint_wrapper def api_static_reminders_list(inputs: Dict[str, Any]): - reminders: StaticReminders = g.user_data.static_reminders + reminders = api_key_map[g.hashed_api_key].user_data.static_reminders if request.method == 'GET': result = reminders.fetchall(inputs['sort_by']) @@ -840,7 +519,10 @@ def api_static_reminders_list(inputs: Dict[str, Any]): ) @endpoint_wrapper def api_static_reminders_query(inputs: Dict[str, str]): - result = g.user_data.static_reminders.search(inputs['query'], inputs['sort_by']) + result = (api_key_map[g.hashed_api_key] + .user_data + .static_reminders + .search(inputs['query'], inputs['sort_by'])) return return_api(result) @api.route( @@ -857,7 +539,8 @@ def api_static_reminders_query(inputs: Dict[str, str]): ) @endpoint_wrapper def api_get_static_reminder(inputs: Dict[str, Any], s_id: int): - reminders: StaticReminders = g.user_data.static_reminders + reminders = api_key_map[g.hashed_api_key].user_data.static_reminders + if request.method == 'GET': result = reminders.fetchone(s_id).get() return return_api(result) @@ -929,11 +612,11 @@ def api_admin_settings(inputs: Dict[str, Any]): @endpoint_wrapper def api_admin_users(inputs: Dict[str, Any]): if request.method == 'GET': - result = get_users() + result = users.get_all() return return_api(result) elif request.method == 'POST': - register_user(inputs['username'], inputs['password'], True) + users.add(inputs['username'], inputs['password'], True) return return_api({}, code=201) @admin_api.route( @@ -947,14 +630,15 @@ def api_admin_users(inputs: Dict[str, Any]): ) @endpoint_wrapper def api_admin_user(inputs: Dict[str, Any], u_id: int): + user = users.get_one(u_id) if request.method == 'PUT': - edit_user_password(u_id, inputs['new_password']) + user.edit_password(inputs['new_password']) return return_api({}) elif request.method == 'DELETE': - delete_user(u_id) + user.delete() for key, value in api_key_map.items(): - if value['user_data'].user_id == u_id: + if value.user_data.user_id == u_id: del api_key_map[key] break return return_api({}) diff --git a/frontend/input_validation.py b/frontend/input_validation.py new file mode 100644 index 0000000..d560516 --- /dev/null +++ b/frontend/input_validation.py @@ -0,0 +1,434 @@ +#-*- coding: utf-8 -*- + +""" +Input validation for the API +""" + +from abc import ABC, abstractmethod +from re import compile +from typing import Any, Callable, Dict, List, Union + +from apprise import Apprise +from flask import Blueprint, request +from flask.sansio.scaffold import T_route + +from backend.custom_exceptions import (AccessUnauthorized, InvalidKeyValue, + InvalidTime, KeyNotFound, + NewAccountsNotAllowed, + NotificationServiceNotFound, + UsernameInvalid, UsernameTaken, + UserNotFound) +from backend.helpers import (RepeatQuantity, SortingMethod, + TimelessSortingMethod) +from backend.settings import _format_setting + +api_prefix = "/api" +_admin_api_prefix = '/admin' +admin_api_prefix = api_prefix + _admin_api_prefix + +color_regex = compile(r'#[0-9a-f]{6}') + +class DataSource: + DATA = 1 + VALUES = 2 + + +class InputVariable(ABC): + value: Any + + @abstractmethod + def __init__(self, value: Any) -> None: + pass + + @property + @abstractmethod + def name() -> str: + pass + + @abstractmethod + def validate(self) -> bool: + pass + + @property + @abstractmethod + def required() -> bool: + pass + + @property + @abstractmethod + def default() -> Any: + pass + + @property + @abstractmethod + def source() -> int: + pass + + @property + @abstractmethod + def description() -> str: + pass + + @property + @abstractmethod + def related_exceptions() -> List[Exception]: + pass + + +class DefaultInputVariable(InputVariable): + source = DataSource.DATA + required = True + default = None + related_exceptions = [] + + def __init__(self, value: Any) -> None: + self.value = value + + def validate(self) -> bool: + return isinstance(self.value, str) and self.value + + def __repr__(self) -> str: + return f'| {self.name} | {"Yes" if self.required else "No"} | {self.description} | N/A |' + + +class NonRequiredVersion(InputVariable): + required = False + + def validate(self) -> bool: + return self.value is None or super().validate() + + +class UsernameVariable(DefaultInputVariable): + name = 'username' + description = 'The username of the user account' + related_exceptions = [KeyNotFound, UserNotFound] + + +class PasswordVariable(DefaultInputVariable): + name = 'password' + description = 'The password of the user account' + related_exceptions = [KeyNotFound, AccessUnauthorized] + + +class NewPasswordVariable(PasswordVariable): + name = 'new_password' + description = 'The new password of the user account' + related_exceptions = [KeyNotFound] + + +class UsernameCreateVariable(UsernameVariable): + related_exceptions = [ + KeyNotFound, + UsernameInvalid, UsernameTaken, + NewAccountsNotAllowed + ] + + +class PasswordCreateVariable(PasswordVariable): + related_exceptions = [KeyNotFound] + + +class TitleVariable(DefaultInputVariable): + name = 'title' + description = 'The title of the entry' + related_exceptions = [KeyNotFound] + + +class URLVariable(DefaultInputVariable): + name = 'url' + description = 'The Apprise URL of the notification service' + related_exceptions = [KeyNotFound, InvalidKeyValue] + + def validate(self) -> bool: + return Apprise().add(self.value) + + +class EditTitleVariable(NonRequiredVersion, TitleVariable): + related_exceptions = [] + + +class EditURLVariable(NonRequiredVersion, URLVariable): + related_exceptions = [InvalidKeyValue] + + +class SortByVariable(DefaultInputVariable): + name = 'sort_by' + description = 'How to sort the result' + required = False + source = DataSource.VALUES + _options = [k.lower() for k in SortingMethod._member_names_] + default = SortingMethod._member_names_[0].lower() + related_exceptions = [InvalidKeyValue] + + def __init__(self, value: str) -> None: + self.value = value + + def validate(self) -> bool: + if not self.value in self._options: + return False + + self.value = SortingMethod[self.value.upper()] + return True + + def __repr__(self) -> str: + return '| {n} | {r} | {d} | {v} |'.format( + n=self.name, + r="Yes" if self.required else "No", + d=self.description, + v=", ".join(f'`{o}`' for o in self._options) + ) + + +class TemplateSortByVariable(SortByVariable): + _options = [k.lower() for k in TimelessSortingMethod._member_names_] + default = TimelessSortingMethod._member_names_[0].lower() + + def validate(self) -> bool: + if not self.value in self._options: + return False + + self.value = TimelessSortingMethod[self.value.upper()] + return True + +class StaticReminderSortByVariable(TemplateSortByVariable): + pass + + +class TimeVariable(DefaultInputVariable): + name = 'time' + description = 'The UTC epoch timestamp that the reminder should be sent at' + related_exceptions = [KeyNotFound, InvalidKeyValue, InvalidTime] + + def validate(self) -> bool: + return isinstance(self.value, (float, int)) + + +class EditTimeVariable(NonRequiredVersion, TimeVariable): + related_exceptions = [InvalidKeyValue, InvalidTime] + + +class NotificationServicesVariable(DefaultInputVariable): + name = 'notification_services' + description = "Array of the id's of the notification services to use to send the notification" + related_exceptions = [KeyNotFound, InvalidKeyValue, NotificationServiceNotFound] + + def validate(self) -> bool: + if not isinstance(self.value, list): + return False + if not self.value: + return False + for v in self.value: + if not isinstance(v, int): + return False + return True + + +class EditNotificationServicesVariable(NonRequiredVersion, NotificationServicesVariable): + related_exceptions = [InvalidKeyValue, NotificationServiceNotFound] + + +class TextVariable(NonRequiredVersion, DefaultInputVariable): + name = 'text' + description = 'The body of the entry' + default = '' + + def validate(self) -> bool: + return isinstance(self.value, str) + + +class RepeatQuantityVariable(DefaultInputVariable): + name = 'repeat_quantity' + description = 'The quantity of the repeat_interval' + required = False + _options = [m.lower() for m in RepeatQuantity._member_names_] + default = None + related_exceptions = [InvalidKeyValue] + + def validate(self) -> bool: + if self.value is None: + return True + + if not self.value in self._options: + return False + + self.value = RepeatQuantity[self.value.upper()] + return True + + def __repr__(self) -> str: + return '| {n} | {r} | {d} | {v} |'.format( + n=self.name, + r="Yes" if self.required else "No", + d=self.description, + v=", ".join(f'`{o}`' for o in self._options) + ) + + +class RepeatIntervalVariable(DefaultInputVariable): + name = 'repeat_interval' + description = 'The number of the interval' + required = False + default = None + related_exceptions = [InvalidKeyValue] + + def validate(self) -> bool: + return ( + self.value is None + or ( + isinstance(self.value, int) + and self.value > 0 + ) + ) + + +class WeekDaysVariable(DefaultInputVariable): + name = 'weekdays' + description = 'On which days of the weeks to run the reminder' + required = False + default = None + related_exceptions = [InvalidKeyValue] + _options = {0, 1, 2, 3, 4, 5, 6} + + def validate(self) -> bool: + return self.value is None or ( + isinstance(self.value, list) + and len(self.value) > 0 + and all(v in self._options for v in self.value) + ) + + +class ColorVariable(DefaultInputVariable): + name = 'color' + description = 'The hex code of the color of the entry, which is shown in the web-ui' + required = False + default = None + related_exceptions = [InvalidKeyValue] + + def validate(self) -> bool: + return self.value is None or color_regex.search(self.value) + + +class QueryVariable(DefaultInputVariable): + name = 'query' + description = 'The search term' + source = DataSource.VALUES + + +class AdminSettingsVariable(DefaultInputVariable): + related_exceptions = [KeyNotFound, InvalidKeyValue] + + def validate(self) -> bool: + try: + _format_setting(self.name, self.value) + except InvalidKeyValue: + return False + return True + + +class AllowNewAccountsVariable(AdminSettingsVariable): + name = 'allow_new_accounts' + description = ('Whether or not to allow users to register a new account. ' + + 'The admin can always add a new account.') + + +class LoginTimeVariable(AdminSettingsVariable): + name = 'login_time' + description = ('How long a user stays logged in, in seconds. ' + + 'Between 1 min and 1 month (60 <= sec <= 2592000)') + + +class LoginTimeResetVariable(AdminSettingsVariable): + name = 'login_time_reset' + description = 'If the Login Time timer should reset with each API request.' + + +def input_validation() -> Union[None, Dict[str, Any]]: + """Checks, extracts and transforms inputs + + Raises: + KeyNotFound: A required key was not supplied + InvalidKeyValue: The value of a key is not valid + + Returns: + Union[None, Dict[str, Any]]: `None` if the endpoint + method doesn't require input variables. + Otherwise `Dict[str, Any]` with the input variables, checked and formatted. + """ + inputs = {} + + input_variables: Dict[str, List[Union[List[InputVariable], str]]] + if request.path.startswith(admin_api_prefix): + input_variables = api_docs[ + _admin_api_prefix + request.url_rule.rule.split(admin_api_prefix)[1] + ]['input_variables'] + else: + input_variables = api_docs[ + request.url_rule.rule.split(api_prefix)[1] + ]['input_variables'] + + if not input_variables: + return + + if input_variables.get(request.method) is None: + return inputs + + given_variables = {} + given_variables[DataSource.DATA] = request.get_json() if request.data else {} + given_variables[DataSource.VALUES] = request.values + for input_variable in input_variables[request.method]: + if ( + input_variable.required and + not input_variable.name in given_variables[input_variable.source] + ): + raise KeyNotFound(input_variable.name) + + input_value = given_variables[input_variable.source].get( + input_variable.name, + input_variable.default + ) + value: InputVariable = input_variable(input_value) + + if not value.validate(): + raise InvalidKeyValue(input_variable.name, input_value) + + inputs[input_variable.name] = value.value + return inputs + + +api_docs: Dict[str, Dict[str, Any]] = {} +class APIBlueprint(Blueprint): + def route( + self, + rule: str, + description: str = '', + input_variables: Dict[str, List[Union[List[InputVariable], str]]] = {}, + requires_auth: bool = True, + **options: Any + ) -> Callable[[T_route], T_route]: + + if self == api: + processed_rule = rule + elif self == admin_api: + processed_rule = _admin_api_prefix + rule + else: + raise NotImplementedError + + api_docs[processed_rule] = { + 'endpoint': processed_rule, + 'description': description, + 'requires_auth': requires_auth, + 'methods': options['methods'], + 'input_variables': { + k: v[0] + for k, v in input_variables.items() + if v and v[0] + }, + 'method_descriptions': { + k: v[1] + for k, v in input_variables.items() + if v and len(v) == 2 and v[1] + } + } + + return super().route(rule, **options) + +api = APIBlueprint('api', __name__) +admin_api = APIBlueprint('admin_api', __name__) diff --git a/frontend/ui.py b/frontend/ui.py index 44fc476..4f953f6 100644 --- a/frontend/ui.py +++ b/frontend/ui.py @@ -1,20 +1,22 @@ #-*- coding: utf-8 -*- -import logging from flask import Blueprint, render_template ui = Blueprint('ui', __name__) methods = ['GET'] +class UIVariables: + url_prefix: str = '' + @ui.route('/', methods=methods) def ui_login(): - return render_template('login.html', url_prefix=logging.URL_PREFIX) + return render_template('login.html', url_prefix=UIVariables.url_prefix) @ui.route('/reminders', methods=methods) def ui_reminders(): - return render_template('reminders.html', url_prefix=logging.URL_PREFIX) + return render_template('reminders.html', url_prefix=UIVariables.url_prefix) @ui.route('/admin', methods=methods) def ui_admin(): - return render_template('admin.html', url_prefixx=logging.URL_PREFIX) + return render_template('admin.html', url_prefix=UIVariables.url_prefix) diff --git a/project_management/generate_api_docs.py b/project_management/generate_api_docs.py index 98c8935..b9e1c15 100644 --- a/project_management/generate_api_docs.py +++ b/project_management/generate_api_docs.py @@ -1,15 +1,18 @@ #!/usr/bin/env python3 #-*- coding: utf-8 -*- -from sys import path from os.path import dirname +from sys import path + +from frontend.input_validation import DataSource path.insert(0, dirname(path[0])) from subprocess import run from typing import Union -from frontend.api import (DataSource, NotificationServiceNotFound, - ReminderNotFound, TemplateNotFound, api_docs) + +from frontend.api import (NotificationServiceNotFound, ReminderNotFound, + TemplateNotFound, api_docs) from MIND import _folder_path, api_prefix url_var_map = { diff --git a/tests/custom_exceptions_test.py b/tests/custom_exceptions_test.py index 88388da..f086c25 100644 --- a/tests/custom_exceptions_test.py +++ b/tests/custom_exceptions_test.py @@ -8,7 +8,8 @@ import backend.custom_exceptions class Test_Custom_Exceptions(unittest.TestCase): def test_type(self): defined_exceptions: List[Exception] = filter( - lambda c: c.__module__ == 'backend.custom_exceptions' and c is not backend.custom_exceptions.CustomException, + lambda c: c.__module__ == 'backend.custom_exceptions' + and c is not backend.custom_exceptions.CustomException, map( lambda c: c[1], getmembers(modules['backend.custom_exceptions'], isclass) @@ -16,11 +17,21 @@ class Test_Custom_Exceptions(unittest.TestCase): ) for defined_exception in defined_exceptions: - self.assertEqual( + self.assertIn( getmro(defined_exception)[1], - backend.custom_exceptions.CustomException + ( + backend.custom_exceptions.CustomException, + Exception + ) ) - result = defined_exception().api_response + try: + result = defined_exception().api_response + except TypeError: + try: + result = defined_exception('1').api_response + except TypeError: + result = defined_exception('1', '2').api_response + self.assertIsInstance(result, dict) result['error'] result['result'] diff --git a/tests/db_test.py b/tests/db_test.py index 13b3802..5f47389 100644 --- a/tests/db_test.py +++ b/tests/db_test.py @@ -1,10 +1,12 @@ import unittest from backend.db import DBConnection -from MIND import DB_FILENAME, _folder_path +from backend.helpers import folder_path +from MIND import DB_FILENAME + class Test_DB(unittest.TestCase): def test_foreign_key(self): - DBConnection.file = _folder_path(*DB_FILENAME) + DBConnection.file = folder_path(*DB_FILENAME) instance = DBConnection(timeout=20.0) self.assertEqual(instance.cursor().execute("PRAGMA foreign_keys;").fetchone()[0], 1) diff --git a/tests/reminders_test.py b/tests/reminders_test.py index c394d43..1e04b94 100644 --- a/tests/reminders_test.py +++ b/tests/reminders_test.py @@ -1,20 +1,14 @@ import unittest -from backend.reminders import filter_function, ReminderHandler +from backend.helpers import search_filter class Test_Reminder_Handler(unittest.TestCase): - def test_starting_stopping(self): - context = 'test' - instance = ReminderHandler(context) - self.assertIs(context, instance.context) - def test_filter_function(self): p = { 'title': 'TITLE', 'text': 'TEXT' } for test_case in ('', 'title', 'ex'): - self.assertTrue(filter_function(test_case, p)) + self.assertTrue(search_filter(test_case, p)) for test_case in (' ', 'Hello'): - self.assertFalse(filter_function(test_case, p)) - + self.assertFalse(search_filter(test_case, p)) diff --git a/tests/users_test.py b/tests/users_test.py index f00e0ae..d1e7b23 100644 --- a/tests/users_test.py +++ b/tests/users_test.py @@ -1,14 +1,14 @@ import unittest from backend.custom_exceptions import UsernameInvalid -from backend.users import ONEPASS_INVALID_USERNAMES, _check_username +from backend.users import ONEPASS_INVALID_USERNAMES, Users class Test_Users(unittest.TestCase): def test_username_check(self): + users = Users() for test_case in ('', 'test'): - _check_username(test_case) + users._check_username(test_case) for test_case in (' ', ' ', '0', 'api', *ONEPASS_INVALID_USERNAMES): with self.assertRaises(UsernameInvalid): - _check_username(test_case) - \ No newline at end of file + users._check_username(test_case)