Skip to content

Commit 64aa064

Browse files
yao-matrixkashif
andauthored
enable activation offloading on XPU (#3444)
Signed-off-by: Matrix Yao <[email protected]> Co-authored-by: Kashif Rasul <[email protected]>
1 parent be93a0c commit 64aa064

File tree

2 files changed

+31
-25
lines changed

2 files changed

+31
-25
lines changed

tests/test_activation_offloading.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import torch
1818
from torch import nn
1919
from transformers import AutoModelForCausalLM
20-
from transformers.testing_utils import require_peft, require_torch_accelerator
20+
from transformers.testing_utils import require_peft, require_torch_accelerator, torch_device
2121
from transformers.utils import is_peft_available
2222

2323
from trl.models.activation_offloading import NoOpManager, OffloadActivations
@@ -33,7 +33,7 @@ class TestActivationOffloading(unittest.TestCase):
3333
def test_offloading_with_peft_models(self) -> None:
3434
"""Test that activation offloading works with PEFT models."""
3535
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
36-
model = AutoModelForCausalLM.from_pretrained(model_id).cuda()
36+
model = AutoModelForCausalLM.from_pretrained(model_id).to(torch_device)
3737
peft_config = LoraConfig(
3838
lora_alpha=16,
3939
lora_dropout=0.1,
@@ -43,7 +43,7 @@ def test_offloading_with_peft_models(self) -> None:
4343
)
4444

4545
model = get_peft_model(model, peft_config)
46-
inp = torch.randint(0, 100, (2, 10), device="cuda")
46+
inp = torch.randint(0, 100, (2, 10), device=torch_device)
4747

4848
# First forward-backward pass without offloading
4949
torch.manual_seed(42)
@@ -79,8 +79,8 @@ def test_offloading_with_peft_models(self) -> None:
7979
@require_torch_accelerator
8080
def test_noop_manager_with_offloading(self):
8181
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
82-
model = AutoModelForCausalLM.from_pretrained(model_id).cuda()
83-
inp = torch.randint(0, 100, (2, 10), device="cuda")
82+
model = AutoModelForCausalLM.from_pretrained(model_id).to(torch_device)
83+
inp = torch.randint(0, 100, (2, 10), device=torch_device)
8484

8585
# Run with offloading but disable for specific section
8686
with OffloadActivations():
@@ -112,9 +112,9 @@ def test_min_offload_size(self):
112112
model = nn.Sequential(
113113
nn.Linear(5, 5), # Small layer that shouldn't be offloaded
114114
nn.Linear(5, 1000), # Large layer that should be offloaded
115-
).cuda()
115+
).to(torch_device)
116116

117-
inp = torch.randn(2, 5, device="cuda")
117+
inp = torch.randn(2, 5, device=torch_device)
118118

119119
with OffloadActivations(min_offload_size=1000):
120120
out = model(inp)
@@ -127,10 +127,10 @@ def test_min_offload_size(self):
127127
def test_real_hf_model(self):
128128
"""Test with an actual HuggingFace model"""
129129
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
130-
model = AutoModelForCausalLM.from_pretrained(model_id).cuda()
130+
model = AutoModelForCausalLM.from_pretrained(model_id).to(torch_device)
131131

132132
# Create small input
133-
inp = torch.randint(0, 100, (2, 10), device="cuda")
133+
inp = torch.randint(0, 100, (2, 10), device=torch_device)
134134

135135
# Baseline without offloading
136136
torch.manual_seed(42)

