changed all imports to be fully qualified

This commit is contained in:
SwiftyOS
2024-08-01 08:58:17 +02:00
parent 4e16366fda
commit 738ba79cff
8 changed files with 118 additions and 105 deletions

View File

@@ -1,61 +1,58 @@
import contextlib
import os
from contextlib import asynccontextmanager
import dotenv
import fastapi
import fastapi.middleware.gzip
import prisma
import sentry_sdk
from dotenv import load_dotenv
from fastapi import FastAPI
from fastapi.middleware.gzip import GZipMiddleware
from prisma import Prisma
from sentry_sdk.integrations.asyncio import AsyncioIntegration
from sentry_sdk.integrations.fastapi import FastApiIntegration
from sentry_sdk.integrations.starlette import StarletteIntegration
import sentry_sdk.integrations.asyncio
import sentry_sdk.integrations.fastapi
import sentry_sdk.integrations.starlette
from market.routes import agents, search
import market.routes.agents
import market.routes.search
load_dotenv()
dotenv.load_dotenv()
if os.environ.get("SENTRY_DSN"):
sentry_sdk.init(
dsn=os.environ.get("SENTRY_DSN"),
# Set traces_sample_rate to 1.0 to capture 100%
# of transactions for performance monitoring.
traces_sample_rate=1.0,
# Set profiles_sample_rate to 1.0 to profile 100%
# of sampled transactions.
# We recommend adjusting this value in production.
profiles_sample_rate=1.0,
enable_tracing=True,
environment=os.environ.get("RUN_ENV", default="CLOUD").lower(),
integrations=[
StarletteIntegration(transaction_style="url"),
FastApiIntegration(transaction_style="url"),
AsyncioIntegration(),
sentry_sdk.integrations.starlette.StarletteIntegration(
transaction_style="url"
),
sentry_sdk.integrations.fastapi.FastApiIntegration(transaction_style="url"),
sentry_sdk.integrations.asyncio.AsyncioIntegration(),
],
)
db_client = Prisma(auto_register=True)
db_client = prisma.Prisma(auto_register=True)
@asynccontextmanager
async def lifespan(app: FastAPI):
@contextlib.asynccontextmanager
async def lifespan(app: fastapi.FastAPI):
await db_client.connect()
yield
await db_client.disconnect()
app = FastAPI(
title="Marketplace API",
description=(
"AutoGPT Marketplace API is a service that allows users to share AI agents."
),
app = fastapi.FastAPI(
title="Marketplace API",
description="AutoGPT Marketplace API is a service that allows users to share AI agents.",
summary="Maketplace API",
version="0.1",
lifespan=lifespan,
)
# Add gzip middleware to compress responses
app.add_middleware(GZipMiddleware, minimum_size=1000)
app.include_router(agents.router, prefix="/market/agents", tags=["agents"])
app.include_router(search.router, prefix="/market/search", tags=["search"])
app.add_middleware(fastapi.middleware.gzip.GZipMiddleware, minimum_size=1000)
app.include_router(
market.routes.agents.router, prefix="/market/agents", tags=["agents"]
)
app.include_router(
market.routes.search.router, prefix="/market/search", tags=["search"]
)

View File

@@ -1,11 +1,12 @@
from typing import List, Literal
import typing
import fuzzywuzzy
import fuzzywuzzy.fuzz
import prisma.errors
import prisma.models
import prisma.types
from fuzzywuzzy import fuzz
from prisma.errors import PrismaError
from market.utils.extension_types import AgentsWithRank
import market.utils.extension_types
class AgentQueryError(Exception):
@@ -23,7 +24,7 @@ async def get_agents(
description: str | None = None,
description_threshold: int = 60,
sort_by: str = "createdAt",
sort_order: Literal["desc"] | Literal["asc"] = "desc",
sort_order: typing.Literal["desc"] | typing.Literal["asc"] = "desc",
):
"""
Retrieve a list of agents from the database based on the provided filters and pagination parameters.
@@ -68,7 +69,7 @@ async def get_agents(
skip=skip,
take=page_size,
)
except PrismaError as e:
except prisma.errors.PrismaError as e:
raise AgentQueryError(f"Database query failed: {str(e)}")
# Apply fuzzy search on description if provided
@@ -78,7 +79,7 @@ async def get_agents(
for agent in agents:
if (
agent.description
and fuzz.partial_ratio(
and fuzzywuzzy.fuzz.partial_ratio(
description.lower(), agent.description.lower()
)
>= description_threshold
@@ -135,7 +136,7 @@ async def get_agent_details(agent_id: str, version: int | None = None):
return agent
except PrismaError as e:
except prisma.errors.PrismaError as e:
raise AgentQueryError(f"Database query failed: {str(e)}")
except Exception as e:
raise AgentQueryError(f"Unexpected error occurred: {str(e)}")
@@ -145,10 +146,10 @@ async def search_db(
query: str,
page: int = 1,
page_size: int = 10,
categories: List[str] | None = None,
categories: typing.List[str] | None = None,
description_threshold: int = 60,
sort_by: str = "rank",
sort_order: Literal["desc"] | Literal["asc"] = "desc",
sort_order: typing.Literal["desc"] | typing.Literal["asc"] = "desc",
):
"""Perform a search for agents based on the provided query string.
@@ -210,12 +211,12 @@ async def search_db(
results = await prisma.client.get_client().query_raw(
query=sql_query,
model=AgentsWithRank,
model=market.utils.extension_types.AgentsWithRank,
)
return results
except PrismaError as e:
except prisma.errors.PrismaError as e:
raise AgentQueryError(f"Database query failed: {str(e)}")
except Exception as e:
raise AgentQueryError(f"Unexpected error occurred: {str(e)}")

View File

View File

@@ -1,31 +1,41 @@
import json
from tempfile import NamedTemporaryFile
from typing import Literal, Optional
import tempfile
import typing
from fastapi import APIRouter, BackgroundTasks, HTTPException, Path, Query
from fastapi.responses import FileResponse
from prisma import Json
import fastapi
import fastapi.responses
import prisma
import market.db
import market.model
from market.db import AgentQueryError, get_agent_details, get_agents
import market.utils.analytics
router = APIRouter()
router = fastapi.APIRouter()
@router.get("/agents", response_model=market.model.AgentListResponse)
async def list_agents(
page: int = Query(1, ge=1, description="Page number"),
page_size: int = Query(10, ge=1, le=100, description="Number of items per page"),
name: Optional[str] = Query(None, description="Filter by agent name"),
keyword: Optional[str] = Query(None, description="Filter by keyword"),
category: Optional[str] = Query(None, description="Filter by category"),
description: Optional[str] = Query(None, description="Fuzzy search in description"),
description_threshold: int = Query(
page: int = fastapi.Query(1, ge=1, description="Page number"),
page_size: int = fastapi.Query(
10, ge=1, le=100, description="Number of items per page"
),
name: typing.Optional[str] = fastapi.Query(
None, description="Filter by agent name"
),
keyword: typing.Optional[str] = fastapi.Query(
None, description="Filter by keyword"
),
category: typing.Optional[str] = fastapi.Query(
None, description="Filter by category"
),
description: typing.Optional[str] = fastapi.Query(
None, description="Fuzzy search in description"
),
description_threshold: int = fastapi.Query(
60, ge=0, le=100, description="Fuzzy search threshold"
),
sort_by: str = Query("createdAt", description="Field to sort by"),
sort_order: Literal["asc"] | Literal["desc"] = Query(
sort_by: str = fastapi.Query("createdAt", description="Field to sort by"),
sort_order: typing.Literal["asc", "desc"] = fastapi.Query(
"desc", description="Sort order (asc or desc)"
),
):
@@ -50,7 +60,7 @@ async def list_agents(
HTTPException: If there is a client error (status code 400) or an unexpected error (status code 500).
"""
try:
result = await get_agents(
result = await market.db.get_agents(
page=page,
page_size=page_size,
name=name,
@@ -62,7 +72,6 @@ async def list_agents(
sort_order=sort_order,
)
# Convert the result to the response model
agents = [
market.model.AgentResponse(**agent.dict()) for agent in result["agents"]
]
@@ -75,19 +84,21 @@ async def list_agents(
total_pages=result["total_pages"],
)
except AgentQueryError as e:
raise HTTPException(status_code=400, detail=str(e))
except market.db.AgentQueryError as e:
raise fastapi.HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(
raise fastapi.HTTPException(
status_code=500, detail=f"An unexpected error occurred: {e}"
)
@router.get("/agents/{agent_id}", response_model=market.model.AgentDetailResponse)
async def get_agent_details_endpoint(
background_tasks: BackgroundTasks,
agent_id: str = Path(..., description="The ID of the agent to retrieve"),
version: Optional[int] = Query(None, description="Specific version of the agent"),
background_tasks: fastapi.BackgroundTasks,
agent_id: str = fastapi.Path(..., description="The ID of the agent to retrieve"),
version: typing.Optional[int] = fastapi.Query(
None, description="Specific version of the agent"
),
):
"""
Retrieve details of a specific agent.
@@ -103,24 +114,26 @@ async def get_agent_details_endpoint(
HTTPException: If the agent is not found or an unexpected error occurs.
"""
try:
agent = await get_agent_details(agent_id, version)
agent = await market.db.get_agent_details(agent_id, version)
background_tasks.add_task(market.utils.analytics.track_view, agent_id)
return market.model.AgentDetailResponse(**agent.model_dump())
except AgentQueryError as e:
raise HTTPException(status_code=404, detail=str(e))
except market.db.AgentQueryError as e:
raise fastapi.HTTPException(status_code=404, detail=str(e))
except Exception as e:
raise HTTPException(
raise fastapi.HTTPException(
status_code=500, detail=f"An unexpected error occurred: {str(e)}"
)
@router.get("/agents/{agent_id}/download")
async def download_agent(
background_tasks: BackgroundTasks,
agent_id: str = Path(..., description="The ID of the agent to download"),
version: Optional[int] = Query(None, description="Specific version of the agent"),
) -> FileResponse:
background_tasks: fastapi.BackgroundTasks,
agent_id: str = fastapi.Path(..., description="The ID of the agent to download"),
version: typing.Optional[int] = fastapi.Query(
None, description="Specific version of the agent"
),
) -> fastapi.responses.FileResponse:
"""
Download the agent file by streaming its content.
@@ -134,22 +147,20 @@ async def download_agent(
Raises:
HTTPException: If the agent is not found or an unexpected error occurs.
"""
agent = await get_agent_details(agent_id, version)
agent = await market.db.get_agent_details(agent_id, version)
# The agent.graph is already a JSON string, no need to parse and re-stringify
graph_data: Json = agent.graph
graph_data: prisma.Json = agent.graph
background_tasks.add_task(market.utils.analytics.track_download, agent_id)
# Prepare the file name for download
file_name = f"agent_{agent_id}_v{version or 'latest'}.json"
# Create a temporary file to store the graph data
with NamedTemporaryFile(mode="w", suffix=".json", delete=False) as tmp_file:
with tempfile.NamedTemporaryFile(
mode="w", suffix=".json", delete=False
) as tmp_file:
tmp_file.write(json.dumps(graph_data))
tmp_file.flush()
# Return the temporary file as a streaming response
return FileResponse(
return fastapi.responses.FileResponse(
tmp_file.name, filename=file_name, media_type="application/json"
)

View File

@@ -1,27 +1,31 @@
from typing import List, Literal
import typing
from fastapi import APIRouter, Query
import fastapi
from market.db import search_db
from market.utils.extension_types import AgentsWithRank
import market.db
import market.utils.extension_types
router = APIRouter()
router = fastapi.APIRouter()
@router.get("/search")
async def search(
query: str,
page: int = Query(1, description="The pagination page to start on"),
page_size: int = Query(10, description="The number of items to return per page"),
categories: List[str] = Query(None, description="The categories to filter by"),
description_threshold: int = Query(
page: int = fastapi.Query(1, description="The pagination page to start on"),
page_size: int = fastapi.Query(
10, description="The number of items to return per page"
),
categories: typing.List[str] = fastapi.Query(
None, description="The categories to filter by"
),
description_threshold: int = fastapi.Query(
60, description="The number of characters to return from the description"
),
sort_by: str = Query("rank", description="Sorting by column"),
sort_order: Literal["desc", "asc"] = Query(
sort_by: str = fastapi.Query("rank", description="Sorting by column"),
sort_order: typing.Literal["desc", "asc"] = fastapi.Query(
"desc", description="The sort order based on sort_by"
),
) -> List[AgentsWithRank]:
) -> typing.List[market.utils.extension_types.AgentsWithRank]:
"""searches endpoint for agents
Args:
@@ -33,7 +37,7 @@ async def search(
sort_by (str, optional): Sorting by column. Defaults to "rank".
sort_order ('asc' | 'desc', optional): the sort order based on sort_by. Defaults to "desc".
"""
return await search_db(
return await market.db.search_db(
query=query,
page=page,
page_size=page_size,

View File

@@ -1,4 +1,4 @@
from prisma.models import AnalyticsTracker
import prisma.models
async def track_download(agent_id: str):
@@ -13,7 +13,7 @@ async def track_download(agent_id: str):
Exception: If there is an error tracking the download event.
"""
try:
await AnalyticsTracker.prisma().upsert(
await prisma.models.AnalyticsTracker.prisma().upsert(
where={"agentId": agent_id},
data={
"update": {"downloads": {"increment": 1}},
@@ -36,7 +36,7 @@ async def track_view(agent_id: str):
Exception: If there is an error tracking the view event.
"""
try:
await AnalyticsTracker.prisma().upsert(
await prisma.models.AnalyticsTracker.prisma().upsert(
where={"agentId": agent_id},
data={
"update": {"views": {"increment": 1}},
@@ -44,4 +44,4 @@ async def track_view(agent_id: str):
},
)
except Exception as e:
raise Exception(f"Error tracking view event: {str(e)}")
raise Exception(f"Error tracking view event: {str(e)}")

View File

@@ -1,5 +1,5 @@
from prisma.models import Agents
import prisma.models
class AgentsWithRank(Agents):
class AgentsWithRank(prisma.models.Agents):
rank: float

View File

@@ -1,6 +1,6 @@
from prisma.models import Agents
import prisma.models
Agents.create_partial(
prisma.models.Agents.create_partial(
"AgentOnlyDescriptionNameAuthorIdCategories",
include={"name", "author", "id", "categories"},
)