Initial impl of topics and subscriptions (#350)

* initial impl of topics and subscriptions

* Update python/src/agnext/core/_agent_runtime.py

Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>

* add topic in context

* migrate

* migrate code for topics

* migrate team one

* edit notebooks

* formatting

* fix imports

* Build proto

* Fix circular import

---------

Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
Jack Gerrits
2024-08-20 14:41:24 -04:00
committed by GitHub
parent 4ba7e84721
commit e1a823fb6d
71 changed files with 685 additions and 495 deletions

View File

@@ -3,6 +3,7 @@ import logging
from agnext.application import SingleThreadedAgentRuntime
from agnext.application.logging import EVENT_LOGGER_NAME
from agnext.core import AgentId, AgentProxy
from team_one.agents.coder import Coder, Executor
from team_one.agents.orchestrator import LedgerOrchestrator
from team_one.agents.user_proxy import UserProxy
@@ -15,18 +16,22 @@ async def main() -> None:
runtime = SingleThreadedAgentRuntime()
# Register agents.
coder = await runtime.register_and_get_proxy(
await runtime.register(
"Coder",
lambda: Coder(model_client=create_completion_client_from_env()),
)
coder = AgentProxy(AgentId("Coder", "default"), runtime)
executor = await runtime.register_and_get_proxy("Executor", lambda: Executor("A agent for executing code"))
await runtime.register("Executor", lambda: Executor("A agent for executing code"))
executor = AgentProxy(AgentId("Executor", "default"), runtime)
user_proxy = await runtime.register_and_get_proxy(
await runtime.register(
"UserProxy",
lambda: UserProxy(description="The current user interacting with you."),
)
user_proxy = AgentProxy(AgentId("UserProxy", "default"), runtime)
# TODO: doesn't work for more than default key
await runtime.register(
"orchestrator",
lambda: LedgerOrchestrator(

View File

@@ -3,6 +3,7 @@ import logging
from agnext.application import SingleThreadedAgentRuntime
from agnext.application.logging import EVENT_LOGGER_NAME
from agnext.core import AgentId, AgentProxy
from team_one.agents.coder import Coder, Executor
from team_one.agents.orchestrator import RoundRobinOrchestrator
from team_one.agents.user_proxy import UserProxy
@@ -15,17 +16,20 @@ async def main() -> None:
runtime = SingleThreadedAgentRuntime()
# Register agents.
coder = await runtime.register_and_get_proxy(
await runtime.register(
"Coder",
lambda: Coder(model_client=create_completion_client_from_env()),
)
coder = AgentProxy(AgentId("Coder", "default"), runtime)
executor = await runtime.register_and_get_proxy("Executor", lambda: Executor("A agent for executing code"))
await runtime.register("Executor", lambda: Executor("A agent for executing code"))
executor = AgentProxy(AgentId("Executor", "default"), runtime)
user_proxy = await runtime.register_and_get_proxy(
await runtime.register(
"UserProxy",
lambda: UserProxy(),
lambda: UserProxy(description="The current user interacting with you."),
)
user_proxy = AgentProxy(AgentId("UserProxy", "default"), runtime)
await runtime.register("orchestrator", lambda: RoundRobinOrchestrator([coder, executor, user_proxy]))

View File

@@ -3,6 +3,7 @@ import logging
from agnext.application import SingleThreadedAgentRuntime
from agnext.application.logging import EVENT_LOGGER_NAME
from agnext.core import AgentId, AgentProxy
from team_one.agents.file_surfer import FileSurfer
from team_one.agents.orchestrator import RoundRobinOrchestrator
from team_one.agents.user_proxy import UserProxy
@@ -18,14 +19,17 @@ async def main() -> None:
client = create_completion_client_from_env()
# Register agents.
file_surfer = await runtime.register_and_get_proxy(
await runtime.register(
"file_surfer",
lambda: FileSurfer(model_client=client),
)
user_proxy = await runtime.register_and_get_proxy(
file_surfer = AgentProxy(AgentId("file_surfer", "default"), runtime)
await runtime.register(
"UserProxy",
lambda: UserProxy(),
)
user_proxy = AgentProxy(AgentId("UserProxy", "default"), runtime)
await runtime.register("orchestrator", lambda: RoundRobinOrchestrator([file_surfer, user_proxy]))

View File

@@ -4,6 +4,7 @@ import logging
from agnext.application import SingleThreadedAgentRuntime
from agnext.application.logging import EVENT_LOGGER_NAME
from agnext.components.models import UserMessage
from agnext.core import AgentId, AgentProxy, TopicId
from team_one.agents.orchestrator import RoundRobinOrchestrator
from team_one.agents.reflex_agents import ReflexAgent
from team_one.messages import BroadcastMessage
@@ -13,14 +14,19 @@ from team_one.utils import LogHandler
async def main() -> None:
runtime = SingleThreadedAgentRuntime()
fake1 = await runtime.register_and_get_proxy("fake_agent_1", lambda: ReflexAgent("First reflect agent"))
fake2 = await runtime.register_and_get_proxy("fake_agent_2", lambda: ReflexAgent("Second reflect agent"))
fake3 = await runtime.register_and_get_proxy("fake_agent_3", lambda: ReflexAgent("Third reflect agent"))
await runtime.register_and_get("orchestrator", lambda: RoundRobinOrchestrator([fake1, fake2, fake3]))
await runtime.register("fake_agent_1", lambda: ReflexAgent("First reflect agent"))
fake1 = AgentProxy(AgentId("fake_agent_1", "default"), runtime)
await runtime.register("fake_agent_2", lambda: ReflexAgent("Second reflect agent"))
fake2 = AgentProxy(AgentId("fake_agent_2", "default"), runtime)
await runtime.register("fake_agent_3", lambda: ReflexAgent("Third reflect agent"))
fake3 = AgentProxy(AgentId("fake_agent_3", "default"), runtime)
await runtime.register("orchestrator", lambda: RoundRobinOrchestrator([fake1, fake2, fake3]))
task_message = UserMessage(content="Test Message", source="User")
run_context = runtime.start()
await runtime.publish_message(BroadcastMessage(task_message), namespace="default")
await runtime.publish_message(BroadcastMessage(task_message), topic_id=TopicId("default", "default"))
await run_context.stop_when_idle()

View File

@@ -4,6 +4,7 @@ import logging
# from typing import Any, Dict, List, Tuple, Union
from agnext.application import SingleThreadedAgentRuntime
from agnext.application.logging import EVENT_LOGGER_NAME
from agnext.core import AgentId, AgentProxy
from team_one.agents.coder import Coder
from team_one.agents.orchestrator import RoundRobinOrchestrator
from team_one.agents.user_proxy import UserProxy
@@ -19,14 +20,17 @@ async def main() -> None:
client = create_completion_client_from_env()
# Register agents.
coder = await runtime.register_and_get_proxy(
await runtime.register(
"Coder",
lambda: Coder(model_client=client),
)
user_proxy = await runtime.register_and_get_proxy(
coder = AgentProxy(AgentId("Coder", "default"), runtime)
await runtime.register(
"UserProxy",
lambda: UserProxy(),
)
user_proxy = AgentProxy(AgentId("UserProxy", "default"), runtime)
await runtime.register("orchestrator", lambda: RoundRobinOrchestrator([coder, user_proxy]))

View File

@@ -4,6 +4,7 @@ import os
from agnext.application import SingleThreadedAgentRuntime
from agnext.application.logging import EVENT_LOGGER_NAME
from agnext.core import AgentId, AgentProxy
from team_one.agents.multimodal_web_surfer import MultimodalWebSurfer
from team_one.agents.orchestrator import RoundRobinOrchestrator
from team_one.agents.user_proxy import UserProxy
@@ -21,15 +22,17 @@ async def main() -> None:
client = create_completion_client_from_env()
# Register agents.
web_surfer = await runtime.register_and_get_proxy(
await runtime.register(
"WebSurfer",
lambda: MultimodalWebSurfer(),
)
web_surfer = AgentProxy(AgentId("WebSurfer", "default"), runtime)
user_proxy = await runtime.register_and_get_proxy(
await runtime.register(
"UserProxy",
lambda: UserProxy(),
)
user_proxy = AgentProxy(AgentId("UserProxy", "default"), runtime)
await runtime.register("orchestrator", lambda: RoundRobinOrchestrator([web_surfer, user_proxy]))

View File

@@ -5,7 +5,7 @@ from agnext.components.models import (
LLMMessage,
UserMessage,
)
from agnext.core import CancellationToken, MessageContext
from agnext.core import CancellationToken, MessageContext, TopicId
from team_one.messages import (
BroadcastMessage,
@@ -45,7 +45,8 @@ class BaseWorker(TeamOneBaseAgent):
self._chat_history.append(assistant_message)
user_message = UserMessage(content=response, source=self.metadata["type"])
await self.publish_message(BroadcastMessage(content=user_message, request_halt=request_halt))
topic_id = TopicId("default", self.id.key)
await self.publish_message(BroadcastMessage(content=user_message, request_halt=request_halt), topic_id=topic_id)
async def _generate_reply(self, cancellation_token: CancellationToken) -> Tuple[bool, UserContent]:
"""Returns (request_halt, response_message)"""

View File

@@ -2,7 +2,7 @@ import json
from typing import Any, Dict, List, Optional
from agnext.components.models import AssistantMessage, ChatCompletionClient, LLMMessage, SystemMessage, UserMessage
from agnext.core import AgentProxy
from agnext.core import AgentProxy, TopicId
from ..messages import BroadcastMessage, OrchestrationEvent, ResetMessage
from .base_orchestrator import BaseOrchestrator, logger
@@ -248,8 +248,10 @@ class LedgerOrchestrator(BaseOrchestrator):
synthesized_prompt = self._get_synthesize_prompt(
self._task, self._team_description, self._facts, self._plan
)
topic_id = TopicId("default", self.id.key)
await self.publish_message(
BroadcastMessage(content=UserMessage(content=synthesized_prompt, source=self.metadata["type"]))
BroadcastMessage(content=UserMessage(content=synthesized_prompt, source=self.metadata["type"])),
topic_id=topic_id,
)
logger.info(
@@ -319,14 +321,17 @@ class LedgerOrchestrator(BaseOrchestrator):
# Reset everyone, then rebroadcast the new plan
self._chat_history = [self._chat_history[0]]
await self.publish_message(ResetMessage())
topic_id = TopicId("default", self.id.key)
await self.publish_message(ResetMessage(), topic_id=topic_id)
# Send everyone the NEW plan
synthesized_prompt = self._get_synthesize_prompt(
self._task, self._team_description, self._facts, self._plan
)
topic_id = TopicId("default", self.id.key)
await self.publish_message(
BroadcastMessage(content=UserMessage(content=synthesized_prompt, source=self.metadata["type"]))
BroadcastMessage(content=UserMessage(content=synthesized_prompt, source=self.metadata["type"])),
topic_id=topic_id,
)
logger.info(
@@ -351,8 +356,10 @@ class LedgerOrchestrator(BaseOrchestrator):
assistant_message = AssistantMessage(content=instruction, source=self.metadata["type"])
logger.info(OrchestrationEvent(f"{self.metadata['type']} (-> {next_agent_name})", instruction))
self._chat_history.append(assistant_message) # My copy
topic_id = TopicId("default", self.id.key)
await self.publish_message(
BroadcastMessage(content=user_message, request_halt=False)
BroadcastMessage(content=user_message, request_halt=False),
topic_id=topic_id,
) # Send to everyone else
return agent

View File

@@ -1,6 +1,6 @@
from agnext.components import TypeRoutedAgent, message_handler
from agnext.components.models import UserMessage
from agnext.core import MessageContext
from agnext.core import MessageContext, TopicId
from ..messages import BroadcastMessage, RequestReplyMessage
@@ -22,5 +22,6 @@ class ReflexAgent(TypeRoutedAgent):
content=f"Hello, world from {name}!",
source=name,
)
topic_id = TopicId("default", self.id.key)
await self.publish_message(BroadcastMessage(response_message))
await self.publish_message(BroadcastMessage(response_message), topic_id=topic_id)

View File

@@ -7,11 +7,14 @@ from math import ceil
import asyncio
import pytest
from agnext.core import AgentId
from agnext.core import AgentProxy
pytest_plugins = ('pytest_asyncio',)
from json import dumps
from team_one.utils import (
ENVIRON_KEY_CHAT_COMPLETION_PROVIDER,
ENVIRON_KEY_CHAT_COMPLETION_PROVIDER,
ENVIRON_KEY_CHAT_COMPLETION_KWARGS_JSON,
create_completion_client_from_env
)
@@ -96,13 +99,14 @@ async def test_web_surfer() -> None:
# Register agents.
# Register agents.
web_surfer = await runtime.register_and_get_proxy(
await runtime.register(
"WebSurfer",
lambda: MultimodalWebSurfer(),
)
web_surfer = AgentId("WebSurfer", "default")
run_context = runtime.start()
actual_surfer = await runtime.try_get_underlying_agent_instance(web_surfer.id, MultimodalWebSurfer)
actual_surfer = await runtime.try_get_underlying_agent_instance(web_surfer, MultimodalWebSurfer)
await actual_surfer.init(model_client=client, downloads_folder=os.getcwd(), browser_channel="chromium")
# Test some basic navigations
@@ -138,7 +142,7 @@ async def test_web_surfer() -> None:
tool_resp = await make_browser_request(actual_surfer, TOOL_PAGE_DOWN)
assert (
f"The viewport shows {viewport_percentage}% of the webpage, and is positioned at the bottom of the page" in tool_resp
)
)
# Test Q&A and summarization -- we don't have a key so we expect it to fail #(but it means the code path is correct)
with pytest.raises(AuthenticationError):
@@ -160,15 +164,17 @@ async def test_web_surfer_oai() -> None:
client = create_completion_client_from_env()
# Register agents.
web_surfer = await runtime.register_and_get_proxy(
await runtime.register(
"WebSurfer",
lambda: MultimodalWebSurfer(),
)
web_surfer = AgentProxy(AgentId("WebSurfer", "default"), runtime)
user_proxy = await runtime.register_and_get_proxy(
await runtime.register(
"UserProxy",
lambda: UserProxy(),
)
user_proxy = AgentProxy(AgentId("UserProxy", "default"), runtime)
await runtime.register("orchestrator", lambda: RoundRobinOrchestrator([web_surfer, user_proxy]))
run_context = runtime.start()
@@ -220,10 +226,12 @@ async def test_web_surfer_bing() -> None:
# Register agents.
# Register agents.
web_surfer = await runtime.register_and_get_proxy(
await runtime.register(
"WebSurfer",
lambda: MultimodalWebSurfer(),
)
web_surfer = AgentProxy(AgentId("WebSurfer", "default"), runtime)
run_context = runtime.start()
actual_surfer = await runtime.try_get_underlying_agent_instance(web_surfer.id, MultimodalWebSurfer)
await actual_surfer.init(model_client=client, downloads_folder=os.getcwd(), browser_channel="chromium")
@@ -235,7 +243,7 @@ async def test_web_surfer_bing() -> None:
assert f"{BING_QUERY}".strip() in metadata["meta_tags"]["og:url"]
assert f"{BING_QUERY}".strip() in metadata["meta_tags"]["og:title"]
assert f"I typed '{BING_QUERY}' into the browser search bar." in tool_resp.replace("\\","")
tool_resp = await make_browser_request(actual_surfer, TOOL_WEB_SEARCH, {"query": BING_QUERY + " Wikipedia"})
markdown = await actual_surfer._get_page_markdown() # type: ignore
assert "https://en.wikipedia.org/wiki/" in markdown