mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-09 23:28:07 -05:00
changed all imports to be fully qualified
This commit is contained in:
@@ -1,61 +1,58 @@
|
||||
import contextlib
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import dotenv
|
||||
import fastapi
|
||||
import fastapi.middleware.gzip
|
||||
import prisma
|
||||
import sentry_sdk
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.gzip import GZipMiddleware
|
||||
from prisma import Prisma
|
||||
from sentry_sdk.integrations.asyncio import AsyncioIntegration
|
||||
from sentry_sdk.integrations.fastapi import FastApiIntegration
|
||||
from sentry_sdk.integrations.starlette import StarletteIntegration
|
||||
import sentry_sdk.integrations.asyncio
|
||||
import sentry_sdk.integrations.fastapi
|
||||
import sentry_sdk.integrations.starlette
|
||||
|
||||
from market.routes import agents, search
|
||||
import market.routes.agents
|
||||
import market.routes.search
|
||||
|
||||
load_dotenv()
|
||||
dotenv.load_dotenv()
|
||||
|
||||
if os.environ.get("SENTRY_DSN"):
|
||||
sentry_sdk.init(
|
||||
dsn=os.environ.get("SENTRY_DSN"),
|
||||
# Set traces_sample_rate to 1.0 to capture 100%
|
||||
# of transactions for performance monitoring.
|
||||
traces_sample_rate=1.0,
|
||||
# Set profiles_sample_rate to 1.0 to profile 100%
|
||||
# of sampled transactions.
|
||||
# We recommend adjusting this value in production.
|
||||
profiles_sample_rate=1.0,
|
||||
enable_tracing=True,
|
||||
environment=os.environ.get("RUN_ENV", default="CLOUD").lower(),
|
||||
integrations=[
|
||||
StarletteIntegration(transaction_style="url"),
|
||||
FastApiIntegration(transaction_style="url"),
|
||||
AsyncioIntegration(),
|
||||
sentry_sdk.integrations.starlette.StarletteIntegration(
|
||||
transaction_style="url"
|
||||
),
|
||||
sentry_sdk.integrations.fastapi.FastApiIntegration(transaction_style="url"),
|
||||
sentry_sdk.integrations.asyncio.AsyncioIntegration(),
|
||||
],
|
||||
)
|
||||
|
||||
db_client = Prisma(auto_register=True)
|
||||
db_client = prisma.Prisma(auto_register=True)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
@contextlib.asynccontextmanager
|
||||
async def lifespan(app: fastapi.FastAPI):
|
||||
await db_client.connect()
|
||||
yield
|
||||
await db_client.disconnect()
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="Marketplace API",
|
||||
description=(
|
||||
"AutoGPT Marketplace API is a service that allows users to share AI agents."
|
||||
),
|
||||
app = fastapi.FastAPI(
|
||||
title="Marketplace API",
|
||||
description="AutoGPT Marketplace API is a service that allows users to share AI agents.",
|
||||
summary="Maketplace API",
|
||||
version="0.1",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# Add gzip middleware to compress responses
|
||||
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
||||
|
||||
|
||||
app.include_router(agents.router, prefix="/market/agents", tags=["agents"])
|
||||
app.include_router(search.router, prefix="/market/search", tags=["search"])
|
||||
app.add_middleware(fastapi.middleware.gzip.GZipMiddleware, minimum_size=1000)
|
||||
app.include_router(
|
||||
market.routes.agents.router, prefix="/market/agents", tags=["agents"]
|
||||
)
|
||||
app.include_router(
|
||||
market.routes.search.router, prefix="/market/search", tags=["search"]
|
||||
)
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
from typing import List, Literal
|
||||
import typing
|
||||
|
||||
import fuzzywuzzy
|
||||
import fuzzywuzzy.fuzz
|
||||
import prisma.errors
|
||||
import prisma.models
|
||||
import prisma.types
|
||||
from fuzzywuzzy import fuzz
|
||||
from prisma.errors import PrismaError
|
||||
|
||||
from market.utils.extension_types import AgentsWithRank
|
||||
import market.utils.extension_types
|
||||
|
||||
|
||||
class AgentQueryError(Exception):
|
||||
@@ -23,7 +24,7 @@ async def get_agents(
|
||||
description: str | None = None,
|
||||
description_threshold: int = 60,
|
||||
sort_by: str = "createdAt",
|
||||
sort_order: Literal["desc"] | Literal["asc"] = "desc",
|
||||
sort_order: typing.Literal["desc"] | typing.Literal["asc"] = "desc",
|
||||
):
|
||||
"""
|
||||
Retrieve a list of agents from the database based on the provided filters and pagination parameters.
|
||||
@@ -68,7 +69,7 @@ async def get_agents(
|
||||
skip=skip,
|
||||
take=page_size,
|
||||
)
|
||||
except PrismaError as e:
|
||||
except prisma.errors.PrismaError as e:
|
||||
raise AgentQueryError(f"Database query failed: {str(e)}")
|
||||
|
||||
# Apply fuzzy search on description if provided
|
||||
@@ -78,7 +79,7 @@ async def get_agents(
|
||||
for agent in agents:
|
||||
if (
|
||||
agent.description
|
||||
and fuzz.partial_ratio(
|
||||
and fuzzywuzzy.fuzz.partial_ratio(
|
||||
description.lower(), agent.description.lower()
|
||||
)
|
||||
>= description_threshold
|
||||
@@ -135,7 +136,7 @@ async def get_agent_details(agent_id: str, version: int | None = None):
|
||||
|
||||
return agent
|
||||
|
||||
except PrismaError as e:
|
||||
except prisma.errors.PrismaError as e:
|
||||
raise AgentQueryError(f"Database query failed: {str(e)}")
|
||||
except Exception as e:
|
||||
raise AgentQueryError(f"Unexpected error occurred: {str(e)}")
|
||||
@@ -145,10 +146,10 @@ async def search_db(
|
||||
query: str,
|
||||
page: int = 1,
|
||||
page_size: int = 10,
|
||||
categories: List[str] | None = None,
|
||||
categories: typing.List[str] | None = None,
|
||||
description_threshold: int = 60,
|
||||
sort_by: str = "rank",
|
||||
sort_order: Literal["desc"] | Literal["asc"] = "desc",
|
||||
sort_order: typing.Literal["desc"] | typing.Literal["asc"] = "desc",
|
||||
):
|
||||
"""Perform a search for agents based on the provided query string.
|
||||
|
||||
@@ -210,12 +211,12 @@ async def search_db(
|
||||
|
||||
results = await prisma.client.get_client().query_raw(
|
||||
query=sql_query,
|
||||
model=AgentsWithRank,
|
||||
model=market.utils.extension_types.AgentsWithRank,
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
except PrismaError as e:
|
||||
except prisma.errors.PrismaError as e:
|
||||
raise AgentQueryError(f"Database query failed: {str(e)}")
|
||||
except Exception as e:
|
||||
raise AgentQueryError(f"Unexpected error occurred: {str(e)}")
|
||||
|
||||
0
rnd/market/market/routes/admin.py
Normal file
0
rnd/market/market/routes/admin.py
Normal file
@@ -1,31 +1,41 @@
|
||||
import json
|
||||
from tempfile import NamedTemporaryFile
|
||||
from typing import Literal, Optional
|
||||
import tempfile
|
||||
import typing
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, HTTPException, Path, Query
|
||||
from fastapi.responses import FileResponse
|
||||
from prisma import Json
|
||||
import fastapi
|
||||
import fastapi.responses
|
||||
import prisma
|
||||
|
||||
import market.db
|
||||
import market.model
|
||||
from market.db import AgentQueryError, get_agent_details, get_agents
|
||||
import market.utils.analytics
|
||||
|
||||
router = APIRouter()
|
||||
router = fastapi.APIRouter()
|
||||
|
||||
|
||||
@router.get("/agents", response_model=market.model.AgentListResponse)
|
||||
async def list_agents(
|
||||
page: int = Query(1, ge=1, description="Page number"),
|
||||
page_size: int = Query(10, ge=1, le=100, description="Number of items per page"),
|
||||
name: Optional[str] = Query(None, description="Filter by agent name"),
|
||||
keyword: Optional[str] = Query(None, description="Filter by keyword"),
|
||||
category: Optional[str] = Query(None, description="Filter by category"),
|
||||
description: Optional[str] = Query(None, description="Fuzzy search in description"),
|
||||
description_threshold: int = Query(
|
||||
page: int = fastapi.Query(1, ge=1, description="Page number"),
|
||||
page_size: int = fastapi.Query(
|
||||
10, ge=1, le=100, description="Number of items per page"
|
||||
),
|
||||
name: typing.Optional[str] = fastapi.Query(
|
||||
None, description="Filter by agent name"
|
||||
),
|
||||
keyword: typing.Optional[str] = fastapi.Query(
|
||||
None, description="Filter by keyword"
|
||||
),
|
||||
category: typing.Optional[str] = fastapi.Query(
|
||||
None, description="Filter by category"
|
||||
),
|
||||
description: typing.Optional[str] = fastapi.Query(
|
||||
None, description="Fuzzy search in description"
|
||||
),
|
||||
description_threshold: int = fastapi.Query(
|
||||
60, ge=0, le=100, description="Fuzzy search threshold"
|
||||
),
|
||||
sort_by: str = Query("createdAt", description="Field to sort by"),
|
||||
sort_order: Literal["asc"] | Literal["desc"] = Query(
|
||||
sort_by: str = fastapi.Query("createdAt", description="Field to sort by"),
|
||||
sort_order: typing.Literal["asc", "desc"] = fastapi.Query(
|
||||
"desc", description="Sort order (asc or desc)"
|
||||
),
|
||||
):
|
||||
@@ -50,7 +60,7 @@ async def list_agents(
|
||||
HTTPException: If there is a client error (status code 400) or an unexpected error (status code 500).
|
||||
"""
|
||||
try:
|
||||
result = await get_agents(
|
||||
result = await market.db.get_agents(
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
name=name,
|
||||
@@ -62,7 +72,6 @@ async def list_agents(
|
||||
sort_order=sort_order,
|
||||
)
|
||||
|
||||
# Convert the result to the response model
|
||||
agents = [
|
||||
market.model.AgentResponse(**agent.dict()) for agent in result["agents"]
|
||||
]
|
||||
@@ -75,19 +84,21 @@ async def list_agents(
|
||||
total_pages=result["total_pages"],
|
||||
)
|
||||
|
||||
except AgentQueryError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except market.db.AgentQueryError as e:
|
||||
raise fastapi.HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500, detail=f"An unexpected error occurred: {e}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/agents/{agent_id}", response_model=market.model.AgentDetailResponse)
|
||||
async def get_agent_details_endpoint(
|
||||
background_tasks: BackgroundTasks,
|
||||
agent_id: str = Path(..., description="The ID of the agent to retrieve"),
|
||||
version: Optional[int] = Query(None, description="Specific version of the agent"),
|
||||
background_tasks: fastapi.BackgroundTasks,
|
||||
agent_id: str = fastapi.Path(..., description="The ID of the agent to retrieve"),
|
||||
version: typing.Optional[int] = fastapi.Query(
|
||||
None, description="Specific version of the agent"
|
||||
),
|
||||
):
|
||||
"""
|
||||
Retrieve details of a specific agent.
|
||||
@@ -103,24 +114,26 @@ async def get_agent_details_endpoint(
|
||||
HTTPException: If the agent is not found or an unexpected error occurs.
|
||||
"""
|
||||
try:
|
||||
agent = await get_agent_details(agent_id, version)
|
||||
agent = await market.db.get_agent_details(agent_id, version)
|
||||
background_tasks.add_task(market.utils.analytics.track_view, agent_id)
|
||||
return market.model.AgentDetailResponse(**agent.model_dump())
|
||||
|
||||
except AgentQueryError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except market.db.AgentQueryError as e:
|
||||
raise fastapi.HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500, detail=f"An unexpected error occurred: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/agents/{agent_id}/download")
|
||||
async def download_agent(
|
||||
background_tasks: BackgroundTasks,
|
||||
agent_id: str = Path(..., description="The ID of the agent to download"),
|
||||
version: Optional[int] = Query(None, description="Specific version of the agent"),
|
||||
) -> FileResponse:
|
||||
background_tasks: fastapi.BackgroundTasks,
|
||||
agent_id: str = fastapi.Path(..., description="The ID of the agent to download"),
|
||||
version: typing.Optional[int] = fastapi.Query(
|
||||
None, description="Specific version of the agent"
|
||||
),
|
||||
) -> fastapi.responses.FileResponse:
|
||||
"""
|
||||
Download the agent file by streaming its content.
|
||||
|
||||
@@ -134,22 +147,20 @@ async def download_agent(
|
||||
Raises:
|
||||
HTTPException: If the agent is not found or an unexpected error occurs.
|
||||
"""
|
||||
agent = await get_agent_details(agent_id, version)
|
||||
agent = await market.db.get_agent_details(agent_id, version)
|
||||
|
||||
# The agent.graph is already a JSON string, no need to parse and re-stringify
|
||||
graph_data: Json = agent.graph
|
||||
graph_data: prisma.Json = agent.graph
|
||||
|
||||
background_tasks.add_task(market.utils.analytics.track_download, agent_id)
|
||||
|
||||
# Prepare the file name for download
|
||||
file_name = f"agent_{agent_id}_v{version or 'latest'}.json"
|
||||
|
||||
# Create a temporary file to store the graph data
|
||||
with NamedTemporaryFile(mode="w", suffix=".json", delete=False) as tmp_file:
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".json", delete=False
|
||||
) as tmp_file:
|
||||
tmp_file.write(json.dumps(graph_data))
|
||||
tmp_file.flush()
|
||||
|
||||
# Return the temporary file as a streaming response
|
||||
return FileResponse(
|
||||
return fastapi.responses.FileResponse(
|
||||
tmp_file.name, filename=file_name, media_type="application/json"
|
||||
)
|
||||
|
||||
@@ -1,27 +1,31 @@
|
||||
from typing import List, Literal
|
||||
import typing
|
||||
|
||||
from fastapi import APIRouter, Query
|
||||
import fastapi
|
||||
|
||||
from market.db import search_db
|
||||
from market.utils.extension_types import AgentsWithRank
|
||||
import market.db
|
||||
import market.utils.extension_types
|
||||
|
||||
router = APIRouter()
|
||||
router = fastapi.APIRouter()
|
||||
|
||||
|
||||
@router.get("/search")
|
||||
async def search(
|
||||
query: str,
|
||||
page: int = Query(1, description="The pagination page to start on"),
|
||||
page_size: int = Query(10, description="The number of items to return per page"),
|
||||
categories: List[str] = Query(None, description="The categories to filter by"),
|
||||
description_threshold: int = Query(
|
||||
page: int = fastapi.Query(1, description="The pagination page to start on"),
|
||||
page_size: int = fastapi.Query(
|
||||
10, description="The number of items to return per page"
|
||||
),
|
||||
categories: typing.List[str] = fastapi.Query(
|
||||
None, description="The categories to filter by"
|
||||
),
|
||||
description_threshold: int = fastapi.Query(
|
||||
60, description="The number of characters to return from the description"
|
||||
),
|
||||
sort_by: str = Query("rank", description="Sorting by column"),
|
||||
sort_order: Literal["desc", "asc"] = Query(
|
||||
sort_by: str = fastapi.Query("rank", description="Sorting by column"),
|
||||
sort_order: typing.Literal["desc", "asc"] = fastapi.Query(
|
||||
"desc", description="The sort order based on sort_by"
|
||||
),
|
||||
) -> List[AgentsWithRank]:
|
||||
) -> typing.List[market.utils.extension_types.AgentsWithRank]:
|
||||
"""searches endpoint for agents
|
||||
|
||||
Args:
|
||||
@@ -33,7 +37,7 @@ async def search(
|
||||
sort_by (str, optional): Sorting by column. Defaults to "rank".
|
||||
sort_order ('asc' | 'desc', optional): the sort order based on sort_by. Defaults to "desc".
|
||||
"""
|
||||
return await search_db(
|
||||
return await market.db.search_db(
|
||||
query=query,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from prisma.models import AnalyticsTracker
|
||||
import prisma.models
|
||||
|
||||
|
||||
async def track_download(agent_id: str):
|
||||
@@ -13,7 +13,7 @@ async def track_download(agent_id: str):
|
||||
Exception: If there is an error tracking the download event.
|
||||
"""
|
||||
try:
|
||||
await AnalyticsTracker.prisma().upsert(
|
||||
await prisma.models.AnalyticsTracker.prisma().upsert(
|
||||
where={"agentId": agent_id},
|
||||
data={
|
||||
"update": {"downloads": {"increment": 1}},
|
||||
@@ -36,7 +36,7 @@ async def track_view(agent_id: str):
|
||||
Exception: If there is an error tracking the view event.
|
||||
"""
|
||||
try:
|
||||
await AnalyticsTracker.prisma().upsert(
|
||||
await prisma.models.AnalyticsTracker.prisma().upsert(
|
||||
where={"agentId": agent_id},
|
||||
data={
|
||||
"update": {"views": {"increment": 1}},
|
||||
@@ -44,4 +44,4 @@ async def track_view(agent_id: str):
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise Exception(f"Error tracking view event: {str(e)}")
|
||||
raise Exception(f"Error tracking view event: {str(e)}")
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from prisma.models import Agents
|
||||
import prisma.models
|
||||
|
||||
|
||||
class AgentsWithRank(Agents):
|
||||
class AgentsWithRank(prisma.models.Agents):
|
||||
rank: float
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from prisma.models import Agents
|
||||
import prisma.models
|
||||
|
||||
Agents.create_partial(
|
||||
prisma.models.Agents.create_partial(
|
||||
"AgentOnlyDescriptionNameAuthorIdCategories",
|
||||
include={"name", "author", "id", "categories"},
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user