mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
feat(backend): Add tool execution response on Smart Decision Block (#9552)
**Proposal / Request For Comments** The tool calling provided by the LLM provider requires an output generated by the tool to be looped back as part of the conversation history. The scope of this PR is trying to address the need to fulfill the feature expectation. ### Changes 🏗️ * `Last Tool Output` is introduced to loop back the output of the tool back to Smart Decision Block. * Smart Decision Block execution will be pending unless the `Last Tool Output` us provided where a pending tool call is present in the conversation history. * **Known hack**: The last tool output will prefill all the pending tool calls from the conversation history. * A few tweaks were required to allow a blocking loop to be executed without awaiting on all the inbound links. <img width="1395" alt="image" src="https://github.com/user-attachments/assets/fdad4407-621b-45d0-a457-76b2d4c853b9" /> ### Checklist 📋 #### For code changes: - [ ] I have clearly listed my changes in the PR description - [ ] I have made a test plan - [ ] I have tested my changes according to the test plan: <!-- Put your test plan here: --> - [ ] ... <details> <summary>Example test plan</summary> - [ ] Create from scratch and execute an agent with at least 3 blocks - [ ] Import an agent from file upload, and confirm it executes correctly - [ ] Upload agent to marketplace - [ ] Import an agent from marketplace and confirm it executes correctly - [ ] Edit an agent from monitor, and confirm it executes correctly </details> #### For configuration changes: - [ ] `.env.example` is updated or already compatible with my changes - [ ] `docker-compose.yml` is updated or already compatible with my changes - [ ] I have included a list of my configuration changes in the PR description (under **Changes**) <details> <summary>Examples of configuration changes</summary> - Changing ports - Adding new services that need to communicate with each other - Secrets or environment variable changes - New or infrastructure changes such as databases </details>
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
|
||||
@@ -13,6 +14,7 @@ from backend.data.block import (
|
||||
)
|
||||
from backend.data.execution import ExecutionStatus
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util import json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -42,6 +44,23 @@ class AgentExecutorBlock(Block):
|
||||
input_schema: dict = SchemaField(description="Input schema for the graph")
|
||||
output_schema: dict = SchemaField(description="Output schema for the graph")
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls, data: BlockInput) -> dict[str, Any]:
|
||||
return data.get("input_schema", {})
|
||||
|
||||
@classmethod
|
||||
def get_input_defaults(cls, data: BlockInput) -> BlockInput:
|
||||
return data.get("data", {})
|
||||
|
||||
@classmethod
|
||||
def get_missing_input(cls, data: BlockInput) -> set[str]:
|
||||
required_fields = cls.get_input_schema(data).get("required", [])
|
||||
return set(required_fields) - set(data)
|
||||
|
||||
@classmethod
|
||||
def validate_data(cls, data: BlockInput) -> str | None:
|
||||
return json.validate_with_jsonschema(cls.get_input_schema(data), data)
|
||||
|
||||
class Output(BlockSchema):
|
||||
pass
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import logging
|
||||
import re
|
||||
from collections import Counter
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
@@ -9,6 +10,7 @@ from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockInput,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
@@ -31,6 +33,24 @@ def get_database_manager_client():
|
||||
return get_service_client(DatabaseManager)
|
||||
|
||||
|
||||
def get_pending_tool_calls(conversation_history: list[Any]) -> dict[str, int]:
|
||||
"""
|
||||
All the tool calls entry in the conversation history requires a response.
|
||||
This function returns the pending tool calls that has not generated an output yet.
|
||||
|
||||
Return: dict[str, int] - A dictionary of pending tool call IDs with their count.
|
||||
"""
|
||||
pending_calls = Counter()
|
||||
for history in conversation_history:
|
||||
for call in history.get("tool_calls") or []:
|
||||
pending_calls[call.get("id")] += 1
|
||||
|
||||
if call_id := history.get("tool_call_id"):
|
||||
pending_calls[call_id] -= 1
|
||||
|
||||
return {call_id: count for call_id, count in pending_calls.items() if count > 0}
|
||||
|
||||
|
||||
class SmartDecisionMakerBlock(Block):
|
||||
"""
|
||||
A block that uses a language model to make smart decisions based on a given prompt.
|
||||
@@ -57,6 +77,10 @@ class SmartDecisionMakerBlock(Block):
|
||||
default=[],
|
||||
description="The conversation history to provide context for the prompt.",
|
||||
)
|
||||
last_tool_output: Any = SchemaField(
|
||||
default=None,
|
||||
description="The output of the last tool that was called.",
|
||||
)
|
||||
retry: int = SchemaField(
|
||||
title="Retry Count",
|
||||
default=3,
|
||||
@@ -78,6 +102,32 @@ class SmartDecisionMakerBlock(Block):
|
||||
description="Ollama host for local models",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_missing_links(cls, data: BlockInput, links: list["Link"]) -> set[str]:
|
||||
# conversation_history & last_tool_output validation is handled differently
|
||||
return super().get_missing_links(
|
||||
data,
|
||||
[
|
||||
link
|
||||
for link in links
|
||||
if link.sink_name
|
||||
not in ["conversation_history", "last_tool_output"]
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_missing_input(cls, data: BlockInput) -> set[str]:
|
||||
if missing_input := super().get_missing_input(data):
|
||||
return missing_input
|
||||
|
||||
conversation_history = data.get("conversation_history", [])
|
||||
pending_tool_calls = get_pending_tool_calls(conversation_history)
|
||||
|
||||
last_tool_output = data.get("last_tool_output")
|
||||
if not last_tool_output and pending_tool_calls:
|
||||
return {"last_tool_output"}
|
||||
return set()
|
||||
|
||||
class Output(BlockSchema):
|
||||
error: str = SchemaField(description="Error message if the API call failed.")
|
||||
tools: Any = SchemaField(description="The tools that are available to use.")
|
||||
@@ -287,8 +337,31 @@ class SmartDecisionMakerBlock(Block):
|
||||
) -> BlockOutput:
|
||||
tool_functions = self._create_function_signature(node_id)
|
||||
|
||||
input_data.conversation_history = input_data.conversation_history or []
|
||||
prompt = [json.to_dict(p) for p in input_data.conversation_history]
|
||||
|
||||
pending_tool_calls = get_pending_tool_calls(input_data.conversation_history)
|
||||
if pending_tool_calls and not input_data.last_tool_output:
|
||||
raise ValueError(f"Tool call requires an output for {pending_tool_calls}")
|
||||
|
||||
# Prefill all missing tool calls with the last tool output/
|
||||
# TODO: we need a better way to handle this.
|
||||
tool_output = [
|
||||
{
|
||||
"role": "tool",
|
||||
"content": input_data.last_tool_output,
|
||||
"tool_call_id": pending_call_id,
|
||||
}
|
||||
for pending_call_id, count in pending_tool_calls.items()
|
||||
for _ in range(count)
|
||||
]
|
||||
if len(tool_output) > 1:
|
||||
logger.warning(
|
||||
f"[node_exec_id={node_exec_id}] Multiple pending tool calls are prefilled using a single output. Execution may not be accurate."
|
||||
)
|
||||
|
||||
prompt.extend(tool_output)
|
||||
|
||||
values = input_data.prompt_values
|
||||
if values:
|
||||
input_data.prompt = llm.fmt.format_string(input_data.prompt, values)
|
||||
@@ -312,14 +385,15 @@ class SmartDecisionMakerBlock(Block):
|
||||
|
||||
if not response.tool_calls:
|
||||
yield "finished", f"No Decision Made finishing task: {response.response}"
|
||||
else:
|
||||
for tool_call in response.tool_calls:
|
||||
tool_name = tool_call.function.name
|
||||
tool_args = json.loads(tool_call.function.arguments)
|
||||
return
|
||||
|
||||
for arg_name, arg_value in tool_args.items():
|
||||
yield f"tools_^_{tool_name}_{arg_name}".lower(), arg_value
|
||||
for tool_call in response.tool_calls:
|
||||
tool_name = tool_call.function.name
|
||||
tool_args = json.loads(tool_call.function.arguments)
|
||||
|
||||
input_data.conversation_history.extend(response.prompt)
|
||||
input_data.conversation_history.append(response.raw_response)
|
||||
yield "conversations", input_data.conversation_history
|
||||
for arg_name, arg_value in tool_args.items():
|
||||
yield f"tools_^_{tool_name}_{arg_name}".lower(), arg_value
|
||||
|
||||
response.prompt.append(response.raw_response)
|
||||
yield "conversations", response.prompt
|
||||
yield "length", len(response.prompt)
|
||||
|
||||
@@ -2,6 +2,7 @@ import inspect
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
ClassVar,
|
||||
Generator,
|
||||
@@ -28,6 +29,9 @@ from .model import (
|
||||
is_credentials_field_name,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .graph import Link
|
||||
|
||||
app_config = Config()
|
||||
|
||||
BlockData = tuple[str, Any] # Input & Output data should be a tuple of (name, data).
|
||||
@@ -187,6 +191,19 @@ class BlockSchema(BaseModel):
|
||||
)
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_input_defaults(cls, data: BlockInput) -> BlockInput:
|
||||
return data # Return as is, by default.
|
||||
|
||||
@classmethod
|
||||
def get_missing_links(cls, data: BlockInput, links: list["Link"]) -> set[str]:
|
||||
input_fields_from_nodes = {link.sink_name for link in links}
|
||||
return input_fields_from_nodes - set(data)
|
||||
|
||||
@classmethod
|
||||
def get_missing_input(cls, data: BlockInput) -> set[str]:
|
||||
return cls.get_required_fields() - set(data)
|
||||
|
||||
|
||||
BlockSchemaInputType = TypeVar("BlockSchemaInputType", bound=BlockSchema)
|
||||
BlockSchemaOutputType = TypeVar("BlockSchemaOutputType", bound=BlockSchema)
|
||||
@@ -407,7 +424,6 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
}
|
||||
|
||||
def execute(self, input_data: BlockInput, **kwargs) -> BlockOutput:
|
||||
# Merge the input data with the extra execution arguments, preferring the args for security
|
||||
if error := self.input_schema.validate_data(input_data):
|
||||
raise ValueError(
|
||||
f"Unable to execute block with invalid input data: {error}"
|
||||
|
||||
@@ -415,20 +415,20 @@ class GraphModel(Graph):
|
||||
for link in self.links:
|
||||
source = (link.source_id, link.source_name)
|
||||
sink = (link.sink_id, link.sink_name)
|
||||
suffix = f"Link {source} <-> {sink}"
|
||||
prefix = f"Link {source} <-> {sink}"
|
||||
|
||||
for i, (node_id, name) in enumerate([source, sink]):
|
||||
node = node_map.get(node_id)
|
||||
if not node:
|
||||
raise ValueError(
|
||||
f"{suffix}, {node_id} is invalid node id, available nodes: {node_map.keys()}"
|
||||
f"{prefix}, {node_id} is invalid node id, available nodes: {node_map.keys()}"
|
||||
)
|
||||
|
||||
block = get_block(node.block_id)
|
||||
if not block:
|
||||
blocks = {v().id: v().name for v in get_blocks().values()}
|
||||
raise ValueError(
|
||||
f"{suffix}, {node.block_id} is invalid block id, available blocks: {blocks}"
|
||||
f"{prefix}, {node.block_id} is invalid block id, available blocks: {blocks}"
|
||||
)
|
||||
|
||||
sanitized_name = sanitize(name)
|
||||
@@ -447,7 +447,7 @@ class GraphModel(Graph):
|
||||
)
|
||||
if sanitized_name not in fields and not name.startswith("tools_^_"):
|
||||
fields_msg = f"Allowed fields: {fields}"
|
||||
raise ValueError(f"{suffix}, `{name}` invalid, {fields_msg}")
|
||||
raise ValueError(f"{prefix}, `{name}` invalid, {fields_msg}")
|
||||
|
||||
if is_static_output_block(link.source_id):
|
||||
link.is_static = True # Each value block output should be static.
|
||||
|
||||
@@ -412,46 +412,30 @@ def validate_exec(
|
||||
node_block: Block | None = get_block(node.block_id)
|
||||
if not node_block:
|
||||
return None, f"Block for {node.block_id} not found."
|
||||
schema = node_block.input_schema
|
||||
|
||||
if isinstance(node_block, AgentExecutorBlock):
|
||||
# Validate the execution metadata for the agent executor block.
|
||||
try:
|
||||
exec_data = AgentExecutorBlock.Input(**node.input_default)
|
||||
except Exception as e:
|
||||
return None, f"Input data doesn't match {node_block.name}: {str(e)}"
|
||||
|
||||
# Validation input
|
||||
input_schema = exec_data.input_schema
|
||||
required_fields = set(input_schema["required"])
|
||||
input_default = exec_data.data
|
||||
else:
|
||||
# Convert non-matching data types to the expected input schema.
|
||||
for name, data_type in node_block.input_schema.__annotations__.items():
|
||||
if (value := data.get(name)) and (type(value) is not data_type):
|
||||
data[name] = convert(value, data_type)
|
||||
|
||||
# Validation input
|
||||
input_schema = node_block.input_schema.jsonschema()
|
||||
required_fields = node_block.input_schema.get_required_fields()
|
||||
input_default = node.input_default
|
||||
# Convert non-matching data types to the expected input schema.
|
||||
for name, data_type in schema.__annotations__.items():
|
||||
if (value := data.get(name)) and (type(value) is not data_type):
|
||||
data[name] = convert(value, data_type)
|
||||
|
||||
# Input data (without default values) should contain all required fields.
|
||||
error_prefix = f"Input data missing or mismatch for `{node_block.name}`:"
|
||||
input_fields_from_nodes = {link.sink_name for link in node.input_links}
|
||||
if not input_fields_from_nodes.issubset(data):
|
||||
return None, f"{error_prefix} {input_fields_from_nodes - set(data)}"
|
||||
if missing_links := schema.get_missing_links(data, node.input_links):
|
||||
return None, f"{error_prefix} unpopulated links {missing_links}"
|
||||
|
||||
# Merge input data with default values and resolve dynamic dict/list/object pins.
|
||||
input_default = schema.get_input_defaults(node.input_default)
|
||||
data = {**input_default, **data}
|
||||
if resolve_input:
|
||||
data = merge_execution_input(data)
|
||||
|
||||
# Input data post-merge should contain all required fields from the schema.
|
||||
if not required_fields.issubset(data):
|
||||
return None, f"{error_prefix} {required_fields - set(data)}"
|
||||
if missing_input := schema.get_missing_input(data):
|
||||
return None, f"{error_prefix} missing input {missing_input}"
|
||||
|
||||
# Last validation: Validate the input values against the schema.
|
||||
if error := json.validate_with_jsonschema(schema=input_schema, data=data):
|
||||
if error := schema.validate_data(data):
|
||||
error_message = f"{error_prefix} {error}"
|
||||
logger.error(error_message)
|
||||
return None, error_message
|
||||
|
||||
Reference in New Issue
Block a user