5454 OnlineTrainerState ,
5555 batch_generation ,
5656 disable_dropout_in_model ,
57+ empty_cache ,
5758 exact_div ,
5859 first_true_indices ,
5960 forward ,
@@ -437,7 +438,7 @@ def repeat_generator():
437438 logits = logitss [i : i + args .local_rollout_forward_batch_size ]
438439 logprob = selective_log_softmax (logits , response )
439440 del logits
440- torch . cuda . empty_cache ()
441+ empty_cache ()
441442
442443 if ref_policy is None :
443444 with self .null_ref_context ():
@@ -448,7 +449,7 @@ def repeat_generator():
448449 ref_logits /= args .temperature + 1e-7
449450 ref_logprob = selective_log_softmax (ref_logits , response )
450451 del ref_output , ref_logits
451- torch . cuda . empty_cache ()
452+ empty_cache ()
452453
453454 # Response Processing 1. truncate response after the first occurrence of `stop_token_id`
454455 postprocessed_response = response
@@ -484,7 +485,7 @@ def repeat_generator():
484485 scores = torch .cat (scores , 0 )
485486 values = torch .cat (values , 0 )
486487 del (logprob , ref_logprob , full_value , value , score , unwrapped_model )
487- torch . cuda . empty_cache ()
488+ empty_cache ()
488489 gc .collect ()
489490
490491 # Response Processing 3. Filter completion. Ensure that the sample contains stop_token_id
@@ -531,7 +532,7 @@ def repeat_generator():
531532 returns = advantages + values
532533 advantages = masked_whiten (advantages , ~ padding_mask )
533534 advantages = torch .masked_fill (advantages , padding_mask , 0 )
534- torch . cuda . empty_cache ()
535+ empty_cache ()
535536
536537 # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch
537538 for ppo_epoch_idx in range (args .num_ppo_epochs ):
@@ -612,7 +613,7 @@ def repeat_generator():
612613 mb_advantage , mb_values , mb_responses , mb_query_responses , mb_logprobs ,
613614 )
614615 # fmt: on
615- torch . cuda . empty_cache ()
616+ empty_cache ()
616617 with torch .no_grad ():
617618 mean_kl = kl .sum (1 ).mean ()
618619 mean_entropy = (- logprobs ).sum (1 ).mean ()
@@ -649,12 +650,12 @@ def repeat_generator():
649650 self ._save_checkpoint (model , trial = None )
650651 self .control = self .callback_handler .on_save (self .args , self .state , self .control )
651652 del kl , mean_kl , mean_entropy , mean_non_score_reward , scores , metrics , non_score_reward
652- torch . cuda . empty_cache ()
653+ empty_cache ()
653654 gc .collect ()
654655
655656 if args .num_sample_generations > 0 and (update - 1 ) % self .sample_generations_freq == 0 :
656657 self .generate_completions (sampling = True )
657- torch . cuda . empty_cache ()
658+ empty_cache ()
658659 del (
659660 query_responses ,
660661 responses ,
@@ -674,7 +675,7 @@ def repeat_generator():
674675 advantages ,
675676 returns ,
676677 )
677- torch . cuda . empty_cache ()
678+ empty_cache ()
678679
679680 # HF trainer specifics
680681 self .control = self .callback_handler .on_train_end (args , self .state , self .control )
0 commit comments