Correct and clean up JSON handling (#4655)

* Correct and clean up JSON handling

* Use ast for message history too

* Lint

* Add comments explaining why we use literal_eval

* Add descriptions to llm_response_format schema

* Parse responses in code blocks

* Be more careful when parsing in code blocks

* Lint
This commit is contained in:
Erik Peterson
2023-06-13 09:54:50 -07:00
committed by GitHub
parent 7bf39cbb72
commit 07d9b584f7
15 changed files with 135 additions and 730 deletions

View File

@@ -1,3 +1,4 @@
import json
import signal
import sys
from datetime import datetime
@@ -7,8 +8,7 @@ from colorama import Fore, Style
from autogpt.commands.command import CommandRegistry
from autogpt.config import Config
from autogpt.config.ai_config import AIConfig
from autogpt.json_utils.json_fix_llm import fix_json_using_multiple_techniques
from autogpt.json_utils.utilities import LLM_DEFAULT_RESPONSE_FORMAT, validate_json
from autogpt.json_utils.utilities import extract_json_from_response, validate_json
from autogpt.llm.base import ChatSequence
from autogpt.llm.chat import chat_with_ai, create_chat_completion
from autogpt.llm.providers.openai import OPEN_AI_CHAT_MODELS
@@ -144,7 +144,13 @@ class Agent:
self.config.fast_llm_model,
)
assistant_reply_json = fix_json_using_multiple_techniques(assistant_reply)
try:
assistant_reply_json = extract_json_from_response(assistant_reply)
validate_json(assistant_reply_json)
except json.JSONDecodeError as e:
logger.error(f"Exception while validating assistant reply JSON: {e}")
assistant_reply_json = {}
for plugin in self.config.plugins:
if not plugin.can_handle_post_planning():
continue
@@ -152,7 +158,6 @@ class Agent:
# Print Assistant thoughts
if assistant_reply_json != {}:
validate_json(assistant_reply_json, LLM_DEFAULT_RESPONSE_FORMAT)
# Get command name and arguments
try:
print_assistant_thoughts(

View File

@@ -4,11 +4,11 @@ import subprocess
from pathlib import Path
import docker
from confection import Config
from docker.errors import ImageNotFound
from autogpt.agent.agent import Agent
from autogpt.commands.command import command
from autogpt.config import Config
from autogpt.logs import logger
from autogpt.setup import CFG
from autogpt.workspace.workspace import Workspace

View File

@@ -1,121 +0,0 @@
"""This module contains functions to fix JSON strings using general programmatic approaches, suitable for addressing
common JSON formatting issues."""
from __future__ import annotations
import contextlib
import json
import re
from typing import Optional
from autogpt.config import Config
from autogpt.json_utils.utilities import extract_char_position
from autogpt.logs import logger
CFG = Config()
def fix_invalid_escape(json_to_load: str, error_message: str) -> str:
"""Fix invalid escape sequences in JSON strings.
Args:
json_to_load (str): The JSON string.
error_message (str): The error message from the JSONDecodeError
exception.
Returns:
str: The JSON string with invalid escape sequences fixed.
"""
while error_message.startswith("Invalid \\escape"):
bad_escape_location = extract_char_position(error_message)
json_to_load = (
json_to_load[:bad_escape_location] + json_to_load[bad_escape_location + 1 :]
)
try:
json.loads(json_to_load)
return json_to_load
except json.JSONDecodeError as e:
logger.debug("json loads error - fix invalid escape", e)
error_message = str(e)
return json_to_load
def balance_braces(json_string: str) -> Optional[str]:
"""
Balance the braces in a JSON string.
Args:
json_string (str): The JSON string.
Returns:
str: The JSON string with braces balanced.
"""
open_braces_count = json_string.count("{")
close_braces_count = json_string.count("}")
while open_braces_count > close_braces_count:
json_string += "}"
close_braces_count += 1
while close_braces_count > open_braces_count:
json_string = json_string.rstrip("}")
close_braces_count -= 1
with contextlib.suppress(json.JSONDecodeError):
json.loads(json_string)
return json_string
def add_quotes_to_property_names(json_string: str) -> str:
"""
Add quotes to property names in a JSON string.
Args:
json_string (str): The JSON string.
Returns:
str: The JSON string with quotes added to property names.
"""
def replace_func(match: re.Match) -> str:
return f'"{match[1]}":'
property_name_pattern = re.compile(r"(\w+):")
corrected_json_string = property_name_pattern.sub(replace_func, json_string)
try:
json.loads(corrected_json_string)
return corrected_json_string
except json.JSONDecodeError as e:
raise e
def correct_json(json_to_load: str) -> str:
"""
Correct common JSON errors.
Args:
json_to_load (str): The JSON string.
"""
try:
logger.debug("json", json_to_load)
json.loads(json_to_load)
return json_to_load
except json.JSONDecodeError as e:
logger.debug("json loads error", e)
error_message = str(e)
if error_message.startswith("Invalid \\escape"):
json_to_load = fix_invalid_escape(json_to_load, error_message)
if error_message.startswith(
"Expecting property name enclosed in double quotes"
):
json_to_load = add_quotes_to_property_names(json_to_load)
try:
json.loads(json_to_load)
return json_to_load
except json.JSONDecodeError as e:
logger.debug("json loads error - add quotes", e)
error_message = str(e)
if balanced_str := balance_braces(json_to_load):
return balanced_str
return json_to_load

View File

@@ -1,239 +0,0 @@
"""This module contains functions to fix JSON strings generated by LLM models, such as ChatGPT, using the assistance
of the ChatGPT API or LLM models."""
from __future__ import annotations
import contextlib
import json
from typing import Any, Dict
from colorama import Fore
from regex import regex
from autogpt.config import Config
from autogpt.json_utils.json_fix_general import correct_json
from autogpt.llm.utils import call_ai_function
from autogpt.logs import logger
from autogpt.speech import say_text
JSON_SCHEMA = """
{
"command": {
"name": "command name",
"args": {
"arg name": "value"
}
},
"thoughts":
{
"text": "thought",
"reasoning": "reasoning",
"plan": "- short bulleted\n- list that conveys\n- long-term plan",
"criticism": "constructive self-criticism",
"speak": "thoughts summary to say to user"
}
}
"""
CFG = Config()
def auto_fix_json(json_string: str, schema: str) -> str:
"""Fix the given JSON string to make it parseable and fully compliant with
the provided schema using GPT-3.
Args:
json_string (str): The JSON string to fix.
schema (str): The schema to use to fix the JSON.
Returns:
str: The fixed JSON string.
"""
# Try to fix the JSON using GPT:
function_string = "def fix_json(json_string: str, schema:str=None) -> str:"
args = [f"'''{json_string}'''", f"'''{schema}'''"]
description_string = (
"This function takes a JSON string and ensures that it"
" is parseable and fully compliant with the provided schema. If an object"
" or field specified in the schema isn't contained within the correct JSON,"
" it is omitted. The function also escapes any double quotes within JSON"
" string values to ensure that they are valid. If the JSON string contains"
" any None or NaN values, they are replaced with null before being parsed."
)
# If it doesn't already start with a "`", add one:
if not json_string.startswith("`"):
json_string = "```json\n" + json_string + "\n```"
result_string = call_ai_function(
function_string, args, description_string, model=CFG.fast_llm_model
)
logger.debug("------------ JSON FIX ATTEMPT ---------------")
logger.debug(f"Original JSON: {json_string}")
logger.debug("-----------")
logger.debug(f"Fixed JSON: {result_string}")
logger.debug("----------- END OF FIX ATTEMPT ----------------")
try:
json.loads(result_string) # just check the validity
return result_string
except json.JSONDecodeError: # noqa: E722
# Get the call stack:
# import traceback
# call_stack = traceback.format_exc()
# print(f"Failed to fix JSON: '{json_string}' "+call_stack)
return "failed"
def fix_json_using_multiple_techniques(assistant_reply: str) -> Dict[Any, Any]:
"""Fix the given JSON string to make it parseable and fully compliant with two techniques.
Args:
json_string (str): The JSON string to fix.
Returns:
str: The fixed JSON string.
"""
assistant_reply = assistant_reply.strip()
if assistant_reply.startswith("```json"):
assistant_reply = assistant_reply[7:]
if assistant_reply.endswith("```"):
assistant_reply = assistant_reply[:-3]
try:
return json.loads(assistant_reply) # just check the validity
except json.JSONDecodeError: # noqa: E722
pass
if assistant_reply.startswith("json "):
assistant_reply = assistant_reply[5:]
assistant_reply = assistant_reply.strip()
try:
return json.loads(assistant_reply) # just check the validity
except json.JSONDecodeError: # noqa: E722
pass
# Parse and print Assistant response
assistant_reply_json = fix_and_parse_json(assistant_reply)
logger.debug("Assistant reply JSON: %s", str(assistant_reply_json))
if assistant_reply_json == {}:
assistant_reply_json = attempt_to_fix_json_by_finding_outermost_brackets(
assistant_reply
)
logger.debug("Assistant reply JSON 2: %s", str(assistant_reply_json))
if assistant_reply_json != {}:
return assistant_reply_json
logger.error(
"Error: The following AI output couldn't be converted to a JSON:\n",
assistant_reply,
)
if CFG.speak_mode:
say_text("I have received an invalid JSON response from the OpenAI API.")
return {}
def fix_and_parse_json(
json_to_load: str, try_to_fix_with_gpt: bool = True
) -> Dict[Any, Any]:
"""Fix and parse JSON string
Args:
json_to_load (str): The JSON string.
try_to_fix_with_gpt (bool, optional): Try to fix the JSON with GPT.
Defaults to True.
Returns:
str or dict[Any, Any]: The parsed JSON.
"""
with contextlib.suppress(json.JSONDecodeError):
json_to_load = json_to_load.replace("\t", "")
return json.loads(json_to_load)
with contextlib.suppress(json.JSONDecodeError):
json_to_load = correct_json(json_to_load)
return json.loads(json_to_load)
# Let's do something manually:
# sometimes GPT responds with something BEFORE the braces:
# "I'm sorry, I don't understand. Please try again."
# {"text": "I'm sorry, I don't understand. Please try again.",
# "confidence": 0.0}
# So let's try to find the first brace and then parse the rest
# of the string
try:
brace_index = json_to_load.index("{")
maybe_fixed_json = json_to_load[brace_index:]
last_brace_index = maybe_fixed_json.rindex("}")
maybe_fixed_json = maybe_fixed_json[: last_brace_index + 1]
return json.loads(maybe_fixed_json)
except (json.JSONDecodeError, ValueError) as e:
return try_ai_fix(try_to_fix_with_gpt, e, json_to_load)
def try_ai_fix(
try_to_fix_with_gpt: bool, exception: Exception, json_to_load: str
) -> Dict[Any, Any]:
"""Try to fix the JSON with the AI
Args:
try_to_fix_with_gpt (bool): Whether to try to fix the JSON with the AI.
exception (Exception): The exception that was raised.
json_to_load (str): The JSON string to load.
Raises:
exception: If try_to_fix_with_gpt is False.
Returns:
str or dict[Any, Any]: The JSON string or dictionary.
"""
if not try_to_fix_with_gpt:
raise exception
if CFG.debug_mode:
logger.warn(
"Warning: Failed to parse AI output, attempting to fix."
"\n If you see this warning frequently, it's likely that"
" your prompt is confusing the AI. Try changing it up"
" slightly."
)
# Now try to fix this up using the ai_functions
ai_fixed_json = auto_fix_json(json_to_load, JSON_SCHEMA)
if ai_fixed_json != "failed":
return json.loads(ai_fixed_json)
# This allows the AI to react to the error message,
# which usually results in it correcting its ways.
# logger.error("Failed to fix AI output, telling the AI.")
return {}
def attempt_to_fix_json_by_finding_outermost_brackets(json_string: str):
if CFG.speak_mode and CFG.debug_mode:
say_text(
"I have received an invalid JSON response from the OpenAI API. "
"Trying to fix it now."
)
logger.error("Attempting to fix JSON by finding outermost brackets\n")
try:
json_pattern = regex.compile(r"\{(?:[^{}]|(?R))*\}")
json_match = json_pattern.search(json_string)
if json_match:
# Extract the valid JSON object from the string
json_string = json_match.group(0)
logger.typewriter_log(
title="Apparently json was fixed.", title_color=Fore.GREEN
)
if CFG.speak_mode and CFG.debug_mode:
say_text("Apparently json was fixed.")
else:
return {}
except (json.JSONDecodeError, ValueError):
if CFG.debug_mode:
logger.error(f"Error: Invalid JSON: {json_string}\n")
if CFG.speak_mode:
say_text("Didn't work. I will have to ignore this response then.")
logger.error("Error: Invalid JSON, setting it to empty JSON now.\n")
json_string = {}
return fix_and_parse_json(json_string)

View File

@@ -5,11 +5,25 @@
"thoughts": {
"type": "object",
"properties": {
"text": {"type": "string"},
"reasoning": {"type": "string"},
"plan": {"type": "string"},
"criticism": {"type": "string"},
"speak": {"type": "string"}
"text": {
"type": "string",
"description": "thoughts"
},
"reasoning": {
"type": "string"
},
"plan": {
"type": "string",
"description": "- short bulleted\n- list that conveys\n- long-term plan"
},
"criticism": {
"type": "string",
"description": "constructive self-criticism"
},
"speak": {
"type": "string",
"description": "thoughts summary to say to user"
}
},
"required": ["text", "reasoning", "plan", "criticism", "speak"],
"additionalProperties": false

View File

@@ -1,7 +1,8 @@
"""Utilities for the json_fixes package."""
import ast
import json
import os.path
import re
from typing import Any
from jsonschema import Draft7Validator
@@ -12,37 +13,47 @@ CFG = Config()
LLM_DEFAULT_RESPONSE_FORMAT = "llm_response_format_1"
def extract_char_position(error_message: str) -> int:
"""Extract the character position from the JSONDecodeError message.
def extract_json_from_response(response_content: str) -> dict:
# Sometimes the response includes the JSON in a code block with ```
if response_content.startswith("```") and response_content.endswith("```"):
# Discard the first and last ```, then re-join in case the response naturally included ```
response_content = "```".join(response_content.split("```")[1:-1])
Args:
error_message (str): The error message from the JSONDecodeError
exception.
Returns:
int: The character position.
"""
char_pattern = re.compile(r"\(char (\d+)\)")
if match := char_pattern.search(error_message):
return int(match[1])
else:
raise ValueError("Character position not found in the error message.")
# response content comes from OpenAI as a Python `str(content_dict)`, literal_eval reverses this
try:
return ast.literal_eval(response_content)
except BaseException as e:
logger.error(f"Error parsing JSON response with literal_eval {e}")
# TODO: How to raise an error here without causing the program to exit?
return {}
def validate_json(json_object: object, schema_name: str) -> dict | None:
def llm_response_schema(
schema_name: str = LLM_DEFAULT_RESPONSE_FORMAT,
) -> dict[str, Any]:
filename = os.path.join(os.path.dirname(__file__), f"{schema_name}.json")
with open(filename, "r") as f:
return json.load(f)
def validate_json(
json_object: object, schema_name: str = LLM_DEFAULT_RESPONSE_FORMAT
) -> bool:
"""
:type schema_name: object
:param schema_name: str
:type json_object: object
Returns:
bool: Whether the json_object is valid or not
"""
scheme_file = os.path.join(os.path.dirname(__file__), f"{schema_name}.json")
with open(scheme_file, "r") as f:
schema = json.load(f)
schema = llm_response_schema(schema_name)
validator = Draft7Validator(schema)
if errors := sorted(validator.iter_errors(json_object), key=lambda e: e.path):
logger.error("The JSON object is invalid.")
for error in errors:
logger.error(f"JSON Validation Error: {error}")
if CFG.debug_mode:
logger.error(
json.dumps(json_object, indent=4)
@@ -51,10 +62,11 @@ def validate_json(json_object: object, schema_name: str) -> dict | None:
for error in errors:
logger.error(f"Error: {error.message}")
else:
logger.debug("The JSON object is valid.")
return False
return json_object
logger.debug("The JSON object is valid.")
return True
def validate_json_string(json_string: str, schema_name: str) -> dict | None:
@@ -66,7 +78,9 @@ def validate_json_string(json_string: str, schema_name: str) -> dict | None:
try:
json_loaded = json.loads(json_string)
return validate_json(json_loaded, schema_name)
if not validate_json(json_loaded, schema_name):
return None
return json_loaded
except:
return None

View File

@@ -11,6 +11,7 @@ if TYPE_CHECKING:
from autogpt.config import Config
from autogpt.json_utils.utilities import (
LLM_DEFAULT_RESPONSE_FORMAT,
extract_json_from_response,
is_string_valid_json,
)
from autogpt.llm.base import ChatSequence, Message, MessageRole, MessageType
@@ -153,13 +154,14 @@ class MessageHistory:
# Remove "thoughts" dictionary from "content"
try:
content_dict = json.loads(event.content)
content_dict = extract_json_from_response(event.content)
if "thoughts" in content_dict:
del content_dict["thoughts"]
event.content = json.dumps(content_dict)
except json.decoder.JSONDecodeError:
except json.JSONDecodeError as e:
logger.error(f"Error: Invalid JSON: {e}")
if cfg.debug_mode:
logger.error(f"Error: Invalid JSON: {event.content}\n")
logger.error(f"{event.content}")
elif event.role.lower() == "system":
event.role = "your computer"

View File

@@ -1,7 +1,8 @@
""" A module for generating custom prompt strings."""
import json
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
from autogpt.json_utils.utilities import llm_response_schema
if TYPE_CHECKING:
from autogpt.commands.command import CommandRegistry
@@ -25,16 +26,6 @@ class PromptGenerator:
self.command_registry: CommandRegistry | None = None
self.name = "Bob"
self.role = "AI"
self.response_format = {
"thoughts": {
"text": "thought",
"reasoning": "reasoning",
"plan": "- short bulleted\n- list that conveys\n- long-term plan",
"criticism": "constructive self-criticism",
"speak": "thoughts summary to say to user",
},
"command": {"name": "command name", "args": {"arg name": "value"}},
}
def add_constraint(self, constraint: str) -> None:
"""
@@ -144,7 +135,6 @@ class PromptGenerator:
Returns:
str: The generated prompt string.
"""
formatted_response_format = json.dumps(self.response_format, indent=4)
return (
f"Constraints:\n{self._generate_numbered_list(self.constraints)}\n\n"
"Commands:\n"
@@ -152,7 +142,6 @@ class PromptGenerator:
f"Resources:\n{self._generate_numbered_list(self.resources)}\n\n"
"Performance Evaluation:\n"
f"{self._generate_numbered_list(self.performance_evaluation)}\n\n"
"You should only respond in JSON format as described below \nResponse"
f" Format: \n{formatted_response_format} \nEnsure the response can be"
" parsed by Python json.loads"
"Respond with only valid JSON conforming to the following schema: \n"
f"{llm_response_schema()}\n"
)

View File

@@ -11,7 +11,7 @@ from autogpt.utils import clean_input
CFG = Config()
DEFAULT_TRIGGERING_PROMPT = "Determine exactly one command to use, and respond using the format specified above:"
DEFAULT_TRIGGERING_PROMPT = "Determine exactly one command to use, and respond using the JSON schema specified previously:"
def build_default_prompt_generator() -> PromptGenerator:

View File

@@ -14,6 +14,5 @@ performance_evaluations: [
'Continuously review and analyze your actions to ensure you are performing to the best of your abilities.',
'Constructively self-criticize your big-picture behavior constantly.',
'Reflect on past decisions and strategies to refine your approach.',
'Every command has a cost, so be smart and efficient. Aim to complete tasks in the least number of steps.',
'Write all code to a file.'
'Every command has a cost, so be smart and efficient. Aim to complete tasks in the least number of steps.'
]

View File

@@ -49,4 +49,4 @@
"max_level_beaten": null
}
}
}
}

