add full text search

This commit is contained in:
Swifty
2025-10-16 11:49:36 +02:00
parent 8b995c2394
commit 794aee25ab
4 changed files with 275 additions and 60 deletions

View File

@@ -0,0 +1,5 @@
import prisma.models
class StoreAgentWithRank(prisma.models.StoreAgent):
rank: float

View File

@@ -1,5 +1,6 @@
import asyncio
import logging
import typing
from datetime import datetime, timezone
import fastapi
@@ -70,64 +71,176 @@ async def get_store_agents(
logger.debug(
f"Getting store agents. featured={featured}, creators={creators}, sorted_by={sorted_by}, search={search_query}, category={category}, page={page}"
)
search_term = sanitize_query(search_query)
where_clause: prisma.types.StoreAgentWhereInput = {"is_available": True}
if featured:
where_clause["featured"] = featured
if creators:
where_clause["creator_username"] = {"in": creators}
if category:
where_clause["categories"] = {"has": category}
if search_term:
where_clause["OR"] = [
{"agent_name": {"contains": search_term, "mode": "insensitive"}},
{"description": {"contains": search_term, "mode": "insensitive"}},
]
order_by = []
if sorted_by == "rating":
order_by.append({"rating": "desc"})
elif sorted_by == "runs":
order_by.append({"runs": "desc"})
elif sorted_by == "name":
order_by.append({"agent_name": "asc"})
try:
agents = await prisma.models.StoreAgent.prisma().find_many(
where=where_clause,
order=order_by,
skip=(page - 1) * page_size,
take=page_size,
)
total = await prisma.models.StoreAgent.prisma().count(where=where_clause)
total_pages = (total + page_size - 1) // page_size
store_agents: list[backend.server.v2.store.model.StoreAgent] = []
for agent in agents:
try:
# Create the StoreAgent object safely
store_agent = backend.server.v2.store.model.StoreAgent(
slug=agent.slug,
agent_name=agent.agent_name,
agent_image=agent.agent_image[0] if agent.agent_image else "",
creator=agent.creator_username or "Needs Profile",
creator_avatar=agent.creator_avatar or "",
sub_heading=agent.sub_heading,
description=agent.description,
runs=agent.runs,
rating=agent.rating,
# If search_query is provided, use full-text search
if search_query:
search_term = sanitize_query(search_query)
if not search_term:
# Return empty results for invalid search query
return backend.server.v2.store.model.StoreAgentsResponse(
agents=[],
pagination=backend.server.v2.store.model.Pagination(
current_page=page,
total_items=0,
total_pages=0,
page_size=page_size,
),
)
# Add to the list only if creation was successful
store_agents.append(store_agent)
except Exception as e:
# Skip this agent if there was an error
# You could log the error here if needed
logger.error(
f"Error parsing Store agent when getting store agents from db: {e}"
)
continue
offset = (page - 1) * page_size
# Build filter conditions
filter_conditions = []
filter_conditions.append("is_available = true")
if featured:
filter_conditions.append("featured = true")
if creators:
creator_list = "','".join(creators)
filter_conditions.append(f"creator_username IN ('{creator_list}')")
if category:
filter_conditions.append(f"'{category}' = ANY(categories)")
where_filter = (
" AND ".join(filter_conditions) if filter_conditions else "1=1"
)
# Build ORDER BY clause
if sorted_by == "rating":
order_by_clause = "rating DESC, rank DESC"
elif sorted_by == "runs":
order_by_clause = "runs DESC, rank DESC"
elif sorted_by == "name":
order_by_clause = "agent_name ASC, rank DESC"
else:
order_by_clause = "rank DESC, updated_at DESC"
# Execute full-text search query
sql_query = f"""
WITH query AS (
SELECT to_tsquery(string_agg(lexeme || ':*', ' & ' ORDER BY positions)) AS q
FROM unnest(to_tsvector('{search_term}'))
)
SELECT
slug,
agent_name,
agent_image,
creator_username,
creator_avatar,
sub_heading,
description,
runs,
rating,
categories,
featured,
is_available,
updated_at,
ts_rank(CAST(search AS tsvector), query.q) AS rank
FROM "StoreAgent", query
WHERE {where_filter} AND search @@ query.q
ORDER BY {order_by_clause}
LIMIT {page_size}
OFFSET {offset};
"""
# Count query for pagination
count_query = f"""
WITH query AS (
SELECT to_tsquery(string_agg(lexeme || ':*', ' & ' ORDER BY positions)) AS q
FROM unnest(to_tsvector('{search_term}'))
)
SELECT COUNT(*) as count
FROM "StoreAgent", query
WHERE {where_filter} AND search @@ query.q;
"""
# Execute both queries
agents = await prisma.client.get_client().query_raw(
query=typing.cast(typing.LiteralString, sql_query)
)
count_result = await prisma.client.get_client().query_raw(
query=typing.cast(typing.LiteralString, count_query)
)
total = count_result[0]["count"] if count_result else 0
total_pages = (total + page_size - 1) // page_size
# Convert raw results to StoreAgent models
store_agents: list[backend.server.v2.store.model.StoreAgent] = []
for agent in agents:
try:
store_agent = backend.server.v2.store.model.StoreAgent(
slug=agent["slug"],
agent_name=agent["agent_name"],
agent_image=(
agent["agent_image"][0] if agent["agent_image"] else ""
),
creator=agent["creator_username"] or "Needs Profile",
creator_avatar=agent["creator_avatar"] or "",
sub_heading=agent["sub_heading"],
description=agent["description"],
runs=agent["runs"],
rating=agent["rating"],
)
store_agents.append(store_agent)
except Exception as e:
logger.error(f"Error parsing Store agent from search results: {e}")
continue
else:
# Non-search query path (original logic)
where_clause: prisma.types.StoreAgentWhereInput = {"is_available": True}
if featured:
where_clause["featured"] = featured
if creators:
where_clause["creator_username"] = {"in": creators}
if category:
where_clause["categories"] = {"has": category}
order_by = []
if sorted_by == "rating":
order_by.append({"rating": "desc"})
elif sorted_by == "runs":
order_by.append({"runs": "desc"})
elif sorted_by == "name":
order_by.append({"agent_name": "asc"})
agents = await prisma.models.StoreAgent.prisma().find_many(
where=where_clause,
order=order_by,
skip=(page - 1) * page_size,
take=page_size,
)
total = await prisma.models.StoreAgent.prisma().count(where=where_clause)
total_pages = (total + page_size - 1) // page_size
store_agents: list[backend.server.v2.store.model.StoreAgent] = []
for agent in agents:
try:
# Create the StoreAgent object safely
store_agent = backend.server.v2.store.model.StoreAgent(
slug=agent.slug,
agent_name=agent.agent_name,
agent_image=agent.agent_image[0] if agent.agent_image else "",
creator=agent.creator_username or "Needs Profile",
creator_avatar=agent.creator_avatar or "",
sub_heading=agent.sub_heading,
description=agent.description,
runs=agent.runs,
rating=agent.rating,
)
# Add to the list only if creation was successful
store_agents.append(store_agent)
except Exception as e:
# Skip this agent if there was an error
# You could log the error here if needed
logger.error(
f"Error parsing Store agent when getting store agents from db: {e}"
)
continue
logger.debug(f"Found {len(store_agents)} agents")
return backend.server.v2.store.model.StoreAgentsResponse(

View File

@@ -0,0 +1,93 @@
-- AlterTable
ALTER TABLE "StoreListingVersion" ADD COLUMN "search" tsvector DEFAULT ''::tsvector;
-- Add trigger to update the search column with the tsvector of the agent
-- Function to be invoked by trigger
CREATE OR REPLACE FUNCTION update_tsvector_column() RETURNS TRIGGER AS $$
BEGIN
NEW.search := to_tsvector('english', COALESCE(NEW.description, '')|| ' ' ||COALESCE(NEW.name, '')|| ' ' ||COALESCE(NEW.subHeading, '')|| ' ' ||COALESCE(NEW.description, ''));
RETURN NEW;
END;
$$ LANGUAGE plpgsql SECURITY definer SET search_path = public, pg_temp;
-- Trigger that keeps the TSVECTOR up to date
DROP TRIGGER IF EXISTS "update_tsvector" ON "StoreListingVersion";
CREATE TRIGGER "update_tsvector"
BEFORE INSERT OR UPDATE ON "StoreListingVersion"
FOR EACH ROW
EXECUTE FUNCTION update_tsvector_column ();
BEGIN;
-- Drop and recreate the StoreAgent view with isAvailable field
DROP VIEW IF EXISTS "StoreAgent";
CREATE OR REPLACE VIEW "StoreAgent" AS
WITH latest_versions AS (
SELECT
"storeListingId",
MAX(version) AS max_version
FROM "StoreListingVersion"
WHERE "submissionStatus" = 'APPROVED'
GROUP BY "storeListingId"
),
agent_versions AS (
SELECT
"storeListingId",
array_agg(DISTINCT version::text ORDER BY version::text) AS versions
FROM "StoreListingVersion"
WHERE "submissionStatus" = 'APPROVED'
GROUP BY "storeListingId"
)
SELECT
sl.id AS listing_id,
slv.id AS "storeListingVersionId",
slv."createdAt" AS updated_at,
sl.slug,
COALESCE(slv.name, '') AS agent_name,
slv."videoUrl" AS agent_video,
COALESCE(slv."imageUrls", ARRAY[]::text[]) AS agent_image,
slv."isFeatured" AS featured,
p.username AS creator_username, -- Allow NULL for malformed sub-agents
p."avatarUrl" AS creator_avatar, -- Allow NULL for malformed sub-agents
slv."subHeading" AS sub_heading,
slv.description,
slv.categories,
slv.search,
COALESCE(ar.run_count, 0::bigint) AS runs,
COALESCE(rs.avg_rating, 0.0)::double precision AS rating,
COALESCE(av.versions, ARRAY[slv.version::text]) AS versions,
slv."isAvailable" AS is_available -- Add isAvailable field to filter sub-agents
FROM "StoreListing" sl
JOIN latest_versions lv
ON sl.id = lv."storeListingId"
JOIN "StoreListingVersion" slv
ON slv."storeListingId" = lv."storeListingId"
AND slv.version = lv.max_version
AND slv."submissionStatus" = 'APPROVED'
JOIN "AgentGraph" a
ON slv."agentGraphId" = a.id
AND slv."agentGraphVersion" = a.version
LEFT JOIN "Profile" p
ON sl."owningUserId" = p."userId"
LEFT JOIN "mv_review_stats" rs
ON sl.id = rs."storeListingId"
LEFT JOIN "mv_agent_run_counts" ar
ON a.id = ar."agentGraphId"
LEFT JOIN agent_versions av
ON sl.id = av."storeListingId"
WHERE sl."isDeleted" = false
AND sl."hasApprovedVersion" = true;
COMMIT;

View File

@@ -5,10 +5,11 @@ datasource db {
}
generator client {
provider = "prisma-client-py"
recursive_type_depth = -1
interface = "asyncio"
previewFeatures = ["views"]
provider = "prisma-client-py"
recursive_type_depth = -1
interface = "asyncio"
previewFeatures = ["views", "fullTextSearch"]
partial_type_generator = "backend/data/partial_types.py"
}
// User model to mirror Auth provider users
@@ -663,6 +664,7 @@ view StoreAgent {
sub_heading String
description String
categories String[]
search Unsupported("tsvector")? @default(dbgenerated("''::tsvector"))
runs Int
rating Float
versions String[]
@@ -746,7 +748,7 @@ model StoreListing {
slug String
// Allow this agent to be used during onboarding
useForOnboarding Boolean @default(false)
useForOnboarding Boolean @default(false)
// The currently active version that should be shown to users
activeVersionId String? @unique
@@ -797,6 +799,8 @@ model StoreListingVersion {
// Old versions can be made unavailable by the author if desired
isAvailable Boolean @default(true)
search Unsupported("tsvector")? @default(dbgenerated("''::tsvector"))
// Version workflow state
submissionStatus SubmissionStatus @default(DRAFT)
submittedAt DateTime?