This commit is contained in:
Zamil Majdy
2025-03-07 15:51:42 +07:00
parent 44409c035f
commit 131d9d2e84
5 changed files with 51 additions and 17 deletions

View File

@@ -186,7 +186,7 @@ class SmartDecisionMakerBlock(Block):
@classmethod
def get_missing_links(cls, data: BlockInput, links: list["Link"]) -> set[str]:
# conversation_history & last_tool_output validation is handled differently
return super().get_missing_links(
missing_links = super().get_missing_links(
data,
[
link
@@ -196,6 +196,19 @@ class SmartDecisionMakerBlock(Block):
],
)
# Avoid executing the block if the last_tool_output is connected to a static
# link, like StoreValueBlock or AgentInputBlock.
if any(link.sink_name == "conversation_history" for link in links) and any(
link.sink_name == "last_tool_output" and link.is_static
for link in links
):
raise ValueError(
"Last Tool Output can't be connected to a static (dashed line) "
"link like the output of `StoreValue` or `AgentInput` block"
)
return missing_links
@classmethod
def get_missing_input(cls, data: BlockInput) -> set[str]:
if missing_input := super().get_missing_input(data):

View File

@@ -156,6 +156,10 @@ class CountdownTimerBlock(Block):
days: Union[int, str] = SchemaField(
advanced=False, description="Duration in days", default=0
)
repeat: int = SchemaField(
description="Number of times to repeat the timer",
default=1,
)
class Output(BlockSchema):
output_message: Any = SchemaField(
@@ -187,5 +191,6 @@ class CountdownTimerBlock(Block):
total_seconds = seconds + minutes * 60 + hours * 3600 + days * 86400
time.sleep(total_seconds)
yield "output_message", input_data.input_message
for _ in range(input_data.repeat):
time.sleep(total_seconds)
yield "output_message", input_data.input_message

View File

@@ -189,7 +189,7 @@ async def upsert_execution_input(
input_name: str,
input_data: Any,
node_exec_id: str | None = None,
) -> tuple[str, BlockInput]:
) -> tuple[ExecutionResult, BlockInput]:
"""
Insert AgentNodeExecutionInputOutput record for as one of AgentNodeExecution.Input.
If there is no AgentNodeExecution that has no `input_name` as input, create new one.
@@ -226,7 +226,7 @@ async def upsert_execution_input(
"referencedByInputExecId": existing_execution.id,
}
)
return existing_execution.id, {
return ExecutionResult.from_db(existing_execution), {
**{
input_data.name: type.convert(input_data.data, Type[Any])
for input_data in existing_execution.Input or []
@@ -243,7 +243,7 @@ async def upsert_execution_input(
"Input": {"create": {"name": input_name, "data": json_input_data}},
}
)
return result.id, {input_name: input_data}
return ExecutionResult.from_db(result), {input_name: input_data}
else:
raise ValueError(
@@ -535,7 +535,7 @@ async def get_output_from_links(
links: dict[str, tuple[str, str]], graph_eid: str
) -> BlockInput:
"""
Get the latest output from the inbound static links of a node.
Get the latest output from the graph links.
Args:
links: dict[node_id, (source_name, sink_name)] of the links to get the output from.
graph_eid: the id of the graph execution to get the output from.
@@ -561,7 +561,6 @@ async def get_output_from_links(
if value := execution.output_data.get(source_name):
latest_output[sink_name] = value[-1]
print(">>>>>>>>> from links", links, "latest_output", latest_output)
return latest_output

View File

@@ -109,7 +109,7 @@ class LogMetadata:
logger.exception(msg, extra={"json_fields": {**self.metadata, **extra}})
def _wrap(self, msg: str, **extra):
return f"{self.prefix} {msg} {extra or ""}"
return f"{self.prefix} {msg} {extra or ''}"
T = TypeVar("T")
@@ -292,6 +292,19 @@ def _enqueue_next_nodes(
data=data,
)
def validate_next_exec(
next_node_exec_id: str, next_node: Node, next_node_input: BlockInput
) -> tuple[BlockInput | None, str]:
try:
return validate_exec(next_node, next_node_input)
except Exception as e:
db_client.upsert_execution_output(next_node_exec_id, "error", str(e))
execution = db_client.update_execution_status(
next_node_exec_id, ExecutionStatus.FAILED
)
db_client.send_execution_update(execution)
return None, str(e)
def register_next_executions(node_link: Link) -> list[NodeExecutionEntry]:
enqueued_executions = []
next_output_name = node_link.source_name
@@ -309,12 +322,14 @@ def _enqueue_next_nodes(
# Or the same input to be consumed multiple times.
with synchronized(f"upsert_input-{next_node_id}-{graph_exec_id}"):
# Add output data to the earliest incomplete execution, or create a new one.
next_node_exec_id, next_node_input = db_client.upsert_execution_input(
next_node_exec, next_node_input = db_client.upsert_execution_input(
node_id=next_node_id,
graph_exec_id=graph_exec_id,
input_name=next_input_name,
input_data=next_data,
)
next_node_exec_id = next_node_exec.node_exec_id
db_client.send_execution_update(next_node_exec)
# Complete missing static input pins data using the last execution input.
static_links = {
@@ -331,7 +346,9 @@ def _enqueue_next_nodes(
next_node_input[name] = next_node_input.get(name, value)
# Validate the input data for the next node.
next_node_input, validation_msg = validate_exec(next_node, next_node_input)
next_node_input, validation_msg = validate_next_exec(
next_node_exec_id, next_node, next_node_input
)
suffix = f"{next_output_name}>{next_input_name}~{next_node_exec_id}:{validation_msg}"
# Incomplete input data, skip queueing the execution.
@@ -365,7 +382,7 @@ def _enqueue_next_nodes(
for input_name, input_value in static_output.items():
idata[input_name] = idata.get(input_name, input_value)
idata, msg = validate_exec(next_node, idata)
idata, msg = validate_next_exec(next_node_exec_id, next_node, idata)
suffix = f"{next_output_name}>{next_input_name}~{ineid}:{msg}"
if not idata:
log_metadata.info(f"Enqueueing static-link skipped: {suffix}")

View File

@@ -154,11 +154,11 @@ async def test_input_pin_always_waited(server: SpinTestServer):
even when default value on that pin is defined, the value has to be ignored.
Test scenario:
StoreValueBlock
StoreValueBlock1
\\ input
>------- FindInDictionaryBlock | input_default: key: "", input: {}
// key
AgentInputBlock
StoreValueBlock2
"""
logger.info("Starting test_input_pin_always_waited")
nodes = [
@@ -167,8 +167,8 @@ async def test_input_pin_always_waited(server: SpinTestServer):
input_default={"input": {"key1": "value1", "key2": "value2"}},
),
graph.Node(
block_id=AgentInputBlock().id,
input_default={"name": "input", "value": "key2"},
block_id=StoreValueBlock().id,
input_default={"input": "key2"},
),
graph.Node(
block_id=FindInDictionaryBlock().id,
@@ -185,7 +185,7 @@ async def test_input_pin_always_waited(server: SpinTestServer):
graph.Link(
source_id=nodes[1].id,
sink_id=nodes[2].id,
source_name="result",
source_name="output",
sink_name="key",
),
]