mirror of
https://github.com/microsoft/autogen.git
synced 2026-02-15 21:44:56 -05:00
v1 of AutoGen Studio on AgentChat (#4097)
* add skeleton worflow manager * add test notebook * update test nb * add sample team spec * refactor requirements to agentchat and ext * add base provider to return agentchat agents from json spec * initial api refactor, update dbmanager * api refactor * refactor tests * ags api tutorial update * ui refactor * general refactor * minor refactor updates * backend api refaactor * ui refactor and update * implement v1 for streaming connection with ui updates * backend refactor * ui refactor * minor ui tweak * minor refactor and tweaks * general refactor * update tests * sync uv.lock with main * uv lock update
This commit is contained in:
246
python/packages/autogen-studio/tests/test_component_factory.py
Normal file
246
python/packages/autogen-studio/tests/test_component_factory.py
Normal file
@@ -0,0 +1,246 @@
|
||||
import pytest
|
||||
from typing import List
|
||||
|
||||
from autogen_agentchat.agents import AssistantAgent
|
||||
from autogen_agentchat.teams import RoundRobinGroupChat, SelectorGroupChat
|
||||
from autogen_agentchat.task import MaxMessageTermination, StopMessageTermination, TextMentionTermination
|
||||
from autogen_core.components.tools import FunctionTool
|
||||
|
||||
from autogenstudio.datamodel import (
|
||||
AgentConfig, ModelConfig, TeamConfig, ToolConfig, TerminationConfig,
|
||||
ModelTypes, AgentTypes, TeamTypes, TerminationTypes, ToolTypes,
|
||||
ComponentType
|
||||
)
|
||||
from autogenstudio.database import ComponentFactory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def component_factory():
|
||||
return ComponentFactory()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_tool_config():
|
||||
return ToolConfig(
|
||||
name="calculator",
|
||||
description="A simple calculator function",
|
||||
content="""
|
||||
def calculator(a: int, b: int, operation: str = '+') -> int:
|
||||
'''
|
||||
A simple calculator that performs basic operations
|
||||
'''
|
||||
if operation == '+':
|
||||
return a + b
|
||||
elif operation == '-':
|
||||
return a - b
|
||||
elif operation == '*':
|
||||
return a * b
|
||||
elif operation == '/':
|
||||
return a / b
|
||||
else:
|
||||
raise ValueError("Invalid operation")
|
||||
""",
|
||||
tool_type=ToolTypes.PYTHON_FUNCTION,
|
||||
component_type=ComponentType.TOOL,
|
||||
version="1.0.0"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_model_config():
|
||||
return ModelConfig(
|
||||
model_type=ModelTypes.OPENAI,
|
||||
model="gpt-4",
|
||||
api_key="test-key",
|
||||
component_type=ComponentType.MODEL,
|
||||
version="1.0.0"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_agent_config(sample_model_config: ModelConfig, sample_tool_config: ToolConfig):
|
||||
return AgentConfig(
|
||||
name="test_agent",
|
||||
agent_type=AgentTypes.ASSISTANT,
|
||||
system_message="You are a helpful assistant",
|
||||
model_client=sample_model_config,
|
||||
tools=[sample_tool_config],
|
||||
component_type=ComponentType.AGENT,
|
||||
version="1.0.0"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_termination_config():
|
||||
return TerminationConfig(
|
||||
termination_type=TerminationTypes.MAX_MESSAGES,
|
||||
max_messages=10,
|
||||
component_type=ComponentType.TERMINATION,
|
||||
version="1.0.0"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_team_config(sample_agent_config: AgentConfig, sample_termination_config: TerminationConfig, sample_model_config: ModelConfig):
|
||||
return TeamConfig(
|
||||
name="test_team",
|
||||
team_type=TeamTypes.ROUND_ROBIN,
|
||||
participants=[sample_agent_config],
|
||||
termination_condition=sample_termination_config,
|
||||
model_client=sample_model_config,
|
||||
component_type=ComponentType.TEAM,
|
||||
version="1.0.0"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_tool(component_factory: ComponentFactory, sample_tool_config: ToolConfig):
|
||||
# Test loading tool from ToolConfig
|
||||
tool = await component_factory.load_tool(sample_tool_config)
|
||||
assert isinstance(tool, FunctionTool)
|
||||
assert tool.name == "calculator"
|
||||
assert tool.description == "A simple calculator function"
|
||||
|
||||
# Test tool functionality
|
||||
result = tool._func(5, 3, '+')
|
||||
assert result == 8
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_tool_invalid_config(component_factory: ComponentFactory):
|
||||
# Test with missing required fields
|
||||
with pytest.raises(ValueError):
|
||||
await component_factory.load_tool(ToolConfig(
|
||||
name="test",
|
||||
description="",
|
||||
content="",
|
||||
tool_type=ToolTypes.PYTHON_FUNCTION,
|
||||
component_type=ComponentType.TOOL,
|
||||
version="1.0.0"
|
||||
))
|
||||
|
||||
# Test with invalid Python code
|
||||
invalid_config = ToolConfig(
|
||||
name="invalid",
|
||||
description="Invalid function",
|
||||
content="def invalid_func(): return invalid syntax",
|
||||
tool_type=ToolTypes.PYTHON_FUNCTION,
|
||||
component_type=ComponentType.TOOL,
|
||||
version="1.0.0"
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
await component_factory.load_tool(invalid_config)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_model(component_factory: ComponentFactory, sample_model_config: ModelConfig):
|
||||
# Test loading model from ModelConfig
|
||||
model = await component_factory.load_model(sample_model_config)
|
||||
assert model is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_agent(component_factory: ComponentFactory, sample_agent_config: AgentConfig):
|
||||
# Test loading agent from AgentConfig
|
||||
agent = await component_factory.load_agent(sample_agent_config)
|
||||
assert isinstance(agent, AssistantAgent)
|
||||
assert agent.name == "test_agent"
|
||||
assert len(agent._tools) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_termination(component_factory: ComponentFactory):
|
||||
# Test MaxMessageTermination
|
||||
max_msg_config = TerminationConfig(
|
||||
termination_type=TerminationTypes.MAX_MESSAGES,
|
||||
max_messages=5,
|
||||
component_type=ComponentType.TERMINATION,
|
||||
version="1.0.0"
|
||||
)
|
||||
termination = await component_factory.load_termination(max_msg_config)
|
||||
assert isinstance(termination, MaxMessageTermination)
|
||||
assert termination._max_messages == 5
|
||||
|
||||
# Test StopMessageTermination
|
||||
stop_msg_config = TerminationConfig(
|
||||
termination_type=TerminationTypes.STOP_MESSAGE,
|
||||
component_type=ComponentType.TERMINATION,
|
||||
version="1.0.0"
|
||||
)
|
||||
termination = await component_factory.load_termination(stop_msg_config)
|
||||
assert isinstance(termination, StopMessageTermination)
|
||||
|
||||
# Test TextMentionTermination
|
||||
text_mention_config = TerminationConfig(
|
||||
termination_type=TerminationTypes.TEXT_MENTION,
|
||||
text="DONE",
|
||||
component_type=ComponentType.TERMINATION,
|
||||
version="1.0.0"
|
||||
)
|
||||
termination = await component_factory.load_termination(text_mention_config)
|
||||
assert isinstance(termination, TextMentionTermination)
|
||||
assert termination._text == "DONE"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_team(component_factory: ComponentFactory, sample_team_config: TeamConfig, sample_model_config: ModelConfig):
|
||||
# Test loading RoundRobinGroupChat team
|
||||
team = await component_factory.load_team(sample_team_config)
|
||||
assert isinstance(team, RoundRobinGroupChat)
|
||||
assert len(team._participants) == 1
|
||||
|
||||
# Test loading SelectorGroupChat team with multiple participants
|
||||
selector_team_config = TeamConfig(
|
||||
name="selector_team",
|
||||
team_type=TeamTypes.SELECTOR,
|
||||
participants=[ # Add two participants
|
||||
sample_team_config.participants[0], # First agent
|
||||
AgentConfig( # Second agent
|
||||
name="test_agent_2",
|
||||
agent_type=AgentTypes.ASSISTANT,
|
||||
system_message="You are another helpful assistant",
|
||||
model_client=sample_model_config,
|
||||
tools=sample_team_config.participants[0].tools,
|
||||
component_type=ComponentType.AGENT,
|
||||
version="1.0.0"
|
||||
)
|
||||
],
|
||||
termination_condition=sample_team_config.termination_condition,
|
||||
model_client=sample_model_config,
|
||||
component_type=ComponentType.TEAM,
|
||||
version="1.0.0"
|
||||
)
|
||||
team = await component_factory.load_team(selector_team_config)
|
||||
assert isinstance(team, SelectorGroupChat)
|
||||
assert len(team._participants) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_configs(component_factory: ComponentFactory):
|
||||
# Test invalid agent type
|
||||
with pytest.raises(ValueError):
|
||||
await component_factory.load_agent(AgentConfig(
|
||||
name="test",
|
||||
agent_type="InvalidAgent", # type: ignore
|
||||
system_message="test",
|
||||
component_type=ComponentType.AGENT,
|
||||
version="1.0.0"
|
||||
))
|
||||
|
||||
# Test invalid team type
|
||||
with pytest.raises(ValueError):
|
||||
await component_factory.load_team(TeamConfig(
|
||||
name="test",
|
||||
team_type="InvalidTeam", # type: ignore
|
||||
participants=[],
|
||||
component_type=ComponentType.TEAM,
|
||||
version="1.0.0"
|
||||
))
|
||||
|
||||
# Test invalid termination type
|
||||
with pytest.raises(ValueError):
|
||||
await component_factory.load_termination(TerminationConfig(
|
||||
termination_type="InvalidTermination", # type: ignore
|
||||
component_type=ComponentType.TERMINATION,
|
||||
version="1.0.0"
|
||||
))
|
||||
Reference in New Issue
Block a user