mirror of
https://github.com/microsoft/autogen.git
synced 2026-04-20 03:02:16 -04:00
Numerous fixes for agbench (#170)
* Shift to new runtime API * Add pretty printing * Reformat * Fix linting errors
This commit is contained in:
@@ -1,30 +1,36 @@
|
||||
import asyncio
|
||||
#from typing import Any, Dict, List, Tuple, Union
|
||||
import logging
|
||||
|
||||
# from typing import Any, Dict, List, Tuple, Union
|
||||
|
||||
from agnext.application import SingleThreadedAgentRuntime
|
||||
from agnext.components.models import (
|
||||
AzureOpenAIChatCompletionClient,
|
||||
LLMMessage,
|
||||
ModelCapabilities,
|
||||
UserMessage,
|
||||
)
|
||||
from agnext.components.code_executor import LocalCommandLineCodeExecutor
|
||||
from agnext.application.logging import EVENT_LOGGER_NAME
|
||||
|
||||
from team_one.agents.coder import Coder, Executor
|
||||
from team_one.agents.orchestrator import RoundRobinOrchestrator
|
||||
from team_one.messages import BroadcastMessage
|
||||
from team_one.messages import BroadcastMessage, OrchestrationEvent
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
# Create the runtime.
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
# Create the AzureOpenAI client, with AAD auth
|
||||
#token_provider = get_bearer_token_provider(DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default")
|
||||
# token_provider = get_bearer_token_provider(DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default")
|
||||
client = AzureOpenAIChatCompletionClient(
|
||||
api_version="2024-02-15-preview",
|
||||
azure_endpoint="https://aif-complex-tasks-west-us-3.openai.azure.com/",
|
||||
model="gpt-4o-2024-05-13",
|
||||
model_capabilities=ModelCapabilities(function_calling=True, json_output=True, vision=True),
|
||||
#azure_ad_token_provider=token_provider
|
||||
model_capabilities=ModelCapabilities(
|
||||
function_calling=True, json_output=True, vision=True
|
||||
),
|
||||
# azure_ad_token_provider=token_provider
|
||||
)
|
||||
|
||||
# Register agents.
|
||||
@@ -34,7 +40,9 @@ async def main() -> None:
|
||||
)
|
||||
executor = runtime.register_and_get_proxy(
|
||||
"Executor",
|
||||
lambda: Executor("A agent for executing code", executor=LocalCommandLineCodeExecutor())
|
||||
lambda: Executor(
|
||||
"A agent for executing code", executor=LocalCommandLineCodeExecutor()
|
||||
),
|
||||
)
|
||||
|
||||
runtime.register("orchestrator", lambda: RoundRobinOrchestrator([coder, executor]))
|
||||
@@ -43,7 +51,7 @@ async def main() -> None:
|
||||
with open("prompt.txt", "rt") as fh:
|
||||
prompt = fh.read()
|
||||
|
||||
entry_point = "__ENTRY_POINT__"
|
||||
entry_point = "__ENTRY_POINT__"
|
||||
|
||||
task = f"""
|
||||
The following python code imports the `run_tests` function from unit_tests.py, and runs
|
||||
@@ -64,16 +72,32 @@ run_tests({entry_point})
|
||||
```
|
||||
""".strip()
|
||||
|
||||
run_context = runtime.start()
|
||||
|
||||
await runtime.publish_message(BroadcastMessage(content=UserMessage(content=task, source="human")), namespace="default")
|
||||
await runtime.publish_message(
|
||||
BroadcastMessage(content=UserMessage(content=task, source="human")),
|
||||
namespace="default",
|
||||
)
|
||||
|
||||
await run_context.stop_when_idle()
|
||||
|
||||
|
||||
class MyHandler(logging.Handler):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def emit(self, record: logging.LogRecord) -> None:
|
||||
try:
|
||||
if isinstance(record.msg, OrchestrationEvent):
|
||||
print(f"[{record.msg.timestamp}]: {record.msg.message}", flush=True)
|
||||
except Exception:
|
||||
self.handleError(record)
|
||||
|
||||
# Run the runtime until the task is completed.
|
||||
await runtime.process_until_idle()
|
||||
|
||||
if __name__ == "__main__":
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
logging.getLogger("agnext").setLevel(logging.DEBUG)
|
||||
logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||
logger.setLevel(logging.INFO)
|
||||
my_handler = MyHandler()
|
||||
logger.handlers = [my_handler]
|
||||
asyncio.run(main())
|
||||
|
||||
|
||||
Reference in New Issue
Block a user