@@ -223,7 +223,11 @@ def __init__(
223
223
224
224
# Registered hooks are kept in lists, indexed by hookable method, to be called in their order of registration.
225
225
# New hookable methods should be added to this list as required to support new agent capabilities.
226
- self .hook_lists = {"process_last_message" : [], "process_all_messages" : []}
226
+ self .hook_lists = {
227
+ "process_last_received_message" : [],
228
+ "process_all_messages_before_reply" : [],
229
+ "process_message_before_send" : [],
230
+ }
227
231
228
232
@property
229
233
def name (self ) -> str :
@@ -467,6 +471,15 @@ def _append_oai_message(self, message: Union[Dict, str], role, conversation_id:
467
471
self ._oai_messages [conversation_id ].append (oai_message )
468
472
return True
469
473
474
+ def _process_message_before_send (
475
+ self , message : Union [Dict , str ], recipient : Agent , silent : bool
476
+ ) -> Union [Dict , str ]:
477
+ """Process the message before sending it to the recipient."""
478
+ hook_list = self .hook_lists ["process_message_before_send" ]
479
+ for hook in hook_list :
480
+ message = hook (message , recipient , silent )
481
+ return message
482
+
470
483
def send (
471
484
self ,
472
485
message : Union [Dict , str ],
@@ -509,6 +522,7 @@ def send(
509
522
Returns:
510
523
ChatResult: a ChatResult object.
511
524
"""
525
+ message = self ._process_message_before_send (message , recipient , silent )
512
526
# When the agent composes and sends the message, the role of the message is "assistant"
513
527
# unless it's "function".
514
528
valid = self ._append_oai_message (message , "assistant" , recipient )
@@ -561,6 +575,7 @@ async def a_send(
561
575
Returns:
562
576
ChatResult: an ChatResult object.
563
577
"""
578
+ message = self ._process_message_before_send (message , recipient , silent )
564
579
# When the agent composes and sends the message, the role of the message is "assistant"
565
580
# unless it's "function".
566
581
valid = self ._append_oai_message (message , "assistant" , recipient )
@@ -1634,11 +1649,11 @@ def generate_reply(
1634
1649
1635
1650
# Call the hookable method that gives registered hooks a chance to process all messages.
1636
1651
# Message modifications do not affect the incoming messages or self._oai_messages.
1637
- messages = self .process_all_messages (messages )
1652
+ messages = self .process_all_messages_before_reply (messages )
1638
1653
1639
1654
# Call the hookable method that gives registered hooks a chance to process the last message.
1640
1655
# Message modifications do not affect the incoming messages or self._oai_messages.
1641
- messages = self .process_last_message (messages )
1656
+ messages = self .process_last_received_message (messages )
1642
1657
1643
1658
for reply_func_tuple in self ._reply_func_list :
1644
1659
reply_func = reply_func_tuple ["reply_func" ]
@@ -1695,11 +1710,11 @@ async def a_generate_reply(
1695
1710
1696
1711
# Call the hookable method that gives registered hooks a chance to process all messages.
1697
1712
# Message modifications do not affect the incoming messages or self._oai_messages.
1698
- messages = self .process_all_messages (messages )
1713
+ messages = self .process_all_messages_before_reply (messages )
1699
1714
1700
1715
# Call the hookable method that gives registered hooks a chance to process the last message.
1701
1716
# Message modifications do not affect the incoming messages or self._oai_messages.
1702
- messages = self .process_last_message (messages )
1717
+ messages = self .process_last_received_message (messages )
1703
1718
1704
1719
for reply_func_tuple in self ._reply_func_list :
1705
1720
reply_func = reply_func_tuple ["reply_func" ]
@@ -2333,11 +2348,11 @@ def register_hook(self, hookable_method: str, hook: Callable):
2333
2348
assert hook not in hook_list , f"{ hook } is already registered as a hook."
2334
2349
hook_list .append (hook )
2335
2350
2336
- def process_all_messages (self , messages : List [Dict ]) -> List [Dict ]:
2351
+ def process_all_messages_before_reply (self , messages : List [Dict ]) -> List [Dict ]:
2337
2352
"""
2338
2353
Calls any registered capability hooks to process all messages, potentially modifying the messages.
2339
2354
"""
2340
- hook_list = self .hook_lists ["process_all_messages " ]
2355
+ hook_list = self .hook_lists ["process_all_messages_before_reply " ]
2341
2356
# If no hooks are registered, or if there are no messages to process, return the original message list.
2342
2357
if len (hook_list ) == 0 or messages is None :
2343
2358
return messages
@@ -2348,14 +2363,14 @@ def process_all_messages(self, messages: List[Dict]) -> List[Dict]:
2348
2363
processed_messages = hook (processed_messages )
2349
2364
return processed_messages
2350
2365
2351
- def process_last_message (self , messages ):
2366
+ def process_last_received_message (self , messages ):
2352
2367
"""
2353
2368
Calls any registered capability hooks to use and potentially modify the text of the last message,
2354
2369
as long as the last message is not a function call or exit command.
2355
2370
"""
2356
2371
2357
2372
# If any required condition is not met, return the original message list.
2358
- hook_list = self .hook_lists ["process_last_message " ]
2373
+ hook_list = self .hook_lists ["process_last_received_message " ]
2359
2374
if len (hook_list ) == 0 :
2360
2375
return messages # No hooks registered.
2361
2376
if messages is None :
0 commit comments