Skip to content

Commit 05ab01d

Browse files
committed
(fix) timm: ROCm 7.0 compatibility for Attention2d modules
ROCm 7.0 enforces GEMM paths for 1x1 convolutions, requiring strict memory contiguity. This change causes HIP error: invalid argument when non-contiguous tensors (from reshape/permute/slice operations) are passed to Attention2d and MultiQueryAttention2d modules. Changes: - Add contiguity checks in Attention2d.forward() - Add contiguity checks in MultiQueryAttention2d.forward() - Force .contiguous() only when tensor is non-contiguous Fixes #2613 Signed-off-by: Emilien Macchi <[email protected]>
1 parent ae4d1bb commit 05ab01d

File tree

3 files changed

+21
-8
lines changed

3 files changed

+21
-8
lines changed

timm/layers/attention2d.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from .config import use_fused_attn
88
from .create_conv2d import create_conv2d
9-
from .helpers import to_2tuple
9+
from .helpers import to_2tuple, is_contiguous
1010
from .pool2d_same import create_pool2d
1111

1212

@@ -271,6 +271,10 @@ def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
271271
"""Run layer computation."""
272272
B, C, H, W = s = x.shape
273273

274+
# Force memory contiguity to satisfy GEMM constraints for 1x1 convolutions
275+
if not is_contiguous(x):
276+
x = x.contiguous()
277+
274278
q = self.query(x)
275279
# desired q shape: [b, h, k, n x n] - [b, l, h, k]
276280
q = self._reshape_projected_query(q, self.num_heads, self.key_dim)
@@ -351,6 +355,10 @@ def __init__(
351355
def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
352356
B, C, H, W = x.shape
353357

358+
# Force memory contiguity to satisfy GEMM constraints for 1x1 convolutions
359+
if not is_contiguous(x):
360+
x = x.contiguous()
361+
354362
if self.head_first:
355363
q, k, v = self.qkv(x).view(B, self.num_heads, self.dim_head * 3, -1).chunk(3, dim=2)
356364
else:

timm/layers/helpers.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"""
55
from itertools import repeat
66
import collections.abc
7+
import torch
78

89

910
# From PyTorch internals
@@ -41,3 +42,12 @@ def extend_tuple(x, n):
4142
if pad_n <= 0:
4243
return x[:n]
4344
return x + (x[-1],) * pad_n
45+
46+
47+
def is_contiguous(tensor: torch.Tensor) -> bool:
48+
"""Check tensor contiguity with proper handling for TorchScript compilation."""
49+
# jit is oh so lovely :/
50+
if torch.jit.is_scripting():
51+
return tensor.is_contiguous()
52+
else:
53+
return tensor.is_contiguous(memory_format=torch.contiguous_format)

timm/layers/norm.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
fast_simple_norm,
2222
simple_norm,
2323
)
24+
from .helpers import is_contiguous
2425

2526
try:
2627
from torch.nn.functional import rms_norm
@@ -155,12 +156,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
155156
return x
156157

157158

158-
def _is_contiguous(tensor: torch.Tensor) -> bool:
159-
# jit is oh so lovely :/
160-
if torch.jit.is_scripting():
161-
return tensor.is_contiguous()
162-
else:
163-
return tensor.is_contiguous(memory_format=torch.contiguous_format)
164159

165160

166161
def _layer_norm_cf(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float):
@@ -191,7 +186,7 @@ def __init__(self, num_channels: int, eps: float = 1e-6):
191186
super().__init__(num_channels, eps=eps)
192187

193188
def forward(self, x) -> torch.Tensor:
194-
if _is_contiguous(x):
189+
if is_contiguous(x):
195190
x = F.layer_norm(
196191
x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2)
197192
else:

0 commit comments

Comments
 (0)