mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
addressing feedback
This commit is contained in:
@@ -17,7 +17,7 @@ export default class AutoGPTServerAPI {
|
||||
private supabaseClient = createClient();
|
||||
|
||||
constructor(
|
||||
baseUrl: string = process.env.AGPT_SERVER_URL || "http://localhost:8000/api"
|
||||
baseUrl: string = process.env.NEXT_PUBLIC_AGPT_SERVER_URL || "http://localhost:8000/api"
|
||||
) {
|
||||
this.baseUrl = baseUrl;
|
||||
this.wsUrl = `ws://${new URL(this.baseUrl).host}/ws`;
|
||||
|
||||
@@ -10,6 +10,7 @@ from autogpt_server.util import json
|
||||
|
||||
class ExecutionSchedule(BaseDbModel):
|
||||
graph_id: str
|
||||
user_id: str
|
||||
graph_version: int
|
||||
schedule: str
|
||||
is_enabled: bool
|
||||
@@ -25,6 +26,7 @@ class ExecutionSchedule(BaseDbModel):
|
||||
return ExecutionSchedule(
|
||||
id=schedule.id,
|
||||
graph_id=schedule.agentGraphId,
|
||||
user_id=schedule.userId,
|
||||
graph_version=schedule.agentGraphVersion,
|
||||
schedule=schedule.schedule,
|
||||
is_enabled=schedule.isEnabled,
|
||||
|
||||
@@ -68,10 +68,16 @@ class ExecutionScheduler(AppService):
|
||||
|
||||
@expose
|
||||
def add_execution_schedule(
|
||||
self, graph_id: str, graph_version: int, cron: str, input_data: BlockInput
|
||||
self,
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
cron: str,
|
||||
input_data: BlockInput,
|
||||
user_id: str,
|
||||
) -> str:
|
||||
schedule = model.ExecutionSchedule(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
graph_version=graph_version,
|
||||
schedule=cron,
|
||||
input_data=input_data,
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
from collections import defaultdict
|
||||
from contextlib import asynccontextmanager
|
||||
from functools import wraps
|
||||
from typing import Annotated, Any, Dict
|
||||
|
||||
import uvicorn
|
||||
from autogpt_libs.auth.jwt_utils import parse_jwt_token
|
||||
from autogpt_libs.auth.middleware import auth_middleware
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
@@ -59,6 +62,7 @@ class AgentServer(AppService):
|
||||
manager = ConnectionManager()
|
||||
mutex = KeyedMutex()
|
||||
use_db = False
|
||||
_test_dependency_overrides = {}
|
||||
|
||||
async def event_broadcaster(self):
|
||||
while True:
|
||||
@@ -87,6 +91,9 @@ class AgentServer(AppService):
|
||||
lifespan=self.lifespan,
|
||||
)
|
||||
|
||||
if self._test_dependency_overrides:
|
||||
app.dependency_overrides.update(self._test_dependency_overrides)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # Allows all origins
|
||||
@@ -222,6 +229,35 @@ class AgentServer(AppService):
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
|
||||
def set_test_dependency_overrides(self, overrides: dict):
|
||||
self._test_dependency_overrides = overrides
|
||||
|
||||
def _apply_overrides_to_methods(self):
|
||||
for attr_name in dir(self):
|
||||
attr = getattr(self, attr_name)
|
||||
if callable(attr) and hasattr(attr, "__annotations__"):
|
||||
setattr(self, attr_name, self._override_method(attr))
|
||||
|
||||
# TODO: fix this with some proper refactoring of the server
|
||||
def _override_method(self, method):
|
||||
@wraps(method)
|
||||
async def wrapper(*args, **kwargs):
|
||||
sig = inspect.signature(method)
|
||||
for param_name, param in sig.parameters.items():
|
||||
if param.annotation is inspect.Parameter.empty:
|
||||
continue
|
||||
if isinstance(param.annotation, Depends) or ( # type: ignore
|
||||
isinstance(param.annotation, type) and issubclass(param.annotation, Depends) # type: ignore
|
||||
):
|
||||
dependency = param.annotation.dependency if isinstance(param.annotation, Depends) else param.annotation # type: ignore
|
||||
if dependency in self._test_dependency_overrides:
|
||||
kwargs[param_name] = self._test_dependency_overrides[
|
||||
dependency
|
||||
]()
|
||||
return await method(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
@property
|
||||
def execution_manager_client(self) -> ExecutionManager:
|
||||
return get_service_client(ExecutionManager)
|
||||
@@ -240,7 +276,27 @@ class AgentServer(AppService):
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
async def authenticate_websocket(self, websocket: WebSocket) -> str:
|
||||
token = websocket.query_params.get("token")
|
||||
if not token:
|
||||
await websocket.close(code=4001, reason="Missing authentication token")
|
||||
return ""
|
||||
|
||||
try:
|
||||
payload = parse_jwt_token(token)
|
||||
user_id = payload.get("sub")
|
||||
if not user_id:
|
||||
await websocket.close(code=4002, reason="Invalid token")
|
||||
return ""
|
||||
return user_id
|
||||
except ValueError:
|
||||
await websocket.close(code=4003, reason="Invalid token")
|
||||
return ""
|
||||
|
||||
async def websocket_router(self, websocket: WebSocket):
|
||||
user_id = await self.authenticate_websocket(websocket)
|
||||
if not user_id:
|
||||
return
|
||||
await self.manager.connect(websocket)
|
||||
try:
|
||||
while True:
|
||||
@@ -279,7 +335,7 @@ class AgentServer(AppService):
|
||||
).model_dump_json()
|
||||
)
|
||||
elif message.method == Methods.GET_GRAPHS:
|
||||
data = await self.get_graphs()
|
||||
data = await self.get_graphs(user_id=user_id)
|
||||
await websocket.send_text(
|
||||
WsMessage(
|
||||
method=Methods.GET_GRAPHS,
|
||||
@@ -290,7 +346,9 @@ class AgentServer(AppService):
|
||||
print("Get graphs request received")
|
||||
elif message.method == Methods.GET_GRAPH:
|
||||
assert isinstance(message.data, dict), "Data must be a dictionary"
|
||||
data = await self.get_graph(message.data["graph_id"])
|
||||
data = await self.get_graph(
|
||||
message.data["graph_id"], user_id=user_id
|
||||
)
|
||||
await websocket.send_text(
|
||||
WsMessage(
|
||||
method=Methods.GET_GRAPH,
|
||||
@@ -302,7 +360,7 @@ class AgentServer(AppService):
|
||||
elif message.method == Methods.CREATE_GRAPH:
|
||||
assert isinstance(message.data, dict), "Data must be a dictionary"
|
||||
create_graph = CreateGraph.model_validate(message.data)
|
||||
data = await self.create_new_graph(create_graph)
|
||||
data = await self.create_new_graph(create_graph, user_id=user_id)
|
||||
await websocket.send_text(
|
||||
WsMessage(
|
||||
method=Methods.CREATE_GRAPH,
|
||||
@@ -315,7 +373,7 @@ class AgentServer(AppService):
|
||||
elif message.method == Methods.RUN_GRAPH:
|
||||
assert isinstance(message.data, dict), "Data must be a dictionary"
|
||||
data = await self.execute_graph(
|
||||
message.data["graph_id"], message.data["data"]
|
||||
message.data["graph_id"], message.data["data"], user_id=user_id
|
||||
)
|
||||
await websocket.send_text(
|
||||
WsMessage(
|
||||
@@ -328,7 +386,9 @@ class AgentServer(AppService):
|
||||
print("Run graph request received")
|
||||
elif message.method == Methods.GET_GRAPH_RUNS:
|
||||
assert isinstance(message.data, dict), "Data must be a dictionary"
|
||||
data = await self.list_graph_runs(message.data["graph_id"])
|
||||
data = await self.list_graph_runs(
|
||||
message.data["graph_id"], user_id=user_id
|
||||
)
|
||||
await websocket.send_text(
|
||||
WsMessage(
|
||||
method=Methods.GET_GRAPH_RUNS,
|
||||
@@ -344,6 +404,7 @@ class AgentServer(AppService):
|
||||
message.data["graph_id"],
|
||||
message.data["cron"],
|
||||
message.data["data"],
|
||||
user_id=user_id,
|
||||
)
|
||||
await websocket.send_text(
|
||||
WsMessage(
|
||||
@@ -356,7 +417,9 @@ class AgentServer(AppService):
|
||||
print("Create scheduled run request received")
|
||||
elif message.method == Methods.GET_SCHEDULED_RUNS:
|
||||
assert isinstance(message.data, dict), "Data must be a dictionary"
|
||||
data = self.get_execution_schedules(message.data["graph_id"])
|
||||
data = self.get_execution_schedules(
|
||||
message.data["graph_id"], user_id=user_id
|
||||
)
|
||||
await websocket.send_text(
|
||||
WsMessage(
|
||||
method=Methods.GET_SCHEDULED_RUNS,
|
||||
@@ -368,7 +431,7 @@ class AgentServer(AppService):
|
||||
elif message.method == Methods.UPDATE_SCHEDULED_RUN:
|
||||
assert isinstance(message.data, dict), "Data must be a dictionary"
|
||||
data = self.update_schedule(
|
||||
message.data["schedule_id"], message.data
|
||||
message.data["schedule_id"], message.data, user_id=user_id
|
||||
)
|
||||
await websocket.send_text(
|
||||
WsMessage(
|
||||
@@ -431,7 +494,7 @@ class AgentServer(AppService):
|
||||
|
||||
@classmethod
|
||||
async def get_graphs(
|
||||
cls, user_id: str = Depends(get_user_id)
|
||||
cls, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> list[graph_db.GraphMeta]:
|
||||
return await graph_db.get_graphs_meta(filter_by="active", user_id=user_id)
|
||||
|
||||
@@ -443,8 +506,8 @@ class AgentServer(AppService):
|
||||
async def get_graph(
|
||||
cls,
|
||||
graph_id: str,
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
version: int | None = None,
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> graph_db.Graph:
|
||||
graph = await graph_db.get_graph(graph_id, version, user_id=user_id)
|
||||
if not graph:
|
||||
@@ -464,7 +527,7 @@ class AgentServer(AppService):
|
||||
|
||||
@classmethod
|
||||
async def get_graph_all_versions(
|
||||
cls, graph_id: str, user_id: str = Depends(get_user_id)
|
||||
cls, graph_id: str, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> list[graph_db.Graph]:
|
||||
graphs = await graph_db.get_graph_all_versions(graph_id, user_id=user_id)
|
||||
if not graphs:
|
||||
@@ -473,13 +536,13 @@ class AgentServer(AppService):
|
||||
|
||||
@classmethod
|
||||
async def create_new_graph(
|
||||
cls, create_graph: CreateGraph, user_id: str = Depends(get_user_id)
|
||||
cls, create_graph: CreateGraph, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> graph_db.Graph:
|
||||
return await cls.create_graph(create_graph, is_template=False, user_id=user_id)
|
||||
|
||||
@classmethod
|
||||
async def create_new_template(
|
||||
cls, create_graph: CreateGraph, user_id: str = Depends(get_user_id)
|
||||
cls, create_graph: CreateGraph, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> graph_db.Graph:
|
||||
return await cls.create_graph(create_graph, is_template=True, user_id=user_id)
|
||||
|
||||
@@ -515,7 +578,10 @@ class AgentServer(AppService):
|
||||
|
||||
@classmethod
|
||||
async def update_graph(
|
||||
cls, graph_id: str, graph: graph_db.Graph, user_id: str = Depends(get_user_id)
|
||||
cls,
|
||||
graph_id: str,
|
||||
graph: graph_db.Graph,
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> graph_db.Graph:
|
||||
# Sanity check
|
||||
if graph.id and graph.id != graph_id:
|
||||
@@ -555,7 +621,7 @@ class AgentServer(AppService):
|
||||
cls,
|
||||
graph_id: str,
|
||||
request_body: SetGraphActiveVersion,
|
||||
user_id: str = Depends(get_user_id),
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
):
|
||||
new_active_version = request_body.active_graph_version
|
||||
if not await graph_db.get_graph(graph_id, new_active_version, user_id=user_id):
|
||||
@@ -572,7 +638,7 @@ class AgentServer(AppService):
|
||||
self,
|
||||
graph_id: str,
|
||||
node_input: dict[Any, Any],
|
||||
user_id: str = Depends(get_user_id),
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> dict[Any, Any]:
|
||||
try:
|
||||
return self.execution_manager_client.add_execution(
|
||||
@@ -586,8 +652,8 @@ class AgentServer(AppService):
|
||||
async def list_graph_runs(
|
||||
cls,
|
||||
graph_id: str,
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
graph_version: int | None = None,
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> list[str]:
|
||||
graph = await graph_db.get_graph(graph_id, graph_version, user_id=user_id)
|
||||
if not graph:
|
||||
@@ -600,7 +666,7 @@ class AgentServer(AppService):
|
||||
|
||||
@classmethod
|
||||
async def get_run_execution_results(
|
||||
cls, graph_id: str, run_id: str, user_id: str = Depends(get_user_id)
|
||||
cls, graph_id: str, run_id: str, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> list[ExecutionResult]:
|
||||
graph = await graph_db.get_graph(graph_id, user_id=user_id)
|
||||
if not graph:
|
||||
@@ -613,7 +679,7 @@ class AgentServer(AppService):
|
||||
graph_id: str,
|
||||
cron: str,
|
||||
input_data: dict[Any, Any],
|
||||
user_id: str = Depends(get_user_id),
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> dict[Any, Any]:
|
||||
graph = await graph_db.get_graph(graph_id, user_id=user_id)
|
||||
if not graph:
|
||||
@@ -621,7 +687,7 @@ class AgentServer(AppService):
|
||||
execution_scheduler = self.execution_scheduler_client
|
||||
return {
|
||||
"id": execution_scheduler.add_execution_schedule(
|
||||
graph_id, graph.version, cron, input_data
|
||||
graph_id, graph.version, cron, input_data, user_id=user_id
|
||||
)
|
||||
}
|
||||
|
||||
@@ -629,7 +695,7 @@ class AgentServer(AppService):
|
||||
self,
|
||||
schedule_id: str,
|
||||
input_data: dict[Any, Any],
|
||||
user_id: str = Depends(get_user_id),
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> dict[Any, Any]:
|
||||
execution_scheduler = self.execution_scheduler_client
|
||||
is_enabled = input_data.get("is_enabled", False)
|
||||
@@ -637,7 +703,7 @@ class AgentServer(AppService):
|
||||
return {"id": schedule_id}
|
||||
|
||||
def get_execution_schedules(
|
||||
self, graph_id: str, user_id: str = Depends(get_user_id)
|
||||
self, graph_id: str, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> dict[str, str]:
|
||||
execution_scheduler = self.execution_scheduler_client
|
||||
return execution_scheduler.get_execution_schedules(graph_id, user_id) # type: ignore
|
||||
|
||||
@@ -253,7 +253,9 @@ async def block_autogen_agent():
|
||||
test_user = await create_test_user()
|
||||
test_graph = await create_graph(create_test_graph(), user_id=test_user.id)
|
||||
input_data = {"input": "Write me a block that writes a string into a file."}
|
||||
response = await server.agent_server.execute_graph(test_graph.id, input_data)
|
||||
response = await server.agent_server.execute_graph(
|
||||
test_graph.id, input_data, test_user.id
|
||||
)
|
||||
print(response)
|
||||
result = await wait_execution(
|
||||
exec_manager=test_manager,
|
||||
@@ -261,6 +263,7 @@ async def block_autogen_agent():
|
||||
graph_exec_id=response["id"],
|
||||
num_execs=10,
|
||||
timeout=1200,
|
||||
user_id=test_user.id,
|
||||
)
|
||||
print(result)
|
||||
|
||||
|
||||
@@ -155,9 +155,13 @@ async def reddit_marketing_agent():
|
||||
test_user = await create_test_user()
|
||||
test_graph = await create_graph(create_test_graph(), user_id=test_user.id)
|
||||
input_data = {"subreddit": "AutoGPT"}
|
||||
response = await server.agent_server.execute_graph(test_graph.id, input_data)
|
||||
response = await server.agent_server.execute_graph(
|
||||
test_graph.id, input_data, test_user.id
|
||||
)
|
||||
print(response)
|
||||
result = await wait_execution(exec_man, test_graph.id, response["id"], 13, 120)
|
||||
result = await wait_execution(
|
||||
exec_man, test_user.id, test_graph.id, response["id"], 13, 120
|
||||
)
|
||||
print(result)
|
||||
|
||||
|
||||
|
||||
@@ -79,9 +79,13 @@ async def sample_agent():
|
||||
test_user = await create_test_user()
|
||||
test_graph = await create_graph(create_test_graph(), test_user.id)
|
||||
input_data = {"input_1": "Hello", "input_2": "World"}
|
||||
response = await server.agent_server.execute_graph(test_graph.id, input_data)
|
||||
response = await server.agent_server.execute_graph(
|
||||
test_graph.id, input_data, test_user.id
|
||||
)
|
||||
print(response)
|
||||
result = await wait_execution(exec_man, test_graph.id, response["id"], 4, 10)
|
||||
result = await wait_execution(
|
||||
exec_man, test_user.id, test_graph.id, response["id"], 4, 10
|
||||
)
|
||||
print(result)
|
||||
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from autogpt_server.data.block import Block, initialize_blocks
|
||||
from autogpt_server.data.execution import ExecutionStatus
|
||||
from autogpt_server.executor import ExecutionManager, ExecutionScheduler
|
||||
from autogpt_server.server import AgentServer
|
||||
from autogpt_server.server.server import get_user_id
|
||||
from autogpt_server.util.service import PyroNameServer
|
||||
|
||||
log = print
|
||||
@@ -17,14 +18,21 @@ class SpinTestServer:
|
||||
self.agent_server = AgentServer()
|
||||
self.scheduler = ExecutionScheduler()
|
||||
|
||||
@staticmethod
|
||||
def test_get_user_id():
|
||||
return "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
|
||||
async def __aenter__(self):
|
||||
|
||||
self.name_server.__enter__()
|
||||
self.setup_dependency_overrides()
|
||||
self.agent_server.__enter__()
|
||||
self.exec_manager.__enter__()
|
||||
self.scheduler.__enter__()
|
||||
|
||||
await db.connect()
|
||||
await initialize_blocks()
|
||||
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
@@ -35,16 +43,25 @@ class SpinTestServer:
|
||||
self.exec_manager.__exit__(exc_type, exc_val, exc_tb)
|
||||
self.scheduler.__exit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
def setup_dependency_overrides(self):
|
||||
# Override get_user_id for testing
|
||||
self.agent_server.set_test_dependency_overrides(
|
||||
{get_user_id: self.test_get_user_id}
|
||||
)
|
||||
|
||||
|
||||
async def wait_execution(
|
||||
exec_manager: ExecutionManager,
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
graph_exec_id: str,
|
||||
num_execs: int,
|
||||
timeout: int = 20,
|
||||
) -> list:
|
||||
async def is_execution_completed():
|
||||
execs = await AgentServer().get_run_execution_results(graph_id, graph_exec_id)
|
||||
execs = await AgentServer().get_run_execution_results(
|
||||
graph_id, graph_exec_id, user_id
|
||||
)
|
||||
return (
|
||||
exec_manager.queue.empty()
|
||||
and len(execs) == num_execs
|
||||
@@ -58,7 +75,7 @@ async def wait_execution(
|
||||
for i in range(timeout):
|
||||
if await is_execution_completed():
|
||||
return await AgentServer().get_run_execution_results(
|
||||
graph_id, graph_exec_id
|
||||
graph_id, graph_exec_id, user_id
|
||||
)
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
-- CreateIndex
|
||||
CREATE INDEX "User_id_idx" ON "User"("id");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "User_email_idx" ON "User"("email");
|
||||
@@ -0,0 +1,5 @@
|
||||
-- CreateIndex
|
||||
CREATE INDEX "User_id_idx" ON "User"("id");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "User_email_idx" ON "User"("email");
|
||||
@@ -22,6 +22,9 @@ model User {
|
||||
AgentGraphs AgentGraph[]
|
||||
AgentGraphExecutions AgentGraphExecution[]
|
||||
AgentGraphExecutionSchedules AgentGraphExecutionSchedule[]
|
||||
|
||||
@@index([id])
|
||||
@@index([email])
|
||||
}
|
||||
|
||||
// This model describes the Agent Graph/Flow (Multi Agent System).
|
||||
|
||||
@@ -21,6 +21,9 @@ model User {
|
||||
AgentGraphs AgentGraph[]
|
||||
AgentGraphExecutions AgentGraphExecution[]
|
||||
AgentGraphExecutionSchedules AgentGraphExecutionSchedule[]
|
||||
|
||||
@@index([id])
|
||||
@@index([email])
|
||||
}
|
||||
|
||||
// This model describes the Agent Graph/Flow (Multi Agent System).
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import pytest
|
||||
from prisma.models import User
|
||||
|
||||
from autogpt_server.blocks.basic import ObjectLookupBlock, ValueBlock
|
||||
from autogpt_server.blocks.maths import MathsBlock, Operation
|
||||
@@ -13,24 +14,30 @@ async def execute_graph(
|
||||
agent_server: AgentServer,
|
||||
test_manager: ExecutionManager,
|
||||
test_graph: graph.Graph,
|
||||
test_user: User,
|
||||
input_data: dict,
|
||||
num_execs: int = 4,
|
||||
) -> str:
|
||||
# --- Test adding new executions --- #
|
||||
response = await agent_server.execute_graph(test_graph.id, input_data)
|
||||
response = await agent_server.execute_graph(test_graph.id, input_data, test_user.id)
|
||||
graph_exec_id = response["id"]
|
||||
|
||||
# Execution queue should be empty
|
||||
assert await wait_execution(test_manager, test_graph.id, graph_exec_id, num_execs)
|
||||
assert await wait_execution(
|
||||
test_manager, test_user.id, test_graph.id, graph_exec_id, num_execs
|
||||
)
|
||||
return graph_exec_id
|
||||
|
||||
|
||||
async def assert_sample_graph_executions(
|
||||
agent_server: AgentServer, test_graph: graph.Graph, graph_exec_id: str
|
||||
agent_server: AgentServer,
|
||||
test_graph: graph.Graph,
|
||||
test_user: User,
|
||||
graph_exec_id: str,
|
||||
):
|
||||
input = {"input_1": "Hello", "input_2": "World"}
|
||||
executions = await agent_server.get_run_execution_results(
|
||||
test_graph.id, graph_exec_id
|
||||
test_graph.id, graph_exec_id, test_user.id
|
||||
)
|
||||
|
||||
# Executing ValueBlock
|
||||
@@ -79,9 +86,16 @@ async def test_agent_execution(server):
|
||||
await graph.create_graph(test_graph, user_id=test_user.id)
|
||||
data = {"input_1": "Hello", "input_2": "World"}
|
||||
graph_exec_id = await execute_graph(
|
||||
server.agent_server, server.exec_manager, test_graph, data, 4
|
||||
server.agent_server,
|
||||
server.exec_manager,
|
||||
test_graph,
|
||||
test_user,
|
||||
data,
|
||||
4,
|
||||
)
|
||||
await assert_sample_graph_executions(
|
||||
server.agent_server, test_graph, test_user, graph_exec_id
|
||||
)
|
||||
await assert_sample_graph_executions(server.agent_server, test_graph, graph_exec_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
@@ -134,11 +148,11 @@ async def test_input_pin_always_waited(server):
|
||||
test_user = await create_test_user()
|
||||
test_graph = await graph.create_graph(test_graph, user_id=test_user.id)
|
||||
graph_exec_id = await execute_graph(
|
||||
server.agent_server, server.exec_manager, test_graph, {}, 3
|
||||
server.agent_server, server.exec_manager, test_graph, test_user, {}, 3
|
||||
)
|
||||
|
||||
executions = await server.agent_server.get_run_execution_results(
|
||||
test_graph.id, graph_exec_id
|
||||
test_graph.id, graph_exec_id, test_user.id
|
||||
)
|
||||
assert len(executions) == 3
|
||||
# ObjectLookupBlock should wait for the input pin to be provided,
|
||||
@@ -215,10 +229,10 @@ async def test_static_input_link_on_graph(server):
|
||||
test_user = await create_test_user()
|
||||
test_graph = await graph.create_graph(test_graph, user_id=test_user.id)
|
||||
graph_exec_id = await execute_graph(
|
||||
server.agent_server, server.exec_manager, test_graph, {}, 8
|
||||
server.agent_server, server.exec_manager, test_graph, test_user, {}, 8
|
||||
)
|
||||
executions = await server.agent_server.get_run_execution_results(
|
||||
test_graph.id, graph_exec_id
|
||||
test_graph.id, graph_exec_id, test_user.id
|
||||
)
|
||||
assert len(executions) == 8
|
||||
# The last 3 executions will be a+b=4+5=9
|
||||
|
||||
@@ -14,18 +14,19 @@ async def test_agent_schedule(server):
|
||||
|
||||
scheduler = get_service_client(ExecutionScheduler)
|
||||
|
||||
schedules = scheduler.get_execution_schedules(test_graph.id)
|
||||
schedules = scheduler.get_execution_schedules(test_graph.id, test_user.id)
|
||||
assert len(schedules) == 0
|
||||
|
||||
schedule_id = scheduler.add_execution_schedule(
|
||||
graph_id=test_graph.id,
|
||||
user_id=test_user.id,
|
||||
graph_version=1,
|
||||
cron="0 0 * * *",
|
||||
input_data={"input": "data"},
|
||||
)
|
||||
assert schedule_id
|
||||
|
||||
schedules = scheduler.get_execution_schedules(test_graph.id)
|
||||
schedules = scheduler.get_execution_schedules(test_graph.id, test_user.id)
|
||||
assert len(schedules) == 1
|
||||
assert schedules[schedule_id] == "0 0 * * *"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user