Skip to content

Commit 4fed7fb

Browse files
committed
[NPUW] Patch Phi-3 Sliding Window Attention Mask to work with different paddings in past and present tokens
1 parent 56d62ff commit 4fed7fb

File tree

1 file changed

+175
-0
lines changed

1 file changed

+175
-0
lines changed

src/plugins/intel_npu/src/plugin/npuw/llm_compiled_model.cpp

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,168 @@ class GemmaSlidingMask : public ov::pass::MatcherPass {
396396
}
397397
};
398398

399+
class Phi3SlidingMask : public ov::pass::MatcherPass {
400+
public:
401+
OPENVINO_MATCHER_PASS_RTTI("npuw::LLMCompiledModel::Phi3SlidingMask");
402+
403+
Phi3SlidingMask() {
404+
// Search for the Phi3 sliding mask pattern to extend it to work with right-padded
405+
// past tokens and left-padded present tokens.
406+
//
407+
// Mask creation is simply done via "less_equal" and "greater" operations between
408+
// row K range: [0,... mask_len] and column Q range: [current_pos_id,... mask_len].T
409+
// and sliding window length.
410+
// Due to broadcasting rules these two operation form two triangular masks.
411+
//
412+
// - "less_equal" forms a sliding window mask, more precisely, it has following expression:
413+
//
414+
// row range [0,... mask_len] <= column range [current_pos_id - sliding_window_size,
415+
// ...,
416+
// mask_len - sliding_window_size]
417+
//
418+
// forming, under example conditions, the mask below:
419+
// past tokens = 3
420+
// present tokens = 5 (starting with current_pos_id = 3)
421+
// sliding window len = 4
422+
// K0 K1 K2 K3 K4 K5 K6 K7
423+
// [ 0 1 2 3 4 5 6 7 ]
424+
// Q3[ 3 - 4 ] 0 0 0 0 0 0 0 0
425+
// Q4[ 4 - 4 ] 1 0 0 0 0 0 0 0
426+
// Q5[ 5 - 4 ] 1 1 0 0 0 0 0 0
427+
// Q6[ 6 - 4 ] 1 1 1 0 0 0 0 0
428+
// Q7[ 7 - 4 ] 1 1 1 1 0 0 0 0
429+
// where 1 at [i, j] means that j token should be forgotten as it can't fit into the sliding
430+
// window from the left of i-th token.
431+
//
432+
// - "greater" forms a similar to self-attention mask:
433+
//
434+
// row range [0,... mask_len] > column range [current_pos_id,
435+
// ...,
436+
// mask_len]
437+
//
438+
// forming, under example conditions, the mask below:
439+
// past tokens = 3
440+
// present tokens = 5 (starting with current_pos_id = 3)
441+
// K0 K1 K2 K3 K4 K5 K6 K7
442+
// [ 0 1 2 3 4 5 6 7 ]
443+
// Q3[ 3 ] 0 0 0 0 1 1 1 1
444+
// Q4[ 4 ] 0 0 0 0 0 1 1 1
445+
// Q5[ 5 ] 0 0 0 0 0 0 1 1
446+
// Q6[ 6 ] 0 0 0 0 0 0 0 1
447+
// Q7[ 7 ] 0 0 0 0 0 0 0 0
448+
// where 1 at [i, j] means that j token is a future token for i-th token, that we shouldn't attend to.
449+
//
450+
// Together, via "bitwise_or" this two masks forms the inverted sliding attention mask:
451+
// past tokens = 3
452+
// present tokens = 5 (starting with current_pos_id = 3)
453+
// sliding window len = 4
454+
// K0 K1 K2 K3 K4 K5 K6 K7
455+
// [ 0 1 2 3 4 5 6 7 ]
456+
// Q3[ 3 - 4 ] 0 0 0 0 1 1 1 1
457+
// Q4[ 4 - 4 ] 1 0 0 0 0 1 1 1
458+
// Q5[ 5 - 4 ] 1 1 0 0 0 0 1 1
459+
// Q6[ 6 - 4 ] 1 1 1 0 0 0 0 1
460+
// Q7[ 7 - 4 ] 1 1 1 1 0 0 0 0
461+
//
462+
// Issue with sliding attention mask appears when we work with static shapes and different
463+
// paddings for past and present tokens.
464+
// More precisely, issue appears with sliding window mask, as Q column range is created
465+
// from length of past key/values tensor (2175 for 2K case) as start point and the length
466+
// of attention mask (2176 for 2K) as an end point. This is okay for inverted
467+
// self-attention mask by means of "greater" operation, as our present tokens exactly
468+
// left-padded and located on the right in the attention mask.
469+
// However, for the sliding window mask created by means of "less_equal" operation, given
470+
// Q range will behave as if position ids of new Q tokens will start from 2175 and not from
471+
// 3 as in example above and therefore, 2175 - 2047 = 128 first tokens should be forgotten.
472+
// To fix it a new formula is suggested:
473+
// 1. (K range <= (Q_pos range - sliding window).T) | (K range > Q range.T)
474+
// 2. (K range <= (Q range - sliding window).T) & (K range >= len(past_key_values))
475+
// 3. Resulting mask = 1 | 2,
476+
// where K range and Q range are created by the same rules as before and Q_pos range is
477+
// a position_ids array.
478+
auto past_kv_len = opp::wrap_type<ov::op::v8::Gather>({opp::any_input(), opp::any_input(), opp::any_input()});
479+
auto pos_ids_param = opp::wrap_type<ov::op::v0::Parameter>();
480+
auto pos_ids_shape_of = opp::wrap_type<ov::op::v3::ShapeOf>({pos_ids_param});
481+
auto pos_ids_len = opp::wrap_type<ov::op::v8::Gather>({pos_ids_shape_of, opp::any_input(), opp::any_input()});
482+
auto full_ctx_len = opp::wrap_type<ov::op::v1::Add>({past_kv_len, pos_ids_len});
483+
auto query_range = opp::wrap_type<ov::op::v4::Range>({past_kv_len, full_ctx_len, opp::any_input()});
484+
auto column_shape = opp::wrap_type<ov::op::v0::Constant>();
485+
auto query_range_column = opp::wrap_type<ov::op::v1::Reshape>({query_range, column_shape});
486+
487+
auto zero_const = opp::wrap_type<ov::op::v0::Constant>();
488+
auto atten_mask_len =
489+
opp::wrap_type<ov::op::v8::Gather>({opp::any_input(), opp::any_input(), opp::any_input()});
490+
auto key_range = opp::wrap_type<ov::op::v4::Range>({zero_const, atten_mask_len, opp::any_input()});
491+
auto key_range_i64 = opp::wrap_type<ov::op::v0::Convert>({key_range});
492+
auto key_range_f32 = opp::wrap_type<ov::op::v0::Convert>({key_range_i64});
493+
494+
auto neg_window_size = opp::wrap_type<ov::op::v0::Constant>();
495+
auto query_left_bound_range = opp::wrap_type<ov::op::v1::Add>({query_range_column, neg_window_size});
496+
// False in mask means that we shouldn't forget this token
497+
auto forget_left_tokens_mask = opp::wrap_type<ov::op::v1::LessEqual>({key_range_f32, query_left_bound_range});
498+
// Basically it is a reference triangle self-attention mask that
499+
// forbids tokens to attend to future ones, but values are inverted:
500+
auto look_only_future_mask = opp::wrap_type<ov::op::v1::Greater>({key_range_f32, query_range_column});
501+
502+
auto inv_sliding_attention_mask =
503+
opp::wrap_type<ov::op::v13::BitwiseOr>({look_only_future_mask, forget_left_tokens_mask});
504+
505+
auto callback = [=](ov::pass::pattern::Matcher& m) {
506+
auto& node_to_output = m.get_pattern_value_map();
507+
auto node_past_kv_len = node_to_output.at(past_kv_len).get_node_shared_ptr();
508+
auto node_pos_ids_param = node_to_output.at(pos_ids_param).get_node_shared_ptr();
509+
auto node_key_range_f32 = node_to_output.at(key_range_f32).get_node_shared_ptr();
510+
auto node_query_left_bound_range = node_to_output.at(query_left_bound_range).get_node_shared_ptr();
511+
auto node_neg_window_size = node_to_output.at(neg_window_size).get_node_shared_ptr();
512+
auto node_forget_left_tokens_mask = node_to_output.at(forget_left_tokens_mask).get_node_shared_ptr();
513+
auto node_bitwise_or = node_to_output.at(inv_sliding_attention_mask).get_node_shared_ptr();
514+
515+
auto matched_past_kv_len = std::static_pointer_cast<ov::op::v8::Gather>(node_past_kv_len);
516+
auto matched_pos_ids = std::static_pointer_cast<ov::op::v0::Parameter>(node_pos_ids_param);
517+
auto matched_key_range_f32 = std::static_pointer_cast<ov::op::v0::Convert>(node_key_range_f32);
518+
auto matched_query_left_bound = std::static_pointer_cast<ov::op::v1::Add>(node_query_left_bound_range);
519+
auto matched_neg_window_size = std::static_pointer_cast<ov::op::v0::Constant>(node_neg_window_size);
520+
auto matched_forget_left_tokens_mask =
521+
std::static_pointer_cast<ov::op::v1::LessEqual>(node_forget_left_tokens_mask);
522+
auto matched_bitwise_or = std::static_pointer_cast<ov::op::v13::BitwiseOr>(node_bitwise_or);
523+
OPENVINO_ASSERT(matched_neg_window_size->get_output_size() == 1,
524+
"Sliding window size constant must be of size 1, but got " +
525+
std::to_string(matched_neg_window_size->get_output_size()));
526+
527+
// 1.(K range <= (Q_pos range - sliding window).T) | (K range > Q range.T)
528+
auto query_range_as_pos_ids = std::make_shared<ov::op::v0::Convert>(matched_pos_ids, ov::element::f32);
529+
std::vector<int64_t> vector_shape{-1, 1};
530+
auto vector_shape_const =
531+
std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{2}, vector_shape);
532+
auto query_range_as_pos_ids_col =
533+
std::make_shared<ov::op::v1::Reshape>(query_range_as_pos_ids, vector_shape_const, false);
534+
auto query_range_as_pos_left_bound =
535+
std::make_shared<ov::op::v1::Add>(query_range_as_pos_ids_col, matched_neg_window_size);
536+
auto forget_left_mask_for_right_padding =
537+
std::make_shared<ov::op::v1::LessEqual>(matched_key_range_f32, query_range_as_pos_left_bound);
538+
matched_bitwise_or->input(1).replace_source_output(forget_left_mask_for_right_padding);
539+
540+
// 2. (K range <= (Q range - sliding window).T) & (K range >= shape(past_key_values, 2))
541+
auto past_kv_len_f32 = std::make_shared<ov::op::v0::Convert>(matched_past_kv_len, ov::element::f32);
542+
auto only_present_tokens_mask = std::make_shared<ov::op::v1::GreaterEqual>(matched_key_range_f32,
543+
past_kv_len_f32);
544+
// unsqueeze?
545+
auto bitwise_and = std::make_shared<ov::op::v13::BitwiseAnd>(matched_forget_left_tokens_mask,
546+
only_present_tokens_mask);
547+
548+
// 3. Result = 1 | 2
549+
auto target_inputs = matched_bitwise_or->output(0).get_target_inputs();
550+
auto final_inv_sliding_mask = std::make_shared<ov::op::v13::BitwiseOr>(matched_bitwise_or, bitwise_and);
551+
for (auto&& input : target_inputs) {
552+
input.replace_source_output(final_inv_sliding_mask);
553+
}
554+
return true;
555+
};
556+
register_matcher(std::make_shared<opp::Matcher>(inv_sliding_attention_mask, "Phi3SlidingMask"),
557+
std::move(callback));
558+
}
559+
};
560+
399561
namespace {
400562
uint32_t align_to(uint32_t value, uint32_t alignment) {
401563
return (value + alignment - 1) & ~(alignment - 1);
@@ -455,6 +617,13 @@ void decompose_GQA(std::shared_ptr<ov::Model> model, bool is_prefill_model) {
455617
rewr.add_matcher<GroupQueryAttentionDecomposition>(is_prefill_model);
456618
rewr.run_on_model(model);
457619
}
620+
621+
void patch_phi3_sliding_mask(const std::shared_ptr<ov::Model>& model) {
622+
ov::pass::GraphRewrite rewr;
623+
rewr.add_matcher<Phi3SlidingMask>();
624+
rewr.run_on_model(model);
625+
model->validate_nodes_and_infer_types();
626+
}
458627
} // namespace
459628

460629
namespace {
@@ -1022,6 +1191,10 @@ ov::npuw::LLMCompiledModel::LLMCompiledModel(const std::shared_ptr<ov::Model>& m
10221191
LOG_INFO("Two-model pipeline will be created.");
10231192
}
10241193

1194+
LOG_DEBUG("Try patch Phi-3 sliding window mask, if it exists.");
1195+
patch_phi3_sliding_mask(kvcache_model);
1196+
ov::save_model(kvcache_model, "swa_patched.xml");
1197+
10251198
LOG_DEBUG("Creating prefill model as clone of transformed kvcache one.");
10261199
auto prefill_model = kvcache_model->clone();
10271200
prefill_model->set_friendly_name(kvcache_model->get_friendly_name() + "_prefill");
@@ -1117,6 +1290,8 @@ ov::npuw::LLMCompiledModel::LLMCompiledModel(const std::shared_ptr<ov::Model>& m
11171290
axes,
11181291
m_max_lora_rank,
11191292
whisper_lhs_seq_size);
1293+
1294+
LOG_DEBUG("Try parametrize Gemma sliding window mask, if it exists.");
11201295
gemma_transformations(kvcache_model);
11211296

11221297
if (lm_head_model) {

0 commit comments

Comments
 (0)