Compare commits

...

9 Commits

Author SHA1 Message Date
SwiftyOS
2a63093b14 adding tests 2024-09-18 12:31:22 +02:00
SwiftyOS
70f884367b added agent routes 2024-09-17 16:24:39 +02:00
SwiftyOS
10da2d3b2a updated tmp named new_rest_app 2024-09-17 15:30:53 +02:00
SwiftyOS
55e206fcd7 updated routes init 2024-09-17 15:30:31 +02:00
SwiftyOS
74db958932 updated integrations routes 2024-09-17 15:30:23 +02:00
SwiftyOS
72d539f777 add root routes 2024-09-17 15:30:07 +02:00
SwiftyOS
2142caf2ca added blocks routes 2024-09-17 15:30:00 +02:00
SwiftyOS
8d0bbc5ffe adding template for refactored rest service app 2024-09-17 11:51:46 +02:00
SwiftyOS
3d4aca9fcc Splitting rest services into routes 2024-09-16 16:56:09 +02:00
11 changed files with 961 additions and 105 deletions

View File

@@ -1,105 +0,0 @@
import logging
from typing import Annotated, Literal
from autogpt_libs.supabase_integration_credentials_store import (
SupabaseIntegrationCredentialsStore,
)
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query, Request
from pydantic import BaseModel
from supabase import Client
from autogpt_server.integrations.oauth import HANDLERS_BY_NAME, BaseOAuthHandler
from autogpt_server.util.settings import Settings
from .utils import get_supabase, get_user_id
logger = logging.getLogger(__name__)
settings = Settings()
integrations_api_router = APIRouter()
def get_store(supabase: Client = Depends(get_supabase)):
return SupabaseIntegrationCredentialsStore(supabase)
class LoginResponse(BaseModel):
login_url: str
@integrations_api_router.get("/{provider}/login")
async def login(
provider: Annotated[str, Path(title="The provider to initiate an OAuth flow for")],
user_id: Annotated[str, Depends(get_user_id)],
request: Request,
store: Annotated[SupabaseIntegrationCredentialsStore, Depends(get_store)],
scopes: Annotated[
str, Query(title="Comma-separated list of authorization scopes")
] = "",
) -> LoginResponse:
handler = _get_provider_oauth_handler(request, provider)
# Generate and store a secure random state token
state = await store.store_state_token(user_id, provider)
requested_scopes = scopes.split(",") if scopes else []
login_url = handler.get_login_url(requested_scopes, state)
return LoginResponse(login_url=login_url)
class CredentialsMetaResponse(BaseModel):
credentials_id: str
credentials_type: Literal["oauth2", "api_key"]
@integrations_api_router.post("/{provider}/callback")
async def callback(
provider: Annotated[str, Path(title="The target provider for this OAuth exchange")],
code: Annotated[str, Body(title="Authorization code acquired by user login")],
state_token: Annotated[str, Body(title="Anti-CSRF nonce")],
store: Annotated[SupabaseIntegrationCredentialsStore, Depends(get_store)],
user_id: Annotated[str, Depends(get_user_id)],
request: Request,
) -> CredentialsMetaResponse:
handler = _get_provider_oauth_handler(request, provider)
# Verify the state token
if not await store.verify_state_token(user_id, state_token, provider):
raise HTTPException(status_code=400, detail="Invalid or expired state token")
try:
credentials = handler.exchange_code_for_tokens(code)
except Exception as e:
logger.warning(f"Code->Token exchange failed for provider {provider}: {e}")
raise HTTPException(status_code=400, detail=str(e))
store.add_creds(user_id, credentials)
return CredentialsMetaResponse(
credentials_id=credentials.id,
credentials_type=credentials.type,
)
# -------- UTILITIES --------- #
def _get_provider_oauth_handler(req: Request, provider_name: str) -> BaseOAuthHandler:
if provider_name not in HANDLERS_BY_NAME:
raise HTTPException(
status_code=404, detail=f"Unknown provider '{provider_name}'"
)
client_id = getattr(settings.secrets, f"{provider_name}_client_id")
client_secret = getattr(settings.secrets, f"{provider_name}_client_secret")
if not (client_id and client_secret):
raise HTTPException(
status_code=501,
detail=f"Integration with provider '{provider_name}' is not configured",
)
handler_class = HANDLERS_BY_NAME[provider_name]
return handler_class(
client_id=client_id,
client_secret=client_secret,
redirect_uri=str(req.url_for("callback", provider=provider_name)),
)

View File

@@ -0,0 +1,92 @@
import contextlib
import fastapi
import fastapi.responses
import fastapi.middleware.cors
import autogpt_libs.auth.middleware
import autogpt_server.data.block
import autogpt_server.data.db
import autogpt_server.data.graph
import autogpt_server.data.user
import autogpt_server.server.routes
import autogpt_server.server.utils
import autogpt_server.util.settings
from autogpt_server.data import user as user_db
settings = autogpt_server.util.settings.Settings()
@contextlib.asynccontextmanager
async def app_lifespan(app: fastapi.FastAPI):
await autogpt_server.data.db.connect()
await autogpt_server.data.block.initialize_blocks()
if await user_db.create_default_user(settings.config.enable_auth):
await autogpt_server.data.graph.import_packaged_templates()
yield
await autogpt_server.data.db.disconnect()
app = fastapi.FastAPI(
title="AutoGPT Agent Server",
description=(
"This server is used to execute agents that are created by the "
"AutoGPT system."
),
summary="AutoGPT Agent Server",
version="0.1",
lifespan=app_lifespan,
)
api_router = fastapi.APIRouter(prefix="/api/v1")
api_router.dependencies.append(
fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)
)
api_router.include_router(autogpt_server.server.routes.root_router)
api_router.include_router(autogpt_server.server.routes.agents_router, tags=["agents"])
api_router.include_router(autogpt_server.server.routes.blocks_router, tags=["blocks"])
api_router.include_router(
autogpt_server.server.routes.integrations_router, prefix="/integrations"
)
app.include_router(api_router)
@app.exception_handler(500)
def handle_internal_http_error(request: fastapi.Request, exc: Exception):
return fastapi.responses.JSONResponse(
status_code=500,
content={"message": str(exc)},
)
@app.exception_handler(fastapi.exceptions.RequestValidationError)
async def validation_exception_handler(request: fastapi.Request, exc: fastapi.exceptions.RequestValidationError):
errors = []
for err in exc.errors():
error = {
"field": ".".join(err["loc"][1:]), # Skipping 'body' or 'query' etc.
"message": err["msg"],
"type": err["type"]
}
errors.append(error)
return fastapi.responses.JSONResponse(
status_code=422,
content={
"status": "fail",
"errors": errors
},
)
app.add_middleware(
fastapi.middleware.cors.CORSMiddleware,
allow_origins=["*"], # Allows all origins
allow_credentials=True,
allow_methods=["*"], # Allows all methods
allow_headers=["*"], # Allows all headers
)

View File

@@ -0,0 +1,12 @@
from .agents import router as agents_router
from .root import router as root_router
from .blocks import router as blocks_router
from .integrations import integrations_api_router as integrations_router
__all__ = [
"agents_router",
"root_router",
"blocks_router",
"integrations_router",
]

View File

