From 738ba79cffe232f075adb5df6f117a8be64e76f3 Mon Sep 17 00:00:00 2001 From: SwiftyOS Date: Thu, 1 Aug 2024 08:58:17 +0200 Subject: [PATCH] changed all imports to be fully qualified --- rnd/market/market/app.py | 61 +++++++-------- rnd/market/market/db.py | 25 +++--- rnd/market/market/routes/admin.py | 0 rnd/market/market/routes/agents.py | 91 ++++++++++++---------- rnd/market/market/routes/search.py | 30 +++---- rnd/market/market/utils/analytics.py | 8 +- rnd/market/market/utils/extension_types.py | 4 +- rnd/market/market/utils/partial_types.py | 4 +- 8 files changed, 118 insertions(+), 105 deletions(-) create mode 100644 rnd/market/market/routes/admin.py diff --git a/rnd/market/market/app.py b/rnd/market/market/app.py index fd745df174..e7c3216d7b 100644 --- a/rnd/market/market/app.py +++ b/rnd/market/market/app.py @@ -1,61 +1,58 @@ +import contextlib import os -from contextlib import asynccontextmanager +import dotenv +import fastapi +import fastapi.middleware.gzip +import prisma import sentry_sdk -from dotenv import load_dotenv -from fastapi import FastAPI -from fastapi.middleware.gzip import GZipMiddleware -from prisma import Prisma -from sentry_sdk.integrations.asyncio import AsyncioIntegration -from sentry_sdk.integrations.fastapi import FastApiIntegration -from sentry_sdk.integrations.starlette import StarletteIntegration +import sentry_sdk.integrations.asyncio +import sentry_sdk.integrations.fastapi +import sentry_sdk.integrations.starlette -from market.routes import agents, search +import market.routes.agents +import market.routes.search -load_dotenv() +dotenv.load_dotenv() if os.environ.get("SENTRY_DSN"): sentry_sdk.init( dsn=os.environ.get("SENTRY_DSN"), - # Set traces_sample_rate to 1.0 to capture 100% - # of transactions for performance monitoring. traces_sample_rate=1.0, - # Set profiles_sample_rate to 1.0 to profile 100% - # of sampled transactions. - # We recommend adjusting this value in production. profiles_sample_rate=1.0, enable_tracing=True, environment=os.environ.get("RUN_ENV", default="CLOUD").lower(), integrations=[ - StarletteIntegration(transaction_style="url"), - FastApiIntegration(transaction_style="url"), - AsyncioIntegration(), + sentry_sdk.integrations.starlette.StarletteIntegration( + transaction_style="url" + ), + sentry_sdk.integrations.fastapi.FastApiIntegration(transaction_style="url"), + sentry_sdk.integrations.asyncio.AsyncioIntegration(), ], ) -db_client = Prisma(auto_register=True) +db_client = prisma.Prisma(auto_register=True) -@asynccontextmanager -async def lifespan(app: FastAPI): +@contextlib.asynccontextmanager +async def lifespan(app: fastapi.FastAPI): await db_client.connect() yield await db_client.disconnect() -app = FastAPI( - title="Marketplace API", - description=( - "AutoGPT Marketplace API is a service that allows users to share AI agents." - ), +app = fastapi.FastAPI( + title="Marketplace API", + description="AutoGPT Marketplace API is a service that allows users to share AI agents.", summary="Maketplace API", version="0.1", lifespan=lifespan, ) -# Add gzip middleware to compress responses -app.add_middleware(GZipMiddleware, minimum_size=1000) - - -app.include_router(agents.router, prefix="/market/agents", tags=["agents"]) -app.include_router(search.router, prefix="/market/search", tags=["search"]) +app.add_middleware(fastapi.middleware.gzip.GZipMiddleware, minimum_size=1000) +app.include_router( + market.routes.agents.router, prefix="/market/agents", tags=["agents"] +) +app.include_router( + market.routes.search.router, prefix="/market/search", tags=["search"] +) diff --git a/rnd/market/market/db.py b/rnd/market/market/db.py index bea922f02a..ff7a182cc3 100644 --- a/rnd/market/market/db.py +++ b/rnd/market/market/db.py @@ -1,11 +1,12 @@ -from typing import List, Literal +import typing +import fuzzywuzzy +import fuzzywuzzy.fuzz +import prisma.errors import prisma.models import prisma.types -from fuzzywuzzy import fuzz -from prisma.errors import PrismaError -from market.utils.extension_types import AgentsWithRank +import market.utils.extension_types class AgentQueryError(Exception): @@ -23,7 +24,7 @@ async def get_agents( description: str | None = None, description_threshold: int = 60, sort_by: str = "createdAt", - sort_order: Literal["desc"] | Literal["asc"] = "desc", + sort_order: typing.Literal["desc"] | typing.Literal["asc"] = "desc", ): """ Retrieve a list of agents from the database based on the provided filters and pagination parameters. @@ -68,7 +69,7 @@ async def get_agents( skip=skip, take=page_size, ) - except PrismaError as e: + except prisma.errors.PrismaError as e: raise AgentQueryError(f"Database query failed: {str(e)}") # Apply fuzzy search on description if provided @@ -78,7 +79,7 @@ async def get_agents( for agent in agents: if ( agent.description - and fuzz.partial_ratio( + and fuzzywuzzy.fuzz.partial_ratio( description.lower(), agent.description.lower() ) >= description_threshold @@ -135,7 +136,7 @@ async def get_agent_details(agent_id: str, version: int | None = None): return agent - except PrismaError as e: + except prisma.errors.PrismaError as e: raise AgentQueryError(f"Database query failed: {str(e)}") except Exception as e: raise AgentQueryError(f"Unexpected error occurred: {str(e)}") @@ -145,10 +146,10 @@ async def search_db( query: str, page: int = 1, page_size: int = 10, - categories: List[str] | None = None, + categories: typing.List[str] | None = None, description_threshold: int = 60, sort_by: str = "rank", - sort_order: Literal["desc"] | Literal["asc"] = "desc", + sort_order: typing.Literal["desc"] | typing.Literal["asc"] = "desc", ): """Perform a search for agents based on the provided query string. @@ -210,12 +211,12 @@ async def search_db( results = await prisma.client.get_client().query_raw( query=sql_query, - model=AgentsWithRank, + model=market.utils.extension_types.AgentsWithRank, ) return results - except PrismaError as e: + except prisma.errors.PrismaError as e: raise AgentQueryError(f"Database query failed: {str(e)}") except Exception as e: raise AgentQueryError(f"Unexpected error occurred: {str(e)}") diff --git a/rnd/market/market/routes/admin.py b/rnd/market/market/routes/admin.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/rnd/market/market/routes/agents.py b/rnd/market/market/routes/agents.py index b12476ee38..91cb6c3fcb 100644 --- a/rnd/market/market/routes/agents.py +++ b/rnd/market/market/routes/agents.py @@ -1,31 +1,41 @@ import json -from tempfile import NamedTemporaryFile -from typing import Literal, Optional +import tempfile +import typing -from fastapi import APIRouter, BackgroundTasks, HTTPException, Path, Query -from fastapi.responses import FileResponse -from prisma import Json +import fastapi +import fastapi.responses +import prisma +import market.db import market.model -from market.db import AgentQueryError, get_agent_details, get_agents import market.utils.analytics -router = APIRouter() +router = fastapi.APIRouter() @router.get("/agents", response_model=market.model.AgentListResponse) async def list_agents( - page: int = Query(1, ge=1, description="Page number"), - page_size: int = Query(10, ge=1, le=100, description="Number of items per page"), - name: Optional[str] = Query(None, description="Filter by agent name"), - keyword: Optional[str] = Query(None, description="Filter by keyword"), - category: Optional[str] = Query(None, description="Filter by category"), - description: Optional[str] = Query(None, description="Fuzzy search in description"), - description_threshold: int = Query( + page: int = fastapi.Query(1, ge=1, description="Page number"), + page_size: int = fastapi.Query( + 10, ge=1, le=100, description="Number of items per page" + ), + name: typing.Optional[str] = fastapi.Query( + None, description="Filter by agent name" + ), + keyword: typing.Optional[str] = fastapi.Query( + None, description="Filter by keyword" + ), + category: typing.Optional[str] = fastapi.Query( + None, description="Filter by category" + ), + description: typing.Optional[str] = fastapi.Query( + None, description="Fuzzy search in description" + ), + description_threshold: int = fastapi.Query( 60, ge=0, le=100, description="Fuzzy search threshold" ), - sort_by: str = Query("createdAt", description="Field to sort by"), - sort_order: Literal["asc"] | Literal["desc"] = Query( + sort_by: str = fastapi.Query("createdAt", description="Field to sort by"), + sort_order: typing.Literal["asc", "desc"] = fastapi.Query( "desc", description="Sort order (asc or desc)" ), ): @@ -50,7 +60,7 @@ async def list_agents( HTTPException: If there is a client error (status code 400) or an unexpected error (status code 500). """ try: - result = await get_agents( + result = await market.db.get_agents( page=page, page_size=page_size, name=name, @@ -62,7 +72,6 @@ async def list_agents( sort_order=sort_order, ) - # Convert the result to the response model agents = [ market.model.AgentResponse(**agent.dict()) for agent in result["agents"] ] @@ -75,19 +84,21 @@ async def list_agents( total_pages=result["total_pages"], ) - except AgentQueryError as e: - raise HTTPException(status_code=400, detail=str(e)) + except market.db.AgentQueryError as e: + raise fastapi.HTTPException(status_code=400, detail=str(e)) except Exception as e: - raise HTTPException( + raise fastapi.HTTPException( status_code=500, detail=f"An unexpected error occurred: {e}" ) @router.get("/agents/{agent_id}", response_model=market.model.AgentDetailResponse) async def get_agent_details_endpoint( - background_tasks: BackgroundTasks, - agent_id: str = Path(..., description="The ID of the agent to retrieve"), - version: Optional[int] = Query(None, description="Specific version of the agent"), + background_tasks: fastapi.BackgroundTasks, + agent_id: str = fastapi.Path(..., description="The ID of the agent to retrieve"), + version: typing.Optional[int] = fastapi.Query( + None, description="Specific version of the agent" + ), ): """ Retrieve details of a specific agent. @@ -103,24 +114,26 @@ async def get_agent_details_endpoint( HTTPException: If the agent is not found or an unexpected error occurs. """ try: - agent = await get_agent_details(agent_id, version) + agent = await market.db.get_agent_details(agent_id, version) background_tasks.add_task(market.utils.analytics.track_view, agent_id) return market.model.AgentDetailResponse(**agent.model_dump()) - except AgentQueryError as e: - raise HTTPException(status_code=404, detail=str(e)) + except market.db.AgentQueryError as e: + raise fastapi.HTTPException(status_code=404, detail=str(e)) except Exception as e: - raise HTTPException( + raise fastapi.HTTPException( status_code=500, detail=f"An unexpected error occurred: {str(e)}" ) @router.get("/agents/{agent_id}/download") async def download_agent( - background_tasks: BackgroundTasks, - agent_id: str = Path(..., description="The ID of the agent to download"), - version: Optional[int] = Query(None, description="Specific version of the agent"), -) -> FileResponse: + background_tasks: fastapi.BackgroundTasks, + agent_id: str = fastapi.Path(..., description="The ID of the agent to download"), + version: typing.Optional[int] = fastapi.Query( + None, description="Specific version of the agent" + ), +) -> fastapi.responses.FileResponse: """ Download the agent file by streaming its content. @@ -134,22 +147,20 @@ async def download_agent( Raises: HTTPException: If the agent is not found or an unexpected error occurs. """ - agent = await get_agent_details(agent_id, version) + agent = await market.db.get_agent_details(agent_id, version) - # The agent.graph is already a JSON string, no need to parse and re-stringify - graph_data: Json = agent.graph + graph_data: prisma.Json = agent.graph background_tasks.add_task(market.utils.analytics.track_download, agent_id) - # Prepare the file name for download file_name = f"agent_{agent_id}_v{version or 'latest'}.json" - # Create a temporary file to store the graph data - with NamedTemporaryFile(mode="w", suffix=".json", delete=False) as tmp_file: + with tempfile.NamedTemporaryFile( + mode="w", suffix=".json", delete=False + ) as tmp_file: tmp_file.write(json.dumps(graph_data)) tmp_file.flush() - # Return the temporary file as a streaming response - return FileResponse( + return fastapi.responses.FileResponse( tmp_file.name, filename=file_name, media_type="application/json" ) diff --git a/rnd/market/market/routes/search.py b/rnd/market/market/routes/search.py index 45c1004124..5964e2c8d7 100644 --- a/rnd/market/market/routes/search.py +++ b/rnd/market/market/routes/search.py @@ -1,27 +1,31 @@ -from typing import List, Literal +import typing -from fastapi import APIRouter, Query +import fastapi -from market.db import search_db -from market.utils.extension_types import AgentsWithRank +import market.db +import market.utils.extension_types -router = APIRouter() +router = fastapi.APIRouter() @router.get("/search") async def search( query: str, - page: int = Query(1, description="The pagination page to start on"), - page_size: int = Query(10, description="The number of items to return per page"), - categories: List[str] = Query(None, description="The categories to filter by"), - description_threshold: int = Query( + page: int = fastapi.Query(1, description="The pagination page to start on"), + page_size: int = fastapi.Query( + 10, description="The number of items to return per page" + ), + categories: typing.List[str] = fastapi.Query( + None, description="The categories to filter by" + ), + description_threshold: int = fastapi.Query( 60, description="The number of characters to return from the description" ), - sort_by: str = Query("rank", description="Sorting by column"), - sort_order: Literal["desc", "asc"] = Query( + sort_by: str = fastapi.Query("rank", description="Sorting by column"), + sort_order: typing.Literal["desc", "asc"] = fastapi.Query( "desc", description="The sort order based on sort_by" ), -) -> List[AgentsWithRank]: +) -> typing.List[market.utils.extension_types.AgentsWithRank]: """searches endpoint for agents Args: @@ -33,7 +37,7 @@ async def search( sort_by (str, optional): Sorting by column. Defaults to "rank". sort_order ('asc' | 'desc', optional): the sort order based on sort_by. Defaults to "desc". """ - return await search_db( + return await market.db.search_db( query=query, page=page, page_size=page_size, diff --git a/rnd/market/market/utils/analytics.py b/rnd/market/market/utils/analytics.py index d97bd59a02..71dcaf0786 100644 --- a/rnd/market/market/utils/analytics.py +++ b/rnd/market/market/utils/analytics.py @@ -1,4 +1,4 @@ -from prisma.models import AnalyticsTracker +import prisma.models async def track_download(agent_id: str): @@ -13,7 +13,7 @@ async def track_download(agent_id: str): Exception: If there is an error tracking the download event. """ try: - await AnalyticsTracker.prisma().upsert( + await prisma.models.AnalyticsTracker.prisma().upsert( where={"agentId": agent_id}, data={ "update": {"downloads": {"increment": 1}}, @@ -36,7 +36,7 @@ async def track_view(agent_id: str): Exception: If there is an error tracking the view event. """ try: - await AnalyticsTracker.prisma().upsert( + await prisma.models.AnalyticsTracker.prisma().upsert( where={"agentId": agent_id}, data={ "update": {"views": {"increment": 1}}, @@ -44,4 +44,4 @@ async def track_view(agent_id: str): }, ) except Exception as e: - raise Exception(f"Error tracking view event: {str(e)}") \ No newline at end of file + raise Exception(f"Error tracking view event: {str(e)}") diff --git a/rnd/market/market/utils/extension_types.py b/rnd/market/market/utils/extension_types.py index f03a2cbf63..d76bbb19f5 100644 --- a/rnd/market/market/utils/extension_types.py +++ b/rnd/market/market/utils/extension_types.py @@ -1,5 +1,5 @@ -from prisma.models import Agents +import prisma.models -class AgentsWithRank(Agents): +class AgentsWithRank(prisma.models.Agents): rank: float diff --git a/rnd/market/market/utils/partial_types.py b/rnd/market/market/utils/partial_types.py index b722b56d39..3ec3bdc2b8 100644 --- a/rnd/market/market/utils/partial_types.py +++ b/rnd/market/market/utils/partial_types.py @@ -1,6 +1,6 @@ -from prisma.models import Agents +import prisma.models -Agents.create_partial( +prisma.models.Agents.create_partial( "AgentOnlyDescriptionNameAuthorIdCategories", include={"name", "author", "id", "categories"}, )