Skip to content

Commit 85f4ce5

Browse files
authored
[STABLE ABI] Eliminate Device.h (#4145)
1 parent f1a2a37 commit 85f4ce5

File tree

3 files changed

+6
-51
lines changed

3 files changed

+6
-51
lines changed

src/libtorchaudio/forced_align/gpu/compute.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ std::tuple<Tensor, Tensor> compute(
294294
auto B = logProbs.size(0);
295295
auto T = logProbs.size(1); // num frames
296296

297-
Tensor paths = torchaudio::stable::new_zeros(targets, {B, T}, /*dtype=*/std::nullopt, /*layout=*/std::nullopt, /*device=*/torchaudio::stable::cpu_device());
297+
Tensor paths = torchaudio::stable::new_zeros(targets, {B, T}, /*dtype=*/std::nullopt, /*layout=*/std::nullopt, /*device=*/torch::stable::DeviceType::CPU);
298298

299299
THO_DISPATCH_V2(logProbs.scalar_type(), "forced_align_impl", AT_WRAP([&] {
300300
if (targets.scalar_type() == ScalarType::Long) {

src/libtorchaudio/stable/Device.h

Lines changed: 0 additions & 45 deletions
This file was deleted.

src/libtorchaudio/stable/ops.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
features implemented here.
1010
*/
1111

12-
#include <libtorchaudio/stable/Device.h>
1312
#include <torch/csrc/stable/ops.h>
13+
#include <torch/csrc/stable/tensor.h>
1414

1515
#ifdef USE_CUDA
1616
#include <ATen/cuda/CUDAContext.h>
@@ -113,7 +113,7 @@ inline Tensor new_zeros(
113113
std::vector<int64_t> size,
114114
std::optional<c10::ScalarType> dtype = std::nullopt,
115115
std::optional<Layout> layout = std::nullopt,
116-
std::optional<Device> device = std::nullopt,
116+
std::optional<torch::stable::Device> device = std::nullopt,
117117
std::optional<bool> pin_memory = std::nullopt) {
118118
int32_t target_dtype{};
119119
if (dtype.has_value()) {
@@ -130,11 +130,11 @@ inline Tensor new_zeros(
130130
TORCH_ERROR_CODE_CHECK(aoti_torch_get_layout(self.get(), &layout_));
131131
}
132132

133-
DeviceType device_type;
134-
DeviceIndex device_index = 0;
133+
int32_t device_type;
134+
torch::stable::DeviceIndex device_index = 0;
135135
if (device.has_value()) {
136136
auto device_ = device.value();
137-
device_type = device_.type();
137+
device_type = static_cast<int32_t>(device_.type());
138138
device_index = device_.index();
139139
} else {
140140
TORCH_ERROR_CODE_CHECK(

0 commit comments

Comments
 (0)