From 5bd91fb15ee94b3d60c65d7981084f38ecca12aa Mon Sep 17 00:00:00 2001 From: Jack Gerrits Date: Fri, 27 Dec 2024 12:21:39 -0500 Subject: [PATCH] Accept agent type in more places (#4829) * Accept agenttype in more places * remove type hint --- .../src/autogen_core/_default_subscription.py | 3 ++- .../src/autogen_core/_type_prefix_subscription.py | 8 ++++++-- .../autogen-core/src/autogen_core/_type_subscription.py | 8 ++++++-- 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/python/packages/autogen-core/src/autogen_core/_default_subscription.py b/python/packages/autogen-core/src/autogen_core/_default_subscription.py index d47c62109..4c5251c85 100644 --- a/python/packages/autogen-core/src/autogen_core/_default_subscription.py +++ b/python/packages/autogen-core/src/autogen_core/_default_subscription.py @@ -1,5 +1,6 @@ from typing import Callable, Type, TypeVar, overload +from ._agent_type import AgentType from ._base_agent import BaseAgent, subscription_factory from ._subscription_context import SubscriptionInstantiationContext from ._type_subscription import TypeSubscription @@ -16,7 +17,7 @@ class DefaultSubscription(TypeSubscription): agent_type (str, optional): The agent type to use for the subscription. Defaults to None, in which case it will attempt to detect the agent type based on the instantiation context. """ - def __init__(self, topic_type: str = "default", agent_type: str | None = None): + def __init__(self, topic_type: str = "default", agent_type: str | AgentType | None = None): if agent_type is None: try: agent_type = SubscriptionInstantiationContext.agent_type().type diff --git a/python/packages/autogen-core/src/autogen_core/_type_prefix_subscription.py b/python/packages/autogen-core/src/autogen_core/_type_prefix_subscription.py index 4c88f12a4..9e0d52683 100644 --- a/python/packages/autogen-core/src/autogen_core/_type_prefix_subscription.py +++ b/python/packages/autogen-core/src/autogen_core/_type_prefix_subscription.py @@ -1,6 +1,7 @@ import uuid from ._agent_id import AgentId +from ._agent_type import AgentType from ._subscription import Subscription from ._topic import TopicId from .exceptions import CantHandleException @@ -30,9 +31,12 @@ class TypePrefixSubscription(Subscription): agent_type (str): Agent type to handle this subscription """ - def __init__(self, topic_type_prefix: str, agent_type: str): + def __init__(self, topic_type_prefix: str, agent_type: str | AgentType): self._topic_type_prefix = topic_type_prefix - self._agent_type = agent_type + if isinstance(agent_type, AgentType): + self._agent_type = agent_type.type + else: + self._agent_type = agent_type self._id = str(uuid.uuid4()) @property diff --git a/python/packages/autogen-core/src/autogen_core/_type_subscription.py b/python/packages/autogen-core/src/autogen_core/_type_subscription.py index 4ad815418..14d599b56 100644 --- a/python/packages/autogen-core/src/autogen_core/_type_subscription.py +++ b/python/packages/autogen-core/src/autogen_core/_type_subscription.py @@ -1,6 +1,7 @@ import uuid from ._agent_id import AgentId +from ._agent_type import AgentType from ._subscription import Subscription from ._topic import TopicId from .exceptions import CantHandleException @@ -29,9 +30,12 @@ class TypeSubscription(Subscription): agent_type (str): Agent type to handle this subscription """ - def __init__(self, topic_type: str, agent_type: str): + def __init__(self, topic_type: str, agent_type: str | AgentType): self._topic_type = topic_type - self._agent_type = agent_type + if isinstance(agent_type, AgentType): + self._agent_type = agent_type.type + else: + self._agent_type = agent_type self._id = str(uuid.uuid4()) @property