mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-09 23:28:07 -05:00
feature(platform) Smart Decision Maker Block (#9490)
## Task The SmartDecisionMakerBlock is a specialized block in a graph-based system that leverages a language model (LLM) to make intelligent decisions about which tools or functions to invoke based on a user-provided prompt. It is designed to process input data, interact with a language model, and dynamically determine the appropriate tools to call from a set of available options, making it a powerful component for AI-driven workflows. ## How It Works in Practice - **Scenario:** Imagine a workflow where a user inputs, "Send an email to John about the meeting." The SmartDecisionMakerBlock is connected to tools like send_email, schedule_meeting, and search_contacts. - **Execution:** 1. The block receives the prompt and system instructions (e.g., "Choose a function to call"). 2.It identifies the available tools from the graph and constructs their signatures (e.g., send_email(recipient, subject, body)). 3. The LLM analyzes the prompt and decides to call send_email with arguments like recipient: "John", subject: "Meeting", body: "Let’s discuss...". 4. The block yields these tool-specific outputs, which can be picked up by downstream nodes to execute the email-sending action. ## Changes 🏗️ - Add the Smart Decision Maker (SDM) block. - Break circular imports in integration code.  ## Work in Progress ⚠️ **Important note this is a temporary UX for the system - UX will be addressed in a future PR** ⚠️ ### Current Status I’m currently focused on the smart decision logic. The main additions in the ongoing PR include: - Defining function signatures for OpenAI function-calling schemas based on node links and the linked blocks. - Adding tests for function signature generation. - Force all tool calls to be made via an agent. (Need to uncomment) - Restrict each tool call entry to a single node. - simplify the output emission process, to emit each parameter one at a time. - Change test to use agents and hardcode output how I think it should work to test it does actually work - Hook up openai, in a simplified way, to test the function calling (mock for testing) - Once all the above is working, use credentials system and build of llm.py ### What’s Next - Review Process ### Reviewers Phase 1 This PR is now ready for review, during the first phase of reviews I'm looking for comments on approach and logic. Out of scope: code style and organization at this stage ### Reviewers Phase 2 Once we are all happy with the approach and logic. We can open the review process to general code quality and nits, to be considered. --------- Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
This commit is contained in:
@@ -4,9 +4,9 @@ from abc import ABC
|
||||
from enum import Enum, EnumMeta
|
||||
from json import JSONDecodeError
|
||||
from types import MappingProxyType
|
||||
from typing import TYPE_CHECKING, Any, List, Literal, NamedTuple
|
||||
from typing import TYPE_CHECKING, Any, Iterable, List, Literal, NamedTuple, Optional
|
||||
|
||||
from pydantic import SecretStr
|
||||
from pydantic import BaseModel, SecretStr
|
||||
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
@@ -16,6 +16,8 @@ if TYPE_CHECKING:
|
||||
import anthropic
|
||||
import ollama
|
||||
import openai
|
||||
from anthropic._types import NotGiven
|
||||
from anthropic.types import ToolParam
|
||||
from groq import Groq
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
@@ -238,6 +240,291 @@ class Message(BlockSchema):
|
||||
content: str
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
|
||||
class ToolContentBlock(BaseModel):
|
||||
id: str
|
||||
type: str
|
||||
function: ToolCall
|
||||
|
||||
|
||||
class LLMResponse(BaseModel):
|
||||
prompt: str
|
||||
response: str
|
||||
tool_calls: Optional[List[ToolContentBlock]] | None
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
|
||||
|
||||
def convert_openai_tool_fmt_to_anthropic(
|
||||
openai_tools: list[dict] | None = None,
|
||||
) -> Iterable[ToolParam] | NotGiven:
|
||||
"""
|
||||
Convert OpenAI tool format to Anthropic tool format.
|
||||
"""
|
||||
if not openai_tools or len(openai_tools) == 0:
|
||||
return anthropic.NOT_GIVEN
|
||||
|
||||
anthropic_tools = []
|
||||
for tool in openai_tools:
|
||||
if "function" in tool:
|
||||
# Handle case where tool is already in OpenAI format with "type" and "function"
|
||||
function_data = tool["function"]
|
||||
else:
|
||||
# Handle case where tool is just the function definition
|
||||
function_data = tool
|
||||
|
||||
anthropic_tool: anthropic.types.ToolParam = {
|
||||
"name": function_data["name"],
|
||||
"description": function_data.get("description", ""),
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": function_data.get("parameters", {}).get("properties", {}),
|
||||
"required": function_data.get("parameters", {}).get("required", []),
|
||||
},
|
||||
}
|
||||
anthropic_tools.append(anthropic_tool)
|
||||
|
||||
return anthropic_tools
|
||||
|
||||
|
||||
def llm_call(
|
||||
credentials: APIKeyCredentials,
|
||||
llm_model: LlmModel,
|
||||
prompt: list[dict],
|
||||
json_format: bool,
|
||||
max_tokens: int | None,
|
||||
tools: list[dict] | None = None,
|
||||
ollama_host: str = "localhost:11434",
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
Make a call to a language model.
|
||||
|
||||
Args:
|
||||
credentials: The API key credentials to use.
|
||||
llm_model: The LLM model to use.
|
||||
prompt: The prompt to send to the LLM.
|
||||
json_format: Whether the response should be in JSON format.
|
||||
max_tokens: The maximum number of tokens to generate in the chat completion.
|
||||
tools: The tools to use in the chat completion.
|
||||
ollama_host: The host for ollama to use.
|
||||
|
||||
Returns:
|
||||
LLMResponse object containing:
|
||||
- prompt: The prompt sent to the LLM.
|
||||
- response: The text response from the LLM.
|
||||
- tool_calls: Any tool calls the model made, if applicable.
|
||||
- prompt_tokens: The number of tokens used in the prompt.
|
||||
- completion_tokens: The number of tokens used in the completion.
|
||||
"""
|
||||
provider = llm_model.metadata.provider
|
||||
max_tokens = max_tokens or llm_model.max_output_tokens or 4096
|
||||
|
||||
if provider == "openai":
|
||||
tools_param = tools if tools else openai.NOT_GIVEN
|
||||
oai_client = openai.OpenAI(api_key=credentials.api_key.get_secret_value())
|
||||
response_format = None
|
||||
|
||||
if llm_model in [LlmModel.O1_MINI, LlmModel.O1_PREVIEW]:
|
||||
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
|
||||
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
|
||||
prompt = [
|
||||
{"role": "user", "content": "\n".join(sys_messages)},
|
||||
{"role": "user", "content": "\n".join(usr_messages)},
|
||||
]
|
||||
elif json_format:
|
||||
response_format = {"type": "json_object"}
|
||||
|
||||
response = oai_client.chat.completions.create(
|
||||
model=llm_model.value,
|
||||
messages=prompt, # type: ignore
|
||||
response_format=response_format, # type: ignore
|
||||
max_completion_tokens=max_tokens,
|
||||
tools=tools_param, # type: ignore
|
||||
)
|
||||
|
||||
if response.choices[0].message.tool_calls:
|
||||
tool_calls = [
|
||||
ToolContentBlock(
|
||||
id=tool.id,
|
||||
type=tool.type,
|
||||
function=ToolCall(
|
||||
name=tool.function.name,
|
||||
arguments=tool.function.arguments,
|
||||
),
|
||||
)
|
||||
for tool in response.choices[0].message.tool_calls
|
||||
]
|
||||
else:
|
||||
tool_calls = None
|
||||
|
||||
return LLMResponse(
|
||||
prompt=json.dumps(prompt),
|
||||
response=response.choices[0].message.content or "",
|
||||
tool_calls=tool_calls,
|
||||
prompt_tokens=response.usage.prompt_tokens if response.usage else 0,
|
||||
completion_tokens=response.usage.completion_tokens if response.usage else 0,
|
||||
)
|
||||
elif provider == "anthropic":
|
||||
|
||||
an_tools = convert_openai_tool_fmt_to_anthropic(tools)
|
||||
|
||||
system_messages = [p["content"] for p in prompt if p["role"] == "system"]
|
||||
sysprompt = " ".join(system_messages)
|
||||
|
||||
messages = []
|
||||
last_role = None
|
||||
for p in prompt:
|
||||
if p["role"] in ["user", "assistant"]:
|
||||
if p["role"] != last_role:
|
||||
messages.append({"role": p["role"], "content": p["content"]})
|
||||
last_role = p["role"]
|
||||
else:
|
||||
# If the role is the same as the last one, combine the content
|
||||
messages[-1]["content"] += "\n" + p["content"]
|
||||
|
||||
client = anthropic.Anthropic(api_key=credentials.api_key.get_secret_value())
|
||||
try:
|
||||
resp = client.messages.create(
|
||||
model=llm_model.value,
|
||||
system=sysprompt,
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
tools=an_tools,
|
||||
)
|
||||
|
||||
if not resp.content:
|
||||
raise ValueError("No content returned from Anthropic.")
|
||||
|
||||
tool_calls = None
|
||||
for content_block in resp.content:
|
||||
# Antropic is different to openai, need to iterate through
|
||||
# the content blocks to find the tool calls
|
||||
if content_block.type == "tool_use":
|
||||
if tool_calls is None:
|
||||
tool_calls = []
|
||||
tool_calls.append(
|
||||
ToolContentBlock(
|
||||
id=content_block.id,
|
||||
type=content_block.type,
|
||||
function=ToolCall(
|
||||
name=content_block.name,
|
||||
arguments=json.dumps(content_block.input),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if not tool_calls and resp.stop_reason == "tool_use":
|
||||
logger.warning(
|
||||
"Tool use stop reason but no tool calls found in content. %s", resp
|
||||
)
|
||||
|
||||
return LLMResponse(
|
||||
prompt=json.dumps(prompt),
|
||||
response=(
|
||||
resp.content[0].name
|
||||
if isinstance(resp.content[0], anthropic.types.ToolUseBlock)
|
||||
else resp.content[0].text
|
||||
),
|
||||
tool_calls=tool_calls,
|
||||
prompt_tokens=resp.usage.input_tokens,
|
||||
completion_tokens=resp.usage.output_tokens,
|
||||
)
|
||||
except anthropic.APIError as e:
|
||||
error_message = f"Anthropic API error: {str(e)}"
|
||||
logger.error(error_message)
|
||||
raise ValueError(error_message)
|
||||
elif provider == "groq":
|
||||
if tools:
|
||||
raise ValueError("Groq does not support tools.")
|
||||
|
||||
client = Groq(api_key=credentials.api_key.get_secret_value())
|
||||
response_format = {"type": "json_object"} if json_format else None
|
||||
response = client.chat.completions.create(
|
||||
model=llm_model.value,
|
||||
messages=prompt, # type: ignore
|
||||
response_format=response_format, # type: ignore
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
return LLMResponse(
|
||||
prompt=json.dumps(prompt),
|
||||
response=response.choices[0].message.content or "",
|
||||
tool_calls=None,
|
||||
prompt_tokens=response.usage.prompt_tokens if response.usage else 0,
|
||||
completion_tokens=response.usage.completion_tokens if response.usage else 0,
|
||||
)
|
||||
elif provider == "ollama":
|
||||
if tools:
|
||||
raise ValueError("Ollama does not support tools.")
|
||||
|
||||
client = ollama.Client(host=ollama_host)
|
||||
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
|
||||
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
|
||||
response = client.generate(
|
||||
model=llm_model.value,
|
||||
prompt=f"{sys_messages}\n\n{usr_messages}",
|
||||
stream=False,
|
||||
)
|
||||
return LLMResponse(
|
||||
prompt=json.dumps(prompt),
|
||||
response=response.get("response") or "",
|
||||
tool_calls=None,
|
||||
prompt_tokens=response.get("prompt_eval_count") or 0,
|
||||
completion_tokens=response.get("eval_count") or 0,
|
||||
)
|
||||
elif provider == "open_router":
|
||||
tools_param = tools if tools else openai.NOT_GIVEN
|
||||
client = openai.OpenAI(
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
api_key=credentials.api_key.get_secret_value(),
|
||||
)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
extra_headers={
|
||||
"HTTP-Referer": "https://agpt.co",
|
||||
"X-Title": "AutoGPT",
|
||||
},
|
||||
model=llm_model.value,
|
||||
messages=prompt, # type: ignore
|
||||
max_tokens=max_tokens,
|
||||
tools=tools_param, # type: ignore
|
||||
)
|
||||
|
||||
# If there's no response, raise an error
|
||||
if not response.choices:
|
||||
if response:
|
||||
raise ValueError(f"OpenRouter error: {response}")
|
||||
else:
|
||||
raise ValueError("No response from OpenRouter.")
|
||||
|
||||
if response.choices[0].message.tool_calls:
|
||||
tool_calls = [
|
||||
ToolContentBlock(
|
||||
id=tool.id,
|
||||
type=tool.type,
|
||||
function=ToolCall(
|
||||
name=tool.function.name, arguments=tool.function.arguments
|
||||
),
|
||||
)
|
||||
for tool in response.choices[0].message.tool_calls
|
||||
]
|
||||
else:
|
||||
tool_calls = None
|
||||
|
||||
return LLMResponse(
|
||||
prompt=json.dumps(prompt),
|
||||
response=response.choices[0].message.content or "",
|
||||
tool_calls=tool_calls,
|
||||
prompt_tokens=response.usage.prompt_tokens if response.usage else 0,
|
||||
completion_tokens=response.usage.completion_tokens if response.usage else 0,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||
|
||||
|
||||
class AIBlockBase(Block, ABC):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
@@ -260,7 +547,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=LlmModel.GPT4_TURBO,
|
||||
default=LlmModel.GPT4O,
|
||||
description="The language model to use for answering the prompt.",
|
||||
advanced=False,
|
||||
)
|
||||
@@ -311,7 +598,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
input_schema=AIStructuredResponseGeneratorBlock.Input,
|
||||
output_schema=AIStructuredResponseGeneratorBlock.Output,
|
||||
test_input={
|
||||
"model": LlmModel.GPT4_TURBO,
|
||||
"model": LlmModel.GPT4O,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"expected_format": {
|
||||
"key1": "value1",
|
||||
@@ -325,19 +612,20 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
("prompt", str),
|
||||
],
|
||||
test_mock={
|
||||
"llm_call": lambda *args, **kwargs: (
|
||||
json.dumps(
|
||||
"llm_call": lambda *args, **kwargs: LLMResponse(
|
||||
prompt="",
|
||||
response=json.dumps(
|
||||
{
|
||||
"key1": "key1Value",
|
||||
"key2": "key2Value",
|
||||
}
|
||||
),
|
||||
0,
|
||||
0,
|
||||
tool_calls=None,
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
)
|
||||
},
|
||||
)
|
||||
self.prompt = ""
|
||||
|
||||
def llm_call(
|
||||
self,
|
||||
@@ -346,154 +634,22 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
prompt: list[dict],
|
||||
json_format: bool,
|
||||
max_tokens: int | None,
|
||||
tools: list[dict] | None = None,
|
||||
ollama_host: str = "localhost:11434",
|
||||
) -> tuple[str, int, int]:
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
Args:
|
||||
credentials: The API key credentials to use.
|
||||
llm_model: The LLM model to use.
|
||||
prompt: The prompt to send to the LLM.
|
||||
json_format: Whether the response should be in JSON format.
|
||||
max_tokens: The maximum number of tokens to generate in the chat completion.
|
||||
ollama_host: The host for ollama to use
|
||||
|
||||
Returns:
|
||||
The response from the LLM.
|
||||
The number of tokens used in the prompt.
|
||||
The number of tokens used in the completion.
|
||||
Test mocks work only on class functions, this wraps the llm_call function
|
||||
so that it can be mocked withing the block testing framework.
|
||||
"""
|
||||
provider = llm_model.metadata.provider
|
||||
max_tokens = max_tokens or llm_model.max_output_tokens or 4096
|
||||
|
||||
if provider == "openai":
|
||||
oai_client = openai.OpenAI(api_key=credentials.api_key.get_secret_value())
|
||||
response_format = None
|
||||
|
||||
if llm_model in [LlmModel.O1_MINI, LlmModel.O1_PREVIEW]:
|
||||
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
|
||||
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
|
||||
prompt = [
|
||||
{"role": "user", "content": "\n".join(sys_messages)},
|
||||
{"role": "user", "content": "\n".join(usr_messages)},
|
||||
]
|
||||
elif json_format:
|
||||
response_format = {"type": "json_object"}
|
||||
|
||||
response = oai_client.chat.completions.create(
|
||||
model=llm_model.value,
|
||||
messages=prompt, # type: ignore
|
||||
response_format=response_format, # type: ignore
|
||||
max_completion_tokens=max_tokens,
|
||||
)
|
||||
self.prompt = json.dumps(prompt)
|
||||
|
||||
return (
|
||||
response.choices[0].message.content or "",
|
||||
response.usage.prompt_tokens if response.usage else 0,
|
||||
response.usage.completion_tokens if response.usage else 0,
|
||||
)
|
||||
elif provider == "anthropic":
|
||||
system_messages = [p["content"] for p in prompt if p["role"] == "system"]
|
||||
sysprompt = " ".join(system_messages)
|
||||
|
||||
messages = []
|
||||
last_role = None
|
||||
for p in prompt:
|
||||
if p["role"] in ["user", "assistant"]:
|
||||
if p["role"] != last_role:
|
||||
messages.append({"role": p["role"], "content": p["content"]})
|
||||
last_role = p["role"]
|
||||
else:
|
||||
# If the role is the same as the last one, combine the content
|
||||
messages[-1]["content"] += "\n" + p["content"]
|
||||
|
||||
client = anthropic.Anthropic(api_key=credentials.api_key.get_secret_value())
|
||||
try:
|
||||
resp = client.messages.create(
|
||||
model=llm_model.value,
|
||||
system=sysprompt,
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
self.prompt = json.dumps(prompt)
|
||||
|
||||
if not resp.content:
|
||||
raise ValueError("No content returned from Anthropic.")
|
||||
|
||||
return (
|
||||
(
|
||||
resp.content[0].name
|
||||
if isinstance(resp.content[0], anthropic.types.ToolUseBlock)
|
||||
else resp.content[0].text
|
||||
),
|
||||
resp.usage.input_tokens,
|
||||
resp.usage.output_tokens,
|
||||
)
|
||||
except anthropic.APIError as e:
|
||||
error_message = f"Anthropic API error: {str(e)}"
|
||||
logger.error(error_message)
|
||||
raise ValueError(error_message)
|
||||
elif provider == "groq":
|
||||
client = Groq(api_key=credentials.api_key.get_secret_value())
|
||||
response_format = {"type": "json_object"} if json_format else None
|
||||
response = client.chat.completions.create(
|
||||
model=llm_model.value,
|
||||
messages=prompt, # type: ignore
|
||||
response_format=response_format, # type: ignore
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
self.prompt = json.dumps(prompt)
|
||||
return (
|
||||
response.choices[0].message.content or "",
|
||||
response.usage.prompt_tokens if response.usage else 0,
|
||||
response.usage.completion_tokens if response.usage else 0,
|
||||
)
|
||||
elif provider == "ollama":
|
||||
client = ollama.Client(host=ollama_host)
|
||||
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
|
||||
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
|
||||
response = client.generate(
|
||||
model=llm_model.value,
|
||||
prompt=f"{sys_messages}\n\n{usr_messages}",
|
||||
stream=False,
|
||||
)
|
||||
self.prompt = json.dumps(prompt)
|
||||
return (
|
||||
response.get("response") or "",
|
||||
response.get("prompt_eval_count") or 0,
|
||||
response.get("eval_count") or 0,
|
||||
)
|
||||
elif provider == "open_router":
|
||||
client = openai.OpenAI(
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
api_key=credentials.api_key.get_secret_value(),
|
||||
)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
extra_headers={
|
||||
"HTTP-Referer": "https://agpt.co",
|
||||
"X-Title": "AutoGPT",
|
||||
},
|
||||
model=llm_model.value,
|
||||
messages=prompt, # type: ignore
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
self.prompt = json.dumps(prompt)
|
||||
|
||||
# If there's no response, raise an error
|
||||
if not response.choices:
|
||||
if response:
|
||||
raise ValueError(f"OpenRouter error: {response}")
|
||||
else:
|
||||
raise ValueError("No response from OpenRouter.")
|
||||
|
||||
return (
|
||||
response.choices[0].message.content or "",
|
||||
response.usage.prompt_tokens if response.usage else 0,
|
||||
response.usage.completion_tokens if response.usage else 0,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||
return llm_call(
|
||||
credentials=credentials,
|
||||
llm_model=llm_model,
|
||||
prompt=prompt,
|
||||
json_format=json_format,
|
||||
max_tokens=max_tokens,
|
||||
tools=tools,
|
||||
ollama_host=ollama_host,
|
||||
)
|
||||
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
@@ -549,7 +705,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
|
||||
for retry_count in range(input_data.retry):
|
||||
try:
|
||||
response_text, input_token, output_token = self.llm_call(
|
||||
llm_response = self.llm_call(
|
||||
credentials=credentials,
|
||||
llm_model=llm_model,
|
||||
prompt=prompt,
|
||||
@@ -557,10 +713,11 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
ollama_host=input_data.ollama_host,
|
||||
max_tokens=input_data.max_tokens,
|
||||
)
|
||||
response_text = llm_response.response
|
||||
self.merge_stats(
|
||||
{
|
||||
"input_token_count": input_token,
|
||||
"output_token_count": output_token,
|
||||
"input_token_count": llm_response.prompt_tokens,
|
||||
"output_token_count": llm_response.completion_tokens,
|
||||
}
|
||||
)
|
||||
logger.info(f"LLM attempt-{retry_count} response: {response_text}")
|
||||
@@ -621,7 +778,7 @@ class AITextGeneratorBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=LlmModel.GPT4_TURBO,
|
||||
default=LlmModel.GPT4O,
|
||||
description="The language model to use for answering the prompt.",
|
||||
advanced=False,
|
||||
)
|
||||
@@ -714,7 +871,7 @@ class AITextSummarizerBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=LlmModel.GPT4_TURBO,
|
||||
default=LlmModel.GPT4O,
|
||||
description="The language model to use for summarizing the text.",
|
||||
)
|
||||
focus: str = SchemaField(
|
||||
@@ -880,7 +1037,7 @@ class AIConversationBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=LlmModel.GPT4_TURBO,
|
||||
default=LlmModel.GPT4O,
|
||||
description="The language model to use for the conversation.",
|
||||
)
|
||||
credentials: AICredentials = AICredentialsField()
|
||||
@@ -919,7 +1076,7 @@ class AIConversationBlock(AIBlockBase):
|
||||
},
|
||||
{"role": "user", "content": "Where was it played?"},
|
||||
],
|
||||
"model": LlmModel.GPT4_TURBO,
|
||||
"model": LlmModel.GPT4O,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
@@ -981,7 +1138,7 @@ class AIListGeneratorBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=LlmModel.GPT4_TURBO,
|
||||
default=LlmModel.GPT4O,
|
||||
description="The language model to use for generating the list.",
|
||||
advanced=True,
|
||||
)
|
||||
@@ -1030,7 +1187,7 @@ class AIListGeneratorBlock(AIBlockBase):
|
||||
"drawing explorers to uncover its mysteries. Each planet showcases the limitless possibilities of "
|
||||
"fictional worlds."
|
||||
),
|
||||
"model": LlmModel.GPT4_TURBO,
|
||||
"model": LlmModel.GPT4O,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"max_retries": 3,
|
||||
},
|
||||
|
||||
307
autogpt_platform/backend/backend/blocks/smart_decision_maker.py
Normal file
307
autogpt_platform/backend/backend/blocks/smart_decision_maker.py
Normal file
@@ -0,0 +1,307 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, List
|
||||
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
|
||||
import backend.blocks.llm as llm
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema, BlockType
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_database_manager_client():
|
||||
from backend.executor import DatabaseManager
|
||||
from backend.util.service import get_service_client
|
||||
|
||||
return get_service_client(DatabaseManager)
|
||||
|
||||
|
||||
class SmartDecisionMakerBlock(Block):
|
||||
"""
|
||||
A block that uses a language model to make smart decisions based on a given prompt.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
prompt: str = SchemaField(
|
||||
description="The prompt to send to the language model.",
|
||||
placeholder="Enter your prompt here...",
|
||||
)
|
||||
model: llm.LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=llm.LlmModel.GPT4O,
|
||||
description="The language model to use for answering the prompt.",
|
||||
advanced=False,
|
||||
)
|
||||
credentials: llm.AICredentials = llm.AICredentialsField()
|
||||
sys_prompt: str = SchemaField(
|
||||
title="System Prompt",
|
||||
default="Thinking carefully step by step decide which function to call. Always choose a function call from the list of function signatures.",
|
||||
description="The system prompt to provide additional context to the model.",
|
||||
)
|
||||
conversation_history: list[llm.Message] = SchemaField(
|
||||
default=[],
|
||||
description="The conversation history to provide context for the prompt.",
|
||||
)
|
||||
retry: int = SchemaField(
|
||||
title="Retry Count",
|
||||
default=3,
|
||||
description="Number of times to retry the LLM call if the response does not match the expected format.",
|
||||
)
|
||||
prompt_values: dict[str, str] = SchemaField(
|
||||
advanced=False,
|
||||
default={},
|
||||
description="Values used to fill in the prompt. The values can be used in the prompt by putting them in a double curly braces, e.g. {{variable_name}}.",
|
||||
)
|
||||
max_tokens: int | None = SchemaField(
|
||||
advanced=True,
|
||||
default=None,
|
||||
description="The maximum number of tokens to generate in the chat completion.",
|
||||
)
|
||||
ollama_host: str = SchemaField(
|
||||
advanced=True,
|
||||
default="localhost:11434",
|
||||
description="Ollama host for local models",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
|
||||
prompt: str = SchemaField(description="The prompt sent to the language model.")
|
||||
error: str = SchemaField(description="Error message if the API call failed.")
|
||||
function_signatures: list[dict[str, Any]] = SchemaField(
|
||||
description="The function signatures that are sent to the language model."
|
||||
)
|
||||
tools: Any = SchemaField(description="The tools that are available to use.")
|
||||
finished: str = SchemaField(
|
||||
description="The finished message to display to the user."
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="3b191d9f-356f-482d-8238-ba04b6d18381",
|
||||
description="Uses AI to intelligently decide what tool to use.",
|
||||
categories={BlockCategory.AI},
|
||||
block_type=BlockType.AI,
|
||||
input_schema=SmartDecisionMakerBlock.Input,
|
||||
output_schema=SmartDecisionMakerBlock.Output,
|
||||
test_input={
|
||||
"prompt": "Hello, World!",
|
||||
"credentials": llm.TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_output=[],
|
||||
test_credentials=llm.TEST_CREDENTIALS,
|
||||
)
|
||||
|
||||
# If I import Graph here, it will break with a circular import.
|
||||
def _get_tool_graph_metadata(self, node_id: str, graph: Any) -> List[Any]:
|
||||
"""
|
||||
Retrieves metadata for tool graphs linked to a specified node within a graph.
|
||||
|
||||
This method identifies the tool links connected to the given node_id and fetches
|
||||
the metadata for each linked tool graph from the database.
|
||||
|
||||
Args:
|
||||
node_id (str): The ID of the node for which tool graph metadata is to be retrieved.
|
||||
graph (Any): The graph object containing nodes and links.
|
||||
|
||||
Returns:
|
||||
List[Any]: A list of metadata for the tool graphs linked to the specified node.
|
||||
"""
|
||||
db_client = get_database_manager_client()
|
||||
graph_meta = []
|
||||
|
||||
tool_links = {
|
||||
link.sink_id
|
||||
for link in graph.links
|
||||
if link.source_name.startswith("tools_^_") and link.source_id == node_id
|
||||
}
|
||||
|
||||
for link_id in tool_links:
|
||||
node = next((node for node in graph.nodes if node.id == link_id), None)
|
||||
if node:
|
||||
node_graph_meta = db_client.get_graph_metadata(
|
||||
node.input_default["graph_id"], node.input_default["graph_version"]
|
||||
)
|
||||
if node_graph_meta:
|
||||
graph_meta.append(node_graph_meta)
|
||||
|
||||
return graph_meta
|
||||
|
||||
@staticmethod
|
||||
def _create_function_signature(
|
||||
# If I import Graph here, it will break with a circular import.
|
||||
node_id: str,
|
||||
graph: Any,
|
||||
tool_graph_metadata: List[Any],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Creates function signatures for tools linked to a specified node within a graph.
|
||||
|
||||
This method filters the graph links to identify those that are tools and are
|
||||
connected to the given node_id. It then constructs function signatures for each
|
||||
tool based on the metadata and input schema of the linked nodes.
|
||||
|
||||
Args:
|
||||
node_id (str): The ID of the node for which tool function signatures are to be created.
|
||||
graph (Any): The graph object containing nodes and links.
|
||||
tool_graph_metadata (List[Any]): Metadata for the tool graphs, used to retrieve
|
||||
names and descriptions for the tools.
|
||||
|
||||
Returns:
|
||||
list[dict[str, Any]]: A list of dictionaries, each representing a function signature
|
||||
for a tool, including its name, description, and parameters.
|
||||
|
||||
Raises:
|
||||
ValueError: If no tool links are found for the specified node_id, or if a sink node
|
||||
or its metadata cannot be found.
|
||||
"""
|
||||
# Filter the graph links to find those that are tools and are linked to the specified node_id
|
||||
tool_links = [
|
||||
link
|
||||
for link in graph.links
|
||||
# NOTE: Maybe we can do a specific database call to only get relevant nodes
|
||||
# async def get_connected_output_nodes(source_node_id: str) -> list[Node]:
|
||||
# links = await AgentNodeLink.prisma().find_many(
|
||||
# where={"agentNodeSourceId": source_node_id},
|
||||
# include={"AgentNode": {"include": AGENT_NODE_INCLUDE}},
|
||||
# )
|
||||
# return [NodeModel.from_db(link.AgentNodeSink) for link in links]
|
||||
if link.source_name.startswith("tools_^_") and link.source_id == node_id
|
||||
]
|
||||
|
||||
if not tool_links:
|
||||
raise ValueError(
|
||||
f"Expected at least one tool link in the graph. Node ID: {node_id}. Graph: {graph.links}"
|
||||
)
|
||||
|
||||
return_tool_functions = []
|
||||
|
||||
grouped_tool_links = {}
|
||||
|
||||
for link in tool_links:
|
||||
grouped_tool_links.setdefault(link.sink_id, []).append(link)
|
||||
|
||||
for _, links in grouped_tool_links.items():
|
||||
sink_node = next(
|
||||
(node for node in graph.nodes if node.id == links[0].sink_id), None
|
||||
)
|
||||
|
||||
if not sink_node:
|
||||
raise ValueError(f"Sink node not found: {links[0].sink_id}")
|
||||
|
||||
graph_id = sink_node.input_default["graph_id"]
|
||||
graph_version = sink_node.input_default["graph_version"]
|
||||
|
||||
sink_graph_meta = next(
|
||||
(
|
||||
meta
|
||||
for meta in tool_graph_metadata
|
||||
if meta.id == graph_id and meta.version == graph_version
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if not sink_graph_meta:
|
||||
raise ValueError(
|
||||
f"Sink graph metadata not found: {graph_id} {graph_version}"
|
||||
)
|
||||
|
||||
tool_function: dict[str, Any] = {
|
||||
"name": re.sub(r"[^a-zA-Z0-9_-]", "_", sink_graph_meta.name).lower(),
|
||||
"description": sink_graph_meta.description,
|
||||
}
|
||||
|
||||
properties = {}
|
||||
required = []
|
||||
|
||||
for link in links:
|
||||
sink_block_input_schema = sink_node.input_default["input_schema"]
|
||||
description = (
|
||||
sink_block_input_schema["properties"][link.sink_name]["description"]
|
||||
if "description"
|
||||
in sink_block_input_schema["properties"][link.sink_name]
|
||||
else f"The {link.sink_name} of the tool"
|
||||
)
|
||||
properties[link.sink_name.lower()] = {
|
||||
"type": "string",
|
||||
"description": description,
|
||||
}
|
||||
|
||||
tool_function["parameters"] = {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required,
|
||||
"additionalProperties": False,
|
||||
"strict": True,
|
||||
}
|
||||
|
||||
return_tool_functions.append(
|
||||
{"type": "function", "function": tool_function}
|
||||
)
|
||||
return return_tool_functions
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: llm.APIKeyCredentials,
|
||||
graph_id: str,
|
||||
node_id: str,
|
||||
graph_exec_id: str,
|
||||
node_exec_id: str,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
db_client = get_database_manager_client()
|
||||
|
||||
# Retrieve the current graph and node details
|
||||
graph = db_client.get_graph(graph_id=graph_id, user_id=user_id)
|
||||
|
||||
if not graph:
|
||||
raise ValueError(
|
||||
f"The currently running graph that is executing this node is not found {graph_id}"
|
||||
)
|
||||
|
||||
tool_graph_metadata = self._get_tool_graph_metadata(node_id, graph)
|
||||
|
||||
tool_functions = self._create_function_signature(
|
||||
node_id, graph, tool_graph_metadata
|
||||
)
|
||||
|
||||
prompt = [p.model_dump() for p in input_data.conversation_history]
|
||||
|
||||
values = input_data.prompt_values
|
||||
if values:
|
||||
input_data.prompt = llm.fmt.format_string(input_data.prompt, values)
|
||||
input_data.sys_prompt = llm.fmt.format_string(input_data.sys_prompt, values)
|
||||
|
||||
if input_data.sys_prompt:
|
||||
prompt.append({"role": "system", "content": input_data.sys_prompt})
|
||||
|
||||
if input_data.prompt:
|
||||
prompt.append({"role": "user", "content": input_data.prompt})
|
||||
|
||||
response = llm.llm_call(
|
||||
credentials=credentials,
|
||||
llm_model=input_data.model,
|
||||
prompt=prompt,
|
||||
json_format=False,
|
||||
max_tokens=input_data.max_tokens,
|
||||
tools=tool_functions,
|
||||
ollama_host=input_data.ollama_host,
|
||||
)
|
||||
|
||||
if not response.tool_calls:
|
||||
|
||||
yield "finished", f"No Decision Made finishing task: {response.response}"
|
||||
|
||||
if response.tool_calls:
|
||||
for tool_call in response.tool_calls:
|
||||
tool_name = tool_call.function.name
|
||||
tool_args = json.loads(tool_call.function.arguments)
|
||||
|
||||
for arg_name, arg_value in tool_args.items():
|
||||
yield f"tools_^_{tool_name}_{arg_name}".lower(), arg_value
|
||||
@@ -44,6 +44,7 @@ class BlockType(Enum):
|
||||
WEBHOOK = "Webhook"
|
||||
WEBHOOK_MANUAL = "Webhook (manual)"
|
||||
AGENT = "Agent"
|
||||
AI = "AI"
|
||||
|
||||
|
||||
class BlockCategory(Enum):
|
||||
@@ -351,6 +352,14 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
Run the block with the given input data.
|
||||
Args:
|
||||
input_data: The input data with the structure of input_schema.
|
||||
|
||||
Kwargs: Currently 14/02/2025 these include
|
||||
graph_id: The ID of the graph.
|
||||
node_id: The ID of the node.
|
||||
graph_exec_id: The ID of the graph execution.
|
||||
node_exec_id: The ID of the node execution.
|
||||
user_id: The ID of the user.
|
||||
|
||||
Returns:
|
||||
A Generator that yields (output_name, output_data).
|
||||
output_name: One of the output name defined in Block's output_schema.
|
||||
|
||||
@@ -399,7 +399,38 @@ OBJC_SPLIT = "_@_"
|
||||
|
||||
|
||||
def parse_execution_output(output: BlockData, name: str) -> Any | None:
|
||||
# Allow extracting partial output data by name.
|
||||
"""
|
||||
Extracts partial output data by name from a given BlockData.
|
||||
|
||||
The function supports extracting data from lists, dictionaries, and objects
|
||||
using specific naming conventions:
|
||||
- For lists: <output_name>_$_<index>
|
||||
- For dictionaries: <output_name>_#_<key>
|
||||
- For objects: <output_name>_@_<attribute>
|
||||
|
||||
Args:
|
||||
output (BlockData): A tuple containing the output name and data.
|
||||
name (str): The name used to extract specific data from the output.
|
||||
|
||||
Returns:
|
||||
Any | None: The extracted data if found, otherwise None.
|
||||
|
||||
Examples:
|
||||
>>> output = ("result", [10, 20, 30])
|
||||
>>> parse_execution_output(output, "result_$_1")
|
||||
20
|
||||
|
||||
>>> output = ("config", {"key1": "value1", "key2": "value2"})
|
||||
>>> parse_execution_output(output, "config_#_key1")
|
||||
'value1'
|
||||
|
||||
>>> class Sample:
|
||||
... attr1 = "value1"
|
||||
... attr2 = "value2"
|
||||
>>> output = ("object", Sample())
|
||||
>>> parse_execution_output(output, "object_@_attr1")
|
||||
'value1'
|
||||
"""
|
||||
output_name, output_data = output
|
||||
|
||||
if name == output_name:
|
||||
@@ -428,11 +459,37 @@ def parse_execution_output(output: BlockData, name: str) -> Any | None:
|
||||
|
||||
def merge_execution_input(data: BlockInput) -> BlockInput:
|
||||
"""
|
||||
Merge all dynamic input pins which described by the following pattern:
|
||||
- <input_name>_$_<index> for list input.
|
||||
- <input_name>_#_<index> for dict input.
|
||||
- <input_name>_@_<index> for object input.
|
||||
This function will construct pins with the same name into a single list/dict/object.
|
||||
Merges dynamic input pins into a single list, dictionary, or object based on naming patterns.
|
||||
|
||||
This function processes input keys that follow specific patterns to merge them into a unified structure:
|
||||
- `<input_name>_$_<index>` for list inputs.
|
||||
- `<input_name>_#_<index>` for dictionary inputs.
|
||||
- `<input_name>_@_<index>` for object inputs.
|
||||
|
||||
Args:
|
||||
data (BlockInput): A dictionary containing input keys and their corresponding values.
|
||||
|
||||
Returns:
|
||||
BlockInput: A dictionary with merged inputs.
|
||||
|
||||
Raises:
|
||||
ValueError: If a list index is not an integer.
|
||||
|
||||
Examples:
|
||||
>>> data = {
|
||||
... "list_$_0": "a",
|
||||
... "list_$_1": "b",
|
||||
... "dict_#_key1": "value1",
|
||||
... "dict_#_key2": "value2",
|
||||
... "object_@_attr1": "value1",
|
||||
... "object_@_attr2": "value2"
|
||||
... }
|
||||
>>> merge_execution_input(data)
|
||||
{
|
||||
"list": ["a", "b"],
|
||||
"dict": {"key1": "value1", "key2": "value2"},
|
||||
"object": <MockObject attr1="value1" attr2="value2">
|
||||
}
|
||||
"""
|
||||
|
||||
# Merge all input with <input_name>_$_<index> into a single list.
|
||||
|
||||
@@ -312,16 +312,49 @@ class GraphModel(Graph):
|
||||
|
||||
def validate_graph(self, for_run: bool = False):
|
||||
def sanitize(name):
|
||||
return name.split("_#_")[0].split("_@_")[0].split("_$_")[0]
|
||||
sanitized_name = name.split("_#_")[0].split("_@_")[0].split("_$_")[0]
|
||||
if sanitized_name.startswith("tools_^_"):
|
||||
return sanitized_name.split("_^_")[0]
|
||||
return sanitized_name
|
||||
|
||||
# Validate smart decision maker nodes
|
||||
smart_decision_maker_nodes = set()
|
||||
agent_nodes = set()
|
||||
nodes_block = {
|
||||
node.id: block
|
||||
for node in self.nodes
|
||||
if (block := get_block(node.block_id)) is not None
|
||||
}
|
||||
|
||||
for node in self.nodes:
|
||||
if (block := nodes_block.get(node.id)) is None:
|
||||
raise ValueError(f"Invalid block {node.block_id} for node #{node.id}")
|
||||
|
||||
# Smart decision maker nodes
|
||||
if block.block_type == BlockType.AI:
|
||||
smart_decision_maker_nodes.add(node.id)
|
||||
# Agent nodes
|
||||
elif block.block_type == BlockType.AGENT:
|
||||
agent_nodes.add(node.id)
|
||||
|
||||
input_links = defaultdict(list)
|
||||
|
||||
for link in self.links:
|
||||
input_links[link.sink_id].append(link)
|
||||
|
||||
# Check if the link is a tool link from a smart decision maker to a non-agent node
|
||||
if (
|
||||
link.source_id in smart_decision_maker_nodes
|
||||
and link.source_name.startswith("tools_^_")
|
||||
and link.sink_id not in agent_nodes
|
||||
):
|
||||
raise ValueError(
|
||||
f"Smart decision maker node {link.source_id} cannot link to non-agent node {link.sink_id}"
|
||||
)
|
||||
|
||||
# Nodes: required fields are filled or connected and dependencies are satisfied
|
||||
for node in self.nodes:
|
||||
block = get_block(node.block_id)
|
||||
if block is None:
|
||||
if (block := nodes_block.get(node.id)) is None:
|
||||
raise ValueError(f"Invalid block {node.block_id} for node #{node.id}")
|
||||
|
||||
provided_inputs = set(
|
||||
@@ -341,6 +374,7 @@ class GraphModel(Graph):
|
||||
or block.block_type == BlockType.INPUT
|
||||
or block.block_type == BlockType.OUTPUT
|
||||
or block.block_type == BlockType.AGENT
|
||||
or block.block_type == BlockType.AI
|
||||
)
|
||||
):
|
||||
raise ValueError(
|
||||
@@ -410,16 +444,16 @@ class GraphModel(Graph):
|
||||
if i == 0:
|
||||
fields = (
|
||||
block.output_schema.get_fields()
|
||||
if block.block_type != BlockType.AGENT
|
||||
if block.block_type not in [BlockType.AGENT, BlockType.AI]
|
||||
else vals.get("output_schema", {}).get("properties", {}).keys()
|
||||
)
|
||||
else:
|
||||
fields = (
|
||||
block.input_schema.get_fields()
|
||||
if block.block_type != BlockType.AGENT
|
||||
if block.block_type not in [BlockType.AGENT, BlockType.AI]
|
||||
else vals.get("input_schema", {}).get("properties", {}).keys()
|
||||
)
|
||||
if sanitized_name not in fields:
|
||||
if sanitized_name not in fields and not name.startswith("tools_"):
|
||||
fields_msg = f"Allowed fields: {fields}"
|
||||
raise ValueError(f"{suffix}, `{name}` invalid, {fields_msg}")
|
||||
|
||||
@@ -610,6 +644,33 @@ async def get_execution(user_id: str, execution_id: str) -> GraphExecution | Non
|
||||
return GraphExecution.from_db(execution) if execution else None
|
||||
|
||||
|
||||
async def get_graph_metadata(graph_id: str, version: int | None = None) -> Graph | None:
|
||||
where_clause: AgentGraphWhereInput = {
|
||||
"id": graph_id,
|
||||
}
|
||||
|
||||
if version is not None:
|
||||
where_clause["version"] = version
|
||||
|
||||
graph = await AgentGraph.prisma().find_first(
|
||||
where=where_clause,
|
||||
include=AGENT_GRAPH_INCLUDE,
|
||||
order={"version": "desc"},
|
||||
)
|
||||
|
||||
if not graph:
|
||||
return None
|
||||
|
||||
return Graph(
|
||||
id=graph.id,
|
||||
name=graph.name or "",
|
||||
description=graph.description or "",
|
||||
version=graph.version,
|
||||
is_active=graph.isActive,
|
||||
is_template=graph.isTemplate,
|
||||
)
|
||||
|
||||
|
||||
async def get_graph(
|
||||
graph_id: str,
|
||||
version: int | None = None,
|
||||
|
||||
@@ -13,7 +13,7 @@ from backend.data.execution import (
|
||||
upsert_execution_input,
|
||||
upsert_execution_output,
|
||||
)
|
||||
from backend.data.graph import get_graph, get_node
|
||||
from backend.data.graph import get_graph, get_graph_metadata, get_node
|
||||
from backend.data.user import (
|
||||
get_user_integrations,
|
||||
get_user_metadata,
|
||||
@@ -60,6 +60,7 @@ class DatabaseManager(AppService):
|
||||
# Graphs
|
||||
get_node = exposed_run_and_wait(get_node)
|
||||
get_graph = exposed_run_and_wait(get_graph)
|
||||
get_graph_metadata = exposed_run_and_wait(get_graph_metadata)
|
||||
|
||||
# Credits
|
||||
spend_credits = exposed_run_and_wait(_spend_credits)
|
||||
|
||||
@@ -16,6 +16,7 @@ import backend.data.block
|
||||
import backend.data.db
|
||||
import backend.data.graph
|
||||
import backend.data.user
|
||||
import backend.server.integrations.router
|
||||
import backend.server.routers.v1
|
||||
import backend.server.v2.library.db
|
||||
import backend.server.v2.library.model
|
||||
@@ -24,6 +25,8 @@ import backend.server.v2.store.model
|
||||
import backend.server.v2.store.routes
|
||||
import backend.util.service
|
||||
import backend.util.settings
|
||||
from backend.data.model import Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.server.external.api import external_app
|
||||
|
||||
settings = backend.util.settings.Settings()
|
||||
@@ -243,5 +246,15 @@ class AgentServer(backend.util.service.AppProcess):
|
||||
):
|
||||
return await backend.server.v2.store.routes.review_submission(request, user)
|
||||
|
||||
@staticmethod
|
||||
def test_create_credentials(
|
||||
user_id: str,
|
||||
provider: ProviderName,
|
||||
credentials: Credentials,
|
||||
) -> Credentials:
|
||||
return backend.server.integrations.router.create_credentials(
|
||||
user_id=user_id, provider=provider, credentials=credentials
|
||||
)
|
||||
|
||||
def set_test_dependency_overrides(self, overrides: dict):
|
||||
app.dependency_overrides.update(overrides)
|
||||
|
||||
@@ -40,7 +40,10 @@ def create_test_graph() -> graph.Graph:
|
||||
),
|
||||
graph.Node(
|
||||
block_id=AgentInputBlock().id,
|
||||
input_default={"name": "input_2"},
|
||||
input_default={
|
||||
"name": "input_2",
|
||||
"description": "This is my description of this parameter",
|
||||
},
|
||||
),
|
||||
graph.Node(
|
||||
block_id=FillTextTemplateBlock().id,
|
||||
@@ -74,7 +77,7 @@ def create_test_graph() -> graph.Graph:
|
||||
|
||||
return graph.Graph(
|
||||
name="TestGraph",
|
||||
description="Test graph",
|
||||
description="Test graph description",
|
||||
nodes=nodes,
|
||||
links=links,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,55 @@
|
||||
from backend.data.execution import merge_execution_input, parse_execution_output
|
||||
|
||||
|
||||
def test_parse_execution_output():
|
||||
# Test case for list extraction
|
||||
output = ("result", [10, 20, 30])
|
||||
assert parse_execution_output(output, "result_$_1") == 20
|
||||
assert parse_execution_output(output, "result_$_3") is None
|
||||
|
||||
# Test case for dictionary extraction
|
||||
output = ("config", {"key1": "value1", "key2": "value2"})
|
||||
assert parse_execution_output(output, "config_#_key1") == "value1"
|
||||
assert parse_execution_output(output, "config_#_key3") is None
|
||||
|
||||
# Test case for object extraction
|
||||
class Sample:
|
||||
attr1 = "value1"
|
||||
attr2 = "value2"
|
||||
|
||||
output = ("object", Sample())
|
||||
assert parse_execution_output(output, "object_@_attr1") == "value1"
|
||||
assert parse_execution_output(output, "object_@_attr3") is None
|
||||
|
||||
# Test case for direct match
|
||||
output = ("direct", "match")
|
||||
assert parse_execution_output(output, "direct") == "match"
|
||||
assert parse_execution_output(output, "nomatch") is None
|
||||
|
||||
|
||||
def test_merge_execution_input():
|
||||
# Test case for merging list inputs
|
||||
data = {"list_$_0": "a", "list_$_1": "b", "list_$_3": "d"}
|
||||
merged_data = merge_execution_input(data)
|
||||
assert merged_data["list"] == ["a", "b", "", "d"]
|
||||
|
||||
# Test case for merging dictionary inputs
|
||||
data = {"dict_#_key1": "value1", "dict_#_key2": "value2"}
|
||||
merged_data = merge_execution_input(data)
|
||||
assert merged_data["dict"] == {"key1": "value1", "key2": "value2"}
|
||||
|
||||
# Test case for merging object inputs
|
||||
data = {"object_@_attr1": "value1", "object_@_attr2": "value2"}
|
||||
merged_data = merge_execution_input(data)
|
||||
assert hasattr(merged_data["object"], "attr1")
|
||||
assert hasattr(merged_data["object"], "attr2")
|
||||
assert merged_data["object"].attr1 == "value1"
|
||||
assert merged_data["object"].attr2 == "value2"
|
||||
|
||||
# Test case for mixed inputs
|
||||
data = {"list_$_0": "a", "dict_#_key1": "value1", "object_@_attr1": "value1"}
|
||||
merged_data = merge_execution_input(data)
|
||||
assert merged_data["list"] == ["a"]
|
||||
assert merged_data["dict"] == {"key1": "value1"}
|
||||
assert hasattr(merged_data["object"], "attr1")
|
||||
assert merged_data["object"].attr1 == "value1"
|
||||
@@ -73,6 +73,7 @@ async def assert_sample_graph_executions(
|
||||
{
|
||||
"name": "input_2",
|
||||
"value": "World",
|
||||
"description": "This is my description of this parameter",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
249
autogpt_platform/backend/test/executor/test_tool_use.py
Normal file
249
autogpt_platform/backend/test/executor/test_tool_use.py
Normal file
@@ -0,0 +1,249 @@
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
from prisma.models import User
|
||||
|
||||
import backend.blocks.llm as llm
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.blocks.basic import StoreValueBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.data import graph
|
||||
from backend.data.model import ProviderName
|
||||
from backend.server.model import CreateGraph
|
||||
from backend.server.rest_api import AgentServer
|
||||
from backend.usecases.sample import create_test_graph, create_test_user
|
||||
from backend.util.test import SpinTestServer, wait_execution
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def create_graph(s: SpinTestServer, g: graph.Graph, u: User) -> graph.Graph:
|
||||
logger.info("Creating graph for user %s", u.id)
|
||||
return await s.agent_server.test_create_graph(CreateGraph(graph=g), u.id)
|
||||
|
||||
|
||||
def create_credentials(s: SpinTestServer, u: User):
|
||||
provider = ProviderName.OPENAI
|
||||
credentials = llm.TEST_CREDENTIALS
|
||||
try:
|
||||
s.agent_server.test_create_credentials(u.id, provider, credentials)
|
||||
except Exception:
|
||||
# ValueErrors is raised trying to recreate the same credentials
|
||||
# so hidding the error
|
||||
pass
|
||||
|
||||
|
||||
async def execute_graph(
|
||||
agent_server: AgentServer,
|
||||
test_graph: graph.Graph,
|
||||
test_user: User,
|
||||
input_data: dict,
|
||||
num_execs: int = 4,
|
||||
) -> str:
|
||||
logger.info("Executing graph %s for user %s", test_graph.id, test_user.id)
|
||||
logger.info("Input data: %s", input_data)
|
||||
|
||||
# --- Test adding new executions --- #
|
||||
response = await agent_server.test_execute_graph(
|
||||
user_id=test_user.id,
|
||||
graph_id=test_graph.id,
|
||||
graph_version=test_graph.version,
|
||||
node_input=input_data,
|
||||
)
|
||||
graph_exec_id = response.graph_exec_id
|
||||
logger.info("Created execution with ID: %s", graph_exec_id)
|
||||
|
||||
# Execution queue should be empty
|
||||
logger.info("Waiting for execution to complete...")
|
||||
result = await wait_execution(test_user.id, test_graph.id, graph_exec_id, 30)
|
||||
logger.info("Execution completed with %d results", len(result))
|
||||
return graph_exec_id
|
||||
|
||||
|
||||
@pytest.mark.skip()
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_graph_validation_with_tool_nodes_correct(server: SpinTestServer):
|
||||
test_user = await create_test_user()
|
||||
test_tool_graph = await create_graph(server, create_test_graph(), test_user)
|
||||
create_credentials(server, test_user)
|
||||
|
||||
nodes = [
|
||||
graph.Node(
|
||||
block_id=SmartDecisionMakerBlock().id,
|
||||
input_default={
|
||||
"prompt": "Hello, World!",
|
||||
"credentials": llm.TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
),
|
||||
graph.Node(
|
||||
block_id=AgentExecutorBlock().id,
|
||||
input_default={
|
||||
"graph_id": test_tool_graph.id,
|
||||
"graph_version": test_tool_graph.version,
|
||||
"input_schema": test_tool_graph.input_schema,
|
||||
"output_schema": test_tool_graph.output_schema,
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
links = [
|
||||
graph.Link(
|
||||
source_id=nodes[0].id,
|
||||
sink_id=nodes[1].id,
|
||||
source_name="tools_^_sample_tool_input_1",
|
||||
sink_name="input_1",
|
||||
),
|
||||
graph.Link(
|
||||
source_id=nodes[0].id,
|
||||
sink_id=nodes[1].id,
|
||||
source_name="tools_^_sample_tool_input_2",
|
||||
sink_name="input_2",
|
||||
),
|
||||
]
|
||||
|
||||
test_graph = graph.Graph(
|
||||
name="TestGraph",
|
||||
description="Test graph",
|
||||
nodes=nodes,
|
||||
links=links,
|
||||
)
|
||||
test_graph = await create_graph(server, test_graph, test_user)
|
||||
|
||||
|
||||
@pytest.mark.skip()
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_graph_validation_with_tool_nodes_raises_error(server: SpinTestServer):
|
||||
|
||||
test_user = await create_test_user()
|
||||
test_tool_graph = await create_graph(server, create_test_graph(), test_user)
|
||||
create_credentials(server, test_user)
|
||||
|
||||
nodes = [
|
||||
graph.Node(
|
||||
block_id=SmartDecisionMakerBlock().id,
|
||||
input_default={
|
||||
"prompt": "Hello, World!",
|
||||
"credentials": llm.TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
),
|
||||
graph.Node(
|
||||
block_id=AgentExecutorBlock().id,
|
||||
input_default={
|
||||
"graph_id": test_tool_graph.id,
|
||||
"graph_version": test_tool_graph.version,
|
||||
"input_schema": test_tool_graph.input_schema,
|
||||
"output_schema": test_tool_graph.output_schema,
|
||||
},
|
||||
),
|
||||
graph.Node(
|
||||
block_id=StoreValueBlock().id,
|
||||
),
|
||||
]
|
||||
|
||||
links = [
|
||||
graph.Link(
|
||||
source_id=nodes[0].id,
|
||||
sink_id=nodes[1].id,
|
||||
source_name="tools_^_sample_tool_input_1",
|
||||
sink_name="input_1",
|
||||
),
|
||||
graph.Link(
|
||||
source_id=nodes[0].id,
|
||||
sink_id=nodes[1].id,
|
||||
source_name="tools_^_sample_tool_input_2",
|
||||
sink_name="input_2",
|
||||
),
|
||||
graph.Link(
|
||||
source_id=nodes[0].id,
|
||||
sink_id=nodes[2].id,
|
||||
source_name="tools_^_store_value_input",
|
||||
sink_name="input",
|
||||
),
|
||||
]
|
||||
|
||||
test_graph = graph.Graph(
|
||||
name="TestGraph",
|
||||
description="Test graph",
|
||||
nodes=nodes,
|
||||
links=links,
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
test_graph = await create_graph(server, test_graph, test_user)
|
||||
|
||||
|
||||
@pytest.mark.skip()
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_smart_decision_maker_function_signature(server: SpinTestServer):
|
||||
test_user = await create_test_user()
|
||||
test_tool_graph = await create_graph(server, create_test_graph(), test_user)
|
||||
create_credentials(server, test_user)
|
||||
|
||||
nodes = [
|
||||
graph.Node(
|
||||
block_id=SmartDecisionMakerBlock().id,
|
||||
input_default={
|
||||
"prompt": "Hello, World!",
|
||||
"credentials": llm.TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
),
|
||||
graph.Node(
|
||||
block_id=AgentExecutorBlock().id,
|
||||
input_default={
|
||||
"graph_id": test_tool_graph.id,
|
||||
"graph_version": test_tool_graph.version,
|
||||
"input_schema": test_tool_graph.input_schema,
|
||||
"output_schema": test_tool_graph.output_schema,
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
links = [
|
||||
graph.Link(
|
||||
source_id=nodes[0].id,
|
||||
sink_id=nodes[1].id,
|
||||
source_name="tools_^_sample_tool_input_1",
|
||||
sink_name="input_1",
|
||||
),
|
||||
graph.Link(
|
||||
source_id=nodes[0].id,
|
||||
sink_id=nodes[1].id,
|
||||
source_name="tools_^_sample_tool_input_2",
|
||||
sink_name="input_2",
|
||||
),
|
||||
]
|
||||
|
||||
test_graph = graph.Graph(
|
||||
name="TestGraph",
|
||||
description="Test graph",
|
||||
nodes=nodes,
|
||||
links=links,
|
||||
)
|
||||
test_graph = await create_graph(server, test_graph, test_user)
|
||||
|
||||
tool_functions = SmartDecisionMakerBlock._create_function_signature(
|
||||
test_graph.nodes[0].id, test_graph, [test_tool_graph]
|
||||
)
|
||||
assert tool_functions is not None, "Tool functions should not be None"
|
||||
assert (
|
||||
len(tool_functions) == 1
|
||||
), f"Expected 1 tool function, got {len(tool_functions)}"
|
||||
|
||||
tool_function = next(
|
||||
filter(lambda x: x["function"]["name"] == "testgraph", tool_functions),
|
||||
None,
|
||||
)
|
||||
assert tool_function is not None, f"testgraph function not found: {tool_functions}"
|
||||
assert (
|
||||
tool_function["function"]["name"] == "testgraph"
|
||||
), "Incorrect function name for testgraph"
|
||||
assert (
|
||||
tool_function["function"]["parameters"]["properties"]["input_1"]["type"]
|
||||
== "string"
|
||||
), "Input type for input_1 should be 'string'"
|
||||
assert (
|
||||
tool_function["function"]["parameters"]["properties"]["input_2"]["type"]
|
||||
== "string"
|
||||
), "Input type for input_2 should be 'string'"
|
||||
assert (
|
||||
tool_function["function"]["parameters"]["required"] == []
|
||||
), "Required parameters should be an empty list"
|
||||
@@ -191,12 +191,15 @@ export default function useAgentGraph(
|
||||
});
|
||||
setEdges(() =>
|
||||
graph.links.map((link) => {
|
||||
const adjustedSourceName = link.source_name?.startsWith("tools_^_")
|
||||
? "tools"
|
||||
: link.source_name;
|
||||
return {
|
||||
id: formatEdgeID(link),
|
||||
type: "custom",
|
||||
data: {
|
||||
edgeColor: getTypeColor(
|
||||
getOutputType(newNodes, link.source_id, link.source_name!),
|
||||
getOutputType(newNodes, link.source_id, adjustedSourceName!),
|
||||
),
|
||||
sourcePos: newNodes.find((node) => node.id === link.source_id)
|
||||
?.position,
|
||||
@@ -209,12 +212,12 @@ export default function useAgentGraph(
|
||||
type: MarkerType.ArrowClosed,
|
||||
strokeWidth: 2,
|
||||
color: getTypeColor(
|
||||
getOutputType(newNodes, link.source_id, link.source_name!),
|
||||
getOutputType(newNodes, link.source_id, adjustedSourceName!),
|
||||
),
|
||||
},
|
||||
source: link.source_id,
|
||||
target: link.sink_id,
|
||||
sourceHandle: link.source_name || undefined,
|
||||
sourceHandle: adjustedSourceName || undefined,
|
||||
targetHandle: link.sink_name || undefined,
|
||||
};
|
||||
}),
|
||||
@@ -795,12 +798,35 @@ export default function useAgentGraph(
|
||||
};
|
||||
});
|
||||
|
||||
const links = edges.map((edge) => ({
|
||||
source_id: edge.source,
|
||||
sink_id: edge.target,
|
||||
source_name: edge.sourceHandle || "",
|
||||
sink_name: edge.targetHandle || "",
|
||||
}));
|
||||
const links = edges.map((edge) => {
|
||||
let sourceName = edge.sourceHandle || "";
|
||||
if (sourceName.toLowerCase() === "tools") {
|
||||
const sinkNode = nodes.find((node) => node.id === edge.target);
|
||||
|
||||
const sinkNodeName = sinkNode
|
||||
? sinkNode.data.block_id === "e189baac-8c20-45a1-94a7-55177ea42565" // AgentExecutorBlock ID
|
||||
? sinkNode.data.hardcodedValues?.graph_id
|
||||
? availableFlows
|
||||
.find(
|
||||
(flow) =>
|
||||
flow.id === sinkNode.data.hardcodedValues.graph_id,
|
||||
)
|
||||
?.name?.toLowerCase()
|
||||
.replace(/ /g, "_") || "agentexecutorblock"
|
||||
: "agentexecutorblock"
|
||||
: sinkNode.data.title.toLowerCase().replace(/ /g, "_")
|
||||
: "";
|
||||
|
||||
sourceName =
|
||||
`tools_^_${sinkNodeName}_${edge.targetHandle || ""}`.toLowerCase();
|
||||
}
|
||||
return {
|
||||
source_id: edge.source,
|
||||
sink_id: edge.target,
|
||||
source_name: sourceName,
|
||||
sink_name: edge.targetHandle || "",
|
||||
};
|
||||
});
|
||||
|
||||
const payload = {
|
||||
id: savedAgent?.id!,
|
||||
|
||||
@@ -411,6 +411,15 @@ export class BuildPage extends BasePage {
|
||||
};
|
||||
}
|
||||
|
||||
async getSmartDecisionMakerBlockDetails(): Promise<Block> {
|
||||
return {
|
||||
id: "3b191d9f-356f-482d-8238-ba04b6d18381",
|
||||
name: "Smart Decision Maker",
|
||||
description:
|
||||
"This block is used to make a decision based on the input and the available tools.",
|
||||
};
|
||||
}
|
||||
|
||||
async nextTutorialStep(): Promise<void> {
|
||||
console.log(`clicking next tutorial step`);
|
||||
await this.page.getByRole("button", { name: "Next" }).click();
|
||||
@@ -487,6 +496,7 @@ export class BuildPage extends BasePage {
|
||||
(await this.getAgentInputBlockDetails()).id,
|
||||
(await this.getAgentOutputBlockDetails()).id,
|
||||
(await this.getGithubTriggerBlockDetails()).id,
|
||||
(await this.getSmartDecisionMakerBlockDetails()).id,
|
||||
];
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user