@@ -275,7 +275,6 @@ def _set_pass_message_when_whitespace(message: LLMMessage, context: Dict[str, An
275
275
276
276
base_user_transformer_funcs : List [Callable [[LLMMessage , Dict [str , Any ]], Dict [str , Any ]]] = [
277
277
_assert_valid_name ,
278
- _set_name ,
279
278
_set_role ("user" ),
280
279
]
281
280
@@ -293,13 +292,15 @@ def _set_pass_message_when_whitespace(message: LLMMessage, context: Dict[str, An
293
292
single_user_transformer_funcs : List [Callable [[LLMMessage , Dict [str , Any ]], Dict [str , Any ]]] = (
294
293
base_user_transformer_funcs
295
294
+ [
295
+ _set_name ,
296
296
_set_prepend_text_content ,
297
297
]
298
298
)
299
299
300
300
multimodal_user_transformer_funcs : List [Callable [[LLMMessage , Dict [str , Any ]], Dict [str , Any ]]] = (
301
301
base_user_transformer_funcs
302
302
+ [
303
+ _set_name ,
303
304
_set_multimodal_content ,
304
305
]
305
306
)
@@ -334,6 +335,19 @@ def _set_pass_message_when_whitespace(message: LLMMessage, context: Dict[str, An
334
335
335
336
336
337
# === Specific message param functions ===
338
+ single_user_transformer_funcs_mistral : List [Callable [[LLMMessage , Dict [str , Any ]], Dict [str , Any ]]] = (
339
+ base_user_transformer_funcs
340
+ + [
341
+ _set_prepend_text_content ,
342
+ ]
343
+ )
344
+
345
+ multimodal_user_transformer_funcs_mistral : List [Callable [[LLMMessage , Dict [str , Any ]], Dict [str , Any ]]] = (
346
+ base_user_transformer_funcs
347
+ + [
348
+ _set_multimodal_content ,
349
+ ]
350
+ )
337
351
338
352
339
353
# === Transformer maps ===
@@ -359,6 +373,8 @@ def user_condition(message: LLMMessage, context: Dict[str, Any]) -> str:
359
373
"tools" : tools_assistant_transformer_funcs ,
360
374
"thought" : thought_assistant_transformer_funcs ,
361
375
}
376
+
377
+
362
378
assistant_transformer_constructors : Dict [str , Callable [..., Any ]] = {
363
379
"text" : ChatCompletionAssistantMessageParam ,
364
380
"tools" : ChatCompletionAssistantMessageParam ,
@@ -403,6 +419,12 @@ def assistant_condition(message: LLMMessage, context: Dict[str, Any]) -> str:
403
419
}
404
420
405
421
422
+ user_transformer_funcs_mistral : Dict [str , List [Callable [[LLMMessage , Dict [str , Any ]], Dict [str , Any ]]]] = {
423
+ "text" : single_user_transformer_funcs_mistral ,
424
+ "multimodal" : multimodal_user_transformer_funcs_mistral ,
425
+ }
426
+
427
+
406
428
def function_execution_result_message (message : LLMMessage , context : Dict [str , Any ]) -> TrasformerReturnType :
407
429
assert isinstance (message , FunctionExecutionResultMessage )
408
430
return [
@@ -466,6 +488,24 @@ def function_execution_result_message(message: LLMMessage, context: Dict[str, An
466
488
FunctionExecutionResultMessage : function_execution_result_message ,
467
489
}
468
490
491
+ __MISTRAL_TRANSFORMER_MAP : TransformerMap = {
492
+ SystemMessage : build_transformer_func (
493
+ funcs = system_message_transformers + [_set_empty_to_whitespace ],
494
+ message_param_func = ChatCompletionSystemMessageParam ,
495
+ ),
496
+ UserMessage : build_conditional_transformer_func (
497
+ funcs_map = user_transformer_funcs_mistral ,
498
+ message_param_func_map = user_transformer_constructors ,
499
+ condition_func = user_condition ,
500
+ ),
501
+ AssistantMessage : build_conditional_transformer_func (
502
+ funcs_map = assistant_transformer_funcs ,
503
+ message_param_func_map = assistant_transformer_constructors ,
504
+ condition_func = assistant_condition ,
505
+ ),
506
+ FunctionExecutionResultMessage : function_execution_result_message ,
507
+ }
508
+
469
509
470
510
# set openai models to use the transformer map
471
511
total_models = get_args (ModelFamily .ANY )
@@ -475,7 +515,11 @@ def function_execution_result_message(message: LLMMessage, context: Dict[str, An
475
515
476
516
__gemini_models = [model for model in total_models if ModelFamily .is_gemini (model )]
477
517
478
- __unknown_models = list (set (total_models ) - set (__openai_models ) - set (__claude_models ) - set (__gemini_models ))
518
+ __mistral_models = [model for model in total_models if ModelFamily .is_mistral (model )]
519
+
520
+ __unknown_models = list (
521
+ set (total_models ) - set (__openai_models ) - set (__claude_models ) - set (__gemini_models ) - set (__mistral_models )
522
+ )
479
523
480
524
for model in __openai_models :
481
525
register_transformer ("openai" , model , __BASE_TRANSFORMER_MAP )
@@ -486,6 +530,9 @@ def function_execution_result_message(message: LLMMessage, context: Dict[str, An
486
530
for model in __gemini_models :
487
531
register_transformer ("openai" , model , __GEMINI_TRANSFORMER_MAP )
488
532
533
+ for model in __mistral_models :
534
+ register_transformer ("openai" , model , __MISTRAL_TRANSFORMER_MAP )
535
+
489
536
for model in __unknown_models :
490
537
register_transformer ("openai" , model , __BASE_TRANSFORMER_MAP )
491
538
0 commit comments