@@ -0,0 +1,335 @@
import typing
import fastapi
import autogpt_server.data.graph
import autogpt_server.data.queue
import autogpt_server.data.execution
import autogpt_server.executor
import autogpt_server.server.model
import autogpt_server.server.utils
import autogpt_server.util.service
import autogpt_server.util.settings
router = fastapi.APIRouter()
def execution_manager_client() -> autogpt_server.executor.ExecutionManager:
return autogpt_server.util.service.get_service_client(
autogpt_server.executor.ExecutionManager,
autogpt_server.util.settings.Config().execution_manager_port,
)
def execution_scheduler_client() -> autogpt_server.executor.ExecutionScheduler:
return autogpt_server.util.service.get_service_client(
autogpt_server.executor.ExecutionScheduler,
autogpt_server.util.settings.Config().execution_scheduler_port,
)
@router.get("/graphs")
async def get_graphs(
user_id: typing.Annotated[
str, fastapi.Depends(autogpt_server.server.utils.get_user_id)
]
) -> list[autogpt_server.data.graph.GraphMeta]:
return await autogpt_server.data.graph.get_graphs_meta(
filter_by="active", user_id=user_id
)
@router.post("/graphs")
async def create_new_graph(
create_graph: autogpt_server.server.model.CreateGraph,
user_id: typing.Annotated[
str, fastapi.Depends(autogpt_server.server.utils.get_user_id)
],
) -> autogpt_server.data.graph.Graph:
if create_graph.graph:
graph = create_graph.graph
elif create_graph.template_id:
# Create a new graph from a template
graph = await autogpt_server.data.graph.get_graph(
create_graph.template_id,
create_graph.template_version,
template=True,
user_id=user_id,
)
if not graph:
raise fastapi.HTTPException(
400, detail=f"Template #{create_graph.template_id} not found"
)
graph.version = 1
else:
raise fastapi.HTTPException(
status_code=400, detail="Either graph or template_id must be provided."
)
graph.is_template = False
graph.is_active = True
graph.reassign_ids(reassign_graph_id=True)
return await autogpt_server.data.graph.create_graph(graph, user_id=user_id)
return await cls.create_graph(create_graph, is_template=False, user_id=user_id)
@router.get("/graphs/{graph_id}")
async def get_graph(
graph_id: str,
user_id: typing.Annotated[
str, fastapi.Depends(autogpt_server.server.utils.get_user_id)
],
version: typing.Optional[int] = None,
) -> autogpt_server.data.graph.Graph:
graph = await autogpt_server.data.graph.get_graph(
graph_id, version, user_id=user_id
)
if graph and graph.id != graph_id:
raise fastapi.HTTPException(400, detail="Graph ID does not match ID in URI")
if not graph:
raise fastapi.HTTPException(
status_code=404, detail=f"Graph #{graph_id} not found."
)
return graph
@router.put("/graphs/{graph_id}")
async def update_graph(
graph_id: str,
graph: autogpt_server.data.graph.Graph,
user_id: typing.Annotated[
str, fastapi.Depends(autogpt_server.server.utils.get_user_id)
],
) -> autogpt_server.data.graph.Graph:
# Sanity check
if graph.id and graph.id != graph_id:
raise fastapi.HTTPException(400, detail="Graph ID does not match ID in URI")
# Determine new version
existing_versions = await autogpt_server.data.graph.get_graph_all_versions(
graph_id, user_id=user_id
)
if not existing_versions:
raise fastapi.HTTPException(404, detail=f"Graph #{graph_id} not found")
latest_version_number = max(g.version for g in existing_versions)
graph.version = latest_version_number + 1
latest_version_graph = next(
v for v in existing_versions if v.version == latest_version_number
)
if latest_version_graph.is_template != graph.is_template:
raise fastapi.HTTPException(
400, detail="Changing is_template on an existing graph is forbidden"
)
graph.is_active = not graph.is_template
graph.reassign_ids()
new_graph_version = await autogpt_server.data.graph.create_graph(
graph, user_id=user_id
)
if new_graph_version.is_active:
# Ensure new version is the only active version
await autogpt_server.data.graph.set_graph_active_version(
graph_id=graph_id, version=new_graph_version.version, user_id=user_id
)
return new_graph_version
@router.get("/graphs/{graph_id}/versions")
async def get_graph_all_versions(
graph_id: str,
user_id: typing.Annotated[
str, fastapi.Depends(autogpt_server.server.utils.get_user_id)
],
) -> list[autogpt_server.data.graph.Graph]:
graphs = await autogpt_server.data.graph.get_graph_all_versions(
graph_id, user_id=user_id
)
if not graphs:
raise fastapi.HTTPException(
status_code=404, detail=f"Graph #{graph_id} not found."
)
return graphs
@router.get("/graphs/{graph_id}/versions/{version}")
async def get_graph_version(
graph_id: str,
version: int,
user_id: typing.Annotated[
str, fastapi.Depends(autogpt_server.server.utils.get_user_id)
],
) -> autogpt_server.data.graph.Graph:
graph = await autogpt_server.data.graph.get_graph(
graph_id, version, user_id=user_id
)
if not graph:
raise fastapi.HTTPException(
status_code=404, detail=f"Graph #{graph_id} not found."
)
return graph
@router.put("/graphs/{graph_id}/versions/active")
async def set_graph_active_version(
graph_id: str,
request_body: autogpt_server.server.model.SetGraphActiveVersion,
user_id: typing.Annotated[
str, fastapi.Depends(autogpt_server.server.utils.get_user_id)
],
):
new_active_version = request_body.active_graph_version
if not await autogpt_server.data.graph.get_graph(
graph_id, new_active_version, user_id=user_id
):
raise fastapi.HTTPException(
status_code=404, detail=f"Graph #{graph_id} v{new_active_version} not found"
)
await autogpt_server.data.graph.set_graph_active_version(
graph_id=graph_id,
version=request_body.active_graph_version,
user_id=user_id,
)
@router.post("/graphs/{graph_id}/execute")
async def execute_graph(
graph_id: str,
node_input: dict[typing.Any, typing.Any],
user_id: typing.Annotated[
str, fastapi.Depends(autogpt_server.server.utils.get_user_id)
],
execution_manager: typing.Annotated[
autogpt_server.executor.ExecutionManager,
fastapi.Depends(execution_manager_client),
],
) -> dict[str, typing.Any]: # FIXME: add proper return type
try:
graph_exec = execution_manager.add_execution(
graph_id, node_input, user_id=user_id
)
return {"id": graph_exec["graph_exec_id"]}
except Exception as e:
msg = e.__str__().encode().decode("unicode_escape")
raise fastapi.HTTPException(status_code=400, detail=msg)
@router.get("/graphs/{graph_id}/executions")
async def list_graph_runs(
graph_id: str,
user_id: typing.Annotated[
str, fastapi.Depends(autogpt_server.server.utils.get_user_id)
],
graph_version: typing.Optional[int] = None,
) -> list[str]:
graph = await autogpt_server.data.graph.get_graph(
graph_id, graph_version, user_id=user_id
)
if not graph:
rev = "" if graph_version is None else f" v{graph_version}"
raise fastapi.HTTPException(
status_code=404, detail=f"Agent #{graph_id}{rev} not found."
)
return await autogpt_server.data.execution.list_executions(graph_id, graph_version)
@router.get("/graphs/{graph_id}/executions/{graph_exec_id}")
async def get_graph_run_node_execution_results(
graph_id: str,
graph_exec_id: str,
user_id: typing.Annotated[
str, fastapi.Depends(autogpt_server.server.utils.get_user_id)
],
) -> list[autogpt_server.data.execution.ExecutionResult]:
graph = await autogpt_server.data.graph.get_graph(graph_id, user_id=user_id)
if not graph:
raise fastapi.HTTPException(
status_code=404, detail=f"Graph #{graph_id} not found."
)
return await autogpt_server.data.execution.get_execution_results(graph_exec_id)
@router.post("/graphs/{graph_id}/executions/{graph_exec_id}/stop")
async def stop_graph_run(
graph_exec_id: str,
user_id: typing.Annotated[
str, fastapi.Depends(autogpt_server.server.utils.get_user_id)
],
execution_manager: typing.Annotated[
autogpt_server.executor.ExecutionManager,
fastapi.Depends(execution_manager_client),
],
) -> list[autogpt_server.data.execution.ExecutionResult]:
if not await autogpt_server.data.execution.get_graph_execution(
graph_exec_id, user_id
):
raise fastapi.HTTPException(
404, detail=f"Agent execution #{graph_exec_id} not found"
)
execution_manager.cancel_execution(graph_exec_id)
# Retrieve & return canceled graph execution in its final state
return await autogpt_server.data.execution.get_execution_results(graph_exec_id)
@router.post("/graphs/{graph_id}/schedules")
async def create_schedule(
graph_id: str,
cron: str,
input_data: dict[typing.Any, typing.Any],
user_id: typing.Annotated[
str, fastapi.Depends(autogpt_server.server.utils.get_user_id)
],
execution_scheduler: typing.Annotated[
autogpt_server.executor.ExecutionScheduler,
fastapi.Depends(execution_scheduler_client),
],
) -> dict[typing.Any, typing.Any]:
graph = await autogpt_server.data.graph.get_graph(graph_id, user_id=user_id)
if not graph:
raise fastapi.HTTPException(
status_code=404, detail=f"Graph #{graph_id} not found."
)
return {
"id": execution_scheduler.add_execution_schedule(
graph_id, graph.version, cron, input_data, user_id=user_id
)
}
@router.get("/graphs/{graph_id}/schedules")
async def get_execution_schedules(
graph_id: str,
user_id: typing.Annotated[
str, fastapi.Depends(autogpt_server.server.utils.get_user_id)
],
execution_scheduler: typing.Annotated[
autogpt_server.executor.ExecutionScheduler,
fastapi.Depends(execution_scheduler_client),
],
) -> dict[str, str]:
return execution_scheduler.get_execution_schedules(graph_id, user_id)
@router.put("/graphs/schedules/{schedule_id}")
async def update_schedule(
schedule_id: str,
input_data: dict[typing.Any, typing.Any],
user_id: typing.Annotated[
str, fastapi.Depends(autogpt_server.server.utils.get_user_id)
],
execution_scheduler: typing.Annotated[
autogpt_server.executor.ExecutionScheduler,
fastapi.Depends(execution_scheduler_client),
],
) -> dict[typing.Any, typing.Any]:
is_enabled = input_data.get("is_enabled", False)
execution_scheduler.update_schedule(schedule_id, is_enabled, user_id=user_id)
return {"id": schedule_id}

