-
Notifications
You must be signed in to change notification settings - Fork 34
Closed
Description
Summary
During multi-turn PPO training, we’d like better hyperparameter customization and new PPO feature support.trianing implemenation.
Feature 1: More Diverse KL Divergence Approximation
Allow users to flexibly choose how KL divergence is calculated in compute_kl_divergence
(in tunix/rl/common.py
).
kl
: Unbiased, high-variance forward KL →log(p) - log(q)
mse_kl
: Biased, low-variance →0.5 * (log(p) - log(q))^2
low_var_kl
: Unbiased, low-variance control variate →(r - 1) - log(r)
wherer = exp(ref_logp - logp)
(Ref: J. Schulman’s blog: http://joschu.net/blog/kl-approx.html)
(Example: https://github.com/Yuxuan-Zhang-Dexter/tunix/blob/334e6b22b8c5bd30da9167b1004272c04107b137/tunix/rl/common.py#L102-L142)
Feature 2: Entropy in Policy Loss
- Add entropy regularization to the policy loss to encourage exploration. (including entropy coefficient hyperparameter)
- Return logits so entropy can be calculated
(Example: https://github.com/lmgame-org/GRL/blob/4ab7250210a9c2ac6c156a33ea04a1e848d878d8/grl/trainer/tunix_agent_trainer_exp.py#L898-L1030)
Feature 3: Asymmetric & Dual Clipping in Policy Loss
- Support asymmetric clipping (
epsilon_low
/epsilon_high
) instead of symmetric only. - Add an optional “dual-clipping” option (e.g.,
epsilon_c
) to guard against negative-advantage policy updates.
(Example: https://github.com/lmgame-org/GRL/blob/4ab7250210a9c2ac6c156a33ea04a1e848d878d8/grl/trainer/tunix_agent_trainer_exp.py#L968-L998)
Feature 4: Support Custom completion_mask
in process_ids
In tunix/rl/common.py
, the process_ids
function always calls make_completion_mask
, generating a new mask from the first eos
token. This discards any pre-built completion mask passed in.
Multi-turn PPO training uses full trajectories that often contain multiple eos
tokens. The default behavior stops at the first eos
, effectively truncating the trajectory and invalidating multi-turn context.
- Add an optional
completion_mask
parameter toprocess_ids
(and all functions that rely on it). - Use the passed-in mask if available; only call
make_completion_mask
when none is provided.
(Example: https://github.com/Yuxuan-Zhang-Dexter/tunix/blob/ppo_feature_dev/tunix/rl/common.py)
Feature 5: Add More Metrics and Organize Logs by Category in W&B
- Log richer metrics:
old_value
→ min, mean, maxadvantage
→ min, mean, maxreturn
→ min, mean, maxentropy
,vf_loss
, etc.
- Organize metrics into logical groups
train/old_value/min
,train/old_value/mean
,train/old_value/max
train/advantage/mean
,train/advantage/min
,train/advantage/max
eval/return/mean
,actor/entropy
,critic/vf_loss
, etc.
Mddct, tianshub and abheesht17
Metadata
Metadata
Assignees
Labels
No labels