feat(app): add noop api validation run stuff to routes and methods

This commit is contained in:
psychedelicious
2025-04-02 15:49:00 +10:00
parent 496f1262c6
commit 08ee08557b
6 changed files with 37 additions and 0 deletions

View File

@@ -15,6 +15,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
CancelByDestinationResult,
ClearResult,
EnqueueBatchResult,
FieldIdentifier,
PruneResult,
RetryItemsResult,
SessionQueueCountsByDestination,
@@ -45,6 +46,16 @@ async def enqueue_batch(
queue_id: str = Path(description="The queue id to perform this operation on"),
batch: Batch = Body(description="Batch to process"),
prepend: bool = Body(default=False, description="Whether or not to prepend this batch in the queue"),
is_api_validation_run: bool = Body(
default=False,
description="Whether or not this is a validation run.",
),
api_input_fields: Optional[list[FieldIdentifier]] = Body(
default=None, description="The fields that were used as input to the API"
),
api_output_fields: Optional[list[FieldIdentifier]] = Body(
default=None, description="The fields that were used as output from the API"
),
) -> EnqueueBatchResult:
"""Processes a batch and enqueues the output graphs for execution."""

View File

@@ -106,6 +106,7 @@ async def list_workflows(
tags: Optional[list[str]] = Query(default=None, description="The tags of workflow to get"),
query: Optional[str] = Query(default=None, description="The text to query by (matches name and description)"),
has_been_opened: Optional[bool] = Query(default=None, description="Whether to include/exclude recent workflows"),
is_published: Optional[bool] = Query(default=None, description="Whether to include/exclude published workflows"),
) -> PaginatedResults[WorkflowRecordListItemWithThumbnailDTO]:
"""Gets a page of workflows"""
workflows_with_thumbnails: list[WorkflowRecordListItemWithThumbnailDTO] = []
@@ -118,6 +119,7 @@ async def list_workflows(
categories=categories,
tags=tags,
has_been_opened=has_been_opened,
is_published=is_published,
)
for workflow in workflows.items:
workflows_with_thumbnails.append(

View File

@@ -201,6 +201,12 @@ def get_workflow(queue_item_dict: dict) -> Optional[WorkflowWithoutID]:
return None
class FieldIdentifier(BaseModel):
kind: Literal["input", "output"] = Field(description="The kind of field")
node_id: str = Field(description="The ID of the node")
field_name: str = Field(description="The name of the field")
class SessionQueueItemWithoutGraph(BaseModel):
"""Session queue item without the full graph. Used for serialization."""
@@ -237,6 +243,16 @@ class SessionQueueItemWithoutGraph(BaseModel):
retried_from_item_id: Optional[int] = Field(
default=None, description="The item_id of the queue item that this item was retried from"
)
is_api_validation_run: bool = Field(
default=False,
description="Whether this queue item is an API validation run.",
)
api_input_fields: Optional[list[FieldIdentifier]] = Field(
default=None, description="The fields that were used as input to the API"
)
api_output_fields: Optional[list[FieldIdentifier]] = Field(
default=None, description="The nodes that were used as output from the API"
)
@classmethod
def queue_item_dto_from_dict(cls, queue_item_dict: dict) -> "SessionQueueItemDTO":

View File

@@ -47,6 +47,7 @@ class WorkflowRecordsStorageBase(ABC):
query: Optional[str],
tags: Optional[list[str]],
has_been_opened: Optional[bool],
is_published: Optional[bool],
) -> PaginatedResults[WorkflowRecordListItemDTO]:
"""Gets many workflows."""
pass
@@ -56,6 +57,7 @@ class WorkflowRecordsStorageBase(ABC):
self,
categories: list[WorkflowCategory],
has_been_opened: Optional[bool] = None,
is_published: Optional[bool] = None,
) -> dict[str, int]:
"""Gets a dictionary of counts for each of the provided categories."""
pass
@@ -66,6 +68,7 @@ class WorkflowRecordsStorageBase(ABC):
tags: list[str],
categories: Optional[list[WorkflowCategory]] = None,
has_been_opened: Optional[bool] = None,
is_published: Optional[bool] = None,
) -> dict[str, int]:
"""Gets a dictionary of counts for each of the provided tags."""
pass

View File

@@ -67,6 +67,7 @@ class WorkflowWithoutID(BaseModel):
# This is typed as optional to prevent errors when pulling workflows from the DB. The frontend adds a default form if
# it is None.
form: dict[str, JsonValue] | None = Field(default=None, description="The form of the workflow.")
is_published: bool | None = Field(default=None, description="Whether the workflow is published or not.")
model_config = ConfigDict(extra="ignore")
@@ -101,6 +102,7 @@ class WorkflowRecordDTOBase(BaseModel):
opened_at: Optional[Union[datetime.datetime, str]] = Field(
default=None, description="The opened timestamp of the workflow."
)
is_published: bool | None = Field(default=None, description="Whether the workflow is published or not.")
class WorkflowRecordDTO(WorkflowRecordDTOBase):

View File

@@ -119,6 +119,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
query: Optional[str] = None,
tags: Optional[list[str]] = None,
has_been_opened: Optional[bool] = None,
is_published: Optional[bool] = None,
) -> PaginatedResults[WorkflowRecordListItemDTO]:
# sanitize!
assert order_by in WorkflowRecordOrderBy
@@ -241,6 +242,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
tags: list[str],
categories: Optional[list[WorkflowCategory]] = None,
has_been_opened: Optional[bool] = None,
is_published: Optional[bool] = None,
) -> dict[str, int]:
if not tags:
return {}
@@ -292,6 +294,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
self,
categories: list[WorkflowCategory],
has_been_opened: Optional[bool] = None,
is_published: Optional[bool] = None,
) -> dict[str, int]:
cursor = self._conn.cursor()
result: dict[str, int] = {}