Skip to content

Commit 02c328f

Browse files
committed
Add testcase for log_softmax
1 parent f3dab79 commit 02c328f

File tree

3 files changed

+128
-12
lines changed

3 files changed

+128
-12
lines changed

custom_ops/batch_invariant_ops/batch_invariant_ops.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -285,19 +285,19 @@ def _log_softmax_kernel(
285285
tl.store(output_row_start_ptr + col_idx, output, mask=mask)
286286

287287

288-
def log_softmax(input: paddle.Tensor, dim: int = -1) -> paddle.Tensor:
288+
def log_softmax(input: paddle.Tensor, axis: int = -1) -> paddle.Tensor:
289289
"""
290290
Compute log_softmax using Triton kernel.
291291
292292
Args:
293293
input: Input tensor
294-
dim: Dimension along which to compute log_softmax (only -1 or last dim supported)
294+
axis: Dimension along which to compute log_softmax (only -1 or last dim supported)
295295
>> Stashed changes
296296
Returns:
297297
Tensor with log_softmax applied along the specified dimension
298298
"""
299-
# TODO:use axis not dim in paddle
300-
if dim != -1 and dim != input.ndim - 1:
299+
# print("You are using triton impl for log_softmax")
300+
if axis != -1 and axis != input.ndim - 1:
301301
raise ValueError("This implementation only supports log_softmax along the last dimension")
302302

303303
# Flatten all dimensions except the last one
@@ -477,10 +477,8 @@ def addmm_batch_invariant(bias, a, b, alpha=1.0, beta=1.0):
477477
return result
478478

479479

480-
def _log_softmax_batch_invariant(input, dim, _half_to_float):
481-
# TODO:use axis not dim in Paddle
482-
assert not _half_to_float, "not implemented"
483-
return log_softmax(input, dim=dim)
480+
def _log_softmax_batch_invariant(input, axis):
481+
return log_softmax(input, axis=axis)
484482

485483

486484
def mean_batch_invariant(input, dim, keepdim=False, dtype: paddle.dtype | None = None):
@@ -511,12 +509,12 @@ def enable_batch_invariant_mode():
511509

512510
_original_ops["mm"] = paddle._C_ops.matmul
513511
_original_ops["addmm"] = paddle._C_ops.addmm
514-
_original_ops["log_softmax"] = paddle.nn.functional.log_softmax
512+
_original_ops["log_softmax"] = paddle._C_ops.log_softmax
515513
_original_ops["mean"] = paddle.mean
516514

517515
paddle._C_ops.matmul = mm_batch_invariant
518516
paddle._C_ops.addmm = addmm_batch_invariant
519-
paddle.nn.functional.log_softmax = _log_softmax_batch_invariant
517+
paddle._C_ops.log_softmax = _log_softmax_batch_invariant
520518
paddle.mean = mean_batch_invariant
521519

522520
_batch_invariant_MODE = True
@@ -532,7 +530,7 @@ def disable_batch_invariant_mode():
532530
if _original_ops["addmm"]:
533531
paddle._C_ops.addmm = _original_ops["addmm"]
534532
if _original_ops["log_softmax"]:
535-
paddle.nn.functional.log_softmax = _original_ops["log_softmax"]
533+
paddle._C_ops.log_softmax = _original_ops["log_softmax"]
536534
if _original_ops["mean"]:
537535
paddle.mean = _original_ops["mean"]
538536

@@ -543,7 +541,6 @@ def disable_batch_invariant_mode():
543541
def set_batch_invariant_mode(enabled: bool = True):
544542
global _batch_invariant_MODE, _original_ops
545543
old_mode = _batch_invariant_MODE
546-
# old_ops = _original_ops.copy()
547544
if enabled:
548545
enable_batch_invariant_mode()
549546
else:
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# Adapted from https://github.com/thinking-machines-lab/batch_invariant_ops/blob/main/batch_invariant_ops/test_batch_invariance.py
2+
3+
import random
4+
import unittest
5+
6+
import paddle
7+
8+
from custom_ops.batch_invariant_ops import set_batch_invariant_mode
9+
10+
11+
class TestBatchInvariantForLogsoftmax(unittest.TestCase):
12+
def setUp(self):
13+
"""
14+
Initialize the test environment
15+
"""
16+
device = "gpu" if paddle.is_compiled_with_cuda() else "cpu"
17+
paddle.set_device(device)
18+
19+
def create_softmax_trap_tensor(self, B, D, dtype):
20+
"""
21+
Constructs a "trap" tensor designed to trigger batch-invariance issues in Softmax/LogSoftmax.
22+
Inspired by https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/
23+
24+
Principle:
25+
The goal is to make the result of `exp(a - max(a))` contain numbers spanning an extremely wide numerical range
26+
(e.g., 1.0, 1e-5, 1e-10, and many numbers close to 0).
27+
When summing these numbers using parallel reduction, different summation orders (due to parallelism)
28+
can produce different accumulated rounding errors, leading to a subtle difference between
29+
batch (parallel) and single-sample (serial) computation results.
30+
"""
31+
# 1. Determine the desired values after `exp` and calculate the required input values using log().
32+
max_val = 20.0
33+
34+
# Offsets relative to max_val. These offsets result in values spanning vastly different orders of magnitude after exp.
35+
trap_values = [
36+
max_val, # Corresponds to exp(a-max) -> 1.0
37+
max_val - 4.6, # Corresponds to exp(a-max) -> ~1e-2
38+
max_val - 11.5, # Corresponds to exp(a-max) -> ~1e-5
39+
max_val - 23.0, # Corresponds to exp(a-max) -> ~1e-10
40+
]
41+
42+
# 2. Create a background tensor filled with a very large negative number.
43+
background_val = -1000.0
44+
a = paddle.full((B, D), background_val, dtype=dtype)
45+
46+
# 3. Scatter these "trap" values at random positions in each row.
47+
for i in range(B):
48+
# Randomly shuffle the positions of the trap values for each row to increase non-determinism.
49+
indices = random.sample(range(D), k=len(trap_values))
50+
for j, val in enumerate(trap_values):
51+
a[i, indices[j]] = val
52+
53+
return a
54+
55+
def test_batch_invariance(self, B: int = 2048, D: int = 4096, dtype=paddle.float32):
56+
a = self.create_softmax_trap_tensor(B, D, dtype)
57+
58+
# Method 1: Matrix-vector multiplication (batch size 1)
59+
out1 = paddle.nn.functional.log_softmax(a[:1])
60+
61+
# Method 2: Matrix-matrix multiplication, then slice (full batch)
62+
out2 = paddle.nn.functional.log_softmax(a)[:1]
63+
64+
# Check if results are identical
65+
diff = (out1 - out2).abs().max()
66+
return diff.item() == 0, diff
67+
68+
def run_iters(self, iters=10, ass=False):
69+
for dtype in [paddle.float32, paddle.bfloat16, paddle.float16]:
70+
is_deterministic = True
71+
difflist = []
72+
for i in range(iters):
73+
isd, df = self.test_batch_invariance(dtype=dtype)
74+
is_deterministic = is_deterministic and isd
75+
difflist.append(df)
76+
print(
77+
f"Batch Deterministic: {is_deterministic} run-to-run max/min/diff {max(difflist)}/{min(difflist)}/{max(difflist)-min(difflist)} for {dtype} in {iters} iterations"
78+
)
79+
if ass:
80+
assert max(difflist) == 0
81+
82+
def test_case(self):
83+
# Test with standard Paddle (likely to show differences)
84+
print("Standard Paddle:")
85+
with set_batch_invariant_mode(False):
86+
self.run_iters(ass=False)
87+
# Test with batch-invariant operations
88+
print("\nBatch-Invariant Mode:")
89+
with set_batch_invariant_mode(True):
90+
self.run_iters(ass=True)
91+
92+
93+
if __name__ == "__main__":
94+
unittest.main()
95+
"""
96+
Even in Standard Paddle, we can achieve deterministic results, so maybe the standard implementation is already batch-invariant?
97+
98+
Result:
99+
100+
Standard Paddle:
101+
Batch Deterministic: True run-to-run max/min/diff 0.0/0.0/0.0 for paddle.float32 in 10 iterations
102+
Batch Deterministic: True run-to-run max/min/diff 0.0/0.0/0.0 for paddle.bfloat16 in 10 iterations
103+
Batch Deterministic: True run-to-run max/min/diff 0.0/0.0/0.0 for paddle.float16 in 10 iterations
104+
105+
Batch-Invariant Mode:
106+
Batch Deterministic: True run-to-run max/min/diff 0.0/0.0/0.0 for paddle.float32 in 10 iterations
107+
Batch Deterministic: True run-to-run max/min/diff 0.0/0.0/0.0 for paddle.bfloat16 in 10 iterations
108+
Batch Deterministic: True run-to-run max/min/diff 0.0/0.0/0.0 for paddle.float16 in 10 iterations
109+
"""

tests/batch_invariant/test_batch_invariance_op_mm.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,13 @@ def test_case(self):
5656

5757
if __name__ == "__main__":
5858
unittest.main()
59+
"""
60+
61+
Standard Paddle:
62+
Batch Deterministic: False run-to-run max/min/diff 10.7294921875/10.7294921875/0.0 for paddle.float32 in 10 iterations
63+
Batch Deterministic: True run-to-run max/min/diff 0.0/0.0/0.0 for paddle.bfloat16 in 10 iterations
64+
65+
Batch-Invariant Mode:
66+
Batch Deterministic: True run-to-run max/min/diff 0.0/0.0/0.0 for paddle.float32 in 10 iterations
67+
Batch Deterministic: True run-to-run max/min/diff 0.0/0.0/0.0 for paddle.bfloat16 in 10 iterations
68+
"""

0 commit comments

Comments
 (0)