Skip to content

Conversation

@shunting314
Copy link
Contributor

Add the backward formula of swiglu in examples/swiglu.py

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 1, 2025


@helion.kernel()
def swiglu_bwd(gout: Tensor, x1: Tensor, x2: Tensor) -> tuple[Tensor, Tensor]:
Copy link
Contributor

@oulgen oulgen Oct 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add this to run.py there are two lists you need to update there

also please run with triton bench and generate perf/accuracy numbers
cc: @yf225

Copy link
Contributor Author

@shunting314 shunting314 Oct 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@oulgen do you have an example to do that for a backward kernel? I can find a few examples for fwd but not bwd

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh found 'rms_norm-bwd' in the run.py. will follow it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ran this command:

python benchmarks/run.py --metrics speedup,accuracy --kernel swiglu-bwd

but don't see the number for helion. Any ideas?

      (B, T, H)    liger_swiglu-speedup    liger_swiglu-accuracy    torch_compile_swiglu-speedup    torch_compile_swiglu-accuracy
---------------  ----------------------  -----------------------  ------------------------------  -------------------------------
(4, 1024, 4096)                1.01139                         1                         1.03097                                1
(4, 2048, 4096)                1.02854                         1                         1.00777                                1
(4, 4096, 4096)                1.03631                         1                         1.03787                                1
(4, 8192, 4096)                0.841614                        1                         1.04048                                1
        average                0.979463                        1                         1.02927                                1

@oulgen

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yf225 can you help?

Copy link
Contributor

@yf225 yf225 Oct 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ran with the same command and with tritonbench's latest main (I did git pull in helion/benchmarks/tritonbench), and helion shows up:

      (B, T, H)    liger_swiglu-speedup    liger_swiglu-accuracy    torch_compile_swiglu-speedup    torch_compile_swiglu-accuracy    helion_swiglu_tritonbench-speedup    helion_swiglu_tritonbench-accuracy
---------------  ----------------------  -----------------------  ------------------------------  -------------------------------  -----------------------------------  ------------------------------------
(4, 1024, 4096)                0.994532                        1                        0.992842                                1                             1.00311                                      1
(4, 2048, 4096)                0.950479                        1                        0.973353                                1                             0.844336                                     1
(4, 4096, 4096)                0.982585                        1                        1.02047                                 1                             0.851285                                     1
(4, 8192, 4096)                1.01794                         1                        1.04066                                 1                             0.977584                                     1
        average                0.986385                        1                        1.00683                                 1                             0.919078                                     1

@shunting314 it could be a tritonbench version issue - wonder would you like to try again? thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can see the result now after pull from tritonbench

      (B, T, H)    liger_swiglu-speedup    liger_swiglu-accuracy    torch_compile_swiglu-speedup    torch_compile_swiglu-accuracy    helion_swiglu_tritonbench-speedup    helion_swiglu_tritonbench-accuracy
---------------  ----------------------  -----------------------  ------------------------------  -------------------------------  -----------------------------------  ------------------------------------
(4, 1024, 4096)                1.06699                         1                        1.02534                                 1                              1.00823                                     1
(4, 2048, 4096)                1.02478                         1                        0.952649                                1                              1.03361                                     1
(4, 4096, 4096)                0.991505                        1                        1.03377                                 1                              1.02527                                     1
(4, 8192, 4096)                0.925007                        1                        1.06515                                 1                              1.03323                                     1
        average                1.00207                         1                        1.01923                                 1                              1.02509                                     1
        ```

@shunting314 shunting314 force-pushed the swiglu-bwd branch 2 times, most recently from abbe582 to 5d29b48 Compare October 1, 2025 23:18
@shunting314 shunting314 requested a review from oulgen October 1, 2025 23:18
@shunting314
Copy link
Contributor Author

ping @oulgen @yf225 for review

Copy link
Contributor

@yf225 yf225 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @shunting314 !

@shunting314
Copy link
Contributor Author

CI failure is unrelated.

I don't seem to have permission to merge. @oulgen @yf225 can you help?

@jansel jansel merged commit 8b38f31 into pytorch:main Oct 24, 2025
13 of 14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants