@@ -513,7 +513,6 @@ class Phi3SlidingMask : public ov::pass::MatcherPass {
513513            auto  node_atten_mask_param = node_to_output.at (atten_mask_param).get_node_shared_ptr ();
514514            auto  node_atten_mask_len = node_to_output.at (atten_mask_len).get_node_shared_ptr ();
515515            auto  node_key_range_f32 = node_to_output.at (key_range_f32).get_node_shared_ptr ();
516-             auto  node_query_left_bound_range = node_to_output.at (query_left_bound_range).get_node_shared_ptr ();
517516            auto  node_neg_window_size = node_to_output.at (neg_window_size).get_node_shared_ptr ();
518517            auto  node_forget_left_tokens_mask = node_to_output.at (forget_left_tokens_mask).get_node_shared_ptr ();
519518            auto  node_bitwise_or = node_to_output.at (inv_sliding_attention_mask).get_node_shared_ptr ();
@@ -523,7 +522,6 @@ class Phi3SlidingMask : public ov::pass::MatcherPass {
523522            auto  matched_atten_mask_input = std::static_pointer_cast<ov::op::v0::Parameter>(node_atten_mask_param);
524523            auto  matched_atten_mask_len = std::static_pointer_cast<ov::op::v8::Gather>(node_atten_mask_len);
525524            auto  matched_key_range_f32 = std::static_pointer_cast<ov::op::v0::Convert>(node_key_range_f32);
526-             auto  matched_query_left_bound = std::static_pointer_cast<ov::op::v1::Add>(node_query_left_bound_range);
527525            auto  matched_neg_window_size = std::static_pointer_cast<ov::op::v0::Constant>(node_neg_window_size);
528526            auto  matched_forget_left_tokens_mask =
529527                std::static_pointer_cast<ov::op::v1::LessEqual>(node_forget_left_tokens_mask);
@@ -559,18 +557,19 @@ class Phi3SlidingMask : public ov::pass::MatcherPass {
559557            auto  new_inv_sliding_mask = std::make_shared<ov::op::v13::BitwiseOr>(matched_bitwise_or, bitwise_and);
560558
561559            //  4. Removing extra padding via : 3 | ~(attention_mask_input[past_kv_len:]).T
562-             std::vector<int64_t > shape_1{1 };
563-             auto  shape_1_const = std::make_shared<ov::op::v0::Constant>(ov::element::i64 , ov::Shape{1 }, shape_1);
564-             auto  matched_past_len_shape_1 =
565-                 std::make_shared<ov::op::v1::Reshape>(matched_past_kv_len, shape_1_const, false );
566-             auto  matched_atten_len_shape_1 =
567-                 std::make_shared<ov::op::v1::Reshape>(matched_atten_mask_len, shape_1_const, false );
568-             auto  const_1 = std::make_shared<ov::op::v0::Constant>(ov::element::i64 , ov::Shape{1 }, 1 );
560+             std::vector<int64_t > shape_rank_one{1 };
561+             auto  shape_rank_one_const =
562+                 std::make_shared<ov::op::v0::Constant>(ov::element::i64 , ov::Shape{1 }, shape_rank_one);
563+             auto  past_len_reshaped =
564+                 std::make_shared<ov::op::v1::Reshape>(matched_past_kv_len, shape_rank_one_const, false );
565+             auto  atten_len_reshaped =
566+                 std::make_shared<ov::op::v1::Reshape>(matched_atten_mask_len, shape_rank_one_const, false );
567+             auto  const_one = std::make_shared<ov::op::v0::Constant>(ov::element::i64 , ov::Shape{1 }, 1 );
569568            auto  present_atten_mask = std::make_shared<ov::op::v8::Slice>(matched_atten_mask_input,
570-                                                                           matched_past_len_shape_1 ,
571-                                                                           matched_atten_len_shape_1 ,
572-                                                                           const_1 ,
573-                                                                           const_1 );
569+                                                                           past_len_reshaped ,
570+                                                                           atten_len_reshaped ,
571+                                                                           const_one ,
572+                                                                           const_one );
574573            auto  present_atten_mask_bool =
575574                std::make_shared<ov::op::v0::Convert>(present_atten_mask, ov::element::boolean);
576575            auto  inv_present_atten_mask = std::make_shared<ov::op::v13::BitwiseNot>(present_atten_mask_bool);
0 commit comments