feat(rnd): Add user management (#7663)

* add user management support

* fixing tests

* fix formatting and linting

* default user, formatting & linting

* remove unused code

* remove default creation when auth enabled

* add client side calls

* addressing feedback

* prettier

* add defaults for websockets

* linting

---------

Co-authored-by: Swifty <craigswift13@gmail.com>
This commit is contained in:
Aarushi
2024-08-07 14:15:29 +01:00
committed by GitHub
parent 31dbb543a2
commit 1bad26657c
23 changed files with 589 additions and 115 deletions

View File

@@ -1,3 +1,4 @@
import { createClient } from "../supabase/client";
import {
Block,
Graph,
@@ -13,6 +14,7 @@ export default class AutoGPTServerAPI {
private wsUrl: string;
private socket: WebSocket | null = null;
private messageHandlers: { [key: string]: (data: any) => void } = {};
private supabaseClient = createClient();
constructor(
baseUrl: string = process.env.NEXT_PUBLIC_AGPT_SERVER_URL ||
@@ -141,18 +143,23 @@ export default class AutoGPTServerAPI {
console.debug(`${method} ${path} payload:`, payload);
}
const response = await fetch(
this.baseUrl + path,
method != "GET"
? {
method,
headers: {
const token =
(await this.supabaseClient?.auth.getSession())?.data.session
?.access_token || "";
const response = await fetch(this.baseUrl + path, {
method,
headers:
method != "GET"
? {
"Content-Type": "application/json",
Authorization: token ? `Bearer ${token}` : "",
}
: {
Authorization: token ? `Bearer ${token}` : "",
},
body: JSON.stringify(payload),
}
: undefined,
);
body: JSON.stringify(payload),
});
const response_data = await response.json();
if (!response.ok) {

View File

@@ -12,7 +12,7 @@ def parse_jwt_token(token: str) -> Dict[str, Any]:
:raises ValueError: If the token is invalid or expired
"""
try:
payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM], audience="authenticated")
return payload
except jwt.ExpiredSignatureError:
raise ValueError("Token has expired")

View File

@@ -272,7 +272,8 @@ async def get_node(node_id: str) -> Node | None:
async def get_graphs_meta(
filter_by: Literal["active", "template"] | None = "active"
filter_by: Literal["active", "template"] | None = "active",
user_id: str | None = None,
) -> list[GraphMeta]:
"""
Retrieves graph metadata objects.
@@ -291,6 +292,9 @@ async def get_graphs_meta(
elif filter_by == "template":
where_clause["isTemplate"] = True
if user_id and filter_by != "template":
where_clause["userId"] = user_id
graphs = await AgentGraph.prisma().find_many(
where=where_clause,
distinct=["id"],
@@ -304,7 +308,10 @@ async def get_graphs_meta(
async def get_graph(
graph_id: str, version: int | None = None, template: bool = False
graph_id: str,
version: int | None = None,
template: bool = False,
user_id: str | None = None,
) -> Graph | None:
"""
Retrieves a graph from the DB.
@@ -322,6 +329,9 @@ async def get_graph(
elif not template:
where_clause["isActive"] = True
if user_id and not template:
where_clause["userId"] = user_id
graph = await AgentGraph.prisma().find_first(
where=where_clause,
include=AGENT_GRAPH_INCLUDE,
@@ -330,10 +340,23 @@ async def get_graph(
return Graph.from_db(graph) if graph else None
async def set_graph_active_version(graph_id: str, version: int) -> None:
async def set_graph_active_version(graph_id: str, version: int, user_id: str) -> None:
# Check if the graph belongs to the user
graph = await AgentGraph.prisma().find_first(
where={
"id": graph_id,
"version": version,
"userId": user_id,
}
)
if not graph:
raise Exception(f"Graph #{graph_id} v{version} not found or not owned by user")
updated_graph = await AgentGraph.prisma().update(
data={"isActive": True},
where={"graphVersionId": {"id": graph_id, "version": version}},
where={
"graphVersionId": {"id": graph_id, "version": version},
},
)
if not updated_graph:
raise Exception(f"Graph #{graph_id} v{version} not found")
@@ -341,13 +364,15 @@ async def set_graph_active_version(graph_id: str, version: int) -> None:
# Deactivate all other versions
await AgentGraph.prisma().update_many(
data={"isActive": False},
where={"id": graph_id, "version": {"not": version}},
where={"id": graph_id, "version": {"not": version}, "userId": user_id},
)
async def get_graph_all_versions(graph_id: str) -> list[Graph]:
async def get_graph_all_versions(
graph_id: str, user_id: str | None = None
) -> list[Graph]:
graph_versions = await AgentGraph.prisma().find_many(
where={"id": graph_id},
where={"id": graph_id, "userId": user_id},
order={"version": "desc"},
include=AGENT_GRAPH_INCLUDE,
)
@@ -358,17 +383,19 @@ async def get_graph_all_versions(graph_id: str) -> list[Graph]:
return [Graph.from_db(graph) for graph in graph_versions]
async def create_graph(graph: Graph) -> Graph:
async def create_graph(graph: Graph, user_id: str | None) -> Graph:
async with transaction() as tx:
await __create_graph(tx, graph)
await __create_graph(tx, graph, user_id)
if created_graph := await get_graph(graph.id, graph.version, graph.is_template):
if created_graph := await get_graph(
graph.id, graph.version, graph.is_template, user_id=user_id
):
return created_graph
raise ValueError(f"Created graph {graph.id} v{graph.version} is not in DB")
async def __create_graph(tx, graph: Graph):
async def __create_graph(tx, graph: Graph, user_id: str | None):
await AgentGraph.prisma(tx).create(
data={
"id": graph.id,
@@ -377,6 +404,7 @@ async def __create_graph(tx, graph: Graph):
"description": graph.description,
"isTemplate": graph.is_template,
"isActive": graph.is_active,
"userId": user_id,
}
)
@@ -391,6 +419,7 @@ async def __create_graph(tx, graph: Graph):
"description": f"Sub-Graph of {graph.id}",
"isTemplate": graph.is_template,
"isActive": graph.is_active,
"userId": user_id,
}
)
for subgraph_id in graph.subgraphs
@@ -453,5 +482,5 @@ async def import_packaged_templates() -> None:
exists := next((t for t in templates_in_db if t.id == template.id), None)
) and exists.version >= template.version:
continue
await create_graph(template)
await create_graph(template, None)
print(f"Loaded template '{template.name}' ({template.id})")

View File

@@ -10,6 +10,7 @@ from autogpt_server.util import json
class ExecutionSchedule(BaseDbModel):
graph_id: str
user_id: str
graph_version: int
schedule: str
is_enabled: bool
@@ -25,6 +26,7 @@ class ExecutionSchedule(BaseDbModel):
return ExecutionSchedule(
id=schedule.id,
graph_id=schedule.agentGraphId,
user_id=schedule.userId,
graph_version=schedule.agentGraphVersion,
schedule=schedule.schedule,
is_enabled=schedule.isEnabled,
@@ -47,11 +49,12 @@ async def disable_schedule(schedule_id: str):
)
async def get_schedules(graph_id: str) -> list[ExecutionSchedule]:
async def get_schedules(graph_id: str, user_id: str) -> list[ExecutionSchedule]:
query = AgentGraphExecutionSchedule.prisma().find_many(
where={
"isEnabled": True,
"agentGraphId": graph_id,
"userId": user_id,
},
)
return [ExecutionSchedule.from_db(schedule) for schedule in await query]
@@ -71,7 +74,7 @@ async def add_schedule(schedule: ExecutionSchedule) -> ExecutionSchedule:
return ExecutionSchedule.from_db(obj)
async def update_schedule(schedule_id: str, is_enabled: bool):
async def update_schedule(schedule_id: str, is_enabled: bool, user_id: str):
await AgentGraphExecutionSchedule.prisma().update(
where={"id": schedule_id}, data={"isEnabled": is_enabled}
)

View File

@@ -0,0 +1,40 @@
from typing import Optional
from prisma.models import User
from autogpt_server.data.db import prisma
async def get_or_create_user(user_data: dict) -> User:
user = await prisma.user.find_unique(where={"id": user_data["sub"]})
if not user:
user = await prisma.user.create(
data={
"id": user_data["sub"],
"email": user_data["email"],
"name": user_data.get("user_metadata", {}).get("name"),
}
)
return User.model_validate(user)
async def get_user_by_id(user_id: str) -> Optional[User]:
user = await prisma.user.find_unique(where={"id": user_id})
return User.model_validate(user) if user else None
async def create_default_user(enable_auth: str) -> Optional[User]:
if not enable_auth.lower() == "true":
user = await prisma.user.find_unique(
where={"id": "3e53486c-cf57-477e-ba2a-cb02dc828e1a"}
)
if not user:
user = await prisma.user.create(
data={
"id": "3e53486c-cf57-477e-ba2a-cb02dc828e1a",
"email": "default@example.com",
"name": "Default User",
}
)
return User.model_validate(user)
return None

View File

@@ -416,8 +416,10 @@ class ExecutionManager(AppService):
return get_agent_server_client()
@expose
def add_execution(self, graph_id: str, data: BlockInput) -> dict[Any, Any]:
graph: Graph | None = self.run_and_wait(get_graph(graph_id))
def add_execution(
self, graph_id: str, data: BlockInput, user_id: str
) -> dict[Any, Any]:
graph: Graph | None = self.run_and_wait(get_graph(graph_id, user_id=user_id))
if not graph:
raise Exception(f"Graph #{graph_id} not found.")
graph.validate_graph(for_run=True)

View File

@@ -62,16 +62,22 @@ class ExecutionScheduler(AppService):
logger.exception(f"Error executing graph {graph_id}: {e}")
@expose
def update_schedule(self, schedule_id: str, is_enabled: bool) -> str:
self.run_and_wait(model.update_schedule(schedule_id, is_enabled))
def update_schedule(self, schedule_id: str, is_enabled: bool, user_id: str) -> str:
self.run_and_wait(model.update_schedule(schedule_id, is_enabled, user_id))
return schedule_id
@expose
def add_execution_schedule(
self, graph_id: str, graph_version: int, cron: str, input_data: BlockInput
self,
graph_id: str,
graph_version: int,
cron: str,
input_data: BlockInput,
user_id: str,
) -> str:
schedule = model.ExecutionSchedule(
graph_id=graph_id,
user_id=user_id,
graph_version=graph_version,
schedule=cron,
input_data=input_data,
@@ -79,7 +85,7 @@ class ExecutionScheduler(AppService):
return self.run_and_wait(model.add_schedule(schedule)).id
@expose
def get_execution_schedules(self, graph_id: str) -> dict[str, str]:
query = model.get_schedules(graph_id)
def get_execution_schedules(self, graph_id: str, user_id: str) -> dict[str, str]:
query = model.get_schedules(graph_id, user_id=user_id)
schedules: list[model.ExecutionSchedule] = self.run_and_wait(query)
return {v.id: v.schedule for v in schedules}

View File

@@ -1,9 +1,12 @@
import asyncio
import inspect
from collections import defaultdict
from contextlib import asynccontextmanager
from functools import wraps
from typing import Annotated, Any, Dict
import uvicorn
from autogpt_libs.auth.jwt_utils import parse_jwt_token
from autogpt_libs.auth.middleware import auth_middleware
from fastapi import (
APIRouter,
@@ -20,12 +23,14 @@ from fastapi.responses import JSONResponse
import autogpt_server.server.ws_api
from autogpt_server.data import block, db
from autogpt_server.data import graph as graph_db
from autogpt_server.data import user as user_db
from autogpt_server.data.block import BlockInput, CompletedBlockOutput
from autogpt_server.data.execution import (
ExecutionResult,
get_execution_results,
list_executions,
)
from autogpt_server.data.user import get_or_create_user
from autogpt_server.executor import ExecutionManager, ExecutionScheduler
from autogpt_server.server.conn_manager import ConnectionManager
from autogpt_server.server.model import (
@@ -38,12 +43,26 @@ from autogpt_server.util.lock import KeyedMutex
from autogpt_server.util.service import AppService, expose, get_service_client
from autogpt_server.util.settings import Settings
settings = Settings()
def get_user_id(payload: dict = Depends(auth_middleware)) -> str:
if not payload:
# This handles the case when authentication is disabled
return "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
user_id = payload.get("sub")
if not user_id:
raise HTTPException(status_code=401, detail="User ID not found in token")
return user_id
class AgentServer(AppService):
event_queue: asyncio.Queue[ExecutionResult] = asyncio.Queue()
manager = ConnectionManager()
mutex = KeyedMutex()
use_db = False
_test_dependency_overrides = {}
async def event_broadcaster(self):
while True:
@@ -55,6 +74,7 @@ class AgentServer(AppService):
await db.connect()
await block.initialize_blocks()
await graph_db.import_packaged_templates()
await user_db.create_default_user(settings.config.enable_auth)
asyncio.create_task(self.event_broadcaster())
yield
await db.disconnect()
@@ -71,6 +91,9 @@ class AgentServer(AppService):
lifespan=self.lifespan,
)
if self._test_dependency_overrides:
app.dependency_overrides.update(self._test_dependency_overrides)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allows all origins
@@ -83,6 +106,12 @@ class AgentServer(AppService):
router = APIRouter(prefix="/api")
router.dependencies.append(Depends(auth_middleware))
router.add_api_route(
path="/auth/user",
endpoint=self.get_or_create_user_route,
methods=["POST"],
)
router.add_api_route(
path="/blocks",
endpoint=self.get_graph_blocks, # type: ignore
@@ -200,6 +229,35 @@ class AgentServer(AppService):
uvicorn.run(app, host="0.0.0.0", port=8000)
def set_test_dependency_overrides(self, overrides: dict):
self._test_dependency_overrides = overrides
def _apply_overrides_to_methods(self):
for attr_name in dir(self):
attr = getattr(self, attr_name)
if callable(attr) and hasattr(attr, "__annotations__"):
setattr(self, attr_name, self._override_method(attr))
# TODO: fix this with some proper refactoring of the server
def _override_method(self, method):
@wraps(method)
async def wrapper(*args, **kwargs):
sig = inspect.signature(method)
for param_name, param in sig.parameters.items():
if param.annotation is inspect.Parameter.empty:
continue
if isinstance(param.annotation, Depends) or ( # type: ignore
isinstance(param.annotation, type) and issubclass(param.annotation, Depends) # type: ignore
):
dependency = param.annotation.dependency if isinstance(param.annotation, Depends) else param.annotation # type: ignore
if dependency in self._test_dependency_overrides:
kwargs[param_name] = self._test_dependency_overrides[
dependency
]()
return await method(*args, **kwargs)
return wrapper
@property
def execution_manager_client(self) -> ExecutionManager:
return get_service_client(ExecutionManager)
@@ -218,7 +276,30 @@ class AgentServer(AppService):
status_code=500,
)
async def authenticate_websocket(self, websocket: WebSocket) -> str:
if settings.config.enable_auth.lower() == "true":
token = websocket.query_params.get("token")
if not token:
await websocket.close(code=4001, reason="Missing authentication token")
return ""
try:
payload = parse_jwt_token(token)
user_id = payload.get("sub")
if not user_id:
await websocket.close(code=4002, reason="Invalid token")
return ""
return user_id
except ValueError:
await websocket.close(code=4003, reason="Invalid token")
return ""
else:
return "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
async def websocket_router(self, websocket: WebSocket):
user_id = await self.authenticate_websocket(websocket)
if not user_id:
return
await self.manager.connect(websocket)
try:
while True:
@@ -257,7 +338,7 @@ class AgentServer(AppService):
).model_dump_json()
)
elif message.method == Methods.GET_GRAPHS:
data = await self.get_graphs()
data = await self.get_graphs(user_id=user_id)
await websocket.send_text(
WsMessage(
method=Methods.GET_GRAPHS,
@@ -268,7 +349,9 @@ class AgentServer(AppService):
print("Get graphs request received")
elif message.method == Methods.GET_GRAPH:
assert isinstance(message.data, dict), "Data must be a dictionary"
data = await self.get_graph(message.data["graph_id"])
data = await self.get_graph(
message.data["graph_id"], user_id=user_id
)
await websocket.send_text(
WsMessage(
method=Methods.GET_GRAPH,
@@ -280,7 +363,7 @@ class AgentServer(AppService):
elif message.method == Methods.CREATE_GRAPH:
assert isinstance(message.data, dict), "Data must be a dictionary"
create_graph = CreateGraph.model_validate(message.data)
data = await self.create_new_graph(create_graph)
data = await self.create_new_graph(create_graph, user_id=user_id)
await websocket.send_text(
WsMessage(
method=Methods.CREATE_GRAPH,
@@ -293,7 +376,7 @@ class AgentServer(AppService):
elif message.method == Methods.RUN_GRAPH:
assert isinstance(message.data, dict), "Data must be a dictionary"
data = await self.execute_graph(
message.data["graph_id"], message.data["data"]
message.data["graph_id"], message.data["data"], user_id=user_id
)
await websocket.send_text(
WsMessage(
@@ -306,7 +389,9 @@ class AgentServer(AppService):
print("Run graph request received")
elif message.method == Methods.GET_GRAPH_RUNS:
assert isinstance(message.data, dict), "Data must be a dictionary"
data = await self.list_graph_runs(message.data["graph_id"])
data = await self.list_graph_runs(
message.data["graph_id"], user_id=user_id
)
await websocket.send_text(
WsMessage(
method=Methods.GET_GRAPH_RUNS,
@@ -322,6 +407,7 @@ class AgentServer(AppService):
message.data["graph_id"],
message.data["cron"],
message.data["data"],
user_id=user_id,
)
await websocket.send_text(
WsMessage(
@@ -334,7 +420,9 @@ class AgentServer(AppService):
print("Create scheduled run request received")
elif message.method == Methods.GET_SCHEDULED_RUNS:
assert isinstance(message.data, dict), "Data must be a dictionary"
data = self.get_execution_schedules(message.data["graph_id"])
data = self.get_execution_schedules(
message.data["graph_id"], user_id=user_id
)
await websocket.send_text(
WsMessage(
method=Methods.GET_SCHEDULED_RUNS,
@@ -346,7 +434,7 @@ class AgentServer(AppService):
elif message.method == Methods.UPDATE_SCHEDULED_RUN:
assert isinstance(message.data, dict), "Data must be a dictionary"
data = self.update_schedule(
message.data["schedule_id"], message.data
message.data["schedule_id"], message.data, user_id=user_id
)
await websocket.send_text(
WsMessage(
@@ -385,6 +473,11 @@ class AgentServer(AppService):
self.manager.disconnect(websocket)
print("Client Disconnected")
@classmethod
async def get_or_create_user_route(cls, user_data: dict = Depends(auth_middleware)):
user = await get_or_create_user(user_data)
return user.model_dump()
@classmethod
def get_graph_blocks(cls) -> list[dict[Any, Any]]:
return [v.to_dict() for v in block.get_blocks().values()] # type: ignore
@@ -403,8 +496,10 @@ class AgentServer(AppService):
return output
@classmethod
async def get_graphs(cls) -> list[graph_db.GraphMeta]:
return await graph_db.get_graphs_meta(filter_by="active")
async def get_graphs(
cls, user_id: Annotated[str, Depends(get_user_id)]
) -> list[graph_db.GraphMeta]:
return await graph_db.get_graphs_meta(filter_by="active", user_id=user_id)
@classmethod
async def get_templates(cls) -> list[graph_db.GraphMeta]:
@@ -412,9 +507,12 @@ class AgentServer(AppService):
@classmethod
async def get_graph(
cls, graph_id: str, version: int | None = None
cls,
graph_id: str,
user_id: Annotated[str, Depends(get_user_id)],
version: int | None = None,
) -> graph_db.Graph:
graph = await graph_db.get_graph(graph_id, version)
graph = await graph_db.get_graph(graph_id, version, user_id=user_id)
if not graph:
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
return graph
@@ -431,30 +529,39 @@ class AgentServer(AppService):
return graph
@classmethod
async def get_graph_all_versions(cls, graph_id: str) -> list[graph_db.Graph]:
graphs = await graph_db.get_graph_all_versions(graph_id)
async def get_graph_all_versions(
cls, graph_id: str, user_id: Annotated[str, Depends(get_user_id)]
) -> list[graph_db.Graph]:
graphs = await graph_db.get_graph_all_versions(graph_id, user_id=user_id)
if not graphs:
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
return graphs
@classmethod
async def create_new_graph(cls, create_graph: CreateGraph) -> graph_db.Graph:
return await cls.create_graph(create_graph, is_template=False)
async def create_new_graph(
cls, create_graph: CreateGraph, user_id: Annotated[str, Depends(get_user_id)]
) -> graph_db.Graph:
return await cls.create_graph(create_graph, is_template=False, user_id=user_id)
@classmethod
async def create_new_template(cls, create_graph: CreateGraph) -> graph_db.Graph:
return await cls.create_graph(create_graph, is_template=True)
async def create_new_template(
cls, create_graph: CreateGraph, user_id: Annotated[str, Depends(get_user_id)]
) -> graph_db.Graph:
return await cls.create_graph(create_graph, is_template=True, user_id=user_id)
@classmethod
async def create_graph(
cls, create_graph: CreateGraph, is_template: bool
cls, create_graph: CreateGraph, is_template: bool, user_id: str
) -> graph_db.Graph:
if create_graph.graph:
graph = create_graph.graph
elif create_graph.template_id:
# Create a new graph from a template
graph = await graph_db.get_graph(
create_graph.template_id, create_graph.template_version, template=True
create_graph.template_id,
create_graph.template_version,
template=True,
user_id=user_id,
)
if not graph:
raise HTTPException(
@@ -470,16 +577,23 @@ class AgentServer(AppService):
graph.is_active = not is_template
graph.reassign_ids(reassign_graph_id=True)
return await graph_db.create_graph(graph)
return await graph_db.create_graph(graph, user_id=user_id)
@classmethod
async def update_graph(cls, graph_id: str, graph: graph_db.Graph) -> graph_db.Graph:
async def update_graph(
cls,
graph_id: str,
graph: graph_db.Graph,
user_id: Annotated[str, Depends(get_user_id)],
) -> graph_db.Graph:
# Sanity check
if graph.id and graph.id != graph_id:
raise HTTPException(400, detail="Graph ID does not match ID in URI")
# Determine new version
existing_versions = await graph_db.get_graph_all_versions(graph_id)
existing_versions = await graph_db.get_graph_all_versions(
graph_id, user_id=user_id
)
if not existing_versions:
raise HTTPException(404, detail=f"Graph #{graph_id} not found")
latest_version_number = max(g.version for g in existing_versions)
@@ -495,43 +609,56 @@ class AgentServer(AppService):
graph.is_active = not graph.is_template
graph.reassign_ids()
new_graph_version = await graph_db.create_graph(graph)
new_graph_version = await graph_db.create_graph(graph, user_id=user_id)
if new_graph_version.is_active:
# Ensure new version is the only active version
await graph_db.set_graph_active_version(
graph_id=graph_id, version=new_graph_version.version
graph_id=graph_id, version=new_graph_version.version, user_id=user_id
)
return new_graph_version
@classmethod
async def set_graph_active_version(
cls, graph_id: str, request_body: SetGraphActiveVersion
cls,
graph_id: str,
request_body: SetGraphActiveVersion,
user_id: Annotated[str, Depends(get_user_id)],
):
new_active_version = request_body.active_graph_version
if not await graph_db.get_graph(graph_id, new_active_version):
if not await graph_db.get_graph(graph_id, new_active_version, user_id=user_id):
raise HTTPException(
404, f"Graph #{graph_id} v{new_active_version} not found"
)
await graph_db.set_graph_active_version(
graph_id=graph_id, version=request_body.active_graph_version
graph_id=graph_id,
version=request_body.active_graph_version,
user_id=user_id,
)
async def execute_graph(
self, graph_id: str, node_input: dict[Any, Any]
self,
graph_id: str,
node_input: dict[Any, Any],
user_id: Annotated[str, Depends(get_user_id)],
) -> dict[Any, Any]:
try:
return self.execution_manager_client.add_execution(graph_id, node_input)
return self.execution_manager_client.add_execution(
graph_id, node_input, user_id=user_id
)
except Exception as e:
msg = e.__str__().encode().decode("unicode_escape")
raise HTTPException(status_code=400, detail=msg)
@classmethod
async def list_graph_runs(
cls, graph_id: str, graph_version: int | None = None
cls,
graph_id: str,
user_id: Annotated[str, Depends(get_user_id)],
graph_version: int | None = None,
) -> list[str]:
graph = await graph_db.get_graph(graph_id, graph_version)
graph = await graph_db.get_graph(graph_id, graph_version, user_id=user_id)
if not graph:
rev = "" if graph_version is None else f" v{graph_version}"
raise HTTPException(
@@ -542,38 +669,47 @@ class AgentServer(AppService):
@classmethod
async def get_run_execution_results(
cls, graph_id: str, run_id: str
cls, graph_id: str, run_id: str, user_id: Annotated[str, Depends(get_user_id)]
) -> list[ExecutionResult]:
graph = await graph_db.get_graph(graph_id)
graph = await graph_db.get_graph(graph_id, user_id=user_id)
if not graph:
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
return await get_execution_results(run_id)
async def create_schedule(
self, graph_id: str, cron: str, input_data: dict[Any, Any]
self,
graph_id: str,
cron: str,
input_data: dict[Any, Any],
user_id: Annotated[str, Depends(get_user_id)],
) -> dict[Any, Any]:
graph = await graph_db.get_graph(graph_id)
graph = await graph_db.get_graph(graph_id, user_id=user_id)
if not graph:
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
execution_scheduler = self.execution_scheduler_client
return {
"id": execution_scheduler.add_execution_schedule(
graph_id, graph.version, cron, input_data
graph_id, graph.version, cron, input_data, user_id=user_id
)
}
def update_schedule(
self, schedule_id: str, input_data: dict[Any, Any]
self,
schedule_id: str,
input_data: dict[Any, Any],
user_id: Annotated[str, Depends(get_user_id)],
) -> dict[Any, Any]:
execution_scheduler = self.execution_scheduler_client
is_enabled = input_data.get("is_enabled", False)
execution_scheduler.update_schedule(schedule_id, is_enabled) # type: ignore
execution_scheduler.update_schedule(schedule_id, is_enabled, user_id=user_id) # type: ignore
return {"id": schedule_id}
def get_execution_schedules(self, graph_id: str) -> dict[str, str]:
def get_execution_schedules(
self, graph_id: str, user_id: Annotated[str, Depends(get_user_id)]
) -> dict[str, str]:
execution_scheduler = self.execution_scheduler_client
return execution_scheduler.get_execution_schedules(graph_id) # type: ignore
return execution_scheduler.get_execution_schedules(graph_id, user_id) # type: ignore
@expose
def send_execution_update(self, execution_result_dict: dict[Any, Any]):

View File

@@ -1,11 +1,14 @@
from pathlib import Path
from prisma.models import User
from autogpt_server.blocks.basic import ValueBlock
from autogpt_server.blocks.block import BlockInstallationBlock
from autogpt_server.blocks.http import HttpRequestBlock
from autogpt_server.blocks.llm import TextLlmCallBlock
from autogpt_server.blocks.text import TextFormatterBlock, TextParserBlock
from autogpt_server.data.graph import Graph, Link, Node, create_graph
from autogpt_server.data.user import get_or_create_user
from autogpt_server.util.test import SpinTestServer, wait_execution
sample_block_modules = {
@@ -23,6 +26,16 @@ for module, description in sample_block_modules.items():
sample_block_codes[module] = f"[Example: {description}]\n{code}"
async def create_test_user() -> User:
test_user_data = {
"sub": "ef3b97d7-1161-4eb4-92b2-10c24fb154c1",
"email": "testuser@example.com",
"name": "Test User",
}
user = await get_or_create_user(test_user_data)
return user
def create_test_graph() -> Graph:
"""
ValueBlock (input)
@@ -237,9 +250,12 @@ Here are a couple of sample of the Block class implementation:
async def block_autogen_agent():
async with SpinTestServer() as server:
test_manager = server.exec_manager
test_graph = await create_graph(create_test_graph())
test_user = await create_test_user()
test_graph = await create_graph(create_test_graph(), user_id=test_user.id)
input_data = {"input": "Write me a block that writes a string into a file."}
response = await server.agent_server.execute_graph(test_graph.id, input_data)
response = await server.agent_server.execute_graph(
test_graph.id, input_data, test_user.id
)
print(response)
result = await wait_execution(
exec_manager=test_manager,
@@ -247,6 +263,7 @@ async def block_autogen_agent():
graph_exec_id=response["id"],
num_execs=10,
timeout=1200,
user_id=test_user.id,
)
print(result)

View File

@@ -1,7 +1,10 @@
from prisma.models import User
from autogpt_server.blocks.llm import ObjectLlmCallBlock
from autogpt_server.blocks.reddit import RedditGetPostsBlock, RedditPostCommentBlock
from autogpt_server.blocks.text import TextFormatterBlock, TextMatcherBlock
from autogpt_server.data.graph import Graph, Link, Node, create_graph
from autogpt_server.data.user import get_or_create_user
from autogpt_server.util.test import SpinTestServer, wait_execution
@@ -136,14 +139,29 @@ Make sure to only comment on a relevant post.
return test_graph
async def create_test_user() -> User:
test_user_data = {
"sub": "ef3b97d7-1161-4eb4-92b2-10c24fb154c1",
"email": "testuser@example.com",
"name": "Test User",
}
user = await get_or_create_user(test_user_data)
return user
async def reddit_marketing_agent():
async with SpinTestServer() as server:
exec_man = server.exec_manager
test_graph = await create_graph(create_test_graph())
test_user = await create_test_user()
test_graph = await create_graph(create_test_graph(), user_id=test_user.id)
input_data = {"subreddit": "AutoGPT"}
response = await server.agent_server.execute_graph(test_graph.id, input_data)
response = await server.agent_server.execute_graph(
test_graph.id, input_data, test_user.id
)
print(response)
result = await wait_execution(exec_man, test_graph.id, response["id"], 13, 120)
result = await wait_execution(
exec_man, test_user.id, test_graph.id, response["id"], 13, 120
)
print(result)

View File

@@ -1,10 +1,23 @@
from prisma.models import User
from autogpt_server.blocks.basic import InputBlock, PrintingBlock
from autogpt_server.blocks.text import TextFormatterBlock
from autogpt_server.data import graph
from autogpt_server.data.graph import create_graph
from autogpt_server.data.user import get_or_create_user
from autogpt_server.util.test import SpinTestServer, wait_execution
async def create_test_user() -> User:
test_user_data = {
"sub": "ef3b97d7-1161-4eb4-92b2-10c24fb154c1",
"email": "testuser@example.com",
"name": "Test User",
}
user = await get_or_create_user(test_user_data)
return user
def create_test_graph() -> graph.Graph:
"""
ValueBlock
@@ -63,11 +76,16 @@ def create_test_graph() -> graph.Graph:
async def sample_agent():
async with SpinTestServer() as server:
exec_man = server.exec_manager
test_graph = await create_graph(create_test_graph())
test_user = await create_test_user()
test_graph = await create_graph(create_test_graph(), test_user.id)
input_data = {"input_1": "Hello", "input_2": "World"}
response = await server.agent_server.execute_graph(test_graph.id, input_data)
response = await server.agent_server.execute_graph(
test_graph.id, input_data, test_user.id
)
print(response)
result = await wait_execution(exec_man, test_graph.id, response["id"], 4, 10)
result = await wait_execution(
exec_man, test_user.id, test_graph.id, response["id"], 4, 10
)
print(result)

View File

@@ -57,6 +57,10 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
default="localhost",
description="The default hostname of the Pyro server.",
)
enable_auth: str = Field(
default="false",
description="If authentication is enabled or not",
)
# Add more configuration fields as needed
model_config = SettingsConfigDict(

View File

@@ -5,6 +5,7 @@ from autogpt_server.data.block import Block, initialize_blocks
from autogpt_server.data.execution import ExecutionStatus
from autogpt_server.executor import ExecutionManager, ExecutionScheduler
from autogpt_server.server import AgentServer
from autogpt_server.server.server import get_user_id
from autogpt_server.util.service import PyroNameServer
log = print
@@ -17,14 +18,21 @@ class SpinTestServer:
self.agent_server = AgentServer()
self.scheduler = ExecutionScheduler()
@staticmethod
def test_get_user_id():
return "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
async def __aenter__(self):
self.name_server.__enter__()
self.setup_dependency_overrides()
self.agent_server.__enter__()
self.exec_manager.__enter__()
self.scheduler.__enter__()
await db.connect()
await initialize_blocks()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
@@ -35,16 +43,25 @@ class SpinTestServer:
self.exec_manager.__exit__(exc_type, exc_val, exc_tb)
self.scheduler.__exit__(exc_type, exc_val, exc_tb)
def setup_dependency_overrides(self):
# Override get_user_id for testing
self.agent_server.set_test_dependency_overrides(
{get_user_id: self.test_get_user_id}
)
async def wait_execution(
exec_manager: ExecutionManager,
user_id: str,
graph_id: str,
graph_exec_id: str,
num_execs: int,
timeout: int = 20,
) -> list:
async def is_execution_completed():
execs = await AgentServer().get_run_execution_results(graph_id, graph_exec_id)
execs = await AgentServer().get_run_execution_results(
graph_id, graph_exec_id, user_id
)
return (
exec_manager.queue.empty()
and len(execs) == num_execs
@@ -58,7 +75,7 @@ async def wait_execution(
for i in range(timeout):
if await is_execution_completed():
return await AgentServer().get_run_execution_results(
graph_id, graph_exec_id
graph_id, graph_exec_id, user_id
)
time.sleep(1)

View File

@@ -0,0 +1,60 @@
-- CreateTable
CREATE TABLE "User" (
"id" TEXT NOT NULL PRIMARY KEY,
"email" TEXT NOT NULL,
"name" TEXT,
"createdAt" DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" DATETIME NOT NULL
);
-- RedefineTables
PRAGMA foreign_keys=OFF;
CREATE TABLE "new_AgentGraph" (
"id" TEXT NOT NULL,
"version" INTEGER NOT NULL DEFAULT 1,
"name" TEXT,
"description" TEXT,
"isActive" BOOLEAN NOT NULL DEFAULT true,
"isTemplate" BOOLEAN NOT NULL DEFAULT false,
"userId" TEXT,
"agentGraphParentId" TEXT,
PRIMARY KEY ("id", "version"),
CONSTRAINT "AgentGraph_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User" ("id") ON DELETE SET NULL ON UPDATE CASCADE,
CONSTRAINT "AgentGraph_agentGraphParentId_version_fkey" FOREIGN KEY ("agentGraphParentId", "version") REFERENCES "AgentGraph" ("id", "version") ON DELETE RESTRICT ON UPDATE CASCADE
);
INSERT INTO "new_AgentGraph" ("agentGraphParentId", "description", "id", "isActive", "isTemplate", "name", "version") SELECT "agentGraphParentId", "description", "id", "isActive", "isTemplate", "name", "version" FROM "AgentGraph";
DROP TABLE "AgentGraph";
ALTER TABLE "new_AgentGraph" RENAME TO "AgentGraph";
CREATE TABLE "new_AgentGraphExecution" (
"id" TEXT NOT NULL PRIMARY KEY,
"agentGraphId" TEXT NOT NULL,
"agentGraphVersion" INTEGER NOT NULL DEFAULT 1,
"userId" TEXT,
CONSTRAINT "AgentGraphExecution_agentGraphId_agentGraphVersion_fkey" FOREIGN KEY ("agentGraphId", "agentGraphVersion") REFERENCES "AgentGraph" ("id", "version") ON DELETE RESTRICT ON UPDATE CASCADE,
CONSTRAINT "AgentGraphExecution_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User" ("id") ON DELETE SET NULL ON UPDATE CASCADE
);
INSERT INTO "new_AgentGraphExecution" ("agentGraphId", "agentGraphVersion", "id") SELECT "agentGraphId", "agentGraphVersion", "id" FROM "AgentGraphExecution";
DROP TABLE "AgentGraphExecution";
ALTER TABLE "new_AgentGraphExecution" RENAME TO "AgentGraphExecution";
CREATE TABLE "new_AgentGraphExecutionSchedule" (
"id" TEXT NOT NULL PRIMARY KEY,
"agentGraphId" TEXT NOT NULL,
"agentGraphVersion" INTEGER NOT NULL DEFAULT 1,
"schedule" TEXT NOT NULL,
"isEnabled" BOOLEAN NOT NULL DEFAULT true,
"inputData" TEXT NOT NULL,
"lastUpdated" DATETIME NOT NULL,
"userId" TEXT,
CONSTRAINT "AgentGraphExecutionSchedule_agentGraphId_agentGraphVersion_fkey" FOREIGN KEY ("agentGraphId", "agentGraphVersion") REFERENCES "AgentGraph" ("id", "version") ON DELETE RESTRICT ON UPDATE CASCADE,
CONSTRAINT "AgentGraphExecutionSchedule_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User" ("id") ON DELETE SET NULL ON UPDATE CASCADE
);
INSERT INTO "new_AgentGraphExecutionSchedule" ("agentGraphId", "agentGraphVersion", "id", "inputData", "isEnabled", "lastUpdated", "schedule") SELECT "agentGraphId", "agentGraphVersion", "id", "inputData", "isEnabled", "lastUpdated", "schedule" FROM "AgentGraphExecutionSchedule";
DROP TABLE "AgentGraphExecutionSchedule";
ALTER TABLE "new_AgentGraphExecutionSchedule" RENAME TO "AgentGraphExecutionSchedule";
CREATE INDEX "AgentGraphExecutionSchedule_isEnabled_idx" ON "AgentGraphExecutionSchedule"("isEnabled");
PRAGMA foreign_key_check;
PRAGMA foreign_keys=ON;
-- CreateIndex
CREATE UNIQUE INDEX "User_email_key" ON "User"("email");

View File

@@ -0,0 +1,5 @@
-- CreateIndex
CREATE INDEX "User_id_idx" ON "User"("id");
-- CreateIndex
CREATE INDEX "User_email_idx" ON "User"("email");

View File

@@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
[[package]]
name = "agpt"
@@ -25,7 +25,7 @@ requests = "*"
sentry-sdk = "^1.40.4"
[package.extras]
benchmark = ["agbenchmark @ file:///home/bently/Desktop/autogpt-ui/AutoGPT/benchmark"]
benchmark = ["agbenchmark"]
[package.source]
type = "directory"
@@ -329,7 +329,7 @@ watchdog = "4.0.0"
webdriver-manager = "^4.0.1"
[package.extras]
benchmark = ["agbenchmark @ file:///home/bently/Desktop/autogpt-ui/AutoGPT/benchmark"]
benchmark = ["agbenchmark"]
[package.source]
type = "directory"
@@ -342,7 +342,7 @@ description = "Shared libraries across NextGen AutoGPT"
optional = false
python-versions = ">=3.10,<4.0"
files = []
develop = true
develop = false
[package.dependencies]
pyjwt = "^2.8.0"
@@ -4212,19 +4212,19 @@ windows-terminal = ["colorama (>=0.4.6)"]
[[package]]
name = "pyjwt"
version = "2.8.0"
version = "2.9.0"
description = "JSON Web Token implementation in Python"
optional = false
python-versions = ">=3.7"
python-versions = ">=3.8"
files = [
{file = "PyJWT-2.8.0-py3-none-any.whl", hash = "sha256:59127c392cc44c2da5bb3192169a91f429924e17aff6534d70fdc02ab3e04320"},
{file = "PyJWT-2.8.0.tar.gz", hash = "sha256:57e28d156e3d5c10088e0c68abb90bfac3df82b40a71bd0daa20c65ccd5c23de"},
{file = "PyJWT-2.9.0-py3-none-any.whl", hash = "sha256:3b02fb0f44517787776cf48f2ae25d8e14f300e6d7545a4315cee571a415e850"},
{file = "pyjwt-2.9.0.tar.gz", hash = "sha256:7e1e5b56cc735432a7369cbfa0efe50fa113ebecdc04ae6922deba8b84582d0c"},
]
[package.extras]
crypto = ["cryptography (>=3.4.0)"]
dev = ["coverage[toml] (==5.0.4)", "cryptography (>=3.4.0)", "pre-commit", "pytest (>=6.0.0,<7.0.0)", "sphinx (>=4.5.0,<5.0.0)", "sphinx-rtd-theme", "zope.interface"]
docs = ["sphinx (>=4.5.0,<5.0.0)", "sphinx-rtd-theme", "zope.interface"]
dev = ["coverage[toml] (==5.0.4)", "cryptography (>=3.4.0)", "pre-commit", "pytest (>=6.0.0,<7.0.0)", "sphinx", "sphinx-rtd-theme", "zope.interface"]
docs = ["sphinx", "sphinx-rtd-theme", "zope.interface"]
tests = ["coverage[toml] (==5.0.4)", "pytest (>=6.0.0,<7.0.0)"]
[[package]]
@@ -6419,4 +6419,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools",
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
content-hash = "9991857e7076d3bfcbae7af6c2cec54dc943167a3adceb5a0ebf74d80c05778f"
content-hash = "003a4c89682abbf72c67631367f57e56d91d72b44f95e972b2326440199045e7"

View File

@@ -0,0 +1,31 @@
-- AlterTable
ALTER TABLE "AgentGraph" ADD COLUMN "userId" TEXT;
-- AlterTable
ALTER TABLE "AgentGraphExecution" ADD COLUMN "userId" TEXT;
-- AlterTable
ALTER TABLE "AgentGraphExecutionSchedule" ADD COLUMN "userId" TEXT;
-- CreateTable
CREATE TABLE "User" (
"id" TEXT NOT NULL,
"email" TEXT NOT NULL,
"name" TEXT,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
CONSTRAINT "User_pkey" PRIMARY KEY ("id")
);
-- CreateIndex
CREATE UNIQUE INDEX "User_email_key" ON "User"("email");
-- AddForeignKey
ALTER TABLE "AgentGraph" ADD CONSTRAINT "AgentGraph_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE SET NULL ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "AgentGraphExecution" ADD CONSTRAINT "AgentGraphExecution_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE SET NULL ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "AgentGraphExecutionSchedule" ADD CONSTRAINT "AgentGraphExecutionSchedule_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE SET NULL ON UPDATE CASCADE;

View File

@@ -0,0 +1,5 @@
-- CreateIndex
CREATE INDEX "User_id_idx" ON "User"("id");
-- CreateIndex
CREATE INDEX "User_email_idx" ON "User"("email");

View File

@@ -10,6 +10,23 @@ generator client {
interface = "asyncio"
}
// User model to mirror Auth provider users
model User {
id String @id // This should match the Supabase user ID
email String @unique
name String?
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
// Relations
AgentGraphs AgentGraph[]
AgentGraphExecutions AgentGraphExecution[]
AgentGraphExecutionSchedules AgentGraphExecutionSchedule[]
@@index([id])
@@index([email])
}
// This model describes the Agent Graph/Flow (Multi Agent System).
model AgentGraph {
id String @default(uuid())
@@ -20,6 +37,10 @@ model AgentGraph {
isActive Boolean @default(true)
isTemplate Boolean @default(false)
// Link to User model
userId String?
user User? @relation(fields: [userId], references: [id])
AgentNodes AgentNode[]
AgentGraphExecution AgentGraphExecution[]
AgentGraphExecutionSchedule AgentGraphExecutionSchedule[]
@@ -99,6 +120,10 @@ model AgentGraphExecution {
AgentGraph AgentGraph @relation(fields: [agentGraphId, agentGraphVersion], references: [id, version])
AgentNodeExecutions AgentNodeExecution[]
// Link to User model
userId String?
user User? @relation(fields: [userId], references: [id])
}
// This model describes the execution of an AgentNode.
@@ -158,5 +183,9 @@ model AgentGraphExecutionSchedule {
// default and set the value on each update, lastUpdated field has no time zone.
lastUpdated DateTime @updatedAt
// Link to User model
userId String?
user User? @relation(fields: [userId], references: [id])
@@index([isEnabled])
}

View File

@@ -41,7 +41,7 @@ python-dotenv = "^1.0.1"
expiringdict = "^1.2.2"
discord-py = "^2.4.0"
autogpt-libs = { path = "../autogpt_libs", develop = true }
autogpt-libs = {path = "../autogpt_libs"}
[tool.poetry.group.dev.dependencies]
cx-freeze = { git = "https://github.com/ntindle/cx_Freeze.git", rev = "main", develop = true }
poethepoet = "^0.26.1"

View File

@@ -9,6 +9,23 @@ generator client {
interface = "asyncio"
}
// User model to mirror Auth provider users
model User {
id String @id // This should match the Supabase user ID
email String @unique
name String?
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
// Relations
AgentGraphs AgentGraph[]
AgentGraphExecutions AgentGraphExecution[]
AgentGraphExecutionSchedules AgentGraphExecutionSchedule[]
@@index([id])
@@index([email])
}
// This model describes the Agent Graph/Flow (Multi Agent System).
model AgentGraph {
id String @default(uuid())
@@ -19,6 +36,10 @@ model AgentGraph {
isActive Boolean @default(true)
isTemplate Boolean @default(false)
// Link to User model
userId String?
user User? @relation(fields: [userId], references: [id])
AgentNodes AgentNode[]
AgentGraphExecution AgentGraphExecution[]
AgentGraphExecutionSchedule AgentGraphExecutionSchedule[]
@@ -98,6 +119,10 @@ model AgentGraphExecution {
AgentGraph AgentGraph @relation(fields: [agentGraphId, agentGraphVersion], references: [id, version])
AgentNodeExecutions AgentNodeExecution[]
// Link to User model
userId String?
user User? @relation(fields: [userId], references: [id])
}
// This model describes the execution of an AgentNode.
@@ -157,5 +182,9 @@ model AgentGraphExecutionSchedule {
// default and set the value on each update, lastUpdated field has no time zone.
lastUpdated DateTime @updatedAt
// Link to User model
userId String?
user User? @relation(fields: [userId], references: [id])
@@index([isEnabled])
}
}

View File

@@ -1,11 +1,12 @@
import pytest
from prisma.models import User
from autogpt_server.blocks.basic import ObjectLookupBlock, ValueBlock
from autogpt_server.blocks.maths import MathsBlock, Operation
from autogpt_server.data import execution, graph
from autogpt_server.executor import ExecutionManager
from autogpt_server.server import AgentServer
from autogpt_server.usecases.sample import create_test_graph
from autogpt_server.usecases.sample import create_test_graph, create_test_user
from autogpt_server.util.test import wait_execution
@@ -13,24 +14,30 @@ async def execute_graph(
agent_server: AgentServer,
test_manager: ExecutionManager,
test_graph: graph.Graph,
test_user: User,
input_data: dict,
num_execs: int = 4,
) -> str:
# --- Test adding new executions --- #
response = await agent_server.execute_graph(test_graph.id, input_data)
response = await agent_server.execute_graph(test_graph.id, input_data, test_user.id)
graph_exec_id = response["id"]
# Execution queue should be empty
assert await wait_execution(test_manager, test_graph.id, graph_exec_id, num_execs)
assert await wait_execution(
test_manager, test_user.id, test_graph.id, graph_exec_id, num_execs
)
return graph_exec_id
async def assert_sample_graph_executions(
agent_server: AgentServer, test_graph: graph.Graph, graph_exec_id: str
agent_server: AgentServer,
test_graph: graph.Graph,
test_user: User,
graph_exec_id: str,
):
input = {"input_1": "Hello", "input_2": "World"}
executions = await agent_server.get_run_execution_results(
test_graph.id, graph_exec_id
test_graph.id, graph_exec_id, test_user.id
)
# Executing ValueBlock
@@ -75,12 +82,20 @@ async def assert_sample_graph_executions(
@pytest.mark.asyncio(scope="session")
async def test_agent_execution(server):
test_graph = create_test_graph()
await graph.create_graph(test_graph)
test_user = await create_test_user()
await graph.create_graph(test_graph, user_id=test_user.id)
data = {"input_1": "Hello", "input_2": "World"}
graph_exec_id = await execute_graph(
server.agent_server, server.exec_manager, test_graph, data, 4
server.agent_server,
server.exec_manager,
test_graph,
test_user,
data,
4,
)
await assert_sample_graph_executions(
server.agent_server, test_graph, test_user, graph_exec_id
)
await assert_sample_graph_executions(server.agent_server, test_graph, graph_exec_id)
@pytest.mark.asyncio(scope="session")
@@ -130,14 +145,14 @@ async def test_input_pin_always_waited(server):
nodes=nodes,
links=links,
)
test_graph = await graph.create_graph(test_graph)
test_user = await create_test_user()
test_graph = await graph.create_graph(test_graph, user_id=test_user.id)
graph_exec_id = await execute_graph(
server.agent_server, server.exec_manager, test_graph, {}, 3
server.agent_server, server.exec_manager, test_graph, test_user, {}, 3
)
executions = await server.agent_server.get_run_execution_results(
test_graph.id, graph_exec_id
test_graph.id, graph_exec_id, test_user.id
)
assert len(executions) == 3
# ObjectLookupBlock should wait for the input pin to be provided,
@@ -211,13 +226,13 @@ async def test_static_input_link_on_graph(server):
nodes=nodes,
links=links,
)
test_graph = await graph.create_graph(test_graph)
test_user = await create_test_user()
test_graph = await graph.create_graph(test_graph, user_id=test_user.id)
graph_exec_id = await execute_graph(
server.agent_server, server.exec_manager, test_graph, {}, 8
server.agent_server, server.exec_manager, test_graph, test_user, {}, 8
)
executions = await server.agent_server.get_run_execution_results(
test_graph.id, graph_exec_id
test_graph.id, graph_exec_id, test_user.id
)
assert len(executions) == 8
# The last 3 executions will be a+b=4+5=9

View File

@@ -2,29 +2,32 @@ import pytest
from autogpt_server.data import db, graph
from autogpt_server.executor import ExecutionScheduler
from autogpt_server.usecases.sample import create_test_graph
from autogpt_server.usecases.sample import create_test_graph, create_test_user
from autogpt_server.util.service import get_service_client
@pytest.mark.skip(reason="flakey test, needs to be investigated")
@pytest.mark.asyncio(scope="session")
async def test_agent_schedule(server):
await db.connect()
test_graph = await graph.create_graph(create_test_graph())
test_user = await create_test_user()
test_graph = await graph.create_graph(create_test_graph(), user_id=test_user.id)
scheduler = get_service_client(ExecutionScheduler)
schedules = scheduler.get_execution_schedules(test_graph.id)
schedules = scheduler.get_execution_schedules(test_graph.id, test_user.id)
assert len(schedules) == 0
schedule_id = scheduler.add_execution_schedule(
graph_id=test_graph.id,
user_id=test_user.id,
graph_version=1,
cron="0 0 * * *",
input_data={"input": "data"},
)
assert schedule_id
schedules = scheduler.get_execution_schedules(test_graph.id)
schedules = scheduler.get_execution_schedules(test_graph.id, test_user.id)
assert len(schedules) == 1
assert schedules[schedule_id] == "0 0 * * *"