-
Notifications
You must be signed in to change notification settings - Fork 58
[helion] backward support for swiglu #756
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
8d8a27d to
99587ae
Compare
|
|
||
|
|
||
| @helion.kernel() | ||
| def swiglu_bwd(gout: Tensor, x1: Tensor, x2: Tensor) -> tuple[Tensor, Tensor]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please add this to run.py there are two lists you need to update there
also please run with triton bench and generate perf/accuracy numbers
cc: @yf225
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@oulgen do you have an example to do that for a backward kernel? I can find a few examples for fwd but not bwd
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh found 'rms_norm-bwd' in the run.py. will follow it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I ran this command:
python benchmarks/run.py --metrics speedup,accuracy --kernel swiglu-bwd
but don't see the number for helion. Any ideas?
(B, T, H) liger_swiglu-speedup liger_swiglu-accuracy torch_compile_swiglu-speedup torch_compile_swiglu-accuracy
--------------- ---------------------- ----------------------- ------------------------------ -------------------------------
(4, 1024, 4096) 1.01139 1 1.03097 1
(4, 2048, 4096) 1.02854 1 1.00777 1
(4, 4096, 4096) 1.03631 1 1.03787 1
(4, 8192, 4096) 0.841614 1 1.04048 1
average 0.979463 1 1.02927 1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 can you help?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I ran with the same command and with tritonbench's latest main (I did git pull in helion/benchmarks/tritonbench), and helion shows up:
(B, T, H) liger_swiglu-speedup liger_swiglu-accuracy torch_compile_swiglu-speedup torch_compile_swiglu-accuracy helion_swiglu_tritonbench-speedup helion_swiglu_tritonbench-accuracy
--------------- ---------------------- ----------------------- ------------------------------ ------------------------------- ----------------------------------- ------------------------------------
(4, 1024, 4096) 0.994532 1 0.992842 1 1.00311 1
(4, 2048, 4096) 0.950479 1 0.973353 1 0.844336 1
(4, 4096, 4096) 0.982585 1 1.02047 1 0.851285 1
(4, 8192, 4096) 1.01794 1 1.04066 1 0.977584 1
average 0.986385 1 1.00683 1 0.919078 1
@shunting314 it could be a tritonbench version issue - wonder would you like to try again? thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can see the result now after pull from tritonbench
(B, T, H) liger_swiglu-speedup liger_swiglu-accuracy torch_compile_swiglu-speedup torch_compile_swiglu-accuracy helion_swiglu_tritonbench-speedup helion_swiglu_tritonbench-accuracy
--------------- ---------------------- ----------------------- ------------------------------ ------------------------------- ----------------------------------- ------------------------------------
(4, 1024, 4096) 1.06699 1 1.02534 1 1.00823 1
(4, 2048, 4096) 1.02478 1 0.952649 1 1.03361 1
(4, 4096, 4096) 0.991505 1 1.03377 1 1.02527 1
(4, 8192, 4096) 0.925007 1 1.06515 1 1.03323 1
average 1.00207 1 1.01923 1 1.02509 1
```
abbe582 to
5d29b48
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks @shunting314 !
5d29b48 to
9e19810
Compare
9e19810 to
73d7d47
Compare
Add the backward formula of swiglu in examples/swiglu.py