Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 43 additions & 47 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,13 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
completion. If `"assistant_masks"` are present, they are used to set the labels to `-100` for tokens that are not
in the assistant part of the sequence. The collator returns a dictionary containing the following keys:
- `"input_ids"`: Tensor of input IDs, padded to the maximum length of the batch.
- `"attention_mask"`: Tensor of attention mask, padded to the maximum length of the batch.
- `"position_ids"`: Tensor of position IDs, padded to the maximum length of the batch.
- `"labels"`: Tensor of labels, padded to the maximum length of the batch. If `completion_only_loss` is set to
`True`, tokens that are not in the completion are set to -100. If `assistant_masks` are present, tokens that are
not in the assistant part of the sequence are set to -100.
not in the assistant part of the sequence are set to -100. If `padding_free` is set to `False`, the following key
is also returned:
- `"attention_mask"`: Tensor of attention masks, padded to the maximum length of the batch.
If `padding_free` is set to `True`, the following key is also returned:
- `"position_ids"`: Tensor of position IDs, padded to the maximum length of the batch.

Args:
pad_token_id (`int`):
Expand All @@ -129,7 +131,7 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
that are no in the completion.
padding_free (`bool`, *optional*, defaults to `False`):
If set to `True`, the sequences will be flattened into a single sequence, and the position IDs will be
generated accordingly.
generated accordingly and returned instead of the attention mask.
pad_to_multiple_of (`int`, *optional*):
If set, the sequences will be padded to a multiple of this value.
return_tensors (`str`, *optional*, defaults to `"pt"`):
Expand All @@ -146,8 +148,6 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
[ 4, 5, 0]]),
'attention_mask': tensor([[ 1, 1, 1],
[ 1, 1, 0]]),
'position_ids': tensor([[0, 1, 2],
[0, 1, 0]]),
'labels': tensor([[ 1, 2, 3],
[ 4, 5, -100]])}

Expand All @@ -161,16 +161,13 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
[ 4, 5, 0]]),
'attention_mask': tensor([[ 1, 1, 1],
[ 1, 1, 0]]),
'position_ids': tensor([[0, 1, 2],
[0, 1, 0]]),
'labels': tensor([[-100, 2, 3],
[-100, 5, -100]])}

