From 88dcea859d0e32683a52dd7b71bcedb79234a54d Mon Sep 17 00:00:00 2001 From: littledgg <1658565283@qq.com> Date: Mon, 3 Nov 2025 15:02:27 +0800 Subject: [PATCH 01/16] Move batch invariant pkg to Fastdeploy --- custom_ops/batch_invariant_ops/__init__.py | 25 + .../batch_invariant_ops.py | 571 ++++++++++++++++++ .../test_batch_invariance.py | 46 ++ 3 files changed, 642 insertions(+) create mode 100644 custom_ops/batch_invariant_ops/__init__.py create mode 100644 custom_ops/batch_invariant_ops/batch_invariant_ops.py create mode 100644 custom_ops/batch_invariant_ops/test_batch_invariance.py diff --git a/custom_ops/batch_invariant_ops/__init__.py b/custom_ops/batch_invariant_ops/__init__.py new file mode 100644 index 00000000000..07cc85b2de4 --- /dev/null +++ b/custom_ops/batch_invariant_ops/__init__.py @@ -0,0 +1,25 @@ +from .batch_invariant_ops import ( + AttentionBlockSize, + disable_batch_invariant_mode, + enable_batch_invariant_mode, + get_batch_invariant_attention_block_size, + is_batch_invariant_mode_enabled, + log_softmax, + matmul_persistent, + mean_dim, + set_batch_invariant_mode, +) + +__version__ = "0.1.0" + +__all__ = [ + "set_batch_invariant_mode", + "is_batch_invariant_mode_enabled", + "disable_batch_invariant_mode", + "enable_batch_invariant_mode", + "matmul_persistent", + "log_softmax", + "mean_dim", + "get_batch_invariant_attention_block_size", + "AttentionBlockSize", +] \ No newline at end of file diff --git a/custom_ops/batch_invariant_ops/batch_invariant_ops.py b/custom_ops/batch_invariant_ops/batch_invariant_ops.py new file mode 100644 index 00000000000..c1880364a59 --- /dev/null +++ b/custom_ops/batch_invariant_ops/batch_invariant_ops.py @@ -0,0 +1,571 @@ +# Adapted from https://github.com/thinking-machines-lab/batch_invariant_ops/blob/main/batch_invariant_ops/batch_invariant_ops.py + +import contextlib +from collections import namedtuple +from collections.abc import Callable +from typing import Any, Dict + +import paddle +import triton +import triton.language as tl +# paddle.compat.use_torch_proxy_guard() + +__all__ = ["set_batch_invariant_mode", "is_batch_invariant_mode_enabled", "disable_batch_invariant_mode", "enable_batch_invariant_mode"] + + +def _matmul_launch_metadata( + grid: Callable[..., Any], kernel: Any, args: Dict[str, Any] +) -> Dict[str, Any]: + ret = {} + m, n, k = args["M"], args["N"], args["K"] + ret["name"] = f"{kernel.name} [M={m}, N={n}, K={k}]" + if "tiles_per_update" in args: + ret["name"] = ( + f"{kernel.name} [M={m}, N={n}, K={k}, tiles_per_update={args['tiles_per_update']:02}]" + ) + if "c_ptr" in args: + bytes_per_elem = args["c_ptr"].element_size() + else: + bytes_per_elem = 1 if args["FP8_OUTPUT"] else 2 + ret[f"flops{bytes_per_elem * 8}"] = 2.0 * m * n * k + ret["bytes"] = bytes_per_elem * (m * k + n * k + m * n) + return ret + + +@triton.jit +def _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS): + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (tile_id % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + return pid_m, pid_n + + +@triton.jit(launch_metadata=_matmul_launch_metadata) +def matmul_kernel_persistent( + a_ptr, + b_ptr, + c_ptr, # + bias_ptr, + M, + N, + K, # + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + BLOCK_SIZE_M: tl.constexpr, # + BLOCK_SIZE_N: tl.constexpr, # + BLOCK_SIZE_K: tl.constexpr, # + GROUP_SIZE_M: tl.constexpr, # + NUM_SMS: tl.constexpr, # + A_LARGE: tl.constexpr, + B_LARGE: tl.constexpr, + C_LARGE: tl.constexpr, + HAS_BIAS: tl.constexpr, +): + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + num_tiles = num_pid_m * num_pid_n + + tile_id_c = start_pid - NUM_SMS + + offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + + for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True): + pid_m, pid_n = _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS) + start_m = pid_m * BLOCK_SIZE_M + start_n = pid_n * BLOCK_SIZE_N + offs_am = start_m + tl.arange(0, BLOCK_SIZE_M) + offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N) + if A_LARGE: + offs_am = offs_am.to(tl.int64) + if B_LARGE: + offs_bn = offs_bn.to(tl.int64) + offs_am = tl.where(offs_am < M, offs_am, 0) + offs_bn = tl.where(offs_bn < N, offs_bn, 0) + offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M) + offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for ki in range(k_tiles): + if A_LARGE or B_LARGE: + offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K).to(tl.int64) + else: + offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + a = tl.load(a_ptrs, mask=offs_k_for_mask[None, :] < K - ki * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k_for_mask[:, None] < K - ki * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, b, accumulator) + + tile_id_c += NUM_SMS + pid_m, pid_n = _compute_pid(tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS) + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + if C_LARGE: + offs_cm = offs_cm.to(tl.int64) + offs_cn = offs_cn.to(tl.int64) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + if HAS_BIAS: + bias_ptrs = bias_ptr + offs_cn + bias = tl.load(bias_ptrs, mask=offs_cn < N, other=0.0).to(tl.float32) + accumulator += bias + c = accumulator.to(c_ptr.dtype.element_ty) + tl.store(c_ptrs, c, mask=c_mask) + +def get_compute_units(): + """ + Returns the number of streaming multiprocessors (SMs) or equivalent compute units + for the available accelerator. Assigns the value to NUM_SMS. + """ + NUM_SMS = None + + if paddle.is_compiled_with_cuda(): + try: + paddle.device.get_device()#Triton + Paddle may can't get the device + device_properties = paddle.cuda.get_device_properties(0) + NUM_SMS = device_properties.multi_processor_count + except Exception: + print("Could not get CUDA device properties. Falling back to CPU threads.") + NUM_SMS = os.cpu_count() + else: + print("No CUDA device available. Using CPU.") + # For CPU, use the number of CPU cores + NUM_SMS = os.cpu_count() + + return NUM_SMS + + +def matmul_persistent(a: paddle.Tensor, b: paddle.Tensor, bias: paddle.Tensor | None = None): + # Check constraints. + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + assert a.dtype == b.dtype, "Incompatible dtypes" + assert bias is None or bias.dim() == 1, ( + "Currently assuming bias is 1D, let Horace know if you run into this" + ) + + + NUM_SMS = get_compute_units() + M, K = a.shape + K, N = b.shape + dtype = a.dtype + # Allocates output. In PaddlePaddle, we create on the same device as input tensor + # Simply create the tensor without specifying device, Paddle will handle it + c = paddle.empty((M, N), dtype=dtype) + + # 1D launch kernel where each block gets its own program. + def grid(META): + return ( + min( + NUM_SMS, triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]) + ), + ) + + configs = { + paddle.bfloat16: { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_stages": 3, + "num_warps": 8, + }, + paddle.float16: { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_stages": 3, + "num_warps": 8, + }, + paddle.float32: { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + "num_stages": 3, + "num_warps": 8, + }, + } + # print(a.device, b.device, c.device) + matmul_kernel_persistent[grid]( + a, + b, + c, # + bias, + M, + N, + K, # + a.stride(0), + a.stride(1), # + b.stride(0), + b.stride(1), # + c.stride(0), + c.stride(1), # + NUM_SMS=NUM_SMS, # + A_LARGE=int(a.numel() > 2**31), + B_LARGE=int(b.numel() > 2**31), + C_LARGE=int(c.numel() > 2**31), + HAS_BIAS=int(bias is not None), + # The Triton compiler (when used with Paddle) cannot handle these variables as booleans. Explicitly cast to int so the compiler can process them. + **configs[dtype], + ) + return c + + +@triton.jit +def _log_softmax_kernel( + input_ptr, + output_ptr, + input_row_stride, + output_row_stride, + n_cols, + BLOCK_SIZE: tl.constexpr, +): + """ + Compute log_softmax along the last dimension of a 2D tensor. + Each block handles one row of the input tensor. + """ + # Get the row index for this block + row_idx = tl.program_id(0).to(tl.int64) + + # Compute base pointers for input and output rows + row_start_ptr = input_ptr + row_idx * input_row_stride + output_row_start_ptr = output_ptr + row_idx * output_row_stride + + # Step 1: Find maximum value in the row for numerical stability + max_val = -float("inf") + for col_offset in range(0, n_cols, BLOCK_SIZE): + col_idx = col_offset + tl.arange(0, BLOCK_SIZE) + mask = col_idx < n_cols + + # Load values + vals = tl.load(row_start_ptr + col_idx, mask=mask, other=-float("inf")) + + # Update maximum + max_val = tl.max(tl.maximum(vals, max_val)) + + # Step 2: Compute sum of exp(x - max_val) + sum_exp = 0.0 + for col_offset in range(0, n_cols, BLOCK_SIZE): + col_idx = col_offset + tl.arange(0, BLOCK_SIZE) + mask = col_idx < n_cols + + # Load values + vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0) + + # Compute exp(x - max_val) and accumulate + exp_vals = tl.exp(vals - max_val) + sum_exp += tl.sum(tl.where(mask, exp_vals, 0.0)) + + # Compute log(sum_exp) + log_sum_exp = tl.log(sum_exp) + + # Step 3: Compute final log_softmax values: x - max_val - log_sum_exp + for col_offset in range(0, n_cols, BLOCK_SIZE): + col_idx = col_offset + tl.arange(0, BLOCK_SIZE) + mask = col_idx < n_cols + + # Load values + vals = tl.load(row_start_ptr + col_idx, mask=mask) + + # Compute log_softmax + output = vals - max_val - log_sum_exp + + # Store results + tl.store(output_row_start_ptr + col_idx, output, mask=mask) + + +def log_softmax(input: paddle.Tensor, dim: int = -1) -> paddle.Tensor: + """ + Compute log_softmax using Triton kernel. + + Args: + input: Input tensor + dim: Dimension along which to compute log_softmax (only -1 or last dim supported) + >> Stashed changes + Returns: + Tensor with log_softmax applied along the specified dimension + """ + #TODO:use axis not dim in paddle + if dim != -1 and dim != input.ndim - 1: + raise ValueError("This implementation only supports log_softmax along the last dimension") + + # Flatten all dimensions except the last one + original_shape = input.shape + input_2d = input.reshape(-1, input.shape[-1]) + input_2d = input_2d.contiguous() + + n_rows, n_cols = input_2d.shape + + # Allocate output tensor + output = paddle.empty_like(input_2d) + + # Choose block size based on the number of columns + BLOCK_SIZE = 1024 + + # Launch kernel with one block per row + grid = (n_rows,) + _log_softmax_kernel[grid]( + input_2d, + output, + input_2d.stride(0), + output.stride(0), + n_cols, + BLOCK_SIZE=BLOCK_SIZE, + ) + # Reshape output back to original shape + return output.reshape(original_shape) + + +@triton.jit +def mean_kernel( + input_ptr, + output_ptr, + input_stride0, + input_stride1, + input_stride2, + output_stride0, + output_stride1, + M, # size before reduction dim + N, # size of reduction dim + K, # size after reduction dim + BLOCK_SIZE: tl.constexpr, +): + """ + Kernel for computing mean along a single dimension. + Input is viewed as (M, N, K) where N is the dimension being reduced. + """ + # Program ID gives us which output element we're computing + pid = tl.program_id(0) + + # Compute output indices + m_idx = pid // K + k_idx = pid % K + + # Bounds check + if m_idx >= M or k_idx >= K: + return + + # Accumulate sum across reduction dimension + acc = 0.0 + for n_start in range(0, N, BLOCK_SIZE): + n_offsets = n_start + tl.arange(0, BLOCK_SIZE) + mask = n_offsets < N + + # Calculate input indices + input_idx = m_idx * input_stride0 + n_offsets * input_stride1 + k_idx * input_stride2 + + # Load and accumulate + vals = tl.load(input_ptr + input_idx, mask=mask, other=0.0) + acc += tl.sum(vals) + + # Compute mean and store + mean_val = acc / N + output_idx = m_idx * output_stride0 + k_idx * output_stride1 + tl.store(output_ptr + output_idx, mean_val) + + +def mean_dim( + input: paddle.Tensor, dim: int, keepdim: bool = False, dtype: paddle.dtype | None = None +) -> paddle.Tensor: + """ + Triton implementation of paddle.mean with single dimension reduction. + + Args: + input: Input tensor + dim: Single dimension along which to compute mean + keepdim: Whether to keep the reduced dimension + dtype: Output dtype. If None, uses input dtype (or float32 for integer inputs) + + Returns: + Tensor with mean values along specified dimension + """ + # Validate inputs + assert input.is_cuda, "Input must be a CUDA tensor" + assert -input.ndim <= dim < input.ndim, ( + f"Invalid dimension {dim} for tensor with {input.ndim} dimensions" + ) + + # Handle negative dim + if dim < 0: + dim = dim + input.ndim + + # Handle dtype + if dtype is None: + if input.dtype in [paddle.int8, paddle.int16, paddle.int32, paddle.int64]: + dtype = paddle.float32 + else: + dtype = input.dtype + + # Convert input to appropriate dtype if needed + if input.dtype != dtype: + input = input.to(dtype) + + # Get input shape and strides + shape = list(input.shape) + + # Calculate dimensions for kernel + M = 1 + for i in range(dim): + M *= shape[i] + + N = shape[dim] + + K = 1 + for i in range(dim + 1, len(shape)): + K *= shape[i] + + # Reshape input to 3D view (M, N, K) + input_3d = input.reshape(M, N, K) + + # Create output shape + if keepdim: + output_shape = shape.copy() + output_shape[dim] = 1 + else: + output_shape = shape[:dim] + shape[dim + 1 :] + + # Create output tensor + output = paddle.empty(output_shape, dtype=dtype) + + # Reshape output for kernel + if keepdim: + output_2d = output.reshape(M, 1, K).squeeze(1) + else: + output_2d = output.reshape(M, K) + + # Launch kernel + grid = (M * K,) + BLOCK_SIZE = 1024 + + mean_kernel[grid]( + input_3d, + output_2d, + input_3d.stride(0), + input_3d.stride(1), + input_3d.stride(2), + output_2d.stride(0), + output_2d.stride(1) if output_2d.ndim > 1 else 0, + M, + N, + K, + BLOCK_SIZE, + ) + + return output + + +def mm_batch_invariant(a, b, transpose_x=False, transpose_y=False): + if transpose_x: + a = a.T + if transpose_y: + b = b.T + return matmul_persistent(a, b) + + +def addmm_batch_invariant(bias, a, b, alpha=1.0, beta=1.0): + #TODO:check API + result = matmul_persistent(a, b, bias=bias) + return result + + +def _log_softmax_batch_invariant(input, dim, _half_to_float): + #TODO:use axis not dim in Paddle + assert not _half_to_float, "not implemented" + return log_softmax(input, dim=dim) + + +def mean_batch_invariant(input, dim, keepdim=False, dtype: paddle.dtype | None = None): + assert dtype is None or dtype == paddle.float32, f"unsupported dtype: {dtype}" + if len(dim) == 1: + return mean_dim(input, dim[0], keepdim=keepdim) + else: + assert input.dtype in {paddle.float16, paddle.bfloat16, paddle.float32}, ( + "only float types supported for now" + ) + n_elems = 1 + for d in dim: + n_elems *= input.shape[d] + return paddle.sum(input, dim=dim, keepdim=keepdim, dtype=paddle.float32) / n_elems + + +_original_ops = { + 'mm': None, + 'addmm': None, + '_log_softmax': None, + 'mean_dim': None +} + +_batch_invariant_MODE = False + + +def is_batch_invariant_mode_enabled(): + return _batch_invariant_MODE + + +def enable_batch_invariant_mode(): + global _batch_invariant_MODE, _original_ops + if _batch_invariant_MODE: + return + + _original_ops['mm'] = paddle._C_ops.matmul + _original_ops['addmm'] = paddle._C_ops.addmm + _original_ops['log_softmax'] = paddle.nn.functional.log_softmax + _original_ops['mean'] = paddle.mean + + paddle._C_ops.matmul = mm_batch_invariant + paddle._C_ops.addmm = addmm_batch_invariant + paddle.nn.functional.log_softmax = _log_softmax_batch_invariant + paddle.mean = mean_batch_invariant + + _batch_invariant_MODE = True + + +def disable_batch_invariant_mode(): + global _batch_invariant_MODE, _original_ops + if not _batch_invariant_MODE: + return + + if _original_ops['mm']: + paddle._C_ops.matmul = _original_ops['mm'] + if _original_ops['addmm']: + paddle._C_ops.addmm = _original_ops['addmm'] + if _original_ops['log_softmax']: + paddle.nn.functional.log_softmax = _original_ops['log_softmax'] + if _original_ops['mean']: + paddle.mean = _original_ops['mean'] + + _batch_invariant_MODE = False + + +@contextlib.contextmanager +def set_batch_invariant_mode(enabled: bool = True): + global _batch_invariant_MODE, _original_ops + old_mode = _batch_invariant_MODE + # old_ops = _original_ops.copy() + if enabled: + enable_batch_invariant_mode() + else: + disable_batch_invariant_mode() + yield + if old_mode: + enable_batch_invariant_mode() + else: + disable_batch_invariant_mode() + + +AttentionBlockSize = namedtuple("AttentionBlockSize", ["block_m", "block_n"]) + + +def get_batch_invariant_attention_block_size() -> AttentionBlockSize: + return AttentionBlockSize(block_m=16, block_n=16) diff --git a/custom_ops/batch_invariant_ops/test_batch_invariance.py b/custom_ops/batch_invariant_ops/test_batch_invariance.py new file mode 100644 index 00000000000..01aefc488e0 --- /dev/null +++ b/custom_ops/batch_invariant_ops/test_batch_invariance.py @@ -0,0 +1,46 @@ +# Adapted from https://github.com/thinking-machines-lab/batch_invariant_ops/blob/main/batch_invariant_ops/test_batch_invariance.py + +import paddle +from batch_invariant_ops import set_batch_invariant_mode + +device = "gpu" if paddle.is_compiled_with_cuda() else "cpu" +paddle.set_device(device) + +# Just to get the logging out of the way haha +with set_batch_invariant_mode(True): + pass + +def test_batch_invariance(dtype=paddle.float32): + B, D = 2048, 4096 + a = paddle.linspace(-100, 100, B*D, dtype=dtype).reshape(B, D) + b = paddle.linspace(-100, 100, D*D, dtype=dtype).reshape(D, D) + + # Method 1: Matrix-vector multiplication (batch size 1) + out1 = paddle.mm(a[:1], b) + + # Method 2: Matrix-matrix multiplication, then slice (full batch) + out2 = paddle.mm(a, b)[:1] + + # Check if results are identical + diff = (out1 - out2).abs().max() + return diff.item() == 0, diff + +def run_iters(iters=10): + for dtype in [ paddle.float32 , paddle.bfloat16 ]: + is_deterministic = True + difflist = [] + for i in range (iters): + isd, df = test_batch_invariance(dtype) + is_deterministic = is_deterministic and isd + difflist.append(df) + print( f"Batch Deterministic: {is_deterministic} run-to-run max/min/diff {max(difflist)}/{min(difflist)}/{max(difflist)-min(difflist)} for {dtype} in {iters} iterations") + + +# Test with standard Paddle (likely to show differences) +print("Standard Paddle:") +with set_batch_invariant_mode(False): + run_iters() +# Test with batch-invariant operations +print("\nBatch-Invariant Mode:") +with set_batch_invariant_mode(True): + run_iters() From 77d804e7ddbd191a5337497c6f48f32ce4be1bbc Mon Sep 17 00:00:00 2001 From: littledgg <1658565283@qq.com> Date: Mon, 3 Nov 2025 19:57:39 +0800 Subject: [PATCH 02/16] fix problem and pre-commit --- .../batch_invariant_ops.py | 93 +++++++++---------- 1 file changed, 42 insertions(+), 51 deletions(-) diff --git a/custom_ops/batch_invariant_ops/batch_invariant_ops.py b/custom_ops/batch_invariant_ops/batch_invariant_ops.py index c1880364a59..7b83e7f7e6a 100644 --- a/custom_ops/batch_invariant_ops/batch_invariant_ops.py +++ b/custom_ops/batch_invariant_ops/batch_invariant_ops.py @@ -1,6 +1,7 @@ # Adapted from https://github.com/thinking-machines-lab/batch_invariant_ops/blob/main/batch_invariant_ops/batch_invariant_ops.py import contextlib +import os from collections import namedtuple from collections.abc import Callable from typing import Any, Dict @@ -8,21 +9,23 @@ import paddle import triton import triton.language as tl -# paddle.compat.use_torch_proxy_guard() -__all__ = ["set_batch_invariant_mode", "is_batch_invariant_mode_enabled", "disable_batch_invariant_mode", "enable_batch_invariant_mode"] +paddle.compat.enable_torch_proxy() +__all__ = [ + "set_batch_invariant_mode", + "is_batch_invariant_mode_enabled", + "disable_batch_invariant_mode", + "enable_batch_invariant_mode", +] -def _matmul_launch_metadata( - grid: Callable[..., Any], kernel: Any, args: Dict[str, Any] -) -> Dict[str, Any]: + +def _matmul_launch_metadata(grid: Callable[..., Any], kernel: Any, args: Dict[str, Any]) -> Dict[str, Any]: ret = {} m, n, k = args["M"], args["N"], args["K"] ret["name"] = f"{kernel.name} [M={m}, N={n}, K={k}]" if "tiles_per_update" in args: - ret["name"] = ( - f"{kernel.name} [M={m}, N={n}, K={k}, tiles_per_update={args['tiles_per_update']:02}]" - ) + ret["name"] = f"{kernel.name} [M={m}, N={n}, K={k}, tiles_per_update={args['tiles_per_update']:02}]" if "c_ptr" in args: bytes_per_elem = args["c_ptr"].element_size() else: @@ -122,20 +125,24 @@ def matmul_kernel_persistent( c = accumulator.to(c_ptr.dtype.element_ty) tl.store(c_ptrs, c, mask=c_mask) + def get_compute_units(): """ Returns the number of streaming multiprocessors (SMs) or equivalent compute units for the available accelerator. Assigns the value to NUM_SMS. """ NUM_SMS = None - + if paddle.is_compiled_with_cuda(): try: - paddle.device.get_device()#Triton + Paddle may can't get the device + paddle.device.get_device() # Triton + Paddle may can't get the device device_properties = paddle.cuda.get_device_properties(0) NUM_SMS = device_properties.multi_processor_count except Exception: print("Could not get CUDA device properties. Falling back to CPU threads.") + # TODO: Paddle lacks a torch.get_num_threads() equivalent for the *configured* thread count. + # Using os.cpu_count() (total logical cores) as a fallback, which may not be correct. + # Must check downstream logic to determine if this impacts correctness. NUM_SMS = os.cpu_count() else: print("No CUDA device available. Using CPU.") @@ -149,10 +156,7 @@ def matmul_persistent(a: paddle.Tensor, b: paddle.Tensor, bias: paddle.Tensor | # Check constraints. assert a.shape[1] == b.shape[0], "Incompatible dimensions" assert a.dtype == b.dtype, "Incompatible dtypes" - assert bias is None or bias.dim() == 1, ( - "Currently assuming bias is 1D, let Horace know if you run into this" - ) - + assert bias is None or bias.dim() == 1, "Currently assuming bias is 1D, let Horace know if you run into this" NUM_SMS = get_compute_units() M, K = a.shape @@ -164,11 +168,7 @@ def matmul_persistent(a: paddle.Tensor, b: paddle.Tensor, bias: paddle.Tensor | # 1D launch kernel where each block gets its own program. def grid(META): - return ( - min( - NUM_SMS, triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]) - ), - ) + return (min(NUM_SMS, triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"])),) configs = { paddle.bfloat16: { @@ -296,7 +296,7 @@ def log_softmax(input: paddle.Tensor, dim: int = -1) -> paddle.Tensor: Returns: Tensor with log_softmax applied along the specified dimension """ - #TODO:use axis not dim in paddle + # TODO:use axis not dim in paddle if dim != -1 and dim != input.ndim - 1: raise ValueError("This implementation only supports log_softmax along the last dimension") @@ -392,9 +392,7 @@ def mean_dim( """ # Validate inputs assert input.is_cuda, "Input must be a CUDA tensor" - assert -input.ndim <= dim < input.ndim, ( - f"Invalid dimension {dim} for tensor with {input.ndim} dimensions" - ) + assert -input.ndim <= dim < input.ndim, f"Invalid dimension {dim} for tensor with {input.ndim} dimensions" # Handle negative dim if dim < 0: @@ -474,13 +472,13 @@ def mm_batch_invariant(a, b, transpose_x=False, transpose_y=False): def addmm_batch_invariant(bias, a, b, alpha=1.0, beta=1.0): - #TODO:check API + # TODO:check API result = matmul_persistent(a, b, bias=bias) return result def _log_softmax_batch_invariant(input, dim, _half_to_float): - #TODO:use axis not dim in Paddle + # TODO:use axis not dim in Paddle assert not _half_to_float, "not implemented" return log_softmax(input, dim=dim) @@ -490,21 +488,14 @@ def mean_batch_invariant(input, dim, keepdim=False, dtype: paddle.dtype | None = if len(dim) == 1: return mean_dim(input, dim[0], keepdim=keepdim) else: - assert input.dtype in {paddle.float16, paddle.bfloat16, paddle.float32}, ( - "only float types supported for now" - ) + assert input.dtype in {paddle.float16, paddle.bfloat16, paddle.float32}, "only float types supported for now" n_elems = 1 for d in dim: n_elems *= input.shape[d] return paddle.sum(input, dim=dim, keepdim=keepdim, dtype=paddle.float32) / n_elems -_original_ops = { - 'mm': None, - 'addmm': None, - '_log_softmax': None, - 'mean_dim': None -} +_original_ops = {"mm": None, "addmm": None, "_log_softmax": None, "mean_dim": None} _batch_invariant_MODE = False @@ -517,17 +508,17 @@ def enable_batch_invariant_mode(): global _batch_invariant_MODE, _original_ops if _batch_invariant_MODE: return - - _original_ops['mm'] = paddle._C_ops.matmul - _original_ops['addmm'] = paddle._C_ops.addmm - _original_ops['log_softmax'] = paddle.nn.functional.log_softmax - _original_ops['mean'] = paddle.mean - + + _original_ops["mm"] = paddle._C_ops.matmul + _original_ops["addmm"] = paddle._C_ops.addmm + _original_ops["log_softmax"] = paddle.nn.functional.log_softmax + _original_ops["mean"] = paddle.mean + paddle._C_ops.matmul = mm_batch_invariant paddle._C_ops.addmm = addmm_batch_invariant paddle.nn.functional.log_softmax = _log_softmax_batch_invariant paddle.mean = mean_batch_invariant - + _batch_invariant_MODE = True @@ -535,16 +526,16 @@ def disable_batch_invariant_mode(): global _batch_invariant_MODE, _original_ops if not _batch_invariant_MODE: return - - if _original_ops['mm']: - paddle._C_ops.matmul = _original_ops['mm'] - if _original_ops['addmm']: - paddle._C_ops.addmm = _original_ops['addmm'] - if _original_ops['log_softmax']: - paddle.nn.functional.log_softmax = _original_ops['log_softmax'] - if _original_ops['mean']: - paddle.mean = _original_ops['mean'] - + + if _original_ops["mm"]: + paddle._C_ops.matmul = _original_ops["mm"] + if _original_ops["addmm"]: + paddle._C_ops.addmm = _original_ops["addmm"] + if _original_ops["log_softmax"]: + paddle.nn.functional.log_softmax = _original_ops["log_softmax"] + if _original_ops["mean"]: + paddle.mean = _original_ops["mean"] + _batch_invariant_MODE = False From 9f971678483066fe3456907903da6e6d0345fc85 Mon Sep 17 00:00:00 2001 From: littledgg <1658565283@qq.com> Date: Mon, 3 Nov 2025 20:05:57 +0800 Subject: [PATCH 03/16] move test --- .../batch_invariant}/test_batch_invariance.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) rename {custom_ops/batch_invariant_ops => tests/batch_invariant}/test_batch_invariance.py (70%) diff --git a/custom_ops/batch_invariant_ops/test_batch_invariance.py b/tests/batch_invariant/test_batch_invariance.py similarity index 70% rename from custom_ops/batch_invariant_ops/test_batch_invariance.py rename to tests/batch_invariant/test_batch_invariance.py index 01aefc488e0..4fa4044d182 100644 --- a/custom_ops/batch_invariant_ops/test_batch_invariance.py +++ b/tests/batch_invariant/test_batch_invariance.py @@ -1,7 +1,8 @@ # Adapted from https://github.com/thinking-machines-lab/batch_invariant_ops/blob/main/batch_invariant_ops/test_batch_invariance.py import paddle -from batch_invariant_ops import set_batch_invariant_mode + +from custom_ops.batch_invariant_ops import set_batch_invariant_mode device = "gpu" if paddle.is_compiled_with_cuda() else "cpu" paddle.set_device(device) @@ -10,10 +11,11 @@ with set_batch_invariant_mode(True): pass + def test_batch_invariance(dtype=paddle.float32): B, D = 2048, 4096 - a = paddle.linspace(-100, 100, B*D, dtype=dtype).reshape(B, D) - b = paddle.linspace(-100, 100, D*D, dtype=dtype).reshape(D, D) + a = paddle.linspace(-100, 100, B * D, dtype=dtype).reshape(B, D) + b = paddle.linspace(-100, 100, D * D, dtype=dtype).reshape(D, D) # Method 1: Matrix-vector multiplication (batch size 1) out1 = paddle.mm(a[:1], b) @@ -25,15 +27,18 @@ def test_batch_invariance(dtype=paddle.float32): diff = (out1 - out2).abs().max() return diff.item() == 0, diff + def run_iters(iters=10): - for dtype in [ paddle.float32 , paddle.bfloat16 ]: + for dtype in [paddle.float32, paddle.bfloat16]: is_deterministic = True difflist = [] - for i in range (iters): + for i in range(iters): isd, df = test_batch_invariance(dtype) is_deterministic = is_deterministic and isd difflist.append(df) - print( f"Batch Deterministic: {is_deterministic} run-to-run max/min/diff {max(difflist)}/{min(difflist)}/{max(difflist)-min(difflist)} for {dtype} in {iters} iterations") + print( + f"Batch Deterministic: {is_deterministic} run-to-run max/min/diff {max(difflist)}/{min(difflist)}/{max(difflist)-min(difflist)} for {dtype} in {iters} iterations" + ) # Test with standard Paddle (likely to show differences) From f3dab794939e93af535371285f1a35e0e686c858 Mon Sep 17 00:00:00 2001 From: littledgg <1658565283@qq.com> Date: Tue, 4 Nov 2025 14:45:07 +0800 Subject: [PATCH 04/16] Change testcase to FD style --- .../test_batch_invariance_op_mm.py | 58 +++++++++++++++++++ 1 file changed, 58 insertions(+) create mode 100644 tests/batch_invariant/test_batch_invariance_op_mm.py diff --git a/tests/batch_invariant/test_batch_invariance_op_mm.py b/tests/batch_invariant/test_batch_invariance_op_mm.py new file mode 100644 index 00000000000..22fb9c42c82 --- /dev/null +++ b/tests/batch_invariant/test_batch_invariance_op_mm.py @@ -0,0 +1,58 @@ +# Adapted from https://github.com/thinking-machines-lab/batch_invariant_ops/blob/main/batch_invariant_ops/test_batch_invariance.py + +import unittest + +import paddle + +from custom_ops.batch_invariant_ops import set_batch_invariant_mode + + +class TestBatchInvariantForMM(unittest.TestCase): + def setUp(self): + """ + Initialize the test environment + """ + device = "gpu" if paddle.is_compiled_with_cuda() else "cpu" + paddle.set_device(device) + + def test_batch_invariance(self, B: int = 2048, D: int = 4096, dtype=paddle.float32): + a = paddle.linspace(-100, 100, B * D, dtype=dtype).reshape(B, D) + b = paddle.linspace(-100, 100, D * D, dtype=dtype).reshape(D, D) + + # Method 1: Matrix-vector multiplication (batch size 1) + out1 = paddle.mm(a[:1], b) + + # Method 2: Matrix-matrix multiplication, then slice (full batch) + out2 = paddle.mm(a, b)[:1] + + # Check if results are identical + diff = (out1 - out2).abs().max() + return diff.item() == 0, diff + + def run_iters(self, iters=10, ass=False): + for dtype in [paddle.float32, paddle.bfloat16]: + is_deterministic = True + difflist = [] + for i in range(iters): + isd, df = self.test_batch_invariance(dtype=dtype) + is_deterministic = is_deterministic and isd + difflist.append(df) + print( + f"Batch Deterministic: {is_deterministic} run-to-run max/min/diff {max(difflist)}/{min(difflist)}/{max(difflist)-min(difflist)} for {dtype} in {iters} iterations" + ) + if ass: + assert max(difflist) == 0 + + def test_case(self): + # Test with standard Paddle (likely to show differences) + print("Standard Paddle:") + with set_batch_invariant_mode(False): + self.run_iters(ass=False) + # Test with batch-invariant operations + print("\nBatch-Invariant Mode:") + with set_batch_invariant_mode(True): + self.run_iters(ass=True) + + +if __name__ == "__main__": + unittest.main() From 02c328fe4f89a0c8aafe6abecdd72b8409497097 Mon Sep 17 00:00:00 2001 From: littledgg <1658565283@qq.com> Date: Tue, 4 Nov 2025 17:34:27 +0800 Subject: [PATCH 05/16] Add testcase for log_softmax --- .../batch_invariant_ops.py | 21 ++-- .../test_batch_invariance_op_logsoftmax.py | 109 ++++++++++++++++++ .../test_batch_invariance_op_mm.py | 10 ++ 3 files changed, 128 insertions(+), 12 deletions(-) create mode 100644 tests/batch_invariant/test_batch_invariance_op_logsoftmax.py diff --git a/custom_ops/batch_invariant_ops/batch_invariant_ops.py b/custom_ops/batch_invariant_ops/batch_invariant_ops.py index 7b83e7f7e6a..ccfac6e0c0b 100644 --- a/custom_ops/batch_invariant_ops/batch_invariant_ops.py +++ b/custom_ops/batch_invariant_ops/batch_invariant_ops.py @@ -285,19 +285,19 @@ def _log_softmax_kernel( tl.store(output_row_start_ptr + col_idx, output, mask=mask) -def log_softmax(input: paddle.Tensor, dim: int = -1) -> paddle.Tensor: +def log_softmax(input: paddle.Tensor, axis: int = -1) -> paddle.Tensor: """ Compute log_softmax using Triton kernel. Args: input: Input tensor - dim: Dimension along which to compute log_softmax (only -1 or last dim supported) + axis: Dimension along which to compute log_softmax (only -1 or last dim supported) >> Stashed changes Returns: Tensor with log_softmax applied along the specified dimension """ - # TODO:use axis not dim in paddle - if dim != -1 and dim != input.ndim - 1: + # print("You are using triton impl for log_softmax") + if axis != -1 and axis != input.ndim - 1: raise ValueError("This implementation only supports log_softmax along the last dimension") # Flatten all dimensions except the last one @@ -477,10 +477,8 @@ def addmm_batch_invariant(bias, a, b, alpha=1.0, beta=1.0): return result -def _log_softmax_batch_invariant(input, dim, _half_to_float): - # TODO:use axis not dim in Paddle - assert not _half_to_float, "not implemented" - return log_softmax(input, dim=dim) +def _log_softmax_batch_invariant(input, axis): + return log_softmax(input, axis=axis) def mean_batch_invariant(input, dim, keepdim=False, dtype: paddle.dtype | None = None): @@ -511,12 +509,12 @@ def enable_batch_invariant_mode(): _original_ops["mm"] = paddle._C_ops.matmul _original_ops["addmm"] = paddle._C_ops.addmm - _original_ops["log_softmax"] = paddle.nn.functional.log_softmax + _original_ops["log_softmax"] = paddle._C_ops.log_softmax _original_ops["mean"] = paddle.mean paddle._C_ops.matmul = mm_batch_invariant paddle._C_ops.addmm = addmm_batch_invariant - paddle.nn.functional.log_softmax = _log_softmax_batch_invariant + paddle._C_ops.log_softmax = _log_softmax_batch_invariant paddle.mean = mean_batch_invariant _batch_invariant_MODE = True @@ -532,7 +530,7 @@ def disable_batch_invariant_mode(): if _original_ops["addmm"]: paddle._C_ops.addmm = _original_ops["addmm"] if _original_ops["log_softmax"]: - paddle.nn.functional.log_softmax = _original_ops["log_softmax"] + paddle._C_ops.log_softmax = _original_ops["log_softmax"] if _original_ops["mean"]: paddle.mean = _original_ops["mean"] @@ -543,7 +541,6 @@ def disable_batch_invariant_mode(): def set_batch_invariant_mode(enabled: bool = True): global _batch_invariant_MODE, _original_ops old_mode = _batch_invariant_MODE - # old_ops = _original_ops.copy() if enabled: enable_batch_invariant_mode() else: diff --git a/tests/batch_invariant/test_batch_invariance_op_logsoftmax.py b/tests/batch_invariant/test_batch_invariance_op_logsoftmax.py new file mode 100644 index 00000000000..ee14756c501 --- /dev/null +++ b/tests/batch_invariant/test_batch_invariance_op_logsoftmax.py @@ -0,0 +1,109 @@ +# Adapted from https://github.com/thinking-machines-lab/batch_invariant_ops/blob/main/batch_invariant_ops/test_batch_invariance.py + +import random +import unittest + +import paddle + +from custom_ops.batch_invariant_ops import set_batch_invariant_mode + + +class TestBatchInvariantForLogsoftmax(unittest.TestCase): + def setUp(self): + """ + Initialize the test environment + """ + device = "gpu" if paddle.is_compiled_with_cuda() else "cpu" + paddle.set_device(device) + + def create_softmax_trap_tensor(self, B, D, dtype): + """ + Constructs a "trap" tensor designed to trigger batch-invariance issues in Softmax/LogSoftmax. + Inspired by https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/ + + Principle: + The goal is to make the result of `exp(a - max(a))` contain numbers spanning an extremely wide numerical range + (e.g., 1.0, 1e-5, 1e-10, and many numbers close to 0). + When summing these numbers using parallel reduction, different summation orders (due to parallelism) + can produce different accumulated rounding errors, leading to a subtle difference between + batch (parallel) and single-sample (serial) computation results. + """ + # 1. Determine the desired values after `exp` and calculate the required input values using log(). + max_val = 20.0 + + # Offsets relative to max_val. These offsets result in values spanning vastly different orders of magnitude after exp. + trap_values = [ + max_val, # Corresponds to exp(a-max) -> 1.0 + max_val - 4.6, # Corresponds to exp(a-max) -> ~1e-2 + max_val - 11.5, # Corresponds to exp(a-max) -> ~1e-5 + max_val - 23.0, # Corresponds to exp(a-max) -> ~1e-10 + ] + + # 2. Create a background tensor filled with a very large negative number. + background_val = -1000.0 + a = paddle.full((B, D), background_val, dtype=dtype) + + # 3. Scatter these "trap" values at random positions in each row. + for i in range(B): + # Randomly shuffle the positions of the trap values for each row to increase non-determinism. + indices = random.sample(range(D), k=len(trap_values)) + for j, val in enumerate(trap_values): + a[i, indices[j]] = val + + return a + + def test_batch_invariance(self, B: int = 2048, D: int = 4096, dtype=paddle.float32): + a = self.create_softmax_trap_tensor(B, D, dtype) + + # Method 1: Matrix-vector multiplication (batch size 1) + out1 = paddle.nn.functional.log_softmax(a[:1]) + + # Method 2: Matrix-matrix multiplication, then slice (full batch) + out2 = paddle.nn.functional.log_softmax(a)[:1] + + # Check if results are identical + diff = (out1 - out2).abs().max() + return diff.item() == 0, diff + + def run_iters(self, iters=10, ass=False): + for dtype in [paddle.float32, paddle.bfloat16, paddle.float16]: + is_deterministic = True + difflist = [] + for i in range(iters): + isd, df = self.test_batch_invariance(dtype=dtype) + is_deterministic = is_deterministic and isd + difflist.append(df) + print( + f"Batch Deterministic: {is_deterministic} run-to-run max/min/diff {max(difflist)}/{min(difflist)}/{max(difflist)-min(difflist)} for {dtype} in {iters} iterations" + ) + if ass: + assert max(difflist) == 0 + + def test_case(self): + # Test with standard Paddle (likely to show differences) + print("Standard Paddle:") + with set_batch_invariant_mode(False): + self.run_iters(ass=False) + # Test with batch-invariant operations + print("\nBatch-Invariant Mode:") + with set_batch_invariant_mode(True): + self.run_iters(ass=True) + + +if __name__ == "__main__": + unittest.main() + """ + Even in Standard Paddle, we can achieve deterministic results, so maybe the standard implementation is already batch-invariant? + + Result: + + Standard Paddle: + Batch Deterministic: True run-to-run max/min/diff 0.0/0.0/0.0 for paddle.float32 in 10 iterations + Batch Deterministic: True run-to-run max/min/diff 0.0/0.0/0.0 for paddle.bfloat16 in 10 iterations + Batch Deterministic: True run-to-run max/min/diff 0.0/0.0/0.0 for paddle.float16 in 10 iterations + + Batch-Invariant Mode: + Batch Deterministic: True run-to-run max/min/diff 0.0/0.0/0.0 for paddle.float32 in 10 iterations + Batch Deterministic: True run-to-run max/min/diff 0.0/0.0/0.0 for paddle.bfloat16 in 10 iterations + Batch Deterministic: True run-to-run max/min/diff 0.0/0.0/0.0 for paddle.float16 in 10 iterations + """ diff --git a/tests/batch_invariant/test_batch_invariance_op_mm.py b/tests/batch_invariant/test_batch_invariance_op_mm.py index 22fb9c42c82..102584d14f0 100644 --- a/tests/batch_invariant/test_batch_invariance_op_mm.py +++ b/tests/batch_invariant/test_batch_invariance_op_mm.py @@ -56,3 +56,13 @@ def test_case(self): if __name__ == "__main__": unittest.main() + """ + + Standard Paddle: + Batch Deterministic: False run-to-run max/min/diff 10.7294921875/10.7294921875/0.0 for paddle.float32 in 10 iterations + Batch Deterministic: True run-to-run max/min/diff 0.0/0.0/0.0 for paddle.bfloat16 in 10 iterations + + Batch-Invariant Mode: + Batch Deterministic: True run-to-run max/min/diff 0.0/0.0/0.0 for paddle.float32 in 10 iterations + Batch Deterministic: True run-to-run max/min/diff 0.0/0.0/0.0 for paddle.bfloat16 in 10 iterations + """ From 1d2021c585d6d68e978049b8164285e37ca18411 Mon Sep 17 00:00:00 2001 From: littledgg <1658565283@qq.com> Date: Tue, 4 Nov 2025 19:52:39 +0800 Subject: [PATCH 06/16] Add testcase for mean --- .../batch_invariant_ops.py | 24 ++++--- .../test_batch_invariance_op_mean.py | 66 +++++++++++++++++++ 2 files changed, 82 insertions(+), 8 deletions(-) create mode 100644 tests/batch_invariant/test_batch_invariance_op_mean.py diff --git a/custom_ops/batch_invariant_ops/batch_invariant_ops.py b/custom_ops/batch_invariant_ops/batch_invariant_ops.py index ccfac6e0c0b..bd43240ea82 100644 --- a/custom_ops/batch_invariant_ops/batch_invariant_ops.py +++ b/custom_ops/batch_invariant_ops/batch_invariant_ops.py @@ -481,16 +481,24 @@ def _log_softmax_batch_invariant(input, axis): return log_softmax(input, axis=axis) -def mean_batch_invariant(input, dim, keepdim=False, dtype: paddle.dtype | None = None): +def mean_batch_invariant(input, axis, keepdim=False, dtype: paddle.dtype | None = None, out=None): assert dtype is None or dtype == paddle.float32, f"unsupported dtype: {dtype}" - if len(dim) == 1: - return mean_dim(input, dim[0], keepdim=keepdim) + if type(axis) is int: + result = mean_dim(input, axis, keepdim=keepdim) + elif len(axis) == 1: # axis: int | Sequence[int] + result = mean_dim(input, axis[0], keepdim=keepdim) else: assert input.dtype in {paddle.float16, paddle.bfloat16, paddle.float32}, "only float types supported for now" n_elems = 1 - for d in dim: + for d in axis: n_elems *= input.shape[d] - return paddle.sum(input, dim=dim, keepdim=keepdim, dtype=paddle.float32) / n_elems + result = paddle.sum(input, axis=axis, keepdim=keepdim, dtype=paddle.float32) / n_elems + + # Handle out parameter if provided + if out is not None: + out.copy_(result) + return out + return result _original_ops = {"mm": None, "addmm": None, "_log_softmax": None, "mean_dim": None} @@ -510,12 +518,12 @@ def enable_batch_invariant_mode(): _original_ops["mm"] = paddle._C_ops.matmul _original_ops["addmm"] = paddle._C_ops.addmm _original_ops["log_softmax"] = paddle._C_ops.log_softmax - _original_ops["mean"] = paddle.mean + _original_ops["mean"] = paddle._C_ops.mean paddle._C_ops.matmul = mm_batch_invariant paddle._C_ops.addmm = addmm_batch_invariant paddle._C_ops.log_softmax = _log_softmax_batch_invariant - paddle.mean = mean_batch_invariant + paddle._C_ops.mean = mean_batch_invariant _batch_invariant_MODE = True @@ -532,7 +540,7 @@ def disable_batch_invariant_mode(): if _original_ops["log_softmax"]: paddle._C_ops.log_softmax = _original_ops["log_softmax"] if _original_ops["mean"]: - paddle.mean = _original_ops["mean"] + paddle._C_ops.mean = _original_ops["mean"] _batch_invariant_MODE = False diff --git a/tests/batch_invariant/test_batch_invariance_op_mean.py b/tests/batch_invariant/test_batch_invariance_op_mean.py new file mode 100644 index 00000000000..46726f05f6e --- /dev/null +++ b/tests/batch_invariant/test_batch_invariance_op_mean.py @@ -0,0 +1,66 @@ +# Adapted from https://github.com/thinking-machines-lab/batch_invariant_ops/blob/main/batch_invariant_ops/test_batch_invariance.py + +import unittest + +import paddle + +from custom_ops.batch_invariant_ops import set_batch_invariant_mode + + +class TestBatchInvariantForMean(unittest.TestCase): + def setUp(self): + """ + Initialize the test environment + """ + device = "gpu" if paddle.is_compiled_with_cuda() else "cpu" + paddle.set_device(device) + + def test_batch_invariance(self, B: int = 2048, D: int = 4096, dtype=paddle.float32): + a = paddle.linspace(-100, 100, B * D, dtype=dtype).reshape(B, D) + + # Method 1: Matrix-vector multiplication (batch size 1) + out1 = paddle.mean(a[:1], axis=-1) + + # Method 2: Matrix-matrix multiplication, then slice (full batch) + out2 = paddle.mean(a, axis=-1)[:1] + + # Check if results are identical + diff = (out1 - out2).abs().max() + return diff.item() == 0, diff + + def run_iters(self, iters=10, ass=False): + for dtype in [paddle.float32, paddle.bfloat16]: + is_deterministic = True + difflist = [] + for i in range(iters): + isd, df = self.test_batch_invariance(dtype=dtype) + is_deterministic = is_deterministic and isd + difflist.append(df) + print( + f"Batch Deterministic: {is_deterministic} run-to-run max/min/diff {max(difflist)}/{min(difflist)}/{max(difflist)-min(difflist)} for {dtype} in {iters} iterations" + ) + if ass: + assert max(difflist) == 0 + + def test_case(self): + # Test with standard Paddle (likely to show differences) + print("Standard Paddle:") + with set_batch_invariant_mode(False): + self.run_iters(ass=False) + # Test with batch-invariant operations + print("\nBatch-Invariant Mode:") + with set_batch_invariant_mode(True): + self.run_iters(ass=True) + + +if __name__ == "__main__": + unittest.main() + """ + Standard Paddle: + Batch Deterministic: False run-to-run max/min/diff 7.62939453125e-06/7.62939453125e-06/0.0 for paddle.float32 in 10 iterations + Batch Deterministic: True run-to-run max/min/diff 0.0/0.0/0.0 for paddle.bfloat16 in 10 iterations + + Batch-Invariant Mode: + Batch Deterministic: True run-to-run max/min/diff 0.0/0.0/0.0 for paddle.float32 in 10 iterations + Batch Deterministic: True run-to-run max/min/diff 0.0/0.0/0.0 for paddle.bfloat16 in 10 iterations + """ From 471b075c54711b2bee49c526539b568ee5d1e74f Mon Sep 17 00:00:00 2001 From: littledgg <1658565283@qq.com> Date: Tue, 4 Nov 2025 20:05:47 +0800 Subject: [PATCH 07/16] Add testcase for addmm --- .../test_batch_invariance_op_addmm.py | 67 +++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 tests/batch_invariant/test_batch_invariance_op_addmm.py diff --git a/tests/batch_invariant/test_batch_invariance_op_addmm.py b/tests/batch_invariant/test_batch_invariance_op_addmm.py new file mode 100644 index 00000000000..e8b33ae3bc4 --- /dev/null +++ b/tests/batch_invariant/test_batch_invariance_op_addmm.py @@ -0,0 +1,67 @@ +# Adapted from https://github.com/thinking-machines-lab/batch_invariant_ops/blob/main/batch_invariant_ops/test_batch_invariance.py + +import unittest + +import paddle + +from custom_ops.batch_invariant_ops import set_batch_invariant_mode + + +class TestBatchInvariantForAddmm(unittest.TestCase): + def setUp(self): + """ + Initialize the test environment + """ + device = "gpu" if paddle.is_compiled_with_cuda() else "cpu" + paddle.set_device(device) + + def test_batch_invariance(self, B: int = 2048, D: int = 4096, dtype=paddle.float32): + a = paddle.linspace(-100, 100, B * D, dtype=dtype).reshape(B, D) + b = paddle.linspace(-100, 100, D * D, dtype=dtype).reshape(D, D) + + # Method 1: Matrix-vector multiplication (batch size 1) + out1 = paddle.addmm(a[:1].squeeze(0), a[:1], b) + + # Method 2: Matrix-matrix multiplication, then slice (full batch) + out2 = paddle.addmm(a[:1].squeeze(0), a, b)[:1] + + # Check if results are identical + diff = (out1 - out2).abs().max() + return diff.item() == 0, diff + + def run_iters(self, iters=10, ass=False): + for dtype in [paddle.float32, paddle.bfloat16]: + is_deterministic = True + difflist = [] + for i in range(iters): + isd, df = self.test_batch_invariance(dtype=dtype) + is_deterministic = is_deterministic and isd + difflist.append(df) + print( + f"Batch Deterministic: {is_deterministic} run-to-run max/min/diff {max(difflist)}/{min(difflist)}/{max(difflist)-min(difflist)} for {dtype} in {iters} iterations" + ) + if ass: + assert max(difflist) == 0 + + def test_case(self): + # Test with standard Paddle (likely to show differences) + print("Standard Paddle:") + with set_batch_invariant_mode(False): + self.run_iters(ass=False) + # Test with batch-invariant operations + print("\nBatch-Invariant Mode:") + with set_batch_invariant_mode(True): + self.run_iters(ass=True) + + +if __name__ == "__main__": + unittest.main() + """ + Standard Paddle: + Batch Deterministic: False run-to-run max/min/diff 10.7294921875/10.7294921875/0.0 for paddle.float32 in 10 iterations + Batch Deterministic: True run-to-run max/min/diff 0.0/0.0/0.0 for paddle.bfloat16 in 10 iterations + + Batch-Invariant Mode: + Batch Deterministic: True run-to-run max/min/diff 0.0/0.0/0.0 for paddle.float32 in 10 iterations + Batch Deterministic: True run-to-run max/min/diff 0.0/0.0/0.0 for paddle.bfloat16 in 10 iterations + """ From daf47cd0cc9e32313d15720500490d41c6ac8058 Mon Sep 17 00:00:00 2001 From: littledgg <1658565283@qq.com> Date: Wed, 5 Nov 2025 11:09:35 +0800 Subject: [PATCH 08/16] fix pre-commit --- custom_ops/batch_invariant_ops/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/custom_ops/batch_invariant_ops/__init__.py b/custom_ops/batch_invariant_ops/__init__.py index 07cc85b2de4..fe7b9ff8cef 100644 --- a/custom_ops/batch_invariant_ops/__init__.py +++ b/custom_ops/batch_invariant_ops/__init__.py @@ -22,4 +22,4 @@ "mean_dim", "get_batch_invariant_attention_block_size", "AttentionBlockSize", -] \ No newline at end of file +] From c85f38be1eac7fdc66f9fc174ff92ff3d865f8dd Mon Sep 17 00:00:00 2001 From: littledgg <1658565283@qq.com> Date: Wed, 5 Nov 2025 11:31:09 +0800 Subject: [PATCH 09/16] API check v0.9 --- .../batch_invariant_ops.py | 25 +++++++++++-------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/custom_ops/batch_invariant_ops/batch_invariant_ops.py b/custom_ops/batch_invariant_ops/batch_invariant_ops.py index bd43240ea82..2af01bff5aa 100644 --- a/custom_ops/batch_invariant_ops/batch_invariant_ops.py +++ b/custom_ops/batch_invariant_ops/batch_invariant_ops.py @@ -471,28 +471,31 @@ def mm_batch_invariant(a, b, transpose_x=False, transpose_y=False): return matmul_persistent(a, b) -def addmm_batch_invariant(bias, a, b, alpha=1.0, beta=1.0): - # TODO:check API - result = matmul_persistent(a, b, bias=bias) +def addmm_batch_invariant( + input: paddle.Tensor, x: paddle.Tensor, y: paddle.Tensor, beta: float = 1.0, alpha: float = 1.0 +) -> paddle.Tensor: + result = matmul_persistent(a=x, b=y, bias=input) return result -def _log_softmax_batch_invariant(input, axis): - return log_softmax(input, axis=axis) +def _log_softmax_batch_invariant(x: paddle.Tensor, axis: int = -1) -> paddle.Tensor: + return log_softmax(input=x, axis=axis) -def mean_batch_invariant(input, axis, keepdim=False, dtype: paddle.dtype | None = None, out=None): +def mean_batch_invariant( + x: paddle.Tensor, axis: list[int] = [], keepdim: bool = False, dtype: paddle.dtype | None = None, out=None +) -> paddle.Tensor: assert dtype is None or dtype == paddle.float32, f"unsupported dtype: {dtype}" if type(axis) is int: - result = mean_dim(input, axis, keepdim=keepdim) + result = mean_dim(x, axis, keepdim=keepdim) elif len(axis) == 1: # axis: int | Sequence[int] - result = mean_dim(input, axis[0], keepdim=keepdim) + result = mean_dim(x, axis[0], keepdim=keepdim) else: - assert input.dtype in {paddle.float16, paddle.bfloat16, paddle.float32}, "only float types supported for now" + assert x.dtype in {paddle.float16, paddle.bfloat16, paddle.float32}, "only float types supported for now" n_elems = 1 for d in axis: - n_elems *= input.shape[d] - result = paddle.sum(input, axis=axis, keepdim=keepdim, dtype=paddle.float32) / n_elems + n_elems *= x.shape[d] + result = paddle.sum(x, axis=axis, keepdim=keepdim, dtype=paddle.float32) / n_elems # Handle out parameter if provided if out is not None: From a573a54e77534cbed33c33fa8bfda7c674c547d1 Mon Sep 17 00:00:00 2001 From: littledgg <1658565283@qq.com> Date: Wed, 5 Nov 2025 17:54:42 +0800 Subject: [PATCH 10/16] move to layers and add comment about log_softmax --- .../layers}/batch_invariant_ops/__init__.py | 0 .../batch_invariant_ops.py | 0 .../batch_invariant/test_batch_invariance.py | 51 ------------------- .../test_batch_invariance_op_addmm.py | 4 +- .../test_batch_invariance_op_logsoftmax.py | 18 ++++++- .../test_batch_invariance_op_mean.py | 4 +- .../test_batch_invariance_op_mm.py | 4 +- 7 files changed, 26 insertions(+), 55 deletions(-) rename {custom_ops => fastdeploy/model_executor/layers}/batch_invariant_ops/__init__.py (100%) rename {custom_ops => fastdeploy/model_executor/layers}/batch_invariant_ops/batch_invariant_ops.py (100%) delete mode 100644 tests/batch_invariant/test_batch_invariance.py diff --git a/custom_ops/batch_invariant_ops/__init__.py b/fastdeploy/model_executor/layers/batch_invariant_ops/__init__.py similarity index 100% rename from custom_ops/batch_invariant_ops/__init__.py rename to fastdeploy/model_executor/layers/batch_invariant_ops/__init__.py diff --git a/custom_ops/batch_invariant_ops/batch_invariant_ops.py b/fastdeploy/model_executor/layers/batch_invariant_ops/batch_invariant_ops.py similarity index 100% rename from custom_ops/batch_invariant_ops/batch_invariant_ops.py rename to fastdeploy/model_executor/layers/batch_invariant_ops/batch_invariant_ops.py diff --git a/tests/batch_invariant/test_batch_invariance.py b/tests/batch_invariant/test_batch_invariance.py deleted file mode 100644 index 4fa4044d182..00000000000 --- a/tests/batch_invariant/test_batch_invariance.py +++ /dev/null @@ -1,51 +0,0 @@ -# Adapted from https://github.com/thinking-machines-lab/batch_invariant_ops/blob/main/batch_invariant_ops/test_batch_invariance.py - -import paddle - -from custom_ops.batch_invariant_ops import set_batch_invariant_mode - -device = "gpu" if paddle.is_compiled_with_cuda() else "cpu" -paddle.set_device(device) - -# Just to get the logging out of the way haha -with set_batch_invariant_mode(True): - pass - - -def test_batch_invariance(dtype=paddle.float32): - B, D = 2048, 4096 - a = paddle.linspace(-100, 100, B * D, dtype=dtype).reshape(B, D) - b = paddle.linspace(-100, 100, D * D, dtype=dtype).reshape(D, D) - - # Method 1: Matrix-vector multiplication (batch size 1) - out1 = paddle.mm(a[:1], b) - - # Method 2: Matrix-matrix multiplication, then slice (full batch) - out2 = paddle.mm(a, b)[:1] - - # Check if results are identical - diff = (out1 - out2).abs().max() - return diff.item() == 0, diff - - -def run_iters(iters=10): - for dtype in [paddle.float32, paddle.bfloat16]: - is_deterministic = True - difflist = [] - for i in range(iters): - isd, df = test_batch_invariance(dtype) - is_deterministic = is_deterministic and isd - difflist.append(df) - print( - f"Batch Deterministic: {is_deterministic} run-to-run max/min/diff {max(difflist)}/{min(difflist)}/{max(difflist)-min(difflist)} for {dtype} in {iters} iterations" - ) - - -# Test with standard Paddle (likely to show differences) -print("Standard Paddle:") -with set_batch_invariant_mode(False): - run_iters() -# Test with batch-invariant operations -print("\nBatch-Invariant Mode:") -with set_batch_invariant_mode(True): - run_iters() diff --git a/tests/batch_invariant/test_batch_invariance_op_addmm.py b/tests/batch_invariant/test_batch_invariance_op_addmm.py index e8b33ae3bc4..59bcfd147a4 100644 --- a/tests/batch_invariant/test_batch_invariance_op_addmm.py +++ b/tests/batch_invariant/test_batch_invariance_op_addmm.py @@ -4,7 +4,9 @@ import paddle -from custom_ops.batch_invariant_ops import set_batch_invariant_mode +from fastdeploy.model_executor.layers.batch_invariant_ops import ( + set_batch_invariant_mode, +) class TestBatchInvariantForAddmm(unittest.TestCase): diff --git a/tests/batch_invariant/test_batch_invariance_op_logsoftmax.py b/tests/batch_invariant/test_batch_invariance_op_logsoftmax.py index ee14756c501..1bc77669a76 100644 --- a/tests/batch_invariant/test_batch_invariance_op_logsoftmax.py +++ b/tests/batch_invariant/test_batch_invariance_op_logsoftmax.py @@ -5,7 +5,9 @@ import paddle -from custom_ops.batch_invariant_ops import set_batch_invariant_mode +from fastdeploy.model_executor.layers.batch_invariant_ops import ( + set_batch_invariant_mode, +) class TestBatchInvariantForLogsoftmax(unittest.TestCase): @@ -95,6 +97,20 @@ def test_case(self): """ Even in Standard Paddle, we can achieve deterministic results, so maybe the standard implementation is already batch-invariant? + After reviewing the four implementations called by the dispatcher function `SoftmaxForwardCUDAKernelDriverImpl` (dispatched by 'D') + in `paddle/phi/kernels/gpudnn/softmax_gpudnn.h`: + + 1. SwitchWarpSoftmaxForward (one Warp processes 1-2 rows) + 2. LaunchKeMatrixSoftmaxForwardKernel (one Block processes one row) + 3. LaunchSoftmaxForwardCudnnKernel (the Cudnn implementation) + 4. LaunchNormalSoftmaxForward (in one Block, threads with the same threadIdx.x [a "thread column"] cooperate to process one row) + + Excluding the Cudnn implementation, the other three custom implementations are almost certainly batch-invariant.(Need someone check again) + The determinism of the Cudnn implementation is uncertain. + + However, in practice, this testcase (D=4096) is dispatched to the Cudnn implementation, + while Qwen-3 8B is dispatched to the LaunchKeMatrixSoftmaxForwardKernel implementation. + Result: Standard Paddle: diff --git a/tests/batch_invariant/test_batch_invariance_op_mean.py b/tests/batch_invariant/test_batch_invariance_op_mean.py index 46726f05f6e..cf89593fcf4 100644 --- a/tests/batch_invariant/test_batch_invariance_op_mean.py +++ b/tests/batch_invariant/test_batch_invariance_op_mean.py @@ -4,7 +4,9 @@ import paddle -from custom_ops.batch_invariant_ops import set_batch_invariant_mode +from fastdeploy.model_executor.layers.batch_invariant_ops import ( + set_batch_invariant_mode, +) class TestBatchInvariantForMean(unittest.TestCase): diff --git a/tests/batch_invariant/test_batch_invariance_op_mm.py b/tests/batch_invariant/test_batch_invariance_op_mm.py index 102584d14f0..4e77b29a1bc 100644 --- a/tests/batch_invariant/test_batch_invariance_op_mm.py +++ b/tests/batch_invariant/test_batch_invariance_op_mm.py @@ -4,7 +4,9 @@ import paddle -from custom_ops.batch_invariant_ops import set_batch_invariant_mode +from fastdeploy.model_executor.layers.batch_invariant_ops import ( + set_batch_invariant_mode, +) class TestBatchInvariantForMM(unittest.TestCase): From f25a8c75e215501d98cc7d7245778367c6d21fae Mon Sep 17 00:00:00 2001 From: Jundong Liu <61149469+littledgg@users.noreply.github.com> Date: Wed, 12 Nov 2025 16:46:30 +0800 Subject: [PATCH 11/16] Update fastdeploy/model_executor/layers/batch_invariant_ops/batch_invariant_ops.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 存在于原版代码注释中的版本控制遗留的内容,确实应该去除 Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../layers/batch_invariant_ops/batch_invariant_ops.py | 1 - 1 file changed, 1 deletion(-) diff --git a/fastdeploy/model_executor/layers/batch_invariant_ops/batch_invariant_ops.py b/fastdeploy/model_executor/layers/batch_invariant_ops/batch_invariant_ops.py index 2af01bff5aa..9ea3fa05b3e 100644 --- a/fastdeploy/model_executor/layers/batch_invariant_ops/batch_invariant_ops.py +++ b/fastdeploy/model_executor/layers/batch_invariant_ops/batch_invariant_ops.py @@ -292,7 +292,6 @@ def log_softmax(input: paddle.Tensor, axis: int = -1) -> paddle.Tensor: Args: input: Input tensor axis: Dimension along which to compute log_softmax (only -1 or last dim supported) - >> Stashed changes Returns: Tensor with log_softmax applied along the specified dimension """ From 5b44576287e5993da03c1e927cb43d369f26c91b Mon Sep 17 00:00:00 2001 From: Jundong Liu <61149469+littledgg@users.noreply.github.com> Date: Wed, 12 Nov 2025 16:56:17 +0800 Subject: [PATCH 12/16] Update tests/batch_invariant/test_batch_invariance_op_mean.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- tests/batch_invariant/test_batch_invariance_op_mean.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/batch_invariant/test_batch_invariance_op_mean.py b/tests/batch_invariant/test_batch_invariance_op_mean.py index cf89593fcf4..11ea6c73e34 100644 --- a/tests/batch_invariant/test_batch_invariance_op_mean.py +++ b/tests/batch_invariant/test_batch_invariance_op_mean.py @@ -20,7 +20,7 @@ def setUp(self): def test_batch_invariance(self, B: int = 2048, D: int = 4096, dtype=paddle.float32): a = paddle.linspace(-100, 100, B * D, dtype=dtype).reshape(B, D) - # Method 1: Matrix-vector multiplication (batch size 1) + # Method 1: Mean reduction over last axis (batch size 1) out1 = paddle.mean(a[:1], axis=-1) # Method 2: Matrix-matrix multiplication, then slice (full batch) From 2026212c7369af6f295d9e5efcf9d69d426e52e7 Mon Sep 17 00:00:00 2001 From: Jundong Liu <61149469+littledgg@users.noreply.github.com> Date: Wed, 12 Nov 2025 16:56:36 +0800 Subject: [PATCH 13/16] Update tests/batch_invariant/test_batch_invariance_op_logsoftmax.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- tests/batch_invariant/test_batch_invariance_op_logsoftmax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/batch_invariant/test_batch_invariance_op_logsoftmax.py b/tests/batch_invariant/test_batch_invariance_op_logsoftmax.py index 1bc77669a76..193112921e6 100644 --- a/tests/batch_invariant/test_batch_invariance_op_logsoftmax.py +++ b/tests/batch_invariant/test_batch_invariance_op_logsoftmax.py @@ -57,7 +57,7 @@ def create_softmax_trap_tensor(self, B, D, dtype): def test_batch_invariance(self, B: int = 2048, D: int = 4096, dtype=paddle.float32): a = self.create_softmax_trap_tensor(B, D, dtype) - # Method 1: Matrix-vector multiplication (batch size 1) + # Method 1: log_softmax on batch size 1 (first row) out1 = paddle.nn.functional.log_softmax(a[:1]) # Method 2: Matrix-matrix multiplication, then slice (full batch) From 48c0ed1c43b5005d09f2fdeabc57e0ff81685047 Mon Sep 17 00:00:00 2001 From: Jundong Liu <61149469+littledgg@users.noreply.github.com> Date: Wed, 12 Nov 2025 17:03:56 +0800 Subject: [PATCH 14/16] Update fastdeploy/model_executor/layers/batch_invariant_ops/batch_invariant_ops.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../layers/batch_invariant_ops/batch_invariant_ops.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/fastdeploy/model_executor/layers/batch_invariant_ops/batch_invariant_ops.py b/fastdeploy/model_executor/layers/batch_invariant_ops/batch_invariant_ops.py index 9ea3fa05b3e..9691b741083 100644 --- a/fastdeploy/model_executor/layers/batch_invariant_ops/batch_invariant_ops.py +++ b/fastdeploy/model_executor/layers/batch_invariant_ops/batch_invariant_ops.py @@ -473,7 +473,8 @@ def mm_batch_invariant(a, b, transpose_x=False, transpose_y=False): def addmm_batch_invariant( input: paddle.Tensor, x: paddle.Tensor, y: paddle.Tensor, beta: float = 1.0, alpha: float = 1.0 ) -> paddle.Tensor: - result = matmul_persistent(a=x, b=y, bias=input) + matmul_result = matmul_persistent(a=x, b=y) + result = beta * input + alpha * matmul_result return result From 99b836300f961500c87c726c7f71994d6051ea4f Mon Sep 17 00:00:00 2001 From: littledgg <1658565283@qq.com> Date: Wed, 12 Nov 2025 17:15:47 +0800 Subject: [PATCH 15/16] change comment after copilot fix --- .../layers/batch_invariant_ops/batch_invariant_ops.py | 6 ++++-- tests/batch_invariant/test_batch_invariance_op_addmm.py | 4 ++-- .../batch_invariant/test_batch_invariance_op_logsoftmax.py | 2 +- tests/batch_invariant/test_batch_invariance_op_mean.py | 2 +- 4 files changed, 8 insertions(+), 6 deletions(-) diff --git a/fastdeploy/model_executor/layers/batch_invariant_ops/batch_invariant_ops.py b/fastdeploy/model_executor/layers/batch_invariant_ops/batch_invariant_ops.py index 9691b741083..0bb8178b234 100644 --- a/fastdeploy/model_executor/layers/batch_invariant_ops/batch_invariant_ops.py +++ b/fastdeploy/model_executor/layers/batch_invariant_ops/batch_invariant_ops.py @@ -140,7 +140,7 @@ def get_compute_units(): NUM_SMS = device_properties.multi_processor_count except Exception: print("Could not get CUDA device properties. Falling back to CPU threads.") - # TODO: Paddle lacks a torch.get_num_threads() equivalent for the *configured* thread count. + # TODO(liujundong): Paddle lacks a torch.get_num_threads() equivalent for the *configured* thread count. # Using os.cpu_count() (total logical cores) as a fallback, which may not be correct. # Must check downstream logic to determine if this impacts correctness. NUM_SMS = os.cpu_count() @@ -474,7 +474,9 @@ def addmm_batch_invariant( input: paddle.Tensor, x: paddle.Tensor, y: paddle.Tensor, beta: float = 1.0, alpha: float = 1.0 ) -> paddle.Tensor: matmul_result = matmul_persistent(a=x, b=y) - result = beta * input + alpha * matmul_result + result = ( + beta * input + alpha * matmul_result + ) # TODO(liujundong): paddle._C_ops.addmm have more parameters, this may effect the performance return result diff --git a/tests/batch_invariant/test_batch_invariance_op_addmm.py b/tests/batch_invariant/test_batch_invariance_op_addmm.py index 59bcfd147a4..45b5caac58f 100644 --- a/tests/batch_invariant/test_batch_invariance_op_addmm.py +++ b/tests/batch_invariant/test_batch_invariance_op_addmm.py @@ -21,10 +21,10 @@ def test_batch_invariance(self, B: int = 2048, D: int = 4096, dtype=paddle.float a = paddle.linspace(-100, 100, B * D, dtype=dtype).reshape(B, D) b = paddle.linspace(-100, 100, D * D, dtype=dtype).reshape(D, D) - # Method 1: Matrix-vector multiplication (batch size 1) + # Method 1: Matrix-vector multiplication and add (batch size 1) out1 = paddle.addmm(a[:1].squeeze(0), a[:1], b) - # Method 2: Matrix-matrix multiplication, then slice (full batch) + # Method 2: Matrix-matrix multiplication and add, then slice (full batch) out2 = paddle.addmm(a[:1].squeeze(0), a, b)[:1] # Check if results are identical diff --git a/tests/batch_invariant/test_batch_invariance_op_logsoftmax.py b/tests/batch_invariant/test_batch_invariance_op_logsoftmax.py index 193112921e6..5a7b6c7a5c6 100644 --- a/tests/batch_invariant/test_batch_invariance_op_logsoftmax.py +++ b/tests/batch_invariant/test_batch_invariance_op_logsoftmax.py @@ -60,7 +60,7 @@ def test_batch_invariance(self, B: int = 2048, D: int = 4096, dtype=paddle.float # Method 1: log_softmax on batch size 1 (first row) out1 = paddle.nn.functional.log_softmax(a[:1]) - # Method 2: Matrix-matrix multiplication, then slice (full batch) + # Method 2: log_softmax on full batch, then slice (first row) out2 = paddle.nn.functional.log_softmax(a)[:1] # Check if results are identical diff --git a/tests/batch_invariant/test_batch_invariance_op_mean.py b/tests/batch_invariant/test_batch_invariance_op_mean.py index 11ea6c73e34..fc7796e700f 100644 --- a/tests/batch_invariant/test_batch_invariance_op_mean.py +++ b/tests/batch_invariant/test_batch_invariance_op_mean.py @@ -23,7 +23,7 @@ def test_batch_invariance(self, B: int = 2048, D: int = 4096, dtype=paddle.float # Method 1: Mean reduction over last axis (batch size 1) out1 = paddle.mean(a[:1], axis=-1) - # Method 2: Matrix-matrix multiplication, then slice (full batch) + # Method 2: Mean reduction over last axis (full batch) out2 = paddle.mean(a, axis=-1)[:1] # Check if results are identical From 36e8d86afad5880dd88d09f201fee0fa3ed53cbf Mon Sep 17 00:00:00 2001 From: littledgg <1658565283@qq.com> Date: Thu, 13 Nov 2025 15:30:13 +0800 Subject: [PATCH 16/16] fix bug about addmm --- .../batch_invariant_ops/batch_invariant_ops.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/fastdeploy/model_executor/layers/batch_invariant_ops/batch_invariant_ops.py b/fastdeploy/model_executor/layers/batch_invariant_ops/batch_invariant_ops.py index 0bb8178b234..1afec4829a7 100644 --- a/fastdeploy/model_executor/layers/batch_invariant_ops/batch_invariant_ops.py +++ b/fastdeploy/model_executor/layers/batch_invariant_ops/batch_invariant_ops.py @@ -473,10 +473,14 @@ def mm_batch_invariant(a, b, transpose_x=False, transpose_y=False): def addmm_batch_invariant( input: paddle.Tensor, x: paddle.Tensor, y: paddle.Tensor, beta: float = 1.0, alpha: float = 1.0 ) -> paddle.Tensor: - matmul_result = matmul_persistent(a=x, b=y) - result = ( - beta * input + alpha * matmul_result - ) # TODO(liujundong): paddle._C_ops.addmm have more parameters, this may effect the performance + """ " + We need achieve `Out = alpha * (x @ y) + beta * input` + But matmul_persistent only achieve `x @ y + input`(according to aten::addmm in torch,paddle._C_ops.addmm have more parameters) + So we use `alpha * (x @ y) + beta * input = alpha * [ (x @ y) + (beta / alpha) * input ]` + to minimize the effection on performance + """ + matmul_result = matmul_persistent(a=x, b=y, bias=input * beta / alpha) + result = alpha * matmul_result return result