1717#include " tiling/platform/platform_ascendc.h"
1818#include " tiling/hccl/hccl_tiling.h"
1919
20- #define GM_ALIGN_SIZE 512
21- #define ENABLE_TILING_CHECK
22-
2320using namespace ge ;
2421namespace {
2522constexpr uint32_t OP_TYPE_ALL_TO_ALL = 8 ;
2623constexpr uint32_t SYSTEM_NEED_WORKSPACE = 16 * 1024 * 1024 ;
24+ constexpr uint32_t GM_ALIGN_SIZE = 512 ;
2725constexpr uint32_t TOKEN_DTYPE_BYTE_SIZE = 2 ;
2826constexpr uint32_t USE_CORE_NUM = 24 ;
2927constexpr uint32_t L1_TILE_BYTE_SIZE = 32 * 1024 ;
@@ -78,50 +76,50 @@ static ge::graphStatus CheckTensorShape(gert::TilingContext *context, const char
7876
7977 uint32_t localExpertNum = epRankId < sharedExpertRankNum ? 1 : moeExpertNumPerRank;
8078 const gert::StorageShape *gmm1WeightStorageShape = context->GetInputShape (INPUT_GMM1_WEIGHT_INDEX);
81- OP_TILING_CHECK (gmm1WeightStorageShape == nullptr , OP_LOGE (nodeName, " gmm1 weight shape is null." ),
79+ OPS_ERR_IF (gmm1WeightStorageShape == nullptr , OP_LOGE (nodeName, " gmm1 weight shape is null." ),
8280 return ge::GRAPH_FAILED);
8381 const int64_t gmm1WeightDim0 = gmm1WeightStorageShape->GetStorageShape ().GetDim (0 );
84- OP_TILING_CHECK (gmm1WeightDim0 != localExpertNum,
82+ OPS_ERR_IF (gmm1WeightDim0 != localExpertNum,
8583 OP_LOGE (nodeName, " gmm1Weight Dim0 must be expert number in current rank." ),
8684 return ge::GRAPH_FAILED);
8785
8886 const gert::StorageShape *gmm1WeightScaleStorageShape = context->GetInputShape (INPUT_GMM1_WEIGHT_SCALE_INDEX);
89- OP_TILING_CHECK (gmm1WeightScaleStorageShape == nullptr , OP_LOGE (nodeName, " gmm1 weight scale shape is null." ),
87+ OPS_ERR_IF (gmm1WeightScaleStorageShape == nullptr , OP_LOGE (nodeName, " gmm1 weight scale shape is null." ),
9088 return ge::GRAPH_FAILED);
91- OP_TILING_CHECK (gmm1WeightScaleStorageShape->GetStorageShape ().GetDimNum () != TWO_DIMS,
89+ OPS_ERR_IF (gmm1WeightScaleStorageShape->GetStorageShape ().GetDimNum () != TWO_DIMS,
9290 OP_LOGE (nodeName, " gmm1 weight scale shape dims must be 2, but current dim num is %lu." ,
9391 gmm1WeightScaleStorageShape->GetStorageShape ().GetDimNum ()),
9492 return ge::GRAPH_FAILED);
9593 const int64_t gmm1WeightScaleDim0 = gmm1WeightScaleStorageShape->GetStorageShape ().GetDim (0 );
96- OP_TILING_CHECK (gmm1WeightScaleDim0 != localExpertNum,
94+ OPS_ERR_IF (gmm1WeightScaleDim0 != localExpertNum,
9795 OP_LOGE (nodeName, " gmm1WeightScale Dim0 must be expert number in current rank." ),
9896 return ge::GRAPH_FAILED);
9997 const int64_t gmm1WeightScaleDim1 = gmm1WeightScaleStorageShape->GetStorageShape ().GetDim (1 );
100- OP_TILING_CHECK (gmm1WeightScaleDim1 != gmm1WeightDim2,
98+ OPS_ERR_IF (gmm1WeightScaleDim1 != gmm1WeightDim2,
10199 OP_LOGE (nodeName, " gmm1WeightScale Dim1 must be %lu(gmm1WeightDim2)." , gmm1WeightDim2),
102100 return ge::GRAPH_FAILED);
103101
104102 const gert::StorageShape *gmm2WeightStorageShape = context->GetInputShape (INPUT_GMM2_WEIGHT_INDEX);
105- OP_TILING_CHECK (gmm2WeightStorageShape == nullptr , OP_LOGE (nodeName, " gmm2 weight shape is null." ),
103+ OPS_ERR_IF (gmm2WeightStorageShape == nullptr , OP_LOGE (nodeName, " gmm2 weight shape is null." ),
106104 return ge::GRAPH_FAILED);
107105 const int64_t gmm2WeightDim0 = gmm2WeightStorageShape->GetStorageShape ().GetDim (0 );
108- OP_TILING_CHECK (gmm2WeightDim0 != localExpertNum,
106+ OPS_ERR_IF (gmm2WeightDim0 != localExpertNum,
109107 OP_LOGE (nodeName, " gmm2Weight Dim0 must be expert number in current rank." ),
110108 return ge::GRAPH_FAILED);
111109
112110 const gert::StorageShape *gmm2WeightScaleStorageShape = context->GetInputShape (INPUT_GMM2_WEIGHT_SCALE_INDEX);
113- OP_TILING_CHECK (gmm2WeightScaleStorageShape == nullptr , OP_LOGE (nodeName, " gmm2 weight scale shape is null." ),
111+ OPS_ERR_IF (gmm2WeightScaleStorageShape == nullptr , OP_LOGE (nodeName, " gmm2 weight scale shape is null." ),
114112 return ge::GRAPH_FAILED);
115- OP_TILING_CHECK (gmm2WeightScaleStorageShape->GetStorageShape ().GetDimNum () != TWO_DIMS,
113+ OPS_ERR_IF (gmm2WeightScaleStorageShape->GetStorageShape ().GetDimNum () != TWO_DIMS,
116114 OP_LOGE (nodeName, " gmm2 weight scale shape dims must be 2, but current dim num is %lu." ,
117115 gmm2WeightScaleStorageShape->GetStorageShape ().GetDimNum ()),
118116 return ge::GRAPH_FAILED);
119117 const int64_t gmm2WeightScaleDim0 = gmm2WeightScaleStorageShape->GetStorageShape ().GetDim (0 );
120- OP_TILING_CHECK (gmm2WeightScaleDim0 != localExpertNum,
118+ OPS_ERR_IF (gmm2WeightScaleDim0 != localExpertNum,
121119 OP_LOGE (nodeName, " gmm2WeightScale Dim0 must be expert number in current rank." ),
122120 return ge::GRAPH_FAILED);
123121 const int64_t gmm2WeightScaleDim1 = gmm2WeightScaleStorageShape->GetStorageShape ().GetDim (1 );
124- OP_TILING_CHECK (gmm2WeightScaleDim1 != h, OP_LOGE (nodeName, " gmm2WeightScale Dim1 must be %u." , h),
122+ OPS_ERR_IF (gmm2WeightScaleDim1 != h, OP_LOGE (nodeName, " gmm2WeightScale Dim1 must be %u." , h),
125123 return ge::GRAPH_FAILED);
126124
127125 return ge::GRAPH_SUCCESS;
@@ -130,31 +128,31 @@ static ge::graphStatus CheckTensorShape(gert::TilingContext *context, const char
130128static ge::graphStatus CheckData (const char *nodeName, DispatchGmmCombineDecodeTilingData &tilingData)
131129{
132130 uint32_t batchSize = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo .bs ;
133- OP_TILING_CHECK (batchSize < MIN_BATCH_SIZE, OP_LOGE (nodeName, " batchSize(bs) must >= %d." , MIN_BATCH_SIZE),
131+ OPS_ERR_IF (batchSize < MIN_BATCH_SIZE, OP_LOGE (nodeName, " batchSize(bs) must >= %d." , MIN_BATCH_SIZE),
134132 return ge::GRAPH_FAILED);
135- OP_TILING_CHECK (batchSize > MAX_BATCH_SIZE, OP_LOGE (nodeName, " batchSize(bs) must <= %d." , MAX_BATCH_SIZE),
133+ OPS_ERR_IF (batchSize > MAX_BATCH_SIZE, OP_LOGE (nodeName, " batchSize(bs) must <= %d." , MAX_BATCH_SIZE),
136134 return ge::GRAPH_FAILED);
137135 uint32_t tokenLength = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo .h ;
138- OP_TILING_CHECK (
136+ OPS_ERR_IF (
139137 tokenLength < MIN_TOKEN_LENGTH || tokenLength > MAX_TOKEN_LENGTH,
140138 OP_LOGE (nodeName, " tokenLength(h) is invalid. Only support [%u, %u]." , MIN_TOKEN_LENGTH, MAX_TOKEN_LENGTH),
141139 return ge::GRAPH_FAILED);
142140 uint32_t gmm1HLen = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo .gmm1HLen ;
143- OP_TILING_CHECK (
141+ OPS_ERR_IF (
144142 gmm1HLen < MIN_GMM1_HIDDEN || gmm1HLen > MAX_GMM1_HIDDEN,
145143 OP_LOGE (nodeName, " gmm1 hidden size is invalid. Only support [%u, %u]." , MIN_GMM1_HIDDEN, MAX_GMM1_HIDDEN),
146144 return ge::GRAPH_FAILED);
147145 uint32_t topK = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo .k ;
148- OP_TILING_CHECK (topK > SUPPORT_TOP_K, OP_LOGE (nodeName, " topK(k) must <= %d." , SUPPORT_TOP_K),
146+ OPS_ERR_IF (topK > SUPPORT_TOP_K, OP_LOGE (nodeName, " topK(k) must <= %d." , SUPPORT_TOP_K),
149147 return ge::GRAPH_FAILED);
150148 uint32_t globalBatchSize = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo .globalBs ;
151149 uint32_t epRankSize = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo .epRankSize ;
152150 if (globalBatchSize == 0 ) {
153151 globalBatchSize = epRankSize * batchSize;
154152 tilingData.disGmmDeqSwigluQuantGmmDeqComInfo .globalBs = globalBatchSize;
155153 } else {
156- OP_TILING_CHECK (globalBatchSize < 0 , OP_LOGE (nodeName, " globalBatchSize must >= 0." ), return ge::GRAPH_FAILED);
157- OP_TILING_CHECK (globalBatchSize % epRankSize > 0 ,
154+ OPS_ERR_IF (globalBatchSize < 0 , OP_LOGE (nodeName, " globalBatchSize must >= 0." ), return ge::GRAPH_FAILED);
155+ OPS_ERR_IF (globalBatchSize % epRankSize > 0 ,
158156 OP_LOGE (nodeName, " globalBatchSize must be divisible by epRankSize." ), return ge::GRAPH_FAILED);
159157 }
160158
@@ -165,7 +163,7 @@ static ge::graphStatus GetAttrAndSetTilingData(gert::TilingContext *context, con
165163 DispatchGmmCombineDecodeTilingData &tilingData, std::string &groupEp)
166164{
167165 auto attrs = context->GetAttrs ();
168- OP_TILING_CHECK (attrs == nullptr , OP_LOGE (nodeName, " attrs is nullptr." ), return ge::GRAPH_FAILED);
166+ OPS_ERR_IF (attrs == nullptr , OP_LOGE (nodeName, " attrs is nullptr." ), return ge::GRAPH_FAILED);
169167
170168 auto groupEpPtr = attrs->GetAttrPointer <char >(static_cast <int >(ATTR_GROUP_EP_INDEX));
171169 auto epRankSizePtr = attrs->GetAttrPointer <int64_t >(ATTR_EP_RANK_SIZE_INDEX);
@@ -184,16 +182,16 @@ static ge::graphStatus GetAttrAndSetTilingData(gert::TilingContext *context, con
184182 uint32_t moeExpertNumPerRank = moeExpertNum / (epRankSize - sharedExpertRankNum);
185183
186184#ifdef ENABLE_TILING_CHECK
187- OP_TILING_CHECK (epRankId < 0 , OP_LOGE (nodeName, " epRankId must >= 0." ), return ge::GRAPH_FAILED);
188- OP_TILING_CHECK (epRankId >= epRankSize, OP_LOGE (nodeName, " epRankId must < epRankSize." ), return ge::GRAPH_FAILED);
189- OP_TILING_CHECK (moeExpertNum > MAX_MOE_EXERT_NUM, OP_LOGE (nodeName, " moeExpertNum must <= %d." , MAX_MOE_EXERT_NUM),
185+ OPS_ERR_IF (epRankId < 0 , OP_LOGE (nodeName, " epRankId must >= 0." ), return ge::GRAPH_FAILED);
186+ OPS_ERR_IF (epRankId >= epRankSize, OP_LOGE (nodeName, " epRankId must < epRankSize." ), return ge::GRAPH_FAILED);
187+ OPS_ERR_IF (moeExpertNum > MAX_MOE_EXERT_NUM, OP_LOGE (nodeName, " moeExpertNum must <= %d." , MAX_MOE_EXERT_NUM),
190188 return ge::GRAPH_FAILED);
191- OP_TILING_CHECK (moeExpertNum <= 0 , OP_LOGE (nodeName, " moeExpertNum must > 0." ), return ge::GRAPH_FAILED);
192- OP_TILING_CHECK (sharedExpertNum != 1 , OP_LOGE (nodeName, " sharedExpertNum must be 1." ), return ge::GRAPH_FAILED);
193- OP_TILING_CHECK (moeExpertNum % (epRankSize - sharedExpertRankNum) != 0 ,
189+ OPS_ERR_IF (moeExpertNum <= 0 , OP_LOGE (nodeName, " moeExpertNum must > 0." ), return ge::GRAPH_FAILED);
190+ OPS_ERR_IF (sharedExpertNum != 1 , OP_LOGE (nodeName, " sharedExpertNum must be 1." ), return ge::GRAPH_FAILED);
191+ OPS_ERR_IF (moeExpertNum % (epRankSize - sharedExpertRankNum) != 0 ,
194192 OP_LOGE (nodeName, " moeExpertNum must be divisible by (epRankSize - sharedExpertRankNum)." ),
195193 return ge::GRAPH_FAILED);
196- OP_TILING_CHECK (moeExpertNumPerRank > RECV_AIV_NUM,
194+ OPS_ERR_IF (moeExpertNumPerRank > RECV_AIV_NUM,
197195 OP_LOGE (nodeName, " moeExpertNumPerRank must <= %d." , RECV_AIV_NUM), return ge::GRAPH_FAILED);
198196#endif
199197
@@ -226,7 +224,7 @@ static ge::graphStatus SetWorkSpace(gert::TilingContext *context, const char *no
226224 DispatchGmmCombineDecodeTilingData &tilingData)
227225{
228226 size_t *workSpaces = context->GetWorkspaceSizes (1 );
229- OP_TILING_CHECK (workSpaces == nullptr , OP_LOGE (nodeName, " workSpaces is nullptr." ), return ge::GRAPH_FAILED);
227+ OPS_ERR_IF (workSpaces == nullptr , OP_LOGE (nodeName, " workSpaces is nullptr." ), return ge::GRAPH_FAILED);
230228 size_t maxTokenNum;
231229 uint32_t epRankSize = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo .epRankSize ;
232230 uint32_t epRankId = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo .epRankId ;
@@ -267,12 +265,12 @@ static ge::graphStatus DispatchGmmCombineDecodeTilingFuncImpl(gert::TilingContex
267265{
268266 const char *nodeName = context->GetNodeName ();
269267 DispatchGmmCombineDecodeTilingData *tilingData = context->GetTilingData <DispatchGmmCombineDecodeTilingData>();
270- OP_TILING_CHECK (tilingData == nullptr , OP_LOGE (nodeName, " tilingData is nullptr." ), return ge::GRAPH_FAILED);
268+ OPS_ERR_IF (tilingData == nullptr , OP_LOGE (nodeName, " tilingData is nullptr." ), return ge::GRAPH_FAILED);
271269 std::string groupEp = " " ;
272270
273271 const gert::StorageShape *xStorageShape = context->GetInputShape (INPUT_X_INDEX);
274- OP_TILING_CHECK (xStorageShape == nullptr , OP_LOGE (nodeName, " x shape is null." ), return ge::GRAPH_FAILED);
275- OP_TILING_CHECK (xStorageShape->GetStorageShape ().GetDimNum () != TWO_DIMS,
272+ OPS_ERR_IF (xStorageShape == nullptr , OP_LOGE (nodeName, " x shape is null." ), return ge::GRAPH_FAILED);
273+ OPS_ERR_IF (xStorageShape->GetStorageShape ().GetDimNum () != TWO_DIMS,
276274 OP_LOGE (nodeName, " x shape dims must be 2, but current dim num is %lu." ,
277275 xStorageShape->GetStorageShape ().GetDimNum ()),
278276 return ge::GRAPH_FAILED);
@@ -282,25 +280,25 @@ static ge::graphStatus DispatchGmmCombineDecodeTilingFuncImpl(gert::TilingContex
282280 tilingData->disGmmDeqSwigluQuantGmmDeqComInfo .h = hiddenSize;
283281
284282 const gert::StorageShape *expertIdsStorageShape = context->GetInputShape (INPUT_EXPERT_IDS_INDEX);
285- OP_TILING_CHECK (expertIdsStorageShape == nullptr , OP_LOGE (nodeName, " expertIds shape is null." ),
283+ OPS_ERR_IF (expertIdsStorageShape == nullptr , OP_LOGE (nodeName, " expertIds shape is null." ),
286284 return ge::GRAPH_FAILED);
287- OP_TILING_CHECK (expertIdsStorageShape->GetStorageShape ().GetDimNum () != TWO_DIMS,
285+ OPS_ERR_IF (expertIdsStorageShape->GetStorageShape ().GetDimNum () != TWO_DIMS,
288286 OP_LOGE (nodeName, " expertIds shape dims must be 2, but current dim num is %lu." ,
289287 expertIdsStorageShape->GetStorageShape ().GetDimNum ()),
290288 return ge::GRAPH_FAILED);
291289 const int64_t topK = expertIdsStorageShape->GetStorageShape ().GetDim (1 );
292290 tilingData->disGmmDeqSwigluQuantGmmDeqComInfo .k = topK;
293- OP_TILING_CHECK (GetAttrAndSetTilingData (context, nodeName, *tilingData, groupEp) != ge::GRAPH_SUCCESS,
291+ OPS_ERR_IF (GetAttrAndSetTilingData (context, nodeName, *tilingData, groupEp) != ge::GRAPH_SUCCESS,
294292 OP_LOGE (nodeName, " Get attr and set tiling data failed." ), return ge::GRAPH_FAILED);
295293 const gert::StorageShape *gmm1WeightStorageShape = context->GetInputShape (INPUT_GMM1_WEIGHT_INDEX);
296- OP_TILING_CHECK (gmm1WeightStorageShape == nullptr , OP_LOGE (nodeName, " gmm1Weight shape is null." ),
294+ OPS_ERR_IF (gmm1WeightStorageShape == nullptr , OP_LOGE (nodeName, " gmm1Weight shape is null." ),
297295 return ge::GRAPH_FAILED);
298296 tilingData->disGmmDeqSwigluQuantGmmDeqComInfo .gmm1HLen = gmm1WeightStorageShape->GetOriginShape ().GetDim (TWO_DIMS);
299297#ifdef ENABLE_TILING_CHECK
300- OP_TILING_CHECK (CheckData (nodeName, *tilingData) != ge::GRAPH_SUCCESS, OP_LOGE (nodeName, " CheckData failed." ),
298+ OPS_ERR_IF (CheckData (nodeName, *tilingData) != ge::GRAPH_SUCCESS, OP_LOGE (nodeName, " CheckData failed." ),
301299 return ge::GRAPH_FAILED);
302300#endif
303- OP_TILING_CHECK (SetWorkSpace (context, nodeName, *tilingData) != ge::GRAPH_SUCCESS,
301+ OPS_ERR_IF (SetWorkSpace (context, nodeName, *tilingData) != ge::GRAPH_SUCCESS,
304302 OP_LOGE (nodeName, " Tiling set workspace failed." ), return ge::GRAPH_FAILED);
305303 SetHcommCfg (context, tilingData, groupEp);
306304 if (tilingData->disGmmDeqSwigluQuantGmmDeqComInfo .moeExpertNumPerRank == 1 ) {
0 commit comments