diff --git a/tests/test_data_utils.py b/tests/test_data_utils.py index 8fe8a24bd50..eff62650bb4 100644 --- a/tests/test_data_utils.py +++ b/tests/test_data_utils.py @@ -396,6 +396,29 @@ def test_maybe_apply_chat_template(self, tokenizer_id, example): assert isinstance(result["label"], bool) assert result["label"] == example["label"] + def test_apply_chat_template_with_chat_template_kwargs(self): + tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3ForCausalLM") + + example = { + "prompt": [{"role": "user", "content": "What color is the sky?"}], + # with this tokenizer, when you pass enable_thinking=False, it will add "\n\n\n\n" + "chat_template_kwargs": {"enable_thinking": False}, + } + result = apply_chat_template(example, tokenizer) + + # docstyle-ignore + expected = textwrap.dedent("""\ + <|im_start|>user + What color is the sky?<|im_end|> + <|im_start|>assistant + + + + + """) + + assert result["prompt"] == expected + def test_apply_chat_template_with_tools(self): tokenizer = AutoProcessor.from_pretrained("trl-internal-testing/tiny-LlamaForCausalLM-3.2") diff --git a/trl/data_utils.py b/trl/data_utils.py index 75e7a76f979..454dd24af15 100644 --- a/trl/data_utils.py +++ b/trl/data_utils.py @@ -143,7 +143,13 @@ def apply_chat_template( # Apply the chat template to the whole conversation if "messages" in example: - messages = tokenizer.apply_chat_template(example["messages"], tools=tools, tokenize=False, **template_kwargs) + messages = tokenizer.apply_chat_template( + example["messages"], + tools=tools, + tokenize=False, + **example.get("chat_template_kwargs", {}), + **template_kwargs, + ) # Apply the chat template to the prompt, adding the generation prompt if "prompt" in example: @@ -162,6 +168,7 @@ def apply_chat_template( continue_final_message=continue_final_message, tokenize=False, add_generation_prompt=add_generation_prompt, + **example.get("chat_template_kwargs", {}), **template_kwargs, ) @@ -169,7 +176,11 @@ def apply_chat_template( if "prompt" in example: # explicit prompt and prompt-completion case if "chosen" in example: prompt_chosen = tokenizer.apply_chat_template( - example["prompt"] + example["chosen"], tools=tools, tokenize=False, **template_kwargs + example["prompt"] + example["chosen"], + tools=tools, + tokenize=False, + **example.get("chat_template_kwargs", {}), + **template_kwargs, ) # DeepSeek-R1 inserts a token when using `add_generation_prompt`, which can cause discrepancies # between the prompt alone and the combined prompt+completion. To ensure consistency, we extract the @@ -179,24 +190,42 @@ def apply_chat_template( chosen = prompt_chosen[len(prompt) :] if "rejected" in example and "prompt" in example: # explicit prompt prompt_rejected = tokenizer.apply_chat_template( - example["prompt"] + example["rejected"], tools=tools, tokenize=False, **template_kwargs + example["prompt"] + example["rejected"], + tools=tools, + tokenize=False, + **example.get("chat_template_kwargs", {}), + **template_kwargs, ) # Handle DeepSeek-R1 token, see the above comment for details prompt = "".join(x for x, _ in takewhile(lambda x: x[0] == x[1], zip(prompt, prompt_rejected))) rejected = prompt_rejected[len(prompt) :] if "completion" in example: prompt_completion = tokenizer.apply_chat_template( - example["prompt"] + example["completion"], tools=tools, tokenize=False, **template_kwargs + example["prompt"] + example["completion"], + tools=tools, + tokenize=False, + **example.get("chat_template_kwargs", {}), + **template_kwargs, ) # Handle DeepSeek-R1 token, see the above comment for details prompt = "".join(x for x, _ in takewhile(lambda x: x[0] == x[1], zip(prompt, prompt_completion))) completion = prompt_completion[len(prompt) :] else: # implicit prompt case if "chosen" in example: - chosen = tokenizer.apply_chat_template(example["chosen"], tools=tools, tokenize=False, **template_kwargs) + chosen = tokenizer.apply_chat_template( + example["chosen"], + tools=tools, + tokenize=False, + **example.get("chat_template_kwargs", {}), + **template_kwargs, + ) if "rejected" in example: rejected = tokenizer.apply_chat_template( - example["rejected"], tools=tools, tokenize=False, **template_kwargs + example["rejected"], + tools=tools, + tokenize=False, + **example.get("chat_template_kwargs", {}), + **template_kwargs, ) # Extract the completion by removing the prompt part from the prompt-completion string @@ -239,7 +268,9 @@ def maybe_apply_chat_template( - Unpaired preference dataset: `"prompt"`, `"completion"`, and `"label"`. For keys `"messages"`, `"prompt"`, `"chosen"`, `"rejected"`, and `"completion"`, the values are lists of - messages, where each message is a dictionary with keys `"role"` and `"content"`. + messages, where each message is a dictionary with keys `"role"` and `"content"`. Additionally, the example + may contain a `"chat_template_kwargs"` key, which is a dictionary of additional keyword arguments to pass + to the chat template renderer. tokenizer (`PreTrainedTokenizerBase`): Tokenizer to apply the chat template with. tools (`list[Union[dict, Callable]]`, *optional*):