diff --git a/MIND.py b/MIND.py index 646f244..40a4511 100644 --- a/MIND.py +++ b/MIND.py @@ -12,7 +12,7 @@ from waitress.server import create_server from werkzeug.middleware.dispatcher import DispatcherMiddleware from backend.db import DBConnection, close_db, setup_db -from frontend.api import api, reminder_handler +from frontend.api import api, api_prefix, reminder_handler from frontend.ui import ui HOST = '0.0.0.0' @@ -65,7 +65,7 @@ def _create_app() -> Flask: return render_template('page_not_found.html', url_prefix=logging.URL_PREFIX) app.register_blueprint(ui) - app.register_blueprint(api, url_prefix="/api") + app.register_blueprint(api, url_prefix=api_prefix) # Setup closing database app.teardown_appcontext(close_db) diff --git a/backend/custom_exceptions.py b/backend/custom_exceptions.py index 6059ce2..35b5d03 100644 --- a/backend/custom_exceptions.py +++ b/backend/custom_exceptions.py @@ -41,10 +41,6 @@ class InvalidTime(Exception): """The time given is in the past""" api_response = {'error': 'InvalidTime', 'result': {}, 'code': 400} -class InvalidURL(Exception): - """The apprise url is invalid""" - api_response = {'error': 'InvalidURL', 'result': {}, 'code': 400} - class KeyNotFound(Exception): """A key was not found in the input that is required to be given""" def __init__(self, key: str=''): diff --git a/backend/notification_service.py b/backend/notification_service.py index 6c3ea78..9f23938 100644 --- a/backend/notification_service.py +++ b/backend/notification_service.py @@ -2,20 +2,18 @@ from typing import List -from apprise import Apprise - -from backend.custom_exceptions import (InvalidURL, NotificationServiceInUse, +from backend.custom_exceptions import (NotificationServiceInUse, NotificationServiceNotFound) from backend.db import get_db class NotificationService: - def __init__(self, notification_service_id: int) -> None: + def __init__(self, user_id: int, notification_service_id: int) -> None: self.id = notification_service_id if not get_db().execute( - "SELECT 1 FROM notification_services WHERE id = ? LIMIT 1;", - (self.id,) + "SELECT 1 FROM notification_services WHERE id = ? AND user_id = ? LIMIT 1;", + (self.id, user_id) ).fetchone(): raise NotificationServiceNotFound @@ -46,9 +44,7 @@ class NotificationService: Returns: dict: The new info about the service """ - if not Apprise().add(url): - raise InvalidURL - + # Get current data and update it with new values data = self.get() new_values = { @@ -157,7 +153,7 @@ class NotificationServices: Returns: NotificationService: Instance of NotificationService """ - return NotificationService(notification_service_id) + return NotificationService(self.user_id, notification_service_id) def add(self, title: str, url: str) -> NotificationService: """Add a notification service @@ -166,15 +162,10 @@ class NotificationServices: title (str): The title of the service url (str): The apprise url of the service - Raises: - InvalidURL: The apprise url is invalid - Returns: dict: The info about the new service """ - if not Apprise().add(url): - raise InvalidURL - + new_id = get_db().execute(""" INSERT INTO notification_services(user_id, title, url) VALUES (?,?,?) diff --git a/backend/reminders.py b/backend/reminders.py index 96bc8f5..62e99e2 100644 --- a/backend/reminders.py +++ b/backend/reminders.py @@ -142,13 +142,13 @@ reminder_handler = ReminderHandler(handler_context.app_context) class Reminder: """Represents a reminder """ - def __init__(self, reminder_id: int): + def __init__(self, user_id: int, reminder_id: int): self.id = reminder_id # Check if reminder exists if not get_db().execute( - "SELECT 1 FROM reminders WHERE id = ? LIMIT 1", - (self.id,) + "SELECT 1 FROM reminders WHERE id = ? AND user_id = ? LIMIT 1", + (self.id, user_id) ).fetchone(): raise ReminderNotFound @@ -381,7 +381,7 @@ class Reminders: Returns: Reminder: A Reminder instance """ - return Reminder(id) + return Reminder(self.user_id, id) def add( self, @@ -420,6 +420,13 @@ class Reminders: raise InvalidKeyValue('repeat_interval', repeat_interval) cursor = get_db() + for service in notification_services: + if not cursor.execute( + "SELECT 1 FROM notification_services WHERE id = ? AND user_id = ? LIMIT 1;", + (service, self.user_id) + ).fetchone(): + raise NotificationServiceNotFound + if repeat_quantity is None and repeat_interval is None: id = cursor.execute(""" INSERT INTO reminders(user_id, title, text, time, color) @@ -446,27 +453,28 @@ class Reminders: # Return info return self.fetchone(id) -def test_reminder( - title: str, - notification_services: List[int], - text: str = '' -) -> None: - """Test send a reminder draft + def test_reminder( + self, + title: str, + notification_services: List[int], + text: str = '' + ) -> None: + """Test send a reminder draft - Args: - title (str): Title title of the entry - notification_service (int): The id of the notification service to use to send the reminder - text (str, optional): The body of the reminder. Defaults to ''. - """ - a = Apprise() - cursor = get_db(dict) - for service in notification_services: - url = cursor.execute( - "SELECT url FROM notification_services WHERE id = ? LIMIT 1;", - (service,) - ).fetchone() - if not url: - raise NotificationServiceNotFound - a.add(url[0]) - a.notify(title=title, body=text) - return + Args: + title (str): Title title of the entry + notification_service (int): The id of the notification service to use to send the reminder + text (str, optional): The body of the reminder. Defaults to ''. + """ + a = Apprise() + cursor = get_db(dict) + for service in notification_services: + url = cursor.execute( + "SELECT url FROM notification_services WHERE id = ? AND user_id = ? LIMIT 1;", + (service, self.user_id) + ).fetchone() + if not url: + raise NotificationServiceNotFound + a.add(url[0]) + a.notify(title=title, body=text) + return diff --git a/backend/static_reminders.py b/backend/static_reminders.py index d3a684b..1f1330a 100644 --- a/backend/static_reminders.py +++ b/backend/static_reminders.py @@ -17,13 +17,13 @@ filter_function = lambda query, p: ( class StaticReminder: """Represents a static reminder """ - def __init__(self, reminder_id: int) -> None: + def __init__(self, user_id: int, reminder_id: int) -> None: self.id = reminder_id # Check if reminder exists if not get_db().execute( - "SELECT 1 FROM static_reminders WHERE id = ? LIMIT 1;", - (self.id,) + "SELECT 1 FROM static_reminders WHERE id = ? AND user_id = ? LIMIT 1;", + (self.id, user_id) ).fetchone(): raise ReminderNotFound @@ -195,7 +195,7 @@ class StaticReminders: Returns: StaticReminder: A StaticReminder instance """ - return StaticReminder(id) + return StaticReminder(self.user_id, id) def add( self, @@ -249,9 +249,11 @@ class StaticReminders: reminder = cursor.execute(""" SELECT title, text FROM static_reminders - WHERE id = ? + WHERE + id = ? + AND user_id = ? LIMIT 1; - """, (id,)).fetchone() + """, (id, self.user_id)).fetchone() if not reminder: raise ReminderNotFound reminder = dict(reminder) diff --git a/backend/templates.py b/backend/templates.py index 7824701..84749ea 100644 --- a/backend/templates.py +++ b/backend/templates.py @@ -15,12 +15,12 @@ filter_function = lambda query, p: ( class Template: """Represents a template """ - def __init__(self, template_id: int): + def __init__(self, user_id: int, template_id: int): self.id = template_id exists = get_db().execute( - "SELECT 1 FROM templates WHERE id = ? LIMIT 1;", - (self.id,) + "SELECT 1 FROM templates WHERE id = ? AND user_id = ? LIMIT 1;", + (self.id, user_id) ).fetchone() if not exists: raise TemplateNotFound @@ -184,7 +184,7 @@ class Templates: Returns: Template: A Template instance """ - return Template(id) + return Template(self.user_id, id) def add( self, diff --git a/frontend/api.py b/frontend/api.py index b49e2a7..90a6712 100644 --- a/frontend/api.py +++ b/frontend/api.py @@ -1,25 +1,28 @@ #-*- coding: utf-8 -*- +from abc import ABC, abstractmethod from os import urandom from re import compile from time import time as epoch_time -from typing import Any, Tuple +from typing import Any, Dict, List, Tuple +from apprise import Apprise from flask import Blueprint, g, request from backend.custom_exceptions import (AccessUnauthorized, InvalidKeyValue, - InvalidTime, InvalidURL, KeyNotFound, + InvalidTime, KeyNotFound, NotificationServiceInUse, NotificationServiceNotFound, ReminderNotFound, UsernameInvalid, UsernameTaken, UserNotFound) from backend.notification_service import (NotificationService, NotificationServices) -from backend.reminders import Reminders, reminder_handler, test_reminder +from backend.reminders import Reminders, reminder_handler from backend.static_reminders import StaticReminders from backend.templates import Template, Templates from backend.users import User, register_user +api_prefix = "/api" api = Blueprint('api', __name__) api_key_map = {} color_regex = compile(r'#[0-9a-f]{6}') @@ -66,69 +69,285 @@ def error_handler(method): return method(*args, **kwargs) except (UsernameTaken, UsernameInvalid, UserNotFound, AccessUnauthorized, - ReminderNotFound, NotificationServiceNotFound, NotificationServiceInUse, - InvalidTime, InvalidURL, + ReminderNotFound, NotificationServiceNotFound, + NotificationServiceInUse, InvalidTime, KeyNotFound, InvalidKeyValue) as e: return return_api(**e.api_response) wrapper.__name__ = method.__name__ return wrapper -def extract_key(values: dict, key: str, check_existence: bool=True, sort_options: dict=None) -> Any: - value: str = values.get(key) - if check_existence and value is None: - raise KeyNotFound(key) - - if value is not None: - # Check value and optionally convert - if key == 'time': - try: - value = int(value) - except (ValueError, TypeError): - raise InvalidKeyValue(key, value) - - elif key == 'repeat_interval': - try: - value = int(value) - if value <= 0: - raise ValueError - except (ValueError, TypeError): - raise InvalidKeyValue(key, value) - - elif key == 'sort_by': - if not value in sort_options: - raise InvalidKeyValue(key, value) - - elif key == 'repeat_quantity': - if not value in ("years", "months", "weeks", "days", "hours", "minutes"): - raise InvalidKeyValue(key, value) - - elif key in ('username', 'password', 'new_password', 'title', 'url', - 'text', 'query'): - if not isinstance(value, str): - raise InvalidKeyValue(key, value) - - elif key == 'color': - if not color_regex.search(value): - raise InvalidKeyValue(key, value) +#=================== +# Input validation +#=================== - elif key == 'notification_services': - if not value: - raise KeyNotFound(key) - if not isinstance(value, list): - raise InvalidKeyValue(key, value) - for v in value: - if not isinstance(v, int): - raise InvalidKeyValue(key, value) +class DataSource: + DATA = 1 + VALUES = 2 - else: - if key == 'sort_by': - value = next(iter(sort_options)) - - elif key == 'text': - value = '' +class InputVariable(ABC): + @abstractmethod + def __init__(self, value: Any) -> None: + pass + + @property + @abstractmethod + def name() -> str: + pass + + @abstractmethod + def validate(self) -> bool: + pass + + @property + @abstractmethod + def required() -> bool: + pass + + @property + @abstractmethod + def default() -> Any: + pass + + @property + @abstractmethod + def source() -> int: + pass + + @property + @abstractmethod + def description() -> str: + pass + +class DefaultInputVariable(InputVariable): + source = DataSource.DATA + required = True + default = None + + def __init__(self, value: Any) -> None: + self.value = value + + def validate(self) -> bool: + return isinstance(self.value, str) and self.value + +class NonRequiredVersion(InputVariable): + required = False + + def validate(self) -> bool: + return self.value is None or super().validate() + +class UsernameVariable(DefaultInputVariable): + name = 'username' + description = 'The username of the user account' + +class PasswordVariable(DefaultInputVariable): + name = 'password' + description = 'The password of the user account' + +class NewPasswordVariable(PasswordVariable): + name = 'new_password' + description = 'The new password of the user account' + +class TitleVariable(DefaultInputVariable): + name = 'title' + description = 'The title of the entry' + +class URLVariable(DefaultInputVariable): + name = 'url' + description = 'The Apprise URL of the notification service' + + def validate(self) -> bool: + return Apprise().add(self.value) + +class EditTitleVariable(NonRequiredVersion, TitleVariable): + pass + +class EditURLVariable(NonRequiredVersion, URLVariable): + pass + +class SortByVariable(DefaultInputVariable): + name = 'sort_by' + description = "How to sort the result. Allowed values are 'title', 'title_reversed', 'time', 'time_reversed', 'date_added' and 'date_added_reversed'" + required = False + source = DataSource.VALUES + _options = Reminders.sort_functions + default = next(iter(Reminders.sort_functions)) + + def __init__(self, value: str) -> None: + self.value = value + + def validate(self) -> bool: + return self.value in self._options + +class TemplateSortByVariable(SortByVariable): + description = "How to sort the result. Allowed values are 'title', 'title_reversed', 'date_added' and 'date_added_reversed'" + _options = Templates.sort_functions + default = next(iter(Templates.sort_functions)) + +class StaticReminderSortByVariable(TemplateSortByVariable): + _options = StaticReminders.sort_functions + default = next(iter(StaticReminders.sort_functions)) + +class TimeVariable(DefaultInputVariable): + name = 'time' + description = 'The UTC epoch timestamp that the reminder should be sent at' + + def validate(self) -> bool: + return isinstance(self.value, (float, int)) + +class EditTimeVariable(NonRequiredVersion, TimeVariable): + pass + +class NotificationServicesVariable(DefaultInputVariable): + name = 'notification_services' + description = "Array of the id's of the notification services to use to send the notification" + + def validate(self) -> bool: + if not isinstance(self.value, list): + return False + if not self.value: + return False + for v in self.value: + if not isinstance(v, int): + return False + return True + +class EditNotificationServicesVariable(NonRequiredVersion, NotificationServicesVariable): + pass + +class TextVariable(NonRequiredVersion, DefaultInputVariable): + name = 'text' + description = 'The body of the entry' + default = '' + + def validate(self) -> bool: + return isinstance(self.value, str) + +class RepeatQuantityVariable(DefaultInputVariable): + name = 'repeat_quantity' + description = 'The quantity of the repeat_interval' + required = False + _options = ("years", "months", "weeks", "days", "hours", "minutes") + default = None + + def validate(self) -> bool: + return self.value is None or self.value in self._options + +class RepeatIntervalVariable(DefaultInputVariable): + name = 'repeat_interval' + description = 'The number of the interval' + required = False + default = None + + def validate(self) -> bool: + return self.value is None or isinstance(self.value, int) + +class ColorVariable(DefaultInputVariable): + name = 'color' + description = 'The hex code of the color of the entry, which is shown in the web-ui' + required = False + default = None + + def validate(self) -> None: + return self.value is None or color_regex.search(self.value) + +class QueryVariable(DefaultInputVariable): + name = 'query' + description = 'The search term' + source = DataSource.VALUES + +endpoint_variables: Dict[str, Dict[str, List[InputVariable]]] = { + '/auth/login': { + 'POST': [UsernameVariable, PasswordVariable] + }, + '/user/add': { + 'POST': [UsernameVariable, PasswordVariable] + }, + '/user': { + 'PUT': [NewPasswordVariable] + }, + '/notificationservices': { + 'POST': [TitleVariable, URLVariable] + }, + '/notificationservices/': { + 'PUT': [EditTitleVariable, EditURLVariable] + }, + '/reminders': { + 'GET': [SortByVariable], + 'POST': [TitleVariable, TimeVariable, + NotificationServicesVariable, TextVariable, + RepeatQuantityVariable, RepeatIntervalVariable, + ColorVariable] + }, + '/reminders/search': { + 'GET': [SortByVariable, QueryVariable] + }, + '/reminders/test': { + 'POST': [TitleVariable, NotificationServicesVariable, + TextVariable] + }, + '/reminders/': { + 'PUT': [EditTitleVariable, EditTimeVariable, + EditNotificationServicesVariable, TextVariable, + RepeatQuantityVariable, RepeatIntervalVariable, + ColorVariable] + }, + '/templates': { + 'GET': [TemplateSortByVariable], + 'POST': [TitleVariable, NotificationServicesVariable, + TextVariable, ColorVariable] + }, + '/templates/search': { + 'GET': [TemplateSortByVariable, QueryVariable] + }, + '/templates/': { + 'PUT': [EditTitleVariable, EditNotificationServicesVariable, + TextVariable, ColorVariable] + }, + '/staticreminders': { + 'GET': [StaticReminderSortByVariable], + 'POST': [TitleVariable, NotificationServicesVariable, + TextVariable, ColorVariable] + }, + '/staticreminders/search': { + 'GET': [StaticReminderSortByVariable, QueryVariable] + }, + '/staticreminders/': { + 'PUT': [EditTitleVariable, EditNotificationServicesVariable, + TextVariable, ColorVariable] + } +} + +def input_validation(method): + """Checks, extracts and transforms inputs + """ + def wrapper(*args, **kwargs): + inputs = {} + endpoint = request.url_rule.rule.split(api_prefix)[1] + input_variables = endpoint_variables.get(endpoint, {}).get(request.method) + if input_variables is not None: + given_variables = {} + given_variables[DataSource.DATA] = request.get_json() if request.data else {} + given_variables[DataSource.VALUES] = request.values + + for input_variable in input_variables: + if ( + input_variable.required and + not input_variable.name in given_variables[input_variable.source] + ): + raise KeyNotFound(input_variable.name) + + input_value = given_variables[input_variable.source].get(input_variable.name, input_variable.default) - return value + if not input_variable(input_value).validate(): + raise InvalidKeyValue(input_variable.name, input_value) + + inputs[input_variable.name] = input_value + + return method(inputs, *args, **kwargs) + + wrapper.__name__ = method.__name__ + return wrapper #=================== # Authentication endpoints @@ -136,7 +355,8 @@ def extract_key(values: dict, key: str, check_existence: bool=True, sort_options @api.route('/auth/login', methods=['POST']) @error_handler -def api_login(): +@input_validation +def api_login(inputs: Dict[str, str]): """ Endpoint: /auth/login Description: Login to a user account @@ -156,21 +376,17 @@ def api_login(): 404: UsernameNotFound: The username was not found """ - data = request.get_json() - # Check if required keys are given - username = extract_key(data, 'username') - password = extract_key(data, 'password') + user = User(inputs['username'], inputs['password']) - # Check credentials - user = User(username, password) - - # Login valid + # Generate an API key until one + # is generated that isn't used already while True: api_key = urandom(16).hex() # <- length api key / 2 hashed_api_key = hash(api_key) if not hashed_api_key in api_key_map: break + exp = epoch_time() + 3600 api_key_map.update({ hashed_api_key: { @@ -225,7 +441,8 @@ def api_status(): @api.route('/user/add', methods=['POST']) @error_handler -def api_add_user(): +@input_validation +def api_add_user(inputs: Dict[str, str]): """ Endpoint: /user/add Description: Create a new user account @@ -243,20 +460,14 @@ def api_add_user(): UsernameInvalid: The username given is not allowed UsernameTaken: The username given is already in use """ - data = request.get_json() - - # Check if required keys are given - username = extract_key(data, 'username') - password = extract_key(data, 'password') - - # Add user - user_id = register_user(username, password) + user_id = register_user(inputs['username'], inputs['password']) return return_api({'user_id': user_id}, code=201) @api.route('/user', methods=['PUT', 'DELETE']) @error_handler @auth -def api_manage_user(): +@input_validation +def api_manage_user(inputs: Dict[str, str]): """ Endpoint: /user Description: Manage a user account @@ -278,17 +489,10 @@ def api_manage_user(): Account deleted successfully """ if request.method == 'PUT': - data = request.get_json() - - # Check if required key is given - new_password = extract_key(data, 'new_password') - - # Edit user - g.user_data.edit_password(new_password) + g.user_data.edit_password(inputs['new_password']) return return_api({}) elif request.method == 'DELETE': - # Delete user g.user_data.delete() api_key_map.pop(g.hashed_api_key) return return_api({}) @@ -300,7 +504,8 @@ def api_manage_user(): @api.route('/notificationservices', methods=['GET', 'POST']) @error_handler @auth -def api_notification_services_list(): +@input_validation +def api_notification_services_list(inputs: Dict[str, str]): """ Endpoint: /notificationservices Description: Manage the notification services @@ -323,23 +528,21 @@ def api_notification_services_list(): KeyNotFound: One of the required parameters was not given """ services: NotificationServices = g.user_data.notification_services - + if request.method == 'GET': result = services.fetchall() return return_api(result) elif request.method == 'POST': - data = request.get_json() - title = extract_key(data, 'title') - url = extract_key(data, 'url') - result = services.add(title=title, - url=url).get() + result = services.add(title=inputs['title'], + url=inputs['url']).get() return return_api(result, code=201) @api.route('/notificationservices/', methods=['GET', 'PUT', 'DELETE']) @error_handler @auth -def api_notification_service(n_id: int): +@input_validation +def api_notification_service(inputs: Dict[str, str], n_id: int): """ Endpoint: /notificationservices/ Description: Manage a specific notification service @@ -381,12 +584,8 @@ def api_notification_service(n_id: int): return return_api(result) elif request.method == 'PUT': - data = request.get_json() - title = extract_key(data, 'title', check_existence=False) - url = extract_key(data, 'url', check_existence=False) - - result = service.update(title=title, - url=url) + result = service.update(title=inputs['title'], + url=inputs['url']) return return_api(result) elif request.method == 'DELETE': @@ -400,7 +599,8 @@ def api_notification_service(n_id: int): @api.route('/reminders', methods=['GET', 'POST']) @error_handler @auth -def api_reminders_list(): +@input_validation +def api_reminders_list(inputs: Dict[str, Any]): """ Endpoint: /reminders Description: Manage the reminders @@ -434,33 +634,24 @@ def api_reminders_list(): reminders: Reminders = g.user_data.reminders if request.method == 'GET': - sort_by = extract_key(request.values, 'sort_by', check_existence=False, sort_options=Reminders.sort_functions) - result = reminders.fetchall(sort_by) + result = reminders.fetchall(inputs['sort_by']) return return_api(result) - - elif request.method == 'POST': - data = request.get_json() - title = extract_key(data, 'title') - time = extract_key(data, 'time') - notification_services = extract_key(data, 'notification_services') - text = extract_key(data, 'text', check_existence=False) - repeat_quantity = extract_key(data, 'repeat_quantity', check_existence=False) - repeat_interval = extract_key(data, 'repeat_interval', check_existence=False) - color = extract_key(data, 'color', check_existence=False) - result = reminders.add(title=title, - time=time, - notification_services=notification_services, - text=text, - repeat_quantity=repeat_quantity, - repeat_interval=repeat_interval, - color=color) + elif request.method == 'POST': + result = reminders.add(title=inputs['title'], + time=inputs['time'], + notification_services=inputs['notification_services'], + text=inputs['text'], + repeat_quantity=inputs['repeat_quantity'], + repeat_interval=inputs['repeat_interval'], + color=inputs['color']) return return_api(result.get(), code=201) @api.route('/reminders/search', methods=['GET']) @error_handler @auth -def api_reminders_query(): +@input_validation +def api_reminders_query(inputs: Dict[str, str]): """ Endpoint: /reminders/search Description: Search through the list of reminders @@ -476,16 +667,14 @@ def api_reminders_query(): 400: KeyNotFound: One of the required parameters was not given """ - query = extract_key(request.values, 'query') - sort_by = extract_key(request.values, 'sort_by', check_existence=False, sort_options=Reminders.sort_functions) - - result = g.user_data.reminders.search(query, sort_by) + result = g.user_data.reminders.search(inputs['query'], inputs['sort_by']) return return_api(result) @api.route('/reminders/test', methods=['POST']) @error_handler @auth -def api_test_reminder(): +@input_validation +def api_test_reminder(inputs: Dict[str, Any]): """ Endpoint: /reminders/test Description: Test send a reminder draft @@ -504,18 +693,14 @@ def api_test_reminder(): 404: NotificationServiceNotFound: The notification service given was not found """ - data = request.get_json() - title = extract_key(data, 'title') - notification_services = extract_key(data, 'notification_services') - text = extract_key(data, 'text', check_existence=False) - - test_reminder(title, notification_services, text) + g.user_data.reminders.test_reminder(inputs['title'], inputs['notification_services'], inputs['text']) return return_api({}, code=201) @api.route('/reminders/', methods=['GET', 'PUT', 'DELETE']) @error_handler @auth -def api_get_reminder(r_id: int): +@input_validation +def api_get_reminder(inputs: Dict[str, Any], r_id: int): """ Endpoint: /reminders/ Description: Manage a specific reminder @@ -560,22 +745,13 @@ def api_get_reminder(r_id: int): return return_api(result) elif request.method == 'PUT': - data = request.get_json() - title = extract_key(data, 'title', check_existence=False) - time = extract_key(data, 'time', check_existence=False) - notification_services = extract_key(data, 'notification_services', check_existence=False) - text = extract_key(data, 'text', check_existence=False) - repeat_quantity = extract_key(data, 'repeat_quantity', check_existence=False) - repeat_interval = extract_key(data, 'repeat_interval', check_existence=False) - color = extract_key(data, 'color', check_existence=False) - - result = reminders.fetchone(r_id).update(title=title, - time=time, - notification_services=notification_services, - text=text, - repeat_quantity=repeat_quantity, - repeat_interval=repeat_interval, - color=color) + result = reminders.fetchone(r_id).update(title=inputs['title'], + time=inputs['time'], + notification_services=inputs['notification_services'], + text=inputs['text'], + repeat_quantity=inputs['repeat_quantity'], + repeat_interval=inputs['repeat_interval'], + color=inputs['color']) return return_api(result) elif request.method == 'DELETE': @@ -589,7 +765,8 @@ def api_get_reminder(r_id: int): @api.route('/templates', methods=['GET', 'POST']) @error_handler @auth -def api_get_templates(): +@input_validation +def api_get_templates(inputs: Dict[str, Any]): """ Endpoint: /templates Description: Manage the templates @@ -620,27 +797,21 @@ def api_get_templates(): templates: Templates = g.user_data.templates if request.method == 'GET': - sort_by = extract_key(request.values, 'sort_by', check_existence=False, sort_options=Templates.sort_functions) - result = templates.fetchall(sort_by) + result = templates.fetchall(inputs['sort_by']) return return_api(result) elif request.method == 'POST': - data = request.get_json() - title = extract_key(data, 'title') - notification_services = extract_key(data, 'notification_services') - text = extract_key(data, 'text', check_existence=False) - color = extract_key(data, 'color', check_existence=False) - - result = templates.add(title=title, - notification_services=notification_services, - text=text, - color=color) + result = templates.add(title=inputs['title'], + notification_services=inputs['notification_services'], + text=inputs['text'], + color=inputs['color']) return return_api(result.get(), code=201) @api.route('/templates/search', methods=['GET']) @error_handler @auth -def api_templates_query(): +@input_validation +def api_templates_query(inputs: Dict[str, str]): """ Endpoint: /templates/search Description: Search through the list of templates @@ -656,16 +827,14 @@ def api_templates_query(): 400: KeyNotFound: One of the required parameters was not given """ - query = extract_key(request.values, 'query') - sort_by = extract_key(request.values, 'sort_by', check_existence=False, sort_options=Templates.sort_functions) - - result = g.user_data.templates.search(query, sort_by) + result = g.user_data.templates.search(inputs['query'], inputs['sort_by']) return return_api(result) @api.route('/templates/', methods=['GET', 'PUT', 'DELETE']) @error_handler @auth -def api_get_template(t_id: int): +@input_validation +def api_get_template(inputs: Dict[str, Any], t_id: int): """ Endpoint: /templates/ Description: Manage a specific template @@ -708,16 +877,10 @@ def api_get_template(t_id: int): return return_api(result) elif request.method == 'PUT': - data = request.get_json() - title = extract_key(data, 'title', check_existence=False) - notification_services = extract_key(data, 'notification_services', check_existence=False) - text = extract_key(data, 'text', check_existence=False) - color = extract_key(data, 'color', check_existence=False) - - result = template.update(title=title, - notification_services=notification_services, - text=text, - color=color) + result = template.update(title=inputs['title'], + notification_services=inputs['notification_services'], + text=inputs['text'], + color=inputs['color']) return return_api(result) elif request.method == 'DELETE': @@ -731,7 +894,8 @@ def api_get_template(t_id: int): @api.route('/staticreminders', methods=['GET', 'POST']) @error_handler @auth -def api_static_reminders_list(): +@input_validation +def api_static_reminders_list(inputs: Dict[str, Any]): """ Endpoint: /staticreminders Description: Manage the static reminders @@ -762,27 +926,21 @@ def api_static_reminders_list(): reminders: StaticReminders = g.user_data.static_reminders if request.method == 'GET': - sort_by = extract_key(request.values, 'sort_by', check_existence=False, sort_options=StaticReminders.sort_functions) - result = reminders.fetchall(sort_by) + result = reminders.fetchall(inputs['sort_by']) return return_api(result) elif request.method == 'POST': - data = request.get_json() - title = extract_key(data, 'title') - notification_services = extract_key(data, 'notification_services') - text = extract_key(data, 'text', check_existence=False) - color = extract_key(data, 'color', check_existence=False) - - result = reminders.add(title=title, - notification_services=notification_services, - text=text, - color=color) + result = reminders.add(title=inputs['title'], + notification_services=inputs['notification_services'], + text=inputs['text'], + color=inputs['color']) return return_api(result.get(), code=201) @api.route('/staticreminders/search', methods=['GET']) @error_handler @auth -def api_static_reminders_query(): +@input_validation +def api_static_reminders_query(inputs: Dict[str, str]): """ Endpoint: /staticreminders/search Description: Search through the list of staticreminders @@ -798,16 +956,14 @@ def api_static_reminders_query(): 400: KeyNotFound: One of the required parameters was not given """ - query = extract_key(request.values, 'query') - sort_by = extract_key(request.values, 'sort_by', check_existence=False, sort_options=StaticReminders.sort_functions) - - result = g.user_data.static_reminders.search(query, sort_by) + result = g.user_data.static_reminders.search(inputs['query'], inputs['sort_by']) return return_api(result) @api.route('/staticreminders/', methods=['GET', 'POST', 'PUT', 'DELETE']) @error_handler @auth -def api_get_static_reminder(r_id: int): +@input_validation +def api_get_static_reminder(inputs: Dict[str, Any], r_id: int): """ Endpoint: /staticreminders/ Description: Manage a specific static reminder @@ -858,16 +1014,10 @@ def api_get_static_reminder(r_id: int): return return_api({}) elif request.method == 'PUT': - data = request.get_json() - title = extract_key(data, 'title', check_existence=False) - notification_services = extract_key(data, 'notification_services', check_existence=False) - text = extract_key(data, 'text', check_existence=False) - color = extract_key(data, 'color', check_existence=False) - - result = reminders.fetchone(r_id).update(title=title, - notification_services=notification_services, - text=text, - color=color) + result = reminders.fetchone(r_id).update(title=inputs['title'], + notification_services=inputs['notification_services'], + text=inputs['text'], + color=inputs['color']) return return_api(result) elif request.method == 'DELETE': diff --git a/frontend/static/js/window.js b/frontend/static/js/window.js index 09522d0..d8c147a 100644 --- a/frontend/static/js/window.js +++ b/frontend/static/js/window.js @@ -94,7 +94,7 @@ function testReminder() { const ns = [... document.querySelectorAll('.notification-service-list input[type="checkbox"]:checked') ].map(c => parseInt(c.dataset.id)) - if (!ns) { + if (!ns.length) { input.classList.add('error-input'); input.title = 'No notification service set'; return diff --git a/tests/api_test.py b/tests/api_test.py index c062c4d..b504492 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -41,7 +41,6 @@ class Test_API(unittest.TestCase): self.assertEqual(result(ReminderNotFound), return_api(**ReminderNotFound.api_response)) self.assertEqual(result(NotificationServiceNotFound), return_api(**NotificationServiceNotFound.api_response)) self.assertEqual(result(InvalidTime), return_api(**InvalidTime.api_response)) - self.assertEqual(result(InvalidURL), return_api(**InvalidURL.api_response)) self.assertEqual(result(NotificationServiceInUse, 'test'), return_api(**NotificationServiceInUse('test').api_response)) self.assertEqual(result(KeyNotFound, 'test'), return_api(**KeyNotFound('test').api_response)) self.assertEqual(result(InvalidKeyValue, 'test', 'value'), return_api(**InvalidKeyValue('test', 'value').api_response))