-
Notifications
You must be signed in to change notification settings - Fork 80
在scaled_dot_product_attention函数中,改为如果输入是3D,输出也是3D #601
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
modified: paconvert/api_matcher.py modified: tests/test_scaled_dot_product_attention.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
modified: tests/test_scaled_dot_product_attention.py
| import torch | ||
| np.random.seed(100) | ||
| x = np.random.rand(8, 128, 64) | ||
| query = torch.tensor(x, dtype=torch.float16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个在CPU下目前还不支持bfloat16。所以跑不了。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这一块由于 scaled_dot_product_attention 目前后端选择比较混乱。
可能很容易存在选择了不合理的后端导致无法通过。想要通过可能需要仔细设计下case,更合理的方式是优化scaled_dot_product_attention后端选择逻辑,避免总是选择不合理的后端。
嗯,收到 |
先修改下case吧,看看怎么能测到这个3D的功能,同时又能避开后端选择错误的问题。先把这个PR合入进去。 |
modified: tests/test_scaled_dot_product_attention.py
|
在case10中暂时考虑了mask |
PR Docs
PaddlePaddle/Paddle#73804
PR APIs
新增参数
优化功能
解决bug
LayerList.insert,bernoulli,Tensor.data,LSTMCell,fused_rms_norm,softmaxatleast和to_tensor和解决bug中的函数不需要修改PaConvert和Docs