Backend Refactor

This commit is contained in:
CasVT
2024-02-01 14:42:10 +01:00
parent a8b85a975f
commit ccdb16eef5
20 changed files with 1735 additions and 1071 deletions

View File

@@ -7,5 +7,7 @@
"*_test.py"
],
"python.testing.pytestEnabled": false,
"python.testing.unittestEnabled": true
"python.testing.unittestEnabled": true,
"python.analysis.autoImportCompletions": true,
"python.analysis.typeCheckingMode": "off"
}

78
MIND.py
View File

@@ -1,20 +1,24 @@
#!/usr/bin/env python3
#-*- coding: utf-8 -*-
"""
The main file where MIND is started from
"""
import logging
from os import makedirs, urandom
from os.path import abspath, dirname, isfile, join
from os.path import dirname, isfile
from shutil import move
from sys import version_info
from flask import Flask, render_template, request
from waitress.server import create_server
from werkzeug.middleware.dispatcher import DispatcherMiddleware
from backend.db import DBConnection, ThreadedTaskDispatcher, close_db, setup_db
from frontend.api import (admin_api, admin_api_prefix, api, api_prefix,
reminder_handler)
from frontend.ui import ui
from backend.helpers import check_python_version, folder_path
from backend.reminders import ReminderHandler
from frontend.api import admin_api, admin_api_prefix, api, api_prefix
from frontend.ui import UIVariables, ui
HOST = '0.0.0.0'
PORT = '8080'
@@ -23,35 +27,33 @@ LOGGING_LEVEL = logging.INFO
THREADS = 10
DB_FILENAME = 'db', 'MIND.db'
UIVariables.url_prefix = URL_PREFIX
logging.basicConfig(
level=LOGGING_LEVEL,
format='[%(asctime)s][%(threadName)s][%(levelname)s] %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
def _folder_path(*folders) -> str:
"""Turn filepaths relative to the project folder into absolute paths
Returns:
str: The absolute filepath
"""
return join(dirname(abspath(__file__)), *folders)
def _create_app() -> Flask:
"""Create a Flask app instance
Returns:
Flask: The created app instance
"""
app = Flask(
__name__,
template_folder=_folder_path('frontend','templates'),
static_folder=_folder_path('frontend','static'),
template_folder=folder_path('frontend','templates'),
static_folder=folder_path('frontend','static'),
static_url_path='/static'
)
app.config['SECRET_KEY'] = urandom(32)
app.config['JSONIFY_PRETTYPRINT_REGULAR'] = True
app.config['JSON_SORT_KEYS'] = False
app.config['APPLICATION_ROOT'] = URL_PREFIX
app.wsgi_app = DispatcherMiddleware(Flask(__name__), {URL_PREFIX: app.wsgi_app})
app.wsgi_app = DispatcherMiddleware(
Flask(__name__),
{URL_PREFIX: app.wsgi_app}
)
# Add error handlers
@app.errorhandler(400)
@@ -68,7 +70,7 @@ def _create_app() -> Flask:
@app.errorhandler(404)
def not_found(e):
if request.path.startswith('/api'):
if request.path.startswith(api_prefix):
return {'error': 'Not Found', 'result': {}}, 404
return render_template('page_not_found.html', url_prefix=logging.URL_PREFIX)
@@ -83,40 +85,42 @@ def _create_app() -> Flask:
def MIND() -> None:
"""The main function of MIND
Returns:
None
"""
# Check python version
if (version_info.major < 3) or (version_info.major == 3 and version_info.minor < 7):
logging.error('Error: the minimum python version required is python3.7 (currently ' + version_info.major + '.' + version_info.minor + '.' + version_info.micro + ')')
logging.info('Starting up MIND')
if not check_python_version():
exit(1)
if isfile(folder_path('db', 'Noted.db')):
move(folder_path('db', 'Noted.db'), folder_path(*DB_FILENAME))
db_location = folder_path(*DB_FILENAME)
logging.debug(f'Database location: {db_location}')
makedirs(dirname(db_location), exist_ok=True)
DBConnection.file = db_location
# Register web server
# We need to get the value to ui.py but MIND.py imports from ui.py so we get an import loop.
# To go around this, we abuse the fact that the logging module is a singleton.
# We add an attribute to the logging module and in ui.py get the value this way.
logging.URL_PREFIX = URL_PREFIX
app = _create_app()
reminder_handler = ReminderHandler(app.app_context)
with app.app_context():
if isfile(_folder_path('db', 'Noted.db')):
move(_folder_path('db', 'Noted.db'), _folder_path(*DB_FILENAME))
db_location = _folder_path(*DB_FILENAME)
logging.debug(f'Database location: {db_location}')
makedirs(dirname(db_location), exist_ok=True)
DBConnection.file = db_location
setup_db()
reminder_handler.find_next_reminder()
# Create waitress server and run
dispatcher = ThreadedTaskDispatcher()
dispatcher.set_thread_count(THREADS)
server = create_server(app, _dispatcher=dispatcher, host=HOST, port=PORT, threads=THREADS)
server = create_server(
app,
_dispatcher=dispatcher,
host=HOST,
port=PORT,
threads=THREADS
)
logging.info(f'MIND running on http://{HOST}:{PORT}{URL_PREFIX}')
# =================
server.run()
# Stopping thread
# =================
reminder_handler.stop_handling()
return

View File

@@ -1,8 +1,17 @@
#-*- coding: utf-8 -*-
"""
All custom exceptions are defined here
"""
"""
Note: Not all CE's inherit from CustomException.
"""
import logging
from typing import Any, Dict
class CustomException(Exception):
def __init__(self, e=None) -> None:
logging.warning(self.__doc__)
@@ -13,10 +22,18 @@ class UsernameTaken(CustomException):
"""The username is already taken"""
api_response = {'error': 'UsernameTaken', 'result': {}, 'code': 400}
class UsernameInvalid(CustomException):
class UsernameInvalid(Exception):
"""The username contains invalid characters"""
api_response = {'error': 'UsernameInvalid', 'result': {}, 'code': 400}
def __init__(self, username: str):
self.username = username
super().__init__(self.username)
logging.warning(
f'The username contains invalid characters: {username}'
)
return
class UserNotFound(CustomException):
"""The user requested can not be found"""
api_response = {'error': 'UserNotFound', 'result': {}, 'code': 404}
@@ -33,59 +50,81 @@ class NotificationServiceNotFound(CustomException):
"""The notification service was not found"""
api_response = {'error': 'NotificationServiceNotFound', 'result': {}, 'code': 404}
class NotificationServiceInUse(CustomException):
"""The notification service is wished to be deleted but a reminder is still using it"""
class NotificationServiceInUse(Exception):
"""
The notification service is wished to be deleted
but a reminder is still using it
"""
def __init__(self, type: str=''):
self.type = type
super().__init__(self.type)
logging.warning(
f'The notification is wished to be deleted but a reminder of type {type} is still using it'
)
return
@property
def api_response(self) -> Dict[str, Any]:
return {'error': 'NotificationServiceInUse', 'result': {'type': self.type}, 'code': 400}
return {
'error': 'NotificationServiceInUse',
'result': {'type': self.type},
'code': 400
}
class InvalidTime(CustomException):
"""The time given is in the past"""
api_response = {'error': 'InvalidTime', 'result': {}, 'code': 400}
class KeyNotFound(CustomException):
class KeyNotFound(Exception):
"""A key was not found in the input that is required to be given"""
def __init__(self, key: str=''):
self.key = key
super().__init__(self.key)
logging.warning(
"This key was not found in the API request,"
+ f" eventhough it's required: {key}"
)
return
@property
def api_response(self) -> Dict[str, Any]:
return {'error': 'KeyNotFound', 'result': {'key': self.key}, 'code': 400}
return {
'error': 'KeyNotFound',
'result': {'key': self.key},
'code': 400
}
class InvalidKeyValue(CustomException):
class InvalidKeyValue(Exception):
"""The value of a key is invalid"""
def __init__(self, key: str='', value: str=''):
self.key = key
self.value = value
super().__init__(self.key)
logging.warning(
'This key in the API request has an invalid value: ' +
f'{key} = {value}'
)
@property
def api_response(self) -> Dict[str, Any]:
return {'error': 'InvalidKeyValue', 'result': {'key': self.key, 'value': self.value}, 'code': 400}
return {
'error': 'InvalidKeyValue',
'result': {'key': self.key, 'value': self.value},
'code': 400
}
class TemplateNotFound(CustomException):
"""The template was not found"""
api_response = {'error': 'TemplateNotFound', 'result': {}, 'code': 404}
class APIKeyInvalid(CustomException):
class APIKeyInvalid(Exception):
"""The API key is not correct"""
api_response = {'error': 'APIKeyInvalid', 'result': {}, 'code': 401}
def __init__(self, e=None) -> None:
return
class APIKeyExpired(CustomException):
class APIKeyExpired(Exception):
"""The API key has expired"""
api_response = {'error': 'APIKeyExpired', 'result': {}, 'code': 401}
def __init__(self, e=None) -> None:
return
class NewAccountsNotAllowed(CustomException):
"""It's not allowed to create a new account"""
api_response = {'error': 'NewAccountsNotAllowed', 'result': {}, 'code': 403}

View File

@@ -1,11 +1,15 @@
#-*- coding: utf-8 -*-
"""
Setting up and interacting with the database.
"""
import logging
from datetime import datetime
from sqlite3 import Connection, ProgrammingError, Row
from threading import current_thread, main_thread
from time import time
from typing import Union
from typing import Type, Union
from flask import g
from waitress.task import ThreadedTaskDispatcher as OldThreadedTaskDispatcher
@@ -14,14 +18,14 @@ from backend.custom_exceptions import AccessUnauthorized, UserNotFound
__DATABASE_VERSION__ = 8
class Singleton(type):
class DB_Singleton(type):
_instances = {}
def __call__(cls, *args, **kwargs):
i = f'{cls}{current_thread()}'
if (i not in cls._instances
or cls._instances[i].closed):
logging.debug(f'Creating singleton instance: {i}')
cls._instances[i] = super(Singleton, cls).__call__(*args, **kwargs)
cls._instances[i] = super(DB_Singleton, cls).__call__(*args, **kwargs)
return cls._instances[i]
@@ -29,19 +33,21 @@ class ThreadedTaskDispatcher(OldThreadedTaskDispatcher):
def handler_thread(self, thread_no: int) -> None:
super().handler_thread(thread_no)
i = f'{DBConnection}{current_thread()}'
if i in Singleton._instances and not Singleton._instances[i].closed:
if i in DB_Singleton._instances and not DB_Singleton._instances[i].closed:
logging.debug(f'Closing singleton instance: {i}')
Singleton._instances[i].close()
DB_Singleton._instances[i].close()
return
def shutdown(self, cancel_pending: bool = True, timeout: int = 5) -> bool:
print()
logging.info('Shutting down MIND...')
super().shutdown(cancel_pending, timeout)
DBConnection(20.0).close()
return
class DBConnection(Connection, metaclass=Singleton):
class DBConnection(Connection, metaclass=DB_Singleton):
file = ''
def __init__(self, timeout: float) -> None:
logging.debug(f'Opening database connection for {current_thread()}')
super().__init__(self.file, timeout=timeout)
@@ -55,11 +61,13 @@ class DBConnection(Connection, metaclass=Singleton):
super().close()
return
def get_db(output_type: Union[dict, tuple]=tuple):
def get_db(output_type: Union[Type[dict], Type[tuple]]=tuple):
"""Get a database cursor instance. Coupled to Flask's g.
Args:
output_type (Union[dict, tuple], optional): The type of output: a tuple or dictionary with the row values. Defaults to tuple.
output_type (Union[Type[dict], Type[tuple]], optional):
The type of output: a tuple or dictionary with the row values.
Defaults to tuple.
Returns:
Cursor: The Cursor instance to use
@@ -220,7 +228,7 @@ def migrate_db(current_db_version: int) -> None:
if current_db_version == 7:
# V7 -> V8
from backend.settings import _format_setting, default_settings
from backend.users import register_user
from backend.users import Users
cursor.executescript("""
DROP TABLE config;
@@ -239,7 +247,7 @@ def migrate_db(current_db_version: int) -> None:
default_settings.items()
)
)
cursor.executescript("""
ALTER TABLE users
ADD admin BOOL NOT NULL DEFAULT 0;
@@ -248,9 +256,9 @@ def migrate_db(current_db_version: int) -> None:
SET username = 'admin_old'
WHERE username = 'admin';
""")
register_user('admin', 'admin')
Users().add('admin', 'admin', True)
cursor.execute("""
UPDATE users
SET admin = 1
@@ -264,6 +272,8 @@ def setup_db() -> None:
"""
from backend.settings import (_format_setting, default_settings, get_setting,
set_setting)
from backend.users import Users
cursor = get_db()
cursor.execute("PRAGMA journal_mode = wal;")
@@ -348,11 +358,22 @@ def setup_db() -> None:
default_settings.items()
)
)
current_db_version = get_setting('database_version')
logging.debug(f'Current database version {current_db_version} and desired database version {__DATABASE_VERSION__}')
if current_db_version < __DATABASE_VERSION__:
logging.debug(
f'Database migration: {current_db_version} -> {__DATABASE_VERSION__}'
)
migrate_db(current_db_version)
set_setting('database_version', __DATABASE_VERSION__)
users = Users()
if not 'admin' in users:
users.add('admin', 'admin', True)
cursor.execute("""
UPDATE users
SET admin = 1
WHERE username = 'admin';
""")
return

91
backend/helpers.py Normal file
View File

