mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-02 10:55:14 -05:00
Compare commits
9 Commits
test/verif
...
swiftyos/s
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2a63093b14 | ||
|
|
70f884367b | ||
|
|
10da2d3b2a | ||
|
|
55e206fcd7 | ||
|
|
74db958932 | ||
|
|
72d539f777 | ||
|
|
2142caf2ca | ||
|
|
8d0bbc5ffe | ||
|
|
3d4aca9fcc |
@@ -1,105 +0,0 @@
|
||||
import logging
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from autogpt_libs.supabase_integration_credentials_store import (
|
||||
SupabaseIntegrationCredentialsStore,
|
||||
)
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query, Request
|
||||
from pydantic import BaseModel
|
||||
from supabase import Client
|
||||
|
||||
from autogpt_server.integrations.oauth import HANDLERS_BY_NAME, BaseOAuthHandler
|
||||
from autogpt_server.util.settings import Settings
|
||||
|
||||
from .utils import get_supabase, get_user_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
integrations_api_router = APIRouter()
|
||||
|
||||
|
||||
def get_store(supabase: Client = Depends(get_supabase)):
|
||||
return SupabaseIntegrationCredentialsStore(supabase)
|
||||
|
||||
|
||||
class LoginResponse(BaseModel):
|
||||
login_url: str
|
||||
|
||||
|
||||
@integrations_api_router.get("/{provider}/login")
|
||||
async def login(
|
||||
provider: Annotated[str, Path(title="The provider to initiate an OAuth flow for")],
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
request: Request,
|
||||
store: Annotated[SupabaseIntegrationCredentialsStore, Depends(get_store)],
|
||||
scopes: Annotated[
|
||||
str, Query(title="Comma-separated list of authorization scopes")
|
||||
] = "",
|
||||
) -> LoginResponse:
|
||||
handler = _get_provider_oauth_handler(request, provider)
|
||||
|
||||
# Generate and store a secure random state token
|
||||
state = await store.store_state_token(user_id, provider)
|
||||
|
||||
requested_scopes = scopes.split(",") if scopes else []
|
||||
login_url = handler.get_login_url(requested_scopes, state)
|
||||
|
||||
return LoginResponse(login_url=login_url)
|
||||
|
||||
|
||||
class CredentialsMetaResponse(BaseModel):
|
||||
credentials_id: str
|
||||
credentials_type: Literal["oauth2", "api_key"]
|
||||
|
||||
|
||||
@integrations_api_router.post("/{provider}/callback")
|
||||
async def callback(
|
||||
provider: Annotated[str, Path(title="The target provider for this OAuth exchange")],
|
||||
code: Annotated[str, Body(title="Authorization code acquired by user login")],
|
||||
state_token: Annotated[str, Body(title="Anti-CSRF nonce")],
|
||||
store: Annotated[SupabaseIntegrationCredentialsStore, Depends(get_store)],
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
request: Request,
|
||||
) -> CredentialsMetaResponse:
|
||||
handler = _get_provider_oauth_handler(request, provider)
|
||||
|
||||
# Verify the state token
|
||||
if not await store.verify_state_token(user_id, state_token, provider):
|
||||
raise HTTPException(status_code=400, detail="Invalid or expired state token")
|
||||
|
||||
try:
|
||||
credentials = handler.exchange_code_for_tokens(code)
|
||||
except Exception as e:
|
||||
logger.warning(f"Code->Token exchange failed for provider {provider}: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
store.add_creds(user_id, credentials)
|
||||
return CredentialsMetaResponse(
|
||||
credentials_id=credentials.id,
|
||||
credentials_type=credentials.type,
|
||||
)
|
||||
|
||||
|
||||
# -------- UTILITIES --------- #
|
||||
|
||||
|
||||
def _get_provider_oauth_handler(req: Request, provider_name: str) -> BaseOAuthHandler:
|
||||
if provider_name not in HANDLERS_BY_NAME:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Unknown provider '{provider_name}'"
|
||||
)
|
||||
|
||||
client_id = getattr(settings.secrets, f"{provider_name}_client_id")
|
||||
client_secret = getattr(settings.secrets, f"{provider_name}_client_secret")
|
||||
if not (client_id and client_secret):
|
||||
raise HTTPException(
|
||||
status_code=501,
|
||||
detail=f"Integration with provider '{provider_name}' is not configured",
|
||||
)
|
||||
|
||||
handler_class = HANDLERS_BY_NAME[provider_name]
|
||||
return handler_class(
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
redirect_uri=str(req.url_for("callback", provider=provider_name)),
|
||||
)
|
||||
92
rnd/autogpt_server/autogpt_server/server/new_rest_app.py
Normal file
92
rnd/autogpt_server/autogpt_server/server/new_rest_app.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import contextlib
|
||||
|
||||
import fastapi
|
||||
import fastapi.responses
|
||||
import fastapi.middleware.cors
|
||||
|
||||
import autogpt_libs.auth.middleware
|
||||
import autogpt_server.data.block
|
||||
import autogpt_server.data.db
|
||||
import autogpt_server.data.graph
|
||||
import autogpt_server.data.user
|
||||
import autogpt_server.server.routes
|
||||
import autogpt_server.server.utils
|
||||
import autogpt_server.util.settings
|
||||
|
||||
from autogpt_server.data import user as user_db
|
||||
|
||||
settings = autogpt_server.util.settings.Settings()
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def app_lifespan(app: fastapi.FastAPI):
|
||||
await autogpt_server.data.db.connect()
|
||||
await autogpt_server.data.block.initialize_blocks()
|
||||
|
||||
if await user_db.create_default_user(settings.config.enable_auth):
|
||||
await autogpt_server.data.graph.import_packaged_templates()
|
||||
yield
|
||||
|
||||
await autogpt_server.data.db.disconnect()
|
||||
|
||||
|
||||
app = fastapi.FastAPI(
|
||||
title="AutoGPT Agent Server",
|
||||
description=(
|
||||
"This server is used to execute agents that are created by the "
|
||||
"AutoGPT system."
|
||||
),
|
||||
summary="AutoGPT Agent Server",
|
||||
version="0.1",
|
||||
lifespan=app_lifespan,
|
||||
)
|
||||
|
||||
api_router = fastapi.APIRouter(prefix="/api/v1")
|
||||
api_router.dependencies.append(
|
||||
fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)
|
||||
)
|
||||
|
||||
api_router.include_router(autogpt_server.server.routes.root_router)
|
||||
api_router.include_router(autogpt_server.server.routes.agents_router, tags=["agents"])
|
||||
api_router.include_router(autogpt_server.server.routes.blocks_router, tags=["blocks"])
|
||||
api_router.include_router(
|
||||
autogpt_server.server.routes.integrations_router, prefix="/integrations"
|
||||
)
|
||||
|
||||
app.include_router(api_router)
|
||||
|
||||
|
||||
@app.exception_handler(500)
|
||||
def handle_internal_http_error(request: fastapi.Request, exc: Exception):
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=500,
|
||||
content={"message": str(exc)},
|
||||
)
|
||||
|
||||
|
||||
@app.exception_handler(fastapi.exceptions.RequestValidationError)
|
||||
async def validation_exception_handler(request: fastapi.Request, exc: fastapi.exceptions.RequestValidationError):
|
||||
errors = []
|
||||
for err in exc.errors():
|
||||
error = {
|
||||
"field": ".".join(err["loc"][1:]), # Skipping 'body' or 'query' etc.
|
||||
"message": err["msg"],
|
||||
"type": err["type"]
|
||||
}
|
||||
errors.append(error)
|
||||
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=422,
|
||||
content={
|
||||
"status": "fail",
|
||||
"errors": errors
|
||||
},
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
fastapi.middleware.cors.CORSMiddleware,
|
||||
allow_origins=["*"], # Allows all origins
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"], # Allows all methods
|
||||
allow_headers=["*"], # Allows all headers
|
||||
)
|
||||
12
rnd/autogpt_server/autogpt_server/server/routes/__init__.py
Normal file
12
rnd/autogpt_server/autogpt_server/server/routes/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from .agents import router as agents_router
|
||||
from .root import router as root_router
|
||||
from .blocks import router as blocks_router
|
||||
from .integrations import integrations_api_router as integrations_router
|
||||
|
||||
|
||||
__all__ = [
|
||||
"agents_router",
|
||||
"root_router",
|
||||
"blocks_router",
|
||||
"integrations_router",
|
||||
]
|
||||
335
rnd/autogpt_server/autogpt_server/server/routes/agents.py
Normal file
335
rnd/autogpt_server/autogpt_server/server/routes/agents.py
Normal file
@@ -0,0 +1,335 @@
|
||||
import typing
|
||||
|
||||
import fastapi
|
||||
|
||||
import autogpt_server.data.graph
|
||||
import autogpt_server.data.queue
|
||||
import autogpt_server.data.execution
|
||||
import autogpt_server.executor
|
||||
import autogpt_server.server.model
|
||||
import autogpt_server.server.utils
|
||||
import autogpt_server.util.service
|
||||
import autogpt_server.util.settings
|
||||
|
||||
router = fastapi.APIRouter()
|
||||
|
||||
|
||||
def execution_manager_client() -> autogpt_server.executor.ExecutionManager:
|
||||
return autogpt_server.util.service.get_service_client(
|
||||
autogpt_server.executor.ExecutionManager,
|
||||
autogpt_server.util.settings.Config().execution_manager_port,
|
||||
)
|
||||
|
||||
|
||||
def execution_scheduler_client() -> autogpt_server.executor.ExecutionScheduler:
|
||||
return autogpt_server.util.service.get_service_client(
|
||||
autogpt_server.executor.ExecutionScheduler,
|
||||
autogpt_server.util.settings.Config().execution_scheduler_port,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/graphs")
|
||||
async def get_graphs(
|
||||
user_id: typing.Annotated[
|
||||
str, fastapi.Depends(autogpt_server.server.utils.get_user_id)
|
||||
]
|
||||
) -> list[autogpt_server.data.graph.GraphMeta]:
|
||||
return await autogpt_server.data.graph.get_graphs_meta(
|
||||
filter_by="active", user_id=user_id
|
||||
)
|
||||
|
||||
|
||||
@router.post("/graphs")
|
||||
async def create_new_graph(
|
||||
create_graph: autogpt_server.server.model.CreateGraph,
|
||||
user_id: typing.Annotated[
|
||||
str, fastapi.Depends(autogpt_server.server.utils.get_user_id)
|
||||
],
|
||||
) -> autogpt_server.data.graph.Graph:
|
||||
|
||||
if create_graph.graph:
|
||||
graph = create_graph.graph
|
||||
elif create_graph.template_id:
|
||||
# Create a new graph from a template
|
||||
graph = await autogpt_server.data.graph.get_graph(
|
||||
create_graph.template_id,
|
||||
create_graph.template_version,
|
||||
template=True,
|
||||
user_id=user_id,
|
||||
)
|
||||
if not graph:
|
||||
raise fastapi.HTTPException(
|
||||
400, detail=f"Template #{create_graph.template_id} not found"
|
||||
)
|
||||
graph.version = 1
|
||||
else:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=400, detail="Either graph or template_id must be provided."
|
||||
)
|
||||
|
||||
graph.is_template = False
|
||||
graph.is_active = True
|
||||
graph.reassign_ids(reassign_graph_id=True)
|
||||
|
||||
return await autogpt_server.data.graph.create_graph(graph, user_id=user_id)
|
||||
return await cls.create_graph(create_graph, is_template=False, user_id=user_id)
|
||||
|
||||
|
||||
@router.get("/graphs/{graph_id}")
|
||||
async def get_graph(
|
||||
graph_id: str,
|
||||
user_id: typing.Annotated[
|
||||
str, fastapi.Depends(autogpt_server.server.utils.get_user_id)
|
||||
],
|
||||
version: typing.Optional[int] = None,
|
||||
) -> autogpt_server.data.graph.Graph:
|
||||
graph = await autogpt_server.data.graph.get_graph(
|
||||
graph_id, version, user_id=user_id
|
||||
)
|
||||
if graph and graph.id != graph_id:
|
||||
raise fastapi.HTTPException(400, detail="Graph ID does not match ID in URI")
|
||||
if not graph:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=404, detail=f"Graph #{graph_id} not found."
|
||||
)
|
||||
return graph
|
||||
|
||||
|
||||
@router.put("/graphs/{graph_id}")
|
||||
async def update_graph(
|
||||
graph_id: str,
|
||||
graph: autogpt_server.data.graph.Graph,
|
||||
user_id: typing.Annotated[
|
||||
str, fastapi.Depends(autogpt_server.server.utils.get_user_id)
|
||||
],
|
||||
) -> autogpt_server.data.graph.Graph:
|
||||
# Sanity check
|
||||
if graph.id and graph.id != graph_id:
|
||||
raise fastapi.HTTPException(400, detail="Graph ID does not match ID in URI")
|
||||
|
||||
# Determine new version
|
||||
existing_versions = await autogpt_server.data.graph.get_graph_all_versions(
|
||||
graph_id, user_id=user_id
|
||||
)
|
||||
if not existing_versions:
|
||||
raise fastapi.HTTPException(404, detail=f"Graph #{graph_id} not found")
|
||||
latest_version_number = max(g.version for g in existing_versions)
|
||||
graph.version = latest_version_number + 1
|
||||
|
||||
latest_version_graph = next(
|
||||
v for v in existing_versions if v.version == latest_version_number
|
||||
)
|
||||
if latest_version_graph.is_template != graph.is_template:
|
||||
raise fastapi.HTTPException(
|
||||
400, detail="Changing is_template on an existing graph is forbidden"
|
||||
)
|
||||
graph.is_active = not graph.is_template
|
||||
graph.reassign_ids()
|
||||
|
||||
new_graph_version = await autogpt_server.data.graph.create_graph(
|
||||
graph, user_id=user_id
|
||||
)
|
||||
|
||||
if new_graph_version.is_active:
|
||||
# Ensure new version is the only active version
|
||||
await autogpt_server.data.graph.set_graph_active_version(
|
||||
graph_id=graph_id, version=new_graph_version.version, user_id=user_id
|
||||
)
|
||||
|
||||
return new_graph_version
|
||||
|
||||
|
||||
@router.get("/graphs/{graph_id}/versions")
|
||||
async def get_graph_all_versions(
|
||||
graph_id: str,
|
||||
user_id: typing.Annotated[
|
||||
str, fastapi.Depends(autogpt_server.server.utils.get_user_id)
|
||||
],
|
||||
) -> list[autogpt_server.data.graph.Graph]:
|
||||
graphs = await autogpt_server.data.graph.get_graph_all_versions(
|
||||
graph_id, user_id=user_id
|
||||
)
|
||||
if not graphs:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=404, detail=f"Graph #{graph_id} not found."
|
||||
)
|
||||
return graphs
|
||||
|
||||
|
||||
@router.get("/graphs/{graph_id}/versions/{version}")
|
||||
async def get_graph_version(
|
||||
graph_id: str,
|
||||
version: int,
|
||||
user_id: typing.Annotated[
|
||||
str, fastapi.Depends(autogpt_server.server.utils.get_user_id)
|
||||
],
|
||||
) -> autogpt_server.data.graph.Graph:
|
||||
graph = await autogpt_server.data.graph.get_graph(
|
||||
graph_id, version, user_id=user_id
|
||||
)
|
||||
if not graph:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=404, detail=f"Graph #{graph_id} not found."
|
||||
)
|
||||
return graph
|
||||
|
||||
|
||||
@router.put("/graphs/{graph_id}/versions/active")
|
||||
async def set_graph_active_version(
|
||||
graph_id: str,
|
||||
request_body: autogpt_server.server.model.SetGraphActiveVersion,
|
||||
user_id: typing.Annotated[
|
||||
str, fastapi.Depends(autogpt_server.server.utils.get_user_id)
|
||||
],
|
||||
):
|
||||
new_active_version = request_body.active_graph_version
|
||||
if not await autogpt_server.data.graph.get_graph(
|
||||
graph_id, new_active_version, user_id=user_id
|
||||
):
|
||||
raise fastapi.HTTPException(
|
||||
status_code=404, detail=f"Graph #{graph_id} v{new_active_version} not found"
|
||||
)
|
||||
await autogpt_server.data.graph.set_graph_active_version(
|
||||
graph_id=graph_id,
|
||||
version=request_body.active_graph_version,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/graphs/{graph_id}/execute")
|
||||
async def execute_graph(
|
||||
graph_id: str,
|
||||
node_input: dict[typing.Any, typing.Any],
|
||||
user_id: typing.Annotated[
|
||||
str, fastapi.Depends(autogpt_server.server.utils.get_user_id)
|
||||
],
|
||||
execution_manager: typing.Annotated[
|
||||
autogpt_server.executor.ExecutionManager,
|
||||
fastapi.Depends(execution_manager_client),
|
||||
],
|
||||
) -> dict[str, typing.Any]: # FIXME: add proper return type
|
||||
try:
|
||||
graph_exec = execution_manager.add_execution(
|
||||
graph_id, node_input, user_id=user_id
|
||||
)
|
||||
return {"id": graph_exec["graph_exec_id"]}
|
||||
except Exception as e:
|
||||
msg = e.__str__().encode().decode("unicode_escape")
|
||||
raise fastapi.HTTPException(status_code=400, detail=msg)
|
||||
|
||||
|
||||
@router.get("/graphs/{graph_id}/executions")
|
||||
async def list_graph_runs(
|
||||
graph_id: str,
|
||||
user_id: typing.Annotated[
|
||||
str, fastapi.Depends(autogpt_server.server.utils.get_user_id)
|
||||
],
|
||||
graph_version: typing.Optional[int] = None,
|
||||
) -> list[str]:
|
||||
graph = await autogpt_server.data.graph.get_graph(
|
||||
graph_id, graph_version, user_id=user_id
|
||||
)
|
||||
if not graph:
|
||||
rev = "" if graph_version is None else f" v{graph_version}"
|
||||
raise fastapi.HTTPException(
|
||||
status_code=404, detail=f"Agent #{graph_id}{rev} not found."
|
||||
)
|
||||
|
||||
return await autogpt_server.data.execution.list_executions(graph_id, graph_version)
|
||||
|
||||
|
||||
@router.get("/graphs/{graph_id}/executions/{graph_exec_id}")
|
||||
async def get_graph_run_node_execution_results(
|
||||
graph_id: str,
|
||||
graph_exec_id: str,
|
||||
user_id: typing.Annotated[
|
||||
str, fastapi.Depends(autogpt_server.server.utils.get_user_id)
|
||||
],
|
||||
) -> list[autogpt_server.data.execution.ExecutionResult]:
|
||||
graph = await autogpt_server.data.graph.get_graph(graph_id, user_id=user_id)
|
||||
if not graph:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=404, detail=f"Graph #{graph_id} not found."
|
||||
)
|
||||
|
||||
return await autogpt_server.data.execution.get_execution_results(graph_exec_id)
|
||||
|
||||
|
||||
@router.post("/graphs/{graph_id}/executions/{graph_exec_id}/stop")
|
||||
async def stop_graph_run(
|
||||
graph_exec_id: str,
|
||||
user_id: typing.Annotated[
|
||||
str, fastapi.Depends(autogpt_server.server.utils.get_user_id)
|
||||
],
|
||||
execution_manager: typing.Annotated[
|
||||
autogpt_server.executor.ExecutionManager,
|
||||
fastapi.Depends(execution_manager_client),
|
||||
],
|
||||
) -> list[autogpt_server.data.execution.ExecutionResult]:
|
||||
if not await autogpt_server.data.execution.get_graph_execution(
|
||||
graph_exec_id, user_id
|
||||
):
|
||||
raise fastapi.HTTPException(
|
||||
404, detail=f"Agent execution #{graph_exec_id} not found"
|
||||
)
|
||||
|
||||
execution_manager.cancel_execution(graph_exec_id)
|
||||
|
||||
# Retrieve & return canceled graph execution in its final state
|
||||
return await autogpt_server.data.execution.get_execution_results(graph_exec_id)
|
||||
|
||||
|
||||
@router.post("/graphs/{graph_id}/schedules")
|
||||
async def create_schedule(
|
||||
graph_id: str,
|
||||
cron: str,
|
||||
input_data: dict[typing.Any, typing.Any],
|
||||
user_id: typing.Annotated[
|
||||
str, fastapi.Depends(autogpt_server.server.utils.get_user_id)
|
||||
],
|
||||
execution_scheduler: typing.Annotated[
|
||||
autogpt_server.executor.ExecutionScheduler,
|
||||
fastapi.Depends(execution_scheduler_client),
|
||||
],
|
||||
) -> dict[typing.Any, typing.Any]:
|
||||
graph = await autogpt_server.data.graph.get_graph(graph_id, user_id=user_id)
|
||||
if not graph:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=404, detail=f"Graph #{graph_id} not found."
|
||||
)
|
||||
return {
|
||||
"id": execution_scheduler.add_execution_schedule(
|
||||
graph_id, graph.version, cron, input_data, user_id=user_id
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@router.get("/graphs/{graph_id}/schedules")
|
||||
async def get_execution_schedules(
|
||||
graph_id: str,
|
||||
user_id: typing.Annotated[
|
||||
str, fastapi.Depends(autogpt_server.server.utils.get_user_id)
|
||||
],
|
||||
execution_scheduler: typing.Annotated[
|
||||
autogpt_server.executor.ExecutionScheduler,
|
||||
fastapi.Depends(execution_scheduler_client),
|
||||
],
|
||||
) -> dict[str, str]:
|
||||
return execution_scheduler.get_execution_schedules(graph_id, user_id)
|
||||
|
||||
|
||||
@router.put("/graphs/schedules/{schedule_id}")
|
||||
async def update_schedule(
|
||||
schedule_id: str,
|
||||
input_data: dict[typing.Any, typing.Any],
|
||||
user_id: typing.Annotated[
|
||||
str, fastapi.Depends(autogpt_server.server.utils.get_user_id)
|
||||
],
|
||||
execution_scheduler: typing.Annotated[
|
||||
autogpt_server.executor.ExecutionScheduler,
|
||||
fastapi.Depends(execution_scheduler_client),
|
||||
],
|
||||
) -> dict[typing.Any, typing.Any]:
|
||||
is_enabled = input_data.get("is_enabled", False)
|
||||
execution_scheduler.update_schedule(schedule_id, is_enabled, user_id=user_id)
|
||||
return {"id": schedule_id}
|
||||
237
rnd/autogpt_server/autogpt_server/server/routes/agents_tests.py
Normal file
237
rnd/autogpt_server/autogpt_server/server/routes/agents_tests.py
Normal file
@@ -0,0 +1,237 @@
|
||||
import autogpt_server.server
|
||||
import autogpt_server.server.routes
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
import unittest.mock
|
||||
|
||||
from autogpt_server.server.new_rest_app import app
|
||||
import autogpt_server.server.utils as utils
|
||||
import autogpt_server.executor as executor
|
||||
import autogpt_server.data.graph
|
||||
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
async def override_get_user_id():
|
||||
return "test_user_id"
|
||||
|
||||
async def override_execution_manager_client():
|
||||
class MockExecutionManager:
|
||||
|
||||
def add_execution(self, graph_id, node_input, user_id):
|
||||
return {"graph_exec_id": "exec1"}
|
||||
|
||||
def cancel_execution(self, graph_exec_id):
|
||||
return {"graph_exec_id": "exec1"}
|
||||
|
||||
def get_execution_results(self, graph_exec_id):
|
||||
return [{"node_id": "node1", "output": "result"}]
|
||||
|
||||
|
||||
return MockExecutionManager()
|
||||
|
||||
async def override_scheduler_client():
|
||||
return executor.ExecutionScheduler()
|
||||
|
||||
app.dependency_overrides[utils.get_user_id] = override_get_user_id
|
||||
app.dependency_overrides[autogpt_server.server.routes.agents.execution_manager_client] = override_execution_manager_client
|
||||
app.dependency_overrides[autogpt_server.server.routes.agents.execution_scheduler_client] = override_scheduler_client
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_graphs():
|
||||
mock_get_graphs_meta = unittest.mock.AsyncMock()
|
||||
mock_get_graphs_meta.return_value = [
|
||||
autogpt_server.data.graph.GraphMeta(
|
||||
id="graph1",
|
||||
version=1,
|
||||
name="Test Graph 1",
|
||||
description="Test Graph 1 Description",
|
||||
is_active=True,
|
||||
is_template=False,
|
||||
),
|
||||
autogpt_server.data.graph.GraphMeta(
|
||||
id="graph2",
|
||||
version=1,
|
||||
name="Test Graph 2",
|
||||
description="Test Graph 2 Description",
|
||||
is_active=True,
|
||||
is_template=False,
|
||||
),
|
||||
autogpt_server.data.graph.GraphMeta(
|
||||
id="graph3",
|
||||
version=2,
|
||||
name="Test Graph 3",
|
||||
description="Test Graph 3 Description",
|
||||
is_active=False,
|
||||
is_template=True,
|
||||
)
|
||||
|
||||
]
|
||||
|
||||
with unittest.mock.patch('autogpt_server.data.graph.get_graphs_meta', mock_get_graphs_meta):
|
||||
response = client.get("/api/v1/graphs")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == [
|
||||
{
|
||||
"id": "graph1",
|
||||
"version": 1,
|
||||
"name": "Test Graph 1",
|
||||
"description": "Test Graph 1 Description",
|
||||
"is_active": True,
|
||||
"is_template": False,
|
||||
},
|
||||
{
|
||||
"id": "graph2",
|
||||
"version": 1,
|
||||
"name": "Test Graph 2",
|
||||
"description": "Test Graph 2 Description",
|
||||
"is_active": True,
|
||||
"is_template": False,
|
||||
},
|
||||
{
|
||||
"id": "graph3",
|
||||
"version": 2,
|
||||
"name": "Test Graph 3",
|
||||
"description": "Test Graph 3 Description",
|
||||
"is_active": False,
|
||||
"is_template": True,
|
||||
}
|
||||
]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_new_graph():
|
||||
sample_graph = autogpt_server.data.graph.Graph(
|
||||
version=1,
|
||||
is_active=True,
|
||||
is_template=False,
|
||||
name="New Graph",
|
||||
description="New Graph Description",
|
||||
nodes=[],
|
||||
links=[]
|
||||
)
|
||||
|
||||
mock_create_graph = unittest.mock.AsyncMock()
|
||||
mock_create_graph.return_value = sample_graph
|
||||
|
||||
with unittest.mock.patch('autogpt_server.data.graph.create_graph', mock_create_graph):
|
||||
response = client.post("/api/v1/graphs", json={"graph": sample_graph.model_dump()})
|
||||
assert response.status_code == 200
|
||||
assert response.json() == sample_graph.model_dump()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_new_graph_invalid_request():
|
||||
# Test case for missing both graph and template_id
|
||||
response = client.post("/api/v1/graphs", json={})
|
||||
assert response.status_code == 400
|
||||
assert response.json() == {"detail": "Either graph or template_id must be provided."}
|
||||
|
||||
# Test case for non-existent template
|
||||
mock_get_graph = unittest.mock.AsyncMock(return_value=None)
|
||||
with unittest.mock.patch('autogpt_server.data.graph.get_graph', mock_get_graph):
|
||||
response = client.post("/api/v1/graphs", json={"template_id": "non_existent_template"})
|
||||
assert response.status_code == 400
|
||||
assert response.json() == {"detail": "Template #non_existent_template not found"}
|
||||
|
||||
# Test case for invalid graph structure
|
||||
invalid_graph = {
|
||||
"version": 1,
|
||||
"is_active": True,
|
||||
"is_template": False,
|
||||
"name": "Invalid Graph",
|
||||
"description": "Invalid Graph Description",
|
||||
"nodes": "not a list", # This should be a list
|
||||
"links": []
|
||||
}
|
||||
response = client.post("/api/v1/graphs", json={"graph": invalid_graph})
|
||||
assert response.status_code == 422 # Unprocessable Entity
|
||||
assert "status" in response.json()
|
||||
assert "errors" in response.json()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_graph_details():
|
||||
sample_graph = autogpt_server.data.graph.Graph(
|
||||
id="graph1",
|
||||
version=1,
|
||||
is_active=True,
|
||||
is_template=False,
|
||||
name="Test Graph",
|
||||
description="Test Graph Description",
|
||||
nodes=[],
|
||||
links=[]
|
||||
)
|
||||
|
||||
mock_get_graph = unittest.mock.AsyncMock(return_value=sample_graph)
|
||||
|
||||
with unittest.mock.patch('autogpt_server.data.graph.get_graph', mock_get_graph):
|
||||
response = client.get("/api/v1/graphs/graph1")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == sample_graph.model_dump()
|
||||
|
||||
# Test case for non-existent graph
|
||||
mock_get_graph.return_value = None
|
||||
with unittest.mock.patch('autogpt_server.data.graph.get_graph', mock_get_graph):
|
||||
response = client.get("/api/v1/graphs/non_existent_graph")
|
||||
assert response.status_code == 404
|
||||
assert response.json() == {"detail": "Graph #non_existent_graph not found."}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_graph():
|
||||
sample_graph = autogpt_server.data.graph.Graph(
|
||||
id="graph1",
|
||||
version=1,
|
||||
is_active=True,
|
||||
is_template=False,
|
||||
name="Test Graph",
|
||||
description="Test Graph Description",
|
||||
nodes=[],
|
||||
links=[]
|
||||
)
|
||||
|
||||
mock_get_graph_all_versions = unittest.mock.AsyncMock(return_value=[sample_graph])
|
||||
mock_create_graph = unittest.mock.AsyncMock(return_value=sample_graph)
|
||||
mock_set_graph_active_version = unittest.mock.AsyncMock()
|
||||
|
||||
with unittest.mock.patch('autogpt_server.data.graph.get_graph_all_versions', mock_get_graph_all_versions), \
|
||||
unittest.mock.patch('autogpt_server.data.graph.create_graph', mock_create_graph), \
|
||||
unittest.mock.patch('autogpt_server.data.graph.set_graph_active_version', mock_set_graph_active_version):
|
||||
|
||||
updated_graph = sample_graph.model_copy(update={"name": "Updated Graph", "version": 2})
|
||||
response = client.put("/api/v1/graphs/graph1", json=updated_graph.model_dump())
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()['description'] == updated_graph.model_dump()['description']
|
||||
|
||||
|
||||
non_existent_graph = autogpt_server.data.graph.Graph(
|
||||
id="non_existent_graph",
|
||||
version=1,
|
||||
is_active=True,
|
||||
is_template=False,
|
||||
name="Test Graph",
|
||||
description="Test Graph Description",
|
||||
nodes=[],
|
||||
links=[]
|
||||
)
|
||||
|
||||
# Test case for non-existent graph
|
||||
mock_get_graph_all_versions.return_value = []
|
||||
with unittest.mock.patch('autogpt_server.data.graph.get_graph_all_versions', mock_get_graph_all_versions):
|
||||
response = client.put("/api/v1/graphs/non_existent_graph", json=non_existent_graph.model_dump())
|
||||
assert response.status_code == 404
|
||||
assert response.json() == {"detail": "Graph #non_existent_graph not found"}
|
||||
|
||||
# Test case for mismatched graph ID
|
||||
response = client.put("/api/v1/graphs/graph2", json=updated_graph.model_dump())
|
||||
assert response.status_code == 400
|
||||
assert response.json() == {"detail": "Graph ID does not match ID in URI"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_graph():
|
||||
mock_add_execution = unittest.mock.AsyncMock()
|
||||
mock_add_execution.return_value = {"graph_exec_id": "exec1"}
|
||||
|
||||
response = client.post("/api/v1/graphs/graph1/execute", json={"node_input": {"text": "hi"}})
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"id": "exec1"}
|
||||
33
rnd/autogpt_server/autogpt_server/server/routes/blocks.py
Normal file
33
rnd/autogpt_server/autogpt_server/server/routes/blocks.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import collections
|
||||
|
||||
import fastapi
|
||||
import autogpt_server.data.block
|
||||
import autogpt_server.data.credit
|
||||
|
||||
router = fastapi.APIRouter()
|
||||
|
||||
|
||||
@router.get("/blocks")
|
||||
async def get_graph_blocks():
|
||||
return [v.to_dict() for v in autogpt_server.data.block.get_blocks().values()]
|
||||
|
||||
|
||||
@router.get("/blocks/costs")
|
||||
async def get_graph_block_costs():
|
||||
return autogpt_server.data.credit.get_block_costs()
|
||||
|
||||
|
||||
@router.post("/blocks/{block_id}/execute")
|
||||
async def execute_graph_block(
|
||||
block_id: str, data: autogpt_server.data.block.BlockInput
|
||||
) -> autogpt_server.data.block.CompletedBlockOutput:
|
||||
obj = autogpt_server.data.block.get_block(block_id)
|
||||
if not obj:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=404, detail=f"Block #{block_id} not found."
|
||||
)
|
||||
|
||||
output = collections.defaultdict(list)
|
||||
for name, data in obj.execute(data):
|
||||
output[name].append(data)
|
||||
return output
|
||||
@@ -0,0 +1,39 @@
|
||||
import pytest
|
||||
import fastapi.testclient
|
||||
import autogpt_server.server.routes.blocks
|
||||
import fastapi
|
||||
|
||||
client = fastapi.testclient.TestClient(autogpt_server.server.routes.blocks.router)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_graph_blocks():
|
||||
response = client.get("/blocks")
|
||||
assert response.status_code == 200
|
||||
assert isinstance(response.json(), list)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_graph_block_costs():
|
||||
response = client.get("/blocks/costs")
|
||||
assert response.status_code == 200
|
||||
assert isinstance(response.json(), dict)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_graph_block_success():
|
||||
block_id = "f3b1c1b2-4c4f-4f0d-8d2f-4c4f0d8d2f4c"
|
||||
data = {"text": "hi"}
|
||||
response = client.post(f"/blocks/{block_id}/execute", json=data)
|
||||
assert response.status_code == 200
|
||||
assert isinstance(response.json(), dict)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_graph_block_not_found():
|
||||
block_id = "invalid_block_id"
|
||||
data = {"input": "test data"}
|
||||
with pytest.raises(fastapi.HTTPException) as exc_info:
|
||||
client.post(f"/blocks/{block_id}/execute", json=data)
|
||||
assert exc_info.value.status_code == 404
|
||||
assert exc_info.value.detail == f"Block #{block_id} not found."
|
||||
129
rnd/autogpt_server/autogpt_server/server/routes/integrations.py
Normal file
129
rnd/autogpt_server/autogpt_server/server/routes/integrations.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import logging
|
||||
import typing
|
||||
|
||||
import autogpt_libs.supabase_integration_credentials_store
|
||||
import fastapi
|
||||
import pydantic
|
||||
import supabase
|
||||
|
||||
import autogpt_server.integrations.oauth
|
||||
import autogpt_server.util.settings
|
||||
|
||||
import autogpt_server.server.utils
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = autogpt_server.util.settings.Settings()
|
||||
integrations_api_router = fastapi.APIRouter()
|
||||
|
||||
|
||||
def get_store(
|
||||
supabase: supabase.Client = fastapi.Depends(
|
||||
autogpt_server.server.utils.get_supabase
|
||||
),
|
||||
):
|
||||
return autogpt_libs.supabase_integration_credentials_store.SupabaseIntegrationCredentialsStore(
|
||||
supabase
|
||||
)
|
||||
|
||||
|
||||
class LoginResponse(pydantic.BaseModel):
|
||||
login_url: str
|
||||
|
||||
|
||||
@integrations_api_router.get("/{provider}/login")
|
||||
async def login(
|
||||
provider: typing.Annotated[
|
||||
str, fastapi.Path(title="The provider to initiate an OAuth flow for")
|
||||
],
|
||||
user_id: typing.Annotated[
|
||||
str, fastapi.Depends(autogpt_server.server.utils.get_user_id)
|
||||
],
|
||||
request: fastapi.Request,
|
||||
store: typing.Annotated[
|
||||
autogpt_libs.supabase_integration_credentials_store.SupabaseIntegrationCredentialsStore,
|
||||
fastapi.Depends(get_store),
|
||||
],
|
||||
scopes: typing.Annotated[
|
||||
str, fastapi.Query(title="Comma-separated list of authorization scopes")
|
||||
] = "",
|
||||
) -> LoginResponse:
|
||||
handler = _get_provider_oauth_handler(request, provider)
|
||||
|
||||
# Generate and store a secure random state token
|
||||
state = await store.store_state_token(user_id, provider)
|
||||
|
||||
requested_scopes = scopes.split(",") if scopes else []
|
||||
login_url = handler.get_login_url(requested_scopes, state)
|
||||
|
||||
return LoginResponse(login_url=login_url)
|
||||
|
||||
|
||||
class CredentialsMetaResponse(pydantic.BaseModel):
|
||||
credentials_id: str
|
||||
credentials_type: typing.Literal["oauth2", "api_key"]
|
||||
|
||||
|
||||
@integrations_api_router.post("/{provider}/callback")
|
||||
async def callback(
|
||||
provider: typing.Annotated[
|
||||
str, fastapi.Path(title="The target provider for this OAuth exchange")
|
||||
],
|
||||
code: typing.Annotated[
|
||||
str, fastapi.Body(title="Authorization code acquired by user login")
|
||||
],
|
||||
state_token: typing.Annotated[str, fastapi.Body(title="Anti-CSRF nonce")],
|
||||
store: typing.Annotated[
|
||||
autogpt_libs.supabase_integration_credentials_store.SupabaseIntegrationCredentialsStore,
|
||||
fastapi.Depends(get_store),
|
||||
],
|
||||
user_id: typing.Annotated[
|
||||
str, fastapi.Depends(autogpt_server.server.utils.get_user_id)
|
||||
],
|
||||
request: fastapi.Request,
|
||||
) -> CredentialsMetaResponse:
|
||||
handler = _get_provider_oauth_handler(request, provider)
|
||||
|
||||
# Verify the state token
|
||||
if not await store.verify_state_token(user_id, state_token, provider):
|
||||
raise fastapi.HTTPException(
|
||||
status_code=400, detail="Invalid or expired state token"
|
||||
)
|
||||
|
||||
try:
|
||||
credentials = handler.exchange_code_for_tokens(code)
|
||||
except Exception as e:
|
||||
logger.warning(f"Code->Token exchange failed for provider {provider}: {e}")
|
||||
raise fastapi.HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
store.add_creds(user_id, credentials)
|
||||
return CredentialsMetaResponse(
|
||||
credentials_id=credentials.id,
|
||||
credentials_type=credentials.type,
|
||||
)
|
||||
|
||||
|
||||
# -------- UTILITIES --------- #
|
||||
|
||||
|
||||
def _get_provider_oauth_handler(
|
||||
req: fastapi.Request, provider_name: str
|
||||
) -> autogpt_server.integrations.oauth.BaseOAuthHandler:
|
||||
if provider_name not in autogpt_server.integrations.oauth.HANDLERS_BY_NAME:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=404, detail=f"Unknown provider '{provider_name}'"
|
||||
)
|
||||
|
||||
client_id = getattr(settings.secrets, f"{provider_name}_client_id")
|
||||
client_secret = getattr(settings.secrets, f"{provider_name}_client_secret")
|
||||
if not (client_id and client_secret):
|
||||
raise fastapi.HTTPException(
|
||||
status_code=501,
|
||||
detail=f"Integration with provider '{provider_name}' is not configured",
|
||||
)
|
||||
|
||||
handler_class = autogpt_server.integrations.oauth.HANDLERS_BY_NAME[provider_name]
|
||||
return handler_class(
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
redirect_uri=str(req.url_for("callback", provider=provider_name)),
|
||||
)
|
||||
34
rnd/autogpt_server/autogpt_server/server/routes/root.py
Normal file
34
rnd/autogpt_server/autogpt_server/server/routes/root.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import typing
|
||||
import fastapi
|
||||
|
||||
import autogpt_libs.auth.middleware
|
||||
import autogpt_server.data.credit
|
||||
import autogpt_server.data.user
|
||||
import autogpt_server.server.utils
|
||||
|
||||
|
||||
router = fastapi.APIRouter()
|
||||
|
||||
_user_credit_model = autogpt_server.data.credit.get_user_credit_model()
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def root():
|
||||
return {"message": "Welcome to the Autogpt Server API"}
|
||||
|
||||
|
||||
@router.post("/auth/user")
|
||||
async def get_or_create_user_route(
|
||||
user_data: dict = fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware),
|
||||
):
|
||||
user = await autogpt_server.data.user.get_or_create_user(user_data)
|
||||
return user.model_dump()
|
||||
|
||||
|
||||
@router.get("/credits")
|
||||
async def get_user_credits(
|
||||
user_id: typing.Annotated[
|
||||
str, fastapi.Depends(autogpt_server.server.utils.get_user_id)
|
||||
]
|
||||
):
|
||||
return {"credits": await _user_credit_model.get_or_refill_credit(user_id)}
|
||||
@@ -0,0 +1,50 @@
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
|
||||
from autogpt_server.server.new_rest_app import app
|
||||
import autogpt_server.server.utils as utils
|
||||
import autogpt_libs.auth.middleware as auth_middleware
|
||||
import unittest.mock
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
async def override_get_user_id():
|
||||
return "test_user_id"
|
||||
|
||||
|
||||
async def override_user_data():
|
||||
return {"id": "test_user_id", "name": "Test User"}
|
||||
|
||||
|
||||
app.dependency_overrides[utils.get_user_id] = override_get_user_id
|
||||
app.dependency_overrides[auth_middleware.auth_middleware] = override_user_data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_root():
|
||||
response = client.get("/api/v1/")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"message": "Welcome to the Autogpt Server API"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_or_create_user_route():
|
||||
|
||||
# Create a mock for the get_or_create_user function
|
||||
mock_get_or_create_user = unittest.mock.AsyncMock()
|
||||
mock_get_or_create_user.return_value = unittest.mock.Mock(
|
||||
model_dump=lambda: {"id": "test_user_id", "name": "Test User"}
|
||||
)
|
||||
|
||||
# Apply the mock using patch
|
||||
with unittest.mock.patch('autogpt_server.data.user.get_or_create_user', mock_get_or_create_user):
|
||||
response = client.post("/api/v1/auth/user")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"id": "test_user_id", "name": "Test User"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_credits():
|
||||
response = client.get("/api/v1/credits")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"credits": 0}
|
||||
Reference in New Issue
Block a user