Skip to content

Commit 12e9577

Browse files
committed
Eliminate ATen/cuda/CUDAContext.h and c10/cuda/CUDAGuard.h
1 parent 7049cae commit 12e9577

File tree

3 files changed

+38
-16
lines changed

3 files changed

+38
-16
lines changed

src/libtorchaudio/cuda_utils.h

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#pragma once
22

3-
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
3+
#include <torch/csrc/stable/c/shim.h>
44
#include <torch/csrc/stable/device.h>
55

66
#include <cuda_runtime_api.h>
@@ -17,9 +17,27 @@ inline cudaStream_t getCurrentCUDAStream(
1717
return static_cast<cudaStream_t>(stream_ptr);
1818
}
1919

20-
// A strip-down version of at::cuda::stream_synchronize
21-
inline void stream_synchronize(cudaStream_t stream) {
22-
TA_CUDA_CHECK(cudaStreamSynchronize(stream));
20+
inline void setCurrentCUDAStream(
21+
cudaStream_t stream,
22+
torch::stable::DeviceIndex device_index = -1) {
23+
TORCH_ERROR_CODE_CHECK(
24+
torch_set_current_cuda_stream(static_cast<void*>(stream), device_index));
25+
}
26+
27+
inline cudaStream_t getStreamFromPool(
28+
const bool isHighPriority = false,
29+
torch::stable::DeviceIndex device_index = -1) {
30+
void* stream_ptr = nullptr;
31+
TORCH_ERROR_CODE_CHECK(torch_get_cuda_stream_from_pool(
32+
isHighPriority, device_index, &stream_ptr));
33+
return static_cast<cudaStream_t>(stream_ptr);
34+
}
35+
36+
inline void synchronize(
37+
cudaStream_t stream,
38+
torch::stable::DeviceIndex device_index = -1) {
39+
TORCH_ERROR_CODE_CHECK(
40+
torch_cuda_stream_synchronize(static_cast<void*>(stream), device_index));
2341
}
2442

2543
} // namespace libtorchaudio::cuda

src/libtorchaudio/forced_align/gpu/compute.cu

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
#include <libtorchaudio/cuda_utils.h>
12
#include <libtorchaudio/utils.h>
23
#include <torch/csrc/stable/library.h>
34
#include <torch/headeronly/core/Dispatch_v2.h>
45
#include <torch/headeronly/core/ScalarType.h>
5-
#include <ATen/cuda/CUDAContext.h>
6+
#include <c10/cuda/CUDAException.h>
67

78
#include <cub/cub.cuh>
89
#include <limits.h>
@@ -120,8 +121,9 @@ void forced_align_impl(
120121
const Tensor& targets,
121122
const int64_t blank,
122123
Tensor& paths) {
123-
auto defaultStream = at::cuda::getCurrentCUDAStream();
124-
auto cpuDataTranferStream = at::cuda::getStreamFromPool();
124+
auto device_index = logProbs.get_device_index();
125+
auto defaultStream = libtorchaudio::cuda::getCurrentCUDAStream(device_index);
126+
auto cpuDataTranferStream = libtorchaudio::cuda::getStreamFromPool(false, device_index);
125127
const scalar_t kNegInfinity = -std::numeric_limits<scalar_t>::infinity();
126128
using target_t = typename std::
127129
conditional<target_scalar_type == ScalarType::Int, int, int64_t>::type;
@@ -208,12 +210,14 @@ void forced_align_impl(
208210
C10_CUDA_KERNEL_LAUNCH_CHECK();
209211
++backPtrBufferLen;
210212
if (backPtrBufferLen == kBackPtrBufferSize || t == T - 1) {
211-
cpuDataTranferStream.synchronize();
213+
//cpuDataTranferStream.synchronize();
214+
libtorchaudio::cuda::synchronize(cpuDataTranferStream, device_index);
212215
// GPU -> GPU copy
213216
bufferCopy = torch::stable::clone(backPtrBuffer);
214217
STD_TORCH_CHECK(bufferCopy.is_contiguous(), "unexpected fail, need to implement stable::Tensor::contiguous()")
215-
defaultStream.synchronize();
216-
at::cuda::setCurrentCUDAStream(cpuDataTranferStream);
218+
//defaultStream.synchronize();
219+
libtorchaudio::cuda::synchronize(defaultStream, device_index);
220+
libtorchaudio::cuda::setCurrentCUDAStream(cpuDataTranferStream, device_index);
217221
// Copy ASYNC from GPU to CPU
218222
int64_t offset =
219223
static_cast<int64_t>(t + 1 - backPtrBufferLen) * S * sizeof(int8_t);
@@ -223,11 +227,12 @@ void forced_align_impl(
223227
backPtrBufferLen * S * sizeof(int8_t),
224228
cudaMemcpyDeviceToHost,
225229
cpuDataTranferStream));
226-
at::cuda::setCurrentCUDAStream(defaultStream);
230+
libtorchaudio::cuda::setCurrentCUDAStream(defaultStream, device_index);
227231
backPtrBufferLen = 0;
228232
}
229233
}
230-
cpuDataTranferStream.synchronize();
234+
//cpuDataTranferStream.synchronize();
235+
libtorchaudio::cuda::synchronize(cpuDataTranferStream, device_index);
231236
auto alphasCpu = torchaudio::stable::cpu(alphas);
232237
auto alphasCpu_a = torchaudio::accessor<scalar_t, 2>(alphasCpu);
233238
int curIdxOffset = ((T - 1) % 2);

src/libtorchaudio/iir_cuda.cu

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#include <libtorchaudio/utils.h>
2+
#include <torch/csrc/stable/accelerator.h>
23
#include <torch/headeronly/core/Dispatch_v2.h>
34
#include <torch/headeronly/core/ScalarType.h>
4-
#include <c10/cuda/CUDAGuard.h>
5-
#include <c10/core/DeviceGuard.h>
5+
#include <c10/cuda/CUDAException.h>
66

77
using torch::headeronly::ScalarType;
88
using torch::stable::Tensor;
@@ -64,8 +64,7 @@ Tensor cuda_lfilter_core_loop(
6464

6565
STD_TORCH_CHECK(in.size(2) + a_flipped.size(1) - 1 == padded_out.size(2));
6666

67-
const at::cuda::OptionalCUDAGuard device_guard(in.get_device_index());
68-
67+
const torch::stable::accelerator::DeviceGuard device_guard(in.get_device_index());
6968
const dim3 threads(256);
7069
const dim3 blocks((N * C + threads.x - 1) / threads.x);
7170

0 commit comments

Comments
 (0)