mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-20 20:48:11 -05:00
fix(backend): Fix static link resolving behaviour on concurrent output
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from multiprocessing import Manager
|
||||
from typing import Any, AsyncGenerator, Generator, Generic, Type, TypeVar
|
||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, Generator, Generic, Type, TypeVar
|
||||
|
||||
from prisma import Json
|
||||
from prisma.enums import AgentExecutionStatus
|
||||
@@ -20,6 +20,9 @@ from backend.server.v2.store.exceptions import DatabaseError
|
||||
from backend.util import mock, type
|
||||
from backend.util.settings import Config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
class GraphExecutionEntry(BaseModel):
|
||||
user_id: str
|
||||
@@ -528,19 +531,38 @@ def merge_execution_input(data: BlockInput) -> BlockInput:
|
||||
return data
|
||||
|
||||
|
||||
async def get_latest_execution(node_id: str, graph_eid: str) -> ExecutionResult | None:
|
||||
execution = await AgentNodeExecution.prisma().find_first(
|
||||
async def get_output_from_links(
|
||||
links: dict[str, tuple[str, str]], graph_eid: str
|
||||
) -> BlockInput:
|
||||
"""
|
||||
Get the latest output from the inbound static links of a node.
|
||||
Args:
|
||||
links: dict[node_id, (source_name, sink_name)] of the links to get the output from.
|
||||
graph_eid: the id of the graph execution to get the output from.
|
||||
|
||||
Returns:
|
||||
BlockInput: a dict of the latest output from the links.
|
||||
"""
|
||||
executions = await AgentNodeExecution.prisma().find_many(
|
||||
where={
|
||||
"agentNodeId": node_id,
|
||||
"agentNodeId": {"in": list(links.keys())},
|
||||
"agentGraphExecutionId": graph_eid,
|
||||
"executionStatus": {"not": ExecutionStatus.INCOMPLETE}, # type: ignore
|
||||
},
|
||||
order={"queuedTime": "desc"},
|
||||
distinct=["agentNodeId"],
|
||||
include=EXECUTION_RESULT_INCLUDE,
|
||||
)
|
||||
if not execution:
|
||||
return None
|
||||
return ExecutionResult.from_db(execution)
|
||||
|
||||
latest_output = {}
|
||||
for e in executions:
|
||||
execution = ExecutionResult.from_db(e)
|
||||
source_name, sink_name = links[execution.node_id]
|
||||
if value := execution.output_data.get(source_name):
|
||||
latest_output[sink_name] = value[-1]
|
||||
|
||||
print(">>>>>>>>> from links", links, "latest_output", latest_output)
|
||||
return latest_output
|
||||
|
||||
|
||||
async def get_incomplete_executions(
|
||||
|
||||
@@ -6,7 +6,7 @@ from backend.data.execution import (
|
||||
create_graph_execution,
|
||||
get_execution_results,
|
||||
get_incomplete_executions,
|
||||
get_latest_execution,
|
||||
get_output_from_links,
|
||||
update_execution_status,
|
||||
update_graph_execution_start_time,
|
||||
update_graph_execution_stats,
|
||||
@@ -56,7 +56,7 @@ class DatabaseManager(AppService):
|
||||
create_graph_execution = exposed_run_and_wait(create_graph_execution)
|
||||
get_execution_results = exposed_run_and_wait(get_execution_results)
|
||||
get_incomplete_executions = exposed_run_and_wait(get_incomplete_executions)
|
||||
get_latest_execution = exposed_run_and_wait(get_latest_execution)
|
||||
get_output_from_links = exposed_run_and_wait(get_output_from_links)
|
||||
update_execution_status = exposed_run_and_wait(update_execution_status)
|
||||
update_graph_execution_start_time = exposed_run_and_wait(
|
||||
update_graph_execution_start_time
|
||||
|
||||
@@ -109,7 +109,7 @@ class LogMetadata:
|
||||
logger.exception(msg, extra={"json_fields": {**self.metadata, **extra}})
|
||||
|
||||
def _wrap(self, msg: str, **extra):
|
||||
return f"{self.prefix} {msg} {extra}"
|
||||
return f"{self.prefix} {msg} {extra or ""}"
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
@@ -317,18 +317,18 @@ def _enqueue_next_nodes(
|
||||
)
|
||||
|
||||
# Complete missing static input pins data using the last execution input.
|
||||
static_link_names = {
|
||||
link.sink_name
|
||||
static_links = {
|
||||
link.source_id: (link.source_name, link.sink_name)
|
||||
for link in next_node.input_links
|
||||
if link.is_static and link.sink_name not in next_node_input
|
||||
if link.is_static
|
||||
}
|
||||
if static_link_names and (
|
||||
latest_execution := db_client.get_latest_execution(
|
||||
next_node_id, graph_exec_id
|
||||
)
|
||||
):
|
||||
for name in static_link_names:
|
||||
next_node_input[name] = latest_execution.input_data.get(name)
|
||||
static_output = (
|
||||
db_client.get_output_from_links(static_links, graph_exec_id)
|
||||
if static_links
|
||||
else {}
|
||||
)
|
||||
for name, value in static_output.items():
|
||||
next_node_input[name] = next_node_input.get(name, value)
|
||||
|
||||
# Validate the input data for the next node.
|
||||
next_node_input, validation_msg = validate_exec(next_node, next_node_input)
|
||||
@@ -362,13 +362,8 @@ def _enqueue_next_nodes(
|
||||
idata = iexec.input_data
|
||||
ineid = iexec.node_exec_id
|
||||
|
||||
static_link_names = {
|
||||
link.sink_name
|
||||
for link in next_node.input_links
|
||||
if link.is_static and link.sink_name not in idata
|
||||
}
|
||||
for input_name in static_link_names:
|
||||
idata[input_name] = next_node_input[input_name]
|
||||
for input_name, input_value in static_output.items():
|
||||
idata[input_name] = idata.get(input_name, input_value)
|
||||
|
||||
idata, msg = validate_exec(next_node, idata)
|
||||
suffix = f"{next_output_name}>{next_input_name}~{ineid}:{msg}"
|
||||
|
||||
@@ -154,11 +154,11 @@ async def test_input_pin_always_waited(server: SpinTestServer):
|
||||
even when default value on that pin is defined, the value has to be ignored.
|
||||
|
||||
Test scenario:
|
||||
StoreValueBlock1
|
||||
StoreValueBlock
|
||||
\\ input
|
||||
>------- FindInDictionaryBlock | input_default: key: "", input: {}
|
||||
// key
|
||||
StoreValueBlock2
|
||||
AgentInputBlock
|
||||
"""
|
||||
logger.info("Starting test_input_pin_always_waited")
|
||||
nodes = [
|
||||
@@ -167,8 +167,8 @@ async def test_input_pin_always_waited(server: SpinTestServer):
|
||||
input_default={"input": {"key1": "value1", "key2": "value2"}},
|
||||
),
|
||||
graph.Node(
|
||||
block_id=StoreValueBlock().id,
|
||||
input_default={"input": "key2"},
|
||||
block_id=AgentInputBlock().id,
|
||||
input_default={"name": "input", "value": "key2"},
|
||||
),
|
||||
graph.Node(
|
||||
block_id=FindInDictionaryBlock().id,
|
||||
@@ -185,7 +185,7 @@ async def test_input_pin_always_waited(server: SpinTestServer):
|
||||
graph.Link(
|
||||
source_id=nodes[1].id,
|
||||
sink_id=nodes[2].id,
|
||||
source_name="output",
|
||||
source_name="result",
|
||||
sink_name="key",
|
||||
),
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user