diff --git a/autogpt_platform/backend/backend/data/partial_types.py b/autogpt_platform/backend/backend/data/partial_types.py new file mode 100644 index 0000000000..befa32219f --- /dev/null +++ b/autogpt_platform/backend/backend/data/partial_types.py @@ -0,0 +1,5 @@ +import prisma.models + + +class StoreAgentWithRank(prisma.models.StoreAgent): + rank: float diff --git a/autogpt_platform/backend/backend/server/v2/store/db.py b/autogpt_platform/backend/backend/server/v2/store/db.py index 84e63b36ac..4ad0b72d29 100644 --- a/autogpt_platform/backend/backend/server/v2/store/db.py +++ b/autogpt_platform/backend/backend/server/v2/store/db.py @@ -1,5 +1,6 @@ import asyncio import logging +import typing from datetime import datetime, timezone import fastapi @@ -70,64 +71,176 @@ async def get_store_agents( logger.debug( f"Getting store agents. featured={featured}, creators={creators}, sorted_by={sorted_by}, search={search_query}, category={category}, page={page}" ) - search_term = sanitize_query(search_query) - where_clause: prisma.types.StoreAgentWhereInput = {"is_available": True} - if featured: - where_clause["featured"] = featured - if creators: - where_clause["creator_username"] = {"in": creators} - if category: - where_clause["categories"] = {"has": category} - - if search_term: - where_clause["OR"] = [ - {"agent_name": {"contains": search_term, "mode": "insensitive"}}, - {"description": {"contains": search_term, "mode": "insensitive"}}, - ] - - order_by = [] - if sorted_by == "rating": - order_by.append({"rating": "desc"}) - elif sorted_by == "runs": - order_by.append({"runs": "desc"}) - elif sorted_by == "name": - order_by.append({"agent_name": "asc"}) try: - agents = await prisma.models.StoreAgent.prisma().find_many( - where=where_clause, - order=order_by, - skip=(page - 1) * page_size, - take=page_size, - ) - - total = await prisma.models.StoreAgent.prisma().count(where=where_clause) - total_pages = (total + page_size - 1) // page_size - - store_agents: list[backend.server.v2.store.model.StoreAgent] = [] - for agent in agents: - try: - # Create the StoreAgent object safely - store_agent = backend.server.v2.store.model.StoreAgent( - slug=agent.slug, - agent_name=agent.agent_name, - agent_image=agent.agent_image[0] if agent.agent_image else "", - creator=agent.creator_username or "Needs Profile", - creator_avatar=agent.creator_avatar or "", - sub_heading=agent.sub_heading, - description=agent.description, - runs=agent.runs, - rating=agent.rating, + # If search_query is provided, use full-text search + if search_query: + search_term = sanitize_query(search_query) + if not search_term: + # Return empty results for invalid search query + return backend.server.v2.store.model.StoreAgentsResponse( + agents=[], + pagination=backend.server.v2.store.model.Pagination( + current_page=page, + total_items=0, + total_pages=0, + page_size=page_size, + ), ) - # Add to the list only if creation was successful - store_agents.append(store_agent) - except Exception as e: - # Skip this agent if there was an error - # You could log the error here if needed - logger.error( - f"Error parsing Store agent when getting store agents from db: {e}" - ) - continue + + offset = (page - 1) * page_size + + # Build filter conditions + filter_conditions = [] + filter_conditions.append("is_available = true") + + if featured: + filter_conditions.append("featured = true") + if creators: + creator_list = "','".join(creators) + filter_conditions.append(f"creator_username IN ('{creator_list}')") + if category: + filter_conditions.append(f"'{category}' = ANY(categories)") + + where_filter = ( + " AND ".join(filter_conditions) if filter_conditions else "1=1" + ) + + # Build ORDER BY clause + if sorted_by == "rating": + order_by_clause = "rating DESC, rank DESC" + elif sorted_by == "runs": + order_by_clause = "runs DESC, rank DESC" + elif sorted_by == "name": + order_by_clause = "agent_name ASC, rank DESC" + else: + order_by_clause = "rank DESC, updated_at DESC" + + # Execute full-text search query + sql_query = f""" + WITH query AS ( + SELECT to_tsquery(string_agg(lexeme || ':*', ' & ' ORDER BY positions)) AS q + FROM unnest(to_tsvector('{search_term}')) + ) + SELECT + slug, + agent_name, + agent_image, + creator_username, + creator_avatar, + sub_heading, + description, + runs, + rating, + categories, + featured, + is_available, + updated_at, + ts_rank(CAST(search AS tsvector), query.q) AS rank + FROM "StoreAgent", query + WHERE {where_filter} AND search @@ query.q + ORDER BY {order_by_clause} + LIMIT {page_size} + OFFSET {offset}; + """ + + # Count query for pagination + count_query = f""" + WITH query AS ( + SELECT to_tsquery(string_agg(lexeme || ':*', ' & ' ORDER BY positions)) AS q + FROM unnest(to_tsvector('{search_term}')) + ) + SELECT COUNT(*) as count + FROM "StoreAgent", query + WHERE {where_filter} AND search @@ query.q; + """ + + # Execute both queries + agents = await prisma.client.get_client().query_raw( + query=typing.cast(typing.LiteralString, sql_query) + ) + + count_result = await prisma.client.get_client().query_raw( + query=typing.cast(typing.LiteralString, count_query) + ) + + total = count_result[0]["count"] if count_result else 0 + total_pages = (total + page_size - 1) // page_size + + # Convert raw results to StoreAgent models + store_agents: list[backend.server.v2.store.model.StoreAgent] = [] + for agent in agents: + try: + store_agent = backend.server.v2.store.model.StoreAgent( + slug=agent["slug"], + agent_name=agent["agent_name"], + agent_image=( + agent["agent_image"][0] if agent["agent_image"] else "" + ), + creator=agent["creator_username"] or "Needs Profile", + creator_avatar=agent["creator_avatar"] or "", + sub_heading=agent["sub_heading"], + description=agent["description"], + runs=agent["runs"], + rating=agent["rating"], + ) + store_agents.append(store_agent) + except Exception as e: + logger.error(f"Error parsing Store agent from search results: {e}") + continue + + else: + # Non-search query path (original logic) + where_clause: prisma.types.StoreAgentWhereInput = {"is_available": True} + if featured: + where_clause["featured"] = featured + if creators: + where_clause["creator_username"] = {"in": creators} + if category: + where_clause["categories"] = {"has": category} + + order_by = [] + if sorted_by == "rating": + order_by.append({"rating": "desc"}) + elif sorted_by == "runs": + order_by.append({"runs": "desc"}) + elif sorted_by == "name": + order_by.append({"agent_name": "asc"}) + + agents = await prisma.models.StoreAgent.prisma().find_many( + where=where_clause, + order=order_by, + skip=(page - 1) * page_size, + take=page_size, + ) + + total = await prisma.models.StoreAgent.prisma().count(where=where_clause) + total_pages = (total + page_size - 1) // page_size + + store_agents: list[backend.server.v2.store.model.StoreAgent] = [] + for agent in agents: + try: + # Create the StoreAgent object safely + store_agent = backend.server.v2.store.model.StoreAgent( + slug=agent.slug, + agent_name=agent.agent_name, + agent_image=agent.agent_image[0] if agent.agent_image else "", + creator=agent.creator_username or "Needs Profile", + creator_avatar=agent.creator_avatar or "", + sub_heading=agent.sub_heading, + description=agent.description, + runs=agent.runs, + rating=agent.rating, + ) + # Add to the list only if creation was successful + store_agents.append(store_agent) + except Exception as e: + # Skip this agent if there was an error + # You could log the error here if needed + logger.error( + f"Error parsing Store agent when getting store agents from db: {e}" + ) + continue logger.debug(f"Found {len(store_agents)} agents") return backend.server.v2.store.model.StoreAgentsResponse( diff --git a/autogpt_platform/backend/migrations/20251016093049_add_full_text_search/migration.sql b/autogpt_platform/backend/migrations/20251016093049_add_full_text_search/migration.sql new file mode 100644 index 0000000000..5c35ca7a65 --- /dev/null +++ b/autogpt_platform/backend/migrations/20251016093049_add_full_text_search/migration.sql @@ -0,0 +1,93 @@ +-- AlterTable +ALTER TABLE "StoreListingVersion" ADD COLUMN "search" tsvector DEFAULT ''::tsvector; + +-- Add trigger to update the search column with the tsvector of the agent +-- Function to be invoked by trigger + +CREATE OR REPLACE FUNCTION update_tsvector_column() RETURNS TRIGGER AS $$ + +BEGIN + +NEW.search := to_tsvector('english', COALESCE(NEW.description, '')|| ' ' ||COALESCE(NEW.name, '')|| ' ' ||COALESCE(NEW.subHeading, '')|| ' ' ||COALESCE(NEW.description, '')); + +RETURN NEW; + +END; + +$$ LANGUAGE plpgsql SECURITY definer SET search_path = public, pg_temp; + +-- Trigger that keeps the TSVECTOR up to date + +DROP TRIGGER IF EXISTS "update_tsvector" ON "StoreListingVersion"; + +CREATE TRIGGER "update_tsvector" + +BEFORE INSERT OR UPDATE ON "StoreListingVersion" + +FOR EACH ROW + +EXECUTE FUNCTION update_tsvector_column (); + +BEGIN; + +-- Drop and recreate the StoreAgent view with isAvailable field +DROP VIEW IF EXISTS "StoreAgent"; + +CREATE OR REPLACE VIEW "StoreAgent" AS +WITH latest_versions AS ( + SELECT + "storeListingId", + MAX(version) AS max_version + FROM "StoreListingVersion" + WHERE "submissionStatus" = 'APPROVED' + GROUP BY "storeListingId" +), +agent_versions AS ( + SELECT + "storeListingId", + array_agg(DISTINCT version::text ORDER BY version::text) AS versions + FROM "StoreListingVersion" + WHERE "submissionStatus" = 'APPROVED' + GROUP BY "storeListingId" +) +SELECT + sl.id AS listing_id, + slv.id AS "storeListingVersionId", + slv."createdAt" AS updated_at, + sl.slug, + COALESCE(slv.name, '') AS agent_name, + slv."videoUrl" AS agent_video, + COALESCE(slv."imageUrls", ARRAY[]::text[]) AS agent_image, + slv."isFeatured" AS featured, + p.username AS creator_username, -- Allow NULL for malformed sub-agents + p."avatarUrl" AS creator_avatar, -- Allow NULL for malformed sub-agents + slv."subHeading" AS sub_heading, + slv.description, + slv.categories, + slv.search, + COALESCE(ar.run_count, 0::bigint) AS runs, + COALESCE(rs.avg_rating, 0.0)::double precision AS rating, + COALESCE(av.versions, ARRAY[slv.version::text]) AS versions, + slv."isAvailable" AS is_available -- Add isAvailable field to filter sub-agents +FROM "StoreListing" sl +JOIN latest_versions lv + ON sl.id = lv."storeListingId" +JOIN "StoreListingVersion" slv + ON slv."storeListingId" = lv."storeListingId" + AND slv.version = lv.max_version + AND slv."submissionStatus" = 'APPROVED' +JOIN "AgentGraph" a + ON slv."agentGraphId" = a.id + AND slv."agentGraphVersion" = a.version +LEFT JOIN "Profile" p + ON sl."owningUserId" = p."userId" +LEFT JOIN "mv_review_stats" rs + ON sl.id = rs."storeListingId" +LEFT JOIN "mv_agent_run_counts" ar + ON a.id = ar."agentGraphId" +LEFT JOIN agent_versions av + ON sl.id = av."storeListingId" +WHERE sl."isDeleted" = false + AND sl."hasApprovedVersion" = true; + +COMMIT; \ No newline at end of file diff --git a/autogpt_platform/backend/schema.prisma b/autogpt_platform/backend/schema.prisma index 7556c45918..d1b878f6fe 100644 --- a/autogpt_platform/backend/schema.prisma +++ b/autogpt_platform/backend/schema.prisma @@ -5,10 +5,11 @@ datasource db { } generator client { - provider = "prisma-client-py" - recursive_type_depth = -1 - interface = "asyncio" - previewFeatures = ["views"] + provider = "prisma-client-py" + recursive_type_depth = -1 + interface = "asyncio" + previewFeatures = ["views", "fullTextSearch"] + partial_type_generator = "backend/data/partial_types.py" } // User model to mirror Auth provider users @@ -663,6 +664,7 @@ view StoreAgent { sub_heading String description String categories String[] + search Unsupported("tsvector")? @default(dbgenerated("''::tsvector")) runs Int rating Float versions String[] @@ -746,7 +748,7 @@ model StoreListing { slug String // Allow this agent to be used during onboarding - useForOnboarding Boolean @default(false) + useForOnboarding Boolean @default(false) // The currently active version that should be shown to users activeVersionId String? @unique @@ -797,6 +799,8 @@ model StoreListingVersion { // Old versions can be made unavailable by the author if desired isAvailable Boolean @default(true) + search Unsupported("tsvector")? @default(dbgenerated("''::tsvector")) + // Version workflow state submissionStatus SubmissionStatus @default(DRAFT) submittedAt DateTime?