Add message context to message handler (#367)

Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
Jack Gerrits
2024-08-16 23:14:09 -04:00
committed by GitHub
parent bc26ec3de4
commit 853b00b0f0
49 changed files with 267 additions and 235 deletions

View File

@@ -17,7 +17,7 @@ from typing import Dict, List
from agnext.application import SingleThreadedAgentRuntime
from agnext.components import TypeRoutedAgent, message_handler
from agnext.components.models import ChatCompletionClient, SystemMessage, UserMessage
from agnext.core import CancellationToken
from agnext.core import MessageContext
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
@@ -60,7 +60,7 @@ class ReferenceAgent(TypeRoutedAgent):
self._model_client = model_client
@message_handler
async def handle_task(self, message: ReferenceAgentTask, cancellation_token: CancellationToken) -> None:
async def handle_task(self, message: ReferenceAgentTask, ctx: MessageContext) -> None:
"""Handle a task message. This method sends the task to the model and publishes the result."""
task_message = UserMessage(content=message.task, source=self.metadata["type"])
response = await self._model_client.create(self._system_messages + [task_message])
@@ -86,14 +86,14 @@ class AggregatorAgent(TypeRoutedAgent):
self._session_results: Dict[str, List[ReferenceAgentTaskResult]] = {}
@message_handler
async def handle_task(self, message: AggregatorTask, cancellation_token: CancellationToken) -> None:
async def handle_task(self, message: AggregatorTask, ctx: MessageContext) -> None:
"""Handle a task message. This method publishes the task to the reference agents."""
session_id = str(uuid.uuid4())
ref_task = ReferenceAgentTask(session_id=session_id, task=message.task)
await self.publish_message(ref_task)
@message_handler
async def handle_result(self, message: ReferenceAgentTaskResult, cancellation_token: CancellationToken) -> None:
async def handle_result(self, message: ReferenceAgentTaskResult, ctx: MessageContext) -> None:
"""Handle a task result message. Once all results are received, this method
aggregates the results and publishes the final result."""
self._session_results.setdefault(message.session_id, []).append(message)