diff --git a/backend/base/definitions.py b/backend/base/definitions.py index 680bf33..a00f6f9 100644 --- a/backend/base/definitions.py +++ b/backend/base/definitions.py @@ -38,8 +38,8 @@ EndpointResponse = Union[ None ] EndpointHandler = Union[ - Callable[[Dict[str, Any]], EndpointResponse], - Callable[[Dict[str, Any], int], EndpointResponse] + Callable[[], EndpointResponse], + Callable[[int], EndpointResponse] ] diff --git a/frontend/api.py b/frontend/api.py index 73e5013..6d5912e 100644 --- a/frontend/api.py +++ b/frontend/api.py @@ -5,14 +5,14 @@ from io import BytesIO, StringIO from os import remove, urandom from os.path import basename, exists from time import time as epoch_time -from typing import Any, Callable, Dict +from typing import Any, Dict from flask import g, request, send_file from backend.base.custom_exceptions import (APIKeyExpired, APIKeyInvalid, LogFileNotFound) from backend.base.definitions import (ApiKeyEntry, Constants, - EndpointHandler, SendResult, StartType) + SendResult, StartType) from backend.base.helpers import folder_path, return_api from backend.base.logging import LOGGER, get_log_filepath from backend.implementations.apprise_parser import get_apprise_services @@ -46,7 +46,7 @@ from frontend.input_validation import (AboutData, AuthLoginData, UsersManagementData, admin_api, api, get_api_docs, input_validation) -# region General variables and functions +# region Auth and input users = Users() api_key_map: Dict[int, ApiKeyEntry] = {} @@ -100,24 +100,22 @@ def auth() -> None: return -def endpoint_wrapper(method: EndpointHandler) -> Callable: - def wrapper(*args, **kwargs): - requires_auth = get_api_docs(request).requires_auth +@api.before_request +@admin_api.before_request +def api_auth_and_input_validation() -> None: + requires_auth = get_api_docs(request).requires_auth - if requires_auth: - auth() + if requires_auth: + auth() - inputs = input_validation() - return method(inputs, *args, **kwargs) - - wrapper.__name__ = method.__name__ - return wrapper + g.inputs = input_validation() + return # region Auth @api.route('/auth/login', AuthLoginData) -@endpoint_wrapper -def api_login(inputs: Dict[str, Any]): +def api_login(): + inputs: Dict[str, Any] = g.inputs user = users.login(inputs['username'], inputs['password']) # Login successful @@ -146,15 +144,13 @@ def api_login(inputs: Dict[str, Any]): @api.route('/auth/logout', AuthLogoutData) -@endpoint_wrapper -def api_logout(inputs: Dict[str, Any]): +def api_logout(): api_key_map.pop(g.hashed_api_key) return return_api({}, code=201) @api.route('/auth/status', AuthStatusData) -@endpoint_wrapper -def api_status(inputs: Dict[str, Any]): +def api_status(): map_entry = api_key_map[g.hashed_api_key] user_data = map_entry.user_data.get() result = { @@ -167,17 +163,17 @@ def api_status(inputs: Dict[str, Any]): # region User @api.route('/user/add', UsersAddData) -@endpoint_wrapper -def api_add_user(inputs: Dict[str, str]): +def api_add_user(): + inputs: Dict[str, Any] = g.inputs users.add(inputs['username'], inputs['password']) return return_api({}, code=201) @api.route('/user', UsersData) -@endpoint_wrapper -def api_manage_user(inputs: Dict[str, Any]): +def api_manage_user(): user = api_key_map[g.hashed_api_key].user_data if request.method == 'PUT': + inputs: Dict[str, Any] = g.inputs if inputs['new_username']: user.update_username(inputs['new_username']) if inputs['new_password']: @@ -192,8 +188,7 @@ def api_manage_user(inputs: Dict[str, Any]): # region Notification Service @api.route('/notificationservices', NotificationServicesData) -@endpoint_wrapper -def api_notification_services_list(inputs: Dict[str, str]): +def api_notification_services_list(): services = NotificationServices( api_key_map[g.hashed_api_key].user_data.user_id ) @@ -203,6 +198,7 @@ def api_notification_services_list(inputs: Dict[str, str]): return return_api(result=[r.todict() for r in result]) elif request.method == 'POST': + inputs: Dict[str, Any] = g.inputs result = services.add( title=inputs['title'], url=inputs['url'] @@ -211,16 +207,14 @@ def api_notification_services_list(inputs: Dict[str, str]): @api.route('/notificationservices/available', AvailableNotificationServicesData) -@endpoint_wrapper -def api_notification_service_available(inputs: Dict[str, str]): +def api_notification_service_available(): result = get_apprise_services() return return_api(result) # type: ignore @api.route('/notificationservices/test', TestNotificationServiceURLData) -@endpoint_wrapper -def api_test_service(inputs: Dict[str, Any]): - success = NotificationServices.test(inputs['url']) +def api_test_service(): + success = NotificationServices.test(g.inputs['url']) return return_api( { 'success': success == SendResult.SUCCESS, @@ -231,8 +225,8 @@ def api_test_service(inputs: Dict[str, Any]): @api.route('/notificationservices/', NotificationServiceData) -@endpoint_wrapper -def api_notification_service(inputs: Dict[str, Any], n_id: int): +def api_notification_service(n_id: int): + inputs: Dict[str, Any] = g.inputs user_id = api_key_map[g.hashed_api_key].user_data.user_id service = NotificationServices(user_id).get_one(n_id) @@ -256,8 +250,8 @@ def api_notification_service(inputs: Dict[str, Any], n_id: int): # region Library @api.route('/reminders', RemindersData) -@endpoint_wrapper -def api_reminders_list(inputs: Dict[str, Any]): +def api_reminders_list(): + inputs: Dict[str, Any] = g.inputs reminders = Reminders(api_key_map[g.hashed_api_key].user_data.user_id) if request.method == 'GET': @@ -281,16 +275,16 @@ def api_reminders_list(inputs: Dict[str, Any]): @api.route('/reminders/search', SearchRemindersData) -@endpoint_wrapper -def api_reminders_query(inputs: Dict[str, Any]): +def api_reminders_query(): + inputs: Dict[str, Any] = g.inputs reminders = Reminders(api_key_map[g.hashed_api_key].user_data.user_id) result = reminders.search(inputs['query'], inputs['sort_by']) return return_api([r.todict() for r in result]) @api.route('/reminders/test', TestRemindersData) -@endpoint_wrapper -def api_test_reminder(inputs: Dict[str, Any]): +def api_test_reminder(): + inputs: Dict[str, Any] = g.inputs Reminders( api_key_map[g.hashed_api_key].user_data.user_id ).test_reminder( @@ -302,8 +296,7 @@ def api_test_reminder(inputs: Dict[str, Any]): @api.route('/reminders/', ReminderData) -@endpoint_wrapper -def api_get_reminder(inputs: Dict[str, Any], r_id: int): +def api_get_reminder(r_id: int): reminders = Reminders( api_key_map[g.hashed_api_key].user_data.user_id ) @@ -313,6 +306,7 @@ def api_get_reminder(inputs: Dict[str, Any], r_id: int): return return_api(result.todict()) elif request.method == 'PUT': + inputs: Dict[str, Any] = g.inputs result = reminders.get_one(r_id).update( title=inputs['title'], time=inputs['time'], @@ -334,8 +328,8 @@ def api_get_reminder(inputs: Dict[str, Any], r_id: int): # region Template @api.route('/templates', TemplatesData) -@endpoint_wrapper -def api_get_templates(inputs: Dict[str, Any]): +def api_get_templates(): + inputs: Dict[str, Any] = g.inputs templates = Templates( api_key_map[g.hashed_api_key].user_data.user_id ) @@ -355,8 +349,8 @@ def api_get_templates(inputs: Dict[str, Any]): @api.route('/templates/search', SearchTemplatesData) -@endpoint_wrapper -def api_templates_query(inputs: Dict[str, Any]): +def api_templates_query(): + inputs: Dict[str, Any] = g.inputs templates = Templates( api_key_map[g.hashed_api_key].user_data.user_id ) @@ -365,8 +359,7 @@ def api_templates_query(inputs: Dict[str, Any]): @api.route('/templates/', TemplateData) -@endpoint_wrapper -def api_get_template(inputs: Dict[str, Any], t_id: int): +def api_get_template(t_id: int): template = Templates( api_key_map[g.hashed_api_key].user_data.user_id ).get_one(t_id) @@ -376,6 +369,7 @@ def api_get_template(inputs: Dict[str, Any], t_id: int): return return_api(result.todict()) elif request.method == 'PUT': + inputs: Dict[str, Any] = g.inputs result = template.update( title=inputs['title'], notification_services=inputs['notification_services'], @@ -391,8 +385,8 @@ def api_get_template(inputs: Dict[str, Any], t_id: int): # region Static Reminder @api.route('/staticreminders', StaticRemindersData) -@endpoint_wrapper -def api_static_reminders_list(inputs: Dict[str, Any]): +def api_static_reminders_list(): + inputs: Dict[str, Any] = g.inputs reminders = StaticReminders( api_key_map[g.hashed_api_key].user_data.user_id ) @@ -412,8 +406,8 @@ def api_static_reminders_list(inputs: Dict[str, Any]): @api.route('/staticreminders/search', SearchStaticRemindersData) -@endpoint_wrapper -def api_static_reminders_query(inputs: Dict[str, Any]): +def api_static_reminders_query(): + inputs: Dict[str, Any] = g.inputs result = StaticReminders( api_key_map[g.hashed_api_key].user_data.user_id ).search(inputs['query'], inputs['sort_by']) @@ -421,8 +415,7 @@ def api_static_reminders_query(inputs: Dict[str, Any]): @api.route('/staticreminders/', StaticReminderData) -@endpoint_wrapper -def api_get_static_reminder(inputs: Dict[str, Any], s_id: int): +def api_get_static_reminder(s_id: int): reminders = StaticReminders( api_key_map[g.hashed_api_key].user_data.user_id ) @@ -436,6 +429,7 @@ def api_get_static_reminder(inputs: Dict[str, Any], s_id: int): return return_api({}, code=201) elif request.method == 'PUT': + inputs: Dict[str, Any] = g.inputs result = reminders.get_one(s_id).update( title=inputs['title'], notification_services=inputs['notification_services'], @@ -451,34 +445,30 @@ def api_get_static_reminder(inputs: Dict[str, Any], s_id: int): # region Admin Panel @admin_api.route('/shutdown', ShutdownData) -@endpoint_wrapper -def api_shutdown(inputs: Dict[str, Any]): +def api_shutdown(): Server().shutdown() return return_api({}) @admin_api.route('/restart', RestartData) -@endpoint_wrapper -def api_restart(inputs: Dict[str, Any]): +def api_restart(): Server().restart() return return_api({}) @api.route('/settings', PublicSettingsData) -@endpoint_wrapper -def api_settings(inputs: Dict[str, Any]): +def api_settings(): return return_api(Settings().get_public_settings().todict()) @api.route('/about', AboutData) -@endpoint_wrapper -def api_about(inputs: Dict[str, Any]): +def api_about(): return return_api(get_about_data()) @admin_api.route('/settings', SettingsData) -@endpoint_wrapper -def api_admin_settings(inputs: Dict[str, Any]): +def api_admin_settings(): + inputs: Dict[str, Any] = g.inputs settings = Settings() if request.method == 'GET': @@ -530,8 +520,7 @@ def api_admin_settings(inputs: Dict[str, Any]): @admin_api.route('/logs', LogfileData) -@endpoint_wrapper -def api_admin_logs(inputs: Dict[str, Any]): +def api_admin_logs(): file = get_log_filepath() if not exists(file): raise LogFileNotFound(file) @@ -552,22 +541,22 @@ def api_admin_logs(inputs: Dict[str, Any]): @admin_api.route('/users', UsersManagementData) -@endpoint_wrapper -def api_admin_users(inputs: Dict[str, Any]): +def api_admin_users(): if request.method == 'GET': result = users.get_all() return return_api([r.todict() for r in result]) elif request.method == 'POST': + inputs: Dict[str, Any] = g.inputs users.add(inputs['username'], inputs['password'], True) return return_api({}, code=201) @admin_api.route('/users/', UserManagementData) -@endpoint_wrapper -def api_admin_user(inputs: Dict[str, Any], u_id: int): +def api_admin_user(u_id: int): user = users.get_one(u_id) if request.method == 'PUT': + inputs: Dict[str, Any] = g.inputs if inputs['new_username']: user.update_username(inputs['new_username']) if inputs['new_password']: @@ -584,8 +573,7 @@ def api_admin_user(inputs: Dict[str, Any], u_id: int): @admin_api.route('/database', DatabaseData) -@endpoint_wrapper -def api_admin_database(inputs: Dict[str, Any]): +def api_admin_database(): if request.method == "GET": filename = create_database_copy(folder_path('db')) @@ -602,19 +590,18 @@ def api_admin_database(inputs: Dict[str, Any]): ), 200 elif request.method == "POST": + inputs: Dict[str, Any] = g.inputs import_db(inputs['file'], inputs['copy_hosting_settings']) return return_api({}) @admin_api.route('/database/backups', BackupsData) -@endpoint_wrapper -def api_admin_backups(inputs: Dict[str, Any]): +def api_admin_backups(): return return_api(get_backups()) @admin_api.route('/database/backups/', BackupData) -@endpoint_wrapper -def api_admin_backup(inputs: Dict[str, Any], b_idx: int): +def api_admin_backup(b_idx: int): if request.method == "GET": return send_file( get_backup(b_idx)['filepath'], @@ -622,5 +609,6 @@ def api_admin_backup(inputs: Dict[str, Any], b_idx: int): ), 200 elif request.method == "POST": + inputs: Dict[str, Any] = g.inputs import_db_backup(b_idx, inputs['copy_hosting_settings']) return return_api({}) diff --git a/frontend/input_validation.py b/frontend/input_validation.py index 26c97a5..8110694 100644 --- a/frontend/input_validation.py +++ b/frontend/input_validation.py @@ -22,14 +22,14 @@ from backend.base.custom_exceptions import (AccessUnauthorized, NotificationServiceNotFound, UsernameInvalid, UsernameTaken) from backend.base.definitions import (Constants, DataSource, DataType, - MindException, RepeatQuantity, - SortingMethod, TimelessSortingMethod) + EndpointHandler, MindException, + RepeatQuantity, SortingMethod, + TimelessSortingMethod) from backend.base.helpers import folder_path from backend.internals.settings import SettingsValues if TYPE_CHECKING: from flask import Request - from flask.sansio.scaffold import T_route # =================== @@ -980,8 +980,7 @@ class APIBlueprint(Blueprint): rule: str, endpoint_data: Type[EndpointData], **options: Any - ) -> Callable[[T_route], T_route]: - + ) -> Callable[[EndpointHandler], EndpointHandler]: if self == api: processed_rule = rule elif self == admin_api: @@ -994,7 +993,7 @@ class APIBlueprint(Blueprint): if "methods" not in options: options["methods"] = API_DOCS[processed_rule].methods.used_methods() - return super().route(rule, **options) + return super().route(rule, **options) # type: ignore api = APIBlueprint('api', __name__)