Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 204 additions & 0 deletions src/plugins/intel_npu/src/plugin/npuw/llm_compiled_model.cpp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you introduced a new pass, why there are changes to the existing code? Have you also moved stuff around?
Since the changes are quite complex nowadays, I'd avoid the unnecessary diffs to focus only on the relevant ones.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, will do!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed, thanks!

Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,198 @@ class GemmaSlidingMask : public ov::pass::MatcherPass {
}
};

class Phi3SlidingMask : public ov::pass::MatcherPass {
public:
OPENVINO_MATCHER_PASS_RTTI("npuw::LLMCompiledModel::Phi3SlidingMask");

Phi3SlidingMask() {
// Search for the Phi3 sliding mask pattern to extend it to work with right-padded
// past tokens and left-padded present tokens.
//
// Mask creation is simply done via "less_equal" and "greater" operations between
// row K range: [0,... mask_len] and column Q range: [current_pos_id,... mask_len].T
// and sliding window length.
// Due to broadcasting rules these two operation form two triangular masks.
//
// - "less_equal" forms a sliding window mask, more precisely, it has following expression:
//
// row range [0,... mask_len] <= column range [current_pos_id - sliding_window_size,
// ...,
// mask_len - sliding_window_size]
//
// forming, under example conditions, the mask below:
// past tokens = 3
// present tokens = 5 (starting with current_pos_id = 3)
Comment on lines +419 to +420
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

out of curiosity, how can we get 5 present tokens with 3 past tokens?

Copy link
Contributor Author

@AsyaPronina AsyaPronina Oct 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can be done in speculative decoding case, for example
However, 3 and 5 case was described as general simple case for dynamic shapes

// sliding window len = 4
// K0 K1 K2 K3 K4 K5 K6 K7
// [ 0 1 2 3 4 5 6 7 ]
// Q3[ 3 - 4 ] 0 0 0 0 0 0 0 0
// Q4[ 4 - 4 ] 1 0 0 0 0 0 0 0
// Q5[ 5 - 4 ] 1 1 0 0 0 0 0 0
// Q6[ 6 - 4 ] 1 1 1 0 0 0 0 0
// Q7[ 7 - 4 ] 1 1 1 1 0 0 0 0
// where 1 at [i, j] means that j token should be forgotten as it can't fit into the sliding
// window from the left of i-th token.
//
// - "greater" forms a similar to self-attention mask:
//
// row range [0,... mask_len] > column range [current_pos_id,
// ...,
// mask_len]
//
// forming, under example conditions, the mask below:
// past tokens = 3
// present tokens = 5 (starting with current_pos_id = 3)
// K0 K1 K2 K3 K4 K5 K6 K7
// [ 0 1 2 3 4 5 6 7 ]
// Q3[ 3 ] 0 0 0 0 1 1 1 1
// Q4[ 4 ] 0 0 0 0 0 1 1 1
// Q5[ 5 ] 0 0 0 0 0 0 1 1
// Q6[ 6 ] 0 0 0 0 0 0 0 1
// Q7[ 7 ] 0 0 0 0 0 0 0 0
// where 1 at [i, j] means that j token is a future token for i-th token, that we shouldn't attend to.
//
// Together, via "bitwise_or" this two masks forms the inverted sliding attention mask:
// past tokens = 3
// present tokens = 5 (starting with current_pos_id = 3)
// sliding window len = 4
// K0 K1 K2 K3 K4 K5 K6 K7
// [ 0 1 2 3 4 5 6 7 ]
// Q3[ 3 - 4 ] 0 0 0 0 1 1 1 1
// Q4[ 4 - 4 ] 1 0 0 0 0 1 1 1
// Q5[ 5 - 4 ] 1 1 0 0 0 0 1 1
// Q6[ 6 - 4 ] 1 1 1 0 0 0 0 1
// Q7[ 7 - 4 ] 1 1 1 1 0 0 0 0
//
// Issue with sliding attention mask appears when we work with static shapes and different
// paddings for past and present tokens.
// More precisely, issue appears with sliding window mask, as Q column range is created
// from length of past key/values tensor (2175 for 2K case) as start point and the length
// of attention mask (2176 for 2K) as an end point. This is okay for inverted
// self-attention mask by means of "greater" operation, as our present tokens exactly
// left-padded and located on the right in the attention mask.
// However, for the sliding window mask created by means of "less_equal" operation, given
// Q range will behave as if position ids of new Q tokens will start from 2175 and not from
// 3 as in example above and therefore, 2175 - 2047 = 128 first tokens should be forgotten.
// To fix it a new formula is suggested:
// 1. (K range <= (Q_pos range - sliding window).T) | (K range > Q range.T)
// 2. (K range <= (Q range - sliding window).T) & (K range >= len(past_key_values))
// 3. Resulting mask = 1 | 2,
// where K range and Q range are created by the same rules as before and Q_pos range is
// a position_ids array.
// 4. We also clean mask in places where paddings used instead of real tokens via:
// Clean mask = 3 | !(attention_mask_input[past_kv_len:]).T
auto past_kv_len = opp::wrap_type<ov::op::v8::Gather>({opp::any_input(), opp::any_input(), opp::any_input()});
auto pos_ids_param = opp::wrap_type<ov::op::v0::Parameter>();
auto pos_ids_shape_of = opp::wrap_type<ov::op::v3::ShapeOf>({pos_ids_param});
auto pos_ids_len = opp::wrap_type<ov::op::v8::Gather>({pos_ids_shape_of, opp::any_input(), opp::any_input()});
auto full_ctx_len = opp::wrap_type<ov::op::v1::Add>({past_kv_len, pos_ids_len});
auto query_range = opp::wrap_type<ov::op::v4::Range>({past_kv_len, full_ctx_len, opp::any_input()});
auto column_shape = opp::wrap_type<ov::op::v0::Constant>();
auto query_range_column = opp::wrap_type<ov::op::v1::Reshape>({query_range, column_shape});

auto zero_const = opp::wrap_type<ov::op::v0::Constant>();
auto atten_mask_param = opp::wrap_type<ov::op::v0::Parameter>();
auto atten_mask_shape_of = opp::wrap_type<ov::op::v3::ShapeOf>({atten_mask_param});
auto atten_mask_len =
opp::wrap_type<ov::op::v8::Gather>({atten_mask_shape_of, opp::any_input(), opp::any_input()});
auto key_range = opp::wrap_type<ov::op::v4::Range>({zero_const, atten_mask_len, opp::any_input()});
auto key_range_i64 = opp::wrap_type<ov::op::v0::Convert>({key_range});
auto key_range_f32 = opp::wrap_type<ov::op::v0::Convert>({key_range_i64});

auto neg_window_size = opp::wrap_type<ov::op::v0::Constant>();
auto query_left_bound_range = opp::wrap_type<ov::op::v1::Add>({query_range_column, neg_window_size});
// False in mask means that we shouldn't forget this token
auto forget_left_tokens_mask = opp::wrap_type<ov::op::v1::LessEqual>({key_range_f32, query_left_bound_range});
// Basically it is a reference triangle self-attention mask that
// forbids tokens to attend to future ones, but values are inverted:
auto look_only_future_mask = opp::wrap_type<ov::op::v1::Greater>({key_range_f32, query_range_column});

auto inv_sliding_attention_mask =
opp::wrap_type<ov::op::v13::BitwiseOr>({look_only_future_mask, forget_left_tokens_mask});

auto callback = [=](ov::pass::pattern::Matcher& m) {
auto& node_to_output = m.get_pattern_value_map();
auto node_past_kv_len = node_to_output.at(past_kv_len).get_node_shared_ptr();
auto node_pos_ids_param = node_to_output.at(pos_ids_param).get_node_shared_ptr();
auto node_atten_mask_param = node_to_output.at(atten_mask_param).get_node_shared_ptr();
auto node_atten_mask_len = node_to_output.at(atten_mask_len).get_node_shared_ptr();
auto node_key_range_f32 = node_to_output.at(key_range_f32).get_node_shared_ptr();
auto node_neg_window_size = node_to_output.at(neg_window_size).get_node_shared_ptr();
auto node_forget_left_tokens_mask = node_to_output.at(forget_left_tokens_mask).get_node_shared_ptr();
auto node_bitwise_or = node_to_output.at(inv_sliding_attention_mask).get_node_shared_ptr();

auto matched_past_kv_len = std::static_pointer_cast<ov::op::v8::Gather>(node_past_kv_len);
auto matched_pos_ids_input = std::static_pointer_cast<ov::op::v0::Parameter>(node_pos_ids_param);
auto matched_atten_mask_input = std::static_pointer_cast<ov::op::v0::Parameter>(node_atten_mask_param);
auto matched_atten_mask_len = std::static_pointer_cast<ov::op::v8::Gather>(node_atten_mask_len);
auto matched_key_range_f32 = std::static_pointer_cast<ov::op::v0::Convert>(node_key_range_f32);
auto matched_neg_window_size = std::static_pointer_cast<ov::op::v0::Constant>(node_neg_window_size);
auto matched_forget_left_tokens_mask =
std::static_pointer_cast<ov::op::v1::LessEqual>(node_forget_left_tokens_mask);
auto matched_bitwise_or = std::static_pointer_cast<ov::op::v13::BitwiseOr>(node_bitwise_or);
OPENVINO_ASSERT(matched_neg_window_size->get_output_size() == 1,
"Sliding window size constant must be of size 1, but got " +
std::to_string(matched_neg_window_size->get_output_size()));

// 1.(K range <= (Q_pos range - sliding window).T) | (K range > Q range.T)
auto query_range_as_pos_ids =
std::make_shared<ov::op::v0::Convert>(matched_pos_ids_input, ov::element::f32);
std::vector<int64_t> vector_shape{-1, 1};
auto vector_shape_const =
std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{2}, vector_shape);
auto query_range_as_pos_ids_col =
std::make_shared<ov::op::v1::Reshape>(query_range_as_pos_ids, vector_shape_const, false);
auto query_range_as_pos_left_bound =
std::make_shared<ov::op::v1::Add>(query_range_as_pos_ids_col, matched_neg_window_size);
auto forget_left_mask_for_right_padding =
std::make_shared<ov::op::v1::LessEqual>(matched_key_range_f32, query_range_as_pos_left_bound);
matched_bitwise_or->input(1).replace_source_output(forget_left_mask_for_right_padding);

// 2. (K range <= (Q range - sliding window).T) & (K range >= shape(past_key_values, 2))
auto past_kv_len_f32 = std::make_shared<ov::op::v0::Convert>(matched_past_kv_len, ov::element::f32);
auto only_present_tokens_mask =
std::make_shared<ov::op::v1::GreaterEqual>(matched_key_range_f32, past_kv_len_f32);
auto bitwise_and =
std::make_shared<ov::op::v13::BitwiseAnd>(matched_forget_left_tokens_mask, only_present_tokens_mask);

// 3. Result = 1 | 2
// Save target inputs first:
auto target_inputs = matched_bitwise_or->output(0).get_target_inputs();
auto new_inv_sliding_mask = std::make_shared<ov::op::v13::BitwiseOr>(matched_bitwise_or, bitwise_and);

// 4. Removing extra padding via : 3 | !(attention_mask_input[past_kv_len:]).T
std::vector<int64_t> shape_rank_one{1};
auto shape_rank_one_const =
std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{1}, shape_rank_one);
auto past_len_reshaped =
std::make_shared<ov::op::v1::Reshape>(matched_past_kv_len, shape_rank_one_const, false);
auto atten_len_reshaped =
std::make_shared<ov::op::v1::Reshape>(matched_atten_mask_len, shape_rank_one_const, false);
auto const_one = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{1}, 1);
auto present_atten_mask = std::make_shared<ov::op::v8::Slice>(matched_atten_mask_input,
past_len_reshaped,
atten_len_reshaped,
const_one,
const_one);
auto present_atten_mask_bool =
std::make_shared<ov::op::v0::Convert>(present_atten_mask, ov::element::boolean);
auto inv_present_atten_mask = std::make_shared<ov::op::v1::LogicalNot>(present_atten_mask_bool);
auto inv_present_atten_mask_col =
std::make_shared<ov::op::v1::Reshape>(inv_present_atten_mask, vector_shape_const, false);
auto clean_inv_sliding_mask =
std::make_shared<ov::op::v13::BitwiseOr>(new_inv_sliding_mask, inv_present_atten_mask_col);
for (auto&& input : target_inputs) {
input.replace_source_output(clean_inv_sliding_mask);
}

