mirror of
https://github.com/microsoft/autogen.git
synced 2026-05-13 03:00:55 -04:00
Add MagenticOneGroupChat to AGS (#4595)
* add magenticonegroupchat to ags * fix termination condition * typing order check * format error * fix M1 orchestrator handle tool mesages * add filesurfer and coder
This commit is contained in:
@@ -12,7 +12,17 @@ from autogen_core.components.models import (
|
||||
|
||||
from .... import TRACE_LOGGER_NAME
|
||||
from ....base import Response, TerminationCondition
|
||||
from ....messages import AgentMessage, ChatMessage, MultiModalMessage, StopMessage, TextMessage
|
||||
from ....messages import (
|
||||
AgentMessage,
|
||||
ChatMessage,
|
||||
HandoffMessage,
|
||||
MultiModalMessage,
|
||||
StopMessage,
|
||||
TextMessage,
|
||||
ToolCallMessage,
|
||||
ToolCallResultMessage,
|
||||
)
|
||||
|
||||
from ....state import MagenticOneOrchestratorState
|
||||
from .._base_group_chat_manager import BaseGroupChatManager
|
||||
from .._events import (
|
||||
@@ -418,7 +428,12 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
|
||||
"""Convert the message thread to a context for the model."""
|
||||
context: List[LLMMessage] = []
|
||||
for m in self._message_thread:
|
||||
if m.source == self._name:
|
||||
if isinstance(m, ToolCallMessage | ToolCallResultMessage):
|
||||
# Ignore tool call messages.
|
||||
continue
|
||||
elif isinstance(m, StopMessage | HandoffMessage):
|
||||
context.append(UserMessage(content=m.content, source=m.source))
|
||||
elif m.source == self._name:
|
||||
assert isinstance(m, TextMessage)
|
||||
context.append(AssistantMessage(content=m.content, source=m.source))
|
||||
else:
|
||||
|
||||
@@ -8,9 +8,11 @@ import aiofiles
|
||||
import yaml
|
||||
from autogen_agentchat.agents import AssistantAgent, UserProxyAgent
|
||||
from autogen_agentchat.conditions import MaxMessageTermination, StopMessageTermination, TextMentionTermination
|
||||
from autogen_agentchat.teams import RoundRobinGroupChat, SelectorGroupChat
|
||||
from autogen_agentchat.teams import RoundRobinGroupChat, SelectorGroupChat, MagenticOneGroupChat
|
||||
from autogen_core.components.tools import FunctionTool
|
||||
from autogen_ext.agents.web_surfer import MultimodalWebSurfer
|
||||
from autogen_ext.agents.file_surfer import FileSurfer
|
||||
from autogen_ext.agents.magentic_one import MagenticOneCoderAgent
|
||||
from autogen_ext.models import OpenAIChatCompletionClient
|
||||
|
||||
from ..datamodel.types import (
|
||||
@@ -32,8 +34,8 @@ from ..utils.utils import Version
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TeamComponent = Union[RoundRobinGroupChat, SelectorGroupChat]
|
||||
AgentComponent = Union[AssistantAgent, MultimodalWebSurfer]
|
||||
TeamComponent = Union[RoundRobinGroupChat, SelectorGroupChat, MagenticOneGroupChat]
|
||||
AgentComponent = Union[AssistantAgent, MultimodalWebSurfer, UserProxyAgent, FileSurfer, MagenticOneCoderAgent]
|
||||
ModelComponent = Union[OpenAIChatCompletionClient]
|
||||
ToolComponent = Union[FunctionTool] # Will grow with more tool types
|
||||
TerminationComponent = Union[MaxMessageTermination, StopMessageTermination, TextMentionTermination]
|
||||
@@ -243,6 +245,15 @@ class ComponentFactory:
|
||||
termination_condition=termination,
|
||||
selector_prompt=selector_prompt,
|
||||
)
|
||||
elif config.team_type == TeamTypes.MAGENTIC_ONE:
|
||||
if not model_client:
|
||||
raise ValueError("MagenticOneGroupChat requires a model_client")
|
||||
return MagenticOneGroupChat(
|
||||
participants=participants,
|
||||
model_client=model_client,
|
||||
termination_condition=termination if termination is not None else None,
|
||||
max_turns=config.max_turns if config.max_turns is not None else 20,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported team type: {config.team_type}")
|
||||
|
||||
@@ -292,7 +303,16 @@ class ComponentFactory:
|
||||
use_ocr=config.use_ocr if config.use_ocr is not None else False,
|
||||
animate_actions=config.animate_actions if config.animate_actions is not None else False,
|
||||
)
|
||||
|
||||
elif config.agent_type == AgentTypes.FILE_SURFER:
|
||||
return FileSurfer(
|
||||
name=config.name,
|
||||
model_client=model_client,
|
||||
)
|
||||
elif config.agent_type == AgentTypes.MAGENTIC_ONE_CODER:
|
||||
return MagenticOneCoderAgent(
|
||||
name=config.name,
|
||||
model_client=model_client,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported agent type: {config.agent_type}")
|
||||
|
||||
|
||||
@@ -19,11 +19,14 @@ class AgentTypes(str, Enum):
|
||||
ASSISTANT = "AssistantAgent"
|
||||
USERPROXY = "UserProxyAgent"
|
||||
MULTIMODAL_WEBSURFER = "MultimodalWebSurfer"
|
||||
FILE_SURFER = "FileSurfer"
|
||||
MAGENTIC_ONE_CODER = "MagenticOneCoderAgent"
|
||||
|
||||
|
||||
class TeamTypes(str, Enum):
|
||||
ROUND_ROBIN = "RoundRobinGroupChat"
|
||||
SELECTOR = "SelectorGroupChat"
|
||||
MAGENTIC_ONE = "MagenticOneGroupChat"
|
||||
|
||||
|
||||
class TerminationTypes(str, Enum):
|
||||
@@ -103,6 +106,7 @@ class TeamConfig(BaseConfig):
|
||||
selector_prompt: Optional[str] = None
|
||||
termination_condition: Optional[TerminationConfig] = None
|
||||
component_type: ComponentTypes = ComponentTypes.TEAM
|
||||
max_turns: Optional[int] = None
|
||||
|
||||
|
||||
class TeamResult(BaseModel):
|
||||
|
||||
@@ -92,12 +92,15 @@ class WebSocketManager:
|
||||
await self._send_message(run_id, formatted_message)
|
||||
|
||||
# Save message if it's a content message
|
||||
if isinstance(message, (AgentMessage, ChatMessage)):
|
||||
if isinstance(message, TextMessage):
|
||||
await self._save_message(run_id, message)
|
||||
elif isinstance(message, MultiModalMessage):
|
||||
await self._save_message(run_id, message)
|
||||
# Capture final result if it's a TeamResult
|
||||
elif isinstance(message, TeamResult):
|
||||
final_result = message.model_dump()
|
||||
|
||||
elif isinstance(message, (AgentMessage, ChatMessage)):
|
||||
await self._save_message(run_id, message)
|
||||
if not cancellation_token.is_cancelled() and run_id not in self._closed_connections:
|
||||
if final_result:
|
||||
await self._update_run(run_id, RunStatus.COMPLETE, team_result=final_result)
|
||||
@@ -285,6 +288,7 @@ class WebSocketManager:
|
||||
Returns:
|
||||
Optional[dict]: Formatted message or None if formatting fails
|
||||
"""
|
||||
|
||||
try:
|
||||
if isinstance(message, MultiModalMessage):
|
||||
message_dump = message.model_dump()
|
||||
@@ -296,7 +300,8 @@ class WebSocketManager:
|
||||
},
|
||||
]
|
||||
return {"type": "message", "data": message_dump}
|
||||
elif isinstance(message, (AgentMessage, ChatMessage)):
|
||||
|
||||
elif isinstance(message, TextMessage):
|
||||
return {"type": "message", "data": message.model_dump()}
|
||||
|
||||
elif isinstance(message, TeamResult):
|
||||
@@ -305,6 +310,9 @@ class WebSocketManager:
|
||||
"data": message.model_dump(),
|
||||
"status": "complete",
|
||||
}
|
||||
elif isinstance(message, (AgentMessage, ChatMessage)):
|
||||
return {"type": "message", "data": message.model_dump()}
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Message formatting error: {e}")
|
||||
|
||||
@@ -114,9 +114,11 @@ export type ModelTypes = "OpenAIChatCompletionClient";
|
||||
export type AgentTypes =
|
||||
| "AssistantAgent"
|
||||
| "CodingAssistantAgent"
|
||||
| "MultimodalWebSurfer";
|
||||
| "MultimodalWebSurfer"
|
||||
| "FileSurfer"
|
||||
| "MagenticOneCoderAgent";
|
||||
|
||||
export type TeamTypes = "RoundRobinGroupChat" | "SelectorGroupChat";
|
||||
export type TeamTypes = "RoundRobinGroupChat" | "SelectorGroupChat" | "MagenticOneGroupChat";
|
||||
|
||||
// class ComponentType(str, Enum):
|
||||
// TEAM = "team"
|
||||
|
||||
@@ -85,7 +85,7 @@ export const TeamEditor: React.FC<TeamEditorProps> = ({
|
||||
throw new Error("Participants must be an array");
|
||||
}
|
||||
if (
|
||||
!["RoundRobinGroupChat", "SelectorGroupChat"].includes(parsed.team_type)
|
||||
!["RoundRobinGroupChat", "SelectorGroupChat", "MagenticOneGroupChat"].includes(parsed.team_type)
|
||||
) {
|
||||
throw new Error("Invalid team_type");
|
||||
}
|
||||
@@ -169,7 +169,7 @@ export const TeamEditor: React.FC<TeamEditorProps> = ({
|
||||
>
|
||||
<div className="mb-2 text-xs text-gray-500">
|
||||
Required fields: name (string), team_type ("RoundRobinGroupChat" |
|
||||
"SelectorGroupChat"), participants (array)
|
||||
"SelectorGroupChat" | "MagenticOneGroupChat"), participants (array)
|
||||
</div>
|
||||
|
||||
<div className="h-[500px] mb-4">
|
||||
|
||||
@@ -2,14 +2,22 @@ import pytest
|
||||
from typing import List
|
||||
|
||||
from autogen_agentchat.agents import AssistantAgent
|
||||
from autogen_agentchat.teams import RoundRobinGroupChat, SelectorGroupChat
|
||||
from autogen_agentchat.teams import RoundRobinGroupChat, SelectorGroupChat, MagenticOneGroupChat
|
||||
from autogen_agentchat.conditions import MaxMessageTermination, StopMessageTermination, TextMentionTermination
|
||||
from autogen_core.components.tools import FunctionTool
|
||||
|
||||
from autogenstudio.datamodel.types import (
|
||||
AgentConfig, ModelConfig, TeamConfig, ToolConfig, TerminationConfig,
|
||||
ModelTypes, AgentTypes, TeamTypes, TerminationTypes, ToolTypes,
|
||||
ComponentTypes
|
||||
AgentConfig,
|
||||
ModelConfig,
|
||||
TeamConfig,
|
||||
ToolConfig,
|
||||
TerminationConfig,
|
||||
ModelTypes,
|
||||
AgentTypes,
|
||||
TeamTypes,
|
||||
TerminationTypes,
|
||||
ToolTypes,
|
||||
ComponentTypes,
|
||||
)
|
||||
from autogenstudio.database import ComponentFactory
|
||||
|
||||
@@ -42,7 +50,7 @@ def calculator(a: int, b: int, operation: str = '+') -> int:
|
||||
""",
|
||||
tool_type=ToolTypes.PYTHON_FUNCTION,
|
||||
component_type=ComponentTypes.TOOL,
|
||||
version="1.0.0"
|
||||
version="1.0.0",
|
||||
)
|
||||
|
||||
|
||||
@@ -53,7 +61,7 @@ def sample_model_config():
|
||||
model="gpt-4",
|
||||
api_key="test-key",
|
||||
component_type=ComponentTypes.MODEL,
|
||||
version="1.0.0"
|
||||
version="1.0.0",
|
||||
)
|
||||
|
||||
|
||||
@@ -66,7 +74,7 @@ def sample_agent_config(sample_model_config: ModelConfig, sample_tool_config: To
|
||||
model_client=sample_model_config,
|
||||
tools=[sample_tool_config],
|
||||
component_type=ComponentTypes.AGENT,
|
||||
version="1.0.0"
|
||||
version="1.0.0",
|
||||
)
|
||||
|
||||
|
||||
@@ -76,12 +84,14 @@ def sample_termination_config():
|
||||
termination_type=TerminationTypes.MAX_MESSAGES,
|
||||
max_messages=10,
|
||||
component_type=ComponentTypes.TERMINATION,
|
||||
version="1.0.0"
|
||||
version="1.0.0",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_team_config(sample_agent_config: AgentConfig, sample_termination_config: TerminationConfig, sample_model_config: ModelConfig):
|
||||
def sample_team_config(
|
||||
sample_agent_config: AgentConfig, sample_termination_config: TerminationConfig, sample_model_config: ModelConfig
|
||||
):
|
||||
return TeamConfig(
|
||||
name="test_team",
|
||||
team_type=TeamTypes.ROUND_ROBIN,
|
||||
@@ -89,7 +99,8 @@ def sample_team_config(sample_agent_config: AgentConfig, sample_termination_conf
|
||||
termination_condition=sample_termination_config,
|
||||
model_client=sample_model_config,
|
||||
component_type=ComponentTypes.TEAM,
|
||||
version="1.0.0"
|
||||
max_turns=10,
|
||||
version="1.0.0",
|
||||
)
|
||||
|
||||
|
||||
@@ -102,7 +113,7 @@ async def test_load_tool(component_factory: ComponentFactory, sample_tool_config
|
||||
assert tool.description == "A simple calculator function"
|
||||
|
||||
# Test tool functionality
|
||||
result = tool._func(5, 3, '+')
|
||||
result = tool._func(5, 3, "+")
|
||||
assert result == 8
|
||||
|
||||
|
||||
@@ -110,14 +121,16 @@ async def test_load_tool(component_factory: ComponentFactory, sample_tool_config
|
||||
async def test_load_tool_invalid_config(component_factory: ComponentFactory):
|
||||
# Test with missing required fields
|
||||
with pytest.raises(ValueError):
|
||||
await component_factory.load_tool(ToolConfig(
|
||||
name="test",
|
||||
description="",
|
||||
content="",
|
||||
tool_type=ToolTypes.PYTHON_FUNCTION,
|
||||
component_type=ComponentTypes.TOOL,
|
||||
version="1.0.0"
|
||||
))
|
||||
await component_factory.load_tool(
|
||||
ToolConfig(
|
||||
name="test",
|
||||
description="",
|
||||
content="",
|
||||
tool_type=ToolTypes.PYTHON_FUNCTION,
|
||||
component_type=ComponentTypes.TOOL,
|
||||
version="1.0.0",
|
||||
)
|
||||
)
|
||||
|
||||
# Test with invalid Python code
|
||||
invalid_config = ToolConfig(
|
||||
@@ -126,7 +139,7 @@ async def test_load_tool_invalid_config(component_factory: ComponentFactory):
|
||||
content="def invalid_func(): return invalid syntax",
|
||||
tool_type=ToolTypes.PYTHON_FUNCTION,
|
||||
component_type=ComponentTypes.TOOL,
|
||||
version="1.0.0"
|
||||
version="1.0.0",
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
await component_factory.load_tool(invalid_config)
|
||||
@@ -155,7 +168,7 @@ async def test_load_termination(component_factory: ComponentFactory):
|
||||
termination_type=TerminationTypes.MAX_MESSAGES,
|
||||
max_messages=5,
|
||||
component_type=ComponentTypes.TERMINATION,
|
||||
version="1.0.0"
|
||||
version="1.0.0",
|
||||
)
|
||||
termination = await component_factory.load_termination(max_msg_config)
|
||||
assert isinstance(termination, MaxMessageTermination)
|
||||
@@ -163,9 +176,7 @@ async def test_load_termination(component_factory: ComponentFactory):
|
||||
|
||||
# Test StopMessageTermination
|
||||
stop_msg_config = TerminationConfig(
|
||||
termination_type=TerminationTypes.STOP_MESSAGE,
|
||||
component_type=ComponentTypes.TERMINATION,
|
||||
version="1.0.0"
|
||||
termination_type=TerminationTypes.STOP_MESSAGE, component_type=ComponentTypes.TERMINATION, version="1.0.0"
|
||||
)
|
||||
termination = await component_factory.load_termination(stop_msg_config)
|
||||
assert isinstance(termination, StopMessageTermination)
|
||||
@@ -175,7 +186,7 @@ async def test_load_termination(component_factory: ComponentFactory):
|
||||
termination_type=TerminationTypes.TEXT_MENTION,
|
||||
text="DONE",
|
||||
component_type=ComponentTypes.TERMINATION,
|
||||
version="1.0.0"
|
||||
version="1.0.0",
|
||||
)
|
||||
termination = await component_factory.load_termination(text_mention_config)
|
||||
assert isinstance(termination, TextMentionTermination)
|
||||
@@ -190,17 +201,17 @@ async def test_load_termination(component_factory: ComponentFactory):
|
||||
termination_type=TerminationTypes.MAX_MESSAGES,
|
||||
max_messages=5,
|
||||
component_type=ComponentTypes.TERMINATION,
|
||||
version="1.0.0"
|
||||
version="1.0.0",
|
||||
),
|
||||
TerminationConfig(
|
||||
termination_type=TerminationTypes.TEXT_MENTION,
|
||||
text="DONE",
|
||||
component_type=ComponentTypes.TERMINATION,
|
||||
version="1.0.0"
|
||||
)
|
||||
version="1.0.0",
|
||||
),
|
||||
],
|
||||
component_type=ComponentTypes.TERMINATION,
|
||||
version="1.0.0"
|
||||
version="1.0.0",
|
||||
)
|
||||
termination = await component_factory.load_termination(and_combo_config)
|
||||
assert termination is not None
|
||||
@@ -214,71 +225,79 @@ async def test_load_termination(component_factory: ComponentFactory):
|
||||
termination_type=TerminationTypes.MAX_MESSAGES,
|
||||
max_messages=5,
|
||||
component_type=ComponentTypes.TERMINATION,
|
||||
version="1.0.0"
|
||||
version="1.0.0",
|
||||
),
|
||||
TerminationConfig(
|
||||
termination_type=TerminationTypes.TEXT_MENTION,
|
||||
text="DONE",
|
||||
component_type=ComponentTypes.TERMINATION,
|
||||
version="1.0.0"
|
||||
)
|
||||
version="1.0.0",
|
||||
),
|
||||
],
|
||||
component_type=ComponentTypes.TERMINATION,
|
||||
version="1.0.0"
|
||||
version="1.0.0",
|
||||
)
|
||||
termination = await component_factory.load_termination(or_combo_config)
|
||||
assert termination is not None
|
||||
|
||||
# Test invalid combinations
|
||||
with pytest.raises(ValueError):
|
||||
await component_factory.load_termination(TerminationConfig(
|
||||
termination_type=TerminationTypes.COMBINATION,
|
||||
conditions=[], # Empty conditions
|
||||
component_type=ComponentTypes.TERMINATION,
|
||||
version="1.0.0"
|
||||
))
|
||||
await component_factory.load_termination(
|
||||
TerminationConfig(
|
||||
termination_type=TerminationTypes.COMBINATION,
|
||||
conditions=[], # Empty conditions
|
||||
component_type=ComponentTypes.TERMINATION,
|
||||
version="1.0.0",
|
||||
)
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await component_factory.load_termination(TerminationConfig(
|
||||
termination_type=TerminationTypes.COMBINATION,
|
||||
operator="invalid", # type: ignore
|
||||
conditions=[
|
||||
TerminationConfig(
|
||||
termination_type=TerminationTypes.MAX_MESSAGES,
|
||||
max_messages=5,
|
||||
component_type=ComponentTypes.TERMINATION,
|
||||
version="1.0.0"
|
||||
)
|
||||
],
|
||||
component_type=ComponentTypes.TERMINATION,
|
||||
version="1.0.0"
|
||||
))
|
||||
await component_factory.load_termination(
|
||||
TerminationConfig(
|
||||
termination_type=TerminationTypes.COMBINATION,
|
||||
operator="invalid", # type: ignore
|
||||
conditions=[
|
||||
TerminationConfig(
|
||||
termination_type=TerminationTypes.MAX_MESSAGES,
|
||||
max_messages=5,
|
||||
component_type=ComponentTypes.TERMINATION,
|
||||
version="1.0.0",
|
||||
)
|
||||
],
|
||||
component_type=ComponentTypes.TERMINATION,
|
||||
version="1.0.0",
|
||||
)
|
||||
)
|
||||
|
||||
# Test missing operator
|
||||
with pytest.raises(ValueError):
|
||||
await component_factory.load_termination(TerminationConfig(
|
||||
termination_type=TerminationTypes.COMBINATION,
|
||||
conditions=[
|
||||
TerminationConfig(
|
||||
termination_type=TerminationTypes.MAX_MESSAGES,
|
||||
max_messages=5,
|
||||
component_type=ComponentTypes.TERMINATION,
|
||||
version="1.0.0"
|
||||
),
|
||||
TerminationConfig(
|
||||
termination_type=TerminationTypes.TEXT_MENTION,
|
||||
text="DONE",
|
||||
component_type=ComponentTypes.TERMINATION,
|
||||
version="1.0.0"
|
||||
)
|
||||
],
|
||||
component_type=ComponentTypes.TERMINATION,
|
||||
version="1.0.0"
|
||||
))
|
||||
await component_factory.load_termination(
|
||||
TerminationConfig(
|
||||
termination_type=TerminationTypes.COMBINATION,
|
||||
conditions=[
|
||||
TerminationConfig(
|
||||
termination_type=TerminationTypes.MAX_MESSAGES,
|
||||
max_messages=5,
|
||||
component_type=ComponentTypes.TERMINATION,
|
||||
version="1.0.0",
|
||||
),
|
||||
TerminationConfig(
|
||||
termination_type=TerminationTypes.TEXT_MENTION,
|
||||
text="DONE",
|
||||
component_type=ComponentTypes.TERMINATION,
|
||||
version="1.0.0",
|
||||
),
|
||||
],
|
||||
component_type=ComponentTypes.TERMINATION,
|
||||
version="1.0.0",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_team(component_factory: ComponentFactory, sample_team_config: TeamConfig, sample_model_config: ModelConfig):
|
||||
async def test_load_team(
|
||||
component_factory: ComponentFactory, sample_team_config: TeamConfig, sample_model_config: ModelConfig
|
||||
):
|
||||
# Test loading RoundRobinGroupChat team
|
||||
team = await component_factory.load_team(sample_team_config)
|
||||
assert isinstance(team, RoundRobinGroupChat)
|
||||
@@ -297,45 +316,77 @@ async def test_load_team(component_factory: ComponentFactory, sample_team_config
|
||||
model_client=sample_model_config,
|
||||
tools=sample_team_config.participants[0].tools,
|
||||
component_type=ComponentTypes.AGENT,
|
||||
version="1.0.0"
|
||||
)
|
||||
version="1.0.0",
|
||||
),
|
||||
],
|
||||
termination_condition=sample_team_config.termination_condition,
|
||||
model_client=sample_model_config,
|
||||
component_type=ComponentTypes.TEAM,
|
||||
version="1.0.0"
|
||||
version="1.0.0",
|
||||
)
|
||||
team = await component_factory.load_team(selector_team_config)
|
||||
assert isinstance(team, SelectorGroupChat)
|
||||
assert len(team._participants) == 2
|
||||
|
||||
# Test loading MagenticOneGroupChat team
|
||||
magentic_one_config = TeamConfig(
|
||||
name="magentic_one_team",
|
||||
team_type=TeamTypes.MAGENTIC_ONE,
|
||||
participants=[ # Add two participants
|
||||
sample_team_config.participants[0], # First agent
|
||||
AgentConfig( # Second agent
|
||||
name="test_agent_2",
|
||||
agent_type=AgentTypes.ASSISTANT,
|
||||
system_message="You are another helpful assistant",
|
||||
model_client=sample_model_config,
|
||||
tools=sample_team_config.participants[0].tools,
|
||||
component_type=ComponentTypes.AGENT,
|
||||
max_turns=sample_team_config.max_turns,
|
||||
version="1.0.0",
|
||||
),
|
||||
],
|
||||
termination_condition=sample_team_config.termination_condition,
|
||||
model_client=sample_model_config,
|
||||
component_type=ComponentTypes.TEAM,
|
||||
version="1.0.0",
|
||||
)
|
||||
team = await component_factory.load_team(magentic_one_config)
|
||||
assert isinstance(team, MagenticOneGroupChat)
|
||||
assert len(team._participants) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_configs(component_factory: ComponentFactory):
|
||||
# Test invalid agent type
|
||||
with pytest.raises(ValueError):
|
||||
await component_factory.load_agent(AgentConfig(
|
||||
name="test",
|
||||
agent_type="InvalidAgent", # type: ignore
|
||||
system_message="test",
|
||||
component_type=ComponentTypes.AGENT,
|
||||
version="1.0.0"
|
||||
))
|
||||
await component_factory.load_agent(
|
||||
AgentConfig(
|
||||
name="test",
|
||||
agent_type="InvalidAgent", # type: ignore
|
||||
system_message="test",
|
||||
component_type=ComponentTypes.AGENT,
|
||||
version="1.0.0",
|
||||
)
|
||||
)
|
||||
|
||||
# Test invalid team type
|
||||
with pytest.raises(ValueError):
|
||||
await component_factory.load_team(TeamConfig(
|
||||
name="test",
|
||||
team_type="InvalidTeam", # type: ignore
|
||||
participants=[],
|
||||
component_type=ComponentTypes.TEAM,
|
||||
version="1.0.0"
|
||||
))
|
||||
await component_factory.load_team(
|
||||
TeamConfig(
|
||||
name="test",
|
||||
team_type="InvalidTeam", # type: ignore
|
||||
participants=[],
|
||||
component_type=ComponentTypes.TEAM,
|
||||
version="1.0.0",
|
||||
)
|
||||
)
|
||||
|
||||
# Test invalid termination type
|
||||
with pytest.raises(ValueError):
|
||||
await component_factory.load_termination(TerminationConfig(
|
||||
termination_type="InvalidTermination", # type: ignore
|
||||
component_type=ComponentTypes.TERMINATION,
|
||||
version="1.0.0"
|
||||
))
|
||||
await component_factory.load_termination(
|
||||
TerminationConfig(
|
||||
termination_type="InvalidTermination", # type: ignore
|
||||
component_type=ComponentTypes.TERMINATION,
|
||||
version="1.0.0",
|
||||
)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user