mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-09 15:17:59 -05:00
fix(blocks/ai): Make AI List Generator block more reliable (#11317)
- Resolves #11305 ### Changes 🏗️ Make `AIListGeneratorBlock` more reliable: - Leverage `AIStructuredResponseGenerator`'s robust prompt/retry/validate logic - Use JSON format instead of Python list format - Add `force_json_output` toggle - Fix output instructions in prompt (only string values allowed) ### Checklist 📋 #### For code changes: - [x] I have clearly listed my changes in the PR description - [x] I have made a test plan - [x] I have tested my changes according to the test plan: - [x] Works without `force_json_output` - [x] Works with `force_json_output` - [x] Retry mechanism works as intended
This commit is contained in:
committed by
GitHub
parent
27d886f05c
commit
749be06599
@@ -1,6 +1,5 @@
|
||||
# This file contains a lot of prompt block strings that would trigger "line too long"
|
||||
# flake8: noqa: E501
|
||||
import ast
|
||||
import logging
|
||||
import re
|
||||
import secrets
|
||||
@@ -1633,6 +1632,17 @@ class AIListGeneratorBlock(AIBlockBase):
|
||||
ge=1,
|
||||
le=5,
|
||||
)
|
||||
force_json_output: bool = SchemaField(
|
||||
title="Restrict LLM to pure JSON output",
|
||||
default=False,
|
||||
description=(
|
||||
"Whether to force the LLM to produce a JSON-only response. "
|
||||
"This can increase the block's reliability, "
|
||||
"but may also reduce the quality of the response "
|
||||
"because it prohibits the LLM from reasoning "
|
||||
"before providing its JSON response."
|
||||
),
|
||||
)
|
||||
max_tokens: int | None = SchemaField(
|
||||
advanced=True,
|
||||
default=None,
|
||||
@@ -1645,7 +1655,7 @@ class AIListGeneratorBlock(AIBlockBase):
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
generated_list: List[str] = SchemaField(description="The generated list.")
|
||||
generated_list: list[str] = SchemaField(description="The generated list.")
|
||||
list_item: str = SchemaField(
|
||||
description="Each individual item in the list.",
|
||||
)
|
||||
@@ -1654,7 +1664,7 @@ class AIListGeneratorBlock(AIBlockBase):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="9c0b0450-d199-458b-a731-072189dd6593",
|
||||
description="Generate a Python list based on the given prompt using a Large Language Model (LLM).",
|
||||
description="Generate a list of values based on the given prompt using a Large Language Model (LLM).",
|
||||
categories={BlockCategory.AI, BlockCategory.TEXT},
|
||||
input_schema=AIListGeneratorBlock.Input,
|
||||
output_schema=AIListGeneratorBlock.Output,
|
||||
@@ -1671,6 +1681,7 @@ class AIListGeneratorBlock(AIBlockBase):
|
||||
"model": LlmModel.GPT4O,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"max_retries": 3,
|
||||
"force_json_output": False,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
@@ -1687,7 +1698,13 @@ class AIListGeneratorBlock(AIBlockBase):
|
||||
],
|
||||
test_mock={
|
||||
"llm_call": lambda input_data, credentials: {
|
||||
"response": "['Zylora Prime', 'Kharon-9', 'Vortexia', 'Oceara', 'Draknos']"
|
||||
"list": [
|
||||
"Zylora Prime",
|
||||
"Kharon-9",
|
||||
"Vortexia",
|
||||
"Oceara",
|
||||
"Draknos",
|
||||
]
|
||||
},
|
||||
},
|
||||
)
|
||||
@@ -1696,7 +1713,7 @@ class AIListGeneratorBlock(AIBlockBase):
|
||||
self,
|
||||
input_data: AIStructuredResponseGeneratorBlock.Input,
|
||||
credentials: APIKeyCredentials,
|
||||
) -> dict[str, str]:
|
||||
) -> dict[str, Any]:
|
||||
llm_block = AIStructuredResponseGeneratorBlock()
|
||||
response = await llm_block.run_once(
|
||||
input_data, "response", credentials=credentials
|
||||
@@ -1704,72 +1721,23 @@ class AIListGeneratorBlock(AIBlockBase):
|
||||
self.merge_llm_stats(llm_block)
|
||||
return response
|
||||
|
||||
@staticmethod
|
||||
def string_to_list(string):
|
||||
"""
|
||||
Converts a string representation of a list into an actual Python list object.
|
||||
"""
|
||||
logger.debug(f"Converting string to list. Input string: {string}")
|
||||
try:
|
||||
# Use ast.literal_eval to safely evaluate the string
|
||||
python_list = ast.literal_eval(string)
|
||||
if isinstance(python_list, list):
|
||||
logger.debug(f"Successfully converted string to list: {python_list}")
|
||||
return python_list
|
||||
else:
|
||||
logger.error(f"The provided string '{string}' is not a valid list")
|
||||
raise ValueError(f"The provided string '{string}' is not a valid list.")
|
||||
except (SyntaxError, ValueError) as e:
|
||||
logger.error(f"Failed to convert string to list: {e}")
|
||||
raise ValueError("Invalid list format. Could not convert to list.")
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
logger.debug(f"Starting AIListGeneratorBlock.run with input data: {input_data}")
|
||||
|
||||
# Check for API key
|
||||
api_key_check = credentials.api_key.get_secret_value()
|
||||
if not api_key_check:
|
||||
raise ValueError("No LLM API key provided.")
|
||||
# Create a proper expected format for the structured response generator
|
||||
expected_format = {
|
||||
"list": "A JSON array containing the generated string values"
|
||||
}
|
||||
if input_data.force_json_output:
|
||||
# Add reasoning field for better performance
|
||||
expected_format = {
|
||||
"reasoning": "... (optional)",
|
||||
**expected_format,
|
||||
}
|
||||
|
||||
# Prepare the system prompt
|
||||
sys_prompt = """You are a Python list generator. Your task is to generate a Python list based on the user's prompt.
|
||||
|Respond ONLY with a valid python list.
|
||||
|The list can contain strings, numbers, or nested lists as appropriate.
|
||||
|Do not include any explanations or additional text.
|
||||
|
||||
|Valid Example string formats:
|
||||
|
||||
|Example 1:
|
||||
|```
|
||||
|['1', '2', '3', '4']
|
||||
|```
|
||||
|
||||
|Example 2:
|
||||
|```
|
||||
|[['1', '2'], ['3', '4'], ['5', '6']]
|
||||
|```
|
||||
|
||||
|Example 3:
|
||||
|```
|
||||
|['1', ['2', '3'], ['4', ['5', '6']]]
|
||||
|```
|
||||
|
||||
|Example 4:
|
||||
|```
|
||||
|['a', 'b', 'c']
|
||||
|```
|
||||
|
||||
|Example 5:
|
||||
|```
|
||||
|['1', '2.5', 'string', 'True', ['False', 'None']]
|
||||
|```
|
||||
|
||||
|Do not include any explanations or additional text, just respond with the list in the format specified above.
|
||||
|Do not include code fences or any other formatting, just the raw list.
|
||||
"""
|
||||
# If a focus is provided, add it to the prompt
|
||||
# Build the prompt
|
||||
if input_data.focus:
|
||||
prompt = f"Generate a list with the following focus:\n<focus>\n\n{input_data.focus}</focus>"
|
||||
else:
|
||||
@@ -1777,7 +1745,7 @@ class AIListGeneratorBlock(AIBlockBase):
|
||||
if input_data.source_data:
|
||||
prompt = "Extract the main focus of the source data to a list.\ni.e if the source data is a news website, the focus would be the news stories rather than the social links in the footer."
|
||||
else:
|
||||
# No focus or source data provided, generat a random list
|
||||
# No focus or source data provided, generate a random list
|
||||
prompt = "Generate a random list."
|
||||
|
||||
# If the source data is provided, add it to the prompt
|
||||
@@ -1787,63 +1755,56 @@ class AIListGeneratorBlock(AIBlockBase):
|
||||
else:
|
||||
prompt += "\n\nInvent the data to generate the list from."
|
||||
|
||||
for attempt in range(input_data.max_retries):
|
||||
try:
|
||||
logger.debug("Calling LLM")
|
||||
llm_response = await self.llm_call(
|
||||
AIStructuredResponseGeneratorBlock.Input(
|
||||
sys_prompt=sys_prompt,
|
||||
prompt=prompt,
|
||||
credentials=input_data.credentials,
|
||||
model=input_data.model,
|
||||
expected_format={}, # Do not use structured response
|
||||
ollama_host=input_data.ollama_host,
|
||||
),
|
||||
credentials=credentials,
|
||||
)
|
||||
# Use the structured response generator to handle all the complexity
|
||||
response_obj = await self.llm_call(
|
||||
AIStructuredResponseGeneratorBlock.Input(
|
||||
sys_prompt=self.SYSTEM_PROMPT,
|
||||
prompt=prompt,
|
||||
credentials=input_data.credentials,
|
||||
model=input_data.model,
|
||||
expected_format=expected_format,
|
||||
force_json_output=input_data.force_json_output,
|
||||
retry=input_data.max_retries,
|
||||
max_tokens=input_data.max_tokens,
|
||||
ollama_host=input_data.ollama_host,
|
||||
),
|
||||
credentials=credentials,
|
||||
)
|
||||
logger.debug(f"Response object: {response_obj}")
|
||||
|
||||
logger.debug(f"LLM response: {llm_response}")
|
||||
# Extract the list from the response object
|
||||
if isinstance(response_obj, dict) and "list" in response_obj:
|
||||
parsed_list = response_obj["list"]
|
||||
else:
|
||||
# Fallback - treat the whole response as the list
|
||||
parsed_list = response_obj
|
||||
|
||||
# Extract Response string
|
||||
response_string = llm_response["response"]
|
||||
logger.debug(f"Response string: {response_string}")
|
||||
# Validate that we got a list
|
||||
if not isinstance(parsed_list, list):
|
||||
raise ValueError(
|
||||
f"Expected a list, but got {type(parsed_list).__name__}: {parsed_list}"
|
||||
)
|
||||
|
||||
# Convert the string to a Python list
|
||||
logger.debug("Converting string to Python list")
|
||||
parsed_list = self.string_to_list(response_string)
|
||||
logger.debug(f"Parsed list: {parsed_list}")
|
||||
logger.debug(f"Parsed list: {parsed_list}")
|
||||
|
||||
# If we reach here, we have a valid Python list
|
||||
logger.debug("Successfully generated a valid Python list")
|
||||
yield "generated_list", parsed_list
|
||||
yield "prompt", self.prompt
|
||||
# Yield the results
|
||||
yield "generated_list", parsed_list
|
||||
yield "prompt", self.prompt
|
||||
|
||||
# Yield each item in the list
|
||||
for item in parsed_list:
|
||||
yield "list_item", item
|
||||
return
|
||||
# Yield each item in the list
|
||||
for item in parsed_list:
|
||||
yield "list_item", item
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in attempt {attempt + 1}: {str(e)}")
|
||||
if attempt == input_data.max_retries - 1:
|
||||
logger.error(
|
||||
f"Failed to generate a valid Python list after {input_data.max_retries} attempts"
|
||||
)
|
||||
raise RuntimeError(
|
||||
f"Failed to generate a valid Python list after {input_data.max_retries} attempts. Last error: {str(e)}"
|
||||
)
|
||||
else:
|
||||
# Add a retry prompt
|
||||
logger.debug("Preparing retry prompt")
|
||||
prompt = f"""
|
||||
The previous attempt failed due to `{e}`
|
||||
Generate a valid Python list based on the original prompt.
|
||||
Remember to respond ONLY with a valid Python list as per the format specified earlier.
|
||||
Original prompt:
|
||||
```{prompt}```
|
||||
|
||||
Respond only with the list in the format specified with no commentary or apologies.
|
||||
"""
|
||||
logger.debug(f"Retry prompt: {prompt}")
|
||||
|
||||
logger.debug("AIListGeneratorBlock.run completed")
|
||||
SYSTEM_PROMPT = trim_prompt(
|
||||
"""
|
||||
|You are a JSON array generator. Your task is to generate a JSON array of string values based on the user's prompt.
|
||||
|
|
||||
|The 'list' field should contain a JSON array with the generated string values.
|
||||
|The array can contain ONLY strings.
|
||||
|
|
||||
|Valid JSON array formats include:
|
||||
|• ["string1", "string2", "string3"]
|
||||
|
|
||||
|Ensure you provide a proper JSON array with only string values in the 'list' field.
|
||||
"""
|
||||
)
|
||||
|
||||
@@ -6,12 +6,12 @@ from backend.data.block import Block, get_blocks
|
||||
from backend.util.test import execute_block_test
|
||||
|
||||
|
||||
@pytest.mark.parametrize("block", get_blocks().values(), ids=lambda b: b.name)
|
||||
@pytest.mark.parametrize("block", get_blocks().values(), ids=lambda b: b().name)
|
||||
async def test_available_blocks(block: Type[Block]):
|
||||
await execute_block_test(block())
|
||||
|
||||
|
||||
@pytest.mark.parametrize("block", get_blocks().values(), ids=lambda b: b.name)
|
||||
@pytest.mark.parametrize("block", get_blocks().values(), ids=lambda b: b().name)
|
||||
async def test_block_ids_valid(block: Type[Block]):
|
||||
# add the tests here to check they are uuid4
|
||||
import uuid
|
||||
|
||||
@@ -365,37 +365,22 @@ class TestLLMStatsTracking:
|
||||
assert outputs["response"] == "AI response to conversation"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ai_list_generator_with_retries(self):
|
||||
"""Test that AIListGeneratorBlock correctly tracks stats with retries."""
|
||||
async def test_ai_list_generator_basic_functionality(self):
|
||||
"""Test that AIListGeneratorBlock correctly works with structured responses."""
|
||||
import backend.blocks.llm as llm
|
||||
|
||||
block = llm.AIListGeneratorBlock()
|
||||
|
||||
# Counter to track calls
|
||||
call_count = 0
|
||||
|
||||
# Mock the llm_call to return a structured response
|
||||
async def mock_llm_call(input_data, credentials):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
|
||||
# Update stats
|
||||
if hasattr(block, "execution_stats") and block.execution_stats:
|
||||
block.execution_stats.input_token_count += 40
|
||||
block.execution_stats.output_token_count += 20
|
||||
block.execution_stats.llm_call_count += 1
|
||||
else:
|
||||
block.execution_stats = NodeExecutionStats(
|
||||
input_token_count=40,
|
||||
output_token_count=20,
|
||||
llm_call_count=1,
|
||||
)
|
||||
|
||||
if call_count == 1:
|
||||
# First call returns invalid format
|
||||
return {"response": "not a valid list"}
|
||||
else:
|
||||
# Second call returns valid list
|
||||
return {"response": "['item1', 'item2', 'item3']"}
|
||||
# Update stats to simulate LLM call
|
||||
block.execution_stats = NodeExecutionStats(
|
||||
input_token_count=50,
|
||||
output_token_count=30,
|
||||
llm_call_count=1,
|
||||
)
|
||||
# Return a structured response with the expected format
|
||||
return {"list": ["item1", "item2", "item3"]}
|
||||
|
||||
block.llm_call = mock_llm_call # type: ignore
|
||||
|
||||
@@ -413,14 +398,20 @@ class TestLLMStatsTracking:
|
||||
):
|
||||
outputs[output_name] = output_data
|
||||
|
||||
# Check stats - should have 2 calls
|
||||
assert call_count == 2
|
||||
assert block.execution_stats.input_token_count == 80 # 40 * 2
|
||||
assert block.execution_stats.output_token_count == 40 # 20 * 2
|
||||
assert block.execution_stats.llm_call_count == 2
|
||||
# Check stats
|
||||
assert block.execution_stats.input_token_count == 50
|
||||
assert block.execution_stats.output_token_count == 30
|
||||
assert block.execution_stats.llm_call_count == 1
|
||||
|
||||
# Check output
|
||||
assert outputs["generated_list"] == ["item1", "item2", "item3"]
|
||||
# Check that individual items were yielded
|
||||
# Note: outputs dict will only contain the last value for each key
|
||||
# So we need to check that the list_item output exists
|
||||
assert "list_item" in outputs
|
||||
# The list_item output should be the last item in the list
|
||||
assert outputs["list_item"] == "item3"
|
||||
assert "prompt" in outputs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_llm_stats(self):
|
||||
|
||||
Reference in New Issue
Block a user