testing preset functionality

This commit is contained in:
SwiftyOS
2025-01-10 12:23:49 +01:00
parent 6007def168
commit 4d20e419e1
5 changed files with 272 additions and 26 deletions

View File

@@ -16,6 +16,7 @@ import backend.data.db
import backend.data.graph
import backend.data.user
import backend.server.routers.v1
import backend.server.v2.library.model
import backend.server.v2.library.routes
import backend.server.v2.store.routes
import backend.util.service
@@ -154,5 +155,58 @@ class AgentServer(backend.util.service.AppProcess):
async def test_delete_graph(graph_id: str, user_id: str):
return await backend.server.routers.v1.delete_graph(graph_id, user_id)
@staticmethod
async def test_get_presets(user_id: str, page: int = 1, page_size: int = 10):
return await backend.server.v2.library.routes.presets.get_presets(
user_id=user_id, page=page, page_size=page_size
)
@staticmethod
async def test_get_preset(preset_id: str, user_id: str):
return await backend.server.v2.library.routes.presets.get_preset(
preset_id=preset_id, user_id=user_id
)
@staticmethod
async def test_create_preset(
preset: backend.server.v2.library.model.CreateLibraryAgentPresetRequest,
user_id: str,
):
return await backend.server.v2.library.routes.presets.create_preset(
preset=preset, user_id=user_id
)
@staticmethod
async def test_update_preset(
preset_id: str,
preset: backend.server.v2.library.model.CreateLibraryAgentPresetRequest,
user_id: str,
):
return await backend.server.v2.library.routes.presets.update_preset(
preset_id=preset_id, preset=preset, user_id=user_id
)
@staticmethod
async def test_delete_preset(preset_id: str, user_id: str):
return await backend.server.v2.library.routes.presets.delete_preset(
preset_id=preset_id, user_id=user_id
)
@staticmethod
async def test_execute_preset(
graph_id: str,
graph_version: int,
preset_id: str,
node_input: dict[typing.Any, typing.Any],
user_id: str,
):
return await backend.server.v2.library.routes.presets.execute_preset(
graph_id=graph_id,
graph_version=graph_version,
preset_id=preset_id,
node_input=node_input,
user_id=user_id,
)
def set_test_dependency_overrides(self, overrides: dict):
app.dependency_overrides.update(overrides)

View File

