diff --git a/backend/internals/db.py b/backend/internals/db.py index 4aee2aa..626ab1f 100644 --- a/backend/internals/db.py +++ b/backend/internals/db.py @@ -120,6 +120,18 @@ class DBConnectionManager(type): return cls.instances[thread_id] + @classmethod + def close_connection_of_thread(cls) -> None: + """Close the DB connection of the current thread""" + thread_id = current_thread_id() + if ( + thread_id in cls.instances + and not cls.instances[thread_id].closed + ): + cls.instances[thread_id].close() + del cls.instances[thread_id] + return + class DBConnection(Connection, metaclass=DBConnectionManager): default_file = '' diff --git a/backend/internals/server.py b/backend/internals/server.py index 9742c6e..28de3d7 100644 --- a/backend/internals/server.py +++ b/backend/internals/server.py @@ -7,7 +7,7 @@ Setting up, running and shutting down the API and web-ui from __future__ import annotations from os import urandom -from threading import Timer, current_thread +from threading import Timer from typing import TYPE_CHECKING, Any, Callable, Iterable, Mapping, Union from flask import Flask, request @@ -27,44 +27,21 @@ if TYPE_CHECKING: class ThreadedTaskDispatcher(TTD): - def handler_thread(self, thread_no: int) -> None: - # Most of this method's content is copied straight from waitress - # except for the the part marked. The thread is considered to be - # stopped when it's removed from self.threads, so we need to close - # the database connection before it. - while True: - with self.lock: - while not self.queue and self.stop_count == 0: - # Mark ourselves as idle before waiting to be - # woken up, then we will once again be active - self.active_count -= 1 - self.queue_cv.wait() - self.active_count += 1 + def __init__(self) -> None: + super().__init__() - if self.stop_count > 0: - self.active_count -= 1 - self.stop_count -= 1 - - # ================= - # Kapowarr part - thread_id = current_thread().native_id or -1 - if ( - thread_id in DBConnectionManager.instances - and not DBConnectionManager.instances[thread_id].closed - ): - DBConnectionManager.instances[thread_id].close() - # ================= - - self.threads.discard(thread_no) - self.thread_exit_cv.notify() - break - - task = self.queue.popleft() - try: - task.service() - except BaseException: - self.logger.exception("Exception when servicing %r", task) + # The DB connection should be closed when the thread is ending, but + # right before it actually has. Waitress will consider a thread closed + # once it's not in the self.threads set anymore, regardless of whether + # the thread has actually ended/joined, so anything we do after that + # could be cut short by the main thread ending. So we need to close + # the DB connection before the thread is discarded from the set. + class TDDSet(set): + def discard(self, element: Any) -> None: + DBConnectionManager.close_connection_of_thread() + return super().discard(element) + self.threads = TDDSet() return def shutdown(self, cancel_pending: bool = True, timeout: int = 5) -> bool: @@ -314,14 +291,7 @@ class Server(metaclass=Singleton): with self.app.app_context(): target(*args, **kwargs) - thread_id = current_thread().native_id or -1 - if ( - thread_id in DBConnectionManager.instances - and - not DBConnectionManager.instances[thread_id].closed - ): - DBConnectionManager.instances[thread_id].close() - + DBConnectionManager.close_connection_of_thread() return t = Timer(