diff --git a/backend/db.py b/backend/db.py index fed13b5..2c380d3 100644 --- a/backend/db.py +++ b/backend/db.py @@ -1,8 +1,8 @@ #-*- coding: utf-8 -*- from datetime import datetime -from sqlite3 import Connection, Row -from threading import current_thread +from sqlite3 import Connection, ProgrammingError, Row +from threading import current_thread, main_thread from time import time from typing import Union @@ -14,7 +14,8 @@ class Singleton(type): _instances = {} def __call__(cls, *args, **kwargs): i = f'{cls}{current_thread()}' - if i not in cls._instances: + if (i not in cls._instances + or cls._instances[i].closed): cls._instances[i] = super(Singleton, cls).__call__(*args, **kwargs) return cls._instances[i] @@ -25,6 +26,12 @@ class DBConnection(Connection, metaclass=Singleton): def __init__(self, timeout: float) -> None: super().__init__(self.file, timeout=timeout) super().cursor().execute("PRAGMA foreign_keys = ON;") + self.closed = False + return + + def close(self) -> None: + self.closed = True + super().close() return def get_db(output_type: Union[dict, tuple]=tuple): @@ -54,11 +61,13 @@ def close_db(e=None) -> None: """ try: cursor = g.cursor - db = cursor.connection + db: DBConnection = cursor.connection cursor.close() delattr(g, 'cursor') db.commit() - except AttributeError: + if current_thread() is main_thread(): + db.close() + except (AttributeError, ProgrammingError): pass return