Skip to content

Commit 5779f2e

Browse files
committed
修复文件冲突
Signed-off-by: wangqiankun <[email protected]>
1 parent 6735410 commit 5779f2e

File tree

3 files changed

+55
-610
lines changed

3 files changed

+55
-610
lines changed

csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_tiling.cpp

Lines changed: 38 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,11 @@
1717
#include "tiling/platform/platform_ascendc.h"
1818
#include "tiling/hccl/hccl_tiling.h"
1919

20-
#define GM_ALIGN_SIZE 512
21-
#define ENABLE_TILING_CHECK
22-
2320
using namespace ge;
2421
namespace {
2522
constexpr uint32_t OP_TYPE_ALL_TO_ALL = 8;
2623
constexpr uint32_t SYSTEM_NEED_WORKSPACE = 16 * 1024 * 1024;
24+
constexpr uint32_t GM_ALIGN_SIZE = 512;
2725
constexpr uint32_t TOKEN_DTYPE_BYTE_SIZE = 2;
2826
constexpr uint32_t USE_CORE_NUM = 24;
2927
constexpr uint32_t L1_TILE_BYTE_SIZE = 32 * 1024;
@@ -78,50 +76,50 @@ static ge::graphStatus CheckTensorShape(gert::TilingContext *context, const char
7876

7977
uint32_t localExpertNum = epRankId < sharedExpertRankNum ? 1 : moeExpertNumPerRank;
8078
const gert::StorageShape *gmm1WeightStorageShape = context->GetInputShape(INPUT_GMM1_WEIGHT_INDEX);
81-
OP_TILING_CHECK(gmm1WeightStorageShape == nullptr, OP_LOGE(nodeName, "gmm1 weight shape is null."),
79+
OPS_ERR_IF(gmm1WeightStorageShape == nullptr, OP_LOGE(nodeName, "gmm1 weight shape is null."),
8280
return ge::GRAPH_FAILED);
8381
const int64_t gmm1WeightDim0 = gmm1WeightStorageShape->GetStorageShape().GetDim(0);
84-
OP_TILING_CHECK(gmm1WeightDim0 != localExpertNum,
82+
OPS_ERR_IF(gmm1WeightDim0 != localExpertNum,
8583
OP_LOGE(nodeName, "gmm1Weight Dim0 must be expert number in current rank."),
8684
return ge::GRAPH_FAILED);
8785

8886
const gert::StorageShape *gmm1WeightScaleStorageShape = context->GetInputShape(INPUT_GMM1_WEIGHT_SCALE_INDEX);
89-
OP_TILING_CHECK(gmm1WeightScaleStorageShape == nullptr, OP_LOGE(nodeName, "gmm1 weight scale shape is null."),
87+
OPS_ERR_IF(gmm1WeightScaleStorageShape == nullptr, OP_LOGE(nodeName, "gmm1 weight scale shape is null."),
9088
return ge::GRAPH_FAILED);
91-
OP_TILING_CHECK(gmm1WeightScaleStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS,
89+
OPS_ERR_IF(gmm1WeightScaleStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS,
9290
OP_LOGE(nodeName, "gmm1 weight scale shape dims must be 2, but current dim num is %lu.",
9391
gmm1WeightScaleStorageShape->GetStorageShape().GetDimNum()),
9492
return ge::GRAPH_FAILED);
9593
const int64_t gmm1WeightScaleDim0 = gmm1WeightScaleStorageShape->GetStorageShape().GetDim(0);
96-
OP_TILING_CHECK(gmm1WeightScaleDim0 != localExpertNum,
94+
OPS_ERR_IF(gmm1WeightScaleDim0 != localExpertNum,
9795
OP_LOGE(nodeName, "gmm1WeightScale Dim0 must be expert number in current rank."),
9896
return ge::GRAPH_FAILED);
9997
const int64_t gmm1WeightScaleDim1 = gmm1WeightScaleStorageShape->GetStorageShape().GetDim(1);
100-
OP_TILING_CHECK(gmm1WeightScaleDim1 != gmm1WeightDim2,
98+
OPS_ERR_IF(gmm1WeightScaleDim1 != gmm1WeightDim2,
10199
OP_LOGE(nodeName, "gmm1WeightScale Dim1 must be %lu(gmm1WeightDim2).", gmm1WeightDim2),
102100
return ge::GRAPH_FAILED);
103101

104102
const gert::StorageShape *gmm2WeightStorageShape = context->GetInputShape(INPUT_GMM2_WEIGHT_INDEX);
105-
OP_TILING_CHECK(gmm2WeightStorageShape == nullptr, OP_LOGE(nodeName, "gmm2 weight shape is null."),
103+
OPS_ERR_IF(gmm2WeightStorageShape == nullptr, OP_LOGE(nodeName, "gmm2 weight shape is null."),
106104
return ge::GRAPH_FAILED);
107105
const int64_t gmm2WeightDim0 = gmm2WeightStorageShape->GetStorageShape().GetDim(0);
108-
OP_TILING_CHECK(gmm2WeightDim0 != localExpertNum,
106+
OPS_ERR_IF(gmm2WeightDim0 != localExpertNum,
109107
OP_LOGE(nodeName, "gmm2Weight Dim0 must be expert number in current rank."),
110108
return ge::GRAPH_FAILED);
111109

112110
const gert::StorageShape *gmm2WeightScaleStorageShape = context->GetInputShape(INPUT_GMM2_WEIGHT_SCALE_INDEX);
113-
OP_TILING_CHECK(gmm2WeightScaleStorageShape == nullptr, OP_LOGE(nodeName, "gmm2 weight scale shape is null."),
111+
OPS_ERR_IF(gmm2WeightScaleStorageShape == nullptr, OP_LOGE(nodeName, "gmm2 weight scale shape is null."),
114112
return ge::GRAPH_FAILED);
115-
OP_TILING_CHECK(gmm2WeightScaleStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS,
113+
OPS_ERR_IF(gmm2WeightScaleStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS,
116114
OP_LOGE(nodeName, "gmm2 weight scale shape dims must be 2, but current dim num is %lu.",
117115
gmm2WeightScaleStorageShape->GetStorageShape().GetDimNum()),
118116
return ge::GRAPH_FAILED);
119117
const int64_t gmm2WeightScaleDim0 = gmm2WeightScaleStorageShape->GetStorageShape().GetDim(0);
120-
OP_TILING_CHECK(gmm2WeightScaleDim0 != localExpertNum,
118+
OPS_ERR_IF(gmm2WeightScaleDim0 != localExpertNum,
121119
OP_LOGE(nodeName, "gmm2WeightScale Dim0 must be expert number in current rank."),
122120
return ge::GRAPH_FAILED);
123121
const int64_t gmm2WeightScaleDim1 = gmm2WeightScaleStorageShape->GetStorageShape().GetDim(1);
124-
OP_TILING_CHECK(gmm2WeightScaleDim1 != h, OP_LOGE(nodeName, "gmm2WeightScale Dim1 must be %u.", h),
122+
OPS_ERR_IF(gmm2WeightScaleDim1 != h, OP_LOGE(nodeName, "gmm2WeightScale Dim1 must be %u.", h),
125123
return ge::GRAPH_FAILED);
126124

127125
return ge::GRAPH_SUCCESS;
@@ -130,31 +128,31 @@ static ge::graphStatus CheckTensorShape(gert::TilingContext *context, const char
130128
static ge::graphStatus CheckData(const char *nodeName, DispatchGmmCombineDecodeTilingData &tilingData)
131129
{
132130
uint32_t batchSize = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.bs;
133-
OP_TILING_CHECK(batchSize < MIN_BATCH_SIZE, OP_LOGE(nodeName, "batchSize(bs) must >= %d.", MIN_BATCH_SIZE),
131+
OPS_ERR_IF(batchSize < MIN_BATCH_SIZE, OP_LOGE(nodeName, "batchSize(bs) must >= %d.", MIN_BATCH_SIZE),
134132
return ge::GRAPH_FAILED);
135-
OP_TILING_CHECK(batchSize > MAX_BATCH_SIZE, OP_LOGE(nodeName, "batchSize(bs) must <= %d.", MAX_BATCH_SIZE),
133+
OPS_ERR_IF(batchSize > MAX_BATCH_SIZE, OP_LOGE(nodeName, "batchSize(bs) must <= %d.", MAX_BATCH_SIZE),
136134
return ge::GRAPH_FAILED);
137135
uint32_t tokenLength = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.h;
138-
OP_TILING_CHECK(
136+
OPS_ERR_IF(
139137
tokenLength < MIN_TOKEN_LENGTH || tokenLength > MAX_TOKEN_LENGTH,
140138
OP_LOGE(nodeName, "tokenLength(h) is invalid. Only support [%u, %u].", MIN_TOKEN_LENGTH, MAX_TOKEN_LENGTH),
141139
return ge::GRAPH_FAILED);
142140
uint32_t gmm1HLen = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen;
143-
OP_TILING_CHECK(
141+
OPS_ERR_IF(
144142
gmm1HLen < MIN_GMM1_HIDDEN || gmm1HLen > MAX_GMM1_HIDDEN,
145143
OP_LOGE(nodeName, "gmm1 hidden size is invalid. Only support [%u, %u].", MIN_GMM1_HIDDEN, MAX_GMM1_HIDDEN),
146144
return ge::GRAPH_FAILED);
147145
uint32_t topK = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.k;
148-
OP_TILING_CHECK(topK > SUPPORT_TOP_K, OP_LOGE(nodeName, "topK(k) must <= %d.", SUPPORT_TOP_K),
146+
OPS_ERR_IF(topK > SUPPORT_TOP_K, OP_LOGE(nodeName, "topK(k) must <= %d.", SUPPORT_TOP_K),
149147
return ge::GRAPH_FAILED);
150148
uint32_t globalBatchSize = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.globalBs;
151149
uint32_t epRankSize = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.epRankSize;
152150
if (globalBatchSize == 0) {
153151
globalBatchSize = epRankSize * batchSize;
154152
tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.globalBs = globalBatchSize;
155153
} else {
156-
OP_TILING_CHECK(globalBatchSize < 0, OP_LOGE(nodeName, "globalBatchSize must >= 0."), return ge::GRAPH_FAILED);
157-
OP_TILING_CHECK(globalBatchSize % epRankSize > 0,
154+
OPS_ERR_IF(globalBatchSize < 0, OP_LOGE(nodeName, "globalBatchSize must >= 0."), return ge::GRAPH_FAILED);
155+
OPS_ERR_IF(globalBatchSize % epRankSize > 0,
158156
OP_LOGE(nodeName, "globalBatchSize must be divisible by epRankSize."), return ge::GRAPH_FAILED);
159157
}
160158

@@ -165,7 +163,7 @@ static ge::graphStatus GetAttrAndSetTilingData(gert::TilingContext *context, con
165163
DispatchGmmCombineDecodeTilingData &tilingData, std::string &groupEp)
166164
{
167165
auto attrs = context->GetAttrs();
168-
OP_TILING_CHECK(attrs == nullptr, OP_LOGE(nodeName, "attrs is nullptr."), return ge::GRAPH_FAILED);
166+
OPS_ERR_IF(attrs == nullptr, OP_LOGE(nodeName, "attrs is nullptr."), return ge::GRAPH_FAILED);
169167

170168
auto groupEpPtr = attrs->GetAttrPointer<char>(static_cast<int>(ATTR_GROUP_EP_INDEX));
171169
auto epRankSizePtr = attrs->GetAttrPointer<int64_t>(ATTR_EP_RANK_SIZE_INDEX);
@@ -184,16 +182,16 @@ static ge::graphStatus GetAttrAndSetTilingData(gert::TilingContext *context, con
184182
uint32_t moeExpertNumPerRank = moeExpertNum / (epRankSize - sharedExpertRankNum);
185183

186184
#ifdef ENABLE_TILING_CHECK
187-
OP_TILING_CHECK(epRankId < 0, OP_LOGE(nodeName, "epRankId must >= 0."), return ge::GRAPH_FAILED);
188-
OP_TILING_CHECK(epRankId >= epRankSize, OP_LOGE(nodeName, "epRankId must < epRankSize."), return ge::GRAPH_FAILED);
189-
OP_TILING_CHECK(moeExpertNum > MAX_MOE_EXERT_NUM, OP_LOGE(nodeName, "moeExpertNum must <= %d.", MAX_MOE_EXERT_NUM),
185+
OPS_ERR_IF(epRankId < 0, OP_LOGE(nodeName, "epRankId must >= 0."), return ge::GRAPH_FAILED);
186+
OPS_ERR_IF(epRankId >= epRankSize, OP_LOGE(nodeName, "epRankId must < epRankSize."), return ge::GRAPH_FAILED);
187+
OPS_ERR_IF(moeExpertNum > MAX_MOE_EXERT_NUM, OP_LOGE(nodeName, "moeExpertNum must <= %d.", MAX_MOE_EXERT_NUM),
190188
return ge::GRAPH_FAILED);
191-
OP_TILING_CHECK(moeExpertNum <= 0, OP_LOGE(nodeName, "moeExpertNum must > 0."), return ge::GRAPH_FAILED);
192-
OP_TILING_CHECK(sharedExpertNum != 1, OP_LOGE(nodeName, "sharedExpertNum must be 1."), return ge::GRAPH_FAILED);
193-
OP_TILING_CHECK(moeExpertNum % (epRankSize - sharedExpertRankNum) != 0,
189+
OPS_ERR_IF(moeExpertNum <= 0, OP_LOGE(nodeName, "moeExpertNum must > 0."), return ge::GRAPH_FAILED);
190+
OPS_ERR_IF(sharedExpertNum != 1, OP_LOGE(nodeName, "sharedExpertNum must be 1."), return ge::GRAPH_FAILED);
191+
OPS_ERR_IF(moeExpertNum % (epRankSize - sharedExpertRankNum) != 0,
194192
OP_LOGE(nodeName, "moeExpertNum must be divisible by (epRankSize - sharedExpertRankNum)."),
195193
return ge::GRAPH_FAILED);
196-
OP_TILING_CHECK(moeExpertNumPerRank > RECV_AIV_NUM,
194+
OPS_ERR_IF(moeExpertNumPerRank > RECV_AIV_NUM,
197195
OP_LOGE(nodeName, "moeExpertNumPerRank must <= %d.", RECV_AIV_NUM), return ge::GRAPH_FAILED);
198196
#endif
199197

@@ -226,7 +224,7 @@ static ge::graphStatus SetWorkSpace(gert::TilingContext *context, const char *no
226224
DispatchGmmCombineDecodeTilingData &tilingData)
227225
{
228226
size_t *workSpaces = context->GetWorkspaceSizes(1);
229-
OP_TILING_CHECK(workSpaces == nullptr, OP_LOGE(nodeName, "workSpaces is nullptr."), return ge::GRAPH_FAILED);
227+
OPS_ERR_IF(workSpaces == nullptr, OP_LOGE(nodeName, "workSpaces is nullptr."), return ge::GRAPH_FAILED);
230228
size_t maxTokenNum;
231229
uint32_t epRankSize = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.epRankSize;
232230
uint32_t epRankId = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.epRankId;
@@ -267,12 +265,12 @@ static ge::graphStatus DispatchGmmCombineDecodeTilingFuncImpl(gert::TilingContex
267265
{
268266
const char *nodeName = context->GetNodeName();
269267
DispatchGmmCombineDecodeTilingData *tilingData = context->GetTilingData<DispatchGmmCombineDecodeTilingData>();
270-
OP_TILING_CHECK(tilingData == nullptr, OP_LOGE(nodeName, "tilingData is nullptr."), return ge::GRAPH_FAILED);
268+
OPS_ERR_IF(tilingData == nullptr, OP_LOGE(nodeName, "tilingData is nullptr."), return ge::GRAPH_FAILED);
271269
std::string groupEp = "";
272270

273271
const gert::StorageShape *xStorageShape = context->GetInputShape(INPUT_X_INDEX);
274-
OP_TILING_CHECK(xStorageShape == nullptr, OP_LOGE(nodeName, "x shape is null."), return ge::GRAPH_FAILED);
275-
OP_TILING_CHECK(xStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS,
272+
OPS_ERR_IF(xStorageShape == nullptr, OP_LOGE(nodeName, "x shape is null."), return ge::GRAPH_FAILED);
273+
OPS_ERR_IF(xStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS,
276274
OP_LOGE(nodeName, "x shape dims must be 2, but current dim num is %lu.",
277275
xStorageShape->GetStorageShape().GetDimNum()),
278276
return ge::GRAPH_FAILED);
@@ -282,25 +280,25 @@ static ge::graphStatus DispatchGmmCombineDecodeTilingFuncImpl(gert::TilingContex
282280
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.h = hiddenSize;
283281

284282
const gert::StorageShape *expertIdsStorageShape = context->GetInputShape(INPUT_EXPERT_IDS_INDEX);
285-
OP_TILING_CHECK(expertIdsStorageShape == nullptr, OP_LOGE(nodeName, "expertIds shape is null."),
283+
OPS_ERR_IF(expertIdsStorageShape == nullptr, OP_LOGE(nodeName, "expertIds shape is null."),
286284
return ge::GRAPH_FAILED);
287-
OP_TILING_CHECK(expertIdsStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS,
285+
OPS_ERR_IF(expertIdsStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS,
288286
OP_LOGE(nodeName, "expertIds shape dims must be 2, but current dim num is %lu.",
289287
expertIdsStorageShape->GetStorageShape().GetDimNum()),
290288
return ge::GRAPH_FAILED);
291289
const int64_t topK = expertIdsStorageShape->GetStorageShape().GetDim(1);
292290
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.k = topK;
293-
OP_TILING_CHECK(GetAttrAndSetTilingData(context, nodeName, *tilingData, groupEp) != ge::GRAPH_SUCCESS,
291+
OPS_ERR_IF(GetAttrAndSetTilingData(context, nodeName, *tilingData, groupEp) != ge::GRAPH_SUCCESS,
294292
OP_LOGE(nodeName, "Get attr and set tiling data failed."), return ge::GRAPH_FAILED);
295293
const gert::StorageShape *gmm1WeightStorageShape = context->GetInputShape(INPUT_GMM1_WEIGHT_INDEX);
296-
OP_TILING_CHECK(gmm1WeightStorageShape == nullptr, OP_LOGE(nodeName, "gmm1Weight shape is null."),
294+
OPS_ERR_IF(gmm1WeightStorageShape == nullptr, OP_LOGE(nodeName, "gmm1Weight shape is null."),
297295
return ge::GRAPH_FAILED);
298296
tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen = gmm1WeightStorageShape->GetOriginShape().GetDim(TWO_DIMS);
299297
#ifdef ENABLE_TILING_CHECK
300-
OP_TILING_CHECK(CheckData(nodeName, *tilingData) != ge::GRAPH_SUCCESS, OP_LOGE(nodeName, "CheckData failed."),
298+
OPS_ERR_IF(CheckData(nodeName, *tilingData) != ge::GRAPH_SUCCESS, OP_LOGE(nodeName, "CheckData failed."),
301299
return ge::GRAPH_FAILED);
302300
#endif
303-
OP_TILING_CHECK(SetWorkSpace(context, nodeName, *tilingData) != ge::GRAPH_SUCCESS,
301+
OPS_ERR_IF(SetWorkSpace(context, nodeName, *tilingData) != ge::GRAPH_SUCCESS,
304302
OP_LOGE(nodeName, "Tiling set workspace failed."), return ge::GRAPH_FAILED);
305303
SetHcommCfg(context, tilingData, groupEp);
306304
if (tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank == 1) {

0 commit comments

Comments
 (0)