Skip to content

Commit 1335f78

Browse files
committed
[STABLE ABI] Port overdrive
1 parent ee1a135 commit 1335f78

File tree

1 file changed

+63
-37
lines changed

1 file changed

+63
-37
lines changed

src/libtorchaudio/overdrive.cpp

Lines changed: 63 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,78 @@
1-
#include <torch/script.h>
2-
#include <torch/torch.h>
1+
#include <torch/csrc/stable/library.h>
2+
#include <torch/csrc/stable/ops.h>
3+
#include <torch/csrc/stable/tensor.h>
4+
#include <torch/headeronly/core/Dispatch_v2.h>
5+
#include <torch/headeronly/core/TensorAccessor.h>
36

47
namespace {
58

9+
using torch::stable::Tensor;
10+
11+
template <typename T, size_t N>
12+
using TensorAccessor = torch::headeronly::HeaderOnlyTensorAccessor<T, N>;
13+
14+
// TODO: eliminate accessor<T, N>(t) in favor of t.accessor<T, N>
15+
// after Tensor::accessor is supported in stable ABI
16+
template <typename T, size_t N>
17+
inline TensorAccessor<T, N> accessor(Tensor t) {
18+
return TensorAccessor<T, N>(
19+
reinterpret_cast<T*>(t.data_ptr()), t.sizes().data(), t.strides().data());
20+
}
21+
622
template <typename scalar_t>
723
void overdrive_cpu_kernel(
8-
at::TensorAccessor<scalar_t, 2> waveform_accessor,
9-
at::TensorAccessor<scalar_t, 2> temp_accessor,
10-
at::TensorAccessor<scalar_t, 1> last_in_accessor,
11-
at::TensorAccessor<scalar_t, 1> last_out_accessor,
12-
at::TensorAccessor<scalar_t, 2> output_waveform_accessor) {
24+
TensorAccessor<scalar_t, 2> waveform_accessor,
25+
TensorAccessor<scalar_t, 2> temp_accessor,
26+
TensorAccessor<scalar_t, 1> last_in_accessor,
27+
TensorAccessor<scalar_t, 1> last_out_accessor,
28+
TensorAccessor<scalar_t, 2> output_waveform_accessor) {
1329
int64_t n_frames = waveform_accessor.size(1);
1430
int64_t n_channels = waveform_accessor.size(0);
1531

16-
at::parallel_for(0, n_channels, 1, [&](int64_t begin, int64_t end) {
17-
for (int64_t i_channel = begin; i_channel < end; ++i_channel) {
18-
for (int64_t i_frame = 0; i_frame < n_frames; ++i_frame) {
19-
last_out_accessor[i_channel] = temp_accessor[i_channel][i_frame] -
20-
last_in_accessor[i_channel] + 0.995 * last_out_accessor[i_channel];
21-
last_in_accessor[i_channel] = temp_accessor[i_channel][i_frame];
22-
output_waveform_accessor[i_channel][i_frame] =
23-
waveform_accessor[i_channel][i_frame] * 0.5 +
24-
last_out_accessor[i_channel] * 0.75;
25-
}
26-
}
27-
});
32+
torch::stable::parallel_for(
33+
0, n_channels, 1, [&](int64_t begin, int64_t end) {
34+
for (int64_t i_channel = begin; i_channel < end; ++i_channel) {
35+
for (int64_t i_frame = 0; i_frame < n_frames; ++i_frame) {
36+
last_out_accessor[i_channel] = temp_accessor[i_channel][i_frame] -
37+
last_in_accessor[i_channel] +
38+
0.995 * last_out_accessor[i_channel];
39+
last_in_accessor[i_channel] = temp_accessor[i_channel][i_frame];
40+
output_waveform_accessor[i_channel][i_frame] =
41+
waveform_accessor[i_channel][i_frame] * 0.5 +
42+
last_out_accessor[i_channel] * 0.75;
43+
}
44+
}
45+
});
2846
}
2947

30-
void overdrive_core_loop_cpu(
31-
at::Tensor& waveform,
32-
at::Tensor& temp,
33-
at::Tensor& last_in,
34-
at::Tensor& last_out,
35-
at::Tensor& output_waveform) {
36-
AT_DISPATCH_FLOATING_TYPES(waveform.scalar_type(), "overdrive_cpu", ([&] {
37-
overdrive_cpu_kernel<scalar_t>(
38-
waveform.accessor<scalar_t, 2>(),
39-
temp.accessor<scalar_t, 2>(),
40-
last_in.accessor<scalar_t, 1>(),
41-
last_out.accessor<scalar_t, 1>(),
42-
output_waveform.accessor<scalar_t, 2>());
43-
}));
48+
std::tuple<Tensor, Tensor, Tensor> overdrive_core_loop_cpu(
49+
Tensor waveform,
50+
Tensor temp,
51+
Tensor last_in,
52+
Tensor last_out,
53+
Tensor output_waveform) {
54+
THO_DISPATCH_V2(
55+
waveform.scalar_type(),
56+
"overdrive_cpu",
57+
AT_WRAP([&] {
58+
overdrive_cpu_kernel<scalar_t>(
59+
accessor<scalar_t, 2>(waveform),
60+
accessor<scalar_t, 2>(temp),
61+
accessor<scalar_t, 1>(last_in),
62+
accessor<scalar_t, 1>(last_out),
63+
accessor<scalar_t, 2>(output_waveform));
64+
}),
65+
AT_FLOATING_TYPES);
66+
return std::make_tuple(last_in, last_out, output_waveform);
4467
}
4568

4669
} // namespace
4770

48-
// Note: We want to avoid using "catch-all" kernel.
49-
// The following registration should be replaced with CPU specific registration.
50-
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
51-
m.def("torchaudio::_overdrive_core_loop", &overdrive_core_loop_cpu);
71+
STABLE_TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
72+
m.def(
73+
"_overdrive_core_loop(Tensor waveform, Tensor temp, Tensor(a!) last_in, Tensor(b!) last_out, Tensor(c!) output_waveform) -> (Tensor(a!), Tensor(b!), Tensor(c!))");
74+
}
75+
76+
STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
77+
m.impl("_overdrive_core_loop", TORCH_BOX(&overdrive_core_loop_cpu));
5278
}

0 commit comments

Comments
 (0)