Skip to content

Conversation

qgallouedec
Copy link
Member

@qgallouedec qgallouedec commented Sep 30, 2025

Before

comparing

  • BSZ 4, GAS 2
  • BSZ 2, GAS 4
  • BSZ 1, GAS 8
Screenshot 2025-09-29 at 10 10 46 PM

After

comparing

  • BSZ 4, GAS 2
  • BSZ 2, GAS 4
  • BSZ 1, GAS 8
Screenshot 2025-09-29 at 10 09 40 PM

and comparing

  • kernels-community/flash-attn
  • kernels-community/flash-attn3
  • kernels-community/vllm-flash-attn3
Screenshot 2025-09-29 at 10 24 56 PM

and comparing 1/2 GPUs (slightly different, not sure why)

Screenshot 2025-09-29 at 10 35 01 PM

Bonus: it's waaaay faster

Screenshot 2025-09-29 at 10 26 08 PM

To reproduce

from trl import SFTTrainer, SFTConfig
from datasets import load_dataset
import os

os.environ["TRACKIO_SPACE_ID"] = "qgallouedec/trackio"
os.environ["TRACKIO_PROJECT"] = "fa_loss_bug"

dataset = load_dataset("trl-lib/Capybara", split="train[:5%]")

gas=2
bsz=4

trainer = SFTTrainer(
    model="HuggingFaceTB/SmolLM3-3B",
    args=SFTConfig(
        per_device_train_batch_size=bsz,
        gradient_accumulation_steps=gas,
        logging_steps=1,
        packing=True,
        use_liger_kernel=True,
        max_length=32768,
        model_init_kwargs={"torch_dtype": "bfloat16", "attn_implementation": "kernels-community/vllm-flash-attn3"},
        report_to="trackio",
        run_name=f"after_bsz{bsz}_gas{gas}_vllm_fa3",
    ),
    train_dataset=dataset,
)
trainer.train()

output["position_ids"] = pad(
position_ids, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
)
output["labels"][output["position_ids"] == 0] = -100
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the most important line

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@lewtun lewtun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent fix, LGTM

@qgallouedec qgallouedec changed the title Fix FA loss ⚡ Fix Flash Attention x Padding-Free loss Sep 30, 2025
@qgallouedec qgallouedec merged commit ebb8899 into main Sep 30, 2025
11 of 12 checks passed
@qgallouedec qgallouedec deleted the fix-fa-loss branch September 30, 2025 18:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants