2020import psutil
2121import pytest
2222from transformers import AutoModelForCausalLM
23- from transformers .testing_utils import require_torch_multi_gpu
23+ from transformers .testing_utils import require_torch_multi_accelerator , torch_device
2424
2525from trl .extras .vllm_client import VLLMClient
2626from trl .scripts .vllm_serve import chunk_list
2727
28- from .testing_utils import require_3_gpus
28+ from .testing_utils import require_3_accelerators
2929
3030
3131class TestChunkList (unittest .TestCase ):
@@ -55,15 +55,16 @@ def test_any_dtype(self):
5555
5656
5757@pytest .mark .slow
58- @require_torch_multi_gpu
58+ @require_torch_multi_accelerator
5959class TestVLLMClientServer (unittest .TestCase ):
6060 model_id = "Qwen/Qwen2.5-1.5B"
6161
6262 @classmethod
6363 def setUpClass (cls ):
64- # We want the server to run on GPU 1, so we set CUDA_VISIBLE_DEVICES to "1"
64+ # We want the server to run on accelerator 1, so we set VISIBLE_DEVICES to "1"
6565 env = os .environ .copy ()
66- env ["CUDA_VISIBLE_DEVICES" ] = "1" # Restrict to GPU 1
66+ VISIBLE_DEVICES = "ZE_AFFINITY_MASK" if torch_device == "xpu" else "CUDA_VISIBLE_DEVICES"
67+ env [VISIBLE_DEVICES ] = "1" # Restrict to accelerator 1
6768
6869 # Start the server process
6970 cls .server_process = subprocess .Popen (
@@ -107,7 +108,7 @@ def test_generate_with_params(self):
107108 self .assertLessEqual (len (seq ), 32 )
108109
109110 def test_update_model_params (self ):
110- model = AutoModelForCausalLM .from_pretrained (self .model_id , device_map = "cuda" )
111+ model = AutoModelForCausalLM .from_pretrained (self .model_id , device_map = torch_device )
111112 self .client .update_model_params (model )
112113
113114 def test_reset_prefix_cache (self ):
@@ -132,15 +133,16 @@ def tearDownClass(cls):
132133
133134
134135@pytest .mark .slow
135- @require_3_gpus
136+ @require_3_accelerators
136137class TestVLLMClientServerTP (unittest .TestCase ):
137138 model_id = "Qwen/Qwen2.5-1.5B"
138139
139140 @classmethod
140141 def setUpClass (cls ):
141- # We want the server to run on GPU 1 and 2, so we set CUDA_VISIBLE_DEVICES to "1,2"
142+ # We want the server to run on accelerator 1 and 2, so we set VISIBLE_DEVICES to "1,2"
142143 env = os .environ .copy ()
143- env ["CUDA_VISIBLE_DEVICES" ] = "1,2" # Restrict to GPU 1 and 2
144+ VISIBLE_DEVICES = "ZE_AFFINITY_MASK" if torch_device == "xpu" else "CUDA_VISIBLE_DEVICES"
145+ env [VISIBLE_DEVICES ] = "1,2" # Restrict to accelerator 1 and 2
144146
145147 # Start the server process
146148 cls .server_process = subprocess .Popen (
@@ -169,7 +171,7 @@ def test_generate(self):
169171 self .assertTrue (all (isinstance (tok , int ) for tok in seq ))
170172
171173 def test_update_model_params (self ):
172- model = AutoModelForCausalLM .from_pretrained (self .model_id , device_map = "cuda" )
174+ model = AutoModelForCausalLM .from_pretrained (self .model_id , device_map = torch_device )
173175 self .client .update_model_params (model )
174176
175177 def test_reset_prefix_cache (self ):
@@ -194,15 +196,16 @@ def tearDownClass(cls):
194196
195197
196198@pytest .mark .slow
197- @require_3_gpus
199+ @require_3_accelerators
198200class TestVLLMClientServerDP (unittest .TestCase ):
199201 model_id = "Qwen/Qwen2.5-1.5B"
200202
201203 @classmethod
202204 def setUpClass (cls ):
203- # We want the server to run on GPU 1 and 2, so we set CUDA_VISIBLE_DEVICES to "1,2"
205+ # We want the server to run on accelerator 1 and 2, so we set VISIBLE_DEVICES to "1,2"
204206 env = os .environ .copy ()
205- env ["CUDA_VISIBLE_DEVICES" ] = "1,2" # Restrict to GPU 1 and 2
207+ VISIBLE_DEVICES = "ZE_AFFINITY_MASK" if torch_device == "xpu" else "CUDA_VISIBLE_DEVICES"
208+ env [VISIBLE_DEVICES ] = "1,2" # Restrict to accelerator 1 and 2
206209
207210 # Start the server process
208211 cls .server_process = subprocess .Popen (
@@ -230,7 +233,7 @@ def test_generate(self):
230233 self .assertTrue (all (isinstance (tok , int ) for tok in seq ))
231234
232235 def test_update_model_params (self ):
233- model = AutoModelForCausalLM .from_pretrained (self .model_id , device_map = "cuda" )
236+ model = AutoModelForCausalLM .from_pretrained (self .model_id , device_map = torch_device )
234237 self .client .update_model_params (model )
235238
236239 def test_reset_prefix_cache (self ):
0 commit comments