Skip to content

Commit 5b47474

Browse files
committed
add disptach_ffn_combine kernel
Signed-off-by: mojave2 <[email protected]>
1 parent 18b90b5 commit 5b47474

File tree

57 files changed

+9753
-48
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+9753
-48
lines changed

.gitmodules

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
[submodule "csrc/third_party/catlass"]
2+
path = csrc/third_party/catlass
3+
url = https://gitee.com/ascend/catlass.git
4+
branch = catlass-v1-stable

csrc/build_aclnn.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@ if [[ "$SOC_VERSION" =~ ^ascend310 ]]; then
1111
exit 0
1212
elif [[ "$SOC_VERSION" =~ ^ascend910b ]]; then
1313
# ASCEND910B (A2) series
14-
CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention"
14+
CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention;dispatch_ffn_combine"
1515
SOC_ARG="ascend910b"
1616
elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then
1717
# ASCEND910C (A3) series
18-
CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention"
18+
CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention;dispatch_ffn_combine"
1919
SOC_ARG="ascend910_93"
2020
else
2121
# others
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
2+
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR} ops_srcs)
3+
4+
opbuild(OPS_SRC ${ops_srcs}
5+
OUT_DIR ${ASCEND_AUTOGEN_PATH}
6+
)
7+
8+
file(GLOB group_proto_src ${ASCEND_AUTOGEN_PATH}/group_proto/*.cc)
9+
10+
add_library(cust_op_proto SHARED
11+
$<$<TARGET_EXISTS:group_proto_src>:${group_proto_src}>
12+
${ops_srcs}
13+
${ASCEND_AUTOGEN_PATH}/op_proto.cc
14+
)
15+
target_compile_definitions(cust_op_proto PRIVATE OP_PROTO_LIB)
16+
target_compile_options(cust_op_proto PRIVATE
17+
-fvisibility=hidden
18+
)
19+
if(ENABLE_CROSS_COMPILE)
20+
target_link_directories(cust_op_proto PRIVATE
21+
${CMAKE_COMPILE_COMPILER_LIBRARY}
22+
${CMAKE_COMPILE_RUNTIME_LIBRARY}
23+
)
24+
endif()
25+
target_link_libraries(cust_op_proto PRIVATE
26+
intf_pub
27+
exe_graph
28+
register
29+
tiling_api
30+
-Wl,--whole-archive
31+
rt2_registry
32+
-Wl,--no-whole-archive
33+
)
34+
set_target_properties(cust_op_proto PROPERTIES OUTPUT_NAME
35+
cust_opsproto_rt2.0
36+
)
37+
file(GLOB fallback_src ${ASCEND_AUTOGEN_PATH}/fallback_*.cpp)
38+
add_library(cust_optiling SHARED ${ops_srcs})
39+
if (${fallback_src})
40+
target_sources(cust_optiling PRIVATE ${fallback_src})
41+
endif()
42+
target_compile_definitions(cust_optiling PRIVATE OP_TILING_LIB)
43+
target_compile_options(cust_optiling PRIVATE
44+
-fvisibility=hidden
45+
)
46+
if(ENABLE_CROSS_COMPILE)
47+
target_link_directories(cust_optiling PRIVATE
48+
${CMAKE_COMPILE_COMPILER_LIBRARY}
49+
${CMAKE_COMPILE_RUNTIME_LIBRARY}
50+
)
51+
endif()
52+
target_link_libraries(cust_optiling PRIVATE
53+
nnopbase
54+
intf_pub
55+
exe_graph
56+
register
57+
tiling_api
58+
-Wl,--whole-archive
59+
rt2_registry
60+
-Wl,--no-whole-archive
61+
)
62+
set_target_properties(cust_optiling PROPERTIES OUTPUT_NAME
63+
cust_opmaster_rt2.0
64+
)
65+
66+
file(GLOB_RECURSE pregen_file
67+
"${CMAKE_CURRENT_SOURCE_DIR}/op_api/*"
68+
)
69+
70+
file(COPY ${pregen_file} DESTINATION ${ASCEND_AUTOGEN_PATH})
71+
file(GLOB aclnn_src ${ASCEND_AUTOGEN_PATH}/aclnn*.cpp)
72+
file(GLOB aclnn_inc ${ASCEND_AUTOGEN_PATH}/aclnn_*.h)
73+
if(NOT ASCEND_PACK_SHARED_LIBRARY)
74+
add_library(cust_opapi SHARED ${aclnn_src})
75+
else()
76+
file(GLOB op_registry ${ASCEND_AUTOGEN_PATH}/custom_op_registry.cpp)
77+
add_library(cust_opapi SHARED ${aclnn_src} ${op_registry})
78+
target_compile_definitions(cust_opapi PRIVATE ACLNN_WITH_BINARY)
79+
endif()
80+
if(ENABLE_CROSS_COMPILE)
81+
target_link_directories(cust_opapi PRIVATE
82+
${CMAKE_COMPILE_COMPILER_LIBRARY}
83+
${CMAKE_COMPILE_RUNTIME_LIBRARY}
84+
)
85+
endif()
86+
if(NOT ASCEND_PACK_SHARED_LIBRARY)
87+
target_link_libraries(cust_opapi PRIVATE intf_pub ascendcl nnopbase)
88+
else()
89+
add_library(cust_op_proto_obj OBJECT
90+
$<$<TARGET_EXISTS:group_proto_src>:${group_proto_src}>
91+
${ops_srcs}
92+
${ASCEND_AUTOGEN_PATH}/op_proto.cc
93+
)
94+
target_compile_definitions(cust_op_proto_obj PRIVATE OP_PROTO_LIB)
95+
target_compile_options(cust_op_proto_obj PRIVATE
96+
-fvisibility=hidden
97+
)
98+
if(ENABLE_CROSS_COMPILE)
99+
target_link_directories(cust_op_proto_obj PRIVATE
100+
${CMAKE_COMPILE_COMPILER_LIBRARY}
101+
${CMAKE_COMPILE_RUNTIME_LIBRARY}
102+
)
103+
endif()
104+
target_link_libraries(cust_op_proto_obj PRIVATE
105+
intf_pub
106+
exe_graph
107+
register
108+
tiling_api
109+
-Wl,--whole-archive
110+
rt2_registry
111+
-Wl,--no-whole-archive
112+
)
113+
add_library(cust_optiling_obj OBJECT ${ops_srcs})
114+
target_compile_definitions(cust_optiling_obj PRIVATE OP_TILING_LIB)
115+
target_compile_options(cust_optiling_obj PRIVATE
116+
-fvisibility=hidden
117+
)
118+
if(ENABLE_CROSS_COMPILE)
119+
target_link_directories(cust_optiling_obj PRIVATE
120+
${CMAKE_COMPILE_COMPILER_LIBRARY}
121+
${CMAKE_COMPILE_RUNTIME_LIBRARY}
122+
)
123+
endif()
124+
target_link_libraries(cust_optiling_obj PRIVATE
125+
intf_pub
126+
exe_graph
127+
register
128+
tiling_api
129+
-Wl,--whole-archive
130+
rt2_registry
131+
-Wl,--no-whole-archive
132+
)
133+
target_compile_options(cust_opapi PRIVATE -DLOG_CPP)
134+
target_include_directories(cust_opapi INTERFACE ${CMAKE_SOURCE_DIR}/build_out/library/)
135+
target_link_libraries(cust_opapi PRIVATE intf_pub ascendcl nnopbase cust_optiling_obj cust_op_proto_obj ascend_opregistry ascend_kernels)
136+
add_dependencies(cust_opapi ascend_opregistry)
137+
endif()
138+
139+
target_include_directories(cust_opapi PRIVATE
140+
$ENV{ASCEND_HOME_PATH}/aarch64-linux/include/experiment/platform/
141+
$ENV{ASCEND_HOME_PATH}/x86_64-linux/include/experiment/platform/
142+
)
143+
include_directories($ENV{ASCEND_HOME_PATH}/../opp/vendors/CAM/op_impl/ai_core/tbe/CAM_impl/dynamic/)
144+
145+
add_custom_target(optiling_compat ALL
146+
COMMAND ln -sf lib/linux/${CMAKE_SYSTEM_PROCESSOR}/$<TARGET_FILE_NAME:cust_optiling>
147+
${CMAKE_CURRENT_BINARY_DIR}/liboptiling.so
148+
)
149+
if(NOT ASCEND_PACK_SHARED_LIBRARY)
150+
install(TARGETS cust_op_proto
151+
LIBRARY DESTINATION packages/vendors/${vendor_name}/op_proto/lib/linux/${CMAKE_SYSTEM_PROCESSOR})
152+
install(FILES ${ASCEND_AUTOGEN_PATH}/op_proto.h
153+
DESTINATION packages/vendors/${vendor_name}/op_proto/inc)
154+
file(GLOB GROUP_PROTO_HEADERS ${ASCEND_AUTOGEN_PATH}/group_proto/*.h)
155+
if (GROUP_PROTO_HEADERS)
156+
install(FILES ${GROUP_PROTO_HEADERS}
157+
DESTINATION packages/vendors/${vendor_name}/op_proto/inc)
158+
endif()
159+
install(TARGETS cust_optiling
160+
LIBRARY DESTINATION packages/vendors/${vendor_name}/op_impl/ai_core/tbe/op_tiling/lib/linux/${CMAKE_SYSTEM_PROCESSOR})
161+
install(FILES ${CMAKE_CURRENT_BINARY_DIR}/liboptiling.so
162+
DESTINATION packages/vendors/${vendor_name}/op_impl/ai_core/tbe/op_tiling)
163+
install(TARGETS cust_opapi
164+
LIBRARY DESTINATION packages/vendors/${vendor_name}/op_api/lib)
165+
install(FILES ${aclnn_inc}
166+
DESTINATION packages/vendors/${vendor_name}/op_api/include)
167+
else()
168+
file(GLOB group_inc ${ASCEND_AUTOGEN_PATH}/group_proto/*.h)
169+
install(TARGETS cust_opapi
170+
LIBRARY DESTINATION op_api/lib)
171+
install(FILES ${ASCEND_AUTOGEN_PATH}/op_proto.h
172+
DESTINATION op_api/include)
173+
install(FILES ${group_inc}
174+
DESTINATION op_api/include)
175+
install(FILES ${aclnn_inc}
176+
DESTINATION op_api/include)
177+
endif()
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
/**
2+
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
3+
* This file is a part of the CANN Open Software.
4+
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
5+
* Please refer to the License for details. You may not use this file except in compliance with the License.
6+
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
7+
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
8+
* See LICENSE in the root of the software repository for the full text of the License.
9+
*/
10+
11+
/*!
12+
* \file dispatch_ffn_combine_def.cpp
13+
* \brief
14+
*/
15+
#include "register/op_def_registry.h"
16+
17+
namespace ops {
18+
class DispatchFFNCombine : public OpDef {
19+
public:
20+
explicit DispatchFFNCombine(const char *name) : OpDef(name) {
21+
this->Input("a")
22+
.ParamType(REQUIRED)
23+
.DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_BF16})
24+
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
25+
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
26+
this->Input("w1")
27+
.ParamType(REQUIRED)
28+
.DataType({ge::DT_INT8, ge::DT_INT8, ge::DT_INT8})
29+
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ})
30+
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ})
31+
.IgnoreContiguous();
32+
this->Input("w2")
33+
.ParamType(REQUIRED)
34+
.DataType({ge::DT_INT8, ge::DT_INT8, ge::DT_INT8})
35+
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ})
36+
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ})
37+
.IgnoreContiguous();
38+
this->Input("expertIdx")
39+
.ParamType(REQUIRED)
40+
.DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32})
41+
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
42+
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
43+
this->Input("scale1")
44+
.ParamType(REQUIRED)
45+
.DataType({ge::DT_INT64, ge::DT_INT64, ge::DT_INT64})
46+
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
47+
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
48+
this->Input("scale2")
49+
.ParamType(REQUIRED)
50+
.DataType({ge::DT_INT64, ge::DT_INT64, ge::DT_INT64})
51+
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
52+
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
53+
this->Input("probs")
54+
.ParamType(REQUIRED)
55+
.DataType({ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT})
56+
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
57+
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
58+
59+
// 输出
60+
this->Output("out")
61+
.ParamType(REQUIRED)
62+
.DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_BF16})
63+
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
64+
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND,ge::FORMAT_ND});
65+
66+
this->Attr("group").AttrType(REQUIRED).String();
67+
this->Attr("M").AttrType(OPTIONAL).Int();
68+
this->Attr("transB").AttrType(OPTIONAL).Bool(false);
69+
this->Attr("weightNz").AttrType(OPTIONAL).Bool(false);
70+
71+
OpAICoreConfig aicore_config;
72+
aicore_config.DynamicCompileStaticFlag(true)
73+
.DynamicFormatFlag(true)
74+
.DynamicRankSupportFlag(true)
75+
.DynamicShapeSupportFlag(true)
76+
.NeedCheckSupportFlag(false)
77+
.PrecisionReduceFlag(true)
78+
.ExtendCfgInfo("aclnnSupport.value", "support_aclnn")
79+
.ExtendCfgInfo("jitCompile.flag", "static_false")
80+
.ExtendCfgInfo("multiKernelSupportDynamicGraph.value", "multi_kernel");
81+
this->AICore().AddConfig("ascend910_93", aicore_config);
82+
// this->AICore().AddConfig("ascend910b", aicore_config);
83+
this->MC2().HcclGroup("group");
84+
}
85+
};
86+
87+
OP_ADD(DispatchFFNCombine);
88+
} // namespace ops
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/**
2+
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
3+
* This file is a part of the CANN Open Software.
4+
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
5+
* Please refer to the License for details. You may not use this file except in compliance with the License.
6+
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
7+
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
8+
* See LICENSE in the root of the software repository for the full text of the License.
9+
*/
10+
11+
/*!
12+
* \file dispatch_ffn_proto.cpp
13+
* \brief
14+
*/
15+
#include <graph/utils/type_utils.h>
16+
#include <register/op_impl_registry.h>
17+
// #include "../../common/ophost/op_util.h"
18+
// #include "../../common/ophost/hcom_topo_info.h"
19+
// #include "log/ops_log.h"
20+
21+
using namespace ge;
22+
namespace ops {
23+
const size_t ATTR_GROUP = 0;
24+
const size_t ATTR_RANK_SIZE = 1;
25+
const size_t SUPPORT_DIM_SIZE = 2;
26+
27+
static ge::graphStatus InferShapeDispatchFFNCombine(gert::InferShapeContext* context) {
28+
return ge::GRAPH_SUCCESS;
29+
}
30+
31+
static ge::graphStatus InferDataTypeDispatchFFNCombine(gert::InferDataTypeContext* context) {
32+
// auto d_type = context->GetInputDataType(0);
33+
// context->SetOutputDataType(0, d_type);
34+
return ge::GRAPH_SUCCESS;
35+
}
36+
37+
IMPL_OP_INFERSHAPE(DispatchFFNCombine)
38+
.InferShape(InferShapeDispatchFFNCombine)
39+
.InferDataType(InferDataTypeDispatchFFNCombine);
40+
} // namespace ops

0 commit comments

Comments
 (0)