mirror of
https://github.com/microsoft/autogen.git
synced 2026-01-23 19:17:55 -05:00
Disallow unions in dataclass messages, move check to creation instead of usage (#499)
* Disallow unions in dataclass messages, move check to creation instead of usage * make image serializable by pydantic * fixup team one * update lockfile * fix * fix dataclass checking bug * fix mypy
This commit is contained in:
@@ -18,7 +18,7 @@ dependencies = [
|
||||
"pillow",
|
||||
"aiohttp",
|
||||
"typing-extensions",
|
||||
"pydantic>=1.10,<3",
|
||||
"pydantic<3.0.0,>=2.0.0",
|
||||
"grpcio~=1.62.0",
|
||||
"protobuf~=4.25.1",
|
||||
"tiktoken",
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import json
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Any, ClassVar, Dict, List, Protocol, TypeVar, cast, runtime_checkable
|
||||
from dataclasses import asdict, dataclass, fields
|
||||
from typing import Any, ClassVar, Dict, List, Protocol, TypeVar, cast, get_args, get_origin, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from autogen_core.base._type_helpers import is_union
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@@ -27,7 +29,7 @@ class IsDataclass(Protocol):
|
||||
|
||||
|
||||
def is_dataclass(cls: type[Any]) -> bool:
|
||||
return isinstance(cls, IsDataclass)
|
||||
return hasattr(cls, "__dataclass_fields__")
|
||||
|
||||
|
||||
def has_nested_dataclass(cls: type[IsDataclass]) -> bool:
|
||||
@@ -35,9 +37,54 @@ def has_nested_dataclass(cls: type[IsDataclass]) -> bool:
|
||||
return any(is_dataclass(f.type) for f in cls.__dataclass_fields__.values())
|
||||
|
||||
|
||||
def contains_a_union(cls: type[IsDataclass]) -> bool:
|
||||
return any(is_union(f.type) for f in cls.__dataclass_fields__.values())
|
||||
|
||||
|
||||
def has_nested_base_model(cls: type[IsDataclass]) -> bool:
|
||||
# iterate fields and check if any of them are basebodels
|
||||
return any(issubclass(f.type, BaseModel) for f in cls.__dataclass_fields__.values())
|
||||
for f in fields(cls):
|
||||
field_type = f.type
|
||||
# Resolve forward references and other annotations
|
||||
origin = get_origin(field_type)
|
||||
args = get_args(field_type)
|
||||
|
||||
# If the field type is directly a subclass of BaseModel
|
||||
if isinstance(field_type, type) and issubclass(field_type, BaseModel):
|
||||
return True
|
||||
|
||||
# If the field type is a generic type like List[BaseModel], Tuple[BaseModel, ...], etc.
|
||||
if origin is not None and args:
|
||||
for arg in args:
|
||||
# Recursively check the argument types
|
||||
if isinstance(arg, type) and issubclass(arg, BaseModel):
|
||||
return True
|
||||
elif get_origin(arg) is not None:
|
||||
# Handle nested generics like List[List[BaseModel]]
|
||||
if has_nested_base_model_in_type(arg):
|
||||
return True
|
||||
# Handle Union types
|
||||
elif args:
|
||||
for arg in args:
|
||||
if isinstance(arg, type) and issubclass(arg, BaseModel):
|
||||
return True
|
||||
elif get_origin(arg) is not None:
|
||||
if has_nested_base_model_in_type(arg):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def has_nested_base_model_in_type(tp: Any) -> bool:
|
||||
"""Helper function to check if a type or its arguments is a BaseModel subclass."""
|
||||
origin = get_origin(tp)
|
||||
args = get_args(tp)
|
||||
|
||||
if isinstance(tp, type) and issubclass(tp, BaseModel):
|
||||
return True
|
||||
if origin is not None and args:
|
||||
for arg in args:
|
||||
if has_nested_base_model_in_type(arg):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
DataclassT = TypeVar("DataclassT", bound=IsDataclass)
|
||||
@@ -45,8 +92,16 @@ DataclassT = TypeVar("DataclassT", bound=IsDataclass)
|
||||
JSON_DATA_CONTENT_TYPE = "application/json"
|
||||
|
||||
|
||||
class DataclassJsonMessageSerializer(MessageSerializer[IsDataclass]):
|
||||
def __init__(self, cls: type[IsDataclass]) -> None:
|
||||
class DataclassJsonMessageSerializer(MessageSerializer[DataclassT]):
|
||||
def __init__(self, cls: type[DataclassT]) -> None:
|
||||
if contains_a_union(cls):
|
||||
raise ValueError("Dataclass has a union type, which is not supported. To use a union, use a Pydantic model")
|
||||
|
||||
if has_nested_dataclass(cls) or has_nested_base_model(cls):
|
||||
raise ValueError(
|
||||
"Dataclass has nested dataclasses or base models, which are not supported. To use nested types, use a Pydantic model"
|
||||
)
|
||||
|
||||
self.cls = cls
|
||||
|
||||
@property
|
||||
@@ -57,14 +112,11 @@ class DataclassJsonMessageSerializer(MessageSerializer[IsDataclass]):
|
||||
def type_name(self) -> str:
|
||||
return _type_name(self.cls)
|
||||
|
||||
def deserialize(self, payload: bytes) -> IsDataclass:
|
||||
def deserialize(self, payload: bytes) -> DataclassT:
|
||||
message_str = payload.decode("utf-8")
|
||||
return self.cls(**json.loads(message_str))
|
||||
|
||||
def serialize(self, message: IsDataclass) -> bytes:
|
||||
if has_nested_dataclass(type(message)) or has_nested_base_model(type(message)):
|
||||
raise ValueError("Dataclass has nested dataclasses or base models, which are not supported")
|
||||
|
||||
def serialize(self, message: DataclassT) -> bytes:
|
||||
return json.dumps(asdict(message)).encode("utf-8")
|
||||
|
||||
|
||||
|
||||
@@ -9,8 +9,8 @@ from ..base._agent_instantiation import AgentInstantiationContext
|
||||
from ..base._agent_metadata import AgentMetadata
|
||||
from ..base._agent_runtime import AgentRuntime
|
||||
from ..base._serialization import JSON_DATA_CONTENT_TYPE, MESSAGE_TYPE_REGISTRY, try_get_known_serializers_for_type
|
||||
from ..base._type_helpers import get_types
|
||||
from ..base.exceptions import CantHandleException
|
||||
from ._type_helpers import get_types
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@@ -4,10 +4,13 @@ import base64
|
||||
import re
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
from openai.types.chat import ChatCompletionContentPartImageParam
|
||||
from PIL import Image as PILImage
|
||||
from pydantic import BaseModel, GetCoreSchemaHandler, ValidationInfo
|
||||
from pydantic_core import core_schema
|
||||
from typing_extensions import Literal
|
||||
|
||||
|
||||
@@ -39,6 +42,12 @@ class Image:
|
||||
def from_base64(cls, base64_str: str) -> Image:
|
||||
return cls(PILImage.open(BytesIO(base64.b64decode(base64_str))))
|
||||
|
||||
def to_base64(self) -> str:
|
||||
buffered = BytesIO()
|
||||
self.image.save(buffered, format="PNG")
|
||||
content = buffered.getvalue()
|
||||
return base64.b64encode(content).decode("utf-8")
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, file_path: Path) -> Image:
|
||||
return cls(PILImage.open(file_path))
|
||||
@@ -49,14 +58,35 @@ class Image:
|
||||
|
||||
@property
|
||||
def data_uri(self) -> str:
|
||||
buffered = BytesIO()
|
||||
self.image.save(buffered, format="PNG")
|
||||
content = buffered.getvalue()
|
||||
return _convert_base64_to_data_uri(base64.b64encode(content).decode("utf-8"))
|
||||
return _convert_base64_to_data_uri(self.to_base64())
|
||||
|
||||
def to_openai_format(self, detail: Literal["auto", "low", "high"] = "auto") -> ChatCompletionContentPartImageParam:
|
||||
return {"type": "image_url", "image_url": {"url": self.data_uri, "detail": detail}}
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
|
||||
# Custom validation
|
||||
def validate(value: Any, validation_info: ValidationInfo) -> Image:
|
||||
if isinstance(value, dict):
|
||||
base_64 = value.get("data")
|
||||
if base_64 is None:
|
||||
raise ValueError("Expected 'data' key in the dictionary")
|
||||
return cls.from_base64(base_64)
|
||||
elif isinstance(value, cls):
|
||||
return value
|
||||
else:
|
||||
raise TypeError(f"Expected dict or {cls.__name__} instance, got {type(value)}")
|
||||
|
||||
# Custom serialization
|
||||
def serialize(value: Image) -> dict[str, Any]:
|
||||
return {"data": value.to_base64()}
|
||||
|
||||
return core_schema.with_info_after_validator_function(
|
||||
validate,
|
||||
core_schema.any_schema(), # Accept any type; adjust if needed
|
||||
serialization=core_schema.plain_serializer_function_ser_schema(serialize),
|
||||
)
|
||||
|
||||
|
||||
def _convert_base64_to_data_uri(base64_image: str) -> str:
|
||||
def _get_mime_type_from_data_uri(base64_image: str) -> str:
|
||||
|
||||
@@ -21,8 +21,8 @@ from typing import (
|
||||
from autogen_core.base import try_get_known_serializers_for_type
|
||||
|
||||
from ..base import MESSAGE_TYPE_REGISTRY, BaseAgent, MessageContext
|
||||
from ..base._type_helpers import AnyType, get_types
|
||||
from ..base.exceptions import CantHandleException
|
||||
from ._type_helpers import AnyType, get_types
|
||||
|
||||
logger = logging.getLogger("autogen_core")
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Union
|
||||
|
||||
import pytest
|
||||
from autogen_core.base import (
|
||||
@@ -7,6 +8,9 @@ from autogen_core.base import (
|
||||
Serialization,
|
||||
try_get_known_serializers_for_type,
|
||||
)
|
||||
from autogen_core.base._serialization import DataclassJsonMessageSerializer, PydanticJsonMessageSerializer
|
||||
from autogen_core.components import Image
|
||||
from PIL import Image as PILImage
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@@ -75,22 +79,33 @@ def test_dataclass() -> None:
|
||||
|
||||
def test_nesting_dataclass_dataclass() -> None:
|
||||
serde = Serialization()
|
||||
serde.add_serializer(try_get_known_serializers_for_type(NestingDataclassMessage))
|
||||
|
||||
message = NestingDataclassMessage(message="hello", nested=DataclassMessage(message="world"))
|
||||
name = serde.type_name(message)
|
||||
with pytest.raises(ValueError):
|
||||
_json = serde.serialize(message, type_name=name, data_content_type=JSON_DATA_CONTENT_TYPE)
|
||||
serde.add_serializer(try_get_known_serializers_for_type(NestingDataclassMessage))
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataclassNestedUnionSyntaxOldMessage:
|
||||
message: Union[str, int]
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataclassNestedUnionSyntaxNewMessage:
|
||||
message: str | int
|
||||
|
||||
|
||||
@pytest.mark.parametrize("cls", [DataclassNestedUnionSyntaxOldMessage, DataclassNestedUnionSyntaxNewMessage])
|
||||
def test_nesting_union_old_syntax_dataclass(
|
||||
cls: type[DataclassNestedUnionSyntaxOldMessage | DataclassNestedUnionSyntaxNewMessage],
|
||||
) -> None:
|
||||
with pytest.raises(ValueError):
|
||||
_serializer = DataclassJsonMessageSerializer(cls)
|
||||
|
||||
|
||||
def test_nesting_dataclass_pydantic() -> None:
|
||||
serde = Serialization()
|
||||
serde.add_serializer(try_get_known_serializers_for_type(NestingPydanticDataclassMessage))
|
||||
|
||||
message = NestingPydanticDataclassMessage(message="hello", nested=PydanticMessage(message="world"))
|
||||
name = serde.type_name(message)
|
||||
with pytest.raises(ValueError):
|
||||
_json = serde.serialize(message, type_name=name, data_content_type=JSON_DATA_CONTENT_TYPE)
|
||||
serde.add_serializer(try_get_known_serializers_for_type(NestingPydanticDataclassMessage))
|
||||
|
||||
|
||||
def test_invalid_type() -> None:
|
||||
@@ -126,3 +141,22 @@ def test_custom_type() -> None:
|
||||
assert json == b'"hello"'
|
||||
deserialized = serde.deserialize(json, type_name="custom_str", data_content_type="str")
|
||||
assert deserialized == message
|
||||
|
||||
|
||||
def test_image_type() -> None:
|
||||
pil_image = PILImage.new("RGB", (100, 100))
|
||||
|
||||
image = Image(pil_image)
|
||||
|
||||
class PydanticImageMessage(BaseModel):
|
||||
image: Image
|
||||
|
||||
serializer = PydanticJsonMessageSerializer(PydanticImageMessage)
|
||||
|
||||
json = serializer.serialize(PydanticImageMessage(image=image))
|
||||
|
||||
deserialized = serializer.deserialize(json)
|
||||
|
||||
assert deserialized.image.image.size == (100, 100)
|
||||
assert deserialized.image.image.mode == "RGB"
|
||||
assert deserialized.image.image == image.image
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
from dataclasses import dataclass
|
||||
from types import NoneType
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
from autogen_core.base import MessageContext
|
||||
from autogen_core.base._serialization import has_nested_base_model
|
||||
from autogen_core.base._type_helpers import AnyType, get_types
|
||||
from autogen_core.components._routed_agent import message_handler
|
||||
from autogen_core.components._type_helpers import AnyType, get_types
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def test_get_types() -> None:
|
||||
@@ -38,3 +41,44 @@ class HandlerClass:
|
||||
@message_handler()
|
||||
async def handler(self, message: int, ctx: MessageContext) -> Any:
|
||||
return None
|
||||
|
||||
|
||||
def test_nested_data_model() -> None:
|
||||
class MyBaseModel(BaseModel):
|
||||
message: str
|
||||
|
||||
@dataclass
|
||||
class NestedBaseModel:
|
||||
nested: MyBaseModel
|
||||
|
||||
@dataclass
|
||||
class NestedBaseModelList:
|
||||
nested: List[MyBaseModel]
|
||||
|
||||
@dataclass
|
||||
class NestedBaseModelList2:
|
||||
nested: list[MyBaseModel]
|
||||
|
||||
@dataclass
|
||||
class NestedBaseModelList3:
|
||||
nested: list[list[MyBaseModel]]
|
||||
|
||||
@dataclass
|
||||
class NestedBaseModelList4:
|
||||
nested: list[list[list[list[list[list[MyBaseModel]]]]]]
|
||||
|
||||
@dataclass
|
||||
class NestedBaseModelUnion:
|
||||
nested: Union[MyBaseModel, str]
|
||||
|
||||
@dataclass
|
||||
class NestedBaseModelUnion2:
|
||||
nested: MyBaseModel | str
|
||||
|
||||
assert has_nested_base_model(NestedBaseModel)
|
||||
assert has_nested_base_model(NestedBaseModelList)
|
||||
assert has_nested_base_model(NestedBaseModelList2)
|
||||
assert has_nested_base_model(NestedBaseModelList3)
|
||||
assert has_nested_base_model(NestedBaseModelList4)
|
||||
assert has_nested_base_model(NestedBaseModelUnion)
|
||||
assert has_nested_base_model(NestedBaseModelUnion2)
|
||||
|
||||
@@ -30,7 +30,7 @@ async def main() -> None:
|
||||
|
||||
task_message = UserMessage(content="Test Message", source="User")
|
||||
runtime.start()
|
||||
await runtime.publish_message(BroadcastMessage(task_message), topic_id=DefaultTopicId())
|
||||
await runtime.publish_message(BroadcastMessage(content=task_message), topic_id=DefaultTopicId())
|
||||
|
||||
await runtime.stop_when_idle()
|
||||
|
||||
|
||||
@@ -33,6 +33,7 @@ dependencies = [
|
||||
"SpeechRecognition",
|
||||
"pathvalidate",
|
||||
"playwright",
|
||||
"pydantic<3.0.0,>=2.0.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
||||
@@ -24,4 +24,4 @@ class ReflexAgent(RoutedAgent):
|
||||
)
|
||||
topic_id = TopicId("default", self.id.key)
|
||||
|
||||
await self.publish_message(BroadcastMessage(response_message), topic_id=topic_id)
|
||||
await self.publish_message(BroadcastMessage(content=response_message), topic_id=topic_id)
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Any, Dict, List, Union
|
||||
|
||||
from autogen_core.components import FunctionCall, Image
|
||||
from autogen_core.components.models import FunctionExecutionResult, LLMMessage
|
||||
from pydantic import BaseModel
|
||||
|
||||
# Convenience type
|
||||
UserContent = Union[str, List[Union[str, Image]]]
|
||||
@@ -11,8 +12,7 @@ FunctionExecutionContent = List[FunctionExecutionResult]
|
||||
SystemContent = str
|
||||
|
||||
|
||||
@dataclass
|
||||
class BroadcastMessage:
|
||||
class BroadcastMessage(BaseModel):
|
||||
content: LLMMessage
|
||||
request_halt: bool = False
|
||||
|
||||
|
||||
Reference in New Issue
Block a user