mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
testing preset functionality
This commit is contained in:
@@ -16,6 +16,7 @@ import backend.data.db
|
||||
import backend.data.graph
|
||||
import backend.data.user
|
||||
import backend.server.routers.v1
|
||||
import backend.server.v2.library.model
|
||||
import backend.server.v2.library.routes
|
||||
import backend.server.v2.store.routes
|
||||
import backend.util.service
|
||||
@@ -154,5 +155,58 @@ class AgentServer(backend.util.service.AppProcess):
|
||||
async def test_delete_graph(graph_id: str, user_id: str):
|
||||
return await backend.server.routers.v1.delete_graph(graph_id, user_id)
|
||||
|
||||
@staticmethod
|
||||
async def test_get_presets(user_id: str, page: int = 1, page_size: int = 10):
|
||||
return await backend.server.v2.library.routes.presets.get_presets(
|
||||
user_id=user_id, page=page, page_size=page_size
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def test_get_preset(preset_id: str, user_id: str):
|
||||
return await backend.server.v2.library.routes.presets.get_preset(
|
||||
preset_id=preset_id, user_id=user_id
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def test_create_preset(
|
||||
preset: backend.server.v2.library.model.CreateLibraryAgentPresetRequest,
|
||||
user_id: str,
|
||||
):
|
||||
return await backend.server.v2.library.routes.presets.create_preset(
|
||||
preset=preset, user_id=user_id
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def test_update_preset(
|
||||
preset_id: str,
|
||||
preset: backend.server.v2.library.model.CreateLibraryAgentPresetRequest,
|
||||
user_id: str,
|
||||
):
|
||||
return await backend.server.v2.library.routes.presets.update_preset(
|
||||
preset_id=preset_id, preset=preset, user_id=user_id
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def test_delete_preset(preset_id: str, user_id: str):
|
||||
return await backend.server.v2.library.routes.presets.delete_preset(
|
||||
preset_id=preset_id, user_id=user_id
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def test_execute_preset(
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
preset_id: str,
|
||||
node_input: dict[typing.Any, typing.Any],
|
||||
user_id: str,
|
||||
):
|
||||
return await backend.server.v2.library.routes.presets.execute_preset(
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
preset_id=preset_id,
|
||||
node_input=node_input,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
def set_test_dependency_overrides(self, overrides: dict):
|
||||
app.dependency_overrides.update(overrides)
|
||||
|
||||
@@ -276,7 +276,7 @@ async def get_preset(
|
||||
) -> backend.server.v2.library.model.LibraryAgentPreset | None:
|
||||
try:
|
||||
preset = await prisma.models.AgentPreset.prisma().find_unique(
|
||||
where={"id": preset_id, "userId": user_id}
|
||||
where={"id": preset_id, "userId": user_id}, include={"InputPresets": True}
|
||||
)
|
||||
if not preset:
|
||||
return None
|
||||
@@ -294,36 +294,32 @@ async def create_or_update_preset(
|
||||
preset_id: str | None = None,
|
||||
) -> backend.server.v2.library.model.LibraryAgentPreset:
|
||||
try:
|
||||
|
||||
logger.info(f"DB Creating Preset with inputs: {preset.inputs}")
|
||||
new_preset = await prisma.models.AgentPreset.prisma().upsert(
|
||||
where={
|
||||
"id": preset_id if preset_id else "",
|
||||
},
|
||||
data={
|
||||
"create": prisma.types.AgentPresetCreateInput(
|
||||
userId=user_id,
|
||||
name=preset.name,
|
||||
description=preset.description,
|
||||
agentId=preset.agent_id,
|
||||
agentVersion=preset.agent_version,
|
||||
Agent=prisma.types.AgentGraphUpdateOneWithoutRelationsInput(
|
||||
connect=prisma.types.AgentGraphWhereUniqueInput(
|
||||
id=preset.agent_id,
|
||||
version=preset.agent_version,
|
||||
),
|
||||
),
|
||||
isActive=preset.is_active,
|
||||
InputPresets={
|
||||
"create": {
|
||||
"userId": user_id,
|
||||
"name": preset.name,
|
||||
"description": preset.description,
|
||||
"agentId": preset.agent_id,
|
||||
"agentVersion": preset.agent_version,
|
||||
"isActive": preset.is_active,
|
||||
"InputPresets": {
|
||||
"create": [
|
||||
{"name": name, "data": json.dumps(data)}
|
||||
for name, data in preset.inputs.items()
|
||||
]
|
||||
},
|
||||
),
|
||||
"update": prisma.types.AgentPresetUpdateInput(
|
||||
name=preset.name,
|
||||
description=preset.description,
|
||||
isActive=preset.is_active,
|
||||
),
|
||||
},
|
||||
"update": {
|
||||
"name": preset.name,
|
||||
"description": preset.description,
|
||||
"isActive": preset.is_active,
|
||||
},
|
||||
},
|
||||
)
|
||||
return backend.server.v2.library.model.LibraryAgentPreset.from_db(new_preset)
|
||||
|
||||
@@ -76,7 +76,7 @@ class LibraryAgentPreset(pydantic.BaseModel):
|
||||
description: str
|
||||
|
||||
is_active: bool
|
||||
inputs: dict[str, backend.data.block.BlockInput]
|
||||
inputs: dict[str, typing.Union[backend.data.block.BlockInput, typing.Any]]
|
||||
|
||||
@staticmethod
|
||||
def from_db(preset: prisma.models.AgentPreset):
|
||||
@@ -105,7 +105,7 @@ class LibraryAgentPresetResponse(pydantic.BaseModel):
|
||||
class CreateLibraryAgentPresetRequest(pydantic.BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
inputs: dict[str, backend.data.block.BlockInput]
|
||||
inputs: dict[str, typing.Union[backend.data.block.BlockInput, typing.Any]]
|
||||
agent_id: str
|
||||
agent_version: int
|
||||
is_active: bool
|
||||
|
||||
@@ -132,16 +132,24 @@ async def execute_preset(
|
||||
if not preset:
|
||||
raise fastapi.HTTPException(status_code=404, detail="Preset not found")
|
||||
|
||||
merged_input = {**preset.inputs, **node_input}
|
||||
logger.info(f"Preset inputs: {preset.inputs}")
|
||||
|
||||
updated_node_input = node_input.copy()
|
||||
# Merge in preset input values
|
||||
for key, value in preset.inputs.items():
|
||||
if key not in updated_node_input:
|
||||
updated_node_input[key] = value
|
||||
|
||||
execution = execution_manager_client().add_execution(
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
data=merged_input,
|
||||
data=updated_node_input,
|
||||
user_id=user_id,
|
||||
preset_id=preset_id,
|
||||
)
|
||||
|
||||
logger.info(f"Execution added: {execution} with input: {updated_node_input}")
|
||||
|
||||
return {"id": execution.graph_exec_id}
|
||||
except Exception as e:
|
||||
msg = e.__str__().encode().decode("unicode_escape")
|
||||
|
||||
@@ -3,7 +3,8 @@ import logging
|
||||
import pytest
|
||||
from prisma.models import User
|
||||
|
||||
from backend.blocks.basic import FindInDictionaryBlock, StoreValueBlock
|
||||
import backend.server.v2.library.model
|
||||
from backend.blocks.basic import AgentInputBlock, FindInDictionaryBlock, StoreValueBlock
|
||||
from backend.blocks.maths import CalculatorBlock, Operation
|
||||
from backend.data import execution, graph
|
||||
from backend.server.model import CreateGraph
|
||||
@@ -287,3 +288,190 @@ async def test_static_input_link_on_graph(server: SpinTestServer):
|
||||
assert exec_data.status == execution.ExecutionStatus.COMPLETED
|
||||
assert exec_data.output_data == {"result": [9]}
|
||||
logger.info("Completed test_static_input_link_on_graph")
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_execute_preset(server: SpinTestServer):
|
||||
"""
|
||||
Test executing a preset.
|
||||
|
||||
This test ensures that:
|
||||
1. A preset can be successfully executed
|
||||
2. The execution results are correct
|
||||
|
||||
Args:
|
||||
server (SpinTestServer): The test server instance.
|
||||
"""
|
||||
# Create test graph and user
|
||||
nodes = [
|
||||
graph.Node( # 0
|
||||
block_id=AgentInputBlock().id,
|
||||
input_default={"name": "dictionary"},
|
||||
),
|
||||
graph.Node( # 1
|
||||
block_id=AgentInputBlock().id,
|
||||
input_default={"name": "selected_value"},
|
||||
),
|
||||
graph.Node( # 2
|
||||
block_id=StoreValueBlock().id,
|
||||
input_default={"input": {"key1": "Hi", "key2": "Everyone"}},
|
||||
),
|
||||
graph.Node( # 3
|
||||
block_id=FindInDictionaryBlock().id,
|
||||
input_default={"key": "", "input": {}},
|
||||
),
|
||||
]
|
||||
links = [
|
||||
graph.Link(
|
||||
source_id=nodes[0].id,
|
||||
sink_id=nodes[2].id,
|
||||
source_name="result",
|
||||
sink_name="input",
|
||||
),
|
||||
graph.Link(
|
||||
source_id=nodes[1].id,
|
||||
sink_id=nodes[3].id,
|
||||
source_name="result",
|
||||
sink_name="key",
|
||||
),
|
||||
graph.Link(
|
||||
source_id=nodes[2].id,
|
||||
sink_id=nodes[3].id,
|
||||
source_name="output",
|
||||
sink_name="input",
|
||||
),
|
||||
]
|
||||
test_graph = graph.Graph(
|
||||
name="TestGraph",
|
||||
description="Test graph",
|
||||
nodes=nodes,
|
||||
links=links,
|
||||
)
|
||||
test_user = await create_test_user()
|
||||
test_graph = await create_graph(server, test_graph, test_user)
|
||||
|
||||
# Create preset with initial values
|
||||
preset = backend.server.v2.library.model.CreateLibraryAgentPresetRequest(
|
||||
name="Test Preset With Clash",
|
||||
description="Test preset with clashing input values",
|
||||
agent_id=test_graph.id,
|
||||
agent_version=test_graph.version,
|
||||
inputs={
|
||||
"dictionary": {"key1": "Hello", "key2": "World"},
|
||||
"selected_value": "key2",
|
||||
},
|
||||
is_active=True,
|
||||
)
|
||||
created_preset = await server.agent_server.test_create_preset(preset, test_user.id)
|
||||
|
||||
# Execute preset with overriding values
|
||||
result = await server.agent_server.test_execute_preset(
|
||||
graph_id=test_graph.id,
|
||||
graph_version=test_graph.version,
|
||||
preset_id=created_preset.id,
|
||||
node_input={},
|
||||
user_id=test_user.id,
|
||||
)
|
||||
|
||||
# Verify execution
|
||||
assert result is not None
|
||||
graph_exec_id = result["id"]
|
||||
|
||||
# Wait for execution to complete
|
||||
executions = await wait_execution(test_user.id, test_graph.id, graph_exec_id)
|
||||
assert len(executions) == 4
|
||||
|
||||
# FindInDictionaryBlock should wait for the input pin to be provided,
|
||||
# Hence executing extraction of "key" from {"key1": "value1", "key2": "value2"}
|
||||
assert executions[3].status == execution.ExecutionStatus.COMPLETED
|
||||
assert executions[3].output_data == {"output": ["World"]}
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_execute_preset_with_clash(server: SpinTestServer):
|
||||
"""
|
||||
Test executing a preset with clashing input data.
|
||||
"""
|
||||
# Create test graph and user
|
||||
nodes = [
|
||||
graph.Node( # 0
|
||||
block_id=AgentInputBlock().id,
|
||||
input_default={"name": "dictionary"},
|
||||
),
|
||||
graph.Node( # 1
|
||||
block_id=AgentInputBlock().id,
|
||||
input_default={"name": "selected_value"},
|
||||
),
|
||||
graph.Node( # 2
|
||||
block_id=StoreValueBlock().id,
|
||||
input_default={"input": {"key1": "Hi", "key2": "Everyone"}},
|
||||
),
|
||||
graph.Node( # 3
|
||||
block_id=FindInDictionaryBlock().id,
|
||||
input_default={"key": "", "input": {}},
|
||||
),
|
||||
]
|
||||
links = [
|
||||
graph.Link(
|
||||
source_id=nodes[0].id,
|
||||
sink_id=nodes[2].id,
|
||||
source_name="result",
|
||||
sink_name="input",
|
||||
),
|
||||
graph.Link(
|
||||
source_id=nodes[1].id,
|
||||
sink_id=nodes[3].id,
|
||||
source_name="result",
|
||||
sink_name="key",
|
||||
),
|
||||
graph.Link(
|
||||
source_id=nodes[2].id,
|
||||
sink_id=nodes[3].id,
|
||||
source_name="output",
|
||||
sink_name="input",
|
||||
),
|
||||
]
|
||||
test_graph = graph.Graph(
|
||||
name="TestGraph",
|
||||
description="Test graph",
|
||||
nodes=nodes,
|
||||
links=links,
|
||||
)
|
||||
test_user = await create_test_user()
|
||||
test_graph = await create_graph(server, test_graph, test_user)
|
||||
|
||||
# Create preset with initial values
|
||||
preset = backend.server.v2.library.model.CreateLibraryAgentPresetRequest(
|
||||
name="Test Preset With Clash",
|
||||
description="Test preset with clashing input values",
|
||||
agent_id=test_graph.id,
|
||||
agent_version=test_graph.version,
|
||||
inputs={
|
||||
"dictionary": {"key1": "Hello", "key2": "World"},
|
||||
"selected_value": "key2",
|
||||
},
|
||||
is_active=True,
|
||||
)
|
||||
created_preset = await server.agent_server.test_create_preset(preset, test_user.id)
|
||||
|
||||
# Execute preset with overriding values
|
||||
result = await server.agent_server.test_execute_preset(
|
||||
graph_id=test_graph.id,
|
||||
graph_version=test_graph.version,
|
||||
preset_id=created_preset.id,
|
||||
node_input={"selected_value": "key1"},
|
||||
user_id=test_user.id,
|
||||
)
|
||||
|
||||
# Verify execution
|
||||
assert result is not None
|
||||
graph_exec_id = result["id"]
|
||||
|
||||
# Wait for execution to complete
|
||||
executions = await wait_execution(test_user.id, test_graph.id, graph_exec_id)
|
||||
assert len(executions) == 4
|
||||
|
||||
# FindInDictionaryBlock should wait for the input pin to be provided,
|
||||
# Hence executing extraction of "key" from {"key1": "value1", "key2": "value2"}
|
||||
assert executions[3].status == execution.ExecutionStatus.COMPLETED
|
||||
assert executions[3].output_data == {"output": ["Hello"]}
|
||||
|
||||
Reference in New Issue
Block a user