-
Notifications
You must be signed in to change notification settings - Fork 5.8k
Open
Labels
Description
需求描述 Feature Description
任务目标: 在游戏卡之类双精度计算速度被阉割的卡上实现快速的双精度矩阵向量乘法运算,其中矩阵行数远大于列数,该计算过程在线性系统求解IDR(s)算法里有大量使用,是计算速度的瓶颈之一
两段测试代码分别为pytorch与paddle的实现,可以看出pytorch可以巧妙的避开双精度速度限制,而paddle没有
import torch
default_device = torch.device('cuda')
torch.set_default_device(default_device)
torch.set_default_dtype(torch.float64)
size = 6400000
#方法1,使用pytorch自带的mv,计算速度慢
a = torch.randn((size, 4))
b = torch.randn((4,))
torch.cuda.synchronize()
start = time.time()
for i in range(1000):
c = torch.mv(a,b)
torch.cuda.synchronize()
end = time.time()
print('c = torch.mv(a,b)耗时为:', end - start)
#方法2,将a按转置的方式存储,用einsum实现等价运算,计算速度快
a = torch.randn((4, size))
b = torch.randn((4,))
torch.cuda.synchronize()
start = time.time()
for i in range(1000):
c = torch.einsum('ij,i->j',a,b)
torch.cuda.synchronize()
end = time.time()
print('c = torch.einsum耗时为:', end - start)
import paddle
default_device = 'gpu'
paddle.device.set_device(default_device)
paddle.set_default_dtype(paddle.float64)
size = 6400000
#方法一,使用paddle自带的mv实现,计算速度慢
a = paddle.randn((size, 4))
b = paddle.randn((4,))
paddle.device.synchronize()
start = time.time()
for i in range(1000):
c = paddle.mv(a,b)
paddle.device.synchronize()
end = time.time()
print('c = paddle.mv(a,b)耗时为:', end - start)
#方法二,将矩阵a转置后用einsum实现,计算速度依然慢,远远慢于pytorch对应实现方法
a = paddle.randn((4, size))
b = paddle.randn((4,))
paddle.device.synchronize()
start = time.time()
for i in range(1000):
c = paddle.einsum('ij,i->j',a,b)
paddle.device.synchronize()
end = time.time()
print('c = paddle.einsum耗时为:', end - start)
替代实现 Alternatives
No response