mirror of
https://github.com/Casvt/MIND.git
synced 2026-04-25 03:00:20 -04:00
Refactored input validation
This commit is contained in:
@@ -4,9 +4,13 @@
|
||||
Input validation for the API
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from os.path import splitext
|
||||
from re import compile
|
||||
from typing import Any, Callable, Dict, List, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Union
|
||||
|
||||
from apprise import Apprise
|
||||
from flask import Blueprint, request
|
||||
@@ -19,23 +23,41 @@ from backend.custom_exceptions import (AccessUnauthorized, InvalidKeyValue,
|
||||
UsernameInvalid, UsernameTaken,
|
||||
UserNotFound)
|
||||
from backend.helpers import (RepeatQuantity, SortingMethod,
|
||||
TimelessSortingMethod)
|
||||
TimelessSortingMethod, folder_path)
|
||||
from backend.settings import _format_setting
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from flask import Request
|
||||
|
||||
api_prefix = "/api"
|
||||
_admin_api_prefix = '/admin'
|
||||
admin_api_prefix = api_prefix + _admin_api_prefix
|
||||
|
||||
color_regex = compile(r'#[0-9a-f]{6}')
|
||||
|
||||
api_docs: Dict[str, ApiDocEntry] = {}
|
||||
|
||||
|
||||
class DataSource:
|
||||
DATA = 1
|
||||
VALUES = 2
|
||||
FILES = 3
|
||||
|
||||
def __init__(self, request: Request) -> None:
|
||||
self.map = {
|
||||
self.DATA: request.get_json() if request.data else {},
|
||||
self.VALUES: request.values,
|
||||
self.FILES: request.files
|
||||
}
|
||||
return
|
||||
|
||||
def __getitem__(self, key: DataSource) -> dict:
|
||||
return self.map[key]
|
||||
|
||||
|
||||
class InputVariable(ABC):
|
||||
value: Any
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, value: Any) -> None:
|
||||
pass
|
||||
@@ -75,11 +97,51 @@ class InputVariable(ABC):
|
||||
pass
|
||||
|
||||
|
||||
class DefaultInputVariable(InputVariable):
|
||||
@dataclass(frozen=True)
|
||||
class Method:
|
||||
description: str = ''
|
||||
vars: List[InputVariable] = field(default_factory=list)
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return self.vars != []
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Methods:
|
||||
get: Method = Method()
|
||||
post: Method = Method()
|
||||
put: Method = Method()
|
||||
delete: Method = Method()
|
||||
|
||||
def __getitem__(self, key: str) -> Method:
|
||||
return getattr(self, key.lower())
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self.get or self.post or self.put or self.delete)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ApiDocEntry:
|
||||
endpoint: str
|
||||
description: str
|
||||
requires_auth: bool
|
||||
used_methods: List[str]
|
||||
methods: Methods
|
||||
|
||||
|
||||
def get_api_docs(request: Request) -> ApiDocEntry:
|
||||
if request.path.startswith(admin_api_prefix):
|
||||
url = _admin_api_prefix + request.url_rule.rule.split(admin_api_prefix)[1]
|
||||
else:
|
||||
url = request.url_rule.rule.split(api_prefix)[1]
|
||||
return api_docs[url]
|
||||
|
||||
|
||||
class BaseInputVariable(InputVariable):
|
||||
source = DataSource.DATA
|
||||
required = True
|
||||
default = None
|
||||
related_exceptions = []
|
||||
related_exceptions = [KeyNotFound, InvalidKeyValue]
|
||||
|
||||
def __init__(self, value: Any) -> None:
|
||||
self.value = value
|
||||
@@ -91,31 +153,38 @@ class DefaultInputVariable(InputVariable):
|
||||
return f'| {self.name} | {"Yes" if self.required else "No"} | {self.description} | N/A |'
|
||||
|
||||
|
||||
class NonRequiredVersion(InputVariable):
|
||||
class NonRequiredVersion(BaseInputVariable):
|
||||
required = False
|
||||
related_exceptions = [InvalidKeyValue]
|
||||
|
||||
def __init__(self, value: Any) -> None:
|
||||
super().__init__(
|
||||
value
|
||||
if value is not None else
|
||||
self.default
|
||||
)
|
||||
return
|
||||
|
||||
def validate(self) -> bool:
|
||||
return self.value is None or super().validate()
|
||||
|
||||
|
||||
class UsernameVariable(DefaultInputVariable):
|
||||
class UsernameVariable(BaseInputVariable):
|
||||
name = 'username'
|
||||
description = 'The username of the user account'
|
||||
related_exceptions = [KeyNotFound, UserNotFound]
|
||||
|
||||
|
||||
class PasswordVariable(DefaultInputVariable):
|
||||
class PasswordCreateVariable(BaseInputVariable):
|
||||
name = 'password'
|
||||
description = 'The password of the user account'
|
||||
related_exceptions = [KeyNotFound, AccessUnauthorized]
|
||||
|
||||
|
||||
class NewPasswordVariable(PasswordVariable):
|
||||
name = 'new_password'
|
||||
description = 'The new password of the user account'
|
||||
related_exceptions = [KeyNotFound]
|
||||
|
||||
|
||||
class PasswordVariable(PasswordCreateVariable):
|
||||
related_exceptions = [KeyNotFound, AccessUnauthorized]
|
||||
|
||||
|
||||
class UsernameCreateVariable(UsernameVariable):
|
||||
related_exceptions = [
|
||||
KeyNotFound,
|
||||
@@ -124,44 +193,39 @@ class UsernameCreateVariable(UsernameVariable):
|
||||
]
|
||||
|
||||
|
||||
class PasswordCreateVariable(PasswordVariable):
|
||||
class NewPasswordVariable(BaseInputVariable):
|
||||
name = 'new_password'
|
||||
description = 'The new password of the user account'
|
||||
related_exceptions = [KeyNotFound]
|
||||
|
||||
|
||||
class TitleVariable(DefaultInputVariable):
|
||||
class TitleVariable(BaseInputVariable):
|
||||
name = 'title'
|
||||
description = 'The title of the entry'
|
||||
related_exceptions = [KeyNotFound]
|
||||
|
||||
|
||||
class URLVariable(DefaultInputVariable):
|
||||
class URLVariable(BaseInputVariable):
|
||||
name = 'url'
|
||||
description = 'The Apprise URL of the notification service'
|
||||
related_exceptions = [KeyNotFound, InvalidKeyValue]
|
||||
|
||||
def validate(self) -> bool:
|
||||
return Apprise().add(self.value)
|
||||
return super().validate() and Apprise().add(self.value)
|
||||
|
||||
|
||||
class EditTitleVariable(NonRequiredVersion, TitleVariable):
|
||||
related_exceptions = []
|
||||
pass
|
||||
|
||||
|
||||
class EditURLVariable(NonRequiredVersion, URLVariable):
|
||||
related_exceptions = [InvalidKeyValue]
|
||||
pass
|
||||
|
||||
|
||||
class SortByVariable(DefaultInputVariable):
|
||||
class SortByVariable(NonRequiredVersion, BaseInputVariable):
|
||||
name = 'sort_by'
|
||||
description = 'How to sort the result'
|
||||
required = False
|
||||
source = DataSource.VALUES
|
||||
_options = [k.lower() for k in SortingMethod._member_names_]
|
||||
default = SortingMethod._member_names_[0].lower()
|
||||
related_exceptions = [InvalidKeyValue]
|
||||
|
||||
def __init__(self, value: str) -> None:
|
||||
self.value = value
|
||||
|
||||
def validate(self) -> bool:
|
||||
if not self.value in self._options:
|
||||
@@ -179,7 +243,7 @@ class SortByVariable(DefaultInputVariable):
|
||||
)
|
||||
|
||||
|
||||
class TemplateSortByVariable(SortByVariable):
|
||||
class TimelessSortByVariable(SortByVariable):
|
||||
_options = [k.lower() for k in TimelessSortingMethod._member_names_]
|
||||
default = TimelessSortingMethod._member_names_[0].lower()
|
||||
|
||||
@@ -190,11 +254,8 @@ class TemplateSortByVariable(SortByVariable):
|
||||
self.value = TimelessSortingMethod[self.value.upper()]
|
||||
return True
|
||||
|
||||
class StaticReminderSortByVariable(TemplateSortByVariable):
|
||||
pass
|
||||
|
||||
|
||||
class TimeVariable(DefaultInputVariable):
|
||||
class TimeVariable(BaseInputVariable):
|
||||
name = 'time'
|
||||
description = 'The UTC epoch timestamp that the reminder should be sent at'
|
||||
related_exceptions = [KeyNotFound, InvalidKeyValue, InvalidTime]
|
||||
@@ -207,10 +268,13 @@ class EditTimeVariable(NonRequiredVersion, TimeVariable):
|
||||
related_exceptions = [InvalidKeyValue, InvalidTime]
|
||||
|
||||
|
||||
class NotificationServicesVariable(DefaultInputVariable):
|
||||
class NotificationServicesVariable(BaseInputVariable):
|
||||
name = 'notification_services'
|
||||
description = "Array of the id's of the notification services to use to send the notification"
|
||||
related_exceptions = [KeyNotFound, InvalidKeyValue, NotificationServiceNotFound]
|
||||
related_exceptions = [
|
||||
KeyNotFound, InvalidKeyValue,
|
||||
NotificationServiceNotFound
|
||||
]
|
||||
|
||||
def validate(self) -> bool:
|
||||
if not isinstance(self.value, list):
|
||||
@@ -227,7 +291,7 @@ class EditNotificationServicesVariable(NonRequiredVersion, NotificationServicesV
|
||||
related_exceptions = [InvalidKeyValue, NotificationServiceNotFound]
|
||||
|
||||
|
||||
class TextVariable(NonRequiredVersion, DefaultInputVariable):
|
||||
class TextVariable(NonRequiredVersion, BaseInputVariable):
|
||||
name = 'text'
|
||||
description = 'The body of the entry'
|
||||
default = ''
|
||||
@@ -236,13 +300,10 @@ class TextVariable(NonRequiredVersion, DefaultInputVariable):
|
||||
return isinstance(self.value, str)
|
||||
|
||||
|
||||
class RepeatQuantityVariable(DefaultInputVariable):
|
||||
class RepeatQuantityVariable(NonRequiredVersion, BaseInputVariable):
|
||||
name = 'repeat_quantity'
|
||||
description = 'The quantity of the repeat_interval'
|
||||
required = False
|
||||
_options = [m.lower() for m in RepeatQuantity._member_names_]
|
||||
default = None
|
||||
related_exceptions = [InvalidKeyValue]
|
||||
|
||||
def validate(self) -> bool:
|
||||
if self.value is None:
|
||||
@@ -263,12 +324,9 @@ class RepeatQuantityVariable(DefaultInputVariable):
|
||||
)
|
||||
|
||||
|
||||
class RepeatIntervalVariable(DefaultInputVariable):
|
||||
class RepeatIntervalVariable(NonRequiredVersion, BaseInputVariable):
|
||||
name = 'repeat_interval'
|
||||
description = 'The number of the interval'
|
||||
required = False
|
||||
default = None
|
||||
related_exceptions = [InvalidKeyValue]
|
||||
|
||||
def validate(self) -> bool:
|
||||
return (
|
||||
@@ -280,12 +338,9 @@ class RepeatIntervalVariable(DefaultInputVariable):
|
||||
)
|
||||
|
||||
|
||||
class WeekDaysVariable(DefaultInputVariable):
|
||||
class WeekDaysVariable(NonRequiredVersion, BaseInputVariable):
|
||||
name = 'weekdays'
|
||||
description = 'On which days of the weeks to run the reminder'
|
||||
required = False
|
||||
default = None
|
||||
related_exceptions = [InvalidKeyValue]
|
||||
_options = {0, 1, 2, 3, 4, 5, 6}
|
||||
|
||||
def validate(self) -> bool:
|
||||
@@ -296,26 +351,25 @@ class WeekDaysVariable(DefaultInputVariable):
|
||||
)
|
||||
|
||||
|
||||
class ColorVariable(DefaultInputVariable):
|
||||
class ColorVariable(NonRequiredVersion, BaseInputVariable):
|
||||
name = 'color'
|
||||
description = 'The hex code of the color of the entry, which is shown in the web-ui'
|
||||
required = False
|
||||
default = None
|
||||
related_exceptions = [InvalidKeyValue]
|
||||
|
||||
def validate(self) -> bool:
|
||||
return self.value is None or color_regex.search(self.value)
|
||||
super()
|
||||
return self.value is None or (
|
||||
isinstance(self.value, str)
|
||||
and color_regex.search(self.value)
|
||||
)
|
||||
|
||||
|
||||
class QueryVariable(DefaultInputVariable):
|
||||
class QueryVariable(BaseInputVariable):
|
||||
name = 'query'
|
||||
description = 'The search term'
|
||||
source = DataSource.VALUES
|
||||
|
||||
|
||||
class AdminSettingsVariable(DefaultInputVariable):
|
||||
related_exceptions = [KeyNotFound, InvalidKeyValue]
|
||||
|
||||
class AdminSettingsVariable(BaseInputVariable):
|
||||
def validate(self) -> bool:
|
||||
try:
|
||||
_format_setting(self.name, self.value)
|
||||
@@ -352,54 +406,43 @@ def input_validation() -> Union[None, Dict[str, Any]]:
|
||||
Union[None, Dict[str, Any]]: `None` if the endpoint + method doesn't require input variables.
|
||||
Otherwise `Dict[str, Any]` with the input variables, checked and formatted.
|
||||
"""
|
||||
inputs = {}
|
||||
result = {}
|
||||
|
||||
input_variables: Dict[str, List[Union[List[InputVariable], str]]]
|
||||
if request.path.startswith(admin_api_prefix):
|
||||
input_variables = api_docs[
|
||||
_admin_api_prefix + request.url_rule.rule.split(admin_api_prefix)[1]
|
||||
]['input_variables']
|
||||
else:
|
||||
input_variables = api_docs[
|
||||
request.url_rule.rule.split(api_prefix)[1]
|
||||
]['input_variables']
|
||||
methods = get_api_docs(request).methods
|
||||
method = methods[request.method]
|
||||
noted_variables = method.vars
|
||||
|
||||
if not input_variables:
|
||||
return
|
||||
if not methods:
|
||||
return None
|
||||
|
||||
if input_variables.get(request.method) is None:
|
||||
return inputs
|
||||
if not method:
|
||||
return result
|
||||
|
||||
given_variables = {}
|
||||
given_variables[DataSource.DATA] = request.get_json() if request.data else {}
|
||||
given_variables[DataSource.VALUES] = request.values
|
||||
for input_variable in input_variables[request.method]:
|
||||
given_variables = DataSource(request)
|
||||
|
||||
for noted_var in noted_variables:
|
||||
if (
|
||||
input_variable.required and
|
||||
not input_variable.name in given_variables[input_variable.source]
|
||||
noted_var.required and
|
||||
not noted_var.name in given_variables[noted_var.source]
|
||||
):
|
||||
raise KeyNotFound(input_variable.name)
|
||||
raise KeyNotFound(noted_var.name)
|
||||
|
||||
input_value = given_variables[input_variable.source].get(
|
||||
input_variable.name,
|
||||
input_variable.default
|
||||
)
|
||||
value: InputVariable = input_variable(input_value)
|
||||
input_value = given_variables[noted_var.source].get(noted_var.name)
|
||||
value: InputVariable = noted_var(input_value)
|
||||
|
||||
if not value.validate():
|
||||
raise InvalidKeyValue(input_variable.name, input_value)
|
||||
raise InvalidKeyValue(noted_var.name, input_value)
|
||||
|
||||
inputs[input_variable.name] = value.value
|
||||
return inputs
|
||||
result[noted_var.name] = value.value
|
||||
return result
|
||||
|
||||
|
||||
api_docs: Dict[str, Dict[str, Any]] = {}
|
||||
class APIBlueprint(Blueprint):
|
||||
def route(
|
||||
self,
|
||||
rule: str,
|
||||
description: str = '',
|
||||
input_variables: Dict[str, List[Union[List[InputVariable], str]]] = {},
|
||||
input_variables: Methods = Methods(),
|
||||
requires_auth: bool = True,
|
||||
**options: Any
|
||||
) -> Callable[[T_route], T_route]:
|
||||
@@ -411,22 +454,13 @@ class APIBlueprint(Blueprint):
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
api_docs[processed_rule] = {
|
||||
'endpoint': processed_rule,
|
||||
'description': description,
|
||||
'requires_auth': requires_auth,
|
||||
'methods': options['methods'],
|
||||
'input_variables': {
|
||||
k: v[0]
|
||||
for k, v in input_variables.items()
|
||||
if v and v[0]
|
||||
},
|
||||
'method_descriptions': {
|
||||
k: v[1]
|
||||
for k, v in input_variables.items()
|
||||
if v and len(v) == 2 and v[1]
|
||||
}
|
||||
}
|
||||
api_docs[processed_rule] = ApiDocEntry(
|
||||
endpoint=processed_rule,
|
||||
description=description,
|
||||
requires_auth=requires_auth,
|
||||
used_methods=options['methods'],
|
||||
methods=input_variables
|
||||
)
|
||||
|
||||
return super().route(rule, **options)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user