Compare commits

...

7 Commits

Author SHA1 Message Date
Aarushi
64ecd28804 Merge branch 'master' into zamilmajdy/simplify-ai-block 2024-07-10 13:17:09 +01:00
Toran Bruce Richards
cac41edafc feat: Update LlmCallBlock to handle different response formats 2024-07-10 12:00:56 +01:00
Toran Bruce Richards
6cd5007857 Update response_format in LlmCallBlock to handle json parameter 2024-07-10 11:49:45 +01:00
Zamil Majdy
d1badceb34 Merge remote-tracking branch 'origin/zamilmajdy/simplify-ai-block' into zamilmajdy/simplify-ai-block 2024-07-10 17:02:32 +07:00
Zamil Majdy
c24cfc3718 feat(rnd): Simplify AI block 2024-07-10 17:02:19 +07:00
Zamil Majdy
fb7480304a Merge branch 'master' into zamilmajdy/simplify-ai-block 2024-07-10 13:58:55 +04:00
Zamil Majdy
eb097eefab feat(rnd): Simplify AI block 2024-07-10 16:57:42 +07:00
3 changed files with 49 additions and 60 deletions

View File

@@ -1,9 +1,7 @@
import logging import logging
from enum import Enum
import openai import openai
from pydantic import BaseModel
from enum import Enum
from autogpt_server.data.block import Block, BlockOutput, BlockSchema from autogpt_server.data.block import Block, BlockOutput, BlockSchema
from autogpt_server.util import json from autogpt_server.util import json
@@ -14,17 +12,13 @@ class LlmModel(str, Enum):
openai_gpt4 = "gpt-4-turbo" openai_gpt4 = "gpt-4-turbo"
class LlmConfig(BaseModel):
model: LlmModel
api_key: str
class LlmCallBlock(Block): class LlmCallBlock(Block):
class Input(BlockSchema): class Input(BlockSchema):
config: LlmConfig api_key: str
expected_format: dict[str, str] prompt: str
sys_prompt: str = "" sys_prompt: str = ""
usr_prompt: str = "" expected_format: dict[str, str] = {}
model: LlmModel = LlmModel.openai_gpt4
retry: int = 3 retry: int = 3
class Output(BlockSchema): class Output(BlockSchema):
@@ -37,18 +31,15 @@ class LlmCallBlock(Block):
input_schema=LlmCallBlock.Input, input_schema=LlmCallBlock.Input,
output_schema=LlmCallBlock.Output, output_schema=LlmCallBlock.Output,
test_input={ test_input={
"config": { "model": "gpt-4-turbo",
"model": "gpt-4-turbo", "api_key": "fake-api",
"api_key": "fake-api",
},
"expected_format": { "expected_format": {
"key1": "value1", "key1": "value1",
"key2": "value2", "key2": "value2",
}, },
"sys_prompt": "System prompt", "prompt": "User prompt",
"usr_prompt": "User prompt",
}, },
test_output=("response", {"key1": "key1Value","key2": "key2Value"}), test_output=("response", {"key1": "key1Value", "key2": "key2Value"}),
test_mock={"llm_call": lambda *args, **kwargs: json.dumps({ test_mock={"llm_call": lambda *args, **kwargs: json.dumps({
"key1": "key1Value", "key1": "key1Value",
"key2": "key2Value", "key2": "key2Value",
@@ -56,36 +47,40 @@ class LlmCallBlock(Block):
) )
@staticmethod @staticmethod
def llm_call(api_key: str, model: LlmModel, prompt: list[dict]) -> str: def llm_call(api_key: str, model: LlmModel, prompt: list[dict], json: bool) -> str:
openai.api_key = api_key openai.api_key = api_key
response = openai.chat.completions.create( response = openai.chat.completions.create(
model=model, model=model,
messages=prompt, # type: ignore messages=prompt, # type: ignore
response_format={"type": "json_object"}, response_format={"type": "json_object"} if json else None,
) )
return response.choices[0].message.content or "" return response.choices[0].message.content or ""
def run(self, input_data: Input) -> BlockOutput: def run(self, input_data: Input) -> BlockOutput:
expected_format = [f'"{k}": "{v}"' for k, v in prompt = []
input_data.expected_format.items()]
format_prompt = ",\n ".join(expected_format)
sys_prompt = f"""
|{input_data.sys_prompt}
|
|Reply in json format:
|{{
| {format_prompt}
|}}
"""
usr_prompt = f"""
|{input_data.usr_prompt}
"""
def trim_prompt(s: str) -> str: def trim_prompt(s: str) -> str:
lines = s.strip().split("\n") lines = s.strip().split("\n")
return "\n".join([line.strip().lstrip("|") for line in lines]) return "\n".join([line.strip().lstrip("|") for line in lines])
if input_data.sys_prompt:
prompt.append({"role": "system", "content": input_data.sys_prompt})
if input_data.expected_format:
expected_format = [f'"{k}": "{v}"' for k, v in
input_data.expected_format.items()]
format_prompt = ",\n ".join(expected_format)
sys_prompt = f"""
|Reply in json format:
|{{
| {format_prompt}
|}}
"""
prompt.append({"role": "system", "content": trim_prompt(sys_prompt)})
prompt.append({"role": "user", "content": input_data.prompt})
def parse_response(resp: str) -> tuple[dict[str, str], str | None]: def parse_response(resp: str) -> tuple[dict[str, str], str | None]:
try: try:
parsed = json.loads(resp) parsed = json.loads(resp)
@@ -96,24 +91,24 @@ class LlmCallBlock(Block):
except Exception as e: except Exception as e:
return {}, f"JSON decode error: {e}" return {}, f"JSON decode error: {e}"
prompt = [
{"role": "system", "content": trim_prompt(sys_prompt)},
{"role": "user", "content": trim_prompt(usr_prompt)},
]
logger.warning(f"LLM request: {prompt}") logger.warning(f"LLM request: {prompt}")
retry_prompt = "" retry_prompt = ""
for retry_count in range(input_data.retry): for retry_count in range(input_data.retry):
response_text = self.llm_call( response_text = self.llm_call(
input_data.config.api_key, api_key=input_data.api_key,
input_data.config.model, model=input_data.model,
prompt prompt=prompt,
json=bool(input_data.expected_format)
) )
logger.warning(f"LLM attempt-{retry_count} response: {response_text}") logger.warning(f"LLM attempt-{retry_count} response: {response_text}")
parsed_dict, parsed_error = parse_response(response_text) if input_data.expected_format:
if not parsed_error: parsed_dict, parsed_error = parse_response(response_text)
yield "response", {k: str(v) for k, v in parsed_dict.items()} if not parsed_error:
yield "response", {k: str(v) for k, v in parsed_dict.items()}
return
else:
yield "response", {"response": response_text}
return return
retry_prompt = f""" retry_prompt = f"""

View File

@@ -138,7 +138,7 @@ def reddit(
import requests import requests
from autogpt_server.data.graph import Graph, Link, Node from autogpt_server.data.graph import Graph, Link, Node
from autogpt_server.blocks.ai import LlmConfig, LlmCallBlock, LlmModel from autogpt_server.blocks.ai import LlmCallBlock, LlmModel
from autogpt_server.blocks.reddit import ( from autogpt_server.blocks.reddit import (
RedditCredentials, RedditCredentials,
RedditGetPostsBlock, RedditGetPostsBlock,
@@ -153,10 +153,7 @@ def reddit(
password=password, password=password,
user_agent=user_agent, user_agent=user_agent,
) )
openai_creds = LlmConfig( openai_api_key = "TODO_FILL_OUT_THIS"
model=LlmModel.openai_gpt4,
api_key="TODO_FILL_OUT_THIS",
)
# Hardcoded inputs # Hardcoded inputs
reddit_get_post_input = { reddit_get_post_input = {
@@ -179,7 +176,7 @@ The product you are marketing is: Auto-GPT an autonomous AI agent utilizing GPT
You reply the post that you find it relevant to be replied with marketing text. You reply the post that you find it relevant to be replied with marketing text.
Make sure to only comment on a relevant post. Make sure to only comment on a relevant post.
""", """,
"config": openai_creds, "api_key": openai_api_key,
"expected_format": { "expected_format": {
"post_id": "str, the reddit post id", "post_id": "str, the reddit post id",
"is_relevant": "bool, whether the post is relevant for marketing", "is_relevant": "bool, whether the post is relevant for marketing",
@@ -219,7 +216,7 @@ Make sure to only comment on a relevant post.
# Links # Links
links = [ links = [
Link(reddit_get_post_node.id, text_formatter_node.id, "post", "named_texts"), Link(reddit_get_post_node.id, text_formatter_node.id, "post", "named_texts"),
Link(text_formatter_node.id, llm_call_node.id, "output", "usr_prompt"), Link(text_formatter_node.id, llm_call_node.id, "output", "prompt"),
Link(llm_call_node.id, text_matcher_node.id, "response", "data"), Link(llm_call_node.id, text_matcher_node.id, "response", "data"),
Link(llm_call_node.id, text_matcher_node.id, "response_#_is_relevant", "text"), Link(llm_call_node.id, text_matcher_node.id, "response_#_is_relevant", "text"),
Link( Link(

View File

@@ -2,7 +2,7 @@ import time
from autogpt_server.data import block, db from autogpt_server.data import block, db
from autogpt_server.data.graph import Graph, Link, Node, create_graph from autogpt_server.data.graph import Graph, Link, Node, create_graph
from autogpt_server.data.execution import ExecutionStatus from autogpt_server.data.execution import ExecutionStatus
from autogpt_server.blocks.ai import LlmConfig, LlmCallBlock, LlmModel from autogpt_server.blocks.ai import LlmCallBlock, LlmModel
from autogpt_server.blocks.reddit import ( from autogpt_server.blocks.reddit import (
RedditCredentials, RedditCredentials,
RedditGetPostsBlock, RedditGetPostsBlock,
@@ -27,10 +27,7 @@ async def create_test_graph() -> Graph:
password="TODO_FILL_OUT_THIS", password="TODO_FILL_OUT_THIS",
user_agent="TODO_FILL_OUT_THIS", user_agent="TODO_FILL_OUT_THIS",
) )
openai_creds = LlmConfig( openai_api_key = "TODO_FILL_OUT_THIS"
model=LlmModel.openai_gpt4,
api_key="TODO_FILL_OUT_THIS",
)
# Hardcoded inputs # Hardcoded inputs
reddit_get_post_input = { reddit_get_post_input = {
@@ -53,7 +50,7 @@ The product you are marketing is: Auto-GPT an autonomous AI agent utilizing GPT
You reply the post that you find it relevant to be replied with marketing text. You reply the post that you find it relevant to be replied with marketing text.
Make sure to only comment on a relevant post. Make sure to only comment on a relevant post.
""", """,
"config": openai_creds, "api_key": openai_api_key,
"expected_format": { "expected_format": {
"post_id": "str, the reddit post id", "post_id": "str, the reddit post id",
"is_relevant": "bool, whether the post is relevant for marketing", "is_relevant": "bool, whether the post is relevant for marketing",
@@ -96,7 +93,7 @@ Make sure to only comment on a relevant post.
# Links # Links
links = [ links = [
Link(reddit_get_post_node.id, text_formatter_node.id, "post", "named_texts"), Link(reddit_get_post_node.id, text_formatter_node.id, "post", "named_texts"),
Link(text_formatter_node.id, llm_call_node.id, "output", "usr_prompt"), Link(text_formatter_node.id, llm_call_node.id, "output", "prompt"),
Link(llm_call_node.id, text_matcher_node.id, "response", "data"), Link(llm_call_node.id, text_matcher_node.id, "response", "data"),
Link(llm_call_node.id, text_matcher_node.id, "response_#_is_relevant", "text"), Link(llm_call_node.id, text_matcher_node.id, "response_#_is_relevant", "text"),
Link(text_matcher_node.id, reddit_comment_node.id, "positive_#_post_id", Link(text_matcher_node.id, reddit_comment_node.id, "positive_#_post_id",