@@ -0,0 +1,91 @@
#-*- coding: utf-8 -*-
"""
General functions
"""
import logging
from enum import Enum
from os.path import abspath, dirname, join
from sys import version_info
def folder_path(*folders) -> str:
"""Turn filepaths relative to the project folder into absolute paths
Returns:
str: The absolute filepath
"""
return join(dirname(dirname(abspath(__file__))), *folders)
def check_python_version() -> bool:
"""Check if the python version that is used is a minimum version.
Returns:
bool: Whether or not the python version is version 3.8 or above or not.
"""
if not (version_info.major == 3 and version_info.minor >= 8):
logging.critical(
'The minimum python version required is python3.8 ' +
'(currently ' + version_info.major + '.' + version_info.minor + '.' + version_info.micro + ').'
)
return False
return True
def search_filter(query: str, result: dict) -> bool:
"""Filter library results based on a query.
Args:
query (str): The query to filter with.
result (dict): The library result to check.
Returns:
bool: Whether or not the result passes the filter.
"""
query = query.lower()
return (
query in result["title"].lower()
or query in result["text"].lower()
)
class Singleton(type):
_instances = {}
def __call__(cls, *args, **kwargs):
c = str(cls)
if c not in cls._instances:
cls._instances[c] = super().__call__(*args, **kwargs)
return cls._instances[c]
class BaseEnum(Enum):
def __eq__(self, other) -> bool:
return self.value == other
class TimelessSortingMethod(BaseEnum):
TITLE = (lambda r: (r['title'], r['text'], r['color']), False)
TITLE_REVERSED = (lambda r: (r['title'], r['text'], r['color']), True)
DATE_ADDED = (lambda r: r['id'], False)
DATE_ADDED_REVERSED = (lambda r: r['id'], True)
class SortingMethod(BaseEnum):
TIME = (lambda r: (r['time'], r['title'], r['text'], r['color']), False)
TIME_REVERSED = (lambda r: (r['time'], r['title'], r['text'], r['color']), True)
TITLE = (lambda r: (r['title'], r['time'], r['text'], r['color']), False)
TITLE_REVERSED = (lambda r: (r['title'], r['time'], r['text'], r['color']), True)
DATE_ADDED = (lambda r: r['id'], False)
DATE_ADDED_REVERSED = (lambda r: r['id'], True)
class RepeatQuantity(BaseEnum):
YEARS = "years"
MONTHS = "months"
WEEKS = "weeks"
DAYS = "days"
HOURS = "hours"
MINUTES = "minutes"

View File

