diff --git a/src/transformers/cli/serving/utils.py b/src/transformers/cli/serving/utils.py index fb165b4f4c63..90e514b448c7 100644 --- a/src/transformers/cli/serving/utils.py +++ b/src/transformers/cli/serving/utils.py @@ -930,6 +930,12 @@ def get_processor_inputs_from_messages(messages: list[dict], modality: Modality) for message in messages: parsed = {"role": message["role"], "content": []} + # Forward tool-use fields so apply_chat_template can handle multi-turn tool conversations + if "tool_calls" in message: + parsed["tool_calls"] = message["tool_calls"] + if "tool_call_id" in message: + parsed["tool_call_id"] = message["tool_call_id"] + content = message.get("content") if modality == Modality.LLM: if isinstance(content, str): diff --git a/tests/cli/test_serve.py b/tests/cli/test_serve.py index 9be3dbeb99ff..17a8d6a128c5 100644 --- a/tests/cli/test_serve.py +++ b/tests/cli/test_serve.py @@ -209,6 +209,42 @@ def test_vlm_multi_turn(self): self.assertIsInstance(msg["content"], list) self.assertEqual(msg["content"][0]["type"], "text") + def test_llm_tool_use_fields_forwarded(self): + """Tool-use fields (tool_calls, tool_call_id) should be forwarded to processor inputs.""" + + get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages + + tool_calls = [ + {"id": "call_1", "type": "function", "function": {"name": "get_weather", "arguments": '{"city": "Paris"}'}} + ] + messages = [ + {"role": "user", "content": "What's the weather in Paris?"}, + {"role": "assistant", "tool_calls": tool_calls}, + {"role": "tool", "content": "22°C, sunny", "tool_call_id": "call_1"}, + ] + result = get_processor_inputs_from_messages(messages, Modality.LLM) + self.assertEqual(len(result), 3) + self.assertEqual(result[1]["tool_calls"], tool_calls) + self.assertNotIn("tool_calls", result[0]) + self.assertEqual(result[2]["tool_call_id"], "call_1") + self.assertNotIn("tool_call_id", result[0]) + + def test_vlm_tool_use_fields_forwarded(self): + """Tool-use fields should be forwarded for VLM modality as well.""" + + get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages + + tool_calls = [{"id": "call_1", "type": "function", "function": {"name": "describe", "arguments": "{}"}}] + messages = [ + {"role": "user", "content": "Describe this"}, + {"role": "assistant", "tool_calls": tool_calls}, + {"role": "tool", "content": "A landscape photo", "tool_call_id": "call_1"}, + ] + result = get_processor_inputs_from_messages(messages, Modality.VLM) + self.assertEqual(len(result), 3) + self.assertEqual(result[1]["tool_calls"], tool_calls) + self.assertEqual(result[2]["tool_call_id"], "call_1") + class TestGenerativeModelList(unittest.TestCase): def test_lists_only_generative_models(self):