Skip to content

Commit 8ec1c3e

Browse files
authored
process message before send (#1783)
* process message before send * rename
1 parent 085bf6c commit 8ec1c3e

File tree

4 files changed

+47
-13
lines changed

4 files changed

+47
-13
lines changed

autogen/agentchat/contrib/capabilities/context_handling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def add_to_agent(self, agent: ConversableAgent):
4646
"""
4747
Adds TransformChatHistory capability to the given agent.
4848
"""
49-
agent.register_hook(hookable_method="process_all_messages", hook=self._transform_messages)
49+
agent.register_hook(hookable_method="process_all_messages_before_reply", hook=self._transform_messages)
5050

5151
def _transform_messages(self, messages: List[Dict]) -> List[Dict]:
5252
"""

autogen/agentchat/contrib/capabilities/teachability.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def add_to_agent(self, agent: ConversableAgent):
6161
self.teachable_agent = agent
6262

6363
# Register a hook for processing the last message.
64-
agent.register_hook(hookable_method="process_last_message", hook=self.process_last_message)
64+
agent.register_hook(hookable_method="process_last_received_message", hook=self.process_last_received_message)
6565

6666
# Was an llm_config passed to the constructor?
6767
if self.llm_config is None:
@@ -82,7 +82,7 @@ def prepopulate_db(self):
8282
"""Adds a few arbitrary memos to the DB."""
8383
self.memo_store.prepopulate()
8484

85-
def process_last_message(self, text):
85+
def process_last_received_message(self, text):
8686
"""
8787
Appends any relevant memos to the message text, and stores any apparent teachings in new memos.
8888
Uses TextAnalyzerAgent to make decisions about memo storage and retrieval.

autogen/agentchat/conversable_agent.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,11 @@ def __init__(
223223

224224
# Registered hooks are kept in lists, indexed by hookable method, to be called in their order of registration.
225225
# New hookable methods should be added to this list as required to support new agent capabilities.
226-
self.hook_lists = {"process_last_message": [], "process_all_messages": []}
226+
self.hook_lists = {
227+
"process_last_received_message": [],
228+
"process_all_messages_before_reply": [],
229+
"process_message_before_send": [],
230+
}
227231

228232
@property
229233
def name(self) -> str:
@@ -467,6 +471,15 @@ def _append_oai_message(self, message: Union[Dict, str], role, conversation_id:
467471
self._oai_messages[conversation_id].append(oai_message)
468472
return True
469473

474+
def _process_message_before_send(
475+
self, message: Union[Dict, str], recipient: Agent, silent: bool
476+
) -> Union[Dict, str]:
477+
"""Process the message before sending it to the recipient."""
478+
hook_list = self.hook_lists["process_message_before_send"]
479+
for hook in hook_list:
480+
message = hook(message, recipient, silent)
481+
return message
482+
470483
def send(
471484
self,
472485
message: Union[Dict, str],
@@ -509,6 +522,7 @@ def send(
509522
Returns:
510523
ChatResult: a ChatResult object.
511524
"""
525+
message = self._process_message_before_send(message, recipient, silent)
512526
# When the agent composes and sends the message, the role of the message is "assistant"
513527
# unless it's "function".
514528
valid = self._append_oai_message(message, "assistant", recipient)
@@ -561,6 +575,7 @@ async def a_send(
561575
Returns:
562576
ChatResult: an ChatResult object.
563577
"""
578+
message = self._process_message_before_send(message, recipient, silent)
564579
# When the agent composes and sends the message, the role of the message is "assistant"
565580
# unless it's "function".
566581
valid = self._append_oai_message(message, "assistant", recipient)
@@ -1634,11 +1649,11 @@ def generate_reply(
16341649

16351650
# Call the hookable method that gives registered hooks a chance to process all messages.
16361651
# Message modifications do not affect the incoming messages or self._oai_messages.
1637-
messages = self.process_all_messages(messages)
1652+
messages = self.process_all_messages_before_reply(messages)
16381653

16391654
# Call the hookable method that gives registered hooks a chance to process the last message.
16401655
# Message modifications do not affect the incoming messages or self._oai_messages.
1641-
messages = self.process_last_message(messages)
1656+
messages = self.process_last_received_message(messages)
16421657

16431658
for reply_func_tuple in self._reply_func_list:
16441659
reply_func = reply_func_tuple["reply_func"]
@@ -1695,11 +1710,11 @@ async def a_generate_reply(
16951710

16961711
# Call the hookable method that gives registered hooks a chance to process all messages.
16971712
# Message modifications do not affect the incoming messages or self._oai_messages.
1698-
messages = self.process_all_messages(messages)
1713+
messages = self.process_all_messages_before_reply(messages)
16991714

17001715
# Call the hookable method that gives registered hooks a chance to process the last message.
17011716
# Message modifications do not affect the incoming messages or self._oai_messages.
1702-
messages = self.process_last_message(messages)
1717+
messages = self.process_last_received_message(messages)
17031718

17041719
for reply_func_tuple in self._reply_func_list:
17051720
reply_func = reply_func_tuple["reply_func"]
@@ -2333,11 +2348,11 @@ def register_hook(self, hookable_method: str, hook: Callable):
23332348
assert hook not in hook_list, f"{hook} is already registered as a hook."
23342349
hook_list.append(hook)
23352350

2336-
def process_all_messages(self, messages: List[Dict]) -> List[Dict]:
2351+
def process_all_messages_before_reply(self, messages: List[Dict]) -> List[Dict]:
23372352
"""
23382353
Calls any registered capability hooks to process all messages, potentially modifying the messages.
23392354
"""
2340-
hook_list = self.hook_lists["process_all_messages"]
2355+
hook_list = self.hook_lists["process_all_messages_before_reply"]
23412356
# If no hooks are registered, or if there are no messages to process, return the original message list.
23422357
if len(hook_list) == 0 or messages is None:
23432358
return messages
@@ -2348,14 +2363,14 @@ def process_all_messages(self, messages: List[Dict]) -> List[Dict]:
23482363
processed_messages = hook(processed_messages)
23492364
return processed_messages
23502365

2351-
def process_last_message(self, messages):
2366+
def process_last_received_message(self, messages):
23522367
"""
23532368
Calls any registered capability hooks to use and potentially modify the text of the last message,
23542369
as long as the last message is not a function call or exit command.
23552370
"""
23562371

23572372
# If any required condition is not met, return the original message list.
2358-
hook_list = self.hook_lists["process_last_message"]
2373+
hook_list = self.hook_lists["process_last_received_message"]
23592374
if len(hook_list) == 0:
23602375
return messages # No hooks registered.
23612376
if messages is None:

test/agentchat/test_conversable_agent.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1074,11 +1074,30 @@ def test_max_turn():
10741074
assert len(res.chat_history) <= 6
10751075

10761076

1077+
def test_process_before_send():
1078+
print_mock = unittest.mock.MagicMock()
1079+
1080+
def send_to_frontend(message, recipient, silent):
1081+
if not silent:
1082+
print(f"Message sent to {recipient.name}: {message}")
1083+
print_mock(message=message)
1084+
return message
1085+
1086+
dummy_agent_1 = ConversableAgent(name="dummy_agent_1", llm_config=False, human_input_mode="NEVER")
1087+
dummy_agent_2 = ConversableAgent(name="dummy_agent_2", llm_config=False, human_input_mode="NEVER")
1088+
dummy_agent_1.register_hook("process_message_before_send", send_to_frontend)
1089+
dummy_agent_1.send("hello", dummy_agent_2)
1090+
print_mock.assert_called_once_with(message="hello")
1091+
dummy_agent_1.send("silent hello", dummy_agent_2, silent=True)
1092+
print_mock.assert_called_once_with(message="hello")
1093+
1094+
10771095
if __name__ == "__main__":
10781096
# test_trigger()
10791097
# test_context()
10801098
# test_max_consecutive_auto_reply()
10811099
# test_generate_code_execution_reply()
10821100
# test_conversable_agent()
10831101
# test_no_llm_config()
1084-
test_max_turn()
1102+
# test_max_turn()
1103+
test_process_before_send()

0 commit comments

Comments
 (0)