Skip to content

Commit 4955e98

Browse files
mengluy0125facebook-github-bot
authored andcommitted
Add simplified se_block kernel (#989)
Summary: We add a helion kernel to compute 2 * x * sigmoid(x @ w) Differential Revision: D84968671
1 parent c5e930b commit 4955e98

File tree

3 files changed

+496
-0
lines changed

3 files changed

+496
-0
lines changed

examples/se_block.py

Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
"""
2+
Helion SE Block Example
3+
============================
4+
This example demonstrates a Helion kernel implementation of SE Block.
5+
"""
6+
7+
# %%
8+
from __future__ import annotations
9+
10+
import torch
11+
from torch import Tensor
12+
13+
import helion
14+
from helion._testing import DEVICE
15+
from helion._testing import run_example
16+
import helion.language as hl
17+
18+
19+
# %%
20+
@helion.kernel(
21+
# static_shapes=True gives a performance boost for matmuls
22+
static_shapes=True,
23+
)
24+
def se_block_fwd(x: Tensor, w: Tensor) -> tuple[Tensor, Tensor]:
25+
"""
26+
Performs 2 * x * sigmoid(x @ w)
27+
Args:
28+
x: 2D tensor of shape [m, n].
29+
w: 2D tensor of shape [n, n].
30+
Returns:
31+
out: Resulting matrix of shape [m, n].
32+
s: sigmoid(x @ w) of shape [m, n].
33+
"""
34+
m, n = x.size()
35+
36+
out = torch.empty([m, n], dtype=x.dtype, device=x.device)
37+
s = torch.empty([m, n], dtype=x.dtype, device=x.device)
38+
39+
for tile_m in hl.tile(m):
40+
for tile_n in hl.tile(n):
41+
s[tile_m, tile_n] = torch.sigmoid(x[tile_m, :] @ w[:, tile_n])
42+
acc = 2.0 * x[tile_m, tile_n] * s[tile_m, tile_n]
43+
out[tile_m, tile_n] = acc
44+
45+
return out, s
46+
47+
48+
# %%
49+
@helion.kernel(static_shapes=True)
50+
def se_block_bwd_dx(grad_out: Tensor, x: Tensor, w: Tensor, s: Tensor) -> Tensor:
51+
"""
52+
Compute gradient for x.
53+
grad_x = 2 * grad_out * s + (2 * grad_out * x * s * (1 - s)) @ w.T
54+
55+
Args:
56+
grad_out: Gradient w.r.t output [m, n]
57+
x: Input tensor [m, n]
58+
w: Weight matrix [n, n]
59+
s: sigmoid(x @ w) from forward pass [m, n]
60+
61+
Returns:
62+
grad_x: Gradient w.r.t x [m, n]
63+
"""
64+
m, n = x.size()
65+
66+
grad_x = torch.empty([m, n], dtype=torch.float32, device=x.device)
67+
68+
for tile_m, tile_n in hl.tile([m, n]):
69+
# 2 * grad_out * s
70+
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
71+
acc += 2.0 * grad_out[tile_m, tile_n] * s[tile_m, tile_n]
72+
73+
for tile_k in hl.tile(n):
74+
# 2 * grad_out * x * s * (1-s) for tile_k
75+
grad_to_w = (
76+
2.0
77+
* grad_out[tile_m, tile_k].to(torch.float32)
78+
* x[tile_m, tile_k].to(torch.float32)
79+
* s[tile_m, tile_k].to(torch.float32)
80+
* (1.0 - s[tile_m, tile_k].to(torch.float32))
81+
)
82+
# grad_to_w @ w.T[tile_k, tile_n] = grad_to_w @ w[tile_n, tile_k].T
83+
acc += grad_to_w @ w[tile_n, tile_k].to(torch.float32).T
84+
85+
grad_x[tile_m, tile_n] = acc.to(x.dtype)
86+
87+
return grad_x
88+
89+
90+
# %%
91+
@helion.kernel(static_shapes=True)
92+
def se_block_bwd_dw(grad_out: Tensor, x: Tensor, s: Tensor) -> Tensor:
93+
"""
94+
Compute gradient for w.
95+
grad_w = x.T @ (2 * grad_out * x * s * (1 - s))
96+
97+
Args:
98+
grad_out: Gradient w.r.t output [m, n]
99+
x: Input tensor [m, n]
100+
s: sigmoid(x @ w) from forward pass [m, n]
101+
102+
Returns:
103+
grad_w: Gradient w.r.t w [n, n]
104+
"""
105+
m, n = x.size()
106+
107+
grad_w = torch.zeros([n, n], dtype=torch.float32, device=x.device)
108+
109+
for tile_n1, tile_n2 in hl.tile([n, n]):
110+
acc_w = hl.zeros([tile_n1, tile_n2], dtype=torch.float32)
111+
for tile_m in hl.tile(m):
112+
# 2 * grad_out * x * s * (1-s)
113+
grad_to_w = (
114+
2.0
115+
* grad_out[tile_m, tile_n2].to(torch.float32)
116+
* x[tile_m, tile_n2].to(torch.float32)
117+
* s[tile_m, tile_n2].to(torch.float32)
118+
* (1.0 - s[tile_m, tile_n2].to(torch.float32))
119+
)
120+
# x[tile_m, tile_n1].T @ grad_to_w[tile_m, tile_n2]
121+
acc_w += x[tile_m, tile_n1].to(torch.float32).T @ grad_to_w
122+
123+
grad_w[tile_n1, tile_n2] = acc_w.to(x.dtype)
124+
125+
return grad_w
126+
127+
128+
# %%
129+
# Reference Implementation
130+
# --------------------
131+
def se_block_pytorch(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
132+
"""
133+
PyTorch reference implementation se_block.
134+
135+
Args:
136+
x, w: Input tensors
137+
138+
Returns:
139+
tensor of 2 * x * sigmoid(x @ w)
140+
"""
141+
return 2 * x * torch.sigmoid(x @ w)
142+
143+
144+
# %%
145+
# Autograd Function
146+
# ------------------
147+
class SEBlockFunction(torch.autograd.Function):
148+
@staticmethod
149+
def forward( # type: ignore[override]
150+
ctx: object,
151+
x: torch.Tensor,
152+
w: torch.Tensor,
153+
) -> torch.Tensor:
154+
"""Forward pass for se block."""
155+
out, s = se_block_fwd(x, w)
156+
ctx.save_for_backward(x, w, s) # type: ignore[attr-defined]
157+
return out
158+
159+
@staticmethod
160+
def backward( # type: ignore[override]
161+
ctx: object,
162+
grad_out: torch.Tensor,
163+
) -> tuple[torch.Tensor, torch.Tensor]:
164+
"""Backward pass for se block."""
165+
x, w, s = ctx.saved_tensors # type: ignore[attr-defined]
166+
167+
grad_x = se_block_bwd_dx(grad_out, x, w, s)
168+
grad_w = se_block_bwd_dw(grad_out, x, s)
169+
170+
return grad_x, grad_w
171+
172+
173+
def se_block(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
174+
"""
175+
SE Block with autograd support.
176+
177+
Args:
178+
x: Input tensor [m, n]
179+
w: Weight matrix [n, n]
180+
181+
Returns:
182+
Output tensor [m, n]
183+
"""
184+
return SEBlockFunction.apply(x, w) # type: ignore[no-any-return]
185+
186+
187+
def check(m: int, n: int) -> None:
188+
"""
189+
Checks the correctness against PyTorch.
190+
Args:
191+
m (int): Number of rows in matrix x.
192+
n (int): Number of columns in matrix x.
193+
"""
194+
x = torch.randn([m, n], device=DEVICE, dtype=torch.float16, requires_grad=True)
195+
w = torch.randn([n, n], device=DEVICE, dtype=torch.float16, requires_grad=True)
196+
for bwd in [True, False]:
197+
run_example(se_block, se_block_pytorch, (x, w), bwd=bwd)
198+
199+
200+
# %%
201+
def main() -> None:
202+
"""
203+
Main function to run correctness checks.
204+
"""
205+
check(1024, 1024)
206+
207+
208+
# %%
209+
if __name__ == "__main__":
210+
main()

0 commit comments

Comments
 (0)