feat(app): cancel by destination, not origin

When resetting the canvas or staging area, we don't want to cancel generations that are going to the gallery - only those going to the canvas.

Thus the method should not cancel by origin, but instead cancel by destination.

Update the queue method and route.
This commit is contained in:
psychedelicious
2024-09-06 10:30:57 +10:00
parent 97aad2ab2f
commit 480856a528
4 changed files with 21 additions and 19 deletions

View File

@@ -11,7 +11,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
Batch,
BatchStatus,
CancelByBatchIDsResult,
CancelByOriginResult,
CancelByDestinationResult,
ClearResult,
EnqueueBatchResult,
PruneResult,
@@ -107,16 +107,18 @@ async def cancel_by_batch_ids(
@session_queue_router.put(
"/{queue_id}/cancel_by_origin",
operation_id="cancel_by_origin",
"/{queue_id}/cancel_by_destination",
operation_id="cancel_by_destination",
responses={200: {"model": CancelByBatchIDsResult}},
)
async def cancel_by_origin(
async def cancel_by_destination(
queue_id: str = Path(description="The queue id to perform this operation on"),
origin: str = Query(description="The origin to cancel all queue items for"),
) -> CancelByOriginResult:
destination: str = Query(description="The destination to cancel all queue items for"),
) -> CancelByDestinationResult:
"""Immediately cancels all queue items with the given origin"""
return ApiDependencies.invoker.services.session_queue.cancel_by_origin(queue_id=queue_id, origin=origin)
return ApiDependencies.invoker.services.session_queue.cancel_by_destination(
queue_id=queue_id, destination=destination
)
@session_queue_router.put(

View File

@@ -6,7 +6,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
Batch,
BatchStatus,
CancelByBatchIDsResult,
CancelByOriginResult,
CancelByDestinationResult,
CancelByQueueIDResult,
ClearResult,
EnqueueBatchResult,
@@ -97,8 +97,8 @@ class SessionQueueBase(ABC):
pass
@abstractmethod
def cancel_by_origin(self, queue_id: str, origin: str) -> CancelByOriginResult:
"""Cancels all queue items with the given batch origin"""
def cancel_by_destination(self, queue_id: str, destination: str) -> CancelByDestinationResult:
"""Cancels all queue items with the given batch destination"""
pass
@abstractmethod

View File

@@ -346,10 +346,10 @@ class CancelByBatchIDsResult(BaseModel):
canceled: int = Field(..., description="Number of queue items canceled")
class CancelByOriginResult(BaseModel):
"""Result of canceling by list of batch ids"""
class CancelByDestinationResult(CancelByBatchIDsResult):
"""Result of canceling by a destination"""
canceled: int = Field(..., description="Number of queue items canceled")
pass
class CancelByQueueIDResult(CancelByBatchIDsResult):

View File

@@ -10,7 +10,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
Batch,
BatchStatus,
CancelByBatchIDsResult,
CancelByOriginResult,
CancelByDestinationResult,
CancelByQueueIDResult,
ClearResult,
EnqueueBatchResult,
@@ -426,19 +426,19 @@ class SqliteSessionQueue(SessionQueueBase):
self.__lock.release()
return CancelByBatchIDsResult(canceled=count)
def cancel_by_origin(self, queue_id: str, origin: str) -> CancelByOriginResult:
def cancel_by_destination(self, queue_id: str, destination: str) -> CancelByDestinationResult:
try:
current_queue_item = self.get_current(queue_id)
self.__lock.acquire()
where = """--sql
WHERE
queue_id == ?
AND origin == ?
AND destination == ?
AND status != 'canceled'
AND status != 'completed'
AND status != 'failed'
"""
params = (queue_id, origin)
params = (queue_id, destination)
self.__cursor.execute(
f"""--sql
SELECT COUNT(*)
@@ -457,14 +457,14 @@ class SqliteSessionQueue(SessionQueueBase):
params,
)
self.__conn.commit()
if current_queue_item is not None and current_queue_item.origin == origin:
if current_queue_item is not None and current_queue_item.destination == destination:
self._set_queue_item_status(current_queue_item.item_id, "canceled")
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
return CancelByOriginResult(canceled=count)
return CancelByDestinationResult(canceled=count)
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
try: