Skip to content

Commit 978cbd2

Browse files
SongChiYoungekzhu
andauthored
FIX/mistral could not recive name field (#6503)
## Why are these changes needed? FIX/mistral could not recive name field, so add model transformer for mistral ## Related issue number Closes #6147 ## Checks - [ ] I've included any doc changes needed for <https://microsoft.github.io/autogen/>. See <https://github.com/microsoft/autogen/blob/main/CONTRIBUTING.md> to build and test documentation locally. - [x] I've added tests (if relevant) corresponding to the changes introduced in this PR. - [x] I've made sure all auto checks have passed. Co-authored-by: Eric Zhu <[email protected]>
1 parent 177211b commit 978cbd2

File tree

3 files changed

+87
-2
lines changed

3 files changed

+87
-2
lines changed

python/packages/autogen-core/src/autogen_core/models/_model_client.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,15 @@ class ModelFamily:
3737
CLAUDE_3_5_HAIKU = "claude-3-5-haiku"
3838
CLAUDE_3_5_SONNET = "claude-3-5-sonnet"
3939
CLAUDE_3_7_SONNET = "claude-3-7-sonnet"
40+
CODESRAL = "codestral"
41+
OPEN_CODESRAL_MAMBA = "open-codestral-mamba"
42+
MISTRAL = "mistral"
43+
MINISTRAL = "ministral"
44+
PIXTRAL = "pixtral"
4045
UNKNOWN = "unknown"
4146

4247
ANY: TypeAlias = Literal[
48+
# openai_models
4349
"gpt-41",
4450
"gpt-45",
4551
"gpt-4o",
@@ -49,16 +55,25 @@ class ModelFamily:
4955
"gpt-4",
5056
"gpt-35",
5157
"r1",
58+
# google_models
5259
"gemini-1.5-flash",
5360
"gemini-1.5-pro",
5461
"gemini-2.0-flash",
5562
"gemini-2.5-pro",
63+
# anthropic_models
5664
"claude-3-haiku",
5765
"claude-3-sonnet",
5866
"claude-3-opus",
5967
"claude-3-5-haiku",
6068
"claude-3-5-sonnet",
6169
"claude-3-7-sonnet",
70+
# mistral_models
71+
"codestral",
72+
"open-codestral-mamba",
73+
"mistral",
74+
"ministral",
75+
"pixtral",
76+
# unknown
6277
"unknown",
6378
]
6479

@@ -98,6 +113,16 @@ def is_openai(family: str) -> bool:
98113
ModelFamily.GPT_35,
99114
)
100115

116+
@staticmethod
117+
def is_mistral(family: str) -> bool:
118+
return family in (
119+
ModelFamily.CODESRAL,
120+
ModelFamily.OPEN_CODESRAL_MAMBA,
121+
ModelFamily.MISTRAL,
122+
ModelFamily.MINISTRAL,
123+
ModelFamily.PIXTRAL,
124+
)
125+
101126

102127
@deprecated("Use the ModelInfo class instead ModelCapabilities.")
103128
class ModelCapabilities(TypedDict, total=False):

