feat(app): avoid nested cursors in model_records service

This commit is contained in:
psychedelicious
2025-03-04 06:32:20 +10:00
parent 657095d2e2
commit 028d8d8ead

View File

@@ -78,7 +78,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
"""
super().__init__()
self._db = db
self._cursor = db.conn.cursor()
self._logger = logger
@property
@@ -97,7 +96,8 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
Can raise DuplicateModelException and InvalidModelConfigException exceptions.
"""
try:
self._cursor.execute(
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
INSERT INTO models (
id,
@@ -139,14 +139,15 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
Can raise an UnknownModelException
"""
try:
self._cursor.execute(
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
DELETE FROM models
WHERE id=?;
""",
(key,),
)
if self._cursor.rowcount == 0:
if cursor.rowcount == 0:
raise UnknownModelException("model not found")
self._db.conn.commit()
except sqlite3.Error as e:
@@ -163,7 +164,8 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
json_serialized = record.model_dump_json()
try:
self._cursor.execute(
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
UPDATE models
SET
@@ -172,7 +174,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
""",
(json_serialized, key),
)
if self._cursor.rowcount == 0:
if cursor.rowcount == 0:
raise UnknownModelException("model not found")
self._db.conn.commit()
except sqlite3.Error as e:
@@ -189,28 +191,30 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
Exceptions: UnknownModelException
"""
self._cursor.execute(
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE id=?;
""",
(key,),
)
rows = self._cursor.fetchone()
rows = cursor.fetchone()
if not rows:
raise UnknownModelException("model not found")
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
return model
def get_model_by_hash(self, hash: str) -> AnyModelConfig:
self._cursor.execute(
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE hash=?;
""",
(hash,),
)
rows = self._cursor.fetchone()
rows = cursor.fetchone()
if not rows:
raise UnknownModelException("model not found")
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
@@ -222,14 +226,15 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
:param key: Unique key for the model to be deleted
"""
self._cursor.execute(
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
select count(*) FROM models
WHERE id=?;
""",
(key,),
)
count = self._cursor.fetchone()[0]
count = cursor.fetchone()[0]
return count > 0
def search_by_attr(
@@ -277,7 +282,9 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
where_clause.append("format=?")
bindings.append(model_format)
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
self._cursor.execute(
cursor = self._db.conn.cursor()
cursor.execute(
f"""--sql
SELECT config, strftime('%s',updated_at)
FROM models
@@ -286,7 +293,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
""",
tuple(bindings),
)
result = self._cursor.fetchall()
result = cursor.fetchall()
# Parse the model configs.
results: list[AnyModelConfig] = []
@@ -305,26 +312,28 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]:
"""Return models with the indicated path."""
self._cursor.execute(
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE path=?;
""",
(str(path),),
)
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()]
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()]
return results
def search_by_hash(self, hash: str) -> List[AnyModelConfig]:
"""Return models with the indicated hash."""
self._cursor.execute(
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE hash=?;
""",
(hash,),
)
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()]
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()]
return results
def list_models(
@@ -340,18 +349,20 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
ModelRecordOrderBy.Format: "format",
}
cursor = self._db.conn.cursor()
# Lock so that the database isn't updated while we're doing the two queries.
# query1: get the total number of model configs
self._cursor.execute(
cursor.execute(
"""--sql
select count(*) from models;
""",
(),
)
total = int(self._cursor.fetchone()[0])
total = int(cursor.fetchone()[0])
# query2: fetch key fields
self._cursor.execute(
cursor.execute(
f"""--sql
SELECT config
FROM models
@@ -364,6 +375,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
page * per_page,
),
)
rows = self._cursor.fetchall()
rows = cursor.fetchall()
items = [ModelSummary.model_validate(dict(x)) for x in rows]
return PaginatedResults(page=page, pages=ceil(total / per_page), per_page=per_page, total=total, items=items)