@@ -85,11 +85,17 @@ def __init__(
8585 self .use_pin_memory = use_pin_memory
8686 self .virtual_memory_safe_pct = 60 # we should not exceed this percentage of memory
8787
88- self .s0 = torch .cuda .default_stream () # comp stream
88+ self .accelerator_type = (
89+ torch .accelerator .current_accelerator ().type if hasattr (torch , "accelerator" ) else "cuda"
90+ )
91+ # NOTE: xpu doesn't have `default_stream` API, use `current_stream` instead
92+ self .s0 = (
93+ torch .xpu .current_stream () if self .accelerator_type == "xpu" else torch .cuda .default_stream ()
94+ ) # comp stream
8995
9096 # For streaming
9197 if self .use_streams :
92- self .s1 = torch .cuda .Stream () # comms stream
98+ self .s1 = torch .Stream () if self . accelerator_type == "xpu" else torch . cuda .Stream () # comms stream
9399 self .fwd_stash = {} # tensor_id => (activation, ev1)
94100 if max_fwd_stash_size < 1 :
95101 raise ValueError (f"max_fwd_stash_size should be at least 1 but is { max_fwd_stash_size } " )
@@ -136,7 +142,7 @@ def pack_tensor(activation: torch.Tensor) -> int:
136142 # only offload hefty bois if they're activations on CUDA (our heuristic
137143 # for that is to check if they're not params or buffers)!
138144 if (
139- activation .is_cuda
145+ activation .device . type in [ "cuda" , "xpu" ]
140146 and num_bytes >= self .min_tensor_size_bytes
141147 and (
142148 not isinstance (activation , torch .nn .Parameter )
@@ -158,7 +164,7 @@ def pack_tensor(activation: torch.Tensor) -> int:
158164 self .s1 .wait_stream (self .s0 )
159165
160166 stream = self .s1 if self .use_streams else self .s0
161- with torch .cuda .stream (stream ):
167+ with stream if self . accelerator_type == "xpu" else torch .cuda .stream (stream ):
162168 cpu_tensor = torch .empty_like (activation , pin_memory = self .use_pin_memory , device = "cpu" )
163169 cpu_tensor .copy_ (activation , non_blocking = True )
164170 self .tracker [tensor_id ] = (
@@ -194,14 +200,14 @@ def unpack_tensor_single_stream(unpack_tensor_id: int) -> torch.Tensor:
194200 if unpack_tensor_id not in self .tracker :
195201 raise ValueError (f"Untracked tensor with id { unpack_tensor_id } " )
196202
197- maybe_gpu_tensor , modified = self .tracker [unpack_tensor_id ]
203+ maybe_accelerator_tensor , modified = self .tracker [unpack_tensor_id ]
198204 if modified :
199- gpu_tensor = maybe_gpu_tensor .to ("cuda" , non_blocking = True )
200- maybe_gpu_tensor = gpu_tensor
205+ accelerator_tensor = maybe_accelerator_tensor .to (self . accelerator_type , non_blocking = True )
206+ maybe_accelerator_tensor = accelerator_tensor
201207
202208 # clear tensor from tracking
203209 del self .tracker [unpack_tensor_id ]
204- return maybe_gpu_tensor
210+ return maybe_accelerator_tensor
205211
206212 def unpack_tensor_with_streams (unpack_tensor_id : int ) -> torch .Tensor :
207213 # backward pass - we are called with the tensor_id, which
@@ -229,7 +235,7 @@ def wait_and_del_remaining_references() -> None:
229235 if unpack_tensor_id not in self .tracker :
230236 raise ValueError (f"untracked tensor with id { unpack_tensor_id } " )
231237
232- maybe_gpu_tensor , modified = self .tracker [unpack_tensor_id ]
238+ maybe_accelerator_tensor , modified = self .tracker [unpack_tensor_id ]
233239 if modified :
234240 # Get data on the current autograd node
235241 graph_id = torch ._C ._current_graph_task_id ()
@@ -243,19 +249,19 @@ def wait_and_del_remaining_references() -> None:
243249
244250 brought_back_from_cpu = True
245251 if unpack_tensor_id in self .fwd_stash :
246- maybe_gpu_tensor = self .fwd_stash [unpack_tensor_id ][0 ]
252+ maybe_accelerator_tensor = self .fwd_stash [unpack_tensor_id ][0 ]
247253 brought_back_from_cpu = False
248254 else :
249255 # Kick off the process to bring tensors back
250- with torch .cuda .stream (self .s1 ):
251- gpu_tensor = maybe_gpu_tensor .to ("cuda" , non_blocking = True )
252- maybe_gpu_tensor = gpu_tensor
256+ with self . s1 if self . accelerator_type == "xpu" else torch .cuda .stream (self .s1 ):
257+ accelerator_tensor = maybe_accelerator_tensor .to (self . accelerator_type , non_blocking = True )
258+ maybe_accelerator_tensor = accelerator_tensor
253259
254260 # Tell comp stream to wait for the info to be loaded before executing
255261 self .s0 .wait_stream (self .s1 )
256262
257263 # Stash the tensor to keep memory alive until compute stream is complete
258- self .bwd_tensor_stash [unpack_tensor_id ] = maybe_gpu_tensor
264+ self .bwd_tensor_stash [unpack_tensor_id ] = maybe_accelerator_tensor
259265
260266 # Note: [Track views of the unpacked]
261267 # Why do we get the use count of the unpacked tensor here? We want an
@@ -270,7 +276,7 @@ def wait_and_del_remaining_references() -> None:
270276 # up as a view of the unpacked tensor.
271277 # 3. The user abuses the system somehow and manually relies on the
272278 # unpacked tensor to exist after the backward node has executed.
273- storage_refcount = torch ._C ._storage_Use_Count (maybe_gpu_tensor .untyped_storage ()._cdata )
279+ storage_refcount = torch ._C ._storage_Use_Count (maybe_accelerator_tensor .untyped_storage ()._cdata )
274280
275281 def hook (outputs , inputs ):
276282 # create events for the current node inputs/outputs if they were streamed in
@@ -312,7 +318,7 @@ def hook(outputs, inputs):
312318
313319 # clear tensor from tracking
314320 del self .tracker [unpack_tensor_id ]
315- return maybe_gpu_tensor
321+ return maybe_accelerator_tensor
316322
317323 unpack_tensor = unpack_tensor_with_streams if self .use_streams else unpack_tensor_single_stream
318324 super ().__init__ (pack_tensor , unpack_tensor )
0 commit comments