steps_per_generation in GRPOTrainer #4200
-
|
I am looking through the GRPOTrainer code. Can anyone help me understand why steps_per_generation is involved in the calculation of repeat_count here, in the construction of the RepeatSampler? trl/trl/trainer/grpo_trainer.py Line 695 in e086f07 It seems to be also used as a multiplier for the batch size: trl/trl/trainer/grpo_config.py Line 650 in e086f07 trl/trl/trainer/grpo_trainer.py Line 694 in e086f07 However, my understanding is that each batch gets repeated repeat_count times in the Sampler, based on the following snippets: Line 1752 in e086f07 Line 1759 in e086f07 So, it seems to me like steps_per_generation contributes to both batch size and number of times the batch is repeated. If this understanding is correct, why is this? If not, what did I miss? Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 2 replies
-
|
steps_per_generation controls how often you generate. If it's 1, it generates every step. If it's 2, it generates for 2 steps every 2 steps. And so on. |
Beta Was this translation helpful? Give feedback.
-
|
Thank you so much for taking the time to answer! Unfortunately, I'm still not quite understanding the details that I see in the code. I think the main point of confusion for me is why repeat_count is not simply equal to num_iterations (in the first snippet I cited). Say we want steps_per_generation=2: It makes sense to me that our number of prompts per batch should therefore be multiplied by 2 (second & third snippets) -- and we then divide that batch up over 2 iterations via _prepare_inputs. However, it's unclear to me why we also multiply by 2 the number of iterations for which that doubled batch of prompts gets used (first & last snippets). What is the reasoning for this? Or, am I wrong that this is what is happening? Thanks again. |
Beta Was this translation helpful? Give feedback.
Yes I get that it may not be super easy to understand the implementation details. Maybe the important point here is that we don't really need the repetition in the sampling, because when the trainer doesn't need to generate, it just ignores the samples data. See https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py#L1007-L1008
So if you have steps_per_generation=2 then
And so on