# -*- coding: utf-8 -*- """ General "helper" functions and classes """ from base64 import urlsafe_b64encode from datetime import datetime from hashlib import pbkdf2_hmac, sha256 from logging import WARNING from os import makedirs, scandir, symlink from os.path import abspath, dirname, exists, isfile, join, splitext from secrets import token_bytes, token_hex from shutil import copy2, move from sys import base_exec_prefix, executable, platform, version_info from threading import current_thread from typing import (Any, Callable, Dict, Iterable, List, Sequence, Set, Tuple, Union, cast) from apprise import Apprise, LogCapture from cron_converter import Cron from dateutil.relativedelta import relativedelta from backend.base.definitions import (WEEKDAY_NUMBER, Constants, GeneralReminderData, JSONSerialisable, RepeatQuantity, SendResult, T, U) from backend.base.logging import LOGGER # region Python def get_python_version() -> str: """Get the Python version as a string. E.g. `"3.8.10.final.0"`. Returns: str: The Python version. """ return ".".join( str(i) for i in list(version_info) ) def check_min_python_version( min_major: int, min_minor: int, min_micro: int ) -> bool: """Check whether the version of Python that is used is equal or higher than the version given. Will log a critical error if not. ``` # On Python3.9.1 >>> check_min_python_version(3, 8, 2) True >>> check_min_python_version(3, 10, 0) False ``` Args: min_major (int): The minimum major version. min_minor (int): The minimum minor version. min_micro (int): The miminum micro version. Returns: bool: Whether it's equal or higher than the version given or below it. """ min_version = ( min_major, min_minor, min_micro ) current_version = ( version_info.major, version_info.minor, version_info.micro ) if current_version < min_version: LOGGER.critical( "The minimum python version required is python" + ".".join(map(str, min_version)) + " (currently " + ".".join(map(str, current_version)) + ")." ) return False return True def get_python_exe() -> Union[str, None]: """Get the absolute filepath to the python executable. Returns: Union[str, None]: The python executable path, or `None` if not found. """ if platform.startswith('darwin'): filepath = None bundle_path = join( base_exec_prefix, "Resources", "Python.app", "Contents", "MacOS", "Python" ) if exists(bundle_path): from tempfile import mkdtemp filepath = join(mkdtemp(), "python") symlink(bundle_path, filepath) else: filepath = executable or None if filepath and not isfile(filepath): filepath = None return filepath def get_version_from_pyproject(filepath: str) -> str: """Get the application version from the `pyproject.toml` file. Args: filepath (str): The path to the `pyproject.toml` file. Raises: RuntimeError: Version not found in file. Returns: str: The version string. """ with open(filepath, "r") as f: for line in f: if line.startswith("version = "): return "V" + line.split('"')[1] else: raise RuntimeError("Version not found in pyproject.toml") # region Generic def first_of_subarrays( subarrays: Iterable[Sequence[T]] ) -> List[T]: """Get the first element of each sub-array. ``` >>> first_of_subarrays([[1, 2], [3, 4]]) [1, 3] ``` Args: subarrays (Iterable[Sequence[T]]): List of sub-arrays. Returns: List[T]: List with first value of each sub-array. """ return [e[0] for e in subarrays] def when_not_none( value: Union[T, None], to_run: Callable[[T], U] ) -> Union[U, None]: """Run `to_run` with argument `value` iff `value is not None`. Else return `None`. Args: value (Union[T, None]): The value to check. to_run (Callable[[T], U]): The function to run. Returns: Union[U, None]: Either the return value of `to_run`, or `None`. """ if value is None: return None else: return to_run(value) def search_filter(query: str, result: GeneralReminderData) -> bool: """Filter library results based on a query. Args: query (str): The query to filter with. result (GeneralReminderData): The library result to check. Returns: bool: Whether the result passes the filter. """ query = query.lower().replace(' ', '') return ( query.lower().replace(' ', '') in (result.title + (result.text or '')).lower().replace(' ', '') ) def current_thread_id() -> int: """Get the ID of the current thread. Returns: int: The ID. """ return current_thread().native_id or -1 def return_api( result: JSONSerialisable, error: Union[str, None] = None, code: int = 200 ) -> Tuple[Dict[str, Any], int]: return {'error': error, 'result': result}, code # region Security def get_hash(salt: bytes, data: str) -> bytes: """Hash a string using the supplied salt. Args: salt (bytes): The salt to use when hashing. data (str): The data to hash. Returns: bytes: The b64 encoded hash of the supplied string. """ return urlsafe_b64encode( pbkdf2_hmac('sha256', data.encode(), salt, 100_000) ) def generate_salt_hash(password: str) -> Tuple[bytes, bytes]: """Generate a salt and get the hash of the password. Args: password (str): The password to generate for. Returns: Tuple[bytes, bytes]: The salt (1) and hashed_password (2). """ salt = token_bytes() hashed_password = get_hash(salt, password) return salt, hashed_password def hash_api_key(api_key: str) -> str: """Hashes an API key using SHA-256. Args: api_key (str): The API key to hash. Returns: str: The hashed API key as a hexadecimal string. """ return sha256(api_key.encode('utf-8')).hexdigest() def generate_api_key() -> str: """Generate an API key. Returns: str: The API key. """ # Each byte is represented by two hexadecimal characters, so halve # the desired amount of bytes. return token_hex(Constants.API_KEY_LENGTH // 2) def generate_mfa_code() -> str: """Generate a 6-digit MFA code. Returns: str: The code. """ return str(int.from_bytes(token_bytes(3), 'big') % 1_000_000).zfill(6) # region Apprise def send_apprise_notification( urls: List[str], title: str, text: Union[str, None] = None ) -> SendResult: """Send a notification to all Apprise URLs given. Args: urls (List[str]): The Apprise URLs to send the notification to. title (str): The title of the notification. text (Union[str, None], optional): The optional body of the notification. Defaults to None. Returns: SendResult: Whether Apprise was successful. """ a = Apprise() for url in urls: if not a.add(url): return SendResult.SYNTAX_INVALID_URL with LogCapture(level=WARNING) as log: result = a.notify( title=title, body=text or '\u200B' ) if not result: if "socket exception" in log.getvalue(): # type: ignore return SendResult.CONNECTION_ERROR else: return SendResult.REJECTED_URL return SendResult.SUCCESS # region Time def next_selected_day( allowed_weekdays: List[WEEKDAY_NUMBER], current_weekday: WEEKDAY_NUMBER ) -> WEEKDAY_NUMBER: """Find the next allowed day in the week. ``` >>> next_selected_day([0, 4, 6], 4) 6 >>> next_selected_day([0, 4, 6], 6) 0 ``` Args: allowed_weekdays (List[WEEKDAY_NUMBER]): The days of the week that are allowed. Monday is 0, Sunday is 6. current_weekday (WEEKDAY_NUMBER): The current weekday. Returns: WEEKDAY_NUMBER: The next allowed weekday. """ for d in allowed_weekdays: if current_weekday < d: return d return allowed_weekdays[0] def find_next_time( original_time: int, repeat_quantity: Union[RepeatQuantity, None], repeat_interval: Union[int, None], weekdays: Union[List[WEEKDAY_NUMBER], None], cron_schedule: Union[str, None] ) -> int: """Calculate the next timestamp based on original time and repeat/interval values. Args: original_time (int): The original time of the repeating timestamp. repeat_quantity (Union[RepeatQuantity, None]): If set, what the quantity is of the time interval. repeat_interval (Union[int, None]): If set, the value of the time interval. weekdays (Union[List[WEEKDAY_NUMBER], None]): If set, on which days the timestamp can be. Monday is 0, Sunday is 6. cron_schedule (Union[str, None]): If set, the cron schedule to follow. Returns: int: The next timestamp in the future. """ if weekdays is not None: weekdays.sort() current_time = datetime.fromtimestamp(datetime.utcnow().timestamp()) original_datetime = datetime.fromtimestamp(original_time) new_time = datetime.fromtimestamp(original_time) if cron_schedule is not None: cron_instance = Cron(cron_schedule) schedule = cron_instance.schedule(current_time) new_time = schedule.next() while new_time <= current_time: new_time = schedule.next() elif ( repeat_quantity is not None and repeat_interval is not None ): # Add the interval to the original time until we are in the future. # We need to multiply the interval and add it to the original time # instead of just adding the interval once each time to the original # time, because otherwise date jumping could happen. Say original time # is a leap day with an interval of 1 year. Then next date would be the # day before leap day, as leap day doesn't exist in the next year. But # if we then keep adding 1 year to this time, we would keep getting the # day before leap day, a year later. So we need to multiply the interval # and add the whole interval to the original time in one go. This way # after four years we will get the leap day again. interval = relativedelta( **{repeat_quantity.value: repeat_interval} # type: ignore ) multiplier = 1 while new_time <= current_time: new_time = original_datetime + (interval * multiplier) multiplier += 1 elif weekdays is not None: if ( current_time.weekday() in weekdays and current_time.time() < original_datetime.time() ): # Next reminder is later today, so target weekday is current weekday weekday = current_time.weekday() else: # Next reminder is not today or earlier today, so target weekday # is next selected one weekday = next_selected_day( weekdays, cast(WEEKDAY_NUMBER, current_time.weekday()) ) new_time = current_time + relativedelta( # Move to upcoming weekday (possibly today) weekday=weekday, # Also move current time to set time hour=original_datetime.hour, minute=original_datetime.minute, second=original_datetime.second ) result = int(new_time.timestamp()) LOGGER.debug( f'{original_datetime=}, {current_time=} ' + f'and interval of {repeat_interval} {repeat_quantity} ' + f'and weekdays {weekdays} ' + f'leads to {result}' ) return result # region Files def folder_path(*folders: str) -> str: """Turn filepaths relative to the project folder into absolute paths. Returns: str: The absolute filepath. """ return join( dirname(dirname(dirname(abspath(__file__)))), *folders ) def list_files(folder: str, ext: Iterable[str] = []) -> List[str]: """List all files in a folder recursively with absolute paths. Hidden files (files starting with `.`) are ignored. Args: folder (str): The base folder to search through. ext (Iterable[str], optional): File extensions to only include. Dot-prefix not necessary. Let empty to allow all extensions. Defaults to []. Returns: List[str]: The paths of the files in the folder. """ files: List[str] = [] def _list_files(folder: str, ext: Set[str] = set()): """Internal function to add all files in a folder to the files list. Args: folder (str): The base folder to search through. ext (Set[str], optional): A set of lowercase, dot-prefixed, extensions to filter for or empty for no filter. Defaults to set(). """ for f in scandir(folder): if f.is_dir(): _list_files(f.path, ext) elif ( f.is_file() and not f.name.startswith('.') and ( not ext or (splitext(f.name)[1].lower() in ext) ) ): files.append(f.path) ext = {'.' + e.lower().lstrip('.') for e in ext} _list_files(folder, ext) return files def create_folder(folder: str) -> None: """Create a folder. Also creates any parent folders if they don't exist already. Allows folder to already exist. Args: folder (str): The path to the folder to create. """ makedirs(folder, exist_ok=True) return def copy( src: str, dst: str, *, follow_symlinks: bool = True ) -> str: """Copy a file or folder. Args: src (str): The source file or folder. dst (str): The destination of the copy. follow_symlinks (bool, optional): Whether to follow symlinks. Defaults to True. Returns: str: The destination. """ try: return copy2(src, dst, follow_symlinks=follow_symlinks) except PermissionError as pe: if pe.errno == 1: # NFS file system doesn't allow/support chmod. # This is done after the file is already copied. So just accept that # it isn't possible to change the permissions. Continue like normal. return dst raise except OSError as oe: if oe.errno == 524: # NFS file system doesn't allow/support setting extended attributes. # This is done after the file is already copied. So just accept that # it isn't possible to set them. Continue like normal. return dst raise def rename_file( before: str, after: str ) -> None: """Rename a file, taking care of new folder locations and the possible complications with files on OS'es. Args: before (str): The current filepath of the file. after (str): The new desired filepath of the file. """ create_folder(dirname(after)) move(before, after, copy_function=copy) return # region Classes class Singleton(type): """ Make each initialisation of a class return the same instance by setting this as the metaclass. Works across threads, but not spawned subprocesses. """ _instances = {} def __call__(cls, *args, **kwargs): c_term = cls.__module__ + '.' + cls.__name__ if c_term not in cls._instances: cls._instances[c_term] = super().__call__(*args, **kwargs) return cls._instances[c_term]