From 2a74381ae8a31bc7b4078a4e21dcd3f850e8f6fb Mon Sep 17 00:00:00 2001 From: Zamil Majdy Date: Tue, 8 Oct 2024 19:03:26 +0300 Subject: [PATCH] feat(platform): Add delete agent functionality (#8273) --- .../backend/backend/data/graph.py | 9 ++++ .../backend/backend/server/rest_api.py | 18 +++++++ .../migration.sql | 5 ++ autogpt_platform/backend/schema.prisma | 4 +- autogpt_platform/backend/test/conftest.py | 25 ++++++++++ .../backend/test/executor/test_manager.py | 12 +++-- .../backend/test/executor/test_scheduler.py | 9 +++- autogpt_platform/frontend/src/app/page.tsx | 5 ++ .../src/components/monitor/FlowInfo.tsx | 49 ++++++++++++++++++- .../src/lib/autogpt-server-api/baseClient.ts | 4 ++ 10 files changed, 130 insertions(+), 10 deletions(-) create mode 100644 autogpt_platform/backend/migrations/20241007115713_cascade_graph_deletion/migration.sql diff --git a/autogpt_platform/backend/backend/data/graph.py b/autogpt_platform/backend/backend/data/graph.py index a507dd275a..007c0030f1 100644 --- a/autogpt_platform/backend/backend/data/graph.py +++ b/autogpt_platform/backend/backend/data/graph.py @@ -500,6 +500,15 @@ async def get_graph_all_versions(graph_id: str, user_id: str) -> list[Graph]: return [Graph.from_db(graph) for graph in graph_versions] +async def delete_graph(graph_id: str, user_id: str) -> int: + entries_count = await AgentGraph.prisma().delete_many( + where={"id": graph_id, "userId": user_id} + ) + if entries_count: + logger.info(f"Deleted {entries_count} graph entries for Graph #{graph_id}") + return entries_count + + async def create_graph(graph: Graph, user_id: str) -> Graph: async with transaction() as tx: await __create_graph(tx, graph, user_id) diff --git a/autogpt_platform/backend/backend/server/rest_api.py b/autogpt_platform/backend/backend/server/rest_api.py index 99dba10009..0f190dea43 100644 --- a/autogpt_platform/backend/backend/server/rest_api.py +++ b/autogpt_platform/backend/backend/server/rest_api.py @@ -10,6 +10,7 @@ from autogpt_libs.auth.middleware import auth_middleware from fastapi import APIRouter, Body, Depends, FastAPI, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse +from typing_extensions import TypedDict from backend.data import block, db from backend.data import execution as execution_db @@ -168,6 +169,12 @@ class AgentServer(AppService): methods=["PUT"], tags=["templates", "graphs"], ) + api_router.add_api_route( + path="/graphs/{graph_id}", + endpoint=self.delete_graph, + methods=["DELETE"], + tags=["graphs"], + ) api_router.add_api_route( path="/graphs/{graph_id}/versions", endpoint=self.get_graph_all_versions, @@ -395,6 +402,17 @@ class AgentServer(AppService): ) -> graph_db.Graph: return await cls.create_graph(create_graph, is_template=True, user_id=user_id) + class DeleteGraphResponse(TypedDict): + version_counts: int + + @classmethod + async def delete_graph( + cls, graph_id: str, user_id: Annotated[str, Depends(get_user_id)] + ) -> DeleteGraphResponse: + return { + "version_counts": await graph_db.delete_graph(graph_id, user_id=user_id) + } + @classmethod async def create_graph( cls, diff --git a/autogpt_platform/backend/migrations/20241007115713_cascade_graph_deletion/migration.sql b/autogpt_platform/backend/migrations/20241007115713_cascade_graph_deletion/migration.sql new file mode 100644 index 0000000000..3b783a6d92 --- /dev/null +++ b/autogpt_platform/backend/migrations/20241007115713_cascade_graph_deletion/migration.sql @@ -0,0 +1,5 @@ +-- DropForeignKey +ALTER TABLE "AgentGraph" DROP CONSTRAINT "AgentGraph_agentGraphParentId_version_fkey"; + +-- AddForeignKey +ALTER TABLE "AgentGraph" ADD CONSTRAINT "AgentGraph_agentGraphParentId_version_fkey" FOREIGN KEY ("agentGraphParentId", "version") REFERENCES "AgentGraph"("id", "version") ON DELETE CASCADE ON UPDATE CASCADE; diff --git a/autogpt_platform/backend/schema.prisma b/autogpt_platform/backend/schema.prisma index d23ed6d9dd..3fab8dc259 100644 --- a/autogpt_platform/backend/schema.prisma +++ b/autogpt_platform/backend/schema.prisma @@ -53,7 +53,7 @@ model AgentGraph { // All sub-graphs are defined within this 1-level depth list (even if it's a nested graph). AgentSubGraphs AgentGraph[] @relation("AgentSubGraph") agentGraphParentId String? - AgentGraphParent AgentGraph? @relation("AgentSubGraph", fields: [agentGraphParentId, version], references: [id, version]) + AgentGraphParent AgentGraph? @relation("AgentSubGraph", fields: [agentGraphParentId, version], references: [id, version], onDelete: Cascade) @@id(name: "graphVersionId", [id, version]) } @@ -63,7 +63,7 @@ model AgentNode { id String @id @default(uuid()) agentBlockId String - AgentBlock AgentBlock @relation(fields: [agentBlockId], references: [id]) + AgentBlock AgentBlock @relation(fields: [agentBlockId], references: [id], onUpdate: Cascade) agentGraphId String agentGraphVersion Int @default(1) diff --git a/autogpt_platform/backend/test/conftest.py b/autogpt_platform/backend/test/conftest.py index b0b5c6cc68..59d6f70cf9 100644 --- a/autogpt_platform/backend/test/conftest.py +++ b/autogpt_platform/backend/test/conftest.py @@ -7,3 +7,28 @@ from backend.util.test import SpinTestServer async def server(): async with SpinTestServer() as server: yield server + + +@pytest.fixture(scope="session", autouse=True) +async def graph_cleanup(server): + created_graph_ids = [] + original_create_graph = server.agent_server.create_graph + + async def create_graph_wrapper(*args, **kwargs): + created_graph = await original_create_graph(*args, **kwargs) + # Extract user_id correctly + user_id = kwargs.get("user_id", args[2] if len(args) > 2 else None) + created_graph_ids.append((created_graph.id, user_id)) + return created_graph + + try: + server.agent_server.create_graph = create_graph_wrapper + yield # This runs the test function + finally: + server.agent_server.create_graph = original_create_graph + + # Delete the created graphs and assert they were deleted + for graph_id, user_id in created_graph_ids: + resp = await server.agent_server.delete_graph(graph_id, user_id) + num_deleted = resp["version_counts"] + assert num_deleted > 0, f"Graph {graph_id} was not deleted." diff --git a/autogpt_platform/backend/test/executor/test_manager.py b/autogpt_platform/backend/test/executor/test_manager.py index 66fb202240..3df990294f 100644 --- a/autogpt_platform/backend/test/executor/test_manager.py +++ b/autogpt_platform/backend/test/executor/test_manager.py @@ -5,10 +5,15 @@ from backend.blocks.basic import FindInDictionaryBlock, StoreValueBlock from backend.blocks.maths import CalculatorBlock, Operation from backend.data import execution, graph from backend.server import AgentServer +from backend.server.model import CreateGraph from backend.usecases.sample import create_test_graph, create_test_user from backend.util.test import SpinTestServer, wait_execution +async def create_graph(s: SpinTestServer, g: graph.Graph, u: User) -> graph.Graph: + return await s.agent_server.create_graph(CreateGraph(graph=g), False, u.id) + + async def execute_graph( agent_server: AgentServer, test_graph: graph.Graph, @@ -99,9 +104,8 @@ async def assert_sample_graph_executions( @pytest.mark.asyncio(scope="session") async def test_agent_execution(server: SpinTestServer): - test_graph = create_test_graph() test_user = await create_test_user() - await graph.create_graph(test_graph, user_id=test_user.id) + test_graph = await create_graph(server, create_test_graph(), test_user) data = {"input_1": "Hello", "input_2": "World"} graph_exec_id = await execute_graph( server.agent_server, @@ -163,7 +167,7 @@ async def test_input_pin_always_waited(server: SpinTestServer): links=links, ) test_user = await create_test_user() - test_graph = await graph.create_graph(test_graph, user_id=test_user.id) + test_graph = await create_graph(server, test_graph, test_user) graph_exec_id = await execute_graph( server.agent_server, test_graph, test_user, {}, 3 ) @@ -244,7 +248,7 @@ async def test_static_input_link_on_graph(server: SpinTestServer): links=links, ) test_user = await create_test_user() - test_graph = await graph.create_graph(test_graph, user_id=test_user.id) + test_graph = await create_graph(server, test_graph, test_user) graph_exec_id = await execute_graph( server.agent_server, test_graph, test_user, {}, 8 ) diff --git a/autogpt_platform/backend/test/executor/test_scheduler.py b/autogpt_platform/backend/test/executor/test_scheduler.py index 6c46110776..c0bcc83079 100644 --- a/autogpt_platform/backend/test/executor/test_scheduler.py +++ b/autogpt_platform/backend/test/executor/test_scheduler.py @@ -1,7 +1,8 @@ import pytest -from backend.data import db, graph +from backend.data import db from backend.executor import ExecutionScheduler +from backend.server.model import CreateGraph from backend.usecases.sample import create_test_graph, create_test_user from backend.util.service import get_service_client from backend.util.settings import Config @@ -12,7 +13,11 @@ from backend.util.test import SpinTestServer async def test_agent_schedule(server: SpinTestServer): await db.connect() test_user = await create_test_user() - test_graph = await graph.create_graph(create_test_graph(), user_id=test_user.id) + test_graph = await server.agent_server.create_graph( + create_graph=CreateGraph(graph=create_test_graph()), + is_template=False, + user_id=test_user.id, + ) scheduler = get_service_client( ExecutionScheduler, Config().execution_scheduler_port diff --git a/autogpt_platform/frontend/src/app/page.tsx b/autogpt_platform/frontend/src/app/page.tsx index 7924883160..756b1642f5 100644 --- a/autogpt_platform/frontend/src/app/page.tsx +++ b/autogpt_platform/frontend/src/app/page.tsx @@ -90,6 +90,11 @@ const Monitor = () => { flow={selectedFlow} flowRuns={flowRuns.filter((r) => r.graphID == selectedFlow.id)} className={column3} + refresh={() => { + fetchAgents(); + setSelectedFlow(null); + setSelectedRun(null); + }} /> )) || ( diff --git a/autogpt_platform/frontend/src/components/monitor/FlowInfo.tsx b/autogpt_platform/frontend/src/components/monitor/FlowInfo.tsx index 0cf53d78e3..0fdd2b1436 100644 --- a/autogpt_platform/frontend/src/components/monitor/FlowInfo.tsx +++ b/autogpt_platform/frontend/src/components/monitor/FlowInfo.tsx @@ -20,14 +20,24 @@ import { ClockIcon, ExitIcon, Pencil2Icon } from "@radix-ui/react-icons"; import Link from "next/link"; import { exportAsJSONFile } from "@/lib/utils"; import { FlowRunsStats } from "@/components/monitor/index"; +import { Trash2Icon } from "lucide-react"; +import { + Dialog, + DialogContent, + DialogHeader, + DialogTitle, + DialogDescription, + DialogFooter, +} from "@/components/ui/dialog"; export const FlowInfo: React.FC< React.HTMLAttributes & { flow: GraphMeta; flowRuns: FlowRun[]; flowVersion?: number | "all"; + refresh: () => void; } -> = ({ flow, flowRuns, flowVersion, ...props }) => { +> = ({ flow, flowRuns, flowVersion, refresh, ...props }) => { const api = useMemo(() => new AutoGPTServerAPI(), []); const [flowVersions, setFlowVersions] = useState(null); @@ -39,6 +49,8 @@ export const FlowInfo: React.FC< v.version == (selectedVersion == "all" ? flow.version : selectedVersion), ); + const [isDeleteModalOpen, setIsDeleteModalOpen] = useState(false); + useEffect(() => { api.getGraphAllVersions(flow.id).then((result) => setFlowVersions(result)); }, [flow.id, api]); @@ -96,7 +108,7 @@ export const FlowInfo: React.FC< className={buttonVariants({ variant: "outline" })} href={`/build?flowID=${flow.id}`} > - Edit + + @@ -128,6 +143,36 @@ export const FlowInfo: React.FC< )} /> + + + + Delete Agent + + Are you sure you want to delete this agent?
+ This action cannot be undone. +
+
+ + + + +
+
); }; diff --git a/autogpt_platform/frontend/src/lib/autogpt-server-api/baseClient.ts b/autogpt_platform/frontend/src/lib/autogpt-server-api/baseClient.ts index 19d8550d87..28b14b3ff5 100644 --- a/autogpt_platform/frontend/src/lib/autogpt-server-api/baseClient.ts +++ b/autogpt_platform/frontend/src/lib/autogpt-server-api/baseClient.ts @@ -124,6 +124,10 @@ export default class BaseAutoGPTServerAPI { return this._request("PUT", `/templates/${id}`, template); } + deleteGraph(id: string): Promise { + return this._request("DELETE", `/graphs/${id}`); + } + setGraphActiveVersion(id: string, version: number): Promise { return this._request("PUT", `/graphs/${id}/versions/active`, { active_graph_version: version,