diff --git a/tests/test_gkd_trainer.py b/tests/test_gkd_trainer.py index 4a0d458440c..b7de475c699 100644 --- a/tests/test_gkd_trainer.py +++ b/tests/test_gkd_trainer.py @@ -14,7 +14,6 @@ import os -import pytest import torch import torch.nn.functional as F from datasets import load_dataset @@ -247,7 +246,6 @@ def test_gkd_trainer(self): self.assertIn("model.safetensors", os.listdir(self.tmp_dir + "/checkpoint-2")) @require_liger_kernel - @pytest.mark.xfail(reason="Computing the Liger loss spikes GPU memory usage, causing the test to run OOM.") def test_gkd_trainer_with_liger(self): training_args = GKDConfig( output_dir=self.tmp_dir, diff --git a/trl/trainer/gkd_trainer.py b/trl/trainer/gkd_trainer.py index 72844a2cf72..b618d789616 100644 --- a/trl/trainer/gkd_trainer.py +++ b/trl/trainer/gkd_trainer.py @@ -300,7 +300,6 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N student_outputs = base_student( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], - output_hidden_states=True, use_cache=False, ) @@ -316,13 +315,15 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N teacher_outputs = base_teacher( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], - output_hidden_states=True, use_cache=False, ) # hidden states (shifted) - student_hidden = student_outputs.last_hidden_state[:, :-1].contiguous() - teacher_hidden = teacher_outputs.last_hidden_state[:, :-1].contiguous() + student_hidden = student_outputs.last_hidden_state[:, :-1] + teacher_hidden = teacher_outputs.last_hidden_state[:, :-1] + + # Release full outputs to free memory + del student_outputs, teacher_outputs # labels mask and labels (shifted) labels_mask = inputs["labels"] != -100 @@ -331,6 +332,9 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N ) true_labels = masked_input_ids[:, 1:].contiguous() + # Release intermediate tensors + del labels_mask, masked_input_ids + # heads student_head = unwrapped_student.get_output_embeddings() teacher_head = unwrapped_teacher.get_output_embeddings() @@ -345,6 +349,9 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N student_bias=getattr(student_head, "bias", None), teacher_bias=getattr(teacher_head, "bias", None), ) + + # Release hidden states after loss computation + del student_hidden, teacher_hidden, true_labels else: # compute student output student_outputs = model(