View File

@@ -0,0 +1,237 @@
import autogpt_server.server
import autogpt_server.server.routes
import fastapi
import fastapi.testclient
import pytest
import unittest.mock
from autogpt_server.server.new_rest_app import app
import autogpt_server.server.utils as utils
import autogpt_server.executor as executor
import autogpt_server.data.graph
client = fastapi.testclient.TestClient(app)
async def override_get_user_id():
return "test_user_id"
async def override_execution_manager_client():
class MockExecutionManager:
def add_execution(self, graph_id, node_input, user_id):
return {"graph_exec_id": "exec1"}
def cancel_execution(self, graph_exec_id):
return {"graph_exec_id": "exec1"}
def get_execution_results(self, graph_exec_id):
return [{"node_id": "node1", "output": "result"}]
return MockExecutionManager()
async def override_scheduler_client():
return executor.ExecutionScheduler()
app.dependency_overrides[utils.get_user_id] = override_get_user_id
app.dependency_overrides[autogpt_server.server.routes.agents.execution_manager_client] = override_execution_manager_client
app.dependency_overrides[autogpt_server.server.routes.agents.execution_scheduler_client] = override_scheduler_client
@pytest.mark.asyncio
async def test_get_graphs():
mock_get_graphs_meta = unittest.mock.AsyncMock()
mock_get_graphs_meta.return_value = [
autogpt_server.data.graph.GraphMeta(
id="graph1",
version=1,
name="Test Graph 1",
description="Test Graph 1 Description",
is_active=True,
is_template=False,
),
autogpt_server.data.graph.GraphMeta(
id="graph2",
version=1,
name="Test Graph 2",
description="Test Graph 2 Description",
is_active=True,
is_template=False,
),
autogpt_server.data.graph.GraphMeta(
id="graph3",
version=2,
name="Test Graph 3",
description="Test Graph 3 Description",
is_active=False,
is_template=True,
)
]
with unittest.mock.patch('autogpt_server.data.graph.get_graphs_meta', mock_get_graphs_meta):
response = client.get("/api/v1/graphs")
assert response.status_code == 200
assert response.json() == [
{
"id": "graph1",
"version": 1,
"name": "Test Graph 1",
"description": "Test Graph 1 Description",
"is_active": True,
"is_template": False,
},
{
"id": "graph2",
"version": 1,
"name": "Test Graph 2",
"description": "Test Graph 2 Description",
"is_active": True,
"is_template": False,
},
{
"id": "graph3",
"version": 2,
"name": "Test Graph 3",
"description": "Test Graph 3 Description",
"is_active": False,
"is_template": True,
}
]
@pytest.mark.asyncio
async def test_create_new_graph():
sample_graph = autogpt_server.data.graph.Graph(
version=1,
is_active=True,
is_template=False,
name="New Graph",
description="New Graph Description",
nodes=[],
links=[]
)
mock_create_graph = unittest.mock.AsyncMock()
mock_create_graph.return_value = sample_graph
with unittest.mock.patch('autogpt_server.data.graph.create_graph', mock_create_graph):
response = client.post("/api/v1/graphs", json={"graph": sample_graph.model_dump()})
assert response.status_code == 200
assert response.json() == sample_graph.model_dump()
@pytest.mark.asyncio
async def test_create_new_graph_invalid_request():
# Test case for missing both graph and template_id
response = client.post("/api/v1/graphs", json={})
assert response.status_code == 400
assert response.json() == {"detail": "Either graph or template_id must be provided."}
# Test case for non-existent template
mock_get_graph = unittest.mock.AsyncMock(return_value=None)
with unittest.mock.patch('autogpt_server.data.graph.get_graph', mock_get_graph):
response = client.post("/api/v1/graphs", json={"template_id": "non_existent_template"})
assert response.status_code == 400
assert response.json() == {"detail": "Template #non_existent_template not found"}
# Test case for invalid graph structure
invalid_graph = {
"version": 1,
"is_active": True,
"is_template": False,
"name": "Invalid Graph",
"description": "Invalid Graph Description",
"nodes": "not a list", # This should be a list
"links": []
}
response = client.post("/api/v1/graphs", json={"graph": invalid_graph})
assert response.status_code == 422 # Unprocessable Entity
assert "status" in response.json()
assert "errors" in response.json()
@pytest.mark.asyncio
async def test_get_graph_details():
sample_graph = autogpt_server.data.graph.Graph(
id="graph1",
version=1,
is_active=True,
is_template=False,
name="Test Graph",
description="Test Graph Description",
nodes=[],
links=[]
)
mock_get_graph = unittest.mock.AsyncMock(return_value=sample_graph)
with unittest.mock.patch('autogpt_server.data.graph.get_graph', mock_get_graph):
response = client.get("/api/v1/graphs/graph1")
assert response.status_code == 200
assert response.json() == sample_graph.model_dump()
# Test case for non-existent graph
mock_get_graph.return_value = None
with unittest.mock.patch('autogpt_server.data.graph.get_graph', mock_get_graph):
response = client.get("/api/v1/graphs/non_existent_graph")
assert response.status_code == 404
assert response.json() == {"detail": "Graph #non_existent_graph not found."}
@pytest.mark.asyncio
async def test_update_graph():
sample_graph = autogpt_server.data.graph.Graph(
id="graph1",
version=1,
is_active=True,
is_template=False,
name="Test Graph",
description="Test Graph Description",
nodes=[],
links=[]
)
mock_get_graph_all_versions = unittest.mock.AsyncMock(return_value=[sample_graph])
mock_create_graph = unittest.mock.AsyncMock(return_value=sample_graph)
mock_set_graph_active_version = unittest.mock.AsyncMock()
with unittest.mock.patch('autogpt_server.data.graph.get_graph_all_versions', mock_get_graph_all_versions), \
unittest.mock.patch('autogpt_server.data.graph.create_graph', mock_create_graph), \
unittest.mock.patch('autogpt_server.data.graph.set_graph_active_version', mock_set_graph_active_version):
updated_graph = sample_graph.model_copy(update={"name": "Updated Graph", "version": 2})
response = client.put("/api/v1/graphs/graph1", json=updated_graph.model_dump())
assert response.status_code == 200
assert response.json()['description'] == updated_graph.model_dump()['description']
non_existent_graph = autogpt_server.data.graph.Graph(
id="non_existent_graph",
version=1,
is_active=True,
is_template=False,
name="Test Graph",
description="Test Graph Description",
nodes=[],
links=[]
)
# Test case for non-existent graph
mock_get_graph_all_versions.return_value = []
with unittest.mock.patch('autogpt_server.data.graph.get_graph_all_versions', mock_get_graph_all_versions):
response = client.put("/api/v1/graphs/non_existent_graph", json=non_existent_graph.model_dump())
assert response.status_code == 404
assert response.json() == {"detail": "Graph #non_existent_graph not found"}
# Test case for mismatched graph ID
response = client.put("/api/v1/graphs/graph2", json=updated_graph.model_dump())
assert response.status_code == 400
assert response.json() == {"detail": "Graph ID does not match ID in URI"}
@pytest.mark.asyncio
async def test_execute_graph():
mock_add_execution = unittest.mock.AsyncMock()
mock_add_execution.return_value = {"graph_exec_id": "exec1"}
response = client.post("/api/v1/graphs/graph1/execute", json={"node_input": {"text": "hi"}})
assert response.status_code == 200
assert response.json() == {"id": "exec1"}

