mirror of
https://github.com/Casvt/MIND.git
synced 2026-04-03 03:00:22 -04:00
Refactored backend (Fixes #87)
This commit is contained in:
428
backend/base/custom_exceptions.py
Normal file
428
backend/base/custom_exceptions.py
Normal file
@@ -0,0 +1,428 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from typing import Any, Union
|
||||
|
||||
from backend.base.definitions import (ApiResponse, InvalidUsernameReason,
|
||||
MindException)
|
||||
from backend.base.logging import LOGGER
|
||||
|
||||
|
||||
# region Input/Output
|
||||
class KeyNotFound(MindException):
|
||||
"A key was not found in the input that is required to be given."
|
||||
|
||||
def __init__(self, key: str) -> None:
|
||||
self.key = key
|
||||
LOGGER.warning(
|
||||
"This key was not found in the API request,"
|
||||
" eventhough it's required: %s",
|
||||
key
|
||||
)
|
||||
return
|
||||
|
||||
@property
|
||||
def api_response(self) -> ApiResponse:
|
||||
return {
|
||||
'code': 400,
|
||||
'error': self.__class__.__name__,
|
||||
'result': {
|
||||
'key': self.key
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class InvalidKeyValue(MindException):
|
||||
"The value of a key is invalid."
|
||||
|
||||
def __init__(self, key: str, value: Any) -> None:
|
||||
self.key = key
|
||||
self.value = value
|
||||
LOGGER.warning(
|
||||
"This key in the API request has an invalid value: "
|
||||
"%s = %",
|
||||
key, value
|
||||
)
|
||||
|
||||
@property
|
||||
def api_response(self) -> ApiResponse:
|
||||
return {
|
||||
'code': 400,
|
||||
'error': self.__class__.__name__,
|
||||
'result': {
|
||||
'key': self.key,
|
||||
'value': self.value
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# region Auth
|
||||
class AccessUnauthorized(MindException):
|
||||
"The password given is not correct"
|
||||
|
||||
def __init__(self) -> None:
|
||||
LOGGER.warning(
|
||||
"The password given is not correct"
|
||||
)
|
||||
return
|
||||
|
||||
@property
|
||||
def api_response(self) -> ApiResponse:
|
||||
return {
|
||||
'code': 401,
|
||||
'error': self.__class__.__name__,
|
||||
'result': {}
|
||||
}
|
||||
|
||||
|
||||
class APIKeyInvalid(MindException):
|
||||
"The API key is not correct"
|
||||
|
||||
def __init__(self, api_key: str) -> None:
|
||||
self.api_key = api_key
|
||||
return
|
||||
|
||||
@property
|
||||
def api_response(self) -> ApiResponse:
|
||||
return {
|
||||
'code': 401,
|
||||
'error': self.__class__.__name__,
|
||||
'result': {
|
||||
'api_key': self.api_key
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class APIKeyExpired(MindException):
|
||||
"The API key has expired"
|
||||
|
||||
def __init__(self, api_key: str) -> None:
|
||||
self.api_key = api_key
|
||||
return
|
||||
|
||||
@property
|
||||
def api_response(self) -> ApiResponse:
|
||||
return {
|
||||
'code': 401,
|
||||
'error': self.__class__.__name__,
|
||||
'result': {
|
||||
'api_key': self.api_key
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# region Admin Operations
|
||||
class OperationNotAllowed(MindException):
|
||||
"What was requested to be done is not allowed"
|
||||
|
||||
def __init__(self, operation: str) -> None:
|
||||
LOGGER.warning(
|
||||
"Operation not allowed: %s",
|
||||
operation
|
||||
)
|
||||
|
||||
@property
|
||||
def api_response(self) -> ApiResponse:
|
||||
return {
|
||||
'code': 403,
|
||||
'error': self.__class__.__name__,
|
||||
'result': {}
|
||||
}
|
||||
|
||||
|
||||
class NewAccountsNotAllowed(MindException):
|
||||
"It's not allowed to create a new account except for the admin"
|
||||
|
||||
def __init__(self) -> None:
|
||||
LOGGER.warning(
|
||||
"The creation of a new account was attempted but it's disabled by the admin"
|
||||
)
|
||||
return
|
||||
|
||||
@property
|
||||
def api_response(self) -> ApiResponse:
|
||||
return {
|
||||
'code': 403,
|
||||
'error': self.__class__.__name__,
|
||||
'result': {}
|
||||
}
|
||||
|
||||
|
||||
class InvalidDatabaseFile(MindException):
|
||||
"The uploaded database file is invalid or not supported"
|
||||
|
||||
def __init__(self, filepath_db: str) -> None:
|
||||
self.filepath_db = filepath_db
|
||||
LOGGER.warning(
|
||||
"The given database file is invalid: %s",
|
||||
filepath_db
|
||||
)
|
||||
return
|
||||
|
||||
@property
|
||||
def api_response(self) -> ApiResponse:
|
||||
return {
|
||||
'code': 400,
|
||||
'error': self.__class__.__name__,
|
||||
'result': {
|
||||
'filepath_db': self.filepath_db
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class LogFileNotFound(MindException):
|
||||
"The log file was not found"
|
||||
|
||||
def __init__(self, log_file: str) -> None:
|
||||
self.log_file = log_file
|
||||
LOGGER.warning(
|
||||
"The log file was not found: %s",
|
||||
log_file
|
||||
)
|
||||
return
|
||||
|
||||
@property
|
||||
def api_response(self) -> ApiResponse:
|
||||
return {
|
||||
'code': 404,
|
||||
'error': self.__class__.__name__,
|
||||
'result': {
|
||||
'log_file': self.log_file
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# region Users
|
||||
class UsernameTaken(MindException):
|
||||
"The username is already taken"
|
||||
|
||||
def __init__(self, username: str) -> None:
|
||||
self.username = username
|
||||
LOGGER.warning(
|
||||
"The username is already taken: %s",
|
||||
username
|
||||
)
|
||||
return
|
||||
|
||||
@property
|
||||
def api_response(self) -> ApiResponse:
|
||||
return {
|
||||
'code': 400,
|
||||
'error': self.__class__.__name__,
|
||||
'result': {
|
||||
'username': self.username
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class UsernameInvalid(MindException):
|
||||
"The username contains invalid characters or is not allowed"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
username: str,
|
||||
reason: InvalidUsernameReason
|
||||
) -> None:
|
||||
self.username = username
|
||||
self.reason = reason
|
||||
LOGGER.warning(
|
||||
"The username '%s' is invalid for the following reason: %s",
|
||||
username, reason.value
|
||||
)
|
||||
return
|
||||
|
||||
@property
|
||||
def api_response(self) -> ApiResponse:
|
||||
return {
|
||||
'code': 400,
|
||||
'error': self.__class__.__name__,
|
||||
'result': {
|
||||
'username': self.username,
|
||||
'reason': self.reason.value
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class UserNotFound(MindException):
|
||||
"The user requested can not be found"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
username: Union[str, None],
|
||||
user_id: Union[int, None]
|
||||
) -> None:
|
||||
self.username = username
|
||||
self.user_id = user_id
|
||||
if username:
|
||||
LOGGER.warning(
|
||||
"The user can not be found: %s",
|
||||
username
|
||||
)
|
||||
|
||||
elif user_id:
|
||||
LOGGER.warning(
|
||||
"The user can not be found: ID %d",
|
||||
user_id
|
||||
)
|
||||
|
||||
else:
|
||||
LOGGER.warning(
|
||||
"The user can not be found"
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
@property
|
||||
def api_response(self) -> ApiResponse:
|
||||
return {
|
||||
'code': 404,
|
||||
'error': self.__class__.__name__,
|
||||
'result': {
|
||||
'username': self.username,
|
||||
'user_id': self.user_id
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# region Notification Services
|
||||
class NotificationServiceNotFound(MindException):
|
||||
"The notification service was not found"
|
||||
|
||||
def __init__(self, notification_service_id: int) -> None:
|
||||
self.notification_service_id = notification_service_id
|
||||
LOGGER.warning(
|
||||
"The notification service with the given ID cannot be found: %d",
|
||||
notification_service_id
|
||||
)
|
||||
return
|
||||
|
||||
@property
|
||||
def api_response(self) -> ApiResponse:
|
||||
return {
|
||||
'code': 404,
|
||||
'error': self.__class__.__name__,
|
||||
'result': {
|
||||
'notification_service_id': self.notification_service_id
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class NotificationServiceInUse(MindException):
|
||||
"""
|
||||
The notification service is wished to be deleted
|
||||
but a reminder is still using it
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
notification_service_id: int,
|
||||
reminder_type: str
|
||||
) -> None:
|
||||
self.notification_service_id = notification_service_id
|
||||
self.reminder_type = reminder_type
|
||||
LOGGER.warning(
|
||||
"The notification service with ID %d is wished to be deleted "
|
||||
"but a reminder of type %s is still using it",
|
||||
notification_service_id,
|
||||
reminder_type
|
||||
)
|
||||
return
|
||||
|
||||
@property
|
||||
def api_response(self) -> ApiResponse:
|
||||
return {
|
||||
'code': 404,
|
||||
'error': self.__class__.__name__,
|
||||
'result': {
|
||||
'notification_service_id': self.notification_service_id,
|
||||
'reminder_type': self.reminder_type
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class URLInvalid(MindException):
|
||||
"The Apprise URL is invalid"
|
||||
|
||||
def __init__(self, url: str) -> None:
|
||||
self.url = url
|
||||
LOGGER.warning(
|
||||
"The Apprise URL given is invalid: %s",
|
||||
url
|
||||
)
|
||||
|
||||
@property
|
||||
def api_response(self) -> ApiResponse:
|
||||
return {
|
||||
'code': 400,
|
||||
'error': self.__class__.__name__,
|
||||
'result': {
|
||||
'url': self.url
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# region Templates
|
||||
class TemplateNotFound(MindException):
|
||||
"The template was not found"
|
||||
|
||||
def __init__(self, template_id: int) -> None:
|
||||
self.template_id = template_id
|
||||
LOGGER.warning(
|
||||
"The template with the given ID cannot be found: %d",
|
||||
template_id
|
||||
)
|
||||
return
|
||||
|
||||
@property
|
||||
def api_response(self) -> ApiResponse:
|
||||
return {
|
||||
'code': 404,
|
||||
'error': self.__class__.__name__,
|
||||
'result': {
|
||||
'template_id': self.template_id
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# region Reminders
|
||||
class ReminderNotFound(MindException):
|
||||
"The reminder was not found"
|
||||
|
||||
def __init__(self, reminder_id: int) -> None:
|
||||
self.reminder_id = reminder_id
|
||||
LOGGER.warning(
|
||||
"The reminder with the given ID cannot be found: %d",
|
||||
reminder_id
|
||||
)
|
||||
return
|
||||
|
||||
@property
|
||||
def api_response(self) -> ApiResponse:
|
||||
return {
|
||||
'code': 404,
|
||||
'error': self.__class__.__name__,
|
||||
'result': {
|
||||
'reminder_id': self.reminder_id
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class InvalidTime(MindException):
|
||||
"The time given is in the past"
|
||||
|
||||
def __init__(self, time: int) -> None:
|
||||
self.time = time
|
||||
LOGGER.warning(
|
||||
"The given time is invalid: %d",
|
||||
time
|
||||
)
|
||||
return
|
||||
|
||||
@property
|
||||
def api_response(self) -> ApiResponse:
|
||||
return {
|
||||
'code': 400,
|
||||
'error': self.__class__.__name__,
|
||||
'result': {
|
||||
'time': self.time
|
||||
}
|
||||
}
|
||||
315
backend/base/definitions.py
Normal file
315
backend/base/definitions.py
Normal file
@@ -0,0 +1,315 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
Definitions of basic types, abstract classes, enums, etc.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import (TYPE_CHECKING, Any, Dict, List, Literal,
|
||||
Tuple, Type, TypedDict, TypeVar, Union, cast)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.implementations.users import User
|
||||
|
||||
|
||||
# region Types
|
||||
T = TypeVar('T')
|
||||
U = TypeVar('U')
|
||||
WEEKDAY_NUMBER = Literal[0, 1, 2, 3, 4, 5, 6]
|
||||
BaseSerialisable = Union[
|
||||
int, float, bool, str, None
|
||||
]
|
||||
Serialisable = Union[
|
||||
List[Union[
|
||||
BaseSerialisable,
|
||||
List[BaseSerialisable],
|
||||
Dict[str, BaseSerialisable]
|
||||
]],
|
||||
Dict[str, Union[
|
||||
BaseSerialisable,
|
||||
List[BaseSerialisable],
|
||||
Dict[str, BaseSerialisable]
|
||||
]],
|
||||
]
|
||||
|
||||
|
||||
# region Constants
|
||||
class Constants:
|
||||
SUB_PROCESS_TIMEOUT = 20.0 # seconds
|
||||
|
||||
HOSTING_THREADS = 10
|
||||
HOSTING_REVERT_TIME = 60.0 # seconds
|
||||
|
||||
DB_FOLDER = ("db",)
|
||||
DB_NAME = "MIND.db"
|
||||
DB_ORIGINAL_NAME = 'MIND_original.db'
|
||||
DB_TIMEOUT = 10.0 # seconds
|
||||
DB_REVERT_TIME = 60.0 # seconds
|
||||
|
||||
LOGGER_NAME = "MIND"
|
||||
LOGGER_FILENAME = "MIND.log"
|
||||
|
||||
ADMIN_USERNAME = "admin"
|
||||
ADMIN_PASSWORD = "admin"
|
||||
INVALID_USERNAMES = ("reminders", "api")
|
||||
USERNAME_CHARACTERS = 'abcedfghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_-.!@$'
|
||||
|
||||
CONNECTION_ERROR_TIMEOUT = 120 # seconds
|
||||
|
||||
APPRISE_TEST_TITLE = "MIND: Test title"
|
||||
APPRISE_TEST_BODY = "MIND: Test body"
|
||||
|
||||
|
||||
# region Enums
|
||||
class BaseEnum(Enum):
|
||||
def __eq__(self, other) -> bool:
|
||||
return self.value == other
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return id(self.value)
|
||||
|
||||
|
||||
class StartType(BaseEnum):
|
||||
STARTUP = 130
|
||||
RESTART = 131
|
||||
RESTART_HOSTING_CHANGES = 132
|
||||
RESTART_DB_CHANGES = 133
|
||||
|
||||
|
||||
class InvalidUsernameReason(BaseEnum):
|
||||
ONLY_NUMBERS = "Username can not only be numbers"
|
||||
NOT_ALLOWED = "Username is not allowed"
|
||||
INVALID_CHARACTER = "Username contains an invalid character"
|
||||
|
||||
|
||||
class SendResult(BaseEnum):
|
||||
SUCCESS = "Success"
|
||||
CONNECTION_ERROR = "Connection error"
|
||||
SYNTAX_INVALID_URL = "Syntax of URL invalid"
|
||||
REJECTED_URL = "Values in URL rejected by service (e.g. invalid API token)"
|
||||
|
||||
|
||||
class ReminderType(BaseEnum):
|
||||
REMINDER = "Reminder"
|
||||
STATIC_REMINDER = "Static Reminder"
|
||||
TEMPLATE = "Template"
|
||||
|
||||
|
||||
class RepeatQuantity(BaseEnum):
|
||||
YEARS = "years"
|
||||
MONTHS = "months"
|
||||
WEEKS = "weeks"
|
||||
DAYS = "days"
|
||||
HOURS = "hours"
|
||||
MINUTES = "minutes"
|
||||
|
||||
|
||||
def sort_by_timeless_title(r: GeneralReminderData) -> Tuple[str, str, str]:
|
||||
return (r.title, r.text or '', r.color or '')
|
||||
|
||||
|
||||
def sort_by_time(r: ReminderData) -> Tuple[int, str, str, str]:
|
||||
return (r.time, r.title, r.text or '', r.color or '')
|
||||
|
||||
|
||||
def sort_by_timed_title(r: ReminderData) -> Tuple[str, int, str, str]:
|
||||
return (r.title, r.time, r.text or '', r.color or '')
|
||||
|
||||
|
||||
def sort_by_id(r: GeneralReminderData) -> int:
|
||||
return r.id
|
||||
|
||||
|
||||
class TimelessSortingMethod(BaseEnum):
|
||||
TITLE = sort_by_timeless_title, False
|
||||
TITLE_REVERSED = sort_by_timeless_title, True
|
||||
DATE_ADDED = sort_by_id, False
|
||||
DATE_ADDED_REVERSED = sort_by_id, True
|
||||
|
||||
|
||||
class SortingMethod(BaseEnum):
|
||||
TIME = sort_by_time, False
|
||||
TIME_REVERSED = sort_by_time, True
|
||||
TITLE = sort_by_timed_title, False
|
||||
TITLE_REVERSED = sort_by_timed_title, True
|
||||
DATE_ADDED = sort_by_id, False
|
||||
DATE_ADDED_REVERSED = sort_by_id, True
|
||||
|
||||
|
||||
class DataType(BaseEnum):
|
||||
STR = 'string'
|
||||
INT = 'number'
|
||||
FLOAT = 'decimal number'
|
||||
BOOL = 'bool'
|
||||
INT_ARRAY = 'list of numbers'
|
||||
NA = 'N/A'
|
||||
|
||||
|
||||
class DataSource(BaseEnum):
|
||||
DATA = 1
|
||||
VALUES = 2
|
||||
FILES = 3
|
||||
|
||||
|
||||
# region TypedDicts
|
||||
class ApiResponse(TypedDict):
|
||||
result: Any
|
||||
error: Union[str, None]
|
||||
code: int
|
||||
|
||||
|
||||
# region Abstract Classes
|
||||
class DBMigrator(ABC):
|
||||
start_version: int
|
||||
|
||||
@abstractmethod
|
||||
def run(self) -> None:
|
||||
...
|
||||
|
||||
|
||||
class MindException(Exception, ABC):
|
||||
"""An exception specific to MIND"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def api_response(self) -> ApiResponse:
|
||||
...
|
||||
|
||||
|
||||
# region Dataclasses
|
||||
@dataclass
|
||||
class ApiKeyEntry:
|
||||
exp: int
|
||||
user_data: User
|
||||
|
||||
|
||||
def _return_exceptions() -> List[Type[MindException]]:
|
||||
from backend.base.custom_exceptions import InvalidKeyValue, KeyNotFound
|
||||
return [KeyNotFound, InvalidKeyValue]
|
||||
|
||||
|
||||
@dataclass
|
||||
class InputVariable(ABC):
|
||||
value: Any
|
||||
name: str
|
||||
description: str
|
||||
required: bool = True
|
||||
default: Any = None
|
||||
data_type: List[DataType] = field(default_factory=lambda: [DataType.STR])
|
||||
source: DataSource = DataSource.DATA
|
||||
related_exceptions: List[Type[MindException]] = field(
|
||||
default_factory=_return_exceptions
|
||||
)
|
||||
|
||||
def validate(self) -> bool:
|
||||
return isinstance(self.value, str) and bool(self.value)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Method:
|
||||
description: str = ''
|
||||
vars: List[Type[InputVariable]] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Methods:
|
||||
get: Union[Method, None] = None
|
||||
post: Union[Method, None] = None
|
||||
put: Union[Method, None] = None
|
||||
delete: Union[Method, None] = None
|
||||
|
||||
def __getitem__(self, key: str) -> Union[Method, None]:
|
||||
return getattr(self, key.lower())
|
||||
|
||||
def used_methods(self) -> List[str]:
|
||||
result = []
|
||||
for method in ('get', 'post', 'put', 'delete'):
|
||||
if getattr(self, method) is not None:
|
||||
result.append(method)
|
||||
return result
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ApiDocEntry:
|
||||
endpoint: str
|
||||
description: str
|
||||
methods: Methods
|
||||
requires_auth: bool
|
||||
|
||||
|
||||
@dataclass(frozen=True, order=True)
|
||||
class NotificationServiceData:
|
||||
id: int
|
||||
title: str
|
||||
url: str
|
||||
|
||||
def todict(self) -> Dict[str, Any]:
|
||||
return self.__dict__
|
||||
|
||||
|
||||
@dataclass(frozen=True, order=True)
|
||||
class UserData:
|
||||
id: int
|
||||
username: str
|
||||
admin: bool
|
||||
salt: bytes
|
||||
hash: bytes
|
||||
|
||||
def todict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
k: v
|
||||
for k, v in self.__dict__.items()
|
||||
if k in ('id', 'username', 'admin')
|
||||
}
|
||||
|
||||
|
||||
@dataclass(order=True)
|
||||
class GeneralReminderData:
|
||||
id: int
|
||||
title: str
|
||||
text: Union[str, None]
|
||||
color: Union[str, None]
|
||||
notification_services: List[int]
|
||||
|
||||
def todict(self) -> Dict[str, Any]:
|
||||
return self.__dict__
|
||||
|
||||
|
||||
@dataclass(order=True)
|
||||
class TemplateData(GeneralReminderData):
|
||||
...
|
||||
|
||||
|
||||
@dataclass(order=True)
|
||||
class StaticReminderData(GeneralReminderData):
|
||||
...
|
||||
|
||||
|
||||
@dataclass(order=True)
|
||||
class ReminderData(GeneralReminderData):
|
||||
time: int
|
||||
original_time: Union[int, None]
|
||||
repeat_quantity: Union[str, None]
|
||||
repeat_interval: Union[int, None]
|
||||
_weekdays: Union[str, None]
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self._weekdays is not None:
|
||||
self.weekdays: Union[List[WEEKDAY_NUMBER], None] = [
|
||||
cast(WEEKDAY_NUMBER, int(n))
|
||||
for n in self._weekdays.split(',')
|
||||
if n
|
||||
]
|
||||
else:
|
||||
self.weekdays = None
|
||||
|
||||
def todict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
k: v
|
||||
for k, v in self.__dict__.items()
|
||||
if k != '_weekdays'
|
||||
}
|
||||
394
backend/base/helpers.py
Normal file
394
backend/base/helpers.py
Normal file
@@ -0,0 +1,394 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
General "helper" function and classes
|
||||
"""
|
||||
|
||||
from base64 import urlsafe_b64encode
|
||||
from datetime import datetime
|
||||
from hashlib import pbkdf2_hmac
|
||||
from logging import WARNING
|
||||
from os import makedirs, symlink
|
||||
from os.path import abspath, dirname, exists, join
|
||||
from secrets import token_bytes
|
||||
from shutil import copy2, move
|
||||
from sys import base_exec_prefix, executable, platform, version_info
|
||||
from typing import (Any, Callable, Generator, Iterable,
|
||||
List, Sequence, Tuple, Union, cast)
|
||||
|
||||
from apprise import Apprise, LogCapture
|
||||
from dateutil.relativedelta import relativedelta
|
||||
|
||||
from backend.base.definitions import (WEEKDAY_NUMBER, GeneralReminderData,
|
||||
RepeatQuantity, SendResult, T, U)
|
||||
|
||||
|
||||
def get_python_version() -> str:
|
||||
"""Get python version as string
|
||||
|
||||
Returns:
|
||||
str: The python version
|
||||
"""
|
||||
return ".".join(
|
||||
str(i) for i in list(version_info)
|
||||
)
|
||||
|
||||
|
||||
def check_python_version() -> bool:
|
||||
"""Check if the python version that is used is a minimum version.
|
||||
|
||||
Returns:
|
||||
bool: Whether or not the python version is version 3.8 or above or not.
|
||||
"""
|
||||
if not (version_info.major == 3 and version_info.minor >= 8):
|
||||
from backend.base.logging import LOGGER
|
||||
LOGGER.critical(
|
||||
'The minimum python version required is python3.8 '
|
||||
'(currently ' + str(version_info.major) + '.' +
|
||||
str(version_info.minor) + '.' + str(version_info.micro) + ').'
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_python_exe() -> str:
|
||||
"""Get the path to the python executable.
|
||||
|
||||
Returns:
|
||||
str: The python executable path.
|
||||
"""
|
||||
if platform.startswith('darwin'):
|
||||
bundle_path = join(
|
||||
base_exec_prefix,
|
||||
"Resources",
|
||||
"Python.app",
|
||||
"Contents",
|
||||
"MacOS",
|
||||
"Python"
|
||||
)
|
||||
if exists(bundle_path):
|
||||
from tempfile import mkdtemp
|
||||
python_path = join(mkdtemp(), "python")
|
||||
symlink(bundle_path, python_path)
|
||||
|
||||
return python_path
|
||||
|
||||
return executable
|
||||
|
||||
|
||||
def reversed_tuples(
|
||||
i: Iterable[Tuple[T, U]]
|
||||
) -> Generator[Tuple[U, T], Any, Any]:
|
||||
"""Yield sub-tuples in reversed order.
|
||||
|
||||
Args:
|
||||
i (Iterable[Tuple[T, U]]): Iterator.
|
||||
|
||||
Yields:
|
||||
Generator[Tuple[U, T], Any, Any]: Sub-tuple with reversed order.
|
||||
"""
|
||||
for entry_1, entry_2 in i:
|
||||
yield entry_2, entry_1
|
||||
|
||||
|
||||
def first_of_column(
|
||||
columns: Iterable[Sequence[T]]
|
||||
) -> List[T]:
|
||||
"""Get the first element of each sub-array.
|
||||
|
||||
Args:
|
||||
columns (Iterable[Sequence[T]]): List of
|
||||
sub-arrays.
|
||||
|
||||
Returns:
|
||||
List[T]: List with first value of each sub-array.
|
||||
"""
|
||||
return [e[0] for e in columns]
|
||||
|
||||
|
||||
def when_not_none(
|
||||
value: Union[T, None],
|
||||
to_run: Callable[[T], U]
|
||||
) -> Union[U, None]:
|
||||
"""Run `to_run` with argument `value` iff `value is not None`. Else return
|
||||
`None`.
|
||||
|
||||
Args:
|
||||
value (Union[T, None]): The value to check.
|
||||
to_run (Callable[[T], U]): The function to run.
|
||||
|
||||
Returns:
|
||||
Union[U, None]: Either the return value of `to_run`, or `None`.
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
else:
|
||||
return to_run(value)
|
||||
|
||||
|
||||
def search_filter(query: str, result: GeneralReminderData) -> bool:
|
||||
"""Filter library results based on a query.
|
||||
|
||||
Args:
|
||||
query (str): The query to filter with.
|
||||
result (GeneralReminderData): The library result to check.
|
||||
|
||||
Returns:
|
||||
bool: Whether or not the result passes the filter.
|
||||
"""
|
||||
query = query.lower()
|
||||
return (
|
||||
query in result.title.lower()
|
||||
or query in (result.text or '').lower()
|
||||
)
|
||||
|
||||
|
||||
def get_hash(salt: bytes, data: str) -> bytes:
|
||||
"""Hash a string using the supplied salt
|
||||
|
||||
Args:
|
||||
salt (bytes): The salt to use when hashing
|
||||
data (str): The data to hash
|
||||
|
||||
Returns:
|
||||
bytes: The b64 encoded hash of the supplied string
|
||||
"""
|
||||
return urlsafe_b64encode(
|
||||
pbkdf2_hmac('sha256', data.encode(), salt, 100_000)
|
||||
)
|
||||
|
||||
|
||||
def generate_salt_hash(password: str) -> Tuple[bytes, bytes]:
|
||||
"""Generate a salt and get the hash of the password
|
||||
|
||||
Args:
|
||||
password (str): The password to generate for
|
||||
|
||||
Returns:
|
||||
Tuple[bytes, bytes]: The salt (1) and hashed_password (2)
|
||||
"""
|
||||
salt = token_bytes()
|
||||
hashed_password = get_hash(salt, password)
|
||||
return salt, hashed_password
|
||||
|
||||
|
||||
def send_apprise_notification(
|
||||
urls: List[str],
|
||||
title: str,
|
||||
text: Union[str, None] = None
|
||||
) -> SendResult:
|
||||
"""Send a notification to all Apprise URL's given.
|
||||
|
||||
Args:
|
||||
urls (List[str]): The Apprise URL's to send the notification to.
|
||||
|
||||
title (str): The title of the notification.
|
||||
|
||||
text (Union[str, None], optional): The optional body of the
|
||||
notification.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
SendResult: Whether or not it was successful.
|
||||
"""
|
||||
a = Apprise()
|
||||
|
||||
for url in urls:
|
||||
if not a.add(url):
|
||||
return SendResult.SYNTAX_INVALID_URL
|
||||
|
||||
with LogCapture(level=WARNING) as log:
|
||||
result = a.notify(
|
||||
title=title,
|
||||
body=text or '\u200B'
|
||||
)
|
||||
if not result:
|
||||
if "socket exception" in log.getvalue(): # type: ignore
|
||||
return SendResult.CONNECTION_ERROR
|
||||
else:
|
||||
return SendResult.REJECTED_URL
|
||||
|
||||
return SendResult.SUCCESS
|
||||
|
||||
|
||||
def next_selected_day(
|
||||
weekdays: List[WEEKDAY_NUMBER],
|
||||
weekday: WEEKDAY_NUMBER
|
||||
) -> WEEKDAY_NUMBER:
|
||||
"""Find the next allowed day in the week.
|
||||
|
||||
Args:
|
||||
weekdays (List[WEEKDAY_NUMBER]): The days of the week that are allowed.
|
||||
Monday is 0, Sunday is 6.
|
||||
weekday (WEEKDAY_NUMBER): The current weekday.
|
||||
|
||||
Returns:
|
||||
WEEKDAY_NUMBER: The next allowed weekday.
|
||||
"""
|
||||
for d in weekdays:
|
||||
if weekday < d:
|
||||
return d
|
||||
return weekdays[0]
|
||||
|
||||
|
||||
def find_next_time(
|
||||
original_time: int,
|
||||
repeat_quantity: Union[RepeatQuantity, None],
|
||||
repeat_interval: Union[int, None],
|
||||
weekdays: Union[List[WEEKDAY_NUMBER], None]
|
||||
) -> int:
|
||||
"""Calculate the next timestep based on original time and repeat/interval
|
||||
values.
|
||||
|
||||
Args:
|
||||
original_time (int): The original time of the repeating timestamp.
|
||||
|
||||
repeat_quantity (Union[RepeatQuantity, None]): If set, what the quantity
|
||||
is of the repetition.
|
||||
|
||||
repeat_interval (Union[int, None]): If set, the value of the repetition.
|
||||
|
||||
weekdays (Union[List[WEEKDAY_NUMBER], None]): If set, on which days the
|
||||
time can continue. Monday is 0, Sunday is 6.
|
||||
|
||||
Returns:
|
||||
int: The next timestamp in the future.
|
||||
"""
|
||||
if weekdays is not None:
|
||||
weekdays.sort()
|
||||
|
||||
current_time = datetime.fromtimestamp(datetime.utcnow().timestamp())
|
||||
original_datetime = datetime.fromtimestamp(original_time)
|
||||
new_time = datetime.fromtimestamp(original_time)
|
||||
|
||||
if (
|
||||
repeat_quantity is not None
|
||||
and repeat_interval is not None
|
||||
):
|
||||
# Add the interval to the original time until we are in the future.
|
||||
# We need to multiply the interval and add it to the original time
|
||||
# instead of just adding the interval once each time to the original
|
||||
# time, because otherwise date jumping could happen. Say original time
|
||||
# is a leap day with an interval of 1 year. Then next date would be the
|
||||
# day before leap day, as leap day doesn't exist in the next year. But
|
||||
# if we then keep adding 1 year to this time, we would keep getting the
|
||||
# day before leap day, a year later. So we need to multiply the interval
|
||||
# and add the whole interval to the original time in one go. This way
|
||||
# after four years we will get the leap day again.
|
||||
interval = relativedelta(
|
||||
**{repeat_quantity.value: repeat_interval} # type: ignore
|
||||
)
|
||||
multiplier = 1
|
||||
while new_time <= current_time:
|
||||
new_time = original_datetime + (interval * multiplier)
|
||||
multiplier += 1
|
||||
|
||||
elif weekdays is not None:
|
||||
if (
|
||||
current_time.weekday() in weekdays
|
||||
and current_time.time() < original_datetime.time()
|
||||
):
|
||||
# Next reminder is later today, so target weekday is current weekday
|
||||
weekday = current_time.weekday()
|
||||
|
||||
else:
|
||||
# Next reminder is not today or earlier today, so target weekday
|
||||
# is next selected one
|
||||
weekday = next_selected_day(
|
||||
weekdays,
|
||||
cast(WEEKDAY_NUMBER, current_time.weekday())
|
||||
)
|
||||
|
||||
new_time = current_time + relativedelta(
|
||||
# Move to upcoming weekday (possibly today)
|
||||
weekday=weekday,
|
||||
# Also move current time to set time
|
||||
hour=original_datetime.hour,
|
||||
minute=original_datetime.minute,
|
||||
second=original_datetime.second
|
||||
)
|
||||
|
||||
result = int(new_time.timestamp())
|
||||
# LOGGER.debug(
|
||||
# f'{original_datetime=}, {current_time=} ' +
|
||||
# f'and interval of {repeat_interval} {repeat_quantity} ' +
|
||||
# f'and weekdays {weekdays} ' +
|
||||
# f'leads to {result}'
|
||||
# )
|
||||
return result
|
||||
|
||||
|
||||
def folder_path(*folders: str) -> str:
|
||||
"""Turn filepaths relative to the project folder into absolute paths.
|
||||
|
||||
Returns:
|
||||
str: The absolute filepath.
|
||||
"""
|
||||
return join(
|
||||
dirname(dirname(dirname(abspath(__file__)))),
|
||||
*folders
|
||||
)
|
||||
|
||||
|
||||
def create_folder(
|
||||
folder: str
|
||||
) -> None:
|
||||
"""Create a folder, if it doesn't exist already.
|
||||
|
||||
Args:
|
||||
folder (str): The path to the folder to create.
|
||||
"""
|
||||
makedirs(folder, exist_ok=True)
|
||||
return
|
||||
|
||||
|
||||
def __copy2(src, dst, *, follow_symlinks=True):
|
||||
try:
|
||||
return copy2(src, dst, follow_symlinks=follow_symlinks)
|
||||
|
||||
except PermissionError as pe:
|
||||
if pe.errno == 1:
|
||||
# NFS file system doesn't allow/support chmod.
|
||||
# This is done after the file is already copied. So just accept that
|
||||
# it isn't possible to change the permissions. Continue like normal.
|
||||
return dst
|
||||
|
||||
raise
|
||||
|
||||
except OSError as oe:
|
||||
if oe.errno == 524:
|
||||
# NFS file system doesn't allow/support setting extended attributes.
|
||||
# This is done after the file is already copied. So just accept that
|
||||
# it isn't possible to set them. Continue like normal.
|
||||
return dst
|
||||
|
||||
raise
|
||||
|
||||
|
||||
def rename_file(
|
||||
before: str,
|
||||
after: str
|
||||
) -> None:
|
||||
"""Rename a file, taking care of new folder locations and
|
||||
the possible complications with files on OS'es.
|
||||
|
||||
Args:
|
||||
before (str): The current filepath of the file.
|
||||
after (str): The new desired filepath of the file.
|
||||
"""
|
||||
create_folder(dirname(after))
|
||||
|
||||
move(before, after, copy_function=__copy2)
|
||||
|
||||
return
|
||||
|
||||
|
||||
class Singleton(type):
|
||||
_instances = {}
|
||||
|
||||
def __call__(cls, *args: Any, **kwargs: Any):
|
||||
c = str(cls)
|
||||
if c not in cls._instances:
|
||||
cls._instances[c] = super().__call__(*args, **kwargs)
|
||||
|
||||
return cls._instances[c]
|
||||
163
backend/base/logging.py
Normal file
163
backend/base/logging.py
Normal file
@@ -0,0 +1,163 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import logging
|
||||
import logging.config
|
||||
from os.path import exists, isdir, join
|
||||
from typing import Any, Union
|
||||
|
||||
from backend.base.definitions import Constants
|
||||
from backend.base.helpers import create_folder, folder_path
|
||||
|
||||
|
||||
class UpToInfoFilter(logging.Filter):
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
return record.levelno <= logging.INFO
|
||||
|
||||
|
||||
class ErrorColorFormatter(logging.Formatter):
|
||||
def format(self, record: logging.LogRecord) -> Any:
|
||||
result = super().format(record)
|
||||
return f'\033[1;31:40m{result}\033[0m'
|
||||
|
||||
|
||||
LOGGER = logging.getLogger(Constants.LOGGER_NAME)
|
||||
LOGGING_CONFIG = {
|
||||
"version": 1,
|
||||
"disable_existing_loggers": False,
|
||||
"formatters": {
|
||||
"simple": {
|
||||
"format": "[%(asctime)s][%(levelname)s] %(message)s",
|
||||
"datefmt": "%H:%M:%S"
|
||||
},
|
||||
"simple_red": {
|
||||
"()": ErrorColorFormatter,
|
||||
"format": "[%(asctime)s][%(levelname)s] %(message)s",
|
||||
"datefmt": "%H:%M:%S"
|
||||
},
|
||||
"detailed": {
|
||||
"format": "%(asctime)s | %(threadName)s | %(filename)sL%(lineno)s | %(levelname)s | %(message)s",
|
||||
"datefmt": "%Y-%m-%dT%H:%M:%S%z",
|
||||
}
|
||||
},
|
||||
"filters": {
|
||||
"up_to_info": {
|
||||
"()": UpToInfoFilter
|
||||
},
|
||||
},
|
||||
"handlers": {
|
||||
"console_error": {
|
||||
"class": "logging.StreamHandler",
|
||||
"level": "WARNING",
|
||||
"formatter": "simple_red",
|
||||
"stream": "ext://sys.stderr"
|
||||
},
|
||||
"console": {
|
||||
"class": "logging.StreamHandler",
|
||||
"level": "DEBUG",
|
||||
"formatter": "simple",
|
||||
"filters": ["up_to_info"],
|
||||
"stream": "ext://sys.stdout"
|
||||
},
|
||||
"file": {
|
||||
"class": "logging.handlers.RotatingFileHandler",
|
||||
"level": "DEBUG",
|
||||
"formatter": "detailed",
|
||||
"filename": "",
|
||||
"maxBytes": 1_000_000,
|
||||
"backupCount": 1
|
||||
}
|
||||
},
|
||||
"loggers": {
|
||||
Constants.LOGGER_NAME: {}
|
||||
},
|
||||
"root": {
|
||||
"level": "INFO",
|
||||
"handlers": [
|
||||
"console",
|
||||
"console_error",
|
||||
"file"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def setup_logging(log_folder: Union[str, None]) -> None:
|
||||
"""Setup the basic config of the logging module.
|
||||
|
||||
Args:
|
||||
log_folder (Union[str, None]): The folder to put the log file in.
|
||||
If `None`, the log file will be in the same folder as the
|
||||
application folder.
|
||||
|
||||
Raises:
|
||||
ValueError: The given log folder is not a folder.
|
||||
"""
|
||||
if log_folder:
|
||||
if exists(log_folder) and not isdir(log_folder):
|
||||
raise ValueError("Logging folder is not a folder")
|
||||
|
||||
create_folder(log_folder)
|
||||
|
||||
if log_folder is None:
|
||||
LOGGING_CONFIG["handlers"]["file"]["filename"] = folder_path(
|
||||
Constants.LOGGER_FILENAME
|
||||
)
|
||||
else:
|
||||
LOGGING_CONFIG["handlers"]["file"]["filename"] = join(
|
||||
log_folder,
|
||||
Constants.LOGGER_FILENAME
|
||||
)
|
||||
|
||||
logging.config.dictConfig(LOGGING_CONFIG)
|
||||
|
||||
# Log uncaught exceptions using the logger instead of printing the stderr
|
||||
# Logger goes to stderr anyway, so still visible in console but also logs
|
||||
# to file, so that downloaded log file also contains any errors.
|
||||
import sys
|
||||
import threading
|
||||
from traceback import format_exception
|
||||
|
||||
def log_uncaught_exceptions(e_type, value, tb):
|
||||
LOGGER.error(
|
||||
"UNCAUGHT EXCEPTION:\n" +
|
||||
''.join(format_exception(e_type, value, tb))
|
||||
)
|
||||
return
|
||||
|
||||
def log_uncaught_threading_exceptions(args):
|
||||
LOGGER.exception(
|
||||
f"UNCAUGHT EXCEPTION IN THREAD: {args.exc_value}"
|
||||
)
|
||||
return
|
||||
|
||||
sys.excepthook = log_uncaught_exceptions
|
||||
threading.excepthook = log_uncaught_threading_exceptions
|
||||
|
||||
return
|
||||
|
||||
|
||||
def get_log_filepath() -> str:
|
||||
"Get the filepath to the logging file"
|
||||
return LOGGING_CONFIG["handlers"]["file"]["filename"]
|
||||
|
||||
|
||||
def set_log_level(
|
||||
level: Union[int, str],
|
||||
) -> None:
|
||||
"""Change the logging level.
|
||||
|
||||
Args:
|
||||
level (Union[int, str]): The level to set the logging to.
|
||||
Should be a logging level, like `logging.INFO` or `"DEBUG"`.
|
||||
"""
|
||||
if isinstance(level, str):
|
||||
level = logging._nameToLevel[level.upper()]
|
||||
|
||||
root_logger = logging.getLogger()
|
||||
if root_logger.level == level:
|
||||
return
|
||||
|
||||
LOGGER.debug(f'Setting logging level: {level}')
|
||||
root_logger.setLevel(level)
|
||||
|
||||
return
|
||||
@@ -1,138 +0,0 @@
|
||||
#-*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
All custom exceptions are defined here
|
||||
"""
|
||||
|
||||
"""
|
||||
Note: Not all CE's inherit from CustomException.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict
|
||||
from backend.logging import LOGGER
|
||||
|
||||
|
||||
class CustomException(Exception):
|
||||
def __init__(self, e=None) -> None:
|
||||
LOGGER.warning(self.__doc__)
|
||||
super().__init__(e)
|
||||
return
|
||||
|
||||
class UsernameTaken(CustomException):
|
||||
"""The username is already taken"""
|
||||
api_response = {'error': 'UsernameTaken', 'result': {}, 'code': 400}
|
||||
|
||||
class UsernameInvalid(Exception):
|
||||
"""The username contains invalid characters"""
|
||||
api_response = {'error': 'UsernameInvalid', 'result': {}, 'code': 400}
|
||||
|
||||
def __init__(self, username: str):
|
||||
self.username = username
|
||||
super().__init__(self.username)
|
||||
LOGGER.warning(
|
||||
f'The username contains invalid characters: {username}'
|
||||
)
|
||||
return
|
||||
|
||||
class UserNotFound(CustomException):
|
||||
"""The user requested can not be found"""
|
||||
api_response = {'error': 'UserNotFound', 'result': {}, 'code': 404}
|
||||
|
||||
class AccessUnauthorized(CustomException):
|
||||
"""The password given is not correct"""
|
||||
api_response = {'error': 'AccessUnauthorized', 'result': {}, 'code': 401}
|
||||
|
||||
class ReminderNotFound(CustomException):
|
||||
"""The reminder with the id can not be found"""
|
||||
api_response = {'error': 'ReminderNotFound', 'result': {}, 'code': 404}
|
||||
|
||||
class NotificationServiceNotFound(CustomException):
|
||||
"""The notification service was not found"""
|
||||
api_response = {'error': 'NotificationServiceNotFound', 'result': {}, 'code': 404}
|
||||
|
||||
class NotificationServiceInUse(Exception):
|
||||
"""
|
||||
The notification service is wished to be deleted
|
||||
but a reminder is still using it
|
||||
"""
|
||||
def __init__(self, type: str=''):
|
||||
self.type = type
|
||||
super().__init__(self.type)
|
||||
LOGGER.warning(
|
||||
f'The notification is wished to be deleted but a reminder of type {type} is still using it'
|
||||
)
|
||||
return
|
||||
|
||||
@property
|
||||
def api_response(self) -> Dict[str, Any]:
|
||||
return {
|
||||
'error': 'NotificationServiceInUse',
|
||||
'result': {'type': self.type},
|
||||
'code': 400
|
||||
}
|
||||
|
||||
class InvalidTime(CustomException):
|
||||
"""The time given is in the past"""
|
||||
api_response = {'error': 'InvalidTime', 'result': {}, 'code': 400}
|
||||
|
||||
class KeyNotFound(Exception):
|
||||
"""A key was not found in the input that is required to be given"""
|
||||
def __init__(self, key: str=''):
|
||||
self.key = key
|
||||
super().__init__(self.key)
|
||||
LOGGER.warning(
|
||||
"This key was not found in the API request,"
|
||||
+ f" eventhough it's required: {key}"
|
||||
)
|
||||
return
|
||||
|
||||
@property
|
||||
def api_response(self) -> Dict[str, Any]:
|
||||
return {
|
||||
'error': 'KeyNotFound',
|
||||
'result': {'key': self.key},
|
||||
'code': 400
|
||||
}
|
||||
|
||||
class InvalidKeyValue(Exception):
|
||||
"""The value of a key is invalid"""
|
||||
def __init__(self, key: str = '', value: Any = ''):
|
||||
self.key = key
|
||||
self.value = value
|
||||
super().__init__(self.key)
|
||||
LOGGER.warning(
|
||||
'This key in the API request has an invalid value: ' +
|
||||
f'{key} = {value}'
|
||||
)
|
||||
|
||||
@property
|
||||
def api_response(self) -> Dict[str, Any]:
|
||||
return {
|
||||
'error': 'InvalidKeyValue',
|
||||
'result': {'key': self.key, 'value': self.value},
|
||||
'code': 400
|
||||
}
|
||||
|
||||
class TemplateNotFound(CustomException):
|
||||
"""The template was not found"""
|
||||
api_response = {'error': 'TemplateNotFound', 'result': {}, 'code': 404}
|
||||
|
||||
class APIKeyInvalid(Exception):
|
||||
"""The API key is not correct"""
|
||||
api_response = {'error': 'APIKeyInvalid', 'result': {}, 'code': 401}
|
||||
|
||||
class APIKeyExpired(Exception):
|
||||
"""The API key has expired"""
|
||||
api_response = {'error': 'APIKeyExpired', 'result': {}, 'code': 401}
|
||||
|
||||
class NewAccountsNotAllowed(CustomException):
|
||||
"""It's not allowed to create a new account except for the admin"""
|
||||
api_response = {'error': 'NewAccountsNotAllowed', 'result': {}, 'code': 403}
|
||||
|
||||
class InvalidDatabaseFile(CustomException):
|
||||
"""The uploaded database file is invalid or not supported"""
|
||||
api_response = {'error': 'InvalidDatabaseFile', 'result': {}, 'code': 400}
|
||||
|
||||
class LogFileNotFound(CustomException):
|
||||
"""No log file was found"""
|
||||
api_response = {'error': 'LogFileNotFound', 'result': {}, 'code': 404}
|
||||
521
backend/db.py
521
backend/db.py
@@ -1,521 +0,0 @@
|
||||
#-*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
Setting up and interacting with the database.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from os import makedirs, remove
|
||||
from os.path import dirname, isfile, join
|
||||
from shutil import move
|
||||
from sqlite3 import Connection, OperationalError, ProgrammingError, Row
|
||||
from threading import current_thread, main_thread
|
||||
from time import time
|
||||
from typing import Type, Union
|
||||
|
||||
from flask import g
|
||||
|
||||
from backend.custom_exceptions import (AccessUnauthorized, InvalidDatabaseFile,
|
||||
UserNotFound)
|
||||
from backend.helpers import RestartVars, folder_path
|
||||
from backend.logging import LOGGER, set_log_level
|
||||
|
||||
DB_FILENAME = 'db', 'MIND.db'
|
||||
__DATABASE_VERSION__ = 10
|
||||
__DATEBASE_NAME_ORIGINAL__ = "MIND_original.db"
|
||||
|
||||
class DB_Singleton(type):
|
||||
_instances = {}
|
||||
def __call__(cls, *args, **kwargs):
|
||||
i = f'{cls}{current_thread()}'
|
||||
if (i not in cls._instances
|
||||
or cls._instances[i].closed):
|
||||
cls._instances[i] = super(DB_Singleton, cls).__call__(*args, **kwargs)
|
||||
|
||||
return cls._instances[i]
|
||||
|
||||
class DBConnection(Connection, metaclass=DB_Singleton):
|
||||
file = ''
|
||||
|
||||
def __init__(self, timeout: float) -> None:
|
||||
LOGGER.debug(f'Creating connection {self}')
|
||||
super().__init__(self.file, timeout=timeout)
|
||||
super().cursor().execute("PRAGMA foreign_keys = ON;")
|
||||
self.closed = False
|
||||
return
|
||||
|
||||
def close(self) -> None:
|
||||
LOGGER.debug(f'Closing connection {self}')
|
||||
self.closed = True
|
||||
super().close()
|
||||
return
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<{self.__class__.__name__}; {current_thread().name}; {id(self)}>'
|
||||
|
||||
def setup_db_location() -> None:
|
||||
"""Create folder for database and link file to DBConnection class
|
||||
"""
|
||||
if isfile(folder_path('db', 'Noted.db')):
|
||||
move(folder_path('db', 'Noted.db'), folder_path(*DB_FILENAME))
|
||||
|
||||
db_location = folder_path(*DB_FILENAME)
|
||||
makedirs(dirname(db_location), exist_ok=True)
|
||||
|
||||
DBConnection.file = db_location
|
||||
return
|
||||
|
||||
def get_db(output_type: Union[Type[dict], Type[tuple]]=tuple):
|
||||
"""Get a database cursor instance. Coupled to Flask's g.
|
||||
|
||||
Args:
|
||||
output_type (Union[Type[dict], Type[tuple]], optional):
|
||||
The type of output: a tuple or dictionary with the row values.
|
||||
Defaults to tuple.
|
||||
|
||||
Returns:
|
||||
Cursor: The Cursor instance to use
|
||||
"""
|
||||
try:
|
||||
cursor = g.cursor
|
||||
except AttributeError:
|
||||
db = DBConnection(timeout=20.0)
|
||||
cursor = g.cursor = db.cursor()
|
||||
|
||||
if output_type is dict:
|
||||
cursor.row_factory = Row
|
||||
else:
|
||||
cursor.row_factory = None
|
||||
|
||||
return g.cursor
|
||||
|
||||
def close_db(e=None) -> None:
|
||||
"""Savely closes the database connection
|
||||
"""
|
||||
try:
|
||||
cursor = g.cursor
|
||||
db: DBConnection = cursor.connection
|
||||
cursor.close()
|
||||
delattr(g, 'cursor')
|
||||
db.commit()
|
||||
if current_thread() is main_thread():
|
||||
db.close()
|
||||
except (AttributeError, ProgrammingError):
|
||||
pass
|
||||
return
|
||||
|
||||
def migrate_db(current_db_version: int) -> None:
|
||||
"""
|
||||
Migrate a MIND database from it's current version
|
||||
to the newest version supported by the MIND version installed.
|
||||
"""
|
||||
LOGGER.info('Migrating database to newer version...')
|
||||
cursor = get_db()
|
||||
if current_db_version == 1:
|
||||
# V1 -> V2
|
||||
t = time()
|
||||
utc_offset = datetime.fromtimestamp(t) - datetime.utcfromtimestamp(t)
|
||||
cursor.execute("SELECT time, id FROM reminders;")
|
||||
new_reminders = []
|
||||
new_reminders_append = new_reminders.append
|
||||
for reminder in cursor:
|
||||
new_reminders_append([round((datetime.fromtimestamp(reminder[0]) - utc_offset).timestamp()), reminder[1]])
|
||||
cursor.executemany("UPDATE reminders SET time = ? WHERE id = ?;", new_reminders)
|
||||
current_db_version = 2
|
||||
|
||||
if current_db_version == 2:
|
||||
# V2 -> V3
|
||||
cursor.executescript("""
|
||||
ALTER TABLE reminders
|
||||
ADD color VARCHAR(7);
|
||||
ALTER TABLE templates
|
||||
ADD color VARCHAR(7);
|
||||
""")
|
||||
current_db_version = 3
|
||||
|
||||
if current_db_version == 3:
|
||||
# V3 -> V4
|
||||
cursor.executescript("""
|
||||
UPDATE reminders
|
||||
SET repeat_quantity = repeat_quantity || 's'
|
||||
WHERE repeat_quantity NOT LIKE '%s';
|
||||
""")
|
||||
current_db_version = 4
|
||||
|
||||
if current_db_version == 4:
|
||||
# V4 -> V5
|
||||
cursor.executescript("""
|
||||
BEGIN TRANSACTION;
|
||||
PRAGMA defer_foreign_keys = ON;
|
||||
|
||||
CREATE TEMPORARY TABLE temp_reminder_services(
|
||||
reminder_id,
|
||||
static_reminder_id,
|
||||
template_id,
|
||||
notification_service_id
|
||||
);
|
||||
|
||||
-- Reminders
|
||||
INSERT INTO temp_reminder_services(reminder_id, notification_service_id)
|
||||
SELECT id, notification_service
|
||||
FROM reminders;
|
||||
|
||||
CREATE TEMPORARY TABLE temp_reminders AS
|
||||
SELECT id, user_id, title, text, time, repeat_quantity, repeat_interval, original_time, color
|
||||
FROM reminders;
|
||||
DROP TABLE reminders;
|
||||
CREATE TABLE reminders(
|
||||
id INTEGER PRIMARY KEY,
|
||||
user_id INTEGER NOT NULL,
|
||||
title VARCHAR(255) NOT NULL,
|
||||
text TEXT,
|
||||
time INTEGER NOT NULL,
|
||||
|
||||
repeat_quantity VARCHAR(15),
|
||||
repeat_interval INTEGER,
|
||||
original_time INTEGER,
|
||||
|
||||
color VARCHAR(7),
|
||||
|
||||
FOREIGN KEY (user_id) REFERENCES users(id)
|
||||
);
|
||||
INSERT INTO reminders
|
||||
SELECT * FROM temp_reminders;
|
||||
|
||||
-- Templates
|
||||
INSERT INTO temp_reminder_services(template_id, notification_service_id)
|
||||
SELECT id, notification_service
|
||||
FROM templates;
|
||||
|
||||
CREATE TEMPORARY TABLE temp_templates AS
|
||||
SELECT id, user_id, title, text, color
|
||||
FROM templates;
|
||||
DROP TABLE templates;
|
||||
CREATE TABLE templates(
|
||||
id INTEGER PRIMARY KEY,
|
||||
user_id INTEGER NOT NULL,
|
||||
title VARCHAR(255) NOT NULL,
|
||||
text TEXT,
|
||||
|
||||
color VARCHAR(7),
|
||||
|
||||
FOREIGN KEY (user_id) REFERENCES users(id)
|
||||
);
|
||||
INSERT INTO templates
|
||||
SELECT * FROM temp_templates;
|
||||
|
||||
INSERT INTO reminder_services
|
||||
SELECT * FROM temp_reminder_services;
|
||||
|
||||
COMMIT;
|
||||
""")
|
||||
current_db_version = 5
|
||||
|
||||
if current_db_version == 5:
|
||||
# V5 -> V6
|
||||
from backend.users import User
|
||||
try:
|
||||
User('User1', 'Password1').delete()
|
||||
except (UserNotFound, AccessUnauthorized):
|
||||
pass
|
||||
|
||||
current_db_version = 6
|
||||
|
||||
if current_db_version == 6:
|
||||
# V6 -> V7
|
||||
cursor.executescript("""
|
||||
ALTER TABLE reminders
|
||||
ADD weekdays VARCHAR(13);
|
||||
""")
|
||||
current_db_version = 7
|
||||
|
||||
if current_db_version == 7:
|
||||
# V7 -> V8
|
||||
from backend.settings import _format_setting, default_settings
|
||||
from backend.users import Users
|
||||
|
||||
cursor.executescript("""
|
||||
DROP TABLE config;
|
||||
CREATE TABLE IF NOT EXISTS config(
|
||||
key VARCHAR(255) PRIMARY KEY,
|
||||
value BLOB NOT NULL
|
||||
);
|
||||
"""
|
||||
)
|
||||
cursor.executemany("""
|
||||
INSERT OR IGNORE INTO config(key, value)
|
||||
VALUES (?, ?);
|
||||
""",
|
||||
map(
|
||||
lambda kv: (kv[0], _format_setting(*kv)),
|
||||
default_settings.items()
|
||||
)
|
||||
)
|
||||
|
||||
cursor.executescript("""
|
||||
ALTER TABLE users
|
||||
ADD admin BOOL NOT NULL DEFAULT 0;
|
||||
|
||||
UPDATE users
|
||||
SET username = 'admin_old'
|
||||
WHERE username = 'admin';
|
||||
""")
|
||||
|
||||
Users().add('admin', 'admin', True)
|
||||
|
||||
cursor.execute("""
|
||||
UPDATE users
|
||||
SET admin = 1
|
||||
WHERE username = 'admin';
|
||||
""")
|
||||
|
||||
current_db_version = 8
|
||||
|
||||
if current_db_version == 8:
|
||||
# V8 -> V9
|
||||
from backend.settings import set_setting
|
||||
from MIND import HOST, PORT, URL_PREFIX
|
||||
|
||||
set_setting('host', HOST)
|
||||
set_setting('port', int(PORT))
|
||||
set_setting('url_prefix', URL_PREFIX)
|
||||
|
||||
current_db_version = 9
|
||||
|
||||
if current_db_version == 9:
|
||||
# V9 -> V10
|
||||
|
||||
# Nothing is changed in the database
|
||||
# It's just that this code needs to run once
|
||||
# and the DB migration system does exactly that:
|
||||
# run pieces of code once.
|
||||
from backend.settings import update_manifest
|
||||
|
||||
url_prefix: str = cursor.execute(
|
||||
"SELECT value FROM config WHERE key = 'url_prefix' LIMIT 1;"
|
||||
).fetchone()[0]
|
||||
update_manifest(url_prefix)
|
||||
|
||||
current_db_version = 10
|
||||
|
||||
return
|
||||
|
||||
def setup_db() -> None:
|
||||
"""Setup the database
|
||||
"""
|
||||
from backend.settings import (_format_setting, default_settings, get_setting,
|
||||
set_setting, update_manifest)
|
||||
from backend.users import Users
|
||||
|
||||
cursor = get_db()
|
||||
cursor.execute("PRAGMA journal_mode = wal;")
|
||||
|
||||
cursor.executescript("""
|
||||
CREATE TABLE IF NOT EXISTS users(
|
||||
id INTEGER PRIMARY KEY,
|
||||
username VARCHAR(255) UNIQUE NOT NULL,
|
||||
salt VARCHAR(40) NOT NULL,
|
||||
hash VARCHAR(100) NOT NULL,
|
||||
admin BOOL NOT NULL DEFAULT 0
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS notification_services(
|
||||
id INTEGER PRIMARY KEY,
|
||||
user_id INTEGER NOT NULL,
|
||||
title VARCHAR(255),
|
||||
url TEXT,
|
||||
|
||||
FOREIGN KEY (user_id) REFERENCES users(id)
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS reminders(
|
||||
id INTEGER PRIMARY KEY,
|
||||
user_id INTEGER NOT NULL,
|
||||
title VARCHAR(255) NOT NULL,
|
||||
text TEXT,
|
||||
time INTEGER NOT NULL,
|
||||
|
||||
repeat_quantity VARCHAR(15),
|
||||
repeat_interval INTEGER,
|
||||
original_time INTEGER,
|
||||
weekdays VARCHAR(13),
|
||||
|
||||
color VARCHAR(7),
|
||||
|
||||
FOREIGN KEY (user_id) REFERENCES users(id)
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS templates(
|
||||
id INTEGER PRIMARY KEY,
|
||||
user_id INTEGER NOT NULL,
|
||||
title VARCHAR(255) NOT NULL,
|
||||
text TEXT,
|
||||
|
||||
color VARCHAR(7),
|
||||
|
||||
FOREIGN KEY (user_id) REFERENCES users(id)
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS static_reminders(
|
||||
id INTEGER PRIMARY KEY,
|
||||
user_id INTEGER NOT NULL,
|
||||
title VARCHAR(255) NOT NULL,
|
||||
text TEXT,
|
||||
|
||||
color VARCHAR(7),
|
||||
|
||||
FOREIGN KEY (user_id) REFERENCES users(id)
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS reminder_services(
|
||||
reminder_id INTEGER,
|
||||
static_reminder_id INTEGER,
|
||||
template_id INTEGER,
|
||||
notification_service_id INTEGER NOT NULL,
|
||||
|
||||
FOREIGN KEY (reminder_id) REFERENCES reminders(id)
|
||||
ON DELETE CASCADE,
|
||||
FOREIGN KEY (static_reminder_id) REFERENCES static_reminders(id)
|
||||
ON DELETE CASCADE,
|
||||
FOREIGN KEY (template_id) REFERENCES templates(id)
|
||||
ON DELETE CASCADE,
|
||||
FOREIGN KEY (notification_service_id) REFERENCES notification_services(id)
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS config(
|
||||
key VARCHAR(255) PRIMARY KEY,
|
||||
value BLOB NOT NULL
|
||||
);
|
||||
""")
|
||||
|
||||
cursor.executemany("""
|
||||
INSERT OR IGNORE INTO config(key, value)
|
||||
VALUES (?, ?);
|
||||
""",
|
||||
map(
|
||||
lambda kv: (kv[0], _format_setting(*kv)),
|
||||
default_settings.items()
|
||||
)
|
||||
)
|
||||
|
||||
set_log_level(get_setting('log_level'), clear_file=False)
|
||||
update_manifest(get_setting('url_prefix'))
|
||||
|
||||
current_db_version = get_setting('database_version')
|
||||
if current_db_version < __DATABASE_VERSION__:
|
||||
LOGGER.debug(
|
||||
f'Database migration: {current_db_version} -> {__DATABASE_VERSION__}'
|
||||
)
|
||||
migrate_db(current_db_version)
|
||||
set_setting('database_version', __DATABASE_VERSION__)
|
||||
|
||||
users = Users()
|
||||
if not 'admin' in users:
|
||||
users.add('admin', 'admin', True)
|
||||
cursor.execute("""
|
||||
UPDATE users
|
||||
SET admin = 1
|
||||
WHERE username = 'admin';
|
||||
""")
|
||||
|
||||
return
|
||||
|
||||
def revert_db_import(
|
||||
swap: bool,
|
||||
imported_db_file: str = ''
|
||||
) -> None:
|
||||
"""Revert the database import process. The original_db_file is the file
|
||||
currently used (`DBConnection.file`).
|
||||
|
||||
Args:
|
||||
swap (bool): Whether or not to keep the imported_db_file or not,
|
||||
instead of the original_db_file.
|
||||
imported_db_file (str, optional): The other database file. Keep empty
|
||||
to use `__DATABASE_NAME_ORIGINAL__`. Defaults to ''.
|
||||
"""
|
||||
original_db_file = DBConnection.file
|
||||
if not imported_db_file:
|
||||
imported_db_file = join(dirname(DBConnection.file), __DATEBASE_NAME_ORIGINAL__)
|
||||
|
||||
if swap:
|
||||
remove(original_db_file)
|
||||
move(
|
||||
imported_db_file,
|
||||
original_db_file
|
||||
)
|
||||
|
||||
else:
|
||||
remove(imported_db_file)
|
||||
|
||||
return
|
||||
|
||||
def import_db(
|
||||
new_db_file: str,
|
||||
copy_hosting_settings: bool
|
||||
) -> None:
|
||||
"""Replace the current database with a new one.
|
||||
|
||||
Args:
|
||||
new_db_file (str): The path to the new database file.
|
||||
copy_hosting_settings (bool): Keep the hosting settings from the current
|
||||
database.
|
||||
|
||||
Raises:
|
||||
InvalidDatabaseFile: The new database file is invalid or unsupported.
|
||||
"""
|
||||
LOGGER.info(f'Importing new database; {copy_hosting_settings=}')
|
||||
try:
|
||||
cursor = Connection(new_db_file, timeout=20.0).cursor()
|
||||
|
||||
database_version = cursor.execute(
|
||||
"SELECT value FROM config WHERE key = 'database_version' LIMIT 1;"
|
||||
).fetchone()[0]
|
||||
if not isinstance(database_version, int):
|
||||
raise InvalidDatabaseFile
|
||||
|
||||
except (OperationalError, InvalidDatabaseFile):
|
||||
LOGGER.error('Uploaded database is not a MIND database file')
|
||||
cursor.connection.close()
|
||||
revert_db_import(
|
||||
swap=False,
|
||||
imported_db_file=new_db_file
|
||||
)
|
||||
raise InvalidDatabaseFile
|
||||
|
||||
if database_version > __DATABASE_VERSION__:
|
||||
LOGGER.error('Uploaded database is higher version than this MIND installation can support')
|
||||
revert_db_import(
|
||||
swap=False,
|
||||
imported_db_file=new_db_file
|
||||
)
|
||||
raise InvalidDatabaseFile
|
||||
|
||||
if copy_hosting_settings:
|
||||
hosting_settings = get_db().execute("""
|
||||
SELECT key, value, value
|
||||
FROM config
|
||||
WHERE key = 'host'
|
||||
OR key = 'port'
|
||||
OR key = 'url_prefix'
|
||||
LIMIT 3;
|
||||
"""
|
||||
)
|
||||
cursor.executemany("""
|
||||
INSERT INTO config(key, value)
|
||||
VALUES (?, ?)
|
||||
ON CONFLICT(key) DO
|
||||
UPDATE
|
||||
SET value = ?;
|
||||
""",
|
||||
hosting_settings
|
||||
)
|
||||
cursor.connection.commit()
|
||||
cursor.connection.close()
|
||||
|
||||
move(
|
||||
DBConnection.file,
|
||||
join(dirname(DBConnection.file), __DATEBASE_NAME_ORIGINAL__)
|
||||
)
|
||||
move(
|
||||
new_db_file,
|
||||
DBConnection.file
|
||||
)
|
||||
|
||||
from backend.server import SERVER
|
||||
SERVER.restart([RestartVars.DB_IMPORT.value])
|
||||
|
||||
return
|
||||
132
backend/features/reminder_handler.py
Normal file
132
backend/features/reminder_handler.py
Normal file
@@ -0,0 +1,132 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from datetime import datetime
|
||||
from threading import Timer
|
||||
from typing import Union
|
||||
|
||||
from backend.base.definitions import Constants, RepeatQuantity, SendResult
|
||||
from backend.base.helpers import (Singleton, find_next_time,
|
||||
send_apprise_notification, when_not_none)
|
||||
from backend.base.logging import LOGGER
|
||||
from backend.implementations.notification_services import NotificationService
|
||||
from backend.internals.db_models import UserlessRemindersDB
|
||||
from backend.internals.server import Server
|
||||
|
||||
|
||||
class ReminderHandler(metaclass=Singleton):
|
||||
"""
|
||||
Handle set reminders. This class is a singleton.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"Create instance of handler"
|
||||
self.thread: Union[Timer, None] = None
|
||||
self.time: Union[int, None] = None
|
||||
self.reminder_db = UserlessRemindersDB()
|
||||
return
|
||||
|
||||
def __trigger_reminders(self, time: int) -> None:
|
||||
"""Trigger all reminders that are set for a certain time.
|
||||
|
||||
Args:
|
||||
time (int): The time of the reminders to trigger.
|
||||
"""
|
||||
with Server().app.app_context():
|
||||
for reminder in self.reminder_db.fetch(time):
|
||||
try:
|
||||
user_id = self.reminder_db.reminder_id_to_user_id(
|
||||
reminder.id)
|
||||
result = send_apprise_notification(
|
||||
[
|
||||
NotificationService(user_id, ns).get().url
|
||||
for ns in reminder.notification_services
|
||||
],
|
||||
reminder.title,
|
||||
reminder.text
|
||||
)
|
||||
|
||||
self.thread = None
|
||||
self.time = None
|
||||
|
||||
if result == SendResult.CONNECTION_ERROR:
|
||||
# Retry sending the notification in a few minutes
|
||||
self.reminder_db.update(
|
||||
reminder.id,
|
||||
time + Constants.CONNECTION_ERROR_TIMEOUT
|
||||
)
|
||||
|
||||
elif (
|
||||
reminder.repeat_quantity,
|
||||
reminder.weekdays
|
||||
) == (None, None):
|
||||
# Delete the reminder from the database
|
||||
self.reminder_db.delete(reminder.id)
|
||||
|
||||
else:
|
||||
# Set next time
|
||||
new_time = find_next_time(
|
||||
reminder.original_time or -1,
|
||||
when_not_none(
|
||||
reminder.repeat_quantity,
|
||||
lambda q: RepeatQuantity(q)
|
||||
),
|
||||
reminder.repeat_interval,
|
||||
reminder.weekdays
|
||||
)
|
||||
|
||||
self.reminder_db.update(reminder.id, new_time)
|
||||
|
||||
except Exception:
|
||||
# If the notification fails, we don't want to crash the whole program
|
||||
# Just log the error and continue
|
||||
LOGGER.exception(
|
||||
"Failed to send notification for reminder %s: ",
|
||||
reminder.id
|
||||
)
|
||||
|
||||
finally:
|
||||
self.find_next_reminder()
|
||||
|
||||
return
|
||||
|
||||
def find_next_reminder(self, time: Union[int, None] = None) -> None:
|
||||
"""Determine when the soonest reminder is and set the timer to that time.
|
||||
|
||||
Args:
|
||||
time (Union[int, None], optional): The timestamp to check for.
|
||||
Otherwise check soonest in database.
|
||||
Defaults to None.
|
||||
"""
|
||||
if time is None:
|
||||
time = self.reminder_db.get_soonest_time()
|
||||
if not time:
|
||||
return
|
||||
|
||||
if (
|
||||
self.thread is None
|
||||
or (
|
||||
self.time is not None
|
||||
and time < self.time
|
||||
)
|
||||
):
|
||||
if self.thread is not None:
|
||||
self.thread.cancel()
|
||||
|
||||
delta_t = time - datetime.utcnow().timestamp()
|
||||
self.thread = Timer(
|
||||
delta_t,
|
||||
self.__trigger_reminders,
|
||||
(time,)
|
||||
)
|
||||
self.thread.name = "ReminderHandler"
|
||||
self.thread.start()
|
||||
self.time = time
|
||||
|
||||
return
|
||||
|
||||
def stop_handling(self) -> None:
|
||||
"""Stop the timer if it's active
|
||||
"""
|
||||
if self.thread is not None:
|
||||
self.thread.cancel()
|
||||
return
|
||||
418
backend/features/reminders.py
Normal file
418
backend/features/reminders.py
Normal file
@@ -0,0 +1,418 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from dataclasses import asdict
|
||||
from datetime import datetime
|
||||
from typing import List, Union
|
||||
|
||||
from backend.base.custom_exceptions import (InvalidKeyValue, InvalidTime,
|
||||
ReminderNotFound)
|
||||
from backend.base.definitions import (WEEKDAY_NUMBER, ReminderData,
|
||||
RepeatQuantity, SendResult,
|
||||
SortingMethod)
|
||||
from backend.base.helpers import (find_next_time, search_filter,
|
||||
send_apprise_notification, when_not_none)
|
||||
from backend.base.logging import LOGGER
|
||||
from backend.features.reminder_handler import ReminderHandler
|
||||
from backend.implementations.notification_services import NotificationService
|
||||
from backend.internals.db_models import RemindersDB
|
||||
|
||||
REMINDER_HANDLER = ReminderHandler()
|
||||
|
||||
|
||||
class Reminder:
|
||||
def __init__(self, user_id: int, reminder_id: int) -> None:
|
||||
"""Represent a reminder.
|
||||
|
||||
Args:
|
||||
user_id (int): The ID of the user.
|
||||
reminder_id (int): The ID of the reminder.
|
||||
|
||||
Raises:
|
||||
ReminderNotFound: Reminder with given ID does not exist or is not
|
||||
owned by user.
|
||||
"""
|
||||
self.user_id = user_id
|
||||
self.id = reminder_id
|
||||
|
||||
self.reminder_db = RemindersDB(self.user_id)
|
||||
|
||||
if not self.reminder_db.exists(self.id):
|
||||
raise ReminderNotFound(reminder_id)
|
||||
return
|
||||
|
||||
def get(self) -> ReminderData:
|
||||
"""Get info about the reminder.
|
||||
|
||||
Returns:
|
||||
ReminderData: The info about the reminder.
|
||||
"""
|
||||
return self.reminder_db.fetch(self.id)[0]
|
||||
|
||||
def update(
|
||||
self,
|
||||
title: Union[None, str] = None,
|
||||
time: Union[None, int] = None,
|
||||
notification_services: Union[None, List[int]] = None,
|
||||
text: Union[None, str] = None,
|
||||
repeat_quantity: Union[None, RepeatQuantity] = None,
|
||||
repeat_interval: Union[None, int] = None,
|
||||
weekdays: Union[None, List[WEEKDAY_NUMBER]] = None,
|
||||
color: Union[None, str] = None
|
||||
) -> ReminderData:
|
||||
"""Edit the reminder.
|
||||
|
||||
Args:
|
||||
title (Union[None, str]): The new title of the entry.
|
||||
Defaults to None.
|
||||
|
||||
time (Union[None, int]): The new UTC epoch timestamp when the
|
||||
reminder should be send.
|
||||
Defaults to None.
|
||||
|
||||
notification_services (Union[None, List[int]]): The new list
|
||||
of id's of the notification services to use to send the reminder.
|
||||
Defaults to None.
|
||||
|
||||
text (Union[None, str], optional): The new body of the reminder.
|
||||
Defaults to None.
|
||||
|
||||
repeat_quantity (Union[None, RepeatQuantity], optional): The new
|
||||
quantity of the repeat specified for the reminder.
|
||||
Defaults to None.
|
||||
|
||||
repeat_interval (Union[None, int], optional): The new amount of
|
||||
repeat_quantity, like "5" (hours).
|
||||
Defaults to None.
|
||||
|
||||
weekdays (Union[None, List[WEEKDAY_NUMBER]], optional): The new
|
||||
indexes of the days of the week that the reminder should run.
|
||||
Defaults to None.
|
||||
|
||||
color (Union[None, str], optional): The new hex code of the color
|
||||
of the reminder, which is shown in the web-ui.
|
||||
Defaults to None.
|
||||
|
||||
Note about args:
|
||||
Either repeat_quantity and repeat_interval are given, weekdays is
|
||||
given or neither, but not both.
|
||||
|
||||
Raises:
|
||||
NotificationServiceNotFound: One of the notification services was not found.
|
||||
InvalidKeyValue: The value of one of the keys is not valid or
|
||||
the "Note about args" is violated.
|
||||
|
||||
Returns:
|
||||
ReminderData: The new reminder info.
|
||||
"""
|
||||
LOGGER.info(
|
||||
f'Updating notification service {self.id}: '
|
||||
+ f'{title=}, {time=}, {notification_services=}, {text=}, '
|
||||
+ f'{repeat_quantity=}, {repeat_interval=}, {weekdays=}, {color=}'
|
||||
)
|
||||
|
||||
# Validate data
|
||||
if repeat_quantity is None and repeat_interval is not None:
|
||||
raise InvalidKeyValue('repeat_quantity', repeat_quantity)
|
||||
elif repeat_quantity is not None and repeat_interval is None:
|
||||
raise InvalidKeyValue('repeat_interval', repeat_interval)
|
||||
elif weekdays is not None and repeat_quantity is not None:
|
||||
raise InvalidKeyValue('weekdays', weekdays)
|
||||
|
||||
repeated_reminder = (
|
||||
(repeat_quantity is not None and repeat_interval is not None)
|
||||
or weekdays is not None
|
||||
)
|
||||
|
||||
if time is not None:
|
||||
if not repeated_reminder:
|
||||
if time < datetime.utcnow().timestamp():
|
||||
raise InvalidTime(time)
|
||||
time = round(time)
|
||||
|
||||
if notification_services:
|
||||
# Check if all notification services exist
|
||||
for ns in notification_services:
|
||||
NotificationService(self.user_id, ns)
|
||||
|
||||
# Get current data and update it with new values
|
||||
data = asdict(self.get())
|
||||
|
||||
new_values = {
|
||||
'title': title,
|
||||
'time': time,
|
||||
'text': text,
|
||||
'repeat_quantity': when_not_none(
|
||||
repeat_quantity,
|
||||
lambda q: q.value
|
||||
),
|
||||
'repeat_interval': repeat_interval,
|
||||
'weekdays': when_not_none(
|
||||
weekdays,
|
||||
lambda w: ",".join(map(str, sorted(w)))
|
||||
),
|
||||
'color': color,
|
||||
'notification_services': notification_services
|
||||
}
|
||||
for k, v in new_values.items():
|
||||
if (
|
||||
k in ('repeat_quantity', 'repeat_interval', 'weekdays', 'color')
|
||||
or v is not None
|
||||
):
|
||||
data[k] = v
|
||||
|
||||
if repeated_reminder:
|
||||
next_time = find_next_time(
|
||||
data["time"],
|
||||
data["repeat_quantity"],
|
||||
data["repeat_interval"],
|
||||
weekdays
|
||||
)
|
||||
self.reminder_db.update(
|
||||
self.id,
|
||||
data["title"],
|
||||
data["text"],
|
||||
next_time,
|
||||
data["repeat_quantity"],
|
||||
data["repeat_interval"],
|
||||
data["weekdays"],
|
||||
data["time"],
|
||||
data["color"],
|
||||
data["notification_services"]
|
||||
)
|
||||
|
||||
else:
|
||||
next_time = data["time"]
|
||||
self.reminder_db.update(
|
||||
self.id,
|
||||
data["title"],
|
||||
data["text"],
|
||||
next_time,
|
||||
data["repeat_quantity"],
|
||||
data["repeat_interval"],
|
||||
data["weekdays"],
|
||||
data["original_time"],
|
||||
data["color"],
|
||||
data["notification_services"]
|
||||
)
|
||||
|
||||
REMINDER_HANDLER.find_next_reminder(next_time)
|
||||
return self.get()
|
||||
|
||||
def delete(self) -> None:
|
||||
"Delete the reminder"
|
||||
LOGGER.info(f'Deleting reminder {self.id}')
|
||||
self.reminder_db.delete(self.id)
|
||||
REMINDER_HANDLER.find_next_reminder()
|
||||
return
|
||||
|
||||
|
||||
class Reminders:
|
||||
def __init__(self, user_id: int) -> None:
|
||||
"""Create an instance.
|
||||
|
||||
Args:
|
||||
user_id (int): The ID of the user.
|
||||
"""
|
||||
self.user_id = user_id
|
||||
self.reminder_db = RemindersDB(self.user_id)
|
||||
return
|
||||
|
||||
def fetchall(
|
||||
self,
|
||||
sort_by: SortingMethod = SortingMethod.TIME
|
||||
) -> List[ReminderData]:
|
||||
"""Get all reminders.
|
||||
|
||||
Args:
|
||||
sort_by (SortingMethod, optional): How to sort the result.
|
||||
Defaults to SortingMethod.TIME.
|
||||
|
||||
Returns:
|
||||
List[ReminderData]: The info of each reminder.
|
||||
"""
|
||||
reminders = self.reminder_db.fetch()
|
||||
reminders.sort(key=sort_by.value[0], reverse=sort_by.value[1])
|
||||
return reminders
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
sort_by: SortingMethod = SortingMethod.TIME
|
||||
) -> List[ReminderData]:
|
||||
"""Search for reminders.
|
||||
|
||||
Args:
|
||||
query (str): The term to search for.
|
||||
sort_by (SortingMethod, optional): How to sort the result.
|
||||
Defaults to SortingMethod.TIME.
|
||||
|
||||
Returns:
|
||||
List[ReminderData]: All reminders that match. Similar output to
|
||||
self.fetchall.
|
||||
"""
|
||||
reminders = [
|
||||
r
|
||||
for r in self.fetchall(sort_by)
|
||||
if search_filter(query, r)
|
||||
]
|
||||
return reminders
|
||||
|
||||
def fetchone(self, id: int) -> Reminder:
|
||||
"""Get one reminder.
|
||||
|
||||
Args:
|
||||
id (int): The ID of the reminder to fetch.
|
||||
|
||||
Raises:
|
||||
ReminderNotFound: The reminder with the given ID does not exist
|
||||
or is not owned by the user.
|
||||
|
||||
Returns:
|
||||
Reminder: A Reminder instance.
|
||||
"""
|
||||
return Reminder(self.user_id, id)
|
||||
|
||||
def add(
|
||||
self,
|
||||
title: str,
|
||||
time: int,
|
||||
notification_services: List[int],
|
||||
text: str = '',
|
||||
repeat_quantity: Union[None, RepeatQuantity] = None,
|
||||
repeat_interval: Union[None, int] = None,
|
||||
weekdays: Union[None, List[WEEKDAY_NUMBER]] = None,
|
||||
color: Union[None, str] = None
|
||||
) -> Reminder:
|
||||
"""Add a reminder.
|
||||
|
||||
Args:
|
||||
title (str): The title of the entry.
|
||||
|
||||
time (int): The UTC epoch timestamp the the reminder should be send.
|
||||
|
||||
notification_services (List[int]): The id's of the notification
|
||||
services to use to send the reminder.
|
||||
|
||||
text (str, optional): The body of the reminder.
|
||||
Defaults to ''.
|
||||
|
||||
repeat_quantity (Union[None, RepeatQuantity], optional): The quantity
|
||||
of the repeat specified for the reminder.
|
||||
Defaults to None.
|
||||
|
||||
repeat_interval (Union[None, int], optional): The amount of
|
||||
repeat_quantity, like "5" (hours).
|
||||
Defaults to None.
|
||||
|
||||
weekdays (Union[None, List[WEEKDAY_NUMBER]], optional): The indexes
|
||||
of the days of the week that the reminder should run.
|
||||
Defaults to None.
|
||||
|
||||
color (Union[None, str], optional): The hex code of the color of the
|
||||
reminder, which is shown in the web-ui.
|
||||
Defaults to None.
|
||||
|
||||
Note about args:
|
||||
Either repeat_quantity and repeat_interval are given,
|
||||
weekdays is given or neither, but not both.
|
||||
|
||||
Raises:
|
||||
NotificationServiceNotFound: One of the notification services was
|
||||
not found.
|
||||
InvalidKeyValue: The value of one of the keys is not valid
|
||||
or the "Note about args" is violated.
|
||||
|
||||
Returns:
|
||||
Reminder: The info about the reminder.
|
||||
"""
|
||||
LOGGER.info(
|
||||
f'Adding reminder with {title=}, {time=}, {notification_services=}, ' +
|
||||
f'{text=}, {repeat_quantity=}, {repeat_interval=}, {weekdays=}, {color=}')
|
||||
|
||||
# Validate data
|
||||
if time < datetime.utcnow().timestamp():
|
||||
raise InvalidTime(time)
|
||||
time = round(time)
|
||||
|
||||
if repeat_quantity is None and repeat_interval is not None:
|
||||
raise InvalidKeyValue('repeat_quantity', repeat_quantity)
|
||||
elif repeat_quantity is not None and repeat_interval is None:
|
||||
raise InvalidKeyValue('repeat_interval', repeat_interval)
|
||||
elif (
|
||||
weekdays is not None
|
||||
and repeat_quantity is not None
|
||||
and repeat_interval is not None
|
||||
):
|
||||
raise InvalidKeyValue('weekdays', weekdays)
|
||||
|
||||
# Check if all notification services exist
|
||||
for ns in notification_services:
|
||||
NotificationService(self.user_id, ns)
|
||||
|
||||
# Prepare args
|
||||
if any((repeat_quantity, weekdays)):
|
||||
original_time = time
|
||||
time = find_next_time(
|
||||
original_time,
|
||||
repeat_quantity,
|
||||
repeat_interval,
|
||||
weekdays
|
||||
)
|
||||
else:
|
||||
original_time = None
|
||||
|
||||
weekdays_str = when_not_none(
|
||||
weekdays,
|
||||
lambda w: ",".join(map(str, sorted(w)))
|
||||
)
|
||||
repeat_quantity_str = when_not_none(
|
||||
repeat_quantity,
|
||||
lambda q: q.value
|
||||
)
|
||||
|
||||
new_id = self.reminder_db.add(
|
||||
title, text,
|
||||
time, repeat_quantity_str,
|
||||
repeat_interval,
|
||||
weekdays_str,
|
||||
original_time,
|
||||
color,
|
||||
notification_services
|
||||
)
|
||||
|
||||
REMINDER_HANDLER.find_next_reminder(time)
|
||||
|
||||
return self.fetchone(new_id)
|
||||
|
||||
def test_reminder(
|
||||
self,
|
||||
title: str,
|
||||
notification_services: List[int],
|
||||
text: str = ''
|
||||
) -> SendResult:
|
||||
"""Test send a reminder draft.
|
||||
|
||||
Args:
|
||||
title (str): Title title of the entry.
|
||||
|
||||
notification_service (int): The id of the notification service to
|
||||
use to send the reminder.
|
||||
|
||||
text (str, optional): The body of the reminder.
|
||||
Defaults to ''.
|
||||
|
||||
Returns:
|
||||
SendResult: Whether or not it was successful.
|
||||
"""
|
||||
LOGGER.info(
|
||||
f'Testing reminder with {title=}, {notification_services=}, {text=}'
|
||||
)
|
||||
|
||||
return send_apprise_notification(
|
||||
[
|
||||
NotificationService(self.user_id, ns_id).get().url
|
||||
for ns_id in notification_services
|
||||
],
|
||||
title,
|
||||
text
|
||||
)
|
||||
244
backend/features/static_reminders.py
Normal file
244
backend/features/static_reminders.py
Normal file
@@ -0,0 +1,244 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from dataclasses import asdict
|
||||
from typing import List, Union
|
||||
|
||||
from backend.base.custom_exceptions import ReminderNotFound
|
||||
from backend.base.definitions import (SendResult, StaticReminderData,
|
||||
TimelessSortingMethod)
|
||||
from backend.base.helpers import search_filter, send_apprise_notification
|
||||
from backend.base.logging import LOGGER
|
||||
from backend.implementations.notification_services import NotificationService
|
||||
from backend.internals.db_models import StaticRemindersDB
|
||||
|
||||
|
||||
class StaticReminder:
|
||||
def __init__(self, user_id: int, reminder_id: int) -> None:
|
||||
"""Represent a static reminder.
|
||||
|
||||
Args:
|
||||
user_id (int): The ID of the user.
|
||||
reminder_id (int): The ID of the reminder.
|
||||
|
||||
Raises:
|
||||
ReminderNotFound: Reminder with given ID does not exist or is not
|
||||
owned by user.
|
||||
"""
|
||||
self.user_id = user_id
|
||||
self.id = reminder_id
|
||||
|
||||
self.reminder_db = StaticRemindersDB(self.user_id)
|
||||
|
||||
if not self.reminder_db.exists(self.id):
|
||||
raise ReminderNotFound(reminder_id)
|
||||
return
|
||||
|
||||
def get(self) -> StaticReminderData:
|
||||
"""Get info about the static reminder.
|
||||
|
||||
Returns:
|
||||
StaticReminderData: The info about the static reminder.
|
||||
"""
|
||||
return self.reminder_db.fetch(self.id)[0]
|
||||
|
||||
def trigger_reminder(self) -> SendResult:
|
||||
"""Send the reminder.
|
||||
|
||||
Returns:
|
||||
SendResult: The result of the sending process.
|
||||
"""
|
||||
LOGGER.info(f'Triggering static reminder {self.id}')
|
||||
|
||||
reminder_data = self.get()
|
||||
|
||||
return send_apprise_notification(
|
||||
[
|
||||
NotificationService(self.user_id, ns_id).get().url
|
||||
for ns_id in reminder_data.notification_services
|
||||
],
|
||||
reminder_data.title,
|
||||
reminder_data.text
|
||||
)
|
||||
|
||||
def update(
|
||||
self,
|
||||
title: Union[str, None] = None,
|
||||
notification_services: Union[List[int], None] = None,
|
||||
text: Union[str, None] = None,
|
||||
color: Union[str, None] = None
|
||||
) -> StaticReminderData:
|
||||
"""Edit the static reminder.
|
||||
|
||||
Args:
|
||||
title (Union[str, None], optional): The new title of the entry.
|
||||
Defaults to None.
|
||||
|
||||
notification_services (Union[List[int], None], optional): The new
|
||||
id's of the notification services to use to send the reminder.
|
||||
Defaults to None.
|
||||
|
||||
text (Union[str, None], optional): The new body of the reminder.
|
||||
Defaults to None.
|
||||
|
||||
color (Union[str, None], optional): The new hex code of the color
|
||||
of the reminder, which is shown in the web-ui.
|
||||
Defaults to None.
|
||||
|
||||
Raises:
|
||||
NotificationServiceNotFound: One of the notification services was
|
||||
not found.
|
||||
|
||||
Returns:
|
||||
StaticReminderData: The new static reminder info.
|
||||
"""
|
||||
LOGGER.info(
|
||||
f'Updating static reminder {self.id}: '
|
||||
+ f'{title=}, {notification_services=}, {text=}, {color=}'
|
||||
)
|
||||
|
||||
if notification_services:
|
||||
# Check whether all notification services exist
|
||||
for ns in notification_services:
|
||||
NotificationService(self.user_id, ns)
|
||||
|
||||
# Get current data and update it with new values
|
||||
data = asdict(self.get())
|
||||
|
||||
new_values = {
|
||||
'title': title,
|
||||
'text': text,
|
||||
'color': color,
|
||||
'notification_services': notification_services
|
||||
}
|
||||
for k, v in new_values.items():
|
||||
if k in ('color',) or v is not None:
|
||||
data[k] = v
|
||||
|
||||
self.reminder_db.update(
|
||||
self.id,
|
||||
data['title'],
|
||||
data['text'],
|
||||
data['color'],
|
||||
data['notification_services']
|
||||
)
|
||||
|
||||
return self.get()
|
||||
|
||||
def delete(self) -> None:
|
||||
"Delete the static reminder"
|
||||
LOGGER.info(f'Deleting static reminder {self.id}')
|
||||
self.reminder_db.delete(self.id)
|
||||
return
|
||||
|
||||
|
||||
class StaticReminders:
|
||||
def __init__(self, user_id: int) -> None:
|
||||
"""Create an instance.
|
||||
|
||||
Args:
|
||||
user_id (int): The ID of the user.
|
||||
"""
|
||||
self.user_id = user_id
|
||||
self.reminder_db = StaticRemindersDB(self.user_id)
|
||||
return
|
||||
|
||||
def fetchall(
|
||||
self,
|
||||
sort_by: TimelessSortingMethod = TimelessSortingMethod.TITLE
|
||||
) -> List[StaticReminderData]:
|
||||
"""Get all static reminders.
|
||||
|
||||
Args:
|
||||
sort_by (TimelessSortingMethod, optional): How to sort the result.
|
||||
Defaults to TimelessSortingMethod.TITLE.
|
||||
|
||||
Returns:
|
||||
List[StaticReminderData]: The info of each static reminder.
|
||||
"""
|
||||
reminders = self.reminder_db.fetch()
|
||||
reminders.sort(key=sort_by.value[0], reverse=sort_by.value[1])
|
||||
return reminders
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
sort_by: TimelessSortingMethod = TimelessSortingMethod.TITLE
|
||||
) -> List[StaticReminderData]:
|
||||
"""Search for static reminders.
|
||||
|
||||
Args:
|
||||
query (str): The term to search for.
|
||||
|
||||
sort_by (TimelessSortingMethod, optional): The sorting method of
|
||||
the resulting list.
|
||||
Defaults to TimelessSortingMethod.TITLE.
|
||||
|
||||
Returns:
|
||||
List[StaticReminderData]: All static reminders that match.
|
||||
Similar output to `self.fetchall`
|
||||
"""
|
||||
static_reminders = [
|
||||
r
|
||||
for r in self.fetchall(sort_by)
|
||||
if search_filter(query, r)
|
||||
]
|
||||
return static_reminders
|
||||
|
||||
def fetchone(self, reminder_id: int) -> StaticReminder:
|
||||
"""Get one static reminder.
|
||||
|
||||
Args:
|
||||
reminder_id (int): The id of the static reminder to fetch.
|
||||
|
||||
Raises:
|
||||
ReminderNotFound: The static reminder with the given ID does not
|
||||
exist or is not owned by the user.
|
||||
|
||||
Returns:
|
||||
StaticReminder: A StaticReminder instance.
|
||||
"""
|
||||
return StaticReminder(self.user_id, reminder_id)
|
||||
|
||||
def add(
|
||||
self,
|
||||
title: str,
|
||||
notification_services: List[int],
|
||||
text: str = '',
|
||||
color: Union[str, None] = None
|
||||
) -> StaticReminder:
|
||||
"""Add a static reminder.
|
||||
|
||||
Args:
|
||||
title (str): The title of the entry.
|
||||
|
||||
notification_services (List[int]): The id's of the
|
||||
notification services to use to send the reminder.
|
||||
|
||||
text (str, optional): The body of the reminder.
|
||||
Defaults to ''.
|
||||
|
||||
color (Union[str, None], optional): The hex code of the color of the
|
||||
template, which is shown in the web-ui.
|
||||
Defaults to None.
|
||||
|
||||
Raises:
|
||||
NotificationServiceNotFound: One of the notification services was
|
||||
not found.
|
||||
|
||||
Returns:
|
||||
StaticReminder: The info about the static reminder
|
||||
"""
|
||||
LOGGER.info(
|
||||
f'Adding static reminder with {title=}, {notification_services=}, {text=}, {color=}'
|
||||
)
|
||||
|
||||
# Check if all notification services exist
|
||||
for ns in notification_services:
|
||||
NotificationService(self.user_id, ns)
|
||||
|
||||
new_id = self.reminder_db.add(
|
||||
title, text, color,
|
||||
notification_services
|
||||
)
|
||||
|
||||
return self.fetchone(new_id)
|
||||
223
backend/features/templates.py
Normal file
223
backend/features/templates.py
Normal file
@@ -0,0 +1,223 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from dataclasses import asdict
|
||||
from typing import List, Union
|
||||
|
||||
from backend.base.custom_exceptions import TemplateNotFound
|
||||
from backend.base.definitions import TemplateData, TimelessSortingMethod
|
||||
from backend.base.helpers import search_filter
|
||||
from backend.base.logging import LOGGER
|
||||
from backend.implementations.notification_services import NotificationService
|
||||
from backend.internals.db_models import TemplatesDB
|
||||
|
||||
|
||||
class Template:
|
||||
def __init__(self, user_id: int, template_id: int) -> None:
|
||||
"""Represent a template.
|
||||
|
||||
Args:
|
||||
user_id (int): The ID of the user.
|
||||
template_id (int): The ID of the template.
|
||||
|
||||
Raises:
|
||||
TemplateNotFound: Template with given ID does not exist or is not
|
||||
owned by user.
|
||||
"""
|
||||
self.user_id = user_id
|
||||
self.id = template_id
|
||||
|
||||
self.template_db = TemplatesDB(self.user_id)
|
||||
|
||||
if not self.template_db.exists(self.id):
|
||||
raise TemplateNotFound(self.id)
|
||||
return
|
||||
|
||||
def get(self) -> TemplateData:
|
||||
"""Get info about the template.
|
||||
|
||||
Returns:
|
||||
TemplateData: The info about the template.
|
||||
"""
|
||||
return self.template_db.fetch(self.id)[0]
|
||||
|
||||
def update(self,
|
||||
title: Union[str, None] = None,
|
||||
notification_services: Union[List[int], None] = None,
|
||||
text: Union[str, None] = None,
|
||||
color: Union[str, None] = None
|
||||
) -> TemplateData:
|
||||
"""Edit the template.
|
||||
|
||||
Args:
|
||||
title (Union[str, None]): The new title of the entry.
|
||||
Defaults to None.
|
||||
|
||||
notification_services (Union[List[int], None]): The new id's of the
|
||||
notification services to use to send the reminder.
|
||||
Defaults to None.
|
||||
|
||||
text (Union[str, None], optional): The new body of the template.
|
||||
Defaults to None.
|
||||
|
||||
color (Union[str, None], optional): The new hex code of the color of
|
||||
the template, which is shown in the web-ui.
|
||||
Defaults to None.
|
||||
|
||||
Raises:
|
||||
NotificationServiceNotFound: One of the notification services was
|
||||
not found.
|
||||
|
||||
Returns:
|
||||
TemplateData: The new template info.
|
||||
"""
|
||||
LOGGER.info(
|
||||
f'Updating template {self.id}: '
|
||||
+ f'{title=}, {notification_services=}, {text=}, {color=}'
|
||||
)
|
||||
|
||||
if notification_services:
|
||||
# Check if all notification services exist
|
||||
for ns in notification_services:
|
||||
NotificationService(self.user_id, ns)
|
||||
|
||||
data = asdict(self.get())
|
||||
|
||||
new_values = {
|
||||
'title': title,
|
||||
'text': text,
|
||||
'color': color,
|
||||
'notification_services': notification_services
|
||||
}
|
||||
for k, v in new_values.items():
|
||||
if k in ('color',) or v is not None:
|
||||
data[k] = v
|
||||
|
||||
self.template_db.update(
|
||||
self.id,
|
||||
data['title'],
|
||||
data['text'],
|
||||
data['color'],
|
||||
data['notification_services']
|
||||
)
|
||||
|
||||
return self.get()
|
||||
|
||||
def delete(self) -> None:
|
||||
"Delete the template"
|
||||
LOGGER.info(f'Deleting template {self.id}')
|
||||
self.template_db.delete(self.id)
|
||||
return
|
||||
|
||||
|
||||
class Templates:
|
||||
def __init__(self, user_id: int) -> None:
|
||||
"""Create an instance.
|
||||
|
||||
Args:
|
||||
user_id (int): The ID of the user.
|
||||
"""
|
||||
self.user_id = user_id
|
||||
self.template_db = TemplatesDB(self.user_id)
|
||||
return
|
||||
|
||||
def fetchall(
|
||||
self,
|
||||
sort_by: TimelessSortingMethod = TimelessSortingMethod.TITLE
|
||||
) -> List[TemplateData]:
|
||||
"""Get all templates of the user.
|
||||
|
||||
Args:
|
||||
sort_by (TimelessSortingMethod, optional): The sorting method of
|
||||
the resulting list.
|
||||
Defaults to TimelessSortingMethod.TITLE.
|
||||
|
||||
Returns:
|
||||
List[TemplateData]: The id, title, text and color of each template.
|
||||
"""
|
||||
templates = self.template_db.fetch()
|
||||
templates.sort(key=sort_by.value[0], reverse=sort_by.value[1])
|
||||
return templates
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
sort_by: TimelessSortingMethod = TimelessSortingMethod.TITLE
|
||||
) -> List[TemplateData]:
|
||||
"""Search for templates.
|
||||
|
||||
Args:
|
||||
query (str): The term to search for.
|
||||
|
||||
sort_by (TimelessSortingMethod, optional): The sorting method of
|
||||
the resulting list.
|
||||
Defaults to TimelessSortingMethod.TITLE.
|
||||
|
||||
Returns:
|
||||
List[TemplateData]: All templates that match. Similar output to
|
||||
`self.fetchall`.
|
||||
"""
|
||||
templates = [
|
||||
r
|
||||
for r in self.fetchall(sort_by)
|
||||
if search_filter(query, r)
|
||||
]
|
||||
return templates
|
||||
|
||||
def fetchone(self, template_id: int) -> Template:
|
||||
"""Get one template.
|
||||
|
||||
Args:
|
||||
template_id (int): The id of the template to fetch.
|
||||
|
||||
Raises:
|
||||
TemplateNotFound: Template with given ID does not exist or is not
|
||||
owned by user.
|
||||
|
||||
Returns:
|
||||
Template: A Template instance.
|
||||
"""
|
||||
return Template(self.user_id, template_id)
|
||||
|
||||
def add(
|
||||
self,
|
||||
title: str,
|
||||
notification_services: List[int],
|
||||
text: str = '',
|
||||
color: Union[str, None] = None
|
||||
) -> Template:
|
||||
"""Add a template.
|
||||
|
||||
Args:
|
||||
title (str): The title of the entry.
|
||||
|
||||
notification_services (List[int]): The id's of the
|
||||
notification services to use to send the reminder.
|
||||
|
||||
text (str, optional): The body of the reminder.
|
||||
Defaults to ''.
|
||||
|
||||
color (Union[str, None], optional): The hex code of the color of the
|
||||
template, which is shown in the web-ui.
|
||||
Defaults to None.
|
||||
|
||||
Raises:
|
||||
NotificationServiceNotFound: One of the notification services was
|
||||
not found.
|
||||
|
||||
Returns:
|
||||
Template: The info about the template.
|
||||
"""
|
||||
LOGGER.info(
|
||||
f'Adding template with {title=}, {notification_services=}, {text=}, {color=}'
|
||||
)
|
||||
|
||||
# Check if all notification services exist
|
||||
for ns in notification_services:
|
||||
NotificationService(self.user_id, ns)
|
||||
|
||||
new_id = self.template_db.add(
|
||||
title, text, color,
|
||||
notification_services
|
||||
)
|
||||
|
||||
return self.fetchone(new_id)
|
||||
@@ -1,116 +0,0 @@
|
||||
#-*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
General functions
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from os.path import abspath, dirname, join
|
||||
from sys import version_info
|
||||
from typing import Callable, TypeVar, Union
|
||||
|
||||
T = TypeVar('T')
|
||||
U = TypeVar('U')
|
||||
|
||||
def folder_path(*folders) -> str:
|
||||
"""Turn filepaths relative to the project folder into absolute paths
|
||||
|
||||
Returns:
|
||||
str: The absolute filepath
|
||||
"""
|
||||
return join(dirname(dirname(abspath(__file__))), *folders)
|
||||
|
||||
|
||||
def check_python_version() -> bool:
|
||||
"""Check if the python version that is used is a minimum version.
|
||||
|
||||
Returns:
|
||||
bool: Whether or not the python version is version 3.8 or above or not.
|
||||
"""
|
||||
if not (version_info.major == 3 and version_info.minor >= 8):
|
||||
from backend.logging import LOGGER
|
||||
|
||||
LOGGER.critical(
|
||||
'The minimum python version required is python3.8 ' +
|
||||
'(currently ' + str(version_info.major) + '.' + str(version_info.minor) + '.' + str(version_info.micro) + ').'
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def search_filter(query: str, result: dict) -> bool:
|
||||
"""Filter library results based on a query.
|
||||
|
||||
Args:
|
||||
query (str): The query to filter with.
|
||||
result (dict): The library result to check.
|
||||
|
||||
Returns:
|
||||
bool: Whether or not the result passes the filter.
|
||||
"""
|
||||
query = query.lower()
|
||||
return (
|
||||
query in result["title"].lower()
|
||||
or query in result["text"].lower()
|
||||
)
|
||||
|
||||
|
||||
def when_not_none(value: Union[T, None], to_run: Callable[[T], U]) -> Union[U, None]:
|
||||
"""Run `to_run` with argument `value` iff `value is not None`. Else return
|
||||
`None`.
|
||||
|
||||
Args:
|
||||
value (Union[T, None]): The value to check.
|
||||
to_run (Callable[[T], U]): The function to run.
|
||||
|
||||
Returns:
|
||||
Union[U, None]: Either the return value of `to_run`, or `None`.
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
else:
|
||||
return to_run(value)
|
||||
|
||||
|
||||
class Singleton(type):
|
||||
_instances = {}
|
||||
def __call__(cls, *args, **kwargs):
|
||||
c = str(cls)
|
||||
if c not in cls._instances:
|
||||
cls._instances[c] = super().__call__(*args, **kwargs)
|
||||
|
||||
return cls._instances[c]
|
||||
|
||||
|
||||
class BaseEnum(Enum):
|
||||
def __eq__(self, other) -> bool:
|
||||
return self.value == other
|
||||
|
||||
|
||||
class TimelessSortingMethod(BaseEnum):
|
||||
TITLE = (lambda r: (r['title'], r['text'], r['color']), False)
|
||||
TITLE_REVERSED = (lambda r: (r['title'], r['text'], r['color']), True)
|
||||
DATE_ADDED = (lambda r: r['id'], False)
|
||||
DATE_ADDED_REVERSED = (lambda r: r['id'], True)
|
||||
|
||||
|
||||
class SortingMethod(BaseEnum):
|
||||
TIME = (lambda r: (r['time'], r['title'], r['text'], r['color']), False)
|
||||
TIME_REVERSED = (lambda r: (r['time'], r['title'], r['text'], r['color']), True)
|
||||
TITLE = (lambda r: (r['title'], r['time'], r['text'], r['color']), False)
|
||||
TITLE_REVERSED = (lambda r: (r['title'], r['time'], r['text'], r['color']), True)
|
||||
DATE_ADDED = (lambda r: r['id'], False)
|
||||
DATE_ADDED_REVERSED = (lambda r: r['id'], True)
|
||||
|
||||
|
||||
class RepeatQuantity(BaseEnum):
|
||||
YEARS = "years"
|
||||
MONTHS = "months"
|
||||
WEEKS = "weeks"
|
||||
DAYS = "days"
|
||||
HOURS = "hours"
|
||||
MINUTES = "minutes"
|
||||
|
||||
class RestartVars(BaseEnum):
|
||||
DB_IMPORT = "db_import"
|
||||
HOST_CHANGE = "host_change"
|
||||
178
backend/implementations/apprise_parser.py
Normal file
178
backend/implementations/apprise_parser.py
Normal file
@@ -0,0 +1,178 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from re import compile
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
|
||||
from apprise import Apprise
|
||||
|
||||
from backend.base.helpers import when_not_none
|
||||
|
||||
remove_named_groups = compile(r'(?<=\()\?P<\w+>')
|
||||
IGNORED_ARGS = ('cto', 'format', 'overflow', 'rto', 'verify')
|
||||
|
||||
|
||||
def process_regex(
|
||||
regex: Union[Tuple[str, str], None]
|
||||
) -> Union[Tuple[str, str], None]:
|
||||
return when_not_none(
|
||||
regex,
|
||||
lambda r: (remove_named_groups.sub('', r[0]), r[1])
|
||||
)
|
||||
|
||||
|
||||
def _sort_tokens(t: Dict[str, Any]) -> List[int]:
|
||||
result = [
|
||||
int(not t['required'])
|
||||
]
|
||||
|
||||
if t['name'] == 'Schema':
|
||||
result.append(0)
|
||||
|
||||
if t['type'] == 'choice':
|
||||
result.append(1)
|
||||
|
||||
elif t['type'] != 'list':
|
||||
result.append(2)
|
||||
|
||||
else:
|
||||
result.append(3)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_apprise_services() -> List[Dict[str, Any]]:
|
||||
apprise_services = []
|
||||
|
||||
raw = Apprise().details()['schemas']
|
||||
for entry in raw:
|
||||
result = {
|
||||
'name': str(entry['service_name']),
|
||||
'doc_url': entry['setup_url'],
|
||||
'details': {
|
||||
'templates': entry['details']['templates'],
|
||||
'tokens': [],
|
||||
'args': []
|
||||
}
|
||||
}
|
||||
|
||||
handled_tokens = set()
|
||||
for k, v in entry['details']['tokens'].items():
|
||||
if not v['type'].startswith('list:'):
|
||||
continue
|
||||
|
||||
list_entry = {
|
||||
'name': v['name'],
|
||||
'map_to': k,
|
||||
'required': v['required'],
|
||||
'type': 'list',
|
||||
'delim': v['delim'][0],
|
||||
'content': []
|
||||
}
|
||||
|
||||
for content in v['group']:
|
||||
token = entry['details']['tokens'][content]
|
||||
list_entry['content'].append({
|
||||
'name': token['name'],
|
||||
'required': token['required'],
|
||||
'type': token['type'],
|
||||
'prefix': token.get('prefix'),
|
||||
'regex': process_regex(token.get('regex'))
|
||||
})
|
||||
handled_tokens.add(content)
|
||||
|
||||
result['details']['tokens'].append(list_entry)
|
||||
handled_tokens.add(k)
|
||||
|
||||
for k, v in entry['details']['tokens'].items():
|
||||
if k in handled_tokens:
|
||||
continue
|
||||
|
||||
normal_entry = {
|
||||
'name': v['name'],
|
||||
'map_to': k,
|
||||
'required': v['required'],
|
||||
'type': v['type'].split(':')[0]
|
||||
}
|
||||
|
||||
if v['type'].startswith('choice'):
|
||||
normal_entry.update({
|
||||
'options': v.get('values'),
|
||||
'default': v.get('default')
|
||||
})
|
||||
|
||||
else:
|
||||
normal_entry.update({
|
||||
'prefix': v.get('prefix'),
|
||||
'min': v.get('min'),
|
||||
'max': v.get('max'),
|
||||
'regex': process_regex(v.get('regex'))
|
||||
})
|
||||
|
||||
result['details']['tokens'].append(normal_entry)
|
||||
|
||||
for k, v in entry['details']['args'].items():
|
||||
if (
|
||||
v.get('alias_of') is not None
|
||||
or k in IGNORED_ARGS
|
||||
):
|
||||
continue
|
||||
|
||||
args_entry = {
|
||||
'name': v.get('name', k),
|
||||
'map_to': k,
|
||||
'required': v.get('required', False),
|
||||
'type': v['type'].split(':')[0],
|
||||
}
|
||||
|
||||
if v['type'].startswith('list'):
|
||||
args_entry.update({
|
||||
'delim': v['delim'][0],
|
||||
'content': []
|
||||
})
|
||||
|
||||
elif v['type'].startswith('choice'):
|
||||
args_entry.update({
|
||||
'options': v['values'],
|
||||
'default': v.get('default')
|
||||
})
|
||||
|
||||
elif v['type'] == 'bool':
|
||||
args_entry.update({
|
||||
'default': v['default']
|
||||
})
|
||||
|
||||
else:
|
||||
args_entry.update({
|
||||
'min': v.get('min'),
|
||||
'max': v.get('max'),
|
||||
'regex': process_regex(v.get('regex'))
|
||||
})
|
||||
|
||||
result['details']['args'].append(args_entry)
|
||||
|
||||
result['details']['tokens'].sort(key=_sort_tokens)
|
||||
result['details']['args'].sort(key=_sort_tokens)
|
||||
apprise_services.append(result)
|
||||
|
||||
apprise_services.sort(key=lambda s: s['name'].lower())
|
||||
|
||||
apprise_services.insert(0, {
|
||||
'name': 'Custom URL',
|
||||
'doc_url': 'https://github.com/caronc/apprise#supported-notifications',
|
||||
'details': {
|
||||
'templates': ['{url}'],
|
||||
'tokens': [{
|
||||
'name': 'Apprise URL',
|
||||
'map_to': 'url',
|
||||
'required': True,
|
||||
'type': 'string',
|
||||
'prefix': None,
|
||||
'min': None,
|
||||
'max': None,
|
||||
'regex': None
|
||||
}],
|
||||
'args': []
|
||||
}
|
||||
})
|
||||
|
||||
return apprise_services
|
||||
205
backend/implementations/notification_services.py
Normal file
205
backend/implementations/notification_services.py
Normal file
@@ -0,0 +1,205 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from dataclasses import asdict
|
||||
from typing import List, Union
|
||||
|
||||
from backend.base.custom_exceptions import (NotificationServiceInUse,
|
||||
NotificationServiceNotFound,
|
||||
URLInvalid)
|
||||
from backend.base.definitions import (Constants, NotificationServiceData,
|
||||
ReminderType, SendResult)
|
||||
from backend.base.helpers import send_apprise_notification
|
||||
from backend.base.logging import LOGGER
|
||||
from backend.internals.db_models import (NotificationServicesDB,
|
||||
ReminderServicesDB)
|
||||
|
||||
|
||||
class NotificationService:
|
||||
def __init__(
|
||||
self,
|
||||
user_id: int,
|
||||
notification_service_id: int
|
||||
) -> None:
|
||||
"""Create an representation of a notification service.
|
||||
|
||||
Args:
|
||||
user_id (int): The ID that the service belongs to.
|
||||
notification_service_id (int): The ID of the service itself.
|
||||
|
||||
Raises:
|
||||
NotificationServiceNotFound: The user does not own a notification
|
||||
service with the given ID.
|
||||
"""
|
||||
self.user_id = user_id
|
||||
self.id = notification_service_id
|
||||
|
||||
self.ns_db = NotificationServicesDB(self.user_id)
|
||||
|
||||
if not self.ns_db.exists(self.id):
|
||||
raise NotificationServiceNotFound(self.id)
|
||||
|
||||
return
|
||||
|
||||
def get(self) -> NotificationServiceData:
|
||||
"""Get the info about the notification service.
|
||||
|
||||
Returns:
|
||||
NotificationServiceData: The info about the notification service.
|
||||
"""
|
||||
return self.ns_db.fetch(self.id)[0]
|
||||
|
||||
def update(
|
||||
self,
|
||||
title: Union[str, None] = None,
|
||||
url: Union[str, None] = None
|
||||
) -> NotificationServiceData:
|
||||
"""Edit the notification service. The URL is tested by sending a test
|
||||
notification to it.
|
||||
|
||||
Args:
|
||||
title (Union[str, None], optional): The new title of the service.
|
||||
Defaults to None.
|
||||
|
||||
url (Union[str, None], optional): The new url of the service.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
NotificationServiceData: The new info about the service.
|
||||
"""
|
||||
LOGGER.info(
|
||||
f'Updating notification service {self.id}: {title=}, {url=}'
|
||||
)
|
||||
|
||||
# Get current data and update it with new values
|
||||
data = asdict(self.get())
|
||||
test_url = data["url"] != url
|
||||
|
||||
new_values = {
|
||||
'title': title,
|
||||
'url': url
|
||||
}
|
||||
for k, v in new_values.items():
|
||||
if v is not None:
|
||||
data[k] = v
|
||||
|
||||
if test_url and NotificationServices(self.user_id).test(
|
||||
data['url']
|
||||
) != SendResult.SUCCESS:
|
||||
raise URLInvalid(data['url'])
|
||||
|
||||
self.ns_db.update(self.id, data["title"], data["url"])
|
||||
|
||||
return self.get()
|
||||
|
||||
def delete(
|
||||
self,
|
||||
delete_reminders_using: bool = False
|
||||
) -> None:
|
||||
"""Delete the service.
|
||||
|
||||
Args:
|
||||
delete_reminders_using (bool, optional): Instead of throwing an
|
||||
error when there are still reminders using the service, delete
|
||||
the reminders.
|
||||
Defaults to False.
|
||||
|
||||
Raises:
|
||||
NotificationServiceInUse: The service is still used by a reminder.
|
||||
"""
|
||||
from backend.features.reminders import Reminder
|
||||
from backend.features.static_reminders import StaticReminder
|
||||
from backend.features.templates import Template
|
||||
|
||||
LOGGER.info(f'Deleting notification service {self.id}')
|
||||
|
||||
for r_type, RClass in (
|
||||
(ReminderType.REMINDER, Reminder),
|
||||
(ReminderType.STATIC_REMINDER, StaticReminder),
|
||||
(ReminderType.TEMPLATE, Template)
|
||||
):
|
||||
uses = ReminderServicesDB(r_type).uses_ns(self.id)
|
||||
if uses:
|
||||
if not delete_reminders_using:
|
||||
raise NotificationServiceInUse(
|
||||
self.id,
|
||||
r_type.value
|
||||
)
|
||||
|
||||
for r_id in uses:
|
||||
RClass(self.user_id, r_id).delete()
|
||||
|
||||
self.ns_db.delete(self.id)
|
||||
return
|
||||
|
||||
|
||||
class NotificationServices:
|
||||
def __init__(self, user_id: int) -> None:
|
||||
"""Represent the notification services of a user.
|
||||
|
||||
Args:
|
||||
user_id (int): The ID of the user.
|
||||
"""
|
||||
self.user_id = user_id
|
||||
self.ns_db = NotificationServicesDB(self.user_id)
|
||||
return
|
||||
|
||||
def fetchall(self) -> List[NotificationServiceData]:
|
||||
"""Get a list of all notification services.
|
||||
|
||||
Returns:
|
||||
List[NotificationServiceData]: The list of all notification services.
|
||||
"""
|
||||
return self.ns_db.fetch()
|
||||
|
||||
def fetchone(self, notification_service_id: int) -> NotificationService:
|
||||
"""Get one notification service based on it's id.
|
||||
|
||||
Args:
|
||||
notification_service_id (int): The id of the desired service.
|
||||
|
||||
Raises:
|
||||
NotificationServiceNotFound: The user does not own a notification
|
||||
service with the given ID.
|
||||
|
||||
Returns:
|
||||
NotificationService: Instance of NotificationService.
|
||||
"""
|
||||
return NotificationService(self.user_id, notification_service_id)
|
||||
|
||||
def add(self, title: str, url: str) -> NotificationService:
|
||||
"""Add a notification service. The service is tested by sending a test
|
||||
notification to it.
|
||||
|
||||
Args:
|
||||
title (str): The title of the service.
|
||||
url (str): The apprise url of the service.
|
||||
|
||||
Raises:
|
||||
URLInvalid: The url is invalid.
|
||||
|
||||
Returns:
|
||||
NotificationService: The instance representing the new service.
|
||||
"""
|
||||
LOGGER.info(f'Adding notification service with {title=}, {url=}')
|
||||
|
||||
if self.test(url) != SendResult.SUCCESS:
|
||||
raise URLInvalid(url)
|
||||
|
||||
new_id = self.ns_db.add(title, url)
|
||||
|
||||
return self.fetchone(new_id)
|
||||
|
||||
def test(self, url: str) -> SendResult:
|
||||
"""Test a notification service by sending a test notification to it.
|
||||
|
||||
Args:
|
||||
url (str): The apprise url of the service.
|
||||
|
||||
Returns:
|
||||
SendResult: The result of the test.
|
||||
"""
|
||||
return send_apprise_notification(
|
||||
[url],
|
||||
Constants.APPRISE_TEST_TITLE,
|
||||
Constants.APPRISE_TEST_BODY
|
||||
)
|
||||
292
backend/implementations/users.py
Normal file
292
backend/implementations/users.py
Normal file
@@ -0,0 +1,292 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from typing import List, Union
|
||||
|
||||
from backend.base.custom_exceptions import (AccessUnauthorized,
|
||||
NewAccountsNotAllowed,
|
||||
OperationNotAllowed,
|
||||
UsernameInvalid, UsernameTaken,
|
||||
UserNotFound)
|
||||
from backend.base.definitions import Constants, InvalidUsernameReason, UserData
|
||||
from backend.base.helpers import Singleton, generate_salt_hash, get_hash
|
||||
from backend.base.logging import LOGGER
|
||||
from backend.internals.db_models import UsersDB
|
||||
from backend.internals.settings import Settings
|
||||
|
||||
|
||||
def is_valid_username(username: str) -> None:
|
||||
"""Check if username is valid.
|
||||
|
||||
Args:
|
||||
username (str): The username to check.
|
||||
|
||||
Raises:
|
||||
UsernameInvalid: The username is not valid.
|
||||
"""
|
||||
if username in Constants.INVALID_USERNAMES:
|
||||
raise UsernameInvalid(username, InvalidUsernameReason.NOT_ALLOWED)
|
||||
|
||||
if username.isdigit():
|
||||
raise UsernameInvalid(username, InvalidUsernameReason.ONLY_NUMBERS)
|
||||
|
||||
if any(
|
||||
c not in Constants.USERNAME_CHARACTERS
|
||||
for c in username
|
||||
):
|
||||
raise UsernameInvalid(
|
||||
username,
|
||||
InvalidUsernameReason.INVALID_CHARACTER
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
|
||||
class User:
|
||||
def __init__(self, id: int) -> None:
|
||||
"""Create a representation of a user.
|
||||
|
||||
Args:
|
||||
id (int): The ID of the user.
|
||||
|
||||
Raises:
|
||||
UserNotFound: The user does not exist.
|
||||
"""
|
||||
self.user_db = UsersDB()
|
||||
self.user_id = id
|
||||
|
||||
if not self.user_db.exists(self.user_id):
|
||||
raise UserNotFound(None, id)
|
||||
|
||||
return
|
||||
|
||||
def get(self) -> UserData:
|
||||
"""Get the info about the user.
|
||||
|
||||
Returns:
|
||||
UserData: The info about the user.
|
||||
"""
|
||||
return self.user_db.fetch(self.user_id)[0]
|
||||
|
||||
def update(
|
||||
self,
|
||||
new_username: Union[str, None],
|
||||
new_password: Union[str, None]
|
||||
) -> None:
|
||||
"""Change the username and/or password of the account.
|
||||
|
||||
Args:
|
||||
new_username (Union[str, None]): The new username, or None if it
|
||||
should not be changed.
|
||||
new_password (Union[str, None]): The new password, or None if it
|
||||
should not be changed.
|
||||
|
||||
Raises:
|
||||
OperationNotAllowed: The user is an admin and is trying to change
|
||||
the username.
|
||||
UsernameInvalid: The new username is not valid.
|
||||
UsernameTaken: The new username is already taken.
|
||||
"""
|
||||
user_data = self.get()
|
||||
|
||||
if new_username is not None:
|
||||
if user_data.admin:
|
||||
raise OperationNotAllowed(
|
||||
"Changing the username of an admin account"
|
||||
)
|
||||
|
||||
is_valid_username(new_username)
|
||||
|
||||
if self.user_db.taken(new_username):
|
||||
raise UsernameTaken(new_username)
|
||||
|
||||
self.user_db.update(
|
||||
self.user_id,
|
||||
new_username,
|
||||
user_data.hash
|
||||
)
|
||||
|
||||
LOGGER.info(
|
||||
f"The user with ID {self.user_id} has a changed username: {new_username}"
|
||||
)
|
||||
|
||||
user_data = self.get()
|
||||
|
||||
if new_password is not None:
|
||||
hash_password = get_hash(user_data.salt, new_password)
|
||||
|
||||
self.user_db.update(
|
||||
self.user_id,
|
||||
user_data.username,
|
||||
hash_password
|
||||
)
|
||||
|
||||
LOGGER.info(
|
||||
f'The user with ID {self.user_id} changed their password'
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
def delete(self) -> None:
|
||||
"""Delete the user.
|
||||
|
||||
Raises:
|
||||
OperationNotAllowed: The admin account cannot be deleted.
|
||||
"""
|
||||
user_data = self.get()
|
||||
if user_data.admin:
|
||||
raise OperationNotAllowed(
|
||||
"The admin account cannot be deleted"
|
||||
)
|
||||
|
||||
LOGGER.info(f'Deleting the user with ID {self.user_id}')
|
||||
|
||||
self.user_db.delete(self.user_id)
|
||||
|
||||
return
|
||||
|
||||
|
||||
class Users(metaclass=Singleton):
|
||||
def __init__(self) -> None:
|
||||
self.user_db = UsersDB()
|
||||
return
|
||||
|
||||
def get_all(self) -> List[UserData]:
|
||||
"""Get all user info for the admin
|
||||
|
||||
Returns:
|
||||
List[UserData]: The info about all users
|
||||
"""
|
||||
result = self.user_db.fetch()
|
||||
return result
|
||||
|
||||
def get_one(self, id: int) -> User:
|
||||
"""Get a user instance based on the ID.
|
||||
|
||||
Args:
|
||||
id (int): The ID of the user.
|
||||
|
||||
Returns:
|
||||
User: The user instance.
|
||||
"""
|
||||
return User(id)
|
||||
|
||||
def __contains__(self, username_or_id: Union[str, int]) -> bool:
|
||||
if isinstance(username_or_id, str):
|
||||
return self.username_taken(username_or_id)
|
||||
else:
|
||||
return self.id_taken(username_or_id)
|
||||
|
||||
def username_taken(self, username: str) -> bool:
|
||||
"""Check if a username is taken.
|
||||
|
||||
Args:
|
||||
username (str): The username to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the username is taken, False otherwise.
|
||||
"""
|
||||
return self.user_db.taken(username)
|
||||
|
||||
def id_taken(self, id: int) -> bool:
|
||||
"""Check if a user ID is taken.
|
||||
|
||||
Args:
|
||||
id (int): The user ID to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the user ID is taken, False otherwise.
|
||||
"""
|
||||
return self.user_db.exists(id)
|
||||
|
||||
def login(
|
||||
self,
|
||||
username: str,
|
||||
password: str
|
||||
) -> User:
|
||||
"""Login into an user account.
|
||||
|
||||
Args:
|
||||
username (str): The username of the user.
|
||||
password (str): The password of the user.
|
||||
|
||||
Raises:
|
||||
UserNotFound: There is no user with the given username.
|
||||
AccessUnauthorized: The password is incorrect.
|
||||
|
||||
Returns:
|
||||
User: The user that was logged into.
|
||||
"""
|
||||
if not self.user_db.taken(username):
|
||||
raise UserNotFound(username, None)
|
||||
|
||||
user_data = self.user_db.fetch(
|
||||
self.user_db.username_to_id(username)
|
||||
)[0]
|
||||
|
||||
hash_password = get_hash(user_data.salt, password)
|
||||
# Comparing hashes, not password strings, so no need for
|
||||
# constant time comparison
|
||||
if not hash_password == user_data.hash:
|
||||
raise AccessUnauthorized
|
||||
|
||||
return User(user_data.id)
|
||||
|
||||
def add(
|
||||
self,
|
||||
username: str,
|
||||
password: str,
|
||||
force: bool = False,
|
||||
is_admin: bool = False
|
||||
) -> int:
|
||||
"""Add a user.
|
||||
|
||||
Args:
|
||||
username (str): The username of the new user.
|
||||
|
||||
password (str): The password of the new user.
|
||||
|
||||
force (bool, optional): Skip check for whether new accounts are
|
||||
allowed.
|
||||
Defaults to False.
|
||||
|
||||
is_admin (bool, optional): The account is the admin account.
|
||||
Defaults to False.
|
||||
|
||||
Raises:
|
||||
UsernameInvalid: Username not allowed or contains invalid characters.
|
||||
UsernameTaken: Username is already taken; usernames must be unique.
|
||||
NewAccountsNotAllowed: In the admin panel, new accounts are set to be
|
||||
not allowed.
|
||||
|
||||
Returns:
|
||||
int: The ID of the new user. User registered successfully.
|
||||
"""
|
||||
LOGGER.info(f'Registering user with username {username}')
|
||||
|
||||
if not force and not Settings().get_settings().allow_new_accounts:
|
||||
raise NewAccountsNotAllowed
|
||||
|
||||
is_valid_username(username)
|
||||
|
||||
if self.user_db.taken(username):
|
||||
raise UsernameTaken(username)
|
||||
|
||||
if is_admin:
|
||||
if self.user_db.taken(Constants.ADMIN_USERNAME):
|
||||
# Attempted to add admin account (only done internally),
|
||||
# but admin account already exists
|
||||
raise RuntimeError("Admin account already exists")
|
||||
|
||||
# Generate salt and key exclusive for user
|
||||
salt, hashed_password = generate_salt_hash(password)
|
||||
|
||||
# Add user to database
|
||||
user_id = self.user_db.add(
|
||||
username,
|
||||
salt,
|
||||
hashed_password,
|
||||
is_admin
|
||||
)
|
||||
|
||||
LOGGER.debug(f'Newly registered user has id {user_id}')
|
||||
return user_id
|
||||
493
backend/internals/db.py
Normal file
493
backend/internals/db.py
Normal file
@@ -0,0 +1,493 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
Setting up the database and handling connections
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from os import remove
|
||||
from os.path import dirname, exists, isdir, isfile, join
|
||||
from shutil import move
|
||||
from sqlite3 import (PARSE_DECLTYPES, Connection, Cursor,
|
||||
OperationalError, ProgrammingError, Row,
|
||||
register_adapter, register_converter)
|
||||
from threading import current_thread
|
||||
from typing import Any, Dict, Generator, Iterable, List, Type, Union
|
||||
|
||||
from flask import g
|
||||
|
||||
from backend.base.custom_exceptions import InvalidDatabaseFile
|
||||
from backend.base.definitions import Constants, ReminderType, StartType, T
|
||||
from backend.base.helpers import create_folder, folder_path, rename_file
|
||||
from backend.base.logging import LOGGER, set_log_level
|
||||
from backend.internals.db_migration import get_latest_db_version, migrate_db
|
||||
|
||||
REMINDER_TO_KEY = {
|
||||
ReminderType.REMINDER: "reminder_id",
|
||||
ReminderType.STATIC_REMINDER: "static_reminder_id",
|
||||
ReminderType.TEMPLATE: "template_id"
|
||||
}
|
||||
|
||||
|
||||
class MindCursor(Cursor):
|
||||
|
||||
row_factory: Union[Type[Row], None] # type: ignore
|
||||
|
||||
@property
|
||||
def lastrowid(self) -> int:
|
||||
return super().lastrowid or 1
|
||||
|
||||
def fetchonedict(self) -> Union[Dict[str, Any], None]:
|
||||
"""Same as `fetchone` but convert the Row object to a dict.
|
||||
|
||||
Returns:
|
||||
Union[Dict[str, Any], None]: The dict or None i.c.o. no result.
|
||||
"""
|
||||
r = self.fetchone()
|
||||
if r is None:
|
||||
return r
|
||||
return dict(r)
|
||||
|
||||
def fetchmanydict(self, size: Union[int, None] = 1) -> List[Dict[str, Any]]:
|
||||
"""Same as `fetchmany` but convert the Row object to a dict.
|
||||
|
||||
Args:
|
||||
size (Union[int, None], optional): The amount of rows to return.
|
||||
Defaults to 1.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: The rows.
|
||||
"""
|
||||
return [dict(e) for e in self.fetchmany(size)]
|
||||
|
||||
def fetchalldict(self) -> List[Dict[str, Any]]:
|
||||
"""Same as `fetchall` but convert the Row object to a dict.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: The results.
|
||||
"""
|
||||
return [dict(e) for e in self]
|
||||
|
||||
def exists(self) -> Union[Any, None]:
|
||||
"""Return the first column of the first row, or `None` if not found.
|
||||
|
||||
Returns:
|
||||
Union[Any, None]: The value of the first column of the first row,
|
||||
or `None` if not found.
|
||||
"""
|
||||
r = self.fetchone()
|
||||
if r is None:
|
||||
return r
|
||||
return r[0]
|
||||
|
||||
|
||||
class DBConnectionManager(type):
|
||||
instances: Dict[int, DBConnection] = {}
|
||||
|
||||
def __call__(cls, *args: Any, **kwargs: Any) -> DBConnection:
|
||||
thread_id = current_thread().native_id or -1
|
||||
|
||||
if (
|
||||
not thread_id in cls.instances
|
||||
or cls.instances[thread_id].closed
|
||||
):
|
||||
cls.instances[thread_id] = super().__call__(*args, **kwargs)
|
||||
|
||||
return cls.instances[thread_id]
|
||||
|
||||
|
||||
class DBConnection(Connection, metaclass=DBConnectionManager):
|
||||
file = ''
|
||||
|
||||
def __init__(self, timeout: float) -> None:
|
||||
"""Create a connection with a database.
|
||||
|
||||
Args:
|
||||
timeout (float): How long to wait before giving up on a command.
|
||||
"""
|
||||
LOGGER.debug(f'Creating connection {self}')
|
||||
super().__init__(
|
||||
self.file,
|
||||
timeout=timeout,
|
||||
detect_types=PARSE_DECLTYPES
|
||||
)
|
||||
super().cursor().execute("PRAGMA foreign_keys = ON;")
|
||||
self.closed = False
|
||||
return
|
||||
|
||||
def cursor( # type: ignore
|
||||
self,
|
||||
force_new: bool = False
|
||||
) -> MindCursor:
|
||||
"""Get a database cursor from the connection.
|
||||
|
||||
Args:
|
||||
force_new (bool, optional): Get a new cursor instead of the cached
|
||||
one.
|
||||
Defaults to False.
|
||||
|
||||
Returns:
|
||||
MindCursor: The database cursor.
|
||||
"""
|
||||
if not hasattr(g, 'cursors'):
|
||||
g.cursors = []
|
||||
|
||||
if not g.cursors:
|
||||
c = MindCursor(self)
|
||||
c.row_factory = Row
|
||||
g.cursors.append(c)
|
||||
|
||||
if not force_new:
|
||||
return g.cursors[0]
|
||||
else:
|
||||
c = MindCursor(self)
|
||||
c.row_factory = Row
|
||||
g.cursors.append(c)
|
||||
return g.cursors[-1]
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the database connection"""
|
||||
LOGGER.debug(f'Closing connection {self}')
|
||||
self.closed = True
|
||||
super().close()
|
||||
return
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<{self.__class__.__name__}; {current_thread().name}; {id(self)}>'
|
||||
|
||||
|
||||
def set_db_location(
|
||||
db_folder: Union[str, None]
|
||||
) -> None:
|
||||
"""Setup database location. Create folder for database and set location for
|
||||
`db.DBConnection`.
|
||||
|
||||
Args:
|
||||
db_folder (Union[str, None], optional): The folder in which the database
|
||||
will be stored or in which a database is for MIND to use. Give
|
||||
`None` for the default location.
|
||||
|
||||
Raises:
|
||||
ValueError: Value of `db_folder` exists but is not a folder.
|
||||
"""
|
||||
if db_folder:
|
||||
if exists(db_folder) and not isdir(db_folder):
|
||||
raise ValueError('Database location is not a folder')
|
||||
|
||||
db_file_location = join(
|
||||
db_folder or folder_path(*Constants.DB_FOLDER),
|
||||
Constants.DB_NAME
|
||||
)
|
||||
|
||||
LOGGER.debug(f'Setting database location: {db_file_location}')
|
||||
|
||||
create_folder(dirname(db_file_location))
|
||||
|
||||
if isfile(folder_path('db', 'Noted.db')):
|
||||
rename_file(
|
||||
folder_path('db', 'Noted.db'),
|
||||
db_file_location
|
||||
)
|
||||
|
||||
DBConnection.file = db_file_location
|
||||
|
||||
return
|
||||
|
||||
|
||||
def get_db(force_new: bool = False) -> MindCursor:
|
||||
"""Get a database cursor instance or create a new one if needed.
|
||||
|
||||
Args:
|
||||
force_new (bool, optional): Decides if a new cursor is
|
||||
returned instead of the standard one.
|
||||
Defaults to False.
|
||||
|
||||
Returns:
|
||||
MindCursor: Database cursor instance that outputs Row objects.
|
||||
"""
|
||||
cursor = (
|
||||
DBConnection(timeout=Constants.DB_TIMEOUT)
|
||||
.cursor(force_new=force_new)
|
||||
)
|
||||
return cursor
|
||||
|
||||
|
||||
def commit() -> None:
|
||||
"""Commit the database"""
|
||||
get_db().connection.commit()
|
||||
return
|
||||
|
||||
|
||||
def iter_commit(iterable: Iterable[T]) -> Generator[T, Any, Any]:
|
||||
"""Commit the database after each iteration. Also commits just before the
|
||||
first iteration starts.
|
||||
|
||||
Args:
|
||||
iterable (Iterable[T]): Iterable that will be iterated over like normal.
|
||||
|
||||
Yields:
|
||||
Generator[T, Any, Any]: Items of iterable.
|
||||
"""
|
||||
commit = get_db().connection.commit
|
||||
commit()
|
||||
for i in iterable:
|
||||
yield i
|
||||
commit()
|
||||
return
|
||||
|
||||
|
||||
def close_db(e: Union[None, BaseException] = None) -> None:
|
||||
"""Close database cursor, commit database and close database.
|
||||
|
||||
Args:
|
||||
e (Union[None, BaseException], optional): Error. Defaults to None.
|
||||
"""
|
||||
try:
|
||||
cursors = g.cursors
|
||||
db: DBConnection = cursors[0].connection
|
||||
for c in cursors:
|
||||
c.close()
|
||||
delattr(g, 'cursors')
|
||||
db.commit()
|
||||
if not current_thread().name.startswith('waitress-'):
|
||||
db.close()
|
||||
|
||||
except (AttributeError, ProgrammingError):
|
||||
pass
|
||||
|
||||
return
|
||||
|
||||
|
||||
def close_all_db() -> None:
|
||||
"Close all non-temporary database connections that are still open"
|
||||
LOGGER.debug('Closing any open database connections')
|
||||
|
||||
for i in DBConnectionManager.instances.values():
|
||||
if not i.closed:
|
||||
i.close()
|
||||
|
||||
c = DBConnection(timeout=20.0)
|
||||
c.commit()
|
||||
c.close()
|
||||
return
|
||||
|
||||
|
||||
def setup_db() -> None:
|
||||
"""
|
||||
Setup the database tables and default config when they aren't setup yet
|
||||
"""
|
||||
from backend.implementations.users import Users
|
||||
from backend.internals.settings import Settings
|
||||
|
||||
cursor = get_db()
|
||||
cursor.execute("PRAGMA journal_mode = wal;")
|
||||
register_adapter(bool, lambda b: int(b))
|
||||
register_converter("BOOL", lambda b: b == b'1')
|
||||
|
||||
cursor.executescript("""
|
||||
CREATE TABLE IF NOT EXISTS users(
|
||||
id INTEGER PRIMARY KEY,
|
||||
username VARCHAR(255) UNIQUE NOT NULL,
|
||||
salt VARCHAR(40) NOT NULL,
|
||||
hash VARCHAR(100) NOT NULL,
|
||||
admin BOOL NOT NULL DEFAULT 0
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS notification_services(
|
||||
id INTEGER PRIMARY KEY,
|
||||
user_id INTEGER NOT NULL,
|
||||
title VARCHAR(255),
|
||||
url TEXT,
|
||||
|
||||
FOREIGN KEY (user_id) REFERENCES users(id)
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS reminders(
|
||||
id INTEGER PRIMARY KEY,
|
||||
user_id INTEGER NOT NULL,
|
||||
title VARCHAR(255) NOT NULL,
|
||||
text TEXT,
|
||||
time INTEGER NOT NULL,
|
||||
|
||||
repeat_quantity VARCHAR(15),
|
||||
repeat_interval INTEGER,
|
||||
original_time INTEGER,
|
||||
weekdays VARCHAR(13),
|
||||
|
||||
color VARCHAR(7),
|
||||
|
||||
FOREIGN KEY (user_id) REFERENCES users(id)
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS templates(
|
||||
id INTEGER PRIMARY KEY,
|
||||
user_id INTEGER NOT NULL,
|
||||
title VARCHAR(255) NOT NULL,
|
||||
text TEXT,
|
||||
|
||||
color VARCHAR(7),
|
||||
|
||||
FOREIGN KEY (user_id) REFERENCES users(id)
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS static_reminders(
|
||||
id INTEGER PRIMARY KEY,
|
||||
user_id INTEGER NOT NULL,
|
||||
title VARCHAR(255) NOT NULL,
|
||||
text TEXT,
|
||||
|
||||
color VARCHAR(7),
|
||||
|
||||
FOREIGN KEY (user_id) REFERENCES users(id)
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS reminder_services(
|
||||
reminder_id INTEGER,
|
||||
static_reminder_id INTEGER,
|
||||
template_id INTEGER,
|
||||
notification_service_id INTEGER NOT NULL,
|
||||
|
||||
FOREIGN KEY (reminder_id) REFERENCES reminders(id)
|
||||
ON DELETE CASCADE,
|
||||
FOREIGN KEY (static_reminder_id) REFERENCES static_reminders(id)
|
||||
ON DELETE CASCADE,
|
||||
FOREIGN KEY (template_id) REFERENCES templates(id)
|
||||
ON DELETE CASCADE,
|
||||
FOREIGN KEY (notification_service_id) REFERENCES notification_services(id)
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS config(
|
||||
key VARCHAR(255) PRIMARY KEY,
|
||||
value BLOB NOT NULL
|
||||
);
|
||||
""")
|
||||
|
||||
settings = Settings()
|
||||
settings_values = settings.get_settings()
|
||||
|
||||
set_log_level(settings_values.log_level)
|
||||
|
||||
migrate_db()
|
||||
|
||||
# DB Migration might change settings, so update cache just to be sure.
|
||||
settings._fetch_settings()
|
||||
|
||||
# Add admin user if it doesn't exist
|
||||
users = Users()
|
||||
if Constants.ADMIN_USERNAME not in users:
|
||||
users.add(
|
||||
Constants.ADMIN_USERNAME, Constants.ADMIN_PASSWORD,
|
||||
force=True,
|
||||
is_admin=True
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
|
||||
def revert_db_import(
|
||||
swap: bool,
|
||||
imported_db_file: str = ''
|
||||
) -> None:
|
||||
"""Revert the database import process. The original_db_file is the file
|
||||
currently used (`DBConnection.file`).
|
||||
|
||||
Args:
|
||||
swap (bool): Whether or not to keep the imported_db_file or not,
|
||||
instead of the original_db_file.
|
||||
|
||||
imported_db_file (str, optional): The other database file. Keep empty
|
||||
to use `Constants.DB_ORIGINAL_FILENAME`.
|
||||
Defaults to ''.
|
||||
"""
|
||||
original_db_file = DBConnection.file
|
||||
if not imported_db_file:
|
||||
imported_db_file = join(
|
||||
dirname(DBConnection.file),
|
||||
Constants.DB_ORIGINAL_NAME
|
||||
)
|
||||
|
||||
if swap:
|
||||
remove(original_db_file)
|
||||
move(
|
||||
imported_db_file,
|
||||
original_db_file
|
||||
)
|
||||
|
||||
else:
|
||||
remove(imported_db_file)
|
||||
|
||||
return
|
||||
|
||||
|
||||
def import_db(
|
||||
new_db_file: str,
|
||||
copy_hosting_settings: bool
|
||||
) -> None:
|
||||
"""Replace the current database with a new one.
|
||||
|
||||
Args:
|
||||
new_db_file (str): The path to the new database file.
|
||||
copy_hosting_settings (bool): Keep the hosting settings from the current
|
||||
database.
|
||||
|
||||
Raises:
|
||||
InvalidDatabaseFile: The new database file is invalid or unsupported.
|
||||
"""
|
||||
LOGGER.info(f'Importing new database; {copy_hosting_settings=}')
|
||||
|
||||
cursor = Connection(new_db_file, timeout=20.0).cursor()
|
||||
try:
|
||||
database_version = cursor.execute(
|
||||
"SELECT value FROM config WHERE key = 'database_version' LIMIT 1;"
|
||||
).fetchone()[0]
|
||||
if not isinstance(database_version, int):
|
||||
raise InvalidDatabaseFile(new_db_file)
|
||||
|
||||
except (OperationalError, InvalidDatabaseFile):
|
||||
LOGGER.error('Uploaded database is not a MIND database file')
|
||||
cursor.connection.close()
|
||||
revert_db_import(
|
||||
swap=False,
|
||||
imported_db_file=new_db_file
|
||||
)
|
||||
raise InvalidDatabaseFile(new_db_file)
|
||||
|
||||
if database_version > get_latest_db_version():
|
||||
LOGGER.error(
|
||||
'Uploaded database is higher version than this MIND installation can support')
|
||||
revert_db_import(
|
||||
swap=False,
|
||||
imported_db_file=new_db_file
|
||||
)
|
||||
raise InvalidDatabaseFile(new_db_file)
|
||||
|
||||
if copy_hosting_settings:
|
||||
hosting_settings = get_db().execute("""
|
||||
SELECT key, value
|
||||
FROM config
|
||||
WHERE key = 'host'
|
||||
OR key = 'port'
|
||||
OR key = 'url_prefix'
|
||||
LIMIT 3;
|
||||
"""
|
||||
).fetchalldict()
|
||||
cursor.executemany("""
|
||||
INSERT INTO config(key, value)
|
||||
VALUES (:key, :value)
|
||||
ON CONFLICT(key) DO
|
||||
UPDATE
|
||||
SET value = :value;
|
||||
""",
|
||||
hosting_settings
|
||||
)
|
||||
cursor.connection.commit()
|
||||
cursor.connection.close()
|
||||
|
||||
move(
|
||||
DBConnection.file,
|
||||
join(dirname(DBConnection.file), Constants.DB_ORIGINAL_NAME)
|
||||
)
|
||||
move(
|
||||
new_db_file,
|
||||
DBConnection.file
|
||||
)
|
||||
|
||||
from backend.internals.server import Server
|
||||
Server().restart(StartType.RESTART_DB_CHANGES)
|
||||
|
||||
return
|
||||
312
backend/internals/db_migration.py
Normal file
312
backend/internals/db_migration.py
Normal file
@@ -0,0 +1,312 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from typing import Dict, Type
|
||||
|
||||
from backend.base.definitions import Constants, DBMigrator
|
||||
from backend.base.logging import LOGGER
|
||||
|
||||
|
||||
class VersionMappingContainer:
|
||||
version_map: Dict[int, Type[DBMigrator]] = {}
|
||||
|
||||
|
||||
def _load_version_map() -> None:
|
||||
if VersionMappingContainer.version_map:
|
||||
return
|
||||
|
||||
VersionMappingContainer.version_map = {
|
||||
m.start_version: m
|
||||
for m in DBMigrator.__subclasses__()
|
||||
}
|
||||
return
|
||||
|
||||
|
||||
def get_latest_db_version() -> int:
|
||||
_load_version_map()
|
||||
return max(VersionMappingContainer.version_map) + 1
|
||||
|
||||
|
||||
def migrate_db() -> None:
|
||||
"""
|
||||
Migrate a MIND database from it's current version
|
||||
to the newest version supported by the MIND version installed.
|
||||
"""
|
||||
from backend.internals.db import iter_commit
|
||||
from backend.internals.settings import Settings
|
||||
|
||||
s = Settings()
|
||||
current_db_version = s.get_settings().database_version
|
||||
newest_version = get_latest_db_version()
|
||||
if current_db_version == newest_version:
|
||||
return
|
||||
|
||||
LOGGER.info('Migrating database to newer version...')
|
||||
LOGGER.debug(
|
||||
"Database migration: %d -> %d",
|
||||
current_db_version, newest_version
|
||||
)
|
||||
|
||||
for start_version in iter_commit(range(current_db_version, newest_version)):
|
||||
if start_version not in VersionMappingContainer.version_map:
|
||||
continue
|
||||
VersionMappingContainer.version_map[start_version]().run()
|
||||
s.update({'database_version': start_version + 1})
|
||||
|
||||
s._fetch_settings()
|
||||
|
||||
return
|
||||
|
||||
|
||||
class MigrateToUTC(DBMigrator):
|
||||
start_version = 1
|
||||
|
||||
def run(self) -> None:
|
||||
# V1 -> V2
|
||||
|
||||
from datetime import datetime
|
||||
from time import time
|
||||
|
||||
from backend.internals.db import get_db
|
||||
|
||||
cursor = get_db()
|
||||
|
||||
t = time()
|
||||
utc_offset = datetime.fromtimestamp(t) - datetime.utcfromtimestamp(t)
|
||||
|
||||
cursor.execute("SELECT time, id FROM reminders;")
|
||||
new_reminders = [
|
||||
[
|
||||
round((
|
||||
datetime.fromtimestamp(r["time"]) - utc_offset
|
||||
).timestamp()),
|
||||
r["id"]
|
||||
]
|
||||
for r in cursor
|
||||
]
|
||||
|
||||
cursor.executemany(
|
||||
"UPDATE reminders SET time = ? WHERE id = ?;",
|
||||
new_reminders
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
class MigrateAddColor(DBMigrator):
|
||||
start_version = 2
|
||||
|
||||
def run(self) -> None:
|
||||
# V2 -> V3
|
||||
|
||||
from backend.internals.db import get_db
|
||||
|
||||
get_db().executescript("""
|
||||
ALTER TABLE reminders
|
||||
ADD color VARCHAR(7);
|
||||
ALTER TABLE templates
|
||||
ADD color VARCHAR(7);
|
||||
""")
|
||||
|
||||
return
|
||||
|
||||
|
||||
class MigrateFixRQ(DBMigrator):
|
||||
start_version = 3
|
||||
|
||||
def run(self) -> None:
|
||||
# V3 -> V4
|
||||
|
||||
from backend.internals.db import get_db
|
||||
|
||||
get_db().executescript("""
|
||||
UPDATE reminders
|
||||
SET repeat_quantity = repeat_quantity || 's'
|
||||
WHERE repeat_quantity NOT LIKE '%s';
|
||||
""")
|
||||
|
||||
return
|
||||
|
||||
|
||||
class MigrateToReminderServices(DBMigrator):
|
||||
start_version = 4
|
||||
|
||||
def run(self) -> None:
|
||||
# V4 -> V5
|
||||
|
||||
from backend.internals.db import get_db
|
||||
|
||||
get_db().executescript("""
|
||||
BEGIN TRANSACTION;
|
||||
PRAGMA defer_foreign_keys = ON;
|
||||
|
||||
CREATE TEMPORARY TABLE temp_reminder_services(
|
||||
reminder_id,
|
||||
static_reminder_id,
|
||||
template_id,
|
||||
notification_service_id
|
||||
);
|
||||
|
||||
-- Reminders
|
||||
INSERT INTO temp_reminder_services(reminder_id, notification_service_id)
|
||||
SELECT id, notification_service
|
||||
FROM reminders;
|
||||
|
||||
CREATE TEMPORARY TABLE temp_reminders AS
|
||||
SELECT id, user_id, title, text, time, repeat_quantity, repeat_interval, original_time, color
|
||||
FROM reminders;
|
||||
DROP TABLE reminders;
|
||||
CREATE TABLE reminders(
|
||||
id INTEGER PRIMARY KEY,
|
||||
user_id INTEGER NOT NULL,
|
||||
title VARCHAR(255) NOT NULL,
|
||||
text TEXT,
|
||||
time INTEGER NOT NULL,
|
||||
|
||||
repeat_quantity VARCHAR(15),
|
||||
repeat_interval INTEGER,
|
||||
original_time INTEGER,
|
||||
|
||||
color VARCHAR(7),
|
||||
|
||||
FOREIGN KEY (user_id) REFERENCES users(id)
|
||||
);
|
||||
INSERT INTO reminders
|
||||
SELECT * FROM temp_reminders;
|
||||
|
||||
-- Templates
|
||||
INSERT INTO temp_reminder_services(template_id, notification_service_id)
|
||||
SELECT id, notification_service
|
||||
FROM templates;
|
||||
|
||||
CREATE TEMPORARY TABLE temp_templates AS
|
||||
SELECT id, user_id, title, text, color
|
||||
FROM templates;
|
||||
DROP TABLE templates;
|
||||
CREATE TABLE templates(
|
||||
id INTEGER PRIMARY KEY,
|
||||
user_id INTEGER NOT NULL,
|
||||
title VARCHAR(255) NOT NULL,
|
||||
text TEXT,
|
||||
|
||||
color VARCHAR(7),
|
||||
|
||||
FOREIGN KEY (user_id) REFERENCES users(id)
|
||||
);
|
||||
INSERT INTO templates
|
||||
SELECT * FROM temp_templates;
|
||||
|
||||
INSERT INTO reminder_services
|
||||
SELECT * FROM temp_reminder_services;
|
||||
|
||||
COMMIT;
|
||||
""")
|
||||
|
||||
return
|
||||
|
||||
|
||||
class MigrateRemoveUser1(DBMigrator):
|
||||
start_version = 5
|
||||
|
||||
def run(self) -> None:
|
||||
# V5 -> V6
|
||||
from backend.base.custom_exceptions import (AccessUnauthorized,
|
||||
UserNotFound)
|
||||
from backend.implementations.users import Users
|
||||
|
||||
try:
|
||||
Users().login('User1', 'Password1').delete()
|
||||
|
||||
except (UserNotFound, AccessUnauthorized):
|
||||
pass
|
||||
|
||||
return
|
||||
|
||||
|
||||
class MigrateAddWeekdays(DBMigrator):
|
||||
start_version = 6
|
||||
|
||||
def run(self) -> None:
|
||||
# V6 -> V7
|
||||
|
||||
from backend.internals.db import get_db
|
||||
|
||||
get_db().executescript("""
|
||||
ALTER TABLE reminders
|
||||
ADD weekdays VARCHAR(13);
|
||||
""")
|
||||
|
||||
return
|
||||
|
||||
|
||||
class MigrateAddAdmin(DBMigrator):
|
||||
start_version = 7
|
||||
|
||||
def run(self) -> None:
|
||||
# V7 -> V8
|
||||
|
||||
from backend.implementations.users import Users
|
||||
from backend.internals.db import get_db
|
||||
from backend.internals.settings import Settings
|
||||
|
||||
cursor = get_db()
|
||||
|
||||
cursor.executescript("""
|
||||
DROP TABLE config;
|
||||
CREATE TABLE IF NOT EXISTS config(
|
||||
key VARCHAR(255) PRIMARY KEY,
|
||||
value BLOB NOT NULL
|
||||
);
|
||||
"""
|
||||
)
|
||||
Settings()._insert_missing_settings()
|
||||
|
||||
cursor.executescript("""
|
||||
ALTER TABLE users
|
||||
ADD admin BOOL NOT NULL DEFAULT 0;
|
||||
"""
|
||||
)
|
||||
users = Users()
|
||||
if 'admin' in users:
|
||||
users.get_one(
|
||||
users.user_db.username_to_id('admin')
|
||||
).update(
|
||||
new_username='admin_old',
|
||||
new_password=None
|
||||
)
|
||||
|
||||
users.add(
|
||||
Constants.ADMIN_USERNAME, Constants.ADMIN_PASSWORD,
|
||||
force=True,
|
||||
is_admin=True
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
|
||||
class MigrateHostSettingsToDB(DBMigrator):
|
||||
start_version = 8
|
||||
|
||||
def run(self) -> None:
|
||||
# V8 -> V9
|
||||
# In newer versions, the variables don't exist anymore, and behaviour
|
||||
# was to then set the values to the default values. But that's already
|
||||
# taken care of by the settings, so nothing to do here anymore.
|
||||
return
|
||||
|
||||
|
||||
class MigrateUpdateManifest(DBMigrator):
|
||||
start_version = 9
|
||||
|
||||
def run(self) -> None:
|
||||
# V9 -> V10
|
||||
|
||||
# Nothing is changed in the database
|
||||
# It's just that this code needs to run once
|
||||
# and the DB migration system does exactly that:
|
||||
# run pieces of code once.
|
||||
from backend.internals.settings import Settings, update_manifest
|
||||
|
||||
update_manifest(
|
||||
Settings().get_settings().url_prefix
|
||||
)
|
||||
|
||||
return
|
||||
815
backend/internals/db_models.py
Normal file
815
backend/internals/db_models.py
Normal file
@@ -0,0 +1,815 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from typing import List, Union
|
||||
|
||||
from backend.base.definitions import (NotificationServiceData, ReminderData,
|
||||
ReminderType, StaticReminderData,
|
||||
TemplateData, UserData)
|
||||
from backend.base.helpers import first_of_column
|
||||
from backend.internals.db import REMINDER_TO_KEY, get_db
|
||||
|
||||
|
||||
class NotificationServicesDB:
|
||||
def __init__(self, user_id: int) -> None:
|
||||
self.user_id = user_id
|
||||
return
|
||||
|
||||
def exists(self, notification_service_id: int) -> bool:
|
||||
return get_db().execute("""
|
||||
SELECT 1
|
||||
FROM notification_services
|
||||
WHERE id = :id
|
||||
AND user_id = :user_id
|
||||
LIMIT 1;
|
||||
""",
|
||||
{
|
||||
'user_id': self.user_id,
|
||||
'id': notification_service_id
|
||||
}
|
||||
).fetchone() is not None
|
||||
|
||||
def fetch(
|
||||
self,
|
||||
notification_service_id: Union[int, None] = None
|
||||
) -> List[NotificationServiceData]:
|
||||
id_filter = ""
|
||||
if notification_service_id:
|
||||
id_filter = "AND id = :ns_id"
|
||||
|
||||
result = get_db().execute(f"""
|
||||
SELECT
|
||||
id, title, url
|
||||
FROM notification_services
|
||||
WHERE user_id = :user_id
|
||||
{id_filter}
|
||||
ORDER BY title, id;
|
||||
""",
|
||||
{
|
||||
"user_id": self.user_id,
|
||||
"ns_id": notification_service_id
|
||||
}
|
||||
).fetchalldict()
|
||||
|
||||
return [
|
||||
NotificationServiceData(**entry)
|
||||
for entry in result
|
||||
]
|
||||
|
||||
def add(
|
||||
self,
|
||||
title: str,
|
||||
url: str
|
||||
) -> int:
|
||||
new_id = get_db().execute("""
|
||||
INSERT INTO notification_services(user_id, title, url)
|
||||
VALUES (?, ?, ?)
|
||||
""",
|
||||
(self.user_id, title, url)
|
||||
).lastrowid
|
||||
return new_id
|
||||
|
||||
def update(
|
||||
self,
|
||||
notification_service_id: int,
|
||||
title: str,
|
||||
url: str
|
||||
) -> None:
|
||||
get_db().execute("""
|
||||
UPDATE notification_services
|
||||
SET title = :title, url = :url
|
||||
WHERE id = :ns_id;
|
||||
""",
|
||||
{
|
||||
"title": title,
|
||||
"url": url,
|
||||
"ns_id": notification_service_id
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
def delete(
|
||||
self,
|
||||
notification_service_id: int
|
||||
) -> None:
|
||||
get_db().execute(
|
||||
"DELETE FROM notification_services WHERE id = ?;",
|
||||
(notification_service_id,)
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
class ReminderServicesDB:
|
||||
def __init__(self, reminder_type: ReminderType) -> None:
|
||||
self.key = REMINDER_TO_KEY[reminder_type]
|
||||
return
|
||||
|
||||
def reminder_to_ns(
|
||||
self,
|
||||
reminder_id: int
|
||||
) -> List[int]:
|
||||
"""Get the ID's of the notification services that are linked to the given
|
||||
reminder, static reminder or template.
|
||||
|
||||
Args:
|
||||
reminder_id (int): The ID of the reminder, static reminder or template.
|
||||
|
||||
Returns:
|
||||
List[int]: A list of the notification service ID's that are linked to
|
||||
the given reminder, static reminder or template.
|
||||
"""
|
||||
result = first_of_column(get_db().execute(
|
||||
f"""
|
||||
SELECT notification_service_id
|
||||
FROM reminder_services
|
||||
WHERE {self.key} = ?;
|
||||
""",
|
||||
(reminder_id,)
|
||||
))
|
||||
|
||||
return result
|
||||
|
||||
def update_ns_bindings(
|
||||
self,
|
||||
reminder_id: int,
|
||||
notification_services: List[int]
|
||||
) -> None:
|
||||
"""Update the bindings of a reminder, static reminder or template to
|
||||
notification services.
|
||||
|
||||
Args:
|
||||
reminder_id (int): The ID of the reminder, static reminder or template.
|
||||
|
||||
notification_services (List[int]): The new list of notification services
|
||||
that should be linked to the reminder, static reminder or template.
|
||||
"""
|
||||
cursor = get_db()
|
||||
cursor.connection.isolation_level = None
|
||||
cursor.execute("BEGIN TRANSACTION;")
|
||||
|
||||
cursor.execute(
|
||||
f"""
|
||||
DELETE FROM reminder_services
|
||||
WHERE {self.key} = ?;
|
||||
""",
|
||||
(reminder_id,)
|
||||
)
|
||||
cursor.executemany(
|
||||
f"""
|
||||
INSERT INTO reminder_services(
|
||||
{self.key},
|
||||
notification_service_id
|
||||
)
|
||||
VALUES (?, ?);
|
||||
""",
|
||||
((reminder_id, ns_id) for ns_id in notification_services)
|
||||
)
|
||||
|
||||
cursor.execute("COMMIT;")
|
||||
cursor.connection.isolation_level = ""
|
||||
return
|
||||
|
||||
def uses_ns(
|
||||
self,
|
||||
notification_service_id: int
|
||||
) -> List[int]:
|
||||
"""Get the ID's of the reminders (of given type) that use the given
|
||||
notification service.
|
||||
|
||||
Args:
|
||||
notification_service_id (int): The ID of the notification service to
|
||||
check for.
|
||||
|
||||
Returns:
|
||||
List[int]: The ID's of the reminders (only of the given type) that
|
||||
use the notification service.
|
||||
"""
|
||||
return first_of_column(get_db().execute(
|
||||
f"""
|
||||
SELECT {self.key}
|
||||
FROM reminder_services
|
||||
WHERE notification_service_id = ?
|
||||
AND {self.key} IS NOT NULL
|
||||
LIMIT 1;
|
||||
""",
|
||||
(notification_service_id,)
|
||||
))
|
||||
|
||||
|
||||
class UsersDB:
|
||||
def exists(self, user_id: int) -> bool:
|
||||
return get_db().execute("""
|
||||
SELECT 1
|
||||
FROM users
|
||||
WHERE id = ?
|
||||
LIMIT 1;
|
||||
""",
|
||||
(user_id,)
|
||||
).fetchone() is not None
|
||||
|
||||
def taken(self, username: str) -> bool:
|
||||
return get_db().execute("""
|
||||
SELECT 1
|
||||
FROM users
|
||||
WHERE username = ?
|
||||
LIMIT 1;
|
||||
""",
|
||||
(username,)
|
||||
).fetchone() is not None
|
||||
|
||||
def username_to_id(self, username: str) -> int:
|
||||
return get_db().execute("""
|
||||
SELECT id
|
||||
FROM users
|
||||
WHERE username = ?
|
||||
LIMIT 1;
|
||||
""",
|
||||
(username,)
|
||||
).fetchone()[0]
|
||||
|
||||
def fetch(
|
||||
self,
|
||||
user_id: Union[int, None] = None
|
||||
) -> List[UserData]:
|
||||
id_filter = ""
|
||||
if user_id:
|
||||
id_filter = "WHERE id = :id"
|
||||
|
||||
result = get_db().execute(f"""
|
||||
SELECT
|
||||
id, username, admin, salt, hash
|
||||
FROM users
|
||||
{id_filter}
|
||||
ORDER BY username, id;
|
||||
""",
|
||||
{
|
||||
"id": user_id
|
||||
}
|
||||
).fetchalldict()
|
||||
|
||||
return [
|
||||
UserData(**entry)
|
||||
for entry in result
|
||||
]
|
||||
|
||||
def add(
|
||||
self,
|
||||
username: str,
|
||||
salt: bytes,
|
||||
hash: bytes,
|
||||
admin: bool
|
||||
) -> int:
|
||||
user_id = get_db().execute(
|
||||
"""
|
||||
INSERT INTO users(username, salt, hash, admin)
|
||||
VALUES (?, ?, ?, ?);
|
||||
""",
|
||||
(username, salt, hash, admin)
|
||||
).lastrowid
|
||||
return user_id
|
||||
|
||||
def update(
|
||||
self,
|
||||
user_id: int,
|
||||
username: str,
|
||||
hash: bytes
|
||||
) -> None:
|
||||
get_db().execute("""
|
||||
UPDATE users
|
||||
SET username = :username, hash = :hash
|
||||
WHERE id = :user_id;
|
||||
""",
|
||||
{
|
||||
"username": username,
|
||||
"hash": hash,
|
||||
"user_id": user_id
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
def delete(
|
||||
self,
|
||||
user_id: int
|
||||
) -> None:
|
||||
get_db().executescript(f"""
|
||||
BEGIN TRANSACTION;
|
||||
|
||||
DELETE FROM reminders WHERE user_id = {user_id};
|
||||
DELETE FROM templates WHERE user_id = {user_id};
|
||||
DELETE FROM static_reminders WHERE user_id = {user_id};
|
||||
DELETE FROM notification_services WHERE user_id = {user_id};
|
||||
DELETE FROM users WHERE id = {user_id};
|
||||
|
||||
COMMIT;
|
||||
""")
|
||||
return
|
||||
|
||||
|
||||
class TemplatesDB:
|
||||
def __init__(self, user_id: int) -> None:
|
||||
self.user_id = user_id
|
||||
self.rms_db = ReminderServicesDB(ReminderType.TEMPLATE)
|
||||
return
|
||||
|
||||
def exists(self, template_id: int) -> bool:
|
||||
return get_db().execute(
|
||||
"SELECT 1 FROM templates WHERE id = ? AND user_id = ? LIMIT 1;",
|
||||
(template_id, self.user_id)
|
||||
).fetchone() is not None
|
||||
|
||||
def fetch(
|
||||
self,
|
||||
template_id: Union[int, None] = None
|
||||
) -> List[TemplateData]:
|
||||
id_filter = ""
|
||||
if template_id:
|
||||
id_filter = "AND id = :t_id"
|
||||
|
||||
result = get_db().execute(f"""
|
||||
SELECT
|
||||
id, title, text, color
|
||||
FROM templates
|
||||
WHERE user_id = :user_id
|
||||
{id_filter}
|
||||
ORDER BY title, id;
|
||||
""",
|
||||
{
|
||||
"user_id": self.user_id,
|
||||
"t_id": template_id
|
||||
}
|
||||
).fetchalldict()
|
||||
|
||||
for r in result:
|
||||
r['notification_services'] = self.rms_db.reminder_to_ns(r['id'])
|
||||
|
||||
return [
|
||||
TemplateData(**entry)
|
||||
for entry in result
|
||||
]
|
||||
|
||||
def add(
|
||||
self,
|
||||
title: str,
|
||||
text: Union[str, None],
|
||||
color: Union[str, None],
|
||||
notification_services: List[int]
|
||||
) -> int:
|
||||
new_id = get_db().execute("""
|
||||
INSERT INTO templates(user_id, title, text, color)
|
||||
VALUES (?, ?, ?, ?);
|
||||
""",
|
||||
(self.user_id, title, text, color)
|
||||
).lastrowid
|
||||
|
||||
self.rms_db.update_ns_bindings(
|
||||
new_id, notification_services
|
||||
)
|
||||
|
||||
return new_id
|
||||
|
||||
def update(
|
||||
self,
|
||||
template_id: int,
|
||||
title: str,
|
||||
text: Union[str, None],
|
||||
color: Union[str, None],
|
||||
notification_services: List[int]
|
||||
) -> None:
|
||||
get_db().execute("""
|
||||
UPDATE templates
|
||||
SET
|
||||
title = :title,
|
||||
text = :text,
|
||||
color = :color
|
||||
WHERE id = :t_id;
|
||||
""",
|
||||
{
|
||||
"title": title,
|
||||
"text": text,
|
||||
"color": color,
|
||||
"t_id": template_id
|
||||
}
|
||||
)
|
||||
|
||||
self.rms_db.update_ns_bindings(
|
||||
template_id,
|
||||
notification_services
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
def delete(
|
||||
self,
|
||||
template_id: int
|
||||
) -> None:
|
||||
get_db().execute(
|
||||
"DELETE FROM templates WHERE id = ?;",
|
||||
(template_id,)
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
class StaticRemindersDB:
|
||||
def __init__(self, user_id: int) -> None:
|
||||
self.user_id = user_id
|
||||
self.rms_db = ReminderServicesDB(ReminderType.STATIC_REMINDER)
|
||||
return
|
||||
|
||||
def exists(self, reminder_id: int) -> bool:
|
||||
return get_db().execute("""
|
||||
SELECT 1
|
||||
FROM static_reminders
|
||||
WHERE id = ?
|
||||
AND user_id = ?
|
||||
LIMIT 1;
|
||||
""",
|
||||
(reminder_id, self.user_id)
|
||||
).fetchone() is not None
|
||||
|
||||
def fetch(
|
||||
self,
|
||||
reminder_id: Union[int, None] = None
|
||||
) -> List[StaticReminderData]:
|
||||
id_filter = ""
|
||||
if reminder_id:
|
||||
id_filter = "AND id = :r_id"
|
||||
|
||||
result = get_db().execute(f"""
|
||||
SELECT
|
||||
id, title, text, color
|
||||
FROM static_reminders
|
||||
WHERE user_id = :user_id
|
||||
{id_filter}
|
||||
ORDER BY title, id;
|
||||
""",
|
||||
{
|
||||
"user_id": self.user_id,
|
||||
"r_id": reminder_id
|
||||
}
|
||||
).fetchalldict()
|
||||
|
||||
for r in result:
|
||||
r['notification_services'] = self.rms_db.reminder_to_ns(r['id'])
|
||||
|
||||
return [
|
||||
StaticReminderData(**entry)
|
||||
for entry in result
|
||||
]
|
||||
|
||||
def add(
|
||||
self,
|
||||
title: str,
|
||||
text: Union[str, None],
|
||||
color: Union[str, None],
|
||||
notification_services: List[int]
|
||||
) -> int:
|
||||
new_id = get_db().execute("""
|
||||
INSERT INTO static_reminders(user_id, title, text, color)
|
||||
VALUES (?, ?, ?, ?);
|
||||
""",
|
||||
(self.user_id, title, text, color)
|
||||
).lastrowid
|
||||
|
||||
self.rms_db.update_ns_bindings(
|
||||
new_id, notification_services
|
||||
)
|
||||
|
||||
return new_id
|
||||
|
||||
def update(
|
||||
self,
|
||||
reminder_id: int,
|
||||
title: str,
|
||||
text: Union[str, None],
|
||||
color: Union[str, None],
|
||||
notification_services: List[int]
|
||||
) -> None:
|
||||
get_db().execute("""
|
||||
UPDATE static_reminders
|
||||
SET
|
||||
title = :title,
|
||||
text = :text,
|
||||
color = :color
|
||||
WHERE id = :r_id;
|
||||
""",
|
||||
{
|
||||
"title": title,
|
||||
"text": text,
|
||||
"color": color,
|
||||
"r_id": reminder_id
|
||||
}
|
||||
)
|
||||
|
||||
self.rms_db.update_ns_bindings(
|
||||
reminder_id,
|
||||
notification_services
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
def delete(
|
||||
self,
|
||||
reminder_id: int
|
||||
) -> None:
|
||||
get_db().execute(
|
||||
"DELETE FROM static_reminders WHERE id = ?;",
|
||||
(reminder_id,)
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
class RemindersDB:
|
||||
def __init__(self, user_id: int) -> None:
|
||||
self.user_id = user_id
|
||||
self.rms_db = ReminderServicesDB(ReminderType.REMINDER)
|
||||
return
|
||||
|
||||
def exists(self, reminder_id: int) -> bool:
|
||||
return get_db().execute("""
|
||||
SELECT 1
|
||||
FROM reminders
|
||||
WHERE id = ?
|
||||
AND user_id = ?
|
||||
LIMIT 1;
|
||||
""",
|
||||
(reminder_id, self.user_id)
|
||||
).fetchone() is not None
|
||||
|
||||
def fetch(
|
||||
self,
|
||||
reminder_id: Union[int, None] = None
|
||||
) -> List[ReminderData]:
|
||||
id_filter = ""
|
||||
if reminder_id:
|
||||
id_filter = "AND id = :r_id"
|
||||
|
||||
result = get_db().execute(f"""
|
||||
SELECT
|
||||
id, title, text, color,
|
||||
time, original_time,
|
||||
repeat_quantity, repeat_interval,
|
||||
weekdays AS _weekdays
|
||||
FROM reminders
|
||||
WHERE user_id = :user_id
|
||||
{id_filter};
|
||||
""",
|
||||
{
|
||||
"user_id": self.user_id,
|
||||
"r_id": reminder_id
|
||||
}
|
||||
).fetchalldict()
|
||||
|
||||
for r in result:
|
||||
r['notification_services'] = self.rms_db.reminder_to_ns(r['id'])
|
||||
|
||||
return [
|
||||
ReminderData(**entry)
|
||||
for entry in result
|
||||
]
|
||||
|
||||
def add(
|
||||
self,
|
||||
title: str,
|
||||
text: Union[str, None],
|
||||
time: int,
|
||||
repeat_quantity: Union[str, None],
|
||||
repeat_interval: Union[int, None],
|
||||
weekdays: Union[str, None],
|
||||
original_time: Union[int, None],
|
||||
color: Union[str, None],
|
||||
notification_services: List[int]
|
||||
) -> int:
|
||||
new_id = get_db().execute("""
|
||||
INSERT INTO reminders(
|
||||
user_id,
|
||||
title, text,
|
||||
time,
|
||||
repeat_quantity, repeat_interval,
|
||||
weekdays,
|
||||
original_time,
|
||||
color
|
||||
)
|
||||
VALUES (
|
||||
:user_id,
|
||||
:title, :text,
|
||||
:time,
|
||||
:rq, :ri,
|
||||
:wd,
|
||||
:ot,
|
||||
:color
|
||||
);
|
||||
""",
|
||||
{
|
||||
"user_id": self.user_id,
|
||||
"title": title,
|
||||
"text": text,
|
||||
"time": time,
|
||||
"rq": repeat_quantity,
|
||||
"ri": repeat_interval,
|
||||
"wd": weekdays,
|
||||
"ot": original_time,
|
||||
"color": color
|
||||
}
|
||||
).lastrowid
|
||||
|
||||
self.rms_db.update_ns_bindings(
|
||||
new_id, notification_services
|
||||
)
|
||||
|
||||
return new_id
|
||||
|
||||
def update(
|
||||
self,
|
||||
reminder_id: int,
|
||||
title: str,
|
||||
text: Union[str, None],
|
||||
time: int,
|
||||
repeat_quantity: Union[str, None],
|
||||
repeat_interval: Union[int, None],
|
||||
weekdays: Union[str, None],
|
||||
original_time: Union[int, None],
|
||||
color: Union[str, None],
|
||||
notification_services: List[int]
|
||||
) -> None:
|
||||
get_db().execute("""
|
||||
UPDATE reminders
|
||||
SET
|
||||
title = :title,
|
||||
text = :text,
|
||||
time = :time,
|
||||
repeat_quantity = :rq,
|
||||
repeat_interval = :ri,
|
||||
weekdays = :wd,
|
||||
original_time = :ot,
|
||||
color = :color
|
||||
WHERE id = :r_id;
|
||||
""",
|
||||
{
|
||||
"title": title,
|
||||
"text": text,
|
||||
"time": time,
|
||||
"rq": repeat_quantity,
|
||||
"ri": repeat_interval,
|
||||
"wd": weekdays,
|
||||
"ot": original_time,
|
||||
"color": color,
|
||||
"r_id": reminder_id
|
||||
}
|
||||
)
|
||||
|
||||
self.rms_db.update_ns_bindings(
|
||||
reminder_id,
|
||||
notification_services
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
def delete(
|
||||
self,
|
||||
reminder_id: int
|
||||
) -> None:
|
||||
get_db().execute(
|
||||
"DELETE FROM reminders WHERE id = ?;",
|
||||
(reminder_id,)
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
class UserlessRemindersDB:
|
||||
def __init__(self) -> None:
|
||||
self.rms_db = ReminderServicesDB(ReminderType.REMINDER)
|
||||
return
|
||||
|
||||
def exists(self, reminder_id: int) -> bool:
|
||||
return get_db().execute("""
|
||||
SELECT 1
|
||||
FROM reminders
|
||||
WHERE id = ?
|
||||
LIMIT 1;
|
||||
""",
|
||||
(reminder_id,)
|
||||
).fetchone() is not None
|
||||
|
||||
def reminder_id_to_user_id(self, reminder_id: int) -> int:
|
||||
return get_db().execute(
|
||||
"""
|
||||
SELECT user_id
|
||||
FROM reminders
|
||||
WHERE id = ?
|
||||
LIMIT 1;
|
||||
""",
|
||||
(reminder_id,)
|
||||
).exists() or -1
|
||||
|
||||
def get_soonest_time(self) -> Union[int, None]:
|
||||
return get_db().execute("SELECT MIN(time) FROM reminders;").exists()
|
||||
|
||||
def fetch(
|
||||
self,
|
||||
time: Union[int, None] = None
|
||||
) -> List[ReminderData]:
|
||||
time_filter = ""
|
||||
if time:
|
||||
time_filter = "WHERE time = :time"
|
||||
|
||||
result = get_db().execute(f"""
|
||||
SELECT
|
||||
id,
|
||||
title, text, color,
|
||||
time, original_time,
|
||||
repeat_quantity, repeat_interval,
|
||||
weekdays AS _weekdays
|
||||
FROM reminders
|
||||
{time_filter};
|
||||
""",
|
||||
{
|
||||
"time": time
|
||||
}
|
||||
).fetchalldict()
|
||||
|
||||
for r in result:
|
||||
r['notification_services'] = self.rms_db.reminder_to_ns(r['id'])
|
||||
|
||||
return [
|
||||
ReminderData(**entry)
|
||||
for entry in result
|
||||
]
|
||||
|
||||
def add(
|
||||
self,
|
||||
user_id: int,
|
||||
title: str,
|
||||
text: Union[str, None],
|
||||
time: int,
|
||||
repeat_quantity: Union[str, None],
|
||||
repeat_interval: Union[int, None],
|
||||
weekdays: Union[str, None],
|
||||
original_time: Union[int, None],
|
||||
color: Union[str, None],
|
||||
notification_services: List[int]
|
||||
) -> int:
|
||||
new_id = get_db().execute("""
|
||||
INSERT INTO reminders(
|
||||
user_id,
|
||||
title, text,
|
||||
time,
|
||||
repeat_quantity, repeat_interval,
|
||||
weekdays,
|
||||
original_time,
|
||||
color
|
||||
)
|
||||
VALUES (
|
||||
:user_id,
|
||||
:title, :text,
|
||||
:time,
|
||||
:rq, :ri,
|
||||
:wd,
|
||||
:ot,
|
||||
:color
|
||||
);
|
||||
""",
|
||||
{
|
||||
"user_id": user_id,
|
||||
"title": title,
|
||||
"text": text,
|
||||
"time": time,
|
||||
"rq": repeat_quantity,
|
||||
"ri": repeat_interval,
|
||||
"wd": weekdays,
|
||||
"ot": original_time,
|
||||
"color": color
|
||||
}
|
||||
).lastrowid
|
||||
|
||||
self.rms_db.update_ns_bindings(
|
||||
new_id, notification_services
|
||||
)
|
||||
|
||||
return new_id
|
||||
|
||||
def update(
|
||||
self,
|
||||
reminder_id: int,
|
||||
time: int
|
||||
) -> None:
|
||||
get_db().execute("""
|
||||
UPDATE reminders
|
||||
SET time = :time
|
||||
WHERE id = :r_id;
|
||||
""",
|
||||
{
|
||||
"time": time,
|
||||
"r_id": reminder_id
|
||||
}
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
def delete(
|
||||
self,
|
||||
reminder_id: int
|
||||
) -> None:
|
||||
get_db().execute(
|
||||
"DELETE FROM reminders WHERE id = ?;",
|
||||
(reminder_id,)
|
||||
)
|
||||
return
|
||||
247
backend/internals/server.py
Normal file
247
backend/internals/server.py
Normal file
@@ -0,0 +1,247 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
Setting up, running and shutting down the API and web-ui
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from os import urandom
|
||||
from threading import Timer, current_thread
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
from flask import Flask, render_template, request
|
||||
from waitress.server import create_server
|
||||
from waitress.task import ThreadedTaskDispatcher as TTD
|
||||
from werkzeug.middleware.dispatcher import DispatcherMiddleware
|
||||
|
||||
from backend.base.definitions import Constants, StartType
|
||||
from backend.base.helpers import Singleton, folder_path
|
||||
from backend.base.logging import LOGGER
|
||||
from backend.internals.db import (DBConnectionManager,
|
||||
close_db, revert_db_import)
|
||||
from backend.internals.settings import Settings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from waitress.server import BaseWSGIServer, MultiSocketServer
|
||||
|
||||
|
||||
class ThreadedTaskDispatcher(TTD):
|
||||
def handler_thread(self, thread_no: int) -> None:
|
||||
super().handler_thread(thread_no)
|
||||
|
||||
thread_id = current_thread().native_id or -1
|
||||
if (
|
||||
thread_id in DBConnectionManager.instances
|
||||
and not DBConnectionManager.instances[thread_id].closed
|
||||
):
|
||||
DBConnectionManager.instances[thread_id].close()
|
||||
|
||||
return
|
||||
|
||||
def shutdown(self, cancel_pending: bool = True, timeout: int = 5) -> bool:
|
||||
print()
|
||||
LOGGER.info('Shutting down MIND')
|
||||
result = super().shutdown(cancel_pending, timeout)
|
||||
return result
|
||||
|
||||
|
||||
def handle_start_type(start_type: StartType) -> None:
|
||||
"""Do special actions needed based on restart version.
|
||||
|
||||
Args:
|
||||
start_type (StartType): The restart version.
|
||||
"""
|
||||
if start_type == StartType.RESTART_HOSTING_CHANGES:
|
||||
LOGGER.info("Starting timer for hosting changes")
|
||||
Server().revert_hosting_timer.start()
|
||||
|
||||
elif start_type == StartType.RESTART_DB_CHANGES:
|
||||
LOGGER.info("Starting timer for database import")
|
||||
Server().revert_db_timer.start()
|
||||
|
||||
return
|
||||
|
||||
|
||||
def diffuse_timers() -> None:
|
||||
"""Stop any timers running after doing a special restart."""
|
||||
|
||||
SERVER = Server()
|
||||
|
||||
if SERVER.revert_hosting_timer.is_alive():
|
||||
LOGGER.info("Timer for hosting changes diffused")
|
||||
SERVER.revert_hosting_timer.cancel()
|
||||
|
||||
elif SERVER.revert_db_timer.is_alive():
|
||||
LOGGER.info("Timer for database import diffused")
|
||||
SERVER.revert_db_timer.cancel()
|
||||
revert_db_import(swap=False)
|
||||
|
||||
return
|
||||
|
||||
|
||||
class Server(metaclass=Singleton):
|
||||
api_prefix = "/api"
|
||||
admin_api_extension = "/admin"
|
||||
admin_prefix = "/api/admin"
|
||||
url_prefix = ''
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.start_type = None
|
||||
|
||||
self.revert_db_timer = Timer(
|
||||
Constants.DB_REVERT_TIME,
|
||||
revert_db_import,
|
||||
kwargs={"swap": True}
|
||||
)
|
||||
self.revert_db_timer.name = "DatabaseImportHandler"
|
||||
|
||||
self.revert_hosting_timer = Timer(
|
||||
Constants.HOSTING_REVERT_TIME,
|
||||
self.restore_hosting_settings
|
||||
)
|
||||
self.revert_hosting_timer.name = "HostingHandler"
|
||||
|
||||
return
|
||||
|
||||
def create_app(self) -> None:
|
||||
"""Creates an flask app instance that can be used to start a web server"""
|
||||
|
||||
from frontend.api import admin_api, api
|
||||
from frontend.ui import ui
|
||||
|
||||
app = Flask(
|
||||
__name__,
|
||||
template_folder=folder_path('frontend', 'templates'),
|
||||
static_folder=folder_path('frontend', 'static'),
|
||||
static_url_path='/static'
|
||||
)
|
||||
app.config['SECRET_KEY'] = urandom(32)
|
||||
app.config['JSONIFY_PRETTYPRINT_REGULAR'] = True
|
||||
app.config['JSON_SORT_KEYS'] = False
|
||||
|
||||
# Add error handlers
|
||||
@app.errorhandler(400)
|
||||
def bad_request(e):
|
||||
return {'error': "BadRequest", "result": {}}, 400
|
||||
|
||||
@app.errorhandler(405)
|
||||
def method_not_allowed(e):
|
||||
return {'error': "MethodNotAllowed", "result": {}}, 405
|
||||
|
||||
@app.errorhandler(500)
|
||||
def internal_error(e):
|
||||
return {'error': "InternalError", "result": {}}, 500
|
||||
|
||||
# Add endpoints
|
||||
app.register_blueprint(ui)
|
||||
app.register_blueprint(api, url_prefix=self.api_prefix)
|
||||
app.register_blueprint(admin_api, url_prefix=self.admin_prefix)
|
||||
|
||||
# Setup db handling
|
||||
app.teardown_appcontext(close_db)
|
||||
|
||||
self.app = app
|
||||
return
|
||||
|
||||
def set_url_prefix(self, url_prefix: str) -> None:
|
||||
"""Change the URL prefix of the server.
|
||||
|
||||
Args:
|
||||
url_prefix (str): The desired URL prefix to set it to.
|
||||
"""
|
||||
self.app.config["APPLICATION_ROOT"] = url_prefix
|
||||
self.app.wsgi_app = DispatcherMiddleware( # type: ignore
|
||||
Flask(__name__),
|
||||
{url_prefix: self.app.wsgi_app}
|
||||
)
|
||||
self.url_prefix = url_prefix
|
||||
return
|
||||
|
||||
def __create_waitress_server(
|
||||
self,
|
||||
host: str,
|
||||
port: int
|
||||
) -> Union[MultiSocketServer, BaseWSGIServer]:
|
||||
"""From the `Flask` instance created in `self.create_app()`, create
|
||||
a waitress server instance.
|
||||
|
||||
Args:
|
||||
host (str): Where to host the server on (e.g. `0.0.0.0`).
|
||||
port (int): The port to host the server on (e.g. `5656`).
|
||||
|
||||
Returns:
|
||||
Union[MultiSocketServer, BaseWSGIServer]: The waitress server instance.
|
||||
"""
|
||||
dispatcher = ThreadedTaskDispatcher()
|
||||
dispatcher.set_thread_count(Constants.HOSTING_THREADS)
|
||||
|
||||
server = create_server(
|
||||
self.app,
|
||||
_dispatcher=dispatcher,
|
||||
host=host,
|
||||
port=port,
|
||||
threads=Constants.HOSTING_THREADS
|
||||
)
|
||||
return server
|
||||
|
||||
def run(self, host: str, port: int) -> None:
|
||||
"""Start the webserver.
|
||||
|
||||
Args:
|
||||
host (str): Where to host the server on (e.g. `0.0.0.0`).
|
||||
port (int): The port to host the server on (e.g. `5656`).
|
||||
"""
|
||||
self.server = self.__create_waitress_server(host, port)
|
||||
LOGGER.info(f'MIND running on http://{host}:{port}{self.url_prefix}')
|
||||
self.server.run()
|
||||
|
||||
return
|
||||
|
||||
def __shutdown_thread_function(self) -> None:
|
||||
"""Shutdown waitress server. Intended to be run in a thread.
|
||||
"""
|
||||
if not hasattr(self, 'server'):
|
||||
return
|
||||
|
||||
self.server.task_dispatcher.shutdown()
|
||||
self.server.close()
|
||||
self.server._map.clear() # type: ignore
|
||||
return
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""
|
||||
Stop the waitress server. Starts a thread that shuts down the server.
|
||||
"""
|
||||
t = Timer(1.0, self.__shutdown_thread_function)
|
||||
t.name = "InternalStateHandler"
|
||||
t.start()
|
||||
return
|
||||
|
||||
def restart(
|
||||
self,
|
||||
start_type: StartType = StartType.STARTUP
|
||||
) -> None:
|
||||
"""Same as `self.shutdown()`, but restart instead of shutting down.
|
||||
|
||||
Args:
|
||||
start_type (StartType, optional): Why Kapowarr should
|
||||
restart.
|
||||
Defaults to StartType.STARTUP.
|
||||
"""
|
||||
self.start_type = start_type
|
||||
self.shutdown()
|
||||
return
|
||||
|
||||
def restore_hosting_settings(self) -> None:
|
||||
with self.app.app_context():
|
||||
settings = Settings()
|
||||
values = settings.get_settings()
|
||||
main_settings = {
|
||||
'host': values.backup_host,
|
||||
'port': values.backup_port,
|
||||
'url_prefix': values.backup_url_prefix
|
||||
}
|
||||
settings.update(main_settings)
|
||||
self.restart()
|
||||
return
|
||||
255
backend/internals/settings.py
Normal file
255
backend/internals/settings.py
Normal file
@@ -0,0 +1,255 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from dataclasses import _MISSING_TYPE, asdict, dataclass
|
||||
from functools import lru_cache
|
||||
from json import dump, load
|
||||
from logging import DEBUG, INFO
|
||||
from typing import Any, Dict, Mapping
|
||||
|
||||
from backend.base.custom_exceptions import InvalidKeyValue, KeyNotFound
|
||||
from backend.base.helpers import (Singleton, folder_path,
|
||||
get_python_version, reversed_tuples)
|
||||
from backend.base.logging import LOGGER, set_log_level
|
||||
from backend.internals.db import DBConnection, commit, get_db
|
||||
from backend.internals.db_migration import get_latest_db_version
|
||||
|
||||
THIRTY_DAYS = 2592000
|
||||
|
||||
|
||||
@lru_cache(1)
|
||||
def get_about_data() -> Dict[str, Any]:
|
||||
"""Get data about the application and it's environment.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the version is not found in the pyproject.toml file.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The information.
|
||||
"""
|
||||
with open(folder_path("pyproject.toml"), "r") as f:
|
||||
for line in f:
|
||||
if line.startswith("version = "):
|
||||
version = "V" + line.split('"')[1]
|
||||
break
|
||||
else:
|
||||
raise RuntimeError("Version not found in pyproject.toml")
|
||||
|
||||
return {
|
||||
"version": version,
|
||||
"python_version": get_python_version(),
|
||||
"database_version": get_latest_db_version(),
|
||||
"database_location": DBConnection.file,
|
||||
"data_folder": folder_path()
|
||||
}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SettingsValues:
|
||||
database_version: int = get_latest_db_version()
|
||||
log_level: int = INFO
|
||||
|
||||
host: str = '0.0.0.0'
|
||||
port: int = 8080
|
||||
url_prefix: str = ''
|
||||
backup_host: str = '0.0.0.0'
|
||||
backup_port: int = 8080
|
||||
backup_url_prefix: str = ''
|
||||
|
||||
allow_new_accounts: bool = True
|
||||
login_time: int = 3600
|
||||
login_time_reset: bool = True
|
||||
|
||||
def todict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
k: v
|
||||
for k, v in self.__dict__.items()
|
||||
if not k.startswith('backup_')
|
||||
}
|
||||
|
||||
|
||||
class Settings(metaclass=Singleton):
|
||||
def __init__(self) -> None:
|
||||
self._insert_missing_settings()
|
||||
self._fetch_settings()
|
||||
return
|
||||
|
||||
def _insert_missing_settings(self) -> None:
|
||||
"Insert any missing keys from the settings into the database."
|
||||
get_db().executemany(
|
||||
"INSERT OR IGNORE INTO config(key, value) VALUES (?, ?);",
|
||||
asdict(SettingsValues()).items()
|
||||
)
|
||||
commit()
|
||||
return
|
||||
|
||||
def _fetch_settings(self) -> None:
|
||||
"Load the settings from the database into the cache."
|
||||
db_values = {
|
||||
k: v
|
||||
for k, v in get_db().execute(
|
||||
"SELECT key, value FROM config;"
|
||||
)
|
||||
if k in SettingsValues.__dataclass_fields__
|
||||
}
|
||||
|
||||
for b_key in ('allow_new_accounts', 'login_time_reset'):
|
||||
db_values[b_key] = bool(db_values[b_key])
|
||||
|
||||
self.__cached_values = SettingsValues(**db_values)
|
||||
return
|
||||
|
||||
def get_settings(self) -> SettingsValues:
|
||||
"""Get the settings from the cache.
|
||||
|
||||
Returns:
|
||||
SettingsValues: The settings.
|
||||
"""
|
||||
return self.__cached_values
|
||||
|
||||
# Alias, better in one-liners
|
||||
# sv = Settings Values
|
||||
@property
|
||||
def sv(self) -> SettingsValues:
|
||||
"""Get the settings from the cache.
|
||||
|
||||
Returns:
|
||||
SettingsValues: The settings.
|
||||
"""
|
||||
return self.__cached_values
|
||||
|
||||
def update(
|
||||
self,
|
||||
data: Mapping[str, Any]
|
||||
) -> None:
|
||||
"""Change the settings, in a `dict.update()` type of way.
|
||||
|
||||
Args:
|
||||
data (Mapping[str, Any]): The keys and their new values.
|
||||
|
||||
Raises:
|
||||
KeyNotFound: Key is not a setting.
|
||||
InvalidKeyValue: Value of the key is not allowed.
|
||||
"""
|
||||
formatted_data = {}
|
||||
for key, value in data.items():
|
||||
formatted_data[key] = self.__format_setting(key, value)
|
||||
|
||||
get_db().executemany(
|
||||
"UPDATE config SET value = ? WHERE key = ?;",
|
||||
reversed_tuples(formatted_data.items())
|
||||
)
|
||||
|
||||
for key, handler in (
|
||||
('url_prefix', update_manifest),
|
||||
('log_level', set_log_level)
|
||||
):
|
||||
if (
|
||||
key in data
|
||||
and formatted_data[key] != getattr(self.get_settings(), key)
|
||||
):
|
||||
handler(formatted_data[key])
|
||||
|
||||
self._fetch_settings()
|
||||
|
||||
LOGGER.info(f"Settings changed: {formatted_data}")
|
||||
|
||||
return
|
||||
|
||||
def reset(self, key: str) -> None:
|
||||
"""Reset the value of the key to the default value.
|
||||
|
||||
Args:
|
||||
key (str): The key of which to reset the value.
|
||||
|
||||
Raises:
|
||||
KeyNotFound: Key is not a setting.
|
||||
"""
|
||||
LOGGER.debug(f'Setting reset: {key}')
|
||||
|
||||
if not isinstance(
|
||||
SettingsValues.__dataclass_fields__[key].default_factory,
|
||||
_MISSING_TYPE
|
||||
):
|
||||
self.update({
|
||||
key: SettingsValues.__dataclass_fields__[key].default_factory()
|
||||
})
|
||||
else:
|
||||
self.update({
|
||||
key: SettingsValues.__dataclass_fields__[key].default
|
||||
})
|
||||
|
||||
return
|
||||
|
||||
def backup_hosting_settings(self) -> None:
|
||||
"Backup the hosting settings in the database."
|
||||
s = self.get_settings()
|
||||
backup_settings = {
|
||||
'backup_host': s.host,
|
||||
'backup_port': s.port,
|
||||
'backup_url_prefix': s.url_prefix
|
||||
}
|
||||
self.update(backup_settings)
|
||||
return
|
||||
|
||||
def __format_setting(self, key: str, value: Any) -> Any:
|
||||
"""Check if the value of a setting is allowed and convert if needed.
|
||||
|
||||
Args:
|
||||
key (str): Key of setting.
|
||||
value (Any): Value of setting.
|
||||
|
||||
Raises:
|
||||
KeyNotFound: Key is not a setting.
|
||||
InvalidKeyValue: Value is not allowed.
|
||||
|
||||
Returns:
|
||||
Any: (Converted) Setting value.
|
||||
"""
|
||||
converted_value = value
|
||||
|
||||
if key not in SettingsValues.__dataclass_fields__:
|
||||
raise KeyNotFound(key)
|
||||
|
||||
key_data = SettingsValues.__dataclass_fields__[key]
|
||||
|
||||
if not isinstance(value, key_data.type):
|
||||
raise InvalidKeyValue(key, value)
|
||||
|
||||
if key == 'login_time':
|
||||
if not 60 <= value <= THIRTY_DAYS:
|
||||
raise InvalidKeyValue(key, value)
|
||||
|
||||
elif key in ('port', 'backup_port'):
|
||||
if not 1 <= value <= 65535:
|
||||
raise InvalidKeyValue(key, value)
|
||||
|
||||
elif key in ('url_prefix', 'backup_url_prefix'):
|
||||
if value:
|
||||
converted_value = ('/' + value.lstrip('/')).rstrip('/')
|
||||
|
||||
elif key == 'log_level':
|
||||
if value not in (INFO, DEBUG):
|
||||
raise InvalidKeyValue(key, value)
|
||||
|
||||
return converted_value
|
||||
|
||||
|
||||
def update_manifest(url_base: str) -> None:
|
||||
"""Update the url's in the manifest file.
|
||||
Needs to happen when url base changes.
|
||||
|
||||
Args:
|
||||
url_base (str): The url base to use in the file.
|
||||
"""
|
||||
filename = folder_path('frontend', 'static', 'json', 'pwa_manifest.json')
|
||||
|
||||
with open(filename, 'r') as f:
|
||||
manifest = load(f)
|
||||
manifest['start_url'] = url_base + '/'
|
||||
manifest['scope'] = url_base + '/'
|
||||
manifest['icons'][0]['src'] = f'{url_base}/static/img/favicon.svg'
|
||||
|
||||
with open(filename, 'w') as f:
|
||||
dump(manifest, f, indent=4)
|
||||
|
||||
return
|
||||
@@ -1,141 +0,0 @@
|
||||
#-*- coding: utf-8 -*-
|
||||
|
||||
import logging
|
||||
import logging.config
|
||||
from os.path import exists
|
||||
from typing import Any
|
||||
|
||||
from backend.helpers import folder_path
|
||||
|
||||
|
||||
class InfoOnlyFilter(logging.Filter):
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
return record.levelno == logging.INFO
|
||||
|
||||
|
||||
class DebuggingOnlyFilter(logging.Filter):
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
return LOGGER.level == logging.DEBUG
|
||||
|
||||
|
||||
class ErrorColorFormatter(logging.Formatter):
|
||||
def format(self, record: logging.LogRecord) -> Any:
|
||||
result = super().format(record)
|
||||
return f'\033[1;31:40m{result}\033[0m'
|
||||
|
||||
|
||||
LOGGER_NAME = "MIND"
|
||||
LOGGER_DEBUG_FILENAME = "MIND_debug.log"
|
||||
LOGGER = logging.getLogger(LOGGER_NAME)
|
||||
LOGGING_CONFIG = {
|
||||
"version": 1,
|
||||
"disable_existing_loggers": False,
|
||||
"formatters": {
|
||||
"simple": {
|
||||
"format": "[%(asctime)s][%(levelname)s] %(message)s",
|
||||
"datefmt": "%H:%M:%S"
|
||||
},
|
||||
"simple_red": {
|
||||
"()": ErrorColorFormatter,
|
||||
"format": "[%(asctime)s][%(levelname)s] %(message)s",
|
||||
"datefmt": "%H:%M:%S"
|
||||
},
|
||||
"detailed": {
|
||||
"format": "%(asctime)s | %(threadName)s | %(filename)sL%(lineno)s | %(levelname)s | %(message)s",
|
||||
"datefmt": "%Y-%m-%dT%H:%M:%S%z",
|
||||
}
|
||||
},
|
||||
"filters": {
|
||||
"only_info": {
|
||||
"()": InfoOnlyFilter
|
||||
},
|
||||
"only_if_debugging": {
|
||||
"()": DebuggingOnlyFilter
|
||||
}
|
||||
},
|
||||
"handlers": {
|
||||
"console_error": {
|
||||
"class": "logging.StreamHandler",
|
||||
"level": "WARNING",
|
||||
"formatter": "simple_red",
|
||||
"stream": "ext://sys.stderr"
|
||||
},
|
||||
"console": {
|
||||
"class": "logging.StreamHandler",
|
||||
"level": "INFO",
|
||||
"formatter": "simple",
|
||||
"filters": ["only_info"],
|
||||
"stream": "ext://sys.stdout"
|
||||
},
|
||||
"debug_file": {
|
||||
"class": "logging.StreamHandler",
|
||||
"level": "DEBUG",
|
||||
"formatter": "detailed",
|
||||
"filters": ["only_if_debugging"],
|
||||
"stream": ""
|
||||
}
|
||||
},
|
||||
"loggers": {
|
||||
LOGGER_NAME: {
|
||||
"level": "INFO"
|
||||
}
|
||||
},
|
||||
"root": {
|
||||
"level": "DEBUG",
|
||||
"handlers": [
|
||||
"console",
|
||||
"console_error",
|
||||
"debug_file"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
def setup_logging() -> None:
|
||||
"Setup the basic config of the logging module"
|
||||
logging.config.dictConfig(LOGGING_CONFIG)
|
||||
return
|
||||
|
||||
def get_debug_log_filepath() -> str:
|
||||
"""
|
||||
Get the filepath to the debug logging file.
|
||||
Not in a global variable to avoid unnecessary computation.
|
||||
"""
|
||||
return folder_path(LOGGER_DEBUG_FILENAME)
|
||||
|
||||
def set_log_level(
|
||||
level: int,
|
||||
clear_file: bool = True
|
||||
) -> None:
|
||||
"""Change the logging level
|
||||
|
||||
Args:
|
||||
level (int): The level to set the logging to.
|
||||
Should be a logging level, like `logging.INFO` or `logging.DEBUG`.
|
||||
|
||||
clear_file (bool, optional): Empty the debug logging file.
|
||||
Defaults to True.
|
||||
"""
|
||||
if LOGGER.level == level:
|
||||
return
|
||||
|
||||
LOGGER.debug(f'Setting logging level: {level}')
|
||||
LOGGER.setLevel(level)
|
||||
|
||||
if level == logging.DEBUG:
|
||||
stream_handler = logging.getLogger().handlers[
|
||||
LOGGING_CONFIG["root"]["handlers"].index('debug_file')
|
||||
]
|
||||
|
||||
file = get_debug_log_filepath()
|
||||
|
||||
if clear_file:
|
||||
if exists(file):
|
||||
open(file, "w").close()
|
||||
else:
|
||||
open(file, "x").close()
|
||||
|
||||
stream_handler.setStream(
|
||||
open(file, "a")
|
||||
)
|
||||
|
||||
return
|
||||
@@ -1,405 +0,0 @@
|
||||
#-*- coding: utf-8 -*-
|
||||
|
||||
from re import compile
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from apprise import Apprise
|
||||
|
||||
from backend.custom_exceptions import (NotificationServiceInUse,
|
||||
NotificationServiceNotFound)
|
||||
from backend.db import get_db
|
||||
from backend.helpers import when_not_none
|
||||
from backend.logging import LOGGER
|
||||
|
||||
remove_named_groups = compile(r'(?<=\()\?P<\w+>')
|
||||
|
||||
def process_regex(regex: Union[List[str], None]) -> Union[None, List[str]]:
|
||||
return when_not_none(
|
||||
regex,
|
||||
lambda r: [remove_named_groups.sub('', r[0]), r[1]]
|
||||
)
|
||||
|
||||
def _sort_tokens(t: dict) -> List[int]:
|
||||
result = [
|
||||
int(not t['required'])
|
||||
]
|
||||
|
||||
if t['type'] == 'choice':
|
||||
result.append(0)
|
||||
elif t['type'] != 'list':
|
||||
result.append(1)
|
||||
else:
|
||||
result.append(2)
|
||||
|
||||
return result
|
||||
|
||||
def get_apprise_services() -> List[Dict[str, Union[str, Dict[str, list]]]]:
|
||||
apprise_services = []
|
||||
raw = Apprise().details()
|
||||
for entry in raw['schemas']:
|
||||
entry: Dict[str, Union[str, dict]]
|
||||
result: Dict[str, Union[str, Dict[str, list]]] = {
|
||||
'name': str(entry['service_name']),
|
||||
'doc_url': entry['setup_url'],
|
||||
'details': {
|
||||
'templates': entry['details']['templates'],
|
||||
'tokens': [],
|
||||
'args': []
|
||||
}
|
||||
}
|
||||
|
||||
schema = entry['details']['tokens']['schema']
|
||||
result['details']['tokens'].append({
|
||||
'name': schema['name'],
|
||||
'map_to': 'schema',
|
||||
'required': schema['required'],
|
||||
'type': 'choice',
|
||||
'options': schema['values'],
|
||||
'default': schema.get('default')
|
||||
})
|
||||
|
||||
handled_tokens = {'schema'}
|
||||
result['details']['tokens'] += [
|
||||
{
|
||||
'name': v['name'],
|
||||
'map_to': k,
|
||||
'required': v['required'],
|
||||
'type': 'list',
|
||||
'delim': v['delim'][0],
|
||||
'content': [
|
||||
{
|
||||
'name': content['name'],
|
||||
'required': content['required'],
|
||||
'type': content['type'],
|
||||
'prefix': content.get('prefix'),
|
||||
'regex': process_regex(content.get('regex'))
|
||||
}
|
||||
for content, _ in (
|
||||
(entry['details']['tokens'][e], handled_tokens.add(e))
|
||||
for e in v['group']
|
||||
)
|
||||
]
|
||||
}
|
||||
for k, v in
|
||||
filter(
|
||||
lambda t: t[1]['type'].startswith('list:'),
|
||||
entry['details']['tokens'].items()
|
||||
)
|
||||
]
|
||||
handled_tokens.update(
|
||||
set(map(lambda e: e[0],
|
||||
filter(lambda e: e[1]['type'].startswith('list:'),
|
||||
entry['details']['tokens'].items())
|
||||
))
|
||||
)
|
||||
|
||||
result['details']['tokens'] += [
|
||||
{
|
||||
'name': v['name'],
|
||||
'map_to': k,
|
||||
'required': v['required'],
|
||||
'type': v['type'].split(':')[0],
|
||||
**({
|
||||
'options': v.get('values'),
|
||||
'default': v.get('default')
|
||||
} if v['type'].startswith('choice') else {
|
||||
'prefix': v.get('prefix'),
|
||||
'min': v.get('min'),
|
||||
'max': v.get('max'),
|
||||
'regex': process_regex(v.get('regex'))
|
||||
})
|
||||
}
|
||||
for k, v in
|
||||
filter(
|
||||
lambda t: not t[0] in handled_tokens,
|
||||
entry['details']['tokens'].items()
|
||||
)
|
||||
]
|
||||
|
||||
result['details']['tokens'].sort(key=_sort_tokens)
|
||||
|
||||
result['details']['args'] += [
|
||||
{
|
||||
'name': v.get('name', k),
|
||||
'map_to': k,
|
||||
'required': v.get('required', False),
|
||||
'type': v['type'].split(':')[0],
|
||||
**({
|
||||
'delim': v['delim'][0],
|
||||
'content': []
|
||||
} if v['type'].startswith('list') else {
|
||||
'options': v['values'],
|
||||
'default': v.get('default')
|
||||
} if v['type'].startswith('choice') else {
|
||||
'default': v['default']
|
||||
} if v['type'] == 'bool' else {
|
||||
'min': v.get('min'),
|
||||
'max': v.get('max'),
|
||||
'regex': process_regex(v.get('regex'))
|
||||
})
|
||||
}
|
||||
for k, v in
|
||||
filter(
|
||||
lambda a: (
|
||||
a[1].get('alias_of') is None
|
||||
and not a[0] in ('cto', 'format', 'overflow', 'rto', 'verify')
|
||||
),
|
||||
entry['details']['args'].items()
|
||||
)
|
||||
]
|
||||
result['details']['args'].sort(key=_sort_tokens)
|
||||
|
||||
apprise_services.append(result)
|
||||
|
||||
apprise_services.sort(key=lambda s: s['name'].lower())
|
||||
|
||||
apprise_services.insert(0, {
|
||||
'name': 'Custom URL',
|
||||
'doc_url': 'https://github.com/caronc/apprise#supported-notifications',
|
||||
'details': {
|
||||
'templates': ['{url}'],
|
||||
'tokens': [{
|
||||
'name': 'Apprise URL',
|
||||
'map_to': 'url',
|
||||
'required': True,
|
||||
'type': 'string',
|
||||
'prefix': None,
|
||||
'min': None,
|
||||
'max': None,
|
||||
'regex': None
|
||||
}],
|
||||
'args': []
|
||||
}
|
||||
})
|
||||
|
||||
return apprise_services
|
||||
|
||||
class NotificationService:
|
||||
def __init__(self, user_id: int, notification_service_id: int) -> None:
|
||||
self.id = notification_service_id
|
||||
|
||||
if not get_db().execute("""
|
||||
SELECT 1
|
||||
FROM notification_services
|
||||
WHERE id = ?
|
||||
AND user_id = ?
|
||||
LIMIT 1;
|
||||
""",
|
||||
(self.id, user_id)
|
||||
).fetchone():
|
||||
raise NotificationServiceNotFound
|
||||
|
||||
def get(self) -> dict:
|
||||
"""Get the info about the notification service
|
||||
|
||||
Returns:
|
||||
dict: The info about the notification service
|
||||
"""
|
||||
result = dict(get_db(dict).execute("""
|
||||
SELECT id, title, url
|
||||
FROM notification_services
|
||||
WHERE id = ?
|
||||
LIMIT 1
|
||||
""",
|
||||
(self.id,)
|
||||
).fetchone())
|
||||
|
||||
return result
|
||||
|
||||
def update(
|
||||
self,
|
||||
title: Optional[str] = None,
|
||||
url: Optional[str] = None
|
||||
) -> dict:
|
||||
"""Edit the notification service
|
||||
|
||||
Args:
|
||||
title (Optional[str], optional): The new title of the service. Defaults to None.
|
||||
url (Optional[str], optional): The new url of the service. Defaults to None.
|
||||
|
||||
Returns:
|
||||
dict: The new info about the service
|
||||
"""
|
||||
LOGGER.info(f'Updating notification service {self.id}: {title=}, {url=}')
|
||||
|
||||
# Get current data and update it with new values
|
||||
data = self.get()
|
||||
new_values = {
|
||||
'title': title,
|
||||
'url': url
|
||||
}
|
||||
for k, v in new_values.items():
|
||||
if v is not None:
|
||||
data[k] = v
|
||||
|
||||
# Update database
|
||||
get_db().execute("""
|
||||
UPDATE notification_services
|
||||
SET title = ?, url = ?
|
||||
WHERE id = ?;
|
||||
""",
|
||||
(
|
||||
data["title"],
|
||||
data["url"],
|
||||
self.id
|
||||
)
|
||||
)
|
||||
|
||||
return self.get()
|
||||
|
||||
def delete(
|
||||
self,
|
||||
delete_reminders_using: bool = False
|
||||
) -> None:
|
||||
"""Delete the service.
|
||||
|
||||
Args:
|
||||
delete_reminders_using (bool, optional): Instead of throwing an
|
||||
error when there are still reminders using the service, delete
|
||||
the reminders.
|
||||
Defaults to False.
|
||||
|
||||
Raises:
|
||||
NotificationServiceInUse: The service is still used by a reminder.
|
||||
"""
|
||||
LOGGER.info(f'Deleting notification service {self.id}')
|
||||
|
||||
cursor = get_db()
|
||||
if not delete_reminders_using:
|
||||
# Check if no reminders exist with this service
|
||||
cursor.execute("""
|
||||
SELECT 1
|
||||
FROM reminder_services
|
||||
WHERE notification_service_id = ?
|
||||
AND reminder_id IS NOT NULL
|
||||
LIMIT 1;
|
||||
""",
|
||||
(self.id,)
|
||||
)
|
||||
if cursor.fetchone():
|
||||
raise NotificationServiceInUse('reminder')
|
||||
|
||||
# Check if no templates exist with this service
|
||||
cursor.execute("""
|
||||
SELECT 1
|
||||
FROM reminder_services
|
||||
WHERE notification_service_id = ?
|
||||
AND template_id IS NOT NULL
|
||||
LIMIT 1;
|
||||
""",
|
||||
(self.id,)
|
||||
)
|
||||
if cursor.fetchone():
|
||||
raise NotificationServiceInUse('template')
|
||||
|
||||
# Check if no static reminders exist with this service
|
||||
cursor.execute("""
|
||||
SELECT 1
|
||||
FROM reminder_services
|
||||
WHERE notification_service_id = ?
|
||||
AND static_reminder_id IS NOT NULL
|
||||
LIMIT 1;
|
||||
""",
|
||||
(self.id,)
|
||||
)
|
||||
if cursor.fetchone():
|
||||
raise NotificationServiceInUse('static reminder')
|
||||
|
||||
else:
|
||||
cursor.execute("""
|
||||
DELETE FROM reminders
|
||||
WHERE id IN (
|
||||
SELECT reminder_id AS id FROM reminder_services
|
||||
WHERE notification_service_id = ?
|
||||
);
|
||||
""", (self.id,))
|
||||
cursor.execute("""
|
||||
DELETE FROM static_reminders
|
||||
WHERE id IN (
|
||||
SELECT static_reminder_id AS id FROM reminder_services
|
||||
WHERE notification_service_id = ?
|
||||
);
|
||||
""", (self.id,))
|
||||
cursor.execute("""
|
||||
DELETE FROM templates
|
||||
WHERE id IN (
|
||||
SELECT template_id AS id FROM reminder_services
|
||||
WHERE notification_service_id = ?
|
||||
);
|
||||
""", (self.id,))
|
||||
|
||||
cursor.execute(
|
||||
"DELETE FROM notification_services WHERE id = ?",
|
||||
(self.id,)
|
||||
)
|
||||
return
|
||||
|
||||
class NotificationServices:
|
||||
def __init__(self, user_id: int) -> None:
|
||||
self.user_id = user_id
|
||||
|
||||
def fetchall(self) -> List[dict]:
|
||||
"""Get a list of all notification services
|
||||
|
||||
Returns:
|
||||
List[dict]: The list of all notification services
|
||||
"""
|
||||
result = list(map(dict, get_db(dict).execute("""
|
||||
SELECT
|
||||
id, title, url
|
||||
FROM notification_services
|
||||
WHERE user_id = ?
|
||||
ORDER BY title, id;
|
||||
""",
|
||||
(self.user_id,)
|
||||
)))
|
||||
|
||||
return result
|
||||
|
||||
def fetchone(self, notification_service_id: int) -> NotificationService:
|
||||
"""Get one notification service based on it's id
|
||||
|
||||
Args:
|
||||
notification_service_id (int): The id of the desired service
|
||||
|
||||
Returns:
|
||||
NotificationService: Instance of NotificationService
|
||||
"""
|
||||
return NotificationService(self.user_id, notification_service_id)
|
||||
|
||||
def add(self, title: str, url: str) -> NotificationService:
|
||||
"""Add a notification service
|
||||
|
||||
Args:
|
||||
title (str): The title of the service
|
||||
url (str): The apprise url of the service
|
||||
|
||||
Returns:
|
||||
NotificationService: The instance representing the new service
|
||||
"""
|
||||
LOGGER.info(f'Adding notification service with {title=}, {url=}')
|
||||
|
||||
new_id = get_db().execute("""
|
||||
INSERT INTO notification_services(user_id, title, url)
|
||||
VALUES (?,?,?)
|
||||
""",
|
||||
(self.user_id, title, url)
|
||||
).lastrowid
|
||||
|
||||
return self.fetchone(new_id)
|
||||
|
||||
def test_service(
|
||||
self,
|
||||
url: str
|
||||
) -> None:
|
||||
"""Send a test notification using the supplied Apprise URL
|
||||
|
||||
Args:
|
||||
url (str): The Apprise URL to use to send the test notification
|
||||
"""
|
||||
LOGGER.info(f'Testing service with {url=}')
|
||||
a = Apprise()
|
||||
a.add(url)
|
||||
a.notify(title='MIND: Test title', body='MIND: Test body')
|
||||
return
|
||||
|
||||
@@ -1,796 +0,0 @@
|
||||
#-*- coding: utf-8 -*-
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from sqlite3 import IntegrityError
|
||||
from threading import Timer
|
||||
from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Union
|
||||
|
||||
from apprise import Apprise
|
||||
from dateutil.relativedelta import relativedelta
|
||||
from dateutil.relativedelta import weekday as du_weekday
|
||||
|
||||
from backend.custom_exceptions import (InvalidKeyValue, InvalidTime,
|
||||
NotificationServiceNotFound,
|
||||
ReminderNotFound)
|
||||
from backend.db import get_db
|
||||
from backend.helpers import (RepeatQuantity, Singleton, SortingMethod,
|
||||
search_filter, when_not_none)
|
||||
from backend.logging import LOGGER
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from flask.ctx import AppContext
|
||||
|
||||
|
||||
def __next_selected_day(
|
||||
weekdays: List[int],
|
||||
weekday: int
|
||||
) -> int:
|
||||
"""Find the next allowed day in the week.
|
||||
|
||||
Args:
|
||||
weekdays (List[int]): The days of the week that are allowed.
|
||||
Monday is 0, Sunday is 6.
|
||||
weekday (int): The current weekday.
|
||||
|
||||
Returns:
|
||||
int: The next allowed weekday.
|
||||
"""
|
||||
return (
|
||||
# Get all days later than current, then grab first one.
|
||||
[d for d in weekdays if weekday < d]
|
||||
or
|
||||
# weekday is last allowed day, so it should grab the first
|
||||
# allowed day of the week.
|
||||
weekdays
|
||||
)[0]
|
||||
|
||||
def _find_next_time(
|
||||
original_time: int,
|
||||
repeat_quantity: Union[RepeatQuantity, None],
|
||||
repeat_interval: Union[int, None],
|
||||
weekdays: Union[List[int], None]
|
||||
) -> int:
|
||||
"""Calculate the next timestep based on original time and repeat/interval
|
||||
values.
|
||||
|
||||
Args:
|
||||
original_time (int): The original time of the repeating timestamp.
|
||||
|
||||
repeat_quantity (Union[RepeatQuantity, None]): If set, what the quantity
|
||||
is of the repetition.
|
||||
|
||||
repeat_interval (Union[int, None]): If set, the value of the repetition.
|
||||
|
||||
weekdays (Union[List[int], None]): If set, on which days the time can
|
||||
continue. Monday is 0, Sunday is 6.
|
||||
|
||||
Returns:
|
||||
int: The next timestamp in the future.
|
||||
"""
|
||||
if weekdays is not None:
|
||||
weekdays.sort()
|
||||
|
||||
new_time = datetime.fromtimestamp(original_time)
|
||||
current_time = datetime.fromtimestamp(datetime.utcnow().timestamp())
|
||||
|
||||
if repeat_quantity is not None:
|
||||
td = relativedelta(**{repeat_quantity.value: repeat_interval})
|
||||
while new_time <= current_time:
|
||||
new_time += td
|
||||
|
||||
elif weekdays is not None:
|
||||
# We run the loop contents at least once and then actually use the cond.
|
||||
# This is because we need to force the 'free' date to go to one of the
|
||||
# selected weekdays.
|
||||
# Say it's Monday, we set a reminder for Wednesday and make it repeat
|
||||
# on Tuesday and Thursday. Then the first notification needs to go on
|
||||
# Thurday, not Wednesday. So run code at least once to force that.
|
||||
# Afterwards, it can run normally to push the timestamp into the future.
|
||||
one_to_go = True
|
||||
while one_to_go or new_time <= current_time:
|
||||
next_day = __next_selected_day(weekdays, new_time.weekday())
|
||||
proposed_time = new_time + relativedelta(weekday=du_weekday(next_day))
|
||||
if proposed_time == new_time:
|
||||
proposed_time += relativedelta(weekday=du_weekday(next_day, 2))
|
||||
new_time = proposed_time
|
||||
one_to_go = False
|
||||
|
||||
result = int(new_time.timestamp())
|
||||
LOGGER.debug(
|
||||
f'{original_time=}, {current_time=} ' +
|
||||
f'and interval of {repeat_interval} {repeat_quantity} ' +
|
||||
f'leads to {result}'
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class Reminder:
|
||||
"""Represents a reminder
|
||||
"""
|
||||
def __init__(self, user_id: int, reminder_id: int) -> None:
|
||||
"""Create an instance.
|
||||
|
||||
Args:
|
||||
user_id (int): The ID of the user.
|
||||
reminder_id (int): The ID of the reminder.
|
||||
|
||||
Raises:
|
||||
ReminderNotFound: Reminder with given ID does not exist or is not
|
||||
owned by user.
|
||||
"""
|
||||
self.id = reminder_id
|
||||
|
||||
# Check if reminder exists
|
||||
if not get_db().execute(
|
||||
"SELECT 1 FROM reminders WHERE id = ? AND user_id = ? LIMIT 1",
|
||||
(self.id, user_id)
|
||||
).fetchone():
|
||||
raise ReminderNotFound
|
||||
|
||||
return
|
||||
|
||||
def _get_notification_services(self) -> List[int]:
|
||||
"""Get ID's of notification services linked to the reminder.
|
||||
|
||||
Returns:
|
||||
List[int]: The list with ID's.
|
||||
"""
|
||||
result = [
|
||||
r[0]
|
||||
for r in get_db().execute("""
|
||||
SELECT notification_service_id
|
||||
FROM reminder_services
|
||||
WHERE reminder_id = ?;
|
||||
""",
|
||||
(self.id,)
|
||||
)
|
||||
]
|
||||
return result
|
||||
|
||||
def get(self) -> dict:
|
||||
"""Get info about the reminder
|
||||
|
||||
Returns:
|
||||
dict: The info about the reminder
|
||||
"""
|
||||
reminder = get_db(dict).execute("""
|
||||
SELECT
|
||||
id,
|
||||
title, text,
|
||||
time,
|
||||
repeat_quantity,
|
||||
repeat_interval,
|
||||
weekdays,
|
||||
color
|
||||
FROM reminders
|
||||
WHERE id = ?
|
||||
LIMIT 1;
|
||||
""",
|
||||
(self.id,)
|
||||
).fetchone()
|
||||
reminder = dict(reminder)
|
||||
|
||||
reminder["weekdays"] = [
|
||||
int(n)
|
||||
for n in reminder["weekdays"].split(",")
|
||||
if n
|
||||
] if reminder["weekdays"] else None
|
||||
reminder['notification_services'] = self._get_notification_services()
|
||||
|
||||
return reminder
|
||||
|
||||
def update(
|
||||
self,
|
||||
title: Union[None, str] = None,
|
||||
time: Union[None, int] = None,
|
||||
notification_services: Union[None, List[int]] = None,
|
||||
text: Union[None, str] = None,
|
||||
repeat_quantity: Union[None, RepeatQuantity] = None,
|
||||
repeat_interval: Union[None, int] = None,
|
||||
weekdays: Union[None, List[int]] = None,
|
||||
color: Union[None, str] = None
|
||||
) -> dict:
|
||||
"""Edit the reminder.
|
||||
|
||||
Args:
|
||||
title (Union[None, str]): The new title of the entry.
|
||||
Defaults to None.
|
||||
|
||||
time (Union[None, int]): The new UTC epoch timestamp when the
|
||||
reminder should be send.
|
||||
Defaults to None.
|
||||
|
||||
notification_services (Union[None, List[int]]): The new list
|
||||
of id's of the notification services to use to send the reminder.
|
||||
Defaults to None.
|
||||
|
||||
text (Union[None, str], optional): The new body of the reminder.
|
||||
Defaults to None.
|
||||
|
||||
repeat_quantity (Union[None, RepeatQuantity], optional): The new
|
||||
quantity of the repeat specified for the reminder.
|
||||
Defaults to None.
|
||||
|
||||
repeat_interval (Union[None, int], optional): The new amount of
|
||||
repeat_quantity, like "5" (hours).
|
||||
Defaults to None.
|
||||
|
||||
weekdays (Union[None, List[int]], optional): The new indexes of
|
||||
the days of the week that the reminder should run.
|
||||
Defaults to None.
|
||||
|
||||
color (Union[None, str], optional): The new hex code of the color
|
||||
of the reminder, which is shown in the web-ui.
|
||||
Defaults to None.
|
||||
|
||||
Note about args:
|
||||
Either repeat_quantity and repeat_interval are given, weekdays is
|
||||
given or neither, but not both.
|
||||
|
||||
Raises:
|
||||
NotificationServiceNotFound: One of the notification services was not found.
|
||||
InvalidKeyValue: The value of one of the keys is not valid or
|
||||
the "Note about args" is violated.
|
||||
|
||||
Returns:
|
||||
dict: The new reminder info.
|
||||
"""
|
||||
LOGGER.info(
|
||||
f'Updating notification service {self.id}: '
|
||||
+ f'{title=}, {time=}, {notification_services=}, {text=}, '
|
||||
+ f'{repeat_quantity=}, {repeat_interval=}, {weekdays=}, {color=}'
|
||||
)
|
||||
cursor = get_db()
|
||||
|
||||
# Validate data
|
||||
if repeat_quantity is None and repeat_interval is not None:
|
||||
raise InvalidKeyValue('repeat_quantity', repeat_quantity)
|
||||
elif repeat_quantity is not None and repeat_interval is None:
|
||||
raise InvalidKeyValue('repeat_interval', repeat_interval)
|
||||
elif weekdays is not None and repeat_quantity is not None:
|
||||
raise InvalidKeyValue('weekdays', weekdays)
|
||||
|
||||
repeated_reminder = (
|
||||
(repeat_quantity is not None and repeat_interval is not None)
|
||||
or weekdays is not None
|
||||
)
|
||||
|
||||
if time is not None:
|
||||
if not repeated_reminder:
|
||||
if time < datetime.utcnow().timestamp():
|
||||
raise InvalidTime
|
||||
time = round(time)
|
||||
|
||||
# Get current data and update it with new values
|
||||
data = self.get()
|
||||
new_values = {
|
||||
'title': title,
|
||||
'time': time,
|
||||
'text': text,
|
||||
'repeat_quantity': repeat_quantity,
|
||||
'repeat_interval': repeat_interval,
|
||||
'weekdays': when_not_none(
|
||||
weekdays,
|
||||
lambda w: ",".join(map(str, sorted(w)))
|
||||
),
|
||||
'color': color
|
||||
}
|
||||
for k, v in new_values.items():
|
||||
if (
|
||||
k in ('repeat_quantity', 'repeat_interval', 'weekdays', 'color')
|
||||
or v is not None
|
||||
):
|
||||
data[k] = v
|
||||
|
||||
# Update database
|
||||
rq = when_not_none(
|
||||
data["repeat_quantity"],
|
||||
lambda q: q.value
|
||||
)
|
||||
if repeated_reminder:
|
||||
next_time = _find_next_time(
|
||||
data["time"],
|
||||
data["repeat_quantity"],
|
||||
data["repeat_interval"],
|
||||
weekdays
|
||||
)
|
||||
cursor.execute("""
|
||||
UPDATE reminders
|
||||
SET
|
||||
title=?,
|
||||
text=?,
|
||||
time=?,
|
||||
repeat_quantity=?,
|
||||
repeat_interval=?,
|
||||
weekdays=?,
|
||||
original_time=?,
|
||||
color=?
|
||||
WHERE id = ?;
|
||||
""", (
|
||||
data["title"],
|
||||
data["text"],
|
||||
next_time,
|
||||
rq,
|
||||
data["repeat_interval"],
|
||||
data["weekdays"],
|
||||
data["time"],
|
||||
data["color"],
|
||||
self.id
|
||||
))
|
||||
|
||||
else:
|
||||
next_time = data["time"]
|
||||
cursor.execute("""
|
||||
UPDATE reminders
|
||||
SET
|
||||
title=?,
|
||||
text=?,
|
||||
time=?,
|
||||
repeat_quantity=?,
|
||||
repeat_interval=?,
|
||||
weekdays=?,
|
||||
color=?
|
||||
WHERE id = ?;
|
||||
""", (
|
||||
data["title"],
|
||||
data["text"],
|
||||
data["time"],
|
||||
rq,
|
||||
data["repeat_interval"],
|
||||
data["weekdays"],
|
||||
data["color"],
|
||||
self.id
|
||||
))
|
||||
|
||||
if notification_services:
|
||||
cursor.connection.isolation_level = None
|
||||
cursor.execute("BEGIN TRANSACTION;")
|
||||
cursor.execute(
|
||||
"DELETE FROM reminder_services WHERE reminder_id = ?",
|
||||
(self.id,)
|
||||
)
|
||||
try:
|
||||
cursor.executemany("""
|
||||
INSERT INTO reminder_services(
|
||||
reminder_id,
|
||||
notification_service_id
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
((self.id, s) for s in notification_services)
|
||||
)
|
||||
cursor.execute("COMMIT;")
|
||||
|
||||
except IntegrityError:
|
||||
raise NotificationServiceNotFound
|
||||
|
||||
finally:
|
||||
cursor.connection.isolation_level = ""
|
||||
|
||||
ReminderHandler().find_next_reminder(next_time)
|
||||
return self.get()
|
||||
|
||||
def delete(self) -> None:
|
||||
"""Delete the reminder
|
||||
"""
|
||||
LOGGER.info(f'Deleting reminder {self.id}')
|
||||
get_db().execute("DELETE FROM reminders WHERE id = ?", (self.id,))
|
||||
ReminderHandler().find_next_reminder()
|
||||
return
|
||||
|
||||
class Reminders:
|
||||
"""Represents the reminder library of the user account
|
||||
"""
|
||||
|
||||
def __init__(self, user_id: int) -> None:
|
||||
"""Create an instance.
|
||||
|
||||
Args:
|
||||
user_id (int): The ID of the user.
|
||||
"""
|
||||
self.user_id = user_id
|
||||
return
|
||||
|
||||
def fetchall(
|
||||
self,
|
||||
sort_by: SortingMethod = SortingMethod.TIME
|
||||
) -> List[dict]:
|
||||
"""Get all reminders
|
||||
|
||||
Args:
|
||||
sort_by (SortingMethod, optional): How to sort the result.
|
||||
Defaults to SortingMethod.TIME.
|
||||
|
||||
Returns:
|
||||
List[dict]: The id, title, text, time and color of each reminder
|
||||
"""
|
||||
reminders = [
|
||||
dict(r)
|
||||
for r in get_db(dict).execute("""
|
||||
SELECT
|
||||
id,
|
||||
title, text,
|
||||
time,
|
||||
repeat_quantity,
|
||||
repeat_interval,
|
||||
weekdays,
|
||||
color
|
||||
FROM reminders
|
||||
WHERE user_id = ?;
|
||||
""",
|
||||
(self.user_id,)
|
||||
)
|
||||
]
|
||||
for r in reminders:
|
||||
r["weekdays"] = [
|
||||
int(n)
|
||||
for n in r["weekdays"].split(",")
|
||||
if n
|
||||
] if r["weekdays"] else None
|
||||
|
||||
# Sort result
|
||||
reminders.sort(key=sort_by.value[0], reverse=sort_by.value[1])
|
||||
|
||||
return reminders
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
sort_by: SortingMethod = SortingMethod.TIME) -> List[dict]:
|
||||
"""Search for reminders
|
||||
|
||||
Args:
|
||||
query (str): The term to search for.
|
||||
sort_by (SortingMethod, optional): How to sort the result.
|
||||
Defaults to SortingMethod.TIME.
|
||||
|
||||
Returns:
|
||||
List[dict]: All reminders that match. Similar output to self.fetchall
|
||||
"""
|
||||
reminders = [
|
||||
r for r in self.fetchall(sort_by)
|
||||
if search_filter(query, r)
|
||||
]
|
||||
return reminders
|
||||
|
||||
def fetchone(self, id: int) -> Reminder:
|
||||
"""Get one reminder
|
||||
|
||||
Args:
|
||||
id (int): The id of the reminder to fetch
|
||||
|
||||
Returns:
|
||||
Reminder: A Reminder instance
|
||||
"""
|
||||
return Reminder(self.user_id, id)
|
||||
|
||||
def add(
|
||||
self,
|
||||
title: str,
|
||||
time: int,
|
||||
notification_services: List[int],
|
||||
text: str = '',
|
||||
repeat_quantity: Union[None, RepeatQuantity] = None,
|
||||
repeat_interval: Union[None, int] = None,
|
||||
weekdays: Union[None, List[int]] = None,
|
||||
color: Union[None, str] = None
|
||||
) -> Reminder:
|
||||
"""Add a reminder
|
||||
|
||||
Args:
|
||||
title (str): The title of the entry.
|
||||
|
||||
time (int): The UTC epoch timestamp the the reminder should be send.
|
||||
|
||||
notification_services (List[int]): The id's of the notification services
|
||||
to use to send the reminder.
|
||||
|
||||
text (str, optional): The body of the reminder.
|
||||
Defaults to ''.
|
||||
|
||||
repeat_quantity (Union[None, RepeatQuantity], optional): The quantity
|
||||
of the repeat specified for the reminder.
|
||||
Defaults to None.
|
||||
|
||||
repeat_interval (Union[None, int], optional): The amount of repeat_quantity,
|
||||
like "5" (hours).
|
||||
Defaults to None.
|
||||
|
||||
weekdays (Union[None, List[int]], optional): The indexes of the days
|
||||
of the week that the reminder should run.
|
||||
Defaults to None.
|
||||
|
||||
color (Union[None, str], optional): The hex code of the color of the
|
||||
reminder, which is shown in the web-ui.
|
||||
Defaults to None.
|
||||
|
||||
Note about args:
|
||||
Either repeat_quantity and repeat_interval are given,
|
||||
weekdays is given or neither, but not both.
|
||||
|
||||
Raises:
|
||||
NotificationServiceNotFound: One of the notification services was not found.
|
||||
InvalidKeyValue: The value of one of the keys is not valid
|
||||
or the "Note about args" is violated.
|
||||
|
||||
Returns:
|
||||
dict: The info about the reminder.
|
||||
"""
|
||||
LOGGER.info(
|
||||
f'Adding reminder with {title=}, {time=}, {notification_services=}, '
|
||||
+ f'{text=}, {repeat_quantity=}, {repeat_interval=}, {weekdays=}, {color=}'
|
||||
)
|
||||
|
||||
if time < datetime.utcnow().timestamp():
|
||||
raise InvalidTime
|
||||
time = round(time)
|
||||
|
||||
if repeat_quantity is None and repeat_interval is not None:
|
||||
raise InvalidKeyValue('repeat_quantity', repeat_quantity)
|
||||
elif repeat_quantity is not None and repeat_interval is None:
|
||||
raise InvalidKeyValue('repeat_interval', repeat_interval)
|
||||
elif (
|
||||
weekdays is not None
|
||||
and repeat_quantity is not None
|
||||
and repeat_interval is not None
|
||||
):
|
||||
raise InvalidKeyValue('weekdays', weekdays)
|
||||
|
||||
cursor = get_db()
|
||||
for service in notification_services:
|
||||
if not cursor.execute("""
|
||||
SELECT 1
|
||||
FROM notification_services
|
||||
WHERE id = ?
|
||||
AND user_id = ?
|
||||
LIMIT 1;
|
||||
""",
|
||||
(service, self.user_id)
|
||||
).fetchone():
|
||||
raise NotificationServiceNotFound
|
||||
|
||||
# Prepare args
|
||||
if any((repeat_quantity, weekdays)):
|
||||
original_time = time
|
||||
time = _find_next_time(
|
||||
original_time,
|
||||
repeat_quantity,
|
||||
repeat_interval,
|
||||
weekdays
|
||||
)
|
||||
else:
|
||||
original_time = None
|
||||
|
||||
weekdays_str = when_not_none(
|
||||
weekdays,
|
||||
lambda w: ",".join(map(str, sorted(w)))
|
||||
)
|
||||
repeat_quantity_str = when_not_none(
|
||||
repeat_quantity,
|
||||
lambda q: q.value
|
||||
)
|
||||
|
||||
cursor.connection.isolation_level = None
|
||||
cursor.execute("BEGIN TRANSACTION;")
|
||||
|
||||
id = cursor.execute("""
|
||||
INSERT INTO reminders(
|
||||
user_id,
|
||||
title, text,
|
||||
time,
|
||||
repeat_quantity, repeat_interval,
|
||||
weekdays,
|
||||
original_time,
|
||||
color
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?);
|
||||
""", (
|
||||
self.user_id,
|
||||
title, text,
|
||||
time,
|
||||
repeat_quantity_str,
|
||||
repeat_interval,
|
||||
weekdays_str,
|
||||
original_time,
|
||||
color
|
||||
)).lastrowid
|
||||
|
||||
try:
|
||||
cursor.executemany("""
|
||||
INSERT INTO reminder_services(
|
||||
reminder_id,
|
||||
notification_service_id
|
||||
)
|
||||
VALUES (?, ?);
|
||||
""",
|
||||
((id, service) for service in notification_services)
|
||||
)
|
||||
cursor.execute("COMMIT;")
|
||||
|
||||
except IntegrityError:
|
||||
raise NotificationServiceNotFound
|
||||
|
||||
finally:
|
||||
cursor.connection.isolation_level = ''
|
||||
|
||||
ReminderHandler().find_next_reminder(time)
|
||||
|
||||
return self.fetchone(id)
|
||||
|
||||
def test_reminder(
|
||||
self,
|
||||
title: str,
|
||||
notification_services: List[int],
|
||||
text: str = ''
|
||||
) -> None:
|
||||
"""Test send a reminder draft.
|
||||
|
||||
Args:
|
||||
title (str): Title title of the entry.
|
||||
|
||||
notification_service (int): The id of the notification service to
|
||||
use to send the reminder.
|
||||
|
||||
text (str, optional): The body of the reminder.
|
||||
Defaults to ''.
|
||||
"""
|
||||
LOGGER.info(f'Testing reminder with {title=}, {notification_services=}, {text=}')
|
||||
a = Apprise()
|
||||
cursor = get_db(dict)
|
||||
|
||||
for service in notification_services:
|
||||
url = cursor.execute("""
|
||||
SELECT url
|
||||
FROM notification_services
|
||||
WHERE id = ?
|
||||
AND user_id = ?
|
||||
LIMIT 1;
|
||||
""",
|
||||
(service, self.user_id)
|
||||
).fetchone()
|
||||
if not url:
|
||||
raise NotificationServiceNotFound
|
||||
a.add(url[0])
|
||||
|
||||
a.notify(title=title, body=text or '\u200B')
|
||||
return
|
||||
|
||||
|
||||
class ReminderHandler(metaclass=Singleton):
|
||||
"""Handle set reminders.
|
||||
|
||||
Note: Singleton.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
context: Callable[[], AppContext]
|
||||
) -> None:
|
||||
"""Create instance of handler.
|
||||
|
||||
Args:
|
||||
context (Optional[AppContext], optional): `Flask.app_context`.
|
||||
Defaults to None.
|
||||
"""
|
||||
self.context = context
|
||||
self.thread: Union[Timer, None] = None
|
||||
self.time: Union[int, None] = None
|
||||
return
|
||||
|
||||
def __trigger_reminders(self, time: int) -> None:
|
||||
"""Trigger all reminders that are set for a certain time
|
||||
|
||||
Args:
|
||||
time (int): The time of the reminders to trigger
|
||||
"""
|
||||
with self.context():
|
||||
cursor = get_db(dict)
|
||||
reminders = [
|
||||
dict(r)
|
||||
for r in cursor.execute("""
|
||||
SELECT
|
||||
id, user_id,
|
||||
title, text,
|
||||
repeat_quantity, repeat_interval,
|
||||
weekdays,
|
||||
original_time
|
||||
FROM reminders
|
||||
WHERE time = ?;
|
||||
""",
|
||||
(time,)
|
||||
)
|
||||
]
|
||||
|
||||
for reminder in reminders:
|
||||
cursor.execute("""
|
||||
SELECT url
|
||||
FROM reminder_services rs
|
||||
INNER JOIN notification_services ns
|
||||
ON rs.notification_service_id = ns.id
|
||||
WHERE rs.reminder_id = ?;
|
||||
""",
|
||||
(reminder['id'],)
|
||||
)
|
||||
|
||||
# Send reminder
|
||||
a = Apprise()
|
||||
for url in cursor:
|
||||
a.add(url['url'])
|
||||
a.notify(title=reminder["title"], body=reminder["text"] or '\u200B')
|
||||
|
||||
self.thread = None
|
||||
self.time = None
|
||||
|
||||
if (reminder['repeat_quantity'], reminder['weekdays']) == (None, None):
|
||||
# Delete the reminder from the database
|
||||
Reminder(reminder["user_id"], reminder["id"]).delete()
|
||||
|
||||
else:
|
||||
# Set next time
|
||||
new_time = _find_next_time(
|
||||
reminder['original_time'],
|
||||
when_not_none(
|
||||
reminder["repeat_quantity"],
|
||||
lambda q: RepeatQuantity(q)
|
||||
),
|
||||
reminder['repeat_interval'],
|
||||
when_not_none(
|
||||
reminder["weekdays"],
|
||||
lambda w: [int(d) for d in w.split(',')]
|
||||
)
|
||||
)
|
||||
cursor.execute(
|
||||
"UPDATE reminders SET time = ? WHERE id = ?;",
|
||||
(new_time, reminder['id'])
|
||||
)
|
||||
|
||||
self.find_next_reminder()
|
||||
return
|
||||
|
||||
def find_next_reminder(self, time: Optional[int] = None) -> None:
|
||||
"""Determine when the soonest reminder is and set the timer to that time
|
||||
|
||||
Args:
|
||||
time (Optional[int], optional): The timestamp to check for.
|
||||
Otherwise check soonest in database.
|
||||
Defaults to None.
|
||||
"""
|
||||
if time is None:
|
||||
with self.context():
|
||||
soonest_time: Union[Tuple[int], None] = get_db().execute("""
|
||||
SELECT DISTINCT r1.time
|
||||
FROM reminders r1
|
||||
LEFT JOIN reminders r2
|
||||
ON r1.time > r2.time
|
||||
WHERE r2.id IS NULL;
|
||||
""").fetchone()
|
||||
if soonest_time is None:
|
||||
return
|
||||
time = soonest_time[0]
|
||||
|
||||
if (
|
||||
self.thread is None
|
||||
or time < self.time
|
||||
):
|
||||
if self.thread is not None:
|
||||
self.thread.cancel()
|
||||
|
||||
t = time - datetime.utcnow().timestamp()
|
||||
self.thread = Timer(
|
||||
t,
|
||||
self.__trigger_reminders,
|
||||
(time,)
|
||||
)
|
||||
self.thread.name = "ReminderHandler"
|
||||
self.thread.start()
|
||||
self.time = time
|
||||
|
||||
return
|
||||
|
||||
def stop_handling(self) -> None:
|
||||
"""Stop the timer if it's active
|
||||
"""
|
||||
if self.thread is not None:
|
||||
self.thread.cancel()
|
||||
return
|
||||
@@ -1,40 +0,0 @@
|
||||
#-*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
Hashing and salting
|
||||
"""
|
||||
|
||||
from base64 import urlsafe_b64encode
|
||||
from hashlib import pbkdf2_hmac
|
||||
from secrets import token_bytes
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
def get_hash(salt: bytes, data: str) -> bytes:
|
||||
"""Hash a string using the supplied salt
|
||||
|
||||
Args:
|
||||
salt (bytes): The salt to use when hashing
|
||||
data (str): The data to hash
|
||||
|
||||
Returns:
|
||||
bytes: The b64 encoded hash of the supplied string
|
||||
"""
|
||||
return urlsafe_b64encode(
|
||||
pbkdf2_hmac('sha256', data.encode(), salt, 100_000)
|
||||
)
|
||||
|
||||
def generate_salt_hash(password: str) -> Tuple[bytes, bytes]:
|
||||
"""Generate a salt and get the hash of the password
|
||||
|
||||
Args:
|
||||
password (str): The password to generate for
|
||||
|
||||
Returns:
|
||||
Tuple[bytes, bytes]: The salt (1) and hashed_password (2)
|
||||
"""
|
||||
salt = token_bytes()
|
||||
hashed_password = get_hash(salt, password)
|
||||
del password
|
||||
|
||||
return salt, hashed_password
|
||||
@@ -1,264 +0,0 @@
|
||||
#-*- coding: utf-8 -*-
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from os import execv, urandom
|
||||
from sys import argv
|
||||
from threading import Timer, current_thread
|
||||
from typing import TYPE_CHECKING, List, NoReturn, Union
|
||||
|
||||
from flask import Flask, render_template, request
|
||||
from waitress import create_server
|
||||
from waitress.task import ThreadedTaskDispatcher as TTD
|
||||
from werkzeug.middleware.dispatcher import DispatcherMiddleware
|
||||
|
||||
from backend.db import DB_Singleton, DBConnection, close_db, revert_db_import
|
||||
from backend.helpers import RestartVars, Singleton, folder_path
|
||||
from backend.logging import LOGGER
|
||||
from backend.settings import restore_hosting_settings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from waitress.server import TcpWSGIServer
|
||||
|
||||
THREADS = 10
|
||||
|
||||
class ThreadedTaskDispatcher(TTD):
|
||||
def handler_thread(self, thread_no: int) -> None:
|
||||
super().handler_thread(thread_no)
|
||||
i = f'{DBConnection}{current_thread()}'
|
||||
if i in DB_Singleton._instances and not DB_Singleton._instances[i].closed:
|
||||
DB_Singleton._instances[i].close()
|
||||
return
|
||||
|
||||
def shutdown(self, cancel_pending: bool = True, timeout: int = 5) -> bool:
|
||||
print()
|
||||
LOGGER.info('Shutting down MIND')
|
||||
result = super().shutdown(cancel_pending, timeout)
|
||||
DBConnection(timeout=20.0).close()
|
||||
return result
|
||||
|
||||
|
||||
class Server(metaclass=Singleton):
|
||||
api_prefix = "/api"
|
||||
admin_api_extension = "/admin"
|
||||
admin_prefix = "/api/admin"
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.do_restart = False
|
||||
"Restart instead of shutdown"
|
||||
|
||||
self.restart_args: List[str] = []
|
||||
"Flag to run with when restarting"
|
||||
|
||||
self.handle_flags: bool = False
|
||||
"Run any flag specific actions before restarting"
|
||||
|
||||
self.url_prefix = ""
|
||||
|
||||
self.revert_db_timer = Timer(60.0, self.__revert_db)
|
||||
self.revert_db_timer.name = "DatabaseImportHandler"
|
||||
self.revert_hosting_timer = Timer(60.0, self.__revert_hosting)
|
||||
self.revert_hosting_timer.name = "HostingHandler"
|
||||
|
||||
return
|
||||
|
||||
def create_app(self) -> None:
|
||||
"""Create a Flask app instance"""
|
||||
from frontend.api import admin_api, api
|
||||
from frontend.ui import ui
|
||||
|
||||
app = Flask(
|
||||
__name__,
|
||||
template_folder=folder_path('frontend','templates'),
|
||||
static_folder=folder_path('frontend','static'),
|
||||
static_url_path='/static'
|
||||
)
|
||||
app.config['SECRET_KEY'] = urandom(32)
|
||||
app.config['JSONIFY_PRETTYPRINT_REGULAR'] = True
|
||||
app.config['JSON_SORT_KEYS'] = False
|
||||
|
||||
# Add error handlers
|
||||
@app.errorhandler(400)
|
||||
def bad_request(e):
|
||||
return {'error': 'Bad request', 'result': {}}, 400
|
||||
|
||||
@app.errorhandler(405)
|
||||
def method_not_allowed(e):
|
||||
return {'error': 'Method not allowed', 'result': {}}, 405
|
||||
|
||||
@app.errorhandler(500)
|
||||
def internal_error(e):
|
||||
return {'error': 'Internal error', 'result': {}}, 500
|
||||
|
||||
@app.errorhandler(404)
|
||||
def not_found(e):
|
||||
if request.path.startswith(self.api_prefix):
|
||||
return {'error': 'Not Found', 'result': {}}, 404
|
||||
return render_template('page_not_found.html', url_prefix=self.url_prefix)
|
||||
|
||||
app.register_blueprint(ui)
|
||||
app.register_blueprint(api, url_prefix=self.api_prefix)
|
||||
app.register_blueprint(admin_api, url_prefix=self.admin_prefix)
|
||||
|
||||
# Setup closing database
|
||||
app.teardown_appcontext(close_db)
|
||||
|
||||
self.app = app
|
||||
return
|
||||
|
||||
def set_url_prefix(self, url_prefix: str) -> None:
|
||||
"""Change the URL prefix of the server.
|
||||
|
||||
Args:
|
||||
url_prefix (str): The desired URL prefix to set it to.
|
||||
"""
|
||||
self.app.config["APPLICATION_ROOT"] = url_prefix
|
||||
self.app.wsgi_app = DispatcherMiddleware(
|
||||
Flask(__name__),
|
||||
{url_prefix: self.app.wsgi_app}
|
||||
)
|
||||
self.url_prefix = url_prefix
|
||||
return
|
||||
|
||||
def __create_waitress_server(
|
||||
self,
|
||||
host: str,
|
||||
port: int
|
||||
) -> TcpWSGIServer:
|
||||
"""From the `Flask` instance created in `self.create_app()`, create
|
||||
a waitress server instance.
|
||||
|
||||
Args:
|
||||
host (str): The host to bind to.
|
||||
port (int): The port to listen on.
|
||||
|
||||
Returns:
|
||||
TcpWSGIServer: The waitress server.
|
||||
"""
|
||||
dispatcher = ThreadedTaskDispatcher()
|
||||
dispatcher.set_thread_count(THREADS)
|
||||
server = create_server(
|
||||
self.app,
|
||||
_dispatcher=dispatcher,
|
||||
host=host,
|
||||
port=port,
|
||||
threads=THREADS
|
||||
)
|
||||
return server
|
||||
|
||||
def run(self, host: str, port: int) -> None:
|
||||
"""Start the webserver.
|
||||
|
||||
Args:
|
||||
host (str): The host to bind to.
|
||||
port (int): The port to listen on.
|
||||
"""
|
||||
self.server = self.__create_waitress_server(host, port)
|
||||
LOGGER.info(f'MIND running on http://{host}:{port}{self.url_prefix}')
|
||||
self.server.run()
|
||||
|
||||
return
|
||||
|
||||
def __shutdown_thread_function(self) -> None:
|
||||
"""Shutdown waitress server. Intended to be run in a thread.
|
||||
"""
|
||||
self.server.close()
|
||||
self.server.task_dispatcher.shutdown()
|
||||
self.server._map.clear()
|
||||
return
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""Stop the waitress server. Starts a thread that
|
||||
shuts down the server.
|
||||
"""
|
||||
t = Timer(1.0, self.__shutdown_thread_function)
|
||||
t.name = "InternalStateHandler"
|
||||
t.start()
|
||||
return
|
||||
|
||||
def restart(
|
||||
self,
|
||||
restart_args: List[str] = [],
|
||||
handle_flags: bool = False
|
||||
) -> None:
|
||||
"""Same as `self.shutdown()`, but restart instead of shutting down.
|
||||
|
||||
Args:
|
||||
restart_args (List[str], optional): Any arguments to run the new instance with.
|
||||
Defaults to [].
|
||||
|
||||
handle_flags (bool, optional): Run flag specific actions just before restarting.
|
||||
Defaults to False.
|
||||
"""
|
||||
self.do_restart = True
|
||||
self.restart_args = restart_args
|
||||
self.handle_flags = handle_flags
|
||||
self.shutdown()
|
||||
return
|
||||
|
||||
def handle_restart(self, flag: Union[str, None]) -> NoReturn:
|
||||
"""Restart the interpreter.
|
||||
|
||||
Args:
|
||||
flag (Union[str, None]): Supplied flag, for flag handling.
|
||||
|
||||
Returns:
|
||||
NoReturn: No return because it replaces the interpreter.
|
||||
"""
|
||||
if self.handle_flags:
|
||||
handle_flags_pre_restart(flag)
|
||||
|
||||
LOGGER.info('Restarting MIND')
|
||||
from MIND import __file__ as mind_file
|
||||
execv(folder_path(mind_file), [argv[0], *self.restart_args])
|
||||
|
||||
def __revert_db(self) -> None:
|
||||
"""Revert database import and restart.
|
||||
"""
|
||||
LOGGER.warning(f'Timer for database import expired; reverting back to original file')
|
||||
self.restart(handle_flags=True)
|
||||
return
|
||||
|
||||
def __revert_hosting(self) -> None:
|
||||
"""Revert the hosting changes.
|
||||
"""
|
||||
LOGGER.warning(f'Timer for hosting changes expired; reverting back to original settings')
|
||||
self.restart(handle_flags=True)
|
||||
return
|
||||
|
||||
|
||||
SERVER = Server()
|
||||
|
||||
|
||||
def handle_flags(flag: Union[None, str]) -> None:
|
||||
"""Run flag specific actions on startup.
|
||||
|
||||
Args:
|
||||
flag (Union[None, str]): The flag or `None` if there is no flag set.
|
||||
"""
|
||||
if flag == RestartVars.DB_IMPORT:
|
||||
LOGGER.info('Starting timer for database import')
|
||||
SERVER.revert_db_timer.start()
|
||||
|
||||
elif flag == RestartVars.HOST_CHANGE:
|
||||
LOGGER.info('Starting timer for hosting changes')
|
||||
SERVER.revert_hosting_timer.start()
|
||||
|
||||
return
|
||||
|
||||
|
||||
def handle_flags_pre_restart(flag: Union[None, str]) -> None:
|
||||
"""Run flag specific actions just before restarting.
|
||||
|
||||
Args:
|
||||
flag (Union[None, str]): The flag or `None` if there is no flag set.
|
||||
"""
|
||||
if flag == RestartVars.DB_IMPORT:
|
||||
revert_db_import(swap=True)
|
||||
|
||||
elif flag == RestartVars.HOST_CHANGE:
|
||||
with SERVER.app.app_context():
|
||||
restore_hosting_settings()
|
||||
close_db()
|
||||
|
||||
return
|
||||
@@ -1,245 +0,0 @@
|
||||
#-*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
Getting and setting settings
|
||||
"""
|
||||
|
||||
import logging
|
||||
from json import dump, load
|
||||
from typing import Any
|
||||
|
||||
from backend.custom_exceptions import InvalidKeyValue, KeyNotFound
|
||||
from backend.db import __DATABASE_VERSION__, get_db
|
||||
from backend.helpers import folder_path
|
||||
from backend.logging import set_log_level
|
||||
|
||||
default_settings = {
|
||||
'allow_new_accounts': True,
|
||||
'login_time': 3600,
|
||||
'login_time_reset': True,
|
||||
|
||||
'database_version': __DATABASE_VERSION__,
|
||||
|
||||
'host': '0.0.0.0',
|
||||
'port': 8080,
|
||||
'url_prefix': '',
|
||||
|
||||
'log_level': logging.INFO
|
||||
}
|
||||
|
||||
def _format_setting(key: str, value):
|
||||
"""Turn python value in to database value.
|
||||
|
||||
Args:
|
||||
key (str): The key of the value.
|
||||
value (Any): The value itself.
|
||||
|
||||
Raises:
|
||||
InvalidKeyValue: The value is not valid.
|
||||
|
||||
Returns:
|
||||
Any: The converted value.
|
||||
"""
|
||||
if key == 'database_version':
|
||||
try:
|
||||
value = int(value)
|
||||
except ValueError:
|
||||
raise InvalidKeyValue(key, value)
|
||||
|
||||
elif key in ('allow_new_accounts', 'login_time_reset'):
|
||||
if not isinstance(value, bool):
|
||||
raise InvalidKeyValue(key, value)
|
||||
value = int(value)
|
||||
|
||||
elif key == 'login_time':
|
||||
if not isinstance(value, int) or not 60 <= value <= 2592000:
|
||||
raise InvalidKeyValue(key, value)
|
||||
|
||||
elif key == 'host':
|
||||
if not isinstance(value, str):
|
||||
raise InvalidKeyValue(key, value)
|
||||
|
||||
elif key == 'port':
|
||||
if not isinstance(value, int) or not 1 <= value <= 65535:
|
||||
raise InvalidKeyValue(key, value)
|
||||
|
||||
elif key == 'url_prefix':
|
||||
if not isinstance(value, str):
|
||||
raise InvalidKeyValue(key, value)
|
||||
|
||||
if value == '/':
|
||||
value = ''
|
||||
|
||||
elif value:
|
||||
value = '/' + value.strip('/')
|
||||
|
||||
elif key == 'log_level' and not value in (logging.INFO, logging.DEBUG):
|
||||
raise InvalidKeyValue(key, value)
|
||||
|
||||
return value
|
||||
|
||||
def _reverse_format_setting(key: str, value: Any) -> Any:
|
||||
"""Turn database value in to python value.
|
||||
|
||||
Args:
|
||||
key (str): The key of the value.
|
||||
value (Any): The value itself.
|
||||
|
||||
Returns:
|
||||
Any: The converted value.
|
||||
"""
|
||||
if key in ('allow_new_accounts', 'login_time_reset'):
|
||||
value = value == 1
|
||||
|
||||
elif key in ('log_level', 'database_version', 'login_time'):
|
||||
value = int(value)
|
||||
|
||||
return value
|
||||
|
||||
def get_setting(key: str) -> Any:
|
||||
"""Get a value from the config.
|
||||
|
||||
Args:
|
||||
key (str): The key of which to get the value.
|
||||
|
||||
Raises:
|
||||
KeyNotFound: Key is not in config.
|
||||
|
||||
Returns:
|
||||
Any: The value of the key.
|
||||
"""
|
||||
result = get_db().execute(
|
||||
"SELECT value FROM config WHERE key = ? LIMIT 1;",
|
||||
(key,)
|
||||
).fetchone()
|
||||
if result is None:
|
||||
raise KeyNotFound(key)
|
||||
|
||||
result = _reverse_format_setting(key, result[0])
|
||||
|
||||
return result
|
||||
|
||||
def get_admin_settings() -> dict:
|
||||
"""Get all admin settings
|
||||
|
||||
Returns:
|
||||
dict: The admin settings
|
||||
"""
|
||||
return dict((
|
||||
(key, _reverse_format_setting(key, value))
|
||||
for key, value in get_db().execute("""
|
||||
SELECT key, value
|
||||
FROM config
|
||||
WHERE
|
||||
key = 'allow_new_accounts'
|
||||
OR key = 'login_time'
|
||||
OR key = 'login_time_reset'
|
||||
OR key = 'host'
|
||||
OR key = 'port'
|
||||
OR key = 'url_prefix'
|
||||
OR key = 'log_level';
|
||||
"""
|
||||
)
|
||||
))
|
||||
|
||||
def set_setting(key: str, value: Any) -> None:
|
||||
"""Set a value in the config
|
||||
|
||||
Args:
|
||||
key (str): The key for which to set the value
|
||||
value (Any): The value to give to the key
|
||||
|
||||
Raises:
|
||||
KeyNotFound: The key is not in the config
|
||||
InvalidKeyValue: The value is not allowed for the key
|
||||
"""
|
||||
if not key in (*default_settings, 'database_version'):
|
||||
raise KeyNotFound(key)
|
||||
|
||||
value = _format_setting(key, value)
|
||||
|
||||
get_db().execute(
|
||||
"UPDATE config SET value = ? WHERE key = ?;",
|
||||
(value, key)
|
||||
)
|
||||
|
||||
if key == 'url_prefix':
|
||||
update_manifest(value)
|
||||
|
||||
elif key == 'log_level':
|
||||
set_log_level(value)
|
||||
|
||||
return
|
||||
|
||||
def update_manifest(url_base: str) -> None:
|
||||
"""Update the url's in the manifest file.
|
||||
Needs to happen when url base changes.
|
||||
|
||||
Args:
|
||||
url_base (str): The url base to use in the file.
|
||||
"""
|
||||
filename = folder_path('frontend', 'static', 'json', 'pwa_manifest.json')
|
||||
|
||||
with open(filename, 'r') as f:
|
||||
manifest = load(f)
|
||||
manifest['start_url'] = url_base + '/'
|
||||
manifest['icons'][0]['src'] = f'{url_base}/static/img/favicon.svg'
|
||||
|
||||
with open(filename, 'w') as f:
|
||||
dump(manifest, f, indent=4)
|
||||
|
||||
return
|
||||
|
||||
def backup_hosting_settings() -> None:
|
||||
"""Copy current hosting settings to backup values.
|
||||
"""
|
||||
cursor = get_db()
|
||||
hosting_settings = dict(cursor.execute("""
|
||||
SELECT key, value
|
||||
FROM config
|
||||
WHERE key = 'host'
|
||||
OR key = 'port'
|
||||
OR key = 'url_prefix'
|
||||
LIMIT 3;
|
||||
"""
|
||||
))
|
||||
hosting_settings = {f'{k}_backup': v for k, v in hosting_settings.items()}
|
||||
|
||||
cursor.executemany("""
|
||||
INSERT INTO config(key, value)
|
||||
VALUES (?, ?)
|
||||
ON CONFLICT(key) DO
|
||||
UPDATE
|
||||
SET value = ?;
|
||||
""",
|
||||
((k, v, v) for k, v in hosting_settings.items())
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
def restore_hosting_settings() -> None:
|
||||
"""Copy the hosting settings from the backup over to the main keys.
|
||||
"""
|
||||
cursor = get_db()
|
||||
hosting_settings = dict(cursor.execute("""
|
||||
SELECT key, value
|
||||
FROM config
|
||||
WHERE key = 'host_backup'
|
||||
OR key = 'port_backup'
|
||||
OR key = 'url_prefix_backup'
|
||||
LIMIT 3;
|
||||
"""
|
||||
))
|
||||
if len(hosting_settings) < 3:
|
||||
return
|
||||
|
||||
hosting_settings = {k.split('_backup')[0]: v for k, v in hosting_settings.items()}
|
||||
|
||||
cursor.executemany(
|
||||
"UPDATE config SET value = ? WHERE key = ?",
|
||||
((v, k) for k, v in hosting_settings.items())
|
||||
)
|
||||
|
||||
update_manifest(hosting_settings['url_prefix'])
|
||||
|
||||
return
|
||||
@@ -1,356 +0,0 @@
|
||||
#-*- coding: utf-8 -*-
|
||||
|
||||
from sqlite3 import IntegrityError
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from apprise import Apprise
|
||||
|
||||
from backend.custom_exceptions import (NotificationServiceNotFound,
|
||||
ReminderNotFound)
|
||||
from backend.db import get_db
|
||||
from backend.helpers import TimelessSortingMethod, search_filter
|
||||
from backend.logging import LOGGER
|
||||
|
||||
|
||||
class StaticReminder:
|
||||
"""Represents a static reminder
|
||||
"""
|
||||
def __init__(self, user_id: int, reminder_id: int) -> None:
|
||||
"""Create an instance.
|
||||
|
||||
Args:
|
||||
user_id (int): The ID of the user.
|
||||
reminder_id (int): The ID of the reminder.
|
||||
|
||||
Raises:
|
||||
ReminderNotFound: Reminder with given ID does not exist or is not
|
||||
owned by user.
|
||||
"""
|
||||
self.id = reminder_id
|
||||
|
||||
# Check if reminder exists
|
||||
if not get_db().execute(
|
||||
"SELECT 1 FROM static_reminders WHERE id = ? AND user_id = ? LIMIT 1;",
|
||||
(self.id, user_id)
|
||||
).fetchone():
|
||||
raise ReminderNotFound
|
||||
|
||||
return
|
||||
|
||||
def _get_notification_services(self) -> List[int]:
|
||||
"""Get ID's of notification services linked to the static reminder.
|
||||
|
||||
Returns:
|
||||
List[int]: The list with ID's.
|
||||
"""
|
||||
result = [
|
||||
r[0]
|
||||
for r in get_db().execute("""
|
||||
SELECT notification_service_id
|
||||
FROM reminder_services
|
||||
WHERE static_reminder_id = ?;
|
||||
""",
|
||||
(self.id,)
|
||||
)
|
||||
]
|
||||
return result
|
||||
|
||||
def get(self) -> dict:
|
||||
"""Get info about the static reminder
|
||||
|
||||
Returns:
|
||||
dict: The info about the static reminder
|
||||
"""
|
||||
reminder = get_db(dict).execute("""
|
||||
SELECT
|
||||
id,
|
||||
title, text,
|
||||
color
|
||||
FROM static_reminders
|
||||
WHERE id = ?
|
||||
LIMIT 1;
|
||||
""",
|
||||
(self.id,)
|
||||
).fetchone()
|
||||
reminder = dict(reminder)
|
||||
|
||||
reminder['notification_services'] = self._get_notification_services()
|
||||
|
||||
return reminder
|
||||
|
||||
def update(
|
||||
self,
|
||||
title: Union[str, None] = None,
|
||||
notification_services: Union[List[int], None] = None,
|
||||
text: Union[str, None] = None,
|
||||
color: Union[str, None] = None
|
||||
) -> dict:
|
||||
"""Edit the static reminder.
|
||||
|
||||
Args:
|
||||
title (Union[str, None], optional): The new title of the entry.
|
||||
Defaults to None.
|
||||
|
||||
notification_services (Union[List[int], None], optional):
|
||||
The new id's of the notification services to use to send the reminder.
|
||||
Defaults to None.
|
||||
|
||||
text (Union[str, None], optional): The new body of the reminder.
|
||||
Defaults to None.
|
||||
|
||||
color (Union[str, None], optional): The new hex code of the color
|
||||
of the reminder, which is shown in the web-ui.
|
||||
Defaults to None.
|
||||
|
||||
Raises:
|
||||
NotificationServiceNotFound: One of the notification services was not found
|
||||
|
||||
Returns:
|
||||
dict: The new static reminder info
|
||||
"""
|
||||
LOGGER.info(
|
||||
f'Updating static reminder {self.id}: '
|
||||
+ f'{title=}, {notification_services=}, {text=}, {color=}'
|
||||
)
|
||||
|
||||
# Get current data and update it with new values
|
||||
data = self.get()
|
||||
new_values = {
|
||||
'title': title,
|
||||
'text': text,
|
||||
'color': color
|
||||
}
|
||||
for k, v in new_values.items():
|
||||
if k == 'color' or v is not None:
|
||||
data[k] = v
|
||||
|
||||
# Update database
|
||||
cursor = get_db()
|
||||
cursor.execute("""
|
||||
UPDATE static_reminders
|
||||
SET
|
||||
title = ?, text = ?,
|
||||
color = ?
|
||||
WHERE id = ?;
|
||||
""",
|
||||
(data['title'], data['text'],
|
||||
data['color'],
|
||||
self.id)
|
||||
)
|
||||
|
||||
if notification_services:
|
||||
cursor.connection.isolation_level = None
|
||||
cursor.execute("BEGIN TRANSACTION;")
|
||||
cursor.execute(
|
||||
"DELETE FROM reminder_services WHERE static_reminder_id = ?",
|
||||
(self.id,)
|
||||
)
|
||||
try:
|
||||
cursor.executemany("""
|
||||
INSERT INTO reminder_services(
|
||||
static_reminder_id,
|
||||
notification_service_id
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
((self.id, s) for s in notification_services)
|
||||
)
|
||||
cursor.execute("COMMIT;")
|
||||
|
||||
except IntegrityError:
|
||||
raise NotificationServiceNotFound
|
||||
|
||||
finally:
|
||||
cursor.connection.isolation_level = ""
|
||||
|
||||
return self.get()
|
||||
|
||||
def delete(self) -> None:
|
||||
"""Delete the static reminder
|
||||
"""
|
||||
LOGGER.info(f'Deleting static reminder {self.id}')
|
||||
get_db().execute("DELETE FROM static_reminders WHERE id = ?", (self.id,))
|
||||
return
|
||||
|
||||
class StaticReminders:
|
||||
"""Represents the static reminder library of the user account
|
||||
"""
|
||||
|
||||
def __init__(self, user_id: int) -> None:
|
||||
"""Create an instance.
|
||||
|
||||
Args:
|
||||
user_id (int): The ID of the user.
|
||||
"""
|
||||
self.user_id = user_id
|
||||
return
|
||||
|
||||
def fetchall(
|
||||
self,
|
||||
sort_by: TimelessSortingMethod = TimelessSortingMethod.TITLE
|
||||
) -> List[dict]:
|
||||
"""Get all static reminders
|
||||
|
||||
Args:
|
||||
sort_by (TimelessSortingMethod, optional): How to sort the result.
|
||||
Defaults to TimelessSortingMethod.TITLE.
|
||||
|
||||
Returns:
|
||||
List[dict]: The id, title, text and color of each static reminder.
|
||||
"""
|
||||
reminders = [
|
||||
dict(r)
|
||||
for r in get_db(dict).execute("""
|
||||
SELECT
|
||||
id,
|
||||
title, text,
|
||||
color
|
||||
FROM static_reminders
|
||||
WHERE user_id = ?
|
||||
ORDER BY title, id;
|
||||
""",
|
||||
(self.user_id,)
|
||||
)
|
||||
]
|
||||
|
||||
# Sort result
|
||||
reminders.sort(key=sort_by.value[0], reverse=sort_by.value[1])
|
||||
|
||||
return reminders
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
sort_by: TimelessSortingMethod = TimelessSortingMethod.TITLE
|
||||
) -> List[dict]:
|
||||
"""Search for static reminders
|
||||
|
||||
Args:
|
||||
query (str): The term to search for.
|
||||
|
||||
sort_by (TimelessSortingMethod, optional): The sorting method of
|
||||
the resulting list.
|
||||
Defaults to TimelessSortingMethod.TITLE.
|
||||
|
||||
Returns:
|
||||
List[dict]: All static reminders that match.
|
||||
Similar output to `self.fetchall`
|
||||
"""
|
||||
static_reminders = [
|
||||
r for r in self.fetchall(sort_by)
|
||||
if search_filter(query, r)
|
||||
]
|
||||
return static_reminders
|
||||
|
||||
def fetchone(self, id: int) -> StaticReminder:
|
||||
"""Get one static reminder
|
||||
|
||||
Args:
|
||||
id (int): The id of the static reminder to fetch
|
||||
|
||||
Returns:
|
||||
StaticReminder: A StaticReminder instance
|
||||
"""
|
||||
return StaticReminder(self.user_id, id)
|
||||
|
||||
def add(
|
||||
self,
|
||||
title: str,
|
||||
notification_services: List[int],
|
||||
text: str = '',
|
||||
color: Optional[str] = None
|
||||
) -> StaticReminder:
|
||||
"""Add a static reminder
|
||||
|
||||
Args:
|
||||
title (str): The title of the entry.
|
||||
|
||||
notification_services (List[int]): The id's of the
|
||||
notification services to use to send the reminder.
|
||||
|
||||
text (str, optional): The body of the reminder.
|
||||
Defaults to ''.
|
||||
|
||||
color (Optional[str], optional): The hex code of the color of the template,
|
||||
which is shown in the web-ui.
|
||||
Defaults to None.
|
||||
|
||||
Raises:
|
||||
NotificationServiceNotFound: One of the notification services was not found
|
||||
|
||||
Returns:
|
||||
StaticReminder: The info about the static reminder
|
||||
"""
|
||||
LOGGER.info(
|
||||
f'Adding static reminder with {title=}, {notification_services=}, {text=}, {color=}'
|
||||
)
|
||||
|
||||
cursor = get_db()
|
||||
cursor.connection.isolation_level = None
|
||||
cursor.execute("BEGIN TRANSACTION;")
|
||||
|
||||
id = cursor.execute("""
|
||||
INSERT INTO static_reminders(user_id, title, text, color)
|
||||
VALUES (?,?,?,?);
|
||||
""",
|
||||
(self.user_id, title, text, color)
|
||||
).lastrowid
|
||||
|
||||
try:
|
||||
cursor.executemany("""
|
||||
INSERT INTO reminder_services(
|
||||
static_reminder_id,
|
||||
notification_service_id
|
||||
)
|
||||
VALUES (?, ?);
|
||||
""",
|
||||
((id, service) for service in notification_services)
|
||||
)
|
||||
cursor.execute("COMMIT;")
|
||||
|
||||
except IntegrityError:
|
||||
raise NotificationServiceNotFound
|
||||
finally:
|
||||
cursor.connection.isolation_level = ""
|
||||
|
||||
return self.fetchone(id)
|
||||
|
||||
def trigger_reminder(self, id: int) -> None:
|
||||
"""Trigger a static reminder to send it's reminder
|
||||
|
||||
Args:
|
||||
id (int): The id of the static reminder to trigger
|
||||
|
||||
Raises:
|
||||
ReminderNotFound: The static reminder with the given id was not found
|
||||
"""
|
||||
LOGGER.info(f'Triggering static reminder {id}')
|
||||
cursor = get_db(dict)
|
||||
reminder = cursor.execute("""
|
||||
SELECT title, text
|
||||
FROM static_reminders
|
||||
WHERE
|
||||
id = ?
|
||||
AND user_id = ?
|
||||
LIMIT 1;
|
||||
""",
|
||||
(id, self.user_id)
|
||||
).fetchone()
|
||||
if not reminder:
|
||||
raise ReminderNotFound
|
||||
reminder = dict(reminder)
|
||||
|
||||
a = Apprise()
|
||||
cursor.execute("""
|
||||
SELECT url
|
||||
FROM reminder_services rs
|
||||
INNER JOIN notification_services ns
|
||||
ON rs.notification_service_id = ns.id
|
||||
WHERE rs.static_reminder_id = ?;
|
||||
""",
|
||||
(id,)
|
||||
)
|
||||
for url in cursor:
|
||||
a.add(url['url'])
|
||||
a.notify(title=reminder['title'], body=reminder['text'] or '\u200B')
|
||||
return
|
||||
@@ -1,311 +0,0 @@
|
||||
#-*- coding: utf-8 -*-
|
||||
|
||||
from sqlite3 import IntegrityError
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from backend.custom_exceptions import (NotificationServiceNotFound,
|
||||
TemplateNotFound)
|
||||
from backend.db import get_db
|
||||
from backend.helpers import TimelessSortingMethod, search_filter
|
||||
from backend.logging import LOGGER
|
||||
|
||||
|
||||
class Template:
|
||||
"""Represents a template
|
||||
"""
|
||||
def __init__(self, user_id: int, template_id: int) -> None:
|
||||
"""Create instance of class.
|
||||
|
||||
Args:
|
||||
user_id (int): The ID of the user.
|
||||
template_id (int): The ID of the template.
|
||||
|
||||
Raises:
|
||||
TemplateNotFound: Template with given ID does not exist or is not
|
||||
owned by user.
|
||||
"""
|
||||
self.id = template_id
|
||||
|
||||
exists = get_db().execute(
|
||||
"SELECT 1 FROM templates WHERE id = ? AND user_id = ? LIMIT 1;",
|
||||
(self.id, user_id)
|
||||
).fetchone()
|
||||
if not exists:
|
||||
raise TemplateNotFound
|
||||
return
|
||||
|
||||
def _get_notification_services(self) -> List[int]:
|
||||
"""Get ID's of notification services linked to the template.
|
||||
|
||||
Returns:
|
||||
List[int]: The list with ID's.
|
||||
"""
|
||||
result = [
|
||||
r[0]
|
||||
for r in get_db().execute("""
|
||||
SELECT notification_service_id
|
||||
FROM reminder_services
|
||||
WHERE template_id = ?;
|
||||
""",
|
||||
(self.id,)
|
||||
)
|
||||
]
|
||||
return result
|
||||
|
||||
def get(self) -> dict:
|
||||
"""Get info about the template
|
||||
|
||||
Returns:
|
||||
dict: The info about the template
|
||||
"""
|
||||
template = get_db(dict).execute("""
|
||||
SELECT
|
||||
id,
|
||||
title, text,
|
||||
color
|
||||
FROM templates
|
||||
WHERE id = ?
|
||||
LIMIT 1;
|
||||
""",
|
||||
(self.id,)
|
||||
).fetchone()
|
||||
template = dict(template)
|
||||
|
||||
template['notification_services'] = self._get_notification_services()
|
||||
|
||||
return template
|
||||
|
||||
def update(self,
|
||||
title: Union[str, None] = None,
|
||||
notification_services: Union[List[int], None] = None,
|
||||
text: Union[str, None] = None,
|
||||
color: Union[str, None] = None
|
||||
) -> dict:
|
||||
"""Edit the template
|
||||
|
||||
Args:
|
||||
title (Union[str, None]): The new title of the entry.
|
||||
Defaults to None.
|
||||
|
||||
notification_services (Union[List[int], None]): The new id's of the
|
||||
notification services to use to send the reminder.
|
||||
Defaults to None.
|
||||
|
||||
text (Union[str, None], optional): The new body of the template.
|
||||
Defaults to None.
|
||||
|
||||
color (Union[str, None], optional): The new hex code of the color of the template,
|
||||
which is shown in the web-ui.
|
||||
Defaults to None.
|
||||
|
||||
Raises:
|
||||
NotificationServiceNotFound: One of the notification services was not found
|
||||
|
||||
Returns:
|
||||
dict: The new template info
|
||||
"""
|
||||
LOGGER.info(
|
||||
f'Updating template {self.id}: '
|
||||
+ f'{title=}, {notification_services=}, {text=}, {color=}'
|
||||
)
|
||||
|
||||
cursor = get_db()
|
||||
|
||||
data = self.get()
|
||||
new_values = {
|
||||
'title': title,
|
||||
'text': text,
|
||||
'color': color
|
||||
}
|
||||
for k, v in new_values.items():
|
||||
if k in ('color',) or v is not None:
|
||||
data[k] = v
|
||||
|
||||
cursor.execute("""
|
||||
UPDATE templates
|
||||
SET title=?, text=?, color=?
|
||||
WHERE id = ?;
|
||||
""", (
|
||||
data['title'],
|
||||
data['text'],
|
||||
data['color'],
|
||||
self.id
|
||||
))
|
||||
|
||||
if notification_services:
|
||||
cursor.connection.isolation_level = None
|
||||
cursor.execute("BEGIN TRANSACTION;")
|
||||
cursor.execute(
|
||||
"DELETE FROM reminder_services WHERE template_id = ?",
|
||||
(self.id,)
|
||||
)
|
||||
try:
|
||||
cursor.executemany("""
|
||||
INSERT INTO reminder_services(
|
||||
template_id,
|
||||
notification_service_id
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
((self.id, s) for s in notification_services)
|
||||
)
|
||||
cursor.execute("COMMIT;")
|
||||
|
||||
except IntegrityError:
|
||||
raise NotificationServiceNotFound
|
||||
|
||||
finally:
|
||||
cursor.connection.isolation_level = ""
|
||||
|
||||
return self.get()
|
||||
|
||||
def delete(self) -> None:
|
||||
"""Delete the template
|
||||
"""
|
||||
LOGGER.info(f'Deleting template {self.id}')
|
||||
get_db().execute("DELETE FROM templates WHERE id = ?;", (self.id,))
|
||||
return
|
||||
|
||||
class Templates:
|
||||
"""Represents the template library of the user account
|
||||
"""
|
||||
|
||||
def __init__(self, user_id: int) -> None:
|
||||
"""Create an instance.
|
||||
|
||||
Args:
|
||||
user_id (int): The ID of the user.
|
||||
"""
|
||||
self.user_id = user_id
|
||||
return
|
||||
|
||||
def fetchall(
|
||||
self,
|
||||
sort_by: TimelessSortingMethod = TimelessSortingMethod.TITLE
|
||||
) -> List[dict]:
|
||||
"""Get all templates of the user.
|
||||
|
||||
Args:
|
||||
sort_by (TimelessSortingMethod, optional): The sorting method of
|
||||
the resulting list.
|
||||
Defaults to TimelessSortingMethod.TITLE.
|
||||
|
||||
Returns:
|
||||
List[dict]: The id, title, text and color of each template.
|
||||
"""
|
||||
templates = [
|
||||
dict(r)
|
||||
for r in get_db(dict).execute("""
|
||||
SELECT
|
||||
id,
|
||||
title, text,
|
||||
color
|
||||
FROM templates
|
||||
WHERE user_id = ?
|
||||
ORDER BY title, id;
|
||||
""",
|
||||
(self.user_id,)
|
||||
)
|
||||
]
|
||||
|
||||
# Sort result
|
||||
templates.sort(key=sort_by.value[0], reverse=sort_by.value[1])
|
||||
|
||||
return templates
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
sort_by: TimelessSortingMethod = TimelessSortingMethod.TITLE
|
||||
) -> List[dict]:
|
||||
"""Search for templates
|
||||
|
||||
Args:
|
||||
query (str): The term to search for.
|
||||
|
||||
sort_by (TimelessSortingMethod, optional): The sorting method of
|
||||
the resulting list.
|
||||
Defaults to TimelessSortingMethod.TITLE.
|
||||
|
||||
Returns:
|
||||
List[dict]: All templates that match. Similar output to `self.fetchall`
|
||||
"""
|
||||
templates = [
|
||||
r for r in self.fetchall(sort_by)
|
||||
if search_filter(query, r)
|
||||
]
|
||||
return templates
|
||||
|
||||
def fetchone(self, id: int) -> Template:
|
||||
"""Get one template
|
||||
|
||||
Args:
|
||||
id (int): The id of the template to fetch
|
||||
|
||||
Returns:
|
||||
Template: A Template instance
|
||||
"""
|
||||
return Template(self.user_id, id)
|
||||
|
||||
def add(
|
||||
self,
|
||||
title: str,
|
||||
notification_services: List[int],
|
||||
text: str = '',
|
||||
color: Optional[str] = None
|
||||
) -> Template:
|
||||
"""Add a template
|
||||
|
||||
Args:
|
||||
title (str): The title of the entry.
|
||||
|
||||
notification_services (List[int]): The id's of the
|
||||
notification services to use to send the reminder.
|
||||
|
||||
text (str, optional): The body of the reminder.
|
||||
Defaults to ''.
|
||||
|
||||
color (Optional[str], optional): The hex code of the color of the template,
|
||||
which is shown in the web-ui.
|
||||
Defaults to None.
|
||||
|
||||
Raises:
|
||||
NotificationServiceNotFound: One of the notification services was not found
|
||||
|
||||
Returns:
|
||||
Template: The info about the template
|
||||
"""
|
||||
LOGGER.info(
|
||||
f'Adding template with {title=}, {notification_services=}, {text=}, {color=}'
|
||||
)
|
||||
|
||||
cursor = get_db()
|
||||
cursor.connection.isolation_level = None
|
||||
cursor.execute("BEGIN TRANSACTION;")
|
||||
|
||||
id = cursor.execute("""
|
||||
INSERT INTO templates(user_id, title, text, color)
|
||||
VALUES (?,?,?,?);
|
||||
""",
|
||||
(self.user_id, title, text, color)
|
||||
).lastrowid
|
||||
|
||||
try:
|
||||
cursor.executemany("""
|
||||
INSERT INTO reminder_services(
|
||||
template_id,
|
||||
notification_service_id
|
||||
)
|
||||
VALUES (?, ?);
|
||||
""",
|
||||
((id, service) for service in notification_services)
|
||||
)
|
||||
cursor.execute("COMMIT;")
|
||||
|
||||
except IntegrityError:
|
||||
raise NotificationServiceNotFound
|
||||
|
||||
finally:
|
||||
cursor.connection.isolation_level = ""
|
||||
|
||||
return self.fetchone(id)
|
||||
235
backend/users.py
235
backend/users.py
@@ -1,235 +0,0 @@
|
||||
#-*- coding: utf-8 -*-
|
||||
|
||||
from typing import List
|
||||
|
||||
from backend.custom_exceptions import (AccessUnauthorized,
|
||||
NewAccountsNotAllowed, UsernameInvalid,
|
||||
UsernameTaken, UserNotFound)
|
||||
from backend.db import get_db
|
||||
from backend.logging import LOGGER
|
||||
from backend.notification_service import NotificationServices
|
||||
from backend.reminders import Reminders
|
||||
from backend.security import generate_salt_hash, get_hash
|
||||
from backend.settings import get_setting
|
||||
from backend.static_reminders import StaticReminders
|
||||
from backend.templates import Templates
|
||||
|
||||
ONEPASS_USERNAME_CHARACTERS = 'abcedfghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_-.!@$'
|
||||
ONEPASS_INVALID_USERNAMES = ['reminders', 'api']
|
||||
|
||||
class User:
|
||||
"""Represents an user account
|
||||
"""
|
||||
|
||||
def __init__(self, id: int) -> None:
|
||||
result = get_db(dict).execute(
|
||||
"SELECT username, admin, salt FROM users WHERE id = ? LIMIT 1;",
|
||||
(id,)
|
||||
).fetchone()
|
||||
if not result:
|
||||
raise UserNotFound
|
||||
|
||||
self.username: str = result['username']
|
||||
self.user_id = id
|
||||
self.admin: bool = result['admin'] == 1
|
||||
self.salt: bytes = result['salt']
|
||||
return
|
||||
|
||||
@property
|
||||
def reminders(self) -> Reminders:
|
||||
"""Get access to the reminders of the user account
|
||||
|
||||
Returns:
|
||||
Reminders: Reminders instance that can be used to access the
|
||||
reminders of the user account
|
||||
"""
|
||||
if not hasattr(self, 'reminders_instance'):
|
||||
self.reminders_instance = Reminders(self.user_id)
|
||||
return self.reminders_instance
|
||||
|
||||
@property
|
||||
def notification_services(self) -> NotificationServices:
|
||||
"""Get access to the notification services of the user account
|
||||
|
||||
Returns:
|
||||
NotificationServices: NotificationServices instance that can be used
|
||||
to access the notification services of the user account
|
||||
"""
|
||||
if not hasattr(self, 'notification_services_instance'):
|
||||
self.notification_services_instance = NotificationServices(self.user_id)
|
||||
return self.notification_services_instance
|
||||
|
||||
@property
|
||||
def templates(self) -> Templates:
|
||||
"""Get access to the templates of the user account
|
||||
|
||||
Returns:
|
||||
Templates: Templates instance that can be used to access the
|
||||
templates of the user account
|
||||
"""
|
||||
if not hasattr(self, 'templates_instance'):
|
||||
self.templates_instance = Templates(self.user_id)
|
||||
return self.templates_instance
|
||||
|
||||
@property
|
||||
def static_reminders(self) -> StaticReminders:
|
||||
"""Get access to the static reminders of the user account
|
||||
|
||||
Returns:
|
||||
StaticReminders: StaticReminders instance that can be used to
|
||||
access the static reminders of the user account
|
||||
"""
|
||||
if not hasattr(self, 'static_reminders_instance'):
|
||||
self.static_reminders_instance = StaticReminders(self.user_id)
|
||||
return self.static_reminders_instance
|
||||
|
||||
def edit_password(self, new_password: str) -> None:
|
||||
"""Change the password of the account
|
||||
|
||||
Args:
|
||||
new_password (str): The new password
|
||||
"""
|
||||
# Encrypt raw key with new password
|
||||
hash_password = get_hash(self.salt, new_password)
|
||||
|
||||
# Update database
|
||||
get_db().execute(
|
||||
"UPDATE users SET hash = ? WHERE id = ?",
|
||||
(hash_password, self.user_id)
|
||||
)
|
||||
LOGGER.info(f'The user {self.username} ({self.user_id}) changed their password')
|
||||
return
|
||||
|
||||
def delete(self) -> None:
|
||||
"""Delete the user account
|
||||
"""
|
||||
if self.username == 'admin':
|
||||
raise UserNotFound
|
||||
|
||||
LOGGER.info(f'Deleting the user {self.username} ({self.user_id})')
|
||||
|
||||
cursor = get_db()
|
||||
cursor.execute(
|
||||
"DELETE FROM reminders WHERE user_id = ?",
|
||||
(self.user_id,)
|
||||
)
|
||||
cursor.execute(
|
||||
"DELETE FROM templates WHERE user_id = ?",
|
||||
(self.user_id,)
|
||||
)
|
||||
cursor.execute(
|
||||
"DELETE FROM static_reminders WHERE user_id = ?",
|
||||
(self.user_id,)
|
||||
)
|
||||
cursor.execute(
|
||||
"DELETE FROM notification_services WHERE user_id = ?",
|
||||
(self.user_id,)
|
||||
)
|
||||
cursor.execute(
|
||||
"DELETE FROM users WHERE id = ?",
|
||||
(self.user_id,)
|
||||
)
|
||||
return
|
||||
|
||||
class Users:
|
||||
def _check_username(self, username: str) -> None:
|
||||
"""Check if username is valid
|
||||
|
||||
Args:
|
||||
username (str): The username to check
|
||||
|
||||
Raises:
|
||||
UsernameInvalid: The username is not valid
|
||||
"""
|
||||
LOGGER.debug(f'Checking the username {username}')
|
||||
if username in ONEPASS_INVALID_USERNAMES or username.isdigit():
|
||||
raise UsernameInvalid(username)
|
||||
if list(filter(lambda c: not c in ONEPASS_USERNAME_CHARACTERS, username)):
|
||||
raise UsernameInvalid(username)
|
||||
return
|
||||
|
||||
def __contains__(self, username: str) -> bool:
|
||||
result = get_db().execute(
|
||||
"SELECT 1 FROM users WHERE username = ? LIMIT 1;",
|
||||
(username,)
|
||||
).fetchone()
|
||||
return result is not None
|
||||
|
||||
def add(self, username: str, password: str, from_admin: bool=False) -> int:
|
||||
"""Add a user
|
||||
|
||||
Args:
|
||||
username (str): The username of the new user
|
||||
password (str): The password of the new user
|
||||
from_admin (bool, optional): Skip check if new accounts are allowed.
|
||||
Defaults to False.
|
||||
|
||||
Raises:
|
||||
UsernameInvalid: Username not allowed or contains invalid characters
|
||||
UsernameTaken: Username is already taken; usernames must be unique
|
||||
NewAccountsNotAllowed: In the admin panel, new accounts are set to be
|
||||
not allowed.
|
||||
|
||||
Returns:
|
||||
int: The id of the new user. User registered successful
|
||||
"""
|
||||
LOGGER.info(f'Registering user with username {username}')
|
||||
|
||||
if not from_admin and not get_setting('allow_new_accounts'):
|
||||
raise NewAccountsNotAllowed
|
||||
|
||||
# Check if username is valid
|
||||
self._check_username(username)
|
||||
|
||||
cursor = get_db()
|
||||
|
||||
# Check if username isn't already taken
|
||||
if username in self:
|
||||
raise UsernameTaken
|
||||
|
||||
# Generate salt and key exclusive for user
|
||||
salt, hashed_password = generate_salt_hash(password)
|
||||
del password
|
||||
|
||||
# Add user to userlist
|
||||
user_id = cursor.execute(
|
||||
"""
|
||||
INSERT INTO users(username, salt, hash)
|
||||
VALUES (?,?,?);
|
||||
""",
|
||||
(username, salt, hashed_password)
|
||||
).lastrowid
|
||||
|
||||
LOGGER.debug(f'Newly registered user has id {user_id}')
|
||||
return user_id
|
||||
|
||||
def get_all(self) -> List[dict]:
|
||||
"""Get all user info for the admin
|
||||
|
||||
Returns:
|
||||
List[dict]: The info about all users
|
||||
"""
|
||||
result = [
|
||||
dict(u)
|
||||
for u in get_db(dict).execute(
|
||||
"SELECT id, username, admin FROM users ORDER BY username;"
|
||||
)
|
||||
]
|
||||
return result
|
||||
|
||||
def login(self, username: str, password: str) -> User:
|
||||
result = get_db(dict).execute(
|
||||
"SELECT id, salt, hash FROM users WHERE username = ? LIMIT 1;",
|
||||
(username,)
|
||||
).fetchone()
|
||||
if not result:
|
||||
raise UserNotFound
|
||||
|
||||
hash_password = get_hash(result['salt'], password)
|
||||
if not hash_password == result['hash']:
|
||||
raise AccessUnauthorized
|
||||
|
||||
return User(result['id'])
|
||||
|
||||
def get_one(self, id: int) -> User:
|
||||
return User(id)
|
||||
Reference in New Issue
Block a user