mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
feat(rnd): Add user management (#7663)
* add user management support * fixing tests * fix formatting and linting * default user, formatting & linting * remove unused code * remove default creation when auth enabled * add client side calls * addressing feedback * prettier * add defaults for websockets * linting --------- Co-authored-by: Swifty <craigswift13@gmail.com>
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import { createClient } from "../supabase/client";
|
||||
import {
|
||||
Block,
|
||||
Graph,
|
||||
@@ -13,6 +14,7 @@ export default class AutoGPTServerAPI {
|
||||
private wsUrl: string;
|
||||
private socket: WebSocket | null = null;
|
||||
private messageHandlers: { [key: string]: (data: any) => void } = {};
|
||||
private supabaseClient = createClient();
|
||||
|
||||
constructor(
|
||||
baseUrl: string = process.env.NEXT_PUBLIC_AGPT_SERVER_URL ||
|
||||
@@ -141,18 +143,23 @@ export default class AutoGPTServerAPI {
|
||||
console.debug(`${method} ${path} payload:`, payload);
|
||||
}
|
||||
|
||||
const response = await fetch(
|
||||
this.baseUrl + path,
|
||||
method != "GET"
|
||||
? {
|
||||
method,
|
||||
headers: {
|
||||
const token =
|
||||
(await this.supabaseClient?.auth.getSession())?.data.session
|
||||
?.access_token || "";
|
||||
|
||||
const response = await fetch(this.baseUrl + path, {
|
||||
method,
|
||||
headers:
|
||||
method != "GET"
|
||||
? {
|
||||
"Content-Type": "application/json",
|
||||
Authorization: token ? `Bearer ${token}` : "",
|
||||
}
|
||||
: {
|
||||
Authorization: token ? `Bearer ${token}` : "",
|
||||
},
|
||||
body: JSON.stringify(payload),
|
||||
}
|
||||
: undefined,
|
||||
);
|
||||
body: JSON.stringify(payload),
|
||||
});
|
||||
const response_data = await response.json();
|
||||
|
||||
if (!response.ok) {
|
||||
|
||||
@@ -12,7 +12,7 @@ def parse_jwt_token(token: str) -> Dict[str, Any]:
|
||||
:raises ValueError: If the token is invalid or expired
|
||||
"""
|
||||
try:
|
||||
payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
|
||||
payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM], audience="authenticated")
|
||||
return payload
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise ValueError("Token has expired")
|
||||
|
||||
@@ -272,7 +272,8 @@ async def get_node(node_id: str) -> Node | None:
|
||||
|
||||
|
||||
async def get_graphs_meta(
|
||||
filter_by: Literal["active", "template"] | None = "active"
|
||||
filter_by: Literal["active", "template"] | None = "active",
|
||||
user_id: str | None = None,
|
||||
) -> list[GraphMeta]:
|
||||
"""
|
||||
Retrieves graph metadata objects.
|
||||
@@ -291,6 +292,9 @@ async def get_graphs_meta(
|
||||
elif filter_by == "template":
|
||||
where_clause["isTemplate"] = True
|
||||
|
||||
if user_id and filter_by != "template":
|
||||
where_clause["userId"] = user_id
|
||||
|
||||
graphs = await AgentGraph.prisma().find_many(
|
||||
where=where_clause,
|
||||
distinct=["id"],
|
||||
@@ -304,7 +308,10 @@ async def get_graphs_meta(
|
||||
|
||||
|
||||
async def get_graph(
|
||||
graph_id: str, version: int | None = None, template: bool = False
|
||||
graph_id: str,
|
||||
version: int | None = None,
|
||||
template: bool = False,
|
||||
user_id: str | None = None,
|
||||
) -> Graph | None:
|
||||
"""
|
||||
Retrieves a graph from the DB.
|
||||
@@ -322,6 +329,9 @@ async def get_graph(
|
||||
elif not template:
|
||||
where_clause["isActive"] = True
|
||||
|
||||
if user_id and not template:
|
||||
where_clause["userId"] = user_id
|
||||
|
||||
graph = await AgentGraph.prisma().find_first(
|
||||
where=where_clause,
|
||||
include=AGENT_GRAPH_INCLUDE,
|
||||
@@ -330,10 +340,23 @@ async def get_graph(
|
||||
return Graph.from_db(graph) if graph else None
|
||||
|
||||
|
||||
async def set_graph_active_version(graph_id: str, version: int) -> None:
|
||||
async def set_graph_active_version(graph_id: str, version: int, user_id: str) -> None:
|
||||
# Check if the graph belongs to the user
|
||||
graph = await AgentGraph.prisma().find_first(
|
||||
where={
|
||||
"id": graph_id,
|
||||
"version": version,
|
||||
"userId": user_id,
|
||||
}
|
||||
)
|
||||
if not graph:
|
||||
raise Exception(f"Graph #{graph_id} v{version} not found or not owned by user")
|
||||
|
||||
updated_graph = await AgentGraph.prisma().update(
|
||||
data={"isActive": True},
|
||||
where={"graphVersionId": {"id": graph_id, "version": version}},
|
||||
where={
|
||||
"graphVersionId": {"id": graph_id, "version": version},
|
||||
},
|
||||
)
|
||||
if not updated_graph:
|
||||
raise Exception(f"Graph #{graph_id} v{version} not found")
|
||||
@@ -341,13 +364,15 @@ async def set_graph_active_version(graph_id: str, version: int) -> None:
|
||||
# Deactivate all other versions
|
||||
await AgentGraph.prisma().update_many(
|
||||
data={"isActive": False},
|
||||
where={"id": graph_id, "version": {"not": version}},
|
||||
where={"id": graph_id, "version": {"not": version}, "userId": user_id},
|
||||
)
|
||||
|
||||
|
||||
async def get_graph_all_versions(graph_id: str) -> list[Graph]:
|
||||
async def get_graph_all_versions(
|
||||
graph_id: str, user_id: str | None = None
|
||||
) -> list[Graph]:
|
||||
graph_versions = await AgentGraph.prisma().find_many(
|
||||
where={"id": graph_id},
|
||||
where={"id": graph_id, "userId": user_id},
|
||||
order={"version": "desc"},
|
||||
include=AGENT_GRAPH_INCLUDE,
|
||||
)
|
||||
@@ -358,17 +383,19 @@ async def get_graph_all_versions(graph_id: str) -> list[Graph]:
|
||||
return [Graph.from_db(graph) for graph in graph_versions]
|
||||
|
||||
|
||||
async def create_graph(graph: Graph) -> Graph:
|
||||
async def create_graph(graph: Graph, user_id: str | None) -> Graph:
|
||||
async with transaction() as tx:
|
||||
await __create_graph(tx, graph)
|
||||
await __create_graph(tx, graph, user_id)
|
||||
|
||||
if created_graph := await get_graph(graph.id, graph.version, graph.is_template):
|
||||
if created_graph := await get_graph(
|
||||
graph.id, graph.version, graph.is_template, user_id=user_id
|
||||
):
|
||||
return created_graph
|
||||
|
||||
raise ValueError(f"Created graph {graph.id} v{graph.version} is not in DB")
|
||||
|
||||
|
||||
async def __create_graph(tx, graph: Graph):
|
||||
async def __create_graph(tx, graph: Graph, user_id: str | None):
|
||||
await AgentGraph.prisma(tx).create(
|
||||
data={
|
||||
"id": graph.id,
|
||||
@@ -377,6 +404,7 @@ async def __create_graph(tx, graph: Graph):
|
||||
"description": graph.description,
|
||||
"isTemplate": graph.is_template,
|
||||
"isActive": graph.is_active,
|
||||
"userId": user_id,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -391,6 +419,7 @@ async def __create_graph(tx, graph: Graph):
|
||||
"description": f"Sub-Graph of {graph.id}",
|
||||
"isTemplate": graph.is_template,
|
||||
"isActive": graph.is_active,
|
||||
"userId": user_id,
|
||||
}
|
||||
)
|
||||
for subgraph_id in graph.subgraphs
|
||||
@@ -453,5 +482,5 @@ async def import_packaged_templates() -> None:
|
||||
exists := next((t for t in templates_in_db if t.id == template.id), None)
|
||||
) and exists.version >= template.version:
|
||||
continue
|
||||
await create_graph(template)
|
||||
await create_graph(template, None)
|
||||
print(f"Loaded template '{template.name}' ({template.id})")
|
||||
|
||||
@@ -10,6 +10,7 @@ from autogpt_server.util import json
|
||||
|
||||
class ExecutionSchedule(BaseDbModel):
|
||||
graph_id: str
|
||||
user_id: str
|
||||
graph_version: int
|
||||
schedule: str
|
||||
is_enabled: bool
|
||||
@@ -25,6 +26,7 @@ class ExecutionSchedule(BaseDbModel):
|
||||
return ExecutionSchedule(
|
||||
id=schedule.id,
|
||||
graph_id=schedule.agentGraphId,
|
||||
user_id=schedule.userId,
|
||||
graph_version=schedule.agentGraphVersion,
|
||||
schedule=schedule.schedule,
|
||||
is_enabled=schedule.isEnabled,
|
||||
@@ -47,11 +49,12 @@ async def disable_schedule(schedule_id: str):
|
||||
)
|
||||
|
||||
|
||||
async def get_schedules(graph_id: str) -> list[ExecutionSchedule]:
|
||||
async def get_schedules(graph_id: str, user_id: str) -> list[ExecutionSchedule]:
|
||||
query = AgentGraphExecutionSchedule.prisma().find_many(
|
||||
where={
|
||||
"isEnabled": True,
|
||||
"agentGraphId": graph_id,
|
||||
"userId": user_id,
|
||||
},
|
||||
)
|
||||
return [ExecutionSchedule.from_db(schedule) for schedule in await query]
|
||||
@@ -71,7 +74,7 @@ async def add_schedule(schedule: ExecutionSchedule) -> ExecutionSchedule:
|
||||
return ExecutionSchedule.from_db(obj)
|
||||
|
||||
|
||||
async def update_schedule(schedule_id: str, is_enabled: bool):
|
||||
async def update_schedule(schedule_id: str, is_enabled: bool, user_id: str):
|
||||
await AgentGraphExecutionSchedule.prisma().update(
|
||||
where={"id": schedule_id}, data={"isEnabled": is_enabled}
|
||||
)
|
||||
|
||||
40
rnd/autogpt_server/autogpt_server/data/user.py
Normal file
40
rnd/autogpt_server/autogpt_server/data/user.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from typing import Optional
|
||||
|
||||
from prisma.models import User
|
||||
|
||||
from autogpt_server.data.db import prisma
|
||||
|
||||
|
||||
async def get_or_create_user(user_data: dict) -> User:
|
||||
user = await prisma.user.find_unique(where={"id": user_data["sub"]})
|
||||
if not user:
|
||||
user = await prisma.user.create(
|
||||
data={
|
||||
"id": user_data["sub"],
|
||||
"email": user_data["email"],
|
||||
"name": user_data.get("user_metadata", {}).get("name"),
|
||||
}
|
||||
)
|
||||
return User.model_validate(user)
|
||||
|
||||
|
||||
async def get_user_by_id(user_id: str) -> Optional[User]:
|
||||
user = await prisma.user.find_unique(where={"id": user_id})
|
||||
return User.model_validate(user) if user else None
|
||||
|
||||
|
||||
async def create_default_user(enable_auth: str) -> Optional[User]:
|
||||
if not enable_auth.lower() == "true":
|
||||
user = await prisma.user.find_unique(
|
||||
where={"id": "3e53486c-cf57-477e-ba2a-cb02dc828e1a"}
|
||||
)
|
||||
if not user:
|
||||
user = await prisma.user.create(
|
||||
data={
|
||||
"id": "3e53486c-cf57-477e-ba2a-cb02dc828e1a",
|
||||
"email": "default@example.com",
|
||||
"name": "Default User",
|
||||
}
|
||||
)
|
||||
return User.model_validate(user)
|
||||
return None
|
||||
@@ -416,8 +416,10 @@ class ExecutionManager(AppService):
|
||||
return get_agent_server_client()
|
||||
|
||||
@expose
|
||||
def add_execution(self, graph_id: str, data: BlockInput) -> dict[Any, Any]:
|
||||
graph: Graph | None = self.run_and_wait(get_graph(graph_id))
|
||||
def add_execution(
|
||||
self, graph_id: str, data: BlockInput, user_id: str
|
||||
) -> dict[Any, Any]:
|
||||
graph: Graph | None = self.run_and_wait(get_graph(graph_id, user_id=user_id))
|
||||
if not graph:
|
||||
raise Exception(f"Graph #{graph_id} not found.")
|
||||
graph.validate_graph(for_run=True)
|
||||
|
||||
@@ -62,16 +62,22 @@ class ExecutionScheduler(AppService):
|
||||
logger.exception(f"Error executing graph {graph_id}: {e}")
|
||||
|
||||
@expose
|
||||
def update_schedule(self, schedule_id: str, is_enabled: bool) -> str:
|
||||
self.run_and_wait(model.update_schedule(schedule_id, is_enabled))
|
||||
def update_schedule(self, schedule_id: str, is_enabled: bool, user_id: str) -> str:
|
||||
self.run_and_wait(model.update_schedule(schedule_id, is_enabled, user_id))
|
||||
return schedule_id
|
||||
|
||||
@expose
|
||||
def add_execution_schedule(
|
||||
self, graph_id: str, graph_version: int, cron: str, input_data: BlockInput
|
||||
self,
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
cron: str,
|
||||
input_data: BlockInput,
|
||||
user_id: str,
|
||||
) -> str:
|
||||
schedule = model.ExecutionSchedule(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
graph_version=graph_version,
|
||||
schedule=cron,
|
||||
input_data=input_data,
|
||||
@@ -79,7 +85,7 @@ class ExecutionScheduler(AppService):
|
||||
return self.run_and_wait(model.add_schedule(schedule)).id
|
||||
|
||||
@expose
|
||||
def get_execution_schedules(self, graph_id: str) -> dict[str, str]:
|
||||
query = model.get_schedules(graph_id)
|
||||
def get_execution_schedules(self, graph_id: str, user_id: str) -> dict[str, str]:
|
||||
query = model.get_schedules(graph_id, user_id=user_id)
|
||||
schedules: list[model.ExecutionSchedule] = self.run_and_wait(query)
|
||||
return {v.id: v.schedule for v in schedules}
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
from collections import defaultdict
|
||||
from contextlib import asynccontextmanager
|
||||
from functools import wraps
|
||||
from typing import Annotated, Any, Dict
|
||||
|
||||
import uvicorn
|
||||
from autogpt_libs.auth.jwt_utils import parse_jwt_token
|
||||
from autogpt_libs.auth.middleware import auth_middleware
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
@@ -20,12 +23,14 @@ from fastapi.responses import JSONResponse
|
||||
import autogpt_server.server.ws_api
|
||||
from autogpt_server.data import block, db
|
||||
from autogpt_server.data import graph as graph_db
|
||||
from autogpt_server.data import user as user_db
|
||||
from autogpt_server.data.block import BlockInput, CompletedBlockOutput
|
||||
from autogpt_server.data.execution import (
|
||||
ExecutionResult,
|
||||
get_execution_results,
|
||||
list_executions,
|
||||
)
|
||||
from autogpt_server.data.user import get_or_create_user
|
||||
from autogpt_server.executor import ExecutionManager, ExecutionScheduler
|
||||
from autogpt_server.server.conn_manager import ConnectionManager
|
||||
from autogpt_server.server.model import (
|
||||
@@ -38,12 +43,26 @@ from autogpt_server.util.lock import KeyedMutex
|
||||
from autogpt_server.util.service import AppService, expose, get_service_client
|
||||
from autogpt_server.util.settings import Settings
|
||||
|
||||
settings = Settings()
|
||||
|
||||
|
||||
def get_user_id(payload: dict = Depends(auth_middleware)) -> str:
|
||||
if not payload:
|
||||
# This handles the case when authentication is disabled
|
||||
return "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
|
||||
user_id = payload.get("sub")
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="User ID not found in token")
|
||||
return user_id
|
||||
|
||||
|
||||
class AgentServer(AppService):
|
||||
event_queue: asyncio.Queue[ExecutionResult] = asyncio.Queue()
|
||||
manager = ConnectionManager()
|
||||
mutex = KeyedMutex()
|
||||
use_db = False
|
||||
_test_dependency_overrides = {}
|
||||
|
||||
async def event_broadcaster(self):
|
||||
while True:
|
||||
@@ -55,6 +74,7 @@ class AgentServer(AppService):
|
||||
await db.connect()
|
||||
await block.initialize_blocks()
|
||||
await graph_db.import_packaged_templates()
|
||||
await user_db.create_default_user(settings.config.enable_auth)
|
||||
asyncio.create_task(self.event_broadcaster())
|
||||
yield
|
||||
await db.disconnect()
|
||||
@@ -71,6 +91,9 @@ class AgentServer(AppService):
|
||||
lifespan=self.lifespan,
|
||||
)
|
||||
|
||||
if self._test_dependency_overrides:
|
||||
app.dependency_overrides.update(self._test_dependency_overrides)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # Allows all origins
|
||||
@@ -83,6 +106,12 @@ class AgentServer(AppService):
|
||||
router = APIRouter(prefix="/api")
|
||||
router.dependencies.append(Depends(auth_middleware))
|
||||
|
||||
router.add_api_route(
|
||||
path="/auth/user",
|
||||
endpoint=self.get_or_create_user_route,
|
||||
methods=["POST"],
|
||||
)
|
||||
|
||||
router.add_api_route(
|
||||
path="/blocks",
|
||||
endpoint=self.get_graph_blocks, # type: ignore
|
||||
@@ -200,6 +229,35 @@ class AgentServer(AppService):
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
|
||||
def set_test_dependency_overrides(self, overrides: dict):
|
||||
self._test_dependency_overrides = overrides
|
||||
|
||||
def _apply_overrides_to_methods(self):
|
||||
for attr_name in dir(self):
|
||||
attr = getattr(self, attr_name)
|
||||
if callable(attr) and hasattr(attr, "__annotations__"):
|
||||
setattr(self, attr_name, self._override_method(attr))
|
||||
|
||||
# TODO: fix this with some proper refactoring of the server
|
||||
def _override_method(self, method):
|
||||
@wraps(method)
|
||||
async def wrapper(*args, **kwargs):
|
||||
sig = inspect.signature(method)
|
||||
for param_name, param in sig.parameters.items():
|
||||
if param.annotation is inspect.Parameter.empty:
|
||||
continue
|
||||
if isinstance(param.annotation, Depends) or ( # type: ignore
|
||||
isinstance(param.annotation, type) and issubclass(param.annotation, Depends) # type: ignore
|
||||
):
|
||||
dependency = param.annotation.dependency if isinstance(param.annotation, Depends) else param.annotation # type: ignore
|
||||
if dependency in self._test_dependency_overrides:
|
||||
kwargs[param_name] = self._test_dependency_overrides[
|
||||
dependency
|
||||
]()
|
||||
return await method(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
@property
|
||||
def execution_manager_client(self) -> ExecutionManager:
|
||||
return get_service_client(ExecutionManager)
|
||||
@@ -218,7 +276,30 @@ class AgentServer(AppService):
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
async def authenticate_websocket(self, websocket: WebSocket) -> str:
|
||||
if settings.config.enable_auth.lower() == "true":
|
||||
token = websocket.query_params.get("token")
|
||||
if not token:
|
||||
await websocket.close(code=4001, reason="Missing authentication token")
|
||||
return ""
|
||||
|
||||
try:
|
||||
payload = parse_jwt_token(token)
|
||||
user_id = payload.get("sub")
|
||||
if not user_id:
|
||||
await websocket.close(code=4002, reason="Invalid token")
|
||||
return ""
|
||||
return user_id
|
||||
except ValueError:
|
||||
await websocket.close(code=4003, reason="Invalid token")
|
||||
return ""
|
||||
else:
|
||||
return "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
|
||||
async def websocket_router(self, websocket: WebSocket):
|
||||
user_id = await self.authenticate_websocket(websocket)
|
||||
if not user_id:
|
||||
return
|
||||
await self.manager.connect(websocket)
|
||||
try:
|
||||
while True:
|
||||
@@ -257,7 +338,7 @@ class AgentServer(AppService):
|
||||
).model_dump_json()
|
||||
)
|
||||
elif message.method == Methods.GET_GRAPHS:
|
||||
data = await self.get_graphs()
|
||||
data = await self.get_graphs(user_id=user_id)
|
||||
await websocket.send_text(
|
||||
WsMessage(
|
||||
method=Methods.GET_GRAPHS,
|
||||
@@ -268,7 +349,9 @@ class AgentServer(AppService):
|
||||
print("Get graphs request received")
|
||||
elif message.method == Methods.GET_GRAPH:
|
||||
assert isinstance(message.data, dict), "Data must be a dictionary"
|
||||
data = await self.get_graph(message.data["graph_id"])
|
||||
data = await self.get_graph(
|
||||
message.data["graph_id"], user_id=user_id
|
||||
)
|
||||
await websocket.send_text(
|
||||
WsMessage(
|
||||
method=Methods.GET_GRAPH,
|
||||
@@ -280,7 +363,7 @@ class AgentServer(AppService):
|
||||
elif message.method == Methods.CREATE_GRAPH:
|
||||
assert isinstance(message.data, dict), "Data must be a dictionary"
|
||||
create_graph = CreateGraph.model_validate(message.data)
|
||||
data = await self.create_new_graph(create_graph)
|
||||
data = await self.create_new_graph(create_graph, user_id=user_id)
|
||||
await websocket.send_text(
|
||||
WsMessage(
|
||||
method=Methods.CREATE_GRAPH,
|
||||
@@ -293,7 +376,7 @@ class AgentServer(AppService):
|
||||
elif message.method == Methods.RUN_GRAPH:
|
||||
assert isinstance(message.data, dict), "Data must be a dictionary"
|
||||
data = await self.execute_graph(
|
||||
message.data["graph_id"], message.data["data"]
|
||||
message.data["graph_id"], message.data["data"], user_id=user_id
|
||||
)
|
||||
await websocket.send_text(
|
||||
WsMessage(
|
||||
@@ -306,7 +389,9 @@ class AgentServer(AppService):
|
||||
print("Run graph request received")
|
||||
elif message.method == Methods.GET_GRAPH_RUNS:
|
||||
assert isinstance(message.data, dict), "Data must be a dictionary"
|
||||
data = await self.list_graph_runs(message.data["graph_id"])
|
||||
data = await self.list_graph_runs(
|
||||
message.data["graph_id"], user_id=user_id
|
||||
)
|
||||
await websocket.send_text(
|
||||
WsMessage(
|
||||
method=Methods.GET_GRAPH_RUNS,
|
||||
@@ -322,6 +407,7 @@ class AgentServer(AppService):
|
||||
message.data["graph_id"],
|
||||
message.data["cron"],
|
||||
message.data["data"],
|
||||
user_id=user_id,
|
||||
)
|
||||
await websocket.send_text(
|
||||
WsMessage(
|
||||
@@ -334,7 +420,9 @@ class AgentServer(AppService):
|
||||
print("Create scheduled run request received")
|
||||
elif message.method == Methods.GET_SCHEDULED_RUNS:
|
||||
assert isinstance(message.data, dict), "Data must be a dictionary"
|
||||
data = self.get_execution_schedules(message.data["graph_id"])
|
||||
data = self.get_execution_schedules(
|
||||
message.data["graph_id"], user_id=user_id
|
||||
)
|
||||
await websocket.send_text(
|
||||
WsMessage(
|
||||
method=Methods.GET_SCHEDULED_RUNS,
|
||||
@@ -346,7 +434,7 @@ class AgentServer(AppService):
|
||||
elif message.method == Methods.UPDATE_SCHEDULED_RUN:
|
||||
assert isinstance(message.data, dict), "Data must be a dictionary"
|
||||
data = self.update_schedule(
|
||||
message.data["schedule_id"], message.data
|
||||
message.data["schedule_id"], message.data, user_id=user_id
|
||||
)
|
||||
await websocket.send_text(
|
||||
WsMessage(
|
||||
@@ -385,6 +473,11 @@ class AgentServer(AppService):
|
||||
self.manager.disconnect(websocket)
|
||||
print("Client Disconnected")
|
||||
|
||||
@classmethod
|
||||
async def get_or_create_user_route(cls, user_data: dict = Depends(auth_middleware)):
|
||||
user = await get_or_create_user(user_data)
|
||||
return user.model_dump()
|
||||
|
||||
@classmethod
|
||||
def get_graph_blocks(cls) -> list[dict[Any, Any]]:
|
||||
return [v.to_dict() for v in block.get_blocks().values()] # type: ignore
|
||||
@@ -403,8 +496,10 @@ class AgentServer(AppService):
|
||||
return output
|
||||
|
||||
@classmethod
|
||||
async def get_graphs(cls) -> list[graph_db.GraphMeta]:
|
||||
return await graph_db.get_graphs_meta(filter_by="active")
|
||||
async def get_graphs(
|
||||
cls, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> list[graph_db.GraphMeta]:
|
||||
return await graph_db.get_graphs_meta(filter_by="active", user_id=user_id)
|
||||
|
||||
@classmethod
|
||||
async def get_templates(cls) -> list[graph_db.GraphMeta]:
|
||||
@@ -412,9 +507,12 @@ class AgentServer(AppService):
|
||||
|
||||
@classmethod
|
||||
async def get_graph(
|
||||
cls, graph_id: str, version: int | None = None
|
||||
cls,
|
||||
graph_id: str,
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
version: int | None = None,
|
||||
) -> graph_db.Graph:
|
||||
graph = await graph_db.get_graph(graph_id, version)
|
||||
graph = await graph_db.get_graph(graph_id, version, user_id=user_id)
|
||||
if not graph:
|
||||
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
|
||||
return graph
|
||||
@@ -431,30 +529,39 @@ class AgentServer(AppService):
|
||||
return graph
|
||||
|
||||
@classmethod
|
||||
async def get_graph_all_versions(cls, graph_id: str) -> list[graph_db.Graph]:
|
||||
graphs = await graph_db.get_graph_all_versions(graph_id)
|
||||
async def get_graph_all_versions(
|
||||
cls, graph_id: str, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> list[graph_db.Graph]:
|
||||
graphs = await graph_db.get_graph_all_versions(graph_id, user_id=user_id)
|
||||
if not graphs:
|
||||
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
|
||||
return graphs
|
||||
|
||||
@classmethod
|
||||
async def create_new_graph(cls, create_graph: CreateGraph) -> graph_db.Graph:
|
||||
return await cls.create_graph(create_graph, is_template=False)
|
||||
async def create_new_graph(
|
||||
cls, create_graph: CreateGraph, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> graph_db.Graph:
|
||||
return await cls.create_graph(create_graph, is_template=False, user_id=user_id)
|
||||
|
||||
@classmethod
|
||||
async def create_new_template(cls, create_graph: CreateGraph) -> graph_db.Graph:
|
||||
return await cls.create_graph(create_graph, is_template=True)
|
||||
async def create_new_template(
|
||||
cls, create_graph: CreateGraph, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> graph_db.Graph:
|
||||
return await cls.create_graph(create_graph, is_template=True, user_id=user_id)
|
||||
|
||||
@classmethod
|
||||
async def create_graph(
|
||||
cls, create_graph: CreateGraph, is_template: bool
|
||||
cls, create_graph: CreateGraph, is_template: bool, user_id: str
|
||||
) -> graph_db.Graph:
|
||||
if create_graph.graph:
|
||||
graph = create_graph.graph
|
||||
elif create_graph.template_id:
|
||||
# Create a new graph from a template
|
||||
graph = await graph_db.get_graph(
|
||||
create_graph.template_id, create_graph.template_version, template=True
|
||||
create_graph.template_id,
|
||||
create_graph.template_version,
|
||||
template=True,
|
||||
user_id=user_id,
|
||||
)
|
||||
if not graph:
|
||||
raise HTTPException(
|
||||
@@ -470,16 +577,23 @@ class AgentServer(AppService):
|
||||
graph.is_active = not is_template
|
||||
graph.reassign_ids(reassign_graph_id=True)
|
||||
|
||||
return await graph_db.create_graph(graph)
|
||||
return await graph_db.create_graph(graph, user_id=user_id)
|
||||
|
||||
@classmethod
|
||||
async def update_graph(cls, graph_id: str, graph: graph_db.Graph) -> graph_db.Graph:
|
||||
async def update_graph(
|
||||
cls,
|
||||
graph_id: str,
|
||||
graph: graph_db.Graph,
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> graph_db.Graph:
|
||||
# Sanity check
|
||||
if graph.id and graph.id != graph_id:
|
||||
raise HTTPException(400, detail="Graph ID does not match ID in URI")
|
||||
|
||||
# Determine new version
|
||||
existing_versions = await graph_db.get_graph_all_versions(graph_id)
|
||||
existing_versions = await graph_db.get_graph_all_versions(
|
||||
graph_id, user_id=user_id
|
||||
)
|
||||
if not existing_versions:
|
||||
raise HTTPException(404, detail=f"Graph #{graph_id} not found")
|
||||
latest_version_number = max(g.version for g in existing_versions)
|
||||
@@ -495,43 +609,56 @@ class AgentServer(AppService):
|
||||
graph.is_active = not graph.is_template
|
||||
graph.reassign_ids()
|
||||
|
||||
new_graph_version = await graph_db.create_graph(graph)
|
||||
new_graph_version = await graph_db.create_graph(graph, user_id=user_id)
|
||||
|
||||
if new_graph_version.is_active:
|
||||
# Ensure new version is the only active version
|
||||
await graph_db.set_graph_active_version(
|
||||
graph_id=graph_id, version=new_graph_version.version
|
||||
graph_id=graph_id, version=new_graph_version.version, user_id=user_id
|
||||
)
|
||||
|
||||
return new_graph_version
|
||||
|
||||
@classmethod
|
||||
async def set_graph_active_version(
|
||||
cls, graph_id: str, request_body: SetGraphActiveVersion
|
||||
cls,
|
||||
graph_id: str,
|
||||
request_body: SetGraphActiveVersion,
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
):
|
||||
new_active_version = request_body.active_graph_version
|
||||
if not await graph_db.get_graph(graph_id, new_active_version):
|
||||
if not await graph_db.get_graph(graph_id, new_active_version, user_id=user_id):
|
||||
raise HTTPException(
|
||||
404, f"Graph #{graph_id} v{new_active_version} not found"
|
||||
)
|
||||
await graph_db.set_graph_active_version(
|
||||
graph_id=graph_id, version=request_body.active_graph_version
|
||||
graph_id=graph_id,
|
||||
version=request_body.active_graph_version,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
async def execute_graph(
|
||||
self, graph_id: str, node_input: dict[Any, Any]
|
||||
self,
|
||||
graph_id: str,
|
||||
node_input: dict[Any, Any],
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> dict[Any, Any]:
|
||||
try:
|
||||
return self.execution_manager_client.add_execution(graph_id, node_input)
|
||||
return self.execution_manager_client.add_execution(
|
||||
graph_id, node_input, user_id=user_id
|
||||
)
|
||||
except Exception as e:
|
||||
msg = e.__str__().encode().decode("unicode_escape")
|
||||
raise HTTPException(status_code=400, detail=msg)
|
||||
|
||||
@classmethod
|
||||
async def list_graph_runs(
|
||||
cls, graph_id: str, graph_version: int | None = None
|
||||
cls,
|
||||
graph_id: str,
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
graph_version: int | None = None,
|
||||
) -> list[str]:
|
||||
graph = await graph_db.get_graph(graph_id, graph_version)
|
||||
graph = await graph_db.get_graph(graph_id, graph_version, user_id=user_id)
|
||||
if not graph:
|
||||
rev = "" if graph_version is None else f" v{graph_version}"
|
||||
raise HTTPException(
|
||||
@@ -542,38 +669,47 @@ class AgentServer(AppService):
|
||||
|
||||
@classmethod
|
||||
async def get_run_execution_results(
|
||||
cls, graph_id: str, run_id: str
|
||||
cls, graph_id: str, run_id: str, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> list[ExecutionResult]:
|
||||
graph = await graph_db.get_graph(graph_id)
|
||||
graph = await graph_db.get_graph(graph_id, user_id=user_id)
|
||||
if not graph:
|
||||
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
|
||||
|
||||
return await get_execution_results(run_id)
|
||||
|
||||
async def create_schedule(
|
||||
self, graph_id: str, cron: str, input_data: dict[Any, Any]
|
||||
self,
|
||||
graph_id: str,
|
||||
cron: str,
|
||||
input_data: dict[Any, Any],
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> dict[Any, Any]:
|
||||
graph = await graph_db.get_graph(graph_id)
|
||||
graph = await graph_db.get_graph(graph_id, user_id=user_id)
|
||||
if not graph:
|
||||
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
|
||||
execution_scheduler = self.execution_scheduler_client
|
||||
return {
|
||||
"id": execution_scheduler.add_execution_schedule(
|
||||
graph_id, graph.version, cron, input_data
|
||||
graph_id, graph.version, cron, input_data, user_id=user_id
|
||||
)
|
||||
}
|
||||
|
||||
def update_schedule(
|
||||
self, schedule_id: str, input_data: dict[Any, Any]
|
||||
self,
|
||||
schedule_id: str,
|
||||
input_data: dict[Any, Any],
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> dict[Any, Any]:
|
||||
execution_scheduler = self.execution_scheduler_client
|
||||
is_enabled = input_data.get("is_enabled", False)
|
||||
execution_scheduler.update_schedule(schedule_id, is_enabled) # type: ignore
|
||||
execution_scheduler.update_schedule(schedule_id, is_enabled, user_id=user_id) # type: ignore
|
||||
return {"id": schedule_id}
|
||||
|
||||
def get_execution_schedules(self, graph_id: str) -> dict[str, str]:
|
||||
def get_execution_schedules(
|
||||
self, graph_id: str, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> dict[str, str]:
|
||||
execution_scheduler = self.execution_scheduler_client
|
||||
return execution_scheduler.get_execution_schedules(graph_id) # type: ignore
|
||||
return execution_scheduler.get_execution_schedules(graph_id, user_id) # type: ignore
|
||||
|
||||
@expose
|
||||
def send_execution_update(self, execution_result_dict: dict[Any, Any]):
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
from pathlib import Path
|
||||
|
||||
from prisma.models import User
|
||||
|
||||
from autogpt_server.blocks.basic import ValueBlock
|
||||
from autogpt_server.blocks.block import BlockInstallationBlock
|
||||
from autogpt_server.blocks.http import HttpRequestBlock
|
||||
from autogpt_server.blocks.llm import TextLlmCallBlock
|
||||
from autogpt_server.blocks.text import TextFormatterBlock, TextParserBlock
|
||||
from autogpt_server.data.graph import Graph, Link, Node, create_graph
|
||||
from autogpt_server.data.user import get_or_create_user
|
||||
from autogpt_server.util.test import SpinTestServer, wait_execution
|
||||
|
||||
sample_block_modules = {
|
||||
@@ -23,6 +26,16 @@ for module, description in sample_block_modules.items():
|
||||
sample_block_codes[module] = f"[Example: {description}]\n{code}"
|
||||
|
||||
|
||||
async def create_test_user() -> User:
|
||||
test_user_data = {
|
||||
"sub": "ef3b97d7-1161-4eb4-92b2-10c24fb154c1",
|
||||
"email": "testuser@example.com",
|
||||
"name": "Test User",
|
||||
}
|
||||
user = await get_or_create_user(test_user_data)
|
||||
return user
|
||||
|
||||
|
||||
def create_test_graph() -> Graph:
|
||||
"""
|
||||
ValueBlock (input)
|
||||
@@ -237,9 +250,12 @@ Here are a couple of sample of the Block class implementation:
|
||||
async def block_autogen_agent():
|
||||
async with SpinTestServer() as server:
|
||||
test_manager = server.exec_manager
|
||||
test_graph = await create_graph(create_test_graph())
|
||||
test_user = await create_test_user()
|
||||
test_graph = await create_graph(create_test_graph(), user_id=test_user.id)
|
||||
input_data = {"input": "Write me a block that writes a string into a file."}
|
||||
response = await server.agent_server.execute_graph(test_graph.id, input_data)
|
||||
response = await server.agent_server.execute_graph(
|
||||
test_graph.id, input_data, test_user.id
|
||||
)
|
||||
print(response)
|
||||
result = await wait_execution(
|
||||
exec_manager=test_manager,
|
||||
@@ -247,6 +263,7 @@ async def block_autogen_agent():
|
||||
graph_exec_id=response["id"],
|
||||
num_execs=10,
|
||||
timeout=1200,
|
||||
user_id=test_user.id,
|
||||
)
|
||||
print(result)
|
||||
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
from prisma.models import User
|
||||
|
||||
from autogpt_server.blocks.llm import ObjectLlmCallBlock
|
||||
from autogpt_server.blocks.reddit import RedditGetPostsBlock, RedditPostCommentBlock
|
||||
from autogpt_server.blocks.text import TextFormatterBlock, TextMatcherBlock
|
||||
from autogpt_server.data.graph import Graph, Link, Node, create_graph
|
||||
from autogpt_server.data.user import get_or_create_user
|
||||
from autogpt_server.util.test import SpinTestServer, wait_execution
|
||||
|
||||
|
||||
@@ -136,14 +139,29 @@ Make sure to only comment on a relevant post.
|
||||
return test_graph
|
||||
|
||||
|
||||
async def create_test_user() -> User:
|
||||
test_user_data = {
|
||||
"sub": "ef3b97d7-1161-4eb4-92b2-10c24fb154c1",
|
||||
"email": "testuser@example.com",
|
||||
"name": "Test User",
|
||||
}
|
||||
user = await get_or_create_user(test_user_data)
|
||||
return user
|
||||
|
||||
|
||||
async def reddit_marketing_agent():
|
||||
async with SpinTestServer() as server:
|
||||
exec_man = server.exec_manager
|
||||
test_graph = await create_graph(create_test_graph())
|
||||
test_user = await create_test_user()
|
||||
test_graph = await create_graph(create_test_graph(), user_id=test_user.id)
|
||||
input_data = {"subreddit": "AutoGPT"}
|
||||
response = await server.agent_server.execute_graph(test_graph.id, input_data)
|
||||
response = await server.agent_server.execute_graph(
|
||||
test_graph.id, input_data, test_user.id
|
||||
)
|
||||
print(response)
|
||||
result = await wait_execution(exec_man, test_graph.id, response["id"], 13, 120)
|
||||
result = await wait_execution(
|
||||
exec_man, test_user.id, test_graph.id, response["id"], 13, 120
|
||||
)
|
||||
print(result)
|
||||
|
||||
|
||||
|
||||
@@ -1,10 +1,23 @@
|
||||
from prisma.models import User
|
||||
|
||||
from autogpt_server.blocks.basic import InputBlock, PrintingBlock
|
||||
from autogpt_server.blocks.text import TextFormatterBlock
|
||||
from autogpt_server.data import graph
|
||||
from autogpt_server.data.graph import create_graph
|
||||
from autogpt_server.data.user import get_or_create_user
|
||||
from autogpt_server.util.test import SpinTestServer, wait_execution
|
||||
|
||||
|
||||
async def create_test_user() -> User:
|
||||
test_user_data = {
|
||||
"sub": "ef3b97d7-1161-4eb4-92b2-10c24fb154c1",
|
||||
"email": "testuser@example.com",
|
||||
"name": "Test User",
|
||||
}
|
||||
user = await get_or_create_user(test_user_data)
|
||||
return user
|
||||
|
||||
|
||||
def create_test_graph() -> graph.Graph:
|
||||
"""
|
||||
ValueBlock
|
||||
@@ -63,11 +76,16 @@ def create_test_graph() -> graph.Graph:
|
||||
async def sample_agent():
|
||||
async with SpinTestServer() as server:
|
||||
exec_man = server.exec_manager
|
||||
test_graph = await create_graph(create_test_graph())
|
||||
test_user = await create_test_user()
|
||||
test_graph = await create_graph(create_test_graph(), test_user.id)
|
||||
input_data = {"input_1": "Hello", "input_2": "World"}
|
||||
response = await server.agent_server.execute_graph(test_graph.id, input_data)
|
||||
response = await server.agent_server.execute_graph(
|
||||
test_graph.id, input_data, test_user.id
|
||||
)
|
||||
print(response)
|
||||
result = await wait_execution(exec_man, test_graph.id, response["id"], 4, 10)
|
||||
result = await wait_execution(
|
||||
exec_man, test_user.id, test_graph.id, response["id"], 4, 10
|
||||
)
|
||||
print(result)
|
||||
|
||||
|
||||
|
||||
@@ -57,6 +57,10 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
default="localhost",
|
||||
description="The default hostname of the Pyro server.",
|
||||
)
|
||||
enable_auth: str = Field(
|
||||
default="false",
|
||||
description="If authentication is enabled or not",
|
||||
)
|
||||
# Add more configuration fields as needed
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
|
||||
@@ -5,6 +5,7 @@ from autogpt_server.data.block import Block, initialize_blocks
|
||||
from autogpt_server.data.execution import ExecutionStatus
|
||||
from autogpt_server.executor import ExecutionManager, ExecutionScheduler
|
||||
from autogpt_server.server import AgentServer
|
||||
from autogpt_server.server.server import get_user_id
|
||||
from autogpt_server.util.service import PyroNameServer
|
||||
|
||||
log = print
|
||||
@@ -17,14 +18,21 @@ class SpinTestServer:
|
||||
self.agent_server = AgentServer()
|
||||
self.scheduler = ExecutionScheduler()
|
||||
|
||||
@staticmethod
|
||||
def test_get_user_id():
|
||||
return "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
|
||||
async def __aenter__(self):
|
||||
|
||||
self.name_server.__enter__()
|
||||
self.setup_dependency_overrides()
|
||||
self.agent_server.__enter__()
|
||||
self.exec_manager.__enter__()
|
||||
self.scheduler.__enter__()
|
||||
|
||||
await db.connect()
|
||||
await initialize_blocks()
|
||||
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
@@ -35,16 +43,25 @@ class SpinTestServer:
|
||||
self.exec_manager.__exit__(exc_type, exc_val, exc_tb)
|
||||
self.scheduler.__exit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
def setup_dependency_overrides(self):
|
||||
# Override get_user_id for testing
|
||||
self.agent_server.set_test_dependency_overrides(
|
||||
{get_user_id: self.test_get_user_id}
|
||||
)
|
||||
|
||||
|
||||
async def wait_execution(
|
||||
exec_manager: ExecutionManager,
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
graph_exec_id: str,
|
||||
num_execs: int,
|
||||
timeout: int = 20,
|
||||
) -> list:
|
||||
async def is_execution_completed():
|
||||
execs = await AgentServer().get_run_execution_results(graph_id, graph_exec_id)
|
||||
execs = await AgentServer().get_run_execution_results(
|
||||
graph_id, graph_exec_id, user_id
|
||||
)
|
||||
return (
|
||||
exec_manager.queue.empty()
|
||||
and len(execs) == num_execs
|
||||
@@ -58,7 +75,7 @@ async def wait_execution(
|
||||
for i in range(timeout):
|
||||
if await is_execution_completed():
|
||||
return await AgentServer().get_run_execution_results(
|
||||
graph_id, graph_exec_id
|
||||
graph_id, graph_exec_id, user_id
|
||||
)
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
@@ -0,0 +1,60 @@
|
||||
-- CreateTable
|
||||
CREATE TABLE "User" (
|
||||
"id" TEXT NOT NULL PRIMARY KEY,
|
||||
"email" TEXT NOT NULL,
|
||||
"name" TEXT,
|
||||
"createdAt" DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" DATETIME NOT NULL
|
||||
);
|
||||
|
||||
-- RedefineTables
|
||||
PRAGMA foreign_keys=OFF;
|
||||
CREATE TABLE "new_AgentGraph" (
|
||||
"id" TEXT NOT NULL,
|
||||
"version" INTEGER NOT NULL DEFAULT 1,
|
||||
"name" TEXT,
|
||||
"description" TEXT,
|
||||
"isActive" BOOLEAN NOT NULL DEFAULT true,
|
||||
"isTemplate" BOOLEAN NOT NULL DEFAULT false,
|
||||
"userId" TEXT,
|
||||
"agentGraphParentId" TEXT,
|
||||
|
||||
PRIMARY KEY ("id", "version"),
|
||||
CONSTRAINT "AgentGraph_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User" ("id") ON DELETE SET NULL ON UPDATE CASCADE,
|
||||
CONSTRAINT "AgentGraph_agentGraphParentId_version_fkey" FOREIGN KEY ("agentGraphParentId", "version") REFERENCES "AgentGraph" ("id", "version") ON DELETE RESTRICT ON UPDATE CASCADE
|
||||
);
|
||||
INSERT INTO "new_AgentGraph" ("agentGraphParentId", "description", "id", "isActive", "isTemplate", "name", "version") SELECT "agentGraphParentId", "description", "id", "isActive", "isTemplate", "name", "version" FROM "AgentGraph";
|
||||
DROP TABLE "AgentGraph";
|
||||
ALTER TABLE "new_AgentGraph" RENAME TO "AgentGraph";
|
||||
CREATE TABLE "new_AgentGraphExecution" (
|
||||
"id" TEXT NOT NULL PRIMARY KEY,
|
||||
"agentGraphId" TEXT NOT NULL,
|
||||
"agentGraphVersion" INTEGER NOT NULL DEFAULT 1,
|
||||
"userId" TEXT,
|
||||
CONSTRAINT "AgentGraphExecution_agentGraphId_agentGraphVersion_fkey" FOREIGN KEY ("agentGraphId", "agentGraphVersion") REFERENCES "AgentGraph" ("id", "version") ON DELETE RESTRICT ON UPDATE CASCADE,
|
||||
CONSTRAINT "AgentGraphExecution_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User" ("id") ON DELETE SET NULL ON UPDATE CASCADE
|
||||
);
|
||||
INSERT INTO "new_AgentGraphExecution" ("agentGraphId", "agentGraphVersion", "id") SELECT "agentGraphId", "agentGraphVersion", "id" FROM "AgentGraphExecution";
|
||||
DROP TABLE "AgentGraphExecution";
|
||||
ALTER TABLE "new_AgentGraphExecution" RENAME TO "AgentGraphExecution";
|
||||
CREATE TABLE "new_AgentGraphExecutionSchedule" (
|
||||
"id" TEXT NOT NULL PRIMARY KEY,
|
||||
"agentGraphId" TEXT NOT NULL,
|
||||
"agentGraphVersion" INTEGER NOT NULL DEFAULT 1,
|
||||
"schedule" TEXT NOT NULL,
|
||||
"isEnabled" BOOLEAN NOT NULL DEFAULT true,
|
||||
"inputData" TEXT NOT NULL,
|
||||
"lastUpdated" DATETIME NOT NULL,
|
||||
"userId" TEXT,
|
||||
CONSTRAINT "AgentGraphExecutionSchedule_agentGraphId_agentGraphVersion_fkey" FOREIGN KEY ("agentGraphId", "agentGraphVersion") REFERENCES "AgentGraph" ("id", "version") ON DELETE RESTRICT ON UPDATE CASCADE,
|
||||
CONSTRAINT "AgentGraphExecutionSchedule_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User" ("id") ON DELETE SET NULL ON UPDATE CASCADE
|
||||
);
|
||||
INSERT INTO "new_AgentGraphExecutionSchedule" ("agentGraphId", "agentGraphVersion", "id", "inputData", "isEnabled", "lastUpdated", "schedule") SELECT "agentGraphId", "agentGraphVersion", "id", "inputData", "isEnabled", "lastUpdated", "schedule" FROM "AgentGraphExecutionSchedule";
|
||||
DROP TABLE "AgentGraphExecutionSchedule";
|
||||
ALTER TABLE "new_AgentGraphExecutionSchedule" RENAME TO "AgentGraphExecutionSchedule";
|
||||
CREATE INDEX "AgentGraphExecutionSchedule_isEnabled_idx" ON "AgentGraphExecutionSchedule"("isEnabled");
|
||||
PRAGMA foreign_key_check;
|
||||
PRAGMA foreign_keys=ON;
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "User_email_key" ON "User"("email");
|
||||
@@ -0,0 +1,5 @@
|
||||
-- CreateIndex
|
||||
CREATE INDEX "User_id_idx" ON "User"("id");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "User_email_idx" ON "User"("email");
|
||||
22
rnd/autogpt_server/poetry.lock
generated
22
rnd/autogpt_server/poetry.lock
generated
@@ -1,4 +1,4 @@
|
||||
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "agpt"
|
||||
@@ -25,7 +25,7 @@ requests = "*"
|
||||
sentry-sdk = "^1.40.4"
|
||||
|
||||
[package.extras]
|
||||
benchmark = ["agbenchmark @ file:///home/bently/Desktop/autogpt-ui/AutoGPT/benchmark"]
|
||||
benchmark = ["agbenchmark"]
|
||||
|
||||
[package.source]
|
||||
type = "directory"
|
||||
@@ -329,7 +329,7 @@ watchdog = "4.0.0"
|
||||
webdriver-manager = "^4.0.1"
|
||||
|
||||
[package.extras]
|
||||
benchmark = ["agbenchmark @ file:///home/bently/Desktop/autogpt-ui/AutoGPT/benchmark"]
|
||||
benchmark = ["agbenchmark"]
|
||||
|
||||
[package.source]
|
||||
type = "directory"
|
||||
@@ -342,7 +342,7 @@ description = "Shared libraries across NextGen AutoGPT"
|
||||
optional = false
|
||||
python-versions = ">=3.10,<4.0"
|
||||
files = []
|
||||
develop = true
|
||||
develop = false
|
||||
|
||||
[package.dependencies]
|
||||
pyjwt = "^2.8.0"
|
||||
@@ -4212,19 +4212,19 @@ windows-terminal = ["colorama (>=0.4.6)"]
|
||||
|
||||
[[package]]
|
||||
name = "pyjwt"
|
||||
version = "2.8.0"
|
||||
version = "2.9.0"
|
||||
description = "JSON Web Token implementation in Python"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "PyJWT-2.8.0-py3-none-any.whl", hash = "sha256:59127c392cc44c2da5bb3192169a91f429924e17aff6534d70fdc02ab3e04320"},
|
||||
{file = "PyJWT-2.8.0.tar.gz", hash = "sha256:57e28d156e3d5c10088e0c68abb90bfac3df82b40a71bd0daa20c65ccd5c23de"},
|
||||
{file = "PyJWT-2.9.0-py3-none-any.whl", hash = "sha256:3b02fb0f44517787776cf48f2ae25d8e14f300e6d7545a4315cee571a415e850"},
|
||||
{file = "pyjwt-2.9.0.tar.gz", hash = "sha256:7e1e5b56cc735432a7369cbfa0efe50fa113ebecdc04ae6922deba8b84582d0c"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
crypto = ["cryptography (>=3.4.0)"]
|
||||
dev = ["coverage[toml] (==5.0.4)", "cryptography (>=3.4.0)", "pre-commit", "pytest (>=6.0.0,<7.0.0)", "sphinx (>=4.5.0,<5.0.0)", "sphinx-rtd-theme", "zope.interface"]
|
||||
docs = ["sphinx (>=4.5.0,<5.0.0)", "sphinx-rtd-theme", "zope.interface"]
|
||||
dev = ["coverage[toml] (==5.0.4)", "cryptography (>=3.4.0)", "pre-commit", "pytest (>=6.0.0,<7.0.0)", "sphinx", "sphinx-rtd-theme", "zope.interface"]
|
||||
docs = ["sphinx", "sphinx-rtd-theme", "zope.interface"]
|
||||
tests = ["coverage[toml] (==5.0.4)", "pytest (>=6.0.0,<7.0.0)"]
|
||||
|
||||
[[package]]
|
||||
@@ -6419,4 +6419,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools",
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.10"
|
||||
content-hash = "9991857e7076d3bfcbae7af6c2cec54dc943167a3adceb5a0ebf74d80c05778f"
|
||||
content-hash = "003a4c89682abbf72c67631367f57e56d91d72b44f95e972b2326440199045e7"
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
-- AlterTable
|
||||
ALTER TABLE "AgentGraph" ADD COLUMN "userId" TEXT;
|
||||
|
||||
-- AlterTable
|
||||
ALTER TABLE "AgentGraphExecution" ADD COLUMN "userId" TEXT;
|
||||
|
||||
-- AlterTable
|
||||
ALTER TABLE "AgentGraphExecutionSchedule" ADD COLUMN "userId" TEXT;
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "User" (
|
||||
"id" TEXT NOT NULL,
|
||||
"email" TEXT NOT NULL,
|
||||
"name" TEXT,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
|
||||
CONSTRAINT "User_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "User_email_key" ON "User"("email");
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "AgentGraph" ADD CONSTRAINT "AgentGraph_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE SET NULL ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "AgentGraphExecution" ADD CONSTRAINT "AgentGraphExecution_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE SET NULL ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "AgentGraphExecutionSchedule" ADD CONSTRAINT "AgentGraphExecutionSchedule_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE SET NULL ON UPDATE CASCADE;
|
||||
@@ -0,0 +1,5 @@
|
||||
-- CreateIndex
|
||||
CREATE INDEX "User_id_idx" ON "User"("id");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "User_email_idx" ON "User"("email");
|
||||
@@ -10,6 +10,23 @@ generator client {
|
||||
interface = "asyncio"
|
||||
}
|
||||
|
||||
// User model to mirror Auth provider users
|
||||
model User {
|
||||
id String @id // This should match the Supabase user ID
|
||||
email String @unique
|
||||
name String?
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
// Relations
|
||||
AgentGraphs AgentGraph[]
|
||||
AgentGraphExecutions AgentGraphExecution[]
|
||||
AgentGraphExecutionSchedules AgentGraphExecutionSchedule[]
|
||||
|
||||
@@index([id])
|
||||
@@index([email])
|
||||
}
|
||||
|
||||
// This model describes the Agent Graph/Flow (Multi Agent System).
|
||||
model AgentGraph {
|
||||
id String @default(uuid())
|
||||
@@ -20,6 +37,10 @@ model AgentGraph {
|
||||
isActive Boolean @default(true)
|
||||
isTemplate Boolean @default(false)
|
||||
|
||||
// Link to User model
|
||||
userId String?
|
||||
user User? @relation(fields: [userId], references: [id])
|
||||
|
||||
AgentNodes AgentNode[]
|
||||
AgentGraphExecution AgentGraphExecution[]
|
||||
AgentGraphExecutionSchedule AgentGraphExecutionSchedule[]
|
||||
@@ -99,6 +120,10 @@ model AgentGraphExecution {
|
||||
AgentGraph AgentGraph @relation(fields: [agentGraphId, agentGraphVersion], references: [id, version])
|
||||
|
||||
AgentNodeExecutions AgentNodeExecution[]
|
||||
|
||||
// Link to User model
|
||||
userId String?
|
||||
user User? @relation(fields: [userId], references: [id])
|
||||
}
|
||||
|
||||
// This model describes the execution of an AgentNode.
|
||||
@@ -158,5 +183,9 @@ model AgentGraphExecutionSchedule {
|
||||
// default and set the value on each update, lastUpdated field has no time zone.
|
||||
lastUpdated DateTime @updatedAt
|
||||
|
||||
// Link to User model
|
||||
userId String?
|
||||
user User? @relation(fields: [userId], references: [id])
|
||||
|
||||
@@index([isEnabled])
|
||||
}
|
||||
|
||||
@@ -41,7 +41,7 @@ python-dotenv = "^1.0.1"
|
||||
expiringdict = "^1.2.2"
|
||||
discord-py = "^2.4.0"
|
||||
|
||||
autogpt-libs = { path = "../autogpt_libs", develop = true }
|
||||
autogpt-libs = {path = "../autogpt_libs"}
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
cx-freeze = { git = "https://github.com/ntindle/cx_Freeze.git", rev = "main", develop = true }
|
||||
poethepoet = "^0.26.1"
|
||||
|
||||
@@ -9,6 +9,23 @@ generator client {
|
||||
interface = "asyncio"
|
||||
}
|
||||
|
||||
// User model to mirror Auth provider users
|
||||
model User {
|
||||
id String @id // This should match the Supabase user ID
|
||||
email String @unique
|
||||
name String?
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
// Relations
|
||||
AgentGraphs AgentGraph[]
|
||||
AgentGraphExecutions AgentGraphExecution[]
|
||||
AgentGraphExecutionSchedules AgentGraphExecutionSchedule[]
|
||||
|
||||
@@index([id])
|
||||
@@index([email])
|
||||
}
|
||||
|
||||
// This model describes the Agent Graph/Flow (Multi Agent System).
|
||||
model AgentGraph {
|
||||
id String @default(uuid())
|
||||
@@ -19,6 +36,10 @@ model AgentGraph {
|
||||
isActive Boolean @default(true)
|
||||
isTemplate Boolean @default(false)
|
||||
|
||||
// Link to User model
|
||||
userId String?
|
||||
user User? @relation(fields: [userId], references: [id])
|
||||
|
||||
AgentNodes AgentNode[]
|
||||
AgentGraphExecution AgentGraphExecution[]
|
||||
AgentGraphExecutionSchedule AgentGraphExecutionSchedule[]
|
||||
@@ -98,6 +119,10 @@ model AgentGraphExecution {
|
||||
AgentGraph AgentGraph @relation(fields: [agentGraphId, agentGraphVersion], references: [id, version])
|
||||
|
||||
AgentNodeExecutions AgentNodeExecution[]
|
||||
|
||||
// Link to User model
|
||||
userId String?
|
||||
user User? @relation(fields: [userId], references: [id])
|
||||
}
|
||||
|
||||
// This model describes the execution of an AgentNode.
|
||||
@@ -157,5 +182,9 @@ model AgentGraphExecutionSchedule {
|
||||
// default and set the value on each update, lastUpdated field has no time zone.
|
||||
lastUpdated DateTime @updatedAt
|
||||
|
||||
// Link to User model
|
||||
userId String?
|
||||
user User? @relation(fields: [userId], references: [id])
|
||||
|
||||
@@index([isEnabled])
|
||||
}
|
||||
}
|
||||
@@ -1,11 +1,12 @@
|
||||
import pytest
|
||||
from prisma.models import User
|
||||
|
||||
from autogpt_server.blocks.basic import ObjectLookupBlock, ValueBlock
|
||||
from autogpt_server.blocks.maths import MathsBlock, Operation
|
||||
from autogpt_server.data import execution, graph
|
||||
from autogpt_server.executor import ExecutionManager
|
||||
from autogpt_server.server import AgentServer
|
||||
from autogpt_server.usecases.sample import create_test_graph
|
||||
from autogpt_server.usecases.sample import create_test_graph, create_test_user
|
||||
from autogpt_server.util.test import wait_execution
|
||||
|
||||
|
||||
@@ -13,24 +14,30 @@ async def execute_graph(
|
||||
agent_server: AgentServer,
|
||||
test_manager: ExecutionManager,
|
||||
test_graph: graph.Graph,
|
||||
test_user: User,
|
||||
input_data: dict,
|
||||
num_execs: int = 4,
|
||||
) -> str:
|
||||
# --- Test adding new executions --- #
|
||||
response = await agent_server.execute_graph(test_graph.id, input_data)
|
||||
response = await agent_server.execute_graph(test_graph.id, input_data, test_user.id)
|
||||
graph_exec_id = response["id"]
|
||||
|
||||
# Execution queue should be empty
|
||||
assert await wait_execution(test_manager, test_graph.id, graph_exec_id, num_execs)
|
||||
assert await wait_execution(
|
||||
test_manager, test_user.id, test_graph.id, graph_exec_id, num_execs
|
||||
)
|
||||
return graph_exec_id
|
||||
|
||||
|
||||
async def assert_sample_graph_executions(
|
||||
agent_server: AgentServer, test_graph: graph.Graph, graph_exec_id: str
|
||||
agent_server: AgentServer,
|
||||
test_graph: graph.Graph,
|
||||
test_user: User,
|
||||
graph_exec_id: str,
|
||||
):
|
||||
input = {"input_1": "Hello", "input_2": "World"}
|
||||
executions = await agent_server.get_run_execution_results(
|
||||
test_graph.id, graph_exec_id
|
||||
test_graph.id, graph_exec_id, test_user.id
|
||||
)
|
||||
|
||||
# Executing ValueBlock
|
||||
@@ -75,12 +82,20 @@ async def assert_sample_graph_executions(
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_agent_execution(server):
|
||||
test_graph = create_test_graph()
|
||||
await graph.create_graph(test_graph)
|
||||
test_user = await create_test_user()
|
||||
await graph.create_graph(test_graph, user_id=test_user.id)
|
||||
data = {"input_1": "Hello", "input_2": "World"}
|
||||
graph_exec_id = await execute_graph(
|
||||
server.agent_server, server.exec_manager, test_graph, data, 4
|
||||
server.agent_server,
|
||||
server.exec_manager,
|
||||
test_graph,
|
||||
test_user,
|
||||
data,
|
||||
4,
|
||||
)
|
||||
await assert_sample_graph_executions(
|
||||
server.agent_server, test_graph, test_user, graph_exec_id
|
||||
)
|
||||
await assert_sample_graph_executions(server.agent_server, test_graph, graph_exec_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
@@ -130,14 +145,14 @@ async def test_input_pin_always_waited(server):
|
||||
nodes=nodes,
|
||||
links=links,
|
||||
)
|
||||
|
||||
test_graph = await graph.create_graph(test_graph)
|
||||
test_user = await create_test_user()
|
||||
test_graph = await graph.create_graph(test_graph, user_id=test_user.id)
|
||||
graph_exec_id = await execute_graph(
|
||||
server.agent_server, server.exec_manager, test_graph, {}, 3
|
||||
server.agent_server, server.exec_manager, test_graph, test_user, {}, 3
|
||||
)
|
||||
|
||||
executions = await server.agent_server.get_run_execution_results(
|
||||
test_graph.id, graph_exec_id
|
||||
test_graph.id, graph_exec_id, test_user.id
|
||||
)
|
||||
assert len(executions) == 3
|
||||
# ObjectLookupBlock should wait for the input pin to be provided,
|
||||
@@ -211,13 +226,13 @@ async def test_static_input_link_on_graph(server):
|
||||
nodes=nodes,
|
||||
links=links,
|
||||
)
|
||||
|
||||
test_graph = await graph.create_graph(test_graph)
|
||||
test_user = await create_test_user()
|
||||
test_graph = await graph.create_graph(test_graph, user_id=test_user.id)
|
||||
graph_exec_id = await execute_graph(
|
||||
server.agent_server, server.exec_manager, test_graph, {}, 8
|
||||
server.agent_server, server.exec_manager, test_graph, test_user, {}, 8
|
||||
)
|
||||
executions = await server.agent_server.get_run_execution_results(
|
||||
test_graph.id, graph_exec_id
|
||||
test_graph.id, graph_exec_id, test_user.id
|
||||
)
|
||||
assert len(executions) == 8
|
||||
# The last 3 executions will be a+b=4+5=9
|
||||
|
||||
@@ -2,29 +2,32 @@ import pytest
|
||||
|
||||
from autogpt_server.data import db, graph
|
||||
from autogpt_server.executor import ExecutionScheduler
|
||||
from autogpt_server.usecases.sample import create_test_graph
|
||||
from autogpt_server.usecases.sample import create_test_graph, create_test_user
|
||||
from autogpt_server.util.service import get_service_client
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="flakey test, needs to be investigated")
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_agent_schedule(server):
|
||||
await db.connect()
|
||||
test_graph = await graph.create_graph(create_test_graph())
|
||||
test_user = await create_test_user()
|
||||
test_graph = await graph.create_graph(create_test_graph(), user_id=test_user.id)
|
||||
|
||||
scheduler = get_service_client(ExecutionScheduler)
|
||||
|
||||
schedules = scheduler.get_execution_schedules(test_graph.id)
|
||||
schedules = scheduler.get_execution_schedules(test_graph.id, test_user.id)
|
||||
assert len(schedules) == 0
|
||||
|
||||
schedule_id = scheduler.add_execution_schedule(
|
||||
graph_id=test_graph.id,
|
||||
user_id=test_user.id,
|
||||
graph_version=1,
|
||||
cron="0 0 * * *",
|
||||
input_data={"input": "data"},
|
||||
)
|
||||
assert schedule_id
|
||||
|
||||
schedules = scheduler.get_execution_schedules(test_graph.id)
|
||||
schedules = scheduler.get_execution_schedules(test_graph.id, test_user.id)
|
||||
assert len(schedules) == 1
|
||||
assert schedules[schedule_id] == "0 0 * * *"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user