mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
add library management functions
This commit is contained in:
@@ -13,6 +13,7 @@ from typing_extensions import Optional, TypedDict
|
||||
import backend.data.block
|
||||
import backend.server.integrations.router
|
||||
import backend.server.routers.analytics
|
||||
import backend.server.v2.library.db
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data.api_key import (
|
||||
@@ -180,11 +181,6 @@ async def get_graph(
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
@v1_router.get(
|
||||
path="/templates/{graph_id}/versions",
|
||||
tags=["templates", "graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def get_graph_all_versions(
|
||||
graph_id: str, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> Sequence[graph_db.GraphModel]:
|
||||
@@ -223,6 +219,13 @@ async def do_create_graph(
|
||||
400, detail=f"Template #{create_graph.template_id} not found"
|
||||
)
|
||||
graph.version = 1
|
||||
|
||||
# Create a library agent for the new graph
|
||||
await backend.server.v2.library.db.create_library_agent(
|
||||
graph.id,
|
||||
graph.version,
|
||||
user_id,
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Either graph or template_id must be provided."
|
||||
@@ -257,11 +260,6 @@ async def delete_graph(
|
||||
@v1_router.put(
|
||||
path="/graphs/{graph_id}", tags=["graphs"], dependencies=[Depends(auth_middleware)]
|
||||
)
|
||||
@v1_router.put(
|
||||
path="/templates/{graph_id}",
|
||||
tags=["templates", "graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def update_graph(
|
||||
graph_id: str,
|
||||
graph: graph_db.Graph,
|
||||
@@ -294,6 +292,11 @@ async def update_graph(
|
||||
|
||||
if new_graph_version.is_active:
|
||||
|
||||
# Keep the library agent up to date with the new active version
|
||||
await backend.server.v2.library.db.update_agent_version_in_library(
|
||||
user_id, graph.id, graph.version
|
||||
)
|
||||
|
||||
def get_credentials(credentials_id: str) -> "Credentials | None":
|
||||
return integration_creds_manager.get(user_id, credentials_id)
|
||||
|
||||
@@ -349,6 +352,12 @@ async def set_graph_active_version(
|
||||
version=new_active_version,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Keep the library agent up to date with the new active version
|
||||
await backend.server.v2.library.db.update_agent_version_in_library(
|
||||
user_id, new_active_graph.id, new_active_graph.version
|
||||
)
|
||||
|
||||
if current_active_graph and current_active_graph.version != new_active_version:
|
||||
# Handle deactivation of the previously active version
|
||||
await on_graph_deactivate(
|
||||
@@ -425,47 +434,6 @@ async def get_graph_run_node_execution_results(
|
||||
return await execution_db.get_execution_results(graph_exec_id)
|
||||
|
||||
|
||||
########################################################
|
||||
##################### Templates ########################
|
||||
########################################################
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
path="/templates",
|
||||
tags=["graphs", "templates"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def get_templates(
|
||||
user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> Sequence[graph_db.GraphModel]:
|
||||
return await graph_db.get_graphs(filter_by="template", user_id=user_id)
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
path="/templates/{graph_id}",
|
||||
tags=["templates", "graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def get_template(
|
||||
graph_id: str, version: int | None = None
|
||||
) -> graph_db.GraphModel:
|
||||
graph = await graph_db.get_graph(graph_id, version)
|
||||
if not graph:
|
||||
raise HTTPException(status_code=404, detail=f"Template #{graph_id} not found.")
|
||||
return graph
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
path="/templates",
|
||||
tags=["templates", "graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def create_new_template(
|
||||
create_graph: CreateGraph, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> graph_db.GraphModel:
|
||||
return await do_create_graph(create_graph, user_id=user_id)
|
||||
|
||||
|
||||
########################################################
|
||||
##################### Schedules ########################
|
||||
########################################################
|
||||
|
||||
@@ -17,24 +17,51 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
async def get_library_agents(
|
||||
user_id: str,
|
||||
search_query: str | None = None,
|
||||
) -> List[backend.server.v2.library.model.LibraryAgent]:
|
||||
"""
|
||||
Returns all agents (AgentGraph) that belong to the user and all agents in their library (UserAgent table)
|
||||
Returns all agents (AgentGraph) that belong to the user and all agents in their library (LibraryAgent table)
|
||||
"""
|
||||
logger.debug(f"Getting library agents for user {user_id}")
|
||||
logger.debug(
|
||||
f"Getting library agents for user {user_id} with search query: {search_query}"
|
||||
)
|
||||
|
||||
try:
|
||||
# Get agents created by user with nodes and links
|
||||
user_created = await prisma.models.AgentGraph.prisma().find_many(
|
||||
where=prisma.types.AgentGraphWhereInput(userId=user_id, isActive=True),
|
||||
include=backend.data.includes.AGENT_GRAPH_INCLUDE,
|
||||
# Sanitize and validate search query by escaping special characters
|
||||
# Build where clause with sanitized inputs
|
||||
where_clause = prisma.types.LibraryAgentWhereInput(
|
||||
userId=user_id,
|
||||
isDeleted=False,
|
||||
isArchived=False,
|
||||
**(
|
||||
{
|
||||
"OR": [
|
||||
{
|
||||
"Agent": {
|
||||
"name": {
|
||||
"contains": search_query,
|
||||
"mode": "insensitive",
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"Agent": {
|
||||
"description": {
|
||||
"contains": search_query,
|
||||
"mode": "insensitive",
|
||||
}
|
||||
}
|
||||
},
|
||||
]
|
||||
}
|
||||
if search_query
|
||||
else {}
|
||||
),
|
||||
)
|
||||
|
||||
# Get agents in user's library with nodes and links
|
||||
library_agents = await prisma.models.UserAgent.prisma().find_many(
|
||||
where=prisma.types.UserAgentWhereInput(
|
||||
userId=user_id, isDeleted=False, isArchived=False
|
||||
),
|
||||
library_agents = await prisma.models.LibraryAgent.prisma().find_many(
|
||||
where=where_clause,
|
||||
include={
|
||||
"Agent": {
|
||||
"include": {
|
||||
@@ -42,26 +69,17 @@ async def get_library_agents(
|
||||
"include": {
|
||||
"Input": True,
|
||||
"Output": True,
|
||||
"Webhook": True,
|
||||
"AgentBlock": True,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"AgentPreset": {"include": {"InputPresets": True}},
|
||||
},
|
||||
order=[{"updatedAt": "desc"}],
|
||||
)
|
||||
|
||||
# Convert to Graph models first
|
||||
graphs = []
|
||||
|
||||
# Add user created agents
|
||||
for agent in user_created:
|
||||
try:
|
||||
graphs.append(backend.data.graph.GraphModel.from_db(agent))
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing user created agent {agent.id}: {e}")
|
||||
continue
|
||||
|
||||
# Add library agents
|
||||
for agent in library_agents:
|
||||
if agent.Agent:
|
||||
@@ -71,21 +89,10 @@ async def get_library_agents(
|
||||
logger.error(f"Error processing library agent {agent.agentId}: {e}")
|
||||
continue
|
||||
|
||||
# Convert Graph models to LibraryAgent models
|
||||
result = []
|
||||
for graph in graphs:
|
||||
result.append(
|
||||
backend.server.v2.library.model.LibraryAgent(
|
||||
id=graph.id,
|
||||
version=graph.version,
|
||||
is_active=graph.is_active,
|
||||
name=graph.name,
|
||||
description=graph.description,
|
||||
isCreatedByUser=any(a.id == graph.id for a in user_created),
|
||||
input_schema=graph.input_schema,
|
||||
output_schema=graph.output_schema,
|
||||
)
|
||||
)
|
||||
result = [
|
||||
backend.server.v2.library.model.LibraryAgent.from_db(agent)
|
||||
for agent in library_agents
|
||||
]
|
||||
|
||||
logger.debug(f"Found {len(result)} library agents")
|
||||
return result
|
||||
@@ -99,7 +106,7 @@ async def get_library_agents(
|
||||
|
||||
async def add_agent_to_library(store_listing_version_id: str, user_id: str) -> None:
|
||||
"""
|
||||
Finds the agent from the store listing version and adds it to the user's library (UserAgent table)
|
||||
Finds the agent from the store listing version and adds it to the user's library (LibraryAgent table)
|
||||
if they don't already have it
|
||||
"""
|
||||
logger.debug(
|
||||
@@ -133,7 +140,7 @@ async def add_agent_to_library(store_listing_version_id: str, user_id: str) -> N
|
||||
)
|
||||
|
||||
# Check if user already has this agent
|
||||
existing_user_agent = await prisma.models.UserAgent.prisma().find_first(
|
||||
existing_user_agent = await prisma.models.LibraryAgent.prisma().find_first(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"agentId": agent.id,
|
||||
@@ -147,9 +154,9 @@ async def add_agent_to_library(store_listing_version_id: str, user_id: str) -> N
|
||||
)
|
||||
return
|
||||
|
||||
# Create UserAgent entry
|
||||
await prisma.models.UserAgent.prisma().create(
|
||||
data=prisma.types.UserAgentCreateInput(
|
||||
# Create LibraryAgent entry
|
||||
await prisma.models.LibraryAgent.prisma().create(
|
||||
data=prisma.types.LibraryAgentCreateInput(
|
||||
userId=user_id,
|
||||
agentId=agent.id,
|
||||
agentVersion=agent.version,
|
||||
|
||||
@@ -37,7 +37,7 @@ async def test_get_library_agents(mocker):
|
||||
]
|
||||
|
||||
mock_library_agents = [
|
||||
prisma.models.UserAgent(
|
||||
prisma.models.LibraryAgent(
|
||||
id="ua1",
|
||||
userId="test-user",
|
||||
agentId="agent2",
|
||||
@@ -48,6 +48,7 @@ async def test_get_library_agents(mocker):
|
||||
createdAt=datetime.now(),
|
||||
updatedAt=datetime.now(),
|
||||
isFavorite=False,
|
||||
useGraphIsActiveVersion=True,
|
||||
Agent=prisma.models.AgentGraph(
|
||||
id="agent2",
|
||||
version=1,
|
||||
@@ -67,8 +68,8 @@ async def test_get_library_agents(mocker):
|
||||
return_value=mock_user_created
|
||||
)
|
||||
|
||||
mock_user_agent = mocker.patch("prisma.models.UserAgent.prisma")
|
||||
mock_user_agent.return_value.find_many = mocker.AsyncMock(
|
||||
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
|
||||
mock_library_agent.return_value.find_many = mocker.AsyncMock(
|
||||
return_value=mock_library_agents
|
||||
)
|
||||
|
||||
@@ -80,19 +81,19 @@ async def test_get_library_agents(mocker):
|
||||
assert result[0].id == "agent1"
|
||||
assert result[0].name == "Test Agent 1"
|
||||
assert result[0].description == "Test Description 1"
|
||||
assert result[0].isCreatedByUser is True
|
||||
assert result[0].is_created_by_user is True
|
||||
assert result[1].id == "agent2"
|
||||
assert result[1].name == "Test Agent 2"
|
||||
assert result[1].description == "Test Description 2"
|
||||
assert result[1].isCreatedByUser is False
|
||||
assert result[1].is_created_by_user is False
|
||||
|
||||
# Verify mocks called correctly
|
||||
mock_agent_graph.return_value.find_many.assert_called_once_with(
|
||||
where=prisma.types.AgentGraphWhereInput(userId="test-user", isActive=True),
|
||||
include=backend.data.includes.AGENT_GRAPH_INCLUDE,
|
||||
)
|
||||
mock_user_agent.return_value.find_many.assert_called_once_with(
|
||||
where=prisma.types.UserAgentWhereInput(
|
||||
mock_library_agent.return_value.find_many.assert_called_once_with(
|
||||
where=prisma.types.LibraryAgentWhereInput(
|
||||
userId="test-user", isDeleted=False, isArchived=False
|
||||
),
|
||||
include={
|
||||
@@ -152,26 +153,26 @@ async def test_add_agent_to_library(mocker):
|
||||
return_value=mock_store_listing
|
||||
)
|
||||
|
||||
mock_user_agent = mocker.patch("prisma.models.UserAgent.prisma")
|
||||
mock_user_agent.return_value.find_first = mocker.AsyncMock(return_value=None)
|
||||
mock_user_agent.return_value.create = mocker.AsyncMock()
|
||||
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
|
||||
mock_library_agent.return_value.find_first = mocker.AsyncMock(return_value=None)
|
||||
mock_library_agent.return_value.create = mocker.AsyncMock()
|
||||
|
||||
# Call function
|
||||
await db.add_agent_to_library("version123", "test-user")
|
||||
await db.add_store_agent_to_library("version123", "test-user")
|
||||
|
||||
# Verify mocks called correctly
|
||||
mock_store_listing_version.return_value.find_unique.assert_called_once_with(
|
||||
where={"id": "version123"}, include={"Agent": True}
|
||||
)
|
||||
mock_user_agent.return_value.find_first.assert_called_once_with(
|
||||
mock_library_agent.return_value.find_first.assert_called_once_with(
|
||||
where={
|
||||
"userId": "test-user",
|
||||
"agentId": "agent1",
|
||||
"agentVersion": 1,
|
||||
}
|
||||
)
|
||||
mock_user_agent.return_value.create.assert_called_once_with(
|
||||
data=prisma.types.UserAgentCreateInput(
|
||||
mock_library_agent.return_value.create.assert_called_once_with(
|
||||
data=prisma.types.LibraryAgentCreateInput(
|
||||
userId="test-user", agentId="agent1", agentVersion=1, isCreatedByUser=False
|
||||
)
|
||||
)
|
||||
@@ -189,7 +190,7 @@ async def test_add_agent_to_library_not_found(mocker):
|
||||
|
||||
# Call function and verify exception
|
||||
with pytest.raises(backend.server.v2.store.exceptions.AgentNotFoundError):
|
||||
await db.add_agent_to_library("version123", "test-user")
|
||||
await db.add_store_agent_to_library("version123", "test-user")
|
||||
|
||||
# Verify mock called correctly
|
||||
mock_store_listing_version.return_value.find_unique.assert_called_once_with(
|
||||
|
||||
@@ -6,21 +6,64 @@ import prisma.models
|
||||
import pydantic
|
||||
|
||||
import backend.data.block
|
||||
import backend.data.graph
|
||||
import backend.server.model
|
||||
|
||||
|
||||
class LibraryAgent(pydantic.BaseModel):
|
||||
id: str # Changed from agent_id to match GraphMeta
|
||||
version: int # Changed from agent_version to match GraphMeta
|
||||
is_active: bool # Added to match GraphMeta
|
||||
|
||||
agent_id: str
|
||||
agent_version: int # Changed from agent_version to match GraphMeta
|
||||
|
||||
preset_id: str | None
|
||||
|
||||
updated_at: datetime.datetime
|
||||
|
||||
name: str
|
||||
description: str
|
||||
|
||||
isCreatedByUser: bool
|
||||
# Made input_schema and output_schema match GraphMeta's type
|
||||
input_schema: dict[str, typing.Any] # Should be BlockIOObjectSubSchema in frontend
|
||||
output_schema: dict[str, typing.Any] # Should be BlockIOObjectSubSchema in frontend
|
||||
|
||||
is_favorite: bool
|
||||
is_created_by_user: bool
|
||||
|
||||
is_latest_version: bool
|
||||
|
||||
@staticmethod
|
||||
def from_db(agent: prisma.models.LibraryAgent):
|
||||
if not agent.Agent:
|
||||
raise ValueError("AgentGraph is required")
|
||||
|
||||
graph = backend.data.graph.GraphModel.from_db(agent.Agent)
|
||||
|
||||
agent_updated_at = agent.Agent.updatedAt
|
||||
lib_agent_updated_at = agent.updatedAt
|
||||
|
||||
# Take the latest updated_at timestamp either when the graph was updated or the library agent was updated
|
||||
updated_at = (
|
||||
max(agent_updated_at, lib_agent_updated_at)
|
||||
if agent_updated_at
|
||||
else lib_agent_updated_at
|
||||
)
|
||||
|
||||
return LibraryAgent(
|
||||
id=agent.id,
|
||||
agent_id=agent.agentId,
|
||||
agent_version=agent.agentVersion,
|
||||
updated_at=updated_at,
|
||||
name=graph.name,
|
||||
description=graph.description,
|
||||
input_schema=graph.input_schema,
|
||||
output_schema=graph.output_schema,
|
||||
is_favorite=agent.isFavorite,
|
||||
is_created_by_user=agent.isCreatedByUser,
|
||||
is_latest_version=graph.is_active,
|
||||
preset_id=agent.AgentPreset.id if agent.AgentPreset else None,
|
||||
)
|
||||
|
||||
|
||||
class LibraryAgentPreset(pydantic.BaseModel):
|
||||
id: str
|
||||
|
||||
@@ -10,20 +10,26 @@ import backend.server.v2.library.model
|
||||
def test_library_agent():
|
||||
agent = backend.server.v2.library.model.LibraryAgent(
|
||||
id="test-agent-123",
|
||||
version=1,
|
||||
is_active=True,
|
||||
agent_id="agent-123",
|
||||
agent_version=1,
|
||||
preset_id=None,
|
||||
updated_at=datetime.datetime.now(),
|
||||
name="Test Agent",
|
||||
description="Test description",
|
||||
isCreatedByUser=False,
|
||||
input_schema={"type": "object", "properties": {}},
|
||||
output_schema={"type": "object", "properties": {}},
|
||||
is_favorite=False,
|
||||
is_created_by_user=False,
|
||||
is_latest_version=True,
|
||||
)
|
||||
assert agent.id == "test-agent-123"
|
||||
assert agent.version == 1
|
||||
assert agent.is_active is True
|
||||
assert agent.agent_id == "agent-123"
|
||||
assert agent.agent_version == 1
|
||||
assert agent.name == "Test Agent"
|
||||
assert agent.description == "Test description"
|
||||
assert agent.isCreatedByUser is False
|
||||
assert agent.is_favorite is False
|
||||
assert agent.is_created_by_user is False
|
||||
assert agent.is_latest_version is True
|
||||
assert agent.input_schema == {"type": "object", "properties": {}}
|
||||
assert agent.output_schema == {"type": "object", "properties": {}}
|
||||
|
||||
@@ -31,20 +37,26 @@ def test_library_agent():
|
||||
def test_library_agent_with_user_created():
|
||||
agent = backend.server.v2.library.model.LibraryAgent(
|
||||
id="user-agent-456",
|
||||
version=2,
|
||||
is_active=True,
|
||||
agent_id="agent-456",
|
||||
agent_version=2,
|
||||
preset_id=None,
|
||||
updated_at=datetime.datetime.now(),
|
||||
name="User Created Agent",
|
||||
description="An agent created by the user",
|
||||
isCreatedByUser=True,
|
||||
input_schema={"type": "object", "properties": {}},
|
||||
output_schema={"type": "object", "properties": {}},
|
||||
is_favorite=False,
|
||||
is_created_by_user=True,
|
||||
is_latest_version=True,
|
||||
)
|
||||
assert agent.id == "user-agent-456"
|
||||
assert agent.version == 2
|
||||
assert agent.is_active is True
|
||||
assert agent.agent_id == "agent-456"
|
||||
assert agent.agent_version == 2
|
||||
assert agent.name == "User Created Agent"
|
||||
assert agent.description == "An agent created by the user"
|
||||
assert agent.isCreatedByUser is True
|
||||
assert agent.is_favorite is False
|
||||
assert agent.is_created_by_user is True
|
||||
assert agent.is_latest_version is True
|
||||
assert agent.input_schema == {"type": "object", "properties": {}}
|
||||
assert agent.output_schema == {"type": "object", "properties": {}}
|
||||
|
||||
@@ -71,7 +83,11 @@ def test_library_agent_preset():
|
||||
assert preset.agent_id == "test-agent-123"
|
||||
assert preset.agent_version == 1
|
||||
assert preset.is_active is True
|
||||
assert preset.inputs == {"input1": "test value"}
|
||||
assert preset.inputs == {
|
||||
"input1": backend.data.block.BlockInput(
|
||||
name="input1", data={"type": "string", "value": "test value"}
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
def test_library_agent_preset_response():
|
||||
@@ -127,7 +143,11 @@ def test_create_library_agent_preset_request():
|
||||
assert request.agent_id == "agent-123"
|
||||
assert request.agent_version == 1
|
||||
assert request.is_active is True
|
||||
assert request.inputs == {"input1": "test value"}
|
||||
assert request.inputs == {
|
||||
"input1": backend.data.block.BlockInput(
|
||||
name="input1", data={"type": "string", "value": "test value"}
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
def test_library_agent_from_db():
|
||||
|
||||
@@ -0,0 +1,9 @@
|
||||
import fastapi
|
||||
|
||||
import backend.server.v2.library.routes.agents
|
||||
import backend.server.v2.library.routes.presets
|
||||
|
||||
router = fastapi.APIRouter()
|
||||
|
||||
router.include_router(backend.server.v2.library.routes.presets.router)
|
||||
router.include_router(backend.server.v2.library.routes.agents.router)
|
||||
@@ -0,0 +1,148 @@
|
||||
import logging
|
||||
import typing
|
||||
|
||||
import autogpt_libs.auth.depends
|
||||
import autogpt_libs.auth.middleware
|
||||
import autogpt_libs.utils.cache
|
||||
import fastapi
|
||||
|
||||
import backend.server.v2.library.db
|
||||
import backend.server.v2.library.model
|
||||
import backend.server.v2.store.exceptions
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = fastapi.APIRouter()
|
||||
|
||||
|
||||
@router.get(
|
||||
"/agents",
|
||||
tags=["library", "private"],
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
|
||||
)
|
||||
async def get_library_agents(
|
||||
user_id: typing.Annotated[
|
||||
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
|
||||
]
|
||||
) -> typing.Sequence[backend.server.v2.library.model.LibraryAgent]:
|
||||
"""
|
||||
Get all agents in the user's library, including both created and saved agents.
|
||||
"""
|
||||
try:
|
||||
agents = await backend.server.v2.library.db.get_library_agents(user_id)
|
||||
return agents
|
||||
except Exception as e:
|
||||
logger.exception(f"Exception occurred whilst getting library agents: {e}")
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500, detail="Failed to get library agents"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/agents/{store_listing_version_id}",
|
||||
tags=["library", "private"],
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
|
||||
status_code=201,
|
||||
)
|
||||
async def add_agent_to_library(
|
||||
store_listing_version_id: str,
|
||||
user_id: typing.Annotated[
|
||||
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
|
||||
],
|
||||
) -> fastapi.Response:
|
||||
"""
|
||||
Add an agent from the store to the user's library.
|
||||
|
||||
Args:
|
||||
store_listing_version_id (str): ID of the store listing version to add
|
||||
user_id (str): ID of the authenticated user
|
||||
|
||||
Returns:
|
||||
fastapi.Response: 201 status code on success
|
||||
|
||||
Raises:
|
||||
HTTPException: If there is an error adding the agent to the library
|
||||
"""
|
||||
try:
|
||||
# Use the database function to add the agent to the library
|
||||
await backend.server.v2.library.db.add_store_agent_to_library(
|
||||
store_listing_version_id, user_id
|
||||
)
|
||||
return fastapi.Response(status_code=201)
|
||||
|
||||
except backend.server.v2.store.exceptions.AgentNotFoundError:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Store listing version {store_listing_version_id} not found",
|
||||
)
|
||||
except backend.server.v2.store.exceptions.DatabaseError as e:
|
||||
logger.exception(f"Database error occurred whilst adding agent to library: {e}")
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500, detail="Failed to add agent to library"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Unexpected exception occurred whilst adding agent to library: {e}"
|
||||
)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500, detail="Failed to add agent to library"
|
||||
)
|
||||
|
||||
|
||||
@router.put(
|
||||
"/agents/{library_agent_id}",
|
||||
tags=["library", "private"],
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
|
||||
status_code=204,
|
||||
)
|
||||
async def update_library_agent(
|
||||
library_agent_id: str,
|
||||
user_id: typing.Annotated[
|
||||
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
|
||||
],
|
||||
auto_update_version: bool = False,
|
||||
is_favorite: bool = False,
|
||||
is_archived: bool = False,
|
||||
is_deleted: bool = False,
|
||||
) -> fastapi.Response:
|
||||
"""
|
||||
Update the library agent with the given fields.
|
||||
|
||||
Args:
|
||||
library_agent_id (str): ID of the library agent to update
|
||||
user_id (str): ID of the authenticated user
|
||||
auto_update_version (bool): Whether to auto-update the agent version
|
||||
is_favorite (bool): Whether the agent is marked as favorite
|
||||
is_archived (bool): Whether the agent is archived
|
||||
is_deleted (bool): Whether the agent is deleted
|
||||
|
||||
Returns:
|
||||
fastapi.Response: 204 status code on success
|
||||
|
||||
Raises:
|
||||
HTTPException: If there is an error updating the library agent
|
||||
"""
|
||||
try:
|
||||
# Use the database function to update the library agent
|
||||
await backend.server.v2.library.db.update_library_agent(
|
||||
library_agent_id,
|
||||
user_id,
|
||||
auto_update_version,
|
||||
is_favorite,
|
||||
is_archived,
|
||||
is_deleted,
|
||||
)
|
||||
return fastapi.Response(status_code=204)
|
||||
|
||||
except backend.server.v2.store.exceptions.DatabaseError as e:
|
||||
logger.exception(f"Database error occurred whilst updating library agent: {e}")
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500, detail="Failed to update library agent"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Unexpected exception occurred whilst updating library agent: {e}"
|
||||
)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500, detail="Failed to update library agent"
|
||||
)
|
||||
@@ -5,7 +5,6 @@ import autogpt_libs.auth.depends
|
||||
import autogpt_libs.auth.middleware
|
||||
import autogpt_libs.utils.cache
|
||||
import fastapi
|
||||
import prisma
|
||||
|
||||
import backend.data.graph
|
||||
import backend.executor
|
||||
@@ -28,109 +27,6 @@ def execution_manager_client() -> backend.executor.ExecutionManager:
|
||||
return backend.util.service.get_service_client(backend.executor.ExecutionManager)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/agents",
|
||||
tags=["library", "private"],
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
|
||||
)
|
||||
async def get_library_agents(
|
||||
user_id: typing.Annotated[
|
||||
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
|
||||
]
|
||||
) -> typing.Sequence[backend.server.v2.library.model.LibraryAgent]:
|
||||
"""
|
||||
Get all agents in the user's library, including both created and saved agents.
|
||||
"""
|
||||
try:
|
||||
agents = await backend.server.v2.library.db.get_library_agents(user_id)
|
||||
return agents
|
||||
except Exception as e:
|
||||
logger.exception(f"Exception occurred whilst getting library agents: {e}")
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500, detail="Failed to get library agents"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/agents/{store_listing_version_id}",
|
||||
tags=["library", "private"],
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
|
||||
status_code=201,
|
||||
)
|
||||
async def add_agent_to_library(
|
||||
store_listing_version_id: str,
|
||||
user_id: typing.Annotated[
|
||||
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
|
||||
],
|
||||
) -> fastapi.Response:
|
||||
"""
|
||||
Add an agent from the store to the user's library.
|
||||
|
||||
Args:
|
||||
store_listing_version_id (str): ID of the store listing version to add
|
||||
user_id (str): ID of the authenticated user
|
||||
|
||||
Returns:
|
||||
fastapi.Response: 201 status code on success
|
||||
|
||||
Raises:
|
||||
HTTPException: If there is an error adding the agent to the library
|
||||
"""
|
||||
try:
|
||||
# Get the graph from the store listing
|
||||
store_listing_version = (
|
||||
await prisma.models.StoreListingVersion.prisma().find_unique(
|
||||
where={"id": store_listing_version_id}, include={"Agent": True}
|
||||
)
|
||||
)
|
||||
|
||||
if not store_listing_version or not store_listing_version.Agent:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Store listing version {store_listing_version_id} not found",
|
||||
)
|
||||
|
||||
agent = store_listing_version.Agent
|
||||
|
||||
if agent.userId == user_id:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=400, detail="Cannot add own agent to library"
|
||||
)
|
||||
|
||||
# Create a new graph from the template
|
||||
graph = await backend.data.graph.get_graph(
|
||||
agent.id, agent.version, user_id=user_id
|
||||
)
|
||||
|
||||
if not graph:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=404, detail=f"Agent {agent.id} not found"
|
||||
)
|
||||
|
||||
# Create a deep copy with new IDs
|
||||
graph.version = 1
|
||||
graph.is_template = False
|
||||
graph.is_active = True
|
||||
graph.reassign_ids(user_id=user_id, reassign_graph_id=True)
|
||||
|
||||
# Save the new graph
|
||||
graph = await backend.data.graph.create_graph(graph, user_id=user_id)
|
||||
graph = (
|
||||
await backend.integrations.webhooks.graph_lifecycle_hooks.on_graph_activate(
|
||||
graph,
|
||||
get_credentials=lambda id: integration_creds_manager.get(user_id, id),
|
||||
)
|
||||
)
|
||||
|
||||
return fastapi.Response(status_code=201)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Exception occurred whilst adding agent to library: {e}")
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500, detail="Failed to add agent to library"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/presets")
|
||||
async def get_presets(
|
||||
user_id: typing.Annotated[
|
||||
@@ -1,3 +1,5 @@
|
||||
import datetime
|
||||
|
||||
import autogpt_libs.auth.depends
|
||||
import autogpt_libs.auth.middleware
|
||||
import fastapi
|
||||
@@ -35,21 +37,29 @@ def test_get_library_agents_success(mocker: pytest_mock.MockFixture):
|
||||
mocked_value = [
|
||||
backend.server.v2.library.model.LibraryAgent(
|
||||
id="test-agent-1",
|
||||
version=1,
|
||||
is_active=True,
|
||||
agent_id="test-agent-1",
|
||||
agent_version=1,
|
||||
preset_id="preset-1",
|
||||
updated_at=datetime.datetime(2023, 1, 1, 0, 0, 0),
|
||||
is_favorite=False,
|
||||
is_created_by_user=True,
|
||||
is_latest_version=True,
|
||||
name="Test Agent 1",
|
||||
description="Test Description 1",
|
||||
isCreatedByUser=True,
|
||||
input_schema={"type": "object", "properties": {}},
|
||||
output_schema={"type": "object", "properties": {}},
|
||||
),
|
||||
backend.server.v2.library.model.LibraryAgent(
|
||||
id="test-agent-2",
|
||||
version=1,
|
||||
is_active=True,
|
||||
agent_id="test-agent-2",
|
||||
agent_version=1,
|
||||
preset_id="preset-2",
|
||||
updated_at=datetime.datetime(2023, 1, 1, 0, 0, 0),
|
||||
is_favorite=False,
|
||||
is_created_by_user=False,
|
||||
is_latest_version=True,
|
||||
name="Test Agent 2",
|
||||
description="Test Description 2",
|
||||
isCreatedByUser=False,
|
||||
input_schema={"type": "object", "properties": {}},
|
||||
output_schema={"type": "object", "properties": {}},
|
||||
),
|
||||
@@ -65,10 +75,10 @@ def test_get_library_agents_success(mocker: pytest_mock.MockFixture):
|
||||
for agent in response.json()
|
||||
]
|
||||
assert len(data) == 2
|
||||
assert data[0].id == "test-agent-1"
|
||||
assert data[0].isCreatedByUser is True
|
||||
assert data[1].id == "test-agent-2"
|
||||
assert data[1].isCreatedByUser is False
|
||||
assert data[0].agent_id == "test-agent-1"
|
||||
assert data[0].is_created_by_user is True
|
||||
assert data[1].agent_id == "test-agent-2"
|
||||
assert data[1].is_created_by_user is False
|
||||
mock_db_call.assert_called_once_with("test-user-id")
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
/*
|
||||
Warnings:
|
||||
|
||||
- You are about to drop the `UserAgent` table. If the table is not empty, all the data it contains will be lost.
|
||||
|
||||
*/
|
||||
-- DropForeignKey
|
||||
ALTER TABLE "UserAgent" DROP CONSTRAINT "UserAgent_agentId_agentVersion_fkey";
|
||||
|
||||
-- DropForeignKey
|
||||
ALTER TABLE "UserAgent" DROP CONSTRAINT "UserAgent_agentPresetId_fkey";
|
||||
|
||||
-- DropForeignKey
|
||||
ALTER TABLE "UserAgent" DROP CONSTRAINT "UserAgent_userId_fkey";
|
||||
|
||||
-- DropTable
|
||||
DROP TABLE "UserAgent";
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "LibraryAgent" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"userId" TEXT NOT NULL,
|
||||
"agentId" TEXT NOT NULL,
|
||||
"agentVersion" INTEGER NOT NULL,
|
||||
"agentPresetId" TEXT,
|
||||
"isFavorite" BOOLEAN NOT NULL DEFAULT false,
|
||||
"isCreatedByUser" BOOLEAN NOT NULL DEFAULT false,
|
||||
"isArchived" BOOLEAN NOT NULL DEFAULT false,
|
||||
"isDeleted" BOOLEAN NOT NULL DEFAULT false,
|
||||
|
||||
CONSTRAINT "LibraryAgent_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "LibraryAgent_userId_idx" ON "LibraryAgent"("userId");
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "LibraryAgent" ADD CONSTRAINT "LibraryAgent_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "LibraryAgent" ADD CONSTRAINT "LibraryAgent_agentId_agentVersion_fkey" FOREIGN KEY ("agentId", "agentVersion") REFERENCES "AgentGraph"("id", "version") ON DELETE RESTRICT ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "LibraryAgent" ADD CONSTRAINT "LibraryAgent_agentPresetId_fkey" FOREIGN KEY ("agentPresetId") REFERENCES "AgentPreset"("id") ON DELETE SET NULL ON UPDATE CASCADE;
|
||||
@@ -0,0 +1,2 @@
|
||||
-- AlterTable
|
||||
ALTER TABLE "LibraryAgent" ADD COLUMN "useGraphIsActiveVersion" BOOLEAN NOT NULL DEFAULT false;
|
||||
@@ -30,7 +30,7 @@ model User {
|
||||
CreditTransaction CreditTransaction[]
|
||||
|
||||
AgentPreset AgentPreset[]
|
||||
UserAgent UserAgent[]
|
||||
LibraryAgent LibraryAgent[]
|
||||
|
||||
Profile Profile[]
|
||||
StoreListing StoreListing[]
|
||||
@@ -65,7 +65,7 @@ model AgentGraph {
|
||||
AgentGraphExecution AgentGraphExecution[]
|
||||
|
||||
AgentPreset AgentPreset[]
|
||||
UserAgent UserAgent[]
|
||||
LibraryAgent LibraryAgent[]
|
||||
StoreListing StoreListing[]
|
||||
StoreListingVersion StoreListingVersion?
|
||||
|
||||
@@ -103,7 +103,7 @@ model AgentPreset {
|
||||
Agent AgentGraph @relation(fields: [agentId, agentVersion], references: [id, version], onDelete: Cascade)
|
||||
|
||||
InputPresets AgentNodeExecutionInputOutput[] @relation("AgentPresetsInputData")
|
||||
UserAgents UserAgent[]
|
||||
LibraryAgents LibraryAgent[]
|
||||
AgentExecution AgentGraphExecution[]
|
||||
|
||||
isDeleted Boolean @default(false)
|
||||
@@ -113,7 +113,7 @@ model AgentPreset {
|
||||
|
||||
// For the library page
|
||||
// It is a user controlled list of agents, that they will see in there library
|
||||
model UserAgent {
|
||||
model LibraryAgent {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @default(now()) @updatedAt
|
||||
@@ -128,6 +128,8 @@ model UserAgent {
|
||||
agentPresetId String?
|
||||
AgentPreset AgentPreset? @relation(fields: [agentPresetId], references: [id])
|
||||
|
||||
useGraphIsActiveVersion Boolean @default(false)
|
||||
|
||||
isFavorite Boolean @default(false)
|
||||
isCreatedByUser Boolean @default(false)
|
||||
isArchived Boolean @default(false)
|
||||
|
||||
@@ -140,10 +140,10 @@ async def main():
|
||||
print(f"Inserting {NUM_USERS * MAX_AGENTS_PER_USER} user agents")
|
||||
for user in users:
|
||||
num_agents = random.randint(MIN_AGENTS_PER_USER, MAX_AGENTS_PER_USER)
|
||||
for _ in range(num_agents): # Create 1 UserAgent per user
|
||||
for _ in range(num_agents): # Create 1 LibraryAgent per user
|
||||
graph = random.choice(agent_graphs)
|
||||
preset = random.choice(agent_presets)
|
||||
user_agent = await db.useragent.create(
|
||||
user_agent = await db.libraryagent.create(
|
||||
data={
|
||||
"userId": user.id,
|
||||
"agentId": graph.id,
|
||||
|
||||
Reference in New Issue
Block a user