View File

@@ -0,0 +1,33 @@
import collections
import fastapi
import autogpt_server.data.block
import autogpt_server.data.credit
router = fastapi.APIRouter()
@router.get("/blocks")
async def get_graph_blocks():
return [v.to_dict() for v in autogpt_server.data.block.get_blocks().values()]
@router.get("/blocks/costs")
async def get_graph_block_costs():
return autogpt_server.data.credit.get_block_costs()
@router.post("/blocks/{block_id}/execute")
async def execute_graph_block(
block_id: str, data: autogpt_server.data.block.BlockInput
) -> autogpt_server.data.block.CompletedBlockOutput:
obj = autogpt_server.data.block.get_block(block_id)
if not obj:
raise fastapi.HTTPException(
status_code=404, detail=f"Block #{block_id} not found."
)
output = collections.defaultdict(list)
for name, data in obj.execute(data):
output[name].append(data)
return output

View File

@@ -0,0 +1,39 @@
import pytest
import fastapi.testclient
import autogpt_server.server.routes.blocks
import fastapi
client = fastapi.testclient.TestClient(autogpt_server.server.routes.blocks.router)
@pytest.mark.asyncio
async def test_get_graph_blocks():
response = client.get("/blocks")
assert response.status_code == 200
assert isinstance(response.json(), list)
@pytest.mark.asyncio
async def test_get_graph_block_costs():
response = client.get("/blocks/costs")
assert response.status_code == 200
assert isinstance(response.json(), dict)
@pytest.mark.asyncio
async def test_execute_graph_block_success():
block_id = "f3b1c1b2-4c4f-4f0d-8d2f-4c4f0d8d2f4c"
data = {"text": "hi"}
response = client.post(f"/blocks/{block_id}/execute", json=data)
assert response.status_code == 200
assert isinstance(response.json(), dict)
@pytest.mark.asyncio
async def test_execute_graph_block_not_found():
block_id = "invalid_block_id"
data = {"input": "test data"}
with pytest.raises(fastapi.HTTPException) as exc_info:
client.post(f"/blocks/{block_id}/execute", json=data)
assert exc_info.value.status_code == 404
assert exc_info.value.detail == f"Block #{block_id} not found."

