fixing tests

This commit is contained in:
Aarushi
2024-08-05 11:35:34 +01:00
parent 5c9c65806b
commit aa5c044b40
6 changed files with 63 additions and 11 deletions

View File

@@ -74,5 +74,5 @@ async def add_schedule(schedule: ExecutionSchedule) -> ExecutionSchedule:
async def update_schedule(schedule_id: str, is_enabled: bool, user_id: str):
await AgentGraphExecutionSchedule.prisma().update(
where={{"id": schedule_id}, {"userId": user_id}}, data={"isEnabled": is_enabled}
where={"id": schedule_id}, data={"isEnabled": is_enabled}
)

View File

@@ -416,8 +416,10 @@ class ExecutionManager(AppService):
return get_agent_server_client()
@expose
def add_execution(self, graph_id: str, data: BlockInput) -> dict[Any, Any]:
graph: Graph | None = self.run_and_wait(get_graph(graph_id))
def add_execution(
self, graph_id: str, data: BlockInput, user_id: str
) -> dict[Any, Any]:
graph: Graph | None = self.run_and_wait(get_graph(graph_id, user_id=user_id))
if not graph:
raise Exception(f"Graph #{graph_id} not found.")
graph.validate_graph()

View File

@@ -569,7 +569,7 @@ class AgentServer(AppService):
user_id: str = Depends(get_user_id),
):
new_active_version = request_body.active_graph_version
if not await graph_db.get_graph(graph_id, new_active_version):
if not await graph_db.get_graph(graph_id, new_active_version, user_id=user_id):
raise HTTPException(
404, f"Graph #{graph_id} v{new_active_version} not found"
)
@@ -580,19 +580,27 @@ class AgentServer(AppService):
)
async def execute_graph(
self, graph_id: str, node_input: dict[Any, Any]
self,
graph_id: str,
node_input: dict[Any, Any],
user_id: str = Depends(get_user_id),
) -> dict[Any, Any]:
try:
return self.execution_manager_client.add_execution(graph_id, node_input)
return self.execution_manager_client.add_execution(
graph_id, node_input, user_id=user_id
)
except Exception as e:
msg = e.__str__().encode().decode("unicode_escape")
raise HTTPException(status_code=400, detail=msg)
@classmethod
async def list_graph_runs(
cls, graph_id: str, graph_version: int | None = None
cls,
graph_id: str,
graph_version: int | None = None,
user_id: str = Depends(get_user_id),
) -> list[str]:
graph = await graph_db.get_graph(graph_id, graph_version)
graph = await graph_db.get_graph(graph_id, graph_version, user_id=user_id)
if not graph:
rev = "" if graph_version is None else f" v{graph_version}"
raise HTTPException(

View File

@@ -1,11 +1,14 @@
from pathlib import Path
from prisma.models import User
from autogpt_server.blocks.basic import ValueBlock
from autogpt_server.blocks.block import BlockInstallationBlock
from autogpt_server.blocks.http import HttpRequestBlock
from autogpt_server.blocks.llm import TextLlmCallBlock
from autogpt_server.blocks.text import TextFormatterBlock, TextParserBlock
from autogpt_server.data.graph import Graph, Link, Node, create_graph
from autogpt_server.data.user import get_or_create_user
from autogpt_server.util.test import SpinTestServer, wait_execution
sample_block_modules = {
@@ -23,6 +26,16 @@ for module, description in sample_block_modules.items():
sample_block_codes[module] = f"[Example: {description}]\n{code}"
async def create_test_user() -> User:
test_user_data = {
"sub": "ef3b97d7-1161-4eb4-92b2-10c24fb154c1",
"email": "testuser@example.com",
"name": "Test User",
}
user = await get_or_create_user(test_user_data)
return user
def create_test_graph() -> Graph:
"""
ValueBlock (input)
@@ -237,7 +250,8 @@ Here are a couple of sample of the Block class implementation:
async def block_autogen_agent():
async with SpinTestServer() as server:
test_manager = server.exec_manager
test_graph = await create_graph(create_test_graph())
test_user = await create_test_user()
test_graph = await create_graph(create_test_graph(), user_id=test_user.id)
input_data = {"input": "Write me a block that writes a string into a file."}
response = await server.agent_server.execute_graph(test_graph.id, input_data)
print(response)

View File

@@ -1,7 +1,10 @@
from prisma.models import User
from autogpt_server.blocks.llm import ObjectLlmCallBlock
from autogpt_server.blocks.reddit import RedditGetPostsBlock, RedditPostCommentBlock
from autogpt_server.blocks.text import TextFormatterBlock, TextMatcherBlock
from autogpt_server.data.graph import Graph, Link, Node, create_graph
from autogpt_server.data.user import get_or_create_user
from autogpt_server.util.test import SpinTestServer, wait_execution
@@ -136,10 +139,21 @@ Make sure to only comment on a relevant post.
return test_graph
async def create_test_user() -> User:
test_user_data = {
"sub": "ef3b97d7-1161-4eb4-92b2-10c24fb154c1",
"email": "testuser@example.com",
"name": "Test User",
}
user = await get_or_create_user(test_user_data)
return user
async def reddit_marketing_agent():
async with SpinTestServer() as server:
exec_man = server.exec_manager
test_graph = await create_graph(create_test_graph())
test_user = await create_test_user()
test_graph = await create_graph(create_test_graph(), user_id=test_user.id)
input_data = {"subreddit": "AutoGPT"}
response = await server.agent_server.execute_graph(test_graph.id, input_data)
print(response)

View File

@@ -1,10 +1,23 @@
from prisma.models import User
from autogpt_server.blocks.basic import InputBlock, PrintingBlock
from autogpt_server.blocks.text import TextFormatterBlock
from autogpt_server.data import graph
from autogpt_server.data.graph import create_graph
from autogpt_server.data.user import get_or_create_user
from autogpt_server.util.test import SpinTestServer, wait_execution
async def create_test_user() -> User:
test_user_data = {
"sub": "ef3b97d7-1161-4eb4-92b2-10c24fb154c1",
"email": "testuser@example.com",
"name": "Test User",
}
user = await get_or_create_user(test_user_data)
return user
def create_test_graph() -> graph.Graph:
"""
ValueBlock
@@ -63,7 +76,8 @@ def create_test_graph() -> graph.Graph:
async def sample_agent():
async with SpinTestServer() as server:
exec_man = server.exec_manager
test_graph = await create_graph(create_test_graph())
test_user = await create_test_user()
test_graph = await create_graph(create_test_graph(), test_user.id)
input_data = {"input_1": "Hello", "input_2": "World"}
response = await server.agent_server.execute_graph(test_graph.id, input_data)
print(response)