Skip to content

[Bug] GatedDeltaNet backward error on Blackwell: 'error: operand #2 does not dominate this use' #638

@wangsiyu

Description

@wangsiyu

Checklist

  • I have checked FAQs and existing issues for similar problems
  • Please report this bug in English to ensure wider understanding and support

Describe the Bug

When executing GDN backwards on Blackwell triton 3.5.0, the following error happens.

flash-linear-attention/fla/ops/gated_delta_rule/chunk.py:331: UserWarning: head_first is deprecated and will be removed in a future version. Please use head_first=False for now instead.
  warnings.warn(
flash-linear-attention/fla/ops/gated_delta_rule/chunk.py:331: UserWarning: head_first is deprecated and will be removed in a future version. Please use head_first=False for now instead.
  warnings.warn(
flash-linear-attention/fla/ops/gated_delta_rule/wy_fast.py:208:39: error: operand #2 does not dominate this use
        b_dk += tl.dot(tl.trans(b_dA), b_kb)
                                      ^
flash-linear-attention/fla/ops/gated_delta_rule/wy_fast.py:208:16: note: operand defined here (op in the same block)
        b_dk += tl.dot(tl.trans(b_dA), b_kb)
               ^
module {
  tt.func public @prepare_wy_repr_bwd_kernel(%arg0: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, %arg6: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, %arg7: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, %arg8: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, %arg9: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg10: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg11: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, %arg12: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg13: i32) attributes {noinline = false} {
    %cst = arith.constant dense<128> : tensor<1x64xi64>
    %cst_0 = arith.constant dense<4096> : tensor<64x1xi64>
    %cst_1 = arith.constant dense<0> : tensor<1x64xi64>
    %cst_2 = arith.constant dense<64> : tensor<64x1xi64>
    %cst_3 = arith.constant dense<0> : tensor<64x1xi64>
    %cst_4 = arith.constant dense<2048> : tensor<1x64xi64>
    %cst_5 = arith.constant dense<0> : tensor<64xi64>
    %cst_6 = arith.constant dense<32> : tensor<64xi64>
    %c2_i32 = arith.constant 2 : i32
    %cst_7 = arith.constant dense<0.000000e+00> : tensor<64xf32>
    %c1_i32 = arith.constant 1 : i32
    %cst_8 = arith.constant dense<0.000000e+00> : tensor<64x64xf32>
    %c128_i32 = arith.constant 128 : i32
    %c0_i32 = arith.constant 0 : i32
    %c64_i32 = arith.constant 64 : i32
    %c32_i32 = arith.constant 32 : i32
    %0 = tt.get_program_id x : i32
    %1 = tt.get_program_id y : i32
    %2 = arith.divsi %1, %c32_i32 : i32
    %3 = arith.remsi %1, %c32_i32 : i32
    %4 = arith.muli %2, %arg13 : i32
    %5 = arith.muli %4, %c32_i32 : i32
    %6 = arith.addi %5, %3 : i32
    %7 = tt.addptr %arg2, %6 : !tt.ptr<f32>, i32
    %8 = arith.muli %0, %c64_i32 : i32
    %9 = arith.extsi %arg13 : i32 to i64
    %10 = arith.extsi %8 : i32 to i64
    %11 = tt.addptr %arg9, %6 : !tt.ptr<f32>, i32
    %12 = arith.muli %6, %c64_i32 : i32
    %13 = tt.addptr %arg4, %12 : !tt.ptr<bf16>, i32
    %14 = tt.splat %7 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>>
    %15 = tt.splat %10 : i64 -> tensor<64xi64>
    %16 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
    %17 = arith.extsi %16 : tensor<64xi32> to tensor<64xi64>
    %18 = arith.addi %15, %17 : tensor<64xi64>
    %19 = arith.muli %18, %cst_6 : tensor<64xi64>
    %20 = tt.addptr %14, %19 : tensor<64x!tt.ptr<f32>>, tensor<64xi64>
    %21 = arith.cmpi sge, %18, %cst_5 : tensor<64xi64>
    %22 = tt.splat %9 : i64 -> tensor<64xi64>
    %23 = arith.cmpi slt, %18, %22 : tensor<64xi64>
    %24 = arith.andi %21, %23 : tensor<64xi1>
    %25 = tt.load %20, %24 : tensor<64x!tt.ptr<f32>>
    %26 = tt.splat %13 : !tt.ptr<bf16> -> tensor<64x64x!tt.ptr<bf16>>
    %27 = tt.expand_dims %17 {axis = 1 : i32} : tensor<64xi64> -> tensor<64x1xi64>
    %28 = tt.broadcast %27 : tensor<64x1xi64> -> tensor<64x64xi64>
    %29 = tt.expand_dims %18 {axis = 0 : i32} : tensor<64xi64> -> tensor<1x64xi64>
    %30 = arith.muli %29, %cst_4 : tensor<1x64xi64>
    %31 = tt.broadcast %30 : tensor<1x64xi64> -> tensor<64x64xi64>
    %32 = arith.addi %28, %31 : tensor<64x64xi64>
    %33 = tt.addptr %26, %32 : tensor<64x64x!tt.ptr<bf16>>, tensor<64x64xi64>
    %34 = arith.cmpi sge, %27, %cst_3 : tensor<64x1xi64>
    %35 = arith.cmpi slt, %27, %cst_2 : tensor<64x1xi64>
    %36 = arith.andi %34, %35 : tensor<64x1xi1>
    %37 = tt.broadcast %36 : tensor<64x1xi1> -> tensor<64x64xi1>
    %38 = arith.cmpi sge, %29, %cst_1 : tensor<1x64xi64>
    %39 = tt.splat %9 : i64 -> tensor<1x64xi64>
    %40 = arith.cmpi slt, %29, %39 : tensor<1x64xi64>
    %41 = arith.andi %38, %40 : tensor<1x64xi1>
    %42 = tt.broadcast %41 : tensor<1x64xi1> -> tensor<64x64xi1>
    %43 = arith.andi %37, %42 : tensor<64x64xi1>
    %44 = tt.load %33, %43 : tensor<64x64x!tt.ptr<bf16>>
    %45 = tt.addptr %arg3, %6 : !tt.ptr<f32>, i32
    %46 = tt.splat %45 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>>
    %47 = tt.addptr %46, %19 : tensor<64x!tt.ptr<f32>>, tensor<64xi64>
    %48 = tt.load %47, %24 : tensor<64x!tt.ptr<f32>>
    %49 = math.exp %48 : tensor<64xf32>
    %50:3 = scf.for %arg14 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg15 = %cst_7, %arg16 = %cst_8, %arg17 = %cst_7) -> (tensor<64xf32>, tensor<64x64xf32>, tensor<64xf32>)  : i32 {
      %99 = arith.muli %6, %c128_i32 : i32
      %100 = tt.addptr %arg0, %99 : !tt.ptr<bf16>, i32
      %101 = arith.muli %arg14, %c64_i32 : i32
      %102 = arith.extsi %101 : i32 to i64
      %103 = tt.addptr %arg7, %99 : !tt.ptr<bf16>, i32
      %104 = tt.addptr %arg5, %99 : !tt.ptr<bf16>, i32
      %105 = tt.splat %100 : !tt.ptr<bf16> -> tensor<64x64x!tt.ptr<bf16>>
      %106 = tt.expand_dims %18 {axis = 1 : i32} : tensor<64xi64> -> tensor<64x1xi64>
      %107 = arith.muli %106, %cst_0 : tensor<64x1xi64>
      %108 = tt.broadcast %107 : tensor<64x1xi64> -> tensor<64x64xi64>
      %109 = tt.splat %102 : i64 -> tensor<64xi64>
      %110 = arith.addi %109, %17 : tensor<64xi64>
      %111 = tt.expand_dims %110 {axis = 0 : i32} : tensor<64xi64> -> tensor<1x64xi64>
      %112 = tt.broadcast %111 : tensor<1x64xi64> -> tensor<64x64xi64>
      %113 = arith.addi %108, %112 : tensor<64x64xi64>
      %114 = tt.addptr %105, %113 : tensor<64x64x!tt.ptr<bf16>>, tensor<64x64xi64>
      %115 = arith.cmpi sge, %106, %cst_3 : tensor<64x1xi64>
      %116 = tt.splat %9 : i64 -> tensor<64x1xi64>
      %117 = arith.cmpi slt, %106, %116 : tensor<64x1xi64>
      %118 = arith.andi %115, %117 : tensor<64x1xi1>
      %119 = tt.broadcast %118 : tensor<64x1xi1> -> tensor<64x64xi1>
      %120 = arith.cmpi sge, %111, %cst_1 : tensor<1x64xi64>
      %121 = arith.cmpi slt, %111, %cst : tensor<1x64xi64>
      %122 = arith.andi %120, %121 : tensor<1x64xi1>
      %123 = tt.broadcast %122 : tensor<1x64xi1> -> tensor<64x64xi1>
      %124 = arith.andi %119, %123 : tensor<64x64xi1>
      %125 = tt.load %114, %124 : tensor<64x64x!tt.ptr<bf16>>
      %126 = arith.mulf %25, %49 : tensor<64xf32>
      %127 = tt.expand_dims %126 {axis = 1 : i32} : tensor<64xf32> -> tensor<64x1xf32>
      %128 = arith.extf %125 : tensor<64x64xbf16> to tensor<64x64xf32>
      %129 = tt.broadcast %127 : tensor<64x1xf32> -> tensor<64x64xf32>
      %130 = arith.mulf %128, %129 : tensor<64x64xf32>
      %131 = tt.splat %104 : !tt.ptr<bf16> -> tensor<64x64x!tt.ptr<bf16>>
      %132 = tt.addptr %131, %113 : tensor<64x64x!tt.ptr<bf16>>, tensor<64x64xi64>
      %133 = tt.load %132, %124 : tensor<64x64x!tt.ptr<bf16>>
      %134 = tt.trans %130 {order = array<i32: 1, 0>} : tensor<64x64xf32> -> tensor<64x64xf32>
      %135 = arith.truncf %134 : tensor<64x64xf32> to tensor<64x64xbf16>
      %136 = tt.dot %133, %135, %arg16, inputPrecision = tf32 : tensor<64x64xbf16> * tensor<64x64xbf16> -> tensor<64x64xf32>
      %137 = tt.dot %44, %133, %cst_8, inputPrecision = tf32 : tensor<64x64xbf16> * tensor<64x64xbf16> -> tensor<64x64xf32>
      %138 = arith.mulf %137, %129 : tensor<64x64xf32>
      %139 = arith.mulf %137, %128 : tensor<64x64xf32>
      %140 = tt.expand_dims %49 {axis = 1 : i32} : tensor<64xf32> -> tensor<64x1xf32>
      %141 = tt.broadcast %140 : tensor<64x1xf32> -> tensor<64x64xf32>
      %142 = arith.mulf %139, %141 : tensor<64x64xf32>
      %143 = "tt.reduce"(%142) <{axis = 1 : i32}> ({
      ^bb0(%arg18: f32, %arg19: f32):
        %151 = arith.addf %arg18, %arg19 : f32
        tt.reduce.return %151 : f32
      }) : (tensor<64x64xf32>) -> tensor<64xf32>
      %144 = arith.addf %arg15, %143 : tensor<64xf32>
      %145 = arith.mulf %137, %130 : tensor<64x64xf32>
      %146 = "tt.reduce"(%145) <{axis = 1 : i32}> ({
      ^bb0(%arg18: f32, %arg19: f32):
        %151 = arith.addf %arg18, %arg19 : f32
        tt.reduce.return %151 : f32
      }) : (tensor<64x64xf32>) -> tensor<64xf32>
      %147 = arith.addf %arg17, %146 : tensor<64xf32>
      %148 = arith.truncf %138 : tensor<64x64xf32> to tensor<64x64xbf16>
      %149 = tt.splat %103 : !tt.ptr<bf16> -> tensor<64x64x!tt.ptr<bf16>>
      %150 = tt.addptr %149, %113 : tensor<64x64x!tt.ptr<bf16>>, tensor<64x64xi64>
      tt.store %150, %148, %124 : tensor<64x64x!tt.ptr<bf16>>
      scf.yield %144, %136, %147 : tensor<64xf32>, tensor<64x64xf32>, tensor<64xf32>
    }
    %51:2 = scf.for %arg14 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg15 = %50#0, %arg16 = %50#1) -> (tensor<64xf32>, tensor<64x64xf32>)  : i32 {
      %99 = arith.muli %6, %c128_i32 : i32
      %100 = tt.addptr %arg1, %99 : !tt.ptr<bf16>, i32
      %101 = arith.muli %arg14, %c64_i32 : i32
      %102 = arith.extsi %101 : i32 to i64
      %103 = tt.addptr %arg8, %99 : !tt.ptr<bf16>, i32
      %104 = tt.addptr %arg6, %99 : !tt.ptr<bf16>, i32
      %105 = tt.splat %100 : !tt.ptr<bf16> -> tensor<64x64x!tt.ptr<bf16>>
      %106 = tt.expand_dims %18 {axis = 1 : i32} : tensor<64xi64> -> tensor<64x1xi64>
      %107 = arith.muli %106, %cst_0 : tensor<64x1xi64>
      %108 = tt.broadcast %107 : tensor<64x1xi64> -> tensor<64x64xi64>
      %109 = tt.splat %102 : i64 -> tensor<64xi64>
      %110 = arith.addi %109, %17 : tensor<64xi64>
      %111 = tt.expand_dims %110 {axis = 0 : i32} : tensor<64xi64> -> tensor<1x64xi64>
      %112 = tt.broadcast %111 : tensor<1x64xi64> -> tensor<64x64xi64>
      %113 = arith.addi %108, %112 : tensor<64x64xi64>
      %114 = tt.addptr %105, %113 : tensor<64x64x!tt.ptr<bf16>>, tensor<64x64xi64>
      %115 = arith.cmpi sge, %106, %cst_3 : tensor<64x1xi64>
      %116 = tt.splat %9 : i64 -> tensor<64x1xi64>
      %117 = arith.cmpi slt, %106, %116 : tensor<64x1xi64>
      %118 = arith.andi %115, %117 : tensor<64x1xi1>
      %119 = tt.broadcast %118 : tensor<64x1xi1> -> tensor<64x64xi1>
      %120 = arith.cmpi sge, %111, %cst_1 : tensor<1x64xi64>
      %121 = arith.cmpi slt, %111, %cst : tensor<1x64xi64>
      %122 = arith.andi %120, %121 : tensor<1x64xi1>
      %123 = tt.broadcast %122 : tensor<1x64xi1> -> tensor<64x64xi1>
      %124 = arith.andi %119, %123 : tensor<64x64xi1>
      %125 = tt.load %114, %124 : tensor<64x64x!tt.ptr<bf16>>
      %126 = tt.expand_dims %25 {axis = 1 : i32} : tensor<64xf32> -> tensor<64x1xf32>
      %127 = arith.extf %125 : tensor<64x64xbf16> to tensor<64x64xf32>
      %128 = tt.broadcast %126 : tensor<64x1xf32> -> tensor<64x64xf32>
      %129 = arith.mulf %127, %128 : tensor<64x64xf32>
      %130 = arith.truncf %129 : tensor<64x64xf32> to tensor<64x64xbf16>
      %131 = tt.splat %104 : !tt.ptr<bf16> -> tensor<64x64x!tt.ptr<bf16>>
      %132 = tt.addptr %131, %113 : tensor<64x64x!tt.ptr<bf16>>, tensor<64x64xi64>
      %133 = tt.load %132, %124 : tensor<64x64x!tt.ptr<bf16>>
      %134 = tt.trans %130 {order = array<i32: 1, 0>} : tensor<64x64xbf16> -> tensor<64x64xbf16>
      %135 = tt.dot %133, %134, %arg16, inputPrecision = tf32 : tensor<64x64xbf16> * tensor<64x64xbf16> -> tensor<64x64xf32>
      %136 = tt.dot %44, %133, %cst_8, inputPrecision = tf32 : tensor<64x64xbf16> * tensor<64x64xbf16> -> tensor<64x64xf32>
      %137 = arith.mulf %136, %128 : tensor<64x64xf32>
      %138 = arith.mulf %136, %127 : tensor<64x64xf32>
      %139 = "tt.reduce"(%138) <{axis = 1 : i32}> ({
      ^bb0(%arg17: f32, %arg18: f32):
        %144 = arith.addf %arg17, %arg18 : f32
        tt.reduce.return %144 : f32
      }) : (tensor<64x64xf32>) -> tensor<64xf32>
      %140 = arith.addf %arg15, %139 : tensor<64xf32>
      %141 = arith.truncf %137 : tensor<64x64xf32> to tensor<64x64xbf16>
      %142 = tt.splat %103 : !tt.ptr<bf16> -> tensor<64x64x!tt.ptr<bf16>>
      %143 = tt.addptr %142, %113 : tensor<64x64x!tt.ptr<bf16>>, tensor<64x64xi64>
      tt.store %143, %141, %124 : tensor<64x64x!tt.ptr<bf16>>
      scf.yield %140, %135 : tensor<64xf32>, tensor<64x64xf32>
    }
    %52 = tt.splat %8 : i32 -> tensor<64xi32>
    %53 = arith.addi %52, %16 : tensor<64xi32>
    %54 = tt.splat %arg13 : i32 -> tensor<64xi32>
    %55 = arith.cmpi slt, %53, %54 : tensor<64xi32>
    %56 = tt.expand_dims %53 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32>
    %57 = tt.expand_dims %53 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32>
    %58 = tt.broadcast %56 : tensor<64x1xi32> -> tensor<64x64xi32>
    %59 = tt.broadcast %57 : tensor<1x64xi32> -> tensor<64x64xi32>
    %60 = arith.cmpi sgt, %58, %59 : tensor<64x64xi32>
    %61 = tt.expand_dims %55 {axis = 1 : i32} : tensor<64xi1> -> tensor<64x1xi1>
    %62 = tt.expand_dims %55 {axis = 0 : i32} : tensor<64xi1> -> tensor<1x64xi1>
    %63 = tt.broadcast %61 : tensor<64x1xi1> -> tensor<64x64xi1>
    %64 = tt.broadcast %62 : tensor<1x64xi1> -> tensor<64x64xi1>
    %65 = arith.andi %63, %64 : tensor<64x64xi1>
    %66 = arith.andi %60, %65 : tensor<64x64xi1>
    %67 = arith.select %66, %51#1, %cst_8 : tensor<64x64xi1>, tensor<64x64xf32>
    %68 = arith.truncf %67 : tensor<64x64xf32> to tensor<64x64xbf16>
    %69 = tt.dot %68, %44, %cst_8, inputPrecision = tf32 : tensor<64x64xbf16> * tensor<64x64xbf16> -> tensor<64x64xf32>
    %70 = arith.truncf %69 : tensor<64x64xf32> to tensor<64x64xbf16>
    %71 = tt.dot %44, %70, %cst_8, inputPrecision = tf32 : tensor<64x64xbf16> * tensor<64x64xbf16> -> tensor<64x64xf32>
    %72 = tt.expand_dims %48 {axis = 1 : i32} : tensor<64xf32> -> tensor<64x1xf32>
    %73 = tt.expand_dims %48 {axis = 0 : i32} : tensor<64xf32> -> tensor<1x64xf32>
    %74 = tt.broadcast %72 : tensor<64x1xf32> -> tensor<64x64xf32>
    %75 = tt.broadcast %73 : tensor<1x64xf32> -> tensor<64x64xf32>
    %76 = arith.subf %74, %75 : tensor<64x64xf32>
    %77 = math.exp %76 : tensor<64x64xf32>
    %78 = arith.mulf %71, %77 : tensor<64x64xf32>
    %79 = arith.subf %cst_8, %78 : tensor<64x64xf32>
    %80 = arith.select %66, %79, %cst_8 : tensor<64x64xi1>, tensor<64x64xf32>
    %81 = arith.truncf %80 : tensor<64x64xf32> to tensor<64x64xbf16>
    %82:2 = scf.for %arg14 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg15 = %51#0, %arg16 = %cst_8) -> (tensor<64xf32>, tensor<64x64xf32>)  : i32 {
      %99 = arith.muli %6, %c128_i32 : i32
      %100 = tt.addptr %arg0, %99 : !tt.ptr<bf16>, i32
      %101 = arith.muli %arg14, %c64_i32 : i32
      %102 = arith.extsi %101 : i32 to i64
      %103 = tt.addptr %arg7, %99 : !tt.ptr<bf16>, i32
      %104 = tt.splat %100 : !tt.ptr<bf16> -> tensor<64x64x!tt.ptr<bf16>>
      %105 = tt.expand_dims %18 {axis = 1 : i32} : tensor<64xi64> -> tensor<64x1xi64>
      %106 = arith.muli %105, %cst_0 : tensor<64x1xi64>
      %107 = tt.broadcast %106 : tensor<64x1xi64> -> tensor<64x64xi64>
      %108 = tt.splat %102 : i64 -> tensor<64xi64>
      %109 = arith.addi %108, %17 : tensor<64xi64>
      %110 = tt.expand_dims %109 {axis = 0 : i32} : tensor<64xi64> -> tensor<1x64xi64>
      %111 = tt.broadcast %110 : tensor<1x64xi64> -> tensor<64x64xi64>
      %112 = arith.addi %107, %111 : tensor<64x64xi64>
      %113 = tt.addptr %104, %112 : tensor<64x64x!tt.ptr<bf16>>, tensor<64x64xi64>
      %114 = arith.cmpi sge, %105, %cst_3 : tensor<64x1xi64>
      %115 = tt.splat %9 : i64 -> tensor<64x1xi64>
      %116 = arith.cmpi slt, %105, %115 : tensor<64x1xi64>
      %117 = arith.andi %114, %116 : tensor<64x1xi1>
      %118 = tt.broadcast %117 : tensor<64x1xi1> -> tensor<64x64xi1>
      %119 = arith.cmpi sge, %110, %cst_1 : tensor<1x64xi64>
      %120 = arith.cmpi slt, %110, %cst : tensor<1x64xi64>
      %121 = arith.andi %119, %120 : tensor<1x64xi1>
      %122 = tt.broadcast %121 : tensor<1x64xi1> -> tensor<64x64xi1>
      %123 = arith.andi %118, %122 : tensor<64x64xi1>
      %124 = tt.load %113, %123 : tensor<64x64x!tt.ptr<bf16>>
      %125 = tt.splat %103 : !tt.ptr<bf16> -> tensor<64x64x!tt.ptr<bf16>>
      %126 = tt.addptr %125, %112 : tensor<64x64x!tt.ptr<bf16>>, tensor<64x64xi64>
      %127 = tt.load %126, %123 : tensor<64x64x!tt.ptr<bf16>>
      %128 = tt.expand_dims %25 {axis = 1 : i32} : tensor<64xf32> -> tensor<64x1xf32>
      %129 = arith.extf %124 : tensor<64x64xbf16> to tensor<64x64xf32>
      %130 = tt.broadcast %128 : tensor<64x1xf32> -> tensor<64x64xf32>
      %131 = arith.mulf %129, %130 : tensor<64x64xf32>
      %132 = arith.truncf %131 : tensor<64x64xf32> to tensor<64x64xbf16>
      %133 = tt.trans %124 {order = array<i32: 1, 0>} : tensor<64x64xbf16> -> tensor<64x64xbf16>
      %134 = tt.dot %132, %133, %arg16, inputPrecision = tf32 : tensor<64x64xbf16> * tensor<64x64xbf16> -> tensor<64x64xf32>
      %135 = tt.dot %81, %124, %cst_8, inputPrecision = tf32 : tensor<64x64xbf16> * tensor<64x64xbf16> -> tensor<64x64xf32>
      %136 = arith.mulf %135, %129 : tensor<64x64xf32>
      %137 = "tt.reduce"(%136) <{axis = 1 : i32}> ({
      ^bb0(%arg17: f32, %arg18: f32):
        %151 = arith.addf %arg17, %arg18 : f32
        tt.reduce.return %151 : f32
      }) : (tensor<64x64xf32>) -> tensor<64xf32>
      %138 = arith.addf %arg15, %137 : tensor<64xf32>
      %139 = tt.trans %81 {order = array<i32: 1, 0>} : tensor<64x64xbf16> -> tensor<64x64xbf16>
      %140 = arith.extf %127 : tensor<64x64xbf16> to tensor<64x64xf32>
      %141 = tt.dot %139, %132, %140, inputPrecision = tf32 : tensor<64x64xbf16> * tensor<64x64xbf16> -> tensor<64x64xf32>
      %142 = arith.mulf %135, %130 : tensor<64x64xf32>
      %143 = arith.addf %141, %142 : tensor<64x64xf32>
      %144 = tt.addptr %arg11, %99 : !tt.ptr<bf16>, i32
      %145 = tt.splat %144 : !tt.ptr<bf16> -> tensor<64x64x!tt.ptr<bf16>>
      %146 = tt.addptr %145, %112 : tensor<64x64x!tt.ptr<bf16>>, tensor<64x64xi64>
      %147 = tt.load %146, %123 : tensor<64x64x!tt.ptr<bf16>>
      %148 = arith.extf %147 : tensor<64x64xbf16> to tensor<64x64xf32>
      %149 = arith.addf %143, %148 : tensor<64x64xf32>
      %150 = arith.truncf %149 : tensor<64x64xf32> to tensor<64x64xbf16>
      tt.store %126, %150, %123 : tensor<64x64x!tt.ptr<bf16>>
      scf.yield %138, %134 : tensor<64xf32>, tensor<64x64xf32>
    }
    %83 = tt.splat %11 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>>
    %84 = tt.addptr %83, %19 : tensor<64x!tt.ptr<f32>>, tensor<64xi64>
    tt.store %84, %82#0, %24 : tensor<64x!tt.ptr<f32>>
    %85 = arith.extf %81 : tensor<64x64xbf16> to tensor<64x64xf32>
    %86 = arith.mulf %85, %82#1 : tensor<64x64xf32>
    %87 = tt.addptr %arg10, %6 : !tt.ptr<f32>, i32
    %88 = "tt.reduce"(%86) <{axis = 1 : i32}> ({
    ^bb0(%arg14: f32, %arg15: f32):
      %99 = arith.addf %arg14, %arg15 : f32
      tt.reduce.return %99 : f32
    }) : (tensor<64x64xf32>) -> tensor<64xf32>
    %89 = "tt.reduce"(%86) <{axis = 0 : i32}> ({
    ^bb0(%arg14: f32, %arg15: f32):
      %99 = arith.addf %arg14, %arg15 : f32
      tt.reduce.return %99 : f32
    }) : (tensor<64x64xf32>) -> tensor<64xf32>
    %90 = arith.subf %88, %89 : tensor<64xf32>
    %91 = arith.addf %50#2, %90 : tensor<64xf32>
    %92 = tt.addptr %arg12, %6 : !tt.ptr<f32>, i32
    %93 = tt.splat %92 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>>
    %94 = tt.addptr %93, %19 : tensor<64x!tt.ptr<f32>>, tensor<64xi64>
    %95 = tt.load %94, %24 : tensor<64x!tt.ptr<f32>>
    %96 = arith.addf %91, %95 : tensor<64xf32>
    %97 = tt.splat %87 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>>
    %98 = tt.addptr %97, %19 : tensor<64x!tt.ptr<f32>>, tensor<64xi64>
    tt.store %98, %96, %24 : tensor<64x!tt.ptr<f32>>
    tt.return
  }
}

