mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-09 23:28:07 -05:00
feat(agent): Implement more fault tolerant json_loads function (#7016)
* Implement syntax fault tolerant `json_loads` function using `dem3json` - Add `dem3json` dependency * Replace `json.loads` by `json_loads` in places where malformed JSON may occur * Move `json_utils.py` to `autogpt/core/utils` * Add tests for `json_utils` --------- Co-authored-by: Reinier van der Leer <pwuts@agpt.co>
This commit is contained in:
committed by
GitHub
parent
bca50310f6
commit
76d6e61941
@@ -8,7 +8,6 @@ from autogpt.core.prompting import (
|
||||
LanguageModelClassification,
|
||||
PromptStrategy,
|
||||
)
|
||||
from autogpt.core.prompting.utils import json_loads
|
||||
from autogpt.core.resource.model_providers.schema import (
|
||||
AssistantChatMessage,
|
||||
ChatMessage,
|
||||
@@ -16,6 +15,7 @@ from autogpt.core.resource.model_providers.schema import (
|
||||
CompletionModelFunction,
|
||||
)
|
||||
from autogpt.core.utils.json_schema import JSONSchema
|
||||
from autogpt.core.utils.json_utils import json_loads
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ from autogpt.core.resource.model_providers.schema import (
|
||||
CompletionModelFunction,
|
||||
)
|
||||
from autogpt.core.utils.json_schema import JSONSchema
|
||||
from autogpt.json_utils.utilities import extract_dict_from_response
|
||||
from autogpt.core.utils.json_utils import extract_dict_from_json, json_loads
|
||||
from autogpt.prompts.utils import format_numbered_list, indent
|
||||
|
||||
|
||||
@@ -386,7 +386,7 @@ class OneShotAgentPromptStrategy(PromptStrategy):
|
||||
else f" '{response.content}'"
|
||||
)
|
||||
)
|
||||
assistant_reply_dict = extract_dict_from_response(response.content)
|
||||
assistant_reply_dict = extract_dict_from_json(response.content)
|
||||
self.logger.debug(
|
||||
"Validating object extracted from LLM response:\n"
|
||||
f"{json.dumps(assistant_reply_dict, indent=4)}"
|
||||
@@ -439,7 +439,7 @@ def extract_command(
|
||||
raise InvalidAgentResponseError("No 'tool_calls' in assistant reply")
|
||||
assistant_reply_json["command"] = {
|
||||
"name": assistant_reply.tool_calls[0].function.name,
|
||||
"args": json.loads(assistant_reply.tool_calls[0].function.arguments),
|
||||
"args": json_loads(assistant_reply.tool_calls[0].function.arguments),
|
||||
}
|
||||
try:
|
||||
if not isinstance(assistant_reply_json, dict):
|
||||
|
||||
@@ -4,13 +4,14 @@ from autogpt.core.configuration import SystemConfiguration, UserConfigurable
|
||||
from autogpt.core.planning.schema import Task, TaskType
|
||||
from autogpt.core.prompting import PromptStrategy
|
||||
from autogpt.core.prompting.schema import ChatPrompt, LanguageModelClassification
|
||||
from autogpt.core.prompting.utils import json_loads, to_numbered_list
|
||||
from autogpt.core.prompting.utils import to_numbered_list
|
||||
from autogpt.core.resource.model_providers import (
|
||||
AssistantChatMessage,
|
||||
ChatMessage,
|
||||
CompletionModelFunction,
|
||||
)
|
||||
from autogpt.core.utils.json_schema import JSONSchema
|
||||
from autogpt.core.utils.json_utils import json_loads
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -3,13 +3,13 @@ import logging
|
||||
from autogpt.core.configuration import SystemConfiguration, UserConfigurable
|
||||
from autogpt.core.prompting import PromptStrategy
|
||||
from autogpt.core.prompting.schema import ChatPrompt, LanguageModelClassification
|
||||
from autogpt.core.prompting.utils import json_loads
|
||||
from autogpt.core.resource.model_providers import (
|
||||
AssistantChatMessage,
|
||||
ChatMessage,
|
||||
CompletionModelFunction,
|
||||
)
|
||||
from autogpt.core.utils.json_schema import JSONSchema
|
||||
from autogpt.core.utils.json_utils import json_loads
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -4,13 +4,14 @@ from autogpt.core.configuration import SystemConfiguration, UserConfigurable
|
||||
from autogpt.core.planning.schema import Task
|
||||
from autogpt.core.prompting import PromptStrategy
|
||||
from autogpt.core.prompting.schema import ChatPrompt, LanguageModelClassification
|
||||
from autogpt.core.prompting.utils import json_loads, to_numbered_list
|
||||
from autogpt.core.prompting.utils import to_numbered_list
|
||||
from autogpt.core.resource.model_providers import (
|
||||
AssistantChatMessage,
|
||||
ChatMessage,
|
||||
CompletionModelFunction,
|
||||
)
|
||||
from autogpt.core.utils.json_schema import JSONSchema
|
||||
from autogpt.core.utils.json_utils import json_loads
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -1,7 +1,3 @@
|
||||
import ast
|
||||
import json
|
||||
|
||||
|
||||
def to_numbered_list(
|
||||
items: list[str], no_items_response: str = "", **template_args
|
||||
) -> str:
|
||||
@@ -11,19 +7,3 @@ def to_numbered_list(
|
||||
)
|
||||
else:
|
||||
return no_items_response
|
||||
|
||||
|
||||
def json_loads(json_str: str):
|
||||
# TODO: this is a hack function for now. We'll see what errors show up in testing.
|
||||
# Can hopefully just replace with a call to ast.literal_eval.
|
||||
# Can't use json.loads because the function API still sometimes returns json strings
|
||||
# with minor issues like trailing commas.
|
||||
try:
|
||||
json_str = json_str[json_str.index("{") : json_str.rindex("}") + 1]
|
||||
return ast.literal_eval(json_str)
|
||||
except json.decoder.JSONDecodeError as e:
|
||||
try:
|
||||
print(f"json decode error {e}. trying literal eval")
|
||||
return ast.literal_eval(json_str)
|
||||
except Exception:
|
||||
breakpoint()
|
||||
|
||||
@@ -38,6 +38,7 @@ from autogpt.core.resource.model_providers.schema import (
|
||||
ModelTokenizer,
|
||||
)
|
||||
from autogpt.core.utils.json_schema import JSONSchema
|
||||
from autogpt.core.utils.json_utils import json_loads
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_P = ParamSpec("_P")
|
||||
@@ -758,19 +759,18 @@ def _functions_compat_fix_kwargs(
|
||||
|
||||
|
||||
def _tool_calls_compat_extract_calls(response: str) -> Iterator[AssistantToolCall]:
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
|
||||
logging.debug(f"Trying to extract tool calls from response:\n{response}")
|
||||
|
||||
if response[0] == "[":
|
||||
tool_calls: list[AssistantToolCallDict] = json.loads(response)
|
||||
tool_calls: list[AssistantToolCallDict] = json_loads(response)
|
||||
else:
|
||||
block = re.search(r"```(?:tool_calls)?\n(.*)\n```\s*$", response, re.DOTALL)
|
||||
if not block:
|
||||
raise ValueError("Could not find tool_calls block in response")
|
||||
tool_calls: list[AssistantToolCallDict] = json.loads(block.group(1))
|
||||
tool_calls: list[AssistantToolCallDict] = json_loads(block.group(1))
|
||||
|
||||
for t in tool_calls:
|
||||
t["id"] = str(uuid.uuid4())
|
||||
|
||||
92
autogpts/autogpt/autogpt/core/utils/json_utils.py
Normal file
92
autogpts/autogpt/autogpt/core/utils/json_utils.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import io
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
import demjson3
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def json_loads(json_str: str) -> Any:
|
||||
"""Parse a JSON string, tolerating minor syntax issues:
|
||||
- Missing, extra and trailing commas
|
||||
- Extraneous newlines and whitespace outside of string literals
|
||||
- Inconsistent spacing after colons and commas
|
||||
- Missing closing brackets or braces
|
||||
- Numbers: binary, hex, octal, trailing and prefixed decimal points
|
||||
- Different encodings
|
||||
- Surrounding markdown code block
|
||||
- Comments
|
||||
|
||||
Args:
|
||||
json_str: The JSON string to parse.
|
||||
|
||||
Returns:
|
||||
The parsed JSON object, same as built-in json.loads.
|
||||
"""
|
||||
# Remove possible code block
|
||||
pattern = r"```(?:json|JSON)*([\s\S]*?)```"
|
||||
match = re.search(pattern, json_str)
|
||||
|
||||
if match:
|
||||
json_str = match.group(1).strip()
|
||||
|
||||
error_buffer = io.StringIO()
|
||||
json_result = demjson3.decode(
|
||||
json_str, return_errors=True, write_errors=error_buffer
|
||||
)
|
||||
|
||||
if error_buffer.getvalue():
|
||||
logger.debug(f"JSON parse errors:\n{error_buffer.getvalue()}")
|
||||
|
||||
if json_result is None:
|
||||
raise ValueError(f"Failed to parse JSON string: {json_str}")
|
||||
|
||||
return json_result.object
|
||||
|
||||
|
||||
def extract_dict_from_json(json_str: str) -> dict[str, Any]:
|
||||
# Sometimes the response includes the JSON in a code block with ```
|
||||
pattern = r"```(?:json|JSON)*([\s\S]*?)```"
|
||||
match = re.search(pattern, json_str)
|
||||
|
||||
if match:
|
||||
json_str = match.group(1).strip()
|
||||
else:
|
||||
# The string may contain JSON.
|
||||
json_pattern = r"{[\s\S]*}"
|
||||
match = re.search(json_pattern, json_str)
|
||||
|
||||
if match:
|
||||
json_str = match.group()
|
||||
|
||||
result = json_loads(json_str)
|
||||
if not isinstance(result, dict):
|
||||
raise ValueError(
|
||||
f"Response '''{json_str}''' evaluated to non-dict value {repr(result)}"
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def extract_list_from_json(json_str: str) -> list[Any]:
|
||||
# Sometimes the response includes the JSON in a code block with ```
|
||||
pattern = r"```(?:json|JSON)*([\s\S]*?)```"
|
||||
match = re.search(pattern, json_str)
|
||||
|
||||
if match:
|
||||
json_str = match.group(1).strip()
|
||||
else:
|
||||
# The string may contain JSON.
|
||||
json_pattern = r"\[[\s\S]*\]"
|
||||
match = re.search(json_pattern, json_str)
|
||||
|
||||
if match:
|
||||
json_str = match.group()
|
||||
|
||||
result = json_loads(json_str)
|
||||
if not isinstance(result, list):
|
||||
raise ValueError(
|
||||
f"Response '''{json_str}''' evaluated to non-list value {repr(result)}"
|
||||
)
|
||||
return result
|
||||
@@ -1,55 +0,0 @@
|
||||
"""Utilities for the json_fixes package."""
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def extract_dict_from_response(response_content: str) -> dict[str, Any]:
|
||||
# Sometimes the response includes the JSON in a code block with ```
|
||||
pattern = r"```(?:json|JSON)*([\s\S]*?)```"
|
||||
match = re.search(pattern, response_content)
|
||||
|
||||
if match:
|
||||
response_content = match.group(1).strip()
|
||||
else:
|
||||
# The string may contain JSON.
|
||||
json_pattern = r"{[\s\S]*}"
|
||||
match = re.search(json_pattern, response_content)
|
||||
|
||||
if match:
|
||||
response_content = match.group()
|
||||
|
||||
result = json.loads(response_content)
|
||||
if not isinstance(result, dict):
|
||||
raise ValueError(
|
||||
f"Response '''{response_content}''' evaluated to "
|
||||
f"non-dict value {repr(result)}"
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def extract_list_from_response(response_content: str) -> list[Any]:
|
||||
# Sometimes the response includes the JSON in a code block with ```
|
||||
pattern = r"```(?:json|JSON)*([\s\S]*?)```"
|
||||
match = re.search(pattern, response_content)
|
||||
|
||||
if match:
|
||||
response_content = match.group(1).strip()
|
||||
else:
|
||||
# The string may contain JSON.
|
||||
json_pattern = r"\[[\s\S]*\]"
|
||||
match = re.search(json_pattern, response_content)
|
||||
|
||||
if match:
|
||||
response_content = match.group()
|
||||
|
||||
result = json.loads(response_content)
|
||||
if not isinstance(result, list):
|
||||
raise ValueError(
|
||||
f"Response '''{response_content}''' evaluated to "
|
||||
f"non-list value {repr(result)}"
|
||||
)
|
||||
return result
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Text processing functions"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
from typing import Iterator, Optional, TypeVar
|
||||
@@ -12,7 +13,7 @@ from autogpt.core.resource.model_providers import (
|
||||
ChatModelProvider,
|
||||
ModelTokenizer,
|
||||
)
|
||||
from autogpt.json_utils.utilities import extract_list_from_response
|
||||
from autogpt.core.utils.json_utils import extract_list_from_json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -161,9 +162,7 @@ async def _process_text(
|
||||
temperature=0.5,
|
||||
max_tokens=max_result_tokens,
|
||||
completion_parser=lambda s: (
|
||||
extract_list_from_response(s.content)
|
||||
if output_type is not str
|
||||
else None
|
||||
extract_list_from_json(s.content) if output_type is not str else None
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
12
autogpts/autogpt/poetry.lock
generated
12
autogpts/autogpt/poetry.lock
generated
@@ -1616,6 +1616,16 @@ files = [
|
||||
{file = "decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "demjson3"
|
||||
version = "3.0.6"
|
||||
description = "encoder, decoder, and lint/validator for JSON (JavaScript Object Notation) compliant with RFC 7159"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "demjson3-3.0.6.tar.gz", hash = "sha256:37c83b0c6eb08d25defc88df0a2a4875d58a7809a9650bd6eee7afd8053cdbac"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "deprecated"
|
||||
version = "1.2.14"
|
||||
@@ -7248,4 +7258,4 @@ benchmark = ["agbenchmark"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.10"
|
||||
content-hash = "a09e20daaf94e05457ded6a9989585cd37edf96a036c5f9e505f0b2456403a25"
|
||||
content-hash = "9e28a0449253ec931297aa655fcd09da5e9f5d57bd73863419ce4f477018ef8a"
|
||||
|
||||
@@ -30,6 +30,7 @@ boto3 = "^1.33.6"
|
||||
charset-normalizer = "^3.1.0"
|
||||
click = "*"
|
||||
colorama = "^0.4.6"
|
||||
demjson3 = "^3.0.0"
|
||||
distro = "^1.8.0"
|
||||
docker = "*"
|
||||
duckduckgo-search = "^4.0.0"
|
||||
|
||||
93
autogpts/autogpt/tests/unit/test_json_utils.py
Normal file
93
autogpts/autogpt/tests/unit/test_json_utils.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from autogpt.core.utils.json_utils import json_loads
|
||||
|
||||
_JSON_FIXABLE: list[tuple[str, str]] = [
|
||||
# Missing comma
|
||||
('{"name": "John Doe" "age": 30,}', '{"name": "John Doe", "age": 30}'),
|
||||
("[1, 2 3]", "[1, 2, 3]"),
|
||||
# Trailing comma
|
||||
('{"name": "John Doe", "age": 30,}', '{"name": "John Doe", "age": 30}'),
|
||||
("[1, 2, 3,]", "[1, 2, 3]"),
|
||||
# Extra comma in object
|
||||
('{"name": "John Doe",, "age": 30}', '{"name": "John Doe", "age": 30}'),
|
||||
# Extra newlines
|
||||
('{"name": "John Doe",\n"age": 30}', '{"name": "John Doe", "age": 30}'),
|
||||
("[1, 2,\n3]", "[1, 2, 3]"),
|
||||
# Missing closing brace or bracket
|
||||
('{"name": "John Doe", "age": 30', '{"name": "John Doe", "age": 30}'),
|
||||
("[1, 2, 3", "[1, 2, 3]"),
|
||||
# Different numerals
|
||||
("[+1, ---2, .5, +-4.5, 123.]", "[1, -2, 0.5, -4.5, 123]"),
|
||||
('{"bin": 0b1001, "hex": 0x1A, "oct": 0o17}', '{"bin": 9, "hex": 26, "oct": 15}'),
|
||||
# Broken array
|
||||
(
|
||||
'[1, 2 3, "yes" true, false null, 25, {"obj": "var"}',
|
||||
'[1, 2, 3, "yes", true, false, null, 25, {"obj": "var"}]',
|
||||
),
|
||||
# Codeblock
|
||||
(
|
||||
'```json\n{"name": "John Doe", "age": 30}\n```',
|
||||
'{"name": "John Doe", "age": 30}',
|
||||
),
|
||||
# Mutliple problems
|
||||
(
|
||||
'{"name":"John Doe" "age": 30\n "empty": "","address": '
|
||||
"// random comment\n"
|
||||
'{"city": "New York", "state": "NY"},'
|
||||
'"skills": ["Python" "C++", "Java",""],',
|
||||
'{"name": "John Doe", "age": 30, "empty": "", "address": '
|
||||
'{"city": "New York", "state": "NY"}, '
|
||||
'"skills": ["Python", "C++", "Java", ""]}',
|
||||
),
|
||||
# All good
|
||||
(
|
||||
'{"name": "John Doe", "age": 30, "address": '
|
||||
'{"city": "New York", "state": "NY"}, '
|
||||
'"skills": ["Python", "C++", "Java"]}',
|
||||
'{"name": "John Doe", "age": 30, "address": '
|
||||
'{"city": "New York", "state": "NY"}, '
|
||||
'"skills": ["Python", "C++", "Java"]}',
|
||||
),
|
||||
("true", "true"),
|
||||
("false", "false"),
|
||||
("null", "null"),
|
||||
("123.5", "123.5"),
|
||||
('"Hello, World!"', '"Hello, World!"'),
|
||||
("{}", "{}"),
|
||||
("[]", "[]"),
|
||||
]
|
||||
|
||||
_JSON_UNFIXABLE: list[tuple[str, str]] = [
|
||||
# Broken booleans and null
|
||||
("[TRUE, False, NULL]", "[true, false, null]"),
|
||||
# Missing values in array
|
||||
("[1, , 3]", "[1, 3]"),
|
||||
# Leading zeros (are treated as octal)
|
||||
("[0023, 015]", "[23, 15]"),
|
||||
# Missing quotes
|
||||
('{"name": John Doe}', '{"name": "John Doe"}'),
|
||||
# Missing opening braces or bracket
|
||||
('"name": "John Doe"}', '{"name": "John Doe"}'),
|
||||
("1, 2, 3]", "[1, 2, 3]"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(params=_JSON_FIXABLE)
|
||||
def fixable_json(request: pytest.FixtureRequest) -> tuple[str, str]:
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(params=_JSON_UNFIXABLE)
|
||||
def unfixable_json(request: pytest.FixtureRequest) -> tuple[str, str]:
|
||||
return request.param
|
||||
|
||||
|
||||
def test_json_loads_fixable(fixable_json: tuple[str, str]):
|
||||
assert json_loads(fixable_json[0]) == json.loads(fixable_json[1])
|
||||
|
||||
|
||||
def test_json_loads_unfixable(unfixable_json: tuple[str, str]):
|
||||
assert json_loads(unfixable_json[0]) != json.loads(unfixable_json[1])
|
||||
@@ -14,7 +14,7 @@ from autogpt.app.utils import (
|
||||
get_latest_bulletin,
|
||||
set_env_config_value,
|
||||
)
|
||||
from autogpt.json_utils.utilities import extract_dict_from_response
|
||||
from autogpt.core.utils.json_utils import extract_dict_from_json
|
||||
from autogpt.utils import validate_yaml_file
|
||||
from tests.utils import skip_in_ci
|
||||
|
||||
@@ -199,34 +199,26 @@ def test_get_current_git_branch_failure(mock_repo):
|
||||
|
||||
def test_extract_json_from_response(valid_json_response: dict):
|
||||
emulated_response_from_openai = json.dumps(valid_json_response)
|
||||
assert (
|
||||
extract_dict_from_response(emulated_response_from_openai) == valid_json_response
|
||||
)
|
||||
assert extract_dict_from_json(emulated_response_from_openai) == valid_json_response
|
||||
|
||||
|
||||
def test_extract_json_from_response_wrapped_in_code_block(valid_json_response: dict):
|
||||
emulated_response_from_openai = "```" + json.dumps(valid_json_response) + "```"
|
||||
assert (
|
||||
extract_dict_from_response(emulated_response_from_openai) == valid_json_response
|
||||
)
|
||||
assert extract_dict_from_json(emulated_response_from_openai) == valid_json_response
|
||||
|
||||
|
||||
def test_extract_json_from_response_wrapped_in_code_block_with_language(
|
||||
valid_json_response: dict,
|
||||
):
|
||||
emulated_response_from_openai = "```json" + json.dumps(valid_json_response) + "```"
|
||||
assert (
|
||||
extract_dict_from_response(emulated_response_from_openai) == valid_json_response
|
||||
)
|
||||
assert extract_dict_from_json(emulated_response_from_openai) == valid_json_response
|
||||
|
||||
|
||||
def test_extract_json_from_response_json_contained_in_string(valid_json_response: dict):
|
||||
emulated_response_from_openai = (
|
||||
"sentence1" + json.dumps(valid_json_response) + "sentence2"
|
||||
)
|
||||
assert (
|
||||
extract_dict_from_response(emulated_response_from_openai) == valid_json_response
|
||||
)
|
||||
assert extract_dict_from_json(emulated_response_from_openai) == valid_json_response
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
Reference in New Issue
Block a user