mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-12 15:55:03 -05:00
Compare commits
7 Commits
fix/claude
...
zamilmajdy
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
64ecd28804 | ||
|
|
cac41edafc | ||
|
|
6cd5007857 | ||
|
|
d1badceb34 | ||
|
|
c24cfc3718 | ||
|
|
fb7480304a | ||
|
|
eb097eefab |
@@ -1,9 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
from autogpt_server.data.block import Block, BlockOutput, BlockSchema
|
from autogpt_server.data.block import Block, BlockOutput, BlockSchema
|
||||||
from autogpt_server.util import json
|
from autogpt_server.util import json
|
||||||
|
|
||||||
@@ -14,17 +12,13 @@ class LlmModel(str, Enum):
|
|||||||
openai_gpt4 = "gpt-4-turbo"
|
openai_gpt4 = "gpt-4-turbo"
|
||||||
|
|
||||||
|
|
||||||
class LlmConfig(BaseModel):
|
|
||||||
model: LlmModel
|
|
||||||
api_key: str
|
|
||||||
|
|
||||||
|
|
||||||
class LlmCallBlock(Block):
|
class LlmCallBlock(Block):
|
||||||
class Input(BlockSchema):
|
class Input(BlockSchema):
|
||||||
config: LlmConfig
|
api_key: str
|
||||||
expected_format: dict[str, str]
|
prompt: str
|
||||||
sys_prompt: str = ""
|
sys_prompt: str = ""
|
||||||
usr_prompt: str = ""
|
expected_format: dict[str, str] = {}
|
||||||
|
model: LlmModel = LlmModel.openai_gpt4
|
||||||
retry: int = 3
|
retry: int = 3
|
||||||
|
|
||||||
class Output(BlockSchema):
|
class Output(BlockSchema):
|
||||||
@@ -37,18 +31,15 @@ class LlmCallBlock(Block):
|
|||||||
input_schema=LlmCallBlock.Input,
|
input_schema=LlmCallBlock.Input,
|
||||||
output_schema=LlmCallBlock.Output,
|
output_schema=LlmCallBlock.Output,
|
||||||
test_input={
|
test_input={
|
||||||
"config": {
|
"model": "gpt-4-turbo",
|
||||||
"model": "gpt-4-turbo",
|
"api_key": "fake-api",
|
||||||
"api_key": "fake-api",
|
|
||||||
},
|
|
||||||
"expected_format": {
|
"expected_format": {
|
||||||
"key1": "value1",
|
"key1": "value1",
|
||||||
"key2": "value2",
|
"key2": "value2",
|
||||||
},
|
},
|
||||||
"sys_prompt": "System prompt",
|
"prompt": "User prompt",
|
||||||
"usr_prompt": "User prompt",
|
|
||||||
},
|
},
|
||||||
test_output=("response", {"key1": "key1Value","key2": "key2Value"}),
|
test_output=("response", {"key1": "key1Value", "key2": "key2Value"}),
|
||||||
test_mock={"llm_call": lambda *args, **kwargs: json.dumps({
|
test_mock={"llm_call": lambda *args, **kwargs: json.dumps({
|
||||||
"key1": "key1Value",
|
"key1": "key1Value",
|
||||||
"key2": "key2Value",
|
"key2": "key2Value",
|
||||||
@@ -56,36 +47,40 @@ class LlmCallBlock(Block):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def llm_call(api_key: str, model: LlmModel, prompt: list[dict]) -> str:
|
def llm_call(api_key: str, model: LlmModel, prompt: list[dict], json: bool) -> str:
|
||||||
openai.api_key = api_key
|
openai.api_key = api_key
|
||||||
response = openai.chat.completions.create(
|
response = openai.chat.completions.create(
|
||||||
model=model,
|
model=model,
|
||||||
messages=prompt, # type: ignore
|
messages=prompt, # type: ignore
|
||||||
response_format={"type": "json_object"},
|
response_format={"type": "json_object"} if json else None,
|
||||||
)
|
)
|
||||||
return response.choices[0].message.content or ""
|
return response.choices[0].message.content or ""
|
||||||
|
|
||||||
def run(self, input_data: Input) -> BlockOutput:
|
def run(self, input_data: Input) -> BlockOutput:
|
||||||
expected_format = [f'"{k}": "{v}"' for k, v in
|
prompt = []
|
||||||
input_data.expected_format.items()]
|
|
||||||
|
|
||||||
format_prompt = ",\n ".join(expected_format)
|
|
||||||
sys_prompt = f"""
|
|
||||||
|{input_data.sys_prompt}
|
|
||||||
|
|
|
||||||
|Reply in json format:
|
|
||||||
|{{
|
|
||||||
| {format_prompt}
|
|
||||||
|}}
|
|
||||||
"""
|
|
||||||
usr_prompt = f"""
|
|
||||||
|{input_data.usr_prompt}
|
|
||||||
"""
|
|
||||||
|
|
||||||
def trim_prompt(s: str) -> str:
|
def trim_prompt(s: str) -> str:
|
||||||
lines = s.strip().split("\n")
|
lines = s.strip().split("\n")
|
||||||
return "\n".join([line.strip().lstrip("|") for line in lines])
|
return "\n".join([line.strip().lstrip("|") for line in lines])
|
||||||
|
|
||||||
|
if input_data.sys_prompt:
|
||||||
|
prompt.append({"role": "system", "content": input_data.sys_prompt})
|
||||||
|
|
||||||
|
if input_data.expected_format:
|
||||||
|
expected_format = [f'"{k}": "{v}"' for k, v in
|
||||||
|
input_data.expected_format.items()]
|
||||||
|
|
||||||
|
format_prompt = ",\n ".join(expected_format)
|
||||||
|
sys_prompt = f"""
|
||||||
|
|Reply in json format:
|
||||||
|
|{{
|
||||||
|
| {format_prompt}
|
||||||
|
|}}
|
||||||
|
"""
|
||||||
|
prompt.append({"role": "system", "content": trim_prompt(sys_prompt)})
|
||||||
|
|
||||||
|
prompt.append({"role": "user", "content": input_data.prompt})
|
||||||
|
|
||||||
def parse_response(resp: str) -> tuple[dict[str, str], str | None]:
|
def parse_response(resp: str) -> tuple[dict[str, str], str | None]:
|
||||||
try:
|
try:
|
||||||
parsed = json.loads(resp)
|
parsed = json.loads(resp)
|
||||||
@@ -96,24 +91,24 @@ class LlmCallBlock(Block):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return {}, f"JSON decode error: {e}"
|
return {}, f"JSON decode error: {e}"
|
||||||
|
|
||||||
prompt = [
|
|
||||||
{"role": "system", "content": trim_prompt(sys_prompt)},
|
|
||||||
{"role": "user", "content": trim_prompt(usr_prompt)},
|
|
||||||
]
|
|
||||||
|
|
||||||
logger.warning(f"LLM request: {prompt}")
|
logger.warning(f"LLM request: {prompt}")
|
||||||
retry_prompt = ""
|
retry_prompt = ""
|
||||||
for retry_count in range(input_data.retry):
|
for retry_count in range(input_data.retry):
|
||||||
response_text = self.llm_call(
|
response_text = self.llm_call(
|
||||||
input_data.config.api_key,
|
api_key=input_data.api_key,
|
||||||
input_data.config.model,
|
model=input_data.model,
|
||||||
prompt
|
prompt=prompt,
|
||||||
|
json=bool(input_data.expected_format)
|
||||||
)
|
)
|
||||||
logger.warning(f"LLM attempt-{retry_count} response: {response_text}")
|
logger.warning(f"LLM attempt-{retry_count} response: {response_text}")
|
||||||
|
|
||||||
parsed_dict, parsed_error = parse_response(response_text)
|
if input_data.expected_format:
|
||||||
if not parsed_error:
|
parsed_dict, parsed_error = parse_response(response_text)
|
||||||
yield "response", {k: str(v) for k, v in parsed_dict.items()}
|
if not parsed_error:
|
||||||
|
yield "response", {k: str(v) for k, v in parsed_dict.items()}
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
yield "response", {"response": response_text}
|
||||||
return
|
return
|
||||||
|
|
||||||
retry_prompt = f"""
|
retry_prompt = f"""
|
||||||
|
|||||||
@@ -138,7 +138,7 @@ def reddit(
|
|||||||
import requests
|
import requests
|
||||||
|
|
||||||
from autogpt_server.data.graph import Graph, Link, Node
|
from autogpt_server.data.graph import Graph, Link, Node
|
||||||
from autogpt_server.blocks.ai import LlmConfig, LlmCallBlock, LlmModel
|
from autogpt_server.blocks.ai import LlmCallBlock, LlmModel
|
||||||
from autogpt_server.blocks.reddit import (
|
from autogpt_server.blocks.reddit import (
|
||||||
RedditCredentials,
|
RedditCredentials,
|
||||||
RedditGetPostsBlock,
|
RedditGetPostsBlock,
|
||||||
@@ -153,10 +153,7 @@ def reddit(
|
|||||||
password=password,
|
password=password,
|
||||||
user_agent=user_agent,
|
user_agent=user_agent,
|
||||||
)
|
)
|
||||||
openai_creds = LlmConfig(
|
openai_api_key = "TODO_FILL_OUT_THIS"
|
||||||
model=LlmModel.openai_gpt4,
|
|
||||||
api_key="TODO_FILL_OUT_THIS",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Hardcoded inputs
|
# Hardcoded inputs
|
||||||
reddit_get_post_input = {
|
reddit_get_post_input = {
|
||||||
@@ -179,7 +176,7 @@ The product you are marketing is: Auto-GPT an autonomous AI agent utilizing GPT
|
|||||||
You reply the post that you find it relevant to be replied with marketing text.
|
You reply the post that you find it relevant to be replied with marketing text.
|
||||||
Make sure to only comment on a relevant post.
|
Make sure to only comment on a relevant post.
|
||||||
""",
|
""",
|
||||||
"config": openai_creds,
|
"api_key": openai_api_key,
|
||||||
"expected_format": {
|
"expected_format": {
|
||||||
"post_id": "str, the reddit post id",
|
"post_id": "str, the reddit post id",
|
||||||
"is_relevant": "bool, whether the post is relevant for marketing",
|
"is_relevant": "bool, whether the post is relevant for marketing",
|
||||||
@@ -219,7 +216,7 @@ Make sure to only comment on a relevant post.
|
|||||||
# Links
|
# Links
|
||||||
links = [
|
links = [
|
||||||
Link(reddit_get_post_node.id, text_formatter_node.id, "post", "named_texts"),
|
Link(reddit_get_post_node.id, text_formatter_node.id, "post", "named_texts"),
|
||||||
Link(text_formatter_node.id, llm_call_node.id, "output", "usr_prompt"),
|
Link(text_formatter_node.id, llm_call_node.id, "output", "prompt"),
|
||||||
Link(llm_call_node.id, text_matcher_node.id, "response", "data"),
|
Link(llm_call_node.id, text_matcher_node.id, "response", "data"),
|
||||||
Link(llm_call_node.id, text_matcher_node.id, "response_#_is_relevant", "text"),
|
Link(llm_call_node.id, text_matcher_node.id, "response_#_is_relevant", "text"),
|
||||||
Link(
|
Link(
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import time
|
|||||||
from autogpt_server.data import block, db
|
from autogpt_server.data import block, db
|
||||||
from autogpt_server.data.graph import Graph, Link, Node, create_graph
|
from autogpt_server.data.graph import Graph, Link, Node, create_graph
|
||||||
from autogpt_server.data.execution import ExecutionStatus
|
from autogpt_server.data.execution import ExecutionStatus
|
||||||
from autogpt_server.blocks.ai import LlmConfig, LlmCallBlock, LlmModel
|
from autogpt_server.blocks.ai import LlmCallBlock, LlmModel
|
||||||
from autogpt_server.blocks.reddit import (
|
from autogpt_server.blocks.reddit import (
|
||||||
RedditCredentials,
|
RedditCredentials,
|
||||||
RedditGetPostsBlock,
|
RedditGetPostsBlock,
|
||||||
@@ -27,10 +27,7 @@ async def create_test_graph() -> Graph:
|
|||||||
password="TODO_FILL_OUT_THIS",
|
password="TODO_FILL_OUT_THIS",
|
||||||
user_agent="TODO_FILL_OUT_THIS",
|
user_agent="TODO_FILL_OUT_THIS",
|
||||||
)
|
)
|
||||||
openai_creds = LlmConfig(
|
openai_api_key = "TODO_FILL_OUT_THIS"
|
||||||
model=LlmModel.openai_gpt4,
|
|
||||||
api_key="TODO_FILL_OUT_THIS",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Hardcoded inputs
|
# Hardcoded inputs
|
||||||
reddit_get_post_input = {
|
reddit_get_post_input = {
|
||||||
@@ -53,7 +50,7 @@ The product you are marketing is: Auto-GPT an autonomous AI agent utilizing GPT
|
|||||||
You reply the post that you find it relevant to be replied with marketing text.
|
You reply the post that you find it relevant to be replied with marketing text.
|
||||||
Make sure to only comment on a relevant post.
|
Make sure to only comment on a relevant post.
|
||||||
""",
|
""",
|
||||||
"config": openai_creds,
|
"api_key": openai_api_key,
|
||||||
"expected_format": {
|
"expected_format": {
|
||||||
"post_id": "str, the reddit post id",
|
"post_id": "str, the reddit post id",
|
||||||
"is_relevant": "bool, whether the post is relevant for marketing",
|
"is_relevant": "bool, whether the post is relevant for marketing",
|
||||||
@@ -96,7 +93,7 @@ Make sure to only comment on a relevant post.
|
|||||||
# Links
|
# Links
|
||||||
links = [
|
links = [
|
||||||
Link(reddit_get_post_node.id, text_formatter_node.id, "post", "named_texts"),
|
Link(reddit_get_post_node.id, text_formatter_node.id, "post", "named_texts"),
|
||||||
Link(text_formatter_node.id, llm_call_node.id, "output", "usr_prompt"),
|
Link(text_formatter_node.id, llm_call_node.id, "output", "prompt"),
|
||||||
Link(llm_call_node.id, text_matcher_node.id, "response", "data"),
|
Link(llm_call_node.id, text_matcher_node.id, "response", "data"),
|
||||||
Link(llm_call_node.id, text_matcher_node.id, "response_#_is_relevant", "text"),
|
Link(llm_call_node.id, text_matcher_node.id, "response_#_is_relevant", "text"),
|
||||||
Link(text_matcher_node.id, reddit_comment_node.id, "positive_#_post_id",
|
Link(text_matcher_node.id, reddit_comment_node.id, "positive_#_post_id",
|
||||||
|
|||||||
Reference in New Issue
Block a user