mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-10 07:38:04 -05:00
add full text search
This commit is contained in:
5
autogpt_platform/backend/backend/data/partial_types.py
Normal file
5
autogpt_platform/backend/backend/data/partial_types.py
Normal file
@@ -0,0 +1,5 @@
|
||||
import prisma.models
|
||||
|
||||
|
||||
class StoreAgentWithRank(prisma.models.StoreAgent):
|
||||
rank: float
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import typing
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import fastapi
|
||||
@@ -70,64 +71,176 @@ async def get_store_agents(
|
||||
logger.debug(
|
||||
f"Getting store agents. featured={featured}, creators={creators}, sorted_by={sorted_by}, search={search_query}, category={category}, page={page}"
|
||||
)
|
||||
search_term = sanitize_query(search_query)
|
||||
where_clause: prisma.types.StoreAgentWhereInput = {"is_available": True}
|
||||
if featured:
|
||||
where_clause["featured"] = featured
|
||||
if creators:
|
||||
where_clause["creator_username"] = {"in": creators}
|
||||
if category:
|
||||
where_clause["categories"] = {"has": category}
|
||||
|
||||
if search_term:
|
||||
where_clause["OR"] = [
|
||||
{"agent_name": {"contains": search_term, "mode": "insensitive"}},
|
||||
{"description": {"contains": search_term, "mode": "insensitive"}},
|
||||
]
|
||||
|
||||
order_by = []
|
||||
if sorted_by == "rating":
|
||||
order_by.append({"rating": "desc"})
|
||||
elif sorted_by == "runs":
|
||||
order_by.append({"runs": "desc"})
|
||||
elif sorted_by == "name":
|
||||
order_by.append({"agent_name": "asc"})
|
||||
|
||||
try:
|
||||
agents = await prisma.models.StoreAgent.prisma().find_many(
|
||||
where=where_clause,
|
||||
order=order_by,
|
||||
skip=(page - 1) * page_size,
|
||||
take=page_size,
|
||||
)
|
||||
|
||||
total = await prisma.models.StoreAgent.prisma().count(where=where_clause)
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
|
||||
store_agents: list[backend.server.v2.store.model.StoreAgent] = []
|
||||
for agent in agents:
|
||||
try:
|
||||
# Create the StoreAgent object safely
|
||||
store_agent = backend.server.v2.store.model.StoreAgent(
|
||||
slug=agent.slug,
|
||||
agent_name=agent.agent_name,
|
||||
agent_image=agent.agent_image[0] if agent.agent_image else "",
|
||||
creator=agent.creator_username or "Needs Profile",
|
||||
creator_avatar=agent.creator_avatar or "",
|
||||
sub_heading=agent.sub_heading,
|
||||
description=agent.description,
|
||||
runs=agent.runs,
|
||||
rating=agent.rating,
|
||||
# If search_query is provided, use full-text search
|
||||
if search_query:
|
||||
search_term = sanitize_query(search_query)
|
||||
if not search_term:
|
||||
# Return empty results for invalid search query
|
||||
return backend.server.v2.store.model.StoreAgentsResponse(
|
||||
agents=[],
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
current_page=page,
|
||||
total_items=0,
|
||||
total_pages=0,
|
||||
page_size=page_size,
|
||||
),
|
||||
)
|
||||
# Add to the list only if creation was successful
|
||||
store_agents.append(store_agent)
|
||||
except Exception as e:
|
||||
# Skip this agent if there was an error
|
||||
# You could log the error here if needed
|
||||
logger.error(
|
||||
f"Error parsing Store agent when getting store agents from db: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
# Build filter conditions
|
||||
filter_conditions = []
|
||||
filter_conditions.append("is_available = true")
|
||||
|
||||
if featured:
|
||||
filter_conditions.append("featured = true")
|
||||
if creators:
|
||||
creator_list = "','".join(creators)
|
||||
filter_conditions.append(f"creator_username IN ('{creator_list}')")
|
||||
if category:
|
||||
filter_conditions.append(f"'{category}' = ANY(categories)")
|
||||
|
||||
where_filter = (
|
||||
" AND ".join(filter_conditions) if filter_conditions else "1=1"
|
||||
)
|
||||
|
||||
# Build ORDER BY clause
|
||||
if sorted_by == "rating":
|
||||
order_by_clause = "rating DESC, rank DESC"
|
||||
elif sorted_by == "runs":
|
||||
order_by_clause = "runs DESC, rank DESC"
|
||||
elif sorted_by == "name":
|
||||
order_by_clause = "agent_name ASC, rank DESC"
|
||||
else:
|
||||
order_by_clause = "rank DESC, updated_at DESC"
|
||||
|
||||
# Execute full-text search query
|
||||
sql_query = f"""
|
||||
WITH query AS (
|
||||
SELECT to_tsquery(string_agg(lexeme || ':*', ' & ' ORDER BY positions)) AS q
|
||||
FROM unnest(to_tsvector('{search_term}'))
|
||||
)
|
||||
SELECT
|
||||
slug,
|
||||
agent_name,
|
||||
agent_image,
|
||||
creator_username,
|
||||
creator_avatar,
|
||||
sub_heading,
|
||||
description,
|
||||
runs,
|
||||
rating,
|
||||
categories,
|
||||
featured,
|
||||
is_available,
|
||||
updated_at,
|
||||
ts_rank(CAST(search AS tsvector), query.q) AS rank
|
||||
FROM "StoreAgent", query
|
||||
WHERE {where_filter} AND search @@ query.q
|
||||
ORDER BY {order_by_clause}
|
||||
LIMIT {page_size}
|
||||
OFFSET {offset};
|
||||
"""
|
||||
|
||||
# Count query for pagination
|
||||
count_query = f"""
|
||||
WITH query AS (
|
||||
SELECT to_tsquery(string_agg(lexeme || ':*', ' & ' ORDER BY positions)) AS q
|
||||
FROM unnest(to_tsvector('{search_term}'))
|
||||
)
|
||||
SELECT COUNT(*) as count
|
||||
FROM "StoreAgent", query
|
||||
WHERE {where_filter} AND search @@ query.q;
|
||||
"""
|
||||
|
||||
# Execute both queries
|
||||
agents = await prisma.client.get_client().query_raw(
|
||||
query=typing.cast(typing.LiteralString, sql_query)
|
||||
)
|
||||
|
||||
count_result = await prisma.client.get_client().query_raw(
|
||||
query=typing.cast(typing.LiteralString, count_query)
|
||||
)
|
||||
|
||||
total = count_result[0]["count"] if count_result else 0
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
|
||||
# Convert raw results to StoreAgent models
|
||||
store_agents: list[backend.server.v2.store.model.StoreAgent] = []
|
||||
for agent in agents:
|
||||
try:
|
||||
store_agent = backend.server.v2.store.model.StoreAgent(
|
||||
slug=agent["slug"],
|
||||
agent_name=agent["agent_name"],
|
||||
agent_image=(
|
||||
agent["agent_image"][0] if agent["agent_image"] else ""
|
||||
),
|
||||
creator=agent["creator_username"] or "Needs Profile",
|
||||
creator_avatar=agent["creator_avatar"] or "",
|
||||
sub_heading=agent["sub_heading"],
|
||||
description=agent["description"],
|
||||
runs=agent["runs"],
|
||||
rating=agent["rating"],
|
||||
)
|
||||
store_agents.append(store_agent)
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing Store agent from search results: {e}")
|
||||
continue
|
||||
|
||||
else:
|
||||
# Non-search query path (original logic)
|
||||
where_clause: prisma.types.StoreAgentWhereInput = {"is_available": True}
|
||||
if featured:
|
||||
where_clause["featured"] = featured
|
||||
if creators:
|
||||
where_clause["creator_username"] = {"in": creators}
|
||||
if category:
|
||||
where_clause["categories"] = {"has": category}
|
||||
|
||||
order_by = []
|
||||
if sorted_by == "rating":
|
||||
order_by.append({"rating": "desc"})
|
||||
elif sorted_by == "runs":
|
||||
order_by.append({"runs": "desc"})
|
||||
elif sorted_by == "name":
|
||||
order_by.append({"agent_name": "asc"})
|
||||
|
||||
agents = await prisma.models.StoreAgent.prisma().find_many(
|
||||
where=where_clause,
|
||||
order=order_by,
|
||||
skip=(page - 1) * page_size,
|
||||
take=page_size,
|
||||
)
|
||||
|
||||
total = await prisma.models.StoreAgent.prisma().count(where=where_clause)
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
|
||||
store_agents: list[backend.server.v2.store.model.StoreAgent] = []
|
||||
for agent in agents:
|
||||
try:
|
||||
# Create the StoreAgent object safely
|
||||
store_agent = backend.server.v2.store.model.StoreAgent(
|
||||
slug=agent.slug,
|
||||
agent_name=agent.agent_name,
|
||||
agent_image=agent.agent_image[0] if agent.agent_image else "",
|
||||
creator=agent.creator_username or "Needs Profile",
|
||||
creator_avatar=agent.creator_avatar or "",
|
||||
sub_heading=agent.sub_heading,
|
||||
description=agent.description,
|
||||
runs=agent.runs,
|
||||
rating=agent.rating,
|
||||
)
|
||||
# Add to the list only if creation was successful
|
||||
store_agents.append(store_agent)
|
||||
except Exception as e:
|
||||
# Skip this agent if there was an error
|
||||
# You could log the error here if needed
|
||||
logger.error(
|
||||
f"Error parsing Store agent when getting store agents from db: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
logger.debug(f"Found {len(store_agents)} agents")
|
||||
return backend.server.v2.store.model.StoreAgentsResponse(
|
||||
|
||||
@@ -0,0 +1,93 @@
|
||||
-- AlterTable
|
||||
ALTER TABLE "StoreListingVersion" ADD COLUMN "search" tsvector DEFAULT ''::tsvector;
|
||||
|
||||
-- Add trigger to update the search column with the tsvector of the agent
|
||||
-- Function to be invoked by trigger
|
||||
|
||||
CREATE OR REPLACE FUNCTION update_tsvector_column() RETURNS TRIGGER AS $$
|
||||
|
||||
BEGIN
|
||||
|
||||
NEW.search := to_tsvector('english', COALESCE(NEW.description, '')|| ' ' ||COALESCE(NEW.name, '')|| ' ' ||COALESCE(NEW.subHeading, '')|| ' ' ||COALESCE(NEW.description, ''));
|
||||
|
||||
RETURN NEW;
|
||||
|
||||
END;
|
||||
|
||||
$$ LANGUAGE plpgsql SECURITY definer SET search_path = public, pg_temp;
|
||||
|
||||
-- Trigger that keeps the TSVECTOR up to date
|
||||
|
||||
DROP TRIGGER IF EXISTS "update_tsvector" ON "StoreListingVersion";
|
||||
|
||||
CREATE TRIGGER "update_tsvector"
|
||||
|
||||
BEFORE INSERT OR UPDATE ON "StoreListingVersion"
|
||||
|
||||
FOR EACH ROW
|
||||
|
||||
EXECUTE FUNCTION update_tsvector_column ();
|
||||
|
||||
BEGIN;
|
||||
|
||||
-- Drop and recreate the StoreAgent view with isAvailable field
|
||||
DROP VIEW IF EXISTS "StoreAgent";
|
||||
|
||||
CREATE OR REPLACE VIEW "StoreAgent" AS
|
||||
WITH latest_versions AS (
|
||||
SELECT
|
||||
"storeListingId",
|
||||
MAX(version) AS max_version
|
||||
FROM "StoreListingVersion"
|
||||
WHERE "submissionStatus" = 'APPROVED'
|
||||
GROUP BY "storeListingId"
|
||||
),
|
||||
agent_versions AS (
|
||||
SELECT
|
||||
"storeListingId",
|
||||
array_agg(DISTINCT version::text ORDER BY version::text) AS versions
|
||||
FROM "StoreListingVersion"
|
||||
WHERE "submissionStatus" = 'APPROVED'
|
||||
GROUP BY "storeListingId"
|
||||
)
|
||||
SELECT
|
||||
sl.id AS listing_id,
|
||||
slv.id AS "storeListingVersionId",
|
||||
slv."createdAt" AS updated_at,
|
||||
sl.slug,
|
||||
COALESCE(slv.name, '') AS agent_name,
|
||||
slv."videoUrl" AS agent_video,
|
||||
COALESCE(slv."imageUrls", ARRAY[]::text[]) AS agent_image,
|
||||
slv."isFeatured" AS featured,
|
||||
p.username AS creator_username, -- Allow NULL for malformed sub-agents
|
||||
p."avatarUrl" AS creator_avatar, -- Allow NULL for malformed sub-agents
|
||||
slv."subHeading" AS sub_heading,
|
||||
slv.description,
|
||||
slv.categories,
|
||||
slv.search,
|
||||
COALESCE(ar.run_count, 0::bigint) AS runs,
|
||||
COALESCE(rs.avg_rating, 0.0)::double precision AS rating,
|
||||
COALESCE(av.versions, ARRAY[slv.version::text]) AS versions,
|
||||
slv."isAvailable" AS is_available -- Add isAvailable field to filter sub-agents
|
||||
FROM "StoreListing" sl
|
||||
JOIN latest_versions lv
|
||||
ON sl.id = lv."storeListingId"
|
||||
JOIN "StoreListingVersion" slv
|
||||
ON slv."storeListingId" = lv."storeListingId"
|
||||
AND slv.version = lv.max_version
|
||||
AND slv."submissionStatus" = 'APPROVED'
|
||||
JOIN "AgentGraph" a
|
||||
ON slv."agentGraphId" = a.id
|
||||
AND slv."agentGraphVersion" = a.version
|
||||
LEFT JOIN "Profile" p
|
||||
ON sl."owningUserId" = p."userId"
|
||||
LEFT JOIN "mv_review_stats" rs
|
||||
ON sl.id = rs."storeListingId"
|
||||
LEFT JOIN "mv_agent_run_counts" ar
|
||||
ON a.id = ar."agentGraphId"
|
||||
LEFT JOIN agent_versions av
|
||||
ON sl.id = av."storeListingId"
|
||||
WHERE sl."isDeleted" = false
|
||||
AND sl."hasApprovedVersion" = true;
|
||||
|
||||
COMMIT;
|
||||
@@ -5,10 +5,11 @@ datasource db {
|
||||
}
|
||||
|
||||
generator client {
|
||||
provider = "prisma-client-py"
|
||||
recursive_type_depth = -1
|
||||
interface = "asyncio"
|
||||
previewFeatures = ["views"]
|
||||
provider = "prisma-client-py"
|
||||
recursive_type_depth = -1
|
||||
interface = "asyncio"
|
||||
previewFeatures = ["views", "fullTextSearch"]
|
||||
partial_type_generator = "backend/data/partial_types.py"
|
||||
}
|
||||
|
||||
// User model to mirror Auth provider users
|
||||
@@ -663,6 +664,7 @@ view StoreAgent {
|
||||
sub_heading String
|
||||
description String
|
||||
categories String[]
|
||||
search Unsupported("tsvector")? @default(dbgenerated("''::tsvector"))
|
||||
runs Int
|
||||
rating Float
|
||||
versions String[]
|
||||
@@ -746,7 +748,7 @@ model StoreListing {
|
||||
slug String
|
||||
|
||||
// Allow this agent to be used during onboarding
|
||||
useForOnboarding Boolean @default(false)
|
||||
useForOnboarding Boolean @default(false)
|
||||
|
||||
// The currently active version that should be shown to users
|
||||
activeVersionId String? @unique
|
||||
@@ -797,6 +799,8 @@ model StoreListingVersion {
|
||||
// Old versions can be made unavailable by the author if desired
|
||||
isAvailable Boolean @default(true)
|
||||
|
||||
search Unsupported("tsvector")? @default(dbgenerated("''::tsvector"))
|
||||
|
||||
// Version workflow state
|
||||
submissionStatus SubmissionStatus @default(DRAFT)
|
||||
submittedAt DateTime?
|
||||
|
||||
Reference in New Issue
Block a user