-
Notifications
You must be signed in to change notification settings - Fork 322
Open
Labels
bugSomething isn't workingSomething isn't workingtriton-bugwontfixThis will not be worked onThis will not be worked on
Description
Checklist
- I have checked FAQs and existing issues for similar problems
- Please report this bug in English to ensure wider understanding and support
Describe the Bug
When executing GDN backwards on Blackwell triton 3.5.0, the following error happens.
flash-linear-attention/fla/ops/gated_delta_rule/chunk.py:331: UserWarning: head_first is deprecated and will be removed in a future version. Please use head_first=False for now instead.
warnings.warn(
flash-linear-attention/fla/ops/gated_delta_rule/chunk.py:331: UserWarning: head_first is deprecated and will be removed in a future version. Please use head_first=False for now instead.
warnings.warn(
flash-linear-attention/fla/ops/gated_delta_rule/wy_fast.py:208:39: error: operand #2 does not dominate this use
b_dk += tl.dot(tl.trans(b_dA), b_kb)
^
flash-linear-attention/fla/ops/gated_delta_rule/wy_fast.py:208:16: note: operand defined here (op in the same block)
b_dk += tl.dot(tl.trans(b_dA), b_kb)
^
module {
tt.func public @prepare_wy_repr_bwd_kernel(%arg0: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, %arg6: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, %arg7: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, %arg8: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, %arg9: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg10: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg11: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, %arg12: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg13: i32) attributes {noinline = false} {
%cst = arith.constant dense<128> : tensor<1x64xi64>
%cst_0 = arith.constant dense<4096> : tensor<64x1xi64>
%cst_1 = arith.constant dense<0> : tensor<1x64xi64>
%cst_2 = arith.constant dense<64> : tensor<64x1xi64>
%cst_3 = arith.constant dense<0> : tensor<64x1xi64>
%cst_4 = arith.constant dense<2048> : tensor<1x64xi64>
%cst_5 = arith.constant dense<0> : tensor<64xi64>
%cst_6 = arith.constant dense<32> : tensor<64xi64>
%c2_i32 = arith.constant 2 : i32
%cst_7 = arith.constant dense<0.000000e+00> : tensor<64xf32>
%c1_i32 = arith.constant 1 : i32
%cst_8 = arith.constant dense<0.000000e+00> : tensor<64x64xf32>
%c128_i32 = arith.constant 128 : i32
%c0_i32 = arith.constant 0 : i32
%c64_i32 = arith.constant 64 : i32
%c32_i32 = arith.constant 32 : i32
%0 = tt.get_program_id x : i32
%1 = tt.get_program_id y : i32
%2 = arith.divsi %1, %c32_i32 : i32
%3 = arith.remsi %1, %c32_i32 : i32
%4 = arith.muli %2, %arg13 : i32
%5 = arith.muli %4, %c32_i32 : i32
%6 = arith.addi %5, %3 : i32
%7 = tt.addptr %arg2, %6 : !tt.ptr<f32>, i32
%8 = arith.muli %0, %c64_i32 : i32
%9 = arith.extsi %arg13 : i32 to i64
%10 = arith.extsi %8 : i32 to i64
%11 = tt.addptr %arg9, %6 : !tt.ptr<f32>, i32
%12 = arith.muli %6, %c64_i32 : i32
%13 = tt.addptr %arg4, %12 : !tt.ptr<bf16>, i32
%14 = tt.splat %7 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>>
%15 = tt.splat %10 : i64 -> tensor<64xi64>
%16 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%17 = arith.extsi %16 : tensor<64xi32> to tensor<64xi64>
%18 = arith.addi %15, %17 : tensor<64xi64>
%19 = arith.muli %18, %cst_6 : tensor<64xi64>
%20 = tt.addptr %14, %19 : tensor<64x!tt.ptr<f32>>, tensor<64xi64>
%21 = arith.cmpi sge, %18, %cst_5 : tensor<64xi64>
%22 = tt.splat %9 : i64 -> tensor<64xi64>
%23 = arith.cmpi slt, %18, %22 : tensor<64xi64>
%24 = arith.andi %21, %23 : tensor<64xi1>
%25 = tt.load %20, %24 : tensor<64x!tt.ptr<f32>>
%26 = tt.splat %13 : !tt.ptr<bf16> -> tensor<64x64x!tt.ptr<bf16>>
%27 = tt.expand_dims %17 {axis = 1 : i32} : tensor<64xi64> -> tensor<64x1xi64>
%28 = tt.broadcast %27 : tensor<64x1xi64> -> tensor<64x64xi64>
%29 = tt.expand_dims %18 {axis = 0 : i32} : tensor<64xi64> -> tensor<1x64xi64>
%30 = arith.muli %29, %cst_4 : tensor<1x64xi64>
%31 = tt.broadcast %30 : tensor<1x64xi64> -> tensor<64x64xi64>
%32 = arith.addi %28, %31 : tensor<64x64xi64>
%33 = tt.addptr %26, %32 : tensor<64x64x!tt.ptr<bf16>>, tensor<64x64xi64>
%34 = arith.cmpi sge, %27, %cst_3 : tensor<64x1xi64>
%35 = arith.cmpi slt, %27, %cst_2 : tensor<64x1xi64>
%36 = arith.andi %34, %35 : tensor<64x1xi1>
%37 = tt.broadcast %36 : tensor<64x1xi1> -> tensor<64x64xi1>
%38 = arith.cmpi sge, %29, %cst_1 : tensor<1x64xi64>
%39 = tt.splat %9 : i64 -> tensor<1x64xi64>
%40 = arith.cmpi slt, %29, %39 : tensor<1x64xi64>
%41 = arith.andi %38, %40 : tensor<1x64xi1>
%42 = tt.broadcast %41 : tensor<1x64xi1> -> tensor<64x64xi1>
%43 = arith.andi %37, %42 : tensor<64x64xi1>
%44 = tt.load %33, %43 : tensor<64x64x!tt.ptr<bf16>>
%45 = tt.addptr %arg3, %6 : !tt.ptr<f32>, i32
%46 = tt.splat %45 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>>
%47 = tt.addptr %46, %19 : tensor<64x!tt.ptr<f32>>, tensor<64xi64>
%48 = tt.load %47, %24 : tensor<64x!tt.ptr<f32>>
%49 = math.exp %48 : tensor<64xf32>
%50:3 = scf.for %arg14 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg15 = %cst_7, %arg16 = %cst_8, %arg17 = %cst_7) -> (tensor<64xf32>, tensor<64x64xf32>, tensor<64xf32>) : i32 {
%99 = arith.muli %6, %c128_i32 : i32
%100 = tt.addptr %arg0, %99 : !tt.ptr<bf16>, i32
%101 = arith.muli %arg14, %c64_i32 : i32
%102 = arith.extsi %101 : i32 to i64
%103 = tt.addptr %arg7, %99 : !tt.ptr<bf16>, i32
%104 = tt.addptr %arg5, %99 : !tt.ptr<bf16>, i32
%105 = tt.splat %100 : !tt.ptr<bf16> -> tensor<64x64x!tt.ptr<bf16>>
%106 = tt.expand_dims %18 {axis = 1 : i32} : tensor<64xi64> -> tensor<64x1xi64>
%107 = arith.muli %106, %cst_0 : tensor<64x1xi64>
%108 = tt.broadcast %107 : tensor<64x1xi64> -> tensor<64x64xi64>
%109 = tt.splat %102 : i64 -> tensor<64xi64>
%110 = arith.addi %109, %17 : tensor<64xi64>
%111 = tt.expand_dims %110 {axis = 0 : i32} : tensor<64xi64> -> tensor<1x64xi64>
%112 = tt.broadcast %111 : tensor<1x64xi64> -> tensor<64x64xi64>
%113 = arith.addi %108, %112 : tensor<64x64xi64>
%114 = tt.addptr %105, %113 : tensor<64x64x!tt.ptr<bf16>>, tensor<64x64xi64>
%115 = arith.cmpi sge, %106, %cst_3 : tensor<64x1xi64>
%116 = tt.splat %9 : i64 -> tensor<64x1xi64>
%117 = arith.cmpi slt, %106, %116 : tensor<64x1xi64>
%118 = arith.andi %115, %117 : tensor<64x1xi1>
%119 = tt.broadcast %118 : tensor<64x1xi1> -> tensor<64x64xi1>
%120 = arith.cmpi sge, %111, %cst_1 : tensor<1x64xi64>
%121 = arith.cmpi slt, %111, %cst : tensor<1x64xi64>
%122 = arith.andi %120, %121 : tensor<1x64xi1>
%123 = tt.broadcast %122 : tensor<1x64xi1> -> tensor<64x64xi1>
%124 = arith.andi %119, %123 : tensor<64x64xi1>
%125 = tt.load %114, %124 : tensor<64x64x!tt.ptr<bf16>>
%126 = arith.mulf %25, %49 : tensor<64xf32>
%127 = tt.expand_dims %126 {axis = 1 : i32} : tensor<64xf32> -> tensor<64x1xf32>
%128 = arith.extf %125 : tensor<64x64xbf16> to tensor<64x64xf32>
%129 = tt.broadcast %127 : tensor<64x1xf32> -> tensor<64x64xf32>
%130 = arith.mulf %128, %129 : tensor<64x64xf32>
%131 = tt.splat %104 : !tt.ptr<bf16> -> tensor<64x64x!tt.ptr<bf16>>
%132 = tt.addptr %131, %113 : tensor<64x64x!tt.ptr<bf16>>, tensor<64x64xi64>
%133 = tt.load %132, %124 : tensor<64x64x!tt.ptr<bf16>>
%134 = tt.trans %130 {order = array<i32: 1, 0>} : tensor<64x64xf32> -> tensor<64x64xf32>
%135 = arith.truncf %134 : tensor<64x64xf32> to tensor<64x64xbf16>
%136 = tt.dot %133, %135, %arg16, inputPrecision = tf32 : tensor<64x64xbf16> * tensor<64x64xbf16> -> tensor<64x64xf32>
%137 = tt.dot %44, %133, %cst_8, inputPrecision = tf32 : tensor<64x64xbf16> * tensor<64x64xbf16> -> tensor<64x64xf32>
%138 = arith.mulf %137, %129 : tensor<64x64xf32>
%139 = arith.mulf %137, %128 : tensor<64x64xf32>
%140 = tt.expand_dims %49 {axis = 1 : i32} : tensor<64xf32> -> tensor<64x1xf32>
%141 = tt.broadcast %140 : tensor<64x1xf32> -> tensor<64x64xf32>
%142 = arith.mulf %139, %141 : tensor<64x64xf32>
%143 = "tt.reduce"(%142) <{axis = 1 : i32}> ({
^bb0(%arg18: f32, %arg19: f32):
%151 = arith.addf %arg18, %arg19 : f32
tt.reduce.return %151 : f32
}) : (tensor<64x64xf32>) -> tensor<64xf32>
%144 = arith.addf %arg15, %143 : tensor<64xf32>
%145 = arith.mulf %137, %130 : tensor<64x64xf32>
%146 = "tt.reduce"(%145) <{axis = 1 : i32}> ({
^bb0(%arg18: f32, %arg19: f32):
%151 = arith.addf %arg18, %arg19 : f32
tt.reduce.return %151 : f32
}) : (tensor<64x64xf32>) -> tensor<64xf32>
%147 = arith.addf %arg17, %146 : tensor<64xf32>
%148 = arith.truncf %138 : tensor<64x64xf32> to tensor<64x64xbf16>
%149 = tt.splat %103 : !tt.ptr<bf16> -> tensor<64x64x!tt.ptr<bf16>>
%150 = tt.addptr %149, %113 : tensor<64x64x!tt.ptr<bf16>>, tensor<64x64xi64>
tt.store %150, %148, %124 : tensor<64x64x!tt.ptr<bf16>>
scf.yield %144, %136, %147 : tensor<64xf32>, tensor<64x64xf32>, tensor<64xf32>
}
%51:2 = scf.for %arg14 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg15 = %50#0, %arg16 = %50#1) -> (tensor<64xf32>, tensor<64x64xf32>) : i32 {
%99 = arith.muli %6, %c128_i32 : i32
%100 = tt.addptr %arg1, %99 : !tt.ptr<bf16>, i32
%101 = arith.muli %arg14, %c64_i32 : i32
%102 = arith.extsi %101 : i32 to i64
%103 = tt.addptr %arg8, %99 : !tt.ptr<bf16>, i32
%104 = tt.addptr %arg6, %99 : !tt.ptr<bf16>, i32
%105 = tt.splat %100 : !tt.ptr<bf16> -> tensor<64x64x!tt.ptr<bf16>>
%106 = tt.expand_dims %18 {axis = 1 : i32} : tensor<64xi64> -> tensor<64x1xi64>
%107 = arith.muli %106, %cst_0 : tensor<64x1xi64>
%108 = tt.broadcast %107 : tensor<64x1xi64> -> tensor<64x64xi64>
%109 = tt.splat %102 : i64 -> tensor<64xi64>
%110 = arith.addi %109, %17 : tensor<64xi64>
%111 = tt.expand_dims %110 {axis = 0 : i32} : tensor<64xi64> -> tensor<1x64xi64>
%112 = tt.broadcast %111 : tensor<1x64xi64> -> tensor<64x64xi64>
%113 = arith.addi %108, %112 : tensor<64x64xi64>
%114 = tt.addptr %105, %113 : tensor<64x64x!tt.ptr<bf16>>, tensor<64x64xi64>
%115 = arith.cmpi sge, %106, %cst_3 : tensor<64x1xi64>
%116 = tt.splat %9 : i64 -> tensor<64x1xi64>
%117 = arith.cmpi slt, %106, %116 : tensor<64x1xi64>
%118 = arith.andi %115, %117 : tensor<64x1xi1>
%119 = tt.broadcast %118 : tensor<64x1xi1> -> tensor<64x64xi1>
%120 = arith.cmpi sge, %111, %cst_1 : tensor<1x64xi64>
%121 = arith.cmpi slt, %111, %cst : tensor<1x64xi64>
%122 = arith.andi %120, %121 : tensor<1x64xi1>
%123 = tt.broadcast %122 : tensor<1x64xi1> -> tensor<64x64xi1>
%124 = arith.andi %119, %123 : tensor<64x64xi1>
%125 = tt.load %114, %124 : tensor<64x64x!tt.ptr<bf16>>
%126 = tt.expand_dims %25 {axis = 1 : i32} : tensor<64xf32> -> tensor<64x1xf32>
%127 = arith.extf %125 : tensor<64x64xbf16> to tensor<64x64xf32>
%128 = tt.broadcast %126 : tensor<64x1xf32> -> tensor<64x64xf32>
%129 = arith.mulf %127, %128 : tensor<64x64xf32>
%130 = arith.truncf %129 : tensor<64x64xf32> to tensor<64x64xbf16>
%131 = tt.splat %104 : !tt.ptr<bf16> -> tensor<64x64x!tt.ptr<bf16>>
%132 = tt.addptr %131, %113 : tensor<64x64x!tt.ptr<bf16>>, tensor<64x64xi64>
%133 = tt.load %132, %124 : tensor<64x64x!tt.ptr<bf16>>
%134 = tt.trans %130 {order = array<i32: 1, 0>} : tensor<64x64xbf16> -> tensor<64x64xbf16>
%135 = tt.dot %133, %134, %arg16, inputPrecision = tf32 : tensor<64x64xbf16> * tensor<64x64xbf16> -> tensor<64x64xf32>
%136 = tt.dot %44, %133, %cst_8, inputPrecision = tf32 : tensor<64x64xbf16> * tensor<64x64xbf16> -> tensor<64x64xf32>
%137 = arith.mulf %136, %128 : tensor<64x64xf32>
%138 = arith.mulf %136, %127 : tensor<64x64xf32>
%139 = "tt.reduce"(%138) <{axis = 1 : i32}> ({
^bb0(%arg17: f32, %arg18: f32):
%144 = arith.addf %arg17, %arg18 : f32
tt.reduce.return %144 : f32
}) : (tensor<64x64xf32>) -> tensor<64xf32>
%140 = arith.addf %arg15, %139 : tensor<64xf32>
%141 = arith.truncf %137 : tensor<64x64xf32> to tensor<64x64xbf16>
%142 = tt.splat %103 : !tt.ptr<bf16> -> tensor<64x64x!tt.ptr<bf16>>
%143 = tt.addptr %142, %113 : tensor<64x64x!tt.ptr<bf16>>, tensor<64x64xi64>
tt.store %143, %141, %124 : tensor<64x64x!tt.ptr<bf16>>
scf.yield %140, %135 : tensor<64xf32>, tensor<64x64xf32>
}
%52 = tt.splat %8 : i32 -> tensor<64xi32>
%53 = arith.addi %52, %16 : tensor<64xi32>
%54 = tt.splat %arg13 : i32 -> tensor<64xi32>
%55 = arith.cmpi slt, %53, %54 : tensor<64xi32>
%56 = tt.expand_dims %53 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32>
%57 = tt.expand_dims %53 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32>
%58 = tt.broadcast %56 : tensor<64x1xi32> -> tensor<64x64xi32>
%59 = tt.broadcast %57 : tensor<1x64xi32> -> tensor<64x64xi32>
%60 = arith.cmpi sgt, %58, %59 : tensor<64x64xi32>
%61 = tt.expand_dims %55 {axis = 1 : i32} : tensor<64xi1> -> tensor<64x1xi1>
%62 = tt.expand_dims %55 {axis = 0 : i32} : tensor<64xi1> -> tensor<1x64xi1>
%63 = tt.broadcast %61 : tensor<64x1xi1> -> tensor<64x64xi1>
%64 = tt.broadcast %62 : tensor<1x64xi1> -> tensor<64x64xi1>
%65 = arith.andi %63, %64 : tensor<64x64xi1>
%66 = arith.andi %60, %65 : tensor<64x64xi1>
%67 = arith.select %66, %51#1, %cst_8 : tensor<64x64xi1>, tensor<64x64xf32>
%68 = arith.truncf %67 : tensor<64x64xf32> to tensor<64x64xbf16>
%69 = tt.dot %68, %44, %cst_8, inputPrecision = tf32 : tensor<64x64xbf16> * tensor<64x64xbf16> -> tensor<64x64xf32>
%70 = arith.truncf %69 : tensor<64x64xf32> to tensor<64x64xbf16>
%71 = tt.dot %44, %70, %cst_8, inputPrecision = tf32 : tensor<64x64xbf16> * tensor<64x64xbf16> -> tensor<64x64xf32>
%72 = tt.expand_dims %48 {axis = 1 : i32} : tensor<64xf32> -> tensor<64x1xf32>
%73 = tt.expand_dims %48 {axis = 0 : i32} : tensor<64xf32> -> tensor<1x64xf32>
%74 = tt.broadcast %72 : tensor<64x1xf32> -> tensor<64x64xf32>
%75 = tt.broadcast %73 : tensor<1x64xf32> -> tensor<64x64xf32>
%76 = arith.subf %74, %75 : tensor<64x64xf32>
%77 = math.exp %76 : tensor<64x64xf32>
%78 = arith.mulf %71, %77 : tensor<64x64xf32>
%79 = arith.subf %cst_8, %78 : tensor<64x64xf32>
%80 = arith.select %66, %79, %cst_8 : tensor<64x64xi1>, tensor<64x64xf32>
%81 = arith.truncf %80 : tensor<64x64xf32> to tensor<64x64xbf16>
%82:2 = scf.for %arg14 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg15 = %51#0, %arg16 = %cst_8) -> (tensor<64xf32>, tensor<64x64xf32>) : i32 {
%99 = arith.muli %6, %c128_i32 : i32
%100 = tt.addptr %arg0, %99 : !tt.ptr<bf16>, i32
%101 = arith.muli %arg14, %c64_i32 : i32
%102 = arith.extsi %101 : i32 to i64
%103 = tt.addptr %arg7, %99 : !tt.ptr<bf16>, i32
%104 = tt.splat %100 : !tt.ptr<bf16> -> tensor<64x64x!tt.ptr<bf16>>
%105 = tt.expand_dims %18 {axis = 1 : i32} : tensor<64xi64> -> tensor<64x1xi64>
%106 = arith.muli %105, %cst_0 : tensor<64x1xi64>
%107 = tt.broadcast %106 : tensor<64x1xi64> -> tensor<64x64xi64>
%108 = tt.splat %102 : i64 -> tensor<64xi64>
%109 = arith.addi %108, %17 : tensor<64xi64>
%110 = tt.expand_dims %109 {axis = 0 : i32} : tensor<64xi64> -> tensor<1x64xi64>
%111 = tt.broadcast %110 : tensor<1x64xi64> -> tensor<64x64xi64>
%112 = arith.addi %107, %111 : tensor<64x64xi64>
%113 = tt.addptr %104, %112 : tensor<64x64x!tt.ptr<bf16>>, tensor<64x64xi64>
%114 = arith.cmpi sge, %105, %cst_3 : tensor<64x1xi64>
%115 = tt.splat %9 : i64 -> tensor<64x1xi64>
%116 = arith.cmpi slt, %105, %115 : tensor<64x1xi64>
%117 = arith.andi %114, %116 : tensor<64x1xi1>
%118 = tt.broadcast %117 : tensor<64x1xi1> -> tensor<64x64xi1>
%119 = arith.cmpi sge, %110, %cst_1 : tensor<1x64xi64>
%120 = arith.cmpi slt, %110, %cst : tensor<1x64xi64>
%121 = arith.andi %119, %120 : tensor<1x64xi1>
%122 = tt.broadcast %121 : tensor<1x64xi1> -> tensor<64x64xi1>
%123 = arith.andi %118, %122 : tensor<64x64xi1>
%124 = tt.load %113, %123 : tensor<64x64x!tt.ptr<bf16>>
%125 = tt.splat %103 : !tt.ptr<bf16> -> tensor<64x64x!tt.ptr<bf16>>
%126 = tt.addptr %125, %112 : tensor<64x64x!tt.ptr<bf16>>, tensor<64x64xi64>
%127 = tt.load %126, %123 : tensor<64x64x!tt.ptr<bf16>>
%128 = tt.expand_dims %25 {axis = 1 : i32} : tensor<64xf32> -> tensor<64x1xf32>
%129 = arith.extf %124 : tensor<64x64xbf16> to tensor<64x64xf32>
%130 = tt.broadcast %128 : tensor<64x1xf32> -> tensor<64x64xf32>
%131 = arith.mulf %129, %130 : tensor<64x64xf32>
%132 = arith.truncf %131 : tensor<64x64xf32> to tensor<64x64xbf16>
%133 = tt.trans %124 {order = array<i32: 1, 0>} : tensor<64x64xbf16> -> tensor<64x64xbf16>
%134 = tt.dot %132, %133, %arg16, inputPrecision = tf32 : tensor<64x64xbf16> * tensor<64x64xbf16> -> tensor<64x64xf32>
%135 = tt.dot %81, %124, %cst_8, inputPrecision = tf32 : tensor<64x64xbf16> * tensor<64x64xbf16> -> tensor<64x64xf32>
%136 = arith.mulf %135, %129 : tensor<64x64xf32>
%137 = "tt.reduce"(%136) <{axis = 1 : i32}> ({
^bb0(%arg17: f32, %arg18: f32):
%151 = arith.addf %arg17, %arg18 : f32
tt.reduce.return %151 : f32
}) : (tensor<64x64xf32>) -> tensor<64xf32>
%138 = arith.addf %arg15, %137 : tensor<64xf32>
%139 = tt.trans %81 {order = array<i32: 1, 0>} : tensor<64x64xbf16> -> tensor<64x64xbf16>
%140 = arith.extf %127 : tensor<64x64xbf16> to tensor<64x64xf32>
%141 = tt.dot %139, %132, %140, inputPrecision = tf32 : tensor<64x64xbf16> * tensor<64x64xbf16> -> tensor<64x64xf32>
%142 = arith.mulf %135, %130 : tensor<64x64xf32>
%143 = arith.addf %141, %142 : tensor<64x64xf32>
%144 = tt.addptr %arg11, %99 : !tt.ptr<bf16>, i32
%145 = tt.splat %144 : !tt.ptr<bf16> -> tensor<64x64x!tt.ptr<bf16>>
%146 = tt.addptr %145, %112 : tensor<64x64x!tt.ptr<bf16>>, tensor<64x64xi64>
%147 = tt.load %146, %123 : tensor<64x64x!tt.ptr<bf16>>
%148 = arith.extf %147 : tensor<64x64xbf16> to tensor<64x64xf32>
%149 = arith.addf %143, %148 : tensor<64x64xf32>
%150 = arith.truncf %149 : tensor<64x64xf32> to tensor<64x64xbf16>
tt.store %126, %150, %123 : tensor<64x64x!tt.ptr<bf16>>
scf.yield %138, %134 : tensor<64xf32>, tensor<64x64xf32>
}
%83 = tt.splat %11 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>>
%84 = tt.addptr %83, %19 : tensor<64x!tt.ptr<f32>>, tensor<64xi64>
tt.store %84, %82#0, %24 : tensor<64x!tt.ptr<f32>>
%85 = arith.extf %81 : tensor<64x64xbf16> to tensor<64x64xf32>
%86 = arith.mulf %85, %82#1 : tensor<64x64xf32>
%87 = tt.addptr %arg10, %6 : !tt.ptr<f32>, i32
%88 = "tt.reduce"(%86) <{axis = 1 : i32}> ({
^bb0(%arg14: f32, %arg15: f32):
%99 = arith.addf %arg14, %arg15 : f32
tt.reduce.return %99 : f32
}) : (tensor<64x64xf32>) -> tensor<64xf32>
%89 = "tt.reduce"(%86) <{axis = 0 : i32}> ({
^bb0(%arg14: f32, %arg15: f32):
%99 = arith.addf %arg14, %arg15 : f32
tt.reduce.return %99 : f32
}) : (tensor<64x64xf32>) -> tensor<64xf32>
%90 = arith.subf %88, %89 : tensor<64xf32>
%91 = arith.addf %50#2, %90 : tensor<64xf32>
%92 = tt.addptr %arg12, %6 : !tt.ptr<f32>, i32
%93 = tt.splat %92 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>>
%94 = tt.addptr %93, %19 : tensor<64x!tt.ptr<f32>>, tensor<64xi64>
%95 = tt.load %94, %24 : tensor<64x!tt.ptr<f32>>
%96 = arith.addf %91, %95 : tensor<64xf32>
%97 = tt.splat %87 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>>
%98 = tt.addptr %97, %19 : tensor<64x!tt.ptr<f32>>, tensor<64xi64>
tt.store %98, %96, %24 : tensor<64x!tt.ptr<f32>>
tt.return
}
}
{-#
external_resources: {
mlir_reproducer: {
pipeline: "builtin.module(convert-triton-to-tritongpu{enable-source-remat=false num-ctas=1 num-warps=4 target=cuda:100 threads-per-warp=32}, tritongpu-coalesce, tritongpu-F32DotTC, triton-nvidia-gpu-plan-cta, tritongpu-remove-layout-conversions, tritongpu-optimize-thread-locality, tritongpu-accelerate-matmul, tritongpu-remove-layout-conversions, tritongpu-optimize-dot-operands{hoist-layout-conversion=true}, triton-nvidia-optimize-descriptor-encoding, triton-loop-aware-cse, tritongpu-fuse-nested-loops, canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, triton-licm, tritongpu-optimize-accumulator-init, tritongpu-hoist-tmem-alloc{hoist-out-of-if=false}, tritongpu-promote-lhs-to-tmem, tritongpu-assign-latencies{num-stages=2}, tritongpu-schedule-loops, tritongpu-automatic-warp-specialization{num-stages=2}, tritongpu-pipeline{dump-intermediate-steps=false num-stages=2}, tritongpu-combine-tensor-select-and-if, tritongpu-hoist-tmem-alloc{hoist-out-of-if=true}, triton-nvidia-gpu-remove-tmem-tokens, canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, triton-loop-aware-cse, tritongpu-prefetch, tritongpu-optimize-dot-operands{hoist-layout-conversion=true}, tritongpu-coalesce-async-copy, triton-nvidia-optimize-tmem-layouts, tritongpu-remove-layout-conversions, triton-nvidia-interleave-tmem, tritongpu-reduce-data-duplication, tritongpu-reorder-instructions, triton-loop-aware-cse, symbol-dce, triton-nvidia-tma-lowering, triton-nvidia-gpu-fence-insertion{compute-capability=100}, triton-nvidia-mma-lowering, sccp, cse, canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true})",
disable_threading: false,
verify_each: true
}
}
#-}
flash-linear-attention/fla/ops/gated_delta_rule/wy_fast.py:100:0: error: Failures have been detected while processing an MLIR pass pipeline
flash-linear-attention/fla/ops/gated_delta_rule/wy_fast.py:100:0: note: Pipeline failed while executing [`TritonGPUHoistTMEMAlloc` on 'builtin.module' operation]: reproducer generated at `std::errs, please share the reproducer above with Triton project.`
attention-fwd:
seq_len chunk_gated_delta_rule
0 8192.0 1.044992
Traceback (most recent call last):
File "benchmark_gdn.py", line 134, in <module>
benchmark.run(print_data=True,)
File "/usr/local/lib/python3.12/dist-packages/triton/testing.py", line 392, in run
result_dfs.append(self._run(bench, save_path, show_plots, print_data, **kwargs))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/triton/testing.py", line 339, in _run
ret = self.fn(**x_args, **{bench.line_arg: y}, **bench.args, **kwrags)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "benchmark_gdn.py", line 130, in benchmark
res_time = triton.testing.do_bench(fn, warmup=5, rep=5)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/triton/testing.py", line 149, in do_bench
fn()
File "benchmark_gdn.py", line 129, in fn
def fn(): return o.backward(do, retain_graph=True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_tensor.py", line 625, in backward
torch.autograd.backward(
File "/usr/local/lib/python3.12/dist-packages/torch/autograd/__init__.py", line 354, in backward
_engine_run_backward(
File "/usr/local/lib/python3.12/dist-packages/torch/autograd/graph.py", line 841, in _engine_run_backward
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/autograd/function.py", line 315, in apply
return user_fn(self, *args)
^^^^^^^^^^^^^^^^^^^^
File "flash-linear-attention/fla/utils.py", line 175, in wrapper
return fn(*contiguous_args, **contiguous_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/amp/autocast_mode.py", line 573, in decorate_bwd
return bwd(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^
File "flash-linear-attention/fla/ops/gated_delta_rule/chunk.py", line 224, in backward
dq, dk, dv, db, dg, dh0 = chunk_gated_delta_rule_bwd(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "flash-linear-attention/fla/ops/gated_delta_rule/chunk.py", line 147, in chunk_gated_delta_rule_bwd
dk, dv, db, dg = prepare_wy_repr_bwd(
^^^^^^^^^^^^^^^^^^^^
File "flash-linear-attention/fla/ops/gated_delta_rule/wy_fast.py", line 293, in prepare_wy_repr_bwd
prepare_wy_repr_bwd_kernel[(NT, B * H)](
File "/usr/local/lib/python3.12/dist-packages/triton/runtime/jit.py", line 419, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/triton/runtime/autotuner.py", line 452, in run
return self.fn.run(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/triton/runtime/autotuner.py", line 236, in run
used_cached_result = self.check_disk_cache(key, pruned_configs, benchmark)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/triton/runtime/autotuner.py", line 200, in check_disk_cache
bench_fn()
File "/usr/local/lib/python3.12/dist-packages/triton/runtime/autotuner.py", line 227, in benchmark
timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/triton/runtime/autotuner.py", line 162, in _bench
return self.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/triton/testing.py", line 149, in do_bench
fn()
File "/usr/local/lib/python3.12/dist-packages/triton/runtime/autotuner.py", line 148, in kernel_call
self.fn.run(
File "/usr/local/lib/python3.12/dist-packages/triton/runtime/jit.py", line 733, in run
kernel = self._do_compile(key, signature, device, constexprs, options, attrs, warmup)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/triton/runtime/jit.py", line 861, in _do_compile
kernel = self.compile(src, target=target, options=options.__dict__)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/triton/compiler/compiler.py", line 320, in compile
next_module = compile_ir(module, metadata)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/triton/backends/nvidia/compiler.py", line 515, in <lambda>
stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, capability)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/triton/backends/nvidia/compiler.py", line 319, in make_ttgir
pm.run(mod)
RuntimeError: PassManager::run failed
Steps to Reproduce the Bug
import torch
import torch.nn.functional as F
import triton.testing
from fla.ops.gated_delta_rule import chunk_gated_delta_rule
bs = 1
num_groups, num_heads, head_k_dim, head_v_dim = 32, 32, 128, 128
torch.empty(1, device='cuda', requires_grad=True).backward()
@triton.testing.perf_report([
triton.testing.Benchmark(
x_names=["seq_len"],
x_vals=[1024],
x_log=True,
line_arg="provider",
line_vals=[
"chunk_gated_delta_rule",
],
line_names=[
"chunk_gated_delta_rule",
],
ylabel="latency (ms)",
plot_name=f"attention-{mode}",
args={"mode": mode},
)
for mode in ['fwd', 'bwd']
])
def benchmark(seq_len, provider, mode):
if provider == "chunk_gated_delta_rule":
kwargs = {
"q": torch.randn(bs, seq_len, num_groups, head_k_dim).bfloat16().cuda().requires_grad_(),
"k": torch.randn(bs, seq_len, num_groups, head_k_dim).bfloat16().cuda().requires_grad_(),
"v": torch.randn(bs, seq_len, num_heads, head_v_dim).bfloat16().cuda().requires_grad_(),
"g": torch.randn(bs, seq_len, num_heads).cuda().requires_grad_(),
"beta": torch.randn(bs, seq_len, num_heads).cuda().requires_grad_(),
"initial_state": None,
"output_final_state": False,
"cu_seqlens": None,
"head_first": False,
"use_qk_l2norm_in_kernel": True,
}
def fn(): return chunk_gated_delta_rule(**kwargs)[0]
else:
raise ValueError(f"Unknown provider: {provider}")
if mode == 'bwd':
o = fn()
do = torch.randn_like(o)
def fn(): return o.backward(do, retain_graph=True)
res_time = triton.testing.do_bench(fn)
return res_time
benchmark.run(print_data=True,)
Expected Behavior
Expected Behavior
Not to crash 🙂
Environment Information
Torch: 2.9.0 or 2.8.0
Triton: 3.5
GPU: Blackwell
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingtriton-bugwontfixThis will not be worked onThis will not be worked on