Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions tests/test_data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,29 @@ def test_maybe_apply_chat_template(self, tokenizer_id, example):
assert isinstance(result["label"], bool)
assert result["label"] == example["label"]

def test_apply_chat_template_with_chat_template_kwargs(self):
tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3ForCausalLM")

example = {
"prompt": [{"role": "user", "content": "What color is the sky?"}],
# with this tokenizer, when you pass enable_thinking=False, it will add "<think>\n\n</think>\n\n"
"chat_template_kwargs": {"enable_thinking": False},
}
result = apply_chat_template(example, tokenizer)

# docstyle-ignore
expected = textwrap.dedent("""\
<|im_start|>user
What color is the sky?<|im_end|>
<|im_start|>assistant
<think>

</think>

""")

assert result["prompt"] == expected

def test_apply_chat_template_with_tools(self):
tokenizer = AutoProcessor.from_pretrained("trl-internal-testing/tiny-LlamaForCausalLM-3.2")

Expand Down
45 changes: 38 additions & 7 deletions trl/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,13 @@ def apply_chat_template(

# Apply the chat template to the whole conversation
if "messages" in example:
messages = tokenizer.apply_chat_template(example["messages"], tools=tools, tokenize=False, **template_kwargs)
messages = tokenizer.apply_chat_template(
example["messages"],
tools=tools,
tokenize=False,
**example.get("chat_template_kwargs", {}),
**template_kwargs,
)

# Apply the chat template to the prompt, adding the generation prompt
if "prompt" in example:
Expand All @@ -162,14 +168,19 @@ def apply_chat_template(
continue_final_message=continue_final_message,
tokenize=False,
add_generation_prompt=add_generation_prompt,
**example.get("chat_template_kwargs", {}),
**template_kwargs,
)

# Apply the chat template to the entire prompt + completion
if "prompt" in example: # explicit prompt and prompt-completion case
if "chosen" in example:
prompt_chosen = tokenizer.apply_chat_template(
example["prompt"] + example["chosen"], tools=tools, tokenize=False, **template_kwargs
example["prompt"] + example["chosen"],
tools=tools,
tokenize=False,
**example.get("chat_template_kwargs", {}),
**template_kwargs,
)
# DeepSeek-R1 inserts a <tool_call> token when using `add_generation_prompt`, which can cause discrepancies
# between the prompt alone and the combined prompt+completion. To ensure consistency, we extract the
Expand All @@ -179,24 +190,42 @@ def apply_chat_template(
chosen = prompt_chosen[len(prompt) :]
if "rejected" in example and "prompt" in example: # explicit prompt
prompt_rejected = tokenizer.apply_chat_template(
example["prompt"] + example["rejected"], tools=tools, tokenize=False, **template_kwargs
example["prompt"] + example["rejected"],
tools=tools,
tokenize=False,
**example.get("chat_template_kwargs", {}),
**template_kwargs,
)
# Handle DeepSeek-R1 <tool_call> token, see the above comment for details
prompt = "".join(x for x, _ in takewhile(lambda x: x[0] == x[1], zip(prompt, prompt_rejected)))
rejected = prompt_rejected[len(prompt) :]
if "completion" in example:
prompt_completion = tokenizer.apply_chat_template(
example["prompt"] + example["completion"], tools=tools, tokenize=False, **template_kwargs
example["prompt"] + example["completion"],
tools=tools,
tokenize=False,
**example.get("chat_template_kwargs", {}),
**template_kwargs,
)
# Handle DeepSeek-R1 <tool_call> token, see the above comment for details
prompt = "".join(x for x, _ in takewhile(lambda x: x[0] == x[1], zip(prompt, prompt_completion)))
completion = prompt_completion[len(prompt) :]
else: # implicit prompt case
if "chosen" in example:
chosen = tokenizer.apply_chat_template(example["chosen"], tools=tools, tokenize=False, **template_kwargs)
chosen = tokenizer.apply_chat_template(
example["chosen"],
tools=tools,
tokenize=False,
**example.get("chat_template_kwargs", {}),
**template_kwargs,
)
if "rejected" in example:
rejected = tokenizer.apply_chat_template(
example["rejected"], tools=tools, tokenize=False, **template_kwargs
example["rejected"],
tools=tools,
tokenize=False,
**example.get("chat_template_kwargs", {}),
**template_kwargs,
)

# Extract the completion by removing the prompt part from the prompt-completion string
Expand Down Expand Up @@ -239,7 +268,9 @@ def maybe_apply_chat_template(
- Unpaired preference dataset: `"prompt"`, `"completion"`, and `"label"`.

For keys `"messages"`, `"prompt"`, `"chosen"`, `"rejected"`, and `"completion"`, the values are lists of
messages, where each message is a dictionary with keys `"role"` and `"content"`.
messages, where each message is a dictionary with keys `"role"` and `"content"`. Additionally, the example
may contain a `"chat_template_kwargs"` key, which is a dictionary of additional keyword arguments to pass
to the chat template renderer.
tokenizer (`PreTrainedTokenizerBase`):
Tokenizer to apply the chat template with.
tools (`list[Union[dict, Callable]]`, *optional*):
Expand Down
Loading