revert(app): use OR logic for workflow library filtering

This commit is contained in:
psychedelicious
2025-03-12 06:28:02 +10:00
parent 3b0fecafb0
commit 3ff529c718
3 changed files with 54 additions and 19 deletions

View File

@@ -221,17 +221,14 @@ async def get_workflow_thumbnail(
raise HTTPException(status_code=404)
@workflows_router.get("/tag_counts_with_filter", operation_id="get_tag_counts_with_filter")
async def get_tag_counts_with_filter(
tags_to_count: list[str] = Query(description="The tags to get counts for"),
selected_tags: Optional[list[str]] = Query(default=None, description="The tags to include"),
@workflows_router.get("/counts_by_tag", operation_id="get_counts_by_tag")
async def get_counts_by_tag(
tags: list[str] = Query(description="The tags to get counts for"),
categories: Optional[list[WorkflowCategory]] = Query(default=None, description="The categories to include"),
) -> dict[str, int]:
"""Gets tag counts with a filter"""
return ApiDependencies.invoker.services.workflow_records.get_tag_counts_with_filter(
tags_to_count=tags_to_count, categories=categories, selected_tags=selected_tags
)
return ApiDependencies.invoker.services.workflow_records.counts_by_tag(tags=tags, categories=categories)
@workflows_router.put(

View File

@@ -51,20 +51,12 @@ class WorkflowRecordsStorageBase(ABC):
pass
@abstractmethod
def get_tag_counts_with_filter(
def counts_by_tag(
self,
tags_to_count: list[str],
selected_tags: Optional[list[str]] = None,
tags: list[str],
categories: Optional[list[WorkflowCategory]] = None,
) -> dict[str, int]:
"""
For each tag in tags_to_count, count workflows matching:
- All selected_tags (AND logic filter)
- AND the specific tag being counted
- Filtered by categories if provided
Returns a dictionary of tag -> count.
"""
"""Gets a dictionary of counts for each of the provided tags."""
pass
@abstractmethod

View File

@@ -166,7 +166,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
# Construct a list of conditions for each tag
tags_conditions = ["tags LIKE ?" for _ in tags]
tags_conditions_joined = " AND ".join(tags_conditions)
tags_conditions_joined = " OR ".join(tags_conditions)
tags_condition = f"({tags_conditions_joined})"
# And the params for the tags, case-insensitive
@@ -230,6 +230,52 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
total=total,
)
def counts_by_tag(
self,
tags: list[str],
categories: Optional[list[WorkflowCategory]] = None,
) -> dict[str, int]:
if not tags:
return {}
cursor = self._conn.cursor()
result: dict[str, int] = {}
# Base conditions for categories and selected tags
base_conditions: list[str] = []
base_params: list[str | int] = []
# Add category conditions
if categories:
assert all(c in WorkflowCategory for c in categories)
placeholders = ", ".join("?" for _ in categories)
base_conditions.append(f"category IN ({placeholders})")
base_params.extend([category.value for category in categories])
# For each tag to count, run a separate query
for tag in tags:
# Start with the base conditions
conditions = base_conditions.copy()
params = base_params.copy()
# Add this specific tag condition
conditions.append("tags LIKE ?")
params.append(f"%{tag.strip()}%")
# Construct the full query
stmt = """--sql
SELECT COUNT(*)
FROM workflow_library
"""
if conditions:
stmt += " WHERE " + " AND ".join(conditions)
cursor.execute(stmt, params)
count = cursor.fetchone()[0]
result[tag] = count
return result
def get_tag_counts_with_filter(
self,
tags_to_count: list[str],