@@ -276,7 +276,7 @@ async def get_preset(
) -> backend.server.v2.library.model.LibraryAgentPreset | None:
try:
preset = await prisma.models.AgentPreset.prisma().find_unique(
where={"id": preset_id, "userId": user_id}
where={"id": preset_id, "userId": user_id}, include={"InputPresets": True}
)
if not preset:
return None
@@ -294,36 +294,32 @@ async def create_or_update_preset(
preset_id: str | None = None,
) -> backend.server.v2.library.model.LibraryAgentPreset:
try:
logger.info(f"DB Creating Preset with inputs: {preset.inputs}")
new_preset = await prisma.models.AgentPreset.prisma().upsert(
where={
"id": preset_id if preset_id else "",
},
data={
"create": prisma.types.AgentPresetCreateInput(
userId=user_id,
name=preset.name,
description=preset.description,
agentId=preset.agent_id,
agentVersion=preset.agent_version,
Agent=prisma.types.AgentGraphUpdateOneWithoutRelationsInput(
connect=prisma.types.AgentGraphWhereUniqueInput(
id=preset.agent_id,
version=preset.agent_version,
),
),
isActive=preset.is_active,
InputPresets={
"create": {
"userId": user_id,
"name": preset.name,
"description": preset.description,
"agentId": preset.agent_id,
"agentVersion": preset.agent_version,
"isActive": preset.is_active,
"InputPresets": {
"create": [
{"name": name, "data": json.dumps(data)}
for name, data in preset.inputs.items()
]
},
),
"update": prisma.types.AgentPresetUpdateInput(
name=preset.name,
description=preset.description,
isActive=preset.is_active,
),
},
"update": {
"name": preset.name,
"description": preset.description,
"isActive": preset.is_active,
},
},
)
return backend.server.v2.library.model.LibraryAgentPreset.from_db(new_preset)

View File

@@ -76,7 +76,7 @@ class LibraryAgentPreset(pydantic.BaseModel):
description: str
is_active: bool
inputs: dict[str, backend.data.block.BlockInput]
inputs: dict[str, typing.Union[backend.data.block.BlockInput, typing.Any]]
@staticmethod
def from_db(preset: prisma.models.AgentPreset):
@@ -105,7 +105,7 @@ class LibraryAgentPresetResponse(pydantic.BaseModel):
class CreateLibraryAgentPresetRequest(pydantic.BaseModel):
name: str
description: str
inputs: dict[str, backend.data.block.BlockInput]
inputs: dict[str, typing.Union[backend.data.block.BlockInput, typing.Any]]
agent_id: str
agent_version: int
is_active: bool

View File

@@ -132,16 +132,24 @@ async def execute_preset(
if not preset:
raise fastapi.HTTPException(status_code=404, detail="Preset not found")
merged_input = {**preset.inputs, **node_input}
logger.info(f"Preset inputs: {preset.inputs}")
updated_node_input = node_input.copy()
# Merge in preset input values
for key, value in preset.inputs.items():
if key not in updated_node_input:
updated_node_input[key] = value
execution = execution_manager_client().add_execution(
graph_id=graph_id,
graph_version=graph_version,
data=merged_input,
data=updated_node_input,
user_id=user_id,
preset_id=preset_id,
)
logger.info(f"Execution added: {execution} with input: {updated_node_input}")
return {"id": execution.graph_exec_id}
except Exception as e:
msg = e.__str__().encode().decode("unicode_escape")

View File

@@ -3,7 +3,8 @@ import logging
import pytest
from prisma.models import User
from backend.blocks.basic import FindInDictionaryBlock, StoreValueBlock
import backend.server.v2.library.model
from backend.blocks.basic import AgentInputBlock, FindInDictionaryBlock, StoreValueBlock
from backend.blocks.maths import CalculatorBlock, Operation
from backend.data import execution, graph
from backend.server.model import CreateGraph
@@ -287,3 +288,190 @@ async def test_static_input_link_on_graph(server: SpinTestServer):
assert exec_data.status == execution.ExecutionStatus.COMPLETED
assert exec_data.output_data == {"result": [9]}
logger.info("Completed test_static_input_link_on_graph")
@pytest.mark.asyncio(scope="session")
async def test_execute_preset(server: SpinTestServer):
"""
Test executing a preset.
This test ensures that:
1. A preset can be successfully executed
2. The execution results are correct
Args:
server (SpinTestServer): The test server instance.
"""
# Create test graph and user
nodes = [
graph.Node( # 0
block_id=AgentInputBlock().id,
input_default={"name": "dictionary"},
),
graph.Node( # 1
block_id=AgentInputBlock().id,
input_default={"name": "selected_value"},
),
graph.Node( # 2
block_id=StoreValueBlock().id,
input_default={"input": {"key1": "Hi", "key2": "Everyone"}},
),
graph.Node( # 3
block_id=FindInDictionaryBlock().id,
input_default={"key": "", "input": {}},
),
]
links = [
graph.Link(
source_id=nodes[0].id,
sink_id=nodes[2].id,
source_name="result",
sink_name="input",
),
graph.Link(
source_id=nodes[1].id,
sink_id=nodes[3].id,
source_name="result",
sink_name="key",
),
graph.Link(
source_id=nodes[2].id,
sink_id=nodes[3].id,
source_name="output",
sink_name="input",
),
]
test_graph = graph.Graph(
name="TestGraph",
description="Test graph",
nodes=nodes,
links=links,
)
test_user = await create_test_user()
test_graph = await create_graph(server, test_graph, test_user)
# Create preset with initial values
preset = backend.server.v2.library.model.CreateLibraryAgentPresetRequest(
name="Test Preset With Clash",
description="Test preset with clashing input values",
agent_id=test_graph.id,
agent_version=test_graph.version,
inputs={
"dictionary": {"key1": "Hello", "key2": "World"},
"selected_value": "key2",
},
is_active=True,
)
created_preset = await server.agent_server.test_create_preset(preset, test_user.id)
# Execute preset with overriding values
result = await server.agent_server.test_execute_preset(
graph_id=test_graph.id,
graph_version=test_graph.version,
preset_id=created_preset.id,
node_input={},
user_id=test_user.id,
)
# Verify execution
assert result is not None
graph_exec_id = result["id"]
# Wait for execution to complete
executions = await wait_execution(test_user.id, test_graph.id, graph_exec_id)
assert len(executions) == 4
# FindInDictionaryBlock should wait for the input pin to be provided,
# Hence executing extraction of "key" from {"key1": "value1", "key2": "value2"}
assert executions[3].status == execution.ExecutionStatus.COMPLETED
assert executions[3].output_data == {"output": ["World"]}
@pytest.mark.asyncio(scope="session")
async def test_execute_preset_with_clash(server: SpinTestServer):
"""
Test executing a preset with clashing input data.
"""
# Create test graph and user
nodes = [
graph.Node( # 0
block_id=AgentInputBlock().id,
input_default={"name": "dictionary"},
),
graph.Node( # 1
block_id=AgentInputBlock().id,
input_default={"name": "selected_value"},
),
graph.Node( # 2
block_id=StoreValueBlock().id,
input_default={"input": {"key1": "Hi", "key2": "Everyone"}},
),
graph.Node( # 3
block_id=FindInDictionaryBlock().id,
input_default={"key": "", "input": {}},
),
]
links = [
graph.Link(
source_id=nodes[0].id,
sink_id=nodes[2].id,
source_name="result",
sink_name="input",
),
graph.Link(
source_id=nodes[1].id,
sink_id=nodes[3].id,
source_name="result",
sink_name="key",
),
graph.Link(
source_id=nodes[2].id,
sink_id=nodes[3].id,
source_name="output",
sink_name="input",
),
]
test_graph = graph.Graph(
name="TestGraph",
description="Test graph",
nodes=nodes,
links=links,
)
test_user = await create_test_user()
test_graph = await create_graph(server, test_graph, test_user)
# Create preset with initial values
preset = backend.server.v2.library.model.CreateLibraryAgentPresetRequest(
name="Test Preset With Clash",
description="Test preset with clashing input values",
agent_id=test_graph.id,
agent_version=test_graph.version,
inputs={
"dictionary": {"key1": "Hello", "key2": "World"},
"selected_value": "key2",
},
is_active=True,
)
created_preset = await server.agent_server.test_create_preset(preset, test_user.id)
# Execute preset with overriding values
result = await server.agent_server.test_execute_preset(
graph_id=test_graph.id,
graph_version=test_graph.version,
preset_id=created_preset.id,
node_input={"selected_value": "key1"},
user_id=test_user.id,
)
# Verify execution
assert result is not None
graph_exec_id = result["id"]
# Wait for execution to complete
executions = await wait_execution(test_user.id, test_graph.id, graph_exec_id)
assert len(executions) == 4
# FindInDictionaryBlock should wait for the input pin to be provided,
# Hence executing extraction of "key" from {"key1": "value1", "key2": "value2"}
assert executions[3].status == execution.ExecutionStatus.COMPLETED
assert executions[3].output_data == {"output": ["Hello"]}