Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions benchmarks/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,11 @@ class RunResult:
"examples.swiglu",
"swiglu_tritonbench",
),
"swiglu-bwd": (
"tritonbench.operators.swiglu.operator",
"examples.swiglu",
"swiglu_tritonbench",
),
"jsd": (
"tritonbench.operators.jsd.operator",
"examples.jsd",
Expand Down Expand Up @@ -440,6 +445,15 @@ class RunResult:
"helion_swiglu_tritonbench-speedup": "helion_speedup",
"helion_swiglu_tritonbench-accuracy": "helion_accuracy",
},
"swiglu-bwd": {
"torch_swiglu": "baseline",
"liger_swiglu-speedup": "triton_speedup",
"liger_swiglu-accuracy": "triton_accuracy",
"torch_compile_swiglu-speedup": "torch_compile_speedup",
"torch_compile_swiglu-accuracy": "torch_compile_accuracy",
"helion_swiglu_tritonbench-speedup": "helion_speedup",
"helion_swiglu_tritonbench-accuracy": "helion_accuracy",
},
"jsd": {
"torch_jsd": "baseline",
"liger_jsd-speedup": "triton_speedup",
Expand Down
62 changes: 61 additions & 1 deletion examples/swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

if TYPE_CHECKING:
from collections.abc import Callable
from typing import Any


# %%
Expand All @@ -45,7 +46,7 @@

# %%
@helion.kernel()
def swiglu(a: Tensor, b: Tensor) -> Tensor:
def swiglu_fwd(a: Tensor, b: Tensor) -> Tensor:
"""
Performs SwiGLU operation: SiLU(a) * b where SiLU is the Swish activation.

Expand Down Expand Up @@ -94,6 +95,65 @@ def swiglu(a: Tensor, b: Tensor) -> Tensor:
return out


@helion.kernel()
def swiglu_bwd(gout: Tensor, x1: Tensor, x2: Tensor) -> tuple[Tensor, Tensor]:
Copy link
Contributor

@oulgen oulgen Oct 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add this to run.py there are two lists you need to update there

also please run with triton bench and generate perf/accuracy numbers
cc: @yf225

Copy link
Contributor Author

@shunting314 shunting314 Oct 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@oulgen do you have an example to do that for a backward kernel? I can find a few examples for fwd but not bwd

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh found 'rms_norm-bwd' in the run.py. will follow it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ran this command:

python benchmarks/run.py --metrics speedup,accuracy --kernel swiglu-bwd

but don't see the number for helion. Any ideas?

      (B, T, H)    liger_swiglu-speedup    liger_swiglu-accuracy    torch_compile_swiglu-speedup    torch_compile_swiglu-accuracy
---------------  ----------------------  -----------------------  ------------------------------  -------------------------------
(4, 1024, 4096)                1.01139                         1                         1.03097                                1
(4, 2048, 4096)                1.02854                         1                         1.00777                                1
(4, 4096, 4096)                1.03631                         1                         1.03787                                1
(4, 8192, 4096)                0.841614                        1                         1.04048                                1
        average                0.979463                        1                         1.02927                                1

@oulgen

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yf225 can you help?

Copy link
Contributor

@yf225 yf225 Oct 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ran with the same command and with tritonbench's latest main (I did git pull in helion/benchmarks/tritonbench), and helion shows up:

      (B, T, H)    liger_swiglu-speedup    liger_swiglu-accuracy    torch_compile_swiglu-speedup    torch_compile_swiglu-accuracy    helion_swiglu_tritonbench-speedup    helion_swiglu_tritonbench-accuracy
---------------  ----------------------  -----------------------  ------------------------------  -------------------------------  -----------------------------------  ------------------------------------
(4, 1024, 4096)                0.994532                        1                        0.992842                                1                             1.00311                                      1
(4, 2048, 4096)                0.950479                        1                        0.973353                                1                             0.844336                                     1
(4, 4096, 4096)                0.982585                        1                        1.02047                                 1                             0.851285                                     1
(4, 8192, 4096)                1.01794                         1                        1.04066                                 1                             0.977584                                     1
        average                0.986385                        1                        1.00683                                 1                             0.919078                                     1

@shunting314 it could be a tritonbench version issue - wonder would you like to try again? thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can see the result now after pull from tritonbench

      (B, T, H)    liger_swiglu-speedup    liger_swiglu-accuracy    torch_compile_swiglu-speedup    torch_compile_swiglu-accuracy    helion_swiglu_tritonbench-speedup    helion_swiglu_tritonbench-accuracy
---------------  ----------------------  -----------------------  ------------------------------  -------------------------------  -----------------------------------  ------------------------------------
(4, 1024, 4096)                1.06699                         1                        1.02534                                 1                              1.00823                                     1
(4, 2048, 4096)                1.02478                         1                        0.952649                                1                              1.03361                                     1
(4, 4096, 4096)                0.991505                        1                        1.03377                                 1                              1.02527                                     1
(4, 8192, 4096)                0.925007                        1                        1.06515                                 1                              1.03323                                     1
        average                1.00207                         1                        1.01923                                 1                              1.02509                                     1
        ```

"""
Implement the backward formula for swiglu.
"""
dx1 = torch.empty_like(x1)
dx2 = torch.empty_like(x2)

gout_flat = gout.view(-1)
x1_flat = x1.view(-1)
x2_flat = x2.view(-1)
dx1_flat = dx1.view(-1)
dx2_flat = dx2.view(-1)

for tile in hl.tile(x1.numel()):
x1_vals = x1_flat[tile].to(torch.float32)
gout_vals = gout_flat[tile].to(torch.float32)

# compute dx2
dx2_vals = x1_vals * torch.sigmoid(x1_vals) * gout_vals
dx2_flat[tile] = dx2_vals.to(x2.dtype)

# compute dx1
x2_vals = x2_flat[tile].to(torch.float32)
x1_exp = torch.exp(x1_vals)
x1_exp_plus1 = x1_exp + 1
dextra = x1_exp / x1_exp_plus1 + x1_vals * x1_exp / x1_exp_plus1 / x1_exp_plus1
dx1_vals = gout_vals * x2_vals * dextra
dx1_flat[tile] = dx1_vals.to(x1.dtype)

return dx1, dx2


class SwigluFunction(torch.autograd.Function):
@staticmethod
def forward(
ctx: Any, # noqa: ANN401
x1: Tensor,
x2: Tensor,
) -> Tensor:
out = swiglu_fwd(x1, x2)
ctx.save_for_backward(x1, x2)
return out

@staticmethod
def backward( # type: ignore[override]
ctx: Any, # noqa: ANN401
grad_out: Tensor,
) -> tuple[Tensor, Tensor]:
x1, x2 = ctx.saved_tensors
dx1, dx2 = swiglu_bwd(grad_out, x1, x2)
return dx1, dx2


def swiglu(a: Tensor, b: Tensor) -> Tensor:
"""swiglu with forward + backward support."""
return SwigluFunction.apply(a, b) # type: ignore[no-any-return]


# %%
# SwiGLU MLP Module (matches liger_kernel structure)
# --------------------------------------------------
Expand Down
82 changes: 79 additions & 3 deletions test/test_examples.expected
Original file line number Diff line number Diff line change
Expand Up @@ -5952,7 +5952,7 @@ import triton.language as tl
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _helion_swiglu(a_flat, b_flat, out_flat, _BLOCK_SIZE_0: tl.constexpr):
def _helion_swiglu_fwd(a_flat, b_flat, out_flat, _BLOCK_SIZE_0: tl.constexpr):
# src[swiglu.py:N]: for tile_idx in hl.tile(total_elements):
pid_0 = tl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
Expand All @@ -5972,7 +5972,7 @@ def _helion_swiglu(a_flat, b_flat, out_flat, _BLOCK_SIZE_0: tl.constexpr):
# src[swiglu.py:N]: out_flat[tile_idx] = result
tl.store(out_flat + indices_0 * 1, v_4, None)

def swiglu(a: Tensor, b: Tensor, *, _launcher=_default_launcher):
def swiglu_fwd(a: Tensor, b: Tensor, *, _launcher=_default_launcher):
"""
Performs SwiGLU operation: SiLU(a) * b where SiLU is the Swish activation.

Expand Down Expand Up @@ -6006,10 +6006,86 @@ def swiglu(a: Tensor, b: Tensor, *, _launcher=_default_launcher):
# src[swiglu.py:N]: # Load input values and convert to float32 for computation
# src[swiglu.py:N]: a_vals = a_flat[tile_idx].to(torch.float32)
# src[swiglu.py:N-N]: ...
_launcher(_helion_swiglu, (triton.cdiv(1048576, _BLOCK_SIZE_0),), a_flat, b_flat, out_flat, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
_launcher(_helion_swiglu_fwd, (triton.cdiv(1048576, _BLOCK_SIZE_0),), a_flat, b_flat, out_flat, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
# src[swiglu.py:N]: return out
return out

--- assertExpectedJournal(TestExamples.test_swiglu_bwd)
from __future__ import annotations

import torch
import triton
import triton.language as tl
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _helion_swiglu_bwd(x1_flat, gout_flat, dx2_flat, x2_flat, dx1_flat, _BLOCK_SIZE_0: tl.constexpr):
# src[swiglu.py:N]: for tile in hl.tile(x1.numel()):
pid_0 = tl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
# src[swiglu.py:N]: x1_vals = x1_flat[tile].to(torch.float32)
load = tl.load(x1_flat + indices_0 * 1, None)
v_0 = tl.cast(load, tl.float32)
# src[swiglu.py:N]: gout_vals = gout_flat[tile].to(torch.float32)
load_1 = tl.load(gout_flat + indices_0 * 1, None)
v_1 = tl.cast(load_1, tl.float32)
# src[swiglu.py:N]: dx2_vals = x1_vals * torch.sigmoid(x1_vals) * gout_vals
v_2 = tl.sigmoid(tl.cast(v_0, tl.float32))
v_3 = v_0 * v_2
v_4 = v_3 * v_1
# src[swiglu.py:N]: dx2_flat[tile] = dx2_vals.to(x2.dtype)
v_5 = tl.cast(v_4, tl.bfloat16)
tl.store(dx2_flat + indices_0 * 1, v_5, None)
# src[swiglu.py:N]: x2_vals = x2_flat[tile].to(torch.float32)
load_2 = tl.load(x2_flat + indices_0 * 1, None)
v_6 = tl.cast(load_2, tl.float32)
# src[swiglu.py:N]: x1_exp = torch.exp(x1_vals)
v_7 = libdevice.exp(v_0)
# src[swiglu.py:N]: x1_exp_plus1 = x1_exp + 1
v_8 = 1.0
v_9 = v_7 + v_8
# src[swiglu.py:N]: dextra = x1_exp / x1_exp_plus1 + x1_vals * x1_exp / x1_exp_plus1 / x1_exp_plus1
v_10 = v_7 / v_9
v_11 = v_0 * v_7
v_12 = v_11 / v_9
v_13 = v_12 / v_9
v_14 = v_10 + v_13
# src[swiglu.py:N]: dx1_vals = gout_vals * x2_vals * dextra
v_15 = v_1 * v_6
v_16 = v_15 * v_14
# src[swiglu.py:N]: dx1_flat[tile] = dx1_vals.to(x1.dtype)
v_17 = tl.cast(v_16, tl.bfloat16)
tl.store(dx1_flat + indices_0 * 1, v_17, None)

def swiglu_bwd(gout: Tensor, x1: Tensor, x2: Tensor, *, _launcher=_default_launcher):
"""
Implement the backward formula for swiglu.
"""
# src[swiglu.py:N]: dx1 = torch.empty_like(x1)
dx1 = torch.empty_like(x1)
# src[swiglu.py:N]: dx2 = torch.empty_like(x2)
dx2 = torch.empty_like(x2)
# src[swiglu.py:N]: gout_flat = gout.view(-1)
gout_flat = gout.view(-1)
# src[swiglu.py:N]: x1_flat = x1.view(-1)
x1_flat = x1.view(-1)
# src[swiglu.py:N]: x2_flat = x2.view(-1)
x2_flat = x2.view(-1)
# src[swiglu.py:N]: dx1_flat = dx1.view(-1)
dx1_flat = dx1.view(-1)
# src[swiglu.py:N]: dx2_flat = dx2.view(-1)
dx2_flat = dx2.view(-1)
# src[swiglu.py:N]: for tile in hl.tile(x1.numel()):
_BLOCK_SIZE_0 = 1024
# src[swiglu.py:N]: for tile in hl.tile(x1.numel()):
# src[swiglu.py:N]: x1_vals = x1_flat[tile].to(torch.float32)
# src[swiglu.py:N]: gout_vals = gout_flat[tile].to(torch.float32)
# src[swiglu.py:N-N]: ...
_launcher(_helion_swiglu_bwd, (triton.cdiv(1024, _BLOCK_SIZE_0),), x1_flat, gout_flat, dx2_flat, x2_flat, dx1_flat, _BLOCK_SIZE_0, num_warps=4, num_stages=2)
# src[swiglu.py:N]: return dx1, dx2
return (dx1, dx2)

--- assertExpectedJournal(TestExamples.test_template_via_closure0)
from __future__ import annotations

Expand Down
29 changes: 29 additions & 0 deletions test/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from packaging import version
import torch
import torch.nn.functional as F

import helion
from helion import _compat
Expand Down Expand Up @@ -494,6 +495,33 @@ def test_rms_norm_fwd(self):
)
)

def test_swiglu_bwd(self):
"""Test backward pass for swiglu."""
x1, x2 = [
torch.randn(1024, device=DEVICE, dtype=torch.bfloat16, requires_grad=True)
for _ in range(2)
]

out = F.silu(x1) * x2

grad_out = torch.randn_like(out)
out.backward(grad_out)

args = (
grad_out,
x1,
x2,
)

self.assertExpectedJournal(
check_example(
"swiglu",
args,
(x1.grad, x2.grad),
fn_name="swiglu_bwd",
)
)

def test_rms_norm_bwd(self):
"""Test backward pass for rms norm weight gradient."""
batch_size, dim = 32, 64
Expand Down Expand Up @@ -1208,6 +1236,7 @@ def test_swiglu(self):
"swiglu",
args,
torch.nn.functional.silu(args[0]) * args[1],
fn_name="swiglu_fwd",
block_sizes=[16],
num_warps=4,
num_stages=3,
Expand Down
Loading