Skip to content

Commit 9e19810

Browse files
committed
[helion] backward support for swiglu
1 parent 5b126e4 commit 9e19810

File tree

4 files changed

+180
-1
lines changed

4 files changed

+180
-1
lines changed

benchmarks/run.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,11 @@ class RunResult:
128128
"examples.swiglu",
129129
"swiglu_tritonbench",
130130
),
131+
"swiglu-bwd": (
132+
"tritonbench.operators.swiglu.operator",
133+
"examples.swiglu",
134+
"swiglu_tritonbench",
135+
),
131136
"jsd": (
132137
"tritonbench.operators.jsd.operator",
133138
"examples.jsd",
@@ -440,6 +445,15 @@ class RunResult:
440445
"helion_swiglu_tritonbench-speedup": "helion_speedup",
441446
"helion_swiglu_tritonbench-accuracy": "helion_accuracy",
442447
},
448+
"swiglu-bwd": {
449+
"torch_swiglu": "baseline",
450+
"liger_swiglu-speedup": "triton_speedup",
451+
"liger_swiglu-accuracy": "triton_accuracy",
452+
"torch_compile_swiglu-speedup": "torch_compile_speedup",
453+
"torch_compile_swiglu-accuracy": "torch_compile_accuracy",
454+
"helion_swiglu_tritonbench-speedup": "helion_speedup",
455+
"helion_swiglu_tritonbench-accuracy": "helion_accuracy",
456+
},
443457
"jsd": {
444458
"torch_jsd": "baseline",
445459
"liger_jsd-speedup": "triton_speedup",

examples/swiglu.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636

3737
if TYPE_CHECKING:
3838
from collections.abc import Callable
39+
from typing import Any
3940

4041

4142
# %%
@@ -45,7 +46,7 @@
4546

4647
# %%
4748
@helion.kernel()
48-
def swiglu(a: Tensor, b: Tensor) -> Tensor:
49+
def swiglu_fwd(a: Tensor, b: Tensor) -> Tensor:
4950
"""
5051
Performs SwiGLU operation: SiLU(a) * b where SiLU is the Swish activation.
5152
@@ -94,6 +95,65 @@ def swiglu(a: Tensor, b: Tensor) -> Tensor:
9495
return out
9596

9697

98+
@helion.kernel()
99+
def swiglu_bwd(gout: Tensor, x1: Tensor, x2: Tensor) -> tuple[Tensor, Tensor]:
100+
"""
101+
Implement the backward formula for swiglu.
102+
"""
103+
dx1 = torch.empty_like(x1)
104+
dx2 = torch.empty_like(x2)
105+
106+
gout_flat = gout.view(-1)
107+
x1_flat = x1.view(-1)
108+
x2_flat = x2.view(-1)
109+
dx1_flat = dx1.view(-1)
110+
dx2_flat = dx2.view(-1)
111+
112+
for tile in hl.tile(x1.numel()):
113+
x1_vals = x1_flat[tile].to(torch.float32)
114+
gout_vals = gout_flat[tile].to(torch.float32)
115+
116+
# compute dx2
117+
dx2_vals = x1_vals * torch.sigmoid(x1_vals) * gout_vals
118+
dx2_flat[tile] = dx2_vals.to(x2.dtype)
119+
120+
# compute dx1
121+
x2_vals = x2_flat[tile].to(torch.float32)
122+
x1_exp = torch.exp(x1_vals)
123+
x1_exp_plus1 = x1_exp + 1
124+
dextra = x1_exp / x1_exp_plus1 + x1_vals * x1_exp / x1_exp_plus1 / x1_exp_plus1
125+
dx1_vals = gout_vals * x2_vals * dextra
126+
dx1_flat[tile] = dx1_vals.to(x1.dtype)
127+
128+
return dx1, dx2
129+
130+
131+
class SwigluFunction(torch.autograd.Function):
132+
@staticmethod
133+
def forward(
134+
ctx: Any, # noqa: ANN401
135+
x1: Tensor,
136+
x2: Tensor,
137+
) -> Tensor:
138+
out = swiglu_fwd(x1, x2)
139+
ctx.save_for_backward(x1, x2)
140+
return out
141+
142+
@staticmethod
143+
def backward( # type: ignore[override]
144+
ctx: Any, # noqa: ANN401
145+
grad_out: Tensor,
146+
) -> tuple[Tensor, Tensor]:
147+
x1, x2 = ctx.saved_tensors
148+
dx1, dx2 = swiglu_bwd(grad_out, x1, x2)
149+
return dx1, dx2
150+
151+
152+
def swiglu(a: Tensor, b: Tensor) -> Tensor:
153+
"""swiglu with forward + backward support."""
154+
return SwigluFunction.apply(a, b) # type: ignore[no-any-return]
155+
156+
97157
# %%
98158
# SwiGLU MLP Module (matches liger_kernel structure)
99159
# --------------------------------------------------

test/test_examples.expected

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6010,6 +6010,82 @@ def swiglu(a: Tensor, b: Tensor, *, _launcher=_default_launcher):
60106010
# src[swiglu.py:N]: return out
60116011
return out
60126012

6013+
--- assertExpectedJournal(TestExamples.test_swiglu_bwd)
6014+
from __future__ import annotations
6015+
6016+
import torch
6017+
import triton
6018+
import triton.language as tl
6019+
from helion.runtime import default_launcher as _default_launcher
6020+
6021+
@triton.jit
6022+
def _helion_swiglu_bwd(x1_flat, gout_flat, dx2_flat, x2_flat, dx1_flat, _BLOCK_SIZE_0: tl.constexpr):
6023+
# src[swiglu.py:N]: for tile in hl.tile(x1.numel()):
6024+
pid_0 = tl.program_id(0)
6025+
offset_0 = pid_0 * _BLOCK_SIZE_0
6026+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
6027+
# src[swiglu.py:N]: x1_vals = x1_flat[tile].to(torch.float32)
6028+
load = tl.load(x1_flat + indices_0 * 1, None)
6029+
v_0 = tl.cast(load, tl.float32)
6030+
# src[swiglu.py:N]: gout_vals = gout_flat[tile].to(torch.float32)
6031+
load_1 = tl.load(gout_flat + indices_0 * 1, None)
6032+
v_1 = tl.cast(load_1, tl.float32)
6033+
# src[swiglu.py:N]: dx2_vals = x1_vals * torch.sigmoid(x1_vals) * gout_vals
6034+
v_2 = tl.sigmoid(tl.cast(v_0, tl.float32))
6035+
v_3 = v_0 * v_2
6036+
v_4 = v_3 * v_1
6037+
# src[swiglu.py:N]: dx2_flat[tile] = dx2_vals.to(x2.dtype)
6038+
v_5 = tl.cast(v_4, tl.bfloat16)
6039+
tl.store(dx2_flat + indices_0 * 1, v_5, None)
6040+
# src[swiglu.py:N]: x2_vals = x2_flat[tile].to(torch.float32)
6041+
load_2 = tl.load(x2_flat + indices_0 * 1, None)
6042+
v_6 = tl.cast(load_2, tl.float32)
6043+
# src[swiglu.py:N]: x1_exp = torch.exp(x1_vals)
6044+
v_7 = libdevice.exp(v_0)
6045+
# src[swiglu.py:N]: x1_exp_plus1 = x1_exp + 1
6046+
v_8 = 1.0
6047+
v_9 = v_7 + v_8
6048+
# src[swiglu.py:N]: dextra = x1_exp / x1_exp_plus1 + x1_vals * x1_exp / x1_exp_plus1 / x1_exp_plus1
6049+
v_10 = v_7 / v_9
6050+
v_11 = v_0 * v_7
6051+
v_12 = v_11 / v_9
6052+
v_13 = v_12 / v_9
6053+
v_14 = v_10 + v_13
6054+
# src[swiglu.py:N]: dx1_vals = gout_vals * x2_vals * dextra
6055+
v_15 = v_1 * v_6
6056+
v_16 = v_15 * v_14
6057+
# src[swiglu.py:N]: dx1_flat[tile] = dx1_vals.to(x1.dtype)
6058+
v_17 = tl.cast(v_16, tl.bfloat16)
6059+
tl.store(dx1_flat + indices_0 * 1, v_17, None)
6060+
6061+
def swiglu_bwd(gout: Tensor, x1: Tensor, x2: Tensor, *, _launcher=_default_launcher):
6062+
"""
6063+
Implement the backward formula for swiglu.
6064+
"""
6065+
# src[swiglu.py:N]: dx1 = torch.empty_like(x1)
6066+
dx1 = torch.empty_like(x1)
6067+
# src[swiglu.py:N]: dx2 = torch.empty_like(x2)
6068+
dx2 = torch.empty_like(x2)
6069+
# src[swiglu.py:N]: gout_flat = gout.view(-1)
6070+
gout_flat = gout.view(-1)
6071+
# src[swiglu.py:N]: x1_flat = x1.view(-1)
6072+
x1_flat = x1.view(-1)
6073+
# src[swiglu.py:N]: x2_flat = x2.view(-1)
6074+
x2_flat = x2.view(-1)
6075+
# src[swiglu.py:N]: dx1_flat = dx1.view(-1)
6076+
dx1_flat = dx1.view(-1)
6077+
# src[swiglu.py:N]: dx2_flat = dx2.view(-1)
6078+
dx2_flat = dx2.view(-1)
6079+
# src[swiglu.py:N]: for tile in hl.tile(x1.numel()):
6080+
_BLOCK_SIZE_0 = 1024
6081+
# src[swiglu.py:N]: for tile in hl.tile(x1.numel()):
6082+
# src[swiglu.py:N]: x1_vals = x1_flat[tile].to(torch.float32)
6083+
# src[swiglu.py:N]: gout_vals = gout_flat[tile].to(torch.float32)
6084+
# src[swiglu.py:N-N]: ...
6085+
_launcher(_helion_swiglu_bwd, (triton.cdiv(1024, _BLOCK_SIZE_0),), x1_flat, gout_flat, dx2_flat, x2_flat, dx1_flat, _BLOCK_SIZE_0, num_warps=4, num_stages=2)
6086+
# src[swiglu.py:N]: return dx1, dx2
6087+
return (dx1, dx2)
6088+
60136089
--- assertExpectedJournal(TestExamples.test_template_via_closure0)
60146090
from __future__ import annotations
60156091

test/test_examples.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from packaging import version
77
import torch
8+
import torch.nn.functional as F
89

910
import helion
1011
from helion import _compat
@@ -494,6 +495,33 @@ def test_rms_norm_fwd(self):
494495
)
495496
)
496497

498+
def test_swiglu_bwd(self):
499+
"""Test backward pass for swiglu."""
500+
x1, x2 = [
501+
torch.randn(1024, device=DEVICE, dtype=torch.bfloat16, requires_grad=True)
502+
for _ in range(2)
503+
]
504+
505+
out = F.silu(x1) * x2
506+
507+
grad_out = torch.randn_like(out)
508+
out.backward(grad_out)
509+
510+
args = (
511+
grad_out,
512+
x1,
513+
x2,
514+
)
515+
516+
self.assertExpectedJournal(
517+
check_example(
518+
"swiglu",
519+
args,
520+
(x1.grad, x2.grad),
521+
fn_name="swiglu_bwd",
522+
)
523+
)
524+
497525
def test_rms_norm_bwd(self):
498526
"""Test backward pass for rms norm weight gradient."""
499527
batch_size, dim = 32, 64
@@ -1208,6 +1236,7 @@ def test_swiglu(self):
12081236
"swiglu",
12091237
args,
12101238
torch.nn.functional.silu(args[0]) * args[1],
1239+
fn_name="swiglu_fwd",
12111240
block_sizes=[16],
12121241
num_warps=4,
12131242
num_stages=3,

0 commit comments

Comments
 (0)