Improved API input validation

This commit is contained in:
CasVT
2023-06-23 00:14:46 +02:00
parent 348b0b4f0a
commit 32d4faaa16
9 changed files with 412 additions and 266 deletions

View File

@@ -12,7 +12,7 @@ from waitress.server import create_server
from werkzeug.middleware.dispatcher import DispatcherMiddleware
from backend.db import DBConnection, close_db, setup_db
from frontend.api import api, reminder_handler
from frontend.api import api, api_prefix, reminder_handler
from frontend.ui import ui
HOST = '0.0.0.0'
@@ -65,7 +65,7 @@ def _create_app() -> Flask:
return render_template('page_not_found.html', url_prefix=logging.URL_PREFIX)
app.register_blueprint(ui)
app.register_blueprint(api, url_prefix="/api")
app.register_blueprint(api, url_prefix=api_prefix)
# Setup closing database
app.teardown_appcontext(close_db)

View File

@@ -41,10 +41,6 @@ class InvalidTime(Exception):
"""The time given is in the past"""
api_response = {'error': 'InvalidTime', 'result': {}, 'code': 400}
class InvalidURL(Exception):
"""The apprise url is invalid"""
api_response = {'error': 'InvalidURL', 'result': {}, 'code': 400}
class KeyNotFound(Exception):
"""A key was not found in the input that is required to be given"""
def __init__(self, key: str=''):

View File