>>> # With padding_free
>>> collator = DataCollatorForLanguageModeling(pad_token_id=0, padding_free=True)
>>> collator(examples)
{'input_ids': tensor([[ 1, 2, 3, 4, 5]]),
'attention_mask': tensor([[1, 1, 1, 1, 1]]),
'position_ids': tensor([[0, 1, 2, 0, 1]]),
'labels': tensor([[1, 2, 3, 4, 5]])}
```
Expand All @@ -179,33 +176,28 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
pad_token_id: int
completion_only_loss: bool = True
padding_free: bool = False
return_position_ids: bool = True
pad_to_multiple_of: Optional[int] = None
return_tensors: str = "pt"

def torch_call(self, examples: list[dict[str, Any]]) -> dict[str, Any]:
# Convert to tensor
input_ids = [torch.tensor(example["input_ids"]) for example in examples]
if "labels" in examples[0]:
labels = [torch.tensor(example["labels"]) for example in examples]
else:
labels = [torch.tensor(example["input_ids"]) for example in examples]

# Check if we have meaningful seq_lengths from packing (restarting sequences)
has_packed_position_ids = self.return_position_ids and "seq_lengths" in examples[0] and self.padding_free

# For packing with position_ids, we should NOT create attention_mask as it causes
# FlashAttention to ignore position_ids and compute wrong cu_seq_lens from the all-1s mask
if not has_packed_position_ids:
attention_mask = [torch.ones_like(ids) for ids in input_ids]

if self.return_position_ids:
# For padding-free, we should NOT create attention_mask as it causes FlashAttention to ignore position_ids and
# compute wrong cu_seq_lens from the all-1s mask
if self.padding_free:
if "seq_lengths" in examples[0]:
position_ids = self.get_position_ids_from_packed_seq_lengths(
[example["seq_lengths"] for example in examples]
)
else:
position_ids = [torch.arange(len(ids)) for ids in input_ids]
if "labels" in examples[0]:
labels = [torch.tensor(example["labels"]) for example in examples]
else:
labels = [torch.tensor(example["input_ids"]) for example in examples]
attention_mask = [torch.ones_like(ids) for ids in input_ids]
if self.completion_only_loss and "completion_mask" in examples[0]:
completion_mask = [torch.tensor(example["completion_mask"]) for example in examples]
if "assistant_masks" in examples[0]:
Expand All @@ -215,9 +207,8 @@ def torch_call(self, examples: list[dict[str, Any]]) -> dict[str, Any]:
output = {}
if self.padding_free:
input_ids = [torch.cat(input_ids, dim=0)]
if self.return_position_ids:
position_ids = [torch.cat(position_ids, dim=0)]
labels = [torch.cat(labels, dim=0)]
position_ids = [torch.cat(position_ids, dim=0)]
if self.completion_only_loss and "completion_mask" in examples[0]:
completion_mask = [torch.cat(completion_mask, dim=0)]
if "assistant_masks" in examples[0]:
Expand All @@ -230,18 +221,18 @@ def torch_call(self, examples: list[dict[str, Any]]) -> dict[str, Any]:
padding_side="right",
pad_to_multiple_of=self.pad_to_multiple_of,
)
if not has_packed_position_ids:
attention_mask = [torch.ones_like(input_ids) for input_ids in input_ids]
output["attention_mask"] = pad(
attention_mask, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
)
if self.return_position_ids:
output["position_ids"] = pad(
position_ids, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
)
output["labels"] = pad(
labels, padding_value=-100, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
)
if self.padding_free:
output["position_ids"] = pad(
position_ids, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
)
output["labels"][output["position_ids"] == 0] = -100
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the most important line

else:
output["attention_mask"] = pad(
attention_mask, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
)
if self.completion_only_loss and "completion_mask" in examples[0]:
completion_mask = pad(
completion_mask, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
Expand Down Expand Up @@ -721,6 +712,9 @@ def __init__(
use_flash_attention = model.config._attn_implementation in [
"flash_attention_2",
"flash_attention_3",
"kernels-community/flash-attn",
"kernels-community/vllm-flash-attn3",
"kernels-community/flash-attn3",
"kernels-community/vllm-flash-attn3",
]
if self.padding_free:
Expand All @@ -733,13 +727,16 @@ def __init__(
)
if not use_flash_attention:
logger.warning(
"Padding-free training is enabled, but the attention implementation is not set to "
"'flash_attention_2'. Padding-free training flattens batches into a single sequence, and "
"'flash_attention_2' is the only known attention mechanism that reliably supports this. Using "
"other implementations may lead to unexpected behavior. To ensure compatibility, set "
"`attn_implementation='flash_attention_2'` in the model configuration, or verify that your "
"attention mechanism can handle flattened sequences."
"Padding-free training is enabled, but the attention implementation is not set to a supported "
"flash attention variant. Padding-free training flattens batches into a single sequence, and only "
"the following implementations are known to reliably support this: 'flash_attention_2', "
"'flash_attention_3', 'kernels-community/flash-attn', 'kernels-community/flash-attn3', or "
"'kernels-community/vllm-flash-attn3'. Using other implementations may lead to unexpected "
"behavior. To ensure compatibility, set `attn_implementation` in the model configuration to one "
"of these supported options or verify that your attention mechanism can handle flattened "
"sequences."
)

if args.per_device_train_batch_size == 1 and not args.packing:
logger.warning(
"You are using a per_device_train_batch_size of 1 with padding-free training. Using a batch size "
Expand Down Expand Up @@ -777,8 +774,6 @@ def __init__(
pad_token_id=pad_token_id,
completion_only_loss=self.completion_only_loss,
padding_free=self.padding_free,
# Using position_ids without flash_attn hurts the training
return_position_ids=use_flash_attention,
pad_to_multiple_of=args.pad_to_multiple_of,
)
elif data_collator is None and self._is_vision_dataset:
Expand All @@ -792,12 +787,13 @@ def __init__(

if args.packing and args.packing_strategy == "bfd" and not use_flash_attention:
logger.warning(
"You are using packing, but the attention implementation is not set to 'flash_attention_2' or "
"'kernels-community/vllm-flash-attn3'. Packing flattens batches into a single sequence, and Flash "
"Attention is the only known attention mechanisms that reliably support this. Using other "
"implementations may lead to cross-contamination between batches. To avoid this, either disable "
"packing by setting `packing=False`, or set `attn_implementation='flash_attention_2'` or "
"`attn_implementation='kernels-community/vllm-flash-attn3'` in the model configuration."
"You are using packing, but the attention implementation is not set to a supported flash attention "
"variant. Packing gathers multiple samples into a single sequence, and only the following "
"implementations are known to reliably support this: 'flash_attention_2', 'flash_attention_3', "
"'kernels-community/flash-attn', 'kernels-community/flash-attn3', or "
"'kernels-community/vllm-flash-attn3'. Using other implementations may lead to cross-contamination "
"between samples. To avoid this, either disable packing by setting `packing=False`, or set "
"`attn_implementation` in the model configuration to one of these supported options."
)
if args.assistant_only_loss and not is_conversational(dataset_sample):
raise ValueError(
Expand Down
Loading