mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-22 13:38:10 -05:00
Merge branch 'master' into aarushikansal/open-1370-serve-frontend-on-server
This commit is contained in:
@@ -160,10 +160,10 @@ Currently, the IPC is done using Pyro5 and abstracted in a way that allows a fun
|
||||
## Adding a New Agent Block
|
||||
|
||||
To add a new agent block, you need to create a new class that inherits from `Block` and provides the following information:
|
||||
* All the block code should live in the `blocks` (`autogpt_server.blocks`) module.
|
||||
* `input_schema`: the schema of the input data, represented by a Pydantic object.
|
||||
* `output_schema`: the schema of the output data, represented by a Pydantic object.
|
||||
* `run` method: the main logic of the block.
|
||||
* `test_input` & `test_output`: the sample input and output data for the block, which will be used to auto-test the block.
|
||||
* You can mock the functions declared in the block using the `test_mock` field for your unit tests.
|
||||
* If you introduce a new module under the `blocks` package, you need to import the module in `blocks/__init__.py` to make it available to the server.
|
||||
* Once you finish creating the block, you can test it by running `pytest test/block/test_block.py`.
|
||||
|
||||
@@ -1,9 +1,26 @@
|
||||
from autogpt_server.blocks import agent, sample, reddit, text, ai, wikipedia, discord
|
||||
import os
|
||||
import glob
|
||||
import importlib
|
||||
from pathlib import Path
|
||||
from autogpt_server.data.block import Block
|
||||
|
||||
# Dynamically load all modules under autogpt_server.blocks
|
||||
AVAILABLE_MODULES = []
|
||||
current_dir = os.path.dirname(__file__)
|
||||
modules = glob.glob(os.path.join(current_dir, "*.py"))
|
||||
modules = [
|
||||
Path(f).stem
|
||||
for f in modules
|
||||
if os.path.isfile(f) and f.endswith(".py") and not f.endswith("__init__.py")
|
||||
]
|
||||
for module in modules:
|
||||
importlib.import_module(f".{module}", package=__name__)
|
||||
AVAILABLE_MODULES.append(module)
|
||||
|
||||
# Load all Block instances from the available modules
|
||||
AVAILABLE_BLOCKS = {
|
||||
block.id: block
|
||||
for block in [v() for v in Block.__subclasses__()]
|
||||
}
|
||||
|
||||
__all__ = ["agent", "ai", "sample", "reddit", "text", "AVAILABLE_BLOCKS", "wikipedia", "discord"]
|
||||
__all__ = ["AVAILABLE_MODULES", "AVAILABLE_BLOCKS"]
|
||||
|
||||
73
rnd/autogpt_server/autogpt_server/blocks/basic.py
Normal file
73
rnd/autogpt_server/autogpt_server/blocks/basic.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from autogpt_server.data.block import Block, BlockSchema, BlockOutput
|
||||
|
||||
from typing import Any
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
class ValueBlock(Block):
|
||||
"""
|
||||
This block allows you to provide a constant value as a block, in a stateless manner.
|
||||
The common use-case is simply pass the `input` data, it will `output` the same data.
|
||||
But this will not retain the state, once it is executed, the output is consumed.
|
||||
|
||||
To retain the state, you can feed the `output` to the `data` input, so that the data
|
||||
is retained in the block for the next execution. You can then trigger the block by
|
||||
feeding the `input` pin with any data, and the block will produce value of `data`.
|
||||
|
||||
Ex:
|
||||
<constant_data> <any_trigger>
|
||||
|| ||
|
||||
=====> `data` `input`
|
||||
|| \\ //
|
||||
|| ValueBlock
|
||||
|| ||
|
||||
========= `output`
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
input: Any = Field(description="Trigger the block to produce the output. "
|
||||
"The value is only used when `data` is None.")
|
||||
data: Any = Field(description="The constant data to be retained in the block. "
|
||||
"This value is passed as `output`.", default=None)
|
||||
|
||||
class Output(BlockSchema):
|
||||
output: Any
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="1ff065e9-88e8-4358-9d82-8dc91f622ba9",
|
||||
input_schema=ValueBlock.Input,
|
||||
output_schema=ValueBlock.Output,
|
||||
test_input=[
|
||||
{"input": "Hello, World!"},
|
||||
{"input": "Hello, World!", "data": "Existing Data"},
|
||||
],
|
||||
test_output=[
|
||||
("output", "Hello, World!"), # No data provided, so trigger is returned
|
||||
("output", "Existing Data"), # Data is provided, so data is returned.
|
||||
],
|
||||
)
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
yield "output", input_data.data or input_data.input
|
||||
|
||||
|
||||
class PrintingBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
text: str
|
||||
|
||||
class Output(BlockSchema):
|
||||
status: str
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="f3b1c1b2-4c4f-4f0d-8d2f-4c4f0d8d2f4c",
|
||||
input_schema=PrintingBlock.Input,
|
||||
output_schema=PrintingBlock.Output,
|
||||
test_input={"text": "Hello, World!"},
|
||||
test_output=("status", "printed"),
|
||||
)
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
print(">>>>> Print: ", input_data.text)
|
||||
yield "status", "printed"
|
||||
61
rnd/autogpt_server/autogpt_server/blocks/block.py
Normal file
61
rnd/autogpt_server/autogpt_server/blocks/block.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import re
|
||||
import os
|
||||
|
||||
from typing import Type
|
||||
from autogpt_server.data.block import Block, BlockOutput, BlockSchema
|
||||
from autogpt_server.util.test import execute_block_test
|
||||
|
||||
|
||||
class BlockInstallationBlock(Block):
|
||||
"""
|
||||
This block allows the verification and installation of other blocks in the system.
|
||||
|
||||
NOTE:
|
||||
This block allows remote code execution on the server, and it should be used
|
||||
for development purposes only.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
code: str
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: str
|
||||
error: str
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="45e78db5-03e9-447f-9395-308d712f5f08",
|
||||
input_schema=BlockInstallationBlock.Input,
|
||||
output_schema=BlockInstallationBlock.Output,
|
||||
)
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
code = input_data.code
|
||||
|
||||
if search := re.search(r"class (\w+)\(Block\):", code):
|
||||
class_name = search.group(1)
|
||||
else:
|
||||
yield "error", "No class found in the code."
|
||||
return
|
||||
|
||||
if search := re.search(r"id=\"(\w+-\w+-\w+-\w+-\w+)\"", code):
|
||||
file_name = search.group(1)
|
||||
else:
|
||||
yield "error", "No UUID found in the code."
|
||||
return
|
||||
|
||||
block_dir = os.path.dirname(__file__)
|
||||
file_path = f"{block_dir}/{file_name}.py"
|
||||
module_name = f"autogpt_server.blocks.{file_name}"
|
||||
with open(file_path, "w") as f:
|
||||
f.write(code)
|
||||
|
||||
try:
|
||||
module = __import__(module_name, fromlist=[class_name])
|
||||
block_class: Type[Block] = getattr(module, class_name)
|
||||
block = block_class()
|
||||
execute_block_test(block)
|
||||
yield "success", "Block installed successfully."
|
||||
except Exception as e:
|
||||
os.remove(file_path)
|
||||
yield "error", f"[Code]\n{code}\n\n[Error]\n{str(e)}"
|
||||
@@ -1,30 +0,0 @@
|
||||
import requests
|
||||
from autogpt_server.data.block import Block, BlockSchema, BlockOutput
|
||||
|
||||
|
||||
class DiscordSendMessage(Block):
|
||||
class Input(BlockSchema):
|
||||
webhook_url: str
|
||||
message: str
|
||||
|
||||
class Output(BlockSchema):
|
||||
status: str
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b3a9c1f2-5d4e-47b3-9c4e-2b6e4d2c4f3e",
|
||||
input_schema=DiscordSendMessage.Input,
|
||||
output_schema=DiscordSendMessage.Output,
|
||||
test_input={
|
||||
"webhook_url": "https://discord.com/api/webhooks/your_webhook_url",
|
||||
"message": "Hello, Webhook!"
|
||||
},
|
||||
test_output=("status", "sent"),
|
||||
)
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
response = requests.post(input_data.webhook_url, json={"content": input_data.message})
|
||||
if response.status_code == 204: # Discord webhook returns 204 No Content on success
|
||||
yield "status", "sent"
|
||||
else:
|
||||
yield "status", f"failed with status code {response.status_code}"
|
||||
50
rnd/autogpt_server/autogpt_server/blocks/http.py
Normal file
50
rnd/autogpt_server/autogpt_server/blocks/http.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import requests
|
||||
|
||||
from enum import Enum
|
||||
from autogpt_server.data.block import Block, BlockSchema, BlockOutput
|
||||
|
||||
|
||||
class HttpMethod(Enum):
|
||||
GET = "GET"
|
||||
POST = "POST"
|
||||
PUT = "PUT"
|
||||
DELETE = "DELETE"
|
||||
PATCH = "PATCH"
|
||||
OPTIONS = "OPTIONS"
|
||||
HEAD = "HEAD"
|
||||
|
||||
|
||||
class HttpRequestBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
url: str
|
||||
method: HttpMethod = HttpMethod.POST
|
||||
headers: dict[str, str] = {}
|
||||
body: object = {}
|
||||
|
||||
class Output(BlockSchema):
|
||||
response: object
|
||||
client_error: object
|
||||
server_error: object
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="6595ae1f-b924-42cb-9a41-551a0611c4b4",
|
||||
input_schema=HttpRequestBlock.Input,
|
||||
output_schema=HttpRequestBlock.Output,
|
||||
)
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
response = requests.request(
|
||||
input_data.method.value,
|
||||
input_data.url,
|
||||
headers=input_data.headers,
|
||||
json=input_data.body,
|
||||
)
|
||||
if response.status_code // 100 == 2:
|
||||
yield "response", response.json()
|
||||
elif response.status_code // 100 == 4:
|
||||
yield "client_error", response.json()
|
||||
elif response.status_code // 100 == 5:
|
||||
yield "server_error", response.json()
|
||||
else:
|
||||
raise ValueError(f"Unexpected status code: {response.status_code}")
|
||||
@@ -16,7 +16,7 @@ class RedditCredentials(BaseModel):
|
||||
client_secret: BlockFieldSecret = BlockFieldSecret(key="reddit_client_secret")
|
||||
username: BlockFieldSecret = BlockFieldSecret(key="reddit_username")
|
||||
password: BlockFieldSecret = BlockFieldSecret(key="reddit_password")
|
||||
user_agent: str | None = None
|
||||
user_agent: str = "AutoGPT:1.0 (by /u/autogpt)"
|
||||
|
||||
|
||||
class RedditPost(BaseModel):
|
||||
|
||||
@@ -1,44 +0,0 @@
|
||||
# type: ignore
|
||||
|
||||
from autogpt_server.data.block import Block, BlockSchema, BlockOutput
|
||||
|
||||
|
||||
class ParrotBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
input: str
|
||||
|
||||
class Output(BlockSchema):
|
||||
output: str
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="1ff065e9-88e8-4358-9d82-8dc91f622ba9",
|
||||
input_schema=ParrotBlock.Input,
|
||||
output_schema=ParrotBlock.Output,
|
||||
test_input={"input": "Hello, World!"},
|
||||
test_output=("output", "Hello, World!"),
|
||||
)
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
yield "output", input_data.input
|
||||
|
||||
|
||||
class PrintingBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
text: str
|
||||
|
||||
class Output(BlockSchema):
|
||||
status: str
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="f3b1c1b2-4c4f-4f0d-8d2f-4c4f0d8d2f4c",
|
||||
input_schema=PrintingBlock.Input,
|
||||
output_schema=PrintingBlock.Output,
|
||||
test_input={"text": "Hello, World!"},
|
||||
test_output=("status", "printed"),
|
||||
)
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
print(">>>>> Print: ", input_data.text)
|
||||
yield "status", "printed"
|
||||
@@ -12,6 +12,7 @@ class TextMatcherBlock(Block):
|
||||
match: str = Field(description="Pattern (Regex) to match")
|
||||
data: Any = Field(description="Data to be forwarded to output")
|
||||
case_sensitive: bool = Field(description="Case sensitive match", default=True)
|
||||
dot_all: bool = Field(description="Dot matches all", default=True)
|
||||
|
||||
class Output(BlockSchema):
|
||||
positive: Any = Field(description="Output data if match is found")
|
||||
@@ -38,13 +39,73 @@ class TextMatcherBlock(Block):
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
output = input_data.data or input_data.text
|
||||
case = 0 if input_data.case_sensitive else re.IGNORECASE
|
||||
if re.search(input_data.match, json.dumps(input_data.text), case):
|
||||
flags = 0
|
||||
if not input_data.case_sensitive:
|
||||
flags = flags | re.IGNORECASE
|
||||
if input_data.dot_all:
|
||||
flags = flags | re.DOTALL
|
||||
|
||||
if isinstance(input_data.text, str):
|
||||
text = input_data.text
|
||||
else:
|
||||
text = json.dumps(input_data.text)
|
||||
|
||||
if re.search(input_data.match, text, flags=flags):
|
||||
yield "positive", output
|
||||
else:
|
||||
yield "negative", output
|
||||
|
||||
|
||||
class TextParserBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
text: Any = Field(description="Text to parse")
|
||||
pattern: str = Field(description="Pattern (Regex) to parse")
|
||||
group: int = Field(description="Group number to extract", default=0)
|
||||
case_sensitive: bool = Field(description="Case sensitive match", default=True)
|
||||
dot_all: bool = Field(description="Dot matches all", default=True)
|
||||
|
||||
class Output(BlockSchema):
|
||||
positive: str = Field(description="Extracted text")
|
||||
negative: str = Field(description="Original text")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="3146e4fe-2cdd-4f29-bd12-0c9d5bb4deb0",
|
||||
input_schema=TextParserBlock.Input,
|
||||
output_schema=TextParserBlock.Output,
|
||||
test_input=[
|
||||
{"text": "Hello, World!", "pattern": "Hello, (.+)", "group": 1},
|
||||
{"text": "Hello, World!", "pattern": "Hello, (.+)", "group": 0},
|
||||
{"text": "Hello, World!", "pattern": "Hello, (.+)", "group": 2},
|
||||
{"text": "Hello, World!", "pattern": "hello,", "case_sensitive": False},
|
||||
],
|
||||
test_output=[
|
||||
("positive", "World!"),
|
||||
("positive", "Hello, World!"),
|
||||
("negative", "Hello, World!"),
|
||||
("positive", "Hello,"),
|
||||
],
|
||||
)
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
flags = 0
|
||||
if not input_data.case_sensitive:
|
||||
flags = flags | re.IGNORECASE
|
||||
if input_data.dot_all:
|
||||
flags = flags | re.DOTALL
|
||||
|
||||
if isinstance(input_data.text, str):
|
||||
text = input_data.text
|
||||
else:
|
||||
text = json.dumps(input_data.text)
|
||||
|
||||
match = re.search(input_data.pattern, text, flags)
|
||||
if match and input_data.group <= len(match.groups()):
|
||||
yield "positive", match.group(input_data.group)
|
||||
else:
|
||||
yield "negative", text
|
||||
|
||||
|
||||
class TextFormatterBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
texts: list[str] = Field(
|
||||
|
||||
@@ -103,164 +103,13 @@ def test():
|
||||
|
||||
@test.command()
|
||||
@click.argument("server_address")
|
||||
@click.option(
|
||||
"--client-id", required=True, help="Reddit client ID", default="TODO_FILL_OUT_THIS"
|
||||
)
|
||||
@click.option(
|
||||
"--client-secret",
|
||||
required=True,
|
||||
help="Reddit client secret",
|
||||
default="TODO_FILL_OUT_THIS",
|
||||
)
|
||||
@click.option(
|
||||
"--username", required=True, help="Reddit username", default="TODO_FILL_OUT_THIS"
|
||||
)
|
||||
@click.option(
|
||||
"--password", required=True, help="Reddit password", default="TODO_FILL_OUT_THIS"
|
||||
)
|
||||
@click.option(
|
||||
"--user-agent",
|
||||
required=True,
|
||||
help="Reddit user agent",
|
||||
default="TODO_FILL_OUT_THIS",
|
||||
)
|
||||
def reddit(
|
||||
server_address: str,
|
||||
client_id: str,
|
||||
client_secret: str,
|
||||
username: str,
|
||||
password: str,
|
||||
user_agent: str,
|
||||
):
|
||||
def reddit(server_address: str):
|
||||
"""
|
||||
Create an event graph
|
||||
"""
|
||||
import requests
|
||||
|
||||
from autogpt_server.data.graph import Graph, Link, Node
|
||||
from autogpt_server.blocks.ai import LlmCallBlock, LlmModel
|
||||
from autogpt_server.blocks.reddit import (
|
||||
RedditCredentials,
|
||||
RedditGetPostsBlock,
|
||||
RedditPostCommentBlock,
|
||||
)
|
||||
from autogpt_server.blocks.text import TextFormatterBlock, TextMatcherBlock
|
||||
|
||||
reddit_creds = RedditCredentials(
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
username=username,
|
||||
password=password,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
openai_api_key = "TODO_FILL_OUT_THIS"
|
||||
|
||||
# Hardcoded inputs
|
||||
reddit_get_post_input = {
|
||||
"creds": reddit_creds,
|
||||
"last_minutes": 60,
|
||||
"post_limit": 3,
|
||||
}
|
||||
text_formatter_input = {
|
||||
"format": """
|
||||
Based on the following post, write your marketing comment:
|
||||
* Post ID: {id}
|
||||
* Post Subreddit: {subreddit}
|
||||
* Post Title: {title}
|
||||
* Post Body: {body}""".strip(),
|
||||
}
|
||||
llm_call_input = {
|
||||
"sys_prompt": """
|
||||
You are an expert at marketing, and have been tasked with picking Reddit posts that are relevant to your product.
|
||||
The product you are marketing is: Auto-GPT an autonomous AI agent utilizing GPT model.
|
||||
You reply the post that you find it relevant to be replied with marketing text.
|
||||
Make sure to only comment on a relevant post.
|
||||
""",
|
||||
"api_key": openai_api_key,
|
||||
"expected_format": {
|
||||
"post_id": "str, the reddit post id",
|
||||
"is_relevant": "bool, whether the post is relevant for marketing",
|
||||
"marketing_text": "str, marketing text, this is empty on irrelevant posts",
|
||||
},
|
||||
}
|
||||
text_matcher_input = {"match": "true", "case_sensitive": False}
|
||||
reddit_comment_input = {"creds": reddit_creds}
|
||||
|
||||
# Nodes
|
||||
reddit_get_post_node = Node(
|
||||
block_id=RedditGetPostsBlock().id,
|
||||
input_default=reddit_get_post_input,
|
||||
)
|
||||
text_formatter_node = Node(
|
||||
block_id=TextFormatterBlock().id,
|
||||
input_default=text_formatter_input,
|
||||
)
|
||||
llm_call_node = Node(block_id=LlmCallBlock().id, input_default=llm_call_input)
|
||||
text_matcher_node = Node(
|
||||
block_id=TextMatcherBlock().id,
|
||||
input_default=text_matcher_input,
|
||||
)
|
||||
reddit_comment_node = Node(
|
||||
block_id=RedditPostCommentBlock().id,
|
||||
input_default=reddit_comment_input,
|
||||
)
|
||||
|
||||
nodes = [
|
||||
reddit_get_post_node,
|
||||
text_formatter_node,
|
||||
llm_call_node,
|
||||
text_matcher_node,
|
||||
reddit_comment_node,
|
||||
]
|
||||
|
||||
# Links
|
||||
links = [
|
||||
Link(
|
||||
source_id=reddit_get_post_node.id,
|
||||
sink_id=text_formatter_node.id,
|
||||
source_name="post",
|
||||
sink_name="named_texts",
|
||||
),
|
||||
Link(
|
||||
source_id=text_formatter_node.id,
|
||||
sink_id=llm_call_node.id,
|
||||
source_name="output",
|
||||
sink_name="prompt",
|
||||
),
|
||||
Link(
|
||||
source_id=llm_call_node.id,
|
||||
sink_id=text_matcher_node.id,
|
||||
source_name="response",
|
||||
sink_name="data",
|
||||
),
|
||||
Link(
|
||||
source_id=llm_call_node.id,
|
||||
sink_id=text_matcher_node.id,
|
||||
source_name="response_#_is_relevant",
|
||||
sink_name="text",
|
||||
),
|
||||
Link(
|
||||
source_id=text_matcher_node.id,
|
||||
sink_id=reddit_comment_node.id,
|
||||
source_name="positive_#_post_id",
|
||||
sink_name="post_id",
|
||||
),
|
||||
Link(
|
||||
source_id=text_matcher_node.id,
|
||||
sink_id=reddit_comment_node.id,
|
||||
source_name="positive_#_marketing_text",
|
||||
sink_name="comment",
|
||||
),
|
||||
]
|
||||
|
||||
# Create graph
|
||||
test_graph = Graph(
|
||||
name="RedditMarketingAgent",
|
||||
description="Reddit marketing agent",
|
||||
nodes=nodes,
|
||||
links=links,
|
||||
)
|
||||
|
||||
from autogpt_server.usecases.reddit_marketing import create_test_graph
|
||||
test_graph = create_test_graph()
|
||||
url = f"{server_address}/graphs"
|
||||
headers = {"Content-Type": "application/json"}
|
||||
data = test_graph.model_dump_json()
|
||||
@@ -278,50 +127,8 @@ def populate_db(server_address: str):
|
||||
Create an event graph
|
||||
"""
|
||||
import requests
|
||||
|
||||
from autogpt_server.blocks.sample import ParrotBlock, PrintingBlock
|
||||
from autogpt_server.blocks.text import TextFormatterBlock
|
||||
from autogpt_server.data import graph
|
||||
|
||||
nodes = [
|
||||
graph.Node(block_id=ParrotBlock().id),
|
||||
graph.Node(block_id=ParrotBlock().id),
|
||||
graph.Node(
|
||||
block_id=TextFormatterBlock().id,
|
||||
input_default={
|
||||
"format": "{texts[0]},{texts[1]},{texts[2]}",
|
||||
"texts_$_3": "!!!",
|
||||
},
|
||||
),
|
||||
graph.Node(block_id=PrintingBlock().id),
|
||||
]
|
||||
links = [
|
||||
graph.Link(
|
||||
source_id=nodes[0].id,
|
||||
sink_id=nodes[2].id,
|
||||
source_name="output",
|
||||
sink_name="texts_$_1",
|
||||
),
|
||||
graph.Link(
|
||||
source_id=nodes[1].id,
|
||||
sink_id=nodes[2].id,
|
||||
source_name="output",
|
||||
sink_name="texts_$_2",
|
||||
),
|
||||
graph.Link(
|
||||
source_id=nodes[2].id,
|
||||
sink_id=nodes[3].id,
|
||||
source_name="output",
|
||||
sink_name="text",
|
||||
),
|
||||
]
|
||||
test_graph = graph.Graph(
|
||||
name="TestGraph",
|
||||
description="Test graph",
|
||||
nodes=nodes,
|
||||
links=links,
|
||||
)
|
||||
|
||||
from autogpt_server.usecases.sample import create_test_graph
|
||||
test_graph = create_test_graph()
|
||||
url = f"{server_address}/graphs"
|
||||
headers = {"Content-Type": "application/json"}
|
||||
data = test_graph.model_dump_json()
|
||||
@@ -354,54 +161,10 @@ def graph(server_address: str):
|
||||
Create an event graph
|
||||
"""
|
||||
import requests
|
||||
|
||||
from autogpt_server.blocks.sample import ParrotBlock, PrintingBlock
|
||||
from autogpt_server.blocks.text import TextFormatterBlock
|
||||
from autogpt_server.data import graph
|
||||
|
||||
nodes = [
|
||||
graph.Node(block_id=ParrotBlock().id),
|
||||
graph.Node(block_id=ParrotBlock().id),
|
||||
graph.Node(
|
||||
block_id=TextFormatterBlock().id,
|
||||
input_default={
|
||||
"format": "{texts[0]},{texts[1]},{texts[2]}",
|
||||
"texts_$_3": "!!!",
|
||||
},
|
||||
),
|
||||
graph.Node(block_id=PrintingBlock().id),
|
||||
]
|
||||
links = [
|
||||
graph.Link(
|
||||
source_id=nodes[0].id,
|
||||
sink_id=nodes[2].id,
|
||||
source_name="output",
|
||||
sink_name="texts_$_1",
|
||||
),
|
||||
graph.Link(
|
||||
source_id=nodes[1].id,
|
||||
sink_id=nodes[2].id,
|
||||
source_name="output",
|
||||
sink_name="texts_$_2",
|
||||
),
|
||||
graph.Link(
|
||||
source_id=nodes[2].id,
|
||||
sink_id=nodes[3].id,
|
||||
source_name="output",
|
||||
sink_name="text",
|
||||
),
|
||||
]
|
||||
test_graph = graph.Graph(
|
||||
name="TestGraph",
|
||||
description="Test graph",
|
||||
nodes=nodes,
|
||||
links=links,
|
||||
)
|
||||
|
||||
from autogpt_server.usecases.sample import create_test_graph
|
||||
url = f"{server_address}/graphs"
|
||||
headers = {"Content-Type": "application/json"}
|
||||
data = test_graph.model_dump_json()
|
||||
|
||||
data = create_test_graph().model_dump_json()
|
||||
response = requests.post(url, headers=headers, data=data)
|
||||
|
||||
if response.status_code == 200:
|
||||
@@ -418,7 +181,8 @@ def graph(server_address: str):
|
||||
|
||||
@test.command()
|
||||
@click.argument("graph_id")
|
||||
def execute(graph_id: str):
|
||||
@click.argument("content")
|
||||
def execute(graph_id: str, content: dict):
|
||||
"""
|
||||
Create an event graph
|
||||
"""
|
||||
@@ -427,9 +191,7 @@ def execute(graph_id: str):
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
execute_url = f"http://0.0.0.0:8000/graphs/{graph_id}/execute"
|
||||
text = "Hello, World!"
|
||||
input_data = {"input": text}
|
||||
requests.post(execute_url, headers=headers, json=input_data)
|
||||
requests.post(execute_url, headers=headers, json=content)
|
||||
|
||||
|
||||
@test.command()
|
||||
|
||||
@@ -19,6 +19,6 @@ class BaseDbModel(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid4()))
|
||||
|
||||
@field_validator("id", mode="before")
|
||||
def set_graph_id(cls, id: str) -> str:
|
||||
def set_model_id(cls, id: str) -> str:
|
||||
# In case an empty ID is submitted
|
||||
return id or str(uuid4())
|
||||
|
||||
@@ -6,7 +6,6 @@ from typing import Any
|
||||
|
||||
from prisma.models import (
|
||||
AgentGraphExecution,
|
||||
AgentNode,
|
||||
AgentNodeExecution,
|
||||
AgentNodeExecutionInputOutput,
|
||||
)
|
||||
@@ -95,9 +94,16 @@ class ExecutionResult(BaseModel):
|
||||
|
||||
# --------------------- Model functions --------------------- #
|
||||
|
||||
EXECUTION_RESULT_INCLUDE = {
|
||||
"Input": True,
|
||||
"Output": True,
|
||||
"AgentNode": True,
|
||||
"AgentGraphExecution": True,
|
||||
}
|
||||
|
||||
|
||||
async def create_graph_execution(
|
||||
graph_id: str, graph_version: int, node_ids: list[str], data: dict[str, Any]
|
||||
graph_id: str, graph_version: int, node_ids: list[str], data: dict[str, Any]
|
||||
) -> tuple[str, list[ExecutionResult]]:
|
||||
"""
|
||||
Create a new AgentGraphExecution record.
|
||||
@@ -179,9 +185,9 @@ async def upsert_execution_input(
|
||||
|
||||
|
||||
async def upsert_execution_output(
|
||||
node_exec_id: str,
|
||||
output_name: str,
|
||||
output_data: Any,
|
||||
node_exec_id: str,
|
||||
output_name: str,
|
||||
output_data: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Insert AgentNodeExecutionInputOutput record for as one of AgentNodeExecution.Output.
|
||||
@@ -195,7 +201,10 @@ async def upsert_execution_output(
|
||||
)
|
||||
|
||||
|
||||
async def update_execution_status(node_exec_id: str, status: ExecutionStatus) -> None:
|
||||
async def update_execution_status(
|
||||
node_exec_id: str,
|
||||
status: ExecutionStatus
|
||||
) -> ExecutionResult:
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
data = {
|
||||
**({"executionStatus": status}),
|
||||
@@ -208,10 +217,13 @@ async def update_execution_status(node_exec_id: str, status: ExecutionStatus) ->
|
||||
res = await AgentNodeExecution.prisma().update(
|
||||
where={"id": node_exec_id},
|
||||
data=data, # type: ignore
|
||||
include=EXECUTION_RESULT_INCLUDE, # type: ignore
|
||||
)
|
||||
if not res:
|
||||
raise ValueError(f"Execution {node_exec_id} not found.")
|
||||
|
||||
return ExecutionResult.from_db(res)
|
||||
|
||||
|
||||
async def list_executions(graph_id: str, graph_version: int | None = None) -> list[str]:
|
||||
where: AgentGraphExecutionWhereInput = {"agentGraphId": graph_id}
|
||||
@@ -224,25 +236,13 @@ async def list_executions(graph_id: str, graph_version: int | None = None) -> li
|
||||
async def get_execution_results(graph_exec_id: str) -> list[ExecutionResult]:
|
||||
executions = await AgentNodeExecution.prisma().find_many(
|
||||
where={"agentGraphExecutionId": graph_exec_id},
|
||||
include={"Input": True, "Output": True, "AgentGraphExecution": True},
|
||||
include=EXECUTION_RESULT_INCLUDE, # type: ignore
|
||||
order={"addedTime": "asc"},
|
||||
)
|
||||
res = [ExecutionResult.from_db(execution) for execution in executions]
|
||||
return res
|
||||
|
||||
|
||||
async def get_execution_result(
|
||||
graph_exec_id: str, node_exec_id: str
|
||||
) -> ExecutionResult:
|
||||
execution = await AgentNodeExecution.prisma().find_first_or_raise(
|
||||
where={"agentGraphExecutionId": graph_exec_id, "id": node_exec_id},
|
||||
include={"Input": True, "Output": True, "AgentGraphExecution": True},
|
||||
order={"addedTime": "asc"},
|
||||
)
|
||||
res = ExecutionResult.from_db(execution)
|
||||
return res
|
||||
|
||||
|
||||
async def get_node_execution_input(node_exec_id: str) -> dict[str, Any]:
|
||||
"""
|
||||
Get execution node input data from the previous node execution result.
|
||||
@@ -252,10 +252,7 @@ async def get_node_execution_input(node_exec_id: str) -> dict[str, Any]:
|
||||
"""
|
||||
execution = await AgentNodeExecution.prisma().find_unique_or_raise(
|
||||
where={"id": node_exec_id},
|
||||
include={
|
||||
"Input": True,
|
||||
"AgentNode": True,
|
||||
},
|
||||
include=EXECUTION_RESULT_INCLUDE, # type: ignore
|
||||
)
|
||||
if not execution.AgentNode:
|
||||
raise ValueError(f"Node {execution.agentNodeId} not found.")
|
||||
@@ -302,8 +299,9 @@ def parse_execution_output(output: tuple[str, Any], name: str) -> Any | None:
|
||||
|
||||
def merge_execution_input(data: dict[str, Any]) -> dict[str, Any]:
|
||||
# Merge all input with <input_name>_$_<index> into a single list.
|
||||
items = list(data.items())
|
||||
list_input: list[Any] = []
|
||||
for key, value in data.items():
|
||||
for key, value in items:
|
||||
if LIST_SPLIT not in key:
|
||||
continue
|
||||
name, index = key.split(LIST_SPLIT)
|
||||
@@ -317,7 +315,7 @@ def merge_execution_input(data: dict[str, Any]) -> dict[str, Any]:
|
||||
data[name].append(value)
|
||||
|
||||
# Merge all input with <input_name>_#_<index> into a single dict.
|
||||
for key, value in data.items():
|
||||
for key, value in items:
|
||||
if DICT_SPLIT not in key:
|
||||
continue
|
||||
name, index = key.split(DICT_SPLIT)
|
||||
@@ -325,7 +323,7 @@ def merge_execution_input(data: dict[str, Any]) -> dict[str, Any]:
|
||||
data[name][index] = value
|
||||
|
||||
# Merge all input with <input_name>_@_<index> into a single object.
|
||||
for key, value in data.items():
|
||||
for key, value in items:
|
||||
if OBJC_SPLIT not in key:
|
||||
continue
|
||||
name, index = key.split(OBJC_SPLIT)
|
||||
|
||||
@@ -141,10 +141,10 @@ async def get_graphs_meta(
|
||||
Default behaviour is to get all currently active graphs.
|
||||
|
||||
Args:
|
||||
filter: An optional filter to either select templates or active graphs.
|
||||
filter_by: An optional filter to either select templates or active graphs.
|
||||
|
||||
Returns:
|
||||
list[GraphMeta]: A list of GraphMeta objects representing the retrieved graph metadata.
|
||||
list[GraphMeta]: A list of objects representing the retrieved graph metadata.
|
||||
"""
|
||||
where_clause: prisma.types.AgentGraphWhereInput = {}
|
||||
|
||||
|
||||
@@ -9,11 +9,10 @@ from autogpt_server.data import db
|
||||
from autogpt_server.data.block import Block, get_block
|
||||
from autogpt_server.data.execution import (
|
||||
create_graph_execution,
|
||||
get_execution_result,
|
||||
get_node_execution_input,
|
||||
merge_execution_input,
|
||||
parse_execution_output,
|
||||
update_execution_status as execution_update,
|
||||
update_execution_status,
|
||||
upsert_execution_output,
|
||||
upsert_execution_input,
|
||||
NodeExecution as Execution,
|
||||
@@ -21,7 +20,7 @@ from autogpt_server.data.execution import (
|
||||
ExecutionQueue,
|
||||
)
|
||||
from autogpt_server.data.graph import Link, Node, get_node, get_graph, Graph
|
||||
from autogpt_server.util.service import AppService, expose, get_service_client # type: ignore
|
||||
from autogpt_server.util.service import AppService, expose, get_service_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -36,7 +35,7 @@ ExecutionStream = Generator[Execution, None, None]
|
||||
|
||||
def execute_node(
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
agent_server_client: "AgentServer",
|
||||
api_client: "AgentServer",
|
||||
data: Execution
|
||||
) -> ExecutionStream:
|
||||
"""
|
||||
@@ -45,7 +44,7 @@ def execute_node(
|
||||
|
||||
Args:
|
||||
loop: The event loop to run the async functions.
|
||||
agent_server_client: The client to send execution updates to the server.
|
||||
api_client: The client to send execution updates to the server.
|
||||
data: The execution data for executing the current node.
|
||||
|
||||
Returns:
|
||||
@@ -60,6 +59,11 @@ def execute_node(
|
||||
|
||||
def wait(f: Coroutine[T, Any, T]) -> T:
|
||||
return loop.run_until_complete(f)
|
||||
|
||||
def update_execution(status: ExecutionStatus):
|
||||
api_client.send_execution_update(
|
||||
wait(update_execution_status(node_exec_id, status)).model_dump()
|
||||
)
|
||||
|
||||
node = wait(get_node(node_id))
|
||||
if not node:
|
||||
@@ -74,28 +78,16 @@ def execute_node(
|
||||
# Execute the node
|
||||
prefix = get_log_prefix(graph_exec_id, node_exec_id, node_block.name)
|
||||
logger.warning(f"{prefix} execute with input:\n`{exec_data}`")
|
||||
|
||||
wait(execution_update(node_exec_id, ExecutionStatus.RUNNING))
|
||||
|
||||
# TODO: Remove need for multiple database lookups
|
||||
execution_result = wait(get_execution_result(
|
||||
graph_exec_id, node_exec_id
|
||||
))
|
||||
agent_server_client.send_execution_update(execution_result.model_dump()) # type: ignore
|
||||
update_execution(ExecutionStatus.RUNNING)
|
||||
|
||||
try:
|
||||
for output_name, output_data in node_block.execute(exec_data):
|
||||
logger.warning(f"{prefix} Executed, output [{output_name}]:`{output_data}`")
|
||||
wait(execution_update(node_exec_id, ExecutionStatus.COMPLETED))
|
||||
wait(upsert_execution_output(node_exec_id, output_name, output_data))
|
||||
|
||||
# TODO: Remove need for multiple database lookups
|
||||
execution_result = wait(get_execution_result(
|
||||
graph_exec_id, node_exec_id
|
||||
))
|
||||
agent_server_client.send_execution_update(execution_result.model_dump()) # type: ignore
|
||||
update_execution(ExecutionStatus.COMPLETED)
|
||||
|
||||
for execution in enqueue_next_nodes(
|
||||
api_client=api_client,
|
||||
loop=loop,
|
||||
node=node,
|
||||
output=(output_name, output_data),
|
||||
@@ -106,19 +98,14 @@ def execute_node(
|
||||
except Exception as e:
|
||||
error_msg = f"{e.__class__.__name__}: {e}"
|
||||
logger.exception(f"{prefix} failed with error. `%s`", error_msg)
|
||||
wait(execution_update(node_exec_id, ExecutionStatus.FAILED))
|
||||
wait(upsert_execution_output(node_exec_id, "error", error_msg))
|
||||
|
||||
# TODO: Remove need for multiple database lookups
|
||||
execution_result = wait(get_execution_result(
|
||||
graph_exec_id, node_exec_id
|
||||
))
|
||||
agent_server_client.send_execution_update(execution_result.model_dump()) # type: ignore
|
||||
update_execution(ExecutionStatus.FAILED)
|
||||
|
||||
raise e
|
||||
|
||||
|
||||
def enqueue_next_nodes(
|
||||
api_client: "AgentServer",
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
node: Node,
|
||||
output: tuple[str, Any],
|
||||
@@ -127,8 +114,13 @@ def enqueue_next_nodes(
|
||||
) -> list[Execution]:
|
||||
def wait(f: Coroutine[T, Any, T]) -> T:
|
||||
return loop.run_until_complete(f)
|
||||
|
||||
def execution_update(node_exec_id: str, status: ExecutionStatus):
|
||||
api_client.send_execution_update(
|
||||
wait(update_execution_status(node_exec_id, status)).model_dump()
|
||||
)
|
||||
|
||||
def get_next_node_execution(node_link: Link) -> Execution | None:
|
||||
def update_execution_result(node_link: Link) -> Execution | None:
|
||||
next_output_name = node_link.source_name
|
||||
next_input_name = node_link.sink_name
|
||||
next_node_id = node_link.sink_id
|
||||
@@ -148,8 +140,8 @@ def enqueue_next_nodes(
|
||||
input_name=next_input_name,
|
||||
data=next_data
|
||||
))
|
||||
next_node_input = wait(get_node_execution_input(next_node_exec_id))
|
||||
|
||||
next_node_input = wait(get_node_execution_input(next_node_exec_id))
|
||||
is_valid, validation_msg = validate_exec(next_node, next_node_input)
|
||||
suffix = f"{next_output_name}~{next_input_name}#{next_node_id}:{validation_msg}"
|
||||
|
||||
@@ -157,16 +149,20 @@ def enqueue_next_nodes(
|
||||
logger.warning(f"{prefix} Skipped queueing {suffix}")
|
||||
return
|
||||
|
||||
# Input is complete, enqueue the execution.
|
||||
logger.warning(f"{prefix} Enqueued {suffix}")
|
||||
execution_update(next_node_exec_id, ExecutionStatus.QUEUED)
|
||||
return Execution(
|
||||
graph_exec_id=graph_exec_id,
|
||||
node_exec_id=next_node_exec_id,
|
||||
node_id=next_node_id,
|
||||
node_id=next_node.id,
|
||||
data=next_node_input,
|
||||
)
|
||||
|
||||
executions = [get_next_node_execution(link) for link in node.output_links]
|
||||
return [v for v in executions if v]
|
||||
return [
|
||||
execution for link in node.output_links
|
||||
if (execution := update_execution_result(link))
|
||||
]
|
||||
|
||||
|
||||
def validate_exec(node: Node, data: dict[str, Any]) -> tuple[bool, str]:
|
||||
@@ -286,15 +282,6 @@ class ExecutionManager(AppService):
|
||||
data=input_data,
|
||||
)
|
||||
)
|
||||
# TODO: Remove need for multiple database lookups
|
||||
execution_result = self.run_and_wait(get_execution_result(
|
||||
node_exec.graph_exec_id, node_exec.node_exec_id
|
||||
))
|
||||
try:
|
||||
self.agent_server_client.send_execution_update(execution_result.model_dump()) # type: ignore
|
||||
except Exception as e:
|
||||
msg = f"Error sending execution of type {type(execution_result)}: {e}"
|
||||
raise Exception(msg)
|
||||
|
||||
executions.append(
|
||||
{
|
||||
@@ -309,7 +296,9 @@ class ExecutionManager(AppService):
|
||||
}
|
||||
|
||||
def add_node_execution(self, execution: Execution) -> Execution:
|
||||
self.run_and_wait(
|
||||
execution_update(execution.node_exec_id, ExecutionStatus.QUEUED)
|
||||
)
|
||||
res = self.run_and_wait(update_execution_status(
|
||||
execution.node_exec_id,
|
||||
ExecutionStatus.QUEUED
|
||||
))
|
||||
self.agent_server_client.send_execution_update(res.model_dump())
|
||||
return self.queue.add(execution)
|
||||
|
||||
235
rnd/autogpt_server/autogpt_server/usecases/block_autogen.py
Normal file
235
rnd/autogpt_server/autogpt_server/usecases/block_autogen.py
Normal file
@@ -0,0 +1,235 @@
|
||||
from pathlib import Path
|
||||
|
||||
from autogpt_server.blocks.ai import LlmCallBlock
|
||||
from autogpt_server.blocks.basic import ValueBlock
|
||||
from autogpt_server.blocks.block import BlockInstallationBlock
|
||||
from autogpt_server.blocks.http import HttpRequestBlock
|
||||
from autogpt_server.blocks.text import TextParserBlock, TextFormatterBlock
|
||||
from autogpt_server.data.graph import Graph, Node, Link, create_graph
|
||||
from autogpt_server.util.test import SpinTestServer, wait_execution
|
||||
|
||||
|
||||
sample_block_modules = {
|
||||
"ai": "Block that calls the AI model to generate text.",
|
||||
"basic": "Block that does basic operations.",
|
||||
"text": "Blocks that do text operations.",
|
||||
"reddit": "Blocks that interacts with Reddit.",
|
||||
}
|
||||
sample_block_codes = {}
|
||||
for module, description in sample_block_modules.items():
|
||||
current_dir = Path(__file__).parent
|
||||
file_path = current_dir.parent / "blocks" / f"{module}.py"
|
||||
with open(file_path, "r") as f:
|
||||
code = "\n".join(["```python", f.read(), "```"])
|
||||
sample_block_codes[module] = f"[Example: {description}]\n{code}"
|
||||
|
||||
|
||||
def create_test_graph() -> Graph:
|
||||
"""
|
||||
ValueBlock (input)
|
||||
||
|
||||
v
|
||||
TextFormatterBlock (input query)
|
||||
||
|
||||
v
|
||||
HttpRequestBlock (browse)
|
||||
||
|
||||
v
|
||||
------> ValueBlock===============
|
||||
| | | ||
|
||||
| -- ||
|
||||
| ||
|
||||
| ||
|
||||
| v
|
||||
| LlmCallBlock <===== TextFormatterBlock (query)
|
||||
| || ^
|
||||
| v ||
|
||||
| TextParserBlock ||
|
||||
| || ||
|
||||
| v ||
|
||||
------ BlockInstallationBlock ======
|
||||
"""
|
||||
# ======= Nodes ========= #
|
||||
input_data = Node(
|
||||
block_id=ValueBlock().id
|
||||
)
|
||||
input_text_formatter = Node(
|
||||
block_id=TextFormatterBlock().id,
|
||||
input_default={
|
||||
"format": "Show me how to make a python code for this query: `{query}`",
|
||||
},
|
||||
)
|
||||
search_http_request = Node(
|
||||
block_id=HttpRequestBlock().id,
|
||||
input_default={
|
||||
"url": "https://osit-v2.bentlybro.com/search",
|
||||
},
|
||||
)
|
||||
search_result_constant = Node(
|
||||
block_id=ValueBlock().id,
|
||||
input_default={
|
||||
"data": None,
|
||||
},
|
||||
)
|
||||
prompt_text_formatter = Node(
|
||||
block_id=TextFormatterBlock().id,
|
||||
input_default={
|
||||
"format": """
|
||||
Write me a full Block implementation for this query: `{query}`
|
||||
|
||||
Here is the information I get to write a Python code for that:
|
||||
{search_result}
|
||||
|
||||
Here is your previous attempt:
|
||||
{previous_attempt}
|
||||
""",
|
||||
"named_texts_#_previous_attempt": "No previous attempt found."
|
||||
},
|
||||
)
|
||||
code_gen_llm_call = Node(
|
||||
block_id=LlmCallBlock().id,
|
||||
input_default={
|
||||
"sys_prompt": f"""
|
||||
You are a software engineer and you are asked to write the full class implementation.
|
||||
The class that you are implementing is extending a class called `Block`.
|
||||
This class will be used as a node in a graph of other blocks to build a complex system.
|
||||
This class has a method called `run` that takes an input and returns an output.
|
||||
It also has an `id` attribute that is a UUID, input_schema, and output_schema.
|
||||
For UUID, you have to hardcode it, like `d2e2ecd2-9ae6-422d-8dfe-ceca500ce6a6`,
|
||||
don't use any automatic UUID generation, because it needs to be consistent.
|
||||
To validate the correctness of your implementation, you can also define a test.
|
||||
There is `test_input` and `test_output` you can use to validate your implementation.
|
||||
There is also `test_mock` to mock a helper function on your block class for testing.
|
||||
|
||||
Feel free to start your answer by explaining your plan what's required how to test, etc.
|
||||
But make sure to produce the fully working implementation at the end,
|
||||
and it should be enclosed within this block format:
|
||||
```python
|
||||
<Your implementation here>
|
||||
```
|
||||
|
||||
Here are a couple of sample of the Block class implementation:
|
||||
|
||||
{"--------------\n".join([sample_block_codes[v] for v in sample_block_modules])}
|
||||
""",
|
||||
},
|
||||
)
|
||||
code_text_parser = Node(
|
||||
block_id=TextParserBlock().id,
|
||||
input_default={
|
||||
"pattern": "```python\n(.+?)\n```",
|
||||
"group": 1,
|
||||
},
|
||||
)
|
||||
block_installation = Node(
|
||||
block_id=BlockInstallationBlock().id,
|
||||
)
|
||||
nodes = [
|
||||
input_data,
|
||||
input_text_formatter,
|
||||
search_http_request,
|
||||
search_result_constant,
|
||||
prompt_text_formatter,
|
||||
code_gen_llm_call,
|
||||
code_text_parser,
|
||||
block_installation,
|
||||
]
|
||||
|
||||
# ======= Links ========= #
|
||||
links = [
|
||||
Link(
|
||||
source_id=input_data.id,
|
||||
sink_id=input_text_formatter.id,
|
||||
source_name="output",
|
||||
sink_name="named_texts_#_query"),
|
||||
|
||||
Link(
|
||||
source_id=input_text_formatter.id,
|
||||
sink_id=search_http_request.id,
|
||||
source_name="output",
|
||||
sink_name="body_#_query"),
|
||||
|
||||
Link(
|
||||
source_id=search_http_request.id,
|
||||
sink_id=search_result_constant.id,
|
||||
source_name="response_#_reply",
|
||||
sink_name="input"),
|
||||
Link( # Loopback for constant block
|
||||
source_id=search_result_constant.id,
|
||||
sink_id=search_result_constant.id,
|
||||
source_name="output",
|
||||
sink_name="data"
|
||||
),
|
||||
|
||||
Link(
|
||||
source_id=search_result_constant.id,
|
||||
sink_id=prompt_text_formatter.id,
|
||||
source_name="output",
|
||||
sink_name="named_texts_#_search_result"
|
||||
),
|
||||
Link(
|
||||
source_id=input_data.id,
|
||||
sink_id=prompt_text_formatter.id,
|
||||
source_name="output",
|
||||
sink_name="named_texts_#_query"
|
||||
),
|
||||
|
||||
Link(
|
||||
source_id=prompt_text_formatter.id,
|
||||
sink_id=code_gen_llm_call.id,
|
||||
source_name="output",
|
||||
sink_name="prompt"
|
||||
),
|
||||
|
||||
Link(
|
||||
source_id=code_gen_llm_call.id,
|
||||
sink_id=code_text_parser.id,
|
||||
source_name="response_#_response",
|
||||
sink_name="text"
|
||||
),
|
||||
|
||||
Link(
|
||||
source_id=code_text_parser.id,
|
||||
sink_id=block_installation.id,
|
||||
source_name="positive",
|
||||
sink_name="code"
|
||||
),
|
||||
|
||||
Link(
|
||||
source_id=block_installation.id,
|
||||
sink_id=prompt_text_formatter.id,
|
||||
source_name="error",
|
||||
sink_name="named_texts_#_previous_attempt"
|
||||
),
|
||||
Link( # Re-trigger search result.
|
||||
source_id=block_installation.id,
|
||||
sink_id=search_result_constant.id,
|
||||
source_name="error",
|
||||
sink_name="input"
|
||||
),
|
||||
]
|
||||
|
||||
# ======= Graph ========= #
|
||||
return Graph(
|
||||
name="BlockAutoGen",
|
||||
description="Block auto generation agent",
|
||||
nodes=nodes,
|
||||
links=links,
|
||||
)
|
||||
|
||||
|
||||
async def block_autogen_agent():
|
||||
async with SpinTestServer() as server:
|
||||
test_manager = server.exec_manager
|
||||
test_graph = await create_graph(create_test_graph())
|
||||
input_data = {"input": "Write me a block that writes a string into a file."}
|
||||
response = await server.agent_server.execute_graph(test_graph.id, input_data)
|
||||
print(response)
|
||||
result = await wait_execution(test_manager, test_graph.id, response["id"], 10, 1200)
|
||||
print(result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
asyncio.run(block_autogen_agent())
|
||||
@@ -1,38 +1,19 @@
|
||||
import time
|
||||
from autogpt_server.data import block, db
|
||||
from autogpt_server.data.graph import Graph, Link, Node, create_graph
|
||||
from autogpt_server.data.execution import ExecutionStatus
|
||||
from autogpt_server.blocks.ai import LlmCallBlock, LlmModel
|
||||
from autogpt_server.blocks.ai import LlmCallBlock
|
||||
from autogpt_server.blocks.reddit import (
|
||||
RedditCredentials,
|
||||
RedditGetPostsBlock,
|
||||
RedditPostCommentBlock,
|
||||
)
|
||||
from autogpt_server.blocks.text import TextFormatterBlock, TextMatcherBlock
|
||||
from autogpt_server.executor import ExecutionManager
|
||||
from autogpt_server.server import AgentServer
|
||||
from autogpt_server.util.service import PyroNameServer
|
||||
from autogpt_server.util.test import SpinTestServer, wait_execution
|
||||
|
||||
|
||||
async def create_test_graph() -> Graph:
|
||||
def create_test_graph() -> Graph:
|
||||
# /--- post_id -----------\ /--- post_id ---\
|
||||
# subreddit --> RedditGetPostsBlock ---- post_body -------- TextFormatterBlock ----- LlmCallBlock / TextRelevancy --- relevant/not -- TextMatcherBlock -- Yes {postid, text} --- RedditPostCommentBlock
|
||||
# \--- post_title -------/ \--- marketing_text ---/ -- No
|
||||
|
||||
# Creds
|
||||
reddit_creds = RedditCredentials(
|
||||
client_id="TODO_FILL_OUT_THIS",
|
||||
client_secret="TODO_FILL_OUT_THIS",
|
||||
username="TODO_FILL_OUT_THIS",
|
||||
password="TODO_FILL_OUT_THIS",
|
||||
user_agent="TODO_FILL_OUT_THIS",
|
||||
)
|
||||
openai_api_key = "TODO_FILL_OUT_THIS"
|
||||
|
||||
# Hardcoded inputs
|
||||
reddit_get_post_input = {
|
||||
"creds": reddit_creds,
|
||||
"last_minutes": 60,
|
||||
"post_limit": 3,
|
||||
}
|
||||
text_formatter_input = {
|
||||
@@ -50,7 +31,6 @@ The product you are marketing is: Auto-GPT an autonomous AI agent utilizing GPT
|
||||
You reply the post that you find it relevant to be replied with marketing text.
|
||||
Make sure to only comment on a relevant post.
|
||||
""",
|
||||
"api_key": openai_api_key,
|
||||
"expected_format": {
|
||||
"post_id": "str, the reddit post id",
|
||||
"is_relevant": "bool, whether the post is relevant for marketing",
|
||||
@@ -58,7 +38,7 @@ Make sure to only comment on a relevant post.
|
||||
},
|
||||
}
|
||||
text_matcher_input = {"match": "true", "case_sensitive": False}
|
||||
reddit_comment_input = {"creds": reddit_creds}
|
||||
reddit_comment_input = {}
|
||||
|
||||
# Nodes
|
||||
reddit_get_post_node = Node(
|
||||
@@ -134,56 +114,20 @@ Make sure to only comment on a relevant post.
|
||||
nodes=nodes,
|
||||
links=links,
|
||||
)
|
||||
return await create_graph(test_graph)
|
||||
|
||||
|
||||
async def wait_execution(test_manager, graph_id, graph_exec_id) -> list:
|
||||
async def is_execution_completed():
|
||||
execs = await AgentServer().get_run_execution_results(graph_id, graph_exec_id)
|
||||
"""
|
||||
List of execution:
|
||||
reddit_get_post_node 1 (produced 3 posts)
|
||||
text_formatter_node 3
|
||||
llm_call_node 3 (assume 3 of them relevant)
|
||||
text_matcher_node 3
|
||||
reddit_comment_node 3
|
||||
Total: 13
|
||||
"""
|
||||
print("--------> Execution count: ", len(execs), [str(v.status) for v in execs])
|
||||
return (
|
||||
test_manager.queue.empty()
|
||||
and len(execs) == 13
|
||||
and all(
|
||||
v.status in [ExecutionStatus.COMPLETED, ExecutionStatus.FAILED]
|
||||
for v in execs
|
||||
)
|
||||
)
|
||||
|
||||
# Wait for the executions to complete
|
||||
for i in range(120):
|
||||
if await is_execution_completed():
|
||||
return await AgentServer().get_run_execution_results(
|
||||
graph_id, graph_exec_id
|
||||
)
|
||||
time.sleep(1)
|
||||
|
||||
assert False, "Execution did not complete in time."
|
||||
return test_graph
|
||||
|
||||
|
||||
async def reddit_marketing_agent():
|
||||
with PyroNameServer():
|
||||
with ExecutionManager(1) as test_manager:
|
||||
await db.connect()
|
||||
await block.initialize_blocks()
|
||||
test_graph = await create_test_graph()
|
||||
input_data = {"subreddit": "AutoGPT"}
|
||||
response = await AgentServer().execute_graph(test_graph.id, input_data)
|
||||
print(response)
|
||||
result = await wait_execution(test_manager, test_graph.id, response["id"])
|
||||
print(result)
|
||||
async with SpinTestServer() as server:
|
||||
exec_man = server.exec_manager
|
||||
test_graph = await create_graph(create_test_graph())
|
||||
input_data = {"subreddit": "AutoGPT"}
|
||||
response = await server.agent_server.execute_graph(test_graph.id, input_data)
|
||||
print(response)
|
||||
result = await wait_execution(exec_man, test_graph.id, response["id"], 13, 120)
|
||||
print(result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
asyncio.run(reddit_marketing_agent())
|
||||
70
rnd/autogpt_server/autogpt_server/usecases/sample.py
Normal file
70
rnd/autogpt_server/autogpt_server/usecases/sample.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from autogpt_server.blocks.basic import ValueBlock, PrintingBlock
|
||||
from autogpt_server.blocks.text import TextFormatterBlock
|
||||
from autogpt_server.data import graph
|
||||
from autogpt_server.data.graph import create_graph
|
||||
from autogpt_server.util.test import SpinTestServer, wait_execution
|
||||
|
||||
|
||||
def create_test_graph() -> graph.Graph:
|
||||
"""
|
||||
ValueBlock
|
||||
\
|
||||
---- TextFormatterBlock ---- PrintingBlock
|
||||
/
|
||||
ValueBlock
|
||||
"""
|
||||
nodes = [
|
||||
graph.Node(block_id=ValueBlock().id),
|
||||
graph.Node(block_id=ValueBlock().id),
|
||||
graph.Node(
|
||||
block_id=TextFormatterBlock().id,
|
||||
input_default={
|
||||
"format": "{texts[0]},{texts[1]},{texts[2]}",
|
||||
"texts_$_3": "!!!",
|
||||
},
|
||||
),
|
||||
graph.Node(block_id=PrintingBlock().id),
|
||||
]
|
||||
links = [
|
||||
graph.Link(
|
||||
source_id=nodes[0].id,
|
||||
sink_id=nodes[2].id,
|
||||
source_name="output",
|
||||
sink_name="texts_$_1"
|
||||
),
|
||||
graph.Link(
|
||||
source_id=nodes[1].id,
|
||||
sink_id=nodes[2].id,
|
||||
source_name="output",
|
||||
sink_name="texts_$_2"
|
||||
),
|
||||
graph.Link(
|
||||
source_id=nodes[2].id,
|
||||
sink_id=nodes[3].id,
|
||||
source_name="output",
|
||||
sink_name="text"
|
||||
),
|
||||
]
|
||||
|
||||
return graph.Graph(
|
||||
name="TestGraph",
|
||||
description="Test graph",
|
||||
nodes=nodes,
|
||||
links=links,
|
||||
)
|
||||
|
||||
|
||||
async def sample_agent():
|
||||
async with SpinTestServer() as server:
|
||||
exec_man = server.exec_manager
|
||||
test_graph = await create_graph(create_test_graph())
|
||||
input_data = {"input": "test!!"}
|
||||
response = await server.agent_server.execute_graph(test_graph.id, input_data)
|
||||
print(response)
|
||||
result = await wait_execution(exec_man, test_graph.id, response["id"], 4, 10)
|
||||
print(result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
asyncio.run(sample_agent())
|
||||
108
rnd/autogpt_server/autogpt_server/util/test.py
Normal file
108
rnd/autogpt_server/autogpt_server/util/test.py
Normal file
@@ -0,0 +1,108 @@
|
||||
import time
|
||||
|
||||
from autogpt_server.data.block import Block
|
||||
from autogpt_server.data import block, db
|
||||
from autogpt_server.data.execution import ExecutionStatus
|
||||
from autogpt_server.executor import ExecutionManager, ExecutionScheduler
|
||||
from autogpt_server.server import AgentServer
|
||||
from autogpt_server.util.service import PyroNameServer
|
||||
|
||||
log = print
|
||||
|
||||
|
||||
class SpinTestServer:
|
||||
def __init__(self):
|
||||
self.name_server = PyroNameServer()
|
||||
self.exec_manager = ExecutionManager(1)
|
||||
self.agent_server = AgentServer()
|
||||
self.scheduler = ExecutionScheduler()
|
||||
|
||||
async def __aenter__(self):
|
||||
self.name_server.__enter__()
|
||||
self.agent_server.__enter__()
|
||||
self.exec_manager.__enter__()
|
||||
self.scheduler.__enter__()
|
||||
|
||||
await db.connect()
|
||||
await block.initialize_blocks()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await db.disconnect()
|
||||
|
||||
self.name_server.__exit__(exc_type, exc_val, exc_tb)
|
||||
self.agent_server.__exit__(exc_type, exc_val, exc_tb)
|
||||
self.exec_manager.__exit__(exc_type, exc_val, exc_tb)
|
||||
self.scheduler.__exit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
|
||||
async def wait_execution(
|
||||
exec_manager: ExecutionManager,
|
||||
graph_id: str,
|
||||
graph_exec_id: str,
|
||||
num_execs: int,
|
||||
timeout: int = 20,
|
||||
) -> list:
|
||||
async def is_execution_completed():
|
||||
execs = await AgentServer().get_run_execution_results(graph_id, graph_exec_id)
|
||||
return exec_manager.queue.empty() and len(execs) == num_execs and all(
|
||||
v.status in [ExecutionStatus.COMPLETED, ExecutionStatus.FAILED]
|
||||
for v in execs
|
||||
)
|
||||
|
||||
# Wait for the executions to complete
|
||||
for i in range(timeout):
|
||||
if await is_execution_completed():
|
||||
return await AgentServer().get_run_execution_results(
|
||||
graph_id, graph_exec_id
|
||||
)
|
||||
time.sleep(1)
|
||||
|
||||
assert False, "Execution did not complete in time."
|
||||
|
||||
|
||||
def execute_block_test(block: Block):
|
||||
prefix = f"[Test-{block.name}]"
|
||||
|
||||
if not block.test_input or not block.test_output:
|
||||
log(f"{prefix} No test data provided")
|
||||
return
|
||||
if not isinstance(block.test_input, list):
|
||||
block.test_input = [block.test_input]
|
||||
if not isinstance(block.test_output, list):
|
||||
block.test_output = [block.test_output]
|
||||
|
||||
output_index = 0
|
||||
log(f"{prefix} Executing {len(block.test_input)} tests...")
|
||||
prefix = " " * 4 + prefix
|
||||
|
||||
for mock_name, mock_obj in (block.test_mock or {}).items():
|
||||
log(f"{prefix} mocking {mock_name}...")
|
||||
setattr(block, mock_name, mock_obj)
|
||||
|
||||
for input_data in block.test_input:
|
||||
log(f"{prefix} in: {input_data}")
|
||||
|
||||
for output_name, output_data in block.execute(input_data):
|
||||
if output_index >= len(block.test_output):
|
||||
raise ValueError(f"{prefix} produced output more than expected")
|
||||
ex_output_name, ex_output_data = block.test_output[output_index]
|
||||
|
||||
def compare(data, expected_data):
|
||||
if isinstance(expected_data, type):
|
||||
is_matching = isinstance(data, expected_data)
|
||||
else:
|
||||
is_matching = data == expected_data
|
||||
|
||||
mark = "✅" if is_matching else "❌"
|
||||
log(f"{prefix} {mark} comparing `{data}` vs `{expected_data}`")
|
||||
if not is_matching:
|
||||
raise ValueError(
|
||||
f"{prefix}: wrong output {data} vs {expected_data}")
|
||||
|
||||
compare(output_data, ex_output_data)
|
||||
compare(output_name, ex_output_name)
|
||||
output_index += 1
|
||||
|
||||
if output_index < len(block.test_output):
|
||||
raise ValueError(f"{prefix} produced output less than expected")
|
||||
@@ -1,52 +1,5 @@
|
||||
import logging
|
||||
|
||||
from autogpt_server.data.block import Block, get_blocks
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
log = print
|
||||
|
||||
|
||||
def execute_block_test(block: Block):
|
||||
prefix = f"[Test-{block.name}]"
|
||||
|
||||
if not block.test_input or not block.test_output:
|
||||
log(f"{prefix} No test data provided")
|
||||
return
|
||||
if not isinstance(block.test_input, list):
|
||||
block.test_input = [block.test_input]
|
||||
if not isinstance(block.test_output, list):
|
||||
block.test_output = [block.test_output]
|
||||
|
||||
output_index = 0
|
||||
log(f"{prefix} Executing {len(block.test_input)} tests...")
|
||||
prefix = " " * 4 + prefix
|
||||
|
||||
for mock_name, mock_obj in (block.test_mock or {}).items():
|
||||
log(f"{prefix} mocking {mock_name}...")
|
||||
setattr(block, mock_name, mock_obj)
|
||||
|
||||
for input_data in block.test_input:
|
||||
log(f"{prefix} in: {input_data}")
|
||||
|
||||
for output_name, output_data in block.execute(input_data):
|
||||
if output_index >= len(block.test_output):
|
||||
raise ValueError(f"{prefix} produced output more than expected")
|
||||
ex_output_name, ex_output_data = block.test_output[output_index]
|
||||
|
||||
def compare(data, expected_data):
|
||||
if isinstance(expected_data, type):
|
||||
is_matching = isinstance(data, expected_data)
|
||||
else:
|
||||
is_matching = data == expected_data
|
||||
|
||||
mark = "✅" if is_matching else "❌"
|
||||
log(f"{prefix} {mark} comparing `{data}` vs `{expected_data}`")
|
||||
if not is_matching:
|
||||
raise ValueError(f"{prefix}: wrong output {data} vs {expected_data}")
|
||||
|
||||
compare(output_name, ex_output_name)
|
||||
compare(output_data, ex_output_data)
|
||||
output_index += 1
|
||||
from autogpt_server.data.block import get_blocks
|
||||
from autogpt_server.util.test import execute_block_test
|
||||
|
||||
|
||||
def test_available_blocks():
|
||||
|
||||
@@ -1,70 +1,10 @@
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from autogpt_server.blocks.sample import ParrotBlock, PrintingBlock
|
||||
from autogpt_server.blocks.text import TextFormatterBlock
|
||||
from autogpt_server.data import block, db, execution, graph
|
||||
from autogpt_server.data import execution, graph
|
||||
from autogpt_server.executor import ExecutionManager
|
||||
from autogpt_server.server import AgentServer
|
||||
from autogpt_server.util.service import PyroNameServer
|
||||
|
||||
|
||||
async def create_test_graph() -> graph.Graph:
|
||||
"""
|
||||
ParrotBlock
|
||||
\
|
||||
---- TextFormatterBlock ---- PrintingBlock
|
||||
/
|
||||
ParrotBlock
|
||||
"""
|
||||
nodes = [
|
||||
graph.Node(block_id=ParrotBlock().id),
|
||||
graph.Node(block_id=ParrotBlock().id),
|
||||
graph.Node(
|
||||
block_id=TextFormatterBlock().id,
|
||||
input_default={
|
||||
"format": "{texts[0]},{texts[1]},{texts[2]}",
|
||||
"texts_$_3": "!!!",
|
||||
},
|
||||
),
|
||||
graph.Node(block_id=PrintingBlock().id),
|
||||
]
|
||||
links = [
|
||||
graph.Link(
|
||||
source_id=nodes[0].id,
|
||||
sink_id=nodes[2].id,
|
||||
source_name="output",
|
||||
sink_name="texts_$_1",
|
||||
),
|
||||
graph.Link(
|
||||
source_id=nodes[1].id,
|
||||
sink_id=nodes[2].id,
|
||||
source_name="output",
|
||||
sink_name="texts_$_2",
|
||||
),
|
||||
graph.Link(
|
||||
source_id=nodes[2].id,
|
||||
sink_id=nodes[3].id,
|
||||
source_name="output",
|
||||
sink_name="text",
|
||||
),
|
||||
]
|
||||
test_graph = graph.Graph(
|
||||
name="TestGraph",
|
||||
version=1,
|
||||
description="Test graph",
|
||||
nodes=nodes,
|
||||
links=links,
|
||||
)
|
||||
result = await graph.create_graph(test_graph)
|
||||
|
||||
# Assertions
|
||||
assert result.name == test_graph.name
|
||||
assert result.description == test_graph.description
|
||||
assert len(result.nodes) == len(test_graph.nodes)
|
||||
|
||||
return test_graph
|
||||
from autogpt_server.util.test import SpinTestServer, wait_execution
|
||||
from autogpt_server.usecases.sample import create_test_graph
|
||||
|
||||
|
||||
async def execute_graph(test_manager: ExecutionManager, test_graph: graph.Graph) -> str:
|
||||
@@ -77,26 +17,8 @@ async def execute_graph(test_manager: ExecutionManager, test_graph: graph.Graph)
|
||||
graph_exec_id = response["id"]
|
||||
assert len(executions) == 2
|
||||
|
||||
async def is_execution_completed():
|
||||
execs = await agent_server.get_run_execution_results(
|
||||
test_graph.id, graph_exec_id
|
||||
)
|
||||
return (
|
||||
test_manager.queue.empty()
|
||||
and len(execs) == 4
|
||||
and all(
|
||||
exec.status == execution.ExecutionStatus.COMPLETED for exec in execs
|
||||
)
|
||||
)
|
||||
|
||||
# Wait for the executions to complete
|
||||
for i in range(10):
|
||||
if await is_execution_completed():
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
# Execution queue should be empty
|
||||
assert await is_execution_completed()
|
||||
assert await wait_execution(test_manager, test_graph.id, graph_exec_id, 4)
|
||||
return graph_exec_id
|
||||
|
||||
|
||||
@@ -107,7 +29,7 @@ async def assert_executions(test_graph: graph.Graph, graph_exec_id: str):
|
||||
test_graph.id, graph_exec_id
|
||||
)
|
||||
|
||||
# Executing ParrotBlock1
|
||||
# Executing ConstantBlock1
|
||||
exec = executions[0]
|
||||
assert exec.status == execution.ExecutionStatus.COMPLETED
|
||||
assert exec.graph_exec_id == graph_exec_id
|
||||
@@ -115,7 +37,7 @@ async def assert_executions(test_graph: graph.Graph, graph_exec_id: str):
|
||||
assert exec.input_data == {"input": text}
|
||||
assert exec.node_id == test_graph.nodes[0].id
|
||||
|
||||
# Executing ParrotBlock2
|
||||
# Executing ConstantBlock2
|
||||
exec = executions[1]
|
||||
assert exec.status == execution.ExecutionStatus.COMPLETED
|
||||
assert exec.graph_exec_id == graph_exec_id
|
||||
@@ -145,11 +67,8 @@ async def assert_executions(test_graph: graph.Graph, graph_exec_id: str):
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_agent_execution():
|
||||
with PyroNameServer():
|
||||
with AgentServer():
|
||||
with ExecutionManager(1) as test_manager:
|
||||
await db.connect()
|
||||
await block.initialize_blocks()
|
||||
test_graph = await create_test_graph()
|
||||
graph_exec_id = await execute_graph(test_manager, test_graph)
|
||||
await assert_executions(test_graph, graph_exec_id)
|
||||
async with SpinTestServer() as server:
|
||||
test_graph = create_test_graph()
|
||||
await graph.create_graph(test_graph)
|
||||
graph_exec_id = await execute_graph(server.exec_manager, test_graph)
|
||||
await assert_executions(test_graph, graph_exec_id)
|
||||
|
||||
@@ -1,36 +1,34 @@
|
||||
import pytest
|
||||
|
||||
import test_manager
|
||||
from autogpt_server.executor.scheduler import ExecutionScheduler
|
||||
from autogpt_server.util.service import PyroNameServer, get_service_client
|
||||
from autogpt_server.server import AgentServer
|
||||
from autogpt_server.data import db, graph
|
||||
from autogpt_server.executor import ExecutionScheduler
|
||||
from autogpt_server.util.service import get_service_client
|
||||
from autogpt_server.util.test import SpinTestServer
|
||||
from autogpt_server.usecases.sample import create_test_graph
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_agent_schedule():
|
||||
await test_manager.db.connect()
|
||||
test_graph = await test_manager.create_test_graph()
|
||||
await db.connect()
|
||||
test_graph = await graph.create_graph(create_test_graph())
|
||||
|
||||
with PyroNameServer():
|
||||
with AgentServer():
|
||||
with ExecutionScheduler():
|
||||
scheduler = get_service_client(ExecutionScheduler)
|
||||
async with SpinTestServer():
|
||||
scheduler = get_service_client(ExecutionScheduler)
|
||||
|
||||
schedules = scheduler.get_execution_schedules(test_graph.id)
|
||||
assert len(schedules) == 0
|
||||
schedules = scheduler.get_execution_schedules(test_graph.id)
|
||||
assert len(schedules) == 0
|
||||
|
||||
schedule_id = scheduler.add_execution_schedule(
|
||||
graph_id=test_graph.id,
|
||||
graph_version=1,
|
||||
cron="0 0 * * *",
|
||||
input_data={"input": "data"},
|
||||
)
|
||||
assert schedule_id
|
||||
schedule_id = scheduler.add_execution_schedule(
|
||||
graph_id=test_graph.id,
|
||||
graph_version=1,
|
||||
cron="0 0 * * *",
|
||||
input_data={"input": "data"},
|
||||
)
|
||||
assert schedule_id
|
||||
|
||||
schedules = scheduler.get_execution_schedules(test_graph.id)
|
||||
assert len(schedules) == 1
|
||||
assert schedules[schedule_id] == "0 0 * * *"
|
||||
schedules = scheduler.get_execution_schedules(test_graph.id)
|
||||
assert len(schedules) == 1
|
||||
assert schedules[schedule_id] == "0 0 * * *"
|
||||
|
||||
scheduler.update_schedule(schedule_id, is_enabled=False)
|
||||
schedules = scheduler.get_execution_schedules(test_graph.id)
|
||||
assert len(schedules) == 0
|
||||
scheduler.update_schedule(schedule_id, is_enabled=False)
|
||||
schedules = scheduler.get_execution_schedules(test_graph.id)
|
||||
assert len(schedules) == 0
|
||||
|
||||
Reference in New Issue
Block a user