View File

@@ -1,71 +0,0 @@
from unittest import TestCase
from autogpt.json_utils.json_fix_llm import fix_and_parse_json
class TestParseJson(TestCase):
def test_valid_json(self):
"""Test that a valid JSON string is parsed correctly."""
json_str = '{"name": "John", "age": 30, "city": "New York"}'
obj = fix_and_parse_json(json_str)
self.assertEqual(obj, {"name": "John", "age": 30, "city": "New York"})
def test_invalid_json_minor(self):
"""Test that an invalid JSON string can not be fixed without gpt"""
json_str = '{"name": "John", "age": 30, "city": "New York",}'
with self.assertRaises(Exception):
fix_and_parse_json(json_str, try_to_fix_with_gpt=False)
def test_invalid_json_major_with_gpt(self):
"""Test that an invalid JSON string raises an error when try_to_fix_with_gpt is False"""
json_str = 'BEGIN: "name": "John" - "age": 30 - "city": "New York" :END'
with self.assertRaises(Exception):
fix_and_parse_json(json_str, try_to_fix_with_gpt=False)
def test_invalid_json_major_without_gpt(self):
"""Test that a REALLY invalid JSON string raises an error when try_to_fix_with_gpt is False"""
json_str = 'BEGIN: "name": "John" - "age": 30 - "city": "New York" :END'
# Assert that this raises an exception:
with self.assertRaises(Exception):
fix_and_parse_json(json_str, try_to_fix_with_gpt=False)
def test_invalid_json_leading_sentence_with_gpt(self):
"""Test that a REALLY invalid JSON string raises an error when try_to_fix_with_gpt is False"""
json_str = """I suggest we start by browsing the repository to find any issues that we can fix.
{
"command": {
"name": "browse_website",
"args":{
"url": "https://github.com/Torantulino/Auto-GPT"
}
},
"thoughts":
{
"text": "I suggest we start browsing the repository to find any issues that we can fix.",
"reasoning": "Browsing the repository will give us an idea of the current state of the codebase and identify any issues that we can address to improve the repo.",
"plan": "- Look through the repository to find any issues.\n- Investigate any issues to determine what needs to be fixed\n- Identify possible solutions to fix the issues\n- Open Pull Requests with fixes",
"criticism": "I should be careful while browsing so as not to accidentally introduce any new bugs or issues.",
"speak": "I will start browsing the repository to find any issues we can fix."
}
}"""
good_obj = {
"command": {
"name": "browse_website",
"args": {"url": "https://github.com/Torantulino/Auto-GPT"},
},
"thoughts": {
"text": "I suggest we start browsing the repository to find any issues that we can fix.",
"reasoning": "Browsing the repository will give us an idea of the current state of the codebase and identify any issues that we can address to improve the repo.",
"plan": "- Look through the repository to find any issues.\n- Investigate any issues to determine what needs to be fixed\n- Identify possible solutions to fix the issues\n- Open Pull Requests with fixes",
"criticism": "I should be careful while browsing so as not to accidentally introduce any new bugs or issues.",
"speak": "I will start browsing the repository to find any issues we can fix.",
},
}
# # Assert that this can be fixed with GPT
# self.assertEqual(fix_and_parse_json(json_str), good_obj)
# Assert that trying to fix this without GPT raises an exception
with self.assertRaises(Exception):
fix_and_parse_json(json_str, try_to_fix_with_gpt=False)

