mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-19 20:18:22 -05:00
Compare commits
7 Commits
fix/undefi
...
zamilmajdy
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
64ecd28804 | ||
|
|
cac41edafc | ||
|
|
6cd5007857 | ||
|
|
d1badceb34 | ||
|
|
c24cfc3718 | ||
|
|
fb7480304a | ||
|
|
eb097eefab |
@@ -1,9 +1,7 @@
|
||||
import logging
|
||||
from enum import Enum
|
||||
|
||||
import openai
|
||||
from pydantic import BaseModel
|
||||
|
||||
from enum import Enum
|
||||
from autogpt_server.data.block import Block, BlockOutput, BlockSchema
|
||||
from autogpt_server.util import json
|
||||
|
||||
@@ -14,17 +12,13 @@ class LlmModel(str, Enum):
|
||||
openai_gpt4 = "gpt-4-turbo"
|
||||
|
||||
|
||||
class LlmConfig(BaseModel):
|
||||
model: LlmModel
|
||||
api_key: str
|
||||
|
||||
|
||||
class LlmCallBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
config: LlmConfig
|
||||
expected_format: dict[str, str]
|
||||
api_key: str
|
||||
prompt: str
|
||||
sys_prompt: str = ""
|
||||
usr_prompt: str = ""
|
||||
expected_format: dict[str, str] = {}
|
||||
model: LlmModel = LlmModel.openai_gpt4
|
||||
retry: int = 3
|
||||
|
||||
class Output(BlockSchema):
|
||||
@@ -37,18 +31,15 @@ class LlmCallBlock(Block):
|
||||
input_schema=LlmCallBlock.Input,
|
||||
output_schema=LlmCallBlock.Output,
|
||||
test_input={
|
||||
"config": {
|
||||
"model": "gpt-4-turbo",
|
||||
"api_key": "fake-api",
|
||||
},
|
||||
"model": "gpt-4-turbo",
|
||||
"api_key": "fake-api",
|
||||
"expected_format": {
|
||||
"key1": "value1",
|
||||
"key2": "value2",
|
||||
},
|
||||
"sys_prompt": "System prompt",
|
||||
"usr_prompt": "User prompt",
|
||||
"prompt": "User prompt",
|
||||
},
|
||||
test_output=("response", {"key1": "key1Value","key2": "key2Value"}),
|
||||
test_output=("response", {"key1": "key1Value", "key2": "key2Value"}),
|
||||
test_mock={"llm_call": lambda *args, **kwargs: json.dumps({
|
||||
"key1": "key1Value",
|
||||
"key2": "key2Value",
|
||||
@@ -56,36 +47,40 @@ class LlmCallBlock(Block):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def llm_call(api_key: str, model: LlmModel, prompt: list[dict]) -> str:
|
||||
def llm_call(api_key: str, model: LlmModel, prompt: list[dict], json: bool) -> str:
|
||||
openai.api_key = api_key
|
||||
response = openai.chat.completions.create(
|
||||
model=model,
|
||||
messages=prompt, # type: ignore
|
||||
response_format={"type": "json_object"},
|
||||
response_format={"type": "json_object"} if json else None,
|
||||
)
|
||||
return response.choices[0].message.content or ""
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
expected_format = [f'"{k}": "{v}"' for k, v in
|
||||
input_data.expected_format.items()]
|
||||
|
||||
format_prompt = ",\n ".join(expected_format)
|
||||
sys_prompt = f"""
|
||||
|{input_data.sys_prompt}
|
||||
|
|
||||
|Reply in json format:
|
||||
|{{
|
||||
| {format_prompt}
|
||||
|}}
|
||||
"""
|
||||
usr_prompt = f"""
|
||||
|{input_data.usr_prompt}
|
||||
"""
|
||||
prompt = []
|
||||
|
||||
def trim_prompt(s: str) -> str:
|
||||
lines = s.strip().split("\n")
|
||||
return "\n".join([line.strip().lstrip("|") for line in lines])
|
||||
|
||||
if input_data.sys_prompt:
|
||||
prompt.append({"role": "system", "content": input_data.sys_prompt})
|
||||
|
||||
if input_data.expected_format:
|
||||
expected_format = [f'"{k}": "{v}"' for k, v in
|
||||
input_data.expected_format.items()]
|
||||
|
||||
format_prompt = ",\n ".join(expected_format)
|
||||
sys_prompt = f"""
|
||||
|Reply in json format:
|
||||
|{{
|
||||
| {format_prompt}
|
||||
|}}
|
||||
"""
|
||||
prompt.append({"role": "system", "content": trim_prompt(sys_prompt)})
|
||||
|
||||
prompt.append({"role": "user", "content": input_data.prompt})
|
||||
|
||||
def parse_response(resp: str) -> tuple[dict[str, str], str | None]:
|
||||
try:
|
||||
parsed = json.loads(resp)
|
||||
@@ -96,24 +91,24 @@ class LlmCallBlock(Block):
|
||||
except Exception as e:
|
||||
return {}, f"JSON decode error: {e}"
|
||||
|
||||
prompt = [
|
||||
{"role": "system", "content": trim_prompt(sys_prompt)},
|
||||
{"role": "user", "content": trim_prompt(usr_prompt)},
|
||||
]
|
||||
|
||||
logger.warning(f"LLM request: {prompt}")
|
||||
retry_prompt = ""
|
||||
for retry_count in range(input_data.retry):
|
||||
response_text = self.llm_call(
|
||||
input_data.config.api_key,
|
||||
input_data.config.model,
|
||||
prompt
|
||||
api_key=input_data.api_key,
|
||||
model=input_data.model,
|
||||
prompt=prompt,
|
||||
json=bool(input_data.expected_format)
|
||||
)
|
||||
logger.warning(f"LLM attempt-{retry_count} response: {response_text}")
|
||||
|
||||
parsed_dict, parsed_error = parse_response(response_text)
|
||||
if not parsed_error:
|
||||
yield "response", {k: str(v) for k, v in parsed_dict.items()}
|
||||
if input_data.expected_format:
|
||||
parsed_dict, parsed_error = parse_response(response_text)
|
||||
if not parsed_error:
|
||||
yield "response", {k: str(v) for k, v in parsed_dict.items()}
|
||||
return
|
||||
else:
|
||||
yield "response", {"response": response_text}
|
||||
return
|
||||
|
||||
retry_prompt = f"""
|
||||
|
||||
@@ -138,7 +138,7 @@ def reddit(
|
||||
import requests
|
||||
|
||||
from autogpt_server.data.graph import Graph, Link, Node
|
||||
from autogpt_server.blocks.ai import LlmConfig, LlmCallBlock, LlmModel
|
||||
from autogpt_server.blocks.ai import LlmCallBlock, LlmModel
|
||||
from autogpt_server.blocks.reddit import (
|
||||
RedditCredentials,
|
||||
RedditGetPostsBlock,
|
||||
@@ -153,10 +153,7 @@ def reddit(
|
||||
password=password,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
openai_creds = LlmConfig(
|
||||
model=LlmModel.openai_gpt4,
|
||||
api_key="TODO_FILL_OUT_THIS",
|
||||
)
|
||||
openai_api_key = "TODO_FILL_OUT_THIS"
|
||||
|
||||
# Hardcoded inputs
|
||||
reddit_get_post_input = {
|
||||
@@ -179,7 +176,7 @@ The product you are marketing is: Auto-GPT an autonomous AI agent utilizing GPT
|
||||
You reply the post that you find it relevant to be replied with marketing text.
|
||||
Make sure to only comment on a relevant post.
|
||||
""",
|
||||
"config": openai_creds,
|
||||
"api_key": openai_api_key,
|
||||
"expected_format": {
|
||||
"post_id": "str, the reddit post id",
|
||||
"is_relevant": "bool, whether the post is relevant for marketing",
|
||||
@@ -219,7 +216,7 @@ Make sure to only comment on a relevant post.
|
||||
# Links
|
||||
links = [
|
||||
Link(reddit_get_post_node.id, text_formatter_node.id, "post", "named_texts"),
|
||||
Link(text_formatter_node.id, llm_call_node.id, "output", "usr_prompt"),
|
||||
Link(text_formatter_node.id, llm_call_node.id, "output", "prompt"),
|
||||
Link(llm_call_node.id, text_matcher_node.id, "response", "data"),
|
||||
Link(llm_call_node.id, text_matcher_node.id, "response_#_is_relevant", "text"),
|
||||
Link(
|
||||
|
||||
@@ -2,7 +2,7 @@ import time
|
||||
from autogpt_server.data import block, db
|
||||
from autogpt_server.data.graph import Graph, Link, Node, create_graph
|
||||
from autogpt_server.data.execution import ExecutionStatus
|
||||
from autogpt_server.blocks.ai import LlmConfig, LlmCallBlock, LlmModel
|
||||
from autogpt_server.blocks.ai import LlmCallBlock, LlmModel
|
||||
from autogpt_server.blocks.reddit import (
|
||||
RedditCredentials,
|
||||
RedditGetPostsBlock,
|
||||
@@ -27,10 +27,7 @@ async def create_test_graph() -> Graph:
|
||||
password="TODO_FILL_OUT_THIS",
|
||||
user_agent="TODO_FILL_OUT_THIS",
|
||||
)
|
||||
openai_creds = LlmConfig(
|
||||
model=LlmModel.openai_gpt4,
|
||||
api_key="TODO_FILL_OUT_THIS",
|
||||
)
|
||||
openai_api_key = "TODO_FILL_OUT_THIS"
|
||||
|
||||
# Hardcoded inputs
|
||||
reddit_get_post_input = {
|
||||
@@ -53,7 +50,7 @@ The product you are marketing is: Auto-GPT an autonomous AI agent utilizing GPT
|
||||
You reply the post that you find it relevant to be replied with marketing text.
|
||||
Make sure to only comment on a relevant post.
|
||||
""",
|
||||
"config": openai_creds,
|
||||
"api_key": openai_api_key,
|
||||
"expected_format": {
|
||||
"post_id": "str, the reddit post id",
|
||||
"is_relevant": "bool, whether the post is relevant for marketing",
|
||||
@@ -96,7 +93,7 @@ Make sure to only comment on a relevant post.
|
||||
# Links
|
||||
links = [
|
||||
Link(reddit_get_post_node.id, text_formatter_node.id, "post", "named_texts"),
|
||||
Link(text_formatter_node.id, llm_call_node.id, "output", "usr_prompt"),
|
||||
Link(text_formatter_node.id, llm_call_node.id, "output", "prompt"),
|
||||
Link(llm_call_node.id, text_matcher_node.id, "response", "data"),
|
||||
Link(llm_call_node.id, text_matcher_node.id, "response_#_is_relevant", "text"),
|
||||
Link(text_matcher_node.id, reddit_comment_node.id, "positive_#_post_id",
|
||||
|
||||
Reference in New Issue
Block a user