Skip to content

Commit f3dab79

Browse files
committed
Change testcase to FD style
1 parent 9f97167 commit f3dab79

File tree

1 file changed

+58
-0
lines changed

1 file changed

+58
-0
lines changed
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Adapted from https://github.com/thinking-machines-lab/batch_invariant_ops/blob/main/batch_invariant_ops/test_batch_invariance.py
2+
3+
import unittest
4+
5+
import paddle
6+
7+
from custom_ops.batch_invariant_ops import set_batch_invariant_mode
8+
9+
10+
class TestBatchInvariantForMM(unittest.TestCase):
11+
def setUp(self):
12+
"""
13+
Initialize the test environment
14+
"""
15+
device = "gpu" if paddle.is_compiled_with_cuda() else "cpu"
16+
paddle.set_device(device)
17+
18+
def test_batch_invariance(self, B: int = 2048, D: int = 4096, dtype=paddle.float32):
19+
a = paddle.linspace(-100, 100, B * D, dtype=dtype).reshape(B, D)
20+
b = paddle.linspace(-100, 100, D * D, dtype=dtype).reshape(D, D)
21+
22+
# Method 1: Matrix-vector multiplication (batch size 1)
23+
out1 = paddle.mm(a[:1], b)
24+
25+
# Method 2: Matrix-matrix multiplication, then slice (full batch)
26+
out2 = paddle.mm(a, b)[:1]
27+
28+
# Check if results are identical
29+
diff = (out1 - out2).abs().max()
30+
return diff.item() == 0, diff
31+
32+
def run_iters(self, iters=10, ass=False):
33+
for dtype in [paddle.float32, paddle.bfloat16]:
34+
is_deterministic = True
35+
difflist = []
36+
for i in range(iters):
37+
isd, df = self.test_batch_invariance(dtype=dtype)
38+
is_deterministic = is_deterministic and isd
39+
difflist.append(df)
40+
print(
41+
f"Batch Deterministic: {is_deterministic} run-to-run max/min/diff {max(difflist)}/{min(difflist)}/{max(difflist)-min(difflist)} for {dtype} in {iters} iterations"
42+
)
43+
if ass:
44+
assert max(difflist) == 0
45+
46+
def test_case(self):
47+
# Test with standard Paddle (likely to show differences)
48+
print("Standard Paddle:")
49+
with set_batch_invariant_mode(False):
50+
self.run_iters(ass=False)
51+
# Test with batch-invariant operations
52+
print("\nBatch-Invariant Mode:")
53+
with set_batch_invariant_mode(True):
54+
self.run_iters(ass=True)
55+
56+
57+
if __name__ == "__main__":
58+
unittest.main()

0 commit comments

Comments
 (0)