View File

@@ -1,114 +0,0 @@
# Generated by CodiumAI
from autogpt.json_utils.json_fix_llm import (
fix_and_parse_json,
fix_json_using_multiple_techniques,
)
"""
Code Analysis
Objective:
- The objective of the function is to fix a given JSON string to make it parseable and fully compliant with two techniques.
Inputs:
- The function takes in a string called 'assistant_reply', which is the JSON string to be fixed.
Flow:
- The function first calls the 'fix_and_parse_json' function to parse and print the Assistant response.
- If the parsed JSON is an empty dictionary, the function calls the 'attempt_to_fix_json_by_finding_outermost_brackets' function to fix the JSON string.
- If the parsed JSON is not an empty dictionary, the function returns the parsed JSON.
- If the parsed JSON is an empty dictionary and cannot be fixed, the function logs an error and returns an empty dictionary.
Outputs:
- The main output of the function is a dictionary containing the fixed JSON string.
Additional aspects:
- The function uses two techniques to fix the JSON string: parsing and finding outermost brackets.
- The function logs an error if the JSON string cannot be fixed and returns an empty dictionary.
- The function uses the 'CFG' object to determine whether to speak the error message or not.
"""
class TestFixJsonUsingMultipleTechniques:
# Tests that the function successfully fixes and parses a JSON string that is already compliant with both techniques.
def test_fix_and_parse_json_happy_path(self):
# Happy path test case where the JSON string is already compliant with both techniques
json_string = '{"text": "Hello world", "confidence": 0.9}'
expected_output = {"text": "Hello world", "confidence": 0.9}
assert fix_json_using_multiple_techniques(json_string) == expected_output
# Tests that the function successfully fixes and parses a JSON string that contains only whitespace characters.
# @requires_api_key("OPEN_API_KEY")
def test_fix_and_parse_json_whitespace(self, mocker):
# Happy path test case where the JSON string contains only whitespace characters
json_string = " \n\t "
# mock try_ai_fix to avoid calling the AI model:
mocker.patch("autogpt.json_utils.json_fix_llm.try_ai_fix", return_value={})
expected_output = {}
assert fix_json_using_multiple_techniques(json_string) == expected_output
# Tests that the function successfully converts a string with arrays to an array
def test_fix_and_parse_json_array(self):
# Happy path test case where the JSON string contains an array of string
json_string = '[ "Add type hints", "Move docstrings", "Consider using" ]'
expected_output = ["Add type hints", "Move docstrings", "Consider using"]
assert fix_json_using_multiple_techniques(json_string) == expected_output
# Tests that the function returns an empty dictionary when the JSON string is not parseable and cannot be fixed using either technique.
# @requires_api_key("OPEN_API_KEY")
def test_fix_and_parse_json_can_not(self, mocker):
# Edge case test case where the JSON string is not parseable and cannot be fixed using either technique
json_string = "This is not a JSON string"
# mock try_ai_fix to avoid calling the AI model:
mocker.patch("autogpt.json_utils.json_fix_llm.try_ai_fix", return_value={})
expected_output = {}
# Use the actual function name in the test
result = fix_json_using_multiple_techniques(json_string)
assert result == expected_output
# Tests that the function returns an empty dictionary when the JSON string is empty.
# @requires_api_key("OPEN_API_KEY")
def test_fix_and_parse_json_empty_string(self, mocker):
# Arrange
json_string = ""
# Act
# mock try_ai_fix to avoid calling the AI model:
mocker.patch("autogpt.json_utils.json_fix_llm.try_ai_fix", return_value={})
result = fix_and_parse_json(json_string)
# Assert
assert result == {}
# Tests that the function successfully fixes and parses a JSON string that contains escape characters.
def test_fix_and_parse_json_escape_characters(self):
# Arrange
json_string = '{"text": "This is a \\"test\\" string."}'
# Act
result = fix_json_using_multiple_techniques(json_string)
# Assert
assert result == {"text": 'This is a "test" string.'}
# Tests that the function successfully fixes and parses a JSON string that contains nested objects or arrays.
def test_fix_and_parse_json_nested_objects(self):
# Arrange
json_string = '{"person": {"name": "John", "age": 30}, "hobbies": ["reading", "swimming"]}'
# Act
result = fix_json_using_multiple_techniques(json_string)
# Assert
assert result == {
"person": {"name": "John", "age": 30},
"hobbies": ["reading", "swimming"],
}

