Skip to content

Commit fdf310b

Browse files
StanHatkobigPYJ1151
authored andcommitted
[HARDWARE][CPU] Add Option for Disabling Binding to Specific CPU Cores (vllm-project#27953)
Signed-off-by: Stan Hatko <[email protected]> Co-authored-by: Li, Jiang <[email protected]>
1 parent 29d1a6b commit fdf310b

File tree

3 files changed

+15
-8
lines changed

3 files changed

+15
-8
lines changed

docs/getting_started/installation/cpu.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ Currently, there are no pre-built CPU wheels.
9494
## Related runtime environment variables
9595

9696
- `VLLM_CPU_KVCACHE_SPACE`: specify the KV Cache size (e.g, `VLLM_CPU_KVCACHE_SPACE=40` means 40 GiB space for KV cache), larger setting will allow vLLM running more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users. Default value is `0`.
97-
- `VLLM_CPU_OMP_THREADS_BIND`: specify the CPU cores dedicated to the OpenMP threads, can be set as CPU id lists or `auto` (by default). For example, `VLLM_CPU_OMP_THREADS_BIND=0-31` means there will be 32 OpenMP threads bound on 0-31 CPU cores. `VLLM_CPU_OMP_THREADS_BIND=0-31|32-63` means there will be 2 tensor parallel processes, 32 OpenMP threads of rank0 are bound on 0-31 CPU cores, and the OpenMP threads of rank1 are bound on 32-63 CPU cores. By setting to `auto`, the OpenMP threads of each rank are bound to the CPU cores in each NUMA node respectively.
97+
- `VLLM_CPU_OMP_THREADS_BIND`: specify the CPU cores dedicated to the OpenMP threads, can be set as CPU id lists, `auto` (by default), or `nobind` (to disable binding to individual CPU cores and to inherit user-defined OpenMP variables). For example, `VLLM_CPU_OMP_THREADS_BIND=0-31` means there will be 32 OpenMP threads bound on 0-31 CPU cores. `VLLM_CPU_OMP_THREADS_BIND=0-31|32-63` means there will be 2 tensor parallel processes, 32 OpenMP threads of rank0 are bound on 0-31 CPU cores, and the OpenMP threads of rank1 are bound on 32-63 CPU cores. By setting to `auto`, the OpenMP threads of each rank are bound to the CPU cores in each NUMA node respectively. If set to `nobind`, the number of OpenMP threads is determined by the standard `OMP_NUM_THREADS` environment variable.
9898
- `VLLM_CPU_NUM_OF_RESERVED_CPU`: specify the number of CPU cores which are not dedicated to the OpenMP threads for each rank. The variable only takes effect when VLLM_CPU_OMP_THREADS_BIND is set to `auto`. Default value is `None`. If the value is not set and use `auto` thread binding, no CPU will be reserved for `world_size == 1`, 1 CPU per rank will be reserved for `world_size > 1`.
9999
- `CPU_VISIBLE_MEMORY_NODES`: specify visible NUMA memory nodes for vLLM CPU workers, similar to ```CUDA_VISIBLE_DEVICES```. The variable only takes effect when VLLM_CPU_OMP_THREADS_BIND is set to `auto`. The variable provides more control for the auto thread-binding feature, such as masking nodes and changing nodes binding sequence.
100100
- `VLLM_CPU_MOE_PREPACK` (x86 only): whether to use prepack for MoE layer. This will be passed to `ipex.llm.modules.GatedMLPMOE`. Default is `1` (True). On unsupported CPUs, you might need to set this to `0` (False).

vllm/platforms/cpu.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import regex as re
1515
import torch
1616

17+
from vllm import envs
1718
from vllm.logger import init_logger
1819
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
1920

@@ -151,7 +152,6 @@ def get_attn_backend_cls(
151152

152153
@classmethod
153154
def get_device_total_memory(cls, device_id: int = 0) -> int:
154-
import vllm.envs as envs
155155
from vllm.utils.mem_constants import GiB_bytes
156156

157157
kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE
@@ -289,11 +289,16 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
289289
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
290290

291291
# Note: to avoid the error 'nthreads cannot be larger than environment
292-
# variable "NUMEXPR_MAX_THREADS" (64)'.
292+
# variable "NUMEXPR_MAX_THREADS" (64)'.
293293
os.environ["NUMEXPR_MAX_THREADS"] = str(get_max_threads())
294294

295-
# Set default threads num for OpenMP parallel
296-
os.environ["OMP_NUM_THREADS"] = str(torch.get_num_threads())
295+
if envs.VLLM_CPU_OMP_THREADS_BIND != "nobind":
296+
# Set default threads num for OpenMP parallel
297+
os.environ["OMP_NUM_THREADS"] = str(torch.get_num_threads())
298+
else:
299+
# In this case, setting the OpenMP configuration via
300+
# OMP_NUM_THREADS is up to the user.
301+
logger.info("Disabling binding processes to CPU cores...")
297302

298303
# Disable torch async compiling which won't work with daemonic processes
299304
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"

vllm/v1/worker/cpu_worker.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,15 @@ def init_device(self):
6969
self.local_omp_cpuid = self._get_autobind_cpu_ids(
7070
lambda cpus: [cpu for cpu in cpus if cpu.id % 8 < 4]
7171
)
72-
elif current_platform.get_cpu_architecture() == CpuArchEnum.X86:
72+
elif cpu_arch == CpuArchEnum.X86:
7373
# For x86 SMT-2, use 1 CPU per core
7474
self.local_omp_cpuid = self._get_autobind_cpu_ids(
7575
lambda cpus: cpus[-1:]
7676
)
7777
else:
78-
self.local_omp_cpuid = "all"
78+
self.local_omp_cpuid = "nobind"
79+
elif omp_cpuids == "nobind":
80+
self.local_omp_cpuid = "nobind"
7981
else:
8082
local_dp_rank = self.parallel_config.data_parallel_rank_local
8183
omp_cpuids = omp_cpuids.split("|")
@@ -86,7 +88,7 @@ def init_device(self):
8688
]
8789
self.local_omp_cpuid = omp_cpuids[self.rank]
8890

89-
if self.local_omp_cpuid != "all":
91+
if self.local_omp_cpuid != "nobind":
9092
ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid)
9193
if ret:
9294
logger.info(ret)

0 commit comments

Comments
 (0)