diff --git a/autogen/agentchat/contrib/gpt_assistant_agent.py b/autogen/agentchat/contrib/gpt_assistant_agent.py index ea13e7617..20acd2b08 100644 --- a/autogen/agentchat/contrib/gpt_assistant_agent.py +++ b/autogen/agentchat/contrib/gpt_assistant_agent.py @@ -165,9 +165,7 @@ class GPTAssistantAgent(ConversableAgent): # lazily create threads self._openai_threads = {} self._unread_index = defaultdict(int) - self.register_reply(Agent, GPTAssistantAgent._invoke_assistant) - self.register_reply(Agent, GPTAssistantAgent.check_termination_and_human_reply) - self.register_reply(Agent, GPTAssistantAgent.a_check_termination_and_human_reply) + self.register_reply(Agent, GPTAssistantAgent._invoke_assistant, position=2) def _invoke_assistant( self, diff --git a/autogen/agentchat/contrib/multimodal_conversable_agent.py b/autogen/agentchat/contrib/multimodal_conversable_agent.py index 2355c630f..2a016bcff 100644 --- a/autogen/agentchat/contrib/multimodal_conversable_agent.py +++ b/autogen/agentchat/contrib/multimodal_conversable_agent.py @@ -53,16 +53,8 @@ class MultimodalConversableAgent(ConversableAgent): ) # Override the `generate_oai_reply` - def _replace_reply_func(arr, x, y): - for item in arr: - if item["reply_func"] is x: - item["reply_func"] = y - - _replace_reply_func( - self._reply_func_list, ConversableAgent.generate_oai_reply, MultimodalConversableAgent.generate_oai_reply - ) - _replace_reply_func( - self._reply_func_list, + self.replace_reply_func(ConversableAgent.generate_oai_reply, MultimodalConversableAgent.generate_oai_reply) + self.replace_reply_func( ConversableAgent.a_generate_oai_reply, MultimodalConversableAgent.a_generate_oai_reply, ) diff --git a/autogen/agentchat/contrib/web_surfer.py b/autogen/agentchat/contrib/web_surfer.py index 9b7320f09..6cd71dc63 100644 --- a/autogen/agentchat/contrib/web_surfer.py +++ b/autogen/agentchat/contrib/web_surfer.py @@ -79,8 +79,7 @@ class WebSurferAgent(ConversableAgent): if inner_llm_config not in [None, False]: self._register_functions() - self._reply_func_list = [] - self.register_reply([Agent, None], WebSurferAgent.generate_surfer_reply) + self.register_reply([Agent, None], WebSurferAgent.generate_surfer_reply, remove_other_reply_funcs=True) self.register_reply([Agent, None], ConversableAgent.generate_code_execution_reply) self.register_reply([Agent, None], ConversableAgent.generate_function_call_reply) self.register_reply([Agent, None], ConversableAgent.check_termination_and_human_reply) diff --git a/autogen/agentchat/conversable_agent.py b/autogen/agentchat/conversable_agent.py index 0a7faf06e..2a2de8b55 100644 --- a/autogen/agentchat/conversable_agent.py +++ b/autogen/agentchat/conversable_agent.py @@ -277,6 +277,7 @@ class ConversableAgent(LLMAgent): reset_config: Optional[Callable] = None, *, ignore_async_in_sync_chat: bool = False, + remove_other_reply_funcs: bool = False, ): """Register a reply function. @@ -302,20 +303,15 @@ class ConversableAgent(LLMAgent): Note: Be sure to register `None` as a trigger if you would like to trigger an auto-reply function with non-empty messages and `sender=None`. reply_func (Callable): the reply function. The function takes a recipient agent, a list of messages, a sender agent and a config as input and returns a reply message. - position: the position of the reply function in the reply function list. - config: the config to be passed to the reply function, see below. - reset_config: the function to reset the config, see below. - ignore_async_in_sync_chat: whether to ignore the async reply function in sync chats. If `False`, an exception - will be raised if an async reply function is registered and a chat is initialized with a sync - function. - ```python - def reply_func( - recipient: ConversableAgent, - messages: Optional[List[Dict]] = None, - sender: Optional[Agent] = None, - config: Optional[Any] = None, - ) -> Tuple[bool, Union[str, Dict, None]]: - ``` + + ```python + def reply_func( + recipient: ConversableAgent, + messages: Optional[List[Dict]] = None, + sender: Optional[Agent] = None, + config: Optional[Any] = None, + ) -> Tuple[bool, Union[str, Dict, None]]: + ``` position (int): the position of the reply function in the reply function list. The function registered later will be checked earlier by default. To change the order, set the position to a positive integer. @@ -323,9 +319,15 @@ class ConversableAgent(LLMAgent): When an agent is reset, the config will be reset to the original value. reset_config (Callable): the function to reset the config. The function returns None. Signature: ```def reset_config(config: Any)``` + ignore_async_in_sync_chat (bool): whether to ignore the async reply function in sync chats. If `False`, an exception + will be raised if an async reply function is registered and a chat is initialized with a sync + function. + remove_other_reply_funcs (bool): whether to remove other reply functions when registering this reply function. """ if not isinstance(trigger, (type, str, Agent, Callable, list)): raise ValueError("trigger must be a class, a string, an agent, a callable or a list.") + if remove_other_reply_funcs: + self._reply_func_list.clear() self._reply_func_list.insert( position, { @@ -338,6 +340,17 @@ class ConversableAgent(LLMAgent): }, ) + def replace_reply_func(self, old_reply_func: Callable, new_reply_func: Callable): + """Replace a registered reply function with a new one. + + Args: + old_reply_func (Callable): the old reply function to be replaced. + new_reply_func (Callable): the new reply function to replace the old one. + """ + for f in self._reply_func_list: + if f["reply_func"] == old_reply_func: + f["reply_func"] = new_reply_func + @staticmethod def _summary_from_nested_chats( chat_queue: List[Dict[str, Any]], recipient: Agent, messages: Union[str, Callable], sender: Agent, config: Any @@ -839,7 +852,7 @@ class ConversableAgent(LLMAgent): } async_reply_functions = [f for f in reply_functions if inspect.iscoroutinefunction(f)] - if async_reply_functions != []: + if async_reply_functions: msg = ( "Async reply functions can only be used with ConversableAgent.a_initiate_chat(). The following async reply functions are found: " + ", ".join([f.__name__ for f in async_reply_functions])