Skip to content

Commit 3405b09

Browse files
committed
Use TORCH_BOX in rnnt
1 parent 1335f78 commit 3405b09

File tree

2 files changed

+9
-19
lines changed

2 files changed

+9
-19
lines changed

src/libtorchaudio/overdrive.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ template <typename T, size_t N>
1212
using TensorAccessor = torch::headeronly::HeaderOnlyTensorAccessor<T, N>;
1313

1414
// TODO: eliminate accessor<T, N>(t) in favor of t.accessor<T, N>
15-
// after Tensor::accessor is supported in stable ABI
15+
// after Tensor::accessor is supported in stable ABI.
1616
template <typename T, size_t N>
1717
inline TensorAccessor<T, N> accessor(Tensor t) {
1818
return TensorAccessor<T, N>(

src/libtorchaudio/rnnt/gpu/compute.cu

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@ using torch::headeronly::ScalarType;
1414

1515
// Entry point into RNNT Loss
1616
std::tuple<Tensor, Tensor> compute(
17-
const Tensor& logits,
18-
const Tensor& targets,
19-
const Tensor& logit_lengths,
20-
const Tensor& target_lengths,
17+
Tensor logits,
18+
Tensor targets,
19+
Tensor logit_lengths,
20+
Tensor target_lengths,
2121
int64_t blank,
2222
double clamp,
2323
bool fused_log_softmax = true) {
@@ -148,23 +148,13 @@ std::tuple<Tensor, Tensor> compute(
148148
return std::make_tuple(costs, gradients);
149149
}
150150

151-
void boxed_rnnt_loss(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
152-
STD_TORCH_CHECK(num_args == 7, "num_args must be 7");
153-
STD_TORCH_CHECK(num_outputs == 2, "num_outputs must be 2");
154-
std::tuple<Tensor, Tensor> res = compute(
155-
/*logits*/torch::stable::detail::to<Tensor>(stack[0]),
156-
/*targets*/torch::stable::detail::to<Tensor>(stack[1]),
157-
/*logit_lengths*/torch::stable::detail::to<Tensor>(stack[2]),
158-
/*target_lengths*/torch::stable::detail::to<Tensor>(stack[3]),
159-
/*blank*/float(torch::stable::detail::to<int64_t>(stack[4])),
160-
/*clamp*/torch::stable::detail::to<double>(stack[5]),
161-
/*fused_log_softmax*/torch::stable::detail::to<bool>(stack[6]));
162-
stack[0] = torch::stable::detail::from(std::get<0>(res));
163-
stack[1] = torch::stable::detail::from(std::get<1>(res));
151+
STABLE_TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
152+
m.def(
153+
"rnnt_loss_forward(Tensor logits, Tensor targets, Tensor logit_lengths, Tensor target_lengths, int blank, double clamp, bool fused_log_softmax) -> (Tensor, Tensor)");
164154
}
165155

166156
STABLE_TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) {
167-
m.impl("rnnt_loss_forward", &boxed_rnnt_loss);
157+
m.impl("rnnt_loss_forward", TORCH_BOX(&compute));
168158
}
169159

170160
} // namespace gpu

0 commit comments

Comments
 (0)