feat(backend): Enable json parsing with typing & conversion (#8578)

This commit is contained in:
Zamil Majdy
2024-11-15 14:28:59 +04:00
committed by GitHub
parent 6a1cea4c4e
commit 8987fdd48c
4 changed files with 51 additions and 12 deletions

View File

@@ -97,7 +97,7 @@ class ExecutionResult(BaseModel):
def from_db(execution: AgentNodeExecution):
if execution.executionData:
# Execution that has been queued for execution will persist its data.
input_data = json.loads(execution.executionData)
input_data = json.loads(execution.executionData, target_type=dict[str, Any])
else:
# For incomplete execution, executionData will not be yet available.
input_data: BlockInput = defaultdict()

View File

@@ -56,8 +56,8 @@ class Node(BaseDbModel):
obj = Node(
id=node.id,
block_id=node.AgentBlock.id,
input_default=json.loads(node.constantInput),
metadata=json.loads(node.metadata),
input_default=json.loads(node.constantInput, target_type=dict[str, Any]),
metadata=json.loads(node.metadata, target_type=dict[str, Any]),
)
obj.input_links = [Link.from_db(link) for link in node.Input or []]
obj.output_links = [Link.from_db(link) for link in node.Output or []]
@@ -80,10 +80,13 @@ class GraphExecution(BaseDbModel):
duration = (end_time - start_time).total_seconds()
total_run_time = duration
if execution.stats:
stats = json.loads(execution.stats)
duration = stats.get("walltime", duration)
total_run_time = stats.get("nodes_walltime", total_run_time)
try:
stats = json.loads(execution.stats or "{}", target_type=dict[str, Any])
except ValueError:
stats = {}
duration = stats.get("walltime", duration)
total_run_time = stats.get("nodes_walltime", total_run_time)
return GraphExecution(
id=execution.id,
@@ -311,7 +314,9 @@ class Graph(BaseDbModel):
def _process_node(node: AgentNode, hide_credentials: bool) -> Node:
node_dict = node.model_dump()
if hide_credentials and "constantInput" in node_dict:
constant_input = json.loads(node_dict["constantInput"])
constant_input = json.loads(
node_dict["constantInput"], target_type=dict[str, Any]
)
constant_input = Graph._hide_credentials_in_input(constant_input)
node_dict["constantInput"] = json.dumps(constant_input)
return Node.from_db(AgentNode(**node_dict))

View File

@@ -1,7 +1,10 @@
import json
from typing import Any, Type, TypeVar, overload
from fastapi.encoders import jsonable_encoder
from .type import type_match
def to_dict(data) -> dict:
return jsonable_encoder(data)
@@ -11,4 +14,19 @@ def dumps(data) -> str:
return json.dumps(jsonable_encoder(data))
loads = json.loads
T = TypeVar("T")
@overload
def loads(data: str, *args, target_type: Type[T], **kwargs) -> T: ...
@overload
def loads(data: str, *args, **kwargs) -> Any: ...
def loads(data: str, *args, target_type: Type[T] | None = None, **kwargs) -> Any:
parsed = json.loads(data, *args, **kwargs)
if target_type:
return type_match(parsed, target_type)
return parsed

View File

@@ -1,8 +1,8 @@
import json
from typing import Any, Type, TypeVar, get_args, get_origin
from typing import Any, Type, TypeVar, cast, get_args, get_origin
class ConversionError(Exception):
class ConversionError(ValueError):
pass
@@ -102,7 +102,7 @@ def __convert_bool(value: Any) -> bool:
return bool(value)
def convert(value: Any, target_type: Type):
def _try_convert(value: Any, target_type: Type, raise_on_mismatch: bool) -> Any:
origin = get_origin(target_type)
args = get_args(target_type)
if origin is None:
@@ -133,6 +133,8 @@ def convert(value: Any, target_type: Type):
return {convert(v, args[0]) for v in value}
else:
return value
elif raise_on_mismatch:
raise TypeError(f"Value {value} is not of expected type {target_type}")
else:
# Need to convert value to the origin type
if origin is list:
@@ -175,3 +177,17 @@ def convert(value: Any, target_type: Type):
return __convert_bool(value)
else:
return value
T = TypeVar("T")
def type_match(value: Any, target_type: Type[T]) -> T:
return cast(T, _try_convert(value, target_type, raise_on_mismatch=True))
def convert(value: Any, target_type: Type[T]) -> T:
try:
return cast(T, _try_convert(value, target_type, raise_on_mismatch=False))
except Exception as e:
raise ConversionError(f"Failed to convert {value} to {target_type}") from e