Skip to content

Commit 5b126e4

Browse files
authored
Adding demo for running softmax kernel on Google colab (#944)
1 parent 9a30bd1 commit 5b126e4

File tree

5 files changed

+248
-4
lines changed

5 files changed

+248
-4
lines changed

.github/workflows/test.yml

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,68 @@ jobs:
134134
# -rf: print failed tests
135135
# --timeout: max allowed time for each test
136136
pytest -rf --timeout=60
137+
138+
test-notebooks:
139+
name: test-notebooks-cu128-py3.12-pytorch-2.9-a10g
140+
141+
container:
142+
image: nvidia/cuda:12.8.1-devel-ubuntu24.04
143+
options: --gpus all
144+
145+
runs-on: linux.g5.4xlarge.nvidia.gpu
146+
147+
defaults:
148+
run:
149+
shell: bash -l {0}
150+
151+
steps:
152+
- name: Run NVIDIA command
153+
run: |
154+
echo "Detected NVIDIA image"
155+
nvidia-smi || echo "nvidia-smi not found"
156+
157+
- name: Check out code
158+
uses: actions/checkout@v5
159+
160+
- name: Install uv
161+
uses: astral-sh/setup-uv@v7
162+
with:
163+
python-version: "3.12"
164+
enable-cache: true
165+
166+
- name: Create virtual environment
167+
run: |
168+
uv venv --python 3.12
169+
170+
- name: Install pip in venv
171+
run: |
172+
source .venv/bin/activate
173+
uv pip install pip
174+
175+
- name: Get current month
176+
id: date
177+
run: echo "month=$(date +'%Y-%m')" >> $GITHUB_OUTPUT
178+
179+
- name: Cache dependencies
180+
id: cache
181+
uses: actions/cache@v4
182+
with:
183+
path: |
184+
~/.cache/uv
185+
~/.venv
186+
key: notebooks-3.12-cu128-${{ hashFiles('.github/workflows/test.yml') }}-${{ steps.date.outputs.month }}
187+
188+
- name: Install notebook execution tools
189+
run: |
190+
source .venv/bin/activate
191+
# Install jupyter for executing notebooks
192+
uv pip install jupyter nbconvert pytest numpy
193+
194+
- name: Run Notebook Tests
195+
run: |
196+
source .venv/bin/activate
197+
# Execute notebook using jupyter nbconvert
198+
# The notebook's subprocess pip install will install torch and helion
199+
jupyter nbconvert --to notebook --execute --inplace \
200+
--ExecutePreprocessor.timeout=600 \
201+
notebooks/softmax.ipynb

README.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
# About
1313

14-
📚 **[View Documentation](https://helionlang.com)** 📚 | 🎥 **[Watch Talk](https://youtu.be/MBOPzfl1JBo?si=DwAhgL-bpH1kFSt3)** 🎥
14+
📚 **[View Documentation](https://helionlang.com)** 📚 | 🎥 **[Watch Talk](https://youtu.be/MBOPzfl1JBo?si=DwAhgL-bpH1kFSt3)** 🎥 | 🚀 **[Try In Colab](https://colab.research.google.com/github/pytorch/helion/blob/main/notebooks/softmax.ipynb)** 🚀
1515

1616
**Helion** is a Python-embedded domain-specific language (DSL) for
1717
authoring machine learning kernels, designed to compile down to [Triton],
@@ -66,7 +66,6 @@ portable between different hardware. Helion automates and autotunes over:
6666
* Persistent kernel strategies.
6767
* Warp specialization choices, unrolling, and more.
6868

69-
7069
## Example
7170

7271
A minimal matrix multiplication kernel in Helion looks like this:

docs/index.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,12 @@ portable between different hardware. Helion automates and autotunes over:
6666
* Persistent kernel strategies.
6767
* Warp specialization choices, unrolling, and more.
6868

69+
## Try Helion Now
70+
71+
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pytorch/helion/blob/main/notebooks/softmax.ipynb)
72+
73+
Try our [interactive demo notebook](https://github.com/pytorch/helion/blob/main/notebooks/softmax.ipynb) to see Helion in action! The notebook demonstrates softmax kernel implementations and runs directly in Google Colab on a GPU.
74+
6975
## Example
7076

7177
A minimal matrix multiplication kernel in Helion looks like this:

notebooks/softmax.ipynb

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"%pip install \"torch==2.9.*\" --index-url https://download.pytorch.org/whl/cu126\n",
10+
"%pip install helion\n"
11+
]
12+
},
13+
{
14+
"cell_type": "code",
15+
"execution_count": null,
16+
"metadata": {},
17+
"outputs": [],
18+
"source": [
19+
"\"\"\"\n",
20+
"Helion Softmax Kernel Examples\n",
21+
"==============================\n",
22+
"This example demonstrates multiple Helion kernel implementations of the softmax function,\n",
23+
"including a simple wrapper around PyTorch's softmax, and a numerically optimized two-pass version.\n",
24+
"The example also includes a check function to compare these kernels against PyTorch's\n",
25+
"built-in softmax for correctness.\n",
26+
"\"\"\"\n",
27+
"\n",
28+
"# %%\n",
29+
"from __future__ import annotations\n",
30+
"import torch\n",
31+
"import helion\n",
32+
"from helion._testing import run_example\n",
33+
"import helion.language as hl\n",
34+
"\n",
35+
"\n",
36+
"# %%\n",
37+
"@helion.kernel(autotune_effort=\"quick\")\n",
38+
"def softmax(x: torch.Tensor) -> torch.Tensor:\n",
39+
" \"\"\"\n",
40+
" Simple Helion kernel wrapping PyTorch's softmax function.\n",
41+
" Args:\n",
42+
" x (torch.Tensor): Input tensor of shape [n, m].\n",
43+
" Returns:\n",
44+
" torch.Tensor: Softmax output tensor of the same shape.\n",
45+
" \"\"\"\n",
46+
" n, _m = x.size()\n",
47+
" out = torch.empty_like(x)\n",
48+
" for tile_n in hl.tile(n):\n",
49+
" out[tile_n, :] = torch.nn.functional.softmax(x[tile_n, :], dim=1)\n",
50+
" return out\n",
51+
"\n",
52+
"\n",
53+
"# %%\n",
54+
"def check(m: int, n: int) -> None:\n",
55+
" \"\"\"\n",
56+
" Runs correctness checks comparing Helion softmax kernels against PyTorch's softmax.\n",
57+
" Args:\n",
58+
" m (int): Number of rows in input tensor.\n",
59+
" n (int): Number of columns in input tensor.\n",
60+
" \"\"\"\n",
61+
" x = torch.randn([m, n], device=\"cuda\", dtype=torch.float16)\n",
62+
" run_example(softmax, lambda x: torch.nn.functional.softmax(x, dim=1), (x,))\n",
63+
"\n",
64+
"\n",
65+
"# %%\n",
66+
"def main() -> None:\n",
67+
" \"\"\"\n",
68+
" Main function to run the softmax kernel correctness check with example input size.\n",
69+
" \"\"\"\n",
70+
" check(4096, 2560)\n",
71+
"\n",
72+
"\n",
73+
"# %%\n",
74+
"if __name__ == \"__main__\":\n",
75+
" main()\n"
76+
]
77+
},
78+
{
79+
"cell_type": "code",
80+
"execution_count": null,
81+
"metadata": {},
82+
"outputs": [],
83+
"source": [
84+
"\"\"\"\n",
85+
"Helion Softmax Kernel Examples\n",
86+
"==============================\n",
87+
"This example demonstrates multiple Helion kernel implementations of the softmax function,\n",
88+
"including a simple wrapper around PyTorch's softmax, and a numerically optimized two-pass version.\n",
89+
"The example also includes a check function to compare these kernels against PyTorch's\n",
90+
"built-in softmax for correctness.\n",
91+
"\"\"\"\n",
92+
"\n",
93+
"# %%\n",
94+
"from __future__ import annotations\n",
95+
"import torch\n",
96+
"import helion\n",
97+
"from helion._testing import run_example\n",
98+
"import helion.language as hl\n",
99+
"\n",
100+
"\n",
101+
"# %%\n",
102+
"@helion.kernel(autotune_effort=\"quick\")\n",
103+
"def softmax_two_pass(x: torch.Tensor) -> torch.Tensor:\n",
104+
" \"\"\"\n",
105+
" Numerically optimized Helion kernel performing softmax in two passes.\n",
106+
" Args:\n",
107+
" x (torch.Tensor): Input tensor of shape [m, n].\n",
108+
" Returns:\n",
109+
" torch.Tensor: Softmax output tensor of the same shape.\n",
110+
" \"\"\"\n",
111+
" m, n = x.size()\n",
112+
" out = torch.empty_like(x)\n",
113+
" block_size_m = hl.register_block_size(m)\n",
114+
" block_size_n = hl.register_block_size(n)\n",
115+
" for tile_m in hl.tile(m, block_size=block_size_m):\n",
116+
" mi = hl.full([tile_m], float(\"-inf\"), dtype=torch.float32)\n",
117+
" di = hl.zeros([tile_m], dtype=torch.float32)\n",
118+
" for tile_n in hl.tile(n, block_size=block_size_n):\n",
119+
" values = x[tile_m, tile_n]\n",
120+
" local_amax = torch.amax(values, dim=1)\n",
121+
" mi_next = torch.maximum(mi, local_amax)\n",
122+
" di = di * torch.exp(mi - mi_next) + torch.exp(\n",
123+
" values - mi_next[:, None]\n",
124+
" ).sum(dim=1)\n",
125+
" mi = mi_next\n",
126+
" for tile_n in hl.tile(n, block_size=block_size_n):\n",
127+
" values = x[tile_m, tile_n]\n",
128+
" out[tile_m, tile_n] = torch.exp(values - mi[:, None]) / di[:, None]\n",
129+
" return out\n",
130+
"\n",
131+
"\n",
132+
"# %%\n",
133+
"def check(m: int, n: int) -> None:\n",
134+
" \"\"\"\n",
135+
" Runs correctness checks comparing Helion softmax kernels against PyTorch's softmax.\n",
136+
" Args:\n",
137+
" m (int): Number of rows in input tensor.\n",
138+
" n (int): Number of columns in input tensor.\n",
139+
" \"\"\"\n",
140+
" x = torch.randn([m, n], device=\"cuda\", dtype=torch.float16)\n",
141+
" run_example(softmax_two_pass, lambda x: torch.nn.functional.softmax(x, dim=1), (x,))\n",
142+
"\n",
143+
"\n",
144+
"# %%\n",
145+
"def main() -> None:\n",
146+
" \"\"\"\n",
147+
" Main function to run the softmax kernel correctness check with example input size.\n",
148+
" \"\"\"\n",
149+
" check(4096, 2560)\n",
150+
"\n",
151+
"\n",
152+
"# %%\n",
153+
"if __name__ == \"__main__\":\n",
154+
" main()\n"
155+
]
156+
}
157+
],
158+
"metadata": {
159+
"accelerator": "GPU",
160+
"colab": {
161+
"gpuType": "T4",
162+
"provenance": []
163+
},
164+
"kernelspec": {
165+
"display_name": "Python 3",
166+
"name": "python3"
167+
},
168+
"language_info": {
169+
"name": "python"
170+
}
171+
},
172+
"nbformat": 4,
173+
"nbformat_minor": 0
174+
}

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ src = ["helion"]
4343
docstring-code-format = true
4444
quote-style = "double"
4545
line-ending = "lf"
46-
exclude = [".github/*"]
46+
exclude = [".github/*", "notebooks/**/*.ipynb"]
4747

4848
[tool.ruff.lint]
4949
select = [
@@ -64,7 +64,7 @@ ignore = [
6464
]
6565
extend-safe-fixes = ["TC", "UP045", "RUF013", "RSE102"]
6666
preview = true
67-
exclude = ["test/data/*", ".github/*"]
67+
exclude = ["test/data/*", ".github/*", "notebooks/**/*.ipynb"]
6868

6969
[tool.ruff.lint.per-file-ignores]
7070
"test/*" = ["ANN"]

0 commit comments

Comments
 (0)