smol refactor for consistency & readability

This commit is contained in:
Reinier van der Leer
2024-08-26 16:42:14 +02:00
parent b6c4fc4742
commit 2973567010
3 changed files with 21 additions and 27 deletions

View File

@@ -26,11 +26,6 @@ from autogpt_server.data import execution as execution_db
from autogpt_server.data import graph as graph_db
from autogpt_server.data import user as user_db
from autogpt_server.data.block import BlockInput, CompletedBlockOutput
from autogpt_server.data.execution import (
ExecutionResult,
get_execution_results,
list_executions,
)
from autogpt_server.data.user import DEFAULT_USER_ID, get_or_create_user
from autogpt_server.executor import ExecutionManager, ExecutionScheduler
from autogpt_server.server.conn_manager import ConnectionManager
@@ -59,7 +54,7 @@ def get_user_id(payload: dict = Depends(auth_middleware)) -> str:
class AgentServer(AppService):
event_queue: asyncio.Queue[ExecutionResult] = asyncio.Queue()
event_queue: asyncio.Queue[execution_db.ExecutionResult] = asyncio.Queue()
manager = ConnectionManager()
mutex = KeyedMutex()
use_db = False
@@ -67,7 +62,7 @@ class AgentServer(AppService):
async def event_broadcaster(self):
while True:
event: ExecutionResult = await self.event_queue.get()
event: execution_db.ExecutionResult = await self.event_queue.get()
await self.manager.send_execution_result(event)
@asynccontextmanager
@@ -194,13 +189,13 @@ class AgentServer(AppService):
methods=["GET"],
)
router.add_api_route(
path="/graphs/{graph_id}/executions/{run_id}",
endpoint=self.get_run_execution_results,
path="/graphs/{graph_id}/executions/{graph_exec_id}",
endpoint=self.get_graph_run_node_execution_results,
methods=["GET"],
)
router.add_api_route(
path="/graphs/{graph_id}/executions/{graph_exec_id}/stop",
endpoint=self.stop_graph_execution,
endpoint=self.stop_graph_run,
methods=["POST"],
)
router.add_api_route(
@@ -524,11 +519,10 @@ class AgentServer(AppService):
msg = e.__str__().encode().decode("unicode_escape")
raise HTTPException(status_code=400, detail=msg)
async def stop_graph_execution(
async def stop_graph_run(
self, graph_exec_id: str, user_id: Annotated[str, Depends(get_user_id)]
) -> execution_db.AgentGraphExecution:
graph_exec = await execution_db.get_graph_execution(graph_exec_id, user_id)
if not graph_exec:
) -> list[execution_db.ExecutionResult]:
if not await execution_db.get_graph_execution(graph_exec_id, user_id):
raise HTTPException(
404, detail=f"Agent execution #{graph_exec_id} not found"
)
@@ -536,9 +530,7 @@ class AgentServer(AppService):
self.execution_manager_client.cancel_execution(graph_exec_id)
# Retrieve & return canceled graph execution in its final state
graph_exec = await execution_db.get_graph_execution(graph_exec_id, user_id)
assert graph_exec
return graph_exec
return await execution_db.get_execution_results(graph_exec_id)
@classmethod
async def list_graph_runs(
@@ -554,17 +546,20 @@ class AgentServer(AppService):
status_code=404, detail=f"Agent #{graph_id}{rev} not found."
)
return await list_executions(graph_id, graph_version)
return await execution_db.list_executions(graph_id, graph_version)
@classmethod
async def get_run_execution_results(
cls, graph_id: str, run_id: str, user_id: Annotated[str, Depends(get_user_id)]
) -> list[ExecutionResult]:
async def get_graph_run_node_execution_results(
cls,
graph_id: str,
graph_exec_id: str,
user_id: Annotated[str, Depends(get_user_id)],
) -> list[execution_db.ExecutionResult]:
graph = await graph_db.get_graph(graph_id, user_id=user_id)
if not graph:
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
return await get_execution_results(run_id)
return await execution_db.get_execution_results(graph_exec_id)
async def create_schedule(
self,
@@ -602,7 +597,7 @@ class AgentServer(AppService):
@expose
def send_execution_update(self, execution_result_dict: dict[Any, Any]):
execution_result = ExecutionResult(**execution_result_dict)
execution_result = execution_db.ExecutionResult(**execution_result_dict)
self.run_and_wait(self.event_queue.put(execution_result))
@expose

View File

@@ -23,7 +23,6 @@ class SpinTestServer:
return "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
async def __aenter__(self):
self.name_server.__enter__()
self.setup_dependency_overrides()
self.agent_server.__enter__()
@@ -59,7 +58,7 @@ async def wait_execution(
timeout: int = 20,
) -> list:
async def is_execution_completed():
execs = await AgentServer().get_run_execution_results(
execs = await AgentServer().get_graph_run_node_execution_results(
graph_id, graph_exec_id, user_id
)
return (
@@ -74,7 +73,7 @@ async def wait_execution(
# Wait for the executions to complete
for i in range(timeout):
if await is_execution_completed():
return await AgentServer().get_run_execution_results(
return await AgentServer().get_graph_run_node_execution_results(
graph_id, graph_exec_id, user_id
)
time.sleep(1)

View File

@@ -36,7 +36,7 @@ async def assert_sample_graph_executions(
graph_exec_id: str,
):
input = {"input_1": "Hello", "input_2": "World"}
executions = await agent_server.get_run_execution_results(
executions = await agent_server.get_graph_run_node_execution_results(
test_graph.id, graph_exec_id, test_user.id
)