Refactored input validation

This commit is contained in:
CasVT
2024-02-25 22:59:23 +01:00
parent 191325c52e
commit 6f1c37b79c
4 changed files with 337 additions and 211 deletions

View File

@@ -4,9 +4,13 @@
Input validation for the API
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from os.path import splitext
from re import compile
from typing import Any, Callable, Dict, List, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Union
from apprise import Apprise
from flask import Blueprint, request
@@ -19,23 +23,41 @@ from backend.custom_exceptions import (AccessUnauthorized, InvalidKeyValue,
UsernameInvalid, UsernameTaken,
UserNotFound)
from backend.helpers import (RepeatQuantity, SortingMethod,
TimelessSortingMethod)
TimelessSortingMethod, folder_path)
from backend.settings import _format_setting
if TYPE_CHECKING:
from flask import Request
api_prefix = "/api"
_admin_api_prefix = '/admin'
admin_api_prefix = api_prefix + _admin_api_prefix
color_regex = compile(r'#[0-9a-f]{6}')
api_docs: Dict[str, ApiDocEntry] = {}
class DataSource:
DATA = 1
VALUES = 2
FILES = 3
def __init__(self, request: Request) -> None:
self.map = {
self.DATA: request.get_json() if request.data else {},
self.VALUES: request.values,
self.FILES: request.files
}
return
def __getitem__(self, key: DataSource) -> dict:
return self.map[key]
class InputVariable(ABC):
value: Any
@abstractmethod
def __init__(self, value: Any) -> None:
pass
@@ -75,11 +97,51 @@ class InputVariable(ABC):
pass
class DefaultInputVariable(InputVariable):
@dataclass(frozen=True)
class Method:
description: str = ''
vars: List[InputVariable] = field(default_factory=list)
def __bool__(self) -> bool:
return self.vars != []
@dataclass(frozen=True)
class Methods:
get: Method = Method()
post: Method = Method()
put: Method = Method()
delete: Method = Method()
def __getitem__(self, key: str) -> Method:
return getattr(self, key.lower())
def __bool__(self) -> bool:
return bool(self.get or self.post or self.put or self.delete)
@dataclass(frozen=True)
class ApiDocEntry:
endpoint: str
description: str
requires_auth: bool
used_methods: List[str]
methods: Methods
def get_api_docs(request: Request) -> ApiDocEntry:
if request.path.startswith(admin_api_prefix):
url = _admin_api_prefix + request.url_rule.rule.split(admin_api_prefix)[1]
else:
url = request.url_rule.rule.split(api_prefix)[1]
return api_docs[url]
class BaseInputVariable(InputVariable):
source = DataSource.DATA
required = True
default = None
related_exceptions = []
related_exceptions = [KeyNotFound, InvalidKeyValue]
def __init__(self, value: Any) -> None:
self.value = value
@@ -91,31 +153,38 @@ class DefaultInputVariable(InputVariable):
return f'| {self.name} | {"Yes" if self.required else "No"} | {self.description} | N/A |'
class NonRequiredVersion(InputVariable):
class NonRequiredVersion(BaseInputVariable):
required = False
related_exceptions = [InvalidKeyValue]
def __init__(self, value: Any) -> None:
super().__init__(
value
if value is not None else
self.default
)
return
def validate(self) -> bool:
return self.value is None or super().validate()
class UsernameVariable(DefaultInputVariable):
class UsernameVariable(BaseInputVariable):
name = 'username'
description = 'The username of the user account'
related_exceptions = [KeyNotFound, UserNotFound]
class PasswordVariable(DefaultInputVariable):
class PasswordCreateVariable(BaseInputVariable):
name = 'password'
description = 'The password of the user account'
related_exceptions = [KeyNotFound, AccessUnauthorized]
class NewPasswordVariable(PasswordVariable):
name = 'new_password'
description = 'The new password of the user account'
related_exceptions = [KeyNotFound]
class PasswordVariable(PasswordCreateVariable):
related_exceptions = [KeyNotFound, AccessUnauthorized]
class UsernameCreateVariable(UsernameVariable):
related_exceptions = [
KeyNotFound,
@@ -124,44 +193,39 @@ class UsernameCreateVariable(UsernameVariable):
]
class PasswordCreateVariable(PasswordVariable):
class NewPasswordVariable(BaseInputVariable):
name = 'new_password'
description = 'The new password of the user account'
related_exceptions = [KeyNotFound]
class TitleVariable(DefaultInputVariable):
class TitleVariable(BaseInputVariable):
name = 'title'
description = 'The title of the entry'
related_exceptions = [KeyNotFound]
class URLVariable(DefaultInputVariable):
class URLVariable(BaseInputVariable):
name = 'url'
description = 'The Apprise URL of the notification service'
related_exceptions = [KeyNotFound, InvalidKeyValue]
def validate(self) -> bool:
return Apprise().add(self.value)
return super().validate() and Apprise().add(self.value)
class EditTitleVariable(NonRequiredVersion, TitleVariable):
related_exceptions = []
pass
class EditURLVariable(NonRequiredVersion, URLVariable):
related_exceptions = [InvalidKeyValue]
pass
class SortByVariable(DefaultInputVariable):
class SortByVariable(NonRequiredVersion, BaseInputVariable):
name = 'sort_by'
description = 'How to sort the result'
required = False
source = DataSource.VALUES
_options = [k.lower() for k in SortingMethod._member_names_]
default = SortingMethod._member_names_[0].lower()
related_exceptions = [InvalidKeyValue]
def __init__(self, value: str) -> None:
self.value = value
def validate(self) -> bool:
if not self.value in self._options:
@@ -179,7 +243,7 @@ class SortByVariable(DefaultInputVariable):
)
class TemplateSortByVariable(SortByVariable):
class TimelessSortByVariable(SortByVariable):
_options = [k.lower() for k in TimelessSortingMethod._member_names_]
default = TimelessSortingMethod._member_names_[0].lower()
@@ -190,11 +254,8 @@ class TemplateSortByVariable(SortByVariable):
self.value = TimelessSortingMethod[self.value.upper()]
return True
class StaticReminderSortByVariable(TemplateSortByVariable):
pass
class TimeVariable(DefaultInputVariable):
class TimeVariable(BaseInputVariable):
name = 'time'
description = 'The UTC epoch timestamp that the reminder should be sent at'
related_exceptions = [KeyNotFound, InvalidKeyValue, InvalidTime]
@@ -207,10 +268,13 @@ class EditTimeVariable(NonRequiredVersion, TimeVariable):
related_exceptions = [InvalidKeyValue, InvalidTime]
class NotificationServicesVariable(DefaultInputVariable):
class NotificationServicesVariable(BaseInputVariable):
name = 'notification_services'
description = "Array of the id's of the notification services to use to send the notification"
related_exceptions = [KeyNotFound, InvalidKeyValue, NotificationServiceNotFound]
related_exceptions = [
KeyNotFound, InvalidKeyValue,
NotificationServiceNotFound
]
def validate(self) -> bool:
if not isinstance(self.value, list):
@@ -227,7 +291,7 @@ class EditNotificationServicesVariable(NonRequiredVersion, NotificationServicesV
related_exceptions = [InvalidKeyValue, NotificationServiceNotFound]
class TextVariable(NonRequiredVersion, DefaultInputVariable):
class TextVariable(NonRequiredVersion, BaseInputVariable):
name = 'text'
description = 'The body of the entry'
default = ''
@@ -236,13 +300,10 @@ class TextVariable(NonRequiredVersion, DefaultInputVariable):
return isinstance(self.value, str)
class RepeatQuantityVariable(DefaultInputVariable):
class RepeatQuantityVariable(NonRequiredVersion, BaseInputVariable):
name = 'repeat_quantity'
description = 'The quantity of the repeat_interval'
required = False
_options = [m.lower() for m in RepeatQuantity._member_names_]
default = None
related_exceptions = [InvalidKeyValue]
def validate(self) -> bool:
if self.value is None:
@@ -263,12 +324,9 @@ class RepeatQuantityVariable(DefaultInputVariable):
)
class RepeatIntervalVariable(DefaultInputVariable):
class RepeatIntervalVariable(NonRequiredVersion, BaseInputVariable):
name = 'repeat_interval'
description = 'The number of the interval'
required = False
default = None
related_exceptions = [InvalidKeyValue]
def validate(self) -> bool:
return (
@@ -280,12 +338,9 @@ class RepeatIntervalVariable(DefaultInputVariable):
)
class WeekDaysVariable(DefaultInputVariable):
class WeekDaysVariable(NonRequiredVersion, BaseInputVariable):
name = 'weekdays'
description = 'On which days of the weeks to run the reminder'
required = False
default = None
related_exceptions = [InvalidKeyValue]
_options = {0, 1, 2, 3, 4, 5, 6}
def validate(self) -> bool:
@@ -296,26 +351,25 @@ class WeekDaysVariable(DefaultInputVariable):
)
class ColorVariable(DefaultInputVariable):
class ColorVariable(NonRequiredVersion, BaseInputVariable):
name = 'color'
description = 'The hex code of the color of the entry, which is shown in the web-ui'
required = False
default = None
related_exceptions = [InvalidKeyValue]
def validate(self) -> bool:
return self.value is None or color_regex.search(self.value)
super()
return self.value is None or (
isinstance(self.value, str)
and color_regex.search(self.value)
)
class QueryVariable(DefaultInputVariable):
class QueryVariable(BaseInputVariable):
name = 'query'
description = 'The search term'
source = DataSource.VALUES
class AdminSettingsVariable(DefaultInputVariable):
related_exceptions = [KeyNotFound, InvalidKeyValue]
class AdminSettingsVariable(BaseInputVariable):
def validate(self) -> bool:
try:
_format_setting(self.name, self.value)
@@ -352,54 +406,43 @@ def input_validation() -> Union[None, Dict[str, Any]]:
Union[None, Dict[str, Any]]: `None` if the endpoint + method doesn't require input variables.
Otherwise `Dict[str, Any]` with the input variables, checked and formatted.
"""
inputs = {}
result = {}
input_variables: Dict[str, List[Union[List[InputVariable], str]]]
if request.path.startswith(admin_api_prefix):
input_variables = api_docs[
_admin_api_prefix + request.url_rule.rule.split(admin_api_prefix)[1]
]['input_variables']
else:
input_variables = api_docs[
request.url_rule.rule.split(api_prefix)[1]
]['input_variables']
methods = get_api_docs(request).methods
method = methods[request.method]
noted_variables = method.vars
if not input_variables:
return
if not methods:
return None
if input_variables.get(request.method) is None:
return inputs
if not method:
return result
given_variables = {}
given_variables[DataSource.DATA] = request.get_json() if request.data else {}
given_variables[DataSource.VALUES] = request.values
for input_variable in input_variables[request.method]:
given_variables = DataSource(request)
for noted_var in noted_variables:
if (
input_variable.required and
not input_variable.name in given_variables[input_variable.source]
noted_var.required and
not noted_var.name in given_variables[noted_var.source]
):
raise KeyNotFound(input_variable.name)
raise KeyNotFound(noted_var.name)
input_value = given_variables[input_variable.source].get(
input_variable.name,
input_variable.default
)
value: InputVariable = input_variable(input_value)
input_value = given_variables[noted_var.source].get(noted_var.name)
value: InputVariable = noted_var(input_value)
if not value.validate():
raise InvalidKeyValue(input_variable.name, input_value)
raise InvalidKeyValue(noted_var.name, input_value)
inputs[input_variable.name] = value.value
return inputs
result[noted_var.name] = value.value
return result
api_docs: Dict[str, Dict[str, Any]] = {}
class APIBlueprint(Blueprint):
def route(
self,
rule: str,
description: str = '',
input_variables: Dict[str, List[Union[List[InputVariable], str]]] = {},
input_variables: Methods = Methods(),
requires_auth: bool = True,
**options: Any
) -> Callable[[T_route], T_route]:
@@ -411,22 +454,13 @@ class APIBlueprint(Blueprint):
else:
raise NotImplementedError
api_docs[processed_rule] = {
'endpoint': processed_rule,
'description': description,
'requires_auth': requires_auth,
'methods': options['methods'],
'input_variables': {
k: v[0]
for k, v in input_variables.items()
if v and v[0]
},
'method_descriptions': {
k: v[1]
for k, v in input_variables.items()
if v and len(v) == 2 and v[1]
}
}
api_docs[processed_rule] = ApiDocEntry(
endpoint=processed_rule,
description=description,
requires_auth=requires_auth,
used_methods=options['methods'],
methods=input_variables
)
return super().route(rule, **options)