return true;
};
register_matcher(std::make_shared<opp::Matcher>(inv_sliding_attention_mask, "Phi3SlidingMask"),
std::move(callback));
}
};

namespace {
uint32_t align_to(uint32_t value, uint32_t alignment) {
return (value + alignment - 1) & ~(alignment - 1);
Expand Down Expand Up @@ -455,6 +647,13 @@ void decompose_GQA(std::shared_ptr<ov::Model> model, bool is_prefill_model) {
rewr.add_matcher<GroupQueryAttentionDecomposition>(is_prefill_model);
rewr.run_on_model(model);
}

void patch_phi3_sliding_mask(const std::shared_ptr<ov::Model>& model) {
ov::pass::GraphRewrite rewr;
rewr.add_matcher<Phi3SlidingMask>();
rewr.run_on_model(model);
model->validate_nodes_and_infer_types();
}
} // namespace

namespace {
Expand Down Expand Up @@ -1022,6 +1221,9 @@ ov::npuw::LLMCompiledModel::LLMCompiledModel(const std::shared_ptr<ov::Model>& m
LOG_INFO("Two-model pipeline will be created.");
}

LOG_DEBUG("Try patch Phi-3 sliding window mask, if it exists.");
patch_phi3_sliding_mask(kvcache_model);

LOG_DEBUG("Creating prefill model as clone of transformed kvcache one.");
auto prefill_model = kvcache_model->clone();
prefill_model->set_friendly_name(kvcache_model->get_friendly_name() + "_prefill");
Expand Down Expand Up @@ -1117,6 +1319,8 @@ ov::npuw::LLMCompiledModel::LLMCompiledModel(const std::shared_ptr<ov::Model>& m
axes,
m_max_lora_rank,
whisper_lhs_seq_size);

LOG_DEBUG("Try parametrize Gemma sliding window mask, if it exists.");
gemma_transformations(kvcache_model);

if (lm_head_model) {
Expand Down
Loading