View File

@@ -1,128 +0,0 @@
from unittest.mock import patch
import pytest
from openai.error import APIError, RateLimitError
from autogpt.llm import utils as llm_utils
@pytest.fixture(params=[RateLimitError, APIError])
def error(request):
if request.param == APIError:
return request.param("Error", http_status=502)
else:
return request.param("Error")
def error_factory(error_instance, error_count, retry_count, warn_user=True):
class RaisesError:
def __init__(self):
self.count = 0
@llm_utils.retry_openai_api(
num_retries=retry_count, backoff_base=0.001, warn_user=warn_user
)
def __call__(self):
self.count += 1
if self.count <= error_count:
raise error_instance
return self.count
return RaisesError()
def test_retry_open_api_no_error(capsys):
@llm_utils.retry_openai_api()
def f():
return 1
result = f()
assert result == 1
output = capsys.readouterr()
assert output.out == ""
assert output.err == ""
@pytest.mark.parametrize(
"error_count, retry_count, failure",
[(2, 10, False), (2, 2, False), (10, 2, True), (3, 2, True), (1, 0, True)],
ids=["passing", "passing_edge", "failing", "failing_edge", "failing_no_retries"],
)
def test_retry_open_api_passing(capsys, error, error_count, retry_count, failure):
call_count = min(error_count, retry_count) + 1
raises = error_factory(error, error_count, retry_count)
if failure:
with pytest.raises(type(error)):
raises()
else:
result = raises()
assert result == call_count
assert raises.count == call_count
output = capsys.readouterr()
if error_count and retry_count:
if type(error) == RateLimitError:
assert "Reached rate limit, passing..." in output.out
assert "Please double check" in output.out
if type(error) == APIError:
assert "API Bad gateway" in output.out
else:
assert output.out == ""
def test_retry_open_api_rate_limit_no_warn(capsys):
error_count = 2
retry_count = 10
raises = error_factory(RateLimitError, error_count, retry_count, warn_user=False)
result = raises()
call_count = min(error_count, retry_count) + 1
assert result == call_count
assert raises.count == call_count
output = capsys.readouterr()
assert "Reached rate limit, passing..." in output.out
assert "Please double check" not in output.out
def test_retry_openapi_other_api_error(capsys):
error_count = 2
retry_count = 10
raises = error_factory(APIError("Error", http_status=500), error_count, retry_count)
with pytest.raises(APIError):
raises()
call_count = 1
assert raises.count == call_count
output = capsys.readouterr()
assert output.out == ""
def test_check_model(api_manager):
"""
Test if check_model() returns original model when valid.
Test if check_model() returns gpt-3.5-turbo when model is invalid.
"""
with patch("openai.Model.list") as mock_list_models:
# Test when correct model is returned
mock_list_models.return_value = {"data": [{"id": "gpt-4"}]}
result = llm_utils.check_model("gpt-4", "smart_llm_model")
assert result == "gpt-4"
# Reset api manager models
api_manager.models = None
# Test when incorrect model is returned
mock_list_models.return_value = {"data": [{"id": "gpt-3.5-turbo"}]}
result = llm_utils.check_model("gpt-4", "fast_llm_model")
assert result == "gpt-3.5-turbo"
# Reset api manager models
api_manager.models = None

