mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
feat(backend): Enable json parsing with typing & conversion (#8578)
This commit is contained in:
@@ -97,7 +97,7 @@ class ExecutionResult(BaseModel):
|
||||
def from_db(execution: AgentNodeExecution):
|
||||
if execution.executionData:
|
||||
# Execution that has been queued for execution will persist its data.
|
||||
input_data = json.loads(execution.executionData)
|
||||
input_data = json.loads(execution.executionData, target_type=dict[str, Any])
|
||||
else:
|
||||
# For incomplete execution, executionData will not be yet available.
|
||||
input_data: BlockInput = defaultdict()
|
||||
|
||||
@@ -56,8 +56,8 @@ class Node(BaseDbModel):
|
||||
obj = Node(
|
||||
id=node.id,
|
||||
block_id=node.AgentBlock.id,
|
||||
input_default=json.loads(node.constantInput),
|
||||
metadata=json.loads(node.metadata),
|
||||
input_default=json.loads(node.constantInput, target_type=dict[str, Any]),
|
||||
metadata=json.loads(node.metadata, target_type=dict[str, Any]),
|
||||
)
|
||||
obj.input_links = [Link.from_db(link) for link in node.Input or []]
|
||||
obj.output_links = [Link.from_db(link) for link in node.Output or []]
|
||||
@@ -80,10 +80,13 @@ class GraphExecution(BaseDbModel):
|
||||
duration = (end_time - start_time).total_seconds()
|
||||
total_run_time = duration
|
||||
|
||||
if execution.stats:
|
||||
stats = json.loads(execution.stats)
|
||||
duration = stats.get("walltime", duration)
|
||||
total_run_time = stats.get("nodes_walltime", total_run_time)
|
||||
try:
|
||||
stats = json.loads(execution.stats or "{}", target_type=dict[str, Any])
|
||||
except ValueError:
|
||||
stats = {}
|
||||
|
||||
duration = stats.get("walltime", duration)
|
||||
total_run_time = stats.get("nodes_walltime", total_run_time)
|
||||
|
||||
return GraphExecution(
|
||||
id=execution.id,
|
||||
@@ -311,7 +314,9 @@ class Graph(BaseDbModel):
|
||||
def _process_node(node: AgentNode, hide_credentials: bool) -> Node:
|
||||
node_dict = node.model_dump()
|
||||
if hide_credentials and "constantInput" in node_dict:
|
||||
constant_input = json.loads(node_dict["constantInput"])
|
||||
constant_input = json.loads(
|
||||
node_dict["constantInput"], target_type=dict[str, Any]
|
||||
)
|
||||
constant_input = Graph._hide_credentials_in_input(constant_input)
|
||||
node_dict["constantInput"] = json.dumps(constant_input)
|
||||
return Node.from_db(AgentNode(**node_dict))
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import json
|
||||
from typing import Any, Type, TypeVar, overload
|
||||
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
|
||||
from .type import type_match
|
||||
|
||||
|
||||
def to_dict(data) -> dict:
|
||||
return jsonable_encoder(data)
|
||||
@@ -11,4 +14,19 @@ def dumps(data) -> str:
|
||||
return json.dumps(jsonable_encoder(data))
|
||||
|
||||
|
||||
loads = json.loads
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@overload
|
||||
def loads(data: str, *args, target_type: Type[T], **kwargs) -> T: ...
|
||||
|
||||
|
||||
@overload
|
||||
def loads(data: str, *args, **kwargs) -> Any: ...
|
||||
|
||||
|
||||
def loads(data: str, *args, target_type: Type[T] | None = None, **kwargs) -> Any:
|
||||
parsed = json.loads(data, *args, **kwargs)
|
||||
if target_type:
|
||||
return type_match(parsed, target_type)
|
||||
return parsed
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import json
|
||||
from typing import Any, Type, TypeVar, get_args, get_origin
|
||||
from typing import Any, Type, TypeVar, cast, get_args, get_origin
|
||||
|
||||
|
||||
class ConversionError(Exception):
|
||||
class ConversionError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
@@ -102,7 +102,7 @@ def __convert_bool(value: Any) -> bool:
|
||||
return bool(value)
|
||||
|
||||
|
||||
def convert(value: Any, target_type: Type):
|
||||
def _try_convert(value: Any, target_type: Type, raise_on_mismatch: bool) -> Any:
|
||||
origin = get_origin(target_type)
|
||||
args = get_args(target_type)
|
||||
if origin is None:
|
||||
@@ -133,6 +133,8 @@ def convert(value: Any, target_type: Type):
|
||||
return {convert(v, args[0]) for v in value}
|
||||
else:
|
||||
return value
|
||||
elif raise_on_mismatch:
|
||||
raise TypeError(f"Value {value} is not of expected type {target_type}")
|
||||
else:
|
||||
# Need to convert value to the origin type
|
||||
if origin is list:
|
||||
@@ -175,3 +177,17 @@ def convert(value: Any, target_type: Type):
|
||||
return __convert_bool(value)
|
||||
else:
|
||||
return value
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def type_match(value: Any, target_type: Type[T]) -> T:
|
||||
return cast(T, _try_convert(value, target_type, raise_on_mismatch=True))
|
||||
|
||||
|
||||
def convert(value: Any, target_type: Type[T]) -> T:
|
||||
try:
|
||||
return cast(T, _try_convert(value, target_type, raise_on_mismatch=False))
|
||||
except Exception as e:
|
||||
raise ConversionError(f"Failed to convert {value} to {target_type}") from e
|
||||
|
||||
Reference in New Issue
Block a user