Skip to content

Commit 8e33ea7

Browse files
authored
Merge branch 'main' into add_corrupted_req_metric
2 parents 980b085 + 02af36d commit 8e33ea7

File tree

126 files changed

+3393
-1352
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

126 files changed

+3393
-1352
lines changed

.buildkite/test-amd.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -561,7 +561,7 @@ steps:
561561

562562
- label: Model Executor Test # 23min
563563
timeout_in_minutes: 35
564-
mirror_hardwares: [amdexperimental]
564+
mirror_hardwares: [amdexperimental, amdproduction]
565565
agent_pool: mi325_1
566566
# grade: Blocking
567567
source_file_dependencies:

cmake/cpu_extension.cmake

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -212,11 +212,24 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON
212212
# Build ACL with scons
213213
include(ProcessorCount)
214214
ProcessorCount(_NPROC)
215+
set(_scons_cmd
216+
scons -j${_NPROC}
217+
Werror=0 debug=0 neon=1 examples=0 embed_kernels=0 os=linux
218+
arch=armv8.2-a build=native benchmark_examples=0 fixed_format_kernels=1
219+
multi_isa=1 openmp=1 cppthreads=0
220+
)
221+
222+
# locate PyTorch's libgomp (e.g. site-packages/torch.libs/libgomp-947d5fa1.so.1.0.0)
223+
# and create a local shim dir with it
224+
include("${CMAKE_CURRENT_LIST_DIR}/utils.cmake")
225+
vllm_prepare_torch_gomp_shim(VLLM_TORCH_GOMP_SHIM_DIR)
226+
227+
if(NOT VLLM_TORCH_GOMP_SHIM_DIR STREQUAL "")
228+
list(APPEND _scons_cmd extra_link_flags=-L${VLLM_TORCH_GOMP_SHIM_DIR})
229+
endif()
230+
215231
execute_process(
216-
COMMAND scons -j${_NPROC}
217-
Werror=0 debug=0 neon=1 examples=0 embed_kernels=0 os=linux
218-
arch=armv8.2-a build=native benchmark_examples=0 fixed_format_kernels=1
219-
multi_isa=1 openmp=1 cppthreads=0
232+
COMMAND ${_scons_cmd}
220233
WORKING_DIRECTORY "$ENV{ACL_ROOT_DIR}"
221234
RESULT_VARIABLE _acl_rc
222235
)

cmake/utils.cmake

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,44 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
129129
set(${OUT_GPU_FLAGS} ${GPU_FLAGS} PARENT_SCOPE)
130130
endfunction()
131131

132+
# Find libgomp that gets shipped with PyTorch wheel and create a shim dir with:
133+
# libgomp.so -> libgomp-<hash>.so...
134+
# libgomp.so.1 -> libgomp-<hash>.so...
135+
# OUTPUT: TORCH_GOMP_SHIM_DIR ("" if not found)
136+
function(vllm_prepare_torch_gomp_shim TORCH_GOMP_SHIM_DIR)
137+
set(${TORCH_GOMP_SHIM_DIR} "" PARENT_SCOPE)
138+
139+
# Use run_python to locate vendored libgomp; never throw on failure.
140+
run_python(_VLLM_TORCH_GOMP_PATH
141+
"
142+
import os, glob
143+
try:
144+
import torch
145+
torch_pkg = os.path.dirname(torch.__file__)
146+
site_root = os.path.dirname(torch_pkg)
147+
torch_libs = os.path.join(site_root, 'torch.libs')
148+
print(glob.glob(os.path.join(torch_libs, 'libgomp-*.so*'))[0])
149+
except:
150+
print('')
151+
"
152+
"failed to probe torch.libs for libgomp")
153+
154+
if(_VLLM_TORCH_GOMP_PATH STREQUAL "" OR NOT EXISTS "${_VLLM_TORCH_GOMP_PATH}")
155+
return()
156+
endif()
157+
158+
# Create shim under the build tree
159+
set(_shim "${CMAKE_BINARY_DIR}/gomp_shim")
160+
file(MAKE_DIRECTORY "${_shim}")
161+
162+
execute_process(COMMAND ${CMAKE_COMMAND} -E rm -f "${_shim}/libgomp.so")
163+
execute_process(COMMAND ${CMAKE_COMMAND} -E rm -f "${_shim}/libgomp.so.1")
164+
execute_process(COMMAND ${CMAKE_COMMAND} -E create_symlink "${_VLLM_TORCH_GOMP_PATH}" "${_shim}/libgomp.so")
165+
execute_process(COMMAND ${CMAKE_COMMAND} -E create_symlink "${_VLLM_TORCH_GOMP_PATH}" "${_shim}/libgomp.so.1")
166+
167+
set(${TORCH_GOMP_SHIM_DIR} "${_shim}" PARENT_SCOPE)
168+
endfunction()
169+
132170
# Macro for converting a `gencode` version number to a cmake version number.
133171
macro(string_to_ver OUT_VER IN_STR)
134172
string(REGEX REPLACE "\([0-9]+\)\([0-9]\)" "\\1.\\2" ${OUT_VER} ${IN_STR})

docker/Dockerfile.cpu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ RUN echo 'ulimit -c 0' >> ~/.bashrc
7979
######################### BUILD IMAGE #########################
8080
FROM base AS vllm-build
8181

82-
ARG max_jobs=2
82+
ARG max_jobs=32
8383
ENV MAX_JOBS=${max_jobs}
8484

8585
ARG GIT_REPO_CHECK=0

docker/Dockerfile.rocm_base

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ ARG PYTORCH_REPO="https://github.com/ROCm/pytorch.git"
77
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
88
ARG FA_BRANCH="0e60e394"
99
ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git"
10-
ARG AITER_BRANCH="eef23c7f"
10+
ARG AITER_BRANCH="9716b1b8"
1111
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
1212

1313
FROM ${BASE_IMAGE} AS base

docs/deployment/docker.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,11 @@ You can add any other [engine-args](../configuration/engine_args.md) you need af
4141
create a custom Dockerfile on top of the base image with an extra layer that installs them:
4242

4343
```Dockerfile
44-
FROM vllm/vllm-openai:v0.9.0
44+
FROM vllm/vllm-openai:v0.11.0
4545

4646
# e.g. install the `audio` optional dependencies
4747
# NOTE: Make sure the version of vLLM matches the base image!
48-
RUN uv pip install --system vllm[audio]==0.9.0
48+
RUN uv pip install --system vllm[audio]==0.11.0
4949
```
5050

5151
!!! tip

docs/features/reasoning_outputs.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,12 @@ vLLM currently supports the following reasoning models:
1414
| [DeepSeek-V3.1](https://huggingface.co/collections/deepseek-ai/deepseek-v31-68a491bed32bd77e7fca048f) | `deepseek_v3` | `json`, `regex` ||
1515
| [ERNIE-4.5-VL series](https://huggingface.co/baidu/ERNIE-4.5-VL-28B-A3B-PT) | `ernie45` | `json`, `regex` ||
1616
| [ERNIE-4.5-21B-A3B-Thinking](https://huggingface.co/baidu/ERNIE-4.5-21B-A3B-Thinking) | `ernie45` | `json`, `regex` ||
17-
| [QwQ-32B](https://huggingface.co/Qwen/QwQ-32B) | `deepseek_r1` | `json`, `regex` ||
17+
| [GLM-4.5 series](https://huggingface.co/collections/zai-org/glm-45-687c621d34bda8c9e4bf503b) | `glm45` | `json`, `regex` ||
18+
| [Hunyuan A13B series](https://huggingface.co/collections/tencent/hunyuan-a13b-685ec38e5b46321e3ea7c4be) | `hunyuan_a13b` | `json`, `regex` ||
1819
| [IBM Granite 3.2 language models](https://huggingface.co/collections/ibm-granite/granite-32-language-models-67b3bc8c13508f6d064cff9a) | `granite` |||
20+
| [MiniMax-M2](https://huggingface.co/MiniMaxAI/MiniMax-M2) | `minimax_m2_append_think` | `json`, `regex` ||
1921
| [Qwen3 series](https://huggingface.co/collections/Qwen/qwen3-67dd247413f0e2e4f653967f) | `qwen3` | `json`, `regex` ||
20-
| [Hunyuan A13B series](https://huggingface.co/collections/tencent/hunyuan-a13b-685ec38e5b46321e3ea7c4be) | `hunyuan_a13b` | `json`, `regex` ||
21-
| [GLM-4.5 series](https://huggingface.co/collections/zai-org/glm-45-687c621d34bda8c9e4bf503b) | `glm45` | `json`, `regex` ||
22+
| [QwQ-32B](https://huggingface.co/Qwen/QwQ-32B) | `deepseek_r1` | `json`, `regex` ||
2223

2324
!!! note
2425
IBM Granite 3.2 and DeepSeek-V3.1 reasoning is disabled by default; to enable it, you must also pass `thinking=True` in your `chat_template_kwargs`.

docs/features/tool_calling.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ Supported models:
321321
Flags:
322322

323323
* For non-reasoning: `--tool-call-parser hunyuan_a13b`
324-
* For reasoning: `--tool-call-parser hunyuan_a13b --reasoning-parser hunyuan_a13b --enable_reasoning`
324+
* For reasoning: `--tool-call-parser hunyuan_a13b --reasoning-parser hunyuan_a13b`
325325

326326
### LongCat-Flash-Chat Models (`longcat`)
327327

docs/getting_started/installation/gpu.rocm.inc.md

Lines changed: 74 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# --8<-- [start:installation]
22

3-
vLLM supports AMD GPUs with ROCm 6.3 or above.
3+
vLLM supports AMD GPUs with ROCm 6.3 or above, and torch 2.8.0 and above.
44

55
!!! tip
66
[Docker](#set-up-using-docker) is the recommended way to use vLLM on ROCm.
@@ -28,57 +28,63 @@ Currently, there are no pre-built ROCm wheels.
2828
# --8<-- [end:pre-built-wheels]
2929
# --8<-- [start:build-wheel-from-source]
3030

31+
!!! tip
32+
- If you found that the following installation step does not work for you, please refer to [docker/Dockerfile.rocm_base](https://github.com/vllm-project/vllm/blob/main/docker/Dockerfile.rocm_base). Dockerfile is a form of installation steps.
33+
3134
0. Install prerequisites (skip if you are already in an environment/docker with the following installed):
3235

3336
- [ROCm](https://rocm.docs.amd.com/en/latest/deploy/linux/index.html)
3437
- [PyTorch](https://pytorch.org/)
3538

36-
For installing PyTorch, you can start from a fresh docker image, e.g, `rocm/pytorch:rocm6.4.3_ubuntu24.04_py3.12_pytorch_release_2.6.0`, `rocm/pytorch-nightly`. If you are using docker image, you can skip to Step 3.
39+
For installing PyTorch, you can start from a fresh docker image, e.g, `rocm/pytorch:rocm7.0_ubuntu22.04_py3.10_pytorch_release_2.8.0`, `rocm/pytorch-nightly`. If you are using docker image, you can skip to Step 3.
3740

3841
Alternatively, you can install PyTorch using PyTorch wheels. You can check PyTorch installation guide in PyTorch [Getting Started](https://pytorch.org/get-started/locally/). Example:
3942

4043
```bash
4144
# Install PyTorch
4245
pip uninstall torch -y
43-
pip install --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/rocm6.4
46+
pip install --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/nightly/rocm7.0
4447
```
4548

46-
1. Install [Triton for ROCm](https://github.com/triton-lang/triton)
49+
1. Install [Triton for ROCm](https://github.com/ROCm/triton.git)
4750

48-
Install ROCm's Triton (the default triton-mlir branch) following the instructions from [ROCm/triton](https://github.com/ROCm/triton/blob/triton-mlir/README.md)
51+
Install ROCm's Triton following the instructions from [ROCm/triton](https://github.com/ROCm/triton.git)
4952
5053
```bash
5154
python3 -m pip install ninja cmake wheel pybind11
5255
pip uninstall -y triton
53-
git clone https://github.com/triton-lang/triton.git
56+
git clone https://github.com/ROCm/triton.git
5457
cd triton
55-
git checkout e5be006
58+
# git checkout $TRITON_BRANCH
59+
git checkout f9e5bf54
5660
if [ ! -f setup.py ]; then cd python; fi
5761
python3 setup.py install
5862
cd ../..
5963
```
6064
6165
!!! note
62-
If you see HTTP issue related to downloading packages during building triton, please try again as the HTTP error is intermittent.
66+
- The validated `$TRITON_BRANCH` can be found in the [docker/Dockerfile.rocm_base](https://github.com/vllm-project/vllm/blob/main/docker/Dockerfile.rocm_base).
67+
- If you see HTTP issue related to downloading packages during building triton, please try again as the HTTP error is intermittent.
6368
64-
2. Optionally, if you choose to use CK flash attention, you can install [flash attention for ROCm](https://github.com/Dao-AILab/flash-attention)
69+
2. Optionally, if you choose to use CK flash attention, you can install [flash attention for ROCm](https://github.com/Dao-AILab/flash-attention.git)
6570
66-
Install ROCm's flash attention (v2.7.2) following the instructions from [ROCm/flash-attention](https://github.com/ROCm/flash-attention#amd-rocm-support)
67-
Alternatively, wheels intended for vLLM use can be accessed under the releases.
71+
Install ROCm's flash attention (v2.8.0) following the instructions from [ROCm/flash-attention](https://github.com/Dao-AILab/flash-attention#amd-rocm-support)
6872

69-
For example, for ROCm 6.3, suppose your gfx arch is `gfx90a`. To get your gfx architecture, run `rocminfo |grep gfx`.
73+
For example, for ROCm 7.0, suppose your gfx arch is `gfx942`. To get your gfx architecture, run `rocminfo |grep gfx`.
7074

7175
```bash
7276
git clone https://github.com/Dao-AILab/flash-attention.git
7377
cd flash-attention
74-
git checkout 1a7f4dfa
78+
# git checkout $FA_BRANCH
79+
git checkout 0e60e394
7580
git submodule update --init
76-
GPU_ARCHS="gfx90a" python3 setup.py install
81+
GPU_ARCHS="gfx942" python3 setup.py install
7782
cd ..
7883
```
7984

8085
!!! note
81-
You might need to downgrade the "ninja" version to 1.10 as it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`)
86+
- The validated `$FA_BRANCH` can be found in the [docker/Dockerfile.rocm_base](https://github.com/vllm-project/vllm/blob/main/docker/Dockerfile.rocm_base).
87+
8288

8389
3. If you choose to build AITER yourself to use a certain branch or commit, you can build AITER using the following steps:
8490

@@ -92,11 +98,13 @@ Currently, there are no pre-built ROCm wheels.
9298
```
9399

94100
!!! note
95-
You will need to config the `$AITER_BRANCH_OR_COMMIT` for your purpose.
101+
- You will need to config the `$AITER_BRANCH_OR_COMMIT` for your purpose.
102+
- The validated `$AITER_BRANCH_OR_COMMIT` can be found in the [docker/Dockerfile.rocm_base](https://github.com/vllm-project/vllm/blob/main/docker/Dockerfile.rocm_base).
103+
96104

97-
4. Build vLLM. For example, vLLM on ROCM 6.3 can be built with the following steps:
105+
4. Build vLLM. For example, vLLM on ROCM 7.0 can be built with the following steps:
98106

99-
??? console "Commands"
107+
???+ console "Commands"
100108

101109
```bash
102110
pip install --upgrade pip
@@ -109,31 +117,48 @@ Currently, there are no pre-built ROCm wheels.
109117
scipy \
110118
huggingface-hub[cli,hf_transfer] \
111119
setuptools_scm
112-
pip install "numpy<2"
113120
pip install -r requirements/rocm.txt
114121
115-
# Build vLLM for MI210/MI250/MI300.
116-
export PYTORCH_ROCM_ARCH="gfx90a;gfx942"
122+
# To build for a single architecture (e.g., MI300) for faster installation (recommended):
123+
export PYTORCH_ROCM_ARCH="gfx942"
124+
125+
# To build vLLM for multiple arch MI210/MI250/MI300, use this instead
126+
# export PYTORCH_ROCM_ARCH="gfx90a;gfx942"
127+
117128
python3 setup.py develop
118129
```
119130

120131
This may take 5-10 minutes. Currently, `pip install .` does not work for ROCm installation.
121132

122133
!!! tip
123-
- Triton flash attention is used by default. For benchmarking purposes, it is recommended to run a warm-up step before collecting perf numbers.
124-
- Triton flash attention does not currently support sliding window attention. If using half precision, please use CK flash-attention for sliding window support.
125-
- To use CK flash-attention or PyTorch naive attention, please use this flag `export VLLM_USE_TRITON_FLASH_ATTN=0` to turn off triton flash attention.
126134
- The ROCm version of PyTorch, ideally, should match the ROCm driver version.
127135

128136
!!! tip
129137
- For MI300x (gfx942) users, to achieve optimal performance, please refer to [MI300x tuning guide](https://rocm.docs.amd.com/en/latest/how-to/tuning-guides/mi300x/index.html) for performance optimization and tuning tips on system and workflow level.
130-
For vLLM, please refer to [vLLM performance optimization](https://rocm.docs.amd.com/en/latest/how-to/tuning-guides/mi300x/workload.html#vllm-performance-optimization).
138+
For vLLM, please refer to [vLLM performance optimization](https://rocm.docs.amd.com/en/latest/how-to/rocm-for-ai/inference-optimization/vllm-optimization.html).
131139

132140
# --8<-- [end:build-wheel-from-source]
133141
# --8<-- [start:pre-built-images]
134142

135143
The [AMD Infinity hub for vLLM](https://hub.docker.com/r/rocm/vllm/tags) offers a prebuilt, optimized
136144
docker image designed for validating inference performance on the AMD Instinct™ MI300X accelerator.
145+
AMD also offers nightly prebuilt docker image from [Docker Hub](https://hub.docker.com/r/rocm/vllm-dev), which has vLLM and all its dependencies installed.
146+
147+
???+ console "Commands"
148+
```bash
149+
docker pull rocm/vllm-dev:nightly # to get the latest image
150+
docker run -it --rm \
151+
--network=host \
152+
--group-add=video \
153+
--ipc=host \
154+
--cap-add=SYS_PTRACE \
155+
--security-opt seccomp=unconfined \
156+
--device /dev/kfd \
157+
--device /dev/dri \
158+
-v <path/to/your/models>:/app/models \
159+
-e HF_HOME="/app/models" \
160+
rocm/vllm-dev:nightly
161+
```
137162

138163
!!! tip
139164
Please check [LLM inference performance validation on AMD Instinct MI300X](https://rocm.docs.amd.com/en/latest/how-to/performance-validation/mi300x/vllm-benchmark.html)
@@ -144,29 +169,29 @@ docker image designed for validating inference performance on the AMD Instinct
144169

145170
Building the Docker image from source is the recommended way to use vLLM with ROCm.
146171

147-
#### (Optional) Build an image with ROCm software stack
172+
??? info "(Optional) Build an image with ROCm software stack"
148173

149-
Build a docker image from [docker/Dockerfile.rocm_base](https://github.com/vllm-project/vllm/blob/main/docker/Dockerfile.rocm_base) which setup ROCm software stack needed by the vLLM.
150-
**This step is optional as this rocm_base image is usually prebuilt and store at [Docker Hub](https://hub.docker.com/r/rocm/vllm-dev) under tag `rocm/vllm-dev:base` to speed up user experience.**
151-
If you choose to build this rocm_base image yourself, the steps are as follows.
174+
Build a docker image from [docker/Dockerfile.rocm_base](https://github.com/vllm-project/vllm/blob/main/docker/Dockerfile.rocm_base) which setup ROCm software stack needed by the vLLM.
175+
**This step is optional as this rocm_base image is usually prebuilt and store at [Docker Hub](https://hub.docker.com/r/rocm/vllm-dev) under tag `rocm/vllm-dev:base` to speed up user experience.**
176+
If you choose to build this rocm_base image yourself, the steps are as follows.
152177

153-
It is important that the user kicks off the docker build using buildkit. Either the user put DOCKER_BUILDKIT=1 as environment variable when calling docker build command, or the user needs to set up buildkit in the docker daemon configuration /etc/docker/daemon.json as follows and restart the daemon:
178+
It is important that the user kicks off the docker build using buildkit. Either the user put DOCKER_BUILDKIT=1 as environment variable when calling docker build command, or the user needs to set up buildkit in the docker daemon configuration /etc/docker/daemon.json as follows and restart the daemon:
154179

155-
```json
156-
{
157-
"features": {
158-
"buildkit": true
180+
```json
181+
{
182+
"features": {
183+
"buildkit": true
184+
}
159185
}
160-
}
161-
```
186+
```
162187

163-
To build vllm on ROCm 6.3 for MI200 and MI300 series, you can use the default:
188+
To build vllm on ROCm 7.0 for MI200 and MI300 series, you can use the default:
164189

165-
```bash
166-
DOCKER_BUILDKIT=1 docker build \
167-
-f docker/Dockerfile.rocm_base \
168-
-t rocm/vllm-dev:base .
169-
```
190+
```bash
191+
DOCKER_BUILDKIT=1 docker build \
192+
-f docker/Dockerfile.rocm_base \
193+
-t rocm/vllm-dev:base .
194+
```
170195

171196
#### Build an image with vLLM
172197

@@ -181,24 +206,24 @@ It is important that the user kicks off the docker build using buildkit. Either
181206
}
182207
```
183208

184-
[docker/Dockerfile.rocm](https://github.com/vllm-project/vllm/blob/main/docker/Dockerfile.rocm) uses ROCm 6.3 by default, but also supports ROCm 5.7, 6.0, 6.1, and 6.2, in older vLLM branches.
209+
[docker/Dockerfile.rocm](https://github.com/vllm-project/vllm/blob/main/docker/Dockerfile.rocm) uses ROCm 7.0 by default, but also supports ROCm 5.7, 6.0, 6.1, 6.2, 6.3, and 6.4, in older vLLM branches.
185210
It provides flexibility to customize the build of docker image using the following arguments:
186211

187212
- `BASE_IMAGE`: specifies the base image used when running `docker build`. The default value `rocm/vllm-dev:base` is an image published and maintained by AMD. It is being built using [docker/Dockerfile.rocm_base](https://github.com/vllm-project/vllm/blob/main/docker/Dockerfile.rocm_base)
188213
- `ARG_PYTORCH_ROCM_ARCH`: Allows to override the gfx architecture values from the base docker image
189214

190215
Their values can be passed in when running `docker build` with `--build-arg` options.
191216

192-
To build vllm on ROCm 6.3 for MI200 and MI300 series, you can use the default:
217+
To build vllm on ROCm 7.0 for MI200 and MI300 series, you can use the default:
193218

194-
```bash
195-
DOCKER_BUILDKIT=1 docker build -f docker/Dockerfile.rocm -t vllm-rocm .
196-
```
219+
???+ console "Commands"
220+
```bash
221+
DOCKER_BUILDKIT=1 docker build -f docker/Dockerfile.rocm -t vllm-rocm .
222+
```
197223

198224
To run the above docker image `vllm-rocm`, use the below command:
199225

200-
??? console "Command"
201-
226+
???+ console "Commands"
202227
```bash
203228
docker run -it \
204229
--network=host \

docs/getting_started/installation/python_env_setup.inc.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
It's recommended to use [uv](https://docs.astral.sh/uv/), a very fast Python environment manager, to create and manage Python environments. Please follow the [documentation](https://docs.astral.sh/uv/#getting-started) to install `uv`. After installing `uv`, you can create a new Python environment using the following commands:
1+
On NVIDIA CUDA only, it's recommended to use [uv](https://docs.astral.sh/uv/), a very fast Python environment manager, to create and manage Python environments. Please follow the [documentation](https://docs.astral.sh/uv/#getting-started) to install `uv`. After installing `uv`, you can create a new Python environment using the following commands:
22

33
```bash
44
uv venv --python 3.12 --seed

0 commit comments

Comments
 (0)