Skip to content

Commit d3dc8ff

Browse files
yao-matrixkashif
andauthored
use device agnostic empty_cache in ppo & rloo (#3439)
Signed-off-by: Matrix Yao <[email protected]> Co-authored-by: Kashif Rasul <[email protected]>
1 parent 21738c3 commit d3dc8ff

File tree

2 files changed

+16
-15
lines changed

2 files changed

+16
-15
lines changed

trl/trainer/ppo_trainer.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
OnlineTrainerState,
5555
batch_generation,
5656
disable_dropout_in_model,
57+
empty_cache,
5758
exact_div,
5859
first_true_indices,
5960
forward,
@@ -437,7 +438,7 @@ def repeat_generator():
437438
logits = logitss[i : i + args.local_rollout_forward_batch_size]
438439
logprob = selective_log_softmax(logits, response)
439440
del logits
440-
torch.cuda.empty_cache()
441+
empty_cache()
441442

442443
if ref_policy is None:
443444
with self.null_ref_context():
@@ -448,7 +449,7 @@ def repeat_generator():
448449
ref_logits /= args.temperature + 1e-7
449450
ref_logprob = selective_log_softmax(ref_logits, response)
450451
del ref_output, ref_logits
451-
torch.cuda.empty_cache()
452+
empty_cache()
452453

453454
# Response Processing 1. truncate response after the first occurrence of `stop_token_id`
454455
postprocessed_response = response
@@ -484,7 +485,7 @@ def repeat_generator():
484485
scores = torch.cat(scores, 0)
485486
values = torch.cat(values, 0)
486487
del (logprob, ref_logprob, full_value, value, score, unwrapped_model)
487-
torch.cuda.empty_cache()
488+
empty_cache()
488489
gc.collect()
489490

490491
# Response Processing 3. Filter completion. Ensure that the sample contains stop_token_id
@@ -531,7 +532,7 @@ def repeat_generator():
531532
returns = advantages + values
532533
advantages = masked_whiten(advantages, ~padding_mask)
533534
advantages = torch.masked_fill(advantages, padding_mask, 0)
534-
torch.cuda.empty_cache()
535+
empty_cache()
535536

536537
# Do multiple epochs of PPO training, with a fresh random shuffle in each epoch
537538
for ppo_epoch_idx in range(args.num_ppo_epochs):
@@ -612,7 +613,7 @@ def repeat_generator():
612613
mb_advantage, mb_values, mb_responses, mb_query_responses, mb_logprobs,
613614
)
614615
# fmt: on
615-
torch.cuda.empty_cache()
616+
empty_cache()
616617
with torch.no_grad():
617618
mean_kl = kl.sum(1).mean()
618619
mean_entropy = (-logprobs).sum(1).mean()
@@ -649,12 +650,12 @@ def repeat_generator():
649650
self._save_checkpoint(model, trial=None)
650651
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
651652
del kl, mean_kl, mean_entropy, mean_non_score_reward, scores, metrics, non_score_reward
652-
torch.cuda.empty_cache()
653+
empty_cache()
653654
gc.collect()
654655

655656
if args.num_sample_generations > 0 and (update - 1) % self.sample_generations_freq == 0:
656657
self.generate_completions(sampling=True)
657-
torch.cuda.empty_cache()
658+
empty_cache()
658659
del (
659660
query_responses,
660661
responses,
@@ -674,7 +675,7 @@ def repeat_generator():
674675
advantages,
675676
returns,
676677
)
677-
torch.cuda.empty_cache()
678+
empty_cache()
678679

679680
# HF trainer specifics
680681
self.control = self.callback_handler.on_train_end(args, self.state, self.control)

trl/trainer/rloo_trainer.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
truncate_response,
6161
)
6262
from .rloo_config import RLOOConfig
63-
from .utils import generate_model_card, get_comet_experiment_url, log_table_to_comet_experiment
63+
from .utils import empty_cache, generate_model_card, get_comet_experiment_url, log_table_to_comet_experiment
6464

6565

6666
if is_wandb_available():
@@ -333,14 +333,14 @@ def repeat_generator():
333333
logits = logitss[i : i + args.local_rollout_forward_batch_size]
334334
logprob = selective_log_softmax(logits, response)
335335
del logits
336-
torch.cuda.empty_cache()
336+
empty_cache()
337337

338338
ref_output = forward(ref_policy, query_response, processing_class.pad_token_id)
339339
ref_logits = ref_output.logits[:, context_length - 1 : -1]
340340
ref_logits /= args.temperature + 1e-7
341341
ref_logprob = selective_log_softmax(ref_logits, response)
342342
del ref_output, ref_logits
343-
torch.cuda.empty_cache()
343+
empty_cache()
344344

345345
# Response Processing 1. truncate response after the first occurrence of `stop_token_id`
346346
postprocessed_response = response
@@ -381,7 +381,7 @@ def repeat_generator():
381381
sequence_lengths = torch.cat(sequence_lengths, 0)
382382
scores = torch.cat(scores, 0)
383383
del (logprob, ref_logprob, score)
384-
torch.cuda.empty_cache()
384+
empty_cache()
385385
gc.collect()
386386

387387
# Response Processing 3. filter response. Ensure that the sample contains stop_token_id
@@ -439,7 +439,7 @@ def repeat_generator():
439439
if args.normalize_advantage:
440440
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
441441

442-
torch.cuda.empty_cache()
442+
empty_cache()
443443

444444
# Do multiple epochs of PPO training, with a fresh random shuffle in each epoch
445445
for ppo_epoch_idx in range(args.num_ppo_epochs):
@@ -515,7 +515,7 @@ def repeat_generator():
515515
mb_advantage, mb_responses, mb_query_responses, mb_logprobs,
516516
)
517517
# fmt: on
518-
torch.cuda.empty_cache()
518+
empty_cache()
519519

520520
# Compute metrics
521521
with torch.no_grad():
@@ -552,7 +552,7 @@ def repeat_generator():
552552
if self.control.should_save:
553553
self._save_checkpoint(model, trial=None)
554554
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
555-
torch.cuda.empty_cache()
555+
empty_cache()
556556
gc.collect()
557557

558558
if args.num_sample_generations > 0 and (update - 1) % self.sample_generations_freq == 0:

0 commit comments

Comments
 (0)