mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
Merge branch 'master' into aarushikansal/execution-manager
This commit is contained in:
@@ -58,13 +58,21 @@ We use the Poetry to manage the dependencies. To set up the project, follow thes
|
||||
6. Migrate the database. Be careful because this deletes current data in the database.
|
||||
|
||||
```sh
|
||||
docker compose up postgres -d
|
||||
docker compose up postgres redis -d
|
||||
poetry run prisma migrate dev
|
||||
```
|
||||
|
||||
## Running The Server
|
||||
|
||||
### Starting the server directly
|
||||
### Starting the server without Docker
|
||||
|
||||
Run the following command to build the dockerfiles:
|
||||
|
||||
```sh
|
||||
poetry run app
|
||||
```
|
||||
|
||||
### Starting the server with Docker
|
||||
|
||||
Run the following command to build the dockerfiles:
|
||||
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Generic, List, TypeVar
|
||||
from typing import Any, List
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from autogpt_server.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from autogpt_server.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockUIType,
|
||||
)
|
||||
from autogpt_server.data.model import SchemaField
|
||||
from autogpt_server.util.mock import MockObject
|
||||
|
||||
@@ -131,63 +136,151 @@ class FindInDictionaryBlock(Block):
|
||||
yield "missing", input_data.input
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
class InputBlock(Block):
|
||||
"""
|
||||
This block is used to provide input to the graph.
|
||||
|
||||
It takes in a value, name, description, default values list and bool to limit selection to default values.
|
||||
|
||||
class InputOutputBlockInput(BlockSchema, Generic[T]):
|
||||
value: T = Field(description="The value to be passed as input/output.")
|
||||
name: str = Field(description="The name of the input/output.")
|
||||
It Outputs the value passed as input.
|
||||
"""
|
||||
|
||||
|
||||
class InputOutputBlockOutput(BlockSchema, Generic[T]):
|
||||
result: T = Field(description="The value passed as input/output.")
|
||||
|
||||
|
||||
class InputOutputBlockBase(Block, ABC, Generic[T]):
|
||||
@abstractmethod
|
||||
def block_id(self) -> str:
|
||||
pass
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
input_schema = InputOutputBlockInput[T]
|
||||
output_schema = InputOutputBlockOutput[T]
|
||||
|
||||
super().__init__(
|
||||
id=self.block_id(),
|
||||
description="This block is used to define the input & output of a graph.",
|
||||
input_schema=input_schema,
|
||||
output_schema=output_schema,
|
||||
test_input=[
|
||||
{"value": {"apple": 1, "banana": 2, "cherry": 3}, "name": "input_1"},
|
||||
{"value": MockObject(value="!!", key="key"), "name": "input_2"},
|
||||
],
|
||||
test_output=[
|
||||
("result", {"apple": 1, "banana": 2, "cherry": 3}),
|
||||
("result", MockObject(value="!!", key="key")),
|
||||
],
|
||||
static_output=True,
|
||||
*args,
|
||||
**kwargs,
|
||||
class Input(BlockSchema):
|
||||
value: Any = SchemaField(description="The value to be passed as input.")
|
||||
name: str = SchemaField(description="The name of the input.")
|
||||
description: str = SchemaField(description="The description of the input.")
|
||||
placeholder_values: List[Any] = SchemaField(
|
||||
description="The placeholder values to be passed as input."
|
||||
)
|
||||
limit_to_placeholder_values: bool = SchemaField(
|
||||
description="Whether to limit the selection to placeholder values.",
|
||||
default=False,
|
||||
)
|
||||
|
||||
def run(self, input_data: InputOutputBlockInput[T]) -> BlockOutput:
|
||||
class Output(BlockSchema):
|
||||
result: Any = SchemaField(description="The value passed as input.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c0a8e994-ebf1-4a9c-a4d8-89d09c86741b",
|
||||
description="This block is used to provide input to the graph.",
|
||||
input_schema=InputBlock.Input,
|
||||
output_schema=InputBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"value": "Hello, World!",
|
||||
"name": "input_1",
|
||||
"description": "This is a test input.",
|
||||
"placeholder_values": [],
|
||||
"limit_to_placeholder_values": False,
|
||||
},
|
||||
{
|
||||
"value": "Hello, World!",
|
||||
"name": "input_2",
|
||||
"description": "This is a test input.",
|
||||
"placeholder_values": ["Hello, World!"],
|
||||
"limit_to_placeholder_values": True,
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
("result", "Hello, World!"),
|
||||
("result", "Hello, World!"),
|
||||
],
|
||||
categories={BlockCategory.INPUT, BlockCategory.BASIC},
|
||||
ui_type=BlockUIType.INPUT,
|
||||
)
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
yield "result", input_data.value
|
||||
|
||||
|
||||
class InputBlock(InputOutputBlockBase[Any]):
|
||||
class OutputBlock(Block):
|
||||
"""
|
||||
Records the output of the graph for users to see.
|
||||
|
||||
Attributes:
|
||||
recorded_value: The value to be recorded as output.
|
||||
name: The name of the output.
|
||||
description: The description of the output.
|
||||
fmt_string: The format string to be used to format the recorded_value.
|
||||
|
||||
Outputs:
|
||||
output: The formatted recorded_value if fmt_string is provided and the recorded_value
|
||||
can be formatted, otherwise the raw recorded_value.
|
||||
|
||||
Behavior:
|
||||
If fmt_string is provided and the recorded_value is of a type that can be formatted,
|
||||
the block attempts to format the recorded_value using the fmt_string.
|
||||
If formatting fails or no fmt_string is provided, the raw recorded_value is output.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
recorded_value: Any = SchemaField(
|
||||
description="The value to be recorded as output."
|
||||
)
|
||||
name: str = SchemaField(description="The name of the output.")
|
||||
description: str = SchemaField(description="The description of the output.")
|
||||
fmt_string: str = SchemaField(
|
||||
description="The format string to be used to format the recorded_value."
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
output: Any = SchemaField(description="The value recorded as output.")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(categories={BlockCategory.INPUT, BlockCategory.BASIC})
|
||||
super().__init__(
|
||||
id="363ae599-353e-4804-937e-b2ee3cef3da4",
|
||||
description=(
|
||||
"This block records the graph output. It takes a value to record, "
|
||||
"with a name, description, and optional format string. If a format "
|
||||
"string is given, it tries to format the recorded value. The "
|
||||
"formatted (or raw, if formatting fails) value is then output. "
|
||||
"This block is key for capturing and presenting final results or "
|
||||
"important intermediate outputs of the graph execution."
|
||||
),
|
||||
input_schema=OutputBlock.Input,
|
||||
output_schema=OutputBlock.Output,
|
||||
test_input=[
|
||||
{
|
||||
"recorded_value": "Hello, World!",
|
||||
"name": "output_1",
|
||||
"description": "This is a test output.",
|
||||
"fmt_string": "{value}",
|
||||
},
|
||||
{
|
||||
"recorded_value": 42,
|
||||
"name": "output_2",
|
||||
"description": "This is another test output.",
|
||||
"fmt_string": "{value}",
|
||||
},
|
||||
{
|
||||
"recorded_value": MockObject(value="!!", key="key"),
|
||||
"name": "output_3",
|
||||
"description": "This is a test output with a mock object.",
|
||||
"fmt_string": "{value}",
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
("output", "Hello, World!"),
|
||||
("output", 42),
|
||||
("output", MockObject(value="!!", key="key")),
|
||||
],
|
||||
categories={BlockCategory.OUTPUT, BlockCategory.BASIC},
|
||||
ui_type=BlockUIType.OUTPUT,
|
||||
)
|
||||
|
||||
def block_id(self) -> str:
|
||||
return "c0a8e994-ebf1-4a9c-a4d8-89d09c86741b"
|
||||
|
||||
|
||||
class OutputBlock(InputOutputBlockBase[Any]):
|
||||
def __init__(self):
|
||||
super().__init__(categories={BlockCategory.OUTPUT, BlockCategory.BASIC})
|
||||
|
||||
def block_id(self) -> str:
|
||||
return "363ae599-353e-4804-937e-b2ee3cef3da4"
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
"""
|
||||
Attempts to format the recorded_value using the fmt_string if provided.
|
||||
If formatting fails or no fmt_string is given, returns the original recorded_value.
|
||||
"""
|
||||
if input_data.fmt_string:
|
||||
try:
|
||||
yield "output", input_data.fmt_string.format(input_data.recorded_value)
|
||||
except Exception:
|
||||
yield "output", input_data.recorded_value
|
||||
else:
|
||||
yield "output", input_data.recorded_value
|
||||
|
||||
|
||||
class AddToDictionaryBlock(Block):
|
||||
@@ -323,3 +416,24 @@ class AddToListBlock(Block):
|
||||
yield "updated_list", updated_list
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to add entry to list: {str(e)}"
|
||||
|
||||
|
||||
class NoteBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
text: str = SchemaField(description="The text to display in the sticky note.")
|
||||
|
||||
class Output(BlockSchema): ...
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="31d1064e-7446-4693-o7d4-65e5ca9110d1",
|
||||
description="This block is used to display a sticky note with the given text.",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=NoteBlock.Input,
|
||||
output_schema=NoteBlock.Output,
|
||||
test_input={"text": "Hello, World!"},
|
||||
test_output=None,
|
||||
ui_type=BlockUIType.NOTE,
|
||||
)
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput: ...
|
||||
|
||||
@@ -16,6 +16,17 @@ BlockOutput = Generator[BlockData, None, None] # Output: 1 output pin produces
|
||||
CompletedBlockOutput = dict[str, list[Any]] # Completed stream, collected as a dict.
|
||||
|
||||
|
||||
class BlockUIType(Enum):
|
||||
"""
|
||||
The type of Node UI to be displayed in the builder for this block.
|
||||
"""
|
||||
|
||||
STANDARD = "Standard"
|
||||
INPUT = "Input"
|
||||
OUTPUT = "Output"
|
||||
NOTE = "Note"
|
||||
|
||||
|
||||
class BlockCategory(Enum):
|
||||
AI = "Block that leverages AI to perform a task."
|
||||
SOCIAL = "Block that interacts with social media platforms."
|
||||
@@ -134,6 +145,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
test_mock: dict[str, Any] | None = None,
|
||||
disabled: bool = False,
|
||||
static_output: bool = False,
|
||||
ui_type: BlockUIType = BlockUIType.STANDARD,
|
||||
):
|
||||
"""
|
||||
Initialize the block with the given schema.
|
||||
@@ -163,6 +175,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
self.contributors = contributors or set()
|
||||
self.disabled = disabled
|
||||
self.static_output = static_output
|
||||
self.ui_type = ui_type
|
||||
|
||||
@abstractmethod
|
||||
def run(self, input_data: BlockSchemaInputType) -> BlockOutput:
|
||||
@@ -193,6 +206,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
contributor.model_dump() for contributor in self.contributors
|
||||
],
|
||||
"staticOutput": self.static_output,
|
||||
"uiType": self.ui_type.value,
|
||||
}
|
||||
|
||||
def execute(self, input_data: BlockInput) -> BlockOutput:
|
||||
|
||||
@@ -245,7 +245,7 @@ async def upsert_execution_input(
|
||||
async def upsert_execution_output(
|
||||
node_exec_id: str,
|
||||
output_name: str,
|
||||
output_data: str, # JSON serialized data.
|
||||
output_data: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Insert AgentNodeExecutionInputOutput record for as one of AgentNodeExecution.Output.
|
||||
@@ -253,7 +253,7 @@ async def upsert_execution_output(
|
||||
await AgentNodeExecutionInputOutput.prisma().create(
|
||||
data={
|
||||
"name": output_name,
|
||||
"data": output_data,
|
||||
"data": json.dumps(output_data),
|
||||
"referencedByOutputExecId": node_exec_id,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
import asyncio
|
||||
import atexit
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
from concurrent.futures import Future, ProcessPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
@@ -114,9 +118,7 @@ def execute_node(
|
||||
if input_data is None:
|
||||
logger.error(
|
||||
"Skip execution, input validation error",
|
||||
extra={
|
||||
"json_fields": {**log_metadata, "error": error},
|
||||
},
|
||||
extra={"json_fields": {**log_metadata, "error": error}},
|
||||
)
|
||||
return
|
||||
|
||||
@@ -132,13 +134,12 @@ def execute_node(
|
||||
output_size = 0
|
||||
try:
|
||||
for output_name, output_data in node_block.execute(input_data):
|
||||
output_data_str = json.dumps(output_data)
|
||||
output_size += len(output_data_str)
|
||||
output_size += len(json.dumps(output_data))
|
||||
logger.info(
|
||||
"Node produced output",
|
||||
extra={"json_fields": {**log_metadata, output_name: output_data_str}},
|
||||
extra={"json_fields": {**log_metadata, output_name: output_data}},
|
||||
)
|
||||
wait(upsert_execution_output(node_exec_id, output_name, output_data_str))
|
||||
wait(upsert_execution_output(node_exec_id, output_name, output_data))
|
||||
|
||||
for execution in _enqueue_next_nodes(
|
||||
api_client=api_client,
|
||||
@@ -254,22 +255,14 @@ def _enqueue_next_nodes(
|
||||
if not next_node_input:
|
||||
logger.warning(
|
||||
f"Skipped queueing {suffix}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_metadata,
|
||||
}
|
||||
},
|
||||
extra={"json_fields": {**log_metadata}},
|
||||
)
|
||||
return enqueued_executions
|
||||
|
||||
# Input is complete, enqueue the execution.
|
||||
logger.info(
|
||||
f"Enqueued {suffix}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_metadata,
|
||||
}
|
||||
},
|
||||
extra={"json_fields": {**log_metadata}},
|
||||
)
|
||||
enqueued_executions.append(
|
||||
add_enqueued_execution(next_node_exec_id, next_node_id, next_node_input)
|
||||
@@ -402,29 +395,62 @@ class Executor:
|
||||
@classmethod
|
||||
def on_node_executor_start(cls):
|
||||
configure_logging()
|
||||
cls.logger = logging.getLogger("node_executor")
|
||||
|
||||
cls.loop = asyncio.new_event_loop()
|
||||
cls.pid = os.getpid()
|
||||
|
||||
cls.loop.run_until_complete(db.connect())
|
||||
cls.agent_server_client = get_agent_server_client()
|
||||
|
||||
# Set up shutdown handlers
|
||||
cls.shutdown_lock = threading.Lock()
|
||||
atexit.register(cls.on_node_executor_stop) # handle regular shutdown
|
||||
signal.signal( # handle termination
|
||||
signal.SIGTERM, lambda _, __: cls.on_node_executor_sigterm()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def on_node_executor_stop(cls):
|
||||
if not cls.shutdown_lock.acquire(blocking=False):
|
||||
return # already shutting down
|
||||
|
||||
logger.info(f"[on_node_executor_stop {cls.pid}] ⏳ Disconnecting DB...")
|
||||
cls.loop.run_until_complete(db.disconnect())
|
||||
logger.info(f"[on_node_executor_stop {cls.pid}] ✅ Finished cleanup")
|
||||
|
||||
@classmethod
|
||||
def on_node_executor_sigterm(cls):
|
||||
llprint(f"[on_node_executor_sigterm {cls.pid}] ⚠️ SIGTERM received")
|
||||
if not cls.shutdown_lock.acquire(blocking=False):
|
||||
return # already shutting down, no need to self-terminate
|
||||
|
||||
llprint(f"[on_node_executor_sigterm {cls.pid}] ⏳ Disconnecting DB...")
|
||||
cls.loop.run_until_complete(db.disconnect())
|
||||
llprint(f"[on_node_executor_sigterm {cls.pid}] ✅ Finished cleanup")
|
||||
sys.exit(0)
|
||||
|
||||
@classmethod
|
||||
@error_logged
|
||||
def on_node_execution(cls, q: ExecutionQueue[NodeExecution], data: NodeExecution):
|
||||
def on_node_execution(
|
||||
cls, q: ExecutionQueue[NodeExecution], node_exec: NodeExecution
|
||||
):
|
||||
log_metadata = get_log_metadata(
|
||||
graph_eid=data.graph_exec_id,
|
||||
graph_id=data.graph_id,
|
||||
node_eid=data.node_exec_id,
|
||||
node_id=data.node_id,
|
||||
graph_eid=node_exec.graph_exec_id,
|
||||
graph_id=node_exec.graph_id,
|
||||
node_eid=node_exec.node_exec_id,
|
||||
node_id=node_exec.node_id,
|
||||
block_name="-",
|
||||
)
|
||||
|
||||
execution_stats = {}
|
||||
timing_info, _ = cls._on_node_execution(q, data, log_metadata, execution_stats)
|
||||
timing_info, _ = cls._on_node_execution(
|
||||
q, node_exec, log_metadata, execution_stats
|
||||
)
|
||||
execution_stats["walltime"] = timing_info.wall_time
|
||||
execution_stats["cputime"] = timing_info.cpu_time
|
||||
|
||||
cls.loop.run_until_complete(
|
||||
update_node_execution_stats(data.node_exec_id, execution_stats)
|
||||
update_node_execution_stats(node_exec.node_exec_id, execution_stats)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -432,32 +458,26 @@ class Executor:
|
||||
def _on_node_execution(
|
||||
cls,
|
||||
q: ExecutionQueue[NodeExecution],
|
||||
d: NodeExecution,
|
||||
node_exec: NodeExecution,
|
||||
log_metadata: dict,
|
||||
stats: dict[str, Any] | None = None,
|
||||
):
|
||||
try:
|
||||
cls.logger.info(
|
||||
"Start node execution",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_metadata,
|
||||
}
|
||||
},
|
||||
logger.info(
|
||||
f"Start node execution {node_exec.node_exec_id}",
|
||||
extra={"json_fields": {**log_metadata}},
|
||||
)
|
||||
for execution in execute_node(cls.loop, cls.agent_server_client, d, stats):
|
||||
for execution in execute_node(
|
||||
cls.loop, cls.agent_server_client, node_exec, stats
|
||||
):
|
||||
q.add(execution)
|
||||
cls.logger.info(
|
||||
"Finished node execution",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_metadata,
|
||||
}
|
||||
},
|
||||
logger.info(
|
||||
f"Finished node execution {node_exec.node_exec_id}",
|
||||
extra={"json_fields": {**log_metadata}},
|
||||
)
|
||||
except Exception as e:
|
||||
cls.logger.exception(
|
||||
f"Failed node execution: {e}",
|
||||
logger.exception(
|
||||
f"Failed node execution {node_exec.node_exec_id}: {e}",
|
||||
extra={
|
||||
**log_metadata,
|
||||
},
|
||||
@@ -466,12 +486,26 @@ class Executor:
|
||||
@classmethod
|
||||
def on_graph_executor_start(cls):
|
||||
configure_logging()
|
||||
cls.logger = logging.getLogger("graph_executor")
|
||||
cls.loop = asyncio.new_event_loop()
|
||||
cls.loop.run_until_complete(db.connect())
|
||||
|
||||
cls.pool_size = Config().num_node_workers
|
||||
cls.loop = asyncio.new_event_loop()
|
||||
cls.pid = os.getpid()
|
||||
|
||||
cls.loop.run_until_complete(db.connect())
|
||||
cls._init_node_executor_pool()
|
||||
logger.info(f"Graph executor started with max-{cls.pool_size} node workers.")
|
||||
logger.info(
|
||||
f"Graph executor {cls.pid} started with {cls.pool_size} node workers"
|
||||
)
|
||||
|
||||
# Set up shutdown handler
|
||||
atexit.register(cls.on_graph_executor_stop)
|
||||
|
||||
@classmethod
|
||||
def on_graph_executor_stop(cls):
|
||||
logger.info(
|
||||
f"[on_graph_executor_stop {cls.pid}] ⏳ Terminating node executor pool..."
|
||||
)
|
||||
cls.executor.terminate()
|
||||
|
||||
@classmethod
|
||||
def _init_node_executor_pool(cls):
|
||||
@@ -482,19 +516,21 @@ class Executor:
|
||||
|
||||
@classmethod
|
||||
@error_logged
|
||||
def on_graph_execution(cls, data: GraphExecution, cancel: threading.Event):
|
||||
def on_graph_execution(cls, graph_exec: GraphExecution, cancel: threading.Event):
|
||||
log_metadata = get_log_metadata(
|
||||
graph_eid=data.graph_exec_id,
|
||||
graph_id=data.graph_id,
|
||||
graph_eid=graph_exec.graph_exec_id,
|
||||
graph_id=graph_exec.graph_id,
|
||||
node_id="*",
|
||||
node_eid="*",
|
||||
block_name="-",
|
||||
)
|
||||
timing_info, node_count = cls._on_graph_execution(data, cancel, log_metadata)
|
||||
timing_info, node_count = cls._on_graph_execution(
|
||||
graph_exec, cancel, log_metadata
|
||||
)
|
||||
|
||||
cls.loop.run_until_complete(
|
||||
update_graph_execution_stats(
|
||||
data.graph_exec_id,
|
||||
graph_exec.graph_exec_id,
|
||||
{
|
||||
"walltime": timing_info.wall_time,
|
||||
"cputime": timing_info.cpu_time,
|
||||
@@ -506,15 +542,11 @@ class Executor:
|
||||
@classmethod
|
||||
@time_measured
|
||||
def _on_graph_execution(
|
||||
cls, graph_data: GraphExecution, cancel: threading.Event, log_metadata: dict
|
||||
cls, graph_exec: GraphExecution, cancel: threading.Event, log_metadata: dict
|
||||
) -> int:
|
||||
cls.logger.info(
|
||||
"Start graph execution",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_metadata,
|
||||
}
|
||||
},
|
||||
logger.info(
|
||||
f"Start graph execution {graph_exec.graph_exec_id}",
|
||||
extra={"json_fields": {**log_metadata}},
|
||||
)
|
||||
n_node_executions = 0
|
||||
finished = False
|
||||
@@ -526,7 +558,7 @@ class Executor:
|
||||
return
|
||||
cls.executor.terminate()
|
||||
logger.info(
|
||||
f"Terminated graph execution {graph_data.graph_exec_id}",
|
||||
f"Terminated graph execution {graph_exec.graph_exec_id}",
|
||||
extra={"json_fields": {**log_metadata}},
|
||||
)
|
||||
cls._init_node_executor_pool()
|
||||
@@ -536,7 +568,7 @@ class Executor:
|
||||
|
||||
try:
|
||||
queue = ExecutionQueue[NodeExecution]()
|
||||
for node_exec in graph_data.start_node_execs:
|
||||
for node_exec in graph_exec.start_node_execs:
|
||||
queue.add(node_exec)
|
||||
|
||||
running_executions: dict[str, AsyncResult] = {}
|
||||
@@ -566,7 +598,11 @@ class Executor:
|
||||
# Re-enqueueing the data back to the queue will disrupt the order.
|
||||
execution.wait()
|
||||
|
||||
logger.debug(f"Dispatching execution of node {exec_data.node_id}")
|
||||
logger.debug(
|
||||
f"Dispatching node execution {exec_data.node_exec_id} "
|
||||
f"for node {exec_data.node_id}",
|
||||
extra={**log_metadata},
|
||||
)
|
||||
running_executions[exec_data.node_id] = cls.executor.apply_async(
|
||||
cls.on_node_execution,
|
||||
(queue, exec_data),
|
||||
@@ -577,41 +613,30 @@ class Executor:
|
||||
while queue.empty() and running_executions:
|
||||
logger.debug(
|
||||
"Queue empty; running nodes: "
|
||||
f"{list(running_executions.keys())}"
|
||||
f"{list(running_executions.keys())}",
|
||||
extra={"json_fields": {**log_metadata}},
|
||||
)
|
||||
for node_id, execution in list(running_executions.items()):
|
||||
if cancel.is_set():
|
||||
return n_node_executions
|
||||
|
||||
if not queue.empty():
|
||||
logger.debug(
|
||||
"Queue no longer empty! Returning to dispatching loop."
|
||||
)
|
||||
break # yield to parent loop to execute new queue items
|
||||
|
||||
logger.debug(f"Waiting on execution of node {node_id}")
|
||||
execution.wait(3)
|
||||
logger.debug(
|
||||
f"State of execution of node {node_id} after waiting: "
|
||||
f"{'DONE' if execution.ready() else 'RUNNING'}"
|
||||
f"Waiting on execution of node {node_id}",
|
||||
extra={"json_fields": {**log_metadata}},
|
||||
)
|
||||
execution.wait(3)
|
||||
|
||||
cls.logger.info(
|
||||
"Finished graph execution",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_metadata,
|
||||
}
|
||||
},
|
||||
logger.info(
|
||||
f"Finished graph execution {graph_exec.graph_exec_id}",
|
||||
extra={"json_fields": {**log_metadata}},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed graph execution: {e}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_metadata,
|
||||
}
|
||||
},
|
||||
f"Failed graph execution {graph_exec.graph_exec_id}: {e}",
|
||||
extra={"json_fields": {**log_metadata}},
|
||||
)
|
||||
finally:
|
||||
if not cancel.is_set():
|
||||
@@ -629,29 +654,35 @@ class ExecutionManager(AppService):
|
||||
self.queue = ExecutionQueue[GraphExecution]()
|
||||
self.active_graph_runs: dict[str, tuple[Future, threading.Event]] = {}
|
||||
|
||||
# def __del__(self):
|
||||
# self.sync_manager.shutdown()
|
||||
|
||||
def run_service(self):
|
||||
with ProcessPoolExecutor(
|
||||
self.executor = ProcessPoolExecutor(
|
||||
max_workers=self.pool_size,
|
||||
initializer=Executor.on_graph_executor_start,
|
||||
) as executor:
|
||||
sync_manager = multiprocessing.Manager()
|
||||
logger.info(
|
||||
f"Execution manager started with max-{self.pool_size} graph workers."
|
||||
)
|
||||
sync_manager = multiprocessing.Manager()
|
||||
logger.info(
|
||||
f"[{self.service_name}] Started with max-{self.pool_size} graph workers"
|
||||
)
|
||||
while True:
|
||||
graph_exec_data = self.queue.get()
|
||||
graph_exec_id = graph_exec_data.graph_exec_id
|
||||
logger.debug(
|
||||
f"[ExecutionManager] Dispatching graph execution {graph_exec_id}"
|
||||
)
|
||||
while True:
|
||||
graph_exec_data = self.queue.get()
|
||||
graph_exec_id = graph_exec_data.graph_exec_id
|
||||
cancel_event = sync_manager.Event()
|
||||
future = executor.submit(
|
||||
Executor.on_graph_execution, graph_exec_data, cancel_event
|
||||
)
|
||||
self.active_graph_runs[graph_exec_id] = (future, cancel_event)
|
||||
future.add_done_callback(
|
||||
lambda _: self.active_graph_runs.pop(graph_exec_id)
|
||||
)
|
||||
cancel_event = sync_manager.Event()
|
||||
future = self.executor.submit(
|
||||
Executor.on_graph_execution, graph_exec_data, cancel_event
|
||||
)
|
||||
self.active_graph_runs[graph_exec_id] = (future, cancel_event)
|
||||
future.add_done_callback(
|
||||
lambda _: self.active_graph_runs.pop(graph_exec_id)
|
||||
)
|
||||
|
||||
def cleanup(self):
|
||||
logger.info(f"[{__class__.__name__}] ⏳ Shutting down graph executor pool...")
|
||||
self.executor.shutdown(cancel_futures=True)
|
||||
|
||||
super().cleanup()
|
||||
|
||||
@property
|
||||
def agent_server_client(self) -> "AgentServer":
|
||||
@@ -755,3 +786,12 @@ class ExecutionManager(AppService):
|
||||
)
|
||||
)
|
||||
self.agent_server_client.send_execution_update(exec_update.model_dump())
|
||||
|
||||
|
||||
def llprint(message: str):
|
||||
"""
|
||||
Low-level print/log helper function for use in signal handlers.
|
||||
Regular log/print statements are not allowed in signal handlers.
|
||||
"""
|
||||
if logger.getEffectiveLevel() == logging.DEBUG:
|
||||
os.write(sys.stdout.fileno(), (message + "\n").encode())
|
||||
|
||||
@@ -201,7 +201,7 @@ class AgentServer(AppService):
|
||||
|
||||
app.include_router(router)
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000, log_config=None)
|
||||
|
||||
def set_test_dependency_overrides(self, overrides: dict):
|
||||
self._test_dependency_overrides = overrides
|
||||
|
||||
@@ -13,6 +13,7 @@ from autogpt_server.server.model import ExecutionSubscription, Methods, WsMessag
|
||||
from autogpt_server.util.service import AppProcess
|
||||
from autogpt_server.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
|
||||
app = FastAPI()
|
||||
@@ -93,7 +94,7 @@ async def handle_subscribe(
|
||||
else:
|
||||
ex_sub = ExecutionSubscription.model_validate(message.data)
|
||||
await manager.subscribe(ex_sub.graph_id, websocket)
|
||||
print("subscribed")
|
||||
logger.debug(f"New execution subscription for graph {ex_sub.graph_id}")
|
||||
await websocket.send_text(
|
||||
WsMessage(
|
||||
method=Methods.SUBSCRIBE,
|
||||
@@ -117,7 +118,7 @@ async def handle_unsubscribe(
|
||||
else:
|
||||
ex_sub = ExecutionSubscription.model_validate(message.data)
|
||||
await manager.unsubscribe(ex_sub.graph_id, websocket)
|
||||
print("unsubscribed")
|
||||
logger.debug(f"Removed execution subscription for graph {ex_sub.graph_id}")
|
||||
await websocket.send_text(
|
||||
WsMessage(
|
||||
method=Methods.UNSUBSCRIBE,
|
||||
@@ -151,11 +152,12 @@ async def websocket_router(
|
||||
await handle_unsubscribe(websocket, manager, message)
|
||||
|
||||
elif message.method == Methods.ERROR:
|
||||
logging.error("WebSocket Error message received:", message.data)
|
||||
logger.error(f"WebSocket Error message received: {message.data}")
|
||||
|
||||
else:
|
||||
logging.info(
|
||||
f"Message type {message.method} is not processed by the server"
|
||||
logger.warning(
|
||||
f"Unknown WebSocket message type {message.method} received: "
|
||||
f"{message.data}"
|
||||
)
|
||||
await websocket.send_text(
|
||||
WsMessage(
|
||||
@@ -167,7 +169,7 @@ async def websocket_router(
|
||||
|
||||
except WebSocketDisconnect:
|
||||
manager.disconnect(websocket)
|
||||
logging.info("Client Disconnected")
|
||||
logger.debug("WebSocket client disconnected")
|
||||
|
||||
|
||||
class WebsocketServer(AppProcess):
|
||||
|
||||
@@ -20,20 +20,30 @@ async def create_test_user() -> User:
|
||||
|
||||
def create_test_graph() -> graph.Graph:
|
||||
"""
|
||||
StoreValueBlock
|
||||
InputBlock
|
||||
\
|
||||
---- FillTextTemplateBlock ---- PrintToConsoleBlock
|
||||
/
|
||||
StoreValueBlock
|
||||
InputBlock
|
||||
"""
|
||||
nodes = [
|
||||
graph.Node(
|
||||
block_id=InputBlock().id,
|
||||
input_default={"name": "input_1"},
|
||||
input_default={
|
||||
"name": "input_1",
|
||||
"description": "First input value",
|
||||
"placeholder_values": [],
|
||||
"limit_to_placeholder_values": False,
|
||||
},
|
||||
),
|
||||
graph.Node(
|
||||
block_id=InputBlock().id,
|
||||
input_default={"name": "input_2"},
|
||||
input_default={
|
||||
"name": "input_2",
|
||||
"description": "Second input value",
|
||||
"placeholder_values": [],
|
||||
"limit_to_placeholder_values": False,
|
||||
},
|
||||
),
|
||||
graph.Node(
|
||||
block_id=FillTextTemplateBlock().id,
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
from abc import ABC, abstractmethod
|
||||
from multiprocessing import Process, set_start_method
|
||||
@@ -7,6 +9,8 @@ from typing import Optional
|
||||
from autogpt_server.util.logging import configure_logging
|
||||
from autogpt_server.util.metrics import sentry_init
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AppProcess(ABC):
|
||||
"""
|
||||
@@ -19,6 +23,8 @@ class AppProcess(ABC):
|
||||
configure_logging()
|
||||
sentry_init()
|
||||
|
||||
# Methods that are executed INSIDE the process #
|
||||
|
||||
@abstractmethod
|
||||
def run(self):
|
||||
"""
|
||||
@@ -26,6 +32,13 @@ class AppProcess(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
def cleanup(self):
|
||||
"""
|
||||
Implement this method on a subclass to do post-execution cleanup,
|
||||
e.g. disconnecting from a database or terminating child processes.
|
||||
"""
|
||||
pass
|
||||
|
||||
def health_check(self):
|
||||
"""
|
||||
A method to check the health of the process.
|
||||
@@ -33,13 +46,22 @@ class AppProcess(ABC):
|
||||
pass
|
||||
|
||||
def execute_run_command(self, silent):
|
||||
signal.signal(signal.SIGTERM, self._self_terminate)
|
||||
|
||||
try:
|
||||
if silent:
|
||||
sys.stdout = open(os.devnull, "w")
|
||||
sys.stderr = open(os.devnull, "w")
|
||||
logger.info(f"[{self.__class__.__name__}] Starting...")
|
||||
self.run()
|
||||
except KeyboardInterrupt or SystemExit as e:
|
||||
print(f"Process terminated: {e}")
|
||||
except (KeyboardInterrupt, SystemExit) as e:
|
||||
logger.warning(f"[{self.__class__.__name__}] Terminated: {e}; quitting...")
|
||||
|
||||
def _self_terminate(self, signum: int, frame):
|
||||
self.cleanup()
|
||||
sys.exit(0)
|
||||
|
||||
# Methods that are executed OUTSIDE the process #
|
||||
|
||||
def __enter__(self):
|
||||
self.start(background=True)
|
||||
|
||||
@@ -46,19 +46,15 @@ def expose(func: C) -> C:
|
||||
|
||||
class PyroNameServer(AppProcess):
|
||||
def run(self):
|
||||
try:
|
||||
print("Starting NameServer loop")
|
||||
nameserver.start_ns_loop(host=pyro_host, port=9090)
|
||||
except KeyboardInterrupt:
|
||||
print("Shutting down NameServer")
|
||||
nameserver.start_ns_loop(host=pyro_host, port=9090)
|
||||
|
||||
@conn_retry
|
||||
def _wait_for_ns(self):
|
||||
pyro.locate_ns(host="localhost", port=9090)
|
||||
print("NameServer is ready")
|
||||
|
||||
def health_check(self):
|
||||
self._wait_for_ns()
|
||||
logger.info(f"{__class__.__name__} is ready")
|
||||
|
||||
|
||||
class AppService(AppProcess):
|
||||
@@ -108,6 +104,14 @@ class AppService(AppProcess):
|
||||
# Run the main service (if it's not implemented, just sleep).
|
||||
self.run_service()
|
||||
|
||||
def cleanup(self):
|
||||
if self.use_db:
|
||||
logger.info(f"[{self.__class__.__name__}] ⏳ Disconnecting DB...")
|
||||
self.run_and_wait(db.disconnect())
|
||||
if self.use_redis:
|
||||
logger.info(f"[{self.__class__.__name__}] ⏳ Disconnecting Redis...")
|
||||
self.run_and_wait(self.event_queue.close())
|
||||
|
||||
@conn_retry
|
||||
def __start_pyro(self):
|
||||
host = Config().pyro_host
|
||||
|
||||
@@ -73,10 +73,10 @@ class SpinTestServer:
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await db.disconnect()
|
||||
|
||||
self.name_server.__exit__(exc_type, exc_val, exc_tb)
|
||||
self.agent_server.__exit__(exc_type, exc_val, exc_tb)
|
||||
self.exec_manager.__exit__(exc_type, exc_val, exc_tb)
|
||||
self.scheduler.__exit__(exc_type, exc_val, exc_tb)
|
||||
self.exec_manager.__exit__(exc_type, exc_val, exc_tb)
|
||||
self.agent_server.__exit__(exc_type, exc_val, exc_tb)
|
||||
self.name_server.__exit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
def setup_dependency_overrides(self):
|
||||
# Override get_user_id for testing
|
||||
|
||||
@@ -41,8 +41,20 @@ async def assert_sample_graph_executions(
|
||||
|
||||
output_list = [{"result": ["Hello"]}, {"result": ["World"]}]
|
||||
input_list = [
|
||||
{"value": "Hello", "name": "input_1"},
|
||||
{"value": "World", "name": "input_2"},
|
||||
{
|
||||
"name": "input_1",
|
||||
"description": "First input value",
|
||||
"placeholder_values": [],
|
||||
"limit_to_placeholder_values": False,
|
||||
"value": "Hello",
|
||||
},
|
||||
{
|
||||
"name": "input_2",
|
||||
"description": "Second input value",
|
||||
"placeholder_values": [],
|
||||
"limit_to_placeholder_values": False,
|
||||
"value": "World",
|
||||
},
|
||||
]
|
||||
|
||||
# Executing StoreValueBlock
|
||||
|
||||
Reference in New Issue
Block a user