feat(backend): Add tool execution response on Smart Decision Block (#9552)

**Proposal / Request For Comments** 

The tool calling provided by the LLM provider requires an output
generated by the tool to be looped back as part of the conversation
history. The scope of this PR is trying to address the need to fulfill
the feature expectation.

### Changes 🏗️

* `Last Tool Output` is introduced to loop back the output of the tool
back to Smart Decision Block.
* Smart Decision Block execution will be pending unless the `Last Tool
Output` us provided where a pending tool call is present in the
conversation history.
* **Known hack**: The last tool output will prefill all the pending tool
calls from the conversation history.
* A few tweaks were required to allow a blocking loop to be executed
without awaiting on all the inbound links.

<img width="1395" alt="image"
src="https://github.com/user-attachments/assets/fdad4407-621b-45d0-a457-76b2d4c853b9"
/>


### Checklist 📋

#### For code changes:
- [ ] I have clearly listed my changes in the PR description
- [ ] I have made a test plan
- [ ] I have tested my changes according to the test plan:
  <!-- Put your test plan here: -->
  - [ ] ...

<details>
  <summary>Example test plan</summary>
  
  - [ ] Create from scratch and execute an agent with at least 3 blocks
- [ ] Import an agent from file upload, and confirm it executes
correctly
  - [ ] Upload agent to marketplace
- [ ] Import an agent from marketplace and confirm it executes correctly
  - [ ] Edit an agent from monitor, and confirm it executes correctly
</details>

#### For configuration changes:
- [ ] `.env.example` is updated or already compatible with my changes
- [ ] `docker-compose.yml` is updated or already compatible with my
changes
- [ ] I have included a list of my configuration changes in the PR
description (under **Changes**)

<details>
  <summary>Examples of configuration changes</summary>

  - Changing ports
  - Adding new services that need to communicate with each other
  - Secrets or environment variable changes
  - New or infrastructure changes such as databases
</details>
This commit is contained in:
Zamil Majdy
2025-03-02 17:46:41 +07:00
committed by GitHub
parent 1e31136358
commit 36447a0c58
5 changed files with 134 additions and 41 deletions

View File

@@ -1,4 +1,5 @@
import logging
from typing import Any
from autogpt_libs.utils.cache import thread_cached
@@ -13,6 +14,7 @@ from backend.data.block import (
)
from backend.data.execution import ExecutionStatus
from backend.data.model import SchemaField
from backend.util import json
logger = logging.getLogger(__name__)
@@ -42,6 +44,23 @@ class AgentExecutorBlock(Block):
input_schema: dict = SchemaField(description="Input schema for the graph")
output_schema: dict = SchemaField(description="Output schema for the graph")
@classmethod
def get_input_schema(cls, data: BlockInput) -> dict[str, Any]:
return data.get("input_schema", {})
@classmethod
def get_input_defaults(cls, data: BlockInput) -> BlockInput:
return data.get("data", {})
@classmethod
def get_missing_input(cls, data: BlockInput) -> set[str]:
required_fields = cls.get_input_schema(data).get("required", [])
return set(required_fields) - set(data)
@classmethod
def validate_data(cls, data: BlockInput) -> str | None:
return json.validate_with_jsonschema(cls.get_input_schema(data), data)
class Output(BlockSchema):
pass

View File

@@ -1,5 +1,6 @@
import logging
import re
from collections import Counter
from typing import TYPE_CHECKING, Any
from autogpt_libs.utils.cache import thread_cached
@@ -9,6 +10,7 @@ from backend.blocks.agent import AgentExecutorBlock
from backend.data.block import (
Block,
BlockCategory,
BlockInput,
BlockOutput,
BlockSchema,
BlockType,
@@ -31,6 +33,24 @@ def get_database_manager_client():
return get_service_client(DatabaseManager)
def get_pending_tool_calls(conversation_history: list[Any]) -> dict[str, int]:
"""
All the tool calls entry in the conversation history requires a response.
This function returns the pending tool calls that has not generated an output yet.
Return: dict[str, int] - A dictionary of pending tool call IDs with their count.
"""
pending_calls = Counter()
for history in conversation_history:
for call in history.get("tool_calls") or []:
pending_calls[call.get("id")] += 1
if call_id := history.get("tool_call_id"):
pending_calls[call_id] -= 1
return {call_id: count for call_id, count in pending_calls.items() if count > 0}
class SmartDecisionMakerBlock(Block):
"""
A block that uses a language model to make smart decisions based on a given prompt.
@@ -57,6 +77,10 @@ class SmartDecisionMakerBlock(Block):
default=[],
description="The conversation history to provide context for the prompt.",
)
last_tool_output: Any = SchemaField(
default=None,
description="The output of the last tool that was called.",
)
retry: int = SchemaField(
title="Retry Count",
default=3,
@@ -78,6 +102,32 @@ class SmartDecisionMakerBlock(Block):
description="Ollama host for local models",
)
@classmethod
def get_missing_links(cls, data: BlockInput, links: list["Link"]) -> set[str]:
# conversation_history & last_tool_output validation is handled differently
return super().get_missing_links(
data,
[
link
for link in links
if link.sink_name
not in ["conversation_history", "last_tool_output"]
],
)
@classmethod
def get_missing_input(cls, data: BlockInput) -> set[str]:
if missing_input := super().get_missing_input(data):
return missing_input
conversation_history = data.get("conversation_history", [])
pending_tool_calls = get_pending_tool_calls(conversation_history)
last_tool_output = data.get("last_tool_output")
if not last_tool_output and pending_tool_calls:
return {"last_tool_output"}
return set()
class Output(BlockSchema):
error: str = SchemaField(description="Error message if the API call failed.")
tools: Any = SchemaField(description="The tools that are available to use.")
@@ -287,8 +337,31 @@ class SmartDecisionMakerBlock(Block):
) -> BlockOutput:
tool_functions = self._create_function_signature(node_id)
input_data.conversation_history = input_data.conversation_history or []
prompt = [json.to_dict(p) for p in input_data.conversation_history]
pending_tool_calls = get_pending_tool_calls(input_data.conversation_history)
if pending_tool_calls and not input_data.last_tool_output:
raise ValueError(f"Tool call requires an output for {pending_tool_calls}")
# Prefill all missing tool calls with the last tool output/
# TODO: we need a better way to handle this.
tool_output = [
{
"role": "tool",
"content": input_data.last_tool_output,
"tool_call_id": pending_call_id,
}
for pending_call_id, count in pending_tool_calls.items()
for _ in range(count)
]
if len(tool_output) > 1:
logger.warning(
f"[node_exec_id={node_exec_id}] Multiple pending tool calls are prefilled using a single output. Execution may not be accurate."
)
prompt.extend(tool_output)
values = input_data.prompt_values
if values:
input_data.prompt = llm.fmt.format_string(input_data.prompt, values)
@@ -312,14 +385,15 @@ class SmartDecisionMakerBlock(Block):
if not response.tool_calls:
yield "finished", f"No Decision Made finishing task: {response.response}"
else:
for tool_call in response.tool_calls:
tool_name = tool_call.function.name
tool_args = json.loads(tool_call.function.arguments)
return
for arg_name, arg_value in tool_args.items():
yield f"tools_^_{tool_name}_{arg_name}".lower(), arg_value
for tool_call in response.tool_calls:
tool_name = tool_call.function.name
tool_args = json.loads(tool_call.function.arguments)
input_data.conversation_history.extend(response.prompt)
input_data.conversation_history.append(response.raw_response)
yield "conversations", input_data.conversation_history
for arg_name, arg_value in tool_args.items():
yield f"tools_^_{tool_name}_{arg_name}".lower(), arg_value
response.prompt.append(response.raw_response)
yield "conversations", response.prompt
yield "length", len(response.prompt)

View File

@@ -2,6 +2,7 @@ import inspect
from abc import ABC, abstractmethod
from enum import Enum
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Generator,
@@ -28,6 +29,9 @@ from .model import (
is_credentials_field_name,
)
if TYPE_CHECKING:
from .graph import Link
app_config = Config()
BlockData = tuple[str, Any] # Input & Output data should be a tuple of (name, data).
@@ -187,6 +191,19 @@ class BlockSchema(BaseModel):
)
}
@classmethod
def get_input_defaults(cls, data: BlockInput) -> BlockInput:
return data # Return as is, by default.
@classmethod
def get_missing_links(cls, data: BlockInput, links: list["Link"]) -> set[str]:
input_fields_from_nodes = {link.sink_name for link in links}
return input_fields_from_nodes - set(data)
@classmethod
def get_missing_input(cls, data: BlockInput) -> set[str]:
return cls.get_required_fields() - set(data)
BlockSchemaInputType = TypeVar("BlockSchemaInputType", bound=BlockSchema)
BlockSchemaOutputType = TypeVar("BlockSchemaOutputType", bound=BlockSchema)
@@ -407,7 +424,6 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
}
def execute(self, input_data: BlockInput, **kwargs) -> BlockOutput:
# Merge the input data with the extra execution arguments, preferring the args for security
if error := self.input_schema.validate_data(input_data):
raise ValueError(
f"Unable to execute block with invalid input data: {error}"

View File

@@ -415,20 +415,20 @@ class GraphModel(Graph):
for link in self.links:
source = (link.source_id, link.source_name)
sink = (link.sink_id, link.sink_name)
suffix = f"Link {source} <-> {sink}"
prefix = f"Link {source} <-> {sink}"
for i, (node_id, name) in enumerate([source, sink]):
node = node_map.get(node_id)
if not node:
raise ValueError(
f"{suffix}, {node_id} is invalid node id, available nodes: {node_map.keys()}"
f"{prefix}, {node_id} is invalid node id, available nodes: {node_map.keys()}"
)
block = get_block(node.block_id)
if not block:
blocks = {v().id: v().name for v in get_blocks().values()}
raise ValueError(
f"{suffix}, {node.block_id} is invalid block id, available blocks: {blocks}"
f"{prefix}, {node.block_id} is invalid block id, available blocks: {blocks}"
)
sanitized_name = sanitize(name)
@@ -447,7 +447,7 @@ class GraphModel(Graph):
)
if sanitized_name not in fields and not name.startswith("tools_^_"):
fields_msg = f"Allowed fields: {fields}"
raise ValueError(f"{suffix}, `{name}` invalid, {fields_msg}")
raise ValueError(f"{prefix}, `{name}` invalid, {fields_msg}")
if is_static_output_block(link.source_id):
link.is_static = True # Each value block output should be static.

View File

@@ -412,46 +412,30 @@ def validate_exec(
node_block: Block | None = get_block(node.block_id)
if not node_block:
return None, f"Block for {node.block_id} not found."
schema = node_block.input_schema
if isinstance(node_block, AgentExecutorBlock):
# Validate the execution metadata for the agent executor block.
try:
exec_data = AgentExecutorBlock.Input(**node.input_default)
except Exception as e:
return None, f"Input data doesn't match {node_block.name}: {str(e)}"
# Validation input
input_schema = exec_data.input_schema
required_fields = set(input_schema["required"])
input_default = exec_data.data
else:
# Convert non-matching data types to the expected input schema.
for name, data_type in node_block.input_schema.__annotations__.items():
if (value := data.get(name)) and (type(value) is not data_type):
data[name] = convert(value, data_type)
# Validation input
input_schema = node_block.input_schema.jsonschema()
required_fields = node_block.input_schema.get_required_fields()
input_default = node.input_default
# Convert non-matching data types to the expected input schema.
for name, data_type in schema.__annotations__.items():
if (value := data.get(name)) and (type(value) is not data_type):
data[name] = convert(value, data_type)
# Input data (without default values) should contain all required fields.
error_prefix = f"Input data missing or mismatch for `{node_block.name}`:"
input_fields_from_nodes = {link.sink_name for link in node.input_links}
if not input_fields_from_nodes.issubset(data):
return None, f"{error_prefix} {input_fields_from_nodes - set(data)}"
if missing_links := schema.get_missing_links(data, node.input_links):
return None, f"{error_prefix} unpopulated links {missing_links}"
# Merge input data with default values and resolve dynamic dict/list/object pins.
input_default = schema.get_input_defaults(node.input_default)
data = {**input_default, **data}
if resolve_input:
data = merge_execution_input(data)
# Input data post-merge should contain all required fields from the schema.
if not required_fields.issubset(data):
return None, f"{error_prefix} {required_fields - set(data)}"
if missing_input := schema.get_missing_input(data):
return None, f"{error_prefix} missing input {missing_input}"
# Last validation: Validate the input values against the schema.
if error := json.validate_with_jsonschema(schema=input_schema, data=data):
if error := schema.validate_data(data):
error_message = f"{error_prefix} {error}"
logger.error(error_message)
return None, error_message