From 0cbb03151f0c0c2b6f0360c84772fb742444cd8e Mon Sep 17 00:00:00 2001 From: CasVT Date: Tue, 22 Apr 2025 23:29:35 +0200 Subject: [PATCH] Refactored backend (Fixes #87) --- .dockerignore | 11 +- .github/workflows/build_docs.yml | 2 +- .gitignore | 4 +- .pre-commit-config.yaml | 52 + .vscode/settings.json | 54 +- .vscode/tasks.json | 18 + CONTRIBUTING.md | 73 +- Dockerfile | 2 +- MIND.py | 312 +++- README.md | 49 +- backend/base/custom_exceptions.py | 428 ++++++ backend/base/definitions.py | 315 ++++ backend/base/helpers.py | 394 +++++ backend/base/logging.py | 163 ++ backend/custom_exceptions.py | 138 -- backend/db.py | 521 ------- backend/features/reminder_handler.py | 132 ++ backend/features/reminders.py | 418 +++++ backend/features/static_reminders.py | 244 +++ backend/features/templates.py | 223 +++ backend/helpers.py | 116 -- backend/implementations/apprise_parser.py | 178 +++ .../implementations/notification_services.py | 205 +++ backend/implementations/users.py | 292 ++++ backend/internals/db.py | 493 ++++++ backend/internals/db_migration.py | 312 ++++ backend/internals/db_models.py | 815 ++++++++++ backend/internals/server.py | 247 +++ backend/internals/settings.py | 255 ++++ backend/logging.py | 141 -- backend/notification_service.py | 405 ----- backend/reminders.py | 796 ---------- backend/security.py | 40 - backend/server.py | 264 ---- backend/settings.py | 245 --- backend/static_reminders.py | 356 ----- backend/templates.py | 311 ---- backend/users.py | 235 --- docs/index.md | 2 +- frontend/api.py | 1360 +++++++++-------- frontend/input_validation.py | 845 +++++----- frontend/static/json/pwa_manifest.json | 1 + frontend/ui.py | 26 +- project_management/docs-requirements.txt | 6 - project_management/generate_api_docs.py | 164 +- project_management/requirements-docs.txt | 6 + pyproject.toml | 32 + requirements-dev.txt | 4 + requirements.txt | 8 +- tests/MIND_test.py | 22 - tests/Tbackend/MIND_test.py | 24 + tests/Tbackend/__init__.py | 0 tests/Tbackend/custom_exceptions_test.py | 5 + tests/Tbackend/db_test.py | 26 + tests/Tbackend/reminders_test.py | 19 + tests/Tbackend/security_test.py | 10 + tests/Tbackend/ui_test.py | 21 + tests/Tbackend/users_test.py | 15 + tests/api_test.py | 25 - tests/custom_exceptions_test.py | 39 - tests/db_test.py | 11 - tests/reminders_test.py | 14 - tests/security_test.py | 10 - tests/ui_test.py | 20 - tests/users_test.py | 14 - 65 files changed, 6974 insertions(+), 5014 deletions(-) create mode 100644 .pre-commit-config.yaml create mode 100644 .vscode/tasks.json create mode 100644 backend/base/custom_exceptions.py create mode 100644 backend/base/definitions.py create mode 100644 backend/base/helpers.py create mode 100644 backend/base/logging.py delete mode 100644 backend/custom_exceptions.py delete mode 100644 backend/db.py create mode 100644 backend/features/reminder_handler.py create mode 100644 backend/features/reminders.py create mode 100644 backend/features/static_reminders.py create mode 100644 backend/features/templates.py delete mode 100644 backend/helpers.py create mode 100644 backend/implementations/apprise_parser.py create mode 100644 backend/implementations/notification_services.py create mode 100644 backend/implementations/users.py create mode 100644 backend/internals/db.py create mode 100644 backend/internals/db_migration.py create mode 100644 backend/internals/db_models.py create mode 100644 backend/internals/server.py create mode 100644 backend/internals/settings.py delete mode 100644 backend/logging.py delete mode 100644 backend/notification_service.py delete mode 100644 backend/reminders.py delete mode 100644 backend/security.py delete mode 100644 backend/server.py delete mode 100644 backend/settings.py delete mode 100644 backend/static_reminders.py delete mode 100644 backend/templates.py delete mode 100644 backend/users.py delete mode 100644 project_management/docs-requirements.txt create mode 100644 project_management/requirements-docs.txt create mode 100644 pyproject.toml create mode 100644 requirements-dev.txt delete mode 100644 tests/MIND_test.py create mode 100644 tests/Tbackend/MIND_test.py create mode 100644 tests/Tbackend/__init__.py create mode 100644 tests/Tbackend/custom_exceptions_test.py create mode 100644 tests/Tbackend/db_test.py create mode 100644 tests/Tbackend/reminders_test.py create mode 100644 tests/Tbackend/security_test.py create mode 100644 tests/Tbackend/ui_test.py create mode 100644 tests/Tbackend/users_test.py delete mode 100644 tests/api_test.py delete mode 100644 tests/custom_exceptions_test.py delete mode 100644 tests/db_test.py delete mode 100644 tests/reminders_test.py delete mode 100644 tests/security_test.py delete mode 100644 tests/ui_test.py delete mode 100644 tests/users_test.py diff --git a/.dockerignore b/.dockerignore index f90f8f0..ba35e7f 100644 --- a/.dockerignore +++ b/.dockerignore @@ -129,15 +129,14 @@ dmypy.json .pyre/ # Database -**/*.db -**/*.db-shm -**/*.db-wal +db/ # VS code *.code-workspace .vscode/ # Docker +Dockerfile .dockerignore docker-compose.yml @@ -154,8 +153,6 @@ LICENSE tests/ # Project management files -release.sh docs/ -docs-requirements.txt -mkdocs.yml -generate_api_docs.py +project_management/ +requirements-dev.txt diff --git a/.github/workflows/build_docs.yml b/.github/workflows/build_docs.yml index 97c59b8..3b3b3e5 100644 --- a/.github/workflows/build_docs.yml +++ b/.github/workflows/build_docs.yml @@ -17,7 +17,7 @@ jobs: with: python-version: 3.8 cache: 'pip' - - run: pip install -r requirements.txt -r project_management/docs-requirements.txt + - run: pip install -r requirements.txt -r project_management/requirements-docs.txt name: Install dependencies - run: python3 project_management/generate_api_docs.py name: Generate API docs diff --git a/.gitignore b/.gitignore index 033e14a..c4b537e 100644 --- a/.gitignore +++ b/.gitignore @@ -14,7 +14,6 @@ dist/ downloads/ eggs/ .eggs/ -lib/ lib64/ parts/ sdist/ @@ -57,6 +56,7 @@ coverage.xml # Django stuff: *.log +*.log.* local_settings.py db.sqlite3 db.sqlite3-journal @@ -137,4 +137,4 @@ dmypy.json *.code-workspace # Project management files -release.sh +release*.sh diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..0f8749a --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,52 @@ +repos: +- repo: https://github.com/pycqa/isort + rev: 5.13.2 + hooks: + - id: isort + name: isort + additional_dependencies: [ + apprise ~= 1.4, + python-dateutil ~= 2.8, + Flask ~= 3.0, + waitress ~= 2.1 + ] + +- repo: local + hooks: + - id: mypy + name: mypy + language: python + pass_filenames: false + additional_dependencies: [ + mypy ~= 1.10, + + apprise ~= 1.4, + python-dateutil ~= 2.8, + Flask ~= 3.0, + waitress ~= 2.1 + ] + entry: python -m mypy --explicit-package-bases . + + - id: unittest + name: unittest + language: python + pass_filenames: false + additional_dependencies: [ + apprise ~= 1.4, + python-dateutil ~= 2.8, + Flask ~= 3.0, + waitress ~= 2.1 + ] + entry: python -m unittest discover -s ./tests -p '*.py' + +- repo: https://github.com/hhatto/autopep8 + rev: v2.2.0 + hooks: + - id: autopep8 + name: autopep8 + additional_dependencies: [ + apprise ~= 1.4, + python-dateutil ~= 2.8, + Flask ~= 3.0, + waitress ~= 2.1 + ] diff --git a/.vscode/settings.json b/.vscode/settings.json index 88b7c9f..7196aa8 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,13 +1,49 @@ { - "python.testing.unittestArgs": [ - "-v", - "-s", - "./tests", - "-p", - "*_test.py" + "editor.insertSpaces": true, + "editor.tabSize": 4, + + "[python]": { + "editor.formatOnSave": true, + "editor.defaultFormatter": "ms-python.autopep8", + "editor.codeActionsOnSave": { + "source.organizeImports": "always", + } + }, + "isort.check": true, + "isort.severity": { + "W": "Warning", + "E": "Warning" + }, + "isort.args": [ + "--jobs", "-1" ], - "python.testing.pytestEnabled": false, + "python.testing.unittestEnabled": true, - "python.analysis.autoImportCompletions": true, - "python.analysis.typeCheckingMode": "off" + "python.testing.unittestArgs": [ + "-s", "./tests", + "-p", "*.py" + ], + + "python.analysis.typeCheckingMode": "standard", + "python.analysis.diagnosticMode": "workspace", + + "mypy-type-checker.reportingScope": "workspace", + "mypy-type-checker.preferDaemon": false, + "mypy-type-checker.args": [ + "--explicit-package-bases" + ], + + "cSpell.words": [ + "behaviour", + "customisable", + "customised", + "noqa", + "traceback" + ], + "cSpell.languageSettings": [ + { + "languageId": "log", + "enabled": false + } + ] } \ No newline at end of file diff --git a/.vscode/tasks.json b/.vscode/tasks.json new file mode 100644 index 0000000..bbe2380 --- /dev/null +++ b/.vscode/tasks.json @@ -0,0 +1,18 @@ +{ + "version": "2.0.0", + "tasks": [ + { + "label": "Format All", + "type": "shell", + "command": "python3 -m isort .; python3 -m autopep8 --in-place -r .", + "windows": { + "command": "python -m isort .; python -m autopep8 --in-place -r ." + }, + "group": "build", + "presentation": { + "reveal": "silent", + "clear": true + } + } + ] +} \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index b3380c5..6311469 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -2,32 +2,63 @@ ## General steps Contributing to MIND consists of 5 steps, listed hereunder. -1. Make a [contributing request](https://github.com/Casvt/MIND/issues/new?template=contribute-request.md), where you describe what you plan on doing. This request needs to get approved before you can start, or your pull request won't be accepted. This is to avoid multiple people from doing the same thing and to avoid you wasting your time if we do not wish the changes. This is also where discussions can be held about how something will be implemented. -2. When the request is accepted, start your local development (more info about this below). -3. When done, create a pull request to the Development branch, where you mention again what you've changed/added and give a link to the original contributing request issue. -4. The PR will be reviewed and if requested, changes will need to be made before it is accepted. +1. Make a [contributing request](https://github.com/Casvt/MIND/issues/new?template=contribute-request.md), where you describe what you plan on doing. _This request needs to get approved before you can start._ The contributing request has multiple uses: + 1. Avoid multiple people working on the same thing. + 2. Avoid you wasting your time on changes that we do not wish for. + 3. If needed, have discussions about how something will be implemented. + 4. A place for contact, be it questions, status updates or something else. +2. When the request is accepted, start your local development (more info on this below). +3. When done, create a pull request to the Development branch, where you quickly mention what has changed and give a link to the original contributing request issue. +4. The PR will be reviewed. Changes might need to be made in order for it to be merged. 5. When everything is okay, the PR will be accepted and you'll be done! -## Local development steps -Once your request is accepted, you can start your local development. +## Local development -1. Clone the repository onto your computer and open it using your preferred IDE (Visual Studio Code is used by us). -2. Make the changes needed and write accompanying tests if needed. -3. Check if the code written follows the styling guide below. -4. Run the finished version, using python 3.8, to check if you've made any errors. -5. Run the tests (unittest is used). This can be done with a button click within VS Code, or with the following command where you need to be inside the root folder of the project: +Once your contribution request has been accepted, you can start your local development. + +### IDE + +It's up to you how you make the changes, but we use Visual Studio Code as the IDE. A workspace settings file is included that takes care of some styling, testing and formatting of the backend code. + +1. The vs code extension `ms-python.vscode-pylance` in combination with the settings file with enable type checking. +2. The vs code extension `ms-python.mypy-type-checker` in combination with the settings file will enable mypy checking. +3. The vs code extension `ms-python.autopep8` in combination with the settings file will format code on save. +4. The vs code extension `ms-python.isort` in combination with the settings file will sort the import statements on save. +5. The settings file sets up the testing suite in VS Code such that you can just click the test button to run all tests. + +If you do not use VS Code with the mentioned extensions, then below are some commands that you can manually run in the base directory to achieve similar results. + +1. **Mypy**: +```bash +mypy --explicit-package-bases . +``` +2. **autopep8**: +```bash +autopep8 --recursive --in-place . +``` +3. **isort**: +```bash +isort . +``` +4. **unittest** ```bash python3 -m unittest discover -s ./tests -p '*.py' ``` -6. Test your version thoroughly to catch as many bugs as possible (if any). -## Styling guide -The code of MIND is written in such way that it follows the following rules. Your code should too. +### Strict rules -1. Compatible with python 3.8 to 3.11 . -2. Tabs (4 space size) are used for indentation. -3. Use type hints as much as possible. If you encounter an import loop because something needs to be imported for type hinting, utilise [`typing.TYPE_CHECKING`](https://docs.python.org/3/library/typing.html#typing.TYPE_CHECKING). -4. Each function in the backend needs a doc string describing the function, what the inputs are, what errors could be raised from within the function and what the output is. -5. The imports need to be sorted (the extension `isort` is used in VS Code). -6. The code needs to be compatible with Linux, MacOS, Windows and the Docker container. -7. The code should, though not strictly enforced, reasonably comply with the rule of 80 characters per line. +There are a few conditions that should always be met: + +1. MIND should support Python version 3.8 and higher. +2. MIND should be compatible with Linux, MacOS, Windows and the Docker container. +3. The tests should all pass. + +### Styling guide + +Following the styling guide for the backend code is not a strict rule, but effort should be put in to conform to it as much as possible. Running autopep8 and isort handles most of this. + +1. Indentation is done with 4 spaces. Not using tabs. +2. Use type hints as much as possible. If you encounter an import loop because something needs to be imported for type hinting, utilise [`typing.TYPE_CHECKING`](https://docs.python.org/3/library/typing.html#typing.TYPE_CHECKING). +3. A function in the backend needs a doc string describing the function, what the inputs are, what errors could be raised from within the function and what the output is. +4. The imports need to be sorted. +5. The code should, though not strictly enforced, reasonably comply with the rule of 80 characters per line. diff --git a/Dockerfile b/Dockerfile index 5162770..cc6aa5a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -6,7 +6,7 @@ STOPSIGNAL SIGINT WORKDIR /app COPY requirements.txt requirements.txt -RUN pip3 install -r requirements.txt +RUN pip3 install --no-cache-dir -r requirements.txt COPY . . diff --git a/MIND.py b/MIND.py index 0a1245c..e55116b 100644 --- a/MIND.py +++ b/MIND.py @@ -1,65 +1,283 @@ #!/usr/bin/env python3 -#-*- coding: utf-8 -*- +# -*- coding: utf-8 -*- -""" -The main file where MIND is started from -""" +from argparse import ArgumentParser +from atexit import register +from os import environ, name +from signal import SIGINT, SIGTERM, signal +from subprocess import Popen +from sys import argv, exit +from typing import NoReturn, Union -from sys import argv +from backend.base.custom_exceptions import InvalidKeyValue +from backend.base.definitions import Constants, StartType +from backend.base.helpers import check_python_version, get_python_exe +from backend.base.logging import LOGGER, setup_logging +from backend.features.reminder_handler import ReminderHandler +from backend.internals.db import close_all_db, set_db_location, setup_db +from backend.internals.server import Server, handle_start_type +from backend.internals.settings import Settings -from backend.db import setup_db, setup_db_location -from backend.helpers import check_python_version -from backend.logging import LOGGER, setup_logging -from backend.reminders import ReminderHandler -from backend.server import SERVER, handle_flags -from backend.settings import get_setting -#============================= -# WARNING: -# These settings have moved into the admin panel. Their current value has been -# taken over. The values will from now on be ignored, and the variables will -# be deleted next version. -HOST = '0.0.0.0' -PORT = '8080' -URL_PREFIX = '' # Must either be empty or start with '/' e.g. '/mind' -#============================= +def _main( + start_type: StartType, + db_folder: Union[str, None] = None, + log_folder: Union[str, None] = None, + host: Union[str, None] = None, + port: Union[int, None] = None, + url_prefix: Union[str, None] = None +) -> NoReturn: + """The main function of the MIND sub-process -def MIND() -> None: - """The main function of MIND - """ - setup_logging() - LOGGER.info('Starting up MIND') + Args: + start_type (StartType): The type of (re)start. + db_folder (Union[str, None], optional): The folder in which the database + will be stored or in which a database is for MIND to use. + Defaults to None. + log_folder (Union[str, None], optional): The folder in which the logs + from MIND will be stored. + Defaults to None. + host (Union[str, None], optional): The host to bind the server to. + Defaults to None. + port (Union[int, None], optional): The port to bind the server to. + Defaults to None. + url_prefix (Union[str, None], optional): The URL prefix to use for the + server. + Defaults to None. - if not check_python_version(): - exit(1) + Raises: + ValueError: One of the arguments has an invalid value. - flag = argv[1] if len(argv) > 1 else None - handle_flags(flag) + Returns: + NoReturn: Exit code 0 means to shutdown. + Exit code 131 or higher means to restart with possibly special reasons. + """ + setup_logging(log_folder) + LOGGER.info('Starting up MIND') - setup_db_location() + if not check_python_version(): + exit(1) - SERVER.create_app() - reminder_handler = ReminderHandler(SERVER.app.app_context) - with SERVER.app.app_context(): - setup_db() + set_db_location(db_folder) - host = get_setting("host") - port = get_setting("port") - url_prefix = get_setting("url_prefix") - SERVER.set_url_prefix(url_prefix) + SERVER = Server() + SERVER.create_app() + with SERVER.app.app_context(): + handle_start_type(start_type) + setup_db() - reminder_handler.find_next_reminder() + s = Settings() - # ================= - SERVER.run(host, port) - # ================= + if host: + try: + s.update({"host": host}) + except InvalidKeyValue: + raise ValueError("Invalid host value") - reminder_handler.stop_handling() + if port: + try: + s.update({"port": port}) + except InvalidKeyValue: + raise ValueError("Invalid port value") - if SERVER.do_restart: - SERVER.handle_restart(flag) + if url_prefix: + try: + s.update({"url_prefix": url_prefix}) + except InvalidKeyValue: + raise ValueError("Invalid url prefix value") + + settings = s.get_settings() + SERVER.set_url_prefix(settings.url_prefix) + + reminder_handler = ReminderHandler() + reminder_handler.find_next_reminder() + + try: + # ================= + SERVER.run(settings.host, settings.port) + # ================= + + finally: + reminder_handler.stop_handling() + # close_all_db() + + if SERVER.start_type is not None: + LOGGER.info("Restarting MIND") + exit(SERVER.start_type.value) + + exit(0) + + +def _stop_sub_process(proc: Popen) -> None: + """Gracefully stop the sub-process unless that fails. Then terminate it. + + Args: + proc (Popen): The sub-process to stop. + """ + if proc.returncode is not None: + return + + try: + if name != 'nt': + try: + proc.send_signal(SIGINT) + except ProcessLookupError: + pass + else: + import win32api # type: ignore + import win32con # type: ignore + try: + win32api.GenerateConsoleCtrlEvent( + win32con.CTRL_C_EVENT, proc.pid + ) + except KeyboardInterrupt: + pass + except BaseException: + proc.terminate() + + +def _run_sub_process( + start_type: StartType = StartType.STARTUP +) -> int: + """Start the sub-process that MIND will be run in. + + Args: + start_type (StartType, optional): Why MIND was started. + Defaults to `StartType.STARTUP`. + + Returns: + int: The return code from the sub-process. + """ + env = { + **environ, + "MIND_RUN_MAIN": "1", + "MIND_START_TYPE": str(start_type.value) + } + + comm = [get_python_exe(), "-u", __file__] + argv[1:] + proc = Popen( + comm, + env=env + ) + proc._sigint_wait_secs = Constants.SUB_PROCESS_TIMEOUT # type: ignore + register(_stop_sub_process, proc=proc) + signal(SIGTERM, lambda signal_no, frame: _stop_sub_process(proc)) + + try: + return proc.wait() + except (KeyboardInterrupt, SystemExit, ChildProcessError): + return 0 + + +def MIND() -> int: + """The main function of MIND. + + Returns: + int: The return code. + """ + rc = StartType.STARTUP.value + while rc in StartType._member_map_.values(): + rc = _run_sub_process( + StartType(rc) + ) + + return rc - return if __name__ == "__main__": - MIND() + if environ.get("MIND_RUN_MAIN") == "1": + + parser = ArgumentParser( + description="MIND is a simple self hosted reminder application that can send push notifications to your device. Set the reminder and forget about it!") + + fs = parser.add_argument_group(title="Folders") + fs.add_argument( + '-d', '--DatabaseFolder', + type=str, + help="The folder in which the database will be stored or in which a database is for MIND to use" + ) + fs.add_argument( + '-l', '--LogFolder', + type=str, + help="The folder in which the logs from MIND will be stored" + ) + + hs = parser.add_argument_group(title="Hosting Settings") + hs.add_argument( + '-o', '--Host', + type=str, + help="The host to bind the server to" + ) + hs.add_argument( + '-p', '--Port', + type=int, + help="The port to bind the server to" + ) + hs.add_argument( + '-u', '--UrlPrefix', + type=str, + help="The URL prefix to use for the server" + ) + + args = parser.parse_args() + + st = StartType(int(environ.get( + "MIND_START_TYPE", + StartType.STARTUP.value + ))) + + db_folder: Union[str, None] = args.DatabaseFolder + log_folder: Union[str, None] = args.LogFolder + host: Union[str, None] = None + port: Union[int, None] = None + url_prefix: Union[str, None] = None + if st == StartType.STARTUP: + host = args.Host + port = args.Port + url_prefix = args.UrlPrefix + + try: + _main( + start_type=st, + db_folder=db_folder, + log_folder=log_folder, + host=host, + port=port, + url_prefix=url_prefix + ) + + except ValueError as e: + if not e.args: + raise e + + elif e.args[0] == 'Database location is not a folder': + parser.error( + 'The value for -d/--DatabaseFolder is not a folder' + ) + + elif e.args[0] == 'Logging folder is not a folder': + parser.error( + 'The value for -l/--LogFolder is not a folder' + ) + + elif e.args[0] == 'Invalid host value': + parser.error( + 'The value for -h/--Host is not valid' + ) + + elif e.args[0] == 'Invalid port value': + parser.error( + 'The value for -p/--Port is not valid' + ) + + elif e.args[0] == 'Invalid url prefix value': + parser.error( + 'The value for -u/--UrlPrefix is not valid' + ) + + else: + raise e + + else: + rc = MIND() + exit(rc) diff --git a/README.md b/README.md index 984f7eb..0ef67aa 100644 --- a/README.md +++ b/README.md @@ -1,45 +1,38 @@ +

+ MIND +

+

+ + + +

+ # MIND -[![Docker Pulls](https://img.shields.io/docker/pulls/mrcas/mind.svg)](https://hub.docker.com/r/mrcas/mind) +MIND is a simple self hosted reminder application that can send push notifications to your device. Set the reminder and forget about it! -__A simple self hosted reminder application that can send push notifications to your device. Set the reminder and forget about it!__ - -Mind is a simple self hosted application for creating reminders that get pushed to your device using the [Apprise](https://github.com/caronc/apprise) API. You can send messages to just about every platform, including scheduled emails! - -## Workings - -MIND can be used for sending notifications at the desired time. This can be a set time, like a yearly reminder for a birthday, or at a button click, to easily send a predefined notification when you want to. The notification can be sent to 80+ platforms with the integration of [Apprise](https://github.com/caronc/apprise). +MIND allows you to set reminders for a given time. They can run just once, or be repeated at a given interval. Or use static reminders to send notifications by pressing a button. Whether you want to remind yourself of a meeting, set yearly repeating reminders for birthdays or be able to notify your family that you're late with the press of a button, it's all possible with MIND! The notifications can be sent using over 100+ platforms with the integration of [Apprise](https://github.com/caronc/apprise). If you want to send a notification, Apprise probably supports it. ## Features -- Works cross-timezone - - Notifications are sent with second-precision - -- Fine control over repetition: single time, time interval, certain weekdays or manual trigger. - -- Uses the [Apprise library](https://github.com/caronc/apprise), giving you 80+ platforms to send notifications to and the option to send to multiple platforms for each reminder - -- Easily manage the reminders with sorting options, search ability and color coding - +- Fine control over repetition: single time, time interval, certain weekdays or manual trigger +- Uses the [Apprise library](https://github.com/caronc/apprise), giving you 100+ platforms to send notifications to and the option to send to multiple platforms for each reminder +- Works cross-timezone +- Easily manage the reminders with sorting options, color coding and search - An admin panel for user management, settings and backups - -- Docker image available - +- Support for all major OS'es and Docker image available - Mobile friendly web-interface - - API available ## Installation, support and documentation -- For instructions on how to install MIND, see the [installation documentation](https://casvt.github.io/MIND/installation/installation) - -- For support, a [discord server](https://discord.gg/nMNdgG7vsE) is available or [make an issue](https://github.com/Casvt/MIND/issues) - +- For instructions on how to install MIND, see the [installation documentation](https://casvt.github.io/MIND/installation/installation). +- For support, a [discord server](https://discord.gg/nMNdgG7vsE) is available or [make an issue](https://github.com/Casvt/MIND/issues). - For all documentation, see the [documentation hub](https://casvt.github.io/MIND). ## Screenshots - - - +![](https://github.com/Casvt/Kapowarr/assets/88994465/f55c895b-7975-4a3e-88a0-f8e2a148bf8a) +![](https://github.com/Casvt/Kapowarr/assets/88994465/63d72943-0c88-4315-9a8a-01a5dc5f6f15) +![](https://github.com/Casvt/Kapowarr/assets/88994465/1f9cc9a2-ced5-49a2-b779-93528bb50bd4) diff --git a/backend/base/custom_exceptions.py b/backend/base/custom_exceptions.py new file mode 100644 index 0000000..fb3a354 --- /dev/null +++ b/backend/base/custom_exceptions.py @@ -0,0 +1,428 @@ +# -*- coding: utf-8 -*- + +from typing import Any, Union + +from backend.base.definitions import (ApiResponse, InvalidUsernameReason, + MindException) +from backend.base.logging import LOGGER + + +# region Input/Output +class KeyNotFound(MindException): + "A key was not found in the input that is required to be given." + + def __init__(self, key: str) -> None: + self.key = key + LOGGER.warning( + "This key was not found in the API request," + " eventhough it's required: %s", + key + ) + return + + @property + def api_response(self) -> ApiResponse: + return { + 'code': 400, + 'error': self.__class__.__name__, + 'result': { + 'key': self.key + } + } + + +class InvalidKeyValue(MindException): + "The value of a key is invalid." + + def __init__(self, key: str, value: Any) -> None: + self.key = key + self.value = value + LOGGER.warning( + "This key in the API request has an invalid value: " + "%s = %", + key, value + ) + + @property + def api_response(self) -> ApiResponse: + return { + 'code': 400, + 'error': self.__class__.__name__, + 'result': { + 'key': self.key, + 'value': self.value + } + } + + +# region Auth +class AccessUnauthorized(MindException): + "The password given is not correct" + + def __init__(self) -> None: + LOGGER.warning( + "The password given is not correct" + ) + return + + @property + def api_response(self) -> ApiResponse: + return { + 'code': 401, + 'error': self.__class__.__name__, + 'result': {} + } + + +class APIKeyInvalid(MindException): + "The API key is not correct" + + def __init__(self, api_key: str) -> None: + self.api_key = api_key + return + + @property + def api_response(self) -> ApiResponse: + return { + 'code': 401, + 'error': self.__class__.__name__, + 'result': { + 'api_key': self.api_key + } + } + + +class APIKeyExpired(MindException): + "The API key has expired" + + def __init__(self, api_key: str) -> None: + self.api_key = api_key + return + + @property + def api_response(self) -> ApiResponse: + return { + 'code': 401, + 'error': self.__class__.__name__, + 'result': { + 'api_key': self.api_key + } + } + + +# region Admin Operations +class OperationNotAllowed(MindException): + "What was requested to be done is not allowed" + + def __init__(self, operation: str) -> None: + LOGGER.warning( + "Operation not allowed: %s", + operation + ) + + @property + def api_response(self) -> ApiResponse: + return { + 'code': 403, + 'error': self.__class__.__name__, + 'result': {} + } + + +class NewAccountsNotAllowed(MindException): + "It's not allowed to create a new account except for the admin" + + def __init__(self) -> None: + LOGGER.warning( + "The creation of a new account was attempted but it's disabled by the admin" + ) + return + + @property + def api_response(self) -> ApiResponse: + return { + 'code': 403, + 'error': self.__class__.__name__, + 'result': {} + } + + +class InvalidDatabaseFile(MindException): + "The uploaded database file is invalid or not supported" + + def __init__(self, filepath_db: str) -> None: + self.filepath_db = filepath_db + LOGGER.warning( + "The given database file is invalid: %s", + filepath_db + ) + return + + @property + def api_response(self) -> ApiResponse: + return { + 'code': 400, + 'error': self.__class__.__name__, + 'result': { + 'filepath_db': self.filepath_db + } + } + + +class LogFileNotFound(MindException): + "The log file was not found" + + def __init__(self, log_file: str) -> None: + self.log_file = log_file + LOGGER.warning( + "The log file was not found: %s", + log_file + ) + return + + @property + def api_response(self) -> ApiResponse: + return { + 'code': 404, + 'error': self.__class__.__name__, + 'result': { + 'log_file': self.log_file + } + } + + +# region Users +class UsernameTaken(MindException): + "The username is already taken" + + def __init__(self, username: str) -> None: + self.username = username + LOGGER.warning( + "The username is already taken: %s", + username + ) + return + + @property + def api_response(self) -> ApiResponse: + return { + 'code': 400, + 'error': self.__class__.__name__, + 'result': { + 'username': self.username + } + } + + +class UsernameInvalid(MindException): + "The username contains invalid characters or is not allowed" + + def __init__( + self, + username: str, + reason: InvalidUsernameReason + ) -> None: + self.username = username + self.reason = reason + LOGGER.warning( + "The username '%s' is invalid for the following reason: %s", + username, reason.value + ) + return + + @property + def api_response(self) -> ApiResponse: + return { + 'code': 400, + 'error': self.__class__.__name__, + 'result': { + 'username': self.username, + 'reason': self.reason.value + } + } + + +class UserNotFound(MindException): + "The user requested can not be found" + + def __init__( + self, + username: Union[str, None], + user_id: Union[int, None] + ) -> None: + self.username = username + self.user_id = user_id + if username: + LOGGER.warning( + "The user can not be found: %s", + username + ) + + elif user_id: + LOGGER.warning( + "The user can not be found: ID %d", + user_id + ) + + else: + LOGGER.warning( + "The user can not be found" + ) + + return + + @property + def api_response(self) -> ApiResponse: + return { + 'code': 404, + 'error': self.__class__.__name__, + 'result': { + 'username': self.username, + 'user_id': self.user_id + } + } + + +# region Notification Services +class NotificationServiceNotFound(MindException): + "The notification service was not found" + + def __init__(self, notification_service_id: int) -> None: + self.notification_service_id = notification_service_id + LOGGER.warning( + "The notification service with the given ID cannot be found: %d", + notification_service_id + ) + return + + @property + def api_response(self) -> ApiResponse: + return { + 'code': 404, + 'error': self.__class__.__name__, + 'result': { + 'notification_service_id': self.notification_service_id + } + } + + +class NotificationServiceInUse(MindException): + """ + The notification service is wished to be deleted + but a reminder is still using it + """ + + def __init__( + self, + notification_service_id: int, + reminder_type: str + ) -> None: + self.notification_service_id = notification_service_id + self.reminder_type = reminder_type + LOGGER.warning( + "The notification service with ID %d is wished to be deleted " + "but a reminder of type %s is still using it", + notification_service_id, + reminder_type + ) + return + + @property + def api_response(self) -> ApiResponse: + return { + 'code': 404, + 'error': self.__class__.__name__, + 'result': { + 'notification_service_id': self.notification_service_id, + 'reminder_type': self.reminder_type + } + } + + +class URLInvalid(MindException): + "The Apprise URL is invalid" + + def __init__(self, url: str) -> None: + self.url = url + LOGGER.warning( + "The Apprise URL given is invalid: %s", + url + ) + + @property + def api_response(self) -> ApiResponse: + return { + 'code': 400, + 'error': self.__class__.__name__, + 'result': { + 'url': self.url + } + } + + +# region Templates +class TemplateNotFound(MindException): + "The template was not found" + + def __init__(self, template_id: int) -> None: + self.template_id = template_id + LOGGER.warning( + "The template with the given ID cannot be found: %d", + template_id + ) + return + + @property + def api_response(self) -> ApiResponse: + return { + 'code': 404, + 'error': self.__class__.__name__, + 'result': { + 'template_id': self.template_id + } + } + + +# region Reminders +class ReminderNotFound(MindException): + "The reminder was not found" + + def __init__(self, reminder_id: int) -> None: + self.reminder_id = reminder_id + LOGGER.warning( + "The reminder with the given ID cannot be found: %d", + reminder_id + ) + return + + @property + def api_response(self) -> ApiResponse: + return { + 'code': 404, + 'error': self.__class__.__name__, + 'result': { + 'reminder_id': self.reminder_id + } + } + + +class InvalidTime(MindException): + "The time given is in the past" + + def __init__(self, time: int) -> None: + self.time = time + LOGGER.warning( + "The given time is invalid: %d", + time + ) + return + + @property + def api_response(self) -> ApiResponse: + return { + 'code': 400, + 'error': self.__class__.__name__, + 'result': { + 'time': self.time + } + } diff --git a/backend/base/definitions.py b/backend/base/definitions.py new file mode 100644 index 0000000..2c3783f --- /dev/null +++ b/backend/base/definitions.py @@ -0,0 +1,315 @@ +# -*- coding: utf-8 -*- + +""" +Definitions of basic types, abstract classes, enums, etc. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from enum import Enum +from typing import (TYPE_CHECKING, Any, Dict, List, Literal, + Tuple, Type, TypedDict, TypeVar, Union, cast) + +if TYPE_CHECKING: + from backend.implementations.users import User + + +# region Types +T = TypeVar('T') +U = TypeVar('U') +WEEKDAY_NUMBER = Literal[0, 1, 2, 3, 4, 5, 6] +BaseSerialisable = Union[ + int, float, bool, str, None +] +Serialisable = Union[ + List[Union[ + BaseSerialisable, + List[BaseSerialisable], + Dict[str, BaseSerialisable] + ]], + Dict[str, Union[ + BaseSerialisable, + List[BaseSerialisable], + Dict[str, BaseSerialisable] + ]], +] + + +# region Constants +class Constants: + SUB_PROCESS_TIMEOUT = 20.0 # seconds + + HOSTING_THREADS = 10 + HOSTING_REVERT_TIME = 60.0 # seconds + + DB_FOLDER = ("db",) + DB_NAME = "MIND.db" + DB_ORIGINAL_NAME = 'MIND_original.db' + DB_TIMEOUT = 10.0 # seconds + DB_REVERT_TIME = 60.0 # seconds + + LOGGER_NAME = "MIND" + LOGGER_FILENAME = "MIND.log" + + ADMIN_USERNAME = "admin" + ADMIN_PASSWORD = "admin" + INVALID_USERNAMES = ("reminders", "api") + USERNAME_CHARACTERS = 'abcedfghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_-.!@$' + + CONNECTION_ERROR_TIMEOUT = 120 # seconds + + APPRISE_TEST_TITLE = "MIND: Test title" + APPRISE_TEST_BODY = "MIND: Test body" + + +# region Enums +class BaseEnum(Enum): + def __eq__(self, other) -> bool: + return self.value == other + + def __hash__(self) -> int: + return id(self.value) + + +class StartType(BaseEnum): + STARTUP = 130 + RESTART = 131 + RESTART_HOSTING_CHANGES = 132 + RESTART_DB_CHANGES = 133 + + +class InvalidUsernameReason(BaseEnum): + ONLY_NUMBERS = "Username can not only be numbers" + NOT_ALLOWED = "Username is not allowed" + INVALID_CHARACTER = "Username contains an invalid character" + + +class SendResult(BaseEnum): + SUCCESS = "Success" + CONNECTION_ERROR = "Connection error" + SYNTAX_INVALID_URL = "Syntax of URL invalid" + REJECTED_URL = "Values in URL rejected by service (e.g. invalid API token)" + + +class ReminderType(BaseEnum): + REMINDER = "Reminder" + STATIC_REMINDER = "Static Reminder" + TEMPLATE = "Template" + + +class RepeatQuantity(BaseEnum): + YEARS = "years" + MONTHS = "months" + WEEKS = "weeks" + DAYS = "days" + HOURS = "hours" + MINUTES = "minutes" + + +def sort_by_timeless_title(r: GeneralReminderData) -> Tuple[str, str, str]: + return (r.title, r.text or '', r.color or '') + + +def sort_by_time(r: ReminderData) -> Tuple[int, str, str, str]: + return (r.time, r.title, r.text or '', r.color or '') + + +def sort_by_timed_title(r: ReminderData) -> Tuple[str, int, str, str]: + return (r.title, r.time, r.text or '', r.color or '') + + +def sort_by_id(r: GeneralReminderData) -> int: + return r.id + + +class TimelessSortingMethod(BaseEnum): + TITLE = sort_by_timeless_title, False + TITLE_REVERSED = sort_by_timeless_title, True + DATE_ADDED = sort_by_id, False + DATE_ADDED_REVERSED = sort_by_id, True + + +class SortingMethod(BaseEnum): + TIME = sort_by_time, False + TIME_REVERSED = sort_by_time, True + TITLE = sort_by_timed_title, False + TITLE_REVERSED = sort_by_timed_title, True + DATE_ADDED = sort_by_id, False + DATE_ADDED_REVERSED = sort_by_id, True + + +class DataType(BaseEnum): + STR = 'string' + INT = 'number' + FLOAT = 'decimal number' + BOOL = 'bool' + INT_ARRAY = 'list of numbers' + NA = 'N/A' + + +class DataSource(BaseEnum): + DATA = 1 + VALUES = 2 + FILES = 3 + + +# region TypedDicts +class ApiResponse(TypedDict): + result: Any + error: Union[str, None] + code: int + + +# region Abstract Classes +class DBMigrator(ABC): + start_version: int + + @abstractmethod + def run(self) -> None: + ... + + +class MindException(Exception, ABC): + """An exception specific to MIND""" + + @property + @abstractmethod + def api_response(self) -> ApiResponse: + ... + + +# region Dataclasses +@dataclass +class ApiKeyEntry: + exp: int + user_data: User + + +def _return_exceptions() -> List[Type[MindException]]: + from backend.base.custom_exceptions import InvalidKeyValue, KeyNotFound + return [KeyNotFound, InvalidKeyValue] + + +@dataclass +class InputVariable(ABC): + value: Any + name: str + description: str + required: bool = True + default: Any = None + data_type: List[DataType] = field(default_factory=lambda: [DataType.STR]) + source: DataSource = DataSource.DATA + related_exceptions: List[Type[MindException]] = field( + default_factory=_return_exceptions + ) + + def validate(self) -> bool: + return isinstance(self.value, str) and bool(self.value) + + +@dataclass(frozen=True) +class Method: + description: str = '' + vars: List[Type[InputVariable]] = field(default_factory=list) + + +@dataclass(frozen=True) +class Methods: + get: Union[Method, None] = None + post: Union[Method, None] = None + put: Union[Method, None] = None + delete: Union[Method, None] = None + + def __getitem__(self, key: str) -> Union[Method, None]: + return getattr(self, key.lower()) + + def used_methods(self) -> List[str]: + result = [] + for method in ('get', 'post', 'put', 'delete'): + if getattr(self, method) is not None: + result.append(method) + return result + + +@dataclass(frozen=True) +class ApiDocEntry: + endpoint: str + description: str + methods: Methods + requires_auth: bool + + +@dataclass(frozen=True, order=True) +class NotificationServiceData: + id: int + title: str + url: str + + def todict(self) -> Dict[str, Any]: + return self.__dict__ + + +@dataclass(frozen=True, order=True) +class UserData: + id: int + username: str + admin: bool + salt: bytes + hash: bytes + + def todict(self) -> Dict[str, Any]: + return { + k: v + for k, v in self.__dict__.items() + if k in ('id', 'username', 'admin') + } + + +@dataclass(order=True) +class GeneralReminderData: + id: int + title: str + text: Union[str, None] + color: Union[str, None] + notification_services: List[int] + + def todict(self) -> Dict[str, Any]: + return self.__dict__ + + +@dataclass(order=True) +class TemplateData(GeneralReminderData): + ... + + +@dataclass(order=True) +class StaticReminderData(GeneralReminderData): + ... + + +@dataclass(order=True) +class ReminderData(GeneralReminderData): + time: int + original_time: Union[int, None] + repeat_quantity: Union[str, None] + repeat_interval: Union[int, None] + _weekdays: Union[str, None] + + def __post_init__(self) -> None: + if self._weekdays is not None: + self.weekdays: Union[List[WEEKDAY_NUMBER], None] = [ + cast(WEEKDAY_NUMBER, int(n)) + for n in self._weekdays.split(',') + if n + ] + else: + self.weekdays = None + + def todict(self) -> Dict[str, Any]: + return { + k: v + for k, v in self.__dict__.items() + if k != '_weekdays' + } diff --git a/backend/base/helpers.py b/backend/base/helpers.py new file mode 100644 index 0000000..36da172 --- /dev/null +++ b/backend/base/helpers.py @@ -0,0 +1,394 @@ +# -*- coding: utf-8 -*- + +""" +General "helper" function and classes +""" + +from base64 import urlsafe_b64encode +from datetime import datetime +from hashlib import pbkdf2_hmac +from logging import WARNING +from os import makedirs, symlink +from os.path import abspath, dirname, exists, join +from secrets import token_bytes +from shutil import copy2, move +from sys import base_exec_prefix, executable, platform, version_info +from typing import (Any, Callable, Generator, Iterable, + List, Sequence, Tuple, Union, cast) + +from apprise import Apprise, LogCapture +from dateutil.relativedelta import relativedelta + +from backend.base.definitions import (WEEKDAY_NUMBER, GeneralReminderData, + RepeatQuantity, SendResult, T, U) + + +def get_python_version() -> str: + """Get python version as string + + Returns: + str: The python version + """ + return ".".join( + str(i) for i in list(version_info) + ) + + +def check_python_version() -> bool: + """Check if the python version that is used is a minimum version. + + Returns: + bool: Whether or not the python version is version 3.8 or above or not. + """ + if not (version_info.major == 3 and version_info.minor >= 8): + from backend.base.logging import LOGGER + LOGGER.critical( + 'The minimum python version required is python3.8 ' + '(currently ' + str(version_info.major) + '.' + + str(version_info.minor) + '.' + str(version_info.micro) + ').' + ) + return False + return True + + +def get_python_exe() -> str: + """Get the path to the python executable. + + Returns: + str: The python executable path. + """ + if platform.startswith('darwin'): + bundle_path = join( + base_exec_prefix, + "Resources", + "Python.app", + "Contents", + "MacOS", + "Python" + ) + if exists(bundle_path): + from tempfile import mkdtemp + python_path = join(mkdtemp(), "python") + symlink(bundle_path, python_path) + + return python_path + + return executable + + +def reversed_tuples( + i: Iterable[Tuple[T, U]] +) -> Generator[Tuple[U, T], Any, Any]: + """Yield sub-tuples in reversed order. + + Args: + i (Iterable[Tuple[T, U]]): Iterator. + + Yields: + Generator[Tuple[U, T], Any, Any]: Sub-tuple with reversed order. + """ + for entry_1, entry_2 in i: + yield entry_2, entry_1 + + +def first_of_column( + columns: Iterable[Sequence[T]] +) -> List[T]: + """Get the first element of each sub-array. + + Args: + columns (Iterable[Sequence[T]]): List of + sub-arrays. + + Returns: + List[T]: List with first value of each sub-array. + """ + return [e[0] for e in columns] + + +def when_not_none( + value: Union[T, None], + to_run: Callable[[T], U] +) -> Union[U, None]: + """Run `to_run` with argument `value` iff `value is not None`. Else return + `None`. + + Args: + value (Union[T, None]): The value to check. + to_run (Callable[[T], U]): The function to run. + + Returns: + Union[U, None]: Either the return value of `to_run`, or `None`. + """ + if value is None: + return None + else: + return to_run(value) + + +def search_filter(query: str, result: GeneralReminderData) -> bool: + """Filter library results based on a query. + + Args: + query (str): The query to filter with. + result (GeneralReminderData): The library result to check. + + Returns: + bool: Whether or not the result passes the filter. + """ + query = query.lower() + return ( + query in result.title.lower() + or query in (result.text or '').lower() + ) + + +def get_hash(salt: bytes, data: str) -> bytes: + """Hash a string using the supplied salt + + Args: + salt (bytes): The salt to use when hashing + data (str): The data to hash + + Returns: + bytes: The b64 encoded hash of the supplied string + """ + return urlsafe_b64encode( + pbkdf2_hmac('sha256', data.encode(), salt, 100_000) + ) + + +def generate_salt_hash(password: str) -> Tuple[bytes, bytes]: + """Generate a salt and get the hash of the password + + Args: + password (str): The password to generate for + + Returns: + Tuple[bytes, bytes]: The salt (1) and hashed_password (2) + """ + salt = token_bytes() + hashed_password = get_hash(salt, password) + return salt, hashed_password + + +def send_apprise_notification( + urls: List[str], + title: str, + text: Union[str, None] = None +) -> SendResult: + """Send a notification to all Apprise URL's given. + + Args: + urls (List[str]): The Apprise URL's to send the notification to. + + title (str): The title of the notification. + + text (Union[str, None], optional): The optional body of the + notification. + Defaults to None. + + Returns: + SendResult: Whether or not it was successful. + """ + a = Apprise() + + for url in urls: + if not a.add(url): + return SendResult.SYNTAX_INVALID_URL + + with LogCapture(level=WARNING) as log: + result = a.notify( + title=title, + body=text or '\u200B' + ) + if not result: + if "socket exception" in log.getvalue(): # type: ignore + return SendResult.CONNECTION_ERROR + else: + return SendResult.REJECTED_URL + + return SendResult.SUCCESS + + +def next_selected_day( + weekdays: List[WEEKDAY_NUMBER], + weekday: WEEKDAY_NUMBER +) -> WEEKDAY_NUMBER: + """Find the next allowed day in the week. + + Args: + weekdays (List[WEEKDAY_NUMBER]): The days of the week that are allowed. + Monday is 0, Sunday is 6. + weekday (WEEKDAY_NUMBER): The current weekday. + + Returns: + WEEKDAY_NUMBER: The next allowed weekday. + """ + for d in weekdays: + if weekday < d: + return d + return weekdays[0] + + +def find_next_time( + original_time: int, + repeat_quantity: Union[RepeatQuantity, None], + repeat_interval: Union[int, None], + weekdays: Union[List[WEEKDAY_NUMBER], None] +) -> int: + """Calculate the next timestep based on original time and repeat/interval + values. + + Args: + original_time (int): The original time of the repeating timestamp. + + repeat_quantity (Union[RepeatQuantity, None]): If set, what the quantity + is of the repetition. + + repeat_interval (Union[int, None]): If set, the value of the repetition. + + weekdays (Union[List[WEEKDAY_NUMBER], None]): If set, on which days the + time can continue. Monday is 0, Sunday is 6. + + Returns: + int: The next timestamp in the future. + """ + if weekdays is not None: + weekdays.sort() + + current_time = datetime.fromtimestamp(datetime.utcnow().timestamp()) + original_datetime = datetime.fromtimestamp(original_time) + new_time = datetime.fromtimestamp(original_time) + + if ( + repeat_quantity is not None + and repeat_interval is not None + ): + # Add the interval to the original time until we are in the future. + # We need to multiply the interval and add it to the original time + # instead of just adding the interval once each time to the original + # time, because otherwise date jumping could happen. Say original time + # is a leap day with an interval of 1 year. Then next date would be the + # day before leap day, as leap day doesn't exist in the next year. But + # if we then keep adding 1 year to this time, we would keep getting the + # day before leap day, a year later. So we need to multiply the interval + # and add the whole interval to the original time in one go. This way + # after four years we will get the leap day again. + interval = relativedelta( + **{repeat_quantity.value: repeat_interval} # type: ignore + ) + multiplier = 1 + while new_time <= current_time: + new_time = original_datetime + (interval * multiplier) + multiplier += 1 + + elif weekdays is not None: + if ( + current_time.weekday() in weekdays + and current_time.time() < original_datetime.time() + ): + # Next reminder is later today, so target weekday is current weekday + weekday = current_time.weekday() + + else: + # Next reminder is not today or earlier today, so target weekday + # is next selected one + weekday = next_selected_day( + weekdays, + cast(WEEKDAY_NUMBER, current_time.weekday()) + ) + + new_time = current_time + relativedelta( + # Move to upcoming weekday (possibly today) + weekday=weekday, + # Also move current time to set time + hour=original_datetime.hour, + minute=original_datetime.minute, + second=original_datetime.second + ) + + result = int(new_time.timestamp()) + # LOGGER.debug( + # f'{original_datetime=}, {current_time=} ' + + # f'and interval of {repeat_interval} {repeat_quantity} ' + + # f'and weekdays {weekdays} ' + + # f'leads to {result}' + # ) + return result + + +def folder_path(*folders: str) -> str: + """Turn filepaths relative to the project folder into absolute paths. + + Returns: + str: The absolute filepath. + """ + return join( + dirname(dirname(dirname(abspath(__file__)))), + *folders + ) + + +def create_folder( + folder: str +) -> None: + """Create a folder, if it doesn't exist already. + + Args: + folder (str): The path to the folder to create. + """ + makedirs(folder, exist_ok=True) + return + + +def __copy2(src, dst, *, follow_symlinks=True): + try: + return copy2(src, dst, follow_symlinks=follow_symlinks) + + except PermissionError as pe: + if pe.errno == 1: + # NFS file system doesn't allow/support chmod. + # This is done after the file is already copied. So just accept that + # it isn't possible to change the permissions. Continue like normal. + return dst + + raise + + except OSError as oe: + if oe.errno == 524: + # NFS file system doesn't allow/support setting extended attributes. + # This is done after the file is already copied. So just accept that + # it isn't possible to set them. Continue like normal. + return dst + + raise + + +def rename_file( + before: str, + after: str +) -> None: + """Rename a file, taking care of new folder locations and + the possible complications with files on OS'es. + + Args: + before (str): The current filepath of the file. + after (str): The new desired filepath of the file. + """ + create_folder(dirname(after)) + + move(before, after, copy_function=__copy2) + + return + + +class Singleton(type): + _instances = {} + + def __call__(cls, *args: Any, **kwargs: Any): + c = str(cls) + if c not in cls._instances: + cls._instances[c] = super().__call__(*args, **kwargs) + + return cls._instances[c] diff --git a/backend/base/logging.py b/backend/base/logging.py new file mode 100644 index 0000000..ff5a6a9 --- /dev/null +++ b/backend/base/logging.py @@ -0,0 +1,163 @@ +# -*- coding: utf-8 -*- + +import logging +import logging.config +from os.path import exists, isdir, join +from typing import Any, Union + +from backend.base.definitions import Constants +from backend.base.helpers import create_folder, folder_path + + +class UpToInfoFilter(logging.Filter): + def filter(self, record: logging.LogRecord) -> bool: + return record.levelno <= logging.INFO + + +class ErrorColorFormatter(logging.Formatter): + def format(self, record: logging.LogRecord) -> Any: + result = super().format(record) + return f'\033[1;31:40m{result}\033[0m' + + +LOGGER = logging.getLogger(Constants.LOGGER_NAME) +LOGGING_CONFIG = { + "version": 1, + "disable_existing_loggers": False, + "formatters": { + "simple": { + "format": "[%(asctime)s][%(levelname)s] %(message)s", + "datefmt": "%H:%M:%S" + }, + "simple_red": { + "()": ErrorColorFormatter, + "format": "[%(asctime)s][%(levelname)s] %(message)s", + "datefmt": "%H:%M:%S" + }, + "detailed": { + "format": "%(asctime)s | %(threadName)s | %(filename)sL%(lineno)s | %(levelname)s | %(message)s", + "datefmt": "%Y-%m-%dT%H:%M:%S%z", + } + }, + "filters": { + "up_to_info": { + "()": UpToInfoFilter + }, + }, + "handlers": { + "console_error": { + "class": "logging.StreamHandler", + "level": "WARNING", + "formatter": "simple_red", + "stream": "ext://sys.stderr" + }, + "console": { + "class": "logging.StreamHandler", + "level": "DEBUG", + "formatter": "simple", + "filters": ["up_to_info"], + "stream": "ext://sys.stdout" + }, + "file": { + "class": "logging.handlers.RotatingFileHandler", + "level": "DEBUG", + "formatter": "detailed", + "filename": "", + "maxBytes": 1_000_000, + "backupCount": 1 + } + }, + "loggers": { + Constants.LOGGER_NAME: {} + }, + "root": { + "level": "INFO", + "handlers": [ + "console", + "console_error", + "file" + ] + } +} + + +def setup_logging(log_folder: Union[str, None]) -> None: + """Setup the basic config of the logging module. + + Args: + log_folder (Union[str, None]): The folder to put the log file in. + If `None`, the log file will be in the same folder as the + application folder. + + Raises: + ValueError: The given log folder is not a folder. + """ + if log_folder: + if exists(log_folder) and not isdir(log_folder): + raise ValueError("Logging folder is not a folder") + + create_folder(log_folder) + + if log_folder is None: + LOGGING_CONFIG["handlers"]["file"]["filename"] = folder_path( + Constants.LOGGER_FILENAME + ) + else: + LOGGING_CONFIG["handlers"]["file"]["filename"] = join( + log_folder, + Constants.LOGGER_FILENAME + ) + + logging.config.dictConfig(LOGGING_CONFIG) + + # Log uncaught exceptions using the logger instead of printing the stderr + # Logger goes to stderr anyway, so still visible in console but also logs + # to file, so that downloaded log file also contains any errors. + import sys + import threading + from traceback import format_exception + + def log_uncaught_exceptions(e_type, value, tb): + LOGGER.error( + "UNCAUGHT EXCEPTION:\n" + + ''.join(format_exception(e_type, value, tb)) + ) + return + + def log_uncaught_threading_exceptions(args): + LOGGER.exception( + f"UNCAUGHT EXCEPTION IN THREAD: {args.exc_value}" + ) + return + + sys.excepthook = log_uncaught_exceptions + threading.excepthook = log_uncaught_threading_exceptions + + return + + +def get_log_filepath() -> str: + "Get the filepath to the logging file" + return LOGGING_CONFIG["handlers"]["file"]["filename"] + + +def set_log_level( + level: Union[int, str], +) -> None: + """Change the logging level. + + Args: + level (Union[int, str]): The level to set the logging to. + Should be a logging level, like `logging.INFO` or `"DEBUG"`. + """ + if isinstance(level, str): + level = logging._nameToLevel[level.upper()] + + root_logger = logging.getLogger() + if root_logger.level == level: + return + + LOGGER.debug(f'Setting logging level: {level}') + root_logger.setLevel(level) + + return diff --git a/backend/custom_exceptions.py b/backend/custom_exceptions.py deleted file mode 100644 index d7487d2..0000000 --- a/backend/custom_exceptions.py +++ /dev/null @@ -1,138 +0,0 @@ -#-*- coding: utf-8 -*- - -""" -All custom exceptions are defined here -""" - -""" -Note: Not all CE's inherit from CustomException. -""" - -from typing import Any, Dict -from backend.logging import LOGGER - - -class CustomException(Exception): - def __init__(self, e=None) -> None: - LOGGER.warning(self.__doc__) - super().__init__(e) - return - -class UsernameTaken(CustomException): - """The username is already taken""" - api_response = {'error': 'UsernameTaken', 'result': {}, 'code': 400} - -class UsernameInvalid(Exception): - """The username contains invalid characters""" - api_response = {'error': 'UsernameInvalid', 'result': {}, 'code': 400} - - def __init__(self, username: str): - self.username = username - super().__init__(self.username) - LOGGER.warning( - f'The username contains invalid characters: {username}' - ) - return - -class UserNotFound(CustomException): - """The user requested can not be found""" - api_response = {'error': 'UserNotFound', 'result': {}, 'code': 404} - -class AccessUnauthorized(CustomException): - """The password given is not correct""" - api_response = {'error': 'AccessUnauthorized', 'result': {}, 'code': 401} - -class ReminderNotFound(CustomException): - """The reminder with the id can not be found""" - api_response = {'error': 'ReminderNotFound', 'result': {}, 'code': 404} - -class NotificationServiceNotFound(CustomException): - """The notification service was not found""" - api_response = {'error': 'NotificationServiceNotFound', 'result': {}, 'code': 404} - -class NotificationServiceInUse(Exception): - """ - The notification service is wished to be deleted - but a reminder is still using it - """ - def __init__(self, type: str=''): - self.type = type - super().__init__(self.type) - LOGGER.warning( - f'The notification is wished to be deleted but a reminder of type {type} is still using it' - ) - return - - @property - def api_response(self) -> Dict[str, Any]: - return { - 'error': 'NotificationServiceInUse', - 'result': {'type': self.type}, - 'code': 400 - } - -class InvalidTime(CustomException): - """The time given is in the past""" - api_response = {'error': 'InvalidTime', 'result': {}, 'code': 400} - -class KeyNotFound(Exception): - """A key was not found in the input that is required to be given""" - def __init__(self, key: str=''): - self.key = key - super().__init__(self.key) - LOGGER.warning( - "This key was not found in the API request," - + f" eventhough it's required: {key}" - ) - return - - @property - def api_response(self) -> Dict[str, Any]: - return { - 'error': 'KeyNotFound', - 'result': {'key': self.key}, - 'code': 400 - } - -class InvalidKeyValue(Exception): - """The value of a key is invalid""" - def __init__(self, key: str = '', value: Any = ''): - self.key = key - self.value = value - super().__init__(self.key) - LOGGER.warning( - 'This key in the API request has an invalid value: ' + - f'{key} = {value}' - ) - - @property - def api_response(self) -> Dict[str, Any]: - return { - 'error': 'InvalidKeyValue', - 'result': {'key': self.key, 'value': self.value}, - 'code': 400 - } - -class TemplateNotFound(CustomException): - """The template was not found""" - api_response = {'error': 'TemplateNotFound', 'result': {}, 'code': 404} - -class APIKeyInvalid(Exception): - """The API key is not correct""" - api_response = {'error': 'APIKeyInvalid', 'result': {}, 'code': 401} - -class APIKeyExpired(Exception): - """The API key has expired""" - api_response = {'error': 'APIKeyExpired', 'result': {}, 'code': 401} - -class NewAccountsNotAllowed(CustomException): - """It's not allowed to create a new account except for the admin""" - api_response = {'error': 'NewAccountsNotAllowed', 'result': {}, 'code': 403} - -class InvalidDatabaseFile(CustomException): - """The uploaded database file is invalid or not supported""" - api_response = {'error': 'InvalidDatabaseFile', 'result': {}, 'code': 400} - -class LogFileNotFound(CustomException): - """No log file was found""" - api_response = {'error': 'LogFileNotFound', 'result': {}, 'code': 404} diff --git a/backend/db.py b/backend/db.py deleted file mode 100644 index 9612f89..0000000 --- a/backend/db.py +++ /dev/null @@ -1,521 +0,0 @@ -#-*- coding: utf-8 -*- - -""" -Setting up and interacting with the database. -""" - -from datetime import datetime -from os import makedirs, remove -from os.path import dirname, isfile, join -from shutil import move -from sqlite3 import Connection, OperationalError, ProgrammingError, Row -from threading import current_thread, main_thread -from time import time -from typing import Type, Union - -from flask import g - -from backend.custom_exceptions import (AccessUnauthorized, InvalidDatabaseFile, - UserNotFound) -from backend.helpers import RestartVars, folder_path -from backend.logging import LOGGER, set_log_level - -DB_FILENAME = 'db', 'MIND.db' -__DATABASE_VERSION__ = 10 -__DATEBASE_NAME_ORIGINAL__ = "MIND_original.db" - -class DB_Singleton(type): - _instances = {} - def __call__(cls, *args, **kwargs): - i = f'{cls}{current_thread()}' - if (i not in cls._instances - or cls._instances[i].closed): - cls._instances[i] = super(DB_Singleton, cls).__call__(*args, **kwargs) - - return cls._instances[i] - -class DBConnection(Connection, metaclass=DB_Singleton): - file = '' - - def __init__(self, timeout: float) -> None: - LOGGER.debug(f'Creating connection {self}') - super().__init__(self.file, timeout=timeout) - super().cursor().execute("PRAGMA foreign_keys = ON;") - self.closed = False - return - - def close(self) -> None: - LOGGER.debug(f'Closing connection {self}') - self.closed = True - super().close() - return - - def __repr__(self) -> str: - return f'<{self.__class__.__name__}; {current_thread().name}; {id(self)}>' - -def setup_db_location() -> None: - """Create folder for database and link file to DBConnection class - """ - if isfile(folder_path('db', 'Noted.db')): - move(folder_path('db', 'Noted.db'), folder_path(*DB_FILENAME)) - - db_location = folder_path(*DB_FILENAME) - makedirs(dirname(db_location), exist_ok=True) - - DBConnection.file = db_location - return - -def get_db(output_type: Union[Type[dict], Type[tuple]]=tuple): - """Get a database cursor instance. Coupled to Flask's g. - - Args: - output_type (Union[Type[dict], Type[tuple]], optional): - The type of output: a tuple or dictionary with the row values. - Defaults to tuple. - - Returns: - Cursor: The Cursor instance to use - """ - try: - cursor = g.cursor - except AttributeError: - db = DBConnection(timeout=20.0) - cursor = g.cursor = db.cursor() - - if output_type is dict: - cursor.row_factory = Row - else: - cursor.row_factory = None - - return g.cursor - -def close_db(e=None) -> None: - """Savely closes the database connection - """ - try: - cursor = g.cursor - db: DBConnection = cursor.connection - cursor.close() - delattr(g, 'cursor') - db.commit() - if current_thread() is main_thread(): - db.close() - except (AttributeError, ProgrammingError): - pass - return - -def migrate_db(current_db_version: int) -> None: - """ - Migrate a MIND database from it's current version - to the newest version supported by the MIND version installed. - """ - LOGGER.info('Migrating database to newer version...') - cursor = get_db() - if current_db_version == 1: - # V1 -> V2 - t = time() - utc_offset = datetime.fromtimestamp(t) - datetime.utcfromtimestamp(t) - cursor.execute("SELECT time, id FROM reminders;") - new_reminders = [] - new_reminders_append = new_reminders.append - for reminder in cursor: - new_reminders_append([round((datetime.fromtimestamp(reminder[0]) - utc_offset).timestamp()), reminder[1]]) - cursor.executemany("UPDATE reminders SET time = ? WHERE id = ?;", new_reminders) - current_db_version = 2 - - if current_db_version == 2: - # V2 -> V3 - cursor.executescript(""" - ALTER TABLE reminders - ADD color VARCHAR(7); - ALTER TABLE templates - ADD color VARCHAR(7); - """) - current_db_version = 3 - - if current_db_version == 3: - # V3 -> V4 - cursor.executescript(""" - UPDATE reminders - SET repeat_quantity = repeat_quantity || 's' - WHERE repeat_quantity NOT LIKE '%s'; - """) - current_db_version = 4 - - if current_db_version == 4: - # V4 -> V5 - cursor.executescript(""" - BEGIN TRANSACTION; - PRAGMA defer_foreign_keys = ON; - - CREATE TEMPORARY TABLE temp_reminder_services( - reminder_id, - static_reminder_id, - template_id, - notification_service_id - ); - - -- Reminders - INSERT INTO temp_reminder_services(reminder_id, notification_service_id) - SELECT id, notification_service - FROM reminders; - - CREATE TEMPORARY TABLE temp_reminders AS - SELECT id, user_id, title, text, time, repeat_quantity, repeat_interval, original_time, color - FROM reminders; - DROP TABLE reminders; - CREATE TABLE reminders( - id INTEGER PRIMARY KEY, - user_id INTEGER NOT NULL, - title VARCHAR(255) NOT NULL, - text TEXT, - time INTEGER NOT NULL, - - repeat_quantity VARCHAR(15), - repeat_interval INTEGER, - original_time INTEGER, - - color VARCHAR(7), - - FOREIGN KEY (user_id) REFERENCES users(id) - ); - INSERT INTO reminders - SELECT * FROM temp_reminders; - - -- Templates - INSERT INTO temp_reminder_services(template_id, notification_service_id) - SELECT id, notification_service - FROM templates; - - CREATE TEMPORARY TABLE temp_templates AS - SELECT id, user_id, title, text, color - FROM templates; - DROP TABLE templates; - CREATE TABLE templates( - id INTEGER PRIMARY KEY, - user_id INTEGER NOT NULL, - title VARCHAR(255) NOT NULL, - text TEXT, - - color VARCHAR(7), - - FOREIGN KEY (user_id) REFERENCES users(id) - ); - INSERT INTO templates - SELECT * FROM temp_templates; - - INSERT INTO reminder_services - SELECT * FROM temp_reminder_services; - - COMMIT; - """) - current_db_version = 5 - - if current_db_version == 5: - # V5 -> V6 - from backend.users import User - try: - User('User1', 'Password1').delete() - except (UserNotFound, AccessUnauthorized): - pass - - current_db_version = 6 - - if current_db_version == 6: - # V6 -> V7 - cursor.executescript(""" - ALTER TABLE reminders - ADD weekdays VARCHAR(13); - """) - current_db_version = 7 - - if current_db_version == 7: - # V7 -> V8 - from backend.settings import _format_setting, default_settings - from backend.users import Users - - cursor.executescript(""" - DROP TABLE config; - CREATE TABLE IF NOT EXISTS config( - key VARCHAR(255) PRIMARY KEY, - value BLOB NOT NULL - ); - """ - ) - cursor.executemany(""" - INSERT OR IGNORE INTO config(key, value) - VALUES (?, ?); - """, - map( - lambda kv: (kv[0], _format_setting(*kv)), - default_settings.items() - ) - ) - - cursor.executescript(""" - ALTER TABLE users - ADD admin BOOL NOT NULL DEFAULT 0; - - UPDATE users - SET username = 'admin_old' - WHERE username = 'admin'; - """) - - Users().add('admin', 'admin', True) - - cursor.execute(""" - UPDATE users - SET admin = 1 - WHERE username = 'admin'; - """) - - current_db_version = 8 - - if current_db_version == 8: - # V8 -> V9 - from backend.settings import set_setting - from MIND import HOST, PORT, URL_PREFIX - - set_setting('host', HOST) - set_setting('port', int(PORT)) - set_setting('url_prefix', URL_PREFIX) - - current_db_version = 9 - - if current_db_version == 9: - # V9 -> V10 - - # Nothing is changed in the database - # It's just that this code needs to run once - # and the DB migration system does exactly that: - # run pieces of code once. - from backend.settings import update_manifest - - url_prefix: str = cursor.execute( - "SELECT value FROM config WHERE key = 'url_prefix' LIMIT 1;" - ).fetchone()[0] - update_manifest(url_prefix) - - current_db_version = 10 - - return - -def setup_db() -> None: - """Setup the database - """ - from backend.settings import (_format_setting, default_settings, get_setting, - set_setting, update_manifest) - from backend.users import Users - - cursor = get_db() - cursor.execute("PRAGMA journal_mode = wal;") - - cursor.executescript(""" - CREATE TABLE IF NOT EXISTS users( - id INTEGER PRIMARY KEY, - username VARCHAR(255) UNIQUE NOT NULL, - salt VARCHAR(40) NOT NULL, - hash VARCHAR(100) NOT NULL, - admin BOOL NOT NULL DEFAULT 0 - ); - CREATE TABLE IF NOT EXISTS notification_services( - id INTEGER PRIMARY KEY, - user_id INTEGER NOT NULL, - title VARCHAR(255), - url TEXT, - - FOREIGN KEY (user_id) REFERENCES users(id) - ); - CREATE TABLE IF NOT EXISTS reminders( - id INTEGER PRIMARY KEY, - user_id INTEGER NOT NULL, - title VARCHAR(255) NOT NULL, - text TEXT, - time INTEGER NOT NULL, - - repeat_quantity VARCHAR(15), - repeat_interval INTEGER, - original_time INTEGER, - weekdays VARCHAR(13), - - color VARCHAR(7), - - FOREIGN KEY (user_id) REFERENCES users(id) - ); - CREATE TABLE IF NOT EXISTS templates( - id INTEGER PRIMARY KEY, - user_id INTEGER NOT NULL, - title VARCHAR(255) NOT NULL, - text TEXT, - - color VARCHAR(7), - - FOREIGN KEY (user_id) REFERENCES users(id) - ); - CREATE TABLE IF NOT EXISTS static_reminders( - id INTEGER PRIMARY KEY, - user_id INTEGER NOT NULL, - title VARCHAR(255) NOT NULL, - text TEXT, - - color VARCHAR(7), - - FOREIGN KEY (user_id) REFERENCES users(id) - ); - CREATE TABLE IF NOT EXISTS reminder_services( - reminder_id INTEGER, - static_reminder_id INTEGER, - template_id INTEGER, - notification_service_id INTEGER NOT NULL, - - FOREIGN KEY (reminder_id) REFERENCES reminders(id) - ON DELETE CASCADE, - FOREIGN KEY (static_reminder_id) REFERENCES static_reminders(id) - ON DELETE CASCADE, - FOREIGN KEY (template_id) REFERENCES templates(id) - ON DELETE CASCADE, - FOREIGN KEY (notification_service_id) REFERENCES notification_services(id) - ); - CREATE TABLE IF NOT EXISTS config( - key VARCHAR(255) PRIMARY KEY, - value BLOB NOT NULL - ); - """) - - cursor.executemany(""" - INSERT OR IGNORE INTO config(key, value) - VALUES (?, ?); - """, - map( - lambda kv: (kv[0], _format_setting(*kv)), - default_settings.items() - ) - ) - - set_log_level(get_setting('log_level'), clear_file=False) - update_manifest(get_setting('url_prefix')) - - current_db_version = get_setting('database_version') - if current_db_version < __DATABASE_VERSION__: - LOGGER.debug( - f'Database migration: {current_db_version} -> {__DATABASE_VERSION__}' - ) - migrate_db(current_db_version) - set_setting('database_version', __DATABASE_VERSION__) - - users = Users() - if not 'admin' in users: - users.add('admin', 'admin', True) - cursor.execute(""" - UPDATE users - SET admin = 1 - WHERE username = 'admin'; - """) - - return - -def revert_db_import( - swap: bool, - imported_db_file: str = '' -) -> None: - """Revert the database import process. The original_db_file is the file - currently used (`DBConnection.file`). - - Args: - swap (bool): Whether or not to keep the imported_db_file or not, - instead of the original_db_file. - imported_db_file (str, optional): The other database file. Keep empty - to use `__DATABASE_NAME_ORIGINAL__`. Defaults to ''. - """ - original_db_file = DBConnection.file - if not imported_db_file: - imported_db_file = join(dirname(DBConnection.file), __DATEBASE_NAME_ORIGINAL__) - - if swap: - remove(original_db_file) - move( - imported_db_file, - original_db_file - ) - - else: - remove(imported_db_file) - - return - -def import_db( - new_db_file: str, - copy_hosting_settings: bool -) -> None: - """Replace the current database with a new one. - - Args: - new_db_file (str): The path to the new database file. - copy_hosting_settings (bool): Keep the hosting settings from the current - database. - - Raises: - InvalidDatabaseFile: The new database file is invalid or unsupported. - """ - LOGGER.info(f'Importing new database; {copy_hosting_settings=}') - try: - cursor = Connection(new_db_file, timeout=20.0).cursor() - - database_version = cursor.execute( - "SELECT value FROM config WHERE key = 'database_version' LIMIT 1;" - ).fetchone()[0] - if not isinstance(database_version, int): - raise InvalidDatabaseFile - - except (OperationalError, InvalidDatabaseFile): - LOGGER.error('Uploaded database is not a MIND database file') - cursor.connection.close() - revert_db_import( - swap=False, - imported_db_file=new_db_file - ) - raise InvalidDatabaseFile - - if database_version > __DATABASE_VERSION__: - LOGGER.error('Uploaded database is higher version than this MIND installation can support') - revert_db_import( - swap=False, - imported_db_file=new_db_file - ) - raise InvalidDatabaseFile - - if copy_hosting_settings: - hosting_settings = get_db().execute(""" - SELECT key, value, value - FROM config - WHERE key = 'host' - OR key = 'port' - OR key = 'url_prefix' - LIMIT 3; - """ - ) - cursor.executemany(""" - INSERT INTO config(key, value) - VALUES (?, ?) - ON CONFLICT(key) DO - UPDATE - SET value = ?; - """, - hosting_settings - ) - cursor.connection.commit() - cursor.connection.close() - - move( - DBConnection.file, - join(dirname(DBConnection.file), __DATEBASE_NAME_ORIGINAL__) - ) - move( - new_db_file, - DBConnection.file - ) - - from backend.server import SERVER - SERVER.restart([RestartVars.DB_IMPORT.value]) - - return diff --git a/backend/features/reminder_handler.py b/backend/features/reminder_handler.py new file mode 100644 index 0000000..84a3a34 --- /dev/null +++ b/backend/features/reminder_handler.py @@ -0,0 +1,132 @@ +# -*- coding: utf-8 -*- + +from datetime import datetime +from threading import Timer +from typing import Union + +from backend.base.definitions import Constants, RepeatQuantity, SendResult +from backend.base.helpers import (Singleton, find_next_time, + send_apprise_notification, when_not_none) +from backend.base.logging import LOGGER +from backend.implementations.notification_services import NotificationService +from backend.internals.db_models import UserlessRemindersDB +from backend.internals.server import Server + + +class ReminderHandler(metaclass=Singleton): + """ + Handle set reminders. This class is a singleton. + """ + + def __init__(self) -> None: + "Create instance of handler" + self.thread: Union[Timer, None] = None + self.time: Union[int, None] = None + self.reminder_db = UserlessRemindersDB() + return + + def __trigger_reminders(self, time: int) -> None: + """Trigger all reminders that are set for a certain time. + + Args: + time (int): The time of the reminders to trigger. + """ + with Server().app.app_context(): + for reminder in self.reminder_db.fetch(time): + try: + user_id = self.reminder_db.reminder_id_to_user_id( + reminder.id) + result = send_apprise_notification( + [ + NotificationService(user_id, ns).get().url + for ns in reminder.notification_services + ], + reminder.title, + reminder.text + ) + + self.thread = None + self.time = None + + if result == SendResult.CONNECTION_ERROR: + # Retry sending the notification in a few minutes + self.reminder_db.update( + reminder.id, + time + Constants.CONNECTION_ERROR_TIMEOUT + ) + + elif ( + reminder.repeat_quantity, + reminder.weekdays + ) == (None, None): + # Delete the reminder from the database + self.reminder_db.delete(reminder.id) + + else: + # Set next time + new_time = find_next_time( + reminder.original_time or -1, + when_not_none( + reminder.repeat_quantity, + lambda q: RepeatQuantity(q) + ), + reminder.repeat_interval, + reminder.weekdays + ) + + self.reminder_db.update(reminder.id, new_time) + + except Exception: + # If the notification fails, we don't want to crash the whole program + # Just log the error and continue + LOGGER.exception( + "Failed to send notification for reminder %s: ", + reminder.id + ) + + finally: + self.find_next_reminder() + + return + + def find_next_reminder(self, time: Union[int, None] = None) -> None: + """Determine when the soonest reminder is and set the timer to that time. + + Args: + time (Union[int, None], optional): The timestamp to check for. + Otherwise check soonest in database. + Defaults to None. + """ + if time is None: + time = self.reminder_db.get_soonest_time() + if not time: + return + + if ( + self.thread is None + or ( + self.time is not None + and time < self.time + ) + ): + if self.thread is not None: + self.thread.cancel() + + delta_t = time - datetime.utcnow().timestamp() + self.thread = Timer( + delta_t, + self.__trigger_reminders, + (time,) + ) + self.thread.name = "ReminderHandler" + self.thread.start() + self.time = time + + return + + def stop_handling(self) -> None: + """Stop the timer if it's active + """ + if self.thread is not None: + self.thread.cancel() + return diff --git a/backend/features/reminders.py b/backend/features/reminders.py new file mode 100644 index 0000000..7bbf673 --- /dev/null +++ b/backend/features/reminders.py @@ -0,0 +1,418 @@ +# -*- coding: utf-8 -*- + +from dataclasses import asdict +from datetime import datetime +from typing import List, Union + +from backend.base.custom_exceptions import (InvalidKeyValue, InvalidTime, + ReminderNotFound) +from backend.base.definitions import (WEEKDAY_NUMBER, ReminderData, + RepeatQuantity, SendResult, + SortingMethod) +from backend.base.helpers import (find_next_time, search_filter, + send_apprise_notification, when_not_none) +from backend.base.logging import LOGGER +from backend.features.reminder_handler import ReminderHandler +from backend.implementations.notification_services import NotificationService +from backend.internals.db_models import RemindersDB + +REMINDER_HANDLER = ReminderHandler() + + +class Reminder: + def __init__(self, user_id: int, reminder_id: int) -> None: + """Represent a reminder. + + Args: + user_id (int): The ID of the user. + reminder_id (int): The ID of the reminder. + + Raises: + ReminderNotFound: Reminder with given ID does not exist or is not + owned by user. + """ + self.user_id = user_id + self.id = reminder_id + + self.reminder_db = RemindersDB(self.user_id) + + if not self.reminder_db.exists(self.id): + raise ReminderNotFound(reminder_id) + return + + def get(self) -> ReminderData: + """Get info about the reminder. + + Returns: + ReminderData: The info about the reminder. + """ + return self.reminder_db.fetch(self.id)[0] + + def update( + self, + title: Union[None, str] = None, + time: Union[None, int] = None, + notification_services: Union[None, List[int]] = None, + text: Union[None, str] = None, + repeat_quantity: Union[None, RepeatQuantity] = None, + repeat_interval: Union[None, int] = None, + weekdays: Union[None, List[WEEKDAY_NUMBER]] = None, + color: Union[None, str] = None + ) -> ReminderData: + """Edit the reminder. + + Args: + title (Union[None, str]): The new title of the entry. + Defaults to None. + + time (Union[None, int]): The new UTC epoch timestamp when the + reminder should be send. + Defaults to None. + + notification_services (Union[None, List[int]]): The new list + of id's of the notification services to use to send the reminder. + Defaults to None. + + text (Union[None, str], optional): The new body of the reminder. + Defaults to None. + + repeat_quantity (Union[None, RepeatQuantity], optional): The new + quantity of the repeat specified for the reminder. + Defaults to None. + + repeat_interval (Union[None, int], optional): The new amount of + repeat_quantity, like "5" (hours). + Defaults to None. + + weekdays (Union[None, List[WEEKDAY_NUMBER]], optional): The new + indexes of the days of the week that the reminder should run. + Defaults to None. + + color (Union[None, str], optional): The new hex code of the color + of the reminder, which is shown in the web-ui. + Defaults to None. + + Note about args: + Either repeat_quantity and repeat_interval are given, weekdays is + given or neither, but not both. + + Raises: + NotificationServiceNotFound: One of the notification services was not found. + InvalidKeyValue: The value of one of the keys is not valid or + the "Note about args" is violated. + + Returns: + ReminderData: The new reminder info. + """ + LOGGER.info( + f'Updating notification service {self.id}: ' + + f'{title=}, {time=}, {notification_services=}, {text=}, ' + + f'{repeat_quantity=}, {repeat_interval=}, {weekdays=}, {color=}' + ) + + # Validate data + if repeat_quantity is None and repeat_interval is not None: + raise InvalidKeyValue('repeat_quantity', repeat_quantity) + elif repeat_quantity is not None and repeat_interval is None: + raise InvalidKeyValue('repeat_interval', repeat_interval) + elif weekdays is not None and repeat_quantity is not None: + raise InvalidKeyValue('weekdays', weekdays) + + repeated_reminder = ( + (repeat_quantity is not None and repeat_interval is not None) + or weekdays is not None + ) + + if time is not None: + if not repeated_reminder: + if time < datetime.utcnow().timestamp(): + raise InvalidTime(time) + time = round(time) + + if notification_services: + # Check if all notification services exist + for ns in notification_services: + NotificationService(self.user_id, ns) + + # Get current data and update it with new values + data = asdict(self.get()) + + new_values = { + 'title': title, + 'time': time, + 'text': text, + 'repeat_quantity': when_not_none( + repeat_quantity, + lambda q: q.value + ), + 'repeat_interval': repeat_interval, + 'weekdays': when_not_none( + weekdays, + lambda w: ",".join(map(str, sorted(w))) + ), + 'color': color, + 'notification_services': notification_services + } + for k, v in new_values.items(): + if ( + k in ('repeat_quantity', 'repeat_interval', 'weekdays', 'color') + or v is not None + ): + data[k] = v + + if repeated_reminder: + next_time = find_next_time( + data["time"], + data["repeat_quantity"], + data["repeat_interval"], + weekdays + ) + self.reminder_db.update( + self.id, + data["title"], + data["text"], + next_time, + data["repeat_quantity"], + data["repeat_interval"], + data["weekdays"], + data["time"], + data["color"], + data["notification_services"] + ) + + else: + next_time = data["time"] + self.reminder_db.update( + self.id, + data["title"], + data["text"], + next_time, + data["repeat_quantity"], + data["repeat_interval"], + data["weekdays"], + data["original_time"], + data["color"], + data["notification_services"] + ) + + REMINDER_HANDLER.find_next_reminder(next_time) + return self.get() + + def delete(self) -> None: + "Delete the reminder" + LOGGER.info(f'Deleting reminder {self.id}') + self.reminder_db.delete(self.id) + REMINDER_HANDLER.find_next_reminder() + return + + +class Reminders: + def __init__(self, user_id: int) -> None: + """Create an instance. + + Args: + user_id (int): The ID of the user. + """ + self.user_id = user_id + self.reminder_db = RemindersDB(self.user_id) + return + + def fetchall( + self, + sort_by: SortingMethod = SortingMethod.TIME + ) -> List[ReminderData]: + """Get all reminders. + + Args: + sort_by (SortingMethod, optional): How to sort the result. + Defaults to SortingMethod.TIME. + + Returns: + List[ReminderData]: The info of each reminder. + """ + reminders = self.reminder_db.fetch() + reminders.sort(key=sort_by.value[0], reverse=sort_by.value[1]) + return reminders + + def search( + self, + query: str, + sort_by: SortingMethod = SortingMethod.TIME + ) -> List[ReminderData]: + """Search for reminders. + + Args: + query (str): The term to search for. + sort_by (SortingMethod, optional): How to sort the result. + Defaults to SortingMethod.TIME. + + Returns: + List[ReminderData]: All reminders that match. Similar output to + self.fetchall. + """ + reminders = [ + r + for r in self.fetchall(sort_by) + if search_filter(query, r) + ] + return reminders + + def fetchone(self, id: int) -> Reminder: + """Get one reminder. + + Args: + id (int): The ID of the reminder to fetch. + + Raises: + ReminderNotFound: The reminder with the given ID does not exist + or is not owned by the user. + + Returns: + Reminder: A Reminder instance. + """ + return Reminder(self.user_id, id) + + def add( + self, + title: str, + time: int, + notification_services: List[int], + text: str = '', + repeat_quantity: Union[None, RepeatQuantity] = None, + repeat_interval: Union[None, int] = None, + weekdays: Union[None, List[WEEKDAY_NUMBER]] = None, + color: Union[None, str] = None + ) -> Reminder: + """Add a reminder. + + Args: + title (str): The title of the entry. + + time (int): The UTC epoch timestamp the the reminder should be send. + + notification_services (List[int]): The id's of the notification + services to use to send the reminder. + + text (str, optional): The body of the reminder. + Defaults to ''. + + repeat_quantity (Union[None, RepeatQuantity], optional): The quantity + of the repeat specified for the reminder. + Defaults to None. + + repeat_interval (Union[None, int], optional): The amount of + repeat_quantity, like "5" (hours). + Defaults to None. + + weekdays (Union[None, List[WEEKDAY_NUMBER]], optional): The indexes + of the days of the week that the reminder should run. + Defaults to None. + + color (Union[None, str], optional): The hex code of the color of the + reminder, which is shown in the web-ui. + Defaults to None. + + Note about args: + Either repeat_quantity and repeat_interval are given, + weekdays is given or neither, but not both. + + Raises: + NotificationServiceNotFound: One of the notification services was + not found. + InvalidKeyValue: The value of one of the keys is not valid + or the "Note about args" is violated. + + Returns: + Reminder: The info about the reminder. + """ + LOGGER.info( + f'Adding reminder with {title=}, {time=}, {notification_services=}, ' + + f'{text=}, {repeat_quantity=}, {repeat_interval=}, {weekdays=}, {color=}') + + # Validate data + if time < datetime.utcnow().timestamp(): + raise InvalidTime(time) + time = round(time) + + if repeat_quantity is None and repeat_interval is not None: + raise InvalidKeyValue('repeat_quantity', repeat_quantity) + elif repeat_quantity is not None and repeat_interval is None: + raise InvalidKeyValue('repeat_interval', repeat_interval) + elif ( + weekdays is not None + and repeat_quantity is not None + and repeat_interval is not None + ): + raise InvalidKeyValue('weekdays', weekdays) + + # Check if all notification services exist + for ns in notification_services: + NotificationService(self.user_id, ns) + + # Prepare args + if any((repeat_quantity, weekdays)): + original_time = time + time = find_next_time( + original_time, + repeat_quantity, + repeat_interval, + weekdays + ) + else: + original_time = None + + weekdays_str = when_not_none( + weekdays, + lambda w: ",".join(map(str, sorted(w))) + ) + repeat_quantity_str = when_not_none( + repeat_quantity, + lambda q: q.value + ) + + new_id = self.reminder_db.add( + title, text, + time, repeat_quantity_str, + repeat_interval, + weekdays_str, + original_time, + color, + notification_services + ) + + REMINDER_HANDLER.find_next_reminder(time) + + return self.fetchone(new_id) + + def test_reminder( + self, + title: str, + notification_services: List[int], + text: str = '' + ) -> SendResult: + """Test send a reminder draft. + + Args: + title (str): Title title of the entry. + + notification_service (int): The id of the notification service to + use to send the reminder. + + text (str, optional): The body of the reminder. + Defaults to ''. + + Returns: + SendResult: Whether or not it was successful. + """ + LOGGER.info( + f'Testing reminder with {title=}, {notification_services=}, {text=}' + ) + + return send_apprise_notification( + [ + NotificationService(self.user_id, ns_id).get().url + for ns_id in notification_services + ], + title, + text + ) diff --git a/backend/features/static_reminders.py b/backend/features/static_reminders.py new file mode 100644 index 0000000..c67303a --- /dev/null +++ b/backend/features/static_reminders.py @@ -0,0 +1,244 @@ +# -*- coding: utf-8 -*- + +from dataclasses import asdict +from typing import List, Union + +from backend.base.custom_exceptions import ReminderNotFound +from backend.base.definitions import (SendResult, StaticReminderData, + TimelessSortingMethod) +from backend.base.helpers import search_filter, send_apprise_notification +from backend.base.logging import LOGGER +from backend.implementations.notification_services import NotificationService +from backend.internals.db_models import StaticRemindersDB + + +class StaticReminder: + def __init__(self, user_id: int, reminder_id: int) -> None: + """Represent a static reminder. + + Args: + user_id (int): The ID of the user. + reminder_id (int): The ID of the reminder. + + Raises: + ReminderNotFound: Reminder with given ID does not exist or is not + owned by user. + """ + self.user_id = user_id + self.id = reminder_id + + self.reminder_db = StaticRemindersDB(self.user_id) + + if not self.reminder_db.exists(self.id): + raise ReminderNotFound(reminder_id) + return + + def get(self) -> StaticReminderData: + """Get info about the static reminder. + + Returns: + StaticReminderData: The info about the static reminder. + """ + return self.reminder_db.fetch(self.id)[0] + + def trigger_reminder(self) -> SendResult: + """Send the reminder. + + Returns: + SendResult: The result of the sending process. + """ + LOGGER.info(f'Triggering static reminder {self.id}') + + reminder_data = self.get() + + return send_apprise_notification( + [ + NotificationService(self.user_id, ns_id).get().url + for ns_id in reminder_data.notification_services + ], + reminder_data.title, + reminder_data.text + ) + + def update( + self, + title: Union[str, None] = None, + notification_services: Union[List[int], None] = None, + text: Union[str, None] = None, + color: Union[str, None] = None + ) -> StaticReminderData: + """Edit the static reminder. + + Args: + title (Union[str, None], optional): The new title of the entry. + Defaults to None. + + notification_services (Union[List[int], None], optional): The new + id's of the notification services to use to send the reminder. + Defaults to None. + + text (Union[str, None], optional): The new body of the reminder. + Defaults to None. + + color (Union[str, None], optional): The new hex code of the color + of the reminder, which is shown in the web-ui. + Defaults to None. + + Raises: + NotificationServiceNotFound: One of the notification services was + not found. + + Returns: + StaticReminderData: The new static reminder info. + """ + LOGGER.info( + f'Updating static reminder {self.id}: ' + + f'{title=}, {notification_services=}, {text=}, {color=}' + ) + + if notification_services: + # Check whether all notification services exist + for ns in notification_services: + NotificationService(self.user_id, ns) + + # Get current data and update it with new values + data = asdict(self.get()) + + new_values = { + 'title': title, + 'text': text, + 'color': color, + 'notification_services': notification_services + } + for k, v in new_values.items(): + if k in ('color',) or v is not None: + data[k] = v + + self.reminder_db.update( + self.id, + data['title'], + data['text'], + data['color'], + data['notification_services'] + ) + + return self.get() + + def delete(self) -> None: + "Delete the static reminder" + LOGGER.info(f'Deleting static reminder {self.id}') + self.reminder_db.delete(self.id) + return + + +class StaticReminders: + def __init__(self, user_id: int) -> None: + """Create an instance. + + Args: + user_id (int): The ID of the user. + """ + self.user_id = user_id + self.reminder_db = StaticRemindersDB(self.user_id) + return + + def fetchall( + self, + sort_by: TimelessSortingMethod = TimelessSortingMethod.TITLE + ) -> List[StaticReminderData]: + """Get all static reminders. + + Args: + sort_by (TimelessSortingMethod, optional): How to sort the result. + Defaults to TimelessSortingMethod.TITLE. + + Returns: + List[StaticReminderData]: The info of each static reminder. + """ + reminders = self.reminder_db.fetch() + reminders.sort(key=sort_by.value[0], reverse=sort_by.value[1]) + return reminders + + def search( + self, + query: str, + sort_by: TimelessSortingMethod = TimelessSortingMethod.TITLE + ) -> List[StaticReminderData]: + """Search for static reminders. + + Args: + query (str): The term to search for. + + sort_by (TimelessSortingMethod, optional): The sorting method of + the resulting list. + Defaults to TimelessSortingMethod.TITLE. + + Returns: + List[StaticReminderData]: All static reminders that match. + Similar output to `self.fetchall` + """ + static_reminders = [ + r + for r in self.fetchall(sort_by) + if search_filter(query, r) + ] + return static_reminders + + def fetchone(self, reminder_id: int) -> StaticReminder: + """Get one static reminder. + + Args: + reminder_id (int): The id of the static reminder to fetch. + + Raises: + ReminderNotFound: The static reminder with the given ID does not + exist or is not owned by the user. + + Returns: + StaticReminder: A StaticReminder instance. + """ + return StaticReminder(self.user_id, reminder_id) + + def add( + self, + title: str, + notification_services: List[int], + text: str = '', + color: Union[str, None] = None + ) -> StaticReminder: + """Add a static reminder. + + Args: + title (str): The title of the entry. + + notification_services (List[int]): The id's of the + notification services to use to send the reminder. + + text (str, optional): The body of the reminder. + Defaults to ''. + + color (Union[str, None], optional): The hex code of the color of the + template, which is shown in the web-ui. + Defaults to None. + + Raises: + NotificationServiceNotFound: One of the notification services was + not found. + + Returns: + StaticReminder: The info about the static reminder + """ + LOGGER.info( + f'Adding static reminder with {title=}, {notification_services=}, {text=}, {color=}' + ) + + # Check if all notification services exist + for ns in notification_services: + NotificationService(self.user_id, ns) + + new_id = self.reminder_db.add( + title, text, color, + notification_services + ) + + return self.fetchone(new_id) diff --git a/backend/features/templates.py b/backend/features/templates.py new file mode 100644 index 0000000..359f7b1 --- /dev/null +++ b/backend/features/templates.py @@ -0,0 +1,223 @@ +# -*- coding: utf-8 -*- + +from dataclasses import asdict +from typing import List, Union + +from backend.base.custom_exceptions import TemplateNotFound +from backend.base.definitions import TemplateData, TimelessSortingMethod +from backend.base.helpers import search_filter +from backend.base.logging import LOGGER +from backend.implementations.notification_services import NotificationService +from backend.internals.db_models import TemplatesDB + + +class Template: + def __init__(self, user_id: int, template_id: int) -> None: + """Represent a template. + + Args: + user_id (int): The ID of the user. + template_id (int): The ID of the template. + + Raises: + TemplateNotFound: Template with given ID does not exist or is not + owned by user. + """ + self.user_id = user_id + self.id = template_id + + self.template_db = TemplatesDB(self.user_id) + + if not self.template_db.exists(self.id): + raise TemplateNotFound(self.id) + return + + def get(self) -> TemplateData: + """Get info about the template. + + Returns: + TemplateData: The info about the template. + """ + return self.template_db.fetch(self.id)[0] + + def update(self, + title: Union[str, None] = None, + notification_services: Union[List[int], None] = None, + text: Union[str, None] = None, + color: Union[str, None] = None + ) -> TemplateData: + """Edit the template. + + Args: + title (Union[str, None]): The new title of the entry. + Defaults to None. + + notification_services (Union[List[int], None]): The new id's of the + notification services to use to send the reminder. + Defaults to None. + + text (Union[str, None], optional): The new body of the template. + Defaults to None. + + color (Union[str, None], optional): The new hex code of the color of + the template, which is shown in the web-ui. + Defaults to None. + + Raises: + NotificationServiceNotFound: One of the notification services was + not found. + + Returns: + TemplateData: The new template info. + """ + LOGGER.info( + f'Updating template {self.id}: ' + + f'{title=}, {notification_services=}, {text=}, {color=}' + ) + + if notification_services: + # Check if all notification services exist + for ns in notification_services: + NotificationService(self.user_id, ns) + + data = asdict(self.get()) + + new_values = { + 'title': title, + 'text': text, + 'color': color, + 'notification_services': notification_services + } + for k, v in new_values.items(): + if k in ('color',) or v is not None: + data[k] = v + + self.template_db.update( + self.id, + data['title'], + data['text'], + data['color'], + data['notification_services'] + ) + + return self.get() + + def delete(self) -> None: + "Delete the template" + LOGGER.info(f'Deleting template {self.id}') + self.template_db.delete(self.id) + return + + +class Templates: + def __init__(self, user_id: int) -> None: + """Create an instance. + + Args: + user_id (int): The ID of the user. + """ + self.user_id = user_id + self.template_db = TemplatesDB(self.user_id) + return + + def fetchall( + self, + sort_by: TimelessSortingMethod = TimelessSortingMethod.TITLE + ) -> List[TemplateData]: + """Get all templates of the user. + + Args: + sort_by (TimelessSortingMethod, optional): The sorting method of + the resulting list. + Defaults to TimelessSortingMethod.TITLE. + + Returns: + List[TemplateData]: The id, title, text and color of each template. + """ + templates = self.template_db.fetch() + templates.sort(key=sort_by.value[0], reverse=sort_by.value[1]) + return templates + + def search( + self, + query: str, + sort_by: TimelessSortingMethod = TimelessSortingMethod.TITLE + ) -> List[TemplateData]: + """Search for templates. + + Args: + query (str): The term to search for. + + sort_by (TimelessSortingMethod, optional): The sorting method of + the resulting list. + Defaults to TimelessSortingMethod.TITLE. + + Returns: + List[TemplateData]: All templates that match. Similar output to + `self.fetchall`. + """ + templates = [ + r + for r in self.fetchall(sort_by) + if search_filter(query, r) + ] + return templates + + def fetchone(self, template_id: int) -> Template: + """Get one template. + + Args: + template_id (int): The id of the template to fetch. + + Raises: + TemplateNotFound: Template with given ID does not exist or is not + owned by user. + + Returns: + Template: A Template instance. + """ + return Template(self.user_id, template_id) + + def add( + self, + title: str, + notification_services: List[int], + text: str = '', + color: Union[str, None] = None + ) -> Template: + """Add a template. + + Args: + title (str): The title of the entry. + + notification_services (List[int]): The id's of the + notification services to use to send the reminder. + + text (str, optional): The body of the reminder. + Defaults to ''. + + color (Union[str, None], optional): The hex code of the color of the + template, which is shown in the web-ui. + Defaults to None. + + Raises: + NotificationServiceNotFound: One of the notification services was + not found. + + Returns: + Template: The info about the template. + """ + LOGGER.info( + f'Adding template with {title=}, {notification_services=}, {text=}, {color=}' + ) + + # Check if all notification services exist + for ns in notification_services: + NotificationService(self.user_id, ns) + + new_id = self.template_db.add( + title, text, color, + notification_services + ) + + return self.fetchone(new_id) diff --git a/backend/helpers.py b/backend/helpers.py deleted file mode 100644 index 602d517..0000000 --- a/backend/helpers.py +++ /dev/null @@ -1,116 +0,0 @@ -#-*- coding: utf-8 -*- - -""" -General functions -""" - -from enum import Enum -from os.path import abspath, dirname, join -from sys import version_info -from typing import Callable, TypeVar, Union - -T = TypeVar('T') -U = TypeVar('U') - -def folder_path(*folders) -> str: - """Turn filepaths relative to the project folder into absolute paths - - Returns: - str: The absolute filepath - """ - return join(dirname(dirname(abspath(__file__))), *folders) - - -def check_python_version() -> bool: - """Check if the python version that is used is a minimum version. - - Returns: - bool: Whether or not the python version is version 3.8 or above or not. - """ - if not (version_info.major == 3 and version_info.minor >= 8): - from backend.logging import LOGGER - - LOGGER.critical( - 'The minimum python version required is python3.8 ' + - '(currently ' + str(version_info.major) + '.' + str(version_info.minor) + '.' + str(version_info.micro) + ').' - ) - return False - return True - - -def search_filter(query: str, result: dict) -> bool: - """Filter library results based on a query. - - Args: - query (str): The query to filter with. - result (dict): The library result to check. - - Returns: - bool: Whether or not the result passes the filter. - """ - query = query.lower() - return ( - query in result["title"].lower() - or query in result["text"].lower() - ) - - -def when_not_none(value: Union[T, None], to_run: Callable[[T], U]) -> Union[U, None]: - """Run `to_run` with argument `value` iff `value is not None`. Else return - `None`. - - Args: - value (Union[T, None]): The value to check. - to_run (Callable[[T], U]): The function to run. - - Returns: - Union[U, None]: Either the return value of `to_run`, or `None`. - """ - if value is None: - return None - else: - return to_run(value) - - -class Singleton(type): - _instances = {} - def __call__(cls, *args, **kwargs): - c = str(cls) - if c not in cls._instances: - cls._instances[c] = super().__call__(*args, **kwargs) - - return cls._instances[c] - - -class BaseEnum(Enum): - def __eq__(self, other) -> bool: - return self.value == other - - -class TimelessSortingMethod(BaseEnum): - TITLE = (lambda r: (r['title'], r['text'], r['color']), False) - TITLE_REVERSED = (lambda r: (r['title'], r['text'], r['color']), True) - DATE_ADDED = (lambda r: r['id'], False) - DATE_ADDED_REVERSED = (lambda r: r['id'], True) - - -class SortingMethod(BaseEnum): - TIME = (lambda r: (r['time'], r['title'], r['text'], r['color']), False) - TIME_REVERSED = (lambda r: (r['time'], r['title'], r['text'], r['color']), True) - TITLE = (lambda r: (r['title'], r['time'], r['text'], r['color']), False) - TITLE_REVERSED = (lambda r: (r['title'], r['time'], r['text'], r['color']), True) - DATE_ADDED = (lambda r: r['id'], False) - DATE_ADDED_REVERSED = (lambda r: r['id'], True) - - -class RepeatQuantity(BaseEnum): - YEARS = "years" - MONTHS = "months" - WEEKS = "weeks" - DAYS = "days" - HOURS = "hours" - MINUTES = "minutes" - -class RestartVars(BaseEnum): - DB_IMPORT = "db_import" - HOST_CHANGE = "host_change" diff --git a/backend/implementations/apprise_parser.py b/backend/implementations/apprise_parser.py new file mode 100644 index 0000000..2d99322 --- /dev/null +++ b/backend/implementations/apprise_parser.py @@ -0,0 +1,178 @@ +# -*- coding: utf-8 -*- + +from re import compile +from typing import Any, Dict, List, Tuple, Union + +from apprise import Apprise + +from backend.base.helpers import when_not_none + +remove_named_groups = compile(r'(?<=\()\?P<\w+>') +IGNORED_ARGS = ('cto', 'format', 'overflow', 'rto', 'verify') + + +def process_regex( + regex: Union[Tuple[str, str], None] +) -> Union[Tuple[str, str], None]: + return when_not_none( + regex, + lambda r: (remove_named_groups.sub('', r[0]), r[1]) + ) + + +def _sort_tokens(t: Dict[str, Any]) -> List[int]: + result = [ + int(not t['required']) + ] + + if t['name'] == 'Schema': + result.append(0) + + if t['type'] == 'choice': + result.append(1) + + elif t['type'] != 'list': + result.append(2) + + else: + result.append(3) + + return result + + +def get_apprise_services() -> List[Dict[str, Any]]: + apprise_services = [] + + raw = Apprise().details()['schemas'] + for entry in raw: + result = { + 'name': str(entry['service_name']), + 'doc_url': entry['setup_url'], + 'details': { + 'templates': entry['details']['templates'], + 'tokens': [], + 'args': [] + } + } + + handled_tokens = set() + for k, v in entry['details']['tokens'].items(): + if not v['type'].startswith('list:'): + continue + + list_entry = { + 'name': v['name'], + 'map_to': k, + 'required': v['required'], + 'type': 'list', + 'delim': v['delim'][0], + 'content': [] + } + + for content in v['group']: + token = entry['details']['tokens'][content] + list_entry['content'].append({ + 'name': token['name'], + 'required': token['required'], + 'type': token['type'], + 'prefix': token.get('prefix'), + 'regex': process_regex(token.get('regex')) + }) + handled_tokens.add(content) + + result['details']['tokens'].append(list_entry) + handled_tokens.add(k) + + for k, v in entry['details']['tokens'].items(): + if k in handled_tokens: + continue + + normal_entry = { + 'name': v['name'], + 'map_to': k, + 'required': v['required'], + 'type': v['type'].split(':')[0] + } + + if v['type'].startswith('choice'): + normal_entry.update({ + 'options': v.get('values'), + 'default': v.get('default') + }) + + else: + normal_entry.update({ + 'prefix': v.get('prefix'), + 'min': v.get('min'), + 'max': v.get('max'), + 'regex': process_regex(v.get('regex')) + }) + + result['details']['tokens'].append(normal_entry) + + for k, v in entry['details']['args'].items(): + if ( + v.get('alias_of') is not None + or k in IGNORED_ARGS + ): + continue + + args_entry = { + 'name': v.get('name', k), + 'map_to': k, + 'required': v.get('required', False), + 'type': v['type'].split(':')[0], + } + + if v['type'].startswith('list'): + args_entry.update({ + 'delim': v['delim'][0], + 'content': [] + }) + + elif v['type'].startswith('choice'): + args_entry.update({ + 'options': v['values'], + 'default': v.get('default') + }) + + elif v['type'] == 'bool': + args_entry.update({ + 'default': v['default'] + }) + + else: + args_entry.update({ + 'min': v.get('min'), + 'max': v.get('max'), + 'regex': process_regex(v.get('regex')) + }) + + result['details']['args'].append(args_entry) + + result['details']['tokens'].sort(key=_sort_tokens) + result['details']['args'].sort(key=_sort_tokens) + apprise_services.append(result) + + apprise_services.sort(key=lambda s: s['name'].lower()) + + apprise_services.insert(0, { + 'name': 'Custom URL', + 'doc_url': 'https://github.com/caronc/apprise#supported-notifications', + 'details': { + 'templates': ['{url}'], + 'tokens': [{ + 'name': 'Apprise URL', + 'map_to': 'url', + 'required': True, + 'type': 'string', + 'prefix': None, + 'min': None, + 'max': None, + 'regex': None + }], + 'args': [] + } + }) + + return apprise_services diff --git a/backend/implementations/notification_services.py b/backend/implementations/notification_services.py new file mode 100644 index 0000000..a387540 --- /dev/null +++ b/backend/implementations/notification_services.py @@ -0,0 +1,205 @@ +# -*- coding: utf-8 -*- + +from dataclasses import asdict +from typing import List, Union + +from backend.base.custom_exceptions import (NotificationServiceInUse, + NotificationServiceNotFound, + URLInvalid) +from backend.base.definitions import (Constants, NotificationServiceData, + ReminderType, SendResult) +from backend.base.helpers import send_apprise_notification +from backend.base.logging import LOGGER +from backend.internals.db_models import (NotificationServicesDB, + ReminderServicesDB) + + +class NotificationService: + def __init__( + self, + user_id: int, + notification_service_id: int + ) -> None: + """Create an representation of a notification service. + + Args: + user_id (int): The ID that the service belongs to. + notification_service_id (int): The ID of the service itself. + + Raises: + NotificationServiceNotFound: The user does not own a notification + service with the given ID. + """ + self.user_id = user_id + self.id = notification_service_id + + self.ns_db = NotificationServicesDB(self.user_id) + + if not self.ns_db.exists(self.id): + raise NotificationServiceNotFound(self.id) + + return + + def get(self) -> NotificationServiceData: + """Get the info about the notification service. + + Returns: + NotificationServiceData: The info about the notification service. + """ + return self.ns_db.fetch(self.id)[0] + + def update( + self, + title: Union[str, None] = None, + url: Union[str, None] = None + ) -> NotificationServiceData: + """Edit the notification service. The URL is tested by sending a test + notification to it. + + Args: + title (Union[str, None], optional): The new title of the service. + Defaults to None. + + url (Union[str, None], optional): The new url of the service. + Defaults to None. + + Returns: + NotificationServiceData: The new info about the service. + """ + LOGGER.info( + f'Updating notification service {self.id}: {title=}, {url=}' + ) + + # Get current data and update it with new values + data = asdict(self.get()) + test_url = data["url"] != url + + new_values = { + 'title': title, + 'url': url + } + for k, v in new_values.items(): + if v is not None: + data[k] = v + + if test_url and NotificationServices(self.user_id).test( + data['url'] + ) != SendResult.SUCCESS: + raise URLInvalid(data['url']) + + self.ns_db.update(self.id, data["title"], data["url"]) + + return self.get() + + def delete( + self, + delete_reminders_using: bool = False + ) -> None: + """Delete the service. + + Args: + delete_reminders_using (bool, optional): Instead of throwing an + error when there are still reminders using the service, delete + the reminders. + Defaults to False. + + Raises: + NotificationServiceInUse: The service is still used by a reminder. + """ + from backend.features.reminders import Reminder + from backend.features.static_reminders import StaticReminder + from backend.features.templates import Template + + LOGGER.info(f'Deleting notification service {self.id}') + + for r_type, RClass in ( + (ReminderType.REMINDER, Reminder), + (ReminderType.STATIC_REMINDER, StaticReminder), + (ReminderType.TEMPLATE, Template) + ): + uses = ReminderServicesDB(r_type).uses_ns(self.id) + if uses: + if not delete_reminders_using: + raise NotificationServiceInUse( + self.id, + r_type.value + ) + + for r_id in uses: + RClass(self.user_id, r_id).delete() + + self.ns_db.delete(self.id) + return + + +class NotificationServices: + def __init__(self, user_id: int) -> None: + """Represent the notification services of a user. + + Args: + user_id (int): The ID of the user. + """ + self.user_id = user_id + self.ns_db = NotificationServicesDB(self.user_id) + return + + def fetchall(self) -> List[NotificationServiceData]: + """Get a list of all notification services. + + Returns: + List[NotificationServiceData]: The list of all notification services. + """ + return self.ns_db.fetch() + + def fetchone(self, notification_service_id: int) -> NotificationService: + """Get one notification service based on it's id. + + Args: + notification_service_id (int): The id of the desired service. + + Raises: + NotificationServiceNotFound: The user does not own a notification + service with the given ID. + + Returns: + NotificationService: Instance of NotificationService. + """ + return NotificationService(self.user_id, notification_service_id) + + def add(self, title: str, url: str) -> NotificationService: + """Add a notification service. The service is tested by sending a test + notification to it. + + Args: + title (str): The title of the service. + url (str): The apprise url of the service. + + Raises: + URLInvalid: The url is invalid. + + Returns: + NotificationService: The instance representing the new service. + """ + LOGGER.info(f'Adding notification service with {title=}, {url=}') + + if self.test(url) != SendResult.SUCCESS: + raise URLInvalid(url) + + new_id = self.ns_db.add(title, url) + + return self.fetchone(new_id) + + def test(self, url: str) -> SendResult: + """Test a notification service by sending a test notification to it. + + Args: + url (str): The apprise url of the service. + + Returns: + SendResult: The result of the test. + """ + return send_apprise_notification( + [url], + Constants.APPRISE_TEST_TITLE, + Constants.APPRISE_TEST_BODY + ) diff --git a/backend/implementations/users.py b/backend/implementations/users.py new file mode 100644 index 0000000..be9617f --- /dev/null +++ b/backend/implementations/users.py @@ -0,0 +1,292 @@ +# -*- coding: utf-8 -*- + +from typing import List, Union + +from backend.base.custom_exceptions import (AccessUnauthorized, + NewAccountsNotAllowed, + OperationNotAllowed, + UsernameInvalid, UsernameTaken, + UserNotFound) +from backend.base.definitions import Constants, InvalidUsernameReason, UserData +from backend.base.helpers import Singleton, generate_salt_hash, get_hash +from backend.base.logging import LOGGER +from backend.internals.db_models import UsersDB +from backend.internals.settings import Settings + + +def is_valid_username(username: str) -> None: + """Check if username is valid. + + Args: + username (str): The username to check. + + Raises: + UsernameInvalid: The username is not valid. + """ + if username in Constants.INVALID_USERNAMES: + raise UsernameInvalid(username, InvalidUsernameReason.NOT_ALLOWED) + + if username.isdigit(): + raise UsernameInvalid(username, InvalidUsernameReason.ONLY_NUMBERS) + + if any( + c not in Constants.USERNAME_CHARACTERS + for c in username + ): + raise UsernameInvalid( + username, + InvalidUsernameReason.INVALID_CHARACTER + ) + + return + + +class User: + def __init__(self, id: int) -> None: + """Create a representation of a user. + + Args: + id (int): The ID of the user. + + Raises: + UserNotFound: The user does not exist. + """ + self.user_db = UsersDB() + self.user_id = id + + if not self.user_db.exists(self.user_id): + raise UserNotFound(None, id) + + return + + def get(self) -> UserData: + """Get the info about the user. + + Returns: + UserData: The info about the user. + """ + return self.user_db.fetch(self.user_id)[0] + + def update( + self, + new_username: Union[str, None], + new_password: Union[str, None] + ) -> None: + """Change the username and/or password of the account. + + Args: + new_username (Union[str, None]): The new username, or None if it + should not be changed. + new_password (Union[str, None]): The new password, or None if it + should not be changed. + + Raises: + OperationNotAllowed: The user is an admin and is trying to change + the username. + UsernameInvalid: The new username is not valid. + UsernameTaken: The new username is already taken. + """ + user_data = self.get() + + if new_username is not None: + if user_data.admin: + raise OperationNotAllowed( + "Changing the username of an admin account" + ) + + is_valid_username(new_username) + + if self.user_db.taken(new_username): + raise UsernameTaken(new_username) + + self.user_db.update( + self.user_id, + new_username, + user_data.hash + ) + + LOGGER.info( + f"The user with ID {self.user_id} has a changed username: {new_username}" + ) + + user_data = self.get() + + if new_password is not None: + hash_password = get_hash(user_data.salt, new_password) + + self.user_db.update( + self.user_id, + user_data.username, + hash_password + ) + + LOGGER.info( + f'The user with ID {self.user_id} changed their password' + ) + + return + + def delete(self) -> None: + """Delete the user. + + Raises: + OperationNotAllowed: The admin account cannot be deleted. + """ + user_data = self.get() + if user_data.admin: + raise OperationNotAllowed( + "The admin account cannot be deleted" + ) + + LOGGER.info(f'Deleting the user with ID {self.user_id}') + + self.user_db.delete(self.user_id) + + return + + +class Users(metaclass=Singleton): + def __init__(self) -> None: + self.user_db = UsersDB() + return + + def get_all(self) -> List[UserData]: + """Get all user info for the admin + + Returns: + List[UserData]: The info about all users + """ + result = self.user_db.fetch() + return result + + def get_one(self, id: int) -> User: + """Get a user instance based on the ID. + + Args: + id (int): The ID of the user. + + Returns: + User: The user instance. + """ + return User(id) + + def __contains__(self, username_or_id: Union[str, int]) -> bool: + if isinstance(username_or_id, str): + return self.username_taken(username_or_id) + else: + return self.id_taken(username_or_id) + + def username_taken(self, username: str) -> bool: + """Check if a username is taken. + + Args: + username (str): The username to check. + + Returns: + bool: True if the username is taken, False otherwise. + """ + return self.user_db.taken(username) + + def id_taken(self, id: int) -> bool: + """Check if a user ID is taken. + + Args: + id (int): The user ID to check. + + Returns: + bool: True if the user ID is taken, False otherwise. + """ + return self.user_db.exists(id) + + def login( + self, + username: str, + password: str + ) -> User: + """Login into an user account. + + Args: + username (str): The username of the user. + password (str): The password of the user. + + Raises: + UserNotFound: There is no user with the given username. + AccessUnauthorized: The password is incorrect. + + Returns: + User: The user that was logged into. + """ + if not self.user_db.taken(username): + raise UserNotFound(username, None) + + user_data = self.user_db.fetch( + self.user_db.username_to_id(username) + )[0] + + hash_password = get_hash(user_data.salt, password) + # Comparing hashes, not password strings, so no need for + # constant time comparison + if not hash_password == user_data.hash: + raise AccessUnauthorized + + return User(user_data.id) + + def add( + self, + username: str, + password: str, + force: bool = False, + is_admin: bool = False + ) -> int: + """Add a user. + + Args: + username (str): The username of the new user. + + password (str): The password of the new user. + + force (bool, optional): Skip check for whether new accounts are + allowed. + Defaults to False. + + is_admin (bool, optional): The account is the admin account. + Defaults to False. + + Raises: + UsernameInvalid: Username not allowed or contains invalid characters. + UsernameTaken: Username is already taken; usernames must be unique. + NewAccountsNotAllowed: In the admin panel, new accounts are set to be + not allowed. + + Returns: + int: The ID of the new user. User registered successfully. + """ + LOGGER.info(f'Registering user with username {username}') + + if not force and not Settings().get_settings().allow_new_accounts: + raise NewAccountsNotAllowed + + is_valid_username(username) + + if self.user_db.taken(username): + raise UsernameTaken(username) + + if is_admin: + if self.user_db.taken(Constants.ADMIN_USERNAME): + # Attempted to add admin account (only done internally), + # but admin account already exists + raise RuntimeError("Admin account already exists") + + # Generate salt and key exclusive for user + salt, hashed_password = generate_salt_hash(password) + + # Add user to database + user_id = self.user_db.add( + username, + salt, + hashed_password, + is_admin + ) + + LOGGER.debug(f'Newly registered user has id {user_id}') + return user_id diff --git a/backend/internals/db.py b/backend/internals/db.py new file mode 100644 index 0000000..cb2c4bf --- /dev/null +++ b/backend/internals/db.py @@ -0,0 +1,493 @@ +# -*- coding: utf-8 -*- + +""" +Setting up the database and handling connections +""" + +from __future__ import annotations + +from os import remove +from os.path import dirname, exists, isdir, isfile, join +from shutil import move +from sqlite3 import (PARSE_DECLTYPES, Connection, Cursor, + OperationalError, ProgrammingError, Row, + register_adapter, register_converter) +from threading import current_thread +from typing import Any, Dict, Generator, Iterable, List, Type, Union + +from flask import g + +from backend.base.custom_exceptions import InvalidDatabaseFile +from backend.base.definitions import Constants, ReminderType, StartType, T +from backend.base.helpers import create_folder, folder_path, rename_file +from backend.base.logging import LOGGER, set_log_level +from backend.internals.db_migration import get_latest_db_version, migrate_db + +REMINDER_TO_KEY = { + ReminderType.REMINDER: "reminder_id", + ReminderType.STATIC_REMINDER: "static_reminder_id", + ReminderType.TEMPLATE: "template_id" +} + + +class MindCursor(Cursor): + + row_factory: Union[Type[Row], None] # type: ignore + + @property + def lastrowid(self) -> int: + return super().lastrowid or 1 + + def fetchonedict(self) -> Union[Dict[str, Any], None]: + """Same as `fetchone` but convert the Row object to a dict. + + Returns: + Union[Dict[str, Any], None]: The dict or None i.c.o. no result. + """ + r = self.fetchone() + if r is None: + return r + return dict(r) + + def fetchmanydict(self, size: Union[int, None] = 1) -> List[Dict[str, Any]]: + """Same as `fetchmany` but convert the Row object to a dict. + + Args: + size (Union[int, None], optional): The amount of rows to return. + Defaults to 1. + + Returns: + List[Dict[str, Any]]: The rows. + """ + return [dict(e) for e in self.fetchmany(size)] + + def fetchalldict(self) -> List[Dict[str, Any]]: + """Same as `fetchall` but convert the Row object to a dict. + + Returns: + List[Dict[str, Any]]: The results. + """ + return [dict(e) for e in self] + + def exists(self) -> Union[Any, None]: + """Return the first column of the first row, or `None` if not found. + + Returns: + Union[Any, None]: The value of the first column of the first row, + or `None` if not found. + """ + r = self.fetchone() + if r is None: + return r + return r[0] + + +class DBConnectionManager(type): + instances: Dict[int, DBConnection] = {} + + def __call__(cls, *args: Any, **kwargs: Any) -> DBConnection: + thread_id = current_thread().native_id or -1 + + if ( + not thread_id in cls.instances + or cls.instances[thread_id].closed + ): + cls.instances[thread_id] = super().__call__(*args, **kwargs) + + return cls.instances[thread_id] + + +class DBConnection(Connection, metaclass=DBConnectionManager): + file = '' + + def __init__(self, timeout: float) -> None: + """Create a connection with a database. + + Args: + timeout (float): How long to wait before giving up on a command. + """ + LOGGER.debug(f'Creating connection {self}') + super().__init__( + self.file, + timeout=timeout, + detect_types=PARSE_DECLTYPES + ) + super().cursor().execute("PRAGMA foreign_keys = ON;") + self.closed = False + return + + def cursor( # type: ignore + self, + force_new: bool = False + ) -> MindCursor: + """Get a database cursor from the connection. + + Args: + force_new (bool, optional): Get a new cursor instead of the cached + one. + Defaults to False. + + Returns: + MindCursor: The database cursor. + """ + if not hasattr(g, 'cursors'): + g.cursors = [] + + if not g.cursors: + c = MindCursor(self) + c.row_factory = Row + g.cursors.append(c) + + if not force_new: + return g.cursors[0] + else: + c = MindCursor(self) + c.row_factory = Row + g.cursors.append(c) + return g.cursors[-1] + + def close(self) -> None: + """Close the database connection""" + LOGGER.debug(f'Closing connection {self}') + self.closed = True + super().close() + return + + def __repr__(self) -> str: + return f'<{self.__class__.__name__}; {current_thread().name}; {id(self)}>' + + +def set_db_location( + db_folder: Union[str, None] +) -> None: + """Setup database location. Create folder for database and set location for + `db.DBConnection`. + + Args: + db_folder (Union[str, None], optional): The folder in which the database + will be stored or in which a database is for MIND to use. Give + `None` for the default location. + + Raises: + ValueError: Value of `db_folder` exists but is not a folder. + """ + if db_folder: + if exists(db_folder) and not isdir(db_folder): + raise ValueError('Database location is not a folder') + + db_file_location = join( + db_folder or folder_path(*Constants.DB_FOLDER), + Constants.DB_NAME + ) + + LOGGER.debug(f'Setting database location: {db_file_location}') + + create_folder(dirname(db_file_location)) + + if isfile(folder_path('db', 'Noted.db')): + rename_file( + folder_path('db', 'Noted.db'), + db_file_location + ) + + DBConnection.file = db_file_location + + return + + +def get_db(force_new: bool = False) -> MindCursor: + """Get a database cursor instance or create a new one if needed. + + Args: + force_new (bool, optional): Decides if a new cursor is + returned instead of the standard one. + Defaults to False. + + Returns: + MindCursor: Database cursor instance that outputs Row objects. + """ + cursor = ( + DBConnection(timeout=Constants.DB_TIMEOUT) + .cursor(force_new=force_new) + ) + return cursor + + +def commit() -> None: + """Commit the database""" + get_db().connection.commit() + return + + +def iter_commit(iterable: Iterable[T]) -> Generator[T, Any, Any]: + """Commit the database after each iteration. Also commits just before the + first iteration starts. + + Args: + iterable (Iterable[T]): Iterable that will be iterated over like normal. + + Yields: + Generator[T, Any, Any]: Items of iterable. + """ + commit = get_db().connection.commit + commit() + for i in iterable: + yield i + commit() + return + + +def close_db(e: Union[None, BaseException] = None) -> None: + """Close database cursor, commit database and close database. + + Args: + e (Union[None, BaseException], optional): Error. Defaults to None. + """ + try: + cursors = g.cursors + db: DBConnection = cursors[0].connection + for c in cursors: + c.close() + delattr(g, 'cursors') + db.commit() + if not current_thread().name.startswith('waitress-'): + db.close() + + except (AttributeError, ProgrammingError): + pass + + return + + +def close_all_db() -> None: + "Close all non-temporary database connections that are still open" + LOGGER.debug('Closing any open database connections') + + for i in DBConnectionManager.instances.values(): + if not i.closed: + i.close() + + c = DBConnection(timeout=20.0) + c.commit() + c.close() + return + + +def setup_db() -> None: + """ + Setup the database tables and default config when they aren't setup yet + """ + from backend.implementations.users import Users + from backend.internals.settings import Settings + + cursor = get_db() + cursor.execute("PRAGMA journal_mode = wal;") + register_adapter(bool, lambda b: int(b)) + register_converter("BOOL", lambda b: b == b'1') + + cursor.executescript(""" + CREATE TABLE IF NOT EXISTS users( + id INTEGER PRIMARY KEY, + username VARCHAR(255) UNIQUE NOT NULL, + salt VARCHAR(40) NOT NULL, + hash VARCHAR(100) NOT NULL, + admin BOOL NOT NULL DEFAULT 0 + ); + CREATE TABLE IF NOT EXISTS notification_services( + id INTEGER PRIMARY KEY, + user_id INTEGER NOT NULL, + title VARCHAR(255), + url TEXT, + + FOREIGN KEY (user_id) REFERENCES users(id) + ); + CREATE TABLE IF NOT EXISTS reminders( + id INTEGER PRIMARY KEY, + user_id INTEGER NOT NULL, + title VARCHAR(255) NOT NULL, + text TEXT, + time INTEGER NOT NULL, + + repeat_quantity VARCHAR(15), + repeat_interval INTEGER, + original_time INTEGER, + weekdays VARCHAR(13), + + color VARCHAR(7), + + FOREIGN KEY (user_id) REFERENCES users(id) + ); + CREATE TABLE IF NOT EXISTS templates( + id INTEGER PRIMARY KEY, + user_id INTEGER NOT NULL, + title VARCHAR(255) NOT NULL, + text TEXT, + + color VARCHAR(7), + + FOREIGN KEY (user_id) REFERENCES users(id) + ); + CREATE TABLE IF NOT EXISTS static_reminders( + id INTEGER PRIMARY KEY, + user_id INTEGER NOT NULL, + title VARCHAR(255) NOT NULL, + text TEXT, + + color VARCHAR(7), + + FOREIGN KEY (user_id) REFERENCES users(id) + ); + CREATE TABLE IF NOT EXISTS reminder_services( + reminder_id INTEGER, + static_reminder_id INTEGER, + template_id INTEGER, + notification_service_id INTEGER NOT NULL, + + FOREIGN KEY (reminder_id) REFERENCES reminders(id) + ON DELETE CASCADE, + FOREIGN KEY (static_reminder_id) REFERENCES static_reminders(id) + ON DELETE CASCADE, + FOREIGN KEY (template_id) REFERENCES templates(id) + ON DELETE CASCADE, + FOREIGN KEY (notification_service_id) REFERENCES notification_services(id) + ); + CREATE TABLE IF NOT EXISTS config( + key VARCHAR(255) PRIMARY KEY, + value BLOB NOT NULL + ); + """) + + settings = Settings() + settings_values = settings.get_settings() + + set_log_level(settings_values.log_level) + + migrate_db() + + # DB Migration might change settings, so update cache just to be sure. + settings._fetch_settings() + + # Add admin user if it doesn't exist + users = Users() + if Constants.ADMIN_USERNAME not in users: + users.add( + Constants.ADMIN_USERNAME, Constants.ADMIN_PASSWORD, + force=True, + is_admin=True + ) + + return + + +def revert_db_import( + swap: bool, + imported_db_file: str = '' +) -> None: + """Revert the database import process. The original_db_file is the file + currently used (`DBConnection.file`). + + Args: + swap (bool): Whether or not to keep the imported_db_file or not, + instead of the original_db_file. + + imported_db_file (str, optional): The other database file. Keep empty + to use `Constants.DB_ORIGINAL_FILENAME`. + Defaults to ''. + """ + original_db_file = DBConnection.file + if not imported_db_file: + imported_db_file = join( + dirname(DBConnection.file), + Constants.DB_ORIGINAL_NAME + ) + + if swap: + remove(original_db_file) + move( + imported_db_file, + original_db_file + ) + + else: + remove(imported_db_file) + + return + + +def import_db( + new_db_file: str, + copy_hosting_settings: bool +) -> None: + """Replace the current database with a new one. + + Args: + new_db_file (str): The path to the new database file. + copy_hosting_settings (bool): Keep the hosting settings from the current + database. + + Raises: + InvalidDatabaseFile: The new database file is invalid or unsupported. + """ + LOGGER.info(f'Importing new database; {copy_hosting_settings=}') + + cursor = Connection(new_db_file, timeout=20.0).cursor() + try: + database_version = cursor.execute( + "SELECT value FROM config WHERE key = 'database_version' LIMIT 1;" + ).fetchone()[0] + if not isinstance(database_version, int): + raise InvalidDatabaseFile(new_db_file) + + except (OperationalError, InvalidDatabaseFile): + LOGGER.error('Uploaded database is not a MIND database file') + cursor.connection.close() + revert_db_import( + swap=False, + imported_db_file=new_db_file + ) + raise InvalidDatabaseFile(new_db_file) + + if database_version > get_latest_db_version(): + LOGGER.error( + 'Uploaded database is higher version than this MIND installation can support') + revert_db_import( + swap=False, + imported_db_file=new_db_file + ) + raise InvalidDatabaseFile(new_db_file) + + if copy_hosting_settings: + hosting_settings = get_db().execute(""" + SELECT key, value + FROM config + WHERE key = 'host' + OR key = 'port' + OR key = 'url_prefix' + LIMIT 3; + """ + ).fetchalldict() + cursor.executemany(""" + INSERT INTO config(key, value) + VALUES (:key, :value) + ON CONFLICT(key) DO + UPDATE + SET value = :value; + """, + hosting_settings + ) + cursor.connection.commit() + cursor.connection.close() + + move( + DBConnection.file, + join(dirname(DBConnection.file), Constants.DB_ORIGINAL_NAME) + ) + move( + new_db_file, + DBConnection.file + ) + + from backend.internals.server import Server + Server().restart(StartType.RESTART_DB_CHANGES) + + return diff --git a/backend/internals/db_migration.py b/backend/internals/db_migration.py new file mode 100644 index 0000000..beb6c56 --- /dev/null +++ b/backend/internals/db_migration.py @@ -0,0 +1,312 @@ +# -*- coding: utf-8 -*- + +from typing import Dict, Type + +from backend.base.definitions import Constants, DBMigrator +from backend.base.logging import LOGGER + + +class VersionMappingContainer: + version_map: Dict[int, Type[DBMigrator]] = {} + + +def _load_version_map() -> None: + if VersionMappingContainer.version_map: + return + + VersionMappingContainer.version_map = { + m.start_version: m + for m in DBMigrator.__subclasses__() + } + return + + +def get_latest_db_version() -> int: + _load_version_map() + return max(VersionMappingContainer.version_map) + 1 + + +def migrate_db() -> None: + """ + Migrate a MIND database from it's current version + to the newest version supported by the MIND version installed. + """ + from backend.internals.db import iter_commit + from backend.internals.settings import Settings + + s = Settings() + current_db_version = s.get_settings().database_version + newest_version = get_latest_db_version() + if current_db_version == newest_version: + return + + LOGGER.info('Migrating database to newer version...') + LOGGER.debug( + "Database migration: %d -> %d", + current_db_version, newest_version + ) + + for start_version in iter_commit(range(current_db_version, newest_version)): + if start_version not in VersionMappingContainer.version_map: + continue + VersionMappingContainer.version_map[start_version]().run() + s.update({'database_version': start_version + 1}) + + s._fetch_settings() + + return + + +class MigrateToUTC(DBMigrator): + start_version = 1 + + def run(self) -> None: + # V1 -> V2 + + from datetime import datetime + from time import time + + from backend.internals.db import get_db + + cursor = get_db() + + t = time() + utc_offset = datetime.fromtimestamp(t) - datetime.utcfromtimestamp(t) + + cursor.execute("SELECT time, id FROM reminders;") + new_reminders = [ + [ + round(( + datetime.fromtimestamp(r["time"]) - utc_offset + ).timestamp()), + r["id"] + ] + for r in cursor + ] + + cursor.executemany( + "UPDATE reminders SET time = ? WHERE id = ?;", + new_reminders + ) + return + + +class MigrateAddColor(DBMigrator): + start_version = 2 + + def run(self) -> None: + # V2 -> V3 + + from backend.internals.db import get_db + + get_db().executescript(""" + ALTER TABLE reminders + ADD color VARCHAR(7); + ALTER TABLE templates + ADD color VARCHAR(7); + """) + + return + + +class MigrateFixRQ(DBMigrator): + start_version = 3 + + def run(self) -> None: + # V3 -> V4 + + from backend.internals.db import get_db + + get_db().executescript(""" + UPDATE reminders + SET repeat_quantity = repeat_quantity || 's' + WHERE repeat_quantity NOT LIKE '%s'; + """) + + return + + +class MigrateToReminderServices(DBMigrator): + start_version = 4 + + def run(self) -> None: + # V4 -> V5 + + from backend.internals.db import get_db + + get_db().executescript(""" + BEGIN TRANSACTION; + PRAGMA defer_foreign_keys = ON; + + CREATE TEMPORARY TABLE temp_reminder_services( + reminder_id, + static_reminder_id, + template_id, + notification_service_id + ); + + -- Reminders + INSERT INTO temp_reminder_services(reminder_id, notification_service_id) + SELECT id, notification_service + FROM reminders; + + CREATE TEMPORARY TABLE temp_reminders AS + SELECT id, user_id, title, text, time, repeat_quantity, repeat_interval, original_time, color + FROM reminders; + DROP TABLE reminders; + CREATE TABLE reminders( + id INTEGER PRIMARY KEY, + user_id INTEGER NOT NULL, + title VARCHAR(255) NOT NULL, + text TEXT, + time INTEGER NOT NULL, + + repeat_quantity VARCHAR(15), + repeat_interval INTEGER, + original_time INTEGER, + + color VARCHAR(7), + + FOREIGN KEY (user_id) REFERENCES users(id) + ); + INSERT INTO reminders + SELECT * FROM temp_reminders; + + -- Templates + INSERT INTO temp_reminder_services(template_id, notification_service_id) + SELECT id, notification_service + FROM templates; + + CREATE TEMPORARY TABLE temp_templates AS + SELECT id, user_id, title, text, color + FROM templates; + DROP TABLE templates; + CREATE TABLE templates( + id INTEGER PRIMARY KEY, + user_id INTEGER NOT NULL, + title VARCHAR(255) NOT NULL, + text TEXT, + + color VARCHAR(7), + + FOREIGN KEY (user_id) REFERENCES users(id) + ); + INSERT INTO templates + SELECT * FROM temp_templates; + + INSERT INTO reminder_services + SELECT * FROM temp_reminder_services; + + COMMIT; + """) + + return + + +class MigrateRemoveUser1(DBMigrator): + start_version = 5 + + def run(self) -> None: + # V5 -> V6 + from backend.base.custom_exceptions import (AccessUnauthorized, + UserNotFound) + from backend.implementations.users import Users + + try: + Users().login('User1', 'Password1').delete() + + except (UserNotFound, AccessUnauthorized): + pass + + return + + +class MigrateAddWeekdays(DBMigrator): + start_version = 6 + + def run(self) -> None: + # V6 -> V7 + + from backend.internals.db import get_db + + get_db().executescript(""" + ALTER TABLE reminders + ADD weekdays VARCHAR(13); + """) + + return + + +class MigrateAddAdmin(DBMigrator): + start_version = 7 + + def run(self) -> None: + # V7 -> V8 + + from backend.implementations.users import Users + from backend.internals.db import get_db + from backend.internals.settings import Settings + + cursor = get_db() + + cursor.executescript(""" + DROP TABLE config; + CREATE TABLE IF NOT EXISTS config( + key VARCHAR(255) PRIMARY KEY, + value BLOB NOT NULL + ); + """ + ) + Settings()._insert_missing_settings() + + cursor.executescript(""" + ALTER TABLE users + ADD admin BOOL NOT NULL DEFAULT 0; + """ + ) + users = Users() + if 'admin' in users: + users.get_one( + users.user_db.username_to_id('admin') + ).update( + new_username='admin_old', + new_password=None + ) + + users.add( + Constants.ADMIN_USERNAME, Constants.ADMIN_PASSWORD, + force=True, + is_admin=True + ) + + return + + +class MigrateHostSettingsToDB(DBMigrator): + start_version = 8 + + def run(self) -> None: + # V8 -> V9 + # In newer versions, the variables don't exist anymore, and behaviour + # was to then set the values to the default values. But that's already + # taken care of by the settings, so nothing to do here anymore. + return + + +class MigrateUpdateManifest(DBMigrator): + start_version = 9 + + def run(self) -> None: + # V9 -> V10 + + # Nothing is changed in the database + # It's just that this code needs to run once + # and the DB migration system does exactly that: + # run pieces of code once. + from backend.internals.settings import Settings, update_manifest + + update_manifest( + Settings().get_settings().url_prefix + ) + + return diff --git a/backend/internals/db_models.py b/backend/internals/db_models.py new file mode 100644 index 0000000..22fbd1c --- /dev/null +++ b/backend/internals/db_models.py @@ -0,0 +1,815 @@ +# -*- coding: utf-8 -*- + +from typing import List, Union + +from backend.base.definitions import (NotificationServiceData, ReminderData, + ReminderType, StaticReminderData, + TemplateData, UserData) +from backend.base.helpers import first_of_column +from backend.internals.db import REMINDER_TO_KEY, get_db + + +class NotificationServicesDB: + def __init__(self, user_id: int) -> None: + self.user_id = user_id + return + + def exists(self, notification_service_id: int) -> bool: + return get_db().execute(""" + SELECT 1 + FROM notification_services + WHERE id = :id + AND user_id = :user_id + LIMIT 1; + """, + { + 'user_id': self.user_id, + 'id': notification_service_id + } + ).fetchone() is not None + + def fetch( + self, + notification_service_id: Union[int, None] = None + ) -> List[NotificationServiceData]: + id_filter = "" + if notification_service_id: + id_filter = "AND id = :ns_id" + + result = get_db().execute(f""" + SELECT + id, title, url + FROM notification_services + WHERE user_id = :user_id + {id_filter} + ORDER BY title, id; + """, + { + "user_id": self.user_id, + "ns_id": notification_service_id + } + ).fetchalldict() + + return [ + NotificationServiceData(**entry) + for entry in result + ] + + def add( + self, + title: str, + url: str + ) -> int: + new_id = get_db().execute(""" + INSERT INTO notification_services(user_id, title, url) + VALUES (?, ?, ?) + """, + (self.user_id, title, url) + ).lastrowid + return new_id + + def update( + self, + notification_service_id: int, + title: str, + url: str + ) -> None: + get_db().execute(""" + UPDATE notification_services + SET title = :title, url = :url + WHERE id = :ns_id; + """, + { + "title": title, + "url": url, + "ns_id": notification_service_id + } + ) + return + + def delete( + self, + notification_service_id: int + ) -> None: + get_db().execute( + "DELETE FROM notification_services WHERE id = ?;", + (notification_service_id,) + ) + return + + +class ReminderServicesDB: + def __init__(self, reminder_type: ReminderType) -> None: + self.key = REMINDER_TO_KEY[reminder_type] + return + + def reminder_to_ns( + self, + reminder_id: int + ) -> List[int]: + """Get the ID's of the notification services that are linked to the given + reminder, static reminder or template. + + Args: + reminder_id (int): The ID of the reminder, static reminder or template. + + Returns: + List[int]: A list of the notification service ID's that are linked to + the given reminder, static reminder or template. + """ + result = first_of_column(get_db().execute( + f""" + SELECT notification_service_id + FROM reminder_services + WHERE {self.key} = ?; + """, + (reminder_id,) + )) + + return result + + def update_ns_bindings( + self, + reminder_id: int, + notification_services: List[int] + ) -> None: + """Update the bindings of a reminder, static reminder or template to + notification services. + + Args: + reminder_id (int): The ID of the reminder, static reminder or template. + + notification_services (List[int]): The new list of notification services + that should be linked to the reminder, static reminder or template. + """ + cursor = get_db() + cursor.connection.isolation_level = None + cursor.execute("BEGIN TRANSACTION;") + + cursor.execute( + f""" + DELETE FROM reminder_services + WHERE {self.key} = ?; + """, + (reminder_id,) + ) + cursor.executemany( + f""" + INSERT INTO reminder_services( + {self.key}, + notification_service_id + ) + VALUES (?, ?); + """, + ((reminder_id, ns_id) for ns_id in notification_services) + ) + + cursor.execute("COMMIT;") + cursor.connection.isolation_level = "" + return + + def uses_ns( + self, + notification_service_id: int + ) -> List[int]: + """Get the ID's of the reminders (of given type) that use the given + notification service. + + Args: + notification_service_id (int): The ID of the notification service to + check for. + + Returns: + List[int]: The ID's of the reminders (only of the given type) that + use the notification service. + """ + return first_of_column(get_db().execute( + f""" + SELECT {self.key} + FROM reminder_services + WHERE notification_service_id = ? + AND {self.key} IS NOT NULL + LIMIT 1; + """, + (notification_service_id,) + )) + + +class UsersDB: + def exists(self, user_id: int) -> bool: + return get_db().execute(""" + SELECT 1 + FROM users + WHERE id = ? + LIMIT 1; + """, + (user_id,) + ).fetchone() is not None + + def taken(self, username: str) -> bool: + return get_db().execute(""" + SELECT 1 + FROM users + WHERE username = ? + LIMIT 1; + """, + (username,) + ).fetchone() is not None + + def username_to_id(self, username: str) -> int: + return get_db().execute(""" + SELECT id + FROM users + WHERE username = ? + LIMIT 1; + """, + (username,) + ).fetchone()[0] + + def fetch( + self, + user_id: Union[int, None] = None + ) -> List[UserData]: + id_filter = "" + if user_id: + id_filter = "WHERE id = :id" + + result = get_db().execute(f""" + SELECT + id, username, admin, salt, hash + FROM users + {id_filter} + ORDER BY username, id; + """, + { + "id": user_id + } + ).fetchalldict() + + return [ + UserData(**entry) + for entry in result + ] + + def add( + self, + username: str, + salt: bytes, + hash: bytes, + admin: bool + ) -> int: + user_id = get_db().execute( + """ + INSERT INTO users(username, salt, hash, admin) + VALUES (?, ?, ?, ?); + """, + (username, salt, hash, admin) + ).lastrowid + return user_id + + def update( + self, + user_id: int, + username: str, + hash: bytes + ) -> None: + get_db().execute(""" + UPDATE users + SET username = :username, hash = :hash + WHERE id = :user_id; + """, + { + "username": username, + "hash": hash, + "user_id": user_id + } + ) + return + + def delete( + self, + user_id: int + ) -> None: + get_db().executescript(f""" + BEGIN TRANSACTION; + + DELETE FROM reminders WHERE user_id = {user_id}; + DELETE FROM templates WHERE user_id = {user_id}; + DELETE FROM static_reminders WHERE user_id = {user_id}; + DELETE FROM notification_services WHERE user_id = {user_id}; + DELETE FROM users WHERE id = {user_id}; + + COMMIT; + """) + return + + +class TemplatesDB: + def __init__(self, user_id: int) -> None: + self.user_id = user_id + self.rms_db = ReminderServicesDB(ReminderType.TEMPLATE) + return + + def exists(self, template_id: int) -> bool: + return get_db().execute( + "SELECT 1 FROM templates WHERE id = ? AND user_id = ? LIMIT 1;", + (template_id, self.user_id) + ).fetchone() is not None + + def fetch( + self, + template_id: Union[int, None] = None + ) -> List[TemplateData]: + id_filter = "" + if template_id: + id_filter = "AND id = :t_id" + + result = get_db().execute(f""" + SELECT + id, title, text, color + FROM templates + WHERE user_id = :user_id + {id_filter} + ORDER BY title, id; + """, + { + "user_id": self.user_id, + "t_id": template_id + } + ).fetchalldict() + + for r in result: + r['notification_services'] = self.rms_db.reminder_to_ns(r['id']) + + return [ + TemplateData(**entry) + for entry in result + ] + + def add( + self, + title: str, + text: Union[str, None], + color: Union[str, None], + notification_services: List[int] + ) -> int: + new_id = get_db().execute(""" + INSERT INTO templates(user_id, title, text, color) + VALUES (?, ?, ?, ?); + """, + (self.user_id, title, text, color) + ).lastrowid + + self.rms_db.update_ns_bindings( + new_id, notification_services + ) + + return new_id + + def update( + self, + template_id: int, + title: str, + text: Union[str, None], + color: Union[str, None], + notification_services: List[int] + ) -> None: + get_db().execute(""" + UPDATE templates + SET + title = :title, + text = :text, + color = :color + WHERE id = :t_id; + """, + { + "title": title, + "text": text, + "color": color, + "t_id": template_id + } + ) + + self.rms_db.update_ns_bindings( + template_id, + notification_services + ) + + return + + def delete( + self, + template_id: int + ) -> None: + get_db().execute( + "DELETE FROM templates WHERE id = ?;", + (template_id,) + ) + return + + +class StaticRemindersDB: + def __init__(self, user_id: int) -> None: + self.user_id = user_id + self.rms_db = ReminderServicesDB(ReminderType.STATIC_REMINDER) + return + + def exists(self, reminder_id: int) -> bool: + return get_db().execute(""" + SELECT 1 + FROM static_reminders + WHERE id = ? + AND user_id = ? + LIMIT 1; + """, + (reminder_id, self.user_id) + ).fetchone() is not None + + def fetch( + self, + reminder_id: Union[int, None] = None + ) -> List[StaticReminderData]: + id_filter = "" + if reminder_id: + id_filter = "AND id = :r_id" + + result = get_db().execute(f""" + SELECT + id, title, text, color + FROM static_reminders + WHERE user_id = :user_id + {id_filter} + ORDER BY title, id; + """, + { + "user_id": self.user_id, + "r_id": reminder_id + } + ).fetchalldict() + + for r in result: + r['notification_services'] = self.rms_db.reminder_to_ns(r['id']) + + return [ + StaticReminderData(**entry) + for entry in result + ] + + def add( + self, + title: str, + text: Union[str, None], + color: Union[str, None], + notification_services: List[int] + ) -> int: + new_id = get_db().execute(""" + INSERT INTO static_reminders(user_id, title, text, color) + VALUES (?, ?, ?, ?); + """, + (self.user_id, title, text, color) + ).lastrowid + + self.rms_db.update_ns_bindings( + new_id, notification_services + ) + + return new_id + + def update( + self, + reminder_id: int, + title: str, + text: Union[str, None], + color: Union[str, None], + notification_services: List[int] + ) -> None: + get_db().execute(""" + UPDATE static_reminders + SET + title = :title, + text = :text, + color = :color + WHERE id = :r_id; + """, + { + "title": title, + "text": text, + "color": color, + "r_id": reminder_id + } + ) + + self.rms_db.update_ns_bindings( + reminder_id, + notification_services + ) + + return + + def delete( + self, + reminder_id: int + ) -> None: + get_db().execute( + "DELETE FROM static_reminders WHERE id = ?;", + (reminder_id,) + ) + return + + +class RemindersDB: + def __init__(self, user_id: int) -> None: + self.user_id = user_id + self.rms_db = ReminderServicesDB(ReminderType.REMINDER) + return + + def exists(self, reminder_id: int) -> bool: + return get_db().execute(""" + SELECT 1 + FROM reminders + WHERE id = ? + AND user_id = ? + LIMIT 1; + """, + (reminder_id, self.user_id) + ).fetchone() is not None + + def fetch( + self, + reminder_id: Union[int, None] = None + ) -> List[ReminderData]: + id_filter = "" + if reminder_id: + id_filter = "AND id = :r_id" + + result = get_db().execute(f""" + SELECT + id, title, text, color, + time, original_time, + repeat_quantity, repeat_interval, + weekdays AS _weekdays + FROM reminders + WHERE user_id = :user_id + {id_filter}; + """, + { + "user_id": self.user_id, + "r_id": reminder_id + } + ).fetchalldict() + + for r in result: + r['notification_services'] = self.rms_db.reminder_to_ns(r['id']) + + return [ + ReminderData(**entry) + for entry in result + ] + + def add( + self, + title: str, + text: Union[str, None], + time: int, + repeat_quantity: Union[str, None], + repeat_interval: Union[int, None], + weekdays: Union[str, None], + original_time: Union[int, None], + color: Union[str, None], + notification_services: List[int] + ) -> int: + new_id = get_db().execute(""" + INSERT INTO reminders( + user_id, + title, text, + time, + repeat_quantity, repeat_interval, + weekdays, + original_time, + color + ) + VALUES ( + :user_id, + :title, :text, + :time, + :rq, :ri, + :wd, + :ot, + :color + ); + """, + { + "user_id": self.user_id, + "title": title, + "text": text, + "time": time, + "rq": repeat_quantity, + "ri": repeat_interval, + "wd": weekdays, + "ot": original_time, + "color": color + } + ).lastrowid + + self.rms_db.update_ns_bindings( + new_id, notification_services + ) + + return new_id + + def update( + self, + reminder_id: int, + title: str, + text: Union[str, None], + time: int, + repeat_quantity: Union[str, None], + repeat_interval: Union[int, None], + weekdays: Union[str, None], + original_time: Union[int, None], + color: Union[str, None], + notification_services: List[int] + ) -> None: + get_db().execute(""" + UPDATE reminders + SET + title = :title, + text = :text, + time = :time, + repeat_quantity = :rq, + repeat_interval = :ri, + weekdays = :wd, + original_time = :ot, + color = :color + WHERE id = :r_id; + """, + { + "title": title, + "text": text, + "time": time, + "rq": repeat_quantity, + "ri": repeat_interval, + "wd": weekdays, + "ot": original_time, + "color": color, + "r_id": reminder_id + } + ) + + self.rms_db.update_ns_bindings( + reminder_id, + notification_services + ) + + return + + def delete( + self, + reminder_id: int + ) -> None: + get_db().execute( + "DELETE FROM reminders WHERE id = ?;", + (reminder_id,) + ) + return + + +class UserlessRemindersDB: + def __init__(self) -> None: + self.rms_db = ReminderServicesDB(ReminderType.REMINDER) + return + + def exists(self, reminder_id: int) -> bool: + return get_db().execute(""" + SELECT 1 + FROM reminders + WHERE id = ? + LIMIT 1; + """, + (reminder_id,) + ).fetchone() is not None + + def reminder_id_to_user_id(self, reminder_id: int) -> int: + return get_db().execute( + """ + SELECT user_id + FROM reminders + WHERE id = ? + LIMIT 1; + """, + (reminder_id,) + ).exists() or -1 + + def get_soonest_time(self) -> Union[int, None]: + return get_db().execute("SELECT MIN(time) FROM reminders;").exists() + + def fetch( + self, + time: Union[int, None] = None + ) -> List[ReminderData]: + time_filter = "" + if time: + time_filter = "WHERE time = :time" + + result = get_db().execute(f""" + SELECT + id, + title, text, color, + time, original_time, + repeat_quantity, repeat_interval, + weekdays AS _weekdays + FROM reminders + {time_filter}; + """, + { + "time": time + } + ).fetchalldict() + + for r in result: + r['notification_services'] = self.rms_db.reminder_to_ns(r['id']) + + return [ + ReminderData(**entry) + for entry in result + ] + + def add( + self, + user_id: int, + title: str, + text: Union[str, None], + time: int, + repeat_quantity: Union[str, None], + repeat_interval: Union[int, None], + weekdays: Union[str, None], + original_time: Union[int, None], + color: Union[str, None], + notification_services: List[int] + ) -> int: + new_id = get_db().execute(""" + INSERT INTO reminders( + user_id, + title, text, + time, + repeat_quantity, repeat_interval, + weekdays, + original_time, + color + ) + VALUES ( + :user_id, + :title, :text, + :time, + :rq, :ri, + :wd, + :ot, + :color + ); + """, + { + "user_id": user_id, + "title": title, + "text": text, + "time": time, + "rq": repeat_quantity, + "ri": repeat_interval, + "wd": weekdays, + "ot": original_time, + "color": color + } + ).lastrowid + + self.rms_db.update_ns_bindings( + new_id, notification_services + ) + + return new_id + + def update( + self, + reminder_id: int, + time: int + ) -> None: + get_db().execute(""" + UPDATE reminders + SET time = :time + WHERE id = :r_id; + """, + { + "time": time, + "r_id": reminder_id + } + ) + + return + + def delete( + self, + reminder_id: int + ) -> None: + get_db().execute( + "DELETE FROM reminders WHERE id = ?;", + (reminder_id,) + ) + return diff --git a/backend/internals/server.py b/backend/internals/server.py new file mode 100644 index 0000000..df43976 --- /dev/null +++ b/backend/internals/server.py @@ -0,0 +1,247 @@ +# -*- coding: utf-8 -*- + +""" +Setting up, running and shutting down the API and web-ui +""" + +from __future__ import annotations + +from os import urandom +from threading import Timer, current_thread +from typing import TYPE_CHECKING, Union + +from flask import Flask, render_template, request +from waitress.server import create_server +from waitress.task import ThreadedTaskDispatcher as TTD +from werkzeug.middleware.dispatcher import DispatcherMiddleware + +from backend.base.definitions import Constants, StartType +from backend.base.helpers import Singleton, folder_path +from backend.base.logging import LOGGER +from backend.internals.db import (DBConnectionManager, + close_db, revert_db_import) +from backend.internals.settings import Settings + +if TYPE_CHECKING: + from waitress.server import BaseWSGIServer, MultiSocketServer + + +class ThreadedTaskDispatcher(TTD): + def handler_thread(self, thread_no: int) -> None: + super().handler_thread(thread_no) + + thread_id = current_thread().native_id or -1 + if ( + thread_id in DBConnectionManager.instances + and not DBConnectionManager.instances[thread_id].closed + ): + DBConnectionManager.instances[thread_id].close() + + return + + def shutdown(self, cancel_pending: bool = True, timeout: int = 5) -> bool: + print() + LOGGER.info('Shutting down MIND') + result = super().shutdown(cancel_pending, timeout) + return result + + +def handle_start_type(start_type: StartType) -> None: + """Do special actions needed based on restart version. + + Args: + start_type (StartType): The restart version. + """ + if start_type == StartType.RESTART_HOSTING_CHANGES: + LOGGER.info("Starting timer for hosting changes") + Server().revert_hosting_timer.start() + + elif start_type == StartType.RESTART_DB_CHANGES: + LOGGER.info("Starting timer for database import") + Server().revert_db_timer.start() + + return + + +def diffuse_timers() -> None: + """Stop any timers running after doing a special restart.""" + + SERVER = Server() + + if SERVER.revert_hosting_timer.is_alive(): + LOGGER.info("Timer for hosting changes diffused") + SERVER.revert_hosting_timer.cancel() + + elif SERVER.revert_db_timer.is_alive(): + LOGGER.info("Timer for database import diffused") + SERVER.revert_db_timer.cancel() + revert_db_import(swap=False) + + return + + +class Server(metaclass=Singleton): + api_prefix = "/api" + admin_api_extension = "/admin" + admin_prefix = "/api/admin" + url_prefix = '' + + def __init__(self) -> None: + self.start_type = None + + self.revert_db_timer = Timer( + Constants.DB_REVERT_TIME, + revert_db_import, + kwargs={"swap": True} + ) + self.revert_db_timer.name = "DatabaseImportHandler" + + self.revert_hosting_timer = Timer( + Constants.HOSTING_REVERT_TIME, + self.restore_hosting_settings + ) + self.revert_hosting_timer.name = "HostingHandler" + + return + + def create_app(self) -> None: + """Creates an flask app instance that can be used to start a web server""" + + from frontend.api import admin_api, api + from frontend.ui import ui + + app = Flask( + __name__, + template_folder=folder_path('frontend', 'templates'), + static_folder=folder_path('frontend', 'static'), + static_url_path='/static' + ) + app.config['SECRET_KEY'] = urandom(32) + app.config['JSONIFY_PRETTYPRINT_REGULAR'] = True + app.config['JSON_SORT_KEYS'] = False + + # Add error handlers + @app.errorhandler(400) + def bad_request(e): + return {'error': "BadRequest", "result": {}}, 400 + + @app.errorhandler(405) + def method_not_allowed(e): + return {'error': "MethodNotAllowed", "result": {}}, 405 + + @app.errorhandler(500) + def internal_error(e): + return {'error': "InternalError", "result": {}}, 500 + + # Add endpoints + app.register_blueprint(ui) + app.register_blueprint(api, url_prefix=self.api_prefix) + app.register_blueprint(admin_api, url_prefix=self.admin_prefix) + + # Setup db handling + app.teardown_appcontext(close_db) + + self.app = app + return + + def set_url_prefix(self, url_prefix: str) -> None: + """Change the URL prefix of the server. + + Args: + url_prefix (str): The desired URL prefix to set it to. + """ + self.app.config["APPLICATION_ROOT"] = url_prefix + self.app.wsgi_app = DispatcherMiddleware( # type: ignore + Flask(__name__), + {url_prefix: self.app.wsgi_app} + ) + self.url_prefix = url_prefix + return + + def __create_waitress_server( + self, + host: str, + port: int + ) -> Union[MultiSocketServer, BaseWSGIServer]: + """From the `Flask` instance created in `self.create_app()`, create + a waitress server instance. + + Args: + host (str): Where to host the server on (e.g. `0.0.0.0`). + port (int): The port to host the server on (e.g. `5656`). + + Returns: + Union[MultiSocketServer, BaseWSGIServer]: The waitress server instance. + """ + dispatcher = ThreadedTaskDispatcher() + dispatcher.set_thread_count(Constants.HOSTING_THREADS) + + server = create_server( + self.app, + _dispatcher=dispatcher, + host=host, + port=port, + threads=Constants.HOSTING_THREADS + ) + return server + + def run(self, host: str, port: int) -> None: + """Start the webserver. + + Args: + host (str): Where to host the server on (e.g. `0.0.0.0`). + port (int): The port to host the server on (e.g. `5656`). + """ + self.server = self.__create_waitress_server(host, port) + LOGGER.info(f'MIND running on http://{host}:{port}{self.url_prefix}') + self.server.run() + + return + + def __shutdown_thread_function(self) -> None: + """Shutdown waitress server. Intended to be run in a thread. + """ + if not hasattr(self, 'server'): + return + + self.server.task_dispatcher.shutdown() + self.server.close() + self.server._map.clear() # type: ignore + return + + def shutdown(self) -> None: + """ + Stop the waitress server. Starts a thread that shuts down the server. + """ + t = Timer(1.0, self.__shutdown_thread_function) + t.name = "InternalStateHandler" + t.start() + return + + def restart( + self, + start_type: StartType = StartType.STARTUP + ) -> None: + """Same as `self.shutdown()`, but restart instead of shutting down. + + Args: + start_type (StartType, optional): Why Kapowarr should + restart. + Defaults to StartType.STARTUP. + """ + self.start_type = start_type + self.shutdown() + return + + def restore_hosting_settings(self) -> None: + with self.app.app_context(): + settings = Settings() + values = settings.get_settings() + main_settings = { + 'host': values.backup_host, + 'port': values.backup_port, + 'url_prefix': values.backup_url_prefix + } + settings.update(main_settings) + self.restart() + return diff --git a/backend/internals/settings.py b/backend/internals/settings.py new file mode 100644 index 0000000..3e18c79 --- /dev/null +++ b/backend/internals/settings.py @@ -0,0 +1,255 @@ +# -*- coding: utf-8 -*- + +from dataclasses import _MISSING_TYPE, asdict, dataclass +from functools import lru_cache +from json import dump, load +from logging import DEBUG, INFO +from typing import Any, Dict, Mapping + +from backend.base.custom_exceptions import InvalidKeyValue, KeyNotFound +from backend.base.helpers import (Singleton, folder_path, + get_python_version, reversed_tuples) +from backend.base.logging import LOGGER, set_log_level +from backend.internals.db import DBConnection, commit, get_db +from backend.internals.db_migration import get_latest_db_version + +THIRTY_DAYS = 2592000 + + +@lru_cache(1) +def get_about_data() -> Dict[str, Any]: + """Get data about the application and it's environment. + + Raises: + RuntimeError: If the version is not found in the pyproject.toml file. + + Returns: + Dict[str, Any]: The information. + """ + with open(folder_path("pyproject.toml"), "r") as f: + for line in f: + if line.startswith("version = "): + version = "V" + line.split('"')[1] + break + else: + raise RuntimeError("Version not found in pyproject.toml") + + return { + "version": version, + "python_version": get_python_version(), + "database_version": get_latest_db_version(), + "database_location": DBConnection.file, + "data_folder": folder_path() + } + + +@dataclass(frozen=True) +class SettingsValues: + database_version: int = get_latest_db_version() + log_level: int = INFO + + host: str = '0.0.0.0' + port: int = 8080 + url_prefix: str = '' + backup_host: str = '0.0.0.0' + backup_port: int = 8080 + backup_url_prefix: str = '' + + allow_new_accounts: bool = True + login_time: int = 3600 + login_time_reset: bool = True + + def todict(self) -> Dict[str, Any]: + return { + k: v + for k, v in self.__dict__.items() + if not k.startswith('backup_') + } + + +class Settings(metaclass=Singleton): + def __init__(self) -> None: + self._insert_missing_settings() + self._fetch_settings() + return + + def _insert_missing_settings(self) -> None: + "Insert any missing keys from the settings into the database." + get_db().executemany( + "INSERT OR IGNORE INTO config(key, value) VALUES (?, ?);", + asdict(SettingsValues()).items() + ) + commit() + return + + def _fetch_settings(self) -> None: + "Load the settings from the database into the cache." + db_values = { + k: v + for k, v in get_db().execute( + "SELECT key, value FROM config;" + ) + if k in SettingsValues.__dataclass_fields__ + } + + for b_key in ('allow_new_accounts', 'login_time_reset'): + db_values[b_key] = bool(db_values[b_key]) + + self.__cached_values = SettingsValues(**db_values) + return + + def get_settings(self) -> SettingsValues: + """Get the settings from the cache. + + Returns: + SettingsValues: The settings. + """ + return self.__cached_values + + # Alias, better in one-liners + # sv = Settings Values + @property + def sv(self) -> SettingsValues: + """Get the settings from the cache. + + Returns: + SettingsValues: The settings. + """ + return self.__cached_values + + def update( + self, + data: Mapping[str, Any] + ) -> None: + """Change the settings, in a `dict.update()` type of way. + + Args: + data (Mapping[str, Any]): The keys and their new values. + + Raises: + KeyNotFound: Key is not a setting. + InvalidKeyValue: Value of the key is not allowed. + """ + formatted_data = {} + for key, value in data.items(): + formatted_data[key] = self.__format_setting(key, value) + + get_db().executemany( + "UPDATE config SET value = ? WHERE key = ?;", + reversed_tuples(formatted_data.items()) + ) + + for key, handler in ( + ('url_prefix', update_manifest), + ('log_level', set_log_level) + ): + if ( + key in data + and formatted_data[key] != getattr(self.get_settings(), key) + ): + handler(formatted_data[key]) + + self._fetch_settings() + + LOGGER.info(f"Settings changed: {formatted_data}") + + return + + def reset(self, key: str) -> None: + """Reset the value of the key to the default value. + + Args: + key (str): The key of which to reset the value. + + Raises: + KeyNotFound: Key is not a setting. + """ + LOGGER.debug(f'Setting reset: {key}') + + if not isinstance( + SettingsValues.__dataclass_fields__[key].default_factory, + _MISSING_TYPE + ): + self.update({ + key: SettingsValues.__dataclass_fields__[key].default_factory() + }) + else: + self.update({ + key: SettingsValues.__dataclass_fields__[key].default + }) + + return + + def backup_hosting_settings(self) -> None: + "Backup the hosting settings in the database." + s = self.get_settings() + backup_settings = { + 'backup_host': s.host, + 'backup_port': s.port, + 'backup_url_prefix': s.url_prefix + } + self.update(backup_settings) + return + + def __format_setting(self, key: str, value: Any) -> Any: + """Check if the value of a setting is allowed and convert if needed. + + Args: + key (str): Key of setting. + value (Any): Value of setting. + + Raises: + KeyNotFound: Key is not a setting. + InvalidKeyValue: Value is not allowed. + + Returns: + Any: (Converted) Setting value. + """ + converted_value = value + + if key not in SettingsValues.__dataclass_fields__: + raise KeyNotFound(key) + + key_data = SettingsValues.__dataclass_fields__[key] + + if not isinstance(value, key_data.type): + raise InvalidKeyValue(key, value) + + if key == 'login_time': + if not 60 <= value <= THIRTY_DAYS: + raise InvalidKeyValue(key, value) + + elif key in ('port', 'backup_port'): + if not 1 <= value <= 65535: + raise InvalidKeyValue(key, value) + + elif key in ('url_prefix', 'backup_url_prefix'): + if value: + converted_value = ('/' + value.lstrip('/')).rstrip('/') + + elif key == 'log_level': + if value not in (INFO, DEBUG): + raise InvalidKeyValue(key, value) + + return converted_value + + +def update_manifest(url_base: str) -> None: + """Update the url's in the manifest file. + Needs to happen when url base changes. + + Args: + url_base (str): The url base to use in the file. + """ + filename = folder_path('frontend', 'static', 'json', 'pwa_manifest.json') + + with open(filename, 'r') as f: + manifest = load(f) + manifest['start_url'] = url_base + '/' + manifest['scope'] = url_base + '/' + manifest['icons'][0]['src'] = f'{url_base}/static/img/favicon.svg' + + with open(filename, 'w') as f: + dump(manifest, f, indent=4) + + return diff --git a/backend/logging.py b/backend/logging.py deleted file mode 100644 index a93ad74..0000000 --- a/backend/logging.py +++ /dev/null @@ -1,141 +0,0 @@ -#-*- coding: utf-8 -*- - -import logging -import logging.config -from os.path import exists -from typing import Any - -from backend.helpers import folder_path - - -class InfoOnlyFilter(logging.Filter): - def filter(self, record: logging.LogRecord) -> bool: - return record.levelno == logging.INFO - - -class DebuggingOnlyFilter(logging.Filter): - def filter(self, record: logging.LogRecord) -> bool: - return LOGGER.level == logging.DEBUG - - -class ErrorColorFormatter(logging.Formatter): - def format(self, record: logging.LogRecord) -> Any: - result = super().format(record) - return f'\033[1;31:40m{result}\033[0m' - - -LOGGER_NAME = "MIND" -LOGGER_DEBUG_FILENAME = "MIND_debug.log" -LOGGER = logging.getLogger(LOGGER_NAME) -LOGGING_CONFIG = { - "version": 1, - "disable_existing_loggers": False, - "formatters": { - "simple": { - "format": "[%(asctime)s][%(levelname)s] %(message)s", - "datefmt": "%H:%M:%S" - }, - "simple_red": { - "()": ErrorColorFormatter, - "format": "[%(asctime)s][%(levelname)s] %(message)s", - "datefmt": "%H:%M:%S" - }, - "detailed": { - "format": "%(asctime)s | %(threadName)s | %(filename)sL%(lineno)s | %(levelname)s | %(message)s", - "datefmt": "%Y-%m-%dT%H:%M:%S%z", - } - }, - "filters": { - "only_info": { - "()": InfoOnlyFilter - }, - "only_if_debugging": { - "()": DebuggingOnlyFilter - } - }, - "handlers": { - "console_error": { - "class": "logging.StreamHandler", - "level": "WARNING", - "formatter": "simple_red", - "stream": "ext://sys.stderr" - }, - "console": { - "class": "logging.StreamHandler", - "level": "INFO", - "formatter": "simple", - "filters": ["only_info"], - "stream": "ext://sys.stdout" - }, - "debug_file": { - "class": "logging.StreamHandler", - "level": "DEBUG", - "formatter": "detailed", - "filters": ["only_if_debugging"], - "stream": "" - } - }, - "loggers": { - LOGGER_NAME: { - "level": "INFO" - } - }, - "root": { - "level": "DEBUG", - "handlers": [ - "console", - "console_error", - "debug_file" - ] - } -} - -def setup_logging() -> None: - "Setup the basic config of the logging module" - logging.config.dictConfig(LOGGING_CONFIG) - return - -def get_debug_log_filepath() -> str: - """ - Get the filepath to the debug logging file. - Not in a global variable to avoid unnecessary computation. - """ - return folder_path(LOGGER_DEBUG_FILENAME) - -def set_log_level( - level: int, - clear_file: bool = True -) -> None: - """Change the logging level - - Args: - level (int): The level to set the logging to. - Should be a logging level, like `logging.INFO` or `logging.DEBUG`. - - clear_file (bool, optional): Empty the debug logging file. - Defaults to True. - """ - if LOGGER.level == level: - return - - LOGGER.debug(f'Setting logging level: {level}') - LOGGER.setLevel(level) - - if level == logging.DEBUG: - stream_handler = logging.getLogger().handlers[ - LOGGING_CONFIG["root"]["handlers"].index('debug_file') - ] - - file = get_debug_log_filepath() - - if clear_file: - if exists(file): - open(file, "w").close() - else: - open(file, "x").close() - - stream_handler.setStream( - open(file, "a") - ) - - return diff --git a/backend/notification_service.py b/backend/notification_service.py deleted file mode 100644 index 2612793..0000000 --- a/backend/notification_service.py +++ /dev/null @@ -1,405 +0,0 @@ -#-*- coding: utf-8 -*- - -from re import compile -from typing import Dict, List, Optional, Union - -from apprise import Apprise - -from backend.custom_exceptions import (NotificationServiceInUse, - NotificationServiceNotFound) -from backend.db import get_db -from backend.helpers import when_not_none -from backend.logging import LOGGER - -remove_named_groups = compile(r'(?<=\()\?P<\w+>') - -def process_regex(regex: Union[List[str], None]) -> Union[None, List[str]]: - return when_not_none( - regex, - lambda r: [remove_named_groups.sub('', r[0]), r[1]] - ) - -def _sort_tokens(t: dict) -> List[int]: - result = [ - int(not t['required']) - ] - - if t['type'] == 'choice': - result.append(0) - elif t['type'] != 'list': - result.append(1) - else: - result.append(2) - - return result - -def get_apprise_services() -> List[Dict[str, Union[str, Dict[str, list]]]]: - apprise_services = [] - raw = Apprise().details() - for entry in raw['schemas']: - entry: Dict[str, Union[str, dict]] - result: Dict[str, Union[str, Dict[str, list]]] = { - 'name': str(entry['service_name']), - 'doc_url': entry['setup_url'], - 'details': { - 'templates': entry['details']['templates'], - 'tokens': [], - 'args': [] - } - } - - schema = entry['details']['tokens']['schema'] - result['details']['tokens'].append({ - 'name': schema['name'], - 'map_to': 'schema', - 'required': schema['required'], - 'type': 'choice', - 'options': schema['values'], - 'default': schema.get('default') - }) - - handled_tokens = {'schema'} - result['details']['tokens'] += [ - { - 'name': v['name'], - 'map_to': k, - 'required': v['required'], - 'type': 'list', - 'delim': v['delim'][0], - 'content': [ - { - 'name': content['name'], - 'required': content['required'], - 'type': content['type'], - 'prefix': content.get('prefix'), - 'regex': process_regex(content.get('regex')) - } - for content, _ in ( - (entry['details']['tokens'][e], handled_tokens.add(e)) - for e in v['group'] - ) - ] - } - for k, v in - filter( - lambda t: t[1]['type'].startswith('list:'), - entry['details']['tokens'].items() - ) - ] - handled_tokens.update( - set(map(lambda e: e[0], - filter(lambda e: e[1]['type'].startswith('list:'), - entry['details']['tokens'].items()) - )) - ) - - result['details']['tokens'] += [ - { - 'name': v['name'], - 'map_to': k, - 'required': v['required'], - 'type': v['type'].split(':')[0], - **({ - 'options': v.get('values'), - 'default': v.get('default') - } if v['type'].startswith('choice') else { - 'prefix': v.get('prefix'), - 'min': v.get('min'), - 'max': v.get('max'), - 'regex': process_regex(v.get('regex')) - }) - } - for k, v in - filter( - lambda t: not t[0] in handled_tokens, - entry['details']['tokens'].items() - ) - ] - - result['details']['tokens'].sort(key=_sort_tokens) - - result['details']['args'] += [ - { - 'name': v.get('name', k), - 'map_to': k, - 'required': v.get('required', False), - 'type': v['type'].split(':')[0], - **({ - 'delim': v['delim'][0], - 'content': [] - } if v['type'].startswith('list') else { - 'options': v['values'], - 'default': v.get('default') - } if v['type'].startswith('choice') else { - 'default': v['default'] - } if v['type'] == 'bool' else { - 'min': v.get('min'), - 'max': v.get('max'), - 'regex': process_regex(v.get('regex')) - }) - } - for k, v in - filter( - lambda a: ( - a[1].get('alias_of') is None - and not a[0] in ('cto', 'format', 'overflow', 'rto', 'verify') - ), - entry['details']['args'].items() - ) - ] - result['details']['args'].sort(key=_sort_tokens) - - apprise_services.append(result) - - apprise_services.sort(key=lambda s: s['name'].lower()) - - apprise_services.insert(0, { - 'name': 'Custom URL', - 'doc_url': 'https://github.com/caronc/apprise#supported-notifications', - 'details': { - 'templates': ['{url}'], - 'tokens': [{ - 'name': 'Apprise URL', - 'map_to': 'url', - 'required': True, - 'type': 'string', - 'prefix': None, - 'min': None, - 'max': None, - 'regex': None - }], - 'args': [] - } - }) - - return apprise_services - -class NotificationService: - def __init__(self, user_id: int, notification_service_id: int) -> None: - self.id = notification_service_id - - if not get_db().execute(""" - SELECT 1 - FROM notification_services - WHERE id = ? - AND user_id = ? - LIMIT 1; - """, - (self.id, user_id) - ).fetchone(): - raise NotificationServiceNotFound - - def get(self) -> dict: - """Get the info about the notification service - - Returns: - dict: The info about the notification service - """ - result = dict(get_db(dict).execute(""" - SELECT id, title, url - FROM notification_services - WHERE id = ? - LIMIT 1 - """, - (self.id,) - ).fetchone()) - - return result - - def update( - self, - title: Optional[str] = None, - url: Optional[str] = None - ) -> dict: - """Edit the notification service - - Args: - title (Optional[str], optional): The new title of the service. Defaults to None. - url (Optional[str], optional): The new url of the service. Defaults to None. - - Returns: - dict: The new info about the service - """ - LOGGER.info(f'Updating notification service {self.id}: {title=}, {url=}') - - # Get current data and update it with new values - data = self.get() - new_values = { - 'title': title, - 'url': url - } - for k, v in new_values.items(): - if v is not None: - data[k] = v - - # Update database - get_db().execute(""" - UPDATE notification_services - SET title = ?, url = ? - WHERE id = ?; - """, - ( - data["title"], - data["url"], - self.id - ) - ) - - return self.get() - - def delete( - self, - delete_reminders_using: bool = False - ) -> None: - """Delete the service. - - Args: - delete_reminders_using (bool, optional): Instead of throwing an - error when there are still reminders using the service, delete - the reminders. - Defaults to False. - - Raises: - NotificationServiceInUse: The service is still used by a reminder. - """ - LOGGER.info(f'Deleting notification service {self.id}') - - cursor = get_db() - if not delete_reminders_using: - # Check if no reminders exist with this service - cursor.execute(""" - SELECT 1 - FROM reminder_services - WHERE notification_service_id = ? - AND reminder_id IS NOT NULL - LIMIT 1; - """, - (self.id,) - ) - if cursor.fetchone(): - raise NotificationServiceInUse('reminder') - - # Check if no templates exist with this service - cursor.execute(""" - SELECT 1 - FROM reminder_services - WHERE notification_service_id = ? - AND template_id IS NOT NULL - LIMIT 1; - """, - (self.id,) - ) - if cursor.fetchone(): - raise NotificationServiceInUse('template') - - # Check if no static reminders exist with this service - cursor.execute(""" - SELECT 1 - FROM reminder_services - WHERE notification_service_id = ? - AND static_reminder_id IS NOT NULL - LIMIT 1; - """, - (self.id,) - ) - if cursor.fetchone(): - raise NotificationServiceInUse('static reminder') - - else: - cursor.execute(""" - DELETE FROM reminders - WHERE id IN ( - SELECT reminder_id AS id FROM reminder_services - WHERE notification_service_id = ? - ); - """, (self.id,)) - cursor.execute(""" - DELETE FROM static_reminders - WHERE id IN ( - SELECT static_reminder_id AS id FROM reminder_services - WHERE notification_service_id = ? - ); - """, (self.id,)) - cursor.execute(""" - DELETE FROM templates - WHERE id IN ( - SELECT template_id AS id FROM reminder_services - WHERE notification_service_id = ? - ); - """, (self.id,)) - - cursor.execute( - "DELETE FROM notification_services WHERE id = ?", - (self.id,) - ) - return - -class NotificationServices: - def __init__(self, user_id: int) -> None: - self.user_id = user_id - - def fetchall(self) -> List[dict]: - """Get a list of all notification services - - Returns: - List[dict]: The list of all notification services - """ - result = list(map(dict, get_db(dict).execute(""" - SELECT - id, title, url - FROM notification_services - WHERE user_id = ? - ORDER BY title, id; - """, - (self.user_id,) - ))) - - return result - - def fetchone(self, notification_service_id: int) -> NotificationService: - """Get one notification service based on it's id - - Args: - notification_service_id (int): The id of the desired service - - Returns: - NotificationService: Instance of NotificationService - """ - return NotificationService(self.user_id, notification_service_id) - - def add(self, title: str, url: str) -> NotificationService: - """Add a notification service - - Args: - title (str): The title of the service - url (str): The apprise url of the service - - Returns: - NotificationService: The instance representing the new service - """ - LOGGER.info(f'Adding notification service with {title=}, {url=}') - - new_id = get_db().execute(""" - INSERT INTO notification_services(user_id, title, url) - VALUES (?,?,?) - """, - (self.user_id, title, url) - ).lastrowid - - return self.fetchone(new_id) - - def test_service( - self, - url: str - ) -> None: - """Send a test notification using the supplied Apprise URL - - Args: - url (str): The Apprise URL to use to send the test notification - """ - LOGGER.info(f'Testing service with {url=}') - a = Apprise() - a.add(url) - a.notify(title='MIND: Test title', body='MIND: Test body') - return - \ No newline at end of file diff --git a/backend/reminders.py b/backend/reminders.py deleted file mode 100644 index 5d2b0a3..0000000 --- a/backend/reminders.py +++ /dev/null @@ -1,796 +0,0 @@ -#-*- coding: utf-8 -*- - -from __future__ import annotations - -from datetime import datetime -from sqlite3 import IntegrityError -from threading import Timer -from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Union - -from apprise import Apprise -from dateutil.relativedelta import relativedelta -from dateutil.relativedelta import weekday as du_weekday - -from backend.custom_exceptions import (InvalidKeyValue, InvalidTime, - NotificationServiceNotFound, - ReminderNotFound) -from backend.db import get_db -from backend.helpers import (RepeatQuantity, Singleton, SortingMethod, - search_filter, when_not_none) -from backend.logging import LOGGER - -if TYPE_CHECKING: - from flask.ctx import AppContext - - -def __next_selected_day( - weekdays: List[int], - weekday: int -) -> int: - """Find the next allowed day in the week. - - Args: - weekdays (List[int]): The days of the week that are allowed. - Monday is 0, Sunday is 6. - weekday (int): The current weekday. - - Returns: - int: The next allowed weekday. - """ - return ( - # Get all days later than current, then grab first one. - [d for d in weekdays if weekday < d] - or - # weekday is last allowed day, so it should grab the first - # allowed day of the week. - weekdays - )[0] - -def _find_next_time( - original_time: int, - repeat_quantity: Union[RepeatQuantity, None], - repeat_interval: Union[int, None], - weekdays: Union[List[int], None] -) -> int: - """Calculate the next timestep based on original time and repeat/interval - values. - - Args: - original_time (int): The original time of the repeating timestamp. - - repeat_quantity (Union[RepeatQuantity, None]): If set, what the quantity - is of the repetition. - - repeat_interval (Union[int, None]): If set, the value of the repetition. - - weekdays (Union[List[int], None]): If set, on which days the time can - continue. Monday is 0, Sunday is 6. - - Returns: - int: The next timestamp in the future. - """ - if weekdays is not None: - weekdays.sort() - - new_time = datetime.fromtimestamp(original_time) - current_time = datetime.fromtimestamp(datetime.utcnow().timestamp()) - - if repeat_quantity is not None: - td = relativedelta(**{repeat_quantity.value: repeat_interval}) - while new_time <= current_time: - new_time += td - - elif weekdays is not None: - # We run the loop contents at least once and then actually use the cond. - # This is because we need to force the 'free' date to go to one of the - # selected weekdays. - # Say it's Monday, we set a reminder for Wednesday and make it repeat - # on Tuesday and Thursday. Then the first notification needs to go on - # Thurday, not Wednesday. So run code at least once to force that. - # Afterwards, it can run normally to push the timestamp into the future. - one_to_go = True - while one_to_go or new_time <= current_time: - next_day = __next_selected_day(weekdays, new_time.weekday()) - proposed_time = new_time + relativedelta(weekday=du_weekday(next_day)) - if proposed_time == new_time: - proposed_time += relativedelta(weekday=du_weekday(next_day, 2)) - new_time = proposed_time - one_to_go = False - - result = int(new_time.timestamp()) - LOGGER.debug( - f'{original_time=}, {current_time=} ' + - f'and interval of {repeat_interval} {repeat_quantity} ' + - f'leads to {result}' - ) - return result - - -class Reminder: - """Represents a reminder - """ - def __init__(self, user_id: int, reminder_id: int) -> None: - """Create an instance. - - Args: - user_id (int): The ID of the user. - reminder_id (int): The ID of the reminder. - - Raises: - ReminderNotFound: Reminder with given ID does not exist or is not - owned by user. - """ - self.id = reminder_id - - # Check if reminder exists - if not get_db().execute( - "SELECT 1 FROM reminders WHERE id = ? AND user_id = ? LIMIT 1", - (self.id, user_id) - ).fetchone(): - raise ReminderNotFound - - return - - def _get_notification_services(self) -> List[int]: - """Get ID's of notification services linked to the reminder. - - Returns: - List[int]: The list with ID's. - """ - result = [ - r[0] - for r in get_db().execute(""" - SELECT notification_service_id - FROM reminder_services - WHERE reminder_id = ?; - """, - (self.id,) - ) - ] - return result - - def get(self) -> dict: - """Get info about the reminder - - Returns: - dict: The info about the reminder - """ - reminder = get_db(dict).execute(""" - SELECT - id, - title, text, - time, - repeat_quantity, - repeat_interval, - weekdays, - color - FROM reminders - WHERE id = ? - LIMIT 1; - """, - (self.id,) - ).fetchone() - reminder = dict(reminder) - - reminder["weekdays"] = [ - int(n) - for n in reminder["weekdays"].split(",") - if n - ] if reminder["weekdays"] else None - reminder['notification_services'] = self._get_notification_services() - - return reminder - - def update( - self, - title: Union[None, str] = None, - time: Union[None, int] = None, - notification_services: Union[None, List[int]] = None, - text: Union[None, str] = None, - repeat_quantity: Union[None, RepeatQuantity] = None, - repeat_interval: Union[None, int] = None, - weekdays: Union[None, List[int]] = None, - color: Union[None, str] = None - ) -> dict: - """Edit the reminder. - - Args: - title (Union[None, str]): The new title of the entry. - Defaults to None. - - time (Union[None, int]): The new UTC epoch timestamp when the - reminder should be send. - Defaults to None. - - notification_services (Union[None, List[int]]): The new list - of id's of the notification services to use to send the reminder. - Defaults to None. - - text (Union[None, str], optional): The new body of the reminder. - Defaults to None. - - repeat_quantity (Union[None, RepeatQuantity], optional): The new - quantity of the repeat specified for the reminder. - Defaults to None. - - repeat_interval (Union[None, int], optional): The new amount of - repeat_quantity, like "5" (hours). - Defaults to None. - - weekdays (Union[None, List[int]], optional): The new indexes of - the days of the week that the reminder should run. - Defaults to None. - - color (Union[None, str], optional): The new hex code of the color - of the reminder, which is shown in the web-ui. - Defaults to None. - - Note about args: - Either repeat_quantity and repeat_interval are given, weekdays is - given or neither, but not both. - - Raises: - NotificationServiceNotFound: One of the notification services was not found. - InvalidKeyValue: The value of one of the keys is not valid or - the "Note about args" is violated. - - Returns: - dict: The new reminder info. - """ - LOGGER.info( - f'Updating notification service {self.id}: ' - + f'{title=}, {time=}, {notification_services=}, {text=}, ' - + f'{repeat_quantity=}, {repeat_interval=}, {weekdays=}, {color=}' - ) - cursor = get_db() - - # Validate data - if repeat_quantity is None and repeat_interval is not None: - raise InvalidKeyValue('repeat_quantity', repeat_quantity) - elif repeat_quantity is not None and repeat_interval is None: - raise InvalidKeyValue('repeat_interval', repeat_interval) - elif weekdays is not None and repeat_quantity is not None: - raise InvalidKeyValue('weekdays', weekdays) - - repeated_reminder = ( - (repeat_quantity is not None and repeat_interval is not None) - or weekdays is not None - ) - - if time is not None: - if not repeated_reminder: - if time < datetime.utcnow().timestamp(): - raise InvalidTime - time = round(time) - - # Get current data and update it with new values - data = self.get() - new_values = { - 'title': title, - 'time': time, - 'text': text, - 'repeat_quantity': repeat_quantity, - 'repeat_interval': repeat_interval, - 'weekdays': when_not_none( - weekdays, - lambda w: ",".join(map(str, sorted(w))) - ), - 'color': color - } - for k, v in new_values.items(): - if ( - k in ('repeat_quantity', 'repeat_interval', 'weekdays', 'color') - or v is not None - ): - data[k] = v - - # Update database - rq = when_not_none( - data["repeat_quantity"], - lambda q: q.value - ) - if repeated_reminder: - next_time = _find_next_time( - data["time"], - data["repeat_quantity"], - data["repeat_interval"], - weekdays - ) - cursor.execute(""" - UPDATE reminders - SET - title=?, - text=?, - time=?, - repeat_quantity=?, - repeat_interval=?, - weekdays=?, - original_time=?, - color=? - WHERE id = ?; - """, ( - data["title"], - data["text"], - next_time, - rq, - data["repeat_interval"], - data["weekdays"], - data["time"], - data["color"], - self.id - )) - - else: - next_time = data["time"] - cursor.execute(""" - UPDATE reminders - SET - title=?, - text=?, - time=?, - repeat_quantity=?, - repeat_interval=?, - weekdays=?, - color=? - WHERE id = ?; - """, ( - data["title"], - data["text"], - data["time"], - rq, - data["repeat_interval"], - data["weekdays"], - data["color"], - self.id - )) - - if notification_services: - cursor.connection.isolation_level = None - cursor.execute("BEGIN TRANSACTION;") - cursor.execute( - "DELETE FROM reminder_services WHERE reminder_id = ?", - (self.id,) - ) - try: - cursor.executemany(""" - INSERT INTO reminder_services( - reminder_id, - notification_service_id - ) - VALUES (?,?); - """, - ((self.id, s) for s in notification_services) - ) - cursor.execute("COMMIT;") - - except IntegrityError: - raise NotificationServiceNotFound - - finally: - cursor.connection.isolation_level = "" - - ReminderHandler().find_next_reminder(next_time) - return self.get() - - def delete(self) -> None: - """Delete the reminder - """ - LOGGER.info(f'Deleting reminder {self.id}') - get_db().execute("DELETE FROM reminders WHERE id = ?", (self.id,)) - ReminderHandler().find_next_reminder() - return - -class Reminders: - """Represents the reminder library of the user account - """ - - def __init__(self, user_id: int) -> None: - """Create an instance. - - Args: - user_id (int): The ID of the user. - """ - self.user_id = user_id - return - - def fetchall( - self, - sort_by: SortingMethod = SortingMethod.TIME - ) -> List[dict]: - """Get all reminders - - Args: - sort_by (SortingMethod, optional): How to sort the result. - Defaults to SortingMethod.TIME. - - Returns: - List[dict]: The id, title, text, time and color of each reminder - """ - reminders = [ - dict(r) - for r in get_db(dict).execute(""" - SELECT - id, - title, text, - time, - repeat_quantity, - repeat_interval, - weekdays, - color - FROM reminders - WHERE user_id = ?; - """, - (self.user_id,) - ) - ] - for r in reminders: - r["weekdays"] = [ - int(n) - for n in r["weekdays"].split(",") - if n - ] if r["weekdays"] else None - - # Sort result - reminders.sort(key=sort_by.value[0], reverse=sort_by.value[1]) - - return reminders - - def search( - self, - query: str, - sort_by: SortingMethod = SortingMethod.TIME) -> List[dict]: - """Search for reminders - - Args: - query (str): The term to search for. - sort_by (SortingMethod, optional): How to sort the result. - Defaults to SortingMethod.TIME. - - Returns: - List[dict]: All reminders that match. Similar output to self.fetchall - """ - reminders = [ - r for r in self.fetchall(sort_by) - if search_filter(query, r) - ] - return reminders - - def fetchone(self, id: int) -> Reminder: - """Get one reminder - - Args: - id (int): The id of the reminder to fetch - - Returns: - Reminder: A Reminder instance - """ - return Reminder(self.user_id, id) - - def add( - self, - title: str, - time: int, - notification_services: List[int], - text: str = '', - repeat_quantity: Union[None, RepeatQuantity] = None, - repeat_interval: Union[None, int] = None, - weekdays: Union[None, List[int]] = None, - color: Union[None, str] = None - ) -> Reminder: - """Add a reminder - - Args: - title (str): The title of the entry. - - time (int): The UTC epoch timestamp the the reminder should be send. - - notification_services (List[int]): The id's of the notification services - to use to send the reminder. - - text (str, optional): The body of the reminder. - Defaults to ''. - - repeat_quantity (Union[None, RepeatQuantity], optional): The quantity - of the repeat specified for the reminder. - Defaults to None. - - repeat_interval (Union[None, int], optional): The amount of repeat_quantity, - like "5" (hours). - Defaults to None. - - weekdays (Union[None, List[int]], optional): The indexes of the days - of the week that the reminder should run. - Defaults to None. - - color (Union[None, str], optional): The hex code of the color of the - reminder, which is shown in the web-ui. - Defaults to None. - - Note about args: - Either repeat_quantity and repeat_interval are given, - weekdays is given or neither, but not both. - - Raises: - NotificationServiceNotFound: One of the notification services was not found. - InvalidKeyValue: The value of one of the keys is not valid - or the "Note about args" is violated. - - Returns: - dict: The info about the reminder. - """ - LOGGER.info( - f'Adding reminder with {title=}, {time=}, {notification_services=}, ' - + f'{text=}, {repeat_quantity=}, {repeat_interval=}, {weekdays=}, {color=}' - ) - - if time < datetime.utcnow().timestamp(): - raise InvalidTime - time = round(time) - - if repeat_quantity is None and repeat_interval is not None: - raise InvalidKeyValue('repeat_quantity', repeat_quantity) - elif repeat_quantity is not None and repeat_interval is None: - raise InvalidKeyValue('repeat_interval', repeat_interval) - elif ( - weekdays is not None - and repeat_quantity is not None - and repeat_interval is not None - ): - raise InvalidKeyValue('weekdays', weekdays) - - cursor = get_db() - for service in notification_services: - if not cursor.execute(""" - SELECT 1 - FROM notification_services - WHERE id = ? - AND user_id = ? - LIMIT 1; - """, - (service, self.user_id) - ).fetchone(): - raise NotificationServiceNotFound - - # Prepare args - if any((repeat_quantity, weekdays)): - original_time = time - time = _find_next_time( - original_time, - repeat_quantity, - repeat_interval, - weekdays - ) - else: - original_time = None - - weekdays_str = when_not_none( - weekdays, - lambda w: ",".join(map(str, sorted(w))) - ) - repeat_quantity_str = when_not_none( - repeat_quantity, - lambda q: q.value - ) - - cursor.connection.isolation_level = None - cursor.execute("BEGIN TRANSACTION;") - - id = cursor.execute(""" - INSERT INTO reminders( - user_id, - title, text, - time, - repeat_quantity, repeat_interval, - weekdays, - original_time, - color - ) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?); - """, ( - self.user_id, - title, text, - time, - repeat_quantity_str, - repeat_interval, - weekdays_str, - original_time, - color - )).lastrowid - - try: - cursor.executemany(""" - INSERT INTO reminder_services( - reminder_id, - notification_service_id - ) - VALUES (?, ?); - """, - ((id, service) for service in notification_services) - ) - cursor.execute("COMMIT;") - - except IntegrityError: - raise NotificationServiceNotFound - - finally: - cursor.connection.isolation_level = '' - - ReminderHandler().find_next_reminder(time) - - return self.fetchone(id) - - def test_reminder( - self, - title: str, - notification_services: List[int], - text: str = '' - ) -> None: - """Test send a reminder draft. - - Args: - title (str): Title title of the entry. - - notification_service (int): The id of the notification service to - use to send the reminder. - - text (str, optional): The body of the reminder. - Defaults to ''. - """ - LOGGER.info(f'Testing reminder with {title=}, {notification_services=}, {text=}') - a = Apprise() - cursor = get_db(dict) - - for service in notification_services: - url = cursor.execute(""" - SELECT url - FROM notification_services - WHERE id = ? - AND user_id = ? - LIMIT 1; - """, - (service, self.user_id) - ).fetchone() - if not url: - raise NotificationServiceNotFound - a.add(url[0]) - - a.notify(title=title, body=text or '\u200B') - return - - -class ReminderHandler(metaclass=Singleton): - """Handle set reminders. - - Note: Singleton. - """ - def __init__( - self, - context: Callable[[], AppContext] - ) -> None: - """Create instance of handler. - - Args: - context (Optional[AppContext], optional): `Flask.app_context`. - Defaults to None. - """ - self.context = context - self.thread: Union[Timer, None] = None - self.time: Union[int, None] = None - return - - def __trigger_reminders(self, time: int) -> None: - """Trigger all reminders that are set for a certain time - - Args: - time (int): The time of the reminders to trigger - """ - with self.context(): - cursor = get_db(dict) - reminders = [ - dict(r) - for r in cursor.execute(""" - SELECT - id, user_id, - title, text, - repeat_quantity, repeat_interval, - weekdays, - original_time - FROM reminders - WHERE time = ?; - """, - (time,) - ) - ] - - for reminder in reminders: - cursor.execute(""" - SELECT url - FROM reminder_services rs - INNER JOIN notification_services ns - ON rs.notification_service_id = ns.id - WHERE rs.reminder_id = ?; - """, - (reminder['id'],) - ) - - # Send reminder - a = Apprise() - for url in cursor: - a.add(url['url']) - a.notify(title=reminder["title"], body=reminder["text"] or '\u200B') - - self.thread = None - self.time = None - - if (reminder['repeat_quantity'], reminder['weekdays']) == (None, None): - # Delete the reminder from the database - Reminder(reminder["user_id"], reminder["id"]).delete() - - else: - # Set next time - new_time = _find_next_time( - reminder['original_time'], - when_not_none( - reminder["repeat_quantity"], - lambda q: RepeatQuantity(q) - ), - reminder['repeat_interval'], - when_not_none( - reminder["weekdays"], - lambda w: [int(d) for d in w.split(',')] - ) - ) - cursor.execute( - "UPDATE reminders SET time = ? WHERE id = ?;", - (new_time, reminder['id']) - ) - - self.find_next_reminder() - return - - def find_next_reminder(self, time: Optional[int] = None) -> None: - """Determine when the soonest reminder is and set the timer to that time - - Args: - time (Optional[int], optional): The timestamp to check for. - Otherwise check soonest in database. - Defaults to None. - """ - if time is None: - with self.context(): - soonest_time: Union[Tuple[int], None] = get_db().execute(""" - SELECT DISTINCT r1.time - FROM reminders r1 - LEFT JOIN reminders r2 - ON r1.time > r2.time - WHERE r2.id IS NULL; - """).fetchone() - if soonest_time is None: - return - time = soonest_time[0] - - if ( - self.thread is None - or time < self.time - ): - if self.thread is not None: - self.thread.cancel() - - t = time - datetime.utcnow().timestamp() - self.thread = Timer( - t, - self.__trigger_reminders, - (time,) - ) - self.thread.name = "ReminderHandler" - self.thread.start() - self.time = time - - return - - def stop_handling(self) -> None: - """Stop the timer if it's active - """ - if self.thread is not None: - self.thread.cancel() - return diff --git a/backend/security.py b/backend/security.py deleted file mode 100644 index b5afaa1..0000000 --- a/backend/security.py +++ /dev/null @@ -1,40 +0,0 @@ -#-*- coding: utf-8 -*- - -""" -Hashing and salting -""" - -from base64 import urlsafe_b64encode -from hashlib import pbkdf2_hmac -from secrets import token_bytes -from typing import Tuple - - -def get_hash(salt: bytes, data: str) -> bytes: - """Hash a string using the supplied salt - - Args: - salt (bytes): The salt to use when hashing - data (str): The data to hash - - Returns: - bytes: The b64 encoded hash of the supplied string - """ - return urlsafe_b64encode( - pbkdf2_hmac('sha256', data.encode(), salt, 100_000) - ) - -def generate_salt_hash(password: str) -> Tuple[bytes, bytes]: - """Generate a salt and get the hash of the password - - Args: - password (str): The password to generate for - - Returns: - Tuple[bytes, bytes]: The salt (1) and hashed_password (2) - """ - salt = token_bytes() - hashed_password = get_hash(salt, password) - del password - - return salt, hashed_password diff --git a/backend/server.py b/backend/server.py deleted file mode 100644 index 9f133a3..0000000 --- a/backend/server.py +++ /dev/null @@ -1,264 +0,0 @@ -#-*- coding: utf-8 -*- - -from __future__ import annotations - -from os import execv, urandom -from sys import argv -from threading import Timer, current_thread -from typing import TYPE_CHECKING, List, NoReturn, Union - -from flask import Flask, render_template, request -from waitress import create_server -from waitress.task import ThreadedTaskDispatcher as TTD -from werkzeug.middleware.dispatcher import DispatcherMiddleware - -from backend.db import DB_Singleton, DBConnection, close_db, revert_db_import -from backend.helpers import RestartVars, Singleton, folder_path -from backend.logging import LOGGER -from backend.settings import restore_hosting_settings - -if TYPE_CHECKING: - from waitress.server import TcpWSGIServer - -THREADS = 10 - -class ThreadedTaskDispatcher(TTD): - def handler_thread(self, thread_no: int) -> None: - super().handler_thread(thread_no) - i = f'{DBConnection}{current_thread()}' - if i in DB_Singleton._instances and not DB_Singleton._instances[i].closed: - DB_Singleton._instances[i].close() - return - - def shutdown(self, cancel_pending: bool = True, timeout: int = 5) -> bool: - print() - LOGGER.info('Shutting down MIND') - result = super().shutdown(cancel_pending, timeout) - DBConnection(timeout=20.0).close() - return result - - -class Server(metaclass=Singleton): - api_prefix = "/api" - admin_api_extension = "/admin" - admin_prefix = "/api/admin" - - def __init__(self) -> None: - self.do_restart = False - "Restart instead of shutdown" - - self.restart_args: List[str] = [] - "Flag to run with when restarting" - - self.handle_flags: bool = False - "Run any flag specific actions before restarting" - - self.url_prefix = "" - - self.revert_db_timer = Timer(60.0, self.__revert_db) - self.revert_db_timer.name = "DatabaseImportHandler" - self.revert_hosting_timer = Timer(60.0, self.__revert_hosting) - self.revert_hosting_timer.name = "HostingHandler" - - return - - def create_app(self) -> None: - """Create a Flask app instance""" - from frontend.api import admin_api, api - from frontend.ui import ui - - app = Flask( - __name__, - template_folder=folder_path('frontend','templates'), - static_folder=folder_path('frontend','static'), - static_url_path='/static' - ) - app.config['SECRET_KEY'] = urandom(32) - app.config['JSONIFY_PRETTYPRINT_REGULAR'] = True - app.config['JSON_SORT_KEYS'] = False - - # Add error handlers - @app.errorhandler(400) - def bad_request(e): - return {'error': 'Bad request', 'result': {}}, 400 - - @app.errorhandler(405) - def method_not_allowed(e): - return {'error': 'Method not allowed', 'result': {}}, 405 - - @app.errorhandler(500) - def internal_error(e): - return {'error': 'Internal error', 'result': {}}, 500 - - @app.errorhandler(404) - def not_found(e): - if request.path.startswith(self.api_prefix): - return {'error': 'Not Found', 'result': {}}, 404 - return render_template('page_not_found.html', url_prefix=self.url_prefix) - - app.register_blueprint(ui) - app.register_blueprint(api, url_prefix=self.api_prefix) - app.register_blueprint(admin_api, url_prefix=self.admin_prefix) - - # Setup closing database - app.teardown_appcontext(close_db) - - self.app = app - return - - def set_url_prefix(self, url_prefix: str) -> None: - """Change the URL prefix of the server. - - Args: - url_prefix (str): The desired URL prefix to set it to. - """ - self.app.config["APPLICATION_ROOT"] = url_prefix - self.app.wsgi_app = DispatcherMiddleware( - Flask(__name__), - {url_prefix: self.app.wsgi_app} - ) - self.url_prefix = url_prefix - return - - def __create_waitress_server( - self, - host: str, - port: int - ) -> TcpWSGIServer: - """From the `Flask` instance created in `self.create_app()`, create - a waitress server instance. - - Args: - host (str): The host to bind to. - port (int): The port to listen on. - - Returns: - TcpWSGIServer: The waitress server. - """ - dispatcher = ThreadedTaskDispatcher() - dispatcher.set_thread_count(THREADS) - server = create_server( - self.app, - _dispatcher=dispatcher, - host=host, - port=port, - threads=THREADS - ) - return server - - def run(self, host: str, port: int) -> None: - """Start the webserver. - - Args: - host (str): The host to bind to. - port (int): The port to listen on. - """ - self.server = self.__create_waitress_server(host, port) - LOGGER.info(f'MIND running on http://{host}:{port}{self.url_prefix}') - self.server.run() - - return - - def __shutdown_thread_function(self) -> None: - """Shutdown waitress server. Intended to be run in a thread. - """ - self.server.close() - self.server.task_dispatcher.shutdown() - self.server._map.clear() - return - - def shutdown(self) -> None: - """Stop the waitress server. Starts a thread that - shuts down the server. - """ - t = Timer(1.0, self.__shutdown_thread_function) - t.name = "InternalStateHandler" - t.start() - return - - def restart( - self, - restart_args: List[str] = [], - handle_flags: bool = False - ) -> None: - """Same as `self.shutdown()`, but restart instead of shutting down. - - Args: - restart_args (List[str], optional): Any arguments to run the new instance with. - Defaults to []. - - handle_flags (bool, optional): Run flag specific actions just before restarting. - Defaults to False. - """ - self.do_restart = True - self.restart_args = restart_args - self.handle_flags = handle_flags - self.shutdown() - return - - def handle_restart(self, flag: Union[str, None]) -> NoReturn: - """Restart the interpreter. - - Args: - flag (Union[str, None]): Supplied flag, for flag handling. - - Returns: - NoReturn: No return because it replaces the interpreter. - """ - if self.handle_flags: - handle_flags_pre_restart(flag) - - LOGGER.info('Restarting MIND') - from MIND import __file__ as mind_file - execv(folder_path(mind_file), [argv[0], *self.restart_args]) - - def __revert_db(self) -> None: - """Revert database import and restart. - """ - LOGGER.warning(f'Timer for database import expired; reverting back to original file') - self.restart(handle_flags=True) - return - - def __revert_hosting(self) -> None: - """Revert the hosting changes. - """ - LOGGER.warning(f'Timer for hosting changes expired; reverting back to original settings') - self.restart(handle_flags=True) - return - - -SERVER = Server() - - -def handle_flags(flag: Union[None, str]) -> None: - """Run flag specific actions on startup. - - Args: - flag (Union[None, str]): The flag or `None` if there is no flag set. - """ - if flag == RestartVars.DB_IMPORT: - LOGGER.info('Starting timer for database import') - SERVER.revert_db_timer.start() - - elif flag == RestartVars.HOST_CHANGE: - LOGGER.info('Starting timer for hosting changes') - SERVER.revert_hosting_timer.start() - - return - - -def handle_flags_pre_restart(flag: Union[None, str]) -> None: - """Run flag specific actions just before restarting. - - Args: - flag (Union[None, str]): The flag or `None` if there is no flag set. - """ - if flag == RestartVars.DB_IMPORT: - revert_db_import(swap=True) - - elif flag == RestartVars.HOST_CHANGE: - with SERVER.app.app_context(): - restore_hosting_settings() - close_db() - - return diff --git a/backend/settings.py b/backend/settings.py deleted file mode 100644 index 85ac198..0000000 --- a/backend/settings.py +++ /dev/null @@ -1,245 +0,0 @@ -#-*- coding: utf-8 -*- - -""" -Getting and setting settings -""" - -import logging -from json import dump, load -from typing import Any - -from backend.custom_exceptions import InvalidKeyValue, KeyNotFound -from backend.db import __DATABASE_VERSION__, get_db -from backend.helpers import folder_path -from backend.logging import set_log_level - -default_settings = { - 'allow_new_accounts': True, - 'login_time': 3600, - 'login_time_reset': True, - - 'database_version': __DATABASE_VERSION__, - - 'host': '0.0.0.0', - 'port': 8080, - 'url_prefix': '', - - 'log_level': logging.INFO -} - -def _format_setting(key: str, value): - """Turn python value in to database value. - - Args: - key (str): The key of the value. - value (Any): The value itself. - - Raises: - InvalidKeyValue: The value is not valid. - - Returns: - Any: The converted value. - """ - if key == 'database_version': - try: - value = int(value) - except ValueError: - raise InvalidKeyValue(key, value) - - elif key in ('allow_new_accounts', 'login_time_reset'): - if not isinstance(value, bool): - raise InvalidKeyValue(key, value) - value = int(value) - - elif key == 'login_time': - if not isinstance(value, int) or not 60 <= value <= 2592000: - raise InvalidKeyValue(key, value) - - elif key == 'host': - if not isinstance(value, str): - raise InvalidKeyValue(key, value) - - elif key == 'port': - if not isinstance(value, int) or not 1 <= value <= 65535: - raise InvalidKeyValue(key, value) - - elif key == 'url_prefix': - if not isinstance(value, str): - raise InvalidKeyValue(key, value) - - if value == '/': - value = '' - - elif value: - value = '/' + value.strip('/') - - elif key == 'log_level' and not value in (logging.INFO, logging.DEBUG): - raise InvalidKeyValue(key, value) - - return value - -def _reverse_format_setting(key: str, value: Any) -> Any: - """Turn database value in to python value. - - Args: - key (str): The key of the value. - value (Any): The value itself. - - Returns: - Any: The converted value. - """ - if key in ('allow_new_accounts', 'login_time_reset'): - value = value == 1 - - elif key in ('log_level', 'database_version', 'login_time'): - value = int(value) - - return value - -def get_setting(key: str) -> Any: - """Get a value from the config. - - Args: - key (str): The key of which to get the value. - - Raises: - KeyNotFound: Key is not in config. - - Returns: - Any: The value of the key. - """ - result = get_db().execute( - "SELECT value FROM config WHERE key = ? LIMIT 1;", - (key,) - ).fetchone() - if result is None: - raise KeyNotFound(key) - - result = _reverse_format_setting(key, result[0]) - - return result - -def get_admin_settings() -> dict: - """Get all admin settings - - Returns: - dict: The admin settings - """ - return dict(( - (key, _reverse_format_setting(key, value)) - for key, value in get_db().execute(""" - SELECT key, value - FROM config - WHERE - key = 'allow_new_accounts' - OR key = 'login_time' - OR key = 'login_time_reset' - OR key = 'host' - OR key = 'port' - OR key = 'url_prefix' - OR key = 'log_level'; - """ - ) - )) - -def set_setting(key: str, value: Any) -> None: - """Set a value in the config - - Args: - key (str): The key for which to set the value - value (Any): The value to give to the key - - Raises: - KeyNotFound: The key is not in the config - InvalidKeyValue: The value is not allowed for the key - """ - if not key in (*default_settings, 'database_version'): - raise KeyNotFound(key) - - value = _format_setting(key, value) - - get_db().execute( - "UPDATE config SET value = ? WHERE key = ?;", - (value, key) - ) - - if key == 'url_prefix': - update_manifest(value) - - elif key == 'log_level': - set_log_level(value) - - return - -def update_manifest(url_base: str) -> None: - """Update the url's in the manifest file. - Needs to happen when url base changes. - - Args: - url_base (str): The url base to use in the file. - """ - filename = folder_path('frontend', 'static', 'json', 'pwa_manifest.json') - - with open(filename, 'r') as f: - manifest = load(f) - manifest['start_url'] = url_base + '/' - manifest['icons'][0]['src'] = f'{url_base}/static/img/favicon.svg' - - with open(filename, 'w') as f: - dump(manifest, f, indent=4) - - return - -def backup_hosting_settings() -> None: - """Copy current hosting settings to backup values. - """ - cursor = get_db() - hosting_settings = dict(cursor.execute(""" - SELECT key, value - FROM config - WHERE key = 'host' - OR key = 'port' - OR key = 'url_prefix' - LIMIT 3; - """ - )) - hosting_settings = {f'{k}_backup': v for k, v in hosting_settings.items()} - - cursor.executemany(""" - INSERT INTO config(key, value) - VALUES (?, ?) - ON CONFLICT(key) DO - UPDATE - SET value = ?; - """, - ((k, v, v) for k, v in hosting_settings.items()) - ) - - return - -def restore_hosting_settings() -> None: - """Copy the hosting settings from the backup over to the main keys. - """ - cursor = get_db() - hosting_settings = dict(cursor.execute(""" - SELECT key, value - FROM config - WHERE key = 'host_backup' - OR key = 'port_backup' - OR key = 'url_prefix_backup' - LIMIT 3; - """ - )) - if len(hosting_settings) < 3: - return - - hosting_settings = {k.split('_backup')[0]: v for k, v in hosting_settings.items()} - - cursor.executemany( - "UPDATE config SET value = ? WHERE key = ?", - ((v, k) for k, v in hosting_settings.items()) - ) - - update_manifest(hosting_settings['url_prefix']) - - return diff --git a/backend/static_reminders.py b/backend/static_reminders.py deleted file mode 100644 index 4574130..0000000 --- a/backend/static_reminders.py +++ /dev/null @@ -1,356 +0,0 @@ -#-*- coding: utf-8 -*- - -from sqlite3 import IntegrityError -from typing import List, Optional, Union - -from apprise import Apprise - -from backend.custom_exceptions import (NotificationServiceNotFound, - ReminderNotFound) -from backend.db import get_db -from backend.helpers import TimelessSortingMethod, search_filter -from backend.logging import LOGGER - - -class StaticReminder: - """Represents a static reminder - """ - def __init__(self, user_id: int, reminder_id: int) -> None: - """Create an instance. - - Args: - user_id (int): The ID of the user. - reminder_id (int): The ID of the reminder. - - Raises: - ReminderNotFound: Reminder with given ID does not exist or is not - owned by user. - """ - self.id = reminder_id - - # Check if reminder exists - if not get_db().execute( - "SELECT 1 FROM static_reminders WHERE id = ? AND user_id = ? LIMIT 1;", - (self.id, user_id) - ).fetchone(): - raise ReminderNotFound - - return - - def _get_notification_services(self) -> List[int]: - """Get ID's of notification services linked to the static reminder. - - Returns: - List[int]: The list with ID's. - """ - result = [ - r[0] - for r in get_db().execute(""" - SELECT notification_service_id - FROM reminder_services - WHERE static_reminder_id = ?; - """, - (self.id,) - ) - ] - return result - - def get(self) -> dict: - """Get info about the static reminder - - Returns: - dict: The info about the static reminder - """ - reminder = get_db(dict).execute(""" - SELECT - id, - title, text, - color - FROM static_reminders - WHERE id = ? - LIMIT 1; - """, - (self.id,) - ).fetchone() - reminder = dict(reminder) - - reminder['notification_services'] = self._get_notification_services() - - return reminder - - def update( - self, - title: Union[str, None] = None, - notification_services: Union[List[int], None] = None, - text: Union[str, None] = None, - color: Union[str, None] = None - ) -> dict: - """Edit the static reminder. - - Args: - title (Union[str, None], optional): The new title of the entry. - Defaults to None. - - notification_services (Union[List[int], None], optional): - The new id's of the notification services to use to send the reminder. - Defaults to None. - - text (Union[str, None], optional): The new body of the reminder. - Defaults to None. - - color (Union[str, None], optional): The new hex code of the color - of the reminder, which is shown in the web-ui. - Defaults to None. - - Raises: - NotificationServiceNotFound: One of the notification services was not found - - Returns: - dict: The new static reminder info - """ - LOGGER.info( - f'Updating static reminder {self.id}: ' - + f'{title=}, {notification_services=}, {text=}, {color=}' - ) - - # Get current data and update it with new values - data = self.get() - new_values = { - 'title': title, - 'text': text, - 'color': color - } - for k, v in new_values.items(): - if k == 'color' or v is not None: - data[k] = v - - # Update database - cursor = get_db() - cursor.execute(""" - UPDATE static_reminders - SET - title = ?, text = ?, - color = ? - WHERE id = ?; - """, - (data['title'], data['text'], - data['color'], - self.id) - ) - - if notification_services: - cursor.connection.isolation_level = None - cursor.execute("BEGIN TRANSACTION;") - cursor.execute( - "DELETE FROM reminder_services WHERE static_reminder_id = ?", - (self.id,) - ) - try: - cursor.executemany(""" - INSERT INTO reminder_services( - static_reminder_id, - notification_service_id - ) - VALUES (?,?); - """, - ((self.id, s) for s in notification_services) - ) - cursor.execute("COMMIT;") - - except IntegrityError: - raise NotificationServiceNotFound - - finally: - cursor.connection.isolation_level = "" - - return self.get() - - def delete(self) -> None: - """Delete the static reminder - """ - LOGGER.info(f'Deleting static reminder {self.id}') - get_db().execute("DELETE FROM static_reminders WHERE id = ?", (self.id,)) - return - -class StaticReminders: - """Represents the static reminder library of the user account - """ - - def __init__(self, user_id: int) -> None: - """Create an instance. - - Args: - user_id (int): The ID of the user. - """ - self.user_id = user_id - return - - def fetchall( - self, - sort_by: TimelessSortingMethod = TimelessSortingMethod.TITLE - ) -> List[dict]: - """Get all static reminders - - Args: - sort_by (TimelessSortingMethod, optional): How to sort the result. - Defaults to TimelessSortingMethod.TITLE. - - Returns: - List[dict]: The id, title, text and color of each static reminder. - """ - reminders = [ - dict(r) - for r in get_db(dict).execute(""" - SELECT - id, - title, text, - color - FROM static_reminders - WHERE user_id = ? - ORDER BY title, id; - """, - (self.user_id,) - ) - ] - - # Sort result - reminders.sort(key=sort_by.value[0], reverse=sort_by.value[1]) - - return reminders - - def search( - self, - query: str, - sort_by: TimelessSortingMethod = TimelessSortingMethod.TITLE - ) -> List[dict]: - """Search for static reminders - - Args: - query (str): The term to search for. - - sort_by (TimelessSortingMethod, optional): The sorting method of - the resulting list. - Defaults to TimelessSortingMethod.TITLE. - - Returns: - List[dict]: All static reminders that match. - Similar output to `self.fetchall` - """ - static_reminders = [ - r for r in self.fetchall(sort_by) - if search_filter(query, r) - ] - return static_reminders - - def fetchone(self, id: int) -> StaticReminder: - """Get one static reminder - - Args: - id (int): The id of the static reminder to fetch - - Returns: - StaticReminder: A StaticReminder instance - """ - return StaticReminder(self.user_id, id) - - def add( - self, - title: str, - notification_services: List[int], - text: str = '', - color: Optional[str] = None - ) -> StaticReminder: - """Add a static reminder - - Args: - title (str): The title of the entry. - - notification_services (List[int]): The id's of the - notification services to use to send the reminder. - - text (str, optional): The body of the reminder. - Defaults to ''. - - color (Optional[str], optional): The hex code of the color of the template, - which is shown in the web-ui. - Defaults to None. - - Raises: - NotificationServiceNotFound: One of the notification services was not found - - Returns: - StaticReminder: The info about the static reminder - """ - LOGGER.info( - f'Adding static reminder with {title=}, {notification_services=}, {text=}, {color=}' - ) - - cursor = get_db() - cursor.connection.isolation_level = None - cursor.execute("BEGIN TRANSACTION;") - - id = cursor.execute(""" - INSERT INTO static_reminders(user_id, title, text, color) - VALUES (?,?,?,?); - """, - (self.user_id, title, text, color) - ).lastrowid - - try: - cursor.executemany(""" - INSERT INTO reminder_services( - static_reminder_id, - notification_service_id - ) - VALUES (?, ?); - """, - ((id, service) for service in notification_services) - ) - cursor.execute("COMMIT;") - - except IntegrityError: - raise NotificationServiceNotFound - finally: - cursor.connection.isolation_level = "" - - return self.fetchone(id) - - def trigger_reminder(self, id: int) -> None: - """Trigger a static reminder to send it's reminder - - Args: - id (int): The id of the static reminder to trigger - - Raises: - ReminderNotFound: The static reminder with the given id was not found - """ - LOGGER.info(f'Triggering static reminder {id}') - cursor = get_db(dict) - reminder = cursor.execute(""" - SELECT title, text - FROM static_reminders - WHERE - id = ? - AND user_id = ? - LIMIT 1; - """, - (id, self.user_id) - ).fetchone() - if not reminder: - raise ReminderNotFound - reminder = dict(reminder) - - a = Apprise() - cursor.execute(""" - SELECT url - FROM reminder_services rs - INNER JOIN notification_services ns - ON rs.notification_service_id = ns.id - WHERE rs.static_reminder_id = ?; - """, - (id,) - ) - for url in cursor: - a.add(url['url']) - a.notify(title=reminder['title'], body=reminder['text'] or '\u200B') - return diff --git a/backend/templates.py b/backend/templates.py deleted file mode 100644 index e569f63..0000000 --- a/backend/templates.py +++ /dev/null @@ -1,311 +0,0 @@ -#-*- coding: utf-8 -*- - -from sqlite3 import IntegrityError -from typing import List, Optional, Union - -from backend.custom_exceptions import (NotificationServiceNotFound, - TemplateNotFound) -from backend.db import get_db -from backend.helpers import TimelessSortingMethod, search_filter -from backend.logging import LOGGER - - -class Template: - """Represents a template - """ - def __init__(self, user_id: int, template_id: int) -> None: - """Create instance of class. - - Args: - user_id (int): The ID of the user. - template_id (int): The ID of the template. - - Raises: - TemplateNotFound: Template with given ID does not exist or is not - owned by user. - """ - self.id = template_id - - exists = get_db().execute( - "SELECT 1 FROM templates WHERE id = ? AND user_id = ? LIMIT 1;", - (self.id, user_id) - ).fetchone() - if not exists: - raise TemplateNotFound - return - - def _get_notification_services(self) -> List[int]: - """Get ID's of notification services linked to the template. - - Returns: - List[int]: The list with ID's. - """ - result = [ - r[0] - for r in get_db().execute(""" - SELECT notification_service_id - FROM reminder_services - WHERE template_id = ?; - """, - (self.id,) - ) - ] - return result - - def get(self) -> dict: - """Get info about the template - - Returns: - dict: The info about the template - """ - template = get_db(dict).execute(""" - SELECT - id, - title, text, - color - FROM templates - WHERE id = ? - LIMIT 1; - """, - (self.id,) - ).fetchone() - template = dict(template) - - template['notification_services'] = self._get_notification_services() - - return template - - def update(self, - title: Union[str, None] = None, - notification_services: Union[List[int], None] = None, - text: Union[str, None] = None, - color: Union[str, None] = None - ) -> dict: - """Edit the template - - Args: - title (Union[str, None]): The new title of the entry. - Defaults to None. - - notification_services (Union[List[int], None]): The new id's of the - notification services to use to send the reminder. - Defaults to None. - - text (Union[str, None], optional): The new body of the template. - Defaults to None. - - color (Union[str, None], optional): The new hex code of the color of the template, - which is shown in the web-ui. - Defaults to None. - - Raises: - NotificationServiceNotFound: One of the notification services was not found - - Returns: - dict: The new template info - """ - LOGGER.info( - f'Updating template {self.id}: ' - + f'{title=}, {notification_services=}, {text=}, {color=}' - ) - - cursor = get_db() - - data = self.get() - new_values = { - 'title': title, - 'text': text, - 'color': color - } - for k, v in new_values.items(): - if k in ('color',) or v is not None: - data[k] = v - - cursor.execute(""" - UPDATE templates - SET title=?, text=?, color=? - WHERE id = ?; - """, ( - data['title'], - data['text'], - data['color'], - self.id - )) - - if notification_services: - cursor.connection.isolation_level = None - cursor.execute("BEGIN TRANSACTION;") - cursor.execute( - "DELETE FROM reminder_services WHERE template_id = ?", - (self.id,) - ) - try: - cursor.executemany(""" - INSERT INTO reminder_services( - template_id, - notification_service_id - ) - VALUES (?,?); - """, - ((self.id, s) for s in notification_services) - ) - cursor.execute("COMMIT;") - - except IntegrityError: - raise NotificationServiceNotFound - - finally: - cursor.connection.isolation_level = "" - - return self.get() - - def delete(self) -> None: - """Delete the template - """ - LOGGER.info(f'Deleting template {self.id}') - get_db().execute("DELETE FROM templates WHERE id = ?;", (self.id,)) - return - -class Templates: - """Represents the template library of the user account - """ - - def __init__(self, user_id: int) -> None: - """Create an instance. - - Args: - user_id (int): The ID of the user. - """ - self.user_id = user_id - return - - def fetchall( - self, - sort_by: TimelessSortingMethod = TimelessSortingMethod.TITLE - ) -> List[dict]: - """Get all templates of the user. - - Args: - sort_by (TimelessSortingMethod, optional): The sorting method of - the resulting list. - Defaults to TimelessSortingMethod.TITLE. - - Returns: - List[dict]: The id, title, text and color of each template. - """ - templates = [ - dict(r) - for r in get_db(dict).execute(""" - SELECT - id, - title, text, - color - FROM templates - WHERE user_id = ? - ORDER BY title, id; - """, - (self.user_id,) - ) - ] - - # Sort result - templates.sort(key=sort_by.value[0], reverse=sort_by.value[1]) - - return templates - - def search( - self, - query: str, - sort_by: TimelessSortingMethod = TimelessSortingMethod.TITLE - ) -> List[dict]: - """Search for templates - - Args: - query (str): The term to search for. - - sort_by (TimelessSortingMethod, optional): The sorting method of - the resulting list. - Defaults to TimelessSortingMethod.TITLE. - - Returns: - List[dict]: All templates that match. Similar output to `self.fetchall` - """ - templates = [ - r for r in self.fetchall(sort_by) - if search_filter(query, r) - ] - return templates - - def fetchone(self, id: int) -> Template: - """Get one template - - Args: - id (int): The id of the template to fetch - - Returns: - Template: A Template instance - """ - return Template(self.user_id, id) - - def add( - self, - title: str, - notification_services: List[int], - text: str = '', - color: Optional[str] = None - ) -> Template: - """Add a template - - Args: - title (str): The title of the entry. - - notification_services (List[int]): The id's of the - notification services to use to send the reminder. - - text (str, optional): The body of the reminder. - Defaults to ''. - - color (Optional[str], optional): The hex code of the color of the template, - which is shown in the web-ui. - Defaults to None. - - Raises: - NotificationServiceNotFound: One of the notification services was not found - - Returns: - Template: The info about the template - """ - LOGGER.info( - f'Adding template with {title=}, {notification_services=}, {text=}, {color=}' - ) - - cursor = get_db() - cursor.connection.isolation_level = None - cursor.execute("BEGIN TRANSACTION;") - - id = cursor.execute(""" - INSERT INTO templates(user_id, title, text, color) - VALUES (?,?,?,?); - """, - (self.user_id, title, text, color) - ).lastrowid - - try: - cursor.executemany(""" - INSERT INTO reminder_services( - template_id, - notification_service_id - ) - VALUES (?, ?); - """, - ((id, service) for service in notification_services) - ) - cursor.execute("COMMIT;") - - except IntegrityError: - raise NotificationServiceNotFound - - finally: - cursor.connection.isolation_level = "" - - return self.fetchone(id) diff --git a/backend/users.py b/backend/users.py deleted file mode 100644 index 259a398..0000000 --- a/backend/users.py +++ /dev/null @@ -1,235 +0,0 @@ -#-*- coding: utf-8 -*- - -from typing import List - -from backend.custom_exceptions import (AccessUnauthorized, - NewAccountsNotAllowed, UsernameInvalid, - UsernameTaken, UserNotFound) -from backend.db import get_db -from backend.logging import LOGGER -from backend.notification_service import NotificationServices -from backend.reminders import Reminders -from backend.security import generate_salt_hash, get_hash -from backend.settings import get_setting -from backend.static_reminders import StaticReminders -from backend.templates import Templates - -ONEPASS_USERNAME_CHARACTERS = 'abcedfghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_-.!@$' -ONEPASS_INVALID_USERNAMES = ['reminders', 'api'] - -class User: - """Represents an user account - """ - - def __init__(self, id: int) -> None: - result = get_db(dict).execute( - "SELECT username, admin, salt FROM users WHERE id = ? LIMIT 1;", - (id,) - ).fetchone() - if not result: - raise UserNotFound - - self.username: str = result['username'] - self.user_id = id - self.admin: bool = result['admin'] == 1 - self.salt: bytes = result['salt'] - return - - @property - def reminders(self) -> Reminders: - """Get access to the reminders of the user account - - Returns: - Reminders: Reminders instance that can be used to access the - reminders of the user account - """ - if not hasattr(self, 'reminders_instance'): - self.reminders_instance = Reminders(self.user_id) - return self.reminders_instance - - @property - def notification_services(self) -> NotificationServices: - """Get access to the notification services of the user account - - Returns: - NotificationServices: NotificationServices instance that can be used - to access the notification services of the user account - """ - if not hasattr(self, 'notification_services_instance'): - self.notification_services_instance = NotificationServices(self.user_id) - return self.notification_services_instance - - @property - def templates(self) -> Templates: - """Get access to the templates of the user account - - Returns: - Templates: Templates instance that can be used to access the - templates of the user account - """ - if not hasattr(self, 'templates_instance'): - self.templates_instance = Templates(self.user_id) - return self.templates_instance - - @property - def static_reminders(self) -> StaticReminders: - """Get access to the static reminders of the user account - - Returns: - StaticReminders: StaticReminders instance that can be used to - access the static reminders of the user account - """ - if not hasattr(self, 'static_reminders_instance'): - self.static_reminders_instance = StaticReminders(self.user_id) - return self.static_reminders_instance - - def edit_password(self, new_password: str) -> None: - """Change the password of the account - - Args: - new_password (str): The new password - """ - # Encrypt raw key with new password - hash_password = get_hash(self.salt, new_password) - - # Update database - get_db().execute( - "UPDATE users SET hash = ? WHERE id = ?", - (hash_password, self.user_id) - ) - LOGGER.info(f'The user {self.username} ({self.user_id}) changed their password') - return - - def delete(self) -> None: - """Delete the user account - """ - if self.username == 'admin': - raise UserNotFound - - LOGGER.info(f'Deleting the user {self.username} ({self.user_id})') - - cursor = get_db() - cursor.execute( - "DELETE FROM reminders WHERE user_id = ?", - (self.user_id,) - ) - cursor.execute( - "DELETE FROM templates WHERE user_id = ?", - (self.user_id,) - ) - cursor.execute( - "DELETE FROM static_reminders WHERE user_id = ?", - (self.user_id,) - ) - cursor.execute( - "DELETE FROM notification_services WHERE user_id = ?", - (self.user_id,) - ) - cursor.execute( - "DELETE FROM users WHERE id = ?", - (self.user_id,) - ) - return - -class Users: - def _check_username(self, username: str) -> None: - """Check if username is valid - - Args: - username (str): The username to check - - Raises: - UsernameInvalid: The username is not valid - """ - LOGGER.debug(f'Checking the username {username}') - if username in ONEPASS_INVALID_USERNAMES or username.isdigit(): - raise UsernameInvalid(username) - if list(filter(lambda c: not c in ONEPASS_USERNAME_CHARACTERS, username)): - raise UsernameInvalid(username) - return - - def __contains__(self, username: str) -> bool: - result = get_db().execute( - "SELECT 1 FROM users WHERE username = ? LIMIT 1;", - (username,) - ).fetchone() - return result is not None - - def add(self, username: str, password: str, from_admin: bool=False) -> int: - """Add a user - - Args: - username (str): The username of the new user - password (str): The password of the new user - from_admin (bool, optional): Skip check if new accounts are allowed. - Defaults to False. - - Raises: - UsernameInvalid: Username not allowed or contains invalid characters - UsernameTaken: Username is already taken; usernames must be unique - NewAccountsNotAllowed: In the admin panel, new accounts are set to be - not allowed. - - Returns: - int: The id of the new user. User registered successful - """ - LOGGER.info(f'Registering user with username {username}') - - if not from_admin and not get_setting('allow_new_accounts'): - raise NewAccountsNotAllowed - - # Check if username is valid - self._check_username(username) - - cursor = get_db() - - # Check if username isn't already taken - if username in self: - raise UsernameTaken - - # Generate salt and key exclusive for user - salt, hashed_password = generate_salt_hash(password) - del password - - # Add user to userlist - user_id = cursor.execute( - """ - INSERT INTO users(username, salt, hash) - VALUES (?,?,?); - """, - (username, salt, hashed_password) - ).lastrowid - - LOGGER.debug(f'Newly registered user has id {user_id}') - return user_id - - def get_all(self) -> List[dict]: - """Get all user info for the admin - - Returns: - List[dict]: The info about all users - """ - result = [ - dict(u) - for u in get_db(dict).execute( - "SELECT id, username, admin FROM users ORDER BY username;" - ) - ] - return result - - def login(self, username: str, password: str) -> User: - result = get_db(dict).execute( - "SELECT id, salt, hash FROM users WHERE username = ? LIMIT 1;", - (username,) - ).fetchone() - if not result: - raise UserNotFound - - hash_password = get_hash(result['salt'], password) - if not hash_password == result['hash']: - raise AccessUnauthorized - - return User(result['id']) - - def get_one(self, id: int) -> User: - return User(id) diff --git a/docs/index.md b/docs/index.md index 6aeee28..d33a9f6 100644 --- a/docs/index.md +++ b/docs/index.md @@ -7,7 +7,7 @@ hide: __A simple self hosted reminder application that can send push notifications to your device. Set the reminder and forget about it!__ -MIND can be used for sending notifications at the desired time. This can be a set time, like a yearly reminder for a birthday, or at a button click, to easily send a predefined notification when you want to. The notification can be sent to 80+ platforms with the integration of [Apprise](https://github.com/caronc/apprise). +MIND can be used for sending notifications at the desired time. This can be a set time, like a yearly reminder for a birthday, or at a button click, to easily send a predefined notification when you want to. The notification can be sent to 100+ platforms with the integration of [Apprise](https://github.com/caronc/apprise). ## Quick Links diff --git a/frontend/api.py b/frontend/api.py index 3bf5d07..5bc1766 100644 --- a/frontend/api.py +++ b/frontend/api.py @@ -1,35 +1,30 @@ -#-*- coding: utf-8 -*- +# -*- coding: utf-8 -*- -from __future__ import annotations - -from dataclasses import dataclass from datetime import datetime -from io import BytesIO +from io import BytesIO, StringIO from os import remove, urandom -from os.path import basename, exists +from os.path import exists from time import time as epoch_time -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple +from typing import Any, Callable, Dict, Tuple, Union -from flask import g, request, send_file +from flask import Response, g, request, send_file -from backend.custom_exceptions import (AccessUnauthorized, APIKeyExpired, - APIKeyInvalid, InvalidDatabaseFile, - InvalidKeyValue, InvalidTime, - KeyNotFound, LogFileNotFound, - NewAccountsNotAllowed, - NotificationServiceInUse, - NotificationServiceNotFound, - ReminderNotFound, TemplateNotFound, - UsernameInvalid, UsernameTaken, - UserNotFound) -from backend.db import get_db, import_db, revert_db_import -from backend.helpers import RestartVars, folder_path -from backend.logging import LOGGER, get_debug_log_filepath -from backend.notification_service import get_apprise_services -from backend.server import SERVER -from backend.settings import (backup_hosting_settings, get_admin_settings, - get_setting, set_setting) -from backend.users import Users +from backend.base.custom_exceptions import (APIKeyExpired, APIKeyInvalid, + LogFileNotFound) +from backend.base.definitions import (ApiKeyEntry, Method, Methods, + MindException, SendResult, + Serialisable, StartType) +from backend.base.helpers import folder_path +from backend.base.logging import LOGGER, get_log_filepath +from backend.features.reminders import Reminders +from backend.features.static_reminders import StaticReminders +from backend.features.templates import Templates +from backend.implementations.apprise_parser import get_apprise_services +from backend.implementations.notification_services import NotificationServices +from backend.implementations.users import Users +from backend.internals.db import get_db, import_db +from backend.internals.server import Server, diffuse_timers +from backend.internals.settings import Settings, get_about_data from frontend.input_validation import (AllowNewAccountsVariable, ColorVariable, CopyHostingSettingsVariable, DatabaseFileVariable, @@ -39,7 +34,7 @@ from frontend.input_validation import (AllowNewAccountsVariable, ColorVariable, EditURLVariable, HostVariable, LoginTimeResetVariable, LoginTimeVariable, LogLevelVariable, - Method, Methods, NewPasswordVariable, + NewPasswordVariable, NotificationServicesVariable, PasswordCreateVariable, PasswordVariable, PortVariable, @@ -53,798 +48,867 @@ from frontend.input_validation import (AllowNewAccountsVariable, ColorVariable, admin_api, api, get_api_docs, input_validation) -if TYPE_CHECKING: - from backend.users import User - - -#=================== -# General variables and functions -#=================== - -@dataclass -class ApiKeyEntry: - exp: int - user_data: User - +# =================== +# region General variables and functions +# =================== users = Users() api_key_map: Dict[int, ApiKeyEntry] = {} + def return_api( - result: Any, - error: Optional[str] = None, - code: int = 200 -) -> Tuple[dict, int]: - return {'error': error, 'result': result}, code + result: Serialisable, + error: Union[str, None] = None, + code: int = 200 +) -> Tuple[Dict[str, Any], int]: + return {'error': error, 'result': result}, code + def auth() -> None: - """Checks if the client is logged in + """Checks if the client is logged in. - Raises: - APIKeyInvalid: The api key supplied is invalid - APIKeyExpired: The api key supplied has expired - """ - hashed_api_key = hash(request.values.get('api_key','')) - if not hashed_api_key in api_key_map: - raise APIKeyInvalid + Raises: + APIKeyInvalid: The api key supplied is invalid. + APIKeyExpired: The api key supplied has expired. + """ + api_key = request.values.get('api_key', '') + hashed_api_key = hash(api_key) - map_entry = api_key_map[hashed_api_key] + if hashed_api_key not in api_key_map: + raise APIKeyInvalid(api_key) - if ( - map_entry.user_data.admin - and - not request.path.startswith((SERVER.admin_prefix, SERVER.api_prefix + '/auth')) - ): - raise APIKeyInvalid - - if ( - not map_entry.user_data.admin - and - request.path.startswith(SERVER.admin_prefix) - ): - raise APIKeyInvalid + map_entry = api_key_map[hashed_api_key] + user_data = map_entry.user_data.get() - if map_entry.exp <= epoch_time(): - raise APIKeyExpired + if ( + user_data.admin + and not request.path.startswith( + (Server.admin_prefix, Server.api_prefix + '/auth') + ) + ): + raise APIKeyInvalid(api_key) - # Api key valid - - if get_setting('login_time_reset'): - g.exp = map_entry.exp = ( - epoch_time() + get_setting('login_time') - ) - else: - g.exp = map_entry.exp + if ( + not user_data.admin + and + request.path.startswith(Server.admin_prefix) + ): + raise APIKeyInvalid(api_key) - g.hashed_api_key = hashed_api_key - g.user_data = map_entry.user_data + if map_entry.exp <= epoch_time(): + raise APIKeyExpired(api_key) - return + # Api key valid + sv = Settings().get_settings() + if sv.login_time_reset: + g.exp = map_entry.exp = ( + int(epoch_time()) + sv.login_time + ) + else: + g.exp = map_entry.exp -def endpoint_wrapper(method: Callable) -> Callable: - def wrapper(*args, **kwargs): - requires_auth = get_api_docs(request).requires_auth + g.hashed_api_key = hashed_api_key + g.user_data = map_entry.user_data - try: - if requires_auth: - auth() + return - inputs = input_validation() - if inputs is None: - return method(*args, **kwargs) - return method(inputs, *args, **kwargs) +def endpoint_wrapper( + method: Union[ + Callable[[Dict[str, Any]], Union[Tuple[Union[Dict[str, Any], Response], int], None]], + Callable[[Dict[str, Any], int], Union[Tuple[Union[Dict[str, Any], Response], int], None]] + ] +) -> Callable: + def wrapper(*args, **kwargs): + requires_auth = get_api_docs(request).requires_auth - except ( - AccessUnauthorized, APIKeyExpired, - APIKeyInvalid, InvalidDatabaseFile, - InvalidKeyValue, InvalidTime, - KeyNotFound, LogFileNotFound, - NewAccountsNotAllowed, - NotificationServiceInUse, - NotificationServiceNotFound, - ReminderNotFound, TemplateNotFound, - UsernameInvalid, UsernameTaken, - UserNotFound - ) as e: - return return_api(**e.api_response) - - wrapper.__name__ = method.__name__ - return wrapper + try: + if requires_auth: + auth() -#=================== -# Authentication endpoints -#=================== + inputs = input_validation() + result = method(inputs, *args, **kwargs) + except MindException as e: + result = return_api(**e.api_response) + + return result + + wrapper.__name__ = method.__name__ + return wrapper + + +# =================== +# region General Handling +# =================== +@api.errorhandler(404) +def api_not_found(e): + return {'error': "NotFound", "result": {}}, 404 + + +# =================== +# region Auth +# =================== @api.route( - '/auth/login', - 'Login to a user account', - Methods( - post=Method( - vars=[UsernameVariable, PasswordVariable] - ) - ), - requires_auth=False, - methods=['POST'] + '/auth/login', + 'Login to a user account', + Methods( + post=Method( + vars=[UsernameVariable, PasswordVariable] + ) + ), + requires_auth=False ) @endpoint_wrapper -def api_login(inputs: Dict[str, str]): - user = users.login(inputs['username'], inputs['password']) +def api_login(inputs: Dict[str, Any]): + user = users.login(inputs['username'], inputs['password']) - # Login successful + # Login successful - if user.admin and SERVER.revert_db_timer.is_alive(): - LOGGER.info('Timer for database import diffused') - SERVER.revert_db_timer.cancel() - revert_db_import(swap=False) + diffuse_timers() - elif user.admin and SERVER.revert_hosting_timer.is_alive(): - LOGGER.info('Timer for hosting changes diffused') - SERVER.revert_hosting_timer.cancel() + # Generate an API key until one + # is generated that isn't used already + while True: + api_key = urandom(16).hex() # <- length api key / 2 + hashed_api_key = hash(api_key) + if hashed_api_key not in api_key_map: + break - # Generate an API key until one - # is generated that isn't used already - while True: - api_key = urandom(16).hex() # <- length api key / 2 - hashed_api_key = hash(api_key) - if not hashed_api_key in api_key_map: - break + login_time = Settings().sv.login_time + exp = int(epoch_time()) + login_time + api_key_map[hashed_api_key] = ApiKeyEntry(exp, user) - login_time = get_setting('login_time') - exp = epoch_time() + login_time - api_key_map[hashed_api_key] = ApiKeyEntry(exp, user) + result = { + 'api_key': api_key, + 'expires': exp, + 'admin': user.get().admin + } + return return_api(result, code=201) - result = {'api_key': api_key, 'expires': exp, 'admin': user.admin} - return return_api(result, code=201) @api.route( - '/auth/logout', - 'Logout of a user account', - methods=['POST'] + '/auth/logout', + 'Logout of a user account', + methods=['POST'] ) @endpoint_wrapper -def api_logout(): - api_key_map.pop(g.hashed_api_key) - return return_api({}, code=201) +def api_logout(inputs: Dict[str, Any]): + api_key_map.pop(g.hashed_api_key) + return return_api({}, code=201) + @api.route( - '/auth/status', - 'Get current status of login', - methods=['GET'] + '/auth/status', + 'Get current status of login', + methods=['GET'] ) @endpoint_wrapper -def api_status(): - map_entry = api_key_map[g.hashed_api_key] - result = { - 'expires': map_entry.exp, - 'username': map_entry.user_data.username, - 'admin': map_entry.user_data.admin - } - return return_api(result) +def api_status(inputs: Dict[str, Any]): + map_entry = api_key_map[g.hashed_api_key] + user_data = map_entry.user_data.get() + result = { + 'expires': map_entry.exp, + 'username': user_data.username, + 'admin': user_data.admin + } + return return_api(result) -#=================== -# User endpoints -#=================== + +# =================== +# region User +# =================== @api.route( - '/user/add', - 'Create a new user account', - Methods( - post=Method( - vars=[UsernameCreateVariable, PasswordCreateVariable] - ) - ), - requires_auth=False, - methods=['POST'] + '/user/add', + 'Create a new user account', + Methods( + post=Method( + vars=[UsernameCreateVariable, PasswordCreateVariable] + ) + ), + requires_auth=False, + methods=['POST'] ) @endpoint_wrapper def api_add_user(inputs: Dict[str, str]): - users.add(inputs['username'], inputs['password']) - return return_api({}, code=201) + users.add(inputs['username'], inputs['password']) + return return_api({}, code=201) + @api.route( - '/user', - 'Manage a user account', - Methods( - put=Method( - vars=[NewPasswordVariable], - description="Change the password of the user account" - ), - delete=Method( - description='Delete the user account' - ) - ), - methods=['PUT', 'DELETE'] + '/user', + 'Manage a user account', + Methods( + put=Method( + vars=[NewPasswordVariable], + description="Change the password of the user account" + ), + delete=Method( + description='Delete the user account' + ) + ), + methods=['PUT', 'DELETE'] ) @endpoint_wrapper -def api_manage_user(inputs: Dict[str, str]): - user = api_key_map[g.hashed_api_key].user_data - if request.method == 'PUT': - user.edit_password(inputs['new_password']) - return return_api({}) - - elif request.method == 'DELETE': - user.delete() - api_key_map.pop(g.hashed_api_key) - return return_api({}) +def api_manage_user(inputs: Dict[str, Any]): + user = api_key_map[g.hashed_api_key].user_data + if request.method == 'PUT': + user.update(None, inputs['new_password']) + return return_api({}) -#=================== -# Notification service endpoints -#=================== + elif request.method == 'DELETE': + user.delete() + api_key_map.pop(g.hashed_api_key) + return return_api({}) + +# =================== +# region Notification Service +# =================== @api.route( - '/notificationservices', - 'Manage the notification services', - Methods( - get=Method( - description='Get a list of all notification services' - ), - post=Method( - vars=[TitleVariable, URLVariable], - description='Add a notification service' - ) - ), - methods=['GET', 'POST'] + '/notificationservices', + 'Manage the notification services', + Methods( + get=Method( + description='Get a list of all notification services' + ), + post=Method( + vars=[TitleVariable, URLVariable], + description='Add a notification service' + ) + ), + methods=['GET', 'POST'] ) @endpoint_wrapper def api_notification_services_list(inputs: Dict[str, str]): - services = api_key_map[g.hashed_api_key].user_data.notification_services + services = NotificationServices( + api_key_map[g.hashed_api_key].user_data.user_id + ) + + if request.method == 'GET': + result = services.fetchall() + return return_api(result=[r.todict() for r in result]) + + elif request.method == 'POST': + result = services.add( + title=inputs['title'], + url=inputs['url'] + ).get() + return return_api(result.todict(), code=201) - if request.method == 'GET': - result = services.fetchall() - return return_api(result) - - elif request.method == 'POST': - result = services.add(title=inputs['title'], - url=inputs['url']).get() - return return_api(result, code=201) @api.route( - '/notificationservices/available', - 'Get all available notification services and their url layout', - methods=['GET'] + '/notificationservices/available', + 'Get all available notification services and their url layout', + methods=['GET'] ) @endpoint_wrapper -def api_notification_service_available(): - result = get_apprise_services() - return return_api(result) +def api_notification_service_available(inputs: Dict[str, str]): + result = get_apprise_services() + return return_api(result) # type: ignore + @api.route( - '/notificationservices/test', - 'Send a test notification using the supplied Apprise URL', - Methods( - post=Method( - vars=[URLVariable] - ) - ), - methods=['POST'] + '/notificationservices/test', + 'Send a test notification using the supplied Apprise URL', + Methods( + post=Method( + vars=[URLVariable] + ) + ), + methods=['POST'] ) @endpoint_wrapper def api_test_service(inputs: Dict[str, Any]): - (api_key_map[g.hashed_api_key] - .user_data - .notification_services - .test_service(inputs['url'])) - return return_api({}, code=201) + user_id = api_key_map[g.hashed_api_key].user_data.user_id + + success = NotificationServices(user_id).test(inputs['url']) + return return_api( + { + 'success': success == SendResult.SUCCESS, + 'description': success.value + }, + code=201 + ) + @api.route( - '/notificationservices/', - 'Manage a specific notification service', - Methods( - put=Method( - vars=[EditTitleVariable, EditURLVariable], - description='Edit the notification service' - ), - delete=Method( - vars=[DeleteRemindersUsingVariable], - description='Delete the notification service' - ) - ), - methods=['GET', 'PUT', 'DELETE'] + '/notificationservices/', + 'Manage a specific notification service', + Methods( + put=Method( + vars=[EditTitleVariable, EditURLVariable], + description='Edit the notification service' + ), + delete=Method( + vars=[DeleteRemindersUsingVariable], + description='Delete the notification service' + ) + ), + methods=['GET', 'PUT', 'DELETE'] ) @endpoint_wrapper -def api_notification_service(inputs: Dict[str, str], n_id: int): - service = (api_key_map[g.hashed_api_key] - .user_data - .notification_services - .fetchone(n_id)) +def api_notification_service(inputs: Dict[str, Any], n_id: int): + user_id = api_key_map[g.hashed_api_key].user_data.user_id + service = NotificationServices(user_id).fetchone(n_id) - if request.method == 'GET': - result = service.get() - return return_api(result) + if request.method == 'GET': + result = service.get() + return return_api(result.todict()) - elif request.method == 'PUT': - result = service.update(title=inputs['title'], - url=inputs['url']) - return return_api(result) + elif request.method == 'PUT': + result = service.update( + title=inputs['title'], + url=inputs['url'] + ) + return return_api(result.todict()) - elif request.method == 'DELETE': - service.delete( - inputs['delete_reminders_using'] - ) - return return_api({}) + elif request.method == 'DELETE': + service.delete( + inputs['delete_reminders_using'] + ) + return return_api({}) -#=================== -# Library endpoints -#=================== +# =================== +# region Library +# =================== @api.route( - '/reminders', - 'Manage the reminders', - Methods( - get=Method( - vars=[SortByVariable], - description='Get a list of all reminders' - ), - post=Method( - vars=[TitleVariable, TimeVariable, - NotificationServicesVariable, TextVariable, - RepeatQuantityVariable, RepeatIntervalVariable, - WeekDaysVariable, - ColorVariable], - description='Add a reminder' - ), - ), - methods=['GET', 'POST'] + '/reminders', + 'Manage the reminders', + Methods( + get=Method( + vars=[SortByVariable], + description='Get a list of all reminders' + ), + post=Method( + vars=[TitleVariable, TimeVariable, + NotificationServicesVariable, TextVariable, + RepeatQuantityVariable, RepeatIntervalVariable, + WeekDaysVariable, + ColorVariable], + description='Add a reminder' + ), + ), + methods=['GET', 'POST'] ) @endpoint_wrapper def api_reminders_list(inputs: Dict[str, Any]): - reminders = api_key_map[g.hashed_api_key].user_data.reminders - - if request.method == 'GET': - result = reminders.fetchall(inputs['sort_by']) - return return_api(result) + reminders = Reminders(api_key_map[g.hashed_api_key].user_data.user_id) + + if request.method == 'GET': + result = reminders.fetchall(inputs['sort_by']) + return return_api([r.todict() for r in result]) + + elif request.method == 'POST': + result = reminders.add( + title=inputs['title'], + time=inputs['time'], + notification_services=inputs['notification_services'], + text=inputs['text'], + repeat_quantity=inputs['repeat_quantity'], + repeat_interval=inputs['repeat_interval'], + weekdays=inputs['weekdays'], + color=inputs['color'] + ) + return return_api(result.get().todict(), code=201) - elif request.method == 'POST': - result = reminders.add(title=inputs['title'], - time=inputs['time'], - notification_services=inputs['notification_services'], - text=inputs['text'], - repeat_quantity=inputs['repeat_quantity'], - repeat_interval=inputs['repeat_interval'], - weekdays=inputs['weekdays'], - color=inputs['color']) - return return_api(result.get(), code=201) @api.route( - '/reminders/search', - 'Search through the list of reminders', - Methods( - get=Method( - vars=[SortByVariable, QueryVariable] - ) - ), - methods=['GET'] + '/reminders/search', + 'Search through the list of reminders', + Methods( + get=Method( + vars=[SortByVariable, QueryVariable] + ) + ), + methods=['GET'] ) @endpoint_wrapper -def api_reminders_query(inputs: Dict[str, str]): - result = (api_key_map[g.hashed_api_key] - .user_data - .reminders - .search(inputs['query'], inputs['sort_by'])) - return return_api(result) +def api_reminders_query(inputs: Dict[str, Any]): + reminders = Reminders(api_key_map[g.hashed_api_key].user_data.user_id) + result = reminders.search(inputs['query'], inputs['sort_by']) + return return_api([r.todict() for r in result]) + @api.route( - '/reminders/test', - 'Test send a reminder draft', - Methods( - post=Method( - vars=[TitleVariable, NotificationServicesVariable, - TextVariable] - ) - ), - methods=['POST'] + '/reminders/test', + 'Test send a reminder draft', + Methods( + post=Method( + vars=[TitleVariable, NotificationServicesVariable, + TextVariable] + ) + ), + methods=['POST'] ) @endpoint_wrapper def api_test_reminder(inputs: Dict[str, Any]): - api_key_map[g.hashed_api_key].user_data.reminders.test_reminder( - inputs['title'], - inputs['notification_services'], - inputs['text'] - ) - return return_api({}, code=201) + Reminders( + api_key_map[g.hashed_api_key].user_data.user_id + ).test_reminder( + inputs['title'], + inputs['notification_services'], + inputs['text'] + ) + return return_api({}, code=201) + @api.route( - '/reminders/', - 'Manage a specific reminder', - Methods( - put=Method( - vars=[EditTitleVariable, EditTimeVariable, - EditNotificationServicesVariable, TextVariable, - RepeatQuantityVariable, RepeatIntervalVariable, - WeekDaysVariable, - ColorVariable], - description='Edit the reminder' - ), - delete=Method( - description='Delete the reminder' - ) - ), - methods=['GET', 'PUT', 'DELETE'] + '/reminders/', + 'Manage a specific reminder', + Methods( + put=Method( + vars=[EditTitleVariable, EditTimeVariable, + EditNotificationServicesVariable, TextVariable, + RepeatQuantityVariable, RepeatIntervalVariable, + WeekDaysVariable, + ColorVariable], + description='Edit the reminder' + ), + delete=Method( + description='Delete the reminder' + ) + ), + methods=['GET', 'PUT', 'DELETE'] ) @endpoint_wrapper def api_get_reminder(inputs: Dict[str, Any], r_id: int): - reminders = api_key_map[g.hashed_api_key].user_data.reminders + reminders = Reminders( + api_key_map[g.hashed_api_key].user_data.user_id + ) - if request.method == 'GET': - result = reminders.fetchone(r_id).get() - return return_api(result) + if request.method == 'GET': + result = reminders.fetchone(r_id).get() + return return_api(result.todict()) - elif request.method == 'PUT': - result = reminders.fetchone(r_id).update(title=inputs['title'], - time=inputs['time'], - notification_services=inputs['notification_services'], - text=inputs['text'], - repeat_quantity=inputs['repeat_quantity'], - repeat_interval=inputs['repeat_interval'], - weekdays=inputs['weekdays'], - color=inputs['color']) - return return_api(result) + elif request.method == 'PUT': + result = reminders.fetchone(r_id).update( + title=inputs['title'], + time=inputs['time'], + notification_services=inputs['notification_services'], + text=inputs['text'], + repeat_quantity=inputs['repeat_quantity'], + repeat_interval=inputs['repeat_interval'], + weekdays=inputs['weekdays'], + color=inputs['color'] + ) + return return_api(result.todict()) - elif request.method == 'DELETE': - reminders.fetchone(r_id).delete() - return return_api({}) + elif request.method == 'DELETE': + reminders.fetchone(r_id).delete() + return return_api({}) -#=================== -# Template endpoints -#=================== +# =================== +# region Template +# =================== @api.route( - '/templates', - 'Manage the templates', - Methods( - get=Method( - vars=[TimelessSortByVariable], - description='Get a list of all templates' - ), - post=Method( - vars=[TitleVariable, NotificationServicesVariable, - TextVariable, ColorVariable], - description='Add a template' - ) - ), - methods=['GET', 'POST'] + '/templates', + 'Manage the templates', + Methods( + get=Method( + vars=[TimelessSortByVariable], + description='Get a list of all templates' + ), + post=Method( + vars=[TitleVariable, NotificationServicesVariable, + TextVariable, ColorVariable], + description='Add a template' + ) + ), + methods=['GET', 'POST'] ) @endpoint_wrapper def api_get_templates(inputs: Dict[str, Any]): - templates = api_key_map[g.hashed_api_key].user_data.templates - - if request.method == 'GET': - result = templates.fetchall(inputs['sort_by']) - return return_api(result) - - elif request.method == 'POST': - result = templates.add(title=inputs['title'], - notification_services=inputs['notification_services'], - text=inputs['text'], - color=inputs['color']) - return return_api(result.get(), code=201) + templates = Templates( + api_key_map[g.hashed_api_key].user_data.user_id + ) + + if request.method == 'GET': + result = templates.fetchall(inputs['sort_by']) + return return_api([r.todict() for r in result]) + + elif request.method == 'POST': + result = templates.add( + title=inputs['title'], + notification_services=inputs['notification_services'], + text=inputs['text'], + color=inputs['color'] + ) + return return_api(result.get().todict(), code=201) + @api.route( - '/templates/search', - 'Search through the list of templates', - Methods( - get=Method( - vars=[TimelessSortByVariable, QueryVariable] - ) - ), - methods=['GET'] + '/templates/search', + 'Search through the list of templates', + Methods( + get=Method( + vars=[TimelessSortByVariable, QueryVariable] + ) + ), + methods=['GET'] ) @endpoint_wrapper -def api_templates_query(inputs: Dict[str, str]): - result = (api_key_map[g.hashed_api_key] - .user_data - .templates - .search(inputs['query'], inputs['sort_by'])) - return return_api(result) +def api_templates_query(inputs: Dict[str, Any]): + templates = Templates( + api_key_map[g.hashed_api_key].user_data.user_id + ) + result = templates.search(inputs['query'], inputs['sort_by']) + return return_api([r.todict() for r in result]) + @api.route( - '/templates/', - 'Manage a specific template', - Methods( - put=Method( - vars=[EditTitleVariable, EditNotificationServicesVariable, - TextVariable, ColorVariable], - description='Edit the template' - ), - delete=Method( - description='Delete the template' - ) - ), - methods=['GET', 'PUT', 'DELETE'] + '/templates/', + 'Manage a specific template', + Methods( + put=Method( + vars=[EditTitleVariable, EditNotificationServicesVariable, + TextVariable, ColorVariable], + description='Edit the template' + ), + delete=Method( + description='Delete the template' + ) + ), + methods=['GET', 'PUT', 'DELETE'] ) @endpoint_wrapper def api_get_template(inputs: Dict[str, Any], t_id: int): - template = (api_key_map[g.hashed_api_key] - .user_data - .templates - .fetchone(t_id)) - - if request.method == 'GET': - result = template.get() - return return_api(result) - - elif request.method == 'PUT': - result = template.update(title=inputs['title'], - notification_services=inputs['notification_services'], - text=inputs['text'], - color=inputs['color']) - return return_api(result) + template = Templates( + api_key_map[g.hashed_api_key].user_data.user_id + ).fetchone(t_id) - elif request.method == 'DELETE': - template.delete() - return return_api({}) + if request.method == 'GET': + result = template.get() + return return_api(result.todict()) -#=================== -# Static reminder endpoints -#=================== + elif request.method == 'PUT': + result = template.update( + title=inputs['title'], + notification_services=inputs['notification_services'], + text=inputs['text'], + color=inputs['color'] + ) + return return_api(result.todict()) + elif request.method == 'DELETE': + template.delete() + return return_api({}) + + +# =================== +# region Static Reminder +# =================== @api.route( - '/staticreminders', - 'Manage the static reminders', - Methods( - get=Method( - vars=[TimelessSortByVariable], - description='Get a list of all static reminders' - ), - post=Method( - vars=[TitleVariable, NotificationServicesVariable, - TextVariable, ColorVariable], - description='Add a static reminder' - ) - ), - methods=['GET', 'POST'] + '/staticreminders', + 'Manage the static reminders', + Methods( + get=Method( + vars=[TimelessSortByVariable], + description='Get a list of all static reminders' + ), + post=Method( + vars=[TitleVariable, NotificationServicesVariable, + TextVariable, ColorVariable], + description='Add a static reminder' + ) + ), + methods=['GET', 'POST'] ) @endpoint_wrapper def api_static_reminders_list(inputs: Dict[str, Any]): - reminders = api_key_map[g.hashed_api_key].user_data.static_reminders - - if request.method == 'GET': - result = reminders.fetchall(inputs['sort_by']) - return return_api(result) - - elif request.method == 'POST': - result = reminders.add(title=inputs['title'], - notification_services=inputs['notification_services'], - text=inputs['text'], - color=inputs['color']) - return return_api(result.get(), code=201) + reminders = StaticReminders( + api_key_map[g.hashed_api_key].user_data.user_id + ) + + if request.method == 'GET': + result = reminders.fetchall(inputs['sort_by']) + return return_api([r.todict() for r in result]) + + elif request.method == 'POST': + result = reminders.add( + title=inputs['title'], + notification_services=inputs['notification_services'], + text=inputs['text'], + color=inputs['color'] + ) + return return_api(result.get().todict(), code=201) + @api.route( - '/staticreminders/search', - 'Search through the list of staticreminders', - Methods( - get=Method( - vars=[TimelessSortByVariable, QueryVariable] - ) - ), - methods=['GET'] + '/staticreminders/search', + 'Search through the list of staticreminders', + Methods( + get=Method( + vars=[TimelessSortByVariable, QueryVariable] + ) + ), + methods=['GET'] ) @endpoint_wrapper -def api_static_reminders_query(inputs: Dict[str, str]): - result = (api_key_map[g.hashed_api_key] - .user_data - .static_reminders - .search(inputs['query'], inputs['sort_by'])) - return return_api(result) +def api_static_reminders_query(inputs: Dict[str, Any]): + result = StaticReminders( + api_key_map[g.hashed_api_key].user_data.user_id + ).search(inputs['query'], inputs['sort_by']) + return return_api([r.todict() for r in result]) + @api.route( - '/staticreminders/', - 'Manage a specific static reminder', - Methods( - post=Method( - description='Trigger the static reminder' - ), - put=Method( - vars=[EditTitleVariable, EditNotificationServicesVariable, - TextVariable, ColorVariable], - description='Edit the static reminder' - ), - delete=Method( - description='Delete the static reminder' - ) - ), - methods=['GET', 'POST', 'PUT', 'DELETE'] + '/staticreminders/', + 'Manage a specific static reminder', + Methods( + post=Method( + description='Trigger the static reminder' + ), + put=Method( + vars=[EditTitleVariable, EditNotificationServicesVariable, + TextVariable, ColorVariable], + description='Edit the static reminder' + ), + delete=Method( + description='Delete the static reminder' + ) + ), + methods=['GET', 'POST', 'PUT', 'DELETE'] ) @endpoint_wrapper def api_get_static_reminder(inputs: Dict[str, Any], s_id: int): - reminders = api_key_map[g.hashed_api_key].user_data.static_reminders + reminders = StaticReminders( + api_key_map[g.hashed_api_key].user_data.user_id + ) - if request.method == 'GET': - result = reminders.fetchone(s_id).get() - return return_api(result) + if request.method == 'GET': + result = reminders.fetchone(s_id).get() + return return_api(result.todict()) - elif request.method == 'POST': - reminders.trigger_reminder(s_id) - return return_api({}, code=201) + elif request.method == 'POST': + reminders.fetchone(s_id).trigger_reminder() + return return_api({}, code=201) - elif request.method == 'PUT': - result = reminders.fetchone(s_id).update(title=inputs['title'], - notification_services=inputs['notification_services'], - text=inputs['text'], - color=inputs['color']) - return return_api(result) + elif request.method == 'PUT': + result = reminders.fetchone(s_id).update( + title=inputs['title'], + notification_services=inputs['notification_services'], + text=inputs['text'], + color=inputs['color'] + ) + return return_api(result.todict()) - elif request.method == 'DELETE': - reminders.fetchone(s_id).delete() - return return_api({}) + elif request.method == 'DELETE': + reminders.fetchone(s_id).delete() + return return_api({}) -#=================== -# Admin panel endpoints -#=================== +# =================== +# region Admin Panel +# =================== @admin_api.route( - '/shutdown', - 'Shut down the application', - methods=['POST'] + '/shutdown', + 'Shut down the application', + methods=['POST'] ) @endpoint_wrapper -def api_shutdown(): - SERVER.shutdown() - return return_api({}) +def api_shutdown(inputs: Dict[str, Any]): + Server().shutdown() + return return_api({}) + @admin_api.route( - '/restart', - 'Restart the application', - methods=['POST'] + '/restart', + 'Restart the application', + methods=['POST'] ) @endpoint_wrapper -def api_restart(): - SERVER.restart() - return return_api({}) +def api_restart(inputs: Dict[str, Any]): + Server().restart() + return return_api({}) + @api.route( - '/settings', - 'Get the admin settings', - requires_auth=False, - methods=['GET'] + '/settings', + 'Get the admin settings', + requires_auth=False, + methods=['GET'] ) @endpoint_wrapper -def api_settings(): - return return_api(get_admin_settings()) +def api_settings(inputs: Dict[str, Any]): + return return_api(Settings().get_settings().todict()) + + +@api.route( + '/about', + "Get data about the application and it's environment", + requires_auth=False, + methods=['GET'] +) +@endpoint_wrapper +def api_about(inputs: Dict[str, Any]): + return return_api(get_about_data()) + @admin_api.route( - '/settings', - 'Interact with the admin settings', - Methods( - get=Method( - description='Get the admin settings' - ), - put=Method( - vars=[AllowNewAccountsVariable, LoginTimeVariable, - LoginTimeResetVariable, HostVariable, PortVariable, - UrlPrefixVariable, LogLevelVariable], - description='Edit the admin settings. Supplying a hosting setting will automatically restart MIND.' - ) - ), - methods=['GET', 'PUT'] + '/settings', + 'Interact with the admin settings', + Methods( + get=Method( + description='Get the admin settings' + ), + put=Method( + vars=[AllowNewAccountsVariable, LoginTimeVariable, + LoginTimeResetVariable, HostVariable, PortVariable, + UrlPrefixVariable, LogLevelVariable], + description='Edit the admin settings. Supplying a hosting setting will automatically restart MIND.' + ) + ), + methods=['GET', 'PUT'] ) @endpoint_wrapper def api_admin_settings(inputs: Dict[str, Any]): - if request.method == 'GET': - return return_api(get_admin_settings()) + settings = Settings() - elif request.method == 'PUT': - LOGGER.info(f'Submitting admin settings: {inputs}') + if request.method == 'GET': + return return_api(settings.get_settings().todict()) - hosting_changes = any( - inputs[s] is not None - for s in ('host', 'port', 'url_prefix') - ) + elif request.method == 'PUT': + LOGGER.info(f'Submitting admin settings: {inputs}') - if hosting_changes: - backup_hosting_settings() - - for k, v in inputs.items(): - if v is not None: - set_setting(k, v) + hosting_changes = any( + inputs[s] is not None + for s in ('host', 'port', 'url_prefix') + ) + + if hosting_changes: + settings.backup_hosting_settings() + + settings.update({ + k: v + for k, v in inputs.items() + if v is not None + }) + + if hosting_changes: + Server().restart(StartType.RESTART_HOSTING_CHANGES) + + return return_api({}) - if hosting_changes: - SERVER.restart([RestartVars.HOST_CHANGE.value]) - - return return_api({}) @admin_api.route( - '/logs', - 'Get the debug logs', - methods=['GET'] + '/logs', + 'Get the logs as a file', + methods=['GET'] ) @endpoint_wrapper -def api_admin_logs(): - file = get_debug_log_filepath() - if not exists(file): - raise LogFileNotFound +def api_admin_logs(inputs: Dict[str, Any]): + file = get_log_filepath() + if not exists(file): + raise LogFileNotFound(file) + + sio = StringIO() + for ext in ('.1', ''): + lf = file + ext + if not exists(lf): + continue + with open(lf, 'r') as f: + sio.writelines(f) + + return send_file( + BytesIO(sio.getvalue().encode('utf-8')), + mimetype="application/octet-stream", + download_name=f'MIND_log_{datetime.now().strftime("%Y_%m_%d_%H_%M")}.txt' + ), 200 - return send_file(file), 200 @admin_api.route( - '/users', - 'Get all users or add one', - Methods( - get=Method( - description='Get all users' - ), - post=Method( - vars=[UsernameCreateVariable, PasswordCreateVariable], - description='Add a new user' - ) - ), - methods=['GET', 'POST'] + '/users', + 'Get all users or add one', + Methods( + get=Method( + description='Get all users' + ), + post=Method( + vars=[UsernameCreateVariable, PasswordCreateVariable], + description='Add a new user' + ) + ), + methods=['GET', 'POST'] ) @endpoint_wrapper def api_admin_users(inputs: Dict[str, Any]): - if request.method == 'GET': - result = users.get_all() - return return_api(result) + if request.method == 'GET': + result = users.get_all() + return return_api([r.todict() for r in result]) + + elif request.method == 'POST': + users.add(inputs['username'], inputs['password'], True) + return return_api({}, code=201) - elif request.method == 'POST': - users.add(inputs['username'], inputs['password'], True) - return return_api({}, code=201) @admin_api.route( - '/users/', - 'Manage a specific user', - Methods( - put=Method( - vars=[NewPasswordVariable], - description='Change the password of the user account' - ), - delete=Method( - description='Delete the user account' - ) - ), - methods=['PUT', 'DELETE'] + '/users/', + 'Manage a specific user', + Methods( + put=Method( + vars=[NewPasswordVariable], + description='Change the password of the user account' + ), + delete=Method( + description='Delete the user account' + ) + ), + methods=['PUT', 'DELETE'] ) @endpoint_wrapper def api_admin_user(inputs: Dict[str, Any], u_id: int): - user = users.get_one(u_id) - if request.method == 'PUT': - user.edit_password(inputs['new_password']) - return return_api({}) - - elif request.method == 'DELETE': - user.delete() - for key, value in api_key_map.items(): - if value.user_data.user_id == u_id: - del api_key_map[key] - break - return return_api({}) + user = users.get_one(u_id) + if request.method == 'PUT': + user.update(None, inputs['new_password']) + return return_api({}) + + elif request.method == 'DELETE': + user.delete() + for key, value in api_key_map.items(): + if value.user_data.user_id == u_id: + del api_key_map[key] + break + return return_api({}) + @admin_api.route( - '/database', - 'Download and upload the database', - Methods( - get=Method( - description="Download the database file" - ), - post=Method( - vars=[DatabaseFileVariable, CopyHostingSettingsVariable], - description="Upload and apply a database file. Will automatically restart MIND." - ) - ), - methods=['GET', 'POST'] + '/database', + 'Download and upload the database', + Methods( + get=Method( + description="Download the database file" + ), + post=Method( + vars=[DatabaseFileVariable, CopyHostingSettingsVariable], + description="Upload and apply a database file. Will automatically restart MIND." + ) + ), + methods=['GET', 'POST'] ) @endpoint_wrapper def api_admin_database(inputs: Dict[str, Any]): - if request.method == "GET": - current_date = datetime.now().strftime(r"%Y_%m_%d_%H_%M") - filename = folder_path( - 'db', f'MIND_{current_date}.db' - ) - get_db().execute( - "VACUUM INTO ?;", - (filename,) - ) + if request.method == "GET": + current_date = datetime.now().strftime(r"%Y_%m_%d_%H_%M") + filename = folder_path( + 'db', f'MIND_{current_date}.db' + ) + get_db().execute( + "VACUUM INTO ?;", + (filename,) + ) - with open(filename, 'rb') as database_file: - bi = BytesIO(database_file.read()) + with open(filename, 'rb') as database_file: + bi = BytesIO(database_file.read()) - remove(filename) + remove(filename) + return send_file( + bi, + mimetype="application/x-sqlite3", + download_name=f'MIND_{current_date}.db' + ), 200 - return send_file( - bi, - mimetype='application/x-sqlite3', - download_name=basename(filename) - ), 200 - - elif request.method == "POST": - import_db(inputs['file'], inputs['copy_hosting_settings']) - return return_api({}) + elif request.method == "POST": + import_db(inputs['file'], inputs['copy_hosting_settings']) + return return_api({}) diff --git a/frontend/input_validation.py b/frontend/input_validation.py index 49d0d8d..eac4a7a 100644 --- a/frontend/input_validation.py +++ b/frontend/input_validation.py @@ -1,4 +1,4 @@ -#-*- coding: utf-8 -*- +# -*- coding: utf-8 -*- """ Input validation for the API @@ -6,583 +6,542 @@ Input validation for the API from __future__ import annotations -from abc import ABC, abstractmethod from dataclasses import dataclass, field -import logging +from logging import DEBUG, INFO from os.path import splitext from re import compile -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Type, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Type from apprise import Apprise from flask import Blueprint, request -from flask.sansio.scaffold import T_route -from backend.custom_exceptions import (AccessUnauthorized, InvalidDatabaseFile, - InvalidKeyValue, InvalidTime, - KeyNotFound, NewAccountsNotAllowed, - NotificationServiceNotFound, - UsernameInvalid, UsernameTaken, - UserNotFound) -from backend.helpers import (RepeatQuantity, SortingMethod, - TimelessSortingMethod, folder_path) -from backend.server import SERVER -from backend.settings import _format_setting +from backend.base.custom_exceptions import (AccessUnauthorized, + InvalidDatabaseFile, + InvalidKeyValue, InvalidTime, + KeyNotFound, NewAccountsNotAllowed, + NotificationServiceNotFound, + UsernameInvalid, UsernameTaken) +from backend.base.definitions import (ApiDocEntry, DataSource, DataType, + InputVariable, Methods, MindException, + SortingMethod, T, TimelessSortingMethod) +from backend.base.helpers import RepeatQuantity, folder_path +from backend.internals.server import Server if TYPE_CHECKING: - from flask import Request + from flask import Request + from flask.sansio.scaffold import T_route + color_regex = compile(r'#[0-9a-f]{6}') - api_docs: Dict[str, ApiDocEntry] = {} -class DataSource: - DATA = 1 - VALUES = 2 - FILES = 3 +def request_data(request: Request) -> Dict[DataSource, Dict[str, Any]]: + """Returns the request data in a dictionary. - def __init__(self, request: Request) -> None: - self.map: Dict[int, dict] = { - self.DATA: request.get_json() if request.data else {}, - self.VALUES: request.values, - self.FILES: request.files - } - return + Args: + request (Request): The request object. - def __getitem__(self, key: int) -> dict: - return self.map[key] - - -class DataType: - STR = 'string' - INT = 'number' - FLOAT = 'decimal number' - BOOL = 'bool' - INT_ARRAY = 'list of numbers' - NA = 'N/A' - - -class InputVariable(ABC): - value: Any - - @abstractmethod - def __init__(self, value: Any) -> None: - pass - - @property - @abstractmethod - def name(self) -> str: - pass - - @abstractmethod - def validate(self) -> bool: - pass - - @property - @abstractmethod - def required(self) -> bool: - pass - - @property - @abstractmethod - def data_type(self) -> List[str]: - pass - - @property - @abstractmethod - def default(self) -> Any: - pass - - @property - @abstractmethod - def source(self) -> int: - pass - - @property - @abstractmethod - def description(self) -> str: - pass - - @property - @abstractmethod - def related_exceptions(self) -> List[Exception]: - pass - - -@dataclass(frozen=True) -class Method: - description: str = '' - vars: List[Type[InputVariable]] = field(default_factory=list) - - def __bool__(self) -> bool: - return self.vars != [] - - -@dataclass(frozen=True) -class Methods: - get: Method = Method() - post: Method = Method() - put: Method = Method() - delete: Method = Method() - - def __getitem__(self, key: str) -> Method: - return getattr(self, key.lower()) - - def __bool__(self) -> bool: - return bool(self.get or self.post or self.put or self.delete) - - -@dataclass(frozen=True) -class ApiDocEntry: - endpoint: str - description: str - requires_auth: bool - used_methods: List[str] - methods: Methods + Returns: + Dict[DataSource, Dict[str, Any]]: The request data. + """ + return { + DataSource.DATA: request.get_json() if request.data else {}, + DataSource.VALUES: request.values, + DataSource.FILES: request.files + } def get_api_docs(request: Request) -> ApiDocEntry: - if request.path.startswith(SERVER.admin_prefix): - url = SERVER.admin_api_extension + request.url_rule.rule.split(SERVER.admin_prefix)[1] - else: - url = request.url_rule.rule.split(SERVER.api_prefix)[1] - return api_docs[url] + """Returns the API documentation for the given request. + + Args: + request (Request): The request object. + + Returns: + ApiDocEntry: The API documentation for the used endpoint. + """ + assert (request.url_rule is not None) + + if request.path.startswith(Server.admin_prefix): + url = ( + Server.admin_api_extension + + request.url_rule.rule.split(Server.admin_prefix)[1] + ) + else: + url = request.url_rule.rule.split(Server.api_prefix)[1] + + return api_docs[url] -class BaseInputVariable(InputVariable): - source = DataSource.DATA - data_type = [DataType.STR] - required = True - default = None - related_exceptions = [KeyNotFound, InvalidKeyValue] - - def __init__(self, value: Any) -> None: - self.value = value - - def validate(self) -> bool: - return isinstance(self.value, str) and self.value - - def __repr__(self) -> str: - return f'| {self.name} | {"Yes" if self.required else "No"} | {",".join(self.data_type)} | {self.description} | N/A |' +def dl(*args: T) -> List[T]: + return field(default_factory=lambda: list(args)) -class NonRequiredVersion(BaseInputVariable): - required = False - related_exceptions = [InvalidKeyValue] +@dataclass +class NonRequiredVersion(InputVariable): + required: bool = False + related_exceptions: List[Type[MindException]] = dl(InvalidKeyValue) - def __init__(self, value: Any) -> None: - super().__init__( - value - if value is not None else - self.default - ) - return + def __post_init__(self) -> None: + if self.value is None: + self.value = self.default + return - def validate(self) -> bool: - return self.value is None or super().validate() + def validate(self) -> bool: + return self.value is None or super().validate() -class UsernameVariable(BaseInputVariable): - name = 'username' - description = 'The username of the user account' - related_exceptions = [KeyNotFound, UserNotFound] +# =================== +# region Variables +# =================== +@dataclass +class UsernameVariable(InputVariable): + name: str = 'username' + description: str = 'The username of the user account' + related_exceptions: List[Type[MindException]] = dl( + KeyNotFound, UsernameInvalid + ) -class PasswordCreateVariable(BaseInputVariable): - name = 'password' - description = 'The password of the user account' - related_exceptions = [KeyNotFound] +@dataclass +class PasswordCreateVariable(InputVariable): + name: str = 'password' + description: str = 'The password of the user account' + related_exceptions: List[Type[MindException]] = dl(KeyNotFound) +@dataclass class PasswordVariable(PasswordCreateVariable): - related_exceptions = [KeyNotFound, AccessUnauthorized] + related_exceptions: List[Type[MindException]] = dl( + KeyNotFound, AccessUnauthorized) +@dataclass class UsernameCreateVariable(UsernameVariable): - related_exceptions = [ - KeyNotFound, - UsernameInvalid, UsernameTaken, - NewAccountsNotAllowed - ] + related_exceptions: List[Type[MindException]] = dl( + KeyNotFound, + UsernameInvalid, UsernameTaken, + NewAccountsNotAllowed + ) -class NewPasswordVariable(BaseInputVariable): - name = 'new_password' - description = 'The new password of the user account' - related_exceptions = [KeyNotFound] +@dataclass +class NewPasswordVariable(InputVariable): + name: str = 'new_password' + description: str = 'The new password of the user account' + related_exceptions: List[Type[MindException]] = dl(KeyNotFound) -class TitleVariable(BaseInputVariable): - name = 'title' - description = 'The title of the entry' +@dataclass +class TitleVariable(InputVariable): + name: str = 'title' + description: str = 'The title of the entry' -class URLVariable(BaseInputVariable): - name = 'url' - description = 'The Apprise URL of the notification service' +@dataclass +class URLVariable(InputVariable): + name: str = 'url' + description: str = 'The Apprise URL of the notification service' - def validate(self) -> bool: - return super().validate() and Apprise().add(self.value) + def validate(self) -> bool: + return super().validate() and Apprise().add(self.value) +@dataclass class EditTitleVariable(NonRequiredVersion, TitleVariable): - pass + pass +@dataclass class EditURLVariable(NonRequiredVersion, URLVariable): - pass + pass -class SortByVariable(NonRequiredVersion, BaseInputVariable): - name = 'sort_by' - description = 'How to sort the result' - source = DataSource.VALUES - _options = [k.lower() for k in SortingMethod._member_names_] - default = SortingMethod._member_names_[0].lower() +@dataclass +class SortByVariable(NonRequiredVersion, InputVariable): + name: str = 'sort_by' + description: str = 'How to sort the result' + source: DataSource = DataSource.VALUES + _options: List[str] = dl(*(k.lower() for k in SortingMethod._member_names_)) + default: Any = SortingMethod.TIME - def validate(self) -> bool: - if not self.value in self._options: - return False + def validate(self) -> bool: + if self.value not in self._options: + return False - self.value = SortingMethod[self.value.upper()] - return True + self.value = SortingMethod[self.value.upper()] + return True - def __repr__(self) -> str: - return '| {n} | {r} | {t} | {d} | {v} |'.format( - n=self.name, - r="Yes" if self.required else "No", - t=",".join(self.data_type), - d=self.description, - v=", ".join(f'`{o}`' for o in self._options) - ) + def __repr__(self) -> str: + return '| {n} | {r} | {t} | {d} | {v} |'.format( + n=self.name, + r="Yes" if self.required else "No", + t=",".join(d.value for d in self.data_type), + d=self.description, + v=", ".join(f'`{o}`' for o in self._options) + ) +@dataclass class TimelessSortByVariable(SortByVariable): - _options = [k.lower() for k in TimelessSortingMethod._member_names_] - default = TimelessSortingMethod._member_names_[0].lower() + _options: List[str] = dl(*(k.lower() + for k in TimelessSortingMethod._member_names_)) + default: Any = TimelessSortingMethod.TITLE - def validate(self) -> bool: - if not self.value in self._options: - return False + def validate(self) -> bool: + if self.value not in self._options: + return False - self.value = TimelessSortingMethod[self.value.upper()] - return True + self.value = TimelessSortingMethod[self.value.upper()] + return True -class TimeVariable(BaseInputVariable): - name = 'time' - description = 'The UTC epoch timestamp that the reminder should be sent at' - data_type = [DataType.INT, DataType.FLOAT] - related_exceptions = [KeyNotFound, InvalidKeyValue, InvalidTime] +@dataclass +class TimeVariable(InputVariable): + name: str = 'time' + description: str = 'The UTC epoch timestamp that the reminder should be sent at' + data_type: List[DataType] = dl(DataType.INT, DataType.FLOAT) + related_exceptions: List[Type[MindException]] = dl( + KeyNotFound, InvalidKeyValue, InvalidTime) - def validate(self) -> bool: - return isinstance(self.value, (float, int)) + def validate(self) -> bool: + return isinstance(self.value, (float, int)) +@dataclass class EditTimeVariable(NonRequiredVersion, TimeVariable): - related_exceptions = [InvalidKeyValue, InvalidTime] + related_exceptions: List[Type[MindException]] = dl( + InvalidKeyValue, InvalidTime) -class NotificationServicesVariable(BaseInputVariable): - name = 'notification_services' - description = "Array of the id's of the notification services to use to send the notification" - data_type = [DataType.INT_ARRAY] - related_exceptions = [ - KeyNotFound, InvalidKeyValue, - NotificationServiceNotFound - ] +@dataclass +class NotificationServicesVariable(InputVariable): + name: str = 'notification_services' + description: str = "Array of the id's of the notification services to use to send the notification" + data_type: List[DataType] = dl(DataType.INT_ARRAY) + related_exceptions: List[Type[MindException]] = dl( + KeyNotFound, InvalidKeyValue, + NotificationServiceNotFound + ) - def validate(self) -> bool: - if not isinstance(self.value, list): - return False - if not self.value: - return False - for v in self.value: - if not isinstance(v, int): - return False - return True + def validate(self) -> bool: + if not isinstance(self.value, list): + return False + if not self.value: + return False + for v in self.value: + if not isinstance(v, int): + return False + return True -class EditNotificationServicesVariable(NonRequiredVersion, NotificationServicesVariable): - related_exceptions = [InvalidKeyValue, NotificationServiceNotFound] +@dataclass +class EditNotificationServicesVariable( + NonRequiredVersion, + NotificationServicesVariable +): + related_exceptions: List[Type[MindException]] = dl( + InvalidKeyValue, NotificationServiceNotFound) -class TextVariable(NonRequiredVersion, BaseInputVariable): - name = 'text' - description = 'The body of the entry' - default = '' +@dataclass +class TextVariable(NonRequiredVersion): + name: str = 'text' + description: str = 'The body of the entry' + default: Any = '' - def validate(self) -> bool: - return isinstance(self.value, str) + def validate(self) -> bool: + return isinstance(self.value, str) -class RepeatQuantityVariable(NonRequiredVersion, BaseInputVariable): - name = 'repeat_quantity' - description = 'The quantity of the repeat_interval' - _options = [m.lower() for m in RepeatQuantity._member_names_] +@dataclass +class RepeatQuantityVariable(NonRequiredVersion): + name: str = 'repeat_quantity' + description: str = 'The quantity of the repeat_interval' + _options: List[str] = dl(*(m.lower() + for m in RepeatQuantity._member_names_)) - def validate(self) -> bool: - if self.value is None: - return True + def validate(self) -> bool: + if self.value is None: + return True - if not self.value in self._options: - return False + if self.value not in self._options: + return False - self.value = RepeatQuantity[self.value.upper()] - return True + self.value = RepeatQuantity[self.value.upper()] + return True - def __repr__(self) -> str: - return '| {n} | {r} | {t} | {d} | {v} |'.format( - n=self.name, - r="Yes" if self.required else "No", - t=",".join(self.data_type), - d=self.description, - v=", ".join(f'`{o}`' for o in self._options) - ) + def __repr__(self) -> str: + return '| {n} | {r} | {t} | {d} | {v} |'.format( + n=self.name, + r="Yes" if self.required else "No", + t=",".join(d.value for d in self.data_type), + d=self.description, + v=", ".join(f'`{o}`' for o in self._options) + ) -class RepeatIntervalVariable(NonRequiredVersion, BaseInputVariable): - name = 'repeat_interval' - description = 'The number of the interval' - data_type = [DataType.INT] +@dataclass +class RepeatIntervalVariable(NonRequiredVersion): + name: str = 'repeat_interval' + description: str = 'The number of the interval' + data_type: List[DataType] = dl(DataType.INT) - def validate(self) -> bool: - return ( - self.value is None - or ( - isinstance(self.value, int) - and self.value > 0 - ) - ) + def validate(self) -> bool: + return ( + self.value is None + or ( + isinstance(self.value, int) + and self.value > 0 + ) + ) -class WeekDaysVariable(NonRequiredVersion, BaseInputVariable): - name = 'weekdays' - description = 'On which days of the weeks to run the reminder' - data_type = [DataType.INT_ARRAY] - _options = {0, 1, 2, 3, 4, 5, 6} +@dataclass +class WeekDaysVariable(NonRequiredVersion): + name: str = 'weekdays' + description: str = 'On which days of the weeks to run the reminder' + data_type: List[DataType] = dl(DataType.INT_ARRAY) + _options = {0, 1, 2, 3, 4, 5, 6} - def validate(self) -> bool: - return self.value is None or ( - isinstance(self.value, list) - and len(self.value) > 0 - and all(v in self._options for v in self.value) - ) + def validate(self) -> bool: + return self.value is None or ( + isinstance(self.value, list) + and len(self.value) > 0 + and all(v in self._options for v in self.value) + ) - def __repr__(self) -> str: - return '| {n} | {r} | {t} | {d} | {v} |'.format( - n=self.name, - r="Yes" if self.required else "No", - t=",".join(self.data_type), - d=self.description, - v=", ".join(f'`{o}`' for o in self._options) - ) - -class ColorVariable(NonRequiredVersion, BaseInputVariable): - name = 'color' - description = 'The hex code of the color of the entry, which is shown in the web-ui' - - def validate(self) -> bool: - return self.value is None or ( - isinstance(self.value, str) - and color_regex.search(self.value) - ) + def __repr__(self) -> str: + return '| {n} | {r} | {t} | {d} | {v} |'.format( + n=self.name, + r="Yes" if self.required else "No", + t=",".join(d.value for d in self.data_type), + d=self.description, + v=", ".join(f'`{o}`' for o in self._options) + ) -class QueryVariable(BaseInputVariable): - name = 'query' - description = 'The search term' - source = DataSource.VALUES +@dataclass +class ColorVariable(NonRequiredVersion): + name: str = 'color' + description: str = 'The hex code of the color of the entry, which is shown in the web-ui' + + def validate(self) -> bool: + return self.value is None or ( + isinstance(self.value, str) + and color_regex.search(self.value) is not None + ) -class DeleteRemindersUsingVariable(NonRequiredVersion, BaseInputVariable): - name = 'delete_reminders_using' - description = 'Instead of throwing an error when there are still reminders using the service, delete the reminders.' - source = DataSource.VALUES - default = 'false' - data_type = [DataType.BOOL] - - def validate(self) -> bool: - if self.value == 'true': - self.value = True - return True - - elif self.value == 'false': - self.value = False - return True - - else: - return False +@dataclass +class QueryVariable(InputVariable): + name: str = 'query' + description: str = 'The search term' + source: DataSource = DataSource.VALUES -class AdminSettingsVariable(BaseInputVariable): - def validate(self) -> bool: - try: - _format_setting(self.name, self.value) - except InvalidKeyValue: - return False - return True +@dataclass +class DeleteRemindersUsingVariable(NonRequiredVersion): + name: str = 'delete_reminders_using' + description: str = 'Instead of throwing an error when there are still reminders using the service, delete the reminders.' + source: DataSource = DataSource.VALUES + default: Any = 'false' + data_type: List[DataType] = dl(DataType.BOOL) + + def validate(self) -> bool: + if self.value == 'true': + self.value = True + return True + + elif self.value == 'false': + self.value = False + return True + + else: + return False +@dataclass +class AdminSettingsVariable(InputVariable): + def validate(self) -> bool: + # @dataclassValidation is done in + # the settings class + return True + + +@dataclass class AllowNewAccountsVariable(NonRequiredVersion, AdminSettingsVariable): - name = 'allow_new_accounts' - description = ('Whether or not to allow users to register a new account. ' - + 'The admin can always add a new account.') - data_type = [DataType.BOOL] + name: str = 'allow_new_accounts' + description: str = ( + 'Whether or not to allow users to register a new account. ' + + 'The admin can always add a new account.') + data_type: List[DataType] = dl(DataType.BOOL) +@dataclass class LoginTimeVariable(NonRequiredVersion, AdminSettingsVariable): - name = 'login_time' - description = ('How long a user stays logged in, in seconds. ' - + 'Between 1 min and 1 month (60 <= sec <= 2592000)') - data_type = [DataType.INT] + name: str = 'login_time' + description: str = ('How long a user stays logged in, in seconds. ' + + 'Between 1 min and 1 month (60 <= sec <= 2592000)') + data_type: List[DataType] = dl(DataType.INT) +@dataclass class LoginTimeResetVariable(NonRequiredVersion, AdminSettingsVariable): - name = 'login_time_reset' - description = 'If the Login Time timer should reset with each API request.' - data_type = [DataType.BOOL] + name: str = 'login_time_reset' + description: str = 'If the Login Time timer should reset with each API request.' + data_type: List[DataType] = dl(DataType.BOOL) +@dataclass class HostVariable(NonRequiredVersion, AdminSettingsVariable): - name = 'host' - description = 'The IP to bind to. Use 0.0.0.0 to bind to all addresses.' + name: str = 'host' + description: str = 'The IP to bind to. Use 0.0.0.0 to bind to all addresses.' +@dataclass class PortVariable(NonRequiredVersion, AdminSettingsVariable): - name = 'port' - description = 'The port to listen on.' - data_type = [DataType.INT] + name: str = 'port' + description: str = 'The port to listen on.' + data_type: List[DataType] = dl(DataType.INT) +@dataclass class UrlPrefixVariable(NonRequiredVersion, AdminSettingsVariable): - name = 'url_prefix' - description = 'The base url to run on. Useful for reverse proxies. Empty string to disable.' + name: str = 'url_prefix' + description: str = 'The base url to run on. Useful for reverse proxies. Empty string to disable.' +@dataclass class LogLevelVariable(NonRequiredVersion, AdminSettingsVariable): - name = 'log_level' - description = 'The level to log on.' - data_type = [DataType.INT] - _options = [logging.INFO, logging.DEBUG] + name: str = 'log_level' + description: str = 'The level to log on.' + data_type: List[DataType] = dl(DataType.INT) + _options = [INFO, DEBUG] - def __repr__(self) -> str: - return '| {n} | {r} | {t} | {d} | {v} |'.format( - n=self.name, - r="Yes" if self.required else "No", - t=",".join(self.data_type), - d=self.description, - v=", ".join(f'`{o}`' for o in self._options) - ) + def __repr__(self) -> str: + return '| {n} | {r} | {t} | {d} | {v} |'.format( + n=self.name, + r="Yes" if self.required else "No", + t=",".join(d.value for d in self.data_type), + d=self.description, + v=", ".join(f'`{o}`' for o in self._options) + ) -class DatabaseFileVariable(BaseInputVariable): - name = 'file' - description = 'The MIND database file' - data_type = [DataType.NA] - source = DataSource.FILES - related_exceptions = [KeyNotFound, InvalidDatabaseFile] +@dataclass +class DatabaseFileVariable(InputVariable): + name: str = 'file' + description: str = 'The MIND database file' + data_type: List[DataType] = dl(DataType.NA) + source: DataSource = DataSource.FILES + related_exceptions: List[Type[MindException]] = dl( + KeyNotFound, InvalidDatabaseFile) - def validate(self) -> bool: - if ( - self.value.filename - and splitext(self.value.filename)[1] == '.db' - ): - path = folder_path('db', 'MIND_upload.db') - self.value.save(path) - self.value = path - return True - else: - return False + def validate(self) -> bool: + if ( + self.value.filename + and splitext(self.value.filename)[1] == '.db' + ): + path = folder_path('db', 'MIND_upload.db') + self.value.save(path) + self.value = path + return True + + return False -class CopyHostingSettingsVariable(BaseInputVariable): - name = 'copy_hosting_settings' - description = 'Copy the hosting settings from the current database' - data_type = [DataType.BOOL] - source = DataSource.VALUES +@dataclass +class CopyHostingSettingsVariable(InputVariable): + name: str = 'copy_hosting_settings' + description: str = 'Copy the hosting settings from the current database' + data_type: List[DataType] = dl(DataType.BOOL) + source: DataSource = DataSource.VALUES - def validate(self) -> bool: - if not self.value in ('true', 'false'): - return False + def validate(self) -> bool: + if self.value not in ('true', 'false'): + return False - self.value = self.value == 'true' - return True + self.value = self.value == 'true' + return True -def input_validation() -> Union[None, Dict[str, Any]]: - """Checks, extracts and transforms inputs +# =================== +# region Endpoints +# =================== +def input_validation() -> Dict[str, Any]: + """Checks, extracts and transforms inputs. - Raises: - KeyNotFound: A required key was not supplied - InvalidKeyValue: The value of a key is not valid + Raises: + KeyNotFound: A required key was not supplied. + InvalidKeyValue: The value of a key is not valid. - Returns: - Union[None, Dict[str, Any]]: `None` if the endpoint + method doesn't require input variables. - Otherwise `Dict[str, Any]` with the input variables, checked and formatted. - """ - result = {} + Returns: + Dict[str, Any]: The input variables, checked and formatted. + """ + method = get_api_docs(request).methods[request.method] + if not method: + return {} - methods = get_api_docs(request).methods - method = methods[request.method] - noted_variables = method.vars + result = {} + noted_variables = method.vars + given_variables = request_data(request) + for noted_var in noted_variables: + if noted_var.name not in given_variables[noted_var.source]: + if noted_var.required: + # Variable not given while required + raise KeyNotFound(noted_var.name) + else: + # Variable not given while not required, so set to default + result[noted_var.name] = noted_var.default + continue - if not methods: - return None + input_value = given_variables[noted_var.source][noted_var.name] + value = noted_var(input_value) # type: ignore - if not method: - return result + if not value.validate(): + if isinstance(value, DatabaseFileVariable): + raise InvalidDatabaseFile(value.value) + elif noted_var.source == DataSource.FILES: + raise InvalidKeyValue(noted_var.name, input_value.filename) + else: + raise InvalidKeyValue(noted_var.name, input_value) - given_variables = DataSource(request) + result[noted_var.name] = value.value - for noted_var in noted_variables: - if ( - noted_var.required and - not noted_var.name in given_variables[noted_var.source] - ): - raise KeyNotFound(noted_var.name) - - input_value = given_variables[noted_var.source].get(noted_var.name) - value: InputVariable = noted_var(input_value) - - if not value.validate(): - if noted_var.__class__.__name__ == DatabaseFileVariable.__name__: - raise InvalidDatabaseFile - elif noted_var.source == DataSource.FILES: - raise InvalidKeyValue(noted_var.name, input_value.filename) - else: - raise InvalidKeyValue(noted_var.name, input_value) - - result[noted_var.name] = value.value - return result + return result class APIBlueprint(Blueprint): - def route( - self, - rule: str, - description: str = '', - input_variables: Methods = Methods(), - requires_auth: bool = True, - **options: Any - ) -> Callable[[T_route], T_route]: + def route( + self, + rule: str, + description: str = '', + input_variables: Methods = Methods(), + requires_auth: bool = True, + **options: Any + ) -> Callable[[T_route], T_route]: - if self == api: - processed_rule = rule - elif self == admin_api: - processed_rule = SERVER.admin_api_extension + rule - else: - raise NotImplementedError + if self == api: + processed_rule = rule + elif self == admin_api: + processed_rule = Server.admin_api_extension + rule + else: + raise NotImplementedError - api_docs[processed_rule] = ApiDocEntry( - endpoint=processed_rule, - description=description, - requires_auth=requires_auth, - used_methods=options['methods'], - methods=input_variables - ) + api_docs[processed_rule] = ApiDocEntry( + endpoint=processed_rule, + description=description, + requires_auth=requires_auth, + methods=input_variables + ) + + if "methods" not in options: + options["methods"] = api_docs[processed_rule].methods.used_methods() + + return super().route(rule, **options) - return super().route(rule, **options) api = APIBlueprint('api', __name__) admin_api = APIBlueprint('admin_api', __name__) diff --git a/frontend/static/json/pwa_manifest.json b/frontend/static/json/pwa_manifest.json index a9ab530..32919ca 100644 --- a/frontend/static/json/pwa_manifest.json +++ b/frontend/static/json/pwa_manifest.json @@ -2,6 +2,7 @@ "name": "MIND", "short_name": "MIND", "start_url": "/", + "scope": "/", "display": "standalone", "background_color": "#1b1b1b", "theme_color": "#6b6b6b", diff --git a/frontend/ui.py b/frontend/ui.py index 35c3dd9..3e90d14 100644 --- a/frontend/ui.py +++ b/frontend/ui.py @@ -1,21 +1,35 @@ -#-*- coding: utf-8 -*- +# -*- coding: utf-8 -*- + +from typing import Any from flask import Blueprint, render_template -from backend.server import SERVER +from backend.internals.server import Server ui = Blueprint('ui', __name__) - methods = ['GET'] +SERVER = Server() + + +def render(filename: str, **kwargs: Any) -> str: + return render_template(filename, url_prefix=SERVER.url_prefix, **kwargs) + + +@ui.errorhandler(404) +def ui_not_found(e): + return render('page_not_found.html') + @ui.route('/', methods=methods) def ui_login(): - return render_template('login.html', url_prefix=SERVER.url_prefix) + return render('login.html') + @ui.route('/reminders', methods=methods) def ui_reminders(): - return render_template('reminders.html', url_prefix=SERVER.url_prefix) + return render('reminders.html') + @ui.route('/admin', methods=methods) def ui_admin(): - return render_template('admin.html', url_prefix=SERVER.url_prefix) + return render('admin.html') diff --git a/project_management/docs-requirements.txt b/project_management/docs-requirements.txt deleted file mode 100644 index 37c7ed3..0000000 --- a/project_management/docs-requirements.txt +++ /dev/null @@ -1,6 +0,0 @@ -wheel>=0.38.4 -mkdocs-material>=8.5.11 -mkdocs-redirects>=1.2.0 -mkdocs-git-revision-date-localized-plugin>=1.1.0 -Pygments>=2.13.0 -pymdown-extensions>=9.9 diff --git a/project_management/generate_api_docs.py b/project_management/generate_api_docs.py index ca8115c..f3f1b5b 100644 --- a/project_management/generate_api_docs.py +++ b/project_management/generate_api_docs.py @@ -1,39 +1,49 @@ #!/usr/bin/env python3 -#-*- coding: utf-8 -*- +# -*- coding: utf-8 -*- + +# autopep8: off from os.path import dirname -from sys import path - -path.insert(0, dirname(path[0])) - from subprocess import run -from typing import Union +from sys import path +from typing import Type -from backend.helpers import folder_path -from backend.server import SERVER -from frontend.api import (NotificationServiceNotFound, ReminderNotFound, - TemplateNotFound) +path.insert(0, dirname(dirname(__file__))) + +import frontend.api +from backend.base.custom_exceptions import (NotificationServiceNotFound, + ReminderNotFound, TemplateNotFound) +from backend.base.definitions import MindException, StartType +from backend.base.helpers import folder_path +from backend.internals.server import Server from frontend.input_validation import DataSource, api_docs -api_prefix = SERVER.api_prefix -admin_prefix = SERVER.admin_prefix +# autopep8: on + +api_prefix = Server.api_prefix +admin_prefix = Server.admin_prefix api_file = folder_path('docs', 'other_docs', 'api.md') url_var_map = { - 'int:n_id': NotificationServiceNotFound, - 'int:r_id': ReminderNotFound, - 'int:t_id': TemplateNotFound, - 'int:s_id': ReminderNotFound + 'int:n_id': NotificationServiceNotFound, + 'int:r_id': ReminderNotFound, + 'int:t_id': TemplateNotFound, + 'int:s_id': ReminderNotFound } -def make_exception_instance(cls: Exception) -> Exception: - try: - return cls() - except TypeError: - try: - return cls('1') - except TypeError: - return cls('1', '2') + +def make_exception_instance(cls: Type[MindException]) -> MindException: + try: + return cls() + except TypeError: + try: + return cls('1') + except TypeError: + try: + return cls('1', '2') + except AttributeError: + return cls('1', StartType.STARTUP) + result = f"""# API Below is the API documentation. Report an issue on [GitHub](https://github.com/Casvt/MIND/issues). @@ -86,76 +96,96 @@ The following is automatically generated. Please report any issues on [GitHub](h """ for rule, data in api_docs.items(): - result += f"""### `{rule}` + result += f"""### `{rule}` | Requires being logged in | Description | | ------------------------ | ----------- | -| {'Yes' if data.requires_auth else 'No'} | {data.description} | +| {'Yes' if data.requires_auth else 'No'} | {data.description} | """ - url_var = rule.replace('<', '>').split('>') - url_var: Union[str, None] = None if len(url_var) == 1 else url_var[1] + url_var = rule.replace('<', '>').split('>') + url_var = None if len(url_var) == 1 else url_var[1] - if url_var: - result += f""" + if url_var: + result += f""" Replace `<{url_var}>` with the ID of the entry. For example: `{rule.replace(f'<{url_var}>', '2')}`. """ - for m_name, method in ((m, data.methods[m]) for m in data.used_methods): - result += f"\n??? {m_name}\n" + for m_name, method in ((m, data.methods[m]) + for m in data.methods.used_methods()): + if method is None: + continue - if method.description: - result += f"\n {method.description}\n" + result += f"\n??? {m_name}\n" - var_types = { - 'url': [v for v in method.vars if v.source == DataSource.VALUES], - 'body': [v for v in method.vars if v.source == DataSource.DATA], - 'file': [v for v in method.vars if v.source == DataSource.FILES] - } + if method.description: + result += f"\n {method.description}\n" - for var_type, entries in var_types.items(): - if entries: - result += f""" + var_types = { + 'url': [ + v for v in method.vars if v.source == DataSource.VALUES + ], + 'body': [ + v for v in method.vars if v.source == DataSource.DATA + ], + 'file': [ + v for v in method.vars if v.source == DataSource.FILES + ] + } + + for var_type, entries in var_types.items(): + if entries: + result += f""" **Parameters ({var_type})** | Name | Required | Data type | Description | Allowed values | | ---- | -------- | --------- | ----------- | -------------- | """ - for entry in entries: - result += f" {entry('')}\n" - - result += f""" + for entry in entries: + result += f" {super(entry, entry('', entry.name, entry.description)).__repr__()}\n" + + result += f""" **Returns** - + | Code | Error | Description | | ---- | ----- | ----------- | | {201 if m_name == 'POST' else 200} | N/A | Success | """ - url_exception = [url_var_map[url_var]] if url_var in url_var_map else [] - variable_exceptions = [e for v in method.vars for e in v.related_exceptions] - related_exceptions = sorted( - (make_exception_instance(e) for e in set(variable_exceptions + url_exception)), - key=lambda e: (e.api_response['code'], e.api_response['error']) - ) - for related_exception in related_exceptions: - ar = related_exception.api_response - result += f" | {ar['code']} | {ar['error']} | {related_exception.__doc__} |\n" + url_exception = [url_var_map[url_var]] if url_var in url_var_map else [] + variable_exceptions = [ + e + for v in method.vars + for e in v('t', v.name, v.description).related_exceptions + ] + related_exceptions = sorted( + ( + make_exception_instance(e) + for e in set(variable_exceptions + url_exception) + ), + key=lambda e: ( + e.api_response["code"], + e.api_response["error"] + ) + ) + for related_exception in related_exceptions: + ar = related_exception.api_response + result += f" | {ar['code']} | {ar['error']} | {related_exception.__doc__} |\n" - result += '\n' + result += '\n' with open(api_file, 'r') as f: - current_content = f.read() + current_content = f.read() if current_content == result: - print('Nothing changed') + print('Nothing changed') else: - with open(api_file, 'w+') as f: - f.write(result) + with open(api_file, 'w+') as f: + f.write(result) - # run(["git", "config", "--global", "user.email", '"casvantijn@gmail.com"']) - # run(["git", "config", "--global", "user.name", '"CasVT"']) - # run(["git", "checkout", "Development"]) - # run(["git", "add", api_file]) - # run(["git", "commit", "-m", "Updated API docs"]) - # run(["git", "push"]) + run(["git", "config", "--global", "user.email", '"casvantijn@gmail.com"']) + run(["git", "config", "--global", "user.name", '"CasVT"']) + run(["git", "checkout", "Development"]) + run(["git", "add", api_file]) + run(["git", "commit", "-m", "Updated API docs"]) + run(["git", "push"]) diff --git a/project_management/requirements-docs.txt b/project_management/requirements-docs.txt new file mode 100644 index 0000000..6e40420 --- /dev/null +++ b/project_management/requirements-docs.txt @@ -0,0 +1,6 @@ +wheel >= 0.38.4 +mkdocs-material >= 8.5.11 +mkdocs-redirects >= 1.2.0 +mkdocs-git-revision-date-localized-plugin >= 1.1.0 +Pygments >= 2.13.0 +pymdown-extensions >= 9.9 diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..2485f8b --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,32 @@ +[project] +name = "MIND" +version = "1.4.1" +description = "MIND is a simple self hosted reminder application that can send push notifications to your device. Set the reminder and forget about it!" +authors = [ + {name = "Cas van Tijn"} +] +readme = "README.md" +license = {file = "LICENSE"} +requires-python = ">= 3.8" + +[tool.mypy] +warn_unused_configs = true +sqlite_cache = true +cache_fine_grained = true + +ignore_missing_imports = true +disable_error_code = ["abstract", "annotation-unchecked", "arg-type", "assert-type", "assignment", "attr-defined", "await-not-async", "call-arg", "call-overload", "dict-item", "empty-body", "exit-return", "func-returns-value", "has-type", "import", "import-not-found", "import-untyped", "index", "list-item", "literal-required", "method-assign", "misc", "name-defined", "name-match", "no-overload-impl", "no-redef", "operator", "override", "return", "return-value", "safe-super", "str-bytes-safe", "str-format", "syntax", "top-level-await", "truthy-function", "type-abstract", "type-var", "typeddict-item", "typeddict-unknown-key", "union-attr", "unused-coroutine", "used-before-def", "valid-newtype", "valid-type", "var-annotated"] +enable_error_code = ["method-assign", "func-returns-value", "name-match", "no-overload-impl", "unused-coroutine", "top-level-await", "await-not-async", "str-format", "redundant-expr", "unused-awaitable"] + +[tool.isort] +balanced_wrapping = true +combine_as_imports = true +combine_star = true +honor_noqa = true +remove_redundant_aliases = true + +[tool.autopep8] +aggressive = 3 +experimental = true +max_line_length = 80 +ignore = ["E124", "E125", "E126", "E128", "E261"] diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 0000000..aaae271 --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,4 @@ +autopep8 ~= 2.2 +isort ~= 5.13 +mypy ~= 1.10 +pre-commit ~= 3.5 diff --git a/requirements.txt b/requirements.txt index 5238c51..aaf6164 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -apprise~=1.4 -python-dateutil~=2.8 -Flask~=3.0 -waitress~=2.1 +apprise ~= 1.4 +python-dateutil ~= 2.8 +Flask ~= 3.0 +waitress ~= 2.1 diff --git a/tests/MIND_test.py b/tests/MIND_test.py deleted file mode 100644 index 5a497a7..0000000 --- a/tests/MIND_test.py +++ /dev/null @@ -1,22 +0,0 @@ -import unittest - -from flask import Flask - -from frontend.api import api -from frontend.ui import ui -from backend.server import SERVER - -class Test_MIND(unittest.TestCase): - def test_create_app(self): - SERVER.create_app() - self.assertTrue(hasattr(SERVER, 'app')) - app = SERVER.app - self.assertIsInstance(app, Flask) - - self.assertEqual(app.blueprints.get('ui'), ui) - self.assertEqual(app.blueprints.get('api'), api) - - handlers = app.error_handler_spec[None].keys() - required_handlers = 400, 405, 500 - for handler in required_handlers: - self.assertIn(handler, handlers) diff --git a/tests/Tbackend/MIND_test.py b/tests/Tbackend/MIND_test.py new file mode 100644 index 0000000..5706325 --- /dev/null +++ b/tests/Tbackend/MIND_test.py @@ -0,0 +1,24 @@ +import unittest + +from flask import Flask + +from backend.internals.server import Server +from frontend.api import api +from frontend.ui import ui + + +class Test_MIND(unittest.TestCase): + def test_create_app(self): + SERVER = Server() + SERVER.create_app() + self.assertTrue(hasattr(SERVER, 'app')) + app = SERVER.app + self.assertIsInstance(app, Flask) + + self.assertEqual(app.blueprints.get('ui'), ui) + self.assertEqual(app.blueprints.get('api'), api) + + handlers = app.error_handler_spec[None].keys() + required_handlers = 400, 405, 500 + for handler in required_handlers: + self.assertIn(handler, handlers) diff --git a/tests/Tbackend/__init__.py b/tests/Tbackend/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/Tbackend/custom_exceptions_test.py b/tests/Tbackend/custom_exceptions_test.py new file mode 100644 index 0000000..47dc983 --- /dev/null +++ b/tests/Tbackend/custom_exceptions_test.py @@ -0,0 +1,5 @@ +import unittest + + +class Test_Custom_Exceptions(unittest.TestCase): + pass diff --git a/tests/Tbackend/db_test.py b/tests/Tbackend/db_test.py new file mode 100644 index 0000000..9d41052 --- /dev/null +++ b/tests/Tbackend/db_test.py @@ -0,0 +1,26 @@ +import unittest +from os.path import join + +from flask import Flask + +from backend.base.helpers import folder_path +from backend.internals.db import Constants, DBConnection, close_db + + +class Test_DB(unittest.TestCase): + def test_foreign_key_and_wal(self): + app = Flask(__name__) + app.teardown_appcontext(close_db) + + DBConnection.file = join( + folder_path(*Constants.DB_FOLDER), + Constants.DB_NAME + ) + with app.app_context(): + instance = DBConnection(timeout=Constants.DB_TIMEOUT) + self.assertEqual( + instance.cursor().execute( + "PRAGMA foreign_keys;" + ).fetchone()[0], + 1 + ) diff --git a/tests/Tbackend/reminders_test.py b/tests/Tbackend/reminders_test.py new file mode 100644 index 0000000..19226c2 --- /dev/null +++ b/tests/Tbackend/reminders_test.py @@ -0,0 +1,19 @@ +import unittest + +from backend.base.definitions import GeneralReminderData +from backend.base.helpers import search_filter + + +class Test_Reminder_Handler(unittest.TestCase): + def test_filter_function(self): + p = GeneralReminderData( + id=1, + title='TITLE', + text='TEXT', + color=None, + notification_services=[] + ) + for test_case in ('', 'title', 'ex'): + self.assertTrue(search_filter(test_case, p)) + for test_case in (' ', 'Hello'): + self.assertFalse(search_filter(test_case, p)) diff --git a/tests/Tbackend/security_test.py b/tests/Tbackend/security_test.py new file mode 100644 index 0000000..d9b82e6 --- /dev/null +++ b/tests/Tbackend/security_test.py @@ -0,0 +1,10 @@ +import unittest + +from backend.base.helpers import generate_salt_hash, get_hash + + +class Test_Security(unittest.TestCase): + def test_hash(self): + for test_case in ('test', ''): + result = generate_salt_hash(test_case) + self.assertEqual(result[1], get_hash(result[0], test_case)) diff --git a/tests/Tbackend/ui_test.py b/tests/Tbackend/ui_test.py new file mode 100644 index 0000000..332b79b --- /dev/null +++ b/tests/Tbackend/ui_test.py @@ -0,0 +1,21 @@ +import unittest + +from flask import Blueprint, Flask + +from frontend.ui import methods, ui + + +class Test_UI(unittest.TestCase): + def test_methods(self): + self.assertEqual(len(methods), 1) + self.assertEqual(methods[0], 'GET') + + def test_blueprint(self): + self.assertIsInstance(ui, Blueprint) + + def test_route_methods(self): + temp_app = Flask(__name__) + temp_app.register_blueprint(ui) + for rule in temp_app.url_map.iter_rules(): + self.assertEqual(len(rule.methods or []), 3) + self.assertIn(methods[0], rule.methods or []) diff --git a/tests/Tbackend/users_test.py b/tests/Tbackend/users_test.py new file mode 100644 index 0000000..c8551cd --- /dev/null +++ b/tests/Tbackend/users_test.py @@ -0,0 +1,15 @@ +import unittest + +from backend.base.custom_exceptions import UsernameInvalid +from backend.base.definitions import Constants +from backend.implementations.users import is_valid_username + + +class Test_Users(unittest.TestCase): + def test_username_check(self): + for test_case in ('', 'test'): + is_valid_username(test_case) + + for test_case in (' ', ' ', '0', 'api', *Constants.INVALID_USERNAMES): + with self.assertRaises(UsernameInvalid): + is_valid_username(test_case) diff --git a/tests/api_test.py b/tests/api_test.py deleted file mode 100644 index 1e7671b..0000000 --- a/tests/api_test.py +++ /dev/null @@ -1,25 +0,0 @@ -import unittest - -from flask import Blueprint - -from backend.custom_exceptions import * -from frontend.api import api, return_api - -class Test_API(unittest.TestCase): - def test_blueprint(self): - self.assertIsInstance(api, Blueprint) - - def test_return_api(self): - for case in ({'result': {}, 'error': 'Error', 'code': 201}, - {'result': ''}): - result = return_api(**case) - self.assertEqual(result[0]['result'], case['result']) - if case.get('error'): - self.assertEqual(result[0]['error'], case['error']) - else: - self.assertIsNone(result[0]['error']) - if case.get('code'): - self.assertEqual(result[1], case['code']) - else: - self.assertEqual(result[1], 200) - \ No newline at end of file diff --git a/tests/custom_exceptions_test.py b/tests/custom_exceptions_test.py deleted file mode 100644 index f086c25..0000000 --- a/tests/custom_exceptions_test.py +++ /dev/null @@ -1,39 +0,0 @@ -import unittest -from inspect import getmembers, getmro, isclass -from sys import modules -from typing import List - -import backend.custom_exceptions - -class Test_Custom_Exceptions(unittest.TestCase): - def test_type(self): - defined_exceptions: List[Exception] = filter( - lambda c: c.__module__ == 'backend.custom_exceptions' - and c is not backend.custom_exceptions.CustomException, - map( - lambda c: c[1], - getmembers(modules['backend.custom_exceptions'], isclass) - ) - ) - - for defined_exception in defined_exceptions: - self.assertIn( - getmro(defined_exception)[1], - ( - backend.custom_exceptions.CustomException, - Exception - ) - ) - try: - result = defined_exception().api_response - except TypeError: - try: - result = defined_exception('1').api_response - except TypeError: - result = defined_exception('1', '2').api_response - - self.assertIsInstance(result, dict) - result['error'] - result['result'] - result['code'] - self.assertIsInstance(result['code'], int) diff --git a/tests/db_test.py b/tests/db_test.py deleted file mode 100644 index 8b2e195..0000000 --- a/tests/db_test.py +++ /dev/null @@ -1,11 +0,0 @@ -import unittest - -from backend.db import DB_FILENAME, DBConnection -from backend.helpers import folder_path - - -class Test_DB(unittest.TestCase): - def test_foreign_key_and_wal(self): - DBConnection.file = folder_path(*DB_FILENAME) - instance = DBConnection(timeout=20.0) - self.assertEqual(instance.cursor().execute("PRAGMA foreign_keys;").fetchone()[0], 1) diff --git a/tests/reminders_test.py b/tests/reminders_test.py deleted file mode 100644 index 1e04b94..0000000 --- a/tests/reminders_test.py +++ /dev/null @@ -1,14 +0,0 @@ -import unittest - -from backend.helpers import search_filter - -class Test_Reminder_Handler(unittest.TestCase): - def test_filter_function(self): - p = { - 'title': 'TITLE', - 'text': 'TEXT' - } - for test_case in ('', 'title', 'ex'): - self.assertTrue(search_filter(test_case, p)) - for test_case in (' ', 'Hello'): - self.assertFalse(search_filter(test_case, p)) diff --git a/tests/security_test.py b/tests/security_test.py deleted file mode 100644 index 3fcddc5..0000000 --- a/tests/security_test.py +++ /dev/null @@ -1,10 +0,0 @@ -import unittest - -from backend.security import generate_salt_hash, get_hash - -class Test_Security(unittest.TestCase): - def test_hash(self): - for test_case in ('test', ''): - result = generate_salt_hash(test_case) - self.assertEqual(result[1], get_hash(result[0], test_case)) - \ No newline at end of file diff --git a/tests/ui_test.py b/tests/ui_test.py deleted file mode 100644 index 2d611c7..0000000 --- a/tests/ui_test.py +++ /dev/null @@ -1,20 +0,0 @@ -import unittest - -from flask import Blueprint, Flask - -from frontend.ui import methods, ui - -class Test_UI(unittest.TestCase): - def test_methods(self): - self.assertEqual(len(methods), 1) - self.assertEqual(methods[0], 'GET') - - def test_blueprint(self): - self.assertIsInstance(ui, Blueprint) - - def test_route_methods(self): - temp_app = Flask(__name__) - temp_app.register_blueprint(ui) - for rule in temp_app.url_map.iter_rules(): - self.assertEqual(len(rule.methods), 3) - self.assertIn(methods[0], rule.methods) diff --git a/tests/users_test.py b/tests/users_test.py deleted file mode 100644 index d1e7b23..0000000 --- a/tests/users_test.py +++ /dev/null @@ -1,14 +0,0 @@ -import unittest - -from backend.custom_exceptions import UsernameInvalid -from backend.users import ONEPASS_INVALID_USERNAMES, Users - -class Test_Users(unittest.TestCase): - def test_username_check(self): - users = Users() - for test_case in ('', 'test'): - users._check_username(test_case) - - for test_case in (' ', ' ', '0', 'api', *ONEPASS_INVALID_USERNAMES): - with self.assertRaises(UsernameInvalid): - users._check_username(test_case)