fix(server): Make userId required on DB entities and apply default ID to existing entries (#7755)

* Make `userId` required on DB entities `AgentGraph`, `AgentGraphExecution`, and `AgentGraphExecutionSchedule`

* Add SQLite and Postgres migrations to make `userId` required and set `userId` to `3e53486c-cf57-477e-ba2a-cb02dc828e1a` on existing entries without `userId`

* Amend `create_graph` endpoint and `.data.graph`, `.data.execution` methods to handle required `user_id`

* Add `.data.user.DEFAULT_USER_ID` constant to replace hardcoded literals
This commit is contained in:
Reinier van der Leer
2024-08-08 16:57:59 +02:00
committed by GitHub
parent 582571631e
commit 2ff8a0743a
9 changed files with 155 additions and 31 deletions

View File

@@ -117,7 +117,10 @@ EXECUTION_RESULT_INCLUDE = {
async def create_graph_execution(
graph_id: str, graph_version: int, nodes_input: list[tuple[str, BlockInput]]
graph_id: str,
graph_version: int,
nodes_input: list[tuple[str, BlockInput]],
user_id: str,
) -> tuple[str, list[ExecutionResult]]:
"""
Create a new AgentGraphExecution record.
@@ -143,6 +146,7 @@ async def create_graph_execution(
for node_id, node_input in nodes_input
]
},
"userId": user_id,
},
include={
"AgentNodeExecutions": {"include": EXECUTION_RESULT_INCLUDE} # type: ignore

View File

@@ -10,6 +10,7 @@ from pydantic import PrivateAttr
from autogpt_server.blocks.basic import InputBlock, OutputBlock
from autogpt_server.data.block import BlockInput, get_block
from autogpt_server.data.db import BaseDbModel, transaction
from autogpt_server.data.user import DEFAULT_USER_ID
from autogpt_server.util import json
@@ -151,7 +152,6 @@ class Graph(GraphMeta):
}
def validate_graph(self, for_run: bool = False):
def sanitize(name):
return name.split("_#_")[0].split("_@_")[0].split("_$_")[0]
@@ -368,9 +368,7 @@ async def set_graph_active_version(graph_id: str, version: int, user_id: str) ->
)
async def get_graph_all_versions(
graph_id: str, user_id: str | None = None
) -> list[Graph]:
async def get_graph_all_versions(graph_id: str, user_id: str) -> list[Graph]:
graph_versions = await AgentGraph.prisma().find_many(
where={"id": graph_id, "userId": user_id},
order={"version": "desc"},
@@ -383,7 +381,7 @@ async def get_graph_all_versions(
return [Graph.from_db(graph) for graph in graph_versions]
async def create_graph(graph: Graph, user_id: str | None) -> Graph:
async def create_graph(graph: Graph, user_id: str) -> Graph:
async with transaction() as tx:
await __create_graph(tx, graph, user_id)
@@ -395,7 +393,7 @@ async def create_graph(graph: Graph, user_id: str | None) -> Graph:
raise ValueError(f"Created graph {graph.id} v{graph.version} is not in DB")
async def __create_graph(tx, graph: Graph, user_id: str | None):
async def __create_graph(tx, graph: Graph, user_id: str):
await AgentGraph.prisma(tx).create(
data={
"id": graph.id,
@@ -482,5 +480,5 @@ async def import_packaged_templates() -> None:
exists := next((t for t in templates_in_db if t.id == template.id), None)
) and exists.version >= template.version:
continue
await create_graph(template, None)
await create_graph(template, DEFAULT_USER_ID)
print(f"Loaded template '{template.name}' ({template.id})")

View File

@@ -4,6 +4,8 @@ from prisma.models import User
from autogpt_server.data.db import prisma
DEFAULT_USER_ID = "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
async def get_or_create_user(user_data: dict) -> User:
user = await prisma.user.find_unique(where={"id": user_data["sub"]})
@@ -25,13 +27,11 @@ async def get_user_by_id(user_id: str) -> Optional[User]:
async def create_default_user(enable_auth: str) -> Optional[User]:
if not enable_auth.lower() == "true":
user = await prisma.user.find_unique(
where={"id": "3e53486c-cf57-477e-ba2a-cb02dc828e1a"}
)
user = await prisma.user.find_unique(where={"id": DEFAULT_USER_ID})
if not user:
user = await prisma.user.create(
data={
"id": "3e53486c-cf57-477e-ba2a-cb02dc828e1a",
"id": DEFAULT_USER_ID,
"email": "default@example.com",
"name": "Default User",
}

View File

@@ -165,7 +165,6 @@ def _enqueue_next_nodes(
# To avoid same execution to be enqueued multiple times,
# Or the same input to be consumed multiple times.
with synchronized(api_client, ("upsert_input", next_node_id, graph_exec_id)):
# Add output data to the earliest incomplete execution, or create a new one.
next_node_exec_id, next_node_input = wait(
upsert_execution_input(
@@ -442,6 +441,7 @@ class ExecutionManager(AppService):
graph_id=graph_id,
graph_version=graph.version,
nodes_input=nodes_input,
user_id=user_id,
)
)

View File

@@ -30,7 +30,7 @@ from autogpt_server.data.execution import (
get_execution_results,
list_executions,
)
from autogpt_server.data.user import get_or_create_user
from autogpt_server.data.user import DEFAULT_USER_ID, get_or_create_user
from autogpt_server.executor import ExecutionManager, ExecutionScheduler
from autogpt_server.server.conn_manager import ConnectionManager
from autogpt_server.server.model import (
@@ -49,7 +49,7 @@ settings = Settings()
def get_user_id(payload: dict = Depends(auth_middleware)) -> str:
if not payload:
# This handles the case when authentication is disabled
return "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
return DEFAULT_USER_ID
user_id = payload.get("sub")
if not user_id:
@@ -73,8 +73,8 @@ class AgentServer(AppService):
async def lifespan(self, _: FastAPI):
await db.connect()
await block.initialize_blocks()
await graph_db.import_packaged_templates()
await user_db.create_default_user(settings.config.enable_auth)
if await user_db.create_default_user(settings.config.enable_auth):
await graph_db.import_packaged_templates()
asyncio.create_task(self.event_broadcaster())
yield
await db.disconnect()
@@ -294,7 +294,7 @@ class AgentServer(AppService):
await websocket.close(code=4003, reason="Invalid token")
return ""
else:
return "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
return user_db.DEFAULT_USER_ID
async def websocket_router(self, websocket: WebSocket):
user_id = await self.authenticate_websocket(websocket)
@@ -551,7 +551,12 @@ class AgentServer(AppService):
@classmethod
async def create_graph(
cls, create_graph: CreateGraph, is_template: bool, user_id: str
cls,
create_graph: CreateGraph,
is_template: bool,
# user_id doesn't have to be annotated like on other endpoints,
# because create_graph isn't used directly as an endpoint
user_id: str,
) -> graph_db.Graph:
if create_graph.graph:
graph = create_graph.graph

View File

@@ -0,0 +1,92 @@
-- RedefineTables
PRAGMA foreign_keys=OFF;
CREATE TABLE "new_AgentGraphExecution" (
"id" TEXT NOT NULL PRIMARY KEY,
"agentGraphId" TEXT NOT NULL,
"agentGraphVersion" INTEGER NOT NULL DEFAULT 1,
"userId" TEXT NOT NULL,
CONSTRAINT "AgentGraphExecution_agentGraphId_agentGraphVersion_fkey"
FOREIGN KEY ("agentGraphId", "agentGraphVersion")
REFERENCES "AgentGraph" ("id", "version")
ON DELETE RESTRICT ON UPDATE CASCADE,
CONSTRAINT "AgentGraphExecution_userId_fkey"
FOREIGN KEY ("userId")
REFERENCES "User" ("id")
ON DELETE RESTRICT ON UPDATE CASCADE
);
INSERT INTO "new_AgentGraphExecution" ("agentGraphId", "agentGraphVersion", "id", "userId")
SELECT "agentGraphId",
"agentGraphVersion",
"id",
CASE WHEN "userId" IS NULL THEN '3e53486c-cf57-477e-ba2a-cb02dc828e1a' ELSE "userId" END
FROM "AgentGraphExecution";
DROP TABLE "AgentGraphExecution";
ALTER TABLE "new_AgentGraphExecution" RENAME TO "AgentGraphExecution";
CREATE TABLE "new_AgentGraph" (
"id" TEXT NOT NULL,
"version" INTEGER NOT NULL DEFAULT 1,
"name" TEXT,
"description" TEXT,
"isActive" BOOLEAN NOT NULL DEFAULT true,
"isTemplate" BOOLEAN NOT NULL DEFAULT false,
"userId" TEXT NOT NULL,
"agentGraphParentId" TEXT,
PRIMARY KEY ("id", "version"),
CONSTRAINT "AgentGraph_userId_fkey"
FOREIGN KEY ("userId")
REFERENCES "User" ("id")
ON DELETE RESTRICT ON UPDATE CASCADE,
CONSTRAINT "AgentGraph_agentGraphParentId_version_fkey"
FOREIGN KEY ("agentGraphParentId", "version")
REFERENCES "AgentGraph" ("id", "version")
ON DELETE RESTRICT ON UPDATE CASCADE
);
INSERT INTO "new_AgentGraph" ("agentGraphParentId", "description", "id", "isActive", "isTemplate", "name", "userId", "version")
SELECT "agentGraphParentId",
"description",
"id",
"isActive",
"isTemplate",
"name",
CASE WHEN "userId" IS NULL THEN '3e53486c-cf57-477e-ba2a-cb02dc828e1a' ELSE "userId" END,
"version"
FROM "AgentGraph";
DROP TABLE "AgentGraph";
ALTER TABLE "new_AgentGraph" RENAME TO "AgentGraph";
CREATE TABLE "new_AgentGraphExecutionSchedule" (
"id" TEXT NOT NULL PRIMARY KEY,
"agentGraphId" TEXT NOT NULL,
"agentGraphVersion" INTEGER NOT NULL DEFAULT 1,
"schedule" TEXT NOT NULL,
"isEnabled" BOOLEAN NOT NULL DEFAULT true,
"inputData" TEXT NOT NULL,
"lastUpdated" DATETIME NOT NULL,
"userId" TEXT NOT NULL,
CONSTRAINT "AgentGraphExecutionSchedule_agentGraphId_agentGraphVersion_fkey"
FOREIGN KEY ("agentGraphId", "agentGraphVersion")
REFERENCES "AgentGraph" ("id", "version")
ON DELETE RESTRICT ON UPDATE CASCADE,
CONSTRAINT "AgentGraphExecutionSchedule_userId_fkey"
FOREIGN KEY ("userId")
REFERENCES "User" ("id")
ON DELETE RESTRICT ON UPDATE CASCADE
);
INSERT INTO "new_AgentGraphExecutionSchedule" ("agentGraphId", "agentGraphVersion", "id", "inputData", "isEnabled", "lastUpdated", "schedule", "userId")
SELECT "agentGraphId",
"agentGraphVersion",
"id",
"inputData",
"isEnabled",
"lastUpdated",
"schedule",
CASE WHEN "userId" IS NULL THEN '3e53486c-cf57-477e-ba2a-cb02dc828e1a' ELSE "userId" END
FROM "AgentGraphExecutionSchedule";
DROP TABLE "AgentGraphExecutionSchedule";
ALTER TABLE "new_AgentGraphExecutionSchedule" RENAME TO "AgentGraphExecutionSchedule";
CREATE INDEX "AgentGraphExecutionSchedule_isEnabled_idx" ON "AgentGraphExecutionSchedule"("isEnabled");
PRAGMA foreign_key_check;
PRAGMA foreign_keys=ON;

View File

@@ -0,0 +1,25 @@
-- Update existing entries with NULL userId
UPDATE "AgentGraph" SET "userId" = '3e53486c-cf57-477e-ba2a-cb02dc828e1a' WHERE "userId" IS NULL;
UPDATE "AgentGraphExecution" SET "userId" = '3e53486c-cf57-477e-ba2a-cb02dc828e1a' WHERE "userId" IS NULL;
UPDATE "AgentGraphExecutionSchedule" SET "userId" = '3e53486c-cf57-477e-ba2a-cb02dc828e1a' WHERE "userId" IS NULL;
-- AlterTable
ALTER TABLE "AgentGraph" ALTER COLUMN "userId" SET NOT NULL;
-- AlterTable
ALTER TABLE "AgentGraphExecution" ALTER COLUMN "userId" SET NOT NULL;
-- AlterTable
ALTER TABLE "AgentGraphExecutionSchedule" ALTER COLUMN "userId" SET NOT NULL;
-- AlterForeignKey
ALTER TABLE "AgentGraph" DROP CONSTRAINT "AgentGraph_userId_fkey";
ALTER TABLE "AgentGraph" ADD CONSTRAINT "AgentGraph_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE RESTRICT ON UPDATE CASCADE;
-- AlterForeignKey
ALTER TABLE "AgentGraphExecution" DROP CONSTRAINT "AgentGraphExecution_userId_fkey";
ALTER TABLE "AgentGraphExecution" ADD CONSTRAINT "AgentGraphExecution_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE RESTRICT ON UPDATE CASCADE;
-- AlterForeignKey
ALTER TABLE "AgentGraphExecutionSchedule" DROP CONSTRAINT "AgentGraphExecutionSchedule_userId_fkey";
ALTER TABLE "AgentGraphExecutionSchedule" ADD CONSTRAINT "AgentGraphExecutionSchedule_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE RESTRICT ON UPDATE CASCADE;

View File

@@ -38,8 +38,8 @@ model AgentGraph {
isTemplate Boolean @default(false)
// Link to User model
userId String?
user User? @relation(fields: [userId], references: [id])
userId String
user User @relation(fields: [userId], references: [id])
AgentNodes AgentNode[]
AgentGraphExecution AgentGraphExecution[]
@@ -122,8 +122,8 @@ model AgentGraphExecution {
AgentNodeExecutions AgentNodeExecution[]
// Link to User model
userId String?
user User? @relation(fields: [userId], references: [id])
userId String
user User @relation(fields: [userId], references: [id])
}
// This model describes the execution of an AgentNode.
@@ -184,8 +184,8 @@ model AgentGraphExecutionSchedule {
lastUpdated DateTime @updatedAt
// Link to User model
userId String?
user User? @relation(fields: [userId], references: [id])
userId String
user User @relation(fields: [userId], references: [id])
@@index([isEnabled])
}

View File

@@ -37,8 +37,8 @@ model AgentGraph {
isTemplate Boolean @default(false)
// Link to User model
userId String?
user User? @relation(fields: [userId], references: [id])
userId String
user User @relation(fields: [userId], references: [id])
AgentNodes AgentNode[]
AgentGraphExecution AgentGraphExecution[]
@@ -121,8 +121,8 @@ model AgentGraphExecution {
AgentNodeExecutions AgentNodeExecution[]
// Link to User model
userId String?
user User? @relation(fields: [userId], references: [id])
userId String
user User @relation(fields: [userId], references: [id])
}
// This model describes the execution of an AgentNode.
@@ -183,8 +183,8 @@ model AgentGraphExecutionSchedule {
lastUpdated DateTime @updatedAt
// Link to User model
userId String?
user User? @relation(fields: [userId], references: [id])
userId String
user User @relation(fields: [userId], references: [id])
@@index([isEnabled])
}