Skip to content

Commit aad45d5

Browse files
authored
[INTEL_HPU] add use_neox_style (#1834)
1 parent 2febc48 commit aad45d5

File tree

3 files changed

+59
-25
lines changed

3 files changed

+59
-25
lines changed

backends/intel_hpu/custom_ops/llama_infer/fused_block_attention.cc

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ struct FusedBlockAttentionParams {
3131
int num_head;
3232
int num_kv_head;
3333

34+
bool use_neox_style = true;
3435
bool with_qkv_biases = false;
3536
bool transpose = true;
3637
};
@@ -218,7 +219,9 @@ class FusedMHABlockAttention : public HpuFusedOperator {
218219

219220
ns_RoPESt2::ParamsV2 ropeParams;
220221
ropeParams.offset = 0;
221-
ropeParams.mode = ROTARY_POS_EMBEDDING_MODE_BLOCKWISE;
222+
ropeParams.mode = params.use_neox_style
223+
? ROTARY_POS_EMBEDDING_MODE_BLOCKWISE
224+
: ROTARY_POS_EMBEDDING_MODE_PAIRWISE;
222225
AddNodeRope<T>(inputs_q, outputs_q, ropeParams, guid_ + "rope_q");
223226

224227
std::vector<synTensor> inputs_k;
@@ -916,7 +919,9 @@ class FusedGQABlockAttention : public HpuFusedOperator {
916919

917920
ns_RoPESt2::ParamsV2 ropeParams;
918921
ropeParams.offset = 0;
919-
ropeParams.mode = ROTARY_POS_EMBEDDING_MODE_BLOCKWISE;
922+
ropeParams.mode = params.use_neox_style
923+
? ROTARY_POS_EMBEDDING_MODE_BLOCKWISE
924+
: ROTARY_POS_EMBEDDING_MODE_PAIRWISE;
920925
AddNodeRope<T>(inputs_q, outputs_q, ropeParams, guid_ + "rope_q");
921926

922927
std::vector<synTensor> inputs_k;
@@ -1495,14 +1500,16 @@ void FusedBlockAttentionKernel(
14951500
const phi::Scalar& head_dim,
14961501
const phi::Scalar& num_head,
14971502
const phi::Scalar& scaling_factor,
1498-
const phi::Scalar& transpose) {
1503+
const phi::Scalar& transpose,
1504+
const phi::Scalar& use_neox_style) {
14991505
std::vector<int64_t> src_dims = phi::vectorize<int64_t>(src.dims());
15001506
std::vector<int64_t> qkv_weights_dims =
15011507
phi::vectorize<int64_t>(qkv_weights.dims());
15021508

15031509
int head_dim_ = head_dim.to<int>();
15041510
int num_head_ = num_head.to<int>();
15051511
bool transpose_ = transpose.to<bool>();
1512+
bool use_neox_style_ = use_neox_style.to<bool>();
15061513
const int64_t fused_hidden_size =
15071514
transpose_ ? qkv_weights_dims[0] : qkv_weights_dims[1];
15081515
const int num_kv_head =
@@ -1553,6 +1560,7 @@ void FusedBlockAttentionKernel(
15531560
params.index_reduce_params.mode = INDEX_REDUCE_AMAX;
15541561
params.index_reduce_params.include_self = true;
15551562
params.index_reduce_params.axis = 0;
1563+
params.use_neox_style = use_neox_style_;
15561564
params.transpose = transpose_;
15571565
params.head_dim = head_dim_;
15581566
params.num_head = num_head_;
@@ -1603,7 +1611,8 @@ void CallFusedBlockAttentionKernel(
16031611
const phi::Scalar& head_dim,
16041612
const phi::Scalar& num_head,
16051613
const phi::Scalar& scaling_factor,
1606-
const phi::Scalar& transpose) {
1614+
const phi::Scalar& transpose,
1615+
const phi::Scalar& use_neox_style) {
16071616
if (src.dtype() == phi::DataType::FLOAT16) {
16081617
custom_kernel::FusedBlockAttentionKernel<phi::dtype::float16>(
16091618
dev_ctx,
@@ -1624,7 +1633,8 @@ void CallFusedBlockAttentionKernel(
16241633
head_dim,
16251634
num_head,
16261635
scaling_factor,
1627-
transpose);
1636+
transpose,
1637+
use_neox_style);
16281638
} else if (src.dtype() == phi::DataType::BFLOAT16) {
16291639
custom_kernel::FusedBlockAttentionKernel<phi::dtype::bfloat16>(
16301640
dev_ctx,
@@ -1645,7 +1655,8 @@ void CallFusedBlockAttentionKernel(
16451655
head_dim,
16461656
num_head,
16471657
scaling_factor,
1648-
transpose);
1658+
transpose,
1659+
use_neox_style);
16491660
} else {
16501661
throw std::runtime_error(
16511662
"Unsupported data type for FusedBlockAttentionKernel");
@@ -1669,7 +1680,8 @@ std::vector<paddle::Tensor> FusedBlockAttentionForward(
16691680
int head_dim,
16701681
int num_head,
16711682
float scaling_factor,
1672-
bool transpose) {
1683+
bool transpose,
1684+
bool use_neox_style) {
16731685
auto dev_ctx = static_cast<const phi::CustomContext*>(
16741686
paddle::experimental::DeviceContextPool::Instance().Get(src.place()));
16751687
auto src_tensor = static_cast<const phi::DenseTensor*>(src.impl().get());
@@ -1730,7 +1742,8 @@ std::vector<paddle::Tensor> FusedBlockAttentionForward(
17301742
phi::Scalar(head_dim),
17311743
phi::Scalar(num_head),
17321744
phi::Scalar(scaling_factor),
1733-
phi::Scalar(transpose));
1745+
phi::Scalar(transpose),
1746+
phi::Scalar(use_neox_style));
17341747
return {paddle::Tensor(out_linear)};
17351748
}
17361749

@@ -1751,7 +1764,8 @@ std::vector<std::vector<int64_t>> FusedBlockAttentionShape(
17511764
int head_dim,
17521765
int num_head,
17531766
float scaling_factor,
1754-
bool transpose) {
1767+
bool transpose,
1768+
bool use_neox_style) {
17551769
int64_t batch_size = src_shape[0];
17561770
int64_t out_features = linear_weights_shape[1];
17571771
return {{batch_size, 1, out_features}};
@@ -1792,7 +1806,8 @@ PD_BUILD_OP(fused_block_attention)
17921806
.Attrs({"head_dim: int",
17931807
"num_head: int",
17941808
"scaling_factor: float",
1795-
"transpose: bool"})
1809+
"transpose: bool",
1810+
"use_neox_style: bool"})
17961811
.SetKernelFn(PD_KERNEL(FusedBlockAttentionForward))
17971812
.SetInferShapeFn(PD_INFER_SHAPE(FusedBlockAttentionShape))
17981813
.SetInferDtypeFn(PD_INFER_DTYPE(FusedBlockAttentionDtype));

backends/intel_hpu/custom_ops/llama_infer/fused_qkv_rope.cc

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ struct FusedQkvRopeParams {
2626
int num_head;
2727
int kv_num_head;
2828

29+
bool use_neox_style = true;
2930
bool transpose = true;
3031
bool with_qkv_biases = false;
3132
bool use_fp8 = false;
@@ -187,7 +188,9 @@ class FusedQkvRope : public HpuFusedOperator {
187188

188189
ns_RoPESt2::ParamsV2 ropeParams;
189190
ropeParams.offset = 0;
190-
ropeParams.mode = ROTARY_POS_EMBEDDING_MODE_BLOCKWISE;
191+
ropeParams.mode = params.use_neox_style
192+
? ROTARY_POS_EMBEDDING_MODE_BLOCKWISE
193+
: ROTARY_POS_EMBEDDING_MODE_PAIRWISE;
191194
AddNodeRope<T>(inputs_q, outputs_q, ropeParams, guid_ + "rope_q");
192195

193196
std::vector<synTensor> inputs_k;
@@ -239,7 +242,8 @@ void FusedQkvRopeKernel(const Context& dev_ctx,
239242
const phi::Scalar& head_dim,
240243
const phi::Scalar& num_head,
241244
const phi::Scalar& total_batch,
242-
const phi::Scalar& transpose) {
245+
const phi::Scalar& transpose,
246+
const phi::Scalar& use_neox_style) {
243247
int total_batch_ = total_batch.to<int>();
244248
std::vector<int64_t> src_dims = phi::vectorize<int64_t>(src.dims());
245249
int bsz_seqlen = src_dims[0];
@@ -255,6 +259,7 @@ void FusedQkvRopeKernel(const Context& dev_ctx,
255259
int head_dim_ = head_dim.to<int>();
256260
int num_head_ = num_head.to<int>();
257261
bool transpose_ = transpose.to<bool>();
262+
bool use_neox_style_ = use_neox_style.to<bool>();
258263
const int64_t fused_hidden_size =
259264
transpose_ ? qkv_weights_dims[0] : qkv_weights_dims[1];
260265
const int kv_num_head =
@@ -297,6 +302,8 @@ void FusedQkvRopeKernel(const Context& dev_ctx,
297302
params.num_head = num_head_;
298303
params.kv_num_head = kv_num_head;
299304
params.transpose = transpose_;
305+
params.use_neox_style = use_neox_style_;
306+
300307
if (qkv_biases) {
301308
params.with_qkv_biases = true;
302309
}
@@ -333,7 +340,8 @@ void CallFusedQkvRopeKernel(
333340
const phi::Scalar& head_dim,
334341
const phi::Scalar& num_head,
335342
const phi::Scalar& total_batch,
336-
const phi::Scalar& transpose) {
343+
const phi::Scalar& transpose,
344+
const phi::Scalar& use_neox_style) {
337345
if (src.dtype() == phi::DataType::FLOAT16) {
338346
custom_kernel::FusedQkvRopeKernel<phi::dtype::float16>(dev_ctx,
339347
src,
@@ -347,7 +355,8 @@ void CallFusedQkvRopeKernel(
347355
head_dim,
348356
num_head,
349357
total_batch,
350-
transpose);
358+
transpose,
359+
use_neox_style);
351360
} else if (src.dtype() == phi::DataType::BFLOAT16) {
352361
custom_kernel::FusedQkvRopeKernel<phi::dtype::bfloat16>(dev_ctx,
353362
src,
@@ -361,7 +370,8 @@ void CallFusedQkvRopeKernel(
361370
head_dim,
362371
num_head,
363372
total_batch,
364-
transpose);
373+
transpose,
374+
use_neox_style);
365375
} else {
366376
throw std::runtime_error("Unsupported data type for FusedQkvRopeKernel");
367377
}
@@ -375,7 +385,8 @@ std::vector<paddle::Tensor> FusedQkvRopeImpl(
375385
int head_dim,
376386
int num_head,
377387
int total_batch,
378-
bool transpose) {
388+
bool transpose,
389+
bool use_neox_style) {
379390
auto dev_ctx = static_cast<const phi::CustomContext*>(
380391
paddle::experimental::DeviceContextPool::Instance().Get(src.place()));
381392
auto src_tensor = static_cast<const phi::DenseTensor*>(src.impl().get());
@@ -422,7 +433,8 @@ std::vector<paddle::Tensor> FusedQkvRopeImpl(
422433
phi::Scalar(head_dim),
423434
phi::Scalar(num_head),
424435
phi::Scalar(total_batch),
425-
phi::Scalar(transpose));
436+
phi::Scalar(transpose),
437+
phi::Scalar(use_neox_style));
426438
return {paddle::Tensor(query_states), paddle::Tensor(key_value_states)};
427439
}
428440

@@ -434,7 +446,8 @@ std::vector<std::vector<int64_t>> FusedQkvRopeShape(
434446
int head_dim,
435447
int num_head,
436448
int total_batch,
437-
bool transpose) {
449+
bool transpose,
450+
bool use_neox_style) {
438451
int64_t bsz = src_shape[0];
439452
int64_t seq_len = bsz / total_batch;
440453
int64_t fused_hidden_size =
@@ -459,7 +472,8 @@ PD_BUILD_OP(fused_qkv_rope)
459472
.Attrs({"head_dim: int",
460473
"num_head: int",
461474
"total_batch: int",
462-
"transpose: bool"})
475+
"transpose: bool",
476+
"use_neox_style: bool"})
463477
.SetKernelFn(PD_KERNEL(FusedQkvRopeImpl))
464478
.SetInferShapeFn(PD_INFER_SHAPE(FusedQkvRopeShape))
465479
.SetInferDtypeFn(PD_INFER_DTYPE(FusedQkvRopeDtype));
@@ -474,7 +488,8 @@ std::vector<paddle::Tensor> FusedFp8QkvRopeImpl(
474488
int head_dim,
475489
int num_head,
476490
int total_batch,
477-
bool transpose) {
491+
bool transpose,
492+
bool use_neox_style) {
478493
auto dev_ctx = static_cast<const phi::CustomContext*>(
479494
paddle::experimental::DeviceContextPool::Instance().Get(src.place()));
480495
auto src_tensor = static_cast<const phi::DenseTensor*>(src.impl().get());
@@ -528,7 +543,8 @@ std::vector<paddle::Tensor> FusedFp8QkvRopeImpl(
528543
phi::Scalar(head_dim),
529544
phi::Scalar(num_head),
530545
phi::Scalar(total_batch),
531-
phi::Scalar(transpose));
546+
phi::Scalar(transpose),
547+
phi::Scalar(use_neox_style));
532548
return {paddle::Tensor(query_states), paddle::Tensor(key_value_states)};
533549
}
534550

@@ -542,7 +558,8 @@ std::vector<std::vector<int64_t>> FusedFp8QkvRopeShape(
542558
int head_dim,
543559
int num_head,
544560
int total_batch,
545-
bool transpose) {
561+
bool transpose,
562+
bool use_neox_style) {
546563
int64_t bsz = src_shape[0];
547564
int64_t seq_len = bsz / total_batch;
548565
int64_t fused_hidden_size =
@@ -562,7 +579,7 @@ std::vector<paddle::DataType> FusedFp8QkvRopeDtype(
562579
return {src_dtype, src_dtype};
563580
}
564581

565-
PD_BUILD_OP(fused_fp8_qkv_rope_t)
582+
PD_BUILD_OP(fused_fp8_qkv_rope)
566583
.Inputs({"src",
567584
"qkv_weights",
568585
paddle::Optional("qkv_biases"),
@@ -573,7 +590,8 @@ PD_BUILD_OP(fused_fp8_qkv_rope_t)
573590
.Attrs({"head_dim: int",
574591
"num_head: int",
575592
"total_batch: int",
576-
"transpose: bool"})
593+
"transpose: bool",
594+
"use_neox_style: bool"})
577595
.SetKernelFn(PD_KERNEL(FusedFp8QkvRopeImpl))
578596
.SetInferShapeFn(PD_INFER_SHAPE(FusedFp8QkvRopeShape))
579597
.SetInferDtypeFn(PD_INFER_DTYPE(FusedFp8QkvRopeDtype));

backends/intel_hpu/custom_ops/python/paddlenlp_ops/llama_block_atten.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ def prepare_block_metadata_ref(
211211
def rebuild_padding_v2(
212212
tmp_out,
213213
batch_ids,
214+
total_batch,
214215
seq_lens_encoder,
215216
is_prompt=None,
216217
):
@@ -219,7 +220,7 @@ def rebuild_padding_v2(
219220
output_data = paddle.zeros((max_batch, dim_emb))
220221

221222
if is_prompt is True: # context
222-
tmp_out = tmp_out.reshape([max_batch, -1, dim_emb])
223+
tmp_out = tmp_out.reshape([total_batch, -1, dim_emb])
223224
j = 0
224225
for i in range(max_batch):
225226
if seq_lens_encoder[i].item() > 0:

0 commit comments

Comments
 (0)