add user management support

This commit is contained in:
Aarushi
2024-08-01 10:44:32 +01:00
parent c9d41e69bd
commit c613a2b3ec
10 changed files with 321 additions and 44 deletions

View File

@@ -0,0 +1,31 @@
import logging
from datetime import datetime
from typing import List, Optional
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
class UserBase(BaseModel):
id: str = Field(None, description="The unique Supabase ID of the user")
class UserCreate(UserBase):
pass
class UserUpdate(BaseModel):
id: str = Field(..., description="The unique Supabase identifier of the user")
class UserResponse(BaseModel):
id: str
email: Optional[str] = None
name: Optional[str] = None
createdAt: datetime
updatedAt: datetime
class UsersListResponse(BaseModel):
users: List[UserResponse]

View File

@@ -128,7 +128,8 @@ async def get_node(node_id: str) -> Node | None:
async def get_graphs_meta(
filter_by: Literal["active", "template"] | None = "active"
filter_by: Literal["active", "template"] | None = "active",
user_id: str | None = None,
) -> list[GraphMeta]:
"""
Retrieves graph metadata objects.
@@ -147,6 +148,9 @@ async def get_graphs_meta(
elif filter_by == "template":
where_clause["isTemplate"] = True
if user_id and filter_by != "template":
where_clause["userId"] = user_id
graphs = await AgentGraph.prisma().find_many(
where=where_clause,
distinct=["id"],
@@ -160,7 +164,10 @@ async def get_graphs_meta(
async def get_graph(
graph_id: str, version: int | None = None, template: bool = False
graph_id: str,
version: int | None = None,
template: bool = False,
user_id: str | None = None,
) -> Graph | None:
"""
Retrieves a graph from the DB.
@@ -178,6 +185,9 @@ async def get_graph(
elif not template:
where_clause["isActive"] = True
if user_id and not template:
where_clause["userId"] = user_id
graph = await AgentGraph.prisma().find_first(
where=where_clause,
include={"AgentNodes": {"include": EXECUTION_NODE_INCLUDE}}, # type: ignore
@@ -186,10 +196,23 @@ async def get_graph(
return Graph.from_db(graph) if graph else None
async def set_graph_active_version(graph_id: str, version: int) -> None:
async def set_graph_active_version(graph_id: str, version: int, user_id: str) -> None:
# Check if the graph belongs to the user
graph = await AgentGraph.prisma().find_first(
where={
"id": graph_id,
"version": version,
"userId": user_id,
}
)
if not graph:
raise Exception(f"Graph #{graph_id} v{version} not found or not owned by user")
updated_graph = await AgentGraph.prisma().update(
data={"isActive": True},
where={"graphVersionId": {"id": graph_id, "version": version}},
where={
"graphVersionId": {"id": graph_id, "version": version},
},
)
if not updated_graph:
raise Exception(f"Graph #{graph_id} v{version} not found")
@@ -197,13 +220,15 @@ async def set_graph_active_version(graph_id: str, version: int) -> None:
# Deactivate all other versions
await AgentGraph.prisma().update_many(
data={"isActive": False},
where={"id": graph_id, "version": {"not": version}},
where={"id": graph_id, "version": {"not": version}, "userId": user_id},
)
async def get_graph_all_versions(graph_id: str) -> list[Graph]:
async def get_graph_all_versions(
graph_id: str, user_id: str | None = None
) -> list[Graph]:
graph_versions = await AgentGraph.prisma().find_many(
where={"id": graph_id},
where={"id": graph_id, "userId": user_id},
order={"version": "desc"},
include={"AgentNodes": {"include": EXECUTION_NODE_INCLUDE}}, # type: ignore
)
@@ -214,7 +239,7 @@ async def get_graph_all_versions(graph_id: str) -> list[Graph]:
return [Graph.from_db(graph) for graph in graph_versions]
async def create_graph(graph: Graph) -> Graph:
async def create_graph(graph: Graph, user_id: str) -> Graph:
await AgentGraph.prisma().create(
data={
"id": graph.id,
@@ -223,6 +248,7 @@ async def create_graph(graph: Graph) -> Graph:
"description": graph.description,
"isTemplate": graph.is_template,
"isActive": graph.is_active,
"userId": user_id,
}
)
@@ -287,5 +313,5 @@ async def import_packaged_templates() -> None:
exists := next((t for t in templates_in_db if t.id == template.id), None)
) and exists.version >= template.version:
continue
await create_graph(template)
await create_graph(template, "")
print(f"Loaded template '{template.name}' ({template.id})")

View File

@@ -47,11 +47,12 @@ async def disable_schedule(schedule_id: str):
)
async def get_schedules(graph_id: str) -> list[ExecutionSchedule]:
async def get_schedules(graph_id: str, user_id: str) -> list[ExecutionSchedule]:
query = AgentGraphExecutionSchedule.prisma().find_many(
where={
"isEnabled": True,
"agentGraphId": graph_id,
"userId": user_id,
},
)
return [ExecutionSchedule.from_db(schedule) for schedule in await query]
@@ -71,7 +72,7 @@ async def add_schedule(schedule: ExecutionSchedule) -> ExecutionSchedule:
return ExecutionSchedule.from_db(obj)
async def update_schedule(schedule_id: str, is_enabled: bool):
async def update_schedule(schedule_id: str, is_enabled: bool, user_id: str):
await AgentGraphExecutionSchedule.prisma().update(
where={"id": schedule_id}, data={"isEnabled": is_enabled}
where={{"id": schedule_id}, {"userId": user_id}}, data={"isEnabled": is_enabled}
)

View File

@@ -0,0 +1,23 @@
from typing import Optional
from prisma.models import User
from autogpt_server.data.db import prisma
async def get_or_create_user(user_data: dict) -> User:
user = await prisma.user.find_unique(where={"id": user_data["sub"]})
if not user:
user = await prisma.user.create(
data={
"id": user_data["sub"],
"email": user_data["email"],
"name": user_data.get("user_metadata", {}).get("name"),
}
)
return User.model_validate(user)
async def get_user_by_id(user_id: str) -> Optional[User]:
user = await prisma.user.find_unique(where={"id": user_id})
return User.model_validate(user) if user else None

View File

@@ -62,8 +62,8 @@ class ExecutionScheduler(AppService):
logger.exception(f"Error executing graph {graph_id}: {e}")
@expose
def update_schedule(self, schedule_id: str, is_enabled: bool) -> str:
self.run_and_wait(model.update_schedule(schedule_id, is_enabled))
def update_schedule(self, schedule_id: str, is_enabled: bool, user_id: str) -> str:
self.run_and_wait(model.update_schedule(schedule_id, is_enabled, user_id))
return schedule_id
@expose
@@ -79,7 +79,7 @@ class ExecutionScheduler(AppService):
return self.run_and_wait(model.add_schedule(schedule)).id
@expose
def get_execution_schedules(self, graph_id: str) -> dict[str, str]:
query = model.get_schedules(graph_id)
def get_execution_schedules(self, graph_id: str, user_id: str) -> dict[str, str]:
query = model.get_schedules(graph_id, user_id=user_id)
schedules: list[model.ExecutionSchedule] = self.run_and_wait(query)
return {v.id: v.schedule for v in schedules}

View File

@@ -27,6 +27,7 @@ from autogpt_server.data.execution import (
get_execution_results,
list_executions,
)
from autogpt_server.data.user import get_or_create_user
from autogpt_server.executor import ExecutionManager, ExecutionScheduler
from autogpt_server.server.conn_manager import ConnectionManager
from autogpt_server.server.model import (
@@ -40,6 +41,17 @@ from autogpt_server.util.service import AppService, expose, get_service_client
from autogpt_server.util.settings import Settings
def get_user_id(payload: dict = Depends(auth_middleware)) -> str:
if not payload:
# This handles the case when authentication is disabled
return "default_user_id"
user_id = payload.get("sub")
if not user_id:
raise HTTPException(status_code=401, detail="User ID not found in token")
return user_id
class AgentServer(AppService):
event_queue: asyncio.Queue[ExecutionResult] = asyncio.Queue()
manager = ConnectionManager()
@@ -84,6 +96,12 @@ class AgentServer(AppService):
router = APIRouter(prefix="/api")
router.dependencies.append(Depends(auth_middleware))
router.add_api_route(
path="/auth/user",
endpoint=self.get_or_create_user_route,
methods=["POST"],
)
router.add_api_route(
path="/blocks",
endpoint=self.get_graph_blocks, # type: ignore
@@ -386,6 +404,11 @@ class AgentServer(AppService):
self.manager.disconnect(websocket)
print("Client Disconnected")
@classmethod
async def get_or_create_user_route(cls, user_data: dict = Depends(auth_middleware)):
user = await get_or_create_user(user_data)
return user.model_dump()
@classmethod
def get_graph_blocks(cls) -> list[dict[Any, Any]]:
return [v.to_dict() for v in block.get_blocks().values()] # type: ignore
@@ -404,8 +427,10 @@ class AgentServer(AppService):
return output
@classmethod
async def get_graphs(cls) -> list[graph_db.GraphMeta]:
return await graph_db.get_graphs_meta(filter_by="active")
async def get_graphs(
cls, user_id: str = Depends(get_user_id)
) -> list[graph_db.GraphMeta]:
return await graph_db.get_graphs_meta(filter_by="active", user_id=user_id)
@classmethod
async def get_templates(cls) -> list[graph_db.GraphMeta]:
@@ -413,9 +438,12 @@ class AgentServer(AppService):
@classmethod
async def get_graph(
cls, graph_id: str, version: int | None = None
cls,
graph_id: str,
version: int | None = None,
user_id: str = Depends(get_user_id),
) -> graph_db.Graph:
graph = await graph_db.get_graph(graph_id, version)
graph = await graph_db.get_graph(graph_id, version, user_id=user_id)
if not graph:
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
return graph
@@ -432,29 +460,38 @@ class AgentServer(AppService):
return graph
@classmethod
async def get_graph_all_versions(cls, graph_id: str) -> list[graph_db.Graph]:
graphs = await graph_db.get_graph_all_versions(graph_id)
async def get_graph_all_versions(
cls, graph_id: str, user_id: str = Depends(get_user_id)
) -> list[graph_db.Graph]:
graphs = await graph_db.get_graph_all_versions(graph_id, user_id=user_id)
if not graphs:
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
return graphs
@classmethod
async def create_new_graph(cls, create_graph: CreateGraph) -> graph_db.Graph:
return await cls.create_graph(create_graph, is_template=False)
async def create_new_graph(
cls, create_graph: CreateGraph, user_id: str = Depends(get_user_id)
) -> graph_db.Graph:
return await cls.create_graph(create_graph, is_template=False, user_id=user_id)
@classmethod
async def create_new_template(cls, create_graph: CreateGraph) -> graph_db.Graph:
return await cls.create_graph(create_graph, is_template=True)
async def create_new_template(
cls, create_graph: CreateGraph, user_id: str = Depends(get_user_id)
) -> graph_db.Graph:
return await cls.create_graph(create_graph, is_template=True, user_id=user_id)
@classmethod
async def create_graph(
cls, create_graph: CreateGraph, is_template: bool
cls, create_graph: CreateGraph, is_template: bool, user_id: str
) -> graph_db.Graph:
if create_graph.graph:
graph = create_graph.graph
elif create_graph.template_id:
graph = await graph_db.get_graph(
create_graph.template_id, create_graph.template_version, template=True
create_graph.template_id,
create_graph.template_version,
template=True,
user_id=user_id,
)
if not graph:
raise HTTPException(
@@ -478,16 +515,20 @@ class AgentServer(AppService):
link.source_id = id_map[link.source_id]
link.sink_id = id_map[link.sink_id]
return await graph_db.create_graph(graph)
return await graph_db.create_graph(graph, user_id=user_id)
@classmethod
async def update_graph(cls, graph_id: str, graph: graph_db.Graph) -> graph_db.Graph:
async def update_graph(
cls, graph_id: str, graph: graph_db.Graph, user_id: str = Depends(get_user_id)
) -> graph_db.Graph:
# Sanity check
if graph.id and graph.id != graph_id:
raise HTTPException(400, detail="Graph ID does not match ID in URI")
# Determine new version
existing_versions = await graph_db.get_graph_all_versions(graph_id)
existing_versions = await graph_db.get_graph_all_versions(
graph_id, user_id=user_id
)
if not existing_versions:
raise HTTPException(404, detail=f"Graph #{graph_id} not found")
latest_version_number = max(g.version for g in existing_versions)
@@ -510,19 +551,22 @@ class AgentServer(AppService):
link.source_id = id_map[link.source_id]
link.sink_id = id_map[link.sink_id]
new_graph_version = await graph_db.create_graph(graph)
new_graph_version = await graph_db.create_graph(graph, user_id=user_id)
if new_graph_version.is_active:
# Ensure new version is the only active version
await graph_db.set_graph_active_version(
graph_id=graph_id, version=new_graph_version.version
graph_id=graph_id, version=new_graph_version.version, user_id=user_id
)
return new_graph_version
@classmethod
async def set_graph_active_version(
cls, graph_id: str, request_body: SetGraphActiveVersion
cls,
graph_id: str,
request_body: SetGraphActiveVersion,
user_id: str = Depends(get_user_id),
):
new_active_version = request_body.active_graph_version
if not await graph_db.get_graph(graph_id, new_active_version):
@@ -530,7 +574,9 @@ class AgentServer(AppService):
404, f"Graph #{graph_id} v{new_active_version} not found"
)
await graph_db.set_graph_active_version(
graph_id=graph_id, version=request_body.active_graph_version
graph_id=graph_id,
version=request_body.active_graph_version,
user_id=user_id,
)
async def execute_graph(
@@ -557,18 +603,22 @@ class AgentServer(AppService):
@classmethod
async def get_run_execution_results(
cls, graph_id: str, run_id: str
cls, graph_id: str, run_id: str, user_id: str = Depends(get_user_id)
) -> list[ExecutionResult]:
graph = await graph_db.get_graph(graph_id)
graph = await graph_db.get_graph(graph_id, user_id=user_id)
if not graph:
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
return await get_execution_results(run_id)
async def create_schedule(
self, graph_id: str, cron: str, input_data: dict[Any, Any]
self,
graph_id: str,
cron: str,
input_data: dict[Any, Any],
user_id: str = Depends(get_user_id),
) -> dict[Any, Any]:
graph = await graph_db.get_graph(graph_id)
graph = await graph_db.get_graph(graph_id, user_id=user_id)
if not graph:
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
execution_scheduler = self.execution_scheduler_client
@@ -579,16 +629,21 @@ class AgentServer(AppService):
}
def update_schedule(
self, schedule_id: str, input_data: dict[Any, Any]
self,
schedule_id: str,
input_data: dict[Any, Any],
user_id: str = Depends(get_user_id),
) -> dict[Any, Any]:
execution_scheduler = self.execution_scheduler_client
is_enabled = input_data.get("is_enabled", False)
execution_scheduler.update_schedule(schedule_id, is_enabled) # type: ignore
execution_scheduler.update_schedule(schedule_id, is_enabled, user_id=user_id) # type: ignore
return {"id": schedule_id}
def get_execution_schedules(self, graph_id: str) -> dict[str, str]:
def get_execution_schedules(
self, graph_id: str, user_id: str = Depends(get_user_id)
) -> dict[str, str]:
execution_scheduler = self.execution_scheduler_client
return execution_scheduler.get_execution_schedules(graph_id) # type: ignore
return execution_scheduler.get_execution_schedules(graph_id, user_id) # type: ignore
@expose
def send_execution_update(self, execution_result_dict: dict[Any, Any]):

View File

@@ -0,0 +1,58 @@
-- CreateTable
CREATE TABLE "User" (
"id" TEXT NOT NULL PRIMARY KEY,
"email" TEXT NOT NULL,
"name" TEXT,
"createdAt" DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" DATETIME NOT NULL
);
-- RedefineTables
PRAGMA foreign_keys=OFF;
CREATE TABLE "new_AgentGraph" (
"id" TEXT NOT NULL,
"version" INTEGER NOT NULL DEFAULT 1,
"name" TEXT,
"description" TEXT,
"isActive" BOOLEAN NOT NULL DEFAULT true,
"isTemplate" BOOLEAN NOT NULL DEFAULT false,
"userId" TEXT,
PRIMARY KEY ("id", "version"),
CONSTRAINT "AgentGraph_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User" ("id") ON DELETE SET NULL ON UPDATE CASCADE
);
INSERT INTO "new_AgentGraph" ("description", "id", "isActive", "isTemplate", "name", "version") SELECT "description", "id", "isActive", "isTemplate", "name", "version" FROM "AgentGraph";
DROP TABLE "AgentGraph";
ALTER TABLE "new_AgentGraph" RENAME TO "AgentGraph";
CREATE TABLE "new_AgentGraphExecution" (
"id" TEXT NOT NULL PRIMARY KEY,
"agentGraphId" TEXT NOT NULL,
"agentGraphVersion" INTEGER NOT NULL DEFAULT 1,
"userId" TEXT,
CONSTRAINT "AgentGraphExecution_agentGraphId_agentGraphVersion_fkey" FOREIGN KEY ("agentGraphId", "agentGraphVersion") REFERENCES "AgentGraph" ("id", "version") ON DELETE RESTRICT ON UPDATE CASCADE,
CONSTRAINT "AgentGraphExecution_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User" ("id") ON DELETE SET NULL ON UPDATE CASCADE
);
INSERT INTO "new_AgentGraphExecution" ("agentGraphId", "agentGraphVersion", "id") SELECT "agentGraphId", "agentGraphVersion", "id" FROM "AgentGraphExecution";
DROP TABLE "AgentGraphExecution";
ALTER TABLE "new_AgentGraphExecution" RENAME TO "AgentGraphExecution";
CREATE TABLE "new_AgentGraphExecutionSchedule" (
"id" TEXT NOT NULL PRIMARY KEY,
"agentGraphId" TEXT NOT NULL,
"agentGraphVersion" INTEGER NOT NULL DEFAULT 1,
"schedule" TEXT NOT NULL,
"isEnabled" BOOLEAN NOT NULL DEFAULT true,
"inputData" TEXT NOT NULL,
"lastUpdated" DATETIME NOT NULL,
"userId" TEXT,
CONSTRAINT "AgentGraphExecutionSchedule_agentGraphId_agentGraphVersion_fkey" FOREIGN KEY ("agentGraphId", "agentGraphVersion") REFERENCES "AgentGraph" ("id", "version") ON DELETE RESTRICT ON UPDATE CASCADE,
CONSTRAINT "AgentGraphExecutionSchedule_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User" ("id") ON DELETE SET NULL ON UPDATE CASCADE
);
INSERT INTO "new_AgentGraphExecutionSchedule" ("agentGraphId", "agentGraphVersion", "id", "inputData", "isEnabled", "lastUpdated", "schedule") SELECT "agentGraphId", "agentGraphVersion", "id", "inputData", "isEnabled", "lastUpdated", "schedule" FROM "AgentGraphExecutionSchedule";
DROP TABLE "AgentGraphExecutionSchedule";
ALTER TABLE "new_AgentGraphExecutionSchedule" RENAME TO "AgentGraphExecutionSchedule";
CREATE INDEX "AgentGraphExecutionSchedule_isEnabled_idx" ON "AgentGraphExecutionSchedule"("isEnabled");
PRAGMA foreign_key_check;
PRAGMA foreign_keys=ON;
-- CreateIndex
CREATE UNIQUE INDEX "User_email_key" ON "User"("email");

View File

@@ -0,0 +1,31 @@
-- AlterTable
ALTER TABLE "AgentGraph" ADD COLUMN "userId" TEXT;
-- AlterTable
ALTER TABLE "AgentGraphExecution" ADD COLUMN "userId" TEXT;
-- AlterTable
ALTER TABLE "AgentGraphExecutionSchedule" ADD COLUMN "userId" TEXT;
-- CreateTable
CREATE TABLE "User" (
"id" TEXT NOT NULL,
"email" TEXT NOT NULL,
"name" TEXT,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
CONSTRAINT "User_pkey" PRIMARY KEY ("id")
);
-- CreateIndex
CREATE UNIQUE INDEX "User_email_key" ON "User"("email");
-- AddForeignKey
ALTER TABLE "AgentGraph" ADD CONSTRAINT "AgentGraph_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE SET NULL ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "AgentGraphExecution" ADD CONSTRAINT "AgentGraphExecution_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE SET NULL ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "AgentGraphExecutionSchedule" ADD CONSTRAINT "AgentGraphExecutionSchedule_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE SET NULL ON UPDATE CASCADE;

View File

@@ -10,6 +10,20 @@ generator client {
interface = "asyncio"
}
// User model to mirror Auth provider users
model User {
id String @id // This should match the Supabase user ID
email String @unique
name String?
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
// Relations
AgentGraphs AgentGraph[]
AgentGraphExecutions AgentGraphExecution[]
AgentGraphExecutionSchedules AgentGraphExecutionSchedule[]
}
// This model describes the Agent Graph/Flow (Multi Agent System).
model AgentGraph {
id String @default(uuid())
@@ -20,6 +34,10 @@ model AgentGraph {
isActive Boolean @default(true)
isTemplate Boolean @default(false)
// Link to User model
userId String?
user User? @relation(fields: [userId], references: [id])
AgentNodes AgentNode[]
AgentGraphExecution AgentGraphExecution[]
AgentGraphExecutionSchedule AgentGraphExecutionSchedule[]
@@ -94,6 +112,10 @@ model AgentGraphExecution {
AgentGraph AgentGraph @relation(fields: [agentGraphId, agentGraphVersion], references: [id, version])
AgentNodeExecutions AgentNodeExecution[]
// Link to User model
userId String?
user User? @relation(fields: [userId], references: [id])
}
// This model describes the execution of an AgentNode.
@@ -153,5 +175,9 @@ model AgentGraphExecutionSchedule {
// default and set the value on each update, lastUpdated field has no time zone.
lastUpdated DateTime @updatedAt
// Link to User model
userId String?
user User? @relation(fields: [userId], references: [id])
@@index([isEnabled])
}

View File

@@ -9,6 +9,20 @@ generator client {
interface = "asyncio"
}
// User model to mirror Auth provider users
model User {
id String @id // This should match the Supabase user ID
email String @unique
name String?
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
// Relations
AgentGraphs AgentGraph[]
AgentGraphExecutions AgentGraphExecution[]
AgentGraphExecutionSchedules AgentGraphExecutionSchedule[]
}
// This model describes the Agent Graph/Flow (Multi Agent System).
model AgentGraph {
id String @default(uuid())
@@ -19,6 +33,10 @@ model AgentGraph {
isActive Boolean @default(true)
isTemplate Boolean @default(false)
// Link to User model
userId String?
user User? @relation(fields: [userId], references: [id])
AgentNodes AgentNode[]
AgentGraphExecution AgentGraphExecution[]
AgentGraphExecutionSchedule AgentGraphExecutionSchedule[]
@@ -93,6 +111,10 @@ model AgentGraphExecution {
AgentGraph AgentGraph @relation(fields: [agentGraphId, agentGraphVersion], references: [id, version])
AgentNodeExecutions AgentNodeExecution[]
// Link to User model
userId String?
user User? @relation(fields: [userId], references: [id])
}
// This model describes the execution of an AgentNode.
@@ -152,5 +174,9 @@ model AgentGraphExecutionSchedule {
// default and set the value on each update, lastUpdated field has no time zone.
lastUpdated DateTime @updatedAt
// Link to User model
userId String?
user User? @relation(fields: [userId], references: [id])
@@index([isEnabled])
}
}