99#include < memory>
1010#include < queue>
1111
12+ #include " openvino/core/type/element_type_traits.hpp"
1213#include " openvino/reference/divide.hpp"
1314#include " openvino/reference/matmul.hpp"
1415#include " openvino/reference/softmax.hpp"
@@ -28,10 +29,11 @@ template <typename T>
2829class XAttentionBlockSelector {
2930public:
3031 /* * @param threshold Defines a threshold for introduced block sparsity - XAttention attempts to preserve the
31- * smallest subset of attention score matrix blocks so that the ratio of the attention score sum to the total sum of
32- * attention score matrix elements is no less than `threshold`. In other words, `threshold` defines a fraction of
33- * the attention score mass which is to be preserved by most "important" blocks. Valid range is 0.0-1.0, with 0.0
34- * corresponding to 0% of the blocks retained, and 1.0 corresponding to 100% of the blocks retained.
32+ * smallest subset of causal non-diagonal attention score matrix blocks so that the ratio of their attention score
33+ * sum to the total sum of causal non-diagonal attention score matrix blocks in the same K-row is no less than
34+ * `threshold`. In other words, `threshold` defines a fraction of the block non-diagonal causal attention score mass
35+ * which is to be preserved by most "important" blocks. Valid range is 0.0-1.0, with 0.0 corresponding to 0% of the
36+ * non-diagonal causal blocks retained, and 1.0 corresponding to 100% of the non-diagonal causal blocks retained.
3537 * @param block_size The size of blocks into which the attention score matrix [num_heads, query_token_dimension,
3638 * key_token_dimension] will be subdivided for purposes of determining the subset of the most important blocks
3739 * according to `threshold`. This subdivision occurs on query and key dimensions of the attention score matrix with
@@ -76,17 +78,17 @@ class XAttentionBlockSelector {
7678 OPENVINO_ASSERT (input_shape[1 ] / m_stride == out_shape[1 ]);
7779 OPENVINO_ASSERT (input_shape[2 ] * m_stride == out_shape[2 ]);
7880
79- size_t num_stride_steps = input_shape[1 ] / m_stride;
81+ size_t num_elts_in_strided_slice = input_shape[1 ] / m_stride;
8082 for (size_t head_idx = 0 ; head_idx < input_shape[0 ]; head_idx++) {
8183 size_t head_offset = head_idx * input_shape[1 ] * input_shape[2 ];
82- for (size_t slice_idx = 0 ; slice_idx < m_stride; slice_idx ++) {
83- for (size_t stride_idx = 0 ; stride_idx < num_stride_steps; stride_idx ++) {
84+ for (size_t stride_num = 0 ; stride_num < m_stride; stride_num ++) {
85+ for (size_t intra_slice_step = 0 ; intra_slice_step < num_elts_in_strided_slice; intra_slice_step ++) {
8486 size_t input_offset = head_offset;
85- size_t output_offset = head_offset + stride_idx * out_shape[2 ] + slice_idx * input_shape[2 ];
87+ size_t output_offset = head_offset + intra_slice_step * out_shape[2 ] + stride_num * input_shape[2 ];
8688 if (is_antidiagonal) {
87- input_offset += (input_shape[ 1 ] - 1 - slice_idx - stride_idx * m_stride) * input_shape[2 ];
89+ input_offset += (m_stride - 1 - stride_num + intra_slice_step * m_stride) * input_shape[2 ];
8890 } else {
89- input_offset += (slice_idx + stride_idx * m_stride) * input_shape[2 ];
91+ input_offset += (stride_num + intra_slice_step * m_stride) * input_shape[2 ];
9092 }
9193 std::memcpy (output_data + output_offset, input_data + input_offset, input_shape[2 ] * sizeof (T));
9294 }
@@ -142,6 +144,28 @@ class XAttentionBlockSelector {
142144 }
143145 }
144146
147+ /* * Applies the softmax causal mask along the last two dimensions of the rank-3 input tensor in-place.
148+ * @param in_out_data Pointer to the softmax input values (logits).
149+ * @param in_out_shape Shape of the input tensor. Expected shape is [num_heads, num_query_tokens /
150+ * stride, num_key_tokens / stride].
151+ */
152+ void apply_causal_mask_ (T* in_out_data, const Shape& in_out_shape) {
153+ OPENVINO_ASSERT (in_out_shape.size () == 3 );
154+ OPENVINO_ASSERT (in_out_shape[1 ] <= in_out_shape[2 ]);
155+ size_t query_dim = in_out_shape[1 ];
156+ size_t key_dim = in_out_shape[2 ];
157+ for (size_t head_idx = 0 ; head_idx < in_out_shape[0 ]; head_idx++) {
158+ size_t head_offset = head_idx * in_out_shape[1 ] * in_out_shape[2 ];
159+ for (size_t query_dim_idx = 0 ; query_dim_idx < in_out_shape[1 ]; query_dim_idx++) {
160+ size_t query_dim_offset = query_dim_idx * in_out_shape[2 ];
161+ for (size_t key_dim_idx = key_dim - query_dim + query_dim_idx + 1 ; key_dim_idx < key_dim;
162+ key_dim_idx++) {
163+ in_out_data[head_offset + query_dim_offset + key_dim_idx] = -INFINITY;
164+ }
165+ }
166+ }
167+ }
168+
145169 /* * Performs a softmax operation on the last dimension of the rank-3 input tensor.
146170 * @param reshaped_qk_product_data Pointer to the reshaped query-key product input (attention logits pre-softmax).
147171 * @param reshaped_qk_product_shape Shape of the input tensor. Expected shape is [num_heads, num_query_tokens /
@@ -203,9 +227,12 @@ class XAttentionBlockSelector {
203227 }
204228 }
205229
206- /* * Selects the elements of the input tensor along the last two dimensions, independently along the first dimension,
207- * so that the elements constitute a smallest subset constituting a sum portion no less than `threshold` of the
208- * total element sum.
230+ /* * Selects the elements of the input tensor along the last dimension, independently along the first two dimensions,
231+ * so that the selected elements constitute a smallest subset amounting to a sum portion no less than `threshold`
232+ * of the total "causal" element sum. "Causal" is understood in the sense of the last two dimensions being
233+ * treated as the query-block and key-block dimensions in the context of attention matrix scores. The
234+ * first-in-row, the "diagonal" and "non-causal" elements are disregarded when calculating the sum. "Non-causal"
235+ * elements are never preserved, while "diagonal" and first-in-row elements are always preserved.
209236 * @param blocked_scores_data Pointer to the blocked score input.
210237 * @param blocked_attention_scores_shape Shape of the blocked score input tensor. Expected shape is [num_heads,
211238 * num_query_tokens / block_size, num_key_tokens / block_size]
@@ -217,6 +244,8 @@ class XAttentionBlockSelector {
217244 const Shape& blocked_attention_scores_shape) {
218245 OPENVINO_ASSERT (blocked_attention_scores_shape.size () ==
219246 3 ); // [num_heads, num_blocks_in_query, num_blocks_in_key]
247+ //
248+ OPENVINO_ASSERT (blocked_attention_scores_shape[1 ] <= blocked_attention_scores_shape[2 ]);
220249
221250 auto retval = XAttentionRetainedBlockIndicesForAllHeads (blocked_attention_scores_shape[0 ]);
222251
@@ -230,23 +259,40 @@ class XAttentionBlockSelector {
230259
231260 for (size_t head_idx = 0 ; head_idx < blocked_attention_scores_shape[0 ]; head_idx++) {
232261 size_t head_offset = head_idx * blocked_attention_scores_shape[1 ] * blocked_attention_scores_shape[2 ];
233- std::priority_queue<IndexAndScore> indices_and_scores_queue;
234- double total_sum = 0.0 ;
235262 for (size_t q_block_idx = 0 ; q_block_idx < blocked_attention_scores_shape[1 ]; q_block_idx++) {
263+ std::priority_queue<IndexAndScore> indices_and_scores_queue;
264+ double total_sum = 0.0 ;
265+ double cumsum = 0.0 ;
236266 for (size_t k_block_idx = 0 ; k_block_idx < blocked_attention_scores_shape[2 ]; k_block_idx++) {
267+ if (k_block_idx >
268+ (blocked_attention_scores_shape[2 ] - blocked_attention_scores_shape[1 ] + q_block_idx)) {
269+ // Disregard non-causal blocks entirely
270+ continue ;
271+ }
237272 size_t target_offset = head_offset + blocked_attention_scores_shape[2 ] * q_block_idx + k_block_idx;
238273 T current_score = *(blocked_attention_scores_data + target_offset);
239- indices_and_scores_queue.push ({{q_block_idx, k_block_idx}, current_score});
240274 total_sum += current_score;
275+
276+ if ((k_block_idx ==
277+ (blocked_attention_scores_shape[2 ] - blocked_attention_scores_shape[1 ] + q_block_idx)) ||
278+ k_block_idx == 0 ) {
279+ // We preserve first-in-row and diagonal blocks always, and include their score in the
280+ // cumulative sum. The target for the rest of the blocks in row is to fill up the
281+ // rest of the attention mass fraction so that with the diagonal and first blocks they
282+ // comprise the `threshold` portion of the entire causal attention mass in this row
283+ retval[head_idx].insert ({q_block_idx, k_block_idx});
284+ cumsum += current_score;
285+ } else {
286+ indices_and_scores_queue.push ({{q_block_idx, k_block_idx}, current_score});
287+ }
288+ }
289+ double required_sum = m_threshold * total_sum;
290+ while (cumsum < required_sum && !indices_and_scores_queue.empty ()) {
291+ auto index_and_largest_score = indices_and_scores_queue.top ();
292+ indices_and_scores_queue.pop ();
293+ cumsum += index_and_largest_score.score ;
294+ retval[head_idx].insert (index_and_largest_score.idx );
241295 }
242- }
243- double cumsum = 0.0 ;
244- double required_sum = m_threshold * total_sum;
245- while (cumsum < required_sum && !indices_and_scores_queue.empty ()) {
246- auto index_and_largest_score = indices_and_scores_queue.top ();
247- indices_and_scores_queue.pop ();
248- cumsum += index_and_largest_score.score ;
249- retval[head_idx].insert (index_and_largest_score.idx );
250296 }
251297 }
252298 return retval;
@@ -303,6 +349,8 @@ class XAttentionBlockSelector {
303349 q_buf.reset ();
304350 k_buf.reset ();
305351
352+ apply_causal_mask_ (qk_buf.get (), transpose_matmul_scaled_shape);
353+
306354 Shape attention_scores_shape = transpose_matmul_scaled_shape;
307355 auto attn_score_buf = allocate_buf (attention_scores_shape);
308356 softmax (qk_buf.get (), transpose_matmul_scaled_shape, attn_score_buf.get (), attention_scores_shape);
0 commit comments