mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
Add dict_split and list_split to output, add more blocks
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from autogpt_server.blocks import sample, reddit
|
||||
from autogpt_server.blocks import sample, reddit, text, object, ai
|
||||
from autogpt_server.data.block import Block
|
||||
|
||||
AVAILABLE_BLOCKS = {
|
||||
@@ -6,4 +6,4 @@ AVAILABLE_BLOCKS = {
|
||||
for block in [v() for v in Block.__subclasses__()]
|
||||
}
|
||||
|
||||
__all__ = ["sample", "reddit", "AVAILABLE_BLOCKS"]
|
||||
__all__ = ["ai", "object", "sample", "reddit", "text", "AVAILABLE_BLOCKS"]
|
||||
|
||||
@@ -0,0 +1,104 @@
|
||||
import json
|
||||
import logging
|
||||
from enum import Enum
|
||||
|
||||
import openai
|
||||
from pydantic import BaseModel
|
||||
|
||||
from autogpt_server.data.block import Block, BlockOutput, BlockSchema
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LlmModel(str, Enum):
|
||||
openai_gpt4 = "gpt-4-turbo"
|
||||
|
||||
|
||||
class LlmConfig(BaseModel):
|
||||
model: LlmModel
|
||||
api_key: str
|
||||
|
||||
|
||||
class LlmCallBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
config: LlmConfig
|
||||
expected_format: dict[str, str]
|
||||
sys_prompt: str = ""
|
||||
usr_prompt: str = ""
|
||||
retry: int = 3
|
||||
|
||||
class Output(BlockSchema):
|
||||
response: dict[str, str]
|
||||
error: str
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="ed55ac19-356e-4243-a6cb-bc599e9b716f",
|
||||
input_schema=LlmCallBlock.Input,
|
||||
output_schema=LlmCallBlock.Output,
|
||||
)
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
openai.api_key = input_data.config.api_key
|
||||
expected_format = [f'"{k}": "{v}"' for k, v in
|
||||
input_data.expected_format.items()]
|
||||
|
||||
sys_prompt = f"""
|
||||
|{input_data.sys_prompt}
|
||||
|
|
||||
|Reply in json format:
|
||||
|{{
|
||||
| {",\n ".join(expected_format)}
|
||||
|}}
|
||||
"""
|
||||
usr_prompt = f"""
|
||||
|{input_data.usr_prompt}
|
||||
"""
|
||||
|
||||
def trim_prompt(s: str) -> str:
|
||||
return "\n".join([line.strip()[1:] for line in s.strip().split("\n")])
|
||||
|
||||
def parse_response(resp: str) -> tuple[dict[str, str], str | None]:
|
||||
try:
|
||||
parsed = json.loads(resp)
|
||||
miss_keys = set(input_data.expected_format.keys()) - set(parsed.keys())
|
||||
if miss_keys:
|
||||
return parsed, f"Missing keys: {miss_keys}"
|
||||
return parsed, None
|
||||
except json.JSONDecodeError as e:
|
||||
return {}, f"JSON decode error: {e}"
|
||||
|
||||
prompt = [
|
||||
{"role": "system", "content": trim_prompt(sys_prompt)},
|
||||
{"role": "user", "content": trim_prompt(usr_prompt)},
|
||||
]
|
||||
|
||||
logger.warning(f"LLM request: {prompt}")
|
||||
for retry_count in range(input_data.retry):
|
||||
response = openai.chat.completions.create(
|
||||
model=input_data.config.model,
|
||||
messages=prompt, # type: ignore
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
response_text = response.choices[0].message.content or ""
|
||||
logger.warning(f"LLM attempt-{retry_count} response: {response_text}")
|
||||
|
||||
parsed_dict, parsed_error = parse_response(response_text)
|
||||
if not parsed_error:
|
||||
yield "response", parsed_dict
|
||||
break
|
||||
|
||||
retry_prompt = f"""
|
||||
|This is your previous error response:
|
||||
|--
|
||||
|{response_text}
|
||||
|--
|
||||
|
|
||||
|And this is the error:
|
||||
|--
|
||||
|{parsed_error}
|
||||
|--
|
||||
"""
|
||||
prompt.append({"role": "user", "content": trim_prompt(retry_prompt)})
|
||||
|
||||
yield "error", prompt[-1]["content"]
|
||||
|
||||
33
rnd/autogpt_server/autogpt_server/blocks/object.py
Normal file
33
rnd/autogpt_server/autogpt_server/blocks/object.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from typing import Any
|
||||
|
||||
from autogpt_server.data.block import Block, BlockOutput, BlockSchema
|
||||
|
||||
|
||||
class ObjectParser(Block):
|
||||
class Input(BlockSchema):
|
||||
object: Any
|
||||
field_path: str
|
||||
|
||||
class Output(BlockSchema):
|
||||
field_value: Any
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="be45299a-193b-4852-bda4-510883d21814",
|
||||
input_schema=ObjectParser.Input,
|
||||
output_schema=ObjectParser.Output,
|
||||
)
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
field_path = input_data.field_path.split(".")
|
||||
field_value = input_data.object
|
||||
for field in field_path:
|
||||
if isinstance(field_value, dict) and field in field_value:
|
||||
field_value = field_value.get(field)
|
||||
elif isinstance(field_value, object) and hasattr(field_value, field):
|
||||
field_value = getattr(field_value, field)
|
||||
else:
|
||||
yield "error", input_data.object
|
||||
return
|
||||
|
||||
yield "field_value", field_value
|
||||
@@ -21,25 +21,6 @@ class ParrotBlock(Block):
|
||||
yield "output", input_data.input
|
||||
|
||||
|
||||
class TextFormatterBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
texts: list[str]
|
||||
format: str
|
||||
|
||||
class Output(BlockSchema):
|
||||
combined_text: str
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="db7d8f02-2f44-4c55-ab7a-eae0941f0c30",
|
||||
input_schema=TextFormatterBlock.Input,
|
||||
output_schema=TextFormatterBlock.Output,
|
||||
)
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
yield "combined_text", input_data.format.format(texts=input_data.texts)
|
||||
|
||||
|
||||
class PrintingBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
text: str
|
||||
|
||||
52
rnd/autogpt_server/autogpt_server/blocks/text.py
Normal file
52
rnd/autogpt_server/autogpt_server/blocks/text.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import re
|
||||
|
||||
from typing import Any
|
||||
from autogpt_server.data.block import Block, BlockOutput, BlockSchema
|
||||
|
||||
|
||||
class TextMatcherBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
text: str
|
||||
match: str
|
||||
data: Any | None = None
|
||||
|
||||
class Output(BlockSchema):
|
||||
positive: Any
|
||||
negative: Any
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="3060088f-6ed9-4928-9ba7-9c92823a7ccd",
|
||||
input_schema=TextMatcherBlock.Input,
|
||||
output_schema=TextMatcherBlock.Output,
|
||||
)
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
output = input_data.data or input_data.text
|
||||
if re.search(input_data.match, input_data.text):
|
||||
yield "positive", output
|
||||
else:
|
||||
yield "negative", output
|
||||
|
||||
|
||||
class TextFormatterBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
texts: list[str] = []
|
||||
named_texts: dict[str, str] = {}
|
||||
format: str
|
||||
|
||||
class Output(BlockSchema):
|
||||
output: str
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="db7d8f02-2f44-4c55-ab7a-eae0941f0c30",
|
||||
input_schema=TextFormatterBlock.Input,
|
||||
output_schema=TextFormatterBlock.Output,
|
||||
)
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
yield "output", input_data.format.format(
|
||||
texts=input_data.texts,
|
||||
**input_data.named_texts,
|
||||
)
|
||||
@@ -244,6 +244,18 @@ LIST_SPLIT = "_$_"
|
||||
DICT_SPLIT = "_#_"
|
||||
|
||||
|
||||
def parse_execution_output(output: tuple[str, Any], name: str) -> Any | None:
|
||||
# Allow extracting partial output data by name.
|
||||
output_name, output_data = output
|
||||
if name == output_name:
|
||||
return output_data
|
||||
if isinstance(output_data, list) and name.startswith(f"{output_name}{LIST_SPLIT}"):
|
||||
return output_data[int(name.split(LIST_SPLIT)[1])]
|
||||
if isinstance(output_data, dict) and name.startswith(f"{output_name}{DICT_SPLIT}"):
|
||||
return output_data[name.split(DICT_SPLIT)[1]]
|
||||
return None
|
||||
|
||||
|
||||
def merge_execution_input(data: dict[str, Any]) -> dict[str, Any]:
|
||||
# Merge all input with <input_name>_$_<index> into a single list.
|
||||
list_input = []
|
||||
@@ -259,7 +271,7 @@ def merge_execution_input(data: dict[str, Any]) -> dict[str, Any]:
|
||||
for name, value, _ in sorted(list_input, key=lambda x: x[2]):
|
||||
data[name] = data.get(name, [])
|
||||
data[name].append(value)
|
||||
|
||||
|
||||
# Merge all input with <input_name>_#_<index> into a single dict.
|
||||
for key, value in data.items():
|
||||
if DICT_SPLIT not in key:
|
||||
|
||||
@@ -9,6 +9,7 @@ from autogpt_server.data.execution import (
|
||||
create_graph_execution,
|
||||
get_node_execution_input,
|
||||
merge_execution_input,
|
||||
parse_execution_output,
|
||||
update_execution_status as execution_update,
|
||||
upsert_execution_output,
|
||||
upsert_execution_input,
|
||||
@@ -98,13 +99,11 @@ def enqueue_next_nodes(
|
||||
prefix = get_log_prefix(graph_exec_id, node.id)
|
||||
node_id = node.id
|
||||
|
||||
# Try to enqueue next eligible nodes
|
||||
next_node_ids = [nid for name, nid in node.output_nodes if name == output_name]
|
||||
if not next_node_ids:
|
||||
logger.error(f"{prefix} Output [{output_name}] has no subsequent node.")
|
||||
return []
|
||||
def validate_next_node_execution(next_output_name: str, next_node_id: str):
|
||||
next_data = parse_execution_output((output_name, output_data), next_output_name)
|
||||
if next_data is None:
|
||||
return
|
||||
|
||||
def validate_node_execution(next_node_id: str):
|
||||
next_node = wait(get_node(next_node_id))
|
||||
if not next_node:
|
||||
logger.error(f"{prefix} Error, next node {next_node_id} not found.")
|
||||
@@ -117,7 +116,7 @@ def enqueue_next_nodes(
|
||||
node_id=next_node_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
input_name=next_node_input_name,
|
||||
data=output_data
|
||||
data=next_data
|
||||
))
|
||||
|
||||
next_node_input = wait(get_node_execution_input(next_node_exec_id))
|
||||
@@ -135,9 +134,14 @@ def enqueue_next_nodes(
|
||||
)
|
||||
|
||||
executions = []
|
||||
for nid in next_node_ids:
|
||||
if execution := validate_node_execution(nid):
|
||||
for name, nid in node.output_nodes:
|
||||
if execution := validate_next_node_execution(name, nid):
|
||||
executions.append(execution)
|
||||
|
||||
if not executions:
|
||||
logger.error(f"{prefix} Output [{output_name}] has no subsequent node.")
|
||||
return []
|
||||
|
||||
return executions
|
||||
|
||||
|
||||
|
||||
@@ -5,7 +5,8 @@ from autogpt_server.data import block, db, execution, graph
|
||||
from autogpt_server.executor import ExecutionManager
|
||||
from autogpt_server.server import AgentServer
|
||||
from autogpt_server.util.service import PyroNameServer
|
||||
from autogpt_server.blocks.sample import ParrotBlock, TextFormatterBlock, PrintingBlock
|
||||
from autogpt_server.blocks.sample import ParrotBlock, PrintingBlock
|
||||
from autogpt_server.blocks.text import TextFormatterBlock
|
||||
|
||||
|
||||
async def create_test_graph() -> graph.Graph:
|
||||
@@ -30,7 +31,7 @@ async def create_test_graph() -> graph.Graph:
|
||||
]
|
||||
nodes[0].connect(nodes[2], "output", "texts_$_1")
|
||||
nodes[1].connect(nodes[2], "output", "texts_$_2")
|
||||
nodes[2].connect(nodes[3], "combined_text", "text")
|
||||
nodes[2].connect(nodes[3], "output", "text")
|
||||
|
||||
test_graph = graph.Graph(
|
||||
name="TestGraph",
|
||||
@@ -91,7 +92,7 @@ async def execute_graph(test_manager: ExecutionManager, test_graph: graph.Graph)
|
||||
exec = executions[2]
|
||||
assert exec.status == execution.ExecutionStatus.COMPLETED
|
||||
assert exec.graph_exec_id == graph_exec_id
|
||||
assert exec.output_data == {"combined_text": ["Hello, World!,Hello, World!,!!!"]}
|
||||
assert exec.output_data == {"output": ["Hello, World!,Hello, World!,!!!"]}
|
||||
assert exec.input_data == {
|
||||
"texts_$_1": "Hello, World!",
|
||||
"texts_$_2": "Hello, World!",
|
||||
|
||||
Reference in New Issue
Block a user