@@ -74,7 +74,10 @@ def get_apprise_services() -> List[Dict[str, Union[str, Dict[str, list]]]]:
'prefix': content.get('prefix'),
'regex': process_regex(content.get('regex'))
}
for content, _ in ((entry['details']['tokens'][e], handled_tokens.add(e)) for e in v['group'])
for content, _ in (
(entry['details']['tokens'][e], handled_tokens.add(e))
for e in v['group']
)
]
}
for k, v in
@@ -175,8 +178,13 @@ class NotificationService:
def __init__(self, user_id: int, notification_service_id: int) -> None:
self.id = notification_service_id
if not get_db().execute(
"SELECT 1 FROM notification_services WHERE id = ? AND user_id = ? LIMIT 1;",
if not get_db().execute("""
SELECT 1
FROM notification_services
WHERE id = ?
AND user_id = ?
LIMIT 1;
""",
(self.id, user_id)
).fetchone():
raise NotificationServiceNotFound
@@ -187,8 +195,12 @@ class NotificationService:
Returns:
dict: The info about the notification service
"""
result = dict(get_db(dict).execute(
"SELECT id, title, url FROM notification_services WHERE id = ? LIMIT 1",
result = dict(get_db(dict).execute("""
SELECT id, title, url
FROM notification_services
WHERE id = ?
LIMIT 1
""",
(self.id,)
).fetchone())
@@ -330,7 +342,7 @@ class NotificationServices:
url (str): The apprise url of the service
Returns:
dict: The info about the new service
NotificationService: The instance representing the new service
"""
logging.info(f'Adding notification service with {title=}, {url=}')

View File

@@ -4,28 +4,65 @@ import logging
from datetime import datetime
from sqlite3 import IntegrityError
from threading import Timer
from typing import List, Literal
from typing import List, Literal, Union
from apprise import Apprise
from dateutil.relativedelta import relativedelta, weekday
from flask import Flask
from dateutil.relativedelta import relativedelta
from dateutil.relativedelta import weekday as du_weekday
from backend.custom_exceptions import (InvalidKeyValue, InvalidTime,
NotificationServiceNotFound,
ReminderNotFound)
from backend.db import close_db, get_db
from backend.db import get_db
from backend.helpers import RepeatQuantity, Singleton, SortingMethod, search_filter
filter_function = lambda query, p: (
query in p["title"].lower()
or query in p["text"].lower()
)
def __next_selected_day(
weekdays: List[int],
weekday: int
) -> int:
"""Find the next allowed day in the week.
Args:
weekdays (List[int]): The days of the week that are allowed.
Monday is 0, Sunday is 6.
weekday (int): The current weekday.
Returns:
int: The next allowed weekday.
"""
return (
# Get all days later than current, then grab first one.
[d for d in weekdays if weekday < d]
or
# weekday is last allowed day, so it should grab the first
# allowed day of the week.
weekdays
)[0]
def _find_next_time(
original_time: int,
repeat_quantity: Literal["years", "months", "weeks", "days", "hours", "minutes"],
repeat_interval: int,
weekdays: List[int]
repeat_quantity: Union[RepeatQuantity, None],
repeat_interval: Union[int, None],
weekdays: Union[List[int], None]
) -> int:
"""Calculate the next timestep based on original time and repeat/interval
values.
Args:
original_time (int): The original time of the repeating timestamp.
repeat_quantity (Union[RepeatQuantity, None]): If set, what the quantity
is of the repetition.
repeat_interval (Union[int, None]): If set, the value of the repetition.
weekdays (Union[List[int], None]): If set, on which days the time can
continue. Monday is 0, Sunday is 6.
Returns:
int: The next timestamp in the future.
"""
if weekdays is not None:
weekdays.sort()
@@ -33,152 +70,50 @@ def _find_next_time(
current_time = datetime.fromtimestamp(datetime.utcnow().timestamp())
if repeat_quantity is not None:
td = relativedelta(**{repeat_quantity: repeat_interval})
td = relativedelta(**{repeat_quantity.value: repeat_interval})
while new_time <= current_time:
new_time += td
else:
next_day = ([d for d in weekdays if new_time.weekday() < d] or weekdays)[0]
proposed_time = new_time + relativedelta(weekday=weekday(next_day))
if proposed_time == new_time:
proposed_time += relativedelta(weekday=weekday(next_day, 2))
new_time = proposed_time
while new_time <= current_time:
next_day = ([d for d in weekdays if new_time.weekday() < d] or weekdays)[0]
proposed_time = new_time + relativedelta(weekday=weekday(next_day))
# We run the loop contents at least once and then actually use the cond.
# This is because we need to force the 'free' date to go to one of the
# selected weekdays.
# Say it's Monday, we set a reminder for Wednesday and make it repeat
# on Tuesday and Thursday. Then the first notification needs to go on
# Thurday, not Wednesday. So run code at least once to force that.
# Afterwards, it can run normally to push the timestamp into the future.
one_to_go = True
while one_to_go or new_time <= current_time:
next_day = __next_selected_day(weekdays, new_time.weekday())
proposed_time = new_time + relativedelta(weekday=du_weekday(next_day))
if proposed_time == new_time:
proposed_time += relativedelta(weekday=weekday(next_day, 2))
proposed_time += relativedelta(weekday=du_weekday(next_day, 2))
new_time = proposed_time
one_to_go = False
result = int(new_time.timestamp())
logging.debug(
f'{original_time=}, {current_time=} and interval of {repeat_interval} {repeat_quantity} leads to {result}'
f'{original_time=}, {current_time=} ' +
f'and interval of {repeat_interval} {repeat_quantity} ' +
f'leads to {result}'
)
return result
class ReminderHandler:
"""Handle set reminders
"""
def __init__(self, context) -> None:
self.context = context
self.next_trigger = {
'thread': None,
'time': None
}
return
def __trigger_reminders(self, time: int) -> None:
"""Trigger all reminders that are set for a certain time
Args:
time (int): The time of the reminders to trigger
"""
with self.context():
cursor = get_db(dict)
cursor.execute("""
SELECT
r.id,
r.title, r.text,
r.repeat_quantity, r.repeat_interval,
r.weekdays,
r.original_time
FROM reminders r
WHERE time = ?;
""", (time,))
reminders = list(map(dict, cursor))
for reminder in reminders:
cursor.execute("""
SELECT url
FROM reminder_services rs
INNER JOIN notification_services ns
ON rs.notification_service_id = ns.id
WHERE rs.reminder_id = ?;
""", (reminder['id'],))
# Send of reminder
a = Apprise()
for url in cursor:
a.add(url['url'])
a.notify(title=reminder["title"], body=reminder["text"])
if reminder['repeat_quantity'] is None and reminder['weekdays'] is None:
# Delete the reminder from the database
cursor.execute(
"DELETE FROM reminders WHERE id = ?;",
(reminder['id'],)
)
logging.info(f'Deleted reminder {reminder["id"]}')
else:
# Set next time
new_time = _find_next_time(
reminder['original_time'],
reminder['repeat_quantity'],
reminder['repeat_interval'],
[int(d) for d in reminder['weekdays'].split(',')] if reminder['weekdays'] is not None else None
)
cursor.execute(
"UPDATE reminders SET time = ? WHERE id = ?;",
(new_time, reminder['id'])
)
self.next_trigger.update({
'thread': None,
'time': None
})
self.find_next_reminder()
def find_next_reminder(self, time: int=None) -> None:
"""Determine when the soonest reminder is and set the timer to that time
Args:
time (int, optional): The timestamp to check for. Otherwise check soonest in database. Defaults to None.
"""
if not time:
with self.context():
time = get_db().execute("""
SELECT DISTINCT r1.time
FROM reminders r1
LEFT JOIN reminders r2
ON r1.time > r2.time
WHERE r2.id IS NULL;
""").fetchone()
if time is None:
return
time = time[0]
if (self.next_trigger['thread'] is None
or time < self.next_trigger['time']):
if self.next_trigger['thread'] is not None:
self.next_trigger['thread'].cancel()
t = time - datetime.utcnow().timestamp()
self.next_trigger['thread'] = Timer(
t,
self.__trigger_reminders,
(time,)
)
self.next_trigger['thread'].name = "ReminderHandler"
self.next_trigger['thread'].start()
self.next_trigger['time'] = time
def stop_handling(self) -> None:
"""Stop the timer if it's active
"""
if self.next_trigger['thread'] is not None:
self.next_trigger['thread'].cancel()
return
handler_context = Flask('handler')
handler_context.teardown_appcontext(close_db)
reminder_handler = ReminderHandler(handler_context.app_context)
class Reminder:
"""Represents a reminder
"""
def __init__(self, user_id: int, reminder_id: int):
def __init__(self, user_id: int, reminder_id: int) -> None:
"""Create an instance.
Args:
user_id (int): The ID of the user.
reminder_id (int): The ID of the reminder.
Raises:
ReminderNotFound: Reminder with given ID does not exist or is not
owned by user.
"""
self.id = reminder_id
# Check if reminder exists
@@ -188,6 +123,26 @@ class Reminder:
).fetchone():
raise ReminderNotFound
return
def _get_notification_services(self) -> List[int]:
"""Get ID's of notification services linked to the reminder.
Returns:
List[int]: The list with ID's.
"""
result = [
r[0]
for r in get_db().execute("""
SELECT notification_service_id
FROM reminder_services
WHERE reminder_id = ?;
""",
(self.id,)
)
]
return result
def get(self) -> dict:
"""Get info about the reminder
@@ -211,46 +166,65 @@ class Reminder:
).fetchone()
reminder = dict(reminder)
reminder['notification_services'] = list(map(lambda r: r[0], get_db().execute("""
SELECT notification_service_id
FROM reminder_services
WHERE reminder_id = ?;
""", (self.id,))))
reminder['notification_services'] = self._get_notification_services()
return reminder
def update(
self,
title: str = None,
time: int = None,
notification_services: List[int] = None,
text: str = None,
repeat_quantity: Literal["years", "months", "weeks", "days", "hours", "minutes"] = None,
repeat_interval: int = None,
weekdays: List[int] = None,
color: str = None
title: Union[None, str] = None,
time: Union[None, int] = None,
notification_services: Union[None, List[int]] = None,
text: Union[None, str] = None,
repeat_quantity: Union[None, RepeatQuantity] = None,
repeat_interval: Union[None, int] = None,
weekdays: Union[None, List[int]] = None,
color: Union[None, str] = None
) -> dict:
"""Edit the reminder
"""Edit the reminder.
Args:
title (str): The new title of the entry. Defaults to None.
time (int): The new UTC epoch timestamp the the reminder should be send. Defaults to None.
notification_services (List[int]): The new list of id's of the notification services to use to send the reminder. Defaults to None.
text (str, optional): The new body of the reminder. Defaults to None.
repeat_quantity (Literal["years", "months", "weeks", "days", "hours", "minutes"], optional): The new quantity of the repeat specified for the reminder. Defaults to None.
repeat_interval (int, optional): The new amount of repeat_quantity, like "5" (hours). Defaults to None.
weekdays (List[int], optional): The new indexes of the days of the week that the reminder should run. Defaults to None.
color (str, optional): The new hex code of the color of the reminder, which is shown in the web-ui. Defaults to None.
title (Union[None, str]): The new title of the entry.
Defaults to None.
time (Union[None, int]): The new UTC epoch timestamp when the
reminder should be send.
Defaults to None.
notification_services (Union[None, List[int]]): The new list
of id's of the notification services to use to send the reminder.
Defaults to None.
text (Union[None, str], optional): The new body of the reminder.
Defaults to None.
repeat_quantity (Union[None, RepeatQuantity], optional): The new
quantity of the repeat specified for the reminder.
Defaults to None.
repeat_interval (Union[None, int], optional): The new amount of
repeat_quantity, like "5" (hours).
Defaults to None.
weekdays (Union[None, List[int]], optional): The new indexes of
the days of the week that the reminder should run.
Defaults to None.
color (Union[None, str], optional): The new hex code of the color
of the reminder, which is shown in the web-ui.
Defaults to None.
Note about args:
Either repeat_quantity and repeat_interval are given, weekdays is given or neither, but not both.
Either repeat_quantity and repeat_interval are given, weekdays is
given or neither, but not both.
Raises:
NotificationServiceNotFound: One of the notification services was not found
InvalidKeyValue: The value of one of the keys is not valid or the "Note about args" is violated
NotificationServiceNotFound: One of the notification services was not found.
InvalidKeyValue: The value of one of the keys is not valid or
the "Note about args" is violated.
Returns:
dict: The new reminder info
dict: The new reminder info.
"""
logging.info(
f'Updating notification service {self.id}: '
@@ -264,8 +238,9 @@ class Reminder:
raise InvalidKeyValue('repeat_quantity', repeat_quantity)
elif repeat_quantity is not None and repeat_interval is None:
raise InvalidKeyValue('repeat_interval', repeat_interval)
elif weekdays is not None and repeat_quantity is not None and repeat_interval is not None:
elif weekdays is not None and repeat_quantity is not None:
raise InvalidKeyValue('weekdays', weekdays)
repeated_reminder = (
(repeat_quantity is not None and repeat_interval is not None)
or weekdays is not None
@@ -285,26 +260,35 @@ class Reminder:
'text': text,
'repeat_quantity': repeat_quantity,
'repeat_interval': repeat_interval,
'weekdays': ",".join(map(str, sorted(weekdays))) if weekdays is not None else None,
'weekdays':
",".join(map(str, sorted(weekdays)))
if weekdays is not None else
None,
'color': color
}
for k, v in new_values.items():
if k in ('repeat_quantity', 'repeat_interval', 'weekdays', 'color') or v is not None:
if (
k in ('repeat_quantity', 'repeat_interval', 'weekdays', 'color')
or v is not None
):
data[k] = v
# Update database
if repeated_reminder:
next_time = _find_next_time(
data["time"],
data["repeat_quantity"], data["repeat_interval"],
RepeatQuantity(data["repeat_quantity"]),
data["repeat_interval"],
weekdays
)
cursor.execute("""
UPDATE reminders
SET
title=?, text=?,
title=?,
text=?,
time=?,
repeat_quantity=?, repeat_interval=?,
repeat_quantity=?,
repeat_interval=?,
weekdays=?,
original_time=?,
color=?
@@ -326,9 +310,11 @@ class Reminder:
cursor.execute("""
UPDATE reminders
SET
title=?, text=?,
title=?,
text=?,
time=?,
repeat_quantity=?, repeat_interval=?,
repeat_quantity=?,
repeat_interval=?,
weekdays=?,
color=?
WHERE id = ?;
@@ -346,18 +332,29 @@ class Reminder:
if notification_services:
cursor.connection.isolation_level = None
cursor.execute("BEGIN TRANSACTION;")
cursor.execute("DELETE FROM reminder_services WHERE reminder_id = ?", (self.id,))
cursor.execute(
"DELETE FROM reminder_services WHERE reminder_id = ?",
(self.id,)
)
try:
cursor.executemany(
"INSERT INTO reminder_services(reminder_id, notification_service_id) VALUES (?,?)",
cursor.executemany("""
INSERT INTO reminder_services(
reminder_id,
notification_service_id
)
VALUES (?,?);
""",
((self.id, s) for s in notification_services)
)
cursor.execute("COMMIT;")
except IntegrityError:
raise NotificationServiceNotFound
cursor.connection.isolation_level = ""
reminder_handler.find_next_reminder(next_time)
finally:
cursor.connection.isolation_level = ""
ReminderHandler().find_next_reminder(next_time)
return self.get()
def delete(self) -> None:
@@ -365,74 +362,76 @@ class Reminder:
"""
logging.info(f'Deleting reminder {self.id}')
get_db().execute("DELETE FROM reminders WHERE id = ?", (self.id,))
reminder_handler.find_next_reminder()
ReminderHandler().find_next_reminder()
return
class Reminders:
"""Represents the reminder library of the user account
"""
sort_functions = {
'time': (lambda r: (r['time'], r['title'], r['text'], r['color']), False),
'time_reversed': (lambda r: (r['time'], r['title'], r['text'], r['color']), True),
'title': (lambda r: (r['title'], r['time'], r['text'], r['color']), False),
'title_reversed': (lambda r: (r['title'], r['time'], r['text'], r['color']), True),
'date_added': (lambda r: r['id'], False),
'date_added_reversed': (lambda r: r['id'], True)
}
def __init__(self, user_id: int):
self.user_id = user_id
"""
def fetchall(self, sort_by: Literal["time", "time_reversed", "title", "title_reversed", "date_added", "date_added_reversed"] = "time") -> List[dict]:
def __init__(self, user_id: int) -> None:
"""Create an instance.
Args:
user_id (int): The ID of the user.
"""
self.user_id = user_id
return
def fetchall(
self,
sort_by: SortingMethod = SortingMethod.TIME
) -> List[dict]:
"""Get all reminders
Args:
sort_by (Literal["time", "time_reversed", "title", "title_reversed", "date_added", "date_added_reversed"], optional): How to sort the result. Defaults to "time".
sort_by (SortingMethod, optional): How to sort the result.
Defaults to SortingMethod.TIME.
Returns:
List[dict]: The id, title, text, time and color of each reminder
"""
sort_function = self.sort_functions.get(
sort_by,
self.sort_functions['time']
)
# Fetch all reminders
reminders: list = list(map(dict, get_db(dict).execute("""
SELECT
id,
title, text,
time,
repeat_quantity,
repeat_interval,
weekdays,
color
FROM reminders
WHERE user_id = ?;
""",
(self.user_id,)
)))
reminders = [
dict(r)
for r in get_db(dict).execute("""
SELECT
id,
title, text,
time,
repeat_quantity,
repeat_interval,
weekdays,
color
FROM reminders
WHERE user_id = ?;
""",
(self.user_id,)
)
]
# Sort result
reminders.sort(key=sort_function[0], reverse=sort_function[1])
reminders.sort(key=sort_by.value[0], reverse=sort_by.value[1])
return reminders
def search(self, query: str, sort_by: Literal["time", "time_reversed", "title", "title_reversed", "date_added", "date_added_reversed"] = "time") -> List[dict]:
def search(
self,
query: str,
sort_by: SortingMethod = SortingMethod.TIME) -> List[dict]:
"""Search for reminders
Args:
query (str): The term to search for
sort_by (Literal["time", "time_reversed", "title", "title_reversed", "date_added", "date_added_reversed"], optional): How to sort the result. Defaults to "time".
query (str): The term to search for.
sort_by (SortingMethod, optional): How to sort the result.
Defaults to SortingMethod.TIME.
Returns:
List[dict]: All reminders that match. Similar output to self.fetchall
"""
query = query.lower()
reminders = list(filter(
lambda p: filter_function(query, p),
self.fetchall(sort_by)
))
"""
reminders = [
r for r in self.fetchall(sort_by)
if search_filter(query, r)
]
return reminders
def fetchone(self, id: int) -> Reminder:
@@ -452,32 +451,51 @@ class Reminders:
time: int,
notification_services: List[int],
text: str = '',
repeat_quantity: Literal["years", "months", "weeks", "days", "hours", "minutes"] = None,
repeat_interval: int = None,
weekdays: List[int] = None,
color: str = None
repeat_quantity: Union[None, RepeatQuantity] = None,
repeat_interval: Union[None, int] = None,
weekdays: Union[None, List[int]] = None,
color: Union[None, str] = None
) -> Reminder:
"""Add a reminder
Args:
title (str): The title of the entry
title (str): The title of the entry.
time (int): The UTC epoch timestamp the the reminder should be send.
notification_services (List[int]): The id's of the notification services to use to send the reminder.
text (str, optional): The body of the reminder. Defaults to ''.
repeat_quantity (Literal["years", "months", "weeks", "days", "hours", "minutes"], optional): The quantity of the repeat specified for the reminder. Defaults to None.
repeat_interval (int, optional): The amount of repeat_quantity, like "5" (hours). Defaults to None.
weekdays (List[int], optional): The indexes of the days of the week that the reminder should run. Defaults to None.
color (str, optional): The hex code of the color of the reminder, which is shown in the web-ui. Defaults to None.
notification_services (List[int]): The id's of the notification services
to use to send the reminder.
text (str, optional): The body of the reminder.
Defaults to ''.
repeat_quantity (Union[None, RepeatQuantity], optional): The quantity
of the repeat specified for the reminder.
Defaults to None.
repeat_interval (Union[None, int], optional): The amount of repeat_quantity,
like "5" (hours).
Defaults to None.
weekdays (Union[None, List[int]], optional): The indexes of the days
of the week that the reminder should run.
Defaults to None.
color (Union[None, str], optional): The hex code of the color of the
reminder, which is shown in the web-ui.
Defaults to None.
Note about args:
Either repeat_quantity and repeat_interval are given, weekdays is given or neither, but not both.
Either repeat_quantity and repeat_interval are given,
weekdays is given or neither, but not both.
Raises:
NotificationServiceNotFound: One of the notification services was not found
InvalidKeyValue: The value of one of the keys is not valid or the "Note about args" is violated
NotificationServiceNotFound: One of the notification services was not found.
InvalidKeyValue: The value of one of the keys is not valid
or the "Note about args" is violated.
Returns:
dict: The info about the reminder
dict: The info about the reminder.
"""
logging.info(
f'Adding reminder with {title=}, {time=}, {notification_services=}, '
@@ -492,50 +510,89 @@ class Reminders:
raise InvalidKeyValue('repeat_quantity', repeat_quantity)
elif repeat_quantity is not None and repeat_interval is None:
raise InvalidKeyValue('repeat_interval', repeat_interval)
elif weekdays is not None and repeat_quantity is not None and repeat_interval is not None:
elif (
weekdays is not None
and repeat_quantity is not None
and repeat_interval is not None
):
raise InvalidKeyValue('weekdays', weekdays)
cursor = get_db()
for service in notification_services:
if not cursor.execute(
"SELECT 1 FROM notification_services WHERE id = ? AND user_id = ? LIMIT 1;",
if not cursor.execute("""
SELECT 1
FROM notification_services
WHERE id = ?
AND user_id = ?
LIMIT 1;
""",
(service, self.user_id)
).fetchone():
raise NotificationServiceNotFound
if repeat_quantity is not None and repeat_interval is not None:
id = cursor.execute("""
INSERT INTO reminders(user_id, title, text, time, repeat_quantity, repeat_interval, original_time, color)
VALUES (?, ?, ?, ?, ?, ?, ?, ?);
""", (self.user_id, title, text, time, repeat_quantity, repeat_interval, time, color)
).lastrowid
elif weekdays is not None:
weekdays = ",".join(map(str, sorted(weekdays)))
id = cursor.execute("""
INSERT INTO reminders(user_id, title, text, time, weekdays, original_time, color)
VALUES (?, ?, ?, ?, ?, ?, ?);
""", (self.user_id, title, text, time, weekdays, time, color)
).lastrowid
# Prepare args
if any((repeat_quantity, weekdays)):
original_time = time
time = _find_next_time(
original_time,
repeat_quantity,
repeat_interval,
weekdays
)
else:
id = cursor.execute("""
INSERT INTO reminders(user_id, title, text, time, color)
VALUES (?, ?, ?, ?, ?);
""", (self.user_id, title, text, time, color)
).lastrowid
original_time = None
if weekdays is not None:
weekdays = ",".join(map(str, sorted(weekdays)))
if repeat_quantity is not None:
repeat_quantity = repeat_quantity.value
cursor.connection.isolation_level = None
cursor.execute("BEGIN TRANSACTION;")
id = cursor.execute("""
INSERT INTO reminders(
user_id,
title, text,
time,
repeat_quantity, repeat_interval,
weekdays,
original_time,
color
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?);
""", (
self.user_id,
title, text,
time,
repeat_quantity,
repeat_interval,
weekdays,
original_time,
color
)).lastrowid
try:
cursor.executemany(
"INSERT INTO reminder_services(reminder_id, notification_service_id) VALUES (?, ?);",
cursor.executemany("""
INSERT INTO reminder_services(
reminder_id,
notification_service_id
)
VALUES (?, ?);
""",
((id, service) for service in notification_services)
)
cursor.execute("COMMIT;")
except IntegrityError:
raise NotificationServiceNotFound
reminder_handler.find_next_reminder(time)
# Return info
finally:
cursor.connection.isolation_level = ''
ReminderHandler().find_next_reminder(time)
return self.fetchone(id)
def test_reminder(
@@ -544,23 +601,164 @@ class Reminders:
notification_services: List[int],
text: str = ''
) -> None:
"""Test send a reminder draft
"""Test send a reminder draft.
Args:
title (str): Title title of the entry
notification_service (int): The id of the notification service to use to send the reminder
text (str, optional): The body of the reminder. Defaults to ''.
title (str): Title title of the entry.
notification_service (int): The id of the notification service to
use to send the reminder.
text (str, optional): The body of the reminder.
Defaults to ''.
"""
logging.info(f'Testing reminder with {title=}, {notification_services=}, {text=}')
a = Apprise()
cursor = get_db(dict)
for service in notification_services:
url = cursor.execute(
"SELECT url FROM notification_services WHERE id = ? AND user_id = ? LIMIT 1;",
url = cursor.execute("""
SELECT url
FROM notification_services
WHERE id = ?
AND user_id = ?
LIMIT 1;
""",
(service, self.user_id)
).fetchone()
if not url:
raise NotificationServiceNotFound
a.add(url[0])
a.notify(title=title, body=text)
return
class ReminderHandler(metaclass=Singleton):
"""Handle set reminders.
Note: Singleton.
"""
def __init__(self, context) -> None:
"""Create instance of handler.
Args:
context (AppContext): `Flask.app_context`
"""
self.context = context
self.thread: Union[Timer, None] = None
self.time: Union[int, None] = None
return
def __trigger_reminders(self, time: int) -> None:
"""Trigger all reminders that are set for a certain time
Args:
time (int): The time of the reminders to trigger
"""
with self.context():
cursor = get_db(dict)
reminders = [
dict(r)
for r in cursor.execute("""
SELECT
id, user_id,
title, text,
repeat_quantity, repeat_interval,
weekdays,
original_time
FROM reminders
WHERE time = ?;
""",
(time,)
)
]
for reminder in reminders:
cursor.execute("""
SELECT url
FROM reminder_services rs
INNER JOIN notification_services ns
ON rs.notification_service_id = ns.id
WHERE rs.reminder_id = ?;
""",
(reminder['id'],)
)
# Send reminder
a = Apprise()
for url in cursor:
a.add(url['url'])
a.notify(title=reminder["title"], body=reminder["text"])
self.thread = None
self.time = None
if (reminder['repeat_quantity'], reminder['weekdays']) == (None, None):
# Delete the reminder from the database
Reminder(reminder["user_id"], reminder["id"]).delete()
else:
# Set next time
new_time = _find_next_time(
reminder['original_time'],
RepeatQuantity(reminder['repeat_quantity']),
reminder['repeat_interval'],
[int(d) for d in reminder['weekdays'].split(',')]
if reminder['weekdays'] is not None else
None
)
cursor.execute(
"UPDATE reminders SET time = ? WHERE id = ?;",
(new_time, reminder['id'])
)
self.find_next_reminder()
return
def find_next_reminder(self, time: int=None) -> None:
"""Determine when the soonest reminder is and set the timer to that time
Args:
time (int, optional): The timestamp to check for.
Otherwise check soonest in database.
Defaults to None.
"""
if not time:
with self.context():
time = get_db().execute("""
SELECT DISTINCT r1.time
FROM reminders r1
LEFT JOIN reminders r2
ON r1.time > r2.time
WHERE r2.id IS NULL;
""").fetchone()
if time is None:
return
time = time[0]
if (
self.thread is None
or time < self.time
):
if self.thread is not None:
self.thread.cancel()
t = time - datetime.utcnow().timestamp()
self.thread = Timer(
t,
self.__trigger_reminders,
(time,)
)
self.thread.name = "ReminderHandler"
self.thread.start()
self.time = time
return
def stop_handling(self) -> None:
"""Stop the timer if it's active
"""
if self.thread is not None:
self.thread.cancel()
return

View File

@@ -1,5 +1,9 @@
#-*- coding: utf-8 -*-
"""
Hashing and salting
"""
from base64 import urlsafe_b64encode
from hashlib import pbkdf2_hmac
from secrets import token_bytes
@@ -29,7 +33,6 @@ def generate_salt_hash(password: str) -> Tuple[bytes, bytes]:
Returns:
Tuple[bytes, bytes]: The salt (1) and hashed_password (2)
"""
# Hash the password
salt = token_bytes()
hashed_password = get_hash(salt, password)
del password

View File

@@ -1,5 +1,9 @@
#-*- coding: utf-8 -*-
"""
Getting and setting settings
"""
from backend.custom_exceptions import InvalidKeyValue, KeyNotFound
from backend.db import __DATABASE_VERSION__, get_db
@@ -85,7 +89,7 @@ def get_admin_settings() -> dict:
"""
return dict((
(key, _reverse_format_setting(key, value))
for (key, value) in get_db().execute("""
for key, value in get_db().execute("""
SELECT key, value
FROM config
WHERE

View File

@@ -2,23 +2,30 @@
import logging
from sqlite3 import IntegrityError
from typing import List, Literal
from typing import List, Union
from apprise import Apprise
from backend.custom_exceptions import (NotificationServiceNotFound,
ReminderNotFound)
from backend.db import get_db
from backend.helpers import TimelessSortingMethod, search_filter
filter_function = lambda query, p: (
query in p["title"].lower()
or query in p["text"].lower()
)
class StaticReminder:
"""Represents a static reminder
"""
def __init__(self, user_id: int, reminder_id: int) -> None:
"""Create an instance.
Args:
user_id (int): The ID of the user.
reminder_id (int): The ID of the reminder.
Raises:
ReminderNotFound: Reminder with given ID does not exist or is not
owned by user.
"""
self.id = reminder_id
# Check if reminder exists
@@ -27,12 +34,32 @@ class StaticReminder:
(self.id, user_id)
).fetchone():
raise ReminderNotFound
return
def _get_notification_services(self) -> List[int]:
"""Get ID's of notification services linked to the static reminder.
Returns:
List[int]: The list with ID's.
"""
result = [
r[0]
for r in get_db().execute("""
SELECT notification_service_id
FROM reminder_services
WHERE static_reminder_id = ?;
""",
(self.id,)
)
]
return result
def get(self) -> dict:
"""Get info about the static reminder
Returns:
dict: The info about the reminder
dict: The info about the static reminder
"""
reminder = get_db(dict).execute("""
SELECT
@@ -47,28 +74,33 @@ class StaticReminder:
).fetchone()
reminder = dict(reminder)
reminder['notification_services'] = list(map(lambda r: r[0], get_db().execute("""
SELECT notification_service_id
FROM reminder_services
WHERE static_reminder_id = ?;
""", (self.id,))))
reminder['notification_services'] = self._get_notification_services()
return reminder
def update(
self,
title: str = None,
notification_services: List[int] = None,
text: str = None,
color: str = None
title: Union[str, None] = None,
notification_services: Union[List[int], None] = None,
text: Union[str, None] = None,
color: Union[str, None] = None
) -> dict:
"""Edit the static reminder
"""Edit the static reminder.
Args:
title (str, optional): The new title of the entry. Defaults to None.
notification_services (List[int], optional): The new id's of the notification services to use to send the reminder. Defaults to None.
text (str, optional): The new body of the reminder. Defaults to None.
color (str, optional): The new hex code of the color of the reminder, which is shown in the web-ui. Defaults to None.
title (Union[str, None], optional): The new title of the entry.
Defaults to None.
notification_services (Union[List[int], None], optional):
The new id's of the notification services to use to send the reminder.
Defaults to None.
text (Union[str, None], optional): The new body of the reminder.
Defaults to None.
color (Union[str, None], optional): The new hex code of the color
of the reminder, which is shown in the web-ui.
Defaults to None.
Raises:
NotificationServiceNotFound: One of the notification services was not found
@@ -109,16 +141,27 @@ class StaticReminder:
if notification_services:
cursor.connection.isolation_level = None
cursor.execute("BEGIN TRANSACTION;")
cursor.execute("DELETE FROM reminder_services WHERE static_reminder_id = ?", (self.id,))
cursor.execute(
"DELETE FROM reminder_services WHERE static_reminder_id = ?",
(self.id,)
)
try:
cursor.executemany(
"INSERT INTO reminder_services(static_reminder_id, notification_service_id) VALUES (?,?)",
cursor.executemany("""
INSERT INTO reminder_services(
static_reminder_id,
notification_service_id
)
VALUES (?,?);
""",
((self.id, s) for s in notification_services)
)
cursor.execute("COMMIT;")
except IntegrityError:
raise NotificationServiceNotFound
cursor.connection.isolation_level = ""
finally:
cursor.connection.isolation_level = ""
return self.get()
@@ -132,33 +175,32 @@ class StaticReminder:
class StaticReminders:
"""Represents the static reminder library of the user account
"""
sort_functions = {
'title': (lambda r: (r['title'], r['text'], r['color']), False),
'title_reversed': (lambda r: (r['title'], r['text'], r['color']), True),
'date_added': (lambda r: r['id'], False),
'date_added_reversed': (lambda r: r['id'], True)
}
def __init__(self, user_id: int) -> None:
"""Create an instance.
Args:
user_id (int): The ID of the user.
"""
self.user_id = user_id
return
def fetchall(self, sort_by: Literal["title", "title_reversed", "date_added", "date_added_reversed"] = "title") -> List[dict]:
def fetchall(
self,
sort_by: TimelessSortingMethod = TimelessSortingMethod.TITLE
) -> List[dict]:
"""Get all static reminders
Args:
sort_by (Literal["title", "title_reversed", "date_added", "date_added_reversed"], optional): How to sort the result. Defaults to "title".
sort_by (TimelessSortingMethod, optional): How to sort the result.
Defaults to TimelessSortingMethod.TITLE.
Returns:
List[dict]: The id, title, text and color of each static reminder
List[dict]: The id, title, text and color of each static reminder.
"""
sort_function = self.sort_functions.get(
sort_by,
self.sort_functions['title']
)
reminders: list = list(map(
dict,
get_db(dict).execute("""
reminders = [
dict(r)
for r in get_db(dict).execute("""
SELECT
id,
title, text,
@@ -169,29 +211,36 @@ class StaticReminders:
""",
(self.user_id,)
)
))
]
# Sort result
reminders.sort(key=sort_function[0], reverse=sort_function[1])
reminders.sort(key=sort_by.value[0], reverse=sort_by.value[1])
return reminders
def search(self, query: str, sort_by: Literal["title", "title_reversed", "date_added", "date_added_reversed"] = "title") -> List[dict]:
def search(
self,
query: str,
sort_by: TimelessSortingMethod = TimelessSortingMethod.TITLE
) -> List[dict]:
"""Search for static reminders
Args:
query (str): The term to search for
sort_by (Literal["title", "title_reversed", "date_added", "date_added_reversed"], optional): How to sort the result. Defaults to "title".
query (str): The term to search for.
sort_by (TimelessSortingMethod, optional): The sorting method of
the resulting list.
Defaults to TimelessSortingMethod.TITLE.
Returns:
List[dict]: All static reminders that match. Similar output to self.fetchall
"""
query = query.lower()
reminders = list(filter(
lambda p: filter_function(query, p),
self.fetchall(sort_by)
))
return reminders
List[dict]: All static reminders that match.
Similar output to `self.fetchall`
"""
static_reminders = [
r for r in self.fetchall(sort_by)
if search_filter(query, r)
]
return static_reminders
def fetchone(self, id: int) -> StaticReminder:
"""Get one static reminder
@@ -214,22 +263,32 @@ class StaticReminders:
"""Add a static reminder
Args:
title (str): The title of the entry
notification_services (List[int]): The id's of the notification services to use to send the reminder.
text (str, optional): The body of the reminder. Defaults to ''.
color (str, optional): The hex code of the color of the reminder, which is shown in the web-ui. Defaults to None.
title (str): The title of the entry.
notification_services (List[int]): The id's of the
notification services to use to send the reminder.
text (str, optional): The body of the reminder.
Defaults to ''.
color (str, optional): The hex code of the color of the template,
which is shown in the web-ui.
Defaults to None.
Raises:
NotificationServiceNotFound: One of the notification services was not found
Returns:
StaticReminder: A StaticReminder instance representing the newly created static reminder
StaticReminder: The info about the static reminder
"""
logging.info(
f'Adding static reminder with {title=}, {notification_services=}, {text=}, {color=}'
)
cursor = get_db()
cursor.connection.isolation_level = None
cursor.execute("BEGIN TRANSACTION;")
id = cursor.execute("""
INSERT INTO static_reminders(user_id, title, text, color)
VALUES (?,?,?,?);
@@ -238,13 +297,22 @@ class StaticReminders:
).lastrowid
try:
cursor.executemany(
"INSERT INTO reminder_services(static_reminder_id, notification_service_id) VALUES (?, ?);",
cursor.executemany("""
INSERT INTO reminder_services(
static_reminder_id,
notification_service_id
)
VALUES (?, ?);
""",
((id, service) for service in notification_services)
)
cursor.execute("COMMIT;")
except IntegrityError:
raise NotificationServiceNotFound
finally:
cursor.connection.isolation_level = ""
return self.fetchone(id)
def trigger_reminder(self, id: int) -> None:
@@ -265,7 +333,9 @@ class StaticReminders:
id = ?
AND user_id = ?
LIMIT 1;
""", (id, self.user_id)).fetchone()
""",
(id, self.user_id)
).fetchone()
if not reminder:
raise ReminderNotFound
reminder = dict(reminder)
@@ -277,7 +347,9 @@ class StaticReminders:
INNER JOIN notification_services ns
ON rs.notification_service_id = ns.id
WHERE rs.static_reminder_id = ?;
""", (id,))
""",
(id,)
)
for url in cursor:
a.add(url['url'])
a.notify(title=reminder['title'], body=reminder['text'])

View File

@@ -2,21 +2,28 @@
import logging
from sqlite3 import IntegrityError
from typing import List, Literal
from typing import List, Union
from backend.custom_exceptions import (NotificationServiceNotFound,
TemplateNotFound)
from backend.db import get_db
from backend.helpers import TimelessSortingMethod, search_filter
filter_function = lambda query, p: (
query in p["title"].lower()
or query in p["text"].lower()
)
class Template:
"""Represents a template
"""
def __init__(self, user_id: int, template_id: int):
def __init__(self, user_id: int, template_id: int) -> None:
"""Create instance of class.
Args:
user_id (int): The ID of the user.
template_id (int): The ID of the template.
Raises:
TemplateNotFound: Template with given ID does not exist or is not
owned by user.
"""
self.id = template_id
exists = get_db().execute(
@@ -25,7 +32,26 @@ class Template:
).fetchone()
if not exists:
raise TemplateNotFound
return
def _get_notification_services(self) -> List[int]:
"""Get ID's of notification services linked to the template.
Returns:
List[int]: The list with ID's.
"""
result = [
r[0]
for r in get_db().execute("""
SELECT notification_service_id
FROM reminder_services
WHERE template_id = ?;
""",
(self.id,)
)
]
return result
def get(self) -> dict:
"""Get info about the template
@@ -45,27 +71,35 @@ class Template:
).fetchone()
template = dict(template)
template['notification_services'] = list(map(lambda r: r[0], get_db().execute("""
SELECT notification_service_id
FROM reminder_services
WHERE template_id = ?;
""", (self.id,))))
template['notification_services'] = self._get_notification_services()
return template
def update(self,
title: str = None,
notification_services: List[int] = None,
text: str = None,
color: str = None
title: Union[str, None] = None,
notification_services: Union[List[int], None] = None,
text: Union[str, None] = None,
color: Union[str, None] = None
) -> dict:
"""Edit the template
Args:
title (str): The new title of the entry. Defaults to None.
notification_services (List[int]): The new id's of the notification services to use to send the reminder. Defaults to None.
text (str, optional): The new body of the template. Defaults to None.
color (str, optional): The new hex code of the color of the template, which is shown in the web-ui. Defaults to None.
title (Union[str, None]): The new title of the entry.
Defaults to None.
notification_services (Union[List[int], None]): The new id's of the
notification services to use to send the reminder.
Defaults to None.
text (Union[str, None], optional): The new body of the template.
Defaults to None.
color (Union[str, None], optional): The new hex code of the color of the template,
which is shown in the web-ui.
Defaults to None.
Raises:
NotificationServiceNotFound: One of the notification services was not found
Returns:
dict: The new template info
@@ -101,16 +135,27 @@ class Template:
if notification_services:
cursor.connection.isolation_level = None
cursor.execute("BEGIN TRANSACTION;")
cursor.execute("DELETE FROM reminder_services WHERE template_id = ?", (self.id,))
cursor.execute(
"DELETE FROM reminder_services WHERE template_id = ?",
(self.id,)
)
try:
cursor.executemany(
"INSERT INTO reminder_services(template_id, notification_service_id) VALUES (?,?)",
cursor.executemany("""
INSERT INTO reminder_services(
template_id,
notification_service_id
)
VALUES (?,?);
""",
((self.id, s) for s in notification_services)
)
cursor.execute("COMMIT;")
except IntegrityError:
raise NotificationServiceNotFound
cursor.connection.isolation_level = ""
finally:
cursor.connection.isolation_level = ""
return self.get()
@@ -124,63 +169,72 @@ class Template:
class Templates:
"""Represents the template library of the user account
"""
sort_functions = {
'title': (lambda r: (r['title'], r['text'], r['color']), False),
'title_reversed': (lambda r: (r['title'], r['text'], r['color']), True),
'date_added': (lambda r: r['id'], False),
'date_added_reversed': (lambda r: r['id'], True)
}
def __init__(self, user_id: int):
self.user_id = user_id
def fetchall(self, sort_by: Literal["title", "title_reversed", "date_added", "date_added_reversed"] = "title") -> List[dict]:
"""Get all templates
def __init__(self, user_id: int) -> None:
"""Create an instance.
Args:
sort_by (Literal["title", "title_reversed", "date_added", "date_added_reversed"], optional): How to sort the result. Defaults to "title".
user_id (int): The ID of the user.
"""
self.user_id = user_id
return
def fetchall(
self,
sort_by: TimelessSortingMethod = TimelessSortingMethod.TITLE
) -> List[dict]:
"""Get all templates of the user.
Args:
sort_by (TimelessSortingMethod, optional): The sorting method of
the resulting list.
Defaults to TimelessSortingMethod.TITLE.
Returns:
List[dict]: The id, title, text and color
List[dict]: The id, title, text and color of each template.
"""
sort_function = self.sort_functions.get(
sort_by,
self.sort_functions['title']
)
templates: list = list(map(dict, get_db(dict).execute("""
SELECT
id,
title, text,
color
FROM templates
WHERE user_id = ?
ORDER BY title, id;
""",
(self.user_id,)
)))
templates = [
dict(r)
for r in get_db(dict).execute("""
SELECT
id,
title, text,
color
FROM templates
WHERE user_id = ?
ORDER BY title, id;
""",
(self.user_id,)
)
]
# Sort result
templates.sort(key=sort_function[0], reverse=sort_function[1])
templates.sort(key=sort_by.value[0], reverse=sort_by.value[1])
return templates
def search(self, query: str, sort_by: Literal["title", "title_reversed", "date_added", "date_added_reversed"] = "title") -> List[dict]:
def search(
self,
query: str,
sort_by: TimelessSortingMethod = TimelessSortingMethod.TITLE
) -> List[dict]:
"""Search for templates
Args:
query (str): The term to search for
sort_by (Literal["title", "title_reversed", "date_added", "date_added_reversed"], optional): How to sort the result. Defaults to "title".
query (str): The term to search for.
sort_by (TimelessSortingMethod, optional): The sorting method of
the resulting list.
Defaults to TimelessSortingMethod.TITLE.
Returns:
List[dict]: All templates that match. Similar output to self.fetchall
List[dict]: All templates that match. Similar output to `self.fetchall`
"""
query = query.lower()
reminders = list(filter(
lambda p: filter_function(query, p),
self.fetchall(sort_by)
))
return reminders
templates = [
r for r in self.fetchall(sort_by)
if search_filter(query, r)
]
return templates
def fetchone(self, id: int) -> Template:
"""Get one template
@@ -203,10 +257,20 @@ class Templates:
"""Add a template
Args:
title (str): The title of the entry
notification_services (List[int]): The id's of the notification services to use to send the reminder.
text (str, optional): The body of the reminder. Defaults to ''.
color (str, optional): The hex code of the color of the template, which is shown in the web-ui. Defaults to None.
title (str): The title of the entry.
notification_services (List[int]): The id's of the
notification services to use to send the reminder.
text (str, optional): The body of the reminder.
Defaults to ''.
color (str, optional): The hex code of the color of the template,
which is shown in the web-ui.
Defaults to None.
Raises:
NotificationServiceNotFound: One of the notification services was not found
Returns:
Template: The info about the template
@@ -216,19 +280,32 @@ class Templates:
)
cursor = get_db()
cursor.connection.isolation_level = None
cursor.execute("BEGIN TRANSACTION;")
id = cursor.execute("""
INSERT INTO templates(user_id, title, text, color)
VALUES (?,?,?,?);
""",
(self.user_id, title, text, color)
).lastrowid
try:
cursor.executemany(
"INSERT INTO reminder_services(template_id, notification_service_id) VALUES (?, ?);",
cursor.executemany("""
INSERT INTO reminder_services(
template_id,
notification_service_id
)
VALUES (?, ?);
""",
((id, service) for service in notification_services)
)
cursor.execute("COMMIT;")
except IntegrityError:
raise NotificationServiceNotFound
finally:
cursor.connection.isolation_level = ""
return self.fetchone(id)

View File

@@ -1,7 +1,7 @@
#-*- coding: utf-8 -*-
import logging
from typing import List, Union
from typing import List
from backend.custom_exceptions import (AccessUnauthorized,
NewAccountsNotAllowed, UsernameInvalid,
@@ -19,32 +19,28 @@ ONEPASS_INVALID_USERNAMES = ['reminders', 'api']
class User:
"""Represents an user account
"""
def __init__(self, username: str, password: Union[str, None]=None):
# Fetch data of user to check if user exists and to check if password is correct
"""
def __init__(self, id: int) -> None:
result = get_db(dict).execute(
"SELECT id, salt, hash, admin FROM users WHERE username = ? LIMIT 1;",
(username,)
"SELECT username, admin FROM users WHERE id = ? LIMIT 1;",
(id,)
).fetchone()
if not result:
raise UserNotFound
self.username = username
self.salt = result['salt']
self.user_id = result['id']
self.admin = result['admin'] == 1
# Check password
if password is not None:
hash_password = get_hash(result['salt'], password)
if not hash_password == result['hash']:
raise AccessUnauthorized
self.username: str = result['username']
self.user_id = id
self.admin: bool = result['admin'] == 1
return
@property
def reminders(self) -> Reminders:
"""Get access to the reminders of the user account
Returns:
Reminders: Reminders instance that can be used to access the reminders of the user account
Reminders: Reminders instance that can be used to access the
reminders of the user account
"""
if not hasattr(self, 'reminders_instance'):
self.reminders_instance = Reminders(self.user_id)
@@ -55,7 +51,8 @@ class User:
"""Get access to the notification services of the user account
Returns:
NotificationServices: NotificationServices instance that can be used to access the notification services of the user account
NotificationServices: NotificationServices instance that can be used
to access the notification services of the user account
"""
if not hasattr(self, 'notification_services_instance'):
self.notification_services_instance = NotificationServices(self.user_id)
@@ -66,7 +63,8 @@ class User:
"""Get access to the templates of the user account
Returns:
Templates: Templates instance that can be used to access the templates of the user account
Templates: Templates instance that can be used to access the
templates of the user account
"""
if not hasattr(self, 'templates_instance'):
self.templates_instance = Templates(self.user_id)
@@ -77,7 +75,8 @@ class User:
"""Get access to the static reminders of the user account
Returns:
StaticReminders: StaticReminders instance that can be used to access the static reminders of the user account
StaticReminders: StaticReminders instance that can be used to
access the static reminders of the user account
"""
if not hasattr(self, 'static_reminders_instance'):
self.static_reminders_instance = StaticReminders(self.user_id)
@@ -103,121 +102,133 @@ class User:
def delete(self) -> None:
"""Delete the user account
"""
if self.username == 'admin':
raise UserNotFound
logging.info(f'Deleting the user {self.username} ({self.user_id})')
cursor = get_db()
cursor.execute("DELETE FROM reminders WHERE user_id = ?", (self.user_id,))
cursor.execute("DELETE FROM templates WHERE user_id = ?", (self.user_id,))
cursor.execute("DELETE FROM static_reminders WHERE user_id = ?", (self.user_id,))
cursor.execute("DELETE FROM notification_services WHERE user_id = ?", (self.user_id,))
cursor.execute("DELETE FROM users WHERE id = ?", (self.user_id,))
cursor.execute(
"DELETE FROM reminders WHERE user_id = ?",
(self.user_id,)
)
cursor.execute(
"DELETE FROM templates WHERE user_id = ?",
(self.user_id,)
)
cursor.execute(
"DELETE FROM static_reminders WHERE user_id = ?",
(self.user_id,)
)
cursor.execute(
"DELETE FROM notification_services WHERE user_id = ?",
(self.user_id,)
)
cursor.execute(
"DELETE FROM users WHERE id = ?",
(self.user_id,)
)
return
def _check_username(username: str) -> None:
"""Check if username is valid
class Users:
def _check_username(self, username: str) -> None:
"""Check if username is valid
Args:
username (str): The username to check
Args:
username (str): The username to check
Raises:
UsernameInvalid: The username is not valid
"""
logging.debug(f'Checking the username {username}')
if username in ONEPASS_INVALID_USERNAMES or username.isdigit():
raise UsernameInvalid
if list(filter(lambda c: not c in ONEPASS_USERNAME_CHARACTERS, username)):
raise UsernameInvalid
return
def register_user(username: str, password: str, from_admin: bool=False) -> int:
"""Add a user
Args:
username (str): The username of the new user
password (str): The password of the new user
from_admin (bool, optional): Skip check if new accounts are allowed.
Defaults to False.
Raises:
UsernameInvalid: Username not allowed or contains invalid characters
UsernameTaken: Username is already taken; usernames must be unique
NewAccountsNotAllowed: In the admin panel, new accounts are set to be
not allowed.
Returns:
user_id (int): The id of the new user. User registered successful
"""
logging.info(f'Registering user with username {username}')
if not from_admin and not get_setting('allow_new_accounts'):
raise NewAccountsNotAllowed
# Check if username is valid
_check_username(username)
cursor = get_db()
# Check if username isn't already taken
if cursor.execute(
"SELECT 1 FROM users WHERE username = ? LIMIT 1", (username,)
).fetchone():
raise UsernameTaken
# Generate salt and key exclusive for user
salt, hashed_password = generate_salt_hash(password)
del password
# Add user to userlist
user_id = cursor.execute(
Raises:
UsernameInvalid: The username is not valid
"""
INSERT INTO users(username, salt, hash)
VALUES (?,?,?);
""",
(username, salt, hashed_password)
).lastrowid
logging.debug(f'Checking the username {username}')
if username in ONEPASS_INVALID_USERNAMES or username.isdigit():
raise UsernameInvalid(username)
if list(filter(lambda c: not c in ONEPASS_USERNAME_CHARACTERS, username)):
raise UsernameInvalid(username)
return
logging.debug(f'Newly registered user has id {user_id}')
return user_id
def __contains__(self, username: str) -> bool:
result = get_db().execute(
"SELECT 1 FROM users WHERE username = ? LIMIT 1;",
(username,)
).fetchone()
return result is not None
def get_users() -> List[dict]:
"""Get all user info for the admin
def add(self, username: str, password: str, from_admin: bool=False) -> int:
"""Add a user
Returns:
List[dict]: The info about all users
"""
result = [
dict(u)
for u in get_db(dict).execute(
"SELECT id, username, admin FROM users ORDER BY username;"
)
]
return result
Args:
username (str): The username of the new user
password (str): The password of the new user
from_admin (bool, optional): Skip check if new accounts are allowed.
Defaults to False.
def edit_user_password(id: int, new_password: str) -> None:
"""Change the password of a user for the admin
Raises:
UsernameInvalid: Username not allowed or contains invalid characters
UsernameTaken: Username is already taken; usernames must be unique
NewAccountsNotAllowed: In the admin panel, new accounts are set to be
not allowed.
Args:
id (int): The ID of the user to change the password of
new_password (str): The new password to set for the user
"""
username = (get_db().execute(
"SELECT username FROM users WHERE id = ? LIMIT 1;",
(id,)
).fetchone() or [''])[0]
User(username).edit_password(new_password)
return
Returns:
int: The id of the new user. User registered successful
"""
logging.info(f'Registering user with username {username}')
if not from_admin and not get_setting('allow_new_accounts'):
raise NewAccountsNotAllowed
# Check if username is valid
self._check_username(username)
def delete_user(id: int) -> None:
"""Delete a user for the admin
cursor = get_db()
Args:
id (int): The ID of the user to delete
"""
username = (get_db().execute(
"SELECT username FROM users WHERE id = ? LIMIT 1;",
(id,)
).fetchone() or [''])[0]
if username == 'admin':
raise UserNotFound
User(username).delete()
return
# Check if username isn't already taken
if username in self:
raise UsernameTaken
# Generate salt and key exclusive for user
salt, hashed_password = generate_salt_hash(password)
del password
# Add user to userlist
user_id = cursor.execute(
"""
INSERT INTO users(username, salt, hash)
VALUES (?,?,?);
""",
(username, salt, hashed_password)
).lastrowid
logging.debug(f'Newly registered user has id {user_id}')
return user_id
def get_all(self) -> List[dict]:
"""Get all user info for the admin
Returns:
List[dict]: The info about all users
"""
result = [
dict(u)
for u in get_db(dict).execute(
"SELECT id, username, admin FROM users ORDER BY username;"
)
]
return result
def login(self, username: str, password: str) -> User:
result = get_db(dict).execute(
"SELECT id, salt, hash FROM users WHERE username = ? LIMIT 1;",
(username,)
).fetchone()
if not result:
raise UserNotFound
hash_password = get_hash(result['salt'], password)
if not hash_password == result['hash']:
raise AccessUnauthorized
return User(result['id'])
def get_one(self, id: int) -> User:
return User(id)

View File

@@ -1,16 +1,13 @@
#-*- coding: utf-8 -*-
from io import BytesIO
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from io import BytesIO
from os import urandom
from re import compile
from time import time as epoch_time
from typing import Any, Callable, Dict, List, Tuple, Union
from typing import Any, Callable, Dict, Tuple
from apprise import Apprise
from flask import Blueprint, g, request, send_file
from flask.sansio.scaffold import T_route
from flask import g, request, send_file
from backend.custom_exceptions import (AccessUnauthorized, APIKeyExpired,
APIKeyInvalid, InvalidKeyValue,
@@ -22,378 +19,39 @@ from backend.custom_exceptions import (AccessUnauthorized, APIKeyExpired,
UsernameInvalid, UsernameTaken,
UserNotFound)
from backend.db import DBConnection
from backend.notification_service import (NotificationService,
NotificationServices,
get_apprise_services)
from backend.reminders import Reminders, reminder_handler
from backend.settings import (_format_setting, get_admin_settings, get_setting,
set_setting)
from backend.static_reminders import StaticReminders
from backend.templates import Template, Templates
from backend.users import (User, delete_user, edit_user_password, get_users,
register_user)
#===================
# Input validation
#===================
color_regex = compile(r'#[0-9a-f]{6}')
class DataSource:
DATA = 1
VALUES = 2
class InputVariable(ABC):
@abstractmethod
def __init__(self, value: Any) -> None:
pass
@property
@abstractmethod
def name() -> str:
pass
@abstractmethod
def validate(self) -> bool:
pass
@property
@abstractmethod
def required() -> bool:
pass
@property
@abstractmethod
def default() -> Any:
pass
@property
@abstractmethod
def source() -> int:
pass
@property
@abstractmethod
def description() -> str:
pass
@property
@abstractmethod
def related_exceptions() -> List[Exception]:
pass
class DefaultInputVariable(InputVariable):
source = DataSource.DATA
required = True
default = None
related_exceptions = []
def __init__(self, value: Any) -> None:
self.value = value
def validate(self) -> bool:
return isinstance(self.value, str) and self.value
def __repr__(self) -> str:
return f'| {self.name} | {"Yes" if self.required else "No"} | {self.description} | N/A |'
class NonRequiredVersion(InputVariable):
required = False
def validate(self) -> bool:
return self.value is None or super().validate()
class UsernameVariable(DefaultInputVariable):
name = 'username'
description = 'The username of the user account'
related_exceptions = [KeyNotFound, UserNotFound]
class PasswordVariable(DefaultInputVariable):
name = 'password'
description = 'The password of the user account'
related_exceptions = [KeyNotFound, AccessUnauthorized]
class NewPasswordVariable(PasswordVariable):
name = 'new_password'
description = 'The new password of the user account'
related_exceptions = [KeyNotFound]
class UsernameCreateVariable(UsernameVariable):
related_exceptions = [
KeyNotFound,
UsernameInvalid, UsernameTaken,
NewAccountsNotAllowed
]
class PasswordCreateVariable(PasswordVariable):
related_exceptions = [KeyNotFound]
class TitleVariable(DefaultInputVariable):
name = 'title'
description = 'The title of the entry'
related_exceptions = [KeyNotFound]
class URLVariable(DefaultInputVariable):
name = 'url'
description = 'The Apprise URL of the notification service'
related_exceptions = [KeyNotFound, InvalidKeyValue]
def validate(self) -> bool:
return Apprise().add(self.value)
class EditTitleVariable(NonRequiredVersion, TitleVariable):
related_exceptions = []
class EditURLVariable(NonRequiredVersion, URLVariable):
related_exceptions = [InvalidKeyValue]
class SortByVariable(DefaultInputVariable):
name = 'sort_by'
description = 'How to sort the result'
required = False
source = DataSource.VALUES
_options = Reminders.sort_functions
default = next(iter(Reminders.sort_functions))
related_exceptions = [InvalidKeyValue]
def __init__(self, value: str) -> None:
self.value = value
def validate(self) -> bool:
return self.value in self._options
def __repr__(self) -> str:
return '| {n} | {r} | {d} | {v} |'.format(
n=self.name,
r="Yes" if self.required else "No",
d=self.description,
v=", ".join(f'`{o}`' for o in self._options)
)
class TemplateSortByVariable(SortByVariable):
_options = Templates.sort_functions
default = next(iter(Templates.sort_functions))
class StaticReminderSortByVariable(TemplateSortByVariable):
_options = StaticReminders.sort_functions
default = next(iter(StaticReminders.sort_functions))
class TimeVariable(DefaultInputVariable):
name = 'time'
description = 'The UTC epoch timestamp that the reminder should be sent at'
related_exceptions = [KeyNotFound, InvalidKeyValue, InvalidTime]
def validate(self) -> bool:
return isinstance(self.value, (float, int))
class EditTimeVariable(NonRequiredVersion, TimeVariable):
related_exceptions = [InvalidKeyValue, InvalidTime]
class NotificationServicesVariable(DefaultInputVariable):
name = 'notification_services'
description = "Array of the id's of the notification services to use to send the notification"
related_exceptions = [KeyNotFound, InvalidKeyValue, NotificationServiceNotFound]
def validate(self) -> bool:
if not isinstance(self.value, list):
return False
if not self.value:
return False
for v in self.value:
if not isinstance(v, int):
return False
return True
class EditNotificationServicesVariable(NonRequiredVersion, NotificationServicesVariable):
related_exceptions = [InvalidKeyValue, NotificationServiceNotFound]
class TextVariable(NonRequiredVersion, DefaultInputVariable):
name = 'text'
description = 'The body of the entry'
default = ''
def validate(self) -> bool:
return isinstance(self.value, str)
class RepeatQuantityVariable(DefaultInputVariable):
name = 'repeat_quantity'
description = 'The quantity of the repeat_interval'
required = False
_options = ("years", "months", "weeks", "days", "hours", "minutes")
default = None
related_exceptions = [InvalidKeyValue]
def validate(self) -> bool:
return self.value is None or self.value in self._options
def __repr__(self) -> str:
return '| {n} | {r} | {d} | {v} |'.format(
n=self.name,
r="Yes" if self.required else "No",
d=self.description,
v=", ".join(f'`{o}`' for o in self._options)
)
class RepeatIntervalVariable(DefaultInputVariable):
name = 'repeat_interval'
description = 'The number of the interval'
required = False
default = None
related_exceptions = [InvalidKeyValue]
def validate(self) -> bool:
return self.value is None or (isinstance(self.value, int) and self.value > 0)
class WeekDaysVariable(DefaultInputVariable):
name = 'weekdays'
description = 'On which days of the weeks to run the reminder'
required = False
default = None
related_exceptions = [InvalidKeyValue]
_options = {0, 1, 2, 3, 4, 5, 6}
def validate(self) -> bool:
return self.value is None or (
isinstance(self.value, list)
and len(self.value) > 0
and all(v in self._options for v in self.value)
)
class ColorVariable(DefaultInputVariable):
name = 'color'
description = 'The hex code of the color of the entry, which is shown in the web-ui'
required = False
default = None
related_exceptions = [InvalidKeyValue]
def validate(self) -> bool:
return self.value is None or color_regex.search(self.value)
class QueryVariable(DefaultInputVariable):
name = 'query'
description = 'The search term'
source = DataSource.VALUES
class AdminSettingsVariable(DefaultInputVariable):
related_exceptions = [KeyNotFound, InvalidKeyValue]
def validate(self) -> bool:
try:
_format_setting(self.name, self.value)
except InvalidKeyValue:
return False
return True
class AllowNewAccountsVariable(AdminSettingsVariable):
name = 'allow_new_accounts'
description = ('Whether or not to allow users to register a new account. '
+ 'The admin can always add a new account.')
class LoginTimeVariable(AdminSettingsVariable):
name = 'login_time'
description = ('How long a user stays logged in, in seconds. '
+ 'Between 1 min and 1 month (60 <= sec <= 2592000)')
class LoginTimeResetVariable(AdminSettingsVariable):
name = 'login_time_reset'
description = 'If the Login Time timer should reset with each API request.'
def input_validation() -> Union[None, Dict[str, Any]]:
"""Checks, extracts and transforms inputs
Raises:
KeyNotFound: A required key was not supplied
InvalidKeyValue: The value of a key is not valid
Returns:
Union[None, Dict[str, Any]]: `None` if the endpoint + method doesn't require input variables.
Otherwise `Dict[str, Any]` with the input variables, checked and formatted.
"""
inputs = {}
input_variables: Dict[str, List[Union[List[InputVariable], str]]]
if request.path.startswith(admin_api_prefix):
input_variables = api_docs[
_admin_api_prefix + request.url_rule.rule.split(admin_api_prefix)[1]
]['input_variables']
else:
input_variables = api_docs[
request.url_rule.rule.split(api_prefix)[1]
]['input_variables']
if not input_variables:
return
if input_variables.get(request.method) is None:
return inputs
given_variables = {}
given_variables[DataSource.DATA] = request.get_json() if request.data else {}
given_variables[DataSource.VALUES] = request.values
for input_variable in input_variables[request.method]:
if (
input_variable.required and
not input_variable.name in given_variables[input_variable.source]
):
raise KeyNotFound(input_variable.name)
input_value = given_variables[input_variable.source].get(
input_variable.name,
input_variable.default
)
if not input_variable(input_value).validate():
raise InvalidKeyValue(input_variable.name, input_value)
inputs[input_variable.name] = input_value
return inputs
from backend.notification_service import get_apprise_services
from backend.settings import get_admin_settings, get_setting, set_setting
from backend.users import User, Users
from frontend.input_validation import (AllowNewAccountsVariable, ColorVariable,
EditNotificationServicesVariable,
EditTimeVariable, EditTitleVariable,
EditURLVariable, LoginTimeResetVariable,
LoginTimeVariable, NewPasswordVariable,
NotificationServicesVariable,
PasswordCreateVariable,
PasswordVariable, QueryVariable,
RepeatIntervalVariable,
RepeatQuantityVariable, SortByVariable,
StaticReminderSortByVariable,
TemplateSortByVariable, TextVariable,
TimeVariable, TitleVariable,
URLVariable, UsernameCreateVariable,
UsernameVariable, WeekDaysVariable,
_admin_api_prefix, admin_api,
admin_api_prefix, api, api_docs,
api_prefix, input_validation)
#===================
# General variables and functions
#===================
api_docs: Dict[str, Dict[str, Any]] = {}
class APIBlueprint(Blueprint):
def route(
self,
rule: str,
description: str = '',
input_variables: Dict[str, List[Union[List[InputVariable], str]]] = {},
requires_auth: bool = True,
**options: Any
) -> Callable[[T_route], T_route]:
@dataclass
class ApiKeyEntry:
exp: int
user_data: User
if self == api:
processed_rule = rule
elif self == admin_api:
processed_rule = _admin_api_prefix + rule
else:
raise NotImplementedError
api_docs[processed_rule] = {
'endpoint': processed_rule,
'description': description,
'requires_auth': requires_auth,
'methods': options['methods'],
'input_variables': {
k: v[0]
for k, v in input_variables.items()
if v and v[0]
},
'method_descriptions': {
k: v[1]
for k, v in input_variables.items()
if v and len(v) == 2 and v[1]
}
}
return super().route(rule, **options)
api_prefix = "/api"
_admin_api_prefix = '/admin'
admin_api_prefix = api_prefix + _admin_api_prefix
api = APIBlueprint('api', __name__)
admin_api = APIBlueprint('admin_api', __name__)
api_key_map = {}
users = Users()
api_key_map: Dict[int, ApiKeyEntry] = {}
def return_api(result: Any, error: str=None, code: int=200) -> Tuple[dict, int]:
return {'error': error, 'result': result}, code
@@ -409,33 +67,37 @@ def auth() -> None:
if not hashed_api_key in api_key_map:
raise APIKeyInvalid
if not (
(
api_key_map[hashed_api_key]['user_data'].admin
and request.path.startswith((admin_api_prefix, api_prefix + '/auth'))
)
or
(
not api_key_map[hashed_api_key]['user_data'].admin
and not request.path.startswith(admin_api_prefix)
)
map_entry = api_key_map[hashed_api_key]
if (
map_entry.user_data.admin
and
not request.path.startswith((admin_api_prefix, api_prefix + '/auth'))
):
raise APIKeyInvalid
if (
not map_entry.user_data.admin
and
request.path.startswith(admin_api_prefix)
):
raise APIKeyInvalid
exp = api_key_map[hashed_api_key]['exp']
if exp <= epoch_time():
if map_entry.exp <= epoch_time():
raise APIKeyExpired
# Api key valid
if get_setting('login_time_reset'):
api_key_map[hashed_api_key]['exp'] = exp = (
g.exp = map_entry.exp = (
epoch_time() + get_setting('login_time')
)
else:
g.exp = map_entry.exp
g.hashed_api_key = hashed_api_key
g.exp = exp
g.user_data = api_key_map[hashed_api_key]['user_data']
g.user_data = map_entry.user_data
return
def endpoint_wrapper(method: Callable) -> Callable:
@@ -458,14 +120,15 @@ def endpoint_wrapper(method: Callable) -> Callable:
return method(*args, **kwargs)
return method(inputs, *args, **kwargs)
except (UsernameTaken, UsernameInvalid, UserNotFound,
AccessUnauthorized,
ReminderNotFound, NotificationServiceNotFound,
NotificationServiceInUse, InvalidTime,
KeyNotFound, InvalidKeyValue,
APIKeyInvalid, APIKeyExpired,
TemplateNotFound,
NewAccountsNotAllowed) as e:
except (AccessUnauthorized, APIKeyExpired,
APIKeyInvalid, InvalidKeyValue,
InvalidTime, KeyNotFound,
NewAccountsNotAllowed,
NotificationServiceInUse,
NotificationServiceNotFound,
ReminderNotFound, TemplateNotFound,
UsernameInvalid, UsernameTaken,
UserNotFound) as e:
return return_api(**e.api_response)
wrapper.__name__ = method.__name__
@@ -484,7 +147,7 @@ def endpoint_wrapper(method: Callable) -> Callable:
)
@endpoint_wrapper
def api_login(inputs: Dict[str, str]):
user = User(inputs['username'], inputs['password'])
user = users.login(inputs['username'], inputs['password'])
# Generate an API key until one
# is generated that isn't used already
@@ -496,12 +159,7 @@ def api_login(inputs: Dict[str, str]):
login_time = get_setting('login_time')
exp = epoch_time() + login_time
api_key_map.update({
hashed_api_key: {
'exp': exp,
'user_data': user
}
})
api_key_map[hashed_api_key] = ApiKeyEntry(exp, user)
result = {'api_key': api_key, 'expires': exp, 'admin': user.admin}
return return_api(result, code=201)
@@ -523,17 +181,17 @@ def api_logout():
)
@endpoint_wrapper
def api_status():
map_entry = api_key_map[g.hashed_api_key]
result = {
'expires': api_key_map[g.hashed_api_key]['exp'],
'username': api_key_map[g.hashed_api_key]['user_data'].username,
'admin': api_key_map[g.hashed_api_key]['user_data'].admin
'expires': map_entry.exp,
'username': map_entry.user_data.username,
'admin': map_entry.user_data.admin
}
return return_api(result)
#===================
# User endpoints
#===================
@api.route(
'/user/add',
'Create a new user account',
@@ -543,7 +201,7 @@ def api_status():
)
@endpoint_wrapper
def api_add_user(inputs: Dict[str, str]):
register_user(inputs['username'], inputs['password'])
users.add(inputs['username'], inputs['password'])
return return_api({}, code=201)
@api.route(
@@ -557,12 +215,13 @@ def api_add_user(inputs: Dict[str, str]):
)
@endpoint_wrapper
def api_manage_user(inputs: Dict[str, str]):
user = api_key_map[g.hashed_api_key].user_data
if request.method == 'PUT':
g.user_data.edit_password(inputs['new_password'])
user.edit_password(inputs['new_password'])
return return_api({})
elif request.method == 'DELETE':
g.user_data.delete()
user.delete()
api_key_map.pop(g.hashed_api_key)
return return_api({})
@@ -581,7 +240,7 @@ def api_manage_user(inputs: Dict[str, str]):
)
@endpoint_wrapper
def api_notification_services_list(inputs: Dict[str, str]):
services: NotificationServices = g.user_data.notification_services
services = api_key_map[g.hashed_api_key].user_data.notification_services
if request.method == 'GET':
result = services.fetchall()
@@ -610,7 +269,10 @@ def api_notification_service_available():
)
@endpoint_wrapper
def api_test_service(inputs: Dict[str, Any]):
g.user_data.notification_services.test_service(inputs['url'])
(api_key_map[g.hashed_api_key]
.user_data
.notification_services
.test_service(inputs['url']))
return return_api({}, code=201)
@api.route(
@@ -624,17 +286,20 @@ def api_test_service(inputs: Dict[str, Any]):
)
@endpoint_wrapper
def api_notification_service(inputs: Dict[str, str], n_id: int):
service: NotificationService = g.user_data.notification_services.fetchone(n_id)
service = (api_key_map[g.hashed_api_key]
.user_data
.notification_services
.fetchone(n_id))
if request.method == 'GET':
result = service.get()
return return_api(result)
elif request.method == 'PUT':
result = service.update(title=inputs['title'],
url=inputs['url'])
return return_api(result)
elif request.method == 'DELETE':
service.delete()
return return_api({})
@@ -659,7 +324,7 @@ def api_notification_service(inputs: Dict[str, str], n_id: int):
)
@endpoint_wrapper
def api_reminders_list(inputs: Dict[str, Any]):
reminders: Reminders = g.user_data.reminders
reminders = api_key_map[g.hashed_api_key].user_data.reminders
if request.method == 'GET':
result = reminders.fetchall(inputs['sort_by'])
@@ -684,7 +349,10 @@ def api_reminders_list(inputs: Dict[str, Any]):
)
@endpoint_wrapper
def api_reminders_query(inputs: Dict[str, str]):
result = g.user_data.reminders.search(inputs['query'], inputs['sort_by'])
result = (api_key_map[g.hashed_api_key]
.user_data
.reminders
.search(inputs['query'], inputs['sort_by']))
return return_api(result)
@api.route(
@@ -696,7 +364,11 @@ def api_reminders_query(inputs: Dict[str, str]):
)
@endpoint_wrapper
def api_test_reminder(inputs: Dict[str, Any]):
g.user_data.reminders.test_reminder(inputs['title'], inputs['notification_services'], inputs['text'])
api_key_map[g.hashed_api_key].user_data.reminders.test_reminder(
inputs['title'],
inputs['notification_services'],
inputs['text']
)
return return_api({}, code=201)
@api.route(
@@ -714,7 +386,8 @@ def api_test_reminder(inputs: Dict[str, Any]):
)
@endpoint_wrapper
def api_get_reminder(inputs: Dict[str, Any], r_id: int):
reminders: Reminders = g.user_data.reminders
reminders = api_key_map[g.hashed_api_key].user_data.reminders
if request.method == 'GET':
result = reminders.fetchone(r_id).get()
return return_api(result)
@@ -750,7 +423,7 @@ def api_get_reminder(inputs: Dict[str, Any], r_id: int):
)
@endpoint_wrapper
def api_get_templates(inputs: Dict[str, Any]):
templates: Templates = g.user_data.templates
templates = api_key_map[g.hashed_api_key].user_data.templates
if request.method == 'GET':
result = templates.fetchall(inputs['sort_by'])
@@ -771,7 +444,10 @@ def api_get_templates(inputs: Dict[str, Any]):
)
@endpoint_wrapper
def api_templates_query(inputs: Dict[str, str]):
result = g.user_data.templates.search(inputs['query'], inputs['sort_by'])
result = (api_key_map[g.hashed_api_key]
.user_data
.templates
.search(inputs['query'], inputs['sort_by']))
return return_api(result)
@api.route(
@@ -786,7 +462,10 @@ def api_templates_query(inputs: Dict[str, str]):
)
@endpoint_wrapper
def api_get_template(inputs: Dict[str, Any], t_id: int):
template: Template = g.user_data.templates.fetchone(t_id)
template = (api_key_map[g.hashed_api_key]
.user_data
.templates
.fetchone(t_id))
if request.method == 'GET':
result = template.get()
@@ -819,7 +498,7 @@ def api_get_template(inputs: Dict[str, Any], t_id: int):
)
@endpoint_wrapper
def api_static_reminders_list(inputs: Dict[str, Any]):
reminders: StaticReminders = g.user_data.static_reminders
reminders = api_key_map[g.hashed_api_key].user_data.static_reminders
if request.method == 'GET':
result = reminders.fetchall(inputs['sort_by'])
@@ -840,7 +519,10 @@ def api_static_reminders_list(inputs: Dict[str, Any]):
)
@endpoint_wrapper
def api_static_reminders_query(inputs: Dict[str, str]):
result = g.user_data.static_reminders.search(inputs['query'], inputs['sort_by'])
result = (api_key_map[g.hashed_api_key]
.user_data
.static_reminders
.search(inputs['query'], inputs['sort_by']))
return return_api(result)
@api.route(
@@ -857,7 +539,8 @@ def api_static_reminders_query(inputs: Dict[str, str]):
)
@endpoint_wrapper
def api_get_static_reminder(inputs: Dict[str, Any], s_id: int):
reminders: StaticReminders = g.user_data.static_reminders
reminders = api_key_map[g.hashed_api_key].user_data.static_reminders
if request.method == 'GET':
result = reminders.fetchone(s_id).get()
return return_api(result)
@@ -929,11 +612,11 @@ def api_admin_settings(inputs: Dict[str, Any]):
@endpoint_wrapper
def api_admin_users(inputs: Dict[str, Any]):
if request.method == 'GET':
result = get_users()
result = users.get_all()
return return_api(result)
elif request.method == 'POST':
register_user(inputs['username'], inputs['password'], True)
users.add(inputs['username'], inputs['password'], True)
return return_api({}, code=201)
@admin_api.route(
@@ -947,14 +630,15 @@ def api_admin_users(inputs: Dict[str, Any]):
)
@endpoint_wrapper
def api_admin_user(inputs: Dict[str, Any], u_id: int):
user = users.get_one(u_id)
if request.method == 'PUT':
edit_user_password(u_id, inputs['new_password'])
user.edit_password(inputs['new_password'])
return return_api({})
elif request.method == 'DELETE':
delete_user(u_id)
user.delete()
for key, value in api_key_map.items():
if value['user_data'].user_id == u_id:
if value.user_data.user_id == u_id:
del api_key_map[key]
break
return return_api({})

