Skip to content

Commit d5505d9

Browse files
committed
fix 3d matmul tag
1 parent 0aa06ae commit d5505d9

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

src/plugins/intel_cpu/src/nodes/executors/dnnl/dnnl_matmul_primitive.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,10 +160,15 @@ DnnlMemoryDescPtr DnnlMatMulPrimitive::makeTransposedWeightDescriptor(const Dnnl
160160
const auto& weiDesc = srcDesc->getDnnlDesc();
161161
auto wDims = weiDesc.get_dims();
162162
std::swap(wDims[wDims.size() - 1], wDims[wDims.size() - 2]);
163+
const auto wDataType = weiDesc.get_data_type();
164+
if (wDims.size() == 3 && !weightsNonTransposed) {
165+
const auto format3D = dnnl::memory::format_tag::acb;
166+
const auto transposed3DWeiDesc = dnnl::memory::desc{wDims, wDataType, format3D};
167+
return DnnlExtensionUtils::makeDescriptor(transposed3DWeiDesc);
168+
}
163169

164170
const dnnl::memory::dims wDims2D = reshapeDownToRank<2>(wDims);
165171
const auto format = weightsNonTransposed ? dnnl::memory::format_tag::ab : dnnl::memory::format_tag::ba;
166-
const auto wDataType = weiDesc.get_data_type();
167172
const auto transposedWeiDesc = dnnl::memory::desc{wDims2D, wDataType, format};
168173

169174
const auto reshapedWeiDesc = transposedWeiDesc.reshape(dstDesc->getDnnlDesc().get_dims());

0 commit comments

Comments
 (0)