Added user proxy. (#176)

* Added user proxy.

* Add dependency

---------

Co-authored-by: gagb <gagb@users.noreply.github.com>
This commit is contained in:
afourney
2024-07-03 17:13:24 -07:00
committed by GitHub
parent f82f3852d3
commit 8eb8a4b14d
3 changed files with 117 additions and 2 deletions

View File

@@ -0,0 +1,72 @@
import asyncio
import logging
# from typing import Any, Dict, List, Tuple, Union
from agnext.application import SingleThreadedAgentRuntime
from agnext.application.logging import EVENT_LOGGER_NAME
from agnext.components.models import (
AzureOpenAIChatCompletionClient,
ModelCapabilities,
)
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
from team_one.agents.coder import Coder
from team_one.agents.orchestrator import RoundRobinOrchestrator
from team_one.agents.user_proxy import UserProxy
from team_one.messages import OrchestrationEvent, RequestReplyMessage
async def main() -> None:
# Create the runtime.
runtime = SingleThreadedAgentRuntime()
# Create the AzureOpenAI client, with AAD auth
token_provider = get_bearer_token_provider(DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default")
client = AzureOpenAIChatCompletionClient(
api_version="2024-02-15-preview",
azure_endpoint="https://aif-complex-tasks-west-us-3.openai.azure.com/",
model="gpt-4o-2024-05-13",
model_capabilities=ModelCapabilities(function_calling=True, json_output=True, vision=True),
azure_ad_token_provider=token_provider,
)
# Register agents.
coder = runtime.register_and_get_proxy(
"Coder",
lambda: Coder(model_client=client),
)
user_proxy = runtime.register_and_get_proxy(
"UserProxy",
lambda: UserProxy(),
)
runtime.register("orchestrator", lambda: RoundRobinOrchestrator([coder, user_proxy]))
run_context = runtime.start()
await runtime.send_message(RequestReplyMessage(), user_proxy.id)
await run_context.stop_when_idle()
class MyHandler(logging.Handler):
def __init__(self) -> None:
super().__init__()
def emit(self, record: logging.LogRecord) -> None:
try:
if isinstance(record.msg, OrchestrationEvent):
print(
f"""---------------------------------------------------------------------------
\033[91m{record.msg.source}:\033[0m
{record.msg.message}""",
flush=True,
)
except Exception:
self.handleError(record)
if __name__ == "__main__":
logger = logging.getLogger(EVENT_LOGGER_NAME)
logger.setLevel(logging.INFO)
my_handler = MyHandler()
logger.handlers = [my_handler]
asyncio.run(main())

View File

@@ -27,7 +27,8 @@ dependencies = [
"ruff==0.4.8",
"pytest",
"aiofiles",
"types-aiofiles"
"types-aiofiles",
"azure-identity"
]
[tool.hatch.envs.default.extra-scripts]
@@ -85,4 +86,4 @@ disallow_any_unimported = true
include = ["src", "tests", "examples"]
typeCheckingMode = "strict"
reportUnnecessaryIsInstance = false
reportMissingTypeStubs = false
reportMissingTypeStubs = false

View File

@@ -0,0 +1,42 @@
import asyncio
from agnext.components import TypeRoutedAgent, message_handler
from agnext.components.models import UserMessage
from agnext.core import CancellationToken
from team_one.messages import BroadcastMessage, RequestReplyMessage
class UserProxy(TypeRoutedAgent):
"""An agent that allows the user to play the role of an agent in the conversation."""
DEFAULT_DESCRIPTION = "A human user."
def __init__(
self,
description: str = DEFAULT_DESCRIPTION,
) -> None:
super().__init__(description)
@message_handler
async def handle_broadcast(self, message: BroadcastMessage, cancellation_token: CancellationToken) -> None:
"""Handle an incoming broadcast message."""
pass
@message_handler
async def handle_request_reply(self, message: RequestReplyMessage, cancellation_token: CancellationToken) -> None:
"""Respond to a reply request."""
# Make an inference to the model.
response = await self.ainput("User input ('exit' to quit): ")
response = response.strip()
await self.publish_message(
BroadcastMessage(
content=UserMessage(content=response, source=self.metadata["name"]), request_halt=(response == "exit")
)
)
async def ainput(self, prompt: str) -> str:
return await asyncio.to_thread(input, f"{prompt} ")