@@ -2,20 +2,18 @@
from typing import List
from apprise import Apprise
from backend.custom_exceptions import (InvalidURL, NotificationServiceInUse,
from backend.custom_exceptions import (NotificationServiceInUse,
NotificationServiceNotFound)
from backend.db import get_db
class NotificationService:
def __init__(self, notification_service_id: int) -> None:
def __init__(self, user_id: int, notification_service_id: int) -> None:
self.id = notification_service_id
if not get_db().execute(
"SELECT 1 FROM notification_services WHERE id = ? LIMIT 1;",
(self.id,)
"SELECT 1 FROM notification_services WHERE id = ? AND user_id = ? LIMIT 1;",
(self.id, user_id)
).fetchone():
raise NotificationServiceNotFound
@@ -46,9 +44,7 @@ class NotificationService:
Returns:
dict: The new info about the service
"""
if not Apprise().add(url):
raise InvalidURL
# Get current data and update it with new values
data = self.get()
new_values = {
@@ -157,7 +153,7 @@ class NotificationServices:
Returns:
NotificationService: Instance of NotificationService
"""
return NotificationService(notification_service_id)
return NotificationService(self.user_id, notification_service_id)
def add(self, title: str, url: str) -> NotificationService:
"""Add a notification service
@@ -166,15 +162,10 @@ class NotificationServices:
title (str): The title of the service
url (str): The apprise url of the service
Raises:
InvalidURL: The apprise url is invalid
Returns:
dict: The info about the new service
"""
if not Apprise().add(url):
raise InvalidURL
new_id = get_db().execute("""
INSERT INTO notification_services(user_id, title, url)
VALUES (?,?,?)

View File

@@ -142,13 +142,13 @@ reminder_handler = ReminderHandler(handler_context.app_context)
class Reminder:
"""Represents a reminder
"""
def __init__(self, reminder_id: int):
def __init__(self, user_id: int, reminder_id: int):
self.id = reminder_id
# Check if reminder exists
if not get_db().execute(
"SELECT 1 FROM reminders WHERE id = ? LIMIT 1",
(self.id,)
"SELECT 1 FROM reminders WHERE id = ? AND user_id = ? LIMIT 1",
(self.id, user_id)
).fetchone():
raise ReminderNotFound
@@ -381,7 +381,7 @@ class Reminders:
Returns:
Reminder: A Reminder instance
"""
return Reminder(id)
return Reminder(self.user_id, id)
def add(
self,
@@ -420,6 +420,13 @@ class Reminders:
raise InvalidKeyValue('repeat_interval', repeat_interval)
cursor = get_db()
for service in notification_services:
if not cursor.execute(
"SELECT 1 FROM notification_services WHERE id = ? AND user_id = ? LIMIT 1;",
(service, self.user_id)
).fetchone():
raise NotificationServiceNotFound
if repeat_quantity is None and repeat_interval is None:
id = cursor.execute("""
INSERT INTO reminders(user_id, title, text, time, color)
@@ -446,27 +453,28 @@ class Reminders:
# Return info
return self.fetchone(id)
def test_reminder(
title: str,
notification_services: List[int],
text: str = ''
) -> None:
"""Test send a reminder draft
def test_reminder(
self,
title: str,
notification_services: List[int],
text: str = ''
) -> None:
"""Test send a reminder draft
Args:
title (str): Title title of the entry
notification_service (int): The id of the notification service to use to send the reminder
text (str, optional): The body of the reminder. Defaults to ''.
"""
a = Apprise()
cursor = get_db(dict)
for service in notification_services:
url = cursor.execute(
"SELECT url FROM notification_services WHERE id = ? LIMIT 1;",
(service,)
).fetchone()
if not url:
raise NotificationServiceNotFound
a.add(url[0])
a.notify(title=title, body=text)
return
Args:
title (str): Title title of the entry
notification_service (int): The id of the notification service to use to send the reminder
text (str, optional): The body of the reminder. Defaults to ''.
"""
a = Apprise()
cursor = get_db(dict)
for service in notification_services:
url = cursor.execute(
"SELECT url FROM notification_services WHERE id = ? AND user_id = ? LIMIT 1;",
(service, self.user_id)
).fetchone()
if not url:
raise NotificationServiceNotFound
a.add(url[0])
a.notify(title=title, body=text)
return

View File

@@ -17,13 +17,13 @@ filter_function = lambda query, p: (
class StaticReminder:
"""Represents a static reminder
"""
def __init__(self, reminder_id: int) -> None:
def __init__(self, user_id: int, reminder_id: int) -> None:
self.id = reminder_id
# Check if reminder exists
if not get_db().execute(
"SELECT 1 FROM static_reminders WHERE id = ? LIMIT 1;",
(self.id,)
"SELECT 1 FROM static_reminders WHERE id = ? AND user_id = ? LIMIT 1;",
(self.id, user_id)
).fetchone():
raise ReminderNotFound
@@ -195,7 +195,7 @@ class StaticReminders:
Returns:
StaticReminder: A StaticReminder instance
"""
return StaticReminder(id)
return StaticReminder(self.user_id, id)
def add(
self,
@@ -249,9 +249,11 @@ class StaticReminders:
reminder = cursor.execute("""
SELECT title, text
FROM static_reminders
WHERE id = ?
WHERE
id = ?
AND user_id = ?
LIMIT 1;
""", (id,)).fetchone()
""", (id, self.user_id)).fetchone()
if not reminder:
raise ReminderNotFound
reminder = dict(reminder)

View File

@@ -15,12 +15,12 @@ filter_function = lambda query, p: (
class Template:
"""Represents a template
"""
def __init__(self, template_id: int):
def __init__(self, user_id: int, template_id: int):
self.id = template_id
exists = get_db().execute(
"SELECT 1 FROM templates WHERE id = ? LIMIT 1;",
(self.id,)
"SELECT 1 FROM templates WHERE id = ? AND user_id = ? LIMIT 1;",
(self.id, user_id)
).fetchone()
if not exists:
raise TemplateNotFound
@@ -184,7 +184,7 @@ class Templates:
Returns:
Template: A Template instance
"""
return Template(id)
return Template(self.user_id, id)
def add(
self,

View File

@@ -1,25 +1,28 @@
#-*- coding: utf-8 -*-
from abc import ABC, abstractmethod
from os import urandom
from re import compile
from time import time as epoch_time
from typing import Any, Tuple
from typing import Any, Dict, List, Tuple
from apprise import Apprise
from flask import Blueprint, g, request
from backend.custom_exceptions import (AccessUnauthorized, InvalidKeyValue,
InvalidTime, InvalidURL, KeyNotFound,
InvalidTime, KeyNotFound,
NotificationServiceInUse,
NotificationServiceNotFound,
ReminderNotFound, UsernameInvalid,
UsernameTaken, UserNotFound)
from backend.notification_service import (NotificationService,
NotificationServices)
from backend.reminders import Reminders, reminder_handler, test_reminder
from backend.reminders import Reminders, reminder_handler
from backend.static_reminders import StaticReminders
from backend.templates import Template, Templates
from backend.users import User, register_user
api_prefix = "/api"
api = Blueprint('api', __name__)
api_key_map = {}
color_regex = compile(r'#[0-9a-f]{6}')
@@ -66,69 +69,285 @@ def error_handler(method):
return method(*args, **kwargs)
except (UsernameTaken, UsernameInvalid, UserNotFound,
AccessUnauthorized,
ReminderNotFound, NotificationServiceNotFound, NotificationServiceInUse,
InvalidTime, InvalidURL,
ReminderNotFound, NotificationServiceNotFound,
NotificationServiceInUse, InvalidTime,
KeyNotFound, InvalidKeyValue) as e:
return return_api(**e.api_response)
wrapper.__name__ = method.__name__
return wrapper
def extract_key(values: dict, key: str, check_existence: bool=True, sort_options: dict=None) -> Any:
value: str = values.get(key)
if check_existence and value is None:
raise KeyNotFound(key)
if value is not None:
# Check value and optionally convert
if key == 'time':
try:
value = int(value)
except (ValueError, TypeError):
raise InvalidKeyValue(key, value)
elif key == 'repeat_interval':
try:
value = int(value)
if value <= 0:
raise ValueError
except (ValueError, TypeError):
raise InvalidKeyValue(key, value)
elif key == 'sort_by':
if not value in sort_options:
raise InvalidKeyValue(key, value)
elif key == 'repeat_quantity':
if not value in ("years", "months", "weeks", "days", "hours", "minutes"):
raise InvalidKeyValue(key, value)
elif key in ('username', 'password', 'new_password', 'title', 'url',
'text', 'query'):
if not isinstance(value, str):
raise InvalidKeyValue(key, value)
elif key == 'color':
if not color_regex.search(value):
raise InvalidKeyValue(key, value)
#===================
# Input validation
#===================
elif key == 'notification_services':
if not value:
raise KeyNotFound(key)
if not isinstance(value, list):
raise InvalidKeyValue(key, value)
for v in value:
if not isinstance(v, int):
raise InvalidKeyValue(key, value)
class DataSource:
DATA = 1
VALUES = 2
else:
if key == 'sort_by':
value = next(iter(sort_options))
elif key == 'text':
value = ''
class InputVariable(ABC):
@abstractmethod
def __init__(self, value: Any) -> None:
pass
@property
@abstractmethod
def name() -> str:
pass
@abstractmethod
def validate(self) -> bool:
pass
@property
@abstractmethod
def required() -> bool:
pass
@property
@abstractmethod
def default() -> Any:
pass
@property
@abstractmethod
def source() -> int:
pass
@property
@abstractmethod
def description() -> str:
pass
class DefaultInputVariable(InputVariable):
source = DataSource.DATA
required = True
default = None
def __init__(self, value: Any) -> None:
self.value = value
def validate(self) -> bool:
return isinstance(self.value, str) and self.value
class NonRequiredVersion(InputVariable):
required = False
def validate(self) -> bool:
return self.value is None or super().validate()
class UsernameVariable(DefaultInputVariable):
name = 'username'
description = 'The username of the user account'
class PasswordVariable(DefaultInputVariable):
name = 'password'
description = 'The password of the user account'
class NewPasswordVariable(PasswordVariable):
name = 'new_password'
description = 'The new password of the user account'
class TitleVariable(DefaultInputVariable):
name = 'title'
description = 'The title of the entry'
class URLVariable(DefaultInputVariable):
name = 'url'
description = 'The Apprise URL of the notification service'
def validate(self) -> bool:
return Apprise().add(self.value)
class EditTitleVariable(NonRequiredVersion, TitleVariable):
pass
class EditURLVariable(NonRequiredVersion, URLVariable):
pass
class SortByVariable(DefaultInputVariable):
name = 'sort_by'
description = "How to sort the result. Allowed values are 'title', 'title_reversed', 'time', 'time_reversed', 'date_added' and 'date_added_reversed'"
required = False
source = DataSource.VALUES
_options = Reminders.sort_functions
default = next(iter(Reminders.sort_functions))
def __init__(self, value: str) -> None:
self.value = value
def validate(self) -> bool:
return self.value in self._options
class TemplateSortByVariable(SortByVariable):
description = "How to sort the result. Allowed values are 'title', 'title_reversed', 'date_added' and 'date_added_reversed'"
_options = Templates.sort_functions
default = next(iter(Templates.sort_functions))
class StaticReminderSortByVariable(TemplateSortByVariable):
_options = StaticReminders.sort_functions
default = next(iter(StaticReminders.sort_functions))
class TimeVariable(DefaultInputVariable):
name = 'time'
description = 'The UTC epoch timestamp that the reminder should be sent at'
def validate(self) -> bool:
return isinstance(self.value, (float, int))
class EditTimeVariable(NonRequiredVersion, TimeVariable):
pass
class NotificationServicesVariable(DefaultInputVariable):
name = 'notification_services'
description = "Array of the id's of the notification services to use to send the notification"
def validate(self) -> bool:
if not isinstance(self.value, list):
return False
if not self.value:
return False
for v in self.value:
if not isinstance(v, int):
return False
return True
class EditNotificationServicesVariable(NonRequiredVersion, NotificationServicesVariable):
pass
class TextVariable(NonRequiredVersion, DefaultInputVariable):
name = 'text'
description = 'The body of the entry'
default = ''
def validate(self) -> bool:
return isinstance(self.value, str)
class RepeatQuantityVariable(DefaultInputVariable):
name = 'repeat_quantity'
description = 'The quantity of the repeat_interval'
required = False
_options = ("years", "months", "weeks", "days", "hours", "minutes")
default = None
def validate(self) -> bool:
return self.value is None or self.value in self._options
class RepeatIntervalVariable(DefaultInputVariable):
name = 'repeat_interval'
description = 'The number of the interval'
required = False
default = None
def validate(self) -> bool:
return self.value is None or isinstance(self.value, int)
class ColorVariable(DefaultInputVariable):
name = 'color'
description = 'The hex code of the color of the entry, which is shown in the web-ui'
required = False
default = None
def validate(self) -> None:
return self.value is None or color_regex.search(self.value)
class QueryVariable(DefaultInputVariable):
name = 'query'
description = 'The search term'
source = DataSource.VALUES
endpoint_variables: Dict[str, Dict[str, List[InputVariable]]] = {
'/auth/login': {
'POST': [UsernameVariable, PasswordVariable]
},
'/user/add': {
'POST': [UsernameVariable, PasswordVariable]
},
'/user': {
'PUT': [NewPasswordVariable]
},
'/notificationservices': {
'POST': [TitleVariable, URLVariable]
},
'/notificationservices/<int:n_id>': {
'PUT': [EditTitleVariable, EditURLVariable]
},
'/reminders': {
'GET': [SortByVariable],
'POST': [TitleVariable, TimeVariable,
NotificationServicesVariable, TextVariable,
RepeatQuantityVariable, RepeatIntervalVariable,
ColorVariable]
},
'/reminders/search': {
'GET': [SortByVariable, QueryVariable]
},
'/reminders/test': {
'POST': [TitleVariable, NotificationServicesVariable,
TextVariable]
},
'/reminders/<int:r_id>': {
'PUT': [EditTitleVariable, EditTimeVariable,
EditNotificationServicesVariable, TextVariable,
RepeatQuantityVariable, RepeatIntervalVariable,
ColorVariable]
},
'/templates': {
'GET': [TemplateSortByVariable],
'POST': [TitleVariable, NotificationServicesVariable,
TextVariable, ColorVariable]
},
'/templates/search': {
'GET': [TemplateSortByVariable, QueryVariable]
},
'/templates/<int:t_id>': {
'PUT': [EditTitleVariable, EditNotificationServicesVariable,
TextVariable, ColorVariable]
},
'/staticreminders': {
'GET': [StaticReminderSortByVariable],
'POST': [TitleVariable, NotificationServicesVariable,
TextVariable, ColorVariable]
},
'/staticreminders/search': {
'GET': [StaticReminderSortByVariable, QueryVariable]
},
'/staticreminders/<int:r_id>': {
'PUT': [EditTitleVariable, EditNotificationServicesVariable,
TextVariable, ColorVariable]
}
}
def input_validation(method):
"""Checks, extracts and transforms inputs
"""
def wrapper(*args, **kwargs):
inputs = {}
endpoint = request.url_rule.rule.split(api_prefix)[1]
input_variables = endpoint_variables.get(endpoint, {}).get(request.method)
if input_variables is not None:
given_variables = {}
given_variables[DataSource.DATA] = request.get_json() if request.data else {}
given_variables[DataSource.VALUES] = request.values
for input_variable in input_variables:
if (
input_variable.required and
not input_variable.name in given_variables[input_variable.source]
):
raise KeyNotFound(input_variable.name)
input_value = given_variables[input_variable.source].get(input_variable.name, input_variable.default)
return value
if not input_variable(input_value).validate():
raise InvalidKeyValue(input_variable.name, input_value)
inputs[input_variable.name] = input_value
return method(inputs, *args, **kwargs)
wrapper.__name__ = method.__name__
return wrapper
#===================
# Authentication endpoints
@@ -136,7 +355,8 @@ def extract_key(values: dict, key: str, check_existence: bool=True, sort_options
@api.route('/auth/login', methods=['POST'])
@error_handler
def api_login():
@input_validation
def api_login(inputs: Dict[str, str]):
"""
Endpoint: /auth/login
Description: Login to a user account
@@ -156,21 +376,17 @@ def api_login():
404:
UsernameNotFound: The username was not found
"""
data = request.get_json()
# Check if required keys are given
username = extract_key(data, 'username')
password = extract_key(data, 'password')
user = User(inputs['username'], inputs['password'])
# Check credentials
user = User(username, password)
# Login valid
# Generate an API key until one
# is generated that isn't used already
while True:
api_key = urandom(16).hex() # <- length api key / 2
hashed_api_key = hash(api_key)
if not hashed_api_key in api_key_map:
break
exp = epoch_time() + 3600
api_key_map.update({
hashed_api_key: {
@@ -225,7 +441,8 @@ def api_status():
@api.route('/user/add', methods=['POST'])
@error_handler
def api_add_user():
@input_validation
def api_add_user(inputs: Dict[str, str]):
"""
Endpoint: /user/add
Description: Create a new user account
@@ -243,20 +460,14 @@ def api_add_user():
UsernameInvalid: The username given is not allowed
UsernameTaken: The username given is already in use
"""
data = request.get_json()
# Check if required keys are given
username = extract_key(data, 'username')
password = extract_key(data, 'password')
# Add user
user_id = register_user(username, password)
user_id = register_user(inputs['username'], inputs['password'])
return return_api({'user_id': user_id}, code=201)
@api.route('/user', methods=['PUT', 'DELETE'])
@error_handler
@auth
def api_manage_user():
@input_validation
def api_manage_user(inputs: Dict[str, str]):
"""
Endpoint: /user
Description: Manage a user account
@@ -278,17 +489,10 @@ def api_manage_user():
Account deleted successfully
"""
if request.method == 'PUT':
data = request.get_json()
# Check if required key is given
new_password = extract_key(data, 'new_password')
# Edit user
g.user_data.edit_password(new_password)
g.user_data.edit_password(inputs['new_password'])
return return_api({})
elif request.method == 'DELETE':
# Delete user
g.user_data.delete()
api_key_map.pop(g.hashed_api_key)
return return_api({})
@@ -300,7 +504,8 @@ def api_manage_user():
@api.route('/notificationservices', methods=['GET', 'POST'])
@error_handler
@auth
def api_notification_services_list():
@input_validation
def api_notification_services_list(inputs: Dict[str, str]):
"""
Endpoint: /notificationservices
Description: Manage the notification services
@@ -323,23 +528,21 @@ def api_notification_services_list():
KeyNotFound: One of the required parameters was not given
"""
services: NotificationServices = g.user_data.notification_services
if request.method == 'GET':
result = services.fetchall()
return return_api(result)
elif request.method == 'POST':
data = request.get_json()
title = extract_key(data, 'title')
url = extract_key(data, 'url')
result = services.add(title=title,
url=url).get()
result = services.add(title=inputs['title'],
url=inputs['url']).get()
return return_api(result, code=201)
@api.route('/notificationservices/<int:n_id>', methods=['GET', 'PUT', 'DELETE'])
@error_handler
@auth
def api_notification_service(n_id: int):
@input_validation
def api_notification_service(inputs: Dict[str, str], n_id: int):
"""
Endpoint: /notificationservices/<n_id>
Description: Manage a specific notification service
@@ -381,12 +584,8 @@ def api_notification_service(n_id: int):
return return_api(result)
elif request.method == 'PUT':
data = request.get_json()
title = extract_key(data, 'title', check_existence=False)
url = extract_key(data, 'url', check_existence=False)
result = service.update(title=title,
url=url)
result = service.update(title=inputs['title'],
url=inputs['url'])
return return_api(result)
elif request.method == 'DELETE':
@@ -400,7 +599,8 @@ def api_notification_service(n_id: int):
@api.route('/reminders', methods=['GET', 'POST'])
@error_handler
@auth
def api_reminders_list():
@input_validation
def api_reminders_list(inputs: Dict[str, Any]):
"""
Endpoint: /reminders
Description: Manage the reminders
@@ -434,33 +634,24 @@ def api_reminders_list():
reminders: Reminders = g.user_data.reminders
if request.method == 'GET':
sort_by = extract_key(request.values, 'sort_by', check_existence=False, sort_options=Reminders.sort_functions)
result = reminders.fetchall(sort_by)
result = reminders.fetchall(inputs['sort_by'])
return return_api(result)
elif request.method == 'POST':
data = request.get_json()
title = extract_key(data, 'title')
time = extract_key(data, 'time')
notification_services = extract_key(data, 'notification_services')
text = extract_key(data, 'text', check_existence=False)
repeat_quantity = extract_key(data, 'repeat_quantity', check_existence=False)
repeat_interval = extract_key(data, 'repeat_interval', check_existence=False)
color = extract_key(data, 'color', check_existence=False)
result = reminders.add(title=title,
time=time,
notification_services=notification_services,
text=text,
repeat_quantity=repeat_quantity,
repeat_interval=repeat_interval,
color=color)
elif request.method == 'POST':
result = reminders.add(title=inputs['title'],
time=inputs['time'],
notification_services=inputs['notification_services'],
text=inputs['text'],
repeat_quantity=inputs['repeat_quantity'],
repeat_interval=inputs['repeat_interval'],
color=inputs['color'])
return return_api(result.get(), code=201)
@api.route('/reminders/search', methods=['GET'])
@error_handler
@auth
def api_reminders_query():
@input_validation
def api_reminders_query(inputs: Dict[str, str]):
"""
Endpoint: /reminders/search
Description: Search through the list of reminders
@@ -476,16 +667,14 @@ def api_reminders_query():
400:
KeyNotFound: One of the required parameters was not given
"""
query = extract_key(request.values, 'query')
sort_by = extract_key(request.values, 'sort_by', check_existence=False, sort_options=Reminders.sort_functions)
result = g.user_data.reminders.search(query, sort_by)
result = g.user_data.reminders.search(inputs['query'], inputs['sort_by'])
return return_api(result)
@api.route('/reminders/test', methods=['POST'])
@error_handler
@auth
def api_test_reminder():
@input_validation
def api_test_reminder(inputs: Dict[str, Any]):
"""
Endpoint: /reminders/test
Description: Test send a reminder draft
@@ -504,18 +693,14 @@ def api_test_reminder():
404:
NotificationServiceNotFound: The notification service given was not found
"""
data = request.get_json()
title = extract_key(data, 'title')
notification_services = extract_key(data, 'notification_services')
text = extract_key(data, 'text', check_existence=False)
test_reminder(title, notification_services, text)
g.user_data.reminders.test_reminder(inputs['title'], inputs['notification_services'], inputs['text'])
return return_api({}, code=201)
@api.route('/reminders/<int:r_id>', methods=['GET', 'PUT', 'DELETE'])
@error_handler
@auth
def api_get_reminder(r_id: int):
@input_validation
def api_get_reminder(inputs: Dict[str, Any], r_id: int):
"""
Endpoint: /reminders/<r_id>
Description: Manage a specific reminder
@@ -560,22 +745,13 @@ def api_get_reminder(r_id: int):
return return_api(result)
elif request.method == 'PUT':
data = request.get_json()
title = extract_key(data, 'title', check_existence=False)
time = extract_key(data, 'time', check_existence=False)
notification_services = extract_key(data, 'notification_services', check_existence=False)
text = extract_key(data, 'text', check_existence=False)
repeat_quantity = extract_key(data, 'repeat_quantity', check_existence=False)
repeat_interval = extract_key(data, 'repeat_interval', check_existence=False)
color = extract_key(data, 'color', check_existence=False)
result = reminders.fetchone(r_id).update(title=title,
time=time,
notification_services=notification_services,
text=text,
repeat_quantity=repeat_quantity,
repeat_interval=repeat_interval,
color=color)
result = reminders.fetchone(r_id).update(title=inputs['title'],
time=inputs['time'],
notification_services=inputs['notification_services'],
text=inputs['text'],
repeat_quantity=inputs['repeat_quantity'],
repeat_interval=inputs['repeat_interval'],
color=inputs['color'])
return return_api(result)
elif request.method == 'DELETE':
@@ -589,7 +765,8 @@ def api_get_reminder(r_id: int):
@api.route('/templates', methods=['GET', 'POST'])
@error_handler
@auth
def api_get_templates():
@input_validation
def api_get_templates(inputs: Dict[str, Any]):
"""
Endpoint: /templates
Description: Manage the templates
@@ -620,27 +797,21 @@ def api_get_templates():
templates: Templates = g.user_data.templates
if request.method == 'GET':
sort_by = extract_key(request.values, 'sort_by', check_existence=False, sort_options=Templates.sort_functions)
result = templates.fetchall(sort_by)
result = templates.fetchall(inputs['sort_by'])
return return_api(result)
elif request.method == 'POST':
data = request.get_json()
title = extract_key(data, 'title')
notification_services = extract_key(data, 'notification_services')
text = extract_key(data, 'text', check_existence=False)
color = extract_key(data, 'color', check_existence=False)
result = templates.add(title=title,
notification_services=notification_services,
text=text,
color=color)
result = templates.add(title=inputs['title'],
notification_services=inputs['notification_services'],
text=inputs['text'],
color=inputs['color'])
return return_api(result.get(), code=201)
@api.route('/templates/search', methods=['GET'])
@error_handler
@auth
def api_templates_query():
@input_validation
def api_templates_query(inputs: Dict[str, str]):
"""
Endpoint: /templates/search
Description: Search through the list of templates
@@ -656,16 +827,14 @@ def api_templates_query():
400:
KeyNotFound: One of the required parameters was not given
"""
query = extract_key(request.values, 'query')
sort_by = extract_key(request.values, 'sort_by', check_existence=False, sort_options=Templates.sort_functions)
result = g.user_data.templates.search(query, sort_by)
result = g.user_data.templates.search(inputs['query'], inputs['sort_by'])
return return_api(result)
@api.route('/templates/<int:t_id>', methods=['GET', 'PUT', 'DELETE'])
@error_handler
@auth
def api_get_template(t_id: int):
@input_validation
def api_get_template(inputs: Dict[str, Any], t_id: int):
"""
Endpoint: /templates/<t_id>
Description: Manage a specific template
@@ -708,16 +877,10 @@ def api_get_template(t_id: int):
return return_api(result)
elif request.method == 'PUT':
data = request.get_json()
title = extract_key(data, 'title', check_existence=False)
notification_services = extract_key(data, 'notification_services', check_existence=False)
text = extract_key(data, 'text', check_existence=False)
color = extract_key(data, 'color', check_existence=False)
result = template.update(title=title,
notification_services=notification_services,
text=text,
color=color)
result = template.update(title=inputs['title'],
notification_services=inputs['notification_services'],
text=inputs['text'],
color=inputs['color'])
return return_api(result)
elif request.method == 'DELETE':
@@ -731,7 +894,8 @@ def api_get_template(t_id: int):
@api.route('/staticreminders', methods=['GET', 'POST'])
@error_handler
@auth
def api_static_reminders_list():
@input_validation
def api_static_reminders_list(inputs: Dict[str, Any]):
"""
Endpoint: /staticreminders
Description: Manage the static reminders
@@ -762,27 +926,21 @@ def api_static_reminders_list():
reminders: StaticReminders = g.user_data.static_reminders
if request.method == 'GET':
sort_by = extract_key(request.values, 'sort_by', check_existence=False, sort_options=StaticReminders.sort_functions)
result = reminders.fetchall(sort_by)
result = reminders.fetchall(inputs['sort_by'])
return return_api(result)
elif request.method == 'POST':
data = request.get_json()
title = extract_key(data, 'title')
notification_services = extract_key(data, 'notification_services')
text = extract_key(data, 'text', check_existence=False)
color = extract_key(data, 'color', check_existence=False)
result = reminders.add(title=title,
notification_services=notification_services,
text=text,
color=color)
result = reminders.add(title=inputs['title'],
notification_services=inputs['notification_services'],
text=inputs['text'],
color=inputs['color'])
return return_api(result.get(), code=201)
@api.route('/staticreminders/search', methods=['GET'])
@error_handler
@auth
def api_static_reminders_query():
@input_validation
def api_static_reminders_query(inputs: Dict[str, str]):
"""
Endpoint: /staticreminders/search
Description: Search through the list of staticreminders
@@ -798,16 +956,14 @@ def api_static_reminders_query():
400:
KeyNotFound: One of the required parameters was not given
"""
query = extract_key(request.values, 'query')
sort_by = extract_key(request.values, 'sort_by', check_existence=False, sort_options=StaticReminders.sort_functions)
result = g.user_data.static_reminders.search(query, sort_by)
result = g.user_data.static_reminders.search(inputs['query'], inputs['sort_by'])
return return_api(result)
@api.route('/staticreminders/<int:r_id>', methods=['GET', 'POST', 'PUT', 'DELETE'])
@error_handler
@auth
def api_get_static_reminder(r_id: int):
@input_validation
def api_get_static_reminder(inputs: Dict[str, Any], r_id: int):
"""
Endpoint: /staticreminders/<r_id>
Description: Manage a specific static reminder
@@ -858,16 +1014,10 @@ def api_get_static_reminder(r_id: int):
return return_api({})
elif request.method == 'PUT':
data = request.get_json()
title = extract_key(data, 'title', check_existence=False)
notification_services = extract_key(data, 'notification_services', check_existence=False)
text = extract_key(data, 'text', check_existence=False)
color = extract_key(data, 'color', check_existence=False)
result = reminders.fetchone(r_id).update(title=title,
notification_services=notification_services,
text=text,
color=color)
result = reminders.fetchone(r_id).update(title=inputs['title'],
notification_services=inputs['notification_services'],
text=inputs['text'],
color=inputs['color'])
return return_api(result)
elif request.method == 'DELETE':

View File

@@ -94,7 +94,7 @@ function testReminder() {
const ns = [...
document.querySelectorAll('.notification-service-list input[type="checkbox"]:checked')
].map(c => parseInt(c.dataset.id))
if (!ns) {
if (!ns.length) {
input.classList.add('error-input');
input.title = 'No notification service set';
return

View File

@@ -41,7 +41,6 @@ class Test_API(unittest.TestCase):
self.assertEqual(result(ReminderNotFound), return_api(**ReminderNotFound.api_response))
self.assertEqual(result(NotificationServiceNotFound), return_api(**NotificationServiceNotFound.api_response))
self.assertEqual(result(InvalidTime), return_api(**InvalidTime.api_response))
self.assertEqual(result(InvalidURL), return_api(**InvalidURL.api_response))
self.assertEqual(result(NotificationServiceInUse, 'test'), return_api(**NotificationServiceInUse('test').api_response))
self.assertEqual(result(KeyNotFound, 'test'), return_api(**KeyNotFound('test').api_response))
self.assertEqual(result(InvalidKeyValue, 'test', 'value'), return_api(**InvalidKeyValue('test', 'value').api_response))