mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
feat(app): avoid nested cursors in model_records service
This commit is contained in:
@@ -78,7 +78,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
"""
|
||||
super().__init__()
|
||||
self._db = db
|
||||
self._cursor = db.conn.cursor()
|
||||
self._logger = logger
|
||||
|
||||
@property
|
||||
@@ -97,7 +96,8 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
Can raise DuplicateModelException and InvalidModelConfigException exceptions.
|
||||
"""
|
||||
try:
|
||||
self._cursor.execute(
|
||||
cursor = self._db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO models (
|
||||
id,
|
||||
@@ -139,14 +139,15 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
Can raise an UnknownModelException
|
||||
"""
|
||||
try:
|
||||
self._cursor.execute(
|
||||
cursor = self._db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM models
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
)
|
||||
if self._cursor.rowcount == 0:
|
||||
if cursor.rowcount == 0:
|
||||
raise UnknownModelException("model not found")
|
||||
self._db.conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
@@ -163,7 +164,8 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
json_serialized = record.model_dump_json()
|
||||
|
||||
try:
|
||||
self._cursor.execute(
|
||||
cursor = self._db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE models
|
||||
SET
|
||||
@@ -172,7 +174,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
""",
|
||||
(json_serialized, key),
|
||||
)
|
||||
if self._cursor.rowcount == 0:
|
||||
if cursor.rowcount == 0:
|
||||
raise UnknownModelException("model not found")
|
||||
self._db.conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
@@ -189,28 +191,30 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
|
||||
Exceptions: UnknownModelException
|
||||
"""
|
||||
self._cursor.execute(
|
||||
cursor = self._db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
)
|
||||
rows = self._cursor.fetchone()
|
||||
rows = cursor.fetchone()
|
||||
if not rows:
|
||||
raise UnknownModelException("model not found")
|
||||
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
|
||||
return model
|
||||
|
||||
def get_model_by_hash(self, hash: str) -> AnyModelConfig:
|
||||
self._cursor.execute(
|
||||
cursor = self._db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE hash=?;
|
||||
""",
|
||||
(hash,),
|
||||
)
|
||||
rows = self._cursor.fetchone()
|
||||
rows = cursor.fetchone()
|
||||
if not rows:
|
||||
raise UnknownModelException("model not found")
|
||||
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
|
||||
@@ -222,14 +226,15 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
|
||||
:param key: Unique key for the model to be deleted
|
||||
"""
|
||||
self._cursor.execute(
|
||||
cursor = self._db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
select count(*) FROM models
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
)
|
||||
count = self._cursor.fetchone()[0]
|
||||
count = cursor.fetchone()[0]
|
||||
return count > 0
|
||||
|
||||
def search_by_attr(
|
||||
@@ -277,7 +282,9 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
where_clause.append("format=?")
|
||||
bindings.append(model_format)
|
||||
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
|
||||
self._cursor.execute(
|
||||
|
||||
cursor = self._db.conn.cursor()
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
SELECT config, strftime('%s',updated_at)
|
||||
FROM models
|
||||
@@ -286,7 +293,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
""",
|
||||
tuple(bindings),
|
||||
)
|
||||
result = self._cursor.fetchall()
|
||||
result = cursor.fetchall()
|
||||
|
||||
# Parse the model configs.
|
||||
results: list[AnyModelConfig] = []
|
||||
@@ -305,26 +312,28 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
|
||||
def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]:
|
||||
"""Return models with the indicated path."""
|
||||
self._cursor.execute(
|
||||
cursor = self._db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE path=?;
|
||||
""",
|
||||
(str(path),),
|
||||
)
|
||||
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()]
|
||||
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()]
|
||||
return results
|
||||
|
||||
def search_by_hash(self, hash: str) -> List[AnyModelConfig]:
|
||||
"""Return models with the indicated hash."""
|
||||
self._cursor.execute(
|
||||
cursor = self._db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE hash=?;
|
||||
""",
|
||||
(hash,),
|
||||
)
|
||||
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()]
|
||||
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()]
|
||||
return results
|
||||
|
||||
def list_models(
|
||||
@@ -340,18 +349,20 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
ModelRecordOrderBy.Format: "format",
|
||||
}
|
||||
|
||||
cursor = self._db.conn.cursor()
|
||||
|
||||
# Lock so that the database isn't updated while we're doing the two queries.
|
||||
# query1: get the total number of model configs
|
||||
self._cursor.execute(
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
select count(*) from models;
|
||||
""",
|
||||
(),
|
||||
)
|
||||
total = int(self._cursor.fetchone()[0])
|
||||
total = int(cursor.fetchone()[0])
|
||||
|
||||
# query2: fetch key fields
|
||||
self._cursor.execute(
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
SELECT config
|
||||
FROM models
|
||||
@@ -364,6 +375,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
page * per_page,
|
||||
),
|
||||
)
|
||||
rows = self._cursor.fetchall()
|
||||
rows = cursor.fetchall()
|
||||
items = [ModelSummary.model_validate(dict(x)) for x in rows]
|
||||
return PaginatedResults(page=page, pages=ceil(total / per_page), per_page=per_page, total=total, items=items)
|
||||
|
||||
Reference in New Issue
Block a user