Skip to content

Commit a6a3a6d

Browse files
committed
Add temporary shim functions.
1 parent 12e9577 commit a6a3a6d

File tree

2 files changed

+42
-5
lines changed

2 files changed

+42
-5
lines changed

src/libtorchaudio/cuda_utils.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#pragma once
22

3+
#include <libtorchaudio/shim_temporary.h>
34
#include <torch/csrc/stable/c/shim.h>
45
#include <torch/csrc/stable/device.h>
56

@@ -20,24 +21,24 @@ inline cudaStream_t getCurrentCUDAStream(
2021
inline void setCurrentCUDAStream(
2122
cudaStream_t stream,
2223
torch::stable::DeviceIndex device_index = -1) {
23-
TORCH_ERROR_CODE_CHECK(
24-
torch_set_current_cuda_stream(static_cast<void*>(stream), device_index));
24+
TORCH_ERROR_CODE_CHECK(tmp_torch_set_current_cuda_stream(
25+
static_cast<void*>(stream), device_index));
2526
}
2627

2728
inline cudaStream_t getStreamFromPool(
2829
const bool isHighPriority = false,
2930
torch::stable::DeviceIndex device_index = -1) {
3031
void* stream_ptr = nullptr;
31-
TORCH_ERROR_CODE_CHECK(torch_get_cuda_stream_from_pool(
32+
TORCH_ERROR_CODE_CHECK(tmp_torch_get_cuda_stream_from_pool(
3233
isHighPriority, device_index, &stream_ptr));
3334
return static_cast<cudaStream_t>(stream_ptr);
3435
}
3536

3637
inline void synchronize(
3738
cudaStream_t stream,
3839
torch::stable::DeviceIndex device_index = -1) {
39-
TORCH_ERROR_CODE_CHECK(
40-
torch_cuda_stream_synchronize(static_cast<void*>(stream), device_index));
40+
TORCH_ERROR_CODE_CHECK(tmp_torch_cuda_stream_synchronize(
41+
static_cast<void*>(stream), device_index));
4142
}
4243

4344
} // namespace libtorchaudio::cuda

src/libtorchaudio/shim_temporary.h

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
#pragma once
2+
// TODO: remove this file once https://github.com/pytorch/pytorch/pull/169376
3+
// has landed.
4+
5+
#include <c10/cuda/CUDAStream.h>
6+
#include <torch/csrc/inductor/aoti_torch/utils.h>
7+
#include <torch/csrc/stable/c/shim.h>
8+
9+
inline AOTITorchError tmp_torch_set_current_cuda_stream(
10+
void* stream,
11+
int32_t device_index) {
12+
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
13+
at::cuda::setCurrentCUDAStream(at::cuda::getStreamFromExternal(
14+
static_cast<cudaStream_t>(stream), device_index));
15+
});
16+
}
17+
18+
inline AOTITorchError tmp_torch_get_cuda_stream_from_pool(
19+
const bool isHighPriority,
20+
int32_t device_index,
21+
void** ret_stream) {
22+
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
23+
*(cudaStream_t*)(ret_stream) =
24+
at::cuda::getStreamFromPool(isHighPriority, device_index);
25+
});
26+
}
27+
28+
inline AOTITorchError tmp_torch_cuda_stream_synchronize(
29+
void* stream,
30+
int32_t device_index) {
31+
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
32+
at::cuda::getStreamFromExternal(
33+
static_cast<cudaStream_t>(stream), device_index)
34+
.synchronize();
35+
});
36+
}

0 commit comments

Comments
 (0)