mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-09 07:08:09 -05:00
feat(blocks): Improve SmartDecisionBlock & AIStructuredResponseGeneratorBlock (#10198)
Main issues: * `AIStructuredResponseGeneratorBlock` is not able to produce a list of objects. * `SmartDecisionBlock` is not able to call tools with some optional inputs. ### Changes 🏗️ * Allow persisting `null` / `None` value as execution output. * Provide `multiple_tool_calls` option for `SmartDecisionBlock`. * Provide `list_result` option for `AIStructuredResponseGeneratorBlock` ### Checklist 📋 #### For code changes: - [x] I have clearly listed my changes in the PR description - [x] I have made a test plan - [x] I have tested my changes according to the test plan: <!-- Put your test plan here: --> - [x] Run `SmartDecisionBlock` & `AIStructuredResponseGeneratorBlock`
This commit is contained in:
@@ -163,7 +163,7 @@ class IfInputMatchesBlock(Block):
|
||||
},
|
||||
{
|
||||
"input": 10,
|
||||
"value": None,
|
||||
"value": "None",
|
||||
"yes_value": "Yes",
|
||||
"no_value": "No",
|
||||
},
|
||||
|
||||
@@ -663,6 +663,11 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
description="Expected format of the response. If provided, the response will be validated against this format. "
|
||||
"The keys should be the expected fields in the response, and the values should be the description of the field.",
|
||||
)
|
||||
list_result: bool = SchemaField(
|
||||
title="List Result",
|
||||
default=False,
|
||||
description="Whether the response should be a list of objects in the expected format.",
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=LlmModel.GPT4O,
|
||||
@@ -702,7 +707,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
response: dict[str, Any] = SchemaField(
|
||||
response: dict[str, Any] | list[dict[str, Any]] = SchemaField(
|
||||
description="The response object generated by the language model."
|
||||
)
|
||||
prompt: list = SchemaField(description="The prompt sent to the language model.")
|
||||
@@ -793,13 +798,22 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
expected_format = [
|
||||
f'"{k}": "{v}"' for k, v in input_data.expected_format.items()
|
||||
]
|
||||
format_prompt = ",\n ".join(expected_format)
|
||||
if input_data.list_result:
|
||||
format_prompt = (
|
||||
f'"results": [\n {{\n {", ".join(expected_format)}\n }}\n]'
|
||||
)
|
||||
else:
|
||||
format_prompt = "\n ".join(expected_format)
|
||||
|
||||
sys_prompt = trim_prompt(
|
||||
f"""
|
||||
|Reply strictly only in the following JSON format:
|
||||
|{{
|
||||
| {format_prompt}
|
||||
|}}
|
||||
|
|
||||
|Ensure the response is valid JSON. Do not include any additional text outside of the JSON.
|
||||
|If you cannot provide all the keys, provide an empty string for the values you cannot answer.
|
||||
"""
|
||||
)
|
||||
prompt.append({"role": "system", "content": sys_prompt})
|
||||
@@ -807,17 +821,16 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
if input_data.prompt:
|
||||
prompt.append({"role": "user", "content": input_data.prompt})
|
||||
|
||||
def parse_response(resp: str) -> tuple[dict[str, Any], str | None]:
|
||||
def validate_response(parsed: object) -> str | None:
|
||||
try:
|
||||
parsed = json.loads(resp)
|
||||
if not isinstance(parsed, dict):
|
||||
return {}, f"Expected a dictionary, but got {type(parsed)}"
|
||||
return f"Expected a dictionary, but got {type(parsed)}"
|
||||
miss_keys = set(input_data.expected_format.keys()) - set(parsed.keys())
|
||||
if miss_keys:
|
||||
return parsed, f"Missing keys: {miss_keys}"
|
||||
return parsed, None
|
||||
return f"Missing keys: {miss_keys}"
|
||||
return None
|
||||
except JSONDecodeError as e:
|
||||
return {}, f"JSON decode error: {e}"
|
||||
return f"JSON decode error: {e}"
|
||||
|
||||
logger.info(f"LLM request: {prompt}")
|
||||
retry_prompt = ""
|
||||
@@ -843,18 +856,29 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
logger.info(f"LLM attempt-{retry_count} response: {response_text}")
|
||||
|
||||
if input_data.expected_format:
|
||||
parsed_dict, parsed_error = parse_response(response_text)
|
||||
if not parsed_error:
|
||||
yield "response", {
|
||||
k: (
|
||||
json.loads(v)
|
||||
if isinstance(v, str)
|
||||
and v.startswith("[")
|
||||
and v.endswith("]")
|
||||
else (", ".join(v) if isinstance(v, list) else v)
|
||||
|
||||
response_obj = json.loads(response_text)
|
||||
|
||||
if input_data.list_result and isinstance(response_obj, dict):
|
||||
if "results" in response_obj:
|
||||
response_obj = response_obj.get("results", [])
|
||||
elif len(response_obj) == 1:
|
||||
response_obj = list(response_obj.values())
|
||||
|
||||
response_error = "\n".join(
|
||||
[
|
||||
validation_error
|
||||
for response_item in (
|
||||
response_obj
|
||||
if isinstance(response_obj, list)
|
||||
else [response_obj]
|
||||
)
|
||||
for k, v in parsed_dict.items()
|
||||
}
|
||||
if (validation_error := validate_response(response_item))
|
||||
]
|
||||
)
|
||||
|
||||
if not response_error:
|
||||
yield "response", response_obj
|
||||
yield "prompt", self.prompt
|
||||
return
|
||||
else:
|
||||
@@ -871,7 +895,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
|
|
||||
|And this is the error:
|
||||
|--
|
||||
|{parsed_error}
|
||||
|{response_error}
|
||||
|--
|
||||
"""
|
||||
)
|
||||
|
||||
@@ -142,6 +142,12 @@ class SmartDecisionMakerBlock(Block):
|
||||
advanced=False,
|
||||
)
|
||||
credentials: llm.AICredentials = llm.AICredentialsField()
|
||||
multiple_tool_calls: bool = SchemaField(
|
||||
title="Multiple Tool Calls",
|
||||
default=False,
|
||||
description="Whether to allow multiple tool calls in a single response.",
|
||||
advanced=True,
|
||||
)
|
||||
sys_prompt: str = SchemaField(
|
||||
title="System Prompt",
|
||||
default="Thinking carefully step by step decide which function to call. "
|
||||
@@ -150,7 +156,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
"matching the required jsonschema signature, no missing argument is allowed. "
|
||||
"If you have already completed the task objective, you can end the task "
|
||||
"by providing the end result of your work as a finish message. "
|
||||
"Only provide EXACTLY one function call, multiple tool calls is strictly prohibited.",
|
||||
"Function parameters that has no default value and not optional typed has to be provided. ",
|
||||
description="The system prompt to provide additional context to the model.",
|
||||
)
|
||||
conversation_history: list[dict] = SchemaField(
|
||||
@@ -273,29 +279,18 @@ class SmartDecisionMakerBlock(Block):
|
||||
"name": SmartDecisionMakerBlock.cleanup(block.name),
|
||||
"description": block.description,
|
||||
}
|
||||
|
||||
sink_block_input_schema = block.input_schema
|
||||
properties = {}
|
||||
required = []
|
||||
|
||||
for link in links:
|
||||
sink_block_input_schema = block.input_schema
|
||||
description = (
|
||||
sink_block_input_schema.model_fields[link.sink_name].description
|
||||
if link.sink_name in sink_block_input_schema.model_fields
|
||||
and sink_block_input_schema.model_fields[link.sink_name].description
|
||||
else f"The {link.sink_name} of the tool"
|
||||
sink_name = SmartDecisionMakerBlock.cleanup(link.sink_name)
|
||||
properties[sink_name] = sink_block_input_schema.get_field_schema(
|
||||
link.sink_name
|
||||
)
|
||||
properties[SmartDecisionMakerBlock.cleanup(link.sink_name)] = {
|
||||
"type": "string",
|
||||
"description": description,
|
||||
}
|
||||
|
||||
tool_function["parameters"] = {
|
||||
"type": "object",
|
||||
**block.input_schema.jsonschema(),
|
||||
"properties": properties,
|
||||
"required": required,
|
||||
"additionalProperties": False,
|
||||
"strict": True,
|
||||
}
|
||||
|
||||
return {"type": "function", "function": tool_function}
|
||||
@@ -335,25 +330,27 @@ class SmartDecisionMakerBlock(Block):
|
||||
}
|
||||
|
||||
properties = {}
|
||||
required = []
|
||||
|
||||
for link in links:
|
||||
sink_block_input_schema = sink_node.input_default["input_schema"]
|
||||
sink_block_properties = sink_block_input_schema.get("properties", {}).get(
|
||||
link.sink_name, {}
|
||||
)
|
||||
sink_name = SmartDecisionMakerBlock.cleanup(link.sink_name)
|
||||
description = (
|
||||
sink_block_input_schema["properties"][link.sink_name]["description"]
|
||||
if "description"
|
||||
in sink_block_input_schema["properties"][link.sink_name]
|
||||
sink_block_properties["description"]
|
||||
if "description" in sink_block_properties
|
||||
else f"The {link.sink_name} of the tool"
|
||||
)
|
||||
properties[SmartDecisionMakerBlock.cleanup(link.sink_name)] = {
|
||||
properties[sink_name] = {
|
||||
"type": "string",
|
||||
"description": description,
|
||||
"default": json.dumps(sink_block_properties.get("default", None)),
|
||||
}
|
||||
|
||||
tool_function["parameters"] = {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required,
|
||||
"additionalProperties": False,
|
||||
"strict": True,
|
||||
}
|
||||
@@ -430,6 +427,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
tool_functions = self._create_function_signature(node_id)
|
||||
yield "tool_functions", json.dumps(tool_functions)
|
||||
|
||||
input_data.conversation_history = input_data.conversation_history or []
|
||||
prompt = [json.to_dict(p) for p in input_data.conversation_history if p]
|
||||
@@ -469,6 +467,10 @@ class SmartDecisionMakerBlock(Block):
|
||||
)
|
||||
|
||||
prompt.extend(tool_output)
|
||||
if input_data.multiple_tool_calls:
|
||||
input_data.sys_prompt += "\nYou can call a tool (different tools) multiple times in a single response."
|
||||
else:
|
||||
input_data.sys_prompt += "\nOnly provide EXACTLY one function call, multiple tool calls is strictly prohibited."
|
||||
|
||||
values = input_data.prompt_values
|
||||
if values:
|
||||
@@ -495,7 +497,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
max_tokens=input_data.max_tokens,
|
||||
tools=tool_functions,
|
||||
ollama_host=input_data.ollama_host,
|
||||
parallel_tool_calls=False,
|
||||
parallel_tool_calls=True if input_data.multiple_tool_calls else None,
|
||||
)
|
||||
|
||||
if not response.tool_calls:
|
||||
@@ -506,8 +508,31 @@ class SmartDecisionMakerBlock(Block):
|
||||
tool_name = tool_call.function.name
|
||||
tool_args = json.loads(tool_call.function.arguments)
|
||||
|
||||
for arg_name, arg_value in tool_args.items():
|
||||
yield f"tools_^_{tool_name}_~_{arg_name}", arg_value
|
||||
# Find the tool definition to get the expected arguments
|
||||
tool_def = next(
|
||||
(
|
||||
tool
|
||||
for tool in tool_functions
|
||||
if tool["function"]["name"] == tool_name
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if (
|
||||
tool_def
|
||||
and "function" in tool_def
|
||||
and "parameters" in tool_def["function"]
|
||||
):
|
||||
expected_args = tool_def["function"]["parameters"].get("properties", {})
|
||||
else:
|
||||
expected_args = tool_args.keys()
|
||||
|
||||
# Yield provided arguments and None for missing ones
|
||||
for arg_name in expected_args:
|
||||
if arg_name in tool_args:
|
||||
yield f"tools_^_{tool_name}_~_{arg_name}", tool_args[arg_name]
|
||||
else:
|
||||
yield f"tools_^_{tool_name}_~_{arg_name}", None
|
||||
|
||||
response.prompt.append(response.raw_response)
|
||||
yield "conversations", response.prompt
|
||||
|
||||
@@ -118,7 +118,10 @@ class BlockSchema(BaseModel):
|
||||
|
||||
@classmethod
|
||||
def validate_data(cls, data: BlockInput) -> str | None:
|
||||
return json.validate_with_jsonschema(schema=cls.jsonschema(), data=data)
|
||||
return json.validate_with_jsonschema(
|
||||
schema=cls.jsonschema(),
|
||||
data={k: v for k, v in data.items() if v is not None},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_mismatch_error(cls, data: BlockInput) -> str | None:
|
||||
@@ -471,7 +474,8 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
)
|
||||
|
||||
async for output_name, output_data in self.run(
|
||||
self.input_schema(**input_data), **kwargs
|
||||
self.input_schema(**{k: v for k, v in input_data.items() if v is not None}),
|
||||
**kwargs,
|
||||
):
|
||||
if output_name == "error":
|
||||
raise RuntimeError(output_data)
|
||||
|
||||
@@ -556,18 +556,18 @@ async def upsert_execution_input(
|
||||
async def upsert_execution_output(
|
||||
node_exec_id: str,
|
||||
output_name: str,
|
||||
output_data: Any,
|
||||
output_data: Any | None,
|
||||
) -> None:
|
||||
"""
|
||||
Insert AgentNodeExecutionInputOutput record for as one of AgentNodeExecution.Output.
|
||||
"""
|
||||
await AgentNodeExecutionInputOutput.prisma().create(
|
||||
data=AgentNodeExecutionInputOutputCreateInput(
|
||||
name=output_name,
|
||||
data=Json(output_data),
|
||||
referencedByOutputExecId=node_exec_id,
|
||||
)
|
||||
data = AgentNodeExecutionInputOutputCreateInput(
|
||||
name=output_name,
|
||||
referencedByOutputExecId=node_exec_id,
|
||||
)
|
||||
if output_data is not None:
|
||||
data["data"] = Json(output_data)
|
||||
await AgentNodeExecutionInputOutput.prisma().create(data=data)
|
||||
|
||||
|
||||
async def update_graph_execution_start_time(
|
||||
|
||||
@@ -289,8 +289,9 @@ async def _enqueue_next_nodes(
|
||||
next_input_name = node_link.sink_name
|
||||
next_node_id = node_link.sink_id
|
||||
|
||||
output_name, _ = output
|
||||
next_data = parse_execution_output(output, next_output_name)
|
||||
if next_data is None:
|
||||
if next_data is None and output_name != next_output_name:
|
||||
return enqueued_executions
|
||||
next_node = await db_client.get_node(next_node_id)
|
||||
|
||||
|
||||
@@ -0,0 +1,2 @@
|
||||
-- AlterTable
|
||||
ALTER TABLE "AgentNodeExecutionInputOutput" ALTER COLUMN "data" DROP NOT NULL;
|
||||
@@ -80,8 +80,8 @@ enum OnboardingStep {
|
||||
}
|
||||
|
||||
model UserOnboarding {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime? @updatedAt
|
||||
|
||||
completedSteps OnboardingStep[] @default([])
|
||||
@@ -122,7 +122,7 @@ model AgentGraph {
|
||||
|
||||
forkedFromId String?
|
||||
forkedFromVersion Int?
|
||||
forkedFrom AgentGraph? @relation("AgentGraphForks", fields: [forkedFromId, forkedFromVersion], references: [id, version])
|
||||
forkedFrom AgentGraph? @relation("AgentGraphForks", fields: [forkedFromId, forkedFromVersion], references: [id, version])
|
||||
forks AgentGraph[] @relation("AgentGraphForks")
|
||||
|
||||
Nodes AgentNode[]
|
||||
@@ -390,7 +390,7 @@ model AgentNodeExecutionInputOutput {
|
||||
id String @id @default(uuid())
|
||||
|
||||
name String
|
||||
data Json
|
||||
data Json?
|
||||
time DateTime @default(now())
|
||||
|
||||
// Prisma requires explicit back-references.
|
||||
|
||||
Reference in New Issue
Block a user