From aa5c044b407fdefff8404b08c88c3f4f6bc9723f Mon Sep 17 00:00:00 2001 From: Aarushi Date: Mon, 5 Aug 2024 11:35:34 +0100 Subject: [PATCH] fixing tests --- .../autogpt_server/data/schedule.py | 2 +- .../autogpt_server/executor/manager.py | 6 ++++-- .../autogpt_server/server/server.py | 18 +++++++++++++----- .../autogpt_server/usecases/block_autogen.py | 16 +++++++++++++++- .../usecases/reddit_marketing.py | 16 +++++++++++++++- .../autogpt_server/usecases/sample.py | 16 +++++++++++++++- 6 files changed, 63 insertions(+), 11 deletions(-) diff --git a/rnd/autogpt_server/autogpt_server/data/schedule.py b/rnd/autogpt_server/autogpt_server/data/schedule.py index e53ce37200..c32fa00fc8 100644 --- a/rnd/autogpt_server/autogpt_server/data/schedule.py +++ b/rnd/autogpt_server/autogpt_server/data/schedule.py @@ -74,5 +74,5 @@ async def add_schedule(schedule: ExecutionSchedule) -> ExecutionSchedule: async def update_schedule(schedule_id: str, is_enabled: bool, user_id: str): await AgentGraphExecutionSchedule.prisma().update( - where={{"id": schedule_id}, {"userId": user_id}}, data={"isEnabled": is_enabled} + where={"id": schedule_id}, data={"isEnabled": is_enabled} ) diff --git a/rnd/autogpt_server/autogpt_server/executor/manager.py b/rnd/autogpt_server/autogpt_server/executor/manager.py index 951de00764..0f6650125b 100644 --- a/rnd/autogpt_server/autogpt_server/executor/manager.py +++ b/rnd/autogpt_server/autogpt_server/executor/manager.py @@ -416,8 +416,10 @@ class ExecutionManager(AppService): return get_agent_server_client() @expose - def add_execution(self, graph_id: str, data: BlockInput) -> dict[Any, Any]: - graph: Graph | None = self.run_and_wait(get_graph(graph_id)) + def add_execution( + self, graph_id: str, data: BlockInput, user_id: str + ) -> dict[Any, Any]: + graph: Graph | None = self.run_and_wait(get_graph(graph_id, user_id=user_id)) if not graph: raise Exception(f"Graph #{graph_id} not found.") graph.validate_graph() diff --git a/rnd/autogpt_server/autogpt_server/server/server.py b/rnd/autogpt_server/autogpt_server/server/server.py index 5a57b7c7e7..c94351ec20 100644 --- a/rnd/autogpt_server/autogpt_server/server/server.py +++ b/rnd/autogpt_server/autogpt_server/server/server.py @@ -569,7 +569,7 @@ class AgentServer(AppService): user_id: str = Depends(get_user_id), ): new_active_version = request_body.active_graph_version - if not await graph_db.get_graph(graph_id, new_active_version): + if not await graph_db.get_graph(graph_id, new_active_version, user_id=user_id): raise HTTPException( 404, f"Graph #{graph_id} v{new_active_version} not found" ) @@ -580,19 +580,27 @@ class AgentServer(AppService): ) async def execute_graph( - self, graph_id: str, node_input: dict[Any, Any] + self, + graph_id: str, + node_input: dict[Any, Any], + user_id: str = Depends(get_user_id), ) -> dict[Any, Any]: try: - return self.execution_manager_client.add_execution(graph_id, node_input) + return self.execution_manager_client.add_execution( + graph_id, node_input, user_id=user_id + ) except Exception as e: msg = e.__str__().encode().decode("unicode_escape") raise HTTPException(status_code=400, detail=msg) @classmethod async def list_graph_runs( - cls, graph_id: str, graph_version: int | None = None + cls, + graph_id: str, + graph_version: int | None = None, + user_id: str = Depends(get_user_id), ) -> list[str]: - graph = await graph_db.get_graph(graph_id, graph_version) + graph = await graph_db.get_graph(graph_id, graph_version, user_id=user_id) if not graph: rev = "" if graph_version is None else f" v{graph_version}" raise HTTPException( diff --git a/rnd/autogpt_server/autogpt_server/usecases/block_autogen.py b/rnd/autogpt_server/autogpt_server/usecases/block_autogen.py index 1ded46b33f..15e05bc0b1 100644 --- a/rnd/autogpt_server/autogpt_server/usecases/block_autogen.py +++ b/rnd/autogpt_server/autogpt_server/usecases/block_autogen.py @@ -1,11 +1,14 @@ from pathlib import Path +from prisma.models import User + from autogpt_server.blocks.basic import ValueBlock from autogpt_server.blocks.block import BlockInstallationBlock from autogpt_server.blocks.http import HttpRequestBlock from autogpt_server.blocks.llm import TextLlmCallBlock from autogpt_server.blocks.text import TextFormatterBlock, TextParserBlock from autogpt_server.data.graph import Graph, Link, Node, create_graph +from autogpt_server.data.user import get_or_create_user from autogpt_server.util.test import SpinTestServer, wait_execution sample_block_modules = { @@ -23,6 +26,16 @@ for module, description in sample_block_modules.items(): sample_block_codes[module] = f"[Example: {description}]\n{code}" +async def create_test_user() -> User: + test_user_data = { + "sub": "ef3b97d7-1161-4eb4-92b2-10c24fb154c1", + "email": "testuser@example.com", + "name": "Test User", + } + user = await get_or_create_user(test_user_data) + return user + + def create_test_graph() -> Graph: """ ValueBlock (input) @@ -237,7 +250,8 @@ Here are a couple of sample of the Block class implementation: async def block_autogen_agent(): async with SpinTestServer() as server: test_manager = server.exec_manager - test_graph = await create_graph(create_test_graph()) + test_user = await create_test_user() + test_graph = await create_graph(create_test_graph(), user_id=test_user.id) input_data = {"input": "Write me a block that writes a string into a file."} response = await server.agent_server.execute_graph(test_graph.id, input_data) print(response) diff --git a/rnd/autogpt_server/autogpt_server/usecases/reddit_marketing.py b/rnd/autogpt_server/autogpt_server/usecases/reddit_marketing.py index 2505acd9eb..81b82d9deb 100644 --- a/rnd/autogpt_server/autogpt_server/usecases/reddit_marketing.py +++ b/rnd/autogpt_server/autogpt_server/usecases/reddit_marketing.py @@ -1,7 +1,10 @@ +from prisma.models import User + from autogpt_server.blocks.llm import ObjectLlmCallBlock from autogpt_server.blocks.reddit import RedditGetPostsBlock, RedditPostCommentBlock from autogpt_server.blocks.text import TextFormatterBlock, TextMatcherBlock from autogpt_server.data.graph import Graph, Link, Node, create_graph +from autogpt_server.data.user import get_or_create_user from autogpt_server.util.test import SpinTestServer, wait_execution @@ -136,10 +139,21 @@ Make sure to only comment on a relevant post. return test_graph +async def create_test_user() -> User: + test_user_data = { + "sub": "ef3b97d7-1161-4eb4-92b2-10c24fb154c1", + "email": "testuser@example.com", + "name": "Test User", + } + user = await get_or_create_user(test_user_data) + return user + + async def reddit_marketing_agent(): async with SpinTestServer() as server: exec_man = server.exec_manager - test_graph = await create_graph(create_test_graph()) + test_user = await create_test_user() + test_graph = await create_graph(create_test_graph(), user_id=test_user.id) input_data = {"subreddit": "AutoGPT"} response = await server.agent_server.execute_graph(test_graph.id, input_data) print(response) diff --git a/rnd/autogpt_server/autogpt_server/usecases/sample.py b/rnd/autogpt_server/autogpt_server/usecases/sample.py index 9e7cb89bac..b2edee0bcb 100644 --- a/rnd/autogpt_server/autogpt_server/usecases/sample.py +++ b/rnd/autogpt_server/autogpt_server/usecases/sample.py @@ -1,10 +1,23 @@ +from prisma.models import User + from autogpt_server.blocks.basic import InputBlock, PrintingBlock from autogpt_server.blocks.text import TextFormatterBlock from autogpt_server.data import graph from autogpt_server.data.graph import create_graph +from autogpt_server.data.user import get_or_create_user from autogpt_server.util.test import SpinTestServer, wait_execution +async def create_test_user() -> User: + test_user_data = { + "sub": "ef3b97d7-1161-4eb4-92b2-10c24fb154c1", + "email": "testuser@example.com", + "name": "Test User", + } + user = await get_or_create_user(test_user_data) + return user + + def create_test_graph() -> graph.Graph: """ ValueBlock @@ -63,7 +76,8 @@ def create_test_graph() -> graph.Graph: async def sample_agent(): async with SpinTestServer() as server: exec_man = server.exec_manager - test_graph = await create_graph(create_test_graph()) + test_user = await create_test_user() + test_graph = await create_graph(create_test_graph(), test_user.id) input_data = {"input_1": "Hello", "input_2": "World"} response = await server.agent_server.execute_graph(test_graph.id, input_data) print(response)