feat(blocks): Improve SmartDecisionBlock & AIStructuredResponseGeneratorBlock (#10198)

Main issues:
* `AIStructuredResponseGeneratorBlock` is not able to produce a list of
objects.
* `SmartDecisionBlock` is not able to call tools with some optional
inputs.

### Changes 🏗️

* Allow persisting `null` / `None` value as execution output.
* Provide `multiple_tool_calls` option for `SmartDecisionBlock`.
* Provide `list_result` option for `AIStructuredResponseGeneratorBlock`

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  <!-- Put your test plan here: -->
  - [x] Run `SmartDecisionBlock` & `AIStructuredResponseGeneratorBlock`
This commit is contained in:
Zamil Majdy
2025-06-20 07:14:02 -07:00
committed by GitHub
parent aab40fe225
commit 3df6dcd26b
8 changed files with 117 additions and 61 deletions

View File

@@ -163,7 +163,7 @@ class IfInputMatchesBlock(Block):
},
{
"input": 10,
"value": None,
"value": "None",
"yes_value": "Yes",
"no_value": "No",
},

View File

@@ -663,6 +663,11 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
description="Expected format of the response. If provided, the response will be validated against this format. "
"The keys should be the expected fields in the response, and the values should be the description of the field.",
)
list_result: bool = SchemaField(
title="List Result",
default=False,
description="Whether the response should be a list of objects in the expected format.",
)
model: LlmModel = SchemaField(
title="LLM Model",
default=LlmModel.GPT4O,
@@ -702,7 +707,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
)
class Output(BlockSchema):
response: dict[str, Any] = SchemaField(
response: dict[str, Any] | list[dict[str, Any]] = SchemaField(
description="The response object generated by the language model."
)
prompt: list = SchemaField(description="The prompt sent to the language model.")
@@ -793,13 +798,22 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
expected_format = [
f'"{k}": "{v}"' for k, v in input_data.expected_format.items()
]
format_prompt = ",\n ".join(expected_format)
if input_data.list_result:
format_prompt = (
f'"results": [\n {{\n {", ".join(expected_format)}\n }}\n]'
)
else:
format_prompt = "\n ".join(expected_format)
sys_prompt = trim_prompt(
f"""
|Reply strictly only in the following JSON format:
|{{
| {format_prompt}
|}}
|
|Ensure the response is valid JSON. Do not include any additional text outside of the JSON.
|If you cannot provide all the keys, provide an empty string for the values you cannot answer.
"""
)
prompt.append({"role": "system", "content": sys_prompt})
@@ -807,17 +821,16 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
if input_data.prompt:
prompt.append({"role": "user", "content": input_data.prompt})
def parse_response(resp: str) -> tuple[dict[str, Any], str | None]:
def validate_response(parsed: object) -> str | None:
try:
parsed = json.loads(resp)
if not isinstance(parsed, dict):
return {}, f"Expected a dictionary, but got {type(parsed)}"
return f"Expected a dictionary, but got {type(parsed)}"
miss_keys = set(input_data.expected_format.keys()) - set(parsed.keys())
if miss_keys:
return parsed, f"Missing keys: {miss_keys}"
return parsed, None
return f"Missing keys: {miss_keys}"
return None
except JSONDecodeError as e:
return {}, f"JSON decode error: {e}"
return f"JSON decode error: {e}"
logger.info(f"LLM request: {prompt}")
retry_prompt = ""
@@ -843,18 +856,29 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
logger.info(f"LLM attempt-{retry_count} response: {response_text}")
if input_data.expected_format:
parsed_dict, parsed_error = parse_response(response_text)
if not parsed_error:
yield "response", {
k: (
json.loads(v)
if isinstance(v, str)
and v.startswith("[")
and v.endswith("]")
else (", ".join(v) if isinstance(v, list) else v)
response_obj = json.loads(response_text)
if input_data.list_result and isinstance(response_obj, dict):
if "results" in response_obj:
response_obj = response_obj.get("results", [])
elif len(response_obj) == 1:
response_obj = list(response_obj.values())
response_error = "\n".join(
[
validation_error
for response_item in (
response_obj
if isinstance(response_obj, list)
else [response_obj]
)
for k, v in parsed_dict.items()
}
if (validation_error := validate_response(response_item))
]
)
if not response_error:
yield "response", response_obj
yield "prompt", self.prompt
return
else:
@@ -871,7 +895,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
|And this is the error:
|--
|{parsed_error}
|{response_error}
|--
"""
)

View File

@@ -142,6 +142,12 @@ class SmartDecisionMakerBlock(Block):
advanced=False,
)
credentials: llm.AICredentials = llm.AICredentialsField()
multiple_tool_calls: bool = SchemaField(
title="Multiple Tool Calls",
default=False,
description="Whether to allow multiple tool calls in a single response.",
advanced=True,
)
sys_prompt: str = SchemaField(
title="System Prompt",
default="Thinking carefully step by step decide which function to call. "
@@ -150,7 +156,7 @@ class SmartDecisionMakerBlock(Block):
"matching the required jsonschema signature, no missing argument is allowed. "
"If you have already completed the task objective, you can end the task "
"by providing the end result of your work as a finish message. "
"Only provide EXACTLY one function call, multiple tool calls is strictly prohibited.",
"Function parameters that has no default value and not optional typed has to be provided. ",
description="The system prompt to provide additional context to the model.",
)
conversation_history: list[dict] = SchemaField(
@@ -273,29 +279,18 @@ class SmartDecisionMakerBlock(Block):
"name": SmartDecisionMakerBlock.cleanup(block.name),
"description": block.description,
}
sink_block_input_schema = block.input_schema
properties = {}
required = []
for link in links:
sink_block_input_schema = block.input_schema
description = (
sink_block_input_schema.model_fields[link.sink_name].description
if link.sink_name in sink_block_input_schema.model_fields
and sink_block_input_schema.model_fields[link.sink_name].description
else f"The {link.sink_name} of the tool"
sink_name = SmartDecisionMakerBlock.cleanup(link.sink_name)
properties[sink_name] = sink_block_input_schema.get_field_schema(
link.sink_name
)
properties[SmartDecisionMakerBlock.cleanup(link.sink_name)] = {
"type": "string",
"description": description,
}
tool_function["parameters"] = {
"type": "object",
**block.input_schema.jsonschema(),
"properties": properties,
"required": required,
"additionalProperties": False,
"strict": True,
}
return {"type": "function", "function": tool_function}
@@ -335,25 +330,27 @@ class SmartDecisionMakerBlock(Block):
}
properties = {}
required = []
for link in links:
sink_block_input_schema = sink_node.input_default["input_schema"]
sink_block_properties = sink_block_input_schema.get("properties", {}).get(
link.sink_name, {}
)
sink_name = SmartDecisionMakerBlock.cleanup(link.sink_name)
description = (
sink_block_input_schema["properties"][link.sink_name]["description"]
if "description"
in sink_block_input_schema["properties"][link.sink_name]
sink_block_properties["description"]
if "description" in sink_block_properties
else f"The {link.sink_name} of the tool"
)
properties[SmartDecisionMakerBlock.cleanup(link.sink_name)] = {
properties[sink_name] = {
"type": "string",
"description": description,
"default": json.dumps(sink_block_properties.get("default", None)),
}
tool_function["parameters"] = {
"type": "object",
"properties": properties,
"required": required,
"additionalProperties": False,
"strict": True,
}
@@ -430,6 +427,7 @@ class SmartDecisionMakerBlock(Block):
**kwargs,
) -> BlockOutput:
tool_functions = self._create_function_signature(node_id)
yield "tool_functions", json.dumps(tool_functions)
input_data.conversation_history = input_data.conversation_history or []
prompt = [json.to_dict(p) for p in input_data.conversation_history if p]
@@ -469,6 +467,10 @@ class SmartDecisionMakerBlock(Block):
)
prompt.extend(tool_output)
if input_data.multiple_tool_calls:
input_data.sys_prompt += "\nYou can call a tool (different tools) multiple times in a single response."
else:
input_data.sys_prompt += "\nOnly provide EXACTLY one function call, multiple tool calls is strictly prohibited."
values = input_data.prompt_values
if values:
@@ -495,7 +497,7 @@ class SmartDecisionMakerBlock(Block):
max_tokens=input_data.max_tokens,
tools=tool_functions,
ollama_host=input_data.ollama_host,
parallel_tool_calls=False,
parallel_tool_calls=True if input_data.multiple_tool_calls else None,
)
if not response.tool_calls:
@@ -506,8 +508,31 @@ class SmartDecisionMakerBlock(Block):
tool_name = tool_call.function.name
tool_args = json.loads(tool_call.function.arguments)
for arg_name, arg_value in tool_args.items():
yield f"tools_^_{tool_name}_~_{arg_name}", arg_value
# Find the tool definition to get the expected arguments
tool_def = next(
(
tool
for tool in tool_functions
if tool["function"]["name"] == tool_name
),
None,
)
if (
tool_def
and "function" in tool_def
and "parameters" in tool_def["function"]
):
expected_args = tool_def["function"]["parameters"].get("properties", {})
else:
expected_args = tool_args.keys()
# Yield provided arguments and None for missing ones
for arg_name in expected_args:
if arg_name in tool_args:
yield f"tools_^_{tool_name}_~_{arg_name}", tool_args[arg_name]
else:
yield f"tools_^_{tool_name}_~_{arg_name}", None
response.prompt.append(response.raw_response)
yield "conversations", response.prompt

View File

@@ -118,7 +118,10 @@ class BlockSchema(BaseModel):
@classmethod
def validate_data(cls, data: BlockInput) -> str | None:
return json.validate_with_jsonschema(schema=cls.jsonschema(), data=data)
return json.validate_with_jsonschema(
schema=cls.jsonschema(),
data={k: v for k, v in data.items() if v is not None},
)
@classmethod
def get_mismatch_error(cls, data: BlockInput) -> str | None:
@@ -471,7 +474,8 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
)
async for output_name, output_data in self.run(
self.input_schema(**input_data), **kwargs
self.input_schema(**{k: v for k, v in input_data.items() if v is not None}),
**kwargs,
):
if output_name == "error":
raise RuntimeError(output_data)

View File

@@ -556,18 +556,18 @@ async def upsert_execution_input(
async def upsert_execution_output(
node_exec_id: str,
output_name: str,
output_data: Any,
output_data: Any | None,
) -> None:
"""
Insert AgentNodeExecutionInputOutput record for as one of AgentNodeExecution.Output.
"""
await AgentNodeExecutionInputOutput.prisma().create(
data=AgentNodeExecutionInputOutputCreateInput(
name=output_name,
data=Json(output_data),
referencedByOutputExecId=node_exec_id,
)
data = AgentNodeExecutionInputOutputCreateInput(
name=output_name,
referencedByOutputExecId=node_exec_id,
)
if output_data is not None:
data["data"] = Json(output_data)
await AgentNodeExecutionInputOutput.prisma().create(data=data)
async def update_graph_execution_start_time(

View File

@@ -289,8 +289,9 @@ async def _enqueue_next_nodes(
next_input_name = node_link.sink_name
next_node_id = node_link.sink_id
output_name, _ = output
next_data = parse_execution_output(output, next_output_name)
if next_data is None:
if next_data is None and output_name != next_output_name:
return enqueued_executions
next_node = await db_client.get_node(next_node_id)

View File

@@ -0,0 +1,2 @@
-- AlterTable
ALTER TABLE "AgentNodeExecutionInputOutput" ALTER COLUMN "data" DROP NOT NULL;

View File

@@ -80,8 +80,8 @@ enum OnboardingStep {
}
model UserOnboarding {
id String @id @default(uuid())
createdAt DateTime @default(now())
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime? @updatedAt
completedSteps OnboardingStep[] @default([])
@@ -122,7 +122,7 @@ model AgentGraph {
forkedFromId String?
forkedFromVersion Int?
forkedFrom AgentGraph? @relation("AgentGraphForks", fields: [forkedFromId, forkedFromVersion], references: [id, version])
forkedFrom AgentGraph? @relation("AgentGraphForks", fields: [forkedFromId, forkedFromVersion], references: [id, version])
forks AgentGraph[] @relation("AgentGraphForks")
Nodes AgentNode[]
@@ -390,7 +390,7 @@ model AgentNodeExecutionInputOutput {
id String @id @default(uuid())
name String
data Json
data Json?
time DateTime @default(now())
// Prisma requires explicit back-references.