diff --git a/benchmarks/run.py b/benchmarks/run.py index 639f4017d..f1379b125 100644 --- a/benchmarks/run.py +++ b/benchmarks/run.py @@ -128,6 +128,11 @@ class RunResult: "examples.swiglu", "swiglu_tritonbench", ), + "swiglu-bwd": ( + "tritonbench.operators.swiglu.operator", + "examples.swiglu", + "swiglu_tritonbench", + ), "jsd": ( "tritonbench.operators.jsd.operator", "examples.jsd", @@ -440,6 +445,15 @@ class RunResult: "helion_swiglu_tritonbench-speedup": "helion_speedup", "helion_swiglu_tritonbench-accuracy": "helion_accuracy", }, + "swiglu-bwd": { + "torch_swiglu": "baseline", + "liger_swiglu-speedup": "triton_speedup", + "liger_swiglu-accuracy": "triton_accuracy", + "torch_compile_swiglu-speedup": "torch_compile_speedup", + "torch_compile_swiglu-accuracy": "torch_compile_accuracy", + "helion_swiglu_tritonbench-speedup": "helion_speedup", + "helion_swiglu_tritonbench-accuracy": "helion_accuracy", + }, "jsd": { "torch_jsd": "baseline", "liger_jsd-speedup": "triton_speedup", diff --git a/examples/swiglu.py b/examples/swiglu.py index c038304d5..f597d41ea 100644 --- a/examples/swiglu.py +++ b/examples/swiglu.py @@ -36,6 +36,7 @@ if TYPE_CHECKING: from collections.abc import Callable + from typing import Any # %% @@ -45,7 +46,7 @@ # %% @helion.kernel() -def swiglu(a: Tensor, b: Tensor) -> Tensor: +def swiglu_fwd(a: Tensor, b: Tensor) -> Tensor: """ Performs SwiGLU operation: SiLU(a) * b where SiLU is the Swish activation. @@ -94,6 +95,65 @@ def swiglu(a: Tensor, b: Tensor) -> Tensor: return out +@helion.kernel() +def swiglu_bwd(gout: Tensor, x1: Tensor, x2: Tensor) -> tuple[Tensor, Tensor]: + """ + Implement the backward formula for swiglu. + """ + dx1 = torch.empty_like(x1) + dx2 = torch.empty_like(x2) + + gout_flat = gout.view(-1) + x1_flat = x1.view(-1) + x2_flat = x2.view(-1) + dx1_flat = dx1.view(-1) + dx2_flat = dx2.view(-1) + + for tile in hl.tile(x1.numel()): + x1_vals = x1_flat[tile].to(torch.float32) + gout_vals = gout_flat[tile].to(torch.float32) + + # compute dx2 + dx2_vals = x1_vals * torch.sigmoid(x1_vals) * gout_vals + dx2_flat[tile] = dx2_vals.to(x2.dtype) + + # compute dx1 + x2_vals = x2_flat[tile].to(torch.float32) + x1_exp = torch.exp(x1_vals) + x1_exp_plus1 = x1_exp + 1 + dextra = x1_exp / x1_exp_plus1 + x1_vals * x1_exp / x1_exp_plus1 / x1_exp_plus1 + dx1_vals = gout_vals * x2_vals * dextra + dx1_flat[tile] = dx1_vals.to(x1.dtype) + + return dx1, dx2 + + +class SwigluFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx: Any, # noqa: ANN401 + x1: Tensor, + x2: Tensor, + ) -> Tensor: + out = swiglu_fwd(x1, x2) + ctx.save_for_backward(x1, x2) + return out + + @staticmethod + def backward( # type: ignore[override] + ctx: Any, # noqa: ANN401 + grad_out: Tensor, + ) -> tuple[Tensor, Tensor]: + x1, x2 = ctx.saved_tensors + dx1, dx2 = swiglu_bwd(grad_out, x1, x2) + return dx1, dx2 + + +def swiglu(a: Tensor, b: Tensor) -> Tensor: + """swiglu with forward + backward support.""" + return SwigluFunction.apply(a, b) # type: ignore[no-any-return] + + # %% # SwiGLU MLP Module (matches liger_kernel structure) # -------------------------------------------------- diff --git a/test/test_examples.expected b/test/test_examples.expected index ac3c91cff..247e24655 100644 --- a/test/test_examples.expected +++ b/test/test_examples.expected @@ -5952,7 +5952,7 @@ import triton.language as tl from helion.runtime import default_launcher as _default_launcher @triton.jit -def _helion_swiglu(a_flat, b_flat, out_flat, _BLOCK_SIZE_0: tl.constexpr): +def _helion_swiglu_fwd(a_flat, b_flat, out_flat, _BLOCK_SIZE_0: tl.constexpr): # src[swiglu.py:N]: for tile_idx in hl.tile(total_elements): pid_0 = tl.program_id(0) offset_0 = pid_0 * _BLOCK_SIZE_0 @@ -5972,7 +5972,7 @@ def _helion_swiglu(a_flat, b_flat, out_flat, _BLOCK_SIZE_0: tl.constexpr): # src[swiglu.py:N]: out_flat[tile_idx] = result tl.store(out_flat + indices_0 * 1, v_4, None) -def swiglu(a: Tensor, b: Tensor, *, _launcher=_default_launcher): +def swiglu_fwd(a: Tensor, b: Tensor, *, _launcher=_default_launcher): """ Performs SwiGLU operation: SiLU(a) * b where SiLU is the Swish activation. @@ -6006,10 +6006,86 @@ def swiglu(a: Tensor, b: Tensor, *, _launcher=_default_launcher): # src[swiglu.py:N]: # Load input values and convert to float32 for computation # src[swiglu.py:N]: a_vals = a_flat[tile_idx].to(torch.float32) # src[swiglu.py:N-N]: ... - _launcher(_helion_swiglu, (triton.cdiv(1048576, _BLOCK_SIZE_0),), a_flat, b_flat, out_flat, _BLOCK_SIZE_0, num_warps=4, num_stages=3) + _launcher(_helion_swiglu_fwd, (triton.cdiv(1048576, _BLOCK_SIZE_0),), a_flat, b_flat, out_flat, _BLOCK_SIZE_0, num_warps=4, num_stages=3) # src[swiglu.py:N]: return out return out +--- assertExpectedJournal(TestExamples.test_swiglu_bwd) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_swiglu_bwd(x1_flat, gout_flat, dx2_flat, x2_flat, dx1_flat, _BLOCK_SIZE_0: tl.constexpr): + # src[swiglu.py:N]: for tile in hl.tile(x1.numel()): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + # src[swiglu.py:N]: x1_vals = x1_flat[tile].to(torch.float32) + load = tl.load(x1_flat + indices_0 * 1, None) + v_0 = tl.cast(load, tl.float32) + # src[swiglu.py:N]: gout_vals = gout_flat[tile].to(torch.float32) + load_1 = tl.load(gout_flat + indices_0 * 1, None) + v_1 = tl.cast(load_1, tl.float32) + # src[swiglu.py:N]: dx2_vals = x1_vals * torch.sigmoid(x1_vals) * gout_vals + v_2 = tl.sigmoid(tl.cast(v_0, tl.float32)) + v_3 = v_0 * v_2 + v_4 = v_3 * v_1 + # src[swiglu.py:N]: dx2_flat[tile] = dx2_vals.to(x2.dtype) + v_5 = tl.cast(v_4, tl.bfloat16) + tl.store(dx2_flat + indices_0 * 1, v_5, None) + # src[swiglu.py:N]: x2_vals = x2_flat[tile].to(torch.float32) + load_2 = tl.load(x2_flat + indices_0 * 1, None) + v_6 = tl.cast(load_2, tl.float32) + # src[swiglu.py:N]: x1_exp = torch.exp(x1_vals) + v_7 = libdevice.exp(v_0) + # src[swiglu.py:N]: x1_exp_plus1 = x1_exp + 1 + v_8 = 1.0 + v_9 = v_7 + v_8 + # src[swiglu.py:N]: dextra = x1_exp / x1_exp_plus1 + x1_vals * x1_exp / x1_exp_plus1 / x1_exp_plus1 + v_10 = v_7 / v_9 + v_11 = v_0 * v_7 + v_12 = v_11 / v_9 + v_13 = v_12 / v_9 + v_14 = v_10 + v_13 + # src[swiglu.py:N]: dx1_vals = gout_vals * x2_vals * dextra + v_15 = v_1 * v_6 + v_16 = v_15 * v_14 + # src[swiglu.py:N]: dx1_flat[tile] = dx1_vals.to(x1.dtype) + v_17 = tl.cast(v_16, tl.bfloat16) + tl.store(dx1_flat + indices_0 * 1, v_17, None) + +def swiglu_bwd(gout: Tensor, x1: Tensor, x2: Tensor, *, _launcher=_default_launcher): + """ + Implement the backward formula for swiglu. + """ + # src[swiglu.py:N]: dx1 = torch.empty_like(x1) + dx1 = torch.empty_like(x1) + # src[swiglu.py:N]: dx2 = torch.empty_like(x2) + dx2 = torch.empty_like(x2) + # src[swiglu.py:N]: gout_flat = gout.view(-1) + gout_flat = gout.view(-1) + # src[swiglu.py:N]: x1_flat = x1.view(-1) + x1_flat = x1.view(-1) + # src[swiglu.py:N]: x2_flat = x2.view(-1) + x2_flat = x2.view(-1) + # src[swiglu.py:N]: dx1_flat = dx1.view(-1) + dx1_flat = dx1.view(-1) + # src[swiglu.py:N]: dx2_flat = dx2.view(-1) + dx2_flat = dx2.view(-1) + # src[swiglu.py:N]: for tile in hl.tile(x1.numel()): + _BLOCK_SIZE_0 = 1024 + # src[swiglu.py:N]: for tile in hl.tile(x1.numel()): + # src[swiglu.py:N]: x1_vals = x1_flat[tile].to(torch.float32) + # src[swiglu.py:N]: gout_vals = gout_flat[tile].to(torch.float32) + # src[swiglu.py:N-N]: ... + _launcher(_helion_swiglu_bwd, (triton.cdiv(1024, _BLOCK_SIZE_0),), x1_flat, gout_flat, dx2_flat, x2_flat, dx1_flat, _BLOCK_SIZE_0, num_warps=4, num_stages=2) + # src[swiglu.py:N]: return dx1, dx2 + return (dx1, dx2) + --- assertExpectedJournal(TestExamples.test_template_via_closure0) from __future__ import annotations diff --git a/test/test_examples.py b/test/test_examples.py index 6cda8b8a4..88420e5b4 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -5,6 +5,7 @@ from packaging import version import torch +import torch.nn.functional as F import helion from helion import _compat @@ -494,6 +495,33 @@ def test_rms_norm_fwd(self): ) ) + def test_swiglu_bwd(self): + """Test backward pass for swiglu.""" + x1, x2 = [ + torch.randn(1024, device=DEVICE, dtype=torch.bfloat16, requires_grad=True) + for _ in range(2) + ] + + out = F.silu(x1) * x2 + + grad_out = torch.randn_like(out) + out.backward(grad_out) + + args = ( + grad_out, + x1, + x2, + ) + + self.assertExpectedJournal( + check_example( + "swiglu", + args, + (x1.grad, x2.grad), + fn_name="swiglu_bwd", + ) + ) + def test_rms_norm_bwd(self): """Test backward pass for rms norm weight gradient.""" batch_size, dim = 32, 64 @@ -1208,6 +1236,7 @@ def test_swiglu(self): "swiglu", args, torch.nn.functional.silu(args[0]) * args[1], + fn_name="swiglu_fwd", block_sizes=[16], num_warps=4, num_stages=3,