mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
fix(server): Make userId required on DB entities and apply default ID to existing entries (#7755)
* Make `userId` required on DB entities `AgentGraph`, `AgentGraphExecution`, and `AgentGraphExecutionSchedule` * Add SQLite and Postgres migrations to make `userId` required and set `userId` to `3e53486c-cf57-477e-ba2a-cb02dc828e1a` on existing entries without `userId` * Amend `create_graph` endpoint and `.data.graph`, `.data.execution` methods to handle required `user_id` * Add `.data.user.DEFAULT_USER_ID` constant to replace hardcoded literals
This commit is contained in:
committed by
GitHub
parent
582571631e
commit
2ff8a0743a
@@ -117,7 +117,10 @@ EXECUTION_RESULT_INCLUDE = {
|
||||
|
||||
|
||||
async def create_graph_execution(
|
||||
graph_id: str, graph_version: int, nodes_input: list[tuple[str, BlockInput]]
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
nodes_input: list[tuple[str, BlockInput]],
|
||||
user_id: str,
|
||||
) -> tuple[str, list[ExecutionResult]]:
|
||||
"""
|
||||
Create a new AgentGraphExecution record.
|
||||
@@ -143,6 +146,7 @@ async def create_graph_execution(
|
||||
for node_id, node_input in nodes_input
|
||||
]
|
||||
},
|
||||
"userId": user_id,
|
||||
},
|
||||
include={
|
||||
"AgentNodeExecutions": {"include": EXECUTION_RESULT_INCLUDE} # type: ignore
|
||||
|
||||
@@ -10,6 +10,7 @@ from pydantic import PrivateAttr
|
||||
from autogpt_server.blocks.basic import InputBlock, OutputBlock
|
||||
from autogpt_server.data.block import BlockInput, get_block
|
||||
from autogpt_server.data.db import BaseDbModel, transaction
|
||||
from autogpt_server.data.user import DEFAULT_USER_ID
|
||||
from autogpt_server.util import json
|
||||
|
||||
|
||||
@@ -151,7 +152,6 @@ class Graph(GraphMeta):
|
||||
}
|
||||
|
||||
def validate_graph(self, for_run: bool = False):
|
||||
|
||||
def sanitize(name):
|
||||
return name.split("_#_")[0].split("_@_")[0].split("_$_")[0]
|
||||
|
||||
@@ -368,9 +368,7 @@ async def set_graph_active_version(graph_id: str, version: int, user_id: str) ->
|
||||
)
|
||||
|
||||
|
||||
async def get_graph_all_versions(
|
||||
graph_id: str, user_id: str | None = None
|
||||
) -> list[Graph]:
|
||||
async def get_graph_all_versions(graph_id: str, user_id: str) -> list[Graph]:
|
||||
graph_versions = await AgentGraph.prisma().find_many(
|
||||
where={"id": graph_id, "userId": user_id},
|
||||
order={"version": "desc"},
|
||||
@@ -383,7 +381,7 @@ async def get_graph_all_versions(
|
||||
return [Graph.from_db(graph) for graph in graph_versions]
|
||||
|
||||
|
||||
async def create_graph(graph: Graph, user_id: str | None) -> Graph:
|
||||
async def create_graph(graph: Graph, user_id: str) -> Graph:
|
||||
async with transaction() as tx:
|
||||
await __create_graph(tx, graph, user_id)
|
||||
|
||||
@@ -395,7 +393,7 @@ async def create_graph(graph: Graph, user_id: str | None) -> Graph:
|
||||
raise ValueError(f"Created graph {graph.id} v{graph.version} is not in DB")
|
||||
|
||||
|
||||
async def __create_graph(tx, graph: Graph, user_id: str | None):
|
||||
async def __create_graph(tx, graph: Graph, user_id: str):
|
||||
await AgentGraph.prisma(tx).create(
|
||||
data={
|
||||
"id": graph.id,
|
||||
@@ -482,5 +480,5 @@ async def import_packaged_templates() -> None:
|
||||
exists := next((t for t in templates_in_db if t.id == template.id), None)
|
||||
) and exists.version >= template.version:
|
||||
continue
|
||||
await create_graph(template, None)
|
||||
await create_graph(template, DEFAULT_USER_ID)
|
||||
print(f"Loaded template '{template.name}' ({template.id})")
|
||||
|
||||
@@ -4,6 +4,8 @@ from prisma.models import User
|
||||
|
||||
from autogpt_server.data.db import prisma
|
||||
|
||||
DEFAULT_USER_ID = "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
|
||||
|
||||
async def get_or_create_user(user_data: dict) -> User:
|
||||
user = await prisma.user.find_unique(where={"id": user_data["sub"]})
|
||||
@@ -25,13 +27,11 @@ async def get_user_by_id(user_id: str) -> Optional[User]:
|
||||
|
||||
async def create_default_user(enable_auth: str) -> Optional[User]:
|
||||
if not enable_auth.lower() == "true":
|
||||
user = await prisma.user.find_unique(
|
||||
where={"id": "3e53486c-cf57-477e-ba2a-cb02dc828e1a"}
|
||||
)
|
||||
user = await prisma.user.find_unique(where={"id": DEFAULT_USER_ID})
|
||||
if not user:
|
||||
user = await prisma.user.create(
|
||||
data={
|
||||
"id": "3e53486c-cf57-477e-ba2a-cb02dc828e1a",
|
||||
"id": DEFAULT_USER_ID,
|
||||
"email": "default@example.com",
|
||||
"name": "Default User",
|
||||
}
|
||||
|
||||
@@ -165,7 +165,6 @@ def _enqueue_next_nodes(
|
||||
# To avoid same execution to be enqueued multiple times,
|
||||
# Or the same input to be consumed multiple times.
|
||||
with synchronized(api_client, ("upsert_input", next_node_id, graph_exec_id)):
|
||||
|
||||
# Add output data to the earliest incomplete execution, or create a new one.
|
||||
next_node_exec_id, next_node_input = wait(
|
||||
upsert_execution_input(
|
||||
@@ -442,6 +441,7 @@ class ExecutionManager(AppService):
|
||||
graph_id=graph_id,
|
||||
graph_version=graph.version,
|
||||
nodes_input=nodes_input,
|
||||
user_id=user_id,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ from autogpt_server.data.execution import (
|
||||
get_execution_results,
|
||||
list_executions,
|
||||
)
|
||||
from autogpt_server.data.user import get_or_create_user
|
||||
from autogpt_server.data.user import DEFAULT_USER_ID, get_or_create_user
|
||||
from autogpt_server.executor import ExecutionManager, ExecutionScheduler
|
||||
from autogpt_server.server.conn_manager import ConnectionManager
|
||||
from autogpt_server.server.model import (
|
||||
@@ -49,7 +49,7 @@ settings = Settings()
|
||||
def get_user_id(payload: dict = Depends(auth_middleware)) -> str:
|
||||
if not payload:
|
||||
# This handles the case when authentication is disabled
|
||||
return "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
return DEFAULT_USER_ID
|
||||
|
||||
user_id = payload.get("sub")
|
||||
if not user_id:
|
||||
@@ -73,8 +73,8 @@ class AgentServer(AppService):
|
||||
async def lifespan(self, _: FastAPI):
|
||||
await db.connect()
|
||||
await block.initialize_blocks()
|
||||
await graph_db.import_packaged_templates()
|
||||
await user_db.create_default_user(settings.config.enable_auth)
|
||||
if await user_db.create_default_user(settings.config.enable_auth):
|
||||
await graph_db.import_packaged_templates()
|
||||
asyncio.create_task(self.event_broadcaster())
|
||||
yield
|
||||
await db.disconnect()
|
||||
@@ -294,7 +294,7 @@ class AgentServer(AppService):
|
||||
await websocket.close(code=4003, reason="Invalid token")
|
||||
return ""
|
||||
else:
|
||||
return "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
return user_db.DEFAULT_USER_ID
|
||||
|
||||
async def websocket_router(self, websocket: WebSocket):
|
||||
user_id = await self.authenticate_websocket(websocket)
|
||||
@@ -551,7 +551,12 @@ class AgentServer(AppService):
|
||||
|
||||
@classmethod
|
||||
async def create_graph(
|
||||
cls, create_graph: CreateGraph, is_template: bool, user_id: str
|
||||
cls,
|
||||
create_graph: CreateGraph,
|
||||
is_template: bool,
|
||||
# user_id doesn't have to be annotated like on other endpoints,
|
||||
# because create_graph isn't used directly as an endpoint
|
||||
user_id: str,
|
||||
) -> graph_db.Graph:
|
||||
if create_graph.graph:
|
||||
graph = create_graph.graph
|
||||
|
||||
@@ -0,0 +1,92 @@
|
||||
-- RedefineTables
|
||||
PRAGMA foreign_keys=OFF;
|
||||
|
||||
CREATE TABLE "new_AgentGraphExecution" (
|
||||
"id" TEXT NOT NULL PRIMARY KEY,
|
||||
"agentGraphId" TEXT NOT NULL,
|
||||
"agentGraphVersion" INTEGER NOT NULL DEFAULT 1,
|
||||
"userId" TEXT NOT NULL,
|
||||
CONSTRAINT "AgentGraphExecution_agentGraphId_agentGraphVersion_fkey"
|
||||
FOREIGN KEY ("agentGraphId", "agentGraphVersion")
|
||||
REFERENCES "AgentGraph" ("id", "version")
|
||||
ON DELETE RESTRICT ON UPDATE CASCADE,
|
||||
CONSTRAINT "AgentGraphExecution_userId_fkey"
|
||||
FOREIGN KEY ("userId")
|
||||
REFERENCES "User" ("id")
|
||||
ON DELETE RESTRICT ON UPDATE CASCADE
|
||||
);
|
||||
INSERT INTO "new_AgentGraphExecution" ("agentGraphId", "agentGraphVersion", "id", "userId")
|
||||
SELECT "agentGraphId",
|
||||
"agentGraphVersion",
|
||||
"id",
|
||||
CASE WHEN "userId" IS NULL THEN '3e53486c-cf57-477e-ba2a-cb02dc828e1a' ELSE "userId" END
|
||||
FROM "AgentGraphExecution";
|
||||
DROP TABLE "AgentGraphExecution";
|
||||
ALTER TABLE "new_AgentGraphExecution" RENAME TO "AgentGraphExecution";
|
||||
|
||||
CREATE TABLE "new_AgentGraph" (
|
||||
"id" TEXT NOT NULL,
|
||||
"version" INTEGER NOT NULL DEFAULT 1,
|
||||
"name" TEXT,
|
||||
"description" TEXT,
|
||||
"isActive" BOOLEAN NOT NULL DEFAULT true,
|
||||
"isTemplate" BOOLEAN NOT NULL DEFAULT false,
|
||||
"userId" TEXT NOT NULL,
|
||||
"agentGraphParentId" TEXT,
|
||||
PRIMARY KEY ("id", "version"),
|
||||
CONSTRAINT "AgentGraph_userId_fkey"
|
||||
FOREIGN KEY ("userId")
|
||||
REFERENCES "User" ("id")
|
||||
ON DELETE RESTRICT ON UPDATE CASCADE,
|
||||
CONSTRAINT "AgentGraph_agentGraphParentId_version_fkey"
|
||||
FOREIGN KEY ("agentGraphParentId", "version")
|
||||
REFERENCES "AgentGraph" ("id", "version")
|
||||
ON DELETE RESTRICT ON UPDATE CASCADE
|
||||
);
|
||||
INSERT INTO "new_AgentGraph" ("agentGraphParentId", "description", "id", "isActive", "isTemplate", "name", "userId", "version")
|
||||
SELECT "agentGraphParentId",
|
||||
"description",
|
||||
"id",
|
||||
"isActive",
|
||||
"isTemplate",
|
||||
"name",
|
||||
CASE WHEN "userId" IS NULL THEN '3e53486c-cf57-477e-ba2a-cb02dc828e1a' ELSE "userId" END,
|
||||
"version"
|
||||
FROM "AgentGraph";
|
||||
DROP TABLE "AgentGraph";
|
||||
ALTER TABLE "new_AgentGraph" RENAME TO "AgentGraph";
|
||||
|
||||
CREATE TABLE "new_AgentGraphExecutionSchedule" (
|
||||
"id" TEXT NOT NULL PRIMARY KEY,
|
||||
"agentGraphId" TEXT NOT NULL,
|
||||
"agentGraphVersion" INTEGER NOT NULL DEFAULT 1,
|
||||
"schedule" TEXT NOT NULL,
|
||||
"isEnabled" BOOLEAN NOT NULL DEFAULT true,
|
||||
"inputData" TEXT NOT NULL,
|
||||
"lastUpdated" DATETIME NOT NULL,
|
||||
"userId" TEXT NOT NULL,
|
||||
CONSTRAINT "AgentGraphExecutionSchedule_agentGraphId_agentGraphVersion_fkey"
|
||||
FOREIGN KEY ("agentGraphId", "agentGraphVersion")
|
||||
REFERENCES "AgentGraph" ("id", "version")
|
||||
ON DELETE RESTRICT ON UPDATE CASCADE,
|
||||
CONSTRAINT "AgentGraphExecutionSchedule_userId_fkey"
|
||||
FOREIGN KEY ("userId")
|
||||
REFERENCES "User" ("id")
|
||||
ON DELETE RESTRICT ON UPDATE CASCADE
|
||||
);
|
||||
INSERT INTO "new_AgentGraphExecutionSchedule" ("agentGraphId", "agentGraphVersion", "id", "inputData", "isEnabled", "lastUpdated", "schedule", "userId")
|
||||
SELECT "agentGraphId",
|
||||
"agentGraphVersion",
|
||||
"id",
|
||||
"inputData",
|
||||
"isEnabled",
|
||||
"lastUpdated",
|
||||
"schedule",
|
||||
CASE WHEN "userId" IS NULL THEN '3e53486c-cf57-477e-ba2a-cb02dc828e1a' ELSE "userId" END
|
||||
FROM "AgentGraphExecutionSchedule";
|
||||
DROP TABLE "AgentGraphExecutionSchedule";
|
||||
ALTER TABLE "new_AgentGraphExecutionSchedule" RENAME TO "AgentGraphExecutionSchedule";
|
||||
CREATE INDEX "AgentGraphExecutionSchedule_isEnabled_idx" ON "AgentGraphExecutionSchedule"("isEnabled");
|
||||
|
||||
PRAGMA foreign_key_check;
|
||||
PRAGMA foreign_keys=ON;
|
||||
@@ -0,0 +1,25 @@
|
||||
-- Update existing entries with NULL userId
|
||||
UPDATE "AgentGraph" SET "userId" = '3e53486c-cf57-477e-ba2a-cb02dc828e1a' WHERE "userId" IS NULL;
|
||||
UPDATE "AgentGraphExecution" SET "userId" = '3e53486c-cf57-477e-ba2a-cb02dc828e1a' WHERE "userId" IS NULL;
|
||||
UPDATE "AgentGraphExecutionSchedule" SET "userId" = '3e53486c-cf57-477e-ba2a-cb02dc828e1a' WHERE "userId" IS NULL;
|
||||
|
||||
-- AlterTable
|
||||
ALTER TABLE "AgentGraph" ALTER COLUMN "userId" SET NOT NULL;
|
||||
|
||||
-- AlterTable
|
||||
ALTER TABLE "AgentGraphExecution" ALTER COLUMN "userId" SET NOT NULL;
|
||||
|
||||
-- AlterTable
|
||||
ALTER TABLE "AgentGraphExecutionSchedule" ALTER COLUMN "userId" SET NOT NULL;
|
||||
|
||||
-- AlterForeignKey
|
||||
ALTER TABLE "AgentGraph" DROP CONSTRAINT "AgentGraph_userId_fkey";
|
||||
ALTER TABLE "AgentGraph" ADD CONSTRAINT "AgentGraph_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE RESTRICT ON UPDATE CASCADE;
|
||||
|
||||
-- AlterForeignKey
|
||||
ALTER TABLE "AgentGraphExecution" DROP CONSTRAINT "AgentGraphExecution_userId_fkey";
|
||||
ALTER TABLE "AgentGraphExecution" ADD CONSTRAINT "AgentGraphExecution_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE RESTRICT ON UPDATE CASCADE;
|
||||
|
||||
-- AlterForeignKey
|
||||
ALTER TABLE "AgentGraphExecutionSchedule" DROP CONSTRAINT "AgentGraphExecutionSchedule_userId_fkey";
|
||||
ALTER TABLE "AgentGraphExecutionSchedule" ADD CONSTRAINT "AgentGraphExecutionSchedule_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE RESTRICT ON UPDATE CASCADE;
|
||||
@@ -38,8 +38,8 @@ model AgentGraph {
|
||||
isTemplate Boolean @default(false)
|
||||
|
||||
// Link to User model
|
||||
userId String?
|
||||
user User? @relation(fields: [userId], references: [id])
|
||||
userId String
|
||||
user User @relation(fields: [userId], references: [id])
|
||||
|
||||
AgentNodes AgentNode[]
|
||||
AgentGraphExecution AgentGraphExecution[]
|
||||
@@ -122,8 +122,8 @@ model AgentGraphExecution {
|
||||
AgentNodeExecutions AgentNodeExecution[]
|
||||
|
||||
// Link to User model
|
||||
userId String?
|
||||
user User? @relation(fields: [userId], references: [id])
|
||||
userId String
|
||||
user User @relation(fields: [userId], references: [id])
|
||||
}
|
||||
|
||||
// This model describes the execution of an AgentNode.
|
||||
@@ -184,8 +184,8 @@ model AgentGraphExecutionSchedule {
|
||||
lastUpdated DateTime @updatedAt
|
||||
|
||||
// Link to User model
|
||||
userId String?
|
||||
user User? @relation(fields: [userId], references: [id])
|
||||
userId String
|
||||
user User @relation(fields: [userId], references: [id])
|
||||
|
||||
@@index([isEnabled])
|
||||
}
|
||||
|
||||
@@ -37,8 +37,8 @@ model AgentGraph {
|
||||
isTemplate Boolean @default(false)
|
||||
|
||||
// Link to User model
|
||||
userId String?
|
||||
user User? @relation(fields: [userId], references: [id])
|
||||
userId String
|
||||
user User @relation(fields: [userId], references: [id])
|
||||
|
||||
AgentNodes AgentNode[]
|
||||
AgentGraphExecution AgentGraphExecution[]
|
||||
@@ -121,8 +121,8 @@ model AgentGraphExecution {
|
||||
AgentNodeExecutions AgentNodeExecution[]
|
||||
|
||||
// Link to User model
|
||||
userId String?
|
||||
user User? @relation(fields: [userId], references: [id])
|
||||
userId String
|
||||
user User @relation(fields: [userId], references: [id])
|
||||
}
|
||||
|
||||
// This model describes the execution of an AgentNode.
|
||||
@@ -183,8 +183,8 @@ model AgentGraphExecutionSchedule {
|
||||
lastUpdated DateTime @updatedAt
|
||||
|
||||
// Link to User model
|
||||
userId String?
|
||||
user User? @relation(fields: [userId], references: [id])
|
||||
userId String
|
||||
user User @relation(fields: [userId], references: [id])
|
||||
|
||||
@@index([isEnabled])
|
||||
}
|
||||
Reference in New Issue
Block a user