Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions src/art/local/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,6 @@ async def train( # type: ignore[override]
*,
# Core training parameters
learning_rate: float = 5e-6,
beta: float = 0.0,
# KL-penalized advantage adjustment
kl_penalty_coef: float = 0.0,
kl_penalty_reference_step: int | None = None,
Expand Down Expand Up @@ -470,7 +469,6 @@ async def train( # type: ignore[override]
model: The trainable model to train.
trajectory_groups: Batches of trajectories to train on.
learning_rate: Learning rate for training. Defaults to 5e-6.
beta: KL penalty coefficient added to the loss. Defaults to 0.0.
kl_penalty_coef: Coefficient for KL-penalized advantage adjustment.
Tokens diverging more from the reference get reduced advantages.
Defaults to 0.0 (disabled).
Expand Down Expand Up @@ -527,7 +525,7 @@ async def train( # type: ignore[override]

# Build config objects from explicit kwargs
config = TrainConfig(
learning_rate=learning_rate, beta=beta, kl_penalty_coef=kl_penalty_coef
learning_rate=learning_rate, kl_penalty_coef=kl_penalty_coef
)
dev_config: dev.TrainConfig = {
"advantage_balance": advantage_balance,
Expand Down
10 changes: 0 additions & 10 deletions src/art/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
class Loss(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
mean_policy_loss: torch.Tensor
mean_kl: torch.Tensor
mean_entropy: torch.Tensor | None
policy_loss_sum: torch.Tensor
probs_corr: torch.Tensor
Expand Down Expand Up @@ -124,16 +123,8 @@ def loss_fn(
logprob_diff = old_logprobs - original_logprobs
prob_ratio = torch.exp(logprob_diff)
policy_loss *= torch.clamp(prob_ratio, max=upper_bound).detach()
if ref_logprobs is not None:
kl_div = (
torch.exp(ref_logprobs - new_logprobs) - (ref_logprobs - new_logprobs) - 1.0
)
else:
kl_div = torch.zeros_like(policy_loss)
policy_loss = policy_loss * weights * assistant_mask
kl_div = kl_div * weights * assistant_mask
mean_policy_loss = policy_loss.sum() / (assistant_mask.sum() + 1e-6)
mean_kl = kl_div.sum() / (assistant_mask.sum() + 1e-6)
# Compute mean entropy for the current step
if entropies is not None:
shifted_entropies = shift_tensor(entropies, 0.0)
Expand All @@ -144,7 +135,6 @@ def loss_fn(
mean_entropy = None
return Loss(
mean_policy_loss=mean_policy_loss,
mean_kl=mean_kl,
mean_entropy=mean_entropy,
policy_loss_sum=policy_loss.sum(),
probs_corr=probs_corr,
Expand Down
2 changes: 1 addition & 1 deletion src/art/megatron/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def print0(*values: Any) -> None:
)
probs_corr = loss.probs_corr.item()
print0("Correlation between old and new probabilities:", probs_corr)
loss = loss.mean_policy_loss + config.beta * loss.mean_kl
loss = loss.mean_policy_loss
loss.backward()
# Reduce LoRA grads
start = time.perf_counter()
Expand Down
4 changes: 1 addition & 3 deletions src/art/preprocessing/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,7 @@ def create_train_inputs(
[None] if warmup else packed_tensors["image_grid_thw"][offset : offset + 1]
),
config=(
config.model_copy(
update={"learning_rate": 1e-9, "beta": 0.0, "kl_penalty_coef": 0.0}
)
config.model_copy(update={"learning_rate": 1e-9, "kl_penalty_coef": 0.0})
if warmup
else config
),
Expand Down
4 changes: 1 addition & 3 deletions src/art/serverless/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,6 @@ async def train( # type: ignore[override]
*,
# Core training parameters
learning_rate: float = 5e-6,
beta: float = 0.0,
# RL algorithm settings
ppo: bool = False,
epsilon: float | None = None,
Expand Down Expand Up @@ -179,7 +178,6 @@ async def train( # type: ignore[override]
model: The trainable model to train.
trajectory_groups: Batches of trajectories to train on.
learning_rate: Learning rate for training. Defaults to 5e-6.
beta: KL penalty coefficient. Defaults to 0.0.
ppo: Whether to use PPO clipping. Defaults to False.
epsilon: Clip epsilon for importance sampling. Defaults based on ppo.
epsilon_high: Asymmetric upper clip bound. Defaults to epsilon.
Expand Down Expand Up @@ -212,7 +210,7 @@ async def train( # type: ignore[override]
groups_list = list(trajectory_groups)

# Build config objects from explicit kwargs
config = TrainConfig(learning_rate=learning_rate, beta=beta)
config = TrainConfig(learning_rate=learning_rate)
dev_config: dev.TrainConfig = {
"advantage_balance": advantage_balance,
"importance_sampling_level": importance_sampling_level,
Expand Down
1 change: 0 additions & 1 deletion src/art/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

class TrainConfig(pydantic.BaseModel):
learning_rate: float = 5e-6
beta: float = 0.0
kl_penalty_coef: float = 0.0


Expand Down
6 changes: 2 additions & 4 deletions src/art/unsloth/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def compute_loss(
)
if return_new_logprobs:
return torch.nn.functional.pad(new_logprobs[:, :-1], (1, 0), value=0.0)
if config.beta > 0.0 or config.kl_penalty_coef > 0.0:
if config.kl_penalty_coef > 0.0:
ref_adapter = _config.get("kl_ref_adapter_path")
ref_logprobs, _ = calculate_logprobs(
dtype_for_autocasting,
Expand Down Expand Up @@ -173,11 +173,9 @@ def compute_loss(
trainer._metrics["train"]["policy_loss"].append(loss.mean_policy_loss.item())
if loss.mean_entropy is not None:
trainer._metrics["train"]["entropy"].append(loss.mean_entropy.item())
if config.beta > 0.0:
trainer._metrics["train"]["kl_div"].append(loss.mean_kl.item())
if loss.kl_policy_ref is not None:
trainer._metrics["train"]["kl_policy_ref"].append(loss.kl_policy_ref.item())
return loss.mean_policy_loss + config.beta * loss.mean_kl
return loss.mean_policy_loss

return compute_loss

Expand Down