mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
feat(backend): dual-write tenancy fields on create/update paths (PR3)
Update core create/update codepaths to write organizationId and orgWorkspaceId alongside legacy userId fields. Reads remain userId-based for backward compatibility. Updated functions: - graph.create_graph / __create_graph / fork_graph — accept and write organization_id + org_workspace_id to AgentGraphCreateInput - execution.create_graph_execution — accept and write tenancy fields - copilot/db.create_chat_session — accept and write tenancy fields - executor/utils.add_graph_execution — thread tenancy params through Updated routes (v1.py): - create_new_graph — resolves RequestContext, passes org/workspace IDs - update_graph — resolves RequestContext, passes org/workspace IDs - execute_graph — resolves RequestContext, passes to execution chain Test helpers in rest_api.py updated with synthetic RequestContext for backward compatibility. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -9,8 +9,9 @@ from typing import Annotated, Any, Sequence, get_args
|
||||
|
||||
import pydantic
|
||||
import stripe
|
||||
from autogpt_libs.auth import get_user_id, requires_user
|
||||
from autogpt_libs.auth import get_request_context, get_user_id, requires_user
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
from autogpt_libs.auth.models import RequestContext
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Body,
|
||||
@@ -814,12 +815,18 @@ async def get_graph_all_versions(
|
||||
async def create_new_graph(
|
||||
create_graph: CreateGraph,
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
ctx: Annotated[RequestContext, Security(get_request_context)],
|
||||
) -> graph_db.GraphModel:
|
||||
graph = graph_db.make_graph_model(create_graph.graph, user_id)
|
||||
graph.reassign_ids(user_id=user_id, reassign_graph_id=True)
|
||||
graph.validate_graph(for_run=False)
|
||||
|
||||
await graph_db.create_graph(graph, user_id=user_id)
|
||||
await graph_db.create_graph(
|
||||
graph,
|
||||
user_id=user_id,
|
||||
organization_id=ctx.org_id,
|
||||
org_workspace_id=ctx.workspace_id,
|
||||
)
|
||||
await library_db.create_library_agent(graph, user_id)
|
||||
activated_graph = await on_graph_activate(graph, user_id=user_id)
|
||||
|
||||
@@ -856,6 +863,7 @@ async def update_graph(
|
||||
graph_id: str,
|
||||
graph: graph_db.Graph,
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
ctx: Annotated[RequestContext, Security(get_request_context)],
|
||||
) -> graph_db.GraphModel:
|
||||
if graph.id and graph.id != graph_id:
|
||||
raise HTTPException(400, detail="Graph ID does not match ID in URI")
|
||||
@@ -871,7 +879,12 @@ async def update_graph(
|
||||
graph.reassign_ids(user_id=user_id, reassign_graph_id=False)
|
||||
graph.validate_graph(for_run=False)
|
||||
|
||||
new_graph_version = await graph_db.create_graph(graph, user_id=user_id)
|
||||
new_graph_version = await graph_db.create_graph(
|
||||
graph,
|
||||
user_id=user_id,
|
||||
organization_id=ctx.org_id,
|
||||
org_workspace_id=ctx.workspace_id,
|
||||
)
|
||||
|
||||
if new_graph_version.is_active:
|
||||
await library_db.update_library_agent_version_and_settings(
|
||||
@@ -973,6 +986,7 @@ async def update_graph_settings(
|
||||
async def execute_graph(
|
||||
graph_id: str,
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
ctx: Annotated[RequestContext, Security(get_request_context)],
|
||||
inputs: Annotated[dict[str, Any], Body(..., embed=True, default_factory=dict)],
|
||||
credentials_inputs: Annotated[
|
||||
dict[str, CredentialsMetaInput], Body(..., embed=True, default_factory=dict)
|
||||
@@ -1000,6 +1014,8 @@ async def execute_graph(
|
||||
graph_version=graph_version,
|
||||
graph_credentials_inputs=credentials_inputs,
|
||||
dry_run=dry_run,
|
||||
organization_id=ctx.org_id,
|
||||
org_workspace_id=ctx.workspace_id,
|
||||
)
|
||||
# Record successful graph execution
|
||||
record_graph_execution(graph_id=graph_id, status="success", user_id=user_id)
|
||||
|
||||
@@ -417,8 +417,22 @@ class AgentServer(backend.util.service.AppProcess):
|
||||
graph_version: Optional[int] = None,
|
||||
node_input: Optional[dict[str, Any]] = None,
|
||||
):
|
||||
from autogpt_libs.auth.models import RequestContext
|
||||
|
||||
ctx = RequestContext(
|
||||
user_id=user_id,
|
||||
org_id="test-org",
|
||||
workspace_id="test-workspace",
|
||||
is_org_owner=True,
|
||||
is_org_admin=True,
|
||||
is_org_billing_manager=False,
|
||||
is_workspace_admin=True,
|
||||
is_workspace_billing_manager=False,
|
||||
seat_status="ACTIVE",
|
||||
)
|
||||
return await backend.api.features.v1.execute_graph(
|
||||
user_id=user_id,
|
||||
ctx=ctx,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
inputs=node_input or {},
|
||||
@@ -441,7 +455,22 @@ class AgentServer(backend.util.service.AppProcess):
|
||||
create_graph: backend.api.features.v1.CreateGraph,
|
||||
user_id: str,
|
||||
):
|
||||
return await backend.api.features.v1.create_new_graph(create_graph, user_id)
|
||||
from autogpt_libs.auth.models import RequestContext
|
||||
|
||||
ctx = RequestContext(
|
||||
user_id=user_id,
|
||||
org_id="test-org",
|
||||
workspace_id="test-workspace",
|
||||
is_org_owner=True,
|
||||
is_org_admin=True,
|
||||
is_org_billing_manager=False,
|
||||
is_workspace_admin=True,
|
||||
is_workspace_billing_manager=False,
|
||||
seat_status="ACTIVE",
|
||||
)
|
||||
return await backend.api.features.v1.create_new_graph(
|
||||
create_graph, user_id, ctx
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def test_get_graph_run_status(graph_exec_id: str, user_id: str):
|
||||
|
||||
@@ -35,6 +35,9 @@ async def get_chat_session(session_id: str) -> ChatSession | None:
|
||||
async def create_chat_session(
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
*,
|
||||
organization_id: str | None = None,
|
||||
org_workspace_id: str | None = None,
|
||||
) -> ChatSessionInfo:
|
||||
"""Create a new chat session in the database."""
|
||||
data = ChatSessionCreateInput(
|
||||
@@ -43,6 +46,9 @@ async def create_chat_session(
|
||||
credentials=SafeJson({}),
|
||||
successfulAgentRuns=SafeJson({}),
|
||||
successfulAgentSchedules=SafeJson({}),
|
||||
# Tenancy dual-write fields
|
||||
**({"organizationId": organization_id} if organization_id else {}),
|
||||
**({"orgWorkspaceId": org_workspace_id} if org_workspace_id else {}),
|
||||
)
|
||||
prisma_session = await PrismaChatSession.prisma().create(data=data)
|
||||
return ChatSessionInfo.from_db(prisma_session)
|
||||
|
||||
@@ -730,6 +730,8 @@ async def create_graph_execution(
|
||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||
parent_graph_exec_id: Optional[str] = None,
|
||||
is_dry_run: bool = False,
|
||||
organization_id: Optional[str] = None,
|
||||
org_workspace_id: Optional[str] = None,
|
||||
) -> GraphExecutionWithNodes:
|
||||
"""
|
||||
Create a new AgentGraphExecution record.
|
||||
@@ -768,6 +770,9 @@ async def create_graph_execution(
|
||||
"agentPresetId": preset_id,
|
||||
"parentGraphExecutionId": parent_graph_exec_id,
|
||||
**({"stats": Json({"is_dry_run": True})} if is_dry_run else {}),
|
||||
# Tenancy dual-write fields
|
||||
**({"organizationId": organization_id} if organization_id else {}),
|
||||
**({"orgWorkspaceId": org_workspace_id} if org_workspace_id else {}),
|
||||
},
|
||||
include=GRAPH_EXECUTION_INCLUDE_WITH_NODES,
|
||||
)
|
||||
|
||||
@@ -1494,9 +1494,21 @@ async def is_graph_published_in_marketplace(graph_id: str, graph_version: int) -
|
||||
return marketplace_listing is not None
|
||||
|
||||
|
||||
async def create_graph(graph: Graph, user_id: str) -> GraphModel:
|
||||
async def create_graph(
|
||||
graph: Graph,
|
||||
user_id: str,
|
||||
*,
|
||||
organization_id: str | None = None,
|
||||
org_workspace_id: str | None = None,
|
||||
) -> GraphModel:
|
||||
async with transaction() as tx:
|
||||
await __create_graph(tx, graph, user_id)
|
||||
await __create_graph(
|
||||
tx,
|
||||
graph,
|
||||
user_id,
|
||||
organization_id=organization_id,
|
||||
org_workspace_id=org_workspace_id,
|
||||
)
|
||||
|
||||
if created_graph := await get_graph(graph.id, graph.version, user_id=user_id):
|
||||
return created_graph
|
||||
@@ -1504,7 +1516,14 @@ async def create_graph(graph: Graph, user_id: str) -> GraphModel:
|
||||
raise ValueError(f"Created graph {graph.id} v{graph.version} is not in DB")
|
||||
|
||||
|
||||
async def fork_graph(graph_id: str, graph_version: int, user_id: str) -> GraphModel:
|
||||
async def fork_graph(
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
user_id: str,
|
||||
*,
|
||||
organization_id: str | None = None,
|
||||
org_workspace_id: str | None = None,
|
||||
) -> GraphModel:
|
||||
"""
|
||||
Forks a graph by copying it and all its nodes and links to a new graph.
|
||||
"""
|
||||
@@ -1520,12 +1539,25 @@ async def fork_graph(graph_id: str, graph_version: int, user_id: str) -> GraphMo
|
||||
graph.validate_graph(for_run=False)
|
||||
|
||||
async with transaction() as tx:
|
||||
await __create_graph(tx, graph, user_id)
|
||||
await __create_graph(
|
||||
tx,
|
||||
graph,
|
||||
user_id,
|
||||
organization_id=organization_id,
|
||||
org_workspace_id=org_workspace_id,
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
async def __create_graph(tx, graph: Graph, user_id: str):
|
||||
async def __create_graph(
|
||||
tx,
|
||||
graph: Graph,
|
||||
user_id: str,
|
||||
*,
|
||||
organization_id: str | None = None,
|
||||
org_workspace_id: str | None = None,
|
||||
):
|
||||
graphs = [graph] + graph.sub_graphs
|
||||
|
||||
# Auto-increment version for any graph entry (parent or sub-graph) whose
|
||||
@@ -1562,6 +1594,9 @@ async def __create_graph(tx, graph: Graph, user_id: str):
|
||||
userId=user_id,
|
||||
forkedFromId=graph.forked_from_id,
|
||||
forkedFromVersion=graph.forked_from_version,
|
||||
# Tenancy dual-write fields
|
||||
organizationId=organization_id,
|
||||
orgWorkspaceId=org_workspace_id,
|
||||
)
|
||||
for graph in graphs
|
||||
]
|
||||
|
||||
@@ -869,6 +869,8 @@ async def add_graph_execution(
|
||||
execution_context: Optional[ExecutionContext] = None,
|
||||
graph_exec_id: Optional[str] = None,
|
||||
dry_run: bool = False,
|
||||
organization_id: Optional[str] = None,
|
||||
org_workspace_id: Optional[str] = None,
|
||||
) -> GraphExecutionWithNodes:
|
||||
"""
|
||||
Adds a graph execution to the queue and returns the execution entry.
|
||||
@@ -948,6 +950,8 @@ async def add_graph_execution(
|
||||
preset_id=preset_id,
|
||||
parent_graph_exec_id=parent_exec_id,
|
||||
is_dry_run=dry_run,
|
||||
organization_id=organization_id,
|
||||
org_workspace_id=org_workspace_id,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
|
||||
Reference in New Issue
Block a user