process message before send (#1783)

* process message before send

* rename
This commit is contained in:
Chi Wang
2024-02-25 07:57:27 -08:00
committed by GitHub
parent 085bf6cf3d
commit 8ec1c3e0b3
4 changed files with 47 additions and 13 deletions

View File

@@ -46,7 +46,7 @@ class TransformChatHistory:
"""
Adds TransformChatHistory capability to the given agent.
"""
agent.register_hook(hookable_method="process_all_messages", hook=self._transform_messages)
agent.register_hook(hookable_method="process_all_messages_before_reply", hook=self._transform_messages)
def _transform_messages(self, messages: List[Dict]) -> List[Dict]:
"""

View File

@@ -61,7 +61,7 @@ class Teachability(AgentCapability):
self.teachable_agent = agent
# Register a hook for processing the last message.
agent.register_hook(hookable_method="process_last_message", hook=self.process_last_message)
agent.register_hook(hookable_method="process_last_received_message", hook=self.process_last_received_message)
# Was an llm_config passed to the constructor?
if self.llm_config is None:
@@ -82,7 +82,7 @@ class Teachability(AgentCapability):
"""Adds a few arbitrary memos to the DB."""
self.memo_store.prepopulate()
def process_last_message(self, text):
def process_last_received_message(self, text):
"""
Appends any relevant memos to the message text, and stores any apparent teachings in new memos.
Uses TextAnalyzerAgent to make decisions about memo storage and retrieval.

View File

@@ -223,7 +223,11 @@ class ConversableAgent(LLMAgent):
# Registered hooks are kept in lists, indexed by hookable method, to be called in their order of registration.
# New hookable methods should be added to this list as required to support new agent capabilities.
self.hook_lists = {"process_last_message": [], "process_all_messages": []}
self.hook_lists = {
"process_last_received_message": [],
"process_all_messages_before_reply": [],
"process_message_before_send": [],
}
@property
def name(self) -> str:
@@ -467,6 +471,15 @@ class ConversableAgent(LLMAgent):
self._oai_messages[conversation_id].append(oai_message)
return True
def _process_message_before_send(
self, message: Union[Dict, str], recipient: Agent, silent: bool
) -> Union[Dict, str]:
"""Process the message before sending it to the recipient."""
hook_list = self.hook_lists["process_message_before_send"]
for hook in hook_list:
message = hook(message, recipient, silent)
return message
def send(
self,
message: Union[Dict, str],
@@ -509,6 +522,7 @@ class ConversableAgent(LLMAgent):
Returns:
ChatResult: a ChatResult object.
"""
message = self._process_message_before_send(message, recipient, silent)
# When the agent composes and sends the message, the role of the message is "assistant"
# unless it's "function".
valid = self._append_oai_message(message, "assistant", recipient)
@@ -561,6 +575,7 @@ class ConversableAgent(LLMAgent):
Returns:
ChatResult: an ChatResult object.
"""
message = self._process_message_before_send(message, recipient, silent)
# When the agent composes and sends the message, the role of the message is "assistant"
# unless it's "function".
valid = self._append_oai_message(message, "assistant", recipient)
@@ -1634,11 +1649,11 @@ class ConversableAgent(LLMAgent):
# Call the hookable method that gives registered hooks a chance to process all messages.
# Message modifications do not affect the incoming messages or self._oai_messages.
messages = self.process_all_messages(messages)
messages = self.process_all_messages_before_reply(messages)
# Call the hookable method that gives registered hooks a chance to process the last message.
# Message modifications do not affect the incoming messages or self._oai_messages.
messages = self.process_last_message(messages)
messages = self.process_last_received_message(messages)
for reply_func_tuple in self._reply_func_list:
reply_func = reply_func_tuple["reply_func"]
@@ -1695,11 +1710,11 @@ class ConversableAgent(LLMAgent):
# Call the hookable method that gives registered hooks a chance to process all messages.
# Message modifications do not affect the incoming messages or self._oai_messages.
messages = self.process_all_messages(messages)
messages = self.process_all_messages_before_reply(messages)
# Call the hookable method that gives registered hooks a chance to process the last message.
# Message modifications do not affect the incoming messages or self._oai_messages.
messages = self.process_last_message(messages)
messages = self.process_last_received_message(messages)
for reply_func_tuple in self._reply_func_list:
reply_func = reply_func_tuple["reply_func"]
@@ -2333,11 +2348,11 @@ class ConversableAgent(LLMAgent):
assert hook not in hook_list, f"{hook} is already registered as a hook."
hook_list.append(hook)
def process_all_messages(self, messages: List[Dict]) -> List[Dict]:
def process_all_messages_before_reply(self, messages: List[Dict]) -> List[Dict]:
"""
Calls any registered capability hooks to process all messages, potentially modifying the messages.
"""
hook_list = self.hook_lists["process_all_messages"]
hook_list = self.hook_lists["process_all_messages_before_reply"]
# If no hooks are registered, or if there are no messages to process, return the original message list.
if len(hook_list) == 0 or messages is None:
return messages
@@ -2348,14 +2363,14 @@ class ConversableAgent(LLMAgent):
processed_messages = hook(processed_messages)
return processed_messages
def process_last_message(self, messages):
def process_last_received_message(self, messages):
"""
Calls any registered capability hooks to use and potentially modify the text of the last message,
as long as the last message is not a function call or exit command.
"""
# If any required condition is not met, return the original message list.
hook_list = self.hook_lists["process_last_message"]
hook_list = self.hook_lists["process_last_received_message"]
if len(hook_list) == 0:
return messages # No hooks registered.
if messages is None:

View File

@@ -1074,6 +1074,24 @@ def test_max_turn():
assert len(res.chat_history) <= 6
def test_process_before_send():
print_mock = unittest.mock.MagicMock()
def send_to_frontend(message, recipient, silent):
if not silent:
print(f"Message sent to {recipient.name}: {message}")
print_mock(message=message)
return message
dummy_agent_1 = ConversableAgent(name="dummy_agent_1", llm_config=False, human_input_mode="NEVER")
dummy_agent_2 = ConversableAgent(name="dummy_agent_2", llm_config=False, human_input_mode="NEVER")
dummy_agent_1.register_hook("process_message_before_send", send_to_frontend)
dummy_agent_1.send("hello", dummy_agent_2)
print_mock.assert_called_once_with(message="hello")
dummy_agent_1.send("silent hello", dummy_agent_2, silent=True)
print_mock.assert_called_once_with(message="hello")
if __name__ == "__main__":
# test_trigger()
# test_context()
@@ -1081,4 +1099,5 @@ if __name__ == "__main__":
# test_generate_code_execution_reply()
# test_conversable_agent()
# test_no_llm_config()
test_max_turn()
# test_max_turn()
test_process_before_send()