@@ -26,6 +26,7 @@ struct FusedQkvRopeParams {
2626 int num_head;
2727 int kv_num_head;
2828
29+ bool use_neox_style = true ;
2930 bool transpose = true ;
3031 bool with_qkv_biases = false ;
3132 bool use_fp8 = false ;
@@ -187,7 +188,9 @@ class FusedQkvRope : public HpuFusedOperator {
187188
188189 ns_RoPESt2::ParamsV2 ropeParams;
189190 ropeParams.offset = 0 ;
190- ropeParams.mode = ROTARY_POS_EMBEDDING_MODE_BLOCKWISE;
191+ ropeParams.mode = params.use_neox_style
192+ ? ROTARY_POS_EMBEDDING_MODE_BLOCKWISE
193+ : ROTARY_POS_EMBEDDING_MODE_PAIRWISE;
191194 AddNodeRope<T>(inputs_q, outputs_q, ropeParams, guid_ + " rope_q" );
192195
193196 std::vector<synTensor> inputs_k;
@@ -239,7 +242,8 @@ void FusedQkvRopeKernel(const Context& dev_ctx,
239242 const phi::Scalar& head_dim,
240243 const phi::Scalar& num_head,
241244 const phi::Scalar& total_batch,
242- const phi::Scalar& transpose) {
245+ const phi::Scalar& transpose,
246+ const phi::Scalar& use_neox_style) {
243247 int total_batch_ = total_batch.to <int >();
244248 std::vector<int64_t > src_dims = phi::vectorize<int64_t >(src.dims ());
245249 int bsz_seqlen = src_dims[0 ];
@@ -255,6 +259,7 @@ void FusedQkvRopeKernel(const Context& dev_ctx,
255259 int head_dim_ = head_dim.to <int >();
256260 int num_head_ = num_head.to <int >();
257261 bool transpose_ = transpose.to <bool >();
262+ bool use_neox_style_ = use_neox_style.to <bool >();
258263 const int64_t fused_hidden_size =
259264 transpose_ ? qkv_weights_dims[0 ] : qkv_weights_dims[1 ];
260265 const int kv_num_head =
@@ -297,6 +302,8 @@ void FusedQkvRopeKernel(const Context& dev_ctx,
297302 params.num_head = num_head_;
298303 params.kv_num_head = kv_num_head;
299304 params.transpose = transpose_;
305+ params.use_neox_style = use_neox_style_;
306+
300307 if (qkv_biases) {
301308 params.with_qkv_biases = true ;
302309 }
@@ -333,7 +340,8 @@ void CallFusedQkvRopeKernel(
333340 const phi::Scalar& head_dim,
334341 const phi::Scalar& num_head,
335342 const phi::Scalar& total_batch,
336- const phi::Scalar& transpose) {
343+ const phi::Scalar& transpose,
344+ const phi::Scalar& use_neox_style) {
337345 if (src.dtype () == phi::DataType::FLOAT16) {
338346 custom_kernel::FusedQkvRopeKernel<phi::dtype::float16>(dev_ctx,
339347 src,
@@ -347,7 +355,8 @@ void CallFusedQkvRopeKernel(
347355 head_dim,
348356 num_head,
349357 total_batch,
350- transpose);
358+ transpose,
359+ use_neox_style);
351360 } else if (src.dtype () == phi::DataType::BFLOAT16) {
352361 custom_kernel::FusedQkvRopeKernel<phi::dtype::bfloat16>(dev_ctx,
353362 src,
@@ -361,7 +370,8 @@ void CallFusedQkvRopeKernel(
361370 head_dim,
362371 num_head,
363372 total_batch,
364- transpose);
373+ transpose,
374+ use_neox_style);
365375 } else {
366376 throw std::runtime_error (" Unsupported data type for FusedQkvRopeKernel" );
367377 }
@@ -375,7 +385,8 @@ std::vector<paddle::Tensor> FusedQkvRopeImpl(
375385 int head_dim,
376386 int num_head,
377387 int total_batch,
378- bool transpose) {
388+ bool transpose,
389+ bool use_neox_style) {
379390 auto dev_ctx = static_cast <const phi::CustomContext*>(
380391 paddle::experimental::DeviceContextPool::Instance ().Get (src.place ()));
381392 auto src_tensor = static_cast <const phi::DenseTensor*>(src.impl ().get ());
@@ -422,7 +433,8 @@ std::vector<paddle::Tensor> FusedQkvRopeImpl(
422433 phi::Scalar (head_dim),
423434 phi::Scalar (num_head),
424435 phi::Scalar (total_batch),
425- phi::Scalar (transpose));
436+ phi::Scalar (transpose),
437+ phi::Scalar (use_neox_style));
426438 return {paddle::Tensor (query_states), paddle::Tensor (key_value_states)};
427439}
428440
@@ -434,7 +446,8 @@ std::vector<std::vector<int64_t>> FusedQkvRopeShape(
434446 int head_dim,
435447 int num_head,
436448 int total_batch,
437- bool transpose) {
449+ bool transpose,
450+ bool use_neox_style) {
438451 int64_t bsz = src_shape[0 ];
439452 int64_t seq_len = bsz / total_batch;
440453 int64_t fused_hidden_size =
@@ -459,7 +472,8 @@ PD_BUILD_OP(fused_qkv_rope)
459472 .Attrs({" head_dim: int" ,
460473 " num_head: int" ,
461474 " total_batch: int" ,
462- " transpose: bool" })
475+ " transpose: bool" ,
476+ " use_neox_style: bool" })
463477 .SetKernelFn(PD_KERNEL(FusedQkvRopeImpl))
464478 .SetInferShapeFn(PD_INFER_SHAPE(FusedQkvRopeShape))
465479 .SetInferDtypeFn(PD_INFER_DTYPE(FusedQkvRopeDtype));
@@ -474,7 +488,8 @@ std::vector<paddle::Tensor> FusedFp8QkvRopeImpl(
474488 int head_dim,
475489 int num_head,
476490 int total_batch,
477- bool transpose) {
491+ bool transpose,
492+ bool use_neox_style) {
478493 auto dev_ctx = static_cast <const phi::CustomContext*>(
479494 paddle::experimental::DeviceContextPool::Instance ().Get (src.place ()));
480495 auto src_tensor = static_cast <const phi::DenseTensor*>(src.impl ().get ());
@@ -528,7 +543,8 @@ std::vector<paddle::Tensor> FusedFp8QkvRopeImpl(
528543 phi::Scalar (head_dim),
529544 phi::Scalar (num_head),
530545 phi::Scalar (total_batch),
531- phi::Scalar (transpose));
546+ phi::Scalar (transpose),
547+ phi::Scalar (use_neox_style));
532548 return {paddle::Tensor (query_states), paddle::Tensor (key_value_states)};
533549}
534550
@@ -542,7 +558,8 @@ std::vector<std::vector<int64_t>> FusedFp8QkvRopeShape(
542558 int head_dim,
543559 int num_head,
544560 int total_batch,
545- bool transpose) {
561+ bool transpose,
562+ bool use_neox_style) {
546563 int64_t bsz = src_shape[0 ];
547564 int64_t seq_len = bsz / total_batch;
548565 int64_t fused_hidden_size =
@@ -562,7 +579,7 @@ std::vector<paddle::DataType> FusedFp8QkvRopeDtype(
562579 return {src_dtype, src_dtype};
563580}
564581
565- PD_BUILD_OP (fused_fp8_qkv_rope_t )
582+ PD_BUILD_OP (fused_fp8_qkv_rope )
566583 .Inputs({" src" ,
567584 " qkv_weights" ,
568585 paddle::Optional (" qkv_biases" ),
@@ -573,7 +590,8 @@ PD_BUILD_OP(fused_fp8_qkv_rope_t)
573590 .Attrs({" head_dim: int" ,
574591 " num_head: int" ,
575592 " total_batch: int" ,
576- " transpose: bool" })
593+ " transpose: bool" ,
594+ " use_neox_style: bool" })
577595 .SetKernelFn(PD_KERNEL(FusedFp8QkvRopeImpl))
578596 .SetInferShapeFn(PD_INFER_SHAPE(FusedFp8QkvRopeShape))
579597 .SetInferDtypeFn(PD_INFER_DTYPE(FusedFp8QkvRopeDtype));
0 commit comments