- 
                Notifications
    You must be signed in to change notification settings 
- Fork 2.8k
[NPUW] Phi-3 2K accuracy fix #32426
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NPUW] Phi-3 2K accuracy fix #32426
Changes from all commits
7628e2f
              923c145
              92ae4ca
              9d310cd
              b142dc0
              58dfac2
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -396,6 +396,198 @@ class GemmaSlidingMask : public ov::pass::MatcherPass { | |
| } | ||
| }; | ||
|  | ||
| class Phi3SlidingMask : public ov::pass::MatcherPass { | ||
| public: | ||
| OPENVINO_MATCHER_PASS_RTTI("npuw::LLMCompiledModel::Phi3SlidingMask"); | ||
|  | ||
| Phi3SlidingMask() { | ||
| // Search for the Phi3 sliding mask pattern to extend it to work with right-padded | ||
| // past tokens and left-padded present tokens. | ||
| // | ||
| // Mask creation is simply done via "less_equal" and "greater" operations between | ||
| // row K range: [0,... mask_len] and column Q range: [current_pos_id,... mask_len].T | ||
| // and sliding window length. | ||
| // Due to broadcasting rules these two operation form two triangular masks. | ||
| // | ||
| // - "less_equal" forms a sliding window mask, more precisely, it has following expression: | ||
| // | ||
| // row range [0,... mask_len] <= column range [current_pos_id - sliding_window_size, | ||
| // ..., | ||
| // mask_len - sliding_window_size] | ||
| // | ||
| // forming, under example conditions, the mask below: | ||
| // past tokens = 3 | ||
| // present tokens = 5 (starting with current_pos_id = 3) | ||
| 
      Comment on lines
    
      +419
     to 
      +420
    
   There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. out of curiosity, how can we get 5 present tokens with 3 past tokens? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It can be done in speculative decoding case, for example | ||
| // sliding window len = 4 | ||
| // K0 K1 K2 K3 K4 K5 K6 K7 | ||
| // [ 0 1 2 3 4 5 6 7 ] | ||
| // Q3[ 3 - 4 ] 0 0 0 0 0 0 0 0 | ||
| // Q4[ 4 - 4 ] 1 0 0 0 0 0 0 0 | ||
| // Q5[ 5 - 4 ] 1 1 0 0 0 0 0 0 | ||
| // Q6[ 6 - 4 ] 1 1 1 0 0 0 0 0 | ||
| // Q7[ 7 - 4 ] 1 1 1 1 0 0 0 0 | ||
| // where 1 at [i, j] means that j token should be forgotten as it can't fit into the sliding | ||
| // window from the left of i-th token. | ||
| // | ||
| // - "greater" forms a similar to self-attention mask: | ||
| // | ||
| // row range [0,... mask_len] > column range [current_pos_id, | ||
| // ..., | ||
| // mask_len] | ||
| // | ||
| // forming, under example conditions, the mask below: | ||
| // past tokens = 3 | ||
| // present tokens = 5 (starting with current_pos_id = 3) | ||
| // K0 K1 K2 K3 K4 K5 K6 K7 | ||
| // [ 0 1 2 3 4 5 6 7 ] | ||
| // Q3[ 3 ] 0 0 0 0 1 1 1 1 | ||
| // Q4[ 4 ] 0 0 0 0 0 1 1 1 | ||
| // Q5[ 5 ] 0 0 0 0 0 0 1 1 | ||
| // Q6[ 6 ] 0 0 0 0 0 0 0 1 | ||
| // Q7[ 7 ] 0 0 0 0 0 0 0 0 | ||
| // where 1 at [i, j] means that j token is a future token for i-th token, that we shouldn't attend to. | ||
| // | ||
| // Together, via "bitwise_or" this two masks forms the inverted sliding attention mask: | ||
| // past tokens = 3 | ||
| // present tokens = 5 (starting with current_pos_id = 3) | ||
| // sliding window len = 4 | ||
| // K0 K1 K2 K3 K4 K5 K6 K7 | ||
| // [ 0 1 2 3 4 5 6 7 ] | ||
| // Q3[ 3 - 4 ] 0 0 0 0 1 1 1 1 | ||
| // Q4[ 4 - 4 ] 1 0 0 0 0 1 1 1 | ||
| // Q5[ 5 - 4 ] 1 1 0 0 0 0 1 1 | ||
| // Q6[ 6 - 4 ] 1 1 1 0 0 0 0 1 | ||
| // Q7[ 7 - 4 ] 1 1 1 1 0 0 0 0 | ||
|         
                  dmatveev marked this conversation as resolved.
              Show resolved
            Hide resolved | ||
| // | ||
| // Issue with sliding attention mask appears when we work with static shapes and different | ||
| // paddings for past and present tokens. | ||
| // More precisely, issue appears with sliding window mask, as Q column range is created | ||
| // from length of past key/values tensor (2175 for 2K case) as start point and the length | ||
| // of attention mask (2176 for 2K) as an end point. This is okay for inverted | ||
| // self-attention mask by means of "greater" operation, as our present tokens exactly | ||
| // left-padded and located on the right in the attention mask. | ||
| // However, for the sliding window mask created by means of "less_equal" operation, given | ||
| // Q range will behave as if position ids of new Q tokens will start from 2175 and not from | ||
| // 3 as in example above and therefore, 2175 - 2047 = 128 first tokens should be forgotten. | ||
| // To fix it a new formula is suggested: | ||
| // 1. (K range <= (Q_pos range - sliding window).T) | (K range > Q range.T) | ||
| // 2. (K range <= (Q range - sliding window).T) & (K range >= len(past_key_values)) | ||
| // 3. Resulting mask = 1 | 2, | ||
| // where K range and Q range are created by the same rules as before and Q_pos range is | ||
| // a position_ids array. | ||
| // 4. We also clean mask in places where paddings used instead of real tokens via: | ||
| // Clean mask = 3 | !(attention_mask_input[past_kv_len:]).T | ||
| auto past_kv_len = opp::wrap_type<ov::op::v8::Gather>({opp::any_input(), opp::any_input(), opp::any_input()}); | ||
| auto pos_ids_param = opp::wrap_type<ov::op::v0::Parameter>(); | ||
| auto pos_ids_shape_of = opp::wrap_type<ov::op::v3::ShapeOf>({pos_ids_param}); | ||
| auto pos_ids_len = opp::wrap_type<ov::op::v8::Gather>({pos_ids_shape_of, opp::any_input(), opp::any_input()}); | ||
| auto full_ctx_len = opp::wrap_type<ov::op::v1::Add>({past_kv_len, pos_ids_len}); | ||
| auto query_range = opp::wrap_type<ov::op::v4::Range>({past_kv_len, full_ctx_len, opp::any_input()}); | ||
| auto column_shape = opp::wrap_type<ov::op::v0::Constant>(); | ||
| auto query_range_column = opp::wrap_type<ov::op::v1::Reshape>({query_range, column_shape}); | ||
|  | ||
| auto zero_const = opp::wrap_type<ov::op::v0::Constant>(); | ||
| auto atten_mask_param = opp::wrap_type<ov::op::v0::Parameter>(); | ||
| auto atten_mask_shape_of = opp::wrap_type<ov::op::v3::ShapeOf>({atten_mask_param}); | ||
| auto atten_mask_len = | ||
| opp::wrap_type<ov::op::v8::Gather>({atten_mask_shape_of, opp::any_input(), opp::any_input()}); | ||
| auto key_range = opp::wrap_type<ov::op::v4::Range>({zero_const, atten_mask_len, opp::any_input()}); | ||
| auto key_range_i64 = opp::wrap_type<ov::op::v0::Convert>({key_range}); | ||
| auto key_range_f32 = opp::wrap_type<ov::op::v0::Convert>({key_range_i64}); | ||
|  | ||
| auto neg_window_size = opp::wrap_type<ov::op::v0::Constant>(); | ||
| auto query_left_bound_range = opp::wrap_type<ov::op::v1::Add>({query_range_column, neg_window_size}); | ||
| // False in mask means that we shouldn't forget this token | ||
| auto forget_left_tokens_mask = opp::wrap_type<ov::op::v1::LessEqual>({key_range_f32, query_left_bound_range}); | ||
| // Basically it is a reference triangle self-attention mask that | ||
| // forbids tokens to attend to future ones, but values are inverted: | ||
| auto look_only_future_mask = opp::wrap_type<ov::op::v1::Greater>({key_range_f32, query_range_column}); | ||
|  | ||
| auto inv_sliding_attention_mask = | ||
| opp::wrap_type<ov::op::v13::BitwiseOr>({look_only_future_mask, forget_left_tokens_mask}); | ||
|  | ||
| auto callback = [=](ov::pass::pattern::Matcher& m) { | ||
| auto& node_to_output = m.get_pattern_value_map(); | ||
| auto node_past_kv_len = node_to_output.at(past_kv_len).get_node_shared_ptr(); | ||
| auto node_pos_ids_param = node_to_output.at(pos_ids_param).get_node_shared_ptr(); | ||
| auto node_atten_mask_param = node_to_output.at(atten_mask_param).get_node_shared_ptr(); | ||
| auto node_atten_mask_len = node_to_output.at(atten_mask_len).get_node_shared_ptr(); | ||
| auto node_key_range_f32 = node_to_output.at(key_range_f32).get_node_shared_ptr(); | ||
| auto node_neg_window_size = node_to_output.at(neg_window_size).get_node_shared_ptr(); | ||
| auto node_forget_left_tokens_mask = node_to_output.at(forget_left_tokens_mask).get_node_shared_ptr(); | ||
| auto node_bitwise_or = node_to_output.at(inv_sliding_attention_mask).get_node_shared_ptr(); | ||
|  | ||
| auto matched_past_kv_len = std::static_pointer_cast<ov::op::v8::Gather>(node_past_kv_len); | ||
| auto matched_pos_ids_input = std::static_pointer_cast<ov::op::v0::Parameter>(node_pos_ids_param); | ||
| auto matched_atten_mask_input = std::static_pointer_cast<ov::op::v0::Parameter>(node_atten_mask_param); | ||
| auto matched_atten_mask_len = std::static_pointer_cast<ov::op::v8::Gather>(node_atten_mask_len); | ||
| auto matched_key_range_f32 = std::static_pointer_cast<ov::op::v0::Convert>(node_key_range_f32); | ||
| auto matched_neg_window_size = std::static_pointer_cast<ov::op::v0::Constant>(node_neg_window_size); | ||
| auto matched_forget_left_tokens_mask = | ||
| std::static_pointer_cast<ov::op::v1::LessEqual>(node_forget_left_tokens_mask); | ||
| auto matched_bitwise_or = std::static_pointer_cast<ov::op::v13::BitwiseOr>(node_bitwise_or); | ||
| OPENVINO_ASSERT(matched_neg_window_size->get_output_size() == 1, | ||
| "Sliding window size constant must be of size 1, but got " + | ||
| std::to_string(matched_neg_window_size->get_output_size())); | ||
|  | ||
| // 1.(K range <= (Q_pos range - sliding window).T) | (K range > Q range.T) | ||
| auto query_range_as_pos_ids = | ||
| std::make_shared<ov::op::v0::Convert>(matched_pos_ids_input, ov::element::f32); | ||
| std::vector<int64_t> vector_shape{-1, 1}; | ||
| auto vector_shape_const = | ||
| std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{2}, vector_shape); | ||
| auto query_range_as_pos_ids_col = | ||
| std::make_shared<ov::op::v1::Reshape>(query_range_as_pos_ids, vector_shape_const, false); | ||
| auto query_range_as_pos_left_bound = | ||
| std::make_shared<ov::op::v1::Add>(query_range_as_pos_ids_col, matched_neg_window_size); | ||
| auto forget_left_mask_for_right_padding = | ||
| std::make_shared<ov::op::v1::LessEqual>(matched_key_range_f32, query_range_as_pos_left_bound); | ||
| matched_bitwise_or->input(1).replace_source_output(forget_left_mask_for_right_padding); | ||
|  | ||
| // 2. (K range <= (Q range - sliding window).T) & (K range >= shape(past_key_values, 2)) | ||
| auto past_kv_len_f32 = std::make_shared<ov::op::v0::Convert>(matched_past_kv_len, ov::element::f32); | ||
| auto only_present_tokens_mask = | ||
| std::make_shared<ov::op::v1::GreaterEqual>(matched_key_range_f32, past_kv_len_f32); | ||
| auto bitwise_and = | ||
| std::make_shared<ov::op::v13::BitwiseAnd>(matched_forget_left_tokens_mask, only_present_tokens_mask); | ||
|  | ||
| // 3. Result = 1 | 2 | ||
| // Save target inputs first: | ||
| auto target_inputs = matched_bitwise_or->output(0).get_target_inputs(); | ||
| auto new_inv_sliding_mask = std::make_shared<ov::op::v13::BitwiseOr>(matched_bitwise_or, bitwise_and); | ||
|  | ||
| // 4. Removing extra padding via : 3 | !(attention_mask_input[past_kv_len:]).T | ||
| std::vector<int64_t> shape_rank_one{1}; | ||
| auto shape_rank_one_const = | ||
| std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{1}, shape_rank_one); | ||
| auto past_len_reshaped = | ||
| std::make_shared<ov::op::v1::Reshape>(matched_past_kv_len, shape_rank_one_const, false); | ||
| auto atten_len_reshaped = | ||
| std::make_shared<ov::op::v1::Reshape>(matched_atten_mask_len, shape_rank_one_const, false); | ||
| auto const_one = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{1}, 1); | ||
| auto present_atten_mask = std::make_shared<ov::op::v8::Slice>(matched_atten_mask_input, | ||
| past_len_reshaped, | ||
| atten_len_reshaped, | ||
| const_one, | ||
| const_one); | ||
| auto present_atten_mask_bool = | ||
| std::make_shared<ov::op::v0::Convert>(present_atten_mask, ov::element::boolean); | ||
| auto inv_present_atten_mask = std::make_shared<ov::op::v1::LogicalNot>(present_atten_mask_bool); | ||
| auto inv_present_atten_mask_col = | ||
| std::make_shared<ov::op::v1::Reshape>(inv_present_atten_mask, vector_shape_const, false); | ||
| auto clean_inv_sliding_mask = | ||
| std::make_shared<ov::op::v13::BitwiseOr>(new_inv_sliding_mask, inv_present_atten_mask_col); | ||
| for (auto&& input : target_inputs) { | ||
| input.replace_source_output(clean_inv_sliding_mask); | ||
| } | ||
|  | ||
| return true; | ||
| }; | ||
| register_matcher(std::make_shared<opp::Matcher>(inv_sliding_attention_mask, "Phi3SlidingMask"), | ||
| std::move(callback)); | ||
| } | ||
| }; | ||
|  | ||
| namespace { | ||
| uint32_t align_to(uint32_t value, uint32_t alignment) { | ||
| return (value + alignment - 1) & ~(alignment - 1); | ||
|  | @@ -455,6 +647,13 @@ void decompose_GQA(std::shared_ptr<ov::Model> model, bool is_prefill_model) { | |
| rewr.add_matcher<GroupQueryAttentionDecomposition>(is_prefill_model); | ||
| rewr.run_on_model(model); | ||
| } | ||
|  | ||
| void patch_phi3_sliding_mask(const std::shared_ptr<ov::Model>& model) { | ||
| ov::pass::GraphRewrite rewr; | ||
| rewr.add_matcher<Phi3SlidingMask>(); | ||
| rewr.run_on_model(model); | ||
| model->validate_nodes_and_infer_types(); | ||
| } | ||
| } // namespace | ||
|  | ||
| namespace { | ||
|  | @@ -1022,6 +1221,9 @@ ov::npuw::LLMCompiledModel::LLMCompiledModel(const std::shared_ptr<ov::Model>& m | |
| LOG_INFO("Two-model pipeline will be created."); | ||
| } | ||
|  | ||
| LOG_DEBUG("Try patch Phi-3 sliding window mask, if it exists."); | ||
| patch_phi3_sliding_mask(kvcache_model); | ||
|  | ||
| LOG_DEBUG("Creating prefill model as clone of transformed kvcache one."); | ||
| auto prefill_model = kvcache_model->clone(); | ||
| prefill_model->set_friendly_name(kvcache_model->get_friendly_name() + "_prefill"); | ||
|  | @@ -1117,6 +1319,8 @@ ov::npuw::LLMCompiledModel::LLMCompiledModel(const std::shared_ptr<ov::Model>& m | |
| axes, | ||
| m_max_lora_rank, | ||
| whisper_lhs_seq_size); | ||
|  | ||
| LOG_DEBUG("Try parametrize Gemma sliding window mask, if it exists."); | ||
| gemma_transformations(kvcache_model); | ||
|  | ||
| if (lm_head_model) { | ||
|  | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you introduced a new pass, why there are changes to the existing code? Have you also moved stuff around?
Since the changes are quite complex nowadays, I'd avoid the unnecessary diffs to focus only on the relevant ones.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, will do!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed, thanks!