View File

@@ -1,8 +1,10 @@
import os
from unittest.mock import patch
import pytest
import requests
from autogpt.json_utils.utilities import extract_json_from_response, validate_json
from autogpt.utils import (
get_bulletin_from_web,
get_current_git_branch,
@@ -13,6 +15,37 @@ from autogpt.utils import (
from tests.utils import skip_in_ci
@pytest.fixture
def valid_json_response() -> dict:
return {
"thoughts": {
"text": "My task is complete. I will use the 'task_complete' command to shut down.",
"reasoning": "I will use the 'task_complete' command because it allows me to shut down and signal that my task is complete.",
"plan": "I will use the 'task_complete' command with the reason 'Task complete: retrieved Tesla's revenue in 2022.' to shut down.",
"criticism": "I need to ensure that I have completed all necessary tasks before shutting down.",
"speak": "",
},
"command": {
"name": "task_complete",
"args": {"reason": "Task complete: retrieved Tesla's revenue in 2022."},
},
}
@pytest.fixture
def invalid_json_response() -> dict:
return {
"thoughts": {
"text": "My task is complete. I will use the 'task_complete' command to shut down.",
"reasoning": "I will use the 'task_complete' command because it allows me to shut down and signal that my task is complete.",
"plan": "I will use the 'task_complete' command with the reason 'Task complete: retrieved Tesla's revenue in 2022.' to shut down.",
"criticism": "I need to ensure that I have completed all necessary tasks before shutting down.",
"speak": "",
},
"command": {"name": "", "args": {}},
}
def test_validate_yaml_file_valid():
with open("valid_test_file.yaml", "w") as f:
f.write("setting: value")
@@ -150,3 +183,25 @@ def test_get_current_git_branch_failure(mock_repo):
branch_name = get_current_git_branch()
assert branch_name == ""
def test_validate_json_valid(valid_json_response):
assert validate_json(valid_json_response)
def test_validate_json_invalid(invalid_json_response):
assert not validate_json(valid_json_response)
def test_extract_json_from_response(valid_json_response: dict):
emulated_response_from_openai = str(valid_json_response)
assert (
extract_json_from_response(emulated_response_from_openai) == valid_json_response
)
def test_extract_json_from_response_wrapped_in_code_block(valid_json_response: dict):
emulated_response_from_openai = "```" + str(valid_json_response) + "```"
assert (
extract_json_from_response(emulated_response_from_openai) == valid_json_response
)