View File

@@ -0,0 +1,434 @@
#-*- coding: utf-8 -*-
"""
Input validation for the API
"""
from abc import ABC, abstractmethod
from re import compile
from typing import Any, Callable, Dict, List, Union
from apprise import Apprise
from flask import Blueprint, request
from flask.sansio.scaffold import T_route
from backend.custom_exceptions import (AccessUnauthorized, InvalidKeyValue,
InvalidTime, KeyNotFound,
NewAccountsNotAllowed,
NotificationServiceNotFound,
UsernameInvalid, UsernameTaken,
UserNotFound)
from backend.helpers import (RepeatQuantity, SortingMethod,
TimelessSortingMethod)
from backend.settings import _format_setting
api_prefix = "/api"
_admin_api_prefix = '/admin'
admin_api_prefix = api_prefix + _admin_api_prefix
color_regex = compile(r'#[0-9a-f]{6}')
class DataSource:
DATA = 1
VALUES = 2
class InputVariable(ABC):
value: Any
@abstractmethod
def __init__(self, value: Any) -> None:
pass
@property
@abstractmethod
def name() -> str:
pass
@abstractmethod
def validate(self) -> bool:
pass
@property
@abstractmethod
def required() -> bool:
pass
@property
@abstractmethod
def default() -> Any:
pass
@property
@abstractmethod
def source() -> int:
pass
@property
@abstractmethod
def description() -> str:
pass
@property
@abstractmethod
def related_exceptions() -> List[Exception]:
pass
class DefaultInputVariable(InputVariable):
source = DataSource.DATA
required = True
default = None
related_exceptions = []
def __init__(self, value: Any) -> None:
self.value = value
def validate(self) -> bool:
return isinstance(self.value, str) and self.value
def __repr__(self) -> str:
return f'| {self.name} | {"Yes" if self.required else "No"} | {self.description} | N/A |'
class NonRequiredVersion(InputVariable):
required = False
def validate(self) -> bool:
return self.value is None or super().validate()
class UsernameVariable(DefaultInputVariable):
name = 'username'
description = 'The username of the user account'
related_exceptions = [KeyNotFound, UserNotFound]
class PasswordVariable(DefaultInputVariable):
name = 'password'
description = 'The password of the user account'
related_exceptions = [KeyNotFound, AccessUnauthorized]
class NewPasswordVariable(PasswordVariable):
name = 'new_password'
description = 'The new password of the user account'
related_exceptions = [KeyNotFound]
class UsernameCreateVariable(UsernameVariable):
related_exceptions = [
KeyNotFound,
UsernameInvalid, UsernameTaken,
NewAccountsNotAllowed
]
class PasswordCreateVariable(PasswordVariable):
related_exceptions = [KeyNotFound]
class TitleVariable(DefaultInputVariable):
name = 'title'
description = 'The title of the entry'
related_exceptions = [KeyNotFound]
class URLVariable(DefaultInputVariable):
name = 'url'
description = 'The Apprise URL of the notification service'
related_exceptions = [KeyNotFound, InvalidKeyValue]
def validate(self) -> bool:
return Apprise().add(self.value)
class EditTitleVariable(NonRequiredVersion, TitleVariable):
related_exceptions = []
class EditURLVariable(NonRequiredVersion, URLVariable):
related_exceptions = [InvalidKeyValue]
class SortByVariable(DefaultInputVariable):
name = 'sort_by'
description = 'How to sort the result'
required = False
source = DataSource.VALUES
_options = [k.lower() for k in SortingMethod._member_names_]
default = SortingMethod._member_names_[0].lower()
related_exceptions = [InvalidKeyValue]
def __init__(self, value: str) -> None:
self.value = value
def validate(self) -> bool:
if not self.value in self._options:
return False
self.value = SortingMethod[self.value.upper()]
return True
def __repr__(self) -> str:
return '| {n} | {r} | {d} | {v} |'.format(
n=self.name,
r="Yes" if self.required else "No",
d=self.description,
v=", ".join(f'`{o}`' for o in self._options)
)
class TemplateSortByVariable(SortByVariable):
_options = [k.lower() for k in TimelessSortingMethod._member_names_]
default = TimelessSortingMethod._member_names_[0].lower()
def validate(self) -> bool:
if not self.value in self._options:
return False
self.value = TimelessSortingMethod[self.value.upper()]
return True
class StaticReminderSortByVariable(TemplateSortByVariable):
pass
class TimeVariable(DefaultInputVariable):
name = 'time'
description = 'The UTC epoch timestamp that the reminder should be sent at'
related_exceptions = [KeyNotFound, InvalidKeyValue, InvalidTime]
def validate(self) -> bool:
return isinstance(self.value, (float, int))
class EditTimeVariable(NonRequiredVersion, TimeVariable):
related_exceptions = [InvalidKeyValue, InvalidTime]
class NotificationServicesVariable(DefaultInputVariable):
name = 'notification_services'
description = "Array of the id's of the notification services to use to send the notification"
related_exceptions = [KeyNotFound, InvalidKeyValue, NotificationServiceNotFound]
def validate(self) -> bool:
if not isinstance(self.value, list):
return False
if not self.value:
return False
for v in self.value:
if not isinstance(v, int):
return False
return True
class EditNotificationServicesVariable(NonRequiredVersion, NotificationServicesVariable):
related_exceptions = [InvalidKeyValue, NotificationServiceNotFound]
class TextVariable(NonRequiredVersion, DefaultInputVariable):
name = 'text'
description = 'The body of the entry'
default = ''
def validate(self) -> bool:
return isinstance(self.value, str)
class RepeatQuantityVariable(DefaultInputVariable):
name = 'repeat_quantity'
description = 'The quantity of the repeat_interval'
required = False
_options = [m.lower() for m in RepeatQuantity._member_names_]
default = None
related_exceptions = [InvalidKeyValue]
def validate(self) -> bool:
if self.value is None:
return True
if not self.value in self._options:
return False
self.value = RepeatQuantity[self.value.upper()]
return True
def __repr__(self) -> str:
return '| {n} | {r} | {d} | {v} |'.format(
n=self.name,
r="Yes" if self.required else "No",
d=self.description,
v=", ".join(f'`{o}`' for o in self._options)
)
class RepeatIntervalVariable(DefaultInputVariable):
name = 'repeat_interval'
description = 'The number of the interval'
required = False
default = None
related_exceptions = [InvalidKeyValue]
def validate(self) -> bool:
return (
self.value is None
or (
isinstance(self.value, int)
and self.value > 0
)
)
class WeekDaysVariable(DefaultInputVariable):
name = 'weekdays'
description = 'On which days of the weeks to run the reminder'
required = False
default = None
related_exceptions = [InvalidKeyValue]
_options = {0, 1, 2, 3, 4, 5, 6}
def validate(self) -> bool:
return self.value is None or (
isinstance(self.value, list)
and len(self.value) > 0
and all(v in self._options for v in self.value)
)
class ColorVariable(DefaultInputVariable):
name = 'color'
description = 'The hex code of the color of the entry, which is shown in the web-ui'
required = False
default = None
related_exceptions = [InvalidKeyValue]
def validate(self) -> bool:
return self.value is None or color_regex.search(self.value)
class QueryVariable(DefaultInputVariable):
name = 'query'
description = 'The search term'
source = DataSource.VALUES
class AdminSettingsVariable(DefaultInputVariable):
related_exceptions = [KeyNotFound, InvalidKeyValue]
def validate(self) -> bool:
try:
_format_setting(self.name, self.value)
except InvalidKeyValue:
return False
return True
class AllowNewAccountsVariable(AdminSettingsVariable):
name = 'allow_new_accounts'
description = ('Whether or not to allow users to register a new account. '
+ 'The admin can always add a new account.')
class LoginTimeVariable(AdminSettingsVariable):
name = 'login_time'
description = ('How long a user stays logged in, in seconds. '
+ 'Between 1 min and 1 month (60 <= sec <= 2592000)')
class LoginTimeResetVariable(AdminSettingsVariable):
name = 'login_time_reset'
description = 'If the Login Time timer should reset with each API request.'
def input_validation() -> Union[None, Dict[str, Any]]:
"""Checks, extracts and transforms inputs
Raises:
KeyNotFound: A required key was not supplied
InvalidKeyValue: The value of a key is not valid
Returns:
Union[None, Dict[str, Any]]: `None` if the endpoint + method doesn't require input variables.
Otherwise `Dict[str, Any]` with the input variables, checked and formatted.
"""
inputs = {}
input_variables: Dict[str, List[Union[List[InputVariable], str]]]
if request.path.startswith(admin_api_prefix):
input_variables = api_docs[
_admin_api_prefix + request.url_rule.rule.split(admin_api_prefix)[1]
]['input_variables']
else:
input_variables = api_docs[
request.url_rule.rule.split(api_prefix)[1]
]['input_variables']
if not input_variables:
return
if input_variables.get(request.method) is None:
return inputs
given_variables = {}
given_variables[DataSource.DATA] = request.get_json() if request.data else {}
given_variables[DataSource.VALUES] = request.values
for input_variable in input_variables[request.method]:
if (
input_variable.required and
not input_variable.name in given_variables[input_variable.source]
):
raise KeyNotFound(input_variable.name)
input_value = given_variables[input_variable.source].get(
input_variable.name,
input_variable.default
)
value: InputVariable = input_variable(input_value)
if not value.validate():
raise InvalidKeyValue(input_variable.name, input_value)
inputs[input_variable.name] = value.value
return inputs
api_docs: Dict[str, Dict[str, Any]] = {}
class APIBlueprint(Blueprint):
def route(
self,
rule: str,
description: str = '',
input_variables: Dict[str, List[Union[List[InputVariable], str]]] = {},
requires_auth: bool = True,
**options: Any
) -> Callable[[T_route], T_route]:
if self == api:
processed_rule = rule
elif self == admin_api:
processed_rule = _admin_api_prefix + rule
else:
raise NotImplementedError
api_docs[processed_rule] = {
'endpoint': processed_rule,
'description': description,
'requires_auth': requires_auth,
'methods': options['methods'],
'input_variables': {
k: v[0]
for k, v in input_variables.items()
if v and v[0]
},
'method_descriptions': {
k: v[1]
for k, v in input_variables.items()
if v and len(v) == 2 and v[1]
}
}
return super().route(rule, **options)
api = APIBlueprint('api', __name__)
admin_api = APIBlueprint('admin_api', __name__)

View File

@@ -1,20 +1,22 @@
#-*- coding: utf-8 -*-
import logging
from flask import Blueprint, render_template
ui = Blueprint('ui', __name__)
methods = ['GET']
class UIVariables:
url_prefix: str = ''
@ui.route('/', methods=methods)
def ui_login():
return render_template('login.html', url_prefix=logging.URL_PREFIX)
return render_template('login.html', url_prefix=UIVariables.url_prefix)
@ui.route('/reminders', methods=methods)
def ui_reminders():
return render_template('reminders.html', url_prefix=logging.URL_PREFIX)
return render_template('reminders.html', url_prefix=UIVariables.url_prefix)
@ui.route('/admin', methods=methods)
def ui_admin():
return render_template('admin.html', url_prefixx=logging.URL_PREFIX)
return render_template('admin.html', url_prefix=UIVariables.url_prefix)

View File

@@ -1,15 +1,18 @@
#!/usr/bin/env python3
#-*- coding: utf-8 -*-
from sys import path
from os.path import dirname
from sys import path
from frontend.input_validation import DataSource
path.insert(0, dirname(path[0]))
from subprocess import run
from typing import Union
from frontend.api import (DataSource, NotificationServiceNotFound,
ReminderNotFound, TemplateNotFound, api_docs)
from frontend.api import (NotificationServiceNotFound, ReminderNotFound,
TemplateNotFound, api_docs)
from MIND import _folder_path, api_prefix
url_var_map = {

View File

@@ -8,7 +8,8 @@ import backend.custom_exceptions
class Test_Custom_Exceptions(unittest.TestCase):
def test_type(self):
defined_exceptions: List[Exception] = filter(
lambda c: c.__module__ == 'backend.custom_exceptions' and c is not backend.custom_exceptions.CustomException,
lambda c: c.__module__ == 'backend.custom_exceptions'
and c is not backend.custom_exceptions.CustomException,
map(
lambda c: c[1],
getmembers(modules['backend.custom_exceptions'], isclass)
@@ -16,11 +17,21 @@ class Test_Custom_Exceptions(unittest.TestCase):
)
for defined_exception in defined_exceptions:
self.assertEqual(
self.assertIn(
getmro(defined_exception)[1],
backend.custom_exceptions.CustomException
(
backend.custom_exceptions.CustomException,
Exception
)
)
result = defined_exception().api_response
try:
result = defined_exception().api_response
except TypeError:
try:
result = defined_exception('1').api_response
except TypeError:
result = defined_exception('1', '2').api_response
self.assertIsInstance(result, dict)
result['error']
result['result']

View File

@@ -1,10 +1,12 @@
import unittest
from backend.db import DBConnection
from MIND import DB_FILENAME, _folder_path
from backend.helpers import folder_path
from MIND import DB_FILENAME
class Test_DB(unittest.TestCase):
def test_foreign_key(self):
DBConnection.file = _folder_path(*DB_FILENAME)
DBConnection.file = folder_path(*DB_FILENAME)
instance = DBConnection(timeout=20.0)
self.assertEqual(instance.cursor().execute("PRAGMA foreign_keys;").fetchone()[0], 1)

View File

@@ -1,20 +1,14 @@
import unittest
from backend.reminders import filter_function, ReminderHandler
from backend.helpers import search_filter
class Test_Reminder_Handler(unittest.TestCase):
def test_starting_stopping(self):
context = 'test'
instance = ReminderHandler(context)
self.assertIs(context, instance.context)
def test_filter_function(self):
p = {
'title': 'TITLE',
'text': 'TEXT'
}
for test_case in ('', 'title', 'ex'):
self.assertTrue(filter_function(test_case, p))
self.assertTrue(search_filter(test_case, p))
for test_case in (' ', 'Hello'):
self.assertFalse(filter_function(test_case, p))
self.assertFalse(search_filter(test_case, p))

View File

@@ -1,14 +1,14 @@
import unittest
from backend.custom_exceptions import UsernameInvalid
from backend.users import ONEPASS_INVALID_USERNAMES, _check_username
from backend.users import ONEPASS_INVALID_USERNAMES, Users
class Test_Users(unittest.TestCase):
def test_username_check(self):
users = Users()
for test_case in ('', 'test'):
_check_username(test_case)
users._check_username(test_case)
for test_case in (' ', ' ', '0', 'api', *ONEPASS_INVALID_USERNAMES):
with self.assertRaises(UsernameInvalid):
_check_username(test_case)
users._check_username(test_case)