Skip to content

Commit 816ac61

Browse files
authored
🪪 Update SFTTrainer to handle labels correctly and add configuration example in paper index (#4051)
1 parent 373a64a commit 816ac61

File tree

2 files changed

+20
-2
lines changed

2 files changed

+20
-2
lines changed

‎docs/source/paper_index.md‎

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,19 @@ training_args = SFTConfig(
470470
)
471471
```
472472

473+
To closely match the paper’s setup, you can use the following configuration (see Sec. 4.1). Authors also mention that the hyperparameters are not very sensitive (Sec. 4.3):
474+
475+
```python
476+
SFTConfig(
477+
loss_type="dft",
478+
learning_rate=5e-5,
479+
max_length=2048,
480+
# Target batch size 256; achieved via per-device batch 8 * grad accumulation 32
481+
per_device_train_batch_size=8,
482+
gradient_accumulation_steps=32,
483+
)
484+
```
485+
473486
## Reinforce Leave-One-Out
474487

475488
Papers relating to the [`RLOOTrainer`]

‎trl/trainer/sft_trainer.py‎

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1093,6 +1093,11 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
10931093
Compute training loss and additionally compute token accuracies
10941094
"""
10951095
mode = "train" if self.model.training else "eval"
1096+
1097+
# Set aside labels as it will be dropped by super().compute_loss() if a custom `compute_loss_func` is used.
1098+
# This can be removed when this issue is fixed.
1099+
labels = inputs["labels"]
1100+
10961101
# If not set, defaults from model config and may warn since cache isn't compatible with gradient checkpointing
10971102
inputs["use_cache"] = False
10981103
(loss, outputs) = super().compute_loss(
@@ -1137,7 +1142,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
11371142
self._metrics[mode]["num_tokens"] = [self._total_train_tokens]
11381143

11391144
# Compute token accuracy if we have labels and if the model is not using Liger (no logits)
1140-
if "labels" in inputs and not self.args.use_liger_kernel:
1145+
if not self.args.use_liger_kernel:
11411146
with torch.no_grad():
11421147
if "shift_labels" in inputs:
11431148
# When using CP, labels are pre-shifted. We must use these (and cannot manually shift) because:
@@ -1147,7 +1152,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
11471152
shift_labels = inputs["shift_labels"]
11481153
else:
11491154
shift_logits = outputs.logits[..., :-1, :].contiguous()
1150-
shift_labels = inputs["labels"][..., 1:].contiguous()
1155+
shift_labels = labels[..., 1:].contiguous()
11511156

11521157
# When using Prompt Tuning, skip the virtual tokens in logits before accuracy computation, since they do
11531158
# not correspond to actual input labels.

0 commit comments

Comments
 (0)