Skip to content

Commit 56d62ff

Browse files
authored
Fix XAttention reference code for better alignment with the original (#32451)
The original PR adding XAttention reference code was not completely aligned with the original code. Fixed were: * global block scoring was replaced with the K-row local scoring, as in the original code * antidiagonal stride direction was reversed to conform to the original code and algorithm * causal masking was introduced to effectively never select non-causal blocks and properly score the rest * block selection now correctly always keeps the "diagonal" blocks and the entire column of oldest blocks at k_block_dim == 0 * block selection now correctly calculates the required attention mass which should amount to a `threshold` portion of the total K-row block sum and always include diagonal/first-in-row blocks. Added more E2E tests that compare against the original code results, which are fixed in the test code as reference. Note that the original non-Triton code contained bugs, which had to be fixed in order to obtain correct references (see vshampor/x-attention@fdc5c34).
1 parent 8cd9781 commit 56d62ff

File tree

2 files changed

+494
-125
lines changed

2 files changed

+494
-125
lines changed

src/core/reference/include/openvino/reference/xattention.hpp

Lines changed: 72 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <memory>
1010
#include <queue>
1111

12+
#include "openvino/core/type/element_type_traits.hpp"
1213
#include "openvino/reference/divide.hpp"
1314
#include "openvino/reference/matmul.hpp"
1415
#include "openvino/reference/softmax.hpp"
@@ -28,10 +29,11 @@ template <typename T>
2829
class XAttentionBlockSelector {
2930
public:
3031
/** @param threshold Defines a threshold for introduced block sparsity - XAttention attempts to preserve the
31-
* smallest subset of attention score matrix blocks so that the ratio of the attention score sum to the total sum of
32-
* attention score matrix elements is no less than `threshold`. In other words, `threshold` defines a fraction of
33-
* the attention score mass which is to be preserved by most "important" blocks. Valid range is 0.0-1.0, with 0.0
34-
* corresponding to 0% of the blocks retained, and 1.0 corresponding to 100% of the blocks retained.
32+
* smallest subset of causal non-diagonal attention score matrix blocks so that the ratio of their attention score
33+
* sum to the total sum of causal non-diagonal attention score matrix blocks in the same K-row is no less than
34+
* `threshold`. In other words, `threshold` defines a fraction of the block non-diagonal causal attention score mass
35+
* which is to be preserved by most "important" blocks. Valid range is 0.0-1.0, with 0.0 corresponding to 0% of the
36+
* non-diagonal causal blocks retained, and 1.0 corresponding to 100% of the non-diagonal causal blocks retained.
3537
* @param block_size The size of blocks into which the attention score matrix [num_heads, query_token_dimension,
3638
* key_token_dimension] will be subdivided for purposes of determining the subset of the most important blocks
3739
* according to `threshold`. This subdivision occurs on query and key dimensions of the attention score matrix with
@@ -76,17 +78,17 @@ class XAttentionBlockSelector {
7678
OPENVINO_ASSERT(input_shape[1] / m_stride == out_shape[1]);
7779
OPENVINO_ASSERT(input_shape[2] * m_stride == out_shape[2]);
7880

79-
size_t num_stride_steps = input_shape[1] / m_stride;
81+
size_t num_elts_in_strided_slice = input_shape[1] / m_stride;
8082
for (size_t head_idx = 0; head_idx < input_shape[0]; head_idx++) {
8183
size_t head_offset = head_idx * input_shape[1] * input_shape[2];
82-
for (size_t slice_idx = 0; slice_idx < m_stride; slice_idx++) {
83-
for (size_t stride_idx = 0; stride_idx < num_stride_steps; stride_idx++) {
84+
for (size_t stride_num = 0; stride_num < m_stride; stride_num++) {
85+
for (size_t intra_slice_step = 0; intra_slice_step < num_elts_in_strided_slice; intra_slice_step++) {
8486
size_t input_offset = head_offset;
85-
size_t output_offset = head_offset + stride_idx * out_shape[2] + slice_idx * input_shape[2];
87+
size_t output_offset = head_offset + intra_slice_step * out_shape[2] + stride_num * input_shape[2];
8688
if (is_antidiagonal) {
87-
input_offset += (input_shape[1] - 1 - slice_idx - stride_idx * m_stride) * input_shape[2];
89+
input_offset += (m_stride - 1 - stride_num + intra_slice_step * m_stride) * input_shape[2];
8890
} else {
89-
input_offset += (slice_idx + stride_idx * m_stride) * input_shape[2];
91+
input_offset += (stride_num + intra_slice_step * m_stride) * input_shape[2];
9092
}
9193
std::memcpy(output_data + output_offset, input_data + input_offset, input_shape[2] * sizeof(T));
9294
}
@@ -142,6 +144,28 @@ class XAttentionBlockSelector {
142144
}
143145
}
144146

147+
/** Applies the softmax causal mask along the last two dimensions of the rank-3 input tensor in-place.
148+
* @param in_out_data Pointer to the softmax input values (logits).
149+
* @param in_out_shape Shape of the input tensor. Expected shape is [num_heads, num_query_tokens /
150+
* stride, num_key_tokens / stride].
151+
*/
152+
void apply_causal_mask_(T* in_out_data, const Shape& in_out_shape) {
153+
OPENVINO_ASSERT(in_out_shape.size() == 3);
154+
OPENVINO_ASSERT(in_out_shape[1] <= in_out_shape[2]);
155+
size_t query_dim = in_out_shape[1];
156+
size_t key_dim = in_out_shape[2];
157+
for (size_t head_idx = 0; head_idx < in_out_shape[0]; head_idx++) {
158+
size_t head_offset = head_idx * in_out_shape[1] * in_out_shape[2];
159+
for (size_t query_dim_idx = 0; query_dim_idx < in_out_shape[1]; query_dim_idx++) {
160+
size_t query_dim_offset = query_dim_idx * in_out_shape[2];
161+
for (size_t key_dim_idx = key_dim - query_dim + query_dim_idx + 1; key_dim_idx < key_dim;
162+
key_dim_idx++) {
163+
in_out_data[head_offset + query_dim_offset + key_dim_idx] = -INFINITY;
164+
}
165+
}
166+
}
167+
}
168+
145169
/** Performs a softmax operation on the last dimension of the rank-3 input tensor.
146170
* @param reshaped_qk_product_data Pointer to the reshaped query-key product input (attention logits pre-softmax).
147171
* @param reshaped_qk_product_shape Shape of the input tensor. Expected shape is [num_heads, num_query_tokens /
@@ -203,9 +227,12 @@ class XAttentionBlockSelector {
203227
}
204228
}
205229

206-
/** Selects the elements of the input tensor along the last two dimensions, independently along the first dimension,
207-
* so that the elements constitute a smallest subset constituting a sum portion no less than `threshold` of the
208-
* total element sum.
230+
/** Selects the elements of the input tensor along the last dimension, independently along the first two dimensions,
231+
* so that the selected elements constitute a smallest subset amounting to a sum portion no less than `threshold`
232+
* of the total "causal" element sum. "Causal" is understood in the sense of the last two dimensions being
233+
* treated as the query-block and key-block dimensions in the context of attention matrix scores. The
234+
* first-in-row, the "diagonal" and "non-causal" elements are disregarded when calculating the sum. "Non-causal"
235+
* elements are never preserved, while "diagonal" and first-in-row elements are always preserved.
209236
* @param blocked_scores_data Pointer to the blocked score input.
210237
* @param blocked_attention_scores_shape Shape of the blocked score input tensor. Expected shape is [num_heads,
211238
* num_query_tokens / block_size, num_key_tokens / block_size]
@@ -217,6 +244,8 @@ class XAttentionBlockSelector {
217244
const Shape& blocked_attention_scores_shape) {
218245
OPENVINO_ASSERT(blocked_attention_scores_shape.size() ==
219246
3); // [num_heads, num_blocks_in_query, num_blocks_in_key]
247+
//
248+
OPENVINO_ASSERT(blocked_attention_scores_shape[1] <= blocked_attention_scores_shape[2]);
220249

221250
auto retval = XAttentionRetainedBlockIndicesForAllHeads(blocked_attention_scores_shape[0]);
222251

@@ -230,23 +259,40 @@ class XAttentionBlockSelector {
230259

231260
for (size_t head_idx = 0; head_idx < blocked_attention_scores_shape[0]; head_idx++) {
232261
size_t head_offset = head_idx * blocked_attention_scores_shape[1] * blocked_attention_scores_shape[2];
233-
std::priority_queue<IndexAndScore> indices_and_scores_queue;
234-
double total_sum = 0.0;
235262
for (size_t q_block_idx = 0; q_block_idx < blocked_attention_scores_shape[1]; q_block_idx++) {
263+
std::priority_queue<IndexAndScore> indices_and_scores_queue;
264+
double total_sum = 0.0;
265+
double cumsum = 0.0;
236266
for (size_t k_block_idx = 0; k_block_idx < blocked_attention_scores_shape[2]; k_block_idx++) {
267+
if (k_block_idx >
268+
(blocked_attention_scores_shape[2] - blocked_attention_scores_shape[1] + q_block_idx)) {
269+
// Disregard non-causal blocks entirely
270+
continue;
271+
}
237272
size_t target_offset = head_offset + blocked_attention_scores_shape[2] * q_block_idx + k_block_idx;
238273
T current_score = *(blocked_attention_scores_data + target_offset);
239-
indices_and_scores_queue.push({{q_block_idx, k_block_idx}, current_score});
240274
total_sum += current_score;
275+
276+
if ((k_block_idx ==
277+
(blocked_attention_scores_shape[2] - blocked_attention_scores_shape[1] + q_block_idx)) ||
278+
k_block_idx == 0) {
279+
// We preserve first-in-row and diagonal blocks always, and include their score in the
280+
// cumulative sum. The target for the rest of the blocks in row is to fill up the
281+
// rest of the attention mass fraction so that with the diagonal and first blocks they
282+
// comprise the `threshold` portion of the entire causal attention mass in this row
283+
retval[head_idx].insert({q_block_idx, k_block_idx});
284+
cumsum += current_score;
285+
} else {
286+
indices_and_scores_queue.push({{q_block_idx, k_block_idx}, current_score});
287+
}
288+
}
289+
double required_sum = m_threshold * total_sum;
290+
while (cumsum < required_sum && !indices_and_scores_queue.empty()) {
291+
auto index_and_largest_score = indices_and_scores_queue.top();
292+
indices_and_scores_queue.pop();
293+
cumsum += index_and_largest_score.score;
294+
retval[head_idx].insert(index_and_largest_score.idx);
241295
}
242-
}
243-
double cumsum = 0.0;
244-
double required_sum = m_threshold * total_sum;
245-
while (cumsum < required_sum && !indices_and_scores_queue.empty()) {
246-
auto index_and_largest_score = indices_and_scores_queue.top();
247-
indices_and_scores_queue.pop();
248-
cumsum += index_and_largest_score.score;
249-
retval[head_idx].insert(index_and_largest_score.idx);
250296
}
251297
}
252298
return retval;
@@ -303,6 +349,8 @@ class XAttentionBlockSelector {
303349
q_buf.reset();
304350
k_buf.reset();
305351

352+
apply_causal_mask_(qk_buf.get(), transpose_matmul_scaled_shape);
353+
306354
Shape attention_scores_shape = transpose_matmul_scaled_shape;
307355
auto attn_score_buf = allocate_buf(attention_scores_shape);
308356
softmax(qk_buf.get(), transpose_matmul_scaled_shape, attn_score_buf.get(), attention_scores_shape);

0 commit comments

Comments
 (0)