@@ -396,6 +396,168 @@ class GemmaSlidingMask : public ov::pass::MatcherPass {
396396 }
397397};
398398
399+ class Phi3SlidingMask : public ov ::pass::MatcherPass {
400+ public:
401+ OPENVINO_MATCHER_PASS_RTTI (" npuw::LLMCompiledModel::Phi3SlidingMask" );
402+
403+ Phi3SlidingMask () {
404+ // Search for the Phi3 sliding mask pattern to extend it to work with right-padded
405+ // past tokens and left-padded present tokens.
406+ //
407+ // Mask creation is simply done via "less_equal" and "greater" operations between
408+ // row K range: [0,... mask_len] and column Q range: [current_pos_id,... mask_len].T
409+ // and sliding window length.
410+ // Due to broadcasting rules these two operation form two triangular masks.
411+ //
412+ // - "less_equal" forms a sliding window mask, more precisely, it has following expression:
413+ //
414+ // row range [0,... mask_len] <= column range [current_pos_id - sliding_window_size,
415+ // ...,
416+ // mask_len - sliding_window_size]
417+ //
418+ // forming, under example conditions, the mask below:
419+ // past tokens = 3
420+ // present tokens = 5 (starting with current_pos_id = 3)
421+ // sliding window len = 4
422+ // K0 K1 K2 K3 K4 K5 K6 K7
423+ // [ 0 1 2 3 4 5 6 7 ]
424+ // Q3[ 3 - 4 ] 0 0 0 0 0 0 0 0
425+ // Q4[ 4 - 4 ] 1 0 0 0 0 0 0 0
426+ // Q5[ 5 - 4 ] 1 1 0 0 0 0 0 0
427+ // Q6[ 6 - 4 ] 1 1 1 0 0 0 0 0
428+ // Q7[ 7 - 4 ] 1 1 1 1 0 0 0 0
429+ // where 1 at [i, j] means that j token should be forgotten as it can't fit into the sliding
430+ // window from the left of i-th token.
431+ //
432+ // - "greater" forms a similar to self-attention mask:
433+ //
434+ // row range [0,... mask_len] > column range [current_pos_id,
435+ // ...,
436+ // mask_len]
437+ //
438+ // forming, under example conditions, the mask below:
439+ // past tokens = 3
440+ // present tokens = 5 (starting with current_pos_id = 3)
441+ // K0 K1 K2 K3 K4 K5 K6 K7
442+ // [ 0 1 2 3 4 5 6 7 ]
443+ // Q3[ 3 ] 0 0 0 0 1 1 1 1
444+ // Q4[ 4 ] 0 0 0 0 0 1 1 1
445+ // Q5[ 5 ] 0 0 0 0 0 0 1 1
446+ // Q6[ 6 ] 0 0 0 0 0 0 0 1
447+ // Q7[ 7 ] 0 0 0 0 0 0 0 0
448+ // where 1 at [i, j] means that j token is a future token for i-th token, that we shouldn't attend to.
449+ //
450+ // Together, via "bitwise_or" this two masks forms the inverted sliding attention mask:
451+ // past tokens = 3
452+ // present tokens = 5 (starting with current_pos_id = 3)
453+ // sliding window len = 4
454+ // K0 K1 K2 K3 K4 K5 K6 K7
455+ // [ 0 1 2 3 4 5 6 7 ]
456+ // Q3[ 3 - 4 ] 0 0 0 0 1 1 1 1
457+ // Q4[ 4 - 4 ] 1 0 0 0 0 1 1 1
458+ // Q5[ 5 - 4 ] 1 1 0 0 0 0 1 1
459+ // Q6[ 6 - 4 ] 1 1 1 0 0 0 0 1
460+ // Q7[ 7 - 4 ] 1 1 1 1 0 0 0 0
461+ //
462+ // Issue with sliding attention mask appears when we work with static shapes and different
463+ // paddings for past and present tokens.
464+ // More precisely, issue appears with sliding window mask, as Q column range is created
465+ // from length of past key/values tensor (2175 for 2K case) as start point and the length
466+ // of attention mask (2176 for 2K) as an end point. This is okay for inverted
467+ // self-attention mask by means of "greater" operation, as our present tokens exactly
468+ // left-padded and located on the right in the attention mask.
469+ // However, for the sliding window mask created by means of "less_equal" operation, given
470+ // Q range will behave as if position ids of new Q tokens will start from 2175 and not from
471+ // 3 as in example above and therefore, 2175 - 2047 = 128 first tokens should be forgotten.
472+ // To fix it a new formula is suggested:
473+ // 1. (K range <= (Q_pos range - sliding window).T) | (K range > Q range.T)
474+ // 2. (K range <= (Q range - sliding window).T) & (K range >= len(past_key_values))
475+ // 3. Resulting mask = 1 | 2,
476+ // where K range and Q range are created by the same rules as before and Q_pos range is
477+ // a position_ids array.
478+ auto past_kv_len = opp::wrap_type<ov::op::v8::Gather>({opp::any_input (), opp::any_input (), opp::any_input ()});
479+ auto pos_ids_param = opp::wrap_type<ov::op::v0::Parameter>();
480+ auto pos_ids_shape_of = opp::wrap_type<ov::op::v3::ShapeOf>({pos_ids_param});
481+ auto pos_ids_len = opp::wrap_type<ov::op::v8::Gather>({pos_ids_shape_of, opp::any_input (), opp::any_input ()});
482+ auto full_ctx_len = opp::wrap_type<ov::op::v1::Add>({past_kv_len, pos_ids_len});
483+ auto query_range = opp::wrap_type<ov::op::v4::Range>({past_kv_len, full_ctx_len, opp::any_input ()});
484+ auto column_shape = opp::wrap_type<ov::op::v0::Constant>();
485+ auto query_range_column = opp::wrap_type<ov::op::v1::Reshape>({query_range, column_shape});
486+
487+ auto zero_const = opp::wrap_type<ov::op::v0::Constant>();
488+ auto atten_mask_len =
489+ opp::wrap_type<ov::op::v8::Gather>({opp::any_input (), opp::any_input (), opp::any_input ()});
490+ auto key_range = opp::wrap_type<ov::op::v4::Range>({zero_const, atten_mask_len, opp::any_input ()});
491+ auto key_range_i64 = opp::wrap_type<ov::op::v0::Convert>({key_range});
492+ auto key_range_f32 = opp::wrap_type<ov::op::v0::Convert>({key_range_i64});
493+
494+ auto neg_window_size = opp::wrap_type<ov::op::v0::Constant>();
495+ auto query_left_bound_range = opp::wrap_type<ov::op::v1::Add>({query_range_column, neg_window_size});
496+ // False in mask means that we shouldn't forget this token
497+ auto forget_left_tokens_mask = opp::wrap_type<ov::op::v1::LessEqual>({key_range_f32, query_left_bound_range});
498+ // Basically it is a reference triangle self-attention mask that
499+ // forbids tokens to attend to future ones, but values are inverted:
500+ auto look_only_future_mask = opp::wrap_type<ov::op::v1::Greater>({key_range_f32, query_range_column});
501+
502+ auto inv_sliding_attention_mask =
503+ opp::wrap_type<ov::op::v13::BitwiseOr>({look_only_future_mask, forget_left_tokens_mask});
504+
505+ auto callback = [=](ov::pass::pattern::Matcher& m) {
506+ auto & node_to_output = m.get_pattern_value_map ();
507+ auto node_past_kv_len = node_to_output.at (past_kv_len).get_node_shared_ptr ();
508+ auto node_pos_ids_param = node_to_output.at (pos_ids_param).get_node_shared_ptr ();
509+ auto node_key_range_f32 = node_to_output.at (key_range_f32).get_node_shared_ptr ();
510+ auto node_query_left_bound_range = node_to_output.at (query_left_bound_range).get_node_shared_ptr ();
511+ auto node_neg_window_size = node_to_output.at (neg_window_size).get_node_shared_ptr ();
512+ auto node_forget_left_tokens_mask = node_to_output.at (forget_left_tokens_mask).get_node_shared_ptr ();
513+ auto node_bitwise_or = node_to_output.at (inv_sliding_attention_mask).get_node_shared_ptr ();
514+
515+ auto matched_past_kv_len = std::static_pointer_cast<ov::op::v8::Gather>(node_past_kv_len);
516+ auto matched_pos_ids = std::static_pointer_cast<ov::op::v0::Parameter>(node_pos_ids_param);
517+ auto matched_key_range_f32 = std::static_pointer_cast<ov::op::v0::Convert>(node_key_range_f32);
518+ auto matched_query_left_bound = std::static_pointer_cast<ov::op::v1::Add>(node_query_left_bound_range);
519+ auto matched_neg_window_size = std::static_pointer_cast<ov::op::v0::Constant>(node_neg_window_size);
520+ auto matched_forget_left_tokens_mask =
521+ std::static_pointer_cast<ov::op::v1::LessEqual>(node_forget_left_tokens_mask);
522+ auto matched_bitwise_or = std::static_pointer_cast<ov::op::v13::BitwiseOr>(node_bitwise_or);
523+ OPENVINO_ASSERT (matched_neg_window_size->get_output_size () == 1 ,
524+ " Sliding window size constant must be of size 1, but got " +
525+ std::to_string (matched_neg_window_size->get_output_size ()));
526+
527+ // 1.(K range <= (Q_pos range - sliding window).T) | (K range > Q range.T)
528+ auto query_range_as_pos_ids = std::make_shared<ov::op::v0::Convert>(matched_pos_ids, ov::element::f32 );
529+ std::vector<int64_t > vector_shape{-1 , 1 };
530+ auto vector_shape_const =
531+ std::make_shared<ov::op::v0::Constant>(ov::element::i64 , ov::Shape{2 }, vector_shape);
532+ auto query_range_as_pos_ids_col =
533+ std::make_shared<ov::op::v1::Reshape>(query_range_as_pos_ids, vector_shape_const, false );
534+ auto query_range_as_pos_left_bound =
535+ std::make_shared<ov::op::v1::Add>(query_range_as_pos_ids_col, matched_neg_window_size);
536+ auto forget_left_mask_for_right_padding =
537+ std::make_shared<ov::op::v1::LessEqual>(matched_key_range_f32, query_range_as_pos_left_bound);
538+ matched_bitwise_or->input (1 ).replace_source_output (forget_left_mask_for_right_padding);
539+
540+ // 2. (K range <= (Q range - sliding window).T) & (K range >= shape(past_key_values, 2))
541+ auto past_kv_len_f32 = std::make_shared<ov::op::v0::Convert>(matched_past_kv_len, ov::element::f32 );
542+ auto only_present_tokens_mask = std::make_shared<ov::op::v1::GreaterEqual>(matched_key_range_f32,
543+ past_kv_len_f32);
544+ // unsqueeze?
545+ auto bitwise_and = std::make_shared<ov::op::v13::BitwiseAnd>(matched_forget_left_tokens_mask,
546+ only_present_tokens_mask);
547+
548+ // 3. Result = 1 | 2
549+ auto target_inputs = matched_bitwise_or->output (0 ).get_target_inputs ();
550+ auto final_inv_sliding_mask = std::make_shared<ov::op::v13::BitwiseOr>(matched_bitwise_or, bitwise_and);
551+ for (auto && input : target_inputs) {
552+ input.replace_source_output (final_inv_sliding_mask);
553+ }
554+ return true ;
555+ };
556+ register_matcher (std::make_shared<opp::Matcher>(inv_sliding_attention_mask, " Phi3SlidingMask" ),
557+ std::move (callback));
558+ }
559+ };
560+
399561namespace {
400562uint32_t align_to (uint32_t value, uint32_t alignment) {
401563 return (value + alignment - 1 ) & ~(alignment - 1 );
@@ -455,6 +617,13 @@ void decompose_GQA(std::shared_ptr<ov::Model> model, bool is_prefill_model) {
455617 rewr.add_matcher <GroupQueryAttentionDecomposition>(is_prefill_model);
456618 rewr.run_on_model (model);
457619}
620+
621+ void patch_phi3_sliding_mask (const std::shared_ptr<ov::Model>& model) {
622+ ov::pass::GraphRewrite rewr;
623+ rewr.add_matcher <Phi3SlidingMask>();
624+ rewr.run_on_model (model);
625+ model->validate_nodes_and_infer_types ();
626+ }
458627} // namespace
459628
460629namespace {
@@ -1022,6 +1191,10 @@ ov::npuw::LLMCompiledModel::LLMCompiledModel(const std::shared_ptr<ov::Model>& m
10221191 LOG_INFO (" Two-model pipeline will be created." );
10231192 }
10241193
1194+ LOG_DEBUG (" Try patch Phi-3 sliding window mask, if it exists." );
1195+ patch_phi3_sliding_mask (kvcache_model);
1196+ ov::save_model (kvcache_model, " swa_patched.xml" );
1197+
10251198 LOG_DEBUG (" Creating prefill model as clone of transformed kvcache one." );
10261199 auto prefill_model = kvcache_model->clone ();
10271200 prefill_model->set_friendly_name (kvcache_model->get_friendly_name () + " _prefill" );
@@ -1117,6 +1290,8 @@ ov::npuw::LLMCompiledModel::LLMCompiledModel(const std::shared_ptr<ov::Model>& m
11171290 axes,
11181291 m_max_lora_rank,
11191292 whisper_lhs_seq_size);
1293+
1294+ LOG_DEBUG (" Try parametrize Gemma sliding window mask, if it exists." );
11201295 gemma_transformations (kvcache_model);
11211296
11221297 if (lm_head_model) {
0 commit comments