Skip to content

Commit b2838bd

Browse files
authored
Fix set_value bug and chang kernel (#1835)
1 parent aad45d5 commit b2838bd

File tree

9 files changed

+8
-6
lines changed

9 files changed

+8
-6
lines changed

backends/metax_gpu/kernels/cuda_kernels/c_embedding_grad_kernel_register.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ PD_CUSTOM_KERNEL_REGISTER(c_embedding_grad,
2323
phi::CEmbeddingGradKernel,
2424
float,
2525
double,
26+
phi::dtype::bfloat16,
2627
phi::dtype::float16,
2728
phi::dtype::complex<float>,
2829
phi::dtype::complex<double>) {}

backends/metax_gpu/kernels/cuda_kernels/c_embedding_kernel_register.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ PD_CUSTOM_KERNEL_REGISTER(c_embedding,
2323
phi::CEmbeddingKernel,
2424
float,
2525
double,
26+
phi::dtype::bfloat16,
2627
phi::dtype::float16,
2728
phi::dtype::complex<float>,
2829
phi::dtype::complex<double>) {}

backends/metax_gpu/kernels/cuda_kernels/cast_kernel_register.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
#include "paddle/phi/kernels/cast_kernel.h"
1717

1818
PD_CUSTOM_KERNEL_REGISTER(cast,
19-
iluvatar_gpu,
19+
metax_gpu,
2020
ALL_LAYOUT,
2121
phi::CastKernel,
2222
float,

backends/metax_gpu/kernels/cuda_kernels/set_value_kernel_register.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
PD_CUSTOM_KERNEL_REGISTER(set_value,
2121
metax_gpu,
2222
ALL_LAYOUT,
23-
phi::SetValueKernelV2,
23+
phi::SetValueKernel,
2424
float,
2525
double,
2626
int,
@@ -36,7 +36,7 @@ PD_CUSTOM_KERNEL_REGISTER(set_value,
3636
PD_CUSTOM_KERNEL_REGISTER(set_value_with_tensor,
3737
metax_gpu,
3838
ALL_LAYOUT,
39-
phi::SetTensorValueKernelV2,
39+
phi::SetTensorValueKernel,
4040
float,
4141
double,
4242
int,

backends/metax_gpu/kernels/metax_kerenl/blha_get_max_len_register.cu renamed to backends/metax_gpu/kernels/metax_kernel/blha_get_max_len_register.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
// limitations under the License.
1414

1515
#include "kernels/flash_attn_utils.h"
16-
#include "kernels/metax_kerenl/block_attn.h"
16+
#include "kernels/metax_kernel/block_attn.h"
1717
#include "paddle/phi/backends/context_pool.h"
1818
#include "paddle/phi/core/dense_tensor.h"
1919
#include "paddle/phi/core/kernel_registry.h"

backends/metax_gpu/kernels/metax_kerenl/block_attn.h renamed to backends/metax_gpu/kernels/metax_kernel/block_attn.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
#pragma once
1616

1717
#include "kernels/funcs/quant_dequant.h"
18-
#include "kernels/metax_kerenl/mmha_util.cu.h"
18+
#include "kernels/metax_kernel/mmha_util.cu.h"
1919
#include "paddle/common/flags.h"
2020
#include "paddle/phi/backends/gpu/gpu_context.h"
2121
#include "paddle/phi/common/memory_utils.h"

backends/metax_gpu/kernels/metax_kerenl/mmha_util.cu.h renamed to backends/metax_gpu/kernels/metax_kernel/mmha_util.cu.h

File renamed without changes.

backends/metax_gpu/kernels/metax_kerenl/quant_dequant.h renamed to backends/metax_gpu/kernels/metax_kernel/quant_dequant.h

File renamed without changes.

backends/metax_gpu/tests/unittest/test_greater_equal_op_metax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def test_api_fp16(self):
3030
limit = paddle.to_tensor([3, 2], dtype="float16")
3131
out = paddle.greater_equal(x=label, y=limit)
3232
if core.is_compiled_with_cuda():
33-
place = paddle.CustomPlace("iluvatar_gpu", 0)
33+
place = paddle.CustomPlace("metax_gpu", 0)
3434
exe = static.Executor(place)
3535
(res,) = exe.run(fetch_list=[out])
3636
self.assertEqual((res == np.array([True, True])).all(), True)

0 commit comments

Comments
 (0)