diff --git a/tests/test_data_utils.py b/tests/test_data_utils.py
index 8fe8a24bd50..eff62650bb4 100644
--- a/tests/test_data_utils.py
+++ b/tests/test_data_utils.py
@@ -396,6 +396,29 @@ def test_maybe_apply_chat_template(self, tokenizer_id, example):
assert isinstance(result["label"], bool)
assert result["label"] == example["label"]
+ def test_apply_chat_template_with_chat_template_kwargs(self):
+ tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3ForCausalLM")
+
+ example = {
+ "prompt": [{"role": "user", "content": "What color is the sky?"}],
+ # with this tokenizer, when you pass enable_thinking=False, it will add "\n\n\n\n"
+ "chat_template_kwargs": {"enable_thinking": False},
+ }
+ result = apply_chat_template(example, tokenizer)
+
+ # docstyle-ignore
+ expected = textwrap.dedent("""\
+ <|im_start|>user
+ What color is the sky?<|im_end|>
+ <|im_start|>assistant
+
+
+
+
+ """)
+
+ assert result["prompt"] == expected
+
def test_apply_chat_template_with_tools(self):
tokenizer = AutoProcessor.from_pretrained("trl-internal-testing/tiny-LlamaForCausalLM-3.2")
diff --git a/trl/data_utils.py b/trl/data_utils.py
index 75e7a76f979..454dd24af15 100644
--- a/trl/data_utils.py
+++ b/trl/data_utils.py
@@ -143,7 +143,13 @@ def apply_chat_template(
# Apply the chat template to the whole conversation
if "messages" in example:
- messages = tokenizer.apply_chat_template(example["messages"], tools=tools, tokenize=False, **template_kwargs)
+ messages = tokenizer.apply_chat_template(
+ example["messages"],
+ tools=tools,
+ tokenize=False,
+ **example.get("chat_template_kwargs", {}),
+ **template_kwargs,
+ )
# Apply the chat template to the prompt, adding the generation prompt
if "prompt" in example:
@@ -162,6 +168,7 @@ def apply_chat_template(
continue_final_message=continue_final_message,
tokenize=False,
add_generation_prompt=add_generation_prompt,
+ **example.get("chat_template_kwargs", {}),
**template_kwargs,
)
@@ -169,7 +176,11 @@ def apply_chat_template(
if "prompt" in example: # explicit prompt and prompt-completion case
if "chosen" in example:
prompt_chosen = tokenizer.apply_chat_template(
- example["prompt"] + example["chosen"], tools=tools, tokenize=False, **template_kwargs
+ example["prompt"] + example["chosen"],
+ tools=tools,
+ tokenize=False,
+ **example.get("chat_template_kwargs", {}),
+ **template_kwargs,
)
# DeepSeek-R1 inserts a token when using `add_generation_prompt`, which can cause discrepancies
# between the prompt alone and the combined prompt+completion. To ensure consistency, we extract the
@@ -179,24 +190,42 @@ def apply_chat_template(
chosen = prompt_chosen[len(prompt) :]
if "rejected" in example and "prompt" in example: # explicit prompt
prompt_rejected = tokenizer.apply_chat_template(
- example["prompt"] + example["rejected"], tools=tools, tokenize=False, **template_kwargs
+ example["prompt"] + example["rejected"],
+ tools=tools,
+ tokenize=False,
+ **example.get("chat_template_kwargs", {}),
+ **template_kwargs,
)
# Handle DeepSeek-R1 token, see the above comment for details
prompt = "".join(x for x, _ in takewhile(lambda x: x[0] == x[1], zip(prompt, prompt_rejected)))
rejected = prompt_rejected[len(prompt) :]
if "completion" in example:
prompt_completion = tokenizer.apply_chat_template(
- example["prompt"] + example["completion"], tools=tools, tokenize=False, **template_kwargs
+ example["prompt"] + example["completion"],
+ tools=tools,
+ tokenize=False,
+ **example.get("chat_template_kwargs", {}),
+ **template_kwargs,
)
# Handle DeepSeek-R1 token, see the above comment for details
prompt = "".join(x for x, _ in takewhile(lambda x: x[0] == x[1], zip(prompt, prompt_completion)))
completion = prompt_completion[len(prompt) :]
else: # implicit prompt case
if "chosen" in example:
- chosen = tokenizer.apply_chat_template(example["chosen"], tools=tools, tokenize=False, **template_kwargs)
+ chosen = tokenizer.apply_chat_template(
+ example["chosen"],
+ tools=tools,
+ tokenize=False,
+ **example.get("chat_template_kwargs", {}),
+ **template_kwargs,
+ )
if "rejected" in example:
rejected = tokenizer.apply_chat_template(
- example["rejected"], tools=tools, tokenize=False, **template_kwargs
+ example["rejected"],
+ tools=tools,
+ tokenize=False,
+ **example.get("chat_template_kwargs", {}),
+ **template_kwargs,
)
# Extract the completion by removing the prompt part from the prompt-completion string
@@ -239,7 +268,9 @@ def maybe_apply_chat_template(
- Unpaired preference dataset: `"prompt"`, `"completion"`, and `"label"`.
For keys `"messages"`, `"prompt"`, `"chosen"`, `"rejected"`, and `"completion"`, the values are lists of
- messages, where each message is a dictionary with keys `"role"` and `"content"`.
+ messages, where each message is a dictionary with keys `"role"` and `"content"`. Additionally, the example
+ may contain a `"chat_template_kwargs"` key, which is a dictionary of additional keyword arguments to pass
+ to the chat template renderer.
tokenizer (`PreTrainedTokenizerBase`):
Tokenizer to apply the chat template with.
tools (`list[Union[dict, Callable]]`, *optional*):