Skip to content

Commit f1c0b89

Browse files
authored
Generalize examples with the DEVICE variable (#915)
1 parent 3f6d43d commit f1c0b89

38 files changed

+185
-135
lines changed

examples/add.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import torch
1616

1717
import helion
18+
from helion._testing import DEVICE
1819
from helion._testing import run_example
1920
import helion.language as hl
2021

@@ -64,8 +65,8 @@ def check(m: int, n: int) -> None:
6465
m: First dimension of the test tensors
6566
n: Second dimension of the test tensors
6667
"""
67-
x = torch.randn([m, n], device="cuda", dtype=torch.float16)
68-
y = torch.randn([m, n], device="cuda", dtype=torch.float16)
68+
x = torch.randn([m, n], device=DEVICE, dtype=torch.float16)
69+
y = torch.randn([m, n], device=DEVICE, dtype=torch.float16)
6970
run_example(add, torch.add, (x, y))
7071

7172

examples/all_gather_matmul.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import torch.distributed._symmetric_memory as symm_mem
2121

2222
import helion
23+
from helion._testing import DEVICE
2324
import helion.language as hl
2425

2526

@@ -201,7 +202,7 @@ def test(M: int, N: int, K: int, world_size: int, device: torch.device) -> None:
201202
a_shared = symm_mem.empty(
202203
M // world_size, K, dtype=torch.bfloat16, device=device
203204
).normal_()
204-
b = torch.randn((K, N), device="cuda", dtype=torch.bfloat16).T.contiguous().T
205+
b = torch.randn((K, N), device=DEVICE, dtype=torch.bfloat16).T.contiguous().T
205206
a_out, c = helion_all_gather_matmul(a_shared, b)
206207
golden_a = a_shared.clone()
207208
dist_group = dist.group.WORLD
@@ -239,4 +240,6 @@ def main() -> None:
239240
--rdzv-backend c10d --rdzv-endpoint localhost:0 \
240241
--no_python python3 examples/all_gather_matmul.py
241242
"""
243+
# TODO(adam-smnk): generalize to XPU
244+
assert DEVICE.type == "cuda", "Requires CUDA device"
242245
main()

examples/all_reduce.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from torch.utils.cpp_extension import load_inline
2323

2424
import helion
25+
from helion._testing import DEVICE
2526
import helion.language as hl
2627

2728
# %%
@@ -273,4 +274,6 @@ def main() -> None:
273274
--rdzv-backend c10d --rdzv-endpoint localhost:0 \
274275
--no_python python3 examples/all_reduce.py
275276
"""
277+
# TODO(adam-smnk): generalize to XPU
278+
assert DEVICE.type == "cuda", "Requires CUDA device"
276279
main()

examples/attention.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from torch.nn.attention.flex_attention import flex_attention
2222

2323
import helion
24+
from helion._testing import DEVICE
2425
from helion._testing import run_example
2526
import helion.language as hl
2627

@@ -165,7 +166,7 @@ def main() -> None:
165166
Main entry point that runs the attention kernel test with specific parameters.
166167
Tests with batch size 2, 32 heads, 1024 sequence length, and 64-dimensional heads using float16.
167168
"""
168-
test(2, 32, 1024, 64, torch.float16)
169+
test(2, 32, 1024, 64, torch.float16, device=DEVICE)
169170

170171

171172
if __name__ == "__main__":

examples/bf16xint16_gemm.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from torch import Tensor
1515

1616
import helion
17+
from helion._testing import DEVICE
1718
import helion.language as hl
1819

1920

@@ -137,17 +138,17 @@ def check(m: int, k: int, n: int) -> None:
137138
k (int): Shared dimension.
138139
n (int): Number of cols.
139140
"""
140-
x = torch.randn([m, k], device="cuda", dtype=torch.bfloat16)
141-
w = torch.randint(-(2**15), 2**15 - 1, (k, n), device="cuda", dtype=torch.int16)
141+
x = torch.randn([m, k], device=DEVICE, dtype=torch.bfloat16)
142+
w = torch.randint(-(2**15), 2**15 - 1, (k, n), device=DEVICE, dtype=torch.int16)
142143

143144
result = bf16xint16_gemm(x, w, transpose=False)
144145
expected = reference_bf16xint16_pytorch(x, w, transpose=False)
145146
torch.testing.assert_close(result, expected, rtol=1e-2, atol=1e-2)
146147

147148
x_int16 = torch.randint(
148-
-(2**15), 2**15 - 1, (m, k), device="cuda", dtype=torch.int16
149+
-(2**15), 2**15 - 1, (m, k), device=DEVICE, dtype=torch.int16
149150
)
150-
w_bf16 = torch.randn([k, n], device="cuda", dtype=torch.bfloat16)
151+
w_bf16 = torch.randn([k, n], device=DEVICE, dtype=torch.bfloat16)
151152

152153
result = bf16xint16_gemm(x_int16, w_bf16, transpose=True)
153154
expected = reference_bf16xint16_pytorch(x_int16, w_bf16, transpose=True)

examples/bmm.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import torch
1717

1818
import helion
19+
from helion._testing import DEVICE
1920
from helion._testing import run_example
2021
import helion.language as hl
2122

@@ -70,8 +71,8 @@ def check(b: int, m: int, k: int, n: int) -> None:
7071
k: Second dimension of the first matrix / First dimension of the second matrix
7172
n: Second dimension of the second matrix
7273
"""
73-
x = torch.randn([b, m, k], device="cuda", dtype=torch.float16)
74-
y = torch.randn([b, k, n], device="cuda", dtype=torch.float16)
74+
x = torch.randn([b, m, k], device=DEVICE, dtype=torch.float16)
75+
y = torch.randn([b, k, n], device=DEVICE, dtype=torch.float16)
7576
run_example(bmm, torch.bmm, (x, y))
7677

7778

examples/concatenate.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import torch
1616

1717
import helion
18+
from helion._testing import DEVICE
1819
from helion._testing import run_example
1920
import helion.language as hl
2021

@@ -67,8 +68,8 @@ def main() -> None:
6768
Main entry point that runs the concatenation kernel verification.
6869
Tests with two tensors of shapes [1500, 400] and [1500, 600].
6970
"""
70-
x = torch.randn([1500, 400], device="cuda")
71-
y = torch.randn([1500, 600], device="cuda")
71+
x = torch.randn([1500, 400], device=DEVICE)
72+
y = torch.randn([1500, 600], device=DEVICE)
7273
run_example(concat2d_dim1, lambda x, y: torch.cat([x, y], dim=1), (x, y))
7374

7475

examples/cross_entropy.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import torch
1616

1717
import helion
18+
from helion._testing import DEVICE
1819
from helion._testing import run_example
1920
import helion.language as hl
2021

@@ -89,8 +90,8 @@ def main() -> None:
8990
"""
9091
batch_size, seq_len, vocab_size = 8, 2048, 131072
9192
n = batch_size * seq_len
92-
logits = torch.randn(n, vocab_size, device="cuda", dtype=torch.float32)
93-
labels = torch.randint(0, vocab_size, (n,), device="cuda", dtype=torch.long)
93+
logits = torch.randn(n, vocab_size, device=DEVICE, dtype=torch.float32)
94+
labels = torch.randint(0, vocab_size, (n,), device=DEVICE, dtype=torch.long)
9495

9596
run_example(
9697
cross_entropy,

examples/embedding.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import torch
1818

1919
import helion
20+
from helion._testing import DEVICE
2021
from helion._testing import run_example
2122
import helion.language as hl
2223

@@ -88,8 +89,8 @@ def main() -> None:
8889
Tests with a batch of indices and an embedding table of size 16x64.
8990
"""
9091
num_embeddings, embedding_dim = 16, 64
91-
x = torch.randint(0, num_embeddings, [256, 32], device="cuda", dtype=torch.int32)
92-
weight = torch.randn([num_embeddings, embedding_dim], device="cuda")
92+
x = torch.randint(0, num_embeddings, [256, 32], device=DEVICE, dtype=torch.int32)
93+
weight = torch.randn([num_embeddings, embedding_dim], device=DEVICE)
9394
run_example(
9495
embedding, torch.nn.functional.embedding, (x, weight), atol=0.0, rtol=0.0
9596
)

examples/exp.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import torch
1818

1919
import helion
20+
from helion._testing import DEVICE
2021
from helion._testing import run_example
2122
import helion.language as hl
2223

@@ -134,7 +135,7 @@ def check(n: int) -> None:
134135
Args:
135136
n: Size of the test tensor
136137
"""
137-
x = torch.randn(n, device="cuda", dtype=torch.float32, requires_grad=True)
138+
x = torch.randn(n, device=DEVICE, dtype=torch.float32, requires_grad=True)
138139
run_example(exp, torch.exp, (x,), bwd=True)
139140

140141

0 commit comments

Comments
 (0)