{-#
  external_resources: {
    mlir_reproducer: {
      pipeline: "builtin.module(convert-triton-to-tritongpu{enable-source-remat=false num-ctas=1 num-warps=4 target=cuda:100 threads-per-warp=32}, tritongpu-coalesce, tritongpu-F32DotTC, triton-nvidia-gpu-plan-cta, tritongpu-remove-layout-conversions, tritongpu-optimize-thread-locality, tritongpu-accelerate-matmul, tritongpu-remove-layout-conversions, tritongpu-optimize-dot-operands{hoist-layout-conversion=true}, triton-nvidia-optimize-descriptor-encoding, triton-loop-aware-cse, tritongpu-fuse-nested-loops, canonicalize{  max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, triton-licm, tritongpu-optimize-accumulator-init, tritongpu-hoist-tmem-alloc{hoist-out-of-if=false}, tritongpu-promote-lhs-to-tmem, tritongpu-assign-latencies{num-stages=2}, tritongpu-schedule-loops, tritongpu-automatic-warp-specialization{num-stages=2}, tritongpu-pipeline{dump-intermediate-steps=false num-stages=2}, tritongpu-combine-tensor-select-and-if, tritongpu-hoist-tmem-alloc{hoist-out-of-if=true}, triton-nvidia-gpu-remove-tmem-tokens, canonicalize{  max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, triton-loop-aware-cse, tritongpu-prefetch, tritongpu-optimize-dot-operands{hoist-layout-conversion=true}, tritongpu-coalesce-async-copy, triton-nvidia-optimize-tmem-layouts, tritongpu-remove-layout-conversions, triton-nvidia-interleave-tmem, tritongpu-reduce-data-duplication, tritongpu-reorder-instructions, triton-loop-aware-cse, symbol-dce, triton-nvidia-tma-lowering, triton-nvidia-gpu-fence-insertion{compute-capability=100}, triton-nvidia-mma-lowering, sccp, cse, canonicalize{  max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true})",
      disable_threading: false,
      verify_each: true
    }
  }
#-}
flash-linear-attention/fla/ops/gated_delta_rule/wy_fast.py:100:0: error: Failures have been detected while processing an MLIR pass pipeline
flash-linear-attention/fla/ops/gated_delta_rule/wy_fast.py:100:0: note: Pipeline failed while executing [`TritonGPUHoistTMEMAlloc` on 'builtin.module' operation]: reproducer generated at `std::errs, please share the reproducer above with Triton project.`
attention-fwd:
   seq_len  chunk_gated_delta_rule
0   8192.0                1.044992
Traceback (most recent call last):
  File "benchmark_gdn.py", line 134, in <module>
    benchmark.run(print_data=True,)
  File "/usr/local/lib/python3.12/dist-packages/triton/testing.py", line 392, in run
    result_dfs.append(self._run(bench, save_path, show_plots, print_data, **kwargs))
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/triton/testing.py", line 339, in _run
    ret = self.fn(**x_args, **{bench.line_arg: y}, **bench.args, **kwrags)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "benchmark_gdn.py", line 130, in benchmark
    res_time = triton.testing.do_bench(fn, warmup=5, rep=5)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/triton/testing.py", line 149, in do_bench
    fn()
  File "benchmark_gdn.py", line 129, in fn
    def fn(): return o.backward(do, retain_graph=True)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_tensor.py", line 625, in backward
    torch.autograd.backward(
  File "/usr/local/lib/python3.12/dist-packages/torch/autograd/__init__.py", line 354, in backward
    _engine_run_backward(
  File "/usr/local/lib/python3.12/dist-packages/torch/autograd/graph.py", line 841, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/autograd/function.py", line 315, in apply
    return user_fn(self, *args)
           ^^^^^^^^^^^^^^^^^^^^
  File "flash-linear-attention/fla/utils.py", line 175, in wrapper
    return fn(*contiguous_args, **contiguous_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/amp/autocast_mode.py", line 573, in decorate_bwd
    return bwd(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "flash-linear-attention/fla/ops/gated_delta_rule/chunk.py", line 224, in backward
    dq, dk, dv, db, dg, dh0 = chunk_gated_delta_rule_bwd(
                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "flash-linear-attention/fla/ops/gated_delta_rule/chunk.py", line 147, in chunk_gated_delta_rule_bwd
    dk, dv, db, dg = prepare_wy_repr_bwd(
                     ^^^^^^^^^^^^^^^^^^^^
  File "flash-linear-attention/fla/ops/gated_delta_rule/wy_fast.py", line 293, in prepare_wy_repr_bwd
    prepare_wy_repr_bwd_kernel[(NT, B * H)](
  File "/usr/local/lib/python3.12/dist-packages/triton/runtime/jit.py", line 419, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/triton/runtime/autotuner.py", line 452, in run
    return self.fn.run(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/triton/runtime/autotuner.py", line 236, in run
    used_cached_result = self.check_disk_cache(key, pruned_configs, benchmark)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/triton/runtime/autotuner.py", line 200, in check_disk_cache
    bench_fn()
  File "/usr/local/lib/python3.12/dist-packages/triton/runtime/autotuner.py", line 227, in benchmark
    timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/triton/runtime/autotuner.py", line 162, in _bench
    return self.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/triton/testing.py", line 149, in do_bench
    fn()
  File "/usr/local/lib/python3.12/dist-packages/triton/runtime/autotuner.py", line 148, in kernel_call
    self.fn.run(
  File "/usr/local/lib/python3.12/dist-packages/triton/runtime/jit.py", line 733, in run
    kernel = self._do_compile(key, signature, device, constexprs, options, attrs, warmup)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/triton/runtime/jit.py", line 861, in _do_compile
    kernel = self.compile(src, target=target, options=options.__dict__)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/triton/compiler/compiler.py", line 320, in compile
    next_module = compile_ir(module, metadata)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/triton/backends/nvidia/compiler.py", line 515, in <lambda>
    stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, capability)
                                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/triton/backends/nvidia/compiler.py", line 319, in make_ttgir
    pm.run(mod)
RuntimeError: PassManager::run failed

Steps to Reproduce the Bug

import torch
import torch.nn.functional as F
import triton.testing
from fla.ops.gated_delta_rule import chunk_gated_delta_rule

bs = 1
num_groups, num_heads, head_k_dim, head_v_dim = 32, 32, 128, 128

torch.empty(1, device='cuda', requires_grad=True).backward()


@triton.testing.perf_report([
    triton.testing.Benchmark(
        x_names=["seq_len"],
        x_vals=[1024],
        x_log=True,
        line_arg="provider",
        line_vals=[
            "chunk_gated_delta_rule",
        ],
        line_names=[
            "chunk_gated_delta_rule",
        ],
        ylabel="latency (ms)",
        plot_name=f"attention-{mode}",
        args={"mode": mode},
    )
    for mode in ['fwd', 'bwd']
])
def benchmark(seq_len, provider, mode):

    if provider == "chunk_gated_delta_rule":
        kwargs = {
            "q": torch.randn(bs, seq_len, num_groups, head_k_dim).bfloat16().cuda().requires_grad_(),
            "k": torch.randn(bs, seq_len, num_groups, head_k_dim).bfloat16().cuda().requires_grad_(),
            "v": torch.randn(bs, seq_len, num_heads, head_v_dim).bfloat16().cuda().requires_grad_(),
            "g": torch.randn(bs, seq_len, num_heads).cuda().requires_grad_(),
            "beta": torch.randn(bs, seq_len, num_heads).cuda().requires_grad_(),
            "initial_state": None,
            "output_final_state": False,
            "cu_seqlens": None,
            "head_first": False,
            "use_qk_l2norm_in_kernel": True,
        }
        def fn(): return chunk_gated_delta_rule(**kwargs)[0]
    else:
        raise ValueError(f"Unknown provider: {provider}")
    if mode == 'bwd':
        o = fn()
        do = torch.randn_like(o)
        def fn(): return o.backward(do, retain_graph=True)
    res_time = triton.testing.do_bench(fn)
    return res_time

benchmark.run(print_data=True,)

Expected Behavior

Expected Behavior
Not to crash 🙂

Environment Information

Torch: 2.9.0 or 2.8.0
Triton: 3.5
GPU: Blackwell

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingtriton-bugwontfixThis will not be worked on

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions