diff --git a/src/art/local/backend.py b/src/art/local/backend.py index b74c0b05..87676893 100644 --- a/src/art/local/backend.py +++ b/src/art/local/backend.py @@ -427,7 +427,6 @@ async def train( # type: ignore[override] *, # Core training parameters learning_rate: float = 5e-6, - beta: float = 0.0, # KL-penalized advantage adjustment kl_penalty_coef: float = 0.0, kl_penalty_reference_step: int | None = None, @@ -470,7 +469,6 @@ async def train( # type: ignore[override] model: The trainable model to train. trajectory_groups: Batches of trajectories to train on. learning_rate: Learning rate for training. Defaults to 5e-6. - beta: KL penalty coefficient added to the loss. Defaults to 0.0. kl_penalty_coef: Coefficient for KL-penalized advantage adjustment. Tokens diverging more from the reference get reduced advantages. Defaults to 0.0 (disabled). @@ -527,7 +525,7 @@ async def train( # type: ignore[override] # Build config objects from explicit kwargs config = TrainConfig( - learning_rate=learning_rate, beta=beta, kl_penalty_coef=kl_penalty_coef + learning_rate=learning_rate, kl_penalty_coef=kl_penalty_coef ) dev_config: dev.TrainConfig = { "advantage_balance": advantage_balance, diff --git a/src/art/loss.py b/src/art/loss.py index a22cca3f..5a73d7b7 100644 --- a/src/art/loss.py +++ b/src/art/loss.py @@ -14,7 +14,6 @@ class Loss(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) mean_policy_loss: torch.Tensor - mean_kl: torch.Tensor mean_entropy: torch.Tensor | None policy_loss_sum: torch.Tensor probs_corr: torch.Tensor @@ -124,16 +123,8 @@ def loss_fn( logprob_diff = old_logprobs - original_logprobs prob_ratio = torch.exp(logprob_diff) policy_loss *= torch.clamp(prob_ratio, max=upper_bound).detach() - if ref_logprobs is not None: - kl_div = ( - torch.exp(ref_logprobs - new_logprobs) - (ref_logprobs - new_logprobs) - 1.0 - ) - else: - kl_div = torch.zeros_like(policy_loss) policy_loss = policy_loss * weights * assistant_mask - kl_div = kl_div * weights * assistant_mask mean_policy_loss = policy_loss.sum() / (assistant_mask.sum() + 1e-6) - mean_kl = kl_div.sum() / (assistant_mask.sum() + 1e-6) # Compute mean entropy for the current step if entropies is not None: shifted_entropies = shift_tensor(entropies, 0.0) @@ -144,7 +135,6 @@ def loss_fn( mean_entropy = None return Loss( mean_policy_loss=mean_policy_loss, - mean_kl=mean_kl, mean_entropy=mean_entropy, policy_loss_sum=policy_loss.sum(), probs_corr=probs_corr, diff --git a/src/art/megatron/train.py b/src/art/megatron/train.py index 480a03be..85c36d1f 100644 --- a/src/art/megatron/train.py +++ b/src/art/megatron/train.py @@ -250,7 +250,7 @@ def print0(*values: Any) -> None: ) probs_corr = loss.probs_corr.item() print0("Correlation between old and new probabilities:", probs_corr) - loss = loss.mean_policy_loss + config.beta * loss.mean_kl + loss = loss.mean_policy_loss loss.backward() # Reduce LoRA grads start = time.perf_counter() diff --git a/src/art/preprocessing/inputs.py b/src/art/preprocessing/inputs.py index 9e5a7a54..cd4d40be 100644 --- a/src/art/preprocessing/inputs.py +++ b/src/art/preprocessing/inputs.py @@ -41,9 +41,7 @@ def create_train_inputs( [None] if warmup else packed_tensors["image_grid_thw"][offset : offset + 1] ), config=( - config.model_copy( - update={"learning_rate": 1e-9, "beta": 0.0, "kl_penalty_coef": 0.0} - ) + config.model_copy(update={"learning_rate": 1e-9, "kl_penalty_coef": 0.0}) if warmup else config ), diff --git a/src/art/serverless/backend.py b/src/art/serverless/backend.py index abf67f69..dea0198e 100644 --- a/src/art/serverless/backend.py +++ b/src/art/serverless/backend.py @@ -149,7 +149,6 @@ async def train( # type: ignore[override] *, # Core training parameters learning_rate: float = 5e-6, - beta: float = 0.0, # RL algorithm settings ppo: bool = False, epsilon: float | None = None, @@ -179,7 +178,6 @@ async def train( # type: ignore[override] model: The trainable model to train. trajectory_groups: Batches of trajectories to train on. learning_rate: Learning rate for training. Defaults to 5e-6. - beta: KL penalty coefficient. Defaults to 0.0. ppo: Whether to use PPO clipping. Defaults to False. epsilon: Clip epsilon for importance sampling. Defaults based on ppo. epsilon_high: Asymmetric upper clip bound. Defaults to epsilon. @@ -212,7 +210,7 @@ async def train( # type: ignore[override] groups_list = list(trajectory_groups) # Build config objects from explicit kwargs - config = TrainConfig(learning_rate=learning_rate, beta=beta) + config = TrainConfig(learning_rate=learning_rate) dev_config: dev.TrainConfig = { "advantage_balance": advantage_balance, "importance_sampling_level": importance_sampling_level, diff --git a/src/art/types.py b/src/art/types.py index 017f05c7..088041ad 100644 --- a/src/art/types.py +++ b/src/art/types.py @@ -16,7 +16,6 @@ class TrainConfig(pydantic.BaseModel): learning_rate: float = 5e-6 - beta: float = 0.0 kl_penalty_coef: float = 0.0 diff --git a/src/art/unsloth/train.py b/src/art/unsloth/train.py index 34dbc5cd..fcb7e287 100644 --- a/src/art/unsloth/train.py +++ b/src/art/unsloth/train.py @@ -138,7 +138,7 @@ def compute_loss( ) if return_new_logprobs: return torch.nn.functional.pad(new_logprobs[:, :-1], (1, 0), value=0.0) - if config.beta > 0.0 or config.kl_penalty_coef > 0.0: + if config.kl_penalty_coef > 0.0: ref_adapter = _config.get("kl_ref_adapter_path") ref_logprobs, _ = calculate_logprobs( dtype_for_autocasting, @@ -173,11 +173,9 @@ def compute_loss( trainer._metrics["train"]["policy_loss"].append(loss.mean_policy_loss.item()) if loss.mean_entropy is not None: trainer._metrics["train"]["entropy"].append(loss.mean_entropy.item()) - if config.beta > 0.0: - trainer._metrics["train"]["kl_div"].append(loss.mean_kl.item()) if loss.kl_policy_ref is not None: trainer._metrics["train"]["kl_policy_ref"].append(loss.kl_policy_ref.item()) - return loss.mean_policy_loss + config.beta * loss.mean_kl + return loss.mean_policy_loss return compute_loss