Disallow unions in dataclass messages, move check to creation instead of usage (#499)

* Disallow unions in dataclass messages, move check to creation instead of usage

* make image serializable by pydantic

* fixup team one

* update lockfile

* fix

* fix dataclass checking bug

* fix mypy
This commit is contained in:
Jack Gerrits
2024-09-16 12:37:26 -04:00
committed by GitHub
parent 747054aec8
commit 561897b4ee
13 changed files with 198 additions and 35 deletions

View File

@@ -18,7 +18,7 @@ dependencies = [
"pillow",
"aiohttp",
"typing-extensions",
"pydantic>=1.10,<3",
"pydantic<3.0.0,>=2.0.0",
"grpcio~=1.62.0",
"protobuf~=4.25.1",
"tiktoken",

View File

@@ -1,9 +1,11 @@
import json
from dataclasses import asdict, dataclass
from typing import Any, ClassVar, Dict, List, Protocol, TypeVar, cast, runtime_checkable
from dataclasses import asdict, dataclass, fields
from typing import Any, ClassVar, Dict, List, Protocol, TypeVar, cast, get_args, get_origin, runtime_checkable
from pydantic import BaseModel
from autogen_core.base._type_helpers import is_union
T = TypeVar("T")
@@ -27,7 +29,7 @@ class IsDataclass(Protocol):
def is_dataclass(cls: type[Any]) -> bool:
return isinstance(cls, IsDataclass)
return hasattr(cls, "__dataclass_fields__")
def has_nested_dataclass(cls: type[IsDataclass]) -> bool:
@@ -35,9 +37,54 @@ def has_nested_dataclass(cls: type[IsDataclass]) -> bool:
return any(is_dataclass(f.type) for f in cls.__dataclass_fields__.values())
def contains_a_union(cls: type[IsDataclass]) -> bool:
return any(is_union(f.type) for f in cls.__dataclass_fields__.values())
def has_nested_base_model(cls: type[IsDataclass]) -> bool:
# iterate fields and check if any of them are basebodels
return any(issubclass(f.type, BaseModel) for f in cls.__dataclass_fields__.values())
for f in fields(cls):
field_type = f.type
# Resolve forward references and other annotations
origin = get_origin(field_type)
args = get_args(field_type)
# If the field type is directly a subclass of BaseModel
if isinstance(field_type, type) and issubclass(field_type, BaseModel):
return True
# If the field type is a generic type like List[BaseModel], Tuple[BaseModel, ...], etc.
if origin is not None and args:
for arg in args:
# Recursively check the argument types
if isinstance(arg, type) and issubclass(arg, BaseModel):
return True
elif get_origin(arg) is not None:
# Handle nested generics like List[List[BaseModel]]
if has_nested_base_model_in_type(arg):
return True
# Handle Union types
elif args:
for arg in args:
if isinstance(arg, type) and issubclass(arg, BaseModel):
return True
elif get_origin(arg) is not None:
if has_nested_base_model_in_type(arg):
return True
return False
def has_nested_base_model_in_type(tp: Any) -> bool:
"""Helper function to check if a type or its arguments is a BaseModel subclass."""
origin = get_origin(tp)
args = get_args(tp)
if isinstance(tp, type) and issubclass(tp, BaseModel):
return True
if origin is not None and args:
for arg in args:
if has_nested_base_model_in_type(arg):
return True
return False
DataclassT = TypeVar("DataclassT", bound=IsDataclass)
@@ -45,8 +92,16 @@ DataclassT = TypeVar("DataclassT", bound=IsDataclass)
JSON_DATA_CONTENT_TYPE = "application/json"
class DataclassJsonMessageSerializer(MessageSerializer[IsDataclass]):
def __init__(self, cls: type[IsDataclass]) -> None:
class DataclassJsonMessageSerializer(MessageSerializer[DataclassT]):
def __init__(self, cls: type[DataclassT]) -> None:
if contains_a_union(cls):
raise ValueError("Dataclass has a union type, which is not supported. To use a union, use a Pydantic model")
if has_nested_dataclass(cls) or has_nested_base_model(cls):
raise ValueError(
"Dataclass has nested dataclasses or base models, which are not supported. To use nested types, use a Pydantic model"
)
self.cls = cls
@property
@@ -57,14 +112,11 @@ class DataclassJsonMessageSerializer(MessageSerializer[IsDataclass]):
def type_name(self) -> str:
return _type_name(self.cls)
def deserialize(self, payload: bytes) -> IsDataclass:
def deserialize(self, payload: bytes) -> DataclassT:
message_str = payload.decode("utf-8")
return self.cls(**json.loads(message_str))
def serialize(self, message: IsDataclass) -> bytes:
if has_nested_dataclass(type(message)) or has_nested_base_model(type(message)):
raise ValueError("Dataclass has nested dataclasses or base models, which are not supported")
def serialize(self, message: DataclassT) -> bytes:
return json.dumps(asdict(message)).encode("utf-8")

View File

@@ -9,8 +9,8 @@ from ..base._agent_instantiation import AgentInstantiationContext
from ..base._agent_metadata import AgentMetadata
from ..base._agent_runtime import AgentRuntime
from ..base._serialization import JSON_DATA_CONTENT_TYPE, MESSAGE_TYPE_REGISTRY, try_get_known_serializers_for_type
from ..base._type_helpers import get_types
from ..base.exceptions import CantHandleException
from ._type_helpers import get_types
T = TypeVar("T")

View File

@@ -4,10 +4,13 @@ import base64
import re
from io import BytesIO
from pathlib import Path
from typing import Any
import aiohttp
from openai.types.chat import ChatCompletionContentPartImageParam
from PIL import Image as PILImage
from pydantic import BaseModel, GetCoreSchemaHandler, ValidationInfo
from pydantic_core import core_schema
from typing_extensions import Literal
@@ -39,6 +42,12 @@ class Image:
def from_base64(cls, base64_str: str) -> Image:
return cls(PILImage.open(BytesIO(base64.b64decode(base64_str))))
def to_base64(self) -> str:
buffered = BytesIO()
self.image.save(buffered, format="PNG")
content = buffered.getvalue()
return base64.b64encode(content).decode("utf-8")
@classmethod
def from_file(cls, file_path: Path) -> Image:
return cls(PILImage.open(file_path))
@@ -49,14 +58,35 @@ class Image:
@property
def data_uri(self) -> str:
buffered = BytesIO()
self.image.save(buffered, format="PNG")
content = buffered.getvalue()
return _convert_base64_to_data_uri(base64.b64encode(content).decode("utf-8"))
return _convert_base64_to_data_uri(self.to_base64())
def to_openai_format(self, detail: Literal["auto", "low", "high"] = "auto") -> ChatCompletionContentPartImageParam:
return {"type": "image_url", "image_url": {"url": self.data_uri, "detail": detail}}
@classmethod
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
# Custom validation
def validate(value: Any, validation_info: ValidationInfo) -> Image:
if isinstance(value, dict):
base_64 = value.get("data")
if base_64 is None:
raise ValueError("Expected 'data' key in the dictionary")
return cls.from_base64(base_64)
elif isinstance(value, cls):
return value
else:
raise TypeError(f"Expected dict or {cls.__name__} instance, got {type(value)}")
# Custom serialization
def serialize(value: Image) -> dict[str, Any]:
return {"data": value.to_base64()}
return core_schema.with_info_after_validator_function(
validate,
core_schema.any_schema(), # Accept any type; adjust if needed
serialization=core_schema.plain_serializer_function_ser_schema(serialize),
)
def _convert_base64_to_data_uri(base64_image: str) -> str:
def _get_mime_type_from_data_uri(base64_image: str) -> str:

View File

@@ -21,8 +21,8 @@ from typing import (
from autogen_core.base import try_get_known_serializers_for_type
from ..base import MESSAGE_TYPE_REGISTRY, BaseAgent, MessageContext
from ..base._type_helpers import AnyType, get_types
from ..base.exceptions import CantHandleException
from ._type_helpers import AnyType, get_types
logger = logging.getLogger("autogen_core")

View File

@@ -1,4 +1,5 @@
from dataclasses import dataclass
from typing import Union
import pytest
from autogen_core.base import (
@@ -7,6 +8,9 @@ from autogen_core.base import (
Serialization,
try_get_known_serializers_for_type,
)
from autogen_core.base._serialization import DataclassJsonMessageSerializer, PydanticJsonMessageSerializer
from autogen_core.components import Image
from PIL import Image as PILImage
from pydantic import BaseModel
@@ -75,22 +79,33 @@ def test_dataclass() -> None:
def test_nesting_dataclass_dataclass() -> None:
serde = Serialization()
serde.add_serializer(try_get_known_serializers_for_type(NestingDataclassMessage))
message = NestingDataclassMessage(message="hello", nested=DataclassMessage(message="world"))
name = serde.type_name(message)
with pytest.raises(ValueError):
_json = serde.serialize(message, type_name=name, data_content_type=JSON_DATA_CONTENT_TYPE)
serde.add_serializer(try_get_known_serializers_for_type(NestingDataclassMessage))
@dataclass
class DataclassNestedUnionSyntaxOldMessage:
message: Union[str, int]
@dataclass
class DataclassNestedUnionSyntaxNewMessage:
message: str | int
@pytest.mark.parametrize("cls", [DataclassNestedUnionSyntaxOldMessage, DataclassNestedUnionSyntaxNewMessage])
def test_nesting_union_old_syntax_dataclass(
cls: type[DataclassNestedUnionSyntaxOldMessage | DataclassNestedUnionSyntaxNewMessage],
) -> None:
with pytest.raises(ValueError):
_serializer = DataclassJsonMessageSerializer(cls)
def test_nesting_dataclass_pydantic() -> None:
serde = Serialization()
serde.add_serializer(try_get_known_serializers_for_type(NestingPydanticDataclassMessage))
message = NestingPydanticDataclassMessage(message="hello", nested=PydanticMessage(message="world"))
name = serde.type_name(message)
with pytest.raises(ValueError):
_json = serde.serialize(message, type_name=name, data_content_type=JSON_DATA_CONTENT_TYPE)
serde.add_serializer(try_get_known_serializers_for_type(NestingPydanticDataclassMessage))
def test_invalid_type() -> None:
@@ -126,3 +141,22 @@ def test_custom_type() -> None:
assert json == b'"hello"'
deserialized = serde.deserialize(json, type_name="custom_str", data_content_type="str")
assert deserialized == message
def test_image_type() -> None:
pil_image = PILImage.new("RGB", (100, 100))
image = Image(pil_image)
class PydanticImageMessage(BaseModel):
image: Image
serializer = PydanticJsonMessageSerializer(PydanticImageMessage)
json = serializer.serialize(PydanticImageMessage(image=image))
deserialized = serializer.deserialize(json)
assert deserialized.image.image.size == (100, 100)
assert deserialized.image.image.mode == "RGB"
assert deserialized.image.image == image.image

View File

@@ -1,9 +1,12 @@
from dataclasses import dataclass
from types import NoneType
from typing import Any, Optional, Union
from typing import Any, List, Optional, Union
from autogen_core.base import MessageContext
from autogen_core.base._serialization import has_nested_base_model
from autogen_core.base._type_helpers import AnyType, get_types
from autogen_core.components._routed_agent import message_handler
from autogen_core.components._type_helpers import AnyType, get_types
from pydantic import BaseModel
def test_get_types() -> None:
@@ -38,3 +41,44 @@ class HandlerClass:
@message_handler()
async def handler(self, message: int, ctx: MessageContext) -> Any:
return None
def test_nested_data_model() -> None:
class MyBaseModel(BaseModel):
message: str
@dataclass
class NestedBaseModel:
nested: MyBaseModel
@dataclass
class NestedBaseModelList:
nested: List[MyBaseModel]
@dataclass
class NestedBaseModelList2:
nested: list[MyBaseModel]
@dataclass
class NestedBaseModelList3:
nested: list[list[MyBaseModel]]
@dataclass
class NestedBaseModelList4:
nested: list[list[list[list[list[list[MyBaseModel]]]]]]
@dataclass
class NestedBaseModelUnion:
nested: Union[MyBaseModel, str]
@dataclass
class NestedBaseModelUnion2:
nested: MyBaseModel | str
assert has_nested_base_model(NestedBaseModel)
assert has_nested_base_model(NestedBaseModelList)
assert has_nested_base_model(NestedBaseModelList2)
assert has_nested_base_model(NestedBaseModelList3)
assert has_nested_base_model(NestedBaseModelList4)
assert has_nested_base_model(NestedBaseModelUnion)
assert has_nested_base_model(NestedBaseModelUnion2)

View File

@@ -30,7 +30,7 @@ async def main() -> None:
task_message = UserMessage(content="Test Message", source="User")
runtime.start()
await runtime.publish_message(BroadcastMessage(task_message), topic_id=DefaultTopicId())
await runtime.publish_message(BroadcastMessage(content=task_message), topic_id=DefaultTopicId())
await runtime.stop_when_idle()

View File

@@ -33,6 +33,7 @@ dependencies = [
"SpeechRecognition",
"pathvalidate",
"playwright",
"pydantic<3.0.0,>=2.0.0",
]
[project.optional-dependencies]

View File

@@ -24,4 +24,4 @@ class ReflexAgent(RoutedAgent):
)
topic_id = TopicId("default", self.id.key)
await self.publish_message(BroadcastMessage(response_message), topic_id=topic_id)
await self.publish_message(BroadcastMessage(content=response_message), topic_id=topic_id)

View File

@@ -3,6 +3,7 @@ from typing import Any, Dict, List, Union
from autogen_core.components import FunctionCall, Image
from autogen_core.components.models import FunctionExecutionResult, LLMMessage
from pydantic import BaseModel
# Convenience type
UserContent = Union[str, List[Union[str, Image]]]
@@ -11,8 +12,7 @@ FunctionExecutionContent = List[FunctionExecutionResult]
SystemContent = str
@dataclass
class BroadcastMessage:
class BroadcastMessage(BaseModel):
content: LLMMessage
request_halt: bool = False