@@ -422,6 +422,8 @@ static std::pair<VectorDims, VectorDims> makeDummyInputDims(const Shape& in0,
422422 } else {
423423 inDims1[idx1] = inDims0[idx0];
424424 }
425+ } else if (inDims0[idx0] != Shape::UNDEFINED_DIM && inDims1[idx1] != Shape::UNDEFINED_DIM) {
426+ inDims1[idx1] = inDims0[idx0];
425427 }
426428 }
427429 };
@@ -519,13 +521,16 @@ DnnlShapeAgnosticDataPtr DnnlMatMulPrimitive::createShapeAgnosticData(const MatM
519521 if (srcDesc->getShape ().isDynamic () || weiDesc->getShape ().isDynamic ()) {
520522 const auto & srcShape = srcDesc->getShape ();
521523 const auto & weiShape = weiDesc->getShape ();
522- const auto & [inDymmyDims, weiDymmyDims] =
524+ auto [inDymmyDims, weiDymmyDims] =
523525 makeDummyInputDims (srcShape, weiShape, dstDesc->getShape (), attrs.transposeA , attrs.transposeB );
524526 const auto & outDymmyDims = makeDummyOutputDims (inDymmyDims,
525527 weiDymmyDims,
526528 attrs.transposeA ,
527529 attrs.transposeB ,
528530 dstDesc->getShape ().getRank ());
531+ if (attrs.weightsNonTransposed ) {
532+ std::swap (weiDymmyDims[weiDymmyDims.size () - 1 ], weiDymmyDims[weiDymmyDims.size () - 2 ]);
533+ }
529534 srcDesc = std::make_shared<DnnlBlockedMemoryDesc>(srcDesc->getPrecision (), Shape (inDymmyDims));
530535 weiDesc = std::make_shared<DnnlBlockedMemoryDesc>(weiDesc->getPrecision (), Shape (weiDymmyDims));
531536 dstDesc = std::make_shared<DnnlBlockedMemoryDesc>(dstDesc->getPrecision (), Shape (outDymmyDims));
0 commit comments