mirror of
https://github.com/microsoft/autogen.git
synced 2026-04-20 03:02:16 -04:00
process message before send (#1783)
* process message before send * rename
This commit is contained in:
@@ -46,7 +46,7 @@ class TransformChatHistory:
|
||||
"""
|
||||
Adds TransformChatHistory capability to the given agent.
|
||||
"""
|
||||
agent.register_hook(hookable_method="process_all_messages", hook=self._transform_messages)
|
||||
agent.register_hook(hookable_method="process_all_messages_before_reply", hook=self._transform_messages)
|
||||
|
||||
def _transform_messages(self, messages: List[Dict]) -> List[Dict]:
|
||||
"""
|
||||
|
||||
@@ -61,7 +61,7 @@ class Teachability(AgentCapability):
|
||||
self.teachable_agent = agent
|
||||
|
||||
# Register a hook for processing the last message.
|
||||
agent.register_hook(hookable_method="process_last_message", hook=self.process_last_message)
|
||||
agent.register_hook(hookable_method="process_last_received_message", hook=self.process_last_received_message)
|
||||
|
||||
# Was an llm_config passed to the constructor?
|
||||
if self.llm_config is None:
|
||||
@@ -82,7 +82,7 @@ class Teachability(AgentCapability):
|
||||
"""Adds a few arbitrary memos to the DB."""
|
||||
self.memo_store.prepopulate()
|
||||
|
||||
def process_last_message(self, text):
|
||||
def process_last_received_message(self, text):
|
||||
"""
|
||||
Appends any relevant memos to the message text, and stores any apparent teachings in new memos.
|
||||
Uses TextAnalyzerAgent to make decisions about memo storage and retrieval.
|
||||
|
||||
@@ -223,7 +223,11 @@ class ConversableAgent(LLMAgent):
|
||||
|
||||
# Registered hooks are kept in lists, indexed by hookable method, to be called in their order of registration.
|
||||
# New hookable methods should be added to this list as required to support new agent capabilities.
|
||||
self.hook_lists = {"process_last_message": [], "process_all_messages": []}
|
||||
self.hook_lists = {
|
||||
"process_last_received_message": [],
|
||||
"process_all_messages_before_reply": [],
|
||||
"process_message_before_send": [],
|
||||
}
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@@ -467,6 +471,15 @@ class ConversableAgent(LLMAgent):
|
||||
self._oai_messages[conversation_id].append(oai_message)
|
||||
return True
|
||||
|
||||
def _process_message_before_send(
|
||||
self, message: Union[Dict, str], recipient: Agent, silent: bool
|
||||
) -> Union[Dict, str]:
|
||||
"""Process the message before sending it to the recipient."""
|
||||
hook_list = self.hook_lists["process_message_before_send"]
|
||||
for hook in hook_list:
|
||||
message = hook(message, recipient, silent)
|
||||
return message
|
||||
|
||||
def send(
|
||||
self,
|
||||
message: Union[Dict, str],
|
||||
@@ -509,6 +522,7 @@ class ConversableAgent(LLMAgent):
|
||||
Returns:
|
||||
ChatResult: a ChatResult object.
|
||||
"""
|
||||
message = self._process_message_before_send(message, recipient, silent)
|
||||
# When the agent composes and sends the message, the role of the message is "assistant"
|
||||
# unless it's "function".
|
||||
valid = self._append_oai_message(message, "assistant", recipient)
|
||||
@@ -561,6 +575,7 @@ class ConversableAgent(LLMAgent):
|
||||
Returns:
|
||||
ChatResult: an ChatResult object.
|
||||
"""
|
||||
message = self._process_message_before_send(message, recipient, silent)
|
||||
# When the agent composes and sends the message, the role of the message is "assistant"
|
||||
# unless it's "function".
|
||||
valid = self._append_oai_message(message, "assistant", recipient)
|
||||
@@ -1634,11 +1649,11 @@ class ConversableAgent(LLMAgent):
|
||||
|
||||
# Call the hookable method that gives registered hooks a chance to process all messages.
|
||||
# Message modifications do not affect the incoming messages or self._oai_messages.
|
||||
messages = self.process_all_messages(messages)
|
||||
messages = self.process_all_messages_before_reply(messages)
|
||||
|
||||
# Call the hookable method that gives registered hooks a chance to process the last message.
|
||||
# Message modifications do not affect the incoming messages or self._oai_messages.
|
||||
messages = self.process_last_message(messages)
|
||||
messages = self.process_last_received_message(messages)
|
||||
|
||||
for reply_func_tuple in self._reply_func_list:
|
||||
reply_func = reply_func_tuple["reply_func"]
|
||||
@@ -1695,11 +1710,11 @@ class ConversableAgent(LLMAgent):
|
||||
|
||||
# Call the hookable method that gives registered hooks a chance to process all messages.
|
||||
# Message modifications do not affect the incoming messages or self._oai_messages.
|
||||
messages = self.process_all_messages(messages)
|
||||
messages = self.process_all_messages_before_reply(messages)
|
||||
|
||||
# Call the hookable method that gives registered hooks a chance to process the last message.
|
||||
# Message modifications do not affect the incoming messages or self._oai_messages.
|
||||
messages = self.process_last_message(messages)
|
||||
messages = self.process_last_received_message(messages)
|
||||
|
||||
for reply_func_tuple in self._reply_func_list:
|
||||
reply_func = reply_func_tuple["reply_func"]
|
||||
@@ -2333,11 +2348,11 @@ class ConversableAgent(LLMAgent):
|
||||
assert hook not in hook_list, f"{hook} is already registered as a hook."
|
||||
hook_list.append(hook)
|
||||
|
||||
def process_all_messages(self, messages: List[Dict]) -> List[Dict]:
|
||||
def process_all_messages_before_reply(self, messages: List[Dict]) -> List[Dict]:
|
||||
"""
|
||||
Calls any registered capability hooks to process all messages, potentially modifying the messages.
|
||||
"""
|
||||
hook_list = self.hook_lists["process_all_messages"]
|
||||
hook_list = self.hook_lists["process_all_messages_before_reply"]
|
||||
# If no hooks are registered, or if there are no messages to process, return the original message list.
|
||||
if len(hook_list) == 0 or messages is None:
|
||||
return messages
|
||||
@@ -2348,14 +2363,14 @@ class ConversableAgent(LLMAgent):
|
||||
processed_messages = hook(processed_messages)
|
||||
return processed_messages
|
||||
|
||||
def process_last_message(self, messages):
|
||||
def process_last_received_message(self, messages):
|
||||
"""
|
||||
Calls any registered capability hooks to use and potentially modify the text of the last message,
|
||||
as long as the last message is not a function call or exit command.
|
||||
"""
|
||||
|
||||
# If any required condition is not met, return the original message list.
|
||||
hook_list = self.hook_lists["process_last_message"]
|
||||
hook_list = self.hook_lists["process_last_received_message"]
|
||||
if len(hook_list) == 0:
|
||||
return messages # No hooks registered.
|
||||
if messages is None:
|
||||
|
||||
@@ -1074,6 +1074,24 @@ def test_max_turn():
|
||||
assert len(res.chat_history) <= 6
|
||||
|
||||
|
||||
def test_process_before_send():
|
||||
print_mock = unittest.mock.MagicMock()
|
||||
|
||||
def send_to_frontend(message, recipient, silent):
|
||||
if not silent:
|
||||
print(f"Message sent to {recipient.name}: {message}")
|
||||
print_mock(message=message)
|
||||
return message
|
||||
|
||||
dummy_agent_1 = ConversableAgent(name="dummy_agent_1", llm_config=False, human_input_mode="NEVER")
|
||||
dummy_agent_2 = ConversableAgent(name="dummy_agent_2", llm_config=False, human_input_mode="NEVER")
|
||||
dummy_agent_1.register_hook("process_message_before_send", send_to_frontend)
|
||||
dummy_agent_1.send("hello", dummy_agent_2)
|
||||
print_mock.assert_called_once_with(message="hello")
|
||||
dummy_agent_1.send("silent hello", dummy_agent_2, silent=True)
|
||||
print_mock.assert_called_once_with(message="hello")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# test_trigger()
|
||||
# test_context()
|
||||
@@ -1081,4 +1099,5 @@ if __name__ == "__main__":
|
||||
# test_generate_code_execution_reply()
|
||||
# test_conversable_agent()
|
||||
# test_no_llm_config()
|
||||
test_max_turn()
|
||||
# test_max_turn()
|
||||
test_process_before_send()
|
||||
|
||||
Reference in New Issue
Block a user