mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-09 23:28:07 -05:00
feat(rnd): Add IPC support on autogpt_server (#7212)
### Background
This PR adds support on IPC on autogpt_server.
To make this happen, there are a couple of refactoring efforts being made (will be described in the `Changes` section).
Currently, there are three independent processes:
```
AgentServer ----> ExecutionManager
|
--> ExecutionScheduler
```
### Changes 🏗️
* Added Pyro5 for IPC support.
* Introduced `AppService`: a class to construct an independent process that can expose a method to other running processes (this is analogous to a microservice).
* Introduced `AppProcess`: used by `AppService` a class for creating a child process that can be executed in the background.
* Adapting existing codebase to user `AppService`.
This commit is contained in:
@@ -1,27 +1,34 @@
|
||||
from multiprocessing import freeze_support
|
||||
from multiprocessing.spawn import freeze_support as freeze_support_spawn
|
||||
|
||||
from autogpt_server.data.execution import ExecutionQueue
|
||||
from autogpt_server.executor import start_executor_manager
|
||||
from autogpt_server.server import start_server
|
||||
from autogpt_server.executor import ExecutionManager, ExecutionScheduler
|
||||
from autogpt_server.server import AgentServer
|
||||
from autogpt_server.util.process import AppProcess
|
||||
from autogpt_server.util.service import PyroNameServer
|
||||
|
||||
|
||||
def background_process() -> None:
|
||||
def run_processes(processes: list[AppProcess], **kwargs):
|
||||
"""
|
||||
Used by the cli to run the server and executor in the background.
|
||||
This function runs the server and starts the executor in the background.
|
||||
Execute all processes in the app. The last process is run in the foreground.
|
||||
"""
|
||||
# These directives are required to make multiprocessing work with cx_Freeze
|
||||
# and are both required and safe across platforms (Windows, macOS, Linux)
|
||||
# They must be placed at the beginning of the executions before any other
|
||||
# multiprocessing code is run
|
||||
freeze_support()
|
||||
freeze_support_spawn()
|
||||
# Start the application
|
||||
queue = ExecutionQueue()
|
||||
start_executor_manager(5, queue)
|
||||
start_server(queue)
|
||||
try:
|
||||
for process in processes[:-1]:
|
||||
process.start(background=True, **kwargs)
|
||||
processes[-1].start(background=False, **kwargs)
|
||||
except Exception as e:
|
||||
for process in processes:
|
||||
process.stop()
|
||||
raise e
|
||||
|
||||
|
||||
def main(**kwargs):
|
||||
run_processes(
|
||||
[
|
||||
PyroNameServer(),
|
||||
ExecutionScheduler(),
|
||||
ExecutionManager(pool_size=5),
|
||||
AgentServer(),
|
||||
],
|
||||
**kwargs
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
background_process()
|
||||
main()
|
||||
|
||||
@@ -2,25 +2,53 @@
|
||||
The command line interface for the agent server
|
||||
"""
|
||||
|
||||
from multiprocessing import freeze_support
|
||||
from multiprocessing.spawn import freeze_support as freeze_support_spawn
|
||||
import os
|
||||
import pathlib
|
||||
|
||||
import click
|
||||
import psutil
|
||||
|
||||
from autogpt_server import app
|
||||
from autogpt_server.util.process import AppProcess
|
||||
|
||||
|
||||
def get_pid_path() -> pathlib.Path:
|
||||
home_dir = pathlib.Path.home()
|
||||
new_dir = home_dir / ".config" / "agpt"
|
||||
file_path = new_dir / "running.tmp"
|
||||
return file_path
|
||||
|
||||
|
||||
def get_pid() -> int | None:
|
||||
file_path = get_pid_path()
|
||||
if not file_path.exists():
|
||||
return None
|
||||
|
||||
os.makedirs(file_path.parent, exist_ok=True)
|
||||
with open(file_path, "r", encoding="utf-8") as file:
|
||||
pid = file.read()
|
||||
try:
|
||||
return int(pid)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
def write_pid(pid: int):
|
||||
file_path = get_pid_path()
|
||||
os.makedirs(file_path.parent, exist_ok=True)
|
||||
with open(file_path, "w", encoding="utf-8") as file:
|
||||
file.write(str(pid))
|
||||
|
||||
|
||||
class MainApp(AppProcess):
|
||||
def run(self):
|
||||
app.main(silent=True)
|
||||
|
||||
|
||||
@click.group()
|
||||
def main():
|
||||
"""AutoGPT Server CLI Tool"""
|
||||
|
||||
|
||||
@main.command()
|
||||
def background() -> None:
|
||||
"""
|
||||
Command to run the server in the background. Used by the run command
|
||||
"""
|
||||
from autogpt_server.app import background_process
|
||||
|
||||
background_process()
|
||||
pass
|
||||
|
||||
|
||||
@main.command()
|
||||
@@ -28,37 +56,22 @@ def start():
|
||||
"""
|
||||
Starts the server in the background and saves the PID
|
||||
"""
|
||||
import os
|
||||
import pathlib
|
||||
import subprocess
|
||||
import psutil
|
||||
|
||||
# Define the path for the new directory and file
|
||||
home_dir = pathlib.Path.home()
|
||||
new_dir = home_dir / ".config" / "agpt"
|
||||
file_path = new_dir / "running.tmp"
|
||||
pid = get_pid()
|
||||
if pid and psutil.pid_exists(pid):
|
||||
print("Server is already running")
|
||||
exit(1)
|
||||
elif pid:
|
||||
print("PID does not exist deleting file")
|
||||
os.remove(get_pid_path())
|
||||
|
||||
# Create the directory if it does not exist
|
||||
os.makedirs(new_dir, exist_ok=True)
|
||||
if file_path.exists():
|
||||
with open(file_path, "r", encoding="utf-8") as file:
|
||||
pid = int(file.read())
|
||||
if psutil.pid_exists(pid):
|
||||
print("Server is already running")
|
||||
exit(1)
|
||||
else:
|
||||
print("PID does not exist deleting file")
|
||||
os.remove(file_path)
|
||||
print("Starting server")
|
||||
pid = MainApp().start(background=True, silent=True)
|
||||
print(f"Server running in process: {pid}")
|
||||
|
||||
sp = subprocess.Popen(
|
||||
["poetry", "run", "python", "autogpt_server/cli.py", "background"],
|
||||
stdout=subprocess.DEVNULL, # Redirect standard output to devnull
|
||||
stderr=subprocess.DEVNULL, # Redirect standard error to devnull
|
||||
)
|
||||
print(f"Server running in process: {sp.pid}")
|
||||
|
||||
with open(file_path, "w", encoding="utf-8") as file:
|
||||
file.write(str(sp.pid))
|
||||
write_pid(pid)
|
||||
print("done")
|
||||
os._exit(status=0)
|
||||
|
||||
|
||||
@main.command()
|
||||
@@ -66,22 +79,17 @@ def stop():
|
||||
"""
|
||||
Stops the server
|
||||
"""
|
||||
import os
|
||||
import pathlib
|
||||
import subprocess
|
||||
|
||||
home_dir = pathlib.Path.home()
|
||||
new_dir = home_dir / ".config" / "agpt"
|
||||
file_path = new_dir / "running.tmp"
|
||||
if not file_path.exists():
|
||||
pid = get_pid()
|
||||
if not pid:
|
||||
print("Server is not running")
|
||||
return
|
||||
|
||||
with open(file_path, "r", encoding="utf-8") as file:
|
||||
pid = file.read()
|
||||
os.remove(file_path)
|
||||
os.remove(get_pid_path())
|
||||
process = psutil.Process(int(pid))
|
||||
for child in process.children(recursive=True):
|
||||
child.terminate()
|
||||
process.terminate()
|
||||
|
||||
subprocess.Popen(["kill", pid])
|
||||
print("Server Stopped")
|
||||
|
||||
|
||||
@@ -90,6 +98,7 @@ def test():
|
||||
"""
|
||||
Group for test commands
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@test.command()
|
||||
@@ -100,17 +109,5 @@ def event():
|
||||
print("Event sent")
|
||||
|
||||
|
||||
main.add_command(test)
|
||||
|
||||
|
||||
def start_cli() -> None:
|
||||
"""
|
||||
Entry point into the cli
|
||||
"""
|
||||
freeze_support()
|
||||
freeze_support_spawn()
|
||||
main()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
start_cli()
|
||||
main()
|
||||
|
||||
@@ -2,9 +2,10 @@ import json
|
||||
import jsonschema
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from prisma.models import AgentBlock
|
||||
from pydantic import BaseModel
|
||||
from typing import Any, ClassVar
|
||||
|
||||
BlockData = dict[str, Any]
|
||||
|
||||
@@ -49,7 +50,7 @@ class BlockSchema(BaseModel):
|
||||
self,
|
||||
properties: dict[str, str | dict],
|
||||
required: list[str] | None = None,
|
||||
**kwargs: Any
|
||||
**kwargs: Any,
|
||||
):
|
||||
schema = {
|
||||
"type": "object",
|
||||
@@ -125,7 +126,7 @@ class Block(ABC, BaseModel):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def run(self, input_data: BlockData) -> tuple[str, Any]:
|
||||
def run(self, input_data: BlockData) -> tuple[str, Any]:
|
||||
"""
|
||||
Run the block with the given input data.
|
||||
Args:
|
||||
@@ -140,13 +141,21 @@ class Block(ABC, BaseModel):
|
||||
def name(cls):
|
||||
return cls.__name__
|
||||
|
||||
async def execute(self, input_data: BlockData) -> tuple[str, Any]:
|
||||
def to_dict(self):
|
||||
return {
|
||||
"id": self.id,
|
||||
"name": self.name,
|
||||
"inputSchema": self.input_schema.jsonschema,
|
||||
"outputSchema": self.output_schema.jsonschema,
|
||||
}
|
||||
|
||||
def execute(self, input_data: BlockData) -> tuple[str, Any]:
|
||||
if error := self.input_schema.validate_data(input_data):
|
||||
raise ValueError(
|
||||
f"Unable to execute block with invalid input data: {error}"
|
||||
)
|
||||
|
||||
output_name, output_data = await self.run(input_data)
|
||||
output_name, output_data = self.run(input_data)
|
||||
|
||||
if error := self.output_schema.validate_field(output_name, output_data):
|
||||
raise ValueError(
|
||||
@@ -161,29 +170,37 @@ class Block(ABC, BaseModel):
|
||||
|
||||
class ParrotBlock(Block):
|
||||
id: ClassVar[str] = "1ff065e9-88e8-4358-9d82-8dc91f622ba9" # type: ignore
|
||||
input_schema: ClassVar[BlockSchema] = BlockSchema({ # type: ignore
|
||||
"input": "string",
|
||||
})
|
||||
output_schema: ClassVar[BlockSchema] = BlockSchema({ # type: ignore
|
||||
"output": "string",
|
||||
})
|
||||
input_schema: ClassVar[BlockSchema] = BlockSchema( # type: ignore
|
||||
{
|
||||
"input": "string",
|
||||
}
|
||||
)
|
||||
output_schema: ClassVar[BlockSchema] = BlockSchema( # type: ignore
|
||||
{
|
||||
"output": "string",
|
||||
}
|
||||
)
|
||||
|
||||
async def run(self, input_data: BlockData) -> tuple[str, Any]:
|
||||
def run(self, input_data: BlockData) -> tuple[str, Any]:
|
||||
return "output", input_data["input"]
|
||||
|
||||
|
||||
class TextCombinerBlock(Block):
|
||||
id: ClassVar[str] = "db7d8f02-2f44-4c55-ab7a-eae0941f0c30" # type: ignore
|
||||
input_schema: ClassVar[BlockSchema] = BlockSchema({ # type: ignore
|
||||
"text1": "string",
|
||||
"text2": "string",
|
||||
"format": "string",
|
||||
})
|
||||
output_schema: ClassVar[BlockSchema] = BlockSchema({ # type: ignore
|
||||
"combined_text": "string",
|
||||
})
|
||||
input_schema: ClassVar[BlockSchema] = BlockSchema( # type: ignore
|
||||
{
|
||||
"text1": "string",
|
||||
"text2": "string",
|
||||
"format": "string",
|
||||
}
|
||||
)
|
||||
output_schema: ClassVar[BlockSchema] = BlockSchema( # type: ignore
|
||||
{
|
||||
"combined_text": "string",
|
||||
}
|
||||
)
|
||||
|
||||
async def run(self, input_data: BlockData) -> tuple[str, Any]:
|
||||
def run(self, input_data: BlockData) -> tuple[str, Any]:
|
||||
return "combined_text", input_data["format"].format(
|
||||
text1=input_data["text1"],
|
||||
text2=input_data["text2"],
|
||||
@@ -192,15 +209,18 @@ class TextCombinerBlock(Block):
|
||||
|
||||
class PrintingBlock(Block):
|
||||
id: ClassVar[str] = "f3b1c1b2-4c4f-4f0d-8d2f-4c4f0d8d2f4c" # type: ignore
|
||||
input_schema: ClassVar[BlockSchema] = BlockSchema({ # type: ignore
|
||||
"text": "string",
|
||||
})
|
||||
output_schema: ClassVar[BlockSchema] = BlockSchema({ # type: ignore
|
||||
"status": "string",
|
||||
})
|
||||
input_schema: ClassVar[BlockSchema] = BlockSchema( # type: ignore
|
||||
{
|
||||
"text": "string",
|
||||
}
|
||||
)
|
||||
output_schema: ClassVar[BlockSchema] = BlockSchema( # type: ignore
|
||||
{
|
||||
"status": "string",
|
||||
}
|
||||
)
|
||||
|
||||
async def run(self, input_data: BlockData) -> tuple[str, Any]:
|
||||
print(input_data["text"])
|
||||
def run(self, input_data: BlockData) -> tuple[str, Any]:
|
||||
return "status", "printed"
|
||||
|
||||
|
||||
@@ -215,10 +235,7 @@ async def initialize_blocks() -> None:
|
||||
AVAILABLE_BLOCKS = {block.id: block() for block in Block.__subclasses__()}
|
||||
|
||||
for block in AVAILABLE_BLOCKS.values():
|
||||
existing_block = await AgentBlock.prisma().find_unique(
|
||||
where={"id": block.id}
|
||||
)
|
||||
if existing_block:
|
||||
if await AgentBlock.prisma().find_unique(where={"id": block.id}):
|
||||
continue
|
||||
|
||||
await AgentBlock.prisma().create(
|
||||
@@ -231,7 +248,13 @@ async def initialize_blocks() -> None:
|
||||
)
|
||||
|
||||
|
||||
async def get_block(block_id: str) -> Block:
|
||||
async def get_blocks() -> list[Block]:
|
||||
if not AVAILABLE_BLOCKS:
|
||||
await initialize_blocks()
|
||||
return AVAILABLE_BLOCKS[block_id]
|
||||
return list(AVAILABLE_BLOCKS.values())
|
||||
|
||||
|
||||
async def get_block(block_id: str) -> Block | None:
|
||||
if not AVAILABLE_BLOCKS:
|
||||
await initialize_blocks()
|
||||
return AVAILABLE_BLOCKS.get(block_id)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
from uuid import uuid4
|
||||
from prisma import Prisma
|
||||
from pydantic import BaseModel
|
||||
@@ -6,16 +5,14 @@ from pydantic import BaseModel
|
||||
prisma = Prisma(auto_register=True)
|
||||
|
||||
|
||||
def connect_sync():
|
||||
asyncio.get_event_loop().run_until_complete(connect())
|
||||
|
||||
|
||||
async def connect():
|
||||
await prisma.connect()
|
||||
if not prisma.is_connected():
|
||||
await prisma.connect()
|
||||
|
||||
|
||||
async def disconnect():
|
||||
await prisma.disconnect()
|
||||
if prisma.is_connected():
|
||||
await prisma.disconnect()
|
||||
|
||||
|
||||
class BaseDbModel(BaseModel):
|
||||
|
||||
@@ -2,15 +2,16 @@ import json
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from multiprocessing import Queue
|
||||
from typing import Any
|
||||
|
||||
from prisma.models import AgentNodeExecution
|
||||
from typing import Any
|
||||
|
||||
from autogpt_server.data.db import BaseDbModel
|
||||
|
||||
|
||||
class Execution(BaseDbModel):
|
||||
"""Data model for an execution of an Agent"""
|
||||
|
||||
run_id: str
|
||||
node_id: str
|
||||
data: dict[str, Any]
|
||||
@@ -23,11 +24,6 @@ class ExecutionStatus(str, Enum):
|
||||
FAILED = "FAILED"
|
||||
|
||||
|
||||
# TODO: This shared class make api & executor coupled in one machine.
|
||||
# Replace this with a persistent & remote-hosted queue.
|
||||
# One very likely candidate would be persisted Redis (Redis Queue).
|
||||
# It will also open the possibility of using it for other purposes like
|
||||
# caching, execution engine broker (like Celery), user session management etc.
|
||||
class ExecutionQueue:
|
||||
"""
|
||||
Queue for managing the execution of agents.
|
||||
@@ -48,7 +44,38 @@ class ExecutionQueue:
|
||||
return self.queue.empty()
|
||||
|
||||
|
||||
async def add_execution(execution: Execution, queue: ExecutionQueue) -> Execution:
|
||||
class ExecutionResult(BaseDbModel):
|
||||
run_id: str
|
||||
execution_id: str
|
||||
node_id: str
|
||||
status: ExecutionStatus
|
||||
input_data: dict[str, Any]
|
||||
output_name: str
|
||||
output_data: Any
|
||||
creation_time: datetime
|
||||
start_time: datetime | None
|
||||
end_time: datetime | None
|
||||
|
||||
@staticmethod
|
||||
def from_db(execution: AgentNodeExecution):
|
||||
return ExecutionResult(
|
||||
run_id=execution.executionId,
|
||||
node_id=execution.agentNodeId,
|
||||
execution_id=execution.id,
|
||||
status=ExecutionStatus(execution.executionStatus),
|
||||
input_data=json.loads(execution.inputData or "{}"),
|
||||
output_name=execution.outputName or "",
|
||||
output_data=json.loads(execution.outputData or "{}"),
|
||||
creation_time=execution.creationTime,
|
||||
start_time=execution.startTime,
|
||||
end_time=execution.endTime,
|
||||
)
|
||||
|
||||
|
||||
# --------------------- Model functions --------------------- #
|
||||
|
||||
|
||||
async def enqueue_execution(execution: Execution) -> None:
|
||||
await AgentNodeExecution.prisma().create(
|
||||
data={
|
||||
"id": execution.id,
|
||||
@@ -59,7 +86,6 @@ async def add_execution(execution: Execution, queue: ExecutionQueue) -> Executio
|
||||
"creationTime": datetime.now(),
|
||||
}
|
||||
)
|
||||
return queue.add(execution)
|
||||
|
||||
|
||||
async def start_execution(exec_id: str) -> None:
|
||||
@@ -96,3 +122,12 @@ async def fail_execution(exec_id: str, error: Exception) -> None:
|
||||
"endTime": datetime.now(),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def get_executions(run_id: str) -> list[ExecutionResult]:
|
||||
executions = await AgentNodeExecution.prisma().find_many(
|
||||
where={"executionId": run_id},
|
||||
order={"startTime": "asc"},
|
||||
)
|
||||
res = [ExecutionResult.from_db(execution) for execution in executions]
|
||||
return res
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from prisma.models import AgentGraph, AgentNode, AgentNodeLink, AgentNodeExecution
|
||||
from typing import Any
|
||||
from prisma.models import AgentGraph, AgentNode, AgentNodeExecution, AgentNodeLink
|
||||
|
||||
from autogpt_server.data.db import BaseDbModel
|
||||
from autogpt_server.data.block import get_block
|
||||
|
||||
|
||||
class Node(BaseDbModel):
|
||||
@@ -33,10 +33,6 @@ class Node(BaseDbModel):
|
||||
self.output_nodes[source_name] = node.id
|
||||
node.input_nodes[sink_name] = self.id
|
||||
|
||||
@property
|
||||
async def block(self):
|
||||
return await get_block(self.block_id)
|
||||
|
||||
|
||||
class Graph(BaseDbModel):
|
||||
name: str
|
||||
@@ -64,6 +60,9 @@ EXECUTION_NODE_INCLUDE = {
|
||||
}
|
||||
|
||||
|
||||
# --------------------- Model functions --------------------- #
|
||||
|
||||
|
||||
async def get_node(node_id: str) -> Node | None:
|
||||
node = await AgentNode.prisma().find_unique_or_raise(
|
||||
where={"id": node_id},
|
||||
@@ -72,6 +71,10 @@ async def get_node(node_id: str) -> Node | None:
|
||||
return Node.from_db(node) if node else None
|
||||
|
||||
|
||||
async def get_graph_ids() -> list[str]:
|
||||
return [graph.id for graph in await AgentGraph.prisma().find_many()] # type: ignore
|
||||
|
||||
|
||||
async def get_graph(graph_id: str) -> Graph | None:
|
||||
graph = await AgentGraph.prisma().find_unique(
|
||||
where={"id": graph_id},
|
||||
@@ -89,7 +92,7 @@ async def get_node_input(node: Node, exec_id: str) -> dict[str, Any]:
|
||||
Returns:
|
||||
dictionary of input data, key is the input name, value is the input data.
|
||||
"""
|
||||
query = AgentNodeExecution.prisma().find_many(
|
||||
query = await AgentNodeExecution.prisma().find_many(
|
||||
where={ # type: ignore
|
||||
"executionId": exec_id,
|
||||
"agentNodeId": {"in": list(node.input_nodes.values())},
|
||||
@@ -100,7 +103,7 @@ async def get_node_input(node: Node, exec_id: str) -> dict[str, Any]:
|
||||
)
|
||||
|
||||
latest_executions: dict[str, AgentNodeExecution] = {
|
||||
execution.agentNodeId: execution for execution in await query
|
||||
execution.agentNodeId: execution for execution in query
|
||||
}
|
||||
|
||||
return {
|
||||
@@ -114,6 +117,7 @@ async def get_node_input(node: Node, exec_id: str) -> dict[str, Any]:
|
||||
|
||||
|
||||
async def create_graph(graph: Graph) -> Graph:
|
||||
|
||||
await AgentGraph.prisma().create(
|
||||
data={
|
||||
"id": graph.id,
|
||||
@@ -123,19 +127,14 @@ async def create_graph(graph: Graph) -> Graph:
|
||||
)
|
||||
|
||||
# TODO: replace bulk creation using create_many
|
||||
await asyncio.gather(
|
||||
*[
|
||||
AgentNode.prisma().create(
|
||||
{
|
||||
"id": node.id,
|
||||
"agentBlockId": node.block_id,
|
||||
"agentGraphId": graph.id,
|
||||
"constantInput": json.dumps(node.input_default),
|
||||
}
|
||||
)
|
||||
for node in graph.nodes
|
||||
]
|
||||
)
|
||||
await asyncio.gather(*[
|
||||
AgentNode.prisma().create({
|
||||
"id": node.id,
|
||||
"agentBlockId": node.block_id,
|
||||
"agentGraphId": graph.id,
|
||||
"constantInput": json.dumps(node.input_default),
|
||||
}) for node in graph.nodes
|
||||
])
|
||||
|
||||
edge_source_names = {
|
||||
(source_node.id, sink_node_id): output_name
|
||||
@@ -149,22 +148,16 @@ async def create_graph(graph: Graph) -> Graph:
|
||||
}
|
||||
|
||||
# TODO: replace bulk creation using create_many
|
||||
await asyncio.gather(
|
||||
*[
|
||||
AgentNodeLink.prisma().create(
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"sourceName": edge_source_names.get((input_node, output_node), ""),
|
||||
"sinkName": edge_sink_names.get((input_node, output_node), ""),
|
||||
"agentNodeSourceId": input_node,
|
||||
"agentNodeSinkId": output_node,
|
||||
}
|
||||
)
|
||||
for input_node, output_node in (
|
||||
edge_source_names.keys() | edge_sink_names.keys()
|
||||
)
|
||||
]
|
||||
)
|
||||
await asyncio.gather(*[
|
||||
AgentNodeLink.prisma().create({
|
||||
"id": str(uuid.uuid4()),
|
||||
"sourceName": edge_source_names.get((input_node, output_node), ""),
|
||||
"sinkName": edge_sink_names.get((input_node, output_node), ""),
|
||||
"agentNodeSourceId": input_node,
|
||||
"agentNodeSinkId": output_node,
|
||||
})
|
||||
for input_node, output_node in edge_source_names.keys() | edge_sink_names.keys()
|
||||
])
|
||||
|
||||
if created_graph := await get_graph(graph.id):
|
||||
return created_graph
|
||||
|
||||
@@ -1 +1,8 @@
|
||||
from .executor import start_executor_manager # type: ignore # noqa
|
||||
from .manager import ExecutionManager
|
||||
from .scheduler import ExecutionScheduler
|
||||
|
||||
__all__ = [
|
||||
"ExecutionManager",
|
||||
"ExecutionScheduler",
|
||||
]
|
||||
|
||||
|
||||
@@ -1,138 +0,0 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from multiprocessing import Process
|
||||
from typing import Optional
|
||||
|
||||
from autogpt_server.data import block, db, graph
|
||||
from autogpt_server.data.execution import (
|
||||
Execution,
|
||||
ExecutionQueue,
|
||||
add_execution,
|
||||
complete_execution,
|
||||
fail_execution,
|
||||
start_execution,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_log_prefix(run_id: str, exec_id: str, block_name: str = "-"):
|
||||
return f"[Execution graph-{run_id}|node-{exec_id}|{block_name}]"
|
||||
|
||||
|
||||
async def execute_node(data: Execution) -> Execution | None:
|
||||
"""
|
||||
Execute a node in the graph. This will trigger a block execution on a node,
|
||||
persist the execution result, and return the subsequent node to be executed.
|
||||
|
||||
Args:
|
||||
data: The execution data for executing the current node.
|
||||
|
||||
Returns:
|
||||
The subsequent node to be enqueued, or None if there is no subsequent node.
|
||||
"""
|
||||
run_id = data.run_id
|
||||
exec_id = data.id
|
||||
exec_data = data.data
|
||||
node_id = data.node_id
|
||||
|
||||
node = await graph.get_node(node_id)
|
||||
if not node:
|
||||
logger.error(f"Node {node_id} not found.")
|
||||
return None
|
||||
|
||||
node_block = await block.get_block(node.block_id)
|
||||
if not node_block:
|
||||
logger.error(f"Block {node.block_id} not found.")
|
||||
return None
|
||||
|
||||
# Execute the node
|
||||
prefix = get_log_prefix(run_id, exec_id, node_block.name)
|
||||
logger.warning(f"{prefix} execute with input:\n{exec_data}")
|
||||
await start_execution(exec_id)
|
||||
|
||||
try:
|
||||
output_name, output_data = await node_block.execute(exec_data)
|
||||
logger.warning(f"{prefix} executed with output: `{output_name}`:{output_data}")
|
||||
await complete_execution(exec_id, (output_name, output_data))
|
||||
except Exception as e:
|
||||
logger.exception(f"{prefix} failed with error: %s", e)
|
||||
await fail_execution(exec_id, e)
|
||||
raise e
|
||||
|
||||
# Try to enqueue next eligible nodes
|
||||
if output_name not in node.output_nodes:
|
||||
logger.error(f"{prefix} output name `{output_name}` has no subsequent node.")
|
||||
return None
|
||||
|
||||
next_node_id = node.output_nodes[output_name]
|
||||
next_node = await graph.get_node(next_node_id)
|
||||
if not next_node:
|
||||
logger.error(f"{prefix} Error, next node {next_node_id} not found.")
|
||||
return None
|
||||
|
||||
next_node_input = await graph.get_node_input(next_node, run_id)
|
||||
next_node_block = await next_node.block
|
||||
|
||||
if not set(next_node.input_nodes).issubset(next_node_input):
|
||||
logger.warning(f"{prefix} Skipped {next_node_id}-{next_node_block.name}, "
|
||||
f"missing: {set(next_node.input_nodes) - set(next_node_input)}")
|
||||
return None
|
||||
|
||||
if error := next_node_block.input_schema.validate_data(next_node_input):
|
||||
logger.warning(
|
||||
f"{prefix} Skipped {next_node_id}-{next_node_block.name}, {error}")
|
||||
return None
|
||||
|
||||
logger.warning(f"{prefix} Enqueue next node {next_node_id}-{next_node_block.name}")
|
||||
return Execution(
|
||||
run_id=run_id, node_id=next_node_id, data=next_node_input
|
||||
)
|
||||
|
||||
|
||||
def execute_node_sync(data: Execution) -> Optional[Execution | None]:
|
||||
"""
|
||||
A synchronous version of `execute_node`, to be used in the ProcessPoolExecutor.
|
||||
"""
|
||||
prefix = get_log_prefix(data.run_id, data.id)
|
||||
try:
|
||||
logger.warning(f"{prefix} Start execution")
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(execute_node(data))
|
||||
except Exception as e:
|
||||
logger.error(f"{prefix} Error: {e}")
|
||||
|
||||
|
||||
def start_executor(pool_size: int, queue: ExecutionQueue) -> None:
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(db.connect())
|
||||
loop.run_until_complete(block.initialize_blocks())
|
||||
|
||||
def on_complete_execution(f: asyncio.Future[Execution | None]):
|
||||
exception = f.exception()
|
||||
if exception:
|
||||
logger.exception("Error during execution!! %s", exception)
|
||||
return exception
|
||||
|
||||
execution = f.result()
|
||||
if execution:
|
||||
loop.run_until_complete(add_execution(execution, queue))
|
||||
return exception
|
||||
|
||||
return None
|
||||
|
||||
logger.warning("Executor started!")
|
||||
|
||||
with ProcessPoolExecutor(
|
||||
max_workers=pool_size,
|
||||
initializer=db.connect_sync,
|
||||
) as executor:
|
||||
while True:
|
||||
future = executor.submit(execute_node_sync, queue.get())
|
||||
future.add_done_callback(on_complete_execution) # type: ignore
|
||||
|
||||
|
||||
def start_executor_manager(pool_size: int, queue: ExecutionQueue) -> None:
|
||||
executor_process = Process(target=start_executor, args=(pool_size, queue))
|
||||
executor_process.start()
|
||||
168
rnd/autogpt_server/autogpt_server/executor/manager.py
Normal file
168
rnd/autogpt_server/autogpt_server/executor/manager.py
Normal file
@@ -0,0 +1,168 @@
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from typing import Optional, Any
|
||||
|
||||
from autogpt_server.data import db
|
||||
from autogpt_server.data.block import Block, get_block
|
||||
from autogpt_server.data.graph import Node, get_node, get_node_input
|
||||
from autogpt_server.data.execution import (
|
||||
Execution,
|
||||
ExecutionQueue,
|
||||
enqueue_execution,
|
||||
complete_execution,
|
||||
fail_execution,
|
||||
start_execution,
|
||||
)
|
||||
from autogpt_server.util.service import AppService, expose
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_log_prefix(run_id: str, exec_id: str, block_name: str = "-"):
|
||||
return f"[Execution graph-{run_id}|node-{exec_id}|{block_name}]"
|
||||
|
||||
|
||||
def execute_node(loop: asyncio.AbstractEventLoop, data: Execution) -> Execution | None:
|
||||
"""
|
||||
Execute a node in the graph. This will trigger a block execution on a node,
|
||||
persist the execution result, and return the subsequent node to be executed.
|
||||
|
||||
Args:
|
||||
loop: The event loop to run the async functions.
|
||||
data: The execution data for executing the current node.
|
||||
|
||||
Returns:
|
||||
The subsequent node to be enqueued, or None if there is no subsequent node.
|
||||
"""
|
||||
run_id = data.run_id
|
||||
exec_id = data.id
|
||||
exec_data = data.data
|
||||
node_id = data.node_id
|
||||
|
||||
asyncio.set_event_loop(loop)
|
||||
wait = lambda f: loop.run_until_complete(f)
|
||||
|
||||
node: Optional[Node] = wait(get_node(node_id))
|
||||
if not node:
|
||||
logger.error(f"Node {node_id} not found.")
|
||||
return None
|
||||
|
||||
node_block: Optional[Block] = wait(get_block(node.block_id))
|
||||
if not node_block:
|
||||
logger.error(f"Block {node.block_id} not found.")
|
||||
return None
|
||||
|
||||
# Execute the node
|
||||
prefix = get_log_prefix(run_id, exec_id, node_block.name)
|
||||
logger.warning(f"{prefix} execute with input:\n`{exec_data}`")
|
||||
wait(start_execution(exec_id))
|
||||
|
||||
try:
|
||||
output_name, output_data = node_block.execute(exec_data)
|
||||
logger.warning(f"{prefix} executed with output [{output_name}]:`{output_data}`")
|
||||
wait(complete_execution(exec_id, (output_name, output_data)))
|
||||
except Exception as e:
|
||||
logger.exception(f"{prefix} failed with error: %s", e)
|
||||
wait(fail_execution(exec_id, e))
|
||||
raise e
|
||||
|
||||
# Try to enqueue next eligible nodes
|
||||
if output_name not in node.output_nodes:
|
||||
logger.error(f"{prefix} Output [{output_name}] has no subsequent node.")
|
||||
return None
|
||||
|
||||
next_node_id = node.output_nodes[output_name]
|
||||
next_node: Optional[Node] = wait(get_node(next_node_id))
|
||||
if not next_node:
|
||||
logger.error(f"{prefix} Error, next node {next_node_id} not found.")
|
||||
return None
|
||||
|
||||
next_node_input: dict[str, Any] = wait(get_node_input(next_node, run_id))
|
||||
next_node_block: Block | None = wait(get_block(next_node.block_id))
|
||||
if not next_node_block:
|
||||
logger.error(f"{prefix} Error, next block {next_node.block_id} not found.")
|
||||
return None
|
||||
|
||||
if not set(next_node.input_nodes).issubset(next_node_input):
|
||||
logger.warning(
|
||||
f"{prefix} Skipped {next_node_id}-{next_node_block.name}, "
|
||||
f"missing: {set(next_node.input_nodes) - set(next_node_input)}"
|
||||
)
|
||||
return None
|
||||
|
||||
if error := next_node_block.input_schema.validate_data(next_node_input):
|
||||
logger.warning(
|
||||
f"{prefix} Skipped {next_node_id}-{next_node_block.name}, {error}"
|
||||
)
|
||||
return None
|
||||
|
||||
logger.warning(f"{prefix} Enqueue next node {next_node_id}-{next_node_block.name}")
|
||||
return Execution(run_id=run_id, node_id=next_node_id, data=next_node_input)
|
||||
|
||||
|
||||
class Executor:
|
||||
loop: asyncio.AbstractEventLoop
|
||||
|
||||
@classmethod
|
||||
def on_executor_start(cls):
|
||||
cls.loop = asyncio.new_event_loop()
|
||||
cls.loop.run_until_complete(db.connect())
|
||||
|
||||
@classmethod
|
||||
def on_start_execution(cls, data: Execution) -> Optional[Execution | None]:
|
||||
"""
|
||||
A synchronous version of `execute_node`, to be used in the ProcessPoolExecutor.
|
||||
"""
|
||||
prefix = get_log_prefix(data.run_id, data.id)
|
||||
try:
|
||||
logger.warning(f"{prefix} Start execution")
|
||||
return execute_node(cls.loop, data)
|
||||
except Exception as e:
|
||||
logger.error(f"{prefix} Error: {e}")
|
||||
|
||||
|
||||
class ExecutionManager(AppService):
|
||||
|
||||
def __init__(self, pool_size: int):
|
||||
self.pool_size = pool_size
|
||||
self.queue = ExecutionQueue()
|
||||
|
||||
def run_service(self):
|
||||
def on_complete_execution(f: asyncio.Future[Execution | None]):
|
||||
exception = f.exception()
|
||||
if exception:
|
||||
logger.exception("Error during execution!! %s", exception)
|
||||
return exception
|
||||
|
||||
execution = f.result()
|
||||
if execution:
|
||||
return self.__add_execution(execution)
|
||||
|
||||
return None
|
||||
|
||||
with ProcessPoolExecutor(
|
||||
max_workers=self.pool_size,
|
||||
initializer=Executor.on_executor_start,
|
||||
) as executor:
|
||||
logger.warning(f"Execution manager started with {self.pool_size} workers.")
|
||||
while True:
|
||||
future = executor.submit(
|
||||
Executor.on_start_execution,
|
||||
self.queue.get()
|
||||
)
|
||||
future.add_done_callback(on_complete_execution) # type: ignore
|
||||
|
||||
@expose
|
||||
def add_execution(self, run_id: str, node_id: str, data: dict[str, Any]) -> str:
|
||||
try:
|
||||
execution = Execution(run_id=run_id, node_id=node_id, data=data)
|
||||
self.__add_execution(execution)
|
||||
return execution.id
|
||||
except Exception as e:
|
||||
raise Exception("Error adding execution ", e)
|
||||
|
||||
def __add_execution(self, execution: Execution) -> Execution:
|
||||
self.run_and_wait(enqueue_execution(execution))
|
||||
return self.queue.add(execution)
|
||||
23
rnd/autogpt_server/autogpt_server/executor/scheduler.py
Normal file
23
rnd/autogpt_server/autogpt_server/executor/scheduler.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import time
|
||||
|
||||
from autogpt_server.util.service import AppService, expose
|
||||
|
||||
|
||||
class ExecutionScheduler(AppService):
|
||||
|
||||
def run_service(self):
|
||||
while True:
|
||||
time.sleep(1) # This will be replaced with apscheduler executor.
|
||||
|
||||
@expose
|
||||
def add_execution_schedule(self, agent_id: str, cron: str, input_data: dict) -> str:
|
||||
print(
|
||||
f"Adding execution schedule for agent {agent_id} with cron {cron} and "
|
||||
f"input data {input_data}"
|
||||
)
|
||||
return "dummy_schedule_id"
|
||||
|
||||
@expose
|
||||
def get_execution_schedules(self, agent_id: str) -> list[dict]:
|
||||
print(f"Getting execution schedules for agent {agent_id}")
|
||||
return [{"cron": "dummy_cron", "input_data": {"dummy_input": "dummy_value"}}]
|
||||
@@ -1 +1,3 @@
|
||||
from .server import start_server # type: ignore # noqa
|
||||
from .server import AgentServer
|
||||
|
||||
__all__ = ["AgentServer"]
|
||||
|
||||
@@ -1,25 +1,25 @@
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import uuid
|
||||
|
||||
import uvicorn
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import APIRouter, FastAPI, HTTPException
|
||||
|
||||
from autogpt_server.data import db, execution, graph
|
||||
from autogpt_server.data import db, execution, graph, block
|
||||
from autogpt_server.executor import ExecutionManager, ExecutionScheduler
|
||||
from autogpt_server.util.process import AppProcess
|
||||
from autogpt_server.util.service import get_service_client
|
||||
|
||||
|
||||
class AgentServer:
|
||||
class AgentServer(AppProcess):
|
||||
|
||||
def __init__(self, queue: execution.ExecutionQueue):
|
||||
@asynccontextmanager
|
||||
async def lifespan(self, _: FastAPI):
|
||||
await db.connect()
|
||||
yield
|
||||
await db.disconnect()
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
await db.connect()
|
||||
yield
|
||||
await db.disconnect()
|
||||
|
||||
self.app = FastAPI(
|
||||
def run(self):
|
||||
app = FastAPI(
|
||||
title="AutoGPT Agent Server",
|
||||
description=(
|
||||
"This server is used to execute agents that are created by the "
|
||||
@@ -27,46 +27,139 @@ class AgentServer:
|
||||
),
|
||||
summary="AutoGPT Agent Server",
|
||||
version="0.1",
|
||||
lifespan=lifespan,
|
||||
lifespan=self.lifespan,
|
||||
)
|
||||
self.execution_queue = queue
|
||||
|
||||
# Define the API routes
|
||||
self.router = APIRouter()
|
||||
self.router.add_api_route(
|
||||
path="/agents/{agent_id}/execute",
|
||||
endpoint=self.execute_agent,
|
||||
router = APIRouter()
|
||||
router.add_api_route(
|
||||
path="/blocks",
|
||||
endpoint=AgentServer.get_agent_blocks,
|
||||
methods=["GET"],
|
||||
)
|
||||
router.add_api_route(
|
||||
path="/agents",
|
||||
endpoint=AgentServer.get_agents,
|
||||
methods=["GET"],
|
||||
)
|
||||
router.add_api_route(
|
||||
path="/agents/{agent_id}",
|
||||
endpoint=AgentServer.get_agent,
|
||||
methods=["GET"],
|
||||
)
|
||||
router.add_api_route(
|
||||
path="/agents",
|
||||
endpoint=AgentServer.create_agent,
|
||||
methods=["POST"],
|
||||
)
|
||||
self.app.include_router(self.router)
|
||||
router.add_api_route(
|
||||
path="/agents/{agent_id}/execute",
|
||||
endpoint=AgentServer.execute_agent,
|
||||
methods=["POST"],
|
||||
)
|
||||
router.add_api_route(
|
||||
path="/agents/{agent_id}/executions/{run_id}",
|
||||
endpoint=AgentServer.get_executions,
|
||||
methods=["GET"],
|
||||
)
|
||||
router.add_api_route(
|
||||
path="/agents/{agent_id}/schedules",
|
||||
endpoint=AgentServer.schedule_agent,
|
||||
methods=["POST"],
|
||||
)
|
||||
router.add_api_route(
|
||||
path="/agents/{agent_id}/schedules",
|
||||
endpoint=AgentServer.get_execution_schedules,
|
||||
methods=["GET"],
|
||||
)
|
||||
|
||||
async def execute_agent(self, agent_id: str, node_input: dict):
|
||||
app.include_router(router)
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
|
||||
@staticmethod
|
||||
async def get_agent_blocks() -> list[dict]:
|
||||
return [v.to_dict() for v in await block.get_blocks()]
|
||||
|
||||
@staticmethod
|
||||
async def get_agents() -> list[str]:
|
||||
return await graph.get_graph_ids()
|
||||
|
||||
@staticmethod
|
||||
async def get_agent(agent_id: str) -> graph.Graph:
|
||||
agent = await graph.get_graph(agent_id)
|
||||
if not agent:
|
||||
raise HTTPException(status_code=404, detail=f"Agent #{agent_id} not found.")
|
||||
|
||||
return agent
|
||||
|
||||
@staticmethod
|
||||
async def create_agent(agent: graph.Graph) -> graph.Graph:
|
||||
agent.id = str(uuid.uuid4())
|
||||
|
||||
id_map = {node.id: str(uuid.uuid4()) for node in agent.nodes}
|
||||
for node in agent.nodes:
|
||||
node.id = id_map[node.id]
|
||||
node.input_nodes = {k: id_map[v] for k, v in node.input_nodes.items()}
|
||||
node.output_nodes = {k: id_map[v] for k, v in node.output_nodes.items()}
|
||||
|
||||
return await graph.create_graph(agent)
|
||||
|
||||
@staticmethod
|
||||
async def execute_agent(agent_id: str, node_input: dict) -> dict:
|
||||
agent = await graph.get_graph(agent_id)
|
||||
if not agent:
|
||||
raise HTTPException(status_code=404, detail=f"Agent #{agent_id} not found.")
|
||||
|
||||
run_id = str(uuid.uuid4())
|
||||
tasks = []
|
||||
executions = []
|
||||
execution_manager = get_service_client(ExecutionManager)
|
||||
|
||||
# Currently, there is no constraint on the number of root nodes in the graph.
|
||||
for node in agent.starting_nodes:
|
||||
block = await node.block
|
||||
if error := block.input_schema.validate_data(node_input):
|
||||
node_block = await block.get_block(node.block_id)
|
||||
if not node_block:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Block #{node.block_id} not found.",
|
||||
)
|
||||
if error := node_block.input_schema.validate_data(node_input):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Input data doesn't match {block.name} input: {error}",
|
||||
detail=f"Input data doesn't match {node_block.name} input: {error}",
|
||||
)
|
||||
|
||||
task = execution.add_execution(
|
||||
execution.Execution(run_id=run_id, node_id=node.id, data=node_input),
|
||||
self.execution_queue,
|
||||
exec_id = execution_manager.add_execution(
|
||||
run_id=run_id, node_id=node.id, data=node_input
|
||||
)
|
||||
executions.append({
|
||||
"exec_id": exec_id,
|
||||
"node_id": node.id,
|
||||
})
|
||||
|
||||
tasks.append(task)
|
||||
return {
|
||||
"run_id": run_id,
|
||||
"executions": executions,
|
||||
}
|
||||
|
||||
return await asyncio.gather(*tasks)
|
||||
@staticmethod
|
||||
async def get_executions(
|
||||
agent_id: str,
|
||||
run_id: str
|
||||
) -> list[execution.ExecutionResult]:
|
||||
agent = await graph.get_graph(agent_id)
|
||||
if not agent:
|
||||
raise HTTPException(status_code=404, detail=f"Agent #{agent_id} not found.")
|
||||
|
||||
return await execution.get_executions(run_id)
|
||||
|
||||
def start_server(queue: execution.ExecutionQueue):
|
||||
agent_server = AgentServer(queue)
|
||||
uvicorn.run(agent_server.app)
|
||||
@staticmethod
|
||||
def schedule_agent(agent_id: str, cron: str, input_data: dict) -> dict:
|
||||
execution_scheduler = get_service_client(ExecutionScheduler)
|
||||
return {
|
||||
"id": execution_scheduler.add_execution_schedule(agent_id, cron, input_data)
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_execution_schedules(agent_id: str) -> list[dict]:
|
||||
execution_scheduler = get_service_client(ExecutionScheduler)
|
||||
return execution_scheduler.get_execution_schedules(agent_id)
|
||||
|
||||
73
rnd/autogpt_server/autogpt_server/util/process.py
Normal file
73
rnd/autogpt_server/autogpt_server/util/process.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import os
|
||||
import sys
|
||||
from abc import ABC, abstractmethod
|
||||
from multiprocessing import Process, freeze_support, set_start_method
|
||||
from multiprocessing.spawn import freeze_support as freeze_support_spawn
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class AppProcess(ABC):
|
||||
"""
|
||||
A class to represent an object that can be executed in a background process.
|
||||
"""
|
||||
process: Optional[Process] = None
|
||||
set_start_method('spawn', force=True)
|
||||
freeze_support()
|
||||
freeze_support_spawn()
|
||||
|
||||
@abstractmethod
|
||||
def run(self):
|
||||
"""
|
||||
The method that will be executed in the process.
|
||||
"""
|
||||
pass
|
||||
|
||||
def execute_run_command(self, silent):
|
||||
try:
|
||||
if silent:
|
||||
sys.stdout = open(os.devnull, "w")
|
||||
sys.stderr = open(os.devnull, "w")
|
||||
self.run()
|
||||
except KeyboardInterrupt or SystemExit as e:
|
||||
print(f"Process terminated: {e}")
|
||||
|
||||
def __enter__(self):
|
||||
self.start(background=True)
|
||||
return self
|
||||
|
||||
def __exit__(self, *args, **kwargs):
|
||||
self.stop()
|
||||
|
||||
def start(self, background: bool = False, silent: bool = False, **proc_args) -> int:
|
||||
"""
|
||||
Start the background process.
|
||||
Args:
|
||||
background: Whether to run the process in the background.
|
||||
silent: Whether to disable stdout and stderr.
|
||||
proc_args: Additional arguments to pass to the process.
|
||||
Returns:
|
||||
the process id or 0 if the process is not running in the background.
|
||||
"""
|
||||
if not background:
|
||||
self.execute_run_command(silent)
|
||||
return 0
|
||||
|
||||
self.process = Process(
|
||||
name=self.__class__.__name__,
|
||||
target=self.execute_run_command,
|
||||
args=(silent,),
|
||||
**proc_args,
|
||||
)
|
||||
self.process.start()
|
||||
return self.process.pid or 0
|
||||
|
||||
def stop(self):
|
||||
"""
|
||||
Stop the background process.
|
||||
"""
|
||||
if not self.process:
|
||||
return
|
||||
|
||||
self.process.terminate()
|
||||
self.process.join()
|
||||
self.process = None
|
||||
100
rnd/autogpt_server/autogpt_server/util/service.py
Normal file
100
rnd/autogpt_server/autogpt_server/util/service.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import time
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Callable, Type, TypeVar, cast, Coroutine
|
||||
|
||||
from Pyro5 import api as pyro
|
||||
from Pyro5 import nameserver
|
||||
from tenacity import retry, stop_after_delay, wait_exponential
|
||||
|
||||
from autogpt_server.data import db
|
||||
from autogpt_server.util.process import AppProcess
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
conn_retry = retry(stop=stop_after_delay(5), wait=wait_exponential(multiplier=0.1))
|
||||
expose = pyro.expose
|
||||
|
||||
|
||||
class PyroNameServer(AppProcess):
|
||||
def run(self):
|
||||
try:
|
||||
print("Starting NameServer loop")
|
||||
nameserver.start_ns_loop()
|
||||
except KeyboardInterrupt:
|
||||
print("Shutting down NameServer")
|
||||
|
||||
|
||||
class AppService(AppProcess):
|
||||
|
||||
shared_event_loop: asyncio.AbstractEventLoop
|
||||
|
||||
@classmethod
|
||||
@property
|
||||
def service_name(cls) -> str:
|
||||
return cls.__name__
|
||||
|
||||
@abstractmethod
|
||||
def run_service(self):
|
||||
while True:
|
||||
time.sleep(10)
|
||||
|
||||
def run_async(self, coro: Coroutine):
|
||||
return asyncio.run_coroutine_threadsafe(coro, self.shared_event_loop)
|
||||
|
||||
def run_and_wait(self, coro: Coroutine):
|
||||
future = self.run_async(coro)
|
||||
return future.result()
|
||||
|
||||
def run(self):
|
||||
self.shared_event_loop = asyncio.get_event_loop()
|
||||
self.shared_event_loop.run_until_complete(db.connect())
|
||||
|
||||
# Initialize the async loop.
|
||||
async_thread = threading.Thread(target=self.__start_async_loop)
|
||||
async_thread.daemon = True
|
||||
async_thread.start()
|
||||
|
||||
# Initialize pyro service
|
||||
daemon_thread = threading.Thread(target=self.__start_pyro)
|
||||
daemon_thread.daemon = True
|
||||
daemon_thread.start()
|
||||
|
||||
# Run the main service (if it's not implemented, just sleep).
|
||||
self.run_service()
|
||||
|
||||
@conn_retry
|
||||
def __start_pyro(self):
|
||||
daemon = pyro.Daemon()
|
||||
ns = pyro.locate_ns()
|
||||
uri = daemon.register(self)
|
||||
ns.register(self.service_name, uri)
|
||||
logger.warning(f"Service [{self.service_name}] Ready. Object URI = {uri}")
|
||||
daemon.requestLoop()
|
||||
|
||||
def __start_async_loop(self):
|
||||
# asyncio.set_event_loop(self.shared_event_loop)
|
||||
self.shared_event_loop.run_forever()
|
||||
|
||||
|
||||
AS = TypeVar("AS", bound=AppService)
|
||||
|
||||
|
||||
def get_service_client(service_type: Type[AS]) -> AS:
|
||||
service_name = service_type.service_name
|
||||
|
||||
class DynamicClient:
|
||||
|
||||
@conn_retry
|
||||
def __init__(self):
|
||||
ns = pyro.locate_ns()
|
||||
uri = ns.lookup(service_name)
|
||||
self.proxy = pyro.Proxy(uri)
|
||||
self.proxy._pyroBind()
|
||||
|
||||
def __getattr__(self, name: str) -> Callable[..., Any]:
|
||||
return getattr(self.proxy, name)
|
||||
|
||||
return cast(AS, DynamicClient())
|
||||
77
rnd/autogpt_server/poetry.lock
generated
77
rnd/autogpt_server/poetry.lock
generated
@@ -1,4 +1,4 @@
|
||||
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "annotated-types"
|
||||
@@ -216,18 +216,18 @@ all = ["email-validator (>=2.0.0)", "httpx (>=0.23.0)", "itsdangerous (>=1.1.0)"
|
||||
|
||||
[[package]]
|
||||
name = "filelock"
|
||||
version = "3.14.0"
|
||||
version = "3.15.1"
|
||||
description = "A platform independent file lock."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "filelock-3.14.0-py3-none-any.whl", hash = "sha256:43339835842f110ca7ae60f1e1c160714c5a6afd15a2873419ab185334975c0f"},
|
||||
{file = "filelock-3.14.0.tar.gz", hash = "sha256:6ea72da3be9b8c82afd3edcf99f2fffbb5076335a5ae4d03248bb5b6c3eae78a"},
|
||||
{file = "filelock-3.15.1-py3-none-any.whl", hash = "sha256:71b3102950e91dfc1bb4209b64be4dc8854f40e5f534428d8684f953ac847fac"},
|
||||
{file = "filelock-3.15.1.tar.gz", hash = "sha256:58a2549afdf9e02e10720eaa4d4470f56386d7a6f72edd7d0596337af8ed7ad8"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.25.2)"]
|
||||
testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8.0.1)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)"]
|
||||
testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8.0.1)", "pytest (>=7.4.3)", "pytest-asyncio (>=0.21)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)"]
|
||||
typing = ["typing-extensions (>=4.8)"]
|
||||
|
||||
[[package]]
|
||||
@@ -695,13 +695,13 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "pydantic"
|
||||
version = "2.7.3"
|
||||
version = "2.7.4"
|
||||
description = "Data validation using Python type hints"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "pydantic-2.7.3-py3-none-any.whl", hash = "sha256:ea91b002777bf643bb20dd717c028ec43216b24a6001a280f83877fd2655d0b4"},
|
||||
{file = "pydantic-2.7.3.tar.gz", hash = "sha256:c46c76a40bb1296728d7a8b99aa73dd70a48c3510111ff290034f860c99c419e"},
|
||||
{file = "pydantic-2.7.4-py3-none-any.whl", hash = "sha256:ee8538d41ccb9c0a9ad3e0e5f07bf15ed8015b481ced539a1759d8cc89ae90d0"},
|
||||
{file = "pydantic-2.7.4.tar.gz", hash = "sha256:0c84efd9548d545f63ac0060c1e4d39bb9b14db8b3c0652338aecc07b5adec52"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -814,6 +814,20 @@ files = [
|
||||
{file = "pyflakes-3.2.0.tar.gz", hash = "sha256:1c61603ff154621fb2a9172037d84dca3500def8c8b630657d1701f026f8af3f"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyro5"
|
||||
version = "5.15"
|
||||
description = "Remote object communication library, fifth major version"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "Pyro5-5.15-py3-none-any.whl", hash = "sha256:4d85428ed75985e63f159d2486ad5680743ea76f766340fd30b65dd20f83d471"},
|
||||
{file = "Pyro5-5.15.tar.gz", hash = "sha256:82c3dfc9860b49f897b28ff24fe6716c841672c600af8fe40d0e3a7fac9a3f5e"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
serpent = ">=1.41"
|
||||
|
||||
[[package]]
|
||||
name = "pytest"
|
||||
version = "8.2.2"
|
||||
@@ -836,24 +850,6 @@ tomli = {version = ">=1", markers = "python_version < \"3.11\""}
|
||||
[package.extras]
|
||||
dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-asyncio"
|
||||
version = "0.23.7"
|
||||
description = "Pytest support for asyncio"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "pytest_asyncio-0.23.7-py3-none-any.whl", hash = "sha256:009b48127fbe44518a547bddd25611551b0e43ccdbf1e67d12479f569832c20b"},
|
||||
{file = "pytest_asyncio-0.23.7.tar.gz", hash = "sha256:5f5c72948f4c49e7db4f29f2521d4031f1c27f86e57b046126654083d4770268"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
pytest = ">=7.0.0,<9"
|
||||
|
||||
[package.extras]
|
||||
docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"]
|
||||
testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-watcher"
|
||||
version = "0.4.2"
|
||||
@@ -908,6 +904,7 @@ files = [
|
||||
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
|
||||
@@ -1091,6 +1088,17 @@ files = [
|
||||
{file = "ruff-0.4.8.tar.gz", hash = "sha256:16d717b1d57b2e2fd68bd0bf80fb43931b79d05a7131aa477d66fc40fbd86268"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serpent"
|
||||
version = "1.41"
|
||||
description = "Serialization based on ast.literal_eval"
|
||||
optional = false
|
||||
python-versions = ">=3.2"
|
||||
files = [
|
||||
{file = "serpent-1.41-py3-none-any.whl", hash = "sha256:5fd776b3420441985bc10679564c2c9b4a19f77bea59f018e473441d98ae5dd7"},
|
||||
{file = "serpent-1.41.tar.gz", hash = "sha256:0407035fe3c6644387d48cff1467d5aa9feff814d07372b78677ed0ee3ed7095"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "setuptools"
|
||||
version = "69.5.1"
|
||||
@@ -1151,6 +1159,21 @@ docs = ["myst-parser[linkify]", "sphinx", "sphinx-rtd-theme"]
|
||||
release = ["twine"]
|
||||
test = ["pylint", "pytest", "pytest-black", "pytest-cov", "pytest-pylint"]
|
||||
|
||||
[[package]]
|
||||
name = "tenacity"
|
||||
version = "8.3.0"
|
||||
description = "Retry code until it succeeds"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "tenacity-8.3.0-py3-none-any.whl", hash = "sha256:3649f6443dbc0d9b01b9d8020a9c4ec7a1ff5f6f3c6c8a036ef371f573fe9185"},
|
||||
{file = "tenacity-8.3.0.tar.gz", hash = "sha256:953d4e6ad24357bceffbc9707bc74349aca9d245f68eb65419cf0c249a1949a2"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
doc = ["reno", "sphinx"]
|
||||
test = ["pytest", "tornado (>=4.5)", "typeguard"]
|
||||
|
||||
[[package]]
|
||||
name = "tomli"
|
||||
version = "2.0.1"
|
||||
@@ -1483,4 +1506,4 @@ test = ["pytest (>=6.0.0)", "setuptools (>=65)"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.10"
|
||||
content-hash = "391567de870dbbf86ea217ff6b15f7c6d2c9406707c196661d29f45deb886812"
|
||||
content-hash = "de508427e9804ded3b3139e13f209baa6cc97bc138d83952ad2b129d3aedc4e2"
|
||||
|
||||
@@ -16,11 +16,12 @@ prisma = "^0.13.1"
|
||||
pytest = "^8.2.1"
|
||||
uvicorn = { extras = ["standard"], version = "^0.30.1" }
|
||||
fastapi = "^0.109.0"
|
||||
pytest-asyncio = "^0.23.7"
|
||||
ruff = "^0.4.8"
|
||||
flake8 = "^7.0.0"
|
||||
jsonschema = "^4.22.0"
|
||||
psutil = "^5.9.8"
|
||||
pyro5 = "^5.15"
|
||||
tenacity = "^8.3.0"
|
||||
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
@@ -35,7 +36,7 @@ build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.poetry.scripts]
|
||||
app = "autogpt_server.app:main"
|
||||
cli = "autogpt_server.cli:start_cli"
|
||||
cli = "autogpt_server.cli:main"
|
||||
|
||||
# https://poethepoet.natn.io/index.html
|
||||
[tool.poe]
|
||||
|
||||
@@ -29,11 +29,14 @@ setup(
|
||||
icon=icon,
|
||||
),
|
||||
Executable(
|
||||
"autogpt_server/cli.py", target_name="agpt_server_cli", base="console", icon=icon
|
||||
"autogpt_server/cli.py",
|
||||
target_name="agpt_server_cli",
|
||||
base="console",
|
||||
icon=icon,
|
||||
),
|
||||
],
|
||||
options={
|
||||
# Options for building all the executables
|
||||
# Options for building all the executables
|
||||
"build_exe": {
|
||||
"packages": packages,
|
||||
"includes": [
|
||||
|
||||
121
rnd/autogpt_server/test/executor/test_manager.py
Normal file
121
rnd/autogpt_server/test/executor/test_manager.py
Normal file
@@ -0,0 +1,121 @@
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from autogpt_server.data import block, db, execution, graph
|
||||
from autogpt_server.executor import ExecutionManager
|
||||
from autogpt_server.server import AgentServer
|
||||
from autogpt_server.util.service import PyroNameServer
|
||||
|
||||
|
||||
async def create_test_graph() -> graph.Graph:
|
||||
"""
|
||||
ParrotBlock
|
||||
\
|
||||
---- TextCombinerBlock ---- PrintingBlock
|
||||
/
|
||||
ParrotBlock
|
||||
"""
|
||||
nodes = [
|
||||
graph.Node(block_id=block.ParrotBlock.id),
|
||||
graph.Node(block_id=block.ParrotBlock.id),
|
||||
graph.Node(
|
||||
block_id=block.TextCombinerBlock.id,
|
||||
input_default={"format": "{text1},{text2}"},
|
||||
),
|
||||
graph.Node(block_id=block.PrintingBlock.id),
|
||||
]
|
||||
nodes[0].connect(nodes[2], "output", "text1")
|
||||
nodes[1].connect(nodes[2], "output", "text2")
|
||||
nodes[2].connect(nodes[3], "combined_text", "text")
|
||||
|
||||
test_graph = graph.Graph(
|
||||
name="TestGraph",
|
||||
description="Test graph",
|
||||
nodes=nodes,
|
||||
)
|
||||
await block.initialize_blocks()
|
||||
result = await graph.create_graph(test_graph)
|
||||
|
||||
# Assertions
|
||||
assert result.name == test_graph.name
|
||||
assert result.description == test_graph.description
|
||||
assert len(result.nodes) == len(test_graph.nodes)
|
||||
|
||||
return test_graph
|
||||
|
||||
|
||||
def execute_agent(test_manager: ExecutionManager, test_graph: graph.Graph, wait_db):
|
||||
# --- Test adding new executions --- #
|
||||
text = "Hello, World!"
|
||||
input_data = {"input": text}
|
||||
response = wait_db(AgentServer.execute_agent(test_graph.id, input_data))
|
||||
executions = response["executions"]
|
||||
run_id = response["run_id"]
|
||||
assert len(executions) == 2
|
||||
|
||||
async def is_execution_completed():
|
||||
execs = await AgentServer.get_executions(test_graph.id, run_id)
|
||||
return test_manager.queue.empty() and len(execs) == 4
|
||||
|
||||
# Wait for the executions to complete
|
||||
for i in range(10):
|
||||
if wait_db(is_execution_completed()):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
# Execution queue should be empty
|
||||
assert wait_db(is_execution_completed())
|
||||
executions = wait_db(AgentServer.get_executions(test_graph.id, run_id))
|
||||
|
||||
# Executing ParrotBlock1
|
||||
exec = executions[0]
|
||||
assert exec.status == execution.ExecutionStatus.COMPLETED
|
||||
assert exec.run_id == run_id
|
||||
assert exec.output_name == "output"
|
||||
assert exec.output_data == "Hello, World!"
|
||||
assert exec.input_data == input_data
|
||||
assert exec.node_id == test_graph.nodes[0].id
|
||||
|
||||
# Executing ParrotBlock2
|
||||
exec = executions[1]
|
||||
assert exec.status == execution.ExecutionStatus.COMPLETED
|
||||
assert exec.run_id == run_id
|
||||
assert exec.output_name == "output"
|
||||
assert exec.output_data == "Hello, World!"
|
||||
assert exec.input_data == input_data
|
||||
assert exec.node_id == test_graph.nodes[1].id
|
||||
|
||||
# Executing TextCombinerBlock
|
||||
exec = executions[2]
|
||||
assert exec.status == execution.ExecutionStatus.COMPLETED
|
||||
assert exec.run_id == run_id
|
||||
assert exec.output_name == "combined_text"
|
||||
assert exec.output_data == "Hello, World!,Hello, World!"
|
||||
assert exec.input_data == {
|
||||
"format": "{text1},{text2}",
|
||||
"text1": "Hello, World!",
|
||||
"text2": "Hello, World!",
|
||||
}
|
||||
assert exec.node_id == test_graph.nodes[2].id
|
||||
|
||||
# Executing PrintingBlock
|
||||
exec = executions[3]
|
||||
assert exec.status == execution.ExecutionStatus.COMPLETED
|
||||
assert exec.run_id == run_id
|
||||
assert exec.output_name == "status"
|
||||
assert exec.output_data == "printed"
|
||||
assert exec.input_data == {"text": "Hello, World!,Hello, World!"}
|
||||
assert exec.node_id == test_graph.nodes[3].id
|
||||
|
||||
|
||||
def test_agent_execution():
|
||||
with PyroNameServer():
|
||||
time.sleep(0.5)
|
||||
with ExecutionManager(1) as test_manager:
|
||||
loop = asyncio.new_event_loop()
|
||||
wait = loop.run_until_complete
|
||||
|
||||
wait(db.connect())
|
||||
test_graph = wait(create_test_graph())
|
||||
|
||||
execute_agent(test_manager, test_graph, wait)
|
||||
@@ -1,97 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from autogpt_server.data import block, db, graph
|
||||
from autogpt_server.data.execution import ExecutionQueue, add_execution
|
||||
from autogpt_server.executor import executor
|
||||
from autogpt_server.server import server
|
||||
|
||||
|
||||
async def create_test_graph() -> graph.Graph:
|
||||
"""
|
||||
ParrotBlock
|
||||
\
|
||||
---- TextCombinerBlock ---- PrintingBlock
|
||||
/
|
||||
ParrotBlock
|
||||
"""
|
||||
nodes = [
|
||||
graph.Node(block_id=block.ParrotBlock.id),
|
||||
graph.Node(block_id=block.ParrotBlock.id),
|
||||
graph.Node(
|
||||
block_id=block.TextCombinerBlock.id,
|
||||
input_default={"format": "{text1},{text2}"}
|
||||
),
|
||||
graph.Node(block_id=block.PrintingBlock.id),
|
||||
]
|
||||
nodes[0].connect(nodes[2], "output", "text1")
|
||||
nodes[1].connect(nodes[2], "output", "text2")
|
||||
nodes[2].connect(nodes[3], "combined_text", "text")
|
||||
|
||||
test_graph = graph.Graph(
|
||||
name="TestGraph",
|
||||
description="Test graph",
|
||||
nodes=nodes,
|
||||
)
|
||||
await block.initialize_blocks()
|
||||
result = await graph.create_graph(test_graph)
|
||||
|
||||
# Assertions
|
||||
assert result.name == test_graph.name
|
||||
assert result.description == test_graph.description
|
||||
assert len(result.nodes) == len(test_graph.nodes)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def execute_node(queue: ExecutionQueue) -> dict | None:
|
||||
next_exec = await executor.execute_node(queue.get())
|
||||
if not next_exec:
|
||||
return None
|
||||
await add_execution(next_exec, queue)
|
||||
return next_exec.data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_execution():
|
||||
await db.connect()
|
||||
test_graph = await create_test_graph()
|
||||
test_queue = ExecutionQueue()
|
||||
test_server = server.AgentServer(test_queue)
|
||||
|
||||
# --- Test adding new executions --- #
|
||||
text = "Hello, World!"
|
||||
input_data = {"input": text}
|
||||
executions = await test_server.execute_agent(test_graph.id, input_data)
|
||||
|
||||
# 2 executions should be created, one for each ParrotBlock, with same run_id.
|
||||
assert len(executions) == 2
|
||||
assert executions[0].run_id == executions[1].run_id
|
||||
assert executions[0].node_id != executions[1].node_id
|
||||
assert executions[0].data == executions[1].data == input_data
|
||||
|
||||
# --- Test Executing added tasks --- #
|
||||
|
||||
# Executing ParrotBlock1, TextCombinerBlock won't be enqueued yet.
|
||||
assert not test_queue.empty()
|
||||
next_execution = await execute_node(test_queue)
|
||||
assert next_execution is None
|
||||
|
||||
# Executing ParrotBlock2, TextCombinerBlock will be enqueued.
|
||||
assert not test_queue.empty()
|
||||
next_execution = await execute_node(test_queue)
|
||||
assert test_queue.empty()
|
||||
assert next_execution
|
||||
assert next_execution.keys() == {"text1", "text2", "format"}
|
||||
assert next_execution["text1"] == text
|
||||
assert next_execution["text2"] == text
|
||||
assert next_execution["format"] == "{text1},{text2}"
|
||||
|
||||
# Executing TextCombinerBlock, PrintingBlock will be enqueued.
|
||||
next_execution = await execute_node(test_queue)
|
||||
assert next_execution
|
||||
assert next_execution.keys() == {"text"}
|
||||
assert next_execution["text"] == f"{text},{text}"
|
||||
|
||||
# Executing PrintingBlock, no more tasks will be enqueued.
|
||||
next_execution = await execute_node(test_queue)
|
||||
assert next_execution is None
|
||||
38
rnd/autogpt_server/test/util/test_service.py
Normal file
38
rnd/autogpt_server/test/util/test_service.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import time
|
||||
|
||||
from autogpt_server.util.service import (
|
||||
AppService,
|
||||
PyroNameServer,
|
||||
expose,
|
||||
get_service_client,
|
||||
)
|
||||
|
||||
|
||||
class TestService(AppService):
|
||||
|
||||
def run_service(self):
|
||||
super().run_service()
|
||||
|
||||
@expose
|
||||
def add(self, a: int, b: int) -> int:
|
||||
return a + b
|
||||
|
||||
@expose
|
||||
def subtract(self, a: int, b: int) -> int:
|
||||
return a - b
|
||||
|
||||
@expose
|
||||
def fun_with_async(self, a: int, b: int) -> int:
|
||||
async def add_async(a: int, b: int) -> int:
|
||||
return a + b
|
||||
return self.run_and_wait(add_async(a, b))
|
||||
|
||||
|
||||
def test_service_creation():
|
||||
with PyroNameServer():
|
||||
time.sleep(0.5)
|
||||
with TestService():
|
||||
client = get_service_client(TestService)
|
||||
assert client.add(5, 3) == 8
|
||||
assert client.subtract(10, 4) == 6
|
||||
assert client.fun_with_async(5, 3) == 8
|
||||
Reference in New Issue
Block a user