Add message context to message handler (#367)

Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
Jack Gerrits
2024-08-16 23:14:09 -04:00
committed by GitHub
parent bc26ec3de4
commit 853b00b0f0
49 changed files with 267 additions and 235 deletions

View File

@@ -48,10 +48,10 @@ from agnext.components.models import (
SystemMessage,
UserMessage,
)
from agnext.core import CancellationToken
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from agnext.core import MessageContext
from common.utils import get_chat_completion_client_from_envs
logger = logging.getLogger(__name__)
@@ -116,7 +116,7 @@ class MathSolver(TypeRoutedAgent):
self._max_round = max_round
@message_handler
async def handle_response(self, message: IntermediateSolverResponse, cancellation_token: CancellationToken) -> None:
async def handle_response(self, message: IntermediateSolverResponse, ctx: MessageContext) -> None:
if message.solver_name not in self._neighbor_names:
return
# Add only neighbor's response to the buffer.
@@ -143,7 +143,7 @@ class MathSolver(TypeRoutedAgent):
self._buffer.pop((message.session_id, message.round))
@message_handler
async def handle_request(self, message: SolverRequest, cancellation_token: CancellationToken) -> None:
async def handle_request(self, message: SolverRequest, ctx: MessageContext) -> None:
# Save the question.
self._questions[message.session_id] = message.question
# Add the question to the memory.
@@ -186,7 +186,7 @@ class MathAggregator(TypeRoutedAgent):
self._responses: Dict[str, List[FinalSolverResponse]] = {}
@message_handler
async def handle_question(self, message: Question, cancellation_token: CancellationToken) -> None:
async def handle_question(self, message: Question, ctx: MessageContext) -> None:
prompt = (
f"Can you solve the following math problem?\n{message.content}\n"
"Explain your reasoning. Your final answer should be a single numerical number, "
@@ -196,9 +196,7 @@ class MathAggregator(TypeRoutedAgent):
await self.publish_message(SolverRequest(content=prompt, session_id=session_id, question=message.content))
@message_handler
async def handle_final_solver_response(
self, message: FinalSolverResponse, cancellation_token: CancellationToken
) -> None:
async def handle_final_solver_response(self, message: FinalSolverResponse, ctx: MessageContext) -> None:
self._responses.setdefault(message.session_id, []).append(message)
if len(self._responses[message.session_id]) == self._num_solvers:
# Find the majority answer.