Skip to content

Commit 6581aac

Browse files
authored
Update input shapes for example kernels (#845)
1 parent f27abd1 commit 6581aac

File tree

13 files changed

+26
-43
lines changed

13 files changed

+26
-43
lines changed

examples/add.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def main() -> None:
6969
"""
7070
Main entry point that runs the add kernel verification with 1024x1024 tensors.
7171
"""
72-
check(1024, 1024)
72+
check(10240, 10240)
7373

7474

7575
if __name__ == "__main__":

examples/cross_entropy.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,12 +79,11 @@ def cross_entropy(
7979
def main() -> None:
8080
"""
8181
Main entry point that runs the cross entropy kernel verification.
82-
Tests with a batch size of 128 and vocabulary size of 1000.
8382
"""
84-
# Test with moderate size
85-
n, v = 128, 1000
86-
logits = torch.randn(n, v, device="cuda", dtype=torch.float32)
87-
labels = torch.randint(0, v, (n,), device="cuda", dtype=torch.long)
83+
batch_size, seq_len, vocab_size = 8, 2048, 131072
84+
n = batch_size * seq_len
85+
logits = torch.randn(n, vocab_size, device="cuda", dtype=torch.float32)
86+
labels = torch.randint(0, vocab_size, (n,), device="cuda", dtype=torch.long)
8887

8988
run_example(
9089
cross_entropy,

examples/exp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,9 @@ def check(n: int) -> None:
132132
# -----------
133133
def main() -> None:
134134
"""
135-
Main entry point that runs the exp kernel verification with a tensor of size 1M elements.
135+
Main entry point that runs the exp kernel verification.
136136
"""
137-
check(1024 * 1024)
137+
check(10240 * 10240)
138138

139139

140140
if __name__ == "__main__":

examples/geglu.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ def main() -> None:
280280
print("Testing GEGLU kernel...")
281281

282282
# Test GEGLU kernel with different shapes
283-
kernel_test_shapes = [(8, 128, 1024), (4, 1024, 2048)]
283+
kernel_test_shapes = [(8, 2048, 4096), (8, 4096, 8192)]
284284

285285
for shape in kernel_test_shapes:
286286
print(f"Testing GEGLU kernel shape: {shape}")
@@ -291,8 +291,8 @@ def main() -> None:
291291

292292
# Test GEGLU MLP with transformer-typical sizes
293293
mlp_test_configs = [
294-
(2, 128, 512, 2048), # Small transformer
295-
(8, 1024, 4096, 11008), # LLaMA-style config
294+
(8, 2048, 4096, 11008),
295+
(8, 4096, 8192, 11008),
296296
]
297297

298298
for batch_size, seq_len, hidden_size, intermediate_size in mlp_test_configs:

examples/int4_gemm.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,9 +163,8 @@ def main() -> None:
163163
"""
164164
Main function to run tests with different matrix sizes.
165165
"""
166-
check(256, 512, 256)
167-
check(512, 512, 512)
168-
check(1024, 1024, 1024)
166+
check(4, 8192, 7168)
167+
check(8192, 8192, 8192)
169168

170169

171170
# %%

examples/jsd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ def main() -> None:
326326
ignore_index = -100
327327
use_labels = False
328328

329-
for V in [2**i for i in range(12, 18)]:
329+
for V in [2**i for i in range(16, 18)]:
330330
print(
331331
f"Testing JSD: B={B}, T={T}, V={V}, beta={beta}, ignore_index={ignore_index}, labels={use_labels}"
332332
)

examples/kl_div.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,8 +244,8 @@ def main() -> None:
244244
log_target = False
245245
eps = 1e-10
246246

247-
# Test with vocabulary sizes from tritonbench (2^12 to 2^17)
248-
for V in [2**i for i in range(12, 18)]:
247+
# Test with vocabulary sizes from tritonbench (2^16 to 2^17)
248+
for V in [2**i for i in range(16, 18)]:
249249
print(
250250
f"Testing KL Div: B={B}, T={T}, V={V}, reduction={reduction}, log_target={log_target}"
251251
)

examples/layer_norm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -278,8 +278,8 @@ def main() -> None:
278278
built-in layer_norm function using the run_example utility.
279279
- Prints comparison results and checks for correctness within specified tolerances.
280280
"""
281-
batch_size = 32
282-
dim = 64
281+
batch_size = 4096
282+
dim = 10240
283283
device = "cuda"
284284

285285
# Test forward pass only

examples/rms_norm.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -240,17 +240,9 @@ def check(m: int, n: int) -> None:
240240
def main() -> None:
241241
"""
242242
Main entry point that runs the RMS norm kernel verification with different tensor sizes.
243-
244-
Tests with configurations:
245-
- 32x64
246-
- 128x256
247-
- 1024x1024
248-
- 2048x1024
249243
"""
250-
check(32, 64)
251-
check(128, 256)
252-
check(1024, 1024)
253-
check(2048, 1024)
244+
check(2048, 4096)
245+
check(2048, 8192)
254246

255247

256248
if __name__ == "__main__":

examples/softmax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def main() -> None:
111111
"""
112112
Main function to run the softmax kernel correctness check with example input size.
113113
"""
114-
check(1024, 1024)
114+
check(4096, 2560)
115115

116116

117117
# %%

0 commit comments

Comments
 (0)