feat(agent): Implement more fault tolerant json_loads function (#7016)

* Implement syntax fault tolerant `json_loads` function using `dem3json`
   - Add `dem3json` dependency

* Replace `json.loads` by `json_loads` in places where malformed JSON may occur

* Move `json_utils.py` to `autogpt/core/utils`

* Add tests for `json_utils`

---------

Co-authored-by: Reinier van der Leer <pwuts@agpt.co>
This commit is contained in:
Krzysztof Czerwinski
2024-03-21 18:11:36 +01:00
committed by GitHub
parent bca50310f6
commit 76d6e61941
15 changed files with 217 additions and 103 deletions

View File

@@ -8,7 +8,6 @@ from autogpt.core.prompting import (
LanguageModelClassification,
PromptStrategy,
)
from autogpt.core.prompting.utils import json_loads
from autogpt.core.resource.model_providers.schema import (
AssistantChatMessage,
ChatMessage,
@@ -16,6 +15,7 @@ from autogpt.core.resource.model_providers.schema import (
CompletionModelFunction,
)
from autogpt.core.utils.json_schema import JSONSchema
from autogpt.core.utils.json_utils import json_loads
logger = logging.getLogger(__name__)

View File

@@ -26,7 +26,7 @@ from autogpt.core.resource.model_providers.schema import (
CompletionModelFunction,
)
from autogpt.core.utils.json_schema import JSONSchema
from autogpt.json_utils.utilities import extract_dict_from_response
from autogpt.core.utils.json_utils import extract_dict_from_json, json_loads
from autogpt.prompts.utils import format_numbered_list, indent
@@ -386,7 +386,7 @@ class OneShotAgentPromptStrategy(PromptStrategy):
else f" '{response.content}'"
)
)
assistant_reply_dict = extract_dict_from_response(response.content)
assistant_reply_dict = extract_dict_from_json(response.content)
self.logger.debug(
"Validating object extracted from LLM response:\n"
f"{json.dumps(assistant_reply_dict, indent=4)}"
@@ -439,7 +439,7 @@ def extract_command(
raise InvalidAgentResponseError("No 'tool_calls' in assistant reply")
assistant_reply_json["command"] = {
"name": assistant_reply.tool_calls[0].function.name,
"args": json.loads(assistant_reply.tool_calls[0].function.arguments),
"args": json_loads(assistant_reply.tool_calls[0].function.arguments),
}
try:
if not isinstance(assistant_reply_json, dict):

View File

@@ -4,13 +4,14 @@ from autogpt.core.configuration import SystemConfiguration, UserConfigurable
from autogpt.core.planning.schema import Task, TaskType
from autogpt.core.prompting import PromptStrategy
from autogpt.core.prompting.schema import ChatPrompt, LanguageModelClassification
from autogpt.core.prompting.utils import json_loads, to_numbered_list
from autogpt.core.prompting.utils import to_numbered_list
from autogpt.core.resource.model_providers import (
AssistantChatMessage,
ChatMessage,
CompletionModelFunction,
)
from autogpt.core.utils.json_schema import JSONSchema
from autogpt.core.utils.json_utils import json_loads
logger = logging.getLogger(__name__)

View File

@@ -3,13 +3,13 @@ import logging
from autogpt.core.configuration import SystemConfiguration, UserConfigurable
from autogpt.core.prompting import PromptStrategy
from autogpt.core.prompting.schema import ChatPrompt, LanguageModelClassification
from autogpt.core.prompting.utils import json_loads
from autogpt.core.resource.model_providers import (
AssistantChatMessage,
ChatMessage,
CompletionModelFunction,
)
from autogpt.core.utils.json_schema import JSONSchema
from autogpt.core.utils.json_utils import json_loads
logger = logging.getLogger(__name__)

View File

@@ -4,13 +4,14 @@ from autogpt.core.configuration import SystemConfiguration, UserConfigurable
from autogpt.core.planning.schema import Task
from autogpt.core.prompting import PromptStrategy
from autogpt.core.prompting.schema import ChatPrompt, LanguageModelClassification
from autogpt.core.prompting.utils import json_loads, to_numbered_list
from autogpt.core.prompting.utils import to_numbered_list
from autogpt.core.resource.model_providers import (
AssistantChatMessage,
ChatMessage,
CompletionModelFunction,
)
from autogpt.core.utils.json_schema import JSONSchema
from autogpt.core.utils.json_utils import json_loads
logger = logging.getLogger(__name__)

View File

@@ -1,7 +1,3 @@
import ast
import json
def to_numbered_list(
items: list[str], no_items_response: str = "", **template_args
) -> str:
@@ -11,19 +7,3 @@ def to_numbered_list(
)
else:
return no_items_response
def json_loads(json_str: str):
# TODO: this is a hack function for now. We'll see what errors show up in testing.
# Can hopefully just replace with a call to ast.literal_eval.
# Can't use json.loads because the function API still sometimes returns json strings
# with minor issues like trailing commas.
try:
json_str = json_str[json_str.index("{") : json_str.rindex("}") + 1]
return ast.literal_eval(json_str)
except json.decoder.JSONDecodeError as e:
try:
print(f"json decode error {e}. trying literal eval")
return ast.literal_eval(json_str)
except Exception:
breakpoint()

View File

@@ -38,6 +38,7 @@ from autogpt.core.resource.model_providers.schema import (
ModelTokenizer,
)
from autogpt.core.utils.json_schema import JSONSchema
from autogpt.core.utils.json_utils import json_loads
_T = TypeVar("_T")
_P = ParamSpec("_P")
@@ -758,19 +759,18 @@ def _functions_compat_fix_kwargs(
def _tool_calls_compat_extract_calls(response: str) -> Iterator[AssistantToolCall]:
import json
import re
import uuid
logging.debug(f"Trying to extract tool calls from response:\n{response}")
if response[0] == "[":
tool_calls: list[AssistantToolCallDict] = json.loads(response)
tool_calls: list[AssistantToolCallDict] = json_loads(response)
else:
block = re.search(r"```(?:tool_calls)?\n(.*)\n```\s*$", response, re.DOTALL)
if not block:
raise ValueError("Could not find tool_calls block in response")
tool_calls: list[AssistantToolCallDict] = json.loads(block.group(1))
tool_calls: list[AssistantToolCallDict] = json_loads(block.group(1))
for t in tool_calls:
t["id"] = str(uuid.uuid4())

View File

@@ -0,0 +1,92 @@
import io
import logging
import re
from typing import Any
import demjson3
logger = logging.getLogger(__name__)
def json_loads(json_str: str) -> Any:
"""Parse a JSON string, tolerating minor syntax issues:
- Missing, extra and trailing commas
- Extraneous newlines and whitespace outside of string literals
- Inconsistent spacing after colons and commas
- Missing closing brackets or braces
- Numbers: binary, hex, octal, trailing and prefixed decimal points
- Different encodings
- Surrounding markdown code block
- Comments
Args:
json_str: The JSON string to parse.
Returns:
The parsed JSON object, same as built-in json.loads.
"""
# Remove possible code block
pattern = r"```(?:json|JSON)*([\s\S]*?)```"
match = re.search(pattern, json_str)
if match:
json_str = match.group(1).strip()
error_buffer = io.StringIO()
json_result = demjson3.decode(
json_str, return_errors=True, write_errors=error_buffer
)
if error_buffer.getvalue():
logger.debug(f"JSON parse errors:\n{error_buffer.getvalue()}")
if json_result is None:
raise ValueError(f"Failed to parse JSON string: {json_str}")
return json_result.object
def extract_dict_from_json(json_str: str) -> dict[str, Any]:
# Sometimes the response includes the JSON in a code block with ```
pattern = r"```(?:json|JSON)*([\s\S]*?)```"
match = re.search(pattern, json_str)
if match:
json_str = match.group(1).strip()
else:
# The string may contain JSON.
json_pattern = r"{[\s\S]*}"
match = re.search(json_pattern, json_str)
if match:
json_str = match.group()
result = json_loads(json_str)
if not isinstance(result, dict):
raise ValueError(
f"Response '''{json_str}''' evaluated to non-dict value {repr(result)}"
)
return result
def extract_list_from_json(json_str: str) -> list[Any]:
# Sometimes the response includes the JSON in a code block with ```
pattern = r"```(?:json|JSON)*([\s\S]*?)```"
match = re.search(pattern, json_str)
if match:
json_str = match.group(1).strip()
else:
# The string may contain JSON.
json_pattern = r"\[[\s\S]*\]"
match = re.search(json_pattern, json_str)
if match:
json_str = match.group()
result = json_loads(json_str)
if not isinstance(result, list):
raise ValueError(
f"Response '''{json_str}''' evaluated to non-list value {repr(result)}"
)
return result

View File

@@ -1,55 +0,0 @@
"""Utilities for the json_fixes package."""
import json
import logging
import re
from typing import Any
logger = logging.getLogger(__name__)
def extract_dict_from_response(response_content: str) -> dict[str, Any]:
# Sometimes the response includes the JSON in a code block with ```
pattern = r"```(?:json|JSON)*([\s\S]*?)```"
match = re.search(pattern, response_content)
if match:
response_content = match.group(1).strip()
else:
# The string may contain JSON.
json_pattern = r"{[\s\S]*}"
match = re.search(json_pattern, response_content)
if match:
response_content = match.group()
result = json.loads(response_content)
if not isinstance(result, dict):
raise ValueError(
f"Response '''{response_content}''' evaluated to "
f"non-dict value {repr(result)}"
)
return result
def extract_list_from_response(response_content: str) -> list[Any]:
# Sometimes the response includes the JSON in a code block with ```
pattern = r"```(?:json|JSON)*([\s\S]*?)```"
match = re.search(pattern, response_content)
if match:
response_content = match.group(1).strip()
else:
# The string may contain JSON.
json_pattern = r"\[[\s\S]*\]"
match = re.search(json_pattern, response_content)
if match:
response_content = match.group()
result = json.loads(response_content)
if not isinstance(result, list):
raise ValueError(
f"Response '''{response_content}''' evaluated to "
f"non-list value {repr(result)}"
)
return result

View File

@@ -1,4 +1,5 @@
"""Text processing functions"""
import logging
import math
from typing import Iterator, Optional, TypeVar
@@ -12,7 +13,7 @@ from autogpt.core.resource.model_providers import (
ChatModelProvider,
ModelTokenizer,
)
from autogpt.json_utils.utilities import extract_list_from_response
from autogpt.core.utils.json_utils import extract_list_from_json
logger = logging.getLogger(__name__)
@@ -161,9 +162,7 @@ async def _process_text(
temperature=0.5,
max_tokens=max_result_tokens,
completion_parser=lambda s: (
extract_list_from_response(s.content)
if output_type is not str
else None
extract_list_from_json(s.content) if output_type is not str else None
),
)

View File

@@ -1616,6 +1616,16 @@ files = [
{file = "decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330"},
]
[[package]]
name = "demjson3"
version = "3.0.6"
description = "encoder, decoder, and lint/validator for JSON (JavaScript Object Notation) compliant with RFC 7159"
optional = false
python-versions = "*"
files = [
{file = "demjson3-3.0.6.tar.gz", hash = "sha256:37c83b0c6eb08d25defc88df0a2a4875d58a7809a9650bd6eee7afd8053cdbac"},
]
[[package]]
name = "deprecated"
version = "1.2.14"
@@ -7248,4 +7258,4 @@ benchmark = ["agbenchmark"]
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
content-hash = "a09e20daaf94e05457ded6a9989585cd37edf96a036c5f9e505f0b2456403a25"
content-hash = "9e28a0449253ec931297aa655fcd09da5e9f5d57bd73863419ce4f477018ef8a"

View File

@@ -30,6 +30,7 @@ boto3 = "^1.33.6"
charset-normalizer = "^3.1.0"
click = "*"
colorama = "^0.4.6"
demjson3 = "^3.0.0"
distro = "^1.8.0"
docker = "*"
duckduckgo-search = "^4.0.0"

View File

@@ -0,0 +1,93 @@
import json
import pytest
from autogpt.core.utils.json_utils import json_loads
_JSON_FIXABLE: list[tuple[str, str]] = [
# Missing comma
('{"name": "John Doe" "age": 30,}', '{"name": "John Doe", "age": 30}'),
("[1, 2 3]", "[1, 2, 3]"),
# Trailing comma
('{"name": "John Doe", "age": 30,}', '{"name": "John Doe", "age": 30}'),
("[1, 2, 3,]", "[1, 2, 3]"),
# Extra comma in object
('{"name": "John Doe",, "age": 30}', '{"name": "John Doe", "age": 30}'),
# Extra newlines
('{"name": "John Doe",\n"age": 30}', '{"name": "John Doe", "age": 30}'),
("[1, 2,\n3]", "[1, 2, 3]"),
# Missing closing brace or bracket
('{"name": "John Doe", "age": 30', '{"name": "John Doe", "age": 30}'),
("[1, 2, 3", "[1, 2, 3]"),
# Different numerals
("[+1, ---2, .5, +-4.5, 123.]", "[1, -2, 0.5, -4.5, 123]"),
('{"bin": 0b1001, "hex": 0x1A, "oct": 0o17}', '{"bin": 9, "hex": 26, "oct": 15}'),
# Broken array
(
'[1, 2 3, "yes" true, false null, 25, {"obj": "var"}',
'[1, 2, 3, "yes", true, false, null, 25, {"obj": "var"}]',
),
# Codeblock
(
'```json\n{"name": "John Doe", "age": 30}\n```',
'{"name": "John Doe", "age": 30}',
),
# Mutliple problems
(
'{"name":"John Doe" "age": 30\n "empty": "","address": '
"// random comment\n"
'{"city": "New York", "state": "NY"},'
'"skills": ["Python" "C++", "Java",""],',
'{"name": "John Doe", "age": 30, "empty": "", "address": '
'{"city": "New York", "state": "NY"}, '
'"skills": ["Python", "C++", "Java", ""]}',
),
# All good
(
'{"name": "John Doe", "age": 30, "address": '
'{"city": "New York", "state": "NY"}, '
'"skills": ["Python", "C++", "Java"]}',
'{"name": "John Doe", "age": 30, "address": '
'{"city": "New York", "state": "NY"}, '
'"skills": ["Python", "C++", "Java"]}',
),
("true", "true"),
("false", "false"),
("null", "null"),
("123.5", "123.5"),
('"Hello, World!"', '"Hello, World!"'),
("{}", "{}"),
("[]", "[]"),
]
_JSON_UNFIXABLE: list[tuple[str, str]] = [
# Broken booleans and null
("[TRUE, False, NULL]", "[true, false, null]"),
# Missing values in array
("[1, , 3]", "[1, 3]"),
# Leading zeros (are treated as octal)
("[0023, 015]", "[23, 15]"),
# Missing quotes
('{"name": John Doe}', '{"name": "John Doe"}'),
# Missing opening braces or bracket
('"name": "John Doe"}', '{"name": "John Doe"}'),
("1, 2, 3]", "[1, 2, 3]"),
]
@pytest.fixture(params=_JSON_FIXABLE)
def fixable_json(request: pytest.FixtureRequest) -> tuple[str, str]:
return request.param
@pytest.fixture(params=_JSON_UNFIXABLE)
def unfixable_json(request: pytest.FixtureRequest) -> tuple[str, str]:
return request.param
def test_json_loads_fixable(fixable_json: tuple[str, str]):
assert json_loads(fixable_json[0]) == json.loads(fixable_json[1])
def test_json_loads_unfixable(unfixable_json: tuple[str, str]):
assert json_loads(unfixable_json[0]) != json.loads(unfixable_json[1])

View File

@@ -14,7 +14,7 @@ from autogpt.app.utils import (
get_latest_bulletin,
set_env_config_value,
)
from autogpt.json_utils.utilities import extract_dict_from_response
from autogpt.core.utils.json_utils import extract_dict_from_json
from autogpt.utils import validate_yaml_file
from tests.utils import skip_in_ci
@@ -199,34 +199,26 @@ def test_get_current_git_branch_failure(mock_repo):
def test_extract_json_from_response(valid_json_response: dict):
emulated_response_from_openai = json.dumps(valid_json_response)
assert (
extract_dict_from_response(emulated_response_from_openai) == valid_json_response
)
assert extract_dict_from_json(emulated_response_from_openai) == valid_json_response
def test_extract_json_from_response_wrapped_in_code_block(valid_json_response: dict):
emulated_response_from_openai = "```" + json.dumps(valid_json_response) + "```"
assert (
extract_dict_from_response(emulated_response_from_openai) == valid_json_response
)
assert extract_dict_from_json(emulated_response_from_openai) == valid_json_response
def test_extract_json_from_response_wrapped_in_code_block_with_language(
valid_json_response: dict,
):
emulated_response_from_openai = "```json" + json.dumps(valid_json_response) + "```"
assert (
extract_dict_from_response(emulated_response_from_openai) == valid_json_response
)
assert extract_dict_from_json(emulated_response_from_openai) == valid_json_response
def test_extract_json_from_response_json_contained_in_string(valid_json_response: dict):
emulated_response_from_openai = (
"sentence1" + json.dumps(valid_json_response) + "sentence2"
)
assert (
extract_dict_from_response(emulated_response_from_openai) == valid_json_response
)
assert extract_dict_from_json(emulated_response_from_openai) == valid_json_response
@pytest.fixture