fix(backend): Fix static link resolving behaviour on concurrent output

This commit is contained in:
Zamil Majdy
2025-03-06 23:30:03 +07:00
parent 5a9235bcf9
commit 4e4a047a40
4 changed files with 49 additions and 32 deletions

View File

@@ -1,7 +1,7 @@
from collections import defaultdict
from datetime import datetime, timezone
from multiprocessing import Manager
from typing import Any, AsyncGenerator, Generator, Generic, Type, TypeVar
from typing import TYPE_CHECKING, Any, AsyncGenerator, Generator, Generic, Type, TypeVar
from prisma import Json
from prisma.enums import AgentExecutionStatus
@@ -20,6 +20,9 @@ from backend.server.v2.store.exceptions import DatabaseError
from backend.util import mock, type
from backend.util.settings import Config
if TYPE_CHECKING:
pass
class GraphExecutionEntry(BaseModel):
user_id: str
@@ -528,19 +531,38 @@ def merge_execution_input(data: BlockInput) -> BlockInput:
return data
async def get_latest_execution(node_id: str, graph_eid: str) -> ExecutionResult | None:
execution = await AgentNodeExecution.prisma().find_first(
async def get_output_from_links(
links: dict[str, tuple[str, str]], graph_eid: str
) -> BlockInput:
"""
Get the latest output from the inbound static links of a node.
Args:
links: dict[node_id, (source_name, sink_name)] of the links to get the output from.
graph_eid: the id of the graph execution to get the output from.
Returns:
BlockInput: a dict of the latest output from the links.
"""
executions = await AgentNodeExecution.prisma().find_many(
where={
"agentNodeId": node_id,
"agentNodeId": {"in": list(links.keys())},
"agentGraphExecutionId": graph_eid,
"executionStatus": {"not": ExecutionStatus.INCOMPLETE}, # type: ignore
},
order={"queuedTime": "desc"},
distinct=["agentNodeId"],
include=EXECUTION_RESULT_INCLUDE,
)
if not execution:
return None
return ExecutionResult.from_db(execution)
latest_output = {}
for e in executions:
execution = ExecutionResult.from_db(e)
source_name, sink_name = links[execution.node_id]
if value := execution.output_data.get(source_name):
latest_output[sink_name] = value[-1]
print(">>>>>>>>> from links", links, "latest_output", latest_output)
return latest_output
async def get_incomplete_executions(

View File

@@ -6,7 +6,7 @@ from backend.data.execution import (
create_graph_execution,
get_execution_results,
get_incomplete_executions,
get_latest_execution,
get_output_from_links,
update_execution_status,
update_graph_execution_start_time,
update_graph_execution_stats,
@@ -56,7 +56,7 @@ class DatabaseManager(AppService):
create_graph_execution = exposed_run_and_wait(create_graph_execution)
get_execution_results = exposed_run_and_wait(get_execution_results)
get_incomplete_executions = exposed_run_and_wait(get_incomplete_executions)
get_latest_execution = exposed_run_and_wait(get_latest_execution)
get_output_from_links = exposed_run_and_wait(get_output_from_links)
update_execution_status = exposed_run_and_wait(update_execution_status)
update_graph_execution_start_time = exposed_run_and_wait(
update_graph_execution_start_time

View File

@@ -109,7 +109,7 @@ class LogMetadata:
logger.exception(msg, extra={"json_fields": {**self.metadata, **extra}})
def _wrap(self, msg: str, **extra):
return f"{self.prefix} {msg} {extra}"
return f"{self.prefix} {msg} {extra or ""}"
T = TypeVar("T")
@@ -317,18 +317,18 @@ def _enqueue_next_nodes(
)
# Complete missing static input pins data using the last execution input.
static_link_names = {
link.sink_name
static_links = {
link.source_id: (link.source_name, link.sink_name)
for link in next_node.input_links
if link.is_static and link.sink_name not in next_node_input
if link.is_static
}
if static_link_names and (
latest_execution := db_client.get_latest_execution(
next_node_id, graph_exec_id
)
):
for name in static_link_names:
next_node_input[name] = latest_execution.input_data.get(name)
static_output = (
db_client.get_output_from_links(static_links, graph_exec_id)
if static_links
else {}
)
for name, value in static_output.items():
next_node_input[name] = next_node_input.get(name, value)
# Validate the input data for the next node.
next_node_input, validation_msg = validate_exec(next_node, next_node_input)
@@ -362,13 +362,8 @@ def _enqueue_next_nodes(
idata = iexec.input_data
ineid = iexec.node_exec_id
static_link_names = {
link.sink_name
for link in next_node.input_links
if link.is_static and link.sink_name not in idata
}
for input_name in static_link_names:
idata[input_name] = next_node_input[input_name]
for input_name, input_value in static_output.items():
idata[input_name] = idata.get(input_name, input_value)
idata, msg = validate_exec(next_node, idata)
suffix = f"{next_output_name}>{next_input_name}~{ineid}:{msg}"

View File

@@ -154,11 +154,11 @@ async def test_input_pin_always_waited(server: SpinTestServer):
even when default value on that pin is defined, the value has to be ignored.
Test scenario:
StoreValueBlock1
StoreValueBlock
\\ input
>------- FindInDictionaryBlock | input_default: key: "", input: {}
// key
StoreValueBlock2
AgentInputBlock
"""
logger.info("Starting test_input_pin_always_waited")
nodes = [
@@ -167,8 +167,8 @@ async def test_input_pin_always_waited(server: SpinTestServer):
input_default={"input": {"key1": "value1", "key2": "value2"}},
),
graph.Node(
block_id=StoreValueBlock().id,
input_default={"input": "key2"},
block_id=AgentInputBlock().id,
input_default={"name": "input", "value": "key2"},
),
graph.Node(
block_id=FindInDictionaryBlock().id,
@@ -185,7 +185,7 @@ async def test_input_pin_always_waited(server: SpinTestServer):
graph.Link(
source_id=nodes[1].id,
sink_id=nodes[2].id,
source_name="output",
source_name="result",
sink_name="key",
),
]