trl/models/activation_offloading.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,17 @@ def __init__(
8585
self.use_pin_memory = use_pin_memory
8686
self.virtual_memory_safe_pct = 60 # we should not exceed this percentage of memory
8787

88-
self.s0 = torch.cuda.default_stream() # comp stream
88+
self.accelerator_type = (
89+
torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
90+
)
91+
# NOTE: xpu doesn't have `default_stream` API, use `current_stream` instead
92+
self.s0 = (
93+
torch.xpu.current_stream() if self.accelerator_type == "xpu" else torch.cuda.default_stream()
94+
) # comp stream
8995

9096
# For streaming
9197
if self.use_streams:
92-
self.s1 = torch.cuda.Stream() # comms stream
98+
self.s1 = torch.Stream() if self.accelerator_type == "xpu" else torch.cuda.Stream() # comms stream
9399
self.fwd_stash = {} # tensor_id => (activation, ev1)
94100
if max_fwd_stash_size < 1:
95101
raise ValueError(f"max_fwd_stash_size should be at least 1 but is {max_fwd_stash_size}")
@@ -136,7 +142,7 @@ def pack_tensor(activation: torch.Tensor) -> int:
136142
# only offload hefty bois if they're activations on CUDA (our heuristic
137143
# for that is to check if they're not params or buffers)!
138144
if (
139-
activation.is_cuda
145+
activation.device.type in ["cuda", "xpu"]
140146
and num_bytes >= self.min_tensor_size_bytes
141147
and (
142148
not isinstance(activation, torch.nn.Parameter)
@@ -158,7 +164,7 @@ def pack_tensor(activation: torch.Tensor) -> int:
158164
self.s1.wait_stream(self.s0)
159165

160166
stream = self.s1 if self.use_streams else self.s0
161-
with torch.cuda.stream(stream):
167+
with stream if self.accelerator_type == "xpu" else torch.cuda.stream(stream):
162168
cpu_tensor = torch.empty_like(activation, pin_memory=self.use_pin_memory, device="cpu")
163169
cpu_tensor.copy_(activation, non_blocking=True)
164170
self.tracker[tensor_id] = (
@@ -194,14 +200,14 @@ def unpack_tensor_single_stream(unpack_tensor_id: int) -> torch.Tensor:
194200
if unpack_tensor_id not in self.tracker:
195201
raise ValueError(f"Untracked tensor with id {unpack_tensor_id}")
196202

197-
maybe_gpu_tensor, modified = self.tracker[unpack_tensor_id]
203+
maybe_accelerator_tensor, modified = self.tracker[unpack_tensor_id]
198204
if modified:
199-
gpu_tensor = maybe_gpu_tensor.to("cuda", non_blocking=True)
200-
maybe_gpu_tensor = gpu_tensor
205+
accelerator_tensor = maybe_accelerator_tensor.to(self.accelerator_type, non_blocking=True)
206+
maybe_accelerator_tensor = accelerator_tensor
201207

202208
# clear tensor from tracking
203209
del self.tracker[unpack_tensor_id]
204-
return maybe_gpu_tensor
210+
return maybe_accelerator_tensor
205211

206212
def unpack_tensor_with_streams(unpack_tensor_id: int) -> torch.Tensor:
207213
# backward pass - we are called with the tensor_id, which
@@ -229,7 +235,7 @@ def wait_and_del_remaining_references() -> None:
229235
if unpack_tensor_id not in self.tracker:
230236
raise ValueError(f"untracked tensor with id {unpack_tensor_id}")
231237

232-
maybe_gpu_tensor, modified = self.tracker[unpack_tensor_id]
238+
maybe_accelerator_tensor, modified = self.tracker[unpack_tensor_id]
233239
if modified:
234240
# Get data on the current autograd node
235241
graph_id = torch._C._current_graph_task_id()
@@ -243,19 +249,19 @@ def wait_and_del_remaining_references() -> None:
243249

244250
brought_back_from_cpu = True
245251
if unpack_tensor_id in self.fwd_stash:
246-
maybe_gpu_tensor = self.fwd_stash[unpack_tensor_id][0]
252+
maybe_accelerator_tensor = self.fwd_stash[unpack_tensor_id][0]
247253
brought_back_from_cpu = False
248254
else:
249255
# Kick off the process to bring tensors back
250-
with torch.cuda.stream(self.s1):
251-
gpu_tensor = maybe_gpu_tensor.to("cuda", non_blocking=True)
252-
maybe_gpu_tensor = gpu_tensor
256+
with self.s1 if self.accelerator_type == "xpu" else torch.cuda.stream(self.s1):
257+
accelerator_tensor = maybe_accelerator_tensor.to(self.accelerator_type, non_blocking=True)
258+
maybe_accelerator_tensor = accelerator_tensor
253259

254260
# Tell comp stream to wait for the info to be loaded before executing
255261
self.s0.wait_stream(self.s1)
256262

257263
# Stash the tensor to keep memory alive until compute stream is complete
258-
self.bwd_tensor_stash[unpack_tensor_id] = maybe_gpu_tensor
264+
self.bwd_tensor_stash[unpack_tensor_id] = maybe_accelerator_tensor
259265

260266
# Note: [Track views of the unpacked]
261267
# Why do we get the use count of the unpacked tensor here? We want an
@@ -270,7 +276,7 @@ def wait_and_del_remaining_references() -> None:
270276
# up as a view of the unpacked tensor.
271277
# 3. The user abuses the system somehow and manually relies on the
272278
# unpacked tensor to exist after the backward node has executed.
273-
storage_refcount = torch._C._storage_Use_Count(maybe_gpu_tensor.untyped_storage()._cdata)
279+
storage_refcount = torch._C._storage_Use_Count(maybe_accelerator_tensor.untyped_storage()._cdata)
274280

275281
def hook(outputs, inputs):
276282
# create events for the current node inputs/outputs if they were streamed in
@@ -312,7 +318,7 @@ def hook(outputs, inputs):
312318

313319
# clear tensor from tracking
314320
del self.tracker[unpack_tensor_id]
315-
return maybe_gpu_tensor
321+
return maybe_accelerator_tensor
316322

317323
unpack_tensor = unpack_tensor_with_streams if self.use_streams else unpack_tensor_single_stream
318324
super().__init__(pack_tensor, unpack_tensor)

0 commit comments

Comments
 (0)