mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-10 07:38:04 -05:00
smol refactor for consistency & readability
This commit is contained in:
@@ -26,11 +26,6 @@ from autogpt_server.data import execution as execution_db
|
||||
from autogpt_server.data import graph as graph_db
|
||||
from autogpt_server.data import user as user_db
|
||||
from autogpt_server.data.block import BlockInput, CompletedBlockOutput
|
||||
from autogpt_server.data.execution import (
|
||||
ExecutionResult,
|
||||
get_execution_results,
|
||||
list_executions,
|
||||
)
|
||||
from autogpt_server.data.user import DEFAULT_USER_ID, get_or_create_user
|
||||
from autogpt_server.executor import ExecutionManager, ExecutionScheduler
|
||||
from autogpt_server.server.conn_manager import ConnectionManager
|
||||
@@ -59,7 +54,7 @@ def get_user_id(payload: dict = Depends(auth_middleware)) -> str:
|
||||
|
||||
|
||||
class AgentServer(AppService):
|
||||
event_queue: asyncio.Queue[ExecutionResult] = asyncio.Queue()
|
||||
event_queue: asyncio.Queue[execution_db.ExecutionResult] = asyncio.Queue()
|
||||
manager = ConnectionManager()
|
||||
mutex = KeyedMutex()
|
||||
use_db = False
|
||||
@@ -67,7 +62,7 @@ class AgentServer(AppService):
|
||||
|
||||
async def event_broadcaster(self):
|
||||
while True:
|
||||
event: ExecutionResult = await self.event_queue.get()
|
||||
event: execution_db.ExecutionResult = await self.event_queue.get()
|
||||
await self.manager.send_execution_result(event)
|
||||
|
||||
@asynccontextmanager
|
||||
@@ -194,13 +189,13 @@ class AgentServer(AppService):
|
||||
methods=["GET"],
|
||||
)
|
||||
router.add_api_route(
|
||||
path="/graphs/{graph_id}/executions/{run_id}",
|
||||
endpoint=self.get_run_execution_results,
|
||||
path="/graphs/{graph_id}/executions/{graph_exec_id}",
|
||||
endpoint=self.get_graph_run_node_execution_results,
|
||||
methods=["GET"],
|
||||
)
|
||||
router.add_api_route(
|
||||
path="/graphs/{graph_id}/executions/{graph_exec_id}/stop",
|
||||
endpoint=self.stop_graph_execution,
|
||||
endpoint=self.stop_graph_run,
|
||||
methods=["POST"],
|
||||
)
|
||||
router.add_api_route(
|
||||
@@ -524,11 +519,10 @@ class AgentServer(AppService):
|
||||
msg = e.__str__().encode().decode("unicode_escape")
|
||||
raise HTTPException(status_code=400, detail=msg)
|
||||
|
||||
async def stop_graph_execution(
|
||||
async def stop_graph_run(
|
||||
self, graph_exec_id: str, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> execution_db.AgentGraphExecution:
|
||||
graph_exec = await execution_db.get_graph_execution(graph_exec_id, user_id)
|
||||
if not graph_exec:
|
||||
) -> list[execution_db.ExecutionResult]:
|
||||
if not await execution_db.get_graph_execution(graph_exec_id, user_id):
|
||||
raise HTTPException(
|
||||
404, detail=f"Agent execution #{graph_exec_id} not found"
|
||||
)
|
||||
@@ -536,9 +530,7 @@ class AgentServer(AppService):
|
||||
self.execution_manager_client.cancel_execution(graph_exec_id)
|
||||
|
||||
# Retrieve & return canceled graph execution in its final state
|
||||
graph_exec = await execution_db.get_graph_execution(graph_exec_id, user_id)
|
||||
assert graph_exec
|
||||
return graph_exec
|
||||
return await execution_db.get_execution_results(graph_exec_id)
|
||||
|
||||
@classmethod
|
||||
async def list_graph_runs(
|
||||
@@ -554,17 +546,20 @@ class AgentServer(AppService):
|
||||
status_code=404, detail=f"Agent #{graph_id}{rev} not found."
|
||||
)
|
||||
|
||||
return await list_executions(graph_id, graph_version)
|
||||
return await execution_db.list_executions(graph_id, graph_version)
|
||||
|
||||
@classmethod
|
||||
async def get_run_execution_results(
|
||||
cls, graph_id: str, run_id: str, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> list[ExecutionResult]:
|
||||
async def get_graph_run_node_execution_results(
|
||||
cls,
|
||||
graph_id: str,
|
||||
graph_exec_id: str,
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> list[execution_db.ExecutionResult]:
|
||||
graph = await graph_db.get_graph(graph_id, user_id=user_id)
|
||||
if not graph:
|
||||
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
|
||||
|
||||
return await get_execution_results(run_id)
|
||||
return await execution_db.get_execution_results(graph_exec_id)
|
||||
|
||||
async def create_schedule(
|
||||
self,
|
||||
@@ -602,7 +597,7 @@ class AgentServer(AppService):
|
||||
|
||||
@expose
|
||||
def send_execution_update(self, execution_result_dict: dict[Any, Any]):
|
||||
execution_result = ExecutionResult(**execution_result_dict)
|
||||
execution_result = execution_db.ExecutionResult(**execution_result_dict)
|
||||
self.run_and_wait(self.event_queue.put(execution_result))
|
||||
|
||||
@expose
|
||||
|
||||
@@ -23,7 +23,6 @@ class SpinTestServer:
|
||||
return "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
|
||||
async def __aenter__(self):
|
||||
|
||||
self.name_server.__enter__()
|
||||
self.setup_dependency_overrides()
|
||||
self.agent_server.__enter__()
|
||||
@@ -59,7 +58,7 @@ async def wait_execution(
|
||||
timeout: int = 20,
|
||||
) -> list:
|
||||
async def is_execution_completed():
|
||||
execs = await AgentServer().get_run_execution_results(
|
||||
execs = await AgentServer().get_graph_run_node_execution_results(
|
||||
graph_id, graph_exec_id, user_id
|
||||
)
|
||||
return (
|
||||
@@ -74,7 +73,7 @@ async def wait_execution(
|
||||
# Wait for the executions to complete
|
||||
for i in range(timeout):
|
||||
if await is_execution_completed():
|
||||
return await AgentServer().get_run_execution_results(
|
||||
return await AgentServer().get_graph_run_node_execution_results(
|
||||
graph_id, graph_exec_id, user_id
|
||||
)
|
||||
time.sleep(1)
|
||||
|
||||
@@ -36,7 +36,7 @@ async def assert_sample_graph_executions(
|
||||
graph_exec_id: str,
|
||||
):
|
||||
input = {"input_1": "Hello", "input_2": "World"}
|
||||
executions = await agent_server.get_run_execution_results(
|
||||
executions = await agent_server.get_graph_run_node_execution_results(
|
||||
test_graph.id, graph_exec_id, test_user.id
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user