addressing feedback

This commit is contained in:
Aarushi
2024-08-05 17:11:34 +01:00
parent bc19e168b5
commit d4d4dcdf81
14 changed files with 175 additions and 42 deletions

View File

@@ -17,7 +17,7 @@ export default class AutoGPTServerAPI {
private supabaseClient = createClient();
constructor(
baseUrl: string = process.env.AGPT_SERVER_URL || "http://localhost:8000/api"
baseUrl: string = process.env.NEXT_PUBLIC_AGPT_SERVER_URL || "http://localhost:8000/api"
) {
this.baseUrl = baseUrl;
this.wsUrl = `ws://${new URL(this.baseUrl).host}/ws`;

View File

@@ -10,6 +10,7 @@ from autogpt_server.util import json
class ExecutionSchedule(BaseDbModel):
graph_id: str
user_id: str
graph_version: int
schedule: str
is_enabled: bool
@@ -25,6 +26,7 @@ class ExecutionSchedule(BaseDbModel):
return ExecutionSchedule(
id=schedule.id,
graph_id=schedule.agentGraphId,
user_id=schedule.userId,
graph_version=schedule.agentGraphVersion,
schedule=schedule.schedule,
is_enabled=schedule.isEnabled,

View File

@@ -68,10 +68,16 @@ class ExecutionScheduler(AppService):
@expose
def add_execution_schedule(
self, graph_id: str, graph_version: int, cron: str, input_data: BlockInput
self,
graph_id: str,
graph_version: int,
cron: str,
input_data: BlockInput,
user_id: str,
) -> str:
schedule = model.ExecutionSchedule(
graph_id=graph_id,
user_id=user_id,
graph_version=graph_version,
schedule=cron,
input_data=input_data,

View File

@@ -1,9 +1,12 @@
import asyncio
import inspect
from collections import defaultdict
from contextlib import asynccontextmanager
from functools import wraps
from typing import Annotated, Any, Dict
import uvicorn
from autogpt_libs.auth.jwt_utils import parse_jwt_token
from autogpt_libs.auth.middleware import auth_middleware
from fastapi import (
APIRouter,
@@ -59,6 +62,7 @@ class AgentServer(AppService):
manager = ConnectionManager()
mutex = KeyedMutex()
use_db = False
_test_dependency_overrides = {}
async def event_broadcaster(self):
while True:
@@ -87,6 +91,9 @@ class AgentServer(AppService):
lifespan=self.lifespan,
)
if self._test_dependency_overrides:
app.dependency_overrides.update(self._test_dependency_overrides)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allows all origins
@@ -222,6 +229,35 @@ class AgentServer(AppService):
uvicorn.run(app, host="0.0.0.0", port=8000)
def set_test_dependency_overrides(self, overrides: dict):
self._test_dependency_overrides = overrides
def _apply_overrides_to_methods(self):
for attr_name in dir(self):
attr = getattr(self, attr_name)
if callable(attr) and hasattr(attr, "__annotations__"):
setattr(self, attr_name, self._override_method(attr))
# TODO: fix this with some proper refactoring of the server
def _override_method(self, method):
@wraps(method)
async def wrapper(*args, **kwargs):
sig = inspect.signature(method)
for param_name, param in sig.parameters.items():
if param.annotation is inspect.Parameter.empty:
continue
if isinstance(param.annotation, Depends) or ( # type: ignore
isinstance(param.annotation, type) and issubclass(param.annotation, Depends) # type: ignore
):
dependency = param.annotation.dependency if isinstance(param.annotation, Depends) else param.annotation # type: ignore
if dependency in self._test_dependency_overrides:
kwargs[param_name] = self._test_dependency_overrides[
dependency
]()
return await method(*args, **kwargs)
return wrapper
@property
def execution_manager_client(self) -> ExecutionManager:
return get_service_client(ExecutionManager)
@@ -240,7 +276,27 @@ class AgentServer(AppService):
status_code=500,
)
async def authenticate_websocket(self, websocket: WebSocket) -> str:
token = websocket.query_params.get("token")
if not token:
await websocket.close(code=4001, reason="Missing authentication token")
return ""
try:
payload = parse_jwt_token(token)
user_id = payload.get("sub")
if not user_id:
await websocket.close(code=4002, reason="Invalid token")
return ""
return user_id
except ValueError:
await websocket.close(code=4003, reason="Invalid token")
return ""
async def websocket_router(self, websocket: WebSocket):
user_id = await self.authenticate_websocket(websocket)
if not user_id:
return
await self.manager.connect(websocket)
try:
while True:
@@ -279,7 +335,7 @@ class AgentServer(AppService):
).model_dump_json()
)
elif message.method == Methods.GET_GRAPHS:
data = await self.get_graphs()
data = await self.get_graphs(user_id=user_id)
await websocket.send_text(
WsMessage(
method=Methods.GET_GRAPHS,
@@ -290,7 +346,9 @@ class AgentServer(AppService):
print("Get graphs request received")
elif message.method == Methods.GET_GRAPH:
assert isinstance(message.data, dict), "Data must be a dictionary"
data = await self.get_graph(message.data["graph_id"])
data = await self.get_graph(
message.data["graph_id"], user_id=user_id
)
await websocket.send_text(
WsMessage(
method=Methods.GET_GRAPH,
@@ -302,7 +360,7 @@ class AgentServer(AppService):
elif message.method == Methods.CREATE_GRAPH:
assert isinstance(message.data, dict), "Data must be a dictionary"
create_graph = CreateGraph.model_validate(message.data)
data = await self.create_new_graph(create_graph)
data = await self.create_new_graph(create_graph, user_id=user_id)
await websocket.send_text(
WsMessage(
method=Methods.CREATE_GRAPH,
@@ -315,7 +373,7 @@ class AgentServer(AppService):
elif message.method == Methods.RUN_GRAPH:
assert isinstance(message.data, dict), "Data must be a dictionary"
data = await self.execute_graph(
message.data["graph_id"], message.data["data"]
message.data["graph_id"], message.data["data"], user_id=user_id
)
await websocket.send_text(
WsMessage(
@@ -328,7 +386,9 @@ class AgentServer(AppService):
print("Run graph request received")
elif message.method == Methods.GET_GRAPH_RUNS:
assert isinstance(message.data, dict), "Data must be a dictionary"
data = await self.list_graph_runs(message.data["graph_id"])
data = await self.list_graph_runs(
message.data["graph_id"], user_id=user_id
)
await websocket.send_text(
WsMessage(
method=Methods.GET_GRAPH_RUNS,
@@ -344,6 +404,7 @@ class AgentServer(AppService):
message.data["graph_id"],
message.data["cron"],
message.data["data"],
user_id=user_id,
)
await websocket.send_text(
WsMessage(
@@ -356,7 +417,9 @@ class AgentServer(AppService):
print("Create scheduled run request received")
elif message.method == Methods.GET_SCHEDULED_RUNS:
assert isinstance(message.data, dict), "Data must be a dictionary"
data = self.get_execution_schedules(message.data["graph_id"])
data = self.get_execution_schedules(
message.data["graph_id"], user_id=user_id
)
await websocket.send_text(
WsMessage(
method=Methods.GET_SCHEDULED_RUNS,
@@ -368,7 +431,7 @@ class AgentServer(AppService):
elif message.method == Methods.UPDATE_SCHEDULED_RUN:
assert isinstance(message.data, dict), "Data must be a dictionary"
data = self.update_schedule(
message.data["schedule_id"], message.data
message.data["schedule_id"], message.data, user_id=user_id
)
await websocket.send_text(
WsMessage(
@@ -431,7 +494,7 @@ class AgentServer(AppService):
@classmethod
async def get_graphs(
cls, user_id: str = Depends(get_user_id)
cls, user_id: Annotated[str, Depends(get_user_id)]
) -> list[graph_db.GraphMeta]:
return await graph_db.get_graphs_meta(filter_by="active", user_id=user_id)
@@ -443,8 +506,8 @@ class AgentServer(AppService):
async def get_graph(
cls,
graph_id: str,
user_id: Annotated[str, Depends(get_user_id)],
version: int | None = None,
user_id: str = Depends(get_user_id),
) -> graph_db.Graph:
graph = await graph_db.get_graph(graph_id, version, user_id=user_id)
if not graph:
@@ -464,7 +527,7 @@ class AgentServer(AppService):
@classmethod
async def get_graph_all_versions(
cls, graph_id: str, user_id: str = Depends(get_user_id)
cls, graph_id: str, user_id: Annotated[str, Depends(get_user_id)]
) -> list[graph_db.Graph]:
graphs = await graph_db.get_graph_all_versions(graph_id, user_id=user_id)
if not graphs:
@@ -473,13 +536,13 @@ class AgentServer(AppService):
@classmethod
async def create_new_graph(
cls, create_graph: CreateGraph, user_id: str = Depends(get_user_id)
cls, create_graph: CreateGraph, user_id: Annotated[str, Depends(get_user_id)]
) -> graph_db.Graph:
return await cls.create_graph(create_graph, is_template=False, user_id=user_id)
@classmethod
async def create_new_template(
cls, create_graph: CreateGraph, user_id: str = Depends(get_user_id)
cls, create_graph: CreateGraph, user_id: Annotated[str, Depends(get_user_id)]
) -> graph_db.Graph:
return await cls.create_graph(create_graph, is_template=True, user_id=user_id)
@@ -515,7 +578,10 @@ class AgentServer(AppService):
@classmethod
async def update_graph(
cls, graph_id: str, graph: graph_db.Graph, user_id: str = Depends(get_user_id)
cls,
graph_id: str,
graph: graph_db.Graph,
user_id: Annotated[str, Depends(get_user_id)],
) -> graph_db.Graph:
# Sanity check
if graph.id and graph.id != graph_id:
@@ -555,7 +621,7 @@ class AgentServer(AppService):
cls,
graph_id: str,
request_body: SetGraphActiveVersion,
user_id: str = Depends(get_user_id),
user_id: Annotated[str, Depends(get_user_id)],
):
new_active_version = request_body.active_graph_version
if not await graph_db.get_graph(graph_id, new_active_version, user_id=user_id):
@@ -572,7 +638,7 @@ class AgentServer(AppService):
self,
graph_id: str,
node_input: dict[Any, Any],
user_id: str = Depends(get_user_id),
user_id: Annotated[str, Depends(get_user_id)],
) -> dict[Any, Any]:
try:
return self.execution_manager_client.add_execution(
@@ -586,8 +652,8 @@ class AgentServer(AppService):
async def list_graph_runs(
cls,
graph_id: str,
user_id: Annotated[str, Depends(get_user_id)],
graph_version: int | None = None,
user_id: str = Depends(get_user_id),
) -> list[str]:
graph = await graph_db.get_graph(graph_id, graph_version, user_id=user_id)
if not graph:
@@ -600,7 +666,7 @@ class AgentServer(AppService):
@classmethod
async def get_run_execution_results(
cls, graph_id: str, run_id: str, user_id: str = Depends(get_user_id)
cls, graph_id: str, run_id: str, user_id: Annotated[str, Depends(get_user_id)]
) -> list[ExecutionResult]:
graph = await graph_db.get_graph(graph_id, user_id=user_id)
if not graph:
@@ -613,7 +679,7 @@ class AgentServer(AppService):
graph_id: str,
cron: str,
input_data: dict[Any, Any],
user_id: str = Depends(get_user_id),
user_id: Annotated[str, Depends(get_user_id)],
) -> dict[Any, Any]:
graph = await graph_db.get_graph(graph_id, user_id=user_id)
if not graph:
@@ -621,7 +687,7 @@ class AgentServer(AppService):
execution_scheduler = self.execution_scheduler_client
return {
"id": execution_scheduler.add_execution_schedule(
graph_id, graph.version, cron, input_data
graph_id, graph.version, cron, input_data, user_id=user_id
)
}
@@ -629,7 +695,7 @@ class AgentServer(AppService):
self,
schedule_id: str,
input_data: dict[Any, Any],
user_id: str = Depends(get_user_id),
user_id: Annotated[str, Depends(get_user_id)],
) -> dict[Any, Any]:
execution_scheduler = self.execution_scheduler_client
is_enabled = input_data.get("is_enabled", False)
@@ -637,7 +703,7 @@ class AgentServer(AppService):
return {"id": schedule_id}
def get_execution_schedules(
self, graph_id: str, user_id: str = Depends(get_user_id)
self, graph_id: str, user_id: Annotated[str, Depends(get_user_id)]
) -> dict[str, str]:
execution_scheduler = self.execution_scheduler_client
return execution_scheduler.get_execution_schedules(graph_id, user_id) # type: ignore

View File

@@ -253,7 +253,9 @@ async def block_autogen_agent():
test_user = await create_test_user()
test_graph = await create_graph(create_test_graph(), user_id=test_user.id)
input_data = {"input": "Write me a block that writes a string into a file."}
response = await server.agent_server.execute_graph(test_graph.id, input_data)
response = await server.agent_server.execute_graph(
test_graph.id, input_data, test_user.id
)
print(response)
result = await wait_execution(
exec_manager=test_manager,
@@ -261,6 +263,7 @@ async def block_autogen_agent():
graph_exec_id=response["id"],
num_execs=10,
timeout=1200,
user_id=test_user.id,
)
print(result)

View File

@@ -155,9 +155,13 @@ async def reddit_marketing_agent():
test_user = await create_test_user()
test_graph = await create_graph(create_test_graph(), user_id=test_user.id)
input_data = {"subreddit": "AutoGPT"}
response = await server.agent_server.execute_graph(test_graph.id, input_data)
response = await server.agent_server.execute_graph(
test_graph.id, input_data, test_user.id
)
print(response)
result = await wait_execution(exec_man, test_graph.id, response["id"], 13, 120)
result = await wait_execution(
exec_man, test_user.id, test_graph.id, response["id"], 13, 120
)
print(result)

View File

@@ -79,9 +79,13 @@ async def sample_agent():
test_user = await create_test_user()
test_graph = await create_graph(create_test_graph(), test_user.id)
input_data = {"input_1": "Hello", "input_2": "World"}
response = await server.agent_server.execute_graph(test_graph.id, input_data)
response = await server.agent_server.execute_graph(
test_graph.id, input_data, test_user.id
)
print(response)
result = await wait_execution(exec_man, test_graph.id, response["id"], 4, 10)
result = await wait_execution(
exec_man, test_user.id, test_graph.id, response["id"], 4, 10
)
print(result)

View File

@@ -5,6 +5,7 @@ from autogpt_server.data.block import Block, initialize_blocks
from autogpt_server.data.execution import ExecutionStatus
from autogpt_server.executor import ExecutionManager, ExecutionScheduler
from autogpt_server.server import AgentServer
from autogpt_server.server.server import get_user_id
from autogpt_server.util.service import PyroNameServer
log = print
@@ -17,14 +18,21 @@ class SpinTestServer:
self.agent_server = AgentServer()
self.scheduler = ExecutionScheduler()
@staticmethod
def test_get_user_id():
return "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
async def __aenter__(self):
self.name_server.__enter__()
self.setup_dependency_overrides()
self.agent_server.__enter__()
self.exec_manager.__enter__()
self.scheduler.__enter__()
await db.connect()
await initialize_blocks()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
@@ -35,16 +43,25 @@ class SpinTestServer:
self.exec_manager.__exit__(exc_type, exc_val, exc_tb)
self.scheduler.__exit__(exc_type, exc_val, exc_tb)
def setup_dependency_overrides(self):
# Override get_user_id for testing
self.agent_server.set_test_dependency_overrides(
{get_user_id: self.test_get_user_id}
)
async def wait_execution(
exec_manager: ExecutionManager,
user_id: str,
graph_id: str,
graph_exec_id: str,
num_execs: int,
timeout: int = 20,
) -> list:
async def is_execution_completed():
execs = await AgentServer().get_run_execution_results(graph_id, graph_exec_id)
execs = await AgentServer().get_run_execution_results(
graph_id, graph_exec_id, user_id
)
return (
exec_manager.queue.empty()
and len(execs) == num_execs
@@ -58,7 +75,7 @@ async def wait_execution(
for i in range(timeout):
if await is_execution_completed():
return await AgentServer().get_run_execution_results(
graph_id, graph_exec_id
graph_id, graph_exec_id, user_id
)
time.sleep(1)

View File

@@ -0,0 +1,5 @@
-- CreateIndex
CREATE INDEX "User_id_idx" ON "User"("id");
-- CreateIndex
CREATE INDEX "User_email_idx" ON "User"("email");

View File

@@ -0,0 +1,5 @@
-- CreateIndex
CREATE INDEX "User_id_idx" ON "User"("id");
-- CreateIndex
CREATE INDEX "User_email_idx" ON "User"("email");

View File

@@ -22,6 +22,9 @@ model User {
AgentGraphs AgentGraph[]
AgentGraphExecutions AgentGraphExecution[]
AgentGraphExecutionSchedules AgentGraphExecutionSchedule[]
@@index([id])
@@index([email])
}
// This model describes the Agent Graph/Flow (Multi Agent System).

View File

@@ -21,6 +21,9 @@ model User {
AgentGraphs AgentGraph[]
AgentGraphExecutions AgentGraphExecution[]
AgentGraphExecutionSchedules AgentGraphExecutionSchedule[]
@@index([id])
@@index([email])
}
// This model describes the Agent Graph/Flow (Multi Agent System).

View File

@@ -1,4 +1,5 @@
import pytest
from prisma.models import User
from autogpt_server.blocks.basic import ObjectLookupBlock, ValueBlock
from autogpt_server.blocks.maths import MathsBlock, Operation
@@ -13,24 +14,30 @@ async def execute_graph(
agent_server: AgentServer,
test_manager: ExecutionManager,
test_graph: graph.Graph,
test_user: User,
input_data: dict,
num_execs: int = 4,
) -> str:
# --- Test adding new executions --- #
response = await agent_server.execute_graph(test_graph.id, input_data)
response = await agent_server.execute_graph(test_graph.id, input_data, test_user.id)
graph_exec_id = response["id"]
# Execution queue should be empty
assert await wait_execution(test_manager, test_graph.id, graph_exec_id, num_execs)
assert await wait_execution(
test_manager, test_user.id, test_graph.id, graph_exec_id, num_execs
)
return graph_exec_id
async def assert_sample_graph_executions(
agent_server: AgentServer, test_graph: graph.Graph, graph_exec_id: str
agent_server: AgentServer,
test_graph: graph.Graph,
test_user: User,
graph_exec_id: str,
):
input = {"input_1": "Hello", "input_2": "World"}
executions = await agent_server.get_run_execution_results(
test_graph.id, graph_exec_id
test_graph.id, graph_exec_id, test_user.id
)
# Executing ValueBlock
@@ -79,9 +86,16 @@ async def test_agent_execution(server):
await graph.create_graph(test_graph, user_id=test_user.id)
data = {"input_1": "Hello", "input_2": "World"}
graph_exec_id = await execute_graph(
server.agent_server, server.exec_manager, test_graph, data, 4
server.agent_server,
server.exec_manager,
test_graph,
test_user,
data,
4,
)
await assert_sample_graph_executions(
server.agent_server, test_graph, test_user, graph_exec_id
)
await assert_sample_graph_executions(server.agent_server, test_graph, graph_exec_id)
@pytest.mark.asyncio(scope="session")
@@ -134,11 +148,11 @@ async def test_input_pin_always_waited(server):
test_user = await create_test_user()
test_graph = await graph.create_graph(test_graph, user_id=test_user.id)
graph_exec_id = await execute_graph(
server.agent_server, server.exec_manager, test_graph, {}, 3
server.agent_server, server.exec_manager, test_graph, test_user, {}, 3
)
executions = await server.agent_server.get_run_execution_results(
test_graph.id, graph_exec_id
test_graph.id, graph_exec_id, test_user.id
)
assert len(executions) == 3
# ObjectLookupBlock should wait for the input pin to be provided,
@@ -215,10 +229,10 @@ async def test_static_input_link_on_graph(server):
test_user = await create_test_user()
test_graph = await graph.create_graph(test_graph, user_id=test_user.id)
graph_exec_id = await execute_graph(
server.agent_server, server.exec_manager, test_graph, {}, 8
server.agent_server, server.exec_manager, test_graph, test_user, {}, 8
)
executions = await server.agent_server.get_run_execution_results(
test_graph.id, graph_exec_id
test_graph.id, graph_exec_id, test_user.id
)
assert len(executions) == 8
# The last 3 executions will be a+b=4+5=9

View File

@@ -14,18 +14,19 @@ async def test_agent_schedule(server):
scheduler = get_service_client(ExecutionScheduler)
schedules = scheduler.get_execution_schedules(test_graph.id)
schedules = scheduler.get_execution_schedules(test_graph.id, test_user.id)
assert len(schedules) == 0
schedule_id = scheduler.add_execution_schedule(
graph_id=test_graph.id,
user_id=test_user.id,
graph_version=1,
cron="0 0 * * *",
input_data={"input": "data"},
)
assert schedule_id
schedules = scheduler.get_execution_schedules(test_graph.id)
schedules = scheduler.get_execution_schedules(test_graph.id, test_user.id)
assert len(schedules) == 1
assert schedules[schedule_id] == "0 0 * * *"