View File

@@ -0,0 +1,129 @@
import logging
import typing
import autogpt_libs.supabase_integration_credentials_store
import fastapi
import pydantic
import supabase
import autogpt_server.integrations.oauth
import autogpt_server.util.settings
import autogpt_server.server.utils
logger = logging.getLogger(__name__)
settings = autogpt_server.util.settings.Settings()
integrations_api_router = fastapi.APIRouter()
def get_store(
supabase: supabase.Client = fastapi.Depends(
autogpt_server.server.utils.get_supabase
),
):
return autogpt_libs.supabase_integration_credentials_store.SupabaseIntegrationCredentialsStore(
supabase
)
class LoginResponse(pydantic.BaseModel):
login_url: str
@integrations_api_router.get("/{provider}/login")
async def login(
provider: typing.Annotated[
str, fastapi.Path(title="The provider to initiate an OAuth flow for")
],
user_id: typing.Annotated[
str, fastapi.Depends(autogpt_server.server.utils.get_user_id)
],
request: fastapi.Request,
store: typing.Annotated[
autogpt_libs.supabase_integration_credentials_store.SupabaseIntegrationCredentialsStore,
fastapi.Depends(get_store),
],
scopes: typing.Annotated[
str, fastapi.Query(title="Comma-separated list of authorization scopes")
] = "",
) -> LoginResponse:
handler = _get_provider_oauth_handler(request, provider)
# Generate and store a secure random state token
state = await store.store_state_token(user_id, provider)
requested_scopes = scopes.split(",") if scopes else []
login_url = handler.get_login_url(requested_scopes, state)
return LoginResponse(login_url=login_url)
class CredentialsMetaResponse(pydantic.BaseModel):
credentials_id: str
credentials_type: typing.Literal["oauth2", "api_key"]
@integrations_api_router.post("/{provider}/callback")
async def callback(
provider: typing.Annotated[
str, fastapi.Path(title="The target provider for this OAuth exchange")
],
code: typing.Annotated[
str, fastapi.Body(title="Authorization code acquired by user login")
],
state_token: typing.Annotated[str, fastapi.Body(title="Anti-CSRF nonce")],
store: typing.Annotated[
autogpt_libs.supabase_integration_credentials_store.SupabaseIntegrationCredentialsStore,
fastapi.Depends(get_store),
],
user_id: typing.Annotated[
str, fastapi.Depends(autogpt_server.server.utils.get_user_id)
],
request: fastapi.Request,
) -> CredentialsMetaResponse:
handler = _get_provider_oauth_handler(request, provider)
# Verify the state token
if not await store.verify_state_token(user_id, state_token, provider):
raise fastapi.HTTPException(
status_code=400, detail="Invalid or expired state token"
)
try:
credentials = handler.exchange_code_for_tokens(code)
except Exception as e:
logger.warning(f"Code->Token exchange failed for provider {provider}: {e}")
raise fastapi.HTTPException(status_code=400, detail=str(e))
store.add_creds(user_id, credentials)
return CredentialsMetaResponse(
credentials_id=credentials.id,
credentials_type=credentials.type,
)
# -------- UTILITIES --------- #
def _get_provider_oauth_handler(
req: fastapi.Request, provider_name: str
) -> autogpt_server.integrations.oauth.BaseOAuthHandler:
if provider_name not in autogpt_server.integrations.oauth.HANDLERS_BY_NAME:
raise fastapi.HTTPException(
status_code=404, detail=f"Unknown provider '{provider_name}'"
)
client_id = getattr(settings.secrets, f"{provider_name}_client_id")
client_secret = getattr(settings.secrets, f"{provider_name}_client_secret")
if not (client_id and client_secret):
raise fastapi.HTTPException(
status_code=501,
detail=f"Integration with provider '{provider_name}' is not configured",
)
handler_class = autogpt_server.integrations.oauth.HANDLERS_BY_NAME[provider_name]
return handler_class(
client_id=client_id,
client_secret=client_secret,
redirect_uri=str(req.url_for("callback", provider=provider_name)),
)

View File

@@ -0,0 +1,34 @@
import typing
import fastapi
import autogpt_libs.auth.middleware
import autogpt_server.data.credit
import autogpt_server.data.user
import autogpt_server.server.utils
router = fastapi.APIRouter()
_user_credit_model = autogpt_server.data.credit.get_user_credit_model()
@router.get("/")
async def root():
return {"message": "Welcome to the Autogpt Server API"}
@router.post("/auth/user")
async def get_or_create_user_route(
user_data: dict = fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware),
):
user = await autogpt_server.data.user.get_or_create_user(user_data)
return user.model_dump()
@router.get("/credits")
async def get_user_credits(
user_id: typing.Annotated[
str, fastapi.Depends(autogpt_server.server.utils.get_user_id)
]
):
return {"credits": await _user_credit_model.get_or_refill_credit(user_id)}

View File

@@ -0,0 +1,50 @@
import fastapi
import fastapi.testclient
import pytest
from autogpt_server.server.new_rest_app import app
import autogpt_server.server.utils as utils
import autogpt_libs.auth.middleware as auth_middleware
import unittest.mock
client = fastapi.testclient.TestClient(app)
async def override_get_user_id():
return "test_user_id"
async def override_user_data():
return {"id": "test_user_id", "name": "Test User"}
app.dependency_overrides[utils.get_user_id] = override_get_user_id
app.dependency_overrides[auth_middleware.auth_middleware] = override_user_data
@pytest.mark.asyncio
async def test_root():
response = client.get("/api/v1/")
assert response.status_code == 200
assert response.json() == {"message": "Welcome to the Autogpt Server API"}
@pytest.mark.asyncio
async def test_get_or_create_user_route():
# Create a mock for the get_or_create_user function
mock_get_or_create_user = unittest.mock.AsyncMock()
mock_get_or_create_user.return_value = unittest.mock.Mock(
model_dump=lambda: {"id": "test_user_id", "name": "Test User"}
)
# Apply the mock using patch
with unittest.mock.patch('autogpt_server.data.user.get_or_create_user', mock_get_or_create_user):
response = client.post("/api/v1/auth/user")
assert response.status_code == 200
assert response.json() == {"id": "test_user_id", "name": "Test User"}
@pytest.mark.asyncio
async def test_get_user_credits():
response = client.get("/api/v1/credits")
assert response.status_code == 200
assert response.json() == {"credits": 0}