From 3ff529c71812fa500214468feef6ff0633729c06 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 12 Mar 2025 06:28:02 +1000 Subject: [PATCH] revert(app): use OR logic for workflow library filtering --- invokeai/app/api/routers/workflows.py | 11 ++--- .../workflow_records/workflow_records_base.py | 14 ++---- .../workflow_records_sqlite.py | 48 ++++++++++++++++++- 3 files changed, 54 insertions(+), 19 deletions(-) diff --git a/invokeai/app/api/routers/workflows.py b/invokeai/app/api/routers/workflows.py index 8ef6cb9b3e..4e32307cad 100644 --- a/invokeai/app/api/routers/workflows.py +++ b/invokeai/app/api/routers/workflows.py @@ -221,17 +221,14 @@ async def get_workflow_thumbnail( raise HTTPException(status_code=404) -@workflows_router.get("/tag_counts_with_filter", operation_id="get_tag_counts_with_filter") -async def get_tag_counts_with_filter( - tags_to_count: list[str] = Query(description="The tags to get counts for"), - selected_tags: Optional[list[str]] = Query(default=None, description="The tags to include"), +@workflows_router.get("/counts_by_tag", operation_id="get_counts_by_tag") +async def get_counts_by_tag( + tags: list[str] = Query(description="The tags to get counts for"), categories: Optional[list[WorkflowCategory]] = Query(default=None, description="The categories to include"), ) -> dict[str, int]: """Gets tag counts with a filter""" - return ApiDependencies.invoker.services.workflow_records.get_tag_counts_with_filter( - tags_to_count=tags_to_count, categories=categories, selected_tags=selected_tags - ) + return ApiDependencies.invoker.services.workflow_records.counts_by_tag(tags=tags, categories=categories) @workflows_router.put( diff --git a/invokeai/app/services/workflow_records/workflow_records_base.py b/invokeai/app/services/workflow_records/workflow_records_base.py index ebc858d943..aea13f55c8 100644 --- a/invokeai/app/services/workflow_records/workflow_records_base.py +++ b/invokeai/app/services/workflow_records/workflow_records_base.py @@ -51,20 +51,12 @@ class WorkflowRecordsStorageBase(ABC): pass @abstractmethod - def get_tag_counts_with_filter( + def counts_by_tag( self, - tags_to_count: list[str], - selected_tags: Optional[list[str]] = None, + tags: list[str], categories: Optional[list[WorkflowCategory]] = None, ) -> dict[str, int]: - """ - For each tag in tags_to_count, count workflows matching: - - All selected_tags (AND logic filter) - - AND the specific tag being counted - - Filtered by categories if provided - - Returns a dictionary of tag -> count. - """ + """Gets a dictionary of counts for each of the provided tags.""" pass @abstractmethod diff --git a/invokeai/app/services/workflow_records/workflow_records_sqlite.py b/invokeai/app/services/workflow_records/workflow_records_sqlite.py index d4f16be15e..6e3cb1fdbd 100644 --- a/invokeai/app/services/workflow_records/workflow_records_sqlite.py +++ b/invokeai/app/services/workflow_records/workflow_records_sqlite.py @@ -166,7 +166,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase): # Construct a list of conditions for each tag tags_conditions = ["tags LIKE ?" for _ in tags] - tags_conditions_joined = " AND ".join(tags_conditions) + tags_conditions_joined = " OR ".join(tags_conditions) tags_condition = f"({tags_conditions_joined})" # And the params for the tags, case-insensitive @@ -230,6 +230,52 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase): total=total, ) + def counts_by_tag( + self, + tags: list[str], + categories: Optional[list[WorkflowCategory]] = None, + ) -> dict[str, int]: + if not tags: + return {} + + cursor = self._conn.cursor() + result: dict[str, int] = {} + # Base conditions for categories and selected tags + base_conditions: list[str] = [] + base_params: list[str | int] = [] + + # Add category conditions + if categories: + assert all(c in WorkflowCategory for c in categories) + placeholders = ", ".join("?" for _ in categories) + base_conditions.append(f"category IN ({placeholders})") + base_params.extend([category.value for category in categories]) + + # For each tag to count, run a separate query + for tag in tags: + # Start with the base conditions + conditions = base_conditions.copy() + params = base_params.copy() + + # Add this specific tag condition + conditions.append("tags LIKE ?") + params.append(f"%{tag.strip()}%") + + # Construct the full query + stmt = """--sql + SELECT COUNT(*) + FROM workflow_library + """ + + if conditions: + stmt += " WHERE " + " AND ".join(conditions) + + cursor.execute(stmt, params) + count = cursor.fetchone()[0] + result[tag] = count + + return result + def get_tag_counts_with_filter( self, tags_to_count: list[str],