diff --git a/python/packages/autogen-core/pyproject.toml b/python/packages/autogen-core/pyproject.toml index d4c2a3ec7..da9d71176 100644 --- a/python/packages/autogen-core/pyproject.toml +++ b/python/packages/autogen-core/pyproject.toml @@ -18,7 +18,7 @@ dependencies = [ "pillow", "aiohttp", "typing-extensions", - "pydantic>=1.10,<3", + "pydantic<3.0.0,>=2.0.0", "grpcio~=1.62.0", "protobuf~=4.25.1", "tiktoken", diff --git a/python/packages/autogen-core/src/autogen_core/base/_serialization.py b/python/packages/autogen-core/src/autogen_core/base/_serialization.py index 44392d61e..41e7fdebe 100644 --- a/python/packages/autogen-core/src/autogen_core/base/_serialization.py +++ b/python/packages/autogen-core/src/autogen_core/base/_serialization.py @@ -1,9 +1,11 @@ import json -from dataclasses import asdict, dataclass -from typing import Any, ClassVar, Dict, List, Protocol, TypeVar, cast, runtime_checkable +from dataclasses import asdict, dataclass, fields +from typing import Any, ClassVar, Dict, List, Protocol, TypeVar, cast, get_args, get_origin, runtime_checkable from pydantic import BaseModel +from autogen_core.base._type_helpers import is_union + T = TypeVar("T") @@ -27,7 +29,7 @@ class IsDataclass(Protocol): def is_dataclass(cls: type[Any]) -> bool: - return isinstance(cls, IsDataclass) + return hasattr(cls, "__dataclass_fields__") def has_nested_dataclass(cls: type[IsDataclass]) -> bool: @@ -35,9 +37,54 @@ def has_nested_dataclass(cls: type[IsDataclass]) -> bool: return any(is_dataclass(f.type) for f in cls.__dataclass_fields__.values()) +def contains_a_union(cls: type[IsDataclass]) -> bool: + return any(is_union(f.type) for f in cls.__dataclass_fields__.values()) + + def has_nested_base_model(cls: type[IsDataclass]) -> bool: - # iterate fields and check if any of them are basebodels - return any(issubclass(f.type, BaseModel) for f in cls.__dataclass_fields__.values()) + for f in fields(cls): + field_type = f.type + # Resolve forward references and other annotations + origin = get_origin(field_type) + args = get_args(field_type) + + # If the field type is directly a subclass of BaseModel + if isinstance(field_type, type) and issubclass(field_type, BaseModel): + return True + + # If the field type is a generic type like List[BaseModel], Tuple[BaseModel, ...], etc. + if origin is not None and args: + for arg in args: + # Recursively check the argument types + if isinstance(arg, type) and issubclass(arg, BaseModel): + return True + elif get_origin(arg) is not None: + # Handle nested generics like List[List[BaseModel]] + if has_nested_base_model_in_type(arg): + return True + # Handle Union types + elif args: + for arg in args: + if isinstance(arg, type) and issubclass(arg, BaseModel): + return True + elif get_origin(arg) is not None: + if has_nested_base_model_in_type(arg): + return True + return False + + +def has_nested_base_model_in_type(tp: Any) -> bool: + """Helper function to check if a type or its arguments is a BaseModel subclass.""" + origin = get_origin(tp) + args = get_args(tp) + + if isinstance(tp, type) and issubclass(tp, BaseModel): + return True + if origin is not None and args: + for arg in args: + if has_nested_base_model_in_type(arg): + return True + return False DataclassT = TypeVar("DataclassT", bound=IsDataclass) @@ -45,8 +92,16 @@ DataclassT = TypeVar("DataclassT", bound=IsDataclass) JSON_DATA_CONTENT_TYPE = "application/json" -class DataclassJsonMessageSerializer(MessageSerializer[IsDataclass]): - def __init__(self, cls: type[IsDataclass]) -> None: +class DataclassJsonMessageSerializer(MessageSerializer[DataclassT]): + def __init__(self, cls: type[DataclassT]) -> None: + if contains_a_union(cls): + raise ValueError("Dataclass has a union type, which is not supported. To use a union, use a Pydantic model") + + if has_nested_dataclass(cls) or has_nested_base_model(cls): + raise ValueError( + "Dataclass has nested dataclasses or base models, which are not supported. To use nested types, use a Pydantic model" + ) + self.cls = cls @property @@ -57,14 +112,11 @@ class DataclassJsonMessageSerializer(MessageSerializer[IsDataclass]): def type_name(self) -> str: return _type_name(self.cls) - def deserialize(self, payload: bytes) -> IsDataclass: + def deserialize(self, payload: bytes) -> DataclassT: message_str = payload.decode("utf-8") return self.cls(**json.loads(message_str)) - def serialize(self, message: IsDataclass) -> bytes: - if has_nested_dataclass(type(message)) or has_nested_base_model(type(message)): - raise ValueError("Dataclass has nested dataclasses or base models, which are not supported") - + def serialize(self, message: DataclassT) -> bytes: return json.dumps(asdict(message)).encode("utf-8") diff --git a/python/packages/autogen-core/src/autogen_core/components/_type_helpers.py b/python/packages/autogen-core/src/autogen_core/base/_type_helpers.py similarity index 100% rename from python/packages/autogen-core/src/autogen_core/components/_type_helpers.py rename to python/packages/autogen-core/src/autogen_core/base/_type_helpers.py diff --git a/python/packages/autogen-core/src/autogen_core/components/_closure_agent.py b/python/packages/autogen-core/src/autogen_core/components/_closure_agent.py index 68f14632c..f020919d4 100644 --- a/python/packages/autogen-core/src/autogen_core/components/_closure_agent.py +++ b/python/packages/autogen-core/src/autogen_core/components/_closure_agent.py @@ -9,8 +9,8 @@ from ..base._agent_instantiation import AgentInstantiationContext from ..base._agent_metadata import AgentMetadata from ..base._agent_runtime import AgentRuntime from ..base._serialization import JSON_DATA_CONTENT_TYPE, MESSAGE_TYPE_REGISTRY, try_get_known_serializers_for_type +from ..base._type_helpers import get_types from ..base.exceptions import CantHandleException -from ._type_helpers import get_types T = TypeVar("T") diff --git a/python/packages/autogen-core/src/autogen_core/components/_image.py b/python/packages/autogen-core/src/autogen_core/components/_image.py index bf01e5442..5d5e6146e 100644 --- a/python/packages/autogen-core/src/autogen_core/components/_image.py +++ b/python/packages/autogen-core/src/autogen_core/components/_image.py @@ -4,10 +4,13 @@ import base64 import re from io import BytesIO from pathlib import Path +from typing import Any import aiohttp from openai.types.chat import ChatCompletionContentPartImageParam from PIL import Image as PILImage +from pydantic import BaseModel, GetCoreSchemaHandler, ValidationInfo +from pydantic_core import core_schema from typing_extensions import Literal @@ -39,6 +42,12 @@ class Image: def from_base64(cls, base64_str: str) -> Image: return cls(PILImage.open(BytesIO(base64.b64decode(base64_str)))) + def to_base64(self) -> str: + buffered = BytesIO() + self.image.save(buffered, format="PNG") + content = buffered.getvalue() + return base64.b64encode(content).decode("utf-8") + @classmethod def from_file(cls, file_path: Path) -> Image: return cls(PILImage.open(file_path)) @@ -49,14 +58,35 @@ class Image: @property def data_uri(self) -> str: - buffered = BytesIO() - self.image.save(buffered, format="PNG") - content = buffered.getvalue() - return _convert_base64_to_data_uri(base64.b64encode(content).decode("utf-8")) + return _convert_base64_to_data_uri(self.to_base64()) def to_openai_format(self, detail: Literal["auto", "low", "high"] = "auto") -> ChatCompletionContentPartImageParam: return {"type": "image_url", "image_url": {"url": self.data_uri, "detail": detail}} + @classmethod + def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: + # Custom validation + def validate(value: Any, validation_info: ValidationInfo) -> Image: + if isinstance(value, dict): + base_64 = value.get("data") + if base_64 is None: + raise ValueError("Expected 'data' key in the dictionary") + return cls.from_base64(base_64) + elif isinstance(value, cls): + return value + else: + raise TypeError(f"Expected dict or {cls.__name__} instance, got {type(value)}") + + # Custom serialization + def serialize(value: Image) -> dict[str, Any]: + return {"data": value.to_base64()} + + return core_schema.with_info_after_validator_function( + validate, + core_schema.any_schema(), # Accept any type; adjust if needed + serialization=core_schema.plain_serializer_function_ser_schema(serialize), + ) + def _convert_base64_to_data_uri(base64_image: str) -> str: def _get_mime_type_from_data_uri(base64_image: str) -> str: diff --git a/python/packages/autogen-core/src/autogen_core/components/_routed_agent.py b/python/packages/autogen-core/src/autogen_core/components/_routed_agent.py index e5f47a7f3..5d3e8149c 100644 --- a/python/packages/autogen-core/src/autogen_core/components/_routed_agent.py +++ b/python/packages/autogen-core/src/autogen_core/components/_routed_agent.py @@ -21,8 +21,8 @@ from typing import ( from autogen_core.base import try_get_known_serializers_for_type from ..base import MESSAGE_TYPE_REGISTRY, BaseAgent, MessageContext +from ..base._type_helpers import AnyType, get_types from ..base.exceptions import CantHandleException -from ._type_helpers import AnyType, get_types logger = logging.getLogger("autogen_core") diff --git a/python/packages/autogen-core/tests/test_serialization.py b/python/packages/autogen-core/tests/test_serialization.py index 7ca3fa722..9029c905c 100644 --- a/python/packages/autogen-core/tests/test_serialization.py +++ b/python/packages/autogen-core/tests/test_serialization.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from typing import Union import pytest from autogen_core.base import ( @@ -7,6 +8,9 @@ from autogen_core.base import ( Serialization, try_get_known_serializers_for_type, ) +from autogen_core.base._serialization import DataclassJsonMessageSerializer, PydanticJsonMessageSerializer +from autogen_core.components import Image +from PIL import Image as PILImage from pydantic import BaseModel @@ -75,22 +79,33 @@ def test_dataclass() -> None: def test_nesting_dataclass_dataclass() -> None: serde = Serialization() - serde.add_serializer(try_get_known_serializers_for_type(NestingDataclassMessage)) - - message = NestingDataclassMessage(message="hello", nested=DataclassMessage(message="world")) - name = serde.type_name(message) with pytest.raises(ValueError): - _json = serde.serialize(message, type_name=name, data_content_type=JSON_DATA_CONTENT_TYPE) + serde.add_serializer(try_get_known_serializers_for_type(NestingDataclassMessage)) + + +@dataclass +class DataclassNestedUnionSyntaxOldMessage: + message: Union[str, int] + + +@dataclass +class DataclassNestedUnionSyntaxNewMessage: + message: str | int + + +@pytest.mark.parametrize("cls", [DataclassNestedUnionSyntaxOldMessage, DataclassNestedUnionSyntaxNewMessage]) +def test_nesting_union_old_syntax_dataclass( + cls: type[DataclassNestedUnionSyntaxOldMessage | DataclassNestedUnionSyntaxNewMessage], +) -> None: + with pytest.raises(ValueError): + _serializer = DataclassJsonMessageSerializer(cls) def test_nesting_dataclass_pydantic() -> None: serde = Serialization() - serde.add_serializer(try_get_known_serializers_for_type(NestingPydanticDataclassMessage)) - message = NestingPydanticDataclassMessage(message="hello", nested=PydanticMessage(message="world")) - name = serde.type_name(message) with pytest.raises(ValueError): - _json = serde.serialize(message, type_name=name, data_content_type=JSON_DATA_CONTENT_TYPE) + serde.add_serializer(try_get_known_serializers_for_type(NestingPydanticDataclassMessage)) def test_invalid_type() -> None: @@ -126,3 +141,22 @@ def test_custom_type() -> None: assert json == b'"hello"' deserialized = serde.deserialize(json, type_name="custom_str", data_content_type="str") assert deserialized == message + + +def test_image_type() -> None: + pil_image = PILImage.new("RGB", (100, 100)) + + image = Image(pil_image) + + class PydanticImageMessage(BaseModel): + image: Image + + serializer = PydanticJsonMessageSerializer(PydanticImageMessage) + + json = serializer.serialize(PydanticImageMessage(image=image)) + + deserialized = serializer.deserialize(json) + + assert deserialized.image.image.size == (100, 100) + assert deserialized.image.image.mode == "RGB" + assert deserialized.image.image == image.image diff --git a/python/packages/autogen-core/tests/test_types.py b/python/packages/autogen-core/tests/test_types.py index 9037e3d3f..1dbc02c4f 100644 --- a/python/packages/autogen-core/tests/test_types.py +++ b/python/packages/autogen-core/tests/test_types.py @@ -1,9 +1,12 @@ +from dataclasses import dataclass from types import NoneType -from typing import Any, Optional, Union +from typing import Any, List, Optional, Union from autogen_core.base import MessageContext +from autogen_core.base._serialization import has_nested_base_model +from autogen_core.base._type_helpers import AnyType, get_types from autogen_core.components._routed_agent import message_handler -from autogen_core.components._type_helpers import AnyType, get_types +from pydantic import BaseModel def test_get_types() -> None: @@ -38,3 +41,44 @@ class HandlerClass: @message_handler() async def handler(self, message: int, ctx: MessageContext) -> Any: return None + + +def test_nested_data_model() -> None: + class MyBaseModel(BaseModel): + message: str + + @dataclass + class NestedBaseModel: + nested: MyBaseModel + + @dataclass + class NestedBaseModelList: + nested: List[MyBaseModel] + + @dataclass + class NestedBaseModelList2: + nested: list[MyBaseModel] + + @dataclass + class NestedBaseModelList3: + nested: list[list[MyBaseModel]] + + @dataclass + class NestedBaseModelList4: + nested: list[list[list[list[list[list[MyBaseModel]]]]]] + + @dataclass + class NestedBaseModelUnion: + nested: Union[MyBaseModel, str] + + @dataclass + class NestedBaseModelUnion2: + nested: MyBaseModel | str + + assert has_nested_base_model(NestedBaseModel) + assert has_nested_base_model(NestedBaseModelList) + assert has_nested_base_model(NestedBaseModelList2) + assert has_nested_base_model(NestedBaseModelList3) + assert has_nested_base_model(NestedBaseModelList4) + assert has_nested_base_model(NestedBaseModelUnion) + assert has_nested_base_model(NestedBaseModelUnion2) diff --git a/python/packages/team-one/examples/example_reflexagents.py b/python/packages/team-one/examples/example_reflexagents.py index 3346a0926..90a3db258 100644 --- a/python/packages/team-one/examples/example_reflexagents.py +++ b/python/packages/team-one/examples/example_reflexagents.py @@ -30,7 +30,7 @@ async def main() -> None: task_message = UserMessage(content="Test Message", source="User") runtime.start() - await runtime.publish_message(BroadcastMessage(task_message), topic_id=DefaultTopicId()) + await runtime.publish_message(BroadcastMessage(content=task_message), topic_id=DefaultTopicId()) await runtime.stop_when_idle() diff --git a/python/packages/team-one/pyproject.toml b/python/packages/team-one/pyproject.toml index d84884467..848f7bc22 100644 --- a/python/packages/team-one/pyproject.toml +++ b/python/packages/team-one/pyproject.toml @@ -33,6 +33,7 @@ dependencies = [ "SpeechRecognition", "pathvalidate", "playwright", + "pydantic<3.0.0,>=2.0.0", ] [project.optional-dependencies] diff --git a/python/packages/team-one/src/team_one/agents/reflex_agents.py b/python/packages/team-one/src/team_one/agents/reflex_agents.py index a5eee1c1e..52439541e 100644 --- a/python/packages/team-one/src/team_one/agents/reflex_agents.py +++ b/python/packages/team-one/src/team_one/agents/reflex_agents.py @@ -24,4 +24,4 @@ class ReflexAgent(RoutedAgent): ) topic_id = TopicId("default", self.id.key) - await self.publish_message(BroadcastMessage(response_message), topic_id=topic_id) + await self.publish_message(BroadcastMessage(content=response_message), topic_id=topic_id) diff --git a/python/packages/team-one/src/team_one/messages.py b/python/packages/team-one/src/team_one/messages.py index 4eee7bf34..b43a1a5e7 100644 --- a/python/packages/team-one/src/team_one/messages.py +++ b/python/packages/team-one/src/team_one/messages.py @@ -3,6 +3,7 @@ from typing import Any, Dict, List, Union from autogen_core.components import FunctionCall, Image from autogen_core.components.models import FunctionExecutionResult, LLMMessage +from pydantic import BaseModel # Convenience type UserContent = Union[str, List[Union[str, Image]]] @@ -11,8 +12,7 @@ FunctionExecutionContent = List[FunctionExecutionResult] SystemContent = str -@dataclass -class BroadcastMessage: +class BroadcastMessage(BaseModel): content: LLMMessage request_halt: bool = False diff --git a/python/uv.lock b/python/uv.lock index a9ee125f3..ea0d59d01 100644 --- a/python/uv.lock +++ b/python/uv.lock @@ -350,7 +350,7 @@ requires-dist = [ { name = "opentelemetry-api", specifier = "~=1.27.0" }, { name = "pillow" }, { name = "protobuf", specifier = "~=4.25.1" }, - { name = "pydantic", specifier = ">=1.10,<3" }, + { name = "pydantic", specifier = ">=2.0.0,<3.0.0" }, { name = "tiktoken" }, { name = "typing-extensions" }, ] @@ -4205,6 +4205,7 @@ dependencies = [ { name = "pdfminer-six" }, { name = "playwright" }, { name = "puremagic" }, + { name = "pydantic" }, { name = "pydub" }, { name = "python-pptx" }, { name = "requests" }, @@ -4242,6 +4243,7 @@ requires-dist = [ { name = "pdfminer-six" }, { name = "playwright" }, { name = "puremagic" }, + { name = "pydantic", specifier = ">=2.0.0,<3.0.0" }, { name = "pydub" }, { name = "python-pptx" }, { name = "requests" },