mirror of
https://github.com/Casvt/MIND.git
synced 2026-02-19 11:54:46 -05:00
Added 2FA in the backend
This commit is contained in:
@@ -158,6 +158,18 @@ class APIKeyExpired(LogUnauthMindException):
|
||||
}
|
||||
|
||||
|
||||
class MFACodeRequired(MindException):
|
||||
"An MFA code is sent and now expected to be supplied"
|
||||
|
||||
@property
|
||||
def api_response(self) -> ApiResponse:
|
||||
return {
|
||||
'code': 200,
|
||||
'error': self.__class__.__name__,
|
||||
'result': {}
|
||||
}
|
||||
|
||||
|
||||
# region Admin Operations
|
||||
class OperationNotAllowed(MindException):
|
||||
"What was requested to be done is not allowed"
|
||||
|
||||
@@ -9,18 +9,15 @@ from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Literal,
|
||||
Sequence, Tuple, TypedDict, TypeVar, Union, cast)
|
||||
from typing import (Any, Callable, Dict, List, Literal, Sequence,
|
||||
Tuple, TypedDict, TypeVar, Union, cast)
|
||||
|
||||
from flask import Response
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.implementations.users import User
|
||||
|
||||
|
||||
# region Types
|
||||
T = TypeVar('T')
|
||||
U = TypeVar('U')
|
||||
MISSING = object()
|
||||
WEEKDAY_NUMBER = Literal[0, 1, 2, 3, 4, 5, 6]
|
||||
|
||||
BaseJSONSerialisable = Union[
|
||||
@@ -56,6 +53,7 @@ class Constants:
|
||||
ADMIN_PREFIX = API_PREFIX + ADMIN_API_EXTENSION
|
||||
API_KEY_LENGTH = 32 # hexadecimal characters
|
||||
API_KEY_CLEANUP_INTERVAL = 86400 # seconds
|
||||
MFA_CODE_TIMEOUT = 300 # seconds
|
||||
|
||||
DB_FOLDER = ("db",)
|
||||
DB_NAME = "MIND.db"
|
||||
@@ -238,12 +236,6 @@ class StartTypeHandler(ABC):
|
||||
|
||||
|
||||
# region Dataclasses
|
||||
@dataclass
|
||||
class ApiKeyEntry:
|
||||
exp: int
|
||||
user_data: UserData
|
||||
|
||||
|
||||
@dataclass(frozen=True, order=True)
|
||||
class NotificationServiceData:
|
||||
id: int
|
||||
@@ -261,12 +253,13 @@ class UserData:
|
||||
admin: bool
|
||||
salt: bytes
|
||||
hash: bytes
|
||||
mfa_apprise_url: Union[str, None]
|
||||
|
||||
def todict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
k: v
|
||||
for k, v in self.__dict__.items()
|
||||
if k in ('id', 'username', 'admin')
|
||||
if k in ('id', 'username', 'admin', 'mfa_apprise_url')
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -262,6 +262,15 @@ def generate_api_key() -> str:
|
||||
return token_hex(Constants.API_KEY_LENGTH // 2)
|
||||
|
||||
|
||||
def generate_mfa_code() -> str:
|
||||
"""Generate a 6-digit MFA code.
|
||||
|
||||
Returns:
|
||||
str: The code.
|
||||
"""
|
||||
return str(int.from_bytes(token_bytes(3), 'big') % 1_000_000).zfill(6)
|
||||
|
||||
|
||||
# region Apprise
|
||||
def send_apprise_notification(
|
||||
urls: List[str],
|
||||
|
||||
@@ -82,10 +82,13 @@ class User:
|
||||
if self.user_db.taken(new_username):
|
||||
raise UsernameTaken(new_username)
|
||||
|
||||
user_data = self.get()
|
||||
|
||||
self.user_db.update(
|
||||
self.user_id,
|
||||
new_username,
|
||||
self.get().hash
|
||||
user_data.hash,
|
||||
user_data.mfa_apprise_url
|
||||
)
|
||||
|
||||
LOGGER.info(
|
||||
@@ -106,7 +109,8 @@ class User:
|
||||
self.user_db.update(
|
||||
self.user_id,
|
||||
user_data.username,
|
||||
hash_password
|
||||
hash_password,
|
||||
user_data.mfa_apprise_url
|
||||
)
|
||||
|
||||
LOGGER.info(
|
||||
@@ -114,6 +118,26 @@ class User:
|
||||
)
|
||||
return
|
||||
|
||||
def update_mfa_apprise_url(
|
||||
self,
|
||||
new_mfa_apprise_url: Union[str, None]
|
||||
) -> None:
|
||||
"""Change the MFA Apprise URL of the account.
|
||||
|
||||
Args:
|
||||
new_mfa_apprise_url (Union[str, None]): The new MFA Apprise URL.
|
||||
"""
|
||||
user_data = self.get()
|
||||
|
||||
self.user_db.update(
|
||||
self.user_id,
|
||||
user_data.username,
|
||||
user_data.hash,
|
||||
new_mfa_apprise_url
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
def delete(self) -> None:
|
||||
"""Delete the user. The instance should not be used after calling this
|
||||
method.
|
||||
|
||||
@@ -387,7 +387,8 @@ DB_SCHEMA = """
|
||||
username VARCHAR(255) UNIQUE NOT NULL,
|
||||
salt VARCHAR(40) NOT NULL,
|
||||
hash VARCHAR(100) NOT NULL,
|
||||
admin BOOL NOT NULL DEFAULT 0
|
||||
admin BOOL NOT NULL DEFAULT 0,
|
||||
mfa_apprise_url TEXT
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS notification_services(
|
||||
id INTEGER PRIMARY KEY,
|
||||
|
||||
@@ -382,3 +382,16 @@ class MigrateAddCronScheduleColumn(DBMigrator):
|
||||
COMMIT;
|
||||
PRAGMA foreign_keys = ON;
|
||||
""")
|
||||
|
||||
|
||||
class MigrateAddMFAColumn(DBMigrator):
|
||||
start_version = 13
|
||||
|
||||
def run(self) -> None:
|
||||
# V13 -> V14
|
||||
|
||||
get_db().executescript("""
|
||||
ALTER TABLE users
|
||||
ADD mfa_apprise_url TEXT;
|
||||
""")
|
||||
return
|
||||
|
||||
@@ -275,7 +275,7 @@ class UsersDB:
|
||||
|
||||
result = get_db().execute(f"""
|
||||
SELECT
|
||||
id, username, admin, salt, hash
|
||||
id, username, admin, salt, hash, mfa_apprise_url
|
||||
FROM users
|
||||
{id_filter}
|
||||
ORDER BY admin DESC, LOWER(username);
|
||||
@@ -310,16 +310,20 @@ class UsersDB:
|
||||
self,
|
||||
user_id: int,
|
||||
username: str,
|
||||
hash: bytes
|
||||
hash: bytes,
|
||||
mfa_apprise_url: Union[str, None]
|
||||
) -> None:
|
||||
get_db().execute("""
|
||||
UPDATE users
|
||||
SET username = :username, hash = :hash
|
||||
SET username = :username,
|
||||
hash = :hash,
|
||||
mfa_apprise_url = :mfa_apprise_url
|
||||
WHERE id = :user_id;
|
||||
""",
|
||||
{
|
||||
"username": username,
|
||||
"hash": hash,
|
||||
"mfa_apprise_url": mfa_apprise_url or None,
|
||||
"user_id": user_id
|
||||
}
|
||||
)
|
||||
|
||||
155
frontend/api.py
155
frontend/api.py
@@ -5,15 +5,17 @@ from io import BytesIO
|
||||
from os import remove
|
||||
from os.path import basename
|
||||
from time import time as epoch_time
|
||||
from typing import TYPE_CHECKING, Any, Dict, cast
|
||||
from typing import TYPE_CHECKING, Any, Dict, Tuple, cast
|
||||
|
||||
from flask import after_this_request, g as flask_g, request, send_file
|
||||
|
||||
from backend.base.custom_exceptions import APIKeyExpired, APIKeyInvalid
|
||||
from backend.base.definitions import (ApiKeyEntry, Constants,
|
||||
from backend.base.custom_exceptions import (AccessUnauthorized, APIKeyExpired,
|
||||
APIKeyInvalid, MFACodeRequired)
|
||||
from backend.base.definitions import (MISSING, Constants, Interval,
|
||||
SendResult, StartType, UserData)
|
||||
from backend.base.helpers import (folder_path, generate_api_key,
|
||||
hash_api_key, return_api)
|
||||
generate_mfa_code, hash_api_key, return_api,
|
||||
send_apprise_notification)
|
||||
from backend.base.logging import LOGGER, get_log_file_contents
|
||||
from backend.implementations.apprise_parser import get_apprise_services
|
||||
from backend.implementations.notification_services import NotificationServices
|
||||
@@ -46,40 +48,8 @@ from frontend.input_validation import (AboutData, AuthLoginData,
|
||||
UsersManagementData, admin_api, api,
|
||||
get_api_docs, input_validation)
|
||||
|
||||
# region Auth and input
|
||||
# region Auth Management and Input
|
||||
users = Users()
|
||||
api_key_map: Dict[str, ApiKeyEntry] = {}
|
||||
|
||||
|
||||
class ApiKeyMapping:
|
||||
_next_run: int = 0
|
||||
|
||||
@classmethod
|
||||
def cleanup(cls) -> None:
|
||||
"""Cleans up expired API keys from the mapping."""
|
||||
now = int(epoch_time())
|
||||
if now < cls._next_run:
|
||||
return
|
||||
cls._next_run = now + Constants.API_KEY_CLEANUP_INTERVAL
|
||||
|
||||
to_delete = [
|
||||
k
|
||||
for k, v in api_key_map.items()
|
||||
if v.exp + 86400 <= now
|
||||
]
|
||||
for k in to_delete:
|
||||
del api_key_map[k]
|
||||
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def remove_user(user_id: int) -> None:
|
||||
for key, value in api_key_map.items():
|
||||
if value.user_data.id == user_id:
|
||||
del api_key_map[key]
|
||||
break
|
||||
return
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
class TypedAppCtxGlobals:
|
||||
@@ -93,6 +63,50 @@ else:
|
||||
g = flask_g
|
||||
|
||||
|
||||
class AuthManager:
|
||||
_next_run: int = 0
|
||||
api_key_map: Dict[str, Tuple[UserData, int]] = {}
|
||||
mfa_code_map: Dict[int, Tuple[str, int]] = {}
|
||||
|
||||
@classmethod
|
||||
def cleanup(cls) -> None:
|
||||
"""Cleans up expired API keys and MFA codes"""
|
||||
|
||||
now = int(epoch_time())
|
||||
if now < cls._next_run:
|
||||
return
|
||||
|
||||
cls._next_run = now + Constants.API_KEY_CLEANUP_INTERVAL
|
||||
|
||||
expired_api_keys = [
|
||||
k
|
||||
for k, (_, exp) in cls.api_key_map.items()
|
||||
# Allow one day to respond with expired key
|
||||
# instead of invalid key
|
||||
if exp + Interval.ONE_DAY.value <= now
|
||||
]
|
||||
for k in expired_api_keys:
|
||||
del cls.api_key_map[k]
|
||||
|
||||
expired_mfa_codes = [
|
||||
k
|
||||
for k, (_, exp) in cls.mfa_code_map.items()
|
||||
if exp <= now
|
||||
]
|
||||
for k in expired_mfa_codes:
|
||||
del cls.mfa_code_map[k]
|
||||
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def remove_user(cls, user_id: int) -> None:
|
||||
for key, value in cls.api_key_map.items():
|
||||
if value[0].id == user_id:
|
||||
del cls.api_key_map[key]
|
||||
break
|
||||
return
|
||||
|
||||
|
||||
def auth() -> None:
|
||||
"""Checks if the client is logged in.
|
||||
|
||||
@@ -103,11 +117,10 @@ def auth() -> None:
|
||||
api_key = request.values.get('api_key', '')
|
||||
hashed_api_key = hash_api_key(api_key)
|
||||
|
||||
if hashed_api_key not in api_key_map:
|
||||
if hashed_api_key not in AuthManager.api_key_map:
|
||||
raise APIKeyInvalid(api_key)
|
||||
|
||||
map_entry = api_key_map[hashed_api_key]
|
||||
user_data = map_entry.user_data
|
||||
user_data, exp = AuthManager.api_key_map[hashed_api_key]
|
||||
|
||||
if (
|
||||
user_data.admin
|
||||
@@ -124,20 +137,20 @@ def auth() -> None:
|
||||
):
|
||||
raise APIKeyInvalid(api_key)
|
||||
|
||||
if map_entry.exp <= epoch_time():
|
||||
del api_key_map[hashed_api_key]
|
||||
if exp <= epoch_time():
|
||||
del AuthManager.api_key_map[hashed_api_key]
|
||||
raise APIKeyExpired(api_key)
|
||||
|
||||
# Api key valid
|
||||
sv = Settings().get_settings()
|
||||
if sv.login_time_reset:
|
||||
map_entry.exp = (
|
||||
exp = (
|
||||
int(epoch_time()) + sv.login_time
|
||||
)
|
||||
|
||||
g.hashed_api_key = hashed_api_key
|
||||
g.user_data = user_data
|
||||
g.exp = map_entry.exp
|
||||
g.exp = exp
|
||||
|
||||
return
|
||||
|
||||
@@ -159,22 +172,50 @@ def api_auth_and_input_validation() -> None:
|
||||
def api_login():
|
||||
user_data = users.login(g.inputs['username'], g.inputs['password']).get()
|
||||
|
||||
# Credentials valid
|
||||
if user_data.mfa_apprise_url:
|
||||
if g.inputs['mfa_code'] is not None:
|
||||
# Validate code
|
||||
mfa_code, exp = AuthManager.mfa_code_map.get(
|
||||
user_data.id, ('', 0)
|
||||
)
|
||||
if not (
|
||||
g.inputs['mfa_code'] == mfa_code
|
||||
and exp > epoch_time()
|
||||
):
|
||||
raise AccessUnauthorized()
|
||||
|
||||
else:
|
||||
mfa_code = generate_mfa_code()
|
||||
|
||||
AuthManager.mfa_code_map[user_data.id] = (
|
||||
mfa_code, int(epoch_time()) + Constants.MFA_CODE_TIMEOUT
|
||||
)
|
||||
|
||||
send_apprise_notification(
|
||||
[user_data.mfa_apprise_url],
|
||||
"MIND MFA Login Code",
|
||||
f"Your login code is: {mfa_code}"
|
||||
)
|
||||
|
||||
raise MFACodeRequired()
|
||||
|
||||
# Login successful
|
||||
|
||||
StartTypeHandlers.diffuse_timer(StartType.RESTART_DB_CHANGES)
|
||||
StartTypeHandlers.diffuse_timer(StartType.RESTART_HOSTING_CHANGES)
|
||||
ApiKeyMapping.cleanup()
|
||||
AuthManager.cleanup()
|
||||
|
||||
# Generate an API key until one is generated that isn't used already
|
||||
while True:
|
||||
api_key = generate_api_key()
|
||||
hashed_api_key = hash_api_key(api_key)
|
||||
if hashed_api_key not in api_key_map:
|
||||
if hashed_api_key not in AuthManager.api_key_map:
|
||||
break
|
||||
|
||||
login_time = Settings().sv.login_time
|
||||
exp = int(epoch_time()) + login_time
|
||||
api_key_map[hashed_api_key] = ApiKeyEntry(exp, user_data)
|
||||
AuthManager.api_key_map[hashed_api_key] = (user_data, exp)
|
||||
|
||||
result = {
|
||||
'api_key': api_key,
|
||||
@@ -186,7 +227,7 @@ def api_login():
|
||||
|
||||
@api.route('/auth/logout', AuthLogoutData)
|
||||
def api_logout():
|
||||
del api_key_map[g.hashed_api_key]
|
||||
del AuthManager.api_key_map[g.hashed_api_key]
|
||||
return return_api({}, code=201)
|
||||
|
||||
|
||||
@@ -212,20 +253,27 @@ def api_add_user():
|
||||
def api_manage_user():
|
||||
user = users.get_one(g.user_data.id)
|
||||
|
||||
if request.method == 'PUT':
|
||||
if request.method == 'GET':
|
||||
result = user.get()
|
||||
return return_api(result.todict())
|
||||
|
||||
elif request.method == 'PUT':
|
||||
new_username = g.inputs['new_username']
|
||||
new_password = g.inputs['new_password']
|
||||
new_mfa_apprise_url = g.inputs['new_mfa_apprise_url']
|
||||
|
||||
if new_username:
|
||||
user.update_username(new_username)
|
||||
if new_password:
|
||||
user.update_password(new_password)
|
||||
if new_mfa_apprise_url != MISSING:
|
||||
user.update_mfa_apprise_url(new_mfa_apprise_url)
|
||||
|
||||
return return_api({})
|
||||
return return_api(user.get().todict())
|
||||
|
||||
elif request.method == 'DELETE':
|
||||
user.delete()
|
||||
del api_key_map[g.hashed_api_key]
|
||||
del AuthManager.api_key_map[g.hashed_api_key]
|
||||
return return_api({})
|
||||
|
||||
|
||||
@@ -590,17 +638,20 @@ def api_admin_user(u_id: int):
|
||||
if request.method == 'PUT':
|
||||
new_username = g.inputs['new_username']
|
||||
new_password = g.inputs['new_password']
|
||||
new_mfa_apprise_url = g.inputs['mfa_apprise_url']
|
||||
|
||||
if new_username:
|
||||
user.update_username(new_username)
|
||||
if new_password:
|
||||
user.update_password(new_password)
|
||||
if new_mfa_apprise_url != MISSING:
|
||||
user.update_mfa_apprise_url(new_mfa_apprise_url)
|
||||
|
||||
return return_api({})
|
||||
|
||||
elif request.method == 'DELETE':
|
||||
user.delete()
|
||||
ApiKeyMapping.remove_user(u_id)
|
||||
AuthManager.remove_user(u_id)
|
||||
return return_api({})
|
||||
|
||||
|
||||
|
||||
@@ -18,10 +18,11 @@ from flask import Blueprint, Request, request
|
||||
from backend.base.custom_exceptions import (AccessUnauthorized,
|
||||
InvalidDatabaseFile,
|
||||
InvalidKeyValue, InvalidTime,
|
||||
KeyNotFound, NewAccountsNotAllowed,
|
||||
KeyNotFound, MFACodeRequired,
|
||||
NewAccountsNotAllowed,
|
||||
NotificationServiceNotFound,
|
||||
UsernameInvalid, UsernameTaken)
|
||||
from backend.base.definitions import (Constants, DataSource, DataType,
|
||||
from backend.base.definitions import (MISSING, Constants, DataSource, DataType,
|
||||
EndpointHandler, MindException,
|
||||
RepeatQuantity, SortingMethod,
|
||||
TimelessSortingMethod)
|
||||
@@ -125,6 +126,19 @@ class PasswordVariable(InputVariable):
|
||||
related_exceptions = [KeyNotFound, AccessUnauthorized]
|
||||
|
||||
|
||||
class MfaCodeVariable(NonRequiredInputVariable):
|
||||
name = "mfa_code"
|
||||
description = "The MFA code sent to the user using the set Apprise URL"
|
||||
related_exceptions = [MFACodeRequired, AccessUnauthorized]
|
||||
|
||||
def validate(self) -> bool:
|
||||
return self.value is None or (
|
||||
isinstance(self.value, str)
|
||||
and len(self.value) == 6
|
||||
and self.value.isdigit()
|
||||
)
|
||||
|
||||
|
||||
class CreatePasswordVariable(PasswordVariable):
|
||||
related_exceptions = [KeyNotFound]
|
||||
|
||||
@@ -150,6 +164,18 @@ class NewPasswordVariable(NonRequiredInputVariable):
|
||||
related_exceptions = [InvalidKeyValue]
|
||||
|
||||
|
||||
class NewMfaAppriseURLVariable(NonRequiredInputVariable):
|
||||
name = "new_mfa_apprise_url"
|
||||
description = "The Apprise URL to use for sending the MFA codes"
|
||||
default = MISSING
|
||||
|
||||
def validate(self) -> bool:
|
||||
return super().validate() and (
|
||||
not isinstance(self.value, str)
|
||||
or Apprise().add(self.value)
|
||||
)
|
||||
|
||||
|
||||
class TitleVariable(InputVariable):
|
||||
name = "title"
|
||||
description = "The title of the entry"
|
||||
@@ -540,7 +566,7 @@ class AuthLoginData(EndpointData):
|
||||
description = "Login to a user account"
|
||||
requires_auth = False
|
||||
methods = Methods(
|
||||
post=("", [UsernameVariable, PasswordVariable])
|
||||
post=("", [UsernameVariable, PasswordVariable, MfaCodeVariable])
|
||||
)
|
||||
|
||||
|
||||
@@ -565,9 +591,10 @@ class UsersAddData(EndpointData):
|
||||
class UsersData(EndpointData):
|
||||
description = "Manage a user account"
|
||||
methods = Methods(
|
||||
get=("Get info of the user account", []),
|
||||
put=(
|
||||
"Change the password of the user account",
|
||||
[NewUsernameVariable, NewPasswordVariable]
|
||||
"Change the settings of the user account",
|
||||
[NewUsernameVariable, NewPasswordVariable, NewMfaAppriseURLVariable]
|
||||
),
|
||||
delete=(
|
||||
"Delete the user account",
|
||||
@@ -839,8 +866,8 @@ class UserManagementData(EndpointData):
|
||||
description = "Manage a specific user"
|
||||
methods = Methods(
|
||||
put=(
|
||||
"Change the password of the user account",
|
||||
[NewUsernameVariable, NewPasswordVariable]
|
||||
"Change the settings of the user account",
|
||||
[NewUsernameVariable, NewPasswordVariable, NewMfaAppriseURLVariable]
|
||||
),
|
||||
delete=(
|
||||
"Delete the user account",
|
||||
|
||||
Reference in New Issue
Block a user