Skip to content

Commit 99587ae

Browse files
committed
[helion] backward support for swiglu
1 parent 8338452 commit 99587ae

File tree

3 files changed

+143
-1
lines changed

3 files changed

+143
-1
lines changed

examples/swiglu.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,14 @@
3131

3232
if TYPE_CHECKING:
3333
from collections.abc import Callable
34+
from typing import Any
3435

3536

3637
# %%
3738
# SwiGLU Kernel
3839
# -------------
3940
@helion.kernel()
40-
def swiglu(a: Tensor, b: Tensor) -> Tensor:
41+
def swiglu_fwd(a: Tensor, b: Tensor) -> Tensor:
4142
"""
4243
Performs SwiGLU operation: SiLU(a) * b where SiLU is the Swish activation.
4344
@@ -86,6 +87,65 @@ def swiglu(a: Tensor, b: Tensor) -> Tensor:
8687
return out
8788

8889

90+
@helion.kernel()
91+
def swiglu_bwd(gout: Tensor, x1: Tensor, x2: Tensor) -> tuple[Tensor, Tensor]:
92+
"""
93+
Implement the backward formula for swiglu.
94+
"""
95+
dx1 = torch.empty_like(x1)
96+
dx2 = torch.empty_like(x2)
97+
98+
gout_flat = gout.view(-1)
99+
x1_flat = x1.view(-1)
100+
x2_flat = x2.view(-1)
101+
dx1_flat = dx1.view(-1)
102+
dx2_flat = dx2.view(-1)
103+
104+
for tile in hl.tile(x1.numel()):
105+
x1_vals = x1_flat[tile].to(torch.float32)
106+
gout_vals = gout_flat[tile].to(torch.float32)
107+
108+
# compute dx2
109+
dx2_vals = x1_vals * torch.sigmoid(x1_vals) * gout_vals
110+
dx2_flat[tile] = dx2_vals.to(x2.dtype)
111+
112+
# compute dx1
113+
x2_vals = x2_flat[tile].to(torch.float32)
114+
x1_exp = torch.exp(x1_vals)
115+
x1_exp_plus1 = x1_exp + 1
116+
dextra = x1_exp / x1_exp_plus1 + x1_vals * x1_exp / x1_exp_plus1 / x1_exp_plus1
117+
dx1_vals = gout_vals * x2_vals * dextra
118+
dx1_flat[tile] = dx1_vals.to(x1.dtype)
119+
120+
return dx1, dx2
121+
122+
123+
class SwigluFunction(torch.autograd.Function):
124+
@staticmethod
125+
def forward(
126+
ctx: Any, # noqa: ANN401
127+
x1: Tensor,
128+
x2: Tensor,
129+
) -> Tensor:
130+
out = swiglu_fwd(x1, x2)
131+
ctx.save_for_backward(x1, x2)
132+
return out
133+
134+
@staticmethod
135+
def backward( # type: ignore[override]
136+
ctx: Any, # noqa: ANN401
137+
grad_out: Tensor,
138+
) -> tuple[Tensor, Tensor]:
139+
x1, x2 = ctx.saved_tensors
140+
dx1, dx2 = swiglu_bwd(grad_out, x1, x2)
141+
return dx1, dx2
142+
143+
144+
def swiglu(a: Tensor, b: Tensor) -> Tensor:
145+
"""swiglu with forward + backward support."""
146+
return SwigluFunction.apply(a, b) # type: ignore[no-any-return]
147+
148+
89149
# %%
90150
# SwiGLU MLP Module (matches liger_kernel structure)
91151
# --------------------------------------------------

test/test_examples.expected

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3745,6 +3745,60 @@ def swiglu(a: Tensor, b: Tensor, *, _launcher=_default_launcher):
37453745
_launcher(_helion_swiglu, (triton.cdiv(total_elements, _BLOCK_SIZE_0),), a_flat, b_flat, out_flat, a_flat.stride(0), b_flat.stride(0), out_flat.stride(0), total_elements, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
37463746
return out
37473747

3748+
--- assertExpectedJournal(TestExamples.test_swiglu_bwd)
3749+
from __future__ import annotations
3750+
3751+
import torch
3752+
import triton
3753+
import triton.language as tl
3754+
from torch._inductor.runtime.triton_compat import libdevice
3755+
from helion.runtime import default_launcher as _default_launcher
3756+
3757+
@triton.jit
3758+
def _helion_swiglu_bwd(x1_flat, gout_flat, dx2_flat, x2_flat, dx1_flat, x1_size_0, dx1_flat_stride_0, dx2_flat_stride_0, gout_flat_stride_0, x1_flat_stride_0, x2_flat_stride_0, _BLOCK_SIZE_0: tl.constexpr):
3759+
pid_0 = tl.program_id(0)
3760+
offset_0 = pid_0 * _BLOCK_SIZE_0
3761+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
3762+
mask_0 = indices_0 < x1_size_0
3763+
load = tl.load(x1_flat + indices_0 * x1_flat_stride_0, mask_0, other=0)
3764+
v_0 = tl.cast(load, tl.float32)
3765+
load_1 = tl.load(gout_flat + indices_0 * gout_flat_stride_0, mask_0, other=0)
3766+
v_1 = tl.cast(load_1, tl.float32)
3767+
v_2 = tl.sigmoid(tl.cast(v_0, tl.float32))
3768+
v_3 = v_0 * v_2
3769+
v_4 = v_3 * v_1
3770+
v_5 = tl.cast(v_4, tl.bfloat16)
3771+
tl.store(dx2_flat + indices_0 * dx2_flat_stride_0, v_5, mask_0)
3772+
load_2 = tl.load(x2_flat + indices_0 * x2_flat_stride_0, mask_0, other=0)
3773+
v_6 = tl.cast(load_2, tl.float32)
3774+
v_7 = libdevice.exp(v_0)
3775+
v_8 = 1.0
3776+
v_9 = v_7 + v_8
3777+
v_10 = v_7 / v_9
3778+
v_11 = v_0 * v_7
3779+
v_12 = v_11 / v_9
3780+
v_13 = v_12 / v_9
3781+
v_14 = v_10 + v_13
3782+
v_15 = v_1 * v_6
3783+
v_16 = v_15 * v_14
3784+
v_17 = tl.cast(v_16, tl.bfloat16)
3785+
tl.store(dx1_flat + indices_0 * dx1_flat_stride_0, v_17, mask_0)
3786+
3787+
def swiglu_bwd(gout: Tensor, x1: Tensor, x2: Tensor, *, _launcher=_default_launcher):
3788+
"""
3789+
Implement the backward formula for swiglu.
3790+
"""
3791+
dx1 = torch.empty_like(x1)
3792+
dx2 = torch.empty_like(x2)
3793+
gout_flat = gout.view(-1)
3794+
x1_flat = x1.view(-1)
3795+
x2_flat = x2.view(-1)
3796+
dx1_flat = dx1.view(-1)
3797+
dx2_flat = dx2.view(-1)
3798+
_BLOCK_SIZE_0 = 1024
3799+
_launcher(_helion_swiglu_bwd, (triton.cdiv(x1.size(0), _BLOCK_SIZE_0),), x1_flat, gout_flat, dx2_flat, x2_flat, dx1_flat, x1.size(0), dx1_flat.stride(0), dx2_flat.stride(0), gout_flat.stride(0), x1_flat.stride(0), x2_flat.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
3800+
return (dx1, dx2)
3801+
37483802
--- assertExpectedJournal(TestExamples.test_template_via_closure0)
37493803
from __future__ import annotations
37503804

test/test_examples.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from packaging import version
66
import torch
7+
import torch.nn.functional as F
78

89
import helion
910
from helion._testing import DEVICE
@@ -329,6 +330,33 @@ def test_rms_norm_fwd(self):
329330
)
330331
)
331332

333+
def test_swiglu_bwd(self):
334+
"""Test backward pass for swiglu."""
335+
x1, x2 = [
336+
torch.randn(1024, device=DEVICE, dtype=torch.bfloat16, requires_grad=True)
337+
for _ in range(2)
338+
]
339+
340+
out = F.silu(x1) * x2
341+
342+
grad_out = torch.randn_like(out)
343+
out.backward(grad_out)
344+
345+
args = (
346+
grad_out,
347+
x1,
348+
x2,
349+
)
350+
351+
self.assertExpectedJournal(
352+
check_example(
353+
"swiglu",
354+
args,
355+
(x1.grad, x2.grad),
356+
fn_name="swiglu_bwd",
357+
)
358+
)
359+
332360
def test_rms_norm_bwd(self):
333361
"""Test backward pass for rms norm weight gradient."""
334362
batch_size, dim = 32, 64

0 commit comments

Comments
 (0)