diff --git a/torchao/csrc/cpu/aten_kernels/float8_linear.cpp b/torchao/csrc/cpu/aten_kernels/float8_linear.cpp index e9b290771d..fb40f088a6 100644 --- a/torchao/csrc/cpu/aten_kernels/float8_linear.cpp +++ b/torchao/csrc/cpu/aten_kernels/float8_linear.cpp @@ -453,7 +453,11 @@ void _float8_linear_impl( TORCH_CHECK(weight.size(3) == block_n, "Float8 linear: unexpected weight shape"); int64_t N = Nc * block_n; TORCH_CHECK(K == Kc * block_k, "Float8 linear: weight and input shapes mismatch"); - auto [parallel_on_M, block_m, Mc, Mc_parallel] = get_m_blocking(M); + std::tuple m_blocking_info = get_m_blocking(M); + bool parallel_on_M = std::get<0>(m_blocking_info); + int64_t block_m = std::get<1>(m_blocking_info); + int64_t Mc = std::get<2>(m_blocking_info); + int64_t Mc_parallel = std::get<3>(m_blocking_info); int64_t num_parallel_blocks = Mc_parallel * Nc; // scales shape = [Nc, G, block_n]