mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-14 20:18:07 -05:00
revert(app): use OR logic for workflow library filtering
This commit is contained in:
@@ -221,17 +221,14 @@ async def get_workflow_thumbnail(
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@workflows_router.get("/tag_counts_with_filter", operation_id="get_tag_counts_with_filter")
|
||||
async def get_tag_counts_with_filter(
|
||||
tags_to_count: list[str] = Query(description="The tags to get counts for"),
|
||||
selected_tags: Optional[list[str]] = Query(default=None, description="The tags to include"),
|
||||
@workflows_router.get("/counts_by_tag", operation_id="get_counts_by_tag")
|
||||
async def get_counts_by_tag(
|
||||
tags: list[str] = Query(description="The tags to get counts for"),
|
||||
categories: Optional[list[WorkflowCategory]] = Query(default=None, description="The categories to include"),
|
||||
) -> dict[str, int]:
|
||||
"""Gets tag counts with a filter"""
|
||||
|
||||
return ApiDependencies.invoker.services.workflow_records.get_tag_counts_with_filter(
|
||||
tags_to_count=tags_to_count, categories=categories, selected_tags=selected_tags
|
||||
)
|
||||
return ApiDependencies.invoker.services.workflow_records.counts_by_tag(tags=tags, categories=categories)
|
||||
|
||||
|
||||
@workflows_router.put(
|
||||
|
||||
@@ -51,20 +51,12 @@ class WorkflowRecordsStorageBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_tag_counts_with_filter(
|
||||
def counts_by_tag(
|
||||
self,
|
||||
tags_to_count: list[str],
|
||||
selected_tags: Optional[list[str]] = None,
|
||||
tags: list[str],
|
||||
categories: Optional[list[WorkflowCategory]] = None,
|
||||
) -> dict[str, int]:
|
||||
"""
|
||||
For each tag in tags_to_count, count workflows matching:
|
||||
- All selected_tags (AND logic filter)
|
||||
- AND the specific tag being counted
|
||||
- Filtered by categories if provided
|
||||
|
||||
Returns a dictionary of tag -> count.
|
||||
"""
|
||||
"""Gets a dictionary of counts for each of the provided tags."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -166,7 +166,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
|
||||
# Construct a list of conditions for each tag
|
||||
tags_conditions = ["tags LIKE ?" for _ in tags]
|
||||
tags_conditions_joined = " AND ".join(tags_conditions)
|
||||
tags_conditions_joined = " OR ".join(tags_conditions)
|
||||
tags_condition = f"({tags_conditions_joined})"
|
||||
|
||||
# And the params for the tags, case-insensitive
|
||||
@@ -230,6 +230,52 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
total=total,
|
||||
)
|
||||
|
||||
def counts_by_tag(
|
||||
self,
|
||||
tags: list[str],
|
||||
categories: Optional[list[WorkflowCategory]] = None,
|
||||
) -> dict[str, int]:
|
||||
if not tags:
|
||||
return {}
|
||||
|
||||
cursor = self._conn.cursor()
|
||||
result: dict[str, int] = {}
|
||||
# Base conditions for categories and selected tags
|
||||
base_conditions: list[str] = []
|
||||
base_params: list[str | int] = []
|
||||
|
||||
# Add category conditions
|
||||
if categories:
|
||||
assert all(c in WorkflowCategory for c in categories)
|
||||
placeholders = ", ".join("?" for _ in categories)
|
||||
base_conditions.append(f"category IN ({placeholders})")
|
||||
base_params.extend([category.value for category in categories])
|
||||
|
||||
# For each tag to count, run a separate query
|
||||
for tag in tags:
|
||||
# Start with the base conditions
|
||||
conditions = base_conditions.copy()
|
||||
params = base_params.copy()
|
||||
|
||||
# Add this specific tag condition
|
||||
conditions.append("tags LIKE ?")
|
||||
params.append(f"%{tag.strip()}%")
|
||||
|
||||
# Construct the full query
|
||||
stmt = """--sql
|
||||
SELECT COUNT(*)
|
||||
FROM workflow_library
|
||||
"""
|
||||
|
||||
if conditions:
|
||||
stmt += " WHERE " + " AND ".join(conditions)
|
||||
|
||||
cursor.execute(stmt, params)
|
||||
count = cursor.fetchone()[0]
|
||||
result[tag] = count
|
||||
|
||||
return result
|
||||
|
||||
def get_tag_counts_with_filter(
|
||||
self,
|
||||
tags_to_count: list[str],
|
||||
|
||||
Reference in New Issue
Block a user