Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions tests/test_gkd_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import os

import pytest
import torch
import torch.nn.functional as F
from datasets import load_dataset
Expand Down Expand Up @@ -247,7 +246,6 @@ def test_gkd_trainer(self):
self.assertIn("model.safetensors", os.listdir(self.tmp_dir + "/checkpoint-2"))

@require_liger_kernel
@pytest.mark.xfail(reason="Computing the Liger loss spikes GPU memory usage, causing the test to run OOM.")
def test_gkd_trainer_with_liger(self):
training_args = GKDConfig(
output_dir=self.tmp_dir,
Expand Down
15 changes: 11 additions & 4 deletions trl/trainer/gkd_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,6 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
student_outputs = base_student(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
output_hidden_states=True,
use_cache=False,
)

Expand All @@ -316,13 +315,15 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
teacher_outputs = base_teacher(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
output_hidden_states=True,
use_cache=False,
)

# hidden states (shifted)
student_hidden = student_outputs.last_hidden_state[:, :-1].contiguous()
teacher_hidden = teacher_outputs.last_hidden_state[:, :-1].contiguous()
student_hidden = student_outputs.last_hidden_state[:, :-1]
teacher_hidden = teacher_outputs.last_hidden_state[:, :-1]

# Release full outputs to free memory
del student_outputs, teacher_outputs

# labels mask and labels (shifted)
labels_mask = inputs["labels"] != -100
Expand All @@ -331,6 +332,9 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
)
true_labels = masked_input_ids[:, 1:].contiguous()

# Release intermediate tensors
del labels_mask, masked_input_ids

# heads
student_head = unwrapped_student.get_output_embeddings()
teacher_head = unwrapped_teacher.get_output_embeddings()
Expand All @@ -345,6 +349,9 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
student_bias=getattr(student_head, "bias", None),
teacher_bias=getattr(teacher_head, "bias", None),
)

# Release hidden states after loss computation
del student_hidden, teacher_hidden, true_labels
else:
# compute student output
student_outputs = model(
Expand Down
Loading