-
Couldn't load subscription status.
- Fork 53
Add simplified se_block kernel #989
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
mengluy0125
wants to merge
1
commit into
pytorch:main
Choose a base branch
from
mengluy0125:export-D84968671
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,213 @@ | ||
| """ | ||
| Helion SE Block Example | ||
| ============================ | ||
| This example demonstrates a Helion kernel implementation of SE Block. | ||
| """ | ||
|
|
||
| # %% | ||
| from __future__ import annotations | ||
|
|
||
| import torch | ||
| from torch import Tensor | ||
|
|
||
| import helion | ||
| from helion._testing import DEVICE | ||
| from helion._testing import run_example | ||
| import helion.language as hl | ||
|
|
||
|
|
||
| # %% | ||
| @helion.kernel( | ||
| # static_shapes=True gives a performance boost for matmuls | ||
| static_shapes=True, | ||
| ) | ||
| def se_block_fwd(x: Tensor, w: Tensor) -> tuple[Tensor, Tensor]: | ||
| """ | ||
| Performs 2 * x * sigmoid(x @ w) | ||
| Args: | ||
| x: 2D tensor of shape [m, n]. | ||
| w: 2D tensor of shape [n, n]. | ||
| Returns: | ||
| out: Resulting matrix of shape [m, n]. | ||
| s: sigmoid(x @ w) of shape [m, n]. | ||
| """ | ||
| m, n = x.size() | ||
|
|
||
| out = torch.empty([m, n], dtype=x.dtype, device=x.device) | ||
| s = torch.empty([m, n], dtype=x.dtype, device=x.device) | ||
|
|
||
| for tile_m in hl.tile(m): | ||
| for tile_n in hl.tile(n): | ||
| # Compute sigmoid in float32 | ||
| sigmoid_result = torch.sigmoid(x[tile_m, :] @ w[:, tile_n]) | ||
| s[tile_m, tile_n] = sigmoid_result | ||
| # Compute output: 2 * x * sigmoid, cast to input dtype | ||
| acc = 2.0 * x[tile_m, tile_n].to(torch.float32) * sigmoid_result | ||
| out[tile_m, tile_n] = acc.to(x.dtype) | ||
|
|
||
| return out, s | ||
|
|
||
|
|
||
| # %% | ||
| @helion.kernel(static_shapes=True) | ||
| def se_block_bwd_dx(grad_out: Tensor, x: Tensor, w: Tensor, s: Tensor) -> Tensor: | ||
| """ | ||
| Compute gradient for x. | ||
| grad_x = 2 * grad_out * s + (2 * grad_out * x * s * (1 - s)) @ w.T | ||
|
|
||
| Args: | ||
| grad_out: Gradient w.r.t output [m, n] | ||
| x: Input tensor [m, n] | ||
| w: Weight matrix [n, n] | ||
| s: sigmoid(x @ w) from forward pass [m, n] | ||
|
|
||
| Returns: | ||
| grad_x: Gradient w.r.t x [m, n] | ||
| """ | ||
| m, n = x.size() | ||
|
|
||
| grad_x = torch.empty([m, n], dtype=torch.float32, device=x.device) | ||
|
|
||
| for tile_m, tile_n in hl.tile([m, n]): | ||
| # 2 * grad_out * s | ||
| acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) | ||
| acc += 2.0 * grad_out[tile_m, tile_n] * s[tile_m, tile_n] | ||
|
|
||
| for tile_k in hl.tile(n): | ||
| # 2 * grad_out * x * s * (1-s) for tile_k | ||
| grad_to_w = ( | ||
| 2.0 | ||
| * grad_out[tile_m, tile_k].to(torch.float32) | ||
| * x[tile_m, tile_k].to(torch.float32) | ||
| * s[tile_m, tile_k].to(torch.float32) | ||
| * (1.0 - s[tile_m, tile_k].to(torch.float32)) | ||
| ) | ||
| # grad_to_w @ w.T[tile_k, tile_n] = grad_to_w @ w[tile_n, tile_k].T | ||
| acc += grad_to_w @ w[tile_n, tile_k].to(torch.float32).T | ||
|
|
||
| grad_x[tile_m, tile_n] = acc.to(x.dtype) | ||
|
|
||
| return grad_x | ||
|
|
||
|
|
||
| # %% | ||
| @helion.kernel(static_shapes=True) | ||
| def se_block_bwd_dw(grad_out: Tensor, x: Tensor, s: Tensor) -> Tensor: | ||
| """ | ||
| Compute gradient for w. | ||
| grad_w = x.T @ (2 * grad_out * x * s * (1 - s)) | ||
|
|
||
| Args: | ||
| grad_out: Gradient w.r.t output [m, n] | ||
| x: Input tensor [m, n] | ||
| s: sigmoid(x @ w) from forward pass [m, n] | ||
|
|
||
| Returns: | ||
| grad_w: Gradient w.r.t w [n, n] | ||
| """ | ||
| m, n = x.size() | ||
|
|
||
| grad_w = torch.zeros([n, n], dtype=torch.float32, device=x.device) | ||
|
|
||
| for tile_n1, tile_n2 in hl.tile([n, n]): | ||
| acc_w = hl.zeros([tile_n1, tile_n2], dtype=torch.float32) | ||
| for tile_m in hl.tile(m): | ||
| # 2 * grad_out * x * s * (1-s) | ||
| grad_to_w = ( | ||
| 2.0 | ||
| * grad_out[tile_m, tile_n2].to(torch.float32) | ||
| * x[tile_m, tile_n2].to(torch.float32) | ||
| * s[tile_m, tile_n2].to(torch.float32) | ||
| * (1.0 - s[tile_m, tile_n2].to(torch.float32)) | ||
| ) | ||
| # x[tile_m, tile_n1].T @ grad_to_w[tile_m, tile_n2] | ||
| acc_w += x[tile_m, tile_n1].to(torch.float32).T @ grad_to_w | ||
|
|
||
| grad_w[tile_n1, tile_n2] = acc_w.to(x.dtype) | ||
|
|
||
| return grad_w | ||
|
|
||
|
|
||
| # %% | ||
| # Reference Implementation | ||
| # -------------------- | ||
| def se_block_pytorch(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor: | ||
| """ | ||
| PyTorch reference implementation se_block. | ||
|
|
||
| Args: | ||
| x, w: Input tensors | ||
|
|
||
| Returns: | ||
| tensor of 2 * x * sigmoid(x @ w) | ||
| """ | ||
| return 2 * x * torch.sigmoid(x @ w) | ||
|
|
||
|
|
||
| # %% | ||
| # Autograd Function | ||
| # ------------------ | ||
| class SEBlockFunction(torch.autograd.Function): | ||
| @staticmethod | ||
| def forward( # type: ignore[override] | ||
| ctx: object, | ||
| x: torch.Tensor, | ||
| w: torch.Tensor, | ||
| ) -> torch.Tensor: | ||
| """Forward pass for se block.""" | ||
| out, s = se_block_fwd(x, w) | ||
| ctx.save_for_backward(x, w, s) # type: ignore[attr-defined] | ||
| return out | ||
|
|
||
| @staticmethod | ||
| def backward( # type: ignore[override] | ||
| ctx: object, | ||
| grad_out: torch.Tensor, | ||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| """Backward pass for se block.""" | ||
| x, w, s = ctx.saved_tensors # type: ignore[attr-defined] | ||
|
|
||
| grad_x = se_block_bwd_dx(grad_out, x, w, s) | ||
| grad_w = se_block_bwd_dw(grad_out, x, s) | ||
|
|
||
| return grad_x, grad_w | ||
|
|
||
|
|
||
| def se_block(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor: | ||
| """ | ||
| SE Block with autograd support. | ||
|
|
||
| Args: | ||
| x: Input tensor [m, n] | ||
| w: Weight matrix [n, n] | ||
|
|
||
| Returns: | ||
| Output tensor [m, n] | ||
| """ | ||
| return SEBlockFunction.apply(x, w) # type: ignore[no-any-return] | ||
|
|
||
|
|
||
| def check(m: int, n: int) -> None: | ||
| """ | ||
| Checks the correctness against PyTorch. | ||
| Args: | ||
| m (int): Number of rows in matrix x. | ||
| n (int): Number of columns in matrix x. | ||
| """ | ||
| x = torch.randn([m, n], device=DEVICE, dtype=torch.float16, requires_grad=True) | ||
| w = torch.randn([n, n], device=DEVICE, dtype=torch.float16, requires_grad=True) | ||
| for bwd in [True, False]: | ||
| run_example(se_block, se_block_pytorch, (x, w), bwd=bwd) | ||
|
|
||
|
|
||
| # %% | ||
| def main() -> None: | ||
| """ | ||
| Main function to run correctness checks. | ||
| """ | ||
| check(1024, 1024) | ||
|
|
||
|
|
||
| # %% | ||
| if __name__ == "__main__": | ||
| main() | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious what is "SE" short for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Squeeze and Excitation Net. Basically it performs excitation on embedding, similar as Squeeze and Excitation Net as those used in https://arxiv.org/abs/1709.01507, the idea is to enhance signal/noise ratio to preserve useful information.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks! maybe we can add this explanation to this docstring as well to help clarify