Skip to content

Commit 54b64f9

Browse files
committed
Eliminate c10/cuda/CUDAException.h
1 parent d54b203 commit 54b64f9

File tree

4 files changed

+6
-10
lines changed

4 files changed

+6
-10
lines changed

src/libtorchaudio/cuda_utils.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,6 @@
55

66
#include <cuda_runtime_api.h>
77

8-
// TODO: replace TA_CUDA_CHECK with STD_CUDA_CHECK after
9-
// https://github.com/pytorch/pytorch/pull/169385 has landed.
10-
#define TA_CUDA_CHECK(...) __VA_ARGS__
11-
128
namespace libtorchaudio::cuda {
139

1410
inline cudaStream_t getCurrentCUDAStream(

src/libtorchaudio/forced_align/gpu/compute.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
#include <libtorchaudio/cuda_utils.h>
22
#include <libtorchaudio/utils.h>
33
#include <torch/csrc/stable/library.h>
4+
#include <torch/csrc/stable/macros.h>
45
#include <torch/headeronly/core/Dispatch_v2.h>
56
#include <torch/headeronly/core/ScalarType.h>
6-
#include <c10/cuda/CUDAException.h>
77

88
#include <cub/cub.cuh>
99
#include <limits.h>
@@ -207,7 +207,7 @@ void forced_align_impl(
207207
backPtrBufferLen,
208208
torchaudio::packed_accessor32<scalar_t, 2>(alphas),
209209
torchaudio::packed_accessor32<int8_t, 2>(backPtrBuffer));
210-
C10_CUDA_KERNEL_LAUNCH_CHECK();
210+
STD_CUDA_KERNEL_LAUNCH_CHECK();
211211
++backPtrBufferLen;
212212
if (backPtrBufferLen == kBackPtrBufferSize || t == T - 1) {
213213
libtorchaudio::cuda::synchronize(cpuDataTranferStream, device_index);
@@ -219,7 +219,7 @@ void forced_align_impl(
219219
// Copy ASYNC from GPU to CPU
220220
int64_t offset =
221221
static_cast<int64_t>(t + 1 - backPtrBufferLen) * S * sizeof(int8_t);
222-
C10_CUDA_CHECK(cudaMemcpyAsync(
222+
STD_CUDA_CHECK(cudaMemcpyAsync(
223223
static_cast<int8_t*>(backPtrCpu.data_ptr()) + offset,
224224
bufferCopy.data_ptr(),
225225
backPtrBufferLen * S * sizeof(int8_t),

src/libtorchaudio/iir_cuda.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#include <libtorchaudio/utils.h>
22
#include <torch/csrc/stable/accelerator.h>
3+
#include <torch/csrc/stable/macros.h>
34
#include <torch/headeronly/core/Dispatch_v2.h>
45
#include <torch/headeronly/core/ScalarType.h>
5-
#include <c10/cuda/CUDAException.h>
66

77
using torch::headeronly::ScalarType;
88
using torch::stable::Tensor;
@@ -74,7 +74,7 @@ Tensor cuda_lfilter_core_loop(
7474
torchaudio::packed_accessor_size_t<scalar_t, 3>(in),
7575
torchaudio::packed_accessor_size_t<scalar_t, 2>(a_flipped),
7676
torchaudio::packed_accessor_size_t<scalar_t, 3>(padded_out)));
77-
C10_CUDA_KERNEL_LAUNCH_CHECK();
77+
STD_CUDA_KERNEL_LAUNCH_CHECK();
7878
}), AT_FLOATING_TYPES);
7979
return padded_out;
8080
}

src/libtorchaudio/stable/ops.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ inline Tensor cuda(const Tensor& self, int32_t cuda_index) {
106106
inline Tensor new_zeros(
107107
const Tensor& self,
108108
std::vector<int64_t> size,
109-
std::optional<c10::ScalarType> dtype = std::nullopt,
109+
std::optional<torch::headeronly::ScalarType> dtype = std::nullopt,
110110
std::optional<Layout> layout = std::nullopt,
111111
std::optional<Device> device = std::nullopt,
112112
std::optional<bool> pin_memory = std::nullopt) {

0 commit comments

Comments
 (0)