fix(nodes): fix results service bugs

This commit is contained in:
psychedelicious
2023-05-17 19:35:34 +10:00
parent f0a9a4fb88
commit 6e4e0fe29e

View File

@@ -1,4 +1,7 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from abc import ABC, abstractmethod
import json
import sqlite3
@@ -7,10 +10,11 @@ from typing import Union
from pydantic import BaseModel, Field, parse_raw_as
from invokeai.app.models.image import ImageField
from invokeai.app.invocations.latent import LatentsField
from invokeai.app.services.graph import GraphExecutionState
from invokeai.app.services.item_storage import PaginatedResults
if TYPE_CHECKING:
from invokeai.app.models.image import ImageField
from invokeai.app.invocations.latent import LatentsField
from invokeai.app.services.graph import GraphExecutionState
from invokeai.app.services.item_storage import PaginatedResults
class Result(BaseModel):
@@ -70,7 +74,7 @@ class SqliteResultsService(ResultsServiceABC):
try:
self._lock.acquire()
self._cursor.execute(
"""
"""--sql
CREATE TABLE IF NOT EXISTS results (
id TEXT PRIMARY KEY,
node_id TEXT,
@@ -80,7 +84,9 @@ class SqliteResultsService(ResultsServiceABC):
"""
)
self._cursor.execute(
"""CREATE UNIQUE INDEX IF NOT EXISTS result_id ON result(id);"""
"""--sql
CREATE UNIQUE INDEX IF NOT EXISTS result_id ON results(id);
"""
)
finally:
self._lock.release()
@@ -89,8 +95,8 @@ class SqliteResultsService(ResultsServiceABC):
try:
self._lock.acquire()
self._cursor.execute(
"""
SELECT results.data, graph_executions.state
"""--sql
SELECT results.data, graph_executions.item
FROM results
JOIN graph_executions ON results.session_id = graph_executions.id
WHERE results.id = ?
@@ -122,8 +128,8 @@ class SqliteResultsService(ResultsServiceABC):
try:
self._lock.acquire()
self._cursor.execute(
"""
SELECT results.data, graph_executions.state
"""--sql
SELECT results.data, graph_executions.item
FROM results
JOIN graph_executions ON results.session_id = graph_executions.id
LIMIT ? OFFSET ?;
@@ -160,8 +166,8 @@ class SqliteResultsService(ResultsServiceABC):
try:
self._lock.acquire()
self._cursor.execute(
"""
SELECT results.data, graph_executions.state
"""--sql
SELECT results.data, graph_executions.item
FROM results
JOIN graph_executions ON results.session_id = graph_executions.id
WHERE item LIKE ?
@@ -182,7 +188,9 @@ class SqliteResultsService(ResultsServiceABC):
)
)
self._cursor.execute(
f"""SELECT count(*) FROM results WHERE item LIKE ?;""",
"""--sql
SELECT count(*) FROM results WHERE item LIKE ?;
""",
(f"%{query}%",),
)
count = self._cursor.fetchone()[0]
@@ -199,23 +207,23 @@ class SqliteResultsService(ResultsServiceABC):
with self._conn as conn:
for node_id, result in session.results.items():
# We'll only process 'image_output' or 'latents_output'
if result["type"] not in ["image_output", "latents_output"]:
if result.type not in ["image_output", "latents_output"]:
continue
# The id depends on the result type
if result["type"] == "image_output":
id = result["image"]["image_name"]
if result.type == "image_output":
id = result.image.image_name
else: # 'latents_output'
id = result["latents"]["latents_name"]
id = result.latents.latents_name
# Stringify the entire result object for the data column
data = json.dumps(result)
data = json.dumps(result.dict())
# Insert the result into the results table, ignoring if it already exists
conn.execute(
"""
"""--sql
INSERT OR IGNORE INTO results (id, node_id, session_id, data)
VALUES (?, ?, ?, ?)
""",
""",
(id, node_id, session.id, data),
)