Skip to content

Commit 0662cfa

Browse files
authored
Add ROCm support: installation guide and FlashAttention compatibility for AMD GPUs (#3925)
* add runtime for rocm * add runtime for rocm * add requirements for rocm * edit block size for rocm * edit requirements * add *.tx * rm index-url * add ROCm-safe tiling to avoid LDS OOR * treat ROCm as bf16-supported * add installation process for rocm * add installation process for rocm * add *.txt * use official lmdeploy URL and simplify ROCm installation guide with one-liner command * style: apply pre-commit fixes in flashattention.py * chore: apply pre-commit auto fixes (trailing whitespace, eof, quotes, line endings) * remove pynvml
1 parent 11b9726 commit 0662cfa

File tree

6 files changed

+91
-10
lines changed

6 files changed

+91
-10
lines changed

docs/en/get_started/installation.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,24 @@ pip install https://github.com/InternLM/lmdeploy/archive/refs/tags/v0.10.0.zip
5555
```
5656

5757
If you want to build LMDeploy with support for Ascend, Cambricon, or MACA, install LMDeploy with the corresponding `LMDEPLOY_TARGET_DEVICE` environment variable.
58+
59+
LMDeploy also supports installation on AMD GPUs with ROCm.
60+
61+
```shell
62+
#The recommended way is to use the official ROCm PyTorch Docker image with pre-installed dependencies:
63+
docker run -it \
64+
--cap-add=SYS_PTRACE \
65+
--security-opt seccomp=unconfined \
66+
--device=/dev/kfd \
67+
--device=/dev/dri \
68+
--group-add video \
69+
--ipc=host \
70+
--network=host \
71+
--shm-size 32G \
72+
-v /root:/workspace \
73+
rocm/pytorch:latest
74+
75+
76+
#Once inside the container, install LMDeploy with ROCm support:
77+
LMDEPLOY_TARGET_DEVICE=rocm pip install git+https://github.com/InternLM/lmdeploy.git
78+
```

docs/zh_cn/get_started/installation.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,24 @@ pip install https://github.com/InternLM/lmdeploy/archive/refs/tags/v0.10.0.zip
5555
```
5656

5757
如果您希望构建支持昇腾、寒武纪或沐熙的 LMDeploy,请使用相应的 `LMDEPLOY_TARGET_DEVICE` 环境变量进行安装。
58+
59+
LMDeploy 也支持在 AMD GPU 的 ROCm 环境中安装。
60+
61+
```shell
62+
#The recommended way is to use the official ROCm PyTorch Docker image with pre-installed dependencies:
63+
docker run -it \
64+
--cap-add=SYS_PTRACE \
65+
--security-opt seccomp=unconfined \
66+
--device=/dev/kfd \
67+
--device=/dev/dri \
68+
--group-add video \
69+
--ipc=host \
70+
--network=host \
71+
--shm-size 32G \
72+
-v /root:/workspace \
73+
rocm/pytorch:latest
74+
75+
76+
#Once inside the container, install LMDeploy with ROCm support:
77+
LMDEPLOY_TARGET_DEVICE=rocm pip install git+https://github.com/InternLM/lmdeploy.git
78+
```

lmdeploy/pytorch/kernels/cuda/flashattention.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,14 @@ def _kernel_meta_sm12x(BLOCK_DK: int, shared_kv: bool):
425425
return BLOCK_M, BLOCK_N, num_warps, num_stages
426426

427427

428+
def _kernel_meta_rocm(BLOCK_DK: int, shared_kv: bool):
429+
BLOCK_N = 32
430+
BLOCK_M = 32 if BLOCK_DK > 128 else 64
431+
num_warps = 4
432+
num_stages = 1
433+
return BLOCK_M, BLOCK_N, num_warps, num_stages
434+
435+
428436
def flash_attention_fwd(
429437
q_states: Tensor,
430438
k_states: Tensor,
@@ -491,17 +499,21 @@ def grid(args):
491499
shared_kv = k_states.data_ptr() == v_states.data_ptr() and BLOCK_DK == BLOCK_DV
492500

493501
num_warps = 4
494-
if _nv_cap[0] < 8:
495-
BLOCK_M, BLOCK_N, num_warps, num_stages = _kernel_meta_sm7x(BLOCK_DK)
496-
elif _nv_cap[0] < 9:
497-
if _nv_cap[1] in [6, 9]:
498-
BLOCK_M, BLOCK_N, num_warps, num_stages = _kernel_meta_sm86(BLOCK_DK, shared_kv)
499-
else:
500-
BLOCK_M, BLOCK_N, num_warps, num_stages = _kernel_meta_sm8x(BLOCK_DK, shared_kv)
501-
elif _nv_cap[0] < 10:
502-
BLOCK_M, BLOCK_N, num_warps, num_stages = _kernel_meta_sm9x(BLOCK_DK, shared_kv)
502+
hip_mode = getattr(torch.version, 'hip', None) is not None
503+
if hip_mode:
504+
BLOCK_M, BLOCK_N, num_warps, num_stages = _kernel_meta_rocm(BLOCK_DK, shared_kv)
503505
else:
504-
BLOCK_M, BLOCK_N, num_warps, num_stages = _kernel_meta_sm12x(BLOCK_DK, shared_kv)
506+
if _nv_cap[0] < 8:
507+
BLOCK_M, BLOCK_N, num_warps, num_stages = _kernel_meta_sm7x(BLOCK_DK)
508+
elif _nv_cap[0] < 9:
509+
if _nv_cap[1] in [6, 9]:
510+
BLOCK_M, BLOCK_N, num_warps, num_stages = _kernel_meta_sm86(BLOCK_DK, shared_kv)
511+
else:
512+
BLOCK_M, BLOCK_N, num_warps, num_stages = _kernel_meta_sm8x(BLOCK_DK, shared_kv)
513+
elif _nv_cap[0] < 10:
514+
BLOCK_M, BLOCK_N, num_warps, num_stages = _kernel_meta_sm9x(BLOCK_DK, shared_kv)
515+
else:
516+
BLOCK_M, BLOCK_N, num_warps, num_stages = _kernel_meta_sm12x(BLOCK_DK, shared_kv)
505517

506518
BLOCK_M = min(128, BLOCK_M)
507519
_flash_prefill_fwd_kernel[grid](

lmdeploy/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,8 @@ def is_bf16_supported(device_type: str = 'cuda'):
389389
return True
390390
elif device_type == 'camb':
391391
return True
392+
elif device_type == 'rocm':
393+
return True
392394
else:
393395
return False
394396

requirements/runtime_rocm.txt

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
accelerate>=0.29.3
2+
einops
3+
fastapi
4+
fire
5+
mmengine-lite
6+
numpy<2.0.0
7+
openai
8+
outlines
9+
partial_json_parser
10+
peft<=0.14.0
11+
pillow
12+
protobuf
13+
pydantic>2.0.0
14+
pyzmq
15+
ray
16+
safetensors
17+
sentencepiece
18+
shortuuid
19+
tiktoken
20+
transformers
21+
uvicorn

requirements_rocm.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
-r requirements/build.txt
2+
-r requirements/runtime_rocm.txt
3+
-r requirements/lite.txt
4+
-r requirements/serve.txt

0 commit comments

Comments
 (0)