Commit 05ab01d
committed
(fix) timm: ROCm 7.0 compatibility for Attention2d modules
ROCm 7.0 enforces GEMM paths for 1x1 convolutions, requiring strict
memory contiguity. This change causes HIP error: invalid argument
when non-contiguous tensors (from reshape/permute/slice operations)
are passed to Attention2d and MultiQueryAttention2d modules.
Changes:
- Add contiguity checks in Attention2d.forward()
- Add contiguity checks in MultiQueryAttention2d.forward()
- Force .contiguous() only when tensor is non-contiguous
Fixes #2613
Signed-off-by: Emilien Macchi <[email protected]>1 parent ae4d1bb commit 05ab01d
3 files changed
+21
-8
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
6 | 6 | | |
7 | 7 | | |
8 | 8 | | |
9 | | - | |
| 9 | + | |
10 | 10 | | |
11 | 11 | | |
12 | 12 | | |
| |||
271 | 271 | | |
272 | 272 | | |
273 | 273 | | |
| 274 | + | |
| 275 | + | |
| 276 | + | |
| 277 | + | |
274 | 278 | | |
275 | 279 | | |
276 | 280 | | |
| |||
351 | 355 | | |
352 | 356 | | |
353 | 357 | | |
| 358 | + | |
| 359 | + | |
| 360 | + | |
| 361 | + | |
354 | 362 | | |
355 | 363 | | |
356 | 364 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
4 | 4 | | |
5 | 5 | | |
6 | 6 | | |
| 7 | + | |
7 | 8 | | |
8 | 9 | | |
9 | 10 | | |
| |||
41 | 42 | | |
42 | 43 | | |
43 | 44 | | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
21 | 21 | | |
22 | 22 | | |
23 | 23 | | |
| 24 | + | |
24 | 25 | | |
25 | 26 | | |
26 | 27 | | |
| |||
155 | 156 | | |
156 | 157 | | |
157 | 158 | | |
158 | | - | |
159 | | - | |
160 | | - | |
161 | | - | |
162 | | - | |
163 | | - | |
164 | 159 | | |
165 | 160 | | |
166 | 161 | | |
| |||
191 | 186 | | |
192 | 187 | | |
193 | 188 | | |
194 | | - | |
| 189 | + | |
195 | 190 | | |
196 | 191 | | |
197 | 192 | | |
| |||
0 commit comments