@@ -3707,7 +3707,7 @@ import triton.language as tl
37073707from helion.runtime import default_launcher as _default_launcher
37083708
37093709@triton.jit
3710- def _helion_swiglu (a_flat, b_flat, out_flat, a_flat_stride_0, b_flat_stride_0, out_flat_stride_0, total_elements, _BLOCK_SIZE_0: tl.constexpr):
3710+ def _helion_swiglu_fwd (a_flat, b_flat, out_flat, a_flat_stride_0, b_flat_stride_0, out_flat_stride_0, total_elements, _BLOCK_SIZE_0: tl.constexpr):
37113711 pid_0 = tl.program_id(0)
37123712 offset_0 = pid_0 * _BLOCK_SIZE_0
37133713 indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
@@ -3721,7 +3721,7 @@ def _helion_swiglu(a_flat, b_flat, out_flat, a_flat_stride_0, b_flat_stride_0, o
37213721 v_4 = v_3 * b_vals
37223722 tl.store(out_flat + indices_0 * out_flat_stride_0, v_4, mask_0)
37233723
3724- def swiglu (a: Tensor, b: Tensor, *, _launcher=_default_launcher):
3724+ def swiglu_fwd (a: Tensor, b: Tensor, *, _launcher=_default_launcher):
37253725 """
37263726 Performs SwiGLU operation: SiLU(a) * b where SiLU is the Swish activation.
37273727
@@ -3742,9 +3742,63 @@ def swiglu(a: Tensor, b: Tensor, *, _launcher=_default_launcher):
37423742 b_flat = b.view(-1)
37433743 out_flat = out.view(-1)
37443744 _BLOCK_SIZE_0 = 16
3745- _launcher(_helion_swiglu , (triton.cdiv(total_elements, _BLOCK_SIZE_0),), a_flat, b_flat, out_flat, a_flat.stride(0), b_flat.stride(0), out_flat.stride(0), total_elements, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
3745+ _launcher(_helion_swiglu_fwd , (triton.cdiv(total_elements, _BLOCK_SIZE_0),), a_flat, b_flat, out_flat, a_flat.stride(0), b_flat.stride(0), out_flat.stride(0), total_elements, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
37463746 return out
37473747
3748+ --- assertExpectedJournal(TestExamples.test_swiglu_bwd)
3749+ from __future__ import annotations
3750+
3751+ import torch
3752+ import triton
3753+ import triton.language as tl
3754+ from torch._inductor.runtime.triton_compat import libdevice
3755+ from helion.runtime import default_launcher as _default_launcher
3756+
3757+ @triton.jit
3758+ def _helion_swiglu_bwd(x1_flat, gout_flat, dx2_flat, x2_flat, dx1_flat, x1_size_0, dx1_flat_stride_0, dx2_flat_stride_0, gout_flat_stride_0, x1_flat_stride_0, x2_flat_stride_0, _BLOCK_SIZE_0: tl.constexpr):
3759+ pid_0 = tl.program_id(0)
3760+ offset_0 = pid_0 * _BLOCK_SIZE_0
3761+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
3762+ mask_0 = indices_0 < x1_size_0
3763+ load = tl.load(x1_flat + indices_0 * x1_flat_stride_0, mask_0, other=0)
3764+ v_0 = tl.cast(load, tl.float32)
3765+ load_1 = tl.load(gout_flat + indices_0 * gout_flat_stride_0, mask_0, other=0)
3766+ v_1 = tl.cast(load_1, tl.float32)
3767+ v_2 = tl.sigmoid(tl.cast(v_0, tl.float32))
3768+ v_3 = v_0 * v_2
3769+ v_4 = v_3 * v_1
3770+ v_5 = tl.cast(v_4, tl.bfloat16)
3771+ tl.store(dx2_flat + indices_0 * dx2_flat_stride_0, v_5, mask_0)
3772+ load_2 = tl.load(x2_flat + indices_0 * x2_flat_stride_0, mask_0, other=0)
3773+ v_6 = tl.cast(load_2, tl.float32)
3774+ v_7 = libdevice.exp(v_0)
3775+ v_8 = 1.0
3776+ v_9 = v_7 + v_8
3777+ v_10 = v_7 / v_9
3778+ v_11 = v_0 * v_7
3779+ v_12 = v_11 / v_9
3780+ v_13 = v_12 / v_9
3781+ v_14 = v_10 + v_13
3782+ v_15 = v_1 * v_6
3783+ v_16 = v_15 * v_14
3784+ v_17 = tl.cast(v_16, tl.bfloat16)
3785+ tl.store(dx1_flat + indices_0 * dx1_flat_stride_0, v_17, mask_0)
3786+
3787+ def swiglu_bwd(gout: Tensor, x1: Tensor, x2: Tensor, *, _launcher=_default_launcher):
3788+ """
3789+ Implement the backward formula for swiglu.
3790+ """
3791+ dx1 = torch.empty_like(x1)
3792+ dx2 = torch.empty_like(x2)
3793+ gout_flat = gout.view(-1)
3794+ x1_flat = x1.view(-1)
3795+ x2_flat = x2.view(-1)
3796+ dx1_flat = dx1.view(-1)
3797+ dx2_flat = dx2.view(-1)
3798+ _BLOCK_SIZE_0 = 1024
3799+ _launcher(_helion_swiglu_bwd, (triton.cdiv(x1.size(0), _BLOCK_SIZE_0),), x1_flat, gout_flat, dx2_flat, x2_flat, dx1_flat, x1.size(0), dx1_flat.stride(0), dx2_flat.stride(0), gout_flat.stride(0), x1_flat.stride(0), x2_flat.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
3800+ return (dx1, dx2)
3801+
37483802--- assertExpectedJournal(TestExamples.test_template_via_closure0)
37493803from __future__ import annotations
37503804
0 commit comments