Merge branch 'master' into aarushikansal/execution-manager

This commit is contained in:
Aarushi
2024-09-06 16:07:29 +01:00
committed by GitHub
12 changed files with 407 additions and 181 deletions

View File

@@ -58,13 +58,21 @@ We use the Poetry to manage the dependencies. To set up the project, follow thes
6. Migrate the database. Be careful because this deletes current data in the database.
```sh
docker compose up postgres -d
docker compose up postgres redis -d
poetry run prisma migrate dev
```
## Running The Server
### Starting the server directly
### Starting the server without Docker
Run the following command to build the dockerfiles:
```sh
poetry run app
```
### Starting the server with Docker
Run the following command to build the dockerfiles:

View File

@@ -1,9 +1,14 @@
from abc import ABC, abstractmethod
from typing import Any, Generic, List, TypeVar
from typing import Any, List
from pydantic import Field
from autogpt_server.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from autogpt_server.data.block import (
Block,
BlockCategory,
BlockOutput,
BlockSchema,
BlockUIType,
)
from autogpt_server.data.model import SchemaField
from autogpt_server.util.mock import MockObject
@@ -131,63 +136,151 @@ class FindInDictionaryBlock(Block):
yield "missing", input_data.input
T = TypeVar("T")
class InputBlock(Block):
"""
This block is used to provide input to the graph.
It takes in a value, name, description, default values list and bool to limit selection to default values.
class InputOutputBlockInput(BlockSchema, Generic[T]):
value: T = Field(description="The value to be passed as input/output.")
name: str = Field(description="The name of the input/output.")
It Outputs the value passed as input.
"""
class InputOutputBlockOutput(BlockSchema, Generic[T]):
result: T = Field(description="The value passed as input/output.")
class InputOutputBlockBase(Block, ABC, Generic[T]):
@abstractmethod
def block_id(self) -> str:
pass
def __init__(self, *args, **kwargs):
input_schema = InputOutputBlockInput[T]
output_schema = InputOutputBlockOutput[T]
super().__init__(
id=self.block_id(),
description="This block is used to define the input & output of a graph.",
input_schema=input_schema,
output_schema=output_schema,
test_input=[
{"value": {"apple": 1, "banana": 2, "cherry": 3}, "name": "input_1"},
{"value": MockObject(value="!!", key="key"), "name": "input_2"},
],
test_output=[
("result", {"apple": 1, "banana": 2, "cherry": 3}),
("result", MockObject(value="!!", key="key")),
],
static_output=True,
*args,
**kwargs,
class Input(BlockSchema):
value: Any = SchemaField(description="The value to be passed as input.")
name: str = SchemaField(description="The name of the input.")
description: str = SchemaField(description="The description of the input.")
placeholder_values: List[Any] = SchemaField(
description="The placeholder values to be passed as input."
)
limit_to_placeholder_values: bool = SchemaField(
description="Whether to limit the selection to placeholder values.",
default=False,
)
def run(self, input_data: InputOutputBlockInput[T]) -> BlockOutput:
class Output(BlockSchema):
result: Any = SchemaField(description="The value passed as input.")
def __init__(self):
super().__init__(
id="c0a8e994-ebf1-4a9c-a4d8-89d09c86741b",
description="This block is used to provide input to the graph.",
input_schema=InputBlock.Input,
output_schema=InputBlock.Output,
test_input=[
{
"value": "Hello, World!",
"name": "input_1",
"description": "This is a test input.",
"placeholder_values": [],
"limit_to_placeholder_values": False,
},
{
"value": "Hello, World!",
"name": "input_2",
"description": "This is a test input.",
"placeholder_values": ["Hello, World!"],
"limit_to_placeholder_values": True,
},
],
test_output=[
("result", "Hello, World!"),
("result", "Hello, World!"),
],
categories={BlockCategory.INPUT, BlockCategory.BASIC},
ui_type=BlockUIType.INPUT,
)
def run(self, input_data: Input) -> BlockOutput:
yield "result", input_data.value
class InputBlock(InputOutputBlockBase[Any]):
class OutputBlock(Block):
"""
Records the output of the graph for users to see.
Attributes:
recorded_value: The value to be recorded as output.
name: The name of the output.
description: The description of the output.
fmt_string: The format string to be used to format the recorded_value.
Outputs:
output: The formatted recorded_value if fmt_string is provided and the recorded_value
can be formatted, otherwise the raw recorded_value.
Behavior:
If fmt_string is provided and the recorded_value is of a type that can be formatted,
the block attempts to format the recorded_value using the fmt_string.
If formatting fails or no fmt_string is provided, the raw recorded_value is output.
"""
class Input(BlockSchema):
recorded_value: Any = SchemaField(
description="The value to be recorded as output."
)
name: str = SchemaField(description="The name of the output.")
description: str = SchemaField(description="The description of the output.")
fmt_string: str = SchemaField(
description="The format string to be used to format the recorded_value."
)
class Output(BlockSchema):
output: Any = SchemaField(description="The value recorded as output.")
def __init__(self):
super().__init__(categories={BlockCategory.INPUT, BlockCategory.BASIC})
super().__init__(
id="363ae599-353e-4804-937e-b2ee3cef3da4",
description=(
"This block records the graph output. It takes a value to record, "
"with a name, description, and optional format string. If a format "
"string is given, it tries to format the recorded value. The "
"formatted (or raw, if formatting fails) value is then output. "
"This block is key for capturing and presenting final results or "
"important intermediate outputs of the graph execution."
),
input_schema=OutputBlock.Input,
output_schema=OutputBlock.Output,
test_input=[
{
"recorded_value": "Hello, World!",
"name": "output_1",
"description": "This is a test output.",
"fmt_string": "{value}",
},
{
"recorded_value": 42,
"name": "output_2",
"description": "This is another test output.",
"fmt_string": "{value}",
},
{
"recorded_value": MockObject(value="!!", key="key"),
"name": "output_3",
"description": "This is a test output with a mock object.",
"fmt_string": "{value}",
},
],
test_output=[
("output", "Hello, World!"),
("output", 42),
("output", MockObject(value="!!", key="key")),
],
categories={BlockCategory.OUTPUT, BlockCategory.BASIC},
ui_type=BlockUIType.OUTPUT,
)
def block_id(self) -> str:
return "c0a8e994-ebf1-4a9c-a4d8-89d09c86741b"
class OutputBlock(InputOutputBlockBase[Any]):
def __init__(self):
super().__init__(categories={BlockCategory.OUTPUT, BlockCategory.BASIC})
def block_id(self) -> str:
return "363ae599-353e-4804-937e-b2ee3cef3da4"
def run(self, input_data: Input) -> BlockOutput:
"""
Attempts to format the recorded_value using the fmt_string if provided.
If formatting fails or no fmt_string is given, returns the original recorded_value.
"""
if input_data.fmt_string:
try:
yield "output", input_data.fmt_string.format(input_data.recorded_value)
except Exception:
yield "output", input_data.recorded_value
else:
yield "output", input_data.recorded_value
class AddToDictionaryBlock(Block):
@@ -323,3 +416,24 @@ class AddToListBlock(Block):
yield "updated_list", updated_list
except Exception as e:
yield "error", f"Failed to add entry to list: {str(e)}"
class NoteBlock(Block):
class Input(BlockSchema):
text: str = SchemaField(description="The text to display in the sticky note.")
class Output(BlockSchema): ...
def __init__(self):
super().__init__(
id="31d1064e-7446-4693-o7d4-65e5ca9110d1",
description="This block is used to display a sticky note with the given text.",
categories={BlockCategory.BASIC},
input_schema=NoteBlock.Input,
output_schema=NoteBlock.Output,
test_input={"text": "Hello, World!"},
test_output=None,
ui_type=BlockUIType.NOTE,
)
def run(self, input_data: Input) -> BlockOutput: ...

View File

@@ -16,6 +16,17 @@ BlockOutput = Generator[BlockData, None, None] # Output: 1 output pin produces
CompletedBlockOutput = dict[str, list[Any]] # Completed stream, collected as a dict.
class BlockUIType(Enum):
"""
The type of Node UI to be displayed in the builder for this block.
"""
STANDARD = "Standard"
INPUT = "Input"
OUTPUT = "Output"
NOTE = "Note"
class BlockCategory(Enum):
AI = "Block that leverages AI to perform a task."
SOCIAL = "Block that interacts with social media platforms."
@@ -134,6 +145,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
test_mock: dict[str, Any] | None = None,
disabled: bool = False,
static_output: bool = False,
ui_type: BlockUIType = BlockUIType.STANDARD,
):
"""
Initialize the block with the given schema.
@@ -163,6 +175,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
self.contributors = contributors or set()
self.disabled = disabled
self.static_output = static_output
self.ui_type = ui_type
@abstractmethod
def run(self, input_data: BlockSchemaInputType) -> BlockOutput:
@@ -193,6 +206,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
contributor.model_dump() for contributor in self.contributors
],
"staticOutput": self.static_output,
"uiType": self.ui_type.value,
}
def execute(self, input_data: BlockInput) -> BlockOutput:

View File

@@ -245,7 +245,7 @@ async def upsert_execution_input(
async def upsert_execution_output(
node_exec_id: str,
output_name: str,
output_data: str, # JSON serialized data.
output_data: Any,
) -> None:
"""
Insert AgentNodeExecutionInputOutput record for as one of AgentNodeExecution.Output.
@@ -253,7 +253,7 @@ async def upsert_execution_output(
await AgentNodeExecutionInputOutput.prisma().create(
data={
"name": output_name,
"data": output_data,
"data": json.dumps(output_data),
"referencedByOutputExecId": node_exec_id,
}
)

View File

@@ -1,6 +1,10 @@
import asyncio
import atexit
import logging
import multiprocessing
import os
import signal
import sys
import threading
from concurrent.futures import Future, ProcessPoolExecutor
from contextlib import contextmanager
@@ -114,9 +118,7 @@ def execute_node(
if input_data is None:
logger.error(
"Skip execution, input validation error",
extra={
"json_fields": {**log_metadata, "error": error},
},
extra={"json_fields": {**log_metadata, "error": error}},
)
return
@@ -132,13 +134,12 @@ def execute_node(
output_size = 0
try:
for output_name, output_data in node_block.execute(input_data):
output_data_str = json.dumps(output_data)
output_size += len(output_data_str)
output_size += len(json.dumps(output_data))
logger.info(
"Node produced output",
extra={"json_fields": {**log_metadata, output_name: output_data_str}},
extra={"json_fields": {**log_metadata, output_name: output_data}},
)
wait(upsert_execution_output(node_exec_id, output_name, output_data_str))
wait(upsert_execution_output(node_exec_id, output_name, output_data))
for execution in _enqueue_next_nodes(
api_client=api_client,
@@ -254,22 +255,14 @@ def _enqueue_next_nodes(
if not next_node_input:
logger.warning(
f"Skipped queueing {suffix}",
extra={
"json_fields": {
**log_metadata,
}
},
extra={"json_fields": {**log_metadata}},
)
return enqueued_executions
# Input is complete, enqueue the execution.
logger.info(
f"Enqueued {suffix}",
extra={
"json_fields": {
**log_metadata,
}
},
extra={"json_fields": {**log_metadata}},
)
enqueued_executions.append(
add_enqueued_execution(next_node_exec_id, next_node_id, next_node_input)
@@ -402,29 +395,62 @@ class Executor:
@classmethod
def on_node_executor_start(cls):
configure_logging()
cls.logger = logging.getLogger("node_executor")
cls.loop = asyncio.new_event_loop()
cls.pid = os.getpid()
cls.loop.run_until_complete(db.connect())
cls.agent_server_client = get_agent_server_client()
# Set up shutdown handlers
cls.shutdown_lock = threading.Lock()
atexit.register(cls.on_node_executor_stop) # handle regular shutdown
signal.signal( # handle termination
signal.SIGTERM, lambda _, __: cls.on_node_executor_sigterm()
)
@classmethod
def on_node_executor_stop(cls):
if not cls.shutdown_lock.acquire(blocking=False):
return # already shutting down
logger.info(f"[on_node_executor_stop {cls.pid}] ⏳ Disconnecting DB...")
cls.loop.run_until_complete(db.disconnect())
logger.info(f"[on_node_executor_stop {cls.pid}] ✅ Finished cleanup")
@classmethod
def on_node_executor_sigterm(cls):
llprint(f"[on_node_executor_sigterm {cls.pid}] ⚠️ SIGTERM received")
if not cls.shutdown_lock.acquire(blocking=False):
return # already shutting down, no need to self-terminate
llprint(f"[on_node_executor_sigterm {cls.pid}] ⏳ Disconnecting DB...")
cls.loop.run_until_complete(db.disconnect())
llprint(f"[on_node_executor_sigterm {cls.pid}] ✅ Finished cleanup")
sys.exit(0)
@classmethod
@error_logged
def on_node_execution(cls, q: ExecutionQueue[NodeExecution], data: NodeExecution):
def on_node_execution(
cls, q: ExecutionQueue[NodeExecution], node_exec: NodeExecution
):
log_metadata = get_log_metadata(
graph_eid=data.graph_exec_id,
graph_id=data.graph_id,
node_eid=data.node_exec_id,
node_id=data.node_id,
graph_eid=node_exec.graph_exec_id,
graph_id=node_exec.graph_id,
node_eid=node_exec.node_exec_id,
node_id=node_exec.node_id,
block_name="-",
)
execution_stats = {}
timing_info, _ = cls._on_node_execution(q, data, log_metadata, execution_stats)
timing_info, _ = cls._on_node_execution(
q, node_exec, log_metadata, execution_stats
)
execution_stats["walltime"] = timing_info.wall_time
execution_stats["cputime"] = timing_info.cpu_time
cls.loop.run_until_complete(
update_node_execution_stats(data.node_exec_id, execution_stats)
update_node_execution_stats(node_exec.node_exec_id, execution_stats)
)
@classmethod
@@ -432,32 +458,26 @@ class Executor:
def _on_node_execution(
cls,
q: ExecutionQueue[NodeExecution],
d: NodeExecution,
node_exec: NodeExecution,
log_metadata: dict,
stats: dict[str, Any] | None = None,
):
try:
cls.logger.info(
"Start node execution",
extra={
"json_fields": {
**log_metadata,
}
},
logger.info(
f"Start node execution {node_exec.node_exec_id}",
extra={"json_fields": {**log_metadata}},
)
for execution in execute_node(cls.loop, cls.agent_server_client, d, stats):
for execution in execute_node(
cls.loop, cls.agent_server_client, node_exec, stats
):
q.add(execution)
cls.logger.info(
"Finished node execution",
extra={
"json_fields": {
**log_metadata,
}
},
logger.info(
f"Finished node execution {node_exec.node_exec_id}",
extra={"json_fields": {**log_metadata}},
)
except Exception as e:
cls.logger.exception(
f"Failed node execution: {e}",
logger.exception(
f"Failed node execution {node_exec.node_exec_id}: {e}",
extra={
**log_metadata,
},
@@ -466,12 +486,26 @@ class Executor:
@classmethod
def on_graph_executor_start(cls):
configure_logging()
cls.logger = logging.getLogger("graph_executor")
cls.loop = asyncio.new_event_loop()
cls.loop.run_until_complete(db.connect())
cls.pool_size = Config().num_node_workers
cls.loop = asyncio.new_event_loop()
cls.pid = os.getpid()
cls.loop.run_until_complete(db.connect())
cls._init_node_executor_pool()
logger.info(f"Graph executor started with max-{cls.pool_size} node workers.")
logger.info(
f"Graph executor {cls.pid} started with {cls.pool_size} node workers"
)
# Set up shutdown handler
atexit.register(cls.on_graph_executor_stop)
@classmethod
def on_graph_executor_stop(cls):
logger.info(
f"[on_graph_executor_stop {cls.pid}] ⏳ Terminating node executor pool..."
)
cls.executor.terminate()
@classmethod
def _init_node_executor_pool(cls):
@@ -482,19 +516,21 @@ class Executor:
@classmethod
@error_logged
def on_graph_execution(cls, data: GraphExecution, cancel: threading.Event):
def on_graph_execution(cls, graph_exec: GraphExecution, cancel: threading.Event):
log_metadata = get_log_metadata(
graph_eid=data.graph_exec_id,
graph_id=data.graph_id,
graph_eid=graph_exec.graph_exec_id,
graph_id=graph_exec.graph_id,
node_id="*",
node_eid="*",
block_name="-",
)
timing_info, node_count = cls._on_graph_execution(data, cancel, log_metadata)
timing_info, node_count = cls._on_graph_execution(
graph_exec, cancel, log_metadata
)
cls.loop.run_until_complete(
update_graph_execution_stats(
data.graph_exec_id,
graph_exec.graph_exec_id,
{
"walltime": timing_info.wall_time,
"cputime": timing_info.cpu_time,
@@ -506,15 +542,11 @@ class Executor:
@classmethod
@time_measured
def _on_graph_execution(
cls, graph_data: GraphExecution, cancel: threading.Event, log_metadata: dict
cls, graph_exec: GraphExecution, cancel: threading.Event, log_metadata: dict
) -> int:
cls.logger.info(
"Start graph execution",
extra={
"json_fields": {
**log_metadata,
}
},
logger.info(
f"Start graph execution {graph_exec.graph_exec_id}",
extra={"json_fields": {**log_metadata}},
)
n_node_executions = 0
finished = False
@@ -526,7 +558,7 @@ class Executor:
return
cls.executor.terminate()
logger.info(
f"Terminated graph execution {graph_data.graph_exec_id}",
f"Terminated graph execution {graph_exec.graph_exec_id}",
extra={"json_fields": {**log_metadata}},
)
cls._init_node_executor_pool()
@@ -536,7 +568,7 @@ class Executor:
try:
queue = ExecutionQueue[NodeExecution]()
for node_exec in graph_data.start_node_execs:
for node_exec in graph_exec.start_node_execs:
queue.add(node_exec)
running_executions: dict[str, AsyncResult] = {}
@@ -566,7 +598,11 @@ class Executor:
# Re-enqueueing the data back to the queue will disrupt the order.
execution.wait()
logger.debug(f"Dispatching execution of node {exec_data.node_id}")
logger.debug(
f"Dispatching node execution {exec_data.node_exec_id} "
f"for node {exec_data.node_id}",
extra={**log_metadata},
)
running_executions[exec_data.node_id] = cls.executor.apply_async(
cls.on_node_execution,
(queue, exec_data),
@@ -577,41 +613,30 @@ class Executor:
while queue.empty() and running_executions:
logger.debug(
"Queue empty; running nodes: "
f"{list(running_executions.keys())}"
f"{list(running_executions.keys())}",
extra={"json_fields": {**log_metadata}},
)
for node_id, execution in list(running_executions.items()):
if cancel.is_set():
return n_node_executions
if not queue.empty():
logger.debug(
"Queue no longer empty! Returning to dispatching loop."
)
break # yield to parent loop to execute new queue items
logger.debug(f"Waiting on execution of node {node_id}")
execution.wait(3)
logger.debug(
f"State of execution of node {node_id} after waiting: "
f"{'DONE' if execution.ready() else 'RUNNING'}"
f"Waiting on execution of node {node_id}",
extra={"json_fields": {**log_metadata}},
)
execution.wait(3)
cls.logger.info(
"Finished graph execution",
extra={
"json_fields": {
**log_metadata,
}
},
logger.info(
f"Finished graph execution {graph_exec.graph_exec_id}",
extra={"json_fields": {**log_metadata}},
)
except Exception as e:
logger.exception(
f"Failed graph execution: {e}",
extra={
"json_fields": {
**log_metadata,
}
},
f"Failed graph execution {graph_exec.graph_exec_id}: {e}",
extra={"json_fields": {**log_metadata}},
)
finally:
if not cancel.is_set():
@@ -629,29 +654,35 @@ class ExecutionManager(AppService):
self.queue = ExecutionQueue[GraphExecution]()
self.active_graph_runs: dict[str, tuple[Future, threading.Event]] = {}
# def __del__(self):
# self.sync_manager.shutdown()
def run_service(self):
with ProcessPoolExecutor(
self.executor = ProcessPoolExecutor(
max_workers=self.pool_size,
initializer=Executor.on_graph_executor_start,
) as executor:
sync_manager = multiprocessing.Manager()
logger.info(
f"Execution manager started with max-{self.pool_size} graph workers."
)
sync_manager = multiprocessing.Manager()
logger.info(
f"[{self.service_name}] Started with max-{self.pool_size} graph workers"
)
while True:
graph_exec_data = self.queue.get()
graph_exec_id = graph_exec_data.graph_exec_id
logger.debug(
f"[ExecutionManager] Dispatching graph execution {graph_exec_id}"
)
while True:
graph_exec_data = self.queue.get()
graph_exec_id = graph_exec_data.graph_exec_id
cancel_event = sync_manager.Event()
future = executor.submit(
Executor.on_graph_execution, graph_exec_data, cancel_event
)
self.active_graph_runs[graph_exec_id] = (future, cancel_event)
future.add_done_callback(
lambda _: self.active_graph_runs.pop(graph_exec_id)
)
cancel_event = sync_manager.Event()
future = self.executor.submit(
Executor.on_graph_execution, graph_exec_data, cancel_event
)
self.active_graph_runs[graph_exec_id] = (future, cancel_event)
future.add_done_callback(
lambda _: self.active_graph_runs.pop(graph_exec_id)
)
def cleanup(self):
logger.info(f"[{__class__.__name__}] ⏳ Shutting down graph executor pool...")
self.executor.shutdown(cancel_futures=True)
super().cleanup()
@property
def agent_server_client(self) -> "AgentServer":
@@ -755,3 +786,12 @@ class ExecutionManager(AppService):
)
)
self.agent_server_client.send_execution_update(exec_update.model_dump())
def llprint(message: str):
"""
Low-level print/log helper function for use in signal handlers.
Regular log/print statements are not allowed in signal handlers.
"""
if logger.getEffectiveLevel() == logging.DEBUG:
os.write(sys.stdout.fileno(), (message + "\n").encode())

View File

@@ -201,7 +201,7 @@ class AgentServer(AppService):
app.include_router(router)
uvicorn.run(app, host="0.0.0.0", port=8000)
uvicorn.run(app, host="0.0.0.0", port=8000, log_config=None)
def set_test_dependency_overrides(self, overrides: dict):
self._test_dependency_overrides = overrides

View File

@@ -13,6 +13,7 @@ from autogpt_server.server.model import ExecutionSubscription, Methods, WsMessag
from autogpt_server.util.service import AppProcess
from autogpt_server.util.settings import Settings
logger = logging.getLogger(__name__)
settings = Settings()
app = FastAPI()
@@ -93,7 +94,7 @@ async def handle_subscribe(
else:
ex_sub = ExecutionSubscription.model_validate(message.data)
await manager.subscribe(ex_sub.graph_id, websocket)
print("subscribed")
logger.debug(f"New execution subscription for graph {ex_sub.graph_id}")
await websocket.send_text(
WsMessage(
method=Methods.SUBSCRIBE,
@@ -117,7 +118,7 @@ async def handle_unsubscribe(
else:
ex_sub = ExecutionSubscription.model_validate(message.data)
await manager.unsubscribe(ex_sub.graph_id, websocket)
print("unsubscribed")
logger.debug(f"Removed execution subscription for graph {ex_sub.graph_id}")
await websocket.send_text(
WsMessage(
method=Methods.UNSUBSCRIBE,
@@ -151,11 +152,12 @@ async def websocket_router(
await handle_unsubscribe(websocket, manager, message)
elif message.method == Methods.ERROR:
logging.error("WebSocket Error message received:", message.data)
logger.error(f"WebSocket Error message received: {message.data}")
else:
logging.info(
f"Message type {message.method} is not processed by the server"
logger.warning(
f"Unknown WebSocket message type {message.method} received: "
f"{message.data}"
)
await websocket.send_text(
WsMessage(
@@ -167,7 +169,7 @@ async def websocket_router(
except WebSocketDisconnect:
manager.disconnect(websocket)
logging.info("Client Disconnected")
logger.debug("WebSocket client disconnected")
class WebsocketServer(AppProcess):

View File

@@ -20,20 +20,30 @@ async def create_test_user() -> User:
def create_test_graph() -> graph.Graph:
"""
StoreValueBlock
InputBlock
\
---- FillTextTemplateBlock ---- PrintToConsoleBlock
/
StoreValueBlock
InputBlock
"""
nodes = [
graph.Node(
block_id=InputBlock().id,
input_default={"name": "input_1"},
input_default={
"name": "input_1",
"description": "First input value",
"placeholder_values": [],
"limit_to_placeholder_values": False,
},
),
graph.Node(
block_id=InputBlock().id,
input_default={"name": "input_2"},
input_default={
"name": "input_2",
"description": "Second input value",
"placeholder_values": [],
"limit_to_placeholder_values": False,
},
),
graph.Node(
block_id=FillTextTemplateBlock().id,

View File

@@ -1,4 +1,6 @@
import logging
import os
import signal
import sys
from abc import ABC, abstractmethod
from multiprocessing import Process, set_start_method
@@ -7,6 +9,8 @@ from typing import Optional
from autogpt_server.util.logging import configure_logging
from autogpt_server.util.metrics import sentry_init
logger = logging.getLogger(__name__)
class AppProcess(ABC):
"""
@@ -19,6 +23,8 @@ class AppProcess(ABC):
configure_logging()
sentry_init()
# Methods that are executed INSIDE the process #
@abstractmethod
def run(self):
"""
@@ -26,6 +32,13 @@ class AppProcess(ABC):
"""
pass
def cleanup(self):
"""
Implement this method on a subclass to do post-execution cleanup,
e.g. disconnecting from a database or terminating child processes.
"""
pass
def health_check(self):
"""
A method to check the health of the process.
@@ -33,13 +46,22 @@ class AppProcess(ABC):
pass
def execute_run_command(self, silent):
signal.signal(signal.SIGTERM, self._self_terminate)
try:
if silent:
sys.stdout = open(os.devnull, "w")
sys.stderr = open(os.devnull, "w")
logger.info(f"[{self.__class__.__name__}] Starting...")
self.run()
except KeyboardInterrupt or SystemExit as e:
print(f"Process terminated: {e}")
except (KeyboardInterrupt, SystemExit) as e:
logger.warning(f"[{self.__class__.__name__}] Terminated: {e}; quitting...")
def _self_terminate(self, signum: int, frame):
self.cleanup()
sys.exit(0)
# Methods that are executed OUTSIDE the process #
def __enter__(self):
self.start(background=True)

View File

@@ -46,19 +46,15 @@ def expose(func: C) -> C:
class PyroNameServer(AppProcess):
def run(self):
try:
print("Starting NameServer loop")
nameserver.start_ns_loop(host=pyro_host, port=9090)
except KeyboardInterrupt:
print("Shutting down NameServer")
nameserver.start_ns_loop(host=pyro_host, port=9090)
@conn_retry
def _wait_for_ns(self):
pyro.locate_ns(host="localhost", port=9090)
print("NameServer is ready")
def health_check(self):
self._wait_for_ns()
logger.info(f"{__class__.__name__} is ready")
class AppService(AppProcess):
@@ -108,6 +104,14 @@ class AppService(AppProcess):
# Run the main service (if it's not implemented, just sleep).
self.run_service()
def cleanup(self):
if self.use_db:
logger.info(f"[{self.__class__.__name__}] ⏳ Disconnecting DB...")
self.run_and_wait(db.disconnect())
if self.use_redis:
logger.info(f"[{self.__class__.__name__}] ⏳ Disconnecting Redis...")
self.run_and_wait(self.event_queue.close())
@conn_retry
def __start_pyro(self):
host = Config().pyro_host

View File

@@ -73,10 +73,10 @@ class SpinTestServer:
async def __aexit__(self, exc_type, exc_val, exc_tb):
await db.disconnect()
self.name_server.__exit__(exc_type, exc_val, exc_tb)
self.agent_server.__exit__(exc_type, exc_val, exc_tb)
self.exec_manager.__exit__(exc_type, exc_val, exc_tb)
self.scheduler.__exit__(exc_type, exc_val, exc_tb)
self.exec_manager.__exit__(exc_type, exc_val, exc_tb)
self.agent_server.__exit__(exc_type, exc_val, exc_tb)
self.name_server.__exit__(exc_type, exc_val, exc_tb)
def setup_dependency_overrides(self):
# Override get_user_id for testing

View File

@@ -41,8 +41,20 @@ async def assert_sample_graph_executions(
output_list = [{"result": ["Hello"]}, {"result": ["World"]}]
input_list = [
{"value": "Hello", "name": "input_1"},
{"value": "World", "name": "input_2"},
{
"name": "input_1",
"description": "First input value",
"placeholder_values": [],
"limit_to_placeholder_values": False,
"value": "Hello",
},
{
"name": "input_2",
"description": "Second input value",
"placeholder_values": [],
"limit_to_placeholder_values": False,
"value": "World",
},
]
# Executing StoreValueBlock