mirror of
https://github.com/microsoft/autogen.git
synced 2026-01-09 19:38:15 -05:00
Fix: Handle nested objects in array items for JSON schema conversion (#6993)
Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
@@ -128,6 +128,17 @@ class _JSONSchemaToPydantic:
|
||||
|
||||
return self._model_cache[ref_name]
|
||||
|
||||
def _get_item_model_name(self, array_field_name: str, parent_model_name: str) -> str:
|
||||
"""Generate hash-based model names for array items to keep names short and unique."""
|
||||
import hashlib
|
||||
|
||||
# Create a short hash of the full path to ensure uniqueness
|
||||
full_path = f"{parent_model_name}_{array_field_name}"
|
||||
hash_suffix = hashlib.md5(full_path.encode()).hexdigest()[:6]
|
||||
|
||||
# Use field name as-is with hash suffix
|
||||
return f"{array_field_name}_{hash_suffix}"
|
||||
|
||||
def _process_definitions(self, root_schema: Dict[str, Any]) -> None:
|
||||
if "$defs" in root_schema:
|
||||
for model_name in root_schema["$defs"]:
|
||||
@@ -253,6 +264,11 @@ class _JSONSchemaToPydantic:
|
||||
item_schema = value.get("items", {"type": "string"})
|
||||
if "$ref" in item_schema:
|
||||
item_type = self.get_ref(item_schema["$ref"].split("/")[-1])
|
||||
elif item_schema.get("type") == "object" and "properties" in item_schema:
|
||||
# Handle array items that are objects with properties - create a nested model
|
||||
# Use hash-based naming to keep names short and unique
|
||||
item_model_name = self._get_item_model_name(key, model_name)
|
||||
item_type = self._json_schema_to_model(item_schema, item_model_name, root_schema)
|
||||
else:
|
||||
item_type_name = item_schema.get("type")
|
||||
if item_type_name is None:
|
||||
|
||||
@@ -834,3 +834,211 @@ def test_unknown_format_raises() -> None:
|
||||
converter = _JSONSchemaToPydantic()
|
||||
with pytest.raises(FormatNotSupportedError):
|
||||
converter.json_schema_to_pydantic(schema, "UnknownFormatModel")
|
||||
|
||||
|
||||
def test_array_items_with_object_schema_properties() -> None:
|
||||
"""Test that array items with object schemas create proper Pydantic models."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"users": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": "string"}, "email": {"type": "string"}, "age": {"type": "integer"}},
|
||||
"required": ["name", "email"],
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
converter = _JSONSchemaToPydantic()
|
||||
Model = converter.json_schema_to_pydantic(schema, "UserListModel")
|
||||
|
||||
# Verify the users field has correct type annotation
|
||||
users_field = Model.model_fields["users"]
|
||||
from typing import Union, get_args, get_origin
|
||||
|
||||
# Extract inner type from Optional[List[...]]
|
||||
actual_list_type = users_field.annotation
|
||||
if get_origin(users_field.annotation) is Union:
|
||||
union_args = get_args(users_field.annotation)
|
||||
for arg in union_args:
|
||||
if get_origin(arg) is list:
|
||||
actual_list_type = arg
|
||||
break
|
||||
|
||||
assert get_origin(actual_list_type) is list
|
||||
inner_type = get_args(actual_list_type)[0]
|
||||
|
||||
# Verify array items are BaseModel subclasses, not dict
|
||||
assert inner_type is not dict
|
||||
assert hasattr(inner_type, "model_fields")
|
||||
|
||||
# Verify expected fields are present
|
||||
expected_fields = {"name", "email", "age"}
|
||||
actual_fields = set(inner_type.model_fields.keys())
|
||||
assert expected_fields.issubset(actual_fields)
|
||||
|
||||
# Test instantiation and field access
|
||||
test_data = {
|
||||
"users": [
|
||||
{"name": "Alice", "email": "alice@example.com", "age": 30},
|
||||
{"name": "Bob", "email": "bob@example.com"},
|
||||
]
|
||||
}
|
||||
|
||||
instance = Model(**test_data)
|
||||
assert len(instance.users) == 2 # type: ignore[attr-defined]
|
||||
|
||||
first_user = instance.users[0] # type: ignore[attr-defined]
|
||||
assert hasattr(first_user, "model_fields") # type: ignore[reportUnknownArgumentType]
|
||||
assert not isinstance(first_user, dict)
|
||||
|
||||
# Test attribute access (BaseModel behavior)
|
||||
assert first_user.name == "Alice" # type: ignore[attr-defined]
|
||||
assert first_user.email == "alice@example.com" # type: ignore[attr-defined]
|
||||
assert first_user.age == 30 # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def test_nested_arrays_with_object_schemas() -> None:
|
||||
"""Test deeply nested arrays with object schemas create proper Pydantic models."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"companies": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"departments": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"employees": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"role": {"type": "string"},
|
||||
"skills": {"type": "array", "items": {"type": "string"}},
|
||||
},
|
||||
"required": ["name", "role"],
|
||||
},
|
||||
},
|
||||
},
|
||||
"required": ["name"],
|
||||
},
|
||||
},
|
||||
},
|
||||
"required": ["name"],
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
converter = _JSONSchemaToPydantic()
|
||||
Model = converter.json_schema_to_pydantic(schema, "CompanyListModel")
|
||||
|
||||
# Verify companies field type annotation
|
||||
companies_field = Model.model_fields["companies"]
|
||||
from typing import Union, get_args, get_origin
|
||||
|
||||
# Extract companies inner type
|
||||
actual_list_type = companies_field.annotation
|
||||
if get_origin(companies_field.annotation) is Union:
|
||||
union_args = get_args(companies_field.annotation)
|
||||
for arg in union_args:
|
||||
if get_origin(arg) is list:
|
||||
actual_list_type = arg
|
||||
break
|
||||
|
||||
assert get_origin(actual_list_type) is list
|
||||
company_type = get_args(actual_list_type)[0]
|
||||
|
||||
# Verify companies are BaseModel subclasses
|
||||
assert company_type is not dict
|
||||
assert hasattr(company_type, "model_fields")
|
||||
assert "name" in company_type.model_fields
|
||||
assert "departments" in company_type.model_fields
|
||||
|
||||
# Verify departments field type annotation
|
||||
departments_field = company_type.model_fields["departments"]
|
||||
dept_list_type = departments_field.annotation
|
||||
if get_origin(dept_list_type) is Union:
|
||||
union_args = get_args(dept_list_type)
|
||||
for arg in union_args:
|
||||
if get_origin(arg) is list:
|
||||
dept_list_type = arg
|
||||
break
|
||||
|
||||
assert get_origin(dept_list_type) is list
|
||||
department_type = get_args(dept_list_type)[0]
|
||||
|
||||
# Verify departments are BaseModel subclasses
|
||||
assert department_type is not dict
|
||||
assert hasattr(department_type, "model_fields")
|
||||
assert "name" in department_type.model_fields
|
||||
assert "employees" in department_type.model_fields
|
||||
|
||||
# Verify employees field type annotation
|
||||
employees_field = department_type.model_fields["employees"]
|
||||
emp_list_type = employees_field.annotation
|
||||
if get_origin(emp_list_type) is Union:
|
||||
union_args = get_args(emp_list_type)
|
||||
for arg in union_args:
|
||||
if get_origin(arg) is list:
|
||||
emp_list_type = arg
|
||||
break
|
||||
|
||||
assert get_origin(emp_list_type) is list
|
||||
employee_type = get_args(emp_list_type)[0]
|
||||
|
||||
# Verify employees are BaseModel subclasses
|
||||
assert employee_type is not dict
|
||||
assert hasattr(employee_type, "model_fields")
|
||||
expected_emp_fields = {"name", "role", "skills"}
|
||||
actual_emp_fields = set(employee_type.model_fields.keys())
|
||||
assert expected_emp_fields.issubset(actual_emp_fields)
|
||||
|
||||
# Test instantiation with nested data
|
||||
test_data = {
|
||||
"companies": [
|
||||
{
|
||||
"name": "TechCorp",
|
||||
"departments": [
|
||||
{
|
||||
"name": "Engineering",
|
||||
"employees": [
|
||||
{"name": "Alice", "role": "Senior Developer", "skills": ["Python", "JavaScript", "Docker"]},
|
||||
{"name": "Bob", "role": "DevOps Engineer", "skills": ["Kubernetes", "AWS"]},
|
||||
],
|
||||
},
|
||||
{"name": "Marketing", "employees": [{"name": "Carol", "role": "Marketing Manager"}]},
|
||||
],
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
instance = Model(**test_data)
|
||||
assert len(instance.companies) == 1 # type: ignore[attr-defined]
|
||||
|
||||
company = instance.companies[0] # type: ignore[attr-defined]
|
||||
assert hasattr(company, "model_fields") # type: ignore[reportUnknownArgumentType]
|
||||
assert company.name == "TechCorp" # type: ignore[attr-defined]
|
||||
assert len(company.departments) == 2 # type: ignore[attr-defined]
|
||||
|
||||
engineering_dept = company.departments[0] # type: ignore[attr-defined]
|
||||
assert hasattr(engineering_dept, "model_fields") # type: ignore[reportUnknownArgumentType]
|
||||
assert engineering_dept.name == "Engineering" # type: ignore[attr-defined]
|
||||
assert len(engineering_dept.employees) == 2 # type: ignore[attr-defined]
|
||||
|
||||
alice = engineering_dept.employees[0] # type: ignore[attr-defined]
|
||||
assert hasattr(alice, "model_fields") # type: ignore[reportUnknownArgumentType]
|
||||
assert alice.name == "Alice" # type: ignore[attr-defined]
|
||||
assert alice.role == "Senior Developer" # type: ignore[attr-defined]
|
||||
assert alice.skills == ["Python", "JavaScript", "Docker"] # type: ignore[attr-defined]
|
||||
|
||||
Reference in New Issue
Block a user