Skip to content

Commit 5b6d940

Browse files
maajidkhanndivya2108abhishek-iitmadras
committed
Hot Fix to allow Matmul to use brg:sve_512 from OneDNN.
Signed-off-by: majidkhann <[email protected]> Co-authored-by: Divya Kotadiya <[email protected]> Co-authored-by: Abhishek Kumar <[email protected]>
1 parent b234f19 commit 5b6d940

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

aten/src/ATen/native/LinearAlgebra.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1512,14 +1512,15 @@ static void addmm_impl_cpu_(
15121512
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!c.is_conj());
15131513

15141514
bool dispatched = false;
1515-
#if defined(__aarch64__) && AT_MKLDNN_ACL_ENABLED()
1515+
// #if defined(__aarch64__) && AT_MKLDNN_ACL_ENABLED()
1516+
#if defined(__aarch64__)
15161517
// On AArch64 if LHS matrix in BLAS routine is transposed but RHS is not then
15171518
// it is faster to call oneDNN matrix multiplication primitive with RHS*LHS
15181519
// that will call then into Arm® Compute Library (ACL) GEMM kernel and also
15191520
// additionally have support for running kernel with BF16 instructions
15201521
if (transpose_c) {
15211522
bool apply_heur = apply_mkldnn_matmul_heur(b.sizes()[0], b.sizes()[1], a.sizes()[1]);
1522-
if (apply_heur && transpose_a && !transpose_b && result.scalar_type() == at::ScalarType::Float) {
1523+
// if (apply_heur && transpose_a && !transpose_b && result.scalar_type() == at::ScalarType::Float) {
15231524
try {
15241525
mkldnn_matmul(b, a, c, beta.to<float>(), alpha.to<float>());
15251526
// We have dispatched to ACL GEMM for single precision float
@@ -1529,7 +1530,7 @@ static void addmm_impl_cpu_(
15291530
TORCH_WARN("mkldnn_matmul failed, switching to BLAS gemm:", e.what());
15301531
at::globalContext().setUserEnabledMkldnn(false);
15311532
}
1532-
}
1533+
//}
15331534
}
15341535
#endif
15351536

@@ -1776,7 +1777,8 @@ static inline void bmm_out_or_baddbmm_(const Tensor& self_or_result_, const Tens
17761777
};
17771778

17781779
bool apply_heur = apply_mkldnn_matmul_heur(batch1.sizes()[1], batch1.sizes()[2], batch2.sizes()[2]);
1779-
if (apply_heur && use_mkldnn_matmul(batch1, batch2, self_or_result)) {
1780+
// if (apply_heur && use_mkldnn_matmul(batch1, batch2, self_or_result)) {
1781+
if (apply_heur) {
17801782
try {
17811783
mkldnn_matmul(batch1, batch2, self_or_result, beta.to<float>(), alpha.to<float>());
17821784
return;

0 commit comments

Comments
 (0)