|  | 
|  | 1 | +""" | 
|  | 2 | +Helion SE Block Example | 
|  | 3 | +============================ | 
|  | 4 | +This example demonstrates a Helion kernel implementation of SE Block. | 
|  | 5 | +""" | 
|  | 6 | + | 
|  | 7 | +# %% | 
|  | 8 | +from __future__ import annotations | 
|  | 9 | + | 
|  | 10 | +import torch | 
|  | 11 | +from torch import Tensor | 
|  | 12 | + | 
|  | 13 | +import helion | 
|  | 14 | +from helion._testing import DEVICE | 
|  | 15 | +from helion._testing import run_example | 
|  | 16 | +import helion.language as hl | 
|  | 17 | + | 
|  | 18 | + | 
|  | 19 | +# %% | 
|  | 20 | +@helion.kernel( | 
|  | 21 | +    # static_shapes=True gives a performance boost for matmuls | 
|  | 22 | +    static_shapes=True, | 
|  | 23 | +) | 
|  | 24 | +def se_block_fwd(x: Tensor, w: Tensor) -> tuple[Tensor, Tensor]: | 
|  | 25 | +    """ | 
|  | 26 | +    Performs 2 * x * sigmoid(x @ w) | 
|  | 27 | +    Args: | 
|  | 28 | +        x: 2D tensor of shape [m, n]. | 
|  | 29 | +        w: 2D tensor of shape [n, n]. | 
|  | 30 | +    Returns: | 
|  | 31 | +        out: Resulting matrix of shape [m, n]. | 
|  | 32 | +        s: sigmoid(x @ w) of shape [m, n]. | 
|  | 33 | +    """ | 
|  | 34 | +    m, n = x.size() | 
|  | 35 | + | 
|  | 36 | +    out = torch.empty([m, n], dtype=x.dtype, device=x.device) | 
|  | 37 | +    s = torch.empty([m, n], dtype=x.dtype, device=x.device) | 
|  | 38 | + | 
|  | 39 | +    for tile_m in hl.tile(m): | 
|  | 40 | +        for tile_n in hl.tile(n): | 
|  | 41 | +            s[tile_m, tile_n] = torch.sigmoid(x[tile_m, :] @ w[:, tile_n]) | 
|  | 42 | +            acc = 2.0 * x[tile_m, tile_n] * s[tile_m, tile_n] | 
|  | 43 | +            out[tile_m, tile_n] = acc | 
|  | 44 | + | 
|  | 45 | +    return out, s | 
|  | 46 | + | 
|  | 47 | + | 
|  | 48 | +# %% | 
|  | 49 | +@helion.kernel(static_shapes=True) | 
|  | 50 | +def se_block_bwd_dx(grad_out: Tensor, x: Tensor, w: Tensor, s: Tensor) -> Tensor: | 
|  | 51 | +    """ | 
|  | 52 | +    Compute gradient for x. | 
|  | 53 | +    grad_x = 2 * grad_out * s + (2 * grad_out * x * s * (1 - s)) @ w.T | 
|  | 54 | +
 | 
|  | 55 | +    Args: | 
|  | 56 | +        grad_out: Gradient w.r.t output [m, n] | 
|  | 57 | +        x: Input tensor [m, n] | 
|  | 58 | +        w: Weight matrix [n, n] | 
|  | 59 | +        s: sigmoid(x @ w) from forward pass [m, n] | 
|  | 60 | +
 | 
|  | 61 | +    Returns: | 
|  | 62 | +        grad_x: Gradient w.r.t x [m, n] | 
|  | 63 | +    """ | 
|  | 64 | +    m, n = x.size() | 
|  | 65 | + | 
|  | 66 | +    grad_x = torch.empty([m, n], dtype=torch.float32, device=x.device) | 
|  | 67 | + | 
|  | 68 | +    for tile_m, tile_n in hl.tile([m, n]): | 
|  | 69 | +        # 2 * grad_out * s | 
|  | 70 | +        acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) | 
|  | 71 | +        acc += 2.0 * grad_out[tile_m, tile_n] * s[tile_m, tile_n] | 
|  | 72 | + | 
|  | 73 | +        for tile_k in hl.tile(n): | 
|  | 74 | +            # 2 * grad_out * x * s * (1-s) for tile_k | 
|  | 75 | +            grad_to_w = ( | 
|  | 76 | +                2.0 | 
|  | 77 | +                * grad_out[tile_m, tile_k].to(torch.float32) | 
|  | 78 | +                * x[tile_m, tile_k].to(torch.float32) | 
|  | 79 | +                * s[tile_m, tile_k].to(torch.float32) | 
|  | 80 | +                * (1.0 - s[tile_m, tile_k].to(torch.float32)) | 
|  | 81 | +            ) | 
|  | 82 | +            # grad_to_w @ w.T[tile_k, tile_n] = grad_to_w @ w[tile_n, tile_k].T | 
|  | 83 | +            acc += grad_to_w @ w[tile_n, tile_k].to(torch.float32).T | 
|  | 84 | + | 
|  | 85 | +        grad_x[tile_m, tile_n] = acc.to(x.dtype) | 
|  | 86 | + | 
|  | 87 | +    return grad_x | 
|  | 88 | + | 
|  | 89 | + | 
|  | 90 | +# %% | 
|  | 91 | +@helion.kernel(static_shapes=True) | 
|  | 92 | +def se_block_bwd_dw(grad_out: Tensor, x: Tensor, s: Tensor) -> Tensor: | 
|  | 93 | +    """ | 
|  | 94 | +    Compute gradient for w. | 
|  | 95 | +    grad_w = x.T @ (2 * grad_out * x * s * (1 - s)) | 
|  | 96 | +
 | 
|  | 97 | +    Args: | 
|  | 98 | +        grad_out: Gradient w.r.t output [m, n] | 
|  | 99 | +        x: Input tensor [m, n] | 
|  | 100 | +        s: sigmoid(x @ w) from forward pass [m, n] | 
|  | 101 | +
 | 
|  | 102 | +    Returns: | 
|  | 103 | +        grad_w: Gradient w.r.t w [n, n] | 
|  | 104 | +    """ | 
|  | 105 | +    m, n = x.size() | 
|  | 106 | + | 
|  | 107 | +    grad_w = torch.zeros([n, n], dtype=torch.float32, device=x.device) | 
|  | 108 | + | 
|  | 109 | +    for tile_n1, tile_n2 in hl.tile([n, n]): | 
|  | 110 | +        acc_w = hl.zeros([tile_n1, tile_n2], dtype=torch.float32) | 
|  | 111 | +        for tile_m in hl.tile(m): | 
|  | 112 | +            # 2 * grad_out * x * s * (1-s) | 
|  | 113 | +            grad_to_w = ( | 
|  | 114 | +                2.0 | 
|  | 115 | +                * grad_out[tile_m, tile_n2].to(torch.float32) | 
|  | 116 | +                * x[tile_m, tile_n2].to(torch.float32) | 
|  | 117 | +                * s[tile_m, tile_n2].to(torch.float32) | 
|  | 118 | +                * (1.0 - s[tile_m, tile_n2].to(torch.float32)) | 
|  | 119 | +            ) | 
|  | 120 | +            # x[tile_m, tile_n1].T @ grad_to_w[tile_m, tile_n2] | 
|  | 121 | +            acc_w += x[tile_m, tile_n1].to(torch.float32).T @ grad_to_w | 
|  | 122 | + | 
|  | 123 | +        grad_w[tile_n1, tile_n2] = acc_w.to(x.dtype) | 
|  | 124 | + | 
|  | 125 | +    return grad_w | 
|  | 126 | + | 
|  | 127 | + | 
|  | 128 | +# %% | 
|  | 129 | +# Reference Implementation | 
|  | 130 | +# -------------------- | 
|  | 131 | +def se_block_pytorch(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor: | 
|  | 132 | +    """ | 
|  | 133 | +    PyTorch reference implementation se_block. | 
|  | 134 | +
 | 
|  | 135 | +    Args: | 
|  | 136 | +        x, w: Input tensors | 
|  | 137 | +
 | 
|  | 138 | +    Returns: | 
|  | 139 | +        tensor of 2 * x * sigmoid(x @ w) | 
|  | 140 | +    """ | 
|  | 141 | +    return 2 * x * torch.sigmoid(x @ w) | 
|  | 142 | + | 
|  | 143 | + | 
|  | 144 | +# %% | 
|  | 145 | +# Autograd Function | 
|  | 146 | +# ------------------ | 
|  | 147 | +class SEBlockFunction(torch.autograd.Function): | 
|  | 148 | +    @staticmethod | 
|  | 149 | +    def forward(  # type: ignore[override] | 
|  | 150 | +        ctx: object, | 
|  | 151 | +        x: torch.Tensor, | 
|  | 152 | +        w: torch.Tensor, | 
|  | 153 | +    ) -> torch.Tensor: | 
|  | 154 | +        """Forward pass for se block.""" | 
|  | 155 | +        out, s = se_block_fwd(x, w) | 
|  | 156 | +        ctx.save_for_backward(x, w, s)  # type: ignore[attr-defined] | 
|  | 157 | +        return out | 
|  | 158 | + | 
|  | 159 | +    @staticmethod | 
|  | 160 | +    def backward(  # type: ignore[override] | 
|  | 161 | +        ctx: object, | 
|  | 162 | +        grad_out: torch.Tensor, | 
|  | 163 | +    ) -> tuple[torch.Tensor, torch.Tensor]: | 
|  | 164 | +        """Backward pass for se block.""" | 
|  | 165 | +        x, w, s = ctx.saved_tensors  # type: ignore[attr-defined] | 
|  | 166 | + | 
|  | 167 | +        grad_x = se_block_bwd_dx(grad_out, x, w, s) | 
|  | 168 | +        grad_w = se_block_bwd_dw(grad_out, x, s) | 
|  | 169 | + | 
|  | 170 | +        return grad_x, grad_w | 
|  | 171 | + | 
|  | 172 | + | 
|  | 173 | +def se_block(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor: | 
|  | 174 | +    """ | 
|  | 175 | +    SE Block with autograd support. | 
|  | 176 | +
 | 
|  | 177 | +    Args: | 
|  | 178 | +        x: Input tensor [m, n] | 
|  | 179 | +        w: Weight matrix [n, n] | 
|  | 180 | +
 | 
|  | 181 | +    Returns: | 
|  | 182 | +        Output tensor [m, n] | 
|  | 183 | +    """ | 
|  | 184 | +    return SEBlockFunction.apply(x, w)  # type: ignore[no-any-return] | 
|  | 185 | + | 
|  | 186 | + | 
|  | 187 | +def check(m: int, n: int) -> None: | 
|  | 188 | +    """ | 
|  | 189 | +    Checks the correctness against PyTorch. | 
|  | 190 | +    Args: | 
|  | 191 | +        m (int): Number of rows in matrix x. | 
|  | 192 | +        n (int): Number of columns in matrix x. | 
|  | 193 | +    """ | 
|  | 194 | +    x = torch.randn([m, n], device=DEVICE, dtype=torch.float16, requires_grad=True) | 
|  | 195 | +    w = torch.randn([n, n], device=DEVICE, dtype=torch.float16, requires_grad=True) | 
|  | 196 | +    for bwd in [True, False]: | 
|  | 197 | +        run_example(se_block, se_block_pytorch, (x, w), bwd=bwd) | 
|  | 198 | + | 
|  | 199 | + | 
|  | 200 | +# %% | 
|  | 201 | +def main() -> None: | 
|  | 202 | +    """ | 
|  | 203 | +    Main function to run correctness checks. | 
|  | 204 | +    """ | 
|  | 205 | +    check(1024, 1024) | 
|  | 206 | + | 
|  | 207 | + | 
|  | 208 | +# %% | 
|  | 209 | +if __name__ == "__main__": | 
|  | 210 | +    main() | 
0 commit comments