Skip to content

Commit 05bc43e

Browse files
ucalyptuskashiflewtun
authored
feat: Implement Two-Sided Clipping for GRPO Trainer (#3434)
Co-authored-by: Kashif Rasul <[email protected]> Co-authored-by: lewtun <[email protected]>
1 parent d3dc8ff commit 05bc43e

File tree

3 files changed

+99
-1
lines changed

3 files changed

+99
-1
lines changed

tests/test_grpo_trainer.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1147,3 +1147,84 @@ def test_training_num_generations_larger_than_batch_size(self):
11471147
for n, param in previous_trainable_params.items():
11481148
new_param = trainer.model.get_parameter(n)
11491149
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
1150+
1151+
@staticmethod
1152+
def _make_delta_trainer(tmp_dir, tokenizer, dataset):
1153+
"""Helper method to create a GRPOTrainer with specific delta clipping parameters."""
1154+
cfg = GRPOConfig(
1155+
output_dir=tmp_dir,
1156+
epsilon=0.20,
1157+
delta=2.0,
1158+
epsilon_high=0.20,
1159+
beta=0.0,
1160+
loss_type="bnpo",
1161+
max_completion_length=2,
1162+
report_to="none",
1163+
)
1164+
return GRPOTrainer(
1165+
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
1166+
args=cfg,
1167+
reward_funcs=lambda x, **y: [1.0] * len(x),
1168+
train_dataset=dataset,
1169+
processing_class=tokenizer,
1170+
)
1171+
1172+
@staticmethod
1173+
def _delta_inputs(device):
1174+
"""Helper method to create standard inputs for delta clipping tests."""
1175+
return {
1176+
"prompt_ids": torch.tensor([[101]], device=device),
1177+
"prompt_mask": torch.tensor([[1]], device=device),
1178+
"completion_ids": torch.tensor([[2000, 2001]], device=device),
1179+
"completion_mask": torch.tensor([[1, 1]], device=device),
1180+
}
1181+
1182+
@parameterized.expand(
1183+
[
1184+
# name, advantage, old_prob, new_prob, expected_loss
1185+
("pos_ratio_in_clip", 2.0, 0.50, 0.55, -2.2),
1186+
("pos_ratio_above_clip", 2.0, 0.40, 0.60, -2.4),
1187+
("neg_ratio_in_clip", -2.0, 0.50, 0.45, 1.8),
1188+
("neg_ratio_below_clip", -2.0, 0.50, 0.35, 1.6),
1189+
("neg_ratio_above_delta", -2.0, 0.20, 0.50, 4.0),
1190+
("neg_between_clip_delta", -2.0, 0.40, 0.60, 3.0),
1191+
]
1192+
)
1193+
def test_two_sided_clipping_loss(self, name, advantage, old_prob, new_prob, expected_loss):
1194+
"""Test two-sided GRPO clipping logic with different scenarios.
1195+
1196+
Args:
1197+
name: Test case name
1198+
advantage: Advantage value for the scenario
1199+
old_prob: Old policy probability
1200+
new_prob: New policy probability
1201+
expected_loss: Expected loss value
1202+
"""
1203+
with tempfile.TemporaryDirectory() as tmp_dir:
1204+
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
1205+
tokenizer = AutoTokenizer.from_pretrained(model_id)
1206+
1207+
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train[:1]")
1208+
1209+
trainer = self._make_delta_trainer(tmp_dir, tokenizer, dataset)
1210+
1211+
inputs = self._delta_inputs(trainer.accelerator.device)
1212+
inputs.update(
1213+
{
1214+
"advantages": torch.tensor([advantage], device=trainer.accelerator.device),
1215+
"old_per_token_logps": torch.log(torch.tensor([[old_prob]], device=trainer.accelerator.device)),
1216+
}
1217+
)
1218+
1219+
# Mock _get_per_token_logps to return predefined new log probabilities
1220+
with patch.object(trainer, "_get_per_token_logps") as mock_logps_func:
1221+
mock_logps_func.return_value = torch.log(torch.tensor([[new_prob]], device=trainer.accelerator.device))
1222+
1223+
# Compute loss and verify
1224+
loss = trainer.compute_loss(trainer.model, inputs)
1225+
self.assertAlmostEqual(
1226+
loss.item(),
1227+
expected_loss,
1228+
delta=1e-5,
1229+
msg=f"Scenario {name} failed: expected {expected_loss}, got {loss.item()}",
1230+
)

trl/trainer/grpo_config.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,8 @@ class GRPOConfig(TrainingArguments):
136136
Number of iterations per batch (denoted as μ in the algorithm).
137137
epsilon (`float`, *optional*, defaults to `0.2`):
138138
Epsilon value for clipping.
139+
delta: (`float`, *optional*, defaults to `None`):
140+
Delta value for the upper clipping bound in two-sided GRPO. Recommended to be > 1 + epsilon. This method was introduced in the [INTELLECT-2 tech report](https://huggingface.co/papers/2505.07291).
139141
epsilon_high (`float` or `None`, *optional*, defaults to `None`):
140142
Upper-bound epsilon value for clipping. If not specified, it defaults to the same value as the lower-bound
141143
specified in argument `epsilon`. Paper [DAPO](https://huggingface.co/papers/2503.14476) recommends `0.28`.
@@ -389,6 +391,12 @@ class GRPOConfig(TrainingArguments):
389391
default=0.2,
390392
metadata={"help": "Epsilon value for clipping."},
391393
)
394+
delta: Optional[float] = field(
395+
default=None,
396+
metadata={
397+
"help": "If set to a float value (e.g., 2.0), enables the upper clipping bound in two-sided GRPO loss. If None (default), the standard GRPO clipping is used. Recommended to be > 1 + epsilon when enabled."
398+
},
399+
)
392400
epsilon_high: Optional[float] = field(
393401
default=None,
394402
metadata={
@@ -536,3 +544,5 @@ def __post_init__(self):
536544
"current global eval batch size, the valid values for the number of generations are: "
537545
f"{possible_values}."
538546
)
547+
if self.delta is not None and self.use_liger_loss:
548+
raise ValueError("Liger loss does not support two-sided GRPO loss yet.")

trl/trainer/grpo_trainer.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1347,7 +1347,14 @@ def _compute_loss(self, model, inputs):
13471347
)
13481348
coef_1 = torch.exp(per_token_logps - old_per_token_logps)
13491349
coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high)
1350-
per_token_loss1 = coef_1 * advantages.unsqueeze(1)
1350+
1351+
if self.args.delta is not None:
1352+
# Use clamp instead of min to handle tensor-float comparison
1353+
per_token_loss1 = torch.clamp(coef_1, max=self.args.delta) * advantages.unsqueeze(1)
1354+
else:
1355+
# Original GRPO clipping (only lower bound implicitly applied by the final min)
1356+
per_token_loss1 = coef_1 * advantages.unsqueeze(1)
1357+
13511358
per_token_loss2 = coef_2 * advantages.unsqueeze(1)
13521359
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
13531360
if self.beta != 0.0:

0 commit comments

Comments
 (0)