mirror of
https://github.com/microsoft/autogen.git
synced 2026-04-20 03:02:16 -04:00
Marketing sample migration to AGNext (#234)
This commit is contained in:
@@ -29,3 +29,6 @@ async def build_app(runtime: AgentRuntime) -> None:
|
||||
|
||||
runtime.register("GraphicDesigner", lambda: GraphicDesignerAgent(client=image_client))
|
||||
runtime.register("Auditor", lambda: AuditAgent(model_client=chat_client))
|
||||
|
||||
runtime.get("GraphicDesigner")
|
||||
runtime.get("Auditor")
|
||||
|
||||
@@ -30,4 +30,4 @@ class AuditAgent(TypeRoutedAgent):
|
||||
assert isinstance(completion.content, str)
|
||||
if "NOTFORME" in completion.content:
|
||||
return
|
||||
await self.publish_message(AuditorAlert(user_id=message.user_id, auditor_alert_message=completion.content))
|
||||
await self.publish_message(AuditorAlert(UserId=message.UserId, auditorAlertMessage=completion.content))
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import logging
|
||||
from typing import Literal
|
||||
|
||||
import openai
|
||||
@@ -21,11 +22,17 @@ class GraphicDesignerAgent(TypeRoutedAgent):
|
||||
|
||||
@message_handler
|
||||
async def handle_user_chat_input(self, message: ArticleCreated, cancellation_token: CancellationToken) -> None:
|
||||
response = await self._client.images.generate(
|
||||
model=self._model, prompt=message.article, response_format="b64_json"
|
||||
)
|
||||
assert len(response.data) > 0 and response.data[0].b64_json is not None
|
||||
image_base64 = response.data[0].b64_json
|
||||
image_uri = f"data:image/png;base64,{image_base64}"
|
||||
logger = logging.getLogger("graphic_designer")
|
||||
try:
|
||||
logger.info(f"Asking model to generate an image for the article '{message.article}'.")
|
||||
response = await self._client.images.generate(
|
||||
model=self._model, prompt=message.article, response_format="url"
|
||||
)
|
||||
logger.info(f"Image response: '{response.data[0]}'")
|
||||
assert len(response.data) > 0 and response.data[0].url is not None
|
||||
image_uri = response.data[0].url
|
||||
logger.info(f"Generated image for article. Got response: '{image_uri}'")
|
||||
|
||||
await self.publish_message(GraphicDesignCreated(user_id=message.user_id, image_uri=image_uri))
|
||||
await self.publish_message(GraphicDesignCreated(UserId=message.UserId, imageUri=image_uri))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate image for article. Error: {e}")
|
||||
|
||||
@@ -2,20 +2,20 @@ from pydantic import BaseModel
|
||||
|
||||
|
||||
class ArticleCreated(BaseModel):
|
||||
user_id: str
|
||||
UserId: str
|
||||
article: str
|
||||
|
||||
|
||||
class GraphicDesignCreated(BaseModel):
|
||||
user_id: str
|
||||
image_uri: str
|
||||
UserId: str
|
||||
imageUri: str
|
||||
|
||||
|
||||
class AuditText(BaseModel):
|
||||
user_id: str
|
||||
UserId: str
|
||||
text: str
|
||||
|
||||
|
||||
class AuditorAlert(BaseModel):
|
||||
user_id: str
|
||||
auditor_alert_message: str
|
||||
UserId: str
|
||||
auditorAlertMessage: str
|
||||
|
||||
@@ -17,14 +17,14 @@ class Printer(TypeRoutedAgent):
|
||||
|
||||
@message_handler
|
||||
async def handle_graphic_design(self, message: GraphicDesignCreated, cancellation_token: CancellationToken) -> None:
|
||||
image = Image.from_uri(message.image_uri)
|
||||
image = Image.from_uri(message.imageUri)
|
||||
# Save image to random name in current directory
|
||||
image.image.save(os.path.join(os.getcwd(), f"{message.user_id}.png"))
|
||||
print(f"Received GraphicDesignCreated: user {message.user_id}, saved to {message.user_id}.png")
|
||||
image.image.save(os.path.join(os.getcwd(), f"{message.UserId}.png"))
|
||||
print(f"Received GraphicDesignCreated: user {message.UserId}, saved to {message.UserId}.png")
|
||||
|
||||
@message_handler
|
||||
async def handle_auditor_alert(self, message: AuditorAlert, cancellation_token: CancellationToken) -> None:
|
||||
print(f"Received AuditorAlert: {message.auditor_alert_message} for user {message.user_id}")
|
||||
print(f"Received AuditorAlert: {message.auditorAlertMessage} for user {message.UserId}")
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
@@ -35,11 +35,11 @@ async def main() -> None:
|
||||
ctx = runtime.start()
|
||||
|
||||
await runtime.publish_message(
|
||||
AuditText(text="Buy my product for a MASSIVE 50% discount.", user_id="user-1"), namespace="default"
|
||||
AuditText(text="Buy my product for a MASSIVE 50% discount.", UserId="user-1"), namespace="default"
|
||||
)
|
||||
|
||||
await runtime.publish_message(
|
||||
ArticleCreated(article="The best article ever written about trees and rocks", user_id="user-2"),
|
||||
ArticleCreated(article="The best article ever written about trees and rocks", UserId="user-2"),
|
||||
namespace="default",
|
||||
)
|
||||
|
||||
|
||||
@@ -1,24 +1,40 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
|
||||
from agnext.core._serialization import MESSAGE_TYPE_REGISTRY
|
||||
from agnext.worker.worker_runtime import WorkerAgentRuntime
|
||||
from app import build_app
|
||||
from dotenv import load_dotenv
|
||||
from messages import ArticleCreated, AuditorAlert, AuditText, GraphicDesignCreated
|
||||
|
||||
agnext_logger = logging.getLogger("agnext")
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
load_dotenv()
|
||||
runtime = WorkerAgentRuntime()
|
||||
await runtime.setup_channel(os.environ["AGENT_HOST"])
|
||||
MESSAGE_TYPE_REGISTRY.add_type(ArticleCreated)
|
||||
MESSAGE_TYPE_REGISTRY.add_type(GraphicDesignCreated)
|
||||
MESSAGE_TYPE_REGISTRY.add_type(AuditText)
|
||||
MESSAGE_TYPE_REGISTRY.add_type(AuditorAlert)
|
||||
agnext_logger.info("1")
|
||||
await runtime.setup_channel("localhost:5145")
|
||||
|
||||
agnext_logger.info("2")
|
||||
|
||||
await build_app(runtime)
|
||||
agnext_logger.info("3")
|
||||
|
||||
# just to keep the runtime running
|
||||
try:
|
||||
await asyncio.sleep(1000000)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
await runtime.close_channel()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
agnext_logger.setLevel(logging.DEBUG)
|
||||
agnext_logger.log(logging.DEBUG, "Starting worker")
|
||||
asyncio.run(main())
|
||||
|
||||
@@ -28,7 +28,7 @@ import grpc
|
||||
from grpc.aio import StreamStreamCall
|
||||
from typing_extensions import Self
|
||||
|
||||
from agnext.core import MESSAGE_TYPE_REGISTRY
|
||||
from agnext.core import MESSAGE_TYPE_REGISTRY, agent_instantiation_context
|
||||
|
||||
from ..core import (
|
||||
Agent,
|
||||
@@ -129,10 +129,13 @@ class RuntimeConnection:
|
||||
async def from_connection_string(
|
||||
cls, connection_string: str, grpc_config: Mapping[str, Any] = DEFAULT_GRPC_CONFIG
|
||||
) -> Self:
|
||||
logger.info("Connecting to %s", connection_string)
|
||||
channel = grpc.aio.insecure_channel(
|
||||
connection_string, options=[("grpc.service_config", json.dumps(grpc_config))]
|
||||
)
|
||||
await channel.channel_ready()
|
||||
# logger.info("awaiting channel_ready")
|
||||
# await channel.channel_ready()
|
||||
# logger.info("channel_ready")
|
||||
instance = cls(channel)
|
||||
instance._connection_task = asyncio.create_task(
|
||||
instance._connect(channel, instance._send_queue, instance._recv_queue)
|
||||
@@ -155,6 +158,7 @@ class RuntimeConnection:
|
||||
|
||||
while True:
|
||||
try:
|
||||
logger.info("Waiting for message")
|
||||
message = await recv_stream.read() # type: ignore
|
||||
if message == grpc.aio.EOF: # type: ignore
|
||||
logger.info("EOF")
|
||||
@@ -162,15 +166,21 @@ class RuntimeConnection:
|
||||
message = cast(Message, message)
|
||||
logger.info("Received message: %s", message)
|
||||
await receive_queue.put(message)
|
||||
logger.info("Put message in queue")
|
||||
except Exception as e:
|
||||
print("=========================================================================")
|
||||
print(e)
|
||||
print("=========================================================================")
|
||||
del recv_stream
|
||||
recv_stream = stub.OpenChannel(QueueAsyncIterable(send_queue)) # type: ignore
|
||||
|
||||
async def send(self, message: Message) -> None:
|
||||
await self._send_queue.put(message)
|
||||
|
||||
async def recv(self) -> Message:
|
||||
logger.info("Getting message from queue")
|
||||
return await self._recv_queue.get()
|
||||
logger.info("Got message from queue")
|
||||
|
||||
|
||||
class WorkerAgentRuntime(AgentRuntime):
|
||||
@@ -191,7 +201,9 @@ class WorkerAgentRuntime(AgentRuntime):
|
||||
self._runtime_connection: RuntimeConnection | None = None
|
||||
|
||||
async def setup_channel(self, connection_string: str) -> None:
|
||||
logger.info(f"connecting to: {connection_string}")
|
||||
self._runtime_connection = await RuntimeConnection.from_connection_string(connection_string)
|
||||
logger.info("connection")
|
||||
if self._read_task is None:
|
||||
self._read_task = asyncio.create_task(self.run_read_loop())
|
||||
self._running = True
|
||||
@@ -203,42 +215,52 @@ class WorkerAgentRuntime(AgentRuntime):
|
||||
logger.info("Sent registerAgentType message for %s", agent_type)
|
||||
|
||||
async def run_read_loop(self) -> None:
|
||||
logger.info("Starting read loop")
|
||||
# TODO: catch exceptions and reconnect
|
||||
while self._running:
|
||||
message = await self._runtime_connection.recv() # type: ignore
|
||||
logger.info("Got message: %s", message)
|
||||
oneofcase = Message.WhichOneof(message, "message")
|
||||
match oneofcase:
|
||||
case "registerAgentType":
|
||||
logger.warn("Cant handle registerAgentType")
|
||||
case "request":
|
||||
# request: RpcRequest = message.request
|
||||
# source = AgentId(request.source.name, request.source.namespace)
|
||||
# target = AgentId(request.target.name, request.target.namespace)
|
||||
try:
|
||||
message = await self._runtime_connection.recv() # type: ignore
|
||||
logger.info("Got message: %s", message)
|
||||
oneofcase = Message.WhichOneof(message, "message")
|
||||
match oneofcase:
|
||||
case "registerAgentType":
|
||||
logger.warn("Cant handle registerAgentType")
|
||||
case "request":
|
||||
# request: RpcRequest = message.request
|
||||
# source = AgentId(request.source.name, request.source.namespace)
|
||||
# target = AgentId(request.target.name, request.target.namespace)
|
||||
|
||||
raise NotImplementedError("Sending messages is not yet implemented.")
|
||||
case "response":
|
||||
response: RpcResponse = message.response
|
||||
future = self._pending_requests.pop(response.request_id)
|
||||
if len(response.error) > 0:
|
||||
future.set_exception(Exception(response.error))
|
||||
break
|
||||
future.set_result(response.result)
|
||||
case "event":
|
||||
event: Event = message.event
|
||||
message = MESSAGE_TYPE_REGISTRY.deserialize(event.data, type_name=event.type)
|
||||
namespace = event.namespace
|
||||
raise NotImplementedError("Sending messages is not yet implemented.")
|
||||
case "response":
|
||||
response: RpcResponse = message.response
|
||||
future = self._pending_requests.pop(response.request_id)
|
||||
if len(response.error) > 0:
|
||||
future.set_exception(Exception(response.error))
|
||||
break
|
||||
future.set_result(response.result)
|
||||
case "event":
|
||||
event: Event = message.event
|
||||
message = MESSAGE_TYPE_REGISTRY.deserialize(event.data, type_name=event.type)
|
||||
# namespace = event.namespace
|
||||
namespace = "default"
|
||||
|
||||
for agent_id in self._per_type_subscribers[(namespace, MESSAGE_TYPE_REGISTRY.type_name(message))]:
|
||||
agent = self._get_agent(agent_id)
|
||||
try:
|
||||
await agent.on_message(message, CancellationToken())
|
||||
except Exception as e:
|
||||
event_logger.error("Error handling message", exc_info=e)
|
||||
logger.info("Got event: %s", message)
|
||||
for agent_id in self._per_type_subscribers[
|
||||
(namespace, MESSAGE_TYPE_REGISTRY.type_name(message))
|
||||
]:
|
||||
logger.info("Sending message to %s", agent_id)
|
||||
agent = self._get_agent(agent_id)
|
||||
try:
|
||||
await agent.on_message(message, CancellationToken())
|
||||
logger.info("%s handled event %s", agent_id, message)
|
||||
except Exception as e:
|
||||
event_logger.error("Error handling message", exc_info=e)
|
||||
|
||||
logger.warn("Cant handle event")
|
||||
case None:
|
||||
logger.warn("No message")
|
||||
logger.warn("Cant handle event")
|
||||
case None:
|
||||
logger.warn("No message")
|
||||
except Exception as e:
|
||||
logger.error("Error in read loop", exc_info=e)
|
||||
|
||||
async def close_channel(self) -> None:
|
||||
self._running = False
|
||||
@@ -367,7 +389,10 @@ class WorkerAgentRuntime(AgentRuntime):
|
||||
|
||||
agent_factory = self._agent_factories[agent_id.name]
|
||||
|
||||
token = agent_instantiation_context.set((self, agent_id))
|
||||
agent = self._invoke_agent_factory(agent_factory, agent_id)
|
||||
agent_instantiation_context.reset(token)
|
||||
|
||||
for message_type in agent.metadata["subscriptions"]:
|
||||
self._per_type_subscribers[(agent_id.namespace, message_type)].add(agent_id)
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from agnext.components import TypeRoutedAgent, message_handler
|
||||
from agnext.core import CancellationToken, AgentId
|
||||
import logging
|
||||
import asyncio
|
||||
#import os
|
||||
import os
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
@@ -25,7 +25,7 @@ class ExampleAgent(TypeRoutedAgent):
|
||||
async def main() -> None:
|
||||
logger = logging.getLogger("main")
|
||||
runtime = WorkerAgentRuntime()
|
||||
await runtime.setup_channel("localhost:5438") #os.environ["AGENT_HOST"])
|
||||
await runtime.setup_channel(os.environ["AGENT_HOST"])
|
||||
|
||||
runtime.register("ExampleAgent", lambda: ExampleAgent())
|
||||
while True:
|
||||
|
||||
Reference in New Issue
Block a user