Files
MIND/backend/base/helpers.py
2025-08-30 23:12:50 +02:00

580 lines
16 KiB
Python

# -*- coding: utf-8 -*-
"""
General "helper" functions and classes
"""
from base64 import urlsafe_b64encode
from datetime import datetime
from hashlib import pbkdf2_hmac, sha256
from logging import WARNING
from os import makedirs, scandir, symlink
from os.path import abspath, dirname, exists, isfile, join, splitext
from secrets import token_bytes, token_hex
from shutil import copy2, move
from sys import base_exec_prefix, executable, platform, version_info
from threading import current_thread
from typing import (Any, Callable, Dict, Iterable, List,
Sequence, Set, Tuple, Union, cast)
from apprise import Apprise, LogCapture
from cron_converter import Cron
from dateutil.relativedelta import relativedelta
from backend.base.definitions import (WEEKDAY_NUMBER, Constants,
GeneralReminderData, JSONSerialisable,
RepeatQuantity, SendResult, T, U)
from backend.base.logging import LOGGER
# region Python
def get_python_version() -> str:
"""Get the Python version as a string. E.g. `"3.8.10.final.0"`.
Returns:
str: The Python version.
"""
return ".".join(
str(i) for i in list(version_info)
)
def check_min_python_version(
min_major: int,
min_minor: int,
min_micro: int
) -> bool:
"""Check whether the version of Python that is used is equal or higher than
the version given. Will log a critical error if not.
```
# On Python3.9.1
>>> check_min_python_version(3, 8, 2)
True
>>> check_min_python_version(3, 10, 0)
False
```
Args:
min_major (int): The minimum major version.
min_minor (int): The minimum minor version.
min_micro (int): The miminum micro version.
Returns:
bool: Whether it's equal or higher than the version given or below it.
"""
min_version = (
min_major,
min_minor,
min_micro
)
current_version = (
version_info.major,
version_info.minor,
version_info.micro
)
if current_version < min_version:
LOGGER.critical(
"The minimum python version required is python"
+ ".".join(map(str, min_version))
+ " (currently " + ".".join(map(str, current_version)) + ")."
)
return False
return True
def get_python_exe() -> Union[str, None]:
"""Get the absolute filepath to the python executable.
Returns:
Union[str, None]: The python executable path, or `None` if not found.
"""
if platform.startswith('darwin'):
filepath = None
bundle_path = join(
base_exec_prefix,
"Resources",
"Python.app",
"Contents",
"MacOS",
"Python"
)
if exists(bundle_path):
from tempfile import mkdtemp
filepath = join(mkdtemp(), "python")
symlink(bundle_path, filepath)
else:
filepath = executable or None
if filepath and not isfile(filepath):
filepath = None
return filepath
def get_version_from_pyproject(filepath: str) -> str:
"""Get the application version from the `pyproject.toml` file.
Args:
filepath (str): The path to the `pyproject.toml` file.
Raises:
RuntimeError: Version not found in file.
Returns:
str: The version string.
"""
with open(filepath, "r") as f:
for line in f:
if line.startswith("version = "):
return "V" + line.split('"')[1]
else:
raise RuntimeError("Version not found in pyproject.toml")
# region Generic
def first_of_subarrays(
subarrays: Iterable[Sequence[T]]
) -> List[T]:
"""Get the first element of each sub-array.
```
>>> first_of_subarrays([[1, 2], [3, 4]])
[1, 3]
```
Args:
subarrays (Iterable[Sequence[T]]): List of sub-arrays.
Returns:
List[T]: List with first value of each sub-array.
"""
return [e[0] for e in subarrays]
def when_not_none(
value: Union[T, None],
to_run: Callable[[T], U]
) -> Union[U, None]:
"""Run `to_run` with argument `value` iff `value is not None`. Else return
`None`.
Args:
value (Union[T, None]): The value to check.
to_run (Callable[[T], U]): The function to run.
Returns:
Union[U, None]: Either the return value of `to_run`, or `None`.
"""
if value is None:
return None
else:
return to_run(value)
def search_filter(query: str, result: GeneralReminderData) -> bool:
"""Filter library results based on a query.
Args:
query (str): The query to filter with.
result (GeneralReminderData): The library result to check.
Returns:
bool: Whether the result passes the filter.
"""
query = query.lower().replace(' ', '')
return (
query.lower().replace(' ', '')
in
(result.title + (result.text or '')).lower().replace(' ', '')
)
def current_thread_id() -> int:
"""Get the ID of the current thread.
Returns:
int: The ID.
"""
return current_thread().native_id or -1
def return_api(
result: JSONSerialisable,
error: Union[str, None] = None,
code: int = 200
) -> Tuple[Dict[str, Any], int]:
return {'error': error, 'result': result}, code
# region Security
def get_hash(salt: bytes, data: str) -> bytes:
"""Hash a string using the supplied salt.
Args:
salt (bytes): The salt to use when hashing.
data (str): The data to hash.
Returns:
bytes: The b64 encoded hash of the supplied string.
"""
return urlsafe_b64encode(
pbkdf2_hmac('sha256', data.encode(), salt, 100_000)
)
def generate_salt_hash(password: str) -> Tuple[bytes, bytes]:
"""Generate a salt and get the hash of the password.
Args:
password (str): The password to generate for.
Returns:
Tuple[bytes, bytes]: The salt (1) and hashed_password (2).
"""
salt = token_bytes()
hashed_password = get_hash(salt, password)
return salt, hashed_password
def hash_api_key(api_key: str) -> str:
"""Hashes an API key using SHA-256.
Args:
api_key (str): The API key to hash.
Returns:
str: The hashed API key as a hexadecimal string.
"""
return sha256(api_key.encode('utf-8')).hexdigest()
def generate_api_key() -> str:
"""Generate an API key.
Returns:
str: The API key.
"""
# Each byte is represented by two hexadecimal characters, so halve
# the desired amount of bytes.
return token_hex(Constants.API_KEY_LENGTH // 2)
def generate_mfa_code() -> str:
"""Generate a 6-digit MFA code.
Returns:
str: The code.
"""
return str(int.from_bytes(token_bytes(3), 'big') % 1_000_000).zfill(6)
# region Apprise
def send_apprise_notification(
urls: List[str],
title: str,
text: Union[str, None] = None
) -> SendResult:
"""Send a notification to all Apprise URLs given.
Args:
urls (List[str]): The Apprise URLs to send the notification to.
title (str): The title of the notification.
text (Union[str, None], optional): The optional body of the
notification.
Defaults to None.
Returns:
SendResult: Whether Apprise was successful.
"""
a = Apprise()
for url in urls:
if not a.add(url):
return SendResult.SYNTAX_INVALID_URL
with LogCapture(level=WARNING) as log:
result = a.notify(
title=title,
body=text or '\u200B'
)
if not result:
if "socket exception" in log.getvalue(): # type: ignore
return SendResult.CONNECTION_ERROR
else:
return SendResult.REJECTED_URL
return SendResult.SUCCESS
# region Time
def next_selected_day(
allowed_weekdays: List[WEEKDAY_NUMBER],
current_weekday: WEEKDAY_NUMBER
) -> WEEKDAY_NUMBER:
"""Find the next allowed day in the week.
```
>>> next_selected_day([0, 4, 6], 4)
6
>>> next_selected_day([0, 4, 6], 6)
0
```
Args:
allowed_weekdays (List[WEEKDAY_NUMBER]): The days of the week that are
allowed. Monday is 0, Sunday is 6.
current_weekday (WEEKDAY_NUMBER): The current weekday.
Returns:
WEEKDAY_NUMBER: The next allowed weekday.
"""
for d in allowed_weekdays:
if current_weekday < d:
return d
return allowed_weekdays[0]
def find_next_time(
original_time: int,
repeat_quantity: Union[RepeatQuantity, None],
repeat_interval: Union[int, None],
weekdays: Union[List[WEEKDAY_NUMBER], None],
cron_schedule: Union[str, None]
) -> int:
"""Calculate the next timestamp based on original time and repeat/interval
values.
Args:
original_time (int): The original time of the repeating timestamp.
repeat_quantity (Union[RepeatQuantity, None]): If set, what the quantity
is of the time interval.
repeat_interval (Union[int, None]): If set, the value of the time
interval.
weekdays (Union[List[WEEKDAY_NUMBER], None]): If set, on which days the
timestamp can be. Monday is 0, Sunday is 6.
cron_schedule (Union[str, None]): If set, the cron schedule to follow.
Returns:
int: The next timestamp in the future.
"""
if weekdays is not None:
weekdays.sort()
current_time = datetime.fromtimestamp(datetime.utcnow().timestamp())
original_datetime = datetime.fromtimestamp(original_time)
new_time = datetime.fromtimestamp(original_time)
if cron_schedule is not None:
cron_instance = Cron(cron_schedule)
schedule = cron_instance.schedule(current_time)
new_time = schedule.next()
while new_time <= current_time:
new_time = schedule.next()
elif (
repeat_quantity is not None
and repeat_interval is not None
):
# Add the interval to the original time until we are in the future.
# We need to multiply the interval and add it to the original time
# instead of just adding the interval once each time to the original
# time, because otherwise date jumping could happen. Say original time
# is a leap day with an interval of 1 year. Then next date would be the
# day before leap day, as leap day doesn't exist in the next year. But
# if we then keep adding 1 year to this time, we would keep getting the
# day before leap day, a year later. So we need to multiply the interval
# and add the whole interval to the original time in one go. This way
# after four years we will get the leap day again.
interval = relativedelta(
**{repeat_quantity.value: repeat_interval} # type: ignore
)
multiplier = 1
while new_time <= current_time:
new_time = original_datetime + (interval * multiplier)
multiplier += 1
elif weekdays is not None:
if (
current_time.weekday() in weekdays
and current_time.time() < original_datetime.time()
):
# Next reminder is later today, so target weekday is current weekday
weekday = current_time.weekday()
else:
# Next reminder is not today or earlier today, so target weekday
# is next selected one
weekday = next_selected_day(
weekdays,
cast(WEEKDAY_NUMBER, current_time.weekday())
)
new_time = current_time + relativedelta(
# Move to upcoming weekday (possibly today)
weekday=weekday,
# Also move current time to set time
hour=original_datetime.hour,
minute=original_datetime.minute,
second=original_datetime.second
)
result = int(new_time.timestamp())
LOGGER.debug(
f'{original_datetime=}, {current_time=} ' +
f'and interval of {repeat_interval} {repeat_quantity} ' +
f'and weekdays {weekdays} ' +
f'leads to {result}'
)
return result
# region Files
def folder_path(*folders: str) -> str:
"""Turn filepaths relative to the project folder into absolute paths.
Returns:
str: The absolute filepath.
"""
return join(
dirname(dirname(dirname(abspath(__file__)))),
*folders
)
def list_files(folder: str, ext: Iterable[str] = []) -> List[str]:
"""List all files in a folder recursively with absolute paths. Hidden files
(files starting with `.`) are ignored.
Args:
folder (str): The base folder to search through.
ext (Iterable[str], optional): File extensions to only include.
Dot-prefix not necessary. Let empty to allow all extensions.
Defaults to [].
Returns:
List[str]: The paths of the files in the folder.
"""
files: List[str] = []
def _list_files(folder: str, ext: Set[str] = set()):
"""Internal function to add all files in a folder to the files list.
Args:
folder (str): The base folder to search through.
ext (Set[str], optional): A set of lowercase, dot-prefixed,
extensions to filter for or empty for no filter. Defaults to set().
"""
for f in scandir(folder):
if f.is_dir():
_list_files(f.path, ext)
elif (
f.is_file()
and not f.name.startswith('.')
and (
not ext
or (splitext(f.name)[1].lower() in ext)
)
):
files.append(f.path)
ext = {'.' + e.lower().lstrip('.') for e in ext}
_list_files(folder, ext)
return files
def create_folder(folder: str) -> None:
"""Create a folder. Also creates any parent folders if they don't exist
already. Allows folder to already exist.
Args:
folder (str): The path to the folder to create.
"""
makedirs(folder, exist_ok=True)
return
def copy(
src: str,
dst: str,
*,
follow_symlinks: bool = True
) -> str:
"""Copy a file or folder.
Args:
src (str): The source file or folder.
dst (str): The destination of the copy.
follow_symlinks (bool, optional): Whether to follow symlinks.
Defaults to True.
Returns:
str: The destination.
"""
try:
return copy2(src, dst, follow_symlinks=follow_symlinks)
except PermissionError as pe:
if pe.errno == 1:
# NFS file system doesn't allow/support chmod.
# This is done after the file is already copied. So just accept that
# it isn't possible to change the permissions. Continue like normal.
return dst
raise
except OSError as oe:
if oe.errno == 524:
# NFS file system doesn't allow/support setting extended attributes.
# This is done after the file is already copied. So just accept that
# it isn't possible to set them. Continue like normal.
return dst
raise
def rename_file(
before: str,
after: str
) -> None:
"""Rename a file, taking care of new folder locations and
the possible complications with files on OS'es.
Args:
before (str): The current filepath of the file.
after (str): The new desired filepath of the file.
"""
create_folder(dirname(after))
move(before, after, copy_function=copy)
return
# region Classes
class Singleton(type):
"""
Make each initialisation of a class return the same instance by setting
this as the metaclass. Works across threads, but not spawned subprocesses.
"""
_instances = {}
def __call__(cls, *args, **kwargs):
c_term = cls.__module__ + '.' + cls.__name__
if c_term not in cls._instances:
cls._instances[c_term] = super().__call__(*args, **kwargs)
return cls._instances[c_term]