mirror of
https://github.com/Casvt/MIND.git
synced 2026-02-19 11:54:46 -05:00
Backend Refactor
This commit is contained in:
4
.vscode/settings.json
vendored
4
.vscode/settings.json
vendored
@@ -7,5 +7,7 @@
|
||||
"*_test.py"
|
||||
],
|
||||
"python.testing.pytestEnabled": false,
|
||||
"python.testing.unittestEnabled": true
|
||||
"python.testing.unittestEnabled": true,
|
||||
"python.analysis.autoImportCompletions": true,
|
||||
"python.analysis.typeCheckingMode": "off"
|
||||
}
|
||||
78
MIND.py
78
MIND.py
@@ -1,20 +1,24 @@
|
||||
#!/usr/bin/env python3
|
||||
#-*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
The main file where MIND is started from
|
||||
"""
|
||||
|
||||
import logging
|
||||
from os import makedirs, urandom
|
||||
from os.path import abspath, dirname, isfile, join
|
||||
from os.path import dirname, isfile
|
||||
from shutil import move
|
||||
from sys import version_info
|
||||
|
||||
from flask import Flask, render_template, request
|
||||
from waitress.server import create_server
|
||||
from werkzeug.middleware.dispatcher import DispatcherMiddleware
|
||||
|
||||
from backend.db import DBConnection, ThreadedTaskDispatcher, close_db, setup_db
|
||||
from frontend.api import (admin_api, admin_api_prefix, api, api_prefix,
|
||||
reminder_handler)
|
||||
from frontend.ui import ui
|
||||
from backend.helpers import check_python_version, folder_path
|
||||
from backend.reminders import ReminderHandler
|
||||
from frontend.api import admin_api, admin_api_prefix, api, api_prefix
|
||||
from frontend.ui import UIVariables, ui
|
||||
|
||||
HOST = '0.0.0.0'
|
||||
PORT = '8080'
|
||||
@@ -23,35 +27,33 @@ LOGGING_LEVEL = logging.INFO
|
||||
THREADS = 10
|
||||
DB_FILENAME = 'db', 'MIND.db'
|
||||
|
||||
UIVariables.url_prefix = URL_PREFIX
|
||||
logging.basicConfig(
|
||||
level=LOGGING_LEVEL,
|
||||
format='[%(asctime)s][%(threadName)s][%(levelname)s] %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
|
||||
def _folder_path(*folders) -> str:
|
||||
"""Turn filepaths relative to the project folder into absolute paths
|
||||
Returns:
|
||||
str: The absolute filepath
|
||||
"""
|
||||
return join(dirname(abspath(__file__)), *folders)
|
||||
|
||||
def _create_app() -> Flask:
|
||||
"""Create a Flask app instance
|
||||
|
||||
Returns:
|
||||
Flask: The created app instance
|
||||
"""
|
||||
app = Flask(
|
||||
__name__,
|
||||
template_folder=_folder_path('frontend','templates'),
|
||||
static_folder=_folder_path('frontend','static'),
|
||||
template_folder=folder_path('frontend','templates'),
|
||||
static_folder=folder_path('frontend','static'),
|
||||
static_url_path='/static'
|
||||
)
|
||||
app.config['SECRET_KEY'] = urandom(32)
|
||||
app.config['JSONIFY_PRETTYPRINT_REGULAR'] = True
|
||||
app.config['JSON_SORT_KEYS'] = False
|
||||
app.config['APPLICATION_ROOT'] = URL_PREFIX
|
||||
app.wsgi_app = DispatcherMiddleware(Flask(__name__), {URL_PREFIX: app.wsgi_app})
|
||||
app.wsgi_app = DispatcherMiddleware(
|
||||
Flask(__name__),
|
||||
{URL_PREFIX: app.wsgi_app}
|
||||
)
|
||||
|
||||
# Add error handlers
|
||||
@app.errorhandler(400)
|
||||
@@ -68,7 +70,7 @@ def _create_app() -> Flask:
|
||||
|
||||
@app.errorhandler(404)
|
||||
def not_found(e):
|
||||
if request.path.startswith('/api'):
|
||||
if request.path.startswith(api_prefix):
|
||||
return {'error': 'Not Found', 'result': {}}, 404
|
||||
return render_template('page_not_found.html', url_prefix=logging.URL_PREFIX)
|
||||
|
||||
@@ -83,40 +85,42 @@ def _create_app() -> Flask:
|
||||
|
||||
def MIND() -> None:
|
||||
"""The main function of MIND
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Check python version
|
||||
if (version_info.major < 3) or (version_info.major == 3 and version_info.minor < 7):
|
||||
logging.error('Error: the minimum python version required is python3.7 (currently ' + version_info.major + '.' + version_info.minor + '.' + version_info.micro + ')')
|
||||
logging.info('Starting up MIND')
|
||||
|
||||
if not check_python_version():
|
||||
exit(1)
|
||||
|
||||
if isfile(folder_path('db', 'Noted.db')):
|
||||
move(folder_path('db', 'Noted.db'), folder_path(*DB_FILENAME))
|
||||
|
||||
db_location = folder_path(*DB_FILENAME)
|
||||
logging.debug(f'Database location: {db_location}')
|
||||
makedirs(dirname(db_location), exist_ok=True)
|
||||
|
||||
DBConnection.file = db_location
|
||||
|
||||
# Register web server
|
||||
# We need to get the value to ui.py but MIND.py imports from ui.py so we get an import loop.
|
||||
# To go around this, we abuse the fact that the logging module is a singleton.
|
||||
# We add an attribute to the logging module and in ui.py get the value this way.
|
||||
logging.URL_PREFIX = URL_PREFIX
|
||||
app = _create_app()
|
||||
reminder_handler = ReminderHandler(app.app_context)
|
||||
with app.app_context():
|
||||
if isfile(_folder_path('db', 'Noted.db')):
|
||||
move(_folder_path('db', 'Noted.db'), _folder_path(*DB_FILENAME))
|
||||
|
||||
db_location = _folder_path(*DB_FILENAME)
|
||||
logging.debug(f'Database location: {db_location}')
|
||||
makedirs(dirname(db_location), exist_ok=True)
|
||||
|
||||
DBConnection.file = db_location
|
||||
setup_db()
|
||||
reminder_handler.find_next_reminder()
|
||||
|
||||
# Create waitress server and run
|
||||
dispatcher = ThreadedTaskDispatcher()
|
||||
dispatcher.set_thread_count(THREADS)
|
||||
server = create_server(app, _dispatcher=dispatcher, host=HOST, port=PORT, threads=THREADS)
|
||||
server = create_server(
|
||||
app,
|
||||
_dispatcher=dispatcher,
|
||||
host=HOST,
|
||||
port=PORT,
|
||||
threads=THREADS
|
||||
)
|
||||
logging.info(f'MIND running on http://{HOST}:{PORT}{URL_PREFIX}')
|
||||
# =================
|
||||
server.run()
|
||||
|
||||
# Stopping thread
|
||||
# =================
|
||||
|
||||
reminder_handler.stop_handling()
|
||||
|
||||
return
|
||||
|
||||
@@ -1,8 +1,17 @@
|
||||
#-*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
All custom exceptions are defined here
|
||||
"""
|
||||
|
||||
"""
|
||||
Note: Not all CE's inherit from CustomException.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
class CustomException(Exception):
|
||||
def __init__(self, e=None) -> None:
|
||||
logging.warning(self.__doc__)
|
||||
@@ -13,10 +22,18 @@ class UsernameTaken(CustomException):
|
||||
"""The username is already taken"""
|
||||
api_response = {'error': 'UsernameTaken', 'result': {}, 'code': 400}
|
||||
|
||||
class UsernameInvalid(CustomException):
|
||||
class UsernameInvalid(Exception):
|
||||
"""The username contains invalid characters"""
|
||||
api_response = {'error': 'UsernameInvalid', 'result': {}, 'code': 400}
|
||||
|
||||
def __init__(self, username: str):
|
||||
self.username = username
|
||||
super().__init__(self.username)
|
||||
logging.warning(
|
||||
f'The username contains invalid characters: {username}'
|
||||
)
|
||||
return
|
||||
|
||||
class UserNotFound(CustomException):
|
||||
"""The user requested can not be found"""
|
||||
api_response = {'error': 'UserNotFound', 'result': {}, 'code': 404}
|
||||
@@ -33,59 +50,81 @@ class NotificationServiceNotFound(CustomException):
|
||||
"""The notification service was not found"""
|
||||
api_response = {'error': 'NotificationServiceNotFound', 'result': {}, 'code': 404}
|
||||
|
||||
class NotificationServiceInUse(CustomException):
|
||||
"""The notification service is wished to be deleted but a reminder is still using it"""
|
||||
class NotificationServiceInUse(Exception):
|
||||
"""
|
||||
The notification service is wished to be deleted
|
||||
but a reminder is still using it
|
||||
"""
|
||||
def __init__(self, type: str=''):
|
||||
self.type = type
|
||||
super().__init__(self.type)
|
||||
logging.warning(
|
||||
f'The notification is wished to be deleted but a reminder of type {type} is still using it'
|
||||
)
|
||||
return
|
||||
|
||||
@property
|
||||
def api_response(self) -> Dict[str, Any]:
|
||||
return {'error': 'NotificationServiceInUse', 'result': {'type': self.type}, 'code': 400}
|
||||
return {
|
||||
'error': 'NotificationServiceInUse',
|
||||
'result': {'type': self.type},
|
||||
'code': 400
|
||||
}
|
||||
|
||||
class InvalidTime(CustomException):
|
||||
"""The time given is in the past"""
|
||||
api_response = {'error': 'InvalidTime', 'result': {}, 'code': 400}
|
||||
|
||||
class KeyNotFound(CustomException):
|
||||
class KeyNotFound(Exception):
|
||||
"""A key was not found in the input that is required to be given"""
|
||||
def __init__(self, key: str=''):
|
||||
self.key = key
|
||||
super().__init__(self.key)
|
||||
logging.warning(
|
||||
"This key was not found in the API request,"
|
||||
+ f" eventhough it's required: {key}"
|
||||
)
|
||||
return
|
||||
|
||||
@property
|
||||
def api_response(self) -> Dict[str, Any]:
|
||||
return {'error': 'KeyNotFound', 'result': {'key': self.key}, 'code': 400}
|
||||
return {
|
||||
'error': 'KeyNotFound',
|
||||
'result': {'key': self.key},
|
||||
'code': 400
|
||||
}
|
||||
|
||||
class InvalidKeyValue(CustomException):
|
||||
class InvalidKeyValue(Exception):
|
||||
"""The value of a key is invalid"""
|
||||
def __init__(self, key: str='', value: str=''):
|
||||
self.key = key
|
||||
self.value = value
|
||||
super().__init__(self.key)
|
||||
logging.warning(
|
||||
'This key in the API request has an invalid value: ' +
|
||||
f'{key} = {value}'
|
||||
)
|
||||
|
||||
@property
|
||||
def api_response(self) -> Dict[str, Any]:
|
||||
return {'error': 'InvalidKeyValue', 'result': {'key': self.key, 'value': self.value}, 'code': 400}
|
||||
return {
|
||||
'error': 'InvalidKeyValue',
|
||||
'result': {'key': self.key, 'value': self.value},
|
||||
'code': 400
|
||||
}
|
||||
|
||||
class TemplateNotFound(CustomException):
|
||||
"""The template was not found"""
|
||||
api_response = {'error': 'TemplateNotFound', 'result': {}, 'code': 404}
|
||||
|
||||
class APIKeyInvalid(CustomException):
|
||||
class APIKeyInvalid(Exception):
|
||||
"""The API key is not correct"""
|
||||
api_response = {'error': 'APIKeyInvalid', 'result': {}, 'code': 401}
|
||||
|
||||
def __init__(self, e=None) -> None:
|
||||
return
|
||||
|
||||
class APIKeyExpired(CustomException):
|
||||
class APIKeyExpired(Exception):
|
||||
"""The API key has expired"""
|
||||
api_response = {'error': 'APIKeyExpired', 'result': {}, 'code': 401}
|
||||
|
||||
def __init__(self, e=None) -> None:
|
||||
return
|
||||
|
||||
class NewAccountsNotAllowed(CustomException):
|
||||
"""It's not allowed to create a new account"""
|
||||
api_response = {'error': 'NewAccountsNotAllowed', 'result': {}, 'code': 403}
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
#-*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
Setting up and interacting with the database.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from sqlite3 import Connection, ProgrammingError, Row
|
||||
from threading import current_thread, main_thread
|
||||
from time import time
|
||||
from typing import Union
|
||||
from typing import Type, Union
|
||||
|
||||
from flask import g
|
||||
from waitress.task import ThreadedTaskDispatcher as OldThreadedTaskDispatcher
|
||||
@@ -14,14 +18,14 @@ from backend.custom_exceptions import AccessUnauthorized, UserNotFound
|
||||
|
||||
__DATABASE_VERSION__ = 8
|
||||
|
||||
class Singleton(type):
|
||||
class DB_Singleton(type):
|
||||
_instances = {}
|
||||
def __call__(cls, *args, **kwargs):
|
||||
i = f'{cls}{current_thread()}'
|
||||
if (i not in cls._instances
|
||||
or cls._instances[i].closed):
|
||||
logging.debug(f'Creating singleton instance: {i}')
|
||||
cls._instances[i] = super(Singleton, cls).__call__(*args, **kwargs)
|
||||
cls._instances[i] = super(DB_Singleton, cls).__call__(*args, **kwargs)
|
||||
|
||||
return cls._instances[i]
|
||||
|
||||
@@ -29,19 +33,21 @@ class ThreadedTaskDispatcher(OldThreadedTaskDispatcher):
|
||||
def handler_thread(self, thread_no: int) -> None:
|
||||
super().handler_thread(thread_no)
|
||||
i = f'{DBConnection}{current_thread()}'
|
||||
if i in Singleton._instances and not Singleton._instances[i].closed:
|
||||
if i in DB_Singleton._instances and not DB_Singleton._instances[i].closed:
|
||||
logging.debug(f'Closing singleton instance: {i}')
|
||||
Singleton._instances[i].close()
|
||||
|
||||
DB_Singleton._instances[i].close()
|
||||
return
|
||||
|
||||
def shutdown(self, cancel_pending: bool = True, timeout: int = 5) -> bool:
|
||||
print()
|
||||
logging.info('Shutting down MIND...')
|
||||
super().shutdown(cancel_pending, timeout)
|
||||
DBConnection(20.0).close()
|
||||
return
|
||||
|
||||
class DBConnection(Connection, metaclass=Singleton):
|
||||
class DBConnection(Connection, metaclass=DB_Singleton):
|
||||
file = ''
|
||||
|
||||
|
||||
def __init__(self, timeout: float) -> None:
|
||||
logging.debug(f'Opening database connection for {current_thread()}')
|
||||
super().__init__(self.file, timeout=timeout)
|
||||
@@ -55,11 +61,13 @@ class DBConnection(Connection, metaclass=Singleton):
|
||||
super().close()
|
||||
return
|
||||
|
||||
def get_db(output_type: Union[dict, tuple]=tuple):
|
||||
def get_db(output_type: Union[Type[dict], Type[tuple]]=tuple):
|
||||
"""Get a database cursor instance. Coupled to Flask's g.
|
||||
|
||||
Args:
|
||||
output_type (Union[dict, tuple], optional): The type of output: a tuple or dictionary with the row values. Defaults to tuple.
|
||||
output_type (Union[Type[dict], Type[tuple]], optional):
|
||||
The type of output: a tuple or dictionary with the row values.
|
||||
Defaults to tuple.
|
||||
|
||||
Returns:
|
||||
Cursor: The Cursor instance to use
|
||||
@@ -220,7 +228,7 @@ def migrate_db(current_db_version: int) -> None:
|
||||
if current_db_version == 7:
|
||||
# V7 -> V8
|
||||
from backend.settings import _format_setting, default_settings
|
||||
from backend.users import register_user
|
||||
from backend.users import Users
|
||||
|
||||
cursor.executescript("""
|
||||
DROP TABLE config;
|
||||
@@ -239,7 +247,7 @@ def migrate_db(current_db_version: int) -> None:
|
||||
default_settings.items()
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
cursor.executescript("""
|
||||
ALTER TABLE users
|
||||
ADD admin BOOL NOT NULL DEFAULT 0;
|
||||
@@ -248,9 +256,9 @@ def migrate_db(current_db_version: int) -> None:
|
||||
SET username = 'admin_old'
|
||||
WHERE username = 'admin';
|
||||
""")
|
||||
|
||||
register_user('admin', 'admin')
|
||||
|
||||
|
||||
Users().add('admin', 'admin', True)
|
||||
|
||||
cursor.execute("""
|
||||
UPDATE users
|
||||
SET admin = 1
|
||||
@@ -264,6 +272,8 @@ def setup_db() -> None:
|
||||
"""
|
||||
from backend.settings import (_format_setting, default_settings, get_setting,
|
||||
set_setting)
|
||||
from backend.users import Users
|
||||
|
||||
cursor = get_db()
|
||||
cursor.execute("PRAGMA journal_mode = wal;")
|
||||
|
||||
@@ -348,11 +358,22 @@ def setup_db() -> None:
|
||||
default_settings.items()
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
current_db_version = get_setting('database_version')
|
||||
logging.debug(f'Current database version {current_db_version} and desired database version {__DATABASE_VERSION__}')
|
||||
if current_db_version < __DATABASE_VERSION__:
|
||||
logging.debug(
|
||||
f'Database migration: {current_db_version} -> {__DATABASE_VERSION__}'
|
||||
)
|
||||
migrate_db(current_db_version)
|
||||
set_setting('database_version', __DATABASE_VERSION__)
|
||||
|
||||
users = Users()
|
||||
if not 'admin' in users:
|
||||
users.add('admin', 'admin', True)
|
||||
cursor.execute("""
|
||||
UPDATE users
|
||||
SET admin = 1
|
||||
WHERE username = 'admin';
|
||||
""")
|
||||
|
||||
return
|
||||
|
||||
91
backend/helpers.py
Normal file
91
backend/helpers.py
Normal file
@@ -0,0 +1,91 @@
|
||||
#-*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
General functions
|
||||
"""
|
||||
|
||||
import logging
|
||||
from enum import Enum
|
||||
from os.path import abspath, dirname, join
|
||||
from sys import version_info
|
||||
|
||||
|
||||
def folder_path(*folders) -> str:
|
||||
"""Turn filepaths relative to the project folder into absolute paths
|
||||
|
||||
Returns:
|
||||
str: The absolute filepath
|
||||
"""
|
||||
return join(dirname(dirname(abspath(__file__))), *folders)
|
||||
|
||||
|
||||
def check_python_version() -> bool:
|
||||
"""Check if the python version that is used is a minimum version.
|
||||
|
||||
Returns:
|
||||
bool: Whether or not the python version is version 3.8 or above or not.
|
||||
"""
|
||||
if not (version_info.major == 3 and version_info.minor >= 8):
|
||||
logging.critical(
|
||||
'The minimum python version required is python3.8 ' +
|
||||
'(currently ' + version_info.major + '.' + version_info.minor + '.' + version_info.micro + ').'
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def search_filter(query: str, result: dict) -> bool:
|
||||
"""Filter library results based on a query.
|
||||
|
||||
Args:
|
||||
query (str): The query to filter with.
|
||||
result (dict): The library result to check.
|
||||
|
||||
Returns:
|
||||
bool: Whether or not the result passes the filter.
|
||||
"""
|
||||
query = query.lower()
|
||||
return (
|
||||
query in result["title"].lower()
|
||||
or query in result["text"].lower()
|
||||
)
|
||||
|
||||
|
||||
class Singleton(type):
|
||||
_instances = {}
|
||||
def __call__(cls, *args, **kwargs):
|
||||
c = str(cls)
|
||||
if c not in cls._instances:
|
||||
cls._instances[c] = super().__call__(*args, **kwargs)
|
||||
|
||||
return cls._instances[c]
|
||||
|
||||
|
||||
class BaseEnum(Enum):
|
||||
def __eq__(self, other) -> bool:
|
||||
return self.value == other
|
||||
|
||||
|
||||
class TimelessSortingMethod(BaseEnum):
|
||||
TITLE = (lambda r: (r['title'], r['text'], r['color']), False)
|
||||
TITLE_REVERSED = (lambda r: (r['title'], r['text'], r['color']), True)
|
||||
DATE_ADDED = (lambda r: r['id'], False)
|
||||
DATE_ADDED_REVERSED = (lambda r: r['id'], True)
|
||||
|
||||
|
||||
class SortingMethod(BaseEnum):
|
||||
TIME = (lambda r: (r['time'], r['title'], r['text'], r['color']), False)
|
||||
TIME_REVERSED = (lambda r: (r['time'], r['title'], r['text'], r['color']), True)
|
||||
TITLE = (lambda r: (r['title'], r['time'], r['text'], r['color']), False)
|
||||
TITLE_REVERSED = (lambda r: (r['title'], r['time'], r['text'], r['color']), True)
|
||||
DATE_ADDED = (lambda r: r['id'], False)
|
||||
DATE_ADDED_REVERSED = (lambda r: r['id'], True)
|
||||
|
||||
|
||||
class RepeatQuantity(BaseEnum):
|
||||
YEARS = "years"
|
||||
MONTHS = "months"
|
||||
WEEKS = "weeks"
|
||||
DAYS = "days"
|
||||
HOURS = "hours"
|
||||
MINUTES = "minutes"
|
||||
@@ -74,7 +74,10 @@ def get_apprise_services() -> List[Dict[str, Union[str, Dict[str, list]]]]:
|
||||
'prefix': content.get('prefix'),
|
||||
'regex': process_regex(content.get('regex'))
|
||||
}
|
||||
for content, _ in ((entry['details']['tokens'][e], handled_tokens.add(e)) for e in v['group'])
|
||||
for content, _ in (
|
||||
(entry['details']['tokens'][e], handled_tokens.add(e))
|
||||
for e in v['group']
|
||||
)
|
||||
]
|
||||
}
|
||||
for k, v in
|
||||
@@ -175,8 +178,13 @@ class NotificationService:
|
||||
def __init__(self, user_id: int, notification_service_id: int) -> None:
|
||||
self.id = notification_service_id
|
||||
|
||||
if not get_db().execute(
|
||||
"SELECT 1 FROM notification_services WHERE id = ? AND user_id = ? LIMIT 1;",
|
||||
if not get_db().execute("""
|
||||
SELECT 1
|
||||
FROM notification_services
|
||||
WHERE id = ?
|
||||
AND user_id = ?
|
||||
LIMIT 1;
|
||||
""",
|
||||
(self.id, user_id)
|
||||
).fetchone():
|
||||
raise NotificationServiceNotFound
|
||||
@@ -187,8 +195,12 @@ class NotificationService:
|
||||
Returns:
|
||||
dict: The info about the notification service
|
||||
"""
|
||||
result = dict(get_db(dict).execute(
|
||||
"SELECT id, title, url FROM notification_services WHERE id = ? LIMIT 1",
|
||||
result = dict(get_db(dict).execute("""
|
||||
SELECT id, title, url
|
||||
FROM notification_services
|
||||
WHERE id = ?
|
||||
LIMIT 1
|
||||
""",
|
||||
(self.id,)
|
||||
).fetchone())
|
||||
|
||||
@@ -330,7 +342,7 @@ class NotificationServices:
|
||||
url (str): The apprise url of the service
|
||||
|
||||
Returns:
|
||||
dict: The info about the new service
|
||||
NotificationService: The instance representing the new service
|
||||
"""
|
||||
logging.info(f'Adding notification service with {title=}, {url=}')
|
||||
|
||||
|
||||
@@ -4,28 +4,65 @@ import logging
|
||||
from datetime import datetime
|
||||
from sqlite3 import IntegrityError
|
||||
from threading import Timer
|
||||
from typing import List, Literal
|
||||
from typing import List, Literal, Union
|
||||
|
||||
from apprise import Apprise
|
||||
from dateutil.relativedelta import relativedelta, weekday
|
||||
from flask import Flask
|
||||
from dateutil.relativedelta import relativedelta
|
||||
from dateutil.relativedelta import weekday as du_weekday
|
||||
|
||||
from backend.custom_exceptions import (InvalidKeyValue, InvalidTime,
|
||||
NotificationServiceNotFound,
|
||||
ReminderNotFound)
|
||||
from backend.db import close_db, get_db
|
||||
from backend.db import get_db
|
||||
from backend.helpers import RepeatQuantity, Singleton, SortingMethod, search_filter
|
||||
|
||||
filter_function = lambda query, p: (
|
||||
query in p["title"].lower()
|
||||
or query in p["text"].lower()
|
||||
)
|
||||
|
||||
def __next_selected_day(
|
||||
weekdays: List[int],
|
||||
weekday: int
|
||||
) -> int:
|
||||
"""Find the next allowed day in the week.
|
||||
|
||||
Args:
|
||||
weekdays (List[int]): The days of the week that are allowed.
|
||||
Monday is 0, Sunday is 6.
|
||||
weekday (int): The current weekday.
|
||||
|
||||
Returns:
|
||||
int: The next allowed weekday.
|
||||
"""
|
||||
return (
|
||||
# Get all days later than current, then grab first one.
|
||||
[d for d in weekdays if weekday < d]
|
||||
or
|
||||
# weekday is last allowed day, so it should grab the first
|
||||
# allowed day of the week.
|
||||
weekdays
|
||||
)[0]
|
||||
|
||||
def _find_next_time(
|
||||
original_time: int,
|
||||
repeat_quantity: Literal["years", "months", "weeks", "days", "hours", "minutes"],
|
||||
repeat_interval: int,
|
||||
weekdays: List[int]
|
||||
repeat_quantity: Union[RepeatQuantity, None],
|
||||
repeat_interval: Union[int, None],
|
||||
weekdays: Union[List[int], None]
|
||||
) -> int:
|
||||
"""Calculate the next timestep based on original time and repeat/interval
|
||||
values.
|
||||
|
||||
Args:
|
||||
original_time (int): The original time of the repeating timestamp.
|
||||
|
||||
repeat_quantity (Union[RepeatQuantity, None]): If set, what the quantity
|
||||
is of the repetition.
|
||||
|
||||
repeat_interval (Union[int, None]): If set, the value of the repetition.
|
||||
|
||||
weekdays (Union[List[int], None]): If set, on which days the time can
|
||||
continue. Monday is 0, Sunday is 6.
|
||||
|
||||
Returns:
|
||||
int: The next timestamp in the future.
|
||||
"""
|
||||
if weekdays is not None:
|
||||
weekdays.sort()
|
||||
|
||||
@@ -33,152 +70,50 @@ def _find_next_time(
|
||||
current_time = datetime.fromtimestamp(datetime.utcnow().timestamp())
|
||||
|
||||
if repeat_quantity is not None:
|
||||
td = relativedelta(**{repeat_quantity: repeat_interval})
|
||||
td = relativedelta(**{repeat_quantity.value: repeat_interval})
|
||||
while new_time <= current_time:
|
||||
new_time += td
|
||||
|
||||
else:
|
||||
next_day = ([d for d in weekdays if new_time.weekday() < d] or weekdays)[0]
|
||||
proposed_time = new_time + relativedelta(weekday=weekday(next_day))
|
||||
if proposed_time == new_time:
|
||||
proposed_time += relativedelta(weekday=weekday(next_day, 2))
|
||||
new_time = proposed_time
|
||||
|
||||
while new_time <= current_time:
|
||||
next_day = ([d for d in weekdays if new_time.weekday() < d] or weekdays)[0]
|
||||
proposed_time = new_time + relativedelta(weekday=weekday(next_day))
|
||||
# We run the loop contents at least once and then actually use the cond.
|
||||
# This is because we need to force the 'free' date to go to one of the
|
||||
# selected weekdays.
|
||||
# Say it's Monday, we set a reminder for Wednesday and make it repeat
|
||||
# on Tuesday and Thursday. Then the first notification needs to go on
|
||||
# Thurday, not Wednesday. So run code at least once to force that.
|
||||
# Afterwards, it can run normally to push the timestamp into the future.
|
||||
one_to_go = True
|
||||
while one_to_go or new_time <= current_time:
|
||||
next_day = __next_selected_day(weekdays, new_time.weekday())
|
||||
proposed_time = new_time + relativedelta(weekday=du_weekday(next_day))
|
||||
if proposed_time == new_time:
|
||||
proposed_time += relativedelta(weekday=weekday(next_day, 2))
|
||||
proposed_time += relativedelta(weekday=du_weekday(next_day, 2))
|
||||
new_time = proposed_time
|
||||
one_to_go = False
|
||||
|
||||
result = int(new_time.timestamp())
|
||||
logging.debug(
|
||||
f'{original_time=}, {current_time=} and interval of {repeat_interval} {repeat_quantity} leads to {result}'
|
||||
f'{original_time=}, {current_time=} ' +
|
||||
f'and interval of {repeat_interval} {repeat_quantity} ' +
|
||||
f'leads to {result}'
|
||||
)
|
||||
return result
|
||||
|
||||
class ReminderHandler:
|
||||
"""Handle set reminders
|
||||
"""
|
||||
def __init__(self, context) -> None:
|
||||
self.context = context
|
||||
self.next_trigger = {
|
||||
'thread': None,
|
||||
'time': None
|
||||
}
|
||||
|
||||
return
|
||||
|
||||
def __trigger_reminders(self, time: int) -> None:
|
||||
"""Trigger all reminders that are set for a certain time
|
||||
|
||||
Args:
|
||||
time (int): The time of the reminders to trigger
|
||||
"""
|
||||
with self.context():
|
||||
cursor = get_db(dict)
|
||||
cursor.execute("""
|
||||
SELECT
|
||||
r.id,
|
||||
r.title, r.text,
|
||||
r.repeat_quantity, r.repeat_interval,
|
||||
r.weekdays,
|
||||
r.original_time
|
||||
FROM reminders r
|
||||
WHERE time = ?;
|
||||
""", (time,))
|
||||
reminders = list(map(dict, cursor))
|
||||
|
||||
for reminder in reminders:
|
||||
cursor.execute("""
|
||||
SELECT url
|
||||
FROM reminder_services rs
|
||||
INNER JOIN notification_services ns
|
||||
ON rs.notification_service_id = ns.id
|
||||
WHERE rs.reminder_id = ?;
|
||||
""", (reminder['id'],))
|
||||
|
||||
# Send of reminder
|
||||
a = Apprise()
|
||||
for url in cursor:
|
||||
a.add(url['url'])
|
||||
a.notify(title=reminder["title"], body=reminder["text"])
|
||||
|
||||
if reminder['repeat_quantity'] is None and reminder['weekdays'] is None:
|
||||
# Delete the reminder from the database
|
||||
cursor.execute(
|
||||
"DELETE FROM reminders WHERE id = ?;",
|
||||
(reminder['id'],)
|
||||
)
|
||||
logging.info(f'Deleted reminder {reminder["id"]}')
|
||||
else:
|
||||
# Set next time
|
||||
new_time = _find_next_time(
|
||||
reminder['original_time'],
|
||||
reminder['repeat_quantity'],
|
||||
reminder['repeat_interval'],
|
||||
[int(d) for d in reminder['weekdays'].split(',')] if reminder['weekdays'] is not None else None
|
||||
)
|
||||
cursor.execute(
|
||||
"UPDATE reminders SET time = ? WHERE id = ?;",
|
||||
(new_time, reminder['id'])
|
||||
)
|
||||
|
||||
self.next_trigger.update({
|
||||
'thread': None,
|
||||
'time': None
|
||||
})
|
||||
self.find_next_reminder()
|
||||
|
||||
def find_next_reminder(self, time: int=None) -> None:
|
||||
"""Determine when the soonest reminder is and set the timer to that time
|
||||
|
||||
Args:
|
||||
time (int, optional): The timestamp to check for. Otherwise check soonest in database. Defaults to None.
|
||||
"""
|
||||
if not time:
|
||||
with self.context():
|
||||
time = get_db().execute("""
|
||||
SELECT DISTINCT r1.time
|
||||
FROM reminders r1
|
||||
LEFT JOIN reminders r2
|
||||
ON r1.time > r2.time
|
||||
WHERE r2.id IS NULL;
|
||||
""").fetchone()
|
||||
if time is None:
|
||||
return
|
||||
time = time[0]
|
||||
|
||||
if (self.next_trigger['thread'] is None
|
||||
or time < self.next_trigger['time']):
|
||||
if self.next_trigger['thread'] is not None:
|
||||
self.next_trigger['thread'].cancel()
|
||||
|
||||
t = time - datetime.utcnow().timestamp()
|
||||
self.next_trigger['thread'] = Timer(
|
||||
t,
|
||||
self.__trigger_reminders,
|
||||
(time,)
|
||||
)
|
||||
self.next_trigger['thread'].name = "ReminderHandler"
|
||||
self.next_trigger['thread'].start()
|
||||
self.next_trigger['time'] = time
|
||||
|
||||
def stop_handling(self) -> None:
|
||||
"""Stop the timer if it's active
|
||||
"""
|
||||
if self.next_trigger['thread'] is not None:
|
||||
self.next_trigger['thread'].cancel()
|
||||
return
|
||||
|
||||
handler_context = Flask('handler')
|
||||
handler_context.teardown_appcontext(close_db)
|
||||
reminder_handler = ReminderHandler(handler_context.app_context)
|
||||
|
||||
class Reminder:
|
||||
"""Represents a reminder
|
||||
"""
|
||||
def __init__(self, user_id: int, reminder_id: int):
|
||||
def __init__(self, user_id: int, reminder_id: int) -> None:
|
||||
"""Create an instance.
|
||||
|
||||
Args:
|
||||
user_id (int): The ID of the user.
|
||||
reminder_id (int): The ID of the reminder.
|
||||
|
||||
Raises:
|
||||
ReminderNotFound: Reminder with given ID does not exist or is not
|
||||
owned by user.
|
||||
"""
|
||||
self.id = reminder_id
|
||||
|
||||
# Check if reminder exists
|
||||
@@ -188,6 +123,26 @@ class Reminder:
|
||||
).fetchone():
|
||||
raise ReminderNotFound
|
||||
|
||||
return
|
||||
|
||||
def _get_notification_services(self) -> List[int]:
|
||||
"""Get ID's of notification services linked to the reminder.
|
||||
|
||||
Returns:
|
||||
List[int]: The list with ID's.
|
||||
"""
|
||||
result = [
|
||||
r[0]
|
||||
for r in get_db().execute("""
|
||||
SELECT notification_service_id
|
||||
FROM reminder_services
|
||||
WHERE reminder_id = ?;
|
||||
""",
|
||||
(self.id,)
|
||||
)
|
||||
]
|
||||
return result
|
||||
|
||||
def get(self) -> dict:
|
||||
"""Get info about the reminder
|
||||
|
||||
@@ -211,46 +166,65 @@ class Reminder:
|
||||
).fetchone()
|
||||
reminder = dict(reminder)
|
||||
|
||||
reminder['notification_services'] = list(map(lambda r: r[0], get_db().execute("""
|
||||
SELECT notification_service_id
|
||||
FROM reminder_services
|
||||
WHERE reminder_id = ?;
|
||||
""", (self.id,))))
|
||||
reminder['notification_services'] = self._get_notification_services()
|
||||
|
||||
return reminder
|
||||
|
||||
def update(
|
||||
self,
|
||||
title: str = None,
|
||||
time: int = None,
|
||||
notification_services: List[int] = None,
|
||||
text: str = None,
|
||||
repeat_quantity: Literal["years", "months", "weeks", "days", "hours", "minutes"] = None,
|
||||
repeat_interval: int = None,
|
||||
weekdays: List[int] = None,
|
||||
color: str = None
|
||||
title: Union[None, str] = None,
|
||||
time: Union[None, int] = None,
|
||||
notification_services: Union[None, List[int]] = None,
|
||||
text: Union[None, str] = None,
|
||||
repeat_quantity: Union[None, RepeatQuantity] = None,
|
||||
repeat_interval: Union[None, int] = None,
|
||||
weekdays: Union[None, List[int]] = None,
|
||||
color: Union[None, str] = None
|
||||
) -> dict:
|
||||
"""Edit the reminder
|
||||
"""Edit the reminder.
|
||||
|
||||
Args:
|
||||
title (str): The new title of the entry. Defaults to None.
|
||||
time (int): The new UTC epoch timestamp the the reminder should be send. Defaults to None.
|
||||
notification_services (List[int]): The new list of id's of the notification services to use to send the reminder. Defaults to None.
|
||||
text (str, optional): The new body of the reminder. Defaults to None.
|
||||
repeat_quantity (Literal["years", "months", "weeks", "days", "hours", "minutes"], optional): The new quantity of the repeat specified for the reminder. Defaults to None.
|
||||
repeat_interval (int, optional): The new amount of repeat_quantity, like "5" (hours). Defaults to None.
|
||||
weekdays (List[int], optional): The new indexes of the days of the week that the reminder should run. Defaults to None.
|
||||
color (str, optional): The new hex code of the color of the reminder, which is shown in the web-ui. Defaults to None.
|
||||
title (Union[None, str]): The new title of the entry.
|
||||
Defaults to None.
|
||||
|
||||
time (Union[None, int]): The new UTC epoch timestamp when the
|
||||
reminder should be send.
|
||||
Defaults to None.
|
||||
|
||||
notification_services (Union[None, List[int]]): The new list
|
||||
of id's of the notification services to use to send the reminder.
|
||||
Defaults to None.
|
||||
|
||||
text (Union[None, str], optional): The new body of the reminder.
|
||||
Defaults to None.
|
||||
|
||||
repeat_quantity (Union[None, RepeatQuantity], optional): The new
|
||||
quantity of the repeat specified for the reminder.
|
||||
Defaults to None.
|
||||
|
||||
repeat_interval (Union[None, int], optional): The new amount of
|
||||
repeat_quantity, like "5" (hours).
|
||||
Defaults to None.
|
||||
|
||||
weekdays (Union[None, List[int]], optional): The new indexes of
|
||||
the days of the week that the reminder should run.
|
||||
Defaults to None.
|
||||
|
||||
color (Union[None, str], optional): The new hex code of the color
|
||||
of the reminder, which is shown in the web-ui.
|
||||
Defaults to None.
|
||||
|
||||
Note about args:
|
||||
Either repeat_quantity and repeat_interval are given, weekdays is given or neither, but not both.
|
||||
Either repeat_quantity and repeat_interval are given, weekdays is
|
||||
given or neither, but not both.
|
||||
|
||||
Raises:
|
||||
NotificationServiceNotFound: One of the notification services was not found
|
||||
InvalidKeyValue: The value of one of the keys is not valid or the "Note about args" is violated
|
||||
NotificationServiceNotFound: One of the notification services was not found.
|
||||
InvalidKeyValue: The value of one of the keys is not valid or
|
||||
the "Note about args" is violated.
|
||||
|
||||
Returns:
|
||||
dict: The new reminder info
|
||||
dict: The new reminder info.
|
||||
"""
|
||||
logging.info(
|
||||
f'Updating notification service {self.id}: '
|
||||
@@ -264,8 +238,9 @@ class Reminder:
|
||||
raise InvalidKeyValue('repeat_quantity', repeat_quantity)
|
||||
elif repeat_quantity is not None and repeat_interval is None:
|
||||
raise InvalidKeyValue('repeat_interval', repeat_interval)
|
||||
elif weekdays is not None and repeat_quantity is not None and repeat_interval is not None:
|
||||
elif weekdays is not None and repeat_quantity is not None:
|
||||
raise InvalidKeyValue('weekdays', weekdays)
|
||||
|
||||
repeated_reminder = (
|
||||
(repeat_quantity is not None and repeat_interval is not None)
|
||||
or weekdays is not None
|
||||
@@ -285,26 +260,35 @@ class Reminder:
|
||||
'text': text,
|
||||
'repeat_quantity': repeat_quantity,
|
||||
'repeat_interval': repeat_interval,
|
||||
'weekdays': ",".join(map(str, sorted(weekdays))) if weekdays is not None else None,
|
||||
'weekdays':
|
||||
",".join(map(str, sorted(weekdays)))
|
||||
if weekdays is not None else
|
||||
None,
|
||||
'color': color
|
||||
}
|
||||
for k, v in new_values.items():
|
||||
if k in ('repeat_quantity', 'repeat_interval', 'weekdays', 'color') or v is not None:
|
||||
if (
|
||||
k in ('repeat_quantity', 'repeat_interval', 'weekdays', 'color')
|
||||
or v is not None
|
||||
):
|
||||
data[k] = v
|
||||
|
||||
# Update database
|
||||
if repeated_reminder:
|
||||
next_time = _find_next_time(
|
||||
data["time"],
|
||||
data["repeat_quantity"], data["repeat_interval"],
|
||||
RepeatQuantity(data["repeat_quantity"]),
|
||||
data["repeat_interval"],
|
||||
weekdays
|
||||
)
|
||||
cursor.execute("""
|
||||
UPDATE reminders
|
||||
SET
|
||||
title=?, text=?,
|
||||
title=?,
|
||||
text=?,
|
||||
time=?,
|
||||
repeat_quantity=?, repeat_interval=?,
|
||||
repeat_quantity=?,
|
||||
repeat_interval=?,
|
||||
weekdays=?,
|
||||
original_time=?,
|
||||
color=?
|
||||
@@ -326,9 +310,11 @@ class Reminder:
|
||||
cursor.execute("""
|
||||
UPDATE reminders
|
||||
SET
|
||||
title=?, text=?,
|
||||
title=?,
|
||||
text=?,
|
||||
time=?,
|
||||
repeat_quantity=?, repeat_interval=?,
|
||||
repeat_quantity=?,
|
||||
repeat_interval=?,
|
||||
weekdays=?,
|
||||
color=?
|
||||
WHERE id = ?;
|
||||
@@ -346,18 +332,29 @@ class Reminder:
|
||||
if notification_services:
|
||||
cursor.connection.isolation_level = None
|
||||
cursor.execute("BEGIN TRANSACTION;")
|
||||
cursor.execute("DELETE FROM reminder_services WHERE reminder_id = ?", (self.id,))
|
||||
cursor.execute(
|
||||
"DELETE FROM reminder_services WHERE reminder_id = ?",
|
||||
(self.id,)
|
||||
)
|
||||
try:
|
||||
cursor.executemany(
|
||||
"INSERT INTO reminder_services(reminder_id, notification_service_id) VALUES (?,?)",
|
||||
cursor.executemany("""
|
||||
INSERT INTO reminder_services(
|
||||
reminder_id,
|
||||
notification_service_id
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
((self.id, s) for s in notification_services)
|
||||
)
|
||||
cursor.execute("COMMIT;")
|
||||
|
||||
except IntegrityError:
|
||||
raise NotificationServiceNotFound
|
||||
cursor.connection.isolation_level = ""
|
||||
|
||||
reminder_handler.find_next_reminder(next_time)
|
||||
finally:
|
||||
cursor.connection.isolation_level = ""
|
||||
|
||||
ReminderHandler().find_next_reminder(next_time)
|
||||
return self.get()
|
||||
|
||||
def delete(self) -> None:
|
||||
@@ -365,74 +362,76 @@ class Reminder:
|
||||
"""
|
||||
logging.info(f'Deleting reminder {self.id}')
|
||||
get_db().execute("DELETE FROM reminders WHERE id = ?", (self.id,))
|
||||
reminder_handler.find_next_reminder()
|
||||
ReminderHandler().find_next_reminder()
|
||||
return
|
||||
|
||||
class Reminders:
|
||||
"""Represents the reminder library of the user account
|
||||
"""
|
||||
sort_functions = {
|
||||
'time': (lambda r: (r['time'], r['title'], r['text'], r['color']), False),
|
||||
'time_reversed': (lambda r: (r['time'], r['title'], r['text'], r['color']), True),
|
||||
'title': (lambda r: (r['title'], r['time'], r['text'], r['color']), False),
|
||||
'title_reversed': (lambda r: (r['title'], r['time'], r['text'], r['color']), True),
|
||||
'date_added': (lambda r: r['id'], False),
|
||||
'date_added_reversed': (lambda r: r['id'], True)
|
||||
}
|
||||
|
||||
def __init__(self, user_id: int):
|
||||
self.user_id = user_id
|
||||
"""
|
||||
|
||||
def fetchall(self, sort_by: Literal["time", "time_reversed", "title", "title_reversed", "date_added", "date_added_reversed"] = "time") -> List[dict]:
|
||||
def __init__(self, user_id: int) -> None:
|
||||
"""Create an instance.
|
||||
|
||||
Args:
|
||||
user_id (int): The ID of the user.
|
||||
"""
|
||||
self.user_id = user_id
|
||||
return
|
||||
|
||||
def fetchall(
|
||||
self,
|
||||
sort_by: SortingMethod = SortingMethod.TIME
|
||||
) -> List[dict]:
|
||||
"""Get all reminders
|
||||
|
||||
Args:
|
||||
sort_by (Literal["time", "time_reversed", "title", "title_reversed", "date_added", "date_added_reversed"], optional): How to sort the result. Defaults to "time".
|
||||
sort_by (SortingMethod, optional): How to sort the result.
|
||||
Defaults to SortingMethod.TIME.
|
||||
|
||||
Returns:
|
||||
List[dict]: The id, title, text, time and color of each reminder
|
||||
"""
|
||||
sort_function = self.sort_functions.get(
|
||||
sort_by,
|
||||
self.sort_functions['time']
|
||||
)
|
||||
|
||||
# Fetch all reminders
|
||||
reminders: list = list(map(dict, get_db(dict).execute("""
|
||||
SELECT
|
||||
id,
|
||||
title, text,
|
||||
time,
|
||||
repeat_quantity,
|
||||
repeat_interval,
|
||||
weekdays,
|
||||
color
|
||||
FROM reminders
|
||||
WHERE user_id = ?;
|
||||
""",
|
||||
(self.user_id,)
|
||||
)))
|
||||
reminders = [
|
||||
dict(r)
|
||||
for r in get_db(dict).execute("""
|
||||
SELECT
|
||||
id,
|
||||
title, text,
|
||||
time,
|
||||
repeat_quantity,
|
||||
repeat_interval,
|
||||
weekdays,
|
||||
color
|
||||
FROM reminders
|
||||
WHERE user_id = ?;
|
||||
""",
|
||||
(self.user_id,)
|
||||
)
|
||||
]
|
||||
|
||||
# Sort result
|
||||
reminders.sort(key=sort_function[0], reverse=sort_function[1])
|
||||
reminders.sort(key=sort_by.value[0], reverse=sort_by.value[1])
|
||||
|
||||
return reminders
|
||||
|
||||
def search(self, query: str, sort_by: Literal["time", "time_reversed", "title", "title_reversed", "date_added", "date_added_reversed"] = "time") -> List[dict]:
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
sort_by: SortingMethod = SortingMethod.TIME) -> List[dict]:
|
||||
"""Search for reminders
|
||||
|
||||
Args:
|
||||
query (str): The term to search for
|
||||
sort_by (Literal["time", "time_reversed", "title", "title_reversed", "date_added", "date_added_reversed"], optional): How to sort the result. Defaults to "time".
|
||||
query (str): The term to search for.
|
||||
sort_by (SortingMethod, optional): How to sort the result.
|
||||
Defaults to SortingMethod.TIME.
|
||||
|
||||
Returns:
|
||||
List[dict]: All reminders that match. Similar output to self.fetchall
|
||||
"""
|
||||
query = query.lower()
|
||||
reminders = list(filter(
|
||||
lambda p: filter_function(query, p),
|
||||
self.fetchall(sort_by)
|
||||
))
|
||||
"""
|
||||
reminders = [
|
||||
r for r in self.fetchall(sort_by)
|
||||
if search_filter(query, r)
|
||||
]
|
||||
return reminders
|
||||
|
||||
def fetchone(self, id: int) -> Reminder:
|
||||
@@ -452,32 +451,51 @@ class Reminders:
|
||||
time: int,
|
||||
notification_services: List[int],
|
||||
text: str = '',
|
||||
repeat_quantity: Literal["years", "months", "weeks", "days", "hours", "minutes"] = None,
|
||||
repeat_interval: int = None,
|
||||
weekdays: List[int] = None,
|
||||
color: str = None
|
||||
repeat_quantity: Union[None, RepeatQuantity] = None,
|
||||
repeat_interval: Union[None, int] = None,
|
||||
weekdays: Union[None, List[int]] = None,
|
||||
color: Union[None, str] = None
|
||||
) -> Reminder:
|
||||
"""Add a reminder
|
||||
|
||||
Args:
|
||||
title (str): The title of the entry
|
||||
title (str): The title of the entry.
|
||||
|
||||
time (int): The UTC epoch timestamp the the reminder should be send.
|
||||
notification_services (List[int]): The id's of the notification services to use to send the reminder.
|
||||
text (str, optional): The body of the reminder. Defaults to ''.
|
||||
repeat_quantity (Literal["years", "months", "weeks", "days", "hours", "minutes"], optional): The quantity of the repeat specified for the reminder. Defaults to None.
|
||||
repeat_interval (int, optional): The amount of repeat_quantity, like "5" (hours). Defaults to None.
|
||||
weekdays (List[int], optional): The indexes of the days of the week that the reminder should run. Defaults to None.
|
||||
color (str, optional): The hex code of the color of the reminder, which is shown in the web-ui. Defaults to None.
|
||||
|
||||
notification_services (List[int]): The id's of the notification services
|
||||
to use to send the reminder.
|
||||
|
||||
text (str, optional): The body of the reminder.
|
||||
Defaults to ''.
|
||||
|
||||
repeat_quantity (Union[None, RepeatQuantity], optional): The quantity
|
||||
of the repeat specified for the reminder.
|
||||
Defaults to None.
|
||||
|
||||
repeat_interval (Union[None, int], optional): The amount of repeat_quantity,
|
||||
like "5" (hours).
|
||||
Defaults to None.
|
||||
|
||||
weekdays (Union[None, List[int]], optional): The indexes of the days
|
||||
of the week that the reminder should run.
|
||||
Defaults to None.
|
||||
|
||||
color (Union[None, str], optional): The hex code of the color of the
|
||||
reminder, which is shown in the web-ui.
|
||||
Defaults to None.
|
||||
|
||||
Note about args:
|
||||
Either repeat_quantity and repeat_interval are given, weekdays is given or neither, but not both.
|
||||
Either repeat_quantity and repeat_interval are given,
|
||||
weekdays is given or neither, but not both.
|
||||
|
||||
Raises:
|
||||
NotificationServiceNotFound: One of the notification services was not found
|
||||
InvalidKeyValue: The value of one of the keys is not valid or the "Note about args" is violated
|
||||
NotificationServiceNotFound: One of the notification services was not found.
|
||||
InvalidKeyValue: The value of one of the keys is not valid
|
||||
or the "Note about args" is violated.
|
||||
|
||||
Returns:
|
||||
dict: The info about the reminder
|
||||
dict: The info about the reminder.
|
||||
"""
|
||||
logging.info(
|
||||
f'Adding reminder with {title=}, {time=}, {notification_services=}, '
|
||||
@@ -492,50 +510,89 @@ class Reminders:
|
||||
raise InvalidKeyValue('repeat_quantity', repeat_quantity)
|
||||
elif repeat_quantity is not None and repeat_interval is None:
|
||||
raise InvalidKeyValue('repeat_interval', repeat_interval)
|
||||
elif weekdays is not None and repeat_quantity is not None and repeat_interval is not None:
|
||||
elif (
|
||||
weekdays is not None
|
||||
and repeat_quantity is not None
|
||||
and repeat_interval is not None
|
||||
):
|
||||
raise InvalidKeyValue('weekdays', weekdays)
|
||||
|
||||
cursor = get_db()
|
||||
for service in notification_services:
|
||||
if not cursor.execute(
|
||||
"SELECT 1 FROM notification_services WHERE id = ? AND user_id = ? LIMIT 1;",
|
||||
if not cursor.execute("""
|
||||
SELECT 1
|
||||
FROM notification_services
|
||||
WHERE id = ?
|
||||
AND user_id = ?
|
||||
LIMIT 1;
|
||||
""",
|
||||
(service, self.user_id)
|
||||
).fetchone():
|
||||
raise NotificationServiceNotFound
|
||||
|
||||
if repeat_quantity is not None and repeat_interval is not None:
|
||||
id = cursor.execute("""
|
||||
INSERT INTO reminders(user_id, title, text, time, repeat_quantity, repeat_interval, original_time, color)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?);
|
||||
""", (self.user_id, title, text, time, repeat_quantity, repeat_interval, time, color)
|
||||
).lastrowid
|
||||
|
||||
elif weekdays is not None:
|
||||
weekdays = ",".join(map(str, sorted(weekdays)))
|
||||
id = cursor.execute("""
|
||||
INSERT INTO reminders(user_id, title, text, time, weekdays, original_time, color)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?);
|
||||
""", (self.user_id, title, text, time, weekdays, time, color)
|
||||
).lastrowid
|
||||
|
||||
# Prepare args
|
||||
if any((repeat_quantity, weekdays)):
|
||||
original_time = time
|
||||
time = _find_next_time(
|
||||
original_time,
|
||||
repeat_quantity,
|
||||
repeat_interval,
|
||||
weekdays
|
||||
)
|
||||
else:
|
||||
id = cursor.execute("""
|
||||
INSERT INTO reminders(user_id, title, text, time, color)
|
||||
VALUES (?, ?, ?, ?, ?);
|
||||
""", (self.user_id, title, text, time, color)
|
||||
).lastrowid
|
||||
|
||||
original_time = None
|
||||
|
||||
if weekdays is not None:
|
||||
weekdays = ",".join(map(str, sorted(weekdays)))
|
||||
|
||||
if repeat_quantity is not None:
|
||||
repeat_quantity = repeat_quantity.value
|
||||
|
||||
cursor.connection.isolation_level = None
|
||||
cursor.execute("BEGIN TRANSACTION;")
|
||||
|
||||
id = cursor.execute("""
|
||||
INSERT INTO reminders(
|
||||
user_id,
|
||||
title, text,
|
||||
time,
|
||||
repeat_quantity, repeat_interval,
|
||||
weekdays,
|
||||
original_time,
|
||||
color
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?);
|
||||
""", (
|
||||
self.user_id,
|
||||
title, text,
|
||||
time,
|
||||
repeat_quantity,
|
||||
repeat_interval,
|
||||
weekdays,
|
||||
original_time,
|
||||
color
|
||||
)).lastrowid
|
||||
|
||||
try:
|
||||
cursor.executemany(
|
||||
"INSERT INTO reminder_services(reminder_id, notification_service_id) VALUES (?, ?);",
|
||||
cursor.executemany("""
|
||||
INSERT INTO reminder_services(
|
||||
reminder_id,
|
||||
notification_service_id
|
||||
)
|
||||
VALUES (?, ?);
|
||||
""",
|
||||
((id, service) for service in notification_services)
|
||||
)
|
||||
cursor.execute("COMMIT;")
|
||||
|
||||
except IntegrityError:
|
||||
raise NotificationServiceNotFound
|
||||
|
||||
reminder_handler.find_next_reminder(time)
|
||||
|
||||
# Return info
|
||||
finally:
|
||||
cursor.connection.isolation_level = ''
|
||||
|
||||
ReminderHandler().find_next_reminder(time)
|
||||
|
||||
return self.fetchone(id)
|
||||
|
||||
def test_reminder(
|
||||
@@ -544,23 +601,164 @@ class Reminders:
|
||||
notification_services: List[int],
|
||||
text: str = ''
|
||||
) -> None:
|
||||
"""Test send a reminder draft
|
||||
"""Test send a reminder draft.
|
||||
|
||||
Args:
|
||||
title (str): Title title of the entry
|
||||
notification_service (int): The id of the notification service to use to send the reminder
|
||||
text (str, optional): The body of the reminder. Defaults to ''.
|
||||
title (str): Title title of the entry.
|
||||
|
||||
notification_service (int): The id of the notification service to
|
||||
use to send the reminder.
|
||||
|
||||
text (str, optional): The body of the reminder.
|
||||
Defaults to ''.
|
||||
"""
|
||||
logging.info(f'Testing reminder with {title=}, {notification_services=}, {text=}')
|
||||
a = Apprise()
|
||||
cursor = get_db(dict)
|
||||
|
||||
for service in notification_services:
|
||||
url = cursor.execute(
|
||||
"SELECT url FROM notification_services WHERE id = ? AND user_id = ? LIMIT 1;",
|
||||
url = cursor.execute("""
|
||||
SELECT url
|
||||
FROM notification_services
|
||||
WHERE id = ?
|
||||
AND user_id = ?
|
||||
LIMIT 1;
|
||||
""",
|
||||
(service, self.user_id)
|
||||
).fetchone()
|
||||
if not url:
|
||||
raise NotificationServiceNotFound
|
||||
a.add(url[0])
|
||||
|
||||
a.notify(title=title, body=text)
|
||||
return
|
||||
|
||||
|
||||
class ReminderHandler(metaclass=Singleton):
|
||||
"""Handle set reminders.
|
||||
|
||||
Note: Singleton.
|
||||
"""
|
||||
def __init__(self, context) -> None:
|
||||
"""Create instance of handler.
|
||||
|
||||
Args:
|
||||
context (AppContext): `Flask.app_context`
|
||||
"""
|
||||
self.context = context
|
||||
self.thread: Union[Timer, None] = None
|
||||
self.time: Union[int, None] = None
|
||||
return
|
||||
|
||||
def __trigger_reminders(self, time: int) -> None:
|
||||
"""Trigger all reminders that are set for a certain time
|
||||
|
||||
Args:
|
||||
time (int): The time of the reminders to trigger
|
||||
"""
|
||||
with self.context():
|
||||
cursor = get_db(dict)
|
||||
reminders = [
|
||||
dict(r)
|
||||
for r in cursor.execute("""
|
||||
SELECT
|
||||
id, user_id,
|
||||
title, text,
|
||||
repeat_quantity, repeat_interval,
|
||||
weekdays,
|
||||
original_time
|
||||
FROM reminders
|
||||
WHERE time = ?;
|
||||
""",
|
||||
(time,)
|
||||
)
|
||||
]
|
||||
|
||||
for reminder in reminders:
|
||||
cursor.execute("""
|
||||
SELECT url
|
||||
FROM reminder_services rs
|
||||
INNER JOIN notification_services ns
|
||||
ON rs.notification_service_id = ns.id
|
||||
WHERE rs.reminder_id = ?;
|
||||
""",
|
||||
(reminder['id'],)
|
||||
)
|
||||
|
||||
# Send reminder
|
||||
a = Apprise()
|
||||
for url in cursor:
|
||||
a.add(url['url'])
|
||||
a.notify(title=reminder["title"], body=reminder["text"])
|
||||
|
||||
self.thread = None
|
||||
self.time = None
|
||||
|
||||
if (reminder['repeat_quantity'], reminder['weekdays']) == (None, None):
|
||||
# Delete the reminder from the database
|
||||
Reminder(reminder["user_id"], reminder["id"]).delete()
|
||||
|
||||
else:
|
||||
# Set next time
|
||||
new_time = _find_next_time(
|
||||
reminder['original_time'],
|
||||
RepeatQuantity(reminder['repeat_quantity']),
|
||||
reminder['repeat_interval'],
|
||||
[int(d) for d in reminder['weekdays'].split(',')]
|
||||
if reminder['weekdays'] is not None else
|
||||
None
|
||||
)
|
||||
cursor.execute(
|
||||
"UPDATE reminders SET time = ? WHERE id = ?;",
|
||||
(new_time, reminder['id'])
|
||||
)
|
||||
|
||||
self.find_next_reminder()
|
||||
return
|
||||
|
||||
def find_next_reminder(self, time: int=None) -> None:
|
||||
"""Determine when the soonest reminder is and set the timer to that time
|
||||
|
||||
Args:
|
||||
time (int, optional): The timestamp to check for.
|
||||
Otherwise check soonest in database.
|
||||
Defaults to None.
|
||||
"""
|
||||
if not time:
|
||||
with self.context():
|
||||
time = get_db().execute("""
|
||||
SELECT DISTINCT r1.time
|
||||
FROM reminders r1
|
||||
LEFT JOIN reminders r2
|
||||
ON r1.time > r2.time
|
||||
WHERE r2.id IS NULL;
|
||||
""").fetchone()
|
||||
if time is None:
|
||||
return
|
||||
time = time[0]
|
||||
|
||||
if (
|
||||
self.thread is None
|
||||
or time < self.time
|
||||
):
|
||||
if self.thread is not None:
|
||||
self.thread.cancel()
|
||||
|
||||
t = time - datetime.utcnow().timestamp()
|
||||
self.thread = Timer(
|
||||
t,
|
||||
self.__trigger_reminders,
|
||||
(time,)
|
||||
)
|
||||
self.thread.name = "ReminderHandler"
|
||||
self.thread.start()
|
||||
self.time = time
|
||||
|
||||
return
|
||||
|
||||
def stop_handling(self) -> None:
|
||||
"""Stop the timer if it's active
|
||||
"""
|
||||
if self.thread is not None:
|
||||
self.thread.cancel()
|
||||
return
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
#-*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
Hashing and salting
|
||||
"""
|
||||
|
||||
from base64 import urlsafe_b64encode
|
||||
from hashlib import pbkdf2_hmac
|
||||
from secrets import token_bytes
|
||||
@@ -29,7 +33,6 @@ def generate_salt_hash(password: str) -> Tuple[bytes, bytes]:
|
||||
Returns:
|
||||
Tuple[bytes, bytes]: The salt (1) and hashed_password (2)
|
||||
"""
|
||||
# Hash the password
|
||||
salt = token_bytes()
|
||||
hashed_password = get_hash(salt, password)
|
||||
del password
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
#-*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
Getting and setting settings
|
||||
"""
|
||||
|
||||
from backend.custom_exceptions import InvalidKeyValue, KeyNotFound
|
||||
from backend.db import __DATABASE_VERSION__, get_db
|
||||
|
||||
@@ -85,7 +89,7 @@ def get_admin_settings() -> dict:
|
||||
"""
|
||||
return dict((
|
||||
(key, _reverse_format_setting(key, value))
|
||||
for (key, value) in get_db().execute("""
|
||||
for key, value in get_db().execute("""
|
||||
SELECT key, value
|
||||
FROM config
|
||||
WHERE
|
||||
|
||||
@@ -2,23 +2,30 @@
|
||||
|
||||
import logging
|
||||
from sqlite3 import IntegrityError
|
||||
from typing import List, Literal
|
||||
from typing import List, Union
|
||||
|
||||
from apprise import Apprise
|
||||
|
||||
from backend.custom_exceptions import (NotificationServiceNotFound,
|
||||
ReminderNotFound)
|
||||
from backend.db import get_db
|
||||
from backend.helpers import TimelessSortingMethod, search_filter
|
||||
|
||||
filter_function = lambda query, p: (
|
||||
query in p["title"].lower()
|
||||
or query in p["text"].lower()
|
||||
)
|
||||
|
||||
class StaticReminder:
|
||||
"""Represents a static reminder
|
||||
"""
|
||||
def __init__(self, user_id: int, reminder_id: int) -> None:
|
||||
"""Create an instance.
|
||||
|
||||
Args:
|
||||
user_id (int): The ID of the user.
|
||||
reminder_id (int): The ID of the reminder.
|
||||
|
||||
Raises:
|
||||
ReminderNotFound: Reminder with given ID does not exist or is not
|
||||
owned by user.
|
||||
"""
|
||||
self.id = reminder_id
|
||||
|
||||
# Check if reminder exists
|
||||
@@ -27,12 +34,32 @@ class StaticReminder:
|
||||
(self.id, user_id)
|
||||
).fetchone():
|
||||
raise ReminderNotFound
|
||||
|
||||
|
||||
return
|
||||
|
||||
def _get_notification_services(self) -> List[int]:
|
||||
"""Get ID's of notification services linked to the static reminder.
|
||||
|
||||
Returns:
|
||||
List[int]: The list with ID's.
|
||||
"""
|
||||
result = [
|
||||
r[0]
|
||||
for r in get_db().execute("""
|
||||
SELECT notification_service_id
|
||||
FROM reminder_services
|
||||
WHERE static_reminder_id = ?;
|
||||
""",
|
||||
(self.id,)
|
||||
)
|
||||
]
|
||||
return result
|
||||
|
||||
def get(self) -> dict:
|
||||
"""Get info about the static reminder
|
||||
|
||||
Returns:
|
||||
dict: The info about the reminder
|
||||
dict: The info about the static reminder
|
||||
"""
|
||||
reminder = get_db(dict).execute("""
|
||||
SELECT
|
||||
@@ -47,28 +74,33 @@ class StaticReminder:
|
||||
).fetchone()
|
||||
reminder = dict(reminder)
|
||||
|
||||
reminder['notification_services'] = list(map(lambda r: r[0], get_db().execute("""
|
||||
SELECT notification_service_id
|
||||
FROM reminder_services
|
||||
WHERE static_reminder_id = ?;
|
||||
""", (self.id,))))
|
||||
reminder['notification_services'] = self._get_notification_services()
|
||||
|
||||
return reminder
|
||||
|
||||
def update(
|
||||
self,
|
||||
title: str = None,
|
||||
notification_services: List[int] = None,
|
||||
text: str = None,
|
||||
color: str = None
|
||||
title: Union[str, None] = None,
|
||||
notification_services: Union[List[int], None] = None,
|
||||
text: Union[str, None] = None,
|
||||
color: Union[str, None] = None
|
||||
) -> dict:
|
||||
"""Edit the static reminder
|
||||
"""Edit the static reminder.
|
||||
|
||||
Args:
|
||||
title (str, optional): The new title of the entry. Defaults to None.
|
||||
notification_services (List[int], optional): The new id's of the notification services to use to send the reminder. Defaults to None.
|
||||
text (str, optional): The new body of the reminder. Defaults to None.
|
||||
color (str, optional): The new hex code of the color of the reminder, which is shown in the web-ui. Defaults to None.
|
||||
title (Union[str, None], optional): The new title of the entry.
|
||||
Defaults to None.
|
||||
|
||||
notification_services (Union[List[int], None], optional):
|
||||
The new id's of the notification services to use to send the reminder.
|
||||
Defaults to None.
|
||||
|
||||
text (Union[str, None], optional): The new body of the reminder.
|
||||
Defaults to None.
|
||||
|
||||
color (Union[str, None], optional): The new hex code of the color
|
||||
of the reminder, which is shown in the web-ui.
|
||||
Defaults to None.
|
||||
|
||||
Raises:
|
||||
NotificationServiceNotFound: One of the notification services was not found
|
||||
@@ -109,16 +141,27 @@ class StaticReminder:
|
||||
if notification_services:
|
||||
cursor.connection.isolation_level = None
|
||||
cursor.execute("BEGIN TRANSACTION;")
|
||||
cursor.execute("DELETE FROM reminder_services WHERE static_reminder_id = ?", (self.id,))
|
||||
cursor.execute(
|
||||
"DELETE FROM reminder_services WHERE static_reminder_id = ?",
|
||||
(self.id,)
|
||||
)
|
||||
try:
|
||||
cursor.executemany(
|
||||
"INSERT INTO reminder_services(static_reminder_id, notification_service_id) VALUES (?,?)",
|
||||
cursor.executemany("""
|
||||
INSERT INTO reminder_services(
|
||||
static_reminder_id,
|
||||
notification_service_id
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
((self.id, s) for s in notification_services)
|
||||
)
|
||||
cursor.execute("COMMIT;")
|
||||
|
||||
except IntegrityError:
|
||||
raise NotificationServiceNotFound
|
||||
cursor.connection.isolation_level = ""
|
||||
|
||||
finally:
|
||||
cursor.connection.isolation_level = ""
|
||||
|
||||
return self.get()
|
||||
|
||||
@@ -132,33 +175,32 @@ class StaticReminder:
|
||||
class StaticReminders:
|
||||
"""Represents the static reminder library of the user account
|
||||
"""
|
||||
sort_functions = {
|
||||
'title': (lambda r: (r['title'], r['text'], r['color']), False),
|
||||
'title_reversed': (lambda r: (r['title'], r['text'], r['color']), True),
|
||||
'date_added': (lambda r: r['id'], False),
|
||||
'date_added_reversed': (lambda r: r['id'], True)
|
||||
}
|
||||
|
||||
def __init__(self, user_id: int) -> None:
|
||||
"""Create an instance.
|
||||
|
||||
Args:
|
||||
user_id (int): The ID of the user.
|
||||
"""
|
||||
self.user_id = user_id
|
||||
return
|
||||
|
||||
def fetchall(self, sort_by: Literal["title", "title_reversed", "date_added", "date_added_reversed"] = "title") -> List[dict]:
|
||||
def fetchall(
|
||||
self,
|
||||
sort_by: TimelessSortingMethod = TimelessSortingMethod.TITLE
|
||||
) -> List[dict]:
|
||||
"""Get all static reminders
|
||||
|
||||
Args:
|
||||
sort_by (Literal["title", "title_reversed", "date_added", "date_added_reversed"], optional): How to sort the result. Defaults to "title".
|
||||
sort_by (TimelessSortingMethod, optional): How to sort the result.
|
||||
Defaults to TimelessSortingMethod.TITLE.
|
||||
|
||||
Returns:
|
||||
List[dict]: The id, title, text and color of each static reminder
|
||||
List[dict]: The id, title, text and color of each static reminder.
|
||||
"""
|
||||
sort_function = self.sort_functions.get(
|
||||
sort_by,
|
||||
self.sort_functions['title']
|
||||
)
|
||||
|
||||
reminders: list = list(map(
|
||||
dict,
|
||||
get_db(dict).execute("""
|
||||
reminders = [
|
||||
dict(r)
|
||||
for r in get_db(dict).execute("""
|
||||
SELECT
|
||||
id,
|
||||
title, text,
|
||||
@@ -169,29 +211,36 @@ class StaticReminders:
|
||||
""",
|
||||
(self.user_id,)
|
||||
)
|
||||
))
|
||||
|
||||
]
|
||||
|
||||
# Sort result
|
||||
reminders.sort(key=sort_function[0], reverse=sort_function[1])
|
||||
reminders.sort(key=sort_by.value[0], reverse=sort_by.value[1])
|
||||
|
||||
return reminders
|
||||
|
||||
def search(self, query: str, sort_by: Literal["title", "title_reversed", "date_added", "date_added_reversed"] = "title") -> List[dict]:
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
sort_by: TimelessSortingMethod = TimelessSortingMethod.TITLE
|
||||
) -> List[dict]:
|
||||
"""Search for static reminders
|
||||
|
||||
Args:
|
||||
query (str): The term to search for
|
||||
sort_by (Literal["title", "title_reversed", "date_added", "date_added_reversed"], optional): How to sort the result. Defaults to "title".
|
||||
query (str): The term to search for.
|
||||
|
||||
sort_by (TimelessSortingMethod, optional): The sorting method of
|
||||
the resulting list.
|
||||
Defaults to TimelessSortingMethod.TITLE.
|
||||
|
||||
Returns:
|
||||
List[dict]: All static reminders that match. Similar output to self.fetchall
|
||||
"""
|
||||
query = query.lower()
|
||||
reminders = list(filter(
|
||||
lambda p: filter_function(query, p),
|
||||
self.fetchall(sort_by)
|
||||
))
|
||||
return reminders
|
||||
List[dict]: All static reminders that match.
|
||||
Similar output to `self.fetchall`
|
||||
"""
|
||||
static_reminders = [
|
||||
r for r in self.fetchall(sort_by)
|
||||
if search_filter(query, r)
|
||||
]
|
||||
return static_reminders
|
||||
|
||||
def fetchone(self, id: int) -> StaticReminder:
|
||||
"""Get one static reminder
|
||||
@@ -214,22 +263,32 @@ class StaticReminders:
|
||||
"""Add a static reminder
|
||||
|
||||
Args:
|
||||
title (str): The title of the entry
|
||||
notification_services (List[int]): The id's of the notification services to use to send the reminder.
|
||||
text (str, optional): The body of the reminder. Defaults to ''.
|
||||
color (str, optional): The hex code of the color of the reminder, which is shown in the web-ui. Defaults to None.
|
||||
title (str): The title of the entry.
|
||||
|
||||
notification_services (List[int]): The id's of the
|
||||
notification services to use to send the reminder.
|
||||
|
||||
text (str, optional): The body of the reminder.
|
||||
Defaults to ''.
|
||||
|
||||
color (str, optional): The hex code of the color of the template,
|
||||
which is shown in the web-ui.
|
||||
Defaults to None.
|
||||
|
||||
Raises:
|
||||
NotificationServiceNotFound: One of the notification services was not found
|
||||
|
||||
Returns:
|
||||
StaticReminder: A StaticReminder instance representing the newly created static reminder
|
||||
StaticReminder: The info about the static reminder
|
||||
"""
|
||||
logging.info(
|
||||
f'Adding static reminder with {title=}, {notification_services=}, {text=}, {color=}'
|
||||
)
|
||||
|
||||
cursor = get_db()
|
||||
cursor.connection.isolation_level = None
|
||||
cursor.execute("BEGIN TRANSACTION;")
|
||||
|
||||
id = cursor.execute("""
|
||||
INSERT INTO static_reminders(user_id, title, text, color)
|
||||
VALUES (?,?,?,?);
|
||||
@@ -238,13 +297,22 @@ class StaticReminders:
|
||||
).lastrowid
|
||||
|
||||
try:
|
||||
cursor.executemany(
|
||||
"INSERT INTO reminder_services(static_reminder_id, notification_service_id) VALUES (?, ?);",
|
||||
cursor.executemany("""
|
||||
INSERT INTO reminder_services(
|
||||
static_reminder_id,
|
||||
notification_service_id
|
||||
)
|
||||
VALUES (?, ?);
|
||||
""",
|
||||
((id, service) for service in notification_services)
|
||||
)
|
||||
cursor.execute("COMMIT;")
|
||||
|
||||
except IntegrityError:
|
||||
raise NotificationServiceNotFound
|
||||
|
||||
finally:
|
||||
cursor.connection.isolation_level = ""
|
||||
|
||||
return self.fetchone(id)
|
||||
|
||||
def trigger_reminder(self, id: int) -> None:
|
||||
@@ -265,7 +333,9 @@ class StaticReminders:
|
||||
id = ?
|
||||
AND user_id = ?
|
||||
LIMIT 1;
|
||||
""", (id, self.user_id)).fetchone()
|
||||
""",
|
||||
(id, self.user_id)
|
||||
).fetchone()
|
||||
if not reminder:
|
||||
raise ReminderNotFound
|
||||
reminder = dict(reminder)
|
||||
@@ -277,7 +347,9 @@ class StaticReminders:
|
||||
INNER JOIN notification_services ns
|
||||
ON rs.notification_service_id = ns.id
|
||||
WHERE rs.static_reminder_id = ?;
|
||||
""", (id,))
|
||||
""",
|
||||
(id,)
|
||||
)
|
||||
for url in cursor:
|
||||
a.add(url['url'])
|
||||
a.notify(title=reminder['title'], body=reminder['text'])
|
||||
|
||||
@@ -2,21 +2,28 @@
|
||||
|
||||
import logging
|
||||
from sqlite3 import IntegrityError
|
||||
from typing import List, Literal
|
||||
from typing import List, Union
|
||||
|
||||
from backend.custom_exceptions import (NotificationServiceNotFound,
|
||||
TemplateNotFound)
|
||||
from backend.db import get_db
|
||||
from backend.helpers import TimelessSortingMethod, search_filter
|
||||
|
||||
filter_function = lambda query, p: (
|
||||
query in p["title"].lower()
|
||||
or query in p["text"].lower()
|
||||
)
|
||||
|
||||
class Template:
|
||||
"""Represents a template
|
||||
"""
|
||||
def __init__(self, user_id: int, template_id: int):
|
||||
def __init__(self, user_id: int, template_id: int) -> None:
|
||||
"""Create instance of class.
|
||||
|
||||
Args:
|
||||
user_id (int): The ID of the user.
|
||||
template_id (int): The ID of the template.
|
||||
|
||||
Raises:
|
||||
TemplateNotFound: Template with given ID does not exist or is not
|
||||
owned by user.
|
||||
"""
|
||||
self.id = template_id
|
||||
|
||||
exists = get_db().execute(
|
||||
@@ -25,7 +32,26 @@ class Template:
|
||||
).fetchone()
|
||||
if not exists:
|
||||
raise TemplateNotFound
|
||||
|
||||
return
|
||||
|
||||
def _get_notification_services(self) -> List[int]:
|
||||
"""Get ID's of notification services linked to the template.
|
||||
|
||||
Returns:
|
||||
List[int]: The list with ID's.
|
||||
"""
|
||||
result = [
|
||||
r[0]
|
||||
for r in get_db().execute("""
|
||||
SELECT notification_service_id
|
||||
FROM reminder_services
|
||||
WHERE template_id = ?;
|
||||
""",
|
||||
(self.id,)
|
||||
)
|
||||
]
|
||||
return result
|
||||
|
||||
def get(self) -> dict:
|
||||
"""Get info about the template
|
||||
|
||||
@@ -45,27 +71,35 @@ class Template:
|
||||
).fetchone()
|
||||
template = dict(template)
|
||||
|
||||
template['notification_services'] = list(map(lambda r: r[0], get_db().execute("""
|
||||
SELECT notification_service_id
|
||||
FROM reminder_services
|
||||
WHERE template_id = ?;
|
||||
""", (self.id,))))
|
||||
template['notification_services'] = self._get_notification_services()
|
||||
|
||||
return template
|
||||
|
||||
def update(self,
|
||||
title: str = None,
|
||||
notification_services: List[int] = None,
|
||||
text: str = None,
|
||||
color: str = None
|
||||
title: Union[str, None] = None,
|
||||
notification_services: Union[List[int], None] = None,
|
||||
text: Union[str, None] = None,
|
||||
color: Union[str, None] = None
|
||||
) -> dict:
|
||||
"""Edit the template
|
||||
|
||||
Args:
|
||||
title (str): The new title of the entry. Defaults to None.
|
||||
notification_services (List[int]): The new id's of the notification services to use to send the reminder. Defaults to None.
|
||||
text (str, optional): The new body of the template. Defaults to None.
|
||||
color (str, optional): The new hex code of the color of the template, which is shown in the web-ui. Defaults to None.
|
||||
title (Union[str, None]): The new title of the entry.
|
||||
Defaults to None.
|
||||
|
||||
notification_services (Union[List[int], None]): The new id's of the
|
||||
notification services to use to send the reminder.
|
||||
Defaults to None.
|
||||
|
||||
text (Union[str, None], optional): The new body of the template.
|
||||
Defaults to None.
|
||||
|
||||
color (Union[str, None], optional): The new hex code of the color of the template,
|
||||
which is shown in the web-ui.
|
||||
Defaults to None.
|
||||
|
||||
Raises:
|
||||
NotificationServiceNotFound: One of the notification services was not found
|
||||
|
||||
Returns:
|
||||
dict: The new template info
|
||||
@@ -101,16 +135,27 @@ class Template:
|
||||
if notification_services:
|
||||
cursor.connection.isolation_level = None
|
||||
cursor.execute("BEGIN TRANSACTION;")
|
||||
cursor.execute("DELETE FROM reminder_services WHERE template_id = ?", (self.id,))
|
||||
cursor.execute(
|
||||
"DELETE FROM reminder_services WHERE template_id = ?",
|
||||
(self.id,)
|
||||
)
|
||||
try:
|
||||
cursor.executemany(
|
||||
"INSERT INTO reminder_services(template_id, notification_service_id) VALUES (?,?)",
|
||||
cursor.executemany("""
|
||||
INSERT INTO reminder_services(
|
||||
template_id,
|
||||
notification_service_id
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
((self.id, s) for s in notification_services)
|
||||
)
|
||||
cursor.execute("COMMIT;")
|
||||
|
||||
except IntegrityError:
|
||||
raise NotificationServiceNotFound
|
||||
cursor.connection.isolation_level = ""
|
||||
|
||||
finally:
|
||||
cursor.connection.isolation_level = ""
|
||||
|
||||
return self.get()
|
||||
|
||||
@@ -124,63 +169,72 @@ class Template:
|
||||
class Templates:
|
||||
"""Represents the template library of the user account
|
||||
"""
|
||||
sort_functions = {
|
||||
'title': (lambda r: (r['title'], r['text'], r['color']), False),
|
||||
'title_reversed': (lambda r: (r['title'], r['text'], r['color']), True),
|
||||
'date_added': (lambda r: r['id'], False),
|
||||
'date_added_reversed': (lambda r: r['id'], True)
|
||||
}
|
||||
|
||||
def __init__(self, user_id: int):
|
||||
self.user_id = user_id
|
||||
|
||||
def fetchall(self, sort_by: Literal["title", "title_reversed", "date_added", "date_added_reversed"] = "title") -> List[dict]:
|
||||
"""Get all templates
|
||||
def __init__(self, user_id: int) -> None:
|
||||
"""Create an instance.
|
||||
|
||||
Args:
|
||||
sort_by (Literal["title", "title_reversed", "date_added", "date_added_reversed"], optional): How to sort the result. Defaults to "title".
|
||||
user_id (int): The ID of the user.
|
||||
"""
|
||||
self.user_id = user_id
|
||||
return
|
||||
|
||||
def fetchall(
|
||||
self,
|
||||
sort_by: TimelessSortingMethod = TimelessSortingMethod.TITLE
|
||||
) -> List[dict]:
|
||||
"""Get all templates of the user.
|
||||
|
||||
Args:
|
||||
sort_by (TimelessSortingMethod, optional): The sorting method of
|
||||
the resulting list.
|
||||
Defaults to TimelessSortingMethod.TITLE.
|
||||
|
||||
Returns:
|
||||
List[dict]: The id, title, text and color
|
||||
List[dict]: The id, title, text and color of each template.
|
||||
"""
|
||||
sort_function = self.sort_functions.get(
|
||||
sort_by,
|
||||
self.sort_functions['title']
|
||||
)
|
||||
|
||||
templates: list = list(map(dict, get_db(dict).execute("""
|
||||
SELECT
|
||||
id,
|
||||
title, text,
|
||||
color
|
||||
FROM templates
|
||||
WHERE user_id = ?
|
||||
ORDER BY title, id;
|
||||
""",
|
||||
(self.user_id,)
|
||||
)))
|
||||
templates = [
|
||||
dict(r)
|
||||
for r in get_db(dict).execute("""
|
||||
SELECT
|
||||
id,
|
||||
title, text,
|
||||
color
|
||||
FROM templates
|
||||
WHERE user_id = ?
|
||||
ORDER BY title, id;
|
||||
""",
|
||||
(self.user_id,)
|
||||
)
|
||||
]
|
||||
|
||||
# Sort result
|
||||
templates.sort(key=sort_function[0], reverse=sort_function[1])
|
||||
templates.sort(key=sort_by.value[0], reverse=sort_by.value[1])
|
||||
|
||||
return templates
|
||||
|
||||
def search(self, query: str, sort_by: Literal["title", "title_reversed", "date_added", "date_added_reversed"] = "title") -> List[dict]:
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
sort_by: TimelessSortingMethod = TimelessSortingMethod.TITLE
|
||||
) -> List[dict]:
|
||||
"""Search for templates
|
||||
|
||||
Args:
|
||||
query (str): The term to search for
|
||||
sort_by (Literal["title", "title_reversed", "date_added", "date_added_reversed"], optional): How to sort the result. Defaults to "title".
|
||||
query (str): The term to search for.
|
||||
|
||||
sort_by (TimelessSortingMethod, optional): The sorting method of
|
||||
the resulting list.
|
||||
Defaults to TimelessSortingMethod.TITLE.
|
||||
|
||||
Returns:
|
||||
List[dict]: All templates that match. Similar output to self.fetchall
|
||||
List[dict]: All templates that match. Similar output to `self.fetchall`
|
||||
"""
|
||||
query = query.lower()
|
||||
reminders = list(filter(
|
||||
lambda p: filter_function(query, p),
|
||||
self.fetchall(sort_by)
|
||||
))
|
||||
return reminders
|
||||
templates = [
|
||||
r for r in self.fetchall(sort_by)
|
||||
if search_filter(query, r)
|
||||
]
|
||||
return templates
|
||||
|
||||
def fetchone(self, id: int) -> Template:
|
||||
"""Get one template
|
||||
@@ -203,10 +257,20 @@ class Templates:
|
||||
"""Add a template
|
||||
|
||||
Args:
|
||||
title (str): The title of the entry
|
||||
notification_services (List[int]): The id's of the notification services to use to send the reminder.
|
||||
text (str, optional): The body of the reminder. Defaults to ''.
|
||||
color (str, optional): The hex code of the color of the template, which is shown in the web-ui. Defaults to None.
|
||||
title (str): The title of the entry.
|
||||
|
||||
notification_services (List[int]): The id's of the
|
||||
notification services to use to send the reminder.
|
||||
|
||||
text (str, optional): The body of the reminder.
|
||||
Defaults to ''.
|
||||
|
||||
color (str, optional): The hex code of the color of the template,
|
||||
which is shown in the web-ui.
|
||||
Defaults to None.
|
||||
|
||||
Raises:
|
||||
NotificationServiceNotFound: One of the notification services was not found
|
||||
|
||||
Returns:
|
||||
Template: The info about the template
|
||||
@@ -216,19 +280,32 @@ class Templates:
|
||||
)
|
||||
|
||||
cursor = get_db()
|
||||
cursor.connection.isolation_level = None
|
||||
cursor.execute("BEGIN TRANSACTION;")
|
||||
|
||||
id = cursor.execute("""
|
||||
INSERT INTO templates(user_id, title, text, color)
|
||||
VALUES (?,?,?,?);
|
||||
""",
|
||||
(self.user_id, title, text, color)
|
||||
).lastrowid
|
||||
|
||||
|
||||
try:
|
||||
cursor.executemany(
|
||||
"INSERT INTO reminder_services(template_id, notification_service_id) VALUES (?, ?);",
|
||||
cursor.executemany("""
|
||||
INSERT INTO reminder_services(
|
||||
template_id,
|
||||
notification_service_id
|
||||
)
|
||||
VALUES (?, ?);
|
||||
""",
|
||||
((id, service) for service in notification_services)
|
||||
)
|
||||
cursor.execute("COMMIT;")
|
||||
|
||||
except IntegrityError:
|
||||
raise NotificationServiceNotFound
|
||||
|
||||
finally:
|
||||
cursor.connection.isolation_level = ""
|
||||
|
||||
return self.fetchone(id)
|
||||
|
||||
257
backend/users.py
257
backend/users.py
@@ -1,7 +1,7 @@
|
||||
#-*- coding: utf-8 -*-
|
||||
|
||||
import logging
|
||||
from typing import List, Union
|
||||
from typing import List
|
||||
|
||||
from backend.custom_exceptions import (AccessUnauthorized,
|
||||
NewAccountsNotAllowed, UsernameInvalid,
|
||||
@@ -19,32 +19,28 @@ ONEPASS_INVALID_USERNAMES = ['reminders', 'api']
|
||||
|
||||
class User:
|
||||
"""Represents an user account
|
||||
"""
|
||||
def __init__(self, username: str, password: Union[str, None]=None):
|
||||
# Fetch data of user to check if user exists and to check if password is correct
|
||||
"""
|
||||
|
||||
def __init__(self, id: int) -> None:
|
||||
result = get_db(dict).execute(
|
||||
"SELECT id, salt, hash, admin FROM users WHERE username = ? LIMIT 1;",
|
||||
(username,)
|
||||
"SELECT username, admin FROM users WHERE id = ? LIMIT 1;",
|
||||
(id,)
|
||||
).fetchone()
|
||||
if not result:
|
||||
raise UserNotFound
|
||||
self.username = username
|
||||
self.salt = result['salt']
|
||||
self.user_id = result['id']
|
||||
self.admin = result['admin'] == 1
|
||||
|
||||
# Check password
|
||||
if password is not None:
|
||||
hash_password = get_hash(result['salt'], password)
|
||||
if not hash_password == result['hash']:
|
||||
raise AccessUnauthorized
|
||||
|
||||
self.username: str = result['username']
|
||||
self.user_id = id
|
||||
self.admin: bool = result['admin'] == 1
|
||||
return
|
||||
|
||||
@property
|
||||
def reminders(self) -> Reminders:
|
||||
"""Get access to the reminders of the user account
|
||||
|
||||
Returns:
|
||||
Reminders: Reminders instance that can be used to access the reminders of the user account
|
||||
Reminders: Reminders instance that can be used to access the
|
||||
reminders of the user account
|
||||
"""
|
||||
if not hasattr(self, 'reminders_instance'):
|
||||
self.reminders_instance = Reminders(self.user_id)
|
||||
@@ -55,7 +51,8 @@ class User:
|
||||
"""Get access to the notification services of the user account
|
||||
|
||||
Returns:
|
||||
NotificationServices: NotificationServices instance that can be used to access the notification services of the user account
|
||||
NotificationServices: NotificationServices instance that can be used
|
||||
to access the notification services of the user account
|
||||
"""
|
||||
if not hasattr(self, 'notification_services_instance'):
|
||||
self.notification_services_instance = NotificationServices(self.user_id)
|
||||
@@ -66,7 +63,8 @@ class User:
|
||||
"""Get access to the templates of the user account
|
||||
|
||||
Returns:
|
||||
Templates: Templates instance that can be used to access the templates of the user account
|
||||
Templates: Templates instance that can be used to access the
|
||||
templates of the user account
|
||||
"""
|
||||
if not hasattr(self, 'templates_instance'):
|
||||
self.templates_instance = Templates(self.user_id)
|
||||
@@ -77,7 +75,8 @@ class User:
|
||||
"""Get access to the static reminders of the user account
|
||||
|
||||
Returns:
|
||||
StaticReminders: StaticReminders instance that can be used to access the static reminders of the user account
|
||||
StaticReminders: StaticReminders instance that can be used to
|
||||
access the static reminders of the user account
|
||||
"""
|
||||
if not hasattr(self, 'static_reminders_instance'):
|
||||
self.static_reminders_instance = StaticReminders(self.user_id)
|
||||
@@ -103,121 +102,133 @@ class User:
|
||||
def delete(self) -> None:
|
||||
"""Delete the user account
|
||||
"""
|
||||
if self.username == 'admin':
|
||||
raise UserNotFound
|
||||
|
||||
logging.info(f'Deleting the user {self.username} ({self.user_id})')
|
||||
|
||||
cursor = get_db()
|
||||
cursor.execute("DELETE FROM reminders WHERE user_id = ?", (self.user_id,))
|
||||
cursor.execute("DELETE FROM templates WHERE user_id = ?", (self.user_id,))
|
||||
cursor.execute("DELETE FROM static_reminders WHERE user_id = ?", (self.user_id,))
|
||||
cursor.execute("DELETE FROM notification_services WHERE user_id = ?", (self.user_id,))
|
||||
cursor.execute("DELETE FROM users WHERE id = ?", (self.user_id,))
|
||||
cursor.execute(
|
||||
"DELETE FROM reminders WHERE user_id = ?",
|
||||
(self.user_id,)
|
||||
)
|
||||
cursor.execute(
|
||||
"DELETE FROM templates WHERE user_id = ?",
|
||||
(self.user_id,)
|
||||
)
|
||||
cursor.execute(
|
||||
"DELETE FROM static_reminders WHERE user_id = ?",
|
||||
(self.user_id,)
|
||||
)
|
||||
cursor.execute(
|
||||
"DELETE FROM notification_services WHERE user_id = ?",
|
||||
(self.user_id,)
|
||||
)
|
||||
cursor.execute(
|
||||
"DELETE FROM users WHERE id = ?",
|
||||
(self.user_id,)
|
||||
)
|
||||
return
|
||||
|
||||
def _check_username(username: str) -> None:
|
||||
"""Check if username is valid
|
||||
class Users:
|
||||
def _check_username(self, username: str) -> None:
|
||||
"""Check if username is valid
|
||||
|
||||
Args:
|
||||
username (str): The username to check
|
||||
Args:
|
||||
username (str): The username to check
|
||||
|
||||
Raises:
|
||||
UsernameInvalid: The username is not valid
|
||||
"""
|
||||
logging.debug(f'Checking the username {username}')
|
||||
if username in ONEPASS_INVALID_USERNAMES or username.isdigit():
|
||||
raise UsernameInvalid
|
||||
if list(filter(lambda c: not c in ONEPASS_USERNAME_CHARACTERS, username)):
|
||||
raise UsernameInvalid
|
||||
return
|
||||
|
||||
def register_user(username: str, password: str, from_admin: bool=False) -> int:
|
||||
"""Add a user
|
||||
|
||||
Args:
|
||||
username (str): The username of the new user
|
||||
password (str): The password of the new user
|
||||
from_admin (bool, optional): Skip check if new accounts are allowed.
|
||||
Defaults to False.
|
||||
|
||||
Raises:
|
||||
UsernameInvalid: Username not allowed or contains invalid characters
|
||||
UsernameTaken: Username is already taken; usernames must be unique
|
||||
NewAccountsNotAllowed: In the admin panel, new accounts are set to be
|
||||
not allowed.
|
||||
|
||||
Returns:
|
||||
user_id (int): The id of the new user. User registered successful
|
||||
"""
|
||||
logging.info(f'Registering user with username {username}')
|
||||
|
||||
if not from_admin and not get_setting('allow_new_accounts'):
|
||||
raise NewAccountsNotAllowed
|
||||
|
||||
# Check if username is valid
|
||||
_check_username(username)
|
||||
|
||||
cursor = get_db()
|
||||
|
||||
# Check if username isn't already taken
|
||||
if cursor.execute(
|
||||
"SELECT 1 FROM users WHERE username = ? LIMIT 1", (username,)
|
||||
).fetchone():
|
||||
raise UsernameTaken
|
||||
|
||||
# Generate salt and key exclusive for user
|
||||
salt, hashed_password = generate_salt_hash(password)
|
||||
del password
|
||||
|
||||
# Add user to userlist
|
||||
user_id = cursor.execute(
|
||||
Raises:
|
||||
UsernameInvalid: The username is not valid
|
||||
"""
|
||||
INSERT INTO users(username, salt, hash)
|
||||
VALUES (?,?,?);
|
||||
""",
|
||||
(username, salt, hashed_password)
|
||||
).lastrowid
|
||||
logging.debug(f'Checking the username {username}')
|
||||
if username in ONEPASS_INVALID_USERNAMES or username.isdigit():
|
||||
raise UsernameInvalid(username)
|
||||
if list(filter(lambda c: not c in ONEPASS_USERNAME_CHARACTERS, username)):
|
||||
raise UsernameInvalid(username)
|
||||
return
|
||||
|
||||
logging.debug(f'Newly registered user has id {user_id}')
|
||||
return user_id
|
||||
def __contains__(self, username: str) -> bool:
|
||||
result = get_db().execute(
|
||||
"SELECT 1 FROM users WHERE username = ? LIMIT 1;",
|
||||
(username,)
|
||||
).fetchone()
|
||||
return result is not None
|
||||
|
||||
def get_users() -> List[dict]:
|
||||
"""Get all user info for the admin
|
||||
def add(self, username: str, password: str, from_admin: bool=False) -> int:
|
||||
"""Add a user
|
||||
|
||||
Returns:
|
||||
List[dict]: The info about all users
|
||||
"""
|
||||
result = [
|
||||
dict(u)
|
||||
for u in get_db(dict).execute(
|
||||
"SELECT id, username, admin FROM users ORDER BY username;"
|
||||
)
|
||||
]
|
||||
return result
|
||||
Args:
|
||||
username (str): The username of the new user
|
||||
password (str): The password of the new user
|
||||
from_admin (bool, optional): Skip check if new accounts are allowed.
|
||||
Defaults to False.
|
||||
|
||||
def edit_user_password(id: int, new_password: str) -> None:
|
||||
"""Change the password of a user for the admin
|
||||
Raises:
|
||||
UsernameInvalid: Username not allowed or contains invalid characters
|
||||
UsernameTaken: Username is already taken; usernames must be unique
|
||||
NewAccountsNotAllowed: In the admin panel, new accounts are set to be
|
||||
not allowed.
|
||||
|
||||
Args:
|
||||
id (int): The ID of the user to change the password of
|
||||
new_password (str): The new password to set for the user
|
||||
"""
|
||||
username = (get_db().execute(
|
||||
"SELECT username FROM users WHERE id = ? LIMIT 1;",
|
||||
(id,)
|
||||
).fetchone() or [''])[0]
|
||||
User(username).edit_password(new_password)
|
||||
return
|
||||
Returns:
|
||||
int: The id of the new user. User registered successful
|
||||
"""
|
||||
logging.info(f'Registering user with username {username}')
|
||||
|
||||
if not from_admin and not get_setting('allow_new_accounts'):
|
||||
raise NewAccountsNotAllowed
|
||||
|
||||
# Check if username is valid
|
||||
self._check_username(username)
|
||||
|
||||
def delete_user(id: int) -> None:
|
||||
"""Delete a user for the admin
|
||||
cursor = get_db()
|
||||
|
||||
Args:
|
||||
id (int): The ID of the user to delete
|
||||
"""
|
||||
username = (get_db().execute(
|
||||
"SELECT username FROM users WHERE id = ? LIMIT 1;",
|
||||
(id,)
|
||||
).fetchone() or [''])[0]
|
||||
if username == 'admin':
|
||||
raise UserNotFound
|
||||
User(username).delete()
|
||||
return
|
||||
# Check if username isn't already taken
|
||||
if username in self:
|
||||
raise UsernameTaken
|
||||
|
||||
# Generate salt and key exclusive for user
|
||||
salt, hashed_password = generate_salt_hash(password)
|
||||
del password
|
||||
|
||||
# Add user to userlist
|
||||
user_id = cursor.execute(
|
||||
"""
|
||||
INSERT INTO users(username, salt, hash)
|
||||
VALUES (?,?,?);
|
||||
""",
|
||||
(username, salt, hashed_password)
|
||||
).lastrowid
|
||||
|
||||
logging.debug(f'Newly registered user has id {user_id}')
|
||||
return user_id
|
||||
|
||||
def get_all(self) -> List[dict]:
|
||||
"""Get all user info for the admin
|
||||
|
||||
Returns:
|
||||
List[dict]: The info about all users
|
||||
"""
|
||||
result = [
|
||||
dict(u)
|
||||
for u in get_db(dict).execute(
|
||||
"SELECT id, username, admin FROM users ORDER BY username;"
|
||||
)
|
||||
]
|
||||
return result
|
||||
|
||||
def login(self, username: str, password: str) -> User:
|
||||
result = get_db(dict).execute(
|
||||
"SELECT id, salt, hash FROM users WHERE username = ? LIMIT 1;",
|
||||
(username,)
|
||||
).fetchone()
|
||||
if not result:
|
||||
raise UserNotFound
|
||||
|
||||
hash_password = get_hash(result['salt'], password)
|
||||
if not hash_password == result['hash']:
|
||||
raise AccessUnauthorized
|
||||
|
||||
return User(result['id'])
|
||||
|
||||
def get_one(self, id: int) -> User:
|
||||
return User(id)
|
||||
|
||||
546
frontend/api.py
546
frontend/api.py
@@ -1,16 +1,13 @@
|
||||
#-*- coding: utf-8 -*-
|
||||
|
||||
from io import BytesIO
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from io import BytesIO
|
||||
from os import urandom
|
||||
from re import compile
|
||||
from time import time as epoch_time
|
||||
from typing import Any, Callable, Dict, List, Tuple, Union
|
||||
from typing import Any, Callable, Dict, Tuple
|
||||
|
||||
from apprise import Apprise
|
||||
from flask import Blueprint, g, request, send_file
|
||||
from flask.sansio.scaffold import T_route
|
||||
from flask import g, request, send_file
|
||||
|
||||
from backend.custom_exceptions import (AccessUnauthorized, APIKeyExpired,
|
||||
APIKeyInvalid, InvalidKeyValue,
|
||||
@@ -22,378 +19,39 @@ from backend.custom_exceptions import (AccessUnauthorized, APIKeyExpired,
|
||||
UsernameInvalid, UsernameTaken,
|
||||
UserNotFound)
|
||||
from backend.db import DBConnection
|
||||
from backend.notification_service import (NotificationService,
|
||||
NotificationServices,
|
||||
get_apprise_services)
|
||||
from backend.reminders import Reminders, reminder_handler
|
||||
from backend.settings import (_format_setting, get_admin_settings, get_setting,
|
||||
set_setting)
|
||||
from backend.static_reminders import StaticReminders
|
||||
from backend.templates import Template, Templates
|
||||
from backend.users import (User, delete_user, edit_user_password, get_users,
|
||||
register_user)
|
||||
|
||||
#===================
|
||||
# Input validation
|
||||
#===================
|
||||
color_regex = compile(r'#[0-9a-f]{6}')
|
||||
|
||||
class DataSource:
|
||||
DATA = 1
|
||||
VALUES = 2
|
||||
|
||||
class InputVariable(ABC):
|
||||
@abstractmethod
|
||||
def __init__(self, value: Any) -> None:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name() -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def validate(self) -> bool:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def required() -> bool:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def default() -> Any:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def source() -> int:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def description() -> str:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def related_exceptions() -> List[Exception]:
|
||||
pass
|
||||
|
||||
class DefaultInputVariable(InputVariable):
|
||||
source = DataSource.DATA
|
||||
required = True
|
||||
default = None
|
||||
related_exceptions = []
|
||||
|
||||
def __init__(self, value: Any) -> None:
|
||||
self.value = value
|
||||
|
||||
def validate(self) -> bool:
|
||||
return isinstance(self.value, str) and self.value
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'| {self.name} | {"Yes" if self.required else "No"} | {self.description} | N/A |'
|
||||
|
||||
class NonRequiredVersion(InputVariable):
|
||||
required = False
|
||||
|
||||
def validate(self) -> bool:
|
||||
return self.value is None or super().validate()
|
||||
|
||||
class UsernameVariable(DefaultInputVariable):
|
||||
name = 'username'
|
||||
description = 'The username of the user account'
|
||||
related_exceptions = [KeyNotFound, UserNotFound]
|
||||
|
||||
class PasswordVariable(DefaultInputVariable):
|
||||
name = 'password'
|
||||
description = 'The password of the user account'
|
||||
related_exceptions = [KeyNotFound, AccessUnauthorized]
|
||||
|
||||
class NewPasswordVariable(PasswordVariable):
|
||||
name = 'new_password'
|
||||
description = 'The new password of the user account'
|
||||
related_exceptions = [KeyNotFound]
|
||||
|
||||
class UsernameCreateVariable(UsernameVariable):
|
||||
related_exceptions = [
|
||||
KeyNotFound,
|
||||
UsernameInvalid, UsernameTaken,
|
||||
NewAccountsNotAllowed
|
||||
]
|
||||
|
||||
class PasswordCreateVariable(PasswordVariable):
|
||||
related_exceptions = [KeyNotFound]
|
||||
|
||||
class TitleVariable(DefaultInputVariable):
|
||||
name = 'title'
|
||||
description = 'The title of the entry'
|
||||
related_exceptions = [KeyNotFound]
|
||||
|
||||
class URLVariable(DefaultInputVariable):
|
||||
name = 'url'
|
||||
description = 'The Apprise URL of the notification service'
|
||||
related_exceptions = [KeyNotFound, InvalidKeyValue]
|
||||
|
||||
def validate(self) -> bool:
|
||||
return Apprise().add(self.value)
|
||||
|
||||
class EditTitleVariable(NonRequiredVersion, TitleVariable):
|
||||
related_exceptions = []
|
||||
|
||||
class EditURLVariable(NonRequiredVersion, URLVariable):
|
||||
related_exceptions = [InvalidKeyValue]
|
||||
|
||||
class SortByVariable(DefaultInputVariable):
|
||||
name = 'sort_by'
|
||||
description = 'How to sort the result'
|
||||
required = False
|
||||
source = DataSource.VALUES
|
||||
_options = Reminders.sort_functions
|
||||
default = next(iter(Reminders.sort_functions))
|
||||
related_exceptions = [InvalidKeyValue]
|
||||
|
||||
def __init__(self, value: str) -> None:
|
||||
self.value = value
|
||||
|
||||
def validate(self) -> bool:
|
||||
return self.value in self._options
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return '| {n} | {r} | {d} | {v} |'.format(
|
||||
n=self.name,
|
||||
r="Yes" if self.required else "No",
|
||||
d=self.description,
|
||||
v=", ".join(f'`{o}`' for o in self._options)
|
||||
)
|
||||
|
||||
class TemplateSortByVariable(SortByVariable):
|
||||
_options = Templates.sort_functions
|
||||
default = next(iter(Templates.sort_functions))
|
||||
|
||||
class StaticReminderSortByVariable(TemplateSortByVariable):
|
||||
_options = StaticReminders.sort_functions
|
||||
default = next(iter(StaticReminders.sort_functions))
|
||||
|
||||
class TimeVariable(DefaultInputVariable):
|
||||
name = 'time'
|
||||
description = 'The UTC epoch timestamp that the reminder should be sent at'
|
||||
related_exceptions = [KeyNotFound, InvalidKeyValue, InvalidTime]
|
||||
|
||||
def validate(self) -> bool:
|
||||
return isinstance(self.value, (float, int))
|
||||
|
||||
class EditTimeVariable(NonRequiredVersion, TimeVariable):
|
||||
related_exceptions = [InvalidKeyValue, InvalidTime]
|
||||
|
||||
class NotificationServicesVariable(DefaultInputVariable):
|
||||
name = 'notification_services'
|
||||
description = "Array of the id's of the notification services to use to send the notification"
|
||||
related_exceptions = [KeyNotFound, InvalidKeyValue, NotificationServiceNotFound]
|
||||
|
||||
def validate(self) -> bool:
|
||||
if not isinstance(self.value, list):
|
||||
return False
|
||||
if not self.value:
|
||||
return False
|
||||
for v in self.value:
|
||||
if not isinstance(v, int):
|
||||
return False
|
||||
return True
|
||||
|
||||
class EditNotificationServicesVariable(NonRequiredVersion, NotificationServicesVariable):
|
||||
related_exceptions = [InvalidKeyValue, NotificationServiceNotFound]
|
||||
|
||||
class TextVariable(NonRequiredVersion, DefaultInputVariable):
|
||||
name = 'text'
|
||||
description = 'The body of the entry'
|
||||
default = ''
|
||||
|
||||
def validate(self) -> bool:
|
||||
return isinstance(self.value, str)
|
||||
|
||||
class RepeatQuantityVariable(DefaultInputVariable):
|
||||
name = 'repeat_quantity'
|
||||
description = 'The quantity of the repeat_interval'
|
||||
required = False
|
||||
_options = ("years", "months", "weeks", "days", "hours", "minutes")
|
||||
default = None
|
||||
related_exceptions = [InvalidKeyValue]
|
||||
|
||||
def validate(self) -> bool:
|
||||
return self.value is None or self.value in self._options
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return '| {n} | {r} | {d} | {v} |'.format(
|
||||
n=self.name,
|
||||
r="Yes" if self.required else "No",
|
||||
d=self.description,
|
||||
v=", ".join(f'`{o}`' for o in self._options)
|
||||
)
|
||||
|
||||
class RepeatIntervalVariable(DefaultInputVariable):
|
||||
name = 'repeat_interval'
|
||||
description = 'The number of the interval'
|
||||
required = False
|
||||
default = None
|
||||
related_exceptions = [InvalidKeyValue]
|
||||
|
||||
def validate(self) -> bool:
|
||||
return self.value is None or (isinstance(self.value, int) and self.value > 0)
|
||||
|
||||
class WeekDaysVariable(DefaultInputVariable):
|
||||
name = 'weekdays'
|
||||
description = 'On which days of the weeks to run the reminder'
|
||||
required = False
|
||||
default = None
|
||||
related_exceptions = [InvalidKeyValue]
|
||||
_options = {0, 1, 2, 3, 4, 5, 6}
|
||||
|
||||
def validate(self) -> bool:
|
||||
return self.value is None or (
|
||||
isinstance(self.value, list)
|
||||
and len(self.value) > 0
|
||||
and all(v in self._options for v in self.value)
|
||||
)
|
||||
|
||||
class ColorVariable(DefaultInputVariable):
|
||||
name = 'color'
|
||||
description = 'The hex code of the color of the entry, which is shown in the web-ui'
|
||||
required = False
|
||||
default = None
|
||||
related_exceptions = [InvalidKeyValue]
|
||||
|
||||
def validate(self) -> bool:
|
||||
return self.value is None or color_regex.search(self.value)
|
||||
|
||||
class QueryVariable(DefaultInputVariable):
|
||||
name = 'query'
|
||||
description = 'The search term'
|
||||
source = DataSource.VALUES
|
||||
|
||||
class AdminSettingsVariable(DefaultInputVariable):
|
||||
related_exceptions = [KeyNotFound, InvalidKeyValue]
|
||||
|
||||
def validate(self) -> bool:
|
||||
try:
|
||||
_format_setting(self.name, self.value)
|
||||
except InvalidKeyValue:
|
||||
return False
|
||||
return True
|
||||
|
||||
class AllowNewAccountsVariable(AdminSettingsVariable):
|
||||
name = 'allow_new_accounts'
|
||||
description = ('Whether or not to allow users to register a new account. '
|
||||
+ 'The admin can always add a new account.')
|
||||
|
||||
class LoginTimeVariable(AdminSettingsVariable):
|
||||
name = 'login_time'
|
||||
description = ('How long a user stays logged in, in seconds. '
|
||||
+ 'Between 1 min and 1 month (60 <= sec <= 2592000)')
|
||||
|
||||
class LoginTimeResetVariable(AdminSettingsVariable):
|
||||
name = 'login_time_reset'
|
||||
description = 'If the Login Time timer should reset with each API request.'
|
||||
|
||||
def input_validation() -> Union[None, Dict[str, Any]]:
|
||||
"""Checks, extracts and transforms inputs
|
||||
|
||||
Raises:
|
||||
KeyNotFound: A required key was not supplied
|
||||
InvalidKeyValue: The value of a key is not valid
|
||||
|
||||
Returns:
|
||||
Union[None, Dict[str, Any]]: `None` if the endpoint + method doesn't require input variables.
|
||||
Otherwise `Dict[str, Any]` with the input variables, checked and formatted.
|
||||
"""
|
||||
inputs = {}
|
||||
|
||||
input_variables: Dict[str, List[Union[List[InputVariable], str]]]
|
||||
if request.path.startswith(admin_api_prefix):
|
||||
input_variables = api_docs[
|
||||
_admin_api_prefix + request.url_rule.rule.split(admin_api_prefix)[1]
|
||||
]['input_variables']
|
||||
else:
|
||||
input_variables = api_docs[
|
||||
request.url_rule.rule.split(api_prefix)[1]
|
||||
]['input_variables']
|
||||
|
||||
if not input_variables:
|
||||
return
|
||||
|
||||
if input_variables.get(request.method) is None:
|
||||
return inputs
|
||||
|
||||
given_variables = {}
|
||||
given_variables[DataSource.DATA] = request.get_json() if request.data else {}
|
||||
given_variables[DataSource.VALUES] = request.values
|
||||
for input_variable in input_variables[request.method]:
|
||||
if (
|
||||
input_variable.required and
|
||||
not input_variable.name in given_variables[input_variable.source]
|
||||
):
|
||||
raise KeyNotFound(input_variable.name)
|
||||
|
||||
input_value = given_variables[input_variable.source].get(
|
||||
input_variable.name,
|
||||
input_variable.default
|
||||
)
|
||||
|
||||
if not input_variable(input_value).validate():
|
||||
raise InvalidKeyValue(input_variable.name, input_value)
|
||||
|
||||
inputs[input_variable.name] = input_value
|
||||
return inputs
|
||||
from backend.notification_service import get_apprise_services
|
||||
from backend.settings import get_admin_settings, get_setting, set_setting
|
||||
from backend.users import User, Users
|
||||
from frontend.input_validation import (AllowNewAccountsVariable, ColorVariable,
|
||||
EditNotificationServicesVariable,
|
||||
EditTimeVariable, EditTitleVariable,
|
||||
EditURLVariable, LoginTimeResetVariable,
|
||||
LoginTimeVariable, NewPasswordVariable,
|
||||
NotificationServicesVariable,
|
||||
PasswordCreateVariable,
|
||||
PasswordVariable, QueryVariable,
|
||||
RepeatIntervalVariable,
|
||||
RepeatQuantityVariable, SortByVariable,
|
||||
StaticReminderSortByVariable,
|
||||
TemplateSortByVariable, TextVariable,
|
||||
TimeVariable, TitleVariable,
|
||||
URLVariable, UsernameCreateVariable,
|
||||
UsernameVariable, WeekDaysVariable,
|
||||
_admin_api_prefix, admin_api,
|
||||
admin_api_prefix, api, api_docs,
|
||||
api_prefix, input_validation)
|
||||
|
||||
#===================
|
||||
# General variables and functions
|
||||
#===================
|
||||
|
||||
api_docs: Dict[str, Dict[str, Any]] = {}
|
||||
class APIBlueprint(Blueprint):
|
||||
def route(
|
||||
self,
|
||||
rule: str,
|
||||
description: str = '',
|
||||
input_variables: Dict[str, List[Union[List[InputVariable], str]]] = {},
|
||||
requires_auth: bool = True,
|
||||
**options: Any
|
||||
) -> Callable[[T_route], T_route]:
|
||||
@dataclass
|
||||
class ApiKeyEntry:
|
||||
exp: int
|
||||
user_data: User
|
||||
|
||||
if self == api:
|
||||
processed_rule = rule
|
||||
elif self == admin_api:
|
||||
processed_rule = _admin_api_prefix + rule
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
api_docs[processed_rule] = {
|
||||
'endpoint': processed_rule,
|
||||
'description': description,
|
||||
'requires_auth': requires_auth,
|
||||
'methods': options['methods'],
|
||||
'input_variables': {
|
||||
k: v[0]
|
||||
for k, v in input_variables.items()
|
||||
if v and v[0]
|
||||
},
|
||||
'method_descriptions': {
|
||||
k: v[1]
|
||||
for k, v in input_variables.items()
|
||||
if v and len(v) == 2 and v[1]
|
||||
}
|
||||
}
|
||||
|
||||
return super().route(rule, **options)
|
||||
|
||||
api_prefix = "/api"
|
||||
_admin_api_prefix = '/admin'
|
||||
admin_api_prefix = api_prefix + _admin_api_prefix
|
||||
api = APIBlueprint('api', __name__)
|
||||
admin_api = APIBlueprint('admin_api', __name__)
|
||||
api_key_map = {}
|
||||
users = Users()
|
||||
api_key_map: Dict[int, ApiKeyEntry] = {}
|
||||
|
||||
def return_api(result: Any, error: str=None, code: int=200) -> Tuple[dict, int]:
|
||||
return {'error': error, 'result': result}, code
|
||||
@@ -409,33 +67,37 @@ def auth() -> None:
|
||||
if not hashed_api_key in api_key_map:
|
||||
raise APIKeyInvalid
|
||||
|
||||
if not (
|
||||
(
|
||||
api_key_map[hashed_api_key]['user_data'].admin
|
||||
and request.path.startswith((admin_api_prefix, api_prefix + '/auth'))
|
||||
)
|
||||
or
|
||||
(
|
||||
not api_key_map[hashed_api_key]['user_data'].admin
|
||||
and not request.path.startswith(admin_api_prefix)
|
||||
)
|
||||
map_entry = api_key_map[hashed_api_key]
|
||||
|
||||
if (
|
||||
map_entry.user_data.admin
|
||||
and
|
||||
not request.path.startswith((admin_api_prefix, api_prefix + '/auth'))
|
||||
):
|
||||
raise APIKeyInvalid
|
||||
|
||||
if (
|
||||
not map_entry.user_data.admin
|
||||
and
|
||||
request.path.startswith(admin_api_prefix)
|
||||
):
|
||||
raise APIKeyInvalid
|
||||
|
||||
exp = api_key_map[hashed_api_key]['exp']
|
||||
if exp <= epoch_time():
|
||||
if map_entry.exp <= epoch_time():
|
||||
raise APIKeyExpired
|
||||
|
||||
# Api key valid
|
||||
|
||||
if get_setting('login_time_reset'):
|
||||
api_key_map[hashed_api_key]['exp'] = exp = (
|
||||
g.exp = map_entry.exp = (
|
||||
epoch_time() + get_setting('login_time')
|
||||
)
|
||||
else:
|
||||
g.exp = map_entry.exp
|
||||
|
||||
g.hashed_api_key = hashed_api_key
|
||||
g.exp = exp
|
||||
g.user_data = api_key_map[hashed_api_key]['user_data']
|
||||
g.user_data = map_entry.user_data
|
||||
|
||||
return
|
||||
|
||||
def endpoint_wrapper(method: Callable) -> Callable:
|
||||
@@ -458,14 +120,15 @@ def endpoint_wrapper(method: Callable) -> Callable:
|
||||
return method(*args, **kwargs)
|
||||
return method(inputs, *args, **kwargs)
|
||||
|
||||
except (UsernameTaken, UsernameInvalid, UserNotFound,
|
||||
AccessUnauthorized,
|
||||
ReminderNotFound, NotificationServiceNotFound,
|
||||
NotificationServiceInUse, InvalidTime,
|
||||
KeyNotFound, InvalidKeyValue,
|
||||
APIKeyInvalid, APIKeyExpired,
|
||||
TemplateNotFound,
|
||||
NewAccountsNotAllowed) as e:
|
||||
except (AccessUnauthorized, APIKeyExpired,
|
||||
APIKeyInvalid, InvalidKeyValue,
|
||||
InvalidTime, KeyNotFound,
|
||||
NewAccountsNotAllowed,
|
||||
NotificationServiceInUse,
|
||||
NotificationServiceNotFound,
|
||||
ReminderNotFound, TemplateNotFound,
|
||||
UsernameInvalid, UsernameTaken,
|
||||
UserNotFound) as e:
|
||||
return return_api(**e.api_response)
|
||||
|
||||
wrapper.__name__ = method.__name__
|
||||
@@ -484,7 +147,7 @@ def endpoint_wrapper(method: Callable) -> Callable:
|
||||
)
|
||||
@endpoint_wrapper
|
||||
def api_login(inputs: Dict[str, str]):
|
||||
user = User(inputs['username'], inputs['password'])
|
||||
user = users.login(inputs['username'], inputs['password'])
|
||||
|
||||
# Generate an API key until one
|
||||
# is generated that isn't used already
|
||||
@@ -496,12 +159,7 @@ def api_login(inputs: Dict[str, str]):
|
||||
|
||||
login_time = get_setting('login_time')
|
||||
exp = epoch_time() + login_time
|
||||
api_key_map.update({
|
||||
hashed_api_key: {
|
||||
'exp': exp,
|
||||
'user_data': user
|
||||
}
|
||||
})
|
||||
api_key_map[hashed_api_key] = ApiKeyEntry(exp, user)
|
||||
|
||||
result = {'api_key': api_key, 'expires': exp, 'admin': user.admin}
|
||||
return return_api(result, code=201)
|
||||
@@ -523,17 +181,17 @@ def api_logout():
|
||||
)
|
||||
@endpoint_wrapper
|
||||
def api_status():
|
||||
map_entry = api_key_map[g.hashed_api_key]
|
||||
result = {
|
||||
'expires': api_key_map[g.hashed_api_key]['exp'],
|
||||
'username': api_key_map[g.hashed_api_key]['user_data'].username,
|
||||
'admin': api_key_map[g.hashed_api_key]['user_data'].admin
|
||||
'expires': map_entry.exp,
|
||||
'username': map_entry.user_data.username,
|
||||
'admin': map_entry.user_data.admin
|
||||
}
|
||||
return return_api(result)
|
||||
|
||||
#===================
|
||||
# User endpoints
|
||||
#===================
|
||||
|
||||
@api.route(
|
||||
'/user/add',
|
||||
'Create a new user account',
|
||||
@@ -543,7 +201,7 @@ def api_status():
|
||||
)
|
||||
@endpoint_wrapper
|
||||
def api_add_user(inputs: Dict[str, str]):
|
||||
register_user(inputs['username'], inputs['password'])
|
||||
users.add(inputs['username'], inputs['password'])
|
||||
return return_api({}, code=201)
|
||||
|
||||
@api.route(
|
||||
@@ -557,12 +215,13 @@ def api_add_user(inputs: Dict[str, str]):
|
||||
)
|
||||
@endpoint_wrapper
|
||||
def api_manage_user(inputs: Dict[str, str]):
|
||||
user = api_key_map[g.hashed_api_key].user_data
|
||||
if request.method == 'PUT':
|
||||
g.user_data.edit_password(inputs['new_password'])
|
||||
user.edit_password(inputs['new_password'])
|
||||
return return_api({})
|
||||
|
||||
elif request.method == 'DELETE':
|
||||
g.user_data.delete()
|
||||
user.delete()
|
||||
api_key_map.pop(g.hashed_api_key)
|
||||
return return_api({})
|
||||
|
||||
@@ -581,7 +240,7 @@ def api_manage_user(inputs: Dict[str, str]):
|
||||
)
|
||||
@endpoint_wrapper
|
||||
def api_notification_services_list(inputs: Dict[str, str]):
|
||||
services: NotificationServices = g.user_data.notification_services
|
||||
services = api_key_map[g.hashed_api_key].user_data.notification_services
|
||||
|
||||
if request.method == 'GET':
|
||||
result = services.fetchall()
|
||||
@@ -610,7 +269,10 @@ def api_notification_service_available():
|
||||
)
|
||||
@endpoint_wrapper
|
||||
def api_test_service(inputs: Dict[str, Any]):
|
||||
g.user_data.notification_services.test_service(inputs['url'])
|
||||
(api_key_map[g.hashed_api_key]
|
||||
.user_data
|
||||
.notification_services
|
||||
.test_service(inputs['url']))
|
||||
return return_api({}, code=201)
|
||||
|
||||
@api.route(
|
||||
@@ -624,17 +286,20 @@ def api_test_service(inputs: Dict[str, Any]):
|
||||
)
|
||||
@endpoint_wrapper
|
||||
def api_notification_service(inputs: Dict[str, str], n_id: int):
|
||||
service: NotificationService = g.user_data.notification_services.fetchone(n_id)
|
||||
|
||||
service = (api_key_map[g.hashed_api_key]
|
||||
.user_data
|
||||
.notification_services
|
||||
.fetchone(n_id))
|
||||
|
||||
if request.method == 'GET':
|
||||
result = service.get()
|
||||
return return_api(result)
|
||||
|
||||
|
||||
elif request.method == 'PUT':
|
||||
result = service.update(title=inputs['title'],
|
||||
url=inputs['url'])
|
||||
return return_api(result)
|
||||
|
||||
|
||||
elif request.method == 'DELETE':
|
||||
service.delete()
|
||||
return return_api({})
|
||||
@@ -659,7 +324,7 @@ def api_notification_service(inputs: Dict[str, str], n_id: int):
|
||||
)
|
||||
@endpoint_wrapper
|
||||
def api_reminders_list(inputs: Dict[str, Any]):
|
||||
reminders: Reminders = g.user_data.reminders
|
||||
reminders = api_key_map[g.hashed_api_key].user_data.reminders
|
||||
|
||||
if request.method == 'GET':
|
||||
result = reminders.fetchall(inputs['sort_by'])
|
||||
@@ -684,7 +349,10 @@ def api_reminders_list(inputs: Dict[str, Any]):
|
||||
)
|
||||
@endpoint_wrapper
|
||||
def api_reminders_query(inputs: Dict[str, str]):
|
||||
result = g.user_data.reminders.search(inputs['query'], inputs['sort_by'])
|
||||
result = (api_key_map[g.hashed_api_key]
|
||||
.user_data
|
||||
.reminders
|
||||
.search(inputs['query'], inputs['sort_by']))
|
||||
return return_api(result)
|
||||
|
||||
@api.route(
|
||||
@@ -696,7 +364,11 @@ def api_reminders_query(inputs: Dict[str, str]):
|
||||
)
|
||||
@endpoint_wrapper
|
||||
def api_test_reminder(inputs: Dict[str, Any]):
|
||||
g.user_data.reminders.test_reminder(inputs['title'], inputs['notification_services'], inputs['text'])
|
||||
api_key_map[g.hashed_api_key].user_data.reminders.test_reminder(
|
||||
inputs['title'],
|
||||
inputs['notification_services'],
|
||||
inputs['text']
|
||||
)
|
||||
return return_api({}, code=201)
|
||||
|
||||
@api.route(
|
||||
@@ -714,7 +386,8 @@ def api_test_reminder(inputs: Dict[str, Any]):
|
||||
)
|
||||
@endpoint_wrapper
|
||||
def api_get_reminder(inputs: Dict[str, Any], r_id: int):
|
||||
reminders: Reminders = g.user_data.reminders
|
||||
reminders = api_key_map[g.hashed_api_key].user_data.reminders
|
||||
|
||||
if request.method == 'GET':
|
||||
result = reminders.fetchone(r_id).get()
|
||||
return return_api(result)
|
||||
@@ -750,7 +423,7 @@ def api_get_reminder(inputs: Dict[str, Any], r_id: int):
|
||||
)
|
||||
@endpoint_wrapper
|
||||
def api_get_templates(inputs: Dict[str, Any]):
|
||||
templates: Templates = g.user_data.templates
|
||||
templates = api_key_map[g.hashed_api_key].user_data.templates
|
||||
|
||||
if request.method == 'GET':
|
||||
result = templates.fetchall(inputs['sort_by'])
|
||||
@@ -771,7 +444,10 @@ def api_get_templates(inputs: Dict[str, Any]):
|
||||
)
|
||||
@endpoint_wrapper
|
||||
def api_templates_query(inputs: Dict[str, str]):
|
||||
result = g.user_data.templates.search(inputs['query'], inputs['sort_by'])
|
||||
result = (api_key_map[g.hashed_api_key]
|
||||
.user_data
|
||||
.templates
|
||||
.search(inputs['query'], inputs['sort_by']))
|
||||
return return_api(result)
|
||||
|
||||
@api.route(
|
||||
@@ -786,7 +462,10 @@ def api_templates_query(inputs: Dict[str, str]):
|
||||
)
|
||||
@endpoint_wrapper
|
||||
def api_get_template(inputs: Dict[str, Any], t_id: int):
|
||||
template: Template = g.user_data.templates.fetchone(t_id)
|
||||
template = (api_key_map[g.hashed_api_key]
|
||||
.user_data
|
||||
.templates
|
||||
.fetchone(t_id))
|
||||
|
||||
if request.method == 'GET':
|
||||
result = template.get()
|
||||
@@ -819,7 +498,7 @@ def api_get_template(inputs: Dict[str, Any], t_id: int):
|
||||
)
|
||||
@endpoint_wrapper
|
||||
def api_static_reminders_list(inputs: Dict[str, Any]):
|
||||
reminders: StaticReminders = g.user_data.static_reminders
|
||||
reminders = api_key_map[g.hashed_api_key].user_data.static_reminders
|
||||
|
||||
if request.method == 'GET':
|
||||
result = reminders.fetchall(inputs['sort_by'])
|
||||
@@ -840,7 +519,10 @@ def api_static_reminders_list(inputs: Dict[str, Any]):
|
||||
)
|
||||
@endpoint_wrapper
|
||||
def api_static_reminders_query(inputs: Dict[str, str]):
|
||||
result = g.user_data.static_reminders.search(inputs['query'], inputs['sort_by'])
|
||||
result = (api_key_map[g.hashed_api_key]
|
||||
.user_data
|
||||
.static_reminders
|
||||
.search(inputs['query'], inputs['sort_by']))
|
||||
return return_api(result)
|
||||
|
||||
@api.route(
|
||||
@@ -857,7 +539,8 @@ def api_static_reminders_query(inputs: Dict[str, str]):
|
||||
)
|
||||
@endpoint_wrapper
|
||||
def api_get_static_reminder(inputs: Dict[str, Any], s_id: int):
|
||||
reminders: StaticReminders = g.user_data.static_reminders
|
||||
reminders = api_key_map[g.hashed_api_key].user_data.static_reminders
|
||||
|
||||
if request.method == 'GET':
|
||||
result = reminders.fetchone(s_id).get()
|
||||
return return_api(result)
|
||||
@@ -929,11 +612,11 @@ def api_admin_settings(inputs: Dict[str, Any]):
|
||||
@endpoint_wrapper
|
||||
def api_admin_users(inputs: Dict[str, Any]):
|
||||
if request.method == 'GET':
|
||||
result = get_users()
|
||||
result = users.get_all()
|
||||
return return_api(result)
|
||||
|
||||
elif request.method == 'POST':
|
||||
register_user(inputs['username'], inputs['password'], True)
|
||||
users.add(inputs['username'], inputs['password'], True)
|
||||
return return_api({}, code=201)
|
||||
|
||||
@admin_api.route(
|
||||
@@ -947,14 +630,15 @@ def api_admin_users(inputs: Dict[str, Any]):
|
||||
)
|
||||
@endpoint_wrapper
|
||||
def api_admin_user(inputs: Dict[str, Any], u_id: int):
|
||||
user = users.get_one(u_id)
|
||||
if request.method == 'PUT':
|
||||
edit_user_password(u_id, inputs['new_password'])
|
||||
user.edit_password(inputs['new_password'])
|
||||
return return_api({})
|
||||
|
||||
elif request.method == 'DELETE':
|
||||
delete_user(u_id)
|
||||
user.delete()
|
||||
for key, value in api_key_map.items():
|
||||
if value['user_data'].user_id == u_id:
|
||||
if value.user_data.user_id == u_id:
|
||||
del api_key_map[key]
|
||||
break
|
||||
return return_api({})
|
||||
|
||||
434
frontend/input_validation.py
Normal file
434
frontend/input_validation.py
Normal file
@@ -0,0 +1,434 @@
|
||||
#-*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
Input validation for the API
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from re import compile
|
||||
from typing import Any, Callable, Dict, List, Union
|
||||
|
||||
from apprise import Apprise
|
||||
from flask import Blueprint, request
|
||||
from flask.sansio.scaffold import T_route
|
||||
|
||||
from backend.custom_exceptions import (AccessUnauthorized, InvalidKeyValue,
|
||||
InvalidTime, KeyNotFound,
|
||||
NewAccountsNotAllowed,
|
||||
NotificationServiceNotFound,
|
||||
UsernameInvalid, UsernameTaken,
|
||||
UserNotFound)
|
||||
from backend.helpers import (RepeatQuantity, SortingMethod,
|
||||
TimelessSortingMethod)
|
||||
from backend.settings import _format_setting
|
||||
|
||||
api_prefix = "/api"
|
||||
_admin_api_prefix = '/admin'
|
||||
admin_api_prefix = api_prefix + _admin_api_prefix
|
||||
|
||||
color_regex = compile(r'#[0-9a-f]{6}')
|
||||
|
||||
class DataSource:
|
||||
DATA = 1
|
||||
VALUES = 2
|
||||
|
||||
|
||||
class InputVariable(ABC):
|
||||
value: Any
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, value: Any) -> None:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name() -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def validate(self) -> bool:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def required() -> bool:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def default() -> Any:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def source() -> int:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def description() -> str:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def related_exceptions() -> List[Exception]:
|
||||
pass
|
||||
|
||||
|
||||
class DefaultInputVariable(InputVariable):
|
||||
source = DataSource.DATA
|
||||
required = True
|
||||
default = None
|
||||
related_exceptions = []
|
||||
|
||||
def __init__(self, value: Any) -> None:
|
||||
self.value = value
|
||||
|
||||
def validate(self) -> bool:
|
||||
return isinstance(self.value, str) and self.value
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'| {self.name} | {"Yes" if self.required else "No"} | {self.description} | N/A |'
|
||||
|
||||
|
||||
class NonRequiredVersion(InputVariable):
|
||||
required = False
|
||||
|
||||
def validate(self) -> bool:
|
||||
return self.value is None or super().validate()
|
||||
|
||||
|
||||
class UsernameVariable(DefaultInputVariable):
|
||||
name = 'username'
|
||||
description = 'The username of the user account'
|
||||
related_exceptions = [KeyNotFound, UserNotFound]
|
||||
|
||||
|
||||
class PasswordVariable(DefaultInputVariable):
|
||||
name = 'password'
|
||||
description = 'The password of the user account'
|
||||
related_exceptions = [KeyNotFound, AccessUnauthorized]
|
||||
|
||||
|
||||
class NewPasswordVariable(PasswordVariable):
|
||||
name = 'new_password'
|
||||
description = 'The new password of the user account'
|
||||
related_exceptions = [KeyNotFound]
|
||||
|
||||
|
||||
class UsernameCreateVariable(UsernameVariable):
|
||||
related_exceptions = [
|
||||
KeyNotFound,
|
||||
UsernameInvalid, UsernameTaken,
|
||||
NewAccountsNotAllowed
|
||||
]
|
||||
|
||||
|
||||
class PasswordCreateVariable(PasswordVariable):
|
||||
related_exceptions = [KeyNotFound]
|
||||
|
||||
|
||||
class TitleVariable(DefaultInputVariable):
|
||||
name = 'title'
|
||||
description = 'The title of the entry'
|
||||
related_exceptions = [KeyNotFound]
|
||||
|
||||
|
||||
class URLVariable(DefaultInputVariable):
|
||||
name = 'url'
|
||||
description = 'The Apprise URL of the notification service'
|
||||
related_exceptions = [KeyNotFound, InvalidKeyValue]
|
||||
|
||||
def validate(self) -> bool:
|
||||
return Apprise().add(self.value)
|
||||
|
||||
|
||||
class EditTitleVariable(NonRequiredVersion, TitleVariable):
|
||||
related_exceptions = []
|
||||
|
||||
|
||||
class EditURLVariable(NonRequiredVersion, URLVariable):
|
||||
related_exceptions = [InvalidKeyValue]
|
||||
|
||||
|
||||
class SortByVariable(DefaultInputVariable):
|
||||
name = 'sort_by'
|
||||
description = 'How to sort the result'
|
||||
required = False
|
||||
source = DataSource.VALUES
|
||||
_options = [k.lower() for k in SortingMethod._member_names_]
|
||||
default = SortingMethod._member_names_[0].lower()
|
||||
related_exceptions = [InvalidKeyValue]
|
||||
|
||||
def __init__(self, value: str) -> None:
|
||||
self.value = value
|
||||
|
||||
def validate(self) -> bool:
|
||||
if not self.value in self._options:
|
||||
return False
|
||||
|
||||
self.value = SortingMethod[self.value.upper()]
|
||||
return True
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return '| {n} | {r} | {d} | {v} |'.format(
|
||||
n=self.name,
|
||||
r="Yes" if self.required else "No",
|
||||
d=self.description,
|
||||
v=", ".join(f'`{o}`' for o in self._options)
|
||||
)
|
||||
|
||||
|
||||
class TemplateSortByVariable(SortByVariable):
|
||||
_options = [k.lower() for k in TimelessSortingMethod._member_names_]
|
||||
default = TimelessSortingMethod._member_names_[0].lower()
|
||||
|
||||
def validate(self) -> bool:
|
||||
if not self.value in self._options:
|
||||
return False
|
||||
|
||||
self.value = TimelessSortingMethod[self.value.upper()]
|
||||
return True
|
||||
|
||||
class StaticReminderSortByVariable(TemplateSortByVariable):
|
||||
pass
|
||||
|
||||
|
||||
class TimeVariable(DefaultInputVariable):
|
||||
name = 'time'
|
||||
description = 'The UTC epoch timestamp that the reminder should be sent at'
|
||||
related_exceptions = [KeyNotFound, InvalidKeyValue, InvalidTime]
|
||||
|
||||
def validate(self) -> bool:
|
||||
return isinstance(self.value, (float, int))
|
||||
|
||||
|
||||
class EditTimeVariable(NonRequiredVersion, TimeVariable):
|
||||
related_exceptions = [InvalidKeyValue, InvalidTime]
|
||||
|
||||
|
||||
class NotificationServicesVariable(DefaultInputVariable):
|
||||
name = 'notification_services'
|
||||
description = "Array of the id's of the notification services to use to send the notification"
|
||||
related_exceptions = [KeyNotFound, InvalidKeyValue, NotificationServiceNotFound]
|
||||
|
||||
def validate(self) -> bool:
|
||||
if not isinstance(self.value, list):
|
||||
return False
|
||||
if not self.value:
|
||||
return False
|
||||
for v in self.value:
|
||||
if not isinstance(v, int):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class EditNotificationServicesVariable(NonRequiredVersion, NotificationServicesVariable):
|
||||
related_exceptions = [InvalidKeyValue, NotificationServiceNotFound]
|
||||
|
||||
|
||||
class TextVariable(NonRequiredVersion, DefaultInputVariable):
|
||||
name = 'text'
|
||||
description = 'The body of the entry'
|
||||
default = ''
|
||||
|
||||
def validate(self) -> bool:
|
||||
return isinstance(self.value, str)
|
||||
|
||||
|
||||
class RepeatQuantityVariable(DefaultInputVariable):
|
||||
name = 'repeat_quantity'
|
||||
description = 'The quantity of the repeat_interval'
|
||||
required = False
|
||||
_options = [m.lower() for m in RepeatQuantity._member_names_]
|
||||
default = None
|
||||
related_exceptions = [InvalidKeyValue]
|
||||
|
||||
def validate(self) -> bool:
|
||||
if self.value is None:
|
||||
return True
|
||||
|
||||
if not self.value in self._options:
|
||||
return False
|
||||
|
||||
self.value = RepeatQuantity[self.value.upper()]
|
||||
return True
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return '| {n} | {r} | {d} | {v} |'.format(
|
||||
n=self.name,
|
||||
r="Yes" if self.required else "No",
|
||||
d=self.description,
|
||||
v=", ".join(f'`{o}`' for o in self._options)
|
||||
)
|
||||
|
||||
|
||||
class RepeatIntervalVariable(DefaultInputVariable):
|
||||
name = 'repeat_interval'
|
||||
description = 'The number of the interval'
|
||||
required = False
|
||||
default = None
|
||||
related_exceptions = [InvalidKeyValue]
|
||||
|
||||
def validate(self) -> bool:
|
||||
return (
|
||||
self.value is None
|
||||
or (
|
||||
isinstance(self.value, int)
|
||||
and self.value > 0
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class WeekDaysVariable(DefaultInputVariable):
|
||||
name = 'weekdays'
|
||||
description = 'On which days of the weeks to run the reminder'
|
||||
required = False
|
||||
default = None
|
||||
related_exceptions = [InvalidKeyValue]
|
||||
_options = {0, 1, 2, 3, 4, 5, 6}
|
||||
|
||||
def validate(self) -> bool:
|
||||
return self.value is None or (
|
||||
isinstance(self.value, list)
|
||||
and len(self.value) > 0
|
||||
and all(v in self._options for v in self.value)
|
||||
)
|
||||
|
||||
|
||||
class ColorVariable(DefaultInputVariable):
|
||||
name = 'color'
|
||||
description = 'The hex code of the color of the entry, which is shown in the web-ui'
|
||||
required = False
|
||||
default = None
|
||||
related_exceptions = [InvalidKeyValue]
|
||||
|
||||
def validate(self) -> bool:
|
||||
return self.value is None or color_regex.search(self.value)
|
||||
|
||||
|
||||
class QueryVariable(DefaultInputVariable):
|
||||
name = 'query'
|
||||
description = 'The search term'
|
||||
source = DataSource.VALUES
|
||||
|
||||
|
||||
class AdminSettingsVariable(DefaultInputVariable):
|
||||
related_exceptions = [KeyNotFound, InvalidKeyValue]
|
||||
|
||||
def validate(self) -> bool:
|
||||
try:
|
||||
_format_setting(self.name, self.value)
|
||||
except InvalidKeyValue:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class AllowNewAccountsVariable(AdminSettingsVariable):
|
||||
name = 'allow_new_accounts'
|
||||
description = ('Whether or not to allow users to register a new account. '
|
||||
+ 'The admin can always add a new account.')
|
||||
|
||||
|
||||
class LoginTimeVariable(AdminSettingsVariable):
|
||||
name = 'login_time'
|
||||
description = ('How long a user stays logged in, in seconds. '
|
||||
+ 'Between 1 min and 1 month (60 <= sec <= 2592000)')
|
||||
|
||||
|
||||
class LoginTimeResetVariable(AdminSettingsVariable):
|
||||
name = 'login_time_reset'
|
||||
description = 'If the Login Time timer should reset with each API request.'
|
||||
|
||||
|
||||
def input_validation() -> Union[None, Dict[str, Any]]:
|
||||
"""Checks, extracts and transforms inputs
|
||||
|
||||
Raises:
|
||||
KeyNotFound: A required key was not supplied
|
||||
InvalidKeyValue: The value of a key is not valid
|
||||
|
||||
Returns:
|
||||
Union[None, Dict[str, Any]]: `None` if the endpoint + method doesn't require input variables.
|
||||
Otherwise `Dict[str, Any]` with the input variables, checked and formatted.
|
||||
"""
|
||||
inputs = {}
|
||||
|
||||
input_variables: Dict[str, List[Union[List[InputVariable], str]]]
|
||||
if request.path.startswith(admin_api_prefix):
|
||||
input_variables = api_docs[
|
||||
_admin_api_prefix + request.url_rule.rule.split(admin_api_prefix)[1]
|
||||
]['input_variables']
|
||||
else:
|
||||
input_variables = api_docs[
|
||||
request.url_rule.rule.split(api_prefix)[1]
|
||||
]['input_variables']
|
||||
|
||||
if not input_variables:
|
||||
return
|
||||
|
||||
if input_variables.get(request.method) is None:
|
||||
return inputs
|
||||
|
||||
given_variables = {}
|
||||
given_variables[DataSource.DATA] = request.get_json() if request.data else {}
|
||||
given_variables[DataSource.VALUES] = request.values
|
||||
for input_variable in input_variables[request.method]:
|
||||
if (
|
||||
input_variable.required and
|
||||
not input_variable.name in given_variables[input_variable.source]
|
||||
):
|
||||
raise KeyNotFound(input_variable.name)
|
||||
|
||||
input_value = given_variables[input_variable.source].get(
|
||||
input_variable.name,
|
||||
input_variable.default
|
||||
)
|
||||
value: InputVariable = input_variable(input_value)
|
||||
|
||||
if not value.validate():
|
||||
raise InvalidKeyValue(input_variable.name, input_value)
|
||||
|
||||
inputs[input_variable.name] = value.value
|
||||
return inputs
|
||||
|
||||
|
||||
api_docs: Dict[str, Dict[str, Any]] = {}
|
||||
class APIBlueprint(Blueprint):
|
||||
def route(
|
||||
self,
|
||||
rule: str,
|
||||
description: str = '',
|
||||
input_variables: Dict[str, List[Union[List[InputVariable], str]]] = {},
|
||||
requires_auth: bool = True,
|
||||
**options: Any
|
||||
) -> Callable[[T_route], T_route]:
|
||||
|
||||
if self == api:
|
||||
processed_rule = rule
|
||||
elif self == admin_api:
|
||||
processed_rule = _admin_api_prefix + rule
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
api_docs[processed_rule] = {
|
||||
'endpoint': processed_rule,
|
||||
'description': description,
|
||||
'requires_auth': requires_auth,
|
||||
'methods': options['methods'],
|
||||
'input_variables': {
|
||||
k: v[0]
|
||||
for k, v in input_variables.items()
|
||||
if v and v[0]
|
||||
},
|
||||
'method_descriptions': {
|
||||
k: v[1]
|
||||
for k, v in input_variables.items()
|
||||
if v and len(v) == 2 and v[1]
|
||||
}
|
||||
}
|
||||
|
||||
return super().route(rule, **options)
|
||||
|
||||
api = APIBlueprint('api', __name__)
|
||||
admin_api = APIBlueprint('admin_api', __name__)
|
||||
@@ -1,20 +1,22 @@
|
||||
#-*- coding: utf-8 -*-
|
||||
|
||||
import logging
|
||||
from flask import Blueprint, render_template
|
||||
|
||||
ui = Blueprint('ui', __name__)
|
||||
|
||||
methods = ['GET']
|
||||
|
||||
class UIVariables:
|
||||
url_prefix: str = ''
|
||||
|
||||
@ui.route('/', methods=methods)
|
||||
def ui_login():
|
||||
return render_template('login.html', url_prefix=logging.URL_PREFIX)
|
||||
return render_template('login.html', url_prefix=UIVariables.url_prefix)
|
||||
|
||||
@ui.route('/reminders', methods=methods)
|
||||
def ui_reminders():
|
||||
return render_template('reminders.html', url_prefix=logging.URL_PREFIX)
|
||||
return render_template('reminders.html', url_prefix=UIVariables.url_prefix)
|
||||
|
||||
@ui.route('/admin', methods=methods)
|
||||
def ui_admin():
|
||||
return render_template('admin.html', url_prefixx=logging.URL_PREFIX)
|
||||
return render_template('admin.html', url_prefix=UIVariables.url_prefix)
|
||||
|
||||
@@ -1,15 +1,18 @@
|
||||
#!/usr/bin/env python3
|
||||
#-*- coding: utf-8 -*-
|
||||
|
||||
from sys import path
|
||||
from os.path import dirname
|
||||
from sys import path
|
||||
|
||||
from frontend.input_validation import DataSource
|
||||
|
||||
path.insert(0, dirname(path[0]))
|
||||
|
||||
from subprocess import run
|
||||
from typing import Union
|
||||
from frontend.api import (DataSource, NotificationServiceNotFound,
|
||||
ReminderNotFound, TemplateNotFound, api_docs)
|
||||
|
||||
from frontend.api import (NotificationServiceNotFound, ReminderNotFound,
|
||||
TemplateNotFound, api_docs)
|
||||
from MIND import _folder_path, api_prefix
|
||||
|
||||
url_var_map = {
|
||||
|
||||
@@ -8,7 +8,8 @@ import backend.custom_exceptions
|
||||
class Test_Custom_Exceptions(unittest.TestCase):
|
||||
def test_type(self):
|
||||
defined_exceptions: List[Exception] = filter(
|
||||
lambda c: c.__module__ == 'backend.custom_exceptions' and c is not backend.custom_exceptions.CustomException,
|
||||
lambda c: c.__module__ == 'backend.custom_exceptions'
|
||||
and c is not backend.custom_exceptions.CustomException,
|
||||
map(
|
||||
lambda c: c[1],
|
||||
getmembers(modules['backend.custom_exceptions'], isclass)
|
||||
@@ -16,11 +17,21 @@ class Test_Custom_Exceptions(unittest.TestCase):
|
||||
)
|
||||
|
||||
for defined_exception in defined_exceptions:
|
||||
self.assertEqual(
|
||||
self.assertIn(
|
||||
getmro(defined_exception)[1],
|
||||
backend.custom_exceptions.CustomException
|
||||
(
|
||||
backend.custom_exceptions.CustomException,
|
||||
Exception
|
||||
)
|
||||
)
|
||||
result = defined_exception().api_response
|
||||
try:
|
||||
result = defined_exception().api_response
|
||||
except TypeError:
|
||||
try:
|
||||
result = defined_exception('1').api_response
|
||||
except TypeError:
|
||||
result = defined_exception('1', '2').api_response
|
||||
|
||||
self.assertIsInstance(result, dict)
|
||||
result['error']
|
||||
result['result']
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import unittest
|
||||
|
||||
from backend.db import DBConnection
|
||||
from MIND import DB_FILENAME, _folder_path
|
||||
from backend.helpers import folder_path
|
||||
from MIND import DB_FILENAME
|
||||
|
||||
|
||||
class Test_DB(unittest.TestCase):
|
||||
def test_foreign_key(self):
|
||||
DBConnection.file = _folder_path(*DB_FILENAME)
|
||||
DBConnection.file = folder_path(*DB_FILENAME)
|
||||
instance = DBConnection(timeout=20.0)
|
||||
self.assertEqual(instance.cursor().execute("PRAGMA foreign_keys;").fetchone()[0], 1)
|
||||
|
||||
@@ -1,20 +1,14 @@
|
||||
import unittest
|
||||
|
||||
from backend.reminders import filter_function, ReminderHandler
|
||||
from backend.helpers import search_filter
|
||||
|
||||
class Test_Reminder_Handler(unittest.TestCase):
|
||||
def test_starting_stopping(self):
|
||||
context = 'test'
|
||||
instance = ReminderHandler(context)
|
||||
self.assertIs(context, instance.context)
|
||||
|
||||
def test_filter_function(self):
|
||||
p = {
|
||||
'title': 'TITLE',
|
||||
'text': 'TEXT'
|
||||
}
|
||||
for test_case in ('', 'title', 'ex'):
|
||||
self.assertTrue(filter_function(test_case, p))
|
||||
self.assertTrue(search_filter(test_case, p))
|
||||
for test_case in (' ', 'Hello'):
|
||||
self.assertFalse(filter_function(test_case, p))
|
||||
|
||||
self.assertFalse(search_filter(test_case, p))
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
import unittest
|
||||
|
||||
from backend.custom_exceptions import UsernameInvalid
|
||||
from backend.users import ONEPASS_INVALID_USERNAMES, _check_username
|
||||
from backend.users import ONEPASS_INVALID_USERNAMES, Users
|
||||
|
||||
class Test_Users(unittest.TestCase):
|
||||
def test_username_check(self):
|
||||
users = Users()
|
||||
for test_case in ('', 'test'):
|
||||
_check_username(test_case)
|
||||
users._check_username(test_case)
|
||||
|
||||
for test_case in (' ', ' ', '0', 'api', *ONEPASS_INVALID_USERNAMES):
|
||||
with self.assertRaises(UsernameInvalid):
|
||||
_check_username(test_case)
|
||||
|
||||
users._check_username(test_case)
|
||||
|
||||
Reference in New Issue
Block a user