python/packages/autogen-ext/src/autogen_ext/models/openai/_message_transform.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,6 @@ def _set_pass_message_when_whitespace(message: LLMMessage, context: Dict[str, An
275275

276276
base_user_transformer_funcs: List[Callable[[LLMMessage, Dict[str, Any]], Dict[str, Any]]] = [
277277
_assert_valid_name,
278-
_set_name,
279278
_set_role("user"),
280279
]
281280

@@ -293,13 +292,15 @@ def _set_pass_message_when_whitespace(message: LLMMessage, context: Dict[str, An
293292
single_user_transformer_funcs: List[Callable[[LLMMessage, Dict[str, Any]], Dict[str, Any]]] = (
294293
base_user_transformer_funcs
295294
+ [
295+
_set_name,
296296
_set_prepend_text_content,
297297
]
298298
)
299299

300300
multimodal_user_transformer_funcs: List[Callable[[LLMMessage, Dict[str, Any]], Dict[str, Any]]] = (
301301
base_user_transformer_funcs
302302
+ [
303+
_set_name,
303304
_set_multimodal_content,
304305
]
305306
)
@@ -334,6 +335,19 @@ def _set_pass_message_when_whitespace(message: LLMMessage, context: Dict[str, An
334335

335336

336337
# === Specific message param functions ===
338+
single_user_transformer_funcs_mistral: List[Callable[[LLMMessage, Dict[str, Any]], Dict[str, Any]]] = (
339+
base_user_transformer_funcs
340+
+ [
341+
_set_prepend_text_content,
342+
]
343+
)
344+
345+
multimodal_user_transformer_funcs_mistral: List[Callable[[LLMMessage, Dict[str, Any]], Dict[str, Any]]] = (
346+
base_user_transformer_funcs
347+
+ [
348+
_set_multimodal_content,
349+
]
350+
)
337351

338352

339353
# === Transformer maps ===
@@ -359,6 +373,8 @@ def user_condition(message: LLMMessage, context: Dict[str, Any]) -> str:
359373
"tools": tools_assistant_transformer_funcs,
360374
"thought": thought_assistant_transformer_funcs,
361375
}
376+
377+
362378
assistant_transformer_constructors: Dict[str, Callable[..., Any]] = {
363379
"text": ChatCompletionAssistantMessageParam,
364380
"tools": ChatCompletionAssistantMessageParam,
@@ -403,6 +419,12 @@ def assistant_condition(message: LLMMessage, context: Dict[str, Any]) -> str:
403419
}
404420

405421

422+
user_transformer_funcs_mistral: Dict[str, List[Callable[[LLMMessage, Dict[str, Any]], Dict[str, Any]]]] = {
423+
"text": single_user_transformer_funcs_mistral,
424+
"multimodal": multimodal_user_transformer_funcs_mistral,
425+
}
426+
427+
406428
def function_execution_result_message(message: LLMMessage, context: Dict[str, Any]) -> TrasformerReturnType:
407429
assert isinstance(message, FunctionExecutionResultMessage)
408430
return [
@@ -466,6 +488,24 @@ def function_execution_result_message(message: LLMMessage, context: Dict[str, An
466488
FunctionExecutionResultMessage: function_execution_result_message,
467489
}
468490

491+
__MISTRAL_TRANSFORMER_MAP: TransformerMap = {
492+
SystemMessage: build_transformer_func(
493+
funcs=system_message_transformers + [_set_empty_to_whitespace],
494+
message_param_func=ChatCompletionSystemMessageParam,
495+
),
496+
UserMessage: build_conditional_transformer_func(
497+
funcs_map=user_transformer_funcs_mistral,
498+
message_param_func_map=user_transformer_constructors,
499+
condition_func=user_condition,
500+
),
501+
AssistantMessage: build_conditional_transformer_func(
502+
funcs_map=assistant_transformer_funcs,
503+
message_param_func_map=assistant_transformer_constructors,
504+
condition_func=assistant_condition,
505+
),
506+
FunctionExecutionResultMessage: function_execution_result_message,
507+
}
508+
469509

470510
# set openai models to use the transformer map
471511
total_models = get_args(ModelFamily.ANY)
@@ -475,7 +515,11 @@ def function_execution_result_message(message: LLMMessage, context: Dict[str, An
475515

476516
__gemini_models = [model for model in total_models if ModelFamily.is_gemini(model)]
477517

478-
__unknown_models = list(set(total_models) - set(__openai_models) - set(__claude_models) - set(__gemini_models))
518+
__mistral_models = [model for model in total_models if ModelFamily.is_mistral(model)]
519+
520+
__unknown_models = list(
521+
set(total_models) - set(__openai_models) - set(__claude_models) - set(__gemini_models) - set(__mistral_models)
522+
)
479523

480524
for model in __openai_models:
481525
register_transformer("openai", model, __BASE_TRANSFORMER_MAP)
@@ -486,6 +530,9 @@ def function_execution_result_message(message: LLMMessage, context: Dict[str, An
486530
for model in __gemini_models:
487531
register_transformer("openai", model, __GEMINI_TRANSFORMER_MAP)
488532

533+
for model in __mistral_models:
534+
register_transformer("openai", model, __MISTRAL_TRANSFORMER_MAP)
535+
489536
for model in __unknown_models:
490537
register_transformer("openai", model, __BASE_TRANSFORMER_MAP)
491538

python/packages/autogen-ext/tests/models/test_openai_model_client.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2485,4 +2485,17 @@ async def test_multimodal_message_test(
24852485
_ = await ocr_agent.run(task=multi_modal_message)
24862486

24872487

2488+
@pytest.mark.asyncio
2489+
async def test_mistral_remove_name() -> None:
2490+
# Test that the name pramaeter is removed from the message
2491+
# when the model is Mistral
2492+
message = UserMessage(content="foo", source="user")
2493+
params = to_oai_type(message, prepend_name=False, model="mistral-7b", model_family=ModelFamily.MISTRAL)
2494+
assert ("name" in params[0]) is False
2495+
2496+
# when the model is gpt-4o, the name parameter is not removed
2497+
params = to_oai_type(message, prepend_name=False, model="gpt-4o", model_family=ModelFamily.GPT_4O)
2498+
assert ("name" in params[0]) is True
2499+
2500+
24882501
# TODO: add integration tests for Azure OpenAI using AAD token.

0 commit comments

Comments
 (0)