mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-09 15:17:59 -05:00
feat(backend): Add conversation history as Smart Decision Block output (#9540)
Now that the MVP for Smar Decision Block is available, we need to add this info in the block output: Messages (list) - The messages list sent to the LLM plus its generated response as the latest assistant entry. This is a single list of dictionaries in standard LLM API format). ### Changes 🏗️ * Add `conversations` output pin that populates the update `conversation_history` input pin with the assistant response. * Refactored `Smart Decision Block` to avoid downloading the whole graph on each execution, remove code duplication, declutter data fetching. * Minor UI issue on the smart decision block entry in the search bar. ### Checklist 📋 #### For code changes: - [ ] I have clearly listed my changes in the PR description - [ ] I have made a test plan - [ ] I have tested my changes according to the test plan: <!-- Put your test plan here: --> - [ ] ... <details> <summary>Example test plan</summary> - [ ] Create from scratch and execute an agent with at least 3 blocks - [ ] Import an agent from file upload, and confirm it executes correctly - [ ] Upload agent to marketplace - [ ] Import an agent from marketplace and confirm it executes correctly - [ ] Edit an agent from monitor, and confirm it executes correctly </details> #### For configuration changes: - [ ] `.env.example` is updated or already compatible with my changes - [ ] `docker-compose.yml` is updated or already compatible with my changes - [ ] I have included a list of my configuration changes in the PR description (under **Changes**) <details> <summary>Examples of configuration changes</summary> - Changing ports - Adding new services that need to communicate with each other - Secrets or environment variable changes - New or infrastructure changes such as databases </details> Co-authored-by: Swifty <craigswift13@gmail.com>
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any, List
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
|
||||
@@ -13,12 +13,12 @@ from backend.data.block import (
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
get_blocks,
|
||||
get_block,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.graph import Graph, Link, Node
|
||||
from backend.data.graph import Link, Node
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -89,6 +89,9 @@ class SmartDecisionMakerBlock(Block):
|
||||
finished: str = SchemaField(
|
||||
description="The finished message to display to the user."
|
||||
)
|
||||
conversations: list[llm.Message] = SchemaField(
|
||||
description="The conversation history to provide context for the prompt."
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -106,41 +109,6 @@ class SmartDecisionMakerBlock(Block):
|
||||
test_credentials=llm.TEST_CREDENTIALS,
|
||||
)
|
||||
|
||||
# If I import Graph here, it will break with a circular import.
|
||||
def _get_tool_graph_metadata(self, node_id: str, graph: "Graph") -> List["Graph"]:
|
||||
"""
|
||||
Retrieves metadata for tool graphs linked to a specified node within a graph.
|
||||
|
||||
This method identifies the tool links connected to the given node_id and fetches
|
||||
the metadata for each linked tool graph from the database.
|
||||
|
||||
Args:
|
||||
node_id (str): The ID of the node for which tool graph metadata is to be retrieved.
|
||||
graph (Any): The graph object containing nodes and links.
|
||||
|
||||
Returns:
|
||||
List[Any]: A list of metadata for the tool graphs linked to the specified node.
|
||||
"""
|
||||
db_client = get_database_manager_client()
|
||||
graph_meta = []
|
||||
|
||||
tool_links = {
|
||||
link.sink_id
|
||||
for link in graph.links
|
||||
if link.source_name.startswith("tools_^_") and link.source_id == node_id
|
||||
}
|
||||
|
||||
for link_id in tool_links:
|
||||
node = next((node for node in graph.nodes if node.id == link_id), None)
|
||||
if node and node.block_id == AgentExecutorBlock().id:
|
||||
node_graph_meta = db_client.get_graph_metadata(
|
||||
node.input_default["graph_id"], node.input_default["graph_version"]
|
||||
)
|
||||
if node_graph_meta:
|
||||
graph_meta.append(node_graph_meta)
|
||||
|
||||
return graph_meta
|
||||
|
||||
@staticmethod
|
||||
def _create_block_function_signature(
|
||||
sink_node: "Node", links: list["Link"]
|
||||
@@ -158,20 +126,20 @@ class SmartDecisionMakerBlock(Block):
|
||||
Raises:
|
||||
ValueError: If the block specified by sink_node.block_id is not found.
|
||||
"""
|
||||
block = get_blocks()[sink_node.block_id]
|
||||
block = get_block(sink_node.block_id)
|
||||
if not block:
|
||||
raise ValueError(f"Block not found: {sink_node.block_id}")
|
||||
|
||||
tool_function: dict[str, Any] = {
|
||||
"name": re.sub(r"[^a-zA-Z0-9_-]", "_", block().name).lower(),
|
||||
"description": block().description,
|
||||
"name": re.sub(r"[^a-zA-Z0-9_-]", "_", block.name).lower(),
|
||||
"description": block.description,
|
||||
}
|
||||
|
||||
properties = {}
|
||||
required = []
|
||||
|
||||
for link in links:
|
||||
sink_block_input_schema = block().input_schema
|
||||
sink_block_input_schema = block.input_schema
|
||||
description = (
|
||||
sink_block_input_schema.model_fields[link.sink_name].description
|
||||
if link.sink_name in sink_block_input_schema.model_fields
|
||||
@@ -195,7 +163,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
|
||||
@staticmethod
|
||||
def _create_agent_function_signature(
|
||||
sink_node: "Node", links: list["Link"], tool_graph_metadata: list["Graph"]
|
||||
sink_node: "Node", links: list["Link"]
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Creates a function signature for an agent node.
|
||||
@@ -203,7 +171,6 @@ class SmartDecisionMakerBlock(Block):
|
||||
Args:
|
||||
sink_node: The agent node for which to create a function signature.
|
||||
links: The list of links connected to the sink node.
|
||||
tool_graph_metadata: List of metadata for available tool graphs.
|
||||
|
||||
Returns:
|
||||
A dictionary representing the function signature in the format expected by LLM tools.
|
||||
@@ -211,18 +178,13 @@ class SmartDecisionMakerBlock(Block):
|
||||
Raises:
|
||||
ValueError: If the graph metadata for the specified graph_id and graph_version is not found.
|
||||
"""
|
||||
graph_id = sink_node.input_default["graph_id"]
|
||||
graph_version = sink_node.input_default["graph_version"]
|
||||
|
||||
sink_graph_meta = next(
|
||||
(
|
||||
meta
|
||||
for meta in tool_graph_metadata
|
||||
if meta.id == graph_id and meta.version == graph_version
|
||||
),
|
||||
None,
|
||||
)
|
||||
graph_id = sink_node.input_default.get("graph_id")
|
||||
graph_version = sink_node.input_default.get("graph_version")
|
||||
if not graph_id or not graph_version:
|
||||
raise ValueError("Graph ID or Graph Version not found in sink node.")
|
||||
|
||||
db_client = get_database_manager_client()
|
||||
sink_graph_meta = db_client.get_graph_metadata(graph_id.graph_version)
|
||||
if not sink_graph_meta:
|
||||
raise ValueError(
|
||||
f"Sink graph metadata not found: {graph_id} {graph_version}"
|
||||
@@ -260,12 +222,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
return {"type": "function", "function": tool_function}
|
||||
|
||||
@staticmethod
|
||||
def _create_function_signature(
|
||||
# If I import Graph here, it will break with a circular import.
|
||||
node_id: str,
|
||||
graph: "Graph",
|
||||
tool_graph_metadata: List["Graph"],
|
||||
) -> list[dict[str, Any]]:
|
||||
def _create_function_signature(node_id: str) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Creates function signatures for tools linked to a specified node within a graph.
|
||||
|
||||
@@ -274,10 +231,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
tool based on the metadata and input schema of the linked nodes.
|
||||
|
||||
Args:
|
||||
node_id (str): The ID of the node for which tool function signatures are to be created.
|
||||
graph (Any): The graph object containing nodes and links.
|
||||
tool_graph_metadata (List[Any]): Metadata for the tool graphs, used to retrieve
|
||||
names and descriptions for the tools.
|
||||
node_id: The node_id for which to create function signatures.
|
||||
|
||||
Returns:
|
||||
list[dict[str, Any]]: A list of dictionaries, each representing a function signature
|
||||
@@ -287,44 +241,32 @@ class SmartDecisionMakerBlock(Block):
|
||||
ValueError: If no tool links are found for the specified node_id, or if a sink node
|
||||
or its metadata cannot be found.
|
||||
"""
|
||||
# Filter the graph links to find those that are tools and are linked to the specified node_id
|
||||
tool_links = [
|
||||
link
|
||||
for link in graph.links
|
||||
# NOTE: Maybe we can do a specific database call to only get relevant nodes
|
||||
# async def get_connected_output_nodes(source_node_id: str) -> list[Node]:
|
||||
# links = await AgentNodeLink.prisma().find_many(
|
||||
# where={"agentNodeSourceId": source_node_id},
|
||||
# include={"AgentNode": {"include": AGENT_NODE_INCLUDE}},
|
||||
# )
|
||||
# return [NodeModel.from_db(link.AgentNodeSink) for link in links]
|
||||
db_client = get_database_manager_client()
|
||||
tools = [
|
||||
(link, node)
|
||||
for link, node in db_client.get_connected_output_nodes(node_id)
|
||||
if link.source_name.startswith("tools_^_") and link.source_id == node_id
|
||||
]
|
||||
|
||||
if not tool_links:
|
||||
raise ValueError(
|
||||
f"Expected at least one tool link in the graph. Node ID: {node_id}. Graph: {graph.links}"
|
||||
)
|
||||
if not tools:
|
||||
raise ValueError("There is no next node to execute.")
|
||||
|
||||
return_tool_functions = []
|
||||
|
||||
grouped_tool_links = {}
|
||||
|
||||
for link in tool_links:
|
||||
grouped_tool_links.setdefault(link.sink_id, []).append(link)
|
||||
|
||||
for _, links in grouped_tool_links.items():
|
||||
sink_node = next(
|
||||
(node for node in graph.nodes if node.id == links[0].sink_id), None
|
||||
)
|
||||
grouped_tool_links: dict[str, tuple["Node", list["Link"]]] = {}
|
||||
for link, node in tools:
|
||||
if link.sink_id not in grouped_tool_links:
|
||||
grouped_tool_links[link.sink_id] = (node, [link])
|
||||
else:
|
||||
grouped_tool_links[link.sink_id][1].append(link)
|
||||
|
||||
for sink_node, links in grouped_tool_links.values():
|
||||
if not sink_node:
|
||||
raise ValueError(f"Sink node not found: {links[0].sink_id}")
|
||||
|
||||
if sink_node.block_id == AgentExecutorBlock().id:
|
||||
return_tool_functions.append(
|
||||
SmartDecisionMakerBlock._create_agent_function_signature(
|
||||
sink_node, links, tool_graph_metadata
|
||||
sink_node, links
|
||||
)
|
||||
)
|
||||
else:
|
||||
@@ -348,21 +290,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
db_client = get_database_manager_client()
|
||||
|
||||
# Retrieve the current graph and node details
|
||||
graph = db_client.get_graph(graph_id=graph_id, user_id=user_id)
|
||||
|
||||
if not graph:
|
||||
raise ValueError(
|
||||
f"The currently running graph that is executing this node is not found {graph_id}"
|
||||
)
|
||||
|
||||
tool_graph_metadata = self._get_tool_graph_metadata(node_id, graph)
|
||||
|
||||
tool_functions = self._create_function_signature(
|
||||
node_id, graph, tool_graph_metadata
|
||||
)
|
||||
tool_functions = self._create_function_signature(node_id)
|
||||
|
||||
prompt = [p.model_dump() for p in input_data.conversation_history]
|
||||
|
||||
@@ -388,13 +316,22 @@ class SmartDecisionMakerBlock(Block):
|
||||
)
|
||||
|
||||
if not response.tool_calls:
|
||||
|
||||
yield "finished", f"No Decision Made finishing task: {response.response}"
|
||||
|
||||
if response.tool_calls:
|
||||
assistant_response = response.response
|
||||
else:
|
||||
for tool_call in response.tool_calls:
|
||||
tool_name = tool_call.function.name
|
||||
tool_args = json.loads(tool_call.function.arguments)
|
||||
|
||||
for arg_name, arg_value in tool_args.items():
|
||||
yield f"tools_^_{tool_name}_{arg_name}".lower(), arg_value
|
||||
|
||||
assistant_response = "\n".join(
|
||||
f"[{c.function.name}] called with arguments: {c.function.arguments}"
|
||||
for c in response.tool_calls
|
||||
)
|
||||
|
||||
input_data.conversation_history.append(
|
||||
llm.Message(role=llm.MessageRole.ASSISTANT, content=assistant_response)
|
||||
)
|
||||
yield "conversations", input_data.conversation_history
|
||||
|
||||
@@ -72,7 +72,7 @@ class NodeModel(Node):
|
||||
webhook: Optional[Webhook] = None
|
||||
|
||||
@staticmethod
|
||||
def from_db(node: AgentNode):
|
||||
def from_db(node: AgentNode) -> "NodeModel":
|
||||
obj = NodeModel(
|
||||
id=node.id,
|
||||
block_id=node.agentBlockId,
|
||||
@@ -711,6 +711,18 @@ async def get_graph(
|
||||
return GraphModel.from_db(graph, for_export)
|
||||
|
||||
|
||||
async def get_connected_output_nodes(node_id: str) -> list[tuple[Link, Node]]:
|
||||
links = await AgentNodeLink.prisma().find_many(
|
||||
where={"agentNodeSourceId": node_id},
|
||||
include={"AgentNodeSink": {"include": AGENT_NODE_INCLUDE}}, # type: ignore
|
||||
)
|
||||
return [
|
||||
(Link.from_db(link), NodeModel.from_db(link.AgentNodeSink))
|
||||
for link in links
|
||||
if link.AgentNodeSink
|
||||
]
|
||||
|
||||
|
||||
async def set_graph_active_version(graph_id: str, version: int, user_id: str) -> None:
|
||||
# Activate the requested version if it exists and is owned by the user.
|
||||
updated_count = await AgentGraph.prisma().update_many(
|
||||
|
||||
@@ -14,7 +14,12 @@ from backend.data.execution import (
|
||||
upsert_execution_input,
|
||||
upsert_execution_output,
|
||||
)
|
||||
from backend.data.graph import get_graph, get_graph_metadata, get_node
|
||||
from backend.data.graph import (
|
||||
get_connected_output_nodes,
|
||||
get_graph,
|
||||
get_graph_metadata,
|
||||
get_node,
|
||||
)
|
||||
from backend.data.user import (
|
||||
get_user_integrations,
|
||||
get_user_metadata,
|
||||
@@ -64,6 +69,7 @@ class DatabaseManager(AppService):
|
||||
# Graphs
|
||||
get_node = exposed_run_and_wait(get_node)
|
||||
get_graph = exposed_run_and_wait(get_graph)
|
||||
get_connected_output_nodes = exposed_run_and_wait(get_connected_output_nodes)
|
||||
get_graph_metadata = exposed_run_and_wait(get_graph_metadata)
|
||||
|
||||
# Credits
|
||||
|
||||
@@ -230,7 +230,7 @@ async def test_smart_decision_maker_function_signature(server: SpinTestServer):
|
||||
test_graph = await create_graph(server, test_graph, test_user)
|
||||
|
||||
tool_functions = SmartDecisionMakerBlock._create_function_signature(
|
||||
test_graph.nodes[0].id, test_graph, [test_tool_graph]
|
||||
test_graph.nodes[0].id
|
||||
)
|
||||
assert tool_functions is not None, "Tool functions should not be None"
|
||||
|
||||
|
||||
@@ -219,7 +219,7 @@ export const BlocksControl: React.FC<BlocksControlProps> = ({
|
||||
</CardHeader>
|
||||
<CardContent className="overflow-scroll border-t border-t-gray-200 p-0 dark:border-t-slate-700">
|
||||
<ScrollArea
|
||||
className="h-[60vh] w-fit"
|
||||
className="h-[60vh] w-full"
|
||||
data-id="blocks-control-scroll-area"
|
||||
>
|
||||
{filteredAvailableBlocks.map((block) => (
|
||||
|
||||
Reference in New Issue
Block a user