Files
MIND/frontend/input_validation.py

589 lines
14 KiB
Python

#-*- coding: utf-8 -*-
"""
Input validation for the API
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
import logging
from os.path import splitext
from re import compile
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Type, Union
from apprise import Apprise
from flask import Blueprint, request
from flask.sansio.scaffold import T_route
from backend.custom_exceptions import (AccessUnauthorized, InvalidDatabaseFile,
InvalidKeyValue, InvalidTime,
KeyNotFound, NewAccountsNotAllowed,
NotificationServiceNotFound,
UsernameInvalid, UsernameTaken,
UserNotFound)
from backend.helpers import (RepeatQuantity, SortingMethod,
TimelessSortingMethod, folder_path)
from backend.server import SERVER
from backend.settings import _format_setting
if TYPE_CHECKING:
from flask import Request
color_regex = compile(r'#[0-9a-f]{6}')
api_docs: Dict[str, ApiDocEntry] = {}
class DataSource:
DATA = 1
VALUES = 2
FILES = 3
def __init__(self, request: Request) -> None:
self.map: Dict[int, dict] = {
self.DATA: request.get_json() if request.data else {},
self.VALUES: request.values,
self.FILES: request.files
}
return
def __getitem__(self, key: int) -> dict:
return self.map[key]
class DataType:
STR = 'string'
INT = 'number'
FLOAT = 'decimal number'
BOOL = 'bool'
INT_ARRAY = 'list of numbers'
NA = 'N/A'
class InputVariable(ABC):
value: Any
@abstractmethod
def __init__(self, value: Any) -> None:
pass
@property
@abstractmethod
def name(self) -> str:
pass
@abstractmethod
def validate(self) -> bool:
pass
@property
@abstractmethod
def required(self) -> bool:
pass
@property
@abstractmethod
def data_type(self) -> List[str]:
pass
@property
@abstractmethod
def default(self) -> Any:
pass
@property
@abstractmethod
def source(self) -> int:
pass
@property
@abstractmethod
def description(self) -> str:
pass
@property
@abstractmethod
def related_exceptions(self) -> List[Exception]:
pass
@dataclass(frozen=True)
class Method:
description: str = ''
vars: List[Type[InputVariable]] = field(default_factory=list)
def __bool__(self) -> bool:
return self.vars != []
@dataclass(frozen=True)
class Methods:
get: Method = Method()
post: Method = Method()
put: Method = Method()
delete: Method = Method()
def __getitem__(self, key: str) -> Method:
return getattr(self, key.lower())
def __bool__(self) -> bool:
return bool(self.get or self.post or self.put or self.delete)
@dataclass(frozen=True)
class ApiDocEntry:
endpoint: str
description: str
requires_auth: bool
used_methods: List[str]
methods: Methods
def get_api_docs(request: Request) -> ApiDocEntry:
if request.path.startswith(SERVER.admin_prefix):
url = SERVER.admin_api_extension + request.url_rule.rule.split(SERVER.admin_prefix)[1]
else:
url = request.url_rule.rule.split(SERVER.api_prefix)[1]
return api_docs[url]
class BaseInputVariable(InputVariable):
source = DataSource.DATA
data_type = [DataType.STR]
required = True
default = None
related_exceptions = [KeyNotFound, InvalidKeyValue]
def __init__(self, value: Any) -> None:
self.value = value
def validate(self) -> bool:
return isinstance(self.value, str) and self.value
def __repr__(self) -> str:
return f'| {self.name} | {"Yes" if self.required else "No"} | {",".join(self.data_type)} | {self.description} | N/A |'
class NonRequiredVersion(BaseInputVariable):
required = False
related_exceptions = [InvalidKeyValue]
def __init__(self, value: Any) -> None:
super().__init__(
value
if value is not None else
self.default
)
return
def validate(self) -> bool:
return self.value is None or super().validate()
class UsernameVariable(BaseInputVariable):
name = 'username'
description = 'The username of the user account'
related_exceptions = [KeyNotFound, UserNotFound]
class PasswordCreateVariable(BaseInputVariable):
name = 'password'
description = 'The password of the user account'
related_exceptions = [KeyNotFound]
class PasswordVariable(PasswordCreateVariable):
related_exceptions = [KeyNotFound, AccessUnauthorized]
class UsernameCreateVariable(UsernameVariable):
related_exceptions = [
KeyNotFound,
UsernameInvalid, UsernameTaken,
NewAccountsNotAllowed
]
class NewPasswordVariable(BaseInputVariable):
name = 'new_password'
description = 'The new password of the user account'
related_exceptions = [KeyNotFound]
class TitleVariable(BaseInputVariable):
name = 'title'
description = 'The title of the entry'
class URLVariable(BaseInputVariable):
name = 'url'
description = 'The Apprise URL of the notification service'
def validate(self) -> bool:
return super().validate() and Apprise().add(self.value)
class EditTitleVariable(NonRequiredVersion, TitleVariable):
pass
class EditURLVariable(NonRequiredVersion, URLVariable):
pass
class SortByVariable(NonRequiredVersion, BaseInputVariable):
name = 'sort_by'
description = 'How to sort the result'
source = DataSource.VALUES
_options = [k.lower() for k in SortingMethod._member_names_]
default = SortingMethod._member_names_[0].lower()
def validate(self) -> bool:
if not self.value in self._options:
return False
self.value = SortingMethod[self.value.upper()]
return True
def __repr__(self) -> str:
return '| {n} | {r} | {t} | {d} | {v} |'.format(
n=self.name,
r="Yes" if self.required else "No",
t=",".join(self.data_type),
d=self.description,
v=", ".join(f'`{o}`' for o in self._options)
)
class TimelessSortByVariable(SortByVariable):
_options = [k.lower() for k in TimelessSortingMethod._member_names_]
default = TimelessSortingMethod._member_names_[0].lower()
def validate(self) -> bool:
if not self.value in self._options:
return False
self.value = TimelessSortingMethod[self.value.upper()]
return True
class TimeVariable(BaseInputVariable):
name = 'time'
description = 'The UTC epoch timestamp that the reminder should be sent at'
data_type = [DataType.INT, DataType.FLOAT]
related_exceptions = [KeyNotFound, InvalidKeyValue, InvalidTime]
def validate(self) -> bool:
return isinstance(self.value, (float, int))
class EditTimeVariable(NonRequiredVersion, TimeVariable):
related_exceptions = [InvalidKeyValue, InvalidTime]
class NotificationServicesVariable(BaseInputVariable):
name = 'notification_services'
description = "Array of the id's of the notification services to use to send the notification"
data_type = [DataType.INT_ARRAY]
related_exceptions = [
KeyNotFound, InvalidKeyValue,
NotificationServiceNotFound
]
def validate(self) -> bool:
if not isinstance(self.value, list):
return False
if not self.value:
return False
for v in self.value:
if not isinstance(v, int):
return False
return True
class EditNotificationServicesVariable(NonRequiredVersion, NotificationServicesVariable):
related_exceptions = [InvalidKeyValue, NotificationServiceNotFound]
class TextVariable(NonRequiredVersion, BaseInputVariable):
name = 'text'
description = 'The body of the entry'
default = ''
def validate(self) -> bool:
return isinstance(self.value, str)
class RepeatQuantityVariable(NonRequiredVersion, BaseInputVariable):
name = 'repeat_quantity'
description = 'The quantity of the repeat_interval'
_options = [m.lower() for m in RepeatQuantity._member_names_]
def validate(self) -> bool:
if self.value is None:
return True
if not self.value in self._options:
return False
self.value = RepeatQuantity[self.value.upper()]
return True
def __repr__(self) -> str:
return '| {n} | {r} | {t} | {d} | {v} |'.format(
n=self.name,
r="Yes" if self.required else "No",
t=",".join(self.data_type),
d=self.description,
v=", ".join(f'`{o}`' for o in self._options)
)
class RepeatIntervalVariable(NonRequiredVersion, BaseInputVariable):
name = 'repeat_interval'
description = 'The number of the interval'
data_type = [DataType.INT]
def validate(self) -> bool:
return (
self.value is None
or (
isinstance(self.value, int)
and self.value > 0
)
)
class WeekDaysVariable(NonRequiredVersion, BaseInputVariable):
name = 'weekdays'
description = 'On which days of the weeks to run the reminder'
data_type = [DataType.INT_ARRAY]
_options = {0, 1, 2, 3, 4, 5, 6}
def validate(self) -> bool:
return self.value is None or (
isinstance(self.value, list)
and len(self.value) > 0
and all(v in self._options for v in self.value)
)
def __repr__(self) -> str:
return '| {n} | {r} | {t} | {d} | {v} |'.format(
n=self.name,
r="Yes" if self.required else "No",
t=",".join(self.data_type),
d=self.description,
v=", ".join(f'`{o}`' for o in self._options)
)
class ColorVariable(NonRequiredVersion, BaseInputVariable):
name = 'color'
description = 'The hex code of the color of the entry, which is shown in the web-ui'
def validate(self) -> bool:
return self.value is None or (
isinstance(self.value, str)
and color_regex.search(self.value)
)
class QueryVariable(BaseInputVariable):
name = 'query'
description = 'The search term'
source = DataSource.VALUES
class DeleteRemindersUsingVariable(NonRequiredVersion, BaseInputVariable):
name = 'delete_reminders_using'
description = 'Instead of throwing an error when there are still reminders using the service, delete the reminders.'
source = DataSource.VALUES
default = 'false'
data_type = [DataType.BOOL]
def validate(self) -> bool:
if self.value == 'true':
self.value = True
return True
elif self.value == 'false':
self.value = False
return True
else:
return False
class AdminSettingsVariable(BaseInputVariable):
def validate(self) -> bool:
try:
_format_setting(self.name, self.value)
except InvalidKeyValue:
return False
return True
class AllowNewAccountsVariable(NonRequiredVersion, AdminSettingsVariable):
name = 'allow_new_accounts'
description = ('Whether or not to allow users to register a new account. '
+ 'The admin can always add a new account.')
data_type = [DataType.BOOL]
class LoginTimeVariable(NonRequiredVersion, AdminSettingsVariable):
name = 'login_time'
description = ('How long a user stays logged in, in seconds. '
+ 'Between 1 min and 1 month (60 <= sec <= 2592000)')
data_type = [DataType.INT]
class LoginTimeResetVariable(NonRequiredVersion, AdminSettingsVariable):
name = 'login_time_reset'
description = 'If the Login Time timer should reset with each API request.'
data_type = [DataType.BOOL]
class HostVariable(NonRequiredVersion, AdminSettingsVariable):
name = 'host'
description = 'The IP to bind to. Use 0.0.0.0 to bind to all addresses.'
class PortVariable(NonRequiredVersion, AdminSettingsVariable):
name = 'port'
description = 'The port to listen on.'
data_type = [DataType.INT]
class UrlPrefixVariable(NonRequiredVersion, AdminSettingsVariable):
name = 'url_prefix'
description = 'The base url to run on. Useful for reverse proxies. Empty string to disable.'
class LogLevelVariable(NonRequiredVersion, AdminSettingsVariable):
name = 'log_level'
description = 'The level to log on.'
data_type = [DataType.INT]
_options = [logging.INFO, logging.DEBUG]
def __repr__(self) -> str:
return '| {n} | {r} | {t} | {d} | {v} |'.format(
n=self.name,
r="Yes" if self.required else "No",
t=",".join(self.data_type),
d=self.description,
v=", ".join(f'`{o}`' for o in self._options)
)
class DatabaseFileVariable(BaseInputVariable):
name = 'file'
description = 'The MIND database file'
data_type = [DataType.NA]
source = DataSource.FILES
related_exceptions = [KeyNotFound, InvalidDatabaseFile]
def validate(self) -> bool:
if (
self.value.filename
and splitext(self.value.filename)[1] == '.db'
):
path = folder_path('db', 'MIND_upload.db')
self.value.save(path)
self.value = path
return True
else:
return False
class CopyHostingSettingsVariable(BaseInputVariable):
name = 'copy_hosting_settings'
description = 'Copy the hosting settings from the current database'
data_type = [DataType.BOOL]
source = DataSource.VALUES
def validate(self) -> bool:
if not self.value in ('true', 'false'):
return False
self.value = self.value == 'true'
return True
def input_validation() -> Union[None, Dict[str, Any]]:
"""Checks, extracts and transforms inputs
Raises:
KeyNotFound: A required key was not supplied
InvalidKeyValue: The value of a key is not valid
Returns:
Union[None, Dict[str, Any]]: `None` if the endpoint + method doesn't require input variables.
Otherwise `Dict[str, Any]` with the input variables, checked and formatted.
"""
result = {}
methods = get_api_docs(request).methods
method = methods[request.method]
noted_variables = method.vars
if not methods:
return None
if not method:
return result
given_variables = DataSource(request)
for noted_var in noted_variables:
if (
noted_var.required and
not noted_var.name in given_variables[noted_var.source]
):
raise KeyNotFound(noted_var.name)
input_value = given_variables[noted_var.source].get(noted_var.name)
value: InputVariable = noted_var(input_value)
if not value.validate():
if noted_var.__class__.__name__ == DatabaseFileVariable.__name__:
raise InvalidDatabaseFile
elif noted_var.source == DataSource.FILES:
raise InvalidKeyValue(noted_var.name, input_value.filename)
else:
raise InvalidKeyValue(noted_var.name, input_value)
result[noted_var.name] = value.value
return result
class APIBlueprint(Blueprint):
def route(
self,
rule: str,
description: str = '',
input_variables: Methods = Methods(),
requires_auth: bool = True,
**options: Any
) -> Callable[[T_route], T_route]:
if self == api:
processed_rule = rule
elif self == admin_api:
processed_rule = SERVER.admin_api_extension + rule
else:
raise NotImplementedError
api_docs[processed_rule] = ApiDocEntry(
endpoint=processed_rule,
description=description,
requires_auth=requires_auth,
used_methods=options['methods'],
methods=input_variables
)
return super().route(rule, **options)
api = APIBlueprint('api', __name__)
admin_api = APIBlueprint('admin_api', __name__)