Moved API inputs to context globals

This commit is contained in:
CasVT
2025-08-25 23:29:16 +02:00
parent 6244aff737
commit 42891f5f32
3 changed files with 69 additions and 82 deletions

View File

@@ -38,8 +38,8 @@ EndpointResponse = Union[
None
]
EndpointHandler = Union[
Callable[[Dict[str, Any]], EndpointResponse],
Callable[[Dict[str, Any], int], EndpointResponse]
Callable[[], EndpointResponse],
Callable[[int], EndpointResponse]
]

View File

@@ -5,14 +5,14 @@ from io import BytesIO, StringIO
from os import remove, urandom
from os.path import basename, exists
from time import time as epoch_time
from typing import Any, Callable, Dict
from typing import Any, Dict
from flask import g, request, send_file
from backend.base.custom_exceptions import (APIKeyExpired, APIKeyInvalid,
LogFileNotFound)
from backend.base.definitions import (ApiKeyEntry, Constants,
EndpointHandler, SendResult, StartType)
SendResult, StartType)
from backend.base.helpers import folder_path, return_api
from backend.base.logging import LOGGER, get_log_filepath
from backend.implementations.apprise_parser import get_apprise_services
@@ -46,7 +46,7 @@ from frontend.input_validation import (AboutData, AuthLoginData,
UsersManagementData, admin_api, api,
get_api_docs, input_validation)
# region General variables and functions
# region Auth and input
users = Users()
api_key_map: Dict[int, ApiKeyEntry] = {}
@@ -100,24 +100,22 @@ def auth() -> None:
return
def endpoint_wrapper(method: EndpointHandler) -> Callable:
def wrapper(*args, **kwargs):
requires_auth = get_api_docs(request).requires_auth
@api.before_request
@admin_api.before_request
def api_auth_and_input_validation() -> None:
requires_auth = get_api_docs(request).requires_auth
if requires_auth:
auth()
if requires_auth:
auth()
inputs = input_validation()
return method(inputs, *args, **kwargs)
wrapper.__name__ = method.__name__
return wrapper
g.inputs = input_validation()
return
# region Auth
@api.route('/auth/login', AuthLoginData)
@endpoint_wrapper
def api_login(inputs: Dict[str, Any]):
def api_login():
inputs: Dict[str, Any] = g.inputs
user = users.login(inputs['username'], inputs['password'])
# Login successful
@@ -146,15 +144,13 @@ def api_login(inputs: Dict[str, Any]):
@api.route('/auth/logout', AuthLogoutData)
@endpoint_wrapper
def api_logout(inputs: Dict[str, Any]):
def api_logout():
api_key_map.pop(g.hashed_api_key)
return return_api({}, code=201)
@api.route('/auth/status', AuthStatusData)
@endpoint_wrapper
def api_status(inputs: Dict[str, Any]):
def api_status():
map_entry = api_key_map[g.hashed_api_key]
user_data = map_entry.user_data.get()
result = {
@@ -167,17 +163,17 @@ def api_status(inputs: Dict[str, Any]):
# region User
@api.route('/user/add', UsersAddData)
@endpoint_wrapper
def api_add_user(inputs: Dict[str, str]):
def api_add_user():
inputs: Dict[str, Any] = g.inputs
users.add(inputs['username'], inputs['password'])
return return_api({}, code=201)
@api.route('/user', UsersData)
@endpoint_wrapper
def api_manage_user(inputs: Dict[str, Any]):
def api_manage_user():
user = api_key_map[g.hashed_api_key].user_data
if request.method == 'PUT':
inputs: Dict[str, Any] = g.inputs
if inputs['new_username']:
user.update_username(inputs['new_username'])
if inputs['new_password']:
@@ -192,8 +188,7 @@ def api_manage_user(inputs: Dict[str, Any]):
# region Notification Service
@api.route('/notificationservices', NotificationServicesData)
@endpoint_wrapper
def api_notification_services_list(inputs: Dict[str, str]):
def api_notification_services_list():
services = NotificationServices(
api_key_map[g.hashed_api_key].user_data.user_id
)
@@ -203,6 +198,7 @@ def api_notification_services_list(inputs: Dict[str, str]):
return return_api(result=[r.todict() for r in result])
elif request.method == 'POST':
inputs: Dict[str, Any] = g.inputs
result = services.add(
title=inputs['title'],
url=inputs['url']
@@ -211,16 +207,14 @@ def api_notification_services_list(inputs: Dict[str, str]):
@api.route('/notificationservices/available', AvailableNotificationServicesData)
@endpoint_wrapper
def api_notification_service_available(inputs: Dict[str, str]):
def api_notification_service_available():
result = get_apprise_services()
return return_api(result) # type: ignore
@api.route('/notificationservices/test', TestNotificationServiceURLData)
@endpoint_wrapper
def api_test_service(inputs: Dict[str, Any]):
success = NotificationServices.test(inputs['url'])
def api_test_service():
success = NotificationServices.test(g.inputs['url'])
return return_api(
{
'success': success == SendResult.SUCCESS,
@@ -231,8 +225,8 @@ def api_test_service(inputs: Dict[str, Any]):
@api.route('/notificationservices/<int:n_id>', NotificationServiceData)
@endpoint_wrapper
def api_notification_service(inputs: Dict[str, Any], n_id: int):
def api_notification_service(n_id: int):
inputs: Dict[str, Any] = g.inputs
user_id = api_key_map[g.hashed_api_key].user_data.user_id
service = NotificationServices(user_id).get_one(n_id)
@@ -256,8 +250,8 @@ def api_notification_service(inputs: Dict[str, Any], n_id: int):
# region Library
@api.route('/reminders', RemindersData)
@endpoint_wrapper
def api_reminders_list(inputs: Dict[str, Any]):
def api_reminders_list():
inputs: Dict[str, Any] = g.inputs
reminders = Reminders(api_key_map[g.hashed_api_key].user_data.user_id)
if request.method == 'GET':
@@ -281,16 +275,16 @@ def api_reminders_list(inputs: Dict[str, Any]):
@api.route('/reminders/search', SearchRemindersData)
@endpoint_wrapper
def api_reminders_query(inputs: Dict[str, Any]):
def api_reminders_query():
inputs: Dict[str, Any] = g.inputs
reminders = Reminders(api_key_map[g.hashed_api_key].user_data.user_id)
result = reminders.search(inputs['query'], inputs['sort_by'])
return return_api([r.todict() for r in result])
@api.route('/reminders/test', TestRemindersData)
@endpoint_wrapper
def api_test_reminder(inputs: Dict[str, Any]):
def api_test_reminder():
inputs: Dict[str, Any] = g.inputs
Reminders(
api_key_map[g.hashed_api_key].user_data.user_id
).test_reminder(
@@ -302,8 +296,7 @@ def api_test_reminder(inputs: Dict[str, Any]):
@api.route('/reminders/<int:r_id>', ReminderData)
@endpoint_wrapper
def api_get_reminder(inputs: Dict[str, Any], r_id: int):
def api_get_reminder(r_id: int):
reminders = Reminders(
api_key_map[g.hashed_api_key].user_data.user_id
)
@@ -313,6 +306,7 @@ def api_get_reminder(inputs: Dict[str, Any], r_id: int):
return return_api(result.todict())
elif request.method == 'PUT':
inputs: Dict[str, Any] = g.inputs
result = reminders.get_one(r_id).update(
title=inputs['title'],
time=inputs['time'],
@@ -334,8 +328,8 @@ def api_get_reminder(inputs: Dict[str, Any], r_id: int):
# region Template
@api.route('/templates', TemplatesData)
@endpoint_wrapper
def api_get_templates(inputs: Dict[str, Any]):
def api_get_templates():
inputs: Dict[str, Any] = g.inputs
templates = Templates(
api_key_map[g.hashed_api_key].user_data.user_id
)
@@ -355,8 +349,8 @@ def api_get_templates(inputs: Dict[str, Any]):
@api.route('/templates/search', SearchTemplatesData)
@endpoint_wrapper
def api_templates_query(inputs: Dict[str, Any]):
def api_templates_query():
inputs: Dict[str, Any] = g.inputs
templates = Templates(
api_key_map[g.hashed_api_key].user_data.user_id
)
@@ -365,8 +359,7 @@ def api_templates_query(inputs: Dict[str, Any]):
@api.route('/templates/<int:t_id>', TemplateData)
@endpoint_wrapper
def api_get_template(inputs: Dict[str, Any], t_id: int):
def api_get_template(t_id: int):
template = Templates(
api_key_map[g.hashed_api_key].user_data.user_id
).get_one(t_id)
@@ -376,6 +369,7 @@ def api_get_template(inputs: Dict[str, Any], t_id: int):
return return_api(result.todict())
elif request.method == 'PUT':
inputs: Dict[str, Any] = g.inputs
result = template.update(
title=inputs['title'],
notification_services=inputs['notification_services'],
@@ -391,8 +385,8 @@ def api_get_template(inputs: Dict[str, Any], t_id: int):
# region Static Reminder
@api.route('/staticreminders', StaticRemindersData)
@endpoint_wrapper
def api_static_reminders_list(inputs: Dict[str, Any]):
def api_static_reminders_list():
inputs: Dict[str, Any] = g.inputs
reminders = StaticReminders(
api_key_map[g.hashed_api_key].user_data.user_id
)
@@ -412,8 +406,8 @@ def api_static_reminders_list(inputs: Dict[str, Any]):
@api.route('/staticreminders/search', SearchStaticRemindersData)
@endpoint_wrapper
def api_static_reminders_query(inputs: Dict[str, Any]):
def api_static_reminders_query():
inputs: Dict[str, Any] = g.inputs
result = StaticReminders(
api_key_map[g.hashed_api_key].user_data.user_id
).search(inputs['query'], inputs['sort_by'])
@@ -421,8 +415,7 @@ def api_static_reminders_query(inputs: Dict[str, Any]):
@api.route('/staticreminders/<int:s_id>', StaticReminderData)
@endpoint_wrapper
def api_get_static_reminder(inputs: Dict[str, Any], s_id: int):
def api_get_static_reminder(s_id: int):
reminders = StaticReminders(
api_key_map[g.hashed_api_key].user_data.user_id
)
@@ -436,6 +429,7 @@ def api_get_static_reminder(inputs: Dict[str, Any], s_id: int):
return return_api({}, code=201)
elif request.method == 'PUT':
inputs: Dict[str, Any] = g.inputs
result = reminders.get_one(s_id).update(
title=inputs['title'],
notification_services=inputs['notification_services'],
@@ -451,34 +445,30 @@ def api_get_static_reminder(inputs: Dict[str, Any], s_id: int):
# region Admin Panel
@admin_api.route('/shutdown', ShutdownData)
@endpoint_wrapper
def api_shutdown(inputs: Dict[str, Any]):
def api_shutdown():
Server().shutdown()
return return_api({})
@admin_api.route('/restart', RestartData)
@endpoint_wrapper
def api_restart(inputs: Dict[str, Any]):
def api_restart():
Server().restart()
return return_api({})
@api.route('/settings', PublicSettingsData)
@endpoint_wrapper
def api_settings(inputs: Dict[str, Any]):
def api_settings():
return return_api(Settings().get_public_settings().todict())
@api.route('/about', AboutData)
@endpoint_wrapper
def api_about(inputs: Dict[str, Any]):
def api_about():
return return_api(get_about_data())
@admin_api.route('/settings', SettingsData)
@endpoint_wrapper
def api_admin_settings(inputs: Dict[str, Any]):
def api_admin_settings():
inputs: Dict[str, Any] = g.inputs
settings = Settings()
if request.method == 'GET':
@@ -530,8 +520,7 @@ def api_admin_settings(inputs: Dict[str, Any]):
@admin_api.route('/logs', LogfileData)
@endpoint_wrapper
def api_admin_logs(inputs: Dict[str, Any]):
def api_admin_logs():
file = get_log_filepath()
if not exists(file):
raise LogFileNotFound(file)
@@ -552,22 +541,22 @@ def api_admin_logs(inputs: Dict[str, Any]):
@admin_api.route('/users', UsersManagementData)
@endpoint_wrapper
def api_admin_users(inputs: Dict[str, Any]):
def api_admin_users():
if request.method == 'GET':
result = users.get_all()
return return_api([r.todict() for r in result])
elif request.method == 'POST':
inputs: Dict[str, Any] = g.inputs
users.add(inputs['username'], inputs['password'], True)
return return_api({}, code=201)
@admin_api.route('/users/<int:u_id>', UserManagementData)
@endpoint_wrapper
def api_admin_user(inputs: Dict[str, Any], u_id: int):
def api_admin_user(u_id: int):
user = users.get_one(u_id)
if request.method == 'PUT':
inputs: Dict[str, Any] = g.inputs
if inputs['new_username']:
user.update_username(inputs['new_username'])
if inputs['new_password']:
@@ -584,8 +573,7 @@ def api_admin_user(inputs: Dict[str, Any], u_id: int):
@admin_api.route('/database', DatabaseData)
@endpoint_wrapper
def api_admin_database(inputs: Dict[str, Any]):
def api_admin_database():
if request.method == "GET":
filename = create_database_copy(folder_path('db'))
@@ -602,19 +590,18 @@ def api_admin_database(inputs: Dict[str, Any]):
), 200
elif request.method == "POST":
inputs: Dict[str, Any] = g.inputs
import_db(inputs['file'], inputs['copy_hosting_settings'])
return return_api({})
@admin_api.route('/database/backups', BackupsData)
@endpoint_wrapper
def api_admin_backups(inputs: Dict[str, Any]):
def api_admin_backups():
return return_api(get_backups())
@admin_api.route('/database/backups/<int:b_idx>', BackupData)
@endpoint_wrapper
def api_admin_backup(inputs: Dict[str, Any], b_idx: int):
def api_admin_backup(b_idx: int):
if request.method == "GET":
return send_file(
get_backup(b_idx)['filepath'],
@@ -622,5 +609,6 @@ def api_admin_backup(inputs: Dict[str, Any], b_idx: int):
), 200
elif request.method == "POST":
inputs: Dict[str, Any] = g.inputs
import_db_backup(b_idx, inputs['copy_hosting_settings'])
return return_api({})

View File

@@ -22,14 +22,14 @@ from backend.base.custom_exceptions import (AccessUnauthorized,
NotificationServiceNotFound,
UsernameInvalid, UsernameTaken)
from backend.base.definitions import (Constants, DataSource, DataType,
MindException, RepeatQuantity,
SortingMethod, TimelessSortingMethod)
EndpointHandler, MindException,
RepeatQuantity, SortingMethod,
TimelessSortingMethod)
from backend.base.helpers import folder_path
from backend.internals.settings import SettingsValues
if TYPE_CHECKING:
from flask import Request
from flask.sansio.scaffold import T_route
# ===================
@@ -980,8 +980,7 @@ class APIBlueprint(Blueprint):
rule: str,
endpoint_data: Type[EndpointData],
**options: Any
) -> Callable[[T_route], T_route]:
) -> Callable[[EndpointHandler], EndpointHandler]:
if self == api:
processed_rule = rule
elif self == admin_api:
@@ -994,7 +993,7 @@ class APIBlueprint(Blueprint):
if "methods" not in options:
options["methods"] = API_DOCS[processed_rule].methods.used_methods()
return super().route(rule, **options)
return super().route(rule, **options) # type: ignore